Improve type hints in data_entry_flow tests (#119877)

pull/119893/head
epenet 2024-06-18 13:25:28 +02:00 committed by GitHub
parent 6b27e9a745
commit 041746a50b
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
1 changed files with 67 additions and 57 deletions

View File

@ -19,42 +19,42 @@ from .common import (
)
class MockFlowManager(data_entry_flow.FlowManager):
"""Test flow manager."""
def __init__(self) -> None:
"""Initialize the flow manager."""
super().__init__(None)
self._handlers = Registry()
self.mock_reg_handler = self._handlers.register
self.mock_created_entries = []
async def async_create_flow(self, handler_key, *, context, data):
"""Test create flow."""
handler = self._handlers.get(handler_key)
if handler is None:
raise data_entry_flow.UnknownHandler
flow = handler()
flow.init_step = context.get("init_step", "init")
return flow
async def async_finish_flow(self, flow, result):
"""Test finish flow."""
if result["type"] == data_entry_flow.FlowResultType.CREATE_ENTRY:
result["source"] = flow.context.get("source")
self.mock_created_entries.append(result)
return result
@pytest.fixture
def manager():
def manager() -> MockFlowManager:
"""Return a flow manager."""
handlers = Registry()
entries = []
class FlowManager(data_entry_flow.FlowManager):
"""Test flow manager."""
async def async_create_flow(self, handler_key, *, context, data):
"""Test create flow."""
handler = handlers.get(handler_key)
if handler is None:
raise data_entry_flow.UnknownHandler
flow = handler()
flow.init_step = context.get("init_step", "init")
return flow
async def async_finish_flow(self, flow, result):
"""Test finish flow."""
if result["type"] == data_entry_flow.FlowResultType.CREATE_ENTRY:
result["source"] = flow.context.get("source")
entries.append(result)
return result
mgr = FlowManager(None)
# pylint: disable-next=attribute-defined-outside-init
mgr.mock_created_entries = entries
# pylint: disable-next=attribute-defined-outside-init
mgr.mock_reg_handler = handlers.register
return mgr
return MockFlowManager()
async def test_configure_reuses_handler_instance(manager) -> None:
async def test_configure_reuses_handler_instance(manager: MockFlowManager) -> None:
"""Test that we reuse instances."""
@manager.mock_reg_handler("test")
@ -82,7 +82,7 @@ async def test_configure_reuses_handler_instance(manager) -> None:
assert len(manager.mock_created_entries) == 0
async def test_configure_two_steps(manager: data_entry_flow.FlowManager) -> None:
async def test_configure_two_steps(manager: MockFlowManager) -> None:
"""Test that we reuse instances."""
@manager.mock_reg_handler("test")
@ -117,7 +117,7 @@ async def test_configure_two_steps(manager: data_entry_flow.FlowManager) -> None
assert result["data"] == ["INIT-DATA", "SECOND-DATA"]
async def test_show_form(manager) -> None:
async def test_show_form(manager: MockFlowManager) -> None:
"""Test that we can show a form."""
schema = vol.Schema({vol.Required("username"): str, vol.Required("password"): str})
@ -136,7 +136,7 @@ async def test_show_form(manager) -> None:
assert form["errors"] == {"username": "Should be unique."}
async def test_abort_removes_instance(manager) -> None:
async def test_abort_removes_instance(manager: MockFlowManager) -> None:
"""Test that abort removes the flow from progress."""
@manager.mock_reg_handler("test")
@ -158,7 +158,7 @@ async def test_abort_removes_instance(manager) -> None:
assert len(manager.mock_created_entries) == 0
async def test_abort_calls_async_remove(manager) -> None:
async def test_abort_calls_async_remove(manager: MockFlowManager) -> None:
"""Test abort calling the async_remove FlowHandler method."""
@manager.mock_reg_handler("test")
@ -177,7 +177,7 @@ async def test_abort_calls_async_remove(manager) -> None:
async def test_abort_calls_async_remove_with_exception(
manager, caplog: pytest.LogCaptureFixture
manager: MockFlowManager, caplog: pytest.LogCaptureFixture
) -> None:
"""Test abort calling the async_remove FlowHandler method, with an exception."""
@ -199,7 +199,7 @@ async def test_abort_calls_async_remove_with_exception(
assert len(manager.mock_created_entries) == 0
async def test_create_saves_data(manager) -> None:
async def test_create_saves_data(manager: MockFlowManager) -> None:
"""Test creating a config entry."""
@manager.mock_reg_handler("test")
@ -220,7 +220,7 @@ async def test_create_saves_data(manager) -> None:
assert entry["source"] is None
async def test_discovery_init_flow(manager) -> None:
async def test_discovery_init_flow(manager: MockFlowManager) -> None:
"""Test a flow initialized by discovery."""
@manager.mock_reg_handler("test")
@ -290,7 +290,7 @@ async def test_finish_callback_change_result_type(hass: HomeAssistant) -> None:
assert result["result"] == 2
async def test_external_step(hass: HomeAssistant, manager) -> None:
async def test_external_step(hass: HomeAssistant, manager: MockFlowManager) -> None:
"""Test external step logic."""
manager.hass = hass
@ -340,7 +340,7 @@ async def test_external_step(hass: HomeAssistant, manager) -> None:
assert result["title"] == "Hello"
async def test_show_progress(hass: HomeAssistant, manager) -> None:
async def test_show_progress(hass: HomeAssistant, manager: MockFlowManager) -> None:
"""Test show progress logic."""
manager.hass = hass
events = []
@ -443,7 +443,9 @@ async def test_show_progress(hass: HomeAssistant, manager) -> None:
assert result["title"] == "Hello"
async def test_show_progress_error(hass: HomeAssistant, manager) -> None:
async def test_show_progress_error(
hass: HomeAssistant, manager: MockFlowManager
) -> None:
"""Test show progress logic."""
manager.hass = hass
events = []
@ -506,7 +508,9 @@ async def test_show_progress_error(hass: HomeAssistant, manager) -> None:
assert result["reason"] == "error"
async def test_show_progress_hidden_from_frontend(hass: HomeAssistant, manager) -> None:
async def test_show_progress_hidden_from_frontend(
hass: HomeAssistant, manager: MockFlowManager
) -> None:
"""Test show progress done is not sent to frontend."""
manager.hass = hass
async_show_progress_done_called = False
@ -557,7 +561,7 @@ async def test_show_progress_hidden_from_frontend(hass: HomeAssistant, manager)
async def test_show_progress_legacy(
hass: HomeAssistant, manager, caplog: pytest.LogCaptureFixture
hass: HomeAssistant, manager: MockFlowManager, caplog: pytest.LogCaptureFixture
) -> None:
"""Test show progress logic.
@ -659,7 +663,7 @@ async def test_show_progress_legacy(
async def test_show_progress_fires_only_when_changed(
hass: HomeAssistant, manager
hass: HomeAssistant, manager: MockFlowManager
) -> None:
"""Test show progress change logic."""
manager.hass = hass
@ -745,7 +749,7 @@ async def test_show_progress_fires_only_when_changed(
) # change (description placeholder)
async def test_abort_flow_exception(manager) -> None:
async def test_abort_flow_exception(manager: MockFlowManager) -> None:
"""Test that the AbortFlow exception works."""
@manager.mock_reg_handler("test")
@ -759,7 +763,7 @@ async def test_abort_flow_exception(manager) -> None:
assert form["description_placeholders"] == {"placeholder": "yo"}
async def test_init_unknown_flow(manager) -> None:
async def test_init_unknown_flow(manager: MockFlowManager) -> None:
"""Test that UnknownFlow is raised when async_create_flow returns None."""
with (
@ -769,7 +773,7 @@ async def test_init_unknown_flow(manager) -> None:
await manager.async_init("test")
async def test_async_get_unknown_flow(manager) -> None:
async def test_async_get_unknown_flow(manager: MockFlowManager) -> None:
"""Test that UnknownFlow is raised when async_get is called with a flow_id that does not exist."""
with pytest.raises(data_entry_flow.UnknownFlow):
@ -777,7 +781,7 @@ async def test_async_get_unknown_flow(manager) -> None:
async def test_async_has_matching_flow(
hass: HomeAssistant, manager: data_entry_flow.FlowManager
hass: HomeAssistant, manager: MockFlowManager
) -> None:
"""Test we can check for matching flows."""
manager.hass = hass
@ -854,7 +858,7 @@ async def test_async_has_matching_flow(
async def test_move_to_unknown_step_raises_and_removes_from_in_progress(
manager,
manager: MockFlowManager,
) -> None:
"""Test that moving to an unknown step raises and removes the flow from in progress."""
@ -880,7 +884,7 @@ async def test_move_to_unknown_step_raises_and_removes_from_in_progress(
],
)
async def test_next_step_unknown_step_raises_and_removes_from_in_progress(
manager, result_type: str, params: dict[str, str]
manager: MockFlowManager, result_type: str, params: dict[str, str]
) -> None:
"""Test that moving to an unknown step raises and removes the flow from in progress."""
@ -897,13 +901,17 @@ async def test_next_step_unknown_step_raises_and_removes_from_in_progress(
assert manager.async_progress() == []
async def test_configure_raises_unknown_flow_if_not_in_progress(manager) -> None:
async def test_configure_raises_unknown_flow_if_not_in_progress(
manager: MockFlowManager,
) -> None:
"""Test configure raises UnknownFlow if the flow is not in progress."""
with pytest.raises(data_entry_flow.UnknownFlow):
await manager.async_configure("wrong_flow_id")
async def test_abort_raises_unknown_flow_if_not_in_progress(manager) -> None:
async def test_abort_raises_unknown_flow_if_not_in_progress(
manager: MockFlowManager,
) -> None:
"""Test abort raises UnknownFlow if the flow is not in progress."""
with pytest.raises(data_entry_flow.UnknownFlow):
await manager.async_abort("wrong_flow_id")
@ -913,7 +921,11 @@ async def test_abort_raises_unknown_flow_if_not_in_progress(manager) -> None:
"menu_options",
[["target1", "target2"], {"target1": "Target 1", "target2": "Target 2"}],
)
async def test_show_menu(hass: HomeAssistant, manager, menu_options) -> None:
async def test_show_menu(
hass: HomeAssistant,
manager: MockFlowManager,
menu_options: list[str] | dict[str, str],
) -> None:
"""Test show menu."""
manager.hass = hass
@ -952,9 +964,7 @@ async def test_show_menu(hass: HomeAssistant, manager, menu_options) -> None:
assert result["step_id"] == "target1"
async def test_find_flows_by_init_data_type(
manager: data_entry_flow.FlowManager,
) -> None:
async def test_find_flows_by_init_data_type(manager: MockFlowManager) -> None:
"""Test we can find flows by init data type."""
@dataclasses.dataclass