Use the orjson equivalent default encoder when save_json is passed the default encoder (#74377)
parent
97b6912856
commit
7f43064f36
|
@ -11,6 +11,10 @@ import orjson
|
|||
|
||||
from homeassistant.core import Event, State
|
||||
from homeassistant.exceptions import HomeAssistantError
|
||||
from homeassistant.helpers.json import (
|
||||
JSONEncoder as DefaultHASSJSONEncoder,
|
||||
json_encoder_default as default_hass_orjson_encoder,
|
||||
)
|
||||
|
||||
from .file import write_utf8_file, write_utf8_file_atomic
|
||||
|
||||
|
@ -52,6 +56,15 @@ def _orjson_encoder(data: Any) -> str:
|
|||
).decode("utf-8")
|
||||
|
||||
|
||||
def _orjson_default_encoder(data: Any) -> str:
|
||||
"""JSON encoder that uses orjson with hass defaults."""
|
||||
return orjson.dumps(
|
||||
data,
|
||||
option=orjson.OPT_INDENT_2 | orjson.OPT_NON_STR_KEYS,
|
||||
default=default_hass_orjson_encoder,
|
||||
).decode("utf-8")
|
||||
|
||||
|
||||
def save_json(
|
||||
filename: str,
|
||||
data: list | dict,
|
||||
|
@ -64,10 +77,20 @@ def save_json(
|
|||
|
||||
Returns True on success.
|
||||
"""
|
||||
dump: Callable[[Any], Any] = json.dumps
|
||||
dump: Callable[[Any], Any]
|
||||
try:
|
||||
if encoder:
|
||||
json_data = json.dumps(data, indent=2, cls=encoder)
|
||||
# For backwards compatibility, if they pass in the
|
||||
# default json encoder we use _orjson_default_encoder
|
||||
# which is the orjson equivalent to the default encoder.
|
||||
if encoder is DefaultHASSJSONEncoder:
|
||||
dump = _orjson_default_encoder
|
||||
json_data = _orjson_default_encoder(data)
|
||||
# If they pass a custom encoder that is not the
|
||||
# DefaultHASSJSONEncoder, we use the slow path of json.dumps
|
||||
else:
|
||||
dump = json.dumps
|
||||
json_data = json.dumps(data, indent=2, cls=encoder)
|
||||
else:
|
||||
dump = _orjson_encoder
|
||||
json_data = _orjson_encoder(data)
|
||||
|
|
|
@ -5,12 +5,13 @@ from json import JSONEncoder, dumps
|
|||
import math
|
||||
import os
|
||||
from tempfile import mkdtemp
|
||||
from unittest.mock import Mock
|
||||
from unittest.mock import Mock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from homeassistant.core import Event, State
|
||||
from homeassistant.exceptions import HomeAssistantError
|
||||
from homeassistant.helpers.json import JSONEncoder as DefaultHASSJSONEncoder
|
||||
from homeassistant.helpers.template import TupleWrapper
|
||||
from homeassistant.util.json import (
|
||||
SerializationError,
|
||||
|
@ -127,6 +128,21 @@ def test_custom_encoder():
|
|||
assert data == "9"
|
||||
|
||||
|
||||
def test_default_encoder_is_passed():
|
||||
"""Test we use orjson if they pass in the default encoder."""
|
||||
fname = _path_for("test6")
|
||||
with patch(
|
||||
"homeassistant.util.json.orjson.dumps", return_value=b"{}"
|
||||
) as mock_orjson_dumps:
|
||||
save_json(fname, {"any": 1}, encoder=DefaultHASSJSONEncoder)
|
||||
assert len(mock_orjson_dumps.mock_calls) == 1
|
||||
# Patch json.dumps to make sure we are using the orjson path
|
||||
with patch("homeassistant.util.json.json.dumps", side_effect=Exception):
|
||||
save_json(fname, {"any": {1}}, encoder=DefaultHASSJSONEncoder)
|
||||
data = load_json(fname)
|
||||
assert data == {"any": [1]}
|
||||
|
||||
|
||||
def test_find_unserializable_data():
|
||||
"""Find unserializeable data."""
|
||||
assert find_paths_unserializable_data(1) == {}
|
||||
|
|
Loading…
Reference in New Issue