Index the device registry (#37990)
parent
92d72f26c7
commit
6ea5c8aed9
|
@ -1,7 +1,7 @@
|
||||||
"""Provide a way to connect entities belonging to one device."""
|
"""Provide a way to connect entities belonging to one device."""
|
||||||
from collections import OrderedDict
|
from collections import OrderedDict
|
||||||
import logging
|
import logging
|
||||||
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Set, Tuple
|
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Set, Tuple, Union
|
||||||
import uuid
|
import uuid
|
||||||
|
|
||||||
import attr
|
import attr
|
||||||
|
@ -32,6 +32,11 @@ CONNECTION_NETWORK_MAC = "mac"
|
||||||
CONNECTION_UPNP = "upnp"
|
CONNECTION_UPNP = "upnp"
|
||||||
CONNECTION_ZIGBEE = "zigbee"
|
CONNECTION_ZIGBEE = "zigbee"
|
||||||
|
|
||||||
|
IDX_CONNECTIONS = "connections"
|
||||||
|
IDX_IDENTIFIERS = "identifiers"
|
||||||
|
REGISTERED_DEVICE = "registered"
|
||||||
|
DELETED_DEVICE = "deleted"
|
||||||
|
|
||||||
|
|
||||||
@attr.s(slots=True, frozen=True)
|
@attr.s(slots=True, frozen=True)
|
||||||
class DeletedDeviceEntry:
|
class DeletedDeviceEntry:
|
||||||
|
@ -98,11 +103,13 @@ class DeviceRegistry:
|
||||||
|
|
||||||
devices: Dict[str, DeviceEntry]
|
devices: Dict[str, DeviceEntry]
|
||||||
deleted_devices: Dict[str, DeletedDeviceEntry]
|
deleted_devices: Dict[str, DeletedDeviceEntry]
|
||||||
|
_devices_index: Dict[str, Dict[str, Dict[str, str]]]
|
||||||
|
|
||||||
def __init__(self, hass: HomeAssistantType) -> None:
|
def __init__(self, hass: HomeAssistantType) -> None:
|
||||||
"""Initialize the device registry."""
|
"""Initialize the device registry."""
|
||||||
self.hass = hass
|
self.hass = hass
|
||||||
self._store = hass.helpers.storage.Store(STORAGE_VERSION, STORAGE_KEY)
|
self._store = hass.helpers.storage.Store(STORAGE_VERSION, STORAGE_KEY)
|
||||||
|
self._clear_index()
|
||||||
|
|
||||||
@callback
|
@callback
|
||||||
def async_get(self, device_id: str) -> Optional[DeviceEntry]:
|
def async_get(self, device_id: str) -> Optional[DeviceEntry]:
|
||||||
|
@ -114,24 +121,83 @@ class DeviceRegistry:
|
||||||
self, identifiers: set, connections: set
|
self, identifiers: set, connections: set
|
||||||
) -> Optional[DeviceEntry]:
|
) -> Optional[DeviceEntry]:
|
||||||
"""Check if device is registered."""
|
"""Check if device is registered."""
|
||||||
for device in self.devices.values():
|
device_id = self._async_get_device_id_from_index(
|
||||||
if any(iden in device.identifiers for iden in identifiers) or any(
|
REGISTERED_DEVICE, identifiers, connections
|
||||||
conn in device.connections for conn in connections
|
)
|
||||||
):
|
if device_id is None:
|
||||||
return device
|
|
||||||
return None
|
return None
|
||||||
|
return self.devices[device_id]
|
||||||
|
|
||||||
@callback
|
|
||||||
def _async_get_deleted_device(
|
def _async_get_deleted_device(
|
||||||
self, identifiers: set, connections: set
|
self, identifiers: set, connections: set
|
||||||
) -> Optional[DeletedDeviceEntry]:
|
) -> Optional[DeletedDeviceEntry]:
|
||||||
"""Check if device has previously been registered."""
|
"""Check if device is deleted."""
|
||||||
for device in self.deleted_devices.values():
|
device_id = self._async_get_device_id_from_index(
|
||||||
if any(iden in device.identifiers for iden in identifiers) or any(
|
DELETED_DEVICE, identifiers, connections
|
||||||
conn in device.connections for conn in connections
|
)
|
||||||
):
|
if device_id is None:
|
||||||
return device
|
|
||||||
return None
|
return None
|
||||||
|
return self.deleted_devices[device_id]
|
||||||
|
|
||||||
|
def _async_get_device_id_from_index(
|
||||||
|
self, index: str, identifiers: set, connections: set
|
||||||
|
) -> Optional[str]:
|
||||||
|
"""Check if device has previously been registered."""
|
||||||
|
devices_index = self._devices_index[index]
|
||||||
|
for identifier in identifiers:
|
||||||
|
if identifier in devices_index[IDX_IDENTIFIERS]:
|
||||||
|
return devices_index[IDX_IDENTIFIERS][identifier]
|
||||||
|
if not connections:
|
||||||
|
return None
|
||||||
|
for connection in _normalize_connections(connections):
|
||||||
|
if connection in devices_index[IDX_CONNECTIONS]:
|
||||||
|
return devices_index[IDX_CONNECTIONS][connection]
|
||||||
|
return None
|
||||||
|
|
||||||
|
def _add_device(self, device: Union[DeviceEntry, DeletedDeviceEntry]) -> None:
|
||||||
|
"""Add a device and index it."""
|
||||||
|
if isinstance(device, DeletedDeviceEntry):
|
||||||
|
devices_index = self._devices_index[DELETED_DEVICE]
|
||||||
|
self.deleted_devices[device.id] = device
|
||||||
|
else:
|
||||||
|
devices_index = self._devices_index[REGISTERED_DEVICE]
|
||||||
|
self.devices[device.id] = device
|
||||||
|
|
||||||
|
_add_device_to_index(devices_index, device)
|
||||||
|
|
||||||
|
def _remove_device(self, device: Union[DeviceEntry, DeletedDeviceEntry]) -> None:
|
||||||
|
"""Remove a device and remove it from the index."""
|
||||||
|
if isinstance(device, DeletedDeviceEntry):
|
||||||
|
devices_index = self._devices_index[DELETED_DEVICE]
|
||||||
|
self.deleted_devices.pop(device.id)
|
||||||
|
else:
|
||||||
|
devices_index = self._devices_index[REGISTERED_DEVICE]
|
||||||
|
self.devices.pop(device.id)
|
||||||
|
|
||||||
|
_remove_device_from_index(devices_index, device)
|
||||||
|
|
||||||
|
def _update_device(self, old_device: DeviceEntry, new_device: DeviceEntry) -> None:
|
||||||
|
"""Update a device and the index."""
|
||||||
|
self.devices[new_device.id] = new_device
|
||||||
|
|
||||||
|
devices_index = self._devices_index[REGISTERED_DEVICE]
|
||||||
|
_remove_device_from_index(devices_index, old_device)
|
||||||
|
_add_device_to_index(devices_index, new_device)
|
||||||
|
|
||||||
|
def _clear_index(self):
|
||||||
|
"""Clear the index."""
|
||||||
|
self._devices_index = {
|
||||||
|
REGISTERED_DEVICE: {IDX_IDENTIFIERS: {}, IDX_CONNECTIONS: {}},
|
||||||
|
DELETED_DEVICE: {IDX_IDENTIFIERS: {}, IDX_CONNECTIONS: {}},
|
||||||
|
}
|
||||||
|
|
||||||
|
def _rebuild_index(self):
|
||||||
|
"""Create the index after loading devices."""
|
||||||
|
self._clear_index()
|
||||||
|
for device in self.devices.values():
|
||||||
|
_add_device_to_index(self._devices_index[REGISTERED_DEVICE], device)
|
||||||
|
for device in self.deleted_devices.values():
|
||||||
|
_add_device_to_index(self._devices_index[DELETED_DEVICE], device)
|
||||||
|
|
||||||
@callback
|
@callback
|
||||||
def async_get_or_create(
|
def async_get_or_create(
|
||||||
|
@ -156,11 +222,8 @@ class DeviceRegistry:
|
||||||
|
|
||||||
if connections is None:
|
if connections is None:
|
||||||
connections = set()
|
connections = set()
|
||||||
|
else:
|
||||||
connections = {
|
connections = _normalize_connections(connections)
|
||||||
(key, format_mac(value)) if key == CONNECTION_NETWORK_MAC else (key, value)
|
|
||||||
for key, value in connections
|
|
||||||
}
|
|
||||||
|
|
||||||
device = self.async_get_device(identifiers, connections)
|
device = self.async_get_device(identifiers, connections)
|
||||||
|
|
||||||
|
@ -169,9 +232,9 @@ class DeviceRegistry:
|
||||||
if deleted_device is None:
|
if deleted_device is None:
|
||||||
device = DeviceEntry(is_new=True)
|
device = DeviceEntry(is_new=True)
|
||||||
else:
|
else:
|
||||||
self.deleted_devices.pop(deleted_device.id)
|
self._remove_device(deleted_device)
|
||||||
device = deleted_device.to_device_entry()
|
device = deleted_device.to_device_entry()
|
||||||
self.devices[device.id] = device
|
self._add_device(device)
|
||||||
|
|
||||||
if via_device is not None:
|
if via_device is not None:
|
||||||
via = self.async_get_device({via_device}, set())
|
via = self.async_get_device({via_device}, set())
|
||||||
|
@ -301,7 +364,8 @@ class DeviceRegistry:
|
||||||
if not changes:
|
if not changes:
|
||||||
return old
|
return old
|
||||||
|
|
||||||
new = self.devices[device_id] = attr.evolve(old, **changes)
|
new = attr.evolve(old, **changes)
|
||||||
|
self._update_device(old, new)
|
||||||
self.async_schedule_save()
|
self.async_schedule_save()
|
||||||
|
|
||||||
self.hass.bus.async_fire(
|
self.hass.bus.async_fire(
|
||||||
|
@ -317,13 +381,16 @@ class DeviceRegistry:
|
||||||
@callback
|
@callback
|
||||||
def async_remove_device(self, device_id: str) -> None:
|
def async_remove_device(self, device_id: str) -> None:
|
||||||
"""Remove a device from the device registry."""
|
"""Remove a device from the device registry."""
|
||||||
device = self.devices.pop(device_id)
|
device = self.devices[device_id]
|
||||||
self.deleted_devices[device_id] = DeletedDeviceEntry(
|
self._remove_device(device)
|
||||||
|
self._add_device(
|
||||||
|
DeletedDeviceEntry(
|
||||||
config_entries=device.config_entries,
|
config_entries=device.config_entries,
|
||||||
connections=device.connections,
|
connections=device.connections,
|
||||||
identifiers=device.identifiers,
|
identifiers=device.identifiers,
|
||||||
id=device.id,
|
id=device.id,
|
||||||
)
|
)
|
||||||
|
)
|
||||||
self.hass.bus.async_fire(
|
self.hass.bus.async_fire(
|
||||||
EVENT_DEVICE_REGISTRY_UPDATED, {"action": "remove", "device_id": device_id}
|
EVENT_DEVICE_REGISTRY_UPDATED, {"action": "remove", "device_id": device_id}
|
||||||
)
|
)
|
||||||
|
@ -371,6 +438,7 @@ class DeviceRegistry:
|
||||||
|
|
||||||
self.devices = devices
|
self.devices = devices
|
||||||
self.deleted_devices = deleted_devices
|
self.deleted_devices = deleted_devices
|
||||||
|
self._rebuild_index()
|
||||||
|
|
||||||
@callback
|
@callback
|
||||||
def async_schedule_save(self) -> None:
|
def async_schedule_save(self) -> None:
|
||||||
|
@ -422,9 +490,11 @@ class DeviceRegistry:
|
||||||
continue
|
continue
|
||||||
if config_entries == {config_entry_id}:
|
if config_entries == {config_entry_id}:
|
||||||
# Permanently remove the device from the device registry.
|
# Permanently remove the device from the device registry.
|
||||||
del self.deleted_devices[deleted_device.id]
|
self._remove_device(deleted_device)
|
||||||
else:
|
else:
|
||||||
config_entries = config_entries - {config_entry_id}
|
config_entries = config_entries - {config_entry_id}
|
||||||
|
# No need to reindex here since we currently
|
||||||
|
# do not have a lookup by config entry
|
||||||
self.deleted_devices[deleted_device.id] = attr.evolve(
|
self.deleted_devices[deleted_device.id] = attr.evolve(
|
||||||
deleted_device, config_entries=config_entries
|
deleted_device, config_entries=config_entries
|
||||||
)
|
)
|
||||||
|
@ -536,3 +606,33 @@ def async_setup_cleanup(hass: HomeAssistantType, dev_reg: DeviceRegistry) -> Non
|
||||||
await debounced_cleanup.async_call()
|
await debounced_cleanup.async_call()
|
||||||
|
|
||||||
hass.bus.async_listen_once(EVENT_HOMEASSISTANT_STARTED, startup_clean)
|
hass.bus.async_listen_once(EVENT_HOMEASSISTANT_STARTED, startup_clean)
|
||||||
|
|
||||||
|
|
||||||
|
def _normalize_connections(connections: set) -> set:
|
||||||
|
"""Normalize connections to ensure we can match mac addresses."""
|
||||||
|
return {
|
||||||
|
(key, format_mac(value)) if key == CONNECTION_NETWORK_MAC else (key, value)
|
||||||
|
for key, value in connections
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def _add_device_to_index(
|
||||||
|
devices_index: dict, device: Union[DeviceEntry, DeletedDeviceEntry]
|
||||||
|
) -> None:
|
||||||
|
"""Add a device to the index."""
|
||||||
|
for identifier in device.identifiers:
|
||||||
|
devices_index[IDX_IDENTIFIERS][identifier] = device.id
|
||||||
|
for connection in device.connections:
|
||||||
|
devices_index[IDX_CONNECTIONS][connection] = device.id
|
||||||
|
|
||||||
|
|
||||||
|
def _remove_device_from_index(
|
||||||
|
devices_index: dict, device: Union[DeviceEntry, DeletedDeviceEntry]
|
||||||
|
) -> None:
|
||||||
|
"""Remove a device from the index."""
|
||||||
|
for identifier in device.identifiers:
|
||||||
|
if identifier in devices_index[IDX_IDENTIFIERS]:
|
||||||
|
del devices_index[IDX_IDENTIFIERS][identifier]
|
||||||
|
for connection in device.connections:
|
||||||
|
if connection in devices_index[IDX_CONNECTIONS]:
|
||||||
|
del devices_index[IDX_CONNECTIONS][connection]
|
||||||
|
|
|
@ -371,6 +371,7 @@ def mock_device_registry(hass, mock_entries=None, mock_deleted_entries=None):
|
||||||
registry = device_registry.DeviceRegistry(hass)
|
registry = device_registry.DeviceRegistry(hass)
|
||||||
registry.devices = mock_entries or OrderedDict()
|
registry.devices = mock_entries or OrderedDict()
|
||||||
registry.deleted_devices = mock_deleted_entries or OrderedDict()
|
registry.deleted_devices = mock_deleted_entries or OrderedDict()
|
||||||
|
registry._rebuild_index()
|
||||||
|
|
||||||
hass.data[device_registry.DATA_REGISTRY] = registry
|
hass.data[device_registry.DATA_REGISTRY] = registry
|
||||||
return registry
|
return registry
|
||||||
|
|
|
@ -562,6 +562,21 @@ async def test_update(registry):
|
||||||
assert updated_entry.identifiers == new_identifiers
|
assert updated_entry.identifiers == new_identifiers
|
||||||
assert updated_entry.via_device_id == "98765B"
|
assert updated_entry.via_device_id == "98765B"
|
||||||
|
|
||||||
|
assert registry.async_get_device({("hue", "456")}, {}) is None
|
||||||
|
assert registry.async_get_device({("bla", "123")}, {}) is None
|
||||||
|
|
||||||
|
assert registry.async_get_device({("hue", "654")}, {}) == updated_entry
|
||||||
|
assert registry.async_get_device({("bla", "321")}, {}) == updated_entry
|
||||||
|
|
||||||
|
assert (
|
||||||
|
registry.async_get_device(
|
||||||
|
{}, {(device_registry.CONNECTION_NETWORK_MAC, "12:34:56:AB:CD:EF")}
|
||||||
|
)
|
||||||
|
== updated_entry
|
||||||
|
)
|
||||||
|
|
||||||
|
assert registry.async_get(updated_entry.id) is not None
|
||||||
|
|
||||||
|
|
||||||
async def test_update_remove_config_entries(hass, registry, update_events):
|
async def test_update_remove_config_entries(hass, registry, update_events):
|
||||||
"""Make sure we do not get duplicate entries."""
|
"""Make sure we do not get duplicate entries."""
|
||||||
|
|
Loading…
Reference in New Issue