"""Helpers to execute scripts.""" from abc import ABC, abstractmethod import asyncio from contextlib import suppress from datetime import datetime from itertools import islice import logging from typing import Any, Callable, Dict, List, Optional, Sequence, Set, Tuple, cast from async_timeout import timeout import voluptuous as vol from homeassistant import exceptions import homeassistant.components.device_automation as device_automation import homeassistant.components.scene as scene from homeassistant.const import ( ATTR_ENTITY_ID, CONF_ALIAS, CONF_CONDITION, CONF_CONTINUE_ON_TIMEOUT, CONF_DELAY, CONF_DEVICE_ID, CONF_DOMAIN, CONF_EVENT, CONF_EVENT_DATA, CONF_EVENT_DATA_TEMPLATE, CONF_SCENE, CONF_TIMEOUT, CONF_WAIT_TEMPLATE, SERVICE_TURN_OFF, SERVICE_TURN_ON, ) from homeassistant.core import ( CALLBACK_TYPE, SERVICE_CALL_LIMIT, Context, HomeAssistant, callback, ) from homeassistant.helpers import ( condition, config_validation as cv, template as template, ) from homeassistant.helpers.event import ( async_track_point_in_utc_time, async_track_template, ) from homeassistant.helpers.service import ( CONF_SERVICE_DATA, async_prepare_call_from_config, ) from homeassistant.helpers.typing import ConfigType from homeassistant.util import slugify from homeassistant.util.dt import utcnow # mypy: allow-untyped-calls, allow-untyped-defs, no-check-untyped-defs SCRIPT_MODE_ERROR = "error" SCRIPT_MODE_IGNORE = "ignore" SCRIPT_MODE_LEGACY = "legacy" SCRIPT_MODE_PARALLEL = "parallel" SCRIPT_MODE_QUEUE = "queue" SCRIPT_MODE_RESTART = "restart" SCRIPT_MODE_CHOICES = [ SCRIPT_MODE_ERROR, SCRIPT_MODE_IGNORE, SCRIPT_MODE_LEGACY, SCRIPT_MODE_PARALLEL, SCRIPT_MODE_QUEUE, SCRIPT_MODE_RESTART, ] DEFAULT_SCRIPT_MODE = SCRIPT_MODE_LEGACY DEFAULT_QUEUE_MAX = 10 _LOG_EXCEPTION = logging.ERROR + 1 _TIMEOUT_MSG = "Timeout reached, abort script." async def async_validate_action_config( hass: HomeAssistant, config: ConfigType ) -> ConfigType: """Validate config.""" action_type = cv.determine_script_action(config) if action_type == cv.SCRIPT_ACTION_DEVICE_AUTOMATION: platform = await device_automation.async_get_device_automation_platform( hass, config[CONF_DOMAIN], "action" ) config = platform.ACTION_SCHEMA(config) # type: ignore if ( action_type == cv.SCRIPT_ACTION_CHECK_CONDITION and config[CONF_CONDITION] == "device" ): platform = await device_automation.async_get_device_automation_platform( hass, config[CONF_DOMAIN], "condition" ) config = platform.CONDITION_SCHEMA(config) # type: ignore return config class _StopScript(Exception): """Throw if script needs to stop.""" class _SuspendScript(Exception): """Throw if script needs to suspend.""" class AlreadyRunning(exceptions.HomeAssistantError): """Throw if script already running and user wants error.""" class QueueFull(exceptions.HomeAssistantError): """Throw if script already running, user wants new run queued, but queue is full.""" class _ScriptRunBase(ABC): """Common data & methods for managing Script sequence run.""" def __init__( self, hass: HomeAssistant, script: "Script", variables: Optional[Sequence], context: Optional[Context], log_exceptions: bool, ) -> None: self._hass = hass self._script = script self._variables = variables self._context = context self._log_exceptions = log_exceptions self._step = -1 self._action: Optional[Dict[str, Any]] = None def _changed(self): self._script._changed() # pylint: disable=protected-access @property def _config_cache(self): return self._script._config_cache # pylint: disable=protected-access @abstractmethod async def async_run(self) -> None: """Run script.""" async def _async_step(self, log_exceptions): try: await getattr( self, f"_async_{cv.determine_script_action(self._action)}_step" )() except Exception as ex: if not isinstance( ex, (_SuspendScript, _StopScript, asyncio.CancelledError) ) and (self._log_exceptions or log_exceptions): self._log_exception(ex) raise @abstractmethod async def async_stop(self) -> None: """Stop script run.""" def _log_exception(self, exception): action_type = cv.determine_script_action(self._action) error = str(exception) level = logging.ERROR if isinstance(exception, vol.Invalid): error_desc = "Invalid data" elif isinstance(exception, exceptions.TemplateError): error_desc = "Error rendering template" elif isinstance(exception, exceptions.Unauthorized): error_desc = "Unauthorized" elif isinstance(exception, exceptions.ServiceNotFound): error_desc = "Service not found" elif isinstance(exception, AlreadyRunning): error_desc = "Already running" elif isinstance(exception, QueueFull): error_desc = "Run queue is full" else: error_desc = "Unexpected error" level = _LOG_EXCEPTION self._log( "Error executing script. %s for %s at pos %s: %s", error_desc, action_type, self._step + 1, error, level=level, ) @abstractmethod async def _async_delay_step(self): """Handle delay.""" def _prep_delay_step(self): try: delay = vol.All(cv.time_period, cv.positive_timedelta)( template.render_complex(self._action[CONF_DELAY], self._variables) ) except (exceptions.TemplateError, vol.Invalid) as ex: self._log( "Error rendering %s delay template: %s", self._script.name, ex, level=logging.ERROR, ) raise _StopScript self._script.last_action = self._action.get(CONF_ALIAS, f"delay {delay}") self._log("Executing step %s", self._script.last_action) return delay @abstractmethod async def _async_wait_template_step(self): """Handle a wait template.""" def _prep_wait_template_step(self, async_script_wait): wait_template = self._action[CONF_WAIT_TEMPLATE] wait_template.hass = self._hass self._script.last_action = self._action.get(CONF_ALIAS, "wait template") self._log("Executing step %s", self._script.last_action) # check if condition already okay if condition.async_template(self._hass, wait_template, self._variables): return None return async_track_template( self._hass, wait_template, async_script_wait, self._variables ) @abstractmethod async def _async_call_service_step(self): """Call the service specified in the action.""" def _prep_call_service_step(self): self._script.last_action = self._action.get(CONF_ALIAS, "call service") self._log("Executing step %s", self._script.last_action) return async_prepare_call_from_config(self._hass, self._action, self._variables) async def _async_device_step(self): """Perform the device automation specified in the action.""" self._script.last_action = self._action.get(CONF_ALIAS, "device automation") self._log("Executing step %s", self._script.last_action) platform = await device_automation.async_get_device_automation_platform( self._hass, self._action[CONF_DOMAIN], "action" ) await platform.async_call_action_from_config( self._hass, self._action, self._variables, self._context ) async def _async_scene_step(self): """Activate the scene specified in the action.""" self._script.last_action = self._action.get(CONF_ALIAS, "activate scene") self._log("Executing step %s", self._script.last_action) await self._hass.services.async_call( scene.DOMAIN, SERVICE_TURN_ON, {ATTR_ENTITY_ID: self._action[CONF_SCENE]}, blocking=True, context=self._context, ) async def _async_event_step(self): """Fire an event.""" self._script.last_action = self._action.get( CONF_ALIAS, self._action[CONF_EVENT] ) self._log("Executing step %s", self._script.last_action) event_data = dict(self._action.get(CONF_EVENT_DATA, {})) if CONF_EVENT_DATA_TEMPLATE in self._action: try: event_data.update( template.render_complex( self._action[CONF_EVENT_DATA_TEMPLATE], self._variables ) ) except exceptions.TemplateError as ex: self._log( "Error rendering event data template: %s", ex, level=logging.ERROR ) self._hass.bus.async_fire( self._action[CONF_EVENT], event_data, context=self._context ) async def _async_condition_step(self): """Test if condition is matching.""" config_cache_key = frozenset((k, str(v)) for k, v in self._action.items()) config = self._config_cache.get(config_cache_key) if not config: config = await condition.async_from_config(self._hass, self._action, False) self._config_cache[config_cache_key] = config self._script.last_action = self._action.get( CONF_ALIAS, self._action[CONF_CONDITION] ) check = config(self._hass, self._variables) self._log("Test condition %s: %s", self._script.last_action, check) if not check: raise _StopScript def _log(self, msg, *args, level=logging.INFO): self._script._log(msg, *args, level=level) # pylint: disable=protected-access class _ScriptRun(_ScriptRunBase): """Manage Script sequence run.""" def __init__( self, hass: HomeAssistant, script: "Script", variables: Optional[Sequence], context: Optional[Context], log_exceptions: bool, ) -> None: super().__init__(hass, script, variables, context, log_exceptions) self._stop = asyncio.Event() self._stopped = asyncio.Event() def _changed(self): if not self._stop.is_set(): super()._changed() async def async_run(self) -> None: """Run script.""" try: if self._stop.is_set(): return self._script.last_triggered = utcnow() self._changed() self._log("Running script") for self._step, self._action in enumerate(self._script.sequence): if self._stop.is_set(): break await self._async_step(log_exceptions=False) except _StopScript: pass finally: self._finish() def _finish(self): self._script._runs.remove(self) # pylint: disable=protected-access if not self._script.is_running: self._script.last_action = None self._changed() self._stopped.set() async def async_stop(self) -> None: """Stop script run.""" self._stop.set() await self._stopped.wait() async def _async_delay_step(self): """Handle delay.""" delay = self._prep_delay_step().total_seconds() self._changed() try: async with timeout(delay): await self._stop.wait() except asyncio.TimeoutError: pass async def _async_wait_template_step(self): """Handle a wait template.""" @callback def async_script_wait(entity_id, from_s, to_s): """Handle script after template condition is true.""" done.set() unsub = self._prep_wait_template_step(async_script_wait) if not unsub: return self._changed() try: delay = self._action[CONF_TIMEOUT].total_seconds() except KeyError: delay = None done = asyncio.Event() try: async with timeout(delay): _, pending = await asyncio.wait( {self._stop.wait(), done.wait()}, return_when=asyncio.FIRST_COMPLETED, ) for pending_task in pending: pending_task.cancel() except asyncio.TimeoutError: if not self._action.get(CONF_CONTINUE_ON_TIMEOUT, True): self._log(_TIMEOUT_MSG) raise _StopScript finally: unsub() async def _async_call_service_step(self): """Call the service specified in the action.""" domain, service, service_data = self._prep_call_service_step() # If this might start a script then disable the call timeout. # Otherwise use the normal service call limit. if domain == "script" and service != SERVICE_TURN_OFF: limit = None else: limit = SERVICE_CALL_LIMIT coro = self._hass.services.async_call( domain, service, service_data, blocking=True, context=self._context, limit=limit, ) if limit is not None: # There is a call limit, so just wait for it to finish. await coro return # No call limit (i.e., potentially starting one or more fully blocking scripts) # so watch for a stop request. done, pending = await asyncio.wait( {self._stop.wait(), coro}, return_when=asyncio.FIRST_COMPLETED, ) # Note that cancelling the service call, if it has not yet returned, will also # stop any non-background script runs that it may have started. for pending_task in pending: pending_task.cancel() # Propagate any exceptions that might have happened. for done_task in done: done_task.result() class _QueuedScriptRun(_ScriptRun): """Manage queued Script sequence run.""" lock_acquired = False async def async_run(self) -> None: """Run script.""" # Wait for previous run, if any, to finish by attempting to acquire the script's # shared lock. At the same time monitor if we've been told to stop. lock_task = self._hass.async_create_task( self._script._queue_lck.acquire() # pylint: disable=protected-access ) done, pending = await asyncio.wait( {self._stop.wait(), lock_task}, return_when=asyncio.FIRST_COMPLETED ) for pending_task in pending: pending_task.cancel() self.lock_acquired = lock_task in done # If we've been told to stop, then just finish up. Otherwise, we've acquired the # lock so we can go ahead and start the run. if self._stop.is_set(): self._finish() else: await super().async_run() def _finish(self): # pylint: disable=protected-access self._script._queue_len -= 1 if self.lock_acquired: self._script._queue_lck.release() self.lock_acquired = False super()._finish() class _LegacyScriptRun(_ScriptRunBase): """Manage legacy Script sequence run.""" def __init__( self, hass: HomeAssistant, script: "Script", variables: Optional[Sequence], context: Optional[Context], log_exceptions: bool, shared: Optional["_LegacyScriptRun"], ) -> None: super().__init__(hass, script, variables, context, log_exceptions) if shared: self._shared = shared else: # To implement legacy behavior we need to share the following "run state" # amongst all runs, so it will only exist in the first instantiation of # concurrent runs, and the rest will use it, too. self._current = -1 self._async_listeners: List[CALLBACK_TYPE] = [] self._shared = self @property def _cur(self): return self._shared._current # pylint: disable=protected-access @_cur.setter def _cur(self, value): self._shared._current = value # pylint: disable=protected-access @property def _async_listener(self): return self._shared._async_listeners # pylint: disable=protected-access async def async_run(self) -> None: """Run script.""" await self._async_run() async def _async_run(self, propagate_exceptions=True): if self._cur == -1: self._script.last_triggered = utcnow() self._log("Running script") self._cur = 0 # Unregister callback if we were in a delay or wait but turn on is # called again. In that case we just continue execution. self._async_remove_listener() suspended = False try: for self._step, self._action in islice( enumerate(self._script.sequence), self._cur, None ): await self._async_step(log_exceptions=not propagate_exceptions) except _StopScript: pass except _SuspendScript: # Store next step to take and notify change listeners self._cur = self._step + 1 suspended = True return except Exception: # pylint: disable=broad-except if propagate_exceptions: raise finally: _cur_was = self._cur if not suspended: self._script.last_action = None await self.async_stop() if _cur_was != -1: self._changed() async def async_stop(self) -> None: """Stop script run.""" if self._cur == -1: return self._cur = -1 self._async_remove_listener() self._script._runs.clear() # pylint: disable=protected-access async def _async_delay_step(self): """Handle delay.""" delay = self._prep_delay_step() @callback def async_script_delay(now): """Handle delay.""" with suppress(ValueError): self._async_listener.remove(unsub) self._hass.async_create_task(self._async_run(False)) unsub = async_track_point_in_utc_time( self._hass, async_script_delay, utcnow() + delay ) self._async_listener.append(unsub) raise _SuspendScript async def _async_wait_template_step(self): """Handle a wait template.""" @callback def async_script_wait(entity_id, from_s, to_s): """Handle script after template condition is true.""" self._async_remove_listener() self._hass.async_create_task(self._async_run(False)) @callback def async_script_timeout(now): """Call after timeout has expired.""" with suppress(ValueError): self._async_listener.remove(unsub_timeout) # Check if we want to continue to execute # the script after the timeout if self._action.get(CONF_CONTINUE_ON_TIMEOUT, True): self._hass.async_create_task(self._async_run(False)) else: self._log(_TIMEOUT_MSG) self._hass.async_create_task(self.async_stop()) unsub_wait = self._prep_wait_template_step(async_script_wait) if not unsub_wait: return self._async_listener.append(unsub_wait) if CONF_TIMEOUT in self._action: unsub_timeout = async_track_point_in_utc_time( self._hass, async_script_timeout, utcnow() + self._action[CONF_TIMEOUT] ) self._async_listener.append(unsub_timeout) raise _SuspendScript async def _async_call_service_step(self): """Call the service specified in the action.""" await self._hass.services.async_call( *self._prep_call_service_step(), blocking=True, context=self._context ) def _async_remove_listener(self): """Remove listeners, if any.""" for unsub in self._async_listener: unsub() self._async_listener.clear() class Script: """Representation of a script.""" def __init__( self, hass: HomeAssistant, sequence: Sequence[Dict[str, Any]], name: Optional[str] = None, change_listener: Optional[Callable[..., Any]] = None, script_mode: str = DEFAULT_SCRIPT_MODE, queue_max: int = DEFAULT_QUEUE_MAX, logger: Optional[logging.Logger] = None, log_exceptions: bool = True, ) -> None: """Initialize the script.""" self._hass = hass self.sequence = sequence template.attach(hass, self.sequence) self.name = name self.change_listener = change_listener self._script_mode = script_mode if logger: self._logger = logger else: logger_name = __name__ if name: logger_name = ".".join([logger_name, slugify(name)]) self._logger = logging.getLogger(logger_name) self._log_exceptions = log_exceptions self.last_action = None self.last_triggered: Optional[datetime] = None self.can_cancel = not self.is_legacy or any( CONF_DELAY in action or CONF_WAIT_TEMPLATE in action for action in self.sequence ) self._runs: List[_ScriptRunBase] = [] if script_mode == SCRIPT_MODE_QUEUE: self._queue_max = queue_max self._queue_len = 0 self._queue_lck = asyncio.Lock() self._config_cache: Dict[Set[Tuple], Callable[..., bool]] = {} self._referenced_entities: Optional[Set[str]] = None self._referenced_devices: Optional[Set[str]] = None def _changed(self): if self.change_listener: self._hass.async_run_job(self.change_listener) @property def is_running(self) -> bool: """Return true if script is on.""" return len(self._runs) > 0 @property def is_legacy(self) -> bool: """Return if using legacy mode.""" return self._script_mode == SCRIPT_MODE_LEGACY @property def referenced_devices(self): """Return a set of referenced devices.""" if self._referenced_devices is not None: return self._referenced_devices referenced = set() for step in self.sequence: action = cv.determine_script_action(step) if action == cv.SCRIPT_ACTION_CHECK_CONDITION: referenced |= condition.async_extract_devices(step) elif action == cv.SCRIPT_ACTION_DEVICE_AUTOMATION: referenced.add(step[CONF_DEVICE_ID]) self._referenced_devices = referenced return referenced @property def referenced_entities(self): """Return a set of referenced entities.""" if self._referenced_entities is not None: return self._referenced_entities referenced = set() for step in self.sequence: action = cv.determine_script_action(step) if action == cv.SCRIPT_ACTION_CALL_SERVICE: data = step.get(CONF_SERVICE_DATA) if not data: continue entity_ids = data.get(ATTR_ENTITY_ID) if entity_ids is None: continue if isinstance(entity_ids, str): entity_ids = [entity_ids] for entity_id in entity_ids: referenced.add(entity_id) elif action == cv.SCRIPT_ACTION_CHECK_CONDITION: referenced |= condition.async_extract_entities(step) elif action == cv.SCRIPT_ACTION_ACTIVATE_SCENE: referenced.add(step[CONF_SCENE]) self._referenced_entities = referenced return referenced def run(self, variables=None, context=None): """Run script.""" asyncio.run_coroutine_threadsafe( self.async_run(variables, context), self._hass.loop ).result() async def async_run( self, variables: Optional[Sequence] = None, context: Optional[Context] = None ) -> None: """Run script.""" if self.is_running: if self._script_mode == SCRIPT_MODE_IGNORE: self._log("Skipping script") return if self._script_mode == SCRIPT_MODE_ERROR: raise AlreadyRunning if self._script_mode == SCRIPT_MODE_RESTART: self._log("Restarting script") await self.async_stop(update_state=False) elif self._script_mode == SCRIPT_MODE_QUEUE: self._log( "Queueing script behind %i run%s", self._queue_len, "s" if self._queue_len > 1 else "", ) if self._queue_len >= self._queue_max: raise QueueFull if self.is_legacy: if self._runs: shared = cast(Optional[_LegacyScriptRun], self._runs[0]) else: shared = None run: _ScriptRunBase = _LegacyScriptRun( self._hass, self, variables, context, self._log_exceptions, shared ) else: if self._script_mode != SCRIPT_MODE_QUEUE: cls = _ScriptRun else: cls = _QueuedScriptRun self._queue_len += 1 run = cls(self._hass, self, variables, context, self._log_exceptions) self._runs.append(run) try: if self.is_legacy: await run.async_run() else: await asyncio.shield(run.async_run()) except asyncio.CancelledError: await run.async_stop() self._changed() raise async def async_stop(self, update_state: bool = True) -> None: """Stop running script.""" if not self.is_running: return await asyncio.shield(asyncio.gather(*(run.async_stop() for run in self._runs))) if update_state: self._changed() def _log(self, msg, *args, level=logging.INFO): if self.name: msg = f"%s: {msg}" args = [self.name, *args] if level == _LOG_EXCEPTION: self._logger.exception(msg, *args) else: self._logger.log(level, msg, *args)