281 lines
9.2 KiB
Python
281 lines
9.2 KiB
Python
"""Provide a registry to track entity IDs.
|
|
|
|
The Entity Registry keeps a registry of entities. Entities are uniquely
|
|
identified by their domain, platform and a unique id provided by that platform.
|
|
|
|
The Entity Registry will persist itself 10 seconds after a new entity is
|
|
registered. Registering a new entity while a timer is in progress resets the
|
|
timer.
|
|
"""
|
|
from collections import OrderedDict
|
|
from itertools import chain
|
|
import logging
|
|
import weakref
|
|
|
|
import attr
|
|
|
|
from homeassistant.core import callback, split_entity_id, valid_entity_id
|
|
from homeassistant.loader import bind_hass
|
|
from homeassistant.util import ensure_unique_string, slugify
|
|
from homeassistant.util.yaml import load_yaml
|
|
|
|
PATH_REGISTRY = 'entity_registry.yaml'
|
|
DATA_REGISTRY = 'entity_registry'
|
|
SAVE_DELAY = 10
|
|
_LOGGER = logging.getLogger(__name__)
|
|
_UNDEF = object()
|
|
DISABLED_HASS = 'hass'
|
|
DISABLED_USER = 'user'
|
|
|
|
STORAGE_VERSION = 1
|
|
STORAGE_KEY = 'core.entity_registry'
|
|
|
|
|
|
@attr.s(slots=True, frozen=True)
|
|
class RegistryEntry:
|
|
"""Entity Registry Entry."""
|
|
|
|
entity_id = attr.ib(type=str)
|
|
unique_id = attr.ib(type=str)
|
|
platform = attr.ib(type=str)
|
|
name = attr.ib(type=str, default=None)
|
|
device_id = attr.ib(type=str, default=None)
|
|
config_entry_id = attr.ib(type=str, default=None)
|
|
disabled_by = attr.ib(
|
|
type=str, default=None,
|
|
validator=attr.validators.in_((DISABLED_HASS, DISABLED_USER, None)))
|
|
update_listeners = attr.ib(type=list, default=attr.Factory(list),
|
|
repr=False)
|
|
domain = attr.ib(type=str, init=False, repr=False)
|
|
|
|
@domain.default
|
|
def _domain_default(self):
|
|
"""Compute domain value."""
|
|
return split_entity_id(self.entity_id)[0]
|
|
|
|
@property
|
|
def disabled(self):
|
|
"""Return if entry is disabled."""
|
|
return self.disabled_by is not None
|
|
|
|
def add_update_listener(self, listener):
|
|
"""Listen for when entry is updated.
|
|
|
|
Listener: Callback function(old_entry, new_entry)
|
|
|
|
Returns function to unlisten.
|
|
"""
|
|
weak_listener = weakref.ref(listener)
|
|
self.update_listeners.append(weak_listener)
|
|
|
|
return lambda: self.update_listeners.remove(weak_listener)
|
|
|
|
|
|
class EntityRegistry:
|
|
"""Class to hold a registry of entities."""
|
|
|
|
def __init__(self, hass):
|
|
"""Initialize the registry."""
|
|
self.hass = hass
|
|
self.entities = None
|
|
self._store = hass.helpers.storage.Store(STORAGE_VERSION, STORAGE_KEY)
|
|
|
|
@callback
|
|
def async_is_registered(self, entity_id):
|
|
"""Check if an entity_id is currently registered."""
|
|
return entity_id in self.entities
|
|
|
|
@callback
|
|
def async_get_entity_id(self, domain: str, platform: str, unique_id: str):
|
|
"""Check if an entity_id is currently registered."""
|
|
for entity in self.entities.values():
|
|
if entity.domain == domain and entity.platform == platform and \
|
|
entity.unique_id == unique_id:
|
|
return entity.entity_id
|
|
return None
|
|
|
|
@callback
|
|
def async_generate_entity_id(self, domain, suggested_object_id):
|
|
"""Generate an entity ID that does not conflict.
|
|
|
|
Conflicts checked against registered and currently existing entities.
|
|
"""
|
|
return ensure_unique_string(
|
|
'{}.{}'.format(domain, slugify(suggested_object_id)),
|
|
chain(self.entities.keys(),
|
|
self.hass.states.async_entity_ids(domain))
|
|
)
|
|
|
|
@callback
|
|
def async_get_or_create(self, domain, platform, unique_id, *,
|
|
suggested_object_id=None, config_entry_id=None,
|
|
device_id=None):
|
|
"""Get entity. Create if it doesn't exist."""
|
|
entity_id = self.async_get_entity_id(domain, platform, unique_id)
|
|
if entity_id:
|
|
return self._async_update_entity(
|
|
entity_id, config_entry_id=config_entry_id,
|
|
device_id=device_id)
|
|
|
|
entity_id = self.async_generate_entity_id(
|
|
domain, suggested_object_id or '{}_{}'.format(platform, unique_id))
|
|
|
|
entity = RegistryEntry(
|
|
entity_id=entity_id,
|
|
config_entry_id=config_entry_id,
|
|
device_id=device_id,
|
|
unique_id=unique_id,
|
|
platform=platform,
|
|
)
|
|
self.entities[entity_id] = entity
|
|
_LOGGER.info('Registered new %s.%s entity: %s',
|
|
domain, platform, entity_id)
|
|
self.async_schedule_save()
|
|
return entity
|
|
|
|
@callback
|
|
def async_update_entity(self, entity_id, *, name=_UNDEF,
|
|
new_entity_id=_UNDEF):
|
|
"""Update properties of an entity."""
|
|
return self._async_update_entity(
|
|
entity_id,
|
|
name=name,
|
|
new_entity_id=new_entity_id
|
|
)
|
|
|
|
@callback
|
|
def _async_update_entity(self, entity_id, *, name=_UNDEF,
|
|
config_entry_id=_UNDEF, new_entity_id=_UNDEF,
|
|
device_id=_UNDEF):
|
|
"""Private facing update properties method."""
|
|
old = self.entities[entity_id]
|
|
|
|
changes = {}
|
|
|
|
if name is not _UNDEF and name != old.name:
|
|
changes['name'] = name
|
|
|
|
if (config_entry_id is not _UNDEF and
|
|
config_entry_id != old.config_entry_id):
|
|
changes['config_entry_id'] = config_entry_id
|
|
|
|
if (device_id is not _UNDEF and device_id != old.device_id):
|
|
changes['device_id'] = device_id
|
|
|
|
if new_entity_id is not _UNDEF and new_entity_id != old.entity_id:
|
|
if self.async_is_registered(new_entity_id):
|
|
raise ValueError('Entity is already registered')
|
|
|
|
if not valid_entity_id(new_entity_id):
|
|
raise ValueError('Invalid entity ID')
|
|
|
|
if (split_entity_id(new_entity_id)[0] !=
|
|
split_entity_id(entity_id)[0]):
|
|
raise ValueError('New entity ID should be same domain')
|
|
|
|
self.entities.pop(entity_id)
|
|
entity_id = changes['entity_id'] = new_entity_id
|
|
|
|
if not changes:
|
|
return old
|
|
|
|
new = self.entities[entity_id] = attr.evolve(old, **changes)
|
|
|
|
to_remove = []
|
|
for listener_ref in new.update_listeners:
|
|
listener = listener_ref()
|
|
if listener is None:
|
|
to_remove.append(listener_ref)
|
|
else:
|
|
try:
|
|
listener.async_registry_updated(old, new)
|
|
except Exception: # pylint: disable=broad-except
|
|
_LOGGER.exception('Error calling update listener')
|
|
|
|
for ref in to_remove:
|
|
new.update_listeners.remove(ref)
|
|
|
|
self.async_schedule_save()
|
|
|
|
return new
|
|
|
|
async def async_load(self):
|
|
"""Load the entity registry."""
|
|
data = await self.hass.helpers.storage.async_migrator(
|
|
self.hass.config.path(PATH_REGISTRY), self._store,
|
|
old_conf_load_func=load_yaml,
|
|
old_conf_migrate_func=_async_migrate
|
|
)
|
|
entities = OrderedDict()
|
|
|
|
if data is not None:
|
|
for entity in data['entities']:
|
|
entities[entity['entity_id']] = RegistryEntry(
|
|
entity_id=entity['entity_id'],
|
|
config_entry_id=entity.get('config_entry_id'),
|
|
device_id=entity.get('device_id'),
|
|
unique_id=entity['unique_id'],
|
|
platform=entity['platform'],
|
|
name=entity.get('name'),
|
|
disabled_by=entity.get('disabled_by')
|
|
)
|
|
|
|
self.entities = entities
|
|
|
|
@callback
|
|
def async_schedule_save(self):
|
|
"""Schedule saving the entity registry."""
|
|
self._store.async_delay_save(self._data_to_save, SAVE_DELAY)
|
|
|
|
@callback
|
|
def _data_to_save(self):
|
|
"""Return data of entity registry to store in a file."""
|
|
data = {}
|
|
|
|
data['entities'] = [
|
|
{
|
|
'entity_id': entry.entity_id,
|
|
'config_entry_id': entry.config_entry_id,
|
|
'device_id': entry.device_id,
|
|
'unique_id': entry.unique_id,
|
|
'platform': entry.platform,
|
|
'name': entry.name,
|
|
'disabled_by': entry.disabled_by,
|
|
} for entry in self.entities.values()
|
|
]
|
|
|
|
return data
|
|
|
|
@callback
|
|
def async_clear_config_entry(self, config_entry):
|
|
"""Clear config entry from registry entries."""
|
|
for entity_id, entry in self.entities.items():
|
|
if config_entry == entry.config_entry_id:
|
|
self._async_update_entity(entity_id, config_entry_id=None)
|
|
|
|
|
|
@bind_hass
|
|
async def async_get_registry(hass) -> EntityRegistry:
|
|
"""Return entity registry instance."""
|
|
task = hass.data.get(DATA_REGISTRY)
|
|
|
|
if task is None:
|
|
async def _load_reg():
|
|
registry = EntityRegistry(hass)
|
|
await registry.async_load()
|
|
return registry
|
|
|
|
task = hass.data[DATA_REGISTRY] = hass.async_create_task(_load_reg())
|
|
|
|
return await task
|
|
|
|
|
|
async def _async_migrate(entities):
|
|
"""Migrate the YAML config file to storage helper format."""
|
|
return {
|
|
'entities': [
|
|
{'entity_id': entity_id, **info}
|
|
for entity_id, info in entities.items()
|
|
]
|
|
}
|