Add custom JSONEncoder for subscribe_trigger WS endpoint (#48664)
parent
324dd12db8
commit
7cc857a298
|
@ -1,9 +1,5 @@
|
|||
"""Helpers for script and automation tracing and debugging."""
|
||||
from collections import OrderedDict
|
||||
from datetime import timedelta
|
||||
from typing import Any
|
||||
|
||||
from homeassistant.helpers.json import JSONEncoder as HAJSONEncoder
|
||||
|
||||
|
||||
class LimitedSizeDict(OrderedDict):
|
||||
|
@ -25,19 +21,3 @@ class LimitedSizeDict(OrderedDict):
|
|||
if self.size_limit is not None:
|
||||
while len(self) > self.size_limit:
|
||||
self.popitem(last=False)
|
||||
|
||||
|
||||
class TraceJSONEncoder(HAJSONEncoder):
|
||||
"""JSONEncoder that supports Home Assistant objects and falls back to repr(o)."""
|
||||
|
||||
def default(self, o: Any) -> Any:
|
||||
"""Convert certain objects.
|
||||
|
||||
Fall back to repr(o).
|
||||
"""
|
||||
if isinstance(o, timedelta):
|
||||
return {"__type": str(type(o)), "total_seconds": o.total_seconds()}
|
||||
try:
|
||||
return super().default(o)
|
||||
except TypeError:
|
||||
return {"__type": str(type(o)), "repr": repr(o)}
|
||||
|
|
|
@ -11,6 +11,7 @@ from homeassistant.helpers.dispatcher import (
|
|||
async_dispatcher_connect,
|
||||
async_dispatcher_send,
|
||||
)
|
||||
from homeassistant.helpers.json import ExtendedJSONEncoder
|
||||
from homeassistant.helpers.script import (
|
||||
SCRIPT_BREAKPOINT_HIT,
|
||||
SCRIPT_DEBUG_CONTINUE_ALL,
|
||||
|
@ -24,7 +25,6 @@ from homeassistant.helpers.script import (
|
|||
)
|
||||
|
||||
from .const import DATA_TRACE
|
||||
from .utils import TraceJSONEncoder
|
||||
|
||||
# mypy: allow-untyped-calls, allow-untyped-defs
|
||||
|
||||
|
@ -71,7 +71,9 @@ def websocket_trace_get(hass, connection, msg):
|
|||
|
||||
message = websocket_api.messages.result_message(msg["id"], trace)
|
||||
|
||||
connection.send_message(json.dumps(message, cls=TraceJSONEncoder, allow_nan=False))
|
||||
connection.send_message(
|
||||
json.dumps(message, cls=ExtendedJSONEncoder, allow_nan=False)
|
||||
)
|
||||
|
||||
|
||||
def get_debug_traces(hass, key):
|
||||
|
|
|
@ -1,5 +1,6 @@
|
|||
"""Commands part of Websocket API."""
|
||||
import asyncio
|
||||
import json
|
||||
|
||||
import voluptuous as vol
|
||||
|
||||
|
@ -17,6 +18,7 @@ from homeassistant.exceptions import (
|
|||
from homeassistant.helpers import config_validation as cv, entity, template
|
||||
from homeassistant.helpers.dispatcher import async_dispatcher_connect
|
||||
from homeassistant.helpers.event import TrackTemplate, async_track_template_result
|
||||
from homeassistant.helpers.json import ExtendedJSONEncoder
|
||||
from homeassistant.helpers.service import async_get_all_descriptions
|
||||
from homeassistant.loader import IntegrationNotFound, async_get_integration
|
||||
from homeassistant.setup import DATA_SETUP_TIME, async_get_loaded_integrations
|
||||
|
@ -417,10 +419,11 @@ async def handle_subscribe_trigger(hass, connection, msg):
|
|||
@callback
|
||||
def forward_triggers(variables, context=None):
|
||||
"""Forward events to websocket."""
|
||||
message = messages.event_message(
|
||||
msg["id"], {"variables": variables, "context": context}
|
||||
)
|
||||
connection.send_message(
|
||||
messages.event_message(
|
||||
msg["id"], {"variables": variables, "context": context}
|
||||
)
|
||||
json.dumps(message, cls=ExtendedJSONEncoder, allow_nan=False)
|
||||
)
|
||||
|
||||
connection.subscriptions[msg["id"]] = (
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
"""Helpers to help with encoding Home Assistant objects in JSON."""
|
||||
from datetime import datetime
|
||||
from datetime import datetime, timedelta
|
||||
import json
|
||||
from typing import Any
|
||||
|
||||
|
@ -20,3 +20,19 @@ class JSONEncoder(json.JSONEncoder):
|
|||
return o.as_dict()
|
||||
|
||||
return json.JSONEncoder.default(self, o)
|
||||
|
||||
|
||||
class ExtendedJSONEncoder(JSONEncoder):
|
||||
"""JSONEncoder that supports Home Assistant objects and falls back to repr(o)."""
|
||||
|
||||
def default(self, o: Any) -> Any:
|
||||
"""Convert certain objects.
|
||||
|
||||
Fall back to repr(o).
|
||||
"""
|
||||
if isinstance(o, timedelta):
|
||||
return {"__type": str(type(o)), "total_seconds": o.total_seconds()}
|
||||
try:
|
||||
return super().default(o)
|
||||
except TypeError:
|
||||
return {"__type": str(type(o)), "repr": repr(o)}
|
||||
|
|
|
@ -1,42 +0,0 @@
|
|||
"""Test trace helpers."""
|
||||
from datetime import timedelta
|
||||
|
||||
from homeassistant import core
|
||||
from homeassistant.components import trace
|
||||
from homeassistant.util import dt as dt_util
|
||||
|
||||
|
||||
def test_json_encoder(hass):
|
||||
"""Test the Trace JSON Encoder."""
|
||||
ha_json_enc = trace.utils.TraceJSONEncoder()
|
||||
state = core.State("test.test", "hello")
|
||||
|
||||
# Test serializing a datetime
|
||||
now = dt_util.utcnow()
|
||||
assert ha_json_enc.default(now) == now.isoformat()
|
||||
|
||||
# Test serializing a timedelta
|
||||
data = timedelta(
|
||||
days=50,
|
||||
seconds=27,
|
||||
microseconds=10,
|
||||
milliseconds=29000,
|
||||
minutes=5,
|
||||
hours=8,
|
||||
weeks=2,
|
||||
)
|
||||
assert ha_json_enc.default(data) == {
|
||||
"__type": str(type(data)),
|
||||
"total_seconds": data.total_seconds(),
|
||||
}
|
||||
|
||||
# Test serializing a set()
|
||||
data = {"milk", "beer"}
|
||||
assert sorted(ha_json_enc.default(data)) == sorted(data)
|
||||
|
||||
# Test serializong object which implements as_dict
|
||||
assert ha_json_enc.default(state) == state.as_dict()
|
||||
|
||||
# Default method falls back to repr(o)
|
||||
o = object()
|
||||
assert ha_json_enc.default(o) == {"__type": str(type(o)), "repr": repr(o)}
|
|
@ -1,8 +1,10 @@
|
|||
"""Test Home Assistant remote methods and classes."""
|
||||
from datetime import timedelta
|
||||
|
||||
import pytest
|
||||
|
||||
from homeassistant import core
|
||||
from homeassistant.helpers.json import JSONEncoder
|
||||
from homeassistant.helpers.json import ExtendedJSONEncoder, JSONEncoder
|
||||
from homeassistant.util import dt as dt_util
|
||||
|
||||
|
||||
|
@ -25,3 +27,39 @@ def test_json_encoder(hass):
|
|||
# Default method raises TypeError if non HA object
|
||||
with pytest.raises(TypeError):
|
||||
ha_json_enc.default(1)
|
||||
|
||||
|
||||
def test_trace_json_encoder(hass):
|
||||
"""Test the Trace JSON Encoder."""
|
||||
ha_json_enc = ExtendedJSONEncoder()
|
||||
state = core.State("test.test", "hello")
|
||||
|
||||
# Test serializing a datetime
|
||||
now = dt_util.utcnow()
|
||||
assert ha_json_enc.default(now) == now.isoformat()
|
||||
|
||||
# Test serializing a timedelta
|
||||
data = timedelta(
|
||||
days=50,
|
||||
seconds=27,
|
||||
microseconds=10,
|
||||
milliseconds=29000,
|
||||
minutes=5,
|
||||
hours=8,
|
||||
weeks=2,
|
||||
)
|
||||
assert ha_json_enc.default(data) == {
|
||||
"__type": str(type(data)),
|
||||
"total_seconds": data.total_seconds(),
|
||||
}
|
||||
|
||||
# Test serializing a set()
|
||||
data = {"milk", "beer"}
|
||||
assert sorted(ha_json_enc.default(data)) == sorted(data)
|
||||
|
||||
# Test serializong object which implements as_dict
|
||||
assert ha_json_enc.default(state) == state.as_dict()
|
||||
|
||||
# Default method falls back to repr(o)
|
||||
o = object()
|
||||
assert ha_json_enc.default(o) == {"__type": str(type(o)), "repr": repr(o)}
|
||||
|
|
Loading…
Reference in New Issue