Allow specifying device_id as target (#43767)

pull/43772/head
Paulus Schoutsen 2020-11-30 14:27:02 +01:00 committed by GitHub
parent e307e1315a
commit 0de9e8e952
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 67 additions and 31 deletions

View File

@ -34,6 +34,7 @@ import voluptuous_serialize
from homeassistant.const import (
ATTR_AREA_ID,
ATTR_DEVICE_ID,
ATTR_ENTITY_ID,
CONF_ABOVE,
CONF_ALIAS,
@ -884,6 +885,9 @@ PLATFORM_SCHEMA_BASE = PLATFORM_SCHEMA.extend({}, extra=vol.ALLOW_EXTRA)
ENTITY_SERVICE_FIELDS = {
vol.Optional(ATTR_ENTITY_ID): comp_entity_ids,
vol.Optional(ATTR_DEVICE_ID): vol.Any(
ENTITY_MATCH_NONE, vol.All(ensure_list, [str])
),
vol.Optional(ATTR_AREA_ID): vol.Any(ENTITY_MATCH_NONE, vol.All(ensure_list, [str])),
}

View File

@ -14,6 +14,7 @@ from typing import (
Set,
Tuple,
Union,
cast,
)
import voluptuous as vol
@ -21,6 +22,7 @@ import voluptuous as vol
from homeassistant.auth.permissions.const import CAT_ENTITIES, POLICY_CONTROL
from homeassistant.const import (
ATTR_AREA_ID,
ATTR_DEVICE_ID,
ATTR_ENTITY_ID,
CONF_SERVICE,
CONF_SERVICE_TEMPLATE,
@ -35,9 +37,8 @@ from homeassistant.exceptions import (
Unauthorized,
UnknownUser,
)
from homeassistant.helpers import template
from homeassistant.helpers import device_registry, entity_registry, template
import homeassistant.helpers.config_validation as cv
from homeassistant.helpers.template import Template
from homeassistant.helpers.typing import ConfigType, HomeAssistantType, TemplateVarsType
from homeassistant.loader import (
MAX_LOAD_CONCURRENTLY,
@ -120,7 +121,7 @@ def async_prepare_call_from_config(
else:
domain_service = config[CONF_SERVICE_TEMPLATE]
if isinstance(domain_service, Template):
if isinstance(domain_service, template.Template):
try:
domain_service.hass = hass
domain_service = domain_service.async_render(variables)
@ -217,17 +218,19 @@ async def async_extract_entity_ids(
Will convert group entity ids to the entity ids it represents.
"""
entity_ids = service_call.data.get(ATTR_ENTITY_ID)
device_ids = service_call.data.get(ATTR_DEVICE_ID)
area_ids = service_call.data.get(ATTR_AREA_ID)
selects_entity_ids = entity_ids not in (None, ENTITY_MATCH_NONE)
selects_device_ids = device_ids not in (None, ENTITY_MATCH_NONE)
selects_area_ids = area_ids not in (None, ENTITY_MATCH_NONE)
extracted: Set[str] = set()
if entity_ids in (None, ENTITY_MATCH_NONE) and area_ids in (
None,
ENTITY_MATCH_NONE,
):
if not selects_entity_ids and not selects_device_ids and not selects_area_ids:
return extracted
if entity_ids and entity_ids != ENTITY_MATCH_NONE:
if selects_entity_ids:
# Entity ID attr can be a list or a string
if isinstance(entity_ids, str):
entity_ids = [entity_ids]
@ -237,39 +240,55 @@ async def async_extract_entity_ids(
extracted.update(entity_ids)
if area_ids and area_ids != ENTITY_MATCH_NONE:
if not selects_device_ids and not selects_area_ids:
return extracted
dev_reg, ent_reg = cast(
Tuple[device_registry.DeviceRegistry, entity_registry.EntityRegistry],
await asyncio.gather(
device_registry.async_get_registry(hass),
entity_registry.async_get_registry(hass),
),
)
if not selects_device_ids:
picked_devices = set()
elif isinstance(device_ids, str):
picked_devices = {device_ids}
else:
assert isinstance(device_ids, list)
picked_devices = set(device_ids)
if selects_area_ids:
if isinstance(area_ids, str):
area_ids = [area_ids]
dev_reg, ent_reg = await asyncio.gather(
hass.helpers.device_registry.async_get_registry(),
hass.helpers.entity_registry.async_get_registry(),
)
assert isinstance(area_ids, list)
# Find entities tied to an area
extracted.update(
entry.entity_id
for area_id in area_ids
for entry in hass.helpers.entity_registry.async_entries_for_area(
ent_reg, area_id
)
for entry in entity_registry.async_entries_for_area(ent_reg, area_id)
)
devices = [
device
for area_id in area_ids
for device in hass.helpers.device_registry.async_entries_for_area(
dev_reg, area_id
)
]
extracted.update(
entry.entity_id
for device in devices
for entry in hass.helpers.entity_registry.async_entries_for_device(
ent_reg, device.id, include_disabled_entities=True
)
if not entry.area_id
picked_devices.update(
[
device.id
for area_id in area_ids
for device in device_registry.async_entries_for_area(dev_reg, area_id)
]
)
if not picked_devices:
return extracted
extracted.update(
entity_entry.entity_id
for entity_entry in ent_reg.entities.values()
if not entity_entry.area_id and entity_entry.device_id in picked_devices
)
return extracted

View File

@ -93,7 +93,7 @@ def area_mock(hass):
hass.states.async_set("light.Kitchen", STATE_OFF)
device_in_area = dev_reg.DeviceEntry(area_id="test-area")
device_no_area = dev_reg.DeviceEntry()
device_no_area = dev_reg.DeviceEntry(id="device-no-area-id")
device_diff_area = dev_reg.DeviceEntry(area_id="diff-area")
mock_device_registry(
@ -947,3 +947,16 @@ async def test_extract_from_service_area_id(hass, area_mock):
"light.diff_area",
"light.in_area",
]
call = ha.ServiceCall(
"light",
"turn_on",
{"area_id": ["test-area", "diff-area"], "device_id": "device-no-area-id"},
)
extracted = await service.async_extract_entities(hass, entities, call)
assert len(extracted) == 3
assert sorted(ent.entity_id for ent in extracted) == [
"light.diff_area",
"light.in_area",
"light.no_area",
]