Speed up processing subscribe_events and subscribe_entites when user has read all permissions (#93611)

Speed up processing subscribe_events and subscribe_entites when user the read all permissions
pull/93677/head^2
J. Nick Koston 2023-05-27 18:59:46 -05:00 committed by GitHub
parent 67d9fa8b22
commit e1b7d68134
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 21 additions and 14 deletions

View File

@ -109,15 +109,18 @@ def handle_subscribe_events(
raise Unauthorized raise Unauthorized
if event_type == EVENT_STATE_CHANGED: if event_type == EVENT_STATE_CHANGED:
user = connection.user
@callback @callback
def forward_events(event: Event) -> None: def forward_events(event: Event) -> None:
"""Forward state changed events to websocket.""" """Forward state changed events to websocket."""
if not connection.user.permissions.check_entity( # We have to lookup the permissions again because the user might have
event.data["entity_id"], POLICY_READ # changed since the subscription was created.
): permissions = user.permissions
if not permissions.access_all_entities(
POLICY_READ
) and not permissions.check_entity(event.data["entity_id"], POLICY_READ):
return return
connection.send_message(messages.cached_event_message(msg["id"], event)) connection.send_message(messages.cached_event_message(msg["id"], event))
else: else:
@ -227,13 +230,13 @@ async def handle_call_service(
def _async_get_allowed_states( def _async_get_allowed_states(
hass: HomeAssistant, connection: ActiveConnection hass: HomeAssistant, connection: ActiveConnection
) -> list[State]: ) -> list[State]:
if connection.user.permissions.access_all_entities("read"): if connection.user.permissions.access_all_entities(POLICY_READ):
return hass.states.async_all() return hass.states.async_all()
entity_perm = connection.user.permissions.check_entity entity_perm = connection.user.permissions.check_entity
return [ return [
state state
for state in hass.states.async_all() for state in hass.states.async_all()
if entity_perm(state.entity_id, "read") if entity_perm(state.entity_id, POLICY_READ)
] ]
@ -289,17 +292,21 @@ def handle_subscribe_entities(
) -> None: ) -> None:
"""Handle subscribe entities command.""" """Handle subscribe entities command."""
entity_ids = set(msg.get("entity_ids", [])) entity_ids = set(msg.get("entity_ids", []))
user = connection.user
@callback @callback
def forward_entity_changes(event: Event) -> None: def forward_entity_changes(event: Event) -> None:
"""Forward entity state changed events to websocket.""" """Forward entity state changed events to websocket."""
if not connection.user.permissions.check_entity( entity_id = event.data["entity_id"]
event.data["entity_id"], POLICY_READ if entity_ids and entity_id not in entity_ids:
):
return return
if entity_ids and event.data["entity_id"] not in entity_ids: # We have to lookup the permissions again because the user might have
# changed since the subscription was created.
permissions = user.permissions
if not permissions.access_all_entities(
POLICY_READ
) and not permissions.check_entity(event.data["entity_id"], POLICY_READ):
return return
connection.send_message(messages.cached_state_diff_message(msg["id"], event)) connection.send_message(messages.cached_state_diff_message(msg["id"], event))
# We must never await between sending the states and listening for # We must never await between sending the states and listening for
@ -541,13 +548,13 @@ def handle_entity_source(
entity_perm = connection.user.permissions.check_entity entity_perm = connection.user.permissions.check_entity
if "entity_id" not in msg: if "entity_id" not in msg:
if connection.user.permissions.access_all_entities("read"): if connection.user.permissions.access_all_entities(POLICY_READ):
sources = raw_sources sources = raw_sources
else: else:
sources = { sources = {
entity_id: source entity_id: source
for entity_id, source in raw_sources.items() for entity_id, source in raw_sources.items()
if entity_perm(entity_id, "read") if entity_perm(entity_id, POLICY_READ)
} }
connection.send_result(msg["id"], sources) connection.send_result(msg["id"], sources)
@ -556,7 +563,7 @@ def handle_entity_source(
sources = {} sources = {}
for entity_id in msg["entity_id"]: for entity_id in msg["entity_id"]:
if not entity_perm(entity_id, "read"): if not entity_perm(entity_id, POLICY_READ):
raise Unauthorized( raise Unauthorized(
context=connection.context(msg), context=connection.context(msg),
permission=POLICY_READ, permission=POLICY_READ,