Use default encoder when saving storage (#75319)

pull/75359/head
J. Nick Koston 2022-07-17 07:25:19 -05:00 committed by GitHub
parent 2eebda63fd
commit 9a27f1437d
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 62 additions and 36 deletions

View File

@ -49,13 +49,6 @@ def load_json(filename: str, default: list | dict | None = None) -> list | dict:
return {} if default is None else default
def _orjson_encoder(data: Any) -> str:
"""JSON encoder that uses orjson."""
return orjson.dumps(
data, option=orjson.OPT_INDENT_2 | orjson.OPT_NON_STR_KEYS
).decode("utf-8")
def _orjson_default_encoder(data: Any) -> str:
"""JSON encoder that uses orjson with hass defaults."""
return orjson.dumps(
@ -79,21 +72,17 @@ def save_json(
"""
dump: Callable[[Any], Any]
try:
if encoder:
# For backwards compatibility, if they pass in the
# default json encoder we use _orjson_default_encoder
# which is the orjson equivalent to the default encoder.
if encoder is DefaultHASSJSONEncoder:
dump = _orjson_default_encoder
json_data = _orjson_default_encoder(data)
# For backwards compatibility, if they pass in the
# default json encoder we use _orjson_default_encoder
# which is the orjson equivalent to the default encoder.
if encoder and encoder is not DefaultHASSJSONEncoder:
# If they pass a custom encoder that is not the
# DefaultHASSJSONEncoder, we use the slow path of json.dumps
else:
dump = json.dumps
json_data = json.dumps(data, indent=2, cls=encoder)
dump = json.dumps
json_data = json.dumps(data, indent=2, cls=encoder)
else:
dump = _orjson_encoder
json_data = _orjson_encoder(data)
dump = _orjson_default_encoder
json_data = _orjson_default_encoder(data)
except TypeError as error:
msg = f"Failed to serialize to JSON: {filename}. Bad data at {format_unserializable_data(find_paths_unserializable_data(data, dump=dump))}"
_LOGGER.error(msg)

View File

@ -2,6 +2,7 @@
import asyncio
from datetime import timedelta
import json
from typing import NamedTuple
from unittest.mock import Mock, patch
import pytest
@ -13,8 +14,9 @@ from homeassistant.const import (
from homeassistant.core import CoreState
from homeassistant.helpers import storage
from homeassistant.util import dt
from homeassistant.util.color import RGBColor
from tests.common import async_fire_time_changed
from tests.common import async_fire_time_changed, async_test_home_assistant
MOCK_VERSION = 1
MOCK_VERSION_2 = 2
@ -460,3 +462,47 @@ async def test_changing_delayed_written_data(hass, store, hass_storage):
"key": MOCK_KEY,
"data": {"hello": "world"},
}
async def test_saving_load_round_trip(tmpdir):
"""Test saving and loading round trip."""
loop = asyncio.get_running_loop()
hass = await async_test_home_assistant(loop)
hass.config.config_dir = await hass.async_add_executor_job(
tmpdir.mkdir, "temp_storage"
)
class NamedTupleSubclass(NamedTuple):
"""A NamedTuple subclass."""
name: str
nts = NamedTupleSubclass("a")
data = {
"named_tuple_subclass": nts,
"rgb_color": RGBColor(255, 255, 0),
"set": {1, 2, 3},
"list": [1, 2, 3],
"tuple": (1, 2, 3),
"dict_with_int": {1: 1, 2: 2},
"dict_with_named_tuple": {1: nts, 2: nts},
}
store = storage.Store(
hass, MOCK_VERSION_2, MOCK_KEY, minor_version=MOCK_MINOR_VERSION_1
)
await store.async_save(data)
load = await store.async_load()
assert load == {
"dict_with_int": {"1": 1, "2": 2},
"dict_with_named_tuple": {"1": ["a"], "2": ["a"]},
"list": [1, 2, 3],
"named_tuple_subclass": ["a"],
"rgb_color": [255, 255, 0],
"set": [1, 2, 3],
"tuple": [1, 2, 3],
}
await hass.async_stop(force=True)

View File

@ -12,7 +12,6 @@ import pytest
from homeassistant.core import Event, State
from homeassistant.exceptions import HomeAssistantError
from homeassistant.helpers.json import JSONEncoder as DefaultHASSJSONEncoder
from homeassistant.helpers.template import TupleWrapper
from homeassistant.util.json import (
SerializationError,
find_paths_unserializable_data,
@ -83,23 +82,15 @@ def test_overwrite_and_reload(atomic_writes):
def test_save_bad_data():
"""Test error from trying to save unserializable data."""
class CannotSerializeMe:
"""Cannot serialize this."""
with pytest.raises(SerializationError) as excinfo:
save_json("test4", {"hello": set()})
save_json("test4", {"hello": CannotSerializeMe()})
assert (
"Failed to serialize to JSON: test4. Bad data at $.hello=set()(<class 'set'>"
in str(excinfo.value)
)
def test_save_bad_data_tuple_wrapper():
"""Test error from trying to save unserializable data."""
with pytest.raises(SerializationError) as excinfo:
save_json("test4", {"hello": TupleWrapper(("4", "5"))})
assert (
"Failed to serialize to JSON: test4. Bad data at $.hello=('4', '5')(<class 'homeassistant.helpers.template.TupleWrapper'>"
in str(excinfo.value)
assert "Failed to serialize to JSON: test4. Bad data at $.hello=" in str(
excinfo.value
)