Reduce some linear searches to cleanup the device registry (#112277)

Some of the data we had to search for was already available
in a dict or underlying data structure. Make it available
instead of having to build it every time.

There are more places these can be used, but I only did
the device registry cleanup for now
pull/112291/head
J. Nick Koston 2024-03-04 15:59:12 -10:00 committed by GitHub
parent e357c4d5e5
commit 2c179dc5fb
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 56 additions and 14 deletions

View File

@ -1462,6 +1462,11 @@ class ConfigEntries:
"""Return entry with matching entry_id.""" """Return entry with matching entry_id."""
return self._entries.data.get(entry_id) return self._entries.data.get(entry_id)
@callback
def async_entry_ids(self) -> list[str]:
"""Return entry ids."""
return list(self._entries.data)
@callback @callback
def async_entries( def async_entries(
self, self,

View File

@ -1090,7 +1090,7 @@ def async_cleanup(
) -> None: ) -> None:
"""Clean up device registry.""" """Clean up device registry."""
# Find all devices that are referenced by a config_entry. # Find all devices that are referenced by a config_entry.
config_entry_ids = {entry.entry_id for entry in hass.config_entries.async_entries()} config_entry_ids = set(hass.config_entries.async_entry_ids())
references_config_entries = { references_config_entries = {
device.id device.id
for device in dev_reg.devices.values() for device in dev_reg.devices.values()
@ -1099,9 +1099,13 @@ def async_cleanup(
} }
# Find all devices that are referenced in the entity registry. # Find all devices that are referenced in the entity registry.
references_entities = {entry.device_id for entry in ent_reg.entities.values()} device_ids_referenced_by_entities = set(ent_reg.entities.get_device_ids())
orphan = set(dev_reg.devices) - references_entities - references_config_entries orphan = (
set(dev_reg.devices)
- device_ids_referenced_by_entities
- references_config_entries
)
for dev_id in orphan: for dev_id in orphan:
dev_reg.async_remove_device(dev_id) dev_reg.async_remove_device(dev_id)

View File

@ -10,7 +10,7 @@ timer.
from __future__ import annotations from __future__ import annotations
from collections import UserDict from collections import UserDict
from collections.abc import Callable, Iterable, Mapping, ValuesView from collections.abc import Callable, Iterable, KeysView, Mapping, ValuesView
from datetime import datetime, timedelta from datetime import datetime, timedelta
from enum import StrEnum from enum import StrEnum
import logging import logging
@ -511,6 +511,14 @@ class EntityRegistryItems(UserDict[str, RegistryEntry]):
self._unindex_entry(key) self._unindex_entry(key)
super().__delitem__(key) super().__delitem__(key)
def get_entity_ids(self) -> ValuesView[str]:
"""Return entity ids."""
return self._index.values()
def get_device_ids(self) -> KeysView[str]:
"""Return device ids."""
return self._device_id_index.keys()
def get_entity_id(self, key: tuple[str, str, str]) -> str | None: def get_entity_id(self, key: tuple[str, str, str]) -> str | None:
"""Get entity_id from (domain, platform, unique_id).""" """Get entity_id from (domain, platform, unique_id)."""
return self._index.get(key) return self._index.get(key)
@ -612,6 +620,16 @@ class EntityRegistry:
"""Check if an entity_id is currently registered.""" """Check if an entity_id is currently registered."""
return self.entities.get_entity_id((domain, platform, unique_id)) return self.entities.get_entity_id((domain, platform, unique_id))
@callback
def async_entity_ids(self) -> list[str]:
"""Return entity ids."""
return list(self.entities.get_entity_ids())
@callback
def async_device_ids(self) -> list[str]:
"""Return known device ids."""
return list(self.entities.get_device_ids())
def _entity_id_available( def _entity_id_available(
self, entity_id: str, known_object_ids: Iterable[str] | None self, entity_id: str, known_object_ids: Iterable[str] | None
) -> bool: ) -> bool:

View File

@ -90,6 +90,9 @@ def test_get_or_create_updates_data(entity_registry: er.EntityRegistry) -> None:
unit_of_measurement="initial-unit_of_measurement", unit_of_measurement="initial-unit_of_measurement",
) )
assert set(entity_registry.async_device_ids()) == {"mock-dev-id"}
assert set(entity_registry.async_entity_ids()) == {"light.hue_5678"}
assert orig_entry == er.RegistryEntry( assert orig_entry == er.RegistryEntry(
"light.hue_5678", "light.hue_5678",
"5678", "5678",
@ -159,6 +162,9 @@ def test_get_or_create_updates_data(entity_registry: er.EntityRegistry) -> None:
unit_of_measurement="updated-unit_of_measurement", unit_of_measurement="updated-unit_of_measurement",
) )
assert set(entity_registry.async_device_ids()) == {"new-mock-dev-id"}
assert set(entity_registry.async_entity_ids()) == {"light.hue_5678"}
new_entry = entity_registry.async_get_or_create( new_entry = entity_registry.async_get_or_create(
"light", "light",
"hue", "hue",
@ -203,6 +209,9 @@ def test_get_or_create_updates_data(entity_registry: er.EntityRegistry) -> None:
unit_of_measurement=None, unit_of_measurement=None,
) )
assert set(entity_registry.async_device_ids()) == set()
assert set(entity_registry.async_entity_ids()) == {"light.hue_5678"}
def test_get_or_create_suggested_object_id_conflict_register( def test_get_or_create_suggested_object_id_conflict_register(
entity_registry: er.EntityRegistry, entity_registry: er.EntityRegistry,
@ -446,6 +455,8 @@ def test_async_get_entity_id(entity_registry: er.EntityRegistry) -> None:
) )
assert entity_registry.async_get_entity_id("light", "hue", "123") is None assert entity_registry.async_get_entity_id("light", "hue", "123") is None
assert set(entity_registry.async_entity_ids()) == {"light.hue_1234"}
async def test_updating_config_entry_id( async def test_updating_config_entry_id(
hass: HomeAssistant, entity_registry: er.EntityRegistry hass: HomeAssistant, entity_registry: er.EntityRegistry
@ -1469,6 +1480,7 @@ def test_entity_registry_items() -> None:
entities = er.EntityRegistryItems() entities = er.EntityRegistryItems()
assert entities.get_entity_id(("a", "b", "c")) is None assert entities.get_entity_id(("a", "b", "c")) is None
assert entities.get_entry("abc") is None assert entities.get_entry("abc") is None
assert set(entities.get_entity_ids()) == set()
entry1 = er.RegistryEntry("test.entity1", "1234", "hue") entry1 = er.RegistryEntry("test.entity1", "1234", "hue")
entry2 = er.RegistryEntry("test.entity2", "2345", "hue") entry2 = er.RegistryEntry("test.entity2", "2345", "hue")
@ -1482,6 +1494,7 @@ def test_entity_registry_items() -> None:
assert entities.get_entry(entry1.id) is entry1 assert entities.get_entry(entry1.id) is entry1
assert entities.get_entity_id(("test", "hue", "2345")) is entry2.entity_id assert entities.get_entity_id(("test", "hue", "2345")) is entry2.entity_id
assert entities.get_entry(entry2.id) is entry2 assert entities.get_entry(entry2.id) is entry2
assert set(entities.get_entity_ids()) == {"test.entity2", "test.entity1"}
entities.pop("test.entity1") entities.pop("test.entity1")
del entities["test.entity2"] del entities["test.entity2"]
@ -1491,6 +1504,8 @@ def test_entity_registry_items() -> None:
assert entities.get_entity_id(("test", "hue", "2345")) is None assert entities.get_entity_id(("test", "hue", "2345")) is None
assert entities.get_entry(entry2.id) is None assert entities.get_entry(entry2.id) is None
assert set(entities.get_entity_ids()) == set()
async def test_disabled_by_str_not_allowed( async def test_disabled_by_str_not_allowed(
hass: HomeAssistant, entity_registry: er.EntityRegistry hass: HomeAssistant, entity_registry: er.EntityRegistry

View File

@ -378,7 +378,7 @@ async def test_remove_entry(
MockConfigEntry(domain="test_other", entry_id="test3").add_to_manager(manager) MockConfigEntry(domain="test_other", entry_id="test3").add_to_manager(manager)
# Check all config entries exist # Check all config entries exist
assert [item.entry_id for item in manager.async_entries()] == [ assert manager.async_entry_ids() == [
"test1", "test1",
"test2", "test2",
"test3", "test3",
@ -408,7 +408,7 @@ async def test_remove_entry(
assert mock_remove_entry.call_count == 1 assert mock_remove_entry.call_count == 1
# Check that config entry was removed. # Check that config entry was removed.
assert [item.entry_id for item in manager.async_entries()] == ["test1", "test3"] assert manager.async_entry_ids() == ["test1", "test3"]
# Check that entity state has been removed # Check that entity state has been removed
assert hass.states.get("light.test_entity") is None assert hass.states.get("light.test_entity") is None
@ -469,7 +469,7 @@ async def test_remove_entry_handles_callback_error(
entry = MockConfigEntry(domain="test", entry_id="test1") entry = MockConfigEntry(domain="test", entry_id="test1")
entry.add_to_manager(manager) entry.add_to_manager(manager)
# Check all config entries exist # Check all config entries exist
assert [item.entry_id for item in manager.async_entries()] == ["test1"] assert manager.async_entry_ids() == ["test1"]
# Setup entry # Setup entry
await entry.async_setup(hass) await entry.async_setup(hass)
await hass.async_block_till_done() await hass.async_block_till_done()
@ -482,7 +482,7 @@ async def test_remove_entry_handles_callback_error(
# Check the remove callback was invoked. # Check the remove callback was invoked.
assert mock_remove_entry.call_count == 1 assert mock_remove_entry.call_count == 1
# Check that config entry was removed. # Check that config entry was removed.
assert [item.entry_id for item in manager.async_entries()] == [] assert manager.async_entry_ids() == []
async def test_remove_entry_raises( async def test_remove_entry_raises(
@ -502,7 +502,7 @@ async def test_remove_entry_raises(
).add_to_manager(manager) ).add_to_manager(manager)
MockConfigEntry(domain="test", entry_id="test3").add_to_manager(manager) MockConfigEntry(domain="test", entry_id="test3").add_to_manager(manager)
assert [item.entry_id for item in manager.async_entries()] == [ assert manager.async_entry_ids() == [
"test1", "test1",
"test2", "test2",
"test3", "test3",
@ -511,7 +511,7 @@ async def test_remove_entry_raises(
result = await manager.async_remove("test2") result = await manager.async_remove("test2")
assert result == {"require_restart": True} assert result == {"require_restart": True}
assert [item.entry_id for item in manager.async_entries()] == ["test1", "test3"] assert manager.async_entry_ids() == ["test1", "test3"]
async def test_remove_entry_if_not_loaded( async def test_remove_entry_if_not_loaded(
@ -526,7 +526,7 @@ async def test_remove_entry_if_not_loaded(
MockConfigEntry(domain="comp", entry_id="test2").add_to_manager(manager) MockConfigEntry(domain="comp", entry_id="test2").add_to_manager(manager)
MockConfigEntry(domain="test", entry_id="test3").add_to_manager(manager) MockConfigEntry(domain="test", entry_id="test3").add_to_manager(manager)
assert [item.entry_id for item in manager.async_entries()] == [ assert manager.async_entry_ids() == [
"test1", "test1",
"test2", "test2",
"test3", "test3",
@ -535,7 +535,7 @@ async def test_remove_entry_if_not_loaded(
result = await manager.async_remove("test2") result = await manager.async_remove("test2")
assert result == {"require_restart": False} assert result == {"require_restart": False}
assert [item.entry_id for item in manager.async_entries()] == ["test1", "test3"] assert manager.async_entry_ids() == ["test1", "test3"]
assert len(mock_unload_entry.mock_calls) == 0 assert len(mock_unload_entry.mock_calls) == 0
@ -550,7 +550,7 @@ async def test_remove_entry_if_integration_deleted(
MockConfigEntry(domain="comp", entry_id="test2").add_to_manager(manager) MockConfigEntry(domain="comp", entry_id="test2").add_to_manager(manager)
MockConfigEntry(domain="test", entry_id="test3").add_to_manager(manager) MockConfigEntry(domain="test", entry_id="test3").add_to_manager(manager)
assert [item.entry_id for item in manager.async_entries()] == [ assert manager.async_entry_ids() == [
"test1", "test1",
"test2", "test2",
"test3", "test3",
@ -559,7 +559,7 @@ async def test_remove_entry_if_integration_deleted(
result = await manager.async_remove("test2") result = await manager.async_remove("test2")
assert result == {"require_restart": False} assert result == {"require_restart": False}
assert [item.entry_id for item in manager.async_entries()] == ["test1", "test3"] assert manager.async_entry_ids() == ["test1", "test3"]
assert len(mock_unload_entry.mock_calls) == 0 assert len(mock_unload_entry.mock_calls) == 0