Add support for sample bytes in preferred TTS format (#125235)
parent
892c32c8b7
commit
4ecc6555bf
homeassistant/components
assist_pipeline
tts
tests/components/assist_pipeline
|
@ -3,6 +3,7 @@
|
|||
from __future__ import annotations
|
||||
|
||||
from collections.abc import AsyncIterable
|
||||
from typing import Any
|
||||
|
||||
import voluptuous as vol
|
||||
|
||||
|
@ -99,7 +100,7 @@ async def async_pipeline_from_audio_stream(
|
|||
wake_word_phrase: str | None = None,
|
||||
pipeline_id: str | None = None,
|
||||
conversation_id: str | None = None,
|
||||
tts_audio_output: str | None = None,
|
||||
tts_audio_output: str | dict[str, Any] | None = None,
|
||||
wake_word_settings: WakeWordSettings | None = None,
|
||||
audio_settings: AudioSettings | None = None,
|
||||
device_id: str | None = None,
|
||||
|
|
|
@ -538,7 +538,7 @@ class PipelineRun:
|
|||
language: str = None # type: ignore[assignment]
|
||||
runner_data: Any | None = None
|
||||
intent_agent: str | None = None
|
||||
tts_audio_output: str | None = None
|
||||
tts_audio_output: str | dict[str, Any] | None = None
|
||||
wake_word_settings: WakeWordSettings | None = None
|
||||
audio_settings: AudioSettings = field(default_factory=AudioSettings)
|
||||
|
||||
|
@ -1052,12 +1052,15 @@ class PipelineRun:
|
|||
if self.pipeline.tts_voice is not None:
|
||||
tts_options[tts.ATTR_VOICE] = self.pipeline.tts_voice
|
||||
|
||||
if self.tts_audio_output is not None:
|
||||
if isinstance(self.tts_audio_output, dict):
|
||||
tts_options.update(self.tts_audio_output)
|
||||
elif isinstance(self.tts_audio_output, str):
|
||||
tts_options[tts.ATTR_PREFERRED_FORMAT] = self.tts_audio_output
|
||||
if self.tts_audio_output == "wav":
|
||||
# 16 Khz, 16-bit mono
|
||||
tts_options[tts.ATTR_PREFERRED_SAMPLE_RATE] = SAMPLE_RATE
|
||||
tts_options[tts.ATTR_PREFERRED_SAMPLE_CHANNELS] = SAMPLE_CHANNELS
|
||||
tts_options[tts.ATTR_PREFERRED_SAMPLE_BYTES] = SAMPLE_WIDTH
|
||||
|
||||
try:
|
||||
options_supported = await tts.async_support_options(
|
||||
|
|
|
@ -77,6 +77,7 @@ __all__ = [
|
|||
"ATTR_PREFERRED_FORMAT",
|
||||
"ATTR_PREFERRED_SAMPLE_RATE",
|
||||
"ATTR_PREFERRED_SAMPLE_CHANNELS",
|
||||
"ATTR_PREFERRED_SAMPLE_BYTES",
|
||||
"CONF_LANG",
|
||||
"DEFAULT_CACHE_DIR",
|
||||
"generate_media_source_id",
|
||||
|
@ -95,6 +96,7 @@ ATTR_AUDIO_OUTPUT = "audio_output"
|
|||
ATTR_PREFERRED_FORMAT = "preferred_format"
|
||||
ATTR_PREFERRED_SAMPLE_RATE = "preferred_sample_rate"
|
||||
ATTR_PREFERRED_SAMPLE_CHANNELS = "preferred_sample_channels"
|
||||
ATTR_PREFERRED_SAMPLE_BYTES = "preferred_sample_bytes"
|
||||
ATTR_MEDIA_PLAYER_ENTITY_ID = "media_player_entity_id"
|
||||
ATTR_VOICE = "voice"
|
||||
|
||||
|
@ -103,6 +105,7 @@ _PREFFERED_FORMAT_OPTIONS: Final[set[str]] = {
|
|||
ATTR_PREFERRED_FORMAT,
|
||||
ATTR_PREFERRED_SAMPLE_RATE,
|
||||
ATTR_PREFERRED_SAMPLE_CHANNELS,
|
||||
ATTR_PREFERRED_SAMPLE_BYTES,
|
||||
}
|
||||
|
||||
CONF_LANG = "language"
|
||||
|
@ -223,6 +226,7 @@ async def async_convert_audio(
|
|||
to_extension: str,
|
||||
to_sample_rate: int | None = None,
|
||||
to_sample_channels: int | None = None,
|
||||
to_sample_bytes: int | None = None,
|
||||
) -> bytes:
|
||||
"""Convert audio to a preferred format using ffmpeg."""
|
||||
ffmpeg_manager = ffmpeg.get_ffmpeg_manager(hass)
|
||||
|
@ -234,6 +238,7 @@ async def async_convert_audio(
|
|||
to_extension,
|
||||
to_sample_rate=to_sample_rate,
|
||||
to_sample_channels=to_sample_channels,
|
||||
to_sample_bytes=to_sample_bytes,
|
||||
)
|
||||
)
|
||||
|
||||
|
@ -245,6 +250,7 @@ def _convert_audio(
|
|||
to_extension: str,
|
||||
to_sample_rate: int | None = None,
|
||||
to_sample_channels: int | None = None,
|
||||
to_sample_bytes: int | None = None,
|
||||
) -> bytes:
|
||||
"""Convert audio to a preferred format using ffmpeg."""
|
||||
|
||||
|
@ -277,6 +283,10 @@ def _convert_audio(
|
|||
# Max quality for MP3
|
||||
command.extend(["-q:a", "0"])
|
||||
|
||||
if to_sample_bytes == 2:
|
||||
# 16-bit samples
|
||||
command.extend(["-sample_fmt", "s16"])
|
||||
|
||||
command.append(output_file.name)
|
||||
|
||||
with subprocess.Popen(
|
||||
|
@ -738,11 +748,25 @@ class SpeechManager:
|
|||
else:
|
||||
sample_rate = options.pop(ATTR_PREFERRED_SAMPLE_RATE, None)
|
||||
|
||||
if sample_rate is not None:
|
||||
sample_rate = int(sample_rate)
|
||||
|
||||
if ATTR_PREFERRED_SAMPLE_CHANNELS in supported_options:
|
||||
sample_channels = options.get(ATTR_PREFERRED_SAMPLE_CHANNELS)
|
||||
else:
|
||||
sample_channels = options.pop(ATTR_PREFERRED_SAMPLE_CHANNELS, None)
|
||||
|
||||
if sample_channels is not None:
|
||||
sample_channels = int(sample_channels)
|
||||
|
||||
if ATTR_PREFERRED_SAMPLE_BYTES in supported_options:
|
||||
sample_bytes = options.get(ATTR_PREFERRED_SAMPLE_BYTES)
|
||||
else:
|
||||
sample_bytes = options.pop(ATTR_PREFERRED_SAMPLE_BYTES, None)
|
||||
|
||||
if sample_bytes is not None:
|
||||
sample_bytes = int(sample_bytes)
|
||||
|
||||
async def get_tts_data() -> str:
|
||||
"""Handle data available."""
|
||||
if engine_instance.name is None or engine_instance.name is UNDEFINED:
|
||||
|
@ -769,6 +793,7 @@ class SpeechManager:
|
|||
(final_extension != extension)
|
||||
or (sample_rate is not None)
|
||||
or (sample_channels is not None)
|
||||
or (sample_bytes is not None)
|
||||
)
|
||||
|
||||
if needs_conversion:
|
||||
|
@ -779,6 +804,7 @@ class SpeechManager:
|
|||
to_extension=final_extension,
|
||||
to_sample_rate=sample_rate,
|
||||
to_sample_channels=sample_channels,
|
||||
to_sample_bytes=sample_bytes,
|
||||
)
|
||||
|
||||
# Create file infos
|
||||
|
|
|
@ -788,13 +788,12 @@ async def test_tts_audio_output(
|
|||
assert len(extra_options) == 0, extra_options
|
||||
|
||||
|
||||
async def test_tts_supports_preferred_format(
|
||||
async def test_tts_wav_preferred_format(
|
||||
hass: HomeAssistant,
|
||||
hass_client: ClientSessionGenerator,
|
||||
mock_tts_provider: MockTTSProvider,
|
||||
init_components,
|
||||
pipeline_data: assist_pipeline.pipeline.PipelineData,
|
||||
snapshot: SnapshotAssertion,
|
||||
) -> None:
|
||||
"""Test that preferred format options are given to the TTS system if supported."""
|
||||
client = await hass_client()
|
||||
|
@ -829,6 +828,7 @@ async def test_tts_supports_preferred_format(
|
|||
tts.ATTR_PREFERRED_FORMAT,
|
||||
tts.ATTR_PREFERRED_SAMPLE_RATE,
|
||||
tts.ATTR_PREFERRED_SAMPLE_CHANNELS,
|
||||
tts.ATTR_PREFERRED_SAMPLE_BYTES,
|
||||
]
|
||||
)
|
||||
|
||||
|
@ -850,6 +850,80 @@ async def test_tts_supports_preferred_format(
|
|||
options = mock_get_tts_audio.call_args_list[0].kwargs["options"]
|
||||
|
||||
# We should have received preferred format options in get_tts_audio
|
||||
assert tts.ATTR_PREFERRED_FORMAT in options
|
||||
assert tts.ATTR_PREFERRED_SAMPLE_RATE in options
|
||||
assert tts.ATTR_PREFERRED_SAMPLE_CHANNELS in options
|
||||
assert options.get(tts.ATTR_PREFERRED_FORMAT) == "wav"
|
||||
assert int(options.get(tts.ATTR_PREFERRED_SAMPLE_RATE)) == 16000
|
||||
assert int(options.get(tts.ATTR_PREFERRED_SAMPLE_CHANNELS)) == 1
|
||||
assert int(options.get(tts.ATTR_PREFERRED_SAMPLE_BYTES)) == 2
|
||||
|
||||
|
||||
async def test_tts_dict_preferred_format(
|
||||
hass: HomeAssistant,
|
||||
hass_client: ClientSessionGenerator,
|
||||
mock_tts_provider: MockTTSProvider,
|
||||
init_components,
|
||||
pipeline_data: assist_pipeline.pipeline.PipelineData,
|
||||
) -> None:
|
||||
"""Test that preferred format options are given to the TTS system if supported."""
|
||||
client = await hass_client()
|
||||
assert await async_setup_component(hass, media_source.DOMAIN, {})
|
||||
|
||||
events: list[assist_pipeline.PipelineEvent] = []
|
||||
|
||||
pipeline_store = pipeline_data.pipeline_store
|
||||
pipeline_id = pipeline_store.async_get_preferred_item()
|
||||
pipeline = assist_pipeline.pipeline.async_get_pipeline(hass, pipeline_id)
|
||||
|
||||
pipeline_input = assist_pipeline.pipeline.PipelineInput(
|
||||
tts_input="This is a test.",
|
||||
conversation_id=None,
|
||||
device_id=None,
|
||||
run=assist_pipeline.pipeline.PipelineRun(
|
||||
hass,
|
||||
context=Context(),
|
||||
pipeline=pipeline,
|
||||
start_stage=assist_pipeline.PipelineStage.TTS,
|
||||
end_stage=assist_pipeline.PipelineStage.TTS,
|
||||
event_callback=events.append,
|
||||
tts_audio_output={
|
||||
tts.ATTR_PREFERRED_FORMAT: "flac",
|
||||
tts.ATTR_PREFERRED_SAMPLE_RATE: 48000,
|
||||
tts.ATTR_PREFERRED_SAMPLE_CHANNELS: 2,
|
||||
tts.ATTR_PREFERRED_SAMPLE_BYTES: 2,
|
||||
},
|
||||
),
|
||||
)
|
||||
await pipeline_input.validate()
|
||||
|
||||
# Make the TTS provider support preferred format options
|
||||
supported_options = list(mock_tts_provider.supported_options or [])
|
||||
supported_options.extend(
|
||||
[
|
||||
tts.ATTR_PREFERRED_FORMAT,
|
||||
tts.ATTR_PREFERRED_SAMPLE_RATE,
|
||||
tts.ATTR_PREFERRED_SAMPLE_CHANNELS,
|
||||
tts.ATTR_PREFERRED_SAMPLE_BYTES,
|
||||
]
|
||||
)
|
||||
|
||||
with (
|
||||
patch.object(mock_tts_provider, "_supported_options", supported_options),
|
||||
patch.object(mock_tts_provider, "get_tts_audio") as mock_get_tts_audio,
|
||||
):
|
||||
await pipeline_input.execute()
|
||||
|
||||
for event in events:
|
||||
if event.type == assist_pipeline.PipelineEventType.TTS_END:
|
||||
# We must fetch the media URL to trigger the TTS
|
||||
assert event.data
|
||||
media_id = event.data["tts_output"]["media_id"]
|
||||
resolved = await media_source.async_resolve_media(hass, media_id, None)
|
||||
await client.get(resolved.url)
|
||||
|
||||
assert mock_get_tts_audio.called
|
||||
options = mock_get_tts_audio.call_args_list[0].kwargs["options"]
|
||||
|
||||
# We should have received preferred format options in get_tts_audio
|
||||
assert options.get(tts.ATTR_PREFERRED_FORMAT) == "flac"
|
||||
assert int(options.get(tts.ATTR_PREFERRED_SAMPLE_RATE)) == 48000
|
||||
assert int(options.get(tts.ATTR_PREFERRED_SAMPLE_CHANNELS)) == 2
|
||||
assert int(options.get(tts.ATTR_PREFERRED_SAMPLE_BYTES)) == 2
|
||||
|
|
Loading…
Reference in New Issue