Reduce branching in generated lambda_stmts (#73042)
parent
3744edc512
commit
c66b000d34
|
@ -15,6 +15,7 @@ from sqlalchemy.orm.query import Query
|
||||||
from sqlalchemy.orm.session import Session
|
from sqlalchemy.orm.session import Session
|
||||||
from sqlalchemy.sql.expression import literal
|
from sqlalchemy.sql.expression import literal
|
||||||
from sqlalchemy.sql.lambdas import StatementLambdaElement
|
from sqlalchemy.sql.lambdas import StatementLambdaElement
|
||||||
|
from sqlalchemy.sql.selectable import Subquery
|
||||||
|
|
||||||
from homeassistant.components import recorder
|
from homeassistant.components import recorder
|
||||||
from homeassistant.components.websocket_api.const import (
|
from homeassistant.components.websocket_api.const import (
|
||||||
|
@ -485,6 +486,25 @@ def _get_states_for_entites_stmt(
|
||||||
return stmt
|
return stmt
|
||||||
|
|
||||||
|
|
||||||
|
def _generate_most_recent_states_by_date(
|
||||||
|
run_start: datetime,
|
||||||
|
utc_point_in_time: datetime,
|
||||||
|
) -> Subquery:
|
||||||
|
"""Generate the sub query for the most recent states by data."""
|
||||||
|
return (
|
||||||
|
select(
|
||||||
|
States.entity_id.label("max_entity_id"),
|
||||||
|
func.max(States.last_updated).label("max_last_updated"),
|
||||||
|
)
|
||||||
|
.filter(
|
||||||
|
(States.last_updated >= run_start)
|
||||||
|
& (States.last_updated < utc_point_in_time)
|
||||||
|
)
|
||||||
|
.group_by(States.entity_id)
|
||||||
|
.subquery()
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def _get_states_for_all_stmt(
|
def _get_states_for_all_stmt(
|
||||||
schema_version: int,
|
schema_version: int,
|
||||||
run_start: datetime,
|
run_start: datetime,
|
||||||
|
@ -500,17 +520,8 @@ def _get_states_for_all_stmt(
|
||||||
# query, then filter out unwanted domains as well as applying the custom filter.
|
# query, then filter out unwanted domains as well as applying the custom filter.
|
||||||
# This filtering can't be done in the inner query because the domain column is
|
# This filtering can't be done in the inner query because the domain column is
|
||||||
# not indexed and we can't control what's in the custom filter.
|
# not indexed and we can't control what's in the custom filter.
|
||||||
most_recent_states_by_date = (
|
most_recent_states_by_date = _generate_most_recent_states_by_date(
|
||||||
select(
|
run_start, utc_point_in_time
|
||||||
States.entity_id.label("max_entity_id"),
|
|
||||||
func.max(States.last_updated).label("max_last_updated"),
|
|
||||||
)
|
|
||||||
.filter(
|
|
||||||
(States.last_updated >= run_start)
|
|
||||||
& (States.last_updated < utc_point_in_time)
|
|
||||||
)
|
|
||||||
.group_by(States.entity_id)
|
|
||||||
.subquery()
|
|
||||||
)
|
)
|
||||||
stmt += lambda q: q.where(
|
stmt += lambda q: q.where(
|
||||||
States.state_id
|
States.state_id
|
||||||
|
|
|
@ -20,6 +20,7 @@ from sqlalchemy.exc import SQLAlchemyError, StatementError
|
||||||
from sqlalchemy.orm.session import Session
|
from sqlalchemy.orm.session import Session
|
||||||
from sqlalchemy.sql.expression import literal_column, true
|
from sqlalchemy.sql.expression import literal_column, true
|
||||||
from sqlalchemy.sql.lambdas import StatementLambdaElement
|
from sqlalchemy.sql.lambdas import StatementLambdaElement
|
||||||
|
from sqlalchemy.sql.selectable import Subquery
|
||||||
import voluptuous as vol
|
import voluptuous as vol
|
||||||
|
|
||||||
from homeassistant.const import (
|
from homeassistant.const import (
|
||||||
|
@ -484,14 +485,13 @@ def _compile_hourly_statistics_summary_mean_stmt(
|
||||||
start_time: datetime, end_time: datetime
|
start_time: datetime, end_time: datetime
|
||||||
) -> StatementLambdaElement:
|
) -> StatementLambdaElement:
|
||||||
"""Generate the summary mean statement for hourly statistics."""
|
"""Generate the summary mean statement for hourly statistics."""
|
||||||
stmt = lambda_stmt(lambda: select(*QUERY_STATISTICS_SUMMARY_MEAN))
|
return lambda_stmt(
|
||||||
stmt += (
|
lambda: select(*QUERY_STATISTICS_SUMMARY_MEAN)
|
||||||
lambda q: q.filter(StatisticsShortTerm.start >= start_time)
|
.filter(StatisticsShortTerm.start >= start_time)
|
||||||
.filter(StatisticsShortTerm.start < end_time)
|
.filter(StatisticsShortTerm.start < end_time)
|
||||||
.group_by(StatisticsShortTerm.metadata_id)
|
.group_by(StatisticsShortTerm.metadata_id)
|
||||||
.order_by(StatisticsShortTerm.metadata_id)
|
.order_by(StatisticsShortTerm.metadata_id)
|
||||||
)
|
)
|
||||||
return stmt
|
|
||||||
|
|
||||||
|
|
||||||
def compile_hourly_statistics(
|
def compile_hourly_statistics(
|
||||||
|
@ -985,26 +985,43 @@ def _statistics_during_period_stmt(
|
||||||
start_time: datetime,
|
start_time: datetime,
|
||||||
end_time: datetime | None,
|
end_time: datetime | None,
|
||||||
metadata_ids: list[int] | None,
|
metadata_ids: list[int] | None,
|
||||||
table: type[Statistics | StatisticsShortTerm],
|
|
||||||
) -> StatementLambdaElement:
|
) -> StatementLambdaElement:
|
||||||
"""Prepare a database query for statistics during a given period.
|
"""Prepare a database query for statistics during a given period.
|
||||||
|
|
||||||
This prepares a lambda_stmt query, so we don't insert the parameters yet.
|
This prepares a lambda_stmt query, so we don't insert the parameters yet.
|
||||||
"""
|
"""
|
||||||
if table == StatisticsShortTerm:
|
stmt = lambda_stmt(
|
||||||
stmt = lambda_stmt(lambda: select(*QUERY_STATISTICS_SHORT_TERM))
|
lambda: select(*QUERY_STATISTICS).filter(Statistics.start >= start_time)
|
||||||
else:
|
)
|
||||||
stmt = lambda_stmt(lambda: select(*QUERY_STATISTICS))
|
|
||||||
|
|
||||||
stmt += lambda q: q.filter(table.start >= start_time)
|
|
||||||
|
|
||||||
if end_time is not None:
|
if end_time is not None:
|
||||||
stmt += lambda q: q.filter(table.start < end_time)
|
stmt += lambda q: q.filter(Statistics.start < end_time)
|
||||||
|
|
||||||
if metadata_ids:
|
if metadata_ids:
|
||||||
stmt += lambda q: q.filter(table.metadata_id.in_(metadata_ids))
|
stmt += lambda q: q.filter(Statistics.metadata_id.in_(metadata_ids))
|
||||||
|
stmt += lambda q: q.order_by(Statistics.metadata_id, Statistics.start)
|
||||||
|
return stmt
|
||||||
|
|
||||||
stmt += lambda q: q.order_by(table.metadata_id, table.start)
|
|
||||||
|
def _statistics_during_period_stmt_short_term(
|
||||||
|
start_time: datetime,
|
||||||
|
end_time: datetime | None,
|
||||||
|
metadata_ids: list[int] | None,
|
||||||
|
) -> StatementLambdaElement:
|
||||||
|
"""Prepare a database query for short term statistics during a given period.
|
||||||
|
|
||||||
|
This prepares a lambda_stmt query, so we don't insert the parameters yet.
|
||||||
|
"""
|
||||||
|
stmt = lambda_stmt(
|
||||||
|
lambda: select(*QUERY_STATISTICS_SHORT_TERM).filter(
|
||||||
|
StatisticsShortTerm.start >= start_time
|
||||||
|
)
|
||||||
|
)
|
||||||
|
if end_time is not None:
|
||||||
|
stmt += lambda q: q.filter(StatisticsShortTerm.start < end_time)
|
||||||
|
if metadata_ids:
|
||||||
|
stmt += lambda q: q.filter(StatisticsShortTerm.metadata_id.in_(metadata_ids))
|
||||||
|
stmt += lambda q: q.order_by(
|
||||||
|
StatisticsShortTerm.metadata_id, StatisticsShortTerm.start
|
||||||
|
)
|
||||||
return stmt
|
return stmt
|
||||||
|
|
||||||
|
|
||||||
|
@ -1034,10 +1051,12 @@ def statistics_during_period(
|
||||||
|
|
||||||
if period == "5minute":
|
if period == "5minute":
|
||||||
table = StatisticsShortTerm
|
table = StatisticsShortTerm
|
||||||
|
stmt = _statistics_during_period_stmt_short_term(
|
||||||
|
start_time, end_time, metadata_ids
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
table = Statistics
|
table = Statistics
|
||||||
|
stmt = _statistics_during_period_stmt(start_time, end_time, metadata_ids)
|
||||||
stmt = _statistics_during_period_stmt(start_time, end_time, metadata_ids, table)
|
|
||||||
stats = execute_stmt_lambda_element(session, stmt)
|
stats = execute_stmt_lambda_element(session, stmt)
|
||||||
|
|
||||||
if not stats:
|
if not stats:
|
||||||
|
@ -1069,19 +1088,27 @@ def statistics_during_period(
|
||||||
def _get_last_statistics_stmt(
|
def _get_last_statistics_stmt(
|
||||||
metadata_id: int,
|
metadata_id: int,
|
||||||
number_of_stats: int,
|
number_of_stats: int,
|
||||||
table: type[Statistics | StatisticsShortTerm],
|
|
||||||
) -> StatementLambdaElement:
|
) -> StatementLambdaElement:
|
||||||
"""Generate a statement for number_of_stats statistics for a given statistic_id."""
|
"""Generate a statement for number_of_stats statistics for a given statistic_id."""
|
||||||
if table == StatisticsShortTerm:
|
return lambda_stmt(
|
||||||
stmt = lambda_stmt(lambda: select(*QUERY_STATISTICS_SHORT_TERM))
|
lambda: select(*QUERY_STATISTICS)
|
||||||
else:
|
.filter_by(metadata_id=metadata_id)
|
||||||
stmt = lambda_stmt(lambda: select(*QUERY_STATISTICS))
|
.order_by(Statistics.metadata_id, Statistics.start.desc())
|
||||||
stmt += (
|
.limit(number_of_stats)
|
||||||
lambda q: q.filter_by(metadata_id=metadata_id)
|
)
|
||||||
.order_by(table.metadata_id, table.start.desc())
|
|
||||||
|
|
||||||
|
def _get_last_statistics_short_term_stmt(
|
||||||
|
metadata_id: int,
|
||||||
|
number_of_stats: int,
|
||||||
|
) -> StatementLambdaElement:
|
||||||
|
"""Generate a statement for number_of_stats short term statistics for a given statistic_id."""
|
||||||
|
return lambda_stmt(
|
||||||
|
lambda: select(*QUERY_STATISTICS_SHORT_TERM)
|
||||||
|
.filter_by(metadata_id=metadata_id)
|
||||||
|
.order_by(StatisticsShortTerm.metadata_id, StatisticsShortTerm.start.desc())
|
||||||
.limit(number_of_stats)
|
.limit(number_of_stats)
|
||||||
)
|
)
|
||||||
return stmt
|
|
||||||
|
|
||||||
|
|
||||||
def _get_last_statistics(
|
def _get_last_statistics(
|
||||||
|
@ -1099,7 +1126,10 @@ def _get_last_statistics(
|
||||||
if not metadata:
|
if not metadata:
|
||||||
return {}
|
return {}
|
||||||
metadata_id = metadata[statistic_id][0]
|
metadata_id = metadata[statistic_id][0]
|
||||||
stmt = _get_last_statistics_stmt(metadata_id, number_of_stats, table)
|
if table == Statistics:
|
||||||
|
stmt = _get_last_statistics_stmt(metadata_id, number_of_stats)
|
||||||
|
else:
|
||||||
|
stmt = _get_last_statistics_short_term_stmt(metadata_id, number_of_stats)
|
||||||
stats = execute_stmt_lambda_element(session, stmt)
|
stats = execute_stmt_lambda_element(session, stmt)
|
||||||
|
|
||||||
if not stats:
|
if not stats:
|
||||||
|
@ -1136,12 +1166,9 @@ def get_last_short_term_statistics(
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def _latest_short_term_statistics_stmt(
|
def _generate_most_recent_statistic_row(metadata_ids: list[int]) -> Subquery:
|
||||||
metadata_ids: list[int],
|
"""Generate the subquery to find the most recent statistic row."""
|
||||||
) -> StatementLambdaElement:
|
return (
|
||||||
"""Create the statement for finding the latest short term stat rows."""
|
|
||||||
stmt = lambda_stmt(lambda: select(*QUERY_STATISTICS_SHORT_TERM))
|
|
||||||
most_recent_statistic_row = (
|
|
||||||
select(
|
select(
|
||||||
StatisticsShortTerm.metadata_id,
|
StatisticsShortTerm.metadata_id,
|
||||||
func.max(StatisticsShortTerm.start).label("start_max"),
|
func.max(StatisticsShortTerm.start).label("start_max"),
|
||||||
|
@ -1149,6 +1176,14 @@ def _latest_short_term_statistics_stmt(
|
||||||
.where(StatisticsShortTerm.metadata_id.in_(metadata_ids))
|
.where(StatisticsShortTerm.metadata_id.in_(metadata_ids))
|
||||||
.group_by(StatisticsShortTerm.metadata_id)
|
.group_by(StatisticsShortTerm.metadata_id)
|
||||||
).subquery()
|
).subquery()
|
||||||
|
|
||||||
|
|
||||||
|
def _latest_short_term_statistics_stmt(
|
||||||
|
metadata_ids: list[int],
|
||||||
|
) -> StatementLambdaElement:
|
||||||
|
"""Create the statement for finding the latest short term stat rows."""
|
||||||
|
stmt = lambda_stmt(lambda: select(*QUERY_STATISTICS_SHORT_TERM))
|
||||||
|
most_recent_statistic_row = _generate_most_recent_statistic_row(metadata_ids)
|
||||||
stmt += lambda s: s.join(
|
stmt += lambda s: s.join(
|
||||||
most_recent_statistic_row,
|
most_recent_statistic_row,
|
||||||
(
|
(
|
||||||
|
|
|
@ -100,6 +100,15 @@ def test_compile_hourly_statistics(hass_recorder):
|
||||||
stats = statistics_during_period(hass, zero, period="5minute")
|
stats = statistics_during_period(hass, zero, period="5minute")
|
||||||
assert stats == {"sensor.test1": expected_stats1, "sensor.test2": expected_stats2}
|
assert stats == {"sensor.test1": expected_stats1, "sensor.test2": expected_stats2}
|
||||||
|
|
||||||
|
# Test statistics_during_period with a far future start and end date
|
||||||
|
future = dt_util.as_utc(dt_util.parse_datetime("2221-11-01 00:00:00"))
|
||||||
|
stats = statistics_during_period(hass, future, end_time=future, period="5minute")
|
||||||
|
assert stats == {}
|
||||||
|
|
||||||
|
# Test statistics_during_period with a far future end date
|
||||||
|
stats = statistics_during_period(hass, zero, end_time=future, period="5minute")
|
||||||
|
assert stats == {"sensor.test1": expected_stats1, "sensor.test2": expected_stats2}
|
||||||
|
|
||||||
stats = statistics_during_period(
|
stats = statistics_during_period(
|
||||||
hass, zero, statistic_ids=["sensor.test2"], period="5minute"
|
hass, zero, statistic_ids=["sensor.test2"], period="5minute"
|
||||||
)
|
)
|
||||||
|
@ -814,6 +823,59 @@ def test_monthly_statistics(hass_recorder, caplog, timezone):
|
||||||
]
|
]
|
||||||
}
|
}
|
||||||
|
|
||||||
|
stats = statistics_during_period(
|
||||||
|
hass,
|
||||||
|
start_time=zero,
|
||||||
|
statistic_ids=["not", "the", "same", "test:total_energy_import"],
|
||||||
|
period="month",
|
||||||
|
)
|
||||||
|
sep_start = dt_util.as_utc(dt_util.parse_datetime("2021-09-01 00:00:00"))
|
||||||
|
sep_end = dt_util.as_utc(dt_util.parse_datetime("2021-10-01 00:00:00"))
|
||||||
|
oct_start = dt_util.as_utc(dt_util.parse_datetime("2021-10-01 00:00:00"))
|
||||||
|
oct_end = dt_util.as_utc(dt_util.parse_datetime("2021-11-01 00:00:00"))
|
||||||
|
assert stats == {
|
||||||
|
"test:total_energy_import": [
|
||||||
|
{
|
||||||
|
"statistic_id": "test:total_energy_import",
|
||||||
|
"start": sep_start.isoformat(),
|
||||||
|
"end": sep_end.isoformat(),
|
||||||
|
"max": None,
|
||||||
|
"mean": None,
|
||||||
|
"min": None,
|
||||||
|
"last_reset": None,
|
||||||
|
"state": approx(1.0),
|
||||||
|
"sum": approx(3.0),
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"statistic_id": "test:total_energy_import",
|
||||||
|
"start": oct_start.isoformat(),
|
||||||
|
"end": oct_end.isoformat(),
|
||||||
|
"max": None,
|
||||||
|
"mean": None,
|
||||||
|
"min": None,
|
||||||
|
"last_reset": None,
|
||||||
|
"state": approx(3.0),
|
||||||
|
"sum": approx(5.0),
|
||||||
|
},
|
||||||
|
]
|
||||||
|
}
|
||||||
|
|
||||||
|
# Use 5minute to ensure table switch works
|
||||||
|
stats = statistics_during_period(
|
||||||
|
hass,
|
||||||
|
start_time=zero,
|
||||||
|
statistic_ids=["test:total_energy_import", "with_other"],
|
||||||
|
period="5minute",
|
||||||
|
)
|
||||||
|
assert stats == {}
|
||||||
|
|
||||||
|
# Ensure future date has not data
|
||||||
|
future = dt_util.as_utc(dt_util.parse_datetime("2221-11-01 00:00:00"))
|
||||||
|
stats = statistics_during_period(
|
||||||
|
hass, start_time=future, end_time=future, period="month"
|
||||||
|
)
|
||||||
|
assert stats == {}
|
||||||
|
|
||||||
dt_util.set_default_time_zone(dt_util.get_time_zone("UTC"))
|
dt_util.set_default_time_zone(dt_util.get_time_zone("UTC"))
|
||||||
|
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue