Schedule tasks eagerly when called from hass.add_job (#113014)

pull/113042/head
J. Nick Koston 2024-03-10 21:19:49 -10:00 committed by GitHub
parent cede16fc40
commit 3387892f59
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 41 additions and 25 deletions

View File

@ -524,30 +524,43 @@ class HomeAssistant:
if target is None:
raise ValueError("Don't call add_job with None")
if asyncio.iscoroutine(target):
self.loop.call_soon_threadsafe(self.async_add_job, target)
self.loop.call_soon_threadsafe(
functools.partial(self.async_add_job, target, eager_start=True)
)
return
if TYPE_CHECKING:
target = cast(Callable[..., Any], target)
self.loop.call_soon_threadsafe(self.async_add_job, target, *args)
self.loop.call_soon_threadsafe(
functools.partial(self.async_add_job, target, *args, eager_start=True)
)
@overload
@callback
def async_add_job(
self, target: Callable[..., Coroutine[Any, Any, _R]], *args: Any
self,
target: Callable[..., Coroutine[Any, Any, _R]],
*args: Any,
eager_start: bool = False,
) -> asyncio.Future[_R] | None:
...
@overload
@callback
def async_add_job(
self, target: Callable[..., Coroutine[Any, Any, _R] | _R], *args: Any
self,
target: Callable[..., Coroutine[Any, Any, _R] | _R],
*args: Any,
eager_start: bool = False,
) -> asyncio.Future[_R] | None:
...
@overload
@callback
def async_add_job(
self, target: Coroutine[Any, Any, _R], *args: Any
self,
target: Coroutine[Any, Any, _R],
*args: Any,
eager_start: bool = False,
) -> asyncio.Future[_R] | None:
...
@ -556,6 +569,7 @@ class HomeAssistant:
self,
target: Callable[..., Coroutine[Any, Any, _R] | _R] | Coroutine[Any, Any, _R],
*args: Any,
eager_start: bool = False,
) -> asyncio.Future[_R] | None:
"""Add a job to be executed by the event loop or by an executor.
@ -571,7 +585,7 @@ class HomeAssistant:
raise ValueError("Don't call async_add_job with None")
if asyncio.iscoroutine(target):
return self.async_create_task(target)
return self.async_create_task(target, eager_start=eager_start)
# This code path is performance sensitive and uses
# if TYPE_CHECKING to avoid the overhead of constructing
@ -579,7 +593,7 @@ class HomeAssistant:
# https://github.com/home-assistant/core/pull/71960
if TYPE_CHECKING:
target = cast(Callable[..., Coroutine[Any, Any, _R] | _R], target)
return self.async_add_hass_job(HassJob(target), *args)
return self.async_add_hass_job(HassJob(target), *args, eager_start=eager_start)
@overload
@callback

View File

@ -234,7 +234,7 @@ async def async_test_home_assistant(
orig_async_create_task = hass.async_create_task
orig_tz = dt_util.DEFAULT_TIME_ZONE
def async_add_job(target, *args):
def async_add_job(target, *args, eager_start: bool = False):
"""Add job."""
check_target = target
while isinstance(check_target, ft.partial):
@ -245,7 +245,7 @@ async def async_test_home_assistant(
fut.set_result(target(*args))
return fut
return orig_async_add_job(target, *args)
return orig_async_add_job(target, *args, eager_start=eager_start)
def async_add_executor_job(target, *args):
"""Add executor job."""

View File

@ -736,6 +736,20 @@ async def test_pending_scheduler(hass: HomeAssistant) -> None:
assert len(call_count) == 3
def test_add_job_pending_tasks_coro(hass: HomeAssistant) -> None:
"""Add a coro to pending tasks."""
async def test_coro():
"""Test Coro."""
pass
for _ in range(2):
hass.add_job(test_coro())
# Ensure add_job does not run immediately
assert len(hass._tasks) == 0
async def test_async_add_job_pending_tasks_coro(hass: HomeAssistant) -> None:
"""Add a coro to pending tasks."""
call_count = []
@ -745,18 +759,12 @@ async def test_async_add_job_pending_tasks_coro(hass: HomeAssistant) -> None:
call_count.append("call")
for _ in range(2):
hass.add_job(test_coro())
async def wait_finish_callback():
"""Wait until all stuff is scheduled."""
await asyncio.sleep(0)
await asyncio.sleep(0)
await wait_finish_callback()
hass.async_add_job(test_coro())
assert len(hass._tasks) == 2
await hass.async_block_till_done()
assert len(call_count) == 2
assert len(hass._tasks) == 0
async def test_async_create_task_pending_tasks_coro(hass: HomeAssistant) -> None:
@ -768,18 +776,12 @@ async def test_async_create_task_pending_tasks_coro(hass: HomeAssistant) -> None
call_count.append("call")
for _ in range(2):
hass.create_task(test_coro())
async def wait_finish_callback():
"""Wait until all stuff is scheduled."""
await asyncio.sleep(0)
await asyncio.sleep(0)
await wait_finish_callback()
hass.async_create_task(test_coro())
assert len(hass._tasks) == 2
await hass.async_block_till_done()
assert len(call_count) == 2
assert len(hass._tasks) == 0
async def test_async_add_job_pending_tasks_executor(hass: HomeAssistant) -> None: