core/homeassistant/components/energy/websocket_api.py

369 lines
11 KiB
Python

"""The Energy websocket API."""
from __future__ import annotations
import asyncio
from collections import defaultdict
from collections.abc import Awaitable, Callable
from datetime import datetime, timedelta
import functools
from itertools import chain
from types import ModuleType
from typing import Any, cast
import voluptuous as vol
from homeassistant.components import recorder, websocket_api
from homeassistant.core import HomeAssistant, callback
from homeassistant.helpers.integration_platform import (
async_process_integration_platforms,
)
from homeassistant.helpers.singleton import singleton
from homeassistant.util import dt as dt_util
from .const import DOMAIN
from .data import (
DEVICE_CONSUMPTION_SCHEMA,
ENERGY_SOURCE_SCHEMA,
EnergyManager,
EnergyPreferencesUpdate,
async_get_manager,
)
from .types import EnergyPlatform, GetSolarForecastType
from .validate import async_validate
EnergyWebSocketCommandHandler = Callable[
[HomeAssistant, websocket_api.ActiveConnection, "dict[str, Any]", "EnergyManager"],
None,
]
AsyncEnergyWebSocketCommandHandler = Callable[
[HomeAssistant, websocket_api.ActiveConnection, "dict[str, Any]", "EnergyManager"],
Awaitable[None],
]
@callback
def async_setup(hass: HomeAssistant) -> None:
"""Set up the energy websocket API."""
websocket_api.async_register_command(hass, ws_get_prefs)
websocket_api.async_register_command(hass, ws_save_prefs)
websocket_api.async_register_command(hass, ws_info)
websocket_api.async_register_command(hass, ws_validate)
websocket_api.async_register_command(hass, ws_solar_forecast)
websocket_api.async_register_command(hass, ws_get_fossil_energy_consumption)
@singleton("energy_platforms")
async def async_get_energy_platforms(
hass: HomeAssistant,
) -> dict[str, GetSolarForecastType]:
"""Get energy platforms."""
platforms: dict[str, GetSolarForecastType] = {}
async def _process_energy_platform(
hass: HomeAssistant, domain: str, platform: ModuleType
) -> None:
"""Process energy platforms."""
if not hasattr(platform, "async_get_solar_forecast"):
return
platforms[domain] = cast(EnergyPlatform, platform).async_get_solar_forecast
await async_process_integration_platforms(hass, DOMAIN, _process_energy_platform)
return platforms
def _ws_with_manager(
func: Any,
) -> websocket_api.WebSocketCommandHandler:
"""Decorate a function to pass in a manager."""
@websocket_api.async_response
@functools.wraps(func)
async def with_manager(
hass: HomeAssistant, connection: websocket_api.ActiveConnection, msg: dict
) -> None:
manager = await async_get_manager(hass)
result = func(hass, connection, msg, manager)
if asyncio.iscoroutine(result):
await result
return with_manager
@websocket_api.websocket_command(
{
vol.Required("type"): "energy/get_prefs",
}
)
@_ws_with_manager
@callback
def ws_get_prefs(
hass: HomeAssistant,
connection: websocket_api.ActiveConnection,
msg: dict,
manager: EnergyManager,
) -> None:
"""Handle get prefs command."""
if manager.data is None:
connection.send_error(msg["id"], websocket_api.ERR_NOT_FOUND, "No prefs")
return
connection.send_result(msg["id"], manager.data)
@websocket_api.require_admin
@websocket_api.websocket_command(
{
vol.Required("type"): "energy/save_prefs",
vol.Optional("energy_sources"): ENERGY_SOURCE_SCHEMA,
vol.Optional("device_consumption"): [DEVICE_CONSUMPTION_SCHEMA],
}
)
@_ws_with_manager
async def ws_save_prefs(
hass: HomeAssistant,
connection: websocket_api.ActiveConnection,
msg: dict,
manager: EnergyManager,
) -> None:
"""Handle get prefs command."""
msg_id = msg.pop("id")
msg.pop("type")
await manager.async_update(cast(EnergyPreferencesUpdate, msg))
connection.send_result(msg_id, manager.data)
@websocket_api.websocket_command(
{
vol.Required("type"): "energy/info",
}
)
@websocket_api.async_response
async def ws_info(
hass: HomeAssistant,
connection: websocket_api.ActiveConnection,
msg: dict,
) -> None:
"""Handle get info command."""
forecast_platforms = await async_get_energy_platforms(hass)
connection.send_result(
msg["id"],
{
"cost_sensors": hass.data[DOMAIN]["cost_sensors"],
"solar_forecast_domains": list(forecast_platforms),
},
)
@websocket_api.websocket_command(
{
vol.Required("type"): "energy/validate",
}
)
@websocket_api.async_response
async def ws_validate(
hass: HomeAssistant,
connection: websocket_api.ActiveConnection,
msg: dict,
) -> None:
"""Handle validate command."""
connection.send_result(msg["id"], (await async_validate(hass)).as_dict())
@websocket_api.websocket_command(
{
vol.Required("type"): "energy/solar_forecast",
}
)
@_ws_with_manager
async def ws_solar_forecast(
hass: HomeAssistant,
connection: websocket_api.ActiveConnection,
msg: dict,
manager: EnergyManager,
) -> None:
"""Handle solar forecast command."""
if manager.data is None:
connection.send_result(msg["id"], {})
return
config_entries: dict[str, str | None] = {}
for source in manager.data["energy_sources"]:
if (
source["type"] != "solar"
or source.get("config_entry_solar_forecast") is None
):
continue
# typing is not catching the above guard for config_entry_solar_forecast being none
for config_entry in source["config_entry_solar_forecast"]: # type: ignore[union-attr]
config_entries[config_entry] = None
if not config_entries:
connection.send_result(msg["id"], {})
return
forecasts = {}
forecast_platforms = await async_get_energy_platforms(hass)
for config_entry_id in config_entries:
config_entry = hass.config_entries.async_get_entry(config_entry_id)
# Filter out non-existing config entries or unsupported domains
if config_entry is None or config_entry.domain not in forecast_platforms:
continue
forecast = await forecast_platforms[config_entry.domain](hass, config_entry_id)
if forecast is not None:
forecasts[config_entry_id] = forecast
connection.send_result(msg["id"], forecasts)
@websocket_api.websocket_command(
{
vol.Required("type"): "energy/fossil_energy_consumption",
vol.Required("start_time"): str,
vol.Required("end_time"): str,
vol.Required("energy_statistic_ids"): [str],
vol.Required("co2_statistic_id"): str,
vol.Required("period"): vol.Any("5minute", "hour", "day", "month"),
}
)
@websocket_api.async_response
async def ws_get_fossil_energy_consumption(
hass: HomeAssistant, connection: websocket_api.ActiveConnection, msg: dict
) -> None:
"""Calculate amount of fossil based energy."""
start_time_str = msg["start_time"]
end_time_str = msg["end_time"]
if start_time := dt_util.parse_datetime(start_time_str):
start_time = dt_util.as_utc(start_time)
else:
connection.send_error(msg["id"], "invalid_start_time", "Invalid start_time")
return
if end_time := dt_util.parse_datetime(end_time_str):
end_time = dt_util.as_utc(end_time)
else:
connection.send_error(msg["id"], "invalid_end_time", "Invalid end_time")
return
statistic_ids = list(msg["energy_statistic_ids"])
statistic_ids.append(msg["co2_statistic_id"])
# Fetch energy + CO2 statistics
statistics = await hass.async_add_executor_job(
recorder.statistics.statistics_during_period,
hass,
start_time,
end_time,
statistic_ids,
"hour",
True,
)
def _combine_sum_statistics(
stats: dict[str, list[dict[str, Any]]], statistic_ids: list[str]
) -> dict[datetime, float]:
"""Combine multiple statistics, returns a dict indexed by start time."""
result: defaultdict[datetime, float] = defaultdict(float)
for statistics_id, stat in stats.items():
if statistics_id not in statistic_ids:
continue
for period in stat:
if period["sum"] is None:
continue
result[period["start"]] += period["sum"]
return {key: result[key] for key in sorted(result)}
def _calculate_deltas(sums: dict[datetime, float]) -> dict[datetime, float]:
prev: float | None = None
result: dict[datetime, float] = {}
for period, sum_ in sums.items():
if prev is not None:
result[period] = sum_ - prev
prev = sum_
return result
def _reduce_deltas(
stat_list: list[dict[str, Any]],
same_period: Callable[[datetime, datetime], bool],
period_start_end: Callable[[datetime], tuple[datetime, datetime]],
period: timedelta,
) -> list[dict[str, Any]]:
"""Reduce hourly deltas to daily or monthly deltas."""
result: list[dict[str, Any]] = []
deltas: list[float] = []
if not stat_list:
return result
prev_stat: dict[str, Any] = stat_list[0]
# Loop over the hourly deltas + a fake entry to end the period
for statistic in chain(
stat_list, ({"start": stat_list[-1]["start"] + period},)
):
if not same_period(prev_stat["start"], statistic["start"]):
start, _ = period_start_end(prev_stat["start"])
# The previous statistic was the last entry of the period
result.append(
{
"start": start.isoformat(),
"delta": sum(deltas),
}
)
deltas = []
if statistic.get("delta") is not None:
deltas.append(statistic["delta"])
prev_stat = statistic
return result
merged_energy_statistics = _combine_sum_statistics(
statistics, msg["energy_statistic_ids"]
)
energy_deltas = _calculate_deltas(merged_energy_statistics)
indexed_co2_statistics = {
period["start"]: period["mean"]
for period in statistics.get(msg["co2_statistic_id"], {})
}
# Calculate amount of fossil based energy, assume 100% fossil if missing
fossil_energy = [
{"start": start, "delta": delta * indexed_co2_statistics.get(start, 100) / 100}
for start, delta in energy_deltas.items()
]
if msg["period"] == "hour":
reduced_fossil_energy = [
{"start": period["start"].isoformat(), "delta": period["delta"]}
for period in fossil_energy
]
elif msg["period"] == "day":
reduced_fossil_energy = _reduce_deltas(
fossil_energy,
recorder.statistics.same_day,
recorder.statistics.day_start_end,
timedelta(days=1),
)
else:
reduced_fossil_energy = _reduce_deltas(
fossil_energy,
recorder.statistics.same_month,
recorder.statistics.month_start_end,
timedelta(days=1),
)
result = {period["start"]: period["delta"] for period in reduced_fossil_energy}
connection.send_result(msg["id"], result)