Add area id to entity registry (#42221)

Co-authored-by: Martin Hjelmare <marhje52@gmail.com>
pull/42312/head
Robert Svensson 2020-10-24 21:25:28 +02:00 committed by GitHub
parent b54dde10ca
commit e06c8009e1
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
8 changed files with 107 additions and 5 deletions

View File

@ -71,6 +71,7 @@ async def websocket_get_entity(hass, connection, msg):
# If passed in, we update value. Passing None will remove old value. # If passed in, we update value. Passing None will remove old value.
vol.Optional("name"): vol.Any(str, None), vol.Optional("name"): vol.Any(str, None),
vol.Optional("icon"): vol.Any(str, None), vol.Optional("icon"): vol.Any(str, None),
vol.Optional("area_id"): vol.Any(str, None),
vol.Optional("new_entity_id"): str, vol.Optional("new_entity_id"): str,
# We only allow setting disabled_by user via API. # We only allow setting disabled_by user via API.
vol.Optional("disabled_by"): vol.Any("user", None), vol.Optional("disabled_by"): vol.Any("user", None),
@ -91,7 +92,7 @@ async def websocket_update_entity(hass, connection, msg):
changes = {} changes = {}
for key in ("name", "icon", "disabled_by"): for key in ("name", "icon", "area_id", "disabled_by"):
if key in msg: if key in msg:
changes[key] = msg[key] changes[key] = msg[key]
@ -149,6 +150,7 @@ def _entry_dict(entry):
return { return {
"config_entry_id": entry.config_entry_id, "config_entry_id": entry.config_entry_id,
"device_id": entry.device_id, "device_id": entry.device_id,
"area_id": entry.area_id,
"disabled_by": entry.disabled_by, "disabled_by": entry.disabled_by,
"entity_id": entry.entity_id, "entity_id": entry.entity_id,
"name": entry.name, "name": entry.name,

View File

@ -122,6 +122,10 @@ class Searcher:
"""Resolve an area.""" """Resolve an area."""
for device in device_registry.async_entries_for_area(self._device_reg, area_id): for device in device_registry.async_entries_for_area(self._device_reg, area_id):
self._add_or_resolve("device", device.id) self._add_or_resolve("device", device.id)
for entity_entry in entity_registry.async_entries_for_area(
self._entity_reg, area_id
):
self._add_or_resolve("entity", entity_entry.entity_id)
@callback @callback
def _resolve_device(self, device_id) -> None: def _resolve_device(self, device_id) -> None:

View File

@ -1,5 +1,5 @@
"""Provide a way to connect devices to one physical location.""" """Provide a way to connect devices to one physical location."""
from asyncio import Event from asyncio import Event, gather
from collections import OrderedDict from collections import OrderedDict
from typing import Dict, Iterable, List, MutableMapping, Optional, cast from typing import Dict, Iterable, List, MutableMapping, Optional, cast
@ -64,8 +64,12 @@ class AreaRegistry:
async def async_delete(self, area_id: str) -> None: async def async_delete(self, area_id: str) -> None:
"""Delete area.""" """Delete area."""
device_registry = await self.hass.helpers.device_registry.async_get_registry() device_registry, entity_registry = await gather(
self.hass.helpers.device_registry.async_get_registry(),
self.hass.helpers.entity_registry.async_get_registry(),
)
device_registry.async_clear_area_id(area_id) device_registry.async_clear_area_id(area_id)
entity_registry.async_clear_area_id(area_id)
del self.areas[area_id] del self.areas[area_id]

View File

@ -83,6 +83,7 @@ class RegistryEntry:
name: Optional[str] = attr.ib(default=None) name: Optional[str] = attr.ib(default=None)
icon: Optional[str] = attr.ib(default=None) icon: Optional[str] = attr.ib(default=None)
device_id: Optional[str] = attr.ib(default=None) device_id: Optional[str] = attr.ib(default=None)
area_id: Optional[str] = attr.ib(default=None)
config_entry_id: Optional[str] = attr.ib(default=None) config_entry_id: Optional[str] = attr.ib(default=None)
disabled_by: Optional[str] = attr.ib( disabled_by: Optional[str] = attr.ib(
default=None, default=None,
@ -204,6 +205,7 @@ class EntityRegistry:
# Data that we want entry to have # Data that we want entry to have
config_entry: Optional["ConfigEntry"] = None, config_entry: Optional["ConfigEntry"] = None,
device_id: Optional[str] = None, device_id: Optional[str] = None,
area_id: Optional[str] = None,
capabilities: Optional[Dict[str, Any]] = None, capabilities: Optional[Dict[str, Any]] = None,
supported_features: Optional[int] = None, supported_features: Optional[int] = None,
device_class: Optional[str] = None, device_class: Optional[str] = None,
@ -223,6 +225,7 @@ class EntityRegistry:
entity_id, entity_id,
config_entry_id=config_entry_id or _UNDEF, config_entry_id=config_entry_id or _UNDEF,
device_id=device_id or _UNDEF, device_id=device_id or _UNDEF,
area_id=area_id or _UNDEF,
capabilities=capabilities or _UNDEF, capabilities=capabilities or _UNDEF,
supported_features=supported_features or _UNDEF, supported_features=supported_features or _UNDEF,
device_class=device_class or _UNDEF, device_class=device_class or _UNDEF,
@ -253,6 +256,7 @@ class EntityRegistry:
entity_id=entity_id, entity_id=entity_id,
config_entry_id=config_entry_id, config_entry_id=config_entry_id,
device_id=device_id, device_id=device_id,
area_id=area_id,
unique_id=unique_id, unique_id=unique_id,
platform=platform, platform=platform,
disabled_by=disabled_by, disabled_by=disabled_by,
@ -302,6 +306,7 @@ class EntityRegistry:
*, *,
name=_UNDEF, name=_UNDEF,
icon=_UNDEF, icon=_UNDEF,
area_id=_UNDEF,
new_entity_id=_UNDEF, new_entity_id=_UNDEF,
new_unique_id=_UNDEF, new_unique_id=_UNDEF,
disabled_by=_UNDEF, disabled_by=_UNDEF,
@ -313,6 +318,7 @@ class EntityRegistry:
entity_id, entity_id,
name=name, name=name,
icon=icon, icon=icon,
area_id=area_id,
new_entity_id=new_entity_id, new_entity_id=new_entity_id,
new_unique_id=new_unique_id, new_unique_id=new_unique_id,
disabled_by=disabled_by, disabled_by=disabled_by,
@ -329,6 +335,7 @@ class EntityRegistry:
config_entry_id=_UNDEF, config_entry_id=_UNDEF,
new_entity_id=_UNDEF, new_entity_id=_UNDEF,
device_id=_UNDEF, device_id=_UNDEF,
area_id=_UNDEF,
new_unique_id=_UNDEF, new_unique_id=_UNDEF,
disabled_by=_UNDEF, disabled_by=_UNDEF,
capabilities=_UNDEF, capabilities=_UNDEF,
@ -348,6 +355,7 @@ class EntityRegistry:
("icon", icon), ("icon", icon),
("config_entry_id", config_entry_id), ("config_entry_id", config_entry_id),
("device_id", device_id), ("device_id", device_id),
("area_id", area_id),
("disabled_by", disabled_by), ("disabled_by", disabled_by),
("capabilities", capabilities), ("capabilities", capabilities),
("supported_features", supported_features), ("supported_features", supported_features),
@ -425,6 +433,7 @@ class EntityRegistry:
entity_id=entity["entity_id"], entity_id=entity["entity_id"],
config_entry_id=entity.get("config_entry_id"), config_entry_id=entity.get("config_entry_id"),
device_id=entity.get("device_id"), device_id=entity.get("device_id"),
area_id=entity.get("area_id"),
unique_id=entity["unique_id"], unique_id=entity["unique_id"],
platform=entity["platform"], platform=entity["platform"],
name=entity.get("name"), name=entity.get("name"),
@ -456,6 +465,7 @@ class EntityRegistry:
"entity_id": entry.entity_id, "entity_id": entry.entity_id,
"config_entry_id": entry.config_entry_id, "config_entry_id": entry.config_entry_id,
"device_id": entry.device_id, "device_id": entry.device_id,
"area_id": entry.area_id,
"unique_id": entry.unique_id, "unique_id": entry.unique_id,
"platform": entry.platform, "platform": entry.platform,
"name": entry.name, "name": entry.name,
@ -483,6 +493,13 @@ class EntityRegistry:
]: ]:
self.async_remove(entity_id) self.async_remove(entity_id)
@callback
def async_clear_area_id(self, area_id: str) -> None:
"""Clear area id from registry entries."""
for entity_id, entry in self.entities.items():
if area_id == entry.area_id:
self._async_update_entity(entity_id, area_id=None) # type: ignore
def _register_entry(self, entry: RegistryEntry) -> None: def _register_entry(self, entry: RegistryEntry) -> None:
self.entities[entry.entity_id] = entry self.entities[entry.entity_id] = entry
self._add_index(entry) self._add_index(entry)
@ -521,6 +538,14 @@ def async_entries_for_device(
] ]
@callback
def async_entries_for_area(
registry: EntityRegistry, area_id: str
) -> List[RegistryEntry]:
"""Return entries that match an area."""
return [entry for entry in registry.entities.values() if entry.area_id == area_id]
@callback @callback
def async_entries_for_config_entry( def async_entries_for_config_entry(
registry: EntityRegistry, config_entry_id: str registry: EntityRegistry, config_entry_id: str

View File

@ -234,6 +234,15 @@ async def async_extract_entity_ids(
hass.helpers.device_registry.async_get_registry(), hass.helpers.device_registry.async_get_registry(),
hass.helpers.entity_registry.async_get_registry(), hass.helpers.entity_registry.async_get_registry(),
) )
extracted.update(
entry.entity_id
for area_id in area_ids
for entry in hass.helpers.entity_registry.async_entries_for_area(
ent_reg, area_id
)
)
devices = [ devices = [
device device
for area_id in area_ids for area_id in area_ids
@ -247,6 +256,7 @@ async def async_extract_entity_ids(
for entry in hass.helpers.entity_registry.async_entries_for_device( for entry in hass.helpers.entity_registry.async_entries_for_device(
ent_reg, device.id ent_reg, device.id
) )
if not entry.area_id
) )
return extracted return extracted

View File

@ -39,6 +39,7 @@ async def test_list_entities(hass, client):
{ {
"config_entry_id": None, "config_entry_id": None,
"device_id": None, "device_id": None,
"area_id": None,
"disabled_by": None, "disabled_by": None,
"entity_id": "test_domain.name", "entity_id": "test_domain.name",
"name": "Hello World", "name": "Hello World",
@ -48,6 +49,7 @@ async def test_list_entities(hass, client):
{ {
"config_entry_id": None, "config_entry_id": None,
"device_id": None, "device_id": None,
"area_id": None,
"disabled_by": None, "disabled_by": None,
"entity_id": "test_domain.no_name", "entity_id": "test_domain.no_name",
"name": None, "name": None,
@ -84,6 +86,7 @@ async def test_get_entity(hass, client):
assert msg["result"] == { assert msg["result"] == {
"config_entry_id": None, "config_entry_id": None,
"device_id": None, "device_id": None,
"area_id": None,
"disabled_by": None, "disabled_by": None,
"platform": "test_platform", "platform": "test_platform",
"entity_id": "test_domain.name", "entity_id": "test_domain.name",
@ -107,6 +110,7 @@ async def test_get_entity(hass, client):
assert msg["result"] == { assert msg["result"] == {
"config_entry_id": None, "config_entry_id": None,
"device_id": None, "device_id": None,
"area_id": None,
"disabled_by": None, "disabled_by": None,
"platform": "test_platform", "platform": "test_platform",
"entity_id": "test_domain.no_name", "entity_id": "test_domain.no_name",
@ -143,7 +147,7 @@ async def test_update_entity(hass, client):
assert state.name == "before update" assert state.name == "before update"
assert state.attributes[ATTR_ICON] == "icon:before update" assert state.attributes[ATTR_ICON] == "icon:before update"
# UPDATE NAME & ICON # UPDATE NAME & ICON & AREA
await client.send_json( await client.send_json(
{ {
"id": 6, "id": 6,
@ -151,6 +155,7 @@ async def test_update_entity(hass, client):
"entity_id": "test_domain.world", "entity_id": "test_domain.world",
"name": "after update", "name": "after update",
"icon": "icon:after update", "icon": "icon:after update",
"area_id": "mock-area-id",
} }
) )
@ -159,6 +164,7 @@ async def test_update_entity(hass, client):
assert msg["result"] == { assert msg["result"] == {
"config_entry_id": None, "config_entry_id": None,
"device_id": None, "device_id": None,
"area_id": "mock-area-id",
"disabled_by": None, "disabled_by": None,
"platform": "test_platform", "platform": "test_platform",
"entity_id": "test_domain.world", "entity_id": "test_domain.world",
@ -204,6 +210,7 @@ async def test_update_entity(hass, client):
assert msg["result"] == { assert msg["result"] == {
"config_entry_id": None, "config_entry_id": None,
"device_id": None, "device_id": None,
"area_id": "mock-area-id",
"disabled_by": None, "disabled_by": None,
"platform": "test_platform", "platform": "test_platform",
"entity_id": "test_domain.world", "entity_id": "test_domain.world",
@ -252,6 +259,7 @@ async def test_update_entity_no_changes(hass, client):
assert msg["result"] == { assert msg["result"] == {
"config_entry_id": None, "config_entry_id": None,
"device_id": None, "device_id": None,
"area_id": None,
"disabled_by": None, "disabled_by": None,
"platform": "test_platform", "platform": "test_platform",
"entity_id": "test_domain.world", "entity_id": "test_domain.world",
@ -329,6 +337,7 @@ async def test_update_entity_id(hass, client):
assert msg["result"] == { assert msg["result"] == {
"config_entry_id": None, "config_entry_id": None,
"device_id": None, "device_id": None,
"area_id": None,
"disabled_by": None, "disabled_by": None,
"platform": "test_platform", "platform": "test_platform",
"entity_id": "test_domain.planet", "entity_id": "test_domain.planet",

View File

@ -154,6 +154,7 @@ async def test_loading_saving_data(hass, registry):
"hue", "hue",
"5678", "5678",
device_id="mock-dev-id", device_id="mock-dev-id",
area_id="mock-area-id",
config_entry=mock_config, config_entry=mock_config,
capabilities={"max": 100}, capabilities={"max": 100},
supported_features=5, supported_features=5,
@ -182,6 +183,7 @@ async def test_loading_saving_data(hass, registry):
assert orig_entry2 == new_entry2 assert orig_entry2 == new_entry2
assert new_entry2.device_id == "mock-dev-id" assert new_entry2.device_id == "mock-dev-id"
assert new_entry2.area_id == "mock-area-id"
assert new_entry2.disabled_by == entity_registry.DISABLED_HASS assert new_entry2.disabled_by == entity_registry.DISABLED_HASS
assert new_entry2.capabilities == {"max": 100} assert new_entry2.capabilities == {"max": 100}
assert new_entry2.supported_features == 5 assert new_entry2.supported_features == 5
@ -330,6 +332,19 @@ async def test_removing_config_entry_id(hass, registry, update_events):
assert update_events[1]["entity_id"] == entry.entity_id assert update_events[1]["entity_id"] == entry.entity_id
async def test_removing_area_id(registry):
"""Make sure we can clear area id."""
entry = registry.async_get_or_create("light", "hue", "5678")
entry_w_area = registry.async_update_entity(entry.entity_id, area_id="12345A")
registry.async_clear_area_id("12345A")
entry_wo_area = registry.async_get(entry.entity_id)
assert not entry_wo_area.area_id
assert entry_w_area != entry_wo_area
async def test_migration(hass): async def test_migration(hass):
"""Test migration from old data to new.""" """Test migration from old data to new."""
mock_config = MockConfigEntry(domain="test-platform", entry_id="test-config-id") mock_config = MockConfigEntry(domain="test-platform", entry_id="test-config-id")

View File

@ -105,12 +105,32 @@ def area_mock(hass):
}, },
) )
entity_in_own_area = ent_reg.RegistryEntry(
entity_id="light.in_own_area",
unique_id="in-own-area-id",
platform="test",
area_id="own-area",
)
entity_in_area = ent_reg.RegistryEntry( entity_in_area = ent_reg.RegistryEntry(
entity_id="light.in_area", entity_id="light.in_area",
unique_id="in-area-id", unique_id="in-area-id",
platform="test", platform="test",
device_id=device_in_area.id, device_id=device_in_area.id,
) )
entity_in_other_area = ent_reg.RegistryEntry(
entity_id="light.in_other_area",
unique_id="in-other-area-id",
platform="test",
device_id=device_in_area.id,
area_id="other-area",
)
entity_assigned_to_area = ent_reg.RegistryEntry(
entity_id="light.assigned_to_area",
unique_id="assigned-area-id",
platform="test",
device_id=device_in_area.id,
area_id="test-area",
)
entity_no_area = ent_reg.RegistryEntry( entity_no_area = ent_reg.RegistryEntry(
entity_id="light.no_area", entity_id="light.no_area",
unique_id="no-area-id", unique_id="no-area-id",
@ -126,7 +146,10 @@ def area_mock(hass):
mock_registry( mock_registry(
hass, hass,
{ {
entity_in_own_area.entity_id: entity_in_own_area,
entity_in_area.entity_id: entity_in_area, entity_in_area.entity_id: entity_in_area,
entity_in_other_area.entity_id: entity_in_other_area,
entity_assigned_to_area.entity_id: entity_assigned_to_area,
entity_no_area.entity_id: entity_no_area, entity_no_area.entity_id: entity_no_area,
entity_diff_area.entity_id: entity_diff_area, entity_diff_area.entity_id: entity_diff_area,
}, },
@ -298,15 +321,25 @@ async def test_extract_entity_ids(hass):
async def test_extract_entity_ids_from_area(hass, area_mock): async def test_extract_entity_ids_from_area(hass, area_mock):
"""Test extract_entity_ids method with areas.""" """Test extract_entity_ids method with areas."""
call = ha.ServiceCall("light", "turn_on", {"area_id": "own-area"})
assert {
"light.in_own_area",
} == await service.async_extract_entity_ids(hass, call)
call = ha.ServiceCall("light", "turn_on", {"area_id": "test-area"}) call = ha.ServiceCall("light", "turn_on", {"area_id": "test-area"})
assert {"light.in_area"} == await service.async_extract_entity_ids(hass, call) assert {
"light.in_area",
"light.assigned_to_area",
} == await service.async_extract_entity_ids(hass, call)
call = ha.ServiceCall("light", "turn_on", {"area_id": ["test-area", "diff-area"]}) call = ha.ServiceCall("light", "turn_on", {"area_id": ["test-area", "diff-area"]})
assert { assert {
"light.in_area", "light.in_area",
"light.diff_area", "light.diff_area",
"light.assigned_to_area",
} == await service.async_extract_entity_ids(hass, call) } == await service.async_extract_entity_ids(hass, call)
assert ( assert (