core/homeassistant/components/template/config_flow.py

398 lines
13 KiB
Python

"""Config flow for the Template integration."""
from __future__ import annotations
from collections.abc import Callable, Coroutine, Mapping
from typing import Any, cast
import voluptuous as vol
from homeassistant.components import websocket_api
from homeassistant.components.binary_sensor import BinarySensorDeviceClass
from homeassistant.components.sensor import (
CONF_STATE_CLASS,
DEVICE_CLASS_STATE_CLASSES,
DEVICE_CLASS_UNITS,
SensorDeviceClass,
SensorStateClass,
)
from homeassistant.const import (
CONF_DEVICE_CLASS,
CONF_NAME,
CONF_STATE,
CONF_UNIT_OF_MEASUREMENT,
Platform,
)
from homeassistant.core import HomeAssistant, callback
from homeassistant.exceptions import HomeAssistantError
from homeassistant.helpers import entity_registry as er, selector
from homeassistant.helpers.schema_config_entry_flow import (
SchemaCommonFlowHandler,
SchemaConfigFlowHandler,
SchemaFlowFormStep,
SchemaFlowMenuStep,
)
from .binary_sensor import async_create_preview_binary_sensor
from .const import DOMAIN
from .sensor import async_create_preview_sensor
from .template_entity import TemplateEntity
NONE_SENTINEL = "none"
def generate_schema(domain: str, flow_type: str) -> dict[vol.Marker, Any]:
"""Generate schema."""
schema: dict[vol.Marker, Any] = {}
if domain == Platform.BINARY_SENSOR and flow_type == "config":
schema = {
vol.Optional(CONF_DEVICE_CLASS): selector.SelectSelector(
selector.SelectSelectorConfig(
options=[
NONE_SENTINEL,
*sorted(
[cls.value for cls in BinarySensorDeviceClass],
key=str.casefold,
),
],
mode=selector.SelectSelectorMode.DROPDOWN,
translation_key="binary_sensor_device_class",
),
)
}
if domain == Platform.SENSOR:
schema = {
vol.Optional(
CONF_UNIT_OF_MEASUREMENT, default=NONE_SENTINEL
): selector.SelectSelector(
selector.SelectSelectorConfig(
options=[
NONE_SENTINEL,
*sorted(
{
str(unit)
for units in DEVICE_CLASS_UNITS.values()
for unit in units
if unit is not None
},
key=str.casefold,
),
],
mode=selector.SelectSelectorMode.DROPDOWN,
translation_key="sensor_unit_of_measurement",
custom_value=True,
),
),
vol.Optional(
CONF_DEVICE_CLASS, default=NONE_SENTINEL
): selector.SelectSelector(
selector.SelectSelectorConfig(
options=[
NONE_SENTINEL,
*sorted(
[
cls.value
for cls in SensorDeviceClass
if cls != SensorDeviceClass.ENUM
],
key=str.casefold,
),
],
mode=selector.SelectSelectorMode.DROPDOWN,
translation_key="sensor_device_class",
),
),
vol.Optional(
CONF_STATE_CLASS, default=NONE_SENTINEL
): selector.SelectSelector(
selector.SelectSelectorConfig(
options=[
NONE_SENTINEL,
*sorted([cls.value for cls in SensorStateClass]),
],
mode=selector.SelectSelectorMode.DROPDOWN,
translation_key="sensor_state_class",
),
),
}
return schema
def options_schema(domain: str) -> vol.Schema:
"""Generate options schema."""
return vol.Schema(
{vol.Required(CONF_STATE): selector.TemplateSelector()}
| generate_schema(domain, "option"),
)
def config_schema(domain: str) -> vol.Schema:
"""Generate config schema."""
return vol.Schema(
{
vol.Required(CONF_NAME): selector.TextSelector(),
vol.Required(CONF_STATE): selector.TemplateSelector(),
}
| generate_schema(domain, "config"),
)
async def choose_options_step(options: dict[str, Any]) -> str:
"""Return next step_id for options flow according to template_type."""
return cast(str, options["template_type"])
def _strip_sentinel(options: dict[str, Any]) -> None:
"""Convert sentinel to None."""
for key in (CONF_DEVICE_CLASS, CONF_STATE_CLASS, CONF_UNIT_OF_MEASUREMENT):
if key not in options:
continue
if options[key] == NONE_SENTINEL:
options.pop(key)
def _validate_unit(options: dict[str, Any]) -> None:
"""Validate unit of measurement."""
if (
(device_class := options.get(CONF_DEVICE_CLASS))
and (units := DEVICE_CLASS_UNITS.get(device_class)) is not None
and (unit := options.get(CONF_UNIT_OF_MEASUREMENT)) not in units
):
sorted_units = sorted(
[f"'{str(unit)}'" if unit else "no unit of measurement" for unit in units],
key=str.casefold,
)
if len(sorted_units) == 1:
units_string = sorted_units[0]
else:
units_string = f"one of {', '.join(sorted_units)}"
raise vol.Invalid(
f"'{unit}' is not a valid unit for device class '{device_class}'; "
f"expected {units_string}"
)
def _validate_state_class(options: dict[str, Any]) -> None:
"""Validate state class."""
if (
(state_class := options.get(CONF_STATE_CLASS))
and (device_class := options.get(CONF_DEVICE_CLASS))
and (state_classes := DEVICE_CLASS_STATE_CLASSES.get(device_class)) is not None
and state_class not in state_classes
):
sorted_state_classes = sorted(
[f"'{str(state_class)}'" for state_class in state_classes],
key=str.casefold,
)
if len(sorted_state_classes) == 0:
state_classes_string = "no state class"
elif len(sorted_state_classes) == 1:
state_classes_string = sorted_state_classes[0]
else:
state_classes_string = f"one of {', '.join(sorted_state_classes)}"
raise vol.Invalid(
f"'{state_class}' is not a valid state class for device class "
f"'{device_class}'; expected {state_classes_string}"
)
def validate_user_input(
template_type: str,
) -> Callable[
[SchemaCommonFlowHandler, dict[str, Any]],
Coroutine[Any, Any, dict[str, Any]],
]:
"""Do post validation of user input.
For binary sensors: Strip none-sentinels.
For sensors: Strip none-sentinels and validate unit of measurement.
For all domaines: Set template type.
"""
async def _validate_user_input(
_: SchemaCommonFlowHandler,
user_input: dict[str, Any],
) -> dict[str, Any]:
"""Add template type to user input."""
if template_type in (Platform.BINARY_SENSOR, Platform.SENSOR):
_strip_sentinel(user_input)
if template_type == Platform.SENSOR:
_validate_unit(user_input)
_validate_state_class(user_input)
return {"template_type": template_type} | user_input
return _validate_user_input
TEMPLATE_TYPES = [
"binary_sensor",
"sensor",
]
CONFIG_FLOW = {
"user": SchemaFlowMenuStep(TEMPLATE_TYPES),
Platform.BINARY_SENSOR: SchemaFlowFormStep(
config_schema(Platform.BINARY_SENSOR),
preview="template",
validate_user_input=validate_user_input(Platform.BINARY_SENSOR),
),
Platform.SENSOR: SchemaFlowFormStep(
config_schema(Platform.SENSOR),
preview="template",
validate_user_input=validate_user_input(Platform.SENSOR),
),
}
OPTIONS_FLOW = {
"init": SchemaFlowFormStep(next_step=choose_options_step),
Platform.BINARY_SENSOR: SchemaFlowFormStep(
options_schema(Platform.BINARY_SENSOR),
preview="template",
validate_user_input=validate_user_input(Platform.BINARY_SENSOR),
),
Platform.SENSOR: SchemaFlowFormStep(
options_schema(Platform.SENSOR),
preview="template",
validate_user_input=validate_user_input(Platform.SENSOR),
),
}
CREATE_PREVIEW_ENTITY: dict[
str,
Callable[[HomeAssistant, str, dict[str, Any]], TemplateEntity],
] = {
"binary_sensor": async_create_preview_binary_sensor,
"sensor": async_create_preview_sensor,
}
class TemplateConfigFlowHandler(SchemaConfigFlowHandler, domain=DOMAIN):
"""Handle config flow for template helper."""
config_flow = CONFIG_FLOW
options_flow = OPTIONS_FLOW
@callback
def async_config_entry_title(self, options: Mapping[str, Any]) -> str:
"""Return config entry title."""
return cast(str, options["name"])
@staticmethod
async def async_setup_preview(hass: HomeAssistant) -> None:
"""Set up preview WS API."""
websocket_api.async_register_command(hass, ws_start_preview)
@websocket_api.websocket_command(
{
vol.Required("type"): "template/start_preview",
vol.Required("flow_id"): str,
vol.Required("flow_type"): vol.Any("config_flow", "options_flow"),
vol.Required("user_input"): dict,
}
)
@callback
def ws_start_preview(
hass: HomeAssistant,
connection: websocket_api.ActiveConnection,
msg: dict[str, Any],
) -> None:
"""Generate a preview."""
def _validate(schema: vol.Schema, domain: str, user_input: dict[str, Any]) -> Any:
errors = {}
key: vol.Marker
for key, validator in schema.schema.items():
if key.schema not in user_input:
continue
try:
validator(user_input[key.schema])
except vol.Invalid as ex:
errors[key.schema] = str(ex.msg)
if domain == Platform.SENSOR:
_strip_sentinel(user_input)
try:
_validate_unit(user_input)
except vol.Invalid as ex:
errors[CONF_UNIT_OF_MEASUREMENT] = str(ex.msg)
try:
_validate_state_class(user_input)
except vol.Invalid as ex:
errors[CONF_STATE_CLASS] = str(ex.msg)
return errors
entity_registry_entry: er.RegistryEntry | None = None
if msg["flow_type"] == "config_flow":
flow_status = hass.config_entries.flow.async_get(msg["flow_id"])
template_type = flow_status["step_id"]
form_step = cast(SchemaFlowFormStep, CONFIG_FLOW[template_type])
schema = cast(vol.Schema, form_step.schema)
name = msg["user_input"]["name"]
else:
flow_status = hass.config_entries.options.async_get(msg["flow_id"])
config_entry = hass.config_entries.async_get_entry(flow_status["handler"])
if not config_entry:
raise HomeAssistantError
template_type = config_entry.options["template_type"]
name = config_entry.options["name"]
schema = cast(vol.Schema, OPTIONS_FLOW[template_type].schema)
entity_registry = er.async_get(hass)
entries = er.async_entries_for_config_entry(
entity_registry, flow_status["handler"]
)
if entries:
entity_registry_entry = entries[0]
errors = _validate(schema, template_type, msg["user_input"])
@callback
def async_preview_updated(
state: str | None,
attributes: Mapping[str, Any] | None,
listeners: dict[str, bool | set[str]] | None,
error: str | None,
) -> None:
"""Forward config entry state events to websocket."""
if error is not None:
connection.send_message(
websocket_api.event_message(
msg["id"],
{"error": error},
)
)
return
connection.send_message(
websocket_api.event_message(
msg["id"],
{"attributes": attributes, "listeners": listeners, "state": state},
)
)
if errors:
connection.send_message(
{
"id": msg["id"],
"type": websocket_api.const.TYPE_RESULT,
"success": False,
"error": {"code": "invalid_user_input", "message": errors},
}
)
return
_strip_sentinel(msg["user_input"])
preview_entity = CREATE_PREVIEW_ENTITY[template_type](hass, name, msg["user_input"])
preview_entity.hass = hass
preview_entity.registry_entry = entity_registry_entry
connection.send_result(msg["id"])
connection.subscriptions[msg["id"]] = preview_entity.async_start_preview(
async_preview_updated
)