"""Provide a registry to track entity IDs. The Entity Registry keeps a registry of entities. Entities are uniquely identified by their domain, platform and a unique id provided by that platform. The Entity Registry will persist itself 10 seconds after a new entity is registered. Registering a new entity while a timer is in progress resets the timer. """ from asyncio import Event from collections import OrderedDict from itertools import chain import logging from typing import List, Optional, cast import weakref import attr from homeassistant.core import callback, split_entity_id, valid_entity_id from homeassistant.loader import bind_hass from homeassistant.util import ensure_unique_string, slugify from homeassistant.util.yaml import load_yaml from .typing import HomeAssistantType PATH_REGISTRY = 'entity_registry.yaml' DATA_REGISTRY = 'entity_registry' EVENT_ENTITY_REGISTRY_UPDATED = 'entity_registry_updated' SAVE_DELAY = 10 _LOGGER = logging.getLogger(__name__) _UNDEF = object() DISABLED_HASS = 'hass' DISABLED_USER = 'user' STORAGE_VERSION = 1 STORAGE_KEY = 'core.entity_registry' @attr.s(slots=True, frozen=True) class RegistryEntry: """Entity Registry Entry.""" entity_id = attr.ib(type=str) unique_id = attr.ib(type=str) platform = attr.ib(type=str) name = attr.ib(type=str, default=None) device_id = attr.ib(type=str, default=None) config_entry_id = attr.ib(type=str, default=None) disabled_by = attr.ib( type=str, default=None, validator=attr.validators.in_((DISABLED_HASS, DISABLED_USER, None))) update_listeners = attr.ib(type=list, default=attr.Factory(list), repr=False) domain = attr.ib(type=str, init=False, repr=False) @domain.default def _domain_default(self): """Compute domain value.""" return split_entity_id(self.entity_id)[0] @property def disabled(self): """Return if entry is disabled.""" return self.disabled_by is not None def add_update_listener(self, listener): """Listen for when entry is updated. Listener: Callback function(old_entry, new_entry) Returns function to unlisten. """ weak_listener = weakref.ref(listener) self.update_listeners.append(weak_listener) return lambda: self.update_listeners.remove(weak_listener) class EntityRegistry: """Class to hold a registry of entities.""" def __init__(self, hass): """Initialize the registry.""" self.hass = hass self.entities = None self._store = hass.helpers.storage.Store(STORAGE_VERSION, STORAGE_KEY) @callback def async_is_registered(self, entity_id): """Check if an entity_id is currently registered.""" return entity_id in self.entities @callback def async_get(self, entity_id: str) -> Optional[RegistryEntry]: """Get EntityEntry for an entity_id.""" return self.entities.get(entity_id) @callback def async_get_entity_id(self, domain: str, platform: str, unique_id: str): """Check if an entity_id is currently registered.""" for entity in self.entities.values(): if entity.domain == domain and entity.platform == platform and \ entity.unique_id == unique_id: return entity.entity_id return None @callback def async_generate_entity_id(self, domain, suggested_object_id, known_object_ids=None): """Generate an entity ID that does not conflict. Conflicts checked against registered and currently existing entities. """ return ensure_unique_string( '{}.{}'.format(domain, slugify(suggested_object_id)), chain(self.entities.keys(), self.hass.states.async_entity_ids(domain), known_object_ids if known_object_ids else []) ) @callback def async_get_or_create(self, domain, platform, unique_id, *, suggested_object_id=None, config_entry_id=None, device_id=None, known_object_ids=None): """Get entity. Create if it doesn't exist.""" entity_id = self.async_get_entity_id(domain, platform, unique_id) if entity_id: return self._async_update_entity( entity_id, config_entry_id=config_entry_id, device_id=device_id, # When we changed our slugify algorithm, we invalidated some # stored entity IDs with either a __ or ending in _. # Fix introduced in 0.86 (Jan 23, 2019). Next line can be # removed when we release 1.0 or in 2020. new_entity_id='.'.join(slugify(part) for part in entity_id.split('.', 1))) entity_id = self.async_generate_entity_id( domain, suggested_object_id or '{}_{}'.format(platform, unique_id), known_object_ids) entity = RegistryEntry( entity_id=entity_id, config_entry_id=config_entry_id, device_id=device_id, unique_id=unique_id, platform=platform, ) self.entities[entity_id] = entity _LOGGER.info('Registered new %s.%s entity: %s', domain, platform, entity_id) self.async_schedule_save() self.hass.bus.async_fire(EVENT_ENTITY_REGISTRY_UPDATED, { 'action': 'create', 'entity_id': entity_id }) return entity @callback def async_remove(self, entity_id): """Remove an entity from registry.""" self.entities.pop(entity_id) self.hass.bus.async_fire(EVENT_ENTITY_REGISTRY_UPDATED, { 'action': 'remove', 'entity_id': entity_id }) self.async_schedule_save() @callback def async_update_entity(self, entity_id, *, name=_UNDEF, new_entity_id=_UNDEF, new_unique_id=_UNDEF): """Update properties of an entity.""" return self._async_update_entity( entity_id, name=name, new_entity_id=new_entity_id, new_unique_id=new_unique_id ) @callback def _async_update_entity(self, entity_id, *, name=_UNDEF, config_entry_id=_UNDEF, new_entity_id=_UNDEF, device_id=_UNDEF, new_unique_id=_UNDEF): """Private facing update properties method.""" old = self.entities[entity_id] changes = {} if name is not _UNDEF and name != old.name: changes['name'] = name if (config_entry_id is not _UNDEF and config_entry_id != old.config_entry_id): changes['config_entry_id'] = config_entry_id if (device_id is not _UNDEF and device_id != old.device_id): changes['device_id'] = device_id if new_entity_id is not _UNDEF and new_entity_id != old.entity_id: if self.async_is_registered(new_entity_id): raise ValueError('Entity is already registered') if not valid_entity_id(new_entity_id): raise ValueError('Invalid entity ID') if (split_entity_id(new_entity_id)[0] != split_entity_id(entity_id)[0]): raise ValueError('New entity ID should be same domain') self.entities.pop(entity_id) entity_id = changes['entity_id'] = new_entity_id if new_unique_id is not _UNDEF: conflict = next((entity for entity in self.entities.values() if entity.unique_id == new_unique_id and entity.domain == old.domain and entity.platform == old.platform), None) if conflict: raise ValueError( "Unique id '{}' is already in use by '{}'".format( new_unique_id, conflict.entity_id)) changes['unique_id'] = new_unique_id if not changes: return old new = self.entities[entity_id] = attr.evolve(old, **changes) to_remove = [] for listener_ref in new.update_listeners: listener = listener_ref() if listener is None: to_remove.append(listener_ref) else: try: listener.async_registry_updated(old, new) except Exception: # pylint: disable=broad-except _LOGGER.exception('Error calling update listener') for ref in to_remove: new.update_listeners.remove(ref) self.async_schedule_save() self.hass.bus.async_fire(EVENT_ENTITY_REGISTRY_UPDATED, { 'action': 'update', 'entity_id': entity_id }) return new async def async_load(self): """Load the entity registry.""" data = await self.hass.helpers.storage.async_migrator( self.hass.config.path(PATH_REGISTRY), self._store, old_conf_load_func=load_yaml, old_conf_migrate_func=_async_migrate ) entities = OrderedDict() if data is not None: for entity in data['entities']: entities[entity['entity_id']] = RegistryEntry( entity_id=entity['entity_id'], config_entry_id=entity.get('config_entry_id'), device_id=entity.get('device_id'), unique_id=entity['unique_id'], platform=entity['platform'], name=entity.get('name'), disabled_by=entity.get('disabled_by') ) self.entities = entities @callback def async_schedule_save(self): """Schedule saving the entity registry.""" self._store.async_delay_save(self._data_to_save, SAVE_DELAY) @callback def _data_to_save(self): """Return data of entity registry to store in a file.""" data = {} data['entities'] = [ { 'entity_id': entry.entity_id, 'config_entry_id': entry.config_entry_id, 'device_id': entry.device_id, 'unique_id': entry.unique_id, 'platform': entry.platform, 'name': entry.name, 'disabled_by': entry.disabled_by, } for entry in self.entities.values() ] return data @callback def async_clear_config_entry(self, config_entry): """Clear config entry from registry entries.""" for entity_id in [ entity_id for entity_id, entry in self.entities.items() if config_entry == entry.config_entry_id]: self.async_remove(entity_id) @bind_hass async def async_get_registry(hass: HomeAssistantType) -> EntityRegistry: """Return entity registry instance.""" reg_or_evt = hass.data.get(DATA_REGISTRY) if not reg_or_evt: evt = hass.data[DATA_REGISTRY] = Event() reg = EntityRegistry(hass) await reg.async_load() hass.data[DATA_REGISTRY] = reg evt.set() return reg if isinstance(reg_or_evt, Event): evt = reg_or_evt await evt.wait() return cast(EntityRegistry, hass.data.get(DATA_REGISTRY)) return cast(EntityRegistry, reg_or_evt) @callback def async_entries_for_device(registry: EntityRegistry, device_id: str) \ -> List[RegistryEntry]: """Return entries that match a device.""" return [entry for entry in registry.entities.values() if entry.device_id == device_id] async def _async_migrate(entities): """Migrate the YAML config file to storage helper format.""" return { 'entities': [ {'entity_id': entity_id, **info} for entity_id, info in entities.items() ] }