From 4313d4b26ba64ffd5f15e3e2a85d38c0b9b4e551 Mon Sep 17 00:00:00 2001
From: "J. Nick Koston" <nick@koston.org>
Date: Mon, 25 May 2020 11:17:30 -0500
Subject: [PATCH] Ensure homekit bridge state is restored before creating
 devices (#36098)

* Ensure homekit bridge state is restored before creating devices

* Tests to ensure homekit device registry entry is stable

* remove stray continue
---
 homeassistant/components/homekit/__init__.py | 44 ++++++++++++++++++--
 tests/components/homekit/test_homekit.py     | 41 +++++++++++++++++-
 2 files changed, 81 insertions(+), 4 deletions(-)

diff --git a/homeassistant/components/homekit/__init__.py b/homeassistant/components/homekit/__init__.py
index adbf79128e3..428f8e30abf 100644
--- a/homeassistant/components/homekit/__init__.py
+++ b/homeassistant/components/homekit/__init__.py
@@ -2,6 +2,7 @@
 import asyncio
 import ipaddress
 import logging
+import os
 
 from aiohttp import web
 import voluptuous as vol
@@ -49,6 +50,7 @@ from .const import (
     ATTR_SOFTWARE_VERSION,
     ATTR_VALUE,
     BRIDGE_NAME,
+    BRIDGE_SERIAL_NUMBER,
     CONF_ADVERTISE_IP,
     CONF_AUTO_START,
     CONF_ENTITY_CONFIG,
@@ -434,6 +436,13 @@ class HomeKit:
             interface_choice=self._interface_choice,
         )
 
+        # If we do not load the mac address will be wrong
+        # as pyhap uses a random one until state is restored
+        if os.path.exists(persist_file):
+            self.driver.load()
+        else:
+            self.driver.persist()
+
         self.bridge = HomeBridge(self.hass, self.driver, self._name)
         if self._safe_mode:
             _LOGGER.debug("Safe_mode selected for %s", self._name)
@@ -540,16 +549,45 @@ class HomeKit:
     @callback
     def _async_register_bridge(self, dev_reg):
         """Register the bridge as a device so homekit_controller and exclude it from discovery."""
+        formatted_mac = device_registry.format_mac(self.driver.state.mac)
+        # Connections and identifiers are both used here.
+        #
+        # connections exists so homekit_controller can know the
+        # virtual mac address of the bridge and know to not offer
+        # it via discovery.
+        #
+        # identifiers is used as well since the virtual mac may change
+        # because it will not survive manual pairing resets (deleting state file)
+        # which we have trained users to do over the past few years
+        # because this was the way you had to fix homekit when pairing
+        # failed.
+        #
+        connection = (device_registry.CONNECTION_NETWORK_MAC, formatted_mac)
+        identifier = (DOMAIN, self._entry_id, BRIDGE_SERIAL_NUMBER)
+        self._async_purge_old_bridges(dev_reg, identifier, connection)
         dev_reg.async_get_or_create(
             config_entry_id=self._entry_id,
-            connections={
-                (device_registry.CONNECTION_NETWORK_MAC, self.driver.state.mac)
-            },
+            identifiers={identifier},
+            connections={connection},
             manufacturer=MANUFACTURER,
             name=self._name,
             model="Home Assistant HomeKit Bridge",
         )
 
+    @callback
+    def _async_purge_old_bridges(self, dev_reg, identifier, connection):
+        """Purge bridges that exist from failed pairing or manual resets."""
+        devices_to_purge = []
+        for entry in dev_reg.devices.values():
+            if self._entry_id in entry.config_entries and (
+                identifier not in entry.identifiers
+                or connection not in entry.connections
+            ):
+                devices_to_purge.append(entry.id)
+
+        for device_id in devices_to_purge:
+            dev_reg.async_remove_device(device_id)
+
     def _start(self, bridged_states):
         from . import (  # noqa: F401 pylint: disable=unused-import, import-outside-toplevel
             type_cameras,
diff --git a/tests/components/homekit/test_homekit.py b/tests/components/homekit/test_homekit.py
index c0e2ea90fba..b016997b7c9 100644
--- a/tests/components/homekit/test_homekit.py
+++ b/tests/components/homekit/test_homekit.py
@@ -19,6 +19,7 @@ from homeassistant.components.homekit.accessories import HomeBridge
 from homeassistant.components.homekit.const import (
     AID_STORAGE,
     BRIDGE_NAME,
+    BRIDGE_SERIAL_NUMBER,
     CONF_AUTO_START,
     CONF_ENTRY_INDEX,
     CONF_SAFE_MODE,
@@ -458,7 +459,7 @@ async def test_homekit_entity_filter(hass):
         assert mock_get_acc.called is False
 
 
-async def test_homekit_start(hass, hk_driver, debounce_patcher):
+async def test_homekit_start(hass, hk_driver, device_reg, debounce_patcher):
     """Test HomeKit start method."""
     entry = await async_init_integration(hass)
 
@@ -480,6 +481,15 @@ async def test_homekit_start(hass, hk_driver, debounce_patcher):
     homekit.driver = hk_driver
     homekit._filter = Mock(return_value=True)
 
+    connection = (device_registry.CONNECTION_NETWORK_MAC, "AA:BB:CC:DD:EE:FF")
+    bridge_with_wrong_mac = device_reg.async_get_or_create(
+        config_entry_id=entry.entry_id,
+        connections={connection},
+        manufacturer="Any",
+        name="Any",
+        model="Home Assistant HomeKit Bridge",
+    )
+
     hass.states.async_set("light.demo", "on")
     state = hass.states.async_all()[0]
 
@@ -505,6 +515,35 @@ async def test_homekit_start(hass, hk_driver, debounce_patcher):
     await hass.async_block_till_done()
     assert not hk_driver_start.called
 
+    assert device_reg.async_get(bridge_with_wrong_mac.id) is None
+
+    device = device_reg.async_get_device(
+        {(DOMAIN, entry.entry_id, BRIDGE_SERIAL_NUMBER)}, {}
+    )
+    assert device
+    formatted_mac = device_registry.format_mac(homekit.driver.state.mac)
+    assert (device_registry.CONNECTION_NETWORK_MAC, formatted_mac) in device.connections
+
+    # Start again to make sure the registry entry is kept
+    homekit.status = STATUS_READY
+    with patch(f"{PATH_HOMEKIT}.HomeKit.add_bridge_accessory") as mock_add_acc, patch(
+        f"{PATH_HOMEKIT}.show_setup_message"
+    ) as mock_setup_msg, patch(
+        "pyhap.accessory_driver.AccessoryDriver.add_accessory"
+    ) as hk_driver_add_acc, patch(
+        "pyhap.accessory_driver.AccessoryDriver.start"
+    ) as hk_driver_start:
+        await homekit.async_start()
+
+    device = device_reg.async_get_device(
+        {(DOMAIN, entry.entry_id, BRIDGE_SERIAL_NUMBER)}, {}
+    )
+    assert device
+    formatted_mac = device_registry.format_mac(homekit.driver.state.mac)
+    assert (device_registry.CONNECTION_NETWORK_MAC, formatted_mac) in device.connections
+
+    assert len(device_reg.devices) == 1
+
 
 async def test_homekit_start_with_a_broken_accessory(hass, hk_driver, debounce_patcher):
     """Test HomeKit start method."""