Subscribe to device registry changes from entities (#93601)

* Subscribe to device registry changes from entities

* Use async_track_device_registry_updated_event

* Fix unsubscribe

* Fix logic, add tests
pull/93893/head
Erik Montnemery 2023-05-31 11:01:55 +02:00 committed by GitHub
parent 204215e0f2
commit 59c6220b7c
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 138 additions and 1 deletions

View File

@ -40,7 +40,10 @@ from homeassistant.util import dt as dt_util, ensure_unique_string, slugify
from . import device_registry as dr, entity_registry as er
from .device_registry import DeviceEntryType
from .event import async_track_entity_registry_updated_event
from .event import (
async_track_device_registry_updated_event,
async_track_entity_registry_updated_event,
)
from .typing import StateType
if TYPE_CHECKING:
@ -265,6 +268,8 @@ class Entity(ABC):
# Hold list for functions to call on remove.
_on_remove: list[CALLBACK_TYPE] | None = None
_unsub_device_updates: CALLBACK_TYPE | None = None
# Context
_context: Context | None = None
_context_set: datetime | None = None
@ -926,6 +931,7 @@ class Entity(ABC):
self.hass, self.entity_id, self._async_registry_updated
)
)
self._async_subscribe_device_updates()
async def async_internal_will_remove_from_hass(self) -> None:
"""Run when entity will be removed from hass.
@ -946,6 +952,9 @@ class Entity(ABC):
if data["action"] != "update":
return
if "device_id" in data["changes"]:
self._async_subscribe_device_updates()
ent_reg = er.async_get(self.hass)
old = self.registry_entry
self.registry_entry = ent_reg.async_get(data["entity_id"])
@ -967,6 +976,51 @@ class Entity(ABC):
self.entity_id = self.registry_entry.entity_id
await self.platform.async_add_entities([self])
@callback
def _async_unsubscribe_device_updates(self) -> None:
"""Unsubscribe from device registry updates."""
if not self._unsub_device_updates:
return
self._unsub_device_updates()
self._unsub_device_updates = None
@callback
def _async_subscribe_device_updates(self) -> None:
"""Subscribe to device registry updates."""
assert self.registry_entry
self._async_unsubscribe_device_updates()
if (device_id := self.registry_entry.device_id) is None:
return
if not self.has_entity_name:
return
@callback
def async_device_registry_updated(event: Event) -> None:
"""Handle device registry update."""
data = event.data
if data["action"] != "update":
return
if "name" not in data["changes"] and "name_by_user" not in data["changes"]:
return
self.async_write_ha_state()
self._unsub_device_updates = async_track_device_registry_updated_event(
self.hass,
device_id,
async_device_registry_updated,
)
if (
not self._on_remove
or self._async_unsubscribe_device_updates not in self._on_remove
):
self.async_on_remove(self._async_unsubscribe_device_updates)
def __repr__(self) -> str:
"""Return the representation."""
return f"<entity {self.entity_id}={self._stringify_state(self.available)}>"

View File

@ -988,6 +988,89 @@ async def test_friendly_name(
assert state.attributes.get(ATTR_FRIENDLY_NAME) == expected_friendly_name
@pytest.mark.parametrize(
(
"entity_name",
"expected_friendly_name1",
"expected_friendly_name2",
"expected_friendly_name3",
),
(
(
"Entity Blu",
"Device Bla Entity Blu",
"Device Bla2 Entity Blu",
"New Device Entity Blu",
),
(
None,
"Device Bla",
"Device Bla2",
"New Device",
),
),
)
async def test_friendly_name_updated(
hass: HomeAssistant,
device_registry: dr.DeviceRegistry,
entity_registry: er.EntityRegistry,
entity_name: str | None,
expected_friendly_name1: str,
expected_friendly_name2: str,
expected_friendly_name3: str,
) -> None:
"""Test entity_id is influenced by entity name."""
async def async_setup_entry(hass, config_entry, async_add_entities):
"""Mock setup entry method."""
async_add_entities(
[
MockEntity(
unique_id="qwer",
device_info={
"identifiers": {("hue", "1234")},
"connections": {(dr.CONNECTION_NETWORK_MAC, "abcd")},
"name": "Device Bla",
},
has_entity_name=True,
name=entity_name,
),
]
)
return True
platform = MockPlatform(async_setup_entry=async_setup_entry)
config_entry = MockConfigEntry(entry_id="super-mock-id")
entity_platform = MockEntityPlatform(
hass, platform_name=config_entry.domain, platform=platform
)
assert await entity_platform.async_setup_entry(config_entry)
await hass.async_block_till_done()
assert len(hass.states.async_entity_ids()) == 1
state = hass.states.async_all()[0]
assert state.attributes.get(ATTR_FRIENDLY_NAME) == expected_friendly_name1
device = device_registry.async_get_device(identifiers={("hue", "1234")})
device_registry.async_update_device(device.id, name_by_user="Device Bla2")
await hass.async_block_till_done()
state = hass.states.async_all()[0]
assert state.attributes.get(ATTR_FRIENDLY_NAME) == expected_friendly_name2
device = device_registry.async_get_or_create(
config_entry_id=config_entry.entry_id,
identifiers={("hue", "5678")},
name="New Device",
)
entity_registry.async_update_entity(state.entity_id, device_id=device.id)
await hass.async_block_till_done()
state = hass.states.async_all()[0]
assert state.attributes.get(ATTR_FRIENDLY_NAME) == expected_friendly_name3
async def test_translation_key(hass: HomeAssistant) -> None:
"""Test translation key property."""
mock_entity1 = entity.Entity()