diff --git a/homeassistant/components/websocket_api/commands.py b/homeassistant/components/websocket_api/commands.py index ea21b7b5eba..66866197081 100644 --- a/homeassistant/components/websocket_api/commands.py +++ b/homeassistant/components/websocket_api/commands.py @@ -11,7 +11,7 @@ from typing import Any, cast import voluptuous as vol from homeassistant.auth.models import User -from homeassistant.auth.permissions.const import CAT_ENTITIES, POLICY_READ +from homeassistant.auth.permissions.const import POLICY_READ from homeassistant.const import ( EVENT_STATE_CHANGED, MATCH_ALL, @@ -52,7 +52,6 @@ from homeassistant.util.json import format_unserializable_data from . import const, decorators, messages from .connection import ActiveConnection -from .const import ERR_NOT_FOUND from .messages import construct_event_message, construct_result_message ALL_SERVICE_DESCRIPTIONS_JSON_CACHE = "websocket_api_all_service_descriptions_json" @@ -596,47 +595,35 @@ async def handle_render_template( hass.loop.call_soon_threadsafe(info.async_refresh) +def _serialize_entity_sources( + entity_infos: dict[str, dict[str, str]] +) -> dict[str, Any]: + """Prepare a websocket response from a dict of entity sources.""" + result = {} + for entity_id, entity_info in entity_infos.items(): + result[entity_id] = {"domain": entity_info["domain"]} + return result + + @callback -@decorators.websocket_command( - {vol.Required("type"): "entity/source", vol.Optional("entity_id"): [cv.entity_id]} -) +@decorators.websocket_command({vol.Required("type"): "entity/source"}) def handle_entity_source( hass: HomeAssistant, connection: ActiveConnection, msg: dict[str, Any] ) -> None: """Handle entity source command.""" - raw_sources = entity.entity_sources(hass) + all_entity_sources = entity.entity_sources(hass) entity_perm = connection.user.permissions.check_entity - if "entity_id" not in msg: - if connection.user.permissions.access_all_entities(POLICY_READ): - sources = raw_sources - else: - sources = { - entity_id: source - for entity_id, source in raw_sources.items() - if entity_perm(entity_id, POLICY_READ) - } + if connection.user.permissions.access_all_entities(POLICY_READ): + entity_sources = all_entity_sources + else: + entity_sources = { + entity_id: source + for entity_id, source in all_entity_sources.items() + if entity_perm(entity_id, POLICY_READ) + } - connection.send_result(msg["id"], sources) - return - - sources = {} - - for entity_id in msg["entity_id"]: - if not entity_perm(entity_id, POLICY_READ): - raise Unauthorized( - context=connection.context(msg), - permission=POLICY_READ, - perm_category=CAT_ENTITIES, - ) - - if (source := raw_sources.get(entity_id)) is None: - connection.send_error(msg["id"], ERR_NOT_FOUND, "Entity not found") - return - - sources[entity_id] = source - - connection.send_result(msg["id"], sources) + connection.send_result(msg["id"], _serialize_entity_sources(entity_sources)) @decorators.websocket_command( diff --git a/tests/components/websocket_api/test_commands.py b/tests/components/websocket_api/test_commands.py index b1b2027c65d..8cd5e23ce29 100644 --- a/tests/components/websocket_api/test_commands.py +++ b/tests/components/websocket_api/test_commands.py @@ -20,7 +20,7 @@ from homeassistant.components.websocket_api.const import FEATURE_COALESCE_MESSAG from homeassistant.const import SIGNAL_BOOTSTRAP_INTEGRATIONS from homeassistant.core import Context, HomeAssistant, State, callback from homeassistant.exceptions import HomeAssistantError -from homeassistant.helpers import device_registry as dr, entity +from homeassistant.helpers import device_registry as dr from homeassistant.helpers.dispatcher import async_dispatcher_send from homeassistant.loader import async_get_integration from homeassistant.setup import DATA_SETUP_TIME, async_setup_component @@ -1941,76 +1941,10 @@ async def test_entity_source_admin( assert msg["type"] == const.TYPE_RESULT assert msg["success"] assert msg["result"] == { - "test_domain.entity_1": { - "custom_component": False, - "domain": "test_platform", - "source": entity.SOURCE_PLATFORM_CONFIG, - }, - "test_domain.entity_2": { - "custom_component": False, - "domain": "test_platform", - "source": entity.SOURCE_PLATFORM_CONFIG, - }, + "test_domain.entity_1": {"domain": "test_platform"}, + "test_domain.entity_2": {"domain": "test_platform"}, } - # Fetch one - await websocket_client.send_json( - {"id": 7, "type": "entity/source", "entity_id": ["test_domain.entity_2"]} - ) - - msg = await websocket_client.receive_json() - assert msg["id"] == 7 - assert msg["type"] == const.TYPE_RESULT - assert msg["success"] - assert msg["result"] == { - "test_domain.entity_2": { - "custom_component": False, - "domain": "test_platform", - "source": entity.SOURCE_PLATFORM_CONFIG, - }, - } - - # Fetch two - await websocket_client.send_json( - { - "id": 8, - "type": "entity/source", - "entity_id": ["test_domain.entity_2", "test_domain.entity_1"], - } - ) - - msg = await websocket_client.receive_json() - assert msg["id"] == 8 - assert msg["type"] == const.TYPE_RESULT - assert msg["success"] - assert msg["result"] == { - "test_domain.entity_1": { - "custom_component": False, - "domain": "test_platform", - "source": entity.SOURCE_PLATFORM_CONFIG, - }, - "test_domain.entity_2": { - "custom_component": False, - "domain": "test_platform", - "source": entity.SOURCE_PLATFORM_CONFIG, - }, - } - - # Fetch non existing - await websocket_client.send_json( - { - "id": 9, - "type": "entity/source", - "entity_id": ["test_domain.entity_2", "test_domain.non_existing"], - } - ) - - msg = await websocket_client.receive_json() - assert msg["id"] == 9 - assert msg["type"] == const.TYPE_RESULT - assert not msg["success"] - assert msg["error"]["code"] == const.ERR_NOT_FOUND - # Mock policy hass_admin_user.groups = [] hass_admin_user.mock_policy( @@ -2025,24 +1959,9 @@ async def test_entity_source_admin( assert msg["type"] == const.TYPE_RESULT assert msg["success"] assert msg["result"] == { - "test_domain.entity_2": { - "custom_component": False, - "domain": "test_platform", - "source": entity.SOURCE_PLATFORM_CONFIG, - }, + "test_domain.entity_2": {"domain": "test_platform"}, } - # Fetch unauthorized - await websocket_client.send_json( - {"id": 11, "type": "entity/source", "entity_id": ["test_domain.entity_1"]} - ) - - msg = await websocket_client.receive_json() - assert msg["id"] == 11 - assert msg["type"] == const.TYPE_RESULT - assert not msg["success"] - assert msg["error"]["code"] == const.ERR_UNAUTHORIZED - async def test_subscribe_trigger(hass: HomeAssistant, websocket_client) -> None: """Test subscribing to a trigger."""