Improve unique_id collision checks in entity_platform (#78132)

pull/78169/head
Erik Montnemery 2022-09-09 14:35:23 +02:00 committed by Paulus Schoutsen
parent 125afb39f0
commit 2b961fd327
2 changed files with 83 additions and 36 deletions

View File

@ -454,6 +454,22 @@ class EntityPlatform:
self.scan_interval,
)
def _entity_id_already_exists(self, entity_id: str) -> tuple[bool, bool]:
"""Check if an entity_id already exists.
Returns a tuple [already_exists, restored]
"""
already_exists = entity_id in self.entities
restored = False
if not already_exists and not self.hass.states.async_available(entity_id):
existing = self.hass.states.get(entity_id)
if existing is not None and ATTR_RESTORED in existing.attributes:
restored = True
else:
already_exists = True
return (already_exists, restored)
async def _async_add_entity( # noqa: C901
self,
entity: Entity,
@ -480,12 +496,31 @@ class EntityPlatform:
entity.add_to_platform_abort()
return
requested_entity_id = None
suggested_object_id: str | None = None
generate_new_entity_id = False
# Get entity_id from unique ID registration
if entity.unique_id is not None:
registered_entity_id = entity_registry.async_get_entity_id(
self.domain, self.platform_name, entity.unique_id
)
if registered_entity_id:
already_exists, _ = self._entity_id_already_exists(registered_entity_id)
if already_exists:
# If there's a collision, the entry belongs to another entity
entity.registry_entry = None
msg = (
f"Platform {self.platform_name} does not generate unique IDs. "
)
if entity.entity_id:
msg += f"ID {entity.unique_id} is already used by {registered_entity_id} - ignoring {entity.entity_id}"
else:
msg += f"ID {entity.unique_id} already exists - ignoring {registered_entity_id}"
self.logger.error(msg)
entity.add_to_platform_abort()
return
if self.config_entry is not None:
config_entry_id: str | None = self.config_entry.entry_id
else:
@ -541,7 +576,6 @@ class EntityPlatform:
pass
if entity.entity_id is not None:
requested_entity_id = entity.entity_id
suggested_object_id = split_entity_id(entity.entity_id)[1]
else:
if device and entity.has_entity_name: # type: ignore[unreachable]
@ -592,16 +626,6 @@ class EntityPlatform:
entity.registry_entry = entry
entity.entity_id = entry.entity_id
if entry.disabled:
self.logger.debug(
"Not adding entity %s because it's disabled",
entry.name
or entity.name
or f'"{self.platform_name} {entity.unique_id}"',
)
entity.add_to_platform_abort()
return
# We won't generate an entity ID if the platform has already set one
# We will however make sure that platform cannot pick a registered ID
elif entity.entity_id is not None and entity_registry.async_is_registered(
@ -628,28 +652,22 @@ class EntityPlatform:
entity.add_to_platform_abort()
raise HomeAssistantError(f"Invalid entity ID: {entity.entity_id}")
already_exists = entity.entity_id in self.entities
restored = False
if not already_exists and not self.hass.states.async_available(
entity.entity_id
):
existing = self.hass.states.get(entity.entity_id)
if existing is not None and ATTR_RESTORED in existing.attributes:
restored = True
else:
already_exists = True
already_exists, restored = self._entity_id_already_exists(entity.entity_id)
if already_exists:
if entity.unique_id is not None:
msg = f"Platform {self.platform_name} does not generate unique IDs. "
if requested_entity_id:
msg += f"ID {entity.unique_id} is already used by {entity.entity_id} - ignoring {requested_entity_id}"
else:
msg += f"ID {entity.unique_id} already exists - ignoring {entity.entity_id}"
else:
msg = f"Entity id already exists - ignoring: {entity.entity_id}"
self.logger.error(msg)
self.logger.error(
f"Entity id already exists - ignoring: {entity.entity_id}"
)
entity.add_to_platform_abort()
return
if entity.registry_entry and entity.registry_entry.disabled:
self.logger.debug(
"Not adding entity %s because it's disabled",
entry.name
or entity.name
or f'"{self.platform_name} {entity.unique_id}"',
)
entity.add_to_platform_abort()
return

View File

@ -438,13 +438,15 @@ async def test_async_remove_with_platform_update_finishes(hass):
async def test_not_adding_duplicate_entities_with_unique_id(hass, caplog):
"""Test for not adding duplicate entities."""
"""Test for not adding duplicate entities.
Also test that the entity registry is not updated for duplicates.
"""
caplog.set_level(logging.ERROR)
component = EntityComponent(_LOGGER, DOMAIN, hass)
await component.async_add_entities(
[MockEntity(name="test1", unique_id="not_very_unique")]
)
ent1 = MockEntity(name="test1", unique_id="not_very_unique")
await component.async_add_entities([ent1])
assert len(hass.states.async_entity_ids()) == 1
assert not caplog.text
@ -466,6 +468,11 @@ async def test_not_adding_duplicate_entities_with_unique_id(hass, caplog):
assert ent2.platform is None
assert len(hass.states.async_entity_ids()) == 1
registry = er.async_get(hass)
# test the entity name was not updated
entry = registry.async_get_or_create(DOMAIN, DOMAIN, "not_very_unique")
assert entry.original_name == "test1"
async def test_using_prescribed_entity_id(hass):
"""Test for using predefined entity ID."""
@ -577,6 +584,28 @@ async def test_registry_respect_entity_disabled(hass):
assert hass.states.async_entity_ids() == []
async def test_unique_id_conflict_has_priority_over_disabled_entity(hass, caplog):
"""Test that an entity that is not unique has priority over a disabled entity."""
component = EntityComponent(_LOGGER, DOMAIN, hass)
entity1 = MockEntity(
name="test1", unique_id="not_very_unique", enabled_by_default=False
)
entity2 = MockEntity(
name="test2", unique_id="not_very_unique", enabled_by_default=False
)
await component.async_add_entities([entity1])
await component.async_add_entities([entity2])
assert len(hass.states.async_entity_ids()) == 1
assert "Platform test_domain does not generate unique IDs." in caplog.text
assert entity1.registry_entry is not None
assert entity2.registry_entry is None
registry = er.async_get(hass)
# test the entity name was not updated
entry = registry.async_get_or_create(DOMAIN, DOMAIN, "not_very_unique")
assert entry.original_name == "test1"
async def test_entity_registry_updates_name(hass):
"""Test that updates on the entity registry update platform entities."""
registry = mock_registry(