From 9a27f1437d84e1ab87a8089f910fc14e384581c9 Mon Sep 17 00:00:00 2001 From: "J. Nick Koston" Date: Sun, 17 Jul 2022 07:25:19 -0500 Subject: [PATCH] Use default encoder when saving storage (#75319) --- homeassistant/util/json.py | 27 ++++++-------------- tests/helpers/test_storage.py | 48 ++++++++++++++++++++++++++++++++++- tests/util/test_json.py | 23 +++++------------ 3 files changed, 62 insertions(+), 36 deletions(-) diff --git a/homeassistant/util/json.py b/homeassistant/util/json.py index 68273c89743..1413f6d9b15 100644 --- a/homeassistant/util/json.py +++ b/homeassistant/util/json.py @@ -49,13 +49,6 @@ def load_json(filename: str, default: list | dict | None = None) -> list | dict: return {} if default is None else default -def _orjson_encoder(data: Any) -> str: - """JSON encoder that uses orjson.""" - return orjson.dumps( - data, option=orjson.OPT_INDENT_2 | orjson.OPT_NON_STR_KEYS - ).decode("utf-8") - - def _orjson_default_encoder(data: Any) -> str: """JSON encoder that uses orjson with hass defaults.""" return orjson.dumps( @@ -79,21 +72,17 @@ def save_json( """ dump: Callable[[Any], Any] try: - if encoder: - # For backwards compatibility, if they pass in the - # default json encoder we use _orjson_default_encoder - # which is the orjson equivalent to the default encoder. - if encoder is DefaultHASSJSONEncoder: - dump = _orjson_default_encoder - json_data = _orjson_default_encoder(data) + # For backwards compatibility, if they pass in the + # default json encoder we use _orjson_default_encoder + # which is the orjson equivalent to the default encoder. + if encoder and encoder is not DefaultHASSJSONEncoder: # If they pass a custom encoder that is not the # DefaultHASSJSONEncoder, we use the slow path of json.dumps - else: - dump = json.dumps - json_data = json.dumps(data, indent=2, cls=encoder) + dump = json.dumps + json_data = json.dumps(data, indent=2, cls=encoder) else: - dump = _orjson_encoder - json_data = _orjson_encoder(data) + dump = _orjson_default_encoder + json_data = _orjson_default_encoder(data) except TypeError as error: msg = f"Failed to serialize to JSON: {filename}. Bad data at {format_unserializable_data(find_paths_unserializable_data(data, dump=dump))}" _LOGGER.error(msg) diff --git a/tests/helpers/test_storage.py b/tests/helpers/test_storage.py index 53c1b8a4677..ca5cb92bfd5 100644 --- a/tests/helpers/test_storage.py +++ b/tests/helpers/test_storage.py @@ -2,6 +2,7 @@ import asyncio from datetime import timedelta import json +from typing import NamedTuple from unittest.mock import Mock, patch import pytest @@ -13,8 +14,9 @@ from homeassistant.const import ( from homeassistant.core import CoreState from homeassistant.helpers import storage from homeassistant.util import dt +from homeassistant.util.color import RGBColor -from tests.common import async_fire_time_changed +from tests.common import async_fire_time_changed, async_test_home_assistant MOCK_VERSION = 1 MOCK_VERSION_2 = 2 @@ -460,3 +462,47 @@ async def test_changing_delayed_written_data(hass, store, hass_storage): "key": MOCK_KEY, "data": {"hello": "world"}, } + + +async def test_saving_load_round_trip(tmpdir): + """Test saving and loading round trip.""" + loop = asyncio.get_running_loop() + hass = await async_test_home_assistant(loop) + + hass.config.config_dir = await hass.async_add_executor_job( + tmpdir.mkdir, "temp_storage" + ) + + class NamedTupleSubclass(NamedTuple): + """A NamedTuple subclass.""" + + name: str + + nts = NamedTupleSubclass("a") + + data = { + "named_tuple_subclass": nts, + "rgb_color": RGBColor(255, 255, 0), + "set": {1, 2, 3}, + "list": [1, 2, 3], + "tuple": (1, 2, 3), + "dict_with_int": {1: 1, 2: 2}, + "dict_with_named_tuple": {1: nts, 2: nts}, + } + + store = storage.Store( + hass, MOCK_VERSION_2, MOCK_KEY, minor_version=MOCK_MINOR_VERSION_1 + ) + await store.async_save(data) + load = await store.async_load() + assert load == { + "dict_with_int": {"1": 1, "2": 2}, + "dict_with_named_tuple": {"1": ["a"], "2": ["a"]}, + "list": [1, 2, 3], + "named_tuple_subclass": ["a"], + "rgb_color": [255, 255, 0], + "set": [1, 2, 3], + "tuple": [1, 2, 3], + } + + await hass.async_stop(force=True) diff --git a/tests/util/test_json.py b/tests/util/test_json.py index 28d321036c5..509c0376fae 100644 --- a/tests/util/test_json.py +++ b/tests/util/test_json.py @@ -12,7 +12,6 @@ import pytest from homeassistant.core import Event, State from homeassistant.exceptions import HomeAssistantError from homeassistant.helpers.json import JSONEncoder as DefaultHASSJSONEncoder -from homeassistant.helpers.template import TupleWrapper from homeassistant.util.json import ( SerializationError, find_paths_unserializable_data, @@ -83,23 +82,15 @@ def test_overwrite_and_reload(atomic_writes): def test_save_bad_data(): """Test error from trying to save unserializable data.""" + + class CannotSerializeMe: + """Cannot serialize this.""" + with pytest.raises(SerializationError) as excinfo: - save_json("test4", {"hello": set()}) + save_json("test4", {"hello": CannotSerializeMe()}) - assert ( - "Failed to serialize to JSON: test4. Bad data at $.hello=set()(" - in str(excinfo.value) - ) - - -def test_save_bad_data_tuple_wrapper(): - """Test error from trying to save unserializable data.""" - with pytest.raises(SerializationError) as excinfo: - save_json("test4", {"hello": TupleWrapper(("4", "5"))}) - - assert ( - "Failed to serialize to JSON: test4. Bad data at $.hello=('4', '5')(" - in str(excinfo.value) + assert "Failed to serialize to JSON: test4. Bad data at $.hello=" in str( + excinfo.value )