core/homeassistant/helpers/entity_registry.py

669 lines
23 KiB
Python

"""Provide a registry to track entity IDs.
The Entity Registry keeps a registry of entities. Entities are uniquely
identified by their domain, platform and a unique id provided by that platform.
The Entity Registry will persist itself 10 seconds after a new entity is
registered. Registering a new entity while a timer is in progress resets the
timer.
"""
from collections import OrderedDict
import logging
from typing import (
TYPE_CHECKING,
Any,
Callable,
Dict,
Iterable,
List,
Optional,
Tuple,
Union,
)
import attr
from homeassistant.const import (
ATTR_DEVICE_CLASS,
ATTR_FRIENDLY_NAME,
ATTR_ICON,
ATTR_RESTORED,
ATTR_SUPPORTED_FEATURES,
ATTR_UNIT_OF_MEASUREMENT,
EVENT_HOMEASSISTANT_START,
STATE_UNAVAILABLE,
)
from homeassistant.core import Event, callback, split_entity_id, valid_entity_id
from homeassistant.helpers.device_registry import EVENT_DEVICE_REGISTRY_UPDATED
from homeassistant.util import slugify
from homeassistant.util.yaml import load_yaml
from .singleton import singleton
from .typing import UNDEFINED, HomeAssistantType, UndefinedType
if TYPE_CHECKING:
from homeassistant.config_entries import ConfigEntry # noqa: F401
PATH_REGISTRY = "entity_registry.yaml"
DATA_REGISTRY = "entity_registry"
EVENT_ENTITY_REGISTRY_UPDATED = "entity_registry_updated"
SAVE_DELAY = 10
_LOGGER = logging.getLogger(__name__)
DISABLED_CONFIG_ENTRY = "config_entry"
DISABLED_DEVICE = "device"
DISABLED_HASS = "hass"
DISABLED_INTEGRATION = "integration"
DISABLED_USER = "user"
STORAGE_VERSION = 1
STORAGE_KEY = "core.entity_registry"
# Attributes relevant to describing entity
# to external services.
ENTITY_DESCRIBING_ATTRIBUTES = {
"entity_id",
"name",
"original_name",
"capabilities",
"supported_features",
"device_class",
"unit_of_measurement",
}
@attr.s(slots=True, frozen=True)
class RegistryEntry:
"""Entity Registry Entry."""
entity_id: str = attr.ib()
unique_id: str = attr.ib()
platform: str = attr.ib()
name: Optional[str] = attr.ib(default=None)
icon: Optional[str] = attr.ib(default=None)
device_id: Optional[str] = attr.ib(default=None)
area_id: Optional[str] = attr.ib(default=None)
config_entry_id: Optional[str] = attr.ib(default=None)
disabled_by: Optional[str] = attr.ib(
default=None,
validator=attr.validators.in_(
(
DISABLED_CONFIG_ENTRY,
DISABLED_DEVICE,
DISABLED_HASS,
DISABLED_INTEGRATION,
DISABLED_USER,
None,
)
),
)
capabilities: Optional[Dict[str, Any]] = attr.ib(default=None)
supported_features: int = attr.ib(default=0)
device_class: Optional[str] = attr.ib(default=None)
unit_of_measurement: Optional[str] = attr.ib(default=None)
# As set by integration
original_name: Optional[str] = attr.ib(default=None)
original_icon: Optional[str] = attr.ib(default=None)
domain: str = attr.ib(init=False, repr=False)
@domain.default
def _domain_default(self) -> str:
"""Compute domain value."""
return split_entity_id(self.entity_id)[0]
@property
def disabled(self) -> bool:
"""Return if entry is disabled."""
return self.disabled_by is not None
class EntityRegistry:
"""Class to hold a registry of entities."""
def __init__(self, hass: HomeAssistantType):
"""Initialize the registry."""
self.hass = hass
self.entities: Dict[str, RegistryEntry]
self._index: Dict[Tuple[str, str, str], str] = {}
self._store = hass.helpers.storage.Store(STORAGE_VERSION, STORAGE_KEY)
self.hass.bus.async_listen(
EVENT_DEVICE_REGISTRY_UPDATED, self.async_device_modified
)
@callback
def async_get_device_class_lookup(self, domain_device_classes: set) -> dict:
"""Return a lookup for the device class by domain."""
lookup: Dict[str, Dict[Tuple[Any, Any], str]] = {}
for entity in self.entities.values():
if not entity.device_id:
continue
domain_device_class = (entity.domain, entity.device_class)
if domain_device_class not in domain_device_classes:
continue
if entity.device_id not in lookup:
lookup[entity.device_id] = {domain_device_class: entity.entity_id}
else:
lookup[entity.device_id][domain_device_class] = entity.entity_id
return lookup
@callback
def async_is_registered(self, entity_id: str) -> bool:
"""Check if an entity_id is currently registered."""
return entity_id in self.entities
@callback
def async_get(self, entity_id: str) -> Optional[RegistryEntry]:
"""Get EntityEntry for an entity_id."""
return self.entities.get(entity_id)
@callback
def async_get_entity_id(
self, domain: str, platform: str, unique_id: str
) -> Optional[str]:
"""Check if an entity_id is currently registered."""
return self._index.get((domain, platform, unique_id))
@callback
def async_generate_entity_id(
self,
domain: str,
suggested_object_id: str,
known_object_ids: Optional[Iterable[str]] = None,
) -> str:
"""Generate an entity ID that does not conflict.
Conflicts checked against registered and currently existing entities.
"""
preferred_string = f"{domain}.{slugify(suggested_object_id)}"
test_string = preferred_string
if not known_object_ids:
known_object_ids = {}
tries = 1
while (
test_string in self.entities
or test_string in known_object_ids
or not self.hass.states.async_available(test_string)
):
tries += 1
test_string = f"{preferred_string}_{tries}"
return test_string
@callback
def async_get_or_create(
self,
domain: str,
platform: str,
unique_id: str,
*,
# To influence entity ID generation
suggested_object_id: Optional[str] = None,
known_object_ids: Optional[Iterable[str]] = None,
# To disable an entity if it gets created
disabled_by: Optional[str] = None,
# Data that we want entry to have
config_entry: Optional["ConfigEntry"] = None,
device_id: Optional[str] = None,
area_id: Optional[str] = None,
capabilities: Optional[Dict[str, Any]] = None,
supported_features: Optional[int] = None,
device_class: Optional[str] = None,
unit_of_measurement: Optional[str] = None,
original_name: Optional[str] = None,
original_icon: Optional[str] = None,
) -> RegistryEntry:
"""Get entity. Create if it doesn't exist."""
config_entry_id = None
if config_entry:
config_entry_id = config_entry.entry_id
entity_id = self.async_get_entity_id(domain, platform, unique_id)
if entity_id:
return self._async_update_entity(
entity_id,
config_entry_id=config_entry_id or UNDEFINED,
device_id=device_id or UNDEFINED,
area_id=area_id or UNDEFINED,
capabilities=capabilities or UNDEFINED,
supported_features=supported_features or UNDEFINED,
device_class=device_class or UNDEFINED,
unit_of_measurement=unit_of_measurement or UNDEFINED,
original_name=original_name or UNDEFINED,
original_icon=original_icon or UNDEFINED,
# When we changed our slugify algorithm, we invalidated some
# stored entity IDs with either a __ or ending in _.
# Fix introduced in 0.86 (Jan 23, 2019). Next line can be
# removed when we release 1.0 or in 2020.
new_entity_id=".".join(
slugify(part) for part in entity_id.split(".", 1)
),
)
entity_id = self.async_generate_entity_id(
domain, suggested_object_id or f"{platform}_{unique_id}", known_object_ids
)
if (
disabled_by is None
and config_entry
and config_entry.system_options.disable_new_entities
):
disabled_by = DISABLED_INTEGRATION
entity = RegistryEntry(
entity_id=entity_id,
config_entry_id=config_entry_id,
device_id=device_id,
area_id=area_id,
unique_id=unique_id,
platform=platform,
disabled_by=disabled_by,
capabilities=capabilities,
supported_features=supported_features or 0,
device_class=device_class,
unit_of_measurement=unit_of_measurement,
original_name=original_name,
original_icon=original_icon,
)
self._register_entry(entity)
_LOGGER.info("Registered new %s.%s entity: %s", domain, platform, entity_id)
self.async_schedule_save()
self.hass.bus.async_fire(
EVENT_ENTITY_REGISTRY_UPDATED, {"action": "create", "entity_id": entity_id}
)
return entity
@callback
def async_remove(self, entity_id: str) -> None:
"""Remove an entity from registry."""
self._unregister_entry(self.entities[entity_id])
self.hass.bus.async_fire(
EVENT_ENTITY_REGISTRY_UPDATED, {"action": "remove", "entity_id": entity_id}
)
self.async_schedule_save()
async def async_device_modified(self, event: Event) -> None:
"""Handle the removal or update of a device.
Remove entities from the registry that are associated to a device when
the device is removed.
Disable entities in the registry that are associated to a device when
the device is disabled.
"""
if event.data["action"] == "remove":
entities = async_entries_for_device(
self, event.data["device_id"], include_disabled_entities=True
)
for entity in entities:
self.async_remove(entity.entity_id)
return
if event.data["action"] != "update":
return
device_registry = await self.hass.helpers.device_registry.async_get_registry()
device = device_registry.async_get(event.data["device_id"])
if not device.disabled:
entities = async_entries_for_device(
self, event.data["device_id"], include_disabled_entities=True
)
for entity in entities:
if entity.disabled_by != DISABLED_DEVICE:
continue
self.async_update_entity(entity.entity_id, disabled_by=None)
return
entities = async_entries_for_device(self, event.data["device_id"])
for entity in entities:
self.async_update_entity(entity.entity_id, disabled_by=DISABLED_DEVICE)
@callback
def async_update_entity(
self,
entity_id: str,
*,
name: Union[str, None, UndefinedType] = UNDEFINED,
icon: Union[str, None, UndefinedType] = UNDEFINED,
area_id: Union[str, None, UndefinedType] = UNDEFINED,
new_entity_id: Union[str, UndefinedType] = UNDEFINED,
new_unique_id: Union[str, UndefinedType] = UNDEFINED,
disabled_by: Union[str, None, UndefinedType] = UNDEFINED,
) -> RegistryEntry:
"""Update properties of an entity."""
return self._async_update_entity(
entity_id,
name=name,
icon=icon,
area_id=area_id,
new_entity_id=new_entity_id,
new_unique_id=new_unique_id,
disabled_by=disabled_by,
)
@callback
def _async_update_entity(
self,
entity_id: str,
*,
name: Union[str, None, UndefinedType] = UNDEFINED,
icon: Union[str, None, UndefinedType] = UNDEFINED,
config_entry_id: Union[str, None, UndefinedType] = UNDEFINED,
new_entity_id: Union[str, UndefinedType] = UNDEFINED,
device_id: Union[str, None, UndefinedType] = UNDEFINED,
area_id: Union[str, None, UndefinedType] = UNDEFINED,
new_unique_id: Union[str, UndefinedType] = UNDEFINED,
disabled_by: Union[str, None, UndefinedType] = UNDEFINED,
capabilities: Union[Dict[str, Any], None, UndefinedType] = UNDEFINED,
supported_features: Union[int, UndefinedType] = UNDEFINED,
device_class: Union[str, None, UndefinedType] = UNDEFINED,
unit_of_measurement: Union[str, None, UndefinedType] = UNDEFINED,
original_name: Union[str, None, UndefinedType] = UNDEFINED,
original_icon: Union[str, None, UndefinedType] = UNDEFINED,
) -> RegistryEntry:
"""Private facing update properties method."""
old = self.entities[entity_id]
changes = {}
for attr_name, value in (
("name", name),
("icon", icon),
("config_entry_id", config_entry_id),
("device_id", device_id),
("area_id", area_id),
("disabled_by", disabled_by),
("capabilities", capabilities),
("supported_features", supported_features),
("device_class", device_class),
("unit_of_measurement", unit_of_measurement),
("original_name", original_name),
("original_icon", original_icon),
):
if value is not UNDEFINED and value != getattr(old, attr_name):
changes[attr_name] = value
if new_entity_id is not UNDEFINED and new_entity_id != old.entity_id:
if self.async_is_registered(new_entity_id):
raise ValueError("Entity is already registered")
if not valid_entity_id(new_entity_id):
raise ValueError("Invalid entity ID")
if split_entity_id(new_entity_id)[0] != split_entity_id(entity_id)[0]:
raise ValueError("New entity ID should be same domain")
self.entities.pop(entity_id)
entity_id = changes["entity_id"] = new_entity_id
if new_unique_id is not UNDEFINED:
conflict_entity_id = self.async_get_entity_id(
old.domain, old.platform, new_unique_id
)
if conflict_entity_id:
raise ValueError(
f"Unique id '{new_unique_id}' is already in use by "
f"'{conflict_entity_id}'"
)
changes["unique_id"] = new_unique_id
if not changes:
return old
self._remove_index(old)
new = attr.evolve(old, **changes)
self._register_entry(new)
self.async_schedule_save()
data = {"action": "update", "entity_id": entity_id, "changes": list(changes)}
if old.entity_id != entity_id:
data["old_entity_id"] = old.entity_id
self.hass.bus.async_fire(EVENT_ENTITY_REGISTRY_UPDATED, data)
return new
async def async_load(self) -> None:
"""Load the entity registry."""
async_setup_entity_restore(self.hass, self)
data = await self.hass.helpers.storage.async_migrator(
self.hass.config.path(PATH_REGISTRY),
self._store,
old_conf_load_func=load_yaml,
old_conf_migrate_func=_async_migrate,
)
entities: Dict[str, RegistryEntry] = OrderedDict()
if data is not None:
for entity in data["entities"]:
# Some old installations can have some bad entities.
# Filter them out as they cause errors down the line.
# Can be removed in Jan 2021
if not valid_entity_id(entity["entity_id"]):
continue
entities[entity["entity_id"]] = RegistryEntry(
entity_id=entity["entity_id"],
config_entry_id=entity.get("config_entry_id"),
device_id=entity.get("device_id"),
area_id=entity.get("area_id"),
unique_id=entity["unique_id"],
platform=entity["platform"],
name=entity.get("name"),
icon=entity.get("icon"),
disabled_by=entity.get("disabled_by"),
capabilities=entity.get("capabilities") or {},
supported_features=entity.get("supported_features", 0),
device_class=entity.get("device_class"),
unit_of_measurement=entity.get("unit_of_measurement"),
original_name=entity.get("original_name"),
original_icon=entity.get("original_icon"),
)
self.entities = entities
self._rebuild_index()
@callback
def async_schedule_save(self) -> None:
"""Schedule saving the entity registry."""
self._store.async_delay_save(self._data_to_save, SAVE_DELAY)
@callback
def _data_to_save(self) -> Dict[str, Any]:
"""Return data of entity registry to store in a file."""
data = {}
data["entities"] = [
{
"entity_id": entry.entity_id,
"config_entry_id": entry.config_entry_id,
"device_id": entry.device_id,
"area_id": entry.area_id,
"unique_id": entry.unique_id,
"platform": entry.platform,
"name": entry.name,
"icon": entry.icon,
"disabled_by": entry.disabled_by,
"capabilities": entry.capabilities,
"supported_features": entry.supported_features,
"device_class": entry.device_class,
"unit_of_measurement": entry.unit_of_measurement,
"original_name": entry.original_name,
"original_icon": entry.original_icon,
}
for entry in self.entities.values()
]
return data
@callback
def async_clear_config_entry(self, config_entry: str) -> None:
"""Clear config entry from registry entries."""
for entity_id in [
entity_id
for entity_id, entry in self.entities.items()
if config_entry == entry.config_entry_id
]:
self.async_remove(entity_id)
@callback
def async_clear_area_id(self, area_id: str) -> None:
"""Clear area id from registry entries."""
for entity_id, entry in self.entities.items():
if area_id == entry.area_id:
self._async_update_entity(entity_id, area_id=None)
def _register_entry(self, entry: RegistryEntry) -> None:
self.entities[entry.entity_id] = entry
self._add_index(entry)
def _add_index(self, entry: RegistryEntry) -> None:
self._index[(entry.domain, entry.platform, entry.unique_id)] = entry.entity_id
def _unregister_entry(self, entry: RegistryEntry) -> None:
self._remove_index(entry)
del self.entities[entry.entity_id]
def _remove_index(self, entry: RegistryEntry) -> None:
del self._index[(entry.domain, entry.platform, entry.unique_id)]
def _rebuild_index(self) -> None:
self._index = {}
for entry in self.entities.values():
self._add_index(entry)
@singleton(DATA_REGISTRY)
async def async_get_registry(hass: HomeAssistantType) -> EntityRegistry:
"""Create entity registry."""
reg = EntityRegistry(hass)
await reg.async_load()
return reg
@callback
def async_entries_for_device(
registry: EntityRegistry, device_id: str, include_disabled_entities: bool = False
) -> List[RegistryEntry]:
"""Return entries that match a device."""
return [
entry
for entry in registry.entities.values()
if entry.device_id == device_id
and (not entry.disabled_by or include_disabled_entities)
]
@callback
def async_entries_for_area(
registry: EntityRegistry, area_id: str
) -> List[RegistryEntry]:
"""Return entries that match an area."""
return [entry for entry in registry.entities.values() if entry.area_id == area_id]
@callback
def async_entries_for_config_entry(
registry: EntityRegistry, config_entry_id: str
) -> List[RegistryEntry]:
"""Return entries that match a config entry."""
return [
entry
for entry in registry.entities.values()
if entry.config_entry_id == config_entry_id
]
async def _async_migrate(entities: Dict[str, Any]) -> Dict[str, List[Dict[str, Any]]]:
"""Migrate the YAML config file to storage helper format."""
return {
"entities": [
{"entity_id": entity_id, **info} for entity_id, info in entities.items()
]
}
@callback
def async_setup_entity_restore(
hass: HomeAssistantType, registry: EntityRegistry
) -> None:
"""Set up the entity restore mechanism."""
@callback
def cleanup_restored_states(event: Event) -> None:
"""Clean up restored states."""
if event.data["action"] != "remove":
return
state = hass.states.get(event.data["entity_id"])
if state is None or not state.attributes.get(ATTR_RESTORED):
return
hass.states.async_remove(event.data["entity_id"], context=event.context)
hass.bus.async_listen(EVENT_ENTITY_REGISTRY_UPDATED, cleanup_restored_states)
if hass.is_running:
return
@callback
def _write_unavailable_states(_: Event) -> None:
"""Make sure state machine contains entry for each registered entity."""
states = hass.states
existing = set(states.async_entity_ids())
for entry in registry.entities.values():
if entry.entity_id in existing or entry.disabled:
continue
attrs: Dict[str, Any] = {ATTR_RESTORED: True}
if entry.capabilities is not None:
attrs.update(entry.capabilities)
if entry.supported_features is not None:
attrs[ATTR_SUPPORTED_FEATURES] = entry.supported_features
if entry.device_class is not None:
attrs[ATTR_DEVICE_CLASS] = entry.device_class
if entry.unit_of_measurement is not None:
attrs[ATTR_UNIT_OF_MEASUREMENT] = entry.unit_of_measurement
name = entry.name or entry.original_name
if name is not None:
attrs[ATTR_FRIENDLY_NAME] = name
icon = entry.icon or entry.original_icon
if icon is not None:
attrs[ATTR_ICON] = icon
states.async_set(entry.entity_id, STATE_UNAVAILABLE, attrs)
hass.bus.async_listen(EVENT_HOMEASSISTANT_START, _write_unavailable_states)
async def async_migrate_entries(
hass: HomeAssistantType,
config_entry_id: str,
entry_callback: Callable[[RegistryEntry], Optional[dict]],
) -> None:
"""Migrator of unique IDs."""
ent_reg = await async_get_registry(hass)
for entry in ent_reg.entities.values():
if entry.config_entry_id != config_entry_id:
continue
updates = entry_callback(entry)
if updates is not None:
ent_reg.async_update_entity(entry.entity_id, **updates)