Websocket api to subscribe to entities (payloads reduced by ~80%+ vs state_changed events) (#67891)

pull/68026/head
J. Nick Koston 2022-03-11 18:54:49 -10:00 committed by GitHub
parent 6526b4eae5
commit 0d8f649bd6
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 590 additions and 12 deletions

View File

@ -16,7 +16,7 @@ from homeassistant.const import (
MATCH_ALL,
SIGNAL_BOOTSTRAP_INTEGRATONS,
)
from homeassistant.core import Context, Event, HomeAssistant, callback
from homeassistant.core import Context, Event, HomeAssistant, State, callback
from homeassistant.exceptions import (
HomeAssistantError,
ServiceNotFound,
@ -68,6 +68,7 @@ def async_register_commands(
async_reg(hass, handle_test_condition)
async_reg(hass, handle_unsubscribe_events)
async_reg(hass, handle_validate_config)
async_reg(hass, handle_subscribe_entities)
def pong_message(iden: int) -> dict[str, Any]:
@ -213,21 +214,27 @@ async def handle_call_service(
connection.send_error(msg["id"], const.ERR_UNKNOWN_ERROR, str(err))
@callback
def _async_get_allowed_states(
hass: HomeAssistant, connection: ActiveConnection
) -> list[State]:
if connection.user.permissions.access_all_entities("read"):
return hass.states.async_all()
entity_perm = connection.user.permissions.check_entity
return [
state
for state in hass.states.async_all()
if entity_perm(state.entity_id, "read")
]
@callback
@decorators.websocket_command({vol.Required("type"): "get_states"})
def handle_get_states(
hass: HomeAssistant, connection: ActiveConnection, msg: dict[str, Any]
) -> None:
"""Handle get states command."""
if connection.user.permissions.access_all_entities("read"):
states = hass.states.async_all()
else:
entity_perm = connection.user.permissions.check_entity
states = [
state
for state in hass.states.async_all()
if entity_perm(state.entity_id, "read")
]
states = _async_get_allowed_states(hass, connection)
# JSON serialize here so we can recover if it blows up due to the
# state machine containing unserializable data. This command is required
@ -260,6 +267,84 @@ def handle_get_states(
connection.send_message(response2)
@callback
@decorators.websocket_command(
{
vol.Required("type"): "subscribe_entities",
vol.Optional("entity_ids"): cv.entity_ids,
}
)
def handle_subscribe_entities(
hass: HomeAssistant, connection: ActiveConnection, msg: dict[str, Any]
) -> None:
"""Handle subscribe entities command."""
# Circular dep
# pylint: disable=import-outside-toplevel
from .permissions import SUBSCRIBE_ALLOWLIST
if "state_changed" not in SUBSCRIBE_ALLOWLIST and not connection.user.is_admin:
raise Unauthorized
entity_ids = set(msg.get("entity_ids", []))
@callback
def forward_entity_changes(event: Event) -> None:
"""Forward entity state changed events to websocket."""
if not connection.user.permissions.check_entity(
event.data["entity_id"], POLICY_READ
):
return
if entity_ids and event.data["entity_id"] not in entity_ids:
return
connection.send_message(messages.cached_state_diff_message(msg["id"], event))
# We must never await between sending the states and listening for
# state changed events or we will introduce a race condition
# where some states are missed
states = _async_get_allowed_states(hass, connection)
connection.subscriptions[msg["id"]] = hass.bus.async_listen(
"state_changed", forward_entity_changes
)
connection.send_result(msg["id"])
data: dict[str, dict[str, dict]] = {
messages.ENTITY_EVENT_ADD: {
state.entity_id: messages.compressed_state_dict_add(state)
for state in states
if not entity_ids or state.entity_id in entity_ids
}
}
# JSON serialize here so we can recover if it blows up due to the
# state machine containing unserializable data. This command is required
# to succeed for the UI to show.
response = messages.event_message(msg["id"], data)
try:
connection.send_message(const.JSON_DUMP(response))
return
except (ValueError, TypeError):
connection.logger.error(
"Unable to serialize to JSON. Bad data found at %s",
format_unserializable_data(
find_paths_unserializable_data(response, dump=const.JSON_DUMP)
),
)
del response
add_entities = data[messages.ENTITY_EVENT_ADD]
cannot_serialize: list[str] = []
for entity_id, state_dict in add_entities.items():
try:
const.JSON_DUMP(state_dict)
except (ValueError, TypeError):
cannot_serialize.append(entity_id)
for entity_id in cannot_serialize:
del add_entities[entity_id]
connection.send_message(const.JSON_DUMP(messages.event_message(msg["id"], data)))
@decorators.websocket_command({vol.Required("type"): "get_services"})
@decorators.async_response
async def handle_get_services(

View File

@ -7,7 +7,7 @@ from typing import Any, Final
import voluptuous as vol
from homeassistant.core import Event
from homeassistant.core import Event, State
from homeassistant.helpers import config_validation as cv
from homeassistant.util.json import (
find_paths_unserializable_data,
@ -31,6 +31,19 @@ BASE_COMMAND_MESSAGE_SCHEMA: Final = vol.Schema({vol.Required("id"): cv.positive
IDEN_TEMPLATE: Final = "__IDEN__"
IDEN_JSON_TEMPLATE: Final = '"__IDEN__"'
COMPRESSED_STATE_STATE = "s"
COMPRESSED_STATE_ATTRIBUTES = "a"
COMPRESSED_STATE_CONTEXT = "c"
COMPRESSED_STATE_LAST_CHANGED = "lc"
COMPRESSED_STATE_LAST_UPDATED = "lu"
STATE_DIFF_ADDITIONS = "+"
STATE_DIFF_REMOVALS = "-"
ENTITY_EVENT_ADD = "a"
ENTITY_EVENT_REMOVE = "r"
ENTITY_EVENT_CHANGE = "c"
def result_message(iden: int, result: Any = None) -> dict[str, Any]:
"""Return a success result message."""
@ -74,6 +87,110 @@ def _cached_event_message(event: Event) -> str:
return message_to_json(event_message(IDEN_TEMPLATE, event))
def cached_state_diff_message(iden: int, event: Event) -> str:
"""Return an event message.
Serialize to json once per message.
Since we can have many clients connected that are
all getting many of the same events (mostly state changed)
we can avoid serializing the same data for each connection.
"""
return _cached_state_diff_message(event).replace(IDEN_JSON_TEMPLATE, str(iden), 1)
@lru_cache(maxsize=128)
def _cached_state_diff_message(event: Event) -> str:
"""Cache and serialize the event to json.
The IDEN_TEMPLATE is used which will be replaced
with the actual iden in cached_event_message
"""
return message_to_json(event_message(IDEN_TEMPLATE, _state_diff_event(event)))
def _state_diff_event(event: Event) -> dict:
"""Convert a state_changed event to the minimal version.
State update example
{
"a": {entity_id: compressed_state,}
"c": {entity_id: diff,}
"r": [entity_id,]
}
"""
if (event_new_state := event.data["new_state"]) is None:
return {ENTITY_EVENT_REMOVE: [event.data["entity_id"]]}
assert isinstance(event_new_state, State)
if (event_old_state := event.data["old_state"]) is None:
return {
ENTITY_EVENT_ADD: {
event_new_state.entity_id: compressed_state_dict_add(event_new_state)
}
}
assert isinstance(event_old_state, State)
return _state_diff(event_old_state, event_new_state)
def _state_diff(
old_state: State, new_state: State
) -> dict[str, dict[str, dict[str, dict[str, str | list[str]]]]]:
"""Create a diff dict that can be used to overlay changes."""
diff: dict = {STATE_DIFF_ADDITIONS: {}}
additions = diff[STATE_DIFF_ADDITIONS]
if old_state.state != new_state.state:
additions[COMPRESSED_STATE_STATE] = new_state.state
if old_state.last_changed != new_state.last_changed:
additions[COMPRESSED_STATE_LAST_CHANGED] = new_state.last_changed.timestamp()
elif old_state.last_updated != new_state.last_updated:
additions[COMPRESSED_STATE_LAST_UPDATED] = new_state.last_updated.timestamp()
if old_state.context.parent_id != new_state.context.parent_id:
additions.setdefault(COMPRESSED_STATE_CONTEXT, {})[
"parent_id"
] = new_state.context.parent_id
if old_state.context.user_id != new_state.context.user_id:
additions.setdefault(COMPRESSED_STATE_CONTEXT, {})[
"user_id"
] = new_state.context.user_id
if old_state.context.id != new_state.context.id:
if COMPRESSED_STATE_CONTEXT in additions:
additions[COMPRESSED_STATE_CONTEXT]["id"] = new_state.context.id
else:
additions[COMPRESSED_STATE_CONTEXT] = new_state.context.id
old_attributes = old_state.attributes
for key, value in new_state.attributes.items():
if old_attributes.get(key) != value:
additions.setdefault(COMPRESSED_STATE_ATTRIBUTES, {})[key] = value
if removed := set(old_attributes).difference(new_state.attributes):
diff[STATE_DIFF_REMOVALS] = {COMPRESSED_STATE_ATTRIBUTES: removed}
return {ENTITY_EVENT_CHANGE: {new_state.entity_id: diff}}
def compressed_state_dict_add(state: State) -> dict[str, Any]:
"""Build a compressed dict of a state for adds.
Omits the lu (last_updated) if it matches (lc) last_changed.
Sends c (context) as a string if it only contains an id.
"""
if state.context.parent_id is None and state.context.user_id is None:
context: dict[str, Any] | str = state.context.id # type: ignore[unreachable]
else:
context = state.context.as_dict()
compressed_state: dict[str, Any] = {
COMPRESSED_STATE_STATE: state.state,
COMPRESSED_STATE_ATTRIBUTES: state.attributes,
COMPRESSED_STATE_CONTEXT: context,
}
if state.last_changed == state.last_updated:
compressed_state[COMPRESSED_STATE_LAST_CHANGED] = state.last_changed.timestamp()
else:
compressed_state[COMPRESSED_STATE_LAST_CHANGED] = state.last_changed.timestamp()
compressed_state[COMPRESSED_STATE_LAST_UPDATED] = state.last_updated.timestamp()
return compressed_state
def message_to_json(message: dict[str, Any]) -> str:
"""Serialize a websocket message to json."""
try:

View File

@ -1,4 +1,5 @@
"""Tests for WebSocket API commands."""
from copy import deepcopy
import datetime
from unittest.mock import ANY, patch
@ -14,7 +15,7 @@ from homeassistant.components.websocket_api.auth import (
)
from homeassistant.components.websocket_api.const import URL
from homeassistant.const import SIGNAL_BOOTSTRAP_INTEGRATONS
from homeassistant.core import Context, HomeAssistant, callback
from homeassistant.core import Context, HomeAssistant, State, callback
from homeassistant.exceptions import HomeAssistantError
from homeassistant.helpers import entity
from homeassistant.helpers.dispatcher import async_dispatcher_send
@ -23,6 +24,38 @@ from homeassistant.setup import DATA_SETUP_TIME, async_setup_component
from tests.common import MockEntity, MockEntityPlatform, async_mock_service
STATE_KEY_SHORT_NAMES = {
"entity_id": "e",
"state": "s",
"last_changed": "lc",
"last_updated": "lu",
"context": "c",
"attributes": "a",
}
STATE_KEY_LONG_NAMES = {v: k for k, v in STATE_KEY_SHORT_NAMES.items()}
def _apply_entities_changes(state_dict: dict, change_dict: dict) -> None:
"""Apply a diff set to a dict.
Port of the client side merging
"""
additions = change_dict.get("+", {})
if "lc" in additions:
additions["lu"] = additions["lc"]
if attributes := additions.pop("a", None):
state_dict["attributes"].update(attributes)
if context := additions.pop("c", None):
if isinstance(context, str):
state_dict["context"]["id"] = context
else:
state_dict["context"].update(context)
for k, v in additions.items():
state_dict[STATE_KEY_LONG_NAMES[k]] = v
for key, items in change_dict.get("-", {}).items():
for item in items:
del state_dict[STATE_KEY_LONG_NAMES[key]][item]
async def test_fire_event(hass, websocket_client):
"""Test fire event command."""
@ -666,6 +699,349 @@ async def test_subscribe_unsubscribe_events_state_changed(
assert msg["event"]["data"]["entity_id"] == "light.permitted"
async def test_subscribe_entities_with_unserializable_state(
hass, websocket_client, hass_admin_user
):
"""Test subscribe entities with an unserializeable state."""
class CannotSerializeMe:
"""Cannot serialize this."""
def __init__(self):
"""Init cannot serialize this."""
hass.states.async_set("light.permitted", "off", {"color": "red"})
hass.states.async_set(
"light.cannot_serialize",
"off",
{"color": "red", "cannot_serialize": CannotSerializeMe()},
)
original_state = hass.states.get("light.cannot_serialize")
assert isinstance(original_state, State)
state_dict = {
"attributes": dict(original_state.attributes),
"context": dict(original_state.context.as_dict()),
"entity_id": original_state.entity_id,
"last_changed": original_state.last_changed.isoformat(),
"last_updated": original_state.last_updated.isoformat(),
"state": original_state.state,
}
hass_admin_user.groups = []
hass_admin_user.mock_policy(
{
"entities": {
"entity_ids": {"light.permitted": True, "light.cannot_serialize": True}
}
}
)
await websocket_client.send_json({"id": 7, "type": "subscribe_entities"})
msg = await websocket_client.receive_json()
assert msg["id"] == 7
assert msg["type"] == const.TYPE_RESULT
assert msg["success"]
msg = await websocket_client.receive_json()
assert msg["id"] == 7
assert msg["type"] == "event"
assert msg["event"] == {
"a": {
"light.permitted": {
"a": {"color": "red"},
"c": ANY,
"lc": ANY,
"s": "off",
}
}
}
hass.states.async_set("light.permitted", "on", {"effect": "help"})
msg = await websocket_client.receive_json()
assert msg["id"] == 7
assert msg["type"] == "event"
assert msg["event"] == {
"c": {
"light.permitted": {
"+": {
"a": {"effect": "help"},
"c": ANY,
"lc": ANY,
"s": "on",
},
"-": {"a": ["color"]},
}
}
}
hass.states.async_set("light.cannot_serialize", "on", {"effect": "help"})
msg = await websocket_client.receive_json()
assert msg["id"] == 7
assert msg["type"] == "event"
# Order does not matter
msg["event"]["c"]["light.cannot_serialize"]["-"]["a"] = set(
msg["event"]["c"]["light.cannot_serialize"]["-"]["a"]
)
assert msg["event"] == {
"c": {
"light.cannot_serialize": {
"+": {"a": {"effect": "help"}, "c": ANY, "lc": ANY, "s": "on"},
"-": {"a": {"color", "cannot_serialize"}},
}
}
}
change_set = msg["event"]["c"]["light.cannot_serialize"]
_apply_entities_changes(state_dict, change_set)
assert state_dict == {
"attributes": {"effect": "help"},
"context": {
"id": ANY,
"parent_id": None,
"user_id": None,
},
"entity_id": "light.cannot_serialize",
"last_changed": ANY,
"last_updated": ANY,
"state": "on",
}
hass.states.async_set(
"light.cannot_serialize",
"off",
{"color": "red", "cannot_serialize": CannotSerializeMe()},
)
msg = await websocket_client.receive_json()
assert msg["id"] == 7
assert msg["type"] == "result"
assert msg["error"] == {
"code": "unknown_error",
"message": "Invalid JSON in response",
}
async def test_subscribe_unsubscribe_entities(hass, websocket_client, hass_admin_user):
"""Test subscribe/unsubscribe entities."""
hass.states.async_set("light.permitted", "off", {"color": "red"})
original_state = hass.states.get("light.permitted")
assert isinstance(original_state, State)
state_dict = {
"attributes": dict(original_state.attributes),
"context": dict(original_state.context.as_dict()),
"entity_id": original_state.entity_id,
"last_changed": original_state.last_changed.isoformat(),
"last_updated": original_state.last_updated.isoformat(),
"state": original_state.state,
}
hass_admin_user.groups = []
hass_admin_user.mock_policy({"entities": {"entity_ids": {"light.permitted": True}}})
await websocket_client.send_json({"id": 7, "type": "subscribe_entities"})
msg = await websocket_client.receive_json()
assert msg["id"] == 7
assert msg["type"] == const.TYPE_RESULT
assert msg["success"]
msg = await websocket_client.receive_json()
assert msg["id"] == 7
assert msg["type"] == "event"
assert isinstance(msg["event"]["a"]["light.permitted"]["c"], str)
assert msg["event"] == {
"a": {
"light.permitted": {
"a": {"color": "red"},
"c": ANY,
"lc": ANY,
"s": "off",
}
}
}
hass.states.async_set("light.not_permitted", "on")
hass.states.async_set("light.permitted", "on", {"color": "blue"})
hass.states.async_set("light.permitted", "on", {"effect": "help"})
hass.states.async_set(
"light.permitted", "on", {"effect": "help", "color": ["blue", "green"]}
)
hass.states.async_remove("light.permitted")
hass.states.async_set("light.permitted", "on", {"effect": "help", "color": "blue"})
msg = await websocket_client.receive_json()
assert msg["id"] == 7
assert msg["type"] == "event"
assert msg["event"] == {
"c": {
"light.permitted": {
"+": {
"a": {"color": "blue"},
"c": ANY,
"lc": ANY,
"s": "on",
}
}
}
}
change_set = msg["event"]["c"]["light.permitted"]
additions = deepcopy(change_set["+"])
_apply_entities_changes(state_dict, change_set)
assert state_dict == {
"attributes": {"color": "blue"},
"context": {
"id": additions["c"],
"parent_id": None,
"user_id": None,
},
"entity_id": "light.permitted",
"last_changed": additions["lc"],
"last_updated": additions["lc"],
"state": "on",
}
msg = await websocket_client.receive_json()
assert msg["id"] == 7
assert msg["type"] == "event"
assert msg["event"] == {
"c": {
"light.permitted": {
"+": {
"a": {"effect": "help"},
"c": ANY,
"lu": ANY,
},
"-": {"a": ["color"]},
}
}
}
change_set = msg["event"]["c"]["light.permitted"]
additions = deepcopy(change_set["+"])
_apply_entities_changes(state_dict, change_set)
assert state_dict == {
"attributes": {"effect": "help"},
"context": {
"id": additions["c"],
"parent_id": None,
"user_id": None,
},
"entity_id": "light.permitted",
"last_changed": ANY,
"last_updated": additions["lu"],
"state": "on",
}
msg = await websocket_client.receive_json()
assert msg["id"] == 7
assert msg["type"] == "event"
assert msg["event"] == {
"c": {
"light.permitted": {
"+": {
"a": {"color": ["blue", "green"]},
"c": ANY,
"lu": ANY,
}
}
}
}
change_set = msg["event"]["c"]["light.permitted"]
additions = deepcopy(change_set["+"])
_apply_entities_changes(state_dict, change_set)
assert state_dict == {
"attributes": {"effect": "help", "color": ["blue", "green"]},
"context": {
"id": additions["c"],
"parent_id": None,
"user_id": None,
},
"entity_id": "light.permitted",
"last_changed": ANY,
"last_updated": additions["lu"],
"state": "on",
}
msg = await websocket_client.receive_json()
assert msg["id"] == 7
assert msg["type"] == "event"
assert msg["event"] == {"r": ["light.permitted"]}
msg = await websocket_client.receive_json()
assert msg["id"] == 7
assert msg["type"] == "event"
assert msg["event"] == {
"a": {
"light.permitted": {
"a": {"color": "blue", "effect": "help"},
"c": ANY,
"lc": ANY,
"s": "on",
}
}
}
async def test_subscribe_unsubscribe_entities_specific_entities(
hass, websocket_client, hass_admin_user
):
"""Test subscribe/unsubscribe entities with a list of entity ids."""
hass.states.async_set("light.permitted", "off", {"color": "red"})
hass.states.async_set("light.not_intrested", "off", {"color": "blue"})
original_state = hass.states.get("light.permitted")
assert isinstance(original_state, State)
hass_admin_user.groups = []
hass_admin_user.mock_policy(
{
"entities": {
"entity_ids": {"light.permitted": True, "light.not_intrested": True}
}
}
)
await websocket_client.send_json(
{"id": 7, "type": "subscribe_entities", "entity_ids": ["light.permitted"]}
)
msg = await websocket_client.receive_json()
assert msg["id"] == 7
assert msg["type"] == const.TYPE_RESULT
assert msg["success"]
msg = await websocket_client.receive_json()
assert msg["id"] == 7
assert msg["type"] == "event"
assert isinstance(msg["event"]["a"]["light.permitted"]["c"], str)
assert msg["event"] == {
"a": {
"light.permitted": {
"a": {"color": "red"},
"c": ANY,
"lc": ANY,
"s": "off",
}
}
}
hass.states.async_set("light.not_intrested", "on", {"effect": "help"})
hass.states.async_set("light.not_permitted", "on")
hass.states.async_set("light.permitted", "on", {"color": "blue"})
msg = await websocket_client.receive_json()
assert msg["id"] == 7
assert msg["type"] == "event"
assert msg["event"] == {
"c": {
"light.permitted": {
"+": {
"a": {"color": "blue"},
"c": ANY,
"lc": ANY,
"s": "on",
}
}
}
}
async def test_render_template_renders_template(hass, websocket_client):
"""Test simple template is rendered and updated."""
hass.states.async_set("light.test", "on")