From c687a6f66910f243ed5a16446ce0508ee020e53c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Konrad=20Vit=C3=A9?= Date: Thu, 16 Jan 2025 23:31:16 +0100 Subject: [PATCH] Fix DiscoveryFlowHandler when discovery_function returns bool (#133563) Co-authored-by: J. Nick Koston --- homeassistant/helpers/config_entry_flow.py | 8 ++- tests/helpers/test_config_entry_flow.py | 65 +++++++++++++++++++--- 2 files changed, 63 insertions(+), 10 deletions(-) diff --git a/homeassistant/helpers/config_entry_flow.py b/homeassistant/helpers/config_entry_flow.py index b047e1aef81..60f2cd6e1a1 100644 --- a/homeassistant/helpers/config_entry_flow.py +++ b/homeassistant/helpers/config_entry_flow.py @@ -67,9 +67,11 @@ class DiscoveryFlowHandler[_R: Awaitable[bool] | bool](config_entries.ConfigFlow in_progress = self._async_in_progress() if not (has_devices := bool(in_progress)): - has_devices = await cast( - "asyncio.Future[bool]", self._discovery_function(self.hass) - ) + discovery_result = self._discovery_function(self.hass) + if isinstance(discovery_result, bool): + has_devices = discovery_result + else: + has_devices = await cast("asyncio.Future[bool]", discovery_result) if not has_devices: return self.async_abort(reason="no_devices_found") diff --git a/tests/helpers/test_config_entry_flow.py b/tests/helpers/test_config_entry_flow.py index 13e28bb8840..172aa393538 100644 --- a/tests/helpers/test_config_entry_flow.py +++ b/tests/helpers/test_config_entry_flow.py @@ -1,6 +1,8 @@ """Tests for the Config Entry Flow helper.""" -from collections.abc import Generator +import asyncio +from collections.abc import Callable, Generator +from contextlib import contextmanager from unittest.mock import Mock, PropertyMock, patch import pytest @@ -13,22 +15,44 @@ from homeassistant.helpers import config_entry_flow from tests.common import MockConfigEntry, MockModule, mock_integration, mock_platform +@contextmanager +def _make_discovery_flow_conf( + has_discovered_devices: Callable[[], asyncio.Future[bool] | bool], +) -> Generator[None]: + with patch.dict(config_entries.HANDLERS): + config_entry_flow.register_discovery_flow( + "test", "Test", has_discovered_devices + ) + yield + + @pytest.fixture -def discovery_flow_conf(hass: HomeAssistant) -> Generator[dict[str, bool]]: - """Register a handler.""" +def async_discovery_flow_conf(hass: HomeAssistant) -> Generator[dict[str, bool]]: + """Register a handler with an async discovery function.""" handler_conf = {"discovered": False} async def has_discovered_devices(hass: HomeAssistant) -> bool: """Mock if we have discovered devices.""" return handler_conf["discovered"] - with patch.dict(config_entries.HANDLERS): - config_entry_flow.register_discovery_flow( - "test", "Test", has_discovered_devices - ) + with _make_discovery_flow_conf(has_discovered_devices): yield handler_conf +@pytest.fixture +def discovery_flow_conf(hass: HomeAssistant) -> Generator[dict[str, bool]]: + """Register a handler with a async friendly callback function.""" + handler_conf = {"discovered": False} + + def has_discovered_devices(hass: HomeAssistant) -> bool: + """Mock if we have discovered devices.""" + return handler_conf["discovered"] + + with _make_discovery_flow_conf(has_discovered_devices): + yield handler_conf + handler_conf = {"discovered": False} + + @pytest.fixture def webhook_flow_conf(hass: HomeAssistant) -> Generator[None]: """Register a handler.""" @@ -95,6 +119,33 @@ async def test_user_has_confirmation( assert result["type"] == data_entry_flow.FlowResultType.CREATE_ENTRY +async def test_user_has_confirmation_async_discovery_flow( + hass: HomeAssistant, async_discovery_flow_conf: dict[str, bool] +) -> None: + """Test user requires confirmation to setup with an async has_discovered_devices.""" + async_discovery_flow_conf["discovered"] = True + mock_platform(hass, "test.config_flow", None) + + result = await hass.config_entries.flow.async_init( + "test", context={"source": config_entries.SOURCE_USER}, data={} + ) + + assert result["type"] == data_entry_flow.FlowResultType.FORM + assert result["step_id"] == "confirm" + + progress = hass.config_entries.flow.async_progress() + assert len(progress) == 1 + assert progress[0]["flow_id"] == result["flow_id"] + assert progress[0]["context"] == { + "confirm_only": True, + "source": config_entries.SOURCE_USER, + "unique_id": "test", + } + + result = await hass.config_entries.flow.async_configure(result["flow_id"], {}) + assert result["type"] == data_entry_flow.FlowResultType.CREATE_ENTRY + + @pytest.mark.parametrize( "source", [