Use default encoder when saving storage (#75319)
parent
2eebda63fd
commit
9a27f1437d
|
@ -49,13 +49,6 @@ def load_json(filename: str, default: list | dict | None = None) -> list | dict:
|
|||
return {} if default is None else default
|
||||
|
||||
|
||||
def _orjson_encoder(data: Any) -> str:
|
||||
"""JSON encoder that uses orjson."""
|
||||
return orjson.dumps(
|
||||
data, option=orjson.OPT_INDENT_2 | orjson.OPT_NON_STR_KEYS
|
||||
).decode("utf-8")
|
||||
|
||||
|
||||
def _orjson_default_encoder(data: Any) -> str:
|
||||
"""JSON encoder that uses orjson with hass defaults."""
|
||||
return orjson.dumps(
|
||||
|
@ -79,21 +72,17 @@ def save_json(
|
|||
"""
|
||||
dump: Callable[[Any], Any]
|
||||
try:
|
||||
if encoder:
|
||||
# For backwards compatibility, if they pass in the
|
||||
# default json encoder we use _orjson_default_encoder
|
||||
# which is the orjson equivalent to the default encoder.
|
||||
if encoder is DefaultHASSJSONEncoder:
|
||||
dump = _orjson_default_encoder
|
||||
json_data = _orjson_default_encoder(data)
|
||||
# For backwards compatibility, if they pass in the
|
||||
# default json encoder we use _orjson_default_encoder
|
||||
# which is the orjson equivalent to the default encoder.
|
||||
if encoder and encoder is not DefaultHASSJSONEncoder:
|
||||
# If they pass a custom encoder that is not the
|
||||
# DefaultHASSJSONEncoder, we use the slow path of json.dumps
|
||||
else:
|
||||
dump = json.dumps
|
||||
json_data = json.dumps(data, indent=2, cls=encoder)
|
||||
dump = json.dumps
|
||||
json_data = json.dumps(data, indent=2, cls=encoder)
|
||||
else:
|
||||
dump = _orjson_encoder
|
||||
json_data = _orjson_encoder(data)
|
||||
dump = _orjson_default_encoder
|
||||
json_data = _orjson_default_encoder(data)
|
||||
except TypeError as error:
|
||||
msg = f"Failed to serialize to JSON: {filename}. Bad data at {format_unserializable_data(find_paths_unserializable_data(data, dump=dump))}"
|
||||
_LOGGER.error(msg)
|
||||
|
|
|
@ -2,6 +2,7 @@
|
|||
import asyncio
|
||||
from datetime import timedelta
|
||||
import json
|
||||
from typing import NamedTuple
|
||||
from unittest.mock import Mock, patch
|
||||
|
||||
import pytest
|
||||
|
@ -13,8 +14,9 @@ from homeassistant.const import (
|
|||
from homeassistant.core import CoreState
|
||||
from homeassistant.helpers import storage
|
||||
from homeassistant.util import dt
|
||||
from homeassistant.util.color import RGBColor
|
||||
|
||||
from tests.common import async_fire_time_changed
|
||||
from tests.common import async_fire_time_changed, async_test_home_assistant
|
||||
|
||||
MOCK_VERSION = 1
|
||||
MOCK_VERSION_2 = 2
|
||||
|
@ -460,3 +462,47 @@ async def test_changing_delayed_written_data(hass, store, hass_storage):
|
|||
"key": MOCK_KEY,
|
||||
"data": {"hello": "world"},
|
||||
}
|
||||
|
||||
|
||||
async def test_saving_load_round_trip(tmpdir):
|
||||
"""Test saving and loading round trip."""
|
||||
loop = asyncio.get_running_loop()
|
||||
hass = await async_test_home_assistant(loop)
|
||||
|
||||
hass.config.config_dir = await hass.async_add_executor_job(
|
||||
tmpdir.mkdir, "temp_storage"
|
||||
)
|
||||
|
||||
class NamedTupleSubclass(NamedTuple):
|
||||
"""A NamedTuple subclass."""
|
||||
|
||||
name: str
|
||||
|
||||
nts = NamedTupleSubclass("a")
|
||||
|
||||
data = {
|
||||
"named_tuple_subclass": nts,
|
||||
"rgb_color": RGBColor(255, 255, 0),
|
||||
"set": {1, 2, 3},
|
||||
"list": [1, 2, 3],
|
||||
"tuple": (1, 2, 3),
|
||||
"dict_with_int": {1: 1, 2: 2},
|
||||
"dict_with_named_tuple": {1: nts, 2: nts},
|
||||
}
|
||||
|
||||
store = storage.Store(
|
||||
hass, MOCK_VERSION_2, MOCK_KEY, minor_version=MOCK_MINOR_VERSION_1
|
||||
)
|
||||
await store.async_save(data)
|
||||
load = await store.async_load()
|
||||
assert load == {
|
||||
"dict_with_int": {"1": 1, "2": 2},
|
||||
"dict_with_named_tuple": {"1": ["a"], "2": ["a"]},
|
||||
"list": [1, 2, 3],
|
||||
"named_tuple_subclass": ["a"],
|
||||
"rgb_color": [255, 255, 0],
|
||||
"set": [1, 2, 3],
|
||||
"tuple": [1, 2, 3],
|
||||
}
|
||||
|
||||
await hass.async_stop(force=True)
|
||||
|
|
|
@ -12,7 +12,6 @@ import pytest
|
|||
from homeassistant.core import Event, State
|
||||
from homeassistant.exceptions import HomeAssistantError
|
||||
from homeassistant.helpers.json import JSONEncoder as DefaultHASSJSONEncoder
|
||||
from homeassistant.helpers.template import TupleWrapper
|
||||
from homeassistant.util.json import (
|
||||
SerializationError,
|
||||
find_paths_unserializable_data,
|
||||
|
@ -83,23 +82,15 @@ def test_overwrite_and_reload(atomic_writes):
|
|||
|
||||
def test_save_bad_data():
|
||||
"""Test error from trying to save unserializable data."""
|
||||
|
||||
class CannotSerializeMe:
|
||||
"""Cannot serialize this."""
|
||||
|
||||
with pytest.raises(SerializationError) as excinfo:
|
||||
save_json("test4", {"hello": set()})
|
||||
save_json("test4", {"hello": CannotSerializeMe()})
|
||||
|
||||
assert (
|
||||
"Failed to serialize to JSON: test4. Bad data at $.hello=set()(<class 'set'>"
|
||||
in str(excinfo.value)
|
||||
)
|
||||
|
||||
|
||||
def test_save_bad_data_tuple_wrapper():
|
||||
"""Test error from trying to save unserializable data."""
|
||||
with pytest.raises(SerializationError) as excinfo:
|
||||
save_json("test4", {"hello": TupleWrapper(("4", "5"))})
|
||||
|
||||
assert (
|
||||
"Failed to serialize to JSON: test4. Bad data at $.hello=('4', '5')(<class 'homeassistant.helpers.template.TupleWrapper'>"
|
||||
in str(excinfo.value)
|
||||
assert "Failed to serialize to JSON: test4. Bad data at $.hello=" in str(
|
||||
excinfo.value
|
||||
)
|
||||
|
||||
|
||||
|
|
Loading…
Reference in New Issue