Improve typing [helpers.entity_registry] (#63767)
parent
2d2944d186
commit
8460c2f66d
|
@ -165,7 +165,7 @@ class EntityRegistryStore(storage.Store):
|
||||||
return await _async_migrate(old_major_version, old_minor_version, old_data)
|
return await _async_migrate(old_major_version, old_minor_version, old_data)
|
||||||
|
|
||||||
|
|
||||||
class EntityRegistryItems(UserDict):
|
class EntityRegistryItems(UserDict[str, "RegistryEntry"]):
|
||||||
"""Container for entity registry items, maps entity_id -> entry.
|
"""Container for entity registry items, maps entity_id -> entry.
|
||||||
|
|
||||||
Maintains two additional indexes:
|
Maintains two additional indexes:
|
||||||
|
@ -196,10 +196,6 @@ class EntityRegistryItems(UserDict):
|
||||||
self._index.__delitem__((entry.domain, entry.platform, entry.unique_id))
|
self._index.__delitem__((entry.domain, entry.platform, entry.unique_id))
|
||||||
super().__delitem__(key)
|
super().__delitem__(key)
|
||||||
|
|
||||||
def __getitem__(self, key: str) -> RegistryEntry:
|
|
||||||
"""Get an item."""
|
|
||||||
return cast(RegistryEntry, super().__getitem__(key))
|
|
||||||
|
|
||||||
def get_entity_id(self, key: tuple[str, str, str]) -> str | None:
|
def get_entity_id(self, key: tuple[str, str, str]) -> str | None:
|
||||||
"""Get entity_id from (domain, platform, unique_id)."""
|
"""Get entity_id from (domain, platform, unique_id)."""
|
||||||
return self._index.get(key)
|
return self._index.get(key)
|
||||||
|
@ -212,10 +208,11 @@ class EntityRegistryItems(UserDict):
|
||||||
class EntityRegistry:
|
class EntityRegistry:
|
||||||
"""Class to hold a registry of entities."""
|
"""Class to hold a registry of entities."""
|
||||||
|
|
||||||
|
entities: EntityRegistryItems
|
||||||
|
|
||||||
def __init__(self, hass: HomeAssistant) -> None:
|
def __init__(self, hass: HomeAssistant) -> None:
|
||||||
"""Initialize the registry."""
|
"""Initialize the registry."""
|
||||||
self.hass = hass
|
self.hass = hass
|
||||||
self.entities: EntityRegistryItems
|
|
||||||
self._store = EntityRegistryStore(
|
self._store = EntityRegistryStore(
|
||||||
hass,
|
hass,
|
||||||
STORAGE_VERSION_MAJOR,
|
STORAGE_VERSION_MAJOR,
|
||||||
|
@ -230,13 +227,13 @@ class EntityRegistry:
|
||||||
@callback
|
@callback
|
||||||
def async_get_device_class_lookup(
|
def async_get_device_class_lookup(
|
||||||
self, domain_device_classes: set[tuple[str, str | None]]
|
self, domain_device_classes: set[tuple[str, str | None]]
|
||||||
) -> dict:
|
) -> dict[str, dict[tuple[str, str | None], str]]:
|
||||||
"""Return a lookup of entity ids for devices which have matching entities.
|
"""Return a lookup of entity ids for devices which have matching entities.
|
||||||
|
|
||||||
Entities must match a set of (domain, device_class) tuples.
|
Entities must match a set of (domain, device_class) tuples.
|
||||||
The result is indexed by device_id, then by the matching (domain, device_class)
|
The result is indexed by device_id, then by the matching (domain, device_class)
|
||||||
"""
|
"""
|
||||||
lookup: dict[str, dict[tuple[Any, Any], str]] = {}
|
lookup: dict[str, dict[tuple[str, str | None], str]] = {}
|
||||||
for entity in self.entities.values():
|
for entity in self.entities.values():
|
||||||
if not entity.device_id:
|
if not entity.device_id:
|
||||||
continue
|
continue
|
||||||
|
@ -483,8 +480,8 @@ class EntityRegistry:
|
||||||
"""Private facing update properties method."""
|
"""Private facing update properties method."""
|
||||||
old = self.entities[entity_id]
|
old = self.entities[entity_id]
|
||||||
|
|
||||||
new_values = {} # Dict with new key/value pairs
|
new_values: dict[str, Any] = {} # Dict with new key/value pairs
|
||||||
old_values = {} # Dict with old key/value pairs
|
old_values: dict[str, Any] = {} # Dict with old key/value pairs
|
||||||
|
|
||||||
if isinstance(disabled_by, str) and not isinstance(
|
if isinstance(disabled_by, str) and not isinstance(
|
||||||
disabled_by, RegistryEntryDisabler
|
disabled_by, RegistryEntryDisabler
|
||||||
|
@ -550,7 +547,11 @@ class EntityRegistry:
|
||||||
|
|
||||||
self.async_schedule_save()
|
self.async_schedule_save()
|
||||||
|
|
||||||
data = {"action": "update", "entity_id": entity_id, "changes": old_values}
|
data: dict[str, str | dict[str, Any]] = {
|
||||||
|
"action": "update",
|
||||||
|
"entity_id": entity_id,
|
||||||
|
"changes": old_values,
|
||||||
|
}
|
||||||
|
|
||||||
if old.entity_id != entity_id:
|
if old.entity_id != entity_id:
|
||||||
data["old_entity_id"] = old.entity_id
|
data["old_entity_id"] = old.entity_id
|
||||||
|
@ -613,7 +614,7 @@ class EntityRegistry:
|
||||||
@callback
|
@callback
|
||||||
def _data_to_save(self) -> dict[str, Any]:
|
def _data_to_save(self) -> dict[str, Any]:
|
||||||
"""Return data of entity registry to store in a file."""
|
"""Return data of entity registry to store in a file."""
|
||||||
data = {}
|
data: dict[str, Any] = {}
|
||||||
|
|
||||||
data["entities"] = [
|
data["entities"] = [
|
||||||
{
|
{
|
||||||
|
@ -841,7 +842,7 @@ def async_setup_entity_restore(hass: HomeAssistant, registry: EntityRegistry) ->
|
||||||
async def async_migrate_entries(
|
async def async_migrate_entries(
|
||||||
hass: HomeAssistant,
|
hass: HomeAssistant,
|
||||||
config_entry_id: str,
|
config_entry_id: str,
|
||||||
entry_callback: Callable[[RegistryEntry], dict | None],
|
entry_callback: Callable[[RegistryEntry], dict[str, Any] | None],
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Migrator of unique IDs."""
|
"""Migrator of unique IDs."""
|
||||||
ent_reg = await async_get_registry(hass)
|
ent_reg = await async_get_registry(hass)
|
||||||
|
|
Loading…
Reference in New Issue