From d94421e1a4bae6ecc9663929482553536b284aad Mon Sep 17 00:00:00 2001 From: wittypluck Date: Sun, 14 Jan 2024 15:19:43 +0100 Subject: [PATCH] Reset UniFi bandwidth sensor when client misses heartbeat (#104522) * Reset UniFi bandwidth sensor when client misses heartbeat * Fix initialization sequence * Code simplification: remove heartbeat_timedelta, unique_id and tracker logic * Add unit tests * Remove unused _is_connected attribute * Remove redundant async_initiate_state * Make is_connected_fn optional, heartbeat detection will only happen if not None * Add checks on is_connected_fn --- homeassistant/components/unifi/sensor.py | 61 +++++++++++++++++++++++- tests/components/unifi/test_sensor.py | 28 ++++++++++- 2 files changed, 87 insertions(+), 2 deletions(-) diff --git a/homeassistant/components/unifi/sensor.py b/homeassistant/components/unifi/sensor.py index c7b851a8fbb..ef158b99e4e 100644 --- a/homeassistant/components/unifi/sensor.py +++ b/homeassistant/components/unifi/sensor.py @@ -32,7 +32,8 @@ from homeassistant.components.sensor import ( ) from homeassistant.config_entries import ConfigEntry from homeassistant.const import EntityCategory, UnitOfDataRate, UnitOfPower -from homeassistant.core import HomeAssistant, callback +from homeassistant.core import Event as core_Event, HomeAssistant, callback +from homeassistant.helpers.dispatcher import async_dispatcher_connect from homeassistant.helpers.entity_platform import AddEntitiesCallback import homeassistant.util.dt as dt_util @@ -132,6 +133,20 @@ def async_device_outlet_supported_fn(controller: UniFiController, obj_id: str) - return controller.api.devices[obj_id].outlet_ac_power_budget is not None +@callback +def async_client_is_connected_fn(controller: UniFiController, obj_id: str) -> bool: + """Check if client was last seen recently.""" + client = controller.api.clients[obj_id] + + if ( + dt_util.utcnow() - dt_util.utc_from_timestamp(client.last_seen or 0) + > controller.option_detection_time + ): + return False + + return True + + @dataclass(frozen=True) class UnifiSensorEntityDescriptionMixin(Generic[HandlerT, ApiItemT]): """Validate and load entities from different UniFi handlers.""" @@ -153,6 +168,8 @@ class UnifiSensorEntityDescription( ): """Class describing UniFi sensor entity.""" + is_connected_fn: Callable[[UniFiController, str], bool] | None = None + ENTITY_DESCRIPTIONS: tuple[UnifiSensorEntityDescription, ...] = ( UnifiSensorEntityDescription[Clients, Client]( @@ -169,6 +186,7 @@ ENTITY_DESCRIPTIONS: tuple[UnifiSensorEntityDescription, ...] = ( device_info_fn=async_client_device_info_fn, event_is_on=None, event_to_subscribe=None, + is_connected_fn=async_client_is_connected_fn, name_fn=lambda _: "RX", object_fn=lambda api, obj_id: api.clients[obj_id], should_poll=False, @@ -190,6 +208,7 @@ ENTITY_DESCRIPTIONS: tuple[UnifiSensorEntityDescription, ...] = ( device_info_fn=async_client_device_info_fn, event_is_on=None, event_to_subscribe=None, + is_connected_fn=async_client_is_connected_fn, name_fn=lambda _: "TX", object_fn=lambda api, obj_id: api.clients[obj_id], should_poll=False, @@ -388,6 +407,16 @@ class UnifiSensorEntity(UnifiEntity[HandlerT, ApiItemT], SensorEntity): entity_description: UnifiSensorEntityDescription[HandlerT, ApiItemT] + @callback + def _make_disconnected(self, *_: core_Event) -> None: + """No heart beat by device. + + Reset sensor value to 0 when client device is disconnected + """ + if self._attr_native_value != 0: + self._attr_native_value = 0 + self.async_write_ha_state() + @callback def async_update_state(self, event: ItemEvent, obj_id: str) -> None: """Update entity state. @@ -398,3 +427,33 @@ class UnifiSensorEntity(UnifiEntity[HandlerT, ApiItemT], SensorEntity): obj = description.object_fn(self.controller.api, self._obj_id) if (value := description.value_fn(self.controller, obj)) != self.native_value: self._attr_native_value = value + + if description.is_connected_fn is not None: + # Send heartbeat if client is connected + if description.is_connected_fn(self.controller, self._obj_id): + self.controller.async_heartbeat( + self._attr_unique_id, + dt_util.utcnow() + self.controller.option_detection_time, + ) + + async def async_added_to_hass(self) -> None: + """Register callbacks.""" + await super().async_added_to_hass() + + if self.entity_description.is_connected_fn is not None: + # Register callback for missed heartbeat + self.async_on_remove( + async_dispatcher_connect( + self.hass, + f"{self.controller.signal_heartbeat_missed}_{self.unique_id}", + self._make_disconnected, + ) + ) + + async def async_will_remove_from_hass(self) -> None: + """Disconnect object when removed.""" + await super().async_will_remove_from_hass() + + if self.entity_description.is_connected_fn is not None: + # Remove heartbeat registration + self.controller.async_heartbeat(self._attr_unique_id) diff --git a/tests/components/unifi/test_sensor.py b/tests/components/unifi/test_sensor.py index 6eb6c05209c..1a3c81ec4c4 100644 --- a/tests/components/unifi/test_sensor.py +++ b/tests/components/unifi/test_sensor.py @@ -5,7 +5,7 @@ from unittest.mock import patch from aiounifi.models.device import DeviceState from aiounifi.models.message import MessageKey -from freezegun.api import FrozenDateTimeFactory +from freezegun.api import FrozenDateTimeFactory, freeze_time import pytest from homeassistant.components.device_tracker import DOMAIN as TRACKER_DOMAIN @@ -22,6 +22,7 @@ from homeassistant.components.unifi.const import ( CONF_TRACK_CLIENTS, CONF_TRACK_DEVICES, DEVICE_STATES, + DOMAIN as UNIFI_DOMAIN, ) from homeassistant.config_entries import RELOAD_AFTER_UPDATE_DELAY from homeassistant.const import ATTR_DEVICE_CLASS, STATE_UNAVAILABLE, EntityCategory @@ -393,6 +394,31 @@ async def test_bandwidth_sensors( assert hass.states.get("sensor.wireless_client_rx").state == "3456.0" assert hass.states.get("sensor.wireless_client_tx").state == "7891.0" + # Verify reset sensor after heartbeat expires + + controller = hass.data[UNIFI_DOMAIN][config_entry.entry_id] + new_time = dt_util.utcnow() + wireless_client["last_seen"] = dt_util.as_timestamp(new_time) + + mock_unifi_websocket(message=MessageKey.CLIENT, data=wireless_client) + await hass.async_block_till_done() + + with freeze_time(new_time): + async_fire_time_changed(hass, new_time) + await hass.async_block_till_done() + + assert hass.states.get("sensor.wireless_client_rx").state == "3456.0" + assert hass.states.get("sensor.wireless_client_tx").state == "7891.0" + + new_time = new_time + controller.option_detection_time + timedelta(seconds=1) + + with freeze_time(new_time): + async_fire_time_changed(hass, new_time) + await hass.async_block_till_done() + + assert hass.states.get("sensor.wireless_client_rx").state == "0" + assert hass.states.get("sensor.wireless_client_tx").state == "0" + # Disable option options[CONF_ALLOW_BANDWIDTH_SENSORS] = False