Rework FlowManager to use inheritance (#30133)
* Pull async_finish_flow/async_create_flow out of ConfigEntries * Towards refactoring * mypy fixes * Mark Flow manager with abc.* annotations * Flake8 fixes * Mypy fixes * Blacken data_entry_flow * Blacken longer signatures caused by mypy changes * test fixes * Test fixes * Fix typo * Avoid protected member lint (W0212) in config_entries * More protected member fixes * Missing awaitpull/29828/head
parent
0a4f3ec1ec
commit
fdfedd086b
|
@ -67,6 +67,69 @@ async def auth_manager_from_config(
|
|||
return manager
|
||||
|
||||
|
||||
class AuthManagerFlowManager(data_entry_flow.FlowManager):
|
||||
"""Manage authentication flows."""
|
||||
|
||||
def __init__(self, hass: HomeAssistant, auth_manager: "AuthManager"):
|
||||
"""Init auth manager flows."""
|
||||
super().__init__(hass)
|
||||
self.auth_manager = auth_manager
|
||||
|
||||
async def async_create_flow(
|
||||
self,
|
||||
handler_key: Any,
|
||||
*,
|
||||
context: Optional[Dict[str, Any]] = None,
|
||||
data: Optional[Dict[str, Any]] = None,
|
||||
) -> data_entry_flow.FlowHandler:
|
||||
"""Create a login flow."""
|
||||
auth_provider = self.auth_manager.get_auth_provider(*handler_key)
|
||||
if not auth_provider:
|
||||
raise KeyError(f"Unknown auth provider {handler_key}")
|
||||
return await auth_provider.async_login_flow(context)
|
||||
|
||||
async def async_finish_flow(
|
||||
self, flow: data_entry_flow.FlowHandler, result: Dict[str, Any]
|
||||
) -> Dict[str, Any]:
|
||||
"""Return a user as result of login flow."""
|
||||
flow = cast(LoginFlow, flow)
|
||||
|
||||
if result["type"] != data_entry_flow.RESULT_TYPE_CREATE_ENTRY:
|
||||
return result
|
||||
|
||||
# we got final result
|
||||
if isinstance(result["data"], models.User):
|
||||
result["result"] = result["data"]
|
||||
return result
|
||||
|
||||
auth_provider = self.auth_manager.get_auth_provider(*result["handler"])
|
||||
if not auth_provider:
|
||||
raise KeyError(f"Unknown auth provider {result['handler']}")
|
||||
|
||||
credentials = await auth_provider.async_get_or_create_credentials(
|
||||
result["data"]
|
||||
)
|
||||
|
||||
if flow.context.get("credential_only"):
|
||||
result["result"] = credentials
|
||||
return result
|
||||
|
||||
# multi-factor module cannot enabled for new credential
|
||||
# which has not linked to a user yet
|
||||
if auth_provider.support_mfa and not credentials.is_new:
|
||||
user = await self.auth_manager.async_get_user_by_credentials(credentials)
|
||||
if user is not None:
|
||||
modules = await self.auth_manager.async_get_enabled_mfa(user)
|
||||
|
||||
if modules:
|
||||
flow.user = user
|
||||
flow.available_mfa_modules = modules
|
||||
return await flow.async_step_select_mfa_module()
|
||||
|
||||
result["result"] = await self.auth_manager.async_get_or_create_user(credentials)
|
||||
return result
|
||||
|
||||
|
||||
class AuthManager:
|
||||
"""Manage the authentication for Home Assistant."""
|
||||
|
||||
|
@ -82,9 +145,7 @@ class AuthManager:
|
|||
self._store = store
|
||||
self._providers = providers
|
||||
self._mfa_modules = mfa_modules
|
||||
self.login_flow = data_entry_flow.FlowManager(
|
||||
hass, self._async_create_login_flow, self._async_finish_login_flow
|
||||
)
|
||||
self.login_flow = AuthManagerFlowManager(hass, self)
|
||||
|
||||
@property
|
||||
def auth_providers(self) -> List[AuthProvider]:
|
||||
|
@ -417,50 +478,6 @@ class AuthManager:
|
|||
|
||||
return refresh_token
|
||||
|
||||
async def _async_create_login_flow(
|
||||
self, handler: _ProviderKey, *, context: Optional[Dict], data: Optional[Any]
|
||||
) -> data_entry_flow.FlowHandler:
|
||||
"""Create a login flow."""
|
||||
auth_provider = self._providers[handler]
|
||||
|
||||
return await auth_provider.async_login_flow(context)
|
||||
|
||||
async def _async_finish_login_flow(
|
||||
self, flow: LoginFlow, result: Dict[str, Any]
|
||||
) -> Dict[str, Any]:
|
||||
"""Return a user as result of login flow."""
|
||||
if result["type"] != data_entry_flow.RESULT_TYPE_CREATE_ENTRY:
|
||||
return result
|
||||
|
||||
# we got final result
|
||||
if isinstance(result["data"], models.User):
|
||||
result["result"] = result["data"]
|
||||
return result
|
||||
|
||||
auth_provider = self._providers[result["handler"]]
|
||||
credentials = await auth_provider.async_get_or_create_credentials(
|
||||
result["data"]
|
||||
)
|
||||
|
||||
if flow.context.get("credential_only"):
|
||||
result["result"] = credentials
|
||||
return result
|
||||
|
||||
# multi-factor module cannot enabled for new credential
|
||||
# which has not linked to a user yet
|
||||
if auth_provider.support_mfa and not credentials.is_new:
|
||||
user = await self.async_get_user_by_credentials(credentials)
|
||||
if user is not None:
|
||||
modules = await self.async_get_enabled_mfa(user)
|
||||
|
||||
if modules:
|
||||
flow.user = user
|
||||
flow.available_mfa_modules = modules
|
||||
return await flow.async_step_select_mfa_module()
|
||||
|
||||
result["result"] = await self.async_get_or_create_user(credentials)
|
||||
return result
|
||||
|
||||
@callback
|
||||
def _async_get_auth_provider(
|
||||
self, credentials: models.Credentials
|
||||
|
|
|
@ -28,25 +28,27 @@ DATA_SETUP_FLOW_MGR = "auth_mfa_setup_flow_manager"
|
|||
_LOGGER = logging.getLogger(__name__)
|
||||
|
||||
|
||||
async def async_setup(hass):
|
||||
"""Init mfa setup flow manager."""
|
||||
class MfaFlowManager(data_entry_flow.FlowManager):
|
||||
"""Manage multi factor authentication flows."""
|
||||
|
||||
async def _async_create_setup_flow(handler, context, data):
|
||||
async def async_create_flow(self, handler_key, *, context, data):
|
||||
"""Create a setup flow. handler is a mfa module."""
|
||||
mfa_module = hass.auth.get_auth_mfa_module(handler)
|
||||
mfa_module = self.hass.auth.get_auth_mfa_module(handler_key)
|
||||
if mfa_module is None:
|
||||
raise ValueError(f"Mfa module {handler} is not found")
|
||||
raise ValueError(f"Mfa module {handler_key} is not found")
|
||||
|
||||
user_id = data.pop("user_id")
|
||||
return await mfa_module.async_setup_flow(user_id)
|
||||
|
||||
async def _async_finish_setup_flow(flow, flow_result):
|
||||
_LOGGER.debug("flow_result: %s", flow_result)
|
||||
return flow_result
|
||||
async def async_finish_flow(self, flow, result):
|
||||
"""Complete an mfs setup flow."""
|
||||
_LOGGER.debug("flow_result: %s", result)
|
||||
return result
|
||||
|
||||
hass.data[DATA_SETUP_FLOW_MGR] = data_entry_flow.FlowManager(
|
||||
hass, _async_create_setup_flow, _async_finish_setup_flow
|
||||
)
|
||||
|
||||
async def async_setup(hass):
|
||||
"""Init mfa setup flow manager."""
|
||||
hass.data[DATA_SETUP_FLOW_MGR] = MfaFlowManager(hass)
|
||||
|
||||
hass.components.websocket_api.async_register_command(
|
||||
WS_TYPE_SETUP_MFA, websocket_setup_mfa, SCHEMA_WS_SETUP_MFA
|
||||
|
|
|
@ -23,12 +23,8 @@ async def async_setup(hass):
|
|||
hass.http.register_view(ConfigManagerFlowResourceView(hass.config_entries.flow))
|
||||
hass.http.register_view(ConfigManagerAvailableFlowView)
|
||||
|
||||
hass.http.register_view(
|
||||
OptionManagerFlowIndexView(hass.config_entries.options.flow)
|
||||
)
|
||||
hass.http.register_view(
|
||||
OptionManagerFlowResourceView(hass.config_entries.options.flow)
|
||||
)
|
||||
hass.http.register_view(OptionManagerFlowIndexView(hass.config_entries.options))
|
||||
hass.http.register_view(OptionManagerFlowResourceView(hass.config_entries.options))
|
||||
|
||||
hass.components.websocket_api.async_register_command(config_entries_progress)
|
||||
hass.components.websocket_api.async_register_command(system_options_list)
|
||||
|
|
|
@ -399,6 +399,137 @@ class ConfigEntry:
|
|||
}
|
||||
|
||||
|
||||
class ConfigEntriesFlowManager(data_entry_flow.FlowManager):
|
||||
"""Manage all the config entry flows that are in progress."""
|
||||
|
||||
def __init__(
|
||||
self, hass: HomeAssistant, config_entries: "ConfigEntries", hass_config: dict
|
||||
):
|
||||
"""Initialize the config entry flow manager."""
|
||||
super().__init__(hass)
|
||||
self.config_entries = config_entries
|
||||
self._hass_config = hass_config
|
||||
|
||||
async def async_finish_flow(
|
||||
self, flow: data_entry_flow.FlowHandler, result: Dict[str, Any]
|
||||
) -> Dict[str, Any]:
|
||||
"""Finish a config flow and add an entry."""
|
||||
flow = cast(ConfigFlow, flow)
|
||||
|
||||
# Remove notification if no other discovery config entries in progress
|
||||
if not any(
|
||||
ent["context"]["source"] in DISCOVERY_SOURCES
|
||||
for ent in self.hass.config_entries.flow.async_progress()
|
||||
if ent["flow_id"] != flow.flow_id
|
||||
):
|
||||
self.hass.components.persistent_notification.async_dismiss(
|
||||
DISCOVERY_NOTIFICATION_ID
|
||||
)
|
||||
|
||||
if result["type"] != data_entry_flow.RESULT_TYPE_CREATE_ENTRY:
|
||||
return result
|
||||
|
||||
# Check if config entry exists with unique ID. Unload it.
|
||||
existing_entry = None
|
||||
|
||||
if flow.unique_id is not None:
|
||||
# Abort all flows in progress with same unique ID.
|
||||
for progress_flow in self.async_progress():
|
||||
if (
|
||||
progress_flow["handler"] == flow.handler
|
||||
and progress_flow["flow_id"] != flow.flow_id
|
||||
and progress_flow["context"].get("unique_id") == flow.unique_id
|
||||
):
|
||||
self.async_abort(progress_flow["flow_id"])
|
||||
|
||||
# Find existing entry.
|
||||
for check_entry in self.config_entries.async_entries(result["handler"]):
|
||||
if check_entry.unique_id == flow.unique_id:
|
||||
existing_entry = check_entry
|
||||
break
|
||||
|
||||
# Unload the entry before setting up the new one.
|
||||
# We will remove it only after the other one is set up,
|
||||
# so that device customizations are not getting lost.
|
||||
if (
|
||||
existing_entry is not None
|
||||
and existing_entry.state not in UNRECOVERABLE_STATES
|
||||
):
|
||||
await self.config_entries.async_unload(existing_entry.entry_id)
|
||||
|
||||
entry = ConfigEntry(
|
||||
version=result["version"],
|
||||
domain=result["handler"],
|
||||
title=result["title"],
|
||||
data=result["data"],
|
||||
options={},
|
||||
system_options={},
|
||||
source=flow.context["source"],
|
||||
connection_class=flow.CONNECTION_CLASS,
|
||||
unique_id=flow.unique_id,
|
||||
)
|
||||
|
||||
await self.config_entries.async_add(entry)
|
||||
|
||||
if existing_entry is not None:
|
||||
await self.config_entries.async_remove(existing_entry.entry_id)
|
||||
|
||||
result["result"] = entry
|
||||
return result
|
||||
|
||||
async def async_create_flow(
|
||||
self, handler_key: Any, *, context: Optional[Dict] = None, data: Any = None
|
||||
) -> "ConfigFlow":
|
||||
"""Create a flow for specified handler.
|
||||
|
||||
Handler key is the domain of the component that we want to set up.
|
||||
"""
|
||||
try:
|
||||
integration = await loader.async_get_integration(self.hass, handler_key)
|
||||
except loader.IntegrationNotFound:
|
||||
_LOGGER.error("Cannot find integration %s", handler_key)
|
||||
raise data_entry_flow.UnknownHandler
|
||||
|
||||
# Make sure requirements and dependencies of component are resolved
|
||||
await async_process_deps_reqs(self.hass, self._hass_config, integration)
|
||||
|
||||
try:
|
||||
integration.get_platform("config_flow")
|
||||
except ImportError as err:
|
||||
_LOGGER.error(
|
||||
"Error occurred loading config flow for integration %s: %s",
|
||||
handler_key,
|
||||
err,
|
||||
)
|
||||
raise data_entry_flow.UnknownHandler
|
||||
|
||||
handler = HANDLERS.get(handler_key)
|
||||
|
||||
if handler is None:
|
||||
raise data_entry_flow.UnknownHandler
|
||||
|
||||
if not context or "source" not in context:
|
||||
raise KeyError("Context not set or doesn't have a source set")
|
||||
|
||||
source = context["source"]
|
||||
|
||||
# Create notification.
|
||||
if source in DISCOVERY_SOURCES:
|
||||
self.hass.bus.async_fire(EVENT_FLOW_DISCOVERED)
|
||||
self.hass.components.persistent_notification.async_create(
|
||||
title="New devices discovered",
|
||||
message=(
|
||||
"We have discovered new devices on your network. "
|
||||
"[Check it out](/config/integrations)"
|
||||
),
|
||||
notification_id=DISCOVERY_NOTIFICATION_ID,
|
||||
)
|
||||
|
||||
flow = cast(ConfigFlow, handler())
|
||||
flow.init_step = source
|
||||
return flow
|
||||
|
||||
|
||||
class ConfigEntries:
|
||||
"""Manage the configuration entries.
|
||||
|
||||
|
@ -408,9 +539,7 @@ class ConfigEntries:
|
|||
def __init__(self, hass: HomeAssistant, hass_config: dict) -> None:
|
||||
"""Initialize the entry manager."""
|
||||
self.hass = hass
|
||||
self.flow = data_entry_flow.FlowManager(
|
||||
hass, self._async_create_flow, self._async_finish_flow
|
||||
)
|
||||
self.flow = ConfigEntriesFlowManager(hass, self, hass_config)
|
||||
self.options = OptionsFlowManager(hass)
|
||||
self._hass_config = hass_config
|
||||
self._entries: List[ConfigEntry] = []
|
||||
|
@ -445,6 +574,12 @@ class ConfigEntries:
|
|||
return list(self._entries)
|
||||
return [entry for entry in self._entries if entry.domain == domain]
|
||||
|
||||
async def async_add(self, entry: ConfigEntry) -> None:
|
||||
"""Add and setup an entry."""
|
||||
self._entries.append(entry)
|
||||
await self.async_setup(entry.entry_id)
|
||||
self._async_schedule_save()
|
||||
|
||||
async def async_remove(self, entry_id: str) -> Dict[str, Any]:
|
||||
"""Remove an entry."""
|
||||
entry = self.async_get_entry(entry_id)
|
||||
|
@ -630,123 +765,6 @@ class ConfigEntries:
|
|||
|
||||
return await entry.async_unload(self.hass, integration=integration)
|
||||
|
||||
async def _async_finish_flow(
|
||||
self, flow: "ConfigFlow", result: Dict[str, Any]
|
||||
) -> Dict[str, Any]:
|
||||
"""Finish a config flow and add an entry."""
|
||||
# Remove notification if no other discovery config entries in progress
|
||||
if not any(
|
||||
ent["context"]["source"] in DISCOVERY_SOURCES
|
||||
for ent in self.hass.config_entries.flow.async_progress()
|
||||
if ent["flow_id"] != flow.flow_id
|
||||
):
|
||||
self.hass.components.persistent_notification.async_dismiss(
|
||||
DISCOVERY_NOTIFICATION_ID
|
||||
)
|
||||
|
||||
if result["type"] != data_entry_flow.RESULT_TYPE_CREATE_ENTRY:
|
||||
return result
|
||||
|
||||
# Check if config entry exists with unique ID. Unload it.
|
||||
existing_entry = None
|
||||
|
||||
if flow.unique_id is not None:
|
||||
# Abort all flows in progress with same unique ID.
|
||||
for progress_flow in self.flow.async_progress():
|
||||
if (
|
||||
progress_flow["handler"] == flow.handler
|
||||
and progress_flow["flow_id"] != flow.flow_id
|
||||
and progress_flow["context"].get("unique_id") == flow.unique_id
|
||||
):
|
||||
self.flow.async_abort(progress_flow["flow_id"])
|
||||
|
||||
# Find existing entry.
|
||||
for check_entry in self.async_entries(result["handler"]):
|
||||
if check_entry.unique_id == flow.unique_id:
|
||||
existing_entry = check_entry
|
||||
break
|
||||
|
||||
# Unload the entry before setting up the new one.
|
||||
# We will remove it only after the other one is set up,
|
||||
# so that device customizations are not getting lost.
|
||||
if (
|
||||
existing_entry is not None
|
||||
and existing_entry.state not in UNRECOVERABLE_STATES
|
||||
):
|
||||
await self.async_unload(existing_entry.entry_id)
|
||||
|
||||
entry = ConfigEntry(
|
||||
version=result["version"],
|
||||
domain=result["handler"],
|
||||
title=result["title"],
|
||||
data=result["data"],
|
||||
options={},
|
||||
system_options={},
|
||||
source=flow.context["source"],
|
||||
connection_class=flow.CONNECTION_CLASS,
|
||||
unique_id=flow.unique_id,
|
||||
)
|
||||
self._entries.append(entry)
|
||||
|
||||
await self.async_setup(entry.entry_id)
|
||||
|
||||
if existing_entry is not None:
|
||||
await self.async_remove(existing_entry.entry_id)
|
||||
|
||||
self._async_schedule_save()
|
||||
|
||||
result["result"] = entry
|
||||
return result
|
||||
|
||||
async def _async_create_flow(
|
||||
self, handler_key: str, *, context: Dict[str, Any], data: Dict[str, Any]
|
||||
) -> "ConfigFlow":
|
||||
"""Create a flow for specified handler.
|
||||
|
||||
Handler key is the domain of the component that we want to set up.
|
||||
"""
|
||||
try:
|
||||
integration = await loader.async_get_integration(self.hass, handler_key)
|
||||
except loader.IntegrationNotFound:
|
||||
_LOGGER.error("Cannot find integration %s", handler_key)
|
||||
raise data_entry_flow.UnknownHandler
|
||||
|
||||
# Make sure requirements and dependencies of component are resolved
|
||||
await async_process_deps_reqs(self.hass, self._hass_config, integration)
|
||||
|
||||
try:
|
||||
integration.get_platform("config_flow")
|
||||
except ImportError as err:
|
||||
_LOGGER.error(
|
||||
"Error occurred loading config flow for integration %s: %s",
|
||||
handler_key,
|
||||
err,
|
||||
)
|
||||
raise data_entry_flow.UnknownHandler
|
||||
|
||||
handler = HANDLERS.get(handler_key)
|
||||
|
||||
if handler is None:
|
||||
raise data_entry_flow.UnknownHandler
|
||||
|
||||
source = context["source"]
|
||||
|
||||
# Create notification.
|
||||
if source in DISCOVERY_SOURCES:
|
||||
self.hass.bus.async_fire(EVENT_FLOW_DISCOVERED)
|
||||
self.hass.components.persistent_notification.async_create(
|
||||
title="New devices discovered",
|
||||
message=(
|
||||
"We have discovered new devices on your network. "
|
||||
"[Check it out](/config/integrations)"
|
||||
),
|
||||
notification_id=DISCOVERY_NOTIFICATION_ID,
|
||||
)
|
||||
|
||||
flow = cast(ConfigFlow, handler())
|
||||
flow.init_step = source
|
||||
return flow
|
||||
|
||||
def _async_schedule_save(self) -> None:
|
||||
"""Save the entity registry to a file."""
|
||||
self._store.async_delay_save(self._data_to_save, SAVE_DELAY)
|
||||
|
@ -854,26 +872,23 @@ class ConfigFlow(data_entry_flow.FlowHandler):
|
|||
return self.async_abort(reason="not_implemented")
|
||||
|
||||
|
||||
class OptionsFlowManager:
|
||||
class OptionsFlowManager(data_entry_flow.FlowManager):
|
||||
"""Flow to set options for a configuration entry."""
|
||||
|
||||
def __init__(self, hass: HomeAssistant) -> None:
|
||||
"""Initialize the options manager."""
|
||||
self.hass = hass
|
||||
self.flow = data_entry_flow.FlowManager(
|
||||
hass, self._async_create_flow, self._async_finish_flow
|
||||
)
|
||||
|
||||
async def _async_create_flow(
|
||||
self, entry_id: str, *, context: Dict[str, Any], data: Dict[str, Any]
|
||||
) -> Optional["OptionsFlow"]:
|
||||
async def async_create_flow(
|
||||
self,
|
||||
handler_key: Any,
|
||||
*,
|
||||
context: Optional[Dict[str, Any]] = None,
|
||||
data: Optional[Dict[str, Any]] = None,
|
||||
) -> "OptionsFlow":
|
||||
"""Create an options flow for a config entry.
|
||||
|
||||
Entry_id and flow.handler is the same thing to map entry with flow.
|
||||
"""
|
||||
entry = self.hass.config_entries.async_get_entry(entry_id)
|
||||
entry = self.hass.config_entries.async_get_entry(handler_key)
|
||||
if entry is None:
|
||||
return None
|
||||
raise UnknownEntry(handler_key)
|
||||
|
||||
if entry.domain not in HANDLERS:
|
||||
raise data_entry_flow.UnknownHandler
|
||||
|
@ -881,16 +896,18 @@ class OptionsFlowManager:
|
|||
flow = cast(OptionsFlow, HANDLERS[entry.domain].async_get_options_flow(entry))
|
||||
return flow
|
||||
|
||||
async def _async_finish_flow(
|
||||
self, flow: "OptionsFlow", result: Dict[str, Any]
|
||||
) -> Optional[Dict[str, Any]]:
|
||||
async def async_finish_flow(
|
||||
self, flow: data_entry_flow.FlowHandler, result: Dict[str, Any]
|
||||
) -> Dict[str, Any]:
|
||||
"""Finish an options flow and update options for configuration entry.
|
||||
|
||||
Flow.handler and entry_id is the same thing to map flow with entry.
|
||||
"""
|
||||
flow = cast(OptionsFlow, flow)
|
||||
|
||||
entry = self.hass.config_entries.async_get_entry(flow.handler)
|
||||
if entry is None:
|
||||
return None
|
||||
raise UnknownEntry(flow.handler)
|
||||
self.hass.config_entries.async_update_entry(entry, options=result["data"])
|
||||
|
||||
result["result"] = True
|
||||
|
|
|
@ -1,6 +1,7 @@
|
|||
"""Classes to help gather user submissions."""
|
||||
import abc
|
||||
import logging
|
||||
from typing import Any, Callable, Dict, List, Optional, cast
|
||||
from typing import Any, Dict, List, Optional, cast
|
||||
import uuid
|
||||
|
||||
import voluptuous as vol
|
||||
|
@ -46,20 +47,34 @@ class AbortFlow(FlowError):
|
|||
self.description_placeholders = description_placeholders
|
||||
|
||||
|
||||
class FlowManager:
|
||||
class FlowManager(abc.ABC):
|
||||
"""Manage all the flows that are in progress."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
hass: HomeAssistant,
|
||||
async_create_flow: Callable,
|
||||
async_finish_flow: Callable,
|
||||
) -> None:
|
||||
def __init__(self, hass: HomeAssistant,) -> None:
|
||||
"""Initialize the flow manager."""
|
||||
self.hass = hass
|
||||
self._progress: Dict[str, Any] = {}
|
||||
self._async_create_flow = async_create_flow
|
||||
self._async_finish_flow = async_finish_flow
|
||||
|
||||
@abc.abstractmethod
|
||||
async def async_create_flow(
|
||||
self,
|
||||
handler_key: Any,
|
||||
*,
|
||||
context: Optional[Dict[str, Any]] = None,
|
||||
data: Optional[Dict[str, Any]] = None,
|
||||
) -> "FlowHandler":
|
||||
"""Create a flow for specified handler.
|
||||
|
||||
Handler key is the domain of the component that we want to set up.
|
||||
"""
|
||||
pass
|
||||
|
||||
@abc.abstractmethod
|
||||
async def async_finish_flow(
|
||||
self, flow: "FlowHandler", result: Dict[str, Any]
|
||||
) -> Dict[str, Any]:
|
||||
"""Finish a config flow and add an entry."""
|
||||
pass
|
||||
|
||||
@callback
|
||||
def async_progress(self) -> List[Dict]:
|
||||
|
@ -75,7 +90,9 @@ class FlowManager:
|
|||
"""Start a configuration flow."""
|
||||
if context is None:
|
||||
context = {}
|
||||
flow = await self._async_create_flow(handler, context=context, data=data)
|
||||
flow = await self.async_create_flow(handler, context=context, data=data)
|
||||
if not flow:
|
||||
raise UnknownFlow("Flow was not created")
|
||||
flow.hass = self.hass
|
||||
flow.handler = handler
|
||||
flow.flow_id = uuid.uuid4().hex
|
||||
|
@ -168,7 +185,7 @@ class FlowManager:
|
|||
return result
|
||||
|
||||
# We pass a copy of the result because we're mutating our version
|
||||
result = await self._async_finish_flow(flow, dict(result))
|
||||
result = await self.async_finish_flow(flow, dict(result))
|
||||
|
||||
# _async_finish_flow may change result type, check it again
|
||||
if result["type"] == RESULT_TYPE_FORM:
|
||||
|
|
|
@ -436,7 +436,7 @@ async def test_option_flow(hass):
|
|||
entry = MockConfigEntry(domain=config_flow.DOMAIN, data={}, options=None)
|
||||
hass.config_entries._entries.append(entry)
|
||||
|
||||
flow = await hass.config_entries.options._async_create_flow(
|
||||
flow = await hass.config_entries.options.async_create_flow(
|
||||
entry.entry_id, context={"source": "test"}, data=None
|
||||
)
|
||||
|
||||
|
|
|
@ -182,13 +182,13 @@ async def test_options_form(hass):
|
|||
)
|
||||
entry.add_to_hass(hass)
|
||||
|
||||
result = await hass.config_entries.options.flow.async_init(
|
||||
result = await hass.config_entries.options.async_init(
|
||||
entry.entry_id, context={"source": "test"}, data=None
|
||||
)
|
||||
assert result["type"] == data_entry_flow.RESULT_TYPE_FORM
|
||||
assert result["step_id"] == "init"
|
||||
|
||||
result = await hass.config_entries.options.flow.async_configure(
|
||||
result = await hass.config_entries.options.async_configure(
|
||||
result["flow_id"],
|
||||
user_input={CONF_FLOOR_TEMP: True, CONF_PRECISION: PRECISION_HALVES},
|
||||
)
|
||||
|
@ -197,11 +197,11 @@ async def test_options_form(hass):
|
|||
assert result["data"][CONF_PRECISION] == PRECISION_HALVES
|
||||
assert result["data"][CONF_FLOOR_TEMP] is True
|
||||
|
||||
result = await hass.config_entries.options.flow.async_init(
|
||||
result = await hass.config_entries.options.async_init(
|
||||
entry.entry_id, context={"source": "test"}, data=None
|
||||
)
|
||||
|
||||
result = await hass.config_entries.options.flow.async_configure(
|
||||
result = await hass.config_entries.options.async_configure(
|
||||
result["flow_id"], user_input={CONF_PRECISION: 0}
|
||||
)
|
||||
|
||||
|
|
|
@ -462,13 +462,13 @@ async def test_option_flow(hass):
|
|||
entry = MockConfigEntry(domain=config_flow.DOMAIN, data={}, options=DEFAULT_OPTIONS)
|
||||
entry.add_to_hass(hass)
|
||||
|
||||
result = await hass.config_entries.options.flow.async_init(
|
||||
result = await hass.config_entries.options.async_init(
|
||||
entry.entry_id, context={"source": "test"}, data=None
|
||||
)
|
||||
assert result["type"] == "form"
|
||||
assert result["step_id"] == "plex_mp_settings"
|
||||
|
||||
result = await hass.config_entries.options.flow.async_configure(
|
||||
result = await hass.config_entries.options.async_configure(
|
||||
result["flow_id"],
|
||||
user_input={
|
||||
config_flow.CONF_USE_EPISODE_ART: True,
|
||||
|
|
|
@ -131,12 +131,12 @@ async def test_option_flow(hass):
|
|||
entry = MockConfigEntry(domain=DOMAIN, data={}, options=None)
|
||||
entry.add_to_hass(hass)
|
||||
|
||||
result = await hass.config_entries.options.flow.async_init(entry.entry_id)
|
||||
result = await hass.config_entries.options.async_init(entry.entry_id)
|
||||
|
||||
assert result["type"] == "form"
|
||||
assert result["step_id"] == "init"
|
||||
|
||||
result = await hass.config_entries.options.flow.async_configure(
|
||||
result = await hass.config_entries.options.async_configure(
|
||||
result["flow_id"], user_input={CONF_SCAN_INTERVAL: 350}
|
||||
)
|
||||
assert result["type"] == "create_entry"
|
||||
|
@ -148,12 +148,12 @@ async def test_option_flow_input_floor(hass):
|
|||
entry = MockConfigEntry(domain=DOMAIN, data={}, options=None)
|
||||
entry.add_to_hass(hass)
|
||||
|
||||
result = await hass.config_entries.options.flow.async_init(entry.entry_id)
|
||||
result = await hass.config_entries.options.async_init(entry.entry_id)
|
||||
|
||||
assert result["type"] == "form"
|
||||
assert result["step_id"] == "init"
|
||||
|
||||
result = await hass.config_entries.options.flow.async_configure(
|
||||
result = await hass.config_entries.options.async_configure(
|
||||
result["flow_id"], user_input={CONF_SCAN_INTERVAL: 1}
|
||||
)
|
||||
assert result["type"] == "create_entry"
|
||||
|
|
|
@ -231,7 +231,7 @@ async def test_option_flow(hass):
|
|||
entry = MockConfigEntry(domain=config_flow.DOMAIN, data={}, options=None)
|
||||
hass.config_entries._entries.append(entry)
|
||||
|
||||
flow = await hass.config_entries.options._async_create_flow(
|
||||
flow = await hass.config_entries.options.async_create_flow(
|
||||
entry.entry_id, context={"source": "test"}, data=None
|
||||
)
|
||||
|
||||
|
|
|
@ -692,13 +692,13 @@ async def test_entry_options(hass, manager):
|
|||
return OptionsFlowHandler()
|
||||
|
||||
config_entries.HANDLERS["test"] = TestFlow()
|
||||
flow = await manager.options._async_create_flow(
|
||||
flow = await manager.options.async_create_flow(
|
||||
entry.entry_id, context={"source": "test"}, data=None
|
||||
)
|
||||
|
||||
flow.handler = entry.entry_id # Used to keep reference to config entry
|
||||
|
||||
await manager.options._async_finish_flow(flow, {"data": {"second": True}})
|
||||
await manager.options.async_finish_flow(flow, {"data": {"second": True}})
|
||||
|
||||
assert entry.data == {"first": True}
|
||||
|
||||
|
|
|
@ -14,27 +14,32 @@ def manager():
|
|||
handlers = Registry()
|
||||
entries = []
|
||||
|
||||
async def async_create_flow(handler_name, *, context, data):
|
||||
handler = handlers.get(handler_name)
|
||||
class FlowManager(data_entry_flow.FlowManager):
|
||||
"""Test flow manager."""
|
||||
|
||||
if handler is None:
|
||||
raise data_entry_flow.UnknownHandler
|
||||
async def async_create_flow(self, handler_key, *, context, data):
|
||||
"""Test create flow."""
|
||||
handler = handlers.get(handler_key)
|
||||
|
||||
flow = handler()
|
||||
flow.init_step = context.get("init_step", "init")
|
||||
flow.source = context.get("source")
|
||||
return flow
|
||||
if handler is None:
|
||||
raise data_entry_flow.UnknownHandler
|
||||
|
||||
async def async_add_entry(flow, result):
|
||||
if result["type"] == data_entry_flow.RESULT_TYPE_CREATE_ENTRY:
|
||||
result["source"] = flow.context.get("source")
|
||||
entries.append(result)
|
||||
return result
|
||||
flow = handler()
|
||||
flow.init_step = context.get("init_step", "init")
|
||||
flow.source = context.get("source")
|
||||
return flow
|
||||
|
||||
manager = data_entry_flow.FlowManager(None, async_create_flow, async_add_entry)
|
||||
manager.mock_created_entries = entries
|
||||
manager.mock_reg_handler = handlers.register
|
||||
return manager
|
||||
async def async_finish_flow(self, flow, result):
|
||||
"""Test finish flow."""
|
||||
if result["type"] == data_entry_flow.RESULT_TYPE_CREATE_ENTRY:
|
||||
result["source"] = flow.context.get("source")
|
||||
entries.append(result)
|
||||
return result
|
||||
|
||||
mgr = FlowManager(None)
|
||||
mgr.mock_created_entries = entries
|
||||
mgr.mock_reg_handler = handlers.register
|
||||
return mgr
|
||||
|
||||
|
||||
async def test_configure_reuses_handler_instance(manager):
|
||||
|
@ -194,22 +199,23 @@ async def test_finish_callback_change_result_type(hass):
|
|||
step_id="init", data_schema=vol.Schema({"count": int})
|
||||
)
|
||||
|
||||
async def async_create_flow(handler_name, *, context, data):
|
||||
"""Create a test flow."""
|
||||
return TestFlow()
|
||||
class FlowManager(data_entry_flow.FlowManager):
|
||||
async def async_create_flow(self, handler_name, *, context, data):
|
||||
"""Create a test flow."""
|
||||
return TestFlow()
|
||||
|
||||
async def async_finish_flow(flow, result):
|
||||
"""Redirect to init form if count <= 1."""
|
||||
if result["type"] == data_entry_flow.RESULT_TYPE_CREATE_ENTRY:
|
||||
if result["data"] is None or result["data"].get("count", 0) <= 1:
|
||||
return flow.async_show_form(
|
||||
step_id="init", data_schema=vol.Schema({"count": int})
|
||||
)
|
||||
else:
|
||||
result["result"] = result["data"]["count"]
|
||||
return result
|
||||
async def async_finish_flow(self, flow, result):
|
||||
"""Redirect to init form if count <= 1."""
|
||||
if result["type"] == data_entry_flow.RESULT_TYPE_CREATE_ENTRY:
|
||||
if result["data"] is None or result["data"].get("count", 0) <= 1:
|
||||
return flow.async_show_form(
|
||||
step_id="init", data_schema=vol.Schema({"count": int})
|
||||
)
|
||||
else:
|
||||
result["result"] = result["data"]["count"]
|
||||
return result
|
||||
|
||||
manager = data_entry_flow.FlowManager(hass, async_create_flow, async_finish_flow)
|
||||
manager = FlowManager(hass)
|
||||
|
||||
result = await manager.async_init("test")
|
||||
assert result["type"] == data_entry_flow.RESULT_TYPE_FORM
|
||||
|
|
Loading…
Reference in New Issue