core/homeassistant/components/assist_pipeline/websocket_api.py

337 lines
11 KiB
Python

"""Assist pipeline Websocket API."""
import asyncio
# Suppressing disable=deprecated-module is needed for Python 3.11
import audioop # pylint: disable=deprecated-module
from collections.abc import AsyncGenerator, Callable
import logging
from typing import Any
import async_timeout
import voluptuous as vol
from homeassistant.components import conversation, stt, tts, websocket_api
from homeassistant.const import MATCH_ALL
from homeassistant.core import HomeAssistant, callback
from homeassistant.helpers import config_validation as cv
from homeassistant.util import language as language_util
from .const import DOMAIN
from .error import PipelineNotFound
from .pipeline import (
PipelineData,
PipelineError,
PipelineEvent,
PipelineEventType,
PipelineInput,
PipelineRun,
PipelineStage,
async_get_pipeline,
)
from .vad import VoiceCommandSegmenter
DEFAULT_TIMEOUT = 30
_LOGGER = logging.getLogger(__name__)
@callback
def async_register_websocket_api(hass: HomeAssistant) -> None:
"""Register the websocket API."""
websocket_api.async_register_command(hass, websocket_run)
websocket_api.async_register_command(hass, websocket_list_languages)
websocket_api.async_register_command(hass, websocket_list_runs)
websocket_api.async_register_command(hass, websocket_get_run)
@websocket_api.websocket_command(
vol.All(
websocket_api.BASE_COMMAND_MESSAGE_SCHEMA.extend(
{
vol.Required("type"): "assist_pipeline/run",
# pylint: disable-next=unnecessary-lambda
vol.Required("start_stage"): lambda val: PipelineStage(val),
# pylint: disable-next=unnecessary-lambda
vol.Required("end_stage"): lambda val: PipelineStage(val),
vol.Optional("input"): dict,
vol.Optional("pipeline"): str,
vol.Optional("conversation_id"): vol.Any(str, None),
vol.Optional("device_id"): vol.Any(str, None),
vol.Optional("timeout"): vol.Any(float, int),
},
),
cv.key_value_schemas(
"start_stage",
{
PipelineStage.STT: vol.Schema(
{vol.Required("input"): {vol.Required("sample_rate"): int}},
extra=vol.ALLOW_EXTRA,
),
PipelineStage.INTENT: vol.Schema(
{vol.Required("input"): {"text": str}},
extra=vol.ALLOW_EXTRA,
),
PipelineStage.TTS: vol.Schema(
{vol.Required("input"): {"text": str}},
extra=vol.ALLOW_EXTRA,
),
},
),
),
)
@websocket_api.async_response
async def websocket_run(
hass: HomeAssistant,
connection: websocket_api.ActiveConnection,
msg: dict[str, Any],
) -> None:
"""Run a pipeline."""
pipeline_id = msg.get("pipeline")
try:
pipeline = async_get_pipeline(hass, pipeline_id=pipeline_id)
except PipelineNotFound:
connection.send_error(
msg["id"],
"pipeline-not-found",
f"Pipeline not found: id={pipeline_id}",
)
return
timeout = msg.get("timeout", DEFAULT_TIMEOUT)
start_stage = PipelineStage(msg["start_stage"])
end_stage = PipelineStage(msg["end_stage"])
handler_id: int | None = None
unregister_handler: Callable[[], None] | None = None
# Arguments to PipelineInput
input_args: dict[str, Any] = {
"conversation_id": msg.get("conversation_id"),
"device_id": msg.get("device_id"),
}
if start_stage == PipelineStage.STT:
# Audio pipeline that will receive audio as binary websocket messages
audio_queue: "asyncio.Queue[bytes]" = asyncio.Queue()
incoming_sample_rate = msg["input"]["sample_rate"]
async def stt_stream() -> AsyncGenerator[bytes, None]:
state = None
segmenter = VoiceCommandSegmenter()
# Yield until we receive an empty chunk
while chunk := await audio_queue.get():
chunk, state = audioop.ratecv(
chunk, 2, 1, incoming_sample_rate, 16000, state
)
if not segmenter.process(chunk):
# Voice command is finished
break
yield chunk
def handle_binary(
_hass: HomeAssistant,
_connection: websocket_api.ActiveConnection,
data: bytes,
) -> None:
# Forward to STT audio stream
audio_queue.put_nowait(data)
handler_id, unregister_handler = connection.async_register_binary_handler(
handle_binary
)
# Audio input must be raw PCM at 16Khz with 16-bit mono samples
input_args["stt_metadata"] = stt.SpeechMetadata(
language=pipeline.stt_language or pipeline.language,
format=stt.AudioFormats.WAV,
codec=stt.AudioCodecs.PCM,
bit_rate=stt.AudioBitRates.BITRATE_16,
sample_rate=stt.AudioSampleRates.SAMPLERATE_16000,
channel=stt.AudioChannels.CHANNEL_MONO,
)
input_args["stt_stream"] = stt_stream()
elif start_stage == PipelineStage.INTENT:
# Input to conversation agent
input_args["intent_input"] = msg["input"]["text"]
elif start_stage == PipelineStage.TTS:
# Input to text-to-speech system
input_args["tts_input"] = msg["input"]["text"]
input_args["run"] = PipelineRun(
hass,
context=connection.context(msg),
pipeline=pipeline,
start_stage=start_stage,
end_stage=end_stage,
event_callback=lambda event: connection.send_event(msg["id"], event),
runner_data={
"stt_binary_handler_id": handler_id,
"timeout": timeout,
},
)
pipeline_input = PipelineInput(**input_args)
try:
await pipeline_input.validate()
except PipelineError as error:
# Report more specific error when possible
connection.send_error(msg["id"], error.code, error.message)
return
# Confirm subscription
connection.send_result(msg["id"])
run_task = hass.async_create_task(pipeline_input.execute())
# Cancel pipeline if user unsubscribes
connection.subscriptions[msg["id"]] = run_task.cancel
try:
# Task contains a timeout
async with async_timeout.timeout(timeout):
await run_task
except asyncio.TimeoutError:
pipeline_input.run.process_event(
PipelineEvent(
PipelineEventType.ERROR,
{"code": "timeout", "message": "Timeout running pipeline"},
)
)
finally:
if unregister_handler is not None:
# Unregister binary handler
unregister_handler()
@callback
@websocket_api.require_admin
@websocket_api.websocket_command(
{
vol.Required("type"): "assist_pipeline/pipeline_debug/list",
vol.Required("pipeline_id"): str,
}
)
def websocket_list_runs(
hass: HomeAssistant,
connection: websocket_api.connection.ActiveConnection,
msg: dict[str, Any],
) -> None:
"""List pipeline runs for which debug data is available."""
pipeline_data: PipelineData = hass.data[DOMAIN]
pipeline_id = msg["pipeline_id"]
if pipeline_id not in pipeline_data.pipeline_runs:
connection.send_result(msg["id"], {"pipeline_runs": []})
return
pipeline_runs = pipeline_data.pipeline_runs[pipeline_id]
connection.send_result(
msg["id"],
{
"pipeline_runs": [
{"pipeline_run_id": id, "timestamp": pipeline_run.timestamp}
for id, pipeline_run in pipeline_runs.items()
]
},
)
@callback
@websocket_api.require_admin
@websocket_api.websocket_command(
{
vol.Required("type"): "assist_pipeline/pipeline_debug/get",
vol.Required("pipeline_id"): str,
vol.Required("pipeline_run_id"): str,
}
)
def websocket_get_run(
hass: HomeAssistant,
connection: websocket_api.connection.ActiveConnection,
msg: dict[str, Any],
) -> None:
"""Get debug data for a pipeline run."""
pipeline_data: PipelineData = hass.data[DOMAIN]
pipeline_id = msg["pipeline_id"]
pipeline_run_id = msg["pipeline_run_id"]
if pipeline_id not in pipeline_data.pipeline_runs:
connection.send_error(
msg["id"],
websocket_api.const.ERR_NOT_FOUND,
f"pipeline_id {pipeline_id} not found",
)
return
pipeline_runs = pipeline_data.pipeline_runs[pipeline_id]
if pipeline_run_id not in pipeline_runs:
connection.send_error(
msg["id"],
websocket_api.const.ERR_NOT_FOUND,
f"pipeline_run_id {pipeline_run_id} not found",
)
return
connection.send_result(
msg["id"],
{"events": pipeline_runs[pipeline_run_id].events},
)
@websocket_api.websocket_command(
{
vol.Required("type"): "assist_pipeline/language/list",
}
)
@websocket_api.async_response
async def websocket_list_languages(
hass: HomeAssistant,
connection: websocket_api.connection.ActiveConnection,
msg: dict[str, Any],
) -> None:
"""List languages which are supported by a complete pipeline.
This will return a list of languages which are supported by at least one stt, tts
and conversation engine respectively.
"""
conv_language_tags = await conversation.async_get_conversation_languages(hass)
stt_language_tags = stt.async_get_speech_to_text_languages(hass)
tts_language_tags = tts.async_get_text_to_speech_languages(hass)
pipeline_languages: set[str] | None = None
if conv_language_tags and conv_language_tags != MATCH_ALL:
languages = set()
for language_tag in conv_language_tags:
dialect = language_util.Dialect.parse(language_tag)
languages.add(dialect.language)
pipeline_languages = languages
if stt_language_tags:
languages = set()
for language_tag in stt_language_tags:
dialect = language_util.Dialect.parse(language_tag)
languages.add(dialect.language)
if pipeline_languages is not None:
pipeline_languages &= languages
else:
pipeline_languages = languages
if tts_language_tags:
languages = set()
for language_tag in tts_language_tags:
dialect = language_util.Dialect.parse(language_tag)
languages.add(dialect.language)
if pipeline_languages is not None:
pipeline_languages &= languages
else:
pipeline_languages = languages
connection.send_result(
msg["id"],
{"languages": pipeline_languages},
)