Index the entity registry (#37994)
parent
41421b56a4
commit
890562e3ae
|
@ -124,6 +124,7 @@ class EntityRegistry:
|
|||
"""Initialize the registry."""
|
||||
self.hass = hass
|
||||
self.entities: Dict[str, RegistryEntry]
|
||||
self._index: Dict[Tuple[str, str, str], str] = {}
|
||||
self._store = hass.helpers.storage.Store(STORAGE_VERSION, STORAGE_KEY)
|
||||
self.hass.bus.async_listen(
|
||||
EVENT_DEVICE_REGISTRY_UPDATED, self.async_device_removed
|
||||
|
@ -160,14 +161,7 @@ class EntityRegistry:
|
|||
self, domain: str, platform: str, unique_id: str
|
||||
) -> Optional[str]:
|
||||
"""Check if an entity_id is currently registered."""
|
||||
for entity in self.entities.values():
|
||||
if (
|
||||
entity.domain == domain
|
||||
and entity.platform == platform
|
||||
and entity.unique_id == unique_id
|
||||
):
|
||||
return entity.entity_id
|
||||
return None
|
||||
return self._index.get((domain, platform, unique_id))
|
||||
|
||||
@callback
|
||||
def async_generate_entity_id(
|
||||
|
@ -270,7 +264,7 @@ class EntityRegistry:
|
|||
original_name=original_name,
|
||||
original_icon=original_icon,
|
||||
)
|
||||
self.entities[entity_id] = entity
|
||||
self._register_entry(entity)
|
||||
_LOGGER.info("Registered new %s.%s entity: %s", domain, platform, entity_id)
|
||||
self.async_schedule_save()
|
||||
|
||||
|
@ -283,7 +277,7 @@ class EntityRegistry:
|
|||
@callback
|
||||
def async_remove(self, entity_id: str) -> None:
|
||||
"""Remove an entity from registry."""
|
||||
self.entities.pop(entity_id)
|
||||
self._unregister_entry(self.entities[entity_id])
|
||||
self.hass.bus.async_fire(
|
||||
EVENT_ENTITY_REGISTRY_UPDATED, {"action": "remove", "entity_id": entity_id}
|
||||
)
|
||||
|
@ -380,27 +374,22 @@ class EntityRegistry:
|
|||
entity_id = changes["entity_id"] = new_entity_id
|
||||
|
||||
if new_unique_id is not _UNDEF:
|
||||
conflict = next(
|
||||
(
|
||||
entity
|
||||
for entity in self.entities.values()
|
||||
if entity.unique_id == new_unique_id
|
||||
and entity.domain == old.domain
|
||||
and entity.platform == old.platform
|
||||
),
|
||||
None,
|
||||
conflict_entity_id = self.async_get_entity_id(
|
||||
old.domain, old.platform, new_unique_id
|
||||
)
|
||||
if conflict:
|
||||
if conflict_entity_id:
|
||||
raise ValueError(
|
||||
f"Unique id '{new_unique_id}' is already in use by "
|
||||
f"'{conflict.entity_id}'"
|
||||
f"'{conflict_entity_id}'"
|
||||
)
|
||||
changes["unique_id"] = new_unique_id
|
||||
|
||||
if not changes:
|
||||
return old
|
||||
|
||||
new = self.entities[entity_id] = attr.evolve(old, **changes)
|
||||
self._remove_index(old)
|
||||
new = attr.evolve(old, **changes)
|
||||
self._register_entry(new)
|
||||
|
||||
self.async_schedule_save()
|
||||
|
||||
|
@ -451,6 +440,7 @@ class EntityRegistry:
|
|||
)
|
||||
|
||||
self.entities = entities
|
||||
self._rebuild_index()
|
||||
|
||||
@callback
|
||||
def async_schedule_save(self) -> None:
|
||||
|
@ -494,6 +484,25 @@ class EntityRegistry:
|
|||
]:
|
||||
self.async_remove(entity_id)
|
||||
|
||||
def _register_entry(self, entry: RegistryEntry) -> None:
|
||||
self.entities[entry.entity_id] = entry
|
||||
self._add_index(entry)
|
||||
|
||||
def _add_index(self, entry: RegistryEntry) -> None:
|
||||
self._index[(entry.domain, entry.platform, entry.unique_id)] = entry.entity_id
|
||||
|
||||
def _unregister_entry(self, entry: RegistryEntry) -> None:
|
||||
self._remove_index(entry)
|
||||
del self.entities[entry.entity_id]
|
||||
|
||||
def _remove_index(self, entry: RegistryEntry) -> None:
|
||||
del self._index[(entry.domain, entry.platform, entry.unique_id)]
|
||||
|
||||
def _rebuild_index(self) -> None:
|
||||
self._index = {}
|
||||
for entry in self.entities.values():
|
||||
self._add_index(entry)
|
||||
|
||||
|
||||
@singleton(DATA_REGISTRY)
|
||||
async def async_get_registry(hass: HomeAssistantType) -> EntityRegistry:
|
||||
|
|
|
@ -351,6 +351,7 @@ def mock_registry(hass, mock_entries=None):
|
|||
"""Mock the Entity Registry."""
|
||||
registry = entity_registry.EntityRegistry(hass)
|
||||
registry.entities = mock_entries or OrderedDict()
|
||||
registry._rebuild_index()
|
||||
|
||||
hass.data[entity_registry.DATA_REGISTRY] = registry
|
||||
return registry
|
||||
|
|
|
@ -428,6 +428,8 @@ async def test_update_entity_unique_id(registry):
|
|||
entry = registry.async_get_or_create(
|
||||
"light", "hue", "5678", config_entry=mock_config
|
||||
)
|
||||
assert registry.async_get_entity_id("light", "hue", "5678") == entry.entity_id
|
||||
|
||||
new_unique_id = "1234"
|
||||
with patch.object(registry, "async_schedule_save") as mock_schedule_save:
|
||||
updated_entry = registry.async_update_entity(
|
||||
|
@ -437,6 +439,9 @@ async def test_update_entity_unique_id(registry):
|
|||
assert updated_entry.unique_id == new_unique_id
|
||||
assert mock_schedule_save.call_count == 1
|
||||
|
||||
assert registry.async_get_entity_id("light", "hue", "5678") is None
|
||||
assert registry.async_get_entity_id("light", "hue", "1234") == entry.entity_id
|
||||
|
||||
|
||||
async def test_update_entity_unique_id_conflict(registry):
|
||||
"""Test migration raises when unique_id already in use."""
|
||||
|
@ -452,6 +457,8 @@ async def test_update_entity_unique_id_conflict(registry):
|
|||
) as mock_schedule_save, pytest.raises(ValueError):
|
||||
registry.async_update_entity(entry.entity_id, new_unique_id=entry2.unique_id)
|
||||
assert mock_schedule_save.call_count == 0
|
||||
assert registry.async_get_entity_id("light", "hue", "5678") == entry.entity_id
|
||||
assert registry.async_get_entity_id("light", "hue", "1234") == entry2.entity_id
|
||||
|
||||
|
||||
async def test_update_entity(registry):
|
||||
|
@ -473,6 +480,10 @@ async def test_update_entity(registry):
|
|||
assert getattr(updated_entry, attr_name) == new_value
|
||||
assert getattr(updated_entry, attr_name) != getattr(entry, attr_name)
|
||||
|
||||
assert (
|
||||
registry.async_get_entity_id("light", "hue", "5678")
|
||||
== updated_entry.entity_id
|
||||
)
|
||||
entry = updated_entry
|
||||
|
||||
|
||||
|
|
Loading…
Reference in New Issue