Reduce script overhead by avoiding creation of many tasks (#113183)
* Reduce script overhead by avoiding creation of many tasks * no eager stop * reduce * make sure wait being cancelled is handled * make sure wait being cancelled is handled * make sure wait being cancelled is handled * preen * preen * result already raises cancelled error, remove redundant code * no need to raise it into the future * will never set an exception * Simplify long action script implementation * comment * preen * dry * dry * preen * dry * preen * no need to access protected * no need to access protected * dry * name * dry * dry * dry * dry * reduce name changes * drop one more task * stale comment * stale commentpull/113190/head
parent
e293afe46e
commit
09934d44c4
|
@ -4,7 +4,7 @@ from __future__ import annotations
|
|||
|
||||
import asyncio
|
||||
from collections.abc import AsyncGenerator, Callable, Mapping, Sequence
|
||||
from contextlib import asynccontextmanager, suppress
|
||||
from contextlib import asynccontextmanager
|
||||
from contextvars import ContextVar
|
||||
from copy import copy
|
||||
from dataclasses import dataclass
|
||||
|
@ -15,6 +15,7 @@ import logging
|
|||
from types import MappingProxyType
|
||||
from typing import TYPE_CHECKING, Any, TypedDict, TypeVar, cast
|
||||
|
||||
import async_interrupt
|
||||
import voluptuous as vol
|
||||
|
||||
from homeassistant import exceptions
|
||||
|
@ -157,6 +158,16 @@ SCRIPT_DEBUG_CONTINUE_ALL = "script_debug_continue_all"
|
|||
script_stack_cv: ContextVar[list[int] | None] = ContextVar("script_stack", default=None)
|
||||
|
||||
|
||||
class ScriptStoppedError(Exception):
|
||||
"""Error to indicate that the script has been stopped."""
|
||||
|
||||
|
||||
def _set_result_unless_done(future: asyncio.Future[None]) -> None:
|
||||
"""Set result of future unless it is done."""
|
||||
if not future.done():
|
||||
future.set_result(None)
|
||||
|
||||
|
||||
def action_trace_append(variables, path):
|
||||
"""Append a TraceElement to trace[path]."""
|
||||
trace_element = TraceElement(variables, path)
|
||||
|
@ -168,7 +179,7 @@ def action_trace_append(variables, path):
|
|||
async def trace_action(
|
||||
hass: HomeAssistant,
|
||||
script_run: _ScriptRun,
|
||||
stop: asyncio.Event,
|
||||
stop: asyncio.Future[None],
|
||||
variables: dict[str, Any],
|
||||
) -> AsyncGenerator[TraceElement, None]:
|
||||
"""Trace action execution."""
|
||||
|
@ -199,13 +210,13 @@ async def trace_action(
|
|||
):
|
||||
async_dispatcher_send(hass, SCRIPT_BREAKPOINT_HIT, key, run_id, path)
|
||||
|
||||
done = asyncio.Event()
|
||||
done = hass.loop.create_future()
|
||||
|
||||
@callback
|
||||
def async_continue_stop(command=None):
|
||||
if command == "stop":
|
||||
stop.set()
|
||||
done.set()
|
||||
_set_result_unless_done(stop)
|
||||
_set_result_unless_done(done)
|
||||
|
||||
signal = SCRIPT_DEBUG_CONTINUE_STOP.format(key, run_id)
|
||||
remove_signal1 = async_dispatcher_connect(hass, signal, async_continue_stop)
|
||||
|
@ -213,10 +224,7 @@ async def trace_action(
|
|||
hass, SCRIPT_DEBUG_CONTINUE_ALL, async_continue_stop
|
||||
)
|
||||
|
||||
tasks = [hass.async_create_task(flag.wait()) for flag in (stop, done)]
|
||||
await asyncio.wait(tasks, return_when=asyncio.FIRST_COMPLETED)
|
||||
for task in tasks:
|
||||
task.cancel()
|
||||
await asyncio.wait([stop, done], return_when=asyncio.FIRST_COMPLETED)
|
||||
remove_signal1()
|
||||
remove_signal2()
|
||||
|
||||
|
@ -393,12 +401,12 @@ class _ScriptRun:
|
|||
self._log_exceptions = log_exceptions
|
||||
self._step = -1
|
||||
self._started = False
|
||||
self._stop = asyncio.Event()
|
||||
self._stop = hass.loop.create_future()
|
||||
self._stopped = asyncio.Event()
|
||||
self._conversation_response: str | None | UndefinedType = UNDEFINED
|
||||
|
||||
def _changed(self) -> None:
|
||||
if not self._stop.is_set():
|
||||
if not self._stop.done():
|
||||
self._script._changed() # pylint: disable=protected-access
|
||||
|
||||
async def _async_get_condition(self, config):
|
||||
|
@ -432,7 +440,7 @@ class _ScriptRun:
|
|||
try:
|
||||
self._log("Running %s", self._script.running_description)
|
||||
for self._step, self._action in enumerate(self._script.sequence):
|
||||
if self._stop.is_set():
|
||||
if self._stop.done():
|
||||
script_execution_set("cancelled")
|
||||
break
|
||||
await self._async_step(log_exceptions=False)
|
||||
|
@ -471,7 +479,7 @@ class _ScriptRun:
|
|||
async with trace_action(
|
||||
self._hass, self, self._stop, self._variables
|
||||
) as trace_element:
|
||||
if self._stop.is_set():
|
||||
if self._stop.done():
|
||||
return
|
||||
|
||||
action = cv.determine_script_action(self._action)
|
||||
|
@ -483,8 +491,8 @@ class _ScriptRun:
|
|||
trace_set_result(enabled=False)
|
||||
return
|
||||
|
||||
handler = f"_async_{action}_step"
|
||||
try:
|
||||
handler = f"_async_{action}_step"
|
||||
await getattr(self, handler)()
|
||||
except Exception as ex: # pylint: disable=broad-except
|
||||
self._handle_exception(
|
||||
|
@ -502,7 +510,7 @@ class _ScriptRun:
|
|||
|
||||
async def async_stop(self) -> None:
|
||||
"""Stop script run."""
|
||||
self._stop.set()
|
||||
_set_result_unless_done(self._stop)
|
||||
# If the script was never started
|
||||
# the stopped event will never be
|
||||
# set because the script will never
|
||||
|
@ -576,9 +584,9 @@ class _ScriptRun:
|
|||
level=level,
|
||||
)
|
||||
|
||||
def _get_pos_time_period_template(self, key):
|
||||
def _get_pos_time_period_template(self, key: str) -> timedelta:
|
||||
try:
|
||||
return cv.positive_time_period(
|
||||
return cv.positive_time_period( # type: ignore[no-any-return]
|
||||
template.render_complex(self._action[key], self._variables)
|
||||
)
|
||||
except (exceptions.TemplateError, vol.Invalid) as ex:
|
||||
|
@ -593,26 +601,34 @@ class _ScriptRun:
|
|||
|
||||
async def _async_delay_step(self):
|
||||
"""Handle delay."""
|
||||
delay = self._get_pos_time_period_template(CONF_DELAY)
|
||||
delay_delta = self._get_pos_time_period_template(CONF_DELAY)
|
||||
|
||||
self._step_log(f"delay {delay}")
|
||||
self._step_log(f"delay {delay_delta}")
|
||||
|
||||
delay = delay.total_seconds()
|
||||
delay = delay_delta.total_seconds()
|
||||
self._changed()
|
||||
trace_set_result(delay=delay, done=False)
|
||||
futures, timeout_handle, timeout_future = self._async_futures_with_timeout(
|
||||
delay
|
||||
)
|
||||
|
||||
try:
|
||||
async with asyncio.timeout(delay):
|
||||
await self._stop.wait()
|
||||
except TimeoutError:
|
||||
trace_set_result(delay=delay, done=True)
|
||||
await asyncio.wait(futures, return_when=asyncio.FIRST_COMPLETED)
|
||||
finally:
|
||||
if timeout_future.done():
|
||||
trace_set_result(delay=delay, done=True)
|
||||
else:
|
||||
timeout_handle.cancel()
|
||||
|
||||
def _get_timeout_seconds_from_action(self) -> float | None:
|
||||
"""Get the timeout from the action."""
|
||||
if CONF_TIMEOUT in self._action:
|
||||
return self._get_pos_time_period_template(CONF_TIMEOUT).total_seconds()
|
||||
return None
|
||||
|
||||
async def _async_wait_template_step(self):
|
||||
"""Handle a wait template."""
|
||||
if CONF_TIMEOUT in self._action:
|
||||
timeout = self._get_pos_time_period_template(CONF_TIMEOUT).total_seconds()
|
||||
else:
|
||||
timeout = None
|
||||
|
||||
timeout = self._get_timeout_seconds_from_action()
|
||||
self._step_log("wait template", timeout)
|
||||
|
||||
self._variables["wait"] = {"remaining": timeout, "completed": False}
|
||||
|
@ -626,74 +642,47 @@ class _ScriptRun:
|
|||
self._variables["wait"]["completed"] = True
|
||||
return
|
||||
|
||||
futures, timeout_handle, timeout_future = self._async_futures_with_timeout(
|
||||
timeout
|
||||
)
|
||||
done = self._hass.loop.create_future()
|
||||
futures.append(done)
|
||||
|
||||
@callback
|
||||
def async_script_wait(entity_id, from_s, to_s):
|
||||
"""Handle script after template condition is true."""
|
||||
# pylint: disable=protected-access
|
||||
wait_var = self._variables["wait"]
|
||||
if to_context and to_context._when:
|
||||
wait_var["remaining"] = to_context._when - self._hass.loop.time()
|
||||
else:
|
||||
wait_var["remaining"] = timeout
|
||||
wait_var["completed"] = True
|
||||
done.set()
|
||||
self._async_set_remaining_time_var(timeout_handle)
|
||||
self._variables["wait"]["completed"] = True
|
||||
_set_result_unless_done(done)
|
||||
|
||||
to_context = None
|
||||
unsub = async_track_template(
|
||||
self._hass, wait_template, async_script_wait, self._variables
|
||||
)
|
||||
|
||||
self._changed()
|
||||
done = asyncio.Event()
|
||||
tasks = [
|
||||
self._hass.async_create_task(flag.wait()) for flag in (self._stop, done)
|
||||
]
|
||||
try:
|
||||
async with asyncio.timeout(timeout) as to_context:
|
||||
await asyncio.wait(tasks, return_when=asyncio.FIRST_COMPLETED)
|
||||
except TimeoutError as ex:
|
||||
self._variables["wait"]["remaining"] = 0.0
|
||||
if not self._action.get(CONF_CONTINUE_ON_TIMEOUT, True):
|
||||
self._log(_TIMEOUT_MSG)
|
||||
trace_set_result(wait=self._variables["wait"], timeout=True)
|
||||
raise _AbortScript from ex
|
||||
finally:
|
||||
for task in tasks:
|
||||
task.cancel()
|
||||
unsub()
|
||||
await self._async_wait_with_optional_timeout(
|
||||
futures, timeout_handle, timeout_future, unsub
|
||||
)
|
||||
|
||||
def _async_set_remaining_time_var(
|
||||
self, timeout_handle: asyncio.TimerHandle | None
|
||||
) -> None:
|
||||
"""Set the remaining time variable for a wait step."""
|
||||
wait_var = self._variables["wait"]
|
||||
if timeout_handle:
|
||||
wait_var["remaining"] = timeout_handle.when() - self._hass.loop.time()
|
||||
else:
|
||||
wait_var["remaining"] = None
|
||||
|
||||
async def _async_run_long_action(self, long_task: asyncio.Task[_T]) -> _T | None:
|
||||
"""Run a long task while monitoring for stop request."""
|
||||
|
||||
async def async_cancel_long_task() -> None:
|
||||
# Stop long task and wait for it to finish.
|
||||
long_task.cancel()
|
||||
with suppress(Exception):
|
||||
await long_task
|
||||
|
||||
# Wait for long task while monitoring for a stop request.
|
||||
stop_task = self._hass.async_create_task(self._stop.wait())
|
||||
try:
|
||||
await asyncio.wait(
|
||||
{long_task, stop_task}, return_when=asyncio.FIRST_COMPLETED
|
||||
)
|
||||
# If our task is cancelled, then cancel long task, too. Note that if long task
|
||||
# is cancelled otherwise the CancelledError exception will not be raised to
|
||||
# here due to the call to asyncio.wait(). Rather we'll check for that below.
|
||||
except asyncio.CancelledError:
|
||||
await async_cancel_long_task()
|
||||
raise
|
||||
finally:
|
||||
stop_task.cancel()
|
||||
|
||||
if long_task.cancelled():
|
||||
raise asyncio.CancelledError
|
||||
if long_task.done():
|
||||
# Propagate any exceptions that occurred.
|
||||
return long_task.result()
|
||||
# Stopped before long task completed, so cancel it.
|
||||
await async_cancel_long_task()
|
||||
return None
|
||||
async with async_interrupt.interrupt(self._stop, ScriptStoppedError, None):
|
||||
# if stop is set, interrupt will cancel inside the context
|
||||
# manager which will cancel long_task, and raise
|
||||
# ScriptStoppedError outside the context manager
|
||||
return await long_task
|
||||
except ScriptStoppedError as ex:
|
||||
raise asyncio.CancelledError from ex
|
||||
|
||||
async def _async_call_service_step(self):
|
||||
"""Call the service specified in the action."""
|
||||
|
@ -735,8 +724,9 @@ class _ScriptRun:
|
|||
blocking=True,
|
||||
context=self._context,
|
||||
return_response=return_response,
|
||||
)
|
||||
),
|
||||
),
|
||||
eager_start=True,
|
||||
)
|
||||
)
|
||||
if response_variable:
|
||||
self._variables[response_variable] = response_data
|
||||
|
@ -866,7 +856,7 @@ class _ScriptRun:
|
|||
for iteration in range(1, count + 1):
|
||||
set_repeat_var(iteration, count)
|
||||
await async_run_sequence(iteration, extra_msg)
|
||||
if self._stop.is_set():
|
||||
if self._stop.done():
|
||||
break
|
||||
|
||||
elif CONF_FOR_EACH in repeat:
|
||||
|
@ -894,7 +884,7 @@ class _ScriptRun:
|
|||
for iteration, item in enumerate(items, 1):
|
||||
set_repeat_var(iteration, count, item)
|
||||
extra_msg = f" of {count} with item: {repr(item)}"
|
||||
if self._stop.is_set():
|
||||
if self._stop.done():
|
||||
break
|
||||
await async_run_sequence(iteration, extra_msg)
|
||||
|
||||
|
@ -905,7 +895,7 @@ class _ScriptRun:
|
|||
for iteration in itertools.count(1):
|
||||
set_repeat_var(iteration)
|
||||
try:
|
||||
if self._stop.is_set():
|
||||
if self._stop.done():
|
||||
break
|
||||
if not self._test_conditions(conditions, "while"):
|
||||
break
|
||||
|
@ -923,7 +913,7 @@ class _ScriptRun:
|
|||
set_repeat_var(iteration)
|
||||
await async_run_sequence(iteration)
|
||||
try:
|
||||
if self._stop.is_set():
|
||||
if self._stop.done():
|
||||
break
|
||||
if self._test_conditions(conditions, "until") in [True, None]:
|
||||
break
|
||||
|
@ -983,12 +973,35 @@ class _ScriptRun:
|
|||
with trace_path("else"):
|
||||
await self._async_run_script(if_data["if_else"])
|
||||
|
||||
def _async_futures_with_timeout(
|
||||
self,
|
||||
timeout: float | None,
|
||||
) -> tuple[
|
||||
list[asyncio.Future[None]],
|
||||
asyncio.TimerHandle | None,
|
||||
asyncio.Future[None] | None,
|
||||
]:
|
||||
"""Return a list of futures to wait for.
|
||||
|
||||
The list will contain the stop future.
|
||||
|
||||
If timeout is set, a timeout future and handle will be created
|
||||
and will be added to the list of futures.
|
||||
"""
|
||||
timeout_handle: asyncio.TimerHandle | None = None
|
||||
timeout_future: asyncio.Future[None] | None = None
|
||||
futures: list[asyncio.Future[None]] = [self._stop]
|
||||
if timeout:
|
||||
timeout_future = self._hass.loop.create_future()
|
||||
timeout_handle = self._hass.loop.call_later(
|
||||
timeout, _set_result_unless_done, timeout_future
|
||||
)
|
||||
futures.append(timeout_future)
|
||||
return futures, timeout_handle, timeout_future
|
||||
|
||||
async def _async_wait_for_trigger_step(self):
|
||||
"""Wait for a trigger event."""
|
||||
if CONF_TIMEOUT in self._action:
|
||||
timeout = self._get_pos_time_period_template(CONF_TIMEOUT).total_seconds()
|
||||
else:
|
||||
timeout = None
|
||||
timeout = self._get_timeout_seconds_from_action()
|
||||
|
||||
self._step_log("wait for trigger", timeout)
|
||||
|
||||
|
@ -996,22 +1009,20 @@ class _ScriptRun:
|
|||
self._variables["wait"] = {"remaining": timeout, "trigger": None}
|
||||
trace_set_result(wait=self._variables["wait"])
|
||||
|
||||
done = asyncio.Event()
|
||||
futures, timeout_handle, timeout_future = self._async_futures_with_timeout(
|
||||
timeout
|
||||
)
|
||||
done = self._hass.loop.create_future()
|
||||
futures.append(done)
|
||||
|
||||
async def async_done(variables, context=None):
|
||||
# pylint: disable=protected-access
|
||||
wait_var = self._variables["wait"]
|
||||
if to_context and to_context._when:
|
||||
wait_var["remaining"] = to_context._when - self._hass.loop.time()
|
||||
else:
|
||||
wait_var["remaining"] = timeout
|
||||
wait_var["trigger"] = variables["trigger"]
|
||||
done.set()
|
||||
self._async_set_remaining_time_var(timeout_handle)
|
||||
self._variables["wait"]["trigger"] = variables["trigger"]
|
||||
_set_result_unless_done(done)
|
||||
|
||||
def log_cb(level, msg, **kwargs):
|
||||
self._log(msg, level=level, **kwargs)
|
||||
|
||||
to_context = None
|
||||
remove_triggers = await async_initialize_triggers(
|
||||
self._hass,
|
||||
self._action[CONF_WAIT_FOR_TRIGGER],
|
||||
|
@ -1023,24 +1034,31 @@ class _ScriptRun:
|
|||
)
|
||||
if not remove_triggers:
|
||||
return
|
||||
|
||||
self._changed()
|
||||
tasks = [
|
||||
self._hass.async_create_task(flag.wait()) for flag in (self._stop, done)
|
||||
]
|
||||
await self._async_wait_with_optional_timeout(
|
||||
futures, timeout_handle, timeout_future, remove_triggers
|
||||
)
|
||||
|
||||
async def _async_wait_with_optional_timeout(
|
||||
self,
|
||||
futures: list[asyncio.Future[None]],
|
||||
timeout_handle: asyncio.TimerHandle | None,
|
||||
timeout_future: asyncio.Future[None] | None,
|
||||
unsub: Callable[[], None],
|
||||
) -> None:
|
||||
try:
|
||||
async with asyncio.timeout(timeout) as to_context:
|
||||
await asyncio.wait(tasks, return_when=asyncio.FIRST_COMPLETED)
|
||||
except TimeoutError as ex:
|
||||
self._variables["wait"]["remaining"] = 0.0
|
||||
if not self._action.get(CONF_CONTINUE_ON_TIMEOUT, True):
|
||||
self._log(_TIMEOUT_MSG)
|
||||
trace_set_result(wait=self._variables["wait"], timeout=True)
|
||||
raise _AbortScript from ex
|
||||
await asyncio.wait(futures, return_when=asyncio.FIRST_COMPLETED)
|
||||
if timeout_future and timeout_future.done():
|
||||
self._variables["wait"]["remaining"] = 0.0
|
||||
if not self._action.get(CONF_CONTINUE_ON_TIMEOUT, True):
|
||||
self._log(_TIMEOUT_MSG)
|
||||
trace_set_result(wait=self._variables["wait"], timeout=True)
|
||||
raise _AbortScript from TimeoutError()
|
||||
finally:
|
||||
for task in tasks:
|
||||
task.cancel()
|
||||
remove_triggers()
|
||||
if timeout_future and not timeout_future.done() and timeout_handle:
|
||||
timeout_handle.cancel()
|
||||
|
||||
unsub()
|
||||
|
||||
async def _async_variables_step(self):
|
||||
"""Set a variable value."""
|
||||
|
@ -1107,7 +1125,7 @@ class _ScriptRun:
|
|||
"""Execute a script."""
|
||||
result = await self._async_run_long_action(
|
||||
self._hass.async_create_task(
|
||||
script.async_run(self._variables, self._context)
|
||||
script.async_run(self._variables, self._context), eager_start=True
|
||||
)
|
||||
)
|
||||
if result and result.conversation_response is not UNDEFINED:
|
||||
|
@ -1123,29 +1141,17 @@ class _QueuedScriptRun(_ScriptRun):
|
|||
"""Run script."""
|
||||
# Wait for previous run, if any, to finish by attempting to acquire the script's
|
||||
# shared lock. At the same time monitor if we've been told to stop.
|
||||
lock_task = self._hass.async_create_task(
|
||||
self._script._queue_lck.acquire() # pylint: disable=protected-access
|
||||
)
|
||||
stop_task = self._hass.async_create_task(self._stop.wait())
|
||||
try:
|
||||
await asyncio.wait(
|
||||
{lock_task, stop_task}, return_when=asyncio.FIRST_COMPLETED
|
||||
)
|
||||
except asyncio.CancelledError:
|
||||
async with async_interrupt.interrupt(self._stop, ScriptStoppedError, None):
|
||||
await self._script._queue_lck.acquire() # pylint: disable=protected-access
|
||||
except ScriptStoppedError as ex:
|
||||
# If we've been told to stop, then just finish up.
|
||||
self._finish()
|
||||
raise
|
||||
else:
|
||||
self.lock_acquired = lock_task.done() and not lock_task.cancelled()
|
||||
finally:
|
||||
lock_task.cancel()
|
||||
stop_task.cancel()
|
||||
raise asyncio.CancelledError from ex
|
||||
|
||||
# If we've been told to stop, then just finish up. Otherwise, we've acquired the
|
||||
# lock so we can go ahead and start the run.
|
||||
if self._stop.is_set():
|
||||
self._finish()
|
||||
else:
|
||||
await super().async_run()
|
||||
self.lock_acquired = True
|
||||
# We've acquired the lock so we can go ahead and start the run.
|
||||
await super().async_run()
|
||||
|
||||
def _finish(self) -> None:
|
||||
if self.lock_acquired:
|
||||
|
|
Loading…
Reference in New Issue