"""Service calling related helpers.""" from __future__ import annotations import asyncio import dataclasses from functools import partial, wraps import logging from typing import TYPE_CHECKING, Any, Awaitable, Callable, Iterable, TypedDict import voluptuous as vol from homeassistant.auth.permissions.const import CAT_ENTITIES, POLICY_CONTROL from homeassistant.const import ( ATTR_AREA_ID, ATTR_DEVICE_ID, ATTR_ENTITY_ID, CONF_ENTITY_ID, CONF_SERVICE, CONF_SERVICE_DATA, CONF_SERVICE_TEMPLATE, CONF_TARGET, ENTITY_MATCH_ALL, ENTITY_MATCH_NONE, ) from homeassistant.core import Context, HomeAssistant, ServiceCall, callback from homeassistant.exceptions import ( HomeAssistantError, TemplateError, Unauthorized, UnknownUser, ) from homeassistant.helpers import ( area_registry, config_validation as cv, device_registry, entity_registry, template, ) from homeassistant.helpers.typing import ConfigType, TemplateVarsType from homeassistant.loader import ( MAX_LOAD_CONCURRENTLY, Integration, async_get_integration, bind_hass, ) from homeassistant.util.async_ import gather_with_concurrency from homeassistant.util.yaml import load_yaml from homeassistant.util.yaml.loader import JSON_TYPE if TYPE_CHECKING: from homeassistant.helpers.entity import Entity from homeassistant.helpers.entity_platform import EntityPlatform CONF_SERVICE_ENTITY_ID = "entity_id" CONF_SERVICE_DATA_TEMPLATE = "data_template" _LOGGER = logging.getLogger(__name__) SERVICE_DESCRIPTION_CACHE = "service_description_cache" class ServiceParams(TypedDict): """Type for service call parameters.""" domain: str service: str service_data: dict[str, Any] target: dict | None class ServiceTargetSelector: """Class to hold a target selector for a service.""" def __init__(self, service_call: ServiceCall): """Extract ids from service call data.""" entity_ids: str | list | None = service_call.data.get(ATTR_ENTITY_ID) device_ids: str | list | None = service_call.data.get(ATTR_DEVICE_ID) area_ids: str | list | None = service_call.data.get(ATTR_AREA_ID) self.entity_ids = ( set(cv.ensure_list(entity_ids)) if _has_match(entity_ids) else set() ) self.device_ids = ( set(cv.ensure_list(device_ids)) if _has_match(device_ids) else set() ) self.area_ids = set(cv.ensure_list(area_ids)) if _has_match(area_ids) else set() @property def has_any_selector(self) -> bool: """Determine if any selectors are present.""" return bool(self.entity_ids or self.device_ids or self.area_ids) @dataclasses.dataclass class SelectedEntities: """Class to hold the selected entities.""" # Entities that were explicitly mentioned. referenced: set[str] = dataclasses.field(default_factory=set) # Entities that were referenced via device/area ID. # Should not trigger a warning when they don't exist. indirectly_referenced: set[str] = dataclasses.field(default_factory=set) # Referenced items that could not be found. missing_devices: set[str] = dataclasses.field(default_factory=set) missing_areas: set[str] = dataclasses.field(default_factory=set) # Referenced devices referenced_devices: set[str] = dataclasses.field(default_factory=set) def log_missing(self, missing_entities: set[str]) -> None: """Log about missing items.""" parts = [] for label, items in ( ("areas", self.missing_areas), ("devices", self.missing_devices), ("entities", missing_entities), ): if items: parts.append(f"{label} {', '.join(sorted(items))}") if not parts: return _LOGGER.warning("Unable to find referenced %s", ", ".join(parts)) @bind_hass def call_from_config( hass: HomeAssistant, config: ConfigType, blocking: bool = False, variables: TemplateVarsType = None, validate_config: bool = True, ) -> None: """Call a service based on a config hash.""" asyncio.run_coroutine_threadsafe( async_call_from_config(hass, config, blocking, variables, validate_config), hass.loop, ).result() @bind_hass async def async_call_from_config( hass: HomeAssistant, config: ConfigType, blocking: bool = False, variables: TemplateVarsType = None, validate_config: bool = True, context: Context | None = None, ) -> None: """Call a service based on a config hash.""" try: params = async_prepare_call_from_config( hass, config, variables, validate_config ) except HomeAssistantError as ex: if blocking: raise _LOGGER.error(ex) else: await hass.services.async_call(**params, blocking=blocking, context=context) @callback @bind_hass def async_prepare_call_from_config( hass: HomeAssistant, config: ConfigType, variables: TemplateVarsType = None, validate_config: bool = False, ) -> ServiceParams: """Prepare to call a service based on a config hash.""" if validate_config: try: config = cv.SERVICE_SCHEMA(config) except vol.Invalid as ex: raise HomeAssistantError( f"Invalid config for calling service: {ex}" ) from ex if CONF_SERVICE in config: domain_service = config[CONF_SERVICE] else: domain_service = config[CONF_SERVICE_TEMPLATE] if isinstance(domain_service, template.Template): try: domain_service.hass = hass domain_service = domain_service.async_render(variables) domain_service = cv.service(domain_service) except TemplateError as ex: raise HomeAssistantError( f"Error rendering service name template: {ex}" ) from ex except vol.Invalid as ex: raise HomeAssistantError( f"Template rendered invalid service: {domain_service}" ) from ex domain, service = domain_service.split(".", 1) target = {} if CONF_TARGET in config: conf = config[CONF_TARGET] try: if isinstance(conf, template.Template): conf.hass = hass target.update(conf.async_render(variables)) else: template.attach(hass, conf) target.update(template.render_complex(conf, variables)) if CONF_ENTITY_ID in target: target[CONF_ENTITY_ID] = cv.comp_entity_ids(target[CONF_ENTITY_ID]) except TemplateError as ex: raise HomeAssistantError( f"Error rendering service target template: {ex}" ) from ex except vol.Invalid as ex: raise HomeAssistantError( f"Template rendered invalid entity IDs: {target[CONF_ENTITY_ID]}" ) from ex service_data = {} for conf in [CONF_SERVICE_DATA, CONF_SERVICE_DATA_TEMPLATE]: if conf not in config: continue try: template.attach(hass, config[conf]) service_data.update(template.render_complex(config[conf], variables)) except TemplateError as ex: raise HomeAssistantError(f"Error rendering data template: {ex}") from ex if CONF_SERVICE_ENTITY_ID in config: if target: target[ATTR_ENTITY_ID] = config[CONF_SERVICE_ENTITY_ID] else: target = {ATTR_ENTITY_ID: config[CONF_SERVICE_ENTITY_ID]} return { "domain": domain, "service": service, "service_data": service_data, "target": target, } @bind_hass def extract_entity_ids( hass: HomeAssistant, service_call: ServiceCall, expand_group: bool = True ) -> set[str]: """Extract a list of entity ids from a service call. Will convert group entity ids to the entity ids it represents. """ return asyncio.run_coroutine_threadsafe( async_extract_entity_ids(hass, service_call, expand_group), hass.loop ).result() @bind_hass async def async_extract_entities( hass: HomeAssistant, entities: Iterable[Entity], service_call: ServiceCall, expand_group: bool = True, ) -> list[Entity]: """Extract a list of entity objects from a service call. Will convert group entity ids to the entity ids it represents. """ data_ent_id = service_call.data.get(ATTR_ENTITY_ID) if data_ent_id == ENTITY_MATCH_ALL: return [entity for entity in entities if entity.available] referenced = await async_extract_referenced_entity_ids( hass, service_call, expand_group ) combined = referenced.referenced | referenced.indirectly_referenced found = [] for entity in entities: if entity.entity_id not in combined: continue combined.remove(entity.entity_id) if not entity.available: continue found.append(entity) referenced.log_missing(referenced.referenced & combined) return found @bind_hass async def async_extract_entity_ids( hass: HomeAssistant, service_call: ServiceCall, expand_group: bool = True ) -> set[str]: """Extract a set of entity ids from a service call. Will convert group entity ids to the entity ids it represents. """ referenced = await async_extract_referenced_entity_ids( hass, service_call, expand_group ) return referenced.referenced | referenced.indirectly_referenced def _has_match(ids: str | list | None) -> bool: """Check if ids can match anything.""" return ids not in (None, ENTITY_MATCH_NONE) @bind_hass async def async_extract_referenced_entity_ids( hass: HomeAssistant, service_call: ServiceCall, expand_group: bool = True ) -> SelectedEntities: """Extract referenced entity IDs from a service call.""" selector = ServiceTargetSelector(service_call) selected = SelectedEntities() if not selector.has_any_selector: return selected entity_ids = selector.entity_ids if expand_group: entity_ids = hass.components.group.expand_entity_ids(entity_ids) selected.referenced.update(entity_ids) if not selector.device_ids and not selector.area_ids: return selected ent_reg = entity_registry.async_get(hass) dev_reg = device_registry.async_get(hass) area_reg = area_registry.async_get(hass) for device_id in selector.device_ids: if device_id not in dev_reg.devices: selected.missing_devices.add(device_id) for area_id in selector.area_ids: if area_id not in area_reg.areas: selected.missing_areas.add(area_id) # Find devices for this area selected.referenced_devices.update(selector.device_ids) for device_entry in dev_reg.devices.values(): if device_entry.area_id in selector.area_ids: selected.referenced_devices.add(device_entry.id) if not selector.area_ids and not selected.referenced_devices: return selected for ent_entry in ent_reg.entities.values(): if ( # when area matches the target area ent_entry.area_id in selector.area_ids # when device matches a referenced devices with no explicitly set area or ( not ent_entry.area_id and ent_entry.device_id in selected.referenced_devices ) # when device matches target device or ent_entry.device_id in selector.device_ids ): selected.indirectly_referenced.add(ent_entry.entity_id) return selected @bind_hass async def async_extract_config_entry_ids( hass: HomeAssistant, service_call: ServiceCall, expand_group: bool = True ) -> set: """Extract referenced config entry ids from a service call.""" referenced = await async_extract_referenced_entity_ids( hass, service_call, expand_group ) ent_reg = entity_registry.async_get(hass) dev_reg = device_registry.async_get(hass) config_entry_ids: set[str] = set() # Some devices may have no entities for device_id in referenced.referenced_devices: if device_id in dev_reg.devices: device = dev_reg.async_get(device_id) if device is not None: config_entry_ids.update(device.config_entries) for entity_id in referenced.referenced | referenced.indirectly_referenced: entry = ent_reg.async_get(entity_id) if entry is not None and entry.config_entry_id is not None: config_entry_ids.add(entry.config_entry_id) return config_entry_ids def _load_services_file(hass: HomeAssistant, integration: Integration) -> JSON_TYPE: """Load services file for an integration.""" try: return load_yaml(str(integration.file_path / "services.yaml")) except FileNotFoundError: _LOGGER.warning( "Unable to find services.yaml for the %s integration", integration.domain ) return {} except HomeAssistantError: _LOGGER.warning( "Unable to parse services.yaml for the %s integration", integration.domain ) return {} def _load_services_files( hass: HomeAssistant, integrations: Iterable[Integration] ) -> list[JSON_TYPE]: """Load service files for multiple intergrations.""" return [_load_services_file(hass, integration) for integration in integrations] @bind_hass async def async_get_all_descriptions( hass: HomeAssistant, ) -> dict[str, dict[str, Any]]: """Return descriptions (i.e. user documentation) for all service calls.""" descriptions_cache = hass.data.setdefault(SERVICE_DESCRIPTION_CACHE, {}) format_cache_key = "{}.{}".format services = hass.services.async_services() # See if there are new services not seen before. # Any service that we saw before already has an entry in description_cache. missing = set() for domain in services: for service in services[domain]: if format_cache_key(domain, service) not in descriptions_cache: missing.add(domain) break # Files we loaded for missing descriptions loaded = {} if missing: integrations = await gather_with_concurrency( MAX_LOAD_CONCURRENTLY, *(async_get_integration(hass, domain) for domain in missing), ) contents = await hass.async_add_executor_job( _load_services_files, hass, integrations ) for domain, content in zip(missing, contents): loaded[domain] = content # Build response descriptions: dict[str, dict[str, Any]] = {} for domain in services: descriptions[domain] = {} for service in services[domain]: cache_key = format_cache_key(domain, service) description = descriptions_cache.get(cache_key) # Cache missing descriptions if description is None: domain_yaml = loaded[domain] yaml_description = domain_yaml.get(service, {}) # type: ignore # Don't warn for missing services, because it triggers false # positives for things like scripts, that register as a service description = { "name": yaml_description.get("name", ""), "description": yaml_description.get("description", ""), "fields": yaml_description.get("fields", {}), } if "target" in yaml_description: description["target"] = yaml_description["target"] descriptions_cache[cache_key] = description descriptions[domain][service] = description return descriptions @callback @bind_hass def async_set_service_schema( hass: HomeAssistant, domain: str, service: str, schema: dict[str, Any] ) -> None: """Register a description for a service.""" hass.data.setdefault(SERVICE_DESCRIPTION_CACHE, {}) description = { "name": schema.get("name", ""), "description": schema.get("description", ""), "fields": schema.get("fields", {}), } if "target" in schema: description["target"] = schema["target"] hass.data[SERVICE_DESCRIPTION_CACHE][f"{domain}.{service}"] = description @bind_hass async def entity_service_call( hass: HomeAssistant, platforms: Iterable[EntityPlatform], func: str | Callable[..., Any], call: ServiceCall, required_features: Iterable[int] | None = None, ) -> None: """Handle an entity service call. Calls all platforms simultaneously. """ if call.context.user_id: user = await hass.auth.async_get_user(call.context.user_id) if user is None: raise UnknownUser(context=call.context) entity_perms: None | ( Callable[[str, str], bool] ) = user.permissions.check_entity else: entity_perms = None target_all_entities = call.data.get(ATTR_ENTITY_ID) == ENTITY_MATCH_ALL if target_all_entities: referenced: SelectedEntities | None = None all_referenced: set[str] | None = None else: # A set of entities we're trying to target. referenced = await async_extract_referenced_entity_ids(hass, call, True) all_referenced = referenced.referenced | referenced.indirectly_referenced # If the service function is a string, we'll pass it the service call data if isinstance(func, str): data: dict | ServiceCall = { key: val for key, val in call.data.items() if key not in cv.ENTITY_SERVICE_FIELDS } # If the service function is not a string, we pass the service call else: data = call # Check the permissions # A list with entities to call the service on. entity_candidates: list[Entity] = [] if entity_perms is None: for platform in platforms: if target_all_entities: entity_candidates.extend(platform.entities.values()) else: assert all_referenced is not None entity_candidates.extend( [ entity for entity in platform.entities.values() if entity.entity_id in all_referenced ] ) elif target_all_entities: # If we target all entities, we will select all entities the user # is allowed to control. for platform in platforms: entity_candidates.extend( [ entity for entity in platform.entities.values() if entity_perms(entity.entity_id, POLICY_CONTROL) ] ) else: assert all_referenced is not None for platform in platforms: platform_entities = [] for entity in platform.entities.values(): if entity.entity_id not in all_referenced: continue if not entity_perms(entity.entity_id, POLICY_CONTROL): raise Unauthorized( context=call.context, entity_id=entity.entity_id, permission=POLICY_CONTROL, ) platform_entities.append(entity) entity_candidates.extend(platform_entities) if not target_all_entities: assert referenced is not None # Only report on explicit referenced entities missing = set(referenced.referenced) for entity in entity_candidates: missing.discard(entity.entity_id) referenced.log_missing(missing) entities = [] for entity in entity_candidates: if not entity.available: continue # Skip entities that don't have the required feature. if required_features is not None and ( entity.supported_features is None or not any( entity.supported_features & feature_set == feature_set for feature_set in required_features ) ): continue entities.append(entity) if not entities: return done, pending = await asyncio.wait( [ asyncio.create_task( entity.async_request_call( _handle_entity_call(hass, entity, func, data, call.context) ) ) for entity in entities ] ) assert not pending for future in done: future.result() # pop exception if have tasks = [] for entity in entities: if not entity.should_poll: continue # Context expires if the turn on commands took a long time. # Set context again so it's there when we update entity.async_set_context(call.context) tasks.append(asyncio.create_task(entity.async_update_ha_state(True))) if tasks: done, pending = await asyncio.wait(tasks) assert not pending for future in done: future.result() # pop exception if have async def _handle_entity_call( hass: HomeAssistant, entity: Entity, func: str | Callable[..., Any], data: dict | ServiceCall, context: Context, ) -> None: """Handle calling service method.""" entity.async_set_context(context) if isinstance(func, str): result = hass.async_run_job(partial(getattr(entity, func), **data)) # type: ignore else: result = hass.async_run_job(func, entity, data) # Guard because callback functions do not return a task when passed to async_run_job. if result is not None: await result if asyncio.iscoroutine(result): _LOGGER.error( "Service %s for %s incorrectly returns a coroutine object. Await result instead in service handler. Report bug to integration author", func, entity.entity_id, ) await result # type: ignore @bind_hass @callback def async_register_admin_service( hass: HomeAssistant, domain: str, service: str, service_func: Callable[[ServiceCall], Awaitable | None], schema: vol.Schema = vol.Schema({}, extra=vol.PREVENT_EXTRA), ) -> None: """Register a service that requires admin access.""" @wraps(service_func) async def admin_handler(call: ServiceCall) -> None: if call.context.user_id: user = await hass.auth.async_get_user(call.context.user_id) if user is None: raise UnknownUser(context=call.context) if not user.is_admin: raise Unauthorized(context=call.context) result = hass.async_run_job(service_func, call) if result is not None: await result hass.services.async_register(domain, service, admin_handler, schema) @bind_hass @callback def verify_domain_control( hass: HomeAssistant, domain: str ) -> Callable[[Callable[[ServiceCall], Any]], Callable[[ServiceCall], Any]]: """Ensure permission to access any entity under domain in service call.""" def decorator( service_handler: Callable[[ServiceCall], Any] ) -> Callable[[ServiceCall], Any]: """Decorate.""" if not asyncio.iscoroutinefunction(service_handler): raise HomeAssistantError("Can only decorate async functions.") async def check_permissions(call: ServiceCall) -> Any: """Check user permission and raise before call if unauthorized.""" if not call.context.user_id: return await service_handler(call) user = await hass.auth.async_get_user(call.context.user_id) if user is None: raise UnknownUser( context=call.context, permission=POLICY_CONTROL, user_id=call.context.user_id, ) reg = await hass.helpers.entity_registry.async_get_registry() authorized = False for entity in reg.entities.values(): if entity.platform != domain: continue if user.permissions.check_entity(entity.entity_id, POLICY_CONTROL): authorized = True break if not authorized: raise Unauthorized( context=call.context, permission=POLICY_CONTROL, user_id=call.context.user_id, perm_category=CAT_ENTITIES, ) return await service_handler(call) return check_permissions return decorator