453 lines
14 KiB
Python
453 lines
14 KiB
Python
"""Service calling related helpers."""
|
|
import asyncio
|
|
from functools import wraps
|
|
import logging
|
|
from typing import Callable
|
|
|
|
import voluptuous as vol
|
|
|
|
from homeassistant.auth.permissions.const import CAT_ENTITIES, POLICY_CONTROL
|
|
from homeassistant.const import ATTR_ENTITY_ID, ENTITY_MATCH_ALL, ATTR_AREA_ID
|
|
import homeassistant.core as ha
|
|
from homeassistant.exceptions import (
|
|
HomeAssistantError,
|
|
TemplateError,
|
|
Unauthorized,
|
|
UnknownUser,
|
|
)
|
|
from homeassistant.helpers import template, typing
|
|
from homeassistant.loader import async_get_integration, bind_hass
|
|
from homeassistant.util.yaml import load_yaml
|
|
from homeassistant.util.yaml.loader import JSON_TYPE
|
|
import homeassistant.helpers.config_validation as cv
|
|
from homeassistant.helpers.typing import HomeAssistantType
|
|
|
|
|
|
# mypy: allow-untyped-defs, no-check-untyped-defs
|
|
|
|
CONF_SERVICE = "service"
|
|
CONF_SERVICE_TEMPLATE = "service_template"
|
|
CONF_SERVICE_ENTITY_ID = "entity_id"
|
|
CONF_SERVICE_DATA = "data"
|
|
CONF_SERVICE_DATA_TEMPLATE = "data_template"
|
|
|
|
_LOGGER = logging.getLogger(__name__)
|
|
|
|
SERVICE_DESCRIPTION_CACHE = "service_description_cache"
|
|
|
|
|
|
@bind_hass
|
|
def call_from_config(
|
|
hass, config, blocking=False, variables=None, validate_config=True
|
|
):
|
|
"""Call a service based on a config hash."""
|
|
asyncio.run_coroutine_threadsafe(
|
|
async_call_from_config(hass, config, blocking, variables, validate_config),
|
|
hass.loop,
|
|
).result()
|
|
|
|
|
|
@bind_hass
|
|
async def async_call_from_config(
|
|
hass, config, blocking=False, variables=None, validate_config=True, context=None
|
|
):
|
|
"""Call a service based on a config hash."""
|
|
if validate_config:
|
|
try:
|
|
config = cv.SERVICE_SCHEMA(config)
|
|
except vol.Invalid as ex:
|
|
_LOGGER.error("Invalid config for calling service: %s", ex)
|
|
return
|
|
|
|
if CONF_SERVICE in config:
|
|
domain_service = config[CONF_SERVICE]
|
|
else:
|
|
try:
|
|
config[CONF_SERVICE_TEMPLATE].hass = hass
|
|
domain_service = config[CONF_SERVICE_TEMPLATE].async_render(variables)
|
|
domain_service = cv.service(domain_service)
|
|
except TemplateError as ex:
|
|
if blocking:
|
|
raise
|
|
_LOGGER.error("Error rendering service name template: %s", ex)
|
|
return
|
|
except vol.Invalid:
|
|
if blocking:
|
|
raise
|
|
_LOGGER.error("Template rendered invalid service: %s", domain_service)
|
|
return
|
|
|
|
domain, service_name = domain_service.split(".", 1)
|
|
service_data = dict(config.get(CONF_SERVICE_DATA, {}))
|
|
|
|
if CONF_SERVICE_DATA_TEMPLATE in config:
|
|
try:
|
|
template.attach(hass, config[CONF_SERVICE_DATA_TEMPLATE])
|
|
service_data.update(
|
|
template.render_complex(config[CONF_SERVICE_DATA_TEMPLATE], variables)
|
|
)
|
|
except TemplateError as ex:
|
|
_LOGGER.error("Error rendering data template: %s", ex)
|
|
return
|
|
|
|
if CONF_SERVICE_ENTITY_ID in config:
|
|
service_data[ATTR_ENTITY_ID] = config[CONF_SERVICE_ENTITY_ID]
|
|
|
|
await hass.services.async_call(
|
|
domain, service_name, service_data, blocking=blocking, context=context
|
|
)
|
|
|
|
|
|
@bind_hass
|
|
def extract_entity_ids(hass, service_call, expand_group=True):
|
|
"""Extract a list of entity ids from a service call.
|
|
|
|
Will convert group entity ids to the entity ids it represents.
|
|
"""
|
|
return asyncio.run_coroutine_threadsafe(
|
|
async_extract_entity_ids(hass, service_call, expand_group), hass.loop
|
|
).result()
|
|
|
|
|
|
@bind_hass
|
|
async def async_extract_entity_ids(hass, service_call, expand_group=True):
|
|
"""Extract a list of entity ids from a service call.
|
|
|
|
Will convert group entity ids to the entity ids it represents.
|
|
|
|
Async friendly.
|
|
"""
|
|
entity_ids = service_call.data.get(ATTR_ENTITY_ID)
|
|
area_ids = service_call.data.get(ATTR_AREA_ID)
|
|
|
|
if not entity_ids and not area_ids:
|
|
return []
|
|
|
|
extracted = set()
|
|
|
|
if entity_ids:
|
|
# Entity ID attr can be a list or a string
|
|
if isinstance(entity_ids, str):
|
|
entity_ids = [entity_ids]
|
|
|
|
if expand_group:
|
|
entity_ids = hass.components.group.expand_entity_ids(entity_ids)
|
|
|
|
extracted.update(entity_ids)
|
|
|
|
if area_ids:
|
|
if isinstance(area_ids, str):
|
|
area_ids = [area_ids]
|
|
|
|
dev_reg, ent_reg = await asyncio.gather(
|
|
hass.helpers.device_registry.async_get_registry(),
|
|
hass.helpers.entity_registry.async_get_registry(),
|
|
)
|
|
devices = [
|
|
device
|
|
for area_id in area_ids
|
|
for device in hass.helpers.device_registry.async_entries_for_area(
|
|
dev_reg, area_id
|
|
)
|
|
]
|
|
extracted.update(
|
|
entry.entity_id
|
|
for device in devices
|
|
for entry in hass.helpers.entity_registry.async_entries_for_device(
|
|
ent_reg, device.id
|
|
)
|
|
)
|
|
|
|
return extracted
|
|
|
|
|
|
async def _load_services_file(hass: HomeAssistantType, domain: str) -> JSON_TYPE:
|
|
"""Load services file for an integration."""
|
|
integration = await async_get_integration(hass, domain)
|
|
try:
|
|
return await hass.async_add_executor_job(
|
|
load_yaml, str(integration.file_path / "services.yaml")
|
|
)
|
|
except FileNotFoundError:
|
|
_LOGGER.warning("Unable to find services.yaml for the %s integration", domain)
|
|
return {}
|
|
except HomeAssistantError:
|
|
_LOGGER.warning("Unable to parse services.yaml for the %s integration", domain)
|
|
return {}
|
|
|
|
|
|
@bind_hass
|
|
async def async_get_all_descriptions(hass):
|
|
"""Return descriptions (i.e. user documentation) for all service calls."""
|
|
descriptions_cache = hass.data.setdefault(SERVICE_DESCRIPTION_CACHE, {})
|
|
format_cache_key = "{}.{}".format
|
|
services = hass.services.async_services()
|
|
|
|
# See if there are new services not seen before.
|
|
# Any service that we saw before already has an entry in description_cache.
|
|
missing = set()
|
|
for domain in services:
|
|
for service in services[domain]:
|
|
if format_cache_key(domain, service) not in descriptions_cache:
|
|
missing.add(domain)
|
|
break
|
|
|
|
# Files we loaded for missing descriptions
|
|
loaded = {}
|
|
|
|
if missing:
|
|
contents = await asyncio.gather(
|
|
*(_load_services_file(hass, domain) for domain in missing)
|
|
)
|
|
|
|
for domain, content in zip(missing, contents):
|
|
loaded[domain] = content
|
|
|
|
# Build response
|
|
descriptions = {}
|
|
for domain in services:
|
|
descriptions[domain] = {}
|
|
|
|
for service in services[domain]:
|
|
cache_key = format_cache_key(domain, service)
|
|
description = descriptions_cache.get(cache_key)
|
|
|
|
# Cache missing descriptions
|
|
if description is None:
|
|
domain_yaml = loaded[domain]
|
|
yaml_description = domain_yaml.get(service, {})
|
|
|
|
# Don't warn for missing services, because it triggers false
|
|
# positives for things like scripts, that register as a service
|
|
|
|
description = descriptions_cache[cache_key] = {
|
|
"description": yaml_description.get("description", ""),
|
|
"fields": yaml_description.get("fields", {}),
|
|
}
|
|
|
|
descriptions[domain][service] = description
|
|
|
|
return descriptions
|
|
|
|
|
|
@ha.callback
|
|
@bind_hass
|
|
def async_set_service_schema(hass, domain, service, schema):
|
|
"""Register a description for a service."""
|
|
hass.data.setdefault(SERVICE_DESCRIPTION_CACHE, {})
|
|
|
|
description = {
|
|
"description": schema.get("description") or "",
|
|
"fields": schema.get("fields") or {},
|
|
}
|
|
|
|
hass.data[SERVICE_DESCRIPTION_CACHE]["{}.{}".format(domain, service)] = description
|
|
|
|
|
|
@bind_hass
|
|
async def entity_service_call(
|
|
hass, platforms, func, call, service_name="", required_features=None
|
|
):
|
|
"""Handle an entity service call.
|
|
|
|
Calls all platforms simultaneously.
|
|
"""
|
|
if call.context.user_id:
|
|
user = await hass.auth.async_get_user(call.context.user_id)
|
|
if user is None:
|
|
raise UnknownUser(context=call.context)
|
|
entity_perms = user.permissions.check_entity
|
|
else:
|
|
entity_perms = None
|
|
|
|
# Are we trying to target all entities
|
|
if ATTR_ENTITY_ID in call.data:
|
|
target_all_entities = call.data[ATTR_ENTITY_ID] == ENTITY_MATCH_ALL
|
|
else:
|
|
# Remove the service_name parameter along with this warning
|
|
_LOGGER.warning(
|
|
"Not passing an entity ID to a service to target all "
|
|
"entities is deprecated. Update your call to %s to be "
|
|
"instead: entity_id: %s",
|
|
service_name,
|
|
ENTITY_MATCH_ALL,
|
|
)
|
|
target_all_entities = True
|
|
|
|
if not target_all_entities:
|
|
# A set of entities we're trying to target.
|
|
entity_ids = await async_extract_entity_ids(hass, call, True)
|
|
|
|
# If the service function is a string, we'll pass it the service call data
|
|
if isinstance(func, str):
|
|
data = {key: val for key, val in call.data.items() if key != ATTR_ENTITY_ID}
|
|
# If the service function is not a string, we pass the service call
|
|
else:
|
|
data = call
|
|
|
|
# Check the permissions
|
|
|
|
# A list with for each platform in platforms a list of entities to call
|
|
# the service on.
|
|
platforms_entities = []
|
|
|
|
if entity_perms is None:
|
|
for platform in platforms:
|
|
if target_all_entities:
|
|
platforms_entities.append(list(platform.entities.values()))
|
|
else:
|
|
platforms_entities.append(
|
|
[
|
|
entity
|
|
for entity in platform.entities.values()
|
|
if entity.entity_id in entity_ids
|
|
]
|
|
)
|
|
|
|
elif target_all_entities:
|
|
# If we target all entities, we will select all entities the user
|
|
# is allowed to control.
|
|
for platform in platforms:
|
|
platforms_entities.append(
|
|
[
|
|
entity
|
|
for entity in platform.entities.values()
|
|
if entity_perms(entity.entity_id, POLICY_CONTROL)
|
|
]
|
|
)
|
|
|
|
else:
|
|
for platform in platforms:
|
|
platform_entities = []
|
|
for entity in platform.entities.values():
|
|
if entity.entity_id not in entity_ids:
|
|
continue
|
|
|
|
if not entity_perms(entity.entity_id, POLICY_CONTROL):
|
|
raise Unauthorized(
|
|
context=call.context,
|
|
entity_id=entity.entity_id,
|
|
permission=POLICY_CONTROL,
|
|
)
|
|
|
|
platform_entities.append(entity)
|
|
|
|
platforms_entities.append(platform_entities)
|
|
|
|
tasks = [
|
|
_handle_service_platform_call(
|
|
func, data, entities, call.context, required_features
|
|
)
|
|
for platform, entities in zip(platforms, platforms_entities)
|
|
]
|
|
|
|
if tasks:
|
|
done, pending = await asyncio.wait(tasks)
|
|
assert not pending
|
|
for future in done:
|
|
future.result() # pop exception if have
|
|
|
|
|
|
async def _handle_service_platform_call(
|
|
func, data, entities, context, required_features
|
|
):
|
|
"""Handle a function call."""
|
|
tasks = []
|
|
|
|
for entity in entities:
|
|
if not entity.available:
|
|
continue
|
|
|
|
# Skip entities that don't have the required feature.
|
|
if required_features is not None and not any(
|
|
entity.supported_features & feature_set for feature_set in required_features
|
|
):
|
|
continue
|
|
|
|
entity.async_set_context(context)
|
|
|
|
if isinstance(func, str):
|
|
await getattr(entity, func)(**data)
|
|
else:
|
|
await func(entity, data)
|
|
|
|
if entity.should_poll:
|
|
tasks.append(entity.async_update_ha_state(True))
|
|
|
|
if tasks:
|
|
done, pending = await asyncio.wait(tasks)
|
|
assert not pending
|
|
for future in done:
|
|
future.result() # pop exception if have
|
|
|
|
|
|
@bind_hass
|
|
@ha.callback
|
|
def async_register_admin_service(
|
|
hass: typing.HomeAssistantType,
|
|
domain: str,
|
|
service: str,
|
|
service_func: Callable,
|
|
schema: vol.Schema = vol.Schema({}, extra=vol.PREVENT_EXTRA),
|
|
) -> None:
|
|
"""Register a service that requires admin access."""
|
|
|
|
@wraps(service_func)
|
|
async def admin_handler(call):
|
|
if call.context.user_id:
|
|
user = await hass.auth.async_get_user(call.context.user_id)
|
|
if user is None:
|
|
raise UnknownUser(context=call.context)
|
|
if not user.is_admin:
|
|
raise Unauthorized(context=call.context)
|
|
|
|
await hass.async_add_job(service_func, call)
|
|
|
|
hass.services.async_register(domain, service, admin_handler, schema)
|
|
|
|
|
|
@bind_hass
|
|
@ha.callback
|
|
def verify_domain_control(hass: HomeAssistantType, domain: str) -> Callable:
|
|
"""Ensure permission to access any entity under domain in service call."""
|
|
|
|
def decorator(service_handler: Callable) -> Callable:
|
|
"""Decorate."""
|
|
if not asyncio.iscoroutinefunction(service_handler):
|
|
raise HomeAssistantError("Can only decorate async functions.")
|
|
|
|
async def check_permissions(call):
|
|
"""Check user permission and raise before call if unauthorized."""
|
|
if not call.context.user_id:
|
|
return await service_handler(call)
|
|
|
|
user = await hass.auth.async_get_user(call.context.user_id)
|
|
if user is None:
|
|
raise UnknownUser(
|
|
context=call.context,
|
|
permission=POLICY_CONTROL,
|
|
user_id=call.context.user_id,
|
|
)
|
|
|
|
reg = await hass.helpers.entity_registry.async_get_registry()
|
|
entities = [
|
|
entity.entity_id
|
|
for entity in reg.entities.values()
|
|
if entity.platform == domain
|
|
]
|
|
|
|
for entity_id in entities:
|
|
if user.permissions.check_entity(entity_id, POLICY_CONTROL):
|
|
return await service_handler(call)
|
|
|
|
raise Unauthorized(
|
|
context=call.context,
|
|
permission=POLICY_CONTROL,
|
|
user_id=call.context.user_id,
|
|
perm_category=CAT_ENTITIES,
|
|
)
|
|
|
|
return check_permissions
|
|
|
|
return decorator
|