Make RestoreStateData.async_get_instance backwards compatible (#93924)

pull/93937/head
J. Nick Koston 2023-06-01 12:31:17 -05:00 committed by GitHub
parent 5a8daf06e8
commit 457bc4571d
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 149 additions and 17 deletions

View File

@ -16,6 +16,7 @@ import homeassistant.util.dt as dt_util
from . import start
from .entity import Entity
from .event import async_track_time_interval
from .frame import report
from .json import JSONEncoder
from .storage import Store
@ -96,7 +97,9 @@ class StoredState:
async def async_load(hass: HomeAssistant) -> None:
"""Load the restore state task."""
hass.data[DATA_RESTORE_STATE] = await RestoreStateData.async_get_instance(hass)
restore_state = RestoreStateData(hass)
await restore_state.async_setup()
hass.data[DATA_RESTORE_STATE] = restore_state
@callback
@ -108,25 +111,26 @@ def async_get(hass: HomeAssistant) -> RestoreStateData:
class RestoreStateData:
"""Helper class for managing the helper saved data."""
@staticmethod
async def async_get_instance(hass: HomeAssistant) -> RestoreStateData:
"""Get the instance of this data helper."""
data = RestoreStateData(hass)
await data.async_load()
async def hass_start(hass: HomeAssistant) -> None:
"""Start the restore state task."""
data.async_setup_dump()
start.async_at_start(hass, hass_start)
return data
@classmethod
async def async_save_persistent_states(cls, hass: HomeAssistant) -> None:
"""Dump states now."""
await async_get(hass).async_dump_states()
@classmethod
async def async_get_instance(cls, hass: HomeAssistant) -> RestoreStateData:
"""Return the instance of this class."""
# Nothing should actually be calling this anymore, but we'll keep it
# around for a while to avoid breaking custom components.
#
# In fact they should not be accessing this at all.
report(
"restore_state.RestoreStateData.async_get_instance is deprecated, "
"and not intended to be called by custom components; Please"
"refactor your code to use RestoreEntity instead;"
" restore_state.async_get(hass) can be used in the meantime",
)
return async_get(hass)
def __init__(self, hass: HomeAssistant) -> None:
"""Initialize the restore state data class."""
self.hass: HomeAssistant = hass
@ -136,6 +140,16 @@ class RestoreStateData:
self.last_states: dict[str, StoredState] = {}
self.entities: dict[str, RestoreEntity] = {}
async def async_setup(self) -> None:
"""Set up up the instance of this data helper."""
await self.async_load()
async def hass_start(hass: HomeAssistant) -> None:
"""Start the restore state task."""
self.async_setup_dump()
start.async_at_start(self.hass, hass_start)
async def async_load(self) -> None:
"""Load the instance of this data helper."""
try:

View File

@ -1,12 +1,19 @@
"""The tests for the Restore component."""
from collections.abc import Coroutine
from datetime import datetime, timedelta
import logging
from typing import Any
from unittest.mock import patch
from unittest.mock import Mock, patch
import pytest
from homeassistant.const import EVENT_HOMEASSISTANT_START, EVENT_HOMEASSISTANT_STOP
from homeassistant.core import CoreState, HomeAssistant, State
from homeassistant.exceptions import HomeAssistantError
from homeassistant.helpers.entity import Entity
from homeassistant.helpers.entity_component import EntityComponent
from homeassistant.helpers.entity_platform import AddEntitiesCallback
from homeassistant.helpers.reload import async_get_platform_without_config_entry
from homeassistant.helpers.restore_state import (
DATA_RESTORE_STATE,
STORAGE_KEY,
@ -16,9 +23,20 @@ from homeassistant.helpers.restore_state import (
async_get,
async_load,
)
from homeassistant.helpers.typing import ConfigType, DiscoveryInfoType
from homeassistant.util import dt as dt_util
from tests.common import async_fire_time_changed
from tests.common import (
MockModule,
MockPlatform,
async_fire_time_changed,
mock_entity_platform,
mock_integration,
)
_LOGGER = logging.getLogger(__name__)
DOMAIN = "test_domain"
PLATFORM = "test_platform"
async def test_caching_data(hass: HomeAssistant) -> None:
@ -68,6 +86,20 @@ async def test_caching_data(hass: HomeAssistant) -> None:
assert mock_write_data.called
async def test_async_get_instance_backwards_compatibility(hass: HomeAssistant) -> None:
"""Test async_get_instance backwards compatibility."""
await async_load(hass)
data = async_get(hass)
# When called from core it should raise
with pytest.raises(RuntimeError):
await RestoreStateData.async_get_instance(hass)
# When called from a component it should not raise
# but it should report
with patch("homeassistant.helpers.restore_state.report"):
assert data is await RestoreStateData.async_get_instance(hass)
async def test_periodic_write(hass: HomeAssistant) -> None:
"""Test that we write periodiclly but not after stop."""
data = async_get(hass)
@ -401,3 +433,89 @@ async def test_restoring_invalid_entity_id(
state = await entity.async_get_last_state()
assert state is None
async def test_restore_entity_end_to_end(
hass: HomeAssistant, hass_storage: dict[str, Any]
) -> None:
"""Test restoring an entity end-to-end."""
component_setup = Mock(return_value=True)
setup_called = []
entity_id = "test_domain.unnamed_device"
data = async_get(hass)
now = dt_util.utcnow()
data.last_states = {
entity_id: StoredState(State(entity_id, "stored"), None, now),
}
class MockRestoreEntity(RestoreEntity):
"""Mock restore entity."""
def __init__(self):
"""Initialize the mock entity."""
self._state: str | None = None
@property
def state(self):
"""Return the state."""
return self._state
async def async_added_to_hass(self) -> Coroutine[Any, Any, None]:
"""Run when entity about to be added to hass."""
await super().async_added_to_hass()
self._state = (await self.async_get_last_state()).state
async def async_setup_platform(
hass: HomeAssistant,
config: ConfigType,
async_add_entities: AddEntitiesCallback,
discovery_info: DiscoveryInfoType | None = None,
) -> None:
"""Set up the test platform."""
async_add_entities([MockRestoreEntity()])
setup_called.append(True)
mock_integration(hass, MockModule(DOMAIN, setup=component_setup))
mock_integration(hass, MockModule(PLATFORM, dependencies=[DOMAIN]))
mock_platform = MockPlatform(async_setup_platform=async_setup_platform)
mock_entity_platform(hass, f"{DOMAIN}.{PLATFORM}", mock_platform)
component = EntityComponent(_LOGGER, DOMAIN, hass)
await component.async_setup({DOMAIN: {"platform": PLATFORM, "sensors": None}})
await hass.async_block_till_done()
assert component_setup.called
assert f"{DOMAIN}.{PLATFORM}" in hass.config.components
assert len(setup_called) == 1
platform = async_get_platform_without_config_entry(hass, PLATFORM, DOMAIN)
assert platform.platform_name == PLATFORM
assert platform.domain == DOMAIN
assert hass.states.get(entity_id).state == "stored"
await data.async_dump_states()
await hass.async_block_till_done()
storage_data = hass_storage[STORAGE_KEY]["data"]
assert len(storage_data) == 1
assert storage_data[0]["state"]["entity_id"] == entity_id
assert storage_data[0]["state"]["state"] == "stored"
await platform.async_reset()
assert hass.states.get(entity_id) is None
# Make sure the entity still gets saved to restore state
# even though the platform has been reset since it should
# not be expired yet.
await data.async_dump_states()
await hass.async_block_till_done()
storage_data = hass_storage[STORAGE_KEY]["data"]
assert len(storage_data) == 1
assert storage_data[0]["state"]["entity_id"] == entity_id
assert storage_data[0]["state"]["state"] == "stored"