Make FlowHandler.context a typed dict (#126291)

* Make FlowHandler.context a typed dict

* Adjust typing

* Adjust typing

* Avoid calling ConfigFlowContext constructor in hot path
pull/127925/head
Erik Montnemery 2024-10-08 12:18:45 +02:00 committed by GitHub
parent 217165208b
commit d6ee10a543
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
19 changed files with 175 additions and 99 deletions

View File

@ -12,7 +12,6 @@ from typing import Any, cast
import jwt
from homeassistant import data_entry_flow
from homeassistant.core import (
CALLBACK_TYPE,
HassJob,
@ -20,13 +19,14 @@ from homeassistant.core import (
HomeAssistant,
callback,
)
from homeassistant.data_entry_flow import FlowHandler, FlowManager, FlowResultType
from homeassistant.helpers.event import async_track_point_in_utc_time
from homeassistant.util import dt as dt_util
from . import auth_store, jwt_wrapper, models
from .const import ACCESS_TOKEN_EXPIRATION, GROUP_ID_ADMIN, REFRESH_TOKEN_EXPIRATION
from .mfa_modules import MultiFactorAuthModule, auth_mfa_module_from_config
from .models import AuthFlowResult
from .models import AuthFlowContext, AuthFlowResult
from .providers import AuthProvider, LoginFlow, auth_provider_from_config
from .providers.homeassistant import HassAuthProvider
@ -98,7 +98,7 @@ async def auth_manager_from_config(
class AuthManagerFlowManager(
data_entry_flow.FlowManager[AuthFlowResult, tuple[str, str]]
FlowManager[AuthFlowContext, AuthFlowResult, tuple[str, str]]
):
"""Manage authentication flows."""
@ -113,7 +113,7 @@ class AuthManagerFlowManager(
self,
handler_key: tuple[str, str],
*,
context: dict[str, Any] | None = None,
context: AuthFlowContext | None = None,
data: dict[str, Any] | None = None,
) -> LoginFlow:
"""Create a login flow."""
@ -124,7 +124,7 @@ class AuthManagerFlowManager(
async def async_finish_flow(
self,
flow: data_entry_flow.FlowHandler[AuthFlowResult, tuple[str, str]],
flow: FlowHandler[AuthFlowContext, AuthFlowResult, tuple[str, str]],
result: AuthFlowResult,
) -> AuthFlowResult:
"""Return a user as result of login flow.
@ -134,7 +134,7 @@ class AuthManagerFlowManager(
"""
flow = cast(LoginFlow, flow)
if result["type"] != data_entry_flow.FlowResultType.CREATE_ENTRY:
if result["type"] != FlowResultType.CREATE_ENTRY:
return result
# we got final result

View File

@ -3,6 +3,7 @@
from __future__ import annotations
from datetime import datetime, timedelta
from ipaddress import IPv4Address, IPv6Address
import secrets
from typing import Any, NamedTuple
import uuid
@ -13,7 +14,7 @@ from attr.setters import validate
from propcache import cached_property
from homeassistant.const import __version__
from homeassistant.data_entry_flow import FlowResult
from homeassistant.data_entry_flow import FlowContext, FlowResult
from homeassistant.util import dt as dt_util
from . import permissions as perm_mdl
@ -23,7 +24,16 @@ TOKEN_TYPE_NORMAL = "normal"
TOKEN_TYPE_SYSTEM = "system"
TOKEN_TYPE_LONG_LIVED_ACCESS_TOKEN = "long_lived_access_token"
AuthFlowResult = FlowResult[tuple[str, str]]
class AuthFlowContext(FlowContext, total=False):
"""Typed context dict for auth flow."""
credential_only: bool
ip_address: IPv4Address | IPv6Address
redirect_uri: str
AuthFlowResult = FlowResult[AuthFlowContext, tuple[str, str]]
@attr.s(slots=True)

View File

@ -10,9 +10,10 @@ from typing import Any
import voluptuous as vol
from voluptuous.humanize import humanize_error
from homeassistant import data_entry_flow, requirements
from homeassistant import requirements
from homeassistant.const import CONF_ID, CONF_NAME, CONF_TYPE
from homeassistant.core import HomeAssistant, callback
from homeassistant.data_entry_flow import FlowHandler
from homeassistant.exceptions import HomeAssistantError
from homeassistant.helpers.importlib import async_import_module
from homeassistant.util import dt as dt_util
@ -21,7 +22,14 @@ from homeassistant.util.hass_dict import HassKey
from ..auth_store import AuthStore
from ..const import MFA_SESSION_EXPIRATION
from ..models import AuthFlowResult, Credentials, RefreshToken, User, UserMeta
from ..models import (
AuthFlowContext,
AuthFlowResult,
Credentials,
RefreshToken,
User,
UserMeta,
)
_LOGGER = logging.getLogger(__name__)
DATA_REQS: HassKey[set[str]] = HassKey("auth_prov_reqs_processed")
@ -97,7 +105,7 @@ class AuthProvider:
# Implement by extending class
async def async_login_flow(self, context: dict[str, Any] | None) -> LoginFlow:
async def async_login_flow(self, context: AuthFlowContext | None) -> LoginFlow:
"""Return the data flow for logging in with auth provider.
Auth provider should extend LoginFlow and return an instance.
@ -184,7 +192,7 @@ async def load_auth_provider_module(
return module
class LoginFlow(data_entry_flow.FlowHandler[AuthFlowResult, tuple[str, str]]):
class LoginFlow(FlowHandler[AuthFlowContext, AuthFlowResult, tuple[str, str]]):
"""Handler for the login flow."""
_flow_result = AuthFlowResult

View File

@ -13,7 +13,7 @@ import voluptuous as vol
from homeassistant.const import CONF_COMMAND
from homeassistant.exceptions import HomeAssistantError
from ..models import AuthFlowResult, Credentials, UserMeta
from ..models import AuthFlowContext, AuthFlowResult, Credentials, UserMeta
from . import AUTH_PROVIDER_SCHEMA, AUTH_PROVIDERS, AuthProvider, LoginFlow
CONF_ARGS = "args"
@ -59,7 +59,7 @@ class CommandLineAuthProvider(AuthProvider):
super().__init__(*args, **kwargs)
self._user_meta: dict[str, dict[str, Any]] = {}
async def async_login_flow(self, context: dict[str, Any] | None) -> LoginFlow:
async def async_login_flow(self, context: AuthFlowContext | None) -> LoginFlow:
"""Return a flow to login."""
return CommandLineLoginFlow(self)

View File

@ -17,7 +17,7 @@ from homeassistant.exceptions import HomeAssistantError
from homeassistant.helpers import issue_registry as ir
from homeassistant.helpers.storage import Store
from ..models import AuthFlowResult, Credentials, UserMeta
from ..models import AuthFlowContext, AuthFlowResult, Credentials, UserMeta
from . import AUTH_PROVIDER_SCHEMA, AUTH_PROVIDERS, AuthProvider, LoginFlow
STORAGE_VERSION = 1
@ -305,7 +305,7 @@ class HassAuthProvider(AuthProvider):
await data.async_load()
self.data = data
async def async_login_flow(self, context: dict[str, Any] | None) -> LoginFlow:
async def async_login_flow(self, context: AuthFlowContext | None) -> LoginFlow:
"""Return a flow to login."""
return HassLoginFlow(self)

View File

@ -4,14 +4,14 @@ from __future__ import annotations
from collections.abc import Mapping
import hmac
from typing import Any, cast
from typing import cast
import voluptuous as vol
from homeassistant.core import callback
from homeassistant.exceptions import HomeAssistantError
from ..models import AuthFlowResult, Credentials, UserMeta
from ..models import AuthFlowContext, AuthFlowResult, Credentials, UserMeta
from . import AUTH_PROVIDER_SCHEMA, AUTH_PROVIDERS, AuthProvider, LoginFlow
USER_SCHEMA = vol.Schema(
@ -36,7 +36,7 @@ class InvalidAuthError(HomeAssistantError):
class ExampleAuthProvider(AuthProvider):
"""Example auth provider based on hardcoded usernames and passwords."""
async def async_login_flow(self, context: dict[str, Any] | None) -> LoginFlow:
async def async_login_flow(self, context: AuthFlowContext | None) -> LoginFlow:
"""Return a flow to login."""
return ExampleLoginFlow(self)

View File

@ -25,7 +25,13 @@ import homeassistant.helpers.config_validation as cv
from homeassistant.helpers.network import is_cloud_connection
from .. import InvalidAuthError
from ..models import AuthFlowResult, Credentials, RefreshToken, UserMeta
from ..models import (
AuthFlowContext,
AuthFlowResult,
Credentials,
RefreshToken,
UserMeta,
)
from . import AUTH_PROVIDER_SCHEMA, AUTH_PROVIDERS, AuthProvider, LoginFlow
type IPAddress = IPv4Address | IPv6Address
@ -98,7 +104,7 @@ class TrustedNetworksAuthProvider(AuthProvider):
"""Trusted Networks auth provider does not support MFA."""
return False
async def async_login_flow(self, context: dict[str, Any] | None) -> LoginFlow:
async def async_login_flow(self, context: AuthFlowContext | None) -> LoginFlow:
"""Return a flow to login."""
assert context is not None
ip_addr = cast(IPAddress, context.get("ip_address"))

View File

@ -80,7 +80,7 @@ import voluptuous_serialize
from homeassistant import data_entry_flow
from homeassistant.auth import AuthManagerFlowManager, InvalidAuthError
from homeassistant.auth.models import AuthFlowResult, Credentials
from homeassistant.auth.models import AuthFlowContext, AuthFlowResult, Credentials
from homeassistant.components import onboarding
from homeassistant.components.http import KEY_HASS
from homeassistant.components.http.auth import async_user_not_allowed_do_auth
@ -322,11 +322,11 @@ class LoginFlowIndexView(LoginFlowBaseView):
try:
result = await self._flow_mgr.async_init(
handler,
context={
"ip_address": ip_address(request.remote), # type: ignore[arg-type]
"credential_only": data.get("type") == "link_user",
"redirect_uri": redirect_uri,
},
context=AuthFlowContext(
ip_address=ip_address(request.remote), # type: ignore[arg-type]
credential_only=data.get("type") == "link_user",
redirect_uri=redirect_uri,
),
)
except data_entry_flow.UnknownHandler:
return self.json_message("Invalid handler specified", HTTPStatus.NOT_FOUND)

View File

@ -11,6 +11,7 @@ import voluptuous_serialize
from homeassistant import data_entry_flow
from homeassistant.components import websocket_api
from homeassistant.core import HomeAssistant, callback
from homeassistant.data_entry_flow import FlowContext
import homeassistant.helpers.config_validation as cv
from homeassistant.util.hass_dict import HassKey
@ -44,7 +45,7 @@ class MfaFlowManager(data_entry_flow.FlowManager):
self,
handler_key: str,
*,
context: dict[str, Any],
context: FlowContext | None,
data: dict[str, Any],
) -> data_entry_flow.FlowHandler:
"""Create a setup flow. handler is a mfa module."""

View File

@ -463,7 +463,7 @@ async def ignore_config_flow(
)
return
context = {"source": config_entries.SOURCE_IGNORE}
context = config_entries.ConfigFlowContext(source=config_entries.SOURCE_IGNORE)
if "discovery_key" in flow["context"]:
context["discovery_key"] = flow["context"]["discovery_key"]
await hass.config_entries.flow.async_init(

View File

@ -12,7 +12,13 @@ from homeassistant.components.homeassistant_hardware import (
firmware_config_flow,
silabs_multiprotocol_addon,
)
from homeassistant.config_entries import ConfigEntry, ConfigFlowResult, OptionsFlow
from homeassistant.config_entries import (
ConfigEntry,
ConfigEntryBaseFlow,
ConfigFlowContext,
ConfigFlowResult,
OptionsFlow,
)
from homeassistant.core import callback
from .const import DOCS_WEB_FLASHER_URL, DOMAIN, HardwareVariant
@ -33,10 +39,10 @@ else:
TranslationPlaceholderProtocol = object
class SkyConnectTranslationMixin(TranslationPlaceholderProtocol):
class SkyConnectTranslationMixin(ConfigEntryBaseFlow, TranslationPlaceholderProtocol):
"""Translation placeholder mixin for Home Assistant SkyConnect."""
context: dict[str, Any]
context: ConfigFlowContext
def _get_translation_placeholders(self) -> dict[str, str]:
"""Shared translation placeholders."""

View File

@ -53,7 +53,7 @@ class RepairsFlowManager(data_entry_flow.FlowManager):
self,
handler_key: str,
*,
context: dict[str, Any] | None = None,
context: data_entry_flow.FlowContext | None = None,
data: dict[str, Any] | None = None,
) -> RepairsFlow:
"""Create a flow. platform is a repairs module."""

View File

@ -378,7 +378,7 @@ class TPLinkConfigFlow(ConfigFlow, domain=DOMAIN):
for flow in _config_entries.flow.async_progress_by_handler(
DOMAIN, include_uninitialized=True
):
context: dict[str, Any] = flow["context"]
context = flow["context"]
if context.get("source") != SOURCE_REAUTH:
continue
entry_id: str = context["entry_id"]

View File

@ -540,7 +540,9 @@ class ZeroconfDiscovery:
continue
matcher_domain = matcher[ATTR_DOMAIN]
context = {
# Create a type annotated regular dict since this is a hot path and creating
# a regular dict is slightly cheaper than calling ConfigFlowContext
context: config_entries.ConfigFlowContext = {
"source": config_entries.SOURCE_ZEROCONF,
}
if domain:

View File

@ -29,6 +29,7 @@ from homeassistant.config_entries import (
ConfigEntryBaseFlow,
ConfigEntryState,
ConfigFlow,
ConfigFlowContext,
ConfigFlowResult,
OptionsFlow,
OptionsFlowManager,
@ -192,7 +193,7 @@ class BaseZwaveJSFlow(ConfigEntryBaseFlow, ABC):
@property
@abstractmethod
def flow_manager(self) -> FlowManager[ConfigFlowResult]:
def flow_manager(self) -> FlowManager[ConfigFlowContext, ConfigFlowResult]:
"""Return the flow manager of the flow."""
async def async_step_install_addon(

View File

@ -41,7 +41,7 @@ from .core import (
HomeAssistant,
callback,
)
from .data_entry_flow import FLOW_NOT_COMPLETE_STEPS, FlowResult
from .data_entry_flow import FLOW_NOT_COMPLETE_STEPS, FlowContext, FlowResult
from .exceptions import (
ConfigEntryAuthFailed,
ConfigEntryError,
@ -267,7 +267,19 @@ UPDATE_ENTRY_CONFIG_ENTRY_ATTRS = {
}
class ConfigFlowResult(FlowResult, total=False):
class ConfigFlowContext(FlowContext, total=False):
"""Typed context dict for config flow."""
alternative_domain: str
configuration_url: str
confirm_only: bool
discovery_key: DiscoveryKey
entry_id: str
title_placeholders: Mapping[str, str]
unique_id: str | None
class ConfigFlowResult(FlowResult[ConfigFlowContext, str], total=False):
"""Typed result dict for config flow."""
minor_version: int
@ -1026,7 +1038,7 @@ class ConfigEntry(Generic[_DataT]):
def async_start_reauth(
self,
hass: HomeAssistant,
context: dict[str, Any] | None = None,
context: ConfigFlowContext | None = None,
data: dict[str, Any] | None = None,
) -> None:
"""Start a reauth flow."""
@ -1044,7 +1056,7 @@ class ConfigEntry(Generic[_DataT]):
async def _async_init_reauth(
self,
hass: HomeAssistant,
context: dict[str, Any] | None = None,
context: ConfigFlowContext | None = None,
data: dict[str, Any] | None = None,
) -> None:
"""Start a reauth flow."""
@ -1056,12 +1068,12 @@ class ConfigEntry(Generic[_DataT]):
return
result = await hass.config_entries.flow.async_init(
self.domain,
context={
"source": SOURCE_REAUTH,
"entry_id": self.entry_id,
"title_placeholders": {"name": self.title},
"unique_id": self.unique_id,
}
context=ConfigFlowContext(
source=SOURCE_REAUTH,
entry_id=self.entry_id,
title_placeholders={"name": self.title},
unique_id=self.unique_id,
)
| (context or {}),
data=self.data | (data or {}),
)
@ -1086,7 +1098,7 @@ class ConfigEntry(Generic[_DataT]):
def async_start_reconfigure(
self,
hass: HomeAssistant,
context: dict[str, Any] | None = None,
context: ConfigFlowContext | None = None,
data: dict[str, Any] | None = None,
) -> None:
"""Start a reconfigure flow."""
@ -1103,7 +1115,7 @@ class ConfigEntry(Generic[_DataT]):
async def _async_init_reconfigure(
self,
hass: HomeAssistant,
context: dict[str, Any] | None = None,
context: ConfigFlowContext | None = None,
data: dict[str, Any] | None = None,
) -> None:
"""Start a reconfigure flow."""
@ -1115,12 +1127,12 @@ class ConfigEntry(Generic[_DataT]):
return
await hass.config_entries.flow.async_init(
self.domain,
context={
"source": SOURCE_RECONFIGURE,
"entry_id": self.entry_id,
"title_placeholders": {"name": self.title},
"unique_id": self.unique_id,
}
context=ConfigFlowContext(
source=SOURCE_RECONFIGURE,
entry_id=self.entry_id,
title_placeholders={"name": self.title},
unique_id=self.unique_id,
)
| (context or {}),
data=self.data | (data or {}),
)
@ -1214,7 +1226,9 @@ def _report_non_awaited_platform_forwards(entry: ConfigEntry, what: str) -> None
)
class ConfigEntriesFlowManager(data_entry_flow.FlowManager[ConfigFlowResult]):
class ConfigEntriesFlowManager(
data_entry_flow.FlowManager[ConfigFlowContext, ConfigFlowResult]
):
"""Manage all the config entry flows that are in progress."""
_flow_result = ConfigFlowResult
@ -1260,7 +1274,11 @@ class ConfigEntriesFlowManager(data_entry_flow.FlowManager[ConfigFlowResult]):
return False
async def async_init(
self, handler: str, *, context: dict[str, Any] | None = None, data: Any = None
self,
handler: str,
*,
context: ConfigFlowContext | None = None,
data: Any = None,
) -> ConfigFlowResult:
"""Start a configuration flow."""
if not context or "source" not in context:
@ -1319,7 +1337,7 @@ class ConfigEntriesFlowManager(data_entry_flow.FlowManager[ConfigFlowResult]):
self,
flow_id: str,
handler: str,
context: dict,
context: ConfigFlowContext,
data: Any,
) -> tuple[ConfigFlow, ConfigFlowResult]:
"""Run the init in a task to allow it to be canceled at shutdown."""
@ -1357,7 +1375,7 @@ class ConfigEntriesFlowManager(data_entry_flow.FlowManager[ConfigFlowResult]):
async def async_finish_flow(
self,
flow: data_entry_flow.FlowHandler[ConfigFlowResult],
flow: data_entry_flow.FlowHandler[ConfigFlowContext, ConfigFlowResult],
result: ConfigFlowResult,
) -> ConfigFlowResult:
"""Finish a config flow and add an entry.
@ -1504,7 +1522,11 @@ class ConfigEntriesFlowManager(data_entry_flow.FlowManager[ConfigFlowResult]):
return result
async def async_create_flow(
self, handler_key: str, *, context: dict | None = None, data: Any = None
self,
handler_key: str,
*,
context: ConfigFlowContext | None = None,
data: Any = None,
) -> ConfigFlow:
"""Create a flow for specified handler.
@ -1522,7 +1544,7 @@ class ConfigEntriesFlowManager(data_entry_flow.FlowManager[ConfigFlowResult]):
async def async_post_init(
self,
flow: data_entry_flow.FlowHandler[ConfigFlowResult],
flow: data_entry_flow.FlowHandler[ConfigFlowContext, ConfigFlowResult],
result: ConfigFlowResult,
) -> None:
"""After a flow is initialised trigger new flow notifications."""
@ -1560,7 +1582,7 @@ class ConfigEntriesFlowManager(data_entry_flow.FlowManager[ConfigFlowResult]):
@callback
def async_has_matching_discovery_flow(
self, handler: str, match_context: dict[str, Any], data: Any
self, handler: str, match_context: ConfigFlowContext, data: Any
) -> bool:
"""Check if an existing matching discovery flow is in progress.
@ -2385,7 +2407,9 @@ def _async_abort_entries_match(
raise data_entry_flow.AbortFlow("already_configured")
class ConfigEntryBaseFlow(data_entry_flow.FlowHandler[ConfigFlowResult]):
class ConfigEntryBaseFlow(
data_entry_flow.FlowHandler[ConfigFlowContext, ConfigFlowResult]
):
"""Base class for config and option flows."""
_flow_result = ConfigFlowResult
@ -2406,7 +2430,7 @@ class ConfigFlow(ConfigEntryBaseFlow):
if not self.context:
return None
return cast(str | None, self.context.get("unique_id"))
return self.context.get("unique_id")
@staticmethod
@callback
@ -2779,7 +2803,7 @@ class ConfigFlow(ConfigEntryBaseFlow):
"""Return reauth entry id."""
if self.source != SOURCE_REAUTH:
raise ValueError(f"Source is {self.source}, expected {SOURCE_REAUTH}")
return self.context["entry_id"] # type: ignore[no-any-return]
return self.context["entry_id"]
@callback
def _get_reauth_entry(self) -> ConfigEntry:
@ -2793,7 +2817,7 @@ class ConfigFlow(ConfigEntryBaseFlow):
"""Return reconfigure entry id."""
if self.source != SOURCE_RECONFIGURE:
raise ValueError(f"Source is {self.source}, expected {SOURCE_RECONFIGURE}")
return self.context["entry_id"] # type: ignore[no-any-return]
return self.context["entry_id"]
@callback
def _get_reconfigure_entry(self) -> ConfigEntry:
@ -2805,7 +2829,9 @@ class ConfigFlow(ConfigEntryBaseFlow):
raise UnknownEntry
class OptionsFlowManager(data_entry_flow.FlowManager[ConfigFlowResult]):
class OptionsFlowManager(
data_entry_flow.FlowManager[ConfigFlowContext, ConfigFlowResult]
):
"""Flow to set options for a configuration entry."""
_flow_result = ConfigFlowResult
@ -2822,7 +2848,7 @@ class OptionsFlowManager(data_entry_flow.FlowManager[ConfigFlowResult]):
self,
handler_key: str,
*,
context: dict[str, Any] | None = None,
context: ConfigFlowContext | None = None,
data: dict[str, Any] | None = None,
) -> OptionsFlow:
"""Create an options flow for a config entry.
@ -2835,7 +2861,7 @@ class OptionsFlowManager(data_entry_flow.FlowManager[ConfigFlowResult]):
async def async_finish_flow(
self,
flow: data_entry_flow.FlowHandler[ConfigFlowResult],
flow: data_entry_flow.FlowHandler[ConfigFlowContext, ConfigFlowResult],
result: ConfigFlowResult,
) -> ConfigFlowResult:
"""Finish an options flow and update options for configuration entry.
@ -2860,7 +2886,7 @@ class OptionsFlowManager(data_entry_flow.FlowManager[ConfigFlowResult]):
return result
async def _async_setup_preview(
self, flow: data_entry_flow.FlowHandler[ConfigFlowResult]
self, flow: data_entry_flow.FlowHandler[ConfigFlowContext, ConfigFlowResult]
) -> None:
"""Set up preview for an option flow handler."""
entry = self._async_get_config_entry(flow.handler)

View File

@ -87,7 +87,10 @@ STEP_ID_OPTIONAL_STEPS = {
}
_FlowResultT = TypeVar("_FlowResultT", bound="FlowResult[Any]", default="FlowResult")
_FlowContextT = TypeVar("_FlowContextT", bound="FlowContext", default="FlowContext")
_FlowResultT = TypeVar(
"_FlowResultT", bound="FlowResult[Any, Any]", default="FlowResult"
)
_HandlerT = TypeVar("_HandlerT", default=str)
@ -139,10 +142,17 @@ class AbortFlow(FlowError):
self.description_placeholders = description_placeholders
class FlowResult(TypedDict, Generic[_HandlerT], total=False):
class FlowContext(TypedDict, total=False):
"""Typed context dict."""
show_advanced_options: bool
source: str
class FlowResult(TypedDict, Generic[_FlowContextT, _HandlerT], total=False):
"""Typed result dict."""
context: dict[str, Any]
context: _FlowContextT
data_schema: vol.Schema | None
data: Mapping[str, Any]
description_placeholders: Mapping[str, str | None] | None
@ -189,7 +199,7 @@ def _map_error_to_schema_errors(
schema_errors[path_part_str] = error.error_message
class FlowManager(abc.ABC, Generic[_FlowResultT, _HandlerT]):
class FlowManager(abc.ABC, Generic[_FlowContextT, _FlowResultT, _HandlerT]):
"""Manage all the flows that are in progress."""
_flow_result: type[_FlowResultT] = FlowResult # type: ignore[assignment]
@ -201,12 +211,14 @@ class FlowManager(abc.ABC, Generic[_FlowResultT, _HandlerT]):
"""Initialize the flow manager."""
self.hass = hass
self._preview: set[_HandlerT] = set()
self._progress: dict[str, FlowHandler[_FlowResultT, _HandlerT]] = {}
self._progress: dict[
str, FlowHandler[_FlowContextT, _FlowResultT, _HandlerT]
] = {}
self._handler_progress_index: defaultdict[
_HandlerT, set[FlowHandler[_FlowResultT, _HandlerT]]
_HandlerT, set[FlowHandler[_FlowContextT, _FlowResultT, _HandlerT]]
] = defaultdict(set)
self._init_data_process_index: defaultdict[
type, set[FlowHandler[_FlowResultT, _HandlerT]]
type, set[FlowHandler[_FlowContextT, _FlowResultT, _HandlerT]]
] = defaultdict(set)
@abc.abstractmethod
@ -214,9 +226,9 @@ class FlowManager(abc.ABC, Generic[_FlowResultT, _HandlerT]):
self,
handler_key: _HandlerT,
*,
context: dict[str, Any] | None = None,
context: _FlowContextT | None = None,
data: dict[str, Any] | None = None,
) -> FlowHandler[_FlowResultT, _HandlerT]:
) -> FlowHandler[_FlowContextT, _FlowResultT, _HandlerT]:
"""Create a flow for specified handler.
Handler key is the domain of the component that we want to set up.
@ -224,7 +236,9 @@ class FlowManager(abc.ABC, Generic[_FlowResultT, _HandlerT]):
@abc.abstractmethod
async def async_finish_flow(
self, flow: FlowHandler[_FlowResultT, _HandlerT], result: _FlowResultT
self,
flow: FlowHandler[_FlowContextT, _FlowResultT, _HandlerT],
result: _FlowResultT,
) -> _FlowResultT:
"""Finish a data entry flow.
@ -233,7 +247,9 @@ class FlowManager(abc.ABC, Generic[_FlowResultT, _HandlerT]):
"""
async def async_post_init(
self, flow: FlowHandler[_FlowResultT, _HandlerT], result: _FlowResultT
self,
flow: FlowHandler[_FlowContextT, _FlowResultT, _HandlerT],
result: _FlowResultT,
) -> None:
"""Entry has finished executing its first step asynchronously."""
@ -288,7 +304,7 @@ class FlowManager(abc.ABC, Generic[_FlowResultT, _HandlerT]):
@callback
def _async_progress_by_handler(
self, handler: _HandlerT, match_context: dict[str, Any] | None
) -> list[FlowHandler[_FlowResultT, _HandlerT]]:
) -> list[FlowHandler[_FlowContextT, _FlowResultT, _HandlerT]]:
"""Return the flows in progress by handler.
If match_context is specified, only return flows with a context that
@ -307,12 +323,12 @@ class FlowManager(abc.ABC, Generic[_FlowResultT, _HandlerT]):
self,
handler: _HandlerT,
*,
context: dict[str, Any] | None = None,
context: _FlowContextT | None = None,
data: Any = None,
) -> _FlowResultT:
"""Start a data entry flow."""
if context is None:
context = {}
context = cast(_FlowContextT, {})
flow = await self.async_create_flow(handler, context=context, data=data)
if not flow:
raise UnknownFlow("Flow was not created")
@ -452,7 +468,7 @@ class FlowManager(abc.ABC, Generic[_FlowResultT, _HandlerT]):
@callback
def _async_add_flow_progress(
self, flow: FlowHandler[_FlowResultT, _HandlerT]
self, flow: FlowHandler[_FlowContextT, _FlowResultT, _HandlerT]
) -> None:
"""Add a flow to in progress."""
if flow.init_data is not None:
@ -462,7 +478,7 @@ class FlowManager(abc.ABC, Generic[_FlowResultT, _HandlerT]):
@callback
def _async_remove_flow_from_index(
self, flow: FlowHandler[_FlowResultT, _HandlerT]
self, flow: FlowHandler[_FlowContextT, _FlowResultT, _HandlerT]
) -> None:
"""Remove a flow from in progress."""
if flow.init_data is not None:
@ -489,7 +505,7 @@ class FlowManager(abc.ABC, Generic[_FlowResultT, _HandlerT]):
async def _async_handle_step(
self,
flow: FlowHandler[_FlowResultT, _HandlerT],
flow: FlowHandler[_FlowContextT, _FlowResultT, _HandlerT],
step_id: str,
user_input: dict | BaseServiceInfo | None,
) -> _FlowResultT:
@ -566,7 +582,7 @@ class FlowManager(abc.ABC, Generic[_FlowResultT, _HandlerT]):
return result
def _raise_if_step_does_not_exist(
self, flow: FlowHandler[_FlowResultT, _HandlerT], step_id: str
self, flow: FlowHandler[_FlowContextT, _FlowResultT, _HandlerT], step_id: str
) -> None:
"""Raise if the step does not exist."""
method = f"async_step_{step_id}"
@ -578,7 +594,7 @@ class FlowManager(abc.ABC, Generic[_FlowResultT, _HandlerT]):
)
async def _async_setup_preview(
self, flow: FlowHandler[_FlowResultT, _HandlerT]
self, flow: FlowHandler[_FlowContextT, _FlowResultT, _HandlerT]
) -> None:
"""Set up preview for a flow handler."""
if flow.handler not in self._preview:
@ -588,7 +604,7 @@ class FlowManager(abc.ABC, Generic[_FlowResultT, _HandlerT]):
@callback
def _async_flow_handler_to_flow_result(
self,
flows: Iterable[FlowHandler[_FlowResultT, _HandlerT]],
flows: Iterable[FlowHandler[_FlowContextT, _FlowResultT, _HandlerT]],
include_uninitialized: bool,
) -> list[_FlowResultT]:
"""Convert a list of FlowHandler to a partial FlowResult that can be serialized."""
@ -610,7 +626,7 @@ class FlowManager(abc.ABC, Generic[_FlowResultT, _HandlerT]):
]
class FlowHandler(Generic[_FlowResultT, _HandlerT]):
class FlowHandler(Generic[_FlowContextT, _FlowResultT, _HandlerT]):
"""Handle a data entry flow."""
_flow_result: type[_FlowResultT] = FlowResult # type: ignore[assignment]
@ -624,7 +640,7 @@ class FlowHandler(Generic[_FlowResultT, _HandlerT]):
hass: HomeAssistant = None # type: ignore[assignment]
handler: _HandlerT = None # type: ignore[assignment]
# Ensure the attribute has a subscriptable, but immutable, default value.
context: dict[str, Any] = MappingProxyType({}) # type: ignore[assignment]
context: _FlowContextT = MappingProxyType({}) # type: ignore[assignment]
# Set by _async_create_flow callback
init_step = "init"
@ -643,12 +659,12 @@ class FlowHandler(Generic[_FlowResultT, _HandlerT]):
@property
def source(self) -> str | None:
"""Source that initialized the flow."""
return self.context.get("source", None) # type: ignore[no-any-return]
return self.context.get("source", None) # type: ignore[return-value]
@property
def show_advanced_options(self) -> bool:
"""If we should show advanced options."""
return self.context.get("show_advanced_options", False) # type: ignore[no-any-return]
return self.context.get("show_advanced_options", False) # type: ignore[return-value]
def add_suggested_values_to_schema(
self, data_schema: vol.Schema, suggested_values: Mapping[str, Any] | None

View File

@ -18,7 +18,7 @@ from . import config_validation as cv
_FlowManagerT = TypeVar(
"_FlowManagerT",
bound=data_entry_flow.FlowManager[Any],
bound=data_entry_flow.FlowManager[Any, Any],
default=data_entry_flow.FlowManager,
)

View File

@ -13,7 +13,7 @@ from homeassistant.util.async_ import gather_with_limited_concurrency
from homeassistant.util.hass_dict import HassKey
if TYPE_CHECKING:
from homeassistant.config_entries import ConfigFlowResult
from homeassistant.config_entries import ConfigFlowContext, ConfigFlowResult
FLOW_INIT_LIMIT = 20
DISCOVERY_FLOW_DISPATCHER: HassKey[FlowDispatcher] = HassKey(
@ -42,7 +42,7 @@ class DiscoveryKey:
def async_create_flow(
hass: HomeAssistant,
domain: str,
context: dict[str, Any],
context: ConfigFlowContext,
data: Any,
*,
discovery_key: DiscoveryKey | None = None,
@ -70,7 +70,7 @@ def async_create_flow(
@callback
def _async_init_flow(
hass: HomeAssistant, domain: str, context: dict[str, Any], data: Any
hass: HomeAssistant, domain: str, context: ConfigFlowContext, data: Any
) -> Coroutine[None, None, ConfigFlowResult] | None:
"""Create a discovery flow."""
# Avoid spawning flows that have the same initial discovery data
@ -98,7 +98,7 @@ class PendingFlowKey(NamedTuple):
class PendingFlowValue(NamedTuple):
"""Value for pending flows."""
context: dict[str, Any]
context: ConfigFlowContext
data: Any
@ -137,7 +137,7 @@ class FlowDispatcher:
await gather_with_limited_concurrency(FLOW_INIT_LIMIT, *init_coros)
@callback
def async_create(self, domain: str, context: dict[str, Any], data: Any) -> None:
def async_create(self, domain: str, context: ConfigFlowContext, data: Any) -> None:
"""Create and add or queue a flow."""
key = PendingFlowKey(domain, context["source"])
values = PendingFlowValue(context, data)