2023-03-17 00:42:26 +00:00
|
|
|
"""Classes for voice assistant pipelines."""
|
|
|
|
from __future__ import annotations
|
|
|
|
|
|
|
|
import asyncio
|
2023-03-23 18:44:19 +00:00
|
|
|
from collections.abc import AsyncIterable, Callable
|
|
|
|
from dataclasses import asdict, dataclass, field
|
|
|
|
import logging
|
2023-03-17 00:42:26 +00:00
|
|
|
from typing import Any
|
|
|
|
|
2023-04-06 16:55:16 +00:00
|
|
|
import voluptuous as vol
|
|
|
|
|
2023-03-17 00:42:26 +00:00
|
|
|
from homeassistant.backports.enum import StrEnum
|
2023-04-15 14:05:46 +00:00
|
|
|
from homeassistant.components import conversation, media_source, stt, tts, websocket_api
|
2023-03-22 01:10:31 +00:00
|
|
|
from homeassistant.components.tts.media_source import (
|
|
|
|
generate_media_source_id as tts_generate_media_source_id,
|
|
|
|
)
|
2023-03-23 18:44:19 +00:00
|
|
|
from homeassistant.core import Context, HomeAssistant, callback
|
2023-04-06 16:55:16 +00:00
|
|
|
from homeassistant.helpers.collection import (
|
2023-04-15 14:05:46 +00:00
|
|
|
CollectionError,
|
|
|
|
ItemNotFound,
|
|
|
|
SerializedStorageCollection,
|
2023-04-06 16:55:16 +00:00
|
|
|
StorageCollection,
|
|
|
|
StorageCollectionWebsocket,
|
|
|
|
)
|
|
|
|
from homeassistant.helpers.storage import Store
|
|
|
|
from homeassistant.util import dt as dt_util, ulid as ulid_util
|
2023-04-17 15:48:02 +00:00
|
|
|
from homeassistant.util.limited_size_dict import LimitedSizeDict
|
2023-03-17 00:42:26 +00:00
|
|
|
|
2023-03-23 18:44:19 +00:00
|
|
|
from .const import DOMAIN
|
2023-04-04 04:06:51 +00:00
|
|
|
from .error import (
|
|
|
|
IntentRecognitionError,
|
|
|
|
PipelineError,
|
|
|
|
SpeechToTextError,
|
|
|
|
TextToSpeechError,
|
|
|
|
)
|
2023-03-23 18:44:19 +00:00
|
|
|
|
|
|
|
_LOGGER = logging.getLogger(__name__)
|
|
|
|
|
2023-04-06 16:55:16 +00:00
|
|
|
STORAGE_KEY = f"{DOMAIN}.pipelines"
|
|
|
|
STORAGE_VERSION = 1
|
|
|
|
|
|
|
|
STORAGE_FIELDS = {
|
2023-04-17 18:54:04 +00:00
|
|
|
vol.Optional("conversation_engine", default=None): vol.Any(str, None),
|
2023-04-06 16:55:16 +00:00
|
|
|
vol.Required("language"): str,
|
|
|
|
vol.Required("name"): str,
|
2023-04-17 18:54:04 +00:00
|
|
|
vol.Optional("stt_engine", default=None): vol.Any(str, None),
|
|
|
|
vol.Optional("tts_engine", default=None): vol.Any(str, None),
|
2023-04-06 16:55:16 +00:00
|
|
|
}
|
|
|
|
|
2023-04-17 15:48:02 +00:00
|
|
|
STORED_PIPELINE_RUNS = 10
|
|
|
|
|
2023-04-06 16:55:16 +00:00
|
|
|
SAVE_DELAY = 10
|
2023-03-23 18:44:19 +00:00
|
|
|
|
2023-04-06 16:55:16 +00:00
|
|
|
|
|
|
|
async def async_get_pipeline(
|
2023-04-18 16:07:20 +00:00
|
|
|
hass: HomeAssistant, pipeline_id: str | None = None
|
2023-03-23 18:44:19 +00:00
|
|
|
) -> Pipeline | None:
|
|
|
|
"""Get a pipeline by id or create one for a language."""
|
2023-04-17 15:48:02 +00:00
|
|
|
pipeline_data: PipelineData = hass.data[DOMAIN]
|
2023-04-06 16:55:16 +00:00
|
|
|
|
2023-04-18 15:35:33 +00:00
|
|
|
if pipeline_id is None:
|
|
|
|
# A pipeline was not specified, use the preferred one
|
|
|
|
pipeline_id = pipeline_data.pipeline_store.async_get_preferred_item()
|
|
|
|
|
|
|
|
if pipeline_id is None:
|
|
|
|
# There's no preferred pipeline, construct a pipeline for the
|
2023-04-18 16:07:20 +00:00
|
|
|
# configured language
|
2023-04-18 15:35:33 +00:00
|
|
|
return await pipeline_data.pipeline_store.async_create_item(
|
|
|
|
{
|
2023-04-18 16:07:20 +00:00
|
|
|
"name": hass.config.language,
|
|
|
|
"language": hass.config.language,
|
2023-04-18 15:35:33 +00:00
|
|
|
"stt_engine": None, # first engine
|
|
|
|
"conversation_engine": None, # first agent
|
|
|
|
"tts_engine": None, # first engine
|
|
|
|
}
|
|
|
|
)
|
|
|
|
|
|
|
|
return pipeline_data.pipeline_store.data.get(pipeline_id)
|
2023-03-23 18:44:19 +00:00
|
|
|
|
|
|
|
|
2023-03-17 00:42:26 +00:00
|
|
|
class PipelineEventType(StrEnum):
|
|
|
|
"""Event types emitted during a pipeline run."""
|
|
|
|
|
|
|
|
RUN_START = "run-start"
|
2023-03-23 18:44:19 +00:00
|
|
|
RUN_END = "run-end"
|
|
|
|
STT_START = "stt-start"
|
|
|
|
STT_END = "stt-end"
|
2023-03-17 00:42:26 +00:00
|
|
|
INTENT_START = "intent-start"
|
2023-03-23 18:44:19 +00:00
|
|
|
INTENT_END = "intent-end"
|
2023-03-22 01:10:31 +00:00
|
|
|
TTS_START = "tts-start"
|
2023-03-23 18:44:19 +00:00
|
|
|
TTS_END = "tts-end"
|
2023-03-17 00:42:26 +00:00
|
|
|
ERROR = "error"
|
|
|
|
|
|
|
|
|
2023-04-17 14:33:53 +00:00
|
|
|
@dataclass(frozen=True)
|
2023-03-17 00:42:26 +00:00
|
|
|
class PipelineEvent:
|
|
|
|
"""Events emitted during a pipeline run."""
|
|
|
|
|
|
|
|
type: PipelineEventType
|
|
|
|
data: dict[str, Any] | None = None
|
2023-04-06 16:55:16 +00:00
|
|
|
timestamp: str = field(default_factory=lambda: dt_util.utcnow().isoformat())
|
2023-03-17 00:42:26 +00:00
|
|
|
|
|
|
|
|
2023-04-04 04:06:51 +00:00
|
|
|
PipelineEventCallback = Callable[[PipelineEvent], None]
|
|
|
|
|
|
|
|
|
2023-04-06 16:55:16 +00:00
|
|
|
@dataclass(frozen=True)
|
2023-03-17 00:42:26 +00:00
|
|
|
class Pipeline:
|
|
|
|
"""A voice assistant pipeline."""
|
|
|
|
|
2023-04-06 16:55:16 +00:00
|
|
|
conversation_engine: str | None
|
2023-04-17 17:09:11 +00:00
|
|
|
language: str
|
2023-04-06 16:55:16 +00:00
|
|
|
name: str
|
2023-03-23 18:44:19 +00:00
|
|
|
stt_engine: str | None
|
2023-03-22 01:10:31 +00:00
|
|
|
tts_engine: str | None
|
2023-03-17 00:42:26 +00:00
|
|
|
|
2023-04-06 16:55:16 +00:00
|
|
|
id: str = field(default_factory=ulid_util.ulid)
|
|
|
|
|
|
|
|
def to_json(self) -> dict[str, Any]:
|
|
|
|
"""Return a JSON serializable representation for storage."""
|
|
|
|
return {
|
|
|
|
"conversation_engine": self.conversation_engine,
|
|
|
|
"id": self.id,
|
|
|
|
"language": self.language,
|
|
|
|
"name": self.name,
|
|
|
|
"stt_engine": self.stt_engine,
|
|
|
|
"tts_engine": self.tts_engine,
|
|
|
|
}
|
|
|
|
|
2023-03-17 00:42:26 +00:00
|
|
|
|
2023-03-23 18:44:19 +00:00
|
|
|
class PipelineStage(StrEnum):
|
|
|
|
"""Stages of a pipeline."""
|
|
|
|
|
|
|
|
STT = "stt"
|
|
|
|
INTENT = "intent"
|
|
|
|
TTS = "tts"
|
|
|
|
|
|
|
|
|
|
|
|
PIPELINE_STAGE_ORDER = [
|
|
|
|
PipelineStage.STT,
|
|
|
|
PipelineStage.INTENT,
|
|
|
|
PipelineStage.TTS,
|
|
|
|
]
|
|
|
|
|
|
|
|
|
|
|
|
class PipelineRunValidationError(Exception):
|
|
|
|
"""Error when a pipeline run is not valid."""
|
|
|
|
|
|
|
|
|
|
|
|
class InvalidPipelineStagesError(PipelineRunValidationError):
|
|
|
|
"""Error when given an invalid combination of start/end stages."""
|
|
|
|
|
|
|
|
def __init__(
|
|
|
|
self,
|
|
|
|
start_stage: PipelineStage,
|
|
|
|
end_stage: PipelineStage,
|
|
|
|
) -> None:
|
|
|
|
"""Set error message."""
|
|
|
|
super().__init__(
|
|
|
|
f"Invalid stage combination: start={start_stage}, end={end_stage}"
|
|
|
|
)
|
|
|
|
|
|
|
|
|
2023-03-22 01:10:31 +00:00
|
|
|
@dataclass
|
|
|
|
class PipelineRun:
|
|
|
|
"""Running context for a pipeline."""
|
|
|
|
|
|
|
|
hass: HomeAssistant
|
|
|
|
context: Context
|
|
|
|
pipeline: Pipeline
|
2023-03-23 18:44:19 +00:00
|
|
|
start_stage: PipelineStage
|
|
|
|
end_stage: PipelineStage
|
2023-04-04 04:06:51 +00:00
|
|
|
event_callback: PipelineEventCallback
|
2023-03-22 01:10:31 +00:00
|
|
|
language: str = None # type: ignore[assignment]
|
2023-03-27 02:41:17 +00:00
|
|
|
runner_data: Any | None = None
|
2023-04-18 21:59:06 +00:00
|
|
|
stt_provider: stt.SpeechToTextEntity | stt.Provider | None = None
|
2023-03-31 19:04:22 +00:00
|
|
|
intent_agent: str | None = None
|
|
|
|
tts_engine: str | None = None
|
2023-04-12 00:25:05 +00:00
|
|
|
tts_options: dict | None = None
|
2023-03-22 01:10:31 +00:00
|
|
|
|
2023-04-17 15:48:02 +00:00
|
|
|
id: str = field(default_factory=ulid_util.ulid)
|
|
|
|
|
2023-04-17 08:32:14 +00:00
|
|
|
def __post_init__(self) -> None:
|
2023-03-22 01:10:31 +00:00
|
|
|
"""Set language for pipeline."""
|
|
|
|
self.language = self.pipeline.language or self.hass.config.language
|
|
|
|
|
2023-03-23 18:44:19 +00:00
|
|
|
# stt -> intent -> tts
|
|
|
|
if PIPELINE_STAGE_ORDER.index(self.end_stage) < PIPELINE_STAGE_ORDER.index(
|
|
|
|
self.start_stage
|
|
|
|
):
|
|
|
|
raise InvalidPipelineStagesError(self.start_stage, self.end_stage)
|
|
|
|
|
2023-04-17 15:48:02 +00:00
|
|
|
pipeline_data: PipelineData = self.hass.data[DOMAIN]
|
|
|
|
if self.pipeline.id not in pipeline_data.pipeline_runs:
|
|
|
|
pipeline_data.pipeline_runs[self.pipeline.id] = LimitedSizeDict(
|
|
|
|
size_limit=STORED_PIPELINE_RUNS
|
|
|
|
)
|
2023-04-18 14:43:46 +00:00
|
|
|
pipeline_data.pipeline_runs[self.pipeline.id][self.id] = PipelineRunDebug()
|
2023-04-17 15:48:02 +00:00
|
|
|
|
|
|
|
@callback
|
|
|
|
def process_event(self, event: PipelineEvent) -> None:
|
|
|
|
"""Log an event and call listener."""
|
|
|
|
self.event_callback(event)
|
|
|
|
pipeline_data: PipelineData = self.hass.data[DOMAIN]
|
|
|
|
if self.id not in pipeline_data.pipeline_runs[self.pipeline.id]:
|
|
|
|
# This run has been evicted from the logged pipeline runs already
|
|
|
|
return
|
2023-04-18 14:43:46 +00:00
|
|
|
pipeline_data.pipeline_runs[self.pipeline.id][self.id].events.append(event)
|
2023-04-17 15:48:02 +00:00
|
|
|
|
2023-04-17 08:32:14 +00:00
|
|
|
def start(self) -> None:
|
2023-03-22 01:10:31 +00:00
|
|
|
"""Emit run start event."""
|
2023-03-27 02:41:17 +00:00
|
|
|
data = {
|
|
|
|
"pipeline": self.pipeline.name,
|
|
|
|
"language": self.language,
|
|
|
|
}
|
|
|
|
if self.runner_data is not None:
|
|
|
|
data["runner_data"] = self.runner_data
|
|
|
|
|
2023-04-17 15:48:02 +00:00
|
|
|
self.process_event(PipelineEvent(PipelineEventType.RUN_START, data))
|
2023-03-17 00:42:26 +00:00
|
|
|
|
2023-04-17 08:32:14 +00:00
|
|
|
def end(self) -> None:
|
2023-03-23 18:44:19 +00:00
|
|
|
"""Emit run end event."""
|
2023-04-17 15:48:02 +00:00
|
|
|
self.process_event(
|
2023-03-23 18:44:19 +00:00
|
|
|
PipelineEvent(
|
|
|
|
PipelineEventType.RUN_END,
|
|
|
|
)
|
|
|
|
)
|
|
|
|
|
2023-03-31 19:04:22 +00:00
|
|
|
async def prepare_speech_to_text(self, metadata: stt.SpeechMetadata) -> None:
|
|
|
|
"""Prepare speech to text."""
|
2023-04-18 21:59:06 +00:00
|
|
|
stt_provider: stt.SpeechToTextEntity | stt.Provider | None = None
|
|
|
|
|
|
|
|
if self.pipeline.stt_engine is not None:
|
|
|
|
# Try entity first
|
|
|
|
stt_provider = stt.async_get_speech_to_text_entity(
|
|
|
|
self.hass,
|
|
|
|
self.pipeline.stt_engine,
|
|
|
|
)
|
|
|
|
|
|
|
|
if stt_provider is None:
|
|
|
|
# Try legacy provider second
|
|
|
|
stt_provider = stt.async_get_provider(
|
|
|
|
self.hass,
|
|
|
|
self.pipeline.stt_engine,
|
|
|
|
)
|
2023-03-31 19:04:22 +00:00
|
|
|
|
|
|
|
if stt_provider is None:
|
|
|
|
engine = self.pipeline.stt_engine or "default"
|
|
|
|
raise SpeechToTextError(
|
|
|
|
code="stt-provider-missing",
|
|
|
|
message=f"No speech to text provider for: {engine}",
|
|
|
|
)
|
|
|
|
|
|
|
|
if not stt_provider.check_metadata(metadata):
|
|
|
|
raise SpeechToTextError(
|
|
|
|
code="stt-provider-unsupported-metadata",
|
|
|
|
message=(
|
2023-04-02 03:34:52 +00:00
|
|
|
f"Provider {stt_provider.name} does not support input speech "
|
2023-03-31 19:04:22 +00:00
|
|
|
"to text metadata"
|
|
|
|
),
|
|
|
|
)
|
|
|
|
|
|
|
|
self.stt_provider = stt_provider
|
|
|
|
|
2023-03-23 18:44:19 +00:00
|
|
|
async def speech_to_text(
|
|
|
|
self,
|
|
|
|
metadata: stt.SpeechMetadata,
|
|
|
|
stream: AsyncIterable[bytes],
|
|
|
|
) -> str:
|
|
|
|
"""Run speech to text portion of pipeline. Returns the spoken text."""
|
2023-03-31 19:04:22 +00:00
|
|
|
if self.stt_provider is None:
|
|
|
|
raise RuntimeError("Speech to text was not prepared")
|
|
|
|
|
|
|
|
engine = self.stt_provider.name
|
|
|
|
|
2023-04-17 15:48:02 +00:00
|
|
|
self.process_event(
|
2023-03-22 01:10:31 +00:00
|
|
|
PipelineEvent(
|
2023-03-23 18:44:19 +00:00
|
|
|
PipelineEventType.STT_START,
|
|
|
|
{
|
|
|
|
"engine": engine,
|
|
|
|
"metadata": asdict(metadata),
|
|
|
|
},
|
2023-03-22 01:10:31 +00:00
|
|
|
)
|
|
|
|
)
|
2023-03-17 00:42:26 +00:00
|
|
|
|
2023-03-23 18:44:19 +00:00
|
|
|
try:
|
|
|
|
# Transcribe audio stream
|
2023-03-31 19:04:22 +00:00
|
|
|
result = await self.stt_provider.async_process_audio_stream(
|
|
|
|
metadata, stream
|
|
|
|
)
|
2023-03-23 18:44:19 +00:00
|
|
|
except Exception as src_error:
|
2023-03-27 02:41:17 +00:00
|
|
|
_LOGGER.exception("Unexpected error during speech to text")
|
|
|
|
raise SpeechToTextError(
|
2023-03-23 18:44:19 +00:00
|
|
|
code="stt-stream-failed",
|
|
|
|
message="Unexpected error during speech to text",
|
2023-03-27 02:41:17 +00:00
|
|
|
) from src_error
|
|
|
|
|
|
|
|
_LOGGER.debug("speech-to-text result %s", result)
|
|
|
|
|
|
|
|
if result.result != stt.SpeechResultState.SUCCESS:
|
|
|
|
raise SpeechToTextError(
|
|
|
|
code="stt-stream-failed",
|
|
|
|
message="Speech to text failed",
|
2023-03-23 18:44:19 +00:00
|
|
|
)
|
2023-03-27 02:41:17 +00:00
|
|
|
|
|
|
|
if not result.text:
|
|
|
|
raise SpeechToTextError(
|
|
|
|
code="stt-no-text-recognized", message="No text recognized"
|
2023-03-23 18:44:19 +00:00
|
|
|
)
|
|
|
|
|
2023-04-17 15:48:02 +00:00
|
|
|
self.process_event(
|
2023-03-23 18:44:19 +00:00
|
|
|
PipelineEvent(
|
|
|
|
PipelineEventType.STT_END,
|
|
|
|
{
|
|
|
|
"stt_output": {
|
|
|
|
"text": result.text,
|
|
|
|
}
|
|
|
|
},
|
|
|
|
)
|
|
|
|
)
|
|
|
|
|
|
|
|
return result.text
|
|
|
|
|
2023-03-31 19:04:22 +00:00
|
|
|
async def prepare_recognize_intent(self) -> None:
|
|
|
|
"""Prepare recognizing an intent."""
|
|
|
|
agent_info = conversation.async_get_agent_info(
|
2023-04-13 04:34:19 +00:00
|
|
|
self.hass,
|
|
|
|
# If no conversation engine is set, use the Home Assistant agent
|
|
|
|
# (the conversation integration default is currently the last one set)
|
|
|
|
self.pipeline.conversation_engine or conversation.HOME_ASSISTANT_AGENT,
|
2023-03-31 19:04:22 +00:00
|
|
|
)
|
|
|
|
|
|
|
|
if agent_info is None:
|
|
|
|
engine = self.pipeline.conversation_engine or "default"
|
|
|
|
raise IntentRecognitionError(
|
|
|
|
code="intent-not-supported",
|
|
|
|
message=f"Intent recognition engine {engine} is not found",
|
|
|
|
)
|
|
|
|
|
2023-04-19 14:53:24 +00:00
|
|
|
self.intent_agent = agent_info.id
|
2023-03-31 19:04:22 +00:00
|
|
|
|
2023-03-22 01:10:31 +00:00
|
|
|
async def recognize_intent(
|
|
|
|
self, intent_input: str, conversation_id: str | None
|
2023-03-23 18:44:19 +00:00
|
|
|
) -> str:
|
|
|
|
"""Run intent recognition portion of pipeline. Returns text to speak."""
|
2023-03-31 19:04:22 +00:00
|
|
|
if self.intent_agent is None:
|
|
|
|
raise RuntimeError("Recognize intent was not prepared")
|
|
|
|
|
2023-04-17 15:48:02 +00:00
|
|
|
self.process_event(
|
2023-03-17 00:42:26 +00:00
|
|
|
PipelineEvent(
|
|
|
|
PipelineEventType.INTENT_START,
|
|
|
|
{
|
2023-03-31 19:04:22 +00:00
|
|
|
"engine": self.intent_agent,
|
2023-03-17 00:42:26 +00:00
|
|
|
"intent_input": intent_input,
|
|
|
|
},
|
|
|
|
)
|
|
|
|
)
|
|
|
|
|
2023-03-23 18:44:19 +00:00
|
|
|
try:
|
|
|
|
conversation_result = await conversation.async_converse(
|
|
|
|
hass=self.hass,
|
|
|
|
text=intent_input,
|
|
|
|
conversation_id=conversation_id,
|
|
|
|
context=self.context,
|
|
|
|
language=self.language,
|
2023-03-31 19:04:22 +00:00
|
|
|
agent_id=self.intent_agent,
|
2023-03-23 18:44:19 +00:00
|
|
|
)
|
|
|
|
except Exception as src_error:
|
2023-03-27 02:41:17 +00:00
|
|
|
_LOGGER.exception("Unexpected error during intent recognition")
|
|
|
|
raise IntentRecognitionError(
|
2023-03-23 18:44:19 +00:00
|
|
|
code="intent-failed",
|
|
|
|
message="Unexpected error during intent recognition",
|
2023-03-27 02:41:17 +00:00
|
|
|
) from src_error
|
|
|
|
|
|
|
|
_LOGGER.debug("conversation result %s", conversation_result)
|
2023-03-17 00:42:26 +00:00
|
|
|
|
2023-04-17 15:48:02 +00:00
|
|
|
self.process_event(
|
2023-03-17 00:42:26 +00:00
|
|
|
PipelineEvent(
|
2023-03-23 18:44:19 +00:00
|
|
|
PipelineEventType.INTENT_END,
|
2023-03-17 00:42:26 +00:00
|
|
|
{"intent_output": conversation_result.as_dict()},
|
|
|
|
)
|
|
|
|
)
|
|
|
|
|
2023-04-17 08:32:14 +00:00
|
|
|
speech: str = conversation_result.response.speech.get("plain", {}).get(
|
|
|
|
"speech", ""
|
|
|
|
)
|
2023-03-23 18:44:19 +00:00
|
|
|
|
|
|
|
return speech
|
2023-03-22 01:10:31 +00:00
|
|
|
|
2023-03-31 19:04:22 +00:00
|
|
|
async def prepare_text_to_speech(self) -> None:
|
|
|
|
"""Prepare text to speech."""
|
|
|
|
engine = tts.async_resolve_engine(self.hass, self.pipeline.tts_engine)
|
|
|
|
|
|
|
|
if engine is None:
|
|
|
|
engine = self.pipeline.tts_engine or "default"
|
|
|
|
raise TextToSpeechError(
|
|
|
|
code="tts-not-supported",
|
|
|
|
message=f"Text to speech engine '{engine}' not found",
|
|
|
|
)
|
|
|
|
|
2023-04-12 00:25:05 +00:00
|
|
|
if not await tts.async_support_options(
|
|
|
|
self.hass,
|
|
|
|
engine,
|
|
|
|
self.language,
|
|
|
|
self.tts_options,
|
|
|
|
):
|
2023-03-31 19:04:22 +00:00
|
|
|
raise TextToSpeechError(
|
|
|
|
code="tts-not-supported",
|
|
|
|
message=(
|
|
|
|
f"Text to speech engine {engine} "
|
2023-04-12 00:25:05 +00:00
|
|
|
f"does not support language {self.language} or options {self.tts_options}"
|
2023-03-31 19:04:22 +00:00
|
|
|
),
|
|
|
|
)
|
|
|
|
|
|
|
|
self.tts_engine = engine
|
|
|
|
|
2023-03-22 01:10:31 +00:00
|
|
|
async def text_to_speech(self, tts_input: str) -> str:
|
|
|
|
"""Run text to speech portion of pipeline. Returns URL of TTS audio."""
|
2023-03-31 19:04:22 +00:00
|
|
|
if self.tts_engine is None:
|
|
|
|
raise RuntimeError("Text to speech was not prepared")
|
|
|
|
|
2023-04-17 15:48:02 +00:00
|
|
|
self.process_event(
|
2023-03-17 00:42:26 +00:00
|
|
|
PipelineEvent(
|
2023-03-22 01:10:31 +00:00
|
|
|
PipelineEventType.TTS_START,
|
|
|
|
{
|
2023-03-31 19:04:22 +00:00
|
|
|
"engine": self.tts_engine,
|
2023-03-22 01:10:31 +00:00
|
|
|
"tts_input": tts_input,
|
|
|
|
},
|
2023-03-17 00:42:26 +00:00
|
|
|
)
|
|
|
|
)
|
2023-03-22 01:10:31 +00:00
|
|
|
|
2023-03-23 18:44:19 +00:00
|
|
|
try:
|
|
|
|
# Synthesize audio and get URL
|
2023-04-12 00:25:05 +00:00
|
|
|
tts_media_id = tts_generate_media_source_id(
|
|
|
|
self.hass,
|
|
|
|
tts_input,
|
|
|
|
engine=self.tts_engine,
|
|
|
|
language=self.language,
|
|
|
|
options=self.tts_options,
|
|
|
|
)
|
2023-03-23 18:44:19 +00:00
|
|
|
tts_media = await media_source.async_resolve_media(
|
2023-03-22 01:10:31 +00:00
|
|
|
self.hass,
|
2023-04-12 00:25:05 +00:00
|
|
|
tts_media_id,
|
2023-04-13 03:23:20 +00:00
|
|
|
None,
|
2023-03-23 18:44:19 +00:00
|
|
|
)
|
|
|
|
except Exception as src_error:
|
2023-03-27 02:41:17 +00:00
|
|
|
_LOGGER.exception("Unexpected error during text to speech")
|
|
|
|
raise TextToSpeechError(
|
2023-03-23 18:44:19 +00:00
|
|
|
code="tts-failed",
|
|
|
|
message="Unexpected error during text to speech",
|
2023-03-27 02:41:17 +00:00
|
|
|
) from src_error
|
|
|
|
|
|
|
|
_LOGGER.debug("TTS result %s", tts_media)
|
2023-03-22 01:10:31 +00:00
|
|
|
|
2023-04-17 15:48:02 +00:00
|
|
|
self.process_event(
|
2023-03-22 01:10:31 +00:00
|
|
|
PipelineEvent(
|
2023-03-23 18:44:19 +00:00
|
|
|
PipelineEventType.TTS_END,
|
2023-04-12 00:25:05 +00:00
|
|
|
{
|
|
|
|
"tts_output": {
|
|
|
|
"media_id": tts_media_id,
|
|
|
|
**asdict(tts_media),
|
|
|
|
}
|
|
|
|
},
|
2023-03-22 01:10:31 +00:00
|
|
|
)
|
|
|
|
)
|
|
|
|
|
2023-03-23 18:44:19 +00:00
|
|
|
return tts_media.url
|
2023-03-22 01:10:31 +00:00
|
|
|
|
|
|
|
|
|
|
|
@dataclass
|
2023-03-23 18:44:19 +00:00
|
|
|
class PipelineInput:
|
|
|
|
"""Input to a pipeline run."""
|
|
|
|
|
2023-03-31 19:04:22 +00:00
|
|
|
run: PipelineRun
|
|
|
|
|
2023-03-23 18:44:19 +00:00
|
|
|
stt_metadata: stt.SpeechMetadata | None = None
|
|
|
|
"""Metadata of stt input audio. Required when start_stage = stt."""
|
|
|
|
|
|
|
|
stt_stream: AsyncIterable[bytes] | None = None
|
|
|
|
"""Input audio for stt. Required when start_stage = stt."""
|
|
|
|
|
|
|
|
intent_input: str | None = None
|
|
|
|
"""Input for conversation agent. Required when start_stage = intent."""
|
|
|
|
|
|
|
|
tts_input: str | None = None
|
|
|
|
"""Input for text to speech. Required when start_stage = tts."""
|
|
|
|
|
|
|
|
conversation_id: str | None = None
|
2023-03-22 01:10:31 +00:00
|
|
|
|
2023-04-17 08:32:14 +00:00
|
|
|
async def execute(self) -> None:
|
2023-03-31 19:04:22 +00:00
|
|
|
"""Run pipeline."""
|
|
|
|
self.run.start()
|
|
|
|
current_stage = self.run.start_stage
|
2023-03-23 18:44:19 +00:00
|
|
|
|
2023-03-27 02:41:17 +00:00
|
|
|
try:
|
|
|
|
# Speech to text
|
|
|
|
intent_input = self.intent_input
|
|
|
|
if current_stage == PipelineStage.STT:
|
|
|
|
assert self.stt_metadata is not None
|
|
|
|
assert self.stt_stream is not None
|
2023-03-31 19:04:22 +00:00
|
|
|
intent_input = await self.run.speech_to_text(
|
2023-03-27 02:41:17 +00:00
|
|
|
self.stt_metadata,
|
|
|
|
self.stt_stream,
|
|
|
|
)
|
|
|
|
current_stage = PipelineStage.INTENT
|
2023-03-23 18:44:19 +00:00
|
|
|
|
2023-03-31 19:04:22 +00:00
|
|
|
if self.run.end_stage != PipelineStage.STT:
|
2023-03-27 02:41:17 +00:00
|
|
|
tts_input = self.tts_input
|
2023-03-23 18:44:19 +00:00
|
|
|
|
2023-03-27 02:41:17 +00:00
|
|
|
if current_stage == PipelineStage.INTENT:
|
|
|
|
assert intent_input is not None
|
2023-03-31 19:04:22 +00:00
|
|
|
tts_input = await self.run.recognize_intent(
|
2023-03-27 02:41:17 +00:00
|
|
|
intent_input, self.conversation_id
|
|
|
|
)
|
|
|
|
current_stage = PipelineStage.TTS
|
2023-03-23 18:44:19 +00:00
|
|
|
|
2023-03-31 19:04:22 +00:00
|
|
|
if self.run.end_stage != PipelineStage.INTENT:
|
2023-03-27 02:41:17 +00:00
|
|
|
if current_stage == PipelineStage.TTS:
|
|
|
|
assert tts_input is not None
|
2023-03-31 19:04:22 +00:00
|
|
|
await self.run.text_to_speech(tts_input)
|
2023-03-27 02:41:17 +00:00
|
|
|
|
|
|
|
except PipelineError as err:
|
2023-04-17 15:48:02 +00:00
|
|
|
self.run.process_event(
|
2023-03-27 02:41:17 +00:00
|
|
|
PipelineEvent(
|
|
|
|
PipelineEventType.ERROR,
|
|
|
|
{"code": err.code, "message": err.message},
|
|
|
|
)
|
|
|
|
)
|
|
|
|
return
|
2023-03-23 18:44:19 +00:00
|
|
|
|
2023-03-31 19:04:22 +00:00
|
|
|
self.run.end()
|
2023-03-23 18:44:19 +00:00
|
|
|
|
2023-04-17 08:32:14 +00:00
|
|
|
async def validate(self) -> None:
|
2023-03-23 18:44:19 +00:00
|
|
|
"""Validate pipeline input against start stage."""
|
2023-03-31 19:04:22 +00:00
|
|
|
if self.run.start_stage == PipelineStage.STT:
|
2023-03-23 18:44:19 +00:00
|
|
|
if self.stt_metadata is None:
|
|
|
|
raise PipelineRunValidationError(
|
|
|
|
"stt_metadata is required for speech to text"
|
|
|
|
)
|
|
|
|
|
|
|
|
if self.stt_stream is None:
|
|
|
|
raise PipelineRunValidationError(
|
|
|
|
"stt_stream is required for speech to text"
|
|
|
|
)
|
2023-03-31 19:04:22 +00:00
|
|
|
elif self.run.start_stage == PipelineStage.INTENT:
|
2023-03-23 18:44:19 +00:00
|
|
|
if self.intent_input is None:
|
|
|
|
raise PipelineRunValidationError(
|
|
|
|
"intent_input is required for intent recognition"
|
|
|
|
)
|
2023-03-31 19:04:22 +00:00
|
|
|
elif self.run.start_stage == PipelineStage.TTS:
|
2023-03-23 18:44:19 +00:00
|
|
|
if self.tts_input is None:
|
|
|
|
raise PipelineRunValidationError(
|
|
|
|
"tts_input is required for text to speech"
|
|
|
|
)
|
2023-03-31 19:04:22 +00:00
|
|
|
|
|
|
|
start_stage_index = PIPELINE_STAGE_ORDER.index(self.run.start_stage)
|
|
|
|
|
|
|
|
prepare_tasks = []
|
|
|
|
|
|
|
|
if start_stage_index <= PIPELINE_STAGE_ORDER.index(PipelineStage.STT):
|
2023-04-17 08:32:14 +00:00
|
|
|
# self.stt_metadata can't be None or we'd raise above
|
|
|
|
prepare_tasks.append(self.run.prepare_speech_to_text(self.stt_metadata)) # type: ignore[arg-type]
|
2023-03-31 19:04:22 +00:00
|
|
|
|
|
|
|
if start_stage_index <= PIPELINE_STAGE_ORDER.index(PipelineStage.INTENT):
|
|
|
|
prepare_tasks.append(self.run.prepare_recognize_intent())
|
|
|
|
|
|
|
|
if start_stage_index <= PIPELINE_STAGE_ORDER.index(PipelineStage.TTS):
|
|
|
|
prepare_tasks.append(self.run.prepare_text_to_speech())
|
|
|
|
|
|
|
|
if prepare_tasks:
|
|
|
|
await asyncio.gather(*prepare_tasks)
|
2023-04-06 16:55:16 +00:00
|
|
|
|
|
|
|
|
2023-04-15 14:05:46 +00:00
|
|
|
class PipelinePreferred(CollectionError):
|
|
|
|
"""Raised when attempting to delete the preferred pipelen."""
|
|
|
|
|
|
|
|
def __init__(self, item_id: str) -> None:
|
|
|
|
"""Initialize pipeline preferred error."""
|
|
|
|
super().__init__(f"Item {item_id} preferred.")
|
|
|
|
self.item_id = item_id
|
|
|
|
|
|
|
|
|
|
|
|
class SerializedPipelineStorageCollection(SerializedStorageCollection):
|
|
|
|
"""Serialized pipeline storage collection."""
|
|
|
|
|
|
|
|
preferred_item: str | None
|
|
|
|
|
|
|
|
|
|
|
|
class PipelineStorageCollection(
|
|
|
|
StorageCollection[Pipeline, SerializedPipelineStorageCollection]
|
|
|
|
):
|
2023-04-06 16:55:16 +00:00
|
|
|
"""Pipeline storage collection."""
|
|
|
|
|
|
|
|
CREATE_UPDATE_SCHEMA = vol.Schema(STORAGE_FIELDS)
|
|
|
|
|
2023-04-15 14:05:46 +00:00
|
|
|
_preferred_item: str | None = None
|
|
|
|
|
|
|
|
async def _async_load_data(self) -> SerializedPipelineStorageCollection | None:
|
|
|
|
"""Load the data."""
|
|
|
|
if not (data := await super()._async_load_data()):
|
|
|
|
return data
|
|
|
|
|
|
|
|
self._preferred_item = data["preferred_item"]
|
|
|
|
|
|
|
|
return data
|
|
|
|
|
2023-04-06 16:55:16 +00:00
|
|
|
async def _process_create_data(self, data: dict) -> dict:
|
|
|
|
"""Validate the config is valid."""
|
|
|
|
# We don't need to validate, the WS API has already validated
|
|
|
|
return data
|
|
|
|
|
|
|
|
@callback
|
|
|
|
def _get_suggested_id(self, info: dict) -> str:
|
|
|
|
"""Suggest an ID based on the config."""
|
|
|
|
return ulid_util.ulid()
|
|
|
|
|
|
|
|
async def _update_data(self, item: Pipeline, update_data: dict) -> Pipeline:
|
|
|
|
"""Return a new updated item."""
|
|
|
|
return Pipeline(id=item.id, **update_data)
|
|
|
|
|
|
|
|
def _create_item(self, item_id: str, data: dict) -> Pipeline:
|
|
|
|
"""Create an item from validated config."""
|
2023-04-15 14:05:46 +00:00
|
|
|
if self._preferred_item is None:
|
|
|
|
self._preferred_item = item_id
|
2023-04-06 16:55:16 +00:00
|
|
|
return Pipeline(id=item_id, **data)
|
|
|
|
|
|
|
|
def _deserialize_item(self, data: dict) -> Pipeline:
|
|
|
|
"""Create an item from its serialized representation."""
|
|
|
|
return Pipeline(**data)
|
|
|
|
|
|
|
|
def _serialize_item(self, item_id: str, item: Pipeline) -> dict:
|
2023-04-15 14:05:46 +00:00
|
|
|
"""Return the serialized representation of an item for storing."""
|
2023-04-06 16:55:16 +00:00
|
|
|
return item.to_json()
|
|
|
|
|
2023-04-15 14:05:46 +00:00
|
|
|
async def async_delete_item(self, item_id: str) -> None:
|
|
|
|
"""Delete item."""
|
|
|
|
if self._preferred_item == item_id:
|
|
|
|
raise PipelinePreferred(item_id)
|
|
|
|
await super().async_delete_item(item_id)
|
|
|
|
|
|
|
|
@callback
|
|
|
|
def async_get_preferred_item(self) -> str | None:
|
|
|
|
"""Get the id of the preferred item."""
|
|
|
|
return self._preferred_item
|
|
|
|
|
|
|
|
@callback
|
|
|
|
def async_set_preferred_item(self, item_id: str) -> None:
|
|
|
|
"""Set the preferred pipeline."""
|
|
|
|
if item_id not in self.data:
|
|
|
|
raise ItemNotFound(item_id)
|
|
|
|
self._preferred_item = item_id
|
|
|
|
self._async_schedule_save()
|
|
|
|
|
|
|
|
@callback
|
|
|
|
def _data_to_save(self) -> SerializedPipelineStorageCollection:
|
|
|
|
"""Return JSON-compatible date for storing to file."""
|
|
|
|
base_data = super()._base_data_to_save()
|
|
|
|
return {
|
|
|
|
"items": base_data["items"],
|
|
|
|
"preferred_item": self._preferred_item,
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
class PipelineStorageCollectionWebsocket(
|
|
|
|
StorageCollectionWebsocket[PipelineStorageCollection]
|
|
|
|
):
|
|
|
|
"""Class to expose storage collection management over websocket."""
|
|
|
|
|
|
|
|
@callback
|
|
|
|
def async_setup(
|
|
|
|
self,
|
|
|
|
hass: HomeAssistant,
|
|
|
|
*,
|
|
|
|
create_list: bool = True,
|
|
|
|
create_create: bool = True,
|
|
|
|
) -> None:
|
|
|
|
"""Set up the websocket commands."""
|
|
|
|
super().async_setup(hass, create_list=create_list, create_create=create_create)
|
|
|
|
|
2023-04-20 13:15:19 +00:00
|
|
|
websocket_api.async_register_command(
|
|
|
|
hass,
|
|
|
|
f"{self.api_prefix}/get",
|
|
|
|
self.ws_get_item,
|
|
|
|
websocket_api.BASE_COMMAND_MESSAGE_SCHEMA.extend(
|
|
|
|
{
|
|
|
|
vol.Required("type"): f"{self.api_prefix}/get",
|
|
|
|
vol.Optional(self.item_id_key): str,
|
|
|
|
}
|
|
|
|
),
|
|
|
|
)
|
|
|
|
|
2023-04-15 14:05:46 +00:00
|
|
|
websocket_api.async_register_command(
|
|
|
|
hass,
|
|
|
|
f"{self.api_prefix}/set_preferred",
|
|
|
|
websocket_api.require_admin(
|
|
|
|
websocket_api.async_response(self.ws_set_preferred_item)
|
|
|
|
),
|
|
|
|
websocket_api.BASE_COMMAND_MESSAGE_SCHEMA.extend(
|
|
|
|
{
|
|
|
|
vol.Required("type"): f"{self.api_prefix}/set_preferred",
|
|
|
|
vol.Required(self.item_id_key): str,
|
|
|
|
}
|
|
|
|
),
|
|
|
|
)
|
|
|
|
|
2023-04-20 13:15:19 +00:00
|
|
|
async def ws_delete_item(
|
|
|
|
self, hass: HomeAssistant, connection: websocket_api.ActiveConnection, msg: dict
|
|
|
|
) -> None:
|
|
|
|
"""Delete an item."""
|
|
|
|
try:
|
|
|
|
await super().ws_delete_item(hass, connection, msg)
|
|
|
|
except PipelinePreferred as exc:
|
|
|
|
connection.send_error(
|
|
|
|
msg["id"], websocket_api.const.ERR_NOT_ALLOWED, str(exc)
|
|
|
|
)
|
|
|
|
|
|
|
|
@callback
|
|
|
|
def ws_get_item(
|
|
|
|
self, hass: HomeAssistant, connection: websocket_api.ActiveConnection, msg: dict
|
|
|
|
) -> None:
|
|
|
|
"""Get an item."""
|
|
|
|
item_id = msg.get(self.item_id_key)
|
|
|
|
if item_id is None:
|
|
|
|
item_id = self.storage_collection.async_get_preferred_item()
|
|
|
|
|
|
|
|
if item_id not in self.storage_collection.data:
|
|
|
|
connection.send_error(
|
|
|
|
msg["id"],
|
|
|
|
websocket_api.const.ERR_NOT_FOUND,
|
|
|
|
f"Unable to find {self.item_id_key} {item_id}",
|
|
|
|
)
|
|
|
|
return
|
|
|
|
|
|
|
|
connection.send_result(msg["id"], self.storage_collection.data[item_id])
|
|
|
|
|
2023-04-20 11:58:11 +00:00
|
|
|
@callback
|
2023-04-15 14:05:46 +00:00
|
|
|
def ws_list_item(
|
|
|
|
self, hass: HomeAssistant, connection: websocket_api.ActiveConnection, msg: dict
|
|
|
|
) -> None:
|
|
|
|
"""List items."""
|
|
|
|
connection.send_result(
|
|
|
|
msg["id"],
|
|
|
|
{
|
|
|
|
"pipelines": self.storage_collection.async_items(),
|
|
|
|
"preferred_pipeline": self.storage_collection.async_get_preferred_item(),
|
|
|
|
},
|
|
|
|
)
|
|
|
|
|
|
|
|
async def ws_set_preferred_item(
|
|
|
|
self,
|
|
|
|
hass: HomeAssistant,
|
|
|
|
connection: websocket_api.ActiveConnection,
|
|
|
|
msg: dict[str, Any],
|
|
|
|
) -> None:
|
|
|
|
"""Set the preferred item."""
|
|
|
|
try:
|
|
|
|
self.storage_collection.async_set_preferred_item(msg[self.item_id_key])
|
|
|
|
except ItemNotFound:
|
|
|
|
connection.send_error(
|
|
|
|
msg["id"], websocket_api.const.ERR_NOT_FOUND, "unknown item"
|
|
|
|
)
|
|
|
|
return
|
|
|
|
connection.send_result(msg["id"])
|
|
|
|
|
2023-04-06 16:55:16 +00:00
|
|
|
|
2023-04-17 15:48:02 +00:00
|
|
|
@dataclass
|
|
|
|
class PipelineData:
|
|
|
|
"""Store and debug data stored in hass.data."""
|
|
|
|
|
2023-04-18 14:43:46 +00:00
|
|
|
pipeline_runs: dict[str, LimitedSizeDict[str, PipelineRunDebug]]
|
2023-04-17 15:48:02 +00:00
|
|
|
pipeline_store: PipelineStorageCollection
|
|
|
|
|
|
|
|
|
2023-04-18 14:43:46 +00:00
|
|
|
@dataclass
|
|
|
|
class PipelineRunDebug:
|
|
|
|
"""Debug data for a pipelinerun."""
|
|
|
|
|
|
|
|
events: list[PipelineEvent] = field(default_factory=list, init=False)
|
|
|
|
timestamp: str = field(
|
|
|
|
default_factory=lambda: dt_util.utcnow().isoformat(),
|
|
|
|
init=False,
|
|
|
|
)
|
|
|
|
|
|
|
|
|
2023-04-17 08:32:14 +00:00
|
|
|
async def async_setup_pipeline_store(hass: HomeAssistant) -> None:
|
2023-04-06 16:55:16 +00:00
|
|
|
"""Set up the pipeline storage collection."""
|
|
|
|
pipeline_store = PipelineStorageCollection(
|
|
|
|
Store(hass, STORAGE_VERSION, STORAGE_KEY)
|
|
|
|
)
|
|
|
|
await pipeline_store.async_load()
|
2023-04-15 14:05:46 +00:00
|
|
|
PipelineStorageCollectionWebsocket(
|
2023-04-06 16:55:16 +00:00
|
|
|
pipeline_store, f"{DOMAIN}/pipeline", "pipeline", STORAGE_FIELDS, STORAGE_FIELDS
|
|
|
|
).async_setup(hass)
|
2023-04-17 15:48:02 +00:00
|
|
|
hass.data[DOMAIN] = PipelineData({}, pipeline_store)
|