Add index for area/config_entry/label to the device registry (#114776)

* Add index for area/config_entry/label to the device registry

* use it for services

* naming

* naming

* tweak
pull/114777/head
J. Nick Koston 2024-04-03 16:52:17 -10:00 committed by GitHub
parent 841d3940d1
commit 3f76d1f056
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 76 additions and 18 deletions

View File

@ -491,10 +491,71 @@ class DeviceRegistryItems(BaseRegistryItems[_EntryTypeT]):
return None
class ActiveDeviceRegistryItems(DeviceRegistryItems[DeviceEntry]):
"""Container for active (non-deleted) device registry entries."""
def __init__(self) -> None:
"""Initialize the container.
Maintains three additional indexes:
- area_id -> dict[key, True]
- config_entry_id -> dict[key, True]
- label -> dict[key, True]
"""
super().__init__()
self._area_id_index: dict[str, dict[str, Literal[True]]] = {}
self._config_entry_id_index: dict[str, dict[str, Literal[True]]] = {}
self._labels_index: dict[str, dict[str, Literal[True]]] = {}
def _index_entry(self, key: str, entry: DeviceEntry) -> None:
"""Index an entry."""
super()._index_entry(key, entry)
if (area_id := entry.area_id) is not None:
self._area_id_index.setdefault(area_id, {})[key] = True
for label in entry.labels:
self._labels_index.setdefault(label, {})[key] = True
for config_entry_id in entry.config_entries:
self._config_entry_id_index.setdefault(config_entry_id, {})[key] = True
def _unindex_entry(
self, key: str, replacement_entry: DeviceEntry | None = None
) -> None:
"""Unindex an entry."""
entry = self.data[key]
if area_id := entry.area_id:
self._unindex_entry_value(key, area_id, self._area_id_index)
if labels := entry.labels:
for label in labels:
self._unindex_entry_value(key, label, self._labels_index)
for config_entry_id in entry.config_entries:
self._unindex_entry_value(key, config_entry_id, self._config_entry_id_index)
super()._unindex_entry(key, replacement_entry)
def get_devices_for_area_id(self, area_id: str) -> list[DeviceEntry]:
"""Get devices for area."""
data = self.data
return [data[key] for key in self._area_id_index.get(area_id, ())]
def get_devices_for_label(self, label: str) -> list[DeviceEntry]:
"""Get devices for label."""
data = self.data
return [data[key] for key in self._labels_index.get(label, ())]
def get_devices_for_config_entry_id(
self, config_entry_id: str
) -> list[DeviceEntry]:
"""Get devices for config entry."""
data = self.data
return [
data[key] for key in self._config_entry_id_index.get(config_entry_id, ())
]
class DeviceRegistry(BaseRegistry):
"""Class to hold a registry of devices."""
devices: DeviceRegistryItems[DeviceEntry]
devices: ActiveDeviceRegistryItems
deleted_devices: DeviceRegistryItems[DeletedDeviceEntry]
_device_data: dict[str, DeviceEntry]
@ -884,7 +945,7 @@ class DeviceRegistry(BaseRegistry):
data = await self._store.async_load()
devices: DeviceRegistryItems[DeviceEntry] = DeviceRegistryItems()
devices = ActiveDeviceRegistryItems()
deleted_devices: DeviceRegistryItems[DeletedDeviceEntry] = DeviceRegistryItems()
if data is not None:
@ -1018,7 +1079,7 @@ async def async_load(hass: HomeAssistant) -> None:
@callback
def async_entries_for_area(registry: DeviceRegistry, area_id: str) -> list[DeviceEntry]:
"""Return entries that match an area."""
return [device for device in registry.devices.values() if device.area_id == area_id]
return registry.devices.get_devices_for_area_id(area_id)
@callback
@ -1026,7 +1087,7 @@ def async_entries_for_label(
registry: DeviceRegistry, label_id: str
) -> list[DeviceEntry]:
"""Return entries that match a label."""
return [device for device in registry.devices.values() if label_id in device.labels]
return registry.devices.get_devices_for_label(label_id)
@callback
@ -1034,11 +1095,7 @@ def async_entries_for_config_entry(
registry: DeviceRegistry, config_entry_id: str
) -> list[DeviceEntry]:
"""Return entries that match a config entry."""
return [
device
for device in registry.devices.values()
if config_entry_id in device.config_entries
]
return registry.devices.get_devices_for_config_entry_id(config_entry_id)
@callback

View File

@ -534,15 +534,14 @@ def async_extract_referenced_entity_ids( # noqa: C901
):
selected.indirectly_referenced.add(entity_entry.entity_id)
# Find areas, devices & entities for targeted labels
for device_entry in dev_reg.devices.get_devices_for_label(label_id):
selected.referenced_devices.add(device_entry.id)
# Find areas for targeted labels
for area_entry in area_reg.areas.values():
if area_entry.labels.intersection(selector.label_ids):
selected.referenced_areas.add(area_entry.id)
for device_entry in dev_reg.devices.values():
if device_entry.labels.intersection(selector.label_ids):
selected.referenced_devices.add(device_entry.id)
# Find areas for targeted floors
if selector.floor_ids:
for area_entry in area_reg.areas.values():
@ -554,9 +553,11 @@ def async_extract_referenced_entity_ids( # noqa: C901
selected.referenced_areas.update(selector.area_ids)
if selected.referenced_areas:
for device_entry in dev_reg.devices.values():
if device_entry.area_id in selected.referenced_areas:
selected.referenced_devices.add(device_entry.id)
for area_id in selected.referenced_areas:
selected.referenced_devices.update(
device_entry.id
for device_entry in dev_reg.devices.get_devices_for_area_id(area_id)
)
if not selected.referenced_areas and not selected.referenced_devices:
return selected

View File

@ -671,7 +671,7 @@ def mock_device_registry(
fixture instead.
"""
registry = dr.DeviceRegistry(hass)
registry.devices = dr.DeviceRegistryItems()
registry.devices = dr.ActiveDeviceRegistryItems()
registry._device_data = registry.devices.data
if mock_entries is None:
mock_entries = {}