diff --git a/homeassistant/components/samsungtv/config_flow.py b/homeassistant/components/samsungtv/config_flow.py index e89c5e59b0e..9d2ecefd442 100644 --- a/homeassistant/components/samsungtv/config_flow.py +++ b/homeassistant/components/samsungtv/config_flow.py @@ -5,7 +5,7 @@ from __future__ import annotations from collections.abc import Mapping from functools import partial import socket -from typing import Any +from typing import Any, Self from urllib.parse import urlparse import getmac @@ -425,10 +425,12 @@ class SamsungTVConfigFlow(ConfigFlow, domain=DOMAIN): @callback def _async_abort_if_host_already_in_progress(self) -> None: - self.context[CONF_HOST] = self._host - for progress in self._async_in_progress(): - if progress.get("context", {}).get(CONF_HOST) == self._host: - raise AbortFlow("already_in_progress") + if self.hass.config_entries.flow.async_has_matching_flow(self): + raise AbortFlow("already_in_progress") + + def is_matching(self, other_flow: Self) -> bool: + """Return True if other_flow is matching this flow.""" + return other_flow._host == self._host # noqa: SLF001 @callback def _abort_if_manufacturer_is_not_samsung(self) -> None: diff --git a/tests/components/samsungtv/test_config_flow.py b/tests/components/samsungtv/test_config_flow.py index 43d8c81d000..7e707376b6f 100644 --- a/tests/components/samsungtv/test_config_flow.py +++ b/tests/components/samsungtv/test_config_flow.py @@ -22,6 +22,7 @@ from websockets.exceptions import ( from homeassistant import config_entries from homeassistant.components import dhcp, ssdp, zeroconf +from homeassistant.components.samsungtv.config_flow import SamsungTVConfigFlow from homeassistant.components.samsungtv.const import ( CONF_MANUFACTURER, CONF_SESSION_ID, @@ -56,7 +57,7 @@ from homeassistant.const import ( CONF_TOKEN, ) from homeassistant.core import HomeAssistant -from homeassistant.data_entry_flow import FlowResultType +from homeassistant.data_entry_flow import BaseServiceInfo, FlowResultType from homeassistant.setup import async_setup_component from .const import ( @@ -982,6 +983,78 @@ async def test_dhcp_wired(hass: HomeAssistant, rest_api: Mock) -> None: assert result["result"].unique_id == "be9554b9-c9fb-41f4-8920-22da015376a4" +@pytest.mark.usefixtures("remotews", "rest_api_non_ssl_only", "remoteencws_failing") +@pytest.mark.parametrize( + ("source1", "data1", "source2", "data2", "is_matching_result"), + [ + ( + config_entries.SOURCE_DHCP, + MOCK_DHCP_DATA, + config_entries.SOURCE_DHCP, + MOCK_DHCP_DATA, + True, + ), + ( + config_entries.SOURCE_DHCP, + MOCK_DHCP_DATA, + config_entries.SOURCE_ZEROCONF, + MOCK_ZEROCONF_DATA, + False, + ), + ( + config_entries.SOURCE_ZEROCONF, + MOCK_ZEROCONF_DATA, + config_entries.SOURCE_DHCP, + MOCK_DHCP_DATA, + False, + ), + ( + config_entries.SOURCE_ZEROCONF, + MOCK_ZEROCONF_DATA, + config_entries.SOURCE_ZEROCONF, + MOCK_ZEROCONF_DATA, + True, + ), + ], +) +async def test_dhcp_zeroconf_already_in_progress( + hass: HomeAssistant, + source1: str, + data1: BaseServiceInfo, + source2: str, + data2: BaseServiceInfo, + is_matching_result: bool, +) -> None: + """Test starting a flow from dhcp or zeroconf when already in progress.""" + # confirm to add the entry + result = await hass.config_entries.flow.async_init( + DOMAIN, context={"source": source1}, data=data1 + ) + await hass.async_block_till_done() + assert result["type"] is FlowResultType.FORM + assert result["step_id"] == "confirm" + + real_is_matching = SamsungTVConfigFlow.is_matching + return_values = [] + + def is_matching(self, other_flow) -> bool: + return_values.append(real_is_matching(self, other_flow)) + return return_values[-1] + + with patch.object( + SamsungTVConfigFlow, "is_matching", wraps=is_matching, autospec=True + ): + # confirm to add the entry + result = await hass.config_entries.flow.async_init( + DOMAIN, context={"source": source2}, data=data2 + ) + await hass.async_block_till_done() + assert result["type"] is FlowResultType.ABORT + assert result["reason"] == RESULT_ALREADY_IN_PROGRESS + # Ensure the is_matching method returned the expected value + assert return_values == [is_matching_result] + + @pytest.mark.usefixtures("remotews", "rest_api", "remoteencws_failing") async def test_zeroconf(hass: HomeAssistant) -> None: """Test starting a flow from zeroconf."""