Filter preferred TTS format options if not supported (#114392)

Filter preferred format options if not supported
pull/114764/head
Michael Hansen 2024-03-28 11:09:15 -05:00 committed by Franck Nijhof
parent 8cd8718855
commit c81e9447f9
No known key found for this signature in database
GPG Key ID: D62583BA8AB11CA3
3 changed files with 150 additions and 32 deletions

View File

@ -16,7 +16,7 @@ import os
import re
import subprocess
import tempfile
from typing import Any, TypedDict, final
from typing import Any, Final, TypedDict, final
from aiohttp import web
import mutagen
@ -99,6 +99,13 @@ ATTR_PREFERRED_SAMPLE_CHANNELS = "preferred_sample_channels"
ATTR_MEDIA_PLAYER_ENTITY_ID = "media_player_entity_id"
ATTR_VOICE = "voice"
_DEFAULT_FORMAT = "mp3"
_PREFFERED_FORMAT_OPTIONS: Final[set[str]] = {
ATTR_PREFERRED_FORMAT,
ATTR_PREFERRED_SAMPLE_RATE,
ATTR_PREFERRED_SAMPLE_CHANNELS,
}
CONF_LANG = "language"
SERVICE_CLEAR_CACHE = "clear_cache"
@ -569,25 +576,23 @@ class SpeechManager:
):
raise HomeAssistantError(f"Language '{language}' not supported")
options = options or {}
supported_options = engine_instance.supported_options or []
# Update default options with provided options
invalid_opts: list[str] = []
merged_options = dict(engine_instance.default_options or {})
merged_options.update(options or {})
for option_name, option_value in options.items():
# Only count an option as invalid if it's not a "preferred format"
# option. These are used as hints to the TTS system if supported,
# and otherwise as parameters to ffmpeg conversion.
if (option_name in supported_options) or (
option_name in _PREFFERED_FORMAT_OPTIONS
):
merged_options[option_name] = option_value
else:
invalid_opts.append(option_name)
supported_options = list(engine_instance.supported_options or [])
# ATTR_PREFERRED_* options are always "supported" since they're used to
# convert audio after the TTS has run (if necessary).
supported_options.extend(
(
ATTR_PREFERRED_FORMAT,
ATTR_PREFERRED_SAMPLE_RATE,
ATTR_PREFERRED_SAMPLE_CHANNELS,
)
)
invalid_opts = [
opt_name for opt_name in merged_options if opt_name not in supported_options
]
if invalid_opts:
raise HomeAssistantError(f"Invalid options found: {invalid_opts}")
@ -687,10 +692,31 @@ class SpeechManager:
This method is a coroutine.
"""
options = options or {}
options = dict(options or {})
supported_options = engine_instance.supported_options or []
# Default to MP3 unless a different format is preferred
final_extension = options.get(ATTR_PREFERRED_FORMAT, "mp3")
# Extract preferred format options.
#
# These options are used by Assist pipelines, etc. to get a format that
# the voice satellite will support.
#
# The TTS system ideally supports options directly so we won't have
# to convert with ffmpeg later. If not, we pop the options here and
# perform the conversation after receiving the audio.
if ATTR_PREFERRED_FORMAT in supported_options:
final_extension = options.get(ATTR_PREFERRED_FORMAT, _DEFAULT_FORMAT)
else:
final_extension = options.pop(ATTR_PREFERRED_FORMAT, _DEFAULT_FORMAT)
if ATTR_PREFERRED_SAMPLE_RATE in supported_options:
sample_rate = options.get(ATTR_PREFERRED_SAMPLE_RATE)
else:
sample_rate = options.pop(ATTR_PREFERRED_SAMPLE_RATE, None)
if ATTR_PREFERRED_SAMPLE_CHANNELS in supported_options:
sample_channels = options.get(ATTR_PREFERRED_SAMPLE_CHANNELS)
else:
sample_channels = options.pop(ATTR_PREFERRED_SAMPLE_CHANNELS, None)
async def get_tts_data() -> str:
"""Handle data available."""
@ -716,8 +742,8 @@ class SpeechManager:
# rate/format/channel count is requested.
needs_conversion = (
(final_extension != extension)
or (ATTR_PREFERRED_SAMPLE_RATE in options)
or (ATTR_PREFERRED_SAMPLE_CHANNELS in options)
or (sample_rate is not None)
or (sample_channels is not None)
)
if needs_conversion:
@ -726,8 +752,8 @@ class SpeechManager:
extension,
data,
to_extension=final_extension,
to_sample_rate=options.get(ATTR_PREFERRED_SAMPLE_RATE),
to_sample_channels=options.get(ATTR_PREFERRED_SAMPLE_CHANNELS),
to_sample_rate=sample_rate,
to_sample_channels=sample_channels,
)
# Create file infos

View File

@ -111,6 +111,7 @@ class MockTTSProvider(tts.Provider):
tts.Voice("fran_drescher", "Fran Drescher"),
]
}
_supported_options = ["voice", "age", tts.ATTR_AUDIO_OUTPUT]
@property
def default_language(self) -> str:
@ -130,7 +131,7 @@ class MockTTSProvider(tts.Provider):
@property
def supported_options(self) -> list[str]:
"""Return list of supported options like voice, emotions."""
return ["voice", "age", tts.ATTR_AUDIO_OUTPUT]
return self._supported_options
def get_tts_audio(
self, message: str, language: str, options: dict[str, Any]

View File

@ -11,7 +11,7 @@ import wave
import pytest
from syrupy.assertion import SnapshotAssertion
from homeassistant.components import assist_pipeline, stt, tts
from homeassistant.components import assist_pipeline, media_source, stt, tts
from homeassistant.components.assist_pipeline.const import (
CONF_DEBUG_RECORDING_DIR,
DOMAIN,
@ -19,9 +19,14 @@ from homeassistant.components.assist_pipeline.const import (
from homeassistant.core import Context, HomeAssistant
from homeassistant.setup import async_setup_component
from .conftest import MockSttProvider, MockSttProviderEntity, MockWakeWordEntity
from .conftest import (
MockSttProvider,
MockSttProviderEntity,
MockTTSProvider,
MockWakeWordEntity,
)
from tests.typing import WebSocketGenerator
from tests.typing import ClientSessionGenerator, WebSocketGenerator
BYTES_ONE_SECOND = 16000 * 2
@ -729,15 +734,17 @@ def test_pipeline_run_equality(hass: HomeAssistant, init_components) -> None:
async def test_tts_audio_output(
hass: HomeAssistant,
mock_stt_provider: MockSttProvider,
hass_client: ClientSessionGenerator,
mock_tts_provider: MockTTSProvider,
init_components,
pipeline_data: assist_pipeline.pipeline.PipelineData,
snapshot: SnapshotAssertion,
) -> None:
"""Test using tts_audio_output with wav sets options correctly."""
client = await hass_client()
assert await async_setup_component(hass, media_source.DOMAIN, {})
def event_callback(event):
pass
events: list[assist_pipeline.PipelineEvent] = []
pipeline_store = pipeline_data.pipeline_store
pipeline_id = pipeline_store.async_get_preferred_item()
@ -753,7 +760,7 @@ async def test_tts_audio_output(
pipeline=pipeline,
start_stage=assist_pipeline.PipelineStage.TTS,
end_stage=assist_pipeline.PipelineStage.TTS,
event_callback=event_callback,
event_callback=events.append,
tts_audio_output="wav",
),
)
@ -764,3 +771,87 @@ async def test_tts_audio_output(
assert pipeline_input.run.tts_options.get(tts.ATTR_PREFERRED_FORMAT) == "wav"
assert pipeline_input.run.tts_options.get(tts.ATTR_PREFERRED_SAMPLE_RATE) == 16000
assert pipeline_input.run.tts_options.get(tts.ATTR_PREFERRED_SAMPLE_CHANNELS) == 1
with patch.object(mock_tts_provider, "get_tts_audio") as mock_get_tts_audio:
await pipeline_input.execute()
for event in events:
if event.type == assist_pipeline.PipelineEventType.TTS_END:
# We must fetch the media URL to trigger the TTS
assert event.data
media_id = event.data["tts_output"]["media_id"]
resolved = await media_source.async_resolve_media(hass, media_id, None)
await client.get(resolved.url)
# Ensure that no unsupported options were passed in
assert mock_get_tts_audio.called
options = mock_get_tts_audio.call_args_list[0].kwargs["options"]
extra_options = set(options).difference(mock_tts_provider.supported_options)
assert len(extra_options) == 0, extra_options
async def test_tts_supports_preferred_format(
hass: HomeAssistant,
hass_client: ClientSessionGenerator,
mock_tts_provider: MockTTSProvider,
init_components,
pipeline_data: assist_pipeline.pipeline.PipelineData,
snapshot: SnapshotAssertion,
) -> None:
"""Test that preferred format options are given to the TTS system if supported."""
client = await hass_client()
assert await async_setup_component(hass, media_source.DOMAIN, {})
events: list[assist_pipeline.PipelineEvent] = []
pipeline_store = pipeline_data.pipeline_store
pipeline_id = pipeline_store.async_get_preferred_item()
pipeline = assist_pipeline.pipeline.async_get_pipeline(hass, pipeline_id)
pipeline_input = assist_pipeline.pipeline.PipelineInput(
tts_input="This is a test.",
conversation_id=None,
device_id=None,
run=assist_pipeline.pipeline.PipelineRun(
hass,
context=Context(),
pipeline=pipeline,
start_stage=assist_pipeline.PipelineStage.TTS,
end_stage=assist_pipeline.PipelineStage.TTS,
event_callback=events.append,
tts_audio_output="wav",
),
)
await pipeline_input.validate()
# Make the TTS provider support preferred format options
supported_options = list(mock_tts_provider.supported_options or [])
supported_options.extend(
[
tts.ATTR_PREFERRED_FORMAT,
tts.ATTR_PREFERRED_SAMPLE_RATE,
tts.ATTR_PREFERRED_SAMPLE_CHANNELS,
]
)
with (
patch.object(mock_tts_provider, "_supported_options", supported_options),
patch.object(mock_tts_provider, "get_tts_audio") as mock_get_tts_audio,
):
await pipeline_input.execute()
for event in events:
if event.type == assist_pipeline.PipelineEventType.TTS_END:
# We must fetch the media URL to trigger the TTS
assert event.data
media_id = event.data["tts_output"]["media_id"]
resolved = await media_source.async_resolve_media(hass, media_id, None)
await client.get(resolved.url)
assert mock_get_tts_audio.called
options = mock_get_tts_audio.call_args_list[0].kwargs["options"]
# We should have received preferred format options in get_tts_audio
assert tts.ATTR_PREFERRED_FORMAT in options
assert tts.ATTR_PREFERRED_SAMPLE_RATE in options
assert tts.ATTR_PREFERRED_SAMPLE_CHANNELS in options