core/homeassistant/helpers/script.py

823 lines
26 KiB
Python

"""Helpers to execute scripts."""
from abc import ABC, abstractmethod
import asyncio
from contextlib import suppress
from datetime import datetime
from itertools import islice
import logging
from typing import Any, Callable, Dict, List, Optional, Sequence, Set, Tuple, cast
from async_timeout import timeout
import voluptuous as vol
from homeassistant import exceptions
import homeassistant.components.device_automation as device_automation
import homeassistant.components.scene as scene
from homeassistant.const import (
ATTR_ENTITY_ID,
CONF_ALIAS,
CONF_CONDITION,
CONF_CONTINUE_ON_TIMEOUT,
CONF_DELAY,
CONF_DEVICE_ID,
CONF_DOMAIN,
CONF_EVENT,
CONF_EVENT_DATA,
CONF_EVENT_DATA_TEMPLATE,
CONF_SCENE,
CONF_TIMEOUT,
CONF_WAIT_TEMPLATE,
SERVICE_TURN_OFF,
SERVICE_TURN_ON,
)
from homeassistant.core import (
CALLBACK_TYPE,
SERVICE_CALL_LIMIT,
Context,
HomeAssistant,
callback,
)
from homeassistant.helpers import (
condition,
config_validation as cv,
template as template,
)
from homeassistant.helpers.event import (
async_track_point_in_utc_time,
async_track_template,
)
from homeassistant.helpers.service import (
CONF_SERVICE_DATA,
async_prepare_call_from_config,
)
from homeassistant.helpers.typing import ConfigType
from homeassistant.util import slugify
from homeassistant.util.dt import utcnow
# mypy: allow-untyped-calls, allow-untyped-defs, no-check-untyped-defs
SCRIPT_MODE_ERROR = "error"
SCRIPT_MODE_IGNORE = "ignore"
SCRIPT_MODE_LEGACY = "legacy"
SCRIPT_MODE_PARALLEL = "parallel"
SCRIPT_MODE_QUEUE = "queue"
SCRIPT_MODE_RESTART = "restart"
SCRIPT_MODE_CHOICES = [
SCRIPT_MODE_ERROR,
SCRIPT_MODE_IGNORE,
SCRIPT_MODE_LEGACY,
SCRIPT_MODE_PARALLEL,
SCRIPT_MODE_QUEUE,
SCRIPT_MODE_RESTART,
]
DEFAULT_SCRIPT_MODE = SCRIPT_MODE_LEGACY
DEFAULT_QUEUE_MAX = 10
_LOG_EXCEPTION = logging.ERROR + 1
_TIMEOUT_MSG = "Timeout reached, abort script."
async def async_validate_action_config(
hass: HomeAssistant, config: ConfigType
) -> ConfigType:
"""Validate config."""
action_type = cv.determine_script_action(config)
if action_type == cv.SCRIPT_ACTION_DEVICE_AUTOMATION:
platform = await device_automation.async_get_device_automation_platform(
hass, config[CONF_DOMAIN], "action"
)
config = platform.ACTION_SCHEMA(config) # type: ignore
if (
action_type == cv.SCRIPT_ACTION_CHECK_CONDITION
and config[CONF_CONDITION] == "device"
):
platform = await device_automation.async_get_device_automation_platform(
hass, config[CONF_DOMAIN], "condition"
)
config = platform.CONDITION_SCHEMA(config) # type: ignore
return config
class _StopScript(Exception):
"""Throw if script needs to stop."""
class _SuspendScript(Exception):
"""Throw if script needs to suspend."""
class AlreadyRunning(exceptions.HomeAssistantError):
"""Throw if script already running and user wants error."""
class QueueFull(exceptions.HomeAssistantError):
"""Throw if script already running, user wants new run queued, but queue is full."""
class _ScriptRunBase(ABC):
"""Common data & methods for managing Script sequence run."""
def __init__(
self,
hass: HomeAssistant,
script: "Script",
variables: Optional[Sequence],
context: Optional[Context],
log_exceptions: bool,
) -> None:
self._hass = hass
self._script = script
self._variables = variables
self._context = context
self._log_exceptions = log_exceptions
self._step = -1
self._action: Optional[Dict[str, Any]] = None
def _changed(self):
self._script._changed() # pylint: disable=protected-access
@property
def _config_cache(self):
return self._script._config_cache # pylint: disable=protected-access
@abstractmethod
async def async_run(self) -> None:
"""Run script."""
async def _async_step(self, log_exceptions):
try:
await getattr(
self, f"_async_{cv.determine_script_action(self._action)}_step"
)()
except Exception as ex:
if not isinstance(
ex, (_SuspendScript, _StopScript, asyncio.CancelledError)
) and (self._log_exceptions or log_exceptions):
self._log_exception(ex)
raise
@abstractmethod
async def async_stop(self) -> None:
"""Stop script run."""
def _log_exception(self, exception):
action_type = cv.determine_script_action(self._action)
error = str(exception)
level = logging.ERROR
if isinstance(exception, vol.Invalid):
error_desc = "Invalid data"
elif isinstance(exception, exceptions.TemplateError):
error_desc = "Error rendering template"
elif isinstance(exception, exceptions.Unauthorized):
error_desc = "Unauthorized"
elif isinstance(exception, exceptions.ServiceNotFound):
error_desc = "Service not found"
elif isinstance(exception, AlreadyRunning):
error_desc = "Already running"
elif isinstance(exception, QueueFull):
error_desc = "Run queue is full"
else:
error_desc = "Unexpected error"
level = _LOG_EXCEPTION
self._log(
"Error executing script. %s for %s at pos %s: %s",
error_desc,
action_type,
self._step + 1,
error,
level=level,
)
@abstractmethod
async def _async_delay_step(self):
"""Handle delay."""
def _prep_delay_step(self):
try:
delay = vol.All(cv.time_period, cv.positive_timedelta)(
template.render_complex(self._action[CONF_DELAY], self._variables)
)
except (exceptions.TemplateError, vol.Invalid) as ex:
self._log(
"Error rendering %s delay template: %s",
self._script.name,
ex,
level=logging.ERROR,
)
raise _StopScript
self._script.last_action = self._action.get(CONF_ALIAS, f"delay {delay}")
self._log("Executing step %s", self._script.last_action)
return delay
@abstractmethod
async def _async_wait_template_step(self):
"""Handle a wait template."""
def _prep_wait_template_step(self, async_script_wait):
wait_template = self._action[CONF_WAIT_TEMPLATE]
wait_template.hass = self._hass
self._script.last_action = self._action.get(CONF_ALIAS, "wait template")
self._log("Executing step %s", self._script.last_action)
# check if condition already okay
if condition.async_template(self._hass, wait_template, self._variables):
return None
return async_track_template(
self._hass, wait_template, async_script_wait, self._variables
)
@abstractmethod
async def _async_call_service_step(self):
"""Call the service specified in the action."""
def _prep_call_service_step(self):
self._script.last_action = self._action.get(CONF_ALIAS, "call service")
self._log("Executing step %s", self._script.last_action)
return async_prepare_call_from_config(self._hass, self._action, self._variables)
async def _async_device_step(self):
"""Perform the device automation specified in the action."""
self._script.last_action = self._action.get(CONF_ALIAS, "device automation")
self._log("Executing step %s", self._script.last_action)
platform = await device_automation.async_get_device_automation_platform(
self._hass, self._action[CONF_DOMAIN], "action"
)
await platform.async_call_action_from_config(
self._hass, self._action, self._variables, self._context
)
async def _async_scene_step(self):
"""Activate the scene specified in the action."""
self._script.last_action = self._action.get(CONF_ALIAS, "activate scene")
self._log("Executing step %s", self._script.last_action)
await self._hass.services.async_call(
scene.DOMAIN,
SERVICE_TURN_ON,
{ATTR_ENTITY_ID: self._action[CONF_SCENE]},
blocking=True,
context=self._context,
)
async def _async_event_step(self):
"""Fire an event."""
self._script.last_action = self._action.get(
CONF_ALIAS, self._action[CONF_EVENT]
)
self._log("Executing step %s", self._script.last_action)
event_data = dict(self._action.get(CONF_EVENT_DATA, {}))
if CONF_EVENT_DATA_TEMPLATE in self._action:
try:
event_data.update(
template.render_complex(
self._action[CONF_EVENT_DATA_TEMPLATE], self._variables
)
)
except exceptions.TemplateError as ex:
self._log(
"Error rendering event data template: %s", ex, level=logging.ERROR
)
self._hass.bus.async_fire(
self._action[CONF_EVENT], event_data, context=self._context
)
async def _async_condition_step(self):
"""Test if condition is matching."""
config_cache_key = frozenset((k, str(v)) for k, v in self._action.items())
config = self._config_cache.get(config_cache_key)
if not config:
config = await condition.async_from_config(self._hass, self._action, False)
self._config_cache[config_cache_key] = config
self._script.last_action = self._action.get(
CONF_ALIAS, self._action[CONF_CONDITION]
)
check = config(self._hass, self._variables)
self._log("Test condition %s: %s", self._script.last_action, check)
if not check:
raise _StopScript
def _log(self, msg, *args, level=logging.INFO):
self._script._log(msg, *args, level=level) # pylint: disable=protected-access
class _ScriptRun(_ScriptRunBase):
"""Manage Script sequence run."""
def __init__(
self,
hass: HomeAssistant,
script: "Script",
variables: Optional[Sequence],
context: Optional[Context],
log_exceptions: bool,
) -> None:
super().__init__(hass, script, variables, context, log_exceptions)
self._stop = asyncio.Event()
self._stopped = asyncio.Event()
def _changed(self):
if not self._stop.is_set():
super()._changed()
async def async_run(self) -> None:
"""Run script."""
try:
if self._stop.is_set():
return
self._script.last_triggered = utcnow()
self._changed()
self._log("Running script")
for self._step, self._action in enumerate(self._script.sequence):
if self._stop.is_set():
break
await self._async_step(log_exceptions=False)
except _StopScript:
pass
finally:
self._finish()
def _finish(self):
self._script._runs.remove(self) # pylint: disable=protected-access
if not self._script.is_running:
self._script.last_action = None
self._changed()
self._stopped.set()
async def async_stop(self) -> None:
"""Stop script run."""
self._stop.set()
await self._stopped.wait()
async def _async_delay_step(self):
"""Handle delay."""
delay = self._prep_delay_step().total_seconds()
self._changed()
try:
async with timeout(delay):
await self._stop.wait()
except asyncio.TimeoutError:
pass
async def _async_wait_template_step(self):
"""Handle a wait template."""
@callback
def async_script_wait(entity_id, from_s, to_s):
"""Handle script after template condition is true."""
done.set()
unsub = self._prep_wait_template_step(async_script_wait)
if not unsub:
return
self._changed()
try:
delay = self._action[CONF_TIMEOUT].total_seconds()
except KeyError:
delay = None
done = asyncio.Event()
try:
async with timeout(delay):
_, pending = await asyncio.wait(
{self._stop.wait(), done.wait()},
return_when=asyncio.FIRST_COMPLETED,
)
for pending_task in pending:
pending_task.cancel()
except asyncio.TimeoutError:
if not self._action.get(CONF_CONTINUE_ON_TIMEOUT, True):
self._log(_TIMEOUT_MSG)
raise _StopScript
finally:
unsub()
async def _async_call_service_step(self):
"""Call the service specified in the action."""
domain, service, service_data = self._prep_call_service_step()
# If this might start a script then disable the call timeout.
# Otherwise use the normal service call limit.
if domain == "script" and service != SERVICE_TURN_OFF:
limit = None
else:
limit = SERVICE_CALL_LIMIT
coro = self._hass.services.async_call(
domain,
service,
service_data,
blocking=True,
context=self._context,
limit=limit,
)
if limit is not None:
# There is a call limit, so just wait for it to finish.
await coro
return
# No call limit (i.e., potentially starting one or more fully blocking scripts)
# so watch for a stop request.
done, pending = await asyncio.wait(
{self._stop.wait(), coro}, return_when=asyncio.FIRST_COMPLETED,
)
# Note that cancelling the service call, if it has not yet returned, will also
# stop any non-background script runs that it may have started.
for pending_task in pending:
pending_task.cancel()
# Propagate any exceptions that might have happened.
for done_task in done:
done_task.result()
class _QueuedScriptRun(_ScriptRun):
"""Manage queued Script sequence run."""
lock_acquired = False
async def async_run(self) -> None:
"""Run script."""
# Wait for previous run, if any, to finish by attempting to acquire the script's
# shared lock. At the same time monitor if we've been told to stop.
lock_task = self._hass.async_create_task(
self._script._queue_lck.acquire() # pylint: disable=protected-access
)
done, pending = await asyncio.wait(
{self._stop.wait(), lock_task}, return_when=asyncio.FIRST_COMPLETED
)
for pending_task in pending:
pending_task.cancel()
self.lock_acquired = lock_task in done
# If we've been told to stop, then just finish up. Otherwise, we've acquired the
# lock so we can go ahead and start the run.
if self._stop.is_set():
self._finish()
else:
await super().async_run()
def _finish(self):
# pylint: disable=protected-access
self._script._queue_len -= 1
if self.lock_acquired:
self._script._queue_lck.release()
self.lock_acquired = False
super()._finish()
class _LegacyScriptRun(_ScriptRunBase):
"""Manage legacy Script sequence run."""
def __init__(
self,
hass: HomeAssistant,
script: "Script",
variables: Optional[Sequence],
context: Optional[Context],
log_exceptions: bool,
shared: Optional["_LegacyScriptRun"],
) -> None:
super().__init__(hass, script, variables, context, log_exceptions)
if shared:
self._shared = shared
else:
# To implement legacy behavior we need to share the following "run state"
# amongst all runs, so it will only exist in the first instantiation of
# concurrent runs, and the rest will use it, too.
self._current = -1
self._async_listeners: List[CALLBACK_TYPE] = []
self._shared = self
@property
def _cur(self):
return self._shared._current # pylint: disable=protected-access
@_cur.setter
def _cur(self, value):
self._shared._current = value # pylint: disable=protected-access
@property
def _async_listener(self):
return self._shared._async_listeners # pylint: disable=protected-access
async def async_run(self) -> None:
"""Run script."""
await self._async_run()
async def _async_run(self, propagate_exceptions=True):
if self._cur == -1:
self._script.last_triggered = utcnow()
self._log("Running script")
self._cur = 0
# Unregister callback if we were in a delay or wait but turn on is
# called again. In that case we just continue execution.
self._async_remove_listener()
suspended = False
try:
for self._step, self._action in islice(
enumerate(self._script.sequence), self._cur, None
):
await self._async_step(log_exceptions=not propagate_exceptions)
except _StopScript:
pass
except _SuspendScript:
# Store next step to take and notify change listeners
self._cur = self._step + 1
suspended = True
return
except Exception: # pylint: disable=broad-except
if propagate_exceptions:
raise
finally:
_cur_was = self._cur
if not suspended:
self._script.last_action = None
await self.async_stop()
if _cur_was != -1:
self._changed()
async def async_stop(self) -> None:
"""Stop script run."""
if self._cur == -1:
return
self._cur = -1
self._async_remove_listener()
self._script._runs.clear() # pylint: disable=protected-access
async def _async_delay_step(self):
"""Handle delay."""
delay = self._prep_delay_step()
@callback
def async_script_delay(now):
"""Handle delay."""
with suppress(ValueError):
self._async_listener.remove(unsub)
self._hass.async_create_task(self._async_run(False))
unsub = async_track_point_in_utc_time(
self._hass, async_script_delay, utcnow() + delay
)
self._async_listener.append(unsub)
raise _SuspendScript
async def _async_wait_template_step(self):
"""Handle a wait template."""
@callback
def async_script_wait(entity_id, from_s, to_s):
"""Handle script after template condition is true."""
self._async_remove_listener()
self._hass.async_create_task(self._async_run(False))
@callback
def async_script_timeout(now):
"""Call after timeout has expired."""
with suppress(ValueError):
self._async_listener.remove(unsub_timeout)
# Check if we want to continue to execute
# the script after the timeout
if self._action.get(CONF_CONTINUE_ON_TIMEOUT, True):
self._hass.async_create_task(self._async_run(False))
else:
self._log(_TIMEOUT_MSG)
self._hass.async_create_task(self.async_stop())
unsub_wait = self._prep_wait_template_step(async_script_wait)
if not unsub_wait:
return
self._async_listener.append(unsub_wait)
if CONF_TIMEOUT in self._action:
unsub_timeout = async_track_point_in_utc_time(
self._hass, async_script_timeout, utcnow() + self._action[CONF_TIMEOUT]
)
self._async_listener.append(unsub_timeout)
raise _SuspendScript
async def _async_call_service_step(self):
"""Call the service specified in the action."""
await self._hass.services.async_call(
*self._prep_call_service_step(), blocking=True, context=self._context
)
def _async_remove_listener(self):
"""Remove listeners, if any."""
for unsub in self._async_listener:
unsub()
self._async_listener.clear()
class Script:
"""Representation of a script."""
def __init__(
self,
hass: HomeAssistant,
sequence: Sequence[Dict[str, Any]],
name: Optional[str] = None,
change_listener: Optional[Callable[..., Any]] = None,
script_mode: str = DEFAULT_SCRIPT_MODE,
queue_max: int = DEFAULT_QUEUE_MAX,
logger: Optional[logging.Logger] = None,
log_exceptions: bool = True,
) -> None:
"""Initialize the script."""
self._hass = hass
self.sequence = sequence
template.attach(hass, self.sequence)
self.name = name
self.change_listener = change_listener
self._script_mode = script_mode
if logger:
self._logger = logger
else:
logger_name = __name__
if name:
logger_name = ".".join([logger_name, slugify(name)])
self._logger = logging.getLogger(logger_name)
self._log_exceptions = log_exceptions
self.last_action = None
self.last_triggered: Optional[datetime] = None
self.can_cancel = not self.is_legacy or any(
CONF_DELAY in action or CONF_WAIT_TEMPLATE in action
for action in self.sequence
)
self._runs: List[_ScriptRunBase] = []
if script_mode == SCRIPT_MODE_QUEUE:
self._queue_max = queue_max
self._queue_len = 0
self._queue_lck = asyncio.Lock()
self._config_cache: Dict[Set[Tuple], Callable[..., bool]] = {}
self._referenced_entities: Optional[Set[str]] = None
self._referenced_devices: Optional[Set[str]] = None
def _changed(self):
if self.change_listener:
self._hass.async_run_job(self.change_listener)
@property
def is_running(self) -> bool:
"""Return true if script is on."""
return len(self._runs) > 0
@property
def is_legacy(self) -> bool:
"""Return if using legacy mode."""
return self._script_mode == SCRIPT_MODE_LEGACY
@property
def referenced_devices(self):
"""Return a set of referenced devices."""
if self._referenced_devices is not None:
return self._referenced_devices
referenced = set()
for step in self.sequence:
action = cv.determine_script_action(step)
if action == cv.SCRIPT_ACTION_CHECK_CONDITION:
referenced |= condition.async_extract_devices(step)
elif action == cv.SCRIPT_ACTION_DEVICE_AUTOMATION:
referenced.add(step[CONF_DEVICE_ID])
self._referenced_devices = referenced
return referenced
@property
def referenced_entities(self):
"""Return a set of referenced entities."""
if self._referenced_entities is not None:
return self._referenced_entities
referenced = set()
for step in self.sequence:
action = cv.determine_script_action(step)
if action == cv.SCRIPT_ACTION_CALL_SERVICE:
data = step.get(CONF_SERVICE_DATA)
if not data:
continue
entity_ids = data.get(ATTR_ENTITY_ID)
if entity_ids is None:
continue
if isinstance(entity_ids, str):
entity_ids = [entity_ids]
for entity_id in entity_ids:
referenced.add(entity_id)
elif action == cv.SCRIPT_ACTION_CHECK_CONDITION:
referenced |= condition.async_extract_entities(step)
elif action == cv.SCRIPT_ACTION_ACTIVATE_SCENE:
referenced.add(step[CONF_SCENE])
self._referenced_entities = referenced
return referenced
def run(self, variables=None, context=None):
"""Run script."""
asyncio.run_coroutine_threadsafe(
self.async_run(variables, context), self._hass.loop
).result()
async def async_run(
self, variables: Optional[Sequence] = None, context: Optional[Context] = None
) -> None:
"""Run script."""
if self.is_running:
if self._script_mode == SCRIPT_MODE_IGNORE:
self._log("Skipping script")
return
if self._script_mode == SCRIPT_MODE_ERROR:
raise AlreadyRunning
if self._script_mode == SCRIPT_MODE_RESTART:
self._log("Restarting script")
await self.async_stop(update_state=False)
elif self._script_mode == SCRIPT_MODE_QUEUE:
self._log(
"Queueing script behind %i run%s",
self._queue_len,
"s" if self._queue_len > 1 else "",
)
if self._queue_len >= self._queue_max:
raise QueueFull
if self.is_legacy:
if self._runs:
shared = cast(Optional[_LegacyScriptRun], self._runs[0])
else:
shared = None
run: _ScriptRunBase = _LegacyScriptRun(
self._hass, self, variables, context, self._log_exceptions, shared
)
else:
if self._script_mode != SCRIPT_MODE_QUEUE:
cls = _ScriptRun
else:
cls = _QueuedScriptRun
self._queue_len += 1
run = cls(self._hass, self, variables, context, self._log_exceptions)
self._runs.append(run)
try:
if self.is_legacy:
await run.async_run()
else:
await asyncio.shield(run.async_run())
except asyncio.CancelledError:
await run.async_stop()
self._changed()
raise
async def async_stop(self, update_state: bool = True) -> None:
"""Stop running script."""
if not self.is_running:
return
await asyncio.shield(asyncio.gather(*(run.async_stop() for run in self._runs)))
if update_state:
self._changed()
def _log(self, msg, *args, level=logging.INFO):
if self.name:
msg = f"%s: {msg}"
args = [self.name, *args]
if level == _LOG_EXCEPTION:
self._logger.exception(msg, *args)
else:
self._logger.log(level, msg, *args)