From f504c279721ff7a00d0152c106a13c14caecc7a5 Mon Sep 17 00:00:00 2001 From: epenet <6771947+epenet@users.noreply.github.com> Date: Thu, 10 Oct 2024 10:20:15 +0200 Subject: [PATCH] Add ability to pass the config entry explicitly in data update coordinators (#127980) * Add ability to pass the config entry explicitely in data update coordinators * Implement in accuweather * Raise if config entry not set * Move accuweather models * Fix gogogate2 * Fix rainforest_raven --- .../components/accuweather/__init__.py | 17 +--- .../components/accuweather/coordinator.py | 19 +++++ .../components/accuweather/diagnostics.py | 2 +- .../components/accuweather/sensor.py | 2 +- .../components/accuweather/system_health.py | 2 +- .../components/accuweather/weather.py | 3 +- .../rainforest_raven/coordinator.py | 1 + homeassistant/helpers/update_coordinator.py | 13 ++- tests/components/gogogate2/test_init.py | 12 +-- tests/helpers/test_update_coordinator.py | 85 +++++++++++++++---- 10 files changed, 115 insertions(+), 41 deletions(-) diff --git a/homeassistant/components/accuweather/__init__.py b/homeassistant/components/accuweather/__init__.py index 3d52df765e6..c046933d5d5 100644 --- a/homeassistant/components/accuweather/__init__.py +++ b/homeassistant/components/accuweather/__init__.py @@ -2,13 +2,11 @@ from __future__ import annotations -from dataclasses import dataclass import logging from accuweather import AccuWeather from homeassistant.components.sensor import DOMAIN as SENSOR_PLATFORM -from homeassistant.config_entries import ConfigEntry from homeassistant.const import CONF_API_KEY, CONF_NAME, Platform from homeassistant.core import HomeAssistant from homeassistant.helpers import entity_registry as er @@ -16,7 +14,9 @@ from homeassistant.helpers.aiohttp_client import async_get_clientsession from .const import DOMAIN, UPDATE_INTERVAL_DAILY_FORECAST, UPDATE_INTERVAL_OBSERVATION from .coordinator import ( + AccuWeatherConfigEntry, AccuWeatherDailyForecastDataUpdateCoordinator, + AccuWeatherData, AccuWeatherObservationDataUpdateCoordinator, ) @@ -25,17 +25,6 @@ _LOGGER = logging.getLogger(__name__) PLATFORMS = [Platform.SENSOR, Platform.WEATHER] -@dataclass -class AccuWeatherData: - """Data for AccuWeather integration.""" - - coordinator_observation: AccuWeatherObservationDataUpdateCoordinator - coordinator_daily_forecast: AccuWeatherDailyForecastDataUpdateCoordinator - - -type AccuWeatherConfigEntry = ConfigEntry[AccuWeatherData] - - async def async_setup_entry(hass: HomeAssistant, entry: AccuWeatherConfigEntry) -> bool: """Set up AccuWeather as config entry.""" api_key: str = entry.data[CONF_API_KEY] @@ -50,6 +39,7 @@ async def async_setup_entry(hass: HomeAssistant, entry: AccuWeatherConfigEntry) coordinator_observation = AccuWeatherObservationDataUpdateCoordinator( hass, + entry, accuweather, name, "observation", @@ -58,6 +48,7 @@ async def async_setup_entry(hass: HomeAssistant, entry: AccuWeatherConfigEntry) coordinator_daily_forecast = AccuWeatherDailyForecastDataUpdateCoordinator( hass, + entry, accuweather, name, "daily forecast", diff --git a/homeassistant/components/accuweather/coordinator.py b/homeassistant/components/accuweather/coordinator.py index 26fadd6806c..40ff3ad2c87 100644 --- a/homeassistant/components/accuweather/coordinator.py +++ b/homeassistant/components/accuweather/coordinator.py @@ -1,6 +1,9 @@ """The AccuWeather coordinator.""" +from __future__ import annotations + from asyncio import timeout +from dataclasses import dataclass from datetime import timedelta import logging from typing import TYPE_CHECKING, Any @@ -8,6 +11,7 @@ from typing import TYPE_CHECKING, Any from accuweather import AccuWeather, ApiError, InvalidApiKeyError, RequestsExceededError from aiohttp.client_exceptions import ClientConnectorError +from homeassistant.config_entries import ConfigEntry from homeassistant.core import HomeAssistant from homeassistant.helpers.device_registry import DeviceEntryType, DeviceInfo from homeassistant.helpers.update_coordinator import ( @@ -23,6 +27,17 @@ EXCEPTIONS = (ApiError, ClientConnectorError, InvalidApiKeyError, RequestsExceed _LOGGER = logging.getLogger(__name__) +@dataclass +class AccuWeatherData: + """Data for AccuWeather integration.""" + + coordinator_observation: AccuWeatherObservationDataUpdateCoordinator + coordinator_daily_forecast: AccuWeatherDailyForecastDataUpdateCoordinator + + +type AccuWeatherConfigEntry = ConfigEntry[AccuWeatherData] + + class AccuWeatherObservationDataUpdateCoordinator( DataUpdateCoordinator[dict[str, Any]] ): @@ -31,6 +46,7 @@ class AccuWeatherObservationDataUpdateCoordinator( def __init__( self, hass: HomeAssistant, + config_entry: AccuWeatherConfigEntry, accuweather: AccuWeather, name: str, coordinator_type: str, @@ -48,6 +64,7 @@ class AccuWeatherObservationDataUpdateCoordinator( super().__init__( hass, _LOGGER, + config_entry=config_entry, name=f"{name} ({coordinator_type})", update_interval=update_interval, ) @@ -73,6 +90,7 @@ class AccuWeatherDailyForecastDataUpdateCoordinator( def __init__( self, hass: HomeAssistant, + config_entry: AccuWeatherConfigEntry, accuweather: AccuWeather, name: str, coordinator_type: str, @@ -90,6 +108,7 @@ class AccuWeatherDailyForecastDataUpdateCoordinator( super().__init__( hass, _LOGGER, + config_entry=config_entry, name=f"{name} ({coordinator_type})", update_interval=update_interval, ) diff --git a/homeassistant/components/accuweather/diagnostics.py b/homeassistant/components/accuweather/diagnostics.py index 85c06a6140a..9f35c47b886 100644 --- a/homeassistant/components/accuweather/diagnostics.py +++ b/homeassistant/components/accuweather/diagnostics.py @@ -8,7 +8,7 @@ from homeassistant.components.diagnostics import async_redact_data from homeassistant.const import CONF_API_KEY, CONF_LATITUDE, CONF_LONGITUDE from homeassistant.core import HomeAssistant -from . import AccuWeatherConfigEntry, AccuWeatherData +from .coordinator import AccuWeatherConfigEntry, AccuWeatherData TO_REDACT = {CONF_API_KEY, CONF_LATITUDE, CONF_LONGITUDE} diff --git a/homeassistant/components/accuweather/sensor.py b/homeassistant/components/accuweather/sensor.py index 2f6b10b296f..001edc5f197 100644 --- a/homeassistant/components/accuweather/sensor.py +++ b/homeassistant/components/accuweather/sensor.py @@ -28,7 +28,6 @@ from homeassistant.core import HomeAssistant, callback from homeassistant.helpers.entity_platform import AddEntitiesCallback from homeassistant.helpers.update_coordinator import CoordinatorEntity -from . import AccuWeatherConfigEntry from .const import ( API_METRIC, ATTR_CATEGORY, @@ -41,6 +40,7 @@ from .const import ( MAX_FORECAST_DAYS, ) from .coordinator import ( + AccuWeatherConfigEntry, AccuWeatherDailyForecastDataUpdateCoordinator, AccuWeatherObservationDataUpdateCoordinator, ) diff --git a/homeassistant/components/accuweather/system_health.py b/homeassistant/components/accuweather/system_health.py index eab16498248..f5efaf3079f 100644 --- a/homeassistant/components/accuweather/system_health.py +++ b/homeassistant/components/accuweather/system_health.py @@ -9,8 +9,8 @@ from accuweather.const import ENDPOINT from homeassistant.components import system_health from homeassistant.core import HomeAssistant, callback -from . import AccuWeatherConfigEntry from .const import DOMAIN +from .coordinator import AccuWeatherConfigEntry @callback diff --git a/homeassistant/components/accuweather/weather.py b/homeassistant/components/accuweather/weather.py index 72d717f2703..7d754278d91 100644 --- a/homeassistant/components/accuweather/weather.py +++ b/homeassistant/components/accuweather/weather.py @@ -33,7 +33,6 @@ from homeassistant.core import HomeAssistant, callback from homeassistant.helpers.entity_platform import AddEntitiesCallback from homeassistant.util.dt import utc_from_timestamp -from . import AccuWeatherConfigEntry, AccuWeatherData from .const import ( API_METRIC, ATTR_DIRECTION, @@ -43,7 +42,9 @@ from .const import ( CONDITION_MAP, ) from .coordinator import ( + AccuWeatherConfigEntry, AccuWeatherDailyForecastDataUpdateCoordinator, + AccuWeatherData, AccuWeatherObservationDataUpdateCoordinator, ) diff --git a/homeassistant/components/rainforest_raven/coordinator.py b/homeassistant/components/rainforest_raven/coordinator.py index d08a10c2670..a652d4a4e83 100644 --- a/homeassistant/components/rainforest_raven/coordinator.py +++ b/homeassistant/components/rainforest_raven/coordinator.py @@ -75,6 +75,7 @@ class RAVEnDataCoordinator(DataUpdateCoordinator): super().__init__( hass, _LOGGER, + config_entry=entry, name=DOMAIN, update_interval=timedelta(seconds=30), ) diff --git a/homeassistant/helpers/update_coordinator.py b/homeassistant/helpers/update_coordinator.py index 25cd4bc4d90..e2739bbdca9 100644 --- a/homeassistant/helpers/update_coordinator.py +++ b/homeassistant/helpers/update_coordinator.py @@ -29,6 +29,7 @@ from homeassistant.util.dt import utcnow from . import entity, event from .debounce import Debouncer +from .typing import UNDEFINED, UndefinedType REQUEST_REFRESH_DEFAULT_COOLDOWN = 10 REQUEST_REFRESH_DEFAULT_IMMEDIATE = True @@ -68,6 +69,7 @@ class DataUpdateCoordinator(BaseDataUpdateCoordinatorProtocol, Generic[_DataT]): hass: HomeAssistant, logger: logging.Logger, *, + config_entry: config_entries.ConfigEntry | None | UndefinedType = UNDEFINED, name: str, update_interval: timedelta | None = None, update_method: Callable[[], Awaitable[_DataT]] | None = None, @@ -84,7 +86,12 @@ class DataUpdateCoordinator(BaseDataUpdateCoordinatorProtocol, Generic[_DataT]): self._update_interval_seconds: float | None = None self.update_interval = update_interval self._shutdown_requested = False - self.config_entry = config_entries.current_entry.get() + if config_entry is UNDEFINED: + self.config_entry = config_entries.current_entry.get() + # This should be deprecated once all core integrations are updated + # to pass in the config entry explicitly. + else: + self.config_entry = config_entry self.always_update = always_update # It's None before the first successful update. @@ -277,6 +284,10 @@ class DataUpdateCoordinator(BaseDataUpdateCoordinatorProtocol, Generic[_DataT]): fails. Additionally logging is handled by config entry setup to ensure that multiple retries do not cause log spam. """ + if self.config_entry is None: + raise ValueError( + "This method is only supported for coordinators with a config entry" + ) if await self.__wrap_async_setup(): await self._async_refresh( log_failures=False, raise_on_auth_failed=True, raise_on_entry_error=True diff --git a/tests/components/gogogate2/test_init.py b/tests/components/gogogate2/test_init.py index f7e58296a43..90765c425b4 100644 --- a/tests/components/gogogate2/test_init.py +++ b/tests/components/gogogate2/test_init.py @@ -3,11 +3,10 @@ from unittest.mock import MagicMock, patch from ismartgate import GogoGate2Api -import pytest -from homeassistant.components.gogogate2 import DEVICE_TYPE_GOGOGATE2, async_setup_entry +from homeassistant.components.gogogate2 import DEVICE_TYPE_GOGOGATE2 from homeassistant.components.gogogate2.const import DEVICE_TYPE_ISMARTGATE, DOMAIN -from homeassistant.config_entries import SOURCE_USER +from homeassistant.config_entries import SOURCE_USER, ConfigEntryState from homeassistant.const import ( CONF_DEVICE, CONF_IP_ADDRESS, @@ -15,7 +14,6 @@ from homeassistant.const import ( CONF_USERNAME, ) from homeassistant.core import HomeAssistant -from homeassistant.exceptions import ConfigEntryNotReady from tests.common import MockConfigEntry @@ -97,6 +95,8 @@ async def test_api_failure_on_startup(hass: HomeAssistant) -> None: "homeassistant.components.gogogate2.common.ISmartGateApi.async_info", side_effect=TimeoutError, ), - pytest.raises(ConfigEntryNotReady), ): - await async_setup_entry(hass, config_entry) + await hass.config_entries.async_setup(config_entry.entry_id) + await hass.async_block_till_done() + + assert config_entry.state is ConfigEntryState.SETUP_RETRY diff --git a/tests/helpers/test_update_coordinator.py b/tests/helpers/test_update_coordinator.py index d450d924f1f..48a2fe416d1 100644 --- a/tests/helpers/test_update_coordinator.py +++ b/tests/helpers/test_update_coordinator.py @@ -57,7 +57,9 @@ KNOWN_ERRORS: list[tuple[Exception, type[Exception], str]] = [ def get_crd( - hass: HomeAssistant, update_interval: timedelta | None + hass: HomeAssistant, + update_interval: timedelta | None, + config_entry: config_entries.ConfigEntry | None = None, ) -> update_coordinator.DataUpdateCoordinator[int]: """Make coordinator mocks.""" calls = 0 @@ -70,6 +72,7 @@ def get_crd( return update_coordinator.DataUpdateCoordinator[int]( hass, _LOGGER, + config_entry=config_entry, name="test", update_method=refresh, update_interval=update_interval, @@ -121,8 +124,7 @@ async def test_async_refresh( async def test_shutdown( - hass: HomeAssistant, - crd: update_coordinator.DataUpdateCoordinator[int], + hass: HomeAssistant, crd: update_coordinator.DataUpdateCoordinator[int] ) -> None: """Test async_shutdown for update coordinator.""" assert crd.data is None @@ -158,8 +160,7 @@ async def test_shutdown( async def test_shutdown_on_entry_unload( - hass: HomeAssistant, - crd: update_coordinator.DataUpdateCoordinator[int], + hass: HomeAssistant, crd: update_coordinator.DataUpdateCoordinator[int] ) -> None: """Test shutdown is requested on entry unload.""" entry = MockConfigEntry() @@ -191,8 +192,7 @@ async def test_shutdown_on_entry_unload( async def test_shutdown_on_hass_stop( - hass: HomeAssistant, - crd: update_coordinator.DataUpdateCoordinator[int], + hass: HomeAssistant, crd: update_coordinator.DataUpdateCoordinator[int] ) -> None: """Test shutdown can be shutdown on STOP event.""" calls = 0 @@ -539,8 +539,8 @@ async def test_stop_refresh_on_ha_stop( ["update_method", "setup_method"], ) async def test_async_config_entry_first_refresh_failure( + hass: HomeAssistant, err_msg: tuple[Exception, type[Exception], str], - crd: update_coordinator.DataUpdateCoordinator[int], method: str, caplog: pytest.LogCaptureFixture, ) -> None: @@ -550,6 +550,8 @@ async def test_async_config_entry_first_refresh_failure( will be caught by config_entries.async_setup which will log it with a decreasing level of logging once the first message is logged. """ + entry = MockConfigEntry() + crd = get_crd(hass, DEFAULT_UPDATE_INTERVAL, entry) setattr(crd, method, AsyncMock(side_effect=err_msg[0])) with pytest.raises(ConfigEntryNotReady): @@ -572,8 +574,8 @@ async def test_async_config_entry_first_refresh_failure( ["update_method", "setup_method"], ) async def test_async_config_entry_first_refresh_failure_passed_through( + hass: HomeAssistant, err_msg: tuple[Exception, type[Exception], str], - crd: update_coordinator.DataUpdateCoordinator[int], method: str, caplog: pytest.LogCaptureFixture, ) -> None: @@ -583,6 +585,8 @@ async def test_async_config_entry_first_refresh_failure_passed_through( will be caught by config_entries.async_setup which will log it with a decreasing level of logging once the first message is logged. """ + entry = MockConfigEntry() + crd = get_crd(hass, DEFAULT_UPDATE_INTERVAL, entry) setattr(crd, method, AsyncMock(side_effect=err_msg[0])) with pytest.raises(err_msg[1]): @@ -593,11 +597,10 @@ async def test_async_config_entry_first_refresh_failure_passed_through( assert err_msg[2] not in caplog.text -async def test_async_config_entry_first_refresh_success( - crd: update_coordinator.DataUpdateCoordinator[int], caplog: pytest.LogCaptureFixture -) -> None: +async def test_async_config_entry_first_refresh_success(hass: HomeAssistant) -> None: """Test first refresh successfully.""" - + entry = MockConfigEntry() + crd = get_crd(hass, DEFAULT_UPDATE_INTERVAL, entry) crd.setup_method = AsyncMock() await crd.async_config_entry_first_refresh() @@ -605,13 +608,26 @@ async def test_async_config_entry_first_refresh_success( crd.setup_method.assert_called_once() +async def test_async_config_entry_first_refresh_no_entry(hass: HomeAssistant) -> None: + """Test first refresh successfully.""" + crd = get_crd(hass, DEFAULT_UPDATE_INTERVAL, None) + crd.setup_method = AsyncMock() + with pytest.raises( + ValueError, + match="This method is only supported for coordinators with a config entry", + ): + await crd.async_config_entry_first_refresh() + + assert crd.last_update_success is True + crd.setup_method.assert_not_called() + + async def test_not_schedule_refresh_if_system_option_disable_polling( hass: HomeAssistant, ) -> None: """Test we do not schedule a refresh if disable polling in config entry.""" entry = MockConfigEntry(pref_disable_polling=True) - config_entries.current_entry.set(entry) - crd = get_crd(hass, DEFAULT_UPDATE_INTERVAL) + crd = get_crd(hass, DEFAULT_UPDATE_INTERVAL, entry) crd.async_add_listener(lambda: None) assert crd._unsub_refresh is None @@ -651,7 +667,7 @@ async def test_async_set_update_error( async def test_only_callback_on_change_when_always_update_is_false( - crd: update_coordinator.DataUpdateCoordinator[int], caplog: pytest.LogCaptureFixture + crd: update_coordinator.DataUpdateCoordinator[int], ) -> None: """Test we do not callback listeners unless something has actually changed when always_update is false.""" update_callback = Mock() @@ -721,7 +737,7 @@ async def test_only_callback_on_change_when_always_update_is_false( async def test_always_callback_when_always_update_is_true( - crd: update_coordinator.DataUpdateCoordinator[int], caplog: pytest.LogCaptureFixture + crd: update_coordinator.DataUpdateCoordinator[int], ) -> None: """Test we callback listeners even though the data is the same when always_update is True.""" update_callback = Mock() @@ -795,3 +811,38 @@ async def test_timestamp_date_update_coordinator(hass: HomeAssistant) -> None: unsub() await crd.async_refresh() assert len(last_update_success_times) == 1 + + +async def test_config_entry(hass: HomeAssistant) -> None: + """Test behavior of coordinator.entry.""" + entry = MockConfigEntry() + + # Default without context should be None + crd = update_coordinator.DataUpdateCoordinator[int](hass, _LOGGER, name="test") + assert crd.config_entry is None + + # Explicit None is OK + crd = update_coordinator.DataUpdateCoordinator[int]( + hass, _LOGGER, name="test", config_entry=None + ) + assert crd.config_entry is None + + # Explicit entry is OK + crd = update_coordinator.DataUpdateCoordinator[int]( + hass, _LOGGER, name="test", config_entry=entry + ) + assert crd.config_entry is entry + + # set ContextVar + config_entries.current_entry.set(entry) + + # Default with ContextVar should match the ContextVar + crd = update_coordinator.DataUpdateCoordinator[int](hass, _LOGGER, name="test") + assert crd.config_entry is entry + + # Explicit entry different from ContextVar not recommended, but should work + another_entry = MockConfigEntry() + crd = update_coordinator.DataUpdateCoordinator[int]( + hass, _LOGGER, name="test", config_entry=another_entry + ) + assert crd.config_entry is another_entry