diff --git a/homeassistant/helpers/script.py b/homeassistant/helpers/script.py index 7a1f0bdb8da..0f625086235 100644 --- a/homeassistant/helpers/script.py +++ b/homeassistant/helpers/script.py @@ -4,7 +4,7 @@ from __future__ import annotations import asyncio from collections.abc import AsyncGenerator, Callable, Mapping, Sequence -from contextlib import asynccontextmanager, suppress +from contextlib import asynccontextmanager from contextvars import ContextVar from copy import copy from dataclasses import dataclass @@ -15,6 +15,7 @@ import logging from types import MappingProxyType from typing import TYPE_CHECKING, Any, TypedDict, TypeVar, cast +import async_interrupt import voluptuous as vol from homeassistant import exceptions @@ -157,6 +158,16 @@ SCRIPT_DEBUG_CONTINUE_ALL = "script_debug_continue_all" script_stack_cv: ContextVar[list[int] | None] = ContextVar("script_stack", default=None) +class ScriptStoppedError(Exception): + """Error to indicate that the script has been stopped.""" + + +def _set_result_unless_done(future: asyncio.Future[None]) -> None: + """Set result of future unless it is done.""" + if not future.done(): + future.set_result(None) + + def action_trace_append(variables, path): """Append a TraceElement to trace[path].""" trace_element = TraceElement(variables, path) @@ -168,7 +179,7 @@ def action_trace_append(variables, path): async def trace_action( hass: HomeAssistant, script_run: _ScriptRun, - stop: asyncio.Event, + stop: asyncio.Future[None], variables: dict[str, Any], ) -> AsyncGenerator[TraceElement, None]: """Trace action execution.""" @@ -199,13 +210,13 @@ async def trace_action( ): async_dispatcher_send(hass, SCRIPT_BREAKPOINT_HIT, key, run_id, path) - done = asyncio.Event() + done = hass.loop.create_future() @callback def async_continue_stop(command=None): if command == "stop": - stop.set() - done.set() + _set_result_unless_done(stop) + _set_result_unless_done(done) signal = SCRIPT_DEBUG_CONTINUE_STOP.format(key, run_id) remove_signal1 = async_dispatcher_connect(hass, signal, async_continue_stop) @@ -213,10 +224,7 @@ async def trace_action( hass, SCRIPT_DEBUG_CONTINUE_ALL, async_continue_stop ) - tasks = [hass.async_create_task(flag.wait()) for flag in (stop, done)] - await asyncio.wait(tasks, return_when=asyncio.FIRST_COMPLETED) - for task in tasks: - task.cancel() + await asyncio.wait([stop, done], return_when=asyncio.FIRST_COMPLETED) remove_signal1() remove_signal2() @@ -393,12 +401,12 @@ class _ScriptRun: self._log_exceptions = log_exceptions self._step = -1 self._started = False - self._stop = asyncio.Event() + self._stop = hass.loop.create_future() self._stopped = asyncio.Event() self._conversation_response: str | None | UndefinedType = UNDEFINED def _changed(self) -> None: - if not self._stop.is_set(): + if not self._stop.done(): self._script._changed() # pylint: disable=protected-access async def _async_get_condition(self, config): @@ -432,7 +440,7 @@ class _ScriptRun: try: self._log("Running %s", self._script.running_description) for self._step, self._action in enumerate(self._script.sequence): - if self._stop.is_set(): + if self._stop.done(): script_execution_set("cancelled") break await self._async_step(log_exceptions=False) @@ -471,7 +479,7 @@ class _ScriptRun: async with trace_action( self._hass, self, self._stop, self._variables ) as trace_element: - if self._stop.is_set(): + if self._stop.done(): return action = cv.determine_script_action(self._action) @@ -483,8 +491,8 @@ class _ScriptRun: trace_set_result(enabled=False) return + handler = f"_async_{action}_step" try: - handler = f"_async_{action}_step" await getattr(self, handler)() except Exception as ex: # pylint: disable=broad-except self._handle_exception( @@ -502,7 +510,7 @@ class _ScriptRun: async def async_stop(self) -> None: """Stop script run.""" - self._stop.set() + _set_result_unless_done(self._stop) # If the script was never started # the stopped event will never be # set because the script will never @@ -576,9 +584,9 @@ class _ScriptRun: level=level, ) - def _get_pos_time_period_template(self, key): + def _get_pos_time_period_template(self, key: str) -> timedelta: try: - return cv.positive_time_period( + return cv.positive_time_period( # type: ignore[no-any-return] template.render_complex(self._action[key], self._variables) ) except (exceptions.TemplateError, vol.Invalid) as ex: @@ -593,26 +601,34 @@ class _ScriptRun: async def _async_delay_step(self): """Handle delay.""" - delay = self._get_pos_time_period_template(CONF_DELAY) + delay_delta = self._get_pos_time_period_template(CONF_DELAY) - self._step_log(f"delay {delay}") + self._step_log(f"delay {delay_delta}") - delay = delay.total_seconds() + delay = delay_delta.total_seconds() self._changed() trace_set_result(delay=delay, done=False) + futures, timeout_handle, timeout_future = self._async_futures_with_timeout( + delay + ) + try: - async with asyncio.timeout(delay): - await self._stop.wait() - except TimeoutError: - trace_set_result(delay=delay, done=True) + await asyncio.wait(futures, return_when=asyncio.FIRST_COMPLETED) + finally: + if timeout_future.done(): + trace_set_result(delay=delay, done=True) + else: + timeout_handle.cancel() + + def _get_timeout_seconds_from_action(self) -> float | None: + """Get the timeout from the action.""" + if CONF_TIMEOUT in self._action: + return self._get_pos_time_period_template(CONF_TIMEOUT).total_seconds() + return None async def _async_wait_template_step(self): """Handle a wait template.""" - if CONF_TIMEOUT in self._action: - timeout = self._get_pos_time_period_template(CONF_TIMEOUT).total_seconds() - else: - timeout = None - + timeout = self._get_timeout_seconds_from_action() self._step_log("wait template", timeout) self._variables["wait"] = {"remaining": timeout, "completed": False} @@ -626,74 +642,47 @@ class _ScriptRun: self._variables["wait"]["completed"] = True return + futures, timeout_handle, timeout_future = self._async_futures_with_timeout( + timeout + ) + done = self._hass.loop.create_future() + futures.append(done) + @callback def async_script_wait(entity_id, from_s, to_s): """Handle script after template condition is true.""" - # pylint: disable=protected-access - wait_var = self._variables["wait"] - if to_context and to_context._when: - wait_var["remaining"] = to_context._when - self._hass.loop.time() - else: - wait_var["remaining"] = timeout - wait_var["completed"] = True - done.set() + self._async_set_remaining_time_var(timeout_handle) + self._variables["wait"]["completed"] = True + _set_result_unless_done(done) - to_context = None unsub = async_track_template( self._hass, wait_template, async_script_wait, self._variables ) - self._changed() - done = asyncio.Event() - tasks = [ - self._hass.async_create_task(flag.wait()) for flag in (self._stop, done) - ] - try: - async with asyncio.timeout(timeout) as to_context: - await asyncio.wait(tasks, return_when=asyncio.FIRST_COMPLETED) - except TimeoutError as ex: - self._variables["wait"]["remaining"] = 0.0 - if not self._action.get(CONF_CONTINUE_ON_TIMEOUT, True): - self._log(_TIMEOUT_MSG) - trace_set_result(wait=self._variables["wait"], timeout=True) - raise _AbortScript from ex - finally: - for task in tasks: - task.cancel() - unsub() + await self._async_wait_with_optional_timeout( + futures, timeout_handle, timeout_future, unsub + ) + + def _async_set_remaining_time_var( + self, timeout_handle: asyncio.TimerHandle | None + ) -> None: + """Set the remaining time variable for a wait step.""" + wait_var = self._variables["wait"] + if timeout_handle: + wait_var["remaining"] = timeout_handle.when() - self._hass.loop.time() + else: + wait_var["remaining"] = None async def _async_run_long_action(self, long_task: asyncio.Task[_T]) -> _T | None: """Run a long task while monitoring for stop request.""" - - async def async_cancel_long_task() -> None: - # Stop long task and wait for it to finish. - long_task.cancel() - with suppress(Exception): - await long_task - - # Wait for long task while monitoring for a stop request. - stop_task = self._hass.async_create_task(self._stop.wait()) try: - await asyncio.wait( - {long_task, stop_task}, return_when=asyncio.FIRST_COMPLETED - ) - # If our task is cancelled, then cancel long task, too. Note that if long task - # is cancelled otherwise the CancelledError exception will not be raised to - # here due to the call to asyncio.wait(). Rather we'll check for that below. - except asyncio.CancelledError: - await async_cancel_long_task() - raise - finally: - stop_task.cancel() - - if long_task.cancelled(): - raise asyncio.CancelledError - if long_task.done(): - # Propagate any exceptions that occurred. - return long_task.result() - # Stopped before long task completed, so cancel it. - await async_cancel_long_task() - return None + async with async_interrupt.interrupt(self._stop, ScriptStoppedError, None): + # if stop is set, interrupt will cancel inside the context + # manager which will cancel long_task, and raise + # ScriptStoppedError outside the context manager + return await long_task + except ScriptStoppedError as ex: + raise asyncio.CancelledError from ex async def _async_call_service_step(self): """Call the service specified in the action.""" @@ -735,8 +724,9 @@ class _ScriptRun: blocking=True, context=self._context, return_response=return_response, - ) - ), + ), + eager_start=True, + ) ) if response_variable: self._variables[response_variable] = response_data @@ -866,7 +856,7 @@ class _ScriptRun: for iteration in range(1, count + 1): set_repeat_var(iteration, count) await async_run_sequence(iteration, extra_msg) - if self._stop.is_set(): + if self._stop.done(): break elif CONF_FOR_EACH in repeat: @@ -894,7 +884,7 @@ class _ScriptRun: for iteration, item in enumerate(items, 1): set_repeat_var(iteration, count, item) extra_msg = f" of {count} with item: {repr(item)}" - if self._stop.is_set(): + if self._stop.done(): break await async_run_sequence(iteration, extra_msg) @@ -905,7 +895,7 @@ class _ScriptRun: for iteration in itertools.count(1): set_repeat_var(iteration) try: - if self._stop.is_set(): + if self._stop.done(): break if not self._test_conditions(conditions, "while"): break @@ -923,7 +913,7 @@ class _ScriptRun: set_repeat_var(iteration) await async_run_sequence(iteration) try: - if self._stop.is_set(): + if self._stop.done(): break if self._test_conditions(conditions, "until") in [True, None]: break @@ -983,12 +973,35 @@ class _ScriptRun: with trace_path("else"): await self._async_run_script(if_data["if_else"]) + def _async_futures_with_timeout( + self, + timeout: float | None, + ) -> tuple[ + list[asyncio.Future[None]], + asyncio.TimerHandle | None, + asyncio.Future[None] | None, + ]: + """Return a list of futures to wait for. + + The list will contain the stop future. + + If timeout is set, a timeout future and handle will be created + and will be added to the list of futures. + """ + timeout_handle: asyncio.TimerHandle | None = None + timeout_future: asyncio.Future[None] | None = None + futures: list[asyncio.Future[None]] = [self._stop] + if timeout: + timeout_future = self._hass.loop.create_future() + timeout_handle = self._hass.loop.call_later( + timeout, _set_result_unless_done, timeout_future + ) + futures.append(timeout_future) + return futures, timeout_handle, timeout_future + async def _async_wait_for_trigger_step(self): """Wait for a trigger event.""" - if CONF_TIMEOUT in self._action: - timeout = self._get_pos_time_period_template(CONF_TIMEOUT).total_seconds() - else: - timeout = None + timeout = self._get_timeout_seconds_from_action() self._step_log("wait for trigger", timeout) @@ -996,22 +1009,20 @@ class _ScriptRun: self._variables["wait"] = {"remaining": timeout, "trigger": None} trace_set_result(wait=self._variables["wait"]) - done = asyncio.Event() + futures, timeout_handle, timeout_future = self._async_futures_with_timeout( + timeout + ) + done = self._hass.loop.create_future() + futures.append(done) async def async_done(variables, context=None): - # pylint: disable=protected-access - wait_var = self._variables["wait"] - if to_context and to_context._when: - wait_var["remaining"] = to_context._when - self._hass.loop.time() - else: - wait_var["remaining"] = timeout - wait_var["trigger"] = variables["trigger"] - done.set() + self._async_set_remaining_time_var(timeout_handle) + self._variables["wait"]["trigger"] = variables["trigger"] + _set_result_unless_done(done) def log_cb(level, msg, **kwargs): self._log(msg, level=level, **kwargs) - to_context = None remove_triggers = await async_initialize_triggers( self._hass, self._action[CONF_WAIT_FOR_TRIGGER], @@ -1023,24 +1034,31 @@ class _ScriptRun: ) if not remove_triggers: return - self._changed() - tasks = [ - self._hass.async_create_task(flag.wait()) for flag in (self._stop, done) - ] + await self._async_wait_with_optional_timeout( + futures, timeout_handle, timeout_future, remove_triggers + ) + + async def _async_wait_with_optional_timeout( + self, + futures: list[asyncio.Future[None]], + timeout_handle: asyncio.TimerHandle | None, + timeout_future: asyncio.Future[None] | None, + unsub: Callable[[], None], + ) -> None: try: - async with asyncio.timeout(timeout) as to_context: - await asyncio.wait(tasks, return_when=asyncio.FIRST_COMPLETED) - except TimeoutError as ex: - self._variables["wait"]["remaining"] = 0.0 - if not self._action.get(CONF_CONTINUE_ON_TIMEOUT, True): - self._log(_TIMEOUT_MSG) - trace_set_result(wait=self._variables["wait"], timeout=True) - raise _AbortScript from ex + await asyncio.wait(futures, return_when=asyncio.FIRST_COMPLETED) + if timeout_future and timeout_future.done(): + self._variables["wait"]["remaining"] = 0.0 + if not self._action.get(CONF_CONTINUE_ON_TIMEOUT, True): + self._log(_TIMEOUT_MSG) + trace_set_result(wait=self._variables["wait"], timeout=True) + raise _AbortScript from TimeoutError() finally: - for task in tasks: - task.cancel() - remove_triggers() + if timeout_future and not timeout_future.done() and timeout_handle: + timeout_handle.cancel() + + unsub() async def _async_variables_step(self): """Set a variable value.""" @@ -1107,7 +1125,7 @@ class _ScriptRun: """Execute a script.""" result = await self._async_run_long_action( self._hass.async_create_task( - script.async_run(self._variables, self._context) + script.async_run(self._variables, self._context), eager_start=True ) ) if result and result.conversation_response is not UNDEFINED: @@ -1123,29 +1141,17 @@ class _QueuedScriptRun(_ScriptRun): """Run script.""" # Wait for previous run, if any, to finish by attempting to acquire the script's # shared lock. At the same time monitor if we've been told to stop. - lock_task = self._hass.async_create_task( - self._script._queue_lck.acquire() # pylint: disable=protected-access - ) - stop_task = self._hass.async_create_task(self._stop.wait()) try: - await asyncio.wait( - {lock_task, stop_task}, return_when=asyncio.FIRST_COMPLETED - ) - except asyncio.CancelledError: + async with async_interrupt.interrupt(self._stop, ScriptStoppedError, None): + await self._script._queue_lck.acquire() # pylint: disable=protected-access + except ScriptStoppedError as ex: + # If we've been told to stop, then just finish up. self._finish() - raise - else: - self.lock_acquired = lock_task.done() and not lock_task.cancelled() - finally: - lock_task.cancel() - stop_task.cancel() + raise asyncio.CancelledError from ex - # If we've been told to stop, then just finish up. Otherwise, we've acquired the - # lock so we can go ahead and start the run. - if self._stop.is_set(): - self._finish() - else: - await super().async_run() + self.lock_acquired = True + # We've acquired the lock so we can go ahead and start the run. + await super().async_run() def _finish(self) -> None: if self.lock_acquired: