Index in-progress flows to avoid linear search (#58146)

Co-authored-by: Steven Looman <steven.looman@gmail.com>
pull/57926/head
J. Nick Koston 2021-10-22 07:19:49 -10:00 committed by GitHub
parent fa56be7cc0
commit 3b7dce8b95
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
11 changed files with 190 additions and 64 deletions

View File

@ -231,14 +231,9 @@ class LoginFlowResourceView(HomeAssistantView):
try:
# do not allow change ip during login flow
for flow in self._flow_mgr.async_progress():
if flow["flow_id"] == flow_id and flow["context"][
"ip_address"
] != ip_address(request.remote):
return self.json_message(
"IP address changed", HTTPStatus.BAD_REQUEST
)
flow = self._flow_mgr.async_get(flow_id)
if flow["context"]["ip_address"] != ip_address(request.remote):
return self.json_message("IP address changed", HTTPStatus.BAD_REQUEST)
result = await self._flow_mgr.async_configure(flow_id, data)
except data_entry_flow.UnknownFlow:
return self.json_message("Invalid flow specified", HTTPStatus.NOT_FOUND)

View File

@ -131,7 +131,7 @@ class PointFlowHandler(config_entries.ConfigFlow, domain=DOMAIN):
_LOGGER.debug(
"Should close all flows below %s",
self.hass.config_entries.flow.async_progress(),
self._async_in_progress(),
)
# Remove notification if no other discovery config entries in progress

View File

@ -73,8 +73,7 @@ async def async_migrate_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool:
# Remove the entry which will invoke the callback to delete the app.
hass.async_create_task(hass.config_entries.async_remove(entry.entry_id))
# only create new flow if there isn't a pending one for SmartThings.
flows = hass.config_entries.flow.async_progress()
if not [flow for flow in flows if flow["handler"] == DOMAIN]:
if not hass.config_entries.flow.async_progress_by_handler(DOMAIN):
hass.async_create_task(
hass.config_entries.flow.async_init(
DOMAIN, context={"source": SOURCE_IMPORT}
@ -181,8 +180,7 @@ async def async_setup_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool:
if remove_entry:
hass.async_create_task(hass.config_entries.async_remove(entry.entry_id))
# only create new flow if there isn't a pending one for SmartThings.
flows = hass.config_entries.flow.async_progress()
if not [flow for flow in flows if flow["handler"] == DOMAIN]:
if not hass.config_entries.flow.async_progress_by_handler(DOMAIN):
hass.async_create_task(
hass.config_entries.flow.async_init(
DOMAIN, context={"source": SOURCE_IMPORT}

View File

@ -406,8 +406,8 @@ async def _continue_flow(
flow = next(
(
flow
for flow in hass.config_entries.flow.async_progress()
if flow["handler"] == DOMAIN and flow["context"]["unique_id"] == unique_id
for flow in hass.config_entries.flow.async_progress_by_handler(DOMAIN)
if flow["context"]["unique_id"] == unique_id
),
None,
)

View File

@ -745,7 +745,9 @@ class DataManager:
flow = next(
iter(
flow
for flow in self._hass.config_entries.flow.async_progress()
for flow in self._hass.config_entries.flow.async_progress_by_handler(
const.DOMAIN
)
if flow.context == context
),
None,

View File

@ -120,9 +120,8 @@ class ZhaFlowHandler(config_entries.ConfigFlow, domain=DOMAIN):
# If they already have a discovery for deconz
# we ignore the usb discovery as they probably
# want to use it there instead
for flow in self.hass.config_entries.flow.async_progress():
if flow["handler"] == DECONZ_DOMAIN:
return self.async_abort(reason="not_zha_device")
if self.hass.config_entries.flow.async_progress_by_handler(DECONZ_DOMAIN):
return self.async_abort(reason="not_zha_device")
for entry in self.hass.config_entries.async_entries(DECONZ_DOMAIN):
if entry.source != config_entries.SOURCE_IGNORE:
return self.async_abort(reason="not_zha_device")

View File

@ -586,7 +586,7 @@ class ConfigEntry:
"unique_id": self.unique_id,
}
for flow in hass.config_entries.flow.async_progress():
for flow in hass.config_entries.flow.async_progress_by_handler(self.domain):
if flow["context"] == flow_context:
return
@ -618,6 +618,14 @@ class ConfigEntriesFlowManager(data_entry_flow.FlowManager):
self.config_entries = config_entries
self._hass_config = hass_config
@callback
def _async_has_other_discovery_flows(self, flow_id: str) -> bool:
"""Check if there are any other discovery flows in progress."""
return any(
flow.context["source"] in DISCOVERY_SOURCES and flow.flow_id != flow_id
for flow in self._progress.values()
)
async def async_finish_flow(
self, flow: data_entry_flow.FlowHandler, result: data_entry_flow.FlowResult
) -> data_entry_flow.FlowResult:
@ -625,11 +633,7 @@ class ConfigEntriesFlowManager(data_entry_flow.FlowManager):
flow = cast(ConfigFlow, flow)
# Remove notification if no other discovery config entries in progress
if not any(
ent["context"]["source"] in DISCOVERY_SOURCES
for ent in self.hass.config_entries.flow.async_progress()
if ent["flow_id"] != flow.flow_id
):
if not self._async_has_other_discovery_flows(flow.flow_id):
self.hass.components.persistent_notification.async_dismiss(
DISCOVERY_NOTIFICATION_ID
)
@ -642,15 +646,11 @@ class ConfigEntriesFlowManager(data_entry_flow.FlowManager):
# Abort all flows in progress with same unique ID
# or the default discovery ID
for progress_flow in self.async_progress():
for progress_flow in self.async_progress_by_handler(flow.handler):
progress_unique_id = progress_flow["context"].get("unique_id")
if (
progress_flow["handler"] == flow.handler
and progress_flow["flow_id"] != flow.flow_id
and (
(flow.unique_id and progress_unique_id == flow.unique_id)
or progress_unique_id == DEFAULT_DISCOVERY_UNIQUE_ID
)
if progress_flow["flow_id"] != flow.flow_id and (
(flow.unique_id and progress_unique_id == flow.unique_id)
or progress_unique_id == DEFAULT_DISCOVERY_UNIQUE_ID
):
self.async_abort(progress_flow["flow_id"])
@ -837,7 +837,9 @@ class ConfigEntries:
# If the configuration entry is removed during reauth, it should
# abort any reauth flow that is active for the removed entry.
for progress_flow in self.hass.config_entries.flow.async_progress():
for progress_flow in self.hass.config_entries.flow.async_progress_by_handler(
entry.domain
):
context = progress_flow.get("context")
if (
context
@ -1265,10 +1267,10 @@ class ConfigFlow(data_entry_flow.FlowHandler):
"""Return other in progress flows for current domain."""
return [
flw
for flw in self.hass.config_entries.flow.async_progress(
include_uninitialized=include_uninitialized
for flw in self.hass.config_entries.flow.async_progress_by_handler(
self.handler, include_uninitialized=include_uninitialized
)
if flw["handler"] == self.handler and flw["flow_id"] != self.flow_id
if flw["flow_id"] != self.flow_id
]
async def async_step_ignore(
@ -1329,7 +1331,9 @@ class ConfigFlow(data_entry_flow.FlowHandler):
# Remove reauth notification if no reauth flows are in progress
if self.source == SOURCE_REAUTH and not any(
ent["context"]["source"] == SOURCE_REAUTH
for ent in self.hass.config_entries.flow.async_progress()
for ent in self.hass.config_entries.flow.async_progress_by_handler(
self.handler
)
if ent["flow_id"] != self.flow_id
):
self.hass.components.persistent_notification.async_dismiss(

View File

@ -3,7 +3,7 @@ from __future__ import annotations
import abc
import asyncio
from collections.abc import Mapping
from collections.abc import Iterable, Mapping
from types import MappingProxyType
from typing import Any, TypedDict
import uuid
@ -78,6 +78,23 @@ class FlowResult(TypedDict, total=False):
options: Mapping[str, Any]
@callback
def _async_flow_handler_to_flow_result(
flows: Iterable[FlowHandler], include_uninitialized: bool
) -> list[FlowResult]:
"""Convert a list of FlowHandler to a partial FlowResult that can be serialized."""
return [
{
"flow_id": flow.flow_id,
"handler": flow.handler,
"context": flow.context,
"step_id": flow.cur_step["step_id"] if flow.cur_step else None,
}
for flow in flows
if include_uninitialized or flow.cur_step is not None
]
class FlowManager(abc.ABC):
"""Manage all the flows that are in progress."""
@ -89,7 +106,8 @@ class FlowManager(abc.ABC):
self.hass = hass
self._initializing: dict[str, list[asyncio.Future]] = {}
self._initialize_tasks: dict[str, list[asyncio.Task]] = {}
self._progress: dict[str, Any] = {}
self._progress: dict[str, FlowHandler] = {}
self._handler_progress_index: dict[str, set[str]] = {}
async def async_wait_init_flow_finish(self, handler: str) -> None:
"""Wait till all flows in progress are initialized."""
@ -127,24 +145,39 @@ class FlowManager(abc.ABC):
"""Check if an existing matching flow is in progress with the same handler, context, and data."""
return any(
flow
for flow in self._progress.values()
if flow.handler == handler
and flow.context["source"] == context["source"]
and flow.init_data == data
for flow in self._async_progress_by_handler(handler)
if flow.context["source"] == context["source"] and flow.init_data == data
)
@callback
def async_get(self, flow_id: str) -> FlowResult | None:
"""Return a flow in progress as a partial FlowResult."""
if (flow := self._progress.get(flow_id)) is None:
raise UnknownFlow
return _async_flow_handler_to_flow_result([flow], False)[0]
@callback
def async_progress(self, include_uninitialized: bool = False) -> list[FlowResult]:
"""Return the flows in progress."""
"""Return the flows in progress as a partial FlowResult."""
return _async_flow_handler_to_flow_result(
self._progress.values(), include_uninitialized
)
@callback
def async_progress_by_handler(
self, handler: str, include_uninitialized: bool = False
) -> list[FlowResult]:
"""Return the flows in progress by handler as a partial FlowResult."""
return _async_flow_handler_to_flow_result(
self._async_progress_by_handler(handler), include_uninitialized
)
@callback
def _async_progress_by_handler(self, handler: str) -> list[FlowHandler]:
"""Return the flows in progress by handler."""
return [
{
"flow_id": flow.flow_id,
"handler": flow.handler,
"context": flow.context,
"step_id": flow.cur_step["step_id"] if flow.cur_step else None,
}
for flow in self._progress.values()
if include_uninitialized or flow.cur_step is not None
self._progress[flow_id]
for flow_id in self._handler_progress_index.get(handler, {})
]
async def async_init(
@ -187,7 +220,7 @@ class FlowManager(abc.ABC):
flow.flow_id = uuid.uuid4().hex
flow.context = context
flow.init_data = data
self._progress[flow.flow_id] = flow
self._async_add_flow_progress(flow)
result = await self._async_handle_step(flow, flow.init_step, data, init_done)
return flow, result
@ -205,6 +238,7 @@ class FlowManager(abc.ABC):
raise UnknownFlow
cur_step = flow.cur_step
assert cur_step is not None
if cur_step.get("data_schema") is not None and user_input is not None:
user_input = cur_step["data_schema"](user_input)
@ -245,8 +279,24 @@ class FlowManager(abc.ABC):
@callback
def async_abort(self, flow_id: str) -> None:
"""Abort a flow."""
if self._progress.pop(flow_id, None) is None:
self._async_remove_flow_progress(flow_id)
@callback
def _async_add_flow_progress(self, flow: FlowHandler) -> None:
"""Add a flow to in progress."""
self._progress[flow.flow_id] = flow
self._handler_progress_index.setdefault(flow.handler, set()).add(flow.flow_id)
@callback
def _async_remove_flow_progress(self, flow_id: str) -> None:
"""Remove a flow from in progress."""
flow = self._progress.pop(flow_id, None)
if flow is None:
raise UnknownFlow
handler = flow.handler
self._handler_progress_index[handler].remove(flow.flow_id)
if not self._handler_progress_index[handler]:
del self._handler_progress_index[handler]
async def _async_handle_step(
self,
@ -259,7 +309,7 @@ class FlowManager(abc.ABC):
method = f"async_step_{step_id}"
if not hasattr(flow, method):
self._progress.pop(flow.flow_id)
self._async_remove_flow_progress(flow.flow_id)
if step_done:
step_done.set_result(None)
raise UnknownStep(
@ -310,7 +360,7 @@ class FlowManager(abc.ABC):
return result
# Abort and Success results both finish the flow
self._progress.pop(flow.flow_id)
self._async_remove_flow_progress(flow.flow_id)
return result
@ -319,7 +369,7 @@ class FlowHandler:
"""Handle the configuration flow of a component."""
# Set by flow manager
cur_step: dict[str, str] | None = None
cur_step: dict[str, Any] | None = None
# While not purely typed, it makes typehinting more useful for us
# and removes the need for constant None checks or asserts.

View File

@ -114,3 +114,43 @@ async def test_login_exist_user(hass, aiohttp_client):
step = await resp.json()
assert step["type"] == "create_entry"
assert len(step["result"]) > 1
async def test_login_exist_user_ip_changes(hass, aiohttp_client):
"""Test logging in and the ip address changes results in an rejection."""
client = await async_setup_auth(hass, aiohttp_client, setup_api=True)
cred = await hass.auth.auth_providers[0].async_get_or_create_credentials(
{"username": "test-user"}
)
await hass.auth.async_get_or_create_user(cred)
resp = await client.post(
"/auth/login_flow",
json={
"client_id": CLIENT_ID,
"handler": ["insecure_example", None],
"redirect_uri": CLIENT_REDIRECT_URI,
},
)
assert resp.status == 200
step = await resp.json()
#
# Here we modify the ip_address in the context to make sure
# when ip address changes in the middle of the login flow we prevent logins.
#
# This method was chosen because it seemed less likely to break
# vs patching aiohttp internals to fake the ip address
#
for flow_id, flow in hass.auth.login_flow._progress.items():
assert flow_id == step["flow_id"]
flow.context["ip_address"] = "10.2.3.1"
resp = await client.post(
f"/auth/login_flow/{step['flow_id']}",
json={"client_id": CLIENT_ID, "username": "test-user", "password": "test-pass"},
)
assert resp.status == 400
response = await resp.json()
assert response == {"message": "IP address changed"}

View File

@ -349,7 +349,7 @@ async def test_remove_entry_cancels_reauth(hass, manager):
await entry.async_setup(hass)
await hass.async_block_till_done()
flows = hass.config_entries.flow.async_progress()
flows = hass.config_entries.flow.async_progress_by_handler("test")
assert len(flows) == 1
assert flows[0]["context"]["entry_id"] == entry.entry_id
assert flows[0]["context"]["source"] == config_entries.SOURCE_REAUTH
@ -357,7 +357,7 @@ async def test_remove_entry_cancels_reauth(hass, manager):
await manager.async_remove(entry.entry_id)
flows = hass.config_entries.flow.async_progress()
flows = hass.config_entries.flow.async_progress_by_handler("test")
assert len(flows) == 0
@ -2100,11 +2100,11 @@ async def test_unignore_step_form(hass, manager):
# Right after removal there shouldn't be an entry or active flows
assert len(hass.config_entries.async_entries("comp")) == 0
assert len(hass.config_entries.flow.async_progress()) == 0
assert len(hass.config_entries.flow.async_progress_by_handler("comp")) == 0
# But after a 'tick' the unignore step has run and we can see an active flow again.
await hass.async_block_till_done()
assert len(hass.config_entries.flow.async_progress()) == 1
assert len(hass.config_entries.flow.async_progress_by_handler("comp")) == 1
# and still not config entries
assert len(hass.config_entries.async_entries("comp")) == 0
@ -2144,7 +2144,7 @@ async def test_unignore_create_entry(hass, manager):
await manager.async_remove(entry.entry_id)
# Right after removal there shouldn't be an entry or flow
assert len(hass.config_entries.flow.async_progress()) == 0
assert len(hass.config_entries.flow.async_progress_by_handler("comp")) == 0
assert len(hass.config_entries.async_entries("comp")) == 0
# But after a 'tick' the unignore step has run and we can see a config entry.
@ -2155,7 +2155,7 @@ async def test_unignore_create_entry(hass, manager):
assert entry.title == "yo"
# And still no active flow
assert len(hass.config_entries.flow.async_progress()) == 0
assert len(hass.config_entries.flow.async_progress_by_handler("comp")) == 0
async def test_unignore_default_impl(hass, manager):

View File

@ -271,6 +271,8 @@ async def test_external_step(hass, manager):
result = await manager.async_init("test")
assert result["type"] == data_entry_flow.RESULT_TYPE_EXTERNAL_STEP
assert len(manager.async_progress()) == 1
assert len(manager.async_progress_by_handler("test")) == 1
assert manager.async_get(result["flow_id"])["handler"] == "test"
# Mimic external step
# Called by integrations: `hass.config_entries.flow.async_configure(…)`
@ -327,6 +329,8 @@ async def test_show_progress(hass, manager):
assert result["type"] == data_entry_flow.RESULT_TYPE_SHOW_PROGRESS
assert result["progress_action"] == "task_one"
assert len(manager.async_progress()) == 1
assert len(manager.async_progress_by_handler("test")) == 1
assert manager.async_get(result["flow_id"])["handler"] == "test"
# Mimic task one done and moving to task two
# Called by integrations: `hass.config_entries.flow.async_configure(…)`
@ -400,6 +404,13 @@ async def test_init_unknown_flow(manager):
await manager.async_init("test")
async def test_async_get_unknown_flow(manager):
"""Test that UnknownFlow is raised when async_get is called with a flow_id that does not exist."""
with pytest.raises(data_entry_flow.UnknownFlow):
await manager.async_get("does_not_exist")
async def test_async_has_matching_flow(
hass: HomeAssistant, manager: data_entry_flow.FlowManager
):
@ -424,6 +435,8 @@ async def test_async_has_matching_flow(
assert result["type"] == data_entry_flow.RESULT_TYPE_SHOW_PROGRESS
assert result["progress_action"] == "task_one"
assert len(manager.async_progress()) == 1
assert len(manager.async_progress_by_handler("test")) == 1
assert manager.async_get(result["flow_id"])["handler"] == "test"
assert (
manager.async_has_matching_flow(
@ -449,3 +462,28 @@ async def test_async_has_matching_flow(
)
is False
)
async def test_move_to_unknown_step_raises_and_removes_from_in_progress(manager):
"""Test that moving to an unknown step raises and removes the flow from in progress."""
@manager.mock_reg_handler("test")
class TestFlow(data_entry_flow.FlowHandler):
VERSION = 1
with pytest.raises(data_entry_flow.UnknownStep):
await manager.async_init("test", context={"init_step": "does_not_exist"})
assert manager.async_progress() == []
async def test_configure_raises_unknown_flow_if_not_in_progress(manager):
"""Test configure raises UnknownFlow if the flow is not in progress."""
with pytest.raises(data_entry_flow.UnknownFlow):
await manager.async_configure("wrong_flow_id")
async def test_abort_raises_unknown_flow_if_not_in_progress(manager):
"""Test abort raises UnknownFlow if the flow is not in progress."""
with pytest.raises(data_entry_flow.UnknownFlow):
await manager.async_abort("wrong_flow_id")