Improve script disallowed recursion logging (#118151)
parent
189cf88537
commit
7bbb33b415
|
@ -157,7 +157,7 @@ SCRIPT_DEBUG_CONTINUE_STOP: SignalTypeFormat[Literal["continue", "stop"]] = (
|
|||
)
|
||||
SCRIPT_DEBUG_CONTINUE_ALL = "script_debug_continue_all"
|
||||
|
||||
script_stack_cv: ContextVar[list[int] | None] = ContextVar("script_stack", default=None)
|
||||
script_stack_cv: ContextVar[list[str] | None] = ContextVar("script_stack", default=None)
|
||||
|
||||
|
||||
class ScriptData(TypedDict):
|
||||
|
@ -452,7 +452,7 @@ class _ScriptRun:
|
|||
if (script_stack := script_stack_cv.get()) is None:
|
||||
script_stack = []
|
||||
script_stack_cv.set(script_stack)
|
||||
script_stack.append(id(self._script))
|
||||
script_stack.append(self._script.unique_id)
|
||||
response = None
|
||||
|
||||
try:
|
||||
|
@ -1401,6 +1401,7 @@ class Script:
|
|||
self.sequence = sequence
|
||||
template.attach(hass, self.sequence)
|
||||
self.name = name
|
||||
self.unique_id = f"{domain}.{name}-{id(self)}"
|
||||
self.domain = domain
|
||||
self.running_description = running_description or f"{domain} script"
|
||||
self._change_listener = change_listener
|
||||
|
@ -1723,10 +1724,21 @@ class Script:
|
|||
if (
|
||||
self.script_mode in (SCRIPT_MODE_RESTART, SCRIPT_MODE_QUEUED)
|
||||
and script_stack is not None
|
||||
and id(self) in script_stack
|
||||
and self.unique_id in script_stack
|
||||
):
|
||||
script_execution_set("disallowed_recursion_detected")
|
||||
self._log("Disallowed recursion detected", level=logging.WARNING)
|
||||
formatted_stack = [
|
||||
f"- {name_id.partition('-')[0]}" for name_id in script_stack
|
||||
]
|
||||
self._log(
|
||||
"Disallowed recursion detected, "
|
||||
f"{script_stack[-1].partition('-')[0]} tried to start "
|
||||
f"{self.domain}.{self.name} which is already running "
|
||||
"in the current execution path; "
|
||||
"Traceback (most recent call last):\n"
|
||||
f"{"\n".join(formatted_stack)}",
|
||||
level=logging.WARNING,
|
||||
)
|
||||
return None
|
||||
|
||||
if self.script_mode != SCRIPT_MODE_QUEUED:
|
||||
|
|
|
@ -6247,3 +6247,72 @@ async def test_stopping_run_before_starting(
|
|||
# would hang indefinitely.
|
||||
run = script._ScriptRun(hass, script_obj, {}, None, True)
|
||||
await run.async_stop()
|
||||
|
||||
|
||||
async def test_disallowed_recursion(
|
||||
hass: HomeAssistant, caplog: pytest.LogCaptureFixture
|
||||
) -> None:
|
||||
"""Test a queued mode script disallowed recursion."""
|
||||
context = Context()
|
||||
calls = 0
|
||||
alias = "event step"
|
||||
sequence1 = cv.SCRIPT_SCHEMA({"alias": alias, "service": "test.call_script_2"})
|
||||
script1_obj = script.Script(
|
||||
hass,
|
||||
sequence1,
|
||||
"Test Name1",
|
||||
"test_domain1",
|
||||
script_mode="queued",
|
||||
running_description="test script1",
|
||||
)
|
||||
|
||||
sequence2 = cv.SCRIPT_SCHEMA({"alias": alias, "service": "test.call_script_3"})
|
||||
script2_obj = script.Script(
|
||||
hass,
|
||||
sequence2,
|
||||
"Test Name2",
|
||||
"test_domain2",
|
||||
script_mode="queued",
|
||||
running_description="test script2",
|
||||
)
|
||||
|
||||
sequence3 = cv.SCRIPT_SCHEMA({"alias": alias, "service": "test.call_script_1"})
|
||||
script3_obj = script.Script(
|
||||
hass,
|
||||
sequence3,
|
||||
"Test Name3",
|
||||
"test_domain3",
|
||||
script_mode="queued",
|
||||
running_description="test script3",
|
||||
)
|
||||
|
||||
async def _async_service_handler_1(*args, **kwargs) -> None:
|
||||
await script1_obj.async_run(context=context)
|
||||
|
||||
hass.services.async_register("test", "call_script_1", _async_service_handler_1)
|
||||
|
||||
async def _async_service_handler_2(*args, **kwargs) -> None:
|
||||
await script2_obj.async_run(context=context)
|
||||
|
||||
hass.services.async_register("test", "call_script_2", _async_service_handler_2)
|
||||
|
||||
async def _async_service_handler_3(*args, **kwargs) -> None:
|
||||
await script3_obj.async_run(context=context)
|
||||
|
||||
hass.services.async_register("test", "call_script_3", _async_service_handler_3)
|
||||
|
||||
await script1_obj.async_run(context=context)
|
||||
await hass.async_block_till_done()
|
||||
|
||||
assert calls == 0
|
||||
assert (
|
||||
"Test Name1: Disallowed recursion detected, "
|
||||
"test_domain3.Test Name3 tried to start test_domain1.Test Name1"
|
||||
" which is already running in the current execution path; "
|
||||
"Traceback (most recent call last):"
|
||||
) in caplog.text
|
||||
assert (
|
||||
"- test_domain1.Test Name1\n"
|
||||
"- test_domain2.Test Name2\n"
|
||||
"- test_domain3.Test Name3"
|
||||
) in caplog.text
|
||||
|
|
Loading…
Reference in New Issue