Migrate Assist Pipeline to use TTS stream (#139542)

* Migrate Pipeline to use TTS stream

* Fix tests
pull/139598/head
Paulus Schoutsen 2025-03-01 15:43:00 -05:00 committed by GitHub
parent c168695323
commit 2cce1b024e
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
7 changed files with 196 additions and 202 deletions

View File

@ -19,14 +19,7 @@ import wave
import hass_nabucasa
import voluptuous as vol
from homeassistant.components import (
conversation,
media_source,
stt,
tts,
wake_word,
websocket_api,
)
from homeassistant.components import conversation, stt, tts, wake_word, websocket_api
from homeassistant.components.tts import (
generate_media_source_id as tts_generate_media_source_id,
)
@ -569,8 +562,7 @@ class PipelineRun:
id: str = field(default_factory=ulid_util.ulid_now)
stt_provider: stt.SpeechToTextEntity | stt.Provider = field(init=False, repr=False)
tts_engine: str = field(init=False, repr=False)
tts_options: dict | None = field(init=False, default=None)
tts_stream: tts.ResultStream | None = field(init=False, default=None)
wake_word_entity_id: str | None = field(init=False, default=None, repr=False)
wake_word_entity: wake_word.WakeWordDetectionEntity = field(init=False, repr=False)
@ -648,13 +640,18 @@ class PipelineRun:
self._device_id = device_id
self._start_debug_recording_thread()
data = {
data: dict[str, Any] = {
"pipeline": self.pipeline.id,
"language": self.language,
"conversation_id": conversation_id,
}
if self.runner_data is not None:
data["runner_data"] = self.runner_data
if self.tts_stream:
data["tts_output"] = {
"url": self.tts_stream.url,
"mime_type": self.tts_stream.content_type,
}
self.process_event(PipelineEvent(PipelineEventType.RUN_START, data))
@ -1246,36 +1243,31 @@ class PipelineRun:
tts_options[tts.ATTR_PREFERRED_SAMPLE_BYTES] = SAMPLE_WIDTH
try:
options_supported = await tts.async_support_options(
self.hass,
engine,
self.pipeline.tts_language,
tts_options,
self.tts_stream = tts.async_create_stream(
hass=self.hass,
engine=engine,
language=self.pipeline.tts_language,
options=tts_options,
)
except HomeAssistantError as err:
raise TextToSpeechError(
code="tts-not-supported",
message=f"Text-to-speech engine '{engine}' not found",
) from err
if not options_supported:
raise TextToSpeechError(
code="tts-not-supported",
message=(
f"Text-to-speech engine {engine} "
f"does not support language {self.pipeline.tts_language} or options {tts_options}"
f"does not support language {self.pipeline.tts_language} or options {tts_options}:"
f" {err}"
),
)
self.tts_engine = engine
self.tts_options = tts_options
) from err
async def text_to_speech(self, tts_input: str) -> None:
"""Run text-to-speech portion of pipeline."""
assert self.tts_stream is not None
self.process_event(
PipelineEvent(
PipelineEventType.TTS_START,
{
"engine": self.tts_engine,
"engine": self.tts_stream.engine,
"language": self.pipeline.tts_language,
"voice": self.pipeline.tts_voice,
"tts_input": tts_input,
@ -1288,14 +1280,9 @@ class PipelineRun:
tts_media_id = tts_generate_media_source_id(
self.hass,
tts_input,
engine=self.tts_engine,
language=self.pipeline.tts_language,
options=self.tts_options,
)
tts_media = await media_source.async_resolve_media(
self.hass,
tts_media_id,
None,
engine=self.tts_stream.engine,
language=self.tts_stream.language,
options=self.tts_stream.options,
)
except Exception as src_error:
_LOGGER.exception("Unexpected error during text-to-speech")
@ -1304,10 +1291,12 @@ class PipelineRun:
message="Unexpected error during text-to-speech",
) from src_error
_LOGGER.debug("TTS result %s", tts_media)
self.tts_stream.async_set_message(tts_input)
tts_output = {
"media_id": tts_media_id,
**asdict(tts_media),
"url": self.tts_stream.url,
"mime_type": self.tts_stream.content_type,
}
self.process_event(

View File

@ -79,13 +79,13 @@ __all__ = [
"PLATFORM_SCHEMA",
"PLATFORM_SCHEMA_BASE",
"Provider",
"ResultStream",
"SampleFormat",
"TextToSpeechEntity",
"TtsAudioType",
"Voice",
"async_default_engine",
"async_get_media_source_audio",
"async_support_options",
"generate_media_source_id",
]
@ -167,22 +167,19 @@ def async_resolve_engine(hass: HomeAssistant, engine: str | None) -> str | None:
return async_default_engine(hass)
async def async_support_options(
@callback
def async_create_stream(
hass: HomeAssistant,
engine: str,
language: str | None = None,
options: dict | None = None,
) -> bool:
"""Return if an engine supports options."""
if (engine_instance := get_engine_instance(hass, engine)) is None:
raise HomeAssistantError(f"Provider {engine} not found")
try:
hass.data[DATA_TTS_MANAGER].process_options(engine_instance, language, options)
except HomeAssistantError:
return False
return True
) -> ResultStream:
"""Create a streaming URL where the rendered TTS can be retrieved."""
return hass.data[DATA_TTS_MANAGER].async_create_result_stream(
engine=engine,
language=language,
options=options,
)
async def async_get_media_source_audio(
@ -407,6 +404,18 @@ class ResultStream:
"""Set cache key for message to be streamed."""
self._result_cache_key.set_result(cache_key)
@callback
def async_set_message(self, message: str) -> None:
"""Set message to be generated."""
cache_key = self._manager.async_cache_message_in_memory(
engine=self.engine,
message=message,
use_file_cache=self.use_file_cache,
language=self.language,
options=self.options,
)
self._result_cache_key.set_result(cache_key)
async def async_stream_result(self) -> AsyncGenerator[bytes]:
"""Get the stream of this result."""
cache_key = await self._result_cache_key

View File

@ -6,6 +6,10 @@
'conversation_id': 'mock-ulid',
'language': 'en',
'pipeline': <ANY>,
'tts_output': dict({
'mime_type': 'audio/mpeg',
'url': '/api/tts_proxy/test_token.mp3',
}),
}),
'type': <PipelineEventType.RUN_START: 'run-start'>,
}),
@ -99,6 +103,10 @@
'conversation_id': 'mock-ulid',
'language': 'en',
'pipeline': <ANY>,
'tts_output': dict({
'mime_type': 'audio/mpeg',
'url': '/api/tts_proxy/test_token.mp3',
}),
}),
'type': <PipelineEventType.RUN_START: 'run-start'>,
}),
@ -192,6 +200,10 @@
'conversation_id': 'mock-ulid',
'language': 'en',
'pipeline': <ANY>,
'tts_output': dict({
'mime_type': 'audio/mpeg',
'url': '/api/tts_proxy/test_token.mp3',
}),
}),
'type': <PipelineEventType.RUN_START: 'run-start'>,
}),
@ -285,6 +297,10 @@
'conversation_id': 'mock-ulid',
'language': 'en',
'pipeline': <ANY>,
'tts_output': dict({
'mime_type': 'audio/mpeg',
'url': '/api/tts_proxy/test_token.mp3',
}),
}),
'type': <PipelineEventType.RUN_START: 'run-start'>,
}),
@ -402,6 +418,10 @@
'conversation_id': 'mock-ulid',
'language': 'en',
'pipeline': <ANY>,
'tts_output': dict({
'mime_type': 'audio/mpeg',
'url': '/api/tts_proxy/mocked-token.mp3',
}),
}),
'type': <PipelineEventType.RUN_START: 'run-start'>,
}),
@ -598,6 +618,10 @@
'conversation_id': 'mock-ulid',
'language': 'en',
'pipeline': <ANY>,
'tts_output': dict({
'mime_type': 'audio/mpeg',
'url': '/api/tts_proxy/mocked-token.mp3',
}),
}),
'type': <PipelineEventType.RUN_START: 'run-start'>,
}),

View File

@ -8,6 +8,10 @@
'stt_binary_handler_id': 1,
'timeout': 300,
}),
'tts_output': dict({
'mime_type': 'audio/mpeg',
'url': '/api/tts_proxy/test_token.mp3',
}),
})
# ---
# name: test_audio_pipeline.1
@ -93,6 +97,10 @@
'stt_binary_handler_id': 1,
'timeout': 300,
}),
'tts_output': dict({
'mime_type': 'audio/mpeg',
'url': '/api/tts_proxy/test_token.mp3',
}),
})
# ---
# name: test_audio_pipeline_debug.1
@ -190,6 +198,10 @@
'stt_binary_handler_id': 1,
'timeout': 300,
}),
'tts_output': dict({
'mime_type': 'audio/mpeg',
'url': '/api/tts_proxy/test_token.mp3',
}),
})
# ---
# name: test_audio_pipeline_with_enhancements.1
@ -275,6 +287,10 @@
'stt_binary_handler_id': 1,
'timeout': 300,
}),
'tts_output': dict({
'mime_type': 'audio/mpeg',
'url': '/api/tts_proxy/test_token.mp3',
}),
})
# ---
# name: test_audio_pipeline_with_wake_word_no_timeout.1
@ -382,6 +398,10 @@
'stt_binary_handler_id': 1,
'timeout': 300,
}),
'tts_output': dict({
'mime_type': 'audio/mpeg',
'url': '/api/tts_proxy/test_token.mp3',
}),
})
# ---
# name: test_audio_pipeline_with_wake_word_timeout.1
@ -585,6 +605,10 @@
'stt_binary_handler_id': None,
'timeout': 300,
}),
'tts_output': dict({
'mime_type': 'audio/mpeg',
'url': '/api/tts_proxy/mocked-token.mp3',
}),
})
# ---
# name: test_pipeline_empty_tts_output.1
@ -634,6 +658,10 @@
'stt_binary_handler_id': 1,
'timeout': 300,
}),
'tts_output': dict({
'mime_type': 'audio/mpeg',
'url': '/api/tts_proxy/mocked-token.mp3',
}),
})
# ---
# name: test_stt_cooldown_different_ids.1
@ -645,6 +673,10 @@
'stt_binary_handler_id': 1,
'timeout': 300,
}),
'tts_output': dict({
'mime_type': 'audio/mpeg',
'url': '/api/tts_proxy/mocked-token.mp3',
}),
})
# ---
# name: test_stt_cooldown_same_id
@ -656,6 +688,10 @@
'stt_binary_handler_id': 1,
'timeout': 300,
}),
'tts_output': dict({
'mime_type': 'audio/mpeg',
'url': '/api/tts_proxy/mocked-token.mp3',
}),
})
# ---
# name: test_stt_cooldown_same_id.1
@ -667,6 +703,10 @@
'stt_binary_handler_id': 1,
'timeout': 300,
}),
'tts_output': dict({
'mime_type': 'audio/mpeg',
'url': '/api/tts_proxy/mocked-token.mp3',
}),
})
# ---
# name: test_stt_stream_failed
@ -678,6 +718,10 @@
'stt_binary_handler_id': 1,
'timeout': 300,
}),
'tts_output': dict({
'mime_type': 'audio/mpeg',
'url': '/api/tts_proxy/mocked-token.mp3',
}),
})
# ---
# name: test_stt_stream_failed.1
@ -798,28 +842,6 @@
'message': 'Timeout running pipeline',
})
# ---
# name: test_tts_failed
dict({
'conversation_id': 'mock-ulid',
'language': 'en',
'pipeline': <ANY>,
'runner_data': dict({
'stt_binary_handler_id': None,
'timeout': 300,
}),
})
# ---
# name: test_tts_failed.1
dict({
'engine': 'test',
'language': 'en-US',
'tts_input': 'Lights are on.',
'voice': 'james_earl_jones',
})
# ---
# name: test_tts_failed.2
None
# ---
# name: test_wake_word_cooldown_different_entities
dict({
'conversation_id': 'mock-ulid',
@ -829,6 +851,10 @@
'stt_binary_handler_id': 1,
'timeout': 300,
}),
'tts_output': dict({
'mime_type': 'audio/mpeg',
'url': '/api/tts_proxy/mocked-token.mp3',
}),
})
# ---
# name: test_wake_word_cooldown_different_entities.1
@ -840,6 +866,10 @@
'stt_binary_handler_id': 1,
'timeout': 300,
}),
'tts_output': dict({
'mime_type': 'audio/mpeg',
'url': '/api/tts_proxy/mocked-token.mp3',
}),
})
# ---
# name: test_wake_word_cooldown_different_entities.2
@ -892,6 +922,10 @@
'stt_binary_handler_id': 1,
'timeout': 300,
}),
'tts_output': dict({
'mime_type': 'audio/mpeg',
'url': '/api/tts_proxy/mocked-token.mp3',
}),
})
# ---
# name: test_wake_word_cooldown_different_ids.1
@ -903,6 +937,10 @@
'stt_binary_handler_id': 1,
'timeout': 300,
}),
'tts_output': dict({
'mime_type': 'audio/mpeg',
'url': '/api/tts_proxy/mocked-token.mp3',
}),
})
# ---
# name: test_wake_word_cooldown_different_ids.2
@ -958,6 +996,10 @@
'stt_binary_handler_id': 1,
'timeout': 300,
}),
'tts_output': dict({
'mime_type': 'audio/mpeg',
'url': '/api/tts_proxy/mocked-token.mp3',
}),
})
# ---
# name: test_wake_word_cooldown_same_id.1
@ -969,6 +1011,10 @@
'stt_binary_handler_id': 1,
'timeout': 300,
}),
'tts_output': dict({
'mime_type': 'audio/mpeg',
'url': '/api/tts_proxy/mocked-token.mp3',
}),
})
# ---
# name: test_wake_word_cooldown_same_id.2

View File

@ -43,13 +43,21 @@ from tests.typing import ClientSessionGenerator, WebSocketGenerator
@pytest.fixture(autouse=True)
def mock_ulid() -> Generator[Mock]:
"""Mock the ulid of chat sessions."""
with patch("homeassistant.helpers.chat_session.ulid_now") as mock_ulid_now:
mock_ulid_now.return_value = "mock-ulid"
def mock_chat_session_id() -> Generator[Mock]:
"""Mock the conversation ID of chat sessions."""
with patch(
"homeassistant.helpers.chat_session.ulid_now", return_value="mock-ulid"
) as mock_ulid_now:
yield mock_ulid_now
@pytest.fixture(autouse=True)
def mock_tts_token() -> Generator[None]:
"""Mock the TTS token for URLs."""
with patch("secrets.token_urlsafe", return_value="mocked-token"):
yield
def process_events(events: list[assist_pipeline.PipelineEvent]) -> list[dict]:
"""Process events to remove dynamic values."""
processed = []
@ -797,10 +805,16 @@ async def test_tts_audio_output(
await pipeline_input.validate()
# Verify TTS audio settings
assert pipeline_input.run.tts_options is not None
assert pipeline_input.run.tts_options.get(tts.ATTR_PREFERRED_FORMAT) == "wav"
assert pipeline_input.run.tts_options.get(tts.ATTR_PREFERRED_SAMPLE_RATE) == 16000
assert pipeline_input.run.tts_options.get(tts.ATTR_PREFERRED_SAMPLE_CHANNELS) == 1
assert pipeline_input.run.tts_stream.options is not None
assert pipeline_input.run.tts_stream.options.get(tts.ATTR_PREFERRED_FORMAT) == "wav"
assert (
pipeline_input.run.tts_stream.options.get(tts.ATTR_PREFERRED_SAMPLE_RATE)
== 16000
)
assert (
pipeline_input.run.tts_stream.options.get(tts.ATTR_PREFERRED_SAMPLE_CHANNELS)
== 1
)
with patch.object(mock_tts_provider, "get_tts_audio") as mock_get_tts_audio:
await pipeline_input.execute()
@ -809,9 +823,7 @@ async def test_tts_audio_output(
if event.type == assist_pipeline.PipelineEventType.TTS_END:
# We must fetch the media URL to trigger the TTS
assert event.data
media_id = event.data["tts_output"]["media_id"]
resolved = await media_source.async_resolve_media(hass, media_id, None)
await client.get(resolved.url)
await client.get(event.data["tts_output"]["url"])
# Ensure that no unsupported options were passed in
assert mock_get_tts_audio.called
@ -875,9 +887,7 @@ async def test_tts_wav_preferred_format(
if event.type == assist_pipeline.PipelineEventType.TTS_END:
# We must fetch the media URL to trigger the TTS
assert event.data
media_id = event.data["tts_output"]["media_id"]
resolved = await media_source.async_resolve_media(hass, media_id, None)
await client.get(resolved.url)
await client.get(event.data["tts_output"]["url"])
assert mock_get_tts_audio.called
options = mock_get_tts_audio.call_args_list[0].kwargs["options"]
@ -949,9 +959,7 @@ async def test_tts_dict_preferred_format(
if event.type == assist_pipeline.PipelineEventType.TTS_END:
# We must fetch the media URL to trigger the TTS
assert event.data
media_id = event.data["tts_output"]["media_id"]
resolved = await media_source.async_resolve_media(hass, media_id, None)
await client.get(resolved.url)
await client.get(event.data["tts_output"]["url"])
assert mock_get_tts_audio.called
options = mock_get_tts_audio.call_args_list[0].kwargs["options"]

View File

@ -20,6 +20,8 @@ from homeassistant.components.assist_pipeline.pipeline import (
DeviceAudioQueue,
Pipeline,
PipelineData,
async_get_pipelines,
async_update_pipeline,
)
from homeassistant.core import HomeAssistant
from homeassistant.exceptions import HomeAssistantError
@ -38,13 +40,21 @@ from tests.typing import WebSocketGenerator
@pytest.fixture(autouse=True)
def mock_ulid() -> Generator[Mock]:
"""Mock the ulid of chat sessions."""
with patch("homeassistant.helpers.chat_session.ulid_now") as mock_ulid_now:
mock_ulid_now.return_value = "mock-ulid"
def mock_chat_session_id() -> Generator[Mock]:
"""Mock the conversation ID of chat sessions."""
with patch(
"homeassistant.helpers.chat_session.ulid_now", return_value="mock-ulid"
) as mock_ulid_now:
yield mock_ulid_now
@pytest.fixture(autouse=True)
def mock_tts_token() -> Generator[None]:
"""Mock the TTS token for URLs."""
with patch("secrets.token_urlsafe", return_value="mocked-token"):
yield
@pytest.mark.parametrize(
"extra_msg",
[
@ -825,74 +835,6 @@ async def test_stt_stream_failed(
assert msg["result"] == {"events": events}
async def test_tts_failed(
hass: HomeAssistant,
hass_ws_client: WebSocketGenerator,
init_components,
snapshot: SnapshotAssertion,
) -> None:
"""Test pipeline run with text-to-speech error."""
events = []
client = await hass_ws_client(hass)
with patch(
"homeassistant.components.media_source.async_resolve_media",
side_effect=RuntimeError,
):
await client.send_json_auto_id(
{
"type": "assist_pipeline/run",
"start_stage": "tts",
"end_stage": "tts",
"input": {"text": "Lights are on."},
}
)
# result
msg = await client.receive_json()
assert msg["success"]
# run start
msg = await client.receive_json()
assert msg["event"]["type"] == "run-start"
msg["event"]["data"]["pipeline"] = ANY
assert msg["event"]["data"] == snapshot
events.append(msg["event"])
# tts start
msg = await client.receive_json()
assert msg["event"]["type"] == "tts-start"
assert msg["event"]["data"] == snapshot
events.append(msg["event"])
# tts error
msg = await client.receive_json()
assert msg["event"]["type"] == "error"
assert msg["event"]["data"]["code"] == "tts-failed"
events.append(msg["event"])
# run end
msg = await client.receive_json()
assert msg["event"]["type"] == "run-end"
assert msg["event"]["data"] == snapshot
events.append(msg["event"])
pipeline_data: PipelineData = hass.data[DOMAIN]
pipeline_id = list(pipeline_data.pipeline_debug)[0]
pipeline_run_id = list(pipeline_data.pipeline_debug[pipeline_id])[0]
await client.send_json_auto_id(
{
"type": "assist_pipeline/pipeline_debug/get",
"pipeline_id": pipeline_id,
"pipeline_run_id": pipeline_run_id,
}
)
msg = await client.receive_json()
assert msg["success"]
assert msg["result"] == {"events": events}
async def test_tts_provider_missing(
hass: HomeAssistant,
hass_ws_client: WebSocketGenerator,
@ -903,23 +845,22 @@ async def test_tts_provider_missing(
"""Test pipeline run with text-to-speech error."""
client = await hass_ws_client(hass)
with patch(
"homeassistant.components.tts.async_support_options",
side_effect=HomeAssistantError,
):
await client.send_json_auto_id(
{
"type": "assist_pipeline/run",
"start_stage": "tts",
"end_stage": "tts",
"input": {"text": "Lights are on."},
}
)
pipelines = async_get_pipelines(hass)
await async_update_pipeline(hass, pipelines[0], tts_engine="unavailable")
# result
msg = await client.receive_json()
assert not msg["success"]
assert msg["error"]["code"] == "tts-not-supported"
await client.send_json_auto_id(
{
"type": "assist_pipeline/run",
"start_stage": "tts",
"end_stage": "tts",
"input": {"text": "Lights are on."},
}
)
# result
msg = await client.receive_json()
assert not msg["success"]
assert msg["error"]["code"] == "tts-not-supported"
async def test_tts_provider_bad_options(
@ -933,8 +874,8 @@ async def test_tts_provider_bad_options(
client = await hass_ws_client(hass)
with patch(
"homeassistant.components.tts.async_support_options",
return_value=False,
"homeassistant.components.tts.SpeechManager.process_options",
side_effect=HomeAssistantError("Language not supported"),
):
await client.send_json_auto_id(
{

View File

@ -1376,29 +1376,6 @@ def test_resolve_engine(hass: HomeAssistant, setup: str, engine_id: str) -> None
assert tts.async_resolve_engine(hass, None) is None
@pytest.mark.parametrize(
("setup", "engine_id"),
[
("mock_setup", "test"),
("mock_config_entry_setup", "tts.test"),
],
indirect=["setup"],
)
async def test_support_options(hass: HomeAssistant, setup: str, engine_id: str) -> None:
"""Test supporting options."""
assert await tts.async_support_options(hass, engine_id, "en_US") is True
assert await tts.async_support_options(hass, engine_id, "nl") is False
assert (
await tts.async_support_options(
hass, engine_id, "en_US", {"invalid_option": "yo"}
)
is False
)
with pytest.raises(HomeAssistantError):
await tts.async_support_options(hass, "non-existing")
async def test_legacy_fetching_in_async(
hass: HomeAssistant, hass_client: ClientSessionGenerator
) -> None: