diff --git a/homeassistant/helpers/entity_registry.py b/homeassistant/helpers/entity_registry.py index e8f1dea0639..3b0cb67f6a2 100644 --- a/homeassistant/helpers/entity_registry.py +++ b/homeassistant/helpers/entity_registry.py @@ -1899,11 +1899,25 @@ def _async_setup_entity_restore(hass: HomeAssistant, registry: EntityRegistry) - @callback def cleanup_restored_states_filter(event_data: Mapping[str, Any]) -> bool: """Clean up restored states filter.""" - return bool(event_data["action"] == "remove") + return (event_data["action"] == "remove") or ( + event_data["action"] == "update" + and "old_entity_id" in event_data + and event_data["entity_id"] != event_data["old_entity_id"] + ) @callback def cleanup_restored_states(event: Event[EventEntityRegistryUpdatedData]) -> None: """Clean up restored states.""" + if event.data["action"] == "update": + old_entity_id = event.data["old_entity_id"] + old_state = hass.states.get(old_entity_id) + if old_state is None or not old_state.attributes.get(ATTR_RESTORED): + return + hass.states.async_remove(old_entity_id, context=event.context) + if entry := registry.async_get(event.data["entity_id"]): + entry.write_unavailable_state(hass) + return + state = hass.states.get(event.data["entity_id"]) if state is None or not state.attributes.get(ATTR_RESTORED): diff --git a/tests/helpers/test_entity_registry.py b/tests/helpers/test_entity_registry.py index 421f52bca73..593e1ea9703 100644 --- a/tests/helpers/test_entity_registry.py +++ b/tests/helpers/test_entity_registry.py @@ -1462,9 +1462,56 @@ async def test_update_entity_unique_id_conflict( ) -async def test_update_entity_entity_id(entity_registry: er.EntityRegistry) -> None: - """Test entity's entity_id is updated.""" +async def test_update_entity_entity_id( + hass: HomeAssistant, entity_registry: er.EntityRegistry +) -> None: + """Test entity's entity_id is updated for entity with a restored state.""" + hass.set_state(CoreState.not_running) + + mock_config = MockConfigEntry(domain="light", entry_id="mock-id-1") + mock_config.add_to_hass(hass) + entry = entity_registry.async_get_or_create( + "light", "hue", "5678", config_entry=mock_config + ) + hass.bus.async_fire(EVENT_HOMEASSISTANT_START, {}) + await hass.async_block_till_done() + assert ( + entity_registry.async_get_entity_id("light", "hue", "5678") == entry.entity_id + ) + state = hass.states.get(entry.entity_id) + assert state is not None + assert state.state == "unavailable" + assert state.attributes == {"restored": True, "supported_features": 0} + + new_entity_id = "light.blah" + assert new_entity_id != entry.entity_id + with patch.object(entity_registry, "async_schedule_save") as mock_schedule_save: + updated_entry = entity_registry.async_update_entity( + entry.entity_id, new_entity_id=new_entity_id + ) + assert updated_entry != entry + assert updated_entry.entity_id == new_entity_id + assert mock_schedule_save.call_count == 1 + + assert entity_registry.async_get(entry.entity_id) is None + assert entity_registry.async_get(new_entity_id) is not None + + # The restored state should be removed + old_state = hass.states.get(entry.entity_id) + assert old_state is None + + # The new entity should have an unavailable initial state + new_state = hass.states.get(new_entity_id) + assert new_state is not None + assert new_state.state == "unavailable" + + +async def test_update_entity_entity_id_without_state( + entity_registry: er.EntityRegistry, +) -> None: + """Test entity's entity_id is updated for entity without a state.""" entry = entity_registry.async_get_or_create("light", "hue", "5678") + assert ( entity_registry.async_get_entity_id("light", "hue", "5678") == entry.entity_id )