Fix script in restart mode that is fired from the same trigger (#116247)

pull/116298/head
J. Nick Koston 2024-04-27 07:08:29 -05:00 committed by GitHub
parent a37d274b37
commit 7715bee6b0
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 91 additions and 11 deletions

View File

@ -1692,7 +1692,7 @@ class Script:
script_stack = script_stack_cv.get()
if (
self.script_mode in (SCRIPT_MODE_RESTART, SCRIPT_MODE_QUEUED)
and (script_stack := script_stack_cv.get()) is not None
and script_stack is not None
and id(self) in script_stack
):
script_execution_set("disallowed_recursion_detected")
@ -1706,15 +1706,19 @@ class Script:
run = cls(
self._hass, self, cast(dict, variables), context, self._log_exceptions
)
has_existing_runs = bool(self._runs)
self._runs.append(run)
if self.script_mode == SCRIPT_MODE_RESTART:
if self.script_mode == SCRIPT_MODE_RESTART and has_existing_runs:
# When script mode is SCRIPT_MODE_RESTART, first add the new run and then
# stop any other runs. If we stop other runs first, self.is_running will
# return false after the other script runs were stopped until our task
# resumes running.
# resumes running. Its important that we check if there are existing
# runs before sleeping as otherwise if two runs are started at the exact
# same time they will cancel each other out.
self._log("Restarting")
# Important: yield to the event loop to allow the script to start in case
# the script is restarting itself.
# the script is restarting itself so it ends up in the script stack and
# the recursion check above will prevent the script from running.
await asyncio.sleep(0)
await self.async_stop(update_state=False, spare=run)
@ -1730,9 +1734,7 @@ class Script:
self._changed()
raise
async def _async_stop(
self, aws: list[asyncio.Task], update_state: bool, spare: _ScriptRun | None
) -> None:
async def _async_stop(self, aws: list[asyncio.Task], update_state: bool) -> None:
await asyncio.wait(aws)
if update_state:
self._changed()
@ -1749,9 +1751,7 @@ class Script:
]
if not aws:
return
await asyncio.shield(
create_eager_task(self._async_stop(aws, update_state, spare))
)
await asyncio.shield(create_eager_task(self._async_stop(aws, update_state)))
async def _async_get_condition(self, config):
if isinstance(config, template.Template):

View File

@ -8,7 +8,7 @@ from unittest.mock import Mock, patch
import pytest
from homeassistant.components import automation
from homeassistant.components import automation, input_boolean, script
from homeassistant.components.automation import (
ATTR_SOURCE,
DOMAIN,
@ -41,6 +41,7 @@ from homeassistant.core import (
)
from homeassistant.exceptions import HomeAssistantError, Unauthorized
from homeassistant.helpers import device_registry as dr
from homeassistant.helpers.event import async_track_state_change_event
from homeassistant.helpers.script import (
SCRIPT_MODE_CHOICES,
SCRIPT_MODE_PARALLEL,
@ -2980,3 +2981,82 @@ async def test_automation_turns_off_other_automation(
async_fire_time_changed(hass, dt_util.utcnow() + timedelta(seconds=5))
await hass.async_block_till_done()
assert len(calls) == 0
async def test_two_automations_call_restart_script_same_time(
hass: HomeAssistant, caplog: pytest.LogCaptureFixture
) -> None:
"""Test two automations that call a restart mode script at the same."""
hass.states.async_set("binary_sensor.presence", "off")
await hass.async_block_till_done()
events = []
@callback
def _save_event(event):
events.append(event)
assert await async_setup_component(
hass,
input_boolean.DOMAIN,
{
input_boolean.DOMAIN: {
"test_1": None,
}
},
)
cancel = async_track_state_change_event(hass, "input_boolean.test_1", _save_event)
assert await async_setup_component(
hass,
script.DOMAIN,
{
script.DOMAIN: {
"fire_toggle": {
"sequence": [
{
"service": "input_boolean.toggle",
"target": {"entity_id": "input_boolean.test_1"},
}
]
},
}
},
)
assert await async_setup_component(
hass,
automation.DOMAIN,
{
automation.DOMAIN: [
{
"trigger": {
"platform": "state",
"entity_id": "binary_sensor.presence",
"to": "on",
},
"action": {
"service": "script.fire_toggle",
},
"id": "automation_0",
"mode": "single",
},
{
"trigger": {
"platform": "state",
"entity_id": "binary_sensor.presence",
"to": "on",
},
"action": {
"service": "script.fire_toggle",
},
"id": "automation_1",
"mode": "single",
},
]
},
)
hass.states.async_set("binary_sensor.presence", "on")
await hass.async_block_till_done()
assert len(events) == 2
cancel()