Fix esphome not removing entities when static info changes (#95202)

Co-authored-by: Paulus Schoutsen <balloob@gmail.com>
pull/95211/head
J. Nick Koston 2023-06-25 21:31:31 -05:00 committed by GitHub
parent d700415045
commit 3b7095c63b
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
21 changed files with 244 additions and 104 deletions

View File

@ -675,6 +675,7 @@ async def _cleanup_instance(
data.disconnect_callbacks = []
for cleanup_callback in data.cleanup_callbacks:
cleanup_callback()
await data.async_cleanup()
await data.client.disconnect()
return data

View File

@ -73,7 +73,6 @@ async def async_setup_entry(
hass,
entry,
async_add_entities,
component_key="alarm_control_panel",
info_type=AlarmControlPanelInfo,
entity_type=EsphomeAlarmControlPanel,
state_type=AlarmControlPanelEntityState,

View File

@ -29,7 +29,6 @@ async def async_setup_entry(
hass,
entry,
async_add_entities,
component_key="binary_sensor",
info_type=BinarySensorInfo,
entity_type=EsphomeBinarySensor,
state_type=BinarySensorState,

View File

@ -23,7 +23,6 @@ async def async_setup_entry(
hass,
entry,
async_add_entities,
component_key="button",
info_type=ButtonInfo,
entity_type=EsphomeButton,
state_type=EntityState,

View File

@ -27,7 +27,6 @@ async def async_setup_entry(
hass,
entry,
async_add_entities,
component_key="camera",
info_type=CameraInfo,
entity_type=EsphomeCamera,
state_type=CameraState,

View File

@ -72,7 +72,6 @@ async def async_setup_entry(
hass,
entry,
async_add_entities,
component_key="climate",
info_type=ClimateInfo,
entity_type=EsphomeClimateEntity,
state_type=ClimateState,

View File

@ -32,7 +32,6 @@ async def async_setup_entry(
hass,
entry,
async_add_entities,
component_key="cover",
info_type=CoverInfo,
entity_type=EsphomeCover,
state_type=CoverState,

View File

@ -1,7 +1,7 @@
"""Diagnostics support for ESPHome."""
from __future__ import annotations
from typing import Any, cast
from typing import Any
from homeassistant.components.bluetooth import async_scanner_by_source
from homeassistant.components.diagnostics import async_redact_data
@ -28,7 +28,6 @@ async def async_get_config_entry_diagnostics(
entry_data = DomainData.get(hass).get_entry_data(config_entry)
if (storage_data := await entry_data.store.async_load()) is not None:
storage_data = cast("dict[str, Any]", storage_data)
diag["storage_data"] = storage_data
if config_entry.unique_id and (

View File

@ -12,10 +12,9 @@ from typing_extensions import Self
from homeassistant.config_entries import ConfigEntry
from homeassistant.core import HomeAssistant
from homeassistant.helpers.json import JSONEncoder
from homeassistant.helpers.storage import Store
from .const import DOMAIN
from .entry_data import RuntimeEntryData
from .entry_data import ESPHomeStorage, RuntimeEntryData
STORAGE_VERSION = 1
MAX_CACHED_SERVICES = 128
@ -26,7 +25,7 @@ class DomainData:
"""Define a class that stores global esphome data in hass.data[DOMAIN]."""
_entry_datas: dict[str, RuntimeEntryData] = field(default_factory=dict)
_stores: dict[str, Store] = field(default_factory=dict)
_stores: dict[str, ESPHomeStorage] = field(default_factory=dict)
_gatt_services_cache: MutableMapping[int, BleakGATTServiceCollection] = field(
default_factory=lambda: LRU(MAX_CACHED_SERVICES)
)
@ -83,11 +82,13 @@ class DomainData:
"""Check whether the given entry is loaded."""
return entry.entry_id in self._entry_datas
def get_or_create_store(self, hass: HomeAssistant, entry: ConfigEntry) -> Store:
def get_or_create_store(
self, hass: HomeAssistant, entry: ConfigEntry
) -> ESPHomeStorage:
"""Get or create a Store instance for the given config entry."""
return self._stores.setdefault(
entry.entry_id,
Store(
ESPHomeStorage(
hass, STORAGE_VERSION, f"esphome.{entry.entry_id}", encoder=JSONEncoder
),
)

View File

@ -27,7 +27,6 @@ import homeassistant.helpers.config_validation as cv
import homeassistant.helpers.device_registry as dr
from homeassistant.helpers.dispatcher import (
async_dispatcher_connect,
async_dispatcher_send,
)
from homeassistant.helpers.entity import DeviceInfo, Entity
from homeassistant.helpers.entity_platform import AddEntitiesCallback
@ -49,7 +48,6 @@ async def platform_async_setup_entry(
entry: ConfigEntry,
async_add_entities: AddEntitiesCallback,
*,
component_key: str,
info_type: type[_InfoT],
entity_type: type[_EntityT],
state_type: type[_StateT],
@ -60,42 +58,35 @@ async def platform_async_setup_entry(
info and state updates.
"""
entry_data: RuntimeEntryData = DomainData.get(hass).get_entry_data(entry)
entry_data.info[component_key] = {}
entry_data.old_info[component_key] = {}
entry_data.info[info_type] = {}
entry_data.state.setdefault(state_type, {})
@callback
def async_list_entities(infos: list[EntityInfo]) -> None:
"""Update entities of this platform when entities are listed."""
old_infos = entry_data.info[component_key]
current_infos = entry_data.info[info_type]
new_infos: dict[int, EntityInfo] = {}
add_entities: list[_EntityT] = []
for info in infos:
if info.key in old_infos:
# Update existing entity
old_infos.pop(info.key)
else:
if not current_infos.pop(info.key, None):
# Create new entity
entity = entity_type(entry_data, component_key, info, state_type)
entity = entity_type(entry_data, info, state_type)
add_entities.append(entity)
new_infos[info.key] = info
# Remove old entities
for info in old_infos.values():
entry_data.async_remove_entity(hass, component_key, info.key)
# First copy the now-old info into the backup object
entry_data.old_info[component_key] = entry_data.info[component_key]
# Then update the actual info
entry_data.info[component_key] = new_infos
for key, new_info in new_infos.items():
async_dispatcher_send(
hass,
entry_data.signal_component_key_static_info_updated(component_key, key),
new_info,
# Anything still in current_infos is now gone
if current_infos:
hass.async_create_task(
entry_data.async_remove_entities(current_infos.values())
)
# Then update the actual info
entry_data.info[info_type] = new_infos
if new_infos:
entry_data.async_update_entity_infos(new_infos.values())
if add_entities:
# Add entities to Home Assistant
async_add_entities(add_entities)
@ -154,14 +145,12 @@ class EsphomeEntity(Entity, Generic[_InfoT, _StateT]):
def __init__(
self,
entry_data: RuntimeEntryData,
component_key: str,
entity_info: EntityInfo,
state_type: type[_StateT],
) -> None:
"""Initialize."""
self._entry_data = entry_data
self._on_entry_data_changed()
self._component_key = component_key
self._key = entity_info.key
self._state_type = state_type
self._on_static_info_update(entity_info)
@ -178,13 +167,11 @@ class EsphomeEntity(Entity, Generic[_InfoT, _StateT]):
"""Register callbacks."""
entry_data = self._entry_data
hass = self.hass
component_key = self._component_key
key = self._key
self.async_on_remove(
async_dispatcher_connect(
hass,
f"esphome_{self._entry_id}_remove_{component_key}_{key}",
entry_data.async_register_key_static_info_remove_callback(
self._static_info,
functools.partial(self.async_remove, force_remove=True),
)
)
@ -201,10 +188,8 @@ class EsphomeEntity(Entity, Generic[_InfoT, _StateT]):
)
)
self.async_on_remove(
async_dispatcher_connect(
hass,
entry_data.signal_component_key_static_info_updated(component_key, key),
self._on_static_info_update,
entry_data.async_register_key_static_info_updated_callback(
self._static_info, self._on_static_info_update
)
)
self._update_state_from_entry_data()

View File

@ -2,10 +2,10 @@
from __future__ import annotations
import asyncio
from collections.abc import Callable
from collections.abc import Callable, Coroutine, Iterable
from dataclasses import dataclass, field
import logging
from typing import Any, cast
from typing import TYPE_CHECKING, Any, Final, TypedDict, cast
from aioesphomeapi import (
COMPONENT_TYPE_TO_INFO,
@ -41,6 +41,8 @@ from homeassistant.helpers.storage import Store
from .dashboard import async_get_dashboard
INFO_TO_COMPONENT_TYPE: Final = {v: k for k, v in COMPONENT_TYPE_TO_INFO.items()}
_SENTINEL = object()
SAVE_DELAY = 120
_LOGGER = logging.getLogger(__name__)
@ -65,26 +67,31 @@ INFO_TYPE_TO_PLATFORM: dict[type[EntityInfo], Platform] = {
}
class StoreData(TypedDict, total=False):
"""ESPHome storage data."""
device_info: dict[str, Any]
services: list[dict[str, Any]]
api_version: dict[str, Any]
class ESPHomeStorage(Store[StoreData]):
"""ESPHome Storage."""
@dataclass
class RuntimeEntryData:
"""Store runtime data for esphome config entries."""
entry_id: str
client: APIClient
store: Store
store: ESPHomeStorage
state: dict[type[EntityState], dict[int, EntityState]] = field(default_factory=dict)
# When the disconnect callback is called, we mark all states
# as stale so we will always dispatch a state update when the
# device reconnects. This is the same format as state_subscriptions.
stale_state: set[tuple[type[EntityState], int]] = field(default_factory=set)
info: dict[str, dict[int, EntityInfo]] = field(default_factory=dict)
# A second list of EntityInfo objects
# This is necessary for when an entity is being removed. HA requires
# some static info to be accessible during removal (unique_id, maybe others)
# If an entity can't find anything in the info array, it will look for info here.
old_info: dict[str, dict[int, EntityInfo]] = field(default_factory=dict)
info: dict[type[EntityInfo], dict[int, EntityInfo]] = field(default_factory=dict)
services: dict[int, UserService] = field(default_factory=dict)
available: bool = False
device_info: DeviceInfo | None = None
@ -96,7 +103,8 @@ class RuntimeEntryData:
] = field(default_factory=dict)
loaded_platforms: set[Platform] = field(default_factory=set)
platform_load_lock: asyncio.Lock = field(default_factory=asyncio.Lock)
_storage_contents: dict[str, Any] | None = None
_storage_contents: StoreData | None = None
_pending_storage: Callable[[], StoreData] | None = None
ble_connections_free: int = 0
ble_connections_limit: int = 0
_ble_connection_free_futures: list[asyncio.Future[int]] = field(
@ -109,6 +117,12 @@ class RuntimeEntryData:
entity_info_callbacks: dict[
type[EntityInfo], list[Callable[[list[EntityInfo]], None]]
] = field(default_factory=dict)
entity_info_key_remove_callbacks: dict[
tuple[type[EntityInfo], int], list[Callable[[], Coroutine[Any, Any, None]]]
] = field(default_factory=dict)
entity_info_key_updated_callbacks: dict[
tuple[type[EntityInfo], int], list[Callable[[EntityInfo], None]]
] = field(default_factory=dict)
original_options: dict[str, Any] = field(default_factory=dict)
@property
@ -133,12 +147,6 @@ class RuntimeEntryData:
"""Return the signal to listen to for updates on static info."""
return f"esphome_{self.entry_id}_on_list"
def signal_component_key_static_info_updated(
self, component_key: str, key: int
) -> str:
"""Return the signal to listen to for updates on static info for a specific component_key and key."""
return f"esphome_{self.entry_id}_static_info_updated_{component_key}_{key}"
@callback
def async_register_static_info_callback(
self,
@ -154,6 +162,38 @@ class RuntimeEntryData:
return _unsub
@callback
def async_register_key_static_info_remove_callback(
self,
static_info: EntityInfo,
callback_: Callable[[], Coroutine[Any, Any, None]],
) -> CALLBACK_TYPE:
"""Register to receive callbacks when static info is removed for a specific key."""
callback_key = (type(static_info), static_info.key)
callbacks = self.entity_info_key_remove_callbacks.setdefault(callback_key, [])
callbacks.append(callback_)
def _unsub() -> None:
callbacks.remove(callback_)
return _unsub
@callback
def async_register_key_static_info_updated_callback(
self,
static_info: EntityInfo,
callback_: Callable[[EntityInfo], None],
) -> CALLBACK_TYPE:
"""Register to receive callbacks when static info is updated for a specific key."""
callback_key = (type(static_info), static_info.key)
callbacks = self.entity_info_key_updated_callbacks.setdefault(callback_key, [])
callbacks.append(callback_)
def _unsub() -> None:
callbacks.remove(callback_)
return _unsub
@callback
def async_update_ble_connection_limits(self, free: int, limit: int) -> None:
"""Update the BLE connection limits."""
@ -203,13 +243,25 @@ class RuntimeEntryData:
self.assist_pipeline_update_callbacks.append(update_callback)
return _unsubscribe
@callback
def async_remove_entity(
self, hass: HomeAssistant, component_key: str, key: int
) -> None:
async def async_remove_entities(self, static_infos: Iterable[EntityInfo]) -> None:
"""Schedule the removal of an entity."""
signal = f"esphome_{self.entry_id}_remove_{component_key}_{key}"
async_dispatcher_send(hass, signal)
callbacks: list[Coroutine[Any, Any, None]] = []
for static_info in static_infos:
callback_key = (type(static_info), static_info.key)
if key_callbacks := self.entity_info_key_remove_callbacks.get(callback_key):
callbacks.extend([callback_() for callback_ in key_callbacks])
if callbacks:
await asyncio.gather(*callbacks)
@callback
def async_update_entity_infos(self, static_infos: Iterable[EntityInfo]) -> None:
"""Call static info updated callbacks."""
for static_info in static_infos:
callback_key = (type(static_info), static_info.key)
for callback_ in self.entity_info_key_updated_callbacks.get(
callback_key, []
):
callback_(static_info)
async def _ensure_platforms_loaded(
self, hass: HomeAssistant, entry: ConfigEntry, platforms: set[Platform]
@ -288,7 +340,7 @@ class RuntimeEntryData:
and subscription_key not in stale_state
and not (
type(state) is SensorState # pylint: disable=unidiomatic-typecheck
and (platform_info := self.info.get(Platform.SENSOR))
and (platform_info := self.info.get(SensorInfo))
and (entity_info := platform_info.get(state.key))
and (cast(SensorInfo, entity_info)).force_update
)
@ -326,47 +378,57 @@ class RuntimeEntryData:
"""Load the retained data from store and return de-serialized data."""
if (restored := await self.store.async_load()) is None:
return [], []
restored = cast("dict[str, Any]", restored)
self._storage_contents = restored.copy()
self.device_info = DeviceInfo.from_dict(restored.pop("device_info"))
self.api_version = APIVersion.from_dict(restored.pop("api_version", {}))
infos = []
infos: list[EntityInfo] = []
for comp_type, restored_infos in restored.items():
if TYPE_CHECKING:
restored_infos = cast(list[dict[str, Any]], restored_infos)
if comp_type not in COMPONENT_TYPE_TO_INFO:
continue
for info in restored_infos:
cls = COMPONENT_TYPE_TO_INFO[comp_type]
infos.append(cls.from_dict(info))
services = []
for service in restored.get("services", []):
services.append(UserService.from_dict(service))
services = [
UserService.from_dict(service) for service in restored.pop("services", [])
]
return infos, services
async def async_save_to_store(self) -> None:
"""Generate dynamic data to store and save it to the filesystem."""
if self.device_info is None:
raise ValueError("device_info is not set yet")
store_data: dict[str, Any] = {
store_data: StoreData = {
"device_info": self.device_info.to_dict(),
"services": [],
"api_version": self.api_version.to_dict(),
}
for comp_type, infos in self.info.items():
store_data[comp_type] = [info.to_dict() for info in infos.values()]
for info_type, infos in self.info.items():
comp_type = INFO_TO_COMPONENT_TYPE[info_type]
store_data[comp_type] = [info.to_dict() for info in infos.values()] # type: ignore[literal-required]
for service in self.services.values():
store_data["services"].append(service.to_dict())
if store_data == self._storage_contents:
return
def _memorized_storage() -> dict[str, Any]:
def _memorized_storage() -> StoreData:
self._pending_storage = None
self._storage_contents = store_data
return store_data
self._pending_storage = _memorized_storage
self.store.async_delay_save(_memorized_storage, SAVE_DELAY)
async def async_cleanup(self) -> None:
"""Cleanup the entry data when disconnected or unloading."""
if self._pending_storage:
# Ensure we save the data if we are unloading before the
# save delay has passed.
await self.store.async_save(self._pending_storage())
async def async_update_listener(
self, hass: HomeAssistant, entry: ConfigEntry
) -> None:

View File

@ -40,7 +40,6 @@ async def async_setup_entry(
hass,
entry,
async_add_entities,
component_key="fan",
info_type=FanInfo,
entity_type=EsphomeFan,
state_type=FanState,

View File

@ -48,7 +48,6 @@ async def async_setup_entry(
hass,
entry,
async_add_entities,
component_key="light",
info_type=LightInfo,
entity_type=EsphomeLight,
state_type=LightState,

View File

@ -26,7 +26,6 @@ async def async_setup_entry(
hass,
entry,
async_add_entities,
component_key="lock",
info_type=LockInfo,
entity_type=EsphomeLock,
state_type=LockEntityState,

View File

@ -43,7 +43,6 @@ async def async_setup_entry(
hass,
entry,
async_add_entities,
component_key="media_player",
info_type=MediaPlayerInfo,
entity_type=EsphomeMediaPlayer,
state_type=MediaPlayerEntityState,

View File

@ -34,7 +34,6 @@ async def async_setup_entry(
hass,
entry,
async_add_entities,
component_key="number",
info_type=NumberInfo,
entity_type=EsphomeNumber,
state_type=NumberState,

View File

@ -29,7 +29,6 @@ async def async_setup_entry(
hass,
entry,
async_add_entities,
component_key="select",
info_type=SelectInfo,
entity_type=EsphomeSelect,
state_type=SelectState,

View File

@ -41,7 +41,6 @@ async def async_setup_entry(
hass,
entry,
async_add_entities,
component_key="sensor",
info_type=SensorInfo,
entity_type=EsphomeSensor,
state_type=SensorState,
@ -50,7 +49,6 @@ async def async_setup_entry(
hass,
entry,
async_add_entities,
component_key="text_sensor",
info_type=TextSensorInfo,
entity_type=EsphomeTextSensor,
state_type=TextSensorState,

View File

@ -26,7 +26,6 @@ async def async_setup_entry(
hass,
entry,
async_add_entities,
component_key="switch",
info_type=SwitchInfo,
entity_type=EsphomeSwitch,
state_type=SwitchState,

View File

@ -169,19 +169,22 @@ async def _mock_generic_device_entry(
mock_device_info: dict[str, Any],
mock_list_entities_services: tuple[list[EntityInfo], list[UserService]],
states: list[EntityState],
entry: MockConfigEntry | None = None,
) -> MockESPHomeDevice:
entry = MockConfigEntry(
domain=DOMAIN,
data={
CONF_HOST: "test.local",
CONF_PORT: 6053,
CONF_PASSWORD: "",
},
options={
CONF_ALLOW_SERVICE_CALLS: DEFAULT_NEW_CONFIG_ALLOW_ALLOW_SERVICE_CALLS
},
)
entry.add_to_hass(hass)
if not entry:
entry = MockConfigEntry(
domain=DOMAIN,
data={
CONF_HOST: "test.local",
CONF_PORT: 6053,
CONF_PASSWORD: "",
},
options={
CONF_ALLOW_SERVICE_CALLS: DEFAULT_NEW_CONFIG_ALLOW_ALLOW_SERVICE_CALLS
},
)
entry.add_to_hass(hass)
mock_device = MockESPHomeDevice(entry)
device_info = DeviceInfo(
@ -290,9 +293,10 @@ async def mock_esphome_device(
entity_info: list[EntityInfo],
user_service: list[UserService],
states: list[EntityState],
entry: MockConfigEntry | None = None,
) -> MockESPHomeDevice:
return await _mock_generic_device_entry(
hass, mock_client, {}, (entity_info, user_service), states
hass, mock_client, {}, (entity_info, user_service), states, entry
)
return _mock_device

View File

@ -0,0 +1,103 @@
"""Test ESPHome binary sensors."""
from collections.abc import Awaitable, Callable
from typing import Any
from aioesphomeapi import (
APIClient,
BinarySensorInfo,
BinarySensorState,
EntityInfo,
EntityState,
UserService,
)
from homeassistant.const import ATTR_RESTORED, STATE_ON
from homeassistant.core import HomeAssistant
from .conftest import MockESPHomeDevice
async def test_entities_removed(
hass: HomeAssistant,
mock_client: APIClient,
hass_storage: dict[str, Any],
mock_esphome_device: Callable[
[APIClient, list[EntityInfo], list[UserService], list[EntityState]],
Awaitable[MockESPHomeDevice],
],
) -> None:
"""Test a generic binary_sensor where has_state is false."""
entity_info = [
BinarySensorInfo(
object_id="mybinary_sensor",
key=1,
name="my binary_sensor",
unique_id="my_binary_sensor",
),
BinarySensorInfo(
object_id="mybinary_sensor_to_be_removed",
key=2,
name="my binary_sensor to be removed",
unique_id="mybinary_sensor_to_be_removed",
),
]
states = [
BinarySensorState(key=1, state=True, missing_state=False),
BinarySensorState(key=2, state=True, missing_state=False),
]
user_service = []
mock_device = await mock_esphome_device(
mock_client=mock_client,
entity_info=entity_info,
user_service=user_service,
states=states,
)
entry = mock_device.entry
entry_id = entry.entry_id
storage_key = f"esphome.{entry_id}"
state = hass.states.get("binary_sensor.test_my_binary_sensor")
assert state is not None
assert state.state == STATE_ON
state = hass.states.get("binary_sensor.test_my_binary_sensor_to_be_removed")
assert state is not None
assert state.state == STATE_ON
await hass.config_entries.async_unload(entry.entry_id)
await hass.async_block_till_done()
assert len(hass_storage[storage_key]["data"]["binary_sensor"]) == 2
state = hass.states.get("binary_sensor.test_my_binary_sensor")
assert state is not None
assert state.attributes[ATTR_RESTORED] is True
state = hass.states.get("binary_sensor.test_my_binary_sensor_to_be_removed")
assert state is not None
assert state.attributes[ATTR_RESTORED] is True
entity_info = [
BinarySensorInfo(
object_id="mybinary_sensor",
key=1,
name="my binary_sensor",
unique_id="my_binary_sensor",
),
]
states = [
BinarySensorState(key=1, state=True, missing_state=False),
]
mock_device = await mock_esphome_device(
mock_client=mock_client,
entity_info=entity_info,
user_service=user_service,
states=states,
entry=entry,
)
assert mock_device.entry.entry_id == entry_id
state = hass.states.get("binary_sensor.test_my_binary_sensor")
assert state is not None
assert state.state == STATE_ON
state = hass.states.get("binary_sensor.test_my_binary_sensor_to_be_removed")
assert state is None
await hass.config_entries.async_unload(entry.entry_id)
await hass.async_block_till_done()
assert len(hass_storage[storage_key]["data"]["binary_sensor"]) == 1