284 lines
8.6 KiB
Python
284 lines
8.6 KiB
Python
"""Helpers to help with encoding Home Assistant objects in JSON."""
|
|
|
|
from collections import deque
|
|
from collections.abc import Callable
|
|
import datetime
|
|
from functools import partial
|
|
import json
|
|
import logging
|
|
from pathlib import Path
|
|
from typing import TYPE_CHECKING, Any, Final
|
|
|
|
import orjson
|
|
|
|
from homeassistant.util.file import write_utf8_file, write_utf8_file_atomic
|
|
from homeassistant.util.json import ( # noqa: F401
|
|
JSON_DECODE_EXCEPTIONS as _JSON_DECODE_EXCEPTIONS,
|
|
JSON_ENCODE_EXCEPTIONS as _JSON_ENCODE_EXCEPTIONS,
|
|
SerializationError,
|
|
format_unserializable_data,
|
|
json_loads as _json_loads,
|
|
)
|
|
|
|
from .deprecation import (
|
|
DeprecatedConstant,
|
|
all_with_deprecated_constants,
|
|
check_if_deprecated_constant,
|
|
deprecated_function,
|
|
dir_with_deprecated_constants,
|
|
)
|
|
|
|
_DEPRECATED_JSON_DECODE_EXCEPTIONS = DeprecatedConstant(
|
|
_JSON_DECODE_EXCEPTIONS, "homeassistant.util.json.JSON_DECODE_EXCEPTIONS", "2025.8"
|
|
)
|
|
_DEPRECATED_JSON_ENCODE_EXCEPTIONS = DeprecatedConstant(
|
|
_JSON_ENCODE_EXCEPTIONS, "homeassistant.util.json.JSON_ENCODE_EXCEPTIONS", "2025.8"
|
|
)
|
|
json_loads = deprecated_function(
|
|
"homeassistant.util.json.json_loads", breaks_in_ha_version="2025.8"
|
|
)(_json_loads)
|
|
|
|
# These can be removed if no deprecated constant are in this module anymore
|
|
__getattr__ = partial(check_if_deprecated_constant, module_globals=globals())
|
|
__dir__ = partial(
|
|
dir_with_deprecated_constants, module_globals_keys=[*globals().keys()]
|
|
)
|
|
__all__ = all_with_deprecated_constants(globals())
|
|
|
|
|
|
_LOGGER = logging.getLogger(__name__)
|
|
|
|
|
|
class JSONEncoder(json.JSONEncoder):
|
|
"""JSONEncoder that supports Home Assistant objects."""
|
|
|
|
def default(self, o: Any) -> Any:
|
|
"""Convert Home Assistant objects.
|
|
|
|
Hand other objects to the original method.
|
|
"""
|
|
if isinstance(o, datetime.datetime):
|
|
return o.isoformat()
|
|
if isinstance(o, set):
|
|
return list(o)
|
|
if hasattr(o, "as_dict"):
|
|
return o.as_dict()
|
|
|
|
return json.JSONEncoder.default(self, o)
|
|
|
|
|
|
def json_encoder_default(obj: Any) -> Any:
|
|
"""Convert Home Assistant objects.
|
|
|
|
Hand other objects to the original method.
|
|
"""
|
|
if hasattr(obj, "json_fragment"):
|
|
return obj.json_fragment
|
|
if isinstance(obj, (set, tuple)):
|
|
return list(obj)
|
|
if isinstance(obj, float):
|
|
return float(obj)
|
|
if hasattr(obj, "as_dict"):
|
|
return obj.as_dict()
|
|
if isinstance(obj, Path):
|
|
return obj.as_posix()
|
|
if isinstance(obj, datetime.datetime):
|
|
return obj.isoformat()
|
|
raise TypeError
|
|
|
|
|
|
if TYPE_CHECKING:
|
|
|
|
def json_bytes(obj: Any) -> bytes:
|
|
"""Dump json bytes."""
|
|
|
|
else:
|
|
json_bytes = partial(
|
|
orjson.dumps, option=orjson.OPT_NON_STR_KEYS, default=json_encoder_default
|
|
)
|
|
"""Dump json bytes."""
|
|
|
|
|
|
class ExtendedJSONEncoder(JSONEncoder):
|
|
"""JSONEncoder that supports Home Assistant objects and falls back to repr(o)."""
|
|
|
|
def default(self, o: Any) -> Any:
|
|
"""Convert certain objects.
|
|
|
|
Fall back to repr(o).
|
|
"""
|
|
if isinstance(o, datetime.timedelta):
|
|
return {"__type": str(type(o)), "total_seconds": o.total_seconds()}
|
|
if isinstance(o, datetime.datetime):
|
|
return o.isoformat()
|
|
if isinstance(o, (datetime.date, datetime.time)):
|
|
return {"__type": str(type(o)), "isoformat": o.isoformat()}
|
|
try:
|
|
return super().default(o)
|
|
except TypeError:
|
|
return {"__type": str(type(o)), "repr": repr(o)}
|
|
|
|
|
|
def _strip_null(obj: Any) -> Any:
|
|
"""Strip NUL from an object."""
|
|
if isinstance(obj, str):
|
|
return obj.split("\0", 1)[0]
|
|
if isinstance(obj, dict):
|
|
return {key: _strip_null(o) for key, o in obj.items()}
|
|
if isinstance(obj, list):
|
|
return [_strip_null(o) for o in obj]
|
|
return obj
|
|
|
|
|
|
def json_bytes_strip_null(data: Any) -> bytes:
|
|
"""Dump json bytes after terminating strings at the first NUL."""
|
|
# We expect null-characters to be very rare, hence try encoding first and look
|
|
# for an escaped null-character in the output.
|
|
result = json_bytes(data)
|
|
if b"\\u0000" not in result:
|
|
return result
|
|
|
|
# We work on the processed result so we don't need to worry about
|
|
# Home Assistant extensions which allows encoding sets, tuples, etc.
|
|
return json_bytes(_strip_null(orjson.loads(result)))
|
|
|
|
|
|
json_fragment = orjson.Fragment
|
|
|
|
|
|
def json_dumps(data: Any) -> str:
|
|
r"""Dump json string.
|
|
|
|
orjson supports serializing dataclasses natively which
|
|
eliminates the need to implement as_dict in many places
|
|
when the data is already in a dataclass. This works
|
|
well as long as all the data in the dataclass can also
|
|
be serialized.
|
|
|
|
If it turns out to be a problem we can disable this
|
|
with option \|= orjson.OPT_PASSTHROUGH_DATACLASS and it
|
|
will fallback to as_dict
|
|
"""
|
|
return json_bytes(data).decode("utf-8")
|
|
|
|
|
|
json_bytes_sorted = partial(
|
|
orjson.dumps,
|
|
option=orjson.OPT_NON_STR_KEYS | orjson.OPT_SORT_KEYS,
|
|
default=json_encoder_default,
|
|
)
|
|
"""Dump json bytes with keys sorted."""
|
|
|
|
|
|
def json_dumps_sorted(data: Any) -> str:
|
|
"""Dump json string with keys sorted."""
|
|
return json_bytes_sorted(data).decode("utf-8")
|
|
|
|
|
|
JSON_DUMP: Final = json_dumps
|
|
|
|
|
|
def _orjson_default_encoder(data: Any) -> str:
|
|
"""JSON encoder that uses orjson with hass defaults and returns a str."""
|
|
return _orjson_bytes_default_encoder(data).decode("utf-8")
|
|
|
|
|
|
def _orjson_bytes_default_encoder(data: Any) -> bytes:
|
|
"""JSON encoder that uses orjson with hass defaults and returns bytes."""
|
|
return orjson.dumps(
|
|
data,
|
|
option=orjson.OPT_INDENT_2 | orjson.OPT_NON_STR_KEYS,
|
|
default=json_encoder_default,
|
|
)
|
|
|
|
|
|
def save_json(
|
|
filename: str,
|
|
data: list | dict,
|
|
private: bool = False,
|
|
*,
|
|
encoder: type[json.JSONEncoder] | None = None,
|
|
atomic_writes: bool = False,
|
|
) -> None:
|
|
"""Save JSON data to a file."""
|
|
dump: Callable[[Any], Any]
|
|
try:
|
|
# For backwards compatibility, if they pass in the
|
|
# default json encoder we use _orjson_default_encoder
|
|
# which is the orjson equivalent to the default encoder.
|
|
if encoder and encoder is not JSONEncoder:
|
|
# If they pass a custom encoder that is not the
|
|
# default JSONEncoder, we use the slow path of json.dumps
|
|
mode = "w"
|
|
dump = json.dumps
|
|
json_data: str | bytes = json.dumps(data, indent=2, cls=encoder)
|
|
else:
|
|
mode = "wb"
|
|
dump = _orjson_default_encoder
|
|
json_data = _orjson_bytes_default_encoder(data)
|
|
except TypeError as error:
|
|
formatted_data = format_unserializable_data(
|
|
find_paths_unserializable_data(data, dump=dump)
|
|
)
|
|
msg = f"Failed to serialize to JSON: {filename}. Bad data at {formatted_data}"
|
|
_LOGGER.error(msg)
|
|
raise SerializationError(msg) from error
|
|
|
|
method = write_utf8_file_atomic if atomic_writes else write_utf8_file
|
|
method(filename, json_data, private, mode=mode)
|
|
|
|
|
|
def find_paths_unserializable_data(
|
|
bad_data: Any, *, dump: Callable[[Any], str] = json.dumps
|
|
) -> dict[str, Any]:
|
|
"""Find the paths to unserializable data.
|
|
|
|
This method is slow! Only use for error handling.
|
|
"""
|
|
from homeassistant.core import ( # pylint: disable=import-outside-toplevel
|
|
Event,
|
|
State,
|
|
)
|
|
|
|
to_process = deque([(bad_data, "$")])
|
|
invalid = {}
|
|
|
|
while to_process:
|
|
obj, obj_path = to_process.popleft()
|
|
|
|
try:
|
|
dump(obj)
|
|
continue
|
|
except (ValueError, TypeError):
|
|
pass
|
|
|
|
# We convert objects with as_dict to their dict values
|
|
# so we can find bad data inside it
|
|
if hasattr(obj, "as_dict"):
|
|
desc = obj.__class__.__name__
|
|
if isinstance(obj, State):
|
|
desc += f": {obj.entity_id}"
|
|
elif isinstance(obj, Event):
|
|
desc += f": {obj.event_type}"
|
|
|
|
obj_path += f"({desc})"
|
|
obj = obj.as_dict()
|
|
|
|
if isinstance(obj, dict):
|
|
for key, value in obj.items():
|
|
try:
|
|
# Is key valid?
|
|
dump({key: None})
|
|
except TypeError:
|
|
invalid[f"{obj_path}<key: {key}>"] = key
|
|
else:
|
|
# Process value
|
|
to_process.append((value, f"{obj_path}.{key}"))
|
|
elif isinstance(obj, list):
|
|
for idx, value in enumerate(obj):
|
|
to_process.append((value, f"{obj_path}[{idx}]"))
|
|
else:
|
|
invalid[obj_path] = obj
|
|
|
|
return invalid
|