Refactor async_get_hass to rely on threading.local instead of a ContextVar (#96005)
* Test for async_get_hass * Add Fixpull/96119/head
parent
372687fe81
commit
18ee9f4725
|
@ -16,7 +16,6 @@ from collections.abc import (
|
|||
)
|
||||
import concurrent.futures
|
||||
from contextlib import suppress
|
||||
from contextvars import ContextVar
|
||||
import datetime
|
||||
import enum
|
||||
import functools
|
||||
|
@ -155,8 +154,6 @@ MAX_EXPECTED_ENTITY_IDS = 16384
|
|||
|
||||
_LOGGER = logging.getLogger(__name__)
|
||||
|
||||
_cv_hass: ContextVar[HomeAssistant] = ContextVar("hass")
|
||||
|
||||
|
||||
@functools.lru_cache(MAX_EXPECTED_ENTITY_IDS)
|
||||
def split_entity_id(entity_id: str) -> tuple[str, str]:
|
||||
|
@ -199,16 +196,27 @@ def is_callback(func: Callable[..., Any]) -> bool:
|
|||
return getattr(func, "_hass_callback", False) is True
|
||||
|
||||
|
||||
class _Hass(threading.local):
|
||||
"""Container which makes a HomeAssistant instance available to the event loop."""
|
||||
|
||||
hass: HomeAssistant | None = None
|
||||
|
||||
|
||||
_hass = _Hass()
|
||||
|
||||
|
||||
@callback
|
||||
def async_get_hass() -> HomeAssistant:
|
||||
"""Return the HomeAssistant instance.
|
||||
|
||||
Raises LookupError if no HomeAssistant instance is available.
|
||||
Raises HomeAssistantError when called from the wrong thread.
|
||||
|
||||
This should be used where it's very cumbersome or downright impossible to pass
|
||||
hass to the code which needs it.
|
||||
"""
|
||||
return _cv_hass.get()
|
||||
if not _hass.hass:
|
||||
raise HomeAssistantError("async_get_hass called from the wrong thread")
|
||||
return _hass.hass
|
||||
|
||||
|
||||
@enum.unique
|
||||
|
@ -292,9 +300,9 @@ class HomeAssistant:
|
|||
config_entries: ConfigEntries = None # type: ignore[assignment]
|
||||
|
||||
def __new__(cls) -> HomeAssistant:
|
||||
"""Set the _cv_hass context variable."""
|
||||
"""Set the _hass thread local data."""
|
||||
hass = super().__new__(cls)
|
||||
_cv_hass.set(hass)
|
||||
_hass.hass = hass
|
||||
return hass
|
||||
|
||||
def __init__(self) -> None:
|
||||
|
|
|
@ -93,7 +93,7 @@ from homeassistant.core import (
|
|||
split_entity_id,
|
||||
valid_entity_id,
|
||||
)
|
||||
from homeassistant.exceptions import TemplateError
|
||||
from homeassistant.exceptions import HomeAssistantError, TemplateError
|
||||
from homeassistant.generated import currencies
|
||||
from homeassistant.generated.countries import COUNTRIES
|
||||
from homeassistant.generated.languages import LANGUAGES
|
||||
|
@ -609,7 +609,7 @@ def template(value: Any | None) -> template_helper.Template:
|
|||
raise vol.Invalid("template value should be a string")
|
||||
|
||||
hass: HomeAssistant | None = None
|
||||
with contextlib.suppress(LookupError):
|
||||
with contextlib.suppress(HomeAssistantError):
|
||||
hass = async_get_hass()
|
||||
|
||||
template_value = template_helper.Template(str(value), hass)
|
||||
|
@ -631,7 +631,7 @@ def dynamic_template(value: Any | None) -> template_helper.Template:
|
|||
raise vol.Invalid("template value does not contain a dynamic template")
|
||||
|
||||
hass: HomeAssistant | None = None
|
||||
with contextlib.suppress(LookupError):
|
||||
with contextlib.suppress(HomeAssistantError):
|
||||
hass = async_get_hass()
|
||||
|
||||
template_value = template_helper.Template(str(value), hass)
|
||||
|
@ -1098,7 +1098,7 @@ def _no_yaml_config_schema(
|
|||
# pylint: disable-next=import-outside-toplevel
|
||||
from .issue_registry import IssueSeverity, async_create_issue
|
||||
|
||||
with contextlib.suppress(LookupError):
|
||||
with contextlib.suppress(HomeAssistantError):
|
||||
hass = async_get_hass()
|
||||
async_create_issue(
|
||||
hass,
|
||||
|
|
|
@ -490,17 +490,7 @@ def hass_fixture_setup() -> list[bool]:
|
|||
|
||||
|
||||
@pytest.fixture
|
||||
def hass(_hass: HomeAssistant) -> HomeAssistant:
|
||||
"""Fixture to provide a test instance of Home Assistant."""
|
||||
# This wraps the async _hass fixture inside a sync fixture, to ensure
|
||||
# the `hass` context variable is set in the execution context in which
|
||||
# the test itself is executed
|
||||
ha._cv_hass.set(_hass)
|
||||
return _hass
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
async def _hass(
|
||||
async def hass(
|
||||
hass_fixture_setup: list[bool],
|
||||
event_loop: asyncio.AbstractEventLoop,
|
||||
load_registries: bool,
|
||||
|
|
|
@ -12,6 +12,7 @@ import voluptuous as vol
|
|||
|
||||
import homeassistant
|
||||
from homeassistant.core import DOMAIN as HOMEASSISTANT_DOMAIN, HomeAssistant
|
||||
from homeassistant.exceptions import HomeAssistantError
|
||||
from homeassistant.helpers import (
|
||||
config_validation as cv,
|
||||
issue_registry as ir,
|
||||
|
@ -383,7 +384,7 @@ def test_service() -> None:
|
|||
schema("homeassistant.turn_on")
|
||||
|
||||
|
||||
def test_service_schema() -> None:
|
||||
def test_service_schema(hass: HomeAssistant) -> None:
|
||||
"""Test service_schema validation."""
|
||||
options = (
|
||||
{},
|
||||
|
@ -1550,10 +1551,10 @@ def test_config_entry_only_schema_cant_find_module() -> None:
|
|||
def test_config_entry_only_schema_no_hass(
|
||||
hass: HomeAssistant, caplog: pytest.LogCaptureFixture
|
||||
) -> None:
|
||||
"""Test if the the hass context var is not set in our context."""
|
||||
"""Test if the the hass context is not set in our context."""
|
||||
with patch(
|
||||
"homeassistant.helpers.config_validation.async_get_hass",
|
||||
side_effect=LookupError,
|
||||
side_effect=HomeAssistantError,
|
||||
):
|
||||
cv.config_entry_only_config_schema("test_domain")(
|
||||
{"test_domain": {"foo": "bar"}}
|
||||
|
|
|
@ -9,10 +9,12 @@ import gc
|
|||
import logging
|
||||
import os
|
||||
from tempfile import TemporaryDirectory
|
||||
import threading
|
||||
import time
|
||||
from typing import Any
|
||||
from unittest.mock import MagicMock, Mock, PropertyMock, patch
|
||||
|
||||
import async_timeout
|
||||
import pytest
|
||||
import voluptuous as vol
|
||||
|
||||
|
@ -40,6 +42,7 @@ from homeassistant.core import (
|
|||
ServiceResponse,
|
||||
State,
|
||||
SupportsResponse,
|
||||
callback,
|
||||
)
|
||||
from homeassistant.exceptions import (
|
||||
HomeAssistantError,
|
||||
|
@ -202,6 +205,184 @@ def test_async_run_hass_job_delegates_non_async() -> None:
|
|||
assert len(hass.async_add_hass_job.mock_calls) == 1
|
||||
|
||||
|
||||
async def test_async_get_hass_can_be_called(hass: HomeAssistant) -> None:
|
||||
"""Test calling async_get_hass via different paths.
|
||||
|
||||
The test asserts async_get_hass can be called from:
|
||||
- Coroutines and callbacks
|
||||
- Callbacks scheduled from callbacks, coroutines and threads
|
||||
- Coroutines scheduled from callbacks, coroutines and threads
|
||||
|
||||
The test also asserts async_get_hass can not be called from threads
|
||||
other than the event loop.
|
||||
"""
|
||||
task_finished = asyncio.Event()
|
||||
|
||||
def can_call_async_get_hass() -> bool:
|
||||
"""Test if it's possible to call async_get_hass."""
|
||||
try:
|
||||
if ha.async_get_hass() is hass:
|
||||
return True
|
||||
raise Exception
|
||||
except HomeAssistantError:
|
||||
return False
|
||||
|
||||
raise Exception
|
||||
|
||||
# Test scheduling a coroutine which calls async_get_hass via hass.async_create_task
|
||||
async def _async_create_task() -> None:
|
||||
task_finished.set()
|
||||
assert can_call_async_get_hass()
|
||||
|
||||
hass.async_create_task(_async_create_task(), "create_task")
|
||||
async with async_timeout.timeout(1):
|
||||
await task_finished.wait()
|
||||
task_finished.clear()
|
||||
|
||||
# Test scheduling a callback which calls async_get_hass via hass.async_add_job
|
||||
@callback
|
||||
def _add_job() -> None:
|
||||
assert can_call_async_get_hass()
|
||||
task_finished.set()
|
||||
|
||||
hass.async_add_job(_add_job)
|
||||
async with async_timeout.timeout(1):
|
||||
await task_finished.wait()
|
||||
task_finished.clear()
|
||||
|
||||
# Test scheduling a callback which calls async_get_hass from a callback
|
||||
@callback
|
||||
def _schedule_callback_from_callback() -> None:
|
||||
@callback
|
||||
def _callback():
|
||||
assert can_call_async_get_hass()
|
||||
task_finished.set()
|
||||
|
||||
# Test the scheduled callback itself can call async_get_hass
|
||||
assert can_call_async_get_hass()
|
||||
hass.async_add_job(_callback)
|
||||
|
||||
_schedule_callback_from_callback()
|
||||
async with async_timeout.timeout(1):
|
||||
await task_finished.wait()
|
||||
task_finished.clear()
|
||||
|
||||
# Test scheduling a coroutine which calls async_get_hass from a callback
|
||||
@callback
|
||||
def _schedule_coroutine_from_callback() -> None:
|
||||
async def _coroutine():
|
||||
assert can_call_async_get_hass()
|
||||
task_finished.set()
|
||||
|
||||
# Test the scheduled callback itself can call async_get_hass
|
||||
assert can_call_async_get_hass()
|
||||
hass.async_add_job(_coroutine())
|
||||
|
||||
_schedule_coroutine_from_callback()
|
||||
async with async_timeout.timeout(1):
|
||||
await task_finished.wait()
|
||||
task_finished.clear()
|
||||
|
||||
# Test scheduling a callback which calls async_get_hass from a coroutine
|
||||
async def _schedule_callback_from_coroutine() -> None:
|
||||
@callback
|
||||
def _callback():
|
||||
assert can_call_async_get_hass()
|
||||
task_finished.set()
|
||||
|
||||
# Test the coroutine itself can call async_get_hass
|
||||
assert can_call_async_get_hass()
|
||||
hass.async_add_job(_callback)
|
||||
|
||||
await _schedule_callback_from_coroutine()
|
||||
async with async_timeout.timeout(1):
|
||||
await task_finished.wait()
|
||||
task_finished.clear()
|
||||
|
||||
# Test scheduling a coroutine which calls async_get_hass from a coroutine
|
||||
async def _schedule_callback_from_coroutine() -> None:
|
||||
async def _coroutine():
|
||||
assert can_call_async_get_hass()
|
||||
task_finished.set()
|
||||
|
||||
# Test the coroutine itself can call async_get_hass
|
||||
assert can_call_async_get_hass()
|
||||
await hass.async_create_task(_coroutine())
|
||||
|
||||
await _schedule_callback_from_coroutine()
|
||||
async with async_timeout.timeout(1):
|
||||
await task_finished.wait()
|
||||
task_finished.clear()
|
||||
|
||||
# Test scheduling a callback which calls async_get_hass from an executor
|
||||
def _async_add_executor_job_add_job() -> None:
|
||||
@callback
|
||||
def _async_add_job():
|
||||
assert can_call_async_get_hass()
|
||||
task_finished.set()
|
||||
|
||||
# Test the executor itself can not call async_get_hass
|
||||
assert not can_call_async_get_hass()
|
||||
hass.add_job(_async_add_job)
|
||||
|
||||
await hass.async_add_executor_job(_async_add_executor_job_add_job)
|
||||
async with async_timeout.timeout(1):
|
||||
await task_finished.wait()
|
||||
task_finished.clear()
|
||||
|
||||
# Test scheduling a coroutine which calls async_get_hass from an executor
|
||||
def _async_add_executor_job_create_task() -> None:
|
||||
async def _async_create_task() -> None:
|
||||
assert can_call_async_get_hass()
|
||||
task_finished.set()
|
||||
|
||||
# Test the executor itself can not call async_get_hass
|
||||
assert not can_call_async_get_hass()
|
||||
hass.create_task(_async_create_task())
|
||||
|
||||
await hass.async_add_executor_job(_async_add_executor_job_create_task)
|
||||
async with async_timeout.timeout(1):
|
||||
await task_finished.wait()
|
||||
task_finished.clear()
|
||||
|
||||
# Test scheduling a callback which calls async_get_hass from a worker thread
|
||||
class MyJobAddJob(threading.Thread):
|
||||
@callback
|
||||
def _my_threaded_job_add_job(self) -> None:
|
||||
assert can_call_async_get_hass()
|
||||
task_finished.set()
|
||||
|
||||
def run(self) -> None:
|
||||
# Test the worker thread itself can not call async_get_hass
|
||||
assert not can_call_async_get_hass()
|
||||
hass.add_job(self._my_threaded_job_add_job)
|
||||
|
||||
my_job_add_job = MyJobAddJob()
|
||||
my_job_add_job.start()
|
||||
async with async_timeout.timeout(1):
|
||||
await task_finished.wait()
|
||||
task_finished.clear()
|
||||
my_job_add_job.join()
|
||||
|
||||
# Test scheduling a coroutine which calls async_get_hass from a worker thread
|
||||
class MyJobCreateTask(threading.Thread):
|
||||
async def _my_threaded_job_create_task(self) -> None:
|
||||
assert can_call_async_get_hass()
|
||||
task_finished.set()
|
||||
|
||||
def run(self) -> None:
|
||||
# Test the worker thread itself can not call async_get_hass
|
||||
assert not can_call_async_get_hass()
|
||||
hass.create_task(self._my_threaded_job_create_task())
|
||||
|
||||
my_job_create_task = MyJobCreateTask()
|
||||
my_job_create_task.start()
|
||||
async with async_timeout.timeout(1):
|
||||
await task_finished.wait()
|
||||
task_finished.clear()
|
||||
my_job_create_task.join()
|
||||
|
||||
|
||||
async def test_stage_shutdown(hass: HomeAssistant) -> None:
|
||||
"""Simulate a shutdown, test calling stuff."""
|
||||
test_stop = async_capture_events(hass, EVENT_HOMEASSISTANT_STOP)
|
||||
|
|
Loading…
Reference in New Issue