diff --git a/homeassistant/components/voice_assistant/__init__.py b/homeassistant/components/voice_assistant/__init__.py index d06176847e9..2ae169a28eb 100644 --- a/homeassistant/components/voice_assistant/__init__.py +++ b/homeassistant/components/voice_assistant/__init__.py @@ -4,20 +4,13 @@ from __future__ import annotations from homeassistant.core import HomeAssistant from homeassistant.helpers.typing import ConfigType -from .const import DEFAULT_PIPELINE, DOMAIN -from .pipeline import Pipeline +from .const import DOMAIN from .websocket_api import async_register_websocket_api async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool: """Set up Voice Assistant integration.""" - hass.data[DOMAIN] = { - DEFAULT_PIPELINE: Pipeline( - name=DEFAULT_PIPELINE, - language=None, - conversation_engine=None, - ) - } + hass.data[DOMAIN] = {} async_register_websocket_api(hass) return True diff --git a/homeassistant/components/voice_assistant/pipeline.py b/homeassistant/components/voice_assistant/pipeline.py index 8c7d22981ab..0b55d724554 100644 --- a/homeassistant/components/voice_assistant/pipeline.py +++ b/homeassistant/components/voice_assistant/pipeline.py @@ -1,6 +1,7 @@ """Classes for voice assistant pipelines.""" from __future__ import annotations +from abc import ABC, abstractmethod import asyncio from collections.abc import Callable from dataclasses import dataclass, field @@ -8,20 +9,16 @@ from typing import Any from homeassistant.backports.enum import StrEnum from homeassistant.components import conversation +from homeassistant.components.media_source import async_resolve_media +from homeassistant.components.tts.media_source import ( + generate_media_source_id as tts_generate_media_source_id, +) from homeassistant.core import Context, HomeAssistant from homeassistant.util.dt import utcnow DEFAULT_TIMEOUT = 30 # seconds -@dataclass -class PipelineRequest: - """Request to start a pipeline run.""" - - intent_input: str - conversation_id: str | None = None - - class PipelineEventType(StrEnum): """Event types emitted during a pipeline run.""" @@ -29,6 +26,8 @@ class PipelineEventType(StrEnum): RUN_FINISH = "run-finish" INTENT_START = "intent-start" INTENT_FINISH = "intent-finish" + TTS_START = "tts-start" + TTS_FINISH = "tts-finish" ERROR = "error" @@ -56,69 +55,161 @@ class Pipeline: name: str language: str | None conversation_engine: str | None + tts_engine: str | None - async def run( - self, - hass: HomeAssistant, - context: Context, - request: PipelineRequest, - event_callback: Callable[[PipelineEvent], None], - timeout: int | float | None = DEFAULT_TIMEOUT, - ) -> None: - """Run a pipeline with an optional timeout.""" - await asyncio.wait_for( - self._run(hass, context, request, event_callback), timeout=timeout - ) - async def _run( - self, - hass: HomeAssistant, - context: Context, - request: PipelineRequest, - event_callback: Callable[[PipelineEvent], None], - ) -> None: - """Run a pipeline.""" - language = self.language or hass.config.language - event_callback( +@dataclass +class PipelineRun: + """Running context for a pipeline.""" + + hass: HomeAssistant + context: Context + pipeline: Pipeline + event_callback: Callable[[PipelineEvent], None] + language: str = None # type: ignore[assignment] + + def __post_init__(self): + """Set language for pipeline.""" + self.language = self.pipeline.language or self.hass.config.language + + def start(self): + """Emit run start event.""" + self.event_callback( PipelineEvent( PipelineEventType.RUN_START, { - "pipeline": self.name, - "language": language, + "pipeline": self.pipeline.name, + "language": self.language, }, ) ) - intent_input = request.intent_input + def finish(self): + """Emit run finish event.""" + self.event_callback( + PipelineEvent( + PipelineEventType.RUN_FINISH, + ) + ) - event_callback( + async def recognize_intent( + self, intent_input: str, conversation_id: str | None + ) -> conversation.ConversationResult: + """Run intent recognition portion of pipeline.""" + self.event_callback( PipelineEvent( PipelineEventType.INTENT_START, { - "engine": self.conversation_engine or "default", + "engine": self.pipeline.conversation_engine or "default", "intent_input": intent_input, }, ) ) conversation_result = await conversation.async_converse( - hass=hass, + hass=self.hass, text=intent_input, - conversation_id=request.conversation_id, - context=context, - language=language, - agent_id=self.conversation_engine, + conversation_id=conversation_id, + context=self.context, + language=self.language, + agent_id=self.pipeline.conversation_engine, ) - event_callback( + self.event_callback( PipelineEvent( PipelineEventType.INTENT_FINISH, {"intent_output": conversation_result.as_dict()}, ) ) - event_callback( + return conversation_result + + async def text_to_speech(self, tts_input: str) -> str: + """Run text to speech portion of pipeline. Returns URL of TTS audio.""" + self.event_callback( PipelineEvent( - PipelineEventType.RUN_FINISH, + PipelineEventType.TTS_START, + { + "engine": self.pipeline.tts_engine or "default", + "tts_input": tts_input, + }, ) ) + + tts_media = await async_resolve_media( + self.hass, + tts_generate_media_source_id( + self.hass, + tts_input, + engine=self.pipeline.tts_engine, + ), + ) + tts_url = tts_media.url + + self.event_callback( + PipelineEvent( + PipelineEventType.TTS_FINISH, + {"tts_output": tts_url}, + ) + ) + + return tts_url + + +@dataclass +class PipelineRequest(ABC): + """Request to for a pipeline run.""" + + async def execute( + self, run: PipelineRun, timeout: int | float | None = DEFAULT_TIMEOUT + ): + """Run pipeline with optional timeout.""" + await asyncio.wait_for( + self._execute(run), + timeout=timeout, + ) + + @abstractmethod + async def _execute(self, run: PipelineRun): + """Run pipeline with request info and context.""" + + +@dataclass +class TextPipelineRequest(PipelineRequest): + """Request to run the text portion only of a pipeline.""" + + intent_input: str + conversation_id: str | None = None + + async def _execute( + self, + run: PipelineRun, + ): + run.start() + await run.recognize_intent(self.intent_input, self.conversation_id) + run.finish() + + +@dataclass +class AudioPipelineRequest(PipelineRequest): + """Request to full pipeline from audio input (stt) to audio output (tts).""" + + intent_input: str # this will be changed to stt audio + conversation_id: str | None = None + + async def _execute(self, run: PipelineRun): + run.start() + + # stt will go here + + conversation_result = await run.recognize_intent( + self.intent_input, self.conversation_id + ) + + tts_input = conversation_result.response.speech.get("plain", {}).get( + "speech", "" + ) + + await run.text_to_speech(tts_input) + + run.finish() diff --git a/homeassistant/components/voice_assistant/websocket_api.py b/homeassistant/components/voice_assistant/websocket_api.py index 4ea88c3da00..54e87e292a1 100644 --- a/homeassistant/components/voice_assistant/websocket_api.py +++ b/homeassistant/components/voice_assistant/websocket_api.py @@ -7,7 +7,7 @@ from homeassistant.components import websocket_api from homeassistant.core import HomeAssistant, callback from .const import DOMAIN -from .pipeline import DEFAULT_TIMEOUT, PipelineRequest +from .pipeline import DEFAULT_TIMEOUT, Pipeline, PipelineRun, TextPipelineRequest @callback @@ -19,7 +19,8 @@ def async_register_websocket_api(hass: HomeAssistant) -> None: @websocket_api.websocket_command( { vol.Required("type"): "voice_assistant/run", - vol.Optional("pipeline", default="default"): str, + vol.Optional("language"): str, + vol.Optional("pipeline"): str, vol.Required("intent_input"): str, vol.Optional("conversation_id"): vol.Any(str, None), vol.Optional("timeout"): vol.Any(float, int), @@ -32,27 +33,42 @@ async def websocket_run( msg: dict[str, Any], ) -> None: """Run a pipeline.""" - pipeline_id = msg["pipeline"] - pipeline = hass.data[DOMAIN].get(pipeline_id) - if pipeline is None: - connection.send_error( - msg["id"], "pipeline_not_found", f"Pipeline not found: {pipeline_id}" + pipeline_id = msg.get("pipeline") + if pipeline_id is not None: + pipeline = hass.data[DOMAIN].get(pipeline_id) + if pipeline is None: + connection.send_error( + msg["id"], + "pipeline_not_found", + f"Pipeline not found: {pipeline_id}", + ) + return + + else: + # Construct a pipeline for the required/configured language + language = msg.get("language", hass.config.language) + pipeline = Pipeline( + name=language, + language=language, + conversation_engine=None, + tts_engine=None, ) - return # Run pipeline with a timeout. # Events are sent over the websocket connection. timeout = msg.get("timeout", DEFAULT_TIMEOUT) run_task = hass.async_create_task( - pipeline.run( - hass, - connection.context(msg), - request=PipelineRequest( - intent_input=msg["intent_input"], - conversation_id=msg.get("conversation_id"), - ), - event_callback=lambda event: connection.send_event( - msg["id"], event.as_dict() + TextPipelineRequest( + intent_input=msg["intent_input"], + conversation_id=msg.get("conversation_id"), + ).execute( + PipelineRun( + hass, + connection.context(msg), + pipeline, + event_callback=lambda event: connection.send_event( + msg["id"], event.as_dict() + ), ), timeout=timeout, ) diff --git a/tests/components/voice_assistant/test_pipeline.py b/tests/components/voice_assistant/test_pipeline.py new file mode 100644 index 00000000000..343719a49fd --- /dev/null +++ b/tests/components/voice_assistant/test_pipeline.py @@ -0,0 +1,110 @@ +"""Pipeline tests for Voice Assistant integration.""" +from unittest.mock import MagicMock, patch + +import pytest + +from homeassistant.components.voice_assistant.pipeline import ( + AudioPipelineRequest, + Pipeline, + PipelineEventType, + PipelineRun, +) +from homeassistant.core import Context +from homeassistant.setup import async_setup_component + +from tests.components.tts.conftest import ( # noqa: F401, pylint: disable=unused-import + mock_get_cache_files, + mock_init_cache_dir, +) + + +@pytest.fixture(autouse=True) +async def init_components(hass): + """Initialize relevant components with empty configs.""" + assert await async_setup_component(hass, "voice_assistant", {}) + + +@pytest.fixture +async def mock_get_tts_audio(hass): + """Set up media source.""" + assert await async_setup_component(hass, "media_source", {}) + assert await async_setup_component( + hass, + "tts", + { + "tts": { + "platform": "demo", + } + }, + ) + + with patch( + "homeassistant.components.demo.tts.DemoProvider.get_tts_audio", + return_value=("mp3", b""), + ) as mock_get_tts: + yield mock_get_tts + + +async def test_audio_pipeline(hass, mock_get_tts_audio): + """Run audio pipeline with mock TTS.""" + pipeline = Pipeline( + name="test", + language=hass.config.language, + conversation_engine=None, + tts_engine=None, + ) + + event_callback = MagicMock() + await AudioPipelineRequest(intent_input="Are the lights on?").execute( + PipelineRun( + hass, + context=Context(), + pipeline=pipeline, + event_callback=event_callback, + language=hass.config.language, + ) + ) + + calls = event_callback.mock_calls + assert calls[0].args[0].type == PipelineEventType.RUN_START + assert calls[0].args[0].data == { + "pipeline": "test", + "language": hass.config.language, + } + + assert calls[1].args[0].type == PipelineEventType.INTENT_START + assert calls[1].args[0].data == { + "engine": "default", + "intent_input": "Are the lights on?", + } + assert calls[2].args[0].type == PipelineEventType.INTENT_FINISH + assert calls[2].args[0].data == { + "intent_output": { + "conversation_id": None, + "response": { + "card": {}, + "data": {"code": "no_intent_match"}, + "language": hass.config.language, + "response_type": "error", + "speech": { + "plain": { + "extra_data": None, + "speech": "Sorry, I couldn't understand that", + } + }, + }, + } + } + + assert calls[3].args[0].type == PipelineEventType.TTS_START + assert calls[3].args[0].data == { + "engine": "default", + "tts_input": "Sorry, I couldn't understand that", + } + assert calls[4].args[0].type == PipelineEventType.TTS_FINISH + assert ( + calls[4].args[0].data["tts_output"] + == f"/api/tts_proxy/dae2cdcb27a1d1c3b07ba2c7db91480f9d4bfd8f_{hass.config.language}_-_demo.mp3" + ) + + assert calls[5].args[0].type == PipelineEventType.RUN_FINISH diff --git a/tests/components/voice_assistant/test_websocket.py b/tests/components/voice_assistant/test_websocket.py index e862da6f542..2fec6cdfb03 100644 --- a/tests/components/voice_assistant/test_websocket.py +++ b/tests/components/voice_assistant/test_websocket.py @@ -24,7 +24,11 @@ async def test_text_only_pipeline( client = await hass_ws_client(hass) await client.send_json( - {"id": 5, "type": "voice_assistant/run", "intent_input": "Are the lights on?"} + { + "id": 5, + "type": "voice_assistant/run", + "intent_input": "Are the lights on?", + } ) # result @@ -35,7 +39,7 @@ async def test_text_only_pipeline( msg = await client.receive_json() assert msg["event"]["type"] == "run-start" assert msg["event"]["data"] == { - "pipeline": "default", + "pipeline": hass.config.language, "language": hass.config.language, } @@ -83,7 +87,8 @@ async def test_conversation_timeout( await asyncio.sleep(3600) with patch( - "homeassistant.components.conversation.async_converse", new=sleepy_converse + "homeassistant.components.conversation.async_converse", + new=sleepy_converse, ): await client.send_json( { @@ -102,7 +107,7 @@ async def test_conversation_timeout( msg = await client.receive_json() assert msg["event"]["type"] == "run-start" assert msg["event"]["data"] == { - "pipeline": "default", + "pipeline": hass.config.language, "language": hass.config.language, } @@ -130,7 +135,7 @@ async def test_pipeline_timeout( await asyncio.sleep(3600) with patch( - "homeassistant.components.voice_assistant.pipeline.Pipeline._run", + "homeassistant.components.voice_assistant.pipeline.TextPipelineRequest._execute", new=sleepy_run, ): await client.send_json(