diff --git a/homeassistant/components/esphome/entry_data.py b/homeassistant/components/esphome/entry_data.py index d9e5b199748..723141a94a2 100644 --- a/homeassistant/components/esphome/entry_data.py +++ b/homeassistant/components/esphome/entry_data.py @@ -4,6 +4,7 @@ from __future__ import annotations import asyncio from collections.abc import Callable, Coroutine, Iterable from dataclasses import dataclass, field +from functools import partial import logging from typing import TYPE_CHECKING, Any, Final, TypedDict, cast @@ -163,11 +164,18 @@ class RuntimeEntryData: """Register to receive callbacks when static info changes for an EntityInfo type.""" callbacks = self.entity_info_callbacks.setdefault(entity_info_type, []) callbacks.append(callback_) + return partial( + self._async_unsubscribe_register_static_info, callbacks, callback_ + ) - def _unsub() -> None: - callbacks.remove(callback_) - - return _unsub + @callback + def _async_unsubscribe_register_static_info( + self, + callbacks: list[Callable[[list[EntityInfo]], None]], + callback_: Callable[[list[EntityInfo]], None], + ) -> None: + """Unsubscribe to when static info is registered.""" + callbacks.remove(callback_) @callback def async_register_key_static_info_remove_callback( @@ -179,11 +187,16 @@ class RuntimeEntryData: callback_key = (type(static_info), static_info.key) callbacks = self.entity_info_key_remove_callbacks.setdefault(callback_key, []) callbacks.append(callback_) + return partial(self._async_unsubscribe_static_key_remove, callbacks, callback_) - def _unsub() -> None: - callbacks.remove(callback_) - - return _unsub + @callback + def _async_unsubscribe_static_key_remove( + self, + callbacks: list[Callable[[], Coroutine[Any, Any, None]]], + callback_: Callable[[], Coroutine[Any, Any, None]], + ) -> None: + """Unsubscribe to when static info is removed.""" + callbacks.remove(callback_) @callback def async_register_key_static_info_updated_callback( @@ -195,11 +208,18 @@ class RuntimeEntryData: callback_key = (type(static_info), static_info.key) callbacks = self.entity_info_key_updated_callbacks.setdefault(callback_key, []) callbacks.append(callback_) + return partial( + self._async_unsubscribe_static_key_info_updated, callbacks, callback_ + ) - def _unsub() -> None: - callbacks.remove(callback_) - - return _unsub + @callback + def _async_unsubscribe_static_key_info_updated( + self, + callbacks: list[Callable[[EntityInfo], None]], + callback_: Callable[[EntityInfo], None], + ) -> None: + """Unsubscribe to when static info is updated .""" + callbacks.remove(callback_) @callback def async_set_assist_pipeline_state(self, state: bool) -> None: @@ -208,16 +228,20 @@ class RuntimeEntryData: for update_callback in self.assist_pipeline_update_callbacks: update_callback() + @callback def async_subscribe_assist_pipeline_update( self, update_callback: Callable[[], None] ) -> Callable[[], None]: """Subscribe to assist pipeline updates.""" - - def _unsubscribe() -> None: - self.assist_pipeline_update_callbacks.remove(update_callback) - self.assist_pipeline_update_callbacks.append(update_callback) - return _unsubscribe + return partial(self._async_unsubscribe_assist_pipeline_update, update_callback) + + @callback + def _async_unsubscribe_assist_pipeline_update( + self, update_callback: Callable[[], None] + ) -> None: + """Unsubscribe to assist pipeline updates.""" + self.assist_pipeline_update_callbacks.remove(update_callback) async def async_remove_entities(self, static_infos: Iterable[EntityInfo]) -> None: """Schedule the removal of an entity.""" @@ -232,19 +256,16 @@ class RuntimeEntryData: @callback def async_update_entity_infos(self, static_infos: Iterable[EntityInfo]) -> None: """Call static info updated callbacks.""" + callbacks = self.entity_info_key_updated_callbacks for static_info in static_infos: - callback_key = (type(static_info), static_info.key) - for callback_ in self.entity_info_key_updated_callbacks.get( - callback_key, [] - ): + for callback_ in callbacks.get((type(static_info), static_info.key), ()): callback_(static_info) async def _ensure_platforms_loaded( self, hass: HomeAssistant, entry: ConfigEntry, platforms: set[Platform] ) -> None: async with self.platform_load_lock: - needed = platforms - self.loaded_platforms - if needed: + if needed := platforms - self.loaded_platforms: await hass.config_entries.async_forward_entry_setups(entry, needed) self.loaded_platforms |= needed @@ -305,12 +326,16 @@ class RuntimeEntryData: entity_callback: Callable[[], None], ) -> Callable[[], None]: """Subscribe to state updates.""" + subscription_key = (state_type, state_key) + self.state_subscriptions[subscription_key] = entity_callback + return partial(self._async_unsubscribe_state_update, subscription_key) - def _unsubscribe() -> None: - self.state_subscriptions.pop((state_type, state_key)) - - self.state_subscriptions[(state_type, state_key)] = entity_callback - return _unsubscribe + @callback + def _async_unsubscribe_state_update( + self, subscription_key: tuple[type[EntityState], int] + ) -> None: + """Unsubscribe to state updates.""" + self.state_subscriptions.pop(subscription_key) @callback def async_update_state(self, state: EntityState) -> None: