core/tests/components/assist_pipeline/test_init.py

858 lines
29 KiB
Python

"""Test Voice Assistant init."""
import asyncio
from dataclasses import asdict
import itertools as it
from pathlib import Path
import tempfile
from unittest.mock import ANY, patch
import wave
import pytest
from syrupy.assertion import SnapshotAssertion
from homeassistant.components import assist_pipeline, media_source, stt, tts
from homeassistant.components.assist_pipeline.const import (
CONF_DEBUG_RECORDING_DIR,
DOMAIN,
)
from homeassistant.core import Context, HomeAssistant
from homeassistant.setup import async_setup_component
from .conftest import (
MockSttProvider,
MockSttProviderEntity,
MockTTSProvider,
MockWakeWordEntity,
)
from tests.typing import ClientSessionGenerator, WebSocketGenerator
BYTES_ONE_SECOND = 16000 * 2
def process_events(events: list[assist_pipeline.PipelineEvent]) -> list[dict]:
"""Process events to remove dynamic values."""
processed = []
for event in events:
as_dict = asdict(event)
as_dict.pop("timestamp")
if as_dict["type"] == assist_pipeline.PipelineEventType.RUN_START:
as_dict["data"]["pipeline"] = ANY
processed.append(as_dict)
return processed
async def test_pipeline_from_audio_stream_auto(
hass: HomeAssistant,
mock_stt_provider: MockSttProvider,
init_components,
snapshot: SnapshotAssertion,
) -> None:
"""Test creating a pipeline from an audio stream.
In this test, no pipeline is specified.
"""
events: list[assist_pipeline.PipelineEvent] = []
async def audio_data():
yield b"part1"
yield b"part2"
yield b""
await assist_pipeline.async_pipeline_from_audio_stream(
hass,
context=Context(),
event_callback=events.append,
stt_metadata=stt.SpeechMetadata(
language="",
format=stt.AudioFormats.WAV,
codec=stt.AudioCodecs.PCM,
bit_rate=stt.AudioBitRates.BITRATE_16,
sample_rate=stt.AudioSampleRates.SAMPLERATE_16000,
channel=stt.AudioChannels.CHANNEL_MONO,
),
stt_stream=audio_data(),
audio_settings=assist_pipeline.AudioSettings(
is_vad_enabled=False, is_chunking_enabled=False
),
)
assert process_events(events) == snapshot
assert mock_stt_provider.received == [b"part1", b"part2"]
async def test_pipeline_from_audio_stream_legacy(
hass: HomeAssistant,
hass_ws_client: WebSocketGenerator,
mock_stt_provider: MockSttProvider,
init_components,
snapshot: SnapshotAssertion,
) -> None:
"""Test creating a pipeline from an audio stream.
In this test, a pipeline using a legacy stt engine is used.
"""
client = await hass_ws_client(hass)
events: list[assist_pipeline.PipelineEvent] = []
async def audio_data():
yield b"part1"
yield b"part2"
yield b""
# Create a pipeline using an stt entity
await client.send_json_auto_id(
{
"type": "assist_pipeline/pipeline/create",
"conversation_engine": "homeassistant",
"conversation_language": "en-US",
"language": "en",
"name": "test_name",
"stt_engine": "test",
"stt_language": "en-US",
"tts_engine": "test",
"tts_language": "en-US",
"tts_voice": "Arnold Schwarzenegger",
"wake_word_entity": None,
"wake_word_id": None,
}
)
msg = await client.receive_json()
assert msg["success"]
pipeline_id = msg["result"]["id"]
# Use the created pipeline
await assist_pipeline.async_pipeline_from_audio_stream(
hass,
context=Context(),
event_callback=events.append,
stt_metadata=stt.SpeechMetadata(
language="en-UK",
format=stt.AudioFormats.WAV,
codec=stt.AudioCodecs.PCM,
bit_rate=stt.AudioBitRates.BITRATE_16,
sample_rate=stt.AudioSampleRates.SAMPLERATE_16000,
channel=stt.AudioChannels.CHANNEL_MONO,
),
stt_stream=audio_data(),
pipeline_id=pipeline_id,
audio_settings=assist_pipeline.AudioSettings(
is_vad_enabled=False, is_chunking_enabled=False
),
)
assert process_events(events) == snapshot
assert mock_stt_provider.received == [b"part1", b"part2"]
async def test_pipeline_from_audio_stream_entity(
hass: HomeAssistant,
hass_ws_client: WebSocketGenerator,
mock_stt_provider_entity: MockSttProviderEntity,
init_components,
snapshot: SnapshotAssertion,
) -> None:
"""Test creating a pipeline from an audio stream.
In this test, a pipeline using am stt entity is used.
"""
client = await hass_ws_client(hass)
events: list[assist_pipeline.PipelineEvent] = []
async def audio_data():
yield b"part1"
yield b"part2"
yield b""
# Create a pipeline using an stt entity
await client.send_json_auto_id(
{
"type": "assist_pipeline/pipeline/create",
"conversation_engine": "homeassistant",
"conversation_language": "en-US",
"language": "en",
"name": "test_name",
"stt_engine": mock_stt_provider_entity.entity_id,
"stt_language": "en-US",
"tts_engine": "test",
"tts_language": "en-US",
"tts_voice": "Arnold Schwarzenegger",
"wake_word_entity": None,
"wake_word_id": None,
}
)
msg = await client.receive_json()
assert msg["success"]
pipeline_id = msg["result"]["id"]
# Use the created pipeline
await assist_pipeline.async_pipeline_from_audio_stream(
hass,
context=Context(),
event_callback=events.append,
stt_metadata=stt.SpeechMetadata(
language="en-UK",
format=stt.AudioFormats.WAV,
codec=stt.AudioCodecs.PCM,
bit_rate=stt.AudioBitRates.BITRATE_16,
sample_rate=stt.AudioSampleRates.SAMPLERATE_16000,
channel=stt.AudioChannels.CHANNEL_MONO,
),
stt_stream=audio_data(),
pipeline_id=pipeline_id,
audio_settings=assist_pipeline.AudioSettings(
is_vad_enabled=False, is_chunking_enabled=False
),
)
assert process_events(events) == snapshot
assert mock_stt_provider_entity.received == [b"part1", b"part2"]
async def test_pipeline_from_audio_stream_no_stt(
hass: HomeAssistant,
hass_ws_client: WebSocketGenerator,
mock_stt_provider: MockSttProvider,
init_components,
snapshot: SnapshotAssertion,
) -> None:
"""Test creating a pipeline from an audio stream.
In this test, the pipeline does not support stt
"""
client = await hass_ws_client(hass)
events: list[assist_pipeline.PipelineEvent] = []
async def audio_data():
yield b"part1"
yield b"part2"
yield b""
# Create a pipeline without stt support
await client.send_json_auto_id(
{
"type": "assist_pipeline/pipeline/create",
"conversation_engine": "homeassistant",
"conversation_language": "en-US",
"language": "en",
"name": "test_name",
"stt_engine": None,
"stt_language": None,
"tts_engine": "test",
"tts_language": "en-AU",
"tts_voice": "Arnold Schwarzenegger",
"wake_word_entity": None,
"wake_word_id": None,
}
)
msg = await client.receive_json()
assert msg["success"]
pipeline_id = msg["result"]["id"]
# Try to use the created pipeline
with pytest.raises(assist_pipeline.pipeline.PipelineRunValidationError):
await assist_pipeline.async_pipeline_from_audio_stream(
hass,
context=Context(),
event_callback=events.append,
stt_metadata=stt.SpeechMetadata(
language="en-UK",
format=stt.AudioFormats.WAV,
codec=stt.AudioCodecs.PCM,
bit_rate=stt.AudioBitRates.BITRATE_16,
sample_rate=stt.AudioSampleRates.SAMPLERATE_16000,
channel=stt.AudioChannels.CHANNEL_MONO,
),
stt_stream=audio_data(),
pipeline_id=pipeline_id,
audio_settings=assist_pipeline.AudioSettings(
is_vad_enabled=False, is_chunking_enabled=False
),
)
assert not events
async def test_pipeline_from_audio_stream_unknown_pipeline(
hass: HomeAssistant,
hass_ws_client: WebSocketGenerator,
mock_stt_provider: MockSttProvider,
init_components,
snapshot: SnapshotAssertion,
) -> None:
"""Test creating a pipeline from an audio stream.
In this test, the pipeline does not exist.
"""
events: list[assist_pipeline.PipelineEvent] = []
async def audio_data():
yield b"part1"
yield b"part2"
yield b""
# Try to use the created pipeline
with pytest.raises(assist_pipeline.PipelineNotFound):
await assist_pipeline.async_pipeline_from_audio_stream(
hass,
context=Context(),
event_callback=events.append,
stt_metadata=stt.SpeechMetadata(
language="en-UK",
format=stt.AudioFormats.WAV,
codec=stt.AudioCodecs.PCM,
bit_rate=stt.AudioBitRates.BITRATE_16,
sample_rate=stt.AudioSampleRates.SAMPLERATE_16000,
channel=stt.AudioChannels.CHANNEL_MONO,
),
stt_stream=audio_data(),
pipeline_id="blah",
)
assert not events
async def test_pipeline_from_audio_stream_wake_word(
hass: HomeAssistant,
mock_stt_provider: MockSttProvider,
mock_wake_word_provider_entity: MockWakeWordEntity,
init_components,
snapshot: SnapshotAssertion,
) -> None:
"""Test creating a pipeline from an audio stream with wake word."""
events: list[assist_pipeline.PipelineEvent] = []
# [0, 1, ...]
wake_chunk_1 = bytes(it.islice(it.cycle(range(256)), BYTES_ONE_SECOND))
# [0, 2, ...]
wake_chunk_2 = bytes(it.islice(it.cycle(range(0, 256, 2)), BYTES_ONE_SECOND))
bytes_per_chunk = int(0.01 * BYTES_ONE_SECOND)
async def audio_data():
# 1 second in 10 ms chunks
i = 0
while i < len(wake_chunk_1):
yield wake_chunk_1[i : i + bytes_per_chunk]
i += bytes_per_chunk
# 1 second in 30 ms chunks
i = 0
while i < len(wake_chunk_2):
yield wake_chunk_2[i : i + bytes_per_chunk]
i += bytes_per_chunk
yield b"wake word!"
yield b"part1"
yield b"part2"
yield b""
await assist_pipeline.async_pipeline_from_audio_stream(
hass,
context=Context(),
event_callback=events.append,
stt_metadata=stt.SpeechMetadata(
language="",
format=stt.AudioFormats.WAV,
codec=stt.AudioCodecs.PCM,
bit_rate=stt.AudioBitRates.BITRATE_16,
sample_rate=stt.AudioSampleRates.SAMPLERATE_16000,
channel=stt.AudioChannels.CHANNEL_MONO,
),
stt_stream=audio_data(),
start_stage=assist_pipeline.PipelineStage.WAKE_WORD,
wake_word_settings=assist_pipeline.WakeWordSettings(
audio_seconds_to_buffer=1.5
),
audio_settings=assist_pipeline.AudioSettings(
is_vad_enabled=False, is_chunking_enabled=False
),
)
assert process_events(events) == snapshot
# 1. Half of wake_chunk_1 + all wake_chunk_2
# 2. queued audio (from mock wake word entity)
# 3. part1
# 4. part2
assert len(mock_stt_provider.received) > 3
first_chunk = bytes(
[c_byte for c in mock_stt_provider.received[:-3] for c_byte in c]
)
assert first_chunk == wake_chunk_1[len(wake_chunk_1) // 2 :] + wake_chunk_2
assert mock_stt_provider.received[-3:] == [b"queued audio", b"part1", b"part2"]
async def test_pipeline_save_audio(
hass: HomeAssistant,
mock_stt_provider: MockSttProvider,
mock_wake_word_provider_entity: MockWakeWordEntity,
init_supporting_components,
snapshot: SnapshotAssertion,
) -> None:
"""Test saving audio during a pipeline run."""
with tempfile.TemporaryDirectory() as temp_dir_str:
# Enable audio recording to temporary directory
temp_dir = Path(temp_dir_str)
assert await async_setup_component(
hass,
DOMAIN,
{DOMAIN: {CONF_DEBUG_RECORDING_DIR: temp_dir_str}},
)
pipeline = assist_pipeline.async_get_pipeline(hass)
events: list[assist_pipeline.PipelineEvent] = []
# Pad out to an even number of bytes since these "samples" will be saved
# as 16-bit values.
async def audio_data():
yield b"wake word_"
# queued audio
yield b"part1_"
yield b"part2_"
yield b""
await assist_pipeline.async_pipeline_from_audio_stream(
hass,
context=Context(),
event_callback=events.append,
stt_metadata=stt.SpeechMetadata(
language="",
format=stt.AudioFormats.WAV,
codec=stt.AudioCodecs.PCM,
bit_rate=stt.AudioBitRates.BITRATE_16,
sample_rate=stt.AudioSampleRates.SAMPLERATE_16000,
channel=stt.AudioChannels.CHANNEL_MONO,
),
stt_stream=audio_data(),
pipeline_id=pipeline.id,
start_stage=assist_pipeline.PipelineStage.WAKE_WORD,
end_stage=assist_pipeline.PipelineStage.STT,
audio_settings=assist_pipeline.AudioSettings(
is_vad_enabled=False, is_chunking_enabled=False
),
)
pipeline_dirs = list(temp_dir.iterdir())
# Only one pipeline run
# <debug_recording_dir>/<pipeline.name>/<run.id>
assert len(pipeline_dirs) == 1
assert pipeline_dirs[0].is_dir()
assert pipeline_dirs[0].name == pipeline.name
# Wake and stt files
run_dirs = list(pipeline_dirs[0].iterdir())
assert run_dirs[0].is_dir()
run_files = list(run_dirs[0].iterdir())
assert len(run_files) == 2
wake_file = run_files[0] if "wake" in run_files[0].name else run_files[1]
stt_file = run_files[0] if "stt" in run_files[0].name else run_files[1]
assert wake_file != stt_file
# Verify wake file
with wave.open(str(wake_file), "rb") as wake_wav:
wake_data = wake_wav.readframes(wake_wav.getnframes())
assert wake_data == b"wake word_"
# Verify stt file
with wave.open(str(stt_file), "rb") as stt_wav:
stt_data = stt_wav.readframes(stt_wav.getnframes())
assert stt_data == b"queued audiopart1_part2_"
async def test_pipeline_saved_audio_with_device_id(
hass: HomeAssistant,
mock_stt_provider: MockSttProvider,
mock_wake_word_provider_entity: MockWakeWordEntity,
init_supporting_components,
snapshot: SnapshotAssertion,
) -> None:
"""Test that saved audio directory uses device id."""
device_id = "test-device-id"
with tempfile.TemporaryDirectory() as temp_dir_str:
# Enable audio recording to temporary directory
temp_dir = Path(temp_dir_str)
assert await async_setup_component(
hass,
DOMAIN,
{DOMAIN: {CONF_DEBUG_RECORDING_DIR: temp_dir_str}},
)
def event_callback(event: assist_pipeline.PipelineEvent):
if event.type == "run-end":
# Verify that saved audio directory is named after device id
device_dirs = list(temp_dir.iterdir())
assert device_dirs[0].name == device_id
async def audio_data():
yield b"not used"
# Force a timeout during wake word detection
with patch.object(
mock_wake_word_provider_entity,
"async_process_audio_stream",
side_effect=assist_pipeline.error.WakeWordTimeoutError(
code="timeout", message="timeout"
),
):
await assist_pipeline.async_pipeline_from_audio_stream(
hass,
context=Context(),
event_callback=event_callback,
stt_metadata=stt.SpeechMetadata(
language="",
format=stt.AudioFormats.WAV,
codec=stt.AudioCodecs.PCM,
bit_rate=stt.AudioBitRates.BITRATE_16,
sample_rate=stt.AudioSampleRates.SAMPLERATE_16000,
channel=stt.AudioChannels.CHANNEL_MONO,
),
stt_stream=audio_data(),
start_stage=assist_pipeline.PipelineStage.WAKE_WORD,
end_stage=assist_pipeline.PipelineStage.STT,
device_id=device_id,
)
async def test_pipeline_saved_audio_write_error(
hass: HomeAssistant,
mock_stt_provider: MockSttProvider,
mock_wake_word_provider_entity: MockWakeWordEntity,
init_supporting_components,
snapshot: SnapshotAssertion,
) -> None:
"""Test that saved audio thread closes WAV file even if there's a write error."""
with tempfile.TemporaryDirectory() as temp_dir_str:
# Enable audio recording to temporary directory
temp_dir = Path(temp_dir_str)
assert await async_setup_component(
hass,
DOMAIN,
{DOMAIN: {CONF_DEBUG_RECORDING_DIR: temp_dir_str}},
)
def event_callback(event: assist_pipeline.PipelineEvent):
if event.type == "run-end":
# Verify WAV file exists, but contains no data
pipeline_dirs = list(temp_dir.iterdir())
run_dirs = list(pipeline_dirs[0].iterdir())
wav_path = next(run_dirs[0].iterdir())
with wave.open(str(wav_path), "rb") as wav_file:
assert wav_file.getnframes() == 0
async def audio_data():
yield b"not used"
# Force a timeout during wake word detection
with patch("wave.Wave_write.writeframes", raises=RuntimeError()):
await assist_pipeline.async_pipeline_from_audio_stream(
hass,
context=Context(),
event_callback=event_callback,
stt_metadata=stt.SpeechMetadata(
language="",
format=stt.AudioFormats.WAV,
codec=stt.AudioCodecs.PCM,
bit_rate=stt.AudioBitRates.BITRATE_16,
sample_rate=stt.AudioSampleRates.SAMPLERATE_16000,
channel=stt.AudioChannels.CHANNEL_MONO,
),
stt_stream=audio_data(),
start_stage=assist_pipeline.PipelineStage.WAKE_WORD,
end_stage=assist_pipeline.PipelineStage.STT,
)
async def test_pipeline_saved_audio_empty_queue(
hass: HomeAssistant,
mock_stt_provider: MockSttProvider,
mock_wake_word_provider_entity: MockWakeWordEntity,
init_supporting_components,
snapshot: SnapshotAssertion,
) -> None:
"""Test that saved audio thread closes WAV file even if there's an empty queue."""
with tempfile.TemporaryDirectory() as temp_dir_str:
# Enable audio recording to temporary directory
temp_dir = Path(temp_dir_str)
assert await async_setup_component(
hass,
DOMAIN,
{DOMAIN: {CONF_DEBUG_RECORDING_DIR: temp_dir_str}},
)
def event_callback(event: assist_pipeline.PipelineEvent):
if event.type == "run-end":
# Verify WAV file exists, but contains no data
pipeline_dirs = list(temp_dir.iterdir())
run_dirs = list(pipeline_dirs[0].iterdir())
wav_path = next(run_dirs[0].iterdir())
with wave.open(str(wav_path), "rb") as wav_file:
assert wav_file.getnframes() == 0
async def audio_data():
# Force timeout in _pipeline_debug_recording_thread_proc
await asyncio.sleep(1)
yield b"not used"
# Wrap original function to time out immediately
_pipeline_debug_recording_thread_proc = (
assist_pipeline.pipeline._pipeline_debug_recording_thread_proc
)
def proc_wrapper(run_recording_dir, queue):
_pipeline_debug_recording_thread_proc(
run_recording_dir, queue, message_timeout=0
)
with patch(
"homeassistant.components.assist_pipeline.pipeline._pipeline_debug_recording_thread_proc",
proc_wrapper,
):
await assist_pipeline.async_pipeline_from_audio_stream(
hass,
context=Context(),
event_callback=event_callback,
stt_metadata=stt.SpeechMetadata(
language="",
format=stt.AudioFormats.WAV,
codec=stt.AudioCodecs.PCM,
bit_rate=stt.AudioBitRates.BITRATE_16,
sample_rate=stt.AudioSampleRates.SAMPLERATE_16000,
channel=stt.AudioChannels.CHANNEL_MONO,
),
stt_stream=audio_data(),
start_stage=assist_pipeline.PipelineStage.WAKE_WORD,
end_stage=assist_pipeline.PipelineStage.STT,
)
async def test_wake_word_detection_aborted(
hass: HomeAssistant,
mock_stt_provider: MockSttProvider,
mock_wake_word_provider_entity: MockWakeWordEntity,
init_components,
pipeline_data: assist_pipeline.pipeline.PipelineData,
snapshot: SnapshotAssertion,
) -> None:
"""Test creating a pipeline from an audio stream with wake word."""
events: list[assist_pipeline.PipelineEvent] = []
async def audio_data():
yield b"silence!"
yield b"wake word!"
yield b"part1"
yield b"part2"
yield b""
pipeline_store = pipeline_data.pipeline_store
pipeline_id = pipeline_store.async_get_preferred_item()
pipeline = assist_pipeline.pipeline.async_get_pipeline(hass, pipeline_id)
pipeline_input = assist_pipeline.pipeline.PipelineInput(
conversation_id=None,
device_id=None,
stt_metadata=stt.SpeechMetadata(
language="",
format=stt.AudioFormats.WAV,
codec=stt.AudioCodecs.PCM,
bit_rate=stt.AudioBitRates.BITRATE_16,
sample_rate=stt.AudioSampleRates.SAMPLERATE_16000,
channel=stt.AudioChannels.CHANNEL_MONO,
),
stt_stream=audio_data(),
run=assist_pipeline.pipeline.PipelineRun(
hass,
context=Context(),
pipeline=pipeline,
start_stage=assist_pipeline.PipelineStage.WAKE_WORD,
end_stage=assist_pipeline.PipelineStage.TTS,
event_callback=events.append,
tts_audio_output=None,
wake_word_settings=assist_pipeline.WakeWordSettings(
audio_seconds_to_buffer=1.5
),
audio_settings=assist_pipeline.AudioSettings(
is_vad_enabled=False, is_chunking_enabled=False
),
),
)
await pipeline_input.validate()
updates = pipeline.to_json()
updates.pop("id")
await pipeline_store.async_update_item(
pipeline_id,
updates,
)
await pipeline_input.execute()
assert process_events(events) == snapshot
def test_pipeline_run_equality(hass: HomeAssistant, init_components) -> None:
"""Test that pipeline run equality uses unique id."""
def event_callback(event):
pass
pipeline = assist_pipeline.pipeline.async_get_pipeline(hass)
run_1 = assist_pipeline.pipeline.PipelineRun(
hass,
context=Context(),
pipeline=pipeline,
start_stage=assist_pipeline.PipelineStage.STT,
end_stage=assist_pipeline.PipelineStage.TTS,
event_callback=event_callback,
)
run_2 = assist_pipeline.pipeline.PipelineRun(
hass,
context=Context(),
pipeline=pipeline,
start_stage=assist_pipeline.PipelineStage.STT,
end_stage=assist_pipeline.PipelineStage.TTS,
event_callback=event_callback,
)
assert run_1 == run_1 # noqa: PLR0124
assert run_1 != run_2
assert run_1 != 1234
async def test_tts_audio_output(
hass: HomeAssistant,
hass_client: ClientSessionGenerator,
mock_tts_provider: MockTTSProvider,
init_components,
pipeline_data: assist_pipeline.pipeline.PipelineData,
snapshot: SnapshotAssertion,
) -> None:
"""Test using tts_audio_output with wav sets options correctly."""
client = await hass_client()
assert await async_setup_component(hass, media_source.DOMAIN, {})
events: list[assist_pipeline.PipelineEvent] = []
pipeline_store = pipeline_data.pipeline_store
pipeline_id = pipeline_store.async_get_preferred_item()
pipeline = assist_pipeline.pipeline.async_get_pipeline(hass, pipeline_id)
pipeline_input = assist_pipeline.pipeline.PipelineInput(
tts_input="This is a test.",
conversation_id=None,
device_id=None,
run=assist_pipeline.pipeline.PipelineRun(
hass,
context=Context(),
pipeline=pipeline,
start_stage=assist_pipeline.PipelineStage.TTS,
end_stage=assist_pipeline.PipelineStage.TTS,
event_callback=events.append,
tts_audio_output="wav",
),
)
await pipeline_input.validate()
# Verify TTS audio settings
assert pipeline_input.run.tts_options is not None
assert pipeline_input.run.tts_options.get(tts.ATTR_PREFERRED_FORMAT) == "wav"
assert pipeline_input.run.tts_options.get(tts.ATTR_PREFERRED_SAMPLE_RATE) == 16000
assert pipeline_input.run.tts_options.get(tts.ATTR_PREFERRED_SAMPLE_CHANNELS) == 1
with patch.object(mock_tts_provider, "get_tts_audio") as mock_get_tts_audio:
await pipeline_input.execute()
for event in events:
if event.type == assist_pipeline.PipelineEventType.TTS_END:
# We must fetch the media URL to trigger the TTS
assert event.data
media_id = event.data["tts_output"]["media_id"]
resolved = await media_source.async_resolve_media(hass, media_id, None)
await client.get(resolved.url)
# Ensure that no unsupported options were passed in
assert mock_get_tts_audio.called
options = mock_get_tts_audio.call_args_list[0].kwargs["options"]
extra_options = set(options).difference(mock_tts_provider.supported_options)
assert len(extra_options) == 0, extra_options
async def test_tts_supports_preferred_format(
hass: HomeAssistant,
hass_client: ClientSessionGenerator,
mock_tts_provider: MockTTSProvider,
init_components,
pipeline_data: assist_pipeline.pipeline.PipelineData,
snapshot: SnapshotAssertion,
) -> None:
"""Test that preferred format options are given to the TTS system if supported."""
client = await hass_client()
assert await async_setup_component(hass, media_source.DOMAIN, {})
events: list[assist_pipeline.PipelineEvent] = []
pipeline_store = pipeline_data.pipeline_store
pipeline_id = pipeline_store.async_get_preferred_item()
pipeline = assist_pipeline.pipeline.async_get_pipeline(hass, pipeline_id)
pipeline_input = assist_pipeline.pipeline.PipelineInput(
tts_input="This is a test.",
conversation_id=None,
device_id=None,
run=assist_pipeline.pipeline.PipelineRun(
hass,
context=Context(),
pipeline=pipeline,
start_stage=assist_pipeline.PipelineStage.TTS,
end_stage=assist_pipeline.PipelineStage.TTS,
event_callback=events.append,
tts_audio_output="wav",
),
)
await pipeline_input.validate()
# Make the TTS provider support preferred format options
supported_options = list(mock_tts_provider.supported_options or [])
supported_options.extend(
[
tts.ATTR_PREFERRED_FORMAT,
tts.ATTR_PREFERRED_SAMPLE_RATE,
tts.ATTR_PREFERRED_SAMPLE_CHANNELS,
]
)
with (
patch.object(mock_tts_provider, "_supported_options", supported_options),
patch.object(mock_tts_provider, "get_tts_audio") as mock_get_tts_audio,
):
await pipeline_input.execute()
for event in events:
if event.type == assist_pipeline.PipelineEventType.TTS_END:
# We must fetch the media URL to trigger the TTS
assert event.data
media_id = event.data["tts_output"]["media_id"]
resolved = await media_source.async_resolve_media(hass, media_id, None)
await client.get(resolved.url)
assert mock_get_tts_audio.called
options = mock_get_tts_audio.call_args_list[0].kwargs["options"]
# We should have received preferred format options in get_tts_audio
assert tts.ATTR_PREFERRED_FORMAT in options
assert tts.ATTR_PREFERRED_SAMPLE_RATE in options
assert tts.ATTR_PREFERRED_SAMPLE_CHANNELS in options