Make FlowResult a generic type (#111952)
parent
008e025d5c
commit
82efb3d35b
|
@ -19,13 +19,13 @@ from homeassistant.core import (
|
|||
HomeAssistant,
|
||||
callback,
|
||||
)
|
||||
from homeassistant.data_entry_flow import FlowResult
|
||||
from homeassistant.helpers.event import async_track_point_in_utc_time
|
||||
from homeassistant.util import dt as dt_util
|
||||
|
||||
from . import auth_store, jwt_wrapper, models
|
||||
from .const import ACCESS_TOKEN_EXPIRATION, GROUP_ID_ADMIN, REFRESH_TOKEN_EXPIRATION
|
||||
from .mfa_modules import MultiFactorAuthModule, auth_mfa_module_from_config
|
||||
from .models import AuthFlowResult
|
||||
from .providers import AuthProvider, LoginFlow, auth_provider_from_config
|
||||
|
||||
EVENT_USER_ADDED = "user_added"
|
||||
|
@ -88,9 +88,13 @@ async def auth_manager_from_config(
|
|||
return manager
|
||||
|
||||
|
||||
class AuthManagerFlowManager(data_entry_flow.FlowManager):
|
||||
class AuthManagerFlowManager(
|
||||
data_entry_flow.FlowManager[AuthFlowResult, tuple[str, str]]
|
||||
):
|
||||
"""Manage authentication flows."""
|
||||
|
||||
_flow_result = AuthFlowResult
|
||||
|
||||
def __init__(self, hass: HomeAssistant, auth_manager: AuthManager) -> None:
|
||||
"""Init auth manager flows."""
|
||||
super().__init__(hass)
|
||||
|
@ -98,11 +102,11 @@ class AuthManagerFlowManager(data_entry_flow.FlowManager):
|
|||
|
||||
async def async_create_flow(
|
||||
self,
|
||||
handler_key: str,
|
||||
handler_key: tuple[str, str],
|
||||
*,
|
||||
context: dict[str, Any] | None = None,
|
||||
data: dict[str, Any] | None = None,
|
||||
) -> data_entry_flow.FlowHandler:
|
||||
) -> LoginFlow:
|
||||
"""Create a login flow."""
|
||||
auth_provider = self.auth_manager.get_auth_provider(*handler_key)
|
||||
if not auth_provider:
|
||||
|
@ -110,8 +114,10 @@ class AuthManagerFlowManager(data_entry_flow.FlowManager):
|
|||
return await auth_provider.async_login_flow(context)
|
||||
|
||||
async def async_finish_flow(
|
||||
self, flow: data_entry_flow.FlowHandler, result: FlowResult
|
||||
) -> FlowResult:
|
||||
self,
|
||||
flow: data_entry_flow.FlowHandler[AuthFlowResult, tuple[str, str]],
|
||||
result: AuthFlowResult,
|
||||
) -> AuthFlowResult:
|
||||
"""Return a user as result of login flow."""
|
||||
flow = cast(LoginFlow, flow)
|
||||
|
||||
|
|
|
@ -11,6 +11,7 @@ from attr import Attribute
|
|||
from attr.setters import validate
|
||||
|
||||
from homeassistant.const import __version__
|
||||
from homeassistant.data_entry_flow import FlowResult
|
||||
from homeassistant.util import dt as dt_util
|
||||
|
||||
from . import permissions as perm_mdl
|
||||
|
@ -26,6 +27,8 @@ TOKEN_TYPE_NORMAL = "normal"
|
|||
TOKEN_TYPE_SYSTEM = "system"
|
||||
TOKEN_TYPE_LONG_LIVED_ACCESS_TOKEN = "long_lived_access_token"
|
||||
|
||||
AuthFlowResult = FlowResult[tuple[str, str]]
|
||||
|
||||
|
||||
@attr.s(slots=True)
|
||||
class Group:
|
||||
|
|
|
@ -13,14 +13,13 @@ from voluptuous.humanize import humanize_error
|
|||
from homeassistant import data_entry_flow, requirements
|
||||
from homeassistant.const import CONF_ID, CONF_NAME, CONF_TYPE
|
||||
from homeassistant.core import HomeAssistant, callback
|
||||
from homeassistant.data_entry_flow import FlowResult
|
||||
from homeassistant.exceptions import HomeAssistantError
|
||||
from homeassistant.util import dt as dt_util
|
||||
from homeassistant.util.decorator import Registry
|
||||
|
||||
from ..auth_store import AuthStore
|
||||
from ..const import MFA_SESSION_EXPIRATION
|
||||
from ..models import Credentials, RefreshToken, User, UserMeta
|
||||
from ..models import AuthFlowResult, Credentials, RefreshToken, User, UserMeta
|
||||
|
||||
_LOGGER = logging.getLogger(__name__)
|
||||
DATA_REQS = "auth_prov_reqs_processed"
|
||||
|
@ -181,9 +180,11 @@ async def load_auth_provider_module(
|
|||
return module
|
||||
|
||||
|
||||
class LoginFlow(data_entry_flow.FlowHandler):
|
||||
class LoginFlow(data_entry_flow.FlowHandler[AuthFlowResult, tuple[str, str]]):
|
||||
"""Handler for the login flow."""
|
||||
|
||||
_flow_result = AuthFlowResult
|
||||
|
||||
def __init__(self, auth_provider: AuthProvider) -> None:
|
||||
"""Initialize the login flow."""
|
||||
self._auth_provider = auth_provider
|
||||
|
@ -197,7 +198,7 @@ class LoginFlow(data_entry_flow.FlowHandler):
|
|||
|
||||
async def async_step_init(
|
||||
self, user_input: dict[str, str] | None = None
|
||||
) -> FlowResult:
|
||||
) -> AuthFlowResult:
|
||||
"""Handle the first step of login flow.
|
||||
|
||||
Return self.async_show_form(step_id='init') if user_input is None.
|
||||
|
@ -207,7 +208,7 @@ class LoginFlow(data_entry_flow.FlowHandler):
|
|||
|
||||
async def async_step_select_mfa_module(
|
||||
self, user_input: dict[str, str] | None = None
|
||||
) -> FlowResult:
|
||||
) -> AuthFlowResult:
|
||||
"""Handle the step of select mfa module."""
|
||||
errors = {}
|
||||
|
||||
|
@ -232,7 +233,7 @@ class LoginFlow(data_entry_flow.FlowHandler):
|
|||
|
||||
async def async_step_mfa(
|
||||
self, user_input: dict[str, str] | None = None
|
||||
) -> FlowResult:
|
||||
) -> AuthFlowResult:
|
||||
"""Handle the step of mfa validation."""
|
||||
assert self.credential
|
||||
assert self.user
|
||||
|
@ -282,6 +283,6 @@ class LoginFlow(data_entry_flow.FlowHandler):
|
|||
errors=errors,
|
||||
)
|
||||
|
||||
async def async_finish(self, flow_result: Any) -> FlowResult:
|
||||
async def async_finish(self, flow_result: Any) -> AuthFlowResult:
|
||||
"""Handle the pass of login flow."""
|
||||
return self.async_create_entry(data=flow_result)
|
||||
|
|
|
@ -10,10 +10,9 @@ from typing import Any, cast
|
|||
import voluptuous as vol
|
||||
|
||||
from homeassistant.const import CONF_COMMAND
|
||||
from homeassistant.data_entry_flow import FlowResult
|
||||
from homeassistant.exceptions import HomeAssistantError
|
||||
|
||||
from ..models import Credentials, UserMeta
|
||||
from ..models import AuthFlowResult, Credentials, UserMeta
|
||||
from . import AUTH_PROVIDER_SCHEMA, AUTH_PROVIDERS, AuthProvider, LoginFlow
|
||||
|
||||
CONF_ARGS = "args"
|
||||
|
@ -138,7 +137,7 @@ class CommandLineLoginFlow(LoginFlow):
|
|||
|
||||
async def async_step_init(
|
||||
self, user_input: dict[str, str] | None = None
|
||||
) -> FlowResult:
|
||||
) -> AuthFlowResult:
|
||||
"""Handle the step of the form."""
|
||||
errors = {}
|
||||
|
||||
|
|
|
@ -12,11 +12,10 @@ import voluptuous as vol
|
|||
|
||||
from homeassistant.const import CONF_ID
|
||||
from homeassistant.core import HomeAssistant, callback
|
||||
from homeassistant.data_entry_flow import FlowResult
|
||||
from homeassistant.exceptions import HomeAssistantError
|
||||
from homeassistant.helpers.storage import Store
|
||||
|
||||
from ..models import Credentials, UserMeta
|
||||
from ..models import AuthFlowResult, Credentials, UserMeta
|
||||
from . import AUTH_PROVIDER_SCHEMA, AUTH_PROVIDERS, AuthProvider, LoginFlow
|
||||
|
||||
STORAGE_VERSION = 1
|
||||
|
@ -321,7 +320,7 @@ class HassLoginFlow(LoginFlow):
|
|||
|
||||
async def async_step_init(
|
||||
self, user_input: dict[str, str] | None = None
|
||||
) -> FlowResult:
|
||||
) -> AuthFlowResult:
|
||||
"""Handle the step of the form."""
|
||||
errors = {}
|
||||
|
||||
|
|
|
@ -8,10 +8,9 @@ from typing import Any, cast
|
|||
import voluptuous as vol
|
||||
|
||||
from homeassistant.core import callback
|
||||
from homeassistant.data_entry_flow import FlowResult
|
||||
from homeassistant.exceptions import HomeAssistantError
|
||||
|
||||
from ..models import Credentials, UserMeta
|
||||
from ..models import AuthFlowResult, Credentials, UserMeta
|
||||
from . import AUTH_PROVIDER_SCHEMA, AUTH_PROVIDERS, AuthProvider, LoginFlow
|
||||
|
||||
USER_SCHEMA = vol.Schema(
|
||||
|
@ -98,7 +97,7 @@ class ExampleLoginFlow(LoginFlow):
|
|||
|
||||
async def async_step_init(
|
||||
self, user_input: dict[str, str] | None = None
|
||||
) -> FlowResult:
|
||||
) -> AuthFlowResult:
|
||||
"""Handle the step of the form."""
|
||||
errors = None
|
||||
|
||||
|
|
|
@ -11,12 +11,11 @@ from typing import Any, cast
|
|||
import voluptuous as vol
|
||||
|
||||
from homeassistant.core import async_get_hass, callback
|
||||
from homeassistant.data_entry_flow import FlowResult
|
||||
from homeassistant.exceptions import HomeAssistantError
|
||||
import homeassistant.helpers.config_validation as cv
|
||||
from homeassistant.helpers.issue_registry import IssueSeverity, async_create_issue
|
||||
|
||||
from ..models import Credentials, UserMeta
|
||||
from ..models import AuthFlowResult, Credentials, UserMeta
|
||||
from . import AUTH_PROVIDER_SCHEMA, AUTH_PROVIDERS, AuthProvider, LoginFlow
|
||||
|
||||
AUTH_PROVIDER_TYPE = "legacy_api_password"
|
||||
|
@ -101,7 +100,7 @@ class LegacyLoginFlow(LoginFlow):
|
|||
|
||||
async def async_step_init(
|
||||
self, user_input: dict[str, str] | None = None
|
||||
) -> FlowResult:
|
||||
) -> AuthFlowResult:
|
||||
"""Handle the step of the form."""
|
||||
errors = {}
|
||||
|
||||
|
|
|
@ -19,13 +19,12 @@ from typing import Any, cast
|
|||
import voluptuous as vol
|
||||
|
||||
from homeassistant.core import callback
|
||||
from homeassistant.data_entry_flow import FlowResult
|
||||
from homeassistant.exceptions import HomeAssistantError
|
||||
import homeassistant.helpers.config_validation as cv
|
||||
from homeassistant.helpers.network import is_cloud_connection
|
||||
|
||||
from .. import InvalidAuthError
|
||||
from ..models import Credentials, RefreshToken, UserMeta
|
||||
from ..models import AuthFlowResult, Credentials, RefreshToken, UserMeta
|
||||
from . import AUTH_PROVIDER_SCHEMA, AUTH_PROVIDERS, AuthProvider, LoginFlow
|
||||
|
||||
IPAddress = IPv4Address | IPv6Address
|
||||
|
@ -226,7 +225,7 @@ class TrustedNetworksLoginFlow(LoginFlow):
|
|||
|
||||
async def async_step_init(
|
||||
self, user_input: dict[str, str] | None = None
|
||||
) -> FlowResult:
|
||||
) -> AuthFlowResult:
|
||||
"""Handle the step of the form."""
|
||||
try:
|
||||
cast(
|
||||
|
|
|
@ -79,7 +79,7 @@ import voluptuous_serialize
|
|||
|
||||
from homeassistant import data_entry_flow
|
||||
from homeassistant.auth import AuthManagerFlowManager, InvalidAuthError
|
||||
from homeassistant.auth.models import Credentials
|
||||
from homeassistant.auth.models import AuthFlowResult, Credentials
|
||||
from homeassistant.components import onboarding
|
||||
from homeassistant.components.http.auth import async_user_not_allowed_do_auth
|
||||
from homeassistant.components.http.ban import (
|
||||
|
@ -197,8 +197,8 @@ class AuthProvidersView(HomeAssistantView):
|
|||
|
||||
|
||||
def _prepare_result_json(
|
||||
result: data_entry_flow.FlowResult,
|
||||
) -> data_entry_flow.FlowResult:
|
||||
result: AuthFlowResult,
|
||||
) -> AuthFlowResult:
|
||||
"""Convert result to JSON."""
|
||||
if result["type"] == data_entry_flow.FlowResultType.CREATE_ENTRY:
|
||||
data = result.copy()
|
||||
|
@ -237,7 +237,7 @@ class LoginFlowBaseView(HomeAssistantView):
|
|||
self,
|
||||
request: web.Request,
|
||||
client_id: str,
|
||||
result: data_entry_flow.FlowResult,
|
||||
result: AuthFlowResult,
|
||||
) -> web.Response:
|
||||
"""Convert the flow result to a response."""
|
||||
if result["type"] != data_entry_flow.FlowResultType.CREATE_ENTRY:
|
||||
|
@ -297,7 +297,9 @@ class LoginFlowIndexView(LoginFlowBaseView):
|
|||
vol.Schema(
|
||||
{
|
||||
vol.Required("client_id"): str,
|
||||
vol.Required("handler"): vol.Any(str, list),
|
||||
vol.Required("handler"): vol.All(
|
||||
[vol.Any(str, None)], vol.Length(2, 2), vol.Coerce(tuple)
|
||||
),
|
||||
vol.Required("redirect_uri"): str,
|
||||
vol.Optional("type", default="authorize"): str,
|
||||
}
|
||||
|
@ -312,15 +314,11 @@ class LoginFlowIndexView(LoginFlowBaseView):
|
|||
if not indieauth.verify_client_id(client_id):
|
||||
return self.json_message("Invalid client id", HTTPStatus.BAD_REQUEST)
|
||||
|
||||
handler: tuple[str, ...] | str
|
||||
if isinstance(data["handler"], list):
|
||||
handler = tuple(data["handler"])
|
||||
else:
|
||||
handler = data["handler"]
|
||||
handler: tuple[str, str] = tuple(data["handler"])
|
||||
|
||||
try:
|
||||
result = await self._flow_mgr.async_init(
|
||||
handler, # type: ignore[arg-type]
|
||||
handler,
|
||||
context={
|
||||
"ip_address": ip_address(request.remote), # type: ignore[arg-type]
|
||||
"credential_only": data.get("type") == "link_user",
|
||||
|
|
|
@ -182,7 +182,7 @@ class BaseZwaveJSFlow(ConfigEntryBaseFlow, ABC):
|
|||
|
||||
@property
|
||||
@abstractmethod
|
||||
def flow_manager(self) -> FlowManager[ConfigFlowResult]:
|
||||
def flow_manager(self) -> FlowManager[ConfigFlowResult, str]:
|
||||
"""Return the flow manager of the flow."""
|
||||
|
||||
async def async_step_install_addon(
|
||||
|
|
|
@ -1045,7 +1045,7 @@ class FlowCancelledError(Exception):
|
|||
"""Error to indicate that a flow has been cancelled."""
|
||||
|
||||
|
||||
class ConfigEntriesFlowManager(data_entry_flow.FlowManager[ConfigFlowResult]):
|
||||
class ConfigEntriesFlowManager(data_entry_flow.FlowManager[ConfigFlowResult, str]):
|
||||
"""Manage all the config entry flows that are in progress."""
|
||||
|
||||
_flow_result = ConfigFlowResult
|
||||
|
@ -1171,7 +1171,7 @@ class ConfigEntriesFlowManager(data_entry_flow.FlowManager[ConfigFlowResult]):
|
|||
|
||||
async def async_finish_flow(
|
||||
self,
|
||||
flow: data_entry_flow.FlowHandler[ConfigFlowResult],
|
||||
flow: data_entry_flow.FlowHandler[ConfigFlowResult, str],
|
||||
result: ConfigFlowResult,
|
||||
) -> ConfigFlowResult:
|
||||
"""Finish a config flow and add an entry."""
|
||||
|
@ -1293,7 +1293,7 @@ class ConfigEntriesFlowManager(data_entry_flow.FlowManager[ConfigFlowResult]):
|
|||
|
||||
async def async_post_init(
|
||||
self,
|
||||
flow: data_entry_flow.FlowHandler[ConfigFlowResult],
|
||||
flow: data_entry_flow.FlowHandler[ConfigFlowResult, str],
|
||||
result: ConfigFlowResult,
|
||||
) -> None:
|
||||
"""After a flow is initialised trigger new flow notifications."""
|
||||
|
@ -1940,7 +1940,7 @@ def _async_abort_entries_match(
|
|||
raise data_entry_flow.AbortFlow("already_configured")
|
||||
|
||||
|
||||
class ConfigEntryBaseFlow(data_entry_flow.FlowHandler[ConfigFlowResult]):
|
||||
class ConfigEntryBaseFlow(data_entry_flow.FlowHandler[ConfigFlowResult, str]):
|
||||
"""Base class for config and option flows."""
|
||||
|
||||
_flow_result = ConfigFlowResult
|
||||
|
@ -2292,7 +2292,7 @@ class ConfigFlow(ConfigEntryBaseFlow):
|
|||
return self.async_abort(reason=reason)
|
||||
|
||||
|
||||
class OptionsFlowManager(data_entry_flow.FlowManager[ConfigFlowResult]):
|
||||
class OptionsFlowManager(data_entry_flow.FlowManager[ConfigFlowResult, str]):
|
||||
"""Flow to set options for a configuration entry."""
|
||||
|
||||
_flow_result = ConfigFlowResult
|
||||
|
@ -2322,7 +2322,7 @@ class OptionsFlowManager(data_entry_flow.FlowManager[ConfigFlowResult]):
|
|||
|
||||
async def async_finish_flow(
|
||||
self,
|
||||
flow: data_entry_flow.FlowHandler[ConfigFlowResult],
|
||||
flow: data_entry_flow.FlowHandler[ConfigFlowResult, str],
|
||||
result: ConfigFlowResult,
|
||||
) -> ConfigFlowResult:
|
||||
"""Finish an options flow and update options for configuration entry.
|
||||
|
@ -2344,7 +2344,7 @@ class OptionsFlowManager(data_entry_flow.FlowManager[ConfigFlowResult]):
|
|||
return result
|
||||
|
||||
async def _async_setup_preview(
|
||||
self, flow: data_entry_flow.FlowHandler[ConfigFlowResult]
|
||||
self, flow: data_entry_flow.FlowHandler[ConfigFlowResult, str]
|
||||
) -> None:
|
||||
"""Set up preview for an option flow handler."""
|
||||
entry = self._async_get_config_entry(flow.handler)
|
||||
|
|
|
@ -85,7 +85,8 @@ STEP_ID_OPTIONAL_STEPS = {
|
|||
}
|
||||
|
||||
|
||||
_FlowResultT = TypeVar("_FlowResultT", bound="FlowResult", default="FlowResult")
|
||||
_FlowResultT = TypeVar("_FlowResultT", bound="FlowResult[Any]", default="FlowResult")
|
||||
_HandlerT = TypeVar("_HandlerT", default=str)
|
||||
|
||||
|
||||
@dataclass(slots=True)
|
||||
|
@ -138,7 +139,7 @@ class AbortFlow(FlowError):
|
|||
self.description_placeholders = description_placeholders
|
||||
|
||||
|
||||
class FlowResult(TypedDict, total=False):
|
||||
class FlowResult(TypedDict, Generic[_HandlerT], total=False):
|
||||
"""Typed result dict."""
|
||||
|
||||
context: dict[str, Any]
|
||||
|
@ -149,7 +150,7 @@ class FlowResult(TypedDict, total=False):
|
|||
errors: dict[str, str] | None
|
||||
extra: str
|
||||
flow_id: Required[str]
|
||||
handler: Required[str]
|
||||
handler: Required[_HandlerT]
|
||||
last_step: bool | None
|
||||
menu_options: list[str] | dict[str, str]
|
||||
options: Mapping[str, Any]
|
||||
|
@ -189,7 +190,7 @@ def _map_error_to_schema_errors(
|
|||
schema_errors[path_part_str] = error.error_message
|
||||
|
||||
|
||||
class FlowManager(abc.ABC, Generic[_FlowResultT]):
|
||||
class FlowManager(abc.ABC, Generic[_FlowResultT, _HandlerT]):
|
||||
"""Manage all the flows that are in progress."""
|
||||
|
||||
_flow_result: Callable[..., _FlowResultT] = FlowResult # type: ignore[assignment]
|
||||
|
@ -200,19 +201,23 @@ class FlowManager(abc.ABC, Generic[_FlowResultT]):
|
|||
) -> None:
|
||||
"""Initialize the flow manager."""
|
||||
self.hass = hass
|
||||
self._preview: set[str] = set()
|
||||
self._progress: dict[str, FlowHandler[_FlowResultT]] = {}
|
||||
self._handler_progress_index: dict[str, set[FlowHandler[_FlowResultT]]] = {}
|
||||
self._init_data_process_index: dict[type, set[FlowHandler[_FlowResultT]]] = {}
|
||||
self._preview: set[_HandlerT] = set()
|
||||
self._progress: dict[str, FlowHandler[_FlowResultT, _HandlerT]] = {}
|
||||
self._handler_progress_index: dict[
|
||||
_HandlerT, set[FlowHandler[_FlowResultT, _HandlerT]]
|
||||
] = {}
|
||||
self._init_data_process_index: dict[
|
||||
type, set[FlowHandler[_FlowResultT, _HandlerT]]
|
||||
] = {}
|
||||
|
||||
@abc.abstractmethod
|
||||
async def async_create_flow(
|
||||
self,
|
||||
handler_key: str,
|
||||
handler_key: _HandlerT,
|
||||
*,
|
||||
context: dict[str, Any] | None = None,
|
||||
data: dict[str, Any] | None = None,
|
||||
) -> FlowHandler[_FlowResultT]:
|
||||
) -> FlowHandler[_FlowResultT, _HandlerT]:
|
||||
"""Create a flow for specified handler.
|
||||
|
||||
Handler key is the domain of the component that we want to set up.
|
||||
|
@ -220,18 +225,18 @@ class FlowManager(abc.ABC, Generic[_FlowResultT]):
|
|||
|
||||
@abc.abstractmethod
|
||||
async def async_finish_flow(
|
||||
self, flow: FlowHandler[_FlowResultT], result: _FlowResultT
|
||||
self, flow: FlowHandler[_FlowResultT, _HandlerT], result: _FlowResultT
|
||||
) -> _FlowResultT:
|
||||
"""Finish a data entry flow."""
|
||||
|
||||
async def async_post_init(
|
||||
self, flow: FlowHandler[_FlowResultT], result: _FlowResultT
|
||||
self, flow: FlowHandler[_FlowResultT, _HandlerT], result: _FlowResultT
|
||||
) -> None:
|
||||
"""Entry has finished executing its first step asynchronously."""
|
||||
|
||||
@callback
|
||||
def async_has_matching_flow(
|
||||
self, handler: str, match_context: dict[str, Any], data: Any
|
||||
self, handler: _HandlerT, match_context: dict[str, Any], data: Any
|
||||
) -> bool:
|
||||
"""Check if an existing matching flow is in progress.
|
||||
|
||||
|
@ -265,7 +270,7 @@ class FlowManager(abc.ABC, Generic[_FlowResultT]):
|
|||
@callback
|
||||
def async_progress_by_handler(
|
||||
self,
|
||||
handler: str,
|
||||
handler: _HandlerT,
|
||||
include_uninitialized: bool = False,
|
||||
match_context: dict[str, Any] | None = None,
|
||||
) -> list[_FlowResultT]:
|
||||
|
@ -298,8 +303,8 @@ class FlowManager(abc.ABC, Generic[_FlowResultT]):
|
|||
|
||||
@callback
|
||||
def _async_progress_by_handler(
|
||||
self, handler: str, match_context: dict[str, Any] | None
|
||||
) -> list[FlowHandler[_FlowResultT]]:
|
||||
self, handler: _HandlerT, match_context: dict[str, Any] | None
|
||||
) -> list[FlowHandler[_FlowResultT, _HandlerT]]:
|
||||
"""Return the flows in progress by handler.
|
||||
|
||||
If match_context is specified, only return flows with a context that
|
||||
|
@ -315,7 +320,11 @@ class FlowManager(abc.ABC, Generic[_FlowResultT]):
|
|||
]
|
||||
|
||||
async def async_init(
|
||||
self, handler: str, *, context: dict[str, Any] | None = None, data: Any = None
|
||||
self,
|
||||
handler: _HandlerT,
|
||||
*,
|
||||
context: dict[str, Any] | None = None,
|
||||
data: Any = None,
|
||||
) -> _FlowResultT:
|
||||
"""Start a data entry flow."""
|
||||
if context is None:
|
||||
|
@ -445,7 +454,9 @@ class FlowManager(abc.ABC, Generic[_FlowResultT]):
|
|||
self._async_remove_flow_progress(flow_id)
|
||||
|
||||
@callback
|
||||
def _async_add_flow_progress(self, flow: FlowHandler[_FlowResultT]) -> None:
|
||||
def _async_add_flow_progress(
|
||||
self, flow: FlowHandler[_FlowResultT, _HandlerT]
|
||||
) -> None:
|
||||
"""Add a flow to in progress."""
|
||||
if flow.init_data is not None:
|
||||
init_data_type = type(flow.init_data)
|
||||
|
@ -454,7 +465,9 @@ class FlowManager(abc.ABC, Generic[_FlowResultT]):
|
|||
self._handler_progress_index.setdefault(flow.handler, set()).add(flow)
|
||||
|
||||
@callback
|
||||
def _async_remove_flow_from_index(self, flow: FlowHandler[_FlowResultT]) -> None:
|
||||
def _async_remove_flow_from_index(
|
||||
self, flow: FlowHandler[_FlowResultT, _HandlerT]
|
||||
) -> None:
|
||||
"""Remove a flow from in progress."""
|
||||
if flow.init_data is not None:
|
||||
init_data_type = type(flow.init_data)
|
||||
|
@ -480,7 +493,7 @@ class FlowManager(abc.ABC, Generic[_FlowResultT]):
|
|||
|
||||
async def _async_handle_step(
|
||||
self,
|
||||
flow: FlowHandler[_FlowResultT],
|
||||
flow: FlowHandler[_FlowResultT, _HandlerT],
|
||||
step_id: str,
|
||||
user_input: dict | BaseServiceInfo | None,
|
||||
) -> _FlowResultT:
|
||||
|
@ -557,7 +570,7 @@ class FlowManager(abc.ABC, Generic[_FlowResultT]):
|
|||
return result
|
||||
|
||||
def _raise_if_step_does_not_exist(
|
||||
self, flow: FlowHandler[_FlowResultT], step_id: str
|
||||
self, flow: FlowHandler[_FlowResultT, _HandlerT], step_id: str
|
||||
) -> None:
|
||||
"""Raise if the step does not exist."""
|
||||
method = f"async_step_{step_id}"
|
||||
|
@ -568,7 +581,9 @@ class FlowManager(abc.ABC, Generic[_FlowResultT]):
|
|||
f"Handler {self.__class__.__name__} doesn't support step {step_id}"
|
||||
)
|
||||
|
||||
async def _async_setup_preview(self, flow: FlowHandler[_FlowResultT]) -> None:
|
||||
async def _async_setup_preview(
|
||||
self, flow: FlowHandler[_FlowResultT, _HandlerT]
|
||||
) -> None:
|
||||
"""Set up preview for a flow handler."""
|
||||
if flow.handler not in self._preview:
|
||||
self._preview.add(flow.handler)
|
||||
|
@ -576,7 +591,9 @@ class FlowManager(abc.ABC, Generic[_FlowResultT]):
|
|||
|
||||
@callback
|
||||
def _async_flow_handler_to_flow_result(
|
||||
self, flows: Iterable[FlowHandler[_FlowResultT]], include_uninitialized: bool
|
||||
self,
|
||||
flows: Iterable[FlowHandler[_FlowResultT, _HandlerT]],
|
||||
include_uninitialized: bool,
|
||||
) -> list[_FlowResultT]:
|
||||
"""Convert a list of FlowHandler to a partial FlowResult that can be serialized."""
|
||||
results = []
|
||||
|
@ -594,7 +611,7 @@ class FlowManager(abc.ABC, Generic[_FlowResultT]):
|
|||
return results
|
||||
|
||||
|
||||
class FlowHandler(Generic[_FlowResultT]):
|
||||
class FlowHandler(Generic[_FlowResultT, _HandlerT]):
|
||||
"""Handle a data entry flow."""
|
||||
|
||||
_flow_result: Callable[..., _FlowResultT] = FlowResult # type: ignore[assignment]
|
||||
|
@ -606,7 +623,7 @@ class FlowHandler(Generic[_FlowResultT]):
|
|||
# and removes the need for constant None checks or asserts.
|
||||
flow_id: str = None # type: ignore[assignment]
|
||||
hass: HomeAssistant = None # type: ignore[assignment]
|
||||
handler: str = None # type: ignore[assignment]
|
||||
handler: _HandlerT = None # type: ignore[assignment]
|
||||
# Ensure the attribute has a subscriptable, but immutable, default value.
|
||||
context: dict[str, Any] = MappingProxyType({}) # type: ignore[assignment]
|
||||
|
||||
|
|
|
@ -17,7 +17,7 @@ from . import config_validation as cv
|
|||
|
||||
_FlowManagerT = TypeVar(
|
||||
"_FlowManagerT",
|
||||
bound=data_entry_flow.FlowManager[Any],
|
||||
bound="data_entry_flow.FlowManager[Any]",
|
||||
default=data_entry_flow.FlowManager,
|
||||
)
|
||||
|
||||
|
@ -61,7 +61,7 @@ class FlowManagerIndexView(_BaseFlowManagerView[_FlowManagerT]):
|
|||
@RequestDataValidator(
|
||||
vol.Schema(
|
||||
{
|
||||
vol.Required("handler"): vol.Any(str, list),
|
||||
vol.Required("handler"): str,
|
||||
vol.Optional("show_advanced_options", default=False): cv.boolean,
|
||||
},
|
||||
extra=vol.ALLOW_EXTRA,
|
||||
|
@ -79,14 +79,9 @@ class FlowManagerIndexView(_BaseFlowManagerView[_FlowManagerT]):
|
|||
self, request: web.Request, data: dict[str, Any]
|
||||
) -> web.Response:
|
||||
"""Handle a POST request."""
|
||||
if isinstance(data["handler"], list):
|
||||
handler = tuple(data["handler"])
|
||||
else:
|
||||
handler = data["handler"]
|
||||
|
||||
try:
|
||||
result = await self._flow_mgr.async_init(
|
||||
handler, # type: ignore[arg-type]
|
||||
data["handler"],
|
||||
context=self.get_context(data),
|
||||
)
|
||||
except data_entry_flow.UnknownHandler:
|
||||
|
|
Loading…
Reference in New Issue