"""Helpers to execute scripts.""" from __future__ import annotations import asyncio from collections.abc import Callable, Sequence from contextlib import asynccontextmanager, suppress from datetime import datetime, timedelta from functools import partial import itertools import logging from types import MappingProxyType from typing import Any, Dict, TypedDict, Union, cast import async_timeout import voluptuous as vol from homeassistant import exceptions from homeassistant.components import device_automation, scene from homeassistant.components.logger import LOGSEVERITY from homeassistant.const import ( ATTR_AREA_ID, ATTR_DEVICE_ID, ATTR_ENTITY_ID, CONF_ALIAS, CONF_CHOOSE, CONF_CONDITION, CONF_CONDITIONS, CONF_CONTINUE_ON_TIMEOUT, CONF_COUNT, CONF_DEFAULT, CONF_DELAY, CONF_DEVICE_ID, CONF_DOMAIN, CONF_EVENT, CONF_EVENT_DATA, CONF_EVENT_DATA_TEMPLATE, CONF_MODE, CONF_REPEAT, CONF_SCENE, CONF_SEQUENCE, CONF_SERVICE, CONF_TARGET, CONF_TIMEOUT, CONF_UNTIL, CONF_VARIABLES, CONF_WAIT_FOR_TRIGGER, CONF_WAIT_TEMPLATE, CONF_WHILE, EVENT_HOMEASSISTANT_STOP, SERVICE_TURN_ON, ) from homeassistant.core import ( SERVICE_CALL_LIMIT, Context, HassJob, HomeAssistant, callback, ) from homeassistant.helpers import condition, config_validation as cv, service, template from homeassistant.helpers.condition import ( ConditionCheckerType, trace_condition_function, ) from homeassistant.helpers.dispatcher import ( async_dispatcher_connect, async_dispatcher_send, ) from homeassistant.helpers.event import async_call_later, async_track_template from homeassistant.helpers.script_variables import ScriptVariables from homeassistant.helpers.trace import script_execution_set from homeassistant.helpers.trigger import ( async_initialize_triggers, async_validate_trigger_config, ) from homeassistant.helpers.typing import ConfigType from homeassistant.util import slugify from homeassistant.util.dt import utcnow from .trace import ( TraceElement, async_trace_path, trace_append_element, trace_id_get, trace_path, trace_path_get, trace_set_result, trace_stack_cv, trace_stack_pop, trace_stack_push, trace_stack_top, trace_update_result, ) # mypy: allow-untyped-calls, allow-untyped-defs, no-check-untyped-defs SCRIPT_MODE_PARALLEL = "parallel" SCRIPT_MODE_QUEUED = "queued" SCRIPT_MODE_RESTART = "restart" SCRIPT_MODE_SINGLE = "single" SCRIPT_MODE_CHOICES = [ SCRIPT_MODE_PARALLEL, SCRIPT_MODE_QUEUED, SCRIPT_MODE_RESTART, SCRIPT_MODE_SINGLE, ] DEFAULT_SCRIPT_MODE = SCRIPT_MODE_SINGLE CONF_MAX = "max" DEFAULT_MAX = 10 CONF_MAX_EXCEEDED = "max_exceeded" _MAX_EXCEEDED_CHOICES = list(LOGSEVERITY) + ["SILENT"] DEFAULT_MAX_EXCEEDED = "WARNING" ATTR_CUR = "current" ATTR_MAX = "max" DATA_SCRIPTS = "helpers.script" DATA_SCRIPT_BREAKPOINTS = "helpers.script_breakpoints" RUN_ID_ANY = "*" NODE_ANY = "*" _LOGGER = logging.getLogger(__name__) _LOG_EXCEPTION = logging.ERROR + 1 _TIMEOUT_MSG = "Timeout reached, abort script." _SHUTDOWN_MAX_WAIT = 60 ACTION_TRACE_NODE_MAX_LEN = 20 # Max length of a trace node for repeated actions SCRIPT_BREAKPOINT_HIT = "script_breakpoint_hit" SCRIPT_DEBUG_CONTINUE_STOP = "script_debug_continue_stop_{}_{}" SCRIPT_DEBUG_CONTINUE_ALL = "script_debug_continue_all" def action_trace_append(variables, path): """Append a TraceElement to trace[path].""" trace_element = TraceElement(variables, path) trace_append_element(trace_element, ACTION_TRACE_NODE_MAX_LEN) return trace_element @asynccontextmanager async def trace_action(hass, script_run, stop, variables): """Trace action execution.""" path = trace_path_get() trace_element = action_trace_append(variables, path) trace_stack_push(trace_stack_cv, trace_element) trace_id = trace_id_get() if trace_id: key = trace_id[0] run_id = trace_id[1] breakpoints = hass.data[DATA_SCRIPT_BREAKPOINTS] if key in breakpoints and ( ( run_id in breakpoints[key] and ( path in breakpoints[key][run_id] or NODE_ANY in breakpoints[key][run_id] ) ) or ( RUN_ID_ANY in breakpoints[key] and ( path in breakpoints[key][RUN_ID_ANY] or NODE_ANY in breakpoints[key][RUN_ID_ANY] ) ) ): async_dispatcher_send(hass, SCRIPT_BREAKPOINT_HIT, key, run_id, path) done = asyncio.Event() @callback def async_continue_stop(command=None): if command == "stop": stop.set() done.set() signal = SCRIPT_DEBUG_CONTINUE_STOP.format(key, run_id) remove_signal1 = async_dispatcher_connect(hass, signal, async_continue_stop) remove_signal2 = async_dispatcher_connect( hass, SCRIPT_DEBUG_CONTINUE_ALL, async_continue_stop ) tasks = [hass.async_create_task(flag.wait()) for flag in (stop, done)] await asyncio.wait(tasks, return_when=asyncio.FIRST_COMPLETED) for task in tasks: task.cancel() remove_signal1() remove_signal2() try: yield trace_element except _StopScript as ex: trace_element.set_error(ex.__cause__ or ex) raise ex except Exception as ex: trace_element.set_error(ex) raise ex finally: trace_stack_pop(trace_stack_cv) def make_script_schema(schema, default_script_mode, extra=vol.PREVENT_EXTRA): """Make a schema for a component that uses the script helper.""" return vol.Schema( { **schema, vol.Optional(CONF_MODE, default=default_script_mode): vol.In( SCRIPT_MODE_CHOICES ), vol.Optional(CONF_MAX, default=DEFAULT_MAX): vol.All( vol.Coerce(int), vol.Range(min=2) ), vol.Optional(CONF_MAX_EXCEEDED, default=DEFAULT_MAX_EXCEEDED): vol.All( vol.Upper, vol.In(_MAX_EXCEEDED_CHOICES) ), }, extra=extra, ) STATIC_VALIDATION_ACTION_TYPES = ( cv.SCRIPT_ACTION_CALL_SERVICE, cv.SCRIPT_ACTION_DELAY, cv.SCRIPT_ACTION_WAIT_TEMPLATE, cv.SCRIPT_ACTION_FIRE_EVENT, cv.SCRIPT_ACTION_ACTIVATE_SCENE, cv.SCRIPT_ACTION_VARIABLES, ) async def async_validate_actions_config( hass: HomeAssistant, actions: list[ConfigType] ) -> list[ConfigType]: """Validate a list of actions.""" return await asyncio.gather( *(async_validate_action_config(hass, action) for action in actions) ) async def async_validate_action_config( hass: HomeAssistant, config: ConfigType ) -> ConfigType: """Validate config.""" action_type = cv.determine_script_action(config) if action_type in STATIC_VALIDATION_ACTION_TYPES: pass elif action_type == cv.SCRIPT_ACTION_DEVICE_AUTOMATION: platform = await device_automation.async_get_device_automation_platform( hass, config[CONF_DOMAIN], "action" ) if hasattr(platform, "async_validate_action_config"): config = await platform.async_validate_action_config(hass, config) # type: ignore else: config = platform.ACTION_SCHEMA(config) # type: ignore elif action_type == cv.SCRIPT_ACTION_CHECK_CONDITION: if config[CONF_CONDITION] == "device": platform = await device_automation.async_get_device_automation_platform( hass, config[CONF_DOMAIN], "condition" ) config = platform.CONDITION_SCHEMA(config) # type: ignore elif action_type == cv.SCRIPT_ACTION_WAIT_FOR_TRIGGER: config[CONF_WAIT_FOR_TRIGGER] = await async_validate_trigger_config( hass, config[CONF_WAIT_FOR_TRIGGER] ) elif action_type == cv.SCRIPT_ACTION_REPEAT: config[CONF_SEQUENCE] = await async_validate_actions_config( hass, config[CONF_REPEAT][CONF_SEQUENCE] ) elif action_type == cv.SCRIPT_ACTION_CHOOSE: if CONF_DEFAULT in config: config[CONF_DEFAULT] = await async_validate_actions_config( hass, config[CONF_DEFAULT] ) for choose_conf in config[CONF_CHOOSE]: choose_conf[CONF_SEQUENCE] = await async_validate_actions_config( hass, choose_conf[CONF_SEQUENCE] ) else: raise ValueError(f"No validation for {action_type}") return config class _StopScript(Exception): """Throw if script needs to stop.""" class _ScriptRun: """Manage Script sequence run.""" def __init__( self, hass: HomeAssistant, script: Script, variables: dict[str, Any], context: Context | None, log_exceptions: bool, ) -> None: self._hass = hass self._script = script self._variables = variables self._context = context self._log_exceptions = log_exceptions self._step = -1 self._action: dict[str, Any] | None = None self._stop = asyncio.Event() self._stopped = asyncio.Event() def _changed(self) -> None: if not self._stop.is_set(): self._script._changed() # pylint: disable=protected-access async def _async_get_condition(self, config): # pylint: disable=protected-access return await self._script._async_get_condition(config) def _log( self, msg: str, *args: Any, level: int = logging.INFO, **kwargs: Any ) -> None: self._script._log( # pylint: disable=protected-access msg, *args, level=level, **kwargs ) def _step_log(self, default_message, timeout=None): self._script.last_action = self._action.get(CONF_ALIAS, default_message) _timeout = ( "" if timeout is None else f" (timeout: {timedelta(seconds=timeout)})" ) self._log("Executing step %s%s", self._script.last_action, _timeout) async def async_run(self) -> None: """Run script.""" try: self._log("Running %s", self._script.running_description) for self._step, self._action in enumerate(self._script.sequence): if self._stop.is_set(): script_execution_set("cancelled") break await self._async_step(log_exceptions=False) else: script_execution_set("finished") except _StopScript: script_execution_set("aborted") except Exception: script_execution_set("error") raise finally: self._finish() async def _async_step(self, log_exceptions): with trace_path(str(self._step)): async with trace_action(self._hass, self, self._stop, self._variables): if self._stop.is_set(): return try: handler = f"_async_{cv.determine_script_action(self._action)}_step" await getattr(self, handler)() except Exception as ex: if not isinstance(ex, _StopScript) and ( self._log_exceptions or log_exceptions ): self._log_exception(ex) raise def _finish(self) -> None: self._script._runs.remove(self) # pylint: disable=protected-access if not self._script.is_running: self._script.last_action = None self._changed() self._stopped.set() async def async_stop(self) -> None: """Stop script run.""" self._stop.set() await self._stopped.wait() def _log_exception(self, exception): action_type = cv.determine_script_action(self._action) error = str(exception) level = logging.ERROR if isinstance(exception, vol.Invalid): error_desc = "Invalid data" elif isinstance(exception, exceptions.TemplateError): error_desc = "Error rendering template" elif isinstance(exception, exceptions.Unauthorized): error_desc = "Unauthorized" elif isinstance(exception, exceptions.ServiceNotFound): error_desc = "Service not found" elif isinstance(exception, exceptions.HomeAssistantError): error_desc = "Error" else: error_desc = "Unexpected error" level = _LOG_EXCEPTION self._log( "Error executing script. %s for %s at pos %s: %s", error_desc, action_type, self._step + 1, error, level=level, ) def _get_pos_time_period_template(self, key): try: return cv.positive_time_period( template.render_complex(self._action[key], self._variables) ) except (exceptions.TemplateError, vol.Invalid) as ex: self._log( "Error rendering %s %s template: %s", self._script.name, key, ex, level=logging.ERROR, ) raise _StopScript from ex async def _async_delay_step(self): """Handle delay.""" delay = self._get_pos_time_period_template(CONF_DELAY) self._step_log(f"delay {delay}") delay = delay.total_seconds() self._changed() trace_set_result(delay=delay, done=False) try: async with async_timeout.timeout(delay): await self._stop.wait() except asyncio.TimeoutError: trace_set_result(delay=delay, done=True) async def _async_wait_template_step(self): """Handle a wait template.""" if CONF_TIMEOUT in self._action: timeout = self._get_pos_time_period_template(CONF_TIMEOUT).total_seconds() else: timeout = None self._step_log("wait template", timeout) self._variables["wait"] = {"remaining": timeout, "completed": False} trace_set_result(wait=self._variables["wait"]) wait_template = self._action[CONF_WAIT_TEMPLATE] wait_template.hass = self._hass # check if condition already okay if condition.async_template(self._hass, wait_template, self._variables, False): self._variables["wait"]["completed"] = True return @callback def async_script_wait(entity_id, from_s, to_s): """Handle script after template condition is true.""" wait_var = self._variables["wait"] wait_var["remaining"] = to_context.remaining if to_context else timeout wait_var["completed"] = True done.set() to_context = None unsub = async_track_template( self._hass, wait_template, async_script_wait, self._variables ) self._changed() done = asyncio.Event() tasks = [ self._hass.async_create_task(flag.wait()) for flag in (self._stop, done) ] try: async with async_timeout.timeout(timeout) as to_context: await asyncio.wait(tasks, return_when=asyncio.FIRST_COMPLETED) except asyncio.TimeoutError as ex: self._variables["wait"]["remaining"] = 0.0 if not self._action.get(CONF_CONTINUE_ON_TIMEOUT, True): self._log(_TIMEOUT_MSG) trace_set_result(wait=self._variables["wait"], timeout=True) raise _StopScript from ex finally: for task in tasks: task.cancel() unsub() async def _async_run_long_action(self, long_task: asyncio.Task) -> None: """Run a long task while monitoring for stop request.""" async def async_cancel_long_task() -> None: # Stop long task and wait for it to finish. long_task.cancel() with suppress(Exception): await long_task # Wait for long task while monitoring for a stop request. stop_task = self._hass.async_create_task(self._stop.wait()) try: await asyncio.wait( {long_task, stop_task}, return_when=asyncio.FIRST_COMPLETED ) # If our task is cancelled, then cancel long task, too. Note that if long task # is cancelled otherwise the CancelledError exception will not be raised to # here due to the call to asyncio.wait(). Rather we'll check for that below. except asyncio.CancelledError: await async_cancel_long_task() raise finally: stop_task.cancel() if long_task.cancelled(): raise asyncio.CancelledError if long_task.done(): # Propagate any exceptions that occurred. long_task.result() else: # Stopped before long task completed, so cancel it. await async_cancel_long_task() async def _async_call_service_step(self): """Call the service specified in the action.""" self._step_log("call service") params = service.async_prepare_call_from_config( self._hass, self._action, self._variables ) running_script = ( params[CONF_DOMAIN] == "automation" and params[CONF_SERVICE] == "trigger" or params[CONF_DOMAIN] in ("python_script", "script") ) # If this might start a script then disable the call timeout. # Otherwise use the normal service call limit. if running_script: limit = None else: limit = SERVICE_CALL_LIMIT trace_set_result(params=params, running_script=running_script, limit=limit) service_task = self._hass.async_create_task( self._hass.services.async_call( **params, blocking=True, context=self._context, limit=limit, ) ) if limit is not None: # There is a call limit, so just wait for it to finish. await service_task return await self._async_run_long_action(service_task) async def _async_device_step(self): """Perform the device automation specified in the action.""" self._step_log("device automation") platform = await device_automation.async_get_device_automation_platform( self._hass, self._action[CONF_DOMAIN], "action" ) await platform.async_call_action_from_config( self._hass, self._action, self._variables, self._context ) async def _async_scene_step(self): """Activate the scene specified in the action.""" self._step_log("activate scene") trace_set_result(scene=self._action[CONF_SCENE]) await self._hass.services.async_call( scene.DOMAIN, SERVICE_TURN_ON, {ATTR_ENTITY_ID: self._action[CONF_SCENE]}, blocking=True, context=self._context, ) async def _async_event_step(self): """Fire an event.""" self._step_log(self._action.get(CONF_ALIAS, self._action[CONF_EVENT])) event_data = {} for conf in (CONF_EVENT_DATA, CONF_EVENT_DATA_TEMPLATE): if conf not in self._action: continue try: event_data.update( template.render_complex(self._action[conf], self._variables) ) except exceptions.TemplateError as ex: self._log( "Error rendering event data template: %s", ex, level=logging.ERROR ) trace_set_result(event=self._action[CONF_EVENT], event_data=event_data) self._hass.bus.async_fire( self._action[CONF_EVENT], event_data, context=self._context ) async def _async_condition_step(self): """Test if condition is matching.""" self._script.last_action = self._action.get( CONF_ALIAS, self._action[CONF_CONDITION] ) cond = await self._async_get_condition(self._action) try: trace_element = trace_stack_top(trace_stack_cv) if trace_element: trace_element.reuse_by_child = True check = cond(self._hass, self._variables) except exceptions.ConditionError as ex: _LOGGER.warning("Error in 'condition' evaluation:\n%s", ex) check = False self._log("Test condition %s: %s", self._script.last_action, check) trace_update_result(result=check) if not check: raise _StopScript def _test_conditions(self, conditions, name, condition_path=None): if condition_path is None: condition_path = name @trace_condition_function def traced_test_conditions(hass, variables): try: with trace_path(condition_path): for idx, cond in enumerate(conditions): with trace_path(str(idx)): if not cond(hass, variables): return False except exceptions.ConditionError as ex: _LOGGER.warning("Error in '%s[%s]' evaluation: %s", name, idx, ex) return None return True result = traced_test_conditions(self._hass, self._variables) return result @async_trace_path("repeat") async def _async_repeat_step(self): """Repeat a sequence.""" description = self._action.get(CONF_ALIAS, "sequence") repeat = self._action[CONF_REPEAT] saved_repeat_vars = self._variables.get("repeat") def set_repeat_var(iteration, count=None): repeat_vars = {"first": iteration == 1, "index": iteration} if count: repeat_vars["last"] = iteration == count self._variables["repeat"] = repeat_vars # pylint: disable=protected-access script = self._script._get_repeat_script(self._step) async def async_run_sequence(iteration, extra_msg=""): self._log("Repeating %s: Iteration %i%s", description, iteration, extra_msg) with trace_path("sequence"): await self._async_run_script(script) if CONF_COUNT in repeat: count = repeat[CONF_COUNT] if isinstance(count, template.Template): try: count = int(count.async_render(self._variables)) except (exceptions.TemplateError, ValueError) as ex: self._log( "Error rendering %s repeat count template: %s", self._script.name, ex, level=logging.ERROR, ) raise _StopScript from ex extra_msg = f" of {count}" for iteration in range(1, count + 1): set_repeat_var(iteration, count) await async_run_sequence(iteration, extra_msg) if self._stop.is_set(): break elif CONF_WHILE in repeat: conditions = [ await self._async_get_condition(config) for config in repeat[CONF_WHILE] ] for iteration in itertools.count(1): set_repeat_var(iteration) try: if self._stop.is_set(): break if not self._test_conditions(conditions, "while"): break except exceptions.ConditionError as ex: _LOGGER.warning("Error in 'while' evaluation:\n%s", ex) break await async_run_sequence(iteration) elif CONF_UNTIL in repeat: conditions = [ await self._async_get_condition(config) for config in repeat[CONF_UNTIL] ] for iteration in itertools.count(1): set_repeat_var(iteration) await async_run_sequence(iteration) try: if self._stop.is_set(): break if self._test_conditions(conditions, "until") in [True, None]: break except exceptions.ConditionError as ex: _LOGGER.warning("Error in 'until' evaluation:\n%s", ex) break if saved_repeat_vars: self._variables["repeat"] = saved_repeat_vars else: del self._variables["repeat"] async def _async_choose_step(self) -> None: """Choose a sequence.""" # pylint: disable=protected-access choose_data = await self._script._async_get_choose_data(self._step) with trace_path("choose"): for idx, (conditions, script) in enumerate(choose_data["choices"]): with trace_path(str(idx)): try: if self._test_conditions(conditions, "choose", "conditions"): trace_set_result(choice=idx) with trace_path("sequence"): await self._async_run_script(script) return except exceptions.ConditionError as ex: _LOGGER.warning("Error in 'choose' evaluation:\n%s", ex) if choose_data["default"] is not None: trace_set_result(choice="default") with trace_path(["default"]): await self._async_run_script(choose_data["default"]) async def _async_wait_for_trigger_step(self): """Wait for a trigger event.""" if CONF_TIMEOUT in self._action: timeout = self._get_pos_time_period_template(CONF_TIMEOUT).total_seconds() else: timeout = None self._step_log("wait for trigger", timeout) variables = {**self._variables} self._variables["wait"] = {"remaining": timeout, "trigger": None} trace_set_result(wait=self._variables["wait"]) done = asyncio.Event() async def async_done(variables, context=None): wait_var = self._variables["wait"] wait_var["remaining"] = to_context.remaining if to_context else timeout wait_var["trigger"] = variables["trigger"] done.set() def log_cb(level, msg, **kwargs): self._log(msg, level=level, **kwargs) to_context = None remove_triggers = await async_initialize_triggers( self._hass, self._action[CONF_WAIT_FOR_TRIGGER], async_done, self._script.domain, self._script.name, log_cb, variables=variables, ) if not remove_triggers: return self._changed() tasks = [ self._hass.async_create_task(flag.wait()) for flag in (self._stop, done) ] try: async with async_timeout.timeout(timeout) as to_context: await asyncio.wait(tasks, return_when=asyncio.FIRST_COMPLETED) except asyncio.TimeoutError as ex: self._variables["wait"]["remaining"] = 0.0 if not self._action.get(CONF_CONTINUE_ON_TIMEOUT, True): self._log(_TIMEOUT_MSG) trace_set_result(wait=self._variables["wait"], timeout=True) raise _StopScript from ex finally: for task in tasks: task.cancel() remove_triggers() async def _async_variables_step(self): """Set a variable value.""" self._step_log("setting variables") self._variables = self._action[CONF_VARIABLES].async_render( self._hass, self._variables, render_as_defaults=False ) async def _async_run_script(self, script: Script) -> None: """Execute a script.""" await self._async_run_long_action( self._hass.async_create_task( script.async_run(self._variables, self._context) ) ) class _QueuedScriptRun(_ScriptRun): """Manage queued Script sequence run.""" lock_acquired = False async def async_run(self) -> None: """Run script.""" # Wait for previous run, if any, to finish by attempting to acquire the script's # shared lock. At the same time monitor if we've been told to stop. lock_task = self._hass.async_create_task( self._script._queue_lck.acquire() # pylint: disable=protected-access ) stop_task = self._hass.async_create_task(self._stop.wait()) try: await asyncio.wait( {lock_task, stop_task}, return_when=asyncio.FIRST_COMPLETED ) except asyncio.CancelledError: lock_task.cancel() self._finish() raise finally: stop_task.cancel() self.lock_acquired = lock_task.done() and not lock_task.cancelled() # If we've been told to stop, then just finish up. Otherwise, we've acquired the # lock so we can go ahead and start the run. if self._stop.is_set(): self._finish() else: await super().async_run() def _finish(self) -> None: # pylint: disable=protected-access if self.lock_acquired: self._script._queue_lck.release() self.lock_acquired = False super()._finish() async def _async_stop_scripts_after_shutdown(hass, point_in_time): """Stop running Script objects started after shutdown.""" running_scripts = [ script for script in hass.data[DATA_SCRIPTS] if script["instance"].is_running ] if running_scripts: names = ", ".join([script["instance"].name for script in running_scripts]) _LOGGER.warning("Stopping scripts running too long after shutdown: %s", names) await asyncio.gather( *( script["instance"].async_stop(update_state=False) for script in running_scripts ) ) async def _async_stop_scripts_at_shutdown(hass, event): """Stop running Script objects started before shutdown.""" async_call_later( hass, _SHUTDOWN_MAX_WAIT, partial(_async_stop_scripts_after_shutdown, hass) ) running_scripts = [ script for script in hass.data[DATA_SCRIPTS] if script["instance"].is_running and script["started_before_shutdown"] ] if running_scripts: names = ", ".join([script["instance"].name for script in running_scripts]) _LOGGER.debug("Stopping scripts running at shutdown: %s", names) await asyncio.gather( *(script["instance"].async_stop() for script in running_scripts) ) _VarsType = Union[Dict[str, Any], MappingProxyType] def _referenced_extract_ids(data: dict[str, Any], key: str, found: set[str]) -> None: """Extract referenced IDs.""" if not data: return item_ids = data.get(key) if item_ids is None or isinstance(item_ids, template.Template): return if isinstance(item_ids, str): found.add(item_ids) else: for item_id in item_ids: found.add(item_id) class _ChooseData(TypedDict): choices: list[tuple[list[ConditionCheckerType], Script]] default: Script | None class Script: """Representation of a script.""" def __init__( self, hass: HomeAssistant, sequence: Sequence[dict[str, Any]], name: str, domain: str, *, # Used in "Running " log message running_description: str | None = None, change_listener: Callable[..., Any] | None = None, script_mode: str = DEFAULT_SCRIPT_MODE, max_runs: int = DEFAULT_MAX, max_exceeded: str = DEFAULT_MAX_EXCEEDED, logger: logging.Logger | None = None, log_exceptions: bool = True, top_level: bool = True, variables: ScriptVariables | None = None, ) -> None: """Initialize the script.""" if not (all_scripts := hass.data.get(DATA_SCRIPTS)): all_scripts = hass.data[DATA_SCRIPTS] = [] hass.bus.async_listen_once( EVENT_HOMEASSISTANT_STOP, partial(_async_stop_scripts_at_shutdown, hass) ) self._top_level = top_level if top_level: all_scripts.append( {"instance": self, "started_before_shutdown": not hass.is_stopping} ) if DATA_SCRIPT_BREAKPOINTS not in hass.data: hass.data[DATA_SCRIPT_BREAKPOINTS] = {} self._hass = hass self.sequence = sequence template.attach(hass, self.sequence) self.name = name self.domain = domain self.running_description = running_description or f"{domain} script" self._change_listener = change_listener self._change_listener_job = ( None if change_listener is None else HassJob(change_listener) ) self.script_mode = script_mode self._set_logger(logger) self._log_exceptions = log_exceptions self.last_action = None self.last_triggered: datetime | None = None self._runs: list[_ScriptRun] = [] self.max_runs = max_runs self._max_exceeded = max_exceeded if script_mode == SCRIPT_MODE_QUEUED: self._queue_lck = asyncio.Lock() self._config_cache: dict[set[tuple], Callable[..., bool]] = {} self._repeat_script: dict[int, Script] = {} self._choose_data: dict[int, _ChooseData] = {} self._referenced_entities: set[str] | None = None self._referenced_devices: set[str] | None = None self._referenced_areas: set[str] | None = None self.variables = variables self._variables_dynamic = template.is_complex(variables) if self._variables_dynamic: template.attach(hass, variables) @property def change_listener(self) -> Callable[..., Any] | None: """Return the change_listener.""" return self._change_listener @change_listener.setter def change_listener(self, change_listener: Callable[..., Any]) -> None: """Update the change_listener.""" self._change_listener = change_listener if ( self._change_listener_job is None or change_listener != self._change_listener_job.target ): self._change_listener_job = HassJob(change_listener) def _set_logger(self, logger: logging.Logger | None = None) -> None: if logger: self._logger = logger else: self._logger = logging.getLogger(f"{__name__}.{slugify(self.name)}") def update_logger(self, logger: logging.Logger | None = None) -> None: """Update logger.""" self._set_logger(logger) for script in self._repeat_script.values(): script.update_logger(self._logger) for choose_data in self._choose_data.values(): for _, script in choose_data["choices"]: script.update_logger(self._logger) if choose_data["default"] is not None: choose_data["default"].update_logger(self._logger) def _changed(self) -> None: if self._change_listener_job: self._hass.async_run_hass_job(self._change_listener_job) def _chain_change_listener(self, sub_script: Script) -> None: if sub_script.is_running: self.last_action = sub_script.last_action self._changed() @property def is_running(self) -> bool: """Return true if script is on.""" return len(self._runs) > 0 @property def runs(self) -> int: """Return the number of current runs.""" return len(self._runs) @property def supports_max(self) -> bool: """Return true if the current mode support max.""" return self.script_mode in (SCRIPT_MODE_PARALLEL, SCRIPT_MODE_QUEUED) @property def referenced_areas(self): """Return a set of referenced areas.""" if self._referenced_areas is not None: return self._referenced_areas self._referenced_areas: set[str] = set() Script._find_referenced_areas(self._referenced_areas, self.sequence) return self._referenced_areas @staticmethod def _find_referenced_areas(referenced, sequence): for step in sequence: action = cv.determine_script_action(step) if action == cv.SCRIPT_ACTION_CALL_SERVICE: for data in ( step.get(CONF_TARGET), step.get(service.CONF_SERVICE_DATA), step.get(service.CONF_SERVICE_DATA_TEMPLATE), ): _referenced_extract_ids(data, ATTR_AREA_ID, referenced) elif action == cv.SCRIPT_ACTION_CHOOSE: for choice in step[CONF_CHOOSE]: Script._find_referenced_areas(referenced, choice[CONF_SEQUENCE]) if CONF_DEFAULT in step: Script._find_referenced_areas(referenced, step[CONF_DEFAULT]) @property def referenced_devices(self): """Return a set of referenced devices.""" if self._referenced_devices is not None: return self._referenced_devices self._referenced_devices: set[str] = set() Script._find_referenced_devices(self._referenced_devices, self.sequence) return self._referenced_devices @staticmethod def _find_referenced_devices(referenced, sequence): for step in sequence: action = cv.determine_script_action(step) if action == cv.SCRIPT_ACTION_CALL_SERVICE: for data in ( step.get(CONF_TARGET), step.get(service.CONF_SERVICE_DATA), step.get(service.CONF_SERVICE_DATA_TEMPLATE), ): _referenced_extract_ids(data, ATTR_DEVICE_ID, referenced) elif action == cv.SCRIPT_ACTION_CHECK_CONDITION: referenced |= condition.async_extract_devices(step) elif action == cv.SCRIPT_ACTION_DEVICE_AUTOMATION: referenced.add(step[CONF_DEVICE_ID]) elif action == cv.SCRIPT_ACTION_CHOOSE: for choice in step[CONF_CHOOSE]: for cond in choice[CONF_CONDITIONS]: referenced |= condition.async_extract_devices(cond) Script._find_referenced_devices(referenced, choice[CONF_SEQUENCE]) if CONF_DEFAULT in step: Script._find_referenced_devices(referenced, step[CONF_DEFAULT]) @property def referenced_entities(self): """Return a set of referenced entities.""" if self._referenced_entities is not None: return self._referenced_entities self._referenced_entities: set[str] = set() Script._find_referenced_entities(self._referenced_entities, self.sequence) return self._referenced_entities @staticmethod def _find_referenced_entities(referenced, sequence): for step in sequence: action = cv.determine_script_action(step) if action == cv.SCRIPT_ACTION_CALL_SERVICE: for data in ( step, step.get(CONF_TARGET), step.get(service.CONF_SERVICE_DATA), step.get(service.CONF_SERVICE_DATA_TEMPLATE), ): _referenced_extract_ids(data, ATTR_ENTITY_ID, referenced) elif action == cv.SCRIPT_ACTION_CHECK_CONDITION: referenced |= condition.async_extract_entities(step) elif action == cv.SCRIPT_ACTION_ACTIVATE_SCENE: referenced.add(step[CONF_SCENE]) elif action == cv.SCRIPT_ACTION_CHOOSE: for choice in step[CONF_CHOOSE]: for cond in choice[CONF_CONDITIONS]: referenced |= condition.async_extract_entities(cond) Script._find_referenced_entities(referenced, choice[CONF_SEQUENCE]) if CONF_DEFAULT in step: Script._find_referenced_entities(referenced, step[CONF_DEFAULT]) def run( self, variables: _VarsType | None = None, context: Context | None = None ) -> None: """Run script.""" asyncio.run_coroutine_threadsafe( self.async_run(variables, context), self._hass.loop ).result() async def async_run( self, run_variables: _VarsType | None = None, context: Context | None = None, started_action: Callable[..., Any] | None = None, ) -> None: """Run script.""" if context is None: self._log( "Running script requires passing in a context", level=logging.WARNING ) context = Context() if self.is_running: if self.script_mode == SCRIPT_MODE_SINGLE: if self._max_exceeded != "SILENT": self._log("Already running", level=LOGSEVERITY[self._max_exceeded]) script_execution_set("failed_single") return if self.script_mode != SCRIPT_MODE_RESTART and self.runs == self.max_runs: if self._max_exceeded != "SILENT": self._log( "Maximum number of runs exceeded", level=LOGSEVERITY[self._max_exceeded], ) script_execution_set("failed_max_runs") return # If this is a top level Script then make a copy of the variables in case they # are read-only, but more importantly, so as not to leak any variables created # during the run back to the caller. if self._top_level: if self.variables: try: variables = self.variables.async_render( self._hass, run_variables, ) except template.TemplateError as err: self._log("Error rendering variables: %s", err, level=logging.ERROR) raise elif run_variables: variables = dict(run_variables) else: variables = {} variables["context"] = context else: variables = cast(dict, run_variables) if self.script_mode != SCRIPT_MODE_QUEUED: cls = _ScriptRun else: cls = _QueuedScriptRun run = cls( self._hass, self, cast(dict, variables), context, self._log_exceptions ) self._runs.append(run) if self.script_mode == SCRIPT_MODE_RESTART: # When script mode is SCRIPT_MODE_RESTART, first add the new run and then # stop any other runs. If we stop other runs first, self.is_running will # return false after the other script runs were stopped until our task # resumes running. self._log("Restarting") await self.async_stop(update_state=False, spare=run) if started_action: self._hass.async_run_job(started_action) self.last_triggered = utcnow() self._changed() try: await asyncio.shield(run.async_run()) except asyncio.CancelledError: await run.async_stop() self._changed() raise async def _async_stop( self, aws: list[asyncio.Task], update_state: bool, spare: _ScriptRun | None ) -> None: await asyncio.wait(aws) if update_state: self._changed() async def async_stop( self, update_state: bool = True, spare: _ScriptRun | None = None ) -> None: """Stop running script.""" # Collect a a list of script runs to stop. This must be done before calling # asyncio.shield as asyncio.shield yields to the event loop, which would cause # us to wait for script runs added after the call to async_stop. aws = [ asyncio.create_task(run.async_stop()) for run in self._runs if run != spare ] if not aws: return await asyncio.shield(self._async_stop(aws, update_state, spare)) async def _async_get_condition(self, config): if isinstance(config, template.Template): config_cache_key = config.template else: config_cache_key = frozenset((k, str(v)) for k, v in config.items()) if not (cond := self._config_cache.get(config_cache_key)): cond = await condition.async_from_config(self._hass, config, False) self._config_cache[config_cache_key] = cond return cond def _prep_repeat_script(self, step: int) -> Script: action = self.sequence[step] step_name = action.get(CONF_ALIAS, f"Repeat at step {step+1}") sub_script = Script( self._hass, action[CONF_REPEAT][CONF_SEQUENCE], f"{self.name}: {step_name}", self.domain, running_description=self.running_description, script_mode=SCRIPT_MODE_PARALLEL, max_runs=self.max_runs, logger=self._logger, top_level=False, ) sub_script.change_listener = partial(self._chain_change_listener, sub_script) return sub_script def _get_repeat_script(self, step: int) -> Script: if not (sub_script := self._repeat_script.get(step)): sub_script = self._prep_repeat_script(step) self._repeat_script[step] = sub_script return sub_script async def _async_prep_choose_data(self, step: int) -> _ChooseData: action = self.sequence[step] step_name = action.get(CONF_ALIAS, f"Choose at step {step+1}") choices = [] for idx, choice in enumerate(action[CONF_CHOOSE], start=1): conditions = [ await self._async_get_condition(config) for config in choice.get(CONF_CONDITIONS, []) ] choice_name = choice.get(CONF_ALIAS, f"choice {idx}") sub_script = Script( self._hass, choice[CONF_SEQUENCE], f"{self.name}: {step_name}: {choice_name}", self.domain, running_description=self.running_description, script_mode=SCRIPT_MODE_PARALLEL, max_runs=self.max_runs, logger=self._logger, top_level=False, ) sub_script.change_listener = partial( self._chain_change_listener, sub_script ) choices.append((conditions, sub_script)) default_script: Script | None if CONF_DEFAULT in action: default_script = Script( self._hass, action[CONF_DEFAULT], f"{self.name}: {step_name}: default", self.domain, running_description=self.running_description, script_mode=SCRIPT_MODE_PARALLEL, max_runs=self.max_runs, logger=self._logger, top_level=False, ) default_script.change_listener = partial( self._chain_change_listener, default_script ) else: default_script = None return {"choices": choices, "default": default_script} async def _async_get_choose_data(self, step: int) -> _ChooseData: if not (choose_data := self._choose_data.get(step)): choose_data = await self._async_prep_choose_data(step) self._choose_data[step] = choose_data return choose_data def _log( self, msg: str, *args: Any, level: int = logging.INFO, **kwargs: Any ) -> None: msg = f"%s: {msg}" args = (self.name, *args) if level == _LOG_EXCEPTION: self._logger.exception(msg, *args, **kwargs) else: self._logger.log(level, msg, *args, **kwargs) @callback def breakpoint_clear(hass, key, run_id, node): """Clear a breakpoint.""" run_id = run_id or RUN_ID_ANY breakpoints = hass.data[DATA_SCRIPT_BREAKPOINTS] if key not in breakpoints or run_id not in breakpoints[key]: return breakpoints[key][run_id].discard(node) @callback def breakpoint_clear_all(hass: HomeAssistant) -> None: """Clear all breakpoints.""" hass.data[DATA_SCRIPT_BREAKPOINTS] = {} @callback def breakpoint_set(hass, key, run_id, node): """Set a breakpoint.""" run_id = run_id or RUN_ID_ANY breakpoints = hass.data[DATA_SCRIPT_BREAKPOINTS] if key not in breakpoints: breakpoints[key] = {} if run_id not in breakpoints[key]: breakpoints[key][run_id] = set() breakpoints[key][run_id].add(node) @callback def breakpoint_list(hass: HomeAssistant) -> list[dict[str, Any]]: """List breakpoints.""" breakpoints = hass.data[DATA_SCRIPT_BREAKPOINTS] return [ {"key": key, "run_id": run_id, "node": node} for key in breakpoints for run_id in breakpoints[key] for node in breakpoints[key][run_id] ] @callback def debug_continue(hass, key, run_id): """Continue execution of a halted script.""" # Clear any wildcard breakpoint breakpoint_clear(hass, key, run_id, NODE_ANY) signal = SCRIPT_DEBUG_CONTINUE_STOP.format(key, run_id) async_dispatcher_send(hass, signal, "continue") @callback def debug_step(hass, key, run_id): """Single step a halted script.""" # Set a wildcard breakpoint breakpoint_set(hass, key, run_id, NODE_ANY) signal = SCRIPT_DEBUG_CONTINUE_STOP.format(key, run_id) async_dispatcher_send(hass, signal, "continue") @callback def debug_stop(hass, key, run_id): """Stop execution of a running or halted script.""" signal = SCRIPT_DEBUG_CONTINUE_STOP.format(key, run_id) async_dispatcher_send(hass, signal, "stop")