core/homeassistant/components/energy/websocket_api.py

221 lines
6.1 KiB
Python

"""The Energy websocket API."""
from __future__ import annotations
import asyncio
import functools
from types import ModuleType
from typing import Any, Awaitable, Callable, cast
import voluptuous as vol
from homeassistant.components import websocket_api
from homeassistant.core import HomeAssistant, callback
from homeassistant.helpers.integration_platform import (
async_process_integration_platforms,
)
from homeassistant.helpers.singleton import singleton
from .const import DOMAIN
from .data import (
DEVICE_CONSUMPTION_SCHEMA,
ENERGY_SOURCE_SCHEMA,
EnergyManager,
EnergyPreferencesUpdate,
async_get_manager,
)
from .types import EnergyPlatform, GetSolarForecastType
from .validate import async_validate
EnergyWebSocketCommandHandler = Callable[
[HomeAssistant, websocket_api.ActiveConnection, "dict[str, Any]", "EnergyManager"],
None,
]
AsyncEnergyWebSocketCommandHandler = Callable[
[HomeAssistant, websocket_api.ActiveConnection, "dict[str, Any]", "EnergyManager"],
Awaitable[None],
]
@callback
def async_setup(hass: HomeAssistant) -> None:
"""Set up the energy websocket API."""
websocket_api.async_register_command(hass, ws_get_prefs)
websocket_api.async_register_command(hass, ws_save_prefs)
websocket_api.async_register_command(hass, ws_info)
websocket_api.async_register_command(hass, ws_validate)
websocket_api.async_register_command(hass, ws_solar_forecast)
@singleton("energy_platforms")
async def async_get_energy_platforms(
hass: HomeAssistant,
) -> dict[str, GetSolarForecastType]:
"""Get energy platforms."""
platforms: dict[str, GetSolarForecastType] = {}
async def _process_energy_platform(
hass: HomeAssistant, domain: str, platform: ModuleType
) -> None:
"""Process energy platforms."""
if not hasattr(platform, "async_get_solar_forecast"):
return
platforms[domain] = cast(EnergyPlatform, platform).async_get_solar_forecast
await async_process_integration_platforms(hass, DOMAIN, _process_energy_platform)
return platforms
def _ws_with_manager(
func: Any,
) -> websocket_api.WebSocketCommandHandler:
"""Decorate a function to pass in a manager."""
@websocket_api.async_response
@functools.wraps(func)
async def with_manager(
hass: HomeAssistant, connection: websocket_api.ActiveConnection, msg: dict
) -> None:
manager = await async_get_manager(hass)
result = func(hass, connection, msg, manager)
if asyncio.iscoroutine(result):
await result
return with_manager
@websocket_api.websocket_command(
{
vol.Required("type"): "energy/get_prefs",
}
)
@_ws_with_manager
@callback
def ws_get_prefs(
hass: HomeAssistant,
connection: websocket_api.ActiveConnection,
msg: dict,
manager: EnergyManager,
) -> None:
"""Handle get prefs command."""
if manager.data is None:
connection.send_error(msg["id"], websocket_api.ERR_NOT_FOUND, "No prefs")
return
connection.send_result(msg["id"], manager.data)
@websocket_api.require_admin
@websocket_api.websocket_command(
{
vol.Required("type"): "energy/save_prefs",
vol.Optional("energy_sources"): ENERGY_SOURCE_SCHEMA,
vol.Optional("device_consumption"): [DEVICE_CONSUMPTION_SCHEMA],
}
)
@_ws_with_manager
async def ws_save_prefs(
hass: HomeAssistant,
connection: websocket_api.ActiveConnection,
msg: dict,
manager: EnergyManager,
) -> None:
"""Handle get prefs command."""
msg_id = msg.pop("id")
msg.pop("type")
await manager.async_update(cast(EnergyPreferencesUpdate, msg))
connection.send_result(msg_id, manager.data)
@websocket_api.websocket_command(
{
vol.Required("type"): "energy/info",
}
)
@websocket_api.async_response
async def ws_info(
hass: HomeAssistant,
connection: websocket_api.ActiveConnection,
msg: dict,
) -> None:
"""Handle get info command."""
forecast_platforms = await async_get_energy_platforms(hass)
connection.send_result(
msg["id"],
{
"cost_sensors": hass.data[DOMAIN]["cost_sensors"],
"solar_forecast_domains": list(forecast_platforms),
},
)
@websocket_api.websocket_command(
{
vol.Required("type"): "energy/validate",
}
)
@websocket_api.async_response
async def ws_validate(
hass: HomeAssistant,
connection: websocket_api.ActiveConnection,
msg: dict,
) -> None:
"""Handle validate command."""
connection.send_result(msg["id"], (await async_validate(hass)).as_dict())
@websocket_api.websocket_command(
{
vol.Required("type"): "energy/solar_forecast",
}
)
@_ws_with_manager
async def ws_solar_forecast(
hass: HomeAssistant,
connection: websocket_api.ActiveConnection,
msg: dict,
manager: EnergyManager,
) -> None:
"""Handle solar forecast command."""
if manager.data is None:
connection.send_result(msg["id"], {})
return
config_entries: dict[str, str | None] = {}
for source in manager.data["energy_sources"]:
if (
source["type"] != "solar"
or source.get("config_entry_solar_forecast") is None
):
continue
# typing is not catching the above guard for config_entry_solar_forecast being none
for config_entry in source["config_entry_solar_forecast"]: # type: ignore[union-attr]
config_entries[config_entry] = None
if not config_entries:
connection.send_result(msg["id"], {})
return
forecasts = {}
forecast_platforms = await async_get_energy_platforms(hass)
for config_entry_id in config_entries:
config_entry = hass.config_entries.async_get_entry(config_entry_id)
# Filter out non-existing config entries or unsupported domains
if config_entry is None or config_entry.domain not in forecast_platforms:
continue
forecast = await forecast_platforms[config_entry.domain](hass, config_entry_id)
if forecast is not None:
forecasts[config_entry_id] = forecast
connection.send_result(msg["id"], forecasts)