Cleanup recorder history typing ()

pull/69580/head
J. Nick Koston 2022-04-07 00:09:05 -10:00 committed by GitHub
parent 97aa65d9a4
commit 5c7c09726a
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 160 additions and 45 deletions
homeassistant/components
recorder
statistics
tests/components/recorder

View File

@ -2,7 +2,7 @@
from __future__ import annotations
from collections import defaultdict
from collections.abc import Iterable, MutableMapping
from collections.abc import Iterable, Iterator, MutableMapping
from datetime import datetime
from itertools import groupby
import logging
@ -141,7 +141,7 @@ def get_significant_states(
significant_changes_only: bool = True,
minimal_response: bool = False,
no_attributes: bool = False,
) -> MutableMapping[str, Iterable[LazyState | State | dict[str, Any]]]:
) -> MutableMapping[str, list[State | dict[str, Any]]]:
"""Wrap get_significant_states_with_session with an sql session."""
with session_scope(hass=hass) as session:
return get_significant_states_with_session(
@ -158,31 +158,20 @@ def get_significant_states(
)
def get_significant_states_with_session(
def _query_significant_states_with_session(
hass: HomeAssistant,
session: Session,
start_time: datetime,
end_time: datetime | None = None,
entity_ids: list[str] | None = None,
filters: Any = None,
include_start_time_state: bool = True,
significant_changes_only: bool = True,
minimal_response: bool = False,
no_attributes: bool = False,
) -> MutableMapping[str, Iterable[LazyState | State | dict[str, Any]]]:
"""
Return states changes during UTC period start_time - end_time.
) -> list[States]:
"""Query the database for significant state changes."""
if _LOGGER.isEnabledFor(logging.DEBUG):
timer_start = time.perf_counter()
entity_ids is an optional iterable of entities to include in the results.
filters is an optional SQLAlchemy filter which will be applied to the database
queries unless entity_ids is given, in which case its ignored.
Significant states are all states where there is a state change,
as well as all states from certain domains (for instance
thermostat so that we get current temperature in our graphs).
"""
timer_start = time.perf_counter()
baked_query, join_attributes = bake_query_and_join_attributes(hass, no_attributes)
if entity_ids is not None and len(entity_ids) == 1:
@ -240,6 +229,43 @@ def get_significant_states_with_session(
elapsed = time.perf_counter() - timer_start
_LOGGER.debug("get_significant_states took %fs", elapsed)
return states
def get_significant_states_with_session(
hass: HomeAssistant,
session: Session,
start_time: datetime,
end_time: datetime | None = None,
entity_ids: list[str] | None = None,
filters: Any = None,
include_start_time_state: bool = True,
significant_changes_only: bool = True,
minimal_response: bool = False,
no_attributes: bool = False,
) -> MutableMapping[str, list[State | dict[str, Any]]]:
"""
Return states changes during UTC period start_time - end_time.
entity_ids is an optional iterable of entities to include in the results.
filters is an optional SQLAlchemy filter which will be applied to the database
queries unless entity_ids is given, in which case its ignored.
Significant states are all states where there is a state change,
as well as all states from certain domains (for instance
thermostat so that we get current temperature in our graphs).
"""
states = _query_significant_states_with_session(
hass,
session,
start_time,
end_time,
entity_ids,
filters,
significant_changes_only,
no_attributes,
)
return _sorted_states_to_dict(
hass,
session,
@ -253,6 +279,35 @@ def get_significant_states_with_session(
)
def get_full_significant_states_with_session(
hass: HomeAssistant,
session: Session,
start_time: datetime,
end_time: datetime | None = None,
entity_ids: list[str] | None = None,
filters: Any = None,
include_start_time_state: bool = True,
significant_changes_only: bool = True,
no_attributes: bool = False,
) -> MutableMapping[str, list[State]]:
"""Variant of get_significant_states_with_session that does not return minimal responses."""
return cast(
MutableMapping[str, list[State]],
get_significant_states_with_session(
hass=hass,
session=session,
start_time=start_time,
end_time=end_time,
entity_ids=entity_ids,
filters=filters,
include_start_time_state=include_start_time_state,
significant_changes_only=significant_changes_only,
minimal_response=False,
no_attributes=no_attributes,
),
)
def state_changes_during_period(
hass: HomeAssistant,
start_time: datetime,
@ -262,7 +317,7 @@ def state_changes_during_period(
descending: bool = False,
limit: int | None = None,
include_start_time_state: bool = True,
) -> MutableMapping[str, Iterable[LazyState]]:
) -> MutableMapping[str, list[State]]:
"""Return states changes during UTC period start_time - end_time."""
with session_scope(hass=hass) as session:
baked_query, join_attributes = bake_query_and_join_attributes(
@ -303,7 +358,7 @@ def state_changes_during_period(
entity_ids = [entity_id] if entity_id is not None else None
return cast(
MutableMapping[str, Iterable[LazyState]],
MutableMapping[str, list[State]],
_sorted_states_to_dict(
hass,
session,
@ -317,7 +372,7 @@ def state_changes_during_period(
def get_last_state_changes(
hass: HomeAssistant, number_of_states: int, entity_id: str
) -> MutableMapping[str, Iterable[LazyState]]:
) -> MutableMapping[str, list[State]]:
"""Return the last number_of_states."""
start_time = dt_util.utcnow()
@ -349,7 +404,7 @@ def get_last_state_changes(
entity_ids = [entity_id] if entity_id is not None else None
return cast(
MutableMapping[str, Iterable[LazyState]],
MutableMapping[str, list[State]],
_sorted_states_to_dict(
hass,
session,
@ -368,7 +423,7 @@ def get_states(
run: RecorderRuns | None = None,
filters: Any = None,
no_attributes: bool = False,
) -> list[LazyState]:
) -> list[State]:
"""Return the states at a specific point in time."""
if (
run is None
@ -392,7 +447,7 @@ def _get_states_with_session(
run: RecorderRuns | None = None,
filters: Any | None = None,
no_attributes: bool = False,
) -> list[LazyState]:
) -> list[State]:
"""Return the states at a specific point in time."""
if entity_ids and len(entity_ids) == 1:
return _get_single_entity_states_with_session(
@ -488,7 +543,7 @@ def _get_single_entity_states_with_session(
utc_point_in_time: datetime,
entity_id: str,
no_attributes: bool = False,
) -> list[LazyState]:
) -> list[State]:
# Use an entirely different (and extremely fast) query if we only
# have a single entity id
baked_query, join_attributes = bake_query_and_join_attributes(hass, no_attributes)
@ -520,7 +575,7 @@ def _sorted_states_to_dict(
include_start_time_state: bool = True,
minimal_response: bool = False,
no_attributes: bool = False,
) -> MutableMapping[str, Iterable[LazyState | State | dict[str, Any]]]:
) -> MutableMapping[str, list[State | dict[str, Any]]]:
"""Convert SQL results into JSON friendly data structure.
This takes our state list and turns it into a JSON friendly data
@ -532,7 +587,7 @@ def _sorted_states_to_dict(
each list of states, otherwise our graphs won't start on the Y
axis correctly.
"""
result: dict[str, list[LazyState | dict[str, Any]]] = defaultdict(list)
result: dict[str, list[State | dict[str, Any]]] = defaultdict(list)
# Set all entity IDs to empty lists in result set to maintain the order
if entity_ids is not None:
for ent_id in entity_ids:
@ -563,21 +618,30 @@ def _sorted_states_to_dict(
# here
_process_timestamp_to_utc_isoformat = process_timestamp_to_utc_isoformat
if entity_ids and len(entity_ids) == 1:
states_iter: Iterable[tuple[str | Column, Iterator[States]]] = (
(entity_ids[0], iter(states)),
)
else:
states_iter = groupby(states, lambda state: state.entity_id)
# Append all changes to it
for ent_id, group in groupby(states, lambda state: state.entity_id): # type: ignore[no-any-return]
domain = split_entity_id(ent_id)[0]
for ent_id, group in states_iter:
ent_results = result[ent_id]
attr_cache: dict[str, dict[str, Any]] = {}
if not minimal_response or domain in NEED_ATTRIBUTE_DOMAINS:
if not minimal_response or split_entity_id(ent_id)[0] in NEED_ATTRIBUTE_DOMAINS:
ent_results.extend(LazyState(db_state, attr_cache) for db_state in group)
continue
# With minimal response we only provide a native
# State for the first and last response. All the states
# in-between only provide the "state" and the
# "last_changed".
if not ent_results:
ent_results.append(LazyState(next(group), attr_cache))
if (first_state := next(group, None)) is None:
continue
ent_results.append(LazyState(first_state, attr_cache))
prev_state = ent_results[-1]
assert isinstance(prev_state, LazyState)
@ -615,7 +679,7 @@ def get_state(
entity_id: str,
run: RecorderRuns | None = None,
no_attributes: bool = False,
) -> LazyState | None:
) -> State | None:
"""Return a state at a specific point in time."""
states = get_states(hass, utc_point_in_time, [entity_id], run, None, no_attributes)
return states[0] if states else None

View File

@ -7,7 +7,7 @@ import datetime
import itertools
import logging
import math
from typing import Any, cast
from typing import Any
from sqlalchemy.orm.session import Session
@ -19,7 +19,6 @@ from homeassistant.components.recorder import (
)
from homeassistant.components.recorder.const import DOMAIN as RECORDER_DOMAIN
from homeassistant.components.recorder.models import (
LazyState,
StatisticData,
StatisticMetaData,
StatisticResult,
@ -417,9 +416,9 @@ def _compile_statistics( # noqa: C901
entities_full_history = [
i.entity_id for i in sensor_states if "sum" in wanted_statistics[i.entity_id]
]
history_list: MutableMapping[str, Iterable[LazyState | State | dict[str, Any]]] = {}
history_list: MutableMapping[str, list[State]] = {}
if entities_full_history:
history_list = history.get_significant_states_with_session(
history_list = history.get_full_significant_states_with_session(
hass,
session,
start - datetime.timedelta.resolution,
@ -433,7 +432,7 @@ def _compile_statistics( # noqa: C901
if "sum" not in wanted_statistics[i.entity_id]
]
if entities_significant_history:
_history_list = history.get_significant_states_with_session(
_history_list = history.get_full_significant_states_with_session(
hass,
session,
start - datetime.timedelta.resolution,
@ -445,7 +444,7 @@ def _compile_statistics( # noqa: C901
# from the recorder. Get the state from the state machine instead.
for _state in sensor_states:
if _state.entity_id not in history_list:
history_list[_state.entity_id] = (_state,)
history_list[_state.entity_id] = [_state]
for _state in sensor_states: # pylint: disable=too-many-nested-blocks
entity_id = _state.entity_id
@ -459,9 +458,7 @@ def _compile_statistics( # noqa: C901
hass,
session,
old_metadatas,
# entity_history does not contain minimal responses
# so we must cast here
cast(list[State], entity_history),
entity_history,
device_class,
entity_id,
)

View File

@ -485,16 +485,14 @@ class StatisticsSensor(SensorEntity):
else:
start_date = datetime.fromtimestamp(0, tz=dt_util.UTC)
_LOGGER.debug("%s: retrieving all records", self.entity_id)
entity_states = history.state_changes_during_period(
return history.state_changes_during_period(
self.hass,
start_date,
entity_id=lower_entity_id,
descending=True,
limit=self._samples_max_buffer_size,
include_start_time_state=False,
)
# Need to cast since minimal responses is not passed in
return cast(list[State], entity_states.get(lower_entity_id, []))
).get(lower_entity_id, [])
async def _initialize_from_database(self) -> None:
"""Initialize the list of states from the database.

View File

@ -124,6 +124,62 @@ def test_get_states(hass_recorder):
assert history.get_state(hass, time_before_recorder_ran, "demo.id") is None
def test_get_full_significant_states_with_session_entity_no_matches(hass_recorder):
"""Test getting states at a specific point in time for entities that never have been recorded."""
hass = hass_recorder()
now = dt_util.utcnow()
time_before_recorder_ran = now - timedelta(days=1000)
with recorder.session_scope(hass=hass) as session:
assert (
history.get_full_significant_states_with_session(
hass, session, time_before_recorder_ran, now, entity_ids=["demo.id"]
)
== {}
)
assert (
history.get_full_significant_states_with_session(
hass,
session,
time_before_recorder_ran,
now,
entity_ids=["demo.id", "demo.id2"],
)
== {}
)
def test_significant_states_with_session_entity_minimal_response_no_matches(
hass_recorder,
):
"""Test getting states at a specific point in time for entities that never have been recorded."""
hass = hass_recorder()
now = dt_util.utcnow()
time_before_recorder_ran = now - timedelta(days=1000)
with recorder.session_scope(hass=hass) as session:
assert (
history.get_significant_states_with_session(
hass,
session,
time_before_recorder_ran,
now,
entity_ids=["demo.id"],
minimal_response=True,
)
== {}
)
assert (
history.get_significant_states_with_session(
hass,
session,
time_before_recorder_ran,
now,
entity_ids=["demo.id", "demo.id2"],
minimal_response=True,
)
== {}
)
def test_get_states_no_attributes(hass_recorder):
"""Test getting states without attributes at a specific point in time."""
hass = hass_recorder()