Revert "Initial orjson support (#72754)" (#72789)

This was causing the wheels to fail to build. We need
to workout why when we don't have release pressure

This reverts commit d9d22a9556.
pull/72793/head
J. Nick Koston 2022-05-31 10:51:55 -10:00 committed by GitHub
parent 9cea936c22
commit c365454afb
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
18 changed files with 67 additions and 127 deletions

View File

@ -24,10 +24,10 @@ from homeassistant.components.recorder.statistics import (
) )
from homeassistant.components.recorder.util import session_scope from homeassistant.components.recorder.util import session_scope
from homeassistant.components.websocket_api import messages from homeassistant.components.websocket_api import messages
from homeassistant.components.websocket_api.const import JSON_DUMP
from homeassistant.core import HomeAssistant from homeassistant.core import HomeAssistant
import homeassistant.helpers.config_validation as cv import homeassistant.helpers.config_validation as cv
from homeassistant.helpers.entityfilter import INCLUDE_EXCLUDE_BASE_FILTER_SCHEMA from homeassistant.helpers.entityfilter import INCLUDE_EXCLUDE_BASE_FILTER_SCHEMA
from homeassistant.helpers.json import JSON_DUMP
from homeassistant.helpers.typing import ConfigType from homeassistant.helpers.typing import ConfigType
import homeassistant.util.dt as dt_util import homeassistant.util.dt as dt_util

View File

@ -14,9 +14,9 @@ from homeassistant.components import websocket_api
from homeassistant.components.recorder import get_instance from homeassistant.components.recorder import get_instance
from homeassistant.components.websocket_api import messages from homeassistant.components.websocket_api import messages
from homeassistant.components.websocket_api.connection import ActiveConnection from homeassistant.components.websocket_api.connection import ActiveConnection
from homeassistant.components.websocket_api.const import JSON_DUMP
from homeassistant.core import CALLBACK_TYPE, Event, HomeAssistant, callback from homeassistant.core import CALLBACK_TYPE, Event, HomeAssistant, callback
from homeassistant.helpers.event import async_track_point_in_utc_time from homeassistant.helpers.event import async_track_point_in_utc_time
from homeassistant.helpers.json import JSON_DUMP
import homeassistant.util.dt as dt_util import homeassistant.util.dt as dt_util
from .helpers import ( from .helpers import (

View File

@ -1,11 +1,12 @@
"""Recorder constants.""" """Recorder constants."""
from functools import partial
import json
from typing import Final
from homeassistant.backports.enum import StrEnum from homeassistant.backports.enum import StrEnum
from homeassistant.const import ATTR_ATTRIBUTION, ATTR_RESTORED, ATTR_SUPPORTED_FEATURES from homeassistant.const import ATTR_ATTRIBUTION, ATTR_RESTORED, ATTR_SUPPORTED_FEATURES
from homeassistant.helpers.json import ( # noqa: F401 pylint: disable=unused-import from homeassistant.helpers.json import JSONEncoder
JSON_DUMP,
)
DATA_INSTANCE = "recorder_instance" DATA_INSTANCE = "recorder_instance"
SQLITE_URL_PREFIX = "sqlite://" SQLITE_URL_PREFIX = "sqlite://"
@ -26,6 +27,7 @@ MAX_ROWS_TO_PURGE = 998
DB_WORKER_PREFIX = "DbWorker" DB_WORKER_PREFIX = "DbWorker"
JSON_DUMP: Final = partial(json.dumps, cls=JSONEncoder, separators=(",", ":"))
ALL_DOMAIN_EXCLUDE_ATTRS = {ATTR_ATTRIBUTION, ATTR_RESTORED, ATTR_SUPPORTED_FEATURES} ALL_DOMAIN_EXCLUDE_ATTRS = {ATTR_ATTRIBUTION, ATTR_RESTORED, ATTR_SUPPORTED_FEATURES}

View File

@ -744,12 +744,11 @@ class Recorder(threading.Thread):
return return
try: try:
shared_data_bytes = EventData.shared_data_bytes_from_event(event) shared_data = EventData.shared_data_from_event(event)
except (TypeError, ValueError) as ex: except (TypeError, ValueError) as ex:
_LOGGER.warning("Event is not JSON serializable: %s: %s", event, ex) _LOGGER.warning("Event is not JSON serializable: %s: %s", event, ex)
return return
shared_data = shared_data_bytes.decode("utf-8")
# Matching attributes found in the pending commit # Matching attributes found in the pending commit
if pending_event_data := self._pending_event_data.get(shared_data): if pending_event_data := self._pending_event_data.get(shared_data):
dbevent.event_data_rel = pending_event_data dbevent.event_data_rel = pending_event_data
@ -757,7 +756,7 @@ class Recorder(threading.Thread):
elif data_id := self._event_data_ids.get(shared_data): elif data_id := self._event_data_ids.get(shared_data):
dbevent.data_id = data_id dbevent.data_id = data_id
else: else:
data_hash = EventData.hash_shared_data_bytes(shared_data_bytes) data_hash = EventData.hash_shared_data(shared_data)
# Matching attributes found in the database # Matching attributes found in the database
if data_id := self._find_shared_data_in_db(data_hash, shared_data): if data_id := self._find_shared_data_in_db(data_hash, shared_data):
self._event_data_ids[shared_data] = dbevent.data_id = data_id self._event_data_ids[shared_data] = dbevent.data_id = data_id
@ -776,7 +775,7 @@ class Recorder(threading.Thread):
assert self.event_session is not None assert self.event_session is not None
try: try:
dbstate = States.from_event(event) dbstate = States.from_event(event)
shared_attrs_bytes = StateAttributes.shared_attrs_bytes_from_event( shared_attrs = StateAttributes.shared_attrs_from_event(
event, self._exclude_attributes_by_domain event, self._exclude_attributes_by_domain
) )
except (TypeError, ValueError) as ex: except (TypeError, ValueError) as ex:
@ -787,7 +786,6 @@ class Recorder(threading.Thread):
) )
return return
shared_attrs = shared_attrs_bytes.decode("utf-8")
dbstate.attributes = None dbstate.attributes = None
# Matching attributes found in the pending commit # Matching attributes found in the pending commit
if pending_attributes := self._pending_state_attributes.get(shared_attrs): if pending_attributes := self._pending_state_attributes.get(shared_attrs):
@ -796,7 +794,7 @@ class Recorder(threading.Thread):
elif attributes_id := self._state_attributes_ids.get(shared_attrs): elif attributes_id := self._state_attributes_ids.get(shared_attrs):
dbstate.attributes_id = attributes_id dbstate.attributes_id = attributes_id
else: else:
attr_hash = StateAttributes.hash_shared_attrs_bytes(shared_attrs_bytes) attr_hash = StateAttributes.hash_shared_attrs(shared_attrs)
# Matching attributes found in the database # Matching attributes found in the database
if attributes_id := self._find_shared_attr_in_db(attr_hash, shared_attrs): if attributes_id := self._find_shared_attr_in_db(attr_hash, shared_attrs):
dbstate.attributes_id = attributes_id dbstate.attributes_id = attributes_id

View File

@ -3,12 +3,12 @@ from __future__ import annotations
from collections.abc import Callable from collections.abc import Callable
from datetime import datetime, timedelta from datetime import datetime, timedelta
import json
import logging import logging
from typing import Any, TypedDict, cast, overload from typing import Any, TypedDict, cast, overload
import ciso8601 import ciso8601
from fnvhash import fnv1a_32 from fnvhash import fnv1a_32
import orjson
from sqlalchemy import ( from sqlalchemy import (
JSON, JSON,
BigInteger, BigInteger,
@ -46,10 +46,9 @@ from homeassistant.const import (
MAX_LENGTH_STATE_STATE, MAX_LENGTH_STATE_STATE,
) )
from homeassistant.core import Context, Event, EventOrigin, State, split_entity_id from homeassistant.core import Context, Event, EventOrigin, State, split_entity_id
from homeassistant.helpers.json import JSON_DUMP, json_bytes
import homeassistant.util.dt as dt_util import homeassistant.util.dt as dt_util
from .const import ALL_DOMAIN_EXCLUDE_ATTRS from .const import ALL_DOMAIN_EXCLUDE_ATTRS, JSON_DUMP
# SQLAlchemy Schema # SQLAlchemy Schema
# pylint: disable=invalid-name # pylint: disable=invalid-name
@ -133,7 +132,7 @@ class JSONLiteral(JSON): # type: ignore[misc]
def process(value: Any) -> str: def process(value: Any) -> str:
"""Dump json.""" """Dump json."""
return JSON_DUMP(value) return json.dumps(value)
return process return process
@ -200,7 +199,7 @@ class Events(Base): # type: ignore[misc,valid-type]
try: try:
return Event( return Event(
self.event_type, self.event_type,
orjson.loads(self.event_data) if self.event_data else {}, json.loads(self.event_data) if self.event_data else {},
EventOrigin(self.origin) EventOrigin(self.origin)
if self.origin if self.origin
else EVENT_ORIGIN_ORDER[self.origin_idx], else EVENT_ORIGIN_ORDER[self.origin_idx],
@ -208,7 +207,7 @@ class Events(Base): # type: ignore[misc,valid-type]
context=context, context=context,
) )
except ValueError: except ValueError:
# When orjson.loads fails # When json.loads fails
_LOGGER.exception("Error converting to event: %s", self) _LOGGER.exception("Error converting to event: %s", self)
return None return None
@ -236,26 +235,25 @@ class EventData(Base): # type: ignore[misc,valid-type]
@staticmethod @staticmethod
def from_event(event: Event) -> EventData: def from_event(event: Event) -> EventData:
"""Create object from an event.""" """Create object from an event."""
shared_data = json_bytes(event.data) shared_data = JSON_DUMP(event.data)
return EventData( return EventData(
shared_data=shared_data.decode("utf-8"), shared_data=shared_data, hash=EventData.hash_shared_data(shared_data)
hash=EventData.hash_shared_data_bytes(shared_data),
) )
@staticmethod @staticmethod
def shared_data_bytes_from_event(event: Event) -> bytes: def shared_data_from_event(event: Event) -> str:
"""Create shared_data from an event.""" """Create shared_attrs from an event."""
return json_bytes(event.data) return JSON_DUMP(event.data)
@staticmethod @staticmethod
def hash_shared_data_bytes(shared_data_bytes: bytes) -> int: def hash_shared_data(shared_data: str) -> int:
"""Return the hash of json encoded shared data.""" """Return the hash of json encoded shared data."""
return cast(int, fnv1a_32(shared_data_bytes)) return cast(int, fnv1a_32(shared_data.encode("utf-8")))
def to_native(self) -> dict[str, Any]: def to_native(self) -> dict[str, Any]:
"""Convert to an HA state object.""" """Convert to an HA state object."""
try: try:
return cast(dict[str, Any], orjson.loads(self.shared_data)) return cast(dict[str, Any], json.loads(self.shared_data))
except ValueError: except ValueError:
_LOGGER.exception("Error converting row to event data: %s", self) _LOGGER.exception("Error converting row to event data: %s", self)
return {} return {}
@ -342,9 +340,9 @@ class States(Base): # type: ignore[misc,valid-type]
parent_id=self.context_parent_id, parent_id=self.context_parent_id,
) )
try: try:
attrs = orjson.loads(self.attributes) if self.attributes else {} attrs = json.loads(self.attributes) if self.attributes else {}
except ValueError: except ValueError:
# When orjson.loads fails # When json.loads fails
_LOGGER.exception("Error converting row to state: %s", self) _LOGGER.exception("Error converting row to state: %s", self)
return None return None
if self.last_changed is None or self.last_changed == self.last_updated: if self.last_changed is None or self.last_changed == self.last_updated:
@ -390,39 +388,40 @@ class StateAttributes(Base): # type: ignore[misc,valid-type]
"""Create object from a state_changed event.""" """Create object from a state_changed event."""
state: State | None = event.data.get("new_state") state: State | None = event.data.get("new_state")
# None state means the state was removed from the state machine # None state means the state was removed from the state machine
attr_bytes = b"{}" if state is None else json_bytes(state.attributes) dbstate = StateAttributes(
dbstate = StateAttributes(shared_attrs=attr_bytes.decode("utf-8")) shared_attrs="{}" if state is None else JSON_DUMP(state.attributes)
dbstate.hash = StateAttributes.hash_shared_attrs_bytes(attr_bytes) )
dbstate.hash = StateAttributes.hash_shared_attrs(dbstate.shared_attrs)
return dbstate return dbstate
@staticmethod @staticmethod
def shared_attrs_bytes_from_event( def shared_attrs_from_event(
event: Event, exclude_attrs_by_domain: dict[str, set[str]] event: Event, exclude_attrs_by_domain: dict[str, set[str]]
) -> bytes: ) -> str:
"""Create shared_attrs from a state_changed event.""" """Create shared_attrs from a state_changed event."""
state: State | None = event.data.get("new_state") state: State | None = event.data.get("new_state")
# None state means the state was removed from the state machine # None state means the state was removed from the state machine
if state is None: if state is None:
return b"{}" return "{}"
domain = split_entity_id(state.entity_id)[0] domain = split_entity_id(state.entity_id)[0]
exclude_attrs = ( exclude_attrs = (
exclude_attrs_by_domain.get(domain, set()) | ALL_DOMAIN_EXCLUDE_ATTRS exclude_attrs_by_domain.get(domain, set()) | ALL_DOMAIN_EXCLUDE_ATTRS
) )
return json_bytes( return JSON_DUMP(
{k: v for k, v in state.attributes.items() if k not in exclude_attrs} {k: v for k, v in state.attributes.items() if k not in exclude_attrs}
) )
@staticmethod @staticmethod
def hash_shared_attrs_bytes(shared_attrs_bytes: bytes) -> int: def hash_shared_attrs(shared_attrs: str) -> int:
"""Return the hash of orjson encoded shared attributes.""" """Return the hash of json encoded shared attributes."""
return cast(int, fnv1a_32(shared_attrs_bytes)) return cast(int, fnv1a_32(shared_attrs.encode("utf-8")))
def to_native(self) -> dict[str, Any]: def to_native(self) -> dict[str, Any]:
"""Convert to an HA state object.""" """Convert to an HA state object."""
try: try:
return cast(dict[str, Any], orjson.loads(self.shared_attrs)) return cast(dict[str, Any], json.loads(self.shared_attrs))
except ValueError: except ValueError:
# When orjson.loads fails # When json.loads fails
_LOGGER.exception("Error converting row to state attributes: %s", self) _LOGGER.exception("Error converting row to state attributes: %s", self)
return {} return {}
@ -836,7 +835,7 @@ def decode_attributes_from_row(
if not source or source == EMPTY_JSON_OBJECT: if not source or source == EMPTY_JSON_OBJECT:
return {} return {}
try: try:
attr_cache[source] = attributes = orjson.loads(source) attr_cache[source] = attributes = json.loads(source)
except ValueError: except ValueError:
_LOGGER.exception("Error converting row to state attributes: %s", source) _LOGGER.exception("Error converting row to state attributes: %s", source)
attr_cache[source] = attributes = {} attr_cache[source] = attributes = {}

View File

@ -29,7 +29,7 @@ from homeassistant.helpers.event import (
TrackTemplateResult, TrackTemplateResult,
async_track_template_result, async_track_template_result,
) )
from homeassistant.helpers.json import JSON_DUMP, ExtendedJSONEncoder from homeassistant.helpers.json import ExtendedJSONEncoder
from homeassistant.helpers.service import async_get_all_descriptions from homeassistant.helpers.service import async_get_all_descriptions
from homeassistant.loader import IntegrationNotFound, async_get_integration from homeassistant.loader import IntegrationNotFound, async_get_integration
from homeassistant.setup import DATA_SETUP_TIME, async_get_loaded_integrations from homeassistant.setup import DATA_SETUP_TIME, async_get_loaded_integrations
@ -241,13 +241,13 @@ def handle_get_states(
# to succeed for the UI to show. # to succeed for the UI to show.
response = messages.result_message(msg["id"], states) response = messages.result_message(msg["id"], states)
try: try:
connection.send_message(JSON_DUMP(response)) connection.send_message(const.JSON_DUMP(response))
return return
except (ValueError, TypeError): except (ValueError, TypeError):
connection.logger.error( connection.logger.error(
"Unable to serialize to JSON. Bad data found at %s", "Unable to serialize to JSON. Bad data found at %s",
format_unserializable_data( format_unserializable_data(
find_paths_unserializable_data(response, dump=JSON_DUMP) find_paths_unserializable_data(response, dump=const.JSON_DUMP)
), ),
) )
del response del response
@ -256,13 +256,13 @@ def handle_get_states(
serialized = [] serialized = []
for state in states: for state in states:
try: try:
serialized.append(JSON_DUMP(state)) serialized.append(const.JSON_DUMP(state))
except (ValueError, TypeError): except (ValueError, TypeError):
# Error is already logged above # Error is already logged above
pass pass
# We now have partially serialized states. Craft some JSON. # We now have partially serialized states. Craft some JSON.
response2 = JSON_DUMP(messages.result_message(msg["id"], ["TO_REPLACE"])) response2 = const.JSON_DUMP(messages.result_message(msg["id"], ["TO_REPLACE"]))
response2 = response2.replace('"TO_REPLACE"', ", ".join(serialized)) response2 = response2.replace('"TO_REPLACE"', ", ".join(serialized))
connection.send_message(response2) connection.send_message(response2)
@ -315,13 +315,13 @@ def handle_subscribe_entities(
# to succeed for the UI to show. # to succeed for the UI to show.
response = messages.event_message(msg["id"], data) response = messages.event_message(msg["id"], data)
try: try:
connection.send_message(JSON_DUMP(response)) connection.send_message(const.JSON_DUMP(response))
return return
except (ValueError, TypeError): except (ValueError, TypeError):
connection.logger.error( connection.logger.error(
"Unable to serialize to JSON. Bad data found at %s", "Unable to serialize to JSON. Bad data found at %s",
format_unserializable_data( format_unserializable_data(
find_paths_unserializable_data(response, dump=JSON_DUMP) find_paths_unserializable_data(response, dump=const.JSON_DUMP)
), ),
) )
del response del response
@ -330,14 +330,14 @@ def handle_subscribe_entities(
cannot_serialize: list[str] = [] cannot_serialize: list[str] = []
for entity_id, state_dict in add_entities.items(): for entity_id, state_dict in add_entities.items():
try: try:
JSON_DUMP(state_dict) const.JSON_DUMP(state_dict)
except (ValueError, TypeError): except (ValueError, TypeError):
cannot_serialize.append(entity_id) cannot_serialize.append(entity_id)
for entity_id in cannot_serialize: for entity_id in cannot_serialize:
del add_entities[entity_id] del add_entities[entity_id]
connection.send_message(JSON_DUMP(messages.event_message(msg["id"], data))) connection.send_message(const.JSON_DUMP(messages.event_message(msg["id"], data)))
@decorators.websocket_command({vol.Required("type"): "get_services"}) @decorators.websocket_command({vol.Required("type"): "get_services"})

View File

@ -11,7 +11,6 @@ import voluptuous as vol
from homeassistant.auth.models import RefreshToken, User from homeassistant.auth.models import RefreshToken, User
from homeassistant.core import Context, HomeAssistant, callback from homeassistant.core import Context, HomeAssistant, callback
from homeassistant.exceptions import HomeAssistantError, Unauthorized from homeassistant.exceptions import HomeAssistantError, Unauthorized
from homeassistant.helpers.json import JSON_DUMP
from . import const, messages from . import const, messages
@ -57,7 +56,7 @@ class ActiveConnection:
async def send_big_result(self, msg_id: int, result: Any) -> None: async def send_big_result(self, msg_id: int, result: Any) -> None:
"""Send a result message that would be expensive to JSON serialize.""" """Send a result message that would be expensive to JSON serialize."""
content = await self.hass.async_add_executor_job( content = await self.hass.async_add_executor_job(
JSON_DUMP, messages.result_message(msg_id, result) const.JSON_DUMP, messages.result_message(msg_id, result)
) )
self.send_message(content) self.send_message(content)

View File

@ -4,9 +4,12 @@ from __future__ import annotations
import asyncio import asyncio
from collections.abc import Awaitable, Callable from collections.abc import Awaitable, Callable
from concurrent import futures from concurrent import futures
from functools import partial
import json
from typing import TYPE_CHECKING, Any, Final from typing import TYPE_CHECKING, Any, Final
from homeassistant.core import HomeAssistant from homeassistant.core import HomeAssistant
from homeassistant.helpers.json import JSONEncoder
if TYPE_CHECKING: if TYPE_CHECKING:
from .connection import ActiveConnection # noqa: F401 from .connection import ActiveConnection # noqa: F401
@ -50,6 +53,10 @@ SIGNAL_WEBSOCKET_DISCONNECTED: Final = "websocket_disconnected"
# Data used to store the current connection list # Data used to store the current connection list
DATA_CONNECTIONS: Final = f"{DOMAIN}.connections" DATA_CONNECTIONS: Final = f"{DOMAIN}.connections"
JSON_DUMP: Final = partial(
json.dumps, cls=JSONEncoder, allow_nan=False, separators=(",", ":")
)
COMPRESSED_STATE_STATE = "s" COMPRESSED_STATE_STATE = "s"
COMPRESSED_STATE_ATTRIBUTES = "a" COMPRESSED_STATE_ATTRIBUTES = "a"
COMPRESSED_STATE_CONTEXT = "c" COMPRESSED_STATE_CONTEXT = "c"

View File

@ -9,7 +9,6 @@ import voluptuous as vol
from homeassistant.core import Event, State from homeassistant.core import Event, State
from homeassistant.helpers import config_validation as cv from homeassistant.helpers import config_validation as cv
from homeassistant.helpers.json import JSON_DUMP
from homeassistant.util.json import ( from homeassistant.util.json import (
find_paths_unserializable_data, find_paths_unserializable_data,
format_unserializable_data, format_unserializable_data,
@ -194,15 +193,15 @@ def compressed_state_dict_add(state: State) -> dict[str, Any]:
def message_to_json(message: dict[str, Any]) -> str: def message_to_json(message: dict[str, Any]) -> str:
"""Serialize a websocket message to json.""" """Serialize a websocket message to json."""
try: try:
return JSON_DUMP(message) return const.JSON_DUMP(message)
except (ValueError, TypeError): except (ValueError, TypeError):
_LOGGER.error( _LOGGER.error(
"Unable to serialize to JSON. Bad data found at %s", "Unable to serialize to JSON. Bad data found at %s",
format_unserializable_data( format_unserializable_data(
find_paths_unserializable_data(message, dump=JSON_DUMP) find_paths_unserializable_data(message, dump=const.JSON_DUMP)
), ),
) )
return JSON_DUMP( return const.JSON_DUMP(
error_message( error_message(
message["id"], const.ERR_UNKNOWN_ERROR, "Invalid JSON in response" message["id"], const.ERR_UNKNOWN_ERROR, "Invalid JSON in response"
) )

View File

@ -14,7 +14,6 @@ from aiohttp import web
from aiohttp.hdrs import CONTENT_TYPE, USER_AGENT from aiohttp.hdrs import CONTENT_TYPE, USER_AGENT
from aiohttp.web_exceptions import HTTPBadGateway, HTTPGatewayTimeout from aiohttp.web_exceptions import HTTPBadGateway, HTTPGatewayTimeout
import async_timeout import async_timeout
import orjson
from homeassistant import config_entries from homeassistant import config_entries
from homeassistant.const import EVENT_HOMEASSISTANT_CLOSE, __version__ from homeassistant.const import EVENT_HOMEASSISTANT_CLOSE, __version__
@ -98,7 +97,6 @@ def _async_create_clientsession(
"""Create a new ClientSession with kwargs, i.e. for cookies.""" """Create a new ClientSession with kwargs, i.e. for cookies."""
clientsession = aiohttp.ClientSession( clientsession = aiohttp.ClientSession(
connector=_async_get_connector(hass, verify_ssl), connector=_async_get_connector(hass, verify_ssl),
json_serialize=lambda x: orjson.dumps(x).decode("utf-8"),
**kwargs, **kwargs,
) )
# Prevent packages accidentally overriding our default headers # Prevent packages accidentally overriding our default headers

View File

@ -1,10 +1,7 @@
"""Helpers to help with encoding Home Assistant objects in JSON.""" """Helpers to help with encoding Home Assistant objects in JSON."""
import datetime import datetime
import json import json
from pathlib import Path from typing import Any
from typing import Any, Final
import orjson
class JSONEncoder(json.JSONEncoder): class JSONEncoder(json.JSONEncoder):
@ -25,20 +22,6 @@ class JSONEncoder(json.JSONEncoder):
return json.JSONEncoder.default(self, o) return json.JSONEncoder.default(self, o)
def json_encoder_default(obj: Any) -> Any:
"""Convert Home Assistant objects.
Hand other objects to the original method.
"""
if isinstance(obj, set):
return list(obj)
if hasattr(obj, "as_dict"):
return obj.as_dict()
if isinstance(obj, Path):
return obj.as_posix()
raise TypeError
class ExtendedJSONEncoder(JSONEncoder): class ExtendedJSONEncoder(JSONEncoder):
"""JSONEncoder that supports Home Assistant objects and falls back to repr(o).""" """JSONEncoder that supports Home Assistant objects and falls back to repr(o)."""
@ -57,31 +40,3 @@ class ExtendedJSONEncoder(JSONEncoder):
return super().default(o) return super().default(o)
except TypeError: except TypeError:
return {"__type": str(type(o)), "repr": repr(o)} return {"__type": str(type(o)), "repr": repr(o)}
def json_bytes(data: Any) -> bytes:
"""Dump json bytes."""
return orjson.dumps(
data, option=orjson.OPT_NON_STR_KEYS, default=json_encoder_default
)
def json_dumps(data: Any) -> str:
"""Dump json string.
orjson supports serializing dataclasses natively which
eliminates the need to implement as_dict in many places
when the data is already in a dataclass. This works
well as long as all the data in the dataclass can also
be serialized.
If it turns out to be a problem we can disable this
with option |= orjson.OPT_PASSTHROUGH_DATACLASS and it
will fallback to as_dict
"""
return orjson.dumps(
data, option=orjson.OPT_NON_STR_KEYS, default=json_encoder_default
).decode("utf-8")
JSON_DUMP: Final = json_dumps

View File

@ -20,7 +20,6 @@ httpx==0.23.0
ifaddr==0.1.7 ifaddr==0.1.7
jinja2==3.1.2 jinja2==3.1.2
lru-dict==1.1.7 lru-dict==1.1.7
orjson==3.6.8
paho-mqtt==1.6.1 paho-mqtt==1.6.1
pillow==9.1.1 pillow==9.1.1
pip>=21.0,<22.2 pip>=21.0,<22.2

View File

@ -12,13 +12,14 @@ from timeit import default_timer as timer
from typing import TypeVar from typing import TypeVar
from homeassistant import core from homeassistant import core
from homeassistant.components.websocket_api.const import JSON_DUMP
from homeassistant.const import EVENT_STATE_CHANGED from homeassistant.const import EVENT_STATE_CHANGED
from homeassistant.helpers.entityfilter import convert_include_exclude_filter from homeassistant.helpers.entityfilter import convert_include_exclude_filter
from homeassistant.helpers.event import ( from homeassistant.helpers.event import (
async_track_state_change, async_track_state_change,
async_track_state_change_event, async_track_state_change_event,
) )
from homeassistant.helpers.json import JSON_DUMP, JSONEncoder from homeassistant.helpers.json import JSONEncoder
# mypy: allow-untyped-calls, allow-untyped-defs, no-check-untyped-defs # mypy: allow-untyped-calls, allow-untyped-defs, no-check-untyped-defs
# mypy: no-warn-return-any # mypy: no-warn-return-any

View File

@ -7,8 +7,6 @@ import json
import logging import logging
from typing import Any from typing import Any
import orjson
from homeassistant.core import Event, State from homeassistant.core import Event, State
from homeassistant.exceptions import HomeAssistantError from homeassistant.exceptions import HomeAssistantError
@ -32,7 +30,7 @@ def load_json(filename: str, default: list | dict | None = None) -> list | dict:
""" """
try: try:
with open(filename, encoding="utf-8") as fdesc: with open(filename, encoding="utf-8") as fdesc:
return orjson.loads(fdesc.read()) # type: ignore[no-any-return] return json.loads(fdesc.read()) # type: ignore[no-any-return]
except FileNotFoundError: except FileNotFoundError:
# This is not a fatal error # This is not a fatal error
_LOGGER.debug("JSON file not found: %s", filename) _LOGGER.debug("JSON file not found: %s", filename)
@ -58,10 +56,7 @@ def save_json(
Returns True on success. Returns True on success.
""" """
try: try:
if encoder: json_data = json.dumps(data, indent=4, cls=encoder)
json_data = json.dumps(data, indent=2, cls=encoder)
else:
json_data = orjson.dumps(data, option=orjson.OPT_INDENT_2).decode("utf-8")
except TypeError as error: except TypeError as error:
msg = f"Failed to serialize to JSON: {filename}. Bad data at {format_unserializable_data(find_paths_unserializable_data(data))}" msg = f"Failed to serialize to JSON: {filename}. Bad data at {format_unserializable_data(find_paths_unserializable_data(data))}"
_LOGGER.error(msg) _LOGGER.error(msg)

View File

@ -41,7 +41,6 @@ dependencies = [
"PyJWT==2.4.0", "PyJWT==2.4.0",
# PyJWT has loose dependency. We want the latest one. # PyJWT has loose dependency. We want the latest one.
"cryptography==36.0.2", "cryptography==36.0.2",
"orjson==3.6.8",
"pip>=21.0,<22.2", "pip>=21.0,<22.2",
"python-slugify==4.0.1", "python-slugify==4.0.1",
"pyyaml==6.0", "pyyaml==6.0",
@ -120,7 +119,6 @@ extension-pkg-allow-list = [
"av.audio.stream", "av.audio.stream",
"av.stream", "av.stream",
"ciso8601", "ciso8601",
"orjson",
"cv2", "cv2",
] ]

View File

@ -15,7 +15,6 @@ ifaddr==0.1.7
jinja2==3.1.2 jinja2==3.1.2
PyJWT==2.4.0 PyJWT==2.4.0
cryptography==36.0.2 cryptography==36.0.2
orjson==3.6.8
pip>=21.0,<22.2 pip>=21.0,<22.2
python-slugify==4.0.1 python-slugify==4.0.1
pyyaml==6.0 pyyaml==6.0

View File

@ -4,7 +4,6 @@ from unittest.mock import patch
import pytest import pytest
from homeassistant.components.energy import async_get_manager, validate from homeassistant.components.energy import async_get_manager, validate
from homeassistant.helpers.json import JSON_DUMP
from homeassistant.setup import async_setup_component from homeassistant.setup import async_setup_component
@ -409,11 +408,7 @@ async def test_validation_grid(
}, },
) )
result = await validate.async_validate(hass) assert (await validate.async_validate(hass)).as_dict() == {
# verify its also json serializable
JSON_DUMP(result)
assert result.as_dict() == {
"energy_sources": [ "energy_sources": [
[ [
{ {

View File

@ -619,15 +619,12 @@ async def test_states_filters_visible(hass, hass_admin_user, websocket_client):
async def test_get_states_not_allows_nan(hass, websocket_client): async def test_get_states_not_allows_nan(hass, websocket_client):
"""Test get_states command converts NaN to None.""" """Test get_states command not allows NaN floats."""
hass.states.async_set("greeting.hello", "world") hass.states.async_set("greeting.hello", "world")
hass.states.async_set("greeting.bad", "data", {"hello": float("NaN")}) hass.states.async_set("greeting.bad", "data", {"hello": float("NaN")})
hass.states.async_set("greeting.bye", "universe") hass.states.async_set("greeting.bye", "universe")
await websocket_client.send_json({"id": 5, "type": "get_states"}) await websocket_client.send_json({"id": 5, "type": "get_states"})
bad = dict(hass.states.get("greeting.bad").as_dict())
bad["attributes"] = dict(bad["attributes"])
bad["attributes"]["hello"] = None
msg = await websocket_client.receive_json() msg = await websocket_client.receive_json()
assert msg["id"] == 5 assert msg["id"] == 5
@ -635,7 +632,6 @@ async def test_get_states_not_allows_nan(hass, websocket_client):
assert msg["success"] assert msg["success"]
assert msg["result"] == [ assert msg["result"] == [
hass.states.get("greeting.hello").as_dict(), hass.states.get("greeting.hello").as_dict(),
bad,
hass.states.get("greeting.bye").as_dict(), hass.states.get("greeting.bye").as_dict(),
] ]