Allow specifying device_id as target (#43767)
parent
e307e1315a
commit
0de9e8e952
|
@ -34,6 +34,7 @@ import voluptuous_serialize
|
|||
|
||||
from homeassistant.const import (
|
||||
ATTR_AREA_ID,
|
||||
ATTR_DEVICE_ID,
|
||||
ATTR_ENTITY_ID,
|
||||
CONF_ABOVE,
|
||||
CONF_ALIAS,
|
||||
|
@ -884,6 +885,9 @@ PLATFORM_SCHEMA_BASE = PLATFORM_SCHEMA.extend({}, extra=vol.ALLOW_EXTRA)
|
|||
|
||||
ENTITY_SERVICE_FIELDS = {
|
||||
vol.Optional(ATTR_ENTITY_ID): comp_entity_ids,
|
||||
vol.Optional(ATTR_DEVICE_ID): vol.Any(
|
||||
ENTITY_MATCH_NONE, vol.All(ensure_list, [str])
|
||||
),
|
||||
vol.Optional(ATTR_AREA_ID): vol.Any(ENTITY_MATCH_NONE, vol.All(ensure_list, [str])),
|
||||
}
|
||||
|
||||
|
|
|
@ -14,6 +14,7 @@ from typing import (
|
|||
Set,
|
||||
Tuple,
|
||||
Union,
|
||||
cast,
|
||||
)
|
||||
|
||||
import voluptuous as vol
|
||||
|
@ -21,6 +22,7 @@ import voluptuous as vol
|
|||
from homeassistant.auth.permissions.const import CAT_ENTITIES, POLICY_CONTROL
|
||||
from homeassistant.const import (
|
||||
ATTR_AREA_ID,
|
||||
ATTR_DEVICE_ID,
|
||||
ATTR_ENTITY_ID,
|
||||
CONF_SERVICE,
|
||||
CONF_SERVICE_TEMPLATE,
|
||||
|
@ -35,9 +37,8 @@ from homeassistant.exceptions import (
|
|||
Unauthorized,
|
||||
UnknownUser,
|
||||
)
|
||||
from homeassistant.helpers import template
|
||||
from homeassistant.helpers import device_registry, entity_registry, template
|
||||
import homeassistant.helpers.config_validation as cv
|
||||
from homeassistant.helpers.template import Template
|
||||
from homeassistant.helpers.typing import ConfigType, HomeAssistantType, TemplateVarsType
|
||||
from homeassistant.loader import (
|
||||
MAX_LOAD_CONCURRENTLY,
|
||||
|
@ -120,7 +121,7 @@ def async_prepare_call_from_config(
|
|||
else:
|
||||
domain_service = config[CONF_SERVICE_TEMPLATE]
|
||||
|
||||
if isinstance(domain_service, Template):
|
||||
if isinstance(domain_service, template.Template):
|
||||
try:
|
||||
domain_service.hass = hass
|
||||
domain_service = domain_service.async_render(variables)
|
||||
|
@ -217,17 +218,19 @@ async def async_extract_entity_ids(
|
|||
Will convert group entity ids to the entity ids it represents.
|
||||
"""
|
||||
entity_ids = service_call.data.get(ATTR_ENTITY_ID)
|
||||
device_ids = service_call.data.get(ATTR_DEVICE_ID)
|
||||
area_ids = service_call.data.get(ATTR_AREA_ID)
|
||||
|
||||
selects_entity_ids = entity_ids not in (None, ENTITY_MATCH_NONE)
|
||||
selects_device_ids = device_ids not in (None, ENTITY_MATCH_NONE)
|
||||
selects_area_ids = area_ids not in (None, ENTITY_MATCH_NONE)
|
||||
|
||||
extracted: Set[str] = set()
|
||||
|
||||
if entity_ids in (None, ENTITY_MATCH_NONE) and area_ids in (
|
||||
None,
|
||||
ENTITY_MATCH_NONE,
|
||||
):
|
||||
if not selects_entity_ids and not selects_device_ids and not selects_area_ids:
|
||||
return extracted
|
||||
|
||||
if entity_ids and entity_ids != ENTITY_MATCH_NONE:
|
||||
if selects_entity_ids:
|
||||
# Entity ID attr can be a list or a string
|
||||
if isinstance(entity_ids, str):
|
||||
entity_ids = [entity_ids]
|
||||
|
@ -237,39 +240,55 @@ async def async_extract_entity_ids(
|
|||
|
||||
extracted.update(entity_ids)
|
||||
|
||||
if area_ids and area_ids != ENTITY_MATCH_NONE:
|
||||
if not selects_device_ids and not selects_area_ids:
|
||||
return extracted
|
||||
|
||||
dev_reg, ent_reg = cast(
|
||||
Tuple[device_registry.DeviceRegistry, entity_registry.EntityRegistry],
|
||||
await asyncio.gather(
|
||||
device_registry.async_get_registry(hass),
|
||||
entity_registry.async_get_registry(hass),
|
||||
),
|
||||
)
|
||||
|
||||
if not selects_device_ids:
|
||||
picked_devices = set()
|
||||
elif isinstance(device_ids, str):
|
||||
picked_devices = {device_ids}
|
||||
else:
|
||||
assert isinstance(device_ids, list)
|
||||
picked_devices = set(device_ids)
|
||||
|
||||
if selects_area_ids:
|
||||
if isinstance(area_ids, str):
|
||||
area_ids = [area_ids]
|
||||
|
||||
dev_reg, ent_reg = await asyncio.gather(
|
||||
hass.helpers.device_registry.async_get_registry(),
|
||||
hass.helpers.entity_registry.async_get_registry(),
|
||||
)
|
||||
assert isinstance(area_ids, list)
|
||||
|
||||
# Find entities tied to an area
|
||||
extracted.update(
|
||||
entry.entity_id
|
||||
for area_id in area_ids
|
||||
for entry in hass.helpers.entity_registry.async_entries_for_area(
|
||||
ent_reg, area_id
|
||||
)
|
||||
for entry in entity_registry.async_entries_for_area(ent_reg, area_id)
|
||||
)
|
||||
|
||||
devices = [
|
||||
device
|
||||
for area_id in area_ids
|
||||
for device in hass.helpers.device_registry.async_entries_for_area(
|
||||
dev_reg, area_id
|
||||
)
|
||||
]
|
||||
extracted.update(
|
||||
entry.entity_id
|
||||
for device in devices
|
||||
for entry in hass.helpers.entity_registry.async_entries_for_device(
|
||||
ent_reg, device.id, include_disabled_entities=True
|
||||
)
|
||||
if not entry.area_id
|
||||
picked_devices.update(
|
||||
[
|
||||
device.id
|
||||
for area_id in area_ids
|
||||
for device in device_registry.async_entries_for_area(dev_reg, area_id)
|
||||
]
|
||||
)
|
||||
|
||||
if not picked_devices:
|
||||
return extracted
|
||||
|
||||
extracted.update(
|
||||
entity_entry.entity_id
|
||||
for entity_entry in ent_reg.entities.values()
|
||||
if not entity_entry.area_id and entity_entry.device_id in picked_devices
|
||||
)
|
||||
|
||||
return extracted
|
||||
|
||||
|
||||
|
|
|
@ -93,7 +93,7 @@ def area_mock(hass):
|
|||
hass.states.async_set("light.Kitchen", STATE_OFF)
|
||||
|
||||
device_in_area = dev_reg.DeviceEntry(area_id="test-area")
|
||||
device_no_area = dev_reg.DeviceEntry()
|
||||
device_no_area = dev_reg.DeviceEntry(id="device-no-area-id")
|
||||
device_diff_area = dev_reg.DeviceEntry(area_id="diff-area")
|
||||
|
||||
mock_device_registry(
|
||||
|
@ -947,3 +947,16 @@ async def test_extract_from_service_area_id(hass, area_mock):
|
|||
"light.diff_area",
|
||||
"light.in_area",
|
||||
]
|
||||
|
||||
call = ha.ServiceCall(
|
||||
"light",
|
||||
"turn_on",
|
||||
{"area_id": ["test-area", "diff-area"], "device_id": "device-no-area-id"},
|
||||
)
|
||||
extracted = await service.async_extract_entities(hass, entities, call)
|
||||
assert len(extracted) == 3
|
||||
assert sorted(ent.entity_id for ent in extracted) == [
|
||||
"light.diff_area",
|
||||
"light.in_area",
|
||||
"light.no_area",
|
||||
]
|
||||
|
|
Loading…
Reference in New Issue