Fix DiscoveryFlowHandler when discovery_function returns bool (#133563)
Co-authored-by: J. Nick Koston <nick@koston.org>pull/136092/head
parent
0027d907a4
commit
c687a6f669
|
@ -67,9 +67,11 @@ class DiscoveryFlowHandler[_R: Awaitable[bool] | bool](config_entries.ConfigFlow
|
|||
in_progress = self._async_in_progress()
|
||||
|
||||
if not (has_devices := bool(in_progress)):
|
||||
has_devices = await cast(
|
||||
"asyncio.Future[bool]", self._discovery_function(self.hass)
|
||||
)
|
||||
discovery_result = self._discovery_function(self.hass)
|
||||
if isinstance(discovery_result, bool):
|
||||
has_devices = discovery_result
|
||||
else:
|
||||
has_devices = await cast("asyncio.Future[bool]", discovery_result)
|
||||
|
||||
if not has_devices:
|
||||
return self.async_abort(reason="no_devices_found")
|
||||
|
|
|
@ -1,6 +1,8 @@
|
|||
"""Tests for the Config Entry Flow helper."""
|
||||
|
||||
from collections.abc import Generator
|
||||
import asyncio
|
||||
from collections.abc import Callable, Generator
|
||||
from contextlib import contextmanager
|
||||
from unittest.mock import Mock, PropertyMock, patch
|
||||
|
||||
import pytest
|
||||
|
@ -13,22 +15,44 @@ from homeassistant.helpers import config_entry_flow
|
|||
from tests.common import MockConfigEntry, MockModule, mock_integration, mock_platform
|
||||
|
||||
|
||||
@contextmanager
|
||||
def _make_discovery_flow_conf(
|
||||
has_discovered_devices: Callable[[], asyncio.Future[bool] | bool],
|
||||
) -> Generator[None]:
|
||||
with patch.dict(config_entries.HANDLERS):
|
||||
config_entry_flow.register_discovery_flow(
|
||||
"test", "Test", has_discovered_devices
|
||||
)
|
||||
yield
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def discovery_flow_conf(hass: HomeAssistant) -> Generator[dict[str, bool]]:
|
||||
"""Register a handler."""
|
||||
def async_discovery_flow_conf(hass: HomeAssistant) -> Generator[dict[str, bool]]:
|
||||
"""Register a handler with an async discovery function."""
|
||||
handler_conf = {"discovered": False}
|
||||
|
||||
async def has_discovered_devices(hass: HomeAssistant) -> bool:
|
||||
"""Mock if we have discovered devices."""
|
||||
return handler_conf["discovered"]
|
||||
|
||||
with patch.dict(config_entries.HANDLERS):
|
||||
config_entry_flow.register_discovery_flow(
|
||||
"test", "Test", has_discovered_devices
|
||||
)
|
||||
with _make_discovery_flow_conf(has_discovered_devices):
|
||||
yield handler_conf
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def discovery_flow_conf(hass: HomeAssistant) -> Generator[dict[str, bool]]:
|
||||
"""Register a handler with a async friendly callback function."""
|
||||
handler_conf = {"discovered": False}
|
||||
|
||||
def has_discovered_devices(hass: HomeAssistant) -> bool:
|
||||
"""Mock if we have discovered devices."""
|
||||
return handler_conf["discovered"]
|
||||
|
||||
with _make_discovery_flow_conf(has_discovered_devices):
|
||||
yield handler_conf
|
||||
handler_conf = {"discovered": False}
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def webhook_flow_conf(hass: HomeAssistant) -> Generator[None]:
|
||||
"""Register a handler."""
|
||||
|
@ -95,6 +119,33 @@ async def test_user_has_confirmation(
|
|||
assert result["type"] == data_entry_flow.FlowResultType.CREATE_ENTRY
|
||||
|
||||
|
||||
async def test_user_has_confirmation_async_discovery_flow(
|
||||
hass: HomeAssistant, async_discovery_flow_conf: dict[str, bool]
|
||||
) -> None:
|
||||
"""Test user requires confirmation to setup with an async has_discovered_devices."""
|
||||
async_discovery_flow_conf["discovered"] = True
|
||||
mock_platform(hass, "test.config_flow", None)
|
||||
|
||||
result = await hass.config_entries.flow.async_init(
|
||||
"test", context={"source": config_entries.SOURCE_USER}, data={}
|
||||
)
|
||||
|
||||
assert result["type"] == data_entry_flow.FlowResultType.FORM
|
||||
assert result["step_id"] == "confirm"
|
||||
|
||||
progress = hass.config_entries.flow.async_progress()
|
||||
assert len(progress) == 1
|
||||
assert progress[0]["flow_id"] == result["flow_id"]
|
||||
assert progress[0]["context"] == {
|
||||
"confirm_only": True,
|
||||
"source": config_entries.SOURCE_USER,
|
||||
"unique_id": "test",
|
||||
}
|
||||
|
||||
result = await hass.config_entries.flow.async_configure(result["flow_id"], {})
|
||||
assert result["type"] == data_entry_flow.FlowResultType.CREATE_ENTRY
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"source",
|
||||
[
|
||||
|
|
Loading…
Reference in New Issue