Reduce script overhead by avoiding creation of many tasks (#113183)

* Reduce script overhead by avoiding creation of many tasks

* no eager stop

* reduce

* make sure wait being cancelled is handled

* make sure wait being cancelled is handled

* make sure wait being cancelled is handled

* preen

* preen

* result already raises cancelled error, remove redundant code

* no need to raise it into the future

* will never set an exception

* Simplify long action script implementation

* comment

* preen

* dry

* dry

* preen

* dry

* preen

* no need to access protected

* no need to access protected

* dry

* name

* dry

* dry

* dry

* dry

* reduce name changes

* drop one more task

* stale comment

* stale comment
pull/113190/head
J. Nick Koston 2024-03-14 14:28:27 -10:00 committed by GitHub
parent e293afe46e
commit 09934d44c4
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
1 changed files with 147 additions and 141 deletions

View File

@ -4,7 +4,7 @@ from __future__ import annotations
import asyncio
from collections.abc import AsyncGenerator, Callable, Mapping, Sequence
from contextlib import asynccontextmanager, suppress
from contextlib import asynccontextmanager
from contextvars import ContextVar
from copy import copy
from dataclasses import dataclass
@ -15,6 +15,7 @@ import logging
from types import MappingProxyType
from typing import TYPE_CHECKING, Any, TypedDict, TypeVar, cast
import async_interrupt
import voluptuous as vol
from homeassistant import exceptions
@ -157,6 +158,16 @@ SCRIPT_DEBUG_CONTINUE_ALL = "script_debug_continue_all"
script_stack_cv: ContextVar[list[int] | None] = ContextVar("script_stack", default=None)
class ScriptStoppedError(Exception):
"""Error to indicate that the script has been stopped."""
def _set_result_unless_done(future: asyncio.Future[None]) -> None:
"""Set result of future unless it is done."""
if not future.done():
future.set_result(None)
def action_trace_append(variables, path):
"""Append a TraceElement to trace[path]."""
trace_element = TraceElement(variables, path)
@ -168,7 +179,7 @@ def action_trace_append(variables, path):
async def trace_action(
hass: HomeAssistant,
script_run: _ScriptRun,
stop: asyncio.Event,
stop: asyncio.Future[None],
variables: dict[str, Any],
) -> AsyncGenerator[TraceElement, None]:
"""Trace action execution."""
@ -199,13 +210,13 @@ async def trace_action(
):
async_dispatcher_send(hass, SCRIPT_BREAKPOINT_HIT, key, run_id, path)
done = asyncio.Event()
done = hass.loop.create_future()
@callback
def async_continue_stop(command=None):
if command == "stop":
stop.set()
done.set()
_set_result_unless_done(stop)
_set_result_unless_done(done)
signal = SCRIPT_DEBUG_CONTINUE_STOP.format(key, run_id)
remove_signal1 = async_dispatcher_connect(hass, signal, async_continue_stop)
@ -213,10 +224,7 @@ async def trace_action(
hass, SCRIPT_DEBUG_CONTINUE_ALL, async_continue_stop
)
tasks = [hass.async_create_task(flag.wait()) for flag in (stop, done)]
await asyncio.wait(tasks, return_when=asyncio.FIRST_COMPLETED)
for task in tasks:
task.cancel()
await asyncio.wait([stop, done], return_when=asyncio.FIRST_COMPLETED)
remove_signal1()
remove_signal2()
@ -393,12 +401,12 @@ class _ScriptRun:
self._log_exceptions = log_exceptions
self._step = -1
self._started = False
self._stop = asyncio.Event()
self._stop = hass.loop.create_future()
self._stopped = asyncio.Event()
self._conversation_response: str | None | UndefinedType = UNDEFINED
def _changed(self) -> None:
if not self._stop.is_set():
if not self._stop.done():
self._script._changed() # pylint: disable=protected-access
async def _async_get_condition(self, config):
@ -432,7 +440,7 @@ class _ScriptRun:
try:
self._log("Running %s", self._script.running_description)
for self._step, self._action in enumerate(self._script.sequence):
if self._stop.is_set():
if self._stop.done():
script_execution_set("cancelled")
break
await self._async_step(log_exceptions=False)
@ -471,7 +479,7 @@ class _ScriptRun:
async with trace_action(
self._hass, self, self._stop, self._variables
) as trace_element:
if self._stop.is_set():
if self._stop.done():
return
action = cv.determine_script_action(self._action)
@ -483,8 +491,8 @@ class _ScriptRun:
trace_set_result(enabled=False)
return
handler = f"_async_{action}_step"
try:
handler = f"_async_{action}_step"
await getattr(self, handler)()
except Exception as ex: # pylint: disable=broad-except
self._handle_exception(
@ -502,7 +510,7 @@ class _ScriptRun:
async def async_stop(self) -> None:
"""Stop script run."""
self._stop.set()
_set_result_unless_done(self._stop)
# If the script was never started
# the stopped event will never be
# set because the script will never
@ -576,9 +584,9 @@ class _ScriptRun:
level=level,
)
def _get_pos_time_period_template(self, key):
def _get_pos_time_period_template(self, key: str) -> timedelta:
try:
return cv.positive_time_period(
return cv.positive_time_period( # type: ignore[no-any-return]
template.render_complex(self._action[key], self._variables)
)
except (exceptions.TemplateError, vol.Invalid) as ex:
@ -593,26 +601,34 @@ class _ScriptRun:
async def _async_delay_step(self):
"""Handle delay."""
delay = self._get_pos_time_period_template(CONF_DELAY)
delay_delta = self._get_pos_time_period_template(CONF_DELAY)
self._step_log(f"delay {delay}")
self._step_log(f"delay {delay_delta}")
delay = delay.total_seconds()
delay = delay_delta.total_seconds()
self._changed()
trace_set_result(delay=delay, done=False)
futures, timeout_handle, timeout_future = self._async_futures_with_timeout(
delay
)
try:
async with asyncio.timeout(delay):
await self._stop.wait()
except TimeoutError:
trace_set_result(delay=delay, done=True)
await asyncio.wait(futures, return_when=asyncio.FIRST_COMPLETED)
finally:
if timeout_future.done():
trace_set_result(delay=delay, done=True)
else:
timeout_handle.cancel()
def _get_timeout_seconds_from_action(self) -> float | None:
"""Get the timeout from the action."""
if CONF_TIMEOUT in self._action:
return self._get_pos_time_period_template(CONF_TIMEOUT).total_seconds()
return None
async def _async_wait_template_step(self):
"""Handle a wait template."""
if CONF_TIMEOUT in self._action:
timeout = self._get_pos_time_period_template(CONF_TIMEOUT).total_seconds()
else:
timeout = None
timeout = self._get_timeout_seconds_from_action()
self._step_log("wait template", timeout)
self._variables["wait"] = {"remaining": timeout, "completed": False}
@ -626,74 +642,47 @@ class _ScriptRun:
self._variables["wait"]["completed"] = True
return
futures, timeout_handle, timeout_future = self._async_futures_with_timeout(
timeout
)
done = self._hass.loop.create_future()
futures.append(done)
@callback
def async_script_wait(entity_id, from_s, to_s):
"""Handle script after template condition is true."""
# pylint: disable=protected-access
wait_var = self._variables["wait"]
if to_context and to_context._when:
wait_var["remaining"] = to_context._when - self._hass.loop.time()
else:
wait_var["remaining"] = timeout
wait_var["completed"] = True
done.set()
self._async_set_remaining_time_var(timeout_handle)
self._variables["wait"]["completed"] = True
_set_result_unless_done(done)
to_context = None
unsub = async_track_template(
self._hass, wait_template, async_script_wait, self._variables
)
self._changed()
done = asyncio.Event()
tasks = [
self._hass.async_create_task(flag.wait()) for flag in (self._stop, done)
]
try:
async with asyncio.timeout(timeout) as to_context:
await asyncio.wait(tasks, return_when=asyncio.FIRST_COMPLETED)
except TimeoutError as ex:
self._variables["wait"]["remaining"] = 0.0
if not self._action.get(CONF_CONTINUE_ON_TIMEOUT, True):
self._log(_TIMEOUT_MSG)
trace_set_result(wait=self._variables["wait"], timeout=True)
raise _AbortScript from ex
finally:
for task in tasks:
task.cancel()
unsub()
await self._async_wait_with_optional_timeout(
futures, timeout_handle, timeout_future, unsub
)
def _async_set_remaining_time_var(
self, timeout_handle: asyncio.TimerHandle | None
) -> None:
"""Set the remaining time variable for a wait step."""
wait_var = self._variables["wait"]
if timeout_handle:
wait_var["remaining"] = timeout_handle.when() - self._hass.loop.time()
else:
wait_var["remaining"] = None
async def _async_run_long_action(self, long_task: asyncio.Task[_T]) -> _T | None:
"""Run a long task while monitoring for stop request."""
async def async_cancel_long_task() -> None:
# Stop long task and wait for it to finish.
long_task.cancel()
with suppress(Exception):
await long_task
# Wait for long task while monitoring for a stop request.
stop_task = self._hass.async_create_task(self._stop.wait())
try:
await asyncio.wait(
{long_task, stop_task}, return_when=asyncio.FIRST_COMPLETED
)
# If our task is cancelled, then cancel long task, too. Note that if long task
# is cancelled otherwise the CancelledError exception will not be raised to
# here due to the call to asyncio.wait(). Rather we'll check for that below.
except asyncio.CancelledError:
await async_cancel_long_task()
raise
finally:
stop_task.cancel()
if long_task.cancelled():
raise asyncio.CancelledError
if long_task.done():
# Propagate any exceptions that occurred.
return long_task.result()
# Stopped before long task completed, so cancel it.
await async_cancel_long_task()
return None
async with async_interrupt.interrupt(self._stop, ScriptStoppedError, None):
# if stop is set, interrupt will cancel inside the context
# manager which will cancel long_task, and raise
# ScriptStoppedError outside the context manager
return await long_task
except ScriptStoppedError as ex:
raise asyncio.CancelledError from ex
async def _async_call_service_step(self):
"""Call the service specified in the action."""
@ -735,8 +724,9 @@ class _ScriptRun:
blocking=True,
context=self._context,
return_response=return_response,
)
),
),
eager_start=True,
)
)
if response_variable:
self._variables[response_variable] = response_data
@ -866,7 +856,7 @@ class _ScriptRun:
for iteration in range(1, count + 1):
set_repeat_var(iteration, count)
await async_run_sequence(iteration, extra_msg)
if self._stop.is_set():
if self._stop.done():
break
elif CONF_FOR_EACH in repeat:
@ -894,7 +884,7 @@ class _ScriptRun:
for iteration, item in enumerate(items, 1):
set_repeat_var(iteration, count, item)
extra_msg = f" of {count} with item: {repr(item)}"
if self._stop.is_set():
if self._stop.done():
break
await async_run_sequence(iteration, extra_msg)
@ -905,7 +895,7 @@ class _ScriptRun:
for iteration in itertools.count(1):
set_repeat_var(iteration)
try:
if self._stop.is_set():
if self._stop.done():
break
if not self._test_conditions(conditions, "while"):
break
@ -923,7 +913,7 @@ class _ScriptRun:
set_repeat_var(iteration)
await async_run_sequence(iteration)
try:
if self._stop.is_set():
if self._stop.done():
break
if self._test_conditions(conditions, "until") in [True, None]:
break
@ -983,12 +973,35 @@ class _ScriptRun:
with trace_path("else"):
await self._async_run_script(if_data["if_else"])
def _async_futures_with_timeout(
self,
timeout: float | None,
) -> tuple[
list[asyncio.Future[None]],
asyncio.TimerHandle | None,
asyncio.Future[None] | None,
]:
"""Return a list of futures to wait for.
The list will contain the stop future.
If timeout is set, a timeout future and handle will be created
and will be added to the list of futures.
"""
timeout_handle: asyncio.TimerHandle | None = None
timeout_future: asyncio.Future[None] | None = None
futures: list[asyncio.Future[None]] = [self._stop]
if timeout:
timeout_future = self._hass.loop.create_future()
timeout_handle = self._hass.loop.call_later(
timeout, _set_result_unless_done, timeout_future
)
futures.append(timeout_future)
return futures, timeout_handle, timeout_future
async def _async_wait_for_trigger_step(self):
"""Wait for a trigger event."""
if CONF_TIMEOUT in self._action:
timeout = self._get_pos_time_period_template(CONF_TIMEOUT).total_seconds()
else:
timeout = None
timeout = self._get_timeout_seconds_from_action()
self._step_log("wait for trigger", timeout)
@ -996,22 +1009,20 @@ class _ScriptRun:
self._variables["wait"] = {"remaining": timeout, "trigger": None}
trace_set_result(wait=self._variables["wait"])
done = asyncio.Event()
futures, timeout_handle, timeout_future = self._async_futures_with_timeout(
timeout
)
done = self._hass.loop.create_future()
futures.append(done)
async def async_done(variables, context=None):
# pylint: disable=protected-access
wait_var = self._variables["wait"]
if to_context and to_context._when:
wait_var["remaining"] = to_context._when - self._hass.loop.time()
else:
wait_var["remaining"] = timeout
wait_var["trigger"] = variables["trigger"]
done.set()
self._async_set_remaining_time_var(timeout_handle)
self._variables["wait"]["trigger"] = variables["trigger"]
_set_result_unless_done(done)
def log_cb(level, msg, **kwargs):
self._log(msg, level=level, **kwargs)
to_context = None
remove_triggers = await async_initialize_triggers(
self._hass,
self._action[CONF_WAIT_FOR_TRIGGER],
@ -1023,24 +1034,31 @@ class _ScriptRun:
)
if not remove_triggers:
return
self._changed()
tasks = [
self._hass.async_create_task(flag.wait()) for flag in (self._stop, done)
]
await self._async_wait_with_optional_timeout(
futures, timeout_handle, timeout_future, remove_triggers
)
async def _async_wait_with_optional_timeout(
self,
futures: list[asyncio.Future[None]],
timeout_handle: asyncio.TimerHandle | None,
timeout_future: asyncio.Future[None] | None,
unsub: Callable[[], None],
) -> None:
try:
async with asyncio.timeout(timeout) as to_context:
await asyncio.wait(tasks, return_when=asyncio.FIRST_COMPLETED)
except TimeoutError as ex:
self._variables["wait"]["remaining"] = 0.0
if not self._action.get(CONF_CONTINUE_ON_TIMEOUT, True):
self._log(_TIMEOUT_MSG)
trace_set_result(wait=self._variables["wait"], timeout=True)
raise _AbortScript from ex
await asyncio.wait(futures, return_when=asyncio.FIRST_COMPLETED)
if timeout_future and timeout_future.done():
self._variables["wait"]["remaining"] = 0.0
if not self._action.get(CONF_CONTINUE_ON_TIMEOUT, True):
self._log(_TIMEOUT_MSG)
trace_set_result(wait=self._variables["wait"], timeout=True)
raise _AbortScript from TimeoutError()
finally:
for task in tasks:
task.cancel()
remove_triggers()
if timeout_future and not timeout_future.done() and timeout_handle:
timeout_handle.cancel()
unsub()
async def _async_variables_step(self):
"""Set a variable value."""
@ -1107,7 +1125,7 @@ class _ScriptRun:
"""Execute a script."""
result = await self._async_run_long_action(
self._hass.async_create_task(
script.async_run(self._variables, self._context)
script.async_run(self._variables, self._context), eager_start=True
)
)
if result and result.conversation_response is not UNDEFINED:
@ -1123,29 +1141,17 @@ class _QueuedScriptRun(_ScriptRun):
"""Run script."""
# Wait for previous run, if any, to finish by attempting to acquire the script's
# shared lock. At the same time monitor if we've been told to stop.
lock_task = self._hass.async_create_task(
self._script._queue_lck.acquire() # pylint: disable=protected-access
)
stop_task = self._hass.async_create_task(self._stop.wait())
try:
await asyncio.wait(
{lock_task, stop_task}, return_when=asyncio.FIRST_COMPLETED
)
except asyncio.CancelledError:
async with async_interrupt.interrupt(self._stop, ScriptStoppedError, None):
await self._script._queue_lck.acquire() # pylint: disable=protected-access
except ScriptStoppedError as ex:
# If we've been told to stop, then just finish up.
self._finish()
raise
else:
self.lock_acquired = lock_task.done() and not lock_task.cancelled()
finally:
lock_task.cancel()
stop_task.cancel()
raise asyncio.CancelledError from ex
# If we've been told to stop, then just finish up. Otherwise, we've acquired the
# lock so we can go ahead and start the run.
if self._stop.is_set():
self._finish()
else:
await super().async_run()
self.lock_acquired = True
# We've acquired the lock so we can go ahead and start the run.
await super().async_run()
def _finish(self) -> None:
if self.lock_acquired: