Cancel discovery flows that are initializing at shutdown (#49241)

pull/49271/head
J. Nick Koston 2021-04-15 07:13:42 -10:00 committed by GitHub
parent a529a12745
commit dafc7a072c
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 59 additions and 12 deletions

View File

@ -792,6 +792,7 @@ class ConfigEntries:
await asyncio.gather(
*[entry.async_shutdown() for entry in self._entries.values()]
)
await self.flow.async_shutdown()
async def async_initialize(self) -> None:
"""Initialize config entry config."""

View File

@ -61,6 +61,7 @@ class FlowManager(abc.ABC):
"""Initialize the flow manager."""
self.hass = hass
self._initializing: dict[str, list[asyncio.Future]] = {}
self._initialize_tasks: dict[str, list[asyncio.Task]] = {}
self._progress: dict[str, Any] = {}
async def async_wait_init_flow_finish(self, handler: str) -> None:
@ -118,21 +119,13 @@ class FlowManager(abc.ABC):
init_done: asyncio.Future = asyncio.Future()
self._initializing.setdefault(handler, []).append(init_done)
flow = await self.async_create_flow(handler, context=context, data=data)
if not flow:
self._initializing[handler].remove(init_done)
raise UnknownFlow("Flow was not created")
flow.hass = self.hass
flow.handler = handler
flow.flow_id = uuid.uuid4().hex
flow.context = context
self._progress[flow.flow_id] = flow
task = asyncio.create_task(self._async_init(init_done, handler, context, data))
self._initialize_tasks.setdefault(handler, []).append(task)
try:
result = await self._async_handle_step(
flow, flow.init_step, data, init_done
)
flow, result = await task
finally:
self._initialize_tasks[handler].remove(task)
self._initializing[handler].remove(init_done)
if result["type"] != RESULT_TYPE_ABORT:
@ -140,6 +133,31 @@ class FlowManager(abc.ABC):
return result
async def _async_init(
self,
init_done: asyncio.Future,
handler: str,
context: dict,
data: Any,
) -> tuple[FlowHandler, Any]:
"""Run the init in a task to allow it to be canceled at shutdown."""
flow = await self.async_create_flow(handler, context=context, data=data)
if not flow:
raise UnknownFlow("Flow was not created")
flow.hass = self.hass
flow.handler = handler
flow.flow_id = uuid.uuid4().hex
flow.context = context
self._progress[flow.flow_id] = flow
result = await self._async_handle_step(flow, flow.init_step, data, init_done)
return flow, result
async def async_shutdown(self) -> None:
"""Cancel any initializing flows."""
for task_list in self._initialize_tasks.values():
for task in task_list:
task.cancel()
async def async_configure(
self, flow_id: str, user_input: dict | None = None
) -> Any:

View File

@ -1,4 +1,7 @@
"""Test the flow classes."""
import asyncio
from unittest.mock import patch
import pytest
import voluptuous as vol
@ -367,3 +370,28 @@ async def test_abort_flow_exception(manager):
assert form["type"] == "abort"
assert form["reason"] == "mock-reason"
assert form["description_placeholders"] == {"placeholder": "yo"}
async def test_initializing_flows_canceled_on_shutdown(hass, manager):
"""Test that initializing flows are canceled on shutdown."""
@manager.mock_reg_handler("test")
class TestFlow(data_entry_flow.FlowHandler):
async def async_step_init(self, user_input=None):
await asyncio.sleep(1)
task = asyncio.create_task(manager.async_init("test"))
await hass.async_block_till_done()
await manager.async_shutdown()
with pytest.raises(asyncio.exceptions.CancelledError):
await task
async def test_init_unknown_flow(manager):
"""Test that UnknownFlow is raised when async_create_flow returns None."""
with pytest.raises(data_entry_flow.UnknownFlow), patch.object(
manager, "async_create_flow", return_value=None
):
await manager.async_init("test")