core/homeassistant/components/assist_satellite/websocket_api.py

206 lines
6.8 KiB
Python
Raw Normal View History

"""Assist satellite Websocket API."""
import asyncio
from dataclasses import asdict, replace
from typing import Any
import voluptuous as vol
from homeassistant.components import websocket_api
from homeassistant.core import HomeAssistant, callback
from homeassistant.exceptions import HomeAssistantError
from homeassistant.helpers import config_validation as cv
from homeassistant.helpers.entity_component import EntityComponent
from homeassistant.util import uuid as uuid_util
from .connection_test import CONNECTION_TEST_URL_BASE
from .const import (
CONNECTION_TEST_DATA,
DATA_COMPONENT,
DOMAIN,
AssistSatelliteEntityFeature,
)
from .entity import AssistSatelliteEntity
CONNECTION_TEST_TIMEOUT = 30
@callback
def async_register_websocket_api(hass: HomeAssistant) -> None:
"""Register the websocket API."""
websocket_api.async_register_command(hass, websocket_intercept_wake_word)
websocket_api.async_register_command(hass, websocket_get_configuration)
websocket_api.async_register_command(hass, websocket_set_wake_words)
websocket_api.async_register_command(hass, websocket_test_connection)
@websocket_api.websocket_command(
{
vol.Required("type"): "assist_satellite/intercept_wake_word",
vol.Required("entity_id"): cv.entity_domain(DOMAIN),
}
)
@websocket_api.require_admin
@websocket_api.async_response
async def websocket_intercept_wake_word(
hass: HomeAssistant,
connection: websocket_api.connection.ActiveConnection,
msg: dict[str, Any],
) -> None:
"""Intercept the next wake word from a satellite."""
satellite = hass.data[DATA_COMPONENT].get_entity(msg["entity_id"])
if satellite is None:
connection.send_error(
msg["id"], websocket_api.ERR_NOT_FOUND, "Entity not found"
)
return
async def intercept_wake_word() -> None:
"""Push an intercepted wake word to websocket."""
try:
wake_word_phrase = await satellite.async_intercept_wake_word()
connection.send_message(
websocket_api.event_message(
msg["id"],
{"wake_word_phrase": wake_word_phrase},
)
)
except HomeAssistantError as err:
connection.send_error(msg["id"], "home_assistant_error", str(err))
task = hass.async_create_task(intercept_wake_word(), "intercept_wake_word")
connection.subscriptions[msg["id"]] = task.cancel
connection.send_message(websocket_api.result_message(msg["id"]))
@callback
@websocket_api.websocket_command(
{
vol.Required("type"): "assist_satellite/get_configuration",
vol.Required("entity_id"): cv.entity_domain(DOMAIN),
}
)
def websocket_get_configuration(
hass: HomeAssistant,
connection: websocket_api.connection.ActiveConnection,
msg: dict[str, Any],
) -> None:
"""Get the current satellite configuration."""
satellite = hass.data[DATA_COMPONENT].get_entity(msg["entity_id"])
if satellite is None:
connection.send_error(
msg["id"], websocket_api.ERR_NOT_FOUND, "Entity not found"
)
return
config_dict = asdict(satellite.async_get_configuration())
config_dict["pipeline_entity_id"] = satellite.pipeline_entity_id
config_dict["vad_entity_id"] = satellite.vad_sensitivity_entity_id
connection.send_result(msg["id"], config_dict)
@websocket_api.websocket_command(
{
vol.Required("type"): "assist_satellite/set_wake_words",
vol.Required("entity_id"): cv.entity_domain(DOMAIN),
vol.Required("wake_word_ids"): [str],
}
)
@websocket_api.require_admin
@websocket_api.async_response
async def websocket_set_wake_words(
hass: HomeAssistant,
connection: websocket_api.connection.ActiveConnection,
msg: dict[str, Any],
) -> None:
"""Set the active wake words for the satellite."""
satellite = hass.data[DATA_COMPONENT].get_entity(msg["entity_id"])
if satellite is None:
connection.send_error(
msg["id"], websocket_api.ERR_NOT_FOUND, "Entity not found"
)
return
config = satellite.async_get_configuration()
# Don't set too many active wake words
actual_ids = msg["wake_word_ids"]
if len(actual_ids) > config.max_active_wake_words:
connection.send_error(
msg["id"],
websocket_api.ERR_NOT_SUPPORTED,
f"Maximum number of active wake words is {config.max_active_wake_words}",
)
return
# Verify all ids are available
available_ids = {ww.id for ww in config.available_wake_words}
for ww_id in actual_ids:
if ww_id not in available_ids:
connection.send_error(
msg["id"],
websocket_api.ERR_NOT_SUPPORTED,
f"Wake word id is not supported: {ww_id}",
)
return
await satellite.async_set_configuration(
replace(config, active_wake_words=actual_ids)
)
connection.send_result(msg["id"])
@websocket_api.websocket_command(
{
vol.Required("type"): "assist_satellite/test_connection",
vol.Required("entity_id"): cv.entity_domain(DOMAIN),
}
)
@websocket_api.async_response
async def websocket_test_connection(
hass: HomeAssistant,
connection: websocket_api.connection.ActiveConnection,
msg: dict[str, Any],
) -> None:
"""Test the connection between the device and Home Assistant.
Send an announcement to the device with a special media id.
"""
component: EntityComponent[AssistSatelliteEntity] = hass.data[DOMAIN]
satellite = component.get_entity(msg["entity_id"])
if satellite is None:
connection.send_error(
msg["id"], websocket_api.ERR_NOT_FOUND, "Entity not found"
)
return
if not (satellite.supported_features or 0) & AssistSatelliteEntityFeature.ANNOUNCE:
connection.send_error(
msg["id"],
websocket_api.ERR_NOT_SUPPORTED,
"Entity does not support announce",
)
return
# Announce and wait for event
connection_test_data = hass.data[CONNECTION_TEST_DATA]
connection_id = uuid_util.random_uuid_hex()
connection_test_event = asyncio.Event()
connection_test_data[connection_id] = connection_test_event
hass.async_create_background_task(
satellite.async_internal_announce(
media_id=f"{CONNECTION_TEST_URL_BASE}/{connection_id}"
),
f"assist_satellite_connection_test_{msg['entity_id']}",
)
try:
async with asyncio.timeout(CONNECTION_TEST_TIMEOUT):
await connection_test_event.wait()
connection.send_result(msg["id"], {"status": "success"})
except TimeoutError:
connection.send_result(msg["id"], {"status": "timeout"})
finally:
connection_test_data.pop(connection_id, None)