Improve check of new_entity_id in entity_registry.async_update_entity (#78276)
* Improve check of new_entity_id in entity_registry.async_update_entity * Fix race in rfxtrx config flow * Make sure Event is created on time * Rename poorly named variable * Fix typing * Correct typing of _handle_state_changepull/79211/head
parent
69bf77be12
commit
c38b1e7727
|
@ -136,22 +136,18 @@ def websocket_update_entity(hass, connection, msg):
|
||||||
|
|
||||||
changes = {}
|
changes = {}
|
||||||
|
|
||||||
for key in ("area_id", "device_class", "disabled_by", "hidden_by", "icon", "name"):
|
for key in (
|
||||||
|
"area_id",
|
||||||
|
"device_class",
|
||||||
|
"disabled_by",
|
||||||
|
"hidden_by",
|
||||||
|
"icon",
|
||||||
|
"name",
|
||||||
|
"new_entity_id",
|
||||||
|
):
|
||||||
if key in msg:
|
if key in msg:
|
||||||
changes[key] = msg[key]
|
changes[key] = msg[key]
|
||||||
|
|
||||||
if "new_entity_id" in msg and msg["new_entity_id"] != entity_id:
|
|
||||||
changes["new_entity_id"] = msg["new_entity_id"]
|
|
||||||
if hass.states.get(msg["new_entity_id"]) is not None:
|
|
||||||
connection.send_message(
|
|
||||||
websocket_api.error_message(
|
|
||||||
msg["id"],
|
|
||||||
"invalid_info",
|
|
||||||
"Entity with this ID is already registered",
|
|
||||||
)
|
|
||||||
)
|
|
||||||
return
|
|
||||||
|
|
||||||
if "disabled_by" in msg and msg["disabled_by"] is None:
|
if "disabled_by" in msg and msg["disabled_by"] is None:
|
||||||
# Don't allow enabling an entity of a disabled device
|
# Don't allow enabling an entity of a disabled device
|
||||||
if entity_entry.device_id:
|
if entity_entry.device_id:
|
||||||
|
|
|
@ -1,6 +1,7 @@
|
||||||
"""Config flow for RFXCOM RFXtrx integration."""
|
"""Config flow for RFXCOM RFXtrx integration."""
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import asyncio
|
||||||
import copy
|
import copy
|
||||||
import itertools
|
import itertools
|
||||||
import os
|
import os
|
||||||
|
@ -23,12 +24,13 @@ from homeassistant.const import (
|
||||||
CONF_PORT,
|
CONF_PORT,
|
||||||
CONF_TYPE,
|
CONF_TYPE,
|
||||||
)
|
)
|
||||||
from homeassistant.core import callback
|
from homeassistant.core import State, callback
|
||||||
from homeassistant.helpers import (
|
from homeassistant.helpers import (
|
||||||
config_validation as cv,
|
config_validation as cv,
|
||||||
device_registry as dr,
|
device_registry as dr,
|
||||||
entity_registry as er,
|
entity_registry as er,
|
||||||
)
|
)
|
||||||
|
from homeassistant.helpers.event import async_track_state_change
|
||||||
|
|
||||||
from . import (
|
from . import (
|
||||||
DOMAIN,
|
DOMAIN,
|
||||||
|
@ -343,9 +345,35 @@ class OptionsFlow(config_entries.OptionsFlow):
|
||||||
if new_entity_id is not None:
|
if new_entity_id is not None:
|
||||||
entity_migration_map[new_entity_id] = entry
|
entity_migration_map[new_entity_id] = entry
|
||||||
|
|
||||||
|
@callback
|
||||||
|
def _handle_state_change(
|
||||||
|
entity_id: str, old_state: State | None, new_state: State | None
|
||||||
|
) -> None:
|
||||||
|
# Wait for entities to finish cleanup
|
||||||
|
if new_state is None and entity_id in pending_entities:
|
||||||
|
pending_entities.remove(entity_id)
|
||||||
|
if not pending_entities:
|
||||||
|
wait_for_entities.set()
|
||||||
|
|
||||||
|
# Create a set with entities to be removed which are currently in the state
|
||||||
|
# machine
|
||||||
|
pending_entities = {
|
||||||
|
entry.entity_id
|
||||||
|
for entry in entity_migration_map.values()
|
||||||
|
if not self.hass.states.async_available(entry.entity_id)
|
||||||
|
}
|
||||||
|
wait_for_entities = asyncio.Event()
|
||||||
|
remove_track_state_changes = async_track_state_change(
|
||||||
|
self.hass, pending_entities, _handle_state_change
|
||||||
|
)
|
||||||
|
|
||||||
for entry in entity_migration_map.values():
|
for entry in entity_migration_map.values():
|
||||||
entity_registry.async_remove(entry.entity_id)
|
entity_registry.async_remove(entry.entity_id)
|
||||||
|
|
||||||
|
# Wait for entities to finish cleanup
|
||||||
|
await wait_for_entities.wait()
|
||||||
|
remove_track_state_changes()
|
||||||
|
|
||||||
for entity_id, entry in entity_migration_map.items():
|
for entity_id, entry in entity_migration_map.items():
|
||||||
entity_registry.async_update_entity(
|
entity_registry.async_update_entity(
|
||||||
entity_id,
|
entity_id,
|
||||||
|
|
|
@ -335,6 +335,25 @@ class EntityRegistry:
|
||||||
"""Check if an entity_id is currently registered."""
|
"""Check if an entity_id is currently registered."""
|
||||||
return self.entities.get_entity_id((domain, platform, unique_id))
|
return self.entities.get_entity_id((domain, platform, unique_id))
|
||||||
|
|
||||||
|
def _entity_id_available(
|
||||||
|
self, entity_id: str, known_object_ids: Iterable[str] | None
|
||||||
|
) -> bool:
|
||||||
|
"""Return True if the entity_id is available.
|
||||||
|
|
||||||
|
An entity_id is available if:
|
||||||
|
- It's not registered
|
||||||
|
- It's not known by the entity component adding the entity
|
||||||
|
- It's not in the state machine
|
||||||
|
"""
|
||||||
|
if known_object_ids is None:
|
||||||
|
known_object_ids = {}
|
||||||
|
|
||||||
|
return (
|
||||||
|
entity_id not in self.entities
|
||||||
|
and entity_id not in known_object_ids
|
||||||
|
and self.hass.states.async_available(entity_id)
|
||||||
|
)
|
||||||
|
|
||||||
@callback
|
@callback
|
||||||
def async_generate_entity_id(
|
def async_generate_entity_id(
|
||||||
self,
|
self,
|
||||||
|
@ -352,15 +371,11 @@ class EntityRegistry:
|
||||||
raise MaxLengthExceeded(domain, "domain", MAX_LENGTH_STATE_DOMAIN)
|
raise MaxLengthExceeded(domain, "domain", MAX_LENGTH_STATE_DOMAIN)
|
||||||
|
|
||||||
test_string = preferred_string[:MAX_LENGTH_STATE_ENTITY_ID]
|
test_string = preferred_string[:MAX_LENGTH_STATE_ENTITY_ID]
|
||||||
if not known_object_ids:
|
if known_object_ids is None:
|
||||||
known_object_ids = {}
|
known_object_ids = {}
|
||||||
|
|
||||||
tries = 1
|
tries = 1
|
||||||
while (
|
while not self._entity_id_available(test_string, known_object_ids):
|
||||||
test_string in self.entities
|
|
||||||
or test_string in known_object_ids
|
|
||||||
or not self.hass.states.async_available(test_string)
|
|
||||||
):
|
|
||||||
tries += 1
|
tries += 1
|
||||||
len_suffix = len(str(tries)) + 1
|
len_suffix = len(str(tries)) + 1
|
||||||
test_string = (
|
test_string = (
|
||||||
|
@ -630,7 +645,7 @@ class EntityRegistry:
|
||||||
old_values[attr_name] = getattr(old, attr_name)
|
old_values[attr_name] = getattr(old, attr_name)
|
||||||
|
|
||||||
if new_entity_id is not UNDEFINED and new_entity_id != old.entity_id:
|
if new_entity_id is not UNDEFINED and new_entity_id != old.entity_id:
|
||||||
if self.async_is_registered(new_entity_id):
|
if not self._entity_id_available(new_entity_id, None):
|
||||||
raise ValueError("Entity with this ID is already registered")
|
raise ValueError("Entity with this ID is already registered")
|
||||||
|
|
||||||
if not valid_entity_id(new_entity_id):
|
if not valid_entity_id(new_entity_id):
|
||||||
|
|
|
@ -6,7 +6,7 @@ import voluptuous as vol
|
||||||
|
|
||||||
from homeassistant import config_entries
|
from homeassistant import config_entries
|
||||||
from homeassistant.const import EVENT_HOMEASSISTANT_START, STATE_UNAVAILABLE
|
from homeassistant.const import EVENT_HOMEASSISTANT_START, STATE_UNAVAILABLE
|
||||||
from homeassistant.core import CoreState, callback
|
from homeassistant.core import CoreState, HomeAssistant, callback
|
||||||
from homeassistant.exceptions import MaxLengthExceeded
|
from homeassistant.exceptions import MaxLengthExceeded
|
||||||
from homeassistant.helpers import device_registry as dr, entity_registry as er
|
from homeassistant.helpers import device_registry as dr, entity_registry as er
|
||||||
from homeassistant.helpers.entity import EntityCategory
|
from homeassistant.helpers.entity import EntityCategory
|
||||||
|
@ -597,6 +597,56 @@ async def test_update_entity_unique_id_conflict(registry):
|
||||||
assert registry.async_get_entity_id("light", "hue", "1234") == entry2.entity_id
|
assert registry.async_get_entity_id("light", "hue", "1234") == entry2.entity_id
|
||||||
|
|
||||||
|
|
||||||
|
async def test_update_entity_entity_id(registry):
|
||||||
|
"""Test entity's entity_id is updated."""
|
||||||
|
entry = registry.async_get_or_create("light", "hue", "5678")
|
||||||
|
assert registry.async_get_entity_id("light", "hue", "5678") == entry.entity_id
|
||||||
|
|
||||||
|
new_entity_id = "light.blah"
|
||||||
|
assert new_entity_id != entry.entity_id
|
||||||
|
with patch.object(registry, "async_schedule_save") as mock_schedule_save:
|
||||||
|
updated_entry = registry.async_update_entity(
|
||||||
|
entry.entity_id, new_entity_id=new_entity_id
|
||||||
|
)
|
||||||
|
assert updated_entry != entry
|
||||||
|
assert updated_entry.entity_id == new_entity_id
|
||||||
|
assert mock_schedule_save.call_count == 1
|
||||||
|
|
||||||
|
assert registry.async_get(entry.entity_id) is None
|
||||||
|
assert registry.async_get(new_entity_id) is not None
|
||||||
|
|
||||||
|
|
||||||
|
async def test_update_entity_entity_id_entity_id(hass: HomeAssistant, registry):
|
||||||
|
"""Test update raises when entity_id already in use."""
|
||||||
|
entry = registry.async_get_or_create("light", "hue", "5678")
|
||||||
|
entry2 = registry.async_get_or_create("light", "hue", "1234")
|
||||||
|
state_entity_id = "light.blah"
|
||||||
|
hass.states.async_set(state_entity_id, "on")
|
||||||
|
assert entry.entity_id != state_entity_id
|
||||||
|
assert entry2.entity_id != state_entity_id
|
||||||
|
|
||||||
|
# Try updating to a registered entity_id
|
||||||
|
with patch.object(
|
||||||
|
registry, "async_schedule_save"
|
||||||
|
) as mock_schedule_save, pytest.raises(ValueError):
|
||||||
|
registry.async_update_entity(entry.entity_id, new_entity_id=entry2.entity_id)
|
||||||
|
assert mock_schedule_save.call_count == 0
|
||||||
|
assert registry.async_get_entity_id("light", "hue", "5678") == entry.entity_id
|
||||||
|
assert registry.async_get(entry.entity_id) is entry
|
||||||
|
assert registry.async_get_entity_id("light", "hue", "1234") == entry2.entity_id
|
||||||
|
assert registry.async_get(entry2.entity_id) is entry2
|
||||||
|
|
||||||
|
# Try updating to an entity_id which is in the state machine
|
||||||
|
with patch.object(
|
||||||
|
registry, "async_schedule_save"
|
||||||
|
) as mock_schedule_save, pytest.raises(ValueError):
|
||||||
|
registry.async_update_entity(entry.entity_id, new_entity_id=state_entity_id)
|
||||||
|
assert mock_schedule_save.call_count == 0
|
||||||
|
assert registry.async_get_entity_id("light", "hue", "5678") == entry.entity_id
|
||||||
|
assert registry.async_get(entry.entity_id) is entry
|
||||||
|
assert registry.async_get(state_entity_id) is None
|
||||||
|
|
||||||
|
|
||||||
async def test_update_entity(registry):
|
async def test_update_entity(registry):
|
||||||
"""Test updating entity."""
|
"""Test updating entity."""
|
||||||
mock_config = MockConfigEntry(domain="light", entry_id="mock-id-1")
|
mock_config = MockConfigEntry(domain="light", entry_id="mock-id-1")
|
||||||
|
|
Loading…
Reference in New Issue