Merge action and condition traces (#47373)

* Merge action and condition traces

* Update __init__.py

* Add typing to AutomationTrace

* Make trace_get prepare a new trace by default

* Correct typing of trace_cv

* Fix tests
pull/47510/head
Erik Montnemery 2021-03-06 12:57:21 +01:00 committed by GitHub
parent 022184176a
commit 2f9d03d115
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 191 additions and 284 deletions

View File

@ -1,8 +1,20 @@
"""Allow to set up simple automation rules via the config file."""
from collections import deque
from contextlib import contextmanager
import datetime as dt
import logging
from typing import Any, Awaitable, Callable, Dict, List, Optional, Set, Union, cast
from typing import (
Any,
Awaitable,
Callable,
Deque,
Dict,
List,
Optional,
Set,
Union,
cast,
)
import voluptuous as vol
from voluptuous.humanize import humanize_error
@ -42,11 +54,6 @@ from homeassistant.exceptions import (
HomeAssistantError,
)
from homeassistant.helpers import condition, extract_domain_configs, template
from homeassistant.helpers.condition import (
condition_path,
condition_trace_clear,
condition_trace_get,
)
import homeassistant.helpers.config_validation as cv
from homeassistant.helpers.entity import ToggleEntity
from homeassistant.helpers.entity_component import EntityComponent
@ -57,12 +64,10 @@ from homeassistant.helpers.script import (
CONF_MAX,
CONF_MAX_EXCEEDED,
Script,
action_path,
action_trace_clear,
action_trace_get,
)
from homeassistant.helpers.script_variables import ScriptVariables
from homeassistant.helpers.service import async_register_admin_service
from homeassistant.helpers.trace import TraceElement, trace_get, trace_path
from homeassistant.helpers.trigger import async_initialize_triggers
from homeassistant.helpers.typing import TemplateVarsType
from homeassistant.loader import bind_hass
@ -235,44 +240,55 @@ async def async_setup(hass, config):
class AutomationTrace:
"""Container for automation trace."""
def __init__(self, unique_id, config, trigger, context, action_trace):
def __init__(
self,
unique_id: Optional[str],
config: Dict[str, Any],
trigger: Dict[str, Any],
context: Context,
):
"""Container for automation trace."""
self._action_trace = action_trace
self._condition_trace = None
self._config = config
self._context = context
self._error = None
self._state = "running"
self._timestamp_finish = None
self._timestamp_start = dt_util.utcnow()
self._trigger = trigger
self._unique_id = unique_id
self._variables = None
self._action_trace: Optional[Dict[str, Deque[TraceElement]]] = None
self._condition_trace: Optional[Dict[str, Deque[TraceElement]]] = None
self._config: Dict[str, Any] = config
self._context: Context = context
self._error: Optional[Exception] = None
self._state: str = "running"
self._timestamp_finish: Optional[dt.datetime] = None
self._timestamp_start: dt.datetime = dt_util.utcnow()
self._trigger: Dict[str, Any] = trigger
self._unique_id: Optional[str] = unique_id
self._variables: Optional[Dict[str, Any]] = None
def set_error(self, ex):
def set_action_trace(self, trace: Dict[str, Deque[TraceElement]]) -> None:
"""Set action trace."""
self._action_trace = trace
def set_condition_trace(self, trace: Dict[str, Deque[TraceElement]]) -> None:
"""Set condition trace."""
self._condition_trace = trace
def set_error(self, ex: Exception) -> None:
"""Set error."""
self._error = ex
def set_variables(self, variables):
def set_variables(self, variables: Dict[str, Any]) -> None:
"""Set variables."""
self._variables = variables
def set_condition_trace(self, condition_trace):
"""Set condition trace."""
self._condition_trace = condition_trace
def finished(self):
def finished(self) -> None:
"""Set finish time."""
self._timestamp_finish = dt_util.utcnow()
self._state = "stopped"
def as_dict(self):
def as_dict(self) -> Dict[str, Any]:
"""Return dictionary version of this AutomationTrace."""
action_traces = {}
condition_traces = {}
for key, trace_list in self._action_trace.items():
action_traces[key] = [item.as_dict() for item in trace_list]
if self._action_trace:
for key, trace_list in self._action_trace.items():
action_traces[key] = [item.as_dict() for item in trace_list]
if self._condition_trace:
for key, trace_list in self._condition_trace.items():
@ -300,11 +316,7 @@ class AutomationTrace:
@contextmanager
def trace_automation(hass, unique_id, config, trigger, context):
"""Trace action execution of automation with automation_id."""
action_trace_clear()
action_trace = action_trace_get()
automation_trace = AutomationTrace(
unique_id, config, trigger, context, action_trace
)
automation_trace = AutomationTrace(unique_id, config, trigger, context)
if unique_id:
if unique_id not in hass.data[DATA_AUTOMATION_TRACE]:
@ -325,7 +337,7 @@ def trace_automation(hass, unique_id, config, trigger, context):
"Automation finished. Summary:\n\ttrigger: %s\n\tcondition: %s\n\taction: %s",
automation_trace._trigger, # pylint: disable=protected-access
automation_trace._condition_trace, # pylint: disable=protected-access
action_trace,
automation_trace._action_trace, # pylint: disable=protected-access
)
@ -510,6 +522,9 @@ class AutomationEntity(ToggleEntity, RestoreEntity):
variables = run_variables
automation_trace.set_variables(variables)
# Prepare tracing the evaluation of the automation's conditions
automation_trace.set_condition_trace(trace_get())
if (
not skip_condition
and self._cond_func is not None
@ -517,12 +532,12 @@ class AutomationEntity(ToggleEntity, RestoreEntity):
):
self._logger.debug(
"Conditions not met, aborting automation. Condition summary: %s",
condition_trace_get(),
trace_get(clear=False),
)
automation_trace.set_condition_trace(condition_trace_get())
return
automation_trace.set_condition_trace(condition_trace_get())
condition_trace_clear()
# Prepare tracing the execution of the automation's actions
automation_trace.set_action_trace(trace_get())
# Create a new context referring to the old context.
parent_id = None if context is None else context.id
@ -543,7 +558,7 @@ class AutomationEntity(ToggleEntity, RestoreEntity):
)
try:
with action_path("action"):
with trace_path("action"):
await self.action_script.async_run(
variables, trigger_context, started_action
)
@ -763,7 +778,7 @@ async def _async_process_if(hass, name, config, p_config):
errors = []
for index, check in enumerate(checks):
try:
with condition_path(["condition", str(index)]):
with trace_path(["condition", str(index)]):
if not check(hass, variables):
return False
except ConditionError as ex:

View File

@ -2,24 +2,12 @@
import asyncio
from collections import deque
from contextlib import contextmanager
from contextvars import ContextVar
from datetime import datetime, timedelta
import functools as ft
import logging
import re
import sys
from typing import (
Any,
Callable,
Container,
Dict,
Generator,
List,
Optional,
Set,
Union,
cast,
)
from typing import Any, Callable, Container, Generator, List, Optional, Set, Union, cast
from homeassistant.components import zone as zone_cmp
from homeassistant.components.device_automation import (
@ -67,6 +55,9 @@ import homeassistant.util.dt as dt_util
from .trace import (
TraceElement,
trace_append_element,
trace_path,
trace_path_get,
trace_stack_cv,
trace_stack_pop,
trace_stack_push,
trace_stack_top,
@ -84,79 +75,16 @@ INPUT_ENTITY_ID = re.compile(
ConditionCheckerType = Callable[[HomeAssistant, TemplateVarsType], bool]
# Context variables for tracing
# Trace of condition being evaluated
condition_trace = ContextVar("condition_trace", default=None)
# Stack of TraceElements
condition_trace_stack: ContextVar[Optional[List[TraceElement]]] = ContextVar(
"condition_trace_stack", default=None
)
# Current location in config tree
condition_path_stack: ContextVar[Optional[List[str]]] = ContextVar(
"condition_path_stack", default=None
)
def condition_trace_stack_push(node: TraceElement) -> None:
"""Push a TraceElement to the top of the trace stack."""
trace_stack_push(condition_trace_stack, node)
def condition_trace_stack_pop() -> None:
"""Remove the top element from the trace stack."""
trace_stack_pop(condition_trace_stack)
def condition_trace_stack_top() -> Optional[TraceElement]:
"""Return the element at the top of the trace stack."""
return cast(Optional[TraceElement], trace_stack_top(condition_trace_stack))
def condition_path_push(suffix: Union[str, List[str]]) -> int:
"""Go deeper in the config tree."""
if isinstance(suffix, str):
suffix = [suffix]
for node in suffix:
trace_stack_push(condition_path_stack, node)
return len(suffix)
def condition_path_pop(count: int) -> None:
"""Go n levels up in the config tree."""
for _ in range(count):
trace_stack_pop(condition_path_stack)
def condition_path_get() -> str:
"""Return a string representing the current location in the config tree."""
path = condition_path_stack.get()
if not path:
return ""
return "/".join(path)
def condition_trace_get() -> Optional[Dict[str, TraceElement]]:
"""Return the trace of the condition that was evaluated."""
return condition_trace.get()
def condition_trace_clear() -> None:
"""Clear the condition trace."""
condition_trace.set(None)
condition_trace_stack.set(None)
condition_path_stack.set(None)
def condition_trace_append(variables: TemplateVarsType, path: str) -> TraceElement:
"""Append a TraceElement to trace[path]."""
trace_element = TraceElement(variables)
trace_append_element(condition_trace, trace_element, path)
trace_append_element(trace_element, path)
return trace_element
def condition_trace_set_result(result: bool, **kwargs: Any) -> None:
"""Set the result of TraceElement at the top of the stack."""
node = condition_trace_stack_top()
node = trace_stack_top(trace_stack_cv)
# The condition function may be called directly, in which case tracing
# is not setup
@ -169,25 +97,15 @@ def condition_trace_set_result(result: bool, **kwargs: Any) -> None:
@contextmanager
def trace_condition(variables: TemplateVarsType) -> Generator:
"""Trace condition evaluation."""
trace_element = condition_trace_append(variables, condition_path_get())
condition_trace_stack_push(trace_element)
trace_element = condition_trace_append(variables, trace_path_get())
trace_stack_push(trace_stack_cv, trace_element)
try:
yield trace_element
except Exception as ex: # pylint: disable=broad-except
trace_element.set_error(ex)
raise ex
finally:
condition_trace_stack_pop()
@contextmanager
def condition_path(suffix: Union[str, List[str]]) -> Generator:
"""Go deeper in the config tree."""
count = condition_path_push(suffix)
try:
yield
finally:
condition_path_pop(count)
trace_stack_pop(trace_stack_cv)
def trace_condition_function(condition: ConditionCheckerType) -> ConditionCheckerType:
@ -260,7 +178,7 @@ async def async_and_from_config(
errors = []
for index, check in enumerate(checks):
try:
with condition_path(["conditions", str(index)]):
with trace_path(["conditions", str(index)]):
if not check(hass, variables):
return False
except ConditionError as ex:
@ -295,7 +213,7 @@ async def async_or_from_config(
errors = []
for index, check in enumerate(checks):
try:
with condition_path(["conditions", str(index)]):
with trace_path(["conditions", str(index)]):
if check(hass, variables):
return True
except ConditionError as ex:
@ -330,7 +248,7 @@ async def async_not_from_config(
errors = []
for index, check in enumerate(checks):
try:
with condition_path(["conditions", str(index)]):
with trace_path(["conditions", str(index)]):
if check(hass, variables):
return False
except ConditionError as ex:
@ -509,9 +427,7 @@ def async_numeric_state_from_config(
errors = []
for index, entity_id in enumerate(entity_ids):
try:
with condition_path(["entity_id", str(index)]), trace_condition(
variables
):
with trace_path(["entity_id", str(index)]), trace_condition(variables):
if not async_numeric_state(
hass,
entity_id,
@ -623,9 +539,7 @@ def state_from_config(
errors = []
for index, entity_id in enumerate(entity_ids):
try:
with condition_path(["entity_id", str(index)]), trace_condition(
variables
):
with trace_path(["entity_id", str(index)]), trace_condition(variables):
if not state(hass, entity_id, req_states, for_period, attribute):
return False
except ConditionError as ex:

View File

@ -1,7 +1,6 @@
"""Helpers to execute scripts."""
import asyncio
from contextlib import contextmanager
from contextvars import ContextVar
from datetime import datetime, timedelta
from functools import partial
import itertools
@ -65,12 +64,7 @@ from homeassistant.core import (
callback,
)
from homeassistant.helpers import condition, config_validation as cv, service, template
from homeassistant.helpers.condition import (
condition_path,
condition_trace_clear,
condition_trace_get,
trace_condition_function,
)
from homeassistant.helpers.condition import trace_condition_function
from homeassistant.helpers.event import async_call_later, async_track_template
from homeassistant.helpers.script_variables import ScriptVariables
from homeassistant.helpers.trigger import (
@ -84,9 +78,12 @@ from homeassistant.util.dt import utcnow
from .trace import (
TraceElement,
trace_append_element,
trace_path,
trace_path_get,
trace_set_result,
trace_stack_cv,
trace_stack_pop,
trace_stack_push,
trace_stack_top,
)
# mypy: allow-untyped-calls, allow-untyped-defs, no-check-untyped-defs
@ -125,111 +122,26 @@ _SHUTDOWN_MAX_WAIT = 60
ACTION_TRACE_NODE_MAX_LEN = 20 # Max the length of a trace node for repeated actions
action_trace = ContextVar("action_trace", default=None)
action_trace_stack = ContextVar("action_trace_stack", default=None)
action_path_stack = ContextVar("action_path_stack", default=None)
def action_trace_stack_push(node):
"""Push a TraceElement to the top of the trace stack."""
trace_stack_push(action_trace_stack, node)
def action_trace_stack_pop():
"""Remove the top element from the trace stack."""
trace_stack_pop(action_trace_stack)
def action_trace_stack_top():
"""Return the element at the top of the trace stack."""
return trace_stack_top(action_trace_stack)
def action_path_push(suffix):
"""Go deeper in the config tree."""
if isinstance(suffix, str):
suffix = [suffix]
for node in suffix:
trace_stack_push(action_path_stack, node)
return len(suffix)
def action_path_pop(count):
"""Go n levels up in the config tree."""
for _ in range(count):
trace_stack_pop(action_path_stack)
def action_path_get():
"""Return a string representing the current location in the config tree."""
path = action_path_stack.get()
if not path:
return ""
return "/".join(path)
def action_trace_get():
"""Return the trace of the script that was executed."""
return action_trace.get()
def action_trace_clear():
"""Clear the action trace."""
action_trace.set({})
action_trace_stack.set(None)
action_path_stack.set(None)
def action_trace_append(variables, path):
"""Append a TraceElement to trace[path]."""
trace_element = TraceElement(variables)
trace_append_element(action_trace, trace_element, path, ACTION_TRACE_NODE_MAX_LEN)
trace_append_element(trace_element, path, ACTION_TRACE_NODE_MAX_LEN)
return trace_element
def action_trace_set_result(**kwargs):
"""Set the result of TraceElement at the top of the stack."""
node = action_trace_stack_top()
node.set_result(**kwargs)
def action_trace_add_conditions():
"""Add the result of condition evaluation to the action trace."""
condition_trace = condition_trace_get()
condition_trace_clear()
if condition_trace is None:
return
_action_path = action_path_get()
for cond_path, conditions in condition_trace.items():
path = _action_path + "/" + cond_path if cond_path else _action_path
for cond in conditions:
trace_append_element(action_trace, cond, path)
@contextmanager
def trace_action(variables):
"""Trace action execution."""
trace_element = action_trace_append(variables, action_path_get())
action_trace_stack_push(trace_element)
trace_element = action_trace_append(variables, trace_path_get())
trace_stack_push(trace_stack_cv, trace_element)
try:
yield trace_element
except Exception as ex: # pylint: disable=broad-except
trace_element.set_error(ex)
raise ex
finally:
action_trace_stack_pop()
@contextmanager
def action_path(suffix):
"""Go deeper in the config tree."""
count = action_path_push(suffix)
try:
yield
finally:
action_path_pop(count)
trace_stack_pop(trace_stack_cv)
def make_script_schema(schema, default_script_mode, extra=vol.PREVENT_EXTRA):
@ -382,7 +294,7 @@ class _ScriptRun:
self._finish()
async def _async_step(self, log_exceptions):
with action_path(str(self._step)), trace_action(None):
with trace_path(str(self._step)), trace_action(None):
try:
handler = f"_async_{cv.determine_script_action(self._action)}_step"
await getattr(self, handler)()
@ -638,15 +550,14 @@ class _ScriptRun:
)
cond = await self._async_get_condition(self._action)
try:
with condition_path("condition"):
with trace_path("condition"):
check = cond(self._hass, self._variables)
except exceptions.ConditionError as ex:
_LOGGER.warning("Error in 'condition' evaluation:\n%s", ex)
check = False
self._log("Test condition %s: %s", self._script.last_action, check)
action_trace_set_result(result=check)
action_trace_add_conditions()
trace_set_result(result=check)
if not check:
raise _StopScript
@ -654,9 +565,9 @@ class _ScriptRun:
@trace_condition_function
def traced_test_conditions(hass, variables):
try:
with condition_path("conditions"):
with trace_path("conditions"):
for idx, cond in enumerate(conditions):
with condition_path(str(idx)):
with trace_path(str(idx)):
if not cond(hass, variables):
return False
except exceptions.ConditionError as ex:
@ -666,7 +577,6 @@ class _ScriptRun:
return True
result = traced_test_conditions(self._hass, self._variables)
action_trace_add_conditions()
return result
async def _async_repeat_step(self):
@ -687,7 +597,7 @@ class _ScriptRun:
async def async_run_sequence(iteration, extra_msg=""):
self._log("Repeating %s: Iteration %i%s", description, iteration, extra_msg)
with action_path(str(self._step)):
with trace_path(str(self._step)):
await self._async_run_script(script)
if CONF_COUNT in repeat:
@ -754,18 +664,18 @@ class _ScriptRun:
choose_data = await self._script._async_get_choose_data(self._step)
for idx, (conditions, script) in enumerate(choose_data["choices"]):
with action_path(str(idx)):
with trace_path(str(idx)):
try:
if self._test_conditions(conditions, "choose"):
action_trace_set_result(choice=idx)
trace_set_result(choice=idx)
await self._async_run_script(script)
return
except exceptions.ConditionError as ex:
_LOGGER.warning("Error in 'choose' evaluation:\n%s", ex)
if choose_data["default"]:
action_trace_set_result(choice="default")
with action_path("default"):
trace_set_result(choice="default")
with trace_path("default"):
await self._async_run_script(choose_data["default"])
async def _async_wait_for_trigger_step(self):

View File

@ -1,33 +1,13 @@
"""Helpers for script and condition tracing."""
from collections import deque
from contextlib import contextmanager
from contextvars import ContextVar
from typing import Any, Dict, Optional
from typing import Any, Deque, Dict, Generator, List, Optional, Union, cast
from homeassistant.helpers.typing import TemplateVarsType
import homeassistant.util.dt as dt_util
def trace_stack_push(trace_stack_var: ContextVar, node: Any) -> None:
"""Push an element to the top of a trace stack."""
trace_stack = trace_stack_var.get()
if trace_stack is None:
trace_stack = []
trace_stack_var.set(trace_stack)
trace_stack.append(node)
def trace_stack_pop(trace_stack_var: ContextVar) -> None:
"""Remove the top element from a trace stack."""
trace_stack = trace_stack_var.get()
trace_stack.pop()
def trace_stack_top(trace_stack_var: ContextVar) -> Optional[Any]:
"""Return the element at the top of a trace stack."""
trace_stack = trace_stack_var.get()
return trace_stack[-1] if trace_stack else None
class TraceElement:
"""Container for trace data."""
@ -62,17 +42,105 @@ class TraceElement:
return result
# Context variables for tracing
# Current trace
trace_cv: ContextVar[Optional[Dict[str, Deque[TraceElement]]]] = ContextVar(
"trace_cv", default=None
)
# Stack of TraceElements
trace_stack_cv: ContextVar[Optional[List[TraceElement]]] = ContextVar(
"trace_stack_cv", default=None
)
# Current location in config tree
trace_path_stack_cv: ContextVar[Optional[List[str]]] = ContextVar(
"trace_path_stack_cv", default=None
)
def trace_stack_push(trace_stack_var: ContextVar, node: Any) -> None:
"""Push an element to the top of a trace stack."""
trace_stack = trace_stack_var.get()
if trace_stack is None:
trace_stack = []
trace_stack_var.set(trace_stack)
trace_stack.append(node)
def trace_stack_pop(trace_stack_var: ContextVar) -> None:
"""Remove the top element from a trace stack."""
trace_stack = trace_stack_var.get()
trace_stack.pop()
def trace_stack_top(trace_stack_var: ContextVar) -> Optional[Any]:
"""Return the element at the top of a trace stack."""
trace_stack = trace_stack_var.get()
return trace_stack[-1] if trace_stack else None
def trace_path_push(suffix: Union[str, List[str]]) -> int:
"""Go deeper in the config tree."""
if isinstance(suffix, str):
suffix = [suffix]
for node in suffix:
trace_stack_push(trace_path_stack_cv, node)
return len(suffix)
def trace_path_pop(count: int) -> None:
"""Go n levels up in the config tree."""
for _ in range(count):
trace_stack_pop(trace_path_stack_cv)
def trace_path_get() -> str:
"""Return a string representing the current location in the config tree."""
path = trace_path_stack_cv.get()
if not path:
return ""
return "/".join(path)
def trace_append_element(
trace_var: ContextVar,
trace_element: TraceElement,
path: str,
maxlen: Optional[int] = None,
) -> None:
"""Append a TraceElement to trace[path]."""
trace = trace_var.get()
trace = trace_cv.get()
if trace is None:
trace_var.set({})
trace = trace_var.get()
trace = {}
trace_cv.set(trace)
if path not in trace:
trace[path] = deque(maxlen=maxlen)
trace[path].append(trace_element)
def trace_get(clear: bool = True) -> Optional[Dict[str, Deque[TraceElement]]]:
"""Return the current trace."""
if clear:
trace_clear()
return trace_cv.get()
def trace_clear() -> None:
"""Clear the trace."""
trace_cv.set({})
trace_stack_cv.set(None)
trace_path_stack_cv.set(None)
def trace_set_result(**kwargs: Any) -> None:
"""Set the result of TraceElement at the top of the stack."""
node = cast(TraceElement, trace_stack_top(trace_stack_cv))
node.set_result(**kwargs)
@contextmanager
def trace_path(suffix: Union[str, List[str]]) -> Generator:
"""Go deeper in the config tree."""
count = trace_path_push(suffix)
try:
yield
finally:
trace_path_pop(count)

View File

@ -4,7 +4,7 @@ from unittest.mock import patch
import pytest
from homeassistant.exceptions import ConditionError, HomeAssistantError
from homeassistant.helpers import condition
from homeassistant.helpers import condition, trace
from homeassistant.helpers.template import Template
from homeassistant.setup import async_setup_component
from homeassistant.util import dt
@ -25,8 +25,8 @@ def assert_element(trace_element, expected_element, path):
def assert_condition_trace(expected):
"""Assert a trace condition sequence is as expected."""
condition_trace = condition.condition_trace_get()
condition.condition_trace_clear()
condition_trace = trace.trace_get(clear=False)
trace.trace_clear()
expected_trace_keys = list(expected.keys())
assert list(condition_trace.keys()) == expected_trace_keys
for trace_key_index, key in enumerate(expected_trace_keys):

View File

@ -17,7 +17,7 @@ from homeassistant import exceptions
import homeassistant.components.scene as scene
from homeassistant.const import ATTR_ENTITY_ID, SERVICE_TURN_ON
from homeassistant.core import Context, CoreState, callback
from homeassistant.helpers import config_validation as cv, script
from homeassistant.helpers import config_validation as cv, script, trace
from homeassistant.setup import async_setup_component
import homeassistant.util.dt as dt_util
@ -45,8 +45,8 @@ def assert_element(trace_element, expected_element, path):
def assert_action_trace(expected):
"""Assert a trace condition sequence is as expected."""
action_trace = script.action_trace_get()
script.action_trace_clear()
action_trace = trace.trace_get(clear=False)
trace.trace_clear()
expected_trace_keys = list(expected.keys())
assert list(action_trace.keys()) == expected_trace_keys
for trace_key_index, key in enumerate(expected_trace_keys):