From b195d5d1db2df456effcbc25d8a684f422e6b767 Mon Sep 17 00:00:00 2001 From: Guido Schmitz Date: Fri, 5 May 2023 23:01:57 +0200 Subject: [PATCH] Assemble platforms upfront in devolo Home Network (#80126) * Assemble platforms upfront in devolo Home Network * Add tests * Optimize mocks * Use async_forward_entry_setups * Adapt tests to newly added switch platform --- .../devolo_home_network/__init__.py | 30 +++++++++++++++---- .../devolo_home_network/binary_sensor.py | 15 +++++----- .../components/devolo_home_network/const.py | 8 ----- .../devolo_home_network/device_tracker.py | 9 +++--- .../devolo_home_network/conftest.py | 14 +++++++++ .../devolo_home_network/test_init.py | 30 +++++++++++++++++++ 6 files changed, 79 insertions(+), 27 deletions(-) diff --git a/homeassistant/components/devolo_home_network/__init__.py b/homeassistant/components/devolo_home_network/__init__.py index 5fdb75bb5f9..d2c7b62a399 100644 --- a/homeassistant/components/devolo_home_network/__init__.py +++ b/homeassistant/components/devolo_home_network/__init__.py @@ -20,8 +20,13 @@ from devolo_plc_api.plcnet_api import LogicalNetwork from homeassistant.components import zeroconf from homeassistant.config_entries import ConfigEntry -from homeassistant.const import CONF_IP_ADDRESS, CONF_PASSWORD, EVENT_HOMEASSISTANT_STOP -from homeassistant.core import Event, HomeAssistant +from homeassistant.const import ( + CONF_IP_ADDRESS, + CONF_PASSWORD, + EVENT_HOMEASSISTANT_STOP, + Platform, +) +from homeassistant.core import Event, HomeAssistant, callback from homeassistant.exceptions import ConfigEntryAuthFailed, ConfigEntryNotReady from homeassistant.helpers.httpx_client import get_async_client from homeassistant.helpers.update_coordinator import DataUpdateCoordinator, UpdateFailed @@ -32,7 +37,6 @@ from .const import ( DOMAIN, LONG_UPDATE_INTERVAL, NEIGHBORING_WIFI_NETWORKS, - PLATFORMS, SHORT_UPDATE_INTERVAL, SWITCH_GUEST_WIFI, SWITCH_LEDS, @@ -156,7 +160,7 @@ async def async_setup_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool: for coordinator in coordinators.values(): await coordinator.async_config_entry_first_refresh() - await hass.config_entries.async_forward_entry_setups(entry, PLATFORMS) + await hass.config_entries.async_forward_entry_setups(entry, platforms(device)) entry.async_on_unload( hass.bus.async_listen_once(EVENT_HOMEASSISTANT_STOP, disconnect) @@ -167,9 +171,23 @@ async def async_setup_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool: async def async_unload_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool: """Unload a config entry.""" - unload_ok = await hass.config_entries.async_unload_platforms(entry, PLATFORMS) + device: Device = hass.data[DOMAIN][entry.entry_id]["device"] + unload_ok = await hass.config_entries.async_unload_platforms( + entry, platforms(device) + ) if unload_ok: - await hass.data[DOMAIN][entry.entry_id]["device"].async_disconnect() + await device.async_disconnect() hass.data[DOMAIN].pop(entry.entry_id) return unload_ok + + +@callback +def platforms(device: Device) -> set[Platform]: + """Assemble supported platforms.""" + supported_platforms = {Platform.SENSOR, Platform.SWITCH} + if device.plcnet: + supported_platforms.add(Platform.BINARY_SENSOR) + if device.device and "wifi1" in device.device.features: + supported_platforms.add(Platform.DEVICE_TRACKER) + return supported_platforms diff --git a/homeassistant/components/devolo_home_network/binary_sensor.py b/homeassistant/components/devolo_home_network/binary_sensor.py index 809dc9086be..b8f2551a891 100644 --- a/homeassistant/components/devolo_home_network/binary_sensor.py +++ b/homeassistant/components/devolo_home_network/binary_sensor.py @@ -68,15 +68,14 @@ async def async_setup_entry( ]["coordinators"] entities: list[BinarySensorEntity] = [] - if device.plcnet: - entities.append( - DevoloBinarySensorEntity( - entry, - coordinators[CONNECTED_PLC_DEVICES], - SENSOR_TYPES[CONNECTED_TO_ROUTER], - device, - ) + entities.append( + DevoloBinarySensorEntity( + entry, + coordinators[CONNECTED_PLC_DEVICES], + SENSOR_TYPES[CONNECTED_TO_ROUTER], + device, ) + ) async_add_entities(entities) diff --git a/homeassistant/components/devolo_home_network/const.py b/homeassistant/components/devolo_home_network/const.py index fffe9b5d482..193a0dc9a15 100644 --- a/homeassistant/components/devolo_home_network/const.py +++ b/homeassistant/components/devolo_home_network/const.py @@ -9,15 +9,7 @@ from devolo_plc_api.device_api import ( WIFI_VAP_MAIN_AP, ) -from homeassistant.const import Platform - DOMAIN = "devolo_home_network" -PLATFORMS = [ - Platform.BINARY_SENSOR, - Platform.DEVICE_TRACKER, - Platform.SENSOR, - Platform.SWITCH, -] PRODUCT = "product" SERIAL_NUMBER = "serial_number" diff --git a/homeassistant/components/devolo_home_network/device_tracker.py b/homeassistant/components/devolo_home_network/device_tracker.py index eb6e9cf6ec6..c73e08abed2 100644 --- a/homeassistant/components/devolo_home_network/device_tracker.py +++ b/homeassistant/components/devolo_home_network/device_tracker.py @@ -73,11 +73,10 @@ async def async_setup_entry( async_add_entities(missing) - if device.device and "wifi1" in device.device.features: - restore_entities() - entry.async_on_unload( - coordinators[CONNECTED_WIFI_CLIENTS].async_add_listener(new_device_callback) - ) + restore_entities() + entry.async_on_unload( + coordinators[CONNECTED_WIFI_CLIENTS].async_add_listener(new_device_callback) + ) class DevoloScannerEntity( diff --git a/tests/components/devolo_home_network/conftest.py b/tests/components/devolo_home_network/conftest.py index 193b0d700ee..1eb91f7a48f 100644 --- a/tests/components/devolo_home_network/conftest.py +++ b/tests/components/devolo_home_network/conftest.py @@ -19,6 +19,20 @@ def mock_device(): yield device +@pytest.fixture() +def mock_repeater_device(mock_device: MockDevice): + """Mock connecting to a devolo home network repeater device.""" + mock_device.plcnet = None + yield mock_device + + +@pytest.fixture() +def mock_nonwifi_device(mock_device: MockDevice): + """Mock connecting to a devolo home network device without wifi.""" + mock_device.device.features = ["reset", "update", "led", "intmtg"] + yield mock_device + + @pytest.fixture(name="info") def mock_validate_input(): """Mock setup entry and user input.""" diff --git a/tests/components/devolo_home_network/test_init.py b/tests/components/devolo_home_network/test_init.py index 1b24c21f2bc..536053e149a 100644 --- a/tests/components/devolo_home_network/test_init.py +++ b/tests/components/devolo_home_network/test_init.py @@ -4,10 +4,15 @@ from unittest.mock import patch from devolo_plc_api.exceptions.device import DeviceNotFound import pytest +from homeassistant.components.binary_sensor import DOMAIN as BINARY_SENSOR +from homeassistant.components.device_tracker import DOMAIN as DEVICE_TRACKER from homeassistant.components.devolo_home_network.const import DOMAIN +from homeassistant.components.sensor import DOMAIN as SENSOR +from homeassistant.components.switch import DOMAIN as SWITCH from homeassistant.config_entries import ConfigEntryState from homeassistant.const import CONF_IP_ADDRESS, EVENT_HOMEASSISTANT_STOP from homeassistant.core import HomeAssistant +from homeassistant.helpers.entity_platform import async_get_platforms from . import configure_integration from .const import IP @@ -73,3 +78,28 @@ async def test_hass_stop(hass: HomeAssistant, mock_device: MockDevice) -> None: hass.bus.async_fire(EVENT_HOMEASSISTANT_STOP) await hass.async_block_till_done() mock_device.async_disconnect.assert_called_once() + + +@pytest.mark.parametrize( + "device, expected_platforms", + [ + ["mock_device", (BINARY_SENSOR, DEVICE_TRACKER, SENSOR, SWITCH)], + ["mock_repeater_device", (DEVICE_TRACKER, SENSOR, SWITCH)], + ["mock_nonwifi_device", (BINARY_SENSOR, SENSOR, SWITCH)], + ], +) +async def test_platforms( + hass: HomeAssistant, + device: str, + expected_platforms: set[str], + request: pytest.FixtureRequest, +): + """Test platform assembly.""" + request.getfixturevalue(device) + entry = configure_integration(hass) + + await hass.config_entries.async_setup(entry.entry_id) + await hass.async_block_till_done() + platforms = [platform.domain for platform in async_get_platforms(hass, DOMAIN)] + assert len(platforms) == len(expected_platforms) + assert all(platform in platforms for platform in expected_platforms)