diff --git a/homeassistant/core.py b/homeassistant/core.py index dbc8769bb6f..82ea7228157 100644 --- a/homeassistant/core.py +++ b/homeassistant/core.py @@ -16,7 +16,6 @@ from collections.abc import ( ) import concurrent.futures from contextlib import suppress -from contextvars import ContextVar import datetime import enum import functools @@ -155,8 +154,6 @@ MAX_EXPECTED_ENTITY_IDS = 16384 _LOGGER = logging.getLogger(__name__) -_cv_hass: ContextVar[HomeAssistant] = ContextVar("hass") - @functools.lru_cache(MAX_EXPECTED_ENTITY_IDS) def split_entity_id(entity_id: str) -> tuple[str, str]: @@ -199,16 +196,27 @@ def is_callback(func: Callable[..., Any]) -> bool: return getattr(func, "_hass_callback", False) is True +class _Hass(threading.local): + """Container which makes a HomeAssistant instance available to the event loop.""" + + hass: HomeAssistant | None = None + + +_hass = _Hass() + + @callback def async_get_hass() -> HomeAssistant: """Return the HomeAssistant instance. - Raises LookupError if no HomeAssistant instance is available. + Raises HomeAssistantError when called from the wrong thread. This should be used where it's very cumbersome or downright impossible to pass hass to the code which needs it. """ - return _cv_hass.get() + if not _hass.hass: + raise HomeAssistantError("async_get_hass called from the wrong thread") + return _hass.hass @enum.unique @@ -292,9 +300,9 @@ class HomeAssistant: config_entries: ConfigEntries = None # type: ignore[assignment] def __new__(cls) -> HomeAssistant: - """Set the _cv_hass context variable.""" + """Set the _hass thread local data.""" hass = super().__new__(cls) - _cv_hass.set(hass) + _hass.hass = hass return hass def __init__(self) -> None: diff --git a/homeassistant/helpers/config_validation.py b/homeassistant/helpers/config_validation.py index cea8a866f5c..e8f1e58615c 100644 --- a/homeassistant/helpers/config_validation.py +++ b/homeassistant/helpers/config_validation.py @@ -93,7 +93,7 @@ from homeassistant.core import ( split_entity_id, valid_entity_id, ) -from homeassistant.exceptions import TemplateError +from homeassistant.exceptions import HomeAssistantError, TemplateError from homeassistant.generated import currencies from homeassistant.generated.countries import COUNTRIES from homeassistant.generated.languages import LANGUAGES @@ -609,7 +609,7 @@ def template(value: Any | None) -> template_helper.Template: raise vol.Invalid("template value should be a string") hass: HomeAssistant | None = None - with contextlib.suppress(LookupError): + with contextlib.suppress(HomeAssistantError): hass = async_get_hass() template_value = template_helper.Template(str(value), hass) @@ -631,7 +631,7 @@ def dynamic_template(value: Any | None) -> template_helper.Template: raise vol.Invalid("template value does not contain a dynamic template") hass: HomeAssistant | None = None - with contextlib.suppress(LookupError): + with contextlib.suppress(HomeAssistantError): hass = async_get_hass() template_value = template_helper.Template(str(value), hass) @@ -1098,7 +1098,7 @@ def _no_yaml_config_schema( # pylint: disable-next=import-outside-toplevel from .issue_registry import IssueSeverity, async_create_issue - with contextlib.suppress(LookupError): + with contextlib.suppress(HomeAssistantError): hass = async_get_hass() async_create_issue( hass, diff --git a/tests/conftest.py b/tests/conftest.py index 56014d7a556..922e42c7a7e 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -490,17 +490,7 @@ def hass_fixture_setup() -> list[bool]: @pytest.fixture -def hass(_hass: HomeAssistant) -> HomeAssistant: - """Fixture to provide a test instance of Home Assistant.""" - # This wraps the async _hass fixture inside a sync fixture, to ensure - # the `hass` context variable is set in the execution context in which - # the test itself is executed - ha._cv_hass.set(_hass) - return _hass - - -@pytest.fixture -async def _hass( +async def hass( hass_fixture_setup: list[bool], event_loop: asyncio.AbstractEventLoop, load_registries: bool, diff --git a/tests/helpers/test_config_validation.py b/tests/helpers/test_config_validation.py index 458774b748c..5ea6df42349 100644 --- a/tests/helpers/test_config_validation.py +++ b/tests/helpers/test_config_validation.py @@ -12,6 +12,7 @@ import voluptuous as vol import homeassistant from homeassistant.core import DOMAIN as HOMEASSISTANT_DOMAIN, HomeAssistant +from homeassistant.exceptions import HomeAssistantError from homeassistant.helpers import ( config_validation as cv, issue_registry as ir, @@ -383,7 +384,7 @@ def test_service() -> None: schema("homeassistant.turn_on") -def test_service_schema() -> None: +def test_service_schema(hass: HomeAssistant) -> None: """Test service_schema validation.""" options = ( {}, @@ -1550,10 +1551,10 @@ def test_config_entry_only_schema_cant_find_module() -> None: def test_config_entry_only_schema_no_hass( hass: HomeAssistant, caplog: pytest.LogCaptureFixture ) -> None: - """Test if the the hass context var is not set in our context.""" + """Test if the the hass context is not set in our context.""" with patch( "homeassistant.helpers.config_validation.async_get_hass", - side_effect=LookupError, + side_effect=HomeAssistantError, ): cv.config_entry_only_config_schema("test_domain")( {"test_domain": {"foo": "bar"}} diff --git a/tests/test_core.py b/tests/test_core.py index 8b63eab7b42..7e0766c8ac5 100644 --- a/tests/test_core.py +++ b/tests/test_core.py @@ -9,10 +9,12 @@ import gc import logging import os from tempfile import TemporaryDirectory +import threading import time from typing import Any from unittest.mock import MagicMock, Mock, PropertyMock, patch +import async_timeout import pytest import voluptuous as vol @@ -40,6 +42,7 @@ from homeassistant.core import ( ServiceResponse, State, SupportsResponse, + callback, ) from homeassistant.exceptions import ( HomeAssistantError, @@ -202,6 +205,184 @@ def test_async_run_hass_job_delegates_non_async() -> None: assert len(hass.async_add_hass_job.mock_calls) == 1 +async def test_async_get_hass_can_be_called(hass: HomeAssistant) -> None: + """Test calling async_get_hass via different paths. + + The test asserts async_get_hass can be called from: + - Coroutines and callbacks + - Callbacks scheduled from callbacks, coroutines and threads + - Coroutines scheduled from callbacks, coroutines and threads + + The test also asserts async_get_hass can not be called from threads + other than the event loop. + """ + task_finished = asyncio.Event() + + def can_call_async_get_hass() -> bool: + """Test if it's possible to call async_get_hass.""" + try: + if ha.async_get_hass() is hass: + return True + raise Exception + except HomeAssistantError: + return False + + raise Exception + + # Test scheduling a coroutine which calls async_get_hass via hass.async_create_task + async def _async_create_task() -> None: + task_finished.set() + assert can_call_async_get_hass() + + hass.async_create_task(_async_create_task(), "create_task") + async with async_timeout.timeout(1): + await task_finished.wait() + task_finished.clear() + + # Test scheduling a callback which calls async_get_hass via hass.async_add_job + @callback + def _add_job() -> None: + assert can_call_async_get_hass() + task_finished.set() + + hass.async_add_job(_add_job) + async with async_timeout.timeout(1): + await task_finished.wait() + task_finished.clear() + + # Test scheduling a callback which calls async_get_hass from a callback + @callback + def _schedule_callback_from_callback() -> None: + @callback + def _callback(): + assert can_call_async_get_hass() + task_finished.set() + + # Test the scheduled callback itself can call async_get_hass + assert can_call_async_get_hass() + hass.async_add_job(_callback) + + _schedule_callback_from_callback() + async with async_timeout.timeout(1): + await task_finished.wait() + task_finished.clear() + + # Test scheduling a coroutine which calls async_get_hass from a callback + @callback + def _schedule_coroutine_from_callback() -> None: + async def _coroutine(): + assert can_call_async_get_hass() + task_finished.set() + + # Test the scheduled callback itself can call async_get_hass + assert can_call_async_get_hass() + hass.async_add_job(_coroutine()) + + _schedule_coroutine_from_callback() + async with async_timeout.timeout(1): + await task_finished.wait() + task_finished.clear() + + # Test scheduling a callback which calls async_get_hass from a coroutine + async def _schedule_callback_from_coroutine() -> None: + @callback + def _callback(): + assert can_call_async_get_hass() + task_finished.set() + + # Test the coroutine itself can call async_get_hass + assert can_call_async_get_hass() + hass.async_add_job(_callback) + + await _schedule_callback_from_coroutine() + async with async_timeout.timeout(1): + await task_finished.wait() + task_finished.clear() + + # Test scheduling a coroutine which calls async_get_hass from a coroutine + async def _schedule_callback_from_coroutine() -> None: + async def _coroutine(): + assert can_call_async_get_hass() + task_finished.set() + + # Test the coroutine itself can call async_get_hass + assert can_call_async_get_hass() + await hass.async_create_task(_coroutine()) + + await _schedule_callback_from_coroutine() + async with async_timeout.timeout(1): + await task_finished.wait() + task_finished.clear() + + # Test scheduling a callback which calls async_get_hass from an executor + def _async_add_executor_job_add_job() -> None: + @callback + def _async_add_job(): + assert can_call_async_get_hass() + task_finished.set() + + # Test the executor itself can not call async_get_hass + assert not can_call_async_get_hass() + hass.add_job(_async_add_job) + + await hass.async_add_executor_job(_async_add_executor_job_add_job) + async with async_timeout.timeout(1): + await task_finished.wait() + task_finished.clear() + + # Test scheduling a coroutine which calls async_get_hass from an executor + def _async_add_executor_job_create_task() -> None: + async def _async_create_task() -> None: + assert can_call_async_get_hass() + task_finished.set() + + # Test the executor itself can not call async_get_hass + assert not can_call_async_get_hass() + hass.create_task(_async_create_task()) + + await hass.async_add_executor_job(_async_add_executor_job_create_task) + async with async_timeout.timeout(1): + await task_finished.wait() + task_finished.clear() + + # Test scheduling a callback which calls async_get_hass from a worker thread + class MyJobAddJob(threading.Thread): + @callback + def _my_threaded_job_add_job(self) -> None: + assert can_call_async_get_hass() + task_finished.set() + + def run(self) -> None: + # Test the worker thread itself can not call async_get_hass + assert not can_call_async_get_hass() + hass.add_job(self._my_threaded_job_add_job) + + my_job_add_job = MyJobAddJob() + my_job_add_job.start() + async with async_timeout.timeout(1): + await task_finished.wait() + task_finished.clear() + my_job_add_job.join() + + # Test scheduling a coroutine which calls async_get_hass from a worker thread + class MyJobCreateTask(threading.Thread): + async def _my_threaded_job_create_task(self) -> None: + assert can_call_async_get_hass() + task_finished.set() + + def run(self) -> None: + # Test the worker thread itself can not call async_get_hass + assert not can_call_async_get_hass() + hass.create_task(self._my_threaded_job_create_task()) + + my_job_create_task = MyJobCreateTask() + my_job_create_task.start() + async with async_timeout.timeout(1): + await task_finished.wait() + task_finished.clear() + my_job_create_task.join() + + async def test_stage_shutdown(hass: HomeAssistant) -> None: """Simulate a shutdown, test calling stuff.""" test_stop = async_capture_events(hass, EVENT_HOMEASSISTANT_STOP)