Filter preferred TTS format options if not supported (#114392)
Filter preferred format options if not supportedpull/114764/head
parent
8cd8718855
commit
c81e9447f9
|
@ -16,7 +16,7 @@ import os
|
|||
import re
|
||||
import subprocess
|
||||
import tempfile
|
||||
from typing import Any, TypedDict, final
|
||||
from typing import Any, Final, TypedDict, final
|
||||
|
||||
from aiohttp import web
|
||||
import mutagen
|
||||
|
@ -99,6 +99,13 @@ ATTR_PREFERRED_SAMPLE_CHANNELS = "preferred_sample_channels"
|
|||
ATTR_MEDIA_PLAYER_ENTITY_ID = "media_player_entity_id"
|
||||
ATTR_VOICE = "voice"
|
||||
|
||||
_DEFAULT_FORMAT = "mp3"
|
||||
_PREFFERED_FORMAT_OPTIONS: Final[set[str]] = {
|
||||
ATTR_PREFERRED_FORMAT,
|
||||
ATTR_PREFERRED_SAMPLE_RATE,
|
||||
ATTR_PREFERRED_SAMPLE_CHANNELS,
|
||||
}
|
||||
|
||||
CONF_LANG = "language"
|
||||
|
||||
SERVICE_CLEAR_CACHE = "clear_cache"
|
||||
|
@ -569,25 +576,23 @@ class SpeechManager:
|
|||
):
|
||||
raise HomeAssistantError(f"Language '{language}' not supported")
|
||||
|
||||
options = options or {}
|
||||
supported_options = engine_instance.supported_options or []
|
||||
|
||||
# Update default options with provided options
|
||||
invalid_opts: list[str] = []
|
||||
merged_options = dict(engine_instance.default_options or {})
|
||||
merged_options.update(options or {})
|
||||
for option_name, option_value in options.items():
|
||||
# Only count an option as invalid if it's not a "preferred format"
|
||||
# option. These are used as hints to the TTS system if supported,
|
||||
# and otherwise as parameters to ffmpeg conversion.
|
||||
if (option_name in supported_options) or (
|
||||
option_name in _PREFFERED_FORMAT_OPTIONS
|
||||
):
|
||||
merged_options[option_name] = option_value
|
||||
else:
|
||||
invalid_opts.append(option_name)
|
||||
|
||||
supported_options = list(engine_instance.supported_options or [])
|
||||
|
||||
# ATTR_PREFERRED_* options are always "supported" since they're used to
|
||||
# convert audio after the TTS has run (if necessary).
|
||||
supported_options.extend(
|
||||
(
|
||||
ATTR_PREFERRED_FORMAT,
|
||||
ATTR_PREFERRED_SAMPLE_RATE,
|
||||
ATTR_PREFERRED_SAMPLE_CHANNELS,
|
||||
)
|
||||
)
|
||||
|
||||
invalid_opts = [
|
||||
opt_name for opt_name in merged_options if opt_name not in supported_options
|
||||
]
|
||||
if invalid_opts:
|
||||
raise HomeAssistantError(f"Invalid options found: {invalid_opts}")
|
||||
|
||||
|
@ -687,10 +692,31 @@ class SpeechManager:
|
|||
|
||||
This method is a coroutine.
|
||||
"""
|
||||
options = options or {}
|
||||
options = dict(options or {})
|
||||
supported_options = engine_instance.supported_options or []
|
||||
|
||||
# Default to MP3 unless a different format is preferred
|
||||
final_extension = options.get(ATTR_PREFERRED_FORMAT, "mp3")
|
||||
# Extract preferred format options.
|
||||
#
|
||||
# These options are used by Assist pipelines, etc. to get a format that
|
||||
# the voice satellite will support.
|
||||
#
|
||||
# The TTS system ideally supports options directly so we won't have
|
||||
# to convert with ffmpeg later. If not, we pop the options here and
|
||||
# perform the conversation after receiving the audio.
|
||||
if ATTR_PREFERRED_FORMAT in supported_options:
|
||||
final_extension = options.get(ATTR_PREFERRED_FORMAT, _DEFAULT_FORMAT)
|
||||
else:
|
||||
final_extension = options.pop(ATTR_PREFERRED_FORMAT, _DEFAULT_FORMAT)
|
||||
|
||||
if ATTR_PREFERRED_SAMPLE_RATE in supported_options:
|
||||
sample_rate = options.get(ATTR_PREFERRED_SAMPLE_RATE)
|
||||
else:
|
||||
sample_rate = options.pop(ATTR_PREFERRED_SAMPLE_RATE, None)
|
||||
|
||||
if ATTR_PREFERRED_SAMPLE_CHANNELS in supported_options:
|
||||
sample_channels = options.get(ATTR_PREFERRED_SAMPLE_CHANNELS)
|
||||
else:
|
||||
sample_channels = options.pop(ATTR_PREFERRED_SAMPLE_CHANNELS, None)
|
||||
|
||||
async def get_tts_data() -> str:
|
||||
"""Handle data available."""
|
||||
|
@ -716,8 +742,8 @@ class SpeechManager:
|
|||
# rate/format/channel count is requested.
|
||||
needs_conversion = (
|
||||
(final_extension != extension)
|
||||
or (ATTR_PREFERRED_SAMPLE_RATE in options)
|
||||
or (ATTR_PREFERRED_SAMPLE_CHANNELS in options)
|
||||
or (sample_rate is not None)
|
||||
or (sample_channels is not None)
|
||||
)
|
||||
|
||||
if needs_conversion:
|
||||
|
@ -726,8 +752,8 @@ class SpeechManager:
|
|||
extension,
|
||||
data,
|
||||
to_extension=final_extension,
|
||||
to_sample_rate=options.get(ATTR_PREFERRED_SAMPLE_RATE),
|
||||
to_sample_channels=options.get(ATTR_PREFERRED_SAMPLE_CHANNELS),
|
||||
to_sample_rate=sample_rate,
|
||||
to_sample_channels=sample_channels,
|
||||
)
|
||||
|
||||
# Create file infos
|
||||
|
|
|
@ -111,6 +111,7 @@ class MockTTSProvider(tts.Provider):
|
|||
tts.Voice("fran_drescher", "Fran Drescher"),
|
||||
]
|
||||
}
|
||||
_supported_options = ["voice", "age", tts.ATTR_AUDIO_OUTPUT]
|
||||
|
||||
@property
|
||||
def default_language(self) -> str:
|
||||
|
@ -130,7 +131,7 @@ class MockTTSProvider(tts.Provider):
|
|||
@property
|
||||
def supported_options(self) -> list[str]:
|
||||
"""Return list of supported options like voice, emotions."""
|
||||
return ["voice", "age", tts.ATTR_AUDIO_OUTPUT]
|
||||
return self._supported_options
|
||||
|
||||
def get_tts_audio(
|
||||
self, message: str, language: str, options: dict[str, Any]
|
||||
|
|
|
@ -11,7 +11,7 @@ import wave
|
|||
import pytest
|
||||
from syrupy.assertion import SnapshotAssertion
|
||||
|
||||
from homeassistant.components import assist_pipeline, stt, tts
|
||||
from homeassistant.components import assist_pipeline, media_source, stt, tts
|
||||
from homeassistant.components.assist_pipeline.const import (
|
||||
CONF_DEBUG_RECORDING_DIR,
|
||||
DOMAIN,
|
||||
|
@ -19,9 +19,14 @@ from homeassistant.components.assist_pipeline.const import (
|
|||
from homeassistant.core import Context, HomeAssistant
|
||||
from homeassistant.setup import async_setup_component
|
||||
|
||||
from .conftest import MockSttProvider, MockSttProviderEntity, MockWakeWordEntity
|
||||
from .conftest import (
|
||||
MockSttProvider,
|
||||
MockSttProviderEntity,
|
||||
MockTTSProvider,
|
||||
MockWakeWordEntity,
|
||||
)
|
||||
|
||||
from tests.typing import WebSocketGenerator
|
||||
from tests.typing import ClientSessionGenerator, WebSocketGenerator
|
||||
|
||||
BYTES_ONE_SECOND = 16000 * 2
|
||||
|
||||
|
@ -729,15 +734,17 @@ def test_pipeline_run_equality(hass: HomeAssistant, init_components) -> None:
|
|||
|
||||
async def test_tts_audio_output(
|
||||
hass: HomeAssistant,
|
||||
mock_stt_provider: MockSttProvider,
|
||||
hass_client: ClientSessionGenerator,
|
||||
mock_tts_provider: MockTTSProvider,
|
||||
init_components,
|
||||
pipeline_data: assist_pipeline.pipeline.PipelineData,
|
||||
snapshot: SnapshotAssertion,
|
||||
) -> None:
|
||||
"""Test using tts_audio_output with wav sets options correctly."""
|
||||
client = await hass_client()
|
||||
assert await async_setup_component(hass, media_source.DOMAIN, {})
|
||||
|
||||
def event_callback(event):
|
||||
pass
|
||||
events: list[assist_pipeline.PipelineEvent] = []
|
||||
|
||||
pipeline_store = pipeline_data.pipeline_store
|
||||
pipeline_id = pipeline_store.async_get_preferred_item()
|
||||
|
@ -753,7 +760,7 @@ async def test_tts_audio_output(
|
|||
pipeline=pipeline,
|
||||
start_stage=assist_pipeline.PipelineStage.TTS,
|
||||
end_stage=assist_pipeline.PipelineStage.TTS,
|
||||
event_callback=event_callback,
|
||||
event_callback=events.append,
|
||||
tts_audio_output="wav",
|
||||
),
|
||||
)
|
||||
|
@ -764,3 +771,87 @@ async def test_tts_audio_output(
|
|||
assert pipeline_input.run.tts_options.get(tts.ATTR_PREFERRED_FORMAT) == "wav"
|
||||
assert pipeline_input.run.tts_options.get(tts.ATTR_PREFERRED_SAMPLE_RATE) == 16000
|
||||
assert pipeline_input.run.tts_options.get(tts.ATTR_PREFERRED_SAMPLE_CHANNELS) == 1
|
||||
|
||||
with patch.object(mock_tts_provider, "get_tts_audio") as mock_get_tts_audio:
|
||||
await pipeline_input.execute()
|
||||
|
||||
for event in events:
|
||||
if event.type == assist_pipeline.PipelineEventType.TTS_END:
|
||||
# We must fetch the media URL to trigger the TTS
|
||||
assert event.data
|
||||
media_id = event.data["tts_output"]["media_id"]
|
||||
resolved = await media_source.async_resolve_media(hass, media_id, None)
|
||||
await client.get(resolved.url)
|
||||
|
||||
# Ensure that no unsupported options were passed in
|
||||
assert mock_get_tts_audio.called
|
||||
options = mock_get_tts_audio.call_args_list[0].kwargs["options"]
|
||||
extra_options = set(options).difference(mock_tts_provider.supported_options)
|
||||
assert len(extra_options) == 0, extra_options
|
||||
|
||||
|
||||
async def test_tts_supports_preferred_format(
|
||||
hass: HomeAssistant,
|
||||
hass_client: ClientSessionGenerator,
|
||||
mock_tts_provider: MockTTSProvider,
|
||||
init_components,
|
||||
pipeline_data: assist_pipeline.pipeline.PipelineData,
|
||||
snapshot: SnapshotAssertion,
|
||||
) -> None:
|
||||
"""Test that preferred format options are given to the TTS system if supported."""
|
||||
client = await hass_client()
|
||||
assert await async_setup_component(hass, media_source.DOMAIN, {})
|
||||
|
||||
events: list[assist_pipeline.PipelineEvent] = []
|
||||
|
||||
pipeline_store = pipeline_data.pipeline_store
|
||||
pipeline_id = pipeline_store.async_get_preferred_item()
|
||||
pipeline = assist_pipeline.pipeline.async_get_pipeline(hass, pipeline_id)
|
||||
|
||||
pipeline_input = assist_pipeline.pipeline.PipelineInput(
|
||||
tts_input="This is a test.",
|
||||
conversation_id=None,
|
||||
device_id=None,
|
||||
run=assist_pipeline.pipeline.PipelineRun(
|
||||
hass,
|
||||
context=Context(),
|
||||
pipeline=pipeline,
|
||||
start_stage=assist_pipeline.PipelineStage.TTS,
|
||||
end_stage=assist_pipeline.PipelineStage.TTS,
|
||||
event_callback=events.append,
|
||||
tts_audio_output="wav",
|
||||
),
|
||||
)
|
||||
await pipeline_input.validate()
|
||||
|
||||
# Make the TTS provider support preferred format options
|
||||
supported_options = list(mock_tts_provider.supported_options or [])
|
||||
supported_options.extend(
|
||||
[
|
||||
tts.ATTR_PREFERRED_FORMAT,
|
||||
tts.ATTR_PREFERRED_SAMPLE_RATE,
|
||||
tts.ATTR_PREFERRED_SAMPLE_CHANNELS,
|
||||
]
|
||||
)
|
||||
|
||||
with (
|
||||
patch.object(mock_tts_provider, "_supported_options", supported_options),
|
||||
patch.object(mock_tts_provider, "get_tts_audio") as mock_get_tts_audio,
|
||||
):
|
||||
await pipeline_input.execute()
|
||||
|
||||
for event in events:
|
||||
if event.type == assist_pipeline.PipelineEventType.TTS_END:
|
||||
# We must fetch the media URL to trigger the TTS
|
||||
assert event.data
|
||||
media_id = event.data["tts_output"]["media_id"]
|
||||
resolved = await media_source.async_resolve_media(hass, media_id, None)
|
||||
await client.get(resolved.url)
|
||||
|
||||
assert mock_get_tts_audio.called
|
||||
options = mock_get_tts_audio.call_args_list[0].kwargs["options"]
|
||||
|
||||
# We should have received preferred format options in get_tts_audio
|
||||
assert tts.ATTR_PREFERRED_FORMAT in options
|
||||
assert tts.ATTR_PREFERRED_SAMPLE_RATE in options
|
||||
assert tts.ATTR_PREFERRED_SAMPLE_CHANNELS in options
|
||||
|
|
Loading…
Reference in New Issue