core/homeassistant/helpers/script.py

744 lines
24 KiB
Python

"""Helpers to execute scripts."""
from abc import ABC, abstractmethod
import asyncio
from contextlib import suppress
from datetime import datetime
from itertools import islice
import logging
from typing import Any, Callable, Dict, List, Optional, Sequence, Set, Tuple, cast
import voluptuous as vol
from homeassistant import exceptions
import homeassistant.components.device_automation as device_automation
import homeassistant.components.scene as scene
from homeassistant.const import (
ATTR_ENTITY_ID,
CONF_CONDITION,
CONF_DEVICE_ID,
CONF_DOMAIN,
CONF_TIMEOUT,
SERVICE_TURN_ON,
)
from homeassistant.core import CALLBACK_TYPE, Context, HomeAssistant, callback
from homeassistant.helpers import (
condition,
config_validation as cv,
service,
template as template,
)
from homeassistant.helpers.event import (
async_track_point_in_utc_time,
async_track_template,
)
from homeassistant.helpers.typing import ConfigType
from homeassistant.util.dt import utcnow
# mypy: allow-untyped-calls, allow-untyped-defs, no-check-untyped-defs
CONF_ALIAS = "alias"
CONF_SERVICE = "service"
CONF_SERVICE_DATA = "data"
CONF_SEQUENCE = "sequence"
CONF_EVENT = "event"
CONF_EVENT_DATA = "event_data"
CONF_EVENT_DATA_TEMPLATE = "event_data_template"
CONF_DELAY = "delay"
CONF_WAIT_TEMPLATE = "wait_template"
CONF_CONTINUE = "continue_on_timeout"
CONF_SCENE = "scene"
ACTION_DELAY = "delay"
ACTION_WAIT_TEMPLATE = "wait_template"
ACTION_CHECK_CONDITION = "condition"
ACTION_FIRE_EVENT = "event"
ACTION_CALL_SERVICE = "call_service"
ACTION_DEVICE_AUTOMATION = "device"
ACTION_ACTIVATE_SCENE = "scene"
IF_RUNNING_ERROR = "error"
IF_RUNNING_IGNORE = "ignore"
IF_RUNNING_PARALLEL = "parallel"
IF_RUNNING_RESTART = "restart"
# First choice is default
IF_RUNNING_CHOICES = [
IF_RUNNING_PARALLEL,
IF_RUNNING_ERROR,
IF_RUNNING_IGNORE,
IF_RUNNING_RESTART,
]
RUN_MODE_BACKGROUND = "background"
RUN_MODE_BLOCKING = "blocking"
RUN_MODE_LEGACY = "legacy"
# First choice is default
RUN_MODE_CHOICES = [
RUN_MODE_BLOCKING,
RUN_MODE_BACKGROUND,
RUN_MODE_LEGACY,
]
_LOG_EXCEPTION = logging.ERROR + 1
_TIMEOUT_MSG = "Timeout reached, abort script."
def _determine_action(action):
"""Determine action type."""
if CONF_DELAY in action:
return ACTION_DELAY
if CONF_WAIT_TEMPLATE in action:
return ACTION_WAIT_TEMPLATE
if CONF_CONDITION in action:
return ACTION_CHECK_CONDITION
if CONF_EVENT in action:
return ACTION_FIRE_EVENT
if CONF_DEVICE_ID in action:
return ACTION_DEVICE_AUTOMATION
if CONF_SCENE in action:
return ACTION_ACTIVATE_SCENE
return ACTION_CALL_SERVICE
async def async_validate_action_config(
hass: HomeAssistant, config: ConfigType
) -> ConfigType:
"""Validate config."""
action_type = _determine_action(config)
if action_type == ACTION_DEVICE_AUTOMATION:
platform = await device_automation.async_get_device_automation_platform(
hass, config[CONF_DOMAIN], "action"
)
config = platform.ACTION_SCHEMA(config) # type: ignore
if action_type == ACTION_CHECK_CONDITION and config[CONF_CONDITION] == "device":
platform = await device_automation.async_get_device_automation_platform(
hass, config[CONF_DOMAIN], "condition"
)
config = platform.CONDITION_SCHEMA(config) # type: ignore
return config
class _StopScript(Exception):
"""Throw if script needs to stop."""
class _SuspendScript(Exception):
"""Throw if script needs to suspend."""
class _ScriptRunBase(ABC):
"""Common data & methods for managing Script sequence run."""
def __init__(
self,
hass: HomeAssistant,
script: "Script",
variables: Optional[Sequence],
context: Optional[Context],
log_exceptions: bool,
) -> None:
self._hass = hass
self._script = script
self._variables = variables
self._context = context
self._log_exceptions = log_exceptions
self._step = -1
self._action: Optional[Dict[str, Any]] = None
def _changed(self):
self._script._changed() # pylint: disable=protected-access
@property
def _config_cache(self):
return self._script._config_cache # pylint: disable=protected-access
@abstractmethod
async def async_run(self) -> None:
"""Run script."""
async def _async_step(self, log_exceptions):
try:
await getattr(self, f"_async_{_determine_action(self._action)}_step")()
except Exception as err:
if not isinstance(err, (_SuspendScript, _StopScript)) and (
self._log_exceptions or log_exceptions
):
self._log_exception(err)
raise
@abstractmethod
async def async_stop(self) -> None:
"""Stop script run."""
def _log_exception(self, exception):
action_type = _determine_action(self._action)
error = str(exception)
level = logging.ERROR
if isinstance(exception, vol.Invalid):
error_desc = "Invalid data"
elif isinstance(exception, exceptions.TemplateError):
error_desc = "Error rendering template"
elif isinstance(exception, exceptions.Unauthorized):
error_desc = "Unauthorized"
elif isinstance(exception, exceptions.ServiceNotFound):
error_desc = "Service not found"
else:
error_desc = "Unexpected error"
level = _LOG_EXCEPTION
self._log(
"Error executing script. %s for %s at pos %s: %s",
error_desc,
action_type,
self._step + 1,
error,
level=level,
)
@abstractmethod
async def _async_delay_step(self):
"""Handle delay."""
def _prep_delay_step(self):
try:
delay = vol.All(cv.time_period, cv.positive_timedelta)(
template.render_complex(self._action[CONF_DELAY], self._variables)
)
except (exceptions.TemplateError, vol.Invalid) as ex:
self._raise(
"Error rendering %s delay template: %s",
self._script.name,
ex,
exception=_StopScript,
)
self._script.last_action = self._action.get(CONF_ALIAS, f"delay {delay}")
self._log("Executing step %s", self._script.last_action)
return delay
@abstractmethod
async def _async_wait_template_step(self):
"""Handle a wait template."""
def _prep_wait_template_step(self, async_script_wait):
wait_template = self._action[CONF_WAIT_TEMPLATE]
wait_template.hass = self._hass
self._script.last_action = self._action.get(CONF_ALIAS, "wait template")
self._log("Executing step %s", self._script.last_action)
# check if condition already okay
if condition.async_template(self._hass, wait_template, self._variables):
return None
return async_track_template(
self._hass, wait_template, async_script_wait, self._variables
)
async def _async_call_service_step(self):
"""Call the service specified in the action."""
self._script.last_action = self._action.get(CONF_ALIAS, "call service")
self._log("Executing step %s", self._script.last_action)
await service.async_call_from_config(
self._hass,
self._action,
blocking=True,
variables=self._variables,
validate_config=False,
context=self._context,
)
async def _async_device_step(self):
"""Perform the device automation specified in the action."""
self._script.last_action = self._action.get(CONF_ALIAS, "device automation")
self._log("Executing step %s", self._script.last_action)
platform = await device_automation.async_get_device_automation_platform(
self._hass, self._action[CONF_DOMAIN], "action"
)
await platform.async_call_action_from_config(
self._hass, self._action, self._variables, self._context
)
async def _async_scene_step(self):
"""Activate the scene specified in the action."""
self._script.last_action = self._action.get(CONF_ALIAS, "activate scene")
self._log("Executing step %s", self._script.last_action)
await self._hass.services.async_call(
scene.DOMAIN,
SERVICE_TURN_ON,
{ATTR_ENTITY_ID: self._action[CONF_SCENE]},
blocking=True,
context=self._context,
)
async def _async_event_step(self):
"""Fire an event."""
self._script.last_action = self._action.get(
CONF_ALIAS, self._action[CONF_EVENT]
)
self._log("Executing step %s", self._script.last_action)
event_data = dict(self._action.get(CONF_EVENT_DATA, {}))
if CONF_EVENT_DATA_TEMPLATE in self._action:
try:
event_data.update(
template.render_complex(
self._action[CONF_EVENT_DATA_TEMPLATE], self._variables
)
)
except exceptions.TemplateError as ex:
self._log(
"Error rendering event data template: %s", ex, level=logging.ERROR
)
self._hass.bus.async_fire(
self._action[CONF_EVENT], event_data, context=self._context
)
async def _async_condition_step(self):
"""Test if condition is matching."""
config_cache_key = frozenset((k, str(v)) for k, v in self._action.items())
config = self._config_cache.get(config_cache_key)
if not config:
config = await condition.async_from_config(self._hass, self._action, False)
self._config_cache[config_cache_key] = config
self._script.last_action = self._action.get(
CONF_ALIAS, self._action[CONF_CONDITION]
)
check = config(self._hass, self._variables)
self._log("Test condition %s: %s", self._script.last_action, check)
if not check:
raise _StopScript
def _log(self, msg, *args, level=logging.INFO):
self._script._log(msg, *args, level=level) # pylint: disable=protected-access
def _raise(self, msg, *args, exception=None):
# pylint: disable=protected-access
self._script._raise(msg, *args, exception=exception)
class _ScriptRun(_ScriptRunBase):
"""Manage Script sequence run."""
def __init__(
self,
hass: HomeAssistant,
script: "Script",
variables: Optional[Sequence],
context: Optional[Context],
log_exceptions: bool,
) -> None:
super().__init__(hass, script, variables, context, log_exceptions)
self._stop = asyncio.Event()
self._stopped = asyncio.Event()
async def _async_run(self, propagate_exceptions=True):
self._log("Running script")
try:
for self._step, self._action in enumerate(self._script.sequence):
if self._stop.is_set():
break
await self._async_step(not propagate_exceptions)
except _StopScript:
pass
except Exception: # pylint: disable=broad-except
if propagate_exceptions:
raise
finally:
if not self._stop.is_set():
self._changed()
self._script.last_action = None
self._script._runs.remove(self) # pylint: disable=protected-access
self._stopped.set()
async def async_stop(self) -> None:
"""Stop script run."""
self._stop.set()
await self._stopped.wait()
async def _async_delay_step(self):
"""Handle delay."""
timeout = self._prep_delay_step().total_seconds()
if not self._stop.is_set():
self._changed()
await asyncio.wait({self._stop.wait()}, timeout=timeout)
async def _async_wait_template_step(self):
"""Handle a wait template."""
@callback
def async_script_wait(entity_id, from_s, to_s):
"""Handle script after template condition is true."""
done.set()
unsub = self._prep_wait_template_step(async_script_wait)
if not unsub:
return
if not self._stop.is_set():
self._changed()
try:
timeout = self._action[CONF_TIMEOUT].total_seconds()
except KeyError:
timeout = None
done = asyncio.Event()
try:
await asyncio.wait_for(
asyncio.wait(
{self._stop.wait(), done.wait()},
return_when=asyncio.FIRST_COMPLETED,
),
timeout,
)
except asyncio.TimeoutError:
if not self._action.get(CONF_CONTINUE, True):
self._log(_TIMEOUT_MSG)
raise _StopScript
finally:
unsub()
class _BackgroundScriptRun(_ScriptRun):
"""Manage background Script sequence run."""
async def async_run(self) -> None:
"""Run script."""
self._hass.async_create_task(self._async_run(False))
class _BlockingScriptRun(_ScriptRun):
"""Manage blocking Script sequence run."""
async def async_run(self) -> None:
"""Run script."""
try:
await asyncio.shield(self._async_run())
except asyncio.CancelledError:
await self.async_stop()
raise
class _LegacyScriptRun(_ScriptRunBase):
"""Manage legacy Script sequence run."""
def __init__(
self,
hass: HomeAssistant,
script: "Script",
variables: Optional[Sequence],
context: Optional[Context],
log_exceptions: bool,
shared: Optional["_LegacyScriptRun"],
) -> None:
super().__init__(hass, script, variables, context, log_exceptions)
if shared:
self._shared = shared
else:
# To implement legacy behavior we need to share the following "run state"
# amongst all runs, so it will only exist in the first instantiation of
# concurrent runs, and the rest will use it, too.
self._current = -1
self._async_listeners: List[CALLBACK_TYPE] = []
self._shared = self
@property
def _cur(self):
return self._shared._current # pylint: disable=protected-access
@_cur.setter
def _cur(self, value):
self._shared._current = value # pylint: disable=protected-access
@property
def _async_listener(self):
return self._shared._async_listeners # pylint: disable=protected-access
async def async_run(self) -> None:
"""Run script."""
await self._async_run()
async def _async_run(self, propagate_exceptions=True):
if self._cur == -1:
self._log("Running script")
self._cur = 0
# Unregister callback if we were in a delay or wait but turn on is
# called again. In that case we just continue execution.
self._async_remove_listener()
suspended = False
try:
for self._step, self._action in islice(
enumerate(self._script.sequence), self._cur, None
):
await self._async_step(not propagate_exceptions)
except _StopScript:
pass
except _SuspendScript:
# Store next step to take and notify change listeners
self._cur = self._step + 1
suspended = True
return
except Exception: # pylint: disable=broad-except
if propagate_exceptions:
raise
finally:
if self._cur != -1:
self._changed()
if not suspended:
self._script.last_action = None
await self.async_stop()
async def async_stop(self) -> None:
"""Stop script run."""
if self._cur == -1:
return
self._cur = -1
self._async_remove_listener()
self._script._runs.clear() # pylint: disable=protected-access
async def _async_delay_step(self):
"""Handle delay."""
delay = self._prep_delay_step()
@callback
def async_script_delay(now):
"""Handle delay."""
with suppress(ValueError):
self._async_listener.remove(unsub)
self._hass.async_create_task(self._async_run(False))
unsub = async_track_point_in_utc_time(
self._hass, async_script_delay, utcnow() + delay
)
self._async_listener.append(unsub)
raise _SuspendScript
async def _async_wait_template_step(self):
"""Handle a wait template."""
@callback
def async_script_wait(entity_id, from_s, to_s):
"""Handle script after template condition is true."""
self._async_remove_listener()
self._hass.async_create_task(self._async_run(False))
@callback
def async_script_timeout(now):
"""Call after timeout is retrieve."""
with suppress(ValueError):
self._async_listener.remove(unsub)
# Check if we want to continue to execute
# the script after the timeout
if self._action.get(CONF_CONTINUE, True):
self._hass.async_create_task(self._async_run(False))
else:
self._log(_TIMEOUT_MSG)
self._hass.async_create_task(self.async_stop())
unsub_wait = self._prep_wait_template_step(async_script_wait)
if not unsub_wait:
return
self._async_listener.append(unsub_wait)
if CONF_TIMEOUT in self._action:
unsub = async_track_point_in_utc_time(
self._hass, async_script_timeout, utcnow() + self._action[CONF_TIMEOUT]
)
self._async_listener.append(unsub)
raise _SuspendScript
def _async_remove_listener(self):
"""Remove listeners, if any."""
for unsub in self._async_listener:
unsub()
self._async_listener.clear()
class Script:
"""Representation of a script."""
def __init__(
self,
hass: HomeAssistant,
sequence: Sequence[Dict[str, Any]],
name: Optional[str] = None,
change_listener: Optional[Callable[..., Any]] = None,
if_running: Optional[str] = None,
run_mode: Optional[str] = None,
logger: Optional[logging.Logger] = None,
log_exceptions: bool = True,
) -> None:
"""Initialize the script."""
self._logger = logger or logging.getLogger(__name__)
self._hass = hass
self.sequence = sequence
template.attach(hass, self.sequence)
self.name = name
self._change_listener = change_listener
self.last_action = None
self.last_triggered: Optional[datetime] = None
self.can_cancel = any(
CONF_DELAY in action or CONF_WAIT_TEMPLATE in action
for action in self.sequence
)
if not if_running and not run_mode:
self._if_running = IF_RUNNING_PARALLEL
self._run_mode = RUN_MODE_LEGACY
elif if_running and run_mode == RUN_MODE_LEGACY:
self._raise('Cannot use if_running if run_mode is "legacy"')
else:
self._if_running = if_running or IF_RUNNING_CHOICES[0]
self._run_mode = run_mode or RUN_MODE_CHOICES[0]
self._runs: List[_ScriptRunBase] = []
self._log_exceptions = log_exceptions
self._config_cache: Dict[Set[Tuple], Callable[..., bool]] = {}
self._referenced_entities: Optional[Set[str]] = None
self._referenced_devices: Optional[Set[str]] = None
def _changed(self):
if self._change_listener:
self._hass.async_add_job(self._change_listener)
@property
def is_running(self) -> bool:
"""Return true if script is on."""
return len(self._runs) > 0
@property
def referenced_devices(self):
"""Return a set of referenced devices."""
if self._referenced_devices is not None:
return self._referenced_devices
referenced = set()
for step in self.sequence:
action = _determine_action(step)
if action == ACTION_CHECK_CONDITION:
referenced |= condition.async_extract_devices(step)
elif action == ACTION_DEVICE_AUTOMATION:
referenced.add(step[CONF_DEVICE_ID])
self._referenced_devices = referenced
return referenced
@property
def referenced_entities(self):
"""Return a set of referenced entities."""
if self._referenced_entities is not None:
return self._referenced_entities
referenced = set()
for step in self.sequence:
action = _determine_action(step)
if action == ACTION_CALL_SERVICE:
data = step.get(service.CONF_SERVICE_DATA)
if not data:
continue
entity_ids = data.get(ATTR_ENTITY_ID)
if entity_ids is None:
continue
if isinstance(entity_ids, str):
entity_ids = [entity_ids]
for entity_id in entity_ids:
referenced.add(entity_id)
elif action == ACTION_CHECK_CONDITION:
referenced |= condition.async_extract_entities(step)
elif action == ACTION_ACTIVATE_SCENE:
referenced.add(step[CONF_SCENE])
self._referenced_entities = referenced
return referenced
def run(self, variables=None, context=None):
"""Run script."""
asyncio.run_coroutine_threadsafe(
self.async_run(variables, context), self._hass.loop
).result()
async def async_run(
self, variables: Optional[Sequence] = None, context: Optional[Context] = None
) -> None:
"""Run script."""
if self.is_running:
if self._if_running == IF_RUNNING_IGNORE:
self._log("Skipping script")
return
if self._if_running == IF_RUNNING_ERROR:
self._raise("Already running")
if self._if_running == IF_RUNNING_RESTART:
self._log("Restarting script")
await self.async_stop()
self.last_triggered = utcnow()
if self._run_mode == RUN_MODE_LEGACY:
if self._runs:
shared = cast(Optional[_LegacyScriptRun], self._runs[0])
else:
shared = None
run: _ScriptRunBase = _LegacyScriptRun(
self._hass, self, variables, context, self._log_exceptions, shared
)
else:
if self._run_mode == RUN_MODE_BACKGROUND:
run = _BackgroundScriptRun(
self._hass, self, variables, context, self._log_exceptions
)
else:
run = _BlockingScriptRun(
self._hass, self, variables, context, self._log_exceptions
)
self._runs.append(run)
await run.async_run()
async def async_stop(self) -> None:
"""Stop running script."""
if not self.is_running:
return
await asyncio.shield(asyncio.gather(*(run.async_stop() for run in self._runs)))
self._changed()
def _log(self, msg, *args, level=logging.INFO):
if self.name:
msg = f"{self.name}: {msg}"
if level == _LOG_EXCEPTION:
self._logger.exception(msg, *args)
else:
self._logger.log(level, msg, *args)
def _raise(self, msg, *args, exception=None):
if not exception:
exception = exceptions.HomeAssistantError
self._log(msg, *args, level=logging.ERROR)
raise exception(msg % args)