From 8910d265d6cf15fed4e6e98b4344031019c1016d Mon Sep 17 00:00:00 2001 From: Joakim Plate Date: Fri, 3 Jun 2022 13:55:57 +0200 Subject: [PATCH] Keep track of a context for each listener (#72702) * Remove async_remove_listener This avoids the ambuigity as to what happens if same callback is added multiple times. * Keep track of a context for each listener This allow a update coordinator to adapt what data to request on update from the backing service based on which entities are enabled. * Clone list before calling callbacks The callbacks can end up unregistering and modifying the dict while iterating. * Only yield actual values * Add a test for update context * Factor out iteration of _listeners to helper * Verify context is passed to coordinator * Switch to Any as type instead of object * Remove function which use was dropped earliers The use was removed in 8bee25c938a123f0da7569b4e2753598d478b900 --- .../components/bmw_connected_drive/button.py | 2 +- .../bmw_connected_drive/coordinator.py | 5 -- .../components/modern_forms/__init__.py | 9 +-- .../components/moehlenhoff_alpha2/__init__.py | 10 ++- .../components/philips_js/__init__.py | 26 ++------ .../components/system_bridge/coordinator.py | 19 ++---- homeassistant/components/toon/coordinator.py | 7 +- homeassistant/components/toon/helpers.py | 4 +- homeassistant/components/wemo/wemo_device.py | 6 -- homeassistant/components/wled/coordinator.py | 7 +- homeassistant/components/wled/helpers.py | 4 +- .../yamaha_musiccast/media_player.py | 5 +- homeassistant/helpers/update_coordinator.py | 65 +++++++++++-------- tests/helpers/test_update_coordinator.py | 41 +++++++++--- 14 files changed, 95 insertions(+), 115 deletions(-) diff --git a/homeassistant/components/bmw_connected_drive/button.py b/homeassistant/components/bmw_connected_drive/button.py index 9cec9a73ce7..baa7870ee8c 100644 --- a/homeassistant/components/bmw_connected_drive/button.py +++ b/homeassistant/components/bmw_connected_drive/button.py @@ -131,4 +131,4 @@ class BMWButton(BMWBaseEntity, ButtonEntity): # Always update HA states after a button was executed. # BMW remote services that change the vehicle's state update the local object # when executing the service, so only the HA state machine needs further updates. - self.coordinator.notify_listeners() + self.coordinator.async_update_listeners() diff --git a/homeassistant/components/bmw_connected_drive/coordinator.py b/homeassistant/components/bmw_connected_drive/coordinator.py index 47d1f358686..1443a3e1e29 100644 --- a/homeassistant/components/bmw_connected_drive/coordinator.py +++ b/homeassistant/components/bmw_connected_drive/coordinator.py @@ -74,8 +74,3 @@ class BMWDataUpdateCoordinator(DataUpdateCoordinator): if not refresh_token: data.pop(CONF_REFRESH_TOKEN) self.hass.config_entries.async_update_entry(self._entry, data=data) - - def notify_listeners(self) -> None: - """Notify all listeners to refresh HA state machine.""" - for update_callback in self._listeners: - update_callback() diff --git a/homeassistant/components/modern_forms/__init__.py b/homeassistant/components/modern_forms/__init__.py index af4f05a1536..ed4212d9444 100644 --- a/homeassistant/components/modern_forms/__init__.py +++ b/homeassistant/components/modern_forms/__init__.py @@ -74,12 +74,12 @@ def modernforms_exception_handler(func): async def handler(self, *args, **kwargs): try: await func(self, *args, **kwargs) - self.coordinator.update_listeners() + self.coordinator.async_update_listeners() except ModernFormsConnectionError as error: _LOGGER.error("Error communicating with API: %s", error) self.coordinator.last_update_success = False - self.coordinator.update_listeners() + self.coordinator.async_update_listeners() except ModernFormsError as error: _LOGGER.error("Invalid response from API: %s", error) @@ -108,11 +108,6 @@ class ModernFormsDataUpdateCoordinator(DataUpdateCoordinator[ModernFormsDeviceSt update_interval=SCAN_INTERVAL, ) - def update_listeners(self) -> None: - """Call update on all listeners.""" - for update_callback in self._listeners: - update_callback() - async def _async_update_data(self) -> ModernFormsDevice: """Fetch data from Modern Forms.""" try: diff --git a/homeassistant/components/moehlenhoff_alpha2/__init__.py b/homeassistant/components/moehlenhoff_alpha2/__init__.py index 86306a56033..64bdfeb4e6d 100644 --- a/homeassistant/components/moehlenhoff_alpha2/__init__.py +++ b/homeassistant/components/moehlenhoff_alpha2/__init__.py @@ -83,8 +83,7 @@ class Alpha2BaseCoordinator(DataUpdateCoordinator[dict[str, dict]]): async def async_set_cooling(self, enabled: bool) -> None: """Enable or disable cooling mode.""" await self.base.set_cooling(enabled) - for update_callback in self._listeners: - update_callback() + self.async_update_listeners() async def async_set_target_temperature( self, heat_area_id: str, target_temperature: float @@ -117,8 +116,7 @@ class Alpha2BaseCoordinator(DataUpdateCoordinator[dict[str, dict]]): "Failed to set target temperature, communication error with alpha2 base" ) from http_err self.data["heat_areas"][heat_area_id].update(update_data) - for update_callback in self._listeners: - update_callback() + self.async_update_listeners() async def async_set_heat_area_mode( self, heat_area_id: str, heat_area_mode: int @@ -161,5 +159,5 @@ class Alpha2BaseCoordinator(DataUpdateCoordinator[dict[str, dict]]): self.data["heat_areas"][heat_area_id]["T_TARGET"] = self.data[ "heat_areas" ][heat_area_id]["T_HEAT_NIGHT"] - for update_callback in self._listeners: - update_callback() + + self.async_update_listeners() diff --git a/homeassistant/components/philips_js/__init__.py b/homeassistant/components/philips_js/__init__.py index 9a317726768..29e92a6ffe3 100644 --- a/homeassistant/components/philips_js/__init__.py +++ b/homeassistant/components/philips_js/__init__.py @@ -19,14 +19,7 @@ from homeassistant.const import ( CONF_USERNAME, Platform, ) -from homeassistant.core import ( - CALLBACK_TYPE, - Context, - Event, - HassJob, - HomeAssistant, - callback, -) +from homeassistant.core import Context, HassJob, HomeAssistant, callback from homeassistant.helpers.debounce import Debouncer from homeassistant.helpers.update_coordinator import DataUpdateCoordinator @@ -121,12 +114,7 @@ class PhilipsTVDataUpdateCoordinator(DataUpdateCoordinator[None]): self.options = options self._notify_future: asyncio.Task | None = None - @callback - def _update_listeners(): - for update_callback in self._listeners: - update_callback() - - self.turn_on = PluggableAction(_update_listeners) + self.turn_on = PluggableAction(self.async_update_listeners) super().__init__( hass, @@ -193,15 +181,9 @@ class PhilipsTVDataUpdateCoordinator(DataUpdateCoordinator[None]): self._notify_future = asyncio.create_task(self._notify_task()) @callback - def async_remove_listener(self, update_callback: CALLBACK_TYPE) -> None: + def _unschedule_refresh(self) -> None: """Remove data update.""" - super().async_remove_listener(update_callback) - if not self._listeners: - self._async_notify_stop() - - @callback - def _async_stop_refresh(self, event: Event) -> None: - super()._async_stop_refresh(event) + super()._unschedule_refresh() self._async_notify_stop() @callback diff --git a/homeassistant/components/system_bridge/coordinator.py b/homeassistant/components/system_bridge/coordinator.py index 89a0c85c1d9..a7343116cde 100644 --- a/homeassistant/components/system_bridge/coordinator.py +++ b/homeassistant/components/system_bridge/coordinator.py @@ -75,11 +75,6 @@ class SystemBridgeDataUpdateCoordinator( hass, LOGGER, name=DOMAIN, update_interval=timedelta(seconds=30) ) - def update_listeners(self) -> None: - """Call update on all listeners.""" - for update_callback in self._listeners: - update_callback() - async def async_get_data( self, modules: list[str], @@ -113,7 +108,7 @@ class SystemBridgeDataUpdateCoordinator( self.unsub() self.unsub = None self.last_update_success = False - self.update_listeners() + self.async_update_listeners() except (ConnectionClosedException, ConnectionResetError) as exception: self.logger.info( "Websocket connection closed for %s. Will retry: %s", @@ -124,7 +119,7 @@ class SystemBridgeDataUpdateCoordinator( self.unsub() self.unsub = None self.last_update_success = False - self.update_listeners() + self.async_update_listeners() except ConnectionErrorException as exception: self.logger.warning( "Connection error occurred for %s. Will retry: %s", @@ -135,7 +130,7 @@ class SystemBridgeDataUpdateCoordinator( self.unsub() self.unsub = None self.last_update_success = False - self.update_listeners() + self.async_update_listeners() async def _setup_websocket(self) -> None: """Use WebSocket for updates.""" @@ -151,7 +146,7 @@ class SystemBridgeDataUpdateCoordinator( self.unsub() self.unsub = None self.last_update_success = False - self.update_listeners() + self.async_update_listeners() except ConnectionErrorException as exception: self.logger.warning( "Connection error occurred for %s. Will retry: %s", @@ -159,7 +154,7 @@ class SystemBridgeDataUpdateCoordinator( exception, ) self.last_update_success = False - self.update_listeners() + self.async_update_listeners() except asyncio.TimeoutError as exception: self.logger.warning( "Timed out waiting for %s. Will retry: %s", @@ -167,11 +162,11 @@ class SystemBridgeDataUpdateCoordinator( exception, ) self.last_update_success = False - self.update_listeners() + self.async_update_listeners() self.hass.async_create_task(self._listen_for_data()) self.last_update_success = True - self.update_listeners() + self.async_update_listeners() async def close_websocket(_) -> None: """Close WebSocket connection.""" diff --git a/homeassistant/components/toon/coordinator.py b/homeassistant/components/toon/coordinator.py index 81c09931fbd..5819ff12743 100644 --- a/homeassistant/components/toon/coordinator.py +++ b/homeassistant/components/toon/coordinator.py @@ -47,11 +47,6 @@ class ToonDataUpdateCoordinator(DataUpdateCoordinator[Status]): hass, _LOGGER, name=DOMAIN, update_interval=DEFAULT_SCAN_INTERVAL ) - def update_listeners(self) -> None: - """Call update on all listeners.""" - for update_callback in self._listeners: - update_callback() - async def register_webhook(self, event: Event | None = None) -> None: """Register a webhook with Toon to get live updates.""" if CONF_WEBHOOK_ID not in self.entry.data: @@ -128,7 +123,7 @@ class ToonDataUpdateCoordinator(DataUpdateCoordinator[Status]): try: await self.toon.update(data["updateDataSet"]) - self.update_listeners() + self.async_update_listeners() except ToonError as err: _LOGGER.error("Could not process data received from Toon webhook - %s", err) diff --git a/homeassistant/components/toon/helpers.py b/homeassistant/components/toon/helpers.py index 405ecc36d7f..4fb4daede65 100644 --- a/homeassistant/components/toon/helpers.py +++ b/homeassistant/components/toon/helpers.py @@ -16,12 +16,12 @@ def toon_exception_handler(func): async def handler(self, *args, **kwargs): try: await func(self, *args, **kwargs) - self.coordinator.update_listeners() + self.coordinator.async_update_listeners() except ToonConnectionError as error: _LOGGER.error("Error communicating with API: %s", error) self.coordinator.last_update_success = False - self.coordinator.update_listeners() + self.coordinator.async_update_listeners() except ToonError as error: _LOGGER.error("Invalid response from API: %s", error) diff --git a/homeassistant/components/wemo/wemo_device.py b/homeassistant/components/wemo/wemo_device.py index 8f5e6864059..1f3e07881c8 100644 --- a/homeassistant/components/wemo/wemo_device.py +++ b/homeassistant/components/wemo/wemo_device.py @@ -123,12 +123,6 @@ class DeviceCoordinator(DataUpdateCoordinator): except ActionException as err: raise UpdateFailed("WeMo update failed") from err - @callback - def async_update_listeners(self) -> None: - """Update all listeners.""" - for update_callback in self._listeners: - update_callback() - def _device_info(wemo: WeMoDevice) -> DeviceInfo: return DeviceInfo( diff --git a/homeassistant/components/wled/coordinator.py b/homeassistant/components/wled/coordinator.py index a4cbaade8ba..81017779fbb 100644 --- a/homeassistant/components/wled/coordinator.py +++ b/homeassistant/components/wled/coordinator.py @@ -54,11 +54,6 @@ class WLEDDataUpdateCoordinator(DataUpdateCoordinator[WLEDDevice]): self.data is not None and len(self.data.state.segments) > 1 ) - def update_listeners(self) -> None: - """Call update on all listeners.""" - for update_callback in self._listeners: - update_callback() - @callback def _use_websocket(self) -> None: """Use WebSocket for updates, instead of polling.""" @@ -81,7 +76,7 @@ class WLEDDataUpdateCoordinator(DataUpdateCoordinator[WLEDDevice]): self.logger.info(err) except WLEDError as err: self.last_update_success = False - self.update_listeners() + self.async_update_listeners() self.logger.error(err) # Ensure we are disconnected diff --git a/homeassistant/components/wled/helpers.py b/homeassistant/components/wled/helpers.py index 66cd8b13b42..77e288bb34d 100644 --- a/homeassistant/components/wled/helpers.py +++ b/homeassistant/components/wled/helpers.py @@ -15,11 +15,11 @@ def wled_exception_handler(func): async def handler(self, *args, **kwargs): try: await func(self, *args, **kwargs) - self.coordinator.update_listeners() + self.coordinator.async_update_listeners() except WLEDConnectionError as error: self.coordinator.last_update_success = False - self.coordinator.update_listeners() + self.coordinator.async_update_listeners() raise HomeAssistantError("Error communicating with WLED API") from error except WLEDError as error: diff --git a/homeassistant/components/yamaha_musiccast/media_player.py b/homeassistant/components/yamaha_musiccast/media_player.py index 954942b2c6b..cee6253531b 100644 --- a/homeassistant/components/yamaha_musiccast/media_player.py +++ b/homeassistant/components/yamaha_musiccast/media_player.py @@ -106,7 +106,9 @@ class MusicCastMediaPlayer(MusicCastDeviceEntity, MediaPlayerEntity): self.coordinator.musiccast.register_group_update_callback( self.update_all_mc_entities ) - self.coordinator.async_add_listener(self.async_schedule_check_client_list) + self.async_on_remove( + self.coordinator.async_add_listener(self.async_schedule_check_client_list) + ) async def async_will_remove_from_hass(self): """Entity being removed from hass.""" @@ -116,7 +118,6 @@ class MusicCastMediaPlayer(MusicCastDeviceEntity, MediaPlayerEntity): self.coordinator.musiccast.remove_group_update_callback( self.update_all_mc_entities ) - self.coordinator.async_remove_listener(self.async_schedule_check_client_list) @property def should_poll(self): diff --git a/homeassistant/helpers/update_coordinator.py b/homeassistant/helpers/update_coordinator.py index f7ad8e013cb..f671e1b973a 100644 --- a/homeassistant/helpers/update_coordinator.py +++ b/homeassistant/helpers/update_coordinator.py @@ -2,7 +2,7 @@ from __future__ import annotations import asyncio -from collections.abc import Awaitable, Callable +from collections.abc import Awaitable, Callable, Generator from datetime import datetime, timedelta import logging from time import monotonic @@ -13,7 +13,7 @@ import aiohttp import requests from homeassistant import config_entries -from homeassistant.core import CALLBACK_TYPE, Event, HassJob, HomeAssistant, callback +from homeassistant.core import CALLBACK_TYPE, HassJob, HomeAssistant, callback from homeassistant.exceptions import ConfigEntryAuthFailed, ConfigEntryNotReady from homeassistant.util.dt import utcnow @@ -61,7 +61,7 @@ class DataUpdateCoordinator(Generic[_T]): # when it was already checked during setup. self.data: _T = None # type: ignore[assignment] - self._listeners: list[CALLBACK_TYPE] = [] + self._listeners: dict[CALLBACK_TYPE, tuple[CALLBACK_TYPE, object | None]] = {} self._job = HassJob(self._handle_refresh_interval) self._unsub_refresh: CALLBACK_TYPE | None = None self._request_refresh_task: asyncio.TimerHandle | None = None @@ -82,32 +82,46 @@ class DataUpdateCoordinator(Generic[_T]): self._debounced_refresh = request_refresh_debouncer @callback - def async_add_listener(self, update_callback: CALLBACK_TYPE) -> Callable[[], None]: + def async_add_listener( + self, update_callback: CALLBACK_TYPE, context: Any = None + ) -> Callable[[], None]: """Listen for data updates.""" schedule_refresh = not self._listeners - self._listeners.append(update_callback) + @callback + def remove_listener() -> None: + """Remove update listener.""" + self._listeners.pop(remove_listener) + if not self._listeners: + self._unschedule_refresh() + + self._listeners[remove_listener] = (update_callback, context) # This is the first listener, set up interval. if schedule_refresh: self._schedule_refresh() - @callback - def remove_listener() -> None: - """Remove update listener.""" - self.async_remove_listener(update_callback) - return remove_listener @callback - def async_remove_listener(self, update_callback: CALLBACK_TYPE) -> None: - """Remove data update.""" - self._listeners.remove(update_callback) + def async_update_listeners(self) -> None: + """Update all registered listeners.""" + for update_callback, _ in list(self._listeners.values()): + update_callback() - if not self._listeners and self._unsub_refresh: + @callback + def _unschedule_refresh(self) -> None: + """Unschedule any pending refresh since there is no longer any listeners.""" + if self._unsub_refresh: self._unsub_refresh() self._unsub_refresh = None + def async_contexts(self) -> Generator[Any, None, None]: + """Return all registered contexts.""" + yield from ( + context for _, context in self._listeners.values() if context is not None + ) + @callback def _schedule_refresh(self) -> None: """Schedule a refresh.""" @@ -266,8 +280,7 @@ class DataUpdateCoordinator(Generic[_T]): if not auth_failed and self._listeners and not self.hass.is_stopping: self._schedule_refresh() - for update_callback in self._listeners: - update_callback() + self.async_update_listeners() @callback def async_set_updated_data(self, data: _T) -> None: @@ -288,24 +301,18 @@ class DataUpdateCoordinator(Generic[_T]): if self._listeners: self._schedule_refresh() - for update_callback in self._listeners: - update_callback() - - @callback - def _async_stop_refresh(self, _: Event) -> None: - """Stop refreshing when Home Assistant is stopping.""" - self.update_interval = None - if self._unsub_refresh: - self._unsub_refresh() - self._unsub_refresh = None + self.async_update_listeners() class CoordinatorEntity(entity.Entity, Generic[_DataUpdateCoordinatorT]): """A class for entities using DataUpdateCoordinator.""" - def __init__(self, coordinator: _DataUpdateCoordinatorT) -> None: + def __init__( + self, coordinator: _DataUpdateCoordinatorT, context: Any = None + ) -> None: """Create the entity with a DataUpdateCoordinator.""" self.coordinator = coordinator + self.coordinator_context = context @property def should_poll(self) -> bool: @@ -321,7 +328,9 @@ class CoordinatorEntity(entity.Entity, Generic[_DataUpdateCoordinatorT]): """When entity is added to hass.""" await super().async_added_to_hass() self.async_on_remove( - self.coordinator.async_add_listener(self._handle_coordinator_update) + self.coordinator.async_add_listener( + self._handle_coordinator_update, self.coordinator_context + ) ) @callback diff --git a/tests/helpers/test_update_coordinator.py b/tests/helpers/test_update_coordinator.py index 7023798f2b4..0d0970a4756 100644 --- a/tests/helpers/test_update_coordinator.py +++ b/tests/helpers/test_update_coordinator.py @@ -109,11 +109,29 @@ async def test_async_refresh(crd): await crd.async_refresh() assert updates == [2] - # Test unsubscribing through method - crd.async_add_listener(update_callback) - crd.async_remove_listener(update_callback) + +async def test_update_context(crd: update_coordinator.DataUpdateCoordinator[int]): + """Test update contexts for the update coordinator.""" await crd.async_refresh() - assert updates == [2] + assert not set(crd.async_contexts()) + + def update_callback1(): + pass + + def update_callback2(): + pass + + unsub1 = crd.async_add_listener(update_callback1, 1) + assert set(crd.async_contexts()) == {1} + + unsub2 = crd.async_add_listener(update_callback2, 2) + assert set(crd.async_contexts()) == {1, 2} + + unsub1() + assert set(crd.async_contexts()) == {2} + + unsub2() + assert not set(crd.async_contexts()) async def test_request_refresh(crd): @@ -191,7 +209,7 @@ async def test_update_interval(hass, crd): # Add subscriber update_callback = Mock() - crd.async_add_listener(update_callback) + unsub = crd.async_add_listener(update_callback) # Test twice we update with subscriber async_fire_time_changed(hass, utcnow() + crd.update_interval) @@ -203,7 +221,7 @@ async def test_update_interval(hass, crd): assert crd.data == 2 # Test removing listener - crd.async_remove_listener(update_callback) + unsub() async_fire_time_changed(hass, utcnow() + crd.update_interval) await hass.async_block_till_done() @@ -222,7 +240,7 @@ async def test_update_interval_not_present(hass, crd_without_update_interval): # Add subscriber update_callback = Mock() - crd.async_add_listener(update_callback) + unsub = crd.async_add_listener(update_callback) # Test twice we don't update with subscriber with no update interval async_fire_time_changed(hass, utcnow() + DEFAULT_UPDATE_INTERVAL) @@ -234,7 +252,7 @@ async def test_update_interval_not_present(hass, crd_without_update_interval): assert crd.data is None # Test removing listener - crd.async_remove_listener(update_callback) + unsub() async_fire_time_changed(hass, utcnow() + DEFAULT_UPDATE_INTERVAL) await hass.async_block_till_done() @@ -253,9 +271,10 @@ async def test_refresh_recover(crd, caplog): assert "Fetching test data recovered" in caplog.text -async def test_coordinator_entity(crd): +async def test_coordinator_entity(crd: update_coordinator.DataUpdateCoordinator[int]): """Test the CoordinatorEntity class.""" - entity = update_coordinator.CoordinatorEntity(crd) + context = object() + entity = update_coordinator.CoordinatorEntity(crd, context) assert entity.should_poll is False @@ -278,6 +297,8 @@ async def test_coordinator_entity(crd): await entity.async_update() assert entity.available is False + assert list(crd.async_contexts()) == [context] + async def test_async_set_updated_data(crd): """Test async_set_updated_data for update coordinator."""