Subscribe to device registry changes from entities (#93601)
* Subscribe to device registry changes from entities * Use async_track_device_registry_updated_event * Fix unsubscribe * Fix logic, add testspull/93893/head
parent
204215e0f2
commit
59c6220b7c
|
@ -40,7 +40,10 @@ from homeassistant.util import dt as dt_util, ensure_unique_string, slugify
|
|||
|
||||
from . import device_registry as dr, entity_registry as er
|
||||
from .device_registry import DeviceEntryType
|
||||
from .event import async_track_entity_registry_updated_event
|
||||
from .event import (
|
||||
async_track_device_registry_updated_event,
|
||||
async_track_entity_registry_updated_event,
|
||||
)
|
||||
from .typing import StateType
|
||||
|
||||
if TYPE_CHECKING:
|
||||
|
@ -265,6 +268,8 @@ class Entity(ABC):
|
|||
# Hold list for functions to call on remove.
|
||||
_on_remove: list[CALLBACK_TYPE] | None = None
|
||||
|
||||
_unsub_device_updates: CALLBACK_TYPE | None = None
|
||||
|
||||
# Context
|
||||
_context: Context | None = None
|
||||
_context_set: datetime | None = None
|
||||
|
@ -926,6 +931,7 @@ class Entity(ABC):
|
|||
self.hass, self.entity_id, self._async_registry_updated
|
||||
)
|
||||
)
|
||||
self._async_subscribe_device_updates()
|
||||
|
||||
async def async_internal_will_remove_from_hass(self) -> None:
|
||||
"""Run when entity will be removed from hass.
|
||||
|
@ -946,6 +952,9 @@ class Entity(ABC):
|
|||
if data["action"] != "update":
|
||||
return
|
||||
|
||||
if "device_id" in data["changes"]:
|
||||
self._async_subscribe_device_updates()
|
||||
|
||||
ent_reg = er.async_get(self.hass)
|
||||
old = self.registry_entry
|
||||
self.registry_entry = ent_reg.async_get(data["entity_id"])
|
||||
|
@ -967,6 +976,51 @@ class Entity(ABC):
|
|||
self.entity_id = self.registry_entry.entity_id
|
||||
await self.platform.async_add_entities([self])
|
||||
|
||||
@callback
|
||||
def _async_unsubscribe_device_updates(self) -> None:
|
||||
"""Unsubscribe from device registry updates."""
|
||||
if not self._unsub_device_updates:
|
||||
return
|
||||
self._unsub_device_updates()
|
||||
self._unsub_device_updates = None
|
||||
|
||||
@callback
|
||||
def _async_subscribe_device_updates(self) -> None:
|
||||
"""Subscribe to device registry updates."""
|
||||
assert self.registry_entry
|
||||
|
||||
self._async_unsubscribe_device_updates()
|
||||
|
||||
if (device_id := self.registry_entry.device_id) is None:
|
||||
return
|
||||
|
||||
if not self.has_entity_name:
|
||||
return
|
||||
|
||||
@callback
|
||||
def async_device_registry_updated(event: Event) -> None:
|
||||
"""Handle device registry update."""
|
||||
data = event.data
|
||||
|
||||
if data["action"] != "update":
|
||||
return
|
||||
|
||||
if "name" not in data["changes"] and "name_by_user" not in data["changes"]:
|
||||
return
|
||||
|
||||
self.async_write_ha_state()
|
||||
|
||||
self._unsub_device_updates = async_track_device_registry_updated_event(
|
||||
self.hass,
|
||||
device_id,
|
||||
async_device_registry_updated,
|
||||
)
|
||||
if (
|
||||
not self._on_remove
|
||||
or self._async_unsubscribe_device_updates not in self._on_remove
|
||||
):
|
||||
self.async_on_remove(self._async_unsubscribe_device_updates)
|
||||
|
||||
def __repr__(self) -> str:
|
||||
"""Return the representation."""
|
||||
return f"<entity {self.entity_id}={self._stringify_state(self.available)}>"
|
||||
|
|
|
@ -988,6 +988,89 @@ async def test_friendly_name(
|
|||
assert state.attributes.get(ATTR_FRIENDLY_NAME) == expected_friendly_name
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
(
|
||||
"entity_name",
|
||||
"expected_friendly_name1",
|
||||
"expected_friendly_name2",
|
||||
"expected_friendly_name3",
|
||||
),
|
||||
(
|
||||
(
|
||||
"Entity Blu",
|
||||
"Device Bla Entity Blu",
|
||||
"Device Bla2 Entity Blu",
|
||||
"New Device Entity Blu",
|
||||
),
|
||||
(
|
||||
None,
|
||||
"Device Bla",
|
||||
"Device Bla2",
|
||||
"New Device",
|
||||
),
|
||||
),
|
||||
)
|
||||
async def test_friendly_name_updated(
|
||||
hass: HomeAssistant,
|
||||
device_registry: dr.DeviceRegistry,
|
||||
entity_registry: er.EntityRegistry,
|
||||
entity_name: str | None,
|
||||
expected_friendly_name1: str,
|
||||
expected_friendly_name2: str,
|
||||
expected_friendly_name3: str,
|
||||
) -> None:
|
||||
"""Test entity_id is influenced by entity name."""
|
||||
|
||||
async def async_setup_entry(hass, config_entry, async_add_entities):
|
||||
"""Mock setup entry method."""
|
||||
async_add_entities(
|
||||
[
|
||||
MockEntity(
|
||||
unique_id="qwer",
|
||||
device_info={
|
||||
"identifiers": {("hue", "1234")},
|
||||
"connections": {(dr.CONNECTION_NETWORK_MAC, "abcd")},
|
||||
"name": "Device Bla",
|
||||
},
|
||||
has_entity_name=True,
|
||||
name=entity_name,
|
||||
),
|
||||
]
|
||||
)
|
||||
return True
|
||||
|
||||
platform = MockPlatform(async_setup_entry=async_setup_entry)
|
||||
config_entry = MockConfigEntry(entry_id="super-mock-id")
|
||||
entity_platform = MockEntityPlatform(
|
||||
hass, platform_name=config_entry.domain, platform=platform
|
||||
)
|
||||
|
||||
assert await entity_platform.async_setup_entry(config_entry)
|
||||
await hass.async_block_till_done()
|
||||
|
||||
assert len(hass.states.async_entity_ids()) == 1
|
||||
state = hass.states.async_all()[0]
|
||||
assert state.attributes.get(ATTR_FRIENDLY_NAME) == expected_friendly_name1
|
||||
|
||||
device = device_registry.async_get_device(identifiers={("hue", "1234")})
|
||||
device_registry.async_update_device(device.id, name_by_user="Device Bla2")
|
||||
await hass.async_block_till_done()
|
||||
|
||||
state = hass.states.async_all()[0]
|
||||
assert state.attributes.get(ATTR_FRIENDLY_NAME) == expected_friendly_name2
|
||||
|
||||
device = device_registry.async_get_or_create(
|
||||
config_entry_id=config_entry.entry_id,
|
||||
identifiers={("hue", "5678")},
|
||||
name="New Device",
|
||||
)
|
||||
entity_registry.async_update_entity(state.entity_id, device_id=device.id)
|
||||
await hass.async_block_till_done()
|
||||
|
||||
state = hass.states.async_all()[0]
|
||||
assert state.attributes.get(ATTR_FRIENDLY_NAME) == expected_friendly_name3
|
||||
|
||||
|
||||
async def test_translation_key(hass: HomeAssistant) -> None:
|
||||
"""Test translation key property."""
|
||||
mock_entity1 = entity.Entity()
|
||||
|
|
Loading…
Reference in New Issue