Keep track of a context for each listener (#72702)

* Remove async_remove_listener

This avoids the ambuigity as to what happens if same callback is added multiple times.

* Keep track of a context for each listener

This allow a update coordinator to adapt what data to request on update from the backing service based on which entities are enabled.

* Clone list before calling callbacks

The callbacks can end up unregistering and modifying the dict while iterating.

* Only yield actual values

* Add a test for update context

* Factor out iteration of _listeners to helper

* Verify context is passed to coordinator

* Switch to Any as type instead of object

* Remove function which use was dropped earliers

The use was removed in 8bee25c938
pull/72988/head
Joakim Plate 2022-06-03 13:55:57 +02:00 committed by GitHub
parent a28fa5377a
commit 8910d265d6
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
14 changed files with 95 additions and 115 deletions

View File

@ -131,4 +131,4 @@ class BMWButton(BMWBaseEntity, ButtonEntity):
# Always update HA states after a button was executed.
# BMW remote services that change the vehicle's state update the local object
# when executing the service, so only the HA state machine needs further updates.
self.coordinator.notify_listeners()
self.coordinator.async_update_listeners()

View File

@ -74,8 +74,3 @@ class BMWDataUpdateCoordinator(DataUpdateCoordinator):
if not refresh_token:
data.pop(CONF_REFRESH_TOKEN)
self.hass.config_entries.async_update_entry(self._entry, data=data)
def notify_listeners(self) -> None:
"""Notify all listeners to refresh HA state machine."""
for update_callback in self._listeners:
update_callback()

View File

@ -74,12 +74,12 @@ def modernforms_exception_handler(func):
async def handler(self, *args, **kwargs):
try:
await func(self, *args, **kwargs)
self.coordinator.update_listeners()
self.coordinator.async_update_listeners()
except ModernFormsConnectionError as error:
_LOGGER.error("Error communicating with API: %s", error)
self.coordinator.last_update_success = False
self.coordinator.update_listeners()
self.coordinator.async_update_listeners()
except ModernFormsError as error:
_LOGGER.error("Invalid response from API: %s", error)
@ -108,11 +108,6 @@ class ModernFormsDataUpdateCoordinator(DataUpdateCoordinator[ModernFormsDeviceSt
update_interval=SCAN_INTERVAL,
)
def update_listeners(self) -> None:
"""Call update on all listeners."""
for update_callback in self._listeners:
update_callback()
async def _async_update_data(self) -> ModernFormsDevice:
"""Fetch data from Modern Forms."""
try:

View File

@ -83,8 +83,7 @@ class Alpha2BaseCoordinator(DataUpdateCoordinator[dict[str, dict]]):
async def async_set_cooling(self, enabled: bool) -> None:
"""Enable or disable cooling mode."""
await self.base.set_cooling(enabled)
for update_callback in self._listeners:
update_callback()
self.async_update_listeners()
async def async_set_target_temperature(
self, heat_area_id: str, target_temperature: float
@ -117,8 +116,7 @@ class Alpha2BaseCoordinator(DataUpdateCoordinator[dict[str, dict]]):
"Failed to set target temperature, communication error with alpha2 base"
) from http_err
self.data["heat_areas"][heat_area_id].update(update_data)
for update_callback in self._listeners:
update_callback()
self.async_update_listeners()
async def async_set_heat_area_mode(
self, heat_area_id: str, heat_area_mode: int
@ -161,5 +159,5 @@ class Alpha2BaseCoordinator(DataUpdateCoordinator[dict[str, dict]]):
self.data["heat_areas"][heat_area_id]["T_TARGET"] = self.data[
"heat_areas"
][heat_area_id]["T_HEAT_NIGHT"]
for update_callback in self._listeners:
update_callback()
self.async_update_listeners()

View File

@ -19,14 +19,7 @@ from homeassistant.const import (
CONF_USERNAME,
Platform,
)
from homeassistant.core import (
CALLBACK_TYPE,
Context,
Event,
HassJob,
HomeAssistant,
callback,
)
from homeassistant.core import Context, HassJob, HomeAssistant, callback
from homeassistant.helpers.debounce import Debouncer
from homeassistant.helpers.update_coordinator import DataUpdateCoordinator
@ -121,12 +114,7 @@ class PhilipsTVDataUpdateCoordinator(DataUpdateCoordinator[None]):
self.options = options
self._notify_future: asyncio.Task | None = None
@callback
def _update_listeners():
for update_callback in self._listeners:
update_callback()
self.turn_on = PluggableAction(_update_listeners)
self.turn_on = PluggableAction(self.async_update_listeners)
super().__init__(
hass,
@ -193,15 +181,9 @@ class PhilipsTVDataUpdateCoordinator(DataUpdateCoordinator[None]):
self._notify_future = asyncio.create_task(self._notify_task())
@callback
def async_remove_listener(self, update_callback: CALLBACK_TYPE) -> None:
def _unschedule_refresh(self) -> None:
"""Remove data update."""
super().async_remove_listener(update_callback)
if not self._listeners:
self._async_notify_stop()
@callback
def _async_stop_refresh(self, event: Event) -> None:
super()._async_stop_refresh(event)
super()._unschedule_refresh()
self._async_notify_stop()
@callback

View File

@ -75,11 +75,6 @@ class SystemBridgeDataUpdateCoordinator(
hass, LOGGER, name=DOMAIN, update_interval=timedelta(seconds=30)
)
def update_listeners(self) -> None:
"""Call update on all listeners."""
for update_callback in self._listeners:
update_callback()
async def async_get_data(
self,
modules: list[str],
@ -113,7 +108,7 @@ class SystemBridgeDataUpdateCoordinator(
self.unsub()
self.unsub = None
self.last_update_success = False
self.update_listeners()
self.async_update_listeners()
except (ConnectionClosedException, ConnectionResetError) as exception:
self.logger.info(
"Websocket connection closed for %s. Will retry: %s",
@ -124,7 +119,7 @@ class SystemBridgeDataUpdateCoordinator(
self.unsub()
self.unsub = None
self.last_update_success = False
self.update_listeners()
self.async_update_listeners()
except ConnectionErrorException as exception:
self.logger.warning(
"Connection error occurred for %s. Will retry: %s",
@ -135,7 +130,7 @@ class SystemBridgeDataUpdateCoordinator(
self.unsub()
self.unsub = None
self.last_update_success = False
self.update_listeners()
self.async_update_listeners()
async def _setup_websocket(self) -> None:
"""Use WebSocket for updates."""
@ -151,7 +146,7 @@ class SystemBridgeDataUpdateCoordinator(
self.unsub()
self.unsub = None
self.last_update_success = False
self.update_listeners()
self.async_update_listeners()
except ConnectionErrorException as exception:
self.logger.warning(
"Connection error occurred for %s. Will retry: %s",
@ -159,7 +154,7 @@ class SystemBridgeDataUpdateCoordinator(
exception,
)
self.last_update_success = False
self.update_listeners()
self.async_update_listeners()
except asyncio.TimeoutError as exception:
self.logger.warning(
"Timed out waiting for %s. Will retry: %s",
@ -167,11 +162,11 @@ class SystemBridgeDataUpdateCoordinator(
exception,
)
self.last_update_success = False
self.update_listeners()
self.async_update_listeners()
self.hass.async_create_task(self._listen_for_data())
self.last_update_success = True
self.update_listeners()
self.async_update_listeners()
async def close_websocket(_) -> None:
"""Close WebSocket connection."""

View File

@ -47,11 +47,6 @@ class ToonDataUpdateCoordinator(DataUpdateCoordinator[Status]):
hass, _LOGGER, name=DOMAIN, update_interval=DEFAULT_SCAN_INTERVAL
)
def update_listeners(self) -> None:
"""Call update on all listeners."""
for update_callback in self._listeners:
update_callback()
async def register_webhook(self, event: Event | None = None) -> None:
"""Register a webhook with Toon to get live updates."""
if CONF_WEBHOOK_ID not in self.entry.data:
@ -128,7 +123,7 @@ class ToonDataUpdateCoordinator(DataUpdateCoordinator[Status]):
try:
await self.toon.update(data["updateDataSet"])
self.update_listeners()
self.async_update_listeners()
except ToonError as err:
_LOGGER.error("Could not process data received from Toon webhook - %s", err)

View File

@ -16,12 +16,12 @@ def toon_exception_handler(func):
async def handler(self, *args, **kwargs):
try:
await func(self, *args, **kwargs)
self.coordinator.update_listeners()
self.coordinator.async_update_listeners()
except ToonConnectionError as error:
_LOGGER.error("Error communicating with API: %s", error)
self.coordinator.last_update_success = False
self.coordinator.update_listeners()
self.coordinator.async_update_listeners()
except ToonError as error:
_LOGGER.error("Invalid response from API: %s", error)

View File

@ -123,12 +123,6 @@ class DeviceCoordinator(DataUpdateCoordinator):
except ActionException as err:
raise UpdateFailed("WeMo update failed") from err
@callback
def async_update_listeners(self) -> None:
"""Update all listeners."""
for update_callback in self._listeners:
update_callback()
def _device_info(wemo: WeMoDevice) -> DeviceInfo:
return DeviceInfo(

View File

@ -54,11 +54,6 @@ class WLEDDataUpdateCoordinator(DataUpdateCoordinator[WLEDDevice]):
self.data is not None and len(self.data.state.segments) > 1
)
def update_listeners(self) -> None:
"""Call update on all listeners."""
for update_callback in self._listeners:
update_callback()
@callback
def _use_websocket(self) -> None:
"""Use WebSocket for updates, instead of polling."""
@ -81,7 +76,7 @@ class WLEDDataUpdateCoordinator(DataUpdateCoordinator[WLEDDevice]):
self.logger.info(err)
except WLEDError as err:
self.last_update_success = False
self.update_listeners()
self.async_update_listeners()
self.logger.error(err)
# Ensure we are disconnected

View File

@ -15,11 +15,11 @@ def wled_exception_handler(func):
async def handler(self, *args, **kwargs):
try:
await func(self, *args, **kwargs)
self.coordinator.update_listeners()
self.coordinator.async_update_listeners()
except WLEDConnectionError as error:
self.coordinator.last_update_success = False
self.coordinator.update_listeners()
self.coordinator.async_update_listeners()
raise HomeAssistantError("Error communicating with WLED API") from error
except WLEDError as error:

View File

@ -106,7 +106,9 @@ class MusicCastMediaPlayer(MusicCastDeviceEntity, MediaPlayerEntity):
self.coordinator.musiccast.register_group_update_callback(
self.update_all_mc_entities
)
self.coordinator.async_add_listener(self.async_schedule_check_client_list)
self.async_on_remove(
self.coordinator.async_add_listener(self.async_schedule_check_client_list)
)
async def async_will_remove_from_hass(self):
"""Entity being removed from hass."""
@ -116,7 +118,6 @@ class MusicCastMediaPlayer(MusicCastDeviceEntity, MediaPlayerEntity):
self.coordinator.musiccast.remove_group_update_callback(
self.update_all_mc_entities
)
self.coordinator.async_remove_listener(self.async_schedule_check_client_list)
@property
def should_poll(self):

View File

@ -2,7 +2,7 @@
from __future__ import annotations
import asyncio
from collections.abc import Awaitable, Callable
from collections.abc import Awaitable, Callable, Generator
from datetime import datetime, timedelta
import logging
from time import monotonic
@ -13,7 +13,7 @@ import aiohttp
import requests
from homeassistant import config_entries
from homeassistant.core import CALLBACK_TYPE, Event, HassJob, HomeAssistant, callback
from homeassistant.core import CALLBACK_TYPE, HassJob, HomeAssistant, callback
from homeassistant.exceptions import ConfigEntryAuthFailed, ConfigEntryNotReady
from homeassistant.util.dt import utcnow
@ -61,7 +61,7 @@ class DataUpdateCoordinator(Generic[_T]):
# when it was already checked during setup.
self.data: _T = None # type: ignore[assignment]
self._listeners: list[CALLBACK_TYPE] = []
self._listeners: dict[CALLBACK_TYPE, tuple[CALLBACK_TYPE, object | None]] = {}
self._job = HassJob(self._handle_refresh_interval)
self._unsub_refresh: CALLBACK_TYPE | None = None
self._request_refresh_task: asyncio.TimerHandle | None = None
@ -82,32 +82,46 @@ class DataUpdateCoordinator(Generic[_T]):
self._debounced_refresh = request_refresh_debouncer
@callback
def async_add_listener(self, update_callback: CALLBACK_TYPE) -> Callable[[], None]:
def async_add_listener(
self, update_callback: CALLBACK_TYPE, context: Any = None
) -> Callable[[], None]:
"""Listen for data updates."""
schedule_refresh = not self._listeners
self._listeners.append(update_callback)
@callback
def remove_listener() -> None:
"""Remove update listener."""
self._listeners.pop(remove_listener)
if not self._listeners:
self._unschedule_refresh()
self._listeners[remove_listener] = (update_callback, context)
# This is the first listener, set up interval.
if schedule_refresh:
self._schedule_refresh()
@callback
def remove_listener() -> None:
"""Remove update listener."""
self.async_remove_listener(update_callback)
return remove_listener
@callback
def async_remove_listener(self, update_callback: CALLBACK_TYPE) -> None:
"""Remove data update."""
self._listeners.remove(update_callback)
def async_update_listeners(self) -> None:
"""Update all registered listeners."""
for update_callback, _ in list(self._listeners.values()):
update_callback()
if not self._listeners and self._unsub_refresh:
@callback
def _unschedule_refresh(self) -> None:
"""Unschedule any pending refresh since there is no longer any listeners."""
if self._unsub_refresh:
self._unsub_refresh()
self._unsub_refresh = None
def async_contexts(self) -> Generator[Any, None, None]:
"""Return all registered contexts."""
yield from (
context for _, context in self._listeners.values() if context is not None
)
@callback
def _schedule_refresh(self) -> None:
"""Schedule a refresh."""
@ -266,8 +280,7 @@ class DataUpdateCoordinator(Generic[_T]):
if not auth_failed and self._listeners and not self.hass.is_stopping:
self._schedule_refresh()
for update_callback in self._listeners:
update_callback()
self.async_update_listeners()
@callback
def async_set_updated_data(self, data: _T) -> None:
@ -288,24 +301,18 @@ class DataUpdateCoordinator(Generic[_T]):
if self._listeners:
self._schedule_refresh()
for update_callback in self._listeners:
update_callback()
@callback
def _async_stop_refresh(self, _: Event) -> None:
"""Stop refreshing when Home Assistant is stopping."""
self.update_interval = None
if self._unsub_refresh:
self._unsub_refresh()
self._unsub_refresh = None
self.async_update_listeners()
class CoordinatorEntity(entity.Entity, Generic[_DataUpdateCoordinatorT]):
"""A class for entities using DataUpdateCoordinator."""
def __init__(self, coordinator: _DataUpdateCoordinatorT) -> None:
def __init__(
self, coordinator: _DataUpdateCoordinatorT, context: Any = None
) -> None:
"""Create the entity with a DataUpdateCoordinator."""
self.coordinator = coordinator
self.coordinator_context = context
@property
def should_poll(self) -> bool:
@ -321,7 +328,9 @@ class CoordinatorEntity(entity.Entity, Generic[_DataUpdateCoordinatorT]):
"""When entity is added to hass."""
await super().async_added_to_hass()
self.async_on_remove(
self.coordinator.async_add_listener(self._handle_coordinator_update)
self.coordinator.async_add_listener(
self._handle_coordinator_update, self.coordinator_context
)
)
@callback

View File

@ -109,11 +109,29 @@ async def test_async_refresh(crd):
await crd.async_refresh()
assert updates == [2]
# Test unsubscribing through method
crd.async_add_listener(update_callback)
crd.async_remove_listener(update_callback)
async def test_update_context(crd: update_coordinator.DataUpdateCoordinator[int]):
"""Test update contexts for the update coordinator."""
await crd.async_refresh()
assert updates == [2]
assert not set(crd.async_contexts())
def update_callback1():
pass
def update_callback2():
pass
unsub1 = crd.async_add_listener(update_callback1, 1)
assert set(crd.async_contexts()) == {1}
unsub2 = crd.async_add_listener(update_callback2, 2)
assert set(crd.async_contexts()) == {1, 2}
unsub1()
assert set(crd.async_contexts()) == {2}
unsub2()
assert not set(crd.async_contexts())
async def test_request_refresh(crd):
@ -191,7 +209,7 @@ async def test_update_interval(hass, crd):
# Add subscriber
update_callback = Mock()
crd.async_add_listener(update_callback)
unsub = crd.async_add_listener(update_callback)
# Test twice we update with subscriber
async_fire_time_changed(hass, utcnow() + crd.update_interval)
@ -203,7 +221,7 @@ async def test_update_interval(hass, crd):
assert crd.data == 2
# Test removing listener
crd.async_remove_listener(update_callback)
unsub()
async_fire_time_changed(hass, utcnow() + crd.update_interval)
await hass.async_block_till_done()
@ -222,7 +240,7 @@ async def test_update_interval_not_present(hass, crd_without_update_interval):
# Add subscriber
update_callback = Mock()
crd.async_add_listener(update_callback)
unsub = crd.async_add_listener(update_callback)
# Test twice we don't update with subscriber with no update interval
async_fire_time_changed(hass, utcnow() + DEFAULT_UPDATE_INTERVAL)
@ -234,7 +252,7 @@ async def test_update_interval_not_present(hass, crd_without_update_interval):
assert crd.data is None
# Test removing listener
crd.async_remove_listener(update_callback)
unsub()
async_fire_time_changed(hass, utcnow() + DEFAULT_UPDATE_INTERVAL)
await hass.async_block_till_done()
@ -253,9 +271,10 @@ async def test_refresh_recover(crd, caplog):
assert "Fetching test data recovered" in caplog.text
async def test_coordinator_entity(crd):
async def test_coordinator_entity(crd: update_coordinator.DataUpdateCoordinator[int]):
"""Test the CoordinatorEntity class."""
entity = update_coordinator.CoordinatorEntity(crd)
context = object()
entity = update_coordinator.CoordinatorEntity(crd, context)
assert entity.should_poll is False
@ -278,6 +297,8 @@ async def test_coordinator_entity(crd):
await entity.async_update()
assert entity.available is False
assert list(crd.async_contexts()) == [context]
async def test_async_set_updated_data(crd):
"""Test async_set_updated_data for update coordinator."""