Merge action and condition traces (#47373)

* Merge action and condition traces

* Update __init__.py

* Add typing to AutomationTrace

* Make trace_get prepare a new trace by default

* Correct typing of trace_cv

* Fix tests
pull/47510/head
Erik Montnemery 2021-03-06 12:57:21 +01:00 committed by GitHub
parent 022184176a
commit 2f9d03d115
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 191 additions and 284 deletions

View File

@ -1,8 +1,20 @@
"""Allow to set up simple automation rules via the config file.""" """Allow to set up simple automation rules via the config file."""
from collections import deque from collections import deque
from contextlib import contextmanager from contextlib import contextmanager
import datetime as dt
import logging import logging
from typing import Any, Awaitable, Callable, Dict, List, Optional, Set, Union, cast from typing import (
Any,
Awaitable,
Callable,
Deque,
Dict,
List,
Optional,
Set,
Union,
cast,
)
import voluptuous as vol import voluptuous as vol
from voluptuous.humanize import humanize_error from voluptuous.humanize import humanize_error
@ -42,11 +54,6 @@ from homeassistant.exceptions import (
HomeAssistantError, HomeAssistantError,
) )
from homeassistant.helpers import condition, extract_domain_configs, template from homeassistant.helpers import condition, extract_domain_configs, template
from homeassistant.helpers.condition import (
condition_path,
condition_trace_clear,
condition_trace_get,
)
import homeassistant.helpers.config_validation as cv import homeassistant.helpers.config_validation as cv
from homeassistant.helpers.entity import ToggleEntity from homeassistant.helpers.entity import ToggleEntity
from homeassistant.helpers.entity_component import EntityComponent from homeassistant.helpers.entity_component import EntityComponent
@ -57,12 +64,10 @@ from homeassistant.helpers.script import (
CONF_MAX, CONF_MAX,
CONF_MAX_EXCEEDED, CONF_MAX_EXCEEDED,
Script, Script,
action_path,
action_trace_clear,
action_trace_get,
) )
from homeassistant.helpers.script_variables import ScriptVariables from homeassistant.helpers.script_variables import ScriptVariables
from homeassistant.helpers.service import async_register_admin_service from homeassistant.helpers.service import async_register_admin_service
from homeassistant.helpers.trace import TraceElement, trace_get, trace_path
from homeassistant.helpers.trigger import async_initialize_triggers from homeassistant.helpers.trigger import async_initialize_triggers
from homeassistant.helpers.typing import TemplateVarsType from homeassistant.helpers.typing import TemplateVarsType
from homeassistant.loader import bind_hass from homeassistant.loader import bind_hass
@ -235,44 +240,55 @@ async def async_setup(hass, config):
class AutomationTrace: class AutomationTrace:
"""Container for automation trace.""" """Container for automation trace."""
def __init__(self, unique_id, config, trigger, context, action_trace): def __init__(
self,
unique_id: Optional[str],
config: Dict[str, Any],
trigger: Dict[str, Any],
context: Context,
):
"""Container for automation trace.""" """Container for automation trace."""
self._action_trace = action_trace self._action_trace: Optional[Dict[str, Deque[TraceElement]]] = None
self._condition_trace = None self._condition_trace: Optional[Dict[str, Deque[TraceElement]]] = None
self._config = config self._config: Dict[str, Any] = config
self._context = context self._context: Context = context
self._error = None self._error: Optional[Exception] = None
self._state = "running" self._state: str = "running"
self._timestamp_finish = None self._timestamp_finish: Optional[dt.datetime] = None
self._timestamp_start = dt_util.utcnow() self._timestamp_start: dt.datetime = dt_util.utcnow()
self._trigger = trigger self._trigger: Dict[str, Any] = trigger
self._unique_id = unique_id self._unique_id: Optional[str] = unique_id
self._variables = None self._variables: Optional[Dict[str, Any]] = None
def set_error(self, ex): def set_action_trace(self, trace: Dict[str, Deque[TraceElement]]) -> None:
"""Set action trace."""
self._action_trace = trace
def set_condition_trace(self, trace: Dict[str, Deque[TraceElement]]) -> None:
"""Set condition trace."""
self._condition_trace = trace
def set_error(self, ex: Exception) -> None:
"""Set error.""" """Set error."""
self._error = ex self._error = ex
def set_variables(self, variables): def set_variables(self, variables: Dict[str, Any]) -> None:
"""Set variables.""" """Set variables."""
self._variables = variables self._variables = variables
def set_condition_trace(self, condition_trace): def finished(self) -> None:
"""Set condition trace."""
self._condition_trace = condition_trace
def finished(self):
"""Set finish time.""" """Set finish time."""
self._timestamp_finish = dt_util.utcnow() self._timestamp_finish = dt_util.utcnow()
self._state = "stopped" self._state = "stopped"
def as_dict(self): def as_dict(self) -> Dict[str, Any]:
"""Return dictionary version of this AutomationTrace.""" """Return dictionary version of this AutomationTrace."""
action_traces = {} action_traces = {}
condition_traces = {} condition_traces = {}
for key, trace_list in self._action_trace.items(): if self._action_trace:
action_traces[key] = [item.as_dict() for item in trace_list] for key, trace_list in self._action_trace.items():
action_traces[key] = [item.as_dict() for item in trace_list]
if self._condition_trace: if self._condition_trace:
for key, trace_list in self._condition_trace.items(): for key, trace_list in self._condition_trace.items():
@ -300,11 +316,7 @@ class AutomationTrace:
@contextmanager @contextmanager
def trace_automation(hass, unique_id, config, trigger, context): def trace_automation(hass, unique_id, config, trigger, context):
"""Trace action execution of automation with automation_id.""" """Trace action execution of automation with automation_id."""
action_trace_clear() automation_trace = AutomationTrace(unique_id, config, trigger, context)
action_trace = action_trace_get()
automation_trace = AutomationTrace(
unique_id, config, trigger, context, action_trace
)
if unique_id: if unique_id:
if unique_id not in hass.data[DATA_AUTOMATION_TRACE]: if unique_id not in hass.data[DATA_AUTOMATION_TRACE]:
@ -325,7 +337,7 @@ def trace_automation(hass, unique_id, config, trigger, context):
"Automation finished. Summary:\n\ttrigger: %s\n\tcondition: %s\n\taction: %s", "Automation finished. Summary:\n\ttrigger: %s\n\tcondition: %s\n\taction: %s",
automation_trace._trigger, # pylint: disable=protected-access automation_trace._trigger, # pylint: disable=protected-access
automation_trace._condition_trace, # pylint: disable=protected-access automation_trace._condition_trace, # pylint: disable=protected-access
action_trace, automation_trace._action_trace, # pylint: disable=protected-access
) )
@ -510,6 +522,9 @@ class AutomationEntity(ToggleEntity, RestoreEntity):
variables = run_variables variables = run_variables
automation_trace.set_variables(variables) automation_trace.set_variables(variables)
# Prepare tracing the evaluation of the automation's conditions
automation_trace.set_condition_trace(trace_get())
if ( if (
not skip_condition not skip_condition
and self._cond_func is not None and self._cond_func is not None
@ -517,12 +532,12 @@ class AutomationEntity(ToggleEntity, RestoreEntity):
): ):
self._logger.debug( self._logger.debug(
"Conditions not met, aborting automation. Condition summary: %s", "Conditions not met, aborting automation. Condition summary: %s",
condition_trace_get(), trace_get(clear=False),
) )
automation_trace.set_condition_trace(condition_trace_get())
return return
automation_trace.set_condition_trace(condition_trace_get())
condition_trace_clear() # Prepare tracing the execution of the automation's actions
automation_trace.set_action_trace(trace_get())
# Create a new context referring to the old context. # Create a new context referring to the old context.
parent_id = None if context is None else context.id parent_id = None if context is None else context.id
@ -543,7 +558,7 @@ class AutomationEntity(ToggleEntity, RestoreEntity):
) )
try: try:
with action_path("action"): with trace_path("action"):
await self.action_script.async_run( await self.action_script.async_run(
variables, trigger_context, started_action variables, trigger_context, started_action
) )
@ -763,7 +778,7 @@ async def _async_process_if(hass, name, config, p_config):
errors = [] errors = []
for index, check in enumerate(checks): for index, check in enumerate(checks):
try: try:
with condition_path(["condition", str(index)]): with trace_path(["condition", str(index)]):
if not check(hass, variables): if not check(hass, variables):
return False return False
except ConditionError as ex: except ConditionError as ex:

View File

@ -2,24 +2,12 @@
import asyncio import asyncio
from collections import deque from collections import deque
from contextlib import contextmanager from contextlib import contextmanager
from contextvars import ContextVar
from datetime import datetime, timedelta from datetime import datetime, timedelta
import functools as ft import functools as ft
import logging import logging
import re import re
import sys import sys
from typing import ( from typing import Any, Callable, Container, Generator, List, Optional, Set, Union, cast
Any,
Callable,
Container,
Dict,
Generator,
List,
Optional,
Set,
Union,
cast,
)
from homeassistant.components import zone as zone_cmp from homeassistant.components import zone as zone_cmp
from homeassistant.components.device_automation import ( from homeassistant.components.device_automation import (
@ -67,6 +55,9 @@ import homeassistant.util.dt as dt_util
from .trace import ( from .trace import (
TraceElement, TraceElement,
trace_append_element, trace_append_element,
trace_path,
trace_path_get,
trace_stack_cv,
trace_stack_pop, trace_stack_pop,
trace_stack_push, trace_stack_push,
trace_stack_top, trace_stack_top,
@ -84,79 +75,16 @@ INPUT_ENTITY_ID = re.compile(
ConditionCheckerType = Callable[[HomeAssistant, TemplateVarsType], bool] ConditionCheckerType = Callable[[HomeAssistant, TemplateVarsType], bool]
# Context variables for tracing
# Trace of condition being evaluated
condition_trace = ContextVar("condition_trace", default=None)
# Stack of TraceElements
condition_trace_stack: ContextVar[Optional[List[TraceElement]]] = ContextVar(
"condition_trace_stack", default=None
)
# Current location in config tree
condition_path_stack: ContextVar[Optional[List[str]]] = ContextVar(
"condition_path_stack", default=None
)
def condition_trace_stack_push(node: TraceElement) -> None:
"""Push a TraceElement to the top of the trace stack."""
trace_stack_push(condition_trace_stack, node)
def condition_trace_stack_pop() -> None:
"""Remove the top element from the trace stack."""
trace_stack_pop(condition_trace_stack)
def condition_trace_stack_top() -> Optional[TraceElement]:
"""Return the element at the top of the trace stack."""
return cast(Optional[TraceElement], trace_stack_top(condition_trace_stack))
def condition_path_push(suffix: Union[str, List[str]]) -> int:
"""Go deeper in the config tree."""
if isinstance(suffix, str):
suffix = [suffix]
for node in suffix:
trace_stack_push(condition_path_stack, node)
return len(suffix)
def condition_path_pop(count: int) -> None:
"""Go n levels up in the config tree."""
for _ in range(count):
trace_stack_pop(condition_path_stack)
def condition_path_get() -> str:
"""Return a string representing the current location in the config tree."""
path = condition_path_stack.get()
if not path:
return ""
return "/".join(path)
def condition_trace_get() -> Optional[Dict[str, TraceElement]]:
"""Return the trace of the condition that was evaluated."""
return condition_trace.get()
def condition_trace_clear() -> None:
"""Clear the condition trace."""
condition_trace.set(None)
condition_trace_stack.set(None)
condition_path_stack.set(None)
def condition_trace_append(variables: TemplateVarsType, path: str) -> TraceElement: def condition_trace_append(variables: TemplateVarsType, path: str) -> TraceElement:
"""Append a TraceElement to trace[path].""" """Append a TraceElement to trace[path]."""
trace_element = TraceElement(variables) trace_element = TraceElement(variables)
trace_append_element(condition_trace, trace_element, path) trace_append_element(trace_element, path)
return trace_element return trace_element
def condition_trace_set_result(result: bool, **kwargs: Any) -> None: def condition_trace_set_result(result: bool, **kwargs: Any) -> None:
"""Set the result of TraceElement at the top of the stack.""" """Set the result of TraceElement at the top of the stack."""
node = condition_trace_stack_top() node = trace_stack_top(trace_stack_cv)
# The condition function may be called directly, in which case tracing # The condition function may be called directly, in which case tracing
# is not setup # is not setup
@ -169,25 +97,15 @@ def condition_trace_set_result(result: bool, **kwargs: Any) -> None:
@contextmanager @contextmanager
def trace_condition(variables: TemplateVarsType) -> Generator: def trace_condition(variables: TemplateVarsType) -> Generator:
"""Trace condition evaluation.""" """Trace condition evaluation."""
trace_element = condition_trace_append(variables, condition_path_get()) trace_element = condition_trace_append(variables, trace_path_get())
condition_trace_stack_push(trace_element) trace_stack_push(trace_stack_cv, trace_element)
try: try:
yield trace_element yield trace_element
except Exception as ex: # pylint: disable=broad-except except Exception as ex: # pylint: disable=broad-except
trace_element.set_error(ex) trace_element.set_error(ex)
raise ex raise ex
finally: finally:
condition_trace_stack_pop() trace_stack_pop(trace_stack_cv)
@contextmanager
def condition_path(suffix: Union[str, List[str]]) -> Generator:
"""Go deeper in the config tree."""
count = condition_path_push(suffix)
try:
yield
finally:
condition_path_pop(count)
def trace_condition_function(condition: ConditionCheckerType) -> ConditionCheckerType: def trace_condition_function(condition: ConditionCheckerType) -> ConditionCheckerType:
@ -260,7 +178,7 @@ async def async_and_from_config(
errors = [] errors = []
for index, check in enumerate(checks): for index, check in enumerate(checks):
try: try:
with condition_path(["conditions", str(index)]): with trace_path(["conditions", str(index)]):
if not check(hass, variables): if not check(hass, variables):
return False return False
except ConditionError as ex: except ConditionError as ex:
@ -295,7 +213,7 @@ async def async_or_from_config(
errors = [] errors = []
for index, check in enumerate(checks): for index, check in enumerate(checks):
try: try:
with condition_path(["conditions", str(index)]): with trace_path(["conditions", str(index)]):
if check(hass, variables): if check(hass, variables):
return True return True
except ConditionError as ex: except ConditionError as ex:
@ -330,7 +248,7 @@ async def async_not_from_config(
errors = [] errors = []
for index, check in enumerate(checks): for index, check in enumerate(checks):
try: try:
with condition_path(["conditions", str(index)]): with trace_path(["conditions", str(index)]):
if check(hass, variables): if check(hass, variables):
return False return False
except ConditionError as ex: except ConditionError as ex:
@ -509,9 +427,7 @@ def async_numeric_state_from_config(
errors = [] errors = []
for index, entity_id in enumerate(entity_ids): for index, entity_id in enumerate(entity_ids):
try: try:
with condition_path(["entity_id", str(index)]), trace_condition( with trace_path(["entity_id", str(index)]), trace_condition(variables):
variables
):
if not async_numeric_state( if not async_numeric_state(
hass, hass,
entity_id, entity_id,
@ -623,9 +539,7 @@ def state_from_config(
errors = [] errors = []
for index, entity_id in enumerate(entity_ids): for index, entity_id in enumerate(entity_ids):
try: try:
with condition_path(["entity_id", str(index)]), trace_condition( with trace_path(["entity_id", str(index)]), trace_condition(variables):
variables
):
if not state(hass, entity_id, req_states, for_period, attribute): if not state(hass, entity_id, req_states, for_period, attribute):
return False return False
except ConditionError as ex: except ConditionError as ex:

View File

@ -1,7 +1,6 @@
"""Helpers to execute scripts.""" """Helpers to execute scripts."""
import asyncio import asyncio
from contextlib import contextmanager from contextlib import contextmanager
from contextvars import ContextVar
from datetime import datetime, timedelta from datetime import datetime, timedelta
from functools import partial from functools import partial
import itertools import itertools
@ -65,12 +64,7 @@ from homeassistant.core import (
callback, callback,
) )
from homeassistant.helpers import condition, config_validation as cv, service, template from homeassistant.helpers import condition, config_validation as cv, service, template
from homeassistant.helpers.condition import ( from homeassistant.helpers.condition import trace_condition_function
condition_path,
condition_trace_clear,
condition_trace_get,
trace_condition_function,
)
from homeassistant.helpers.event import async_call_later, async_track_template from homeassistant.helpers.event import async_call_later, async_track_template
from homeassistant.helpers.script_variables import ScriptVariables from homeassistant.helpers.script_variables import ScriptVariables
from homeassistant.helpers.trigger import ( from homeassistant.helpers.trigger import (
@ -84,9 +78,12 @@ from homeassistant.util.dt import utcnow
from .trace import ( from .trace import (
TraceElement, TraceElement,
trace_append_element, trace_append_element,
trace_path,
trace_path_get,
trace_set_result,
trace_stack_cv,
trace_stack_pop, trace_stack_pop,
trace_stack_push, trace_stack_push,
trace_stack_top,
) )
# mypy: allow-untyped-calls, allow-untyped-defs, no-check-untyped-defs # mypy: allow-untyped-calls, allow-untyped-defs, no-check-untyped-defs
@ -125,111 +122,26 @@ _SHUTDOWN_MAX_WAIT = 60
ACTION_TRACE_NODE_MAX_LEN = 20 # Max the length of a trace node for repeated actions ACTION_TRACE_NODE_MAX_LEN = 20 # Max the length of a trace node for repeated actions
action_trace = ContextVar("action_trace", default=None)
action_trace_stack = ContextVar("action_trace_stack", default=None)
action_path_stack = ContextVar("action_path_stack", default=None)
def action_trace_stack_push(node):
"""Push a TraceElement to the top of the trace stack."""
trace_stack_push(action_trace_stack, node)
def action_trace_stack_pop():
"""Remove the top element from the trace stack."""
trace_stack_pop(action_trace_stack)
def action_trace_stack_top():
"""Return the element at the top of the trace stack."""
return trace_stack_top(action_trace_stack)
def action_path_push(suffix):
"""Go deeper in the config tree."""
if isinstance(suffix, str):
suffix = [suffix]
for node in suffix:
trace_stack_push(action_path_stack, node)
return len(suffix)
def action_path_pop(count):
"""Go n levels up in the config tree."""
for _ in range(count):
trace_stack_pop(action_path_stack)
def action_path_get():
"""Return a string representing the current location in the config tree."""
path = action_path_stack.get()
if not path:
return ""
return "/".join(path)
def action_trace_get():
"""Return the trace of the script that was executed."""
return action_trace.get()
def action_trace_clear():
"""Clear the action trace."""
action_trace.set({})
action_trace_stack.set(None)
action_path_stack.set(None)
def action_trace_append(variables, path): def action_trace_append(variables, path):
"""Append a TraceElement to trace[path].""" """Append a TraceElement to trace[path]."""
trace_element = TraceElement(variables) trace_element = TraceElement(variables)
trace_append_element(action_trace, trace_element, path, ACTION_TRACE_NODE_MAX_LEN) trace_append_element(trace_element, path, ACTION_TRACE_NODE_MAX_LEN)
return trace_element return trace_element
def action_trace_set_result(**kwargs):
"""Set the result of TraceElement at the top of the stack."""
node = action_trace_stack_top()
node.set_result(**kwargs)
def action_trace_add_conditions():
"""Add the result of condition evaluation to the action trace."""
condition_trace = condition_trace_get()
condition_trace_clear()
if condition_trace is None:
return
_action_path = action_path_get()
for cond_path, conditions in condition_trace.items():
path = _action_path + "/" + cond_path if cond_path else _action_path
for cond in conditions:
trace_append_element(action_trace, cond, path)
@contextmanager @contextmanager
def trace_action(variables): def trace_action(variables):
"""Trace action execution.""" """Trace action execution."""
trace_element = action_trace_append(variables, action_path_get()) trace_element = action_trace_append(variables, trace_path_get())
action_trace_stack_push(trace_element) trace_stack_push(trace_stack_cv, trace_element)
try: try:
yield trace_element yield trace_element
except Exception as ex: # pylint: disable=broad-except except Exception as ex: # pylint: disable=broad-except
trace_element.set_error(ex) trace_element.set_error(ex)
raise ex raise ex
finally: finally:
action_trace_stack_pop() trace_stack_pop(trace_stack_cv)
@contextmanager
def action_path(suffix):
"""Go deeper in the config tree."""
count = action_path_push(suffix)
try:
yield
finally:
action_path_pop(count)
def make_script_schema(schema, default_script_mode, extra=vol.PREVENT_EXTRA): def make_script_schema(schema, default_script_mode, extra=vol.PREVENT_EXTRA):
@ -382,7 +294,7 @@ class _ScriptRun:
self._finish() self._finish()
async def _async_step(self, log_exceptions): async def _async_step(self, log_exceptions):
with action_path(str(self._step)), trace_action(None): with trace_path(str(self._step)), trace_action(None):
try: try:
handler = f"_async_{cv.determine_script_action(self._action)}_step" handler = f"_async_{cv.determine_script_action(self._action)}_step"
await getattr(self, handler)() await getattr(self, handler)()
@ -638,15 +550,14 @@ class _ScriptRun:
) )
cond = await self._async_get_condition(self._action) cond = await self._async_get_condition(self._action)
try: try:
with condition_path("condition"): with trace_path("condition"):
check = cond(self._hass, self._variables) check = cond(self._hass, self._variables)
except exceptions.ConditionError as ex: except exceptions.ConditionError as ex:
_LOGGER.warning("Error in 'condition' evaluation:\n%s", ex) _LOGGER.warning("Error in 'condition' evaluation:\n%s", ex)
check = False check = False
self._log("Test condition %s: %s", self._script.last_action, check) self._log("Test condition %s: %s", self._script.last_action, check)
action_trace_set_result(result=check) trace_set_result(result=check)
action_trace_add_conditions()
if not check: if not check:
raise _StopScript raise _StopScript
@ -654,9 +565,9 @@ class _ScriptRun:
@trace_condition_function @trace_condition_function
def traced_test_conditions(hass, variables): def traced_test_conditions(hass, variables):
try: try:
with condition_path("conditions"): with trace_path("conditions"):
for idx, cond in enumerate(conditions): for idx, cond in enumerate(conditions):
with condition_path(str(idx)): with trace_path(str(idx)):
if not cond(hass, variables): if not cond(hass, variables):
return False return False
except exceptions.ConditionError as ex: except exceptions.ConditionError as ex:
@ -666,7 +577,6 @@ class _ScriptRun:
return True return True
result = traced_test_conditions(self._hass, self._variables) result = traced_test_conditions(self._hass, self._variables)
action_trace_add_conditions()
return result return result
async def _async_repeat_step(self): async def _async_repeat_step(self):
@ -687,7 +597,7 @@ class _ScriptRun:
async def async_run_sequence(iteration, extra_msg=""): async def async_run_sequence(iteration, extra_msg=""):
self._log("Repeating %s: Iteration %i%s", description, iteration, extra_msg) self._log("Repeating %s: Iteration %i%s", description, iteration, extra_msg)
with action_path(str(self._step)): with trace_path(str(self._step)):
await self._async_run_script(script) await self._async_run_script(script)
if CONF_COUNT in repeat: if CONF_COUNT in repeat:
@ -754,18 +664,18 @@ class _ScriptRun:
choose_data = await self._script._async_get_choose_data(self._step) choose_data = await self._script._async_get_choose_data(self._step)
for idx, (conditions, script) in enumerate(choose_data["choices"]): for idx, (conditions, script) in enumerate(choose_data["choices"]):
with action_path(str(idx)): with trace_path(str(idx)):
try: try:
if self._test_conditions(conditions, "choose"): if self._test_conditions(conditions, "choose"):
action_trace_set_result(choice=idx) trace_set_result(choice=idx)
await self._async_run_script(script) await self._async_run_script(script)
return return
except exceptions.ConditionError as ex: except exceptions.ConditionError as ex:
_LOGGER.warning("Error in 'choose' evaluation:\n%s", ex) _LOGGER.warning("Error in 'choose' evaluation:\n%s", ex)
if choose_data["default"]: if choose_data["default"]:
action_trace_set_result(choice="default") trace_set_result(choice="default")
with action_path("default"): with trace_path("default"):
await self._async_run_script(choose_data["default"]) await self._async_run_script(choose_data["default"])
async def _async_wait_for_trigger_step(self): async def _async_wait_for_trigger_step(self):

View File

@ -1,33 +1,13 @@
"""Helpers for script and condition tracing.""" """Helpers for script and condition tracing."""
from collections import deque from collections import deque
from contextlib import contextmanager
from contextvars import ContextVar from contextvars import ContextVar
from typing import Any, Dict, Optional from typing import Any, Deque, Dict, Generator, List, Optional, Union, cast
from homeassistant.helpers.typing import TemplateVarsType from homeassistant.helpers.typing import TemplateVarsType
import homeassistant.util.dt as dt_util import homeassistant.util.dt as dt_util
def trace_stack_push(trace_stack_var: ContextVar, node: Any) -> None:
"""Push an element to the top of a trace stack."""
trace_stack = trace_stack_var.get()
if trace_stack is None:
trace_stack = []
trace_stack_var.set(trace_stack)
trace_stack.append(node)
def trace_stack_pop(trace_stack_var: ContextVar) -> None:
"""Remove the top element from a trace stack."""
trace_stack = trace_stack_var.get()
trace_stack.pop()
def trace_stack_top(trace_stack_var: ContextVar) -> Optional[Any]:
"""Return the element at the top of a trace stack."""
trace_stack = trace_stack_var.get()
return trace_stack[-1] if trace_stack else None
class TraceElement: class TraceElement:
"""Container for trace data.""" """Container for trace data."""
@ -62,17 +42,105 @@ class TraceElement:
return result return result
# Context variables for tracing
# Current trace
trace_cv: ContextVar[Optional[Dict[str, Deque[TraceElement]]]] = ContextVar(
"trace_cv", default=None
)
# Stack of TraceElements
trace_stack_cv: ContextVar[Optional[List[TraceElement]]] = ContextVar(
"trace_stack_cv", default=None
)
# Current location in config tree
trace_path_stack_cv: ContextVar[Optional[List[str]]] = ContextVar(
"trace_path_stack_cv", default=None
)
def trace_stack_push(trace_stack_var: ContextVar, node: Any) -> None:
"""Push an element to the top of a trace stack."""
trace_stack = trace_stack_var.get()
if trace_stack is None:
trace_stack = []
trace_stack_var.set(trace_stack)
trace_stack.append(node)
def trace_stack_pop(trace_stack_var: ContextVar) -> None:
"""Remove the top element from a trace stack."""
trace_stack = trace_stack_var.get()
trace_stack.pop()
def trace_stack_top(trace_stack_var: ContextVar) -> Optional[Any]:
"""Return the element at the top of a trace stack."""
trace_stack = trace_stack_var.get()
return trace_stack[-1] if trace_stack else None
def trace_path_push(suffix: Union[str, List[str]]) -> int:
"""Go deeper in the config tree."""
if isinstance(suffix, str):
suffix = [suffix]
for node in suffix:
trace_stack_push(trace_path_stack_cv, node)
return len(suffix)
def trace_path_pop(count: int) -> None:
"""Go n levels up in the config tree."""
for _ in range(count):
trace_stack_pop(trace_path_stack_cv)
def trace_path_get() -> str:
"""Return a string representing the current location in the config tree."""
path = trace_path_stack_cv.get()
if not path:
return ""
return "/".join(path)
def trace_append_element( def trace_append_element(
trace_var: ContextVar,
trace_element: TraceElement, trace_element: TraceElement,
path: str, path: str,
maxlen: Optional[int] = None, maxlen: Optional[int] = None,
) -> None: ) -> None:
"""Append a TraceElement to trace[path].""" """Append a TraceElement to trace[path]."""
trace = trace_var.get() trace = trace_cv.get()
if trace is None: if trace is None:
trace_var.set({}) trace = {}
trace = trace_var.get() trace_cv.set(trace)
if path not in trace: if path not in trace:
trace[path] = deque(maxlen=maxlen) trace[path] = deque(maxlen=maxlen)
trace[path].append(trace_element) trace[path].append(trace_element)
def trace_get(clear: bool = True) -> Optional[Dict[str, Deque[TraceElement]]]:
"""Return the current trace."""
if clear:
trace_clear()
return trace_cv.get()
def trace_clear() -> None:
"""Clear the trace."""
trace_cv.set({})
trace_stack_cv.set(None)
trace_path_stack_cv.set(None)
def trace_set_result(**kwargs: Any) -> None:
"""Set the result of TraceElement at the top of the stack."""
node = cast(TraceElement, trace_stack_top(trace_stack_cv))
node.set_result(**kwargs)
@contextmanager
def trace_path(suffix: Union[str, List[str]]) -> Generator:
"""Go deeper in the config tree."""
count = trace_path_push(suffix)
try:
yield
finally:
trace_path_pop(count)

View File

@ -4,7 +4,7 @@ from unittest.mock import patch
import pytest import pytest
from homeassistant.exceptions import ConditionError, HomeAssistantError from homeassistant.exceptions import ConditionError, HomeAssistantError
from homeassistant.helpers import condition from homeassistant.helpers import condition, trace
from homeassistant.helpers.template import Template from homeassistant.helpers.template import Template
from homeassistant.setup import async_setup_component from homeassistant.setup import async_setup_component
from homeassistant.util import dt from homeassistant.util import dt
@ -25,8 +25,8 @@ def assert_element(trace_element, expected_element, path):
def assert_condition_trace(expected): def assert_condition_trace(expected):
"""Assert a trace condition sequence is as expected.""" """Assert a trace condition sequence is as expected."""
condition_trace = condition.condition_trace_get() condition_trace = trace.trace_get(clear=False)
condition.condition_trace_clear() trace.trace_clear()
expected_trace_keys = list(expected.keys()) expected_trace_keys = list(expected.keys())
assert list(condition_trace.keys()) == expected_trace_keys assert list(condition_trace.keys()) == expected_trace_keys
for trace_key_index, key in enumerate(expected_trace_keys): for trace_key_index, key in enumerate(expected_trace_keys):

View File

@ -17,7 +17,7 @@ from homeassistant import exceptions
import homeassistant.components.scene as scene import homeassistant.components.scene as scene
from homeassistant.const import ATTR_ENTITY_ID, SERVICE_TURN_ON from homeassistant.const import ATTR_ENTITY_ID, SERVICE_TURN_ON
from homeassistant.core import Context, CoreState, callback from homeassistant.core import Context, CoreState, callback
from homeassistant.helpers import config_validation as cv, script from homeassistant.helpers import config_validation as cv, script, trace
from homeassistant.setup import async_setup_component from homeassistant.setup import async_setup_component
import homeassistant.util.dt as dt_util import homeassistant.util.dt as dt_util
@ -45,8 +45,8 @@ def assert_element(trace_element, expected_element, path):
def assert_action_trace(expected): def assert_action_trace(expected):
"""Assert a trace condition sequence is as expected.""" """Assert a trace condition sequence is as expected."""
action_trace = script.action_trace_get() action_trace = trace.trace_get(clear=False)
script.action_trace_clear() trace.trace_clear()
expected_trace_keys = list(expected.keys()) expected_trace_keys = list(expected.keys())
assert list(action_trace.keys()) == expected_trace_keys assert list(action_trace.keys()) == expected_trace_keys
for trace_key_index, key in enumerate(expected_trace_keys): for trace_key_index, key in enumerate(expected_trace_keys):