Cancel discovery flows that are initializing at shutdown (#49241)
parent
a529a12745
commit
dafc7a072c
|
@ -792,6 +792,7 @@ class ConfigEntries:
|
|||
await asyncio.gather(
|
||||
*[entry.async_shutdown() for entry in self._entries.values()]
|
||||
)
|
||||
await self.flow.async_shutdown()
|
||||
|
||||
async def async_initialize(self) -> None:
|
||||
"""Initialize config entry config."""
|
||||
|
|
|
@ -61,6 +61,7 @@ class FlowManager(abc.ABC):
|
|||
"""Initialize the flow manager."""
|
||||
self.hass = hass
|
||||
self._initializing: dict[str, list[asyncio.Future]] = {}
|
||||
self._initialize_tasks: dict[str, list[asyncio.Task]] = {}
|
||||
self._progress: dict[str, Any] = {}
|
||||
|
||||
async def async_wait_init_flow_finish(self, handler: str) -> None:
|
||||
|
@ -118,21 +119,13 @@ class FlowManager(abc.ABC):
|
|||
init_done: asyncio.Future = asyncio.Future()
|
||||
self._initializing.setdefault(handler, []).append(init_done)
|
||||
|
||||
flow = await self.async_create_flow(handler, context=context, data=data)
|
||||
if not flow:
|
||||
self._initializing[handler].remove(init_done)
|
||||
raise UnknownFlow("Flow was not created")
|
||||
flow.hass = self.hass
|
||||
flow.handler = handler
|
||||
flow.flow_id = uuid.uuid4().hex
|
||||
flow.context = context
|
||||
self._progress[flow.flow_id] = flow
|
||||
task = asyncio.create_task(self._async_init(init_done, handler, context, data))
|
||||
self._initialize_tasks.setdefault(handler, []).append(task)
|
||||
|
||||
try:
|
||||
result = await self._async_handle_step(
|
||||
flow, flow.init_step, data, init_done
|
||||
)
|
||||
flow, result = await task
|
||||
finally:
|
||||
self._initialize_tasks[handler].remove(task)
|
||||
self._initializing[handler].remove(init_done)
|
||||
|
||||
if result["type"] != RESULT_TYPE_ABORT:
|
||||
|
@ -140,6 +133,31 @@ class FlowManager(abc.ABC):
|
|||
|
||||
return result
|
||||
|
||||
async def _async_init(
|
||||
self,
|
||||
init_done: asyncio.Future,
|
||||
handler: str,
|
||||
context: dict,
|
||||
data: Any,
|
||||
) -> tuple[FlowHandler, Any]:
|
||||
"""Run the init in a task to allow it to be canceled at shutdown."""
|
||||
flow = await self.async_create_flow(handler, context=context, data=data)
|
||||
if not flow:
|
||||
raise UnknownFlow("Flow was not created")
|
||||
flow.hass = self.hass
|
||||
flow.handler = handler
|
||||
flow.flow_id = uuid.uuid4().hex
|
||||
flow.context = context
|
||||
self._progress[flow.flow_id] = flow
|
||||
result = await self._async_handle_step(flow, flow.init_step, data, init_done)
|
||||
return flow, result
|
||||
|
||||
async def async_shutdown(self) -> None:
|
||||
"""Cancel any initializing flows."""
|
||||
for task_list in self._initialize_tasks.values():
|
||||
for task in task_list:
|
||||
task.cancel()
|
||||
|
||||
async def async_configure(
|
||||
self, flow_id: str, user_input: dict | None = None
|
||||
) -> Any:
|
||||
|
|
|
@ -1,4 +1,7 @@
|
|||
"""Test the flow classes."""
|
||||
import asyncio
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
import voluptuous as vol
|
||||
|
||||
|
@ -367,3 +370,28 @@ async def test_abort_flow_exception(manager):
|
|||
assert form["type"] == "abort"
|
||||
assert form["reason"] == "mock-reason"
|
||||
assert form["description_placeholders"] == {"placeholder": "yo"}
|
||||
|
||||
|
||||
async def test_initializing_flows_canceled_on_shutdown(hass, manager):
|
||||
"""Test that initializing flows are canceled on shutdown."""
|
||||
|
||||
@manager.mock_reg_handler("test")
|
||||
class TestFlow(data_entry_flow.FlowHandler):
|
||||
async def async_step_init(self, user_input=None):
|
||||
await asyncio.sleep(1)
|
||||
|
||||
task = asyncio.create_task(manager.async_init("test"))
|
||||
await hass.async_block_till_done()
|
||||
await manager.async_shutdown()
|
||||
|
||||
with pytest.raises(asyncio.exceptions.CancelledError):
|
||||
await task
|
||||
|
||||
|
||||
async def test_init_unknown_flow(manager):
|
||||
"""Test that UnknownFlow is raised when async_create_flow returns None."""
|
||||
|
||||
with pytest.raises(data_entry_flow.UnknownFlow), patch.object(
|
||||
manager, "async_create_flow", return_value=None
|
||||
):
|
||||
await manager.async_init("test")
|
||||
|
|
Loading…
Reference in New Issue