From 8f614fb06d1d101f442c22c19f9aa0a62fa8aee9 Mon Sep 17 00:00:00 2001 From: Marc Mueller <30130371+cdce8p@users.noreply.github.com> Date: Tue, 7 May 2024 18:24:13 +0200 Subject: [PATCH] Use HassKey for helpers (2) (#117013) --- homeassistant/auth/mfa_modules/__init__.py | 3 ++- homeassistant/auth/providers/__init__.py | 3 ++- .../components/homeassistant/__init__.py | 3 ++- homeassistant/helpers/integration_platform.py | 8 +++++--- homeassistant/helpers/recorder.py | 10 ++++++---- homeassistant/helpers/restore_state.py | 5 +++-- homeassistant/helpers/script.py | 16 +++++++++++++--- homeassistant/helpers/service.py | 19 ++++++++++--------- homeassistant/helpers/signal.py | 7 ++++--- homeassistant/helpers/storage.py | 5 +++-- homeassistant/helpers/sun.py | 5 ++++- homeassistant/helpers/template.py | 14 +++++++++----- homeassistant/helpers/trigger.py | 10 ++++++---- 13 files changed, 69 insertions(+), 39 deletions(-) diff --git a/homeassistant/auth/mfa_modules/__init__.py b/homeassistant/auth/mfa_modules/__init__.py index fd4072ea88a..d57a274c7ff 100644 --- a/homeassistant/auth/mfa_modules/__init__.py +++ b/homeassistant/auth/mfa_modules/__init__.py @@ -16,6 +16,7 @@ from homeassistant.data_entry_flow import FlowResult from homeassistant.exceptions import HomeAssistantError from homeassistant.helpers.importlib import async_import_module from homeassistant.util.decorator import Registry +from homeassistant.util.hass_dict import HassKey MULTI_FACTOR_AUTH_MODULES: Registry[str, type[MultiFactorAuthModule]] = Registry() @@ -29,7 +30,7 @@ MULTI_FACTOR_AUTH_MODULE_SCHEMA = vol.Schema( extra=vol.ALLOW_EXTRA, ) -DATA_REQS = "mfa_auth_module_reqs_processed" +DATA_REQS: HassKey[set[str]] = HassKey("mfa_auth_module_reqs_processed") _LOGGER = logging.getLogger(__name__) diff --git a/homeassistant/auth/providers/__init__.py b/homeassistant/auth/providers/__init__.py index 63028f54d2e..debdd0b1a05 100644 --- a/homeassistant/auth/providers/__init__.py +++ b/homeassistant/auth/providers/__init__.py @@ -17,13 +17,14 @@ from homeassistant.exceptions import HomeAssistantError from homeassistant.helpers.importlib import async_import_module from homeassistant.util import dt as dt_util from homeassistant.util.decorator import Registry +from homeassistant.util.hass_dict import HassKey from ..auth_store import AuthStore from ..const import MFA_SESSION_EXPIRATION from ..models import AuthFlowResult, Credentials, RefreshToken, User, UserMeta _LOGGER = logging.getLogger(__name__) -DATA_REQS = "auth_prov_reqs_processed" +DATA_REQS: HassKey[set[str]] = HassKey("auth_prov_reqs_processed") AUTH_PROVIDERS: Registry[str, type[AuthProvider]] = Registry() diff --git a/homeassistant/components/homeassistant/__init__.py b/homeassistant/components/homeassistant/__init__.py index 6d32f175a8a..cc948fcc663 100644 --- a/homeassistant/components/homeassistant/__init__.py +++ b/homeassistant/components/homeassistant/__init__.py @@ -32,6 +32,7 @@ from homeassistant.helpers.service import ( async_extract_referenced_entity_ids, async_register_admin_service, ) +from homeassistant.helpers.signal import KEY_HA_STOP from homeassistant.helpers.template import async_load_custom_templates from homeassistant.helpers.typing import ConfigType @@ -386,7 +387,7 @@ async def _async_stop(hass: ha.HomeAssistant, restart: bool) -> None: """Stop home assistant.""" exit_code = RESTART_EXIT_CODE if restart else 0 # Track trask in hass.data. No need to cleanup, we're stopping. - hass.data["homeassistant_stop"] = asyncio.create_task(hass.async_stop(exit_code)) + hass.data[KEY_HA_STOP] = asyncio.create_task(hass.async_stop(exit_code)) @ha.callback diff --git a/homeassistant/helpers/integration_platform.py b/homeassistant/helpers/integration_platform.py index fbd26019b64..a3eb19657e8 100644 --- a/homeassistant/helpers/integration_platform.py +++ b/homeassistant/helpers/integration_platform.py @@ -20,10 +20,13 @@ from homeassistant.loader import ( bind_hass, ) from homeassistant.setup import ATTR_COMPONENT, EventComponentLoaded +from homeassistant.util.hass_dict import HassKey from homeassistant.util.logging import catch_log_exception _LOGGER = logging.getLogger(__name__) -DATA_INTEGRATION_PLATFORMS = "integration_platforms" +DATA_INTEGRATION_PLATFORMS: HassKey[list[IntegrationPlatform]] = HassKey( + "integration_platforms" +) @dataclass(slots=True, frozen=True) @@ -160,8 +163,7 @@ async def async_process_integration_platforms( ) -> None: """Process a specific platform for all current and future loaded integrations.""" if DATA_INTEGRATION_PLATFORMS not in hass.data: - integration_platforms: list[IntegrationPlatform] = [] - hass.data[DATA_INTEGRATION_PLATFORMS] = integration_platforms + integration_platforms = hass.data[DATA_INTEGRATION_PLATFORMS] = [] hass.bus.async_listen( EVENT_COMPONENT_LOADED, partial( diff --git a/homeassistant/helpers/recorder.py b/homeassistant/helpers/recorder.py index 74ebbe5c67a..6155fc9b320 100644 --- a/homeassistant/helpers/recorder.py +++ b/homeassistant/helpers/recorder.py @@ -1,12 +1,15 @@ """Helpers to check recorder.""" +from __future__ import annotations + import asyncio from dataclasses import dataclass, field from typing import Any from homeassistant.core import HomeAssistant, callback +from homeassistant.util.hass_dict import HassKey -DOMAIN = "recorder" +DOMAIN: HassKey[RecorderData] = HassKey("recorder") @dataclass(slots=True) @@ -14,7 +17,7 @@ class RecorderData: """Recorder data stored in hass.data.""" recorder_platforms: dict[str, Any] = field(default_factory=dict) - db_connected: asyncio.Future = field(default_factory=asyncio.Future) + db_connected: asyncio.Future[bool] = field(default_factory=asyncio.Future) def async_migration_in_progress(hass: HomeAssistant) -> bool: @@ -40,5 +43,4 @@ async def async_wait_recorder(hass: HomeAssistant) -> bool: """ if DOMAIN not in hass.data: return False - db_connected: asyncio.Future[bool] = hass.data[DOMAIN].db_connected - return await db_connected + return await hass.data[DOMAIN].db_connected diff --git a/homeassistant/helpers/restore_state.py b/homeassistant/helpers/restore_state.py index 2b3afc2f57b..cf492ab38bd 100644 --- a/homeassistant/helpers/restore_state.py +++ b/homeassistant/helpers/restore_state.py @@ -11,6 +11,7 @@ from homeassistant.const import ATTR_RESTORED, EVENT_HOMEASSISTANT_STOP from homeassistant.core import HomeAssistant, State, callback, valid_entity_id from homeassistant.exceptions import HomeAssistantError import homeassistant.util.dt as dt_util +from homeassistant.util.hass_dict import HassKey from homeassistant.util.json import json_loads from . import start @@ -20,7 +21,7 @@ from .frame import report from .json import JSONEncoder from .storage import Store -DATA_RESTORE_STATE = "restore_state" +DATA_RESTORE_STATE: HassKey[RestoreStateData] = HassKey("restore_state") _LOGGER = logging.getLogger(__name__) @@ -104,7 +105,7 @@ async def async_load(hass: HomeAssistant) -> None: @callback def async_get(hass: HomeAssistant) -> RestoreStateData: """Get the restore state data helper.""" - return cast(RestoreStateData, hass.data[DATA_RESTORE_STATE]) + return hass.data[DATA_RESTORE_STATE] class RestoreStateData: diff --git a/homeassistant/helpers/script.py b/homeassistant/helpers/script.py index c246597cb07..cc5027b9f21 100644 --- a/homeassistant/helpers/script.py +++ b/homeassistant/helpers/script.py @@ -81,6 +81,7 @@ from homeassistant.core import ( from homeassistant.util import slugify from homeassistant.util.async_ import create_eager_task from homeassistant.util.dt import utcnow +from homeassistant.util.hass_dict import HassKey from homeassistant.util.signal_type import SignalType, SignalTypeFormat from . import condition, config_validation as cv, service, template @@ -133,9 +134,11 @@ DEFAULT_MAX_EXCEEDED = "WARNING" ATTR_CUR = "current" ATTR_MAX = "max" -DATA_SCRIPTS = "helpers.script" -DATA_SCRIPT_BREAKPOINTS = "helpers.script_breakpoints" -DATA_NEW_SCRIPT_RUNS_NOT_ALLOWED = "helpers.script_not_allowed" +DATA_SCRIPTS: HassKey[list[ScriptData]] = HassKey("helpers.script") +DATA_SCRIPT_BREAKPOINTS: HassKey[dict[str, dict[str, set[str]]]] = HassKey( + "helpers.script_breakpoints" +) +DATA_NEW_SCRIPT_RUNS_NOT_ALLOWED: HassKey[None] = HassKey("helpers.script_not_allowed") RUN_ID_ANY = "*" NODE_ANY = "*" @@ -158,6 +161,13 @@ SCRIPT_DEBUG_CONTINUE_ALL = "script_debug_continue_all" script_stack_cv: ContextVar[list[int] | None] = ContextVar("script_stack", default=None) +class ScriptData(TypedDict): + """Store data related to script instance.""" + + instance: Script + started_before_shutdown: bool + + class ScriptStoppedError(Exception): """Error to indicate that the script has been stopped.""" diff --git a/homeassistant/helpers/service.py b/homeassistant/helpers/service.py index 66c9f7db3e6..1f3d59e761c 100644 --- a/homeassistant/helpers/service.py +++ b/homeassistant/helpers/service.py @@ -47,6 +47,7 @@ from homeassistant.exceptions import ( ) from homeassistant.loader import Integration, async_get_integrations, bind_hass from homeassistant.util.async_ import create_eager_task +from homeassistant.util.hass_dict import HassKey from homeassistant.util.yaml import load_yaml_dict from homeassistant.util.yaml.loader import JSON_TYPE @@ -74,8 +75,12 @@ CONF_SERVICE_ENTITY_ID = "entity_id" _LOGGER = logging.getLogger(__name__) -SERVICE_DESCRIPTION_CACHE = "service_description_cache" -ALL_SERVICE_DESCRIPTIONS_CACHE = "all_service_descriptions_cache" +SERVICE_DESCRIPTION_CACHE: HassKey[dict[tuple[str, str], dict[str, Any] | None]] = ( + HassKey("service_description_cache") +) +ALL_SERVICE_DESCRIPTIONS_CACHE: HassKey[ + tuple[set[tuple[str, str]], dict[str, dict[str, Any]]] +] = HassKey("all_service_descriptions_cache") _T = TypeVar("_T") @@ -660,9 +665,7 @@ async def async_get_all_descriptions( hass: HomeAssistant, ) -> dict[str, dict[str, Any]]: """Return descriptions (i.e. user documentation) for all service calls.""" - descriptions_cache: dict[tuple[str, str], dict[str, Any] | None] = ( - hass.data.setdefault(SERVICE_DESCRIPTION_CACHE, {}) - ) + descriptions_cache = hass.data.setdefault(SERVICE_DESCRIPTION_CACHE, {}) # We don't mutate services here so we avoid calling # async_services which makes a copy of every services @@ -686,7 +689,7 @@ async def async_get_all_descriptions( previous_all_services, previous_descriptions_cache = all_cache # If the services are the same, we can return the cache if previous_all_services == all_services: - return previous_descriptions_cache # type: ignore[no-any-return] + return previous_descriptions_cache # Files we loaded for missing descriptions loaded: dict[str, JSON_TYPE] = {} @@ -812,9 +815,7 @@ def async_set_service_schema( domain = domain.lower() service = service.lower() - descriptions_cache: dict[tuple[str, str], dict[str, Any] | None] = ( - hass.data.setdefault(SERVICE_DESCRIPTION_CACHE, {}) - ) + descriptions_cache = hass.data.setdefault(SERVICE_DESCRIPTION_CACHE, {}) description = { "name": schema.get("name", ""), diff --git a/homeassistant/helpers/signal.py b/homeassistant/helpers/signal.py index baaa36e83ce..4a4b9bead47 100644 --- a/homeassistant/helpers/signal.py +++ b/homeassistant/helpers/signal.py @@ -7,9 +7,12 @@ import signal from homeassistant.const import RESTART_EXIT_CODE from homeassistant.core import HomeAssistant, callback from homeassistant.loader import bind_hass +from homeassistant.util.hass_dict import HassKey _LOGGER = logging.getLogger(__name__) +KEY_HA_STOP: HassKey[asyncio.Task[None]] = HassKey("homeassistant_stop") + @callback @bind_hass @@ -25,9 +28,7 @@ def async_register_signal_handling(hass: HomeAssistant) -> None: """ hass.loop.remove_signal_handler(signal.SIGTERM) hass.loop.remove_signal_handler(signal.SIGINT) - hass.data["homeassistant_stop"] = asyncio.create_task( - hass.async_stop(exit_code) - ) + hass.data[KEY_HA_STOP] = asyncio.create_task(hass.async_stop(exit_code)) try: hass.loop.add_signal_handler(signal.SIGTERM, async_signal_handle, 0) diff --git a/homeassistant/helpers/storage.py b/homeassistant/helpers/storage.py index 1013115fd01..41c8cc32fd0 100644 --- a/homeassistant/helpers/storage.py +++ b/homeassistant/helpers/storage.py @@ -32,6 +32,7 @@ from homeassistant.loader import bind_hass from homeassistant.util import json as json_util import homeassistant.util.dt as dt_util from homeassistant.util.file import WriteError +from homeassistant.util.hass_dict import HassKey from . import json as json_helper @@ -42,8 +43,8 @@ MAX_LOAD_CONCURRENTLY = 6 STORAGE_DIR = ".storage" _LOGGER = logging.getLogger(__name__) -STORAGE_SEMAPHORE = "storage_semaphore" -STORAGE_MANAGER = "storage_manager" +STORAGE_SEMAPHORE: HassKey[asyncio.Semaphore] = HassKey("storage_semaphore") +STORAGE_MANAGER: HassKey[_StoreManager] = HassKey("storage_manager") MANAGER_CLEANUP_DELAY = 60 diff --git a/homeassistant/helpers/sun.py b/homeassistant/helpers/sun.py index a490a7a8213..82f78cd10e2 100644 --- a/homeassistant/helpers/sun.py +++ b/homeassistant/helpers/sun.py @@ -10,12 +10,15 @@ from homeassistant.const import SUN_EVENT_SUNRISE, SUN_EVENT_SUNSET from homeassistant.core import HomeAssistant, callback from homeassistant.loader import bind_hass from homeassistant.util import dt as dt_util +from homeassistant.util.hass_dict import HassKey if TYPE_CHECKING: import astral import astral.location -DATA_LOCATION_CACHE = "astral_location_cache" +DATA_LOCATION_CACHE: HassKey[ + dict[tuple[str, str, str, float, float], astral.location.Location] +] = HassKey("astral_location_cache") ELEVATION_AGNOSTIC_EVENTS = ("noon", "midnight") diff --git a/homeassistant/helpers/template.py b/homeassistant/helpers/template.py index 9e4f116e546..de264760ff5 100644 --- a/homeassistant/helpers/template.py +++ b/homeassistant/helpers/template.py @@ -76,6 +76,7 @@ from homeassistant.util import ( slugify as slugify_util, ) from homeassistant.util.async_ import run_callback_threadsafe +from homeassistant.util.hass_dict import HassKey from homeassistant.util.json import JSON_DECODE_EXCEPTIONS, json_loads from homeassistant.util.read_only_dict import ReadOnlyDict from homeassistant.util.thread import ThreadWithException @@ -99,9 +100,13 @@ _LOGGER = logging.getLogger(__name__) _SENTINEL = object() DATE_STR_FORMAT = "%Y-%m-%d %H:%M:%S" -_ENVIRONMENT = "template.environment" -_ENVIRONMENT_LIMITED = "template.environment_limited" -_ENVIRONMENT_STRICT = "template.environment_strict" +_ENVIRONMENT: HassKey[TemplateEnvironment] = HassKey("template.environment") +_ENVIRONMENT_LIMITED: HassKey[TemplateEnvironment] = HassKey( + "template.environment_limited" +) +_ENVIRONMENT_STRICT: HassKey[TemplateEnvironment] = HassKey( + "template.environment_strict" +) _HASS_LOADER = "template.hass_loader" _RE_JINJA_DELIMITERS = re.compile(r"\{%|\{\{|\{#") @@ -511,8 +516,7 @@ class Template: wanted_env = _ENVIRONMENT_STRICT else: wanted_env = _ENVIRONMENT - ret: TemplateEnvironment | None = self.hass.data.get(wanted_env) - if ret is None: + if (ret := self.hass.data.get(wanted_env)) is None: ret = self.hass.data[wanted_env] = TemplateEnvironment( self.hass, self._limited, self._strict, self._log_fn ) diff --git a/homeassistant/helpers/trigger.py b/homeassistant/helpers/trigger.py index cb14102cb04..5c2b372bb7d 100644 --- a/homeassistant/helpers/trigger.py +++ b/homeassistant/helpers/trigger.py @@ -30,6 +30,7 @@ from homeassistant.core import ( from homeassistant.exceptions import HomeAssistantError from homeassistant.loader import IntegrationNotFound, async_get_integration from homeassistant.util.async_ import create_eager_task +from homeassistant.util.hass_dict import HassKey from .typing import ConfigType, TemplateVarsType @@ -42,7 +43,9 @@ _PLATFORM_ALIASES = { "time": "homeassistant", } -DATA_PLUGGABLE_ACTIONS = "pluggable_actions" +DATA_PLUGGABLE_ACTIONS: HassKey[defaultdict[tuple, PluggableActionsEntry]] = HassKey( + "pluggable_actions" +) class TriggerProtocol(Protocol): @@ -138,9 +141,8 @@ class PluggableAction: def async_get_registry(hass: HomeAssistant) -> dict[tuple, PluggableActionsEntry]: """Return the pluggable actions registry.""" if data := hass.data.get(DATA_PLUGGABLE_ACTIONS): - return data # type: ignore[no-any-return] - data = defaultdict(PluggableActionsEntry) - hass.data[DATA_PLUGGABLE_ACTIONS] = data + return data + data = hass.data[DATA_PLUGGABLE_ACTIONS] = defaultdict(PluggableActionsEntry) return data @staticmethod