Use the orjson equivalent default encoder when save_json is passed the default encoder (#74377)

pull/75528/head
J. Nick Koston 2022-07-04 08:41:23 -05:00 committed by Franck Nijhof
parent 97b6912856
commit 7f43064f36
No known key found for this signature in database
GPG Key ID: D62583BA8AB11CA3
2 changed files with 42 additions and 3 deletions

View File

@ -11,6 +11,10 @@ import orjson
from homeassistant.core import Event, State
from homeassistant.exceptions import HomeAssistantError
from homeassistant.helpers.json import (
JSONEncoder as DefaultHASSJSONEncoder,
json_encoder_default as default_hass_orjson_encoder,
)
from .file import write_utf8_file, write_utf8_file_atomic
@ -52,6 +56,15 @@ def _orjson_encoder(data: Any) -> str:
).decode("utf-8")
def _orjson_default_encoder(data: Any) -> str:
"""JSON encoder that uses orjson with hass defaults."""
return orjson.dumps(
data,
option=orjson.OPT_INDENT_2 | orjson.OPT_NON_STR_KEYS,
default=default_hass_orjson_encoder,
).decode("utf-8")
def save_json(
filename: str,
data: list | dict,
@ -64,10 +77,20 @@ def save_json(
Returns True on success.
"""
dump: Callable[[Any], Any] = json.dumps
dump: Callable[[Any], Any]
try:
if encoder:
json_data = json.dumps(data, indent=2, cls=encoder)
# For backwards compatibility, if they pass in the
# default json encoder we use _orjson_default_encoder
# which is the orjson equivalent to the default encoder.
if encoder is DefaultHASSJSONEncoder:
dump = _orjson_default_encoder
json_data = _orjson_default_encoder(data)
# If they pass a custom encoder that is not the
# DefaultHASSJSONEncoder, we use the slow path of json.dumps
else:
dump = json.dumps
json_data = json.dumps(data, indent=2, cls=encoder)
else:
dump = _orjson_encoder
json_data = _orjson_encoder(data)

View File

@ -5,12 +5,13 @@ from json import JSONEncoder, dumps
import math
import os
from tempfile import mkdtemp
from unittest.mock import Mock
from unittest.mock import Mock, patch
import pytest
from homeassistant.core import Event, State
from homeassistant.exceptions import HomeAssistantError
from homeassistant.helpers.json import JSONEncoder as DefaultHASSJSONEncoder
from homeassistant.helpers.template import TupleWrapper
from homeassistant.util.json import (
SerializationError,
@ -127,6 +128,21 @@ def test_custom_encoder():
assert data == "9"
def test_default_encoder_is_passed():
"""Test we use orjson if they pass in the default encoder."""
fname = _path_for("test6")
with patch(
"homeassistant.util.json.orjson.dumps", return_value=b"{}"
) as mock_orjson_dumps:
save_json(fname, {"any": 1}, encoder=DefaultHASSJSONEncoder)
assert len(mock_orjson_dumps.mock_calls) == 1
# Patch json.dumps to make sure we are using the orjson path
with patch("homeassistant.util.json.json.dumps", side_effect=Exception):
save_json(fname, {"any": {1}}, encoder=DefaultHASSJSONEncoder)
data = load_json(fname)
assert data == {"any": [1]}
def test_find_unserializable_data():
"""Find unserializeable data."""
assert find_paths_unserializable_data(1) == {}