Replace fire_coroutine_threadsafe with asyncio.run_coroutine_threadsafe (#88572)

fire_coroutine_threadsafe did not hold a reference to the asyncio
task which meant the task had the risk of being prematurely
garbage collected
pull/87823/head^2
J. Nick Koston 2023-02-21 20:16:18 -06:00 committed by GitHub
parent e54eb7e2c8
commit 5bc0636905
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 14 additions and 70 deletions

View File

@ -14,6 +14,7 @@ from collections.abc import (
Iterable,
Mapping,
)
import concurrent.futures
from contextlib import suppress
from contextvars import ContextVar
import datetime
@ -79,11 +80,7 @@ from .exceptions import (
)
from .helpers.aiohttp_compat import restore_original_aiohttp_cancel_behavior
from .util import dt as dt_util, location, ulid as ulid_util
from .util.async_ import (
fire_coroutine_threadsafe,
run_callback_threadsafe,
shutdown_run_callback_threadsafe,
)
from .util.async_ import run_callback_threadsafe, shutdown_run_callback_threadsafe
from .util.read_only_dict import ReadOnlyDict
from .util.timeout import TimeoutManager
from .util.unit_system import (
@ -294,6 +291,7 @@ class HomeAssistant:
self._stopped: asyncio.Event | None = None
# Timeout handler for Core/Helper namespace
self.timeout: TimeoutManager = TimeoutManager()
self._stop_future: concurrent.futures.Future[None] | None = None
@property
def is_running(self) -> bool:
@ -312,12 +310,14 @@ class HomeAssistant:
For regular use, use "await hass.run()".
"""
# Register the async start
fire_coroutine_threadsafe(self.async_start(), self.loop)
_future = asyncio.run_coroutine_threadsafe(self.async_start(), self.loop)
# Run forever
# Block until stopped
_LOGGER.info("Starting Home Assistant core loop")
self.loop.run_forever()
# The future is never retrieved but we still hold a reference to it
# to prevent the task from being garbage collected prematurely.
del _future
return self.exit_code
async def async_run(self, *, attach_signals: bool = True) -> int:
@ -682,7 +682,11 @@ class HomeAssistant:
"""Stop Home Assistant and shuts down all threads."""
if self.state == CoreState.not_running: # just ignore
return
fire_coroutine_threadsafe(self.async_stop(), self.loop)
# The future is never retrieved, and we only hold a reference
# to it to prevent it from being garbage collected.
self._stop_future = asyncio.run_coroutine_threadsafe(
self.async_stop(), self.loop
)
async def async_stop(self, exit_code: int = 0, *, force: bool = False) -> None:
"""Stop Home Assistant and shuts down all threads.

View File

@ -1,9 +1,9 @@
"""Asyncio utilities."""
from __future__ import annotations
from asyncio import Semaphore, coroutines, ensure_future, gather, get_running_loop
from asyncio import Semaphore, gather, get_running_loop
from asyncio.events import AbstractEventLoop
from collections.abc import Awaitable, Callable, Coroutine
from collections.abc import Awaitable, Callable
import concurrent.futures
import functools
import logging
@ -20,29 +20,6 @@ _R = TypeVar("_R")
_P = ParamSpec("_P")
def fire_coroutine_threadsafe(
coro: Coroutine[Any, Any, Any], loop: AbstractEventLoop
) -> None:
"""Submit a coroutine object to a given event loop.
This method does not provide a way to retrieve the result and
is intended for fire-and-forget use. This reduces the
work involved to fire the function on the loop.
"""
ident = loop.__dict__.get("_thread_ident")
if ident is not None and ident == threading.get_ident():
raise RuntimeError("Cannot be called from within the event loop")
if not coroutines.iscoroutine(coro):
raise TypeError(f"A coroutine object is required: {coro}")
def callback() -> None:
"""Handle the firing of a coroutine."""
ensure_future(coro, loop=loop)
loop.call_soon_threadsafe(callback)
def run_callback_threadsafe(
loop: AbstractEventLoop, callback: Callable[..., _T], *args: Any
) -> concurrent.futures.Future[_T]:

View File

@ -10,43 +10,6 @@ from homeassistant.core import HomeAssistant
from homeassistant.util import async_ as hasync
@patch("asyncio.coroutines.iscoroutine")
@patch("concurrent.futures.Future")
@patch("threading.get_ident")
def test_fire_coroutine_threadsafe_from_inside_event_loop(
mock_ident, _, mock_iscoroutine
) -> None:
"""Testing calling fire_coroutine_threadsafe from inside an event loop."""
coro = MagicMock()
loop = MagicMock()
loop._thread_ident = None
mock_ident.return_value = 5
mock_iscoroutine.return_value = True
hasync.fire_coroutine_threadsafe(coro, loop)
assert len(loop.call_soon_threadsafe.mock_calls) == 1
loop._thread_ident = 5
mock_ident.return_value = 5
mock_iscoroutine.return_value = True
with pytest.raises(RuntimeError):
hasync.fire_coroutine_threadsafe(coro, loop)
assert len(loop.call_soon_threadsafe.mock_calls) == 1
loop._thread_ident = 1
mock_ident.return_value = 5
mock_iscoroutine.return_value = False
with pytest.raises(TypeError):
hasync.fire_coroutine_threadsafe(coro, loop)
assert len(loop.call_soon_threadsafe.mock_calls) == 1
loop._thread_ident = 1
mock_ident.return_value = 5
mock_iscoroutine.return_value = True
hasync.fire_coroutine_threadsafe(coro, loop)
assert len(loop.call_soon_threadsafe.mock_calls) == 2
@patch("concurrent.futures.Future")
@patch("threading.get_ident")
def test_run_callback_threadsafe_from_inside_event_loop(mock_ident, _) -> None: