Allow renaming entities in entity registry (#12636)
* Allow renaming entities in entity registry * Lintpull/12506/merge
parent
2821820281
commit
6d431c3fc3
|
@ -13,7 +13,8 @@ from homeassistant.util.yaml import load_yaml, dump
|
|||
|
||||
DOMAIN = 'config'
|
||||
DEPENDENCIES = ['http']
|
||||
SECTIONS = ('core', 'customize', 'group', 'hassbian', 'automation', 'script')
|
||||
SECTIONS = ('core', 'customize', 'group', 'hassbian', 'automation', 'script',
|
||||
'entity_registry')
|
||||
ON_DEMAND = ('zwave',)
|
||||
FEATURE_FLAGS = ('config_entries',)
|
||||
|
||||
|
|
|
@ -0,0 +1,55 @@
|
|||
"""HTTP views to interact with the entity registry."""
|
||||
import voluptuous as vol
|
||||
|
||||
from homeassistant.core import callback
|
||||
from homeassistant.components.http import HomeAssistantView
|
||||
from homeassistant.components.http.data_validator import RequestDataValidator
|
||||
from homeassistant.helpers.entity_registry import async_get_registry
|
||||
|
||||
|
||||
async def async_setup(hass):
|
||||
"""Enable the Entity Registry views."""
|
||||
hass.http.register_view(ConfigManagerEntityView)
|
||||
return True
|
||||
|
||||
|
||||
class ConfigManagerEntityView(HomeAssistantView):
|
||||
"""View to interact with an entity registry entry."""
|
||||
|
||||
url = '/api/config/entity_registry/{entity_id}'
|
||||
name = 'api:config:entity_registry:entity'
|
||||
|
||||
async def get(self, request, entity_id):
|
||||
"""Get the entity registry settings for an entity."""
|
||||
hass = request.app['hass']
|
||||
registry = await async_get_registry(hass)
|
||||
entry = registry.entities.get(entity_id)
|
||||
|
||||
if entry is None:
|
||||
return self.json_message('Entry not found', 404)
|
||||
|
||||
return self.json(_entry_dict(entry))
|
||||
|
||||
@RequestDataValidator(vol.Schema({
|
||||
# If passed in, we update value. Passing None will remove old value.
|
||||
vol.Optional('name'): vol.Any(str, None),
|
||||
}))
|
||||
async def post(self, request, entity_id, data):
|
||||
"""Update the entity registry settings for an entity."""
|
||||
hass = request.app['hass']
|
||||
registry = await async_get_registry(hass)
|
||||
|
||||
if entity_id not in registry.entities:
|
||||
return self.json_message('Entry not found', 404)
|
||||
|
||||
entry = registry.async_update_entity(entity_id, **data)
|
||||
return self.json(_entry_dict(entry))
|
||||
|
||||
|
||||
@callback
|
||||
def _entry_dict(entry):
|
||||
"""Helper to convert entry to API format."""
|
||||
return {
|
||||
'entity_id': entry.entity_id,
|
||||
'name': entry.name
|
||||
}
|
|
@ -28,11 +28,11 @@ SUPPORT_DEMO = (SUPPORT_BRIGHTNESS | SUPPORT_COLOR_TEMP | SUPPORT_EFFECT |
|
|||
def setup_platform(hass, config, add_devices_callback, discovery_info=None):
|
||||
"""Set up the demo light platform."""
|
||||
add_devices_callback([
|
||||
DemoLight("Bed Light", False, True, effect_list=LIGHT_EFFECT_LIST,
|
||||
DemoLight(1, "Bed Light", False, True, effect_list=LIGHT_EFFECT_LIST,
|
||||
effect=LIGHT_EFFECT_LIST[0]),
|
||||
DemoLight("Ceiling Lights", True, True,
|
||||
DemoLight(2, "Ceiling Lights", True, True,
|
||||
LIGHT_COLORS[0], LIGHT_TEMPS[1]),
|
||||
DemoLight("Kitchen Lights", True, True,
|
||||
DemoLight(3, "Kitchen Lights", True, True,
|
||||
LIGHT_COLORS[1], LIGHT_TEMPS[0])
|
||||
])
|
||||
|
||||
|
@ -40,10 +40,11 @@ def setup_platform(hass, config, add_devices_callback, discovery_info=None):
|
|||
class DemoLight(Light):
|
||||
"""Representation of a demo light."""
|
||||
|
||||
def __init__(self, name, state, available=False, rgb=None, ct=None,
|
||||
brightness=180, xy_color=(.5, .5), white=200,
|
||||
def __init__(self, unique_id, name, state, available=False, rgb=None,
|
||||
ct=None, brightness=180, xy_color=(.5, .5), white=200,
|
||||
effect_list=None, effect=None):
|
||||
"""Initialize the light."""
|
||||
self._unique_id = unique_id
|
||||
self._name = name
|
||||
self._state = state
|
||||
self._rgb = rgb
|
||||
|
@ -64,6 +65,11 @@ class DemoLight(Light):
|
|||
"""Return the name of the light if any."""
|
||||
return self._name
|
||||
|
||||
@property
|
||||
def unique_id(self):
|
||||
"""Return unique ID for light."""
|
||||
return self._unique_id
|
||||
|
||||
@property
|
||||
def available(self) -> bool:
|
||||
"""Return availability."""
|
||||
|
|
|
@ -340,6 +340,12 @@ class Entity(object):
|
|||
else:
|
||||
self.hass.states.async_remove(self.entity_id)
|
||||
|
||||
@callback
|
||||
def async_registry_updated(self, old, new):
|
||||
"""Called when the entity registry has been updated."""
|
||||
self.registry_name = new.name
|
||||
self.async_schedule_update_ha_state()
|
||||
|
||||
def __eq__(self, other):
|
||||
"""Return the comparison."""
|
||||
if not isinstance(other, self.__class__):
|
||||
|
|
|
@ -10,12 +10,11 @@ from homeassistant.util.async import (
|
|||
import homeassistant.util.dt as dt_util
|
||||
|
||||
from .event import async_track_time_interval, async_track_point_in_time
|
||||
from .entity_registry import EntityRegistry
|
||||
from .entity_registry import async_get_registry
|
||||
|
||||
SLOW_SETUP_WARNING = 10
|
||||
SLOW_SETUP_MAX_WAIT = 60
|
||||
PLATFORM_NOT_READY_RETRIES = 10
|
||||
DATA_REGISTRY = 'entity_registry'
|
||||
|
||||
|
||||
class EntityPlatform(object):
|
||||
|
@ -156,12 +155,7 @@ class EntityPlatform(object):
|
|||
hass = self.hass
|
||||
component_entities = set(hass.states.async_entity_ids(self.domain))
|
||||
|
||||
registry = hass.data.get(DATA_REGISTRY)
|
||||
|
||||
if registry is None:
|
||||
registry = hass.data[DATA_REGISTRY] = EntityRegistry(hass)
|
||||
|
||||
yield from registry.async_ensure_loaded()
|
||||
registry = yield from async_get_registry(hass)
|
||||
|
||||
tasks = [
|
||||
self._async_add_entity(entity, update_before_add,
|
||||
|
@ -226,6 +220,7 @@ class EntityPlatform(object):
|
|||
|
||||
entity.entity_id = entry.entity_id
|
||||
entity.registry_name = entry.name
|
||||
entry.add_update_listener(entity)
|
||||
|
||||
# We won't generate an entity ID if the platform has already set one
|
||||
# We will however make sure that platform cannot pick a registered ID
|
||||
|
|
|
@ -15,17 +15,20 @@ from collections import OrderedDict
|
|||
from itertools import chain
|
||||
import logging
|
||||
import os
|
||||
import weakref
|
||||
|
||||
import attr
|
||||
|
||||
from ..core import callback, split_entity_id
|
||||
from ..loader import bind_hass
|
||||
from ..util import ensure_unique_string, slugify
|
||||
from ..util.yaml import load_yaml, save_yaml
|
||||
|
||||
PATH_REGISTRY = 'entity_registry.yaml'
|
||||
DATA_REGISTRY = 'entity_registry'
|
||||
SAVE_DELAY = 10
|
||||
_LOGGER = logging.getLogger(__name__)
|
||||
|
||||
_UNDEF = object()
|
||||
DISABLED_HASS = 'hass'
|
||||
DISABLED_USER = 'user'
|
||||
|
||||
|
@ -34,6 +37,8 @@ DISABLED_USER = 'user'
|
|||
class RegistryEntry:
|
||||
"""Entity Registry Entry."""
|
||||
|
||||
# pylint: disable=no-member
|
||||
|
||||
entity_id = attr.ib(type=str)
|
||||
unique_id = attr.ib(type=str)
|
||||
platform = attr.ib(type=str)
|
||||
|
@ -41,17 +46,27 @@ class RegistryEntry:
|
|||
disabled_by = attr.ib(
|
||||
type=str, default=None,
|
||||
validator=attr.validators.in_((DISABLED_HASS, DISABLED_USER, None)))
|
||||
domain = attr.ib(type=str, default=None, init=False, repr=False)
|
||||
update_listeners = attr.ib(type=list, default=attr.Factory(list),
|
||||
repr=False)
|
||||
domain = attr.ib(type=str, init=False, repr=False)
|
||||
|
||||
def __attrs_post_init__(self):
|
||||
"""Computed properties."""
|
||||
object.__setattr__(self, "domain", split_entity_id(self.entity_id)[0])
|
||||
@domain.default
|
||||
def _domain_default(self):
|
||||
"""Compute domain value."""
|
||||
return split_entity_id(self.entity_id)[0]
|
||||
|
||||
@property
|
||||
def disabled(self):
|
||||
"""Return if entry is disabled."""
|
||||
return self.disabled_by is not None
|
||||
|
||||
def add_update_listener(self, listener):
|
||||
"""Listen for when entry is updated.
|
||||
|
||||
Listener: Callback function(old_entry, new_entry)
|
||||
"""
|
||||
self.update_listeners.append(weakref.ref(listener))
|
||||
|
||||
|
||||
class EntityRegistry:
|
||||
"""Class to hold a registry of entities."""
|
||||
|
@ -102,6 +117,39 @@ class EntityRegistry:
|
|||
self.async_schedule_save()
|
||||
return entity
|
||||
|
||||
@callback
|
||||
def async_update_entity(self, entity_id, *, name=_UNDEF):
|
||||
"""Update properties of an entity."""
|
||||
old = self.entities[entity_id]
|
||||
|
||||
changes = {}
|
||||
|
||||
if name is not _UNDEF and name != old.name:
|
||||
changes['name'] = name
|
||||
|
||||
if not changes:
|
||||
return old
|
||||
|
||||
new = self.entities[entity_id] = attr.evolve(old, **changes)
|
||||
|
||||
to_remove = []
|
||||
for listener_ref in new.update_listeners:
|
||||
listener = listener_ref()
|
||||
if listener is None:
|
||||
to_remove.append(listener)
|
||||
else:
|
||||
try:
|
||||
listener.async_registry_updated(old, new)
|
||||
except Exception: # pylint: disable=broad-except
|
||||
_LOGGER.exception('Error calling update listener')
|
||||
|
||||
for ref in to_remove:
|
||||
new.update_listeners.remove(ref)
|
||||
|
||||
self.async_schedule_save()
|
||||
|
||||
return new
|
||||
|
||||
@asyncio.coroutine
|
||||
def async_ensure_loaded(self):
|
||||
"""Load the registry from disk."""
|
||||
|
@ -154,7 +202,20 @@ class EntityRegistry:
|
|||
data[entry.entity_id] = {
|
||||
'unique_id': entry.unique_id,
|
||||
'platform': entry.platform,
|
||||
'name': entry.name,
|
||||
}
|
||||
|
||||
yield from self.hass.async_add_job(
|
||||
save_yaml, self.hass.config.path(PATH_REGISTRY), data)
|
||||
|
||||
|
||||
@bind_hass
|
||||
async def async_get_registry(hass) -> EntityRegistry:
|
||||
"""Return entity registry instance."""
|
||||
registry = hass.data.get(DATA_REGISTRY)
|
||||
|
||||
if registry is None:
|
||||
registry = hass.data[DATA_REGISTRY] = EntityRegistry(hass)
|
||||
|
||||
await registry.async_ensure_loaded()
|
||||
return registry
|
||||
|
|
|
@ -1,5 +1,6 @@
|
|||
"""Test the helper method for writing tests."""
|
||||
import asyncio
|
||||
from datetime import timedelta
|
||||
import functools as ft
|
||||
import os
|
||||
import sys
|
||||
|
@ -298,7 +299,7 @@ def mock_registry(hass, mock_entries=None):
|
|||
"""Mock the Entity Registry."""
|
||||
registry = entity_registry.EntityRegistry(hass)
|
||||
registry.entities = mock_entries or {}
|
||||
hass.data[entity_platform.DATA_REGISTRY] = registry
|
||||
hass.data[entity_registry.DATA_REGISTRY] = registry
|
||||
return registry
|
||||
|
||||
|
||||
|
@ -361,6 +362,32 @@ class MockPlatform(object):
|
|||
self.async_setup_platform = mock_coro_func()
|
||||
|
||||
|
||||
class MockEntityPlatform(entity_platform.EntityPlatform):
|
||||
"""Mock class with some mock defaults."""
|
||||
|
||||
def __init__(
|
||||
self, hass,
|
||||
logger=None,
|
||||
domain='test_domain',
|
||||
platform_name='test_platform',
|
||||
scan_interval=timedelta(seconds=15),
|
||||
parallel_updates=0,
|
||||
entity_namespace=None,
|
||||
async_entities_added_callback=lambda: None
|
||||
):
|
||||
"""Initialize a mock entity platform."""
|
||||
super().__init__(
|
||||
hass=hass,
|
||||
logger=logger,
|
||||
domain=domain,
|
||||
platform_name=platform_name,
|
||||
scan_interval=scan_interval,
|
||||
parallel_updates=parallel_updates,
|
||||
entity_namespace=entity_namespace,
|
||||
async_entities_added_callback=async_entities_added_callback,
|
||||
)
|
||||
|
||||
|
||||
class MockToggleDevice(entity.ToggleEntity):
|
||||
"""Provide a mock toggle device."""
|
||||
|
||||
|
|
|
@ -0,0 +1,134 @@
|
|||
"""Test entity_registry API."""
|
||||
import pytest
|
||||
|
||||
from homeassistant.setup import async_setup_component
|
||||
from homeassistant.helpers.entity_registry import RegistryEntry
|
||||
from homeassistant.components.config import entity_registry
|
||||
from tests.common import mock_registry, MockEntity, MockEntityPlatform
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def client(hass, test_client):
|
||||
"""Fixture that can interact with the config manager API."""
|
||||
hass.loop.run_until_complete(async_setup_component(hass, 'http', {}))
|
||||
hass.loop.run_until_complete(entity_registry.async_setup(hass))
|
||||
yield hass.loop.run_until_complete(test_client(hass.http.app))
|
||||
|
||||
|
||||
async def test_get_entity(hass, client):
|
||||
"""Test get entry."""
|
||||
mock_registry(hass, {
|
||||
'test_domain.name': RegistryEntry(
|
||||
entity_id='test_domain.name',
|
||||
unique_id='1234',
|
||||
platform='test_platform',
|
||||
name='Hello World'
|
||||
),
|
||||
'test_domain.no_name': RegistryEntry(
|
||||
entity_id='test_domain.no_name',
|
||||
unique_id='6789',
|
||||
platform='test_platform',
|
||||
),
|
||||
})
|
||||
|
||||
resp = await client.get(
|
||||
'/api/config/entity_registry/test_domain.name')
|
||||
assert resp.status == 200
|
||||
data = await resp.json()
|
||||
assert data == {
|
||||
'entity_id': 'test_domain.name',
|
||||
'name': 'Hello World'
|
||||
}
|
||||
|
||||
resp = await client.get(
|
||||
'/api/config/entity_registry/test_domain.no_name')
|
||||
assert resp.status == 200
|
||||
data = await resp.json()
|
||||
assert data == {
|
||||
'entity_id': 'test_domain.no_name',
|
||||
'name': None
|
||||
}
|
||||
|
||||
|
||||
async def test_update_entity(hass, client):
|
||||
"""Test get entry."""
|
||||
mock_registry(hass, {
|
||||
'test_domain.world': RegistryEntry(
|
||||
entity_id='test_domain.world',
|
||||
unique_id='1234',
|
||||
# Using component.async_add_entities is equal to platform "domain"
|
||||
platform='test_platform',
|
||||
name='before update'
|
||||
)
|
||||
})
|
||||
platform = MockEntityPlatform(hass)
|
||||
entity = MockEntity(unique_id='1234')
|
||||
await platform.async_add_entities([entity])
|
||||
|
||||
state = hass.states.get('test_domain.world')
|
||||
assert state is not None
|
||||
assert state.name == 'before update'
|
||||
|
||||
resp = await client.post(
|
||||
'/api/config/entity_registry/test_domain.world', json={
|
||||
'name': 'after update'
|
||||
})
|
||||
assert resp.status == 200
|
||||
data = await resp.json()
|
||||
assert data == {
|
||||
'entity_id': 'test_domain.world',
|
||||
'name': 'after update'
|
||||
}
|
||||
|
||||
state = hass.states.get('test_domain.world')
|
||||
assert state.name == 'after update'
|
||||
|
||||
|
||||
async def test_update_entity_no_changes(hass, client):
|
||||
"""Test get entry."""
|
||||
mock_registry(hass, {
|
||||
'test_domain.world': RegistryEntry(
|
||||
entity_id='test_domain.world',
|
||||
unique_id='1234',
|
||||
# Using component.async_add_entities is equal to platform "domain"
|
||||
platform='test_platform',
|
||||
name='name of entity'
|
||||
)
|
||||
})
|
||||
platform = MockEntityPlatform(hass)
|
||||
entity = MockEntity(unique_id='1234')
|
||||
await platform.async_add_entities([entity])
|
||||
|
||||
state = hass.states.get('test_domain.world')
|
||||
assert state is not None
|
||||
assert state.name == 'name of entity'
|
||||
|
||||
resp = await client.post(
|
||||
'/api/config/entity_registry/test_domain.world', json={
|
||||
'name': 'name of entity'
|
||||
})
|
||||
assert resp.status == 200
|
||||
data = await resp.json()
|
||||
assert data == {
|
||||
'entity_id': 'test_domain.world',
|
||||
'name': 'name of entity'
|
||||
}
|
||||
|
||||
state = hass.states.get('test_domain.world')
|
||||
assert state.name == 'name of entity'
|
||||
|
||||
|
||||
async def test_get_nonexisting_entity(client):
|
||||
"""Test get entry."""
|
||||
resp = await client.get(
|
||||
'/api/config/entity_registry/test_domain.non_existing')
|
||||
assert resp.status == 404
|
||||
|
||||
|
||||
async def test_update_nonexisting_entity(client):
|
||||
"""Test get entry."""
|
||||
resp = await client.post(
|
||||
'/api/config/entity_registry/test_domain.non_existing', json={
|
||||
'name': 'some name'
|
||||
})
|
||||
assert resp.status == 404
|
|
@ -15,39 +15,13 @@ import homeassistant.util.dt as dt_util
|
|||
|
||||
from tests.common import (
|
||||
get_test_home_assistant, MockPlatform, fire_time_changed, mock_registry,
|
||||
MockEntity)
|
||||
MockEntity, MockEntityPlatform)
|
||||
|
||||
_LOGGER = logging.getLogger(__name__)
|
||||
DOMAIN = "test_domain"
|
||||
PLATFORM = 'test_platform'
|
||||
|
||||
|
||||
class MockEntityPlatform(entity_platform.EntityPlatform):
|
||||
"""Mock class with some mock defaults."""
|
||||
|
||||
def __init__(
|
||||
self, hass,
|
||||
logger=None,
|
||||
domain=DOMAIN,
|
||||
platform_name=PLATFORM,
|
||||
scan_interval=timedelta(seconds=15),
|
||||
parallel_updates=0,
|
||||
entity_namespace=None,
|
||||
async_entities_added_callback=lambda: None
|
||||
):
|
||||
"""Initialize a mock entity platform."""
|
||||
super().__init__(
|
||||
hass=hass,
|
||||
logger=logger,
|
||||
domain=domain,
|
||||
platform_name=platform_name,
|
||||
scan_interval=scan_interval,
|
||||
parallel_updates=parallel_updates,
|
||||
entity_namespace=entity_namespace,
|
||||
async_entities_added_callback=async_entities_added_callback,
|
||||
)
|
||||
|
||||
|
||||
class TestHelpersEntityPlatform(unittest.TestCase):
|
||||
"""Test homeassistant.helpers.entity_component module."""
|
||||
|
||||
|
@ -510,3 +484,30 @@ def test_registry_respect_entity_disabled(hass):
|
|||
yield from platform.async_add_entities([entity])
|
||||
assert entity.entity_id is None
|
||||
assert hass.states.async_entity_ids() == []
|
||||
|
||||
|
||||
async def test_entity_registry_updates(hass):
|
||||
"""Test that updates on the entity registry update platform entities."""
|
||||
registry = mock_registry(hass, {
|
||||
'test_domain.world': entity_registry.RegistryEntry(
|
||||
entity_id='test_domain.world',
|
||||
unique_id='1234',
|
||||
# Using component.async_add_entities is equal to platform "domain"
|
||||
platform='test_platform',
|
||||
name='before update'
|
||||
)
|
||||
})
|
||||
platform = MockEntityPlatform(hass)
|
||||
entity = MockEntity(unique_id='1234')
|
||||
await platform.async_add_entities([entity])
|
||||
|
||||
state = hass.states.get('test_domain.world')
|
||||
assert state is not None
|
||||
assert state.name == 'before update'
|
||||
|
||||
registry.async_update_entity('test_domain.world', name='after update')
|
||||
await hass.async_block_till_done()
|
||||
await hass.async_block_till_done()
|
||||
|
||||
state = hass.states.get('test_domain.world')
|
||||
assert state.name == 'after update'
|
||||
|
|
Loading…
Reference in New Issue