Use HassKey for helpers (2) (#117013)
parent
c50a340cbc
commit
8f614fb06d
|
@ -16,6 +16,7 @@ from homeassistant.data_entry_flow import FlowResult
|
|||
from homeassistant.exceptions import HomeAssistantError
|
||||
from homeassistant.helpers.importlib import async_import_module
|
||||
from homeassistant.util.decorator import Registry
|
||||
from homeassistant.util.hass_dict import HassKey
|
||||
|
||||
MULTI_FACTOR_AUTH_MODULES: Registry[str, type[MultiFactorAuthModule]] = Registry()
|
||||
|
||||
|
@ -29,7 +30,7 @@ MULTI_FACTOR_AUTH_MODULE_SCHEMA = vol.Schema(
|
|||
extra=vol.ALLOW_EXTRA,
|
||||
)
|
||||
|
||||
DATA_REQS = "mfa_auth_module_reqs_processed"
|
||||
DATA_REQS: HassKey[set[str]] = HassKey("mfa_auth_module_reqs_processed")
|
||||
|
||||
_LOGGER = logging.getLogger(__name__)
|
||||
|
||||
|
|
|
@ -17,13 +17,14 @@ from homeassistant.exceptions import HomeAssistantError
|
|||
from homeassistant.helpers.importlib import async_import_module
|
||||
from homeassistant.util import dt as dt_util
|
||||
from homeassistant.util.decorator import Registry
|
||||
from homeassistant.util.hass_dict import HassKey
|
||||
|
||||
from ..auth_store import AuthStore
|
||||
from ..const import MFA_SESSION_EXPIRATION
|
||||
from ..models import AuthFlowResult, Credentials, RefreshToken, User, UserMeta
|
||||
|
||||
_LOGGER = logging.getLogger(__name__)
|
||||
DATA_REQS = "auth_prov_reqs_processed"
|
||||
DATA_REQS: HassKey[set[str]] = HassKey("auth_prov_reqs_processed")
|
||||
|
||||
AUTH_PROVIDERS: Registry[str, type[AuthProvider]] = Registry()
|
||||
|
||||
|
|
|
@ -32,6 +32,7 @@ from homeassistant.helpers.service import (
|
|||
async_extract_referenced_entity_ids,
|
||||
async_register_admin_service,
|
||||
)
|
||||
from homeassistant.helpers.signal import KEY_HA_STOP
|
||||
from homeassistant.helpers.template import async_load_custom_templates
|
||||
from homeassistant.helpers.typing import ConfigType
|
||||
|
||||
|
@ -386,7 +387,7 @@ async def _async_stop(hass: ha.HomeAssistant, restart: bool) -> None:
|
|||
"""Stop home assistant."""
|
||||
exit_code = RESTART_EXIT_CODE if restart else 0
|
||||
# Track trask in hass.data. No need to cleanup, we're stopping.
|
||||
hass.data["homeassistant_stop"] = asyncio.create_task(hass.async_stop(exit_code))
|
||||
hass.data[KEY_HA_STOP] = asyncio.create_task(hass.async_stop(exit_code))
|
||||
|
||||
|
||||
@ha.callback
|
||||
|
|
|
@ -20,10 +20,13 @@ from homeassistant.loader import (
|
|||
bind_hass,
|
||||
)
|
||||
from homeassistant.setup import ATTR_COMPONENT, EventComponentLoaded
|
||||
from homeassistant.util.hass_dict import HassKey
|
||||
from homeassistant.util.logging import catch_log_exception
|
||||
|
||||
_LOGGER = logging.getLogger(__name__)
|
||||
DATA_INTEGRATION_PLATFORMS = "integration_platforms"
|
||||
DATA_INTEGRATION_PLATFORMS: HassKey[list[IntegrationPlatform]] = HassKey(
|
||||
"integration_platforms"
|
||||
)
|
||||
|
||||
|
||||
@dataclass(slots=True, frozen=True)
|
||||
|
@ -160,8 +163,7 @@ async def async_process_integration_platforms(
|
|||
) -> None:
|
||||
"""Process a specific platform for all current and future loaded integrations."""
|
||||
if DATA_INTEGRATION_PLATFORMS not in hass.data:
|
||||
integration_platforms: list[IntegrationPlatform] = []
|
||||
hass.data[DATA_INTEGRATION_PLATFORMS] = integration_platforms
|
||||
integration_platforms = hass.data[DATA_INTEGRATION_PLATFORMS] = []
|
||||
hass.bus.async_listen(
|
||||
EVENT_COMPONENT_LOADED,
|
||||
partial(
|
||||
|
|
|
@ -1,12 +1,15 @@
|
|||
"""Helpers to check recorder."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any
|
||||
|
||||
from homeassistant.core import HomeAssistant, callback
|
||||
from homeassistant.util.hass_dict import HassKey
|
||||
|
||||
DOMAIN = "recorder"
|
||||
DOMAIN: HassKey[RecorderData] = HassKey("recorder")
|
||||
|
||||
|
||||
@dataclass(slots=True)
|
||||
|
@ -14,7 +17,7 @@ class RecorderData:
|
|||
"""Recorder data stored in hass.data."""
|
||||
|
||||
recorder_platforms: dict[str, Any] = field(default_factory=dict)
|
||||
db_connected: asyncio.Future = field(default_factory=asyncio.Future)
|
||||
db_connected: asyncio.Future[bool] = field(default_factory=asyncio.Future)
|
||||
|
||||
|
||||
def async_migration_in_progress(hass: HomeAssistant) -> bool:
|
||||
|
@ -40,5 +43,4 @@ async def async_wait_recorder(hass: HomeAssistant) -> bool:
|
|||
"""
|
||||
if DOMAIN not in hass.data:
|
||||
return False
|
||||
db_connected: asyncio.Future[bool] = hass.data[DOMAIN].db_connected
|
||||
return await db_connected
|
||||
return await hass.data[DOMAIN].db_connected
|
||||
|
|
|
@ -11,6 +11,7 @@ from homeassistant.const import ATTR_RESTORED, EVENT_HOMEASSISTANT_STOP
|
|||
from homeassistant.core import HomeAssistant, State, callback, valid_entity_id
|
||||
from homeassistant.exceptions import HomeAssistantError
|
||||
import homeassistant.util.dt as dt_util
|
||||
from homeassistant.util.hass_dict import HassKey
|
||||
from homeassistant.util.json import json_loads
|
||||
|
||||
from . import start
|
||||
|
@ -20,7 +21,7 @@ from .frame import report
|
|||
from .json import JSONEncoder
|
||||
from .storage import Store
|
||||
|
||||
DATA_RESTORE_STATE = "restore_state"
|
||||
DATA_RESTORE_STATE: HassKey[RestoreStateData] = HassKey("restore_state")
|
||||
|
||||
_LOGGER = logging.getLogger(__name__)
|
||||
|
||||
|
@ -104,7 +105,7 @@ async def async_load(hass: HomeAssistant) -> None:
|
|||
@callback
|
||||
def async_get(hass: HomeAssistant) -> RestoreStateData:
|
||||
"""Get the restore state data helper."""
|
||||
return cast(RestoreStateData, hass.data[DATA_RESTORE_STATE])
|
||||
return hass.data[DATA_RESTORE_STATE]
|
||||
|
||||
|
||||
class RestoreStateData:
|
||||
|
|
|
@ -81,6 +81,7 @@ from homeassistant.core import (
|
|||
from homeassistant.util import slugify
|
||||
from homeassistant.util.async_ import create_eager_task
|
||||
from homeassistant.util.dt import utcnow
|
||||
from homeassistant.util.hass_dict import HassKey
|
||||
from homeassistant.util.signal_type import SignalType, SignalTypeFormat
|
||||
|
||||
from . import condition, config_validation as cv, service, template
|
||||
|
@ -133,9 +134,11 @@ DEFAULT_MAX_EXCEEDED = "WARNING"
|
|||
ATTR_CUR = "current"
|
||||
ATTR_MAX = "max"
|
||||
|
||||
DATA_SCRIPTS = "helpers.script"
|
||||
DATA_SCRIPT_BREAKPOINTS = "helpers.script_breakpoints"
|
||||
DATA_NEW_SCRIPT_RUNS_NOT_ALLOWED = "helpers.script_not_allowed"
|
||||
DATA_SCRIPTS: HassKey[list[ScriptData]] = HassKey("helpers.script")
|
||||
DATA_SCRIPT_BREAKPOINTS: HassKey[dict[str, dict[str, set[str]]]] = HassKey(
|
||||
"helpers.script_breakpoints"
|
||||
)
|
||||
DATA_NEW_SCRIPT_RUNS_NOT_ALLOWED: HassKey[None] = HassKey("helpers.script_not_allowed")
|
||||
RUN_ID_ANY = "*"
|
||||
NODE_ANY = "*"
|
||||
|
||||
|
@ -158,6 +161,13 @@ SCRIPT_DEBUG_CONTINUE_ALL = "script_debug_continue_all"
|
|||
script_stack_cv: ContextVar[list[int] | None] = ContextVar("script_stack", default=None)
|
||||
|
||||
|
||||
class ScriptData(TypedDict):
|
||||
"""Store data related to script instance."""
|
||||
|
||||
instance: Script
|
||||
started_before_shutdown: bool
|
||||
|
||||
|
||||
class ScriptStoppedError(Exception):
|
||||
"""Error to indicate that the script has been stopped."""
|
||||
|
||||
|
|
|
@ -47,6 +47,7 @@ from homeassistant.exceptions import (
|
|||
)
|
||||
from homeassistant.loader import Integration, async_get_integrations, bind_hass
|
||||
from homeassistant.util.async_ import create_eager_task
|
||||
from homeassistant.util.hass_dict import HassKey
|
||||
from homeassistant.util.yaml import load_yaml_dict
|
||||
from homeassistant.util.yaml.loader import JSON_TYPE
|
||||
|
||||
|
@ -74,8 +75,12 @@ CONF_SERVICE_ENTITY_ID = "entity_id"
|
|||
|
||||
_LOGGER = logging.getLogger(__name__)
|
||||
|
||||
SERVICE_DESCRIPTION_CACHE = "service_description_cache"
|
||||
ALL_SERVICE_DESCRIPTIONS_CACHE = "all_service_descriptions_cache"
|
||||
SERVICE_DESCRIPTION_CACHE: HassKey[dict[tuple[str, str], dict[str, Any] | None]] = (
|
||||
HassKey("service_description_cache")
|
||||
)
|
||||
ALL_SERVICE_DESCRIPTIONS_CACHE: HassKey[
|
||||
tuple[set[tuple[str, str]], dict[str, dict[str, Any]]]
|
||||
] = HassKey("all_service_descriptions_cache")
|
||||
|
||||
_T = TypeVar("_T")
|
||||
|
||||
|
@ -660,9 +665,7 @@ async def async_get_all_descriptions(
|
|||
hass: HomeAssistant,
|
||||
) -> dict[str, dict[str, Any]]:
|
||||
"""Return descriptions (i.e. user documentation) for all service calls."""
|
||||
descriptions_cache: dict[tuple[str, str], dict[str, Any] | None] = (
|
||||
hass.data.setdefault(SERVICE_DESCRIPTION_CACHE, {})
|
||||
)
|
||||
descriptions_cache = hass.data.setdefault(SERVICE_DESCRIPTION_CACHE, {})
|
||||
|
||||
# We don't mutate services here so we avoid calling
|
||||
# async_services which makes a copy of every services
|
||||
|
@ -686,7 +689,7 @@ async def async_get_all_descriptions(
|
|||
previous_all_services, previous_descriptions_cache = all_cache
|
||||
# If the services are the same, we can return the cache
|
||||
if previous_all_services == all_services:
|
||||
return previous_descriptions_cache # type: ignore[no-any-return]
|
||||
return previous_descriptions_cache
|
||||
|
||||
# Files we loaded for missing descriptions
|
||||
loaded: dict[str, JSON_TYPE] = {}
|
||||
|
@ -812,9 +815,7 @@ def async_set_service_schema(
|
|||
domain = domain.lower()
|
||||
service = service.lower()
|
||||
|
||||
descriptions_cache: dict[tuple[str, str], dict[str, Any] | None] = (
|
||||
hass.data.setdefault(SERVICE_DESCRIPTION_CACHE, {})
|
||||
)
|
||||
descriptions_cache = hass.data.setdefault(SERVICE_DESCRIPTION_CACHE, {})
|
||||
|
||||
description = {
|
||||
"name": schema.get("name", ""),
|
||||
|
|
|
@ -7,9 +7,12 @@ import signal
|
|||
from homeassistant.const import RESTART_EXIT_CODE
|
||||
from homeassistant.core import HomeAssistant, callback
|
||||
from homeassistant.loader import bind_hass
|
||||
from homeassistant.util.hass_dict import HassKey
|
||||
|
||||
_LOGGER = logging.getLogger(__name__)
|
||||
|
||||
KEY_HA_STOP: HassKey[asyncio.Task[None]] = HassKey("homeassistant_stop")
|
||||
|
||||
|
||||
@callback
|
||||
@bind_hass
|
||||
|
@ -25,9 +28,7 @@ def async_register_signal_handling(hass: HomeAssistant) -> None:
|
|||
"""
|
||||
hass.loop.remove_signal_handler(signal.SIGTERM)
|
||||
hass.loop.remove_signal_handler(signal.SIGINT)
|
||||
hass.data["homeassistant_stop"] = asyncio.create_task(
|
||||
hass.async_stop(exit_code)
|
||||
)
|
||||
hass.data[KEY_HA_STOP] = asyncio.create_task(hass.async_stop(exit_code))
|
||||
|
||||
try:
|
||||
hass.loop.add_signal_handler(signal.SIGTERM, async_signal_handle, 0)
|
||||
|
|
|
@ -32,6 +32,7 @@ from homeassistant.loader import bind_hass
|
|||
from homeassistant.util import json as json_util
|
||||
import homeassistant.util.dt as dt_util
|
||||
from homeassistant.util.file import WriteError
|
||||
from homeassistant.util.hass_dict import HassKey
|
||||
|
||||
from . import json as json_helper
|
||||
|
||||
|
@ -42,8 +43,8 @@ MAX_LOAD_CONCURRENTLY = 6
|
|||
STORAGE_DIR = ".storage"
|
||||
_LOGGER = logging.getLogger(__name__)
|
||||
|
||||
STORAGE_SEMAPHORE = "storage_semaphore"
|
||||
STORAGE_MANAGER = "storage_manager"
|
||||
STORAGE_SEMAPHORE: HassKey[asyncio.Semaphore] = HassKey("storage_semaphore")
|
||||
STORAGE_MANAGER: HassKey[_StoreManager] = HassKey("storage_manager")
|
||||
|
||||
MANAGER_CLEANUP_DELAY = 60
|
||||
|
||||
|
|
|
@ -10,12 +10,15 @@ from homeassistant.const import SUN_EVENT_SUNRISE, SUN_EVENT_SUNSET
|
|||
from homeassistant.core import HomeAssistant, callback
|
||||
from homeassistant.loader import bind_hass
|
||||
from homeassistant.util import dt as dt_util
|
||||
from homeassistant.util.hass_dict import HassKey
|
||||
|
||||
if TYPE_CHECKING:
|
||||
import astral
|
||||
import astral.location
|
||||
|
||||
DATA_LOCATION_CACHE = "astral_location_cache"
|
||||
DATA_LOCATION_CACHE: HassKey[
|
||||
dict[tuple[str, str, str, float, float], astral.location.Location]
|
||||
] = HassKey("astral_location_cache")
|
||||
|
||||
ELEVATION_AGNOSTIC_EVENTS = ("noon", "midnight")
|
||||
|
||||
|
|
|
@ -76,6 +76,7 @@ from homeassistant.util import (
|
|||
slugify as slugify_util,
|
||||
)
|
||||
from homeassistant.util.async_ import run_callback_threadsafe
|
||||
from homeassistant.util.hass_dict import HassKey
|
||||
from homeassistant.util.json import JSON_DECODE_EXCEPTIONS, json_loads
|
||||
from homeassistant.util.read_only_dict import ReadOnlyDict
|
||||
from homeassistant.util.thread import ThreadWithException
|
||||
|
@ -99,9 +100,13 @@ _LOGGER = logging.getLogger(__name__)
|
|||
_SENTINEL = object()
|
||||
DATE_STR_FORMAT = "%Y-%m-%d %H:%M:%S"
|
||||
|
||||
_ENVIRONMENT = "template.environment"
|
||||
_ENVIRONMENT_LIMITED = "template.environment_limited"
|
||||
_ENVIRONMENT_STRICT = "template.environment_strict"
|
||||
_ENVIRONMENT: HassKey[TemplateEnvironment] = HassKey("template.environment")
|
||||
_ENVIRONMENT_LIMITED: HassKey[TemplateEnvironment] = HassKey(
|
||||
"template.environment_limited"
|
||||
)
|
||||
_ENVIRONMENT_STRICT: HassKey[TemplateEnvironment] = HassKey(
|
||||
"template.environment_strict"
|
||||
)
|
||||
_HASS_LOADER = "template.hass_loader"
|
||||
|
||||
_RE_JINJA_DELIMITERS = re.compile(r"\{%|\{\{|\{#")
|
||||
|
@ -511,8 +516,7 @@ class Template:
|
|||
wanted_env = _ENVIRONMENT_STRICT
|
||||
else:
|
||||
wanted_env = _ENVIRONMENT
|
||||
ret: TemplateEnvironment | None = self.hass.data.get(wanted_env)
|
||||
if ret is None:
|
||||
if (ret := self.hass.data.get(wanted_env)) is None:
|
||||
ret = self.hass.data[wanted_env] = TemplateEnvironment(
|
||||
self.hass, self._limited, self._strict, self._log_fn
|
||||
)
|
||||
|
|
|
@ -30,6 +30,7 @@ from homeassistant.core import (
|
|||
from homeassistant.exceptions import HomeAssistantError
|
||||
from homeassistant.loader import IntegrationNotFound, async_get_integration
|
||||
from homeassistant.util.async_ import create_eager_task
|
||||
from homeassistant.util.hass_dict import HassKey
|
||||
|
||||
from .typing import ConfigType, TemplateVarsType
|
||||
|
||||
|
@ -42,7 +43,9 @@ _PLATFORM_ALIASES = {
|
|||
"time": "homeassistant",
|
||||
}
|
||||
|
||||
DATA_PLUGGABLE_ACTIONS = "pluggable_actions"
|
||||
DATA_PLUGGABLE_ACTIONS: HassKey[defaultdict[tuple, PluggableActionsEntry]] = HassKey(
|
||||
"pluggable_actions"
|
||||
)
|
||||
|
||||
|
||||
class TriggerProtocol(Protocol):
|
||||
|
@ -138,9 +141,8 @@ class PluggableAction:
|
|||
def async_get_registry(hass: HomeAssistant) -> dict[tuple, PluggableActionsEntry]:
|
||||
"""Return the pluggable actions registry."""
|
||||
if data := hass.data.get(DATA_PLUGGABLE_ACTIONS):
|
||||
return data # type: ignore[no-any-return]
|
||||
data = defaultdict(PluggableActionsEntry)
|
||||
hass.data[DATA_PLUGGABLE_ACTIONS] = data
|
||||
return data
|
||||
data = hass.data[DATA_PLUGGABLE_ACTIONS] = defaultdict(PluggableActionsEntry)
|
||||
return data
|
||||
|
||||
@staticmethod
|
||||
|
|
Loading…
Reference in New Issue