Simplify WS command entity/source (#99439)

pull/100207/head
Erik Montnemery 2023-09-12 15:39:11 +02:00 committed by GitHub
parent e143bdf2f5
commit fabb098ec3
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 26 additions and 120 deletions

View File

@ -11,7 +11,7 @@ from typing import Any, cast
import voluptuous as vol
from homeassistant.auth.models import User
from homeassistant.auth.permissions.const import CAT_ENTITIES, POLICY_READ
from homeassistant.auth.permissions.const import POLICY_READ
from homeassistant.const import (
EVENT_STATE_CHANGED,
MATCH_ALL,
@ -52,7 +52,6 @@ from homeassistant.util.json import format_unserializable_data
from . import const, decorators, messages
from .connection import ActiveConnection
from .const import ERR_NOT_FOUND
from .messages import construct_event_message, construct_result_message
ALL_SERVICE_DESCRIPTIONS_JSON_CACHE = "websocket_api_all_service_descriptions_json"
@ -596,47 +595,35 @@ async def handle_render_template(
hass.loop.call_soon_threadsafe(info.async_refresh)
def _serialize_entity_sources(
entity_infos: dict[str, dict[str, str]]
) -> dict[str, Any]:
"""Prepare a websocket response from a dict of entity sources."""
result = {}
for entity_id, entity_info in entity_infos.items():
result[entity_id] = {"domain": entity_info["domain"]}
return result
@callback
@decorators.websocket_command(
{vol.Required("type"): "entity/source", vol.Optional("entity_id"): [cv.entity_id]}
)
@decorators.websocket_command({vol.Required("type"): "entity/source"})
def handle_entity_source(
hass: HomeAssistant, connection: ActiveConnection, msg: dict[str, Any]
) -> None:
"""Handle entity source command."""
raw_sources = entity.entity_sources(hass)
all_entity_sources = entity.entity_sources(hass)
entity_perm = connection.user.permissions.check_entity
if "entity_id" not in msg:
if connection.user.permissions.access_all_entities(POLICY_READ):
sources = raw_sources
else:
sources = {
entity_id: source
for entity_id, source in raw_sources.items()
if entity_perm(entity_id, POLICY_READ)
}
if connection.user.permissions.access_all_entities(POLICY_READ):
entity_sources = all_entity_sources
else:
entity_sources = {
entity_id: source
for entity_id, source in all_entity_sources.items()
if entity_perm(entity_id, POLICY_READ)
}
connection.send_result(msg["id"], sources)
return
sources = {}
for entity_id in msg["entity_id"]:
if not entity_perm(entity_id, POLICY_READ):
raise Unauthorized(
context=connection.context(msg),
permission=POLICY_READ,
perm_category=CAT_ENTITIES,
)
if (source := raw_sources.get(entity_id)) is None:
connection.send_error(msg["id"], ERR_NOT_FOUND, "Entity not found")
return
sources[entity_id] = source
connection.send_result(msg["id"], sources)
connection.send_result(msg["id"], _serialize_entity_sources(entity_sources))
@decorators.websocket_command(

View File

@ -20,7 +20,7 @@ from homeassistant.components.websocket_api.const import FEATURE_COALESCE_MESSAG
from homeassistant.const import SIGNAL_BOOTSTRAP_INTEGRATIONS
from homeassistant.core import Context, HomeAssistant, State, callback
from homeassistant.exceptions import HomeAssistantError
from homeassistant.helpers import device_registry as dr, entity
from homeassistant.helpers import device_registry as dr
from homeassistant.helpers.dispatcher import async_dispatcher_send
from homeassistant.loader import async_get_integration
from homeassistant.setup import DATA_SETUP_TIME, async_setup_component
@ -1941,76 +1941,10 @@ async def test_entity_source_admin(
assert msg["type"] == const.TYPE_RESULT
assert msg["success"]
assert msg["result"] == {
"test_domain.entity_1": {
"custom_component": False,
"domain": "test_platform",
"source": entity.SOURCE_PLATFORM_CONFIG,
},
"test_domain.entity_2": {
"custom_component": False,
"domain": "test_platform",
"source": entity.SOURCE_PLATFORM_CONFIG,
},
"test_domain.entity_1": {"domain": "test_platform"},
"test_domain.entity_2": {"domain": "test_platform"},
}
# Fetch one
await websocket_client.send_json(
{"id": 7, "type": "entity/source", "entity_id": ["test_domain.entity_2"]}
)
msg = await websocket_client.receive_json()
assert msg["id"] == 7
assert msg["type"] == const.TYPE_RESULT
assert msg["success"]
assert msg["result"] == {
"test_domain.entity_2": {
"custom_component": False,
"domain": "test_platform",
"source": entity.SOURCE_PLATFORM_CONFIG,
},
}
# Fetch two
await websocket_client.send_json(
{
"id": 8,
"type": "entity/source",
"entity_id": ["test_domain.entity_2", "test_domain.entity_1"],
}
)
msg = await websocket_client.receive_json()
assert msg["id"] == 8
assert msg["type"] == const.TYPE_RESULT
assert msg["success"]
assert msg["result"] == {
"test_domain.entity_1": {
"custom_component": False,
"domain": "test_platform",
"source": entity.SOURCE_PLATFORM_CONFIG,
},
"test_domain.entity_2": {
"custom_component": False,
"domain": "test_platform",
"source": entity.SOURCE_PLATFORM_CONFIG,
},
}
# Fetch non existing
await websocket_client.send_json(
{
"id": 9,
"type": "entity/source",
"entity_id": ["test_domain.entity_2", "test_domain.non_existing"],
}
)
msg = await websocket_client.receive_json()
assert msg["id"] == 9
assert msg["type"] == const.TYPE_RESULT
assert not msg["success"]
assert msg["error"]["code"] == const.ERR_NOT_FOUND
# Mock policy
hass_admin_user.groups = []
hass_admin_user.mock_policy(
@ -2025,24 +1959,9 @@ async def test_entity_source_admin(
assert msg["type"] == const.TYPE_RESULT
assert msg["success"]
assert msg["result"] == {
"test_domain.entity_2": {
"custom_component": False,
"domain": "test_platform",
"source": entity.SOURCE_PLATFORM_CONFIG,
},
"test_domain.entity_2": {"domain": "test_platform"},
}
# Fetch unauthorized
await websocket_client.send_json(
{"id": 11, "type": "entity/source", "entity_id": ["test_domain.entity_1"]}
)
msg = await websocket_client.receive_json()
assert msg["id"] == 11
assert msg["type"] == const.TYPE_RESULT
assert not msg["success"]
assert msg["error"]["code"] == const.ERR_UNAUTHORIZED
async def test_subscribe_trigger(hass: HomeAssistant, websocket_client) -> None:
"""Test subscribing to a trigger."""