Allow renaming entities in entity registry (#12636)

* Allow renaming entities in entity registry

* Lint
pull/12506/merge
Paulus Schoutsen 2018-02-24 10:53:59 -08:00 committed by GitHub
parent 2821820281
commit 6d431c3fc3
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
9 changed files with 333 additions and 47 deletions

View File

@ -13,7 +13,8 @@ from homeassistant.util.yaml import load_yaml, dump
DOMAIN = 'config'
DEPENDENCIES = ['http']
SECTIONS = ('core', 'customize', 'group', 'hassbian', 'automation', 'script')
SECTIONS = ('core', 'customize', 'group', 'hassbian', 'automation', 'script',
'entity_registry')
ON_DEMAND = ('zwave',)
FEATURE_FLAGS = ('config_entries',)

View File

@ -0,0 +1,55 @@
"""HTTP views to interact with the entity registry."""
import voluptuous as vol
from homeassistant.core import callback
from homeassistant.components.http import HomeAssistantView
from homeassistant.components.http.data_validator import RequestDataValidator
from homeassistant.helpers.entity_registry import async_get_registry
async def async_setup(hass):
"""Enable the Entity Registry views."""
hass.http.register_view(ConfigManagerEntityView)
return True
class ConfigManagerEntityView(HomeAssistantView):
"""View to interact with an entity registry entry."""
url = '/api/config/entity_registry/{entity_id}'
name = 'api:config:entity_registry:entity'
async def get(self, request, entity_id):
"""Get the entity registry settings for an entity."""
hass = request.app['hass']
registry = await async_get_registry(hass)
entry = registry.entities.get(entity_id)
if entry is None:
return self.json_message('Entry not found', 404)
return self.json(_entry_dict(entry))
@RequestDataValidator(vol.Schema({
# If passed in, we update value. Passing None will remove old value.
vol.Optional('name'): vol.Any(str, None),
}))
async def post(self, request, entity_id, data):
"""Update the entity registry settings for an entity."""
hass = request.app['hass']
registry = await async_get_registry(hass)
if entity_id not in registry.entities:
return self.json_message('Entry not found', 404)
entry = registry.async_update_entity(entity_id, **data)
return self.json(_entry_dict(entry))
@callback
def _entry_dict(entry):
"""Helper to convert entry to API format."""
return {
'entity_id': entry.entity_id,
'name': entry.name
}

View File

@ -28,11 +28,11 @@ SUPPORT_DEMO = (SUPPORT_BRIGHTNESS | SUPPORT_COLOR_TEMP | SUPPORT_EFFECT |
def setup_platform(hass, config, add_devices_callback, discovery_info=None):
"""Set up the demo light platform."""
add_devices_callback([
DemoLight("Bed Light", False, True, effect_list=LIGHT_EFFECT_LIST,
DemoLight(1, "Bed Light", False, True, effect_list=LIGHT_EFFECT_LIST,
effect=LIGHT_EFFECT_LIST[0]),
DemoLight("Ceiling Lights", True, True,
DemoLight(2, "Ceiling Lights", True, True,
LIGHT_COLORS[0], LIGHT_TEMPS[1]),
DemoLight("Kitchen Lights", True, True,
DemoLight(3, "Kitchen Lights", True, True,
LIGHT_COLORS[1], LIGHT_TEMPS[0])
])
@ -40,10 +40,11 @@ def setup_platform(hass, config, add_devices_callback, discovery_info=None):
class DemoLight(Light):
"""Representation of a demo light."""
def __init__(self, name, state, available=False, rgb=None, ct=None,
brightness=180, xy_color=(.5, .5), white=200,
def __init__(self, unique_id, name, state, available=False, rgb=None,
ct=None, brightness=180, xy_color=(.5, .5), white=200,
effect_list=None, effect=None):
"""Initialize the light."""
self._unique_id = unique_id
self._name = name
self._state = state
self._rgb = rgb
@ -64,6 +65,11 @@ class DemoLight(Light):
"""Return the name of the light if any."""
return self._name
@property
def unique_id(self):
"""Return unique ID for light."""
return self._unique_id
@property
def available(self) -> bool:
"""Return availability."""

View File

@ -340,6 +340,12 @@ class Entity(object):
else:
self.hass.states.async_remove(self.entity_id)
@callback
def async_registry_updated(self, old, new):
"""Called when the entity registry has been updated."""
self.registry_name = new.name
self.async_schedule_update_ha_state()
def __eq__(self, other):
"""Return the comparison."""
if not isinstance(other, self.__class__):

View File

@ -10,12 +10,11 @@ from homeassistant.util.async import (
import homeassistant.util.dt as dt_util
from .event import async_track_time_interval, async_track_point_in_time
from .entity_registry import EntityRegistry
from .entity_registry import async_get_registry
SLOW_SETUP_WARNING = 10
SLOW_SETUP_MAX_WAIT = 60
PLATFORM_NOT_READY_RETRIES = 10
DATA_REGISTRY = 'entity_registry'
class EntityPlatform(object):
@ -156,12 +155,7 @@ class EntityPlatform(object):
hass = self.hass
component_entities = set(hass.states.async_entity_ids(self.domain))
registry = hass.data.get(DATA_REGISTRY)
if registry is None:
registry = hass.data[DATA_REGISTRY] = EntityRegistry(hass)
yield from registry.async_ensure_loaded()
registry = yield from async_get_registry(hass)
tasks = [
self._async_add_entity(entity, update_before_add,
@ -226,6 +220,7 @@ class EntityPlatform(object):
entity.entity_id = entry.entity_id
entity.registry_name = entry.name
entry.add_update_listener(entity)
# We won't generate an entity ID if the platform has already set one
# We will however make sure that platform cannot pick a registered ID

View File

@ -15,17 +15,20 @@ from collections import OrderedDict
from itertools import chain
import logging
import os
import weakref
import attr
from ..core import callback, split_entity_id
from ..loader import bind_hass
from ..util import ensure_unique_string, slugify
from ..util.yaml import load_yaml, save_yaml
PATH_REGISTRY = 'entity_registry.yaml'
DATA_REGISTRY = 'entity_registry'
SAVE_DELAY = 10
_LOGGER = logging.getLogger(__name__)
_UNDEF = object()
DISABLED_HASS = 'hass'
DISABLED_USER = 'user'
@ -34,6 +37,8 @@ DISABLED_USER = 'user'
class RegistryEntry:
"""Entity Registry Entry."""
# pylint: disable=no-member
entity_id = attr.ib(type=str)
unique_id = attr.ib(type=str)
platform = attr.ib(type=str)
@ -41,17 +46,27 @@ class RegistryEntry:
disabled_by = attr.ib(
type=str, default=None,
validator=attr.validators.in_((DISABLED_HASS, DISABLED_USER, None)))
domain = attr.ib(type=str, default=None, init=False, repr=False)
update_listeners = attr.ib(type=list, default=attr.Factory(list),
repr=False)
domain = attr.ib(type=str, init=False, repr=False)
def __attrs_post_init__(self):
"""Computed properties."""
object.__setattr__(self, "domain", split_entity_id(self.entity_id)[0])
@domain.default
def _domain_default(self):
"""Compute domain value."""
return split_entity_id(self.entity_id)[0]
@property
def disabled(self):
"""Return if entry is disabled."""
return self.disabled_by is not None
def add_update_listener(self, listener):
"""Listen for when entry is updated.
Listener: Callback function(old_entry, new_entry)
"""
self.update_listeners.append(weakref.ref(listener))
class EntityRegistry:
"""Class to hold a registry of entities."""
@ -102,6 +117,39 @@ class EntityRegistry:
self.async_schedule_save()
return entity
@callback
def async_update_entity(self, entity_id, *, name=_UNDEF):
"""Update properties of an entity."""
old = self.entities[entity_id]
changes = {}
if name is not _UNDEF and name != old.name:
changes['name'] = name
if not changes:
return old
new = self.entities[entity_id] = attr.evolve(old, **changes)
to_remove = []
for listener_ref in new.update_listeners:
listener = listener_ref()
if listener is None:
to_remove.append(listener)
else:
try:
listener.async_registry_updated(old, new)
except Exception: # pylint: disable=broad-except
_LOGGER.exception('Error calling update listener')
for ref in to_remove:
new.update_listeners.remove(ref)
self.async_schedule_save()
return new
@asyncio.coroutine
def async_ensure_loaded(self):
"""Load the registry from disk."""
@ -154,7 +202,20 @@ class EntityRegistry:
data[entry.entity_id] = {
'unique_id': entry.unique_id,
'platform': entry.platform,
'name': entry.name,
}
yield from self.hass.async_add_job(
save_yaml, self.hass.config.path(PATH_REGISTRY), data)
@bind_hass
async def async_get_registry(hass) -> EntityRegistry:
"""Return entity registry instance."""
registry = hass.data.get(DATA_REGISTRY)
if registry is None:
registry = hass.data[DATA_REGISTRY] = EntityRegistry(hass)
await registry.async_ensure_loaded()
return registry

View File

@ -1,5 +1,6 @@
"""Test the helper method for writing tests."""
import asyncio
from datetime import timedelta
import functools as ft
import os
import sys
@ -298,7 +299,7 @@ def mock_registry(hass, mock_entries=None):
"""Mock the Entity Registry."""
registry = entity_registry.EntityRegistry(hass)
registry.entities = mock_entries or {}
hass.data[entity_platform.DATA_REGISTRY] = registry
hass.data[entity_registry.DATA_REGISTRY] = registry
return registry
@ -361,6 +362,32 @@ class MockPlatform(object):
self.async_setup_platform = mock_coro_func()
class MockEntityPlatform(entity_platform.EntityPlatform):
"""Mock class with some mock defaults."""
def __init__(
self, hass,
logger=None,
domain='test_domain',
platform_name='test_platform',
scan_interval=timedelta(seconds=15),
parallel_updates=0,
entity_namespace=None,
async_entities_added_callback=lambda: None
):
"""Initialize a mock entity platform."""
super().__init__(
hass=hass,
logger=logger,
domain=domain,
platform_name=platform_name,
scan_interval=scan_interval,
parallel_updates=parallel_updates,
entity_namespace=entity_namespace,
async_entities_added_callback=async_entities_added_callback,
)
class MockToggleDevice(entity.ToggleEntity):
"""Provide a mock toggle device."""

View File

@ -0,0 +1,134 @@
"""Test entity_registry API."""
import pytest
from homeassistant.setup import async_setup_component
from homeassistant.helpers.entity_registry import RegistryEntry
from homeassistant.components.config import entity_registry
from tests.common import mock_registry, MockEntity, MockEntityPlatform
@pytest.fixture
def client(hass, test_client):
"""Fixture that can interact with the config manager API."""
hass.loop.run_until_complete(async_setup_component(hass, 'http', {}))
hass.loop.run_until_complete(entity_registry.async_setup(hass))
yield hass.loop.run_until_complete(test_client(hass.http.app))
async def test_get_entity(hass, client):
"""Test get entry."""
mock_registry(hass, {
'test_domain.name': RegistryEntry(
entity_id='test_domain.name',
unique_id='1234',
platform='test_platform',
name='Hello World'
),
'test_domain.no_name': RegistryEntry(
entity_id='test_domain.no_name',
unique_id='6789',
platform='test_platform',
),
})
resp = await client.get(
'/api/config/entity_registry/test_domain.name')
assert resp.status == 200
data = await resp.json()
assert data == {
'entity_id': 'test_domain.name',
'name': 'Hello World'
}
resp = await client.get(
'/api/config/entity_registry/test_domain.no_name')
assert resp.status == 200
data = await resp.json()
assert data == {
'entity_id': 'test_domain.no_name',
'name': None
}
async def test_update_entity(hass, client):
"""Test get entry."""
mock_registry(hass, {
'test_domain.world': RegistryEntry(
entity_id='test_domain.world',
unique_id='1234',
# Using component.async_add_entities is equal to platform "domain"
platform='test_platform',
name='before update'
)
})
platform = MockEntityPlatform(hass)
entity = MockEntity(unique_id='1234')
await platform.async_add_entities([entity])
state = hass.states.get('test_domain.world')
assert state is not None
assert state.name == 'before update'
resp = await client.post(
'/api/config/entity_registry/test_domain.world', json={
'name': 'after update'
})
assert resp.status == 200
data = await resp.json()
assert data == {
'entity_id': 'test_domain.world',
'name': 'after update'
}
state = hass.states.get('test_domain.world')
assert state.name == 'after update'
async def test_update_entity_no_changes(hass, client):
"""Test get entry."""
mock_registry(hass, {
'test_domain.world': RegistryEntry(
entity_id='test_domain.world',
unique_id='1234',
# Using component.async_add_entities is equal to platform "domain"
platform='test_platform',
name='name of entity'
)
})
platform = MockEntityPlatform(hass)
entity = MockEntity(unique_id='1234')
await platform.async_add_entities([entity])
state = hass.states.get('test_domain.world')
assert state is not None
assert state.name == 'name of entity'
resp = await client.post(
'/api/config/entity_registry/test_domain.world', json={
'name': 'name of entity'
})
assert resp.status == 200
data = await resp.json()
assert data == {
'entity_id': 'test_domain.world',
'name': 'name of entity'
}
state = hass.states.get('test_domain.world')
assert state.name == 'name of entity'
async def test_get_nonexisting_entity(client):
"""Test get entry."""
resp = await client.get(
'/api/config/entity_registry/test_domain.non_existing')
assert resp.status == 404
async def test_update_nonexisting_entity(client):
"""Test get entry."""
resp = await client.post(
'/api/config/entity_registry/test_domain.non_existing', json={
'name': 'some name'
})
assert resp.status == 404

View File

@ -15,39 +15,13 @@ import homeassistant.util.dt as dt_util
from tests.common import (
get_test_home_assistant, MockPlatform, fire_time_changed, mock_registry,
MockEntity)
MockEntity, MockEntityPlatform)
_LOGGER = logging.getLogger(__name__)
DOMAIN = "test_domain"
PLATFORM = 'test_platform'
class MockEntityPlatform(entity_platform.EntityPlatform):
"""Mock class with some mock defaults."""
def __init__(
self, hass,
logger=None,
domain=DOMAIN,
platform_name=PLATFORM,
scan_interval=timedelta(seconds=15),
parallel_updates=0,
entity_namespace=None,
async_entities_added_callback=lambda: None
):
"""Initialize a mock entity platform."""
super().__init__(
hass=hass,
logger=logger,
domain=domain,
platform_name=platform_name,
scan_interval=scan_interval,
parallel_updates=parallel_updates,
entity_namespace=entity_namespace,
async_entities_added_callback=async_entities_added_callback,
)
class TestHelpersEntityPlatform(unittest.TestCase):
"""Test homeassistant.helpers.entity_component module."""
@ -510,3 +484,30 @@ def test_registry_respect_entity_disabled(hass):
yield from platform.async_add_entities([entity])
assert entity.entity_id is None
assert hass.states.async_entity_ids() == []
async def test_entity_registry_updates(hass):
"""Test that updates on the entity registry update platform entities."""
registry = mock_registry(hass, {
'test_domain.world': entity_registry.RegistryEntry(
entity_id='test_domain.world',
unique_id='1234',
# Using component.async_add_entities is equal to platform "domain"
platform='test_platform',
name='before update'
)
})
platform = MockEntityPlatform(hass)
entity = MockEntity(unique_id='1234')
await platform.async_add_entities([entity])
state = hass.states.get('test_domain.world')
assert state is not None
assert state.name == 'before update'
registry.async_update_entity('test_domain.world', name='after update')
await hass.async_block_till_done()
await hass.async_block_till_done()
state = hass.states.get('test_domain.world')
assert state.name == 'after update'