diff --git a/homeassistant/helpers/config_validation.py b/homeassistant/helpers/config_validation.py index 37d7f312d94..dea8deec715 100644 --- a/homeassistant/helpers/config_validation.py +++ b/homeassistant/helpers/config_validation.py @@ -34,6 +34,7 @@ import voluptuous_serialize from homeassistant.const import ( ATTR_AREA_ID, + ATTR_DEVICE_ID, ATTR_ENTITY_ID, CONF_ABOVE, CONF_ALIAS, @@ -884,6 +885,9 @@ PLATFORM_SCHEMA_BASE = PLATFORM_SCHEMA.extend({}, extra=vol.ALLOW_EXTRA) ENTITY_SERVICE_FIELDS = { vol.Optional(ATTR_ENTITY_ID): comp_entity_ids, + vol.Optional(ATTR_DEVICE_ID): vol.Any( + ENTITY_MATCH_NONE, vol.All(ensure_list, [str]) + ), vol.Optional(ATTR_AREA_ID): vol.Any(ENTITY_MATCH_NONE, vol.All(ensure_list, [str])), } diff --git a/homeassistant/helpers/service.py b/homeassistant/helpers/service.py index 7aa4ac8b013..47918f31514 100644 --- a/homeassistant/helpers/service.py +++ b/homeassistant/helpers/service.py @@ -14,6 +14,7 @@ from typing import ( Set, Tuple, Union, + cast, ) import voluptuous as vol @@ -21,6 +22,7 @@ import voluptuous as vol from homeassistant.auth.permissions.const import CAT_ENTITIES, POLICY_CONTROL from homeassistant.const import ( ATTR_AREA_ID, + ATTR_DEVICE_ID, ATTR_ENTITY_ID, CONF_SERVICE, CONF_SERVICE_TEMPLATE, @@ -35,9 +37,8 @@ from homeassistant.exceptions import ( Unauthorized, UnknownUser, ) -from homeassistant.helpers import template +from homeassistant.helpers import device_registry, entity_registry, template import homeassistant.helpers.config_validation as cv -from homeassistant.helpers.template import Template from homeassistant.helpers.typing import ConfigType, HomeAssistantType, TemplateVarsType from homeassistant.loader import ( MAX_LOAD_CONCURRENTLY, @@ -120,7 +121,7 @@ def async_prepare_call_from_config( else: domain_service = config[CONF_SERVICE_TEMPLATE] - if isinstance(domain_service, Template): + if isinstance(domain_service, template.Template): try: domain_service.hass = hass domain_service = domain_service.async_render(variables) @@ -217,17 +218,19 @@ async def async_extract_entity_ids( Will convert group entity ids to the entity ids it represents. """ entity_ids = service_call.data.get(ATTR_ENTITY_ID) + device_ids = service_call.data.get(ATTR_DEVICE_ID) area_ids = service_call.data.get(ATTR_AREA_ID) + selects_entity_ids = entity_ids not in (None, ENTITY_MATCH_NONE) + selects_device_ids = device_ids not in (None, ENTITY_MATCH_NONE) + selects_area_ids = area_ids not in (None, ENTITY_MATCH_NONE) + extracted: Set[str] = set() - if entity_ids in (None, ENTITY_MATCH_NONE) and area_ids in ( - None, - ENTITY_MATCH_NONE, - ): + if not selects_entity_ids and not selects_device_ids and not selects_area_ids: return extracted - if entity_ids and entity_ids != ENTITY_MATCH_NONE: + if selects_entity_ids: # Entity ID attr can be a list or a string if isinstance(entity_ids, str): entity_ids = [entity_ids] @@ -237,39 +240,55 @@ async def async_extract_entity_ids( extracted.update(entity_ids) - if area_ids and area_ids != ENTITY_MATCH_NONE: + if not selects_device_ids and not selects_area_ids: + return extracted + + dev_reg, ent_reg = cast( + Tuple[device_registry.DeviceRegistry, entity_registry.EntityRegistry], + await asyncio.gather( + device_registry.async_get_registry(hass), + entity_registry.async_get_registry(hass), + ), + ) + + if not selects_device_ids: + picked_devices = set() + elif isinstance(device_ids, str): + picked_devices = {device_ids} + else: + assert isinstance(device_ids, list) + picked_devices = set(device_ids) + + if selects_area_ids: if isinstance(area_ids, str): area_ids = [area_ids] - dev_reg, ent_reg = await asyncio.gather( - hass.helpers.device_registry.async_get_registry(), - hass.helpers.entity_registry.async_get_registry(), - ) + assert isinstance(area_ids, list) + # Find entities tied to an area extracted.update( entry.entity_id for area_id in area_ids - for entry in hass.helpers.entity_registry.async_entries_for_area( - ent_reg, area_id - ) + for entry in entity_registry.async_entries_for_area(ent_reg, area_id) ) - devices = [ - device - for area_id in area_ids - for device in hass.helpers.device_registry.async_entries_for_area( - dev_reg, area_id - ) - ] - extracted.update( - entry.entity_id - for device in devices - for entry in hass.helpers.entity_registry.async_entries_for_device( - ent_reg, device.id, include_disabled_entities=True - ) - if not entry.area_id + picked_devices.update( + [ + device.id + for area_id in area_ids + for device in device_registry.async_entries_for_area(dev_reg, area_id) + ] ) + if not picked_devices: + return extracted + + extracted.update( + entity_entry.entity_id + for entity_entry in ent_reg.entities.values() + if not entity_entry.area_id and entity_entry.device_id in picked_devices + ) + return extracted diff --git a/tests/helpers/test_service.py b/tests/helpers/test_service.py index 93125fba96d..f9b09b259ca 100644 --- a/tests/helpers/test_service.py +++ b/tests/helpers/test_service.py @@ -93,7 +93,7 @@ def area_mock(hass): hass.states.async_set("light.Kitchen", STATE_OFF) device_in_area = dev_reg.DeviceEntry(area_id="test-area") - device_no_area = dev_reg.DeviceEntry() + device_no_area = dev_reg.DeviceEntry(id="device-no-area-id") device_diff_area = dev_reg.DeviceEntry(area_id="diff-area") mock_device_registry( @@ -947,3 +947,16 @@ async def test_extract_from_service_area_id(hass, area_mock): "light.diff_area", "light.in_area", ] + + call = ha.ServiceCall( + "light", + "turn_on", + {"area_id": ["test-area", "diff-area"], "device_id": "device-no-area-id"}, + ) + extracted = await service.async_extract_entities(hass, entities, call) + assert len(extracted) == 3 + assert sorted(ent.entity_id for ent in extracted) == [ + "light.diff_area", + "light.in_area", + "light.no_area", + ]