Use ConfigFlow.has_matching_flow to deduplicate samsungtv flows (#127235)

pull/127592/head
Erik Montnemery 2024-10-01 17:56:38 +02:00 committed by GitHub
parent 1c11229510
commit 4060705d87
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 81 additions and 6 deletions

View File

@ -5,7 +5,7 @@ from __future__ import annotations
from collections.abc import Mapping
from functools import partial
import socket
from typing import Any
from typing import Any, Self
from urllib.parse import urlparse
import getmac
@ -425,10 +425,12 @@ class SamsungTVConfigFlow(ConfigFlow, domain=DOMAIN):
@callback
def _async_abort_if_host_already_in_progress(self) -> None:
self.context[CONF_HOST] = self._host
for progress in self._async_in_progress():
if progress.get("context", {}).get(CONF_HOST) == self._host:
raise AbortFlow("already_in_progress")
if self.hass.config_entries.flow.async_has_matching_flow(self):
raise AbortFlow("already_in_progress")
def is_matching(self, other_flow: Self) -> bool:
"""Return True if other_flow is matching this flow."""
return other_flow._host == self._host # noqa: SLF001
@callback
def _abort_if_manufacturer_is_not_samsung(self) -> None:

View File

@ -22,6 +22,7 @@ from websockets.exceptions import (
from homeassistant import config_entries
from homeassistant.components import dhcp, ssdp, zeroconf
from homeassistant.components.samsungtv.config_flow import SamsungTVConfigFlow
from homeassistant.components.samsungtv.const import (
CONF_MANUFACTURER,
CONF_SESSION_ID,
@ -56,7 +57,7 @@ from homeassistant.const import (
CONF_TOKEN,
)
from homeassistant.core import HomeAssistant
from homeassistant.data_entry_flow import FlowResultType
from homeassistant.data_entry_flow import BaseServiceInfo, FlowResultType
from homeassistant.setup import async_setup_component
from .const import (
@ -982,6 +983,78 @@ async def test_dhcp_wired(hass: HomeAssistant, rest_api: Mock) -> None:
assert result["result"].unique_id == "be9554b9-c9fb-41f4-8920-22da015376a4"
@pytest.mark.usefixtures("remotews", "rest_api_non_ssl_only", "remoteencws_failing")
@pytest.mark.parametrize(
("source1", "data1", "source2", "data2", "is_matching_result"),
[
(
config_entries.SOURCE_DHCP,
MOCK_DHCP_DATA,
config_entries.SOURCE_DHCP,
MOCK_DHCP_DATA,
True,
),
(
config_entries.SOURCE_DHCP,
MOCK_DHCP_DATA,
config_entries.SOURCE_ZEROCONF,
MOCK_ZEROCONF_DATA,
False,
),
(
config_entries.SOURCE_ZEROCONF,
MOCK_ZEROCONF_DATA,
config_entries.SOURCE_DHCP,
MOCK_DHCP_DATA,
False,
),
(
config_entries.SOURCE_ZEROCONF,
MOCK_ZEROCONF_DATA,
config_entries.SOURCE_ZEROCONF,
MOCK_ZEROCONF_DATA,
True,
),
],
)
async def test_dhcp_zeroconf_already_in_progress(
hass: HomeAssistant,
source1: str,
data1: BaseServiceInfo,
source2: str,
data2: BaseServiceInfo,
is_matching_result: bool,
) -> None:
"""Test starting a flow from dhcp or zeroconf when already in progress."""
# confirm to add the entry
result = await hass.config_entries.flow.async_init(
DOMAIN, context={"source": source1}, data=data1
)
await hass.async_block_till_done()
assert result["type"] is FlowResultType.FORM
assert result["step_id"] == "confirm"
real_is_matching = SamsungTVConfigFlow.is_matching
return_values = []
def is_matching(self, other_flow) -> bool:
return_values.append(real_is_matching(self, other_flow))
return return_values[-1]
with patch.object(
SamsungTVConfigFlow, "is_matching", wraps=is_matching, autospec=True
):
# confirm to add the entry
result = await hass.config_entries.flow.async_init(
DOMAIN, context={"source": source2}, data=data2
)
await hass.async_block_till_done()
assert result["type"] is FlowResultType.ABORT
assert result["reason"] == RESULT_ALREADY_IN_PROGRESS
# Ensure the is_matching method returned the expected value
assert return_values == [is_matching_result]
@pytest.mark.usefixtures("remotews", "rest_api", "remoteencws_failing")
async def test_zeroconf(hass: HomeAssistant) -> None:
"""Test starting a flow from zeroconf."""