Add an index for devices and config entries to the entity registry (#107516)

* Add an index for devices and config entries to the entity registry

* fixes

* tweak

* use a list for now since the tests check order
pull/105955/head
J. Nick Koston 2024-01-13 09:49:41 -10:00 committed by GitHub
parent 5d3e069655
commit d7910841ef
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 43 additions and 15 deletions

View File

@ -436,9 +436,11 @@ class EntityRegistryStore(storage.Store[dict[str, list[dict[str, Any]]]]):
class EntityRegistryItems(UserDict[str, RegistryEntry]):
"""Container for entity registry items, maps entity_id -> entry.
Maintains two additional indexes:
Maintains four additional indexes:
- id -> entry
- (domain, platform, unique_id) -> entity_id
- config_entry_id -> list[key]
- device_id -> list[key]
"""
def __init__(self) -> None:
@ -446,6 +448,8 @@ class EntityRegistryItems(UserDict[str, RegistryEntry]):
super().__init__()
self._entry_ids: dict[str, RegistryEntry] = {}
self._index: dict[tuple[str, str, str], str] = {}
self._config_entry_id_index: dict[str, list[str]] = {}
self._device_id_index: dict[str, list[str]] = {}
def values(self) -> ValuesView[RegistryEntry]:
"""Return the underlying values to avoid __iter__ overhead."""
@ -455,18 +459,34 @@ class EntityRegistryItems(UserDict[str, RegistryEntry]):
"""Add an item."""
data = self.data
if key in data:
old_entry = data[key]
del self._entry_ids[old_entry.id]
del self._index[(old_entry.domain, old_entry.platform, old_entry.unique_id)]
self._unindex_entry(key)
data[key] = entry
self._entry_ids[entry.id] = entry
self._index[(entry.domain, entry.platform, entry.unique_id)] = entry.entity_id
if (config_entry_id := entry.config_entry_id) is not None:
self._config_entry_id_index.setdefault(config_entry_id, []).append(key)
if (device_id := entry.device_id) is not None:
self._device_id_index.setdefault(device_id, []).append(key)
def _unindex_entry(self, key: str) -> None:
"""Unindex an entry."""
entry = self.data[key]
del self._entry_ids[entry.id]
del self._index[(entry.domain, entry.platform, entry.unique_id)]
if (config_entry_id := entry.config_entry_id) is not None:
entries = self._config_entry_id_index[config_entry_id]
entries.remove(key)
if not entries:
del self._config_entry_id_index[config_entry_id]
if (device_id := entry.device_id) is not None:
entries = self._device_id_index[device_id]
entries.remove(key)
if not entries:
del self._device_id_index[device_id]
def __delitem__(self, key: str) -> None:
"""Remove an item."""
entry = self[key]
del self._entry_ids[entry.id]
del self._index[(entry.domain, entry.platform, entry.unique_id)]
self._unindex_entry(key)
super().__delitem__(key)
def get_entity_id(self, key: tuple[str, str, str]) -> str | None:
@ -477,6 +497,19 @@ class EntityRegistryItems(UserDict[str, RegistryEntry]):
"""Get entry from id."""
return self._entry_ids.get(key)
def get_entries_for_device_id(self, device_id: str) -> list[RegistryEntry]:
"""Get entries for device."""
return [self.data[key] for key in self._device_id_index.get(device_id, ())]
def get_entries_for_config_entry_id(
self, config_entry_id: str
) -> list[RegistryEntry]:
"""Get entries for config entry."""
return [
self.data[key]
for key in self._config_entry_id_index.get(config_entry_id, ())
]
class EntityRegistry:
"""Class to hold a registry of entities."""
@ -1217,9 +1250,8 @@ def async_entries_for_device(
"""Return entries that match a device."""
return [
entry
for entry in registry.entities.values()
if entry.device_id == device_id
and (not entry.disabled_by or include_disabled_entities)
for entry in registry.entities.get_entries_for_device_id(device_id)
if (not entry.disabled_by or include_disabled_entities)
]
@ -1236,11 +1268,7 @@ def async_entries_for_config_entry(
registry: EntityRegistry, config_entry_id: str
) -> list[RegistryEntry]:
"""Return entries that match a config entry."""
return [
entry
for entry in registry.entities.values()
if entry.config_entry_id == config_entry_id
]
return registry.entities.get_entries_for_config_entry_id(config_entry_id)
@callback