core/homeassistant/helpers/service.py

514 lines
16 KiB
Python
Raw Normal View History

"""Service calling related helpers."""
import asyncio
from functools import partial, wraps
import logging
from typing import Callable
import voluptuous as vol
from homeassistant.auth.permissions.const import CAT_ENTITIES, POLICY_CONTROL
from homeassistant.const import (
ATTR_AREA_ID,
ATTR_ENTITY_ID,
ENTITY_MATCH_ALL,
ENTITY_MATCH_NONE,
)
import homeassistant.core as ha
from homeassistant.exceptions import (
2019-07-31 19:25:30 +00:00
HomeAssistantError,
TemplateError,
Unauthorized,
UnknownUser,
)
from homeassistant.helpers import template, typing
import homeassistant.helpers.config_validation as cv
from homeassistant.helpers.typing import HomeAssistantType
from homeassistant.loader import async_get_integration, bind_hass
from homeassistant.util.yaml import load_yaml
from homeassistant.util.yaml.loader import JSON_TYPE
# mypy: allow-untyped-defs, no-check-untyped-defs
2019-07-31 19:25:30 +00:00
CONF_SERVICE = "service"
CONF_SERVICE_TEMPLATE = "service_template"
CONF_SERVICE_ENTITY_ID = "entity_id"
CONF_SERVICE_DATA = "data"
CONF_SERVICE_DATA_TEMPLATE = "data_template"
_LOGGER = logging.getLogger(__name__)
2019-07-31 19:25:30 +00:00
SERVICE_DESCRIPTION_CACHE = "service_description_cache"
@bind_hass
2019-07-31 19:25:30 +00:00
def call_from_config(
hass, config, blocking=False, variables=None, validate_config=True
):
"""Call a service based on a config hash."""
asyncio.run_coroutine_threadsafe(
2019-07-31 19:25:30 +00:00
async_call_from_config(hass, config, blocking, variables, validate_config),
hass.loop,
).result()
2016-10-01 05:34:45 +00:00
@bind_hass
2019-07-31 19:25:30 +00:00
async def async_call_from_config(
hass, config, blocking=False, variables=None, validate_config=True, context=None
):
2016-10-01 05:34:45 +00:00
"""Call a service based on a config hash."""
2016-04-23 05:11:21 +00:00
if validate_config:
try:
config = cv.SERVICE_SCHEMA(config)
except vol.Invalid as ex:
_LOGGER.error("Invalid config for calling service: %s", ex)
return
2016-01-10 00:01:27 +00:00
if CONF_SERVICE in config:
domain_service = config[CONF_SERVICE]
else:
try:
config[CONF_SERVICE_TEMPLATE].hass = hass
2019-07-31 19:25:30 +00:00
domain_service = config[CONF_SERVICE_TEMPLATE].async_render(variables)
domain_service = cv.service(domain_service)
except TemplateError as ex:
2018-12-13 11:21:16 +00:00
if blocking:
raise
2019-07-31 19:25:30 +00:00
_LOGGER.error("Error rendering service name template: %s", ex)
return
except vol.Invalid:
2018-12-13 11:21:16 +00:00
if blocking:
raise
2019-07-31 19:25:30 +00:00
_LOGGER.error("Template rendered invalid service: %s", domain_service)
return
2019-07-31 19:25:30 +00:00
domain, service_name = domain_service.split(".", 1)
service_data = dict(config.get(CONF_SERVICE_DATA, {}))
if CONF_SERVICE_DATA_TEMPLATE in config:
try:
template.attach(hass, config[CONF_SERVICE_DATA_TEMPLATE])
2019-07-31 19:25:30 +00:00
service_data.update(
template.render_complex(config[CONF_SERVICE_DATA_TEMPLATE], variables)
)
except TemplateError as ex:
2019-07-31 19:25:30 +00:00
_LOGGER.error("Error rendering data template: %s", ex)
return
if CONF_SERVICE_ENTITY_ID in config:
service_data[ATTR_ENTITY_ID] = config[CONF_SERVICE_ENTITY_ID]
await hass.services.async_call(
2019-07-31 19:25:30 +00:00
domain, service_name, service_data, blocking=blocking, context=context
)
@bind_hass
def extract_entity_ids(hass, service_call, expand_group=True):
"""Extract a list of entity ids from a service call.
2016-03-07 22:39:52 +00:00
Will convert group entity ids to the entity ids it represents.
"""
return asyncio.run_coroutine_threadsafe(
async_extract_entity_ids(hass, service_call, expand_group), hass.loop
).result()
@bind_hass
async def async_extract_entities(hass, entities, service_call, expand_group=True):
"""Extract a list of entity objects from a service call.
Will convert group entity ids to the entity ids it represents.
"""
data_ent_id = service_call.data.get(ATTR_ENTITY_ID)
if data_ent_id == ENTITY_MATCH_ALL:
return [entity for entity in entities if entity.available]
entity_ids = await async_extract_entity_ids(hass, service_call, expand_group)
found = []
for entity in entities:
if entity.entity_id not in entity_ids:
continue
entity_ids.remove(entity.entity_id)
if not entity.available:
continue
found.append(entity)
if entity_ids:
_LOGGER.warning(
"Unable to find referenced entities %s", ", ".join(sorted(entity_ids))
)
return found
@bind_hass
async def async_extract_entity_ids(hass, service_call, expand_group=True):
"""Extract a list of entity ids from a service call.
Will convert group entity ids to the entity ids it represents.
"""
entity_ids = service_call.data.get(ATTR_ENTITY_ID)
area_ids = service_call.data.get(ATTR_AREA_ID)
extracted = set()
if entity_ids in (None, ENTITY_MATCH_NONE) and area_ids in (
None,
ENTITY_MATCH_NONE,
):
return extracted
if entity_ids and entity_ids != ENTITY_MATCH_NONE:
# Entity ID attr can be a list or a string
if isinstance(entity_ids, str):
entity_ids = [entity_ids]
if expand_group:
2019-07-31 19:25:30 +00:00
entity_ids = hass.components.group.expand_entity_ids(entity_ids)
extracted.update(entity_ids)
if area_ids and area_ids != ENTITY_MATCH_NONE:
if isinstance(area_ids, str):
area_ids = [area_ids]
dev_reg, ent_reg = await asyncio.gather(
hass.helpers.device_registry.async_get_registry(),
hass.helpers.entity_registry.async_get_registry(),
)
devices = [
device
for area_id in area_ids
2019-07-31 19:25:30 +00:00
for device in hass.helpers.device_registry.async_entries_for_area(
dev_reg, area_id
)
]
extracted.update(
entry.entity_id
for device in devices
2019-07-31 19:25:30 +00:00
for entry in hass.helpers.entity_registry.async_entries_for_device(
ent_reg, device.id
)
)
return extracted
async def _load_services_file(hass: HomeAssistantType, domain: str) -> JSON_TYPE:
"""Load services file for an integration."""
integration = await async_get_integration(hass, domain)
try:
return await hass.async_add_executor_job(
2019-07-31 19:25:30 +00:00
load_yaml, str(integration.file_path / "services.yaml")
)
except FileNotFoundError:
2019-07-31 19:25:30 +00:00
_LOGGER.warning("Unable to find services.yaml for the %s integration", domain)
return {}
except HomeAssistantError:
2019-07-31 19:25:30 +00:00
_LOGGER.warning("Unable to parse services.yaml for the %s integration", domain)
return {}
@bind_hass
async def async_get_all_descriptions(hass):
"""Return descriptions (i.e. user documentation) for all service calls."""
descriptions_cache = hass.data.setdefault(SERVICE_DESCRIPTION_CACHE, {})
2019-07-31 19:25:30 +00:00
format_cache_key = "{}.{}".format
services = hass.services.async_services()
# See if there are new services not seen before.
# Any service that we saw before already has an entry in description_cache.
missing = set()
for domain in services:
for service in services[domain]:
if format_cache_key(domain, service) not in descriptions_cache:
missing.add(domain)
break
# Files we loaded for missing descriptions
loaded = {}
if missing:
2019-07-31 19:25:30 +00:00
contents = await asyncio.gather(
*(_load_services_file(hass, domain) for domain in missing)
)
for domain, content in zip(missing, contents):
loaded[domain] = content
# Build response
descriptions = {}
for domain in services:
descriptions[domain] = {}
for service in services[domain]:
cache_key = format_cache_key(domain, service)
description = descriptions_cache.get(cache_key)
# Cache missing descriptions
if description is None:
domain_yaml = loaded[domain]
yaml_description = domain_yaml.get(service, {})
# Don't warn for missing services, because it triggers false
# positives for things like scripts, that register as a service
description = descriptions_cache[cache_key] = {
2019-07-31 19:25:30 +00:00
"description": yaml_description.get("description", ""),
"fields": yaml_description.get("fields", {}),
}
descriptions[domain][service] = description
return descriptions
@ha.callback
@bind_hass
def async_set_service_schema(hass, domain, service, schema):
"""Register a description for a service."""
hass.data.setdefault(SERVICE_DESCRIPTION_CACHE, {})
description = {
"description": schema.get("description") or "",
"fields": schema.get("fields") or {},
}
hass.data[SERVICE_DESCRIPTION_CACHE][f"{domain}.{service}"] = description
@bind_hass
async def entity_service_call(hass, platforms, func, call, required_features=None):
"""Handle an entity service call.
Calls all platforms simultaneously.
"""
if call.context.user_id:
user = await hass.auth.async_get_user(call.context.user_id)
if user is None:
raise UnknownUser(context=call.context)
entity_perms = user.permissions.check_entity
else:
entity_perms = None
target_all_entities = call.data.get(ATTR_ENTITY_ID) == ENTITY_MATCH_ALL
if not target_all_entities:
# A set of entities we're trying to target.
entity_ids = await async_extract_entity_ids(hass, call, True)
# If the service function is a string, we'll pass it the service call data
if isinstance(func, str):
data = {
key: val
for key, val in call.data.items()
if key not in cv.ENTITY_SERVICE_FIELDS
}
# If the service function is not a string, we pass the service call
else:
data = call
# Check the permissions
# A list with entities to call the service on.
entity_candidates = []
if entity_perms is None:
for platform in platforms:
if target_all_entities:
entity_candidates.extend(platform.entities.values())
else:
entity_candidates.extend(
2019-07-31 19:25:30 +00:00
[
entity
for entity in platform.entities.values()
if entity.entity_id in entity_ids
]
)
elif target_all_entities:
# If we target all entities, we will select all entities the user
# is allowed to control.
for platform in platforms:
entity_candidates.extend(
2019-07-31 19:25:30 +00:00
[
entity
for entity in platform.entities.values()
if entity_perms(entity.entity_id, POLICY_CONTROL)
]
)
else:
for platform in platforms:
platform_entities = []
for entity in platform.entities.values():
if entity.entity_id not in entity_ids:
continue
if not entity_perms(entity.entity_id, POLICY_CONTROL):
raise Unauthorized(
context=call.context,
entity_id=entity.entity_id,
2019-07-31 19:25:30 +00:00
permission=POLICY_CONTROL,
)
platform_entities.append(entity)
entity_candidates.extend(platform_entities)
if not target_all_entities:
for entity in entity_candidates:
entity_ids.remove(entity.entity_id)
if entity_ids:
_LOGGER.warning(
"Unable to find referenced entities %s", ", ".join(sorted(entity_ids))
)
entities = []
for entity in entity_candidates:
if not entity.available:
continue
# Skip entities that don't have the required feature.
2019-07-31 19:25:30 +00:00
if required_features is not None and not any(
entity.supported_features & feature_set for feature_set in required_features
):
continue
entities.append(entity)
if not entities:
return
done, pending = await asyncio.wait(
[
entity.async_request_call(
_handle_entity_call(hass, entity, func, data, call.context)
)
for entity in entities
]
)
assert not pending
for future in done:
future.result() # pop exception if have
tasks = []
for entity in entities:
if not entity.should_poll:
continue
# Context expires if the turn on commands took a long time.
# Set context again so it's there when we update
entity.async_set_context(call.context)
tasks.append(entity.async_update_ha_state(True))
if tasks:
done, pending = await asyncio.wait(tasks)
assert not pending
for future in done:
future.result() # pop exception if have
async def _handle_entity_call(hass, entity, func, data, context):
"""Handle calling service method."""
entity.async_set_context(context)
if isinstance(func, str):
result = hass.async_add_job(partial(getattr(entity, func), **data))
else:
result = hass.async_add_job(func, entity, data)
# Guard because callback functions do not return a task when passed to async_add_job.
if result is not None:
await result
if asyncio.iscoroutine(result):
_LOGGER.error(
"Service %s for %s incorrectly returns a coroutine object. Await result instead in service handler. Report bug to integration author.",
func,
entity.entity_id,
)
await result
@bind_hass
@ha.callback
def async_register_admin_service(
2019-07-31 19:25:30 +00:00
hass: typing.HomeAssistantType,
domain: str,
service: str,
service_func: Callable,
schema: vol.Schema = vol.Schema({}, extra=vol.PREVENT_EXTRA),
) -> None:
"""Register a service that requires admin access."""
2019-07-31 19:25:30 +00:00
@wraps(service_func)
async def admin_handler(call):
if call.context.user_id:
user = await hass.auth.async_get_user(call.context.user_id)
if user is None:
raise UnknownUser(context=call.context)
if not user.is_admin:
raise Unauthorized(context=call.context)
result = hass.async_add_job(service_func, call)
if result is not None:
await result
2019-07-31 19:25:30 +00:00
hass.services.async_register(domain, service, admin_handler, schema)
@bind_hass
@ha.callback
def verify_domain_control(hass: HomeAssistantType, domain: str) -> Callable:
"""Ensure permission to access any entity under domain in service call."""
2019-07-31 19:25:30 +00:00
def decorator(service_handler: Callable) -> Callable:
"""Decorate."""
if not asyncio.iscoroutinefunction(service_handler):
2019-07-31 19:25:30 +00:00
raise HomeAssistantError("Can only decorate async functions.")
async def check_permissions(call):
"""Check user permission and raise before call if unauthorized."""
if not call.context.user_id:
return await service_handler(call)
user = await hass.auth.async_get_user(call.context.user_id)
if user is None:
raise UnknownUser(
context=call.context,
permission=POLICY_CONTROL,
2019-07-31 19:25:30 +00:00
user_id=call.context.user_id,
)
reg = await hass.helpers.entity_registry.async_get_registry()
for entity in reg.entities.values():
if entity.platform != domain:
continue
if user.permissions.check_entity(entity.entity_id, POLICY_CONTROL):
return await service_handler(call)
raise Unauthorized(
context=call.context,
permission=POLICY_CONTROL,
user_id=call.context.user_id,
2019-07-31 19:25:30 +00:00
perm_category=CAT_ENTITIES,
)
return check_permissions
return decorator