Allow targeting areas in service calls (#21472)
* Allow targeting areas in service calls * Lint + Type * Address commentspull/21655/head
parent
f62eb22ef8
commit
8213016eaf
|
@ -17,7 +17,7 @@ import voluptuous as vol
|
|||
import homeassistant.core as ha
|
||||
import homeassistant.config as conf_util
|
||||
from homeassistant.exceptions import HomeAssistantError
|
||||
from homeassistant.helpers.service import extract_entity_ids
|
||||
from homeassistant.helpers.service import async_extract_entity_ids
|
||||
from homeassistant.helpers import intent
|
||||
from homeassistant.const import (
|
||||
ATTR_ENTITY_ID, SERVICE_TURN_ON, SERVICE_TURN_OFF, SERVICE_TOGGLE,
|
||||
|
@ -70,7 +70,7 @@ async def async_setup(hass: ha.HomeAssistant, config: dict) -> Awaitable[bool]:
|
|||
"""Set up general services related to Home Assistant."""
|
||||
async def async_handle_turn_service(service):
|
||||
"""Handle calls to homeassistant.turn_on/off."""
|
||||
entity_ids = extract_entity_ids(hass, service)
|
||||
entity_ids = await async_extract_entity_ids(hass, service)
|
||||
|
||||
# Generic turn on/off method requires entity id
|
||||
if not entity_ids:
|
||||
|
|
|
@ -89,7 +89,7 @@ async def async_setup(hass, config):
|
|||
|
||||
async def async_handle_alert_service(service_call):
|
||||
"""Handle calls to alert services."""
|
||||
alert_ids = service.extract_entity_ids(hass, service_call)
|
||||
alert_ids = await service.async_extract_entity_ids(hass, service_call)
|
||||
|
||||
for alert_id in alert_ids:
|
||||
for alert in entities:
|
||||
|
|
|
@ -120,7 +120,7 @@ async def async_setup(hass, config):
|
|||
async def trigger_service_handler(service_call):
|
||||
"""Handle automation triggers."""
|
||||
tasks = []
|
||||
for entity in component.async_extract_from_service(service_call):
|
||||
for entity in await component.async_extract_from_service(service_call):
|
||||
tasks.append(entity.async_trigger(
|
||||
service_call.data.get(ATTR_VARIABLES),
|
||||
skip_condition=True,
|
||||
|
@ -133,7 +133,7 @@ async def async_setup(hass, config):
|
|||
"""Handle automation turn on/off service calls."""
|
||||
tasks = []
|
||||
method = 'async_{}'.format(service_call.service)
|
||||
for entity in component.async_extract_from_service(service_call):
|
||||
for entity in await component.async_extract_from_service(service_call):
|
||||
tasks.append(getattr(entity, method)())
|
||||
|
||||
if tasks:
|
||||
|
@ -142,7 +142,7 @@ async def async_setup(hass, config):
|
|||
async def toggle_service_handler(service_call):
|
||||
"""Handle automation toggle service calls."""
|
||||
tasks = []
|
||||
for entity in component.async_extract_from_service(service_call):
|
||||
for entity in await component.async_extract_from_service(service_call):
|
||||
if entity.is_on:
|
||||
tasks.append(entity.async_turn_off())
|
||||
else:
|
||||
|
|
|
@ -300,8 +300,8 @@ async def async_setup(hass, config):
|
|||
visible = service.data.get(ATTR_VISIBLE)
|
||||
|
||||
tasks = []
|
||||
for group in component.async_extract_from_service(service,
|
||||
expand_group=False):
|
||||
for group in await component.async_extract_from_service(
|
||||
service, expand_group=False):
|
||||
group.visible = visible
|
||||
tasks.append(group.async_update_ha_state())
|
||||
|
||||
|
|
|
@ -75,7 +75,7 @@ async def async_setup(hass, config):
|
|||
|
||||
async def async_scan_service(service):
|
||||
"""Service handler for scan."""
|
||||
image_entities = component.async_extract_from_service(service)
|
||||
image_entities = await component.async_extract_from_service(service)
|
||||
|
||||
update_tasks = []
|
||||
for entity in image_entities:
|
||||
|
|
|
@ -256,7 +256,7 @@ async def async_setup(hass, config):
|
|||
params = service.data.copy()
|
||||
|
||||
# Convert the entity ids to valid light ids
|
||||
target_lights = component.async_extract_from_service(service)
|
||||
target_lights = await component.async_extract_from_service(service)
|
||||
params.pop(ATTR_ENTITY_ID, None)
|
||||
|
||||
if service.context.user_id:
|
||||
|
|
|
@ -68,7 +68,7 @@ async def async_setup(hass, config):
|
|||
|
||||
async def async_handle_scene_service(service):
|
||||
"""Handle calls to the switch services."""
|
||||
target_scenes = component.async_extract_from_service(service)
|
||||
target_scenes = await component.async_extract_from_service(service)
|
||||
|
||||
tasks = [scene.async_activate() for scene in target_scenes]
|
||||
if tasks:
|
||||
|
|
|
@ -74,20 +74,21 @@ async def async_setup(hass, config):
|
|||
# We could turn on script directly here, but we only want to offer
|
||||
# one way to do it. Otherwise no easy way to detect invocations.
|
||||
var = service.data.get(ATTR_VARIABLES)
|
||||
for script in component.async_extract_from_service(service):
|
||||
for script in await component.async_extract_from_service(service):
|
||||
await hass.services.async_call(DOMAIN, script.object_id, var,
|
||||
context=service.context)
|
||||
|
||||
async def turn_off_service(service):
|
||||
"""Cancel a script."""
|
||||
# Stopping a script is ok to be done in parallel
|
||||
await asyncio.wait(
|
||||
[script.async_turn_off() for script
|
||||
in component.async_extract_from_service(service)], loop=hass.loop)
|
||||
await asyncio.wait([
|
||||
script.async_turn_off() for script
|
||||
in await component.async_extract_from_service(service)
|
||||
], loop=hass.loop)
|
||||
|
||||
async def toggle_service(service):
|
||||
"""Toggle a script."""
|
||||
for script in component.async_extract_from_service(service):
|
||||
for script in await component.async_extract_from_service(service):
|
||||
await script.async_toggle(context=service.context)
|
||||
|
||||
hass.services.async_register(DOMAIN, SERVICE_RELOAD, reload_service,
|
||||
|
|
|
@ -245,6 +245,9 @@ ATTR_NAME = 'name'
|
|||
# Contains one string or a list of strings, each being an entity id
|
||||
ATTR_ENTITY_ID = 'entity_id'
|
||||
|
||||
# Contains one string or a list of strings, each being an area id
|
||||
ATTR_AREA_ID = 'area_id'
|
||||
|
||||
# String with a friendly name for the entity
|
||||
ATTR_FRIENDLY_NAME = 'friendly_name'
|
||||
|
||||
|
|
|
@ -37,6 +37,11 @@ class AreaRegistry:
|
|||
self.areas = {} # type: MutableMapping[str, AreaEntry]
|
||||
self._store = hass.helpers.storage.Store(STORAGE_VERSION, STORAGE_KEY)
|
||||
|
||||
@callback
|
||||
def async_get_area(self, area_id: str) -> Optional[AreaEntry]:
|
||||
"""Get all areas."""
|
||||
return self.areas.get(area_id)
|
||||
|
||||
@callback
|
||||
def async_list_areas(self) -> Iterable[AreaEntry]:
|
||||
"""Get all areas."""
|
||||
|
|
|
@ -1,6 +1,7 @@
|
|||
"""Provide a way to connect entities belonging to one device."""
|
||||
import logging
|
||||
import uuid
|
||||
from typing import List
|
||||
|
||||
from collections import OrderedDict
|
||||
|
||||
|
@ -280,3 +281,11 @@ async def async_get_registry(hass) -> DeviceRegistry:
|
|||
task = hass.data[DATA_REGISTRY] = hass.async_create_task(_load_reg())
|
||||
|
||||
return await task
|
||||
|
||||
|
||||
@callback
|
||||
def async_entries_for_area(registry: DeviceRegistry, area_id: str) \
|
||||
-> List[DeviceEntry]:
|
||||
"""Return entries that match an area."""
|
||||
return [device for device in registry.devices.values()
|
||||
if device.area_id == area_id]
|
||||
|
|
|
@ -12,7 +12,7 @@ from homeassistant.const import (
|
|||
from homeassistant.core import callback
|
||||
from homeassistant.exceptions import HomeAssistantError
|
||||
from homeassistant.helpers import config_per_platform, discovery
|
||||
from homeassistant.helpers.service import extract_entity_ids
|
||||
from homeassistant.helpers.service import async_extract_entity_ids
|
||||
from homeassistant.loader import bind_hass
|
||||
from homeassistant.util import slugify
|
||||
from .entity_platform import EntityPlatform
|
||||
|
@ -153,8 +153,7 @@ class EntityComponent:
|
|||
await platform.async_reset()
|
||||
return True
|
||||
|
||||
@callback
|
||||
def async_extract_from_service(self, service, expand_group=True):
|
||||
async def async_extract_from_service(self, service, expand_group=True):
|
||||
"""Extract all known and available entities from a service call.
|
||||
|
||||
Will return all entities if no entities specified in call.
|
||||
|
@ -174,7 +173,8 @@ class EntityComponent:
|
|||
|
||||
return [entity for entity in self.entities if entity.available]
|
||||
|
||||
entity_ids = set(extract_entity_ids(self.hass, service, expand_group))
|
||||
entity_ids = await async_extract_entity_ids(
|
||||
self.hass, service, expand_group)
|
||||
return [entity for entity in self.entities
|
||||
if entity.available and entity.entity_id in entity_ids]
|
||||
|
||||
|
|
|
@ -10,7 +10,7 @@ timer.
|
|||
from collections import OrderedDict
|
||||
from itertools import chain
|
||||
import logging
|
||||
from typing import Optional
|
||||
from typing import Optional, List
|
||||
import weakref
|
||||
|
||||
import attr
|
||||
|
@ -292,6 +292,14 @@ async def async_get_registry(hass) -> EntityRegistry:
|
|||
return await task
|
||||
|
||||
|
||||
@callback
|
||||
def async_entries_for_device(registry: EntityRegistry, device_id: str) \
|
||||
-> List[RegistryEntry]:
|
||||
"""Return entries that match a device."""
|
||||
return [entry for entry in registry.entities.values()
|
||||
if entry.device_id == device_id]
|
||||
|
||||
|
||||
async def _async_migrate(entities):
|
||||
"""Migrate the YAML config file to storage helper format."""
|
||||
return {
|
||||
|
|
|
@ -6,7 +6,8 @@ from os import path
|
|||
import voluptuous as vol
|
||||
|
||||
from homeassistant.auth.permissions.const import POLICY_CONTROL
|
||||
from homeassistant.const import ATTR_ENTITY_ID, ENTITY_MATCH_ALL
|
||||
from homeassistant.const import (
|
||||
ATTR_ENTITY_ID, ENTITY_MATCH_ALL, ATTR_AREA_ID)
|
||||
import homeassistant.core as ha
|
||||
from homeassistant.exceptions import TemplateError, Unauthorized, UnknownUser
|
||||
from homeassistant.helpers import template
|
||||
|
@ -89,30 +90,64 @@ async def async_call_from_config(hass, config, blocking=False, variables=None,
|
|||
def extract_entity_ids(hass, service_call, expand_group=True):
|
||||
"""Extract a list of entity ids from a service call.
|
||||
|
||||
Will convert group entity ids to the entity ids it represents.
|
||||
"""
|
||||
return run_coroutine_threadsafe(
|
||||
async_extract_entity_ids(hass, service_call, expand_group), hass.loop
|
||||
).result()
|
||||
|
||||
|
||||
@bind_hass
|
||||
async def async_extract_entity_ids(hass, service_call, expand_group=True):
|
||||
"""Extract a list of entity ids from a service call.
|
||||
|
||||
Will convert group entity ids to the entity ids it represents.
|
||||
|
||||
Async friendly.
|
||||
"""
|
||||
if not (service_call.data and ATTR_ENTITY_ID in service_call.data):
|
||||
entity_ids = service_call.data.get(ATTR_ENTITY_ID)
|
||||
area_ids = service_call.data.get(ATTR_AREA_ID)
|
||||
|
||||
if not entity_ids and not area_ids:
|
||||
return []
|
||||
|
||||
group = hass.components.group
|
||||
extracted = set()
|
||||
|
||||
# Entity ID attr can be a list or a string
|
||||
service_ent_id = service_call.data[ATTR_ENTITY_ID]
|
||||
if entity_ids:
|
||||
# Entity ID attr can be a list or a string
|
||||
if isinstance(entity_ids, str):
|
||||
entity_ids = [entity_ids]
|
||||
|
||||
if expand_group:
|
||||
if expand_group:
|
||||
entity_ids = \
|
||||
hass.components.group.expand_entity_ids(entity_ids)
|
||||
|
||||
if isinstance(service_ent_id, str):
|
||||
return group.expand_entity_ids([service_ent_id])
|
||||
extracted.update(entity_ids)
|
||||
|
||||
return [ent_id for ent_id in
|
||||
group.expand_entity_ids(service_ent_id)]
|
||||
if area_ids:
|
||||
if isinstance(area_ids, str):
|
||||
area_ids = [area_ids]
|
||||
|
||||
if isinstance(service_ent_id, str):
|
||||
return [service_ent_id]
|
||||
dev_reg, ent_reg = await asyncio.gather(
|
||||
hass.helpers.device_registry.async_get_registry(),
|
||||
hass.helpers.entity_registry.async_get_registry(),
|
||||
)
|
||||
devices = [
|
||||
device
|
||||
for area_id in area_ids
|
||||
for device in
|
||||
hass.helpers.device_registry.async_entries_for_area(
|
||||
dev_reg, area_id)
|
||||
]
|
||||
extracted.update(
|
||||
entry.entity_id
|
||||
for device in devices
|
||||
for entry in
|
||||
hass.helpers.entity_registry.async_entries_for_device(
|
||||
ent_reg, device.id)
|
||||
)
|
||||
|
||||
return service_ent_id
|
||||
return extracted
|
||||
|
||||
|
||||
@bind_hass
|
||||
|
@ -213,8 +248,7 @@ async def entity_service_call(hass, platforms, func, call, service_name=''):
|
|||
|
||||
if not target_all_entities:
|
||||
# A set of entities we're trying to target.
|
||||
entity_ids = set(
|
||||
extract_entity_ids(hass, call, True))
|
||||
entity_ids = await async_extract_entity_ids(hass, call, True)
|
||||
|
||||
# If the service function is a string, we'll pass it the service call data
|
||||
if isinstance(func, str):
|
||||
|
|
|
@ -206,7 +206,7 @@ def test_extract_from_service_available_device(hass):
|
|||
|
||||
assert ['test_domain.test_1', 'test_domain.test_3'] == \
|
||||
sorted(ent.entity_id for ent in
|
||||
component.async_extract_from_service(call_1))
|
||||
(yield from component.async_extract_from_service(call_1)))
|
||||
|
||||
call_2 = ha.ServiceCall('test', 'service', data={
|
||||
'entity_id': ['test_domain.test_3', 'test_domain.test_4'],
|
||||
|
@ -214,7 +214,7 @@ def test_extract_from_service_available_device(hass):
|
|||
|
||||
assert ['test_domain.test_3'] == \
|
||||
sorted(ent.entity_id for ent in
|
||||
component.async_extract_from_service(call_2))
|
||||
(yield from component.async_extract_from_service(call_2)))
|
||||
|
||||
|
||||
@asyncio.coroutine
|
||||
|
@ -275,7 +275,7 @@ def test_extract_from_service_returns_all_if_no_entity_id(hass):
|
|||
|
||||
assert ['test_domain.test_1', 'test_domain.test_2'] == \
|
||||
sorted(ent.entity_id for ent in
|
||||
component.async_extract_from_service(call))
|
||||
(yield from component.async_extract_from_service(call)))
|
||||
|
||||
|
||||
@asyncio.coroutine
|
||||
|
@ -293,7 +293,7 @@ def test_extract_from_service_filter_out_non_existing_entities(hass):
|
|||
|
||||
assert ['test_domain.test_2'] == \
|
||||
[ent.entity_id for ent
|
||||
in component.async_extract_from_service(call)]
|
||||
in (yield from component.async_extract_from_service(call))]
|
||||
|
||||
|
||||
@asyncio.coroutine
|
||||
|
@ -308,7 +308,8 @@ def test_extract_from_service_no_group_expand(hass):
|
|||
'entity_id': ['group.test_group']
|
||||
})
|
||||
|
||||
extracted = component.async_extract_from_service(call, expand_group=False)
|
||||
extracted = yield from component.async_extract_from_service(
|
||||
call, expand_group=False)
|
||||
assert extracted == [test_group]
|
||||
|
||||
|
||||
|
@ -466,7 +467,7 @@ async def test_extract_all_omit_entity_id(hass, caplog):
|
|||
|
||||
assert ['test_domain.test_1', 'test_domain.test_2'] == \
|
||||
sorted(ent.entity_id for ent in
|
||||
component.async_extract_from_service(call))
|
||||
await component.async_extract_from_service(call))
|
||||
assert ('Not passing an entity ID to a service to target all entities is '
|
||||
'deprecated') in caplog.text
|
||||
|
||||
|
@ -483,6 +484,6 @@ async def test_extract_all_use_match_all(hass, caplog):
|
|||
|
||||
assert ['test_domain.test_1', 'test_domain.test_2'] == \
|
||||
sorted(ent.entity_id for ent in
|
||||
component.async_extract_from_service(call))
|
||||
await component.async_extract_from_service(call))
|
||||
assert ('Not passing an entity ID to a service to target all entities is '
|
||||
'deprecated') not in caplog.text
|
||||
|
|
|
@ -15,8 +15,11 @@ from homeassistant.helpers import service, template
|
|||
from homeassistant.setup import async_setup_component
|
||||
import homeassistant.helpers.config_validation as cv
|
||||
from homeassistant.auth.permissions import PolicyPermissions
|
||||
|
||||
from tests.common import get_test_home_assistant, mock_service, mock_coro
|
||||
from homeassistant.helpers import (
|
||||
device_registry as dev_reg, entity_registry as ent_reg)
|
||||
from tests.common import (
|
||||
get_test_home_assistant, mock_service, mock_coro, mock_registry,
|
||||
mock_device_registry)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
|
@ -163,29 +166,83 @@ class TestServiceHelpers(unittest.TestCase):
|
|||
})
|
||||
assert 3 == mock_log.call_count
|
||||
|
||||
def test_extract_entity_ids(self):
|
||||
"""Test extract_entity_ids method."""
|
||||
self.hass.states.set('light.Bowl', STATE_ON)
|
||||
self.hass.states.set('light.Ceiling', STATE_OFF)
|
||||
self.hass.states.set('light.Kitchen', STATE_OFF)
|
||||
|
||||
loader.get_component(self.hass, 'group').Group.create_group(
|
||||
self.hass, 'test', ['light.Ceiling', 'light.Kitchen'])
|
||||
async def test_extract_entity_ids(hass):
|
||||
"""Test extract_entity_ids method."""
|
||||
hass.states.async_set('light.Bowl', STATE_ON)
|
||||
hass.states.async_set('light.Ceiling', STATE_OFF)
|
||||
hass.states.async_set('light.Kitchen', STATE_OFF)
|
||||
|
||||
call = ha.ServiceCall('light', 'turn_on',
|
||||
{ATTR_ENTITY_ID: 'light.Bowl'})
|
||||
await loader.get_component(hass, 'group').Group.async_create_group(
|
||||
hass, 'test', ['light.Ceiling', 'light.Kitchen'])
|
||||
|
||||
assert ['light.bowl'] == \
|
||||
service.extract_entity_ids(self.hass, call)
|
||||
call = ha.ServiceCall('light', 'turn_on',
|
||||
{ATTR_ENTITY_ID: 'light.Bowl'})
|
||||
|
||||
call = ha.ServiceCall('light', 'turn_on',
|
||||
{ATTR_ENTITY_ID: 'group.test'})
|
||||
assert {'light.bowl'} == \
|
||||
await service.async_extract_entity_ids(hass, call)
|
||||
|
||||
assert ['light.ceiling', 'light.kitchen'] == \
|
||||
service.extract_entity_ids(self.hass, call)
|
||||
call = ha.ServiceCall('light', 'turn_on',
|
||||
{ATTR_ENTITY_ID: 'group.test'})
|
||||
|
||||
assert ['group.test'] == service.extract_entity_ids(
|
||||
self.hass, call, expand_group=False)
|
||||
assert {'light.ceiling', 'light.kitchen'} == \
|
||||
await service.async_extract_entity_ids(hass, call)
|
||||
|
||||
assert {'group.test'} == await service.async_extract_entity_ids(
|
||||
hass, call, expand_group=False)
|
||||
|
||||
|
||||
async def test_extract_entity_ids_from_area(hass):
|
||||
"""Test extract_entity_ids method with areas."""
|
||||
hass.states.async_set('light.Bowl', STATE_ON)
|
||||
hass.states.async_set('light.Ceiling', STATE_OFF)
|
||||
hass.states.async_set('light.Kitchen', STATE_OFF)
|
||||
|
||||
device_in_area = dev_reg.DeviceEntry(area_id='test-area')
|
||||
device_no_area = dev_reg.DeviceEntry()
|
||||
device_diff_area = dev_reg.DeviceEntry(area_id='diff-area')
|
||||
|
||||
mock_device_registry(hass, {
|
||||
device_in_area.id: device_in_area,
|
||||
device_no_area.id: device_no_area,
|
||||
device_diff_area.id: device_diff_area,
|
||||
})
|
||||
|
||||
entity_in_area = ent_reg.RegistryEntry(
|
||||
entity_id='light.in_area',
|
||||
unique_id='in-area-id',
|
||||
platform='test',
|
||||
device_id=device_in_area.id,
|
||||
)
|
||||
entity_no_area = ent_reg.RegistryEntry(
|
||||
entity_id='light.no_area',
|
||||
unique_id='no-area-id',
|
||||
platform='test',
|
||||
device_id=device_no_area.id,
|
||||
)
|
||||
entity_diff_area = ent_reg.RegistryEntry(
|
||||
entity_id='light.diff_area',
|
||||
unique_id='diff-area-id',
|
||||
platform='test',
|
||||
device_id=device_diff_area.id,
|
||||
)
|
||||
mock_registry(hass, {
|
||||
entity_in_area.entity_id: entity_in_area,
|
||||
entity_no_area.entity_id: entity_no_area,
|
||||
entity_diff_area.entity_id: entity_diff_area,
|
||||
})
|
||||
|
||||
call = ha.ServiceCall('light', 'turn_on',
|
||||
{'area_id': 'test-area'})
|
||||
|
||||
assert {'light.in_area'} == \
|
||||
await service.async_extract_entity_ids(hass, call)
|
||||
|
||||
call = ha.ServiceCall('light', 'turn_on',
|
||||
{'area_id': ['test-area', 'diff-area']})
|
||||
|
||||
assert {'light.in_area', 'light.diff_area'} == \
|
||||
await service.async_extract_entity_ids(hass, call)
|
||||
|
||||
|
||||
@asyncio.coroutine
|
||||
|
|
Loading…
Reference in New Issue