Add TTS to pipelines (#90004)
* Add text to speech and stages to pipeline * Default to "cloud" TTS when engine is None * Refactor pipeline request to split text/audio * Refactor with PipelineRun * Generate pipeline from language * Clean up * Restore TTS code * Add audio pipeline test * Clean TTS cache in test * Clean up tests and pipeline base class * Stop pylint and pytest magics from fighting * Include mock_get_cache_filespull/90081/head
parent
ddcaa9d372
commit
0e7ffff869
|
@ -4,20 +4,13 @@ from __future__ import annotations
|
||||||
from homeassistant.core import HomeAssistant
|
from homeassistant.core import HomeAssistant
|
||||||
from homeassistant.helpers.typing import ConfigType
|
from homeassistant.helpers.typing import ConfigType
|
||||||
|
|
||||||
from .const import DEFAULT_PIPELINE, DOMAIN
|
from .const import DOMAIN
|
||||||
from .pipeline import Pipeline
|
|
||||||
from .websocket_api import async_register_websocket_api
|
from .websocket_api import async_register_websocket_api
|
||||||
|
|
||||||
|
|
||||||
async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool:
|
async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool:
|
||||||
"""Set up Voice Assistant integration."""
|
"""Set up Voice Assistant integration."""
|
||||||
hass.data[DOMAIN] = {
|
hass.data[DOMAIN] = {}
|
||||||
DEFAULT_PIPELINE: Pipeline(
|
|
||||||
name=DEFAULT_PIPELINE,
|
|
||||||
language=None,
|
|
||||||
conversation_engine=None,
|
|
||||||
)
|
|
||||||
}
|
|
||||||
async_register_websocket_api(hass)
|
async_register_websocket_api(hass)
|
||||||
|
|
||||||
return True
|
return True
|
||||||
|
|
|
@ -1,6 +1,7 @@
|
||||||
"""Classes for voice assistant pipelines."""
|
"""Classes for voice assistant pipelines."""
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from abc import ABC, abstractmethod
|
||||||
import asyncio
|
import asyncio
|
||||||
from collections.abc import Callable
|
from collections.abc import Callable
|
||||||
from dataclasses import dataclass, field
|
from dataclasses import dataclass, field
|
||||||
|
@ -8,20 +9,16 @@ from typing import Any
|
||||||
|
|
||||||
from homeassistant.backports.enum import StrEnum
|
from homeassistant.backports.enum import StrEnum
|
||||||
from homeassistant.components import conversation
|
from homeassistant.components import conversation
|
||||||
|
from homeassistant.components.media_source import async_resolve_media
|
||||||
|
from homeassistant.components.tts.media_source import (
|
||||||
|
generate_media_source_id as tts_generate_media_source_id,
|
||||||
|
)
|
||||||
from homeassistant.core import Context, HomeAssistant
|
from homeassistant.core import Context, HomeAssistant
|
||||||
from homeassistant.util.dt import utcnow
|
from homeassistant.util.dt import utcnow
|
||||||
|
|
||||||
DEFAULT_TIMEOUT = 30 # seconds
|
DEFAULT_TIMEOUT = 30 # seconds
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class PipelineRequest:
|
|
||||||
"""Request to start a pipeline run."""
|
|
||||||
|
|
||||||
intent_input: str
|
|
||||||
conversation_id: str | None = None
|
|
||||||
|
|
||||||
|
|
||||||
class PipelineEventType(StrEnum):
|
class PipelineEventType(StrEnum):
|
||||||
"""Event types emitted during a pipeline run."""
|
"""Event types emitted during a pipeline run."""
|
||||||
|
|
||||||
|
@ -29,6 +26,8 @@ class PipelineEventType(StrEnum):
|
||||||
RUN_FINISH = "run-finish"
|
RUN_FINISH = "run-finish"
|
||||||
INTENT_START = "intent-start"
|
INTENT_START = "intent-start"
|
||||||
INTENT_FINISH = "intent-finish"
|
INTENT_FINISH = "intent-finish"
|
||||||
|
TTS_START = "tts-start"
|
||||||
|
TTS_FINISH = "tts-finish"
|
||||||
ERROR = "error"
|
ERROR = "error"
|
||||||
|
|
||||||
|
|
||||||
|
@ -56,69 +55,161 @@ class Pipeline:
|
||||||
name: str
|
name: str
|
||||||
language: str | None
|
language: str | None
|
||||||
conversation_engine: str | None
|
conversation_engine: str | None
|
||||||
|
tts_engine: str | None
|
||||||
|
|
||||||
async def run(
|
|
||||||
self,
|
|
||||||
hass: HomeAssistant,
|
|
||||||
context: Context,
|
|
||||||
request: PipelineRequest,
|
|
||||||
event_callback: Callable[[PipelineEvent], None],
|
|
||||||
timeout: int | float | None = DEFAULT_TIMEOUT,
|
|
||||||
) -> None:
|
|
||||||
"""Run a pipeline with an optional timeout."""
|
|
||||||
await asyncio.wait_for(
|
|
||||||
self._run(hass, context, request, event_callback), timeout=timeout
|
|
||||||
)
|
|
||||||
|
|
||||||
async def _run(
|
@dataclass
|
||||||
self,
|
class PipelineRun:
|
||||||
hass: HomeAssistant,
|
"""Running context for a pipeline."""
|
||||||
context: Context,
|
|
||||||
request: PipelineRequest,
|
hass: HomeAssistant
|
||||||
event_callback: Callable[[PipelineEvent], None],
|
context: Context
|
||||||
) -> None:
|
pipeline: Pipeline
|
||||||
"""Run a pipeline."""
|
event_callback: Callable[[PipelineEvent], None]
|
||||||
language = self.language or hass.config.language
|
language: str = None # type: ignore[assignment]
|
||||||
event_callback(
|
|
||||||
|
def __post_init__(self):
|
||||||
|
"""Set language for pipeline."""
|
||||||
|
self.language = self.pipeline.language or self.hass.config.language
|
||||||
|
|
||||||
|
def start(self):
|
||||||
|
"""Emit run start event."""
|
||||||
|
self.event_callback(
|
||||||
PipelineEvent(
|
PipelineEvent(
|
||||||
PipelineEventType.RUN_START,
|
PipelineEventType.RUN_START,
|
||||||
{
|
{
|
||||||
"pipeline": self.name,
|
"pipeline": self.pipeline.name,
|
||||||
"language": language,
|
"language": self.language,
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
intent_input = request.intent_input
|
def finish(self):
|
||||||
|
"""Emit run finish event."""
|
||||||
|
self.event_callback(
|
||||||
|
PipelineEvent(
|
||||||
|
PipelineEventType.RUN_FINISH,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
event_callback(
|
async def recognize_intent(
|
||||||
|
self, intent_input: str, conversation_id: str | None
|
||||||
|
) -> conversation.ConversationResult:
|
||||||
|
"""Run intent recognition portion of pipeline."""
|
||||||
|
self.event_callback(
|
||||||
PipelineEvent(
|
PipelineEvent(
|
||||||
PipelineEventType.INTENT_START,
|
PipelineEventType.INTENT_START,
|
||||||
{
|
{
|
||||||
"engine": self.conversation_engine or "default",
|
"engine": self.pipeline.conversation_engine or "default",
|
||||||
"intent_input": intent_input,
|
"intent_input": intent_input,
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
conversation_result = await conversation.async_converse(
|
conversation_result = await conversation.async_converse(
|
||||||
hass=hass,
|
hass=self.hass,
|
||||||
text=intent_input,
|
text=intent_input,
|
||||||
conversation_id=request.conversation_id,
|
conversation_id=conversation_id,
|
||||||
context=context,
|
context=self.context,
|
||||||
language=language,
|
language=self.language,
|
||||||
agent_id=self.conversation_engine,
|
agent_id=self.pipeline.conversation_engine,
|
||||||
)
|
)
|
||||||
|
|
||||||
event_callback(
|
self.event_callback(
|
||||||
PipelineEvent(
|
PipelineEvent(
|
||||||
PipelineEventType.INTENT_FINISH,
|
PipelineEventType.INTENT_FINISH,
|
||||||
{"intent_output": conversation_result.as_dict()},
|
{"intent_output": conversation_result.as_dict()},
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
event_callback(
|
return conversation_result
|
||||||
|
|
||||||
|
async def text_to_speech(self, tts_input: str) -> str:
|
||||||
|
"""Run text to speech portion of pipeline. Returns URL of TTS audio."""
|
||||||
|
self.event_callback(
|
||||||
PipelineEvent(
|
PipelineEvent(
|
||||||
PipelineEventType.RUN_FINISH,
|
PipelineEventType.TTS_START,
|
||||||
|
{
|
||||||
|
"engine": self.pipeline.tts_engine or "default",
|
||||||
|
"tts_input": tts_input,
|
||||||
|
},
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
tts_media = await async_resolve_media(
|
||||||
|
self.hass,
|
||||||
|
tts_generate_media_source_id(
|
||||||
|
self.hass,
|
||||||
|
tts_input,
|
||||||
|
engine=self.pipeline.tts_engine,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
tts_url = tts_media.url
|
||||||
|
|
||||||
|
self.event_callback(
|
||||||
|
PipelineEvent(
|
||||||
|
PipelineEventType.TTS_FINISH,
|
||||||
|
{"tts_output": tts_url},
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
return tts_url
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class PipelineRequest(ABC):
|
||||||
|
"""Request to for a pipeline run."""
|
||||||
|
|
||||||
|
async def execute(
|
||||||
|
self, run: PipelineRun, timeout: int | float | None = DEFAULT_TIMEOUT
|
||||||
|
):
|
||||||
|
"""Run pipeline with optional timeout."""
|
||||||
|
await asyncio.wait_for(
|
||||||
|
self._execute(run),
|
||||||
|
timeout=timeout,
|
||||||
|
)
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
async def _execute(self, run: PipelineRun):
|
||||||
|
"""Run pipeline with request info and context."""
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class TextPipelineRequest(PipelineRequest):
|
||||||
|
"""Request to run the text portion only of a pipeline."""
|
||||||
|
|
||||||
|
intent_input: str
|
||||||
|
conversation_id: str | None = None
|
||||||
|
|
||||||
|
async def _execute(
|
||||||
|
self,
|
||||||
|
run: PipelineRun,
|
||||||
|
):
|
||||||
|
run.start()
|
||||||
|
await run.recognize_intent(self.intent_input, self.conversation_id)
|
||||||
|
run.finish()
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class AudioPipelineRequest(PipelineRequest):
|
||||||
|
"""Request to full pipeline from audio input (stt) to audio output (tts)."""
|
||||||
|
|
||||||
|
intent_input: str # this will be changed to stt audio
|
||||||
|
conversation_id: str | None = None
|
||||||
|
|
||||||
|
async def _execute(self, run: PipelineRun):
|
||||||
|
run.start()
|
||||||
|
|
||||||
|
# stt will go here
|
||||||
|
|
||||||
|
conversation_result = await run.recognize_intent(
|
||||||
|
self.intent_input, self.conversation_id
|
||||||
|
)
|
||||||
|
|
||||||
|
tts_input = conversation_result.response.speech.get("plain", {}).get(
|
||||||
|
"speech", ""
|
||||||
|
)
|
||||||
|
|
||||||
|
await run.text_to_speech(tts_input)
|
||||||
|
|
||||||
|
run.finish()
|
||||||
|
|
|
@ -7,7 +7,7 @@ from homeassistant.components import websocket_api
|
||||||
from homeassistant.core import HomeAssistant, callback
|
from homeassistant.core import HomeAssistant, callback
|
||||||
|
|
||||||
from .const import DOMAIN
|
from .const import DOMAIN
|
||||||
from .pipeline import DEFAULT_TIMEOUT, PipelineRequest
|
from .pipeline import DEFAULT_TIMEOUT, Pipeline, PipelineRun, TextPipelineRequest
|
||||||
|
|
||||||
|
|
||||||
@callback
|
@callback
|
||||||
|
@ -19,7 +19,8 @@ def async_register_websocket_api(hass: HomeAssistant) -> None:
|
||||||
@websocket_api.websocket_command(
|
@websocket_api.websocket_command(
|
||||||
{
|
{
|
||||||
vol.Required("type"): "voice_assistant/run",
|
vol.Required("type"): "voice_assistant/run",
|
||||||
vol.Optional("pipeline", default="default"): str,
|
vol.Optional("language"): str,
|
||||||
|
vol.Optional("pipeline"): str,
|
||||||
vol.Required("intent_input"): str,
|
vol.Required("intent_input"): str,
|
||||||
vol.Optional("conversation_id"): vol.Any(str, None),
|
vol.Optional("conversation_id"): vol.Any(str, None),
|
||||||
vol.Optional("timeout"): vol.Any(float, int),
|
vol.Optional("timeout"): vol.Any(float, int),
|
||||||
|
@ -32,28 +33,43 @@ async def websocket_run(
|
||||||
msg: dict[str, Any],
|
msg: dict[str, Any],
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Run a pipeline."""
|
"""Run a pipeline."""
|
||||||
pipeline_id = msg["pipeline"]
|
pipeline_id = msg.get("pipeline")
|
||||||
|
if pipeline_id is not None:
|
||||||
pipeline = hass.data[DOMAIN].get(pipeline_id)
|
pipeline = hass.data[DOMAIN].get(pipeline_id)
|
||||||
if pipeline is None:
|
if pipeline is None:
|
||||||
connection.send_error(
|
connection.send_error(
|
||||||
msg["id"], "pipeline_not_found", f"Pipeline not found: {pipeline_id}"
|
msg["id"],
|
||||||
|
"pipeline_not_found",
|
||||||
|
f"Pipeline not found: {pipeline_id}",
|
||||||
)
|
)
|
||||||
return
|
return
|
||||||
|
|
||||||
|
else:
|
||||||
|
# Construct a pipeline for the required/configured language
|
||||||
|
language = msg.get("language", hass.config.language)
|
||||||
|
pipeline = Pipeline(
|
||||||
|
name=language,
|
||||||
|
language=language,
|
||||||
|
conversation_engine=None,
|
||||||
|
tts_engine=None,
|
||||||
|
)
|
||||||
|
|
||||||
# Run pipeline with a timeout.
|
# Run pipeline with a timeout.
|
||||||
# Events are sent over the websocket connection.
|
# Events are sent over the websocket connection.
|
||||||
timeout = msg.get("timeout", DEFAULT_TIMEOUT)
|
timeout = msg.get("timeout", DEFAULT_TIMEOUT)
|
||||||
run_task = hass.async_create_task(
|
run_task = hass.async_create_task(
|
||||||
pipeline.run(
|
TextPipelineRequest(
|
||||||
hass,
|
|
||||||
connection.context(msg),
|
|
||||||
request=PipelineRequest(
|
|
||||||
intent_input=msg["intent_input"],
|
intent_input=msg["intent_input"],
|
||||||
conversation_id=msg.get("conversation_id"),
|
conversation_id=msg.get("conversation_id"),
|
||||||
),
|
).execute(
|
||||||
|
PipelineRun(
|
||||||
|
hass,
|
||||||
|
connection.context(msg),
|
||||||
|
pipeline,
|
||||||
event_callback=lambda event: connection.send_event(
|
event_callback=lambda event: connection.send_event(
|
||||||
msg["id"], event.as_dict()
|
msg["id"], event.as_dict()
|
||||||
),
|
),
|
||||||
|
),
|
||||||
timeout=timeout,
|
timeout=timeout,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
|
@ -0,0 +1,110 @@
|
||||||
|
"""Pipeline tests for Voice Assistant integration."""
|
||||||
|
from unittest.mock import MagicMock, patch
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from homeassistant.components.voice_assistant.pipeline import (
|
||||||
|
AudioPipelineRequest,
|
||||||
|
Pipeline,
|
||||||
|
PipelineEventType,
|
||||||
|
PipelineRun,
|
||||||
|
)
|
||||||
|
from homeassistant.core import Context
|
||||||
|
from homeassistant.setup import async_setup_component
|
||||||
|
|
||||||
|
from tests.components.tts.conftest import ( # noqa: F401, pylint: disable=unused-import
|
||||||
|
mock_get_cache_files,
|
||||||
|
mock_init_cache_dir,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(autouse=True)
|
||||||
|
async def init_components(hass):
|
||||||
|
"""Initialize relevant components with empty configs."""
|
||||||
|
assert await async_setup_component(hass, "voice_assistant", {})
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
async def mock_get_tts_audio(hass):
|
||||||
|
"""Set up media source."""
|
||||||
|
assert await async_setup_component(hass, "media_source", {})
|
||||||
|
assert await async_setup_component(
|
||||||
|
hass,
|
||||||
|
"tts",
|
||||||
|
{
|
||||||
|
"tts": {
|
||||||
|
"platform": "demo",
|
||||||
|
}
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
with patch(
|
||||||
|
"homeassistant.components.demo.tts.DemoProvider.get_tts_audio",
|
||||||
|
return_value=("mp3", b""),
|
||||||
|
) as mock_get_tts:
|
||||||
|
yield mock_get_tts
|
||||||
|
|
||||||
|
|
||||||
|
async def test_audio_pipeline(hass, mock_get_tts_audio):
|
||||||
|
"""Run audio pipeline with mock TTS."""
|
||||||
|
pipeline = Pipeline(
|
||||||
|
name="test",
|
||||||
|
language=hass.config.language,
|
||||||
|
conversation_engine=None,
|
||||||
|
tts_engine=None,
|
||||||
|
)
|
||||||
|
|
||||||
|
event_callback = MagicMock()
|
||||||
|
await AudioPipelineRequest(intent_input="Are the lights on?").execute(
|
||||||
|
PipelineRun(
|
||||||
|
hass,
|
||||||
|
context=Context(),
|
||||||
|
pipeline=pipeline,
|
||||||
|
event_callback=event_callback,
|
||||||
|
language=hass.config.language,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
calls = event_callback.mock_calls
|
||||||
|
assert calls[0].args[0].type == PipelineEventType.RUN_START
|
||||||
|
assert calls[0].args[0].data == {
|
||||||
|
"pipeline": "test",
|
||||||
|
"language": hass.config.language,
|
||||||
|
}
|
||||||
|
|
||||||
|
assert calls[1].args[0].type == PipelineEventType.INTENT_START
|
||||||
|
assert calls[1].args[0].data == {
|
||||||
|
"engine": "default",
|
||||||
|
"intent_input": "Are the lights on?",
|
||||||
|
}
|
||||||
|
assert calls[2].args[0].type == PipelineEventType.INTENT_FINISH
|
||||||
|
assert calls[2].args[0].data == {
|
||||||
|
"intent_output": {
|
||||||
|
"conversation_id": None,
|
||||||
|
"response": {
|
||||||
|
"card": {},
|
||||||
|
"data": {"code": "no_intent_match"},
|
||||||
|
"language": hass.config.language,
|
||||||
|
"response_type": "error",
|
||||||
|
"speech": {
|
||||||
|
"plain": {
|
||||||
|
"extra_data": None,
|
||||||
|
"speech": "Sorry, I couldn't understand that",
|
||||||
|
}
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
assert calls[3].args[0].type == PipelineEventType.TTS_START
|
||||||
|
assert calls[3].args[0].data == {
|
||||||
|
"engine": "default",
|
||||||
|
"tts_input": "Sorry, I couldn't understand that",
|
||||||
|
}
|
||||||
|
assert calls[4].args[0].type == PipelineEventType.TTS_FINISH
|
||||||
|
assert (
|
||||||
|
calls[4].args[0].data["tts_output"]
|
||||||
|
== f"/api/tts_proxy/dae2cdcb27a1d1c3b07ba2c7db91480f9d4bfd8f_{hass.config.language}_-_demo.mp3"
|
||||||
|
)
|
||||||
|
|
||||||
|
assert calls[5].args[0].type == PipelineEventType.RUN_FINISH
|
|
@ -24,7 +24,11 @@ async def test_text_only_pipeline(
|
||||||
client = await hass_ws_client(hass)
|
client = await hass_ws_client(hass)
|
||||||
|
|
||||||
await client.send_json(
|
await client.send_json(
|
||||||
{"id": 5, "type": "voice_assistant/run", "intent_input": "Are the lights on?"}
|
{
|
||||||
|
"id": 5,
|
||||||
|
"type": "voice_assistant/run",
|
||||||
|
"intent_input": "Are the lights on?",
|
||||||
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
# result
|
# result
|
||||||
|
@ -35,7 +39,7 @@ async def test_text_only_pipeline(
|
||||||
msg = await client.receive_json()
|
msg = await client.receive_json()
|
||||||
assert msg["event"]["type"] == "run-start"
|
assert msg["event"]["type"] == "run-start"
|
||||||
assert msg["event"]["data"] == {
|
assert msg["event"]["data"] == {
|
||||||
"pipeline": "default",
|
"pipeline": hass.config.language,
|
||||||
"language": hass.config.language,
|
"language": hass.config.language,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -83,7 +87,8 @@ async def test_conversation_timeout(
|
||||||
await asyncio.sleep(3600)
|
await asyncio.sleep(3600)
|
||||||
|
|
||||||
with patch(
|
with patch(
|
||||||
"homeassistant.components.conversation.async_converse", new=sleepy_converse
|
"homeassistant.components.conversation.async_converse",
|
||||||
|
new=sleepy_converse,
|
||||||
):
|
):
|
||||||
await client.send_json(
|
await client.send_json(
|
||||||
{
|
{
|
||||||
|
@ -102,7 +107,7 @@ async def test_conversation_timeout(
|
||||||
msg = await client.receive_json()
|
msg = await client.receive_json()
|
||||||
assert msg["event"]["type"] == "run-start"
|
assert msg["event"]["type"] == "run-start"
|
||||||
assert msg["event"]["data"] == {
|
assert msg["event"]["data"] == {
|
||||||
"pipeline": "default",
|
"pipeline": hass.config.language,
|
||||||
"language": hass.config.language,
|
"language": hass.config.language,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -130,7 +135,7 @@ async def test_pipeline_timeout(
|
||||||
await asyncio.sleep(3600)
|
await asyncio.sleep(3600)
|
||||||
|
|
||||||
with patch(
|
with patch(
|
||||||
"homeassistant.components.voice_assistant.pipeline.Pipeline._run",
|
"homeassistant.components.voice_assistant.pipeline.TextPipelineRequest._execute",
|
||||||
new=sleepy_run,
|
new=sleepy_run,
|
||||||
):
|
):
|
||||||
await client.send_json(
|
await client.send_json(
|
||||||
|
|
Loading…
Reference in New Issue