Improve check of new_entity_id in entity_registry.async_update_entity (#78276)
* Improve check of new_entity_id in entity_registry.async_update_entity * Fix race in rfxtrx config flow * Make sure Event is created on time * Rename poorly named variable * Fix typing * Correct typing of _handle_state_changepull/79211/head
parent
69bf77be12
commit
c38b1e7727
|
@ -136,22 +136,18 @@ def websocket_update_entity(hass, connection, msg):
|
|||
|
||||
changes = {}
|
||||
|
||||
for key in ("area_id", "device_class", "disabled_by", "hidden_by", "icon", "name"):
|
||||
for key in (
|
||||
"area_id",
|
||||
"device_class",
|
||||
"disabled_by",
|
||||
"hidden_by",
|
||||
"icon",
|
||||
"name",
|
||||
"new_entity_id",
|
||||
):
|
||||
if key in msg:
|
||||
changes[key] = msg[key]
|
||||
|
||||
if "new_entity_id" in msg and msg["new_entity_id"] != entity_id:
|
||||
changes["new_entity_id"] = msg["new_entity_id"]
|
||||
if hass.states.get(msg["new_entity_id"]) is not None:
|
||||
connection.send_message(
|
||||
websocket_api.error_message(
|
||||
msg["id"],
|
||||
"invalid_info",
|
||||
"Entity with this ID is already registered",
|
||||
)
|
||||
)
|
||||
return
|
||||
|
||||
if "disabled_by" in msg and msg["disabled_by"] is None:
|
||||
# Don't allow enabling an entity of a disabled device
|
||||
if entity_entry.device_id:
|
||||
|
|
|
@ -1,6 +1,7 @@
|
|||
"""Config flow for RFXCOM RFXtrx integration."""
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import copy
|
||||
import itertools
|
||||
import os
|
||||
|
@ -23,12 +24,13 @@ from homeassistant.const import (
|
|||
CONF_PORT,
|
||||
CONF_TYPE,
|
||||
)
|
||||
from homeassistant.core import callback
|
||||
from homeassistant.core import State, callback
|
||||
from homeassistant.helpers import (
|
||||
config_validation as cv,
|
||||
device_registry as dr,
|
||||
entity_registry as er,
|
||||
)
|
||||
from homeassistant.helpers.event import async_track_state_change
|
||||
|
||||
from . import (
|
||||
DOMAIN,
|
||||
|
@ -343,9 +345,35 @@ class OptionsFlow(config_entries.OptionsFlow):
|
|||
if new_entity_id is not None:
|
||||
entity_migration_map[new_entity_id] = entry
|
||||
|
||||
@callback
|
||||
def _handle_state_change(
|
||||
entity_id: str, old_state: State | None, new_state: State | None
|
||||
) -> None:
|
||||
# Wait for entities to finish cleanup
|
||||
if new_state is None and entity_id in pending_entities:
|
||||
pending_entities.remove(entity_id)
|
||||
if not pending_entities:
|
||||
wait_for_entities.set()
|
||||
|
||||
# Create a set with entities to be removed which are currently in the state
|
||||
# machine
|
||||
pending_entities = {
|
||||
entry.entity_id
|
||||
for entry in entity_migration_map.values()
|
||||
if not self.hass.states.async_available(entry.entity_id)
|
||||
}
|
||||
wait_for_entities = asyncio.Event()
|
||||
remove_track_state_changes = async_track_state_change(
|
||||
self.hass, pending_entities, _handle_state_change
|
||||
)
|
||||
|
||||
for entry in entity_migration_map.values():
|
||||
entity_registry.async_remove(entry.entity_id)
|
||||
|
||||
# Wait for entities to finish cleanup
|
||||
await wait_for_entities.wait()
|
||||
remove_track_state_changes()
|
||||
|
||||
for entity_id, entry in entity_migration_map.items():
|
||||
entity_registry.async_update_entity(
|
||||
entity_id,
|
||||
|
|
|
@ -335,6 +335,25 @@ class EntityRegistry:
|
|||
"""Check if an entity_id is currently registered."""
|
||||
return self.entities.get_entity_id((domain, platform, unique_id))
|
||||
|
||||
def _entity_id_available(
|
||||
self, entity_id: str, known_object_ids: Iterable[str] | None
|
||||
) -> bool:
|
||||
"""Return True if the entity_id is available.
|
||||
|
||||
An entity_id is available if:
|
||||
- It's not registered
|
||||
- It's not known by the entity component adding the entity
|
||||
- It's not in the state machine
|
||||
"""
|
||||
if known_object_ids is None:
|
||||
known_object_ids = {}
|
||||
|
||||
return (
|
||||
entity_id not in self.entities
|
||||
and entity_id not in known_object_ids
|
||||
and self.hass.states.async_available(entity_id)
|
||||
)
|
||||
|
||||
@callback
|
||||
def async_generate_entity_id(
|
||||
self,
|
||||
|
@ -352,15 +371,11 @@ class EntityRegistry:
|
|||
raise MaxLengthExceeded(domain, "domain", MAX_LENGTH_STATE_DOMAIN)
|
||||
|
||||
test_string = preferred_string[:MAX_LENGTH_STATE_ENTITY_ID]
|
||||
if not known_object_ids:
|
||||
if known_object_ids is None:
|
||||
known_object_ids = {}
|
||||
|
||||
tries = 1
|
||||
while (
|
||||
test_string in self.entities
|
||||
or test_string in known_object_ids
|
||||
or not self.hass.states.async_available(test_string)
|
||||
):
|
||||
while not self._entity_id_available(test_string, known_object_ids):
|
||||
tries += 1
|
||||
len_suffix = len(str(tries)) + 1
|
||||
test_string = (
|
||||
|
@ -630,7 +645,7 @@ class EntityRegistry:
|
|||
old_values[attr_name] = getattr(old, attr_name)
|
||||
|
||||
if new_entity_id is not UNDEFINED and new_entity_id != old.entity_id:
|
||||
if self.async_is_registered(new_entity_id):
|
||||
if not self._entity_id_available(new_entity_id, None):
|
||||
raise ValueError("Entity with this ID is already registered")
|
||||
|
||||
if not valid_entity_id(new_entity_id):
|
||||
|
|
|
@ -6,7 +6,7 @@ import voluptuous as vol
|
|||
|
||||
from homeassistant import config_entries
|
||||
from homeassistant.const import EVENT_HOMEASSISTANT_START, STATE_UNAVAILABLE
|
||||
from homeassistant.core import CoreState, callback
|
||||
from homeassistant.core import CoreState, HomeAssistant, callback
|
||||
from homeassistant.exceptions import MaxLengthExceeded
|
||||
from homeassistant.helpers import device_registry as dr, entity_registry as er
|
||||
from homeassistant.helpers.entity import EntityCategory
|
||||
|
@ -597,6 +597,56 @@ async def test_update_entity_unique_id_conflict(registry):
|
|||
assert registry.async_get_entity_id("light", "hue", "1234") == entry2.entity_id
|
||||
|
||||
|
||||
async def test_update_entity_entity_id(registry):
|
||||
"""Test entity's entity_id is updated."""
|
||||
entry = registry.async_get_or_create("light", "hue", "5678")
|
||||
assert registry.async_get_entity_id("light", "hue", "5678") == entry.entity_id
|
||||
|
||||
new_entity_id = "light.blah"
|
||||
assert new_entity_id != entry.entity_id
|
||||
with patch.object(registry, "async_schedule_save") as mock_schedule_save:
|
||||
updated_entry = registry.async_update_entity(
|
||||
entry.entity_id, new_entity_id=new_entity_id
|
||||
)
|
||||
assert updated_entry != entry
|
||||
assert updated_entry.entity_id == new_entity_id
|
||||
assert mock_schedule_save.call_count == 1
|
||||
|
||||
assert registry.async_get(entry.entity_id) is None
|
||||
assert registry.async_get(new_entity_id) is not None
|
||||
|
||||
|
||||
async def test_update_entity_entity_id_entity_id(hass: HomeAssistant, registry):
|
||||
"""Test update raises when entity_id already in use."""
|
||||
entry = registry.async_get_or_create("light", "hue", "5678")
|
||||
entry2 = registry.async_get_or_create("light", "hue", "1234")
|
||||
state_entity_id = "light.blah"
|
||||
hass.states.async_set(state_entity_id, "on")
|
||||
assert entry.entity_id != state_entity_id
|
||||
assert entry2.entity_id != state_entity_id
|
||||
|
||||
# Try updating to a registered entity_id
|
||||
with patch.object(
|
||||
registry, "async_schedule_save"
|
||||
) as mock_schedule_save, pytest.raises(ValueError):
|
||||
registry.async_update_entity(entry.entity_id, new_entity_id=entry2.entity_id)
|
||||
assert mock_schedule_save.call_count == 0
|
||||
assert registry.async_get_entity_id("light", "hue", "5678") == entry.entity_id
|
||||
assert registry.async_get(entry.entity_id) is entry
|
||||
assert registry.async_get_entity_id("light", "hue", "1234") == entry2.entity_id
|
||||
assert registry.async_get(entry2.entity_id) is entry2
|
||||
|
||||
# Try updating to an entity_id which is in the state machine
|
||||
with patch.object(
|
||||
registry, "async_schedule_save"
|
||||
) as mock_schedule_save, pytest.raises(ValueError):
|
||||
registry.async_update_entity(entry.entity_id, new_entity_id=state_entity_id)
|
||||
assert mock_schedule_save.call_count == 0
|
||||
assert registry.async_get_entity_id("light", "hue", "5678") == entry.entity_id
|
||||
assert registry.async_get(entry.entity_id) is entry
|
||||
assert registry.async_get(state_entity_id) is None
|
||||
|
||||
|
||||
async def test_update_entity(registry):
|
||||
"""Test updating entity."""
|
||||
mock_config = MockConfigEntry(domain="light", entry_id="mock-id-1")
|
||||
|
|
Loading…
Reference in New Issue