"""Helper to help coordinating calls.""" from __future__ import annotations import asyncio from collections.abc import Callable, Coroutine import functools from typing import Any, Literal, assert_type, cast, overload from homeassistant.core import HomeAssistant from homeassistant.loader import bind_hass from homeassistant.util.hass_dict import HassKey type _FuncType[_T] = Callable[[HomeAssistant], _T] type _Coro[_T] = Coroutine[Any, Any, _T] @overload def singleton[_T]( data_key: HassKey[_T], *, async_: Literal[True] ) -> Callable[[_FuncType[_Coro[_T]]], _FuncType[_Coro[_T]]]: ... @overload def singleton[_T]( data_key: HassKey[_T], ) -> Callable[[_FuncType[_T]], _FuncType[_T]]: ... @overload def singleton[_T](data_key: str) -> Callable[[_FuncType[_T]], _FuncType[_T]]: ... def singleton[_S, _T, _U]( data_key: Any, *, async_: bool = False ) -> Callable[[_FuncType[_S]], _FuncType[_S]]: """Decorate a function that should be called once per instance. Result will be cached and simultaneous calls will be handled. """ @overload def wrapper(func: _FuncType[_Coro[_T]]) -> _FuncType[_Coro[_T]]: ... @overload def wrapper(func: _FuncType[_U]) -> _FuncType[_U]: ... def wrapper(func: _FuncType[_Coro[_T] | _U]) -> _FuncType[_Coro[_T] | _U]: """Wrap a function with caching logic.""" if not asyncio.iscoroutinefunction(func): @functools.lru_cache(maxsize=1) @bind_hass @functools.wraps(func) def wrapped(hass: HomeAssistant) -> _U: if data_key not in hass.data: hass.data[data_key] = func(hass) return cast(_U, hass.data[data_key]) return wrapped @bind_hass @functools.wraps(func) async def async_wrapped(hass: HomeAssistant) -> _T: if data_key not in hass.data: evt = hass.data[data_key] = asyncio.Event() result = await func(hass) hass.data[data_key] = result evt.set() return cast(_T, result) obj_or_evt = hass.data[data_key] if isinstance(obj_or_evt, asyncio.Event): await obj_or_evt.wait() return cast(_T, hass.data[data_key]) return cast(_T, obj_or_evt) return async_wrapped return wrapper async def _test_singleton_typing(hass: HomeAssistant) -> None: """Test singleton overloads work as intended. This is tested during the mypy run. Do not move it to 'tests'! """ # Test HassKey key = HassKey[int]("key") @singleton(key) def func(hass: HomeAssistant) -> int: return 2 @singleton(key, async_=True) async def async_func(hass: HomeAssistant) -> int: return 2 assert_type(func(hass), int) assert_type(await async_func(hass), int) # Test invalid use of 'async_' with sync function @singleton(key, async_=True) # type: ignore[arg-type] def func_error(hass: HomeAssistant) -> int: return 2 # Test string key other_key = "key" @singleton(other_key) def func2(hass: HomeAssistant) -> str: return "" @singleton(other_key) async def async_func2(hass: HomeAssistant) -> str: return "" assert_type(func2(hass), str) assert_type(await async_func2(hass), str)