Add custom JSONEncoder for subscribe_trigger WS endpoint (#48664)

pull/48982/head
Jason 2021-04-09 20:47:10 -07:00 committed by GitHub
parent 324dd12db8
commit 7cc857a298
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 66 additions and 69 deletions

View File

@ -1,9 +1,5 @@
"""Helpers for script and automation tracing and debugging."""
from collections import OrderedDict
from datetime import timedelta
from typing import Any
from homeassistant.helpers.json import JSONEncoder as HAJSONEncoder
class LimitedSizeDict(OrderedDict):
@ -25,19 +21,3 @@ class LimitedSizeDict(OrderedDict):
if self.size_limit is not None:
while len(self) > self.size_limit:
self.popitem(last=False)
class TraceJSONEncoder(HAJSONEncoder):
"""JSONEncoder that supports Home Assistant objects and falls back to repr(o)."""
def default(self, o: Any) -> Any:
"""Convert certain objects.
Fall back to repr(o).
"""
if isinstance(o, timedelta):
return {"__type": str(type(o)), "total_seconds": o.total_seconds()}
try:
return super().default(o)
except TypeError:
return {"__type": str(type(o)), "repr": repr(o)}

View File

@ -11,6 +11,7 @@ from homeassistant.helpers.dispatcher import (
async_dispatcher_connect,
async_dispatcher_send,
)
from homeassistant.helpers.json import ExtendedJSONEncoder
from homeassistant.helpers.script import (
SCRIPT_BREAKPOINT_HIT,
SCRIPT_DEBUG_CONTINUE_ALL,
@ -24,7 +25,6 @@ from homeassistant.helpers.script import (
)
from .const import DATA_TRACE
from .utils import TraceJSONEncoder
# mypy: allow-untyped-calls, allow-untyped-defs
@ -71,7 +71,9 @@ def websocket_trace_get(hass, connection, msg):
message = websocket_api.messages.result_message(msg["id"], trace)
connection.send_message(json.dumps(message, cls=TraceJSONEncoder, allow_nan=False))
connection.send_message(
json.dumps(message, cls=ExtendedJSONEncoder, allow_nan=False)
)
def get_debug_traces(hass, key):

View File

@ -1,5 +1,6 @@
"""Commands part of Websocket API."""
import asyncio
import json
import voluptuous as vol
@ -17,6 +18,7 @@ from homeassistant.exceptions import (
from homeassistant.helpers import config_validation as cv, entity, template
from homeassistant.helpers.dispatcher import async_dispatcher_connect
from homeassistant.helpers.event import TrackTemplate, async_track_template_result
from homeassistant.helpers.json import ExtendedJSONEncoder
from homeassistant.helpers.service import async_get_all_descriptions
from homeassistant.loader import IntegrationNotFound, async_get_integration
from homeassistant.setup import DATA_SETUP_TIME, async_get_loaded_integrations
@ -417,10 +419,11 @@ async def handle_subscribe_trigger(hass, connection, msg):
@callback
def forward_triggers(variables, context=None):
"""Forward events to websocket."""
message = messages.event_message(
msg["id"], {"variables": variables, "context": context}
)
connection.send_message(
messages.event_message(
msg["id"], {"variables": variables, "context": context}
)
json.dumps(message, cls=ExtendedJSONEncoder, allow_nan=False)
)
connection.subscriptions[msg["id"]] = (

View File

@ -1,5 +1,5 @@
"""Helpers to help with encoding Home Assistant objects in JSON."""
from datetime import datetime
from datetime import datetime, timedelta
import json
from typing import Any
@ -20,3 +20,19 @@ class JSONEncoder(json.JSONEncoder):
return o.as_dict()
return json.JSONEncoder.default(self, o)
class ExtendedJSONEncoder(JSONEncoder):
"""JSONEncoder that supports Home Assistant objects and falls back to repr(o)."""
def default(self, o: Any) -> Any:
"""Convert certain objects.
Fall back to repr(o).
"""
if isinstance(o, timedelta):
return {"__type": str(type(o)), "total_seconds": o.total_seconds()}
try:
return super().default(o)
except TypeError:
return {"__type": str(type(o)), "repr": repr(o)}

View File

@ -1,42 +0,0 @@
"""Test trace helpers."""
from datetime import timedelta
from homeassistant import core
from homeassistant.components import trace
from homeassistant.util import dt as dt_util
def test_json_encoder(hass):
"""Test the Trace JSON Encoder."""
ha_json_enc = trace.utils.TraceJSONEncoder()
state = core.State("test.test", "hello")
# Test serializing a datetime
now = dt_util.utcnow()
assert ha_json_enc.default(now) == now.isoformat()
# Test serializing a timedelta
data = timedelta(
days=50,
seconds=27,
microseconds=10,
milliseconds=29000,
minutes=5,
hours=8,
weeks=2,
)
assert ha_json_enc.default(data) == {
"__type": str(type(data)),
"total_seconds": data.total_seconds(),
}
# Test serializing a set()
data = {"milk", "beer"}
assert sorted(ha_json_enc.default(data)) == sorted(data)
# Test serializong object which implements as_dict
assert ha_json_enc.default(state) == state.as_dict()
# Default method falls back to repr(o)
o = object()
assert ha_json_enc.default(o) == {"__type": str(type(o)), "repr": repr(o)}

View File

@ -1,8 +1,10 @@
"""Test Home Assistant remote methods and classes."""
from datetime import timedelta
import pytest
from homeassistant import core
from homeassistant.helpers.json import JSONEncoder
from homeassistant.helpers.json import ExtendedJSONEncoder, JSONEncoder
from homeassistant.util import dt as dt_util
@ -25,3 +27,39 @@ def test_json_encoder(hass):
# Default method raises TypeError if non HA object
with pytest.raises(TypeError):
ha_json_enc.default(1)
def test_trace_json_encoder(hass):
"""Test the Trace JSON Encoder."""
ha_json_enc = ExtendedJSONEncoder()
state = core.State("test.test", "hello")
# Test serializing a datetime
now = dt_util.utcnow()
assert ha_json_enc.default(now) == now.isoformat()
# Test serializing a timedelta
data = timedelta(
days=50,
seconds=27,
microseconds=10,
milliseconds=29000,
minutes=5,
hours=8,
weeks=2,
)
assert ha_json_enc.default(data) == {
"__type": str(type(data)),
"total_seconds": data.total_seconds(),
}
# Test serializing a set()
data = {"milk", "beer"}
assert sorted(ha_json_enc.default(data)) == sorted(data)
# Test serializong object which implements as_dict
assert ha_json_enc.default(state) == state.as_dict()
# Default method falls back to repr(o)
o = object()
assert ha_json_enc.default(o) == {"__type": str(type(o)), "repr": repr(o)}