Add support for using an entityfilter to subscribe_entities (#124641)

* Add support for using an entityfilter to subscribe_entities

* filter init

* fix

* coverage
pull/124696/head
J. Nick Koston 2024-08-26 23:17:05 -10:00 committed by GitHub
parent 68d6f1c1aa
commit d8161c431f
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 87 additions and 16 deletions

View File

@ -36,6 +36,10 @@ from homeassistant.exceptions import (
)
from homeassistant.helpers import config_validation as cv, entity, template
from homeassistant.helpers.dispatcher import async_dispatcher_connect
from homeassistant.helpers.entityfilter import (
INCLUDE_EXCLUDE_BASE_FILTER_SCHEMA,
convert_include_exclude_filter,
)
from homeassistant.helpers.event import (
TrackTemplate,
TrackTemplateResult,
@ -366,14 +370,17 @@ def _send_handle_get_states_response(
@callback
def _forward_entity_changes(
send_message: Callable[[str | bytes | dict[str, Any]], None],
entity_ids: set[str],
entity_ids: set[str] | None,
entity_filter: Callable[[str], bool] | None,
user: User,
message_id_as_bytes: bytes,
event: Event[EventStateChangedData],
) -> None:
"""Forward entity state changed events to websocket."""
entity_id = event.data["entity_id"]
if entity_ids and entity_id not in entity_ids:
if (entity_ids and entity_id not in entity_ids) or (
entity_filter and not entity_filter(entity_id)
):
return
# We have to lookup the permissions again because the user might have
# changed since the subscription was created.
@ -381,7 +388,7 @@ def _forward_entity_changes(
if (
not user.is_admin
and not permissions.access_all_entities(POLICY_READ)
and not permissions.check_entity(event.data["entity_id"], POLICY_READ)
and not permissions.check_entity(entity_id, POLICY_READ)
):
return
send_message(messages.cached_state_diff_message(message_id_as_bytes, event))
@ -392,43 +399,55 @@ def _forward_entity_changes(
{
vol.Required("type"): "subscribe_entities",
vol.Optional("entity_ids"): cv.entity_ids,
**INCLUDE_EXCLUDE_BASE_FILTER_SCHEMA.schema,
}
)
def handle_subscribe_entities(
hass: HomeAssistant, connection: ActiveConnection, msg: dict[str, Any]
) -> None:
"""Handle subscribe entities command."""
entity_ids = set(msg.get("entity_ids", []))
entity_ids = set(msg.get("entity_ids", [])) or None
_filter = convert_include_exclude_filter(msg)
entity_filter = None if _filter.empty_filter else _filter.get_filter()
# We must never await between sending the states and listening for
# state changed events or we will introduce a race condition
# where some states are missed
states = _async_get_allowed_states(hass, connection)
message_id_as_bytes = str(msg["id"]).encode()
connection.subscriptions[msg["id"]] = hass.bus.async_listen(
msg_id = msg["id"]
message_id_as_bytes = str(msg_id).encode()
connection.subscriptions[msg_id] = hass.bus.async_listen(
EVENT_STATE_CHANGED,
partial(
_forward_entity_changes,
connection.send_message,
entity_ids,
entity_filter,
connection.user,
message_id_as_bytes,
),
)
connection.send_result(msg["id"])
connection.send_result(msg_id)
# JSON serialize here so we can recover if it blows up due to the
# state machine containing unserializable data. This command is required
# to succeed for the UI to show.
try:
serialized_states = [
state.as_compressed_state_json
for state in states
if not entity_ids or state.entity_id in entity_ids
]
if entity_ids or entity_filter:
serialized_states = [
state.as_compressed_state_json
for state in states
if (not entity_ids or state.entity_id in entity_ids)
and (not entity_filter or entity_filter(state.entity_id))
]
else:
# Fast path when not filtering
serialized_states = [state.as_compressed_state_json for state in states]
except (ValueError, TypeError):
pass
else:
_send_handle_entities_init_response(connection, msg["id"], serialized_states)
_send_handle_entities_init_response(
connection, message_id_as_bytes, serialized_states
)
return
serialized_states = []
@ -443,18 +462,22 @@ def handle_subscribe_entities(
),
)
_send_handle_entities_init_response(connection, msg["id"], serialized_states)
_send_handle_entities_init_response(
connection, message_id_as_bytes, serialized_states
)
def _send_handle_entities_init_response(
connection: ActiveConnection, msg_id: int, serialized_states: list[bytes]
connection: ActiveConnection,
message_id_as_bytes: bytes,
serialized_states: list[bytes],
) -> None:
"""Send handle entities init response."""
connection.send_message(
b"".join(
(
b'{"id":',
str(msg_id).encode(),
message_id_as_bytes,
b',"type":"event","event":{"a":{',
b",".join(serialized_states),
b"}}}",

View File

@ -1262,6 +1262,54 @@ async def test_subscribe_unsubscribe_entities_specific_entities(
}
async def test_subscribe_unsubscribe_entities_with_filter(
hass: HomeAssistant,
websocket_client: MockHAClientWebSocket,
hass_admin_user: MockUser,
) -> None:
"""Test subscribe/unsubscribe entities with an entity filter."""
hass.states.async_set("switch.not_included", "off")
hass.states.async_set("light.include", "off")
await websocket_client.send_json(
{"id": 7, "type": "subscribe_entities", "include": {"domains": ["light"]}}
)
msg = await websocket_client.receive_json()
assert msg["id"] == 7
assert msg["type"] == const.TYPE_RESULT
assert msg["success"]
msg = await websocket_client.receive_json()
assert msg["id"] == 7
assert msg["type"] == "event"
assert msg["event"] == {
"a": {
"light.include": {
"a": {},
"c": ANY,
"lc": ANY,
"s": "off",
}
}
}
hass.states.async_set("switch.not_included", "on")
hass.states.async_set("light.include", "on")
msg = await websocket_client.receive_json()
assert msg["id"] == 7
assert msg["type"] == "event"
assert msg["event"] == {
"c": {
"light.include": {
"+": {
"c": ANY,
"lc": ANY,
"s": "on",
}
}
}
}
async def test_render_template_renders_template(
hass: HomeAssistant, websocket_client
) -> None: