Reinitialize upnp device on config change (#49081)
* Store coordinator at Device * Use DeviceUpdater to follow config/location changes * Cleaning up * Fix unit tests + review changes * Don't test internalspull/49230/head
parent
ed54494b69
commit
555f508b8c
|
@ -21,7 +21,6 @@ from .const import (
|
|||
DISCOVERY_UDN,
|
||||
DOMAIN,
|
||||
DOMAIN_CONFIG,
|
||||
DOMAIN_COORDINATORS,
|
||||
DOMAIN_DEVICES,
|
||||
DOMAIN_LOCAL_IP,
|
||||
LOGGER as _LOGGER,
|
||||
|
@ -75,7 +74,6 @@ async def async_setup(hass: HomeAssistantType, config: ConfigType):
|
|||
local_ip = await hass.async_add_executor_job(get_local_ip)
|
||||
hass.data[DOMAIN] = {
|
||||
DOMAIN_CONFIG: conf,
|
||||
DOMAIN_COORDINATORS: {},
|
||||
DOMAIN_DEVICES: {},
|
||||
DOMAIN_LOCAL_IP: conf.get(CONF_LOCAL_IP, local_ip),
|
||||
}
|
||||
|
@ -149,6 +147,9 @@ async def async_setup_entry(hass: HomeAssistantType, config_entry: ConfigEntry)
|
|||
hass.config_entries.async_forward_entry_setup(config_entry, "sensor")
|
||||
)
|
||||
|
||||
# Start device updater.
|
||||
await device.async_start()
|
||||
|
||||
return True
|
||||
|
||||
|
||||
|
@ -160,9 +161,10 @@ async def async_unload_entry(
|
|||
|
||||
udn = config_entry.data.get(CONFIG_ENTRY_UDN)
|
||||
if udn in hass.data[DOMAIN][DOMAIN_DEVICES]:
|
||||
device = hass.data[DOMAIN][DOMAIN_DEVICES][udn]
|
||||
await device.async_stop()
|
||||
|
||||
del hass.data[DOMAIN][DOMAIN_DEVICES][udn]
|
||||
if udn in hass.data[DOMAIN][DOMAIN_COORDINATORS]:
|
||||
del hass.data[DOMAIN][DOMAIN_COORDINATORS][udn]
|
||||
|
||||
_LOGGER.debug("Deleting sensors")
|
||||
return await hass.config_entries.async_forward_entry_unload(config_entry, "sensor")
|
||||
|
|
|
@ -25,7 +25,7 @@ from .const import (
|
|||
DISCOVERY_UNIQUE_ID,
|
||||
DISCOVERY_USN,
|
||||
DOMAIN,
|
||||
DOMAIN_COORDINATORS,
|
||||
DOMAIN_DEVICES,
|
||||
LOGGER as _LOGGER,
|
||||
)
|
||||
from .device import Device
|
||||
|
@ -252,7 +252,7 @@ class UpnpOptionsFlowHandler(config_entries.OptionsFlow):
|
|||
"""Manage the options."""
|
||||
if user_input is not None:
|
||||
udn = self.config_entry.data[CONFIG_ENTRY_UDN]
|
||||
coordinator = self.hass.data[DOMAIN][DOMAIN_COORDINATORS][udn]
|
||||
coordinator = self.hass.data[DOMAIN][DOMAIN_DEVICES][udn].coordinator
|
||||
update_interval_sec = user_input.get(
|
||||
CONFIG_ENTRY_SCAN_INTERVAL, DEFAULT_SCAN_INTERVAL
|
||||
)
|
||||
|
|
|
@ -9,7 +9,6 @@ LOGGER = logging.getLogger(__package__)
|
|||
CONF_LOCAL_IP = "local_ip"
|
||||
DOMAIN = "upnp"
|
||||
DOMAIN_CONFIG = "config"
|
||||
DOMAIN_COORDINATORS = "coordinators"
|
||||
DOMAIN_DEVICES = "devices"
|
||||
DOMAIN_LOCAL_IP = "local_ip"
|
||||
BYTES_RECEIVED = "bytes_received"
|
||||
|
|
|
@ -8,10 +8,12 @@ from urllib.parse import urlparse
|
|||
|
||||
from async_upnp_client import UpnpFactory
|
||||
from async_upnp_client.aiohttp import AiohttpSessionRequester
|
||||
from async_upnp_client.device_updater import DeviceUpdater
|
||||
from async_upnp_client.profiles.igd import IgdDevice
|
||||
|
||||
from homeassistant.helpers.aiohttp_client import async_get_clientsession
|
||||
from homeassistant.helpers.typing import HomeAssistantType
|
||||
from homeassistant.helpers.update_coordinator import DataUpdateCoordinator
|
||||
import homeassistant.util.dt as dt_util
|
||||
|
||||
from .const import (
|
||||
|
@ -34,23 +36,29 @@ from .const import (
|
|||
)
|
||||
|
||||
|
||||
def _get_local_ip(hass: HomeAssistantType) -> IPv4Address | None:
|
||||
"""Get the configured local ip."""
|
||||
if DOMAIN in hass.data and DOMAIN_CONFIG in hass.data[DOMAIN]:
|
||||
local_ip = hass.data[DOMAIN][DOMAIN_CONFIG].get(CONF_LOCAL_IP)
|
||||
if local_ip:
|
||||
return IPv4Address(local_ip)
|
||||
return None
|
||||
|
||||
|
||||
class Device:
|
||||
"""Home Assistant representation of a UPnP/IGD device."""
|
||||
|
||||
def __init__(self, igd_device):
|
||||
def __init__(self, igd_device: IgdDevice, device_updater: DeviceUpdater) -> None:
|
||||
"""Initialize UPnP/IGD device."""
|
||||
self._igd_device: IgdDevice = igd_device
|
||||
self._igd_device = igd_device
|
||||
self._device_updater = device_updater
|
||||
self.coordinator: DataUpdateCoordinator = None
|
||||
|
||||
@classmethod
|
||||
async def async_discover(cls, hass: HomeAssistantType) -> list[Mapping]:
|
||||
"""Discover UPnP/IGD devices."""
|
||||
_LOGGER.debug("Discovering UPnP/IGD devices")
|
||||
local_ip = None
|
||||
if DOMAIN in hass.data and DOMAIN_CONFIG in hass.data[DOMAIN]:
|
||||
local_ip = hass.data[DOMAIN][DOMAIN_CONFIG].get(CONF_LOCAL_IP)
|
||||
if local_ip:
|
||||
local_ip = IPv4Address(local_ip)
|
||||
|
||||
local_ip = _get_local_ip(hass)
|
||||
discoveries = await IgdDevice.async_search(source_ip=local_ip, timeout=10)
|
||||
|
||||
# Supplement/standardize discovery.
|
||||
|
@ -81,17 +89,32 @@ class Device:
|
|||
cls, hass: HomeAssistantType, ssdp_location: str
|
||||
) -> Device:
|
||||
"""Create UPnP/IGD device."""
|
||||
# build async_upnp_client requester
|
||||
# Build async_upnp_client requester.
|
||||
session = async_get_clientsession(hass)
|
||||
requester = AiohttpSessionRequester(session, True, 10)
|
||||
|
||||
# create async_upnp_client device
|
||||
# Create async_upnp_client device.
|
||||
factory = UpnpFactory(requester, disable_state_variable_validation=True)
|
||||
upnp_device = await factory.async_create_device(ssdp_location)
|
||||
|
||||
# Create profile wrapper.
|
||||
igd_device = IgdDevice(upnp_device, None)
|
||||
|
||||
return cls(igd_device)
|
||||
# Create updater.
|
||||
local_ip = _get_local_ip(hass)
|
||||
device_updater = DeviceUpdater(
|
||||
device=upnp_device, factory=factory, source_ip=local_ip
|
||||
)
|
||||
|
||||
return cls(igd_device, device_updater)
|
||||
|
||||
async def async_start(self) -> None:
|
||||
"""Start the device updater."""
|
||||
await self._device_updater.async_start()
|
||||
|
||||
async def async_stop(self) -> None:
|
||||
"""Stop the device updater."""
|
||||
await self._device_updater.async_stop()
|
||||
|
||||
@property
|
||||
def udn(self) -> str:
|
||||
|
|
|
@ -2,7 +2,7 @@
|
|||
from __future__ import annotations
|
||||
|
||||
from datetime import timedelta
|
||||
from typing import Any, Mapping
|
||||
from typing import Any, Callable, Mapping
|
||||
|
||||
from homeassistant.components.sensor import SensorEntity
|
||||
from homeassistant.config_entries import ConfigEntry
|
||||
|
@ -23,7 +23,6 @@ from .const import (
|
|||
DATA_RATE_PACKETS_PER_SECOND,
|
||||
DEFAULT_SCAN_INTERVAL,
|
||||
DOMAIN,
|
||||
DOMAIN_COORDINATORS,
|
||||
DOMAIN_DEVICES,
|
||||
KIBIBYTE,
|
||||
LOGGER as _LOGGER,
|
||||
|
@ -83,7 +82,7 @@ async def async_setup_platform(
|
|||
|
||||
|
||||
async def async_setup_entry(
|
||||
hass, config_entry: ConfigEntry, async_add_entities
|
||||
hass: HomeAssistantType, config_entry: ConfigEntry, async_add_entities: Callable
|
||||
) -> None:
|
||||
"""Set up the UPnP/IGD sensors."""
|
||||
udn = config_entry.data[CONFIG_ENTRY_UDN]
|
||||
|
@ -102,8 +101,9 @@ async def async_setup_entry(
|
|||
update_method=device.async_get_traffic_data,
|
||||
update_interval=update_interval,
|
||||
)
|
||||
device.coordinator = coordinator
|
||||
|
||||
await coordinator.async_refresh()
|
||||
hass.data[DOMAIN][DOMAIN_COORDINATORS][udn] = coordinator
|
||||
|
||||
sensors = [
|
||||
RawUpnpSensor(coordinator, device, SENSOR_TYPES[BYTES_RECEIVED]),
|
||||
|
@ -126,14 +126,11 @@ class UpnpSensor(CoordinatorEntity, SensorEntity):
|
|||
coordinator: DataUpdateCoordinator[Mapping[str, Any]],
|
||||
device: Device,
|
||||
sensor_type: Mapping[str, str],
|
||||
update_multiplier: int = 2,
|
||||
) -> None:
|
||||
"""Initialize the base sensor."""
|
||||
super().__init__(coordinator)
|
||||
self._device = device
|
||||
self._sensor_type = sensor_type
|
||||
self._update_counter_max = update_multiplier
|
||||
self._update_counter = 0
|
||||
|
||||
@property
|
||||
def icon(self) -> str:
|
||||
|
|
|
@ -1,6 +1,7 @@
|
|||
"""Mock device for testing purposes."""
|
||||
|
||||
from typing import Mapping
|
||||
from unittest.mock import AsyncMock
|
||||
|
||||
from homeassistant.components.upnp.const import (
|
||||
BYTES_RECEIVED,
|
||||
|
@ -10,7 +11,7 @@ from homeassistant.components.upnp.const import (
|
|||
TIMESTAMP,
|
||||
)
|
||||
from homeassistant.components.upnp.device import Device
|
||||
import homeassistant.util.dt as dt_util
|
||||
from homeassistant.util import dt
|
||||
|
||||
|
||||
class MockDevice(Device):
|
||||
|
@ -19,8 +20,10 @@ class MockDevice(Device):
|
|||
def __init__(self, udn: str) -> None:
|
||||
"""Initialize mock device."""
|
||||
igd_device = object()
|
||||
super().__init__(igd_device)
|
||||
mock_device_updater = AsyncMock()
|
||||
super().__init__(igd_device, mock_device_updater)
|
||||
self._udn = udn
|
||||
self.times_polled = 0
|
||||
|
||||
@classmethod
|
||||
async def async_create_device(cls, hass, ssdp_location) -> "MockDevice":
|
||||
|
@ -59,8 +62,9 @@ class MockDevice(Device):
|
|||
|
||||
async def async_get_traffic_data(self) -> Mapping[str, any]:
|
||||
"""Get traffic data."""
|
||||
self.times_polled += 1
|
||||
return {
|
||||
TIMESTAMP: dt_util.utcnow(),
|
||||
TIMESTAMP: dt.utcnow(),
|
||||
BYTES_RECEIVED: 0,
|
||||
BYTES_SENT: 0,
|
||||
PACKETS_RECEIVED: 0,
|
||||
|
|
|
@ -19,15 +19,15 @@ from homeassistant.components.upnp.const import (
|
|||
DISCOVERY_UNIQUE_ID,
|
||||
DISCOVERY_USN,
|
||||
DOMAIN,
|
||||
DOMAIN_COORDINATORS,
|
||||
)
|
||||
from homeassistant.components.upnp.device import Device
|
||||
from homeassistant.helpers.typing import HomeAssistantType
|
||||
from homeassistant.setup import async_setup_component
|
||||
from homeassistant.util import dt
|
||||
|
||||
from .mock_device import MockDevice
|
||||
|
||||
from tests.common import MockConfigEntry
|
||||
from tests.common import MockConfigEntry, async_fire_time_changed
|
||||
|
||||
|
||||
async def test_flow_ssdp_discovery(hass: HomeAssistantType):
|
||||
|
@ -325,10 +325,12 @@ async def test_options_flow(hass: HomeAssistantType):
|
|||
# Initialisation of component.
|
||||
await async_setup_component(hass, "upnp", config)
|
||||
await hass.async_block_till_done()
|
||||
mock_device.times_polled = 0 # Reset.
|
||||
|
||||
# DataUpdateCoordinator gets a default of 30 seconds for updates.
|
||||
coordinator = hass.data[DOMAIN][DOMAIN_COORDINATORS][mock_device.udn]
|
||||
assert coordinator.update_interval == timedelta(seconds=DEFAULT_SCAN_INTERVAL)
|
||||
# Forward time, ensure single poll after 30 (default) seconds.
|
||||
async_fire_time_changed(hass, dt.utcnow() + timedelta(seconds=31))
|
||||
await hass.async_block_till_done()
|
||||
assert mock_device.times_polled == 1
|
||||
|
||||
# Options flow with no input results in form.
|
||||
result = await hass.config_entries.options.async_init(
|
||||
|
@ -346,5 +348,18 @@ async def test_options_flow(hass: HomeAssistantType):
|
|||
CONFIG_ENTRY_SCAN_INTERVAL: 60,
|
||||
}
|
||||
|
||||
# Also updates DataUpdateCoordinator.
|
||||
assert coordinator.update_interval == timedelta(seconds=60)
|
||||
# Forward time, ensure single poll after 60 seconds, still from original setting.
|
||||
async_fire_time_changed(hass, dt.utcnow() + timedelta(seconds=61))
|
||||
await hass.async_block_till_done()
|
||||
assert mock_device.times_polled == 2
|
||||
|
||||
# Now the updated interval takes effect.
|
||||
# Forward time, ensure single poll after 120 seconds.
|
||||
async_fire_time_changed(hass, dt.utcnow() + timedelta(seconds=121))
|
||||
await hass.async_block_till_done()
|
||||
assert mock_device.times_polled == 3
|
||||
|
||||
# Forward time, ensure single poll after 180 seconds.
|
||||
async_fire_time_changed(hass, dt.utcnow() + timedelta(seconds=181))
|
||||
await hass.async_block_till_done()
|
||||
assert mock_device.times_polled == 4
|
||||
|
|
Loading…
Reference in New Issue