diff --git a/homeassistant/components/automation/__init__.py b/homeassistant/components/automation/__init__.py index df7102effde..9963c942f08 100644 --- a/homeassistant/components/automation/__init__.py +++ b/homeassistant/components/automation/__init__.py @@ -1,7 +1,8 @@ """Allow to set up simple automation rules via the config file.""" -from collections import deque +from collections import OrderedDict from contextlib import contextmanager import datetime as dt +from itertools import count import logging from typing import ( Any, @@ -240,6 +241,8 @@ async def async_setup(hass, config): class AutomationTrace: """Container for automation trace.""" + _runids = count(0) + def __init__( self, unique_id: Optional[str], @@ -254,6 +257,7 @@ class AutomationTrace: self._context: Context = context self._error: Optional[Exception] = None self._state: str = "running" + self.runid: str = str(next(self._runids)) self._timestamp_finish: Optional[dt.datetime] = None self._timestamp_start: dt.datetime = dt_util.utcnow() self._trigger: Dict[str, Any] = trigger @@ -300,6 +304,7 @@ class AutomationTrace: "config": self._config, "context": self._context, "state": self._state, + "run_id": self.runid, "timestamp": { "start": self._timestamp_start, "finish": self._timestamp_finish, @@ -313,16 +318,37 @@ class AutomationTrace: return result +class LimitedSizeDict(OrderedDict): + """OrderedDict limited in size.""" + + def __init__(self, *args, **kwds): + """Initialize OrderedDict limited in size.""" + self.size_limit = kwds.pop("size_limit", None) + OrderedDict.__init__(self, *args, **kwds) + self._check_size_limit() + + def __setitem__(self, key, value): + """Set item and check dict size.""" + OrderedDict.__setitem__(self, key, value) + self._check_size_limit() + + def _check_size_limit(self): + """Check dict size and evict items in FIFO order if needed.""" + if self.size_limit is not None: + while len(self) > self.size_limit: + self.popitem(last=False) + + @contextmanager def trace_automation(hass, unique_id, config, trigger, context): """Trace action execution of automation with automation_id.""" automation_trace = AutomationTrace(unique_id, config, trigger, context) if unique_id: - if unique_id not in hass.data[DATA_AUTOMATION_TRACE]: - hass.data[DATA_AUTOMATION_TRACE][unique_id] = deque([], STORED_TRACES) - traces = hass.data[DATA_AUTOMATION_TRACE][unique_id] - traces.append(automation_trace) + automation_traces = hass.data[DATA_AUTOMATION_TRACE] + if unique_id not in automation_traces: + automation_traces[unique_id] = LimitedSizeDict(size_limit=STORED_TRACES) + automation_traces[unique_id][automation_trace.runid] = automation_trace try: yield automation_trace @@ -835,7 +861,7 @@ def get_debug_traces_for_automation(hass, automation_id): """Return a serializable list of debug traces for an automation.""" traces = [] - for trace in hass.data[DATA_AUTOMATION_TRACE].get(automation_id, []): + for trace in hass.data[DATA_AUTOMATION_TRACE].get(automation_id, {}).values(): traces.append(trace.as_dict()) return traces diff --git a/tests/components/config/test_automation.py b/tests/components/config/test_automation.py index 7e0cf9e8e4d..cb192befade 100644 --- a/tests/components/config/test_automation.py +++ b/tests/components/config/test_automation.py @@ -3,7 +3,7 @@ import json from unittest.mock import patch from homeassistant.bootstrap import async_setup_component -from homeassistant.components import config +from homeassistant.components import automation, config from tests.components.blueprint.conftest import stub_blueprint_populate # noqa: F401 @@ -325,3 +325,81 @@ async def test_get_automation_trace(hass, hass_ws_client): assert trace["trigger"]["description"] == "event 'test_event2'" assert trace["unique_id"] == "moon" assert trace["variables"] + + +async def test_automation_trace_overflow(hass, hass_ws_client): + """Test the number of stored traces per automation is limited.""" + id = 1 + + def next_id(): + nonlocal id + id += 1 + return id + + sun_config = { + "id": "sun", + "trigger": {"platform": "event", "event_type": "test_event"}, + "action": {"event": "some_event"}, + } + moon_config = { + "id": "moon", + "trigger": {"platform": "event", "event_type": "test_event2"}, + "action": {"event": "another_event"}, + } + + assert await async_setup_component( + hass, + "automation", + { + "automation": [ + sun_config, + moon_config, + ] + }, + ) + + with patch.object(config, "SECTIONS", ["automation"]): + await async_setup_component(hass, "config", {}) + + client = await hass_ws_client() + + await client.send_json({"id": next_id(), "type": "automation/trace"}) + response = await client.receive_json() + assert response["success"] + assert response["result"] == {} + + await client.send_json( + {"id": next_id(), "type": "automation/trace", "automation_id": "sun"} + ) + response = await client.receive_json() + assert response["success"] + assert response["result"] == {"sun": []} + + # Trigger "sun" and "moon" automation once + hass.bus.async_fire("test_event") + hass.bus.async_fire("test_event2") + await hass.async_block_till_done() + + # Get traces + await client.send_json({"id": next_id(), "type": "automation/trace"}) + response = await client.receive_json() + assert response["success"] + assert len(response["result"]["moon"]) == 1 + moon_run_id = response["result"]["moon"][0]["run_id"] + assert len(response["result"]["sun"]) == 1 + + # Trigger "moon" automation enough times to overflow the number of stored traces + for _ in range(automation.STORED_TRACES): + hass.bus.async_fire("test_event2") + await hass.async_block_till_done() + + await client.send_json({"id": next_id(), "type": "automation/trace"}) + response = await client.receive_json() + assert response["success"] + assert len(response["result"]["moon"]) == automation.STORED_TRACES + assert len(response["result"]["sun"]) == 1 + assert int(response["result"]["moon"][0]["run_id"]) == int(moon_run_id) + 1 + assert ( + int(response["result"]["moon"][-1]["run_id"]) + == int(moon_run_id) + automation.STORED_TRACES + ) diff --git a/tests/helpers/test_script.py b/tests/helpers/test_script.py index 027254ee03e..18c769fee73 100644 --- a/tests/helpers/test_script.py +++ b/tests/helpers/test_script.py @@ -1192,11 +1192,11 @@ async def test_condition_all_cached(hass): assert len(script_obj._config_cache) == 2 -async def test_repeat_count(hass, caplog): +@pytest.mark.parametrize("count", [3, script.ACTION_TRACE_NODE_MAX_LEN * 2]) +async def test_repeat_count(hass, caplog, count): """Test repeat action w/ count option.""" event = "test_event" events = async_capture_events(hass, event) - count = 3 alias = "condition step" sequence = cv.SCRIPT_SCHEMA( @@ -1215,7 +1215,8 @@ async def test_repeat_count(hass, caplog): }, } ) - script_obj = script.Script(hass, sequence, "Test Name", "test_domain") + with script.trace_action(None): + script_obj = script.Script(hass, sequence, "Test Name", "test_domain") await script_obj.async_run(context=Context()) await hass.async_block_till_done() @@ -1226,6 +1227,13 @@ async def test_repeat_count(hass, caplog): assert event.data.get("index") == index + 1 assert event.data.get("last") == (index == count - 1) assert caplog.text.count(f"Repeating {alias}") == count + assert_action_trace( + { + "": [{}], + "0": [{}], + "0/0/0": [{}] * min(count, script.ACTION_TRACE_NODE_MAX_LEN), + } + ) @pytest.mark.parametrize("condition", ["while", "until"])