Fix esphome not removing entities when static info changes (#95202)
Co-authored-by: Paulus Schoutsen <balloob@gmail.com>pull/95211/head
parent
d700415045
commit
3b7095c63b
|
@ -675,6 +675,7 @@ async def _cleanup_instance(
|
|||
data.disconnect_callbacks = []
|
||||
for cleanup_callback in data.cleanup_callbacks:
|
||||
cleanup_callback()
|
||||
await data.async_cleanup()
|
||||
await data.client.disconnect()
|
||||
return data
|
||||
|
||||
|
|
|
@ -73,7 +73,6 @@ async def async_setup_entry(
|
|||
hass,
|
||||
entry,
|
||||
async_add_entities,
|
||||
component_key="alarm_control_panel",
|
||||
info_type=AlarmControlPanelInfo,
|
||||
entity_type=EsphomeAlarmControlPanel,
|
||||
state_type=AlarmControlPanelEntityState,
|
||||
|
|
|
@ -29,7 +29,6 @@ async def async_setup_entry(
|
|||
hass,
|
||||
entry,
|
||||
async_add_entities,
|
||||
component_key="binary_sensor",
|
||||
info_type=BinarySensorInfo,
|
||||
entity_type=EsphomeBinarySensor,
|
||||
state_type=BinarySensorState,
|
||||
|
|
|
@ -23,7 +23,6 @@ async def async_setup_entry(
|
|||
hass,
|
||||
entry,
|
||||
async_add_entities,
|
||||
component_key="button",
|
||||
info_type=ButtonInfo,
|
||||
entity_type=EsphomeButton,
|
||||
state_type=EntityState,
|
||||
|
|
|
@ -27,7 +27,6 @@ async def async_setup_entry(
|
|||
hass,
|
||||
entry,
|
||||
async_add_entities,
|
||||
component_key="camera",
|
||||
info_type=CameraInfo,
|
||||
entity_type=EsphomeCamera,
|
||||
state_type=CameraState,
|
||||
|
|
|
@ -72,7 +72,6 @@ async def async_setup_entry(
|
|||
hass,
|
||||
entry,
|
||||
async_add_entities,
|
||||
component_key="climate",
|
||||
info_type=ClimateInfo,
|
||||
entity_type=EsphomeClimateEntity,
|
||||
state_type=ClimateState,
|
||||
|
|
|
@ -32,7 +32,6 @@ async def async_setup_entry(
|
|||
hass,
|
||||
entry,
|
||||
async_add_entities,
|
||||
component_key="cover",
|
||||
info_type=CoverInfo,
|
||||
entity_type=EsphomeCover,
|
||||
state_type=CoverState,
|
||||
|
|
|
@ -1,7 +1,7 @@
|
|||
"""Diagnostics support for ESPHome."""
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any, cast
|
||||
from typing import Any
|
||||
|
||||
from homeassistant.components.bluetooth import async_scanner_by_source
|
||||
from homeassistant.components.diagnostics import async_redact_data
|
||||
|
@ -28,7 +28,6 @@ async def async_get_config_entry_diagnostics(
|
|||
entry_data = DomainData.get(hass).get_entry_data(config_entry)
|
||||
|
||||
if (storage_data := await entry_data.store.async_load()) is not None:
|
||||
storage_data = cast("dict[str, Any]", storage_data)
|
||||
diag["storage_data"] = storage_data
|
||||
|
||||
if config_entry.unique_id and (
|
||||
|
|
|
@ -12,10 +12,9 @@ from typing_extensions import Self
|
|||
from homeassistant.config_entries import ConfigEntry
|
||||
from homeassistant.core import HomeAssistant
|
||||
from homeassistant.helpers.json import JSONEncoder
|
||||
from homeassistant.helpers.storage import Store
|
||||
|
||||
from .const import DOMAIN
|
||||
from .entry_data import RuntimeEntryData
|
||||
from .entry_data import ESPHomeStorage, RuntimeEntryData
|
||||
|
||||
STORAGE_VERSION = 1
|
||||
MAX_CACHED_SERVICES = 128
|
||||
|
@ -26,7 +25,7 @@ class DomainData:
|
|||
"""Define a class that stores global esphome data in hass.data[DOMAIN]."""
|
||||
|
||||
_entry_datas: dict[str, RuntimeEntryData] = field(default_factory=dict)
|
||||
_stores: dict[str, Store] = field(default_factory=dict)
|
||||
_stores: dict[str, ESPHomeStorage] = field(default_factory=dict)
|
||||
_gatt_services_cache: MutableMapping[int, BleakGATTServiceCollection] = field(
|
||||
default_factory=lambda: LRU(MAX_CACHED_SERVICES)
|
||||
)
|
||||
|
@ -83,11 +82,13 @@ class DomainData:
|
|||
"""Check whether the given entry is loaded."""
|
||||
return entry.entry_id in self._entry_datas
|
||||
|
||||
def get_or_create_store(self, hass: HomeAssistant, entry: ConfigEntry) -> Store:
|
||||
def get_or_create_store(
|
||||
self, hass: HomeAssistant, entry: ConfigEntry
|
||||
) -> ESPHomeStorage:
|
||||
"""Get or create a Store instance for the given config entry."""
|
||||
return self._stores.setdefault(
|
||||
entry.entry_id,
|
||||
Store(
|
||||
ESPHomeStorage(
|
||||
hass, STORAGE_VERSION, f"esphome.{entry.entry_id}", encoder=JSONEncoder
|
||||
),
|
||||
)
|
||||
|
|
|
@ -27,7 +27,6 @@ import homeassistant.helpers.config_validation as cv
|
|||
import homeassistant.helpers.device_registry as dr
|
||||
from homeassistant.helpers.dispatcher import (
|
||||
async_dispatcher_connect,
|
||||
async_dispatcher_send,
|
||||
)
|
||||
from homeassistant.helpers.entity import DeviceInfo, Entity
|
||||
from homeassistant.helpers.entity_platform import AddEntitiesCallback
|
||||
|
@ -49,7 +48,6 @@ async def platform_async_setup_entry(
|
|||
entry: ConfigEntry,
|
||||
async_add_entities: AddEntitiesCallback,
|
||||
*,
|
||||
component_key: str,
|
||||
info_type: type[_InfoT],
|
||||
entity_type: type[_EntityT],
|
||||
state_type: type[_StateT],
|
||||
|
@ -60,42 +58,35 @@ async def platform_async_setup_entry(
|
|||
info and state updates.
|
||||
"""
|
||||
entry_data: RuntimeEntryData = DomainData.get(hass).get_entry_data(entry)
|
||||
entry_data.info[component_key] = {}
|
||||
entry_data.old_info[component_key] = {}
|
||||
entry_data.info[info_type] = {}
|
||||
entry_data.state.setdefault(state_type, {})
|
||||
|
||||
@callback
|
||||
def async_list_entities(infos: list[EntityInfo]) -> None:
|
||||
"""Update entities of this platform when entities are listed."""
|
||||
old_infos = entry_data.info[component_key]
|
||||
current_infos = entry_data.info[info_type]
|
||||
new_infos: dict[int, EntityInfo] = {}
|
||||
add_entities: list[_EntityT] = []
|
||||
|
||||
for info in infos:
|
||||
if info.key in old_infos:
|
||||
# Update existing entity
|
||||
old_infos.pop(info.key)
|
||||
else:
|
||||
if not current_infos.pop(info.key, None):
|
||||
# Create new entity
|
||||
entity = entity_type(entry_data, component_key, info, state_type)
|
||||
entity = entity_type(entry_data, info, state_type)
|
||||
add_entities.append(entity)
|
||||
new_infos[info.key] = info
|
||||
|
||||
# Remove old entities
|
||||
for info in old_infos.values():
|
||||
entry_data.async_remove_entity(hass, component_key, info.key)
|
||||
|
||||
# First copy the now-old info into the backup object
|
||||
entry_data.old_info[component_key] = entry_data.info[component_key]
|
||||
# Then update the actual info
|
||||
entry_data.info[component_key] = new_infos
|
||||
|
||||
for key, new_info in new_infos.items():
|
||||
async_dispatcher_send(
|
||||
hass,
|
||||
entry_data.signal_component_key_static_info_updated(component_key, key),
|
||||
new_info,
|
||||
# Anything still in current_infos is now gone
|
||||
if current_infos:
|
||||
hass.async_create_task(
|
||||
entry_data.async_remove_entities(current_infos.values())
|
||||
)
|
||||
|
||||
# Then update the actual info
|
||||
entry_data.info[info_type] = new_infos
|
||||
|
||||
if new_infos:
|
||||
entry_data.async_update_entity_infos(new_infos.values())
|
||||
|
||||
if add_entities:
|
||||
# Add entities to Home Assistant
|
||||
async_add_entities(add_entities)
|
||||
|
@ -154,14 +145,12 @@ class EsphomeEntity(Entity, Generic[_InfoT, _StateT]):
|
|||
def __init__(
|
||||
self,
|
||||
entry_data: RuntimeEntryData,
|
||||
component_key: str,
|
||||
entity_info: EntityInfo,
|
||||
state_type: type[_StateT],
|
||||
) -> None:
|
||||
"""Initialize."""
|
||||
self._entry_data = entry_data
|
||||
self._on_entry_data_changed()
|
||||
self._component_key = component_key
|
||||
self._key = entity_info.key
|
||||
self._state_type = state_type
|
||||
self._on_static_info_update(entity_info)
|
||||
|
@ -178,13 +167,11 @@ class EsphomeEntity(Entity, Generic[_InfoT, _StateT]):
|
|||
"""Register callbacks."""
|
||||
entry_data = self._entry_data
|
||||
hass = self.hass
|
||||
component_key = self._component_key
|
||||
key = self._key
|
||||
|
||||
self.async_on_remove(
|
||||
async_dispatcher_connect(
|
||||
hass,
|
||||
f"esphome_{self._entry_id}_remove_{component_key}_{key}",
|
||||
entry_data.async_register_key_static_info_remove_callback(
|
||||
self._static_info,
|
||||
functools.partial(self.async_remove, force_remove=True),
|
||||
)
|
||||
)
|
||||
|
@ -201,10 +188,8 @@ class EsphomeEntity(Entity, Generic[_InfoT, _StateT]):
|
|||
)
|
||||
)
|
||||
self.async_on_remove(
|
||||
async_dispatcher_connect(
|
||||
hass,
|
||||
entry_data.signal_component_key_static_info_updated(component_key, key),
|
||||
self._on_static_info_update,
|
||||
entry_data.async_register_key_static_info_updated_callback(
|
||||
self._static_info, self._on_static_info_update
|
||||
)
|
||||
)
|
||||
self._update_state_from_entry_data()
|
||||
|
|
|
@ -2,10 +2,10 @@
|
|||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
from collections.abc import Callable
|
||||
from collections.abc import Callable, Coroutine, Iterable
|
||||
from dataclasses import dataclass, field
|
||||
import logging
|
||||
from typing import Any, cast
|
||||
from typing import TYPE_CHECKING, Any, Final, TypedDict, cast
|
||||
|
||||
from aioesphomeapi import (
|
||||
COMPONENT_TYPE_TO_INFO,
|
||||
|
@ -41,6 +41,8 @@ from homeassistant.helpers.storage import Store
|
|||
|
||||
from .dashboard import async_get_dashboard
|
||||
|
||||
INFO_TO_COMPONENT_TYPE: Final = {v: k for k, v in COMPONENT_TYPE_TO_INFO.items()}
|
||||
|
||||
_SENTINEL = object()
|
||||
SAVE_DELAY = 120
|
||||
_LOGGER = logging.getLogger(__name__)
|
||||
|
@ -65,26 +67,31 @@ INFO_TYPE_TO_PLATFORM: dict[type[EntityInfo], Platform] = {
|
|||
}
|
||||
|
||||
|
||||
class StoreData(TypedDict, total=False):
|
||||
"""ESPHome storage data."""
|
||||
|
||||
device_info: dict[str, Any]
|
||||
services: list[dict[str, Any]]
|
||||
api_version: dict[str, Any]
|
||||
|
||||
|
||||
class ESPHomeStorage(Store[StoreData]):
|
||||
"""ESPHome Storage."""
|
||||
|
||||
|
||||
@dataclass
|
||||
class RuntimeEntryData:
|
||||
"""Store runtime data for esphome config entries."""
|
||||
|
||||
entry_id: str
|
||||
client: APIClient
|
||||
store: Store
|
||||
store: ESPHomeStorage
|
||||
state: dict[type[EntityState], dict[int, EntityState]] = field(default_factory=dict)
|
||||
# When the disconnect callback is called, we mark all states
|
||||
# as stale so we will always dispatch a state update when the
|
||||
# device reconnects. This is the same format as state_subscriptions.
|
||||
stale_state: set[tuple[type[EntityState], int]] = field(default_factory=set)
|
||||
info: dict[str, dict[int, EntityInfo]] = field(default_factory=dict)
|
||||
|
||||
# A second list of EntityInfo objects
|
||||
# This is necessary for when an entity is being removed. HA requires
|
||||
# some static info to be accessible during removal (unique_id, maybe others)
|
||||
# If an entity can't find anything in the info array, it will look for info here.
|
||||
old_info: dict[str, dict[int, EntityInfo]] = field(default_factory=dict)
|
||||
|
||||
info: dict[type[EntityInfo], dict[int, EntityInfo]] = field(default_factory=dict)
|
||||
services: dict[int, UserService] = field(default_factory=dict)
|
||||
available: bool = False
|
||||
device_info: DeviceInfo | None = None
|
||||
|
@ -96,7 +103,8 @@ class RuntimeEntryData:
|
|||
] = field(default_factory=dict)
|
||||
loaded_platforms: set[Platform] = field(default_factory=set)
|
||||
platform_load_lock: asyncio.Lock = field(default_factory=asyncio.Lock)
|
||||
_storage_contents: dict[str, Any] | None = None
|
||||
_storage_contents: StoreData | None = None
|
||||
_pending_storage: Callable[[], StoreData] | None = None
|
||||
ble_connections_free: int = 0
|
||||
ble_connections_limit: int = 0
|
||||
_ble_connection_free_futures: list[asyncio.Future[int]] = field(
|
||||
|
@ -109,6 +117,12 @@ class RuntimeEntryData:
|
|||
entity_info_callbacks: dict[
|
||||
type[EntityInfo], list[Callable[[list[EntityInfo]], None]]
|
||||
] = field(default_factory=dict)
|
||||
entity_info_key_remove_callbacks: dict[
|
||||
tuple[type[EntityInfo], int], list[Callable[[], Coroutine[Any, Any, None]]]
|
||||
] = field(default_factory=dict)
|
||||
entity_info_key_updated_callbacks: dict[
|
||||
tuple[type[EntityInfo], int], list[Callable[[EntityInfo], None]]
|
||||
] = field(default_factory=dict)
|
||||
original_options: dict[str, Any] = field(default_factory=dict)
|
||||
|
||||
@property
|
||||
|
@ -133,12 +147,6 @@ class RuntimeEntryData:
|
|||
"""Return the signal to listen to for updates on static info."""
|
||||
return f"esphome_{self.entry_id}_on_list"
|
||||
|
||||
def signal_component_key_static_info_updated(
|
||||
self, component_key: str, key: int
|
||||
) -> str:
|
||||
"""Return the signal to listen to for updates on static info for a specific component_key and key."""
|
||||
return f"esphome_{self.entry_id}_static_info_updated_{component_key}_{key}"
|
||||
|
||||
@callback
|
||||
def async_register_static_info_callback(
|
||||
self,
|
||||
|
@ -154,6 +162,38 @@ class RuntimeEntryData:
|
|||
|
||||
return _unsub
|
||||
|
||||
@callback
|
||||
def async_register_key_static_info_remove_callback(
|
||||
self,
|
||||
static_info: EntityInfo,
|
||||
callback_: Callable[[], Coroutine[Any, Any, None]],
|
||||
) -> CALLBACK_TYPE:
|
||||
"""Register to receive callbacks when static info is removed for a specific key."""
|
||||
callback_key = (type(static_info), static_info.key)
|
||||
callbacks = self.entity_info_key_remove_callbacks.setdefault(callback_key, [])
|
||||
callbacks.append(callback_)
|
||||
|
||||
def _unsub() -> None:
|
||||
callbacks.remove(callback_)
|
||||
|
||||
return _unsub
|
||||
|
||||
@callback
|
||||
def async_register_key_static_info_updated_callback(
|
||||
self,
|
||||
static_info: EntityInfo,
|
||||
callback_: Callable[[EntityInfo], None],
|
||||
) -> CALLBACK_TYPE:
|
||||
"""Register to receive callbacks when static info is updated for a specific key."""
|
||||
callback_key = (type(static_info), static_info.key)
|
||||
callbacks = self.entity_info_key_updated_callbacks.setdefault(callback_key, [])
|
||||
callbacks.append(callback_)
|
||||
|
||||
def _unsub() -> None:
|
||||
callbacks.remove(callback_)
|
||||
|
||||
return _unsub
|
||||
|
||||
@callback
|
||||
def async_update_ble_connection_limits(self, free: int, limit: int) -> None:
|
||||
"""Update the BLE connection limits."""
|
||||
|
@ -203,13 +243,25 @@ class RuntimeEntryData:
|
|||
self.assist_pipeline_update_callbacks.append(update_callback)
|
||||
return _unsubscribe
|
||||
|
||||
@callback
|
||||
def async_remove_entity(
|
||||
self, hass: HomeAssistant, component_key: str, key: int
|
||||
) -> None:
|
||||
async def async_remove_entities(self, static_infos: Iterable[EntityInfo]) -> None:
|
||||
"""Schedule the removal of an entity."""
|
||||
signal = f"esphome_{self.entry_id}_remove_{component_key}_{key}"
|
||||
async_dispatcher_send(hass, signal)
|
||||
callbacks: list[Coroutine[Any, Any, None]] = []
|
||||
for static_info in static_infos:
|
||||
callback_key = (type(static_info), static_info.key)
|
||||
if key_callbacks := self.entity_info_key_remove_callbacks.get(callback_key):
|
||||
callbacks.extend([callback_() for callback_ in key_callbacks])
|
||||
if callbacks:
|
||||
await asyncio.gather(*callbacks)
|
||||
|
||||
@callback
|
||||
def async_update_entity_infos(self, static_infos: Iterable[EntityInfo]) -> None:
|
||||
"""Call static info updated callbacks."""
|
||||
for static_info in static_infos:
|
||||
callback_key = (type(static_info), static_info.key)
|
||||
for callback_ in self.entity_info_key_updated_callbacks.get(
|
||||
callback_key, []
|
||||
):
|
||||
callback_(static_info)
|
||||
|
||||
async def _ensure_platforms_loaded(
|
||||
self, hass: HomeAssistant, entry: ConfigEntry, platforms: set[Platform]
|
||||
|
@ -288,7 +340,7 @@ class RuntimeEntryData:
|
|||
and subscription_key not in stale_state
|
||||
and not (
|
||||
type(state) is SensorState # pylint: disable=unidiomatic-typecheck
|
||||
and (platform_info := self.info.get(Platform.SENSOR))
|
||||
and (platform_info := self.info.get(SensorInfo))
|
||||
and (entity_info := platform_info.get(state.key))
|
||||
and (cast(SensorInfo, entity_info)).force_update
|
||||
)
|
||||
|
@ -326,47 +378,57 @@ class RuntimeEntryData:
|
|||
"""Load the retained data from store and return de-serialized data."""
|
||||
if (restored := await self.store.async_load()) is None:
|
||||
return [], []
|
||||
restored = cast("dict[str, Any]", restored)
|
||||
self._storage_contents = restored.copy()
|
||||
|
||||
self.device_info = DeviceInfo.from_dict(restored.pop("device_info"))
|
||||
self.api_version = APIVersion.from_dict(restored.pop("api_version", {}))
|
||||
infos = []
|
||||
infos: list[EntityInfo] = []
|
||||
for comp_type, restored_infos in restored.items():
|
||||
if TYPE_CHECKING:
|
||||
restored_infos = cast(list[dict[str, Any]], restored_infos)
|
||||
if comp_type not in COMPONENT_TYPE_TO_INFO:
|
||||
continue
|
||||
for info in restored_infos:
|
||||
cls = COMPONENT_TYPE_TO_INFO[comp_type]
|
||||
infos.append(cls.from_dict(info))
|
||||
services = []
|
||||
for service in restored.get("services", []):
|
||||
services.append(UserService.from_dict(service))
|
||||
services = [
|
||||
UserService.from_dict(service) for service in restored.pop("services", [])
|
||||
]
|
||||
return infos, services
|
||||
|
||||
async def async_save_to_store(self) -> None:
|
||||
"""Generate dynamic data to store and save it to the filesystem."""
|
||||
if self.device_info is None:
|
||||
raise ValueError("device_info is not set yet")
|
||||
store_data: dict[str, Any] = {
|
||||
store_data: StoreData = {
|
||||
"device_info": self.device_info.to_dict(),
|
||||
"services": [],
|
||||
"api_version": self.api_version.to_dict(),
|
||||
}
|
||||
|
||||
for comp_type, infos in self.info.items():
|
||||
store_data[comp_type] = [info.to_dict() for info in infos.values()]
|
||||
for info_type, infos in self.info.items():
|
||||
comp_type = INFO_TO_COMPONENT_TYPE[info_type]
|
||||
store_data[comp_type] = [info.to_dict() for info in infos.values()] # type: ignore[literal-required]
|
||||
for service in self.services.values():
|
||||
store_data["services"].append(service.to_dict())
|
||||
|
||||
if store_data == self._storage_contents:
|
||||
return
|
||||
|
||||
def _memorized_storage() -> dict[str, Any]:
|
||||
def _memorized_storage() -> StoreData:
|
||||
self._pending_storage = None
|
||||
self._storage_contents = store_data
|
||||
return store_data
|
||||
|
||||
self._pending_storage = _memorized_storage
|
||||
self.store.async_delay_save(_memorized_storage, SAVE_DELAY)
|
||||
|
||||
async def async_cleanup(self) -> None:
|
||||
"""Cleanup the entry data when disconnected or unloading."""
|
||||
if self._pending_storage:
|
||||
# Ensure we save the data if we are unloading before the
|
||||
# save delay has passed.
|
||||
await self.store.async_save(self._pending_storage())
|
||||
|
||||
async def async_update_listener(
|
||||
self, hass: HomeAssistant, entry: ConfigEntry
|
||||
) -> None:
|
||||
|
|
|
@ -40,7 +40,6 @@ async def async_setup_entry(
|
|||
hass,
|
||||
entry,
|
||||
async_add_entities,
|
||||
component_key="fan",
|
||||
info_type=FanInfo,
|
||||
entity_type=EsphomeFan,
|
||||
state_type=FanState,
|
||||
|
|
|
@ -48,7 +48,6 @@ async def async_setup_entry(
|
|||
hass,
|
||||
entry,
|
||||
async_add_entities,
|
||||
component_key="light",
|
||||
info_type=LightInfo,
|
||||
entity_type=EsphomeLight,
|
||||
state_type=LightState,
|
||||
|
|
|
@ -26,7 +26,6 @@ async def async_setup_entry(
|
|||
hass,
|
||||
entry,
|
||||
async_add_entities,
|
||||
component_key="lock",
|
||||
info_type=LockInfo,
|
||||
entity_type=EsphomeLock,
|
||||
state_type=LockEntityState,
|
||||
|
|
|
@ -43,7 +43,6 @@ async def async_setup_entry(
|
|||
hass,
|
||||
entry,
|
||||
async_add_entities,
|
||||
component_key="media_player",
|
||||
info_type=MediaPlayerInfo,
|
||||
entity_type=EsphomeMediaPlayer,
|
||||
state_type=MediaPlayerEntityState,
|
||||
|
|
|
@ -34,7 +34,6 @@ async def async_setup_entry(
|
|||
hass,
|
||||
entry,
|
||||
async_add_entities,
|
||||
component_key="number",
|
||||
info_type=NumberInfo,
|
||||
entity_type=EsphomeNumber,
|
||||
state_type=NumberState,
|
||||
|
|
|
@ -29,7 +29,6 @@ async def async_setup_entry(
|
|||
hass,
|
||||
entry,
|
||||
async_add_entities,
|
||||
component_key="select",
|
||||
info_type=SelectInfo,
|
||||
entity_type=EsphomeSelect,
|
||||
state_type=SelectState,
|
||||
|
|
|
@ -41,7 +41,6 @@ async def async_setup_entry(
|
|||
hass,
|
||||
entry,
|
||||
async_add_entities,
|
||||
component_key="sensor",
|
||||
info_type=SensorInfo,
|
||||
entity_type=EsphomeSensor,
|
||||
state_type=SensorState,
|
||||
|
@ -50,7 +49,6 @@ async def async_setup_entry(
|
|||
hass,
|
||||
entry,
|
||||
async_add_entities,
|
||||
component_key="text_sensor",
|
||||
info_type=TextSensorInfo,
|
||||
entity_type=EsphomeTextSensor,
|
||||
state_type=TextSensorState,
|
||||
|
|
|
@ -26,7 +26,6 @@ async def async_setup_entry(
|
|||
hass,
|
||||
entry,
|
||||
async_add_entities,
|
||||
component_key="switch",
|
||||
info_type=SwitchInfo,
|
||||
entity_type=EsphomeSwitch,
|
||||
state_type=SwitchState,
|
||||
|
|
|
@ -169,19 +169,22 @@ async def _mock_generic_device_entry(
|
|||
mock_device_info: dict[str, Any],
|
||||
mock_list_entities_services: tuple[list[EntityInfo], list[UserService]],
|
||||
states: list[EntityState],
|
||||
entry: MockConfigEntry | None = None,
|
||||
) -> MockESPHomeDevice:
|
||||
entry = MockConfigEntry(
|
||||
domain=DOMAIN,
|
||||
data={
|
||||
CONF_HOST: "test.local",
|
||||
CONF_PORT: 6053,
|
||||
CONF_PASSWORD: "",
|
||||
},
|
||||
options={
|
||||
CONF_ALLOW_SERVICE_CALLS: DEFAULT_NEW_CONFIG_ALLOW_ALLOW_SERVICE_CALLS
|
||||
},
|
||||
)
|
||||
entry.add_to_hass(hass)
|
||||
if not entry:
|
||||
entry = MockConfigEntry(
|
||||
domain=DOMAIN,
|
||||
data={
|
||||
CONF_HOST: "test.local",
|
||||
CONF_PORT: 6053,
|
||||
CONF_PASSWORD: "",
|
||||
},
|
||||
options={
|
||||
CONF_ALLOW_SERVICE_CALLS: DEFAULT_NEW_CONFIG_ALLOW_ALLOW_SERVICE_CALLS
|
||||
},
|
||||
)
|
||||
entry.add_to_hass(hass)
|
||||
|
||||
mock_device = MockESPHomeDevice(entry)
|
||||
|
||||
device_info = DeviceInfo(
|
||||
|
@ -290,9 +293,10 @@ async def mock_esphome_device(
|
|||
entity_info: list[EntityInfo],
|
||||
user_service: list[UserService],
|
||||
states: list[EntityState],
|
||||
entry: MockConfigEntry | None = None,
|
||||
) -> MockESPHomeDevice:
|
||||
return await _mock_generic_device_entry(
|
||||
hass, mock_client, {}, (entity_info, user_service), states
|
||||
hass, mock_client, {}, (entity_info, user_service), states, entry
|
||||
)
|
||||
|
||||
return _mock_device
|
||||
|
|
|
@ -0,0 +1,103 @@
|
|||
"""Test ESPHome binary sensors."""
|
||||
from collections.abc import Awaitable, Callable
|
||||
from typing import Any
|
||||
|
||||
from aioesphomeapi import (
|
||||
APIClient,
|
||||
BinarySensorInfo,
|
||||
BinarySensorState,
|
||||
EntityInfo,
|
||||
EntityState,
|
||||
UserService,
|
||||
)
|
||||
|
||||
from homeassistant.const import ATTR_RESTORED, STATE_ON
|
||||
from homeassistant.core import HomeAssistant
|
||||
|
||||
from .conftest import MockESPHomeDevice
|
||||
|
||||
|
||||
async def test_entities_removed(
|
||||
hass: HomeAssistant,
|
||||
mock_client: APIClient,
|
||||
hass_storage: dict[str, Any],
|
||||
mock_esphome_device: Callable[
|
||||
[APIClient, list[EntityInfo], list[UserService], list[EntityState]],
|
||||
Awaitable[MockESPHomeDevice],
|
||||
],
|
||||
) -> None:
|
||||
"""Test a generic binary_sensor where has_state is false."""
|
||||
entity_info = [
|
||||
BinarySensorInfo(
|
||||
object_id="mybinary_sensor",
|
||||
key=1,
|
||||
name="my binary_sensor",
|
||||
unique_id="my_binary_sensor",
|
||||
),
|
||||
BinarySensorInfo(
|
||||
object_id="mybinary_sensor_to_be_removed",
|
||||
key=2,
|
||||
name="my binary_sensor to be removed",
|
||||
unique_id="mybinary_sensor_to_be_removed",
|
||||
),
|
||||
]
|
||||
states = [
|
||||
BinarySensorState(key=1, state=True, missing_state=False),
|
||||
BinarySensorState(key=2, state=True, missing_state=False),
|
||||
]
|
||||
user_service = []
|
||||
mock_device = await mock_esphome_device(
|
||||
mock_client=mock_client,
|
||||
entity_info=entity_info,
|
||||
user_service=user_service,
|
||||
states=states,
|
||||
)
|
||||
entry = mock_device.entry
|
||||
entry_id = entry.entry_id
|
||||
storage_key = f"esphome.{entry_id}"
|
||||
state = hass.states.get("binary_sensor.test_my_binary_sensor")
|
||||
assert state is not None
|
||||
assert state.state == STATE_ON
|
||||
state = hass.states.get("binary_sensor.test_my_binary_sensor_to_be_removed")
|
||||
assert state is not None
|
||||
assert state.state == STATE_ON
|
||||
|
||||
await hass.config_entries.async_unload(entry.entry_id)
|
||||
await hass.async_block_till_done()
|
||||
|
||||
assert len(hass_storage[storage_key]["data"]["binary_sensor"]) == 2
|
||||
|
||||
state = hass.states.get("binary_sensor.test_my_binary_sensor")
|
||||
assert state is not None
|
||||
assert state.attributes[ATTR_RESTORED] is True
|
||||
state = hass.states.get("binary_sensor.test_my_binary_sensor_to_be_removed")
|
||||
assert state is not None
|
||||
assert state.attributes[ATTR_RESTORED] is True
|
||||
|
||||
entity_info = [
|
||||
BinarySensorInfo(
|
||||
object_id="mybinary_sensor",
|
||||
key=1,
|
||||
name="my binary_sensor",
|
||||
unique_id="my_binary_sensor",
|
||||
),
|
||||
]
|
||||
states = [
|
||||
BinarySensorState(key=1, state=True, missing_state=False),
|
||||
]
|
||||
mock_device = await mock_esphome_device(
|
||||
mock_client=mock_client,
|
||||
entity_info=entity_info,
|
||||
user_service=user_service,
|
||||
states=states,
|
||||
entry=entry,
|
||||
)
|
||||
assert mock_device.entry.entry_id == entry_id
|
||||
state = hass.states.get("binary_sensor.test_my_binary_sensor")
|
||||
assert state is not None
|
||||
assert state.state == STATE_ON
|
||||
state = hass.states.get("binary_sensor.test_my_binary_sensor_to_be_removed")
|
||||
assert state is None
|
||||
await hass.config_entries.async_unload(entry.entry_id)
|
||||
await hass.async_block_till_done()
|
||||
assert len(hass_storage[storage_key]["data"]["binary_sensor"]) == 1
|
Loading…
Reference in New Issue