diff --git a/homeassistant/components/tplink/__init__.py b/homeassistant/components/tplink/__init__.py index 6d300f68aa0..83cfc733716 100644 --- a/homeassistant/components/tplink/__init__.py +++ b/homeassistant/components/tplink/__init__.py @@ -3,6 +3,7 @@ from __future__ import annotations import asyncio +from collections.abc import Iterable from datetime import timedelta import logging from typing import Any @@ -282,6 +283,28 @@ def mac_alias(mac: str) -> str: return mac.replace(":", "")[-4:].upper() +def _mac_connection_or_none(device: dr.DeviceEntry) -> str | None: + return next( + ( + conn + for type_, conn in device.connections + if type_ == dr.CONNECTION_NETWORK_MAC + ), + None, + ) + + +def _device_id_is_mac_or_none(mac: str, device_ids: Iterable[str]) -> str | None: + # Previously only iot devices had child devices and iot devices use + # the upper and lcase MAC addresses as device_id so match on case + # insensitive mac address as the parent device. + upper_mac = mac.upper() + return next( + (device_id for device_id in device_ids if device_id.upper() == upper_mac), + None, + ) + + async def async_migrate_entry(hass: HomeAssistant, config_entry: ConfigEntry) -> bool: """Migrate old entry.""" version = config_entry.version @@ -298,49 +321,48 @@ async def async_migrate_entry(hass: HomeAssistant, config_entry: ConfigEntry) -> # always be linked into one device. dev_reg = dr.async_get(hass) for device in dr.async_entries_for_config_entry(dev_reg, config_entry.entry_id): - new_identifiers: set[tuple[str, str]] | None = None - if len(device.identifiers) > 1 and ( - mac := next( - iter( - [ - conn[1] - for conn in device.connections - if conn[0] == dr.CONNECTION_NETWORK_MAC - ] - ), - None, + original_identifiers = device.identifiers + # Get only the tplink identifier, could be tapo or other integrations. + tplink_identifiers = [ + ident[1] for ident in original_identifiers if ident[0] == DOMAIN + ] + # Nothing to fix if there's only one identifier. mac connection + # should never be none but if it is there's no problem. + if len(tplink_identifiers) <= 1 or not ( + mac := _mac_connection_or_none(device) + ): + continue + if not ( + tplink_parent_device_id := _device_id_is_mac_or_none( + mac, tplink_identifiers ) ): - for identifier in device.identifiers: - # Previously only iot devices that use the MAC address as - # device_id had child devices so check for mac as the - # parent device. - if identifier[0] == DOMAIN and identifier[1].upper() == mac.upper(): - new_identifiers = {identifier} - break - if new_identifiers: - dev_reg.async_update_device( - device.id, new_identifiers=new_identifiers - ) - _LOGGER.debug( - "Replaced identifiers for device %s (%s): %s with: %s", - device.name, - device.model, - device.identifiers, - new_identifiers, - ) - else: - # No match on mac so raise an error. - _LOGGER.error( - "Unable to replace identifiers for device %s (%s): %s", - device.name, - device.model, - device.identifiers, - ) + # No match on mac so raise an error. + _LOGGER.error( + "Unable to replace identifiers for device %s (%s): %s", + device.name, + device.model, + device.identifiers, + ) + continue + # Retain any identifiers for other domains + new_identifiers = { + ident for ident in device.identifiers if ident[0] != DOMAIN + } + new_identifiers.add((DOMAIN, tplink_parent_device_id)) + dev_reg.async_update_device(device.id, new_identifiers=new_identifiers) + _LOGGER.debug( + "Replaced identifiers for device %s (%s): %s with: %s", + device.name, + device.model, + original_identifiers, + new_identifiers, + ) minor_version = 3 hass.config_entries.async_update_entry(config_entry, minor_version=3) - _LOGGER.debug("Migration to version %s.%s successful", version, minor_version) + + _LOGGER.debug("Migration to version %s.%s complete", version, minor_version) if version == 1 and minor_version == 3: # credentials_hash stored in the device_config should be moved to data. diff --git a/tests/components/tplink/__init__.py b/tests/components/tplink/__init__.py index b3092d62904..d12858017cc 100644 --- a/tests/components/tplink/__init__.py +++ b/tests/components/tplink/__init__.py @@ -49,6 +49,7 @@ ALIAS = "My Bulb" MODEL = "HS100" MAC_ADDRESS = "aa:bb:cc:dd:ee:ff" DEVICE_ID = "123456789ABCDEFGH" +DEVICE_ID_MAC = "AA:BB:CC:DD:EE:FF" DHCP_FORMATTED_MAC_ADDRESS = MAC_ADDRESS.replace(":", "") MAC_ADDRESS2 = "11:22:33:44:55:66" DEFAULT_ENTRY_TITLE = f"{ALIAS} {MODEL}" diff --git a/tests/components/tplink/test_init.py b/tests/components/tplink/test_init.py index bfb7e02b63d..c5c5e2ce6db 100644 --- a/tests/components/tplink/test_init.py +++ b/tests/components/tplink/test_init.py @@ -36,6 +36,8 @@ from . import ( CREATE_ENTRY_DATA_AUTH, CREATE_ENTRY_DATA_LEGACY, DEVICE_CONFIG_AUTH, + DEVICE_ID, + DEVICE_ID_MAC, IP_ADDRESS, MAC_ADDRESS, _mocked_device, @@ -404,19 +406,48 @@ async def test_feature_no_category( @pytest.mark.parametrize( - ("identifier_base", "expected_message", "expected_count"), + ("device_id", "id_count", "domains", "expected_message"), [ - pytest.param("C0:06:C3:42:54:2B", "Replaced", 1, id="success"), - pytest.param("123456789", "Unable to replace", 3, id="failure"), + pytest.param(DEVICE_ID_MAC, 1, [DOMAIN], None, id="mac-id-no-children"), + pytest.param(DEVICE_ID_MAC, 3, [DOMAIN], "Replaced", id="mac-id-children"), + pytest.param( + DEVICE_ID_MAC, + 1, + [DOMAIN, "other"], + None, + id="mac-id-no-children-other-domain", + ), + pytest.param( + DEVICE_ID_MAC, + 3, + [DOMAIN, "other"], + "Replaced", + id="mac-id-children-other-domain", + ), + pytest.param(DEVICE_ID, 1, [DOMAIN], None, id="not-mac-id-no-children"), + pytest.param( + DEVICE_ID, 3, [DOMAIN], "Unable to replace", id="not-mac-children" + ), + pytest.param( + DEVICE_ID, 1, [DOMAIN, "other"], None, id="not-mac-no-children-other-domain" + ), + pytest.param( + DEVICE_ID, + 3, + [DOMAIN, "other"], + "Unable to replace", + id="not-mac-children-other-domain", + ), ], ) async def test_unlink_devices( hass: HomeAssistant, device_registry: dr.DeviceRegistry, caplog: pytest.LogCaptureFixture, - identifier_base, + device_id, + id_count, + domains, expected_message, - expected_count, ) -> None: """Test for unlinking child device ids.""" entry = MockConfigEntry( @@ -429,43 +460,54 @@ async def test_unlink_devices( ) entry.add_to_hass(hass) - # Setup initial device registry, with linkages - mac = "C0:06:C3:42:54:2B" - identifiers = [ - (DOMAIN, identifier_base), - (DOMAIN, f"{identifier_base}_0001"), - (DOMAIN, f"{identifier_base}_0002"), + # Generate list of test identifiers + test_identifiers = [ + (domain, f"{device_id}{"" if i == 0 else f"_000{i}"}") + for i in range(id_count) + for domain in domains ] + update_msg_fragment = "identifiers for device dummy (hs300):" + update_msg = f"{expected_message} {update_msg_fragment}" if expected_message else "" + + # Expected identifiers should include all other domains or all the newer non-mac device ids + # or just the parent mac device id + expected_identifiers = [ + (domain, device_id) + for domain, device_id in test_identifiers + if domain != DOMAIN + or device_id.startswith(DEVICE_ID) + or device_id == DEVICE_ID_MAC + ] + device_registry.async_get_or_create( config_entry_id="123456", connections={ - (dr.CONNECTION_NETWORK_MAC, mac.lower()), + (dr.CONNECTION_NETWORK_MAC, MAC_ADDRESS), }, - identifiers=set(identifiers), + identifiers=set(test_identifiers), model="hs300", name="dummy", ) device_entries = dr.async_entries_for_config_entry(device_registry, entry.entry_id) assert device_entries[0].connections == { - (dr.CONNECTION_NETWORK_MAC, mac.lower()), + (dr.CONNECTION_NETWORK_MAC, MAC_ADDRESS), } - assert device_entries[0].identifiers == set(identifiers) + assert device_entries[0].identifiers == set(test_identifiers) await hass.config_entries.async_setup(entry.entry_id) await hass.async_block_till_done() device_entries = dr.async_entries_for_config_entry(device_registry, entry.entry_id) - assert device_entries[0].connections == {(dr.CONNECTION_NETWORK_MAC, mac.lower())} - # If expected count is 1 will be the first identifier only - expected_identifiers = identifiers[:expected_count] + assert device_entries[0].connections == {(dr.CONNECTION_NETWORK_MAC, MAC_ADDRESS)} + assert device_entries[0].identifiers == set(expected_identifiers) assert entry.version == 1 assert entry.minor_version == 4 - msg = f"{expected_message} identifiers for device dummy (hs300): {set(identifiers)}" - assert msg in caplog.text + assert update_msg in caplog.text + assert "Migration to version 1.3 complete" in caplog.text async def test_move_credentials_hash(