Allow changing entity ID (#15637)

* Allow changing entity ID

* Add support to websocket command

* Address comments

* Error handling
pull/15155/merge
Paulus Schoutsen 2018-07-24 14:12:53 +02:00 committed by GitHub
parent fbeaa57604
commit d9cf8fcfe8
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
7 changed files with 206 additions and 22 deletions

View File

@ -20,6 +20,7 @@ SCHEMA_WS_UPDATE = websocket_api.BASE_COMMAND_MESSAGE_SCHEMA.extend({
vol.Required('entity_id'): cv.entity_id,
# If passed in, we update value. Passing None will remove old value.
vol.Optional('name'): vol.Any(str, None),
vol.Optional('new_entity_id'): str,
})
@ -74,13 +75,28 @@ def websocket_update_entity(hass, connection, msg):
msg['id'], websocket_api.ERR_NOT_FOUND, 'Entity not found'))
return
entry = registry.async_update_entity(
msg['entity_id'], name=msg['name'])
connection.send_message_outside(websocket_api.result_message(
msg['id'], _entry_dict(entry)
))
changes = {}
hass.async_add_job(update_entity())
if 'name' in msg:
changes['name'] = msg['name']
if 'new_entity_id' in msg:
changes['new_entity_id'] = msg['new_entity_id']
try:
if changes:
entry = registry.async_update_entity(
msg['entity_id'], **changes)
except ValueError as err:
connection.send_message_outside(websocket_api.error_message(
msg['id'], 'invalid_info', str(err)
))
else:
connection.send_message_outside(websocket_api.result_message(
msg['id'], _entry_dict(entry)
))
hass.async_create_task(update_entity())
@callback

View File

@ -82,6 +82,9 @@ class Entity:
# Name in the entity registry
registry_name = None
# Hold list for functions to call on remove.
_on_remove = None
@property
def should_poll(self) -> bool:
"""Return True if entity has to be polled for state.
@ -324,8 +327,19 @@ class Entity:
if self.parallel_updates:
self.parallel_updates.release()
@callback
def async_on_remove(self, func):
"""Add a function to call when entity removed."""
if self._on_remove is None:
self._on_remove = []
self._on_remove.append(func)
async def async_remove(self):
"""Remove entity from Home Assistant."""
if self._on_remove is not None:
while self._on_remove:
self._on_remove.pop()()
if self.platform is not None:
await self.platform.async_remove_entity(self.entity_id)
else:
@ -335,7 +349,17 @@ class Entity:
def async_registry_updated(self, old, new):
"""Called when the entity registry has been updated."""
self.registry_name = new.name
self.async_schedule_update_ha_state()
if new.entity_id == self.entity_id:
self.async_schedule_update_ha_state()
return
async def readd():
"""Remove and add entity again."""
await self.async_remove()
await self.platform.async_add_entities([self])
self.hass.async_create_task(readd())
def __eq__(self, other):
"""Return the comparison."""

View File

@ -283,7 +283,7 @@ class EntityPlatform:
entity.entity_id = entry.entity_id
entity.registry_name = entry.name
entry.add_update_listener(entity)
entity.async_on_remove(entry.add_update_listener(entity))
# We won't generate an entity ID if the platform has already set one
# We will however make sure that platform cannot pick a registered ID

View File

@ -19,10 +19,10 @@ import weakref
import attr
from ..core import callback, split_entity_id
from ..loader import bind_hass
from ..util import ensure_unique_string, slugify
from ..util.yaml import load_yaml, save_yaml
from homeassistant.core import callback, split_entity_id, valid_entity_id
from homeassistant.loader import bind_hass
from homeassistant.util import ensure_unique_string, slugify
from homeassistant.util.yaml import load_yaml, save_yaml
PATH_REGISTRY = 'entity_registry.yaml'
DATA_REGISTRY = 'entity_registry'
@ -63,8 +63,13 @@ class RegistryEntry:
"""Listen for when entry is updated.
Listener: Callback function(old_entry, new_entry)
Returns function to unlisten.
"""
self.update_listeners.append(weakref.ref(listener))
weak_listener = weakref.ref(listener)
self.update_listeners.append(weak_listener)
return lambda: self.update_listeners.remove(weak_listener)
class EntityRegistry:
@ -133,13 +138,18 @@ class EntityRegistry:
return entity
@callback
def async_update_entity(self, entity_id, *, name=_UNDEF):
def async_update_entity(self, entity_id, *, name=_UNDEF,
new_entity_id=_UNDEF):
"""Update properties of an entity."""
return self._async_update_entity(entity_id, name=name)
return self._async_update_entity(
entity_id,
name=name,
new_entity_id=new_entity_id
)
@callback
def _async_update_entity(self, entity_id, *, name=_UNDEF,
config_entry_id=_UNDEF):
config_entry_id=_UNDEF, new_entity_id=_UNDEF):
"""Private facing update properties method."""
old = self.entities[entity_id]
@ -152,6 +162,20 @@ class EntityRegistry:
config_entry_id != old.config_entry_id):
changes['config_entry_id'] = config_entry_id
if new_entity_id is not _UNDEF and new_entity_id != old.entity_id:
if self.async_is_registered(new_entity_id):
raise ValueError('Entity is already registered')
if not valid_entity_id(new_entity_id):
raise ValueError('Invalid entity ID')
if (split_entity_id(new_entity_id)[0] !=
split_entity_id(entity_id)[0]):
raise ValueError('New entity ID should be same domain')
self.entities.pop(entity_id)
entity_id = changes['entity_id'] = new_entity_id
if not changes:
return old

View File

@ -54,8 +54,8 @@ async def test_get_entity(hass, client):
}
async def test_update_entity(hass, client):
"""Test get entry."""
async def test_update_entity_name(hass, client):
"""Test updating entity name."""
mock_registry(hass, {
'test_domain.world': RegistryEntry(
entity_id='test_domain.world',
@ -92,7 +92,7 @@ async def test_update_entity(hass, client):
async def test_update_entity_no_changes(hass, client):
"""Test get entry."""
"""Test update entity with no changes."""
mock_registry(hass, {
'test_domain.world': RegistryEntry(
entity_id='test_domain.world',
@ -129,7 +129,7 @@ async def test_update_entity_no_changes(hass, client):
async def test_get_nonexisting_entity(client):
"""Test get entry."""
"""Test get entry with nonexisting entity."""
await client.send_json({
'id': 6,
'type': 'config/entity_registry/get',
@ -141,7 +141,7 @@ async def test_get_nonexisting_entity(client):
async def test_update_nonexisting_entity(client):
"""Test get entry."""
"""Test update a nonexisting entity."""
await client.send_json({
'id': 6,
'type': 'config/entity_registry/update',
@ -151,3 +151,37 @@ async def test_update_nonexisting_entity(client):
msg = await client.receive_json()
assert not msg['success']
async def test_update_entity_id(hass, client):
"""Test update entity id."""
mock_registry(hass, {
'test_domain.world': RegistryEntry(
entity_id='test_domain.world',
unique_id='1234',
# Using component.async_add_entities is equal to platform "domain"
platform='test_platform',
)
})
platform = MockEntityPlatform(hass)
entity = MockEntity(unique_id='1234')
await platform.async_add_entities([entity])
assert hass.states.get('test_domain.world') is not None
await client.send_json({
'id': 6,
'type': 'config/entity_registry/update',
'entity_id': 'test_domain.world',
'new_entity_id': 'test_domain.planet',
})
msg = await client.receive_json()
assert msg['result'] == {
'entity_id': 'test_domain.planet',
'name': None
}
assert hass.states.get('test_domain.world') is None
assert hass.states.get('test_domain.planet') is not None

View File

@ -400,3 +400,15 @@ def test_async_remove_no_platform(hass):
assert len(hass.states.async_entity_ids()) == 1
yield from ent.async_remove()
assert len(hass.states.async_entity_ids()) == 0
async def test_async_remove_runs_callbacks(hass):
"""Test async_remove method when no platform set."""
result = []
ent = entity.Entity()
ent.hass = hass
ent.entity_id = 'test.test'
ent.async_on_remove(lambda: result.append(1))
await ent.async_remove()
assert len(result) == 1

View File

@ -5,6 +5,8 @@ import unittest
from unittest.mock import patch, Mock, MagicMock
from datetime import timedelta
import pytest
from homeassistant.exceptions import PlatformNotReady
import homeassistant.loader as loader
from homeassistant.helpers.entity import generate_entity_id
@ -487,7 +489,7 @@ def test_registry_respect_entity_disabled(hass):
assert hass.states.async_entity_ids() == []
async def test_entity_registry_updates(hass):
async def test_entity_registry_updates_name(hass):
"""Test that updates on the entity registry update platform entities."""
registry = mock_registry(hass, {
'test_domain.world': entity_registry.RegistryEntry(
@ -602,3 +604,75 @@ def test_not_fails_with_adding_empty_entities_(hass):
yield from component.async_add_entities([])
assert len(hass.states.async_entity_ids()) == 0
async def test_entity_registry_updates_entity_id(hass):
"""Test that updates on the entity registry update platform entities."""
registry = mock_registry(hass, {
'test_domain.world': entity_registry.RegistryEntry(
entity_id='test_domain.world',
unique_id='1234',
# Using component.async_add_entities is equal to platform "domain"
platform='test_platform',
name='Some name'
)
})
platform = MockEntityPlatform(hass)
entity = MockEntity(unique_id='1234')
await platform.async_add_entities([entity])
state = hass.states.get('test_domain.world')
assert state is not None
assert state.name == 'Some name'
registry.async_update_entity('test_domain.world',
new_entity_id='test_domain.planet')
await hass.async_block_till_done()
await hass.async_block_till_done()
assert hass.states.get('test_domain.world') is None
assert hass.states.get('test_domain.planet') is not None
async def test_entity_registry_updates_invalid_entity_id(hass):
"""Test that we can't update to an invalid entity id."""
registry = mock_registry(hass, {
'test_domain.world': entity_registry.RegistryEntry(
entity_id='test_domain.world',
unique_id='1234',
# Using component.async_add_entities is equal to platform "domain"
platform='test_platform',
name='Some name'
),
'test_domain.existing': entity_registry.RegistryEntry(
entity_id='test_domain.existing',
unique_id='5678',
platform='test_platform',
),
})
platform = MockEntityPlatform(hass)
entity = MockEntity(unique_id='1234')
await platform.async_add_entities([entity])
state = hass.states.get('test_domain.world')
assert state is not None
assert state.name == 'Some name'
with pytest.raises(ValueError):
registry.async_update_entity('test_domain.world',
new_entity_id='test_domain.existing')
with pytest.raises(ValueError):
registry.async_update_entity('test_domain.world',
new_entity_id='invalid_entity_id')
with pytest.raises(ValueError):
registry.async_update_entity('test_domain.world',
new_entity_id='diff_domain.world')
await hass.async_block_till_done()
await hass.async_block_till_done()
assert hass.states.get('test_domain.world') is not None
assert hass.states.get('invalid_entity_id') is None
assert hass.states.get('diff_domain.world') is None