Reduce branching in generated lambda_stmts (#73042)
parent
3744edc512
commit
c66b000d34
|
@ -15,6 +15,7 @@ from sqlalchemy.orm.query import Query
|
|||
from sqlalchemy.orm.session import Session
|
||||
from sqlalchemy.sql.expression import literal
|
||||
from sqlalchemy.sql.lambdas import StatementLambdaElement
|
||||
from sqlalchemy.sql.selectable import Subquery
|
||||
|
||||
from homeassistant.components import recorder
|
||||
from homeassistant.components.websocket_api.const import (
|
||||
|
@ -485,6 +486,25 @@ def _get_states_for_entites_stmt(
|
|||
return stmt
|
||||
|
||||
|
||||
def _generate_most_recent_states_by_date(
|
||||
run_start: datetime,
|
||||
utc_point_in_time: datetime,
|
||||
) -> Subquery:
|
||||
"""Generate the sub query for the most recent states by data."""
|
||||
return (
|
||||
select(
|
||||
States.entity_id.label("max_entity_id"),
|
||||
func.max(States.last_updated).label("max_last_updated"),
|
||||
)
|
||||
.filter(
|
||||
(States.last_updated >= run_start)
|
||||
& (States.last_updated < utc_point_in_time)
|
||||
)
|
||||
.group_by(States.entity_id)
|
||||
.subquery()
|
||||
)
|
||||
|
||||
|
||||
def _get_states_for_all_stmt(
|
||||
schema_version: int,
|
||||
run_start: datetime,
|
||||
|
@ -500,17 +520,8 @@ def _get_states_for_all_stmt(
|
|||
# query, then filter out unwanted domains as well as applying the custom filter.
|
||||
# This filtering can't be done in the inner query because the domain column is
|
||||
# not indexed and we can't control what's in the custom filter.
|
||||
most_recent_states_by_date = (
|
||||
select(
|
||||
States.entity_id.label("max_entity_id"),
|
||||
func.max(States.last_updated).label("max_last_updated"),
|
||||
)
|
||||
.filter(
|
||||
(States.last_updated >= run_start)
|
||||
& (States.last_updated < utc_point_in_time)
|
||||
)
|
||||
.group_by(States.entity_id)
|
||||
.subquery()
|
||||
most_recent_states_by_date = _generate_most_recent_states_by_date(
|
||||
run_start, utc_point_in_time
|
||||
)
|
||||
stmt += lambda q: q.where(
|
||||
States.state_id
|
||||
|
|
|
@ -20,6 +20,7 @@ from sqlalchemy.exc import SQLAlchemyError, StatementError
|
|||
from sqlalchemy.orm.session import Session
|
||||
from sqlalchemy.sql.expression import literal_column, true
|
||||
from sqlalchemy.sql.lambdas import StatementLambdaElement
|
||||
from sqlalchemy.sql.selectable import Subquery
|
||||
import voluptuous as vol
|
||||
|
||||
from homeassistant.const import (
|
||||
|
@ -484,14 +485,13 @@ def _compile_hourly_statistics_summary_mean_stmt(
|
|||
start_time: datetime, end_time: datetime
|
||||
) -> StatementLambdaElement:
|
||||
"""Generate the summary mean statement for hourly statistics."""
|
||||
stmt = lambda_stmt(lambda: select(*QUERY_STATISTICS_SUMMARY_MEAN))
|
||||
stmt += (
|
||||
lambda q: q.filter(StatisticsShortTerm.start >= start_time)
|
||||
return lambda_stmt(
|
||||
lambda: select(*QUERY_STATISTICS_SUMMARY_MEAN)
|
||||
.filter(StatisticsShortTerm.start >= start_time)
|
||||
.filter(StatisticsShortTerm.start < end_time)
|
||||
.group_by(StatisticsShortTerm.metadata_id)
|
||||
.order_by(StatisticsShortTerm.metadata_id)
|
||||
)
|
||||
return stmt
|
||||
|
||||
|
||||
def compile_hourly_statistics(
|
||||
|
@ -985,26 +985,43 @@ def _statistics_during_period_stmt(
|
|||
start_time: datetime,
|
||||
end_time: datetime | None,
|
||||
metadata_ids: list[int] | None,
|
||||
table: type[Statistics | StatisticsShortTerm],
|
||||
) -> StatementLambdaElement:
|
||||
"""Prepare a database query for statistics during a given period.
|
||||
|
||||
This prepares a lambda_stmt query, so we don't insert the parameters yet.
|
||||
"""
|
||||
if table == StatisticsShortTerm:
|
||||
stmt = lambda_stmt(lambda: select(*QUERY_STATISTICS_SHORT_TERM))
|
||||
else:
|
||||
stmt = lambda_stmt(lambda: select(*QUERY_STATISTICS))
|
||||
|
||||
stmt += lambda q: q.filter(table.start >= start_time)
|
||||
|
||||
stmt = lambda_stmt(
|
||||
lambda: select(*QUERY_STATISTICS).filter(Statistics.start >= start_time)
|
||||
)
|
||||
if end_time is not None:
|
||||
stmt += lambda q: q.filter(table.start < end_time)
|
||||
|
||||
stmt += lambda q: q.filter(Statistics.start < end_time)
|
||||
if metadata_ids:
|
||||
stmt += lambda q: q.filter(table.metadata_id.in_(metadata_ids))
|
||||
stmt += lambda q: q.filter(Statistics.metadata_id.in_(metadata_ids))
|
||||
stmt += lambda q: q.order_by(Statistics.metadata_id, Statistics.start)
|
||||
return stmt
|
||||
|
||||
stmt += lambda q: q.order_by(table.metadata_id, table.start)
|
||||
|
||||
def _statistics_during_period_stmt_short_term(
|
||||
start_time: datetime,
|
||||
end_time: datetime | None,
|
||||
metadata_ids: list[int] | None,
|
||||
) -> StatementLambdaElement:
|
||||
"""Prepare a database query for short term statistics during a given period.
|
||||
|
||||
This prepares a lambda_stmt query, so we don't insert the parameters yet.
|
||||
"""
|
||||
stmt = lambda_stmt(
|
||||
lambda: select(*QUERY_STATISTICS_SHORT_TERM).filter(
|
||||
StatisticsShortTerm.start >= start_time
|
||||
)
|
||||
)
|
||||
if end_time is not None:
|
||||
stmt += lambda q: q.filter(StatisticsShortTerm.start < end_time)
|
||||
if metadata_ids:
|
||||
stmt += lambda q: q.filter(StatisticsShortTerm.metadata_id.in_(metadata_ids))
|
||||
stmt += lambda q: q.order_by(
|
||||
StatisticsShortTerm.metadata_id, StatisticsShortTerm.start
|
||||
)
|
||||
return stmt
|
||||
|
||||
|
||||
|
@ -1034,10 +1051,12 @@ def statistics_during_period(
|
|||
|
||||
if period == "5minute":
|
||||
table = StatisticsShortTerm
|
||||
stmt = _statistics_during_period_stmt_short_term(
|
||||
start_time, end_time, metadata_ids
|
||||
)
|
||||
else:
|
||||
table = Statistics
|
||||
|
||||
stmt = _statistics_during_period_stmt(start_time, end_time, metadata_ids, table)
|
||||
stmt = _statistics_during_period_stmt(start_time, end_time, metadata_ids)
|
||||
stats = execute_stmt_lambda_element(session, stmt)
|
||||
|
||||
if not stats:
|
||||
|
@ -1069,19 +1088,27 @@ def statistics_during_period(
|
|||
def _get_last_statistics_stmt(
|
||||
metadata_id: int,
|
||||
number_of_stats: int,
|
||||
table: type[Statistics | StatisticsShortTerm],
|
||||
) -> StatementLambdaElement:
|
||||
"""Generate a statement for number_of_stats statistics for a given statistic_id."""
|
||||
if table == StatisticsShortTerm:
|
||||
stmt = lambda_stmt(lambda: select(*QUERY_STATISTICS_SHORT_TERM))
|
||||
else:
|
||||
stmt = lambda_stmt(lambda: select(*QUERY_STATISTICS))
|
||||
stmt += (
|
||||
lambda q: q.filter_by(metadata_id=metadata_id)
|
||||
.order_by(table.metadata_id, table.start.desc())
|
||||
return lambda_stmt(
|
||||
lambda: select(*QUERY_STATISTICS)
|
||||
.filter_by(metadata_id=metadata_id)
|
||||
.order_by(Statistics.metadata_id, Statistics.start.desc())
|
||||
.limit(number_of_stats)
|
||||
)
|
||||
|
||||
|
||||
def _get_last_statistics_short_term_stmt(
|
||||
metadata_id: int,
|
||||
number_of_stats: int,
|
||||
) -> StatementLambdaElement:
|
||||
"""Generate a statement for number_of_stats short term statistics for a given statistic_id."""
|
||||
return lambda_stmt(
|
||||
lambda: select(*QUERY_STATISTICS_SHORT_TERM)
|
||||
.filter_by(metadata_id=metadata_id)
|
||||
.order_by(StatisticsShortTerm.metadata_id, StatisticsShortTerm.start.desc())
|
||||
.limit(number_of_stats)
|
||||
)
|
||||
return stmt
|
||||
|
||||
|
||||
def _get_last_statistics(
|
||||
|
@ -1099,7 +1126,10 @@ def _get_last_statistics(
|
|||
if not metadata:
|
||||
return {}
|
||||
metadata_id = metadata[statistic_id][0]
|
||||
stmt = _get_last_statistics_stmt(metadata_id, number_of_stats, table)
|
||||
if table == Statistics:
|
||||
stmt = _get_last_statistics_stmt(metadata_id, number_of_stats)
|
||||
else:
|
||||
stmt = _get_last_statistics_short_term_stmt(metadata_id, number_of_stats)
|
||||
stats = execute_stmt_lambda_element(session, stmt)
|
||||
|
||||
if not stats:
|
||||
|
@ -1136,12 +1166,9 @@ def get_last_short_term_statistics(
|
|||
)
|
||||
|
||||
|
||||
def _latest_short_term_statistics_stmt(
|
||||
metadata_ids: list[int],
|
||||
) -> StatementLambdaElement:
|
||||
"""Create the statement for finding the latest short term stat rows."""
|
||||
stmt = lambda_stmt(lambda: select(*QUERY_STATISTICS_SHORT_TERM))
|
||||
most_recent_statistic_row = (
|
||||
def _generate_most_recent_statistic_row(metadata_ids: list[int]) -> Subquery:
|
||||
"""Generate the subquery to find the most recent statistic row."""
|
||||
return (
|
||||
select(
|
||||
StatisticsShortTerm.metadata_id,
|
||||
func.max(StatisticsShortTerm.start).label("start_max"),
|
||||
|
@ -1149,6 +1176,14 @@ def _latest_short_term_statistics_stmt(
|
|||
.where(StatisticsShortTerm.metadata_id.in_(metadata_ids))
|
||||
.group_by(StatisticsShortTerm.metadata_id)
|
||||
).subquery()
|
||||
|
||||
|
||||
def _latest_short_term_statistics_stmt(
|
||||
metadata_ids: list[int],
|
||||
) -> StatementLambdaElement:
|
||||
"""Create the statement for finding the latest short term stat rows."""
|
||||
stmt = lambda_stmt(lambda: select(*QUERY_STATISTICS_SHORT_TERM))
|
||||
most_recent_statistic_row = _generate_most_recent_statistic_row(metadata_ids)
|
||||
stmt += lambda s: s.join(
|
||||
most_recent_statistic_row,
|
||||
(
|
||||
|
|
|
@ -100,6 +100,15 @@ def test_compile_hourly_statistics(hass_recorder):
|
|||
stats = statistics_during_period(hass, zero, period="5minute")
|
||||
assert stats == {"sensor.test1": expected_stats1, "sensor.test2": expected_stats2}
|
||||
|
||||
# Test statistics_during_period with a far future start and end date
|
||||
future = dt_util.as_utc(dt_util.parse_datetime("2221-11-01 00:00:00"))
|
||||
stats = statistics_during_period(hass, future, end_time=future, period="5minute")
|
||||
assert stats == {}
|
||||
|
||||
# Test statistics_during_period with a far future end date
|
||||
stats = statistics_during_period(hass, zero, end_time=future, period="5minute")
|
||||
assert stats == {"sensor.test1": expected_stats1, "sensor.test2": expected_stats2}
|
||||
|
||||
stats = statistics_during_period(
|
||||
hass, zero, statistic_ids=["sensor.test2"], period="5minute"
|
||||
)
|
||||
|
@ -814,6 +823,59 @@ def test_monthly_statistics(hass_recorder, caplog, timezone):
|
|||
]
|
||||
}
|
||||
|
||||
stats = statistics_during_period(
|
||||
hass,
|
||||
start_time=zero,
|
||||
statistic_ids=["not", "the", "same", "test:total_energy_import"],
|
||||
period="month",
|
||||
)
|
||||
sep_start = dt_util.as_utc(dt_util.parse_datetime("2021-09-01 00:00:00"))
|
||||
sep_end = dt_util.as_utc(dt_util.parse_datetime("2021-10-01 00:00:00"))
|
||||
oct_start = dt_util.as_utc(dt_util.parse_datetime("2021-10-01 00:00:00"))
|
||||
oct_end = dt_util.as_utc(dt_util.parse_datetime("2021-11-01 00:00:00"))
|
||||
assert stats == {
|
||||
"test:total_energy_import": [
|
||||
{
|
||||
"statistic_id": "test:total_energy_import",
|
||||
"start": sep_start.isoformat(),
|
||||
"end": sep_end.isoformat(),
|
||||
"max": None,
|
||||
"mean": None,
|
||||
"min": None,
|
||||
"last_reset": None,
|
||||
"state": approx(1.0),
|
||||
"sum": approx(3.0),
|
||||
},
|
||||
{
|
||||
"statistic_id": "test:total_energy_import",
|
||||
"start": oct_start.isoformat(),
|
||||
"end": oct_end.isoformat(),
|
||||
"max": None,
|
||||
"mean": None,
|
||||
"min": None,
|
||||
"last_reset": None,
|
||||
"state": approx(3.0),
|
||||
"sum": approx(5.0),
|
||||
},
|
||||
]
|
||||
}
|
||||
|
||||
# Use 5minute to ensure table switch works
|
||||
stats = statistics_during_period(
|
||||
hass,
|
||||
start_time=zero,
|
||||
statistic_ids=["test:total_energy_import", "with_other"],
|
||||
period="5minute",
|
||||
)
|
||||
assert stats == {}
|
||||
|
||||
# Ensure future date has not data
|
||||
future = dt_util.as_utc(dt_util.parse_datetime("2221-11-01 00:00:00"))
|
||||
stats = statistics_during_period(
|
||||
hass, start_time=future, end_time=future, period="month"
|
||||
)
|
||||
assert stats == {}
|
||||
|
||||
dt_util.set_default_time_zone(dt_util.get_time_zone("UTC"))
|
||||
|
||||
|
||||
|
|
Loading…
Reference in New Issue