diff --git a/homeassistant/components/config/entity_registry.py b/homeassistant/components/config/entity_registry.py index 5f484c10472..445ca96c8b0 100644 --- a/homeassistant/components/config/entity_registry.py +++ b/homeassistant/components/config/entity_registry.py @@ -1,11 +1,18 @@ """HTTP views to interact with the entity registry.""" +from __future__ import annotations + import voluptuous as vol from homeassistant import config_entries from homeassistant.components import websocket_api from homeassistant.components.websocket_api.const import ERR_NOT_FOUND from homeassistant.components.websocket_api.decorators import require_admin -from homeassistant.core import callback +from homeassistant.components.websocket_api.messages import ( + IDEN_JSON_TEMPLATE, + IDEN_TEMPLATE, + message_to_json, +) +from homeassistant.core import Event, HomeAssistant, callback from homeassistant.helpers import ( config_validation as cv, device_registry as dr, @@ -13,8 +20,40 @@ from homeassistant.helpers import ( ) -async def async_setup(hass): +async def async_setup(hass: HomeAssistant) -> bool: """Enable the Entity Registry views.""" + + cached_list_entities: str | None = None + + @callback + def _async_clear_list_entities_cache(event: Event) -> None: + nonlocal cached_list_entities + cached_list_entities = None + + @websocket_api.websocket_command( + {vol.Required("type"): "config/entity_registry/list"} + ) + @callback + def websocket_list_entities(hass, connection, msg): + """Handle list registry entries command.""" + nonlocal cached_list_entities + if not cached_list_entities: + registry = er.async_get(hass) + cached_list_entities = message_to_json( + websocket_api.result_message( + IDEN_TEMPLATE, + [_entry_dict(entry) for entry in registry.entities.values()], + ) + ) + connection.send_message( + cached_list_entities.replace(IDEN_JSON_TEMPLATE, str(msg["id"]), 1) + ) + + hass.bus.async_listen( + er.EVENT_ENTITY_REGISTRY_UPDATED, + _async_clear_list_entities_cache, + run_immediately=True, + ) websocket_api.async_register_command(hass, websocket_list_entities) websocket_api.async_register_command(hass, websocket_get_entity) websocket_api.async_register_command(hass, websocket_update_entity) @@ -22,33 +61,6 @@ async def async_setup(hass): return True -@websocket_api.websocket_command({vol.Required("type"): "config/entity_registry/list"}) -@callback -def websocket_list_entities(hass, connection, msg): - """Handle list registry entries command.""" - registry = er.async_get(hass) - connection.send_message( - websocket_api.result_message( - msg["id"], - [ - { - "area_id": entry.area_id, - "config_entry_id": entry.config_entry_id, - "device_id": entry.device_id, - "disabled_by": entry.disabled_by, - "entity_category": entry.entity_category, - "entity_id": entry.entity_id, - "hidden_by": entry.hidden_by, - "icon": entry.icon, - "name": entry.name, - "platform": entry.platform, - } - for entry in registry.entities.values() - ], - ) - ) - - @websocket_api.websocket_command( { vol.Required("type"): "config/entity_registry/get", @@ -211,7 +223,7 @@ def websocket_remove_entity(hass, connection, msg): @callback -def _entry_ext_dict(entry): +def _entry_dict(entry): """Convert entry to API format.""" return { "area_id": entry.area_id, @@ -224,12 +236,19 @@ def _entry_ext_dict(entry): "icon": entry.icon, "name": entry.name, "platform": entry.platform, - "capabilities": entry.capabilities, - "device_class": entry.device_class, - "has_entity_name": entry.has_entity_name, - "options": entry.options, - "original_device_class": entry.original_device_class, - "original_icon": entry.original_icon, - "original_name": entry.original_name, - "unique_id": entry.unique_id, } + + +@callback +def _entry_ext_dict(entry): + """Convert entry to API format.""" + data = _entry_dict(entry) + data["capabilities"] = entry.capabilities + data["device_class"] = entry.device_class + data["has_entity_name"] = entry.has_entity_name + data["options"] = entry.options + data["original_device_class"] = entry.original_device_class + data["original_icon"] = entry.original_icon + data["original_name"] = entry.original_name + data["unique_id"] = entry.unique_id + return data diff --git a/tests/components/config/test_entity_registry.py b/tests/components/config/test_entity_registry.py index 69744817a27..9c5984f751e 100644 --- a/tests/components/config/test_entity_registry.py +++ b/tests/components/config/test_entity_registry.py @@ -5,6 +5,7 @@ from homeassistant.components.config import entity_registry from homeassistant.const import ATTR_ICON from homeassistant.helpers.device_registry import DeviceEntryDisabler from homeassistant.helpers.entity_registry import ( + EVENT_ENTITY_REGISTRY_UPDATED, RegistryEntry, RegistryEntryDisabler, RegistryEntryHider, @@ -81,6 +82,40 @@ async def test_list_entities(hass, client): }, ] + mock_registry( + hass, + { + "test_domain.name": RegistryEntry( + entity_id="test_domain.name", + unique_id="1234", + platform="test_platform", + name="Hello World", + ), + }, + ) + + hass.bus.async_fire( + EVENT_ENTITY_REGISTRY_UPDATED, + {"action": "create", "entity_id": "test_domain.no_name"}, + ) + await client.send_json({"id": 6, "type": "config/entity_registry/list"}) + msg = await client.receive_json() + + assert msg["result"] == [ + { + "config_entry_id": None, + "device_id": None, + "area_id": None, + "disabled_by": None, + "entity_id": "test_domain.name", + "hidden_by": None, + "name": "Hello World", + "icon": None, + "platform": "test_platform", + "entity_category": None, + }, + ] + async def test_get_entity(hass, client): """Test get entry."""