Make FlowResult a generic type (#111952)

pull/112619/head
Erik Montnemery 2024-03-07 12:41:14 +01:00 committed by GitHub
parent 008e025d5c
commit 82efb3d35b
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
13 changed files with 95 additions and 80 deletions

View File

@ -19,13 +19,13 @@ from homeassistant.core import (
HomeAssistant,
callback,
)
from homeassistant.data_entry_flow import FlowResult
from homeassistant.helpers.event import async_track_point_in_utc_time
from homeassistant.util import dt as dt_util
from . import auth_store, jwt_wrapper, models
from .const import ACCESS_TOKEN_EXPIRATION, GROUP_ID_ADMIN, REFRESH_TOKEN_EXPIRATION
from .mfa_modules import MultiFactorAuthModule, auth_mfa_module_from_config
from .models import AuthFlowResult
from .providers import AuthProvider, LoginFlow, auth_provider_from_config
EVENT_USER_ADDED = "user_added"
@ -88,9 +88,13 @@ async def auth_manager_from_config(
return manager
class AuthManagerFlowManager(data_entry_flow.FlowManager):
class AuthManagerFlowManager(
data_entry_flow.FlowManager[AuthFlowResult, tuple[str, str]]
):
"""Manage authentication flows."""
_flow_result = AuthFlowResult
def __init__(self, hass: HomeAssistant, auth_manager: AuthManager) -> None:
"""Init auth manager flows."""
super().__init__(hass)
@ -98,11 +102,11 @@ class AuthManagerFlowManager(data_entry_flow.FlowManager):
async def async_create_flow(
self,
handler_key: str,
handler_key: tuple[str, str],
*,
context: dict[str, Any] | None = None,
data: dict[str, Any] | None = None,
) -> data_entry_flow.FlowHandler:
) -> LoginFlow:
"""Create a login flow."""
auth_provider = self.auth_manager.get_auth_provider(*handler_key)
if not auth_provider:
@ -110,8 +114,10 @@ class AuthManagerFlowManager(data_entry_flow.FlowManager):
return await auth_provider.async_login_flow(context)
async def async_finish_flow(
self, flow: data_entry_flow.FlowHandler, result: FlowResult
) -> FlowResult:
self,
flow: data_entry_flow.FlowHandler[AuthFlowResult, tuple[str, str]],
result: AuthFlowResult,
) -> AuthFlowResult:
"""Return a user as result of login flow."""
flow = cast(LoginFlow, flow)

View File

@ -11,6 +11,7 @@ from attr import Attribute
from attr.setters import validate
from homeassistant.const import __version__
from homeassistant.data_entry_flow import FlowResult
from homeassistant.util import dt as dt_util
from . import permissions as perm_mdl
@ -26,6 +27,8 @@ TOKEN_TYPE_NORMAL = "normal"
TOKEN_TYPE_SYSTEM = "system"
TOKEN_TYPE_LONG_LIVED_ACCESS_TOKEN = "long_lived_access_token"
AuthFlowResult = FlowResult[tuple[str, str]]
@attr.s(slots=True)
class Group:

View File

@ -13,14 +13,13 @@ from voluptuous.humanize import humanize_error
from homeassistant import data_entry_flow, requirements
from homeassistant.const import CONF_ID, CONF_NAME, CONF_TYPE
from homeassistant.core import HomeAssistant, callback
from homeassistant.data_entry_flow import FlowResult
from homeassistant.exceptions import HomeAssistantError
from homeassistant.util import dt as dt_util
from homeassistant.util.decorator import Registry
from ..auth_store import AuthStore
from ..const import MFA_SESSION_EXPIRATION
from ..models import Credentials, RefreshToken, User, UserMeta
from ..models import AuthFlowResult, Credentials, RefreshToken, User, UserMeta
_LOGGER = logging.getLogger(__name__)
DATA_REQS = "auth_prov_reqs_processed"
@ -181,9 +180,11 @@ async def load_auth_provider_module(
return module
class LoginFlow(data_entry_flow.FlowHandler):
class LoginFlow(data_entry_flow.FlowHandler[AuthFlowResult, tuple[str, str]]):
"""Handler for the login flow."""
_flow_result = AuthFlowResult
def __init__(self, auth_provider: AuthProvider) -> None:
"""Initialize the login flow."""
self._auth_provider = auth_provider
@ -197,7 +198,7 @@ class LoginFlow(data_entry_flow.FlowHandler):
async def async_step_init(
self, user_input: dict[str, str] | None = None
) -> FlowResult:
) -> AuthFlowResult:
"""Handle the first step of login flow.
Return self.async_show_form(step_id='init') if user_input is None.
@ -207,7 +208,7 @@ class LoginFlow(data_entry_flow.FlowHandler):
async def async_step_select_mfa_module(
self, user_input: dict[str, str] | None = None
) -> FlowResult:
) -> AuthFlowResult:
"""Handle the step of select mfa module."""
errors = {}
@ -232,7 +233,7 @@ class LoginFlow(data_entry_flow.FlowHandler):
async def async_step_mfa(
self, user_input: dict[str, str] | None = None
) -> FlowResult:
) -> AuthFlowResult:
"""Handle the step of mfa validation."""
assert self.credential
assert self.user
@ -282,6 +283,6 @@ class LoginFlow(data_entry_flow.FlowHandler):
errors=errors,
)
async def async_finish(self, flow_result: Any) -> FlowResult:
async def async_finish(self, flow_result: Any) -> AuthFlowResult:
"""Handle the pass of login flow."""
return self.async_create_entry(data=flow_result)

View File

@ -10,10 +10,9 @@ from typing import Any, cast
import voluptuous as vol
from homeassistant.const import CONF_COMMAND
from homeassistant.data_entry_flow import FlowResult
from homeassistant.exceptions import HomeAssistantError
from ..models import Credentials, UserMeta
from ..models import AuthFlowResult, Credentials, UserMeta
from . import AUTH_PROVIDER_SCHEMA, AUTH_PROVIDERS, AuthProvider, LoginFlow
CONF_ARGS = "args"
@ -138,7 +137,7 @@ class CommandLineLoginFlow(LoginFlow):
async def async_step_init(
self, user_input: dict[str, str] | None = None
) -> FlowResult:
) -> AuthFlowResult:
"""Handle the step of the form."""
errors = {}

View File

@ -12,11 +12,10 @@ import voluptuous as vol
from homeassistant.const import CONF_ID
from homeassistant.core import HomeAssistant, callback
from homeassistant.data_entry_flow import FlowResult
from homeassistant.exceptions import HomeAssistantError
from homeassistant.helpers.storage import Store
from ..models import Credentials, UserMeta
from ..models import AuthFlowResult, Credentials, UserMeta
from . import AUTH_PROVIDER_SCHEMA, AUTH_PROVIDERS, AuthProvider, LoginFlow
STORAGE_VERSION = 1
@ -321,7 +320,7 @@ class HassLoginFlow(LoginFlow):
async def async_step_init(
self, user_input: dict[str, str] | None = None
) -> FlowResult:
) -> AuthFlowResult:
"""Handle the step of the form."""
errors = {}

View File

@ -8,10 +8,9 @@ from typing import Any, cast
import voluptuous as vol
from homeassistant.core import callback
from homeassistant.data_entry_flow import FlowResult
from homeassistant.exceptions import HomeAssistantError
from ..models import Credentials, UserMeta
from ..models import AuthFlowResult, Credentials, UserMeta
from . import AUTH_PROVIDER_SCHEMA, AUTH_PROVIDERS, AuthProvider, LoginFlow
USER_SCHEMA = vol.Schema(
@ -98,7 +97,7 @@ class ExampleLoginFlow(LoginFlow):
async def async_step_init(
self, user_input: dict[str, str] | None = None
) -> FlowResult:
) -> AuthFlowResult:
"""Handle the step of the form."""
errors = None

View File

@ -11,12 +11,11 @@ from typing import Any, cast
import voluptuous as vol
from homeassistant.core import async_get_hass, callback
from homeassistant.data_entry_flow import FlowResult
from homeassistant.exceptions import HomeAssistantError
import homeassistant.helpers.config_validation as cv
from homeassistant.helpers.issue_registry import IssueSeverity, async_create_issue
from ..models import Credentials, UserMeta
from ..models import AuthFlowResult, Credentials, UserMeta
from . import AUTH_PROVIDER_SCHEMA, AUTH_PROVIDERS, AuthProvider, LoginFlow
AUTH_PROVIDER_TYPE = "legacy_api_password"
@ -101,7 +100,7 @@ class LegacyLoginFlow(LoginFlow):
async def async_step_init(
self, user_input: dict[str, str] | None = None
) -> FlowResult:
) -> AuthFlowResult:
"""Handle the step of the form."""
errors = {}

View File

@ -19,13 +19,12 @@ from typing import Any, cast
import voluptuous as vol
from homeassistant.core import callback
from homeassistant.data_entry_flow import FlowResult
from homeassistant.exceptions import HomeAssistantError
import homeassistant.helpers.config_validation as cv
from homeassistant.helpers.network import is_cloud_connection
from .. import InvalidAuthError
from ..models import Credentials, RefreshToken, UserMeta
from ..models import AuthFlowResult, Credentials, RefreshToken, UserMeta
from . import AUTH_PROVIDER_SCHEMA, AUTH_PROVIDERS, AuthProvider, LoginFlow
IPAddress = IPv4Address | IPv6Address
@ -226,7 +225,7 @@ class TrustedNetworksLoginFlow(LoginFlow):
async def async_step_init(
self, user_input: dict[str, str] | None = None
) -> FlowResult:
) -> AuthFlowResult:
"""Handle the step of the form."""
try:
cast(

View File

@ -79,7 +79,7 @@ import voluptuous_serialize
from homeassistant import data_entry_flow
from homeassistant.auth import AuthManagerFlowManager, InvalidAuthError
from homeassistant.auth.models import Credentials
from homeassistant.auth.models import AuthFlowResult, Credentials
from homeassistant.components import onboarding
from homeassistant.components.http.auth import async_user_not_allowed_do_auth
from homeassistant.components.http.ban import (
@ -197,8 +197,8 @@ class AuthProvidersView(HomeAssistantView):
def _prepare_result_json(
result: data_entry_flow.FlowResult,
) -> data_entry_flow.FlowResult:
result: AuthFlowResult,
) -> AuthFlowResult:
"""Convert result to JSON."""
if result["type"] == data_entry_flow.FlowResultType.CREATE_ENTRY:
data = result.copy()
@ -237,7 +237,7 @@ class LoginFlowBaseView(HomeAssistantView):
self,
request: web.Request,
client_id: str,
result: data_entry_flow.FlowResult,
result: AuthFlowResult,
) -> web.Response:
"""Convert the flow result to a response."""
if result["type"] != data_entry_flow.FlowResultType.CREATE_ENTRY:
@ -297,7 +297,9 @@ class LoginFlowIndexView(LoginFlowBaseView):
vol.Schema(
{
vol.Required("client_id"): str,
vol.Required("handler"): vol.Any(str, list),
vol.Required("handler"): vol.All(
[vol.Any(str, None)], vol.Length(2, 2), vol.Coerce(tuple)
),
vol.Required("redirect_uri"): str,
vol.Optional("type", default="authorize"): str,
}
@ -312,15 +314,11 @@ class LoginFlowIndexView(LoginFlowBaseView):
if not indieauth.verify_client_id(client_id):
return self.json_message("Invalid client id", HTTPStatus.BAD_REQUEST)
handler: tuple[str, ...] | str
if isinstance(data["handler"], list):
handler = tuple(data["handler"])
else:
handler = data["handler"]
handler: tuple[str, str] = tuple(data["handler"])
try:
result = await self._flow_mgr.async_init(
handler, # type: ignore[arg-type]
handler,
context={
"ip_address": ip_address(request.remote), # type: ignore[arg-type]
"credential_only": data.get("type") == "link_user",

View File

@ -182,7 +182,7 @@ class BaseZwaveJSFlow(ConfigEntryBaseFlow, ABC):
@property
@abstractmethod
def flow_manager(self) -> FlowManager[ConfigFlowResult]:
def flow_manager(self) -> FlowManager[ConfigFlowResult, str]:
"""Return the flow manager of the flow."""
async def async_step_install_addon(

View File

@ -1045,7 +1045,7 @@ class FlowCancelledError(Exception):
"""Error to indicate that a flow has been cancelled."""
class ConfigEntriesFlowManager(data_entry_flow.FlowManager[ConfigFlowResult]):
class ConfigEntriesFlowManager(data_entry_flow.FlowManager[ConfigFlowResult, str]):
"""Manage all the config entry flows that are in progress."""
_flow_result = ConfigFlowResult
@ -1171,7 +1171,7 @@ class ConfigEntriesFlowManager(data_entry_flow.FlowManager[ConfigFlowResult]):
async def async_finish_flow(
self,
flow: data_entry_flow.FlowHandler[ConfigFlowResult],
flow: data_entry_flow.FlowHandler[ConfigFlowResult, str],
result: ConfigFlowResult,
) -> ConfigFlowResult:
"""Finish a config flow and add an entry."""
@ -1293,7 +1293,7 @@ class ConfigEntriesFlowManager(data_entry_flow.FlowManager[ConfigFlowResult]):
async def async_post_init(
self,
flow: data_entry_flow.FlowHandler[ConfigFlowResult],
flow: data_entry_flow.FlowHandler[ConfigFlowResult, str],
result: ConfigFlowResult,
) -> None:
"""After a flow is initialised trigger new flow notifications."""
@ -1940,7 +1940,7 @@ def _async_abort_entries_match(
raise data_entry_flow.AbortFlow("already_configured")
class ConfigEntryBaseFlow(data_entry_flow.FlowHandler[ConfigFlowResult]):
class ConfigEntryBaseFlow(data_entry_flow.FlowHandler[ConfigFlowResult, str]):
"""Base class for config and option flows."""
_flow_result = ConfigFlowResult
@ -2292,7 +2292,7 @@ class ConfigFlow(ConfigEntryBaseFlow):
return self.async_abort(reason=reason)
class OptionsFlowManager(data_entry_flow.FlowManager[ConfigFlowResult]):
class OptionsFlowManager(data_entry_flow.FlowManager[ConfigFlowResult, str]):
"""Flow to set options for a configuration entry."""
_flow_result = ConfigFlowResult
@ -2322,7 +2322,7 @@ class OptionsFlowManager(data_entry_flow.FlowManager[ConfigFlowResult]):
async def async_finish_flow(
self,
flow: data_entry_flow.FlowHandler[ConfigFlowResult],
flow: data_entry_flow.FlowHandler[ConfigFlowResult, str],
result: ConfigFlowResult,
) -> ConfigFlowResult:
"""Finish an options flow and update options for configuration entry.
@ -2344,7 +2344,7 @@ class OptionsFlowManager(data_entry_flow.FlowManager[ConfigFlowResult]):
return result
async def _async_setup_preview(
self, flow: data_entry_flow.FlowHandler[ConfigFlowResult]
self, flow: data_entry_flow.FlowHandler[ConfigFlowResult, str]
) -> None:
"""Set up preview for an option flow handler."""
entry = self._async_get_config_entry(flow.handler)

View File

@ -85,7 +85,8 @@ STEP_ID_OPTIONAL_STEPS = {
}
_FlowResultT = TypeVar("_FlowResultT", bound="FlowResult", default="FlowResult")
_FlowResultT = TypeVar("_FlowResultT", bound="FlowResult[Any]", default="FlowResult")
_HandlerT = TypeVar("_HandlerT", default=str)
@dataclass(slots=True)
@ -138,7 +139,7 @@ class AbortFlow(FlowError):
self.description_placeholders = description_placeholders
class FlowResult(TypedDict, total=False):
class FlowResult(TypedDict, Generic[_HandlerT], total=False):
"""Typed result dict."""
context: dict[str, Any]
@ -149,7 +150,7 @@ class FlowResult(TypedDict, total=False):
errors: dict[str, str] | None
extra: str
flow_id: Required[str]
handler: Required[str]
handler: Required[_HandlerT]
last_step: bool | None
menu_options: list[str] | dict[str, str]
options: Mapping[str, Any]
@ -189,7 +190,7 @@ def _map_error_to_schema_errors(
schema_errors[path_part_str] = error.error_message
class FlowManager(abc.ABC, Generic[_FlowResultT]):
class FlowManager(abc.ABC, Generic[_FlowResultT, _HandlerT]):
"""Manage all the flows that are in progress."""
_flow_result: Callable[..., _FlowResultT] = FlowResult # type: ignore[assignment]
@ -200,19 +201,23 @@ class FlowManager(abc.ABC, Generic[_FlowResultT]):
) -> None:
"""Initialize the flow manager."""
self.hass = hass
self._preview: set[str] = set()
self._progress: dict[str, FlowHandler[_FlowResultT]] = {}
self._handler_progress_index: dict[str, set[FlowHandler[_FlowResultT]]] = {}
self._init_data_process_index: dict[type, set[FlowHandler[_FlowResultT]]] = {}
self._preview: set[_HandlerT] = set()
self._progress: dict[str, FlowHandler[_FlowResultT, _HandlerT]] = {}
self._handler_progress_index: dict[
_HandlerT, set[FlowHandler[_FlowResultT, _HandlerT]]
] = {}
self._init_data_process_index: dict[
type, set[FlowHandler[_FlowResultT, _HandlerT]]
] = {}
@abc.abstractmethod
async def async_create_flow(
self,
handler_key: str,
handler_key: _HandlerT,
*,
context: dict[str, Any] | None = None,
data: dict[str, Any] | None = None,
) -> FlowHandler[_FlowResultT]:
) -> FlowHandler[_FlowResultT, _HandlerT]:
"""Create a flow for specified handler.
Handler key is the domain of the component that we want to set up.
@ -220,18 +225,18 @@ class FlowManager(abc.ABC, Generic[_FlowResultT]):
@abc.abstractmethod
async def async_finish_flow(
self, flow: FlowHandler[_FlowResultT], result: _FlowResultT
self, flow: FlowHandler[_FlowResultT, _HandlerT], result: _FlowResultT
) -> _FlowResultT:
"""Finish a data entry flow."""
async def async_post_init(
self, flow: FlowHandler[_FlowResultT], result: _FlowResultT
self, flow: FlowHandler[_FlowResultT, _HandlerT], result: _FlowResultT
) -> None:
"""Entry has finished executing its first step asynchronously."""
@callback
def async_has_matching_flow(
self, handler: str, match_context: dict[str, Any], data: Any
self, handler: _HandlerT, match_context: dict[str, Any], data: Any
) -> bool:
"""Check if an existing matching flow is in progress.
@ -265,7 +270,7 @@ class FlowManager(abc.ABC, Generic[_FlowResultT]):
@callback
def async_progress_by_handler(
self,
handler: str,
handler: _HandlerT,
include_uninitialized: bool = False,
match_context: dict[str, Any] | None = None,
) -> list[_FlowResultT]:
@ -298,8 +303,8 @@ class FlowManager(abc.ABC, Generic[_FlowResultT]):
@callback
def _async_progress_by_handler(
self, handler: str, match_context: dict[str, Any] | None
) -> list[FlowHandler[_FlowResultT]]:
self, handler: _HandlerT, match_context: dict[str, Any] | None
) -> list[FlowHandler[_FlowResultT, _HandlerT]]:
"""Return the flows in progress by handler.
If match_context is specified, only return flows with a context that
@ -315,7 +320,11 @@ class FlowManager(abc.ABC, Generic[_FlowResultT]):
]
async def async_init(
self, handler: str, *, context: dict[str, Any] | None = None, data: Any = None
self,
handler: _HandlerT,
*,
context: dict[str, Any] | None = None,
data: Any = None,
) -> _FlowResultT:
"""Start a data entry flow."""
if context is None:
@ -445,7 +454,9 @@ class FlowManager(abc.ABC, Generic[_FlowResultT]):
self._async_remove_flow_progress(flow_id)
@callback
def _async_add_flow_progress(self, flow: FlowHandler[_FlowResultT]) -> None:
def _async_add_flow_progress(
self, flow: FlowHandler[_FlowResultT, _HandlerT]
) -> None:
"""Add a flow to in progress."""
if flow.init_data is not None:
init_data_type = type(flow.init_data)
@ -454,7 +465,9 @@ class FlowManager(abc.ABC, Generic[_FlowResultT]):
self._handler_progress_index.setdefault(flow.handler, set()).add(flow)
@callback
def _async_remove_flow_from_index(self, flow: FlowHandler[_FlowResultT]) -> None:
def _async_remove_flow_from_index(
self, flow: FlowHandler[_FlowResultT, _HandlerT]
) -> None:
"""Remove a flow from in progress."""
if flow.init_data is not None:
init_data_type = type(flow.init_data)
@ -480,7 +493,7 @@ class FlowManager(abc.ABC, Generic[_FlowResultT]):
async def _async_handle_step(
self,
flow: FlowHandler[_FlowResultT],
flow: FlowHandler[_FlowResultT, _HandlerT],
step_id: str,
user_input: dict | BaseServiceInfo | None,
) -> _FlowResultT:
@ -557,7 +570,7 @@ class FlowManager(abc.ABC, Generic[_FlowResultT]):
return result
def _raise_if_step_does_not_exist(
self, flow: FlowHandler[_FlowResultT], step_id: str
self, flow: FlowHandler[_FlowResultT, _HandlerT], step_id: str
) -> None:
"""Raise if the step does not exist."""
method = f"async_step_{step_id}"
@ -568,7 +581,9 @@ class FlowManager(abc.ABC, Generic[_FlowResultT]):
f"Handler {self.__class__.__name__} doesn't support step {step_id}"
)
async def _async_setup_preview(self, flow: FlowHandler[_FlowResultT]) -> None:
async def _async_setup_preview(
self, flow: FlowHandler[_FlowResultT, _HandlerT]
) -> None:
"""Set up preview for a flow handler."""
if flow.handler not in self._preview:
self._preview.add(flow.handler)
@ -576,7 +591,9 @@ class FlowManager(abc.ABC, Generic[_FlowResultT]):
@callback
def _async_flow_handler_to_flow_result(
self, flows: Iterable[FlowHandler[_FlowResultT]], include_uninitialized: bool
self,
flows: Iterable[FlowHandler[_FlowResultT, _HandlerT]],
include_uninitialized: bool,
) -> list[_FlowResultT]:
"""Convert a list of FlowHandler to a partial FlowResult that can be serialized."""
results = []
@ -594,7 +611,7 @@ class FlowManager(abc.ABC, Generic[_FlowResultT]):
return results
class FlowHandler(Generic[_FlowResultT]):
class FlowHandler(Generic[_FlowResultT, _HandlerT]):
"""Handle a data entry flow."""
_flow_result: Callable[..., _FlowResultT] = FlowResult # type: ignore[assignment]
@ -606,7 +623,7 @@ class FlowHandler(Generic[_FlowResultT]):
# and removes the need for constant None checks or asserts.
flow_id: str = None # type: ignore[assignment]
hass: HomeAssistant = None # type: ignore[assignment]
handler: str = None # type: ignore[assignment]
handler: _HandlerT = None # type: ignore[assignment]
# Ensure the attribute has a subscriptable, but immutable, default value.
context: dict[str, Any] = MappingProxyType({}) # type: ignore[assignment]

View File

@ -17,7 +17,7 @@ from . import config_validation as cv
_FlowManagerT = TypeVar(
"_FlowManagerT",
bound=data_entry_flow.FlowManager[Any],
bound="data_entry_flow.FlowManager[Any]",
default=data_entry_flow.FlowManager,
)
@ -61,7 +61,7 @@ class FlowManagerIndexView(_BaseFlowManagerView[_FlowManagerT]):
@RequestDataValidator(
vol.Schema(
{
vol.Required("handler"): vol.Any(str, list),
vol.Required("handler"): str,
vol.Optional("show_advanced_options", default=False): cv.boolean,
},
extra=vol.ALLOW_EXTRA,
@ -79,14 +79,9 @@ class FlowManagerIndexView(_BaseFlowManagerView[_FlowManagerT]):
self, request: web.Request, data: dict[str, Any]
) -> web.Response:
"""Handle a POST request."""
if isinstance(data["handler"], list):
handler = tuple(data["handler"])
else:
handler = data["handler"]
try:
result = await self._flow_mgr.async_init(
handler, # type: ignore[arg-type]
data["handler"],
context=self.get_context(data),
)
except data_entry_flow.UnknownHandler: