Add support for using an entityfilter to subscribe_entities (#124641)
* Add support for using an entityfilter to subscribe_entities * filter init * fix * coveragepull/124696/head
parent
68d6f1c1aa
commit
d8161c431f
|
@ -36,6 +36,10 @@ from homeassistant.exceptions import (
|
|||
)
|
||||
from homeassistant.helpers import config_validation as cv, entity, template
|
||||
from homeassistant.helpers.dispatcher import async_dispatcher_connect
|
||||
from homeassistant.helpers.entityfilter import (
|
||||
INCLUDE_EXCLUDE_BASE_FILTER_SCHEMA,
|
||||
convert_include_exclude_filter,
|
||||
)
|
||||
from homeassistant.helpers.event import (
|
||||
TrackTemplate,
|
||||
TrackTemplateResult,
|
||||
|
@ -366,14 +370,17 @@ def _send_handle_get_states_response(
|
|||
@callback
|
||||
def _forward_entity_changes(
|
||||
send_message: Callable[[str | bytes | dict[str, Any]], None],
|
||||
entity_ids: set[str],
|
||||
entity_ids: set[str] | None,
|
||||
entity_filter: Callable[[str], bool] | None,
|
||||
user: User,
|
||||
message_id_as_bytes: bytes,
|
||||
event: Event[EventStateChangedData],
|
||||
) -> None:
|
||||
"""Forward entity state changed events to websocket."""
|
||||
entity_id = event.data["entity_id"]
|
||||
if entity_ids and entity_id not in entity_ids:
|
||||
if (entity_ids and entity_id not in entity_ids) or (
|
||||
entity_filter and not entity_filter(entity_id)
|
||||
):
|
||||
return
|
||||
# We have to lookup the permissions again because the user might have
|
||||
# changed since the subscription was created.
|
||||
|
@ -381,7 +388,7 @@ def _forward_entity_changes(
|
|||
if (
|
||||
not user.is_admin
|
||||
and not permissions.access_all_entities(POLICY_READ)
|
||||
and not permissions.check_entity(event.data["entity_id"], POLICY_READ)
|
||||
and not permissions.check_entity(entity_id, POLICY_READ)
|
||||
):
|
||||
return
|
||||
send_message(messages.cached_state_diff_message(message_id_as_bytes, event))
|
||||
|
@ -392,43 +399,55 @@ def _forward_entity_changes(
|
|||
{
|
||||
vol.Required("type"): "subscribe_entities",
|
||||
vol.Optional("entity_ids"): cv.entity_ids,
|
||||
**INCLUDE_EXCLUDE_BASE_FILTER_SCHEMA.schema,
|
||||
}
|
||||
)
|
||||
def handle_subscribe_entities(
|
||||
hass: HomeAssistant, connection: ActiveConnection, msg: dict[str, Any]
|
||||
) -> None:
|
||||
"""Handle subscribe entities command."""
|
||||
entity_ids = set(msg.get("entity_ids", []))
|
||||
entity_ids = set(msg.get("entity_ids", [])) or None
|
||||
_filter = convert_include_exclude_filter(msg)
|
||||
entity_filter = None if _filter.empty_filter else _filter.get_filter()
|
||||
# We must never await between sending the states and listening for
|
||||
# state changed events or we will introduce a race condition
|
||||
# where some states are missed
|
||||
states = _async_get_allowed_states(hass, connection)
|
||||
message_id_as_bytes = str(msg["id"]).encode()
|
||||
connection.subscriptions[msg["id"]] = hass.bus.async_listen(
|
||||
msg_id = msg["id"]
|
||||
message_id_as_bytes = str(msg_id).encode()
|
||||
connection.subscriptions[msg_id] = hass.bus.async_listen(
|
||||
EVENT_STATE_CHANGED,
|
||||
partial(
|
||||
_forward_entity_changes,
|
||||
connection.send_message,
|
||||
entity_ids,
|
||||
entity_filter,
|
||||
connection.user,
|
||||
message_id_as_bytes,
|
||||
),
|
||||
)
|
||||
connection.send_result(msg["id"])
|
||||
connection.send_result(msg_id)
|
||||
|
||||
# JSON serialize here so we can recover if it blows up due to the
|
||||
# state machine containing unserializable data. This command is required
|
||||
# to succeed for the UI to show.
|
||||
try:
|
||||
serialized_states = [
|
||||
state.as_compressed_state_json
|
||||
for state in states
|
||||
if not entity_ids or state.entity_id in entity_ids
|
||||
]
|
||||
if entity_ids or entity_filter:
|
||||
serialized_states = [
|
||||
state.as_compressed_state_json
|
||||
for state in states
|
||||
if (not entity_ids or state.entity_id in entity_ids)
|
||||
and (not entity_filter or entity_filter(state.entity_id))
|
||||
]
|
||||
else:
|
||||
# Fast path when not filtering
|
||||
serialized_states = [state.as_compressed_state_json for state in states]
|
||||
except (ValueError, TypeError):
|
||||
pass
|
||||
else:
|
||||
_send_handle_entities_init_response(connection, msg["id"], serialized_states)
|
||||
_send_handle_entities_init_response(
|
||||
connection, message_id_as_bytes, serialized_states
|
||||
)
|
||||
return
|
||||
|
||||
serialized_states = []
|
||||
|
@ -443,18 +462,22 @@ def handle_subscribe_entities(
|
|||
),
|
||||
)
|
||||
|
||||
_send_handle_entities_init_response(connection, msg["id"], serialized_states)
|
||||
_send_handle_entities_init_response(
|
||||
connection, message_id_as_bytes, serialized_states
|
||||
)
|
||||
|
||||
|
||||
def _send_handle_entities_init_response(
|
||||
connection: ActiveConnection, msg_id: int, serialized_states: list[bytes]
|
||||
connection: ActiveConnection,
|
||||
message_id_as_bytes: bytes,
|
||||
serialized_states: list[bytes],
|
||||
) -> None:
|
||||
"""Send handle entities init response."""
|
||||
connection.send_message(
|
||||
b"".join(
|
||||
(
|
||||
b'{"id":',
|
||||
str(msg_id).encode(),
|
||||
message_id_as_bytes,
|
||||
b',"type":"event","event":{"a":{',
|
||||
b",".join(serialized_states),
|
||||
b"}}}",
|
||||
|
|
|
@ -1262,6 +1262,54 @@ async def test_subscribe_unsubscribe_entities_specific_entities(
|
|||
}
|
||||
|
||||
|
||||
async def test_subscribe_unsubscribe_entities_with_filter(
|
||||
hass: HomeAssistant,
|
||||
websocket_client: MockHAClientWebSocket,
|
||||
hass_admin_user: MockUser,
|
||||
) -> None:
|
||||
"""Test subscribe/unsubscribe entities with an entity filter."""
|
||||
hass.states.async_set("switch.not_included", "off")
|
||||
hass.states.async_set("light.include", "off")
|
||||
await websocket_client.send_json(
|
||||
{"id": 7, "type": "subscribe_entities", "include": {"domains": ["light"]}}
|
||||
)
|
||||
|
||||
msg = await websocket_client.receive_json()
|
||||
assert msg["id"] == 7
|
||||
assert msg["type"] == const.TYPE_RESULT
|
||||
assert msg["success"]
|
||||
|
||||
msg = await websocket_client.receive_json()
|
||||
assert msg["id"] == 7
|
||||
assert msg["type"] == "event"
|
||||
assert msg["event"] == {
|
||||
"a": {
|
||||
"light.include": {
|
||||
"a": {},
|
||||
"c": ANY,
|
||||
"lc": ANY,
|
||||
"s": "off",
|
||||
}
|
||||
}
|
||||
}
|
||||
hass.states.async_set("switch.not_included", "on")
|
||||
hass.states.async_set("light.include", "on")
|
||||
msg = await websocket_client.receive_json()
|
||||
assert msg["id"] == 7
|
||||
assert msg["type"] == "event"
|
||||
assert msg["event"] == {
|
||||
"c": {
|
||||
"light.include": {
|
||||
"+": {
|
||||
"c": ANY,
|
||||
"lc": ANY,
|
||||
"s": "on",
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
async def test_render_template_renders_template(
|
||||
hass: HomeAssistant, websocket_client
|
||||
) -> None:
|
||||
|
|
Loading…
Reference in New Issue