Reduce branching in generated lambda_stmts (#73042)

pull/73063/head^2
J. Nick Koston 2022-06-05 18:13:31 -10:00 committed by GitHub
parent 3744edc512
commit c66b000d34
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 153 additions and 45 deletions

View File

@ -15,6 +15,7 @@ from sqlalchemy.orm.query import Query
from sqlalchemy.orm.session import Session
from sqlalchemy.sql.expression import literal
from sqlalchemy.sql.lambdas import StatementLambdaElement
from sqlalchemy.sql.selectable import Subquery
from homeassistant.components import recorder
from homeassistant.components.websocket_api.const import (
@ -485,6 +486,25 @@ def _get_states_for_entites_stmt(
return stmt
def _generate_most_recent_states_by_date(
run_start: datetime,
utc_point_in_time: datetime,
) -> Subquery:
"""Generate the sub query for the most recent states by data."""
return (
select(
States.entity_id.label("max_entity_id"),
func.max(States.last_updated).label("max_last_updated"),
)
.filter(
(States.last_updated >= run_start)
& (States.last_updated < utc_point_in_time)
)
.group_by(States.entity_id)
.subquery()
)
def _get_states_for_all_stmt(
schema_version: int,
run_start: datetime,
@ -500,17 +520,8 @@ def _get_states_for_all_stmt(
# query, then filter out unwanted domains as well as applying the custom filter.
# This filtering can't be done in the inner query because the domain column is
# not indexed and we can't control what's in the custom filter.
most_recent_states_by_date = (
select(
States.entity_id.label("max_entity_id"),
func.max(States.last_updated).label("max_last_updated"),
)
.filter(
(States.last_updated >= run_start)
& (States.last_updated < utc_point_in_time)
)
.group_by(States.entity_id)
.subquery()
most_recent_states_by_date = _generate_most_recent_states_by_date(
run_start, utc_point_in_time
)
stmt += lambda q: q.where(
States.state_id

View File

@ -20,6 +20,7 @@ from sqlalchemy.exc import SQLAlchemyError, StatementError
from sqlalchemy.orm.session import Session
from sqlalchemy.sql.expression import literal_column, true
from sqlalchemy.sql.lambdas import StatementLambdaElement
from sqlalchemy.sql.selectable import Subquery
import voluptuous as vol
from homeassistant.const import (
@ -484,14 +485,13 @@ def _compile_hourly_statistics_summary_mean_stmt(
start_time: datetime, end_time: datetime
) -> StatementLambdaElement:
"""Generate the summary mean statement for hourly statistics."""
stmt = lambda_stmt(lambda: select(*QUERY_STATISTICS_SUMMARY_MEAN))
stmt += (
lambda q: q.filter(StatisticsShortTerm.start >= start_time)
return lambda_stmt(
lambda: select(*QUERY_STATISTICS_SUMMARY_MEAN)
.filter(StatisticsShortTerm.start >= start_time)
.filter(StatisticsShortTerm.start < end_time)
.group_by(StatisticsShortTerm.metadata_id)
.order_by(StatisticsShortTerm.metadata_id)
)
return stmt
def compile_hourly_statistics(
@ -985,26 +985,43 @@ def _statistics_during_period_stmt(
start_time: datetime,
end_time: datetime | None,
metadata_ids: list[int] | None,
table: type[Statistics | StatisticsShortTerm],
) -> StatementLambdaElement:
"""Prepare a database query for statistics during a given period.
This prepares a lambda_stmt query, so we don't insert the parameters yet.
"""
if table == StatisticsShortTerm:
stmt = lambda_stmt(lambda: select(*QUERY_STATISTICS_SHORT_TERM))
else:
stmt = lambda_stmt(lambda: select(*QUERY_STATISTICS))
stmt += lambda q: q.filter(table.start >= start_time)
stmt = lambda_stmt(
lambda: select(*QUERY_STATISTICS).filter(Statistics.start >= start_time)
)
if end_time is not None:
stmt += lambda q: q.filter(table.start < end_time)
stmt += lambda q: q.filter(Statistics.start < end_time)
if metadata_ids:
stmt += lambda q: q.filter(table.metadata_id.in_(metadata_ids))
stmt += lambda q: q.filter(Statistics.metadata_id.in_(metadata_ids))
stmt += lambda q: q.order_by(Statistics.metadata_id, Statistics.start)
return stmt
stmt += lambda q: q.order_by(table.metadata_id, table.start)
def _statistics_during_period_stmt_short_term(
start_time: datetime,
end_time: datetime | None,
metadata_ids: list[int] | None,
) -> StatementLambdaElement:
"""Prepare a database query for short term statistics during a given period.
This prepares a lambda_stmt query, so we don't insert the parameters yet.
"""
stmt = lambda_stmt(
lambda: select(*QUERY_STATISTICS_SHORT_TERM).filter(
StatisticsShortTerm.start >= start_time
)
)
if end_time is not None:
stmt += lambda q: q.filter(StatisticsShortTerm.start < end_time)
if metadata_ids:
stmt += lambda q: q.filter(StatisticsShortTerm.metadata_id.in_(metadata_ids))
stmt += lambda q: q.order_by(
StatisticsShortTerm.metadata_id, StatisticsShortTerm.start
)
return stmt
@ -1034,10 +1051,12 @@ def statistics_during_period(
if period == "5minute":
table = StatisticsShortTerm
stmt = _statistics_during_period_stmt_short_term(
start_time, end_time, metadata_ids
)
else:
table = Statistics
stmt = _statistics_during_period_stmt(start_time, end_time, metadata_ids, table)
stmt = _statistics_during_period_stmt(start_time, end_time, metadata_ids)
stats = execute_stmt_lambda_element(session, stmt)
if not stats:
@ -1069,19 +1088,27 @@ def statistics_during_period(
def _get_last_statistics_stmt(
metadata_id: int,
number_of_stats: int,
table: type[Statistics | StatisticsShortTerm],
) -> StatementLambdaElement:
"""Generate a statement for number_of_stats statistics for a given statistic_id."""
if table == StatisticsShortTerm:
stmt = lambda_stmt(lambda: select(*QUERY_STATISTICS_SHORT_TERM))
else:
stmt = lambda_stmt(lambda: select(*QUERY_STATISTICS))
stmt += (
lambda q: q.filter_by(metadata_id=metadata_id)
.order_by(table.metadata_id, table.start.desc())
return lambda_stmt(
lambda: select(*QUERY_STATISTICS)
.filter_by(metadata_id=metadata_id)
.order_by(Statistics.metadata_id, Statistics.start.desc())
.limit(number_of_stats)
)
def _get_last_statistics_short_term_stmt(
metadata_id: int,
number_of_stats: int,
) -> StatementLambdaElement:
"""Generate a statement for number_of_stats short term statistics for a given statistic_id."""
return lambda_stmt(
lambda: select(*QUERY_STATISTICS_SHORT_TERM)
.filter_by(metadata_id=metadata_id)
.order_by(StatisticsShortTerm.metadata_id, StatisticsShortTerm.start.desc())
.limit(number_of_stats)
)
return stmt
def _get_last_statistics(
@ -1099,7 +1126,10 @@ def _get_last_statistics(
if not metadata:
return {}
metadata_id = metadata[statistic_id][0]
stmt = _get_last_statistics_stmt(metadata_id, number_of_stats, table)
if table == Statistics:
stmt = _get_last_statistics_stmt(metadata_id, number_of_stats)
else:
stmt = _get_last_statistics_short_term_stmt(metadata_id, number_of_stats)
stats = execute_stmt_lambda_element(session, stmt)
if not stats:
@ -1136,12 +1166,9 @@ def get_last_short_term_statistics(
)
def _latest_short_term_statistics_stmt(
metadata_ids: list[int],
) -> StatementLambdaElement:
"""Create the statement for finding the latest short term stat rows."""
stmt = lambda_stmt(lambda: select(*QUERY_STATISTICS_SHORT_TERM))
most_recent_statistic_row = (
def _generate_most_recent_statistic_row(metadata_ids: list[int]) -> Subquery:
"""Generate the subquery to find the most recent statistic row."""
return (
select(
StatisticsShortTerm.metadata_id,
func.max(StatisticsShortTerm.start).label("start_max"),
@ -1149,6 +1176,14 @@ def _latest_short_term_statistics_stmt(
.where(StatisticsShortTerm.metadata_id.in_(metadata_ids))
.group_by(StatisticsShortTerm.metadata_id)
).subquery()
def _latest_short_term_statistics_stmt(
metadata_ids: list[int],
) -> StatementLambdaElement:
"""Create the statement for finding the latest short term stat rows."""
stmt = lambda_stmt(lambda: select(*QUERY_STATISTICS_SHORT_TERM))
most_recent_statistic_row = _generate_most_recent_statistic_row(metadata_ids)
stmt += lambda s: s.join(
most_recent_statistic_row,
(

View File

@ -100,6 +100,15 @@ def test_compile_hourly_statistics(hass_recorder):
stats = statistics_during_period(hass, zero, period="5minute")
assert stats == {"sensor.test1": expected_stats1, "sensor.test2": expected_stats2}
# Test statistics_during_period with a far future start and end date
future = dt_util.as_utc(dt_util.parse_datetime("2221-11-01 00:00:00"))
stats = statistics_during_period(hass, future, end_time=future, period="5minute")
assert stats == {}
# Test statistics_during_period with a far future end date
stats = statistics_during_period(hass, zero, end_time=future, period="5minute")
assert stats == {"sensor.test1": expected_stats1, "sensor.test2": expected_stats2}
stats = statistics_during_period(
hass, zero, statistic_ids=["sensor.test2"], period="5minute"
)
@ -814,6 +823,59 @@ def test_monthly_statistics(hass_recorder, caplog, timezone):
]
}
stats = statistics_during_period(
hass,
start_time=zero,
statistic_ids=["not", "the", "same", "test:total_energy_import"],
period="month",
)
sep_start = dt_util.as_utc(dt_util.parse_datetime("2021-09-01 00:00:00"))
sep_end = dt_util.as_utc(dt_util.parse_datetime("2021-10-01 00:00:00"))
oct_start = dt_util.as_utc(dt_util.parse_datetime("2021-10-01 00:00:00"))
oct_end = dt_util.as_utc(dt_util.parse_datetime("2021-11-01 00:00:00"))
assert stats == {
"test:total_energy_import": [
{
"statistic_id": "test:total_energy_import",
"start": sep_start.isoformat(),
"end": sep_end.isoformat(),
"max": None,
"mean": None,
"min": None,
"last_reset": None,
"state": approx(1.0),
"sum": approx(3.0),
},
{
"statistic_id": "test:total_energy_import",
"start": oct_start.isoformat(),
"end": oct_end.isoformat(),
"max": None,
"mean": None,
"min": None,
"last_reset": None,
"state": approx(3.0),
"sum": approx(5.0),
},
]
}
# Use 5minute to ensure table switch works
stats = statistics_during_period(
hass,
start_time=zero,
statistic_ids=["test:total_energy_import", "with_other"],
period="5minute",
)
assert stats == {}
# Ensure future date has not data
future = dt_util.as_utc(dt_util.parse_datetime("2221-11-01 00:00:00"))
stats = statistics_during_period(
hass, start_time=future, end_time=future, period="month"
)
assert stats == {}
dt_util.set_default_time_zone(dt_util.get_time_zone("UTC"))