Improve script disallowed recursion logging (#118151)

pull/118171/head
J. Nick Koston 2024-05-26 00:58:34 -10:00 committed by GitHub
parent 189cf88537
commit 7bbb33b415
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 85 additions and 4 deletions

View File

@ -157,7 +157,7 @@ SCRIPT_DEBUG_CONTINUE_STOP: SignalTypeFormat[Literal["continue", "stop"]] = (
)
SCRIPT_DEBUG_CONTINUE_ALL = "script_debug_continue_all"
script_stack_cv: ContextVar[list[int] | None] = ContextVar("script_stack", default=None)
script_stack_cv: ContextVar[list[str] | None] = ContextVar("script_stack", default=None)
class ScriptData(TypedDict):
@ -452,7 +452,7 @@ class _ScriptRun:
if (script_stack := script_stack_cv.get()) is None:
script_stack = []
script_stack_cv.set(script_stack)
script_stack.append(id(self._script))
script_stack.append(self._script.unique_id)
response = None
try:
@ -1401,6 +1401,7 @@ class Script:
self.sequence = sequence
template.attach(hass, self.sequence)
self.name = name
self.unique_id = f"{domain}.{name}-{id(self)}"
self.domain = domain
self.running_description = running_description or f"{domain} script"
self._change_listener = change_listener
@ -1723,10 +1724,21 @@ class Script:
if (
self.script_mode in (SCRIPT_MODE_RESTART, SCRIPT_MODE_QUEUED)
and script_stack is not None
and id(self) in script_stack
and self.unique_id in script_stack
):
script_execution_set("disallowed_recursion_detected")
self._log("Disallowed recursion detected", level=logging.WARNING)
formatted_stack = [
f"- {name_id.partition('-')[0]}" for name_id in script_stack
]
self._log(
"Disallowed recursion detected, "
f"{script_stack[-1].partition('-')[0]} tried to start "
f"{self.domain}.{self.name} which is already running "
"in the current execution path; "
"Traceback (most recent call last):\n"
f"{"\n".join(formatted_stack)}",
level=logging.WARNING,
)
return None
if self.script_mode != SCRIPT_MODE_QUEUED:

View File

@ -6247,3 +6247,72 @@ async def test_stopping_run_before_starting(
# would hang indefinitely.
run = script._ScriptRun(hass, script_obj, {}, None, True)
await run.async_stop()
async def test_disallowed_recursion(
hass: HomeAssistant, caplog: pytest.LogCaptureFixture
) -> None:
"""Test a queued mode script disallowed recursion."""
context = Context()
calls = 0
alias = "event step"
sequence1 = cv.SCRIPT_SCHEMA({"alias": alias, "service": "test.call_script_2"})
script1_obj = script.Script(
hass,
sequence1,
"Test Name1",
"test_domain1",
script_mode="queued",
running_description="test script1",
)
sequence2 = cv.SCRIPT_SCHEMA({"alias": alias, "service": "test.call_script_3"})
script2_obj = script.Script(
hass,
sequence2,
"Test Name2",
"test_domain2",
script_mode="queued",
running_description="test script2",
)
sequence3 = cv.SCRIPT_SCHEMA({"alias": alias, "service": "test.call_script_1"})
script3_obj = script.Script(
hass,
sequence3,
"Test Name3",
"test_domain3",
script_mode="queued",
running_description="test script3",
)
async def _async_service_handler_1(*args, **kwargs) -> None:
await script1_obj.async_run(context=context)
hass.services.async_register("test", "call_script_1", _async_service_handler_1)
async def _async_service_handler_2(*args, **kwargs) -> None:
await script2_obj.async_run(context=context)
hass.services.async_register("test", "call_script_2", _async_service_handler_2)
async def _async_service_handler_3(*args, **kwargs) -> None:
await script3_obj.async_run(context=context)
hass.services.async_register("test", "call_script_3", _async_service_handler_3)
await script1_obj.async_run(context=context)
await hass.async_block_till_done()
assert calls == 0
assert (
"Test Name1: Disallowed recursion detected, "
"test_domain3.Test Name3 tried to start test_domain1.Test Name1"
" which is already running in the current execution path; "
"Traceback (most recent call last):"
) in caplog.text
assert (
"- test_domain1.Test Name1\n"
"- test_domain2.Test Name2\n"
"- test_domain3.Test Name3"
) in caplog.text