Allow disabling entities in the registry (#12360)

pull/12373/head
Paulus Schoutsen 2018-02-13 04:33:15 -08:00 committed by Pascal Vizeli
parent a4b88fc31b
commit d2cea84254
4 changed files with 64 additions and 6 deletions

View File

@ -216,6 +216,14 @@ class EntityPlatform(object):
entry = registry.async_get_or_create( entry = registry.async_get_or_create(
self.domain, self.platform_name, entity.unique_id, self.domain, self.platform_name, entity.unique_id,
suggested_object_id=suggested_object_id) suggested_object_id=suggested_object_id)
if entry.disabled:
self.logger.info(
"Not adding entity %s because it's disabled",
entry.name or entity.name or
'"{} {}"'.format(self.platform_name, entity.unique_id))
return
entity.entity_id = entry.entity_id entity.entity_id = entry.entity_id
entity.registry_name = entry.name entity.registry_name = entry.name

View File

@ -26,6 +26,9 @@ PATH_REGISTRY = 'entity_registry.yaml'
SAVE_DELAY = 10 SAVE_DELAY = 10
_LOGGER = logging.getLogger(__name__) _LOGGER = logging.getLogger(__name__)
DISABLED_HASS = 'hass'
DISABLED_USER = 'user'
@attr.s(slots=True, frozen=True) @attr.s(slots=True, frozen=True)
class RegistryEntry: class RegistryEntry:
@ -35,12 +38,20 @@ class RegistryEntry:
unique_id = attr.ib(type=str) unique_id = attr.ib(type=str)
platform = attr.ib(type=str) platform = attr.ib(type=str)
name = attr.ib(type=str, default=None) name = attr.ib(type=str, default=None)
disabled_by = attr.ib(
type=str, default=None,
validator=attr.validators.in_((DISABLED_HASS, DISABLED_USER, None)))
domain = attr.ib(type=str, default=None, init=False, repr=False) domain = attr.ib(type=str, default=None, init=False, repr=False)
def __attrs_post_init__(self): def __attrs_post_init__(self):
"""Computed properties.""" """Computed properties."""
object.__setattr__(self, "domain", split_entity_id(self.entity_id)[0]) object.__setattr__(self, "domain", split_entity_id(self.entity_id)[0])
@property
def disabled(self):
"""Return if entry is disabled."""
return self.disabled_by is not None
class EntityRegistry: class EntityRegistry:
"""Class to hold a registry of entities.""" """Class to hold a registry of entities."""
@ -116,7 +127,8 @@ class EntityRegistry:
entity_id=entity_id, entity_id=entity_id,
unique_id=info['unique_id'], unique_id=info['unique_id'],
platform=info['platform'], platform=info['platform'],
name=info.get('name') name=info.get('name'),
disabled_by=info.get('disabled_by')
) )
self.entities = entities self.entities = entities

View File

@ -19,16 +19,17 @@ from tests.common import (
_LOGGER = logging.getLogger(__name__) _LOGGER = logging.getLogger(__name__)
DOMAIN = "test_domain" DOMAIN = "test_domain"
PLATFORM = 'test_platform'
class MockEntityPlatform(entity_platform.EntityPlatform): class MockEntityPlatform(entity_platform.EntityPlatform):
"""Mock class with some mock defaults.""" """Mock class with some mock defaults."""
def __init__( def __init__(
self, *, hass, self, hass,
logger=None, logger=None,
domain='test', domain=DOMAIN,
platform_name='test_platform', platform_name=PLATFORM,
scan_interval=timedelta(seconds=15), scan_interval=timedelta(seconds=15),
parallel_updates=0, parallel_updates=0,
entity_namespace=None, entity_namespace=None,
@ -486,7 +487,26 @@ def test_overriding_name_from_registry(hass):
def test_registry_respect_entity_namespace(hass): def test_registry_respect_entity_namespace(hass):
"""Test that the registry respects entity namespace.""" """Test that the registry respects entity namespace."""
mock_registry(hass) mock_registry(hass)
platform = MockEntityPlatform(hass=hass, entity_namespace='ns') platform = MockEntityPlatform(hass, entity_namespace='ns')
entity = MockEntity(unique_id='1234', name='Device Name') entity = MockEntity(unique_id='1234', name='Device Name')
yield from platform.async_add_entities([entity]) yield from platform.async_add_entities([entity])
assert entity.entity_id == 'test.ns_device_name' assert entity.entity_id == 'test_domain.ns_device_name'
@asyncio.coroutine
def test_registry_respect_entity_disabled(hass):
"""Test that the registry respects entity disabled."""
mock_registry(hass, {
'test_domain.world': entity_registry.RegistryEntry(
entity_id='test_domain.world',
unique_id='1234',
# Using component.async_add_entities is equal to platform "domain"
platform='test_platform',
disabled_by=entity_registry.DISABLED_USER
)
})
platform = MockEntityPlatform(hass)
entity = MockEntity(unique_id='1234')
yield from platform.async_add_entities([entity])
assert entity.entity_id is None
assert hass.states.async_entity_ids() == []

View File

@ -148,6 +148,14 @@ test.named:
test.no_name: test.no_name:
platform: super_platform platform: super_platform
unique_id: without-name unique_id: without-name
test.disabled_user:
platform: super_platform
unique_id: disabled-user
disabled_by: user
test.disabled_hass:
platform: super_platform
unique_id: disabled-hass
disabled_by: hass
""" """
registry = entity_registry.EntityRegistry(hass) registry = entity_registry.EntityRegistry(hass)
@ -162,3 +170,13 @@ test.no_name:
'test', 'super_platform', 'without-name') 'test', 'super_platform', 'without-name')
assert entry_with_name.name == 'registry override' assert entry_with_name.name == 'registry override'
assert entry_without_name.name is None assert entry_without_name.name is None
assert not entry_with_name.disabled
entry_disabled_hass = registry.async_get_or_create(
'test', 'super_platform', 'disabled-hass')
entry_disabled_user = registry.async_get_or_create(
'test', 'super_platform', 'disabled-user')
assert entry_disabled_hass.disabled
assert entry_disabled_hass.disabled_by == entity_registry.DISABLED_HASS
assert entry_disabled_user.disabled
assert entry_disabled_user.disabled_by == entity_registry.DISABLED_USER