Rework FlowManager to use inheritance (#30133)

* Pull async_finish_flow/async_create_flow out of ConfigEntries

* Towards refactoring

* mypy fixes

* Mark Flow manager with abc.* annotations

* Flake8 fixes

* Mypy fixes

* Blacken data_entry_flow

* Blacken longer signatures caused by mypy changes

* test fixes

* Test fixes

* Fix typo

* Avoid protected member lint (W0212) in config_entries

* More protected member fixes

* Missing await
pull/29828/head
Jc2k 2020-01-03 10:52:01 +00:00 committed by Paulus Schoutsen
parent 0a4f3ec1ec
commit fdfedd086b
12 changed files with 313 additions and 258 deletions

View File

@ -67,6 +67,69 @@ async def auth_manager_from_config(
return manager
class AuthManagerFlowManager(data_entry_flow.FlowManager):
"""Manage authentication flows."""
def __init__(self, hass: HomeAssistant, auth_manager: "AuthManager"):
"""Init auth manager flows."""
super().__init__(hass)
self.auth_manager = auth_manager
async def async_create_flow(
self,
handler_key: Any,
*,
context: Optional[Dict[str, Any]] = None,
data: Optional[Dict[str, Any]] = None,
) -> data_entry_flow.FlowHandler:
"""Create a login flow."""
auth_provider = self.auth_manager.get_auth_provider(*handler_key)
if not auth_provider:
raise KeyError(f"Unknown auth provider {handler_key}")
return await auth_provider.async_login_flow(context)
async def async_finish_flow(
self, flow: data_entry_flow.FlowHandler, result: Dict[str, Any]
) -> Dict[str, Any]:
"""Return a user as result of login flow."""
flow = cast(LoginFlow, flow)
if result["type"] != data_entry_flow.RESULT_TYPE_CREATE_ENTRY:
return result
# we got final result
if isinstance(result["data"], models.User):
result["result"] = result["data"]
return result
auth_provider = self.auth_manager.get_auth_provider(*result["handler"])
if not auth_provider:
raise KeyError(f"Unknown auth provider {result['handler']}")
credentials = await auth_provider.async_get_or_create_credentials(
result["data"]
)
if flow.context.get("credential_only"):
result["result"] = credentials
return result
# multi-factor module cannot enabled for new credential
# which has not linked to a user yet
if auth_provider.support_mfa and not credentials.is_new:
user = await self.auth_manager.async_get_user_by_credentials(credentials)
if user is not None:
modules = await self.auth_manager.async_get_enabled_mfa(user)
if modules:
flow.user = user
flow.available_mfa_modules = modules
return await flow.async_step_select_mfa_module()
result["result"] = await self.auth_manager.async_get_or_create_user(credentials)
return result
class AuthManager:
"""Manage the authentication for Home Assistant."""
@ -82,9 +145,7 @@ class AuthManager:
self._store = store
self._providers = providers
self._mfa_modules = mfa_modules
self.login_flow = data_entry_flow.FlowManager(
hass, self._async_create_login_flow, self._async_finish_login_flow
)
self.login_flow = AuthManagerFlowManager(hass, self)
@property
def auth_providers(self) -> List[AuthProvider]:
@ -417,50 +478,6 @@ class AuthManager:
return refresh_token
async def _async_create_login_flow(
self, handler: _ProviderKey, *, context: Optional[Dict], data: Optional[Any]
) -> data_entry_flow.FlowHandler:
"""Create a login flow."""
auth_provider = self._providers[handler]
return await auth_provider.async_login_flow(context)
async def _async_finish_login_flow(
self, flow: LoginFlow, result: Dict[str, Any]
) -> Dict[str, Any]:
"""Return a user as result of login flow."""
if result["type"] != data_entry_flow.RESULT_TYPE_CREATE_ENTRY:
return result
# we got final result
if isinstance(result["data"], models.User):
result["result"] = result["data"]
return result
auth_provider = self._providers[result["handler"]]
credentials = await auth_provider.async_get_or_create_credentials(
result["data"]
)
if flow.context.get("credential_only"):
result["result"] = credentials
return result
# multi-factor module cannot enabled for new credential
# which has not linked to a user yet
if auth_provider.support_mfa and not credentials.is_new:
user = await self.async_get_user_by_credentials(credentials)
if user is not None:
modules = await self.async_get_enabled_mfa(user)
if modules:
flow.user = user
flow.available_mfa_modules = modules
return await flow.async_step_select_mfa_module()
result["result"] = await self.async_get_or_create_user(credentials)
return result
@callback
def _async_get_auth_provider(
self, credentials: models.Credentials

View File

@ -28,25 +28,27 @@ DATA_SETUP_FLOW_MGR = "auth_mfa_setup_flow_manager"
_LOGGER = logging.getLogger(__name__)
async def async_setup(hass):
"""Init mfa setup flow manager."""
class MfaFlowManager(data_entry_flow.FlowManager):
"""Manage multi factor authentication flows."""
async def _async_create_setup_flow(handler, context, data):
async def async_create_flow(self, handler_key, *, context, data):
"""Create a setup flow. handler is a mfa module."""
mfa_module = hass.auth.get_auth_mfa_module(handler)
mfa_module = self.hass.auth.get_auth_mfa_module(handler_key)
if mfa_module is None:
raise ValueError(f"Mfa module {handler} is not found")
raise ValueError(f"Mfa module {handler_key} is not found")
user_id = data.pop("user_id")
return await mfa_module.async_setup_flow(user_id)
async def _async_finish_setup_flow(flow, flow_result):
_LOGGER.debug("flow_result: %s", flow_result)
return flow_result
async def async_finish_flow(self, flow, result):
"""Complete an mfs setup flow."""
_LOGGER.debug("flow_result: %s", result)
return result
hass.data[DATA_SETUP_FLOW_MGR] = data_entry_flow.FlowManager(
hass, _async_create_setup_flow, _async_finish_setup_flow
)
async def async_setup(hass):
"""Init mfa setup flow manager."""
hass.data[DATA_SETUP_FLOW_MGR] = MfaFlowManager(hass)
hass.components.websocket_api.async_register_command(
WS_TYPE_SETUP_MFA, websocket_setup_mfa, SCHEMA_WS_SETUP_MFA

View File

@ -23,12 +23,8 @@ async def async_setup(hass):
hass.http.register_view(ConfigManagerFlowResourceView(hass.config_entries.flow))
hass.http.register_view(ConfigManagerAvailableFlowView)
hass.http.register_view(
OptionManagerFlowIndexView(hass.config_entries.options.flow)
)
hass.http.register_view(
OptionManagerFlowResourceView(hass.config_entries.options.flow)
)
hass.http.register_view(OptionManagerFlowIndexView(hass.config_entries.options))
hass.http.register_view(OptionManagerFlowResourceView(hass.config_entries.options))
hass.components.websocket_api.async_register_command(config_entries_progress)
hass.components.websocket_api.async_register_command(system_options_list)

View File

@ -399,6 +399,137 @@ class ConfigEntry:
}
class ConfigEntriesFlowManager(data_entry_flow.FlowManager):
"""Manage all the config entry flows that are in progress."""
def __init__(
self, hass: HomeAssistant, config_entries: "ConfigEntries", hass_config: dict
):
"""Initialize the config entry flow manager."""
super().__init__(hass)
self.config_entries = config_entries
self._hass_config = hass_config
async def async_finish_flow(
self, flow: data_entry_flow.FlowHandler, result: Dict[str, Any]
) -> Dict[str, Any]:
"""Finish a config flow and add an entry."""
flow = cast(ConfigFlow, flow)
# Remove notification if no other discovery config entries in progress
if not any(
ent["context"]["source"] in DISCOVERY_SOURCES
for ent in self.hass.config_entries.flow.async_progress()
if ent["flow_id"] != flow.flow_id
):
self.hass.components.persistent_notification.async_dismiss(
DISCOVERY_NOTIFICATION_ID
)
if result["type"] != data_entry_flow.RESULT_TYPE_CREATE_ENTRY:
return result
# Check if config entry exists with unique ID. Unload it.
existing_entry = None
if flow.unique_id is not None:
# Abort all flows in progress with same unique ID.
for progress_flow in self.async_progress():
if (
progress_flow["handler"] == flow.handler
and progress_flow["flow_id"] != flow.flow_id
and progress_flow["context"].get("unique_id") == flow.unique_id
):
self.async_abort(progress_flow["flow_id"])
# Find existing entry.
for check_entry in self.config_entries.async_entries(result["handler"]):
if check_entry.unique_id == flow.unique_id:
existing_entry = check_entry
break
# Unload the entry before setting up the new one.
# We will remove it only after the other one is set up,
# so that device customizations are not getting lost.
if (
existing_entry is not None
and existing_entry.state not in UNRECOVERABLE_STATES
):
await self.config_entries.async_unload(existing_entry.entry_id)
entry = ConfigEntry(
version=result["version"],
domain=result["handler"],
title=result["title"],
data=result["data"],
options={},
system_options={},
source=flow.context["source"],
connection_class=flow.CONNECTION_CLASS,
unique_id=flow.unique_id,
)
await self.config_entries.async_add(entry)
if existing_entry is not None:
await self.config_entries.async_remove(existing_entry.entry_id)
result["result"] = entry
return result
async def async_create_flow(
self, handler_key: Any, *, context: Optional[Dict] = None, data: Any = None
) -> "ConfigFlow":
"""Create a flow for specified handler.
Handler key is the domain of the component that we want to set up.
"""
try:
integration = await loader.async_get_integration(self.hass, handler_key)
except loader.IntegrationNotFound:
_LOGGER.error("Cannot find integration %s", handler_key)
raise data_entry_flow.UnknownHandler
# Make sure requirements and dependencies of component are resolved
await async_process_deps_reqs(self.hass, self._hass_config, integration)
try:
integration.get_platform("config_flow")
except ImportError as err:
_LOGGER.error(
"Error occurred loading config flow for integration %s: %s",
handler_key,
err,
)
raise data_entry_flow.UnknownHandler
handler = HANDLERS.get(handler_key)
if handler is None:
raise data_entry_flow.UnknownHandler
if not context or "source" not in context:
raise KeyError("Context not set or doesn't have a source set")
source = context["source"]
# Create notification.
if source in DISCOVERY_SOURCES:
self.hass.bus.async_fire(EVENT_FLOW_DISCOVERED)
self.hass.components.persistent_notification.async_create(
title="New devices discovered",
message=(
"We have discovered new devices on your network. "
"[Check it out](/config/integrations)"
),
notification_id=DISCOVERY_NOTIFICATION_ID,
)
flow = cast(ConfigFlow, handler())
flow.init_step = source
return flow
class ConfigEntries:
"""Manage the configuration entries.
@ -408,9 +539,7 @@ class ConfigEntries:
def __init__(self, hass: HomeAssistant, hass_config: dict) -> None:
"""Initialize the entry manager."""
self.hass = hass
self.flow = data_entry_flow.FlowManager(
hass, self._async_create_flow, self._async_finish_flow
)
self.flow = ConfigEntriesFlowManager(hass, self, hass_config)
self.options = OptionsFlowManager(hass)
self._hass_config = hass_config
self._entries: List[ConfigEntry] = []
@ -445,6 +574,12 @@ class ConfigEntries:
return list(self._entries)
return [entry for entry in self._entries if entry.domain == domain]
async def async_add(self, entry: ConfigEntry) -> None:
"""Add and setup an entry."""
self._entries.append(entry)
await self.async_setup(entry.entry_id)
self._async_schedule_save()
async def async_remove(self, entry_id: str) -> Dict[str, Any]:
"""Remove an entry."""
entry = self.async_get_entry(entry_id)
@ -630,123 +765,6 @@ class ConfigEntries:
return await entry.async_unload(self.hass, integration=integration)
async def _async_finish_flow(
self, flow: "ConfigFlow", result: Dict[str, Any]
) -> Dict[str, Any]:
"""Finish a config flow and add an entry."""
# Remove notification if no other discovery config entries in progress
if not any(
ent["context"]["source"] in DISCOVERY_SOURCES
for ent in self.hass.config_entries.flow.async_progress()
if ent["flow_id"] != flow.flow_id
):
self.hass.components.persistent_notification.async_dismiss(
DISCOVERY_NOTIFICATION_ID
)
if result["type"] != data_entry_flow.RESULT_TYPE_CREATE_ENTRY:
return result
# Check if config entry exists with unique ID. Unload it.
existing_entry = None
if flow.unique_id is not None:
# Abort all flows in progress with same unique ID.
for progress_flow in self.flow.async_progress():
if (
progress_flow["handler"] == flow.handler
and progress_flow["flow_id"] != flow.flow_id
and progress_flow["context"].get("unique_id") == flow.unique_id
):
self.flow.async_abort(progress_flow["flow_id"])
# Find existing entry.
for check_entry in self.async_entries(result["handler"]):
if check_entry.unique_id == flow.unique_id:
existing_entry = check_entry
break
# Unload the entry before setting up the new one.
# We will remove it only after the other one is set up,
# so that device customizations are not getting lost.
if (
existing_entry is not None
and existing_entry.state not in UNRECOVERABLE_STATES
):
await self.async_unload(existing_entry.entry_id)
entry = ConfigEntry(
version=result["version"],
domain=result["handler"],
title=result["title"],
data=result["data"],
options={},
system_options={},
source=flow.context["source"],
connection_class=flow.CONNECTION_CLASS,
unique_id=flow.unique_id,
)
self._entries.append(entry)
await self.async_setup(entry.entry_id)
if existing_entry is not None:
await self.async_remove(existing_entry.entry_id)
self._async_schedule_save()
result["result"] = entry
return result
async def _async_create_flow(
self, handler_key: str, *, context: Dict[str, Any], data: Dict[str, Any]
) -> "ConfigFlow":
"""Create a flow for specified handler.
Handler key is the domain of the component that we want to set up.
"""
try:
integration = await loader.async_get_integration(self.hass, handler_key)
except loader.IntegrationNotFound:
_LOGGER.error("Cannot find integration %s", handler_key)
raise data_entry_flow.UnknownHandler
# Make sure requirements and dependencies of component are resolved
await async_process_deps_reqs(self.hass, self._hass_config, integration)
try:
integration.get_platform("config_flow")
except ImportError as err:
_LOGGER.error(
"Error occurred loading config flow for integration %s: %s",
handler_key,
err,
)
raise data_entry_flow.UnknownHandler
handler = HANDLERS.get(handler_key)
if handler is None:
raise data_entry_flow.UnknownHandler
source = context["source"]
# Create notification.
if source in DISCOVERY_SOURCES:
self.hass.bus.async_fire(EVENT_FLOW_DISCOVERED)
self.hass.components.persistent_notification.async_create(
title="New devices discovered",
message=(
"We have discovered new devices on your network. "
"[Check it out](/config/integrations)"
),
notification_id=DISCOVERY_NOTIFICATION_ID,
)
flow = cast(ConfigFlow, handler())
flow.init_step = source
return flow
def _async_schedule_save(self) -> None:
"""Save the entity registry to a file."""
self._store.async_delay_save(self._data_to_save, SAVE_DELAY)
@ -854,26 +872,23 @@ class ConfigFlow(data_entry_flow.FlowHandler):
return self.async_abort(reason="not_implemented")
class OptionsFlowManager:
class OptionsFlowManager(data_entry_flow.FlowManager):
"""Flow to set options for a configuration entry."""
def __init__(self, hass: HomeAssistant) -> None:
"""Initialize the options manager."""
self.hass = hass
self.flow = data_entry_flow.FlowManager(
hass, self._async_create_flow, self._async_finish_flow
)
async def _async_create_flow(
self, entry_id: str, *, context: Dict[str, Any], data: Dict[str, Any]
) -> Optional["OptionsFlow"]:
async def async_create_flow(
self,
handler_key: Any,
*,
context: Optional[Dict[str, Any]] = None,
data: Optional[Dict[str, Any]] = None,
) -> "OptionsFlow":
"""Create an options flow for a config entry.
Entry_id and flow.handler is the same thing to map entry with flow.
"""
entry = self.hass.config_entries.async_get_entry(entry_id)
entry = self.hass.config_entries.async_get_entry(handler_key)
if entry is None:
return None
raise UnknownEntry(handler_key)
if entry.domain not in HANDLERS:
raise data_entry_flow.UnknownHandler
@ -881,16 +896,18 @@ class OptionsFlowManager:
flow = cast(OptionsFlow, HANDLERS[entry.domain].async_get_options_flow(entry))
return flow
async def _async_finish_flow(
self, flow: "OptionsFlow", result: Dict[str, Any]
) -> Optional[Dict[str, Any]]:
async def async_finish_flow(
self, flow: data_entry_flow.FlowHandler, result: Dict[str, Any]
) -> Dict[str, Any]:
"""Finish an options flow and update options for configuration entry.
Flow.handler and entry_id is the same thing to map flow with entry.
"""
flow = cast(OptionsFlow, flow)
entry = self.hass.config_entries.async_get_entry(flow.handler)
if entry is None:
return None
raise UnknownEntry(flow.handler)
self.hass.config_entries.async_update_entry(entry, options=result["data"])
result["result"] = True

View File

@ -1,6 +1,7 @@
"""Classes to help gather user submissions."""
import abc
import logging
from typing import Any, Callable, Dict, List, Optional, cast
from typing import Any, Dict, List, Optional, cast
import uuid
import voluptuous as vol
@ -46,20 +47,34 @@ class AbortFlow(FlowError):
self.description_placeholders = description_placeholders
class FlowManager:
class FlowManager(abc.ABC):
"""Manage all the flows that are in progress."""
def __init__(
self,
hass: HomeAssistant,
async_create_flow: Callable,
async_finish_flow: Callable,
) -> None:
def __init__(self, hass: HomeAssistant,) -> None:
"""Initialize the flow manager."""
self.hass = hass
self._progress: Dict[str, Any] = {}
self._async_create_flow = async_create_flow
self._async_finish_flow = async_finish_flow
@abc.abstractmethod
async def async_create_flow(
self,
handler_key: Any,
*,
context: Optional[Dict[str, Any]] = None,
data: Optional[Dict[str, Any]] = None,
) -> "FlowHandler":
"""Create a flow for specified handler.
Handler key is the domain of the component that we want to set up.
"""
pass
@abc.abstractmethod
async def async_finish_flow(
self, flow: "FlowHandler", result: Dict[str, Any]
) -> Dict[str, Any]:
"""Finish a config flow and add an entry."""
pass
@callback
def async_progress(self) -> List[Dict]:
@ -75,7 +90,9 @@ class FlowManager:
"""Start a configuration flow."""
if context is None:
context = {}
flow = await self._async_create_flow(handler, context=context, data=data)
flow = await self.async_create_flow(handler, context=context, data=data)
if not flow:
raise UnknownFlow("Flow was not created")
flow.hass = self.hass
flow.handler = handler
flow.flow_id = uuid.uuid4().hex
@ -168,7 +185,7 @@ class FlowManager:
return result
# We pass a copy of the result because we're mutating our version
result = await self._async_finish_flow(flow, dict(result))
result = await self.async_finish_flow(flow, dict(result))
# _async_finish_flow may change result type, check it again
if result["type"] == RESULT_TYPE_FORM:

View File

@ -436,7 +436,7 @@ async def test_option_flow(hass):
entry = MockConfigEntry(domain=config_flow.DOMAIN, data={}, options=None)
hass.config_entries._entries.append(entry)
flow = await hass.config_entries.options._async_create_flow(
flow = await hass.config_entries.options.async_create_flow(
entry.entry_id, context={"source": "test"}, data=None
)

View File

@ -182,13 +182,13 @@ async def test_options_form(hass):
)
entry.add_to_hass(hass)
result = await hass.config_entries.options.flow.async_init(
result = await hass.config_entries.options.async_init(
entry.entry_id, context={"source": "test"}, data=None
)
assert result["type"] == data_entry_flow.RESULT_TYPE_FORM
assert result["step_id"] == "init"
result = await hass.config_entries.options.flow.async_configure(
result = await hass.config_entries.options.async_configure(
result["flow_id"],
user_input={CONF_FLOOR_TEMP: True, CONF_PRECISION: PRECISION_HALVES},
)
@ -197,11 +197,11 @@ async def test_options_form(hass):
assert result["data"][CONF_PRECISION] == PRECISION_HALVES
assert result["data"][CONF_FLOOR_TEMP] is True
result = await hass.config_entries.options.flow.async_init(
result = await hass.config_entries.options.async_init(
entry.entry_id, context={"source": "test"}, data=None
)
result = await hass.config_entries.options.flow.async_configure(
result = await hass.config_entries.options.async_configure(
result["flow_id"], user_input={CONF_PRECISION: 0}
)

View File

@ -462,13 +462,13 @@ async def test_option_flow(hass):
entry = MockConfigEntry(domain=config_flow.DOMAIN, data={}, options=DEFAULT_OPTIONS)
entry.add_to_hass(hass)
result = await hass.config_entries.options.flow.async_init(
result = await hass.config_entries.options.async_init(
entry.entry_id, context={"source": "test"}, data=None
)
assert result["type"] == "form"
assert result["step_id"] == "plex_mp_settings"
result = await hass.config_entries.options.flow.async_configure(
result = await hass.config_entries.options.async_configure(
result["flow_id"],
user_input={
config_flow.CONF_USE_EPISODE_ART: True,

View File

@ -131,12 +131,12 @@ async def test_option_flow(hass):
entry = MockConfigEntry(domain=DOMAIN, data={}, options=None)
entry.add_to_hass(hass)
result = await hass.config_entries.options.flow.async_init(entry.entry_id)
result = await hass.config_entries.options.async_init(entry.entry_id)
assert result["type"] == "form"
assert result["step_id"] == "init"
result = await hass.config_entries.options.flow.async_configure(
result = await hass.config_entries.options.async_configure(
result["flow_id"], user_input={CONF_SCAN_INTERVAL: 350}
)
assert result["type"] == "create_entry"
@ -148,12 +148,12 @@ async def test_option_flow_input_floor(hass):
entry = MockConfigEntry(domain=DOMAIN, data={}, options=None)
entry.add_to_hass(hass)
result = await hass.config_entries.options.flow.async_init(entry.entry_id)
result = await hass.config_entries.options.async_init(entry.entry_id)
assert result["type"] == "form"
assert result["step_id"] == "init"
result = await hass.config_entries.options.flow.async_configure(
result = await hass.config_entries.options.async_configure(
result["flow_id"], user_input={CONF_SCAN_INTERVAL: 1}
)
assert result["type"] == "create_entry"

View File

@ -231,7 +231,7 @@ async def test_option_flow(hass):
entry = MockConfigEntry(domain=config_flow.DOMAIN, data={}, options=None)
hass.config_entries._entries.append(entry)
flow = await hass.config_entries.options._async_create_flow(
flow = await hass.config_entries.options.async_create_flow(
entry.entry_id, context={"source": "test"}, data=None
)

View File

@ -692,13 +692,13 @@ async def test_entry_options(hass, manager):
return OptionsFlowHandler()
config_entries.HANDLERS["test"] = TestFlow()
flow = await manager.options._async_create_flow(
flow = await manager.options.async_create_flow(
entry.entry_id, context={"source": "test"}, data=None
)
flow.handler = entry.entry_id # Used to keep reference to config entry
await manager.options._async_finish_flow(flow, {"data": {"second": True}})
await manager.options.async_finish_flow(flow, {"data": {"second": True}})
assert entry.data == {"first": True}

View File

@ -14,27 +14,32 @@ def manager():
handlers = Registry()
entries = []
async def async_create_flow(handler_name, *, context, data):
handler = handlers.get(handler_name)
class FlowManager(data_entry_flow.FlowManager):
"""Test flow manager."""
if handler is None:
raise data_entry_flow.UnknownHandler
async def async_create_flow(self, handler_key, *, context, data):
"""Test create flow."""
handler = handlers.get(handler_key)
flow = handler()
flow.init_step = context.get("init_step", "init")
flow.source = context.get("source")
return flow
if handler is None:
raise data_entry_flow.UnknownHandler
async def async_add_entry(flow, result):
if result["type"] == data_entry_flow.RESULT_TYPE_CREATE_ENTRY:
result["source"] = flow.context.get("source")
entries.append(result)
return result
flow = handler()
flow.init_step = context.get("init_step", "init")
flow.source = context.get("source")
return flow
manager = data_entry_flow.FlowManager(None, async_create_flow, async_add_entry)
manager.mock_created_entries = entries
manager.mock_reg_handler = handlers.register
return manager
async def async_finish_flow(self, flow, result):
"""Test finish flow."""
if result["type"] == data_entry_flow.RESULT_TYPE_CREATE_ENTRY:
result["source"] = flow.context.get("source")
entries.append(result)
return result
mgr = FlowManager(None)
mgr.mock_created_entries = entries
mgr.mock_reg_handler = handlers.register
return mgr
async def test_configure_reuses_handler_instance(manager):
@ -194,22 +199,23 @@ async def test_finish_callback_change_result_type(hass):
step_id="init", data_schema=vol.Schema({"count": int})
)
async def async_create_flow(handler_name, *, context, data):
"""Create a test flow."""
return TestFlow()
class FlowManager(data_entry_flow.FlowManager):
async def async_create_flow(self, handler_name, *, context, data):
"""Create a test flow."""
return TestFlow()
async def async_finish_flow(flow, result):
"""Redirect to init form if count <= 1."""
if result["type"] == data_entry_flow.RESULT_TYPE_CREATE_ENTRY:
if result["data"] is None or result["data"].get("count", 0) <= 1:
return flow.async_show_form(
step_id="init", data_schema=vol.Schema({"count": int})
)
else:
result["result"] = result["data"]["count"]
return result
async def async_finish_flow(self, flow, result):
"""Redirect to init form if count <= 1."""
if result["type"] == data_entry_flow.RESULT_TYPE_CREATE_ENTRY:
if result["data"] is None or result["data"].get("count", 0) <= 1:
return flow.async_show_form(
step_id="init", data_schema=vol.Schema({"count": int})
)
else:
result["result"] = result["data"]["count"]
return result
manager = data_entry_flow.FlowManager(hass, async_create_flow, async_finish_flow)
manager = FlowManager(hass)
result = await manager.async_init("test")
assert result["type"] == data_entry_flow.RESULT_TYPE_FORM