Add support for using an entityfilter to subscribe_entities (#124641)
* Add support for using an entityfilter to subscribe_entities * filter init * fix * coveragepull/124696/head
parent
68d6f1c1aa
commit
d8161c431f
|
@ -36,6 +36,10 @@ from homeassistant.exceptions import (
|
||||||
)
|
)
|
||||||
from homeassistant.helpers import config_validation as cv, entity, template
|
from homeassistant.helpers import config_validation as cv, entity, template
|
||||||
from homeassistant.helpers.dispatcher import async_dispatcher_connect
|
from homeassistant.helpers.dispatcher import async_dispatcher_connect
|
||||||
|
from homeassistant.helpers.entityfilter import (
|
||||||
|
INCLUDE_EXCLUDE_BASE_FILTER_SCHEMA,
|
||||||
|
convert_include_exclude_filter,
|
||||||
|
)
|
||||||
from homeassistant.helpers.event import (
|
from homeassistant.helpers.event import (
|
||||||
TrackTemplate,
|
TrackTemplate,
|
||||||
TrackTemplateResult,
|
TrackTemplateResult,
|
||||||
|
@ -366,14 +370,17 @@ def _send_handle_get_states_response(
|
||||||
@callback
|
@callback
|
||||||
def _forward_entity_changes(
|
def _forward_entity_changes(
|
||||||
send_message: Callable[[str | bytes | dict[str, Any]], None],
|
send_message: Callable[[str | bytes | dict[str, Any]], None],
|
||||||
entity_ids: set[str],
|
entity_ids: set[str] | None,
|
||||||
|
entity_filter: Callable[[str], bool] | None,
|
||||||
user: User,
|
user: User,
|
||||||
message_id_as_bytes: bytes,
|
message_id_as_bytes: bytes,
|
||||||
event: Event[EventStateChangedData],
|
event: Event[EventStateChangedData],
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Forward entity state changed events to websocket."""
|
"""Forward entity state changed events to websocket."""
|
||||||
entity_id = event.data["entity_id"]
|
entity_id = event.data["entity_id"]
|
||||||
if entity_ids and entity_id not in entity_ids:
|
if (entity_ids and entity_id not in entity_ids) or (
|
||||||
|
entity_filter and not entity_filter(entity_id)
|
||||||
|
):
|
||||||
return
|
return
|
||||||
# We have to lookup the permissions again because the user might have
|
# We have to lookup the permissions again because the user might have
|
||||||
# changed since the subscription was created.
|
# changed since the subscription was created.
|
||||||
|
@ -381,7 +388,7 @@ def _forward_entity_changes(
|
||||||
if (
|
if (
|
||||||
not user.is_admin
|
not user.is_admin
|
||||||
and not permissions.access_all_entities(POLICY_READ)
|
and not permissions.access_all_entities(POLICY_READ)
|
||||||
and not permissions.check_entity(event.data["entity_id"], POLICY_READ)
|
and not permissions.check_entity(entity_id, POLICY_READ)
|
||||||
):
|
):
|
||||||
return
|
return
|
||||||
send_message(messages.cached_state_diff_message(message_id_as_bytes, event))
|
send_message(messages.cached_state_diff_message(message_id_as_bytes, event))
|
||||||
|
@ -392,43 +399,55 @@ def _forward_entity_changes(
|
||||||
{
|
{
|
||||||
vol.Required("type"): "subscribe_entities",
|
vol.Required("type"): "subscribe_entities",
|
||||||
vol.Optional("entity_ids"): cv.entity_ids,
|
vol.Optional("entity_ids"): cv.entity_ids,
|
||||||
|
**INCLUDE_EXCLUDE_BASE_FILTER_SCHEMA.schema,
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
def handle_subscribe_entities(
|
def handle_subscribe_entities(
|
||||||
hass: HomeAssistant, connection: ActiveConnection, msg: dict[str, Any]
|
hass: HomeAssistant, connection: ActiveConnection, msg: dict[str, Any]
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Handle subscribe entities command."""
|
"""Handle subscribe entities command."""
|
||||||
entity_ids = set(msg.get("entity_ids", []))
|
entity_ids = set(msg.get("entity_ids", [])) or None
|
||||||
|
_filter = convert_include_exclude_filter(msg)
|
||||||
|
entity_filter = None if _filter.empty_filter else _filter.get_filter()
|
||||||
# We must never await between sending the states and listening for
|
# We must never await between sending the states and listening for
|
||||||
# state changed events or we will introduce a race condition
|
# state changed events or we will introduce a race condition
|
||||||
# where some states are missed
|
# where some states are missed
|
||||||
states = _async_get_allowed_states(hass, connection)
|
states = _async_get_allowed_states(hass, connection)
|
||||||
message_id_as_bytes = str(msg["id"]).encode()
|
msg_id = msg["id"]
|
||||||
connection.subscriptions[msg["id"]] = hass.bus.async_listen(
|
message_id_as_bytes = str(msg_id).encode()
|
||||||
|
connection.subscriptions[msg_id] = hass.bus.async_listen(
|
||||||
EVENT_STATE_CHANGED,
|
EVENT_STATE_CHANGED,
|
||||||
partial(
|
partial(
|
||||||
_forward_entity_changes,
|
_forward_entity_changes,
|
||||||
connection.send_message,
|
connection.send_message,
|
||||||
entity_ids,
|
entity_ids,
|
||||||
|
entity_filter,
|
||||||
connection.user,
|
connection.user,
|
||||||
message_id_as_bytes,
|
message_id_as_bytes,
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
connection.send_result(msg["id"])
|
connection.send_result(msg_id)
|
||||||
|
|
||||||
# JSON serialize here so we can recover if it blows up due to the
|
# JSON serialize here so we can recover if it blows up due to the
|
||||||
# state machine containing unserializable data. This command is required
|
# state machine containing unserializable data. This command is required
|
||||||
# to succeed for the UI to show.
|
# to succeed for the UI to show.
|
||||||
try:
|
try:
|
||||||
|
if entity_ids or entity_filter:
|
||||||
serialized_states = [
|
serialized_states = [
|
||||||
state.as_compressed_state_json
|
state.as_compressed_state_json
|
||||||
for state in states
|
for state in states
|
||||||
if not entity_ids or state.entity_id in entity_ids
|
if (not entity_ids or state.entity_id in entity_ids)
|
||||||
|
and (not entity_filter or entity_filter(state.entity_id))
|
||||||
]
|
]
|
||||||
|
else:
|
||||||
|
# Fast path when not filtering
|
||||||
|
serialized_states = [state.as_compressed_state_json for state in states]
|
||||||
except (ValueError, TypeError):
|
except (ValueError, TypeError):
|
||||||
pass
|
pass
|
||||||
else:
|
else:
|
||||||
_send_handle_entities_init_response(connection, msg["id"], serialized_states)
|
_send_handle_entities_init_response(
|
||||||
|
connection, message_id_as_bytes, serialized_states
|
||||||
|
)
|
||||||
return
|
return
|
||||||
|
|
||||||
serialized_states = []
|
serialized_states = []
|
||||||
|
@ -443,18 +462,22 @@ def handle_subscribe_entities(
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
_send_handle_entities_init_response(connection, msg["id"], serialized_states)
|
_send_handle_entities_init_response(
|
||||||
|
connection, message_id_as_bytes, serialized_states
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def _send_handle_entities_init_response(
|
def _send_handle_entities_init_response(
|
||||||
connection: ActiveConnection, msg_id: int, serialized_states: list[bytes]
|
connection: ActiveConnection,
|
||||||
|
message_id_as_bytes: bytes,
|
||||||
|
serialized_states: list[bytes],
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Send handle entities init response."""
|
"""Send handle entities init response."""
|
||||||
connection.send_message(
|
connection.send_message(
|
||||||
b"".join(
|
b"".join(
|
||||||
(
|
(
|
||||||
b'{"id":',
|
b'{"id":',
|
||||||
str(msg_id).encode(),
|
message_id_as_bytes,
|
||||||
b',"type":"event","event":{"a":{',
|
b',"type":"event","event":{"a":{',
|
||||||
b",".join(serialized_states),
|
b",".join(serialized_states),
|
||||||
b"}}}",
|
b"}}}",
|
||||||
|
|
|
@ -1262,6 +1262,54 @@ async def test_subscribe_unsubscribe_entities_specific_entities(
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
async def test_subscribe_unsubscribe_entities_with_filter(
|
||||||
|
hass: HomeAssistant,
|
||||||
|
websocket_client: MockHAClientWebSocket,
|
||||||
|
hass_admin_user: MockUser,
|
||||||
|
) -> None:
|
||||||
|
"""Test subscribe/unsubscribe entities with an entity filter."""
|
||||||
|
hass.states.async_set("switch.not_included", "off")
|
||||||
|
hass.states.async_set("light.include", "off")
|
||||||
|
await websocket_client.send_json(
|
||||||
|
{"id": 7, "type": "subscribe_entities", "include": {"domains": ["light"]}}
|
||||||
|
)
|
||||||
|
|
||||||
|
msg = await websocket_client.receive_json()
|
||||||
|
assert msg["id"] == 7
|
||||||
|
assert msg["type"] == const.TYPE_RESULT
|
||||||
|
assert msg["success"]
|
||||||
|
|
||||||
|
msg = await websocket_client.receive_json()
|
||||||
|
assert msg["id"] == 7
|
||||||
|
assert msg["type"] == "event"
|
||||||
|
assert msg["event"] == {
|
||||||
|
"a": {
|
||||||
|
"light.include": {
|
||||||
|
"a": {},
|
||||||
|
"c": ANY,
|
||||||
|
"lc": ANY,
|
||||||
|
"s": "off",
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
hass.states.async_set("switch.not_included", "on")
|
||||||
|
hass.states.async_set("light.include", "on")
|
||||||
|
msg = await websocket_client.receive_json()
|
||||||
|
assert msg["id"] == 7
|
||||||
|
assert msg["type"] == "event"
|
||||||
|
assert msg["event"] == {
|
||||||
|
"c": {
|
||||||
|
"light.include": {
|
||||||
|
"+": {
|
||||||
|
"c": ANY,
|
||||||
|
"lc": ANY,
|
||||||
|
"s": "on",
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
async def test_render_template_renders_template(
|
async def test_render_template_renders_template(
|
||||||
hass: HomeAssistant, websocket_client
|
hass: HomeAssistant, websocket_client
|
||||||
) -> None:
|
) -> None:
|
||||||
|
|
Loading…
Reference in New Issue