From e1b7d681349414df21df73a3c494a0b22d082eb1 Mon Sep 17 00:00:00 2001 From: "J. Nick Koston" Date: Sat, 27 May 2023 18:59:46 -0500 Subject: [PATCH] Speed up processing subscribe_events and subscribe_entites when user has read all permissions (#93611) Speed up processing subscribe_events and subscribe_entites when user the read all permissions --- .../components/websocket_api/commands.py | 35 +++++++++++-------- 1 file changed, 21 insertions(+), 14 deletions(-) diff --git a/homeassistant/components/websocket_api/commands.py b/homeassistant/components/websocket_api/commands.py index b5dabf8b733..bdb087069f8 100644 --- a/homeassistant/components/websocket_api/commands.py +++ b/homeassistant/components/websocket_api/commands.py @@ -109,15 +109,18 @@ def handle_subscribe_events( raise Unauthorized if event_type == EVENT_STATE_CHANGED: + user = connection.user @callback def forward_events(event: Event) -> None: """Forward state changed events to websocket.""" - if not connection.user.permissions.check_entity( - event.data["entity_id"], POLICY_READ - ): + # We have to lookup the permissions again because the user might have + # changed since the subscription was created. + permissions = user.permissions + if not permissions.access_all_entities( + POLICY_READ + ) and not permissions.check_entity(event.data["entity_id"], POLICY_READ): return - connection.send_message(messages.cached_event_message(msg["id"], event)) else: @@ -227,13 +230,13 @@ async def handle_call_service( def _async_get_allowed_states( hass: HomeAssistant, connection: ActiveConnection ) -> list[State]: - if connection.user.permissions.access_all_entities("read"): + if connection.user.permissions.access_all_entities(POLICY_READ): return hass.states.async_all() entity_perm = connection.user.permissions.check_entity return [ state for state in hass.states.async_all() - if entity_perm(state.entity_id, "read") + if entity_perm(state.entity_id, POLICY_READ) ] @@ -289,17 +292,21 @@ def handle_subscribe_entities( ) -> None: """Handle subscribe entities command.""" entity_ids = set(msg.get("entity_ids", [])) + user = connection.user @callback def forward_entity_changes(event: Event) -> None: """Forward entity state changed events to websocket.""" - if not connection.user.permissions.check_entity( - event.data["entity_id"], POLICY_READ - ): + entity_id = event.data["entity_id"] + if entity_ids and entity_id not in entity_ids: return - if entity_ids and event.data["entity_id"] not in entity_ids: + # We have to lookup the permissions again because the user might have + # changed since the subscription was created. + permissions = user.permissions + if not permissions.access_all_entities( + POLICY_READ + ) and not permissions.check_entity(event.data["entity_id"], POLICY_READ): return - connection.send_message(messages.cached_state_diff_message(msg["id"], event)) # We must never await between sending the states and listening for @@ -541,13 +548,13 @@ def handle_entity_source( entity_perm = connection.user.permissions.check_entity if "entity_id" not in msg: - if connection.user.permissions.access_all_entities("read"): + if connection.user.permissions.access_all_entities(POLICY_READ): sources = raw_sources else: sources = { entity_id: source for entity_id, source in raw_sources.items() - if entity_perm(entity_id, "read") + if entity_perm(entity_id, POLICY_READ) } connection.send_result(msg["id"], sources) @@ -556,7 +563,7 @@ def handle_entity_source( sources = {} for entity_id in msg["entity_id"]: - if not entity_perm(entity_id, "read"): + if not entity_perm(entity_id, POLICY_READ): raise Unauthorized( context=connection.context(msg), permission=POLICY_READ,