Make FlowHandler.context a typed dict (#126291)
* Make FlowHandler.context a typed dict * Adjust typing * Adjust typing * Avoid calling ConfigFlowContext constructor in hot pathpull/127925/head
parent
217165208b
commit
d6ee10a543
|
@ -12,7 +12,6 @@ from typing import Any, cast
|
|||
|
||||
import jwt
|
||||
|
||||
from homeassistant import data_entry_flow
|
||||
from homeassistant.core import (
|
||||
CALLBACK_TYPE,
|
||||
HassJob,
|
||||
|
@ -20,13 +19,14 @@ from homeassistant.core import (
|
|||
HomeAssistant,
|
||||
callback,
|
||||
)
|
||||
from homeassistant.data_entry_flow import FlowHandler, FlowManager, FlowResultType
|
||||
from homeassistant.helpers.event import async_track_point_in_utc_time
|
||||
from homeassistant.util import dt as dt_util
|
||||
|
||||
from . import auth_store, jwt_wrapper, models
|
||||
from .const import ACCESS_TOKEN_EXPIRATION, GROUP_ID_ADMIN, REFRESH_TOKEN_EXPIRATION
|
||||
from .mfa_modules import MultiFactorAuthModule, auth_mfa_module_from_config
|
||||
from .models import AuthFlowResult
|
||||
from .models import AuthFlowContext, AuthFlowResult
|
||||
from .providers import AuthProvider, LoginFlow, auth_provider_from_config
|
||||
from .providers.homeassistant import HassAuthProvider
|
||||
|
||||
|
@ -98,7 +98,7 @@ async def auth_manager_from_config(
|
|||
|
||||
|
||||
class AuthManagerFlowManager(
|
||||
data_entry_flow.FlowManager[AuthFlowResult, tuple[str, str]]
|
||||
FlowManager[AuthFlowContext, AuthFlowResult, tuple[str, str]]
|
||||
):
|
||||
"""Manage authentication flows."""
|
||||
|
||||
|
@ -113,7 +113,7 @@ class AuthManagerFlowManager(
|
|||
self,
|
||||
handler_key: tuple[str, str],
|
||||
*,
|
||||
context: dict[str, Any] | None = None,
|
||||
context: AuthFlowContext | None = None,
|
||||
data: dict[str, Any] | None = None,
|
||||
) -> LoginFlow:
|
||||
"""Create a login flow."""
|
||||
|
@ -124,7 +124,7 @@ class AuthManagerFlowManager(
|
|||
|
||||
async def async_finish_flow(
|
||||
self,
|
||||
flow: data_entry_flow.FlowHandler[AuthFlowResult, tuple[str, str]],
|
||||
flow: FlowHandler[AuthFlowContext, AuthFlowResult, tuple[str, str]],
|
||||
result: AuthFlowResult,
|
||||
) -> AuthFlowResult:
|
||||
"""Return a user as result of login flow.
|
||||
|
@ -134,7 +134,7 @@ class AuthManagerFlowManager(
|
|||
"""
|
||||
flow = cast(LoginFlow, flow)
|
||||
|
||||
if result["type"] != data_entry_flow.FlowResultType.CREATE_ENTRY:
|
||||
if result["type"] != FlowResultType.CREATE_ENTRY:
|
||||
return result
|
||||
|
||||
# we got final result
|
||||
|
|
|
@ -3,6 +3,7 @@
|
|||
from __future__ import annotations
|
||||
|
||||
from datetime import datetime, timedelta
|
||||
from ipaddress import IPv4Address, IPv6Address
|
||||
import secrets
|
||||
from typing import Any, NamedTuple
|
||||
import uuid
|
||||
|
@ -13,7 +14,7 @@ from attr.setters import validate
|
|||
from propcache import cached_property
|
||||
|
||||
from homeassistant.const import __version__
|
||||
from homeassistant.data_entry_flow import FlowResult
|
||||
from homeassistant.data_entry_flow import FlowContext, FlowResult
|
||||
from homeassistant.util import dt as dt_util
|
||||
|
||||
from . import permissions as perm_mdl
|
||||
|
@ -23,7 +24,16 @@ TOKEN_TYPE_NORMAL = "normal"
|
|||
TOKEN_TYPE_SYSTEM = "system"
|
||||
TOKEN_TYPE_LONG_LIVED_ACCESS_TOKEN = "long_lived_access_token"
|
||||
|
||||
AuthFlowResult = FlowResult[tuple[str, str]]
|
||||
|
||||
class AuthFlowContext(FlowContext, total=False):
|
||||
"""Typed context dict for auth flow."""
|
||||
|
||||
credential_only: bool
|
||||
ip_address: IPv4Address | IPv6Address
|
||||
redirect_uri: str
|
||||
|
||||
|
||||
AuthFlowResult = FlowResult[AuthFlowContext, tuple[str, str]]
|
||||
|
||||
|
||||
@attr.s(slots=True)
|
||||
|
|
|
@ -10,9 +10,10 @@ from typing import Any
|
|||
import voluptuous as vol
|
||||
from voluptuous.humanize import humanize_error
|
||||
|
||||
from homeassistant import data_entry_flow, requirements
|
||||
from homeassistant import requirements
|
||||
from homeassistant.const import CONF_ID, CONF_NAME, CONF_TYPE
|
||||
from homeassistant.core import HomeAssistant, callback
|
||||
from homeassistant.data_entry_flow import FlowHandler
|
||||
from homeassistant.exceptions import HomeAssistantError
|
||||
from homeassistant.helpers.importlib import async_import_module
|
||||
from homeassistant.util import dt as dt_util
|
||||
|
@ -21,7 +22,14 @@ from homeassistant.util.hass_dict import HassKey
|
|||
|
||||
from ..auth_store import AuthStore
|
||||
from ..const import MFA_SESSION_EXPIRATION
|
||||
from ..models import AuthFlowResult, Credentials, RefreshToken, User, UserMeta
|
||||
from ..models import (
|
||||
AuthFlowContext,
|
||||
AuthFlowResult,
|
||||
Credentials,
|
||||
RefreshToken,
|
||||
User,
|
||||
UserMeta,
|
||||
)
|
||||
|
||||
_LOGGER = logging.getLogger(__name__)
|
||||
DATA_REQS: HassKey[set[str]] = HassKey("auth_prov_reqs_processed")
|
||||
|
@ -97,7 +105,7 @@ class AuthProvider:
|
|||
|
||||
# Implement by extending class
|
||||
|
||||
async def async_login_flow(self, context: dict[str, Any] | None) -> LoginFlow:
|
||||
async def async_login_flow(self, context: AuthFlowContext | None) -> LoginFlow:
|
||||
"""Return the data flow for logging in with auth provider.
|
||||
|
||||
Auth provider should extend LoginFlow and return an instance.
|
||||
|
@ -184,7 +192,7 @@ async def load_auth_provider_module(
|
|||
return module
|
||||
|
||||
|
||||
class LoginFlow(data_entry_flow.FlowHandler[AuthFlowResult, tuple[str, str]]):
|
||||
class LoginFlow(FlowHandler[AuthFlowContext, AuthFlowResult, tuple[str, str]]):
|
||||
"""Handler for the login flow."""
|
||||
|
||||
_flow_result = AuthFlowResult
|
||||
|
|
|
@ -13,7 +13,7 @@ import voluptuous as vol
|
|||
from homeassistant.const import CONF_COMMAND
|
||||
from homeassistant.exceptions import HomeAssistantError
|
||||
|
||||
from ..models import AuthFlowResult, Credentials, UserMeta
|
||||
from ..models import AuthFlowContext, AuthFlowResult, Credentials, UserMeta
|
||||
from . import AUTH_PROVIDER_SCHEMA, AUTH_PROVIDERS, AuthProvider, LoginFlow
|
||||
|
||||
CONF_ARGS = "args"
|
||||
|
@ -59,7 +59,7 @@ class CommandLineAuthProvider(AuthProvider):
|
|||
super().__init__(*args, **kwargs)
|
||||
self._user_meta: dict[str, dict[str, Any]] = {}
|
||||
|
||||
async def async_login_flow(self, context: dict[str, Any] | None) -> LoginFlow:
|
||||
async def async_login_flow(self, context: AuthFlowContext | None) -> LoginFlow:
|
||||
"""Return a flow to login."""
|
||||
return CommandLineLoginFlow(self)
|
||||
|
||||
|
|
|
@ -17,7 +17,7 @@ from homeassistant.exceptions import HomeAssistantError
|
|||
from homeassistant.helpers import issue_registry as ir
|
||||
from homeassistant.helpers.storage import Store
|
||||
|
||||
from ..models import AuthFlowResult, Credentials, UserMeta
|
||||
from ..models import AuthFlowContext, AuthFlowResult, Credentials, UserMeta
|
||||
from . import AUTH_PROVIDER_SCHEMA, AUTH_PROVIDERS, AuthProvider, LoginFlow
|
||||
|
||||
STORAGE_VERSION = 1
|
||||
|
@ -305,7 +305,7 @@ class HassAuthProvider(AuthProvider):
|
|||
await data.async_load()
|
||||
self.data = data
|
||||
|
||||
async def async_login_flow(self, context: dict[str, Any] | None) -> LoginFlow:
|
||||
async def async_login_flow(self, context: AuthFlowContext | None) -> LoginFlow:
|
||||
"""Return a flow to login."""
|
||||
return HassLoginFlow(self)
|
||||
|
||||
|
|
|
@ -4,14 +4,14 @@ from __future__ import annotations
|
|||
|
||||
from collections.abc import Mapping
|
||||
import hmac
|
||||
from typing import Any, cast
|
||||
from typing import cast
|
||||
|
||||
import voluptuous as vol
|
||||
|
||||
from homeassistant.core import callback
|
||||
from homeassistant.exceptions import HomeAssistantError
|
||||
|
||||
from ..models import AuthFlowResult, Credentials, UserMeta
|
||||
from ..models import AuthFlowContext, AuthFlowResult, Credentials, UserMeta
|
||||
from . import AUTH_PROVIDER_SCHEMA, AUTH_PROVIDERS, AuthProvider, LoginFlow
|
||||
|
||||
USER_SCHEMA = vol.Schema(
|
||||
|
@ -36,7 +36,7 @@ class InvalidAuthError(HomeAssistantError):
|
|||
class ExampleAuthProvider(AuthProvider):
|
||||
"""Example auth provider based on hardcoded usernames and passwords."""
|
||||
|
||||
async def async_login_flow(self, context: dict[str, Any] | None) -> LoginFlow:
|
||||
async def async_login_flow(self, context: AuthFlowContext | None) -> LoginFlow:
|
||||
"""Return a flow to login."""
|
||||
return ExampleLoginFlow(self)
|
||||
|
||||
|
|
|
@ -25,7 +25,13 @@ import homeassistant.helpers.config_validation as cv
|
|||
from homeassistant.helpers.network import is_cloud_connection
|
||||
|
||||
from .. import InvalidAuthError
|
||||
from ..models import AuthFlowResult, Credentials, RefreshToken, UserMeta
|
||||
from ..models import (
|
||||
AuthFlowContext,
|
||||
AuthFlowResult,
|
||||
Credentials,
|
||||
RefreshToken,
|
||||
UserMeta,
|
||||
)
|
||||
from . import AUTH_PROVIDER_SCHEMA, AUTH_PROVIDERS, AuthProvider, LoginFlow
|
||||
|
||||
type IPAddress = IPv4Address | IPv6Address
|
||||
|
@ -98,7 +104,7 @@ class TrustedNetworksAuthProvider(AuthProvider):
|
|||
"""Trusted Networks auth provider does not support MFA."""
|
||||
return False
|
||||
|
||||
async def async_login_flow(self, context: dict[str, Any] | None) -> LoginFlow:
|
||||
async def async_login_flow(self, context: AuthFlowContext | None) -> LoginFlow:
|
||||
"""Return a flow to login."""
|
||||
assert context is not None
|
||||
ip_addr = cast(IPAddress, context.get("ip_address"))
|
||||
|
|
|
@ -80,7 +80,7 @@ import voluptuous_serialize
|
|||
|
||||
from homeassistant import data_entry_flow
|
||||
from homeassistant.auth import AuthManagerFlowManager, InvalidAuthError
|
||||
from homeassistant.auth.models import AuthFlowResult, Credentials
|
||||
from homeassistant.auth.models import AuthFlowContext, AuthFlowResult, Credentials
|
||||
from homeassistant.components import onboarding
|
||||
from homeassistant.components.http import KEY_HASS
|
||||
from homeassistant.components.http.auth import async_user_not_allowed_do_auth
|
||||
|
@ -322,11 +322,11 @@ class LoginFlowIndexView(LoginFlowBaseView):
|
|||
try:
|
||||
result = await self._flow_mgr.async_init(
|
||||
handler,
|
||||
context={
|
||||
"ip_address": ip_address(request.remote), # type: ignore[arg-type]
|
||||
"credential_only": data.get("type") == "link_user",
|
||||
"redirect_uri": redirect_uri,
|
||||
},
|
||||
context=AuthFlowContext(
|
||||
ip_address=ip_address(request.remote), # type: ignore[arg-type]
|
||||
credential_only=data.get("type") == "link_user",
|
||||
redirect_uri=redirect_uri,
|
||||
),
|
||||
)
|
||||
except data_entry_flow.UnknownHandler:
|
||||
return self.json_message("Invalid handler specified", HTTPStatus.NOT_FOUND)
|
||||
|
|
|
@ -11,6 +11,7 @@ import voluptuous_serialize
|
|||
from homeassistant import data_entry_flow
|
||||
from homeassistant.components import websocket_api
|
||||
from homeassistant.core import HomeAssistant, callback
|
||||
from homeassistant.data_entry_flow import FlowContext
|
||||
import homeassistant.helpers.config_validation as cv
|
||||
from homeassistant.util.hass_dict import HassKey
|
||||
|
||||
|
@ -44,7 +45,7 @@ class MfaFlowManager(data_entry_flow.FlowManager):
|
|||
self,
|
||||
handler_key: str,
|
||||
*,
|
||||
context: dict[str, Any],
|
||||
context: FlowContext | None,
|
||||
data: dict[str, Any],
|
||||
) -> data_entry_flow.FlowHandler:
|
||||
"""Create a setup flow. handler is a mfa module."""
|
||||
|
|
|
@ -463,7 +463,7 @@ async def ignore_config_flow(
|
|||
)
|
||||
return
|
||||
|
||||
context = {"source": config_entries.SOURCE_IGNORE}
|
||||
context = config_entries.ConfigFlowContext(source=config_entries.SOURCE_IGNORE)
|
||||
if "discovery_key" in flow["context"]:
|
||||
context["discovery_key"] = flow["context"]["discovery_key"]
|
||||
await hass.config_entries.flow.async_init(
|
||||
|
|
|
@ -12,7 +12,13 @@ from homeassistant.components.homeassistant_hardware import (
|
|||
firmware_config_flow,
|
||||
silabs_multiprotocol_addon,
|
||||
)
|
||||
from homeassistant.config_entries import ConfigEntry, ConfigFlowResult, OptionsFlow
|
||||
from homeassistant.config_entries import (
|
||||
ConfigEntry,
|
||||
ConfigEntryBaseFlow,
|
||||
ConfigFlowContext,
|
||||
ConfigFlowResult,
|
||||
OptionsFlow,
|
||||
)
|
||||
from homeassistant.core import callback
|
||||
|
||||
from .const import DOCS_WEB_FLASHER_URL, DOMAIN, HardwareVariant
|
||||
|
@ -33,10 +39,10 @@ else:
|
|||
TranslationPlaceholderProtocol = object
|
||||
|
||||
|
||||
class SkyConnectTranslationMixin(TranslationPlaceholderProtocol):
|
||||
class SkyConnectTranslationMixin(ConfigEntryBaseFlow, TranslationPlaceholderProtocol):
|
||||
"""Translation placeholder mixin for Home Assistant SkyConnect."""
|
||||
|
||||
context: dict[str, Any]
|
||||
context: ConfigFlowContext
|
||||
|
||||
def _get_translation_placeholders(self) -> dict[str, str]:
|
||||
"""Shared translation placeholders."""
|
||||
|
|
|
@ -53,7 +53,7 @@ class RepairsFlowManager(data_entry_flow.FlowManager):
|
|||
self,
|
||||
handler_key: str,
|
||||
*,
|
||||
context: dict[str, Any] | None = None,
|
||||
context: data_entry_flow.FlowContext | None = None,
|
||||
data: dict[str, Any] | None = None,
|
||||
) -> RepairsFlow:
|
||||
"""Create a flow. platform is a repairs module."""
|
||||
|
|
|
@ -378,7 +378,7 @@ class TPLinkConfigFlow(ConfigFlow, domain=DOMAIN):
|
|||
for flow in _config_entries.flow.async_progress_by_handler(
|
||||
DOMAIN, include_uninitialized=True
|
||||
):
|
||||
context: dict[str, Any] = flow["context"]
|
||||
context = flow["context"]
|
||||
if context.get("source") != SOURCE_REAUTH:
|
||||
continue
|
||||
entry_id: str = context["entry_id"]
|
||||
|
|
|
@ -540,7 +540,9 @@ class ZeroconfDiscovery:
|
|||
continue
|
||||
|
||||
matcher_domain = matcher[ATTR_DOMAIN]
|
||||
context = {
|
||||
# Create a type annotated regular dict since this is a hot path and creating
|
||||
# a regular dict is slightly cheaper than calling ConfigFlowContext
|
||||
context: config_entries.ConfigFlowContext = {
|
||||
"source": config_entries.SOURCE_ZEROCONF,
|
||||
}
|
||||
if domain:
|
||||
|
|
|
@ -29,6 +29,7 @@ from homeassistant.config_entries import (
|
|||
ConfigEntryBaseFlow,
|
||||
ConfigEntryState,
|
||||
ConfigFlow,
|
||||
ConfigFlowContext,
|
||||
ConfigFlowResult,
|
||||
OptionsFlow,
|
||||
OptionsFlowManager,
|
||||
|
@ -192,7 +193,7 @@ class BaseZwaveJSFlow(ConfigEntryBaseFlow, ABC):
|
|||
|
||||
@property
|
||||
@abstractmethod
|
||||
def flow_manager(self) -> FlowManager[ConfigFlowResult]:
|
||||
def flow_manager(self) -> FlowManager[ConfigFlowContext, ConfigFlowResult]:
|
||||
"""Return the flow manager of the flow."""
|
||||
|
||||
async def async_step_install_addon(
|
||||
|
|
|
@ -41,7 +41,7 @@ from .core import (
|
|||
HomeAssistant,
|
||||
callback,
|
||||
)
|
||||
from .data_entry_flow import FLOW_NOT_COMPLETE_STEPS, FlowResult
|
||||
from .data_entry_flow import FLOW_NOT_COMPLETE_STEPS, FlowContext, FlowResult
|
||||
from .exceptions import (
|
||||
ConfigEntryAuthFailed,
|
||||
ConfigEntryError,
|
||||
|
@ -267,7 +267,19 @@ UPDATE_ENTRY_CONFIG_ENTRY_ATTRS = {
|
|||
}
|
||||
|
||||
|
||||
class ConfigFlowResult(FlowResult, total=False):
|
||||
class ConfigFlowContext(FlowContext, total=False):
|
||||
"""Typed context dict for config flow."""
|
||||
|
||||
alternative_domain: str
|
||||
configuration_url: str
|
||||
confirm_only: bool
|
||||
discovery_key: DiscoveryKey
|
||||
entry_id: str
|
||||
title_placeholders: Mapping[str, str]
|
||||
unique_id: str | None
|
||||
|
||||
|
||||
class ConfigFlowResult(FlowResult[ConfigFlowContext, str], total=False):
|
||||
"""Typed result dict for config flow."""
|
||||
|
||||
minor_version: int
|
||||
|
@ -1026,7 +1038,7 @@ class ConfigEntry(Generic[_DataT]):
|
|||
def async_start_reauth(
|
||||
self,
|
||||
hass: HomeAssistant,
|
||||
context: dict[str, Any] | None = None,
|
||||
context: ConfigFlowContext | None = None,
|
||||
data: dict[str, Any] | None = None,
|
||||
) -> None:
|
||||
"""Start a reauth flow."""
|
||||
|
@ -1044,7 +1056,7 @@ class ConfigEntry(Generic[_DataT]):
|
|||
async def _async_init_reauth(
|
||||
self,
|
||||
hass: HomeAssistant,
|
||||
context: dict[str, Any] | None = None,
|
||||
context: ConfigFlowContext | None = None,
|
||||
data: dict[str, Any] | None = None,
|
||||
) -> None:
|
||||
"""Start a reauth flow."""
|
||||
|
@ -1056,12 +1068,12 @@ class ConfigEntry(Generic[_DataT]):
|
|||
return
|
||||
result = await hass.config_entries.flow.async_init(
|
||||
self.domain,
|
||||
context={
|
||||
"source": SOURCE_REAUTH,
|
||||
"entry_id": self.entry_id,
|
||||
"title_placeholders": {"name": self.title},
|
||||
"unique_id": self.unique_id,
|
||||
}
|
||||
context=ConfigFlowContext(
|
||||
source=SOURCE_REAUTH,
|
||||
entry_id=self.entry_id,
|
||||
title_placeholders={"name": self.title},
|
||||
unique_id=self.unique_id,
|
||||
)
|
||||
| (context or {}),
|
||||
data=self.data | (data or {}),
|
||||
)
|
||||
|
@ -1086,7 +1098,7 @@ class ConfigEntry(Generic[_DataT]):
|
|||
def async_start_reconfigure(
|
||||
self,
|
||||
hass: HomeAssistant,
|
||||
context: dict[str, Any] | None = None,
|
||||
context: ConfigFlowContext | None = None,
|
||||
data: dict[str, Any] | None = None,
|
||||
) -> None:
|
||||
"""Start a reconfigure flow."""
|
||||
|
@ -1103,7 +1115,7 @@ class ConfigEntry(Generic[_DataT]):
|
|||
async def _async_init_reconfigure(
|
||||
self,
|
||||
hass: HomeAssistant,
|
||||
context: dict[str, Any] | None = None,
|
||||
context: ConfigFlowContext | None = None,
|
||||
data: dict[str, Any] | None = None,
|
||||
) -> None:
|
||||
"""Start a reconfigure flow."""
|
||||
|
@ -1115,12 +1127,12 @@ class ConfigEntry(Generic[_DataT]):
|
|||
return
|
||||
await hass.config_entries.flow.async_init(
|
||||
self.domain,
|
||||
context={
|
||||
"source": SOURCE_RECONFIGURE,
|
||||
"entry_id": self.entry_id,
|
||||
"title_placeholders": {"name": self.title},
|
||||
"unique_id": self.unique_id,
|
||||
}
|
||||
context=ConfigFlowContext(
|
||||
source=SOURCE_RECONFIGURE,
|
||||
entry_id=self.entry_id,
|
||||
title_placeholders={"name": self.title},
|
||||
unique_id=self.unique_id,
|
||||
)
|
||||
| (context or {}),
|
||||
data=self.data | (data or {}),
|
||||
)
|
||||
|
@ -1214,7 +1226,9 @@ def _report_non_awaited_platform_forwards(entry: ConfigEntry, what: str) -> None
|
|||
)
|
||||
|
||||
|
||||
class ConfigEntriesFlowManager(data_entry_flow.FlowManager[ConfigFlowResult]):
|
||||
class ConfigEntriesFlowManager(
|
||||
data_entry_flow.FlowManager[ConfigFlowContext, ConfigFlowResult]
|
||||
):
|
||||
"""Manage all the config entry flows that are in progress."""
|
||||
|
||||
_flow_result = ConfigFlowResult
|
||||
|
@ -1260,7 +1274,11 @@ class ConfigEntriesFlowManager(data_entry_flow.FlowManager[ConfigFlowResult]):
|
|||
return False
|
||||
|
||||
async def async_init(
|
||||
self, handler: str, *, context: dict[str, Any] | None = None, data: Any = None
|
||||
self,
|
||||
handler: str,
|
||||
*,
|
||||
context: ConfigFlowContext | None = None,
|
||||
data: Any = None,
|
||||
) -> ConfigFlowResult:
|
||||
"""Start a configuration flow."""
|
||||
if not context or "source" not in context:
|
||||
|
@ -1319,7 +1337,7 @@ class ConfigEntriesFlowManager(data_entry_flow.FlowManager[ConfigFlowResult]):
|
|||
self,
|
||||
flow_id: str,
|
||||
handler: str,
|
||||
context: dict,
|
||||
context: ConfigFlowContext,
|
||||
data: Any,
|
||||
) -> tuple[ConfigFlow, ConfigFlowResult]:
|
||||
"""Run the init in a task to allow it to be canceled at shutdown."""
|
||||
|
@ -1357,7 +1375,7 @@ class ConfigEntriesFlowManager(data_entry_flow.FlowManager[ConfigFlowResult]):
|
|||
|
||||
async def async_finish_flow(
|
||||
self,
|
||||
flow: data_entry_flow.FlowHandler[ConfigFlowResult],
|
||||
flow: data_entry_flow.FlowHandler[ConfigFlowContext, ConfigFlowResult],
|
||||
result: ConfigFlowResult,
|
||||
) -> ConfigFlowResult:
|
||||
"""Finish a config flow and add an entry.
|
||||
|
@ -1504,7 +1522,11 @@ class ConfigEntriesFlowManager(data_entry_flow.FlowManager[ConfigFlowResult]):
|
|||
return result
|
||||
|
||||
async def async_create_flow(
|
||||
self, handler_key: str, *, context: dict | None = None, data: Any = None
|
||||
self,
|
||||
handler_key: str,
|
||||
*,
|
||||
context: ConfigFlowContext | None = None,
|
||||
data: Any = None,
|
||||
) -> ConfigFlow:
|
||||
"""Create a flow for specified handler.
|
||||
|
||||
|
@ -1522,7 +1544,7 @@ class ConfigEntriesFlowManager(data_entry_flow.FlowManager[ConfigFlowResult]):
|
|||
|
||||
async def async_post_init(
|
||||
self,
|
||||
flow: data_entry_flow.FlowHandler[ConfigFlowResult],
|
||||
flow: data_entry_flow.FlowHandler[ConfigFlowContext, ConfigFlowResult],
|
||||
result: ConfigFlowResult,
|
||||
) -> None:
|
||||
"""After a flow is initialised trigger new flow notifications."""
|
||||
|
@ -1560,7 +1582,7 @@ class ConfigEntriesFlowManager(data_entry_flow.FlowManager[ConfigFlowResult]):
|
|||
|
||||
@callback
|
||||
def async_has_matching_discovery_flow(
|
||||
self, handler: str, match_context: dict[str, Any], data: Any
|
||||
self, handler: str, match_context: ConfigFlowContext, data: Any
|
||||
) -> bool:
|
||||
"""Check if an existing matching discovery flow is in progress.
|
||||
|
||||
|
@ -2385,7 +2407,9 @@ def _async_abort_entries_match(
|
|||
raise data_entry_flow.AbortFlow("already_configured")
|
||||
|
||||
|
||||
class ConfigEntryBaseFlow(data_entry_flow.FlowHandler[ConfigFlowResult]):
|
||||
class ConfigEntryBaseFlow(
|
||||
data_entry_flow.FlowHandler[ConfigFlowContext, ConfigFlowResult]
|
||||
):
|
||||
"""Base class for config and option flows."""
|
||||
|
||||
_flow_result = ConfigFlowResult
|
||||
|
@ -2406,7 +2430,7 @@ class ConfigFlow(ConfigEntryBaseFlow):
|
|||
if not self.context:
|
||||
return None
|
||||
|
||||
return cast(str | None, self.context.get("unique_id"))
|
||||
return self.context.get("unique_id")
|
||||
|
||||
@staticmethod
|
||||
@callback
|
||||
|
@ -2779,7 +2803,7 @@ class ConfigFlow(ConfigEntryBaseFlow):
|
|||
"""Return reauth entry id."""
|
||||
if self.source != SOURCE_REAUTH:
|
||||
raise ValueError(f"Source is {self.source}, expected {SOURCE_REAUTH}")
|
||||
return self.context["entry_id"] # type: ignore[no-any-return]
|
||||
return self.context["entry_id"]
|
||||
|
||||
@callback
|
||||
def _get_reauth_entry(self) -> ConfigEntry:
|
||||
|
@ -2793,7 +2817,7 @@ class ConfigFlow(ConfigEntryBaseFlow):
|
|||
"""Return reconfigure entry id."""
|
||||
if self.source != SOURCE_RECONFIGURE:
|
||||
raise ValueError(f"Source is {self.source}, expected {SOURCE_RECONFIGURE}")
|
||||
return self.context["entry_id"] # type: ignore[no-any-return]
|
||||
return self.context["entry_id"]
|
||||
|
||||
@callback
|
||||
def _get_reconfigure_entry(self) -> ConfigEntry:
|
||||
|
@ -2805,7 +2829,9 @@ class ConfigFlow(ConfigEntryBaseFlow):
|
|||
raise UnknownEntry
|
||||
|
||||
|
||||
class OptionsFlowManager(data_entry_flow.FlowManager[ConfigFlowResult]):
|
||||
class OptionsFlowManager(
|
||||
data_entry_flow.FlowManager[ConfigFlowContext, ConfigFlowResult]
|
||||
):
|
||||
"""Flow to set options for a configuration entry."""
|
||||
|
||||
_flow_result = ConfigFlowResult
|
||||
|
@ -2822,7 +2848,7 @@ class OptionsFlowManager(data_entry_flow.FlowManager[ConfigFlowResult]):
|
|||
self,
|
||||
handler_key: str,
|
||||
*,
|
||||
context: dict[str, Any] | None = None,
|
||||
context: ConfigFlowContext | None = None,
|
||||
data: dict[str, Any] | None = None,
|
||||
) -> OptionsFlow:
|
||||
"""Create an options flow for a config entry.
|
||||
|
@ -2835,7 +2861,7 @@ class OptionsFlowManager(data_entry_flow.FlowManager[ConfigFlowResult]):
|
|||
|
||||
async def async_finish_flow(
|
||||
self,
|
||||
flow: data_entry_flow.FlowHandler[ConfigFlowResult],
|
||||
flow: data_entry_flow.FlowHandler[ConfigFlowContext, ConfigFlowResult],
|
||||
result: ConfigFlowResult,
|
||||
) -> ConfigFlowResult:
|
||||
"""Finish an options flow and update options for configuration entry.
|
||||
|
@ -2860,7 +2886,7 @@ class OptionsFlowManager(data_entry_flow.FlowManager[ConfigFlowResult]):
|
|||
return result
|
||||
|
||||
async def _async_setup_preview(
|
||||
self, flow: data_entry_flow.FlowHandler[ConfigFlowResult]
|
||||
self, flow: data_entry_flow.FlowHandler[ConfigFlowContext, ConfigFlowResult]
|
||||
) -> None:
|
||||
"""Set up preview for an option flow handler."""
|
||||
entry = self._async_get_config_entry(flow.handler)
|
||||
|
|
|
@ -87,7 +87,10 @@ STEP_ID_OPTIONAL_STEPS = {
|
|||
}
|
||||
|
||||
|
||||
_FlowResultT = TypeVar("_FlowResultT", bound="FlowResult[Any]", default="FlowResult")
|
||||
_FlowContextT = TypeVar("_FlowContextT", bound="FlowContext", default="FlowContext")
|
||||
_FlowResultT = TypeVar(
|
||||
"_FlowResultT", bound="FlowResult[Any, Any]", default="FlowResult"
|
||||
)
|
||||
_HandlerT = TypeVar("_HandlerT", default=str)
|
||||
|
||||
|
||||
|
@ -139,10 +142,17 @@ class AbortFlow(FlowError):
|
|||
self.description_placeholders = description_placeholders
|
||||
|
||||
|
||||
class FlowResult(TypedDict, Generic[_HandlerT], total=False):
|
||||
class FlowContext(TypedDict, total=False):
|
||||
"""Typed context dict."""
|
||||
|
||||
show_advanced_options: bool
|
||||
source: str
|
||||
|
||||
|
||||
class FlowResult(TypedDict, Generic[_FlowContextT, _HandlerT], total=False):
|
||||
"""Typed result dict."""
|
||||
|
||||
context: dict[str, Any]
|
||||
context: _FlowContextT
|
||||
data_schema: vol.Schema | None
|
||||
data: Mapping[str, Any]
|
||||
description_placeholders: Mapping[str, str | None] | None
|
||||
|
@ -189,7 +199,7 @@ def _map_error_to_schema_errors(
|
|||
schema_errors[path_part_str] = error.error_message
|
||||
|
||||
|
||||
class FlowManager(abc.ABC, Generic[_FlowResultT, _HandlerT]):
|
||||
class FlowManager(abc.ABC, Generic[_FlowContextT, _FlowResultT, _HandlerT]):
|
||||
"""Manage all the flows that are in progress."""
|
||||
|
||||
_flow_result: type[_FlowResultT] = FlowResult # type: ignore[assignment]
|
||||
|
@ -201,12 +211,14 @@ class FlowManager(abc.ABC, Generic[_FlowResultT, _HandlerT]):
|
|||
"""Initialize the flow manager."""
|
||||
self.hass = hass
|
||||
self._preview: set[_HandlerT] = set()
|
||||
self._progress: dict[str, FlowHandler[_FlowResultT, _HandlerT]] = {}
|
||||
self._progress: dict[
|
||||
str, FlowHandler[_FlowContextT, _FlowResultT, _HandlerT]
|
||||
] = {}
|
||||
self._handler_progress_index: defaultdict[
|
||||
_HandlerT, set[FlowHandler[_FlowResultT, _HandlerT]]
|
||||
_HandlerT, set[FlowHandler[_FlowContextT, _FlowResultT, _HandlerT]]
|
||||
] = defaultdict(set)
|
||||
self._init_data_process_index: defaultdict[
|
||||
type, set[FlowHandler[_FlowResultT, _HandlerT]]
|
||||
type, set[FlowHandler[_FlowContextT, _FlowResultT, _HandlerT]]
|
||||
] = defaultdict(set)
|
||||
|
||||
@abc.abstractmethod
|
||||
|
@ -214,9 +226,9 @@ class FlowManager(abc.ABC, Generic[_FlowResultT, _HandlerT]):
|
|||
self,
|
||||
handler_key: _HandlerT,
|
||||
*,
|
||||
context: dict[str, Any] | None = None,
|
||||
context: _FlowContextT | None = None,
|
||||
data: dict[str, Any] | None = None,
|
||||
) -> FlowHandler[_FlowResultT, _HandlerT]:
|
||||
) -> FlowHandler[_FlowContextT, _FlowResultT, _HandlerT]:
|
||||
"""Create a flow for specified handler.
|
||||
|
||||
Handler key is the domain of the component that we want to set up.
|
||||
|
@ -224,7 +236,9 @@ class FlowManager(abc.ABC, Generic[_FlowResultT, _HandlerT]):
|
|||
|
||||
@abc.abstractmethod
|
||||
async def async_finish_flow(
|
||||
self, flow: FlowHandler[_FlowResultT, _HandlerT], result: _FlowResultT
|
||||
self,
|
||||
flow: FlowHandler[_FlowContextT, _FlowResultT, _HandlerT],
|
||||
result: _FlowResultT,
|
||||
) -> _FlowResultT:
|
||||
"""Finish a data entry flow.
|
||||
|
||||
|
@ -233,7 +247,9 @@ class FlowManager(abc.ABC, Generic[_FlowResultT, _HandlerT]):
|
|||
"""
|
||||
|
||||
async def async_post_init(
|
||||
self, flow: FlowHandler[_FlowResultT, _HandlerT], result: _FlowResultT
|
||||
self,
|
||||
flow: FlowHandler[_FlowContextT, _FlowResultT, _HandlerT],
|
||||
result: _FlowResultT,
|
||||
) -> None:
|
||||
"""Entry has finished executing its first step asynchronously."""
|
||||
|
||||
|
@ -288,7 +304,7 @@ class FlowManager(abc.ABC, Generic[_FlowResultT, _HandlerT]):
|
|||
@callback
|
||||
def _async_progress_by_handler(
|
||||
self, handler: _HandlerT, match_context: dict[str, Any] | None
|
||||
) -> list[FlowHandler[_FlowResultT, _HandlerT]]:
|
||||
) -> list[FlowHandler[_FlowContextT, _FlowResultT, _HandlerT]]:
|
||||
"""Return the flows in progress by handler.
|
||||
|
||||
If match_context is specified, only return flows with a context that
|
||||
|
@ -307,12 +323,12 @@ class FlowManager(abc.ABC, Generic[_FlowResultT, _HandlerT]):
|
|||
self,
|
||||
handler: _HandlerT,
|
||||
*,
|
||||
context: dict[str, Any] | None = None,
|
||||
context: _FlowContextT | None = None,
|
||||
data: Any = None,
|
||||
) -> _FlowResultT:
|
||||
"""Start a data entry flow."""
|
||||
if context is None:
|
||||
context = {}
|
||||
context = cast(_FlowContextT, {})
|
||||
flow = await self.async_create_flow(handler, context=context, data=data)
|
||||
if not flow:
|
||||
raise UnknownFlow("Flow was not created")
|
||||
|
@ -452,7 +468,7 @@ class FlowManager(abc.ABC, Generic[_FlowResultT, _HandlerT]):
|
|||
|
||||
@callback
|
||||
def _async_add_flow_progress(
|
||||
self, flow: FlowHandler[_FlowResultT, _HandlerT]
|
||||
self, flow: FlowHandler[_FlowContextT, _FlowResultT, _HandlerT]
|
||||
) -> None:
|
||||
"""Add a flow to in progress."""
|
||||
if flow.init_data is not None:
|
||||
|
@ -462,7 +478,7 @@ class FlowManager(abc.ABC, Generic[_FlowResultT, _HandlerT]):
|
|||
|
||||
@callback
|
||||
def _async_remove_flow_from_index(
|
||||
self, flow: FlowHandler[_FlowResultT, _HandlerT]
|
||||
self, flow: FlowHandler[_FlowContextT, _FlowResultT, _HandlerT]
|
||||
) -> None:
|
||||
"""Remove a flow from in progress."""
|
||||
if flow.init_data is not None:
|
||||
|
@ -489,7 +505,7 @@ class FlowManager(abc.ABC, Generic[_FlowResultT, _HandlerT]):
|
|||
|
||||
async def _async_handle_step(
|
||||
self,
|
||||
flow: FlowHandler[_FlowResultT, _HandlerT],
|
||||
flow: FlowHandler[_FlowContextT, _FlowResultT, _HandlerT],
|
||||
step_id: str,
|
||||
user_input: dict | BaseServiceInfo | None,
|
||||
) -> _FlowResultT:
|
||||
|
@ -566,7 +582,7 @@ class FlowManager(abc.ABC, Generic[_FlowResultT, _HandlerT]):
|
|||
return result
|
||||
|
||||
def _raise_if_step_does_not_exist(
|
||||
self, flow: FlowHandler[_FlowResultT, _HandlerT], step_id: str
|
||||
self, flow: FlowHandler[_FlowContextT, _FlowResultT, _HandlerT], step_id: str
|
||||
) -> None:
|
||||
"""Raise if the step does not exist."""
|
||||
method = f"async_step_{step_id}"
|
||||
|
@ -578,7 +594,7 @@ class FlowManager(abc.ABC, Generic[_FlowResultT, _HandlerT]):
|
|||
)
|
||||
|
||||
async def _async_setup_preview(
|
||||
self, flow: FlowHandler[_FlowResultT, _HandlerT]
|
||||
self, flow: FlowHandler[_FlowContextT, _FlowResultT, _HandlerT]
|
||||
) -> None:
|
||||
"""Set up preview for a flow handler."""
|
||||
if flow.handler not in self._preview:
|
||||
|
@ -588,7 +604,7 @@ class FlowManager(abc.ABC, Generic[_FlowResultT, _HandlerT]):
|
|||
@callback
|
||||
def _async_flow_handler_to_flow_result(
|
||||
self,
|
||||
flows: Iterable[FlowHandler[_FlowResultT, _HandlerT]],
|
||||
flows: Iterable[FlowHandler[_FlowContextT, _FlowResultT, _HandlerT]],
|
||||
include_uninitialized: bool,
|
||||
) -> list[_FlowResultT]:
|
||||
"""Convert a list of FlowHandler to a partial FlowResult that can be serialized."""
|
||||
|
@ -610,7 +626,7 @@ class FlowManager(abc.ABC, Generic[_FlowResultT, _HandlerT]):
|
|||
]
|
||||
|
||||
|
||||
class FlowHandler(Generic[_FlowResultT, _HandlerT]):
|
||||
class FlowHandler(Generic[_FlowContextT, _FlowResultT, _HandlerT]):
|
||||
"""Handle a data entry flow."""
|
||||
|
||||
_flow_result: type[_FlowResultT] = FlowResult # type: ignore[assignment]
|
||||
|
@ -624,7 +640,7 @@ class FlowHandler(Generic[_FlowResultT, _HandlerT]):
|
|||
hass: HomeAssistant = None # type: ignore[assignment]
|
||||
handler: _HandlerT = None # type: ignore[assignment]
|
||||
# Ensure the attribute has a subscriptable, but immutable, default value.
|
||||
context: dict[str, Any] = MappingProxyType({}) # type: ignore[assignment]
|
||||
context: _FlowContextT = MappingProxyType({}) # type: ignore[assignment]
|
||||
|
||||
# Set by _async_create_flow callback
|
||||
init_step = "init"
|
||||
|
@ -643,12 +659,12 @@ class FlowHandler(Generic[_FlowResultT, _HandlerT]):
|
|||
@property
|
||||
def source(self) -> str | None:
|
||||
"""Source that initialized the flow."""
|
||||
return self.context.get("source", None) # type: ignore[no-any-return]
|
||||
return self.context.get("source", None) # type: ignore[return-value]
|
||||
|
||||
@property
|
||||
def show_advanced_options(self) -> bool:
|
||||
"""If we should show advanced options."""
|
||||
return self.context.get("show_advanced_options", False) # type: ignore[no-any-return]
|
||||
return self.context.get("show_advanced_options", False) # type: ignore[return-value]
|
||||
|
||||
def add_suggested_values_to_schema(
|
||||
self, data_schema: vol.Schema, suggested_values: Mapping[str, Any] | None
|
||||
|
|
|
@ -18,7 +18,7 @@ from . import config_validation as cv
|
|||
|
||||
_FlowManagerT = TypeVar(
|
||||
"_FlowManagerT",
|
||||
bound=data_entry_flow.FlowManager[Any],
|
||||
bound=data_entry_flow.FlowManager[Any, Any],
|
||||
default=data_entry_flow.FlowManager,
|
||||
)
|
||||
|
||||
|
|
|
@ -13,7 +13,7 @@ from homeassistant.util.async_ import gather_with_limited_concurrency
|
|||
from homeassistant.util.hass_dict import HassKey
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from homeassistant.config_entries import ConfigFlowResult
|
||||
from homeassistant.config_entries import ConfigFlowContext, ConfigFlowResult
|
||||
|
||||
FLOW_INIT_LIMIT = 20
|
||||
DISCOVERY_FLOW_DISPATCHER: HassKey[FlowDispatcher] = HassKey(
|
||||
|
@ -42,7 +42,7 @@ class DiscoveryKey:
|
|||
def async_create_flow(
|
||||
hass: HomeAssistant,
|
||||
domain: str,
|
||||
context: dict[str, Any],
|
||||
context: ConfigFlowContext,
|
||||
data: Any,
|
||||
*,
|
||||
discovery_key: DiscoveryKey | None = None,
|
||||
|
@ -70,7 +70,7 @@ def async_create_flow(
|
|||
|
||||
@callback
|
||||
def _async_init_flow(
|
||||
hass: HomeAssistant, domain: str, context: dict[str, Any], data: Any
|
||||
hass: HomeAssistant, domain: str, context: ConfigFlowContext, data: Any
|
||||
) -> Coroutine[None, None, ConfigFlowResult] | None:
|
||||
"""Create a discovery flow."""
|
||||
# Avoid spawning flows that have the same initial discovery data
|
||||
|
@ -98,7 +98,7 @@ class PendingFlowKey(NamedTuple):
|
|||
class PendingFlowValue(NamedTuple):
|
||||
"""Value for pending flows."""
|
||||
|
||||
context: dict[str, Any]
|
||||
context: ConfigFlowContext
|
||||
data: Any
|
||||
|
||||
|
||||
|
@ -137,7 +137,7 @@ class FlowDispatcher:
|
|||
await gather_with_limited_concurrency(FLOW_INIT_LIMIT, *init_coros)
|
||||
|
||||
@callback
|
||||
def async_create(self, domain: str, context: dict[str, Any], data: Any) -> None:
|
||||
def async_create(self, domain: str, context: ConfigFlowContext, data: Any) -> None:
|
||||
"""Create and add or queue a flow."""
|
||||
key = PendingFlowKey(domain, context["source"])
|
||||
values = PendingFlowValue(context, data)
|
||||
|
|
Loading…
Reference in New Issue