Speed up processing subscribe_events and subscribe_entites when user has read all permissions (#93611)

Speed up processing subscribe_events and subscribe_entites when user the read all permissions
pull/93677/head^2
J. Nick Koston 2023-05-27 18:59:46 -05:00 committed by GitHub
parent 67d9fa8b22
commit e1b7d68134
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 21 additions and 14 deletions

View File

@ -109,15 +109,18 @@ def handle_subscribe_events(
raise Unauthorized
if event_type == EVENT_STATE_CHANGED:
user = connection.user
@callback
def forward_events(event: Event) -> None:
"""Forward state changed events to websocket."""
if not connection.user.permissions.check_entity(
event.data["entity_id"], POLICY_READ
):
# We have to lookup the permissions again because the user might have
# changed since the subscription was created.
permissions = user.permissions
if not permissions.access_all_entities(
POLICY_READ
) and not permissions.check_entity(event.data["entity_id"], POLICY_READ):
return
connection.send_message(messages.cached_event_message(msg["id"], event))
else:
@ -227,13 +230,13 @@ async def handle_call_service(
def _async_get_allowed_states(
hass: HomeAssistant, connection: ActiveConnection
) -> list[State]:
if connection.user.permissions.access_all_entities("read"):
if connection.user.permissions.access_all_entities(POLICY_READ):
return hass.states.async_all()
entity_perm = connection.user.permissions.check_entity
return [
state
for state in hass.states.async_all()
if entity_perm(state.entity_id, "read")
if entity_perm(state.entity_id, POLICY_READ)
]
@ -289,17 +292,21 @@ def handle_subscribe_entities(
) -> None:
"""Handle subscribe entities command."""
entity_ids = set(msg.get("entity_ids", []))
user = connection.user
@callback
def forward_entity_changes(event: Event) -> None:
"""Forward entity state changed events to websocket."""
if not connection.user.permissions.check_entity(
event.data["entity_id"], POLICY_READ
):
entity_id = event.data["entity_id"]
if entity_ids and entity_id not in entity_ids:
return
if entity_ids and event.data["entity_id"] not in entity_ids:
# We have to lookup the permissions again because the user might have
# changed since the subscription was created.
permissions = user.permissions
if not permissions.access_all_entities(
POLICY_READ
) and not permissions.check_entity(event.data["entity_id"], POLICY_READ):
return
connection.send_message(messages.cached_state_diff_message(msg["id"], event))
# We must never await between sending the states and listening for
@ -541,13 +548,13 @@ def handle_entity_source(
entity_perm = connection.user.permissions.check_entity
if "entity_id" not in msg:
if connection.user.permissions.access_all_entities("read"):
if connection.user.permissions.access_all_entities(POLICY_READ):
sources = raw_sources
else:
sources = {
entity_id: source
for entity_id, source in raw_sources.items()
if entity_perm(entity_id, "read")
if entity_perm(entity_id, POLICY_READ)
}
connection.send_result(msg["id"], sources)
@ -556,7 +563,7 @@ def handle_entity_source(
sources = {}
for entity_id in msg["entity_id"]:
if not entity_perm(entity_id, "read"):
if not entity_perm(entity_id, POLICY_READ):
raise Unauthorized(
context=connection.context(msg),
permission=POLICY_READ,