Add support for sample bytes in preferred TTS format ()

pull/125245/head
Michael Hansen 2024-09-04 12:42:41 -05:00 committed by GitHub
parent 892c32c8b7
commit 4ecc6555bf
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 112 additions and 8 deletions
homeassistant/components
tests/components/assist_pipeline

View File

@ -3,6 +3,7 @@
from __future__ import annotations
from collections.abc import AsyncIterable
from typing import Any
import voluptuous as vol
@ -99,7 +100,7 @@ async def async_pipeline_from_audio_stream(
wake_word_phrase: str | None = None,
pipeline_id: str | None = None,
conversation_id: str | None = None,
tts_audio_output: str | None = None,
tts_audio_output: str | dict[str, Any] | None = None,
wake_word_settings: WakeWordSettings | None = None,
audio_settings: AudioSettings | None = None,
device_id: str | None = None,

View File

@ -538,7 +538,7 @@ class PipelineRun:
language: str = None # type: ignore[assignment]
runner_data: Any | None = None
intent_agent: str | None = None
tts_audio_output: str | None = None
tts_audio_output: str | dict[str, Any] | None = None
wake_word_settings: WakeWordSettings | None = None
audio_settings: AudioSettings = field(default_factory=AudioSettings)
@ -1052,12 +1052,15 @@ class PipelineRun:
if self.pipeline.tts_voice is not None:
tts_options[tts.ATTR_VOICE] = self.pipeline.tts_voice
if self.tts_audio_output is not None:
if isinstance(self.tts_audio_output, dict):
tts_options.update(self.tts_audio_output)
elif isinstance(self.tts_audio_output, str):
tts_options[tts.ATTR_PREFERRED_FORMAT] = self.tts_audio_output
if self.tts_audio_output == "wav":
# 16 Khz, 16-bit mono
tts_options[tts.ATTR_PREFERRED_SAMPLE_RATE] = SAMPLE_RATE
tts_options[tts.ATTR_PREFERRED_SAMPLE_CHANNELS] = SAMPLE_CHANNELS
tts_options[tts.ATTR_PREFERRED_SAMPLE_BYTES] = SAMPLE_WIDTH
try:
options_supported = await tts.async_support_options(

View File

@ -77,6 +77,7 @@ __all__ = [
"ATTR_PREFERRED_FORMAT",
"ATTR_PREFERRED_SAMPLE_RATE",
"ATTR_PREFERRED_SAMPLE_CHANNELS",
"ATTR_PREFERRED_SAMPLE_BYTES",
"CONF_LANG",
"DEFAULT_CACHE_DIR",
"generate_media_source_id",
@ -95,6 +96,7 @@ ATTR_AUDIO_OUTPUT = "audio_output"
ATTR_PREFERRED_FORMAT = "preferred_format"
ATTR_PREFERRED_SAMPLE_RATE = "preferred_sample_rate"
ATTR_PREFERRED_SAMPLE_CHANNELS = "preferred_sample_channels"
ATTR_PREFERRED_SAMPLE_BYTES = "preferred_sample_bytes"
ATTR_MEDIA_PLAYER_ENTITY_ID = "media_player_entity_id"
ATTR_VOICE = "voice"
@ -103,6 +105,7 @@ _PREFFERED_FORMAT_OPTIONS: Final[set[str]] = {
ATTR_PREFERRED_FORMAT,
ATTR_PREFERRED_SAMPLE_RATE,
ATTR_PREFERRED_SAMPLE_CHANNELS,
ATTR_PREFERRED_SAMPLE_BYTES,
}
CONF_LANG = "language"
@ -223,6 +226,7 @@ async def async_convert_audio(
to_extension: str,
to_sample_rate: int | None = None,
to_sample_channels: int | None = None,
to_sample_bytes: int | None = None,
) -> bytes:
"""Convert audio to a preferred format using ffmpeg."""
ffmpeg_manager = ffmpeg.get_ffmpeg_manager(hass)
@ -234,6 +238,7 @@ async def async_convert_audio(
to_extension,
to_sample_rate=to_sample_rate,
to_sample_channels=to_sample_channels,
to_sample_bytes=to_sample_bytes,
)
)
@ -245,6 +250,7 @@ def _convert_audio(
to_extension: str,
to_sample_rate: int | None = None,
to_sample_channels: int | None = None,
to_sample_bytes: int | None = None,
) -> bytes:
"""Convert audio to a preferred format using ffmpeg."""
@ -277,6 +283,10 @@ def _convert_audio(
# Max quality for MP3
command.extend(["-q:a", "0"])
if to_sample_bytes == 2:
# 16-bit samples
command.extend(["-sample_fmt", "s16"])
command.append(output_file.name)
with subprocess.Popen(
@ -738,11 +748,25 @@ class SpeechManager:
else:
sample_rate = options.pop(ATTR_PREFERRED_SAMPLE_RATE, None)
if sample_rate is not None:
sample_rate = int(sample_rate)
if ATTR_PREFERRED_SAMPLE_CHANNELS in supported_options:
sample_channels = options.get(ATTR_PREFERRED_SAMPLE_CHANNELS)
else:
sample_channels = options.pop(ATTR_PREFERRED_SAMPLE_CHANNELS, None)
if sample_channels is not None:
sample_channels = int(sample_channels)
if ATTR_PREFERRED_SAMPLE_BYTES in supported_options:
sample_bytes = options.get(ATTR_PREFERRED_SAMPLE_BYTES)
else:
sample_bytes = options.pop(ATTR_PREFERRED_SAMPLE_BYTES, None)
if sample_bytes is not None:
sample_bytes = int(sample_bytes)
async def get_tts_data() -> str:
"""Handle data available."""
if engine_instance.name is None or engine_instance.name is UNDEFINED:
@ -769,6 +793,7 @@ class SpeechManager:
(final_extension != extension)
or (sample_rate is not None)
or (sample_channels is not None)
or (sample_bytes is not None)
)
if needs_conversion:
@ -779,6 +804,7 @@ class SpeechManager:
to_extension=final_extension,
to_sample_rate=sample_rate,
to_sample_channels=sample_channels,
to_sample_bytes=sample_bytes,
)
# Create file infos

View File

@ -788,13 +788,12 @@ async def test_tts_audio_output(
assert len(extra_options) == 0, extra_options
async def test_tts_supports_preferred_format(
async def test_tts_wav_preferred_format(
hass: HomeAssistant,
hass_client: ClientSessionGenerator,
mock_tts_provider: MockTTSProvider,
init_components,
pipeline_data: assist_pipeline.pipeline.PipelineData,
snapshot: SnapshotAssertion,
) -> None:
"""Test that preferred format options are given to the TTS system if supported."""
client = await hass_client()
@ -829,6 +828,7 @@ async def test_tts_supports_preferred_format(
tts.ATTR_PREFERRED_FORMAT,
tts.ATTR_PREFERRED_SAMPLE_RATE,
tts.ATTR_PREFERRED_SAMPLE_CHANNELS,
tts.ATTR_PREFERRED_SAMPLE_BYTES,
]
)
@ -850,6 +850,80 @@ async def test_tts_supports_preferred_format(
options = mock_get_tts_audio.call_args_list[0].kwargs["options"]
# We should have received preferred format options in get_tts_audio
assert tts.ATTR_PREFERRED_FORMAT in options
assert tts.ATTR_PREFERRED_SAMPLE_RATE in options
assert tts.ATTR_PREFERRED_SAMPLE_CHANNELS in options
assert options.get(tts.ATTR_PREFERRED_FORMAT) == "wav"
assert int(options.get(tts.ATTR_PREFERRED_SAMPLE_RATE)) == 16000
assert int(options.get(tts.ATTR_PREFERRED_SAMPLE_CHANNELS)) == 1
assert int(options.get(tts.ATTR_PREFERRED_SAMPLE_BYTES)) == 2
async def test_tts_dict_preferred_format(
hass: HomeAssistant,
hass_client: ClientSessionGenerator,
mock_tts_provider: MockTTSProvider,
init_components,
pipeline_data: assist_pipeline.pipeline.PipelineData,
) -> None:
"""Test that preferred format options are given to the TTS system if supported."""
client = await hass_client()
assert await async_setup_component(hass, media_source.DOMAIN, {})
events: list[assist_pipeline.PipelineEvent] = []
pipeline_store = pipeline_data.pipeline_store
pipeline_id = pipeline_store.async_get_preferred_item()
pipeline = assist_pipeline.pipeline.async_get_pipeline(hass, pipeline_id)
pipeline_input = assist_pipeline.pipeline.PipelineInput(
tts_input="This is a test.",
conversation_id=None,
device_id=None,
run=assist_pipeline.pipeline.PipelineRun(
hass,
context=Context(),
pipeline=pipeline,
start_stage=assist_pipeline.PipelineStage.TTS,
end_stage=assist_pipeline.PipelineStage.TTS,
event_callback=events.append,
tts_audio_output={
tts.ATTR_PREFERRED_FORMAT: "flac",
tts.ATTR_PREFERRED_SAMPLE_RATE: 48000,
tts.ATTR_PREFERRED_SAMPLE_CHANNELS: 2,
tts.ATTR_PREFERRED_SAMPLE_BYTES: 2,
},
),
)
await pipeline_input.validate()
# Make the TTS provider support preferred format options
supported_options = list(mock_tts_provider.supported_options or [])
supported_options.extend(
[
tts.ATTR_PREFERRED_FORMAT,
tts.ATTR_PREFERRED_SAMPLE_RATE,
tts.ATTR_PREFERRED_SAMPLE_CHANNELS,
tts.ATTR_PREFERRED_SAMPLE_BYTES,
]
)
with (
patch.object(mock_tts_provider, "_supported_options", supported_options),
patch.object(mock_tts_provider, "get_tts_audio") as mock_get_tts_audio,
):
await pipeline_input.execute()
for event in events:
if event.type == assist_pipeline.PipelineEventType.TTS_END:
# We must fetch the media URL to trigger the TTS
assert event.data
media_id = event.data["tts_output"]["media_id"]
resolved = await media_source.async_resolve_media(hass, media_id, None)
await client.get(resolved.url)
assert mock_get_tts_audio.called
options = mock_get_tts_audio.call_args_list[0].kwargs["options"]
# We should have received preferred format options in get_tts_audio
assert options.get(tts.ATTR_PREFERRED_FORMAT) == "flac"
assert int(options.get(tts.ATTR_PREFERRED_SAMPLE_RATE)) == 48000
assert int(options.get(tts.ATTR_PREFERRED_SAMPLE_CHANNELS)) == 2
assert int(options.get(tts.ATTR_PREFERRED_SAMPLE_BYTES)) == 2