Keep track of a context for each listener (#72702)
* Remove async_remove_listener
This avoids the ambuigity as to what happens if same callback is added multiple times.
* Keep track of a context for each listener
This allow a update coordinator to adapt what data to request on update from the backing service based on which entities are enabled.
* Clone list before calling callbacks
The callbacks can end up unregistering and modifying the dict while iterating.
* Only yield actual values
* Add a test for update context
* Factor out iteration of _listeners to helper
* Verify context is passed to coordinator
* Switch to Any as type instead of object
* Remove function which use was dropped earliers
The use was removed in 8bee25c938
pull/72988/head
parent
a28fa5377a
commit
8910d265d6
|
@ -131,4 +131,4 @@ class BMWButton(BMWBaseEntity, ButtonEntity):
|
|||
# Always update HA states after a button was executed.
|
||||
# BMW remote services that change the vehicle's state update the local object
|
||||
# when executing the service, so only the HA state machine needs further updates.
|
||||
self.coordinator.notify_listeners()
|
||||
self.coordinator.async_update_listeners()
|
||||
|
|
|
@ -74,8 +74,3 @@ class BMWDataUpdateCoordinator(DataUpdateCoordinator):
|
|||
if not refresh_token:
|
||||
data.pop(CONF_REFRESH_TOKEN)
|
||||
self.hass.config_entries.async_update_entry(self._entry, data=data)
|
||||
|
||||
def notify_listeners(self) -> None:
|
||||
"""Notify all listeners to refresh HA state machine."""
|
||||
for update_callback in self._listeners:
|
||||
update_callback()
|
||||
|
|
|
@ -74,12 +74,12 @@ def modernforms_exception_handler(func):
|
|||
async def handler(self, *args, **kwargs):
|
||||
try:
|
||||
await func(self, *args, **kwargs)
|
||||
self.coordinator.update_listeners()
|
||||
self.coordinator.async_update_listeners()
|
||||
|
||||
except ModernFormsConnectionError as error:
|
||||
_LOGGER.error("Error communicating with API: %s", error)
|
||||
self.coordinator.last_update_success = False
|
||||
self.coordinator.update_listeners()
|
||||
self.coordinator.async_update_listeners()
|
||||
|
||||
except ModernFormsError as error:
|
||||
_LOGGER.error("Invalid response from API: %s", error)
|
||||
|
@ -108,11 +108,6 @@ class ModernFormsDataUpdateCoordinator(DataUpdateCoordinator[ModernFormsDeviceSt
|
|||
update_interval=SCAN_INTERVAL,
|
||||
)
|
||||
|
||||
def update_listeners(self) -> None:
|
||||
"""Call update on all listeners."""
|
||||
for update_callback in self._listeners:
|
||||
update_callback()
|
||||
|
||||
async def _async_update_data(self) -> ModernFormsDevice:
|
||||
"""Fetch data from Modern Forms."""
|
||||
try:
|
||||
|
|
|
@ -83,8 +83,7 @@ class Alpha2BaseCoordinator(DataUpdateCoordinator[dict[str, dict]]):
|
|||
async def async_set_cooling(self, enabled: bool) -> None:
|
||||
"""Enable or disable cooling mode."""
|
||||
await self.base.set_cooling(enabled)
|
||||
for update_callback in self._listeners:
|
||||
update_callback()
|
||||
self.async_update_listeners()
|
||||
|
||||
async def async_set_target_temperature(
|
||||
self, heat_area_id: str, target_temperature: float
|
||||
|
@ -117,8 +116,7 @@ class Alpha2BaseCoordinator(DataUpdateCoordinator[dict[str, dict]]):
|
|||
"Failed to set target temperature, communication error with alpha2 base"
|
||||
) from http_err
|
||||
self.data["heat_areas"][heat_area_id].update(update_data)
|
||||
for update_callback in self._listeners:
|
||||
update_callback()
|
||||
self.async_update_listeners()
|
||||
|
||||
async def async_set_heat_area_mode(
|
||||
self, heat_area_id: str, heat_area_mode: int
|
||||
|
@ -161,5 +159,5 @@ class Alpha2BaseCoordinator(DataUpdateCoordinator[dict[str, dict]]):
|
|||
self.data["heat_areas"][heat_area_id]["T_TARGET"] = self.data[
|
||||
"heat_areas"
|
||||
][heat_area_id]["T_HEAT_NIGHT"]
|
||||
for update_callback in self._listeners:
|
||||
update_callback()
|
||||
|
||||
self.async_update_listeners()
|
||||
|
|
|
@ -19,14 +19,7 @@ from homeassistant.const import (
|
|||
CONF_USERNAME,
|
||||
Platform,
|
||||
)
|
||||
from homeassistant.core import (
|
||||
CALLBACK_TYPE,
|
||||
Context,
|
||||
Event,
|
||||
HassJob,
|
||||
HomeAssistant,
|
||||
callback,
|
||||
)
|
||||
from homeassistant.core import Context, HassJob, HomeAssistant, callback
|
||||
from homeassistant.helpers.debounce import Debouncer
|
||||
from homeassistant.helpers.update_coordinator import DataUpdateCoordinator
|
||||
|
||||
|
@ -121,12 +114,7 @@ class PhilipsTVDataUpdateCoordinator(DataUpdateCoordinator[None]):
|
|||
self.options = options
|
||||
self._notify_future: asyncio.Task | None = None
|
||||
|
||||
@callback
|
||||
def _update_listeners():
|
||||
for update_callback in self._listeners:
|
||||
update_callback()
|
||||
|
||||
self.turn_on = PluggableAction(_update_listeners)
|
||||
self.turn_on = PluggableAction(self.async_update_listeners)
|
||||
|
||||
super().__init__(
|
||||
hass,
|
||||
|
@ -193,15 +181,9 @@ class PhilipsTVDataUpdateCoordinator(DataUpdateCoordinator[None]):
|
|||
self._notify_future = asyncio.create_task(self._notify_task())
|
||||
|
||||
@callback
|
||||
def async_remove_listener(self, update_callback: CALLBACK_TYPE) -> None:
|
||||
def _unschedule_refresh(self) -> None:
|
||||
"""Remove data update."""
|
||||
super().async_remove_listener(update_callback)
|
||||
if not self._listeners:
|
||||
self._async_notify_stop()
|
||||
|
||||
@callback
|
||||
def _async_stop_refresh(self, event: Event) -> None:
|
||||
super()._async_stop_refresh(event)
|
||||
super()._unschedule_refresh()
|
||||
self._async_notify_stop()
|
||||
|
||||
@callback
|
||||
|
|
|
@ -75,11 +75,6 @@ class SystemBridgeDataUpdateCoordinator(
|
|||
hass, LOGGER, name=DOMAIN, update_interval=timedelta(seconds=30)
|
||||
)
|
||||
|
||||
def update_listeners(self) -> None:
|
||||
"""Call update on all listeners."""
|
||||
for update_callback in self._listeners:
|
||||
update_callback()
|
||||
|
||||
async def async_get_data(
|
||||
self,
|
||||
modules: list[str],
|
||||
|
@ -113,7 +108,7 @@ class SystemBridgeDataUpdateCoordinator(
|
|||
self.unsub()
|
||||
self.unsub = None
|
||||
self.last_update_success = False
|
||||
self.update_listeners()
|
||||
self.async_update_listeners()
|
||||
except (ConnectionClosedException, ConnectionResetError) as exception:
|
||||
self.logger.info(
|
||||
"Websocket connection closed for %s. Will retry: %s",
|
||||
|
@ -124,7 +119,7 @@ class SystemBridgeDataUpdateCoordinator(
|
|||
self.unsub()
|
||||
self.unsub = None
|
||||
self.last_update_success = False
|
||||
self.update_listeners()
|
||||
self.async_update_listeners()
|
||||
except ConnectionErrorException as exception:
|
||||
self.logger.warning(
|
||||
"Connection error occurred for %s. Will retry: %s",
|
||||
|
@ -135,7 +130,7 @@ class SystemBridgeDataUpdateCoordinator(
|
|||
self.unsub()
|
||||
self.unsub = None
|
||||
self.last_update_success = False
|
||||
self.update_listeners()
|
||||
self.async_update_listeners()
|
||||
|
||||
async def _setup_websocket(self) -> None:
|
||||
"""Use WebSocket for updates."""
|
||||
|
@ -151,7 +146,7 @@ class SystemBridgeDataUpdateCoordinator(
|
|||
self.unsub()
|
||||
self.unsub = None
|
||||
self.last_update_success = False
|
||||
self.update_listeners()
|
||||
self.async_update_listeners()
|
||||
except ConnectionErrorException as exception:
|
||||
self.logger.warning(
|
||||
"Connection error occurred for %s. Will retry: %s",
|
||||
|
@ -159,7 +154,7 @@ class SystemBridgeDataUpdateCoordinator(
|
|||
exception,
|
||||
)
|
||||
self.last_update_success = False
|
||||
self.update_listeners()
|
||||
self.async_update_listeners()
|
||||
except asyncio.TimeoutError as exception:
|
||||
self.logger.warning(
|
||||
"Timed out waiting for %s. Will retry: %s",
|
||||
|
@ -167,11 +162,11 @@ class SystemBridgeDataUpdateCoordinator(
|
|||
exception,
|
||||
)
|
||||
self.last_update_success = False
|
||||
self.update_listeners()
|
||||
self.async_update_listeners()
|
||||
|
||||
self.hass.async_create_task(self._listen_for_data())
|
||||
self.last_update_success = True
|
||||
self.update_listeners()
|
||||
self.async_update_listeners()
|
||||
|
||||
async def close_websocket(_) -> None:
|
||||
"""Close WebSocket connection."""
|
||||
|
|
|
@ -47,11 +47,6 @@ class ToonDataUpdateCoordinator(DataUpdateCoordinator[Status]):
|
|||
hass, _LOGGER, name=DOMAIN, update_interval=DEFAULT_SCAN_INTERVAL
|
||||
)
|
||||
|
||||
def update_listeners(self) -> None:
|
||||
"""Call update on all listeners."""
|
||||
for update_callback in self._listeners:
|
||||
update_callback()
|
||||
|
||||
async def register_webhook(self, event: Event | None = None) -> None:
|
||||
"""Register a webhook with Toon to get live updates."""
|
||||
if CONF_WEBHOOK_ID not in self.entry.data:
|
||||
|
@ -128,7 +123,7 @@ class ToonDataUpdateCoordinator(DataUpdateCoordinator[Status]):
|
|||
|
||||
try:
|
||||
await self.toon.update(data["updateDataSet"])
|
||||
self.update_listeners()
|
||||
self.async_update_listeners()
|
||||
except ToonError as err:
|
||||
_LOGGER.error("Could not process data received from Toon webhook - %s", err)
|
||||
|
||||
|
|
|
@ -16,12 +16,12 @@ def toon_exception_handler(func):
|
|||
async def handler(self, *args, **kwargs):
|
||||
try:
|
||||
await func(self, *args, **kwargs)
|
||||
self.coordinator.update_listeners()
|
||||
self.coordinator.async_update_listeners()
|
||||
|
||||
except ToonConnectionError as error:
|
||||
_LOGGER.error("Error communicating with API: %s", error)
|
||||
self.coordinator.last_update_success = False
|
||||
self.coordinator.update_listeners()
|
||||
self.coordinator.async_update_listeners()
|
||||
|
||||
except ToonError as error:
|
||||
_LOGGER.error("Invalid response from API: %s", error)
|
||||
|
|
|
@ -123,12 +123,6 @@ class DeviceCoordinator(DataUpdateCoordinator):
|
|||
except ActionException as err:
|
||||
raise UpdateFailed("WeMo update failed") from err
|
||||
|
||||
@callback
|
||||
def async_update_listeners(self) -> None:
|
||||
"""Update all listeners."""
|
||||
for update_callback in self._listeners:
|
||||
update_callback()
|
||||
|
||||
|
||||
def _device_info(wemo: WeMoDevice) -> DeviceInfo:
|
||||
return DeviceInfo(
|
||||
|
|
|
@ -54,11 +54,6 @@ class WLEDDataUpdateCoordinator(DataUpdateCoordinator[WLEDDevice]):
|
|||
self.data is not None and len(self.data.state.segments) > 1
|
||||
)
|
||||
|
||||
def update_listeners(self) -> None:
|
||||
"""Call update on all listeners."""
|
||||
for update_callback in self._listeners:
|
||||
update_callback()
|
||||
|
||||
@callback
|
||||
def _use_websocket(self) -> None:
|
||||
"""Use WebSocket for updates, instead of polling."""
|
||||
|
@ -81,7 +76,7 @@ class WLEDDataUpdateCoordinator(DataUpdateCoordinator[WLEDDevice]):
|
|||
self.logger.info(err)
|
||||
except WLEDError as err:
|
||||
self.last_update_success = False
|
||||
self.update_listeners()
|
||||
self.async_update_listeners()
|
||||
self.logger.error(err)
|
||||
|
||||
# Ensure we are disconnected
|
||||
|
|
|
@ -15,11 +15,11 @@ def wled_exception_handler(func):
|
|||
async def handler(self, *args, **kwargs):
|
||||
try:
|
||||
await func(self, *args, **kwargs)
|
||||
self.coordinator.update_listeners()
|
||||
self.coordinator.async_update_listeners()
|
||||
|
||||
except WLEDConnectionError as error:
|
||||
self.coordinator.last_update_success = False
|
||||
self.coordinator.update_listeners()
|
||||
self.coordinator.async_update_listeners()
|
||||
raise HomeAssistantError("Error communicating with WLED API") from error
|
||||
|
||||
except WLEDError as error:
|
||||
|
|
|
@ -106,7 +106,9 @@ class MusicCastMediaPlayer(MusicCastDeviceEntity, MediaPlayerEntity):
|
|||
self.coordinator.musiccast.register_group_update_callback(
|
||||
self.update_all_mc_entities
|
||||
)
|
||||
self.coordinator.async_add_listener(self.async_schedule_check_client_list)
|
||||
self.async_on_remove(
|
||||
self.coordinator.async_add_listener(self.async_schedule_check_client_list)
|
||||
)
|
||||
|
||||
async def async_will_remove_from_hass(self):
|
||||
"""Entity being removed from hass."""
|
||||
|
@ -116,7 +118,6 @@ class MusicCastMediaPlayer(MusicCastDeviceEntity, MediaPlayerEntity):
|
|||
self.coordinator.musiccast.remove_group_update_callback(
|
||||
self.update_all_mc_entities
|
||||
)
|
||||
self.coordinator.async_remove_listener(self.async_schedule_check_client_list)
|
||||
|
||||
@property
|
||||
def should_poll(self):
|
||||
|
|
|
@ -2,7 +2,7 @@
|
|||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
from collections.abc import Awaitable, Callable
|
||||
from collections.abc import Awaitable, Callable, Generator
|
||||
from datetime import datetime, timedelta
|
||||
import logging
|
||||
from time import monotonic
|
||||
|
@ -13,7 +13,7 @@ import aiohttp
|
|||
import requests
|
||||
|
||||
from homeassistant import config_entries
|
||||
from homeassistant.core import CALLBACK_TYPE, Event, HassJob, HomeAssistant, callback
|
||||
from homeassistant.core import CALLBACK_TYPE, HassJob, HomeAssistant, callback
|
||||
from homeassistant.exceptions import ConfigEntryAuthFailed, ConfigEntryNotReady
|
||||
from homeassistant.util.dt import utcnow
|
||||
|
||||
|
@ -61,7 +61,7 @@ class DataUpdateCoordinator(Generic[_T]):
|
|||
# when it was already checked during setup.
|
||||
self.data: _T = None # type: ignore[assignment]
|
||||
|
||||
self._listeners: list[CALLBACK_TYPE] = []
|
||||
self._listeners: dict[CALLBACK_TYPE, tuple[CALLBACK_TYPE, object | None]] = {}
|
||||
self._job = HassJob(self._handle_refresh_interval)
|
||||
self._unsub_refresh: CALLBACK_TYPE | None = None
|
||||
self._request_refresh_task: asyncio.TimerHandle | None = None
|
||||
|
@ -82,32 +82,46 @@ class DataUpdateCoordinator(Generic[_T]):
|
|||
self._debounced_refresh = request_refresh_debouncer
|
||||
|
||||
@callback
|
||||
def async_add_listener(self, update_callback: CALLBACK_TYPE) -> Callable[[], None]:
|
||||
def async_add_listener(
|
||||
self, update_callback: CALLBACK_TYPE, context: Any = None
|
||||
) -> Callable[[], None]:
|
||||
"""Listen for data updates."""
|
||||
schedule_refresh = not self._listeners
|
||||
|
||||
self._listeners.append(update_callback)
|
||||
@callback
|
||||
def remove_listener() -> None:
|
||||
"""Remove update listener."""
|
||||
self._listeners.pop(remove_listener)
|
||||
if not self._listeners:
|
||||
self._unschedule_refresh()
|
||||
|
||||
self._listeners[remove_listener] = (update_callback, context)
|
||||
|
||||
# This is the first listener, set up interval.
|
||||
if schedule_refresh:
|
||||
self._schedule_refresh()
|
||||
|
||||
@callback
|
||||
def remove_listener() -> None:
|
||||
"""Remove update listener."""
|
||||
self.async_remove_listener(update_callback)
|
||||
|
||||
return remove_listener
|
||||
|
||||
@callback
|
||||
def async_remove_listener(self, update_callback: CALLBACK_TYPE) -> None:
|
||||
"""Remove data update."""
|
||||
self._listeners.remove(update_callback)
|
||||
def async_update_listeners(self) -> None:
|
||||
"""Update all registered listeners."""
|
||||
for update_callback, _ in list(self._listeners.values()):
|
||||
update_callback()
|
||||
|
||||
if not self._listeners and self._unsub_refresh:
|
||||
@callback
|
||||
def _unschedule_refresh(self) -> None:
|
||||
"""Unschedule any pending refresh since there is no longer any listeners."""
|
||||
if self._unsub_refresh:
|
||||
self._unsub_refresh()
|
||||
self._unsub_refresh = None
|
||||
|
||||
def async_contexts(self) -> Generator[Any, None, None]:
|
||||
"""Return all registered contexts."""
|
||||
yield from (
|
||||
context for _, context in self._listeners.values() if context is not None
|
||||
)
|
||||
|
||||
@callback
|
||||
def _schedule_refresh(self) -> None:
|
||||
"""Schedule a refresh."""
|
||||
|
@ -266,8 +280,7 @@ class DataUpdateCoordinator(Generic[_T]):
|
|||
if not auth_failed and self._listeners and not self.hass.is_stopping:
|
||||
self._schedule_refresh()
|
||||
|
||||
for update_callback in self._listeners:
|
||||
update_callback()
|
||||
self.async_update_listeners()
|
||||
|
||||
@callback
|
||||
def async_set_updated_data(self, data: _T) -> None:
|
||||
|
@ -288,24 +301,18 @@ class DataUpdateCoordinator(Generic[_T]):
|
|||
if self._listeners:
|
||||
self._schedule_refresh()
|
||||
|
||||
for update_callback in self._listeners:
|
||||
update_callback()
|
||||
|
||||
@callback
|
||||
def _async_stop_refresh(self, _: Event) -> None:
|
||||
"""Stop refreshing when Home Assistant is stopping."""
|
||||
self.update_interval = None
|
||||
if self._unsub_refresh:
|
||||
self._unsub_refresh()
|
||||
self._unsub_refresh = None
|
||||
self.async_update_listeners()
|
||||
|
||||
|
||||
class CoordinatorEntity(entity.Entity, Generic[_DataUpdateCoordinatorT]):
|
||||
"""A class for entities using DataUpdateCoordinator."""
|
||||
|
||||
def __init__(self, coordinator: _DataUpdateCoordinatorT) -> None:
|
||||
def __init__(
|
||||
self, coordinator: _DataUpdateCoordinatorT, context: Any = None
|
||||
) -> None:
|
||||
"""Create the entity with a DataUpdateCoordinator."""
|
||||
self.coordinator = coordinator
|
||||
self.coordinator_context = context
|
||||
|
||||
@property
|
||||
def should_poll(self) -> bool:
|
||||
|
@ -321,7 +328,9 @@ class CoordinatorEntity(entity.Entity, Generic[_DataUpdateCoordinatorT]):
|
|||
"""When entity is added to hass."""
|
||||
await super().async_added_to_hass()
|
||||
self.async_on_remove(
|
||||
self.coordinator.async_add_listener(self._handle_coordinator_update)
|
||||
self.coordinator.async_add_listener(
|
||||
self._handle_coordinator_update, self.coordinator_context
|
||||
)
|
||||
)
|
||||
|
||||
@callback
|
||||
|
|
|
@ -109,11 +109,29 @@ async def test_async_refresh(crd):
|
|||
await crd.async_refresh()
|
||||
assert updates == [2]
|
||||
|
||||
# Test unsubscribing through method
|
||||
crd.async_add_listener(update_callback)
|
||||
crd.async_remove_listener(update_callback)
|
||||
|
||||
async def test_update_context(crd: update_coordinator.DataUpdateCoordinator[int]):
|
||||
"""Test update contexts for the update coordinator."""
|
||||
await crd.async_refresh()
|
||||
assert updates == [2]
|
||||
assert not set(crd.async_contexts())
|
||||
|
||||
def update_callback1():
|
||||
pass
|
||||
|
||||
def update_callback2():
|
||||
pass
|
||||
|
||||
unsub1 = crd.async_add_listener(update_callback1, 1)
|
||||
assert set(crd.async_contexts()) == {1}
|
||||
|
||||
unsub2 = crd.async_add_listener(update_callback2, 2)
|
||||
assert set(crd.async_contexts()) == {1, 2}
|
||||
|
||||
unsub1()
|
||||
assert set(crd.async_contexts()) == {2}
|
||||
|
||||
unsub2()
|
||||
assert not set(crd.async_contexts())
|
||||
|
||||
|
||||
async def test_request_refresh(crd):
|
||||
|
@ -191,7 +209,7 @@ async def test_update_interval(hass, crd):
|
|||
|
||||
# Add subscriber
|
||||
update_callback = Mock()
|
||||
crd.async_add_listener(update_callback)
|
||||
unsub = crd.async_add_listener(update_callback)
|
||||
|
||||
# Test twice we update with subscriber
|
||||
async_fire_time_changed(hass, utcnow() + crd.update_interval)
|
||||
|
@ -203,7 +221,7 @@ async def test_update_interval(hass, crd):
|
|||
assert crd.data == 2
|
||||
|
||||
# Test removing listener
|
||||
crd.async_remove_listener(update_callback)
|
||||
unsub()
|
||||
|
||||
async_fire_time_changed(hass, utcnow() + crd.update_interval)
|
||||
await hass.async_block_till_done()
|
||||
|
@ -222,7 +240,7 @@ async def test_update_interval_not_present(hass, crd_without_update_interval):
|
|||
|
||||
# Add subscriber
|
||||
update_callback = Mock()
|
||||
crd.async_add_listener(update_callback)
|
||||
unsub = crd.async_add_listener(update_callback)
|
||||
|
||||
# Test twice we don't update with subscriber with no update interval
|
||||
async_fire_time_changed(hass, utcnow() + DEFAULT_UPDATE_INTERVAL)
|
||||
|
@ -234,7 +252,7 @@ async def test_update_interval_not_present(hass, crd_without_update_interval):
|
|||
assert crd.data is None
|
||||
|
||||
# Test removing listener
|
||||
crd.async_remove_listener(update_callback)
|
||||
unsub()
|
||||
|
||||
async_fire_time_changed(hass, utcnow() + DEFAULT_UPDATE_INTERVAL)
|
||||
await hass.async_block_till_done()
|
||||
|
@ -253,9 +271,10 @@ async def test_refresh_recover(crd, caplog):
|
|||
assert "Fetching test data recovered" in caplog.text
|
||||
|
||||
|
||||
async def test_coordinator_entity(crd):
|
||||
async def test_coordinator_entity(crd: update_coordinator.DataUpdateCoordinator[int]):
|
||||
"""Test the CoordinatorEntity class."""
|
||||
entity = update_coordinator.CoordinatorEntity(crd)
|
||||
context = object()
|
||||
entity = update_coordinator.CoordinatorEntity(crd, context)
|
||||
|
||||
assert entity.should_poll is False
|
||||
|
||||
|
@ -278,6 +297,8 @@ async def test_coordinator_entity(crd):
|
|||
await entity.async_update()
|
||||
assert entity.available is False
|
||||
|
||||
assert list(crd.async_contexts()) == [context]
|
||||
|
||||
|
||||
async def test_async_set_updated_data(crd):
|
||||
"""Test async_set_updated_data for update coordinator."""
|
||||
|
|
Loading…
Reference in New Issue