From c66b000d34b2a22fbaef3dfd4f6ea4b7a807bacc Mon Sep 17 00:00:00 2001 From: "J. Nick Koston" Date: Sun, 5 Jun 2022 18:13:31 -1000 Subject: [PATCH] Reduce branching in generated lambda_stmts (#73042) --- homeassistant/components/recorder/history.py | 33 ++++-- .../components/recorder/statistics.py | 103 ++++++++++++------ tests/components/recorder/test_statistics.py | 62 +++++++++++ 3 files changed, 153 insertions(+), 45 deletions(-) diff --git a/homeassistant/components/recorder/history.py b/homeassistant/components/recorder/history.py index 49796bd0158..5dd5c0d3040 100644 --- a/homeassistant/components/recorder/history.py +++ b/homeassistant/components/recorder/history.py @@ -15,6 +15,7 @@ from sqlalchemy.orm.query import Query from sqlalchemy.orm.session import Session from sqlalchemy.sql.expression import literal from sqlalchemy.sql.lambdas import StatementLambdaElement +from sqlalchemy.sql.selectable import Subquery from homeassistant.components import recorder from homeassistant.components.websocket_api.const import ( @@ -485,6 +486,25 @@ def _get_states_for_entites_stmt( return stmt +def _generate_most_recent_states_by_date( + run_start: datetime, + utc_point_in_time: datetime, +) -> Subquery: + """Generate the sub query for the most recent states by data.""" + return ( + select( + States.entity_id.label("max_entity_id"), + func.max(States.last_updated).label("max_last_updated"), + ) + .filter( + (States.last_updated >= run_start) + & (States.last_updated < utc_point_in_time) + ) + .group_by(States.entity_id) + .subquery() + ) + + def _get_states_for_all_stmt( schema_version: int, run_start: datetime, @@ -500,17 +520,8 @@ def _get_states_for_all_stmt( # query, then filter out unwanted domains as well as applying the custom filter. # This filtering can't be done in the inner query because the domain column is # not indexed and we can't control what's in the custom filter. - most_recent_states_by_date = ( - select( - States.entity_id.label("max_entity_id"), - func.max(States.last_updated).label("max_last_updated"), - ) - .filter( - (States.last_updated >= run_start) - & (States.last_updated < utc_point_in_time) - ) - .group_by(States.entity_id) - .subquery() + most_recent_states_by_date = _generate_most_recent_states_by_date( + run_start, utc_point_in_time ) stmt += lambda q: q.where( States.state_id diff --git a/homeassistant/components/recorder/statistics.py b/homeassistant/components/recorder/statistics.py index 39fcb954ee9..012b34ec0ef 100644 --- a/homeassistant/components/recorder/statistics.py +++ b/homeassistant/components/recorder/statistics.py @@ -20,6 +20,7 @@ from sqlalchemy.exc import SQLAlchemyError, StatementError from sqlalchemy.orm.session import Session from sqlalchemy.sql.expression import literal_column, true from sqlalchemy.sql.lambdas import StatementLambdaElement +from sqlalchemy.sql.selectable import Subquery import voluptuous as vol from homeassistant.const import ( @@ -484,14 +485,13 @@ def _compile_hourly_statistics_summary_mean_stmt( start_time: datetime, end_time: datetime ) -> StatementLambdaElement: """Generate the summary mean statement for hourly statistics.""" - stmt = lambda_stmt(lambda: select(*QUERY_STATISTICS_SUMMARY_MEAN)) - stmt += ( - lambda q: q.filter(StatisticsShortTerm.start >= start_time) + return lambda_stmt( + lambda: select(*QUERY_STATISTICS_SUMMARY_MEAN) + .filter(StatisticsShortTerm.start >= start_time) .filter(StatisticsShortTerm.start < end_time) .group_by(StatisticsShortTerm.metadata_id) .order_by(StatisticsShortTerm.metadata_id) ) - return stmt def compile_hourly_statistics( @@ -985,26 +985,43 @@ def _statistics_during_period_stmt( start_time: datetime, end_time: datetime | None, metadata_ids: list[int] | None, - table: type[Statistics | StatisticsShortTerm], ) -> StatementLambdaElement: """Prepare a database query for statistics during a given period. This prepares a lambda_stmt query, so we don't insert the parameters yet. """ - if table == StatisticsShortTerm: - stmt = lambda_stmt(lambda: select(*QUERY_STATISTICS_SHORT_TERM)) - else: - stmt = lambda_stmt(lambda: select(*QUERY_STATISTICS)) - - stmt += lambda q: q.filter(table.start >= start_time) - + stmt = lambda_stmt( + lambda: select(*QUERY_STATISTICS).filter(Statistics.start >= start_time) + ) if end_time is not None: - stmt += lambda q: q.filter(table.start < end_time) - + stmt += lambda q: q.filter(Statistics.start < end_time) if metadata_ids: - stmt += lambda q: q.filter(table.metadata_id.in_(metadata_ids)) + stmt += lambda q: q.filter(Statistics.metadata_id.in_(metadata_ids)) + stmt += lambda q: q.order_by(Statistics.metadata_id, Statistics.start) + return stmt - stmt += lambda q: q.order_by(table.metadata_id, table.start) + +def _statistics_during_period_stmt_short_term( + start_time: datetime, + end_time: datetime | None, + metadata_ids: list[int] | None, +) -> StatementLambdaElement: + """Prepare a database query for short term statistics during a given period. + + This prepares a lambda_stmt query, so we don't insert the parameters yet. + """ + stmt = lambda_stmt( + lambda: select(*QUERY_STATISTICS_SHORT_TERM).filter( + StatisticsShortTerm.start >= start_time + ) + ) + if end_time is not None: + stmt += lambda q: q.filter(StatisticsShortTerm.start < end_time) + if metadata_ids: + stmt += lambda q: q.filter(StatisticsShortTerm.metadata_id.in_(metadata_ids)) + stmt += lambda q: q.order_by( + StatisticsShortTerm.metadata_id, StatisticsShortTerm.start + ) return stmt @@ -1034,10 +1051,12 @@ def statistics_during_period( if period == "5minute": table = StatisticsShortTerm + stmt = _statistics_during_period_stmt_short_term( + start_time, end_time, metadata_ids + ) else: table = Statistics - - stmt = _statistics_during_period_stmt(start_time, end_time, metadata_ids, table) + stmt = _statistics_during_period_stmt(start_time, end_time, metadata_ids) stats = execute_stmt_lambda_element(session, stmt) if not stats: @@ -1069,19 +1088,27 @@ def statistics_during_period( def _get_last_statistics_stmt( metadata_id: int, number_of_stats: int, - table: type[Statistics | StatisticsShortTerm], ) -> StatementLambdaElement: """Generate a statement for number_of_stats statistics for a given statistic_id.""" - if table == StatisticsShortTerm: - stmt = lambda_stmt(lambda: select(*QUERY_STATISTICS_SHORT_TERM)) - else: - stmt = lambda_stmt(lambda: select(*QUERY_STATISTICS)) - stmt += ( - lambda q: q.filter_by(metadata_id=metadata_id) - .order_by(table.metadata_id, table.start.desc()) + return lambda_stmt( + lambda: select(*QUERY_STATISTICS) + .filter_by(metadata_id=metadata_id) + .order_by(Statistics.metadata_id, Statistics.start.desc()) + .limit(number_of_stats) + ) + + +def _get_last_statistics_short_term_stmt( + metadata_id: int, + number_of_stats: int, +) -> StatementLambdaElement: + """Generate a statement for number_of_stats short term statistics for a given statistic_id.""" + return lambda_stmt( + lambda: select(*QUERY_STATISTICS_SHORT_TERM) + .filter_by(metadata_id=metadata_id) + .order_by(StatisticsShortTerm.metadata_id, StatisticsShortTerm.start.desc()) .limit(number_of_stats) ) - return stmt def _get_last_statistics( @@ -1099,7 +1126,10 @@ def _get_last_statistics( if not metadata: return {} metadata_id = metadata[statistic_id][0] - stmt = _get_last_statistics_stmt(metadata_id, number_of_stats, table) + if table == Statistics: + stmt = _get_last_statistics_stmt(metadata_id, number_of_stats) + else: + stmt = _get_last_statistics_short_term_stmt(metadata_id, number_of_stats) stats = execute_stmt_lambda_element(session, stmt) if not stats: @@ -1136,12 +1166,9 @@ def get_last_short_term_statistics( ) -def _latest_short_term_statistics_stmt( - metadata_ids: list[int], -) -> StatementLambdaElement: - """Create the statement for finding the latest short term stat rows.""" - stmt = lambda_stmt(lambda: select(*QUERY_STATISTICS_SHORT_TERM)) - most_recent_statistic_row = ( +def _generate_most_recent_statistic_row(metadata_ids: list[int]) -> Subquery: + """Generate the subquery to find the most recent statistic row.""" + return ( select( StatisticsShortTerm.metadata_id, func.max(StatisticsShortTerm.start).label("start_max"), @@ -1149,6 +1176,14 @@ def _latest_short_term_statistics_stmt( .where(StatisticsShortTerm.metadata_id.in_(metadata_ids)) .group_by(StatisticsShortTerm.metadata_id) ).subquery() + + +def _latest_short_term_statistics_stmt( + metadata_ids: list[int], +) -> StatementLambdaElement: + """Create the statement for finding the latest short term stat rows.""" + stmt = lambda_stmt(lambda: select(*QUERY_STATISTICS_SHORT_TERM)) + most_recent_statistic_row = _generate_most_recent_statistic_row(metadata_ids) stmt += lambda s: s.join( most_recent_statistic_row, ( diff --git a/tests/components/recorder/test_statistics.py b/tests/components/recorder/test_statistics.py index 882f00d2940..97e64716f49 100644 --- a/tests/components/recorder/test_statistics.py +++ b/tests/components/recorder/test_statistics.py @@ -100,6 +100,15 @@ def test_compile_hourly_statistics(hass_recorder): stats = statistics_during_period(hass, zero, period="5minute") assert stats == {"sensor.test1": expected_stats1, "sensor.test2": expected_stats2} + # Test statistics_during_period with a far future start and end date + future = dt_util.as_utc(dt_util.parse_datetime("2221-11-01 00:00:00")) + stats = statistics_during_period(hass, future, end_time=future, period="5minute") + assert stats == {} + + # Test statistics_during_period with a far future end date + stats = statistics_during_period(hass, zero, end_time=future, period="5minute") + assert stats == {"sensor.test1": expected_stats1, "sensor.test2": expected_stats2} + stats = statistics_during_period( hass, zero, statistic_ids=["sensor.test2"], period="5minute" ) @@ -814,6 +823,59 @@ def test_monthly_statistics(hass_recorder, caplog, timezone): ] } + stats = statistics_during_period( + hass, + start_time=zero, + statistic_ids=["not", "the", "same", "test:total_energy_import"], + period="month", + ) + sep_start = dt_util.as_utc(dt_util.parse_datetime("2021-09-01 00:00:00")) + sep_end = dt_util.as_utc(dt_util.parse_datetime("2021-10-01 00:00:00")) + oct_start = dt_util.as_utc(dt_util.parse_datetime("2021-10-01 00:00:00")) + oct_end = dt_util.as_utc(dt_util.parse_datetime("2021-11-01 00:00:00")) + assert stats == { + "test:total_energy_import": [ + { + "statistic_id": "test:total_energy_import", + "start": sep_start.isoformat(), + "end": sep_end.isoformat(), + "max": None, + "mean": None, + "min": None, + "last_reset": None, + "state": approx(1.0), + "sum": approx(3.0), + }, + { + "statistic_id": "test:total_energy_import", + "start": oct_start.isoformat(), + "end": oct_end.isoformat(), + "max": None, + "mean": None, + "min": None, + "last_reset": None, + "state": approx(3.0), + "sum": approx(5.0), + }, + ] + } + + # Use 5minute to ensure table switch works + stats = statistics_during_period( + hass, + start_time=zero, + statistic_ids=["test:total_energy_import", "with_other"], + period="5minute", + ) + assert stats == {} + + # Ensure future date has not data + future = dt_util.as_utc(dt_util.parse_datetime("2221-11-01 00:00:00")) + stats = statistics_during_period( + hass, start_time=future, end_time=future, period="month" + ) + assert stats == {} + dt_util.set_default_time_zone(dt_util.get_time_zone("UTC"))