Make RestoreStateData.async_get_instance backwards compatible (#93924)
parent
5a8daf06e8
commit
457bc4571d
|
@ -16,6 +16,7 @@ import homeassistant.util.dt as dt_util
|
|||
from . import start
|
||||
from .entity import Entity
|
||||
from .event import async_track_time_interval
|
||||
from .frame import report
|
||||
from .json import JSONEncoder
|
||||
from .storage import Store
|
||||
|
||||
|
@ -96,7 +97,9 @@ class StoredState:
|
|||
|
||||
async def async_load(hass: HomeAssistant) -> None:
|
||||
"""Load the restore state task."""
|
||||
hass.data[DATA_RESTORE_STATE] = await RestoreStateData.async_get_instance(hass)
|
||||
restore_state = RestoreStateData(hass)
|
||||
await restore_state.async_setup()
|
||||
hass.data[DATA_RESTORE_STATE] = restore_state
|
||||
|
||||
|
||||
@callback
|
||||
|
@ -108,25 +111,26 @@ def async_get(hass: HomeAssistant) -> RestoreStateData:
|
|||
class RestoreStateData:
|
||||
"""Helper class for managing the helper saved data."""
|
||||
|
||||
@staticmethod
|
||||
async def async_get_instance(hass: HomeAssistant) -> RestoreStateData:
|
||||
"""Get the instance of this data helper."""
|
||||
data = RestoreStateData(hass)
|
||||
await data.async_load()
|
||||
|
||||
async def hass_start(hass: HomeAssistant) -> None:
|
||||
"""Start the restore state task."""
|
||||
data.async_setup_dump()
|
||||
|
||||
start.async_at_start(hass, hass_start)
|
||||
|
||||
return data
|
||||
|
||||
@classmethod
|
||||
async def async_save_persistent_states(cls, hass: HomeAssistant) -> None:
|
||||
"""Dump states now."""
|
||||
await async_get(hass).async_dump_states()
|
||||
|
||||
@classmethod
|
||||
async def async_get_instance(cls, hass: HomeAssistant) -> RestoreStateData:
|
||||
"""Return the instance of this class."""
|
||||
# Nothing should actually be calling this anymore, but we'll keep it
|
||||
# around for a while to avoid breaking custom components.
|
||||
#
|
||||
# In fact they should not be accessing this at all.
|
||||
report(
|
||||
"restore_state.RestoreStateData.async_get_instance is deprecated, "
|
||||
"and not intended to be called by custom components; Please"
|
||||
"refactor your code to use RestoreEntity instead;"
|
||||
" restore_state.async_get(hass) can be used in the meantime",
|
||||
)
|
||||
return async_get(hass)
|
||||
|
||||
def __init__(self, hass: HomeAssistant) -> None:
|
||||
"""Initialize the restore state data class."""
|
||||
self.hass: HomeAssistant = hass
|
||||
|
@ -136,6 +140,16 @@ class RestoreStateData:
|
|||
self.last_states: dict[str, StoredState] = {}
|
||||
self.entities: dict[str, RestoreEntity] = {}
|
||||
|
||||
async def async_setup(self) -> None:
|
||||
"""Set up up the instance of this data helper."""
|
||||
await self.async_load()
|
||||
|
||||
async def hass_start(hass: HomeAssistant) -> None:
|
||||
"""Start the restore state task."""
|
||||
self.async_setup_dump()
|
||||
|
||||
start.async_at_start(self.hass, hass_start)
|
||||
|
||||
async def async_load(self) -> None:
|
||||
"""Load the instance of this data helper."""
|
||||
try:
|
||||
|
|
|
@ -1,12 +1,19 @@
|
|||
"""The tests for the Restore component."""
|
||||
from collections.abc import Coroutine
|
||||
from datetime import datetime, timedelta
|
||||
import logging
|
||||
from typing import Any
|
||||
from unittest.mock import patch
|
||||
from unittest.mock import Mock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from homeassistant.const import EVENT_HOMEASSISTANT_START, EVENT_HOMEASSISTANT_STOP
|
||||
from homeassistant.core import CoreState, HomeAssistant, State
|
||||
from homeassistant.exceptions import HomeAssistantError
|
||||
from homeassistant.helpers.entity import Entity
|
||||
from homeassistant.helpers.entity_component import EntityComponent
|
||||
from homeassistant.helpers.entity_platform import AddEntitiesCallback
|
||||
from homeassistant.helpers.reload import async_get_platform_without_config_entry
|
||||
from homeassistant.helpers.restore_state import (
|
||||
DATA_RESTORE_STATE,
|
||||
STORAGE_KEY,
|
||||
|
@ -16,9 +23,20 @@ from homeassistant.helpers.restore_state import (
|
|||
async_get,
|
||||
async_load,
|
||||
)
|
||||
from homeassistant.helpers.typing import ConfigType, DiscoveryInfoType
|
||||
from homeassistant.util import dt as dt_util
|
||||
|
||||
from tests.common import async_fire_time_changed
|
||||
from tests.common import (
|
||||
MockModule,
|
||||
MockPlatform,
|
||||
async_fire_time_changed,
|
||||
mock_entity_platform,
|
||||
mock_integration,
|
||||
)
|
||||
|
||||
_LOGGER = logging.getLogger(__name__)
|
||||
DOMAIN = "test_domain"
|
||||
PLATFORM = "test_platform"
|
||||
|
||||
|
||||
async def test_caching_data(hass: HomeAssistant) -> None:
|
||||
|
@ -68,6 +86,20 @@ async def test_caching_data(hass: HomeAssistant) -> None:
|
|||
assert mock_write_data.called
|
||||
|
||||
|
||||
async def test_async_get_instance_backwards_compatibility(hass: HomeAssistant) -> None:
|
||||
"""Test async_get_instance backwards compatibility."""
|
||||
await async_load(hass)
|
||||
data = async_get(hass)
|
||||
# When called from core it should raise
|
||||
with pytest.raises(RuntimeError):
|
||||
await RestoreStateData.async_get_instance(hass)
|
||||
|
||||
# When called from a component it should not raise
|
||||
# but it should report
|
||||
with patch("homeassistant.helpers.restore_state.report"):
|
||||
assert data is await RestoreStateData.async_get_instance(hass)
|
||||
|
||||
|
||||
async def test_periodic_write(hass: HomeAssistant) -> None:
|
||||
"""Test that we write periodiclly but not after stop."""
|
||||
data = async_get(hass)
|
||||
|
@ -401,3 +433,89 @@ async def test_restoring_invalid_entity_id(
|
|||
|
||||
state = await entity.async_get_last_state()
|
||||
assert state is None
|
||||
|
||||
|
||||
async def test_restore_entity_end_to_end(
|
||||
hass: HomeAssistant, hass_storage: dict[str, Any]
|
||||
) -> None:
|
||||
"""Test restoring an entity end-to-end."""
|
||||
component_setup = Mock(return_value=True)
|
||||
|
||||
setup_called = []
|
||||
|
||||
entity_id = "test_domain.unnamed_device"
|
||||
data = async_get(hass)
|
||||
now = dt_util.utcnow()
|
||||
data.last_states = {
|
||||
entity_id: StoredState(State(entity_id, "stored"), None, now),
|
||||
}
|
||||
|
||||
class MockRestoreEntity(RestoreEntity):
|
||||
"""Mock restore entity."""
|
||||
|
||||
def __init__(self):
|
||||
"""Initialize the mock entity."""
|
||||
self._state: str | None = None
|
||||
|
||||
@property
|
||||
def state(self):
|
||||
"""Return the state."""
|
||||
return self._state
|
||||
|
||||
async def async_added_to_hass(self) -> Coroutine[Any, Any, None]:
|
||||
"""Run when entity about to be added to hass."""
|
||||
await super().async_added_to_hass()
|
||||
self._state = (await self.async_get_last_state()).state
|
||||
|
||||
async def async_setup_platform(
|
||||
hass: HomeAssistant,
|
||||
config: ConfigType,
|
||||
async_add_entities: AddEntitiesCallback,
|
||||
discovery_info: DiscoveryInfoType | None = None,
|
||||
) -> None:
|
||||
"""Set up the test platform."""
|
||||
async_add_entities([MockRestoreEntity()])
|
||||
setup_called.append(True)
|
||||
|
||||
mock_integration(hass, MockModule(DOMAIN, setup=component_setup))
|
||||
mock_integration(hass, MockModule(PLATFORM, dependencies=[DOMAIN]))
|
||||
|
||||
mock_platform = MockPlatform(async_setup_platform=async_setup_platform)
|
||||
mock_entity_platform(hass, f"{DOMAIN}.{PLATFORM}", mock_platform)
|
||||
|
||||
component = EntityComponent(_LOGGER, DOMAIN, hass)
|
||||
|
||||
await component.async_setup({DOMAIN: {"platform": PLATFORM, "sensors": None}})
|
||||
await hass.async_block_till_done()
|
||||
assert component_setup.called
|
||||
|
||||
assert f"{DOMAIN}.{PLATFORM}" in hass.config.components
|
||||
assert len(setup_called) == 1
|
||||
|
||||
platform = async_get_platform_without_config_entry(hass, PLATFORM, DOMAIN)
|
||||
assert platform.platform_name == PLATFORM
|
||||
assert platform.domain == DOMAIN
|
||||
assert hass.states.get(entity_id).state == "stored"
|
||||
|
||||
await data.async_dump_states()
|
||||
await hass.async_block_till_done()
|
||||
|
||||
storage_data = hass_storage[STORAGE_KEY]["data"]
|
||||
assert len(storage_data) == 1
|
||||
assert storage_data[0]["state"]["entity_id"] == entity_id
|
||||
assert storage_data[0]["state"]["state"] == "stored"
|
||||
|
||||
await platform.async_reset()
|
||||
|
||||
assert hass.states.get(entity_id) is None
|
||||
|
||||
# Make sure the entity still gets saved to restore state
|
||||
# even though the platform has been reset since it should
|
||||
# not be expired yet.
|
||||
await data.async_dump_states()
|
||||
await hass.async_block_till_done()
|
||||
|
||||
storage_data = hass_storage[STORAGE_KEY]["data"]
|
||||
assert len(storage_data) == 1
|
||||
assert storage_data[0]["state"]["entity_id"] == entity_id
|
||||
assert storage_data[0]["state"]["state"] == "stored"
|
||||
|
|
Loading…
Reference in New Issue