From d6ee10a543eba11de1f2c962a288112588586f2e Mon Sep 17 00:00:00 2001 From: Erik Montnemery Date: Tue, 8 Oct 2024 12:18:45 +0200 Subject: [PATCH] Make FlowHandler.context a typed dict (#126291) * Make FlowHandler.context a typed dict * Adjust typing * Adjust typing * Avoid calling ConfigFlowContext constructor in hot path --- homeassistant/auth/__init__.py | 12 +-- homeassistant/auth/models.py | 14 ++- homeassistant/auth/providers/__init__.py | 16 +++- homeassistant/auth/providers/command_line.py | 4 +- homeassistant/auth/providers/homeassistant.py | 4 +- .../auth/providers/insecure_example.py | 6 +- .../auth/providers/trusted_networks.py | 10 +- homeassistant/components/auth/login_flow.py | 12 +-- .../components/auth/mfa_setup_flow.py | 3 +- .../components/config/config_entries.py | 2 +- .../homeassistant_sky_connect/config_flow.py | 12 ++- .../components/repairs/issue_handler.py | 2 +- .../components/tplink/config_flow.py | 2 +- homeassistant/components/zeroconf/__init__.py | 4 +- .../components/zwave_js/config_flow.py | 3 +- homeassistant/config_entries.py | 92 ++++++++++++------- homeassistant/data_entry_flow.py | 64 ++++++++----- homeassistant/helpers/data_entry_flow.py | 2 +- homeassistant/helpers/discovery_flow.py | 10 +- 19 files changed, 175 insertions(+), 99 deletions(-) diff --git a/homeassistant/auth/__init__.py b/homeassistant/auth/__init__.py index 19045406a15..21a4b6113d0 100644 --- a/homeassistant/auth/__init__.py +++ b/homeassistant/auth/__init__.py @@ -12,7 +12,6 @@ from typing import Any, cast import jwt -from homeassistant import data_entry_flow from homeassistant.core import ( CALLBACK_TYPE, HassJob, @@ -20,13 +19,14 @@ from homeassistant.core import ( HomeAssistant, callback, ) +from homeassistant.data_entry_flow import FlowHandler, FlowManager, FlowResultType from homeassistant.helpers.event import async_track_point_in_utc_time from homeassistant.util import dt as dt_util from . import auth_store, jwt_wrapper, models from .const import ACCESS_TOKEN_EXPIRATION, GROUP_ID_ADMIN, REFRESH_TOKEN_EXPIRATION from .mfa_modules import MultiFactorAuthModule, auth_mfa_module_from_config -from .models import AuthFlowResult +from .models import AuthFlowContext, AuthFlowResult from .providers import AuthProvider, LoginFlow, auth_provider_from_config from .providers.homeassistant import HassAuthProvider @@ -98,7 +98,7 @@ async def auth_manager_from_config( class AuthManagerFlowManager( - data_entry_flow.FlowManager[AuthFlowResult, tuple[str, str]] + FlowManager[AuthFlowContext, AuthFlowResult, tuple[str, str]] ): """Manage authentication flows.""" @@ -113,7 +113,7 @@ class AuthManagerFlowManager( self, handler_key: tuple[str, str], *, - context: dict[str, Any] | None = None, + context: AuthFlowContext | None = None, data: dict[str, Any] | None = None, ) -> LoginFlow: """Create a login flow.""" @@ -124,7 +124,7 @@ class AuthManagerFlowManager( async def async_finish_flow( self, - flow: data_entry_flow.FlowHandler[AuthFlowResult, tuple[str, str]], + flow: FlowHandler[AuthFlowContext, AuthFlowResult, tuple[str, str]], result: AuthFlowResult, ) -> AuthFlowResult: """Return a user as result of login flow. @@ -134,7 +134,7 @@ class AuthManagerFlowManager( """ flow = cast(LoginFlow, flow) - if result["type"] != data_entry_flow.FlowResultType.CREATE_ENTRY: + if result["type"] != FlowResultType.CREATE_ENTRY: return result # we got final result diff --git a/homeassistant/auth/models.py b/homeassistant/auth/models.py index 0b6515ed9a5..6f45dab2b36 100644 --- a/homeassistant/auth/models.py +++ b/homeassistant/auth/models.py @@ -3,6 +3,7 @@ from __future__ import annotations from datetime import datetime, timedelta +from ipaddress import IPv4Address, IPv6Address import secrets from typing import Any, NamedTuple import uuid @@ -13,7 +14,7 @@ from attr.setters import validate from propcache import cached_property from homeassistant.const import __version__ -from homeassistant.data_entry_flow import FlowResult +from homeassistant.data_entry_flow import FlowContext, FlowResult from homeassistant.util import dt as dt_util from . import permissions as perm_mdl @@ -23,7 +24,16 @@ TOKEN_TYPE_NORMAL = "normal" TOKEN_TYPE_SYSTEM = "system" TOKEN_TYPE_LONG_LIVED_ACCESS_TOKEN = "long_lived_access_token" -AuthFlowResult = FlowResult[tuple[str, str]] + +class AuthFlowContext(FlowContext, total=False): + """Typed context dict for auth flow.""" + + credential_only: bool + ip_address: IPv4Address | IPv6Address + redirect_uri: str + + +AuthFlowResult = FlowResult[AuthFlowContext, tuple[str, str]] @attr.s(slots=True) diff --git a/homeassistant/auth/providers/__init__.py b/homeassistant/auth/providers/__init__.py index debdd0b1a05..34278c47df7 100644 --- a/homeassistant/auth/providers/__init__.py +++ b/homeassistant/auth/providers/__init__.py @@ -10,9 +10,10 @@ from typing import Any import voluptuous as vol from voluptuous.humanize import humanize_error -from homeassistant import data_entry_flow, requirements +from homeassistant import requirements from homeassistant.const import CONF_ID, CONF_NAME, CONF_TYPE from homeassistant.core import HomeAssistant, callback +from homeassistant.data_entry_flow import FlowHandler from homeassistant.exceptions import HomeAssistantError from homeassistant.helpers.importlib import async_import_module from homeassistant.util import dt as dt_util @@ -21,7 +22,14 @@ from homeassistant.util.hass_dict import HassKey from ..auth_store import AuthStore from ..const import MFA_SESSION_EXPIRATION -from ..models import AuthFlowResult, Credentials, RefreshToken, User, UserMeta +from ..models import ( + AuthFlowContext, + AuthFlowResult, + Credentials, + RefreshToken, + User, + UserMeta, +) _LOGGER = logging.getLogger(__name__) DATA_REQS: HassKey[set[str]] = HassKey("auth_prov_reqs_processed") @@ -97,7 +105,7 @@ class AuthProvider: # Implement by extending class - async def async_login_flow(self, context: dict[str, Any] | None) -> LoginFlow: + async def async_login_flow(self, context: AuthFlowContext | None) -> LoginFlow: """Return the data flow for logging in with auth provider. Auth provider should extend LoginFlow and return an instance. @@ -184,7 +192,7 @@ async def load_auth_provider_module( return module -class LoginFlow(data_entry_flow.FlowHandler[AuthFlowResult, tuple[str, str]]): +class LoginFlow(FlowHandler[AuthFlowContext, AuthFlowResult, tuple[str, str]]): """Handler for the login flow.""" _flow_result = AuthFlowResult diff --git a/homeassistant/auth/providers/command_line.py b/homeassistant/auth/providers/command_line.py index 43cde284a25..12447bc8c18 100644 --- a/homeassistant/auth/providers/command_line.py +++ b/homeassistant/auth/providers/command_line.py @@ -13,7 +13,7 @@ import voluptuous as vol from homeassistant.const import CONF_COMMAND from homeassistant.exceptions import HomeAssistantError -from ..models import AuthFlowResult, Credentials, UserMeta +from ..models import AuthFlowContext, AuthFlowResult, Credentials, UserMeta from . import AUTH_PROVIDER_SCHEMA, AUTH_PROVIDERS, AuthProvider, LoginFlow CONF_ARGS = "args" @@ -59,7 +59,7 @@ class CommandLineAuthProvider(AuthProvider): super().__init__(*args, **kwargs) self._user_meta: dict[str, dict[str, Any]] = {} - async def async_login_flow(self, context: dict[str, Any] | None) -> LoginFlow: + async def async_login_flow(self, context: AuthFlowContext | None) -> LoginFlow: """Return a flow to login.""" return CommandLineLoginFlow(self) diff --git a/homeassistant/auth/providers/homeassistant.py b/homeassistant/auth/providers/homeassistant.py index ec39bdbdcdc..e5dded74762 100644 --- a/homeassistant/auth/providers/homeassistant.py +++ b/homeassistant/auth/providers/homeassistant.py @@ -17,7 +17,7 @@ from homeassistant.exceptions import HomeAssistantError from homeassistant.helpers import issue_registry as ir from homeassistant.helpers.storage import Store -from ..models import AuthFlowResult, Credentials, UserMeta +from ..models import AuthFlowContext, AuthFlowResult, Credentials, UserMeta from . import AUTH_PROVIDER_SCHEMA, AUTH_PROVIDERS, AuthProvider, LoginFlow STORAGE_VERSION = 1 @@ -305,7 +305,7 @@ class HassAuthProvider(AuthProvider): await data.async_load() self.data = data - async def async_login_flow(self, context: dict[str, Any] | None) -> LoginFlow: + async def async_login_flow(self, context: AuthFlowContext | None) -> LoginFlow: """Return a flow to login.""" return HassLoginFlow(self) diff --git a/homeassistant/auth/providers/insecure_example.py b/homeassistant/auth/providers/insecure_example.py index 8bcf7569f5a..a7dced851a3 100644 --- a/homeassistant/auth/providers/insecure_example.py +++ b/homeassistant/auth/providers/insecure_example.py @@ -4,14 +4,14 @@ from __future__ import annotations from collections.abc import Mapping import hmac -from typing import Any, cast +from typing import cast import voluptuous as vol from homeassistant.core import callback from homeassistant.exceptions import HomeAssistantError -from ..models import AuthFlowResult, Credentials, UserMeta +from ..models import AuthFlowContext, AuthFlowResult, Credentials, UserMeta from . import AUTH_PROVIDER_SCHEMA, AUTH_PROVIDERS, AuthProvider, LoginFlow USER_SCHEMA = vol.Schema( @@ -36,7 +36,7 @@ class InvalidAuthError(HomeAssistantError): class ExampleAuthProvider(AuthProvider): """Example auth provider based on hardcoded usernames and passwords.""" - async def async_login_flow(self, context: dict[str, Any] | None) -> LoginFlow: + async def async_login_flow(self, context: AuthFlowContext | None) -> LoginFlow: """Return a flow to login.""" return ExampleLoginFlow(self) diff --git a/homeassistant/auth/providers/trusted_networks.py b/homeassistant/auth/providers/trusted_networks.py index 564633073fc..f32c35d4bd5 100644 --- a/homeassistant/auth/providers/trusted_networks.py +++ b/homeassistant/auth/providers/trusted_networks.py @@ -25,7 +25,13 @@ import homeassistant.helpers.config_validation as cv from homeassistant.helpers.network import is_cloud_connection from .. import InvalidAuthError -from ..models import AuthFlowResult, Credentials, RefreshToken, UserMeta +from ..models import ( + AuthFlowContext, + AuthFlowResult, + Credentials, + RefreshToken, + UserMeta, +) from . import AUTH_PROVIDER_SCHEMA, AUTH_PROVIDERS, AuthProvider, LoginFlow type IPAddress = IPv4Address | IPv6Address @@ -98,7 +104,7 @@ class TrustedNetworksAuthProvider(AuthProvider): """Trusted Networks auth provider does not support MFA.""" return False - async def async_login_flow(self, context: dict[str, Any] | None) -> LoginFlow: + async def async_login_flow(self, context: AuthFlowContext | None) -> LoginFlow: """Return a flow to login.""" assert context is not None ip_addr = cast(IPAddress, context.get("ip_address")) diff --git a/homeassistant/components/auth/login_flow.py b/homeassistant/components/auth/login_flow.py index 3664c3ca5c9..d27235123b9 100644 --- a/homeassistant/components/auth/login_flow.py +++ b/homeassistant/components/auth/login_flow.py @@ -80,7 +80,7 @@ import voluptuous_serialize from homeassistant import data_entry_flow from homeassistant.auth import AuthManagerFlowManager, InvalidAuthError -from homeassistant.auth.models import AuthFlowResult, Credentials +from homeassistant.auth.models import AuthFlowContext, AuthFlowResult, Credentials from homeassistant.components import onboarding from homeassistant.components.http import KEY_HASS from homeassistant.components.http.auth import async_user_not_allowed_do_auth @@ -322,11 +322,11 @@ class LoginFlowIndexView(LoginFlowBaseView): try: result = await self._flow_mgr.async_init( handler, - context={ - "ip_address": ip_address(request.remote), # type: ignore[arg-type] - "credential_only": data.get("type") == "link_user", - "redirect_uri": redirect_uri, - }, + context=AuthFlowContext( + ip_address=ip_address(request.remote), # type: ignore[arg-type] + credential_only=data.get("type") == "link_user", + redirect_uri=redirect_uri, + ), ) except data_entry_flow.UnknownHandler: return self.json_message("Invalid handler specified", HTTPStatus.NOT_FOUND) diff --git a/homeassistant/components/auth/mfa_setup_flow.py b/homeassistant/components/auth/mfa_setup_flow.py index 34787894c8c..c9efb081a01 100644 --- a/homeassistant/components/auth/mfa_setup_flow.py +++ b/homeassistant/components/auth/mfa_setup_flow.py @@ -11,6 +11,7 @@ import voluptuous_serialize from homeassistant import data_entry_flow from homeassistant.components import websocket_api from homeassistant.core import HomeAssistant, callback +from homeassistant.data_entry_flow import FlowContext import homeassistant.helpers.config_validation as cv from homeassistant.util.hass_dict import HassKey @@ -44,7 +45,7 @@ class MfaFlowManager(data_entry_flow.FlowManager): self, handler_key: str, *, - context: dict[str, Any], + context: FlowContext | None, data: dict[str, Any], ) -> data_entry_flow.FlowHandler: """Create a setup flow. handler is a mfa module.""" diff --git a/homeassistant/components/config/config_entries.py b/homeassistant/components/config/config_entries.py index 9149ffe98e1..da50f7e93a1 100644 --- a/homeassistant/components/config/config_entries.py +++ b/homeassistant/components/config/config_entries.py @@ -463,7 +463,7 @@ async def ignore_config_flow( ) return - context = {"source": config_entries.SOURCE_IGNORE} + context = config_entries.ConfigFlowContext(source=config_entries.SOURCE_IGNORE) if "discovery_key" in flow["context"]: context["discovery_key"] = flow["context"]["discovery_key"] await hass.config_entries.flow.async_init( diff --git a/homeassistant/components/homeassistant_sky_connect/config_flow.py b/homeassistant/components/homeassistant_sky_connect/config_flow.py index b1776624736..5c35732312b 100644 --- a/homeassistant/components/homeassistant_sky_connect/config_flow.py +++ b/homeassistant/components/homeassistant_sky_connect/config_flow.py @@ -12,7 +12,13 @@ from homeassistant.components.homeassistant_hardware import ( firmware_config_flow, silabs_multiprotocol_addon, ) -from homeassistant.config_entries import ConfigEntry, ConfigFlowResult, OptionsFlow +from homeassistant.config_entries import ( + ConfigEntry, + ConfigEntryBaseFlow, + ConfigFlowContext, + ConfigFlowResult, + OptionsFlow, +) from homeassistant.core import callback from .const import DOCS_WEB_FLASHER_URL, DOMAIN, HardwareVariant @@ -33,10 +39,10 @@ else: TranslationPlaceholderProtocol = object -class SkyConnectTranslationMixin(TranslationPlaceholderProtocol): +class SkyConnectTranslationMixin(ConfigEntryBaseFlow, TranslationPlaceholderProtocol): """Translation placeholder mixin for Home Assistant SkyConnect.""" - context: dict[str, Any] + context: ConfigFlowContext def _get_translation_placeholders(self) -> dict[str, str]: """Shared translation placeholders.""" diff --git a/homeassistant/components/repairs/issue_handler.py b/homeassistant/components/repairs/issue_handler.py index b0b3f82a5d6..cc7e017699d 100644 --- a/homeassistant/components/repairs/issue_handler.py +++ b/homeassistant/components/repairs/issue_handler.py @@ -53,7 +53,7 @@ class RepairsFlowManager(data_entry_flow.FlowManager): self, handler_key: str, *, - context: dict[str, Any] | None = None, + context: data_entry_flow.FlowContext | None = None, data: dict[str, Any] | None = None, ) -> RepairsFlow: """Create a flow. platform is a repairs module.""" diff --git a/homeassistant/components/tplink/config_flow.py b/homeassistant/components/tplink/config_flow.py index ae7543218c7..e94cf9558f0 100644 --- a/homeassistant/components/tplink/config_flow.py +++ b/homeassistant/components/tplink/config_flow.py @@ -378,7 +378,7 @@ class TPLinkConfigFlow(ConfigFlow, domain=DOMAIN): for flow in _config_entries.flow.async_progress_by_handler( DOMAIN, include_uninitialized=True ): - context: dict[str, Any] = flow["context"] + context = flow["context"] if context.get("source") != SOURCE_REAUTH: continue entry_id: str = context["entry_id"] diff --git a/homeassistant/components/zeroconf/__init__.py b/homeassistant/components/zeroconf/__init__.py index b0a78a1ff88..449c2ccef91 100644 --- a/homeassistant/components/zeroconf/__init__.py +++ b/homeassistant/components/zeroconf/__init__.py @@ -540,7 +540,9 @@ class ZeroconfDiscovery: continue matcher_domain = matcher[ATTR_DOMAIN] - context = { + # Create a type annotated regular dict since this is a hot path and creating + # a regular dict is slightly cheaper than calling ConfigFlowContext + context: config_entries.ConfigFlowContext = { "source": config_entries.SOURCE_ZEROCONF, } if domain: diff --git a/homeassistant/components/zwave_js/config_flow.py b/homeassistant/components/zwave_js/config_flow.py index 7733e0325ec..5668f90f4c5 100644 --- a/homeassistant/components/zwave_js/config_flow.py +++ b/homeassistant/components/zwave_js/config_flow.py @@ -29,6 +29,7 @@ from homeassistant.config_entries import ( ConfigEntryBaseFlow, ConfigEntryState, ConfigFlow, + ConfigFlowContext, ConfigFlowResult, OptionsFlow, OptionsFlowManager, @@ -192,7 +193,7 @@ class BaseZwaveJSFlow(ConfigEntryBaseFlow, ABC): @property @abstractmethod - def flow_manager(self) -> FlowManager[ConfigFlowResult]: + def flow_manager(self) -> FlowManager[ConfigFlowContext, ConfigFlowResult]: """Return the flow manager of the flow.""" async def async_step_install_addon( diff --git a/homeassistant/config_entries.py b/homeassistant/config_entries.py index a7b1b3b8d77..c4ead1bbf0d 100644 --- a/homeassistant/config_entries.py +++ b/homeassistant/config_entries.py @@ -41,7 +41,7 @@ from .core import ( HomeAssistant, callback, ) -from .data_entry_flow import FLOW_NOT_COMPLETE_STEPS, FlowResult +from .data_entry_flow import FLOW_NOT_COMPLETE_STEPS, FlowContext, FlowResult from .exceptions import ( ConfigEntryAuthFailed, ConfigEntryError, @@ -267,7 +267,19 @@ UPDATE_ENTRY_CONFIG_ENTRY_ATTRS = { } -class ConfigFlowResult(FlowResult, total=False): +class ConfigFlowContext(FlowContext, total=False): + """Typed context dict for config flow.""" + + alternative_domain: str + configuration_url: str + confirm_only: bool + discovery_key: DiscoveryKey + entry_id: str + title_placeholders: Mapping[str, str] + unique_id: str | None + + +class ConfigFlowResult(FlowResult[ConfigFlowContext, str], total=False): """Typed result dict for config flow.""" minor_version: int @@ -1026,7 +1038,7 @@ class ConfigEntry(Generic[_DataT]): def async_start_reauth( self, hass: HomeAssistant, - context: dict[str, Any] | None = None, + context: ConfigFlowContext | None = None, data: dict[str, Any] | None = None, ) -> None: """Start a reauth flow.""" @@ -1044,7 +1056,7 @@ class ConfigEntry(Generic[_DataT]): async def _async_init_reauth( self, hass: HomeAssistant, - context: dict[str, Any] | None = None, + context: ConfigFlowContext | None = None, data: dict[str, Any] | None = None, ) -> None: """Start a reauth flow.""" @@ -1056,12 +1068,12 @@ class ConfigEntry(Generic[_DataT]): return result = await hass.config_entries.flow.async_init( self.domain, - context={ - "source": SOURCE_REAUTH, - "entry_id": self.entry_id, - "title_placeholders": {"name": self.title}, - "unique_id": self.unique_id, - } + context=ConfigFlowContext( + source=SOURCE_REAUTH, + entry_id=self.entry_id, + title_placeholders={"name": self.title}, + unique_id=self.unique_id, + ) | (context or {}), data=self.data | (data or {}), ) @@ -1086,7 +1098,7 @@ class ConfigEntry(Generic[_DataT]): def async_start_reconfigure( self, hass: HomeAssistant, - context: dict[str, Any] | None = None, + context: ConfigFlowContext | None = None, data: dict[str, Any] | None = None, ) -> None: """Start a reconfigure flow.""" @@ -1103,7 +1115,7 @@ class ConfigEntry(Generic[_DataT]): async def _async_init_reconfigure( self, hass: HomeAssistant, - context: dict[str, Any] | None = None, + context: ConfigFlowContext | None = None, data: dict[str, Any] | None = None, ) -> None: """Start a reconfigure flow.""" @@ -1115,12 +1127,12 @@ class ConfigEntry(Generic[_DataT]): return await hass.config_entries.flow.async_init( self.domain, - context={ - "source": SOURCE_RECONFIGURE, - "entry_id": self.entry_id, - "title_placeholders": {"name": self.title}, - "unique_id": self.unique_id, - } + context=ConfigFlowContext( + source=SOURCE_RECONFIGURE, + entry_id=self.entry_id, + title_placeholders={"name": self.title}, + unique_id=self.unique_id, + ) | (context or {}), data=self.data | (data or {}), ) @@ -1214,7 +1226,9 @@ def _report_non_awaited_platform_forwards(entry: ConfigEntry, what: str) -> None ) -class ConfigEntriesFlowManager(data_entry_flow.FlowManager[ConfigFlowResult]): +class ConfigEntriesFlowManager( + data_entry_flow.FlowManager[ConfigFlowContext, ConfigFlowResult] +): """Manage all the config entry flows that are in progress.""" _flow_result = ConfigFlowResult @@ -1260,7 +1274,11 @@ class ConfigEntriesFlowManager(data_entry_flow.FlowManager[ConfigFlowResult]): return False async def async_init( - self, handler: str, *, context: dict[str, Any] | None = None, data: Any = None + self, + handler: str, + *, + context: ConfigFlowContext | None = None, + data: Any = None, ) -> ConfigFlowResult: """Start a configuration flow.""" if not context or "source" not in context: @@ -1319,7 +1337,7 @@ class ConfigEntriesFlowManager(data_entry_flow.FlowManager[ConfigFlowResult]): self, flow_id: str, handler: str, - context: dict, + context: ConfigFlowContext, data: Any, ) -> tuple[ConfigFlow, ConfigFlowResult]: """Run the init in a task to allow it to be canceled at shutdown.""" @@ -1357,7 +1375,7 @@ class ConfigEntriesFlowManager(data_entry_flow.FlowManager[ConfigFlowResult]): async def async_finish_flow( self, - flow: data_entry_flow.FlowHandler[ConfigFlowResult], + flow: data_entry_flow.FlowHandler[ConfigFlowContext, ConfigFlowResult], result: ConfigFlowResult, ) -> ConfigFlowResult: """Finish a config flow and add an entry. @@ -1504,7 +1522,11 @@ class ConfigEntriesFlowManager(data_entry_flow.FlowManager[ConfigFlowResult]): return result async def async_create_flow( - self, handler_key: str, *, context: dict | None = None, data: Any = None + self, + handler_key: str, + *, + context: ConfigFlowContext | None = None, + data: Any = None, ) -> ConfigFlow: """Create a flow for specified handler. @@ -1522,7 +1544,7 @@ class ConfigEntriesFlowManager(data_entry_flow.FlowManager[ConfigFlowResult]): async def async_post_init( self, - flow: data_entry_flow.FlowHandler[ConfigFlowResult], + flow: data_entry_flow.FlowHandler[ConfigFlowContext, ConfigFlowResult], result: ConfigFlowResult, ) -> None: """After a flow is initialised trigger new flow notifications.""" @@ -1560,7 +1582,7 @@ class ConfigEntriesFlowManager(data_entry_flow.FlowManager[ConfigFlowResult]): @callback def async_has_matching_discovery_flow( - self, handler: str, match_context: dict[str, Any], data: Any + self, handler: str, match_context: ConfigFlowContext, data: Any ) -> bool: """Check if an existing matching discovery flow is in progress. @@ -2385,7 +2407,9 @@ def _async_abort_entries_match( raise data_entry_flow.AbortFlow("already_configured") -class ConfigEntryBaseFlow(data_entry_flow.FlowHandler[ConfigFlowResult]): +class ConfigEntryBaseFlow( + data_entry_flow.FlowHandler[ConfigFlowContext, ConfigFlowResult] +): """Base class for config and option flows.""" _flow_result = ConfigFlowResult @@ -2406,7 +2430,7 @@ class ConfigFlow(ConfigEntryBaseFlow): if not self.context: return None - return cast(str | None, self.context.get("unique_id")) + return self.context.get("unique_id") @staticmethod @callback @@ -2779,7 +2803,7 @@ class ConfigFlow(ConfigEntryBaseFlow): """Return reauth entry id.""" if self.source != SOURCE_REAUTH: raise ValueError(f"Source is {self.source}, expected {SOURCE_REAUTH}") - return self.context["entry_id"] # type: ignore[no-any-return] + return self.context["entry_id"] @callback def _get_reauth_entry(self) -> ConfigEntry: @@ -2793,7 +2817,7 @@ class ConfigFlow(ConfigEntryBaseFlow): """Return reconfigure entry id.""" if self.source != SOURCE_RECONFIGURE: raise ValueError(f"Source is {self.source}, expected {SOURCE_RECONFIGURE}") - return self.context["entry_id"] # type: ignore[no-any-return] + return self.context["entry_id"] @callback def _get_reconfigure_entry(self) -> ConfigEntry: @@ -2805,7 +2829,9 @@ class ConfigFlow(ConfigEntryBaseFlow): raise UnknownEntry -class OptionsFlowManager(data_entry_flow.FlowManager[ConfigFlowResult]): +class OptionsFlowManager( + data_entry_flow.FlowManager[ConfigFlowContext, ConfigFlowResult] +): """Flow to set options for a configuration entry.""" _flow_result = ConfigFlowResult @@ -2822,7 +2848,7 @@ class OptionsFlowManager(data_entry_flow.FlowManager[ConfigFlowResult]): self, handler_key: str, *, - context: dict[str, Any] | None = None, + context: ConfigFlowContext | None = None, data: dict[str, Any] | None = None, ) -> OptionsFlow: """Create an options flow for a config entry. @@ -2835,7 +2861,7 @@ class OptionsFlowManager(data_entry_flow.FlowManager[ConfigFlowResult]): async def async_finish_flow( self, - flow: data_entry_flow.FlowHandler[ConfigFlowResult], + flow: data_entry_flow.FlowHandler[ConfigFlowContext, ConfigFlowResult], result: ConfigFlowResult, ) -> ConfigFlowResult: """Finish an options flow and update options for configuration entry. @@ -2860,7 +2886,7 @@ class OptionsFlowManager(data_entry_flow.FlowManager[ConfigFlowResult]): return result async def _async_setup_preview( - self, flow: data_entry_flow.FlowHandler[ConfigFlowResult] + self, flow: data_entry_flow.FlowHandler[ConfigFlowContext, ConfigFlowResult] ) -> None: """Set up preview for an option flow handler.""" entry = self._async_get_config_entry(flow.handler) diff --git a/homeassistant/data_entry_flow.py b/homeassistant/data_entry_flow.py index de08a178a70..1fb6439a8c4 100644 --- a/homeassistant/data_entry_flow.py +++ b/homeassistant/data_entry_flow.py @@ -87,7 +87,10 @@ STEP_ID_OPTIONAL_STEPS = { } -_FlowResultT = TypeVar("_FlowResultT", bound="FlowResult[Any]", default="FlowResult") +_FlowContextT = TypeVar("_FlowContextT", bound="FlowContext", default="FlowContext") +_FlowResultT = TypeVar( + "_FlowResultT", bound="FlowResult[Any, Any]", default="FlowResult" +) _HandlerT = TypeVar("_HandlerT", default=str) @@ -139,10 +142,17 @@ class AbortFlow(FlowError): self.description_placeholders = description_placeholders -class FlowResult(TypedDict, Generic[_HandlerT], total=False): +class FlowContext(TypedDict, total=False): + """Typed context dict.""" + + show_advanced_options: bool + source: str + + +class FlowResult(TypedDict, Generic[_FlowContextT, _HandlerT], total=False): """Typed result dict.""" - context: dict[str, Any] + context: _FlowContextT data_schema: vol.Schema | None data: Mapping[str, Any] description_placeholders: Mapping[str, str | None] | None @@ -189,7 +199,7 @@ def _map_error_to_schema_errors( schema_errors[path_part_str] = error.error_message -class FlowManager(abc.ABC, Generic[_FlowResultT, _HandlerT]): +class FlowManager(abc.ABC, Generic[_FlowContextT, _FlowResultT, _HandlerT]): """Manage all the flows that are in progress.""" _flow_result: type[_FlowResultT] = FlowResult # type: ignore[assignment] @@ -201,12 +211,14 @@ class FlowManager(abc.ABC, Generic[_FlowResultT, _HandlerT]): """Initialize the flow manager.""" self.hass = hass self._preview: set[_HandlerT] = set() - self._progress: dict[str, FlowHandler[_FlowResultT, _HandlerT]] = {} + self._progress: dict[ + str, FlowHandler[_FlowContextT, _FlowResultT, _HandlerT] + ] = {} self._handler_progress_index: defaultdict[ - _HandlerT, set[FlowHandler[_FlowResultT, _HandlerT]] + _HandlerT, set[FlowHandler[_FlowContextT, _FlowResultT, _HandlerT]] ] = defaultdict(set) self._init_data_process_index: defaultdict[ - type, set[FlowHandler[_FlowResultT, _HandlerT]] + type, set[FlowHandler[_FlowContextT, _FlowResultT, _HandlerT]] ] = defaultdict(set) @abc.abstractmethod @@ -214,9 +226,9 @@ class FlowManager(abc.ABC, Generic[_FlowResultT, _HandlerT]): self, handler_key: _HandlerT, *, - context: dict[str, Any] | None = None, + context: _FlowContextT | None = None, data: dict[str, Any] | None = None, - ) -> FlowHandler[_FlowResultT, _HandlerT]: + ) -> FlowHandler[_FlowContextT, _FlowResultT, _HandlerT]: """Create a flow for specified handler. Handler key is the domain of the component that we want to set up. @@ -224,7 +236,9 @@ class FlowManager(abc.ABC, Generic[_FlowResultT, _HandlerT]): @abc.abstractmethod async def async_finish_flow( - self, flow: FlowHandler[_FlowResultT, _HandlerT], result: _FlowResultT + self, + flow: FlowHandler[_FlowContextT, _FlowResultT, _HandlerT], + result: _FlowResultT, ) -> _FlowResultT: """Finish a data entry flow. @@ -233,7 +247,9 @@ class FlowManager(abc.ABC, Generic[_FlowResultT, _HandlerT]): """ async def async_post_init( - self, flow: FlowHandler[_FlowResultT, _HandlerT], result: _FlowResultT + self, + flow: FlowHandler[_FlowContextT, _FlowResultT, _HandlerT], + result: _FlowResultT, ) -> None: """Entry has finished executing its first step asynchronously.""" @@ -288,7 +304,7 @@ class FlowManager(abc.ABC, Generic[_FlowResultT, _HandlerT]): @callback def _async_progress_by_handler( self, handler: _HandlerT, match_context: dict[str, Any] | None - ) -> list[FlowHandler[_FlowResultT, _HandlerT]]: + ) -> list[FlowHandler[_FlowContextT, _FlowResultT, _HandlerT]]: """Return the flows in progress by handler. If match_context is specified, only return flows with a context that @@ -307,12 +323,12 @@ class FlowManager(abc.ABC, Generic[_FlowResultT, _HandlerT]): self, handler: _HandlerT, *, - context: dict[str, Any] | None = None, + context: _FlowContextT | None = None, data: Any = None, ) -> _FlowResultT: """Start a data entry flow.""" if context is None: - context = {} + context = cast(_FlowContextT, {}) flow = await self.async_create_flow(handler, context=context, data=data) if not flow: raise UnknownFlow("Flow was not created") @@ -452,7 +468,7 @@ class FlowManager(abc.ABC, Generic[_FlowResultT, _HandlerT]): @callback def _async_add_flow_progress( - self, flow: FlowHandler[_FlowResultT, _HandlerT] + self, flow: FlowHandler[_FlowContextT, _FlowResultT, _HandlerT] ) -> None: """Add a flow to in progress.""" if flow.init_data is not None: @@ -462,7 +478,7 @@ class FlowManager(abc.ABC, Generic[_FlowResultT, _HandlerT]): @callback def _async_remove_flow_from_index( - self, flow: FlowHandler[_FlowResultT, _HandlerT] + self, flow: FlowHandler[_FlowContextT, _FlowResultT, _HandlerT] ) -> None: """Remove a flow from in progress.""" if flow.init_data is not None: @@ -489,7 +505,7 @@ class FlowManager(abc.ABC, Generic[_FlowResultT, _HandlerT]): async def _async_handle_step( self, - flow: FlowHandler[_FlowResultT, _HandlerT], + flow: FlowHandler[_FlowContextT, _FlowResultT, _HandlerT], step_id: str, user_input: dict | BaseServiceInfo | None, ) -> _FlowResultT: @@ -566,7 +582,7 @@ class FlowManager(abc.ABC, Generic[_FlowResultT, _HandlerT]): return result def _raise_if_step_does_not_exist( - self, flow: FlowHandler[_FlowResultT, _HandlerT], step_id: str + self, flow: FlowHandler[_FlowContextT, _FlowResultT, _HandlerT], step_id: str ) -> None: """Raise if the step does not exist.""" method = f"async_step_{step_id}" @@ -578,7 +594,7 @@ class FlowManager(abc.ABC, Generic[_FlowResultT, _HandlerT]): ) async def _async_setup_preview( - self, flow: FlowHandler[_FlowResultT, _HandlerT] + self, flow: FlowHandler[_FlowContextT, _FlowResultT, _HandlerT] ) -> None: """Set up preview for a flow handler.""" if flow.handler not in self._preview: @@ -588,7 +604,7 @@ class FlowManager(abc.ABC, Generic[_FlowResultT, _HandlerT]): @callback def _async_flow_handler_to_flow_result( self, - flows: Iterable[FlowHandler[_FlowResultT, _HandlerT]], + flows: Iterable[FlowHandler[_FlowContextT, _FlowResultT, _HandlerT]], include_uninitialized: bool, ) -> list[_FlowResultT]: """Convert a list of FlowHandler to a partial FlowResult that can be serialized.""" @@ -610,7 +626,7 @@ class FlowManager(abc.ABC, Generic[_FlowResultT, _HandlerT]): ] -class FlowHandler(Generic[_FlowResultT, _HandlerT]): +class FlowHandler(Generic[_FlowContextT, _FlowResultT, _HandlerT]): """Handle a data entry flow.""" _flow_result: type[_FlowResultT] = FlowResult # type: ignore[assignment] @@ -624,7 +640,7 @@ class FlowHandler(Generic[_FlowResultT, _HandlerT]): hass: HomeAssistant = None # type: ignore[assignment] handler: _HandlerT = None # type: ignore[assignment] # Ensure the attribute has a subscriptable, but immutable, default value. - context: dict[str, Any] = MappingProxyType({}) # type: ignore[assignment] + context: _FlowContextT = MappingProxyType({}) # type: ignore[assignment] # Set by _async_create_flow callback init_step = "init" @@ -643,12 +659,12 @@ class FlowHandler(Generic[_FlowResultT, _HandlerT]): @property def source(self) -> str | None: """Source that initialized the flow.""" - return self.context.get("source", None) # type: ignore[no-any-return] + return self.context.get("source", None) # type: ignore[return-value] @property def show_advanced_options(self) -> bool: """If we should show advanced options.""" - return self.context.get("show_advanced_options", False) # type: ignore[no-any-return] + return self.context.get("show_advanced_options", False) # type: ignore[return-value] def add_suggested_values_to_schema( self, data_schema: vol.Schema, suggested_values: Mapping[str, Any] | None diff --git a/homeassistant/helpers/data_entry_flow.py b/homeassistant/helpers/data_entry_flow.py index b2cad292e3d..adb2062a8ea 100644 --- a/homeassistant/helpers/data_entry_flow.py +++ b/homeassistant/helpers/data_entry_flow.py @@ -18,7 +18,7 @@ from . import config_validation as cv _FlowManagerT = TypeVar( "_FlowManagerT", - bound=data_entry_flow.FlowManager[Any], + bound=data_entry_flow.FlowManager[Any, Any], default=data_entry_flow.FlowManager, ) diff --git a/homeassistant/helpers/discovery_flow.py b/homeassistant/helpers/discovery_flow.py index e6596a496e0..fd41c7ffb44 100644 --- a/homeassistant/helpers/discovery_flow.py +++ b/homeassistant/helpers/discovery_flow.py @@ -13,7 +13,7 @@ from homeassistant.util.async_ import gather_with_limited_concurrency from homeassistant.util.hass_dict import HassKey if TYPE_CHECKING: - from homeassistant.config_entries import ConfigFlowResult + from homeassistant.config_entries import ConfigFlowContext, ConfigFlowResult FLOW_INIT_LIMIT = 20 DISCOVERY_FLOW_DISPATCHER: HassKey[FlowDispatcher] = HassKey( @@ -42,7 +42,7 @@ class DiscoveryKey: def async_create_flow( hass: HomeAssistant, domain: str, - context: dict[str, Any], + context: ConfigFlowContext, data: Any, *, discovery_key: DiscoveryKey | None = None, @@ -70,7 +70,7 @@ def async_create_flow( @callback def _async_init_flow( - hass: HomeAssistant, domain: str, context: dict[str, Any], data: Any + hass: HomeAssistant, domain: str, context: ConfigFlowContext, data: Any ) -> Coroutine[None, None, ConfigFlowResult] | None: """Create a discovery flow.""" # Avoid spawning flows that have the same initial discovery data @@ -98,7 +98,7 @@ class PendingFlowKey(NamedTuple): class PendingFlowValue(NamedTuple): """Value for pending flows.""" - context: dict[str, Any] + context: ConfigFlowContext data: Any @@ -137,7 +137,7 @@ class FlowDispatcher: await gather_with_limited_concurrency(FLOW_INIT_LIMIT, *init_coros) @callback - def async_create(self, domain: str, context: dict[str, Any], data: Any) -> None: + def async_create(self, domain: str, context: ConfigFlowContext, data: Any) -> None: """Create and add or queue a flow.""" key = PendingFlowKey(domain, context["source"]) values = PendingFlowValue(context, data)