diff --git a/homeassistant/components/config/entity_registry.py b/homeassistant/components/config/entity_registry.py index f024f146a60..6e4fa30e205 100644 --- a/homeassistant/components/config/entity_registry.py +++ b/homeassistant/components/config/entity_registry.py @@ -71,6 +71,7 @@ async def websocket_get_entity(hass, connection, msg): # If passed in, we update value. Passing None will remove old value. vol.Optional("name"): vol.Any(str, None), vol.Optional("icon"): vol.Any(str, None), + vol.Optional("area_id"): vol.Any(str, None), vol.Optional("new_entity_id"): str, # We only allow setting disabled_by user via API. vol.Optional("disabled_by"): vol.Any("user", None), @@ -91,7 +92,7 @@ async def websocket_update_entity(hass, connection, msg): changes = {} - for key in ("name", "icon", "disabled_by"): + for key in ("name", "icon", "area_id", "disabled_by"): if key in msg: changes[key] = msg[key] @@ -149,6 +150,7 @@ def _entry_dict(entry): return { "config_entry_id": entry.config_entry_id, "device_id": entry.device_id, + "area_id": entry.area_id, "disabled_by": entry.disabled_by, "entity_id": entry.entity_id, "name": entry.name, diff --git a/homeassistant/components/search/__init__.py b/homeassistant/components/search/__init__.py index a3bbd3844aa..81e33aa24b5 100644 --- a/homeassistant/components/search/__init__.py +++ b/homeassistant/components/search/__init__.py @@ -122,6 +122,10 @@ class Searcher: """Resolve an area.""" for device in device_registry.async_entries_for_area(self._device_reg, area_id): self._add_or_resolve("device", device.id) + for entity_entry in entity_registry.async_entries_for_area( + self._entity_reg, area_id + ): + self._add_or_resolve("entity", entity_entry.entity_id) @callback def _resolve_device(self, device_id) -> None: diff --git a/homeassistant/helpers/area_registry.py b/homeassistant/helpers/area_registry.py index 347e552a012..b8f7952cd5a 100644 --- a/homeassistant/helpers/area_registry.py +++ b/homeassistant/helpers/area_registry.py @@ -1,5 +1,5 @@ """Provide a way to connect devices to one physical location.""" -from asyncio import Event +from asyncio import Event, gather from collections import OrderedDict from typing import Dict, Iterable, List, MutableMapping, Optional, cast @@ -64,8 +64,12 @@ class AreaRegistry: async def async_delete(self, area_id: str) -> None: """Delete area.""" - device_registry = await self.hass.helpers.device_registry.async_get_registry() + device_registry, entity_registry = await gather( + self.hass.helpers.device_registry.async_get_registry(), + self.hass.helpers.entity_registry.async_get_registry(), + ) device_registry.async_clear_area_id(area_id) + entity_registry.async_clear_area_id(area_id) del self.areas[area_id] diff --git a/homeassistant/helpers/entity_registry.py b/homeassistant/helpers/entity_registry.py index 14dce1d6d2c..872d87e732f 100644 --- a/homeassistant/helpers/entity_registry.py +++ b/homeassistant/helpers/entity_registry.py @@ -83,6 +83,7 @@ class RegistryEntry: name: Optional[str] = attr.ib(default=None) icon: Optional[str] = attr.ib(default=None) device_id: Optional[str] = attr.ib(default=None) + area_id: Optional[str] = attr.ib(default=None) config_entry_id: Optional[str] = attr.ib(default=None) disabled_by: Optional[str] = attr.ib( default=None, @@ -204,6 +205,7 @@ class EntityRegistry: # Data that we want entry to have config_entry: Optional["ConfigEntry"] = None, device_id: Optional[str] = None, + area_id: Optional[str] = None, capabilities: Optional[Dict[str, Any]] = None, supported_features: Optional[int] = None, device_class: Optional[str] = None, @@ -223,6 +225,7 @@ class EntityRegistry: entity_id, config_entry_id=config_entry_id or _UNDEF, device_id=device_id or _UNDEF, + area_id=area_id or _UNDEF, capabilities=capabilities or _UNDEF, supported_features=supported_features or _UNDEF, device_class=device_class or _UNDEF, @@ -253,6 +256,7 @@ class EntityRegistry: entity_id=entity_id, config_entry_id=config_entry_id, device_id=device_id, + area_id=area_id, unique_id=unique_id, platform=platform, disabled_by=disabled_by, @@ -302,6 +306,7 @@ class EntityRegistry: *, name=_UNDEF, icon=_UNDEF, + area_id=_UNDEF, new_entity_id=_UNDEF, new_unique_id=_UNDEF, disabled_by=_UNDEF, @@ -313,6 +318,7 @@ class EntityRegistry: entity_id, name=name, icon=icon, + area_id=area_id, new_entity_id=new_entity_id, new_unique_id=new_unique_id, disabled_by=disabled_by, @@ -329,6 +335,7 @@ class EntityRegistry: config_entry_id=_UNDEF, new_entity_id=_UNDEF, device_id=_UNDEF, + area_id=_UNDEF, new_unique_id=_UNDEF, disabled_by=_UNDEF, capabilities=_UNDEF, @@ -348,6 +355,7 @@ class EntityRegistry: ("icon", icon), ("config_entry_id", config_entry_id), ("device_id", device_id), + ("area_id", area_id), ("disabled_by", disabled_by), ("capabilities", capabilities), ("supported_features", supported_features), @@ -425,6 +433,7 @@ class EntityRegistry: entity_id=entity["entity_id"], config_entry_id=entity.get("config_entry_id"), device_id=entity.get("device_id"), + area_id=entity.get("area_id"), unique_id=entity["unique_id"], platform=entity["platform"], name=entity.get("name"), @@ -456,6 +465,7 @@ class EntityRegistry: "entity_id": entry.entity_id, "config_entry_id": entry.config_entry_id, "device_id": entry.device_id, + "area_id": entry.area_id, "unique_id": entry.unique_id, "platform": entry.platform, "name": entry.name, @@ -483,6 +493,13 @@ class EntityRegistry: ]: self.async_remove(entity_id) + @callback + def async_clear_area_id(self, area_id: str) -> None: + """Clear area id from registry entries.""" + for entity_id, entry in self.entities.items(): + if area_id == entry.area_id: + self._async_update_entity(entity_id, area_id=None) # type: ignore + def _register_entry(self, entry: RegistryEntry) -> None: self.entities[entry.entity_id] = entry self._add_index(entry) @@ -521,6 +538,14 @@ def async_entries_for_device( ] +@callback +def async_entries_for_area( + registry: EntityRegistry, area_id: str +) -> List[RegistryEntry]: + """Return entries that match an area.""" + return [entry for entry in registry.entities.values() if entry.area_id == area_id] + + @callback def async_entries_for_config_entry( registry: EntityRegistry, config_entry_id: str diff --git a/homeassistant/helpers/service.py b/homeassistant/helpers/service.py index d03fb8c91c6..0993f490537 100644 --- a/homeassistant/helpers/service.py +++ b/homeassistant/helpers/service.py @@ -234,6 +234,15 @@ async def async_extract_entity_ids( hass.helpers.device_registry.async_get_registry(), hass.helpers.entity_registry.async_get_registry(), ) + + extracted.update( + entry.entity_id + for area_id in area_ids + for entry in hass.helpers.entity_registry.async_entries_for_area( + ent_reg, area_id + ) + ) + devices = [ device for area_id in area_ids @@ -247,6 +256,7 @@ async def async_extract_entity_ids( for entry in hass.helpers.entity_registry.async_entries_for_device( ent_reg, device.id ) + if not entry.area_id ) return extracted diff --git a/tests/components/config/test_entity_registry.py b/tests/components/config/test_entity_registry.py index d63d10437cc..84a646ed2ef 100644 --- a/tests/components/config/test_entity_registry.py +++ b/tests/components/config/test_entity_registry.py @@ -39,6 +39,7 @@ async def test_list_entities(hass, client): { "config_entry_id": None, "device_id": None, + "area_id": None, "disabled_by": None, "entity_id": "test_domain.name", "name": "Hello World", @@ -48,6 +49,7 @@ async def test_list_entities(hass, client): { "config_entry_id": None, "device_id": None, + "area_id": None, "disabled_by": None, "entity_id": "test_domain.no_name", "name": None, @@ -84,6 +86,7 @@ async def test_get_entity(hass, client): assert msg["result"] == { "config_entry_id": None, "device_id": None, + "area_id": None, "disabled_by": None, "platform": "test_platform", "entity_id": "test_domain.name", @@ -107,6 +110,7 @@ async def test_get_entity(hass, client): assert msg["result"] == { "config_entry_id": None, "device_id": None, + "area_id": None, "disabled_by": None, "platform": "test_platform", "entity_id": "test_domain.no_name", @@ -143,7 +147,7 @@ async def test_update_entity(hass, client): assert state.name == "before update" assert state.attributes[ATTR_ICON] == "icon:before update" - # UPDATE NAME & ICON + # UPDATE NAME & ICON & AREA await client.send_json( { "id": 6, @@ -151,6 +155,7 @@ async def test_update_entity(hass, client): "entity_id": "test_domain.world", "name": "after update", "icon": "icon:after update", + "area_id": "mock-area-id", } ) @@ -159,6 +164,7 @@ async def test_update_entity(hass, client): assert msg["result"] == { "config_entry_id": None, "device_id": None, + "area_id": "mock-area-id", "disabled_by": None, "platform": "test_platform", "entity_id": "test_domain.world", @@ -204,6 +210,7 @@ async def test_update_entity(hass, client): assert msg["result"] == { "config_entry_id": None, "device_id": None, + "area_id": "mock-area-id", "disabled_by": None, "platform": "test_platform", "entity_id": "test_domain.world", @@ -252,6 +259,7 @@ async def test_update_entity_no_changes(hass, client): assert msg["result"] == { "config_entry_id": None, "device_id": None, + "area_id": None, "disabled_by": None, "platform": "test_platform", "entity_id": "test_domain.world", @@ -329,6 +337,7 @@ async def test_update_entity_id(hass, client): assert msg["result"] == { "config_entry_id": None, "device_id": None, + "area_id": None, "disabled_by": None, "platform": "test_platform", "entity_id": "test_domain.planet", diff --git a/tests/helpers/test_entity_registry.py b/tests/helpers/test_entity_registry.py index 8f5f8fc501b..336329396cc 100644 --- a/tests/helpers/test_entity_registry.py +++ b/tests/helpers/test_entity_registry.py @@ -154,6 +154,7 @@ async def test_loading_saving_data(hass, registry): "hue", "5678", device_id="mock-dev-id", + area_id="mock-area-id", config_entry=mock_config, capabilities={"max": 100}, supported_features=5, @@ -182,6 +183,7 @@ async def test_loading_saving_data(hass, registry): assert orig_entry2 == new_entry2 assert new_entry2.device_id == "mock-dev-id" + assert new_entry2.area_id == "mock-area-id" assert new_entry2.disabled_by == entity_registry.DISABLED_HASS assert new_entry2.capabilities == {"max": 100} assert new_entry2.supported_features == 5 @@ -330,6 +332,19 @@ async def test_removing_config_entry_id(hass, registry, update_events): assert update_events[1]["entity_id"] == entry.entity_id +async def test_removing_area_id(registry): + """Make sure we can clear area id.""" + entry = registry.async_get_or_create("light", "hue", "5678") + + entry_w_area = registry.async_update_entity(entry.entity_id, area_id="12345A") + + registry.async_clear_area_id("12345A") + entry_wo_area = registry.async_get(entry.entity_id) + + assert not entry_wo_area.area_id + assert entry_w_area != entry_wo_area + + async def test_migration(hass): """Test migration from old data to new.""" mock_config = MockConfigEntry(domain="test-platform", entry_id="test-config-id") diff --git a/tests/helpers/test_service.py b/tests/helpers/test_service.py index 929df2a32e0..6f2cd4ba130 100644 --- a/tests/helpers/test_service.py +++ b/tests/helpers/test_service.py @@ -105,12 +105,32 @@ def area_mock(hass): }, ) + entity_in_own_area = ent_reg.RegistryEntry( + entity_id="light.in_own_area", + unique_id="in-own-area-id", + platform="test", + area_id="own-area", + ) entity_in_area = ent_reg.RegistryEntry( entity_id="light.in_area", unique_id="in-area-id", platform="test", device_id=device_in_area.id, ) + entity_in_other_area = ent_reg.RegistryEntry( + entity_id="light.in_other_area", + unique_id="in-other-area-id", + platform="test", + device_id=device_in_area.id, + area_id="other-area", + ) + entity_assigned_to_area = ent_reg.RegistryEntry( + entity_id="light.assigned_to_area", + unique_id="assigned-area-id", + platform="test", + device_id=device_in_area.id, + area_id="test-area", + ) entity_no_area = ent_reg.RegistryEntry( entity_id="light.no_area", unique_id="no-area-id", @@ -126,7 +146,10 @@ def area_mock(hass): mock_registry( hass, { + entity_in_own_area.entity_id: entity_in_own_area, entity_in_area.entity_id: entity_in_area, + entity_in_other_area.entity_id: entity_in_other_area, + entity_assigned_to_area.entity_id: entity_assigned_to_area, entity_no_area.entity_id: entity_no_area, entity_diff_area.entity_id: entity_diff_area, }, @@ -298,15 +321,25 @@ async def test_extract_entity_ids(hass): async def test_extract_entity_ids_from_area(hass, area_mock): """Test extract_entity_ids method with areas.""" + call = ha.ServiceCall("light", "turn_on", {"area_id": "own-area"}) + + assert { + "light.in_own_area", + } == await service.async_extract_entity_ids(hass, call) + call = ha.ServiceCall("light", "turn_on", {"area_id": "test-area"}) - assert {"light.in_area"} == await service.async_extract_entity_ids(hass, call) + assert { + "light.in_area", + "light.assigned_to_area", + } == await service.async_extract_entity_ids(hass, call) call = ha.ServiceCall("light", "turn_on", {"area_id": ["test-area", "diff-area"]}) assert { "light.in_area", "light.diff_area", + "light.assigned_to_area", } == await service.async_extract_entity_ids(hass, call) assert (