UniFi - refactor entity management (#34367)

* Move removal of sensor entities into a base class

* Fix martins comments on sensors

* Reflect sensor changes on device_tracker platform

* Reflect sensor changes on switch platform

* Improve layering

* Make sure to clean up entity and device registry when removing entities

* Fix martins comments
pull/34440/head
Robert Svensson 2020-04-19 21:30:06 +02:00 committed by GitHub
parent a80ce60e75
commit e5a861dc90
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
7 changed files with 252 additions and 381 deletions

View File

@ -72,6 +72,8 @@ class UniFiController:
self._site_name = None self._site_name = None
self._site_role = None self._site_role = None
self.entities = {}
@property @property
def controller_id(self): def controller_id(self):
"""Return the controller ID.""" """Return the controller ID."""

View File

@ -1,10 +1,11 @@
"""Track devices using UniFi controllers.""" """Track devices using UniFi controllers."""
import logging import logging
from homeassistant.components.device_tracker import DOMAIN as DEVICE_TRACKER_DOMAIN from homeassistant.components.device_tracker import DOMAIN
from homeassistant.components.device_tracker.config_entry import ScannerEntity from homeassistant.components.device_tracker.config_entry import ScannerEntity
from homeassistant.components.device_tracker.const import SOURCE_TYPE_ROUTER from homeassistant.components.device_tracker.const import SOURCE_TYPE_ROUTER
from homeassistant.components.unifi.config_flow import get_controller_from_config_entry from homeassistant.components.unifi.config_flow import get_controller_from_config_entry
from homeassistant.components.unifi.unifi_entity_base import UniFiBase
from homeassistant.core import callback from homeassistant.core import callback
from homeassistant.helpers.device_registry import CONNECTION_NETWORK_MAC from homeassistant.helpers.device_registry import CONNECTION_NETWORK_MAC
from homeassistant.helpers.dispatcher import async_dispatcher_connect from homeassistant.helpers.dispatcher import async_dispatcher_connect
@ -38,30 +39,26 @@ CLIENT_STATIC_ATTRIBUTES = [
"oui", "oui",
] ]
CLIENT_TRACKER = "client"
DEVICE_TRACKER = "device"
async def async_setup_entry(hass, config_entry, async_add_entities): async def async_setup_entry(hass, config_entry, async_add_entities):
"""Set up device tracker for UniFi component.""" """Set up device tracker for UniFi component."""
controller = get_controller_from_config_entry(hass, config_entry) controller = get_controller_from_config_entry(hass, config_entry)
tracked = {} controller.entities[DOMAIN] = {CLIENT_TRACKER: set(), DEVICE_TRACKER: set()}
option_track_clients = controller.option_track_clients
option_track_devices = controller.option_track_devices
option_track_wired_clients = controller.option_track_wired_clients
option_ssid_filter = controller.option_ssid_filter
entity_registry = await hass.helpers.entity_registry.async_get_registry()
# Restore clients that is not a part of active clients list. # Restore clients that is not a part of active clients list.
entity_registry = await hass.helpers.entity_registry.async_get_registry()
for entity in entity_registry.entities.values(): for entity in entity_registry.entities.values():
if ( if (
entity.config_entry_id == config_entry.entry_id entity.config_entry_id == config_entry.entry_id
and entity.domain == DEVICE_TRACKER_DOMAIN and entity.domain == DOMAIN
and "-" in entity.unique_id and "-" in entity.unique_id
): ):
mac, _ = entity.unique_id.split("-", 1) mac, _ = entity.unique_id.split("-", 1)
if mac in controller.api.clients or mac not in controller.api.clients_all: if mac in controller.api.clients or mac not in controller.api.clients_all:
continue continue
@ -74,99 +71,19 @@ async def async_setup_entry(hass, config_entry, async_add_entities):
@callback @callback
def items_added(): def items_added():
"""Update the values of the controller.""" """Update the values of the controller."""
nonlocal option_track_clients if controller.option_track_clients or controller.option_track_devices:
nonlocal option_track_devices add_entities(controller, async_add_entities)
if not option_track_clients and not option_track_devices: for signal in (controller.signal_update, controller.signal_options_update):
return controller.listeners.append(async_dispatcher_connect(hass, signal, items_added))
add_entities(controller, async_add_entities, tracked)
controller.listeners.append(
async_dispatcher_connect(hass, controller.signal_update, items_added)
)
@callback
def items_removed(mac_addresses: set) -> None:
"""Items have been removed from the controller."""
remove_entities(controller, mac_addresses, tracked, entity_registry)
controller.listeners.append(
async_dispatcher_connect(hass, controller.signal_remove, items_removed)
)
@callback
def options_updated():
"""Manage entities affected by config entry options."""
nonlocal option_track_clients
nonlocal option_track_devices
nonlocal option_track_wired_clients
nonlocal option_ssid_filter
update = False
remove = set()
for current_option, config_entry_option, tracker_class in (
(option_track_clients, controller.option_track_clients, UniFiClientTracker),
(option_track_devices, controller.option_track_devices, UniFiDeviceTracker),
):
if current_option == config_entry_option:
continue
if config_entry_option:
update = True
else:
for mac, entity in tracked.items():
if isinstance(entity, tracker_class):
remove.add(mac)
if (
controller.option_track_clients
and option_track_wired_clients != controller.option_track_wired_clients
):
if controller.option_track_wired_clients:
update = True
else:
for mac, entity in tracked.items():
if isinstance(entity, UniFiClientTracker) and entity.is_wired:
remove.add(mac)
if option_ssid_filter != controller.option_ssid_filter:
update = True
if controller.option_ssid_filter:
for mac, entity in tracked.items():
if (
isinstance(entity, UniFiClientTracker)
and not entity.is_wired
and entity.client.essid not in controller.option_ssid_filter
):
remove.add(mac)
option_track_clients = controller.option_track_clients
option_track_devices = controller.option_track_devices
option_track_wired_clients = controller.option_track_wired_clients
option_ssid_filter = controller.option_ssid_filter
remove_entities(controller, remove, tracked, entity_registry)
if update:
items_added()
controller.listeners.append(
async_dispatcher_connect(
hass, controller.signal_options_update, options_updated
)
)
items_added() items_added()
@callback @callback
def add_entities(controller, async_add_entities, tracked): def add_entities(controller, async_add_entities):
"""Add new tracker entities from the controller.""" """Add new tracker entities from the controller."""
new_tracked = [] trackers = []
for items, tracker_class, track in ( for items, tracker_class, track in (
(controller.api.clients, UniFiClientTracker, controller.option_track_clients), (controller.api.clients, UniFiClientTracker, controller.option_track_clients),
@ -175,46 +92,36 @@ def add_entities(controller, async_add_entities, tracked):
if not track: if not track:
continue continue
for item_id in items: for mac in items:
if item_id in tracked: if mac in controller.entities[DOMAIN][tracker_class.TYPE]:
continue continue
item = items[mac]
if tracker_class is UniFiClientTracker: if tracker_class is UniFiClientTracker:
client = items[item_id]
if not controller.option_track_wired_clients and client.is_wired: if item.is_wired:
continue if not controller.option_track_wired_clients:
continue
else:
if (
controller.option_ssid_filter
and item.essid not in controller.option_ssid_filter
):
continue
if ( trackers.append(tracker_class(item, controller))
controller.option_ssid_filter
and not client.is_wired
and client.essid not in controller.option_ssid_filter
):
continue
tracked[item_id] = tracker_class(items[item_id], controller) if trackers:
new_tracked.append(tracked[item_id]) async_add_entities(trackers)
if new_tracked:
async_add_entities(new_tracked)
@callback
def remove_entities(controller, mac_addresses, tracked, entity_registry):
"""Remove select tracked entities."""
for mac in mac_addresses:
if mac not in tracked:
continue
entity = tracked.pop(mac)
controller.hass.async_create_task(entity.async_remove())
class UniFiClientTracker(UniFiClient, ScannerEntity): class UniFiClientTracker(UniFiClient, ScannerEntity):
"""Representation of a network client.""" """Representation of a network client."""
TYPE = CLIENT_TRACKER
def __init__(self, client, controller): def __init__(self, client, controller):
"""Set up tracked client.""" """Set up tracked client."""
super().__init__(client, controller) super().__init__(client, controller)
@ -315,34 +222,52 @@ class UniFiClientTracker(UniFiClient, ScannerEntity):
return attributes return attributes
async def options_updated(self) -> None:
"""Config entry options are updated, remove entity if option is disabled."""
if not self.controller.option_track_clients:
await self.async_remove()
class UniFiDeviceTracker(ScannerEntity): elif self.is_wired:
if not self.controller.option_track_wired_clients:
await self.async_remove()
else:
if (
self.controller.option_ssid_filter
and self.client.essid not in self.controller.option_ssid_filter
):
await self.async_remove()
class UniFiDeviceTracker(UniFiBase, ScannerEntity):
"""Representation of a network infrastructure device.""" """Representation of a network infrastructure device."""
TYPE = DEVICE_TRACKER
def __init__(self, device, controller): def __init__(self, device, controller):
"""Set up tracked device.""" """Set up tracked device."""
super().__init__(controller)
self.device = device self.device = device
self.controller = controller
@property
def mac(self):
"""Return MAC of device."""
return self.device.mac
async def async_added_to_hass(self): async def async_added_to_hass(self):
"""Subscribe to device events.""" """Subscribe to device events."""
await super().async_added_to_hass()
LOGGER.debug("New device %s (%s)", self.entity_id, self.device.mac) LOGGER.debug("New device %s (%s)", self.entity_id, self.device.mac)
self.device.register_callback(self.async_update_callback) self.device.register_callback(self.async_update_callback)
self.async_on_remove(
async_dispatcher_connect(
self.hass, self.controller.signal_reachable, self.async_update_callback
)
)
async def async_will_remove_from_hass(self) -> None: async def async_will_remove_from_hass(self) -> None:
"""Disconnect device object when removed.""" """Disconnect device object when removed."""
await super().async_will_remove_from_hass()
self.device.remove_callback(self.async_update_callback) self.device.remove_callback(self.async_update_callback)
@callback @callback
def async_update_callback(self): def async_update_callback(self):
"""Update the sensor's state.""" """Update the sensor's state."""
LOGGER.debug("Updating device %s (%s)", self.entity_id, self.device.mac) LOGGER.debug("Updating device %s (%s)", self.entity_id, self.device.mac)
self.async_write_ha_state() self.async_write_ha_state()
@property @property
@ -410,7 +335,7 @@ class UniFiDeviceTracker(ScannerEntity):
return attributes return attributes
@property async def options_updated(self) -> None:
def should_poll(self): """Config entry options are updated, remove entity if option is disabled."""
"""No polling needed.""" if not self.controller.option_track_devices:
return True await self.async_remove()

View File

@ -1,6 +1,7 @@
"""Support for bandwidth sensors with UniFi clients.""" """Support for bandwidth sensors with UniFi clients."""
import logging import logging
from homeassistant.components.sensor import DOMAIN
from homeassistant.components.unifi.config_flow import get_controller_from_config_entry from homeassistant.components.unifi.config_flow import get_controller_from_config_entry
from homeassistant.const import DATA_MEGABYTES from homeassistant.const import DATA_MEGABYTES
from homeassistant.core import callback from homeassistant.core import callback
@ -10,6 +11,9 @@ from .unifi_client import UniFiClient
LOGGER = logging.getLogger(__name__) LOGGER = logging.getLogger(__name__)
RX_SENSOR = "rx"
TX_SENSOR = "tx"
async def async_setup_platform(hass, config, async_add_entities, discovery_info=None): async def async_setup_platform(hass, config, async_add_entities, discovery_info=None):
"""Sensor platform doesn't support configuration through configuration.yaml.""" """Sensor platform doesn't support configuration through configuration.yaml."""
@ -18,144 +22,74 @@ async def async_setup_platform(hass, config, async_add_entities, discovery_info=
async def async_setup_entry(hass, config_entry, async_add_entities): async def async_setup_entry(hass, config_entry, async_add_entities):
"""Set up sensors for UniFi integration.""" """Set up sensors for UniFi integration."""
controller = get_controller_from_config_entry(hass, config_entry) controller = get_controller_from_config_entry(hass, config_entry)
sensors = {} controller.entities[DOMAIN] = {RX_SENSOR: set(), TX_SENSOR: set()}
option_allow_bandwidth_sensors = controller.option_allow_bandwidth_sensors
entity_registry = await hass.helpers.entity_registry.async_get_registry()
@callback @callback
def items_added(): def items_added():
"""Update the values of the controller.""" """Update the values of the controller."""
nonlocal option_allow_bandwidth_sensors if controller.option_allow_bandwidth_sensors:
add_entities(controller, async_add_entities)
if not option_allow_bandwidth_sensors: for signal in (controller.signal_update, controller.signal_options_update):
return controller.listeners.append(async_dispatcher_connect(hass, signal, items_added))
add_entities(controller, async_add_entities, sensors)
controller.listeners.append(
async_dispatcher_connect(hass, controller.signal_update, items_added)
)
@callback
def items_removed(mac_addresses: set) -> None:
"""Items have been removed from the controller."""
remove_entities(controller, mac_addresses, sensors, entity_registry)
controller.listeners.append(
async_dispatcher_connect(hass, controller.signal_remove, items_removed)
)
@callback
def options_updated():
"""Update the values of the controller."""
nonlocal option_allow_bandwidth_sensors
if option_allow_bandwidth_sensors != controller.option_allow_bandwidth_sensors:
option_allow_bandwidth_sensors = controller.option_allow_bandwidth_sensors
if option_allow_bandwidth_sensors:
items_added()
else:
for sensor in sensors.values():
hass.async_create_task(sensor.async_remove())
sensors.clear()
controller.listeners.append(
async_dispatcher_connect(
hass, controller.signal_options_update, options_updated
)
)
items_added() items_added()
@callback @callback
def add_entities(controller, async_add_entities, sensors): def add_entities(controller, async_add_entities):
"""Add new sensor entities from the controller.""" """Add new sensor entities from the controller."""
new_sensors = [] sensors = []
for client_id in controller.api.clients: for mac in controller.api.clients:
for direction, sensor_class in ( for sensor_class in (UniFiRxBandwidthSensor, UniFiTxBandwidthSensor):
("rx", UniFiRxBandwidthSensor), if mac not in controller.entities[DOMAIN][sensor_class.TYPE]:
("tx", UniFiTxBandwidthSensor), sensors.append(sensor_class(controller.api.clients[mac], controller))
):
item_id = f"{direction}-{client_id}"
if item_id in sensors: if sensors:
continue async_add_entities(sensors)
sensors[item_id] = sensor_class(
controller.api.clients[client_id], controller
)
new_sensors.append(sensors[item_id])
if new_sensors:
async_add_entities(new_sensors)
@callback class UniFiBandwidthSensor(UniFiClient):
def remove_entities(controller, mac_addresses, sensors, entity_registry): """UniFi bandwidth sensor base class."""
"""Remove select sensor entities."""
for mac in mac_addresses:
for direction in ("rx", "tx"):
item_id = f"{direction}-{mac}"
if item_id not in sensors:
continue
entity = sensors.pop(item_id)
controller.hass.async_create_task(entity.async_remove())
class UniFiRxBandwidthSensor(UniFiClient):
"""Receiving bandwidth sensor."""
@property @property
def state(self): def name(self) -> str:
"""Return the state of the sensor."""
if self._is_wired:
return self.client.wired_rx_bytes / 1000000
return self.client.raw.get("rx_bytes", 0) / 1000000
@property
def name(self):
"""Return the name of the client.""" """Return the name of the client."""
name = self.client.name or self.client.hostname return f"{super().name} {self.TYPE.upper()}"
return f"{name} RX"
@property @property
def unique_id(self): def unit_of_measurement(self) -> str:
"""Return a unique identifier for this bandwidth sensor."""
return f"rx-{self.client.mac}"
@property
def unit_of_measurement(self):
"""Return the unit of measurement of this entity.""" """Return the unit of measurement of this entity."""
return DATA_MEGABYTES return DATA_MEGABYTES
async def options_updated(self) -> None:
"""Config entry options are updated, remove entity if option is disabled."""
if not self.controller.option_allow_bandwidth_sensors:
await self.async_remove()
class UniFiTxBandwidthSensor(UniFiRxBandwidthSensor):
"""Transmitting bandwidth sensor.""" class UniFiRxBandwidthSensor(UniFiBandwidthSensor):
"""Receiving bandwidth sensor."""
TYPE = RX_SENSOR
@property @property
def state(self): def state(self) -> int:
"""Return the state of the sensor."""
if self._is_wired:
return self.client.wired_rx_bytes / 1000000
return self.client.rx_bytes / 1000000
class UniFiTxBandwidthSensor(UniFiBandwidthSensor):
"""Transmitting bandwidth sensor."""
TYPE = TX_SENSOR
@property
def state(self) -> int:
"""Return the state of the sensor.""" """Return the state of the sensor."""
if self._is_wired: if self._is_wired:
return self.client.wired_tx_bytes / 1000000 return self.client.wired_tx_bytes / 1000000
return self.client.raw.get("tx_bytes", 0) / 1000000 return self.client.tx_bytes / 1000000
@property
def name(self):
"""Return the name of the client."""
name = self.client.name or self.client.hostname
return f"{name} TX"
@property
def unique_id(self):
"""Return a unique identifier for this bandwidth sensor."""
return f"tx-{self.client.mac}"

View File

@ -1,7 +1,7 @@
"""Support for devices connected to UniFi POE.""" """Support for devices connected to UniFi POE."""
import logging import logging
from homeassistant.components.switch import SwitchDevice from homeassistant.components.switch import DOMAIN, SwitchDevice
from homeassistant.components.unifi.config_flow import get_controller_from_config_entry from homeassistant.components.unifi.config_flow import get_controller_from_config_entry
from homeassistant.core import callback from homeassistant.core import callback
from homeassistant.helpers.dispatcher import async_dispatcher_connect from homeassistant.helpers.dispatcher import async_dispatcher_connect
@ -11,6 +11,9 @@ from .unifi_client import UniFiClient
LOGGER = logging.getLogger(__name__) LOGGER = logging.getLogger(__name__)
BLOCK_SWITCH = "block"
POE_SWITCH = "poe"
async def async_setup_platform(hass, config, async_add_entities, discovery_info=None): async def async_setup_platform(hass, config, async_add_entities, discovery_info=None):
"""Component doesn't support configuration through configuration.yaml.""" """Component doesn't support configuration through configuration.yaml."""
@ -22,24 +25,20 @@ async def async_setup_entry(hass, config_entry, async_add_entities):
Switches are controlling network access and switch ports with POE. Switches are controlling network access and switch ports with POE.
""" """
controller = get_controller_from_config_entry(hass, config_entry) controller = get_controller_from_config_entry(hass, config_entry)
controller.entities[DOMAIN] = {BLOCK_SWITCH: set(), POE_SWITCH: set()}
if controller.site_role != "admin": if controller.site_role != "admin":
return return
switches = {}
switches_off = [] switches_off = []
option_block_clients = controller.option_block_clients
option_poe_clients = controller.option_poe_clients
entity_registry = await hass.helpers.entity_registry.async_get_registry()
# Restore clients that is not a part of active clients list. # Restore clients that is not a part of active clients list.
entity_registry = await hass.helpers.entity_registry.async_get_registry()
for entity in entity_registry.entities.values(): for entity in entity_registry.entities.values():
if ( if (
entity.config_entry_id == config_entry.entry_id entity.config_entry_id == config_entry.entry_id
and entity.unique_id.startswith("poe-") and entity.unique_id.startswith(f"{POE_SWITCH}-")
): ):
_, mac = entity.unique_id.split("-", 1) _, mac = entity.unique_id.split("-", 1)
@ -57,110 +56,53 @@ async def async_setup_entry(hass, config_entry, async_add_entities):
@callback @callback
def items_added(): def items_added():
"""Update the values of the controller.""" """Update the values of the controller."""
add_entities(controller, async_add_entities, switches, switches_off) if controller.option_block_clients or controller.option_poe_clients:
add_entities(controller, async_add_entities, switches_off)
controller.listeners.append( for signal in (controller.signal_update, controller.signal_options_update):
async_dispatcher_connect(hass, controller.signal_update, items_added) controller.listeners.append(async_dispatcher_connect(hass, signal, items_added))
)
@callback
def items_removed(mac_addresses: set) -> None:
"""Items have been removed from the controller."""
remove_entities(controller, mac_addresses, switches, entity_registry)
controller.listeners.append(
async_dispatcher_connect(hass, controller.signal_remove, items_removed)
)
@callback
def options_updated():
"""Manage entities affected by config entry options."""
nonlocal option_block_clients
nonlocal option_poe_clients
update = set()
remove = set()
if option_block_clients != controller.option_block_clients:
option_block_clients = controller.option_block_clients
for block_client_id, entity in switches.items():
if not isinstance(entity, UniFiBlockClientSwitch):
continue
if entity.client.mac in option_block_clients:
update.add(block_client_id)
else:
remove.add(block_client_id)
if option_poe_clients != controller.option_poe_clients:
option_poe_clients = controller.option_poe_clients
if option_poe_clients:
update.add("poe_clients_enabled")
else:
for poe_client_id, entity in switches.items():
if isinstance(entity, UniFiPOEClientSwitch):
remove.add(poe_client_id)
for client_id in remove:
entity = switches.pop(client_id)
hass.async_create_task(entity.async_remove())
if len(update) != len(option_block_clients):
items_added()
controller.listeners.append(
async_dispatcher_connect(
hass, controller.signal_options_update, options_updated
)
)
items_added() items_added()
switches_off.clear() switches_off.clear()
@callback @callback
def add_entities(controller, async_add_entities, switches, switches_off): def add_entities(controller, async_add_entities, switches_off):
"""Add new switch entities from the controller.""" """Add new switch entities from the controller."""
new_switches = [] switches = []
devices = controller.api.devices
for client_id in controller.option_block_clients: for mac in controller.option_block_clients:
client = None if mac in controller.entities[DOMAIN][BLOCK_SWITCH]:
block_client_id = f"block-{client_id}"
if block_client_id in switches:
continue continue
if client_id in controller.api.clients: client = None
client = controller.api.clients[client_id]
elif client_id in controller.api.clients_all: if mac in controller.api.clients:
client = controller.api.clients_all[client_id] client = controller.api.clients[mac]
elif mac in controller.api.clients_all:
client = controller.api.clients_all[mac]
if not client: if not client:
continue continue
switches[block_client_id] = UniFiBlockClientSwitch(client, controller) switches.append(UniFiBlockClientSwitch(client, controller))
new_switches.append(switches[block_client_id])
if controller.option_poe_clients: if controller.option_poe_clients:
for client_id in controller.api.clients: devices = controller.api.devices
poe_client_id = f"poe-{client_id}" for mac in controller.api.clients:
if poe_client_id in switches: poe_client_id = f"{POE_SWITCH}-{mac}"
if mac in controller.entities[DOMAIN][POE_SWITCH]:
continue continue
client = controller.api.clients[client_id] client = controller.api.clients[mac]
if poe_client_id in switches_off: if poe_client_id not in switches_off and (
pass mac in controller.wireless_clients
# Network device with active POE
elif (
client_id in controller.wireless_clients
or client.sw_mac not in devices or client.sw_mac not in devices
or not devices[client.sw_mac].ports[client.sw_port].port_poe or not devices[client.sw_mac].ports[client.sw_port].port_poe
or not devices[client.sw_mac].ports[client.sw_port].poe_enable or not devices[client.sw_mac].ports[client.sw_port].poe_enable
@ -187,31 +129,17 @@ def add_entities(controller, async_add_entities, switches, switches_off):
if multi_clients_on_port: if multi_clients_on_port:
continue continue
switches[poe_client_id] = UniFiPOEClientSwitch(client, controller) switches.append(UniFiPOEClientSwitch(client, controller))
new_switches.append(switches[poe_client_id])
if new_switches: if switches:
async_add_entities(new_switches) async_add_entities(switches)
@callback
def remove_entities(controller, mac_addresses, switches, entity_registry):
"""Remove select switch entities."""
for mac in mac_addresses:
for switch_type in ("block", "poe"):
item_id = f"{switch_type}-{mac}"
if item_id not in switches:
continue
entity = switches.pop(item_id)
controller.hass.async_create_task(entity.async_remove())
class UniFiPOEClientSwitch(UniFiClient, SwitchDevice, RestoreEntity): class UniFiPOEClientSwitch(UniFiClient, SwitchDevice, RestoreEntity):
"""Representation of a client that uses POE.""" """Representation of a client that uses POE."""
TYPE = POE_SWITCH
def __init__(self, client, controller): def __init__(self, client, controller):
"""Set up POE switch.""" """Set up POE switch."""
super().__init__(client, controller) super().__init__(client, controller)
@ -225,7 +153,6 @@ class UniFiPOEClientSwitch(UniFiClient, SwitchDevice, RestoreEntity):
await super().async_added_to_hass() await super().async_added_to_hass()
state = await self.async_get_last_state() state = await self.async_get_last_state()
if state is None: if state is None:
return return
@ -238,11 +165,6 @@ class UniFiPOEClientSwitch(UniFiClient, SwitchDevice, RestoreEntity):
if not self.client.sw_port: if not self.client.sw_port:
self.client.raw["sw_port"] = state.attributes["port"] self.client.raw["sw_port"] = state.attributes["port"]
@property
def unique_id(self):
"""Return a unique identifier for this switch."""
return f"poe-{self.client.mac}"
@property @property
def is_on(self): def is_on(self):
"""Return true if POE is active.""" """Return true if POE is active."""
@ -301,14 +223,16 @@ class UniFiPOEClientSwitch(UniFiClient, SwitchDevice, RestoreEntity):
self.client.sw_port, self.client.sw_port,
) )
async def options_updated(self) -> None:
"""Config entry options are updated, remove entity if option is disabled."""
if not self.controller.option_poe_clients:
await self.async_remove()
class UniFiBlockClientSwitch(UniFiClient, SwitchDevice): class UniFiBlockClientSwitch(UniFiClient, SwitchDevice):
"""Representation of a blockable client.""" """Representation of a blockable client."""
@property TYPE = BLOCK_SWITCH
def unique_id(self):
"""Return a unique identifier for this switch."""
return f"block-{self.client.mac}"
@property @property
def is_on(self): def is_on(self):
@ -329,3 +253,8 @@ class UniFiBlockClientSwitch(UniFiClient, SwitchDevice):
if self.is_blocked: if self.is_blocked:
return "mdi:network-off" return "mdi:network-off"
return "mdi:network" return "mdi:network"
async def options_updated(self) -> None:
"""Config entry options are updated, remove entity if option is disabled."""
if self.client.mac not in self.controller.option_block_clients:
await self.async_remove()

View File

@ -15,10 +15,9 @@ from aiounifi.events import (
WIRELESS_CLIENT_UNBLOCKED, WIRELESS_CLIENT_UNBLOCKED,
) )
from homeassistant.components.unifi.unifi_entity_base import UniFiBase
from homeassistant.core import callback from homeassistant.core import callback
from homeassistant.helpers.device_registry import CONNECTION_NETWORK_MAC from homeassistant.helpers.device_registry import CONNECTION_NETWORK_MAC
from homeassistant.helpers.dispatcher import async_dispatcher_connect
from homeassistant.helpers.entity import Entity
LOGGER = logging.getLogger(__name__) LOGGER = logging.getLogger(__name__)
@ -32,31 +31,33 @@ WIRELESS_CLIENT = (
) )
class UniFiClient(Entity): class UniFiClient(UniFiBase):
"""Base class for UniFi clients.""" """Base class for UniFi clients."""
def __init__(self, client, controller) -> None: def __init__(self, client, controller) -> None:
"""Set up client.""" """Set up client."""
super().__init__(controller)
self.client = client self.client = client
self.controller = controller
self._is_wired = self.client.mac not in controller.wireless_clients self._is_wired = self.client.mac not in controller.wireless_clients
self.is_blocked = self.client.blocked self.is_blocked = self.client.blocked
self.wired_connection = None self.wired_connection = None
self.wireless_connection = None self.wireless_connection = None
@property
def mac(self):
"""Return MAC of client."""
return self.client.mac
async def async_added_to_hass(self) -> None: async def async_added_to_hass(self) -> None:
"""Client entity created.""" """Client entity created."""
await super().async_added_to_hass()
LOGGER.debug("New client %s (%s)", self.entity_id, self.client.mac) LOGGER.debug("New client %s (%s)", self.entity_id, self.client.mac)
self.client.register_callback(self.async_update_callback) self.client.register_callback(self.async_update_callback)
self.async_on_remove(
async_dispatcher_connect(
self.hass, self.controller.signal_reachable, self.async_update_callback
)
)
async def async_will_remove_from_hass(self) -> None: async def async_will_remove_from_hass(self) -> None:
"""Disconnect client object when removed.""" """Disconnect client object when removed."""
await super().async_will_remove_from_hass()
self.client.remove_callback(self.async_update_callback) self.client.remove_callback(self.async_update_callback)
@callback @callback
@ -93,6 +94,11 @@ class UniFiClient(Entity):
return self.client.is_wired return self.client.is_wired
return self._is_wired return self._is_wired
@property
def unique_id(self):
"""Return a unique identifier for this switch."""
return f"{self.TYPE}-{self.client.mac}"
@property @property
def name(self) -> str: def name(self) -> str:
"""Return the name of the client.""" """Return the name of the client."""
@ -107,8 +113,3 @@ class UniFiClient(Entity):
def device_info(self) -> dict: def device_info(self) -> dict:
"""Return a client description for device registry.""" """Return a client description for device registry."""
return {"connections": {(CONNECTION_NETWORK_MAC, self.client.mac)}} return {"connections": {(CONNECTION_NETWORK_MAC, self.client.mac)}}
@property
def should_poll(self) -> bool:
"""No polling needed."""
return True

View File

@ -0,0 +1,80 @@
"""Base class for UniFi entities."""
from homeassistant.core import callback
from homeassistant.helpers.dispatcher import async_dispatcher_connect
from homeassistant.helpers.entity import Entity
from homeassistant.helpers.entity_registry import async_entries_for_device
class UniFiBase(Entity):
"""UniFi entity base class."""
TYPE = ""
def __init__(self, controller) -> None:
"""Set up UniFi entity base."""
self.controller = controller
@property
def mac(self):
"""Return MAC of entity."""
raise NotImplementedError
async def async_added_to_hass(self) -> None:
"""Entity created."""
self.controller.entities[self.platform.domain][self.TYPE].add(self.mac)
for signal, method in (
(self.controller.signal_reachable, self.async_update_callback),
(self.controller.signal_options_update, self.options_updated),
(self.controller.signal_remove, self.remove_item),
):
self.async_on_remove(async_dispatcher_connect(self.hass, signal, method))
async def async_will_remove_from_hass(self) -> None:
"""Disconnect object when removed."""
self.controller.entities[self.platform.domain][self.TYPE].remove(self.mac)
async def async_remove(self):
"""Clean up when removing entity.
Remove entity if no entry in entity registry exist.
Remove entity registry entry if no entry in device registry exist.
Remove device registry entry if there is only one linked entity (this entity).
Remove entity registry entry if there are more than one entity linked to the device registry entry.
"""
entity_registry = await self.hass.helpers.entity_registry.async_get_registry()
entity_entry = entity_registry.async_get(self.entity_id)
if not entity_entry:
await super().async_remove()
return
device_registry = await self.hass.helpers.device_registry.async_get_registry()
device_entry = device_registry.async_get(entity_entry.device_id)
if not device_entry:
entity_registry.async_remove(self.entity_id)
return
if len(async_entries_for_device(entity_registry, entity_entry.device_id)) == 1:
device_registry.async_remove_device(device_entry.id)
return
entity_registry.async_remove(self.entity_id)
@callback
def async_update_callback(self):
"""Update the entity's state."""
raise NotImplementedError
async def options_updated(self) -> None:
"""Config entry options are updated, remove entity if option is disabled."""
raise NotImplementedError
async def remove_item(self, mac_addresses: set) -> None:
"""Remove entity if MAC is part of set."""
if self.mac in mac_addresses:
await self.async_remove()
@property
def should_poll(self) -> bool:
"""No polling needed."""
return True

View File

@ -207,7 +207,7 @@ async def test_reset_after_successful_setup(hass):
"""Calling reset when the entry has been setup.""" """Calling reset when the entry has been setup."""
controller = await setup_unifi_integration(hass) controller = await setup_unifi_integration(hass)
assert len(controller.listeners) == 9 assert len(controller.listeners) == 6
result = await controller.async_reset() result = await controller.async_reset()
await hass.async_block_till_done() await hass.async_block_till_done()