Don't prevent automations from triggering themselves (#68178)
parent
b99934f62f
commit
46f27fdefd
|
@ -54,6 +54,7 @@ from homeassistant.helpers.script import (
|
|||
CONF_MAX,
|
||||
CONF_MAX_EXCEEDED,
|
||||
Script,
|
||||
script_stack_cv,
|
||||
)
|
||||
from homeassistant.helpers.script_variables import ScriptVariables
|
||||
from homeassistant.helpers.service import (
|
||||
|
@ -505,6 +506,10 @@ class AutomationEntity(ToggleEntity, RestoreEntity):
|
|||
EVENT_AUTOMATION_TRIGGERED, event_data, context=trigger_context
|
||||
)
|
||||
|
||||
# Make a new empty script stack; automations are allowed
|
||||
# to recursively trigger themselves
|
||||
script_stack_cv.set([])
|
||||
|
||||
try:
|
||||
with trace_path("action"):
|
||||
await self.action_script.async_run(
|
||||
|
|
|
@ -1247,7 +1247,7 @@ class Script:
|
|||
and id(self) in script_stack
|
||||
):
|
||||
script_execution_set("disallowed_recursion_detected")
|
||||
_LOGGER.warning("Disallowed recursion detected")
|
||||
self._log("Disallowed recursion detected", level=logging.WARNING)
|
||||
return
|
||||
|
||||
if self.script_mode != SCRIPT_MODE_QUEUED:
|
||||
|
|
|
@ -1,5 +1,6 @@
|
|||
"""The tests for the automation component."""
|
||||
import asyncio
|
||||
from datetime import timedelta
|
||||
import logging
|
||||
from unittest.mock import Mock, patch
|
||||
|
||||
|
@ -25,14 +26,30 @@ from homeassistant.const import (
|
|||
STATE_OFF,
|
||||
STATE_ON,
|
||||
)
|
||||
from homeassistant.core import Context, CoreState, State, callback
|
||||
from homeassistant.core import (
|
||||
Context,
|
||||
CoreState,
|
||||
HomeAssistant,
|
||||
ServiceCall,
|
||||
State,
|
||||
callback,
|
||||
)
|
||||
from homeassistant.exceptions import HomeAssistantError, Unauthorized
|
||||
from homeassistant.helpers.script import (
|
||||
SCRIPT_MODE_CHOICES,
|
||||
SCRIPT_MODE_PARALLEL,
|
||||
SCRIPT_MODE_QUEUED,
|
||||
SCRIPT_MODE_RESTART,
|
||||
SCRIPT_MODE_SINGLE,
|
||||
_async_stop_scripts_at_shutdown,
|
||||
)
|
||||
from homeassistant.setup import async_setup_component
|
||||
import homeassistant.util.dt as dt_util
|
||||
|
||||
from tests.common import (
|
||||
assert_setup_component,
|
||||
async_capture_events,
|
||||
async_fire_time_changed,
|
||||
async_mock_service,
|
||||
mock_restore_cache,
|
||||
)
|
||||
|
@ -1570,3 +1587,191 @@ async def test_trigger_condition_explicit_id(hass, calls):
|
|||
await hass.async_block_till_done()
|
||||
assert len(calls) == 2
|
||||
assert calls[-1].data.get("param") == "two"
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"automation_mode,automation_runs",
|
||||
(
|
||||
(SCRIPT_MODE_PARALLEL, 2),
|
||||
(SCRIPT_MODE_QUEUED, 2),
|
||||
(SCRIPT_MODE_RESTART, 2),
|
||||
(SCRIPT_MODE_SINGLE, 1),
|
||||
),
|
||||
)
|
||||
@pytest.mark.parametrize(
|
||||
"script_mode,script_warning_msg",
|
||||
(
|
||||
(SCRIPT_MODE_PARALLEL, "script1: Maximum number of runs exceeded"),
|
||||
(SCRIPT_MODE_QUEUED, "script1: Disallowed recursion detected"),
|
||||
(SCRIPT_MODE_RESTART, "script1: Disallowed recursion detected"),
|
||||
(SCRIPT_MODE_SINGLE, "script1: Already running"),
|
||||
),
|
||||
)
|
||||
async def test_recursive_automation_starting_script(
|
||||
hass: HomeAssistant,
|
||||
automation_mode,
|
||||
automation_runs,
|
||||
script_mode,
|
||||
script_warning_msg,
|
||||
caplog,
|
||||
):
|
||||
"""Test starting automations does not interfere with script deadlock prevention."""
|
||||
|
||||
# Fail if additional script modes are added to
|
||||
# make sure we cover all script modes in tests
|
||||
assert SCRIPT_MODE_CHOICES == [
|
||||
SCRIPT_MODE_PARALLEL,
|
||||
SCRIPT_MODE_QUEUED,
|
||||
SCRIPT_MODE_RESTART,
|
||||
SCRIPT_MODE_SINGLE,
|
||||
]
|
||||
|
||||
stop_scripts_at_shutdown_called = asyncio.Event()
|
||||
real_stop_scripts_at_shutdown = _async_stop_scripts_at_shutdown
|
||||
|
||||
async def mock_stop_scripts_at_shutdown(*args):
|
||||
await real_stop_scripts_at_shutdown(*args)
|
||||
stop_scripts_at_shutdown_called.set()
|
||||
|
||||
with patch(
|
||||
"homeassistant.helpers.script._async_stop_scripts_at_shutdown",
|
||||
wraps=mock_stop_scripts_at_shutdown,
|
||||
):
|
||||
assert await async_setup_component(
|
||||
hass,
|
||||
"script",
|
||||
{
|
||||
"script": {
|
||||
"script1": {
|
||||
"mode": script_mode,
|
||||
"sequence": [
|
||||
{"event": "trigger_automation"},
|
||||
{
|
||||
"wait_template": f"{{{{ float(states('sensor.test'), 0) >= {automation_runs} }}}}"
|
||||
},
|
||||
{"service": "script.script1"},
|
||||
{"service": "test.script_done"},
|
||||
],
|
||||
},
|
||||
}
|
||||
},
|
||||
)
|
||||
|
||||
assert await async_setup_component(
|
||||
hass,
|
||||
automation.DOMAIN,
|
||||
{
|
||||
automation.DOMAIN: {
|
||||
"mode": automation_mode,
|
||||
"trigger": [
|
||||
{"platform": "event", "event_type": "trigger_automation"},
|
||||
],
|
||||
"action": [
|
||||
{"service": "test.automation_started"},
|
||||
{"service": "script.script1"},
|
||||
],
|
||||
}
|
||||
},
|
||||
)
|
||||
|
||||
script_done_event = asyncio.Event()
|
||||
script_done = []
|
||||
automation_started = []
|
||||
automation_triggered = []
|
||||
|
||||
async def async_service_handler(service: ServiceCall):
|
||||
if service.service == "automation_started":
|
||||
automation_started.append(service)
|
||||
elif service.service == "script_done":
|
||||
script_done.append(service)
|
||||
if len(script_done) == 1:
|
||||
script_done_event.set()
|
||||
|
||||
async def async_automation_triggered(event):
|
||||
"""Listen to automation_triggered event from the automation integration."""
|
||||
automation_triggered.append(event)
|
||||
hass.states.async_set("sensor.test", str(len(automation_triggered)))
|
||||
|
||||
hass.services.async_register("test", "script_done", async_service_handler)
|
||||
hass.services.async_register(
|
||||
"test", "automation_started", async_service_handler
|
||||
)
|
||||
hass.bus.async_listen("automation_triggered", async_automation_triggered)
|
||||
|
||||
hass.bus.async_fire("trigger_automation")
|
||||
await asyncio.wait_for(script_done_event.wait(), 1)
|
||||
|
||||
# Trigger 1st stage script shutdown
|
||||
hass.state = CoreState.stopping
|
||||
hass.bus.async_fire("homeassistant_stop")
|
||||
await asyncio.wait_for(stop_scripts_at_shutdown_called.wait(), 1)
|
||||
|
||||
# Trigger 2nd stage script shutdown
|
||||
async_fire_time_changed(hass, dt_util.utcnow() + timedelta(seconds=60))
|
||||
await hass.async_block_till_done()
|
||||
|
||||
assert script_warning_msg in caplog.text
|
||||
|
||||
|
||||
@pytest.mark.parametrize("automation_mode", SCRIPT_MODE_CHOICES)
|
||||
async def test_recursive_automation(hass: HomeAssistant, automation_mode, caplog):
|
||||
"""Test automation triggering itself.
|
||||
|
||||
- Illegal recursion detection should not be triggered
|
||||
- Home Assistant should not hang on shut down
|
||||
"""
|
||||
stop_scripts_at_shutdown_called = asyncio.Event()
|
||||
real_stop_scripts_at_shutdown = _async_stop_scripts_at_shutdown
|
||||
|
||||
async def stop_scripts_at_shutdown(*args):
|
||||
await real_stop_scripts_at_shutdown(*args)
|
||||
stop_scripts_at_shutdown_called.set()
|
||||
|
||||
with patch(
|
||||
"homeassistant.helpers.script._async_stop_scripts_at_shutdown",
|
||||
wraps=stop_scripts_at_shutdown,
|
||||
):
|
||||
assert await async_setup_component(
|
||||
hass,
|
||||
automation.DOMAIN,
|
||||
{
|
||||
automation.DOMAIN: {
|
||||
"mode": automation_mode,
|
||||
"trigger": [
|
||||
{"platform": "event", "event_type": "trigger_automation"},
|
||||
],
|
||||
"action": [
|
||||
{"event": "trigger_automation"},
|
||||
{"service": "test.automation_done"},
|
||||
],
|
||||
}
|
||||
},
|
||||
)
|
||||
|
||||
service_called = asyncio.Event()
|
||||
service_called_late = []
|
||||
|
||||
async def async_service_handler(service):
|
||||
if service.service == "automation_done":
|
||||
service_called.set()
|
||||
if service.service == "automation_started_late":
|
||||
service_called_late.append(service)
|
||||
|
||||
hass.services.async_register("test", "automation_done", async_service_handler)
|
||||
hass.services.async_register(
|
||||
"test", "automation_started_late", async_service_handler
|
||||
)
|
||||
|
||||
hass.bus.async_fire("trigger_automation")
|
||||
await asyncio.wait_for(service_called.wait(), 1)
|
||||
|
||||
# Trigger 1st stage script shutdown
|
||||
hass.state = CoreState.stopping
|
||||
hass.bus.async_fire("homeassistant_stop")
|
||||
await asyncio.wait_for(stop_scripts_at_shutdown_called.wait(), 1)
|
||||
|
||||
# Trigger 2nd stage script shutdown
|
||||
async_fire_time_changed(hass, dt_util.utcnow() + timedelta(seconds=90))
|
||||
await hass.async_block_till_done()
|
||||
|
||||
assert "Disallowed recursion detected" not in caplog.text
|
||||
|
|
|
@ -840,8 +840,6 @@ async def test_recursive_script(hass, script_mode, warning_msg, caplog):
|
|||
service_called.set()
|
||||
|
||||
hass.services.async_register("test", "script", async_service_handler)
|
||||
hass.states.async_set("input_boolean.test", "on")
|
||||
hass.states.async_set("input_boolean.test2", "off")
|
||||
|
||||
await hass.services.async_call("script", "script1")
|
||||
await asyncio.wait_for(service_called.wait(), 1)
|
||||
|
@ -908,8 +906,6 @@ async def test_recursive_script_indirect(hass, script_mode, warning_msg, caplog)
|
|||
service_called.set()
|
||||
|
||||
hass.services.async_register("test", "script", async_service_handler)
|
||||
hass.states.async_set("input_boolean.test", "on")
|
||||
hass.states.async_set("input_boolean.test2", "off")
|
||||
|
||||
await hass.services.async_call("script", "script1")
|
||||
await asyncio.wait_for(service_called.wait(), 1)
|
||||
|
|
Loading…
Reference in New Issue