Speed up reconnects by caching state serialize (#93050)
parent
9c039a17ea
commit
99265a983a
|
@ -2,7 +2,6 @@
|
|||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Callable
|
||||
from contextlib import suppress
|
||||
import datetime as dt
|
||||
from functools import lru_cache
|
||||
import json
|
||||
|
@ -50,6 +49,17 @@ from . import const, decorators, messages
|
|||
from .connection import ActiveConnection
|
||||
from .const import ERR_NOT_FOUND
|
||||
|
||||
_STATES_TEMPLATE = "__STATES__"
|
||||
_STATES_JSON_TEMPLATE = '"__STATES__"'
|
||||
_HANDLE_SUBSCRIBE_ENTITIES_TEMPLATE = JSON_DUMP(
|
||||
messages.event_message(
|
||||
messages.IDEN_TEMPLATE, {messages.ENTITY_EVENT_ADD: _STATES_TEMPLATE}
|
||||
)
|
||||
)
|
||||
_HANDLE_GET_STATES_TEMPLATE = JSON_DUMP(
|
||||
messages.result_message(messages.IDEN_TEMPLATE, _STATES_TEMPLATE)
|
||||
)
|
||||
|
||||
|
||||
@callback
|
||||
def async_register_commands(
|
||||
|
@ -242,33 +252,43 @@ def handle_get_states(
|
|||
"""Handle get states command."""
|
||||
states = _async_get_allowed_states(hass, connection)
|
||||
|
||||
# JSON serialize here so we can recover if it blows up due to the
|
||||
# state machine containing unserializable data. This command is required
|
||||
# to succeed for the UI to show.
|
||||
response = messages.result_message(msg["id"], states)
|
||||
try:
|
||||
connection.send_message(JSON_DUMP(response))
|
||||
return
|
||||
serialized_states = [state.as_dict_json() for state in states]
|
||||
except (ValueError, TypeError):
|
||||
connection.logger.error(
|
||||
"Unable to serialize to JSON. Bad data found at %s",
|
||||
format_unserializable_data(
|
||||
find_paths_unserializable_data(response, dump=JSON_DUMP)
|
||||
),
|
||||
)
|
||||
del response
|
||||
pass
|
||||
else:
|
||||
_send_handle_get_states_response(connection, msg["id"], serialized_states)
|
||||
return
|
||||
|
||||
# If we can't serialize, we'll filter out unserializable states
|
||||
serialized = []
|
||||
serialized_states = []
|
||||
for state in states:
|
||||
# Error is already logged above
|
||||
with suppress(ValueError, TypeError):
|
||||
serialized.append(JSON_DUMP(state))
|
||||
try:
|
||||
serialized_states.append(state.as_dict_json())
|
||||
except (ValueError, TypeError):
|
||||
connection.logger.error(
|
||||
"Unable to serialize to JSON. Bad data found at %s",
|
||||
format_unserializable_data(
|
||||
find_paths_unserializable_data(state, dump=JSON_DUMP)
|
||||
),
|
||||
)
|
||||
|
||||
# We now have partially serialized states. Craft some JSON.
|
||||
response2 = JSON_DUMP(messages.result_message(msg["id"], ["TO_REPLACE"]))
|
||||
response2 = response2.replace('"TO_REPLACE"', ", ".join(serialized))
|
||||
connection.send_message(response2)
|
||||
_send_handle_get_states_response(connection, msg["id"], serialized_states)
|
||||
|
||||
|
||||
def _send_handle_get_states_response(
|
||||
connection: ActiveConnection, msg_id: int, serialized_states: list[str]
|
||||
) -> None:
|
||||
"""Send handle get states response."""
|
||||
connection.send_message(
|
||||
_HANDLE_GET_STATES_TEMPLATE.replace(
|
||||
messages.IDEN_JSON_TEMPLATE, str(msg_id), 1
|
||||
).replace(
|
||||
_STATES_JSON_TEMPLATE,
|
||||
"[" + ",".join(serialized_states) + "]",
|
||||
1,
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
@callback
|
||||
|
@ -304,42 +324,50 @@ def handle_subscribe_entities(
|
|||
EVENT_STATE_CHANGED, forward_entity_changes, run_immediately=True
|
||||
)
|
||||
connection.send_result(msg["id"])
|
||||
data: dict[str, dict[str, dict]] = {
|
||||
messages.ENTITY_EVENT_ADD: {
|
||||
state.entity_id: state.as_compressed_state()
|
||||
for state in states
|
||||
if not entity_ids or state.entity_id in entity_ids
|
||||
}
|
||||
}
|
||||
|
||||
# JSON serialize here so we can recover if it blows up due to the
|
||||
# state machine containing unserializable data. This command is required
|
||||
# to succeed for the UI to show.
|
||||
response = messages.event_message(msg["id"], data)
|
||||
try:
|
||||
connection.send_message(JSON_DUMP(response))
|
||||
return
|
||||
serialized_states = [
|
||||
state.as_compressed_state_json()
|
||||
for state in states
|
||||
if not entity_ids or state.entity_id in entity_ids
|
||||
]
|
||||
except (ValueError, TypeError):
|
||||
connection.logger.error(
|
||||
"Unable to serialize to JSON. Bad data found at %s",
|
||||
format_unserializable_data(
|
||||
find_paths_unserializable_data(response, dump=JSON_DUMP)
|
||||
),
|
||||
)
|
||||
del response
|
||||
pass
|
||||
else:
|
||||
_send_handle_entities_init_response(connection, msg["id"], serialized_states)
|
||||
return
|
||||
|
||||
add_entities = data[messages.ENTITY_EVENT_ADD]
|
||||
cannot_serialize: list[str] = []
|
||||
for entity_id, state_dict in add_entities.items():
|
||||
serialized_states = []
|
||||
for state in states:
|
||||
try:
|
||||
JSON_DUMP(state_dict)
|
||||
serialized_states.append(state.as_compressed_state_json())
|
||||
except (ValueError, TypeError):
|
||||
cannot_serialize.append(entity_id)
|
||||
connection.logger.error(
|
||||
"Unable to serialize to JSON. Bad data found at %s",
|
||||
format_unserializable_data(
|
||||
find_paths_unserializable_data(state, dump=JSON_DUMP)
|
||||
),
|
||||
)
|
||||
|
||||
for entity_id in cannot_serialize:
|
||||
del add_entities[entity_id]
|
||||
_send_handle_entities_init_response(connection, msg["id"], serialized_states)
|
||||
|
||||
connection.send_message(JSON_DUMP(messages.event_message(msg["id"], data)))
|
||||
|
||||
def _send_handle_entities_init_response(
|
||||
connection: ActiveConnection, msg_id: int, serialized_states: list[str]
|
||||
) -> None:
|
||||
"""Send handle entities init response."""
|
||||
connection.send_message(
|
||||
_HANDLE_SUBSCRIBE_ENTITIES_TEMPLATE.replace(
|
||||
messages.IDEN_JSON_TEMPLATE, str(msg_id), 1
|
||||
).replace(
|
||||
_STATES_JSON_TEMPLATE,
|
||||
"{" + ",".join(serialized_states) + "}",
|
||||
1,
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
@decorators.websocket_command({vol.Required("type"): "get_services"})
|
||||
|
|
|
@ -44,7 +44,7 @@ ENTITY_EVENT_REMOVE = "r"
|
|||
ENTITY_EVENT_CHANGE = "c"
|
||||
|
||||
|
||||
def result_message(iden: int, result: Any = None) -> dict[str, Any]:
|
||||
def result_message(iden: JSON_TYPE | int, result: Any = None) -> dict[str, Any]:
|
||||
"""Return a success result message."""
|
||||
return {"id": iden, "type": const.TYPE_RESULT, "success": True, "result": result}
|
||||
|
||||
|
|
|
@ -80,6 +80,7 @@ from .exceptions import (
|
|||
Unauthorized,
|
||||
)
|
||||
from .helpers.aiohttp_compat import restore_original_aiohttp_cancel_behavior
|
||||
from .helpers.json import json_dumps
|
||||
from .util import dt as dt_util, location, ulid as ulid_util
|
||||
from .util.async_ import run_callback_threadsafe, shutdown_run_callback_threadsafe
|
||||
from .util.read_only_dict import ReadOnlyDict
|
||||
|
@ -1224,6 +1225,8 @@ class State:
|
|||
"object_id",
|
||||
"_as_dict",
|
||||
"_as_compressed_state",
|
||||
"_as_dict_json",
|
||||
"_as_compressed_state_json",
|
||||
)
|
||||
|
||||
def __init__(
|
||||
|
@ -1260,6 +1263,8 @@ class State:
|
|||
self.domain, self.object_id = split_entity_id(self.entity_id)
|
||||
self._as_dict: ReadOnlyDict[str, Collection[Any]] | None = None
|
||||
self._as_compressed_state: dict[str, Any] | None = None
|
||||
self._as_dict_json: str | None = None
|
||||
self._as_compressed_state_json: str | None = None
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
|
@ -1294,6 +1299,12 @@ class State:
|
|||
)
|
||||
return self._as_dict
|
||||
|
||||
def as_dict_json(self) -> str:
|
||||
"""Return a JSON string of the State."""
|
||||
if not self._as_dict_json:
|
||||
self._as_dict_json = json_dumps(self.as_dict())
|
||||
return self._as_dict_json
|
||||
|
||||
def as_compressed_state(self) -> dict[str, Any]:
|
||||
"""Build a compressed dict of a state for adds.
|
||||
|
||||
|
@ -1321,6 +1332,19 @@ class State:
|
|||
self._as_compressed_state = compressed_state
|
||||
return compressed_state
|
||||
|
||||
def as_compressed_state_json(self) -> str:
|
||||
"""Build a compressed JSON key value pair of a state for adds.
|
||||
|
||||
The JSON string is a key value pair of the entity_id and the compressed state.
|
||||
|
||||
It is used for sending multiple states in a single message.
|
||||
"""
|
||||
if not self._as_compressed_state_json:
|
||||
self._as_compressed_state_json = json_dumps(
|
||||
{self.entity_id: self.as_compressed_state()}
|
||||
)[1:-1]
|
||||
return self._as_compressed_state_json
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, json_dict: dict[str, Any]) -> Self | None:
|
||||
"""Initialize a state from a dict.
|
||||
|
|
|
@ -9,7 +9,6 @@ from typing import Any, Final
|
|||
|
||||
import orjson
|
||||
|
||||
from homeassistant.core import Event, State
|
||||
from homeassistant.util.file import write_utf8_file, write_utf8_file_atomic
|
||||
from homeassistant.util.json import ( # pylint: disable=unused-import # noqa: F401
|
||||
JSON_DECODE_EXCEPTIONS,
|
||||
|
@ -189,6 +188,11 @@ def find_paths_unserializable_data(
|
|||
|
||||
This method is slow! Only use for error handling.
|
||||
"""
|
||||
from homeassistant.core import ( # pylint: disable=import-outside-toplevel
|
||||
Event,
|
||||
State,
|
||||
)
|
||||
|
||||
to_process = deque([(bad_data, "$")])
|
||||
invalid = {}
|
||||
|
||||
|
|
|
@ -188,10 +188,9 @@ async def test_non_json_message(
|
|||
assert msg["type"] == const.TYPE_RESULT
|
||||
assert msg["success"]
|
||||
assert msg["result"] == []
|
||||
assert (
|
||||
f"Unable to serialize to JSON. Bad data found at $.result[0](State: test_domain.entity).attributes.bad={bad_data}(<class 'object'>"
|
||||
in caplog.text
|
||||
)
|
||||
assert "Unable to serialize to JSON. Bad data found" in caplog.text
|
||||
assert "State: test_domain.entity" in caplog.text
|
||||
assert "bad=<object" in caplog.text
|
||||
|
||||
|
||||
async def test_prepare_fail(
|
||||
|
|
|
@ -466,6 +466,29 @@ def test_state_as_dict() -> None:
|
|||
assert state.as_dict() is as_dict_1
|
||||
|
||||
|
||||
def test_state_as_dict_json() -> None:
|
||||
"""Test a State as JSON."""
|
||||
last_time = datetime(1984, 12, 8, 12, 0, 0)
|
||||
state = ha.State(
|
||||
"happy.happy",
|
||||
"on",
|
||||
{"pig": "dog"},
|
||||
last_updated=last_time,
|
||||
last_changed=last_time,
|
||||
context=ha.Context(id="01H0D6K3RFJAYAV2093ZW30PCW"),
|
||||
)
|
||||
expected = (
|
||||
'{"entity_id":"happy.happy","state":"on","attributes":{"pig":"dog"},'
|
||||
'"last_changed":"1984-12-08T12:00:00","last_updated":"1984-12-08T12:00:00",'
|
||||
'"context":{"id":"01H0D6K3RFJAYAV2093ZW30PCW","parent_id":null,"user_id":null}}'
|
||||
)
|
||||
as_dict_json_1 = state.as_dict_json()
|
||||
assert as_dict_json_1 == expected
|
||||
# 2nd time to verify cache
|
||||
assert state.as_dict_json() == expected
|
||||
assert state.as_dict_json() is as_dict_json_1
|
||||
|
||||
|
||||
def test_state_as_compressed_state() -> None:
|
||||
"""Test a State as compressed state."""
|
||||
last_time = datetime(1984, 12, 8, 12, 0, 0, tzinfo=dt_util.UTC)
|
||||
|
@ -518,6 +541,27 @@ def test_state_as_compressed_state_unique_last_updated() -> None:
|
|||
assert state.as_compressed_state() is as_compressed_state
|
||||
|
||||
|
||||
def test_state_as_compressed_state_json() -> None:
|
||||
"""Test a State as a JSON compressed state."""
|
||||
last_time = datetime(1984, 12, 8, 12, 0, 0, tzinfo=dt_util.UTC)
|
||||
state = ha.State(
|
||||
"happy.happy",
|
||||
"on",
|
||||
{"pig": "dog"},
|
||||
last_updated=last_time,
|
||||
last_changed=last_time,
|
||||
context=ha.Context(id="01H0D6H5K3SZJ3XGDHED1TJ79N"),
|
||||
)
|
||||
expected = '"happy.happy":{"s":"on","a":{"pig":"dog"},"c":"01H0D6H5K3SZJ3XGDHED1TJ79N","lc":471355200.0}'
|
||||
as_compressed_state = state.as_compressed_state_json()
|
||||
# We are not too concerned about these being ReadOnlyDict
|
||||
# since we don't expect them to be called by external callers
|
||||
assert as_compressed_state == expected
|
||||
# 2nd time to verify cache
|
||||
assert state.as_compressed_state_json() == expected
|
||||
assert state.as_compressed_state_json() is as_compressed_state
|
||||
|
||||
|
||||
async def test_eventbus_add_remove_listener(hass: HomeAssistant) -> None:
|
||||
"""Test remove_listener method."""
|
||||
old_count = len(hass.bus.async_listeners())
|
||||
|
|
Loading…
Reference in New Issue