diff --git a/homeassistant/components/unifi/controller.py b/homeassistant/components/unifi/controller.py index 1f3fd48140e..3cc786f7c86 100644 --- a/homeassistant/components/unifi/controller.py +++ b/homeassistant/components/unifi/controller.py @@ -9,12 +9,7 @@ from typing import Any from aiohttp import CookieJar import aiounifi -from aiounifi.interfaces.messages import ( - DATA_CLIENT_REMOVED, - DATA_DPI_GROUP, - DATA_DPI_GROUP_REMOVED, - DATA_EVENT, -) +from aiounifi.interfaces.messages import DATA_CLIENT_REMOVED, DATA_EVENT from aiounifi.models.event import EventKey from aiounifi.websocket import WebsocketSignal, WebsocketState import async_timeout @@ -247,14 +242,6 @@ class UniFiController: self.hass, self.signal_remove, data[DATA_CLIENT_REMOVED] ) - elif DATA_DPI_GROUP in data: - async_dispatcher_send(self.hass, self.signal_update) - - elif DATA_DPI_GROUP_REMOVED in data: - async_dispatcher_send( - self.hass, self.signal_remove, data[DATA_DPI_GROUP_REMOVED] - ) - @property def signal_reachable(self) -> str: """Integration specific event to signal a change in connection status.""" diff --git a/homeassistant/components/unifi/switch.py b/homeassistant/components/unifi/switch.py index 5e129cc402e..ed55b155559 100644 --- a/homeassistant/components/unifi/switch.py +++ b/homeassistant/components/unifi/switch.py @@ -33,7 +33,6 @@ from homeassistant.helpers.restore_state import RestoreEntity from .const import ATTR_MANUFACTURER, DOMAIN as UNIFI_DOMAIN from .unifi_client import UniFiClient -from .unifi_entity_base import UniFiBase BLOCK_SWITCH = "block" DPI_SWITCH = "dpi" @@ -88,7 +87,7 @@ async def async_setup_entry( @callback def items_added( clients: set = controller.api.clients, - dpi_groups: set = controller.api.dpi_groups, + devices: set = controller.api.devices, ) -> None: """Update the values of the controller.""" if controller.option_block_clients: @@ -97,9 +96,6 @@ async def async_setup_entry( if controller.option_poe_clients: add_poe_entities(controller, async_add_entities, clients, known_poe_clients) - if controller.option_dpi_restrictions: - add_dpi_entities(controller, async_add_entities, dpi_groups) - for signal in (controller.signal_update, controller.signal_options_update): config_entry.async_on_unload( async_dispatcher_connect(hass, signal, items_added) @@ -120,6 +116,20 @@ async def async_setup_entry( for index in controller.api.outlets: async_add_outlet_switch(ItemEvent.ADDED, index) + def async_add_dpi_switch(_: ItemEvent, obj_id: str) -> None: + """Add DPI switch from UniFi controller.""" + if ( + not controller.option_dpi_restrictions + or not controller.api.dpi_groups[obj_id].dpiapp_ids + ): + return + async_add_entities([UnifiDPIRestrictionSwitch(obj_id, controller)]) + + controller.api.ports.subscribe(async_add_dpi_switch, ItemEvent.ADDED) + + for dpi_group_id in controller.api.dpi_groups: + async_add_dpi_switch(ItemEvent.ADDED, dpi_group_id) + @callback def async_add_poe_switch(_: ItemEvent, obj_id: str) -> None: """Add port PoE switch from UniFi controller.""" @@ -198,23 +208,6 @@ def add_poe_entities(controller, async_add_entities, clients, known_poe_clients) async_add_entities(switches) -@callback -def add_dpi_entities(controller, async_add_entities, dpi_groups): - """Add new switch entities from the controller.""" - switches = [] - - for group in dpi_groups: - if ( - group in controller.entities[DOMAIN][DPI_SWITCH] - or not dpi_groups[group].dpiapp_ids - ): - continue - - switches.append(UniFiDPIRestrictionSwitch(dpi_groups[group], controller)) - - async_add_entities(switches) - - class UniFiPOEClientSwitch(UniFiClient, SwitchEntity, RestoreEntity): """Representation of a client that uses POE.""" @@ -367,132 +360,139 @@ class UniFiBlockClientSwitch(UniFiClient, SwitchEntity): await self.remove_item({self.client.mac}) -class UniFiDPIRestrictionSwitch(UniFiBase, SwitchEntity): +class UnifiDPIRestrictionSwitch(SwitchEntity): """Representation of a DPI restriction group.""" - DOMAIN = DOMAIN - TYPE = DPI_SWITCH - _attr_entity_category = EntityCategory.CONFIG - def __init__(self, dpi_group, controller): + def __init__(self, obj_id: str, controller): """Set up dpi switch.""" - super().__init__(dpi_group, controller) + controller.entities[DOMAIN][DPI_SWITCH].add(obj_id) + self._obj_id = obj_id + self.controller = controller - self._is_enabled = self.calculate_enabled() + dpi_group = controller.api.dpi_groups[obj_id] self._known_app_ids = dpi_group.dpiapp_ids - @property - def key(self) -> Any: - """Return item key.""" - return self._item.id + self._attr_available = controller.available + self._attr_is_on = self.calculate_enabled() + self._attr_name = dpi_group.name + self._attr_unique_id = dpi_group.id + self._attr_device_info = DeviceInfo( + entry_type=DeviceEntryType.SERVICE, + identifiers={(DOMAIN, f"unifi_controller_{obj_id}")}, + manufacturer=ATTR_MANUFACTURER, + model="UniFi Network", + name="UniFi Network", + ) async def async_added_to_hass(self) -> None: """Register callback to known apps.""" - await super().async_added_to_hass() - - apps = self.controller.api.dpi_apps - for app_id in self._item.dpiapp_ids: - apps[app_id].register_callback(self.async_update_callback) + self.async_on_remove( + self.controller.api.dpi_groups.subscribe(self.async_signalling_callback) + ) + self.async_on_remove( + self.controller.api.dpi_apps.subscribe( + self.async_signalling_callback, ItemEvent.CHANGED + ), + ) + self.async_on_remove( + async_dispatcher_connect( + self.hass, self.controller.signal_remove, self.remove_item + ) + ) + self.async_on_remove( + async_dispatcher_connect( + self.hass, self.controller.signal_options_update, self.options_updated + ) + ) + self.async_on_remove( + async_dispatcher_connect( + self.hass, + self.controller.signal_reachable, + self.async_signal_reachable_callback, + ) + ) async def async_will_remove_from_hass(self) -> None: - """Remove registered callbacks.""" - apps = self.controller.api.dpi_apps - for app_id in self._item.dpiapp_ids: - apps[app_id].remove_callback(self.async_update_callback) - - await super().async_will_remove_from_hass() + """Disconnect object when removed.""" + self.controller.entities[DOMAIN][DPI_SWITCH].remove(self._obj_id) @callback - def async_update_callback(self) -> None: - """Update the DPI switch state. - - Remove entity when no apps are paired with group. - Register callbacks to new apps. - Calculate and update entity state if it has changed. - """ - if not self._item.dpiapp_ids: - self.hass.async_create_task(self.remove_item({self.key})) + def async_signalling_callback(self, event: ItemEvent, obj_id: str) -> None: + """Object has new event.""" + if event == ItemEvent.DELETED: + self.hass.async_create_task(self.remove_item({self._obj_id})) return - if self._known_app_ids != self._item.dpiapp_ids: - self._known_app_ids = self._item.dpiapp_ids + dpi_group = self.controller.api.dpi_groups[self._obj_id] + if not dpi_group.dpiapp_ids: + self.hass.async_create_task(self.remove_item({self._obj_id})) + return - apps = self.controller.api.dpi_apps - for app_id in self._item.dpiapp_ids: - apps[app_id].register_callback(self.async_update_callback) + self._attr_available = self.controller.available + self._attr_is_on = self.calculate_enabled() + self.async_write_ha_state() - if (enabled := self.calculate_enabled()) != self._is_enabled: - self._is_enabled = enabled - super().async_update_callback() - - @property - def unique_id(self): - """Return a unique identifier for this switch.""" - return self._item.id - - @property - def name(self) -> str: - """Return the name of the DPI group.""" - return self._item.name + @callback + def async_signal_reachable_callback(self) -> None: + """Call when controller connection state change.""" + self.async_signalling_callback(ItemEvent.ADDED, self._obj_id) @property def icon(self): """Return the icon to use in the frontend.""" - if self._is_enabled: + if self._attr_is_on: return "mdi:network" return "mdi:network-off" def calculate_enabled(self) -> bool: """Calculate if all apps are enabled.""" + dpi_group = self.controller.api.dpi_groups[self._obj_id] return all( self.controller.api.dpi_apps[app_id].enabled - for app_id in self._item.dpiapp_ids + for app_id in dpi_group.dpiapp_ids if app_id in self.controller.api.dpi_apps ) - @property - def is_on(self): - """Return true if DPI group app restriction is enabled.""" - return self._is_enabled - async def async_turn_on(self, **kwargs: Any) -> None: """Restrict access of apps related to DPI group.""" + dpi_group = self.controller.api.dpi_groups[self._obj_id] return await asyncio.gather( *[ self.controller.api.request( DPIRestrictionAppEnableRequest.create(app_id, True) ) - for app_id in self._item.dpiapp_ids + for app_id in dpi_group.dpiapp_ids ] ) async def async_turn_off(self, **kwargs: Any) -> None: """Remove restriction of apps related to DPI group.""" + dpi_group = self.controller.api.dpi_groups[self._obj_id] return await asyncio.gather( *[ self.controller.api.request( DPIRestrictionAppEnableRequest.create(app_id, False) ) - for app_id in self._item.dpiapp_ids + for app_id in dpi_group.dpiapp_ids ] ) async def options_updated(self) -> None: """Config entry options are updated, remove entity if option is disabled.""" if not self.controller.option_dpi_restrictions: - await self.remove_item({self.key}) + await self.remove_item({self._attr_unique_id}) - @property - def device_info(self) -> DeviceInfo: - """Return a service description for device registry.""" - return DeviceInfo( - entry_type=DeviceEntryType.SERVICE, - identifiers={(DOMAIN, f"unifi_controller_{self._item.site_id}")}, - manufacturer=ATTR_MANUFACTURER, - model="UniFi Network", - name="UniFi Network", - ) + async def remove_item(self, keys: set) -> None: + """Remove entity if key is part of set.""" + if self._attr_unique_id not in keys: + return + + if self.registry_entry: + er.async_get(self.hass).async_remove(self.entity_id) + else: + await self.async_remove(force_remove=True) class UnifiOutletSwitch(SwitchEntity): diff --git a/tests/components/unifi/test_switch.py b/tests/components/unifi/test_switch.py index 12ef6f9b965..db0b358179c 100644 --- a/tests/components/unifi/test_switch.py +++ b/tests/components/unifi/test_switch.py @@ -761,7 +761,6 @@ async def test_remove_switches(hass, aioclient_mock, mock_unifi_websocket): mock_unifi_websocket(data=DPI_GROUP_REMOVED_EVENT) await hass.async_block_till_done() - await hass.async_block_till_done() assert hass.states.get("switch.block_media_streaming") is None assert len(hass.states.async_entity_ids(SWITCH_DOMAIN)) == 0 @@ -852,10 +851,21 @@ async def test_dpi_switches(hass, aioclient_mock, mock_unifi_websocket): assert hass.states.get("switch.block_media_streaming").state == STATE_OFF + # Availability signalling + + # Controller disconnects + mock_unifi_websocket(state=WebsocketState.DISCONNECTED) + await hass.async_block_till_done() + assert hass.states.get("switch.block_media_streaming").state == STATE_UNAVAILABLE + + # Controller reconnects + mock_unifi_websocket(state=WebsocketState.RUNNING) + await hass.async_block_till_done() + assert hass.states.get("switch.block_media_streaming").state == STATE_OFF + + # Remove app mock_unifi_websocket(data=DPI_GROUP_REMOVE_APP) await hass.async_block_till_done() - await hass.async_block_till_done() - await hass.async_block_till_done() assert hass.states.get("switch.block_media_streaming") is None assert len(hass.states.async_entity_ids(SWITCH_DOMAIN)) == 0