From 6d431c3fc3f6a8c9bf73c73396ddd0c47dd7aeff Mon Sep 17 00:00:00 2001 From: Paulus Schoutsen Date: Sat, 24 Feb 2018 10:53:59 -0800 Subject: [PATCH] Allow renaming entities in entity registry (#12636) * Allow renaming entities in entity registry * Lint --- homeassistant/components/config/__init__.py | 3 +- .../components/config/entity_registry.py | 55 +++++++ homeassistant/components/light/demo.py | 16 ++- homeassistant/helpers/entity.py | 6 + homeassistant/helpers/entity_platform.py | 11 +- homeassistant/helpers/entity_registry.py | 71 +++++++++- tests/common.py | 29 +++- .../components/config/test_entity_registry.py | 134 ++++++++++++++++++ tests/helpers/test_entity_platform.py | 55 +++---- 9 files changed, 333 insertions(+), 47 deletions(-) create mode 100644 homeassistant/components/config/entity_registry.py create mode 100644 tests/components/config/test_entity_registry.py diff --git a/homeassistant/components/config/__init__.py b/homeassistant/components/config/__init__.py index 39c35205619..601b12ffe4a 100644 --- a/homeassistant/components/config/__init__.py +++ b/homeassistant/components/config/__init__.py @@ -13,7 +13,8 @@ from homeassistant.util.yaml import load_yaml, dump DOMAIN = 'config' DEPENDENCIES = ['http'] -SECTIONS = ('core', 'customize', 'group', 'hassbian', 'automation', 'script') +SECTIONS = ('core', 'customize', 'group', 'hassbian', 'automation', 'script', + 'entity_registry') ON_DEMAND = ('zwave',) FEATURE_FLAGS = ('config_entries',) diff --git a/homeassistant/components/config/entity_registry.py b/homeassistant/components/config/entity_registry.py new file mode 100644 index 00000000000..4b9a2c89da0 --- /dev/null +++ b/homeassistant/components/config/entity_registry.py @@ -0,0 +1,55 @@ +"""HTTP views to interact with the entity registry.""" +import voluptuous as vol + +from homeassistant.core import callback +from homeassistant.components.http import HomeAssistantView +from homeassistant.components.http.data_validator import RequestDataValidator +from homeassistant.helpers.entity_registry import async_get_registry + + +async def async_setup(hass): + """Enable the Entity Registry views.""" + hass.http.register_view(ConfigManagerEntityView) + return True + + +class ConfigManagerEntityView(HomeAssistantView): + """View to interact with an entity registry entry.""" + + url = '/api/config/entity_registry/{entity_id}' + name = 'api:config:entity_registry:entity' + + async def get(self, request, entity_id): + """Get the entity registry settings for an entity.""" + hass = request.app['hass'] + registry = await async_get_registry(hass) + entry = registry.entities.get(entity_id) + + if entry is None: + return self.json_message('Entry not found', 404) + + return self.json(_entry_dict(entry)) + + @RequestDataValidator(vol.Schema({ + # If passed in, we update value. Passing None will remove old value. + vol.Optional('name'): vol.Any(str, None), + })) + async def post(self, request, entity_id, data): + """Update the entity registry settings for an entity.""" + hass = request.app['hass'] + registry = await async_get_registry(hass) + + if entity_id not in registry.entities: + return self.json_message('Entry not found', 404) + + entry = registry.async_update_entity(entity_id, **data) + return self.json(_entry_dict(entry)) + + +@callback +def _entry_dict(entry): + """Helper to convert entry to API format.""" + return { + 'entity_id': entry.entity_id, + 'name': entry.name + } diff --git a/homeassistant/components/light/demo.py b/homeassistant/components/light/demo.py index d01611716eb..acc70a57ff4 100644 --- a/homeassistant/components/light/demo.py +++ b/homeassistant/components/light/demo.py @@ -28,11 +28,11 @@ SUPPORT_DEMO = (SUPPORT_BRIGHTNESS | SUPPORT_COLOR_TEMP | SUPPORT_EFFECT | def setup_platform(hass, config, add_devices_callback, discovery_info=None): """Set up the demo light platform.""" add_devices_callback([ - DemoLight("Bed Light", False, True, effect_list=LIGHT_EFFECT_LIST, + DemoLight(1, "Bed Light", False, True, effect_list=LIGHT_EFFECT_LIST, effect=LIGHT_EFFECT_LIST[0]), - DemoLight("Ceiling Lights", True, True, + DemoLight(2, "Ceiling Lights", True, True, LIGHT_COLORS[0], LIGHT_TEMPS[1]), - DemoLight("Kitchen Lights", True, True, + DemoLight(3, "Kitchen Lights", True, True, LIGHT_COLORS[1], LIGHT_TEMPS[0]) ]) @@ -40,10 +40,11 @@ def setup_platform(hass, config, add_devices_callback, discovery_info=None): class DemoLight(Light): """Representation of a demo light.""" - def __init__(self, name, state, available=False, rgb=None, ct=None, - brightness=180, xy_color=(.5, .5), white=200, + def __init__(self, unique_id, name, state, available=False, rgb=None, + ct=None, brightness=180, xy_color=(.5, .5), white=200, effect_list=None, effect=None): """Initialize the light.""" + self._unique_id = unique_id self._name = name self._state = state self._rgb = rgb @@ -64,6 +65,11 @@ class DemoLight(Light): """Return the name of the light if any.""" return self._name + @property + def unique_id(self): + """Return unique ID for light.""" + return self._unique_id + @property def available(self) -> bool: """Return availability.""" diff --git a/homeassistant/helpers/entity.py b/homeassistant/helpers/entity.py index 04719e89187..9168c459f74 100644 --- a/homeassistant/helpers/entity.py +++ b/homeassistant/helpers/entity.py @@ -340,6 +340,12 @@ class Entity(object): else: self.hass.states.async_remove(self.entity_id) + @callback + def async_registry_updated(self, old, new): + """Called when the entity registry has been updated.""" + self.registry_name = new.name + self.async_schedule_update_ha_state() + def __eq__(self, other): """Return the comparison.""" if not isinstance(other, self.__class__): diff --git a/homeassistant/helpers/entity_platform.py b/homeassistant/helpers/entity_platform.py index e17e178bcfb..f627ccd24b1 100644 --- a/homeassistant/helpers/entity_platform.py +++ b/homeassistant/helpers/entity_platform.py @@ -10,12 +10,11 @@ from homeassistant.util.async import ( import homeassistant.util.dt as dt_util from .event import async_track_time_interval, async_track_point_in_time -from .entity_registry import EntityRegistry +from .entity_registry import async_get_registry SLOW_SETUP_WARNING = 10 SLOW_SETUP_MAX_WAIT = 60 PLATFORM_NOT_READY_RETRIES = 10 -DATA_REGISTRY = 'entity_registry' class EntityPlatform(object): @@ -156,12 +155,7 @@ class EntityPlatform(object): hass = self.hass component_entities = set(hass.states.async_entity_ids(self.domain)) - registry = hass.data.get(DATA_REGISTRY) - - if registry is None: - registry = hass.data[DATA_REGISTRY] = EntityRegistry(hass) - - yield from registry.async_ensure_loaded() + registry = yield from async_get_registry(hass) tasks = [ self._async_add_entity(entity, update_before_add, @@ -226,6 +220,7 @@ class EntityPlatform(object): entity.entity_id = entry.entity_id entity.registry_name = entry.name + entry.add_update_listener(entity) # We won't generate an entity ID if the platform has already set one # We will however make sure that platform cannot pick a registered ID diff --git a/homeassistant/helpers/entity_registry.py b/homeassistant/helpers/entity_registry.py index 89719b0b823..c6eafa91335 100644 --- a/homeassistant/helpers/entity_registry.py +++ b/homeassistant/helpers/entity_registry.py @@ -15,17 +15,20 @@ from collections import OrderedDict from itertools import chain import logging import os +import weakref import attr from ..core import callback, split_entity_id +from ..loader import bind_hass from ..util import ensure_unique_string, slugify from ..util.yaml import load_yaml, save_yaml PATH_REGISTRY = 'entity_registry.yaml' +DATA_REGISTRY = 'entity_registry' SAVE_DELAY = 10 _LOGGER = logging.getLogger(__name__) - +_UNDEF = object() DISABLED_HASS = 'hass' DISABLED_USER = 'user' @@ -34,6 +37,8 @@ DISABLED_USER = 'user' class RegistryEntry: """Entity Registry Entry.""" + # pylint: disable=no-member + entity_id = attr.ib(type=str) unique_id = attr.ib(type=str) platform = attr.ib(type=str) @@ -41,17 +46,27 @@ class RegistryEntry: disabled_by = attr.ib( type=str, default=None, validator=attr.validators.in_((DISABLED_HASS, DISABLED_USER, None))) - domain = attr.ib(type=str, default=None, init=False, repr=False) + update_listeners = attr.ib(type=list, default=attr.Factory(list), + repr=False) + domain = attr.ib(type=str, init=False, repr=False) - def __attrs_post_init__(self): - """Computed properties.""" - object.__setattr__(self, "domain", split_entity_id(self.entity_id)[0]) + @domain.default + def _domain_default(self): + """Compute domain value.""" + return split_entity_id(self.entity_id)[0] @property def disabled(self): """Return if entry is disabled.""" return self.disabled_by is not None + def add_update_listener(self, listener): + """Listen for when entry is updated. + + Listener: Callback function(old_entry, new_entry) + """ + self.update_listeners.append(weakref.ref(listener)) + class EntityRegistry: """Class to hold a registry of entities.""" @@ -102,6 +117,39 @@ class EntityRegistry: self.async_schedule_save() return entity + @callback + def async_update_entity(self, entity_id, *, name=_UNDEF): + """Update properties of an entity.""" + old = self.entities[entity_id] + + changes = {} + + if name is not _UNDEF and name != old.name: + changes['name'] = name + + if not changes: + return old + + new = self.entities[entity_id] = attr.evolve(old, **changes) + + to_remove = [] + for listener_ref in new.update_listeners: + listener = listener_ref() + if listener is None: + to_remove.append(listener) + else: + try: + listener.async_registry_updated(old, new) + except Exception: # pylint: disable=broad-except + _LOGGER.exception('Error calling update listener') + + for ref in to_remove: + new.update_listeners.remove(ref) + + self.async_schedule_save() + + return new + @asyncio.coroutine def async_ensure_loaded(self): """Load the registry from disk.""" @@ -154,7 +202,20 @@ class EntityRegistry: data[entry.entity_id] = { 'unique_id': entry.unique_id, 'platform': entry.platform, + 'name': entry.name, } yield from self.hass.async_add_job( save_yaml, self.hass.config.path(PATH_REGISTRY), data) + + +@bind_hass +async def async_get_registry(hass) -> EntityRegistry: + """Return entity registry instance.""" + registry = hass.data.get(DATA_REGISTRY) + + if registry is None: + registry = hass.data[DATA_REGISTRY] = EntityRegistry(hass) + + await registry.async_ensure_loaded() + return registry diff --git a/tests/common.py b/tests/common.py index 6fee7b1bec0..15ce80a9552 100644 --- a/tests/common.py +++ b/tests/common.py @@ -1,5 +1,6 @@ """Test the helper method for writing tests.""" import asyncio +from datetime import timedelta import functools as ft import os import sys @@ -298,7 +299,7 @@ def mock_registry(hass, mock_entries=None): """Mock the Entity Registry.""" registry = entity_registry.EntityRegistry(hass) registry.entities = mock_entries or {} - hass.data[entity_platform.DATA_REGISTRY] = registry + hass.data[entity_registry.DATA_REGISTRY] = registry return registry @@ -361,6 +362,32 @@ class MockPlatform(object): self.async_setup_platform = mock_coro_func() +class MockEntityPlatform(entity_platform.EntityPlatform): + """Mock class with some mock defaults.""" + + def __init__( + self, hass, + logger=None, + domain='test_domain', + platform_name='test_platform', + scan_interval=timedelta(seconds=15), + parallel_updates=0, + entity_namespace=None, + async_entities_added_callback=lambda: None + ): + """Initialize a mock entity platform.""" + super().__init__( + hass=hass, + logger=logger, + domain=domain, + platform_name=platform_name, + scan_interval=scan_interval, + parallel_updates=parallel_updates, + entity_namespace=entity_namespace, + async_entities_added_callback=async_entities_added_callback, + ) + + class MockToggleDevice(entity.ToggleEntity): """Provide a mock toggle device.""" diff --git a/tests/components/config/test_entity_registry.py b/tests/components/config/test_entity_registry.py new file mode 100644 index 00000000000..aa7a5ce5f0e --- /dev/null +++ b/tests/components/config/test_entity_registry.py @@ -0,0 +1,134 @@ +"""Test entity_registry API.""" +import pytest + +from homeassistant.setup import async_setup_component +from homeassistant.helpers.entity_registry import RegistryEntry +from homeassistant.components.config import entity_registry +from tests.common import mock_registry, MockEntity, MockEntityPlatform + + +@pytest.fixture +def client(hass, test_client): + """Fixture that can interact with the config manager API.""" + hass.loop.run_until_complete(async_setup_component(hass, 'http', {})) + hass.loop.run_until_complete(entity_registry.async_setup(hass)) + yield hass.loop.run_until_complete(test_client(hass.http.app)) + + +async def test_get_entity(hass, client): + """Test get entry.""" + mock_registry(hass, { + 'test_domain.name': RegistryEntry( + entity_id='test_domain.name', + unique_id='1234', + platform='test_platform', + name='Hello World' + ), + 'test_domain.no_name': RegistryEntry( + entity_id='test_domain.no_name', + unique_id='6789', + platform='test_platform', + ), + }) + + resp = await client.get( + '/api/config/entity_registry/test_domain.name') + assert resp.status == 200 + data = await resp.json() + assert data == { + 'entity_id': 'test_domain.name', + 'name': 'Hello World' + } + + resp = await client.get( + '/api/config/entity_registry/test_domain.no_name') + assert resp.status == 200 + data = await resp.json() + assert data == { + 'entity_id': 'test_domain.no_name', + 'name': None + } + + +async def test_update_entity(hass, client): + """Test get entry.""" + mock_registry(hass, { + 'test_domain.world': RegistryEntry( + entity_id='test_domain.world', + unique_id='1234', + # Using component.async_add_entities is equal to platform "domain" + platform='test_platform', + name='before update' + ) + }) + platform = MockEntityPlatform(hass) + entity = MockEntity(unique_id='1234') + await platform.async_add_entities([entity]) + + state = hass.states.get('test_domain.world') + assert state is not None + assert state.name == 'before update' + + resp = await client.post( + '/api/config/entity_registry/test_domain.world', json={ + 'name': 'after update' + }) + assert resp.status == 200 + data = await resp.json() + assert data == { + 'entity_id': 'test_domain.world', + 'name': 'after update' + } + + state = hass.states.get('test_domain.world') + assert state.name == 'after update' + + +async def test_update_entity_no_changes(hass, client): + """Test get entry.""" + mock_registry(hass, { + 'test_domain.world': RegistryEntry( + entity_id='test_domain.world', + unique_id='1234', + # Using component.async_add_entities is equal to platform "domain" + platform='test_platform', + name='name of entity' + ) + }) + platform = MockEntityPlatform(hass) + entity = MockEntity(unique_id='1234') + await platform.async_add_entities([entity]) + + state = hass.states.get('test_domain.world') + assert state is not None + assert state.name == 'name of entity' + + resp = await client.post( + '/api/config/entity_registry/test_domain.world', json={ + 'name': 'name of entity' + }) + assert resp.status == 200 + data = await resp.json() + assert data == { + 'entity_id': 'test_domain.world', + 'name': 'name of entity' + } + + state = hass.states.get('test_domain.world') + assert state.name == 'name of entity' + + +async def test_get_nonexisting_entity(client): + """Test get entry.""" + resp = await client.get( + '/api/config/entity_registry/test_domain.non_existing') + assert resp.status == 404 + + +async def test_update_nonexisting_entity(client): + """Test get entry.""" + resp = await client.post( + '/api/config/entity_registry/test_domain.non_existing', json={ + 'name': 'some name' + }) + assert resp.status == 404 diff --git a/tests/helpers/test_entity_platform.py b/tests/helpers/test_entity_platform.py index 0681691ed67..8c085e4abb1 100644 --- a/tests/helpers/test_entity_platform.py +++ b/tests/helpers/test_entity_platform.py @@ -15,39 +15,13 @@ import homeassistant.util.dt as dt_util from tests.common import ( get_test_home_assistant, MockPlatform, fire_time_changed, mock_registry, - MockEntity) + MockEntity, MockEntityPlatform) _LOGGER = logging.getLogger(__name__) DOMAIN = "test_domain" PLATFORM = 'test_platform' -class MockEntityPlatform(entity_platform.EntityPlatform): - """Mock class with some mock defaults.""" - - def __init__( - self, hass, - logger=None, - domain=DOMAIN, - platform_name=PLATFORM, - scan_interval=timedelta(seconds=15), - parallel_updates=0, - entity_namespace=None, - async_entities_added_callback=lambda: None - ): - """Initialize a mock entity platform.""" - super().__init__( - hass=hass, - logger=logger, - domain=domain, - platform_name=platform_name, - scan_interval=scan_interval, - parallel_updates=parallel_updates, - entity_namespace=entity_namespace, - async_entities_added_callback=async_entities_added_callback, - ) - - class TestHelpersEntityPlatform(unittest.TestCase): """Test homeassistant.helpers.entity_component module.""" @@ -510,3 +484,30 @@ def test_registry_respect_entity_disabled(hass): yield from platform.async_add_entities([entity]) assert entity.entity_id is None assert hass.states.async_entity_ids() == [] + + +async def test_entity_registry_updates(hass): + """Test that updates on the entity registry update platform entities.""" + registry = mock_registry(hass, { + 'test_domain.world': entity_registry.RegistryEntry( + entity_id='test_domain.world', + unique_id='1234', + # Using component.async_add_entities is equal to platform "domain" + platform='test_platform', + name='before update' + ) + }) + platform = MockEntityPlatform(hass) + entity = MockEntity(unique_id='1234') + await platform.async_add_entities([entity]) + + state = hass.states.get('test_domain.world') + assert state is not None + assert state.name == 'before update' + + registry.async_update_entity('test_domain.world', name='after update') + await hass.async_block_till_done() + await hass.async_block_till_done() + + state = hass.states.get('test_domain.world') + assert state.name == 'after update'