"""The Energy websocket API.""" from __future__ import annotations import asyncio from collections import defaultdict from collections.abc import Awaitable, Callable from datetime import datetime, timedelta import functools from itertools import chain from types import ModuleType from typing import Any, cast import voluptuous as vol from homeassistant.components import recorder, websocket_api from homeassistant.core import HomeAssistant, callback from homeassistant.helpers.integration_platform import ( async_process_integration_platforms, ) from homeassistant.helpers.singleton import singleton from homeassistant.util import dt as dt_util from .const import DOMAIN from .data import ( DEVICE_CONSUMPTION_SCHEMA, ENERGY_SOURCE_SCHEMA, EnergyManager, EnergyPreferencesUpdate, async_get_manager, ) from .types import EnergyPlatform, GetSolarForecastType from .validate import async_validate EnergyWebSocketCommandHandler = Callable[ [HomeAssistant, websocket_api.ActiveConnection, "dict[str, Any]", "EnergyManager"], None, ] AsyncEnergyWebSocketCommandHandler = Callable[ [HomeAssistant, websocket_api.ActiveConnection, "dict[str, Any]", "EnergyManager"], Awaitable[None], ] @callback def async_setup(hass: HomeAssistant) -> None: """Set up the energy websocket API.""" websocket_api.async_register_command(hass, ws_get_prefs) websocket_api.async_register_command(hass, ws_save_prefs) websocket_api.async_register_command(hass, ws_info) websocket_api.async_register_command(hass, ws_validate) websocket_api.async_register_command(hass, ws_solar_forecast) websocket_api.async_register_command(hass, ws_get_fossil_energy_consumption) @singleton("energy_platforms") async def async_get_energy_platforms( hass: HomeAssistant, ) -> dict[str, GetSolarForecastType]: """Get energy platforms.""" platforms: dict[str, GetSolarForecastType] = {} async def _process_energy_platform( hass: HomeAssistant, domain: str, platform: ModuleType ) -> None: """Process energy platforms.""" if not hasattr(platform, "async_get_solar_forecast"): return platforms[domain] = cast(EnergyPlatform, platform).async_get_solar_forecast await async_process_integration_platforms(hass, DOMAIN, _process_energy_platform) return platforms def _ws_with_manager( func: Any, ) -> websocket_api.WebSocketCommandHandler: """Decorate a function to pass in a manager.""" @websocket_api.async_response @functools.wraps(func) async def with_manager( hass: HomeAssistant, connection: websocket_api.ActiveConnection, msg: dict ) -> None: manager = await async_get_manager(hass) result = func(hass, connection, msg, manager) if asyncio.iscoroutine(result): await result return with_manager @websocket_api.websocket_command( { vol.Required("type"): "energy/get_prefs", } ) @_ws_with_manager @callback def ws_get_prefs( hass: HomeAssistant, connection: websocket_api.ActiveConnection, msg: dict, manager: EnergyManager, ) -> None: """Handle get prefs command.""" if manager.data is None: connection.send_error(msg["id"], websocket_api.ERR_NOT_FOUND, "No prefs") return connection.send_result(msg["id"], manager.data) @websocket_api.require_admin @websocket_api.websocket_command( { vol.Required("type"): "energy/save_prefs", vol.Optional("energy_sources"): ENERGY_SOURCE_SCHEMA, vol.Optional("device_consumption"): [DEVICE_CONSUMPTION_SCHEMA], } ) @_ws_with_manager async def ws_save_prefs( hass: HomeAssistant, connection: websocket_api.ActiveConnection, msg: dict, manager: EnergyManager, ) -> None: """Handle get prefs command.""" msg_id = msg.pop("id") msg.pop("type") await manager.async_update(cast(EnergyPreferencesUpdate, msg)) connection.send_result(msg_id, manager.data) @websocket_api.websocket_command( { vol.Required("type"): "energy/info", } ) @websocket_api.async_response async def ws_info( hass: HomeAssistant, connection: websocket_api.ActiveConnection, msg: dict, ) -> None: """Handle get info command.""" forecast_platforms = await async_get_energy_platforms(hass) connection.send_result( msg["id"], { "cost_sensors": hass.data[DOMAIN]["cost_sensors"], "solar_forecast_domains": list(forecast_platforms), }, ) @websocket_api.websocket_command( { vol.Required("type"): "energy/validate", } ) @websocket_api.async_response async def ws_validate( hass: HomeAssistant, connection: websocket_api.ActiveConnection, msg: dict, ) -> None: """Handle validate command.""" connection.send_result(msg["id"], (await async_validate(hass)).as_dict()) @websocket_api.websocket_command( { vol.Required("type"): "energy/solar_forecast", } ) @_ws_with_manager async def ws_solar_forecast( hass: HomeAssistant, connection: websocket_api.ActiveConnection, msg: dict, manager: EnergyManager, ) -> None: """Handle solar forecast command.""" if manager.data is None: connection.send_result(msg["id"], {}) return config_entries: dict[str, str | None] = {} for source in manager.data["energy_sources"]: if ( source["type"] != "solar" or source.get("config_entry_solar_forecast") is None ): continue # typing is not catching the above guard for config_entry_solar_forecast being none for config_entry in source["config_entry_solar_forecast"]: # type: ignore[union-attr] config_entries[config_entry] = None if not config_entries: connection.send_result(msg["id"], {}) return forecasts = {} forecast_platforms = await async_get_energy_platforms(hass) for config_entry_id in config_entries: config_entry = hass.config_entries.async_get_entry(config_entry_id) # Filter out non-existing config entries or unsupported domains if config_entry is None or config_entry.domain not in forecast_platforms: continue forecast = await forecast_platforms[config_entry.domain](hass, config_entry_id) if forecast is not None: forecasts[config_entry_id] = forecast connection.send_result(msg["id"], forecasts) @websocket_api.websocket_command( { vol.Required("type"): "energy/fossil_energy_consumption", vol.Required("start_time"): str, vol.Required("end_time"): str, vol.Required("energy_statistic_ids"): [str], vol.Required("co2_statistic_id"): str, vol.Required("period"): vol.Any("5minute", "hour", "day", "month"), } ) @websocket_api.async_response async def ws_get_fossil_energy_consumption( hass: HomeAssistant, connection: websocket_api.ActiveConnection, msg: dict ) -> None: """Calculate amount of fossil based energy.""" start_time_str = msg["start_time"] end_time_str = msg["end_time"] if start_time := dt_util.parse_datetime(start_time_str): start_time = dt_util.as_utc(start_time) else: connection.send_error(msg["id"], "invalid_start_time", "Invalid start_time") return if end_time := dt_util.parse_datetime(end_time_str): end_time = dt_util.as_utc(end_time) else: connection.send_error(msg["id"], "invalid_end_time", "Invalid end_time") return statistic_ids = list(msg["energy_statistic_ids"]) statistic_ids.append(msg["co2_statistic_id"]) # Fetch energy + CO2 statistics statistics = await hass.async_add_executor_job( recorder.statistics.statistics_during_period, hass, start_time, end_time, statistic_ids, "hour", True, ) def _combine_sum_statistics( stats: dict[str, list[dict[str, Any]]], statistic_ids: list[str] ) -> dict[datetime, float]: """Combine multiple statistics, returns a dict indexed by start time.""" result: defaultdict[datetime, float] = defaultdict(float) for statistics_id, stat in stats.items(): if statistics_id not in statistic_ids: continue for period in stat: if period["sum"] is None: continue result[period["start"]] += period["sum"] return {key: result[key] for key in sorted(result)} def _calculate_deltas(sums: dict[datetime, float]) -> dict[datetime, float]: prev: float | None = None result: dict[datetime, float] = {} for period, sum_ in sums.items(): if prev is not None: result[period] = sum_ - prev prev = sum_ return result def _reduce_deltas( stat_list: list[dict[str, Any]], same_period: Callable[[datetime, datetime], bool], period_start_end: Callable[[datetime], tuple[datetime, datetime]], period: timedelta, ) -> list[dict[str, Any]]: """Reduce hourly deltas to daily or monthly deltas.""" result: list[dict[str, Any]] = [] deltas: list[float] = [] if not stat_list: return result prev_stat: dict[str, Any] = stat_list[0] # Loop over the hourly deltas + a fake entry to end the period for statistic in chain( stat_list, ({"start": stat_list[-1]["start"] + period},) ): if not same_period(prev_stat["start"], statistic["start"]): start, _ = period_start_end(prev_stat["start"]) # The previous statistic was the last entry of the period result.append( { "start": start.isoformat(), "delta": sum(deltas), } ) deltas = [] if statistic.get("delta") is not None: deltas.append(statistic["delta"]) prev_stat = statistic return result merged_energy_statistics = _combine_sum_statistics( statistics, msg["energy_statistic_ids"] ) energy_deltas = _calculate_deltas(merged_energy_statistics) indexed_co2_statistics = { period["start"]: period["mean"] for period in statistics.get(msg["co2_statistic_id"], {}) } # Calculate amount of fossil based energy, assume 100% fossil if missing fossil_energy = [ {"start": start, "delta": delta * indexed_co2_statistics.get(start, 100) / 100} for start, delta in energy_deltas.items() ] if msg["period"] == "hour": reduced_fossil_energy = [ {"start": period["start"].isoformat(), "delta": period["delta"]} for period in fossil_energy ] elif msg["period"] == "day": reduced_fossil_energy = _reduce_deltas( fossil_energy, recorder.statistics.same_day, recorder.statistics.day_start_end, timedelta(days=1), ) else: reduced_fossil_energy = _reduce_deltas( fossil_energy, recorder.statistics.same_month, recorder.statistics.month_start_end, timedelta(days=1), ) result = {period["start"]: period["delta"] for period in reduced_fossil_energy} connection.send_result(msg["id"], result)