Additional strict typing for additional recorder internals (#68689)

* Strict typing for additional recorder internals

* revert

* fix refactoring error
pull/68832/head
J. Nick Koston 2022-03-28 21:45:25 -10:00 committed by GitHub
parent 05ddd773ff
commit d7634d1cb1
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 84 additions and 42 deletions

View File

@ -164,11 +164,14 @@ homeassistant.components.pure_energie.*
homeassistant.components.rainmachine.*
homeassistant.components.rdw.*
homeassistant.components.recollect_waste.*
homeassistant.components.recorder.models
homeassistant.components.recorder.history
homeassistant.components.recorder.pool
homeassistant.components.recorder.purge
homeassistant.components.recorder.repack
homeassistant.components.recorder.statistics
homeassistant.components.recorder.util
homeassistant.components.recorder.websocket_api
homeassistant.components.remote.*
homeassistant.components.renault.*
homeassistant.components.ridwell.*

View File

@ -232,8 +232,7 @@ def run_information(hass, point_in_time: datetime | None = None) -> RecorderRuns
There is also the run that covers point_in_time.
"""
run_info = run_information_from_instance(hass, point_in_time)
if run_info:
if run_info := run_information_from_instance(hass, point_in_time):
return run_info
with session_scope(hass=hass) as session:
@ -1028,8 +1027,7 @@ class Recorder(threading.Thread):
try:
if event.event_type == EVENT_STATE_CHANGED:
dbevent = Events.from_event(event, event_data="{}")
dbevent.event_data = None
dbevent = Events.from_event(event, event_data=None)
else:
dbevent = Events.from_event(event)
except (TypeError, ValueError):

View File

@ -4,7 +4,7 @@ from __future__ import annotations
from datetime import datetime, timedelta
import json
import logging
from typing import Any, TypedDict, overload
from typing import Any, TypedDict, cast, overload
from fnvhash import fnv1a_32
from sqlalchemy import (
@ -35,6 +35,7 @@ from homeassistant.const import (
MAX_LENGTH_STATE_STATE,
)
from homeassistant.core import Context, Event, EventOrigin, State
from homeassistant.helpers.typing import UNDEFINED, UndefinedType
import homeassistant.util.dt as dt_util
from .const import JSON_DUMP
@ -113,11 +114,13 @@ class Events(Base): # type: ignore[misc,valid-type]
)
@staticmethod
def from_event(event, event_data=None):
def from_event(
event: Event, event_data: UndefinedType | None = UNDEFINED
) -> Events:
"""Create an event database object from a native event."""
return Events(
event_type=event.event_type,
event_data=event_data or JSON_DUMP(event.data),
event_data=JSON_DUMP(event.data) if event_data is UNDEFINED else event_data,
origin=str(event.origin.value),
time_fired=event.time_fired,
context_id=event.context.id,
@ -125,7 +128,7 @@ class Events(Base): # type: ignore[misc,valid-type]
context_parent_id=event.context.parent_id,
)
def to_native(self, validate_entity_id=True):
def to_native(self, validate_entity_id: bool = True) -> Event | None:
"""Convert to a native HA Event."""
context = Context(
id=self.context_id,
@ -185,7 +188,7 @@ class States(Base): # type: ignore[misc,valid-type]
)
@staticmethod
def from_event(event) -> States:
def from_event(event: Event) -> States:
"""Create object from a state_changed event."""
entity_id = event.data["entity_id"]
state: State | None = event.data.get("new_state")
@ -266,12 +269,12 @@ class StateAttributes(Base): # type: ignore[misc,valid-type]
@staticmethod
def hash_shared_attrs(shared_attrs: str) -> int:
"""Return the hash of json encoded shared attributes."""
return fnv1a_32(shared_attrs.encode("utf-8"))
return cast(int, fnv1a_32(shared_attrs.encode("utf-8")))
def to_native(self) -> dict[str, Any]:
"""Convert to an HA state object."""
try:
return json.loads(self.shared_attrs)
return cast(dict[str, Any], json.loads(self.shared_attrs))
except ValueError:
# When json.loads fails
_LOGGER.exception("Error converting row to state attributes: %s", self)
@ -311,8 +314,8 @@ class StatisticsBase:
id = Column(Integer, Identity(), primary_key=True)
created = Column(DATETIME_TYPE, default=dt_util.utcnow)
@declared_attr
def metadata_id(self):
@declared_attr # type: ignore[misc]
def metadata_id(self) -> Column:
"""Define the metadata_id column for sub classes."""
return Column(
Integer,
@ -329,7 +332,7 @@ class StatisticsBase:
sum = Column(DOUBLE_TYPE)
@classmethod
def from_stats(cls, metadata_id: int, stats: StatisticData):
def from_stats(cls, metadata_id: int, stats: StatisticData) -> StatisticsBase:
"""Create object from a statistics."""
return cls( # type: ignore[call-arg,misc]
metadata_id=metadata_id,
@ -422,7 +425,7 @@ class RecorderRuns(Base): # type: ignore[misc,valid-type]
f")>"
)
def entity_ids(self, point_in_time=None):
def entity_ids(self, point_in_time: datetime | None = None) -> list[str]:
"""Return the entity ids that existed in this run.
Specify point_in_time if you want to know which existed at that point
@ -443,7 +446,7 @@ class RecorderRuns(Base): # type: ignore[misc,valid-type]
return [row[0] for row in query]
def to_native(self, validate_entity_id=True):
def to_native(self, validate_entity_id: bool = True) -> RecorderRuns:
"""Return self, native format is this model."""
return self
@ -540,16 +543,16 @@ class LazyState(State):
) -> None:
"""Init the lazy state."""
self._row = row
self.entity_id = self._row.entity_id
self.entity_id: str = self._row.entity_id
self.state = self._row.state or ""
self._attributes = None
self._last_changed = None
self._last_updated = None
self._context = None
self._attributes: dict[str, Any] | None = None
self._last_changed: datetime | None = None
self._last_updated: datetime | None = None
self._context: Context | None = None
self._attr_cache = attr_cache
@property # type: ignore[override]
def attributes(self):
def attributes(self) -> dict[str, Any]: # type: ignore[override]
"""State attributes."""
if self._attributes is None:
source = self._row.shared_attrs or self._row.attributes
@ -574,47 +577,47 @@ class LazyState(State):
return self._attributes
@attributes.setter
def attributes(self, value):
def attributes(self, value: dict[str, Any]) -> None:
"""Set attributes."""
self._attributes = value
@property # type: ignore[override]
def context(self):
def context(self) -> Context: # type: ignore[override]
"""State context."""
if not self._context:
self._context = Context(id=None)
if self._context is None:
self._context = Context(id=None) # type: ignore[arg-type]
return self._context
@context.setter
def context(self, value):
def context(self, value: Context) -> None:
"""Set context."""
self._context = value
@property # type: ignore[override]
def last_changed(self):
def last_changed(self) -> datetime: # type: ignore[override]
"""Last changed datetime."""
if not self._last_changed:
if self._last_changed is None:
self._last_changed = process_timestamp(self._row.last_changed)
return self._last_changed
@last_changed.setter
def last_changed(self, value):
def last_changed(self, value: datetime) -> None:
"""Set last changed datetime."""
self._last_changed = value
@property # type: ignore[override]
def last_updated(self):
def last_updated(self) -> datetime: # type: ignore[override]
"""Last updated datetime."""
if not self._last_updated:
if self._last_updated is None:
self._last_updated = process_timestamp(self._row.last_updated)
return self._last_updated
@last_updated.setter
def last_updated(self, value):
def last_updated(self, value: datetime) -> None:
"""Set last updated datetime."""
self._last_updated = value
def as_dict(self):
def as_dict(self) -> dict[str, Any]: # type: ignore[override]
"""Return a dict representation of the LazyState.
Async friendly.
@ -645,7 +648,7 @@ class LazyState(State):
"last_updated": last_updated_isoformat,
}
def __eq__(self, other):
def __eq__(self, other: Any) -> bool:
"""Return the comparison."""
return (
other.__class__ in [self.__class__, State]

View File

@ -1,5 +1,6 @@
"""A pool for sqlite connections."""
import threading
from typing import Any
from sqlalchemy.pool import NullPool, SingletonThreadPool
@ -10,14 +11,16 @@ from .const import DB_WORKER_PREFIX
POOL_SIZE = 5
class RecorderPool(SingletonThreadPool, NullPool):
class RecorderPool(SingletonThreadPool, NullPool): # type: ignore[misc]
"""A hybrid of NullPool and SingletonThreadPool.
When called from the creating thread or db executor acts like SingletonThreadPool
When called from any other thread, acts like NullPool
"""
def __init__(self, *args, **kw): # pylint: disable=super-init-not-called
def __init__( # pylint: disable=super-init-not-called
self, *args: Any, **kw: Any
) -> None:
"""Create the pool."""
kw["pool_size"] = POOL_SIZE
SingletonThreadPool.__init__(self, *args, **kw)
@ -30,22 +33,24 @@ class RecorderPool(SingletonThreadPool, NullPool):
thread_name == "Recorder" or thread_name.startswith(DB_WORKER_PREFIX)
)
def _do_return_conn(self, conn):
# Any can be switched out for ConnectionPoolEntry in the next version of sqlalchemy
def _do_return_conn(self, conn: Any) -> Any:
if self.recorder_or_dbworker:
return super()._do_return_conn(conn)
conn.close()
def shutdown(self):
def shutdown(self) -> None:
"""Close the connection."""
if self.recorder_or_dbworker and self._conn and (conn := self._conn.current()):
conn.close()
def dispose(self):
def dispose(self) -> None:
"""Dispose of the connection."""
if self.recorder_or_dbworker:
return super().dispose()
super().dispose()
def _do_get(self):
# Any can be switched out for ConnectionPoolEntry in the next version of sqlalchemy
def _do_get(self) -> Any:
if self.recorder_or_dbworker:
return super()._do_get()
report(

View File

@ -1606,6 +1606,17 @@ no_implicit_optional = true
warn_return_any = true
warn_unreachable = true
[mypy-homeassistant.components.recorder.models]
check_untyped_defs = true
disallow_incomplete_defs = true
disallow_subclassing_any = true
disallow_untyped_calls = true
disallow_untyped_decorators = true
disallow_untyped_defs = true
no_implicit_optional = true
warn_return_any = true
warn_unreachable = true
[mypy-homeassistant.components.recorder.history]
check_untyped_defs = true
disallow_incomplete_defs = true
@ -1617,6 +1628,17 @@ no_implicit_optional = true
warn_return_any = true
warn_unreachable = true
[mypy-homeassistant.components.recorder.pool]
check_untyped_defs = true
disallow_incomplete_defs = true
disallow_subclassing_any = true
disallow_untyped_calls = true
disallow_untyped_decorators = true
disallow_untyped_defs = true
no_implicit_optional = true
warn_return_any = true
warn_unreachable = true
[mypy-homeassistant.components.recorder.purge]
check_untyped_defs = true
disallow_incomplete_defs = true
@ -1661,6 +1683,17 @@ no_implicit_optional = true
warn_return_any = true
warn_unreachable = true
[mypy-homeassistant.components.recorder.websocket_api]
check_untyped_defs = true
disallow_incomplete_defs = true
disallow_subclassing_any = true
disallow_untyped_calls = true
disallow_untyped_decorators = true
disallow_untyped_defs = true
no_implicit_optional = true
warn_return_any = true
warn_unreachable = true
[mypy-homeassistant.components.remote.*]
check_untyped_defs = true
disallow_incomplete_defs = true