Use HassKey for helpers (2) (#117013)

pull/117024/head
Marc Mueller 2024-05-07 18:24:13 +02:00 committed by GitHub
parent c50a340cbc
commit 8f614fb06d
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
13 changed files with 69 additions and 39 deletions

View File

@ -16,6 +16,7 @@ from homeassistant.data_entry_flow import FlowResult
from homeassistant.exceptions import HomeAssistantError
from homeassistant.helpers.importlib import async_import_module
from homeassistant.util.decorator import Registry
from homeassistant.util.hass_dict import HassKey
MULTI_FACTOR_AUTH_MODULES: Registry[str, type[MultiFactorAuthModule]] = Registry()
@ -29,7 +30,7 @@ MULTI_FACTOR_AUTH_MODULE_SCHEMA = vol.Schema(
extra=vol.ALLOW_EXTRA,
)
DATA_REQS = "mfa_auth_module_reqs_processed"
DATA_REQS: HassKey[set[str]] = HassKey("mfa_auth_module_reqs_processed")
_LOGGER = logging.getLogger(__name__)

View File

@ -17,13 +17,14 @@ from homeassistant.exceptions import HomeAssistantError
from homeassistant.helpers.importlib import async_import_module
from homeassistant.util import dt as dt_util
from homeassistant.util.decorator import Registry
from homeassistant.util.hass_dict import HassKey
from ..auth_store import AuthStore
from ..const import MFA_SESSION_EXPIRATION
from ..models import AuthFlowResult, Credentials, RefreshToken, User, UserMeta
_LOGGER = logging.getLogger(__name__)
DATA_REQS = "auth_prov_reqs_processed"
DATA_REQS: HassKey[set[str]] = HassKey("auth_prov_reqs_processed")
AUTH_PROVIDERS: Registry[str, type[AuthProvider]] = Registry()

View File

@ -32,6 +32,7 @@ from homeassistant.helpers.service import (
async_extract_referenced_entity_ids,
async_register_admin_service,
)
from homeassistant.helpers.signal import KEY_HA_STOP
from homeassistant.helpers.template import async_load_custom_templates
from homeassistant.helpers.typing import ConfigType
@ -386,7 +387,7 @@ async def _async_stop(hass: ha.HomeAssistant, restart: bool) -> None:
"""Stop home assistant."""
exit_code = RESTART_EXIT_CODE if restart else 0
# Track trask in hass.data. No need to cleanup, we're stopping.
hass.data["homeassistant_stop"] = asyncio.create_task(hass.async_stop(exit_code))
hass.data[KEY_HA_STOP] = asyncio.create_task(hass.async_stop(exit_code))
@ha.callback

View File

@ -20,10 +20,13 @@ from homeassistant.loader import (
bind_hass,
)
from homeassistant.setup import ATTR_COMPONENT, EventComponentLoaded
from homeassistant.util.hass_dict import HassKey
from homeassistant.util.logging import catch_log_exception
_LOGGER = logging.getLogger(__name__)
DATA_INTEGRATION_PLATFORMS = "integration_platforms"
DATA_INTEGRATION_PLATFORMS: HassKey[list[IntegrationPlatform]] = HassKey(
"integration_platforms"
)
@dataclass(slots=True, frozen=True)
@ -160,8 +163,7 @@ async def async_process_integration_platforms(
) -> None:
"""Process a specific platform for all current and future loaded integrations."""
if DATA_INTEGRATION_PLATFORMS not in hass.data:
integration_platforms: list[IntegrationPlatform] = []
hass.data[DATA_INTEGRATION_PLATFORMS] = integration_platforms
integration_platforms = hass.data[DATA_INTEGRATION_PLATFORMS] = []
hass.bus.async_listen(
EVENT_COMPONENT_LOADED,
partial(

View File

@ -1,12 +1,15 @@
"""Helpers to check recorder."""
from __future__ import annotations
import asyncio
from dataclasses import dataclass, field
from typing import Any
from homeassistant.core import HomeAssistant, callback
from homeassistant.util.hass_dict import HassKey
DOMAIN = "recorder"
DOMAIN: HassKey[RecorderData] = HassKey("recorder")
@dataclass(slots=True)
@ -14,7 +17,7 @@ class RecorderData:
"""Recorder data stored in hass.data."""
recorder_platforms: dict[str, Any] = field(default_factory=dict)
db_connected: asyncio.Future = field(default_factory=asyncio.Future)
db_connected: asyncio.Future[bool] = field(default_factory=asyncio.Future)
def async_migration_in_progress(hass: HomeAssistant) -> bool:
@ -40,5 +43,4 @@ async def async_wait_recorder(hass: HomeAssistant) -> bool:
"""
if DOMAIN not in hass.data:
return False
db_connected: asyncio.Future[bool] = hass.data[DOMAIN].db_connected
return await db_connected
return await hass.data[DOMAIN].db_connected

View File

@ -11,6 +11,7 @@ from homeassistant.const import ATTR_RESTORED, EVENT_HOMEASSISTANT_STOP
from homeassistant.core import HomeAssistant, State, callback, valid_entity_id
from homeassistant.exceptions import HomeAssistantError
import homeassistant.util.dt as dt_util
from homeassistant.util.hass_dict import HassKey
from homeassistant.util.json import json_loads
from . import start
@ -20,7 +21,7 @@ from .frame import report
from .json import JSONEncoder
from .storage import Store
DATA_RESTORE_STATE = "restore_state"
DATA_RESTORE_STATE: HassKey[RestoreStateData] = HassKey("restore_state")
_LOGGER = logging.getLogger(__name__)
@ -104,7 +105,7 @@ async def async_load(hass: HomeAssistant) -> None:
@callback
def async_get(hass: HomeAssistant) -> RestoreStateData:
"""Get the restore state data helper."""
return cast(RestoreStateData, hass.data[DATA_RESTORE_STATE])
return hass.data[DATA_RESTORE_STATE]
class RestoreStateData:

View File

@ -81,6 +81,7 @@ from homeassistant.core import (
from homeassistant.util import slugify
from homeassistant.util.async_ import create_eager_task
from homeassistant.util.dt import utcnow
from homeassistant.util.hass_dict import HassKey
from homeassistant.util.signal_type import SignalType, SignalTypeFormat
from . import condition, config_validation as cv, service, template
@ -133,9 +134,11 @@ DEFAULT_MAX_EXCEEDED = "WARNING"
ATTR_CUR = "current"
ATTR_MAX = "max"
DATA_SCRIPTS = "helpers.script"
DATA_SCRIPT_BREAKPOINTS = "helpers.script_breakpoints"
DATA_NEW_SCRIPT_RUNS_NOT_ALLOWED = "helpers.script_not_allowed"
DATA_SCRIPTS: HassKey[list[ScriptData]] = HassKey("helpers.script")
DATA_SCRIPT_BREAKPOINTS: HassKey[dict[str, dict[str, set[str]]]] = HassKey(
"helpers.script_breakpoints"
)
DATA_NEW_SCRIPT_RUNS_NOT_ALLOWED: HassKey[None] = HassKey("helpers.script_not_allowed")
RUN_ID_ANY = "*"
NODE_ANY = "*"
@ -158,6 +161,13 @@ SCRIPT_DEBUG_CONTINUE_ALL = "script_debug_continue_all"
script_stack_cv: ContextVar[list[int] | None] = ContextVar("script_stack", default=None)
class ScriptData(TypedDict):
"""Store data related to script instance."""
instance: Script
started_before_shutdown: bool
class ScriptStoppedError(Exception):
"""Error to indicate that the script has been stopped."""

View File

@ -47,6 +47,7 @@ from homeassistant.exceptions import (
)
from homeassistant.loader import Integration, async_get_integrations, bind_hass
from homeassistant.util.async_ import create_eager_task
from homeassistant.util.hass_dict import HassKey
from homeassistant.util.yaml import load_yaml_dict
from homeassistant.util.yaml.loader import JSON_TYPE
@ -74,8 +75,12 @@ CONF_SERVICE_ENTITY_ID = "entity_id"
_LOGGER = logging.getLogger(__name__)
SERVICE_DESCRIPTION_CACHE = "service_description_cache"
ALL_SERVICE_DESCRIPTIONS_CACHE = "all_service_descriptions_cache"
SERVICE_DESCRIPTION_CACHE: HassKey[dict[tuple[str, str], dict[str, Any] | None]] = (
HassKey("service_description_cache")
)
ALL_SERVICE_DESCRIPTIONS_CACHE: HassKey[
tuple[set[tuple[str, str]], dict[str, dict[str, Any]]]
] = HassKey("all_service_descriptions_cache")
_T = TypeVar("_T")
@ -660,9 +665,7 @@ async def async_get_all_descriptions(
hass: HomeAssistant,
) -> dict[str, dict[str, Any]]:
"""Return descriptions (i.e. user documentation) for all service calls."""
descriptions_cache: dict[tuple[str, str], dict[str, Any] | None] = (
hass.data.setdefault(SERVICE_DESCRIPTION_CACHE, {})
)
descriptions_cache = hass.data.setdefault(SERVICE_DESCRIPTION_CACHE, {})
# We don't mutate services here so we avoid calling
# async_services which makes a copy of every services
@ -686,7 +689,7 @@ async def async_get_all_descriptions(
previous_all_services, previous_descriptions_cache = all_cache
# If the services are the same, we can return the cache
if previous_all_services == all_services:
return previous_descriptions_cache # type: ignore[no-any-return]
return previous_descriptions_cache
# Files we loaded for missing descriptions
loaded: dict[str, JSON_TYPE] = {}
@ -812,9 +815,7 @@ def async_set_service_schema(
domain = domain.lower()
service = service.lower()
descriptions_cache: dict[tuple[str, str], dict[str, Any] | None] = (
hass.data.setdefault(SERVICE_DESCRIPTION_CACHE, {})
)
descriptions_cache = hass.data.setdefault(SERVICE_DESCRIPTION_CACHE, {})
description = {
"name": schema.get("name", ""),

View File

@ -7,9 +7,12 @@ import signal
from homeassistant.const import RESTART_EXIT_CODE
from homeassistant.core import HomeAssistant, callback
from homeassistant.loader import bind_hass
from homeassistant.util.hass_dict import HassKey
_LOGGER = logging.getLogger(__name__)
KEY_HA_STOP: HassKey[asyncio.Task[None]] = HassKey("homeassistant_stop")
@callback
@bind_hass
@ -25,9 +28,7 @@ def async_register_signal_handling(hass: HomeAssistant) -> None:
"""
hass.loop.remove_signal_handler(signal.SIGTERM)
hass.loop.remove_signal_handler(signal.SIGINT)
hass.data["homeassistant_stop"] = asyncio.create_task(
hass.async_stop(exit_code)
)
hass.data[KEY_HA_STOP] = asyncio.create_task(hass.async_stop(exit_code))
try:
hass.loop.add_signal_handler(signal.SIGTERM, async_signal_handle, 0)

View File

@ -32,6 +32,7 @@ from homeassistant.loader import bind_hass
from homeassistant.util import json as json_util
import homeassistant.util.dt as dt_util
from homeassistant.util.file import WriteError
from homeassistant.util.hass_dict import HassKey
from . import json as json_helper
@ -42,8 +43,8 @@ MAX_LOAD_CONCURRENTLY = 6
STORAGE_DIR = ".storage"
_LOGGER = logging.getLogger(__name__)
STORAGE_SEMAPHORE = "storage_semaphore"
STORAGE_MANAGER = "storage_manager"
STORAGE_SEMAPHORE: HassKey[asyncio.Semaphore] = HassKey("storage_semaphore")
STORAGE_MANAGER: HassKey[_StoreManager] = HassKey("storage_manager")
MANAGER_CLEANUP_DELAY = 60

View File

@ -10,12 +10,15 @@ from homeassistant.const import SUN_EVENT_SUNRISE, SUN_EVENT_SUNSET
from homeassistant.core import HomeAssistant, callback
from homeassistant.loader import bind_hass
from homeassistant.util import dt as dt_util
from homeassistant.util.hass_dict import HassKey
if TYPE_CHECKING:
import astral
import astral.location
DATA_LOCATION_CACHE = "astral_location_cache"
DATA_LOCATION_CACHE: HassKey[
dict[tuple[str, str, str, float, float], astral.location.Location]
] = HassKey("astral_location_cache")
ELEVATION_AGNOSTIC_EVENTS = ("noon", "midnight")

View File

@ -76,6 +76,7 @@ from homeassistant.util import (
slugify as slugify_util,
)
from homeassistant.util.async_ import run_callback_threadsafe
from homeassistant.util.hass_dict import HassKey
from homeassistant.util.json import JSON_DECODE_EXCEPTIONS, json_loads
from homeassistant.util.read_only_dict import ReadOnlyDict
from homeassistant.util.thread import ThreadWithException
@ -99,9 +100,13 @@ _LOGGER = logging.getLogger(__name__)
_SENTINEL = object()
DATE_STR_FORMAT = "%Y-%m-%d %H:%M:%S"
_ENVIRONMENT = "template.environment"
_ENVIRONMENT_LIMITED = "template.environment_limited"
_ENVIRONMENT_STRICT = "template.environment_strict"
_ENVIRONMENT: HassKey[TemplateEnvironment] = HassKey("template.environment")
_ENVIRONMENT_LIMITED: HassKey[TemplateEnvironment] = HassKey(
"template.environment_limited"
)
_ENVIRONMENT_STRICT: HassKey[TemplateEnvironment] = HassKey(
"template.environment_strict"
)
_HASS_LOADER = "template.hass_loader"
_RE_JINJA_DELIMITERS = re.compile(r"\{%|\{\{|\{#")
@ -511,8 +516,7 @@ class Template:
wanted_env = _ENVIRONMENT_STRICT
else:
wanted_env = _ENVIRONMENT
ret: TemplateEnvironment | None = self.hass.data.get(wanted_env)
if ret is None:
if (ret := self.hass.data.get(wanted_env)) is None:
ret = self.hass.data[wanted_env] = TemplateEnvironment(
self.hass, self._limited, self._strict, self._log_fn
)

View File

@ -30,6 +30,7 @@ from homeassistant.core import (
from homeassistant.exceptions import HomeAssistantError
from homeassistant.loader import IntegrationNotFound, async_get_integration
from homeassistant.util.async_ import create_eager_task
from homeassistant.util.hass_dict import HassKey
from .typing import ConfigType, TemplateVarsType
@ -42,7 +43,9 @@ _PLATFORM_ALIASES = {
"time": "homeassistant",
}
DATA_PLUGGABLE_ACTIONS = "pluggable_actions"
DATA_PLUGGABLE_ACTIONS: HassKey[defaultdict[tuple, PluggableActionsEntry]] = HassKey(
"pluggable_actions"
)
class TriggerProtocol(Protocol):
@ -138,9 +141,8 @@ class PluggableAction:
def async_get_registry(hass: HomeAssistant) -> dict[tuple, PluggableActionsEntry]:
"""Return the pluggable actions registry."""
if data := hass.data.get(DATA_PLUGGABLE_ACTIONS):
return data # type: ignore[no-any-return]
data = defaultdict(PluggableActionsEntry)
hass.data[DATA_PLUGGABLE_ACTIONS] = data
return data
data = hass.data[DATA_PLUGGABLE_ACTIONS] = defaultdict(PluggableActionsEntry)
return data
@staticmethod