Fix parallel script execution in queued mode (#118153)

pull/118171/head
J. Nick Koston 2024-05-26 01:05:31 -10:00 committed by GitHub
parent f12f82caac
commit 6697cf07a6
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 52 additions and 0 deletions

View File

@ -609,6 +609,15 @@ class ScriptEntity(BaseScriptEntity, RestoreEntity):
) )
coro = self._async_run(variables, context) coro = self._async_run(variables, context)
if wait: if wait:
# If we are executing in parallel, we need to copy the script stack so
# that if this script is called in parallel, it will not be seen in the
# stack of the other parallel calls and hit the disallowed recursion
# check as each parallel call would otherwise be appending to the same
# stack. We do not wipe the stack in this case because we still want to
# be able to detect if there is a disallowed recursion.
if script_stack := script_stack_cv.get():
script_stack_cv.set(script_stack.copy())
script_result = await coro script_result = await coro
return script_result.service_response if script_result else None return script_result.service_response if script_result else None

View File

@ -1741,3 +1741,46 @@ async def test_responses_no_response(hass: HomeAssistant) -> None:
) )
is None is None
) )
async def test_script_queued_mode(hass: HomeAssistant) -> None:
"""Test calling a queued mode script called in parallel."""
calls = 0
async def async_service_handler(*args, **kwargs) -> None:
"""Service that simulates doing background I/O."""
nonlocal calls
calls += 1
await asyncio.sleep(0)
hass.services.async_register("test", "simulated_remote", async_service_handler)
assert await async_setup_component(
hass,
script.DOMAIN,
{
script.DOMAIN: {
"test_main": {
"sequence": [
{
"parallel": [
{"service": "script.test_sub"},
{"service": "script.test_sub"},
{"service": "script.test_sub"},
{"service": "script.test_sub"},
]
}
]
},
"test_sub": {
"mode": "queued",
"sequence": [
{"service": "test.simulated_remote"},
],
},
}
},
)
await hass.async_block_till_done()
await hass.services.async_call("script", "test_main", blocking=True)
assert calls == 4