Small cleanups to ESPHome callbacks (#107428)

pull/105955/head
J. Nick Koston 2024-01-07 07:39:33 -10:00 committed by GitHub
parent 15ce70606f
commit 901b9365b4
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 53 additions and 28 deletions

View File

@ -4,6 +4,7 @@ from __future__ import annotations
import asyncio import asyncio
from collections.abc import Callable, Coroutine, Iterable from collections.abc import Callable, Coroutine, Iterable
from dataclasses import dataclass, field from dataclasses import dataclass, field
from functools import partial
import logging import logging
from typing import TYPE_CHECKING, Any, Final, TypedDict, cast from typing import TYPE_CHECKING, Any, Final, TypedDict, cast
@ -163,11 +164,18 @@ class RuntimeEntryData:
"""Register to receive callbacks when static info changes for an EntityInfo type.""" """Register to receive callbacks when static info changes for an EntityInfo type."""
callbacks = self.entity_info_callbacks.setdefault(entity_info_type, []) callbacks = self.entity_info_callbacks.setdefault(entity_info_type, [])
callbacks.append(callback_) callbacks.append(callback_)
return partial(
self._async_unsubscribe_register_static_info, callbacks, callback_
)
def _unsub() -> None: @callback
callbacks.remove(callback_) def _async_unsubscribe_register_static_info(
self,
return _unsub callbacks: list[Callable[[list[EntityInfo]], None]],
callback_: Callable[[list[EntityInfo]], None],
) -> None:
"""Unsubscribe to when static info is registered."""
callbacks.remove(callback_)
@callback @callback
def async_register_key_static_info_remove_callback( def async_register_key_static_info_remove_callback(
@ -179,11 +187,16 @@ class RuntimeEntryData:
callback_key = (type(static_info), static_info.key) callback_key = (type(static_info), static_info.key)
callbacks = self.entity_info_key_remove_callbacks.setdefault(callback_key, []) callbacks = self.entity_info_key_remove_callbacks.setdefault(callback_key, [])
callbacks.append(callback_) callbacks.append(callback_)
return partial(self._async_unsubscribe_static_key_remove, callbacks, callback_)
def _unsub() -> None: @callback
callbacks.remove(callback_) def _async_unsubscribe_static_key_remove(
self,
return _unsub callbacks: list[Callable[[], Coroutine[Any, Any, None]]],
callback_: Callable[[], Coroutine[Any, Any, None]],
) -> None:
"""Unsubscribe to when static info is removed."""
callbacks.remove(callback_)
@callback @callback
def async_register_key_static_info_updated_callback( def async_register_key_static_info_updated_callback(
@ -195,11 +208,18 @@ class RuntimeEntryData:
callback_key = (type(static_info), static_info.key) callback_key = (type(static_info), static_info.key)
callbacks = self.entity_info_key_updated_callbacks.setdefault(callback_key, []) callbacks = self.entity_info_key_updated_callbacks.setdefault(callback_key, [])
callbacks.append(callback_) callbacks.append(callback_)
return partial(
self._async_unsubscribe_static_key_info_updated, callbacks, callback_
)
def _unsub() -> None: @callback
callbacks.remove(callback_) def _async_unsubscribe_static_key_info_updated(
self,
return _unsub callbacks: list[Callable[[EntityInfo], None]],
callback_: Callable[[EntityInfo], None],
) -> None:
"""Unsubscribe to when static info is updated ."""
callbacks.remove(callback_)
@callback @callback
def async_set_assist_pipeline_state(self, state: bool) -> None: def async_set_assist_pipeline_state(self, state: bool) -> None:
@ -208,16 +228,20 @@ class RuntimeEntryData:
for update_callback in self.assist_pipeline_update_callbacks: for update_callback in self.assist_pipeline_update_callbacks:
update_callback() update_callback()
@callback
def async_subscribe_assist_pipeline_update( def async_subscribe_assist_pipeline_update(
self, update_callback: Callable[[], None] self, update_callback: Callable[[], None]
) -> Callable[[], None]: ) -> Callable[[], None]:
"""Subscribe to assist pipeline updates.""" """Subscribe to assist pipeline updates."""
def _unsubscribe() -> None:
self.assist_pipeline_update_callbacks.remove(update_callback)
self.assist_pipeline_update_callbacks.append(update_callback) self.assist_pipeline_update_callbacks.append(update_callback)
return _unsubscribe return partial(self._async_unsubscribe_assist_pipeline_update, update_callback)
@callback
def _async_unsubscribe_assist_pipeline_update(
self, update_callback: Callable[[], None]
) -> None:
"""Unsubscribe to assist pipeline updates."""
self.assist_pipeline_update_callbacks.remove(update_callback)
async def async_remove_entities(self, static_infos: Iterable[EntityInfo]) -> None: async def async_remove_entities(self, static_infos: Iterable[EntityInfo]) -> None:
"""Schedule the removal of an entity.""" """Schedule the removal of an entity."""
@ -232,19 +256,16 @@ class RuntimeEntryData:
@callback @callback
def async_update_entity_infos(self, static_infos: Iterable[EntityInfo]) -> None: def async_update_entity_infos(self, static_infos: Iterable[EntityInfo]) -> None:
"""Call static info updated callbacks.""" """Call static info updated callbacks."""
callbacks = self.entity_info_key_updated_callbacks
for static_info in static_infos: for static_info in static_infos:
callback_key = (type(static_info), static_info.key) for callback_ in callbacks.get((type(static_info), static_info.key), ()):
for callback_ in self.entity_info_key_updated_callbacks.get(
callback_key, []
):
callback_(static_info) callback_(static_info)
async def _ensure_platforms_loaded( async def _ensure_platforms_loaded(
self, hass: HomeAssistant, entry: ConfigEntry, platforms: set[Platform] self, hass: HomeAssistant, entry: ConfigEntry, platforms: set[Platform]
) -> None: ) -> None:
async with self.platform_load_lock: async with self.platform_load_lock:
needed = platforms - self.loaded_platforms if needed := platforms - self.loaded_platforms:
if needed:
await hass.config_entries.async_forward_entry_setups(entry, needed) await hass.config_entries.async_forward_entry_setups(entry, needed)
self.loaded_platforms |= needed self.loaded_platforms |= needed
@ -305,12 +326,16 @@ class RuntimeEntryData:
entity_callback: Callable[[], None], entity_callback: Callable[[], None],
) -> Callable[[], None]: ) -> Callable[[], None]:
"""Subscribe to state updates.""" """Subscribe to state updates."""
subscription_key = (state_type, state_key)
self.state_subscriptions[subscription_key] = entity_callback
return partial(self._async_unsubscribe_state_update, subscription_key)
def _unsubscribe() -> None: @callback
self.state_subscriptions.pop((state_type, state_key)) def _async_unsubscribe_state_update(
self, subscription_key: tuple[type[EntityState], int]
self.state_subscriptions[(state_type, state_key)] = entity_callback ) -> None:
return _unsubscribe """Unsubscribe to state updates."""
self.state_subscriptions.pop(subscription_key)
@callback @callback
def async_update_state(self, state: EntityState) -> None: def async_update_state(self, state: EntityState) -> None: