Reduce branching in generated lambda_stmts (#73042)

pull/73063/head^2
J. Nick Koston 2022-06-05 18:13:31 -10:00 committed by GitHub
parent 3744edc512
commit c66b000d34
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 153 additions and 45 deletions

View File

@ -15,6 +15,7 @@ from sqlalchemy.orm.query import Query
from sqlalchemy.orm.session import Session from sqlalchemy.orm.session import Session
from sqlalchemy.sql.expression import literal from sqlalchemy.sql.expression import literal
from sqlalchemy.sql.lambdas import StatementLambdaElement from sqlalchemy.sql.lambdas import StatementLambdaElement
from sqlalchemy.sql.selectable import Subquery
from homeassistant.components import recorder from homeassistant.components import recorder
from homeassistant.components.websocket_api.const import ( from homeassistant.components.websocket_api.const import (
@ -485,6 +486,25 @@ def _get_states_for_entites_stmt(
return stmt return stmt
def _generate_most_recent_states_by_date(
run_start: datetime,
utc_point_in_time: datetime,
) -> Subquery:
"""Generate the sub query for the most recent states by data."""
return (
select(
States.entity_id.label("max_entity_id"),
func.max(States.last_updated).label("max_last_updated"),
)
.filter(
(States.last_updated >= run_start)
& (States.last_updated < utc_point_in_time)
)
.group_by(States.entity_id)
.subquery()
)
def _get_states_for_all_stmt( def _get_states_for_all_stmt(
schema_version: int, schema_version: int,
run_start: datetime, run_start: datetime,
@ -500,17 +520,8 @@ def _get_states_for_all_stmt(
# query, then filter out unwanted domains as well as applying the custom filter. # query, then filter out unwanted domains as well as applying the custom filter.
# This filtering can't be done in the inner query because the domain column is # This filtering can't be done in the inner query because the domain column is
# not indexed and we can't control what's in the custom filter. # not indexed and we can't control what's in the custom filter.
most_recent_states_by_date = ( most_recent_states_by_date = _generate_most_recent_states_by_date(
select( run_start, utc_point_in_time
States.entity_id.label("max_entity_id"),
func.max(States.last_updated).label("max_last_updated"),
)
.filter(
(States.last_updated >= run_start)
& (States.last_updated < utc_point_in_time)
)
.group_by(States.entity_id)
.subquery()
) )
stmt += lambda q: q.where( stmt += lambda q: q.where(
States.state_id States.state_id

View File

@ -20,6 +20,7 @@ from sqlalchemy.exc import SQLAlchemyError, StatementError
from sqlalchemy.orm.session import Session from sqlalchemy.orm.session import Session
from sqlalchemy.sql.expression import literal_column, true from sqlalchemy.sql.expression import literal_column, true
from sqlalchemy.sql.lambdas import StatementLambdaElement from sqlalchemy.sql.lambdas import StatementLambdaElement
from sqlalchemy.sql.selectable import Subquery
import voluptuous as vol import voluptuous as vol
from homeassistant.const import ( from homeassistant.const import (
@ -484,14 +485,13 @@ def _compile_hourly_statistics_summary_mean_stmt(
start_time: datetime, end_time: datetime start_time: datetime, end_time: datetime
) -> StatementLambdaElement: ) -> StatementLambdaElement:
"""Generate the summary mean statement for hourly statistics.""" """Generate the summary mean statement for hourly statistics."""
stmt = lambda_stmt(lambda: select(*QUERY_STATISTICS_SUMMARY_MEAN)) return lambda_stmt(
stmt += ( lambda: select(*QUERY_STATISTICS_SUMMARY_MEAN)
lambda q: q.filter(StatisticsShortTerm.start >= start_time) .filter(StatisticsShortTerm.start >= start_time)
.filter(StatisticsShortTerm.start < end_time) .filter(StatisticsShortTerm.start < end_time)
.group_by(StatisticsShortTerm.metadata_id) .group_by(StatisticsShortTerm.metadata_id)
.order_by(StatisticsShortTerm.metadata_id) .order_by(StatisticsShortTerm.metadata_id)
) )
return stmt
def compile_hourly_statistics( def compile_hourly_statistics(
@ -985,26 +985,43 @@ def _statistics_during_period_stmt(
start_time: datetime, start_time: datetime,
end_time: datetime | None, end_time: datetime | None,
metadata_ids: list[int] | None, metadata_ids: list[int] | None,
table: type[Statistics | StatisticsShortTerm],
) -> StatementLambdaElement: ) -> StatementLambdaElement:
"""Prepare a database query for statistics during a given period. """Prepare a database query for statistics during a given period.
This prepares a lambda_stmt query, so we don't insert the parameters yet. This prepares a lambda_stmt query, so we don't insert the parameters yet.
""" """
if table == StatisticsShortTerm: stmt = lambda_stmt(
stmt = lambda_stmt(lambda: select(*QUERY_STATISTICS_SHORT_TERM)) lambda: select(*QUERY_STATISTICS).filter(Statistics.start >= start_time)
else: )
stmt = lambda_stmt(lambda: select(*QUERY_STATISTICS))
stmt += lambda q: q.filter(table.start >= start_time)
if end_time is not None: if end_time is not None:
stmt += lambda q: q.filter(table.start < end_time) stmt += lambda q: q.filter(Statistics.start < end_time)
if metadata_ids: if metadata_ids:
stmt += lambda q: q.filter(table.metadata_id.in_(metadata_ids)) stmt += lambda q: q.filter(Statistics.metadata_id.in_(metadata_ids))
stmt += lambda q: q.order_by(Statistics.metadata_id, Statistics.start)
return stmt
stmt += lambda q: q.order_by(table.metadata_id, table.start)
def _statistics_during_period_stmt_short_term(
start_time: datetime,
end_time: datetime | None,
metadata_ids: list[int] | None,
) -> StatementLambdaElement:
"""Prepare a database query for short term statistics during a given period.
This prepares a lambda_stmt query, so we don't insert the parameters yet.
"""
stmt = lambda_stmt(
lambda: select(*QUERY_STATISTICS_SHORT_TERM).filter(
StatisticsShortTerm.start >= start_time
)
)
if end_time is not None:
stmt += lambda q: q.filter(StatisticsShortTerm.start < end_time)
if metadata_ids:
stmt += lambda q: q.filter(StatisticsShortTerm.metadata_id.in_(metadata_ids))
stmt += lambda q: q.order_by(
StatisticsShortTerm.metadata_id, StatisticsShortTerm.start
)
return stmt return stmt
@ -1034,10 +1051,12 @@ def statistics_during_period(
if period == "5minute": if period == "5minute":
table = StatisticsShortTerm table = StatisticsShortTerm
stmt = _statistics_during_period_stmt_short_term(
start_time, end_time, metadata_ids
)
else: else:
table = Statistics table = Statistics
stmt = _statistics_during_period_stmt(start_time, end_time, metadata_ids)
stmt = _statistics_during_period_stmt(start_time, end_time, metadata_ids, table)
stats = execute_stmt_lambda_element(session, stmt) stats = execute_stmt_lambda_element(session, stmt)
if not stats: if not stats:
@ -1069,19 +1088,27 @@ def statistics_during_period(
def _get_last_statistics_stmt( def _get_last_statistics_stmt(
metadata_id: int, metadata_id: int,
number_of_stats: int, number_of_stats: int,
table: type[Statistics | StatisticsShortTerm],
) -> StatementLambdaElement: ) -> StatementLambdaElement:
"""Generate a statement for number_of_stats statistics for a given statistic_id.""" """Generate a statement for number_of_stats statistics for a given statistic_id."""
if table == StatisticsShortTerm: return lambda_stmt(
stmt = lambda_stmt(lambda: select(*QUERY_STATISTICS_SHORT_TERM)) lambda: select(*QUERY_STATISTICS)
else: .filter_by(metadata_id=metadata_id)
stmt = lambda_stmt(lambda: select(*QUERY_STATISTICS)) .order_by(Statistics.metadata_id, Statistics.start.desc())
stmt += ( .limit(number_of_stats)
lambda q: q.filter_by(metadata_id=metadata_id) )
.order_by(table.metadata_id, table.start.desc())
def _get_last_statistics_short_term_stmt(
metadata_id: int,
number_of_stats: int,
) -> StatementLambdaElement:
"""Generate a statement for number_of_stats short term statistics for a given statistic_id."""
return lambda_stmt(
lambda: select(*QUERY_STATISTICS_SHORT_TERM)
.filter_by(metadata_id=metadata_id)
.order_by(StatisticsShortTerm.metadata_id, StatisticsShortTerm.start.desc())
.limit(number_of_stats) .limit(number_of_stats)
) )
return stmt
def _get_last_statistics( def _get_last_statistics(
@ -1099,7 +1126,10 @@ def _get_last_statistics(
if not metadata: if not metadata:
return {} return {}
metadata_id = metadata[statistic_id][0] metadata_id = metadata[statistic_id][0]
stmt = _get_last_statistics_stmt(metadata_id, number_of_stats, table) if table == Statistics:
stmt = _get_last_statistics_stmt(metadata_id, number_of_stats)
else:
stmt = _get_last_statistics_short_term_stmt(metadata_id, number_of_stats)
stats = execute_stmt_lambda_element(session, stmt) stats = execute_stmt_lambda_element(session, stmt)
if not stats: if not stats:
@ -1136,12 +1166,9 @@ def get_last_short_term_statistics(
) )
def _latest_short_term_statistics_stmt( def _generate_most_recent_statistic_row(metadata_ids: list[int]) -> Subquery:
metadata_ids: list[int], """Generate the subquery to find the most recent statistic row."""
) -> StatementLambdaElement: return (
"""Create the statement for finding the latest short term stat rows."""
stmt = lambda_stmt(lambda: select(*QUERY_STATISTICS_SHORT_TERM))
most_recent_statistic_row = (
select( select(
StatisticsShortTerm.metadata_id, StatisticsShortTerm.metadata_id,
func.max(StatisticsShortTerm.start).label("start_max"), func.max(StatisticsShortTerm.start).label("start_max"),
@ -1149,6 +1176,14 @@ def _latest_short_term_statistics_stmt(
.where(StatisticsShortTerm.metadata_id.in_(metadata_ids)) .where(StatisticsShortTerm.metadata_id.in_(metadata_ids))
.group_by(StatisticsShortTerm.metadata_id) .group_by(StatisticsShortTerm.metadata_id)
).subquery() ).subquery()
def _latest_short_term_statistics_stmt(
metadata_ids: list[int],
) -> StatementLambdaElement:
"""Create the statement for finding the latest short term stat rows."""
stmt = lambda_stmt(lambda: select(*QUERY_STATISTICS_SHORT_TERM))
most_recent_statistic_row = _generate_most_recent_statistic_row(metadata_ids)
stmt += lambda s: s.join( stmt += lambda s: s.join(
most_recent_statistic_row, most_recent_statistic_row,
( (

View File

@ -100,6 +100,15 @@ def test_compile_hourly_statistics(hass_recorder):
stats = statistics_during_period(hass, zero, period="5minute") stats = statistics_during_period(hass, zero, period="5minute")
assert stats == {"sensor.test1": expected_stats1, "sensor.test2": expected_stats2} assert stats == {"sensor.test1": expected_stats1, "sensor.test2": expected_stats2}
# Test statistics_during_period with a far future start and end date
future = dt_util.as_utc(dt_util.parse_datetime("2221-11-01 00:00:00"))
stats = statistics_during_period(hass, future, end_time=future, period="5minute")
assert stats == {}
# Test statistics_during_period with a far future end date
stats = statistics_during_period(hass, zero, end_time=future, period="5minute")
assert stats == {"sensor.test1": expected_stats1, "sensor.test2": expected_stats2}
stats = statistics_during_period( stats = statistics_during_period(
hass, zero, statistic_ids=["sensor.test2"], period="5minute" hass, zero, statistic_ids=["sensor.test2"], period="5minute"
) )
@ -814,6 +823,59 @@ def test_monthly_statistics(hass_recorder, caplog, timezone):
] ]
} }
stats = statistics_during_period(
hass,
start_time=zero,
statistic_ids=["not", "the", "same", "test:total_energy_import"],
period="month",
)
sep_start = dt_util.as_utc(dt_util.parse_datetime("2021-09-01 00:00:00"))
sep_end = dt_util.as_utc(dt_util.parse_datetime("2021-10-01 00:00:00"))
oct_start = dt_util.as_utc(dt_util.parse_datetime("2021-10-01 00:00:00"))
oct_end = dt_util.as_utc(dt_util.parse_datetime("2021-11-01 00:00:00"))
assert stats == {
"test:total_energy_import": [
{
"statistic_id": "test:total_energy_import",
"start": sep_start.isoformat(),
"end": sep_end.isoformat(),
"max": None,
"mean": None,
"min": None,
"last_reset": None,
"state": approx(1.0),
"sum": approx(3.0),
},
{
"statistic_id": "test:total_energy_import",
"start": oct_start.isoformat(),
"end": oct_end.isoformat(),
"max": None,
"mean": None,
"min": None,
"last_reset": None,
"state": approx(3.0),
"sum": approx(5.0),
},
]
}
# Use 5minute to ensure table switch works
stats = statistics_during_period(
hass,
start_time=zero,
statistic_ids=["test:total_energy_import", "with_other"],
period="5minute",
)
assert stats == {}
# Ensure future date has not data
future = dt_util.as_utc(dt_util.parse_datetime("2221-11-01 00:00:00"))
stats = statistics_during_period(
hass, start_time=future, end_time=future, period="month"
)
assert stats == {}
dt_util.set_default_time_zone(dt_util.get_time_zone("UTC")) dt_util.set_default_time_zone(dt_util.get_time_zone("UTC"))