From 17401cbc29ae7ba07b4d4fa0ea7fad3acc1ae71f Mon Sep 17 00:00:00 2001 From: Erik Montnemery Date: Thu, 4 Mar 2021 14:16:24 +0100 Subject: [PATCH] Initial automation tracing (#46755) * Initial prototype of automation tracing * Small fixes * Lint * Move trace helpers to its own file * Improve trace for state and numeric_state conditions * Tweaks + apply suggestions from code review * Index traces by automation_id, trace while script is running * Refactor condition tracing * Improve WS API to get traces for single automation * Add tests * Fix imports * Fix imports * Address review comments * Cap logging of loops * Remove unused ContextVar action_config --- .../components/automation/__init__.py | 261 +++++++++++--- homeassistant/components/automation/config.py | 18 + homeassistant/components/config/automation.py | 27 ++ homeassistant/components/script/__init__.py | 11 +- homeassistant/helpers/condition.py | 220 +++++++++++- homeassistant/helpers/script.py | 206 +++++++++-- homeassistant/helpers/trace.py | 78 +++++ tests/components/config/test_automation.py | 161 +++++++++ tests/helpers/test_condition.py | 330 ++++++++++++++++++ tests/helpers/test_script.py | 43 ++- 10 files changed, 1255 insertions(+), 100 deletions(-) create mode 100644 homeassistant/helpers/trace.py diff --git a/homeassistant/components/automation/__init__.py b/homeassistant/components/automation/__init__.py index 1b1a927d147..acb28df05b0 100644 --- a/homeassistant/components/automation/__init__.py +++ b/homeassistant/components/automation/__init__.py @@ -1,4 +1,6 @@ """Allow to set up simple automation rules via the config file.""" +from collections import deque +from contextlib import contextmanager import logging from typing import Any, Awaitable, Callable, Dict, List, Optional, Set, Union, cast @@ -39,6 +41,11 @@ from homeassistant.exceptions import ( HomeAssistantError, ) from homeassistant.helpers import condition, extract_domain_configs, template +from homeassistant.helpers.condition import ( + condition_path, + condition_trace_clear, + condition_trace_get, +) import homeassistant.helpers.config_validation as cv from homeassistant.helpers.entity import ToggleEntity from homeassistant.helpers.entity_component import EntityComponent @@ -50,17 +57,22 @@ from homeassistant.helpers.script import ( CONF_MAX, CONF_MAX_EXCEEDED, Script, + action_path, + action_trace_clear, + action_trace_get, ) from homeassistant.helpers.script_variables import ScriptVariables from homeassistant.helpers.service import async_register_admin_service from homeassistant.helpers.trigger import async_initialize_triggers from homeassistant.helpers.typing import TemplateVarsType from homeassistant.loader import bind_hass +from homeassistant.util import dt as dt_util from homeassistant.util.dt import parse_datetime +from .config import AutomationConfig, async_validate_config_item + # Not used except by packages to check config structure from .config import PLATFORM_SCHEMA # noqa: F401 -from .config import async_validate_config_item from .const import ( CONF_ACTION, CONF_INITIAL_STATE, @@ -90,6 +102,10 @@ ATTR_SOURCE = "source" ATTR_VARIABLES = "variables" SERVICE_TRIGGER = "trigger" +DATA_AUTOMATION_TRACE = "automation_trace" +STORED_TRACES = 5 # Stored traces per automation + +_LOGGER = logging.getLogger(__name__) AutomationActionType = Callable[[HomeAssistant, TemplateVarsType], Awaitable[None]] @@ -166,8 +182,9 @@ def devices_in_automation(hass: HomeAssistant, entity_id: str) -> List[str]: async def async_setup(hass, config): - """Set up the automation.""" + """Set up all automations.""" hass.data[DOMAIN] = component = EntityComponent(LOGGER, DOMAIN, hass) + hass.data.setdefault(DATA_AUTOMATION_TRACE, {}) # To register the automation blueprints async_get_blueprints(hass) @@ -176,7 +193,7 @@ async def async_setup(hass, config): await async_get_blueprints(hass).async_populate() async def trigger_service_handler(entity, service_call): - """Handle automation triggers.""" + """Handle forced automation trigger, e.g. from frontend.""" await entity.async_trigger( service_call.data[ATTR_VARIABLES], skip_condition=service_call.data[CONF_SKIP_CONDITION], @@ -215,6 +232,103 @@ async def async_setup(hass, config): return True +class AutomationTrace: + """Container for automation trace.""" + + def __init__(self, unique_id, config, trigger, context, action_trace): + """Container for automation trace.""" + self._action_trace = action_trace + self._condition_trace = None + self._config = config + self._context = context + self._error = None + self._state = "running" + self._timestamp_finish = None + self._timestamp_start = dt_util.utcnow() + self._trigger = trigger + self._unique_id = unique_id + self._variables = None + + def set_error(self, ex): + """Set error.""" + self._error = ex + + def set_variables(self, variables): + """Set variables.""" + self._variables = variables + + def set_condition_trace(self, condition_trace): + """Set condition trace.""" + self._condition_trace = condition_trace + + def finished(self): + """Set finish time.""" + self._timestamp_finish = dt_util.utcnow() + self._state = "stopped" + + def as_dict(self): + """Return dictionary version of this AutomationTrace.""" + + action_traces = {} + condition_traces = {} + for key, trace_list in self._action_trace.items(): + action_traces[key] = [item.as_dict() for item in trace_list] + + if self._condition_trace: + for key, trace_list in self._condition_trace.items(): + condition_traces[key] = [item.as_dict() for item in trace_list] + + result = { + "action_trace": action_traces, + "condition_trace": condition_traces, + "config": self._config, + "context": self._context, + "state": self._state, + "timestamp": { + "start": self._timestamp_start, + "finish": self._timestamp_finish, + }, + "trigger": self._trigger, + "unique_id": self._unique_id, + "variables": self._variables, + } + if self._error is not None: + result["error"] = str(self._error) + return result + + +@contextmanager +def trace_automation(hass, unique_id, config, trigger, context): + """Trace action execution of automation with automation_id.""" + action_trace_clear() + action_trace = action_trace_get() + automation_trace = AutomationTrace( + unique_id, config, trigger, context, action_trace + ) + + if unique_id: + if unique_id not in hass.data[DATA_AUTOMATION_TRACE]: + hass.data[DATA_AUTOMATION_TRACE][unique_id] = deque([], STORED_TRACES) + traces = hass.data[DATA_AUTOMATION_TRACE][unique_id] + traces.append(automation_trace) + + try: + yield automation_trace + except Exception as ex: # pylint: disable=broad-except + if unique_id: + automation_trace.set_error(ex) + raise ex + finally: + if unique_id: + automation_trace.finished() + _LOGGER.debug( + "Automation finished. Summary:\n\ttrigger: %s\n\tcondition: %s\n\taction: %s", + automation_trace._trigger, # pylint: disable=protected-access + automation_trace._condition_trace, # pylint: disable=protected-access + action_trace, + ) + + class AutomationEntity(ToggleEntity, RestoreEntity): """Entity to show status of entity.""" @@ -228,6 +342,7 @@ class AutomationEntity(ToggleEntity, RestoreEntity): initial_state, variables, trigger_variables, + raw_config, ): """Initialize an automation entity.""" self._id = automation_id @@ -244,6 +359,7 @@ class AutomationEntity(ToggleEntity, RestoreEntity): self._logger = LOGGER self._variables: ScriptVariables = variables self._trigger_variables: ScriptVariables = trigger_variables + self._raw_config = raw_config @property def name(self): @@ -374,52 +490,73 @@ class AutomationEntity(ToggleEntity, RestoreEntity): This method is a coroutine. """ - if self._variables: - try: - variables = self._variables.async_render(self.hass, run_variables) - except template.TemplateError as err: - self._logger.error("Error rendering variables: %s", err) + reason = "" + if "trigger" in run_variables and "description" in run_variables["trigger"]: + reason = f' by {run_variables["trigger"]["description"]}' + self._logger.debug("Automation triggered%s", reason) + + trigger = run_variables["trigger"] if "trigger" in run_variables else None + with trace_automation( + self.hass, self.unique_id, self._raw_config, trigger, context + ) as automation_trace: + if self._variables: + try: + variables = self._variables.async_render(self.hass, run_variables) + except template.TemplateError as err: + self._logger.error("Error rendering variables: %s", err) + automation_trace.set_error(err) + return + else: + variables = run_variables + automation_trace.set_variables(variables) + + if ( + not skip_condition + and self._cond_func is not None + and not self._cond_func(variables) + ): + self._logger.debug( + "Conditions not met, aborting automation. Condition summary: %s", + condition_trace_get(), + ) + automation_trace.set_condition_trace(condition_trace_get()) return - else: - variables = run_variables + automation_trace.set_condition_trace(condition_trace_get()) + condition_trace_clear() - if ( - not skip_condition - and self._cond_func is not None - and not self._cond_func(variables) - ): - return + # Create a new context referring to the old context. + parent_id = None if context is None else context.id + trigger_context = Context(parent_id=parent_id) - # Create a new context referring to the old context. - parent_id = None if context is None else context.id - trigger_context = Context(parent_id=parent_id) + self.async_set_context(trigger_context) + event_data = { + ATTR_NAME: self._name, + ATTR_ENTITY_ID: self.entity_id, + } + if "trigger" in variables and "description" in variables["trigger"]: + event_data[ATTR_SOURCE] = variables["trigger"]["description"] - self.async_set_context(trigger_context) - event_data = { - ATTR_NAME: self._name, - ATTR_ENTITY_ID: self.entity_id, - } - if "trigger" in variables and "description" in variables["trigger"]: - event_data[ATTR_SOURCE] = variables["trigger"]["description"] + @callback + def started_action(): + self.hass.bus.async_fire( + EVENT_AUTOMATION_TRIGGERED, event_data, context=trigger_context + ) - @callback - def started_action(): - self.hass.bus.async_fire( - EVENT_AUTOMATION_TRIGGERED, event_data, context=trigger_context - ) - - try: - await self.action_script.async_run( - variables, trigger_context, started_action - ) - except (vol.Invalid, HomeAssistantError) as err: - self._logger.error( - "Error while executing automation %s: %s", - self.entity_id, - err, - ) - except Exception: # pylint: disable=broad-except - self._logger.exception("While executing automation %s", self.entity_id) + try: + with action_path("action"): + await self.action_script.async_run( + variables, trigger_context, started_action + ) + except (vol.Invalid, HomeAssistantError) as err: + self._logger.error( + "Error while executing automation %s: %s", + self.entity_id, + err, + ) + automation_trace.set_error(err) + except Exception as err: # pylint: disable=broad-except + self._logger.exception("While executing automation %s", self.entity_id) + automation_trace.set_error(err) async def async_will_remove_from_hass(self): """Remove listeners when removing automation from Home Assistant.""" @@ -527,16 +664,16 @@ async def _async_process_config( ] for list_no, config_block in enumerate(conf): + raw_config = None if isinstance(config_block, blueprint.BlueprintInputs): # type: ignore blueprints_used = True blueprint_inputs = config_block try: + raw_config = blueprint_inputs.async_substitute() config_block = cast( Dict[str, Any], - await async_validate_config_item( - hass, blueprint_inputs.async_substitute() - ), + await async_validate_config_item(hass, raw_config), ) except vol.Invalid as err: LOGGER.error( @@ -546,6 +683,8 @@ async def _async_process_config( humanize_error(config_block, err), ) continue + else: + raw_config = cast(AutomationConfig, config_block).raw_config automation_id = config_block.get(CONF_ID) name = config_block.get(CONF_ALIAS) or f"{config_key} {list_no}" @@ -596,6 +735,7 @@ async def _async_process_config( initial_state, variables, config_block.get(CONF_TRIGGER_VARIABLES), + raw_config, ) entities.append(entity) @@ -623,8 +763,9 @@ async def _async_process_if(hass, name, config, p_config): errors = [] for index, check in enumerate(checks): try: - if not check(hass, variables): - return False + with condition_path(["condition", str(index)]): + if not check(hass, variables): + return False except ConditionError as ex: errors.append( ConditionErrorIndex( @@ -672,3 +813,25 @@ def _trigger_extract_entities(trigger_conf: dict) -> List[str]: return ["sun.sun"] return [] + + +@callback +def get_debug_traces_for_automation(hass, automation_id): + """Return a serializable list of debug traces for an automation.""" + traces = [] + + for trace in hass.data[DATA_AUTOMATION_TRACE].get(automation_id, []): + traces.append(trace.as_dict()) + + return traces + + +@callback +def get_debug_traces(hass): + """Return a serializable list of debug traces.""" + traces = {} + + for automation_id in hass.data[DATA_AUTOMATION_TRACE]: + traces[automation_id] = get_debug_traces_for_automation(hass, automation_id) + + return traces diff --git a/homeassistant/components/automation/config.py b/homeassistant/components/automation/config.py index 32ad92cb86e..5abff8fe974 100644 --- a/homeassistant/components/automation/config.py +++ b/homeassistant/components/automation/config.py @@ -79,8 +79,21 @@ async def async_validate_config_item(hass, config, full_config=None): return config +class AutomationConfig(dict): + """Dummy class to allow adding attributes.""" + + raw_config = None + + async def _try_async_validate_config_item(hass, config, full_config=None): """Validate config item.""" + raw_config = None + try: + raw_config = dict(config) + except ValueError: + # Invalid config + pass + try: config = await async_validate_config_item(hass, config, full_config) except ( @@ -92,6 +105,11 @@ async def _try_async_validate_config_item(hass, config, full_config=None): async_log_exception(ex, DOMAIN, full_config or config, hass) return None + if isinstance(config, blueprint.BlueprintInputs): + return config + + config = AutomationConfig(config) + config.raw_config = raw_config return config diff --git a/homeassistant/components/config/automation.py b/homeassistant/components/config/automation.py index 01e22297c0d..23baa0c8843 100644 --- a/homeassistant/components/config/automation.py +++ b/homeassistant/components/config/automation.py @@ -2,6 +2,13 @@ from collections import OrderedDict import uuid +import voluptuous as vol + +from homeassistant.components import websocket_api +from homeassistant.components.automation import ( + get_debug_traces, + get_debug_traces_for_automation, +) from homeassistant.components.automation.config import ( DOMAIN, PLATFORM_SCHEMA, @@ -17,6 +24,8 @@ from . import ACTION_DELETE, EditIdBasedConfigView async def async_setup(hass): """Set up the Automation config API.""" + websocket_api.async_register_command(hass, websocket_automation_trace) + async def hook(action, config_key): """post_write_hook for Config View that reloads automations.""" await hass.services.async_call(DOMAIN, SERVICE_RELOAD) @@ -80,3 +89,21 @@ class EditAutomationConfigView(EditIdBasedConfigView): updated_value.update(cur_value) updated_value.update(new_value) data[index] = updated_value + + +@websocket_api.websocket_command( + {vol.Required("type"): "automation/trace", vol.Optional("automation_id"): str} +) +@websocket_api.async_response +async def websocket_automation_trace(hass, connection, msg): + """Get automation traces.""" + automation_id = msg.get("automation_id") + + if not automation_id: + automation_traces = get_debug_traces(hass) + else: + automation_traces = { + automation_id: get_debug_traces_for_automation(hass, automation_id) + } + + connection.send_result(msg["id"], automation_traces) diff --git a/homeassistant/components/script/__init__.py b/homeassistant/components/script/__init__.py index 5de3cb8264f..429e97230ce 100644 --- a/homeassistant/components/script/__init__.py +++ b/homeassistant/components/script/__init__.py @@ -308,7 +308,11 @@ class ScriptEntity(ToggleEntity): self._changed.set() async def async_turn_on(self, **kwargs): - """Turn the script on.""" + """Run the script. + + Depending on the script's run mode, this may do nothing, restart the script or + fire an additional parallel run. + """ variables = kwargs.get("variables") context = kwargs.get("context") wait = kwargs.get("wait", True) @@ -331,7 +335,10 @@ class ScriptEntity(ToggleEntity): await self._changed.wait() async def async_turn_off(self, **kwargs): - """Turn script off.""" + """Stop running the script. + + If multiple runs are in progress, all will be stopped. + """ await self.script.async_stop() async def async_will_remove_from_hass(self): diff --git a/homeassistant/helpers/condition.py b/homeassistant/helpers/condition.py index 40087650141..1abbf550bb1 100644 --- a/homeassistant/helpers/condition.py +++ b/homeassistant/helpers/condition.py @@ -1,12 +1,25 @@ """Offer reusable conditions.""" import asyncio from collections import deque +from contextlib import contextmanager +from contextvars import ContextVar from datetime import datetime, timedelta import functools as ft import logging import re import sys -from typing import Any, Callable, Container, List, Optional, Set, Union, cast +from typing import ( + Any, + Callable, + Container, + Dict, + Generator, + List, + Optional, + Set, + Union, + cast, +) from homeassistant.components import zone as zone_cmp from homeassistant.components.device_automation import ( @@ -51,6 +64,14 @@ from homeassistant.helpers.typing import ConfigType, TemplateVarsType from homeassistant.util.async_ import run_callback_threadsafe import homeassistant.util.dt as dt_util +from .trace import ( + TraceElement, + trace_append_element, + trace_stack_pop, + trace_stack_push, + trace_stack_top, +) + FROM_CONFIG_FORMAT = "{}_from_config" ASYNC_FROM_CONFIG_FORMAT = "async_{}_from_config" @@ -63,6 +84,126 @@ INPUT_ENTITY_ID = re.compile( ConditionCheckerType = Callable[[HomeAssistant, TemplateVarsType], bool] +# Context variables for tracing +# Trace of condition being evaluated +condition_trace = ContextVar("condition_trace", default=None) +# Stack of TraceElements +condition_trace_stack: ContextVar[Optional[List[TraceElement]]] = ContextVar( + "condition_trace_stack", default=None +) +# Current location in config tree +condition_path_stack: ContextVar[Optional[List[str]]] = ContextVar( + "condition_path_stack", default=None +) + + +def condition_trace_stack_push(node: TraceElement) -> None: + """Push a TraceElement to the top of the trace stack.""" + trace_stack_push(condition_trace_stack, node) + + +def condition_trace_stack_pop() -> None: + """Remove the top element from the trace stack.""" + trace_stack_pop(condition_trace_stack) + + +def condition_trace_stack_top() -> Optional[TraceElement]: + """Return the element at the top of the trace stack.""" + return cast(Optional[TraceElement], trace_stack_top(condition_trace_stack)) + + +def condition_path_push(suffix: Union[str, List[str]]) -> int: + """Go deeper in the config tree.""" + if isinstance(suffix, str): + suffix = [suffix] + for node in suffix: + trace_stack_push(condition_path_stack, node) + return len(suffix) + + +def condition_path_pop(count: int) -> None: + """Go n levels up in the config tree.""" + for _ in range(count): + trace_stack_pop(condition_path_stack) + + +def condition_path_get() -> str: + """Return a string representing the current location in the config tree.""" + path = condition_path_stack.get() + if not path: + return "" + return "/".join(path) + + +def condition_trace_get() -> Optional[Dict[str, TraceElement]]: + """Return the trace of the condition that was evaluated.""" + return condition_trace.get() + + +def condition_trace_clear() -> None: + """Clear the condition trace.""" + condition_trace.set(None) + condition_trace_stack.set(None) + condition_path_stack.set(None) + + +def condition_trace_append(variables: TemplateVarsType, path: str) -> TraceElement: + """Append a TraceElement to trace[path].""" + trace_element = TraceElement(variables) + trace_append_element(condition_trace, trace_element, path) + return trace_element + + +def condition_trace_set_result(result: bool, **kwargs: Any) -> None: + """Set the result of TraceElement at the top of the stack.""" + node = condition_trace_stack_top() + + # The condition function may be called directly, in which case tracing + # is not setup + if not node: + return + + node.set_result(result=result, **kwargs) + + +@contextmanager +def trace_condition(variables: TemplateVarsType) -> Generator: + """Trace condition evaluation.""" + trace_element = condition_trace_append(variables, condition_path_get()) + condition_trace_stack_push(trace_element) + try: + yield trace_element + except Exception as ex: # pylint: disable=broad-except + trace_element.set_error(ex) + raise ex + finally: + condition_trace_stack_pop() + + +@contextmanager +def condition_path(suffix: Union[str, List[str]]) -> Generator: + """Go deeper in the config tree.""" + count = condition_path_push(suffix) + try: + yield + finally: + condition_path_pop(count) + + +def trace_condition_function(condition: ConditionCheckerType) -> ConditionCheckerType: + """Wrap a condition function to enable basic tracing.""" + + @ft.wraps(condition) + def wrapper(hass: HomeAssistant, variables: TemplateVarsType = None) -> bool: + """Trace condition.""" + with trace_condition(variables): + result = condition(hass, variables) + condition_trace_set_result(result) + return result + + return wrapper + + async def async_from_config( hass: HomeAssistant, config: Union[ConfigType, Template], @@ -111,6 +252,7 @@ async def async_and_from_config( await async_from_config(hass, entry, False) for entry in config["conditions"] ] + @trace_condition_function def if_and_condition( hass: HomeAssistant, variables: TemplateVarsType = None ) -> bool: @@ -118,8 +260,9 @@ async def async_and_from_config( errors = [] for index, check in enumerate(checks): try: - if not check(hass, variables): - return False + with condition_path(["conditions", str(index)]): + if not check(hass, variables): + return False except ConditionError as ex: errors.append( ConditionErrorIndex("and", index=index, total=len(checks), error=ex) @@ -144,6 +287,7 @@ async def async_or_from_config( await async_from_config(hass, entry, False) for entry in config["conditions"] ] + @trace_condition_function def if_or_condition( hass: HomeAssistant, variables: TemplateVarsType = None ) -> bool: @@ -151,8 +295,9 @@ async def async_or_from_config( errors = [] for index, check in enumerate(checks): try: - if check(hass, variables): - return True + with condition_path(["conditions", str(index)]): + if check(hass, variables): + return True except ConditionError as ex: errors.append( ConditionErrorIndex("or", index=index, total=len(checks), error=ex) @@ -177,6 +322,7 @@ async def async_not_from_config( await async_from_config(hass, entry, False) for entry in config["conditions"] ] + @trace_condition_function def if_not_condition( hass: HomeAssistant, variables: TemplateVarsType = None ) -> bool: @@ -184,8 +330,9 @@ async def async_not_from_config( errors = [] for index, check in enumerate(checks): try: - if check(hass, variables): - return False + with condition_path(["conditions", str(index)]): + if check(hass, variables): + return False except ConditionError as ex: errors.append( ConditionErrorIndex("not", index=index, total=len(checks), error=ex) @@ -290,6 +437,11 @@ def async_numeric_state( ) try: if fvalue >= float(below_entity.state): + condition_trace_set_result( + False, + state=fvalue, + wanted_state_below=float(below_entity.state), + ) return False except (ValueError, TypeError) as ex: raise ConditionErrorMessage( @@ -297,6 +449,7 @@ def async_numeric_state( f"the 'below' entity {below} state '{below_entity.state}' cannot be processed as a number", ) from ex elif fvalue >= below: + condition_trace_set_result(False, state=fvalue, wanted_state_below=below) return False if above is not None: @@ -311,6 +464,11 @@ def async_numeric_state( ) try: if fvalue <= float(above_entity.state): + condition_trace_set_result( + False, + state=fvalue, + wanted_state_above=float(above_entity.state), + ) return False except (ValueError, TypeError) as ex: raise ConditionErrorMessage( @@ -318,8 +476,10 @@ def async_numeric_state( f"the 'above' entity {above} state '{above_entity.state}' cannot be processed as a number", ) from ex elif fvalue <= above: + condition_trace_set_result(False, state=fvalue, wanted_state_above=above) return False + condition_trace_set_result(True, state=fvalue) return True @@ -335,6 +495,7 @@ def async_numeric_state_from_config( above = config.get(CONF_ABOVE) value_template = config.get(CONF_VALUE_TEMPLATE) + @trace_condition_function def if_numeric_state( hass: HomeAssistant, variables: TemplateVarsType = None ) -> bool: @@ -345,10 +506,19 @@ def async_numeric_state_from_config( errors = [] for index, entity_id in enumerate(entity_ids): try: - if not async_numeric_state( - hass, entity_id, below, above, value_template, variables, attribute + with condition_path(["entity_id", str(index)]), trace_condition( + variables ): - return False + if not async_numeric_state( + hass, + entity_id, + below, + above, + value_template, + variables, + attribute, + ): + return False except ConditionError as ex: errors.append( ConditionErrorIndex( @@ -421,9 +591,13 @@ def state( break if for_period is None or not is_state: + condition_trace_set_result(is_state, state=value, wanted_state=state_value) return is_state - return dt_util.utcnow() - for_period > entity.last_changed + duration = dt_util.utcnow() - for_period + duration_ok = duration > entity.last_changed + condition_trace_set_result(duration_ok, state=value, duration=duration) + return duration_ok def state_from_config( @@ -440,13 +614,17 @@ def state_from_config( if not isinstance(req_states, list): req_states = [req_states] + @trace_condition_function def if_state(hass: HomeAssistant, variables: TemplateVarsType = None) -> bool: """Test if condition.""" errors = [] for index, entity_id in enumerate(entity_ids): try: - if not state(hass, entity_id, req_states, for_period, attribute): - return False + with condition_path(["entity_id", str(index)]), trace_condition( + variables + ): + if not state(hass, entity_id, req_states, for_period, attribute): + return False except ConditionError as ex: errors.append( ConditionErrorIndex( @@ -529,11 +707,12 @@ def sun_from_config( before_offset = config.get("before_offset") after_offset = config.get("after_offset") - def time_if(hass: HomeAssistant, variables: TemplateVarsType = None) -> bool: + @trace_condition_function + def sun_if(hass: HomeAssistant, variables: TemplateVarsType = None) -> bool: """Validate time based if-condition.""" return sun(hass, before, after, before_offset, after_offset) - return time_if + return sun_if def template( @@ -565,6 +744,7 @@ def async_template_from_config( config = cv.TEMPLATE_CONDITION_SCHEMA(config) value_template = cast(Template, config.get(CONF_VALUE_TEMPLATE)) + @trace_condition_function def template_if(hass: HomeAssistant, variables: TemplateVarsType = None) -> bool: """Validate template based if-condition.""" value_template.hass = hass @@ -645,6 +825,7 @@ def time_from_config( after = config.get(CONF_AFTER) weekday = config.get(CONF_WEEKDAY) + @trace_condition_function def time_if(hass: HomeAssistant, variables: TemplateVarsType = None) -> bool: """Validate time based if-condition.""" return time(hass, before, after, weekday) @@ -710,6 +891,7 @@ def zone_from_config( entity_ids = config.get(CONF_ENTITY_ID, []) zone_entity_ids = config.get(CONF_ZONE, []) + @trace_condition_function def if_in_zone(hass: HomeAssistant, variables: TemplateVarsType = None) -> bool: """Test if condition.""" errors = [] @@ -750,9 +932,11 @@ async def async_device_from_config( platform = await async_get_device_automation_platform( hass, config[CONF_DOMAIN], "condition" ) - return cast( - ConditionCheckerType, - platform.async_condition_from_config(config, config_validation), # type: ignore + return trace_condition_function( + cast( + ConditionCheckerType, + platform.async_condition_from_config(config, config_validation), # type: ignore + ) ) diff --git a/homeassistant/helpers/script.py b/homeassistant/helpers/script.py index e4eb0d4a901..aaed10f7814 100644 --- a/homeassistant/helpers/script.py +++ b/homeassistant/helpers/script.py @@ -1,5 +1,7 @@ """Helpers to execute scripts.""" import asyncio +from contextlib import contextmanager +from contextvars import ContextVar from datetime import datetime, timedelta from functools import partial import itertools @@ -63,6 +65,12 @@ from homeassistant.core import ( callback, ) from homeassistant.helpers import condition, config_validation as cv, service, template +from homeassistant.helpers.condition import ( + condition_path, + condition_trace_clear, + condition_trace_get, + trace_condition_function, +) from homeassistant.helpers.event import async_call_later, async_track_template from homeassistant.helpers.script_variables import ScriptVariables from homeassistant.helpers.trigger import ( @@ -73,6 +81,14 @@ from homeassistant.helpers.typing import ConfigType from homeassistant.util import slugify from homeassistant.util.dt import utcnow +from .trace import ( + TraceElement, + trace_append_element, + trace_stack_pop, + trace_stack_push, + trace_stack_top, +) + # mypy: allow-untyped-calls, allow-untyped-defs, no-check-untyped-defs SCRIPT_MODE_PARALLEL = "parallel" @@ -108,6 +124,115 @@ _TIMEOUT_MSG = "Timeout reached, abort script." _SHUTDOWN_MAX_WAIT = 60 +ACTION_TRACE_NODE_MAX_LEN = 20 # Max the length of a trace node for repeated actions + +action_trace = ContextVar("action_trace", default=None) +action_trace_stack = ContextVar("action_trace_stack", default=None) +action_path_stack = ContextVar("action_path_stack", default=None) + + +def action_trace_stack_push(node): + """Push a TraceElement to the top of the trace stack.""" + trace_stack_push(action_trace_stack, node) + + +def action_trace_stack_pop(): + """Remove the top element from the trace stack.""" + trace_stack_pop(action_trace_stack) + + +def action_trace_stack_top(): + """Return the element at the top of the trace stack.""" + return trace_stack_top(action_trace_stack) + + +def action_path_push(suffix): + """Go deeper in the config tree.""" + if isinstance(suffix, str): + suffix = [suffix] + for node in suffix: + trace_stack_push(action_path_stack, node) + return len(suffix) + + +def action_path_pop(count): + """Go n levels up in the config tree.""" + for _ in range(count): + trace_stack_pop(action_path_stack) + + +def action_path_get(): + """Return a string representing the current location in the config tree.""" + path = action_path_stack.get() + if not path: + return "" + return "/".join(path) + + +def action_trace_get(): + """Return the trace of the script that was executed.""" + return action_trace.get() + + +def action_trace_clear(): + """Clear the action trace.""" + action_trace.set({}) + action_trace_stack.set(None) + action_path_stack.set(None) + + +def action_trace_append(variables, path): + """Append a TraceElement to trace[path].""" + trace_element = TraceElement(variables) + trace_append_element(action_trace, trace_element, path, ACTION_TRACE_NODE_MAX_LEN) + return trace_element + + +def action_trace_set_result(**kwargs): + """Set the result of TraceElement at the top of the stack.""" + node = action_trace_stack_top() + node.set_result(**kwargs) + + +def action_trace_add_conditions(): + """Add the result of condition evaluation to the action trace.""" + condition_trace = condition_trace_get() + condition_trace_clear() + + if condition_trace is None: + return + + _action_path = action_path_get() + for cond_path, conditions in condition_trace.items(): + path = _action_path + "/" + cond_path if cond_path else _action_path + for cond in conditions: + trace_append_element(action_trace, cond, path) + + +@contextmanager +def trace_action(variables): + """Trace action execution.""" + trace_element = action_trace_append(variables, action_path_get()) + action_trace_stack_push(trace_element) + try: + yield trace_element + except Exception as ex: # pylint: disable=broad-except + trace_element.set_error(ex) + raise ex + finally: + action_trace_stack_pop() + + +@contextmanager +def action_path(suffix): + """Go deeper in the config tree.""" + count = action_path_push(suffix) + try: + yield + finally: + action_path_pop(count) + + def make_script_schema(schema, default_script_mode, extra=vol.PREVENT_EXTRA): """Make a schema for a component that uses the script helper.""" return vol.Schema( @@ -258,16 +383,16 @@ class _ScriptRun: self._finish() async def _async_step(self, log_exceptions): - try: - await getattr( - self, f"_async_{cv.determine_script_action(self._action)}_step" - )() - except Exception as ex: - if not isinstance(ex, (_StopScript, asyncio.CancelledError)) and ( - self._log_exceptions or log_exceptions - ): - self._log_exception(ex) - raise + with action_path(str(self._step)), trace_action(None): + try: + handler = f"_async_{cv.determine_script_action(self._action)}_step" + await getattr(self, handler)() + except Exception as ex: + if not isinstance(ex, (_StopScript, asyncio.CancelledError)) and ( + self._log_exceptions or log_exceptions + ): + self._log_exception(ex) + raise def _finish(self) -> None: self._script._runs.remove(self) # pylint: disable=protected-access @@ -514,15 +639,37 @@ class _ScriptRun: ) cond = await self._async_get_condition(self._action) try: - check = cond(self._hass, self._variables) + with condition_path("condition"): + check = cond(self._hass, self._variables) except exceptions.ConditionError as ex: _LOGGER.warning("Error in 'condition' evaluation:\n%s", ex) check = False self._log("Test condition %s: %s", self._script.last_action, check) + action_trace_set_result(result=check) + action_trace_add_conditions() if not check: raise _StopScript + def _test_conditions(self, conditions, name): + @trace_condition_function + def traced_test_conditions(hass, variables): + try: + with condition_path("conditions"): + for idx, cond in enumerate(conditions): + with condition_path(str(idx)): + if not cond(hass, variables): + return False + except exceptions.ConditionError as ex: + _LOGGER.warning("Error in '%s[%s]' evaluation: %s", name, idx, ex) + return None + + return True + + result = traced_test_conditions(self._hass, self._variables) + action_trace_add_conditions() + return result + async def _async_repeat_step(self): """Repeat a sequence.""" description = self._action.get(CONF_ALIAS, "sequence") @@ -541,7 +688,8 @@ class _ScriptRun: async def async_run_sequence(iteration, extra_msg=""): self._log("Repeating %s: Iteration %i%s", description, iteration, extra_msg) - await self._async_run_script(script) + with action_path(str(self._step)): + await self._async_run_script(script) if CONF_COUNT in repeat: count = repeat[CONF_COUNT] @@ -570,9 +718,9 @@ class _ScriptRun: for iteration in itertools.count(1): set_repeat_var(iteration) try: - if self._stop.is_set() or not all( - cond(self._hass, self._variables) for cond in conditions - ): + if self._stop.is_set(): + break + if not self._test_conditions(conditions, "while"): break except exceptions.ConditionError as ex: _LOGGER.warning("Error in 'while' evaluation:\n%s", ex) @@ -588,9 +736,9 @@ class _ScriptRun: set_repeat_var(iteration) await async_run_sequence(iteration) try: - if self._stop.is_set() or all( - cond(self._hass, self._variables) for cond in conditions - ): + if self._stop.is_set(): + break + if self._test_conditions(conditions, "until") in [True, None]: break except exceptions.ConditionError as ex: _LOGGER.warning("Error in 'until' evaluation:\n%s", ex) @@ -606,18 +754,20 @@ class _ScriptRun: # pylint: disable=protected-access choose_data = await self._script._async_get_choose_data(self._step) - for conditions, script in choose_data["choices"]: - try: - if all( - condition(self._hass, self._variables) for condition in conditions - ): - await self._async_run_script(script) - return - except exceptions.ConditionError as ex: - _LOGGER.warning("Error in 'choose' evaluation:\n%s", ex) + for idx, (conditions, script) in enumerate(choose_data["choices"]): + with action_path(str(idx)): + try: + if self._test_conditions(conditions, "choose"): + action_trace_set_result(choice=idx) + await self._async_run_script(script) + return + except exceptions.ConditionError as ex: + _LOGGER.warning("Error in 'choose' evaluation:\n%s", ex) if choose_data["default"]: - await self._async_run_script(choose_data["default"]) + action_trace_set_result(choice="default") + with action_path("default"): + await self._async_run_script(choose_data["default"]) async def _async_wait_for_trigger_step(self): """Wait for a trigger event.""" diff --git a/homeassistant/helpers/trace.py b/homeassistant/helpers/trace.py new file mode 100644 index 00000000000..450faa0336f --- /dev/null +++ b/homeassistant/helpers/trace.py @@ -0,0 +1,78 @@ +"""Helpers for script and condition tracing.""" +from collections import deque +from contextvars import ContextVar +from typing import Any, Dict, Optional + +from homeassistant.helpers.typing import TemplateVarsType +import homeassistant.util.dt as dt_util + + +def trace_stack_push(trace_stack_var: ContextVar, node: Any) -> None: + """Push an element to the top of a trace stack.""" + trace_stack = trace_stack_var.get() + if trace_stack is None: + trace_stack = [] + trace_stack_var.set(trace_stack) + trace_stack.append(node) + + +def trace_stack_pop(trace_stack_var: ContextVar) -> None: + """Remove the top element from a trace stack.""" + trace_stack = trace_stack_var.get() + trace_stack.pop() + + +def trace_stack_top(trace_stack_var: ContextVar) -> Optional[Any]: + """Return the element at the top of a trace stack.""" + trace_stack = trace_stack_var.get() + return trace_stack[-1] if trace_stack else None + + +class TraceElement: + """Container for trace data.""" + + def __init__(self, variables: TemplateVarsType): + """Container for trace data.""" + self._error: Optional[Exception] = None + self._result: Optional[dict] = None + self._timestamp = dt_util.utcnow() + self._variables = variables + + def __repr__(self) -> str: + """Container for trace data.""" + return str(self.as_dict()) + + def set_error(self, ex: Exception) -> None: + """Set error.""" + self._error = ex + + def set_result(self, **kwargs: Any) -> None: + """Set result.""" + self._result = {**kwargs} + + def as_dict(self) -> Dict[str, Any]: + """Return dictionary version of this TraceElement.""" + result: Dict[str, Any] = {"timestamp": self._timestamp} + # Commented out because we get too many copies of the same data + # result["variables"] = self._variables + if self._error is not None: + result["error"] = str(self._error) + if self._result is not None: + result["result"] = self._result + return result + + +def trace_append_element( + trace_var: ContextVar, + trace_element: TraceElement, + path: str, + maxlen: Optional[int] = None, +) -> None: + """Append a TraceElement to trace[path].""" + trace = trace_var.get() + if trace is None: + trace_var.set({}) + trace = trace_var.get() + if path not in trace: + trace[path] = deque(maxlen=maxlen) + trace[path].append(trace_element) diff --git a/tests/components/config/test_automation.py b/tests/components/config/test_automation.py index 541cd3068d2..7e0cf9e8e4d 100644 --- a/tests/components/config/test_automation.py +++ b/tests/components/config/test_automation.py @@ -164,3 +164,164 @@ async def test_delete_automation(hass, hass_client): assert written[0][0]["id"] == "moon" assert len(ent_reg.entities) == 1 + + +async def test_get_automation_trace(hass, hass_ws_client): + """Test deleting an automation.""" + id = 1 + + def next_id(): + nonlocal id + id += 1 + return id + + sun_config = { + "id": "sun", + "trigger": {"platform": "event", "event_type": "test_event"}, + "action": {"service": "test.automation"}, + } + moon_config = { + "id": "moon", + "trigger": [ + {"platform": "event", "event_type": "test_event2"}, + {"platform": "event", "event_type": "test_event3"}, + ], + "condition": { + "condition": "template", + "value_template": "{{ trigger.event.event_type=='test_event2' }}", + }, + "action": {"event": "another_event"}, + } + + assert await async_setup_component( + hass, + "automation", + { + "automation": [ + sun_config, + moon_config, + ] + }, + ) + + with patch.object(config, "SECTIONS", ["automation"]): + await async_setup_component(hass, "config", {}) + + client = await hass_ws_client() + + await client.send_json({"id": next_id(), "type": "automation/trace"}) + response = await client.receive_json() + assert response["success"] + assert response["result"] == {} + + await client.send_json( + {"id": next_id(), "type": "automation/trace", "automation_id": "sun"} + ) + response = await client.receive_json() + assert response["success"] + assert response["result"] == {"sun": []} + + # Trigger "sun" automation + hass.bus.async_fire("test_event") + await hass.async_block_till_done() + + # Get trace + await client.send_json({"id": next_id(), "type": "automation/trace"}) + response = await client.receive_json() + assert response["success"] + assert "moon" not in response["result"] + assert len(response["result"]["sun"]) == 1 + trace = response["result"]["sun"][0] + assert len(trace["action_trace"]) == 1 + assert len(trace["action_trace"]["action/0"]) == 1 + assert trace["action_trace"]["action/0"][0]["error"] + assert "result" not in trace["action_trace"]["action/0"][0] + assert trace["condition_trace"] == {} + assert trace["config"] == sun_config + assert trace["context"] + assert trace["error"] == "Unable to find service test.automation" + assert trace["state"] == "stopped" + assert trace["trigger"]["description"] == "event 'test_event'" + assert trace["unique_id"] == "sun" + assert trace["variables"] + + # Trigger "moon" automation, with passing condition + hass.bus.async_fire("test_event2") + await hass.async_block_till_done() + + # Get trace + await client.send_json( + {"id": next_id(), "type": "automation/trace", "automation_id": "moon"} + ) + response = await client.receive_json() + assert response["success"] + assert "sun" not in response["result"] + assert len(response["result"]["moon"]) == 1 + trace = response["result"]["moon"][0] + assert len(trace["action_trace"]) == 1 + assert len(trace["action_trace"]["action/0"]) == 1 + assert "error" not in trace["action_trace"]["action/0"][0] + assert "result" not in trace["action_trace"]["action/0"][0] + assert len(trace["condition_trace"]) == 1 + assert len(trace["condition_trace"]["condition/0"]) == 1 + assert trace["condition_trace"]["condition/0"][0]["result"] == {"result": True} + assert trace["config"] == moon_config + assert trace["context"] + assert "error" not in trace + assert trace["state"] == "stopped" + assert trace["trigger"]["description"] == "event 'test_event2'" + assert trace["unique_id"] == "moon" + assert trace["variables"] + + # Trigger "moon" automation, with failing condition + hass.bus.async_fire("test_event3") + await hass.async_block_till_done() + + # Get trace + await client.send_json( + {"id": next_id(), "type": "automation/trace", "automation_id": "moon"} + ) + response = await client.receive_json() + assert response["success"] + assert "sun" not in response["result"] + assert len(response["result"]["moon"]) == 2 + trace = response["result"]["moon"][1] + assert len(trace["action_trace"]) == 0 + assert len(trace["condition_trace"]) == 1 + assert len(trace["condition_trace"]["condition/0"]) == 1 + assert trace["condition_trace"]["condition/0"][0]["result"] == {"result": False} + assert trace["config"] == moon_config + assert trace["context"] + assert "error" not in trace + assert trace["state"] == "stopped" + assert trace["trigger"]["description"] == "event 'test_event3'" + assert trace["unique_id"] == "moon" + assert trace["variables"] + + # Trigger "moon" automation, with passing condition + hass.bus.async_fire("test_event2") + await hass.async_block_till_done() + + # Get trace + await client.send_json( + {"id": next_id(), "type": "automation/trace", "automation_id": "moon"} + ) + response = await client.receive_json() + assert response["success"] + assert "sun" not in response["result"] + assert len(response["result"]["moon"]) == 3 + trace = response["result"]["moon"][2] + assert len(trace["action_trace"]) == 1 + assert len(trace["action_trace"]["action/0"]) == 1 + assert "error" not in trace["action_trace"]["action/0"][0] + assert "result" not in trace["action_trace"]["action/0"][0] + assert len(trace["condition_trace"]) == 1 + assert len(trace["condition_trace"]["condition/0"]) == 1 + assert trace["condition_trace"]["condition/0"][0]["result"] == {"result": True} + assert trace["config"] == moon_config + assert trace["context"] + assert "error" not in trace + assert trace["state"] == "stopped" + assert trace["trigger"]["description"] == "event 'test_event2'" + assert trace["unique_id"] == "moon" + assert trace["variables"] diff --git a/tests/helpers/test_condition.py b/tests/helpers/test_condition.py index 5074b6e70c4..b3e950131b0 100644 --- a/tests/helpers/test_condition.py +++ b/tests/helpers/test_condition.py @@ -11,6 +11,32 @@ from homeassistant.setup import async_setup_component from homeassistant.util import dt +def assert_element(trace_element, expected_element, path): + """Assert a trace element is as expected. + + Note: Unused variable path is passed to get helpful errors from pytest. + """ + for result_key, result in expected_element.get("result", {}).items(): + assert trace_element._result[result_key] == result + if "error_type" in expected_element: + assert isinstance(trace_element._error, expected_element["error_type"]) + else: + assert trace_element._error is None + + +def assert_condition_trace(expected): + """Assert a trace condition sequence is as expected.""" + condition_trace = condition.condition_trace_get() + condition.condition_trace_clear() + expected_trace_keys = list(expected.keys()) + assert list(condition_trace.keys()) == expected_trace_keys + for trace_key_index, key in enumerate(expected_trace_keys): + assert len(condition_trace[key]) == len(expected[key]) + for index, element in enumerate(expected[key]): + path = f"[{trace_key_index}][{index}]" + assert_element(condition_trace[key][index], element, path) + + async def test_invalid_condition(hass): """Test if invalid condition raises.""" with pytest.raises(HomeAssistantError): @@ -53,15 +79,112 @@ async def test_and_condition(hass): with pytest.raises(ConditionError): test(hass) + assert_condition_trace( + { + "": [{"error_type": ConditionError}], + "conditions/0": [{"error_type": ConditionError}], + "conditions/0/entity_id/0": [{"error_type": ConditionError}], + "conditions/1": [{"error_type": ConditionError}], + "conditions/1/entity_id/0": [{"error_type": ConditionError}], + } + ) hass.states.async_set("sensor.temperature", 120) assert not test(hass) + assert_condition_trace( + { + "": [{"result": {"result": False}}], + "conditions/0": [{"result": {"result": False}}], + "conditions/0/entity_id/0": [{"result": {"result": False}}], + } + ) hass.states.async_set("sensor.temperature", 105) assert not test(hass) + assert_condition_trace( + { + "": [{"result": {"result": False}}], + "conditions/0": [{"result": {"result": False}}], + "conditions/0/entity_id/0": [{"result": {"result": False}}], + } + ) hass.states.async_set("sensor.temperature", 100) assert test(hass) + assert_condition_trace( + { + "": [{"result": {"result": True}}], + "conditions/0": [{"result": {"result": True}}], + "conditions/0/entity_id/0": [{"result": {"result": True}}], + "conditions/1": [{"result": {"result": True}}], + "conditions/1/entity_id/0": [{"result": {"result": True}}], + } + ) + + +async def test_and_condition_raises(hass): + """Test the 'and' condition.""" + test = await condition.async_from_config( + hass, + { + "alias": "And Condition", + "condition": "and", + "conditions": [ + { + "condition": "state", + "entity_id": "sensor.temperature", + "state": "100", + }, + { + "condition": "numeric_state", + "entity_id": "sensor.temperature2", + "above": 110, + }, + ], + }, + ) + + # All subconditions raise, the AND-condition should raise + with pytest.raises(ConditionError): + test(hass) + assert_condition_trace( + { + "": [{"error_type": ConditionError}], + "conditions/0": [{"error_type": ConditionError}], + "conditions/0/entity_id/0": [{"error_type": ConditionError}], + "conditions/1": [{"error_type": ConditionError}], + "conditions/1/entity_id/0": [{"error_type": ConditionError}], + } + ) + + # The first subconditions raises, the second returns True, the AND-condition + # should raise + hass.states.async_set("sensor.temperature2", 120) + with pytest.raises(ConditionError): + test(hass) + assert_condition_trace( + { + "": [{"error_type": ConditionError}], + "conditions/0": [{"error_type": ConditionError}], + "conditions/0/entity_id/0": [{"error_type": ConditionError}], + "conditions/1": [{"result": {"result": True}}], + "conditions/1/entity_id/0": [{"result": {"result": True}}], + } + ) + + # The first subconditions raises, the second returns False, the AND-condition + # should return False + hass.states.async_set("sensor.temperature2", 90) + assert not test(hass) + assert_condition_trace( + { + "": [{"result": {"result": False}}], + "conditions/0": [{"error_type": ConditionError}], + "conditions/0/entity_id/0": [{"error_type": ConditionError}], + "conditions/1": [{"result": {"result": False}}], + "conditions/1/entity_id/0": [{"result": {"result": False}}], + } + ) async def test_and_condition_with_template(hass): @@ -119,15 +242,114 @@ async def test_or_condition(hass): with pytest.raises(ConditionError): test(hass) + assert_condition_trace( + { + "": [{"error_type": ConditionError}], + "conditions/0": [{"error_type": ConditionError}], + "conditions/0/entity_id/0": [{"error_type": ConditionError}], + "conditions/1": [{"error_type": ConditionError}], + "conditions/1/entity_id/0": [{"error_type": ConditionError}], + } + ) hass.states.async_set("sensor.temperature", 120) assert not test(hass) + assert_condition_trace( + { + "": [{"result": {"result": False}}], + "conditions/0": [{"result": {"result": False}}], + "conditions/0/entity_id/0": [{"result": {"result": False}}], + "conditions/1": [{"result": {"result": False}}], + "conditions/1/entity_id/0": [{"result": {"result": False}}], + } + ) hass.states.async_set("sensor.temperature", 105) assert test(hass) + assert_condition_trace( + { + "": [{"result": {"result": True}}], + "conditions/0": [{"result": {"result": False}}], + "conditions/0/entity_id/0": [{"result": {"result": False}}], + "conditions/1": [{"result": {"result": True}}], + "conditions/1/entity_id/0": [{"result": {"result": True}}], + } + ) hass.states.async_set("sensor.temperature", 100) assert test(hass) + assert_condition_trace( + { + "": [{"result": {"result": True}}], + "conditions/0": [{"result": {"result": True}}], + "conditions/0/entity_id/0": [{"result": {"result": True}}], + } + ) + + +async def test_or_condition_raises(hass): + """Test the 'or' condition.""" + test = await condition.async_from_config( + hass, + { + "alias": "Or Condition", + "condition": "or", + "conditions": [ + { + "condition": "state", + "entity_id": "sensor.temperature", + "state": "100", + }, + { + "condition": "numeric_state", + "entity_id": "sensor.temperature2", + "above": 110, + }, + ], + }, + ) + + # All subconditions raise, the OR-condition should raise + with pytest.raises(ConditionError): + test(hass) + assert_condition_trace( + { + "": [{"error_type": ConditionError}], + "conditions/0": [{"error_type": ConditionError}], + "conditions/0/entity_id/0": [{"error_type": ConditionError}], + "conditions/1": [{"error_type": ConditionError}], + "conditions/1/entity_id/0": [{"error_type": ConditionError}], + } + ) + + # The first subconditions raises, the second returns False, the OR-condition + # should raise + hass.states.async_set("sensor.temperature2", 100) + with pytest.raises(ConditionError): + test(hass) + assert_condition_trace( + { + "": [{"error_type": ConditionError}], + "conditions/0": [{"error_type": ConditionError}], + "conditions/0/entity_id/0": [{"error_type": ConditionError}], + "conditions/1": [{"result": {"result": False}}], + "conditions/1/entity_id/0": [{"result": {"result": False}}], + } + ) + + # The first subconditions raises, the second returns True, the OR-condition + # should return True + hass.states.async_set("sensor.temperature2", 120) + assert test(hass) + assert_condition_trace( + { + "": [{"result": {"result": True}}], + "conditions/0": [{"error_type": ConditionError}], + "conditions/0/entity_id/0": [{"error_type": ConditionError}], + "conditions/1": [{"result": {"result": True}}], + "conditions/1/entity_id/0": [{"result": {"result": True}}], + } + ) async def test_or_condition_with_template(hass): @@ -181,18 +403,126 @@ async def test_not_condition(hass): with pytest.raises(ConditionError): test(hass) + assert_condition_trace( + { + "": [{"error_type": ConditionError}], + "conditions/0": [{"error_type": ConditionError}], + "conditions/0/entity_id/0": [{"error_type": ConditionError}], + "conditions/1": [{"error_type": ConditionError}], + "conditions/1/entity_id/0": [{"error_type": ConditionError}], + } + ) hass.states.async_set("sensor.temperature", 101) assert test(hass) + assert_condition_trace( + { + "": [{"result": {"result": True}}], + "conditions/0": [{"result": {"result": False}}], + "conditions/0/entity_id/0": [{"result": {"result": False}}], + "conditions/1": [{"result": {"result": False}}], + "conditions/1/entity_id/0": [{"result": {"result": False}}], + } + ) hass.states.async_set("sensor.temperature", 50) assert test(hass) + assert_condition_trace( + { + "": [{"result": {"result": True}}], + "conditions/0": [{"result": {"result": False}}], + "conditions/0/entity_id/0": [{"result": {"result": False}}], + "conditions/1": [{"result": {"result": False}}], + "conditions/1/entity_id/0": [{"result": {"result": False}}], + } + ) hass.states.async_set("sensor.temperature", 49) assert not test(hass) + assert_condition_trace( + { + "": [{"result": {"result": False}}], + "conditions/0": [{"result": {"result": False}}], + "conditions/0/entity_id/0": [{"result": {"result": False}}], + "conditions/1": [{"result": {"result": True}}], + "conditions/1/entity_id/0": [{"result": {"result": True}}], + } + ) hass.states.async_set("sensor.temperature", 100) assert not test(hass) + assert_condition_trace( + { + "": [{"result": {"result": False}}], + "conditions/0": [{"result": {"result": True}}], + "conditions/0/entity_id/0": [{"result": {"result": True}}], + } + ) + + +async def test_not_condition_raises(hass): + """Test the 'and' condition.""" + test = await condition.async_from_config( + hass, + { + "alias": "Not Condition", + "condition": "not", + "conditions": [ + { + "condition": "state", + "entity_id": "sensor.temperature", + "state": "100", + }, + { + "condition": "numeric_state", + "entity_id": "sensor.temperature2", + "below": 50, + }, + ], + }, + ) + + # All subconditions raise, the NOT-condition should raise + with pytest.raises(ConditionError): + test(hass) + assert_condition_trace( + { + "": [{"error_type": ConditionError}], + "conditions/0": [{"error_type": ConditionError}], + "conditions/0/entity_id/0": [{"error_type": ConditionError}], + "conditions/1": [{"error_type": ConditionError}], + "conditions/1/entity_id/0": [{"error_type": ConditionError}], + } + ) + + # The first subconditions raises, the second returns False, the NOT-condition + # should raise + hass.states.async_set("sensor.temperature2", 90) + with pytest.raises(ConditionError): + test(hass) + assert_condition_trace( + { + "": [{"error_type": ConditionError}], + "conditions/0": [{"error_type": ConditionError}], + "conditions/0/entity_id/0": [{"error_type": ConditionError}], + "conditions/1": [{"result": {"result": False}}], + "conditions/1/entity_id/0": [{"result": {"result": False}}], + } + ) + + # The first subconditions raises, the second returns True, the NOT-condition + # should return False + hass.states.async_set("sensor.temperature2", 40) + assert not test(hass) + assert_condition_trace( + { + "": [{"result": {"result": False}}], + "conditions/0": [{"error_type": ConditionError}], + "conditions/0/entity_id/0": [{"error_type": ConditionError}], + "conditions/1": [{"result": {"result": True}}], + "conditions/1/entity_id/0": [{"result": {"result": True}}], + } + ) async def test_not_condition_with_template(hass): diff --git a/tests/helpers/test_script.py b/tests/helpers/test_script.py index d2946fcd494..04f922b685e 100644 --- a/tests/helpers/test_script.py +++ b/tests/helpers/test_script.py @@ -30,6 +30,32 @@ from tests.common import ( ENTITY_ID = "script.test" +def assert_element(trace_element, expected_element, path): + """Assert a trace element is as expected. + + Note: Unused variable 'path' is passed to get helpful errors from pytest. + """ + for result_key, result in expected_element.get("result", {}).items(): + assert trace_element._result[result_key] == result + if "error_type" in expected_element: + assert isinstance(trace_element._error, expected_element["error_type"]) + else: + assert trace_element._error is None + + +def assert_action_trace(expected): + """Assert a trace condition sequence is as expected.""" + action_trace = script.action_trace_get() + script.action_trace_clear() + expected_trace_keys = list(expected.keys()) + assert list(action_trace.keys()) == expected_trace_keys + for trace_key_index, key in enumerate(expected_trace_keys): + assert len(action_trace[key]) == len(expected[key]) + for index, element in enumerate(expected[key]): + path = f"[{trace_key_index}][{index}]" + assert_element(action_trace[key][index], element, path) + + def async_watch_for_action(script_obj, message): """Watch for message in last_action.""" flag = asyncio.Event() @@ -54,9 +80,14 @@ async def test_firing_event_basic(hass, caplog): sequence = cv.SCRIPT_SCHEMA( {"alias": alias, "event": event, "event_data": {"hello": "world"}} ) - script_obj = script.Script( - hass, sequence, "Test Name", "test_domain", running_description="test script" - ) + with script.trace_action(None): + script_obj = script.Script( + hass, + sequence, + "Test Name", + "test_domain", + running_description="test script", + ) await script_obj.async_run(context=context) await hass.async_block_till_done() @@ -67,6 +98,12 @@ async def test_firing_event_basic(hass, caplog): assert ".test_name:" in caplog.text assert "Test Name: Running test script" in caplog.text assert f"Executing step {alias}" in caplog.text + assert_action_trace( + { + "": [{}], + "0": [{}], + } + ) async def test_firing_event_template(hass):