Small cleanups to ESPHome callbacks (#107428)

pull/105955/head
J. Nick Koston 2024-01-07 07:39:33 -10:00 committed by GitHub
parent 15ce70606f
commit 901b9365b4
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 53 additions and 28 deletions

View File

@ -4,6 +4,7 @@ from __future__ import annotations
import asyncio
from collections.abc import Callable, Coroutine, Iterable
from dataclasses import dataclass, field
from functools import partial
import logging
from typing import TYPE_CHECKING, Any, Final, TypedDict, cast
@ -163,11 +164,18 @@ class RuntimeEntryData:
"""Register to receive callbacks when static info changes for an EntityInfo type."""
callbacks = self.entity_info_callbacks.setdefault(entity_info_type, [])
callbacks.append(callback_)
return partial(
self._async_unsubscribe_register_static_info, callbacks, callback_
)
def _unsub() -> None:
callbacks.remove(callback_)
return _unsub
@callback
def _async_unsubscribe_register_static_info(
self,
callbacks: list[Callable[[list[EntityInfo]], None]],
callback_: Callable[[list[EntityInfo]], None],
) -> None:
"""Unsubscribe to when static info is registered."""
callbacks.remove(callback_)
@callback
def async_register_key_static_info_remove_callback(
@ -179,11 +187,16 @@ class RuntimeEntryData:
callback_key = (type(static_info), static_info.key)
callbacks = self.entity_info_key_remove_callbacks.setdefault(callback_key, [])
callbacks.append(callback_)
return partial(self._async_unsubscribe_static_key_remove, callbacks, callback_)
def _unsub() -> None:
callbacks.remove(callback_)
return _unsub
@callback
def _async_unsubscribe_static_key_remove(
self,
callbacks: list[Callable[[], Coroutine[Any, Any, None]]],
callback_: Callable[[], Coroutine[Any, Any, None]],
) -> None:
"""Unsubscribe to when static info is removed."""
callbacks.remove(callback_)
@callback
def async_register_key_static_info_updated_callback(
@ -195,11 +208,18 @@ class RuntimeEntryData:
callback_key = (type(static_info), static_info.key)
callbacks = self.entity_info_key_updated_callbacks.setdefault(callback_key, [])
callbacks.append(callback_)
return partial(
self._async_unsubscribe_static_key_info_updated, callbacks, callback_
)
def _unsub() -> None:
callbacks.remove(callback_)
return _unsub
@callback
def _async_unsubscribe_static_key_info_updated(
self,
callbacks: list[Callable[[EntityInfo], None]],
callback_: Callable[[EntityInfo], None],
) -> None:
"""Unsubscribe to when static info is updated ."""
callbacks.remove(callback_)
@callback
def async_set_assist_pipeline_state(self, state: bool) -> None:
@ -208,16 +228,20 @@ class RuntimeEntryData:
for update_callback in self.assist_pipeline_update_callbacks:
update_callback()
@callback
def async_subscribe_assist_pipeline_update(
self, update_callback: Callable[[], None]
) -> Callable[[], None]:
"""Subscribe to assist pipeline updates."""
def _unsubscribe() -> None:
self.assist_pipeline_update_callbacks.remove(update_callback)
self.assist_pipeline_update_callbacks.append(update_callback)
return _unsubscribe
return partial(self._async_unsubscribe_assist_pipeline_update, update_callback)
@callback
def _async_unsubscribe_assist_pipeline_update(
self, update_callback: Callable[[], None]
) -> None:
"""Unsubscribe to assist pipeline updates."""
self.assist_pipeline_update_callbacks.remove(update_callback)
async def async_remove_entities(self, static_infos: Iterable[EntityInfo]) -> None:
"""Schedule the removal of an entity."""
@ -232,19 +256,16 @@ class RuntimeEntryData:
@callback
def async_update_entity_infos(self, static_infos: Iterable[EntityInfo]) -> None:
"""Call static info updated callbacks."""
callbacks = self.entity_info_key_updated_callbacks
for static_info in static_infos:
callback_key = (type(static_info), static_info.key)
for callback_ in self.entity_info_key_updated_callbacks.get(
callback_key, []
):
for callback_ in callbacks.get((type(static_info), static_info.key), ()):
callback_(static_info)
async def _ensure_platforms_loaded(
self, hass: HomeAssistant, entry: ConfigEntry, platforms: set[Platform]
) -> None:
async with self.platform_load_lock:
needed = platforms - self.loaded_platforms
if needed:
if needed := platforms - self.loaded_platforms:
await hass.config_entries.async_forward_entry_setups(entry, needed)
self.loaded_platforms |= needed
@ -305,12 +326,16 @@ class RuntimeEntryData:
entity_callback: Callable[[], None],
) -> Callable[[], None]:
"""Subscribe to state updates."""
subscription_key = (state_type, state_key)
self.state_subscriptions[subscription_key] = entity_callback
return partial(self._async_unsubscribe_state_update, subscription_key)
def _unsubscribe() -> None:
self.state_subscriptions.pop((state_type, state_key))
self.state_subscriptions[(state_type, state_key)] = entity_callback
return _unsubscribe
@callback
def _async_unsubscribe_state_update(
self, subscription_key: tuple[type[EntityState], int]
) -> None:
"""Unsubscribe to state updates."""
self.state_subscriptions.pop(subscription_key)
@callback
def async_update_state(self, state: EntityState) -> None: