Allow targeting areas in service calls (#21472)

* Allow targeting areas in service calls

* Lint + Type

* Address comments
pull/21655/head
Paulus Schoutsen 2019-03-04 09:51:12 -08:00 committed by GitHub
parent f62eb22ef8
commit 8213016eaf
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
16 changed files with 180 additions and 62 deletions

View File

@ -17,7 +17,7 @@ import voluptuous as vol
import homeassistant.core as ha
import homeassistant.config as conf_util
from homeassistant.exceptions import HomeAssistantError
from homeassistant.helpers.service import extract_entity_ids
from homeassistant.helpers.service import async_extract_entity_ids
from homeassistant.helpers import intent
from homeassistant.const import (
ATTR_ENTITY_ID, SERVICE_TURN_ON, SERVICE_TURN_OFF, SERVICE_TOGGLE,
@ -70,7 +70,7 @@ async def async_setup(hass: ha.HomeAssistant, config: dict) -> Awaitable[bool]:
"""Set up general services related to Home Assistant."""
async def async_handle_turn_service(service):
"""Handle calls to homeassistant.turn_on/off."""
entity_ids = extract_entity_ids(hass, service)
entity_ids = await async_extract_entity_ids(hass, service)
# Generic turn on/off method requires entity id
if not entity_ids:

View File

@ -89,7 +89,7 @@ async def async_setup(hass, config):
async def async_handle_alert_service(service_call):
"""Handle calls to alert services."""
alert_ids = service.extract_entity_ids(hass, service_call)
alert_ids = await service.async_extract_entity_ids(hass, service_call)
for alert_id in alert_ids:
for alert in entities:

View File

@ -120,7 +120,7 @@ async def async_setup(hass, config):
async def trigger_service_handler(service_call):
"""Handle automation triggers."""
tasks = []
for entity in component.async_extract_from_service(service_call):
for entity in await component.async_extract_from_service(service_call):
tasks.append(entity.async_trigger(
service_call.data.get(ATTR_VARIABLES),
skip_condition=True,
@ -133,7 +133,7 @@ async def async_setup(hass, config):
"""Handle automation turn on/off service calls."""
tasks = []
method = 'async_{}'.format(service_call.service)
for entity in component.async_extract_from_service(service_call):
for entity in await component.async_extract_from_service(service_call):
tasks.append(getattr(entity, method)())
if tasks:
@ -142,7 +142,7 @@ async def async_setup(hass, config):
async def toggle_service_handler(service_call):
"""Handle automation toggle service calls."""
tasks = []
for entity in component.async_extract_from_service(service_call):
for entity in await component.async_extract_from_service(service_call):
if entity.is_on:
tasks.append(entity.async_turn_off())
else:

View File

@ -300,8 +300,8 @@ async def async_setup(hass, config):
visible = service.data.get(ATTR_VISIBLE)
tasks = []
for group in component.async_extract_from_service(service,
expand_group=False):
for group in await component.async_extract_from_service(
service, expand_group=False):
group.visible = visible
tasks.append(group.async_update_ha_state())

View File

@ -75,7 +75,7 @@ async def async_setup(hass, config):
async def async_scan_service(service):
"""Service handler for scan."""
image_entities = component.async_extract_from_service(service)
image_entities = await component.async_extract_from_service(service)
update_tasks = []
for entity in image_entities:

View File

@ -256,7 +256,7 @@ async def async_setup(hass, config):
params = service.data.copy()
# Convert the entity ids to valid light ids
target_lights = component.async_extract_from_service(service)
target_lights = await component.async_extract_from_service(service)
params.pop(ATTR_ENTITY_ID, None)
if service.context.user_id:

View File

@ -68,7 +68,7 @@ async def async_setup(hass, config):
async def async_handle_scene_service(service):
"""Handle calls to the switch services."""
target_scenes = component.async_extract_from_service(service)
target_scenes = await component.async_extract_from_service(service)
tasks = [scene.async_activate() for scene in target_scenes]
if tasks:

View File

@ -74,20 +74,21 @@ async def async_setup(hass, config):
# We could turn on script directly here, but we only want to offer
# one way to do it. Otherwise no easy way to detect invocations.
var = service.data.get(ATTR_VARIABLES)
for script in component.async_extract_from_service(service):
for script in await component.async_extract_from_service(service):
await hass.services.async_call(DOMAIN, script.object_id, var,
context=service.context)
async def turn_off_service(service):
"""Cancel a script."""
# Stopping a script is ok to be done in parallel
await asyncio.wait(
[script.async_turn_off() for script
in component.async_extract_from_service(service)], loop=hass.loop)
await asyncio.wait([
script.async_turn_off() for script
in await component.async_extract_from_service(service)
], loop=hass.loop)
async def toggle_service(service):
"""Toggle a script."""
for script in component.async_extract_from_service(service):
for script in await component.async_extract_from_service(service):
await script.async_toggle(context=service.context)
hass.services.async_register(DOMAIN, SERVICE_RELOAD, reload_service,

View File

@ -245,6 +245,9 @@ ATTR_NAME = 'name'
# Contains one string or a list of strings, each being an entity id
ATTR_ENTITY_ID = 'entity_id'
# Contains one string or a list of strings, each being an area id
ATTR_AREA_ID = 'area_id'
# String with a friendly name for the entity
ATTR_FRIENDLY_NAME = 'friendly_name'

View File

@ -37,6 +37,11 @@ class AreaRegistry:
self.areas = {} # type: MutableMapping[str, AreaEntry]
self._store = hass.helpers.storage.Store(STORAGE_VERSION, STORAGE_KEY)
@callback
def async_get_area(self, area_id: str) -> Optional[AreaEntry]:
"""Get all areas."""
return self.areas.get(area_id)
@callback
def async_list_areas(self) -> Iterable[AreaEntry]:
"""Get all areas."""

View File

@ -1,6 +1,7 @@
"""Provide a way to connect entities belonging to one device."""
import logging
import uuid
from typing import List
from collections import OrderedDict
@ -280,3 +281,11 @@ async def async_get_registry(hass) -> DeviceRegistry:
task = hass.data[DATA_REGISTRY] = hass.async_create_task(_load_reg())
return await task
@callback
def async_entries_for_area(registry: DeviceRegistry, area_id: str) \
-> List[DeviceEntry]:
"""Return entries that match an area."""
return [device for device in registry.devices.values()
if device.area_id == area_id]

View File

@ -12,7 +12,7 @@ from homeassistant.const import (
from homeassistant.core import callback
from homeassistant.exceptions import HomeAssistantError
from homeassistant.helpers import config_per_platform, discovery
from homeassistant.helpers.service import extract_entity_ids
from homeassistant.helpers.service import async_extract_entity_ids
from homeassistant.loader import bind_hass
from homeassistant.util import slugify
from .entity_platform import EntityPlatform
@ -153,8 +153,7 @@ class EntityComponent:
await platform.async_reset()
return True
@callback
def async_extract_from_service(self, service, expand_group=True):
async def async_extract_from_service(self, service, expand_group=True):
"""Extract all known and available entities from a service call.
Will return all entities if no entities specified in call.
@ -174,7 +173,8 @@ class EntityComponent:
return [entity for entity in self.entities if entity.available]
entity_ids = set(extract_entity_ids(self.hass, service, expand_group))
entity_ids = await async_extract_entity_ids(
self.hass, service, expand_group)
return [entity for entity in self.entities
if entity.available and entity.entity_id in entity_ids]

View File

@ -10,7 +10,7 @@ timer.
from collections import OrderedDict
from itertools import chain
import logging
from typing import Optional
from typing import Optional, List
import weakref
import attr
@ -292,6 +292,14 @@ async def async_get_registry(hass) -> EntityRegistry:
return await task
@callback
def async_entries_for_device(registry: EntityRegistry, device_id: str) \
-> List[RegistryEntry]:
"""Return entries that match a device."""
return [entry for entry in registry.entities.values()
if entry.device_id == device_id]
async def _async_migrate(entities):
"""Migrate the YAML config file to storage helper format."""
return {

View File

@ -6,7 +6,8 @@ from os import path
import voluptuous as vol
from homeassistant.auth.permissions.const import POLICY_CONTROL
from homeassistant.const import ATTR_ENTITY_ID, ENTITY_MATCH_ALL
from homeassistant.const import (
ATTR_ENTITY_ID, ENTITY_MATCH_ALL, ATTR_AREA_ID)
import homeassistant.core as ha
from homeassistant.exceptions import TemplateError, Unauthorized, UnknownUser
from homeassistant.helpers import template
@ -89,30 +90,64 @@ async def async_call_from_config(hass, config, blocking=False, variables=None,
def extract_entity_ids(hass, service_call, expand_group=True):
"""Extract a list of entity ids from a service call.
Will convert group entity ids to the entity ids it represents.
"""
return run_coroutine_threadsafe(
async_extract_entity_ids(hass, service_call, expand_group), hass.loop
).result()
@bind_hass
async def async_extract_entity_ids(hass, service_call, expand_group=True):
"""Extract a list of entity ids from a service call.
Will convert group entity ids to the entity ids it represents.
Async friendly.
"""
if not (service_call.data and ATTR_ENTITY_ID in service_call.data):
entity_ids = service_call.data.get(ATTR_ENTITY_ID)
area_ids = service_call.data.get(ATTR_AREA_ID)
if not entity_ids and not area_ids:
return []
group = hass.components.group
extracted = set()
# Entity ID attr can be a list or a string
service_ent_id = service_call.data[ATTR_ENTITY_ID]
if entity_ids:
# Entity ID attr can be a list or a string
if isinstance(entity_ids, str):
entity_ids = [entity_ids]
if expand_group:
if expand_group:
entity_ids = \
hass.components.group.expand_entity_ids(entity_ids)
if isinstance(service_ent_id, str):
return group.expand_entity_ids([service_ent_id])
extracted.update(entity_ids)
return [ent_id for ent_id in
group.expand_entity_ids(service_ent_id)]
if area_ids:
if isinstance(area_ids, str):
area_ids = [area_ids]
if isinstance(service_ent_id, str):
return [service_ent_id]
dev_reg, ent_reg = await asyncio.gather(
hass.helpers.device_registry.async_get_registry(),
hass.helpers.entity_registry.async_get_registry(),
)
devices = [
device
for area_id in area_ids
for device in
hass.helpers.device_registry.async_entries_for_area(
dev_reg, area_id)
]
extracted.update(
entry.entity_id
for device in devices
for entry in
hass.helpers.entity_registry.async_entries_for_device(
ent_reg, device.id)
)
return service_ent_id
return extracted
@bind_hass
@ -213,8 +248,7 @@ async def entity_service_call(hass, platforms, func, call, service_name=''):
if not target_all_entities:
# A set of entities we're trying to target.
entity_ids = set(
extract_entity_ids(hass, call, True))
entity_ids = await async_extract_entity_ids(hass, call, True)
# If the service function is a string, we'll pass it the service call data
if isinstance(func, str):

View File

@ -206,7 +206,7 @@ def test_extract_from_service_available_device(hass):
assert ['test_domain.test_1', 'test_domain.test_3'] == \
sorted(ent.entity_id for ent in
component.async_extract_from_service(call_1))
(yield from component.async_extract_from_service(call_1)))
call_2 = ha.ServiceCall('test', 'service', data={
'entity_id': ['test_domain.test_3', 'test_domain.test_4'],
@ -214,7 +214,7 @@ def test_extract_from_service_available_device(hass):
assert ['test_domain.test_3'] == \
sorted(ent.entity_id for ent in
component.async_extract_from_service(call_2))
(yield from component.async_extract_from_service(call_2)))
@asyncio.coroutine
@ -275,7 +275,7 @@ def test_extract_from_service_returns_all_if_no_entity_id(hass):
assert ['test_domain.test_1', 'test_domain.test_2'] == \
sorted(ent.entity_id for ent in
component.async_extract_from_service(call))
(yield from component.async_extract_from_service(call)))
@asyncio.coroutine
@ -293,7 +293,7 @@ def test_extract_from_service_filter_out_non_existing_entities(hass):
assert ['test_domain.test_2'] == \
[ent.entity_id for ent
in component.async_extract_from_service(call)]
in (yield from component.async_extract_from_service(call))]
@asyncio.coroutine
@ -308,7 +308,8 @@ def test_extract_from_service_no_group_expand(hass):
'entity_id': ['group.test_group']
})
extracted = component.async_extract_from_service(call, expand_group=False)
extracted = yield from component.async_extract_from_service(
call, expand_group=False)
assert extracted == [test_group]
@ -466,7 +467,7 @@ async def test_extract_all_omit_entity_id(hass, caplog):
assert ['test_domain.test_1', 'test_domain.test_2'] == \
sorted(ent.entity_id for ent in
component.async_extract_from_service(call))
await component.async_extract_from_service(call))
assert ('Not passing an entity ID to a service to target all entities is '
'deprecated') in caplog.text
@ -483,6 +484,6 @@ async def test_extract_all_use_match_all(hass, caplog):
assert ['test_domain.test_1', 'test_domain.test_2'] == \
sorted(ent.entity_id for ent in
component.async_extract_from_service(call))
await component.async_extract_from_service(call))
assert ('Not passing an entity ID to a service to target all entities is '
'deprecated') not in caplog.text

View File

@ -15,8 +15,11 @@ from homeassistant.helpers import service, template
from homeassistant.setup import async_setup_component
import homeassistant.helpers.config_validation as cv
from homeassistant.auth.permissions import PolicyPermissions
from tests.common import get_test_home_assistant, mock_service, mock_coro
from homeassistant.helpers import (
device_registry as dev_reg, entity_registry as ent_reg)
from tests.common import (
get_test_home_assistant, mock_service, mock_coro, mock_registry,
mock_device_registry)
@pytest.fixture
@ -163,29 +166,83 @@ class TestServiceHelpers(unittest.TestCase):
})
assert 3 == mock_log.call_count
def test_extract_entity_ids(self):
"""Test extract_entity_ids method."""
self.hass.states.set('light.Bowl', STATE_ON)
self.hass.states.set('light.Ceiling', STATE_OFF)
self.hass.states.set('light.Kitchen', STATE_OFF)
loader.get_component(self.hass, 'group').Group.create_group(
self.hass, 'test', ['light.Ceiling', 'light.Kitchen'])
async def test_extract_entity_ids(hass):
"""Test extract_entity_ids method."""
hass.states.async_set('light.Bowl', STATE_ON)
hass.states.async_set('light.Ceiling', STATE_OFF)
hass.states.async_set('light.Kitchen', STATE_OFF)
call = ha.ServiceCall('light', 'turn_on',
{ATTR_ENTITY_ID: 'light.Bowl'})
await loader.get_component(hass, 'group').Group.async_create_group(
hass, 'test', ['light.Ceiling', 'light.Kitchen'])
assert ['light.bowl'] == \
service.extract_entity_ids(self.hass, call)
call = ha.ServiceCall('light', 'turn_on',
{ATTR_ENTITY_ID: 'light.Bowl'})
call = ha.ServiceCall('light', 'turn_on',
{ATTR_ENTITY_ID: 'group.test'})
assert {'light.bowl'} == \
await service.async_extract_entity_ids(hass, call)
assert ['light.ceiling', 'light.kitchen'] == \
service.extract_entity_ids(self.hass, call)
call = ha.ServiceCall('light', 'turn_on',
{ATTR_ENTITY_ID: 'group.test'})
assert ['group.test'] == service.extract_entity_ids(
self.hass, call, expand_group=False)
assert {'light.ceiling', 'light.kitchen'} == \
await service.async_extract_entity_ids(hass, call)
assert {'group.test'} == await service.async_extract_entity_ids(
hass, call, expand_group=False)
async def test_extract_entity_ids_from_area(hass):
"""Test extract_entity_ids method with areas."""
hass.states.async_set('light.Bowl', STATE_ON)
hass.states.async_set('light.Ceiling', STATE_OFF)
hass.states.async_set('light.Kitchen', STATE_OFF)
device_in_area = dev_reg.DeviceEntry(area_id='test-area')
device_no_area = dev_reg.DeviceEntry()
device_diff_area = dev_reg.DeviceEntry(area_id='diff-area')
mock_device_registry(hass, {
device_in_area.id: device_in_area,
device_no_area.id: device_no_area,
device_diff_area.id: device_diff_area,
})
entity_in_area = ent_reg.RegistryEntry(
entity_id='light.in_area',
unique_id='in-area-id',
platform='test',
device_id=device_in_area.id,
)
entity_no_area = ent_reg.RegistryEntry(
entity_id='light.no_area',
unique_id='no-area-id',
platform='test',
device_id=device_no_area.id,
)
entity_diff_area = ent_reg.RegistryEntry(
entity_id='light.diff_area',
unique_id='diff-area-id',
platform='test',
device_id=device_diff_area.id,
)
mock_registry(hass, {
entity_in_area.entity_id: entity_in_area,
entity_no_area.entity_id: entity_no_area,
entity_diff_area.entity_id: entity_diff_area,
})
call = ha.ServiceCall('light', 'turn_on',
{'area_id': 'test-area'})
assert {'light.in_area'} == \
await service.async_extract_entity_ids(hass, call)
call = ha.ServiceCall('light', 'turn_on',
{'area_id': ['test-area', 'diff-area']})
assert {'light.in_area', 'light.diff_area'} == \
await service.async_extract_entity_ids(hass, call)
@asyncio.coroutine