Index the entity registry (#37994)

pull/38005/head
J. Nick Koston 2020-07-19 19:52:41 -10:00 committed by GitHub
parent 41421b56a4
commit 890562e3ae
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 43 additions and 22 deletions

View File

@ -124,6 +124,7 @@ class EntityRegistry:
"""Initialize the registry."""
self.hass = hass
self.entities: Dict[str, RegistryEntry]
self._index: Dict[Tuple[str, str, str], str] = {}
self._store = hass.helpers.storage.Store(STORAGE_VERSION, STORAGE_KEY)
self.hass.bus.async_listen(
EVENT_DEVICE_REGISTRY_UPDATED, self.async_device_removed
@ -160,14 +161,7 @@ class EntityRegistry:
self, domain: str, platform: str, unique_id: str
) -> Optional[str]:
"""Check if an entity_id is currently registered."""
for entity in self.entities.values():
if (
entity.domain == domain
and entity.platform == platform
and entity.unique_id == unique_id
):
return entity.entity_id
return None
return self._index.get((domain, platform, unique_id))
@callback
def async_generate_entity_id(
@ -270,7 +264,7 @@ class EntityRegistry:
original_name=original_name,
original_icon=original_icon,
)
self.entities[entity_id] = entity
self._register_entry(entity)
_LOGGER.info("Registered new %s.%s entity: %s", domain, platform, entity_id)
self.async_schedule_save()
@ -283,7 +277,7 @@ class EntityRegistry:
@callback
def async_remove(self, entity_id: str) -> None:
"""Remove an entity from registry."""
self.entities.pop(entity_id)
self._unregister_entry(self.entities[entity_id])
self.hass.bus.async_fire(
EVENT_ENTITY_REGISTRY_UPDATED, {"action": "remove", "entity_id": entity_id}
)
@ -380,27 +374,22 @@ class EntityRegistry:
entity_id = changes["entity_id"] = new_entity_id
if new_unique_id is not _UNDEF:
conflict = next(
(
entity
for entity in self.entities.values()
if entity.unique_id == new_unique_id
and entity.domain == old.domain
and entity.platform == old.platform
),
None,
conflict_entity_id = self.async_get_entity_id(
old.domain, old.platform, new_unique_id
)
if conflict:
if conflict_entity_id:
raise ValueError(
f"Unique id '{new_unique_id}' is already in use by "
f"'{conflict.entity_id}'"
f"'{conflict_entity_id}'"
)
changes["unique_id"] = new_unique_id
if not changes:
return old
new = self.entities[entity_id] = attr.evolve(old, **changes)
self._remove_index(old)
new = attr.evolve(old, **changes)
self._register_entry(new)
self.async_schedule_save()
@ -451,6 +440,7 @@ class EntityRegistry:
)
self.entities = entities
self._rebuild_index()
@callback
def async_schedule_save(self) -> None:
@ -494,6 +484,25 @@ class EntityRegistry:
]:
self.async_remove(entity_id)
def _register_entry(self, entry: RegistryEntry) -> None:
self.entities[entry.entity_id] = entry
self._add_index(entry)
def _add_index(self, entry: RegistryEntry) -> None:
self._index[(entry.domain, entry.platform, entry.unique_id)] = entry.entity_id
def _unregister_entry(self, entry: RegistryEntry) -> None:
self._remove_index(entry)
del self.entities[entry.entity_id]
def _remove_index(self, entry: RegistryEntry) -> None:
del self._index[(entry.domain, entry.platform, entry.unique_id)]
def _rebuild_index(self) -> None:
self._index = {}
for entry in self.entities.values():
self._add_index(entry)
@singleton(DATA_REGISTRY)
async def async_get_registry(hass: HomeAssistantType) -> EntityRegistry:

View File

@ -351,6 +351,7 @@ def mock_registry(hass, mock_entries=None):
"""Mock the Entity Registry."""
registry = entity_registry.EntityRegistry(hass)
registry.entities = mock_entries or OrderedDict()
registry._rebuild_index()
hass.data[entity_registry.DATA_REGISTRY] = registry
return registry

View File

@ -428,6 +428,8 @@ async def test_update_entity_unique_id(registry):
entry = registry.async_get_or_create(
"light", "hue", "5678", config_entry=mock_config
)
assert registry.async_get_entity_id("light", "hue", "5678") == entry.entity_id
new_unique_id = "1234"
with patch.object(registry, "async_schedule_save") as mock_schedule_save:
updated_entry = registry.async_update_entity(
@ -437,6 +439,9 @@ async def test_update_entity_unique_id(registry):
assert updated_entry.unique_id == new_unique_id
assert mock_schedule_save.call_count == 1
assert registry.async_get_entity_id("light", "hue", "5678") is None
assert registry.async_get_entity_id("light", "hue", "1234") == entry.entity_id
async def test_update_entity_unique_id_conflict(registry):
"""Test migration raises when unique_id already in use."""
@ -452,6 +457,8 @@ async def test_update_entity_unique_id_conflict(registry):
) as mock_schedule_save, pytest.raises(ValueError):
registry.async_update_entity(entry.entity_id, new_unique_id=entry2.unique_id)
assert mock_schedule_save.call_count == 0
assert registry.async_get_entity_id("light", "hue", "5678") == entry.entity_id
assert registry.async_get_entity_id("light", "hue", "1234") == entry2.entity_id
async def test_update_entity(registry):
@ -473,6 +480,10 @@ async def test_update_entity(registry):
assert getattr(updated_entry, attr_name) == new_value
assert getattr(updated_entry, attr_name) != getattr(entry, attr_name)
assert (
registry.async_get_entity_id("light", "hue", "5678")
== updated_entry.entity_id
)
entry = updated_entry