Index the device registry (#37990)

pull/38005/head
J. Nick Koston 2020-07-19 20:32:05 -10:00 committed by GitHub
parent 92d72f26c7
commit 6ea5c8aed9
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 144 additions and 28 deletions

View File

@ -1,7 +1,7 @@
"""Provide a way to connect entities belonging to one device.""" """Provide a way to connect entities belonging to one device."""
from collections import OrderedDict from collections import OrderedDict
import logging import logging
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Set, Tuple from typing import TYPE_CHECKING, Any, Dict, List, Optional, Set, Tuple, Union
import uuid import uuid
import attr import attr
@ -32,6 +32,11 @@ CONNECTION_NETWORK_MAC = "mac"
CONNECTION_UPNP = "upnp" CONNECTION_UPNP = "upnp"
CONNECTION_ZIGBEE = "zigbee" CONNECTION_ZIGBEE = "zigbee"
IDX_CONNECTIONS = "connections"
IDX_IDENTIFIERS = "identifiers"
REGISTERED_DEVICE = "registered"
DELETED_DEVICE = "deleted"
@attr.s(slots=True, frozen=True) @attr.s(slots=True, frozen=True)
class DeletedDeviceEntry: class DeletedDeviceEntry:
@ -98,11 +103,13 @@ class DeviceRegistry:
devices: Dict[str, DeviceEntry] devices: Dict[str, DeviceEntry]
deleted_devices: Dict[str, DeletedDeviceEntry] deleted_devices: Dict[str, DeletedDeviceEntry]
_devices_index: Dict[str, Dict[str, Dict[str, str]]]
def __init__(self, hass: HomeAssistantType) -> None: def __init__(self, hass: HomeAssistantType) -> None:
"""Initialize the device registry.""" """Initialize the device registry."""
self.hass = hass self.hass = hass
self._store = hass.helpers.storage.Store(STORAGE_VERSION, STORAGE_KEY) self._store = hass.helpers.storage.Store(STORAGE_VERSION, STORAGE_KEY)
self._clear_index()
@callback @callback
def async_get(self, device_id: str) -> Optional[DeviceEntry]: def async_get(self, device_id: str) -> Optional[DeviceEntry]:
@ -114,24 +121,83 @@ class DeviceRegistry:
self, identifiers: set, connections: set self, identifiers: set, connections: set
) -> Optional[DeviceEntry]: ) -> Optional[DeviceEntry]:
"""Check if device is registered.""" """Check if device is registered."""
for device in self.devices.values(): device_id = self._async_get_device_id_from_index(
if any(iden in device.identifiers for iden in identifiers) or any( REGISTERED_DEVICE, identifiers, connections
conn in device.connections for conn in connections )
): if device_id is None:
return device
return None return None
return self.devices[device_id]
@callback
def _async_get_deleted_device( def _async_get_deleted_device(
self, identifiers: set, connections: set self, identifiers: set, connections: set
) -> Optional[DeletedDeviceEntry]: ) -> Optional[DeletedDeviceEntry]:
"""Check if device has previously been registered.""" """Check if device is deleted."""
for device in self.deleted_devices.values(): device_id = self._async_get_device_id_from_index(
if any(iden in device.identifiers for iden in identifiers) or any( DELETED_DEVICE, identifiers, connections
conn in device.connections for conn in connections )
): if device_id is None:
return device
return None return None
return self.deleted_devices[device_id]
def _async_get_device_id_from_index(
self, index: str, identifiers: set, connections: set
) -> Optional[str]:
"""Check if device has previously been registered."""
devices_index = self._devices_index[index]
for identifier in identifiers:
if identifier in devices_index[IDX_IDENTIFIERS]:
return devices_index[IDX_IDENTIFIERS][identifier]
if not connections:
return None
for connection in _normalize_connections(connections):
if connection in devices_index[IDX_CONNECTIONS]:
return devices_index[IDX_CONNECTIONS][connection]
return None
def _add_device(self, device: Union[DeviceEntry, DeletedDeviceEntry]) -> None:
"""Add a device and index it."""
if isinstance(device, DeletedDeviceEntry):
devices_index = self._devices_index[DELETED_DEVICE]
self.deleted_devices[device.id] = device
else:
devices_index = self._devices_index[REGISTERED_DEVICE]
self.devices[device.id] = device
_add_device_to_index(devices_index, device)
def _remove_device(self, device: Union[DeviceEntry, DeletedDeviceEntry]) -> None:
"""Remove a device and remove it from the index."""
if isinstance(device, DeletedDeviceEntry):
devices_index = self._devices_index[DELETED_DEVICE]
self.deleted_devices.pop(device.id)
else:
devices_index = self._devices_index[REGISTERED_DEVICE]
self.devices.pop(device.id)
_remove_device_from_index(devices_index, device)
def _update_device(self, old_device: DeviceEntry, new_device: DeviceEntry) -> None:
"""Update a device and the index."""
self.devices[new_device.id] = new_device
devices_index = self._devices_index[REGISTERED_DEVICE]
_remove_device_from_index(devices_index, old_device)
_add_device_to_index(devices_index, new_device)
def _clear_index(self):
"""Clear the index."""
self._devices_index = {
REGISTERED_DEVICE: {IDX_IDENTIFIERS: {}, IDX_CONNECTIONS: {}},
DELETED_DEVICE: {IDX_IDENTIFIERS: {}, IDX_CONNECTIONS: {}},
}
def _rebuild_index(self):
"""Create the index after loading devices."""
self._clear_index()
for device in self.devices.values():
_add_device_to_index(self._devices_index[REGISTERED_DEVICE], device)
for device in self.deleted_devices.values():
_add_device_to_index(self._devices_index[DELETED_DEVICE], device)
@callback @callback
def async_get_or_create( def async_get_or_create(
@ -156,11 +222,8 @@ class DeviceRegistry:
if connections is None: if connections is None:
connections = set() connections = set()
else:
connections = { connections = _normalize_connections(connections)
(key, format_mac(value)) if key == CONNECTION_NETWORK_MAC else (key, value)
for key, value in connections
}
device = self.async_get_device(identifiers, connections) device = self.async_get_device(identifiers, connections)
@ -169,9 +232,9 @@ class DeviceRegistry:
if deleted_device is None: if deleted_device is None:
device = DeviceEntry(is_new=True) device = DeviceEntry(is_new=True)
else: else:
self.deleted_devices.pop(deleted_device.id) self._remove_device(deleted_device)
device = deleted_device.to_device_entry() device = deleted_device.to_device_entry()
self.devices[device.id] = device self._add_device(device)
if via_device is not None: if via_device is not None:
via = self.async_get_device({via_device}, set()) via = self.async_get_device({via_device}, set())
@ -301,7 +364,8 @@ class DeviceRegistry:
if not changes: if not changes:
return old return old
new = self.devices[device_id] = attr.evolve(old, **changes) new = attr.evolve(old, **changes)
self._update_device(old, new)
self.async_schedule_save() self.async_schedule_save()
self.hass.bus.async_fire( self.hass.bus.async_fire(
@ -317,13 +381,16 @@ class DeviceRegistry:
@callback @callback
def async_remove_device(self, device_id: str) -> None: def async_remove_device(self, device_id: str) -> None:
"""Remove a device from the device registry.""" """Remove a device from the device registry."""
device = self.devices.pop(device_id) device = self.devices[device_id]
self.deleted_devices[device_id] = DeletedDeviceEntry( self._remove_device(device)
self._add_device(
DeletedDeviceEntry(
config_entries=device.config_entries, config_entries=device.config_entries,
connections=device.connections, connections=device.connections,
identifiers=device.identifiers, identifiers=device.identifiers,
id=device.id, id=device.id,
) )
)
self.hass.bus.async_fire( self.hass.bus.async_fire(
EVENT_DEVICE_REGISTRY_UPDATED, {"action": "remove", "device_id": device_id} EVENT_DEVICE_REGISTRY_UPDATED, {"action": "remove", "device_id": device_id}
) )
@ -371,6 +438,7 @@ class DeviceRegistry:
self.devices = devices self.devices = devices
self.deleted_devices = deleted_devices self.deleted_devices = deleted_devices
self._rebuild_index()
@callback @callback
def async_schedule_save(self) -> None: def async_schedule_save(self) -> None:
@ -422,9 +490,11 @@ class DeviceRegistry:
continue continue
if config_entries == {config_entry_id}: if config_entries == {config_entry_id}:
# Permanently remove the device from the device registry. # Permanently remove the device from the device registry.
del self.deleted_devices[deleted_device.id] self._remove_device(deleted_device)
else: else:
config_entries = config_entries - {config_entry_id} config_entries = config_entries - {config_entry_id}
# No need to reindex here since we currently
# do not have a lookup by config entry
self.deleted_devices[deleted_device.id] = attr.evolve( self.deleted_devices[deleted_device.id] = attr.evolve(
deleted_device, config_entries=config_entries deleted_device, config_entries=config_entries
) )
@ -536,3 +606,33 @@ def async_setup_cleanup(hass: HomeAssistantType, dev_reg: DeviceRegistry) -> Non
await debounced_cleanup.async_call() await debounced_cleanup.async_call()
hass.bus.async_listen_once(EVENT_HOMEASSISTANT_STARTED, startup_clean) hass.bus.async_listen_once(EVENT_HOMEASSISTANT_STARTED, startup_clean)
def _normalize_connections(connections: set) -> set:
"""Normalize connections to ensure we can match mac addresses."""
return {
(key, format_mac(value)) if key == CONNECTION_NETWORK_MAC else (key, value)
for key, value in connections
}
def _add_device_to_index(
devices_index: dict, device: Union[DeviceEntry, DeletedDeviceEntry]
) -> None:
"""Add a device to the index."""
for identifier in device.identifiers:
devices_index[IDX_IDENTIFIERS][identifier] = device.id
for connection in device.connections:
devices_index[IDX_CONNECTIONS][connection] = device.id
def _remove_device_from_index(
devices_index: dict, device: Union[DeviceEntry, DeletedDeviceEntry]
) -> None:
"""Remove a device from the index."""
for identifier in device.identifiers:
if identifier in devices_index[IDX_IDENTIFIERS]:
del devices_index[IDX_IDENTIFIERS][identifier]
for connection in device.connections:
if connection in devices_index[IDX_CONNECTIONS]:
del devices_index[IDX_CONNECTIONS][connection]

View File

@ -371,6 +371,7 @@ def mock_device_registry(hass, mock_entries=None, mock_deleted_entries=None):
registry = device_registry.DeviceRegistry(hass) registry = device_registry.DeviceRegistry(hass)
registry.devices = mock_entries or OrderedDict() registry.devices = mock_entries or OrderedDict()
registry.deleted_devices = mock_deleted_entries or OrderedDict() registry.deleted_devices = mock_deleted_entries or OrderedDict()
registry._rebuild_index()
hass.data[device_registry.DATA_REGISTRY] = registry hass.data[device_registry.DATA_REGISTRY] = registry
return registry return registry

View File

@ -562,6 +562,21 @@ async def test_update(registry):
assert updated_entry.identifiers == new_identifiers assert updated_entry.identifiers == new_identifiers
assert updated_entry.via_device_id == "98765B" assert updated_entry.via_device_id == "98765B"
assert registry.async_get_device({("hue", "456")}, {}) is None
assert registry.async_get_device({("bla", "123")}, {}) is None
assert registry.async_get_device({("hue", "654")}, {}) == updated_entry
assert registry.async_get_device({("bla", "321")}, {}) == updated_entry
assert (
registry.async_get_device(
{}, {(device_registry.CONNECTION_NETWORK_MAC, "12:34:56:AB:CD:EF")}
)
== updated_entry
)
assert registry.async_get(updated_entry.id) is not None
async def test_update_remove_config_entries(hass, registry, update_events): async def test_update_remove_config_entries(hass, registry, update_events):
"""Make sure we do not get duplicate entries.""" """Make sure we do not get duplicate entries."""