Speed up processing subscribe_events and subscribe_entites when user has read all permissions (#93611)
Speed up processing subscribe_events and subscribe_entites when user the read all permissionspull/93677/head^2
parent
67d9fa8b22
commit
e1b7d68134
|
@ -109,15 +109,18 @@ def handle_subscribe_events(
|
||||||
raise Unauthorized
|
raise Unauthorized
|
||||||
|
|
||||||
if event_type == EVENT_STATE_CHANGED:
|
if event_type == EVENT_STATE_CHANGED:
|
||||||
|
user = connection.user
|
||||||
|
|
||||||
@callback
|
@callback
|
||||||
def forward_events(event: Event) -> None:
|
def forward_events(event: Event) -> None:
|
||||||
"""Forward state changed events to websocket."""
|
"""Forward state changed events to websocket."""
|
||||||
if not connection.user.permissions.check_entity(
|
# We have to lookup the permissions again because the user might have
|
||||||
event.data["entity_id"], POLICY_READ
|
# changed since the subscription was created.
|
||||||
):
|
permissions = user.permissions
|
||||||
|
if not permissions.access_all_entities(
|
||||||
|
POLICY_READ
|
||||||
|
) and not permissions.check_entity(event.data["entity_id"], POLICY_READ):
|
||||||
return
|
return
|
||||||
|
|
||||||
connection.send_message(messages.cached_event_message(msg["id"], event))
|
connection.send_message(messages.cached_event_message(msg["id"], event))
|
||||||
|
|
||||||
else:
|
else:
|
||||||
|
@ -227,13 +230,13 @@ async def handle_call_service(
|
||||||
def _async_get_allowed_states(
|
def _async_get_allowed_states(
|
||||||
hass: HomeAssistant, connection: ActiveConnection
|
hass: HomeAssistant, connection: ActiveConnection
|
||||||
) -> list[State]:
|
) -> list[State]:
|
||||||
if connection.user.permissions.access_all_entities("read"):
|
if connection.user.permissions.access_all_entities(POLICY_READ):
|
||||||
return hass.states.async_all()
|
return hass.states.async_all()
|
||||||
entity_perm = connection.user.permissions.check_entity
|
entity_perm = connection.user.permissions.check_entity
|
||||||
return [
|
return [
|
||||||
state
|
state
|
||||||
for state in hass.states.async_all()
|
for state in hass.states.async_all()
|
||||||
if entity_perm(state.entity_id, "read")
|
if entity_perm(state.entity_id, POLICY_READ)
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
|
@ -289,17 +292,21 @@ def handle_subscribe_entities(
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Handle subscribe entities command."""
|
"""Handle subscribe entities command."""
|
||||||
entity_ids = set(msg.get("entity_ids", []))
|
entity_ids = set(msg.get("entity_ids", []))
|
||||||
|
user = connection.user
|
||||||
|
|
||||||
@callback
|
@callback
|
||||||
def forward_entity_changes(event: Event) -> None:
|
def forward_entity_changes(event: Event) -> None:
|
||||||
"""Forward entity state changed events to websocket."""
|
"""Forward entity state changed events to websocket."""
|
||||||
if not connection.user.permissions.check_entity(
|
entity_id = event.data["entity_id"]
|
||||||
event.data["entity_id"], POLICY_READ
|
if entity_ids and entity_id not in entity_ids:
|
||||||
):
|
|
||||||
return
|
return
|
||||||
if entity_ids and event.data["entity_id"] not in entity_ids:
|
# We have to lookup the permissions again because the user might have
|
||||||
|
# changed since the subscription was created.
|
||||||
|
permissions = user.permissions
|
||||||
|
if not permissions.access_all_entities(
|
||||||
|
POLICY_READ
|
||||||
|
) and not permissions.check_entity(event.data["entity_id"], POLICY_READ):
|
||||||
return
|
return
|
||||||
|
|
||||||
connection.send_message(messages.cached_state_diff_message(msg["id"], event))
|
connection.send_message(messages.cached_state_diff_message(msg["id"], event))
|
||||||
|
|
||||||
# We must never await between sending the states and listening for
|
# We must never await between sending the states and listening for
|
||||||
|
@ -541,13 +548,13 @@ def handle_entity_source(
|
||||||
entity_perm = connection.user.permissions.check_entity
|
entity_perm = connection.user.permissions.check_entity
|
||||||
|
|
||||||
if "entity_id" not in msg:
|
if "entity_id" not in msg:
|
||||||
if connection.user.permissions.access_all_entities("read"):
|
if connection.user.permissions.access_all_entities(POLICY_READ):
|
||||||
sources = raw_sources
|
sources = raw_sources
|
||||||
else:
|
else:
|
||||||
sources = {
|
sources = {
|
||||||
entity_id: source
|
entity_id: source
|
||||||
for entity_id, source in raw_sources.items()
|
for entity_id, source in raw_sources.items()
|
||||||
if entity_perm(entity_id, "read")
|
if entity_perm(entity_id, POLICY_READ)
|
||||||
}
|
}
|
||||||
|
|
||||||
connection.send_result(msg["id"], sources)
|
connection.send_result(msg["id"], sources)
|
||||||
|
@ -556,7 +563,7 @@ def handle_entity_source(
|
||||||
sources = {}
|
sources = {}
|
||||||
|
|
||||||
for entity_id in msg["entity_id"]:
|
for entity_id in msg["entity_id"]:
|
||||||
if not entity_perm(entity_id, "read"):
|
if not entity_perm(entity_id, POLICY_READ):
|
||||||
raise Unauthorized(
|
raise Unauthorized(
|
||||||
context=connection.context(msg),
|
context=connection.context(msg),
|
||||||
permission=POLICY_READ,
|
permission=POLICY_READ,
|
||||||
|
|
Loading…
Reference in New Issue