diff --git a/homeassistant/helpers/device_registry.py b/homeassistant/helpers/device_registry.py index fcb9b4ddcb8..388db62ebae 100644 --- a/homeassistant/helpers/device_registry.py +++ b/homeassistant/helpers/device_registry.py @@ -47,12 +47,12 @@ class DeletedDeviceEntry: identifiers: Set[Tuple[str, str]] = attr.ib() id: str = attr.ib() - def to_device_entry(self): + def to_device_entry(self, config_entry_id, connections, identifiers): """Create DeviceEntry from DeletedDeviceEntry.""" return DeviceEntry( - config_entries=self.config_entries, - connections=self.connections, - identifiers=self.identifiers, + config_entries={config_entry_id}, + connections=self.connections & connections, + identifiers=self.identifiers & identifiers, id=self.id, is_new=True, ) @@ -236,7 +236,9 @@ class DeviceRegistry: device = DeviceEntry(is_new=True) else: self._remove_device(deleted_device) - device = deleted_device.to_device_entry() + device = deleted_device.to_device_entry( + config_entry_id, connections, identifiers + ) self._add_device(device) if default_manufacturer is not _UNDEF and device.manufacturer is None: @@ -338,7 +340,7 @@ class DeviceRegistry: config_entries = config_entries - {remove_config_entry_id} - if config_entries is not old.config_entries: + if config_entries != old.config_entries: changes["config_entries"] = config_entries for attr_name, value in ( diff --git a/tests/helpers/test_device_registry.py b/tests/helpers/test_device_registry.py index 7c9e8a6e262..85ff693f261 100644 --- a/tests/helpers/test_device_registry.py +++ b/tests/helpers/test_device_registry.py @@ -458,7 +458,35 @@ async def test_loading_saving_data(hass, registry): registry.async_remove_device(orig_light2.id) - assert len(registry.devices) == 2 + orig_light3 = registry.async_get_or_create( + config_entry_id="789", + connections={(device_registry.CONNECTION_NETWORK_MAC, "34:56:AB:CD:EF:12")}, + identifiers={("hue", "abc")}, + manufacturer="manufacturer", + model="light", + ) + + registry.async_get_or_create( + config_entry_id="abc", + connections={(device_registry.CONNECTION_NETWORK_MAC, "34:56:AB:CD:EF:12")}, + identifiers={("abc", "123")}, + manufacturer="manufacturer", + model="light", + ) + + registry.async_remove_device(orig_light3.id) + + orig_light4 = registry.async_get_or_create( + config_entry_id="789", + connections={(device_registry.CONNECTION_NETWORK_MAC, "34:56:AB:CD:EF:12")}, + identifiers={("hue", "abc")}, + manufacturer="manufacturer", + model="light", + ) + + assert orig_light4.id == orig_light3.id + + assert len(registry.devices) == 3 assert len(registry.deleted_devices) == 1 orig_via = registry.async_update_device( @@ -476,9 +504,11 @@ async def test_loading_saving_data(hass, registry): new_via = registry2.async_get_device({("hue", "0123")}, set()) new_light = registry2.async_get_device({("hue", "456")}, set()) + new_light4 = registry2.async_get_device({("hue", "abc")}, set()) assert orig_via == new_via assert orig_light == new_light + assert orig_light4 == new_light4 async def test_no_unnecessary_changes(registry): @@ -841,6 +871,104 @@ async def test_restore_simple_device(hass, registry, update_events): assert update_events[3]["device_id"] == entry3.id +async def test_restore_shared_device(hass, registry, update_events): + """Make sure device id is stable for shared devices.""" + entry = registry.async_get_or_create( + config_entry_id="123", + connections={(device_registry.CONNECTION_NETWORK_MAC, "12:34:56:AB:CD:EF")}, + identifiers={("entry_123", "0123")}, + manufacturer="manufacturer", + model="model", + ) + + assert len(registry.devices) == 1 + assert len(registry.deleted_devices) == 0 + + registry.async_get_or_create( + config_entry_id="234", + connections={(device_registry.CONNECTION_NETWORK_MAC, "12:34:56:AB:CD:EF")}, + identifiers={("entry_234", "2345")}, + manufacturer="manufacturer", + model="model", + ) + + assert len(registry.devices) == 1 + assert len(registry.deleted_devices) == 0 + + registry.async_remove_device(entry.id) + + assert len(registry.devices) == 0 + assert len(registry.deleted_devices) == 1 + + entry2 = registry.async_get_or_create( + config_entry_id="123", + connections={(device_registry.CONNECTION_NETWORK_MAC, "12:34:56:AB:CD:EF")}, + identifiers={("entry_123", "0123")}, + manufacturer="manufacturer", + model="model", + ) + + assert entry.id == entry2.id + assert len(registry.devices) == 1 + assert len(registry.deleted_devices) == 0 + + assert isinstance(entry2.config_entries, set) + assert isinstance(entry2.connections, set) + assert isinstance(entry2.identifiers, set) + + registry.async_remove_device(entry.id) + + entry3 = registry.async_get_or_create( + config_entry_id="234", + connections={(device_registry.CONNECTION_NETWORK_MAC, "12:34:56:AB:CD:EF")}, + identifiers={("entry_234", "2345")}, + manufacturer="manufacturer", + model="model", + ) + + assert entry.id == entry3.id + assert len(registry.devices) == 1 + assert len(registry.deleted_devices) == 0 + + assert isinstance(entry3.config_entries, set) + assert isinstance(entry3.connections, set) + assert isinstance(entry3.identifiers, set) + + entry4 = registry.async_get_or_create( + config_entry_id="123", + connections={(device_registry.CONNECTION_NETWORK_MAC, "12:34:56:AB:CD:EF")}, + identifiers={("entry_123", "0123")}, + manufacturer="manufacturer", + model="model", + ) + + assert entry.id == entry4.id + assert len(registry.devices) == 1 + assert len(registry.deleted_devices) == 0 + + assert isinstance(entry4.config_entries, set) + assert isinstance(entry4.connections, set) + assert isinstance(entry4.identifiers, set) + + await hass.async_block_till_done() + + assert len(update_events) == 7 + assert update_events[0]["action"] == "create" + assert update_events[0]["device_id"] == entry.id + assert update_events[1]["action"] == "update" + assert update_events[1]["device_id"] == entry.id + assert update_events[2]["action"] == "remove" + assert update_events[2]["device_id"] == entry.id + assert update_events[3]["action"] == "create" + assert update_events[3]["device_id"] == entry.id + assert update_events[4]["action"] == "remove" + assert update_events[4]["device_id"] == entry.id + assert update_events[5]["action"] == "create" + assert update_events[5]["device_id"] == entry.id + assert update_events[1]["action"] == "update" + assert update_events[1]["device_id"] == entry.id + + async def test_get_or_create_empty_then_set_default_values(hass, registry): """Test creating an entry, then setting default name, model, manufacturer.""" entry = registry.async_get_or_create(