"""Statistics helper.""" from __future__ import annotations from collections import defaultdict from collections.abc import Callable, Iterable, Sequence import contextlib import dataclasses from datetime import datetime, timedelta from functools import lru_cache, partial from itertools import chain, groupby import logging from operator import itemgetter import re from statistics import mean from typing import TYPE_CHECKING, Any, Literal, TypedDict, cast from sqlalchemy import Select, and_, bindparam, func, lambda_stmt, select, text from sqlalchemy.engine.row import Row from sqlalchemy.exc import SQLAlchemyError, StatementError from sqlalchemy.orm.session import Session from sqlalchemy.sql.lambdas import StatementLambdaElement import voluptuous as vol from homeassistant.const import ATTR_UNIT_OF_MEASUREMENT from homeassistant.core import HomeAssistant, callback, valid_entity_id from homeassistant.exceptions import HomeAssistantError from homeassistant.helpers.typing import UNDEFINED, UndefinedType from homeassistant.util import dt as dt_util from homeassistant.util.unit_conversion import ( BaseUnitConverter, DataRateConverter, DistanceConverter, ElectricCurrentConverter, ElectricPotentialConverter, EnergyConverter, InformationConverter, MassConverter, PowerConverter, PressureConverter, SpeedConverter, TemperatureConverter, UnitlessRatioConverter, VolumeConverter, ) from .const import ( DOMAIN, EVENT_RECORDER_5MIN_STATISTICS_GENERATED, EVENT_RECORDER_HOURLY_STATISTICS_GENERATED, INTEGRATION_PLATFORM_COMPILE_STATISTICS, INTEGRATION_PLATFORM_LIST_STATISTIC_IDS, INTEGRATION_PLATFORM_VALIDATE_STATISTICS, SupportedDialect, ) from .db_schema import ( STATISTICS_TABLES, Statistics, StatisticsBase, StatisticsRuns, StatisticsShortTerm, ) from .models import ( StatisticData, StatisticDataTimestamp, StatisticMetaData, StatisticResult, datetime_to_timestamp_or_none, process_timestamp, ) from .util import ( execute, execute_stmt_lambda_element, get_instance, retryable_database_job, session_scope, ) if TYPE_CHECKING: from . import Recorder QUERY_STATISTICS = ( Statistics.metadata_id, Statistics.start_ts, Statistics.mean, Statistics.min, Statistics.max, Statistics.last_reset_ts, Statistics.state, Statistics.sum, ) QUERY_STATISTICS_SHORT_TERM = ( StatisticsShortTerm.metadata_id, StatisticsShortTerm.start_ts, StatisticsShortTerm.mean, StatisticsShortTerm.min, StatisticsShortTerm.max, StatisticsShortTerm.last_reset_ts, StatisticsShortTerm.state, StatisticsShortTerm.sum, ) QUERY_STATISTICS_SUMMARY_MEAN = ( StatisticsShortTerm.metadata_id, func.avg(StatisticsShortTerm.mean), # https://github.com/sqlalchemy/sqlalchemy/issues/9189 # pylint: disable-next=not-callable func.min(StatisticsShortTerm.min), # https://github.com/sqlalchemy/sqlalchemy/issues/9189 # pylint: disable-next=not-callable func.max(StatisticsShortTerm.max), ) QUERY_STATISTICS_SUMMARY_SUM = ( StatisticsShortTerm.metadata_id, StatisticsShortTerm.start_ts, StatisticsShortTerm.last_reset_ts, StatisticsShortTerm.state, StatisticsShortTerm.sum, func.row_number() .over( # type: ignore[no-untyped-call] partition_by=StatisticsShortTerm.metadata_id, order_by=StatisticsShortTerm.start_ts.desc(), ) .label("rownum"), ) STATISTIC_UNIT_TO_UNIT_CONVERTER: dict[str | None, type[BaseUnitConverter]] = { **{unit: DataRateConverter for unit in DataRateConverter.VALID_UNITS}, **{unit: DistanceConverter for unit in DistanceConverter.VALID_UNITS}, **{unit: ElectricCurrentConverter for unit in ElectricCurrentConverter.VALID_UNITS}, **{ unit: ElectricPotentialConverter for unit in ElectricPotentialConverter.VALID_UNITS }, **{unit: EnergyConverter for unit in EnergyConverter.VALID_UNITS}, **{unit: InformationConverter for unit in InformationConverter.VALID_UNITS}, **{unit: MassConverter for unit in MassConverter.VALID_UNITS}, **{unit: PowerConverter for unit in PowerConverter.VALID_UNITS}, **{unit: PressureConverter for unit in PressureConverter.VALID_UNITS}, **{unit: SpeedConverter for unit in SpeedConverter.VALID_UNITS}, **{unit: TemperatureConverter for unit in TemperatureConverter.VALID_UNITS}, **{unit: UnitlessRatioConverter for unit in UnitlessRatioConverter.VALID_UNITS}, **{unit: VolumeConverter for unit in VolumeConverter.VALID_UNITS}, } _LOGGER = logging.getLogger(__name__) class BaseStatisticsRow(TypedDict, total=False): """A processed row of statistic data.""" start: float class StatisticsRow(BaseStatisticsRow, total=False): """A processed row of statistic data.""" end: float last_reset: float | None state: float | None sum: float | None min: float | None max: float | None mean: float | None change: float | None def _get_unit_class(unit: str | None) -> str | None: """Get corresponding unit class from from the statistics unit.""" if converter := STATISTIC_UNIT_TO_UNIT_CONVERTER.get(unit): return converter.UNIT_CLASS return None def get_display_unit( hass: HomeAssistant, statistic_id: str, statistic_unit: str | None, ) -> str | None: """Return the unit which the statistic will be displayed in.""" if (converter := STATISTIC_UNIT_TO_UNIT_CONVERTER.get(statistic_unit)) is None: return statistic_unit state_unit: str | None = statistic_unit if state := hass.states.get(statistic_id): state_unit = state.attributes.get(ATTR_UNIT_OF_MEASUREMENT) if state_unit == statistic_unit or state_unit not in converter.VALID_UNITS: # Guard against invalid state unit in the DB return statistic_unit return state_unit def _get_statistic_to_display_unit_converter( statistic_unit: str | None, state_unit: str | None, requested_units: dict[str, str] | None, ) -> Callable[[float | None], float | None] | None: """Prepare a converter from the statistics unit to display unit.""" if (converter := STATISTIC_UNIT_TO_UNIT_CONVERTER.get(statistic_unit)) is None: return None display_unit: str | None unit_class = converter.UNIT_CLASS if requested_units and unit_class in requested_units: display_unit = requested_units[unit_class] else: display_unit = state_unit if display_unit not in converter.VALID_UNITS: # Guard against invalid state unit in the DB return None if display_unit == statistic_unit: return None convert = converter.convert def _from_normalized_unit(val: float | None) -> float | None: """Return val.""" if val is None: return val return convert(val, statistic_unit, display_unit) return _from_normalized_unit def _get_display_to_statistic_unit_converter( display_unit: str | None, statistic_unit: str | None, ) -> Callable[[float], float]: """Prepare a converter from the display unit to the statistics unit.""" def no_conversion(val: float) -> float: """Return val.""" return val if (converter := STATISTIC_UNIT_TO_UNIT_CONVERTER.get(statistic_unit)) is None: return no_conversion return partial(converter.convert, from_unit=display_unit, to_unit=statistic_unit) def _get_unit_converter( from_unit: str, to_unit: str ) -> Callable[[float | None], float | None]: """Prepare a converter from a unit to another unit.""" def convert_units( val: float | None, conv: type[BaseUnitConverter], from_unit: str, to_unit: str ) -> float | None: """Return converted val.""" if val is None: return val return conv.convert(val, from_unit=from_unit, to_unit=to_unit) for conv in STATISTIC_UNIT_TO_UNIT_CONVERTER.values(): if from_unit in conv.VALID_UNITS and to_unit in conv.VALID_UNITS: return partial( convert_units, conv=conv, from_unit=from_unit, to_unit=to_unit ) raise HomeAssistantError def can_convert_units(from_unit: str | None, to_unit: str | None) -> bool: """Return True if it's possible to convert from from_unit to to_unit.""" for converter in STATISTIC_UNIT_TO_UNIT_CONVERTER.values(): if from_unit in converter.VALID_UNITS and to_unit in converter.VALID_UNITS: return True return False @dataclasses.dataclass class PlatformCompiledStatistics: """Compiled Statistics from a platform.""" platform_stats: list[StatisticResult] current_metadata: dict[str, tuple[int, StatisticMetaData]] def split_statistic_id(entity_id: str) -> list[str]: """Split a state entity ID into domain and object ID.""" return entity_id.split(":", 1) VALID_STATISTIC_ID = re.compile(r"^(?!.+__)(?!_)[\da-z_]+(? bool: """Test if a statistic ID is a valid format. Format: : where both are slugs. """ return VALID_STATISTIC_ID.match(statistic_id) is not None def validate_statistic_id(value: str) -> str: """Validate statistic ID.""" if valid_statistic_id(value): return value raise vol.Invalid(f"Statistics ID {value} is an invalid statistic ID") @dataclasses.dataclass class ValidationIssue: """Error or warning message.""" type: str data: dict[str, str | None] | None = None def as_dict(self) -> dict: """Return dictionary version.""" return dataclasses.asdict(self) def get_start_time() -> datetime: """Return start time.""" now = dt_util.utcnow() current_period_minutes = now.minute - now.minute % 5 current_period = now.replace(minute=current_period_minutes, second=0, microsecond=0) last_period = current_period - timedelta(minutes=5) return last_period def _compile_hourly_statistics_summary_mean_stmt( start_time_ts: float, end_time_ts: float ) -> StatementLambdaElement: """Generate the summary mean statement for hourly statistics.""" return lambda_stmt( lambda: select(*QUERY_STATISTICS_SUMMARY_MEAN) .filter(StatisticsShortTerm.start_ts >= start_time_ts) .filter(StatisticsShortTerm.start_ts < end_time_ts) .group_by(StatisticsShortTerm.metadata_id) .order_by(StatisticsShortTerm.metadata_id) ) def _compile_hourly_statistics_last_sum_stmt( start_time_ts: float, end_time_ts: float ) -> StatementLambdaElement: """Generate the summary mean statement for hourly statistics.""" return lambda_stmt( lambda: select( subquery := ( select(*QUERY_STATISTICS_SUMMARY_SUM) .filter(StatisticsShortTerm.start_ts >= start_time_ts) .filter(StatisticsShortTerm.start_ts < end_time_ts) .subquery() ) ) .filter(subquery.c.rownum == 1) .order_by(subquery.c.metadata_id) ) def _compile_hourly_statistics(session: Session, start: datetime) -> None: """Compile hourly statistics. This will summarize 5-minute statistics for one hour: - average, min max is computed by a database query - sum is taken from the last 5-minute entry during the hour """ start_time = start.replace(minute=0) start_time_ts = start_time.timestamp() end_time = start_time + timedelta(hours=1) end_time_ts = end_time.timestamp() # Compute last hour's average, min, max summary: dict[int, StatisticDataTimestamp] = {} stmt = _compile_hourly_statistics_summary_mean_stmt(start_time_ts, end_time_ts) stats = execute_stmt_lambda_element(session, stmt) if stats: for stat in stats: metadata_id, _mean, _min, _max = stat summary[metadata_id] = { "start_ts": start_time_ts, "mean": _mean, "min": _min, "max": _max, } stmt = _compile_hourly_statistics_last_sum_stmt(start_time_ts, end_time_ts) # Get last hour's last sum stats = execute_stmt_lambda_element(session, stmt) if stats: for stat in stats: metadata_id, start, last_reset_ts, state, _sum, _ = stat if metadata_id in summary: summary[metadata_id].update( { "last_reset_ts": last_reset_ts, "state": state, "sum": _sum, } ) else: summary[metadata_id] = { "start_ts": start_time_ts, "last_reset_ts": last_reset_ts, "state": state, "sum": _sum, } # Insert compiled hourly statistics in the database session.add_all( Statistics.from_stats_ts(metadata_id, summary_item) for metadata_id, summary_item in summary.items() ) @retryable_database_job("compile missing statistics") def compile_missing_statistics(instance: Recorder) -> bool: """Compile missing statistics.""" now = dt_util.utcnow() period_size = 5 last_period_minutes = now.minute - now.minute % period_size last_period = now.replace(minute=last_period_minutes, second=0, microsecond=0) start = now - timedelta(days=instance.keep_days) start = start.replace(minute=0, second=0, microsecond=0) # Commit every 12 hours of data commit_interval = 60 / period_size * 12 with session_scope( session=instance.get_session(), exception_filter=_filter_unique_constraint_integrity_error(instance), ) as session: # Find the newest statistics run, if any # https://github.com/sqlalchemy/sqlalchemy/issues/9189 # pylint: disable-next=not-callable if last_run := session.query(func.max(StatisticsRuns.start)).scalar(): start = max(start, process_timestamp(last_run) + timedelta(minutes=5)) periods_without_commit = 0 while start < last_period: periods_without_commit += 1 end = start + timedelta(minutes=period_size) _LOGGER.debug("Compiling missing statistics for %s-%s", start, end) modified_statistic_ids = _compile_statistics( instance, session, start, end >= last_period ) if periods_without_commit == commit_interval or modified_statistic_ids: session.commit() session.expunge_all() periods_without_commit = 0 start = end return True @retryable_database_job("compile statistics") def compile_statistics(instance: Recorder, start: datetime, fire_events: bool) -> bool: """Compile 5-minute statistics for all integrations with a recorder platform. The actual calculation is delegated to the platforms. """ # Return if we already have 5-minute statistics for the requested period with session_scope( session=instance.get_session(), exception_filter=_filter_unique_constraint_integrity_error(instance), ) as session: modified_statistic_ids = _compile_statistics( instance, session, start, fire_events ) if modified_statistic_ids: # In the rare case that we have modified statistic_ids, we reload the modified # statistics meta data into the cache in a fresh session to ensure that the # cache is up to date and future calls to get statistics meta data will # not have to hit the database again. with session_scope(session=instance.get_session(), read_only=True) as session: instance.statistics_meta_manager.get_many(session, modified_statistic_ids) return True def _get_first_id_stmt(start: datetime) -> StatementLambdaElement: """Return a statement that returns the first run_id at start.""" return lambda_stmt(lambda: select(StatisticsRuns.run_id).filter_by(start=start)) def _compile_statistics( instance: Recorder, session: Session, start: datetime, fire_events: bool ) -> set[str]: """Compile 5-minute statistics for all integrations with a recorder platform. This is a helper function for compile_statistics and compile_missing_statistics that does not retry on database errors since both callers already retry. returns a set of modified statistic_ids if any were modified. """ assert start.tzinfo == dt_util.UTC, "start must be in UTC" end = start + timedelta(minutes=5) statistics_meta_manager = instance.statistics_meta_manager modified_statistic_ids: set[str] = set() # Return if we already have 5-minute statistics for the requested period if execute_stmt_lambda_element(session, _get_first_id_stmt(start)): _LOGGER.debug("Statistics already compiled for %s-%s", start, end) return modified_statistic_ids _LOGGER.debug("Compiling statistics for %s-%s", start, end) platform_stats: list[StatisticResult] = [] current_metadata: dict[str, tuple[int, StatisticMetaData]] = {} # Collect statistics from all platforms implementing support for domain, platform in instance.hass.data[DOMAIN].recorder_platforms.items(): if not ( platform_compile_statistics := getattr( platform, INTEGRATION_PLATFORM_COMPILE_STATISTICS, None ) ): continue compiled: PlatformCompiledStatistics = platform_compile_statistics( instance.hass, start, end ) _LOGGER.debug( "Statistics for %s during %s-%s: %s", domain, start, end, compiled.platform_stats, ) platform_stats.extend(compiled.platform_stats) current_metadata.update(compiled.current_metadata) # Insert collected statistics in the database for stats in platform_stats: modified_statistic_id, metadata_id = statistics_meta_manager.update_or_add( session, stats["meta"], current_metadata ) if modified_statistic_id is not None: modified_statistic_ids.add(modified_statistic_id) _insert_statistics( session, StatisticsShortTerm, metadata_id, stats["stat"], ) if start.minute == 55: # A full hour is ready, summarize it _compile_hourly_statistics(session, start) session.add(StatisticsRuns(start=start)) if fire_events: instance.hass.bus.fire(EVENT_RECORDER_5MIN_STATISTICS_GENERATED) if start.minute == 55: instance.hass.bus.fire(EVENT_RECORDER_HOURLY_STATISTICS_GENERATED) return modified_statistic_ids def _adjust_sum_statistics( session: Session, table: type[StatisticsBase], metadata_id: int, start_time: datetime, adj: float, ) -> None: """Adjust statistics in the database.""" start_time_ts = start_time.timestamp() try: session.query(table).filter_by(metadata_id=metadata_id).filter( table.start_ts >= start_time_ts ).update( { table.sum: table.sum + adj, }, synchronize_session=False, ) except SQLAlchemyError: _LOGGER.exception( "Unexpected exception when updating statistics %s", id, ) def _insert_statistics( session: Session, table: type[StatisticsBase], metadata_id: int, statistic: StatisticData, ) -> None: """Insert statistics in the database.""" try: session.add(table.from_stats(metadata_id, statistic)) except SQLAlchemyError: _LOGGER.exception( "Unexpected exception when inserting statistics %s:%s ", metadata_id, statistic, ) def _update_statistics( session: Session, table: type[StatisticsBase], stat_id: int, statistic: StatisticData, ) -> None: """Insert statistics in the database.""" try: session.query(table).filter_by(id=stat_id).update( { table.mean: statistic.get("mean"), table.min: statistic.get("min"), table.max: statistic.get("max"), table.last_reset_ts: datetime_to_timestamp_or_none( statistic.get("last_reset") ), table.state: statistic.get("state"), table.sum: statistic.get("sum"), }, synchronize_session=False, ) except SQLAlchemyError: _LOGGER.exception( "Unexpected exception when updating statistics %s:%s ", stat_id, statistic, ) def get_metadata_with_session( instance: Recorder, session: Session, *, statistic_ids: set[str] | None = None, statistic_type: Literal["mean"] | Literal["sum"] | None = None, statistic_source: str | None = None, ) -> dict[str, tuple[int, StatisticMetaData]]: """Fetch meta data. Returns a dict of (metadata_id, StatisticMetaData) tuples indexed by statistic_id. If statistic_ids is given, fetch metadata only for the listed statistics_ids. If statistic_type is given, fetch metadata only for statistic_ids supporting it. """ return instance.statistics_meta_manager.get_many( session, statistic_ids=statistic_ids, statistic_type=statistic_type, statistic_source=statistic_source, ) def get_metadata( hass: HomeAssistant, *, statistic_ids: set[str] | None = None, statistic_type: Literal["mean"] | Literal["sum"] | None = None, statistic_source: str | None = None, ) -> dict[str, tuple[int, StatisticMetaData]]: """Return metadata for statistic_ids.""" with session_scope(hass=hass, read_only=True) as session: return get_metadata_with_session( get_instance(hass), session, statistic_ids=statistic_ids, statistic_type=statistic_type, statistic_source=statistic_source, ) def clear_statistics(instance: Recorder, statistic_ids: list[str]) -> None: """Clear statistics for a list of statistic_ids.""" with session_scope(session=instance.get_session()) as session: instance.statistics_meta_manager.delete(session, statistic_ids) def update_statistics_metadata( instance: Recorder, statistic_id: str, new_statistic_id: str | None | UndefinedType, new_unit_of_measurement: str | None | UndefinedType, ) -> None: """Update statistics metadata for a statistic_id.""" statistics_meta_manager = instance.statistics_meta_manager if new_unit_of_measurement is not UNDEFINED: with session_scope(session=instance.get_session()) as session: statistics_meta_manager.update_unit_of_measurement( session, statistic_id, new_unit_of_measurement ) if new_statistic_id is not UNDEFINED and new_statistic_id is not None: with session_scope( session=instance.get_session(), exception_filter=_filter_unique_constraint_integrity_error(instance), ) as session: statistics_meta_manager.update_statistic_id( session, DOMAIN, statistic_id, new_statistic_id ) async def async_list_statistic_ids( hass: HomeAssistant, statistic_ids: set[str] | None = None, statistic_type: Literal["mean"] | Literal["sum"] | None = None, ) -> list[dict]: """Return all statistic_ids (or filtered one) and unit of measurement. Queries the database for existing statistic_ids, as well as integrations with a recorder platform for statistic_ids which will be added in the next statistics period. """ instance = get_instance(hass) if statistic_ids is not None: # Try to get the results from the cache since there is nearly # always a cache hit. statistics_meta_manager = instance.statistics_meta_manager metadata = statistics_meta_manager.get_from_cache_threadsafe(statistic_ids) if not statistic_ids.difference(metadata): result = _statistic_by_id_from_metadata(hass, metadata) return _flatten_list_statistic_ids_metadata_result(result) return await instance.async_add_executor_job( list_statistic_ids, hass, statistic_ids, statistic_type, ) def _statistic_by_id_from_metadata( hass: HomeAssistant, metadata: dict[str, tuple[int, StatisticMetaData]], ) -> dict[str, dict[str, Any]]: """Return a list of results for a given metadata dict.""" return { meta["statistic_id"]: { "display_unit_of_measurement": get_display_unit( hass, meta["statistic_id"], meta["unit_of_measurement"] ), "has_mean": meta["has_mean"], "has_sum": meta["has_sum"], "name": meta["name"], "source": meta["source"], "unit_class": _get_unit_class(meta["unit_of_measurement"]), "unit_of_measurement": meta["unit_of_measurement"], } for _, meta in metadata.values() } def _flatten_list_statistic_ids_metadata_result( result: dict[str, dict[str, Any]] ) -> list[dict]: """Return a flat dict of metadata.""" return [ { "statistic_id": _id, "display_unit_of_measurement": info["display_unit_of_measurement"], "has_mean": info["has_mean"], "has_sum": info["has_sum"], "name": info.get("name"), "source": info["source"], "statistics_unit_of_measurement": info["unit_of_measurement"], "unit_class": info["unit_class"], } for _id, info in result.items() ] def list_statistic_ids( hass: HomeAssistant, statistic_ids: set[str] | None = None, statistic_type: Literal["mean"] | Literal["sum"] | None = None, ) -> list[dict]: """Return all statistic_ids (or filtered one) and unit of measurement. Queries the database for existing statistic_ids, as well as integrations with a recorder platform for statistic_ids which will be added in the next statistics period. """ result = {} instance = get_instance(hass) statistics_meta_manager = instance.statistics_meta_manager # Query the database with session_scope(hass=hass, read_only=True) as session: metadata = statistics_meta_manager.get_many( session, statistic_type=statistic_type, statistic_ids=statistic_ids ) result = _statistic_by_id_from_metadata(hass, metadata) if not statistic_ids or statistic_ids.difference(result): # If we want all statistic_ids, or some are missing, we need to query # the integrations for the missing ones. # # Query all integrations with a registered recorder platform for platform in hass.data[DOMAIN].recorder_platforms.values(): if not ( platform_list_statistic_ids := getattr( platform, INTEGRATION_PLATFORM_LIST_STATISTIC_IDS, None ) ): continue platform_statistic_ids = platform_list_statistic_ids( hass, statistic_ids=statistic_ids, statistic_type=statistic_type ) for key, meta in platform_statistic_ids.items(): if key in result: # The database has a higher priority than the integration continue result[key] = { "display_unit_of_measurement": meta["unit_of_measurement"], "has_mean": meta["has_mean"], "has_sum": meta["has_sum"], "name": meta["name"], "source": meta["source"], "unit_class": _get_unit_class(meta["unit_of_measurement"]), "unit_of_measurement": meta["unit_of_measurement"], } # Return a list of statistic_id + metadata return _flatten_list_statistic_ids_metadata_result(result) def _reduce_statistics( stats: dict[str, list[StatisticsRow]], same_period: Callable[[float, float], bool], period_start_end: Callable[[float], tuple[float, float]], period: timedelta, types: set[Literal["last_reset", "max", "mean", "min", "state", "sum"]], ) -> dict[str, list[StatisticsRow]]: """Reduce hourly statistics to daily or monthly statistics.""" result: dict[str, list[StatisticsRow]] = defaultdict(list) period_seconds = period.total_seconds() _want_mean = "mean" in types _want_min = "min" in types _want_max = "max" in types _want_last_reset = "last_reset" in types _want_state = "state" in types _want_sum = "sum" in types for statistic_id, stat_list in stats.items(): max_values: list[float] = [] mean_values: list[float] = [] min_values: list[float] = [] prev_stat: StatisticsRow = stat_list[0] fake_entry: StatisticsRow = {"start": stat_list[-1]["start"] + period_seconds} # Loop over the hourly statistics + a fake entry to end the period for statistic in chain(stat_list, (fake_entry,)): if not same_period(prev_stat["start"], statistic["start"]): start, end = period_start_end(prev_stat["start"]) # The previous statistic was the last entry of the period row: StatisticsRow = { "start": start, "end": end, } if _want_mean: row["mean"] = mean(mean_values) if mean_values else None mean_values.clear() if _want_min: row["min"] = min(min_values) if min_values else None min_values.clear() if _want_max: row["max"] = max(max_values) if max_values else None max_values.clear() if _want_last_reset: row["last_reset"] = prev_stat.get("last_reset") if _want_state: row["state"] = prev_stat.get("state") if _want_sum: row["sum"] = prev_stat["sum"] result[statistic_id].append(row) if _want_max and (_max := statistic.get("max")) is not None: max_values.append(_max) if _want_mean and (_mean := statistic.get("mean")) is not None: mean_values.append(_mean) if _want_min and (_min := statistic.get("min")) is not None: min_values.append(_min) prev_stat = statistic return result def reduce_day_ts_factory() -> ( tuple[ Callable[[float, float], bool], Callable[[float], tuple[float, float]], ] ): """Return functions to match same day and day start end.""" _boundries: tuple[float, float] = (0, 0) # We have to recreate _local_from_timestamp in the closure in case the timezone changes _local_from_timestamp = partial( datetime.fromtimestamp, tz=dt_util.DEFAULT_TIME_ZONE ) def _same_day_ts(time1: float, time2: float) -> bool: """Return True if time1 and time2 are in the same date.""" nonlocal _boundries if not _boundries[0] <= time1 < _boundries[1]: _boundries = _day_start_end_ts_cached(time1) return _boundries[0] <= time2 < _boundries[1] def _day_start_end_ts(time: float) -> tuple[float, float]: """Return the start and end of the period (day) time is within.""" start_local = _local_from_timestamp(time).replace( hour=0, minute=0, second=0, microsecond=0 ) return ( start_local.astimezone(dt_util.UTC).timestamp(), (start_local + timedelta(days=1)).astimezone(dt_util.UTC).timestamp(), ) # We create _day_start_end_ts_cached in the closure in case the timezone changes _day_start_end_ts_cached = lru_cache(maxsize=6)(_day_start_end_ts) return _same_day_ts, _day_start_end_ts_cached def _reduce_statistics_per_day( stats: dict[str, list[StatisticsRow]], types: set[Literal["last_reset", "max", "mean", "min", "state", "sum"]], ) -> dict[str, list[StatisticsRow]]: """Reduce hourly statistics to daily statistics.""" _same_day_ts, _day_start_end_ts = reduce_day_ts_factory() return _reduce_statistics( stats, _same_day_ts, _day_start_end_ts, timedelta(days=1), types ) def reduce_week_ts_factory() -> ( tuple[ Callable[[float, float], bool], Callable[[float], tuple[float, float]], ] ): """Return functions to match same week and week start end.""" _boundries: tuple[float, float] = (0, 0) # We have to recreate _local_from_timestamp in the closure in case the timezone changes _local_from_timestamp = partial( datetime.fromtimestamp, tz=dt_util.DEFAULT_TIME_ZONE ) def _same_week_ts(time1: float, time2: float) -> bool: """Return True if time1 and time2 are in the same year and week.""" nonlocal _boundries if not _boundries[0] <= time1 < _boundries[1]: _boundries = _week_start_end_ts_cached(time1) return _boundries[0] <= time2 < _boundries[1] def _week_start_end_ts(time: float) -> tuple[float, float]: """Return the start and end of the period (week) time is within.""" nonlocal _boundries time_local = _local_from_timestamp(time) start_local = time_local.replace( hour=0, minute=0, second=0, microsecond=0 ) - timedelta(days=time_local.weekday()) return ( start_local.astimezone(dt_util.UTC).timestamp(), (start_local + timedelta(days=7)).astimezone(dt_util.UTC).timestamp(), ) # We create _week_start_end_ts_cached in the closure in case the timezone changes _week_start_end_ts_cached = lru_cache(maxsize=6)(_week_start_end_ts) return _same_week_ts, _week_start_end_ts_cached def _reduce_statistics_per_week( stats: dict[str, list[StatisticsRow]], types: set[Literal["last_reset", "max", "mean", "min", "state", "sum"]], ) -> dict[str, list[StatisticsRow]]: """Reduce hourly statistics to weekly statistics.""" _same_week_ts, _week_start_end_ts = reduce_week_ts_factory() return _reduce_statistics( stats, _same_week_ts, _week_start_end_ts, timedelta(days=7), types ) def _find_month_end_time(timestamp: datetime) -> datetime: """Return the end of the month (midnight at the first day of the next month).""" # We add 4 days to the end to make sure we are in the next month return (timestamp.replace(day=28) + timedelta(days=4)).replace( day=1, hour=0, minute=0, second=0, microsecond=0 ) def reduce_month_ts_factory() -> ( tuple[ Callable[[float, float], bool], Callable[[float], tuple[float, float]], ] ): """Return functions to match same month and month start end.""" _boundries: tuple[float, float] = (0, 0) # We have to recreate _local_from_timestamp in the closure in case the timezone changes _local_from_timestamp = partial( datetime.fromtimestamp, tz=dt_util.DEFAULT_TIME_ZONE ) def _same_month_ts(time1: float, time2: float) -> bool: """Return True if time1 and time2 are in the same year and month.""" nonlocal _boundries if not _boundries[0] <= time1 < _boundries[1]: _boundries = _month_start_end_ts_cached(time1) return _boundries[0] <= time2 < _boundries[1] def _month_start_end_ts(time: float) -> tuple[float, float]: """Return the start and end of the period (month) time is within.""" start_local = _local_from_timestamp(time).replace( day=1, hour=0, minute=0, second=0, microsecond=0 ) end_local = _find_month_end_time(start_local) return ( start_local.astimezone(dt_util.UTC).timestamp(), end_local.astimezone(dt_util.UTC).timestamp(), ) # We create _month_start_end_ts_cached in the closure in case the timezone changes _month_start_end_ts_cached = lru_cache(maxsize=6)(_month_start_end_ts) return _same_month_ts, _month_start_end_ts_cached def _reduce_statistics_per_month( stats: dict[str, list[StatisticsRow]], types: set[Literal["last_reset", "max", "mean", "min", "state", "sum"]], ) -> dict[str, list[StatisticsRow]]: """Reduce hourly statistics to monthly statistics.""" _same_month_ts, _month_start_end_ts = reduce_month_ts_factory() return _reduce_statistics( stats, _same_month_ts, _month_start_end_ts, timedelta(days=31), types ) def _generate_statistics_during_period_stmt( start_time: datetime, end_time: datetime | None, metadata_ids: list[int] | None, table: type[StatisticsBase], types: set[Literal["last_reset", "max", "mean", "min", "state", "sum"]], ) -> StatementLambdaElement: """Prepare a database query for statistics during a given period. This prepares a lambda_stmt query, so we don't insert the parameters yet. """ start_time_ts = start_time.timestamp() stmt = _generate_select_columns_for_types_stmt(table, types) stmt += lambda q: q.filter(table.start_ts >= start_time_ts) if end_time is not None: end_time_ts = end_time.timestamp() stmt += lambda q: q.filter(table.start_ts < end_time_ts) if metadata_ids: stmt += lambda q: q.filter( # https://github.com/python/mypy/issues/2608 table.metadata_id.in_(metadata_ids) # type:ignore[arg-type] ) stmt += lambda q: q.order_by(table.metadata_id, table.start_ts) return stmt def _generate_max_mean_min_statistic_in_sub_period_stmt( columns: Select, start_time: datetime | None, end_time: datetime | None, table: type[StatisticsBase], metadata_id: int, ) -> StatementLambdaElement: stmt = lambda_stmt(lambda: columns.filter(table.metadata_id == metadata_id)) if start_time is not None: start_time_ts = start_time.timestamp() stmt += lambda q: q.filter(table.start_ts >= start_time_ts) if end_time is not None: end_time_ts = end_time.timestamp() stmt += lambda q: q.filter(table.start_ts < end_time_ts) return stmt def _get_max_mean_min_statistic_in_sub_period( session: Session, result: dict[str, float], start_time: datetime | None, end_time: datetime | None, table: type[StatisticsBase], types: set[Literal["max", "mean", "min", "change"]], metadata_id: int, ) -> None: """Return max, mean and min during the period.""" # Calculate max, mean, min columns = select() if "max" in types: # https://github.com/sqlalchemy/sqlalchemy/issues/9189 # pylint: disable-next=not-callable columns = columns.add_columns(func.max(table.max)) if "mean" in types: columns = columns.add_columns(func.avg(table.mean)) # https://github.com/sqlalchemy/sqlalchemy/issues/9189 # pylint: disable-next=not-callable columns = columns.add_columns(func.count(table.mean)) if "min" in types: # https://github.com/sqlalchemy/sqlalchemy/issues/9189 # pylint: disable-next=not-callable columns = columns.add_columns(func.min(table.min)) stmt = _generate_max_mean_min_statistic_in_sub_period_stmt( columns, start_time, end_time, table, metadata_id ) stats = cast(Sequence[Row[Any]], execute_stmt_lambda_element(session, stmt)) if not stats: return if "max" in types and (new_max := stats[0].max) is not None: old_max = result.get("max") result["max"] = max(new_max, old_max) if old_max is not None else new_max if "mean" in types and stats[0].avg is not None: # https://github.com/sqlalchemy/sqlalchemy/issues/9127 duration = stats[0].count * table.duration.total_seconds() # type: ignore[operator] result["duration"] = result.get("duration", 0.0) + duration result["mean_acc"] = result.get("mean_acc", 0.0) + stats[0].avg * duration if "min" in types and (new_min := stats[0].min) is not None: old_min = result.get("min") result["min"] = min(new_min, old_min) if old_min is not None else new_min def _get_max_mean_min_statistic( session: Session, head_start_time: datetime | None, head_end_time: datetime | None, main_start_time: datetime | None, main_end_time: datetime | None, tail_start_time: datetime | None, tail_end_time: datetime | None, tail_only: bool, metadata_id: int, types: set[Literal["max", "mean", "min", "change"]], ) -> dict[str, float | None]: """Return max, mean and min during the period. The mean is a time weighted average, combining hourly and 5-minute statistics if necessary. """ max_mean_min: dict[str, float] = {} result: dict[str, float | None] = {} if tail_start_time is not None: # Calculate max, mean, min _get_max_mean_min_statistic_in_sub_period( session, max_mean_min, tail_start_time, tail_end_time, StatisticsShortTerm, types, metadata_id, ) if not tail_only: _get_max_mean_min_statistic_in_sub_period( session, max_mean_min, main_start_time, main_end_time, Statistics, types, metadata_id, ) if head_start_time is not None: _get_max_mean_min_statistic_in_sub_period( session, max_mean_min, head_start_time, head_end_time, StatisticsShortTerm, types, metadata_id, ) if "max" in types: result["max"] = max_mean_min.get("max") if "mean" in types: if "mean_acc" not in max_mean_min: result["mean"] = None else: result["mean"] = max_mean_min["mean_acc"] / max_mean_min["duration"] if "min" in types: result["min"] = max_mean_min.get("min") return result def _first_statistic( session: Session, table: type[StatisticsBase], metadata_id: int, ) -> datetime | None: """Return the data of the oldest statistic row for a given metadata id.""" stmt = lambda_stmt( lambda: select(table.start_ts) .filter(table.metadata_id == metadata_id) .order_by(table.start_ts.asc()) .limit(1) ) if stats := cast(Sequence[Row], execute_stmt_lambda_element(session, stmt)): return dt_util.utc_from_timestamp(stats[0].start_ts) return None def _get_oldest_sum_statistic( session: Session, head_start_time: datetime | None, main_start_time: datetime | None, tail_start_time: datetime | None, oldest_stat: datetime | None, tail_only: bool, metadata_id: int, ) -> float | None: """Return the oldest non-NULL sum during the period.""" def _get_oldest_sum_statistic_in_sub_period( session: Session, start_time: datetime | None, table: type[StatisticsBase], metadata_id: int, ) -> float | None: """Return the oldest non-NULL sum during the period.""" stmt = lambda_stmt( lambda: select(table.sum) .filter(table.metadata_id == metadata_id) .filter(table.sum.is_not(None)) .order_by(table.start_ts.asc()) .limit(1) ) if start_time is not None: start_time = start_time + table.duration - timedelta.resolution if table == StatisticsShortTerm: minutes = start_time.minute - start_time.minute % 5 period = start_time.replace(minute=minutes, second=0, microsecond=0) else: period = start_time.replace(minute=0, second=0, microsecond=0) prev_period = period - table.duration prev_period_ts = prev_period.timestamp() stmt += lambda q: q.filter(table.start_ts >= prev_period_ts) stats = cast(Sequence[Row], execute_stmt_lambda_element(session, stmt)) return stats[0].sum if stats else None oldest_sum: float | None = None # This function won't be called if tail_only is False and main_start_time is None # the extra checks are added to satisfy MyPy if not tail_only and main_start_time is not None and oldest_stat is not None: period = main_start_time.replace(minute=0, second=0, microsecond=0) prev_period = period - Statistics.duration if prev_period < oldest_stat: return 0 if ( head_start_time is not None and ( oldest_sum := _get_oldest_sum_statistic_in_sub_period( session, head_start_time, StatisticsShortTerm, metadata_id ) ) is not None ): return oldest_sum if not tail_only: if ( oldest_sum := _get_oldest_sum_statistic_in_sub_period( session, main_start_time, Statistics, metadata_id ) ) is not None: return oldest_sum return 0 if ( tail_start_time is not None and ( oldest_sum := _get_oldest_sum_statistic_in_sub_period( session, tail_start_time, StatisticsShortTerm, metadata_id ) ) ) is not None: return oldest_sum return 0 def _get_newest_sum_statistic( session: Session, head_start_time: datetime | None, head_end_time: datetime | None, main_start_time: datetime | None, main_end_time: datetime | None, tail_start_time: datetime | None, tail_end_time: datetime | None, tail_only: bool, metadata_id: int, ) -> float | None: """Return the newest non-NULL sum during the period.""" def _get_newest_sum_statistic_in_sub_period( session: Session, start_time: datetime | None, end_time: datetime | None, table: type[StatisticsBase], metadata_id: int, ) -> float | None: """Return the newest non-NULL sum during the period.""" stmt = lambda_stmt( lambda: select( table.sum, ) .filter(table.metadata_id == metadata_id) .filter(table.sum.is_not(None)) .order_by(table.start_ts.desc()) .limit(1) ) if start_time is not None: start_time_ts = start_time.timestamp() stmt += lambda q: q.filter(table.start_ts >= start_time_ts) if end_time is not None: end_time_ts = end_time.timestamp() stmt += lambda q: q.filter(table.start_ts < end_time_ts) stats = cast(Sequence[Row], execute_stmt_lambda_element(session, stmt)) return stats[0].sum if stats else None newest_sum: float | None = None if tail_start_time is not None: newest_sum = _get_newest_sum_statistic_in_sub_period( session, tail_start_time, tail_end_time, StatisticsShortTerm, metadata_id ) if newest_sum is not None: return newest_sum if not tail_only: newest_sum = _get_newest_sum_statistic_in_sub_period( session, main_start_time, main_end_time, Statistics, metadata_id ) if newest_sum is not None: return newest_sum if head_start_time is not None: newest_sum = _get_newest_sum_statistic_in_sub_period( session, head_start_time, head_end_time, StatisticsShortTerm, metadata_id ) return newest_sum def statistic_during_period( hass: HomeAssistant, start_time: datetime | None, end_time: datetime | None, statistic_id: str, types: set[Literal["max", "mean", "min", "change"]] | None, units: dict[str, str] | None, ) -> dict[str, Any]: """Return a statistic data point for the UTC period start_time - end_time.""" metadata = None if not types: types = {"max", "mean", "min", "change"} result: dict[str, Any] = {} with session_scope(hass=hass, read_only=True) as session: # Fetch metadata for the given statistic_id if not ( metadata := get_instance(hass).statistics_meta_manager.get( session, statistic_id ) ): return result metadata_id = metadata[0] oldest_stat = _first_statistic(session, Statistics, metadata_id) oldest_5_min_stat = None if not valid_statistic_id(statistic_id): oldest_5_min_stat = _first_statistic( session, StatisticsShortTerm, metadata_id ) # To calculate the summary, data from the statistics (hourly) and # short_term_statistics (5 minute) tables is combined # - The short term statistics table is used for the head and tail of the period, # if the period it doesn't start or end on a full hour # - The statistics table is used for the remainder of the time now = dt_util.utcnow() if end_time is not None and end_time > now: end_time = now tail_only = ( start_time is not None and end_time is not None and end_time - start_time < timedelta(hours=1) ) # Calculate the head period head_start_time: datetime | None = None head_end_time: datetime | None = None if ( not tail_only and oldest_stat is not None and oldest_5_min_stat is not None and oldest_5_min_stat - oldest_stat < timedelta(hours=1) and (start_time is None or start_time < oldest_5_min_stat) ): # To improve accuracy of averaged for statistics which were added within # recorder's retention period. head_start_time = oldest_5_min_stat head_end_time = oldest_5_min_stat.replace( minute=0, second=0, microsecond=0 ) + timedelta(hours=1) elif not tail_only and start_time is not None and start_time.minute: head_start_time = start_time head_end_time = start_time.replace( minute=0, second=0, microsecond=0 ) + timedelta(hours=1) # Calculate the tail period tail_start_time: datetime | None = None tail_end_time: datetime | None = None if end_time is None: tail_start_time = now.replace(minute=0, second=0, microsecond=0) elif end_time.minute: tail_start_time = ( start_time if tail_only else end_time.replace(minute=0, second=0, microsecond=0) ) tail_end_time = end_time # Calculate the main period main_start_time: datetime | None = None main_end_time: datetime | None = None if not tail_only: main_start_time = start_time if head_end_time is None else head_end_time main_end_time = end_time if tail_start_time is None else tail_start_time if not types.isdisjoint({"max", "mean", "min"}): result = _get_max_mean_min_statistic( session, head_start_time, head_end_time, main_start_time, main_end_time, tail_start_time, tail_end_time, tail_only, metadata_id, types, ) if "change" in types: oldest_sum: float | None if start_time is None: oldest_sum = 0.0 else: oldest_sum = _get_oldest_sum_statistic( session, head_start_time, main_start_time, tail_start_time, oldest_stat, tail_only, metadata_id, ) newest_sum = _get_newest_sum_statistic( session, head_start_time, head_end_time, main_start_time, main_end_time, tail_start_time, tail_end_time, tail_only, metadata_id, ) # Calculate the difference between the oldest and newest sum if oldest_sum is not None and newest_sum is not None: result["change"] = newest_sum - oldest_sum else: result["change"] = None state_unit = unit = metadata[1]["unit_of_measurement"] if state := hass.states.get(statistic_id): state_unit = state.attributes.get(ATTR_UNIT_OF_MEASUREMENT) convert = _get_statistic_to_display_unit_converter(unit, state_unit, units) return {key: convert(value) if convert else value for key, value in result.items()} _type_column_mapping = { "last_reset": "last_reset_ts", "max": "max", "mean": "mean", "min": "min", "state": "state", "sum": "sum", } def _generate_select_columns_for_types_stmt( table: type[StatisticsBase], types: set[Literal["last_reset", "max", "mean", "min", "state", "sum"]], ) -> StatementLambdaElement: columns = select(table.metadata_id, table.start_ts) track_on: list[str | None] = [ table.__tablename__, # type: ignore[attr-defined] ] for key, column in _type_column_mapping.items(): if key in types: columns = columns.add_columns(getattr(table, column)) track_on.append(column) else: track_on.append(None) return lambda_stmt(lambda: columns, track_on=track_on) def _extract_metadata_and_discard_impossible_columns( metadata: dict[str, tuple[int, StatisticMetaData]], types: set[Literal["last_reset", "max", "mean", "min", "state", "sum"]], ) -> list[int]: """Extract metadata ids from metadata and discard impossible columns.""" metadata_ids = [] has_mean = False has_sum = False for metadata_id, stats_metadata in metadata.values(): metadata_ids.append(metadata_id) has_mean |= stats_metadata["has_mean"] has_sum |= stats_metadata["has_sum"] if not has_mean: types.discard("mean") types.discard("min") types.discard("max") if not has_sum: types.discard("sum") types.discard("state") return metadata_ids def _augment_result_with_change( hass: HomeAssistant, session: Session, start_time: datetime, units: dict[str, str] | None, _types: set[Literal["change", "last_reset", "max", "mean", "min", "state", "sum"]], table: type[Statistics | StatisticsShortTerm], metadata: dict[str, tuple[int, StatisticMetaData]], result: dict[str, list[StatisticsRow]], ) -> None: """Add change to the result.""" drop_sum = "sum" not in _types prev_sums = {} if tmp := _statistics_at_time( session, {metadata[statistic_id][0] for statistic_id in result}, table, start_time, {"sum"}, ): _metadata = dict(metadata.values()) for row in tmp: metadata_by_id = _metadata[row.metadata_id] statistic_id = metadata_by_id["statistic_id"] state_unit = unit = metadata_by_id["unit_of_measurement"] if state := hass.states.get(statistic_id): state_unit = state.attributes.get(ATTR_UNIT_OF_MEASUREMENT) convert = _get_statistic_to_display_unit_converter(unit, state_unit, units) if convert is not None: prev_sums[statistic_id] = convert(row.sum) else: prev_sums[statistic_id] = row.sum for statistic_id, rows in result.items(): prev_sum = prev_sums.get(statistic_id) or 0 for statistics_row in rows: if "sum" not in statistics_row: continue if drop_sum: _sum = statistics_row.pop("sum") else: _sum = statistics_row["sum"] if _sum is None: statistics_row["change"] = None continue statistics_row["change"] = _sum - prev_sum prev_sum = _sum def _statistics_during_period_with_session( hass: HomeAssistant, session: Session, start_time: datetime, end_time: datetime | None, statistic_ids: set[str] | None, period: Literal["5minute", "day", "hour", "week", "month"], units: dict[str, str] | None, _types: set[Literal["change", "last_reset", "max", "mean", "min", "state", "sum"]], ) -> dict[str, list[StatisticsRow]]: """Return statistic data points during UTC period start_time - end_time. If end_time is omitted, returns statistics newer than or equal to start_time. If statistic_ids is omitted, returns statistics for all statistics ids. """ if statistic_ids is not None and not isinstance(statistic_ids, set): # This is for backwards compatibility to avoid a breaking change # for custom integrations that call this method. statistic_ids = set(statistic_ids) # type: ignore[unreachable] # Fetch metadata for the given (or all) statistic_ids metadata = get_instance(hass).statistics_meta_manager.get_many( session, statistic_ids=statistic_ids ) if not metadata: return {} types: set[Literal["last_reset", "max", "mean", "min", "state", "sum"]] = set() for stat_type in _types: if stat_type == "change": types.add("sum") continue types.add(stat_type) metadata_ids = None if statistic_ids is not None: metadata_ids = _extract_metadata_and_discard_impossible_columns(metadata, types) # Align start_time and end_time with the period if period == "day": start_time = dt_util.as_local(start_time).replace( hour=0, minute=0, second=0, microsecond=0 ) start_time = start_time.replace() if end_time is not None: end_local = dt_util.as_local(end_time) end_time = end_local.replace( hour=0, minute=0, second=0, microsecond=0 ) + timedelta(days=1) elif period == "week": start_local = dt_util.as_local(start_time) start_time = start_local.replace( hour=0, minute=0, second=0, microsecond=0 ) - timedelta(days=start_local.weekday()) if end_time is not None: end_local = dt_util.as_local(end_time) end_time = ( end_local.replace(hour=0, minute=0, second=0, microsecond=0) - timedelta(days=end_local.weekday()) + timedelta(days=7) ) elif period == "month": start_time = dt_util.as_local(start_time).replace( day=1, hour=0, minute=0, second=0, microsecond=0 ) if end_time is not None: end_time = _find_month_end_time(dt_util.as_local(end_time)) table: type[Statistics | StatisticsShortTerm] = ( Statistics if period != "5minute" else StatisticsShortTerm ) stmt = _generate_statistics_during_period_stmt( start_time, end_time, metadata_ids, table, types ) stats = cast( Sequence[Row], execute_stmt_lambda_element(session, stmt, orm_rows=False) ) if not stats: return {} result = _sorted_statistics_to_dict( hass, session, stats, statistic_ids, metadata, True, table, start_time, units, types, ) if period == "day": result = _reduce_statistics_per_day(result, types) if period == "week": result = _reduce_statistics_per_week(result, types) if period == "month": result = _reduce_statistics_per_month(result, types) if "change" in _types: _augment_result_with_change( hass, session, start_time, units, _types, table, metadata, result ) # Return statistics combined with metadata return result def statistics_during_period( hass: HomeAssistant, start_time: datetime, end_time: datetime | None, statistic_ids: set[str] | None, period: Literal["5minute", "day", "hour", "week", "month"], units: dict[str, str] | None, types: set[Literal["change", "last_reset", "max", "mean", "min", "state", "sum"]], ) -> dict[str, list[StatisticsRow]]: """Return statistic data points during UTC period start_time - end_time. If end_time is omitted, returns statistics newer than or equal to start_time. If statistic_ids is omitted, returns statistics for all statistics ids. """ with session_scope(hass=hass, read_only=True) as session: return _statistics_during_period_with_session( hass, session, start_time, end_time, statistic_ids, period, units, types, ) def _get_last_statistics_stmt( metadata_id: int, number_of_stats: int, ) -> StatementLambdaElement: """Generate a statement for number_of_stats statistics for a given statistic_id.""" return lambda_stmt( lambda: select(*QUERY_STATISTICS) .filter_by(metadata_id=metadata_id) .order_by(Statistics.metadata_id, Statistics.start_ts.desc()) .limit(number_of_stats) ) def _get_last_statistics_short_term_stmt( metadata_id: int, number_of_stats: int, ) -> StatementLambdaElement: """Generate a statement for number_of_stats short term statistics. For a given statistic_id. """ return lambda_stmt( lambda: select(*QUERY_STATISTICS_SHORT_TERM) .filter_by(metadata_id=metadata_id) .order_by(StatisticsShortTerm.metadata_id, StatisticsShortTerm.start_ts.desc()) .limit(number_of_stats) ) def _get_last_statistics( hass: HomeAssistant, number_of_stats: int, statistic_id: str, convert_units: bool, table: type[StatisticsBase], types: set[Literal["last_reset", "max", "mean", "min", "state", "sum"]], ) -> dict[str, list[StatisticsRow]]: """Return the last number_of_stats statistics for a given statistic_id.""" statistic_ids = {statistic_id} with session_scope(hass=hass, read_only=True) as session: # Fetch metadata for the given statistic_id metadata = get_instance(hass).statistics_meta_manager.get_many( session, statistic_ids=statistic_ids ) if not metadata: return {} metadata_ids = _extract_metadata_and_discard_impossible_columns(metadata, types) metadata_id = metadata_ids[0] if table == Statistics: stmt = _get_last_statistics_stmt(metadata_id, number_of_stats) else: stmt = _get_last_statistics_short_term_stmt(metadata_id, number_of_stats) stats = cast( Sequence[Row], execute_stmt_lambda_element(session, stmt, orm_rows=False) ) if not stats: return {} # Return statistics combined with metadata return _sorted_statistics_to_dict( hass, session, stats, statistic_ids, metadata, convert_units, table, None, None, types, ) def get_last_statistics( hass: HomeAssistant, number_of_stats: int, statistic_id: str, convert_units: bool, types: set[Literal["last_reset", "max", "mean", "min", "state", "sum"]], ) -> dict[str, list[StatisticsRow]]: """Return the last number_of_stats statistics for a statistic_id.""" return _get_last_statistics( hass, number_of_stats, statistic_id, convert_units, Statistics, types ) def get_last_short_term_statistics( hass: HomeAssistant, number_of_stats: int, statistic_id: str, convert_units: bool, types: set[Literal["last_reset", "max", "mean", "min", "state", "sum"]], ) -> dict[str, list[StatisticsRow]]: """Return the last number_of_stats short term statistics for a statistic_id.""" return _get_last_statistics( hass, number_of_stats, statistic_id, convert_units, StatisticsShortTerm, types ) def _latest_short_term_statistics_stmt( metadata_ids: list[int], ) -> StatementLambdaElement: """Create the statement for finding the latest short term stat rows.""" return lambda_stmt( lambda: select(*QUERY_STATISTICS_SHORT_TERM).join( ( most_recent_statistic_row := ( select( StatisticsShortTerm.metadata_id, # https://github.com/sqlalchemy/sqlalchemy/issues/9189 # pylint: disable-next=not-callable func.max(StatisticsShortTerm.start_ts).label("start_max"), ) .where(StatisticsShortTerm.metadata_id.in_(metadata_ids)) .group_by(StatisticsShortTerm.metadata_id) ).subquery() ), (StatisticsShortTerm.metadata_id == most_recent_statistic_row.c.metadata_id) & (StatisticsShortTerm.start_ts == most_recent_statistic_row.c.start_max), ) ) def get_latest_short_term_statistics( hass: HomeAssistant, statistic_ids: set[str], types: set[Literal["last_reset", "max", "mean", "min", "state", "sum"]], metadata: dict[str, tuple[int, StatisticMetaData]] | None = None, ) -> dict[str, list[StatisticsRow]]: """Return the latest short term statistics for a list of statistic_ids.""" with session_scope(hass=hass, read_only=True) as session: # Fetch metadata for the given statistic_ids if not metadata: metadata = get_instance(hass).statistics_meta_manager.get_many( session, statistic_ids=statistic_ids ) if not metadata: return {} metadata_ids = _extract_metadata_and_discard_impossible_columns(metadata, types) stmt = _latest_short_term_statistics_stmt(metadata_ids) stats = cast( Sequence[Row], execute_stmt_lambda_element(session, stmt, orm_rows=False) ) if not stats: return {} # Return statistics combined with metadata return _sorted_statistics_to_dict( hass, session, stats, statistic_ids, metadata, False, StatisticsShortTerm, None, None, types, ) def _generate_statistics_at_time_stmt( table: type[StatisticsBase], metadata_ids: set[int], start_time_ts: float, types: set[Literal["last_reset", "max", "mean", "min", "state", "sum"]], ) -> StatementLambdaElement: """Create the statement for finding the statistics for a given time.""" stmt = _generate_select_columns_for_types_stmt(table, types) stmt += lambda q: q.join( ( most_recent_statistic_ids := ( select( # https://github.com/sqlalchemy/sqlalchemy/issues/9189 # pylint: disable-next=not-callable func.max(table.start_ts).label("max_start_ts"), table.metadata_id.label("max_metadata_id"), ) .filter(table.start_ts < start_time_ts) .filter(table.metadata_id.in_(metadata_ids)) .group_by(table.metadata_id) .subquery() ) ), and_( table.start_ts == most_recent_statistic_ids.c.max_start_ts, table.metadata_id == most_recent_statistic_ids.c.max_metadata_id, ), ) return stmt def _statistics_at_time( session: Session, metadata_ids: set[int], table: type[StatisticsBase], start_time: datetime, types: set[Literal["last_reset", "max", "mean", "min", "state", "sum"]], ) -> Sequence[Row] | None: """Return last known statistics, earlier than start_time, for the metadata_ids.""" start_time_ts = start_time.timestamp() stmt = _generate_statistics_at_time_stmt(table, metadata_ids, start_time_ts, types) return cast(Sequence[Row], execute_stmt_lambda_element(session, stmt)) def _fast_build_sum_list( stats_list: list[Row], table_duration_seconds: float, convert: Callable | None, start_ts_idx: int, sum_idx: int, ) -> list[StatisticsRow]: """Build a list of sum statistics.""" if convert: return [ { "start": (start_ts := db_state[start_ts_idx]), "end": start_ts + table_duration_seconds, "sum": convert(db_state[sum_idx]), } for db_state in stats_list ] return [ { "start": (start_ts := db_state[start_ts_idx]), "end": start_ts + table_duration_seconds, "sum": db_state[sum_idx], } for db_state in stats_list ] def _sorted_statistics_to_dict( # noqa: C901 hass: HomeAssistant, session: Session, stats: Sequence[Row[Any]], statistic_ids: set[str] | None, _metadata: dict[str, tuple[int, StatisticMetaData]], convert_units: bool, table: type[StatisticsBase], start_time: datetime | None, units: dict[str, str] | None, types: set[Literal["last_reset", "max", "mean", "min", "state", "sum"]], ) -> dict[str, list[StatisticsRow]]: """Convert SQL results into JSON friendly data structure.""" assert stats, "stats must not be empty" # Guard against implementation error result: dict[str, list[StatisticsRow]] = defaultdict(list) metadata = dict(_metadata.values()) # Identify metadata IDs for which no data was available at the requested start time field_map: dict[str, int] = {key: idx for idx, key in enumerate(stats[0]._fields)} metadata_id_idx = field_map["metadata_id"] start_ts_idx = field_map["start_ts"] stats_by_meta_id: dict[int, list[Row]] = {} seen_statistic_ids: set[str] = set() key_func = itemgetter(metadata_id_idx) for meta_id, group in groupby(stats, key_func): stats_list = stats_by_meta_id[meta_id] = list(group) seen_statistic_ids.add(metadata[meta_id]["statistic_id"]) # Set all statistic IDs to empty lists in result set to maintain the order if statistic_ids is not None: for stat_id in statistic_ids: # Only set the statistic ID if it is in the data to # avoid having to do a second loop to remove the # statistic IDs that are not in the data at the end if stat_id in seen_statistic_ids: result[stat_id] = [] # Figure out which fields we need to extract from the SQL result # and which indices they have in the result so we can avoid the overhead # of doing a dict lookup for each row mean_idx = field_map["mean"] if "mean" in types else None min_idx = field_map["min"] if "min" in types else None max_idx = field_map["max"] if "max" in types else None last_reset_ts_idx = field_map["last_reset_ts"] if "last_reset" in types else None state_idx = field_map["state"] if "state" in types else None sum_idx = field_map["sum"] if "sum" in types else None sum_only = len(types) == 1 and sum_idx is not None # Append all statistic entries, and optionally do unit conversion table_duration_seconds = table.duration.total_seconds() for meta_id, stats_list in stats_by_meta_id.items(): metadata_by_id = metadata[meta_id] statistic_id = metadata_by_id["statistic_id"] if convert_units: state_unit = unit = metadata_by_id["unit_of_measurement"] if state := hass.states.get(statistic_id): state_unit = state.attributes.get(ATTR_UNIT_OF_MEASUREMENT) convert = _get_statistic_to_display_unit_converter(unit, state_unit, units) else: convert = None if sum_only: # This function is extremely flexible and can handle all types of # statistics, but in practice we only ever use a few combinations. # # For energy, we only need sum statistics, so we can optimize # this path to avoid the overhead of the more generic function. assert sum_idx is not None result[statistic_id] = _fast_build_sum_list( stats_list, table_duration_seconds, convert, start_ts_idx, sum_idx, ) continue ent_results_append = result[statistic_id].append # # The below loop is a red hot path for energy, and every # optimization counts in here. # # Specifically, we want to avoid function calls, # attribute lookups, and dict lookups as much as possible. # for db_state in stats_list: row: StatisticsRow = { "start": (start_ts := db_state[start_ts_idx]), "end": start_ts + table_duration_seconds, } if last_reset_ts_idx is not None: row["last_reset"] = db_state[last_reset_ts_idx] if convert: if mean_idx is not None: row["mean"] = convert(db_state[mean_idx]) if min_idx is not None: row["min"] = convert(db_state[min_idx]) if max_idx is not None: row["max"] = convert(db_state[max_idx]) if state_idx is not None: row["state"] = convert(db_state[state_idx]) if sum_idx is not None: row["sum"] = convert(db_state[sum_idx]) else: if mean_idx is not None: row["mean"] = db_state[mean_idx] if min_idx is not None: row["min"] = db_state[min_idx] if max_idx is not None: row["max"] = db_state[max_idx] if state_idx is not None: row["state"] = db_state[state_idx] if sum_idx is not None: row["sum"] = db_state[sum_idx] ent_results_append(row) return result def validate_statistics(hass: HomeAssistant) -> dict[str, list[ValidationIssue]]: """Validate statistics.""" platform_validation: dict[str, list[ValidationIssue]] = {} for platform in hass.data[DOMAIN].recorder_platforms.values(): if platform_validate_statistics := getattr( platform, INTEGRATION_PLATFORM_VALIDATE_STATISTICS, None ): platform_validation.update(platform_validate_statistics(hass)) return platform_validation def _statistics_exists( session: Session, table: type[StatisticsBase], metadata_id: int, start: datetime, ) -> int | None: """Return id if a statistics entry already exists.""" start_ts = start.timestamp() result = ( session.query(table.id) .filter((table.metadata_id == metadata_id) & (table.start_ts == start_ts)) .first() ) return result.id if result else None @callback def _async_import_statistics( hass: HomeAssistant, metadata: StatisticMetaData, statistics: Iterable[StatisticData], ) -> None: """Validate timestamps and insert an import_statistics job in the queue.""" for statistic in statistics: start = statistic["start"] if start.tzinfo is None or start.tzinfo.utcoffset(start) is None: raise HomeAssistantError("Naive timestamp") if start.minute != 0 or start.second != 0 or start.microsecond != 0: raise HomeAssistantError("Invalid timestamp") statistic["start"] = dt_util.as_utc(start) if "last_reset" in statistic and statistic["last_reset"] is not None: last_reset = statistic["last_reset"] if ( last_reset.tzinfo is None or last_reset.tzinfo.utcoffset(last_reset) is None ): raise HomeAssistantError("Naive timestamp") statistic["last_reset"] = dt_util.as_utc(last_reset) # Insert job in recorder's queue get_instance(hass).async_import_statistics(metadata, statistics, Statistics) @callback def async_import_statistics( hass: HomeAssistant, metadata: StatisticMetaData, statistics: Iterable[StatisticData], ) -> None: """Import hourly statistics from an internal source. This inserts an import_statistics job in the recorder's queue. """ if not valid_entity_id(metadata["statistic_id"]): raise HomeAssistantError("Invalid statistic_id") # The source must not be empty and must be aligned with the statistic_id if not metadata["source"] or metadata["source"] != DOMAIN: raise HomeAssistantError("Invalid source") _async_import_statistics(hass, metadata, statistics) @callback def async_add_external_statistics( hass: HomeAssistant, metadata: StatisticMetaData, statistics: Iterable[StatisticData], ) -> None: """Add hourly statistics from an external source. This inserts an import_statistics job in the recorder's queue. """ # The statistic_id has same limitations as an entity_id, but with a ':' as separator if not valid_statistic_id(metadata["statistic_id"]): raise HomeAssistantError("Invalid statistic_id") # The source must not be empty and must be aligned with the statistic_id domain, _object_id = split_statistic_id(metadata["statistic_id"]) if not metadata["source"] or metadata["source"] != domain: raise HomeAssistantError("Invalid source") _async_import_statistics(hass, metadata, statistics) def _filter_unique_constraint_integrity_error( instance: Recorder, ) -> Callable[[Exception], bool]: def _filter_unique_constraint_integrity_error(err: Exception) -> bool: """Handle unique constraint integrity errors.""" if not isinstance(err, StatementError): return False assert instance.engine is not None dialect_name = instance.engine.dialect.name ignore = False if ( dialect_name == SupportedDialect.SQLITE and "UNIQUE constraint failed" in str(err) ): ignore = True if ( dialect_name == SupportedDialect.POSTGRESQL and err.orig and hasattr(err.orig, "pgcode") and err.orig.pgcode == "23505" ): ignore = True if ( dialect_name == SupportedDialect.MYSQL and err.orig and hasattr(err.orig, "args") ): with contextlib.suppress(TypeError): if err.orig.args[0] == 1062: ignore = True if ignore: _LOGGER.warning( ( "Blocked attempt to insert duplicated statistic rows, please report" " at %s" ), "https://github.com/home-assistant/core/issues?q=is%3Aopen+is%3Aissue+label%3A%22integration%3A+recorder%22", exc_info=err, ) return ignore return _filter_unique_constraint_integrity_error def _import_statistics_with_session( instance: Recorder, session: Session, metadata: StatisticMetaData, statistics: Iterable[StatisticData], table: type[StatisticsBase], ) -> bool: """Import statistics to the database.""" statistics_meta_manager = instance.statistics_meta_manager old_metadata_dict = statistics_meta_manager.get_many( session, statistic_ids={metadata["statistic_id"]} ) _, metadata_id = statistics_meta_manager.update_or_add( session, metadata, old_metadata_dict ) for stat in statistics: if stat_id := _statistics_exists(session, table, metadata_id, stat["start"]): _update_statistics(session, table, stat_id, stat) else: _insert_statistics(session, table, metadata_id, stat) return True @retryable_database_job("statistics") def import_statistics( instance: Recorder, metadata: StatisticMetaData, statistics: Iterable[StatisticData], table: type[StatisticsBase], ) -> bool: """Process an import_statistics job.""" with session_scope( session=instance.get_session(), exception_filter=_filter_unique_constraint_integrity_error(instance), ) as session: return _import_statistics_with_session( instance, session, metadata, statistics, table ) @retryable_database_job("adjust_statistics") def adjust_statistics( instance: Recorder, statistic_id: str, start_time: datetime, sum_adjustment: float, adjustment_unit: str, ) -> bool: """Process an add_statistics job.""" with session_scope(session=instance.get_session()) as session: metadata = instance.statistics_meta_manager.get_many( session, statistic_ids={statistic_id} ) if statistic_id not in metadata: return True statistic_unit = metadata[statistic_id][1]["unit_of_measurement"] convert = _get_display_to_statistic_unit_converter( adjustment_unit, statistic_unit ) sum_adjustment = convert(sum_adjustment) _adjust_sum_statistics( session, StatisticsShortTerm, metadata[statistic_id][0], start_time, sum_adjustment, ) _adjust_sum_statistics( session, Statistics, metadata[statistic_id][0], start_time.replace(minute=0), sum_adjustment, ) return True def _change_statistics_unit_for_table( session: Session, table: type[StatisticsBase], metadata_id: int, convert: Callable[[float | None], float | None], ) -> None: """Insert statistics in the database.""" columns = (table.id, table.mean, table.min, table.max, table.state, table.sum) query = session.query(*columns).filter_by(metadata_id=bindparam("metadata_id")) rows = execute(query.params(metadata_id=metadata_id)) for row in rows: session.query(table).filter(table.id == row.id).update( { table.mean: convert(row.mean), table.min: convert(row.min), table.max: convert(row.max), table.state: convert(row.state), table.sum: convert(row.sum), }, synchronize_session=False, ) def change_statistics_unit( instance: Recorder, statistic_id: str, new_unit: str, old_unit: str, ) -> None: """Change statistics unit for a statistic_id.""" statistics_meta_manager = instance.statistics_meta_manager with session_scope(session=instance.get_session()) as session: metadata = statistics_meta_manager.get(session, statistic_id) # Guard against the statistics being removed or updated before the # change_statistics_unit job executes if ( metadata is None or metadata[1]["source"] != DOMAIN or metadata[1]["unit_of_measurement"] != old_unit ): _LOGGER.warning("Could not change statistics unit for %s", statistic_id) return metadata_id = metadata[0] convert = _get_unit_converter(old_unit, new_unit) tables: tuple[type[StatisticsBase], ...] = ( Statistics, StatisticsShortTerm, ) for table in tables: _change_statistics_unit_for_table(session, table, metadata_id, convert) statistics_meta_manager.update_unit_of_measurement( session, statistic_id, new_unit ) @callback def async_change_statistics_unit( hass: HomeAssistant, statistic_id: str, *, new_unit_of_measurement: str, old_unit_of_measurement: str, ) -> None: """Change statistics unit for a statistic_id.""" if not can_convert_units(old_unit_of_measurement, new_unit_of_measurement): raise HomeAssistantError( f"Can't convert {old_unit_of_measurement} to {new_unit_of_measurement}" ) get_instance(hass).async_change_statistics_unit( statistic_id, new_unit_of_measurement=new_unit_of_measurement, old_unit_of_measurement=old_unit_of_measurement, ) def cleanup_statistics_timestamp_migration(instance: Recorder) -> bool: """Clean up the statistics migration from timestamp to datetime. Returns False if there are more rows to update. Returns True if all rows have been updated. """ engine = instance.engine assert engine is not None if engine.dialect.name == SupportedDialect.SQLITE: for table in STATISTICS_TABLES: with session_scope(session=instance.get_session()) as session: session.connection().execute( text( f"update {table} set start = NULL, created = NULL, last_reset = NULL;" ) ) elif engine.dialect.name == SupportedDialect.MYSQL: for table in STATISTICS_TABLES: with session_scope(session=instance.get_session()) as session: if ( session.connection() .execute( text( f"UPDATE {table} set start=NULL, created=NULL, last_reset=NULL where start is not NULL LIMIT 100000;" ) ) .rowcount ): # We have more rows to update so return False # to indicate we need to run again return False elif engine.dialect.name == SupportedDialect.POSTGRESQL: for table in STATISTICS_TABLES: with session_scope(session=instance.get_session()) as session: if ( session.connection() .execute( text( f"UPDATE {table} set start=NULL, created=NULL, last_reset=NULL " # nosec f"where id in (select id from {table} where start is not NULL LIMIT 100000)" ) ) .rowcount ): # We have more rows to update so return False # to indicate we need to run again return False from .migration import _drop_index # pylint: disable=import-outside-toplevel for table in STATISTICS_TABLES: _drop_index(instance.get_session, table, f"ix_{table}_start") # We have no more rows to update so return True # to indicate we are done return True