Small cleanups to ESPHome callbacks (#107428)
parent
15ce70606f
commit
901b9365b4
|
@ -4,6 +4,7 @@ from __future__ import annotations
|
|||
import asyncio
|
||||
from collections.abc import Callable, Coroutine, Iterable
|
||||
from dataclasses import dataclass, field
|
||||
from functools import partial
|
||||
import logging
|
||||
from typing import TYPE_CHECKING, Any, Final, TypedDict, cast
|
||||
|
||||
|
@ -163,11 +164,18 @@ class RuntimeEntryData:
|
|||
"""Register to receive callbacks when static info changes for an EntityInfo type."""
|
||||
callbacks = self.entity_info_callbacks.setdefault(entity_info_type, [])
|
||||
callbacks.append(callback_)
|
||||
return partial(
|
||||
self._async_unsubscribe_register_static_info, callbacks, callback_
|
||||
)
|
||||
|
||||
def _unsub() -> None:
|
||||
callbacks.remove(callback_)
|
||||
|
||||
return _unsub
|
||||
@callback
|
||||
def _async_unsubscribe_register_static_info(
|
||||
self,
|
||||
callbacks: list[Callable[[list[EntityInfo]], None]],
|
||||
callback_: Callable[[list[EntityInfo]], None],
|
||||
) -> None:
|
||||
"""Unsubscribe to when static info is registered."""
|
||||
callbacks.remove(callback_)
|
||||
|
||||
@callback
|
||||
def async_register_key_static_info_remove_callback(
|
||||
|
@ -179,11 +187,16 @@ class RuntimeEntryData:
|
|||
callback_key = (type(static_info), static_info.key)
|
||||
callbacks = self.entity_info_key_remove_callbacks.setdefault(callback_key, [])
|
||||
callbacks.append(callback_)
|
||||
return partial(self._async_unsubscribe_static_key_remove, callbacks, callback_)
|
||||
|
||||
def _unsub() -> None:
|
||||
callbacks.remove(callback_)
|
||||
|
||||
return _unsub
|
||||
@callback
|
||||
def _async_unsubscribe_static_key_remove(
|
||||
self,
|
||||
callbacks: list[Callable[[], Coroutine[Any, Any, None]]],
|
||||
callback_: Callable[[], Coroutine[Any, Any, None]],
|
||||
) -> None:
|
||||
"""Unsubscribe to when static info is removed."""
|
||||
callbacks.remove(callback_)
|
||||
|
||||
@callback
|
||||
def async_register_key_static_info_updated_callback(
|
||||
|
@ -195,11 +208,18 @@ class RuntimeEntryData:
|
|||
callback_key = (type(static_info), static_info.key)
|
||||
callbacks = self.entity_info_key_updated_callbacks.setdefault(callback_key, [])
|
||||
callbacks.append(callback_)
|
||||
return partial(
|
||||
self._async_unsubscribe_static_key_info_updated, callbacks, callback_
|
||||
)
|
||||
|
||||
def _unsub() -> None:
|
||||
callbacks.remove(callback_)
|
||||
|
||||
return _unsub
|
||||
@callback
|
||||
def _async_unsubscribe_static_key_info_updated(
|
||||
self,
|
||||
callbacks: list[Callable[[EntityInfo], None]],
|
||||
callback_: Callable[[EntityInfo], None],
|
||||
) -> None:
|
||||
"""Unsubscribe to when static info is updated ."""
|
||||
callbacks.remove(callback_)
|
||||
|
||||
@callback
|
||||
def async_set_assist_pipeline_state(self, state: bool) -> None:
|
||||
|
@ -208,16 +228,20 @@ class RuntimeEntryData:
|
|||
for update_callback in self.assist_pipeline_update_callbacks:
|
||||
update_callback()
|
||||
|
||||
@callback
|
||||
def async_subscribe_assist_pipeline_update(
|
||||
self, update_callback: Callable[[], None]
|
||||
) -> Callable[[], None]:
|
||||
"""Subscribe to assist pipeline updates."""
|
||||
|
||||
def _unsubscribe() -> None:
|
||||
self.assist_pipeline_update_callbacks.remove(update_callback)
|
||||
|
||||
self.assist_pipeline_update_callbacks.append(update_callback)
|
||||
return _unsubscribe
|
||||
return partial(self._async_unsubscribe_assist_pipeline_update, update_callback)
|
||||
|
||||
@callback
|
||||
def _async_unsubscribe_assist_pipeline_update(
|
||||
self, update_callback: Callable[[], None]
|
||||
) -> None:
|
||||
"""Unsubscribe to assist pipeline updates."""
|
||||
self.assist_pipeline_update_callbacks.remove(update_callback)
|
||||
|
||||
async def async_remove_entities(self, static_infos: Iterable[EntityInfo]) -> None:
|
||||
"""Schedule the removal of an entity."""
|
||||
|
@ -232,19 +256,16 @@ class RuntimeEntryData:
|
|||
@callback
|
||||
def async_update_entity_infos(self, static_infos: Iterable[EntityInfo]) -> None:
|
||||
"""Call static info updated callbacks."""
|
||||
callbacks = self.entity_info_key_updated_callbacks
|
||||
for static_info in static_infos:
|
||||
callback_key = (type(static_info), static_info.key)
|
||||
for callback_ in self.entity_info_key_updated_callbacks.get(
|
||||
callback_key, []
|
||||
):
|
||||
for callback_ in callbacks.get((type(static_info), static_info.key), ()):
|
||||
callback_(static_info)
|
||||
|
||||
async def _ensure_platforms_loaded(
|
||||
self, hass: HomeAssistant, entry: ConfigEntry, platforms: set[Platform]
|
||||
) -> None:
|
||||
async with self.platform_load_lock:
|
||||
needed = platforms - self.loaded_platforms
|
||||
if needed:
|
||||
if needed := platforms - self.loaded_platforms:
|
||||
await hass.config_entries.async_forward_entry_setups(entry, needed)
|
||||
self.loaded_platforms |= needed
|
||||
|
||||
|
@ -305,12 +326,16 @@ class RuntimeEntryData:
|
|||
entity_callback: Callable[[], None],
|
||||
) -> Callable[[], None]:
|
||||
"""Subscribe to state updates."""
|
||||
subscription_key = (state_type, state_key)
|
||||
self.state_subscriptions[subscription_key] = entity_callback
|
||||
return partial(self._async_unsubscribe_state_update, subscription_key)
|
||||
|
||||
def _unsubscribe() -> None:
|
||||
self.state_subscriptions.pop((state_type, state_key))
|
||||
|
||||
self.state_subscriptions[(state_type, state_key)] = entity_callback
|
||||
return _unsubscribe
|
||||
@callback
|
||||
def _async_unsubscribe_state_update(
|
||||
self, subscription_key: tuple[type[EntityState], int]
|
||||
) -> None:
|
||||
"""Unsubscribe to state updates."""
|
||||
self.state_subscriptions.pop(subscription_key)
|
||||
|
||||
@callback
|
||||
def async_update_state(self, state: EntityState) -> None:
|
||||
|
|
Loading…
Reference in New Issue