Fix rainbird duplicate devices (#104528)

* Repair duplicate devices added to the rainbird integration

* Apply suggestions from code review

Co-authored-by: Martin Hjelmare <marhje52@gmail.com>

* Update tests/components/rainbird/test_init.py

* Remove use of config_entry.async_setup

---------

Co-authored-by: Martin Hjelmare <marhje52@gmail.com>
pull/104601/head
Allen Porter 2023-11-27 07:43:03 -08:00 committed by GitHub
parent 74d7d02833
commit 664aca2c68
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 233 additions and 17 deletions

View File

@ -10,10 +10,9 @@ from homeassistant.config_entries import ConfigEntry
from homeassistant.const import CONF_HOST, CONF_MAC, CONF_PASSWORD, Platform
from homeassistant.core import HomeAssistant
from homeassistant.exceptions import ConfigEntryNotReady
from homeassistant.helpers import entity_registry as er
from homeassistant.helpers import device_registry as dr, entity_registry as er
from homeassistant.helpers.aiohttp_client import async_get_clientsession
from homeassistant.helpers.device_registry import format_mac
from homeassistant.helpers.entity_registry import async_entries_for_config_entry
from .const import CONF_SERIAL_NUMBER
from .coordinator import RainbirdData
@ -55,6 +54,13 @@ async def async_setup_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool:
format_mac(mac_address),
str(entry.data[CONF_SERIAL_NUMBER]),
)
_async_fix_device_id(
hass,
dr.async_get(hass),
entry.entry_id,
format_mac(mac_address),
str(entry.data[CONF_SERIAL_NUMBER]),
)
try:
model_info = await controller.get_model_and_version()
@ -124,7 +130,7 @@ def _async_fix_entity_unique_id(
serial_number: str,
) -> None:
"""Migrate existing entity if current one can't be found and an old one exists."""
entity_entries = async_entries_for_config_entry(entity_registry, config_entry_id)
entity_entries = er.async_entries_for_config_entry(entity_registry, config_entry_id)
for entity_entry in entity_entries:
unique_id = str(entity_entry.unique_id)
if unique_id.startswith(mac_address):
@ -137,6 +143,70 @@ def _async_fix_entity_unique_id(
)
def _async_device_entry_to_keep(
old_entry: dr.DeviceEntry, new_entry: dr.DeviceEntry
) -> dr.DeviceEntry:
"""Determine which device entry to keep when there are duplicates.
As we transitioned to new unique ids, we did not update existing device entries
and as a result there are devices with both the old and new unique id format. We
have to pick which one to keep, and preferably this can repair things if the
user previously renamed devices.
"""
# Prefer the new device if the user already gave it a name or area. Otherwise,
# do the same for the old entry. If no entries have been modified then keep the new one.
if new_entry.disabled_by is None and (
new_entry.area_id is not None or new_entry.name_by_user is not None
):
return new_entry
if old_entry.disabled_by is None and (
old_entry.area_id is not None or old_entry.name_by_user is not None
):
return old_entry
return new_entry if new_entry.disabled_by is None else old_entry
def _async_fix_device_id(
hass: HomeAssistant,
device_registry: dr.DeviceRegistry,
config_entry_id: str,
mac_address: str,
serial_number: str,
) -> None:
"""Migrate existing device identifiers to the new format.
This will rename any device ids that are prefixed with the serial number to be prefixed
with the mac address. This also cleans up from a bug that allowed devices to exist
in both the old and new format.
"""
device_entries = dr.async_entries_for_config_entry(device_registry, config_entry_id)
device_entry_map = {}
migrations = {}
for device_entry in device_entries:
unique_id = next(iter(device_entry.identifiers))[1]
device_entry_map[unique_id] = device_entry
if (suffix := unique_id.removeprefix(str(serial_number))) != unique_id:
migrations[unique_id] = f"{mac_address}{suffix}"
for unique_id, new_unique_id in migrations.items():
old_entry = device_entry_map[unique_id]
if (new_entry := device_entry_map.get(new_unique_id)) is not None:
# Device entries exist for both the old and new format and one must be removed
entry_to_keep = _async_device_entry_to_keep(old_entry, new_entry)
if entry_to_keep == new_entry:
_LOGGER.debug("Removing device entry %s", unique_id)
device_registry.async_remove_device(old_entry.id)
continue
# Remove new entry and update old entry to new id below
_LOGGER.debug("Removing device entry %s", new_unique_id)
device_registry.async_remove_device(new_entry.id)
_LOGGER.debug("Updating device id from %s to %s", unique_id, new_unique_id)
device_registry.async_update_device(
old_entry.id, new_identifiers={(DOMAIN, new_unique_id)}
)
async def async_unload_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool:
"""Unload a config entry."""

View File

@ -3,6 +3,7 @@
from __future__ import annotations
from http import HTTPStatus
from typing import Any
import pytest
@ -10,7 +11,7 @@ from homeassistant.components.rainbird.const import DOMAIN
from homeassistant.config_entries import ConfigEntryState
from homeassistant.const import CONF_MAC
from homeassistant.core import HomeAssistant
from homeassistant.helpers import entity_registry as er
from homeassistant.helpers import device_registry as dr, entity_registry as er
from .conftest import (
CONFIG_ENTRY_DATA,
@ -35,7 +36,7 @@ async def test_init_success(
) -> None:
"""Test successful setup and unload."""
await config_entry.async_setup(hass)
await hass.config_entries.async_setup(config_entry.entry_id)
assert config_entry.state == ConfigEntryState.LOADED
await hass.config_entries.async_unload(config_entry.entry_id)
@ -86,7 +87,7 @@ async def test_communication_failure(
config_entry_state: list[ConfigEntryState],
) -> None:
"""Test unable to talk to device on startup, which fails setup."""
await config_entry.async_setup(hass)
await hass.config_entries.async_setup(config_entry.entry_id)
assert config_entry.state == config_entry_state
@ -115,7 +116,7 @@ async def test_fix_unique_id(
assert entries[0].unique_id is None
assert entries[0].data.get(CONF_MAC) is None
await config_entry.async_setup(hass)
await hass.config_entries.async_setup(config_entry.entry_id)
assert config_entry.state == ConfigEntryState.LOADED
# Verify config entry now has a unique id
@ -167,7 +168,7 @@ async def test_fix_unique_id_failure(
responses.insert(0, initial_response)
await config_entry.async_setup(hass)
await hass.config_entries.async_setup(config_entry.entry_id)
# Config entry is loaded, but not updated
assert config_entry.state == ConfigEntryState.LOADED
assert config_entry.unique_id is None
@ -202,14 +203,10 @@ async def test_fix_unique_id_duplicate(
responses.append(mock_json_response(WIFI_PARAMS_RESPONSE))
responses.extend(responses_copy)
await config_entry.async_setup(hass)
await hass.config_entries.async_setup(config_entry.entry_id)
assert config_entry.state == ConfigEntryState.LOADED
assert config_entry.unique_id == MAC_ADDRESS_UNIQUE_ID
await other_entry.async_setup(hass)
# Config entry unique id could not be updated since it already exists
assert other_entry.state == ConfigEntryState.SETUP_ERROR
assert "Unable to fix missing unique id (already exists)" in caplog.text
await hass.async_block_till_done()
@ -221,34 +218,51 @@ async def test_fix_unique_id_duplicate(
"config_entry_unique_id",
"serial_number",
"entity_unique_id",
"device_identifier",
"expected_unique_id",
"expected_device_identifier",
),
[
(SERIAL_NUMBER, SERIAL_NUMBER, SERIAL_NUMBER, MAC_ADDRESS_UNIQUE_ID),
(
SERIAL_NUMBER,
SERIAL_NUMBER,
SERIAL_NUMBER,
str(SERIAL_NUMBER),
MAC_ADDRESS_UNIQUE_ID,
MAC_ADDRESS_UNIQUE_ID,
),
(
SERIAL_NUMBER,
SERIAL_NUMBER,
f"{SERIAL_NUMBER}-rain-delay",
f"{SERIAL_NUMBER}-1",
f"{MAC_ADDRESS_UNIQUE_ID}-rain-delay",
f"{MAC_ADDRESS_UNIQUE_ID}-1",
),
("0", 0, "0", MAC_ADDRESS_UNIQUE_ID),
("0", 0, "0", "0", MAC_ADDRESS_UNIQUE_ID, MAC_ADDRESS_UNIQUE_ID),
(
"0",
0,
"0-rain-delay",
"0-1",
f"{MAC_ADDRESS_UNIQUE_ID}-rain-delay",
f"{MAC_ADDRESS_UNIQUE_ID}-1",
),
(
MAC_ADDRESS_UNIQUE_ID,
SERIAL_NUMBER,
MAC_ADDRESS_UNIQUE_ID,
MAC_ADDRESS_UNIQUE_ID,
MAC_ADDRESS_UNIQUE_ID,
MAC_ADDRESS_UNIQUE_ID,
),
(
MAC_ADDRESS_UNIQUE_ID,
SERIAL_NUMBER,
f"{MAC_ADDRESS_UNIQUE_ID}-rain-delay",
f"{MAC_ADDRESS_UNIQUE_ID}-1",
f"{MAC_ADDRESS_UNIQUE_ID}-rain-delay",
f"{MAC_ADDRESS_UNIQUE_ID}-1",
),
],
ids=(
@ -264,18 +278,150 @@ async def test_fix_entity_unique_ids(
hass: HomeAssistant,
config_entry: MockConfigEntry,
entity_unique_id: str,
device_identifier: str,
expected_unique_id: str,
expected_device_identifier: str,
entity_registry: er.EntityRegistry,
device_registry: dr.DeviceRegistry,
) -> None:
"""Test fixing entity unique ids from old unique id formats."""
entity_registry = er.async_get(hass)
entity_entry = entity_registry.async_get_or_create(
DOMAIN, "number", unique_id=entity_unique_id, config_entry=config_entry
)
device_entry = device_registry.async_get_or_create(
identifiers={(DOMAIN, device_identifier)},
config_entry_id=config_entry.entry_id,
serial_number=config_entry.data["serial_number"],
)
await config_entry.async_setup(hass)
await hass.config_entries.async_setup(config_entry.entry_id)
assert config_entry.state == ConfigEntryState.LOADED
entity_entry = entity_registry.async_get(entity_entry.id)
assert entity_entry
assert entity_entry.unique_id == expected_unique_id
device_entry = device_registry.async_get_device(
{(DOMAIN, expected_device_identifier)}
)
assert device_entry
assert device_entry.identifiers == {(DOMAIN, expected_device_identifier)}
@pytest.mark.parametrize(
(
"entry1_updates",
"entry2_updates",
"expected_device_name",
"expected_disabled_by",
),
[
({}, {}, None, None),
(
{
"name_by_user": "Front Sprinkler",
},
{},
"Front Sprinkler",
None,
),
(
{},
{
"name_by_user": "Front Sprinkler",
},
"Front Sprinkler",
None,
),
(
{
"name_by_user": "Sprinkler 1",
},
{
"name_by_user": "Sprinkler 2",
},
"Sprinkler 2",
None,
),
(
{
"disabled_by": dr.DeviceEntryDisabler.USER,
},
{},
None,
None,
),
(
{},
{
"disabled_by": dr.DeviceEntryDisabler.USER,
},
None,
None,
),
(
{
"disabled_by": dr.DeviceEntryDisabler.USER,
},
{
"disabled_by": dr.DeviceEntryDisabler.USER,
},
None,
dr.DeviceEntryDisabler.USER,
),
],
ids=[
"duplicates",
"prefer-old-name",
"prefer-new-name",
"both-names-prefers-new",
"old-disabled-prefer-new",
"new-disabled-prefer-old",
"both-disabled",
],
)
async def test_fix_duplicate_device_ids(
hass: HomeAssistant,
config_entry: MockConfigEntry,
device_registry: dr.DeviceRegistry,
entry1_updates: dict[str, Any],
entry2_updates: dict[str, Any],
expected_device_name: str | None,
expected_disabled_by: dr.DeviceEntryDisabler | None,
) -> None:
"""Test fixing duplicate device ids."""
entry1 = device_registry.async_get_or_create(
identifiers={(DOMAIN, str(SERIAL_NUMBER))},
config_entry_id=config_entry.entry_id,
serial_number=config_entry.data["serial_number"],
)
device_registry.async_update_device(entry1.id, **entry1_updates)
entry2 = device_registry.async_get_or_create(
identifiers={(DOMAIN, MAC_ADDRESS_UNIQUE_ID)},
config_entry_id=config_entry.entry_id,
serial_number=config_entry.data["serial_number"],
)
device_registry.async_update_device(entry2.id, **entry2_updates)
device_entries = dr.async_entries_for_config_entry(
device_registry, config_entry.entry_id
)
assert len(device_entries) == 2
await hass.config_entries.async_setup(config_entry.entry_id)
assert config_entry.state == ConfigEntryState.LOADED
# Only the device with the new format exists
device_entries = dr.async_entries_for_config_entry(
device_registry, config_entry.entry_id
)
assert len(device_entries) == 1
device_entry = device_registry.async_get_device({(DOMAIN, MAC_ADDRESS_UNIQUE_ID)})
assert device_entry
assert device_entry.identifiers == {(DOMAIN, MAC_ADDRESS_UNIQUE_ID)}
assert device_entry.name_by_user == expected_device_name
assert device_entry.disabled_by == expected_disabled_by