1781 lines
		
	
	
		
			56 KiB
		
	
	
	
		
			Python
		
	
	
			
		
		
	
	
			1781 lines
		
	
	
		
			56 KiB
		
	
	
	
		
			Python
		
	
	
"""Test the helper method for writing tests."""
 | 
						|
 | 
						|
from __future__ import annotations
 | 
						|
 | 
						|
import asyncio
 | 
						|
from collections.abc import (
 | 
						|
    AsyncGenerator,
 | 
						|
    Callable,
 | 
						|
    Coroutine,
 | 
						|
    Generator,
 | 
						|
    Mapping,
 | 
						|
    Sequence,
 | 
						|
)
 | 
						|
from contextlib import asynccontextmanager, contextmanager
 | 
						|
from datetime import UTC, datetime, timedelta
 | 
						|
from enum import Enum
 | 
						|
import functools as ft
 | 
						|
from functools import lru_cache
 | 
						|
from io import StringIO
 | 
						|
import json
 | 
						|
import logging
 | 
						|
import os
 | 
						|
import pathlib
 | 
						|
import threading
 | 
						|
import time
 | 
						|
from types import FrameType, ModuleType
 | 
						|
from typing import Any, Literal, NoReturn
 | 
						|
from unittest.mock import AsyncMock, Mock, patch
 | 
						|
 | 
						|
from aiohttp.test_utils import unused_port as get_test_instance_port  # noqa: F401
 | 
						|
import pytest
 | 
						|
from syrupy import SnapshotAssertion
 | 
						|
import voluptuous as vol
 | 
						|
 | 
						|
from homeassistant import auth, bootstrap, config_entries, loader
 | 
						|
from homeassistant.auth import (
 | 
						|
    auth_store,
 | 
						|
    models as auth_models,
 | 
						|
    permissions as auth_permissions,
 | 
						|
    providers as auth_providers,
 | 
						|
)
 | 
						|
from homeassistant.auth.permissions import system_policies
 | 
						|
from homeassistant.components import device_automation, persistent_notification as pn
 | 
						|
from homeassistant.components.device_automation import (  # noqa: F401
 | 
						|
    _async_get_device_automation_capabilities as async_get_device_automation_capabilities,
 | 
						|
)
 | 
						|
from homeassistant.config import async_process_component_config
 | 
						|
from homeassistant.config_entries import ConfigEntry, ConfigFlow
 | 
						|
from homeassistant.const import (
 | 
						|
    DEVICE_DEFAULT_NAME,
 | 
						|
    EVENT_HOMEASSISTANT_CLOSE,
 | 
						|
    EVENT_HOMEASSISTANT_STOP,
 | 
						|
    EVENT_STATE_CHANGED,
 | 
						|
    STATE_OFF,
 | 
						|
    STATE_ON,
 | 
						|
)
 | 
						|
from homeassistant.core import (
 | 
						|
    CoreState,
 | 
						|
    Event,
 | 
						|
    HomeAssistant,
 | 
						|
    ServiceCall,
 | 
						|
    ServiceResponse,
 | 
						|
    State,
 | 
						|
    SupportsResponse,
 | 
						|
    callback,
 | 
						|
)
 | 
						|
from homeassistant.helpers import (
 | 
						|
    area_registry as ar,
 | 
						|
    category_registry as cr,
 | 
						|
    device_registry as dr,
 | 
						|
    entity,
 | 
						|
    entity_platform,
 | 
						|
    entity_registry as er,
 | 
						|
    event,
 | 
						|
    floor_registry as fr,
 | 
						|
    intent,
 | 
						|
    issue_registry as ir,
 | 
						|
    label_registry as lr,
 | 
						|
    restore_state as rs,
 | 
						|
    storage,
 | 
						|
    translation,
 | 
						|
)
 | 
						|
from homeassistant.helpers.dispatcher import (
 | 
						|
    async_dispatcher_connect,
 | 
						|
    async_dispatcher_send,
 | 
						|
)
 | 
						|
from homeassistant.helpers.entity import Entity
 | 
						|
from homeassistant.helpers.entity_platform import AddEntitiesCallback
 | 
						|
from homeassistant.helpers.json import JSONEncoder, _orjson_default_encoder, json_dumps
 | 
						|
from homeassistant.helpers.typing import ConfigType, DiscoveryInfoType
 | 
						|
from homeassistant.util.async_ import run_callback_threadsafe
 | 
						|
import homeassistant.util.dt as dt_util
 | 
						|
from homeassistant.util.json import (
 | 
						|
    JsonArrayType,
 | 
						|
    JsonObjectType,
 | 
						|
    JsonValueType,
 | 
						|
    json_loads,
 | 
						|
    json_loads_array,
 | 
						|
    json_loads_object,
 | 
						|
)
 | 
						|
from homeassistant.util.signal_type import SignalType
 | 
						|
import homeassistant.util.ulid as ulid_util
 | 
						|
from homeassistant.util.unit_system import METRIC_SYSTEM
 | 
						|
import homeassistant.util.yaml.loader as yaml_loader
 | 
						|
 | 
						|
from .testing_config.custom_components.test_constant_deprecation import (
 | 
						|
    import_deprecated_constant,
 | 
						|
)
 | 
						|
 | 
						|
_LOGGER = logging.getLogger(__name__)
 | 
						|
INSTANCES = []
 | 
						|
CLIENT_ID = "https://example.com/app"
 | 
						|
CLIENT_REDIRECT_URI = "https://example.com/app/callback"
 | 
						|
 | 
						|
 | 
						|
async def async_get_device_automations(
 | 
						|
    hass: HomeAssistant,
 | 
						|
    automation_type: device_automation.DeviceAutomationType,
 | 
						|
    device_id: str,
 | 
						|
) -> Any:
 | 
						|
    """Get a device automation for a single device id."""
 | 
						|
    automations = await device_automation.async_get_device_automations(
 | 
						|
        hass, automation_type, [device_id]
 | 
						|
    )
 | 
						|
    return automations.get(device_id)
 | 
						|
 | 
						|
 | 
						|
def threadsafe_callback_factory(func):
 | 
						|
    """Create threadsafe functions out of callbacks.
 | 
						|
 | 
						|
    Callback needs to have `hass` as first argument.
 | 
						|
    """
 | 
						|
 | 
						|
    @ft.wraps(func)
 | 
						|
    def threadsafe(*args, **kwargs):
 | 
						|
        """Call func threadsafe."""
 | 
						|
        hass = args[0]
 | 
						|
        return run_callback_threadsafe(
 | 
						|
            hass.loop, ft.partial(func, *args, **kwargs)
 | 
						|
        ).result()
 | 
						|
 | 
						|
    return threadsafe
 | 
						|
 | 
						|
 | 
						|
def threadsafe_coroutine_factory(func):
 | 
						|
    """Create threadsafe functions out of coroutine.
 | 
						|
 | 
						|
    Callback needs to have `hass` as first argument.
 | 
						|
    """
 | 
						|
 | 
						|
    @ft.wraps(func)
 | 
						|
    def threadsafe(*args, **kwargs):
 | 
						|
        """Call func threadsafe."""
 | 
						|
        hass = args[0]
 | 
						|
        return asyncio.run_coroutine_threadsafe(
 | 
						|
            func(*args, **kwargs), hass.loop
 | 
						|
        ).result()
 | 
						|
 | 
						|
    return threadsafe
 | 
						|
 | 
						|
 | 
						|
def get_test_config_dir(*add_path):
 | 
						|
    """Return a path to a test config dir."""
 | 
						|
    return os.path.join(os.path.dirname(__file__), "testing_config", *add_path)
 | 
						|
 | 
						|
 | 
						|
@contextmanager
 | 
						|
def get_test_home_assistant() -> Generator[HomeAssistant]:
 | 
						|
    """Return a Home Assistant object pointing at test config directory."""
 | 
						|
    loop = asyncio.new_event_loop()
 | 
						|
    asyncio.set_event_loop(loop)
 | 
						|
    context_manager = async_test_home_assistant(loop)
 | 
						|
    hass = loop.run_until_complete(context_manager.__aenter__())
 | 
						|
 | 
						|
    loop_stop_event = threading.Event()
 | 
						|
 | 
						|
    def run_loop() -> None:
 | 
						|
        """Run event loop."""
 | 
						|
 | 
						|
        loop._thread_ident = threading.get_ident()
 | 
						|
        hass.loop_thread_id = loop._thread_ident
 | 
						|
        loop.run_forever()
 | 
						|
        loop_stop_event.set()
 | 
						|
 | 
						|
    orig_stop = hass.stop
 | 
						|
    hass._stopped = Mock(set=loop.stop)
 | 
						|
 | 
						|
    def start_hass(*mocks: Any) -> None:
 | 
						|
        """Start hass."""
 | 
						|
        asyncio.run_coroutine_threadsafe(hass.async_start(), loop).result()
 | 
						|
 | 
						|
    def stop_hass() -> None:
 | 
						|
        """Stop hass."""
 | 
						|
        orig_stop()
 | 
						|
        loop_stop_event.wait()
 | 
						|
 | 
						|
    hass.start = start_hass
 | 
						|
    hass.stop = stop_hass
 | 
						|
 | 
						|
    threading.Thread(name="LoopThread", target=run_loop, daemon=False).start()
 | 
						|
 | 
						|
    try:
 | 
						|
        yield hass
 | 
						|
    finally:
 | 
						|
        loop.run_until_complete(context_manager.__aexit__(None, None, None))
 | 
						|
        loop.close()
 | 
						|
 | 
						|
 | 
						|
class StoreWithoutWriteLoad[_T: (Mapping[str, Any] | Sequence[Any])](storage.Store[_T]):
 | 
						|
    """Fake store that does not write or load. Used for testing."""
 | 
						|
 | 
						|
    async def async_save(self, *args: Any, **kwargs: Any) -> None:
 | 
						|
        """Save the data.
 | 
						|
 | 
						|
        This function is mocked out in tests.
 | 
						|
        """
 | 
						|
 | 
						|
    @callback
 | 
						|
    def async_save_delay(self, *args: Any, **kwargs: Any) -> None:
 | 
						|
        """Save data with an optional delay.
 | 
						|
 | 
						|
        This function is mocked out in tests.
 | 
						|
        """
 | 
						|
 | 
						|
 | 
						|
@asynccontextmanager
 | 
						|
async def async_test_home_assistant(
 | 
						|
    event_loop: asyncio.AbstractEventLoop | None = None,
 | 
						|
    load_registries: bool = True,
 | 
						|
    config_dir: str | None = None,
 | 
						|
) -> AsyncGenerator[HomeAssistant]:
 | 
						|
    """Return a Home Assistant object pointing at test config dir."""
 | 
						|
    hass = HomeAssistant(config_dir or get_test_config_dir())
 | 
						|
    store = auth_store.AuthStore(hass)
 | 
						|
    hass.auth = auth.AuthManager(hass, store, {}, {})
 | 
						|
    ensure_auth_manager_loaded(hass.auth)
 | 
						|
    INSTANCES.append(hass)
 | 
						|
 | 
						|
    orig_async_add_job = hass.async_add_job
 | 
						|
    orig_async_add_executor_job = hass.async_add_executor_job
 | 
						|
    orig_async_create_task_internal = hass.async_create_task_internal
 | 
						|
    orig_tz = dt_util.get_default_time_zone()
 | 
						|
 | 
						|
    def async_add_job(target, *args, eager_start: bool = False):
 | 
						|
        """Add job."""
 | 
						|
        check_target = target
 | 
						|
        while isinstance(check_target, ft.partial):
 | 
						|
            check_target = check_target.func
 | 
						|
 | 
						|
        if isinstance(check_target, Mock) and not isinstance(target, AsyncMock):
 | 
						|
            fut = asyncio.Future()
 | 
						|
            fut.set_result(target(*args))
 | 
						|
            return fut
 | 
						|
 | 
						|
        return orig_async_add_job(target, *args, eager_start=eager_start)
 | 
						|
 | 
						|
    def async_add_executor_job(target, *args):
 | 
						|
        """Add executor job."""
 | 
						|
        check_target = target
 | 
						|
        while isinstance(check_target, ft.partial):
 | 
						|
            check_target = check_target.func
 | 
						|
 | 
						|
        if isinstance(check_target, Mock):
 | 
						|
            fut = asyncio.Future()
 | 
						|
            fut.set_result(target(*args))
 | 
						|
            return fut
 | 
						|
 | 
						|
        return orig_async_add_executor_job(target, *args)
 | 
						|
 | 
						|
    def async_create_task_internal(coroutine, name=None, eager_start=True):
 | 
						|
        """Create task."""
 | 
						|
        if isinstance(coroutine, Mock) and not isinstance(coroutine, AsyncMock):
 | 
						|
            fut = asyncio.Future()
 | 
						|
            fut.set_result(None)
 | 
						|
            return fut
 | 
						|
 | 
						|
        return orig_async_create_task_internal(coroutine, name, eager_start)
 | 
						|
 | 
						|
    hass.async_add_job = async_add_job
 | 
						|
    hass.async_add_executor_job = async_add_executor_job
 | 
						|
    hass.async_create_task_internal = async_create_task_internal
 | 
						|
 | 
						|
    hass.data[loader.DATA_CUSTOM_COMPONENTS] = {}
 | 
						|
 | 
						|
    hass.config.location_name = "test home"
 | 
						|
    hass.config.latitude = 32.87336
 | 
						|
    hass.config.longitude = -117.22743
 | 
						|
    hass.config.elevation = 0
 | 
						|
    await hass.config.async_set_time_zone("US/Pacific")
 | 
						|
    hass.config.units = METRIC_SYSTEM
 | 
						|
    hass.config.media_dirs = {"local": get_test_config_dir("media")}
 | 
						|
    hass.config.skip_pip = True
 | 
						|
    hass.config.skip_pip_packages = []
 | 
						|
 | 
						|
    hass.config_entries = config_entries.ConfigEntries(
 | 
						|
        hass,
 | 
						|
        {
 | 
						|
            "_": (
 | 
						|
                "Not empty or else some bad checks for hass config in discovery.py"
 | 
						|
                " breaks"
 | 
						|
            )
 | 
						|
        },
 | 
						|
    )
 | 
						|
    hass.bus.async_listen_once(
 | 
						|
        EVENT_HOMEASSISTANT_STOP,
 | 
						|
        hass.config_entries._async_shutdown,
 | 
						|
    )
 | 
						|
 | 
						|
    # Load the registries
 | 
						|
    entity.async_setup(hass)
 | 
						|
    loader.async_setup(hass)
 | 
						|
 | 
						|
    # setup translation cache instead of calling translation.async_setup(hass)
 | 
						|
    hass.data[translation.TRANSLATION_FLATTEN_CACHE] = translation._TranslationCache(
 | 
						|
        hass
 | 
						|
    )
 | 
						|
    if load_registries:
 | 
						|
        with (
 | 
						|
            patch.object(StoreWithoutWriteLoad, "async_load", return_value=None),
 | 
						|
            patch(
 | 
						|
                "homeassistant.helpers.area_registry.AreaRegistryStore",
 | 
						|
                StoreWithoutWriteLoad,
 | 
						|
            ),
 | 
						|
            patch(
 | 
						|
                "homeassistant.helpers.device_registry.DeviceRegistryStore",
 | 
						|
                StoreWithoutWriteLoad,
 | 
						|
            ),
 | 
						|
            patch(
 | 
						|
                "homeassistant.helpers.entity_registry.EntityRegistryStore",
 | 
						|
                StoreWithoutWriteLoad,
 | 
						|
            ),
 | 
						|
            patch(
 | 
						|
                "homeassistant.helpers.storage.Store",  # Floor & label registry are different
 | 
						|
                StoreWithoutWriteLoad,
 | 
						|
            ),
 | 
						|
            patch(
 | 
						|
                "homeassistant.helpers.issue_registry.IssueRegistryStore",
 | 
						|
                StoreWithoutWriteLoad,
 | 
						|
            ),
 | 
						|
            patch(
 | 
						|
                "homeassistant.helpers.restore_state.RestoreStateData.async_setup_dump",
 | 
						|
                return_value=None,
 | 
						|
            ),
 | 
						|
            patch(
 | 
						|
                "homeassistant.helpers.restore_state.start.async_at_start",
 | 
						|
            ),
 | 
						|
        ):
 | 
						|
            await ar.async_load(hass)
 | 
						|
            await cr.async_load(hass)
 | 
						|
            await dr.async_load(hass)
 | 
						|
            await er.async_load(hass)
 | 
						|
            await fr.async_load(hass)
 | 
						|
            await ir.async_load(hass)
 | 
						|
            await lr.async_load(hass)
 | 
						|
            await rs.async_load(hass)
 | 
						|
        hass.data[bootstrap.DATA_REGISTRIES_LOADED] = None
 | 
						|
 | 
						|
    hass.set_state(CoreState.running)
 | 
						|
 | 
						|
    @callback
 | 
						|
    def clear_instance(event):
 | 
						|
        """Clear global instance."""
 | 
						|
        # Give aiohttp one loop iteration to close
 | 
						|
        hass.loop.call_soon(INSTANCES.remove, hass)
 | 
						|
 | 
						|
    hass.bus.async_listen_once(EVENT_HOMEASSISTANT_CLOSE, clear_instance)
 | 
						|
 | 
						|
    try:
 | 
						|
        yield hass
 | 
						|
    finally:
 | 
						|
        # Restore timezone, it is set when creating the hass object
 | 
						|
        dt_util.set_default_time_zone(orig_tz)
 | 
						|
 | 
						|
 | 
						|
def async_mock_service(
 | 
						|
    hass: HomeAssistant,
 | 
						|
    domain: str,
 | 
						|
    service: str,
 | 
						|
    schema: vol.Schema | None = None,
 | 
						|
    response: ServiceResponse = None,
 | 
						|
    supports_response: SupportsResponse | None = None,
 | 
						|
    raise_exception: Exception | None = None,
 | 
						|
) -> list[ServiceCall]:
 | 
						|
    """Set up a fake service & return a calls log list to this service."""
 | 
						|
    calls = []
 | 
						|
 | 
						|
    @callback
 | 
						|
    def mock_service_log(call):
 | 
						|
        """Mock service call."""
 | 
						|
        calls.append(call)
 | 
						|
        if raise_exception is not None:
 | 
						|
            raise raise_exception
 | 
						|
        return response
 | 
						|
 | 
						|
    if supports_response is None:
 | 
						|
        if response is not None:
 | 
						|
            supports_response = SupportsResponse.OPTIONAL
 | 
						|
        else:
 | 
						|
            supports_response = SupportsResponse.NONE
 | 
						|
 | 
						|
    hass.services.async_register(
 | 
						|
        domain,
 | 
						|
        service,
 | 
						|
        mock_service_log,
 | 
						|
        schema=schema,
 | 
						|
        supports_response=supports_response,
 | 
						|
    )
 | 
						|
 | 
						|
    return calls
 | 
						|
 | 
						|
 | 
						|
mock_service = threadsafe_callback_factory(async_mock_service)
 | 
						|
 | 
						|
 | 
						|
@callback
 | 
						|
def async_mock_intent(hass, intent_typ):
 | 
						|
    """Set up a fake intent handler."""
 | 
						|
    intents = []
 | 
						|
 | 
						|
    class MockIntentHandler(intent.IntentHandler):
 | 
						|
        intent_type = intent_typ
 | 
						|
 | 
						|
        async def async_handle(self, intent_obj):
 | 
						|
            """Handle the intent."""
 | 
						|
            intents.append(intent_obj)
 | 
						|
            return intent_obj.create_response()
 | 
						|
 | 
						|
    intent.async_register(hass, MockIntentHandler())
 | 
						|
 | 
						|
    return intents
 | 
						|
 | 
						|
 | 
						|
@callback
 | 
						|
def async_fire_mqtt_message(
 | 
						|
    hass: HomeAssistant,
 | 
						|
    topic: str,
 | 
						|
    payload: bytes | str,
 | 
						|
    qos: int = 0,
 | 
						|
    retain: bool = False,
 | 
						|
) -> None:
 | 
						|
    """Fire the MQTT message."""
 | 
						|
    # Local import to avoid processing MQTT modules when running a testcase
 | 
						|
    # which does not use MQTT.
 | 
						|
 | 
						|
    # pylint: disable-next=import-outside-toplevel
 | 
						|
    from paho.mqtt.client import MQTTMessage
 | 
						|
 | 
						|
    # pylint: disable-next=import-outside-toplevel
 | 
						|
    from homeassistant.components.mqtt.models import MqttData
 | 
						|
 | 
						|
    if isinstance(payload, str):
 | 
						|
        payload = payload.encode("utf-8")
 | 
						|
 | 
						|
    msg = MQTTMessage(topic=topic.encode("utf-8"))
 | 
						|
    msg.payload = payload
 | 
						|
    msg.qos = qos
 | 
						|
    msg.retain = retain
 | 
						|
    msg.timestamp = time.monotonic()
 | 
						|
 | 
						|
    mqtt_data: MqttData = hass.data["mqtt"]
 | 
						|
    assert mqtt_data.client
 | 
						|
    mqtt_data.client._async_mqtt_on_message(Mock(), None, msg)
 | 
						|
 | 
						|
 | 
						|
fire_mqtt_message = threadsafe_callback_factory(async_fire_mqtt_message)
 | 
						|
 | 
						|
 | 
						|
@callback
 | 
						|
def async_fire_time_changed_exact(
 | 
						|
    hass: HomeAssistant, datetime_: datetime | None = None, fire_all: bool = False
 | 
						|
) -> None:
 | 
						|
    """Fire a time changed event at an exact microsecond.
 | 
						|
 | 
						|
    Consider that it is not possible to actually achieve an exact
 | 
						|
    microsecond in production as the event loop is not precise enough.
 | 
						|
    If your code relies on this level of precision, consider a different
 | 
						|
    approach, as this is only for testing.
 | 
						|
    """
 | 
						|
    if datetime_ is None:
 | 
						|
        utc_datetime = datetime.now(UTC)
 | 
						|
    else:
 | 
						|
        utc_datetime = dt_util.as_utc(datetime_)
 | 
						|
 | 
						|
    _async_fire_time_changed(hass, utc_datetime, fire_all)
 | 
						|
 | 
						|
 | 
						|
@callback
 | 
						|
def async_fire_time_changed(
 | 
						|
    hass: HomeAssistant, datetime_: datetime | None = None, fire_all: bool = False
 | 
						|
) -> None:
 | 
						|
    """Fire a time changed event.
 | 
						|
 | 
						|
    If called within the first 500  ms of a second, time will be bumped to exactly
 | 
						|
    500 ms to match the async_track_utc_time_change event listeners and
 | 
						|
    DataUpdateCoordinator which spreads all updates between 0.05..0.50.
 | 
						|
    Background in PR https://github.com/home-assistant/core/pull/82233
 | 
						|
 | 
						|
    As asyncio is cooperative, we can't guarantee that the event loop will
 | 
						|
    run an event at the exact time we want. If you need to fire time changed
 | 
						|
    for an exact microsecond, use async_fire_time_changed_exact.
 | 
						|
    """
 | 
						|
    if datetime_ is None:
 | 
						|
        utc_datetime = datetime.now(UTC)
 | 
						|
    else:
 | 
						|
        utc_datetime = dt_util.as_utc(datetime_)
 | 
						|
 | 
						|
    # Increase the mocked time by 0.5 s to account for up to 0.5 s delay
 | 
						|
    # added to events scheduled by update_coordinator and async_track_time_interval
 | 
						|
    utc_datetime += timedelta(microseconds=event.RANDOM_MICROSECOND_MAX)
 | 
						|
 | 
						|
    _async_fire_time_changed(hass, utc_datetime, fire_all)
 | 
						|
 | 
						|
 | 
						|
_MONOTONIC_RESOLUTION = time.get_clock_info("monotonic").resolution
 | 
						|
 | 
						|
 | 
						|
@callback
 | 
						|
def _async_fire_time_changed(
 | 
						|
    hass: HomeAssistant, utc_datetime: datetime | None, fire_all: bool
 | 
						|
) -> None:
 | 
						|
    timestamp = dt_util.utc_to_timestamp(utc_datetime)
 | 
						|
    for task in list(hass.loop._scheduled):
 | 
						|
        if not isinstance(task, asyncio.TimerHandle):
 | 
						|
            continue
 | 
						|
        if task.cancelled():
 | 
						|
            continue
 | 
						|
 | 
						|
        mock_seconds_into_future = timestamp - time.time()
 | 
						|
        future_seconds = task.when() - (hass.loop.time() + _MONOTONIC_RESOLUTION)
 | 
						|
 | 
						|
        if fire_all or mock_seconds_into_future >= future_seconds:
 | 
						|
            with (
 | 
						|
                patch(
 | 
						|
                    "homeassistant.helpers.event.time_tracker_utcnow",
 | 
						|
                    return_value=utc_datetime,
 | 
						|
                ),
 | 
						|
                patch(
 | 
						|
                    "homeassistant.helpers.event.time_tracker_timestamp",
 | 
						|
                    return_value=timestamp,
 | 
						|
                ),
 | 
						|
            ):
 | 
						|
                task._run()
 | 
						|
                task.cancel()
 | 
						|
 | 
						|
 | 
						|
fire_time_changed = threadsafe_callback_factory(async_fire_time_changed)
 | 
						|
 | 
						|
 | 
						|
def get_fixture_path(filename: str, integration: str | None = None) -> pathlib.Path:
 | 
						|
    """Get path of fixture."""
 | 
						|
    if integration is None and "/" in filename and not filename.startswith("helpers/"):
 | 
						|
        integration, filename = filename.split("/", 1)
 | 
						|
 | 
						|
    if integration is None:
 | 
						|
        return pathlib.Path(__file__).parent.joinpath("fixtures", filename)
 | 
						|
 | 
						|
    return pathlib.Path(__file__).parent.joinpath(
 | 
						|
        "components", integration, "fixtures", filename
 | 
						|
    )
 | 
						|
 | 
						|
 | 
						|
@lru_cache
 | 
						|
def load_fixture(filename: str, integration: str | None = None) -> str:
 | 
						|
    """Load a fixture."""
 | 
						|
    return get_fixture_path(filename, integration).read_text(encoding="utf8")
 | 
						|
 | 
						|
 | 
						|
def load_json_value_fixture(
 | 
						|
    filename: str, integration: str | None = None
 | 
						|
) -> JsonValueType:
 | 
						|
    """Load a JSON value from a fixture."""
 | 
						|
    return json_loads(load_fixture(filename, integration))
 | 
						|
 | 
						|
 | 
						|
def load_json_array_fixture(
 | 
						|
    filename: str, integration: str | None = None
 | 
						|
) -> JsonArrayType:
 | 
						|
    """Load a JSON array from a fixture."""
 | 
						|
    return json_loads_array(load_fixture(filename, integration))
 | 
						|
 | 
						|
 | 
						|
def load_json_object_fixture(
 | 
						|
    filename: str, integration: str | None = None
 | 
						|
) -> JsonObjectType:
 | 
						|
    """Load a JSON object from a fixture."""
 | 
						|
    return json_loads_object(load_fixture(filename, integration))
 | 
						|
 | 
						|
 | 
						|
def json_round_trip(obj: Any) -> Any:
 | 
						|
    """Round trip an object to JSON."""
 | 
						|
    return json_loads(json_dumps(obj))
 | 
						|
 | 
						|
 | 
						|
def mock_state_change_event(
 | 
						|
    hass: HomeAssistant, new_state: State, old_state: State | None = None
 | 
						|
) -> None:
 | 
						|
    """Mock state change event."""
 | 
						|
    event_data = {
 | 
						|
        "entity_id": new_state.entity_id,
 | 
						|
        "new_state": new_state,
 | 
						|
        "old_state": old_state,
 | 
						|
    }
 | 
						|
    hass.bus.fire(EVENT_STATE_CHANGED, event_data, context=new_state.context)
 | 
						|
 | 
						|
 | 
						|
@callback
 | 
						|
def mock_component(hass: HomeAssistant, component: str) -> None:
 | 
						|
    """Mock a component is setup."""
 | 
						|
    if component in hass.config.components:
 | 
						|
        raise AssertionError(f"Integration {component} is already setup")
 | 
						|
 | 
						|
    hass.config.components.add(component)
 | 
						|
 | 
						|
 | 
						|
def mock_registry(
 | 
						|
    hass: HomeAssistant,
 | 
						|
    mock_entries: dict[str, er.RegistryEntry] | None = None,
 | 
						|
) -> er.EntityRegistry:
 | 
						|
    """Mock the Entity Registry.
 | 
						|
 | 
						|
    This should only be used if you need to mock/re-stage a clean mocked
 | 
						|
    entity registry in your current hass object. It can be useful to,
 | 
						|
    for example, pre-load the registry with items.
 | 
						|
 | 
						|
    This mock will thus replace the existing registry in the running hass.
 | 
						|
 | 
						|
    If you just need to access the existing registry, use the `entity_registry`
 | 
						|
    fixture instead.
 | 
						|
    """
 | 
						|
    registry = er.EntityRegistry(hass)
 | 
						|
    if mock_entries is None:
 | 
						|
        mock_entries = {}
 | 
						|
    registry.deleted_entities = {}
 | 
						|
    registry.entities = er.EntityRegistryItems()
 | 
						|
    registry._entities_data = registry.entities.data
 | 
						|
    for key, entry in mock_entries.items():
 | 
						|
        registry.entities[key] = entry
 | 
						|
 | 
						|
    hass.data[er.DATA_REGISTRY] = registry
 | 
						|
    er.async_get.cache_clear()
 | 
						|
    return registry
 | 
						|
 | 
						|
 | 
						|
def mock_area_registry(
 | 
						|
    hass: HomeAssistant, mock_entries: dict[str, ar.AreaEntry] | None = None
 | 
						|
) -> ar.AreaRegistry:
 | 
						|
    """Mock the Area Registry.
 | 
						|
 | 
						|
    This should only be used if you need to mock/re-stage a clean mocked
 | 
						|
    area registry in your current hass object. It can be useful to,
 | 
						|
    for example, pre-load the registry with items.
 | 
						|
 | 
						|
    This mock will thus replace the existing registry in the running hass.
 | 
						|
 | 
						|
    If you just need to access the existing registry, use the `area_registry`
 | 
						|
    fixture instead.
 | 
						|
    """
 | 
						|
    registry = ar.AreaRegistry(hass)
 | 
						|
    registry.areas = ar.AreaRegistryItems()
 | 
						|
    for key, entry in mock_entries.items():
 | 
						|
        registry.areas[key] = entry
 | 
						|
 | 
						|
    hass.data[ar.DATA_REGISTRY] = registry
 | 
						|
    ar.async_get.cache_clear()
 | 
						|
    return registry
 | 
						|
 | 
						|
 | 
						|
def mock_device_registry(
 | 
						|
    hass: HomeAssistant,
 | 
						|
    mock_entries: dict[str, dr.DeviceEntry] | None = None,
 | 
						|
) -> dr.DeviceRegistry:
 | 
						|
    """Mock the Device Registry.
 | 
						|
 | 
						|
    This should only be used if you need to mock/re-stage a clean mocked
 | 
						|
    device registry in your current hass object. It can be useful to,
 | 
						|
    for example, pre-load the registry with items.
 | 
						|
 | 
						|
    This mock will thus replace the existing registry in the running hass.
 | 
						|
 | 
						|
    If you just need to access the existing registry, use the `device_registry`
 | 
						|
    fixture instead.
 | 
						|
    """
 | 
						|
    registry = dr.DeviceRegistry(hass)
 | 
						|
    registry.devices = dr.ActiveDeviceRegistryItems()
 | 
						|
    registry._device_data = registry.devices.data
 | 
						|
    if mock_entries is None:
 | 
						|
        mock_entries = {}
 | 
						|
    for key, entry in mock_entries.items():
 | 
						|
        registry.devices[key] = entry
 | 
						|
    registry.deleted_devices = dr.DeviceRegistryItems()
 | 
						|
 | 
						|
    hass.data[dr.DATA_REGISTRY] = registry
 | 
						|
    dr.async_get.cache_clear()
 | 
						|
    return registry
 | 
						|
 | 
						|
 | 
						|
class MockGroup(auth_models.Group):
 | 
						|
    """Mock a group in Home Assistant."""
 | 
						|
 | 
						|
    def __init__(self, id: str | None = None, name: str | None = "Mock Group") -> None:
 | 
						|
        """Mock a group."""
 | 
						|
        kwargs = {"name": name, "policy": system_policies.ADMIN_POLICY}
 | 
						|
        if id is not None:
 | 
						|
            kwargs["id"] = id
 | 
						|
 | 
						|
        super().__init__(**kwargs)
 | 
						|
 | 
						|
    def add_to_hass(self, hass: HomeAssistant) -> MockGroup:
 | 
						|
        """Test helper to add entry to hass."""
 | 
						|
        return self.add_to_auth_manager(hass.auth)
 | 
						|
 | 
						|
    def add_to_auth_manager(self, auth_mgr: auth.AuthManager) -> MockGroup:
 | 
						|
        """Test helper to add entry to hass."""
 | 
						|
        ensure_auth_manager_loaded(auth_mgr)
 | 
						|
        auth_mgr._store._groups[self.id] = self
 | 
						|
        return self
 | 
						|
 | 
						|
 | 
						|
class MockUser(auth_models.User):
 | 
						|
    """Mock a user in Home Assistant."""
 | 
						|
 | 
						|
    def __init__(
 | 
						|
        self,
 | 
						|
        id: str | None = None,
 | 
						|
        is_owner: bool = False,
 | 
						|
        is_active: bool = True,
 | 
						|
        name: str | None = "Mock User",
 | 
						|
        system_generated: bool = False,
 | 
						|
        groups: list[auth_models.Group] | None = None,
 | 
						|
    ) -> None:
 | 
						|
        """Initialize mock user."""
 | 
						|
        kwargs = {
 | 
						|
            "is_owner": is_owner,
 | 
						|
            "is_active": is_active,
 | 
						|
            "name": name,
 | 
						|
            "system_generated": system_generated,
 | 
						|
            "groups": groups or [],
 | 
						|
            "perm_lookup": None,
 | 
						|
        }
 | 
						|
        if id is not None:
 | 
						|
            kwargs["id"] = id
 | 
						|
        super().__init__(**kwargs)
 | 
						|
 | 
						|
    def add_to_hass(self, hass: HomeAssistant) -> MockUser:
 | 
						|
        """Test helper to add entry to hass."""
 | 
						|
        return self.add_to_auth_manager(hass.auth)
 | 
						|
 | 
						|
    def add_to_auth_manager(self, auth_mgr: auth.AuthManager) -> MockUser:
 | 
						|
        """Test helper to add entry to hass."""
 | 
						|
        ensure_auth_manager_loaded(auth_mgr)
 | 
						|
        auth_mgr._store._users[self.id] = self
 | 
						|
        return self
 | 
						|
 | 
						|
    def mock_policy(self, policy: auth_permissions.PolicyType) -> None:
 | 
						|
        """Mock a policy for a user."""
 | 
						|
        self.permissions = auth_permissions.PolicyPermissions(policy, self.perm_lookup)
 | 
						|
 | 
						|
 | 
						|
async def register_auth_provider(
 | 
						|
    hass: HomeAssistant, config: ConfigType
 | 
						|
) -> auth_providers.AuthProvider:
 | 
						|
    """Register an auth provider."""
 | 
						|
    provider = await auth_providers.auth_provider_from_config(
 | 
						|
        hass, hass.auth._store, config
 | 
						|
    )
 | 
						|
    assert provider is not None, "Invalid config specified"
 | 
						|
    key = (provider.type, provider.id)
 | 
						|
    providers = hass.auth._providers
 | 
						|
 | 
						|
    if key in providers:
 | 
						|
        raise ValueError("Provider already registered")
 | 
						|
 | 
						|
    providers[key] = provider
 | 
						|
    return provider
 | 
						|
 | 
						|
 | 
						|
@callback
 | 
						|
def ensure_auth_manager_loaded(auth_mgr: auth.AuthManager) -> None:
 | 
						|
    """Ensure an auth manager is considered loaded."""
 | 
						|
    store = auth_mgr._store
 | 
						|
    if store._users is None:
 | 
						|
        store._set_defaults()
 | 
						|
 | 
						|
 | 
						|
class MockModule:
 | 
						|
    """Representation of a fake module."""
 | 
						|
 | 
						|
    def __init__(
 | 
						|
        self,
 | 
						|
        domain: str | None = None,
 | 
						|
        *,
 | 
						|
        dependencies: list[str] | None = None,
 | 
						|
        setup: Callable[[HomeAssistant, ConfigType], bool] | None = None,
 | 
						|
        requirements: list[str] | None = None,
 | 
						|
        config_schema: vol.Schema | None = None,
 | 
						|
        platform_schema: vol.Schema | None = None,
 | 
						|
        platform_schema_base: vol.Schema | None = None,
 | 
						|
        async_setup: Callable[[HomeAssistant, ConfigType], Coroutine[Any, Any, bool]]
 | 
						|
        | None = None,
 | 
						|
        async_setup_entry: Callable[
 | 
						|
            [HomeAssistant, ConfigEntry], Coroutine[Any, Any, bool]
 | 
						|
        ]
 | 
						|
        | None = None,
 | 
						|
        async_unload_entry: Callable[
 | 
						|
            [HomeAssistant, ConfigEntry], Coroutine[Any, Any, bool]
 | 
						|
        ]
 | 
						|
        | None = None,
 | 
						|
        async_migrate_entry: Callable[
 | 
						|
            [HomeAssistant, ConfigEntry], Coroutine[Any, Any, bool]
 | 
						|
        ]
 | 
						|
        | None = None,
 | 
						|
        async_remove_entry: Callable[
 | 
						|
            [HomeAssistant, ConfigEntry], Coroutine[Any, Any, None]
 | 
						|
        ]
 | 
						|
        | None = None,
 | 
						|
        partial_manifest: dict[str, Any] | None = None,
 | 
						|
        async_remove_config_entry_device: Callable[
 | 
						|
            [HomeAssistant, ConfigEntry, dr.DeviceEntry], Coroutine[Any, Any, bool]
 | 
						|
        ]
 | 
						|
        | None = None,
 | 
						|
    ) -> None:
 | 
						|
        """Initialize the mock module."""
 | 
						|
        self.__name__ = f"homeassistant.components.{domain}"
 | 
						|
        self.__file__ = f"homeassistant/components/{domain}"
 | 
						|
        self.DOMAIN = domain
 | 
						|
        self.DEPENDENCIES = dependencies or []
 | 
						|
        self.REQUIREMENTS = requirements or []
 | 
						|
        # Overlay to be used when generating manifest from this module
 | 
						|
        self._partial_manifest = partial_manifest
 | 
						|
 | 
						|
        if config_schema is not None:
 | 
						|
            self.CONFIG_SCHEMA = config_schema
 | 
						|
 | 
						|
        if platform_schema is not None:
 | 
						|
            self.PLATFORM_SCHEMA = platform_schema
 | 
						|
 | 
						|
        if platform_schema_base is not None:
 | 
						|
            self.PLATFORM_SCHEMA_BASE = platform_schema_base
 | 
						|
 | 
						|
        if setup:
 | 
						|
            # We run this in executor, wrap it in function
 | 
						|
            # pylint: disable-next=unnecessary-lambda
 | 
						|
            self.setup = lambda *args: setup(*args)
 | 
						|
 | 
						|
        if async_setup is not None:
 | 
						|
            self.async_setup = async_setup
 | 
						|
 | 
						|
        if setup is None and async_setup is None:
 | 
						|
            self.async_setup = AsyncMock(return_value=True)
 | 
						|
 | 
						|
        if async_setup_entry is not None:
 | 
						|
            self.async_setup_entry = async_setup_entry
 | 
						|
 | 
						|
        if async_unload_entry is not None:
 | 
						|
            self.async_unload_entry = async_unload_entry
 | 
						|
 | 
						|
        if async_migrate_entry is not None:
 | 
						|
            self.async_migrate_entry = async_migrate_entry
 | 
						|
 | 
						|
        if async_remove_entry is not None:
 | 
						|
            self.async_remove_entry = async_remove_entry
 | 
						|
 | 
						|
        if async_remove_config_entry_device is not None:
 | 
						|
            self.async_remove_config_entry_device = async_remove_config_entry_device
 | 
						|
 | 
						|
    def mock_manifest(self):
 | 
						|
        """Generate a mock manifest to represent this module."""
 | 
						|
        return {
 | 
						|
            **loader.manifest_from_legacy_module(self.DOMAIN, self),
 | 
						|
            **(self._partial_manifest or {}),
 | 
						|
        }
 | 
						|
 | 
						|
 | 
						|
class MockPlatform:
 | 
						|
    """Provide a fake platform."""
 | 
						|
 | 
						|
    __name__ = "homeassistant.components.light.bla"
 | 
						|
    __file__ = "homeassistant/components/blah/light"
 | 
						|
 | 
						|
    def __init__(
 | 
						|
        self,
 | 
						|
        *,
 | 
						|
        setup_platform: Callable[
 | 
						|
            [HomeAssistant, ConfigType, AddEntitiesCallback, DiscoveryInfoType | None],
 | 
						|
            None,
 | 
						|
        ]
 | 
						|
        | None = None,
 | 
						|
        dependencies: list[str] | None = None,
 | 
						|
        platform_schema: vol.Schema | None = None,
 | 
						|
        async_setup_platform: Callable[
 | 
						|
            [HomeAssistant, ConfigType, AddEntitiesCallback, DiscoveryInfoType | None],
 | 
						|
            Coroutine[Any, Any, None],
 | 
						|
        ]
 | 
						|
        | None = None,
 | 
						|
        async_setup_entry: Callable[
 | 
						|
            [HomeAssistant, ConfigEntry, AddEntitiesCallback], Coroutine[Any, Any, None]
 | 
						|
        ]
 | 
						|
        | None = None,
 | 
						|
        scan_interval: timedelta | None = None,
 | 
						|
    ) -> None:
 | 
						|
        """Initialize the platform."""
 | 
						|
        self.DEPENDENCIES = dependencies or []
 | 
						|
 | 
						|
        if platform_schema is not None:
 | 
						|
            self.PLATFORM_SCHEMA = platform_schema
 | 
						|
 | 
						|
        if scan_interval is not None:
 | 
						|
            self.SCAN_INTERVAL = scan_interval
 | 
						|
 | 
						|
        if setup_platform is not None:
 | 
						|
            # We run this in executor, wrap it in function
 | 
						|
            # pylint: disable-next=unnecessary-lambda
 | 
						|
            self.setup_platform = lambda *args: setup_platform(*args)
 | 
						|
 | 
						|
        if async_setup_platform is not None:
 | 
						|
            self.async_setup_platform = async_setup_platform
 | 
						|
 | 
						|
        if async_setup_entry is not None:
 | 
						|
            self.async_setup_entry = async_setup_entry
 | 
						|
 | 
						|
        if setup_platform is None and async_setup_platform is None:
 | 
						|
            self.async_setup_platform = AsyncMock(return_value=None)
 | 
						|
 | 
						|
 | 
						|
class MockEntityPlatform(entity_platform.EntityPlatform):
 | 
						|
    """Mock class with some mock defaults."""
 | 
						|
 | 
						|
    def __init__(
 | 
						|
        self,
 | 
						|
        hass: HomeAssistant,
 | 
						|
        logger=None,
 | 
						|
        domain="test_domain",
 | 
						|
        platform_name="test_platform",
 | 
						|
        platform=None,
 | 
						|
        scan_interval=timedelta(seconds=15),
 | 
						|
        entity_namespace=None,
 | 
						|
    ) -> None:
 | 
						|
        """Initialize a mock entity platform."""
 | 
						|
        if logger is None:
 | 
						|
            logger = logging.getLogger("homeassistant.helpers.entity_platform")
 | 
						|
 | 
						|
        # Otherwise the constructor will blow up.
 | 
						|
        if isinstance(platform, Mock) and isinstance(platform.PARALLEL_UPDATES, Mock):
 | 
						|
            platform.PARALLEL_UPDATES = 0
 | 
						|
 | 
						|
        super().__init__(
 | 
						|
            hass=hass,
 | 
						|
            logger=logger,
 | 
						|
            domain=domain,
 | 
						|
            platform_name=platform_name,
 | 
						|
            platform=platform,
 | 
						|
            scan_interval=scan_interval,
 | 
						|
            entity_namespace=entity_namespace,
 | 
						|
        )
 | 
						|
 | 
						|
        @callback
 | 
						|
        def _async_on_stop(_: Event) -> None:
 | 
						|
            self.async_shutdown()
 | 
						|
 | 
						|
        hass.bus.async_listen_once(EVENT_HOMEASSISTANT_STOP, _async_on_stop)
 | 
						|
 | 
						|
 | 
						|
class MockToggleEntity(entity.ToggleEntity):
 | 
						|
    """Provide a mock toggle device."""
 | 
						|
 | 
						|
    def __init__(self, name: str | None, state: Literal["on", "off"] | None) -> None:
 | 
						|
        """Initialize the mock entity."""
 | 
						|
        self._name = name or DEVICE_DEFAULT_NAME
 | 
						|
        self._state = state
 | 
						|
        self.calls: list[tuple[str, dict[str, Any]]] = []
 | 
						|
 | 
						|
    @property
 | 
						|
    def name(self) -> str:
 | 
						|
        """Return the name of the entity if any."""
 | 
						|
        self.calls.append(("name", {}))
 | 
						|
        return self._name
 | 
						|
 | 
						|
    @property
 | 
						|
    def state(self) -> Literal["on", "off"] | None:
 | 
						|
        """Return the state of the entity if any."""
 | 
						|
        self.calls.append(("state", {}))
 | 
						|
        return self._state
 | 
						|
 | 
						|
    @property
 | 
						|
    def is_on(self) -> bool:
 | 
						|
        """Return true if entity is on."""
 | 
						|
        self.calls.append(("is_on", {}))
 | 
						|
        return self._state == STATE_ON
 | 
						|
 | 
						|
    def turn_on(self, **kwargs: Any) -> None:
 | 
						|
        """Turn the entity on."""
 | 
						|
        self.calls.append(("turn_on", kwargs))
 | 
						|
        self._state = STATE_ON
 | 
						|
 | 
						|
    def turn_off(self, **kwargs: Any) -> None:
 | 
						|
        """Turn the entity off."""
 | 
						|
        self.calls.append(("turn_off", kwargs))
 | 
						|
        self._state = STATE_OFF
 | 
						|
 | 
						|
    def last_call(self, method: str | None = None) -> tuple[str, dict[str, Any]]:
 | 
						|
        """Return the last call."""
 | 
						|
        if not self.calls:
 | 
						|
            return None
 | 
						|
        if method is None:
 | 
						|
            return self.calls[-1]
 | 
						|
        try:
 | 
						|
            return next(call for call in reversed(self.calls) if call[0] == method)
 | 
						|
        except StopIteration:
 | 
						|
            return None
 | 
						|
 | 
						|
 | 
						|
class MockConfigEntry(config_entries.ConfigEntry):
 | 
						|
    """Helper for creating config entries that adds some defaults."""
 | 
						|
 | 
						|
    def __init__(
 | 
						|
        self,
 | 
						|
        *,
 | 
						|
        data=None,
 | 
						|
        disabled_by=None,
 | 
						|
        domain="test",
 | 
						|
        entry_id=None,
 | 
						|
        minor_version=1,
 | 
						|
        options=None,
 | 
						|
        pref_disable_new_entities=None,
 | 
						|
        pref_disable_polling=None,
 | 
						|
        reason=None,
 | 
						|
        source=config_entries.SOURCE_USER,
 | 
						|
        state=None,
 | 
						|
        title="Mock Title",
 | 
						|
        unique_id=None,
 | 
						|
        version=1,
 | 
						|
    ) -> None:
 | 
						|
        """Initialize a mock config entry."""
 | 
						|
        kwargs = {
 | 
						|
            "data": data or {},
 | 
						|
            "disabled_by": disabled_by,
 | 
						|
            "domain": domain,
 | 
						|
            "entry_id": entry_id or ulid_util.ulid_now(),
 | 
						|
            "minor_version": minor_version,
 | 
						|
            "options": options or {},
 | 
						|
            "pref_disable_new_entities": pref_disable_new_entities,
 | 
						|
            "pref_disable_polling": pref_disable_polling,
 | 
						|
            "title": title,
 | 
						|
            "unique_id": unique_id,
 | 
						|
            "version": version,
 | 
						|
        }
 | 
						|
        if source is not None:
 | 
						|
            kwargs["source"] = source
 | 
						|
        if state is not None:
 | 
						|
            kwargs["state"] = state
 | 
						|
        super().__init__(**kwargs)
 | 
						|
        if reason is not None:
 | 
						|
            object.__setattr__(self, "reason", reason)
 | 
						|
 | 
						|
    def add_to_hass(self, hass: HomeAssistant) -> None:
 | 
						|
        """Test helper to add entry to hass."""
 | 
						|
        hass.config_entries._entries[self.entry_id] = self
 | 
						|
 | 
						|
    def add_to_manager(self, manager: config_entries.ConfigEntries) -> None:
 | 
						|
        """Test helper to add entry to entry manager."""
 | 
						|
        manager._entries[self.entry_id] = self
 | 
						|
 | 
						|
    def mock_state(
 | 
						|
        self,
 | 
						|
        hass: HomeAssistant,
 | 
						|
        state: config_entries.ConfigEntryState,
 | 
						|
        reason: str | None = None,
 | 
						|
    ) -> None:
 | 
						|
        """Mock the state of a config entry to be used in tests.
 | 
						|
 | 
						|
        Currently this is a wrapper around _async_set_state, but it may
 | 
						|
        change in the future.
 | 
						|
 | 
						|
        It is preferable to get the config entry into the desired state
 | 
						|
        by using the normal config entry methods, and this helper
 | 
						|
        is only intended to be used in cases where that is not possible.
 | 
						|
 | 
						|
        When in doubt, this helper should not be used in new code
 | 
						|
        and is only intended for backwards compatibility with existing
 | 
						|
        tests.
 | 
						|
        """
 | 
						|
        self._async_set_state(hass, state, reason)
 | 
						|
 | 
						|
 | 
						|
def patch_yaml_files(files_dict, endswith=True):
 | 
						|
    """Patch load_yaml with a dictionary of yaml files."""
 | 
						|
    # match using endswith, start search with longest string
 | 
						|
    matchlist = sorted(files_dict.keys(), key=len) if endswith else []
 | 
						|
 | 
						|
    def mock_open_f(fname, **_):
 | 
						|
        """Mock open() in the yaml module, used by load_yaml."""
 | 
						|
        # Return the mocked file on full match
 | 
						|
        if isinstance(fname, pathlib.Path):
 | 
						|
            fname = str(fname)
 | 
						|
 | 
						|
        if fname in files_dict:
 | 
						|
            _LOGGER.debug("patch_yaml_files match %s", fname)
 | 
						|
            res = StringIO(files_dict[fname])
 | 
						|
            setattr(res, "name", fname)
 | 
						|
            return res
 | 
						|
 | 
						|
        # Match using endswith
 | 
						|
        for ends in matchlist:
 | 
						|
            if fname.endswith(ends):
 | 
						|
                _LOGGER.debug("patch_yaml_files end match %s: %s", ends, fname)
 | 
						|
                res = StringIO(files_dict[ends])
 | 
						|
                setattr(res, "name", fname)
 | 
						|
                return res
 | 
						|
 | 
						|
        # Fallback for hass.components (i.e. services.yaml)
 | 
						|
        if "homeassistant/components" in fname:
 | 
						|
            _LOGGER.debug("patch_yaml_files using real file: %s", fname)
 | 
						|
            return open(fname, encoding="utf-8")
 | 
						|
 | 
						|
        # Not found
 | 
						|
        raise FileNotFoundError(f"File not found: {fname}")
 | 
						|
 | 
						|
    return patch.object(yaml_loader, "open", mock_open_f, create=True)
 | 
						|
 | 
						|
 | 
						|
@contextmanager
 | 
						|
def assert_setup_component(count, domain=None):
 | 
						|
    """Collect valid configuration from setup_component.
 | 
						|
 | 
						|
    - count: The amount of valid platforms that should be setup
 | 
						|
    - domain: The domain to count is optional. It can be automatically
 | 
						|
              determined most of the time
 | 
						|
 | 
						|
    Use as a context manager around setup.setup_component
 | 
						|
        with assert_setup_component(0) as result_config:
 | 
						|
            setup_component(hass, domain, start_config)
 | 
						|
            # using result_config is optional
 | 
						|
    """
 | 
						|
    config = {}
 | 
						|
 | 
						|
    async def mock_psc(hass, config_input, integration, component=None):
 | 
						|
        """Mock the prepare_setup_component to capture config."""
 | 
						|
        domain_input = integration.domain
 | 
						|
        integration_config_info = await async_process_component_config(
 | 
						|
            hass, config_input, integration, component
 | 
						|
        )
 | 
						|
        res = integration_config_info.config
 | 
						|
        config[domain_input] = None if res is None else res.get(domain_input)
 | 
						|
        _LOGGER.debug(
 | 
						|
            "Configuration for %s, Validated: %s, Original %s",
 | 
						|
            domain_input,
 | 
						|
            config[domain_input],
 | 
						|
            config_input.get(domain_input),
 | 
						|
        )
 | 
						|
        return integration_config_info
 | 
						|
 | 
						|
    assert isinstance(config, dict)
 | 
						|
    with patch("homeassistant.config.async_process_component_config", mock_psc):
 | 
						|
        yield config
 | 
						|
 | 
						|
    if domain is None:
 | 
						|
        assert (
 | 
						|
            len(config) == 1
 | 
						|
        ), f"assert_setup_component requires DOMAIN: {list(config.keys())}"
 | 
						|
        domain = list(config.keys())[0]
 | 
						|
 | 
						|
    res = config.get(domain)
 | 
						|
    res_len = 0 if res is None else len(res)
 | 
						|
    assert (
 | 
						|
        res_len == count
 | 
						|
    ), f"setup_component failed, expected {count} got {res_len}: {res}"
 | 
						|
 | 
						|
 | 
						|
def mock_restore_cache(hass: HomeAssistant, states: Sequence[State]) -> None:
 | 
						|
    """Mock the DATA_RESTORE_CACHE."""
 | 
						|
    key = rs.DATA_RESTORE_STATE
 | 
						|
    data = rs.RestoreStateData(hass)
 | 
						|
    now = dt_util.utcnow()
 | 
						|
 | 
						|
    last_states = {}
 | 
						|
    for state in states:
 | 
						|
        restored_state = state.as_dict()
 | 
						|
        restored_state = {
 | 
						|
            **restored_state,
 | 
						|
            "attributes": json.loads(
 | 
						|
                json.dumps(restored_state["attributes"], cls=JSONEncoder)
 | 
						|
            ),
 | 
						|
        }
 | 
						|
        last_states[state.entity_id] = rs.StoredState.from_dict(
 | 
						|
            {"state": restored_state, "last_seen": now}
 | 
						|
        )
 | 
						|
    data.last_states = last_states
 | 
						|
    _LOGGER.debug("Restore cache: %s", data.last_states)
 | 
						|
    assert len(data.last_states) == len(states), f"Duplicate entity_id? {states}"
 | 
						|
 | 
						|
    rs.async_get.cache_clear()
 | 
						|
    hass.data[key] = data
 | 
						|
 | 
						|
 | 
						|
def mock_restore_cache_with_extra_data(
 | 
						|
    hass: HomeAssistant, states: Sequence[tuple[State, Mapping[str, Any]]]
 | 
						|
) -> None:
 | 
						|
    """Mock the DATA_RESTORE_CACHE."""
 | 
						|
    key = rs.DATA_RESTORE_STATE
 | 
						|
    data = rs.RestoreStateData(hass)
 | 
						|
    now = dt_util.utcnow()
 | 
						|
 | 
						|
    last_states = {}
 | 
						|
    for state, extra_data in states:
 | 
						|
        restored_state = state.as_dict()
 | 
						|
        restored_state = {
 | 
						|
            **restored_state,
 | 
						|
            "attributes": json.loads(
 | 
						|
                json.dumps(restored_state["attributes"], cls=JSONEncoder)
 | 
						|
            ),
 | 
						|
        }
 | 
						|
        last_states[state.entity_id] = rs.StoredState.from_dict(
 | 
						|
            {"state": restored_state, "extra_data": extra_data, "last_seen": now}
 | 
						|
        )
 | 
						|
    data.last_states = last_states
 | 
						|
    _LOGGER.debug("Restore cache: %s", data.last_states)
 | 
						|
    assert len(data.last_states) == len(states), f"Duplicate entity_id? {states}"
 | 
						|
 | 
						|
    rs.async_get.cache_clear()
 | 
						|
    hass.data[key] = data
 | 
						|
 | 
						|
 | 
						|
async def async_mock_restore_state_shutdown_restart(
 | 
						|
    hass: HomeAssistant,
 | 
						|
) -> rs.RestoreStateData:
 | 
						|
    """Mock shutting down and saving restore state and restoring."""
 | 
						|
    data = rs.async_get(hass)
 | 
						|
    await data.async_dump_states()
 | 
						|
    await async_mock_load_restore_state_from_storage(hass)
 | 
						|
    return data
 | 
						|
 | 
						|
 | 
						|
async def async_mock_load_restore_state_from_storage(
 | 
						|
    hass: HomeAssistant,
 | 
						|
) -> None:
 | 
						|
    """Mock loading restore state from storage.
 | 
						|
 | 
						|
    hass_storage must already be mocked.
 | 
						|
    """
 | 
						|
    await rs.async_get(hass).async_load()
 | 
						|
 | 
						|
 | 
						|
class MockEntity(entity.Entity):
 | 
						|
    """Mock Entity class."""
 | 
						|
 | 
						|
    def __init__(self, **values: Any) -> None:
 | 
						|
        """Initialize an entity."""
 | 
						|
        self._values = values
 | 
						|
 | 
						|
        if "entity_id" in values:
 | 
						|
            self.entity_id = values["entity_id"]
 | 
						|
 | 
						|
    @property
 | 
						|
    def available(self) -> bool:
 | 
						|
        """Return True if entity is available."""
 | 
						|
        return self._handle("available")
 | 
						|
 | 
						|
    @property
 | 
						|
    def capability_attributes(self) -> Mapping[str, Any] | None:
 | 
						|
        """Info about capabilities."""
 | 
						|
        return self._handle("capability_attributes")
 | 
						|
 | 
						|
    @property
 | 
						|
    def device_class(self) -> str | None:
 | 
						|
        """Info how device should be classified."""
 | 
						|
        return self._handle("device_class")
 | 
						|
 | 
						|
    @property
 | 
						|
    def device_info(self) -> dr.DeviceInfo | None:
 | 
						|
        """Info how it links to a device."""
 | 
						|
        return self._handle("device_info")
 | 
						|
 | 
						|
    @property
 | 
						|
    def entity_category(self) -> entity.EntityCategory | None:
 | 
						|
        """Return the entity category."""
 | 
						|
        return self._handle("entity_category")
 | 
						|
 | 
						|
    @property
 | 
						|
    def extra_state_attributes(self) -> Mapping[str, Any] | None:
 | 
						|
        """Return entity specific state attributes."""
 | 
						|
        return self._handle("extra_state_attributes")
 | 
						|
 | 
						|
    @property
 | 
						|
    def has_entity_name(self) -> bool:
 | 
						|
        """Return the has_entity_name name flag."""
 | 
						|
        return self._handle("has_entity_name")
 | 
						|
 | 
						|
    @property
 | 
						|
    def entity_registry_enabled_default(self) -> bool:
 | 
						|
        """Return if the entity should be enabled when first added to the entity registry."""
 | 
						|
        return self._handle("entity_registry_enabled_default")
 | 
						|
 | 
						|
    @property
 | 
						|
    def entity_registry_visible_default(self) -> bool:
 | 
						|
        """Return if the entity should be visible when first added to the entity registry."""
 | 
						|
        return self._handle("entity_registry_visible_default")
 | 
						|
 | 
						|
    @property
 | 
						|
    def icon(self) -> str | None:
 | 
						|
        """Return the suggested icon."""
 | 
						|
        return self._handle("icon")
 | 
						|
 | 
						|
    @property
 | 
						|
    def name(self) -> str | None:
 | 
						|
        """Return the name of the entity."""
 | 
						|
        return self._handle("name")
 | 
						|
 | 
						|
    @property
 | 
						|
    def should_poll(self) -> bool:
 | 
						|
        """Return the ste of the polling."""
 | 
						|
        return self._handle("should_poll")
 | 
						|
 | 
						|
    @property
 | 
						|
    def supported_features(self) -> int | None:
 | 
						|
        """Info about supported features."""
 | 
						|
        return self._handle("supported_features")
 | 
						|
 | 
						|
    @property
 | 
						|
    def translation_key(self) -> str | None:
 | 
						|
        """Return the translation key."""
 | 
						|
        return self._handle("translation_key")
 | 
						|
 | 
						|
    @property
 | 
						|
    def unique_id(self) -> str | None:
 | 
						|
        """Return the unique ID of the entity."""
 | 
						|
        return self._handle("unique_id")
 | 
						|
 | 
						|
    @property
 | 
						|
    def unit_of_measurement(self) -> str | None:
 | 
						|
        """Info on the units the entity state is in."""
 | 
						|
        return self._handle("unit_of_measurement")
 | 
						|
 | 
						|
    def _handle(self, attr: str) -> Any:
 | 
						|
        """Return attribute value."""
 | 
						|
        if attr in self._values:
 | 
						|
            return self._values[attr]
 | 
						|
        return getattr(super(), attr)
 | 
						|
 | 
						|
 | 
						|
@contextmanager
 | 
						|
def mock_storage(data: dict[str, Any] | None = None) -> Generator[dict[str, Any]]:
 | 
						|
    """Mock storage.
 | 
						|
 | 
						|
    Data is a dict {'key': {'version': version, 'data': data}}
 | 
						|
 | 
						|
    Written data will be converted to JSON to ensure JSON parsing works.
 | 
						|
    """
 | 
						|
    if data is None:
 | 
						|
        data = {}
 | 
						|
 | 
						|
    orig_load = storage.Store._async_load
 | 
						|
 | 
						|
    async def mock_async_load(
 | 
						|
        store: storage.Store,
 | 
						|
    ) -> dict[str, Any] | list[Any] | None:
 | 
						|
        """Mock version of load."""
 | 
						|
        if store._data is None:
 | 
						|
            # No data to load
 | 
						|
            if store.key not in data:
 | 
						|
                # Make sure the next attempt will still load
 | 
						|
                store._load_task = None
 | 
						|
                return None
 | 
						|
 | 
						|
            mock_data = data.get(store.key)
 | 
						|
 | 
						|
            if "data" not in mock_data or "version" not in mock_data:
 | 
						|
                _LOGGER.error('Mock data needs "version" and "data"')
 | 
						|
                raise ValueError('Mock data needs "version" and "data"')
 | 
						|
 | 
						|
            store._data = mock_data
 | 
						|
 | 
						|
        # Route through original load so that we trigger migration
 | 
						|
        loaded = await orig_load(store)
 | 
						|
        _LOGGER.debug("Loading data for %s: %s", store.key, loaded)
 | 
						|
        return loaded
 | 
						|
 | 
						|
    async def mock_write_data(
 | 
						|
        store: storage.Store, path: str, data_to_write: dict[str, Any]
 | 
						|
    ) -> None:
 | 
						|
        """Mock version of write data."""
 | 
						|
        # To ensure that the data can be serialized
 | 
						|
        _LOGGER.debug("Writing data to %s: %s", store.key, data_to_write)
 | 
						|
        raise_contains_mocks(data_to_write)
 | 
						|
 | 
						|
        if "data_func" in data_to_write:
 | 
						|
            data_to_write["data"] = data_to_write.pop("data_func")()
 | 
						|
 | 
						|
        encoder = store._encoder
 | 
						|
        if encoder and encoder is not JSONEncoder:
 | 
						|
            # If they pass a custom encoder that is not the
 | 
						|
            # default JSONEncoder, we use the slow path of json.dumps
 | 
						|
            dump = ft.partial(json.dumps, cls=store._encoder)
 | 
						|
        else:
 | 
						|
            dump = _orjson_default_encoder
 | 
						|
        data[store.key] = json_loads(dump(data_to_write))
 | 
						|
 | 
						|
    async def mock_remove(store: storage.Store) -> None:
 | 
						|
        """Remove data."""
 | 
						|
        data.pop(store.key, None)
 | 
						|
 | 
						|
    with (
 | 
						|
        patch(
 | 
						|
            "homeassistant.helpers.storage.Store._async_load",
 | 
						|
            side_effect=mock_async_load,
 | 
						|
            autospec=True,
 | 
						|
        ),
 | 
						|
        patch(
 | 
						|
            "homeassistant.helpers.storage.Store._async_write_data",
 | 
						|
            side_effect=mock_write_data,
 | 
						|
            autospec=True,
 | 
						|
        ),
 | 
						|
        patch(
 | 
						|
            "homeassistant.helpers.storage.Store.async_remove",
 | 
						|
            side_effect=mock_remove,
 | 
						|
            autospec=True,
 | 
						|
        ),
 | 
						|
    ):
 | 
						|
        yield data
 | 
						|
 | 
						|
 | 
						|
async def flush_store(store: storage.Store) -> None:
 | 
						|
    """Make sure all delayed writes of a store are written."""
 | 
						|
    if store._data is None:
 | 
						|
        return
 | 
						|
 | 
						|
    store._async_cleanup_final_write_listener()
 | 
						|
    store._async_cleanup_delay_listener()
 | 
						|
    await store._async_handle_write_data()
 | 
						|
 | 
						|
 | 
						|
async def get_system_health_info(hass: HomeAssistant, domain: str) -> dict[str, Any]:
 | 
						|
    """Get system health info."""
 | 
						|
    return await hass.data["system_health"][domain].info_callback(hass)
 | 
						|
 | 
						|
 | 
						|
@contextmanager
 | 
						|
def mock_config_flow(domain: str, config_flow: type[ConfigFlow]) -> None:
 | 
						|
    """Mock a config flow handler."""
 | 
						|
    original_handler = config_entries.HANDLERS.get(domain)
 | 
						|
    config_entries.HANDLERS[domain] = config_flow
 | 
						|
    _LOGGER.info("Adding mock config flow: %s", domain)
 | 
						|
    yield
 | 
						|
    config_entries.HANDLERS.pop(domain)
 | 
						|
    if original_handler:
 | 
						|
        config_entries.HANDLERS[domain] = original_handler
 | 
						|
 | 
						|
 | 
						|
def mock_integration(
 | 
						|
    hass: HomeAssistant,
 | 
						|
    module: MockModule,
 | 
						|
    built_in: bool = True,
 | 
						|
    top_level_files: set[str] | None = None,
 | 
						|
) -> loader.Integration:
 | 
						|
    """Mock an integration."""
 | 
						|
    integration = loader.Integration(
 | 
						|
        hass,
 | 
						|
        f"{loader.PACKAGE_BUILTIN}.{module.DOMAIN}"
 | 
						|
        if built_in
 | 
						|
        else f"{loader.PACKAGE_CUSTOM_COMPONENTS}.{module.DOMAIN}",
 | 
						|
        pathlib.Path(""),
 | 
						|
        module.mock_manifest(),
 | 
						|
        top_level_files,
 | 
						|
    )
 | 
						|
 | 
						|
    def mock_import_platform(platform_name: str) -> NoReturn:
 | 
						|
        raise ImportError(
 | 
						|
            f"Mocked unable to import platform '{integration.pkg_path}.{platform_name}'",
 | 
						|
            name=f"{integration.pkg_path}.{platform_name}",
 | 
						|
        )
 | 
						|
 | 
						|
    integration._import_platform = mock_import_platform
 | 
						|
 | 
						|
    _LOGGER.info("Adding mock integration: %s", module.DOMAIN)
 | 
						|
    integration_cache = hass.data[loader.DATA_INTEGRATIONS]
 | 
						|
    integration_cache[module.DOMAIN] = integration
 | 
						|
 | 
						|
    module_cache = hass.data[loader.DATA_COMPONENTS]
 | 
						|
    module_cache[module.DOMAIN] = module
 | 
						|
 | 
						|
    return integration
 | 
						|
 | 
						|
 | 
						|
def mock_platform(
 | 
						|
    hass: HomeAssistant,
 | 
						|
    platform_path: str,
 | 
						|
    module: Mock | MockPlatform | None = None,
 | 
						|
    built_in=True,
 | 
						|
) -> None:
 | 
						|
    """Mock a platform.
 | 
						|
 | 
						|
    platform_path is in form hue.config_flow.
 | 
						|
    """
 | 
						|
    domain, _, platform_name = platform_path.partition(".")
 | 
						|
    integration_cache = hass.data[loader.DATA_INTEGRATIONS]
 | 
						|
    module_cache = hass.data[loader.DATA_COMPONENTS]
 | 
						|
 | 
						|
    if domain not in integration_cache:
 | 
						|
        mock_integration(hass, MockModule(domain), built_in=built_in)
 | 
						|
 | 
						|
    integration_cache[domain]._top_level_files.add(f"{platform_name}.py")
 | 
						|
    _LOGGER.info("Adding mock integration platform: %s", platform_path)
 | 
						|
    module_cache[platform_path] = module or Mock()
 | 
						|
 | 
						|
 | 
						|
def async_capture_events(hass: HomeAssistant, event_name: str) -> list[Event]:
 | 
						|
    """Create a helper that captures events."""
 | 
						|
    events = []
 | 
						|
 | 
						|
    @callback
 | 
						|
    def capture_events(event: Event) -> None:
 | 
						|
        events.append(event)
 | 
						|
 | 
						|
    hass.bus.async_listen(event_name, capture_events)
 | 
						|
 | 
						|
    return events
 | 
						|
 | 
						|
 | 
						|
@callback
 | 
						|
def async_mock_signal(
 | 
						|
    hass: HomeAssistant, signal: SignalType[Any] | str
 | 
						|
) -> list[tuple[Any]]:
 | 
						|
    """Catch all dispatches to a signal."""
 | 
						|
    calls = []
 | 
						|
 | 
						|
    @callback
 | 
						|
    def mock_signal_handler(*args: Any) -> None:
 | 
						|
        """Mock service call."""
 | 
						|
        calls.append(args)
 | 
						|
 | 
						|
    async_dispatcher_connect(hass, signal, mock_signal_handler)
 | 
						|
 | 
						|
    return calls
 | 
						|
 | 
						|
 | 
						|
_SENTINEL = object()
 | 
						|
 | 
						|
 | 
						|
class _HA_ANY:
 | 
						|
    """A helper object that compares equal to everything.
 | 
						|
 | 
						|
    Based on unittest.mock.ANY, but modified to not show up in pytest's equality
 | 
						|
    assertion diffs.
 | 
						|
    """
 | 
						|
 | 
						|
    _other = _SENTINEL
 | 
						|
 | 
						|
    def __eq__(self, other: object) -> bool:
 | 
						|
        """Test equal."""
 | 
						|
        self._other = other
 | 
						|
        return True
 | 
						|
 | 
						|
    def __ne__(self, other: object) -> bool:
 | 
						|
        """Test not equal."""
 | 
						|
        self._other = other
 | 
						|
        return False
 | 
						|
 | 
						|
    def __repr__(self) -> str:
 | 
						|
        """Return repr() other to not show up in pytest quality diffs."""
 | 
						|
        if self._other is _SENTINEL:
 | 
						|
            return "<ANY>"
 | 
						|
        return repr(self._other)
 | 
						|
 | 
						|
 | 
						|
ANY = _HA_ANY()
 | 
						|
 | 
						|
 | 
						|
def raise_contains_mocks(val: Any) -> None:
 | 
						|
    """Raise for mocks."""
 | 
						|
    if isinstance(val, Mock):
 | 
						|
        raise TypeError(val)
 | 
						|
 | 
						|
    if isinstance(val, dict):
 | 
						|
        for dict_value in val.values():
 | 
						|
            raise_contains_mocks(dict_value)
 | 
						|
 | 
						|
    if isinstance(val, list):
 | 
						|
        for dict_value in val:
 | 
						|
            raise_contains_mocks(dict_value)
 | 
						|
 | 
						|
 | 
						|
@callback
 | 
						|
def async_get_persistent_notifications(
 | 
						|
    hass: HomeAssistant,
 | 
						|
) -> dict[str, pn.Notification]:
 | 
						|
    """Get the current persistent notifications."""
 | 
						|
    return pn._async_get_or_create_notifications(hass)
 | 
						|
 | 
						|
 | 
						|
def async_mock_cloud_connection_status(hass: HomeAssistant, connected: bool) -> None:
 | 
						|
    """Mock a signal the cloud disconnected."""
 | 
						|
    # pylint: disable-next=import-outside-toplevel
 | 
						|
    from homeassistant.components.cloud import (
 | 
						|
        SIGNAL_CLOUD_CONNECTION_STATE,
 | 
						|
        CloudConnectionState,
 | 
						|
    )
 | 
						|
 | 
						|
    if connected:
 | 
						|
        state = CloudConnectionState.CLOUD_CONNECTED
 | 
						|
    else:
 | 
						|
        state = CloudConnectionState.CLOUD_DISCONNECTED
 | 
						|
    async_dispatcher_send(hass, SIGNAL_CLOUD_CONNECTION_STATE, state)
 | 
						|
 | 
						|
 | 
						|
def import_and_test_deprecated_constant_enum(
 | 
						|
    caplog: pytest.LogCaptureFixture,
 | 
						|
    module: ModuleType,
 | 
						|
    replacement: Enum,
 | 
						|
    constant_prefix: str,
 | 
						|
    breaks_in_ha_version: str,
 | 
						|
) -> None:
 | 
						|
    """Import and test deprecated constant replaced by a enum.
 | 
						|
 | 
						|
    - Import deprecated enum
 | 
						|
    - Assert value is the same as the replacement
 | 
						|
    - Assert a warning is logged
 | 
						|
    - Assert the deprecated constant is included in the modules.__dir__()
 | 
						|
    - Assert the deprecated constant is included in the modules.__all__()
 | 
						|
    """
 | 
						|
    import_and_test_deprecated_constant(
 | 
						|
        caplog,
 | 
						|
        module,
 | 
						|
        constant_prefix + replacement.name,
 | 
						|
        f"{replacement.__class__.__name__}.{replacement.name}",
 | 
						|
        replacement,
 | 
						|
        breaks_in_ha_version,
 | 
						|
    )
 | 
						|
 | 
						|
 | 
						|
def import_and_test_deprecated_constant(
 | 
						|
    caplog: pytest.LogCaptureFixture,
 | 
						|
    module: ModuleType,
 | 
						|
    constant_name: str,
 | 
						|
    replacement_name: str,
 | 
						|
    replacement: Any,
 | 
						|
    breaks_in_ha_version: str,
 | 
						|
) -> None:
 | 
						|
    """Import and test deprecated constant replaced by a value.
 | 
						|
 | 
						|
    - Import deprecated constant
 | 
						|
    - Assert value is the same as the replacement
 | 
						|
    - Assert a warning is logged
 | 
						|
    - Assert the deprecated constant is included in the modules.__dir__()
 | 
						|
    - Assert the deprecated constant is included in the modules.__all__()
 | 
						|
    """
 | 
						|
    value = import_deprecated_constant(module, constant_name)
 | 
						|
    assert value == replacement
 | 
						|
    assert (
 | 
						|
        module.__name__,
 | 
						|
        logging.WARNING,
 | 
						|
        (
 | 
						|
            f"{constant_name} was used from test_constant_deprecation,"
 | 
						|
            f" this is a deprecated constant which will be removed in HA Core {breaks_in_ha_version}. "
 | 
						|
            f"Use {replacement_name} instead, please report "
 | 
						|
            "it to the author of the 'test_constant_deprecation' custom integration"
 | 
						|
        ),
 | 
						|
    ) in caplog.record_tuples
 | 
						|
 | 
						|
    # verify deprecated constant is included in dir()
 | 
						|
    assert constant_name in dir(module)
 | 
						|
    assert constant_name in module.__all__
 | 
						|
 | 
						|
 | 
						|
def import_and_test_deprecated_alias(
 | 
						|
    caplog: pytest.LogCaptureFixture,
 | 
						|
    module: ModuleType,
 | 
						|
    alias_name: str,
 | 
						|
    replacement: Any,
 | 
						|
    breaks_in_ha_version: str,
 | 
						|
) -> None:
 | 
						|
    """Import and test deprecated alias replaced by a value.
 | 
						|
 | 
						|
    - Import deprecated alias
 | 
						|
    - Assert value is the same as the replacement
 | 
						|
    - Assert a warning is logged
 | 
						|
    - Assert the deprecated alias is included in the modules.__dir__()
 | 
						|
    - Assert the deprecated alias is included in the modules.__all__()
 | 
						|
    """
 | 
						|
    replacement_name = f"{replacement.__module__}.{replacement.__name__}"
 | 
						|
    value = import_deprecated_constant(module, alias_name)
 | 
						|
    assert value == replacement
 | 
						|
    assert (
 | 
						|
        module.__name__,
 | 
						|
        logging.WARNING,
 | 
						|
        (
 | 
						|
            f"{alias_name} was used from test_constant_deprecation,"
 | 
						|
            f" this is a deprecated alias which will be removed in HA Core {breaks_in_ha_version}. "
 | 
						|
            f"Use {replacement_name} instead, please report "
 | 
						|
            "it to the author of the 'test_constant_deprecation' custom integration"
 | 
						|
        ),
 | 
						|
    ) in caplog.record_tuples
 | 
						|
 | 
						|
    # verify deprecated alias is included in dir()
 | 
						|
    assert alias_name in dir(module)
 | 
						|
    assert alias_name in module.__all__
 | 
						|
 | 
						|
 | 
						|
def help_test_all(module: ModuleType) -> None:
 | 
						|
    """Test module.__all__ is correctly set."""
 | 
						|
    assert set(module.__all__) == {
 | 
						|
        itm for itm in dir(module) if not itm.startswith("_")
 | 
						|
    }
 | 
						|
 | 
						|
 | 
						|
def extract_stack_to_frame(extract_stack: list[Mock]) -> FrameType:
 | 
						|
    """Convert an extract stack to a frame list."""
 | 
						|
    stack = list(extract_stack)
 | 
						|
    _globals = globals()
 | 
						|
    for frame in stack:
 | 
						|
        frame.f_back = None
 | 
						|
        frame.f_globals = _globals
 | 
						|
        frame.f_code.co_filename = frame.filename
 | 
						|
        frame.f_lineno = int(frame.lineno)
 | 
						|
 | 
						|
    top_frame = stack.pop()
 | 
						|
    current_frame = top_frame
 | 
						|
    while stack and (next_frame := stack.pop()):
 | 
						|
        current_frame.f_back = next_frame
 | 
						|
        current_frame = next_frame
 | 
						|
 | 
						|
    return top_frame
 | 
						|
 | 
						|
 | 
						|
def setup_test_component_platform(
 | 
						|
    hass: HomeAssistant,
 | 
						|
    domain: str,
 | 
						|
    entities: Sequence[Entity],
 | 
						|
    from_config_entry: bool = False,
 | 
						|
    built_in: bool = True,
 | 
						|
) -> MockPlatform:
 | 
						|
    """Mock a test component platform for tests."""
 | 
						|
 | 
						|
    async def _async_setup_platform(
 | 
						|
        hass: HomeAssistant,
 | 
						|
        config: ConfigType,
 | 
						|
        async_add_entities: AddEntitiesCallback,
 | 
						|
        discovery_info: DiscoveryInfoType | None = None,
 | 
						|
    ) -> None:
 | 
						|
        """Set up a test component platform."""
 | 
						|
        async_add_entities(entities)
 | 
						|
 | 
						|
    platform = MockPlatform(
 | 
						|
        async_setup_platform=_async_setup_platform,
 | 
						|
    )
 | 
						|
 | 
						|
    # avoid creating config entry setup if not needed
 | 
						|
    if from_config_entry:
 | 
						|
 | 
						|
        async def _async_setup_entry(
 | 
						|
            hass: HomeAssistant,
 | 
						|
            entry: ConfigEntry,
 | 
						|
            async_add_entities: AddEntitiesCallback,
 | 
						|
        ) -> None:
 | 
						|
            """Set up a test component platform."""
 | 
						|
            async_add_entities(entities)
 | 
						|
 | 
						|
        platform.async_setup_entry = _async_setup_entry
 | 
						|
        platform.async_setup_platform = None
 | 
						|
 | 
						|
    mock_platform(hass, f"test.{domain}", platform, built_in=built_in)
 | 
						|
    return platform
 | 
						|
 | 
						|
 | 
						|
async def snapshot_platform(
 | 
						|
    hass: HomeAssistant,
 | 
						|
    entity_registry: er.EntityRegistry,
 | 
						|
    snapshot: SnapshotAssertion,
 | 
						|
    config_entry_id: str,
 | 
						|
) -> None:
 | 
						|
    """Snapshot a platform."""
 | 
						|
    entity_entries = er.async_entries_for_config_entry(entity_registry, config_entry_id)
 | 
						|
    assert entity_entries
 | 
						|
    assert (
 | 
						|
        len({entity_entry.domain for entity_entry in entity_entries}) == 1
 | 
						|
    ), "Please limit the loaded platforms to 1 platform."
 | 
						|
    for entity_entry in entity_entries:
 | 
						|
        assert entity_entry == snapshot(name=f"{entity_entry.entity_id}-entry")
 | 
						|
        assert entity_entry.disabled_by is None, "Please enable all entities."
 | 
						|
        state = hass.states.get(entity_entry.entity_id)
 | 
						|
        assert state, f"State not found for {entity_entry.entity_id}"
 | 
						|
        assert state == snapshot(name=f"{entity_entry.entity_id}-state")
 |