diff --git a/homeassistant/components/config/entity_registry.py b/homeassistant/components/config/entity_registry.py index e006cd8062e..b024c3f0128 100644 --- a/homeassistant/components/config/entity_registry.py +++ b/homeassistant/components/config/entity_registry.py @@ -136,22 +136,18 @@ def websocket_update_entity(hass, connection, msg): changes = {} - for key in ("area_id", "device_class", "disabled_by", "hidden_by", "icon", "name"): + for key in ( + "area_id", + "device_class", + "disabled_by", + "hidden_by", + "icon", + "name", + "new_entity_id", + ): if key in msg: changes[key] = msg[key] - if "new_entity_id" in msg and msg["new_entity_id"] != entity_id: - changes["new_entity_id"] = msg["new_entity_id"] - if hass.states.get(msg["new_entity_id"]) is not None: - connection.send_message( - websocket_api.error_message( - msg["id"], - "invalid_info", - "Entity with this ID is already registered", - ) - ) - return - if "disabled_by" in msg and msg["disabled_by"] is None: # Don't allow enabling an entity of a disabled device if entity_entry.device_id: diff --git a/homeassistant/components/rfxtrx/config_flow.py b/homeassistant/components/rfxtrx/config_flow.py index 7b5fdc08261..2aa3bd20b8c 100644 --- a/homeassistant/components/rfxtrx/config_flow.py +++ b/homeassistant/components/rfxtrx/config_flow.py @@ -1,6 +1,7 @@ """Config flow for RFXCOM RFXtrx integration.""" from __future__ import annotations +import asyncio import copy import itertools import os @@ -23,12 +24,13 @@ from homeassistant.const import ( CONF_PORT, CONF_TYPE, ) -from homeassistant.core import callback +from homeassistant.core import State, callback from homeassistant.helpers import ( config_validation as cv, device_registry as dr, entity_registry as er, ) +from homeassistant.helpers.event import async_track_state_change from . import ( DOMAIN, @@ -343,9 +345,35 @@ class OptionsFlow(config_entries.OptionsFlow): if new_entity_id is not None: entity_migration_map[new_entity_id] = entry + @callback + def _handle_state_change( + entity_id: str, old_state: State | None, new_state: State | None + ) -> None: + # Wait for entities to finish cleanup + if new_state is None and entity_id in pending_entities: + pending_entities.remove(entity_id) + if not pending_entities: + wait_for_entities.set() + + # Create a set with entities to be removed which are currently in the state + # machine + pending_entities = { + entry.entity_id + for entry in entity_migration_map.values() + if not self.hass.states.async_available(entry.entity_id) + } + wait_for_entities = asyncio.Event() + remove_track_state_changes = async_track_state_change( + self.hass, pending_entities, _handle_state_change + ) + for entry in entity_migration_map.values(): entity_registry.async_remove(entry.entity_id) + # Wait for entities to finish cleanup + await wait_for_entities.wait() + remove_track_state_changes() + for entity_id, entry in entity_migration_map.items(): entity_registry.async_update_entity( entity_id, diff --git a/homeassistant/helpers/entity_registry.py b/homeassistant/helpers/entity_registry.py index f40c6347af7..e58dde19127 100644 --- a/homeassistant/helpers/entity_registry.py +++ b/homeassistant/helpers/entity_registry.py @@ -335,6 +335,25 @@ class EntityRegistry: """Check if an entity_id is currently registered.""" return self.entities.get_entity_id((domain, platform, unique_id)) + def _entity_id_available( + self, entity_id: str, known_object_ids: Iterable[str] | None + ) -> bool: + """Return True if the entity_id is available. + + An entity_id is available if: + - It's not registered + - It's not known by the entity component adding the entity + - It's not in the state machine + """ + if known_object_ids is None: + known_object_ids = {} + + return ( + entity_id not in self.entities + and entity_id not in known_object_ids + and self.hass.states.async_available(entity_id) + ) + @callback def async_generate_entity_id( self, @@ -352,15 +371,11 @@ class EntityRegistry: raise MaxLengthExceeded(domain, "domain", MAX_LENGTH_STATE_DOMAIN) test_string = preferred_string[:MAX_LENGTH_STATE_ENTITY_ID] - if not known_object_ids: + if known_object_ids is None: known_object_ids = {} tries = 1 - while ( - test_string in self.entities - or test_string in known_object_ids - or not self.hass.states.async_available(test_string) - ): + while not self._entity_id_available(test_string, known_object_ids): tries += 1 len_suffix = len(str(tries)) + 1 test_string = ( @@ -630,7 +645,7 @@ class EntityRegistry: old_values[attr_name] = getattr(old, attr_name) if new_entity_id is not UNDEFINED and new_entity_id != old.entity_id: - if self.async_is_registered(new_entity_id): + if not self._entity_id_available(new_entity_id, None): raise ValueError("Entity with this ID is already registered") if not valid_entity_id(new_entity_id): diff --git a/tests/helpers/test_entity_registry.py b/tests/helpers/test_entity_registry.py index 64579c25766..5538950260c 100644 --- a/tests/helpers/test_entity_registry.py +++ b/tests/helpers/test_entity_registry.py @@ -6,7 +6,7 @@ import voluptuous as vol from homeassistant import config_entries from homeassistant.const import EVENT_HOMEASSISTANT_START, STATE_UNAVAILABLE -from homeassistant.core import CoreState, callback +from homeassistant.core import CoreState, HomeAssistant, callback from homeassistant.exceptions import MaxLengthExceeded from homeassistant.helpers import device_registry as dr, entity_registry as er from homeassistant.helpers.entity import EntityCategory @@ -597,6 +597,56 @@ async def test_update_entity_unique_id_conflict(registry): assert registry.async_get_entity_id("light", "hue", "1234") == entry2.entity_id +async def test_update_entity_entity_id(registry): + """Test entity's entity_id is updated.""" + entry = registry.async_get_or_create("light", "hue", "5678") + assert registry.async_get_entity_id("light", "hue", "5678") == entry.entity_id + + new_entity_id = "light.blah" + assert new_entity_id != entry.entity_id + with patch.object(registry, "async_schedule_save") as mock_schedule_save: + updated_entry = registry.async_update_entity( + entry.entity_id, new_entity_id=new_entity_id + ) + assert updated_entry != entry + assert updated_entry.entity_id == new_entity_id + assert mock_schedule_save.call_count == 1 + + assert registry.async_get(entry.entity_id) is None + assert registry.async_get(new_entity_id) is not None + + +async def test_update_entity_entity_id_entity_id(hass: HomeAssistant, registry): + """Test update raises when entity_id already in use.""" + entry = registry.async_get_or_create("light", "hue", "5678") + entry2 = registry.async_get_or_create("light", "hue", "1234") + state_entity_id = "light.blah" + hass.states.async_set(state_entity_id, "on") + assert entry.entity_id != state_entity_id + assert entry2.entity_id != state_entity_id + + # Try updating to a registered entity_id + with patch.object( + registry, "async_schedule_save" + ) as mock_schedule_save, pytest.raises(ValueError): + registry.async_update_entity(entry.entity_id, new_entity_id=entry2.entity_id) + assert mock_schedule_save.call_count == 0 + assert registry.async_get_entity_id("light", "hue", "5678") == entry.entity_id + assert registry.async_get(entry.entity_id) is entry + assert registry.async_get_entity_id("light", "hue", "1234") == entry2.entity_id + assert registry.async_get(entry2.entity_id) is entry2 + + # Try updating to an entity_id which is in the state machine + with patch.object( + registry, "async_schedule_save" + ) as mock_schedule_save, pytest.raises(ValueError): + registry.async_update_entity(entry.entity_id, new_entity_id=state_entity_id) + assert mock_schedule_save.call_count == 0 + assert registry.async_get_entity_id("light", "hue", "5678") == entry.entity_id + assert registry.async_get(entry.entity_id) is entry + assert registry.async_get(state_entity_id) is None + + async def test_update_entity(registry): """Test updating entity.""" mock_config = MockConfigEntry(domain="light", entry_id="mock-id-1")