Don't prevent automations from triggering themselves (#68178)

pull/68205/head
Erik Montnemery 2022-03-15 18:48:54 +01:00 committed by GitHub
parent b99934f62f
commit 46f27fdefd
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 212 additions and 6 deletions

View File

@ -54,6 +54,7 @@ from homeassistant.helpers.script import (
CONF_MAX,
CONF_MAX_EXCEEDED,
Script,
script_stack_cv,
)
from homeassistant.helpers.script_variables import ScriptVariables
from homeassistant.helpers.service import (
@ -505,6 +506,10 @@ class AutomationEntity(ToggleEntity, RestoreEntity):
EVENT_AUTOMATION_TRIGGERED, event_data, context=trigger_context
)
# Make a new empty script stack; automations are allowed
# to recursively trigger themselves
script_stack_cv.set([])
try:
with trace_path("action"):
await self.action_script.async_run(

View File

@ -1247,7 +1247,7 @@ class Script:
and id(self) in script_stack
):
script_execution_set("disallowed_recursion_detected")
_LOGGER.warning("Disallowed recursion detected")
self._log("Disallowed recursion detected", level=logging.WARNING)
return
if self.script_mode != SCRIPT_MODE_QUEUED:

View File

@ -1,5 +1,6 @@
"""The tests for the automation component."""
import asyncio
from datetime import timedelta
import logging
from unittest.mock import Mock, patch
@ -25,14 +26,30 @@ from homeassistant.const import (
STATE_OFF,
STATE_ON,
)
from homeassistant.core import Context, CoreState, State, callback
from homeassistant.core import (
Context,
CoreState,
HomeAssistant,
ServiceCall,
State,
callback,
)
from homeassistant.exceptions import HomeAssistantError, Unauthorized
from homeassistant.helpers.script import (
SCRIPT_MODE_CHOICES,
SCRIPT_MODE_PARALLEL,
SCRIPT_MODE_QUEUED,
SCRIPT_MODE_RESTART,
SCRIPT_MODE_SINGLE,
_async_stop_scripts_at_shutdown,
)
from homeassistant.setup import async_setup_component
import homeassistant.util.dt as dt_util
from tests.common import (
assert_setup_component,
async_capture_events,
async_fire_time_changed,
async_mock_service,
mock_restore_cache,
)
@ -1570,3 +1587,191 @@ async def test_trigger_condition_explicit_id(hass, calls):
await hass.async_block_till_done()
assert len(calls) == 2
assert calls[-1].data.get("param") == "two"
@pytest.mark.parametrize(
"automation_mode,automation_runs",
(
(SCRIPT_MODE_PARALLEL, 2),
(SCRIPT_MODE_QUEUED, 2),
(SCRIPT_MODE_RESTART, 2),
(SCRIPT_MODE_SINGLE, 1),
),
)
@pytest.mark.parametrize(
"script_mode,script_warning_msg",
(
(SCRIPT_MODE_PARALLEL, "script1: Maximum number of runs exceeded"),
(SCRIPT_MODE_QUEUED, "script1: Disallowed recursion detected"),
(SCRIPT_MODE_RESTART, "script1: Disallowed recursion detected"),
(SCRIPT_MODE_SINGLE, "script1: Already running"),
),
)
async def test_recursive_automation_starting_script(
hass: HomeAssistant,
automation_mode,
automation_runs,
script_mode,
script_warning_msg,
caplog,
):
"""Test starting automations does not interfere with script deadlock prevention."""
# Fail if additional script modes are added to
# make sure we cover all script modes in tests
assert SCRIPT_MODE_CHOICES == [
SCRIPT_MODE_PARALLEL,
SCRIPT_MODE_QUEUED,
SCRIPT_MODE_RESTART,
SCRIPT_MODE_SINGLE,
]
stop_scripts_at_shutdown_called = asyncio.Event()
real_stop_scripts_at_shutdown = _async_stop_scripts_at_shutdown
async def mock_stop_scripts_at_shutdown(*args):
await real_stop_scripts_at_shutdown(*args)
stop_scripts_at_shutdown_called.set()
with patch(
"homeassistant.helpers.script._async_stop_scripts_at_shutdown",
wraps=mock_stop_scripts_at_shutdown,
):
assert await async_setup_component(
hass,
"script",
{
"script": {
"script1": {
"mode": script_mode,
"sequence": [
{"event": "trigger_automation"},
{
"wait_template": f"{{{{ float(states('sensor.test'), 0) >= {automation_runs} }}}}"
},
{"service": "script.script1"},
{"service": "test.script_done"},
],
},
}
},
)
assert await async_setup_component(
hass,
automation.DOMAIN,
{
automation.DOMAIN: {
"mode": automation_mode,
"trigger": [
{"platform": "event", "event_type": "trigger_automation"},
],
"action": [
{"service": "test.automation_started"},
{"service": "script.script1"},
],
}
},
)
script_done_event = asyncio.Event()
script_done = []
automation_started = []
automation_triggered = []
async def async_service_handler(service: ServiceCall):
if service.service == "automation_started":
automation_started.append(service)
elif service.service == "script_done":
script_done.append(service)
if len(script_done) == 1:
script_done_event.set()
async def async_automation_triggered(event):
"""Listen to automation_triggered event from the automation integration."""
automation_triggered.append(event)
hass.states.async_set("sensor.test", str(len(automation_triggered)))
hass.services.async_register("test", "script_done", async_service_handler)
hass.services.async_register(
"test", "automation_started", async_service_handler
)
hass.bus.async_listen("automation_triggered", async_automation_triggered)
hass.bus.async_fire("trigger_automation")
await asyncio.wait_for(script_done_event.wait(), 1)
# Trigger 1st stage script shutdown
hass.state = CoreState.stopping
hass.bus.async_fire("homeassistant_stop")
await asyncio.wait_for(stop_scripts_at_shutdown_called.wait(), 1)
# Trigger 2nd stage script shutdown
async_fire_time_changed(hass, dt_util.utcnow() + timedelta(seconds=60))
await hass.async_block_till_done()
assert script_warning_msg in caplog.text
@pytest.mark.parametrize("automation_mode", SCRIPT_MODE_CHOICES)
async def test_recursive_automation(hass: HomeAssistant, automation_mode, caplog):
"""Test automation triggering itself.
- Illegal recursion detection should not be triggered
- Home Assistant should not hang on shut down
"""
stop_scripts_at_shutdown_called = asyncio.Event()
real_stop_scripts_at_shutdown = _async_stop_scripts_at_shutdown
async def stop_scripts_at_shutdown(*args):
await real_stop_scripts_at_shutdown(*args)
stop_scripts_at_shutdown_called.set()
with patch(
"homeassistant.helpers.script._async_stop_scripts_at_shutdown",
wraps=stop_scripts_at_shutdown,
):
assert await async_setup_component(
hass,
automation.DOMAIN,
{
automation.DOMAIN: {
"mode": automation_mode,
"trigger": [
{"platform": "event", "event_type": "trigger_automation"},
],
"action": [
{"event": "trigger_automation"},
{"service": "test.automation_done"},
],
}
},
)
service_called = asyncio.Event()
service_called_late = []
async def async_service_handler(service):
if service.service == "automation_done":
service_called.set()
if service.service == "automation_started_late":
service_called_late.append(service)
hass.services.async_register("test", "automation_done", async_service_handler)
hass.services.async_register(
"test", "automation_started_late", async_service_handler
)
hass.bus.async_fire("trigger_automation")
await asyncio.wait_for(service_called.wait(), 1)
# Trigger 1st stage script shutdown
hass.state = CoreState.stopping
hass.bus.async_fire("homeassistant_stop")
await asyncio.wait_for(stop_scripts_at_shutdown_called.wait(), 1)
# Trigger 2nd stage script shutdown
async_fire_time_changed(hass, dt_util.utcnow() + timedelta(seconds=90))
await hass.async_block_till_done()
assert "Disallowed recursion detected" not in caplog.text

View File

@ -840,8 +840,6 @@ async def test_recursive_script(hass, script_mode, warning_msg, caplog):
service_called.set()
hass.services.async_register("test", "script", async_service_handler)
hass.states.async_set("input_boolean.test", "on")
hass.states.async_set("input_boolean.test2", "off")
await hass.services.async_call("script", "script1")
await asyncio.wait_for(service_called.wait(), 1)
@ -908,8 +906,6 @@ async def test_recursive_script_indirect(hass, script_mode, warning_msg, caplog)
service_called.set()
hass.services.async_register("test", "script", async_service_handler)
hass.states.async_set("input_boolean.test", "on")
hass.states.async_set("input_boolean.test2", "off")
await hass.services.async_call("script", "script1")
await asyncio.wait_for(service_called.wait(), 1)