From f8fc90bc07ccf6b0b493669226a8c471687fc514 Mon Sep 17 00:00:00 2001
From: puddly <32534428+puddly@users.noreply.github.com>
Date: Wed, 31 Aug 2022 12:41:41 -0400
Subject: [PATCH] Add ZHA config flow single instance checks for zeroconf and
 hardware (#77612)

---
 homeassistant/components/zha/config_flow.py |  64 +++++-----
 tests/components/zha/test_config_flow.py    | 134 +++++++++++++++-----
 2 files changed, 136 insertions(+), 62 deletions(-)

diff --git a/homeassistant/components/zha/config_flow.py b/homeassistant/components/zha/config_flow.py
index 9fc17c25f5b..ce2080e4a13 100644
--- a/homeassistant/components/zha/config_flow.py
+++ b/homeassistant/components/zha/config_flow.py
@@ -551,6 +551,36 @@ class ZhaConfigFlowHandler(BaseZhaFlow, config_entries.ConfigFlow, domain=DOMAIN
 
         return await self.async_step_choose_serial_port(user_input)
 
+    async def async_step_confirm(
+        self, user_input: dict[str, Any] | None = None
+    ) -> FlowResult:
+        """Confirm a discovery."""
+        self._set_confirm_only()
+
+        # Don't permit discovery if ZHA is already set up
+        if self._async_current_entries():
+            return self.async_abort(reason="single_instance_allowed")
+
+        # Without confirmation, discovery can automatically progress into parts of the
+        # config flow logic that interacts with hardware!
+        if user_input is not None or not onboarding.async_is_onboarded(self.hass):
+            # Probe the radio type if we don't have one yet
+            if self._radio_type is None and not await self._detect_radio_type():
+                # This path probably will not happen now that we have
+                # more precise USB matching unless there is a problem
+                # with the device
+                return self.async_abort(reason="usb_probe_failed")
+
+            if self._device_settings is None:
+                return await self.async_step_manual_port_config()
+
+            return await self.async_step_choose_formation_strategy()
+
+        return self.async_show_form(
+            step_id="confirm",
+            description_placeholders={CONF_NAME: self._title},
+        )
+
     async def async_step_usb(self, discovery_info: usb.UsbServiceInfo) -> FlowResult:
         """Handle usb discovery."""
         vid = discovery_info.vid
@@ -570,9 +600,6 @@ class ZhaConfigFlowHandler(BaseZhaFlow, config_entries.ConfigFlow, domain=DOMAIN
                     },
                 }
             )
-        # Check if already configured
-        if self._async_current_entries():
-            return self.async_abort(reason="single_instance_allowed")
 
         # If they already have a discovery for deconz we ignore the usb discovery as
         # they probably want to use it there instead
@@ -591,32 +618,14 @@ class ZhaConfigFlowHandler(BaseZhaFlow, config_entries.ConfigFlow, domain=DOMAIN
             vid,
             pid,
         )
-        self._set_confirm_only()
         self.context["title_placeholders"] = {CONF_NAME: self._title}
         return await self.async_step_confirm()
 
-    async def async_step_confirm(
-        self, user_input: dict[str, Any] | None = None
-    ) -> FlowResult:
-        """Confirm a discovery."""
-        if user_input is not None or not onboarding.async_is_onboarded(self.hass):
-            if not await self._detect_radio_type():
-                # This path probably will not happen now that we have
-                # more precise USB matching unless there is a problem
-                # with the device
-                return self.async_abort(reason="usb_probe_failed")
-
-            return await self.async_step_choose_formation_strategy()
-
-        return self.async_show_form(
-            step_id="confirm",
-            description_placeholders={CONF_NAME: self._title},
-        )
-
     async def async_step_zeroconf(
         self, discovery_info: zeroconf.ZeroconfServiceInfo
     ) -> FlowResult:
         """Handle zeroconf discovery."""
+
         # Hostname is format: livingroom.local.
         local_name = discovery_info.hostname[:-1]
         radio_type = discovery_info.properties.get("radio_type") or local_name
@@ -638,10 +647,6 @@ class ZhaConfigFlowHandler(BaseZhaFlow, config_entries.ConfigFlow, domain=DOMAIN
                 }
             )
 
-        # Check if already configured
-        if self._async_current_entries():
-            return self.async_abort(reason="single_instance_allowed")
-
         self.context["title_placeholders"] = {CONF_NAME: node_name}
         self._title = device_path
         self._device_path = device_path
@@ -653,15 +658,12 @@ class ZhaConfigFlowHandler(BaseZhaFlow, config_entries.ConfigFlow, domain=DOMAIN
         else:
             self._radio_type = RadioType.znp
 
-        return await self.async_step_manual_port_config()
+        return await self.async_step_confirm()
 
     async def async_step_hardware(
         self, data: dict[str, Any] | None = None
     ) -> FlowResult:
         """Handle hardware flow."""
-        if self._async_current_entries():
-            return self.async_abort(reason="single_instance_allowed")
-
         if not data:
             return self.async_abort(reason="invalid_hardware_data")
         if data.get("radio_type") != "efr32":
@@ -691,7 +693,7 @@ class ZhaConfigFlowHandler(BaseZhaFlow, config_entries.ConfigFlow, domain=DOMAIN
         self._device_path = device_settings[CONF_DEVICE_PATH]
         self._device_settings = device_settings
 
-        return await self.async_step_choose_formation_strategy()
+        return await self.async_step_confirm()
 
 
 class ZhaOptionsFlowHandler(BaseZhaFlow, config_entries.OptionsFlow):
diff --git a/tests/components/zha/test_config_flow.py b/tests/components/zha/test_config_flow.py
index 8a6496dbc5f..d65732a6ab8 100644
--- a/tests/components/zha/test_config_flow.py
+++ b/tests/components/zha/test_config_flow.py
@@ -107,22 +107,31 @@ async def test_zeroconf_discovery_znp(hass):
     flow = await hass.config_entries.flow.async_init(
         DOMAIN, context={"source": SOURCE_ZEROCONF}, data=service_info
     )
+    assert flow["step_id"] == "confirm"
+
+    # Confirm discovery
     result1 = await hass.config_entries.flow.async_configure(
         flow["flow_id"], user_input={}
     )
+    assert result1["step_id"] == "manual_port_config"
 
-    assert result1["type"] == FlowResultType.MENU
-    assert result1["step_id"] == "choose_formation_strategy"
-
+    # Confirm port settings
     result2 = await hass.config_entries.flow.async_configure(
-        result1["flow_id"],
+        result1["flow_id"], user_input={}
+    )
+
+    assert result2["type"] == FlowResultType.MENU
+    assert result2["step_id"] == "choose_formation_strategy"
+
+    result3 = await hass.config_entries.flow.async_configure(
+        result2["flow_id"],
         user_input={"next_step_id": config_flow.FORMATION_REUSE_SETTINGS},
     )
     await hass.async_block_till_done()
 
-    assert result2["type"] == FlowResultType.CREATE_ENTRY
-    assert result2["title"] == "socket://192.168.1.200:6638"
-    assert result2["data"] == {
+    assert result3["type"] == FlowResultType.CREATE_ENTRY
+    assert result3["title"] == "socket://192.168.1.200:6638"
+    assert result3["data"] == {
         CONF_DEVICE: {
             CONF_BAUDRATE: 115200,
             CONF_FLOWCONTROL: None,
@@ -148,22 +157,31 @@ async def test_zigate_via_zeroconf(setup_entry_mock, hass):
     flow = await hass.config_entries.flow.async_init(
         DOMAIN, context={"source": SOURCE_ZEROCONF}, data=service_info
     )
+    assert flow["step_id"] == "confirm"
+
+    # Confirm discovery
     result1 = await hass.config_entries.flow.async_configure(
         flow["flow_id"], user_input={}
     )
+    assert result1["step_id"] == "manual_port_config"
 
-    assert result1["type"] == FlowResultType.MENU
-    assert result1["step_id"] == "choose_formation_strategy"
-
+    # Confirm port settings
     result2 = await hass.config_entries.flow.async_configure(
-        result1["flow_id"],
+        result1["flow_id"], user_input={}
+    )
+
+    assert result2["type"] == FlowResultType.MENU
+    assert result2["step_id"] == "choose_formation_strategy"
+
+    result3 = await hass.config_entries.flow.async_configure(
+        result2["flow_id"],
         user_input={"next_step_id": config_flow.FORMATION_REUSE_SETTINGS},
     )
     await hass.async_block_till_done()
 
-    assert result2["type"] == FlowResultType.CREATE_ENTRY
-    assert result2["title"] == "socket://192.168.1.200:1234"
-    assert result2["data"] == {
+    assert result3["type"] == FlowResultType.CREATE_ENTRY
+    assert result3["title"] == "socket://192.168.1.200:1234"
+    assert result3["data"] == {
         CONF_DEVICE: {
             CONF_DEVICE_PATH: "socket://192.168.1.200:1234",
         },
@@ -187,22 +205,31 @@ async def test_efr32_via_zeroconf(hass):
     flow = await hass.config_entries.flow.async_init(
         DOMAIN, context={"source": SOURCE_ZEROCONF}, data=service_info
     )
+    assert flow["step_id"] == "confirm"
+
+    # Confirm discovery
     result1 = await hass.config_entries.flow.async_configure(
         flow["flow_id"], user_input={}
     )
+    assert result1["step_id"] == "manual_port_config"
 
-    assert result1["type"] == FlowResultType.MENU
-    assert result1["step_id"] == "choose_formation_strategy"
-
+    # Confirm port settings
     result2 = await hass.config_entries.flow.async_configure(
-        result1["flow_id"],
+        result1["flow_id"], user_input={}
+    )
+
+    assert result2["type"] == FlowResultType.MENU
+    assert result2["step_id"] == "choose_formation_strategy"
+
+    result3 = await hass.config_entries.flow.async_configure(
+        result2["flow_id"],
         user_input={"next_step_id": config_flow.FORMATION_REUSE_SETTINGS},
     )
     await hass.async_block_till_done()
 
-    assert result2["type"] == FlowResultType.CREATE_ENTRY
-    assert result2["title"] == "socket://192.168.1.200:6638"
-    assert result2["data"] == {
+    assert result3["type"] == FlowResultType.CREATE_ENTRY
+    assert result3["title"] == "socket://192.168.1.200:6638"
+    assert result3["data"] == {
         CONF_DEVICE: {
             CONF_DEVICE_PATH: "socket://192.168.1.200:6638",
             CONF_BAUDRATE: 115200,
@@ -282,6 +309,37 @@ async def test_discovery_via_zeroconf_ip_change_ignored(hass):
     }
 
 
+async def test_discovery_confirm_final_abort_if_entries(hass):
+    """Test discovery aborts if ZHA was set up after the confirmation dialog is shown."""
+    service_info = zeroconf.ZeroconfServiceInfo(
+        host="192.168.1.200",
+        addresses=["192.168.1.200"],
+        hostname="tube._tube_zb_gw._tcp.local.",
+        name="tube",
+        port=6053,
+        properties={"name": "tube_123456"},
+        type="mock_type",
+    )
+    flow = await hass.config_entries.flow.async_init(
+        DOMAIN, context={"source": SOURCE_ZEROCONF}, data=service_info
+    )
+    assert flow["step_id"] == "confirm"
+
+    # ZHA was somehow set up while we were in the config flow
+    with patch(
+        "homeassistant.config_entries.ConfigFlow._async_current_entries",
+        return_value=[MagicMock()],
+    ):
+        # Confirm discovery
+        result = await hass.config_entries.flow.async_configure(
+            flow["flow_id"], user_input={}
+        )
+
+    # Config will fail
+    assert result["type"] == FlowResultType.ABORT
+    assert result["reason"] == "single_instance_allowed"
+
+
 @patch(f"zigpy_znp.{PROBE_FUNCTION_PATH}", AsyncMock(return_value=True))
 async def test_discovery_via_usb(hass):
     """Test usb flow -- radio detected."""
@@ -293,15 +351,16 @@ async def test_discovery_via_usb(hass):
         description="zigbee radio",
         manufacturer="test",
     )
-    result = await hass.config_entries.flow.async_init(
+    result1 = await hass.config_entries.flow.async_init(
         DOMAIN, context={"source": SOURCE_USB}, data=discovery_info
     )
     await hass.async_block_till_done()
-    assert result["type"] == FlowResultType.FORM
-    assert result["step_id"] == "confirm"
+
+    assert result1["type"] == FlowResultType.FORM
+    assert result1["step_id"] == "confirm"
 
     result2 = await hass.config_entries.flow.async_configure(
-        result["flow_id"], user_input={}
+        result1["flow_id"], user_input={}
     )
     await hass.async_block_till_done()
 
@@ -878,17 +937,30 @@ async def test_hardware(onboarded, hass):
             DOMAIN, context={"source": "hardware"}, data=data
         )
 
-    assert result1["type"] == FlowResultType.MENU
-    assert result1["step_id"] == "choose_formation_strategy"
+    if onboarded:
+        # Confirm discovery
+        assert result1["type"] == FlowResultType.FORM
+        assert result1["step_id"] == "confirm"
 
-    result2 = await hass.config_entries.flow.async_configure(
-        result1["flow_id"],
+        result2 = await hass.config_entries.flow.async_configure(
+            result1["flow_id"],
+            user_input={},
+        )
+    else:
+        # No need to confirm
+        result2 = result1
+
+    assert result2["type"] == FlowResultType.MENU
+    assert result2["step_id"] == "choose_formation_strategy"
+
+    result3 = await hass.config_entries.flow.async_configure(
+        result2["flow_id"],
         user_input={"next_step_id": config_flow.FORMATION_REUSE_SETTINGS},
     )
     await hass.async_block_till_done()
 
-    assert result2["title"] == "Yellow"
-    assert result2["data"] == {
+    assert result3["title"] == "Yellow"
+    assert result3["data"] == {
         CONF_DEVICE: {
             CONF_BAUDRATE: 115200,
             CONF_FLOWCONTROL: "hardware",