Improve check of new_entity_id in entity_registry.async_update_entity (#78276)

* Improve check of new_entity_id in entity_registry.async_update_entity

* Fix race in rfxtrx config flow

* Make sure Event is created on time

* Rename poorly named variable

* Fix typing

* Correct typing of _handle_state_change
pull/79211/head
Erik Montnemery 2022-09-28 08:43:58 +02:00 committed by GitHub
parent 69bf77be12
commit c38b1e7727
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 111 additions and 22 deletions

View File

@ -136,22 +136,18 @@ def websocket_update_entity(hass, connection, msg):
changes = {}
for key in ("area_id", "device_class", "disabled_by", "hidden_by", "icon", "name"):
for key in (
"area_id",
"device_class",
"disabled_by",
"hidden_by",
"icon",
"name",
"new_entity_id",
):
if key in msg:
changes[key] = msg[key]
if "new_entity_id" in msg and msg["new_entity_id"] != entity_id:
changes["new_entity_id"] = msg["new_entity_id"]
if hass.states.get(msg["new_entity_id"]) is not None:
connection.send_message(
websocket_api.error_message(
msg["id"],
"invalid_info",
"Entity with this ID is already registered",
)
)
return
if "disabled_by" in msg and msg["disabled_by"] is None:
# Don't allow enabling an entity of a disabled device
if entity_entry.device_id:

View File

@ -1,6 +1,7 @@
"""Config flow for RFXCOM RFXtrx integration."""
from __future__ import annotations
import asyncio
import copy
import itertools
import os
@ -23,12 +24,13 @@ from homeassistant.const import (
CONF_PORT,
CONF_TYPE,
)
from homeassistant.core import callback
from homeassistant.core import State, callback
from homeassistant.helpers import (
config_validation as cv,
device_registry as dr,
entity_registry as er,
)
from homeassistant.helpers.event import async_track_state_change
from . import (
DOMAIN,
@ -343,9 +345,35 @@ class OptionsFlow(config_entries.OptionsFlow):
if new_entity_id is not None:
entity_migration_map[new_entity_id] = entry
@callback
def _handle_state_change(
entity_id: str, old_state: State | None, new_state: State | None
) -> None:
# Wait for entities to finish cleanup
if new_state is None and entity_id in pending_entities:
pending_entities.remove(entity_id)
if not pending_entities:
wait_for_entities.set()
# Create a set with entities to be removed which are currently in the state
# machine
pending_entities = {
entry.entity_id
for entry in entity_migration_map.values()
if not self.hass.states.async_available(entry.entity_id)
}
wait_for_entities = asyncio.Event()
remove_track_state_changes = async_track_state_change(
self.hass, pending_entities, _handle_state_change
)
for entry in entity_migration_map.values():
entity_registry.async_remove(entry.entity_id)
# Wait for entities to finish cleanup
await wait_for_entities.wait()
remove_track_state_changes()
for entity_id, entry in entity_migration_map.items():
entity_registry.async_update_entity(
entity_id,

View File

@ -335,6 +335,25 @@ class EntityRegistry:
"""Check if an entity_id is currently registered."""
return self.entities.get_entity_id((domain, platform, unique_id))
def _entity_id_available(
self, entity_id: str, known_object_ids: Iterable[str] | None
) -> bool:
"""Return True if the entity_id is available.
An entity_id is available if:
- It's not registered
- It's not known by the entity component adding the entity
- It's not in the state machine
"""
if known_object_ids is None:
known_object_ids = {}
return (
entity_id not in self.entities
and entity_id not in known_object_ids
and self.hass.states.async_available(entity_id)
)
@callback
def async_generate_entity_id(
self,
@ -352,15 +371,11 @@ class EntityRegistry:
raise MaxLengthExceeded(domain, "domain", MAX_LENGTH_STATE_DOMAIN)
test_string = preferred_string[:MAX_LENGTH_STATE_ENTITY_ID]
if not known_object_ids:
if known_object_ids is None:
known_object_ids = {}
tries = 1
while (
test_string in self.entities
or test_string in known_object_ids
or not self.hass.states.async_available(test_string)
):
while not self._entity_id_available(test_string, known_object_ids):
tries += 1
len_suffix = len(str(tries)) + 1
test_string = (
@ -630,7 +645,7 @@ class EntityRegistry:
old_values[attr_name] = getattr(old, attr_name)
if new_entity_id is not UNDEFINED and new_entity_id != old.entity_id:
if self.async_is_registered(new_entity_id):
if not self._entity_id_available(new_entity_id, None):
raise ValueError("Entity with this ID is already registered")
if not valid_entity_id(new_entity_id):

View File

@ -6,7 +6,7 @@ import voluptuous as vol
from homeassistant import config_entries
from homeassistant.const import EVENT_HOMEASSISTANT_START, STATE_UNAVAILABLE
from homeassistant.core import CoreState, callback
from homeassistant.core import CoreState, HomeAssistant, callback
from homeassistant.exceptions import MaxLengthExceeded
from homeassistant.helpers import device_registry as dr, entity_registry as er
from homeassistant.helpers.entity import EntityCategory
@ -597,6 +597,56 @@ async def test_update_entity_unique_id_conflict(registry):
assert registry.async_get_entity_id("light", "hue", "1234") == entry2.entity_id
async def test_update_entity_entity_id(registry):
"""Test entity's entity_id is updated."""
entry = registry.async_get_or_create("light", "hue", "5678")
assert registry.async_get_entity_id("light", "hue", "5678") == entry.entity_id
new_entity_id = "light.blah"
assert new_entity_id != entry.entity_id
with patch.object(registry, "async_schedule_save") as mock_schedule_save:
updated_entry = registry.async_update_entity(
entry.entity_id, new_entity_id=new_entity_id
)
assert updated_entry != entry
assert updated_entry.entity_id == new_entity_id
assert mock_schedule_save.call_count == 1
assert registry.async_get(entry.entity_id) is None
assert registry.async_get(new_entity_id) is not None
async def test_update_entity_entity_id_entity_id(hass: HomeAssistant, registry):
"""Test update raises when entity_id already in use."""
entry = registry.async_get_or_create("light", "hue", "5678")
entry2 = registry.async_get_or_create("light", "hue", "1234")
state_entity_id = "light.blah"
hass.states.async_set(state_entity_id, "on")
assert entry.entity_id != state_entity_id
assert entry2.entity_id != state_entity_id
# Try updating to a registered entity_id
with patch.object(
registry, "async_schedule_save"
) as mock_schedule_save, pytest.raises(ValueError):
registry.async_update_entity(entry.entity_id, new_entity_id=entry2.entity_id)
assert mock_schedule_save.call_count == 0
assert registry.async_get_entity_id("light", "hue", "5678") == entry.entity_id
assert registry.async_get(entry.entity_id) is entry
assert registry.async_get_entity_id("light", "hue", "1234") == entry2.entity_id
assert registry.async_get(entry2.entity_id) is entry2
# Try updating to an entity_id which is in the state machine
with patch.object(
registry, "async_schedule_save"
) as mock_schedule_save, pytest.raises(ValueError):
registry.async_update_entity(entry.entity_id, new_entity_id=state_entity_id)
assert mock_schedule_save.call_count == 0
assert registry.async_get_entity_id("light", "hue", "5678") == entry.entity_id
assert registry.async_get(entry.entity_id) is entry
assert registry.async_get(state_entity_id) is None
async def test_update_entity(registry):
"""Test updating entity."""
mock_config = MockConfigEntry(domain="light", entry_id="mock-id-1")