diff --git a/homeassistant/components/automation/__init__.py b/homeassistant/components/automation/__init__.py index b9acfc0dde9..6410954191d 100644 --- a/homeassistant/components/automation/__init__.py +++ b/homeassistant/components/automation/__init__.py @@ -68,7 +68,12 @@ from homeassistant.helpers.script import ( ) from homeassistant.helpers.script_variables import ScriptVariables from homeassistant.helpers.service import async_register_admin_service -from homeassistant.helpers.trace import TraceElement, trace_get, trace_path +from homeassistant.helpers.trace import ( + TraceElement, + trace_get, + trace_id_set, + trace_path, +) from homeassistant.helpers.trigger import async_initialize_triggers from homeassistant.helpers.typing import TemplateVarsType from homeassistant.loader import bind_hass @@ -374,6 +379,7 @@ class LimitedSizeDict(OrderedDict): def trace_automation(hass, unique_id, config, trigger, context): """Trace action execution of automation with automation_id.""" automation_trace = AutomationTrace(unique_id, config, trigger, context) + trace_id_set((unique_id, automation_trace.runid)) if unique_id: automation_traces = hass.data[DATA_AUTOMATION_TRACE] diff --git a/homeassistant/components/config/automation.py b/homeassistant/components/config/automation.py index 708ad55aaeb..b5aa1bf7af5 100644 --- a/homeassistant/components/config/automation.py +++ b/homeassistant/components/config/automation.py @@ -16,7 +16,25 @@ from homeassistant.components.automation.config import ( ) from homeassistant.config import AUTOMATION_CONFIG_PATH from homeassistant.const import CONF_ID, SERVICE_RELOAD +from homeassistant.core import callback +from homeassistant.exceptions import HomeAssistantError from homeassistant.helpers import config_validation as cv, entity_registry +from homeassistant.helpers.dispatcher import ( + DATA_DISPATCHER, + async_dispatcher_connect, + async_dispatcher_send, +) +from homeassistant.helpers.script import ( + SCRIPT_BREAKPOINT_HIT, + SCRIPT_DEBUG_CONTINUE_ALL, + breakpoint_clear, + breakpoint_clear_all, + breakpoint_list, + breakpoint_set, + debug_continue, + debug_step, + debug_stop, +) from . import ACTION_DELETE, EditIdBasedConfigView @@ -26,6 +44,13 @@ async def async_setup(hass): websocket_api.async_register_command(hass, websocket_automation_trace_get) websocket_api.async_register_command(hass, websocket_automation_trace_list) + websocket_api.async_register_command(hass, websocket_automation_breakpoint_clear) + websocket_api.async_register_command(hass, websocket_automation_breakpoint_list) + websocket_api.async_register_command(hass, websocket_automation_breakpoint_set) + websocket_api.async_register_command(hass, websocket_automation_debug_continue) + websocket_api.async_register_command(hass, websocket_automation_debug_step) + websocket_api.async_register_command(hass, websocket_automation_debug_stop) + websocket_api.async_register_command(hass, websocket_subscribe_breakpoint_events) async def hook(action, config_key): """post_write_hook for Config View that reloads automations.""" @@ -92,11 +117,12 @@ class EditAutomationConfigView(EditIdBasedConfigView): data[index] = updated_value +@callback +@websocket_api.require_admin @websocket_api.websocket_command( {vol.Required("type"): "automation/trace/get", vol.Optional("automation_id"): str} ) -@websocket_api.async_response -async def websocket_automation_trace_get(hass, connection, msg): +def websocket_automation_trace_get(hass, connection, msg): """Get automation traces.""" automation_id = msg.get("automation_id") @@ -110,10 +136,171 @@ async def websocket_automation_trace_get(hass, connection, msg): connection.send_result(msg["id"], automation_traces) +@callback +@websocket_api.require_admin @websocket_api.websocket_command({vol.Required("type"): "automation/trace/list"}) -@websocket_api.async_response -async def websocket_automation_trace_list(hass, connection, msg): +def websocket_automation_trace_list(hass, connection, msg): """Summarize automation traces.""" automation_traces = get_debug_traces(hass, summary=True) connection.send_result(msg["id"], automation_traces) + + +@callback +@websocket_api.require_admin +@websocket_api.websocket_command( + { + vol.Required("type"): "automation/debug/breakpoint/set", + vol.Required("automation_id"): str, + vol.Required("node"): str, + vol.Optional("run_id"): str, + } +) +def websocket_automation_breakpoint_set(hass, connection, msg): + """Set breakpoint.""" + automation_id = msg["automation_id"] + node = msg["node"] + run_id = msg.get("run_id") + + if ( + SCRIPT_BREAKPOINT_HIT not in hass.data.get(DATA_DISPATCHER, {}) + or not hass.data[DATA_DISPATCHER][SCRIPT_BREAKPOINT_HIT] + ): + raise HomeAssistantError("No breakpoint subscription") + + result = breakpoint_set(hass, automation_id, run_id, node) + connection.send_result(msg["id"], result) + + +@callback +@websocket_api.require_admin +@websocket_api.websocket_command( + { + vol.Required("type"): "automation/debug/breakpoint/clear", + vol.Required("automation_id"): str, + vol.Required("node"): str, + vol.Optional("run_id"): str, + } +) +def websocket_automation_breakpoint_clear(hass, connection, msg): + """Clear breakpoint.""" + automation_id = msg["automation_id"] + node = msg["node"] + run_id = msg.get("run_id") + + result = breakpoint_clear(hass, automation_id, run_id, node) + + connection.send_result(msg["id"], result) + + +@callback +@websocket_api.require_admin +@websocket_api.websocket_command( + {vol.Required("type"): "automation/debug/breakpoint/list"} +) +def websocket_automation_breakpoint_list(hass, connection, msg): + """List breakpoints.""" + breakpoints = breakpoint_list(hass) + for _breakpoint in breakpoints: + _breakpoint["automation_id"] = _breakpoint.pop("unique_id") + + connection.send_result(msg["id"], breakpoints) + + +@callback +@websocket_api.require_admin +@websocket_api.websocket_command( + {vol.Required("type"): "automation/debug/breakpoint/subscribe"} +) +def websocket_subscribe_breakpoint_events(hass, connection, msg): + """Subscribe to breakpoint events.""" + + @callback + def breakpoint_hit(automation_id, run_id, node): + """Forward events to websocket.""" + connection.send_message( + websocket_api.event_message( + msg["id"], + { + "automation_id": automation_id, + "run_id": run_id, + "node": node, + }, + ) + ) + + remove_signal = async_dispatcher_connect( + hass, SCRIPT_BREAKPOINT_HIT, breakpoint_hit + ) + + @callback + def unsub(): + """Unsubscribe from breakpoint events.""" + remove_signal() + if ( + SCRIPT_BREAKPOINT_HIT not in hass.data.get(DATA_DISPATCHER, {}) + or not hass.data[DATA_DISPATCHER][SCRIPT_BREAKPOINT_HIT] + ): + breakpoint_clear_all(hass) + async_dispatcher_send(hass, SCRIPT_DEBUG_CONTINUE_ALL) + + connection.subscriptions[msg["id"]] = unsub + + connection.send_message(websocket_api.result_message(msg["id"])) + + +@callback +@websocket_api.require_admin +@websocket_api.websocket_command( + { + vol.Required("type"): "automation/debug/continue", + vol.Required("automation_id"): str, + vol.Required("run_id"): str, + } +) +def websocket_automation_debug_continue(hass, connection, msg): + """Resume execution of halted automation.""" + automation_id = msg["automation_id"] + run_id = msg["run_id"] + + result = debug_continue(hass, automation_id, run_id) + + connection.send_result(msg["id"], result) + + +@callback +@websocket_api.require_admin +@websocket_api.websocket_command( + { + vol.Required("type"): "automation/debug/step", + vol.Required("automation_id"): str, + vol.Required("run_id"): str, + } +) +def websocket_automation_debug_step(hass, connection, msg): + """Single step a halted automation.""" + automation_id = msg["automation_id"] + run_id = msg["run_id"] + + result = debug_step(hass, automation_id, run_id) + + connection.send_result(msg["id"], result) + + +@callback +@websocket_api.require_admin +@websocket_api.websocket_command( + { + vol.Required("type"): "automation/debug/stop", + vol.Required("automation_id"): str, + vol.Required("run_id"): str, + } +) +def websocket_automation_debug_stop(hass, connection, msg): + """Stop a halted automation.""" + automation_id = msg["automation_id"] + run_id = msg["run_id"] + + result = debug_stop(hass, automation_id, run_id) + + connection.send_result(msg["id"], result) diff --git a/homeassistant/helpers/script.py b/homeassistant/helpers/script.py index b3fcffd4944..257fd6d9715 100644 --- a/homeassistant/helpers/script.py +++ b/homeassistant/helpers/script.py @@ -1,6 +1,6 @@ """Helpers to execute scripts.""" import asyncio -from contextlib import contextmanager +from contextlib import asynccontextmanager from datetime import datetime, timedelta from functools import partial import itertools @@ -65,6 +65,10 @@ from homeassistant.core import ( ) from homeassistant.helpers import condition, config_validation as cv, service, template from homeassistant.helpers.condition import trace_condition_function +from homeassistant.helpers.dispatcher import ( + async_dispatcher_connect, + async_dispatcher_send, +) from homeassistant.helpers.event import async_call_later, async_track_template from homeassistant.helpers.script_variables import ScriptVariables from homeassistant.helpers.trigger import ( @@ -78,6 +82,7 @@ from homeassistant.util.dt import utcnow from .trace import ( TraceElement, trace_append_element, + trace_id_get, trace_path, trace_path_get, trace_set_result, @@ -111,6 +116,9 @@ ATTR_CUR = "current" ATTR_MAX = "max" DATA_SCRIPTS = "helpers.script" +DATA_SCRIPT_BREAKPOINTS = "helpers.script_breakpoints" +RUN_ID_ANY = "*" +NODE_ANY = "*" _LOGGER = logging.getLogger(__name__) @@ -122,6 +130,10 @@ _SHUTDOWN_MAX_WAIT = 60 ACTION_TRACE_NODE_MAX_LEN = 20 # Max length of a trace node for repeated actions +SCRIPT_BREAKPOINT_HIT = "script_breakpoint_hit" +SCRIPT_DEBUG_CONTINUE_STOP = "script_debug_continue_stop_{}_{}" +SCRIPT_DEBUG_CONTINUE_ALL = "script_debug_continue_all" + def action_trace_append(variables, path): """Append a TraceElement to trace[path].""" @@ -130,11 +142,57 @@ def action_trace_append(variables, path): return trace_element -@contextmanager -def trace_action(variables): +@asynccontextmanager +async def trace_action(hass, script_run, stop, variables): """Trace action execution.""" - trace_element = action_trace_append(variables, trace_path_get()) + path = trace_path_get() + trace_element = action_trace_append(variables, path) trace_stack_push(trace_stack_cv, trace_element) + + trace_id = trace_id_get() + if trace_id: + unique_id = trace_id[0] + run_id = trace_id[1] + breakpoints = hass.data[DATA_SCRIPT_BREAKPOINTS] + if unique_id in breakpoints and ( + ( + run_id in breakpoints[unique_id] + and ( + path in breakpoints[unique_id][run_id] + or NODE_ANY in breakpoints[unique_id][run_id] + ) + ) + or ( + RUN_ID_ANY in breakpoints[unique_id] + and ( + path in breakpoints[unique_id][RUN_ID_ANY] + or NODE_ANY in breakpoints[unique_id][RUN_ID_ANY] + ) + ) + ): + async_dispatcher_send(hass, SCRIPT_BREAKPOINT_HIT, unique_id, run_id, path) + + done = asyncio.Event() + + @callback + def async_continue_stop(command=None): + if command == "stop": + stop.set() + done.set() + + signal = SCRIPT_DEBUG_CONTINUE_STOP.format(unique_id, run_id) + remove_signal1 = async_dispatcher_connect(hass, signal, async_continue_stop) + remove_signal2 = async_dispatcher_connect( + hass, SCRIPT_DEBUG_CONTINUE_ALL, async_continue_stop + ) + + tasks = [hass.async_create_task(flag.wait()) for flag in (stop, done)] + await asyncio.wait(tasks, return_when=asyncio.FIRST_COMPLETED) + for task in tasks: + task.cancel() + remove_signal1() + remove_signal2() + try: yield trace_element except Exception as ex: # pylint: disable=broad-except @@ -294,16 +352,19 @@ class _ScriptRun: self._finish() async def _async_step(self, log_exceptions): - with trace_path(str(self._step)), trace_action(self._variables): - try: - handler = f"_async_{cv.determine_script_action(self._action)}_step" - await getattr(self, handler)() - except Exception as ex: - if not isinstance(ex, (_StopScript, asyncio.CancelledError)) and ( - self._log_exceptions or log_exceptions - ): - self._log_exception(ex) - raise + with trace_path(str(self._step)): + async with trace_action(self._hass, self, self._stop, self._variables): + if self._stop.is_set(): + return + try: + handler = f"_async_{cv.determine_script_action(self._action)}_step" + await getattr(self, handler)() + except Exception as ex: + if not isinstance(ex, (_StopScript, asyncio.CancelledError)) and ( + self._log_exceptions or log_exceptions + ): + self._log_exception(ex) + raise def _finish(self) -> None: self._script._runs.remove(self) # pylint: disable=protected-access @@ -876,6 +937,8 @@ class Script: all_scripts.append( {"instance": self, "started_before_shutdown": not hass.is_stopping} ) + if DATA_SCRIPT_BREAKPOINTS not in hass.data: + hass.data[DATA_SCRIPT_BREAKPOINTS] = {} self._hass = hass self.sequence = sequence @@ -1213,3 +1276,71 @@ class Script: self._logger.exception(msg, *args, **kwargs) else: self._logger.log(level, msg, *args, **kwargs) + + +@callback +def breakpoint_clear(hass, unique_id, run_id, node): + """Clear a breakpoint.""" + run_id = run_id or RUN_ID_ANY + breakpoints = hass.data[DATA_SCRIPT_BREAKPOINTS] + if unique_id not in breakpoints or run_id not in breakpoints[unique_id]: + return + breakpoints[unique_id][run_id].discard(node) + + +@callback +def breakpoint_clear_all(hass): + """Clear all breakpoints.""" + hass.data[DATA_SCRIPT_BREAKPOINTS] = {} + + +@callback +def breakpoint_set(hass, unique_id, run_id, node): + """Set a breakpoint.""" + run_id = run_id or RUN_ID_ANY + breakpoints = hass.data[DATA_SCRIPT_BREAKPOINTS] + if unique_id not in breakpoints: + breakpoints[unique_id] = {} + if run_id not in breakpoints[unique_id]: + breakpoints[unique_id][run_id] = set() + breakpoints[unique_id][run_id].add(node) + + +@callback +def breakpoint_list(hass): + """List breakpoints.""" + breakpoints = hass.data[DATA_SCRIPT_BREAKPOINTS] + + return [ + {"unique_id": unique_id, "run_id": run_id, "node": node} + for unique_id in breakpoints + for run_id in breakpoints[unique_id] + for node in breakpoints[unique_id][run_id] + ] + + +@callback +def debug_continue(hass, unique_id, run_id): + """Continue execution of a halted script.""" + # Clear any wildcard breakpoint + breakpoint_clear(hass, unique_id, run_id, NODE_ANY) + + signal = SCRIPT_DEBUG_CONTINUE_STOP.format(unique_id, run_id) + async_dispatcher_send(hass, signal, "continue") + + +@callback +def debug_step(hass, unique_id, run_id): + """Single step a halted script.""" + # Set a wildcard breakpoint + breakpoint_set(hass, unique_id, run_id, NODE_ANY) + + signal = SCRIPT_DEBUG_CONTINUE_STOP.format(unique_id, run_id) + async_dispatcher_send(hass, signal, "continue") + + +@callback +def debug_stop(hass, unique_id, run_id): + """Stop execution of a running or halted script.""" + signal = SCRIPT_DEBUG_CONTINUE_STOP.format(unique_id, run_id) + async_dispatcher_send(hass, signal, "stop") diff --git a/homeassistant/helpers/trace.py b/homeassistant/helpers/trace.py index 0c1969a8ac6..e0c67a1de54 100644 --- a/homeassistant/helpers/trace.py +++ b/homeassistant/helpers/trace.py @@ -2,7 +2,7 @@ from collections import deque from contextlib import contextmanager from contextvars import ContextVar -from typing import Any, Deque, Dict, Generator, List, Optional, Union, cast +from typing import Any, Deque, Dict, Generator, List, Optional, Tuple, Union, cast from homeassistant.helpers.typing import TemplateVarsType import homeassistant.util.dt as dt_util @@ -67,6 +67,20 @@ trace_path_stack_cv: ContextVar[Optional[List[str]]] = ContextVar( ) # Copy of last variables variables_cv: ContextVar[Optional[Any]] = ContextVar("variables_cv", default=None) +# Automation ID + Run ID +trace_id_cv: ContextVar[Optional[Tuple[str, str]]] = ContextVar( + "trace_id_cv", default=None +) + + +def trace_id_set(trace_id: Tuple[str, str]) -> None: + """Set id of the current trace.""" + trace_id_cv.set(trace_id) + + +def trace_id_get() -> Optional[Tuple[str, str]]: + """Get id if the current trace.""" + return trace_id_cv.get() def trace_stack_push(trace_stack_var: ContextVar, node: Any) -> None: diff --git a/tests/components/config/test_automation.py b/tests/components/config/test_automation.py index d52295c75f7..2880287be94 100644 --- a/tests/components/config/test_automation.py +++ b/tests/components/config/test_automation.py @@ -6,6 +6,7 @@ from homeassistant.bootstrap import async_setup_component from homeassistant.components import automation, config from homeassistant.helpers import entity_registry as er +from tests.common import assert_lists_same from tests.components.blueprint.conftest import stub_blueprint_populate # noqa: F401 @@ -511,3 +512,426 @@ async def test_list_automation_traces(hass, hass_ws_client): assert trace["timestamp"] assert trace["trigger"] == "event 'test_event2'" assert trace["unique_id"] == "moon" + + +async def test_automation_breakpoints(hass, hass_ws_client): + """Test automation breakpoints.""" + id = 1 + + def next_id(): + nonlocal id + id += 1 + return id + + async def assert_last_action(automation_id, expected_action, expected_state): + await client.send_json({"id": next_id(), "type": "automation/trace/list"}) + response = await client.receive_json() + assert response["success"] + trace = response["result"][automation_id][-1] + assert trace["last_action"] == expected_action + assert trace["state"] == expected_state + return trace["run_id"] + + sun_config = { + "id": "sun", + "trigger": {"platform": "event", "event_type": "test_event"}, + "action": [ + {"event": "event0"}, + {"event": "event1"}, + {"event": "event2"}, + {"event": "event3"}, + {"event": "event4"}, + {"event": "event5"}, + {"event": "event6"}, + {"event": "event7"}, + {"event": "event8"}, + ], + } + + assert await async_setup_component( + hass, + "automation", + { + "automation": [ + sun_config, + ] + }, + ) + + with patch.object(config, "SECTIONS", ["automation"]): + await async_setup_component(hass, "config", {}) + + client = await hass_ws_client() + + await client.send_json( + { + "id": next_id(), + "type": "automation/debug/breakpoint/set", + "automation_id": "sun", + "node": "1", + } + ) + response = await client.receive_json() + assert not response["success"] + + await client.send_json( + {"id": next_id(), "type": "automation/debug/breakpoint/list"} + ) + response = await client.receive_json() + assert response["success"] + assert response["result"] == [] + + subscription_id = next_id() + await client.send_json( + {"id": subscription_id, "type": "automation/debug/breakpoint/subscribe"} + ) + response = await client.receive_json() + assert response["success"] + + await client.send_json( + { + "id": next_id(), + "type": "automation/debug/breakpoint/set", + "automation_id": "sun", + "node": "action/1", + } + ) + response = await client.receive_json() + assert response["success"] + await client.send_json( + { + "id": next_id(), + "type": "automation/debug/breakpoint/set", + "automation_id": "sun", + "node": "action/5", + } + ) + response = await client.receive_json() + assert response["success"] + + await client.send_json( + {"id": next_id(), "type": "automation/debug/breakpoint/list"} + ) + response = await client.receive_json() + assert response["success"] + assert_lists_same( + response["result"], + [ + {"node": "action/1", "run_id": "*", "automation_id": "sun"}, + {"node": "action/5", "run_id": "*", "automation_id": "sun"}, + ], + ) + + # Trigger "sun" automation + hass.bus.async_fire("test_event") + + response = await client.receive_json() + run_id = await assert_last_action("sun", "action/1", "running") + assert response["event"] == { + "automation_id": "sun", + "node": "action/1", + "run_id": run_id, + } + + await client.send_json( + { + "id": next_id(), + "type": "automation/debug/step", + "automation_id": "sun", + "run_id": run_id, + } + ) + response = await client.receive_json() + assert response["success"] + + response = await client.receive_json() + run_id = await assert_last_action("sun", "action/2", "running") + assert response["event"] == { + "automation_id": "sun", + "node": "action/2", + "run_id": run_id, + } + + await client.send_json( + { + "id": next_id(), + "type": "automation/debug/continue", + "automation_id": "sun", + "run_id": run_id, + } + ) + response = await client.receive_json() + assert response["success"] + + response = await client.receive_json() + run_id = await assert_last_action("sun", "action/5", "running") + assert response["event"] == { + "automation_id": "sun", + "node": "action/5", + "run_id": run_id, + } + + await client.send_json( + { + "id": next_id(), + "type": "automation/debug/stop", + "automation_id": "sun", + "run_id": run_id, + } + ) + response = await client.receive_json() + assert response["success"] + await hass.async_block_till_done() + await assert_last_action("sun", "action/5", "stopped") + + +async def test_automation_breakpoints_2(hass, hass_ws_client): + """Test execution resumes and breakpoints are removed after subscription removed.""" + id = 1 + + def next_id(): + nonlocal id + id += 1 + return id + + async def assert_last_action(automation_id, expected_action, expected_state): + await client.send_json({"id": next_id(), "type": "automation/trace/list"}) + response = await client.receive_json() + assert response["success"] + trace = response["result"][automation_id][-1] + assert trace["last_action"] == expected_action + assert trace["state"] == expected_state + return trace["run_id"] + + sun_config = { + "id": "sun", + "trigger": {"platform": "event", "event_type": "test_event"}, + "action": [ + {"event": "event0"}, + {"event": "event1"}, + {"event": "event2"}, + {"event": "event3"}, + {"event": "event4"}, + {"event": "event5"}, + {"event": "event6"}, + {"event": "event7"}, + {"event": "event8"}, + ], + } + + assert await async_setup_component( + hass, + "automation", + { + "automation": [ + sun_config, + ] + }, + ) + + with patch.object(config, "SECTIONS", ["automation"]): + await async_setup_component(hass, "config", {}) + + client = await hass_ws_client() + + subscription_id = next_id() + await client.send_json( + {"id": subscription_id, "type": "automation/debug/breakpoint/subscribe"} + ) + response = await client.receive_json() + assert response["success"] + + await client.send_json( + { + "id": next_id(), + "type": "automation/debug/breakpoint/set", + "automation_id": "sun", + "node": "action/1", + } + ) + response = await client.receive_json() + assert response["success"] + + # Trigger "sun" automation + hass.bus.async_fire("test_event") + + response = await client.receive_json() + run_id = await assert_last_action("sun", "action/1", "running") + assert response["event"] == { + "automation_id": "sun", + "node": "action/1", + "run_id": run_id, + } + + # Unsubscribe - execution should resume + await client.send_json( + {"id": next_id(), "type": "unsubscribe_events", "subscription": subscription_id} + ) + response = await client.receive_json() + assert response["success"] + await hass.async_block_till_done() + await assert_last_action("sun", "action/8", "stopped") + + # Should not be possible to set breakpoints + await client.send_json( + { + "id": next_id(), + "type": "automation/debug/breakpoint/set", + "automation_id": "sun", + "node": "1", + } + ) + response = await client.receive_json() + assert not response["success"] + + # Trigger "sun" automation, should finish without stopping on breakpoints + hass.bus.async_fire("test_event") + await hass.async_block_till_done() + + new_run_id = await assert_last_action("sun", "action/8", "stopped") + assert new_run_id != run_id + + +async def test_automation_breakpoints_3(hass, hass_ws_client): + """Test breakpoints can be cleared.""" + id = 1 + + def next_id(): + nonlocal id + id += 1 + return id + + async def assert_last_action(automation_id, expected_action, expected_state): + await client.send_json({"id": next_id(), "type": "automation/trace/list"}) + response = await client.receive_json() + assert response["success"] + trace = response["result"][automation_id][-1] + assert trace["last_action"] == expected_action + assert trace["state"] == expected_state + return trace["run_id"] + + sun_config = { + "id": "sun", + "trigger": {"platform": "event", "event_type": "test_event"}, + "action": [ + {"event": "event0"}, + {"event": "event1"}, + {"event": "event2"}, + {"event": "event3"}, + {"event": "event4"}, + {"event": "event5"}, + {"event": "event6"}, + {"event": "event7"}, + {"event": "event8"}, + ], + } + + assert await async_setup_component( + hass, + "automation", + { + "automation": [ + sun_config, + ] + }, + ) + + with patch.object(config, "SECTIONS", ["automation"]): + await async_setup_component(hass, "config", {}) + + client = await hass_ws_client() + + subscription_id = next_id() + await client.send_json( + {"id": subscription_id, "type": "automation/debug/breakpoint/subscribe"} + ) + response = await client.receive_json() + assert response["success"] + + await client.send_json( + { + "id": next_id(), + "type": "automation/debug/breakpoint/set", + "automation_id": "sun", + "node": "action/1", + } + ) + response = await client.receive_json() + assert response["success"] + + await client.send_json( + { + "id": next_id(), + "type": "automation/debug/breakpoint/set", + "automation_id": "sun", + "node": "action/5", + } + ) + response = await client.receive_json() + assert response["success"] + + # Trigger "sun" automation + hass.bus.async_fire("test_event") + + response = await client.receive_json() + run_id = await assert_last_action("sun", "action/1", "running") + assert response["event"] == { + "automation_id": "sun", + "node": "action/1", + "run_id": run_id, + } + + await client.send_json( + { + "id": next_id(), + "type": "automation/debug/continue", + "automation_id": "sun", + "run_id": run_id, + } + ) + response = await client.receive_json() + assert response["success"] + + response = await client.receive_json() + run_id = await assert_last_action("sun", "action/5", "running") + assert response["event"] == { + "automation_id": "sun", + "node": "action/5", + "run_id": run_id, + } + + await client.send_json( + { + "id": next_id(), + "type": "automation/debug/stop", + "automation_id": "sun", + "run_id": run_id, + } + ) + response = await client.receive_json() + assert response["success"] + await hass.async_block_till_done() + await assert_last_action("sun", "action/5", "stopped") + + # Clear 1st breakpoint + await client.send_json( + { + "id": next_id(), + "type": "automation/debug/breakpoint/clear", + "automation_id": "sun", + "node": "action/1", + } + ) + response = await client.receive_json() + assert response["success"] + + # Trigger "sun" automation + hass.bus.async_fire("test_event") + + response = await client.receive_json() + run_id = await assert_last_action("sun", "action/5", "running") + assert response["event"] == { + "automation_id": "sun", + "node": "action/5", + "run_id": run_id, + } diff --git a/tests/helpers/test_script.py b/tests/helpers/test_script.py index 18c769fee73..7cb4b627a94 100644 --- a/tests/helpers/test_script.py +++ b/tests/helpers/test_script.py @@ -18,6 +18,7 @@ import homeassistant.components.scene as scene from homeassistant.const import ATTR_ENTITY_ID, SERVICE_TURN_ON from homeassistant.core import Context, CoreState, callback from homeassistant.helpers import config_validation as cv, script, trace +from homeassistant.helpers.dispatcher import async_dispatcher_connect from homeassistant.setup import async_setup_component import homeassistant.util.dt as dt_util @@ -80,14 +81,17 @@ async def test_firing_event_basic(hass, caplog): sequence = cv.SCRIPT_SCHEMA( {"alias": alias, "event": event, "event_data": {"hello": "world"}} ) - with script.trace_action(None): - script_obj = script.Script( - hass, - sequence, - "Test Name", - "test_domain", - running_description="test script", - ) + + # Prepare tracing + trace.trace_get() + + script_obj = script.Script( + hass, + sequence, + "Test Name", + "test_domain", + running_description="test script", + ) await script_obj.async_run(context=context) await hass.async_block_till_done() @@ -100,7 +104,6 @@ async def test_firing_event_basic(hass, caplog): assert f"Executing step {alias}" in caplog.text assert_action_trace( { - "": [{}], "0": [{}], } ) @@ -1215,8 +1218,11 @@ async def test_repeat_count(hass, caplog, count): }, } ) - with script.trace_action(None): - script_obj = script.Script(hass, sequence, "Test Name", "test_domain") + + # Prepare tracing + trace.trace_get() + + script_obj = script.Script(hass, sequence, "Test Name", "test_domain") await script_obj.async_run(context=Context()) await hass.async_block_till_done() @@ -1229,7 +1235,6 @@ async def test_repeat_count(hass, caplog, count): assert caplog.text.count(f"Repeating {alias}") == count assert_action_trace( { - "": [{}], "0": [{}], "0/0/0": [{}] * min(count, script.ACTION_TRACE_NODE_MAX_LEN), } @@ -2348,3 +2353,165 @@ async def test_embedded_wait_for_trigger_in_automation(hass): await hass.async_block_till_done() assert len(mock_calls) == 1 + + +async def test_breakpoints_1(hass): + """Test setting a breakpoint halts execution, and execution can be resumed.""" + event = "test_event" + events = async_capture_events(hass, event) + sequence = cv.SCRIPT_SCHEMA( + [ + {"event": event, "event_data": {"value": 0}}, # Node "0" + {"event": event, "event_data": {"value": 1}}, # Node "1" + {"event": event, "event_data": {"value": 2}}, # Node "2" + {"event": event, "event_data": {"value": 3}}, # Node "3" + {"event": event, "event_data": {"value": 4}}, # Node "4" + {"event": event, "event_data": {"value": 5}}, # Node "5" + {"event": event, "event_data": {"value": 6}}, # Node "6" + {"event": event, "event_data": {"value": 7}}, # Node "7" + ] + ) + logger = logging.getLogger("TEST") + script_obj = script.Script( + hass, + sequence, + "Test Name", + "test_domain", + script_mode="queued", + max_runs=2, + logger=logger, + ) + trace.trace_id_set(("script_1", "1")) + script.breakpoint_set(hass, "script_1", script.RUN_ID_ANY, "1") + script.breakpoint_set(hass, "script_1", script.RUN_ID_ANY, "5") + + breakpoint_hit_event = asyncio.Event() + + @callback + def breakpoint_hit(*_): + breakpoint_hit_event.set() + + async_dispatcher_connect(hass, script.SCRIPT_BREAKPOINT_HIT, breakpoint_hit) + + watch_messages = [] + + @callback + def check_action(): + for message, flag in watch_messages: + if script_obj.last_action and message in script_obj.last_action: + flag.set() + + script_obj.change_listener = check_action + + assert not script_obj.is_running + assert script_obj.runs == 0 + + # Start script, should stop on breakpoint at node "1" + hass.async_create_task(script_obj.async_run(context=Context())) + await breakpoint_hit_event.wait() + assert script_obj.is_running + assert script_obj.runs == 1 + assert len(events) == 1 + assert events[-1].data["value"] == 0 + + # Single step script, should stop at node "2" + breakpoint_hit_event.clear() + script.debug_step(hass, "script_1", "1") + await breakpoint_hit_event.wait() + assert script_obj.is_running + assert script_obj.runs == 1 + assert len(events) == 2 + assert events[-1].data["value"] == 1 + + # Single step script, should stop at node "3" + breakpoint_hit_event.clear() + script.debug_step(hass, "script_1", "1") + await breakpoint_hit_event.wait() + assert script_obj.is_running + assert script_obj.runs == 1 + assert len(events) == 3 + assert events[-1].data["value"] == 2 + + # Resume script, should stop on breakpoint at node "5" + breakpoint_hit_event.clear() + script.debug_continue(hass, "script_1", "1") + await breakpoint_hit_event.wait() + assert script_obj.is_running + assert script_obj.runs == 1 + assert len(events) == 5 + assert events[-1].data["value"] == 4 + + # Resume script, should run until completion + script.debug_continue(hass, "script_1", "1") + await hass.async_block_till_done() + assert not script_obj.is_running + assert script_obj.runs == 0 + assert len(events) == 8 + assert events[-1].data["value"] == 7 + + +async def test_breakpoints_2(hass): + """Test setting a breakpoint halts execution, and execution can be aborted.""" + event = "test_event" + events = async_capture_events(hass, event) + sequence = cv.SCRIPT_SCHEMA( + [ + {"event": event, "event_data": {"value": 0}}, # Node "0" + {"event": event, "event_data": {"value": 1}}, # Node "1" + {"event": event, "event_data": {"value": 2}}, # Node "2" + {"event": event, "event_data": {"value": 3}}, # Node "3" + {"event": event, "event_data": {"value": 4}}, # Node "4" + {"event": event, "event_data": {"value": 5}}, # Node "5" + {"event": event, "event_data": {"value": 6}}, # Node "6" + {"event": event, "event_data": {"value": 7}}, # Node "7" + ] + ) + logger = logging.getLogger("TEST") + script_obj = script.Script( + hass, + sequence, + "Test Name", + "test_domain", + script_mode="queued", + max_runs=2, + logger=logger, + ) + trace.trace_id_set(("script_1", "1")) + script.breakpoint_set(hass, "script_1", script.RUN_ID_ANY, "1") + script.breakpoint_set(hass, "script_1", script.RUN_ID_ANY, "5") + + breakpoint_hit_event = asyncio.Event() + + @callback + def breakpoint_hit(*_): + breakpoint_hit_event.set() + + async_dispatcher_connect(hass, script.SCRIPT_BREAKPOINT_HIT, breakpoint_hit) + + watch_messages = [] + + @callback + def check_action(): + for message, flag in watch_messages: + if script_obj.last_action and message in script_obj.last_action: + flag.set() + + script_obj.change_listener = check_action + + assert not script_obj.is_running + assert script_obj.runs == 0 + + # Start script, should stop on breakpoint at node "1" + hass.async_create_task(script_obj.async_run(context=Context())) + await breakpoint_hit_event.wait() + assert script_obj.is_running + assert script_obj.runs == 1 + assert len(events) == 1 + assert events[-1].data["value"] == 0 + + # Abort script + script.debug_stop(hass, "script_1", "1") + await hass.async_block_till_done() + assert not script_obj.is_running + assert script_obj.runs == 0 + assert len(events) == 1