221 lines
6.1 KiB
Python
221 lines
6.1 KiB
Python
"""The Energy websocket API."""
|
|
from __future__ import annotations
|
|
|
|
import asyncio
|
|
import functools
|
|
from types import ModuleType
|
|
from typing import Any, Awaitable, Callable, cast
|
|
|
|
import voluptuous as vol
|
|
|
|
from homeassistant.components import websocket_api
|
|
from homeassistant.core import HomeAssistant, callback
|
|
from homeassistant.helpers.integration_platform import (
|
|
async_process_integration_platforms,
|
|
)
|
|
from homeassistant.helpers.singleton import singleton
|
|
|
|
from .const import DOMAIN
|
|
from .data import (
|
|
DEVICE_CONSUMPTION_SCHEMA,
|
|
ENERGY_SOURCE_SCHEMA,
|
|
EnergyManager,
|
|
EnergyPreferencesUpdate,
|
|
async_get_manager,
|
|
)
|
|
from .types import EnergyPlatform, GetSolarForecastType
|
|
from .validate import async_validate
|
|
|
|
EnergyWebSocketCommandHandler = Callable[
|
|
[HomeAssistant, websocket_api.ActiveConnection, "dict[str, Any]", "EnergyManager"],
|
|
None,
|
|
]
|
|
AsyncEnergyWebSocketCommandHandler = Callable[
|
|
[HomeAssistant, websocket_api.ActiveConnection, "dict[str, Any]", "EnergyManager"],
|
|
Awaitable[None],
|
|
]
|
|
|
|
|
|
@callback
|
|
def async_setup(hass: HomeAssistant) -> None:
|
|
"""Set up the energy websocket API."""
|
|
websocket_api.async_register_command(hass, ws_get_prefs)
|
|
websocket_api.async_register_command(hass, ws_save_prefs)
|
|
websocket_api.async_register_command(hass, ws_info)
|
|
websocket_api.async_register_command(hass, ws_validate)
|
|
websocket_api.async_register_command(hass, ws_solar_forecast)
|
|
|
|
|
|
@singleton("energy_platforms")
|
|
async def async_get_energy_platforms(
|
|
hass: HomeAssistant,
|
|
) -> dict[str, GetSolarForecastType]:
|
|
"""Get energy platforms."""
|
|
platforms: dict[str, GetSolarForecastType] = {}
|
|
|
|
async def _process_energy_platform(
|
|
hass: HomeAssistant, domain: str, platform: ModuleType
|
|
) -> None:
|
|
"""Process energy platforms."""
|
|
if not hasattr(platform, "async_get_solar_forecast"):
|
|
return
|
|
|
|
platforms[domain] = cast(EnergyPlatform, platform).async_get_solar_forecast
|
|
|
|
await async_process_integration_platforms(hass, DOMAIN, _process_energy_platform)
|
|
|
|
return platforms
|
|
|
|
|
|
def _ws_with_manager(
|
|
func: Any,
|
|
) -> websocket_api.WebSocketCommandHandler:
|
|
"""Decorate a function to pass in a manager."""
|
|
|
|
@websocket_api.async_response
|
|
@functools.wraps(func)
|
|
async def with_manager(
|
|
hass: HomeAssistant, connection: websocket_api.ActiveConnection, msg: dict
|
|
) -> None:
|
|
manager = await async_get_manager(hass)
|
|
|
|
result = func(hass, connection, msg, manager)
|
|
|
|
if asyncio.iscoroutine(result):
|
|
await result
|
|
|
|
return with_manager
|
|
|
|
|
|
@websocket_api.websocket_command(
|
|
{
|
|
vol.Required("type"): "energy/get_prefs",
|
|
}
|
|
)
|
|
@_ws_with_manager
|
|
@callback
|
|
def ws_get_prefs(
|
|
hass: HomeAssistant,
|
|
connection: websocket_api.ActiveConnection,
|
|
msg: dict,
|
|
manager: EnergyManager,
|
|
) -> None:
|
|
"""Handle get prefs command."""
|
|
if manager.data is None:
|
|
connection.send_error(msg["id"], websocket_api.ERR_NOT_FOUND, "No prefs")
|
|
return
|
|
|
|
connection.send_result(msg["id"], manager.data)
|
|
|
|
|
|
@websocket_api.require_admin
|
|
@websocket_api.websocket_command(
|
|
{
|
|
vol.Required("type"): "energy/save_prefs",
|
|
vol.Optional("energy_sources"): ENERGY_SOURCE_SCHEMA,
|
|
vol.Optional("device_consumption"): [DEVICE_CONSUMPTION_SCHEMA],
|
|
}
|
|
)
|
|
@_ws_with_manager
|
|
async def ws_save_prefs(
|
|
hass: HomeAssistant,
|
|
connection: websocket_api.ActiveConnection,
|
|
msg: dict,
|
|
manager: EnergyManager,
|
|
) -> None:
|
|
"""Handle get prefs command."""
|
|
msg_id = msg.pop("id")
|
|
msg.pop("type")
|
|
await manager.async_update(cast(EnergyPreferencesUpdate, msg))
|
|
connection.send_result(msg_id, manager.data)
|
|
|
|
|
|
@websocket_api.websocket_command(
|
|
{
|
|
vol.Required("type"): "energy/info",
|
|
}
|
|
)
|
|
@websocket_api.async_response
|
|
async def ws_info(
|
|
hass: HomeAssistant,
|
|
connection: websocket_api.ActiveConnection,
|
|
msg: dict,
|
|
) -> None:
|
|
"""Handle get info command."""
|
|
forecast_platforms = await async_get_energy_platforms(hass)
|
|
connection.send_result(
|
|
msg["id"],
|
|
{
|
|
"cost_sensors": hass.data[DOMAIN]["cost_sensors"],
|
|
"solar_forecast_domains": list(forecast_platforms),
|
|
},
|
|
)
|
|
|
|
|
|
@websocket_api.websocket_command(
|
|
{
|
|
vol.Required("type"): "energy/validate",
|
|
}
|
|
)
|
|
@websocket_api.async_response
|
|
async def ws_validate(
|
|
hass: HomeAssistant,
|
|
connection: websocket_api.ActiveConnection,
|
|
msg: dict,
|
|
) -> None:
|
|
"""Handle validate command."""
|
|
connection.send_result(msg["id"], (await async_validate(hass)).as_dict())
|
|
|
|
|
|
@websocket_api.websocket_command(
|
|
{
|
|
vol.Required("type"): "energy/solar_forecast",
|
|
}
|
|
)
|
|
@_ws_with_manager
|
|
async def ws_solar_forecast(
|
|
hass: HomeAssistant,
|
|
connection: websocket_api.ActiveConnection,
|
|
msg: dict,
|
|
manager: EnergyManager,
|
|
) -> None:
|
|
"""Handle solar forecast command."""
|
|
if manager.data is None:
|
|
connection.send_result(msg["id"], {})
|
|
return
|
|
|
|
config_entries: dict[str, str | None] = {}
|
|
|
|
for source in manager.data["energy_sources"]:
|
|
if (
|
|
source["type"] != "solar"
|
|
or source.get("config_entry_solar_forecast") is None
|
|
):
|
|
continue
|
|
|
|
# typing is not catching the above guard for config_entry_solar_forecast being none
|
|
for config_entry in source["config_entry_solar_forecast"]: # type: ignore[union-attr]
|
|
config_entries[config_entry] = None
|
|
|
|
if not config_entries:
|
|
connection.send_result(msg["id"], {})
|
|
return
|
|
|
|
forecasts = {}
|
|
|
|
forecast_platforms = await async_get_energy_platforms(hass)
|
|
|
|
for config_entry_id in config_entries:
|
|
config_entry = hass.config_entries.async_get_entry(config_entry_id)
|
|
# Filter out non-existing config entries or unsupported domains
|
|
|
|
if config_entry is None or config_entry.domain not in forecast_platforms:
|
|
continue
|
|
|
|
forecast = await forecast_platforms[config_entry.domain](hass, config_entry_id)
|
|
|
|
if forecast is not None:
|
|
forecasts[config_entry_id] = forecast
|
|
|
|
connection.send_result(msg["id"], forecasts)
|