From ae516ffbb570aeca27c75aacc0f6dcd72d98efc7 Mon Sep 17 00:00:00 2001 From: Michael Hansen Date: Mon, 6 Nov 2023 14:26:00 -0600 Subject: [PATCH] Automatically convert TTS audio to MP3 on demand (#102814) * Add ATTR_PREFERRED_FORMAT to TTS for auto-converting audio * Move conversion into SpeechManager * Handle None case for expected_extension * Only use ATTR_AUDIO_OUTPUT * Prefer MP3 in pipelines * Automatically convert to mp3 on demand * Add preferred audio format * Break out preferred format * Add ATTR_BLOCKING to allow async fetching * Make a copy of supported options * Fix MaryTTS tests * Update ESPHome to use "wav" instead of "raw" * Clean up tests, remove blocking * Clean up rest of TTS tests * Fix ESPHome tests * More test coverage --- .../components/assist_pipeline/pipeline.py | 8 +- homeassistant/components/cloud/tts.py | 2 +- .../components/esphome/voice_assistant.py | 31 +++- homeassistant/components/tts/__init__.py | 148 +++++++++++++++--- homeassistant/components/tts/manifest.json | 2 +- homeassistant/components/wyoming/tts.py | 35 +---- homeassistant/package_constraints.txt | 1 + tests/components/assist_pipeline/test_init.py | 41 ++++- .../esphome/test_voice_assistant.py | 70 ++++++++- tests/components/google_translate/test_tts.py | 60 ++++--- tests/components/marytts/test_tts.py | 61 +++++--- tests/components/microsoft/test_tts.py | 78 ++++++--- tests/components/tts/common.py | 16 ++ tests/components/tts/test_init.py | 53 ++++--- tests/components/tts/test_media_source.py | 40 +++-- tests/components/voicerss/test_tts.py | 72 +++++---- .../wyoming/snapshots/test_tts.ambr | 33 ++++ tests/components/wyoming/test_tts.py | 103 ++++++++---- tests/components/yandextts/test_tts.py | 110 +++++++++---- 19 files changed, 723 insertions(+), 241 deletions(-) diff --git a/homeassistant/components/assist_pipeline/pipeline.py b/homeassistant/components/assist_pipeline/pipeline.py index 1e1c0b6f495..c6d0f6c5435 100644 --- a/homeassistant/components/assist_pipeline/pipeline.py +++ b/homeassistant/components/assist_pipeline/pipeline.py @@ -971,12 +971,16 @@ class PipelineRun: # pipeline.tts_engine can't be None or this function is not called engine = cast(str, self.pipeline.tts_engine) - tts_options = {} + tts_options: dict[str, Any] = {} if self.pipeline.tts_voice is not None: tts_options[tts.ATTR_VOICE] = self.pipeline.tts_voice if self.tts_audio_output is not None: - tts_options[tts.ATTR_AUDIO_OUTPUT] = self.tts_audio_output + tts_options[tts.ATTR_PREFERRED_FORMAT] = self.tts_audio_output + if self.tts_audio_output == "wav": + # 16 Khz, 16-bit mono + tts_options[tts.ATTR_PREFERRED_SAMPLE_RATE] = 16000 + tts_options[tts.ATTR_PREFERRED_SAMPLE_CHANNELS] = 1 try: options_supported = await tts.async_support_options( diff --git a/homeassistant/components/cloud/tts.py b/homeassistant/components/cloud/tts.py index 88f24d1290f..f8152243bf5 100644 --- a/homeassistant/components/cloud/tts.py +++ b/homeassistant/components/cloud/tts.py @@ -150,4 +150,4 @@ class CloudProvider(Provider): _LOGGER.error("Voice error: %s", err) return (None, None) - return (str(options[ATTR_AUDIO_OUTPUT]), data) + return (str(options[ATTR_AUDIO_OUTPUT].value), data) diff --git a/homeassistant/components/esphome/voice_assistant.py b/homeassistant/components/esphome/voice_assistant.py index 26c0780d735..bb62d495076 100644 --- a/homeassistant/components/esphome/voice_assistant.py +++ b/homeassistant/components/esphome/voice_assistant.py @@ -3,9 +3,11 @@ from __future__ import annotations import asyncio from collections.abc import AsyncIterable, Callable +import io import logging import socket from typing import cast +import wave from aioesphomeapi import ( VoiceAssistantAudioSettings, @@ -88,6 +90,7 @@ class VoiceAssistantUDPServer(asyncio.DatagramProtocol): self.handle_event = handle_event self.handle_finished = handle_finished self._tts_done = asyncio.Event() + self._tts_task: asyncio.Task | None = None async def start_server(self) -> int: """Start accepting connections.""" @@ -189,7 +192,7 @@ class VoiceAssistantUDPServer(asyncio.DatagramProtocol): if self.device_info.voice_assistant_version >= 2: media_id = event.data["tts_output"]["media_id"] - self.hass.async_create_background_task( + self._tts_task = self.hass.async_create_background_task( self._send_tts(media_id), "esphome_voice_assistant_tts" ) else: @@ -228,7 +231,7 @@ class VoiceAssistantUDPServer(asyncio.DatagramProtocol): audio_settings = VoiceAssistantAudioSettings() tts_audio_output = ( - "raw" if self.device_info.voice_assistant_version >= 2 else "mp3" + "wav" if self.device_info.voice_assistant_version >= 2 else "mp3" ) _LOGGER.debug("Starting pipeline") @@ -302,11 +305,32 @@ class VoiceAssistantUDPServer(asyncio.DatagramProtocol): VoiceAssistantEventType.VOICE_ASSISTANT_TTS_STREAM_START, {} ) - _extension, audio_bytes = await tts.async_get_media_source_audio( + extension, data = await tts.async_get_media_source_audio( self.hass, media_id, ) + if extension != "wav": + raise ValueError(f"Only WAV audio can be streamed, got {extension}") + + with io.BytesIO(data) as wav_io: + with wave.open(wav_io, "rb") as wav_file: + sample_rate = wav_file.getframerate() + sample_width = wav_file.getsampwidth() + sample_channels = wav_file.getnchannels() + + if ( + (sample_rate != 16000) + or (sample_width != 2) + or (sample_channels != 1) + ): + raise ValueError( + "Expected rate/width/channels as 16000/2/1," + " got {sample_rate}/{sample_width}/{sample_channels}}" + ) + + audio_bytes = wav_file.readframes(wav_file.getnframes()) + _LOGGER.debug("Sending %d bytes of audio", len(audio_bytes)) bytes_per_sample = stt.AudioBitRates.BITRATE_16 // 8 @@ -330,4 +354,5 @@ class VoiceAssistantUDPServer(asyncio.DatagramProtocol): self.handle_event( VoiceAssistantEventType.VOICE_ASSISTANT_TTS_STREAM_END, {} ) + self._tts_task = None self._tts_done.set() diff --git a/homeassistant/components/tts/__init__.py b/homeassistant/components/tts/__init__.py index 4402722e37f..f84c819e739 100644 --- a/homeassistant/components/tts/__init__.py +++ b/homeassistant/components/tts/__init__.py @@ -13,6 +13,8 @@ import logging import mimetypes import os import re +import subprocess +import tempfile from typing import Any, TypedDict, final from aiohttp import web @@ -20,7 +22,7 @@ import mutagen from mutagen.id3 import ID3, TextFrame as ID3Text import voluptuous as vol -from homeassistant.components import websocket_api +from homeassistant.components import ffmpeg, websocket_api from homeassistant.components.http import HomeAssistantView from homeassistant.components.media_player import ( ATTR_MEDIA_ANNOUNCE, @@ -72,11 +74,15 @@ __all__ = [ "async_get_media_source_audio", "async_support_options", "ATTR_AUDIO_OUTPUT", + "ATTR_PREFERRED_FORMAT", + "ATTR_PREFERRED_SAMPLE_RATE", + "ATTR_PREFERRED_SAMPLE_CHANNELS", "CONF_LANG", "DEFAULT_CACHE_DIR", "generate_media_source_id", "PLATFORM_SCHEMA_BASE", "PLATFORM_SCHEMA", + "SampleFormat", "Provider", "TtsAudioType", "Voice", @@ -86,6 +92,9 @@ _LOGGER = logging.getLogger(__name__) ATTR_PLATFORM = "platform" ATTR_AUDIO_OUTPUT = "audio_output" +ATTR_PREFERRED_FORMAT = "preferred_format" +ATTR_PREFERRED_SAMPLE_RATE = "preferred_sample_rate" +ATTR_PREFERRED_SAMPLE_CHANNELS = "preferred_sample_channels" ATTR_MEDIA_PLAYER_ENTITY_ID = "media_player_entity_id" ATTR_VOICE = "voice" @@ -199,6 +208,83 @@ def async_get_text_to_speech_languages(hass: HomeAssistant) -> set[str]: return languages +async def async_convert_audio( + hass: HomeAssistant, + from_extension: str, + audio_bytes: bytes, + to_extension: str, + to_sample_rate: int | None = None, + to_sample_channels: int | None = None, +) -> bytes: + """Convert audio to a preferred format using ffmpeg.""" + ffmpeg_manager = ffmpeg.get_ffmpeg_manager(hass) + return await hass.async_add_executor_job( + lambda: _convert_audio( + ffmpeg_manager.binary, + from_extension, + audio_bytes, + to_extension, + to_sample_rate=to_sample_rate, + to_sample_channels=to_sample_channels, + ) + ) + + +def _convert_audio( + ffmpeg_binary: str, + from_extension: str, + audio_bytes: bytes, + to_extension: str, + to_sample_rate: int | None = None, + to_sample_channels: int | None = None, +) -> bytes: + """Convert audio to a preferred format using ffmpeg.""" + + # We have to use a temporary file here because some formats like WAV store + # the length of the file in the header, and therefore cannot be written in a + # streaming fashion. + with tempfile.NamedTemporaryFile( + mode="wb+", suffix=f".{to_extension}" + ) as output_file: + # input + command = [ + ffmpeg_binary, + "-y", # overwrite temp file + "-f", + from_extension, + "-i", + "pipe:", # input from stdin + ] + + # output + command.extend(["-f", to_extension]) + + if to_sample_rate is not None: + command.extend(["-ar", str(to_sample_rate)]) + + if to_sample_channels is not None: + command.extend(["-ac", str(to_sample_channels)]) + + if to_extension == "mp3": + # Max quality for MP3 + command.extend(["-q:a", "0"]) + + command.append(output_file.name) + + with subprocess.Popen( + command, stdin=subprocess.PIPE, stderr=subprocess.PIPE + ) as proc: + _stdout, stderr = proc.communicate(input=audio_bytes) + if proc.returncode != 0: + _LOGGER.error(stderr.decode()) + raise RuntimeError( + f"Unexpected error while running ffmpeg with arguments: {command}. See log for details." + ) + + output_file.seek(0) + return output_file.read() + + async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool: """Set up TTS.""" websocket_api.async_register_command(hass, websocket_list_engines) @@ -482,7 +568,18 @@ class SpeechManager: merged_options = dict(engine_instance.default_options or {}) merged_options.update(options or {}) - supported_options = engine_instance.supported_options or [] + supported_options = list(engine_instance.supported_options or []) + + # ATTR_PREFERRED_* options are always "supported" since they're used to + # convert audio after the TTS has run (if necessary). + supported_options.extend( + ( + ATTR_PREFERRED_FORMAT, + ATTR_PREFERRED_SAMPLE_RATE, + ATTR_PREFERRED_SAMPLE_CHANNELS, + ) + ) + invalid_opts = [ opt_name for opt_name in merged_options if opt_name not in supported_options ] @@ -520,12 +617,7 @@ class SpeechManager: # Load speech from engine into memory else: filename = await self._async_get_tts_audio( - engine_instance, - cache_key, - message, - use_cache, - language, - options, + engine_instance, cache_key, message, use_cache, language, options ) return f"/api/tts_proxy/{filename}" @@ -590,10 +682,10 @@ class SpeechManager: This method is a coroutine. """ - if options is not None and ATTR_AUDIO_OUTPUT in options: - expected_extension = options[ATTR_AUDIO_OUTPUT] - else: - expected_extension = None + options = options or {} + + # Default to MP3 unless a different format is preferred + final_extension = options.get(ATTR_PREFERRED_FORMAT, "mp3") async def get_tts_data() -> str: """Handle data available.""" @@ -614,8 +706,27 @@ class SpeechManager: f"No TTS from {engine_instance.name} for '{message}'" ) + # Only convert if we have a preferred format different than the + # expected format from the TTS system, or if a specific sample + # rate/format/channel count is requested. + needs_conversion = ( + (final_extension != extension) + or (ATTR_PREFERRED_SAMPLE_RATE in options) + or (ATTR_PREFERRED_SAMPLE_CHANNELS in options) + ) + + if needs_conversion: + data = await async_convert_audio( + self.hass, + extension, + data, + to_extension=final_extension, + to_sample_rate=options.get(ATTR_PREFERRED_SAMPLE_RATE), + to_sample_channels=options.get(ATTR_PREFERRED_SAMPLE_CHANNELS), + ) + # Create file infos - filename = f"{cache_key}.{extension}".lower() + filename = f"{cache_key}.{final_extension}".lower() # Validate filename if not _RE_VOICE_FILE.match(filename) and not _RE_LEGACY_VOICE_FILE.match( @@ -626,10 +737,11 @@ class SpeechManager: ) # Save to memory - if extension == "mp3": + if final_extension == "mp3": data = self.write_tags( filename, data, engine_instance.name, message, language, options ) + self._async_store_to_memcache(cache_key, filename, data) if cache: @@ -641,9 +753,6 @@ class SpeechManager: audio_task = self.hass.async_create_task(get_tts_data()) - if expected_extension is None: - return await audio_task - def handle_error(_future: asyncio.Future) -> None: """Handle error.""" if audio_task.exception(): @@ -651,7 +760,7 @@ class SpeechManager: audio_task.add_done_callback(handle_error) - filename = f"{cache_key}.{expected_extension}".lower() + filename = f"{cache_key}.{final_extension}".lower() self.mem_cache[cache_key] = { "filename": filename, "voice": b"", @@ -747,11 +856,12 @@ class SpeechManager: raise HomeAssistantError(f"{cache_key} not in cache!") await self._async_file_to_mem(cache_key) - content, _ = mimetypes.guess_type(filename) cached = self.mem_cache[cache_key] if pending := cached.get("pending"): await pending cached = self.mem_cache[cache_key] + + content, _ = mimetypes.guess_type(filename) return content, cached["voice"] @staticmethod diff --git a/homeassistant/components/tts/manifest.json b/homeassistant/components/tts/manifest.json index f1120ed2750..338a8c35003 100644 --- a/homeassistant/components/tts/manifest.json +++ b/homeassistant/components/tts/manifest.json @@ -3,7 +3,7 @@ "name": "Text-to-speech (TTS)", "after_dependencies": ["media_player"], "codeowners": ["@home-assistant/core", "@pvizeli"], - "dependencies": ["http"], + "dependencies": ["http", "ffmpeg"], "documentation": "https://www.home-assistant.io/integrations/tts", "integration_type": "entity", "loggers": ["mutagen"], diff --git a/homeassistant/components/wyoming/tts.py b/homeassistant/components/wyoming/tts.py index 6510fd8c761..cde771cd330 100644 --- a/homeassistant/components/wyoming/tts.py +++ b/homeassistant/components/wyoming/tts.py @@ -4,7 +4,7 @@ import io import logging import wave -from wyoming.audio import AudioChunk, AudioChunkConverter, AudioStop +from wyoming.audio import AudioChunk, AudioStop from wyoming.client import AsyncTcpClient from wyoming.tts import Synthesize, SynthesizeVoice @@ -88,12 +88,16 @@ class WyomingTtsProvider(tts.TextToSpeechEntity): @property def supported_options(self): """Return list of supported options like voice, emotion.""" - return [tts.ATTR_AUDIO_OUTPUT, tts.ATTR_VOICE, ATTR_SPEAKER] + return [ + tts.ATTR_AUDIO_OUTPUT, + tts.ATTR_VOICE, + ATTR_SPEAKER, + ] @property def default_options(self): """Return a dict include default options.""" - return {tts.ATTR_AUDIO_OUTPUT: "wav"} + return {} @callback def async_get_supported_voices(self, language: str) -> list[tts.Voice] | None: @@ -143,27 +147,4 @@ class WyomingTtsProvider(tts.TextToSpeechEntity): except (OSError, WyomingError): return (None, None) - if options[tts.ATTR_AUDIO_OUTPUT] == "wav": - return ("wav", data) - - # Raw output (convert to 16Khz, 16-bit mono) - with io.BytesIO(data) as wav_io: - wav_reader: wave.Wave_read = wave.open(wav_io, "rb") - raw_data = ( - AudioChunkConverter( - rate=16000, - width=2, - channels=1, - ) - .convert( - AudioChunk( - audio=wav_reader.readframes(wav_reader.getnframes()), - rate=wav_reader.getframerate(), - width=wav_reader.getsampwidth(), - channels=wav_reader.getnchannels(), - ) - ) - .audio - ) - - return ("raw", raw_data) + return ("wav", data) diff --git a/homeassistant/package_constraints.txt b/homeassistant/package_constraints.txt index 2c38dc8f153..fac2abb7df1 100644 --- a/homeassistant/package_constraints.txt +++ b/homeassistant/package_constraints.txt @@ -20,6 +20,7 @@ cryptography==41.0.4 dbus-fast==2.12.0 fnv-hash-fast==0.5.0 ha-av==10.1.1 +ha-ffmpeg==3.1.0 hass-nabucasa==0.74.0 hassil==1.2.5 home-assistant-bluetooth==1.10.4 diff --git a/tests/components/assist_pipeline/test_init.py b/tests/components/assist_pipeline/test_init.py index a98858a1bce..24a4a92536d 100644 --- a/tests/components/assist_pipeline/test_init.py +++ b/tests/components/assist_pipeline/test_init.py @@ -9,7 +9,7 @@ import wave import pytest from syrupy.assertion import SnapshotAssertion -from homeassistant.components import assist_pipeline, stt +from homeassistant.components import assist_pipeline, stt, tts from homeassistant.components.assist_pipeline.const import ( CONF_DEBUG_RECORDING_DIR, DOMAIN, @@ -660,3 +660,42 @@ def test_pipeline_run_equality(hass: HomeAssistant, init_components) -> None: assert run_1 == run_1 assert run_1 != run_2 assert run_1 != 1234 + + +async def test_tts_audio_output( + hass: HomeAssistant, + mock_stt_provider: MockSttProvider, + init_components, + pipeline_data: assist_pipeline.pipeline.PipelineData, + snapshot: SnapshotAssertion, +) -> None: + """Test using tts_audio_output with wav sets options correctly.""" + + def event_callback(event): + pass + + pipeline_store = pipeline_data.pipeline_store + pipeline_id = pipeline_store.async_get_preferred_item() + pipeline = assist_pipeline.pipeline.async_get_pipeline(hass, pipeline_id) + + pipeline_input = assist_pipeline.pipeline.PipelineInput( + tts_input="This is a test.", + conversation_id=None, + device_id=None, + run=assist_pipeline.pipeline.PipelineRun( + hass, + context=Context(), + pipeline=pipeline, + start_stage=assist_pipeline.PipelineStage.TTS, + end_stage=assist_pipeline.PipelineStage.TTS, + event_callback=event_callback, + tts_audio_output="wav", + ), + ) + await pipeline_input.validate() + + # Verify TTS audio settings + assert pipeline_input.run.tts_options is not None + assert pipeline_input.run.tts_options.get(tts.ATTR_PREFERRED_FORMAT) == "wav" + assert pipeline_input.run.tts_options.get(tts.ATTR_PREFERRED_SAMPLE_RATE) == 16000 + assert pipeline_input.run.tts_options.get(tts.ATTR_PREFERRED_SAMPLE_CHANNELS) == 1 diff --git a/tests/components/esphome/test_voice_assistant.py b/tests/components/esphome/test_voice_assistant.py index 9b6bcf1c6c7..ca74c99f0cd 100644 --- a/tests/components/esphome/test_voice_assistant.py +++ b/tests/components/esphome/test_voice_assistant.py @@ -1,8 +1,10 @@ """Test ESPHome voice assistant server.""" import asyncio +import io import socket from unittest.mock import Mock, patch +import wave from aioesphomeapi import VoiceAssistantEventType import pytest @@ -340,9 +342,18 @@ async def test_send_tts( voice_assistant_udp_server_v2: VoiceAssistantUDPServer, ) -> None: """Test the UDP server calls sendto to transmit audio data to device.""" + with io.BytesIO() as wav_io: + with wave.open(wav_io, "wb") as wav_file: + wav_file.setframerate(16000) + wav_file.setsampwidth(2) + wav_file.setnchannels(1) + wav_file.writeframes(bytes(_ONE_SECOND)) + + wav_bytes = wav_io.getvalue() + with patch( "homeassistant.components.esphome.voice_assistant.tts.async_get_media_source_audio", - return_value=("raw", bytes(1024)), + return_value=("wav", wav_bytes), ): voice_assistant_udp_server_v2.transport = Mock(spec=asyncio.DatagramTransport) @@ -360,6 +371,63 @@ async def test_send_tts( voice_assistant_udp_server_v2.transport.sendto.assert_called() +async def test_send_tts_wrong_sample_rate( + hass: HomeAssistant, + voice_assistant_udp_server_v2: VoiceAssistantUDPServer, +) -> None: + """Test the UDP server calls sendto to transmit audio data to device.""" + with io.BytesIO() as wav_io: + with wave.open(wav_io, "wb") as wav_file: + wav_file.setframerate(22050) # should be 16000 + wav_file.setsampwidth(2) + wav_file.setnchannels(1) + wav_file.writeframes(bytes(_ONE_SECOND)) + + wav_bytes = wav_io.getvalue() + + with patch( + "homeassistant.components.esphome.voice_assistant.tts.async_get_media_source_audio", + return_value=("wav", wav_bytes), + ), pytest.raises(ValueError): + voice_assistant_udp_server_v2.transport = Mock(spec=asyncio.DatagramTransport) + + voice_assistant_udp_server_v2._event_callback( + PipelineEvent( + type=PipelineEventType.TTS_END, + data={ + "tts_output": {"media_id": _TEST_MEDIA_ID, "url": _TEST_OUTPUT_URL} + }, + ) + ) + + assert voice_assistant_udp_server_v2._tts_task is not None + await voice_assistant_udp_server_v2._tts_task # raises ValueError + + +async def test_send_tts_wrong_format( + hass: HomeAssistant, + voice_assistant_udp_server_v2: VoiceAssistantUDPServer, +) -> None: + """Test that only WAV audio will be streamed.""" + with patch( + "homeassistant.components.esphome.voice_assistant.tts.async_get_media_source_audio", + return_value=("raw", bytes(1024)), + ), pytest.raises(ValueError): + voice_assistant_udp_server_v2.transport = Mock(spec=asyncio.DatagramTransport) + + voice_assistant_udp_server_v2._event_callback( + PipelineEvent( + type=PipelineEventType.TTS_END, + data={ + "tts_output": {"media_id": _TEST_MEDIA_ID, "url": _TEST_OUTPUT_URL} + }, + ) + ) + + assert voice_assistant_udp_server_v2._tts_task is not None + await voice_assistant_udp_server_v2._tts_task # raises ValueError + + async def test_wake_word( hass: HomeAssistant, voice_assistant_udp_server_v2: VoiceAssistantUDPServer, diff --git a/tests/components/google_translate/test_tts.py b/tests/components/google_translate/test_tts.py index d6669ee3c5f..fd1ddd8a4f2 100644 --- a/tests/components/google_translate/test_tts.py +++ b/tests/components/google_translate/test_tts.py @@ -2,13 +2,14 @@ from __future__ import annotations from collections.abc import Generator +from http import HTTPStatus from typing import Any from unittest.mock import MagicMock, patch from gtts import gTTSError import pytest -from homeassistant.components import media_source, tts +from homeassistant.components import tts from homeassistant.components.google_translate.const import CONF_TLD, DOMAIN from homeassistant.components.media_player import ( ATTR_MEDIA_CONTENT_ID, @@ -18,10 +19,11 @@ from homeassistant.components.media_player import ( from homeassistant.config import async_process_ha_core_config from homeassistant.const import ATTR_ENTITY_ID, CONF_PLATFORM from homeassistant.core import HomeAssistant, ServiceCall -from homeassistant.exceptions import HomeAssistantError from homeassistant.setup import async_setup_component from tests.common import MockConfigEntry, async_mock_service +from tests.components.tts.common import retrieve_media +from tests.typing import ClientSessionGenerator @pytest.fixture(autouse=True) @@ -35,15 +37,6 @@ def mock_tts_cache_dir_autouse(mock_tts_cache_dir): return mock_tts_cache_dir -async def get_media_source_url(hass: HomeAssistant, media_content_id: str) -> str: - """Get the media source url.""" - if media_source.DOMAIN not in hass.config.components: - assert await async_setup_component(hass, media_source.DOMAIN, {}) - - resolved = await media_source.async_resolve_media(hass, media_content_id, None) - return resolved.url - - @pytest.fixture async def calls(hass: HomeAssistant) -> list[ServiceCall]: """Mock media player calls.""" @@ -128,6 +121,7 @@ async def mock_config_entry_setup(hass: HomeAssistant, config: dict[str, Any]) - async def test_tts_service( hass: HomeAssistant, mock_gtts: MagicMock, + hass_client: ClientSessionGenerator, calls: list[ServiceCall], setup: str, tts_service: str, @@ -142,9 +136,11 @@ async def test_tts_service( ) assert len(calls) == 1 - url = await get_media_source_url(hass, calls[0].data[ATTR_MEDIA_CONTENT_ID]) + assert ( + await retrieve_media(hass, hass_client, calls[0].data[ATTR_MEDIA_CONTENT_ID]) + == HTTPStatus.OK + ) assert len(mock_gtts.mock_calls) == 2 - assert url.endswith(".mp3") assert mock_gtts.mock_calls[0][2] == { "text": "There is a person at the front door.", @@ -180,6 +176,7 @@ async def test_tts_service( async def test_service_say_german_config( hass: HomeAssistant, mock_gtts: MagicMock, + hass_client: ClientSessionGenerator, calls: list[ServiceCall], setup: str, tts_service: str, @@ -194,7 +191,10 @@ async def test_service_say_german_config( ) assert len(calls) == 1 - await get_media_source_url(hass, calls[0].data[ATTR_MEDIA_CONTENT_ID]) + assert ( + await retrieve_media(hass, hass_client, calls[0].data[ATTR_MEDIA_CONTENT_ID]) + == HTTPStatus.OK + ) assert len(mock_gtts.mock_calls) == 2 assert mock_gtts.mock_calls[0][2] == { "text": "There is a person at the front door.", @@ -231,6 +231,7 @@ async def test_service_say_german_config( async def test_service_say_german_service( hass: HomeAssistant, mock_gtts: MagicMock, + hass_client: ClientSessionGenerator, calls: list[ServiceCall], setup: str, tts_service: str, @@ -245,7 +246,10 @@ async def test_service_say_german_service( ) assert len(calls) == 1 - await get_media_source_url(hass, calls[0].data[ATTR_MEDIA_CONTENT_ID]) + assert ( + await retrieve_media(hass, hass_client, calls[0].data[ATTR_MEDIA_CONTENT_ID]) + == HTTPStatus.OK + ) assert len(mock_gtts.mock_calls) == 2 assert mock_gtts.mock_calls[0][2] == { "text": "There is a person at the front door.", @@ -281,6 +285,7 @@ async def test_service_say_german_service( async def test_service_say_en_uk_config( hass: HomeAssistant, mock_gtts: MagicMock, + hass_client: ClientSessionGenerator, calls: list[ServiceCall], setup: str, tts_service: str, @@ -295,7 +300,10 @@ async def test_service_say_en_uk_config( ) assert len(calls) == 1 - await get_media_source_url(hass, calls[0].data[ATTR_MEDIA_CONTENT_ID]) + assert ( + await retrieve_media(hass, hass_client, calls[0].data[ATTR_MEDIA_CONTENT_ID]) + == HTTPStatus.OK + ) assert len(mock_gtts.mock_calls) == 2 assert mock_gtts.mock_calls[0][2] == { "text": "There is a person at the front door.", @@ -332,6 +340,7 @@ async def test_service_say_en_uk_config( async def test_service_say_en_uk_service( hass: HomeAssistant, mock_gtts: MagicMock, + hass_client: ClientSessionGenerator, calls: list[ServiceCall], setup: str, tts_service: str, @@ -346,7 +355,10 @@ async def test_service_say_en_uk_service( ) assert len(calls) == 1 - await get_media_source_url(hass, calls[0].data[ATTR_MEDIA_CONTENT_ID]) + assert ( + await retrieve_media(hass, hass_client, calls[0].data[ATTR_MEDIA_CONTENT_ID]) + == HTTPStatus.OK + ) assert len(mock_gtts.mock_calls) == 2 assert mock_gtts.mock_calls[0][2] == { "text": "There is a person at the front door.", @@ -383,6 +395,7 @@ async def test_service_say_en_uk_service( async def test_service_say_en_couk( hass: HomeAssistant, mock_gtts: MagicMock, + hass_client: ClientSessionGenerator, calls: list[ServiceCall], setup: str, tts_service: str, @@ -397,9 +410,11 @@ async def test_service_say_en_couk( ) assert len(calls) == 1 - url = await get_media_source_url(hass, calls[0].data[ATTR_MEDIA_CONTENT_ID]) + assert ( + await retrieve_media(hass, hass_client, calls[0].data[ATTR_MEDIA_CONTENT_ID]) + == HTTPStatus.OK + ) assert len(mock_gtts.mock_calls) == 2 - assert url.endswith(".mp3") assert mock_gtts.mock_calls[0][2] == { "text": "There is a person at the front door.", @@ -434,6 +449,7 @@ async def test_service_say_en_couk( async def test_service_say_error( hass: HomeAssistant, mock_gtts: MagicMock, + hass_client: ClientSessionGenerator, calls: list[ServiceCall], setup: str, tts_service: str, @@ -450,6 +466,8 @@ async def test_service_say_error( ) assert len(calls) == 1 - with pytest.raises(HomeAssistantError): - await get_media_source_url(hass, calls[0].data[ATTR_MEDIA_CONTENT_ID]) + assert ( + await retrieve_media(hass, hass_client, calls[0].data[ATTR_MEDIA_CONTENT_ID]) + == HTTPStatus.NOT_FOUND + ) assert len(mock_gtts.mock_calls) == 2 diff --git a/tests/components/marytts/test_tts.py b/tests/components/marytts/test_tts.py index 4282b86ec2e..474d2f19faf 100644 --- a/tests/components/marytts/test_tts.py +++ b/tests/components/marytts/test_tts.py @@ -1,9 +1,12 @@ """The tests for the MaryTTS speech platform.""" +from http import HTTPStatus +import io from unittest.mock import patch +import wave import pytest -from homeassistant.components import media_source, tts +from homeassistant.components import tts from homeassistant.components.media_player import ( ATTR_MEDIA_CONTENT_ID, DOMAIN as DOMAIN_MP, @@ -13,15 +16,19 @@ from homeassistant.core import HomeAssistant from homeassistant.setup import async_setup_component from tests.common import assert_setup_component, async_mock_service +from tests.components.tts.common import retrieve_media +from tests.typing import ClientSessionGenerator -async def get_media_source_url(hass, media_content_id): - """Get the media source url.""" - if media_source.DOMAIN not in hass.config.components: - assert await async_setup_component(hass, media_source.DOMAIN, {}) +def get_empty_wav() -> bytes: + """Get bytes for empty WAV file.""" + with io.BytesIO() as wav_io: + with wave.open(wav_io, "wb") as wav_file: + wav_file.setframerate(22050) + wav_file.setsampwidth(2) + wav_file.setnchannels(1) - resolved = await media_source.async_resolve_media(hass, media_content_id, None) - return resolved.url + return wav_io.getvalue() @pytest.fixture(autouse=True) @@ -39,7 +46,9 @@ async def test_setup_component(hass: HomeAssistant) -> None: await hass.async_block_till_done() -async def test_service_say(hass: HomeAssistant) -> None: +async def test_service_say( + hass: HomeAssistant, hass_client: ClientSessionGenerator +) -> None: """Test service call say.""" calls = async_mock_service(hass, DOMAIN_MP, SERVICE_PLAY_MEDIA) @@ -51,7 +60,7 @@ async def test_service_say(hass: HomeAssistant) -> None: with patch( "homeassistant.components.marytts.tts.MaryTTS.speak", - return_value=b"audio", + return_value=get_empty_wav(), ) as mock_speak: await hass.services.async_call( tts.DOMAIN, @@ -63,16 +72,22 @@ async def test_service_say(hass: HomeAssistant) -> None: blocking=True, ) - url = await get_media_source_url(hass, calls[0].data[ATTR_MEDIA_CONTENT_ID]) + assert ( + await retrieve_media( + hass, hass_client, calls[0].data[ATTR_MEDIA_CONTENT_ID] + ) + == HTTPStatus.OK + ) mock_speak.assert_called_once() mock_speak.assert_called_with("HomeAssistant", {}) assert len(calls) == 1 - assert url.endswith(".wav") -async def test_service_say_with_effect(hass: HomeAssistant) -> None: +async def test_service_say_with_effect( + hass: HomeAssistant, hass_client: ClientSessionGenerator +) -> None: """Test service call say with effects.""" calls = async_mock_service(hass, DOMAIN_MP, SERVICE_PLAY_MEDIA) @@ -84,7 +99,7 @@ async def test_service_say_with_effect(hass: HomeAssistant) -> None: with patch( "homeassistant.components.marytts.tts.MaryTTS.speak", - return_value=b"audio", + return_value=get_empty_wav(), ) as mock_speak: await hass.services.async_call( tts.DOMAIN, @@ -96,16 +111,22 @@ async def test_service_say_with_effect(hass: HomeAssistant) -> None: blocking=True, ) - url = await get_media_source_url(hass, calls[0].data[ATTR_MEDIA_CONTENT_ID]) + assert ( + await retrieve_media( + hass, hass_client, calls[0].data[ATTR_MEDIA_CONTENT_ID] + ) + == HTTPStatus.OK + ) mock_speak.assert_called_once() mock_speak.assert_called_with("HomeAssistant", {"Volume": "amount:2.0;"}) assert len(calls) == 1 - assert url.endswith(".wav") -async def test_service_say_http_error(hass: HomeAssistant) -> None: +async def test_service_say_http_error( + hass: HomeAssistant, hass_client: ClientSessionGenerator +) -> None: """Test service call say.""" calls = async_mock_service(hass, DOMAIN_MP, SERVICE_PLAY_MEDIA) @@ -129,7 +150,11 @@ async def test_service_say_http_error(hass: HomeAssistant) -> None: ) await hass.async_block_till_done() - with pytest.raises(Exception): - await get_media_source_url(hass, calls[0].data[ATTR_MEDIA_CONTENT_ID]) + assert ( + await retrieve_media( + hass, hass_client, calls[0].data[ATTR_MEDIA_CONTENT_ID] + ) + == HTTPStatus.NOT_FOUND + ) mock_speak.assert_called_once() diff --git a/tests/components/microsoft/test_tts.py b/tests/components/microsoft/test_tts.py index 9684d1aa7d5..bc6a3ac7dd7 100644 --- a/tests/components/microsoft/test_tts.py +++ b/tests/components/microsoft/test_tts.py @@ -1,10 +1,11 @@ """Tests for Microsoft text-to-speech.""" +from http import HTTPStatus from unittest.mock import patch from pycsspeechtts import pycsspeechtts import pytest -from homeassistant.components import media_source, tts +from homeassistant.components import tts from homeassistant.components.media_player import ( ATTR_MEDIA_CONTENT_ID, DOMAIN as DOMAIN_MP, @@ -13,19 +14,12 @@ from homeassistant.components.media_player import ( from homeassistant.components.microsoft.tts import SUPPORTED_LANGUAGES from homeassistant.config import async_process_ha_core_config from homeassistant.core import HomeAssistant -from homeassistant.exceptions import HomeAssistantError, ServiceNotFound +from homeassistant.exceptions import ServiceNotFound from homeassistant.setup import async_setup_component from tests.common import async_mock_service - - -async def get_media_source_url(hass: HomeAssistant, media_content_id): - """Get the media source url.""" - if media_source.DOMAIN not in hass.config.components: - assert await async_setup_component(hass, media_source.DOMAIN, {}) - - resolved = await media_source.async_resolve_media(hass, media_content_id, None) - return resolved.url +from tests.components.tts.common import retrieve_media +from tests.typing import ClientSessionGenerator @pytest.fixture(autouse=True) @@ -58,7 +52,9 @@ def mock_tts(): yield mock_tts -async def test_service_say(hass: HomeAssistant, mock_tts, calls) -> None: +async def test_service_say( + hass: HomeAssistant, hass_client: ClientSessionGenerator, mock_tts, calls +) -> None: """Test service call say.""" await async_setup_component( @@ -76,9 +72,12 @@ async def test_service_say(hass: HomeAssistant, mock_tts, calls) -> None: ) assert len(calls) == 1 - url = await get_media_source_url(hass, calls[0].data[ATTR_MEDIA_CONTENT_ID]) + assert ( + await retrieve_media(hass, hass_client, calls[0].data[ATTR_MEDIA_CONTENT_ID]) + == HTTPStatus.OK + ) + assert len(mock_tts.mock_calls) == 2 - assert url.endswith(".mp3") assert mock_tts.mock_calls[1][2] == { "language": "en-us", @@ -93,7 +92,9 @@ async def test_service_say(hass: HomeAssistant, mock_tts, calls) -> None: } -async def test_service_say_en_gb_config(hass: HomeAssistant, mock_tts, calls) -> None: +async def test_service_say_en_gb_config( + hass: HomeAssistant, hass_client: ClientSessionGenerator, mock_tts, calls +) -> None: """Test service call say with en-gb code in the config.""" await async_setup_component( @@ -120,7 +121,11 @@ async def test_service_say_en_gb_config(hass: HomeAssistant, mock_tts, calls) -> ) assert len(calls) == 1 - await get_media_source_url(hass, calls[0].data[ATTR_MEDIA_CONTENT_ID]) + assert ( + await retrieve_media(hass, hass_client, calls[0].data[ATTR_MEDIA_CONTENT_ID]) + == HTTPStatus.OK + ) + assert len(mock_tts.mock_calls) == 2 assert mock_tts.mock_calls[1][2] == { "language": "en-gb", @@ -135,7 +140,9 @@ async def test_service_say_en_gb_config(hass: HomeAssistant, mock_tts, calls) -> } -async def test_service_say_en_gb_service(hass: HomeAssistant, mock_tts, calls) -> None: +async def test_service_say_en_gb_service( + hass: HomeAssistant, hass_client: ClientSessionGenerator, mock_tts, calls +) -> None: """Test service call say with en-gb code in the service.""" await async_setup_component( @@ -157,7 +164,11 @@ async def test_service_say_en_gb_service(hass: HomeAssistant, mock_tts, calls) - ) assert len(calls) == 1 - await get_media_source_url(hass, calls[0].data[ATTR_MEDIA_CONTENT_ID]) + assert ( + await retrieve_media(hass, hass_client, calls[0].data[ATTR_MEDIA_CONTENT_ID]) + == HTTPStatus.OK + ) + assert len(mock_tts.mock_calls) == 2 assert mock_tts.mock_calls[1][2] == { "language": "en-gb", @@ -172,7 +183,9 @@ async def test_service_say_en_gb_service(hass: HomeAssistant, mock_tts, calls) - } -async def test_service_say_fa_ir_config(hass: HomeAssistant, mock_tts, calls) -> None: +async def test_service_say_fa_ir_config( + hass: HomeAssistant, hass_client: ClientSessionGenerator, mock_tts, calls +) -> None: """Test service call say with fa-ir code in the config.""" await async_setup_component( @@ -199,7 +212,11 @@ async def test_service_say_fa_ir_config(hass: HomeAssistant, mock_tts, calls) -> ) assert len(calls) == 1 - await get_media_source_url(hass, calls[0].data[ATTR_MEDIA_CONTENT_ID]) + assert ( + await retrieve_media(hass, hass_client, calls[0].data[ATTR_MEDIA_CONTENT_ID]) + == HTTPStatus.OK + ) + assert len(mock_tts.mock_calls) == 2 assert mock_tts.mock_calls[1][2] == { "language": "fa-ir", @@ -214,7 +231,9 @@ async def test_service_say_fa_ir_config(hass: HomeAssistant, mock_tts, calls) -> } -async def test_service_say_fa_ir_service(hass: HomeAssistant, mock_tts, calls) -> None: +async def test_service_say_fa_ir_service( + hass: HomeAssistant, hass_client: ClientSessionGenerator, mock_tts, calls +) -> None: """Test service call say with fa-ir code in the service.""" config = { @@ -240,7 +259,11 @@ async def test_service_say_fa_ir_service(hass: HomeAssistant, mock_tts, calls) - ) assert len(calls) == 1 - await get_media_source_url(hass, calls[0].data[ATTR_MEDIA_CONTENT_ID]) + assert ( + await retrieve_media(hass, hass_client, calls[0].data[ATTR_MEDIA_CONTENT_ID]) + == HTTPStatus.OK + ) + assert len(mock_tts.mock_calls) == 2 assert mock_tts.mock_calls[1][2] == { "language": "fa-ir", @@ -295,7 +318,9 @@ async def test_invalid_language(hass: HomeAssistant, mock_tts, calls) -> None: assert len(mock_tts.mock_calls) == 0 -async def test_service_say_error(hass: HomeAssistant, mock_tts, calls) -> None: +async def test_service_say_error( + hass: HomeAssistant, hass_client: ClientSessionGenerator, mock_tts, calls +) -> None: """Test service call say with http error.""" mock_tts.return_value.speak.side_effect = pycsspeechtts.requests.HTTPError await async_setup_component( @@ -313,6 +338,9 @@ async def test_service_say_error(hass: HomeAssistant, mock_tts, calls) -> None: ) assert len(calls) == 1 - with pytest.raises(HomeAssistantError): - await get_media_source_url(hass, calls[0].data[ATTR_MEDIA_CONTENT_ID]) + assert ( + await retrieve_media(hass, hass_client, calls[0].data[ATTR_MEDIA_CONTENT_ID]) + == HTTPStatus.NOT_FOUND + ) + assert len(mock_tts.mock_calls) == 2 diff --git a/tests/components/tts/common.py b/tests/components/tts/common.py index a9a95eae2f4..0c3642df6fe 100644 --- a/tests/components/tts/common.py +++ b/tests/components/tts/common.py @@ -2,6 +2,7 @@ from __future__ import annotations from collections.abc import Generator +from http import HTTPStatus from typing import Any from unittest.mock import MagicMock, patch @@ -32,6 +33,7 @@ from tests.common import ( mock_integration, mock_platform, ) +from tests.typing import ClientSessionGenerator DEFAULT_LANG = "en_US" SUPPORT_LANGUAGES = ["de_CH", "de_DE", "en_GB", "en_US"] @@ -103,6 +105,20 @@ async def get_media_source_url(hass: HomeAssistant, media_content_id: str) -> st return resolved.url +async def retrieve_media( + hass: HomeAssistant, hass_client: ClientSessionGenerator, media_content_id: str +) -> HTTPStatus: + """Get the media source url.""" + url = await get_media_source_url(hass, media_content_id) + + # Ensure media has been generated by requesting it + await hass.async_block_till_done() + client = await hass_client() + req = await client.get(url) + + return req.status + + class BaseProvider: """Test speech API provider.""" diff --git a/tests/components/tts/test_init.py b/tests/components/tts/test_init.py index 2656beba236..71be6b3bb11 100644 --- a/tests/components/tts/test_init.py +++ b/tests/components/tts/test_init.py @@ -6,7 +6,7 @@ from unittest.mock import MagicMock, patch import pytest -from homeassistant.components import tts +from homeassistant.components import ffmpeg, tts from homeassistant.components.media_player import ( ATTR_MEDIA_ANNOUNCE, ATTR_MEDIA_CONTENT_ID, @@ -15,7 +15,6 @@ from homeassistant.components.media_player import ( SERVICE_PLAY_MEDIA, MediaType, ) -from homeassistant.components.media_source import Unresolvable from homeassistant.config_entries import ConfigEntryState from homeassistant.const import ATTR_ENTITY_ID, STATE_UNKNOWN from homeassistant.core import HomeAssistant, State @@ -33,6 +32,7 @@ from .common import ( get_media_source_url, mock_config_entry_setup, mock_setup, + retrieve_media, ) from tests.common import async_mock_service, mock_restore_cache @@ -75,7 +75,9 @@ async def test_default_entity_attributes() -> None: async def test_config_entry_unload( - hass: HomeAssistant, mock_tts_entity: MockTTSEntity + hass: HomeAssistant, + hass_client: ClientSessionGenerator, + mock_tts_entity: MockTTSEntity, ) -> None: """Test we can unload config entry.""" entity_id = f"{tts.DOMAIN}.{TEST_DOMAIN}" @@ -104,7 +106,12 @@ async def test_config_entry_unload( ) assert len(calls) == 1 - await get_media_source_url(hass, calls[0].data[ATTR_MEDIA_CONTENT_ID]) + assert ( + await retrieve_media( + hass, hass_client, calls[0].data[ATTR_MEDIA_CONTENT_ID] + ) + == HTTPStatus.OK + ) await hass.async_block_till_done() state = hass.states.get(entity_id) @@ -1159,6 +1166,7 @@ class MockEntityEmpty(MockTTSEntity): ) async def test_service_get_tts_error( hass: HomeAssistant, + hass_client: ClientSessionGenerator, setup: str, tts_service: str, service_data: dict[str, Any], @@ -1173,8 +1181,10 @@ async def test_service_get_tts_error( blocking=True, ) assert len(calls) == 1 - with pytest.raises(Unresolvable): - await get_media_source_url(hass, calls[0].data[ATTR_MEDIA_CONTENT_ID]) + assert ( + await retrieve_media(hass, hass_client, calls[0].data[ATTR_MEDIA_CONTENT_ID]) + == HTTPStatus.NOT_FOUND + ) async def test_load_cache_legacy_retrieve_without_mem_cache( @@ -1454,7 +1464,11 @@ async def test_legacy_fetching_in_async( # Test async_get_media_source_audio media_source_id = tts.generate_media_source_id( - hass, "test message", "test", "en_US", None, None + hass, + "test message", + "test", + "en_US", + cache=None, ) task = hass.async_create_task( @@ -1508,16 +1522,6 @@ async def test_fetching_in_async( class EntityWithAsyncFetching(MockTTSEntity): """Entity that supports audio output option.""" - @property - def supported_options(self) -> list[str]: - """Return list of supported options like voice, emotions.""" - return [tts.ATTR_AUDIO_OUTPUT] - - @property - def default_options(self) -> dict[str, str]: - """Return a dict including the default options.""" - return {tts.ATTR_AUDIO_OUTPUT: "mp3"} - async def async_get_tts_audio( self, message: str, language: str, options: dict[str, Any] ) -> tts.TtsAudioType: @@ -1527,7 +1531,11 @@ async def test_fetching_in_async( # Test async_get_media_source_audio media_source_id = tts.generate_media_source_id( - hass, "test message", "tts.test", "en_US", None, None + hass, + "test message", + "tts.test", + "en_US", + cache=None, ) task = hass.async_create_task( @@ -1751,3 +1759,12 @@ async def test_ws_list_voices( {"voice_id": "fran_drescher", "name": "Fran Drescher"}, ] } + + +async def test_async_convert_audio_error(hass: HomeAssistant) -> None: + """Test that ffmpeg failing during audio conversion will raise an error.""" + assert await async_setup_component(hass, ffmpeg.DOMAIN, {}) + + with pytest.raises(RuntimeError): + # Simulate a bad WAV file + await tts.async_convert_audio(hass, "wav", bytes(0), "mp3") diff --git a/tests/components/tts/test_media_source.py b/tests/components/tts/test_media_source.py index 86f1a3bcf3e..641c02064ec 100644 --- a/tests/components/tts/test_media_source.py +++ b/tests/components/tts/test_media_source.py @@ -1,4 +1,5 @@ """Tests for TTS media source.""" +from http import HTTPStatus from unittest.mock import MagicMock import pytest @@ -14,8 +15,11 @@ from .common import ( MockTTSEntity, mock_config_entry_setup, mock_setup, + retrieve_media, ) +from tests.typing import ClientSessionGenerator + class MSEntity(MockTTSEntity): """Test speech API entity.""" @@ -88,16 +92,18 @@ async def test_browsing(hass: HomeAssistant, setup: str) -> None: @pytest.mark.parametrize("mock_provider", [MSProvider(DEFAULT_LANG)]) -async def test_legacy_resolving(hass: HomeAssistant, mock_provider: MSProvider) -> None: +async def test_legacy_resolving( + hass: HomeAssistant, hass_client: ClientSessionGenerator, mock_provider: MSProvider +) -> None: """Test resolving legacy provider.""" await mock_setup(hass, mock_provider) mock_get_tts_audio = mock_provider.get_tts_audio - media = await media_source.async_resolve_media( - hass, "media-source://tts/test?message=Hello%20World", None - ) + media_id = "media-source://tts/test?message=Hello%20World" + media = await media_source.async_resolve_media(hass, media_id, None) assert media.url.startswith("/api/tts_proxy/") assert media.mime_type == "audio/mpeg" + assert await retrieve_media(hass, hass_client, media_id) == HTTPStatus.OK assert len(mock_get_tts_audio.mock_calls) == 1 message, language = mock_get_tts_audio.mock_calls[0][1] @@ -107,13 +113,11 @@ async def test_legacy_resolving(hass: HomeAssistant, mock_provider: MSProvider) # Pass language and options mock_get_tts_audio.reset_mock() - media = await media_source.async_resolve_media( - hass, - "media-source://tts/test?message=Bye%20World&language=de_DE&voice=Paulus", - None, - ) + media_id = "media-source://tts/test?message=Bye%20World&language=de_DE&voice=Paulus" + media = await media_source.async_resolve_media(hass, media_id, None) assert media.url.startswith("/api/tts_proxy/") assert media.mime_type == "audio/mpeg" + assert await retrieve_media(hass, hass_client, media_id) == HTTPStatus.OK assert len(mock_get_tts_audio.mock_calls) == 1 message, language = mock_get_tts_audio.mock_calls[0][1] @@ -123,16 +127,18 @@ async def test_legacy_resolving(hass: HomeAssistant, mock_provider: MSProvider) @pytest.mark.parametrize("mock_tts_entity", [MSEntity(DEFAULT_LANG)]) -async def test_resolving(hass: HomeAssistant, mock_tts_entity: MSEntity) -> None: +async def test_resolving( + hass: HomeAssistant, hass_client: ClientSessionGenerator, mock_tts_entity: MSEntity +) -> None: """Test resolving entity.""" await mock_config_entry_setup(hass, mock_tts_entity) mock_get_tts_audio = mock_tts_entity.get_tts_audio - media = await media_source.async_resolve_media( - hass, "media-source://tts/tts.test?message=Hello%20World", None - ) + media_id = "media-source://tts/tts.test?message=Hello%20World" + media = await media_source.async_resolve_media(hass, media_id, None) assert media.url.startswith("/api/tts_proxy/") assert media.mime_type == "audio/mpeg" + assert await retrieve_media(hass, hass_client, media_id) == HTTPStatus.OK assert len(mock_get_tts_audio.mock_calls) == 1 message, language = mock_get_tts_audio.mock_calls[0][1] @@ -142,13 +148,13 @@ async def test_resolving(hass: HomeAssistant, mock_tts_entity: MSEntity) -> None # Pass language and options mock_get_tts_audio.reset_mock() - media = await media_source.async_resolve_media( - hass, - "media-source://tts/tts.test?message=Bye%20World&language=de_DE&voice=Paulus", - None, + media_id = ( + "media-source://tts/tts.test?message=Bye%20World&language=de_DE&voice=Paulus" ) + media = await media_source.async_resolve_media(hass, media_id, None) assert media.url.startswith("/api/tts_proxy/") assert media.mime_type == "audio/mpeg" + assert await retrieve_media(hass, hass_client, media_id) == HTTPStatus.OK assert len(mock_get_tts_audio.mock_calls) == 1 message, language = mock_get_tts_audio.mock_calls[0][1] diff --git a/tests/components/voicerss/test_tts.py b/tests/components/voicerss/test_tts.py index 57a5b298162..24997c9d459 100644 --- a/tests/components/voicerss/test_tts.py +++ b/tests/components/voicerss/test_tts.py @@ -4,18 +4,19 @@ from http import HTTPStatus import pytest -from homeassistant.components import media_source, tts +from homeassistant.components import tts from homeassistant.components.media_player import ( ATTR_MEDIA_CONTENT_ID, DOMAIN as DOMAIN_MP, SERVICE_PLAY_MEDIA, ) from homeassistant.core import HomeAssistant -from homeassistant.exceptions import HomeAssistantError from homeassistant.setup import async_setup_component from tests.common import assert_setup_component, async_mock_service +from tests.components.tts.common import retrieve_media from tests.test_util.aiohttp import AiohttpClientMocker +from tests.typing import ClientSessionGenerator URL = "https://api.voicerss.org/" FORM_DATA = { @@ -38,15 +39,6 @@ def mock_tts_cache_dir_autouse(mock_tts_cache_dir): return mock_tts_cache_dir -async def get_media_source_url(hass, media_content_id): - """Get the media source url.""" - if media_source.DOMAIN not in hass.config.components: - assert await async_setup_component(hass, media_source.DOMAIN, {}) - - resolved = await media_source.async_resolve_media(hass, media_content_id, None) - return resolved.url - - async def test_setup_component(hass: HomeAssistant) -> None: """Test setup component.""" config = {tts.DOMAIN: {"platform": "voicerss", "api_key": "1234567xx"}} @@ -66,7 +58,9 @@ async def test_setup_component_without_api_key(hass: HomeAssistant) -> None: async def test_service_say( - hass: HomeAssistant, aioclient_mock: AiohttpClientMocker + hass: HomeAssistant, + hass_client: ClientSessionGenerator, + aioclient_mock: AiohttpClientMocker, ) -> None: """Test service call say.""" calls = async_mock_service(hass, DOMAIN_MP, SERVICE_PLAY_MEDIA) @@ -90,14 +84,18 @@ async def test_service_say( await hass.async_block_till_done() assert len(calls) == 1 - url = await get_media_source_url(hass, calls[0].data[ATTR_MEDIA_CONTENT_ID]) - assert url.endswith(".mp3") + assert ( + await retrieve_media(hass, hass_client, calls[0].data[ATTR_MEDIA_CONTENT_ID]) + == HTTPStatus.OK + ) assert len(aioclient_mock.mock_calls) == 1 assert aioclient_mock.mock_calls[0][2] == FORM_DATA async def test_service_say_german_config( - hass: HomeAssistant, aioclient_mock: AiohttpClientMocker + hass: HomeAssistant, + hass_client: ClientSessionGenerator, + aioclient_mock: AiohttpClientMocker, ) -> None: """Test service call say with german code in the config.""" calls = async_mock_service(hass, DOMAIN_MP, SERVICE_PLAY_MEDIA) @@ -128,13 +126,18 @@ async def test_service_say_german_config( await hass.async_block_till_done() assert len(calls) == 1 - await get_media_source_url(hass, calls[0].data[ATTR_MEDIA_CONTENT_ID]) + assert ( + await retrieve_media(hass, hass_client, calls[0].data[ATTR_MEDIA_CONTENT_ID]) + == HTTPStatus.OK + ) assert len(aioclient_mock.mock_calls) == 1 assert aioclient_mock.mock_calls[0][2] == form_data async def test_service_say_german_service( - hass: HomeAssistant, aioclient_mock: AiohttpClientMocker + hass: HomeAssistant, + hass_client: ClientSessionGenerator, + aioclient_mock: AiohttpClientMocker, ) -> None: """Test service call say with german code in the service.""" calls = async_mock_service(hass, DOMAIN_MP, SERVICE_PLAY_MEDIA) @@ -160,13 +163,18 @@ async def test_service_say_german_service( await hass.async_block_till_done() assert len(calls) == 1 - await get_media_source_url(hass, calls[0].data[ATTR_MEDIA_CONTENT_ID]) + assert ( + await retrieve_media(hass, hass_client, calls[0].data[ATTR_MEDIA_CONTENT_ID]) + == HTTPStatus.OK + ) assert len(aioclient_mock.mock_calls) == 1 assert aioclient_mock.mock_calls[0][2] == form_data async def test_service_say_error( - hass: HomeAssistant, aioclient_mock: AiohttpClientMocker + hass: HomeAssistant, + hass_client: ClientSessionGenerator, + aioclient_mock: AiohttpClientMocker, ) -> None: """Test service call say with http response 400.""" calls = async_mock_service(hass, DOMAIN_MP, SERVICE_PLAY_MEDIA) @@ -189,14 +197,18 @@ async def test_service_say_error( ) await hass.async_block_till_done() - with pytest.raises(HomeAssistantError): - await get_media_source_url(hass, calls[0].data[ATTR_MEDIA_CONTENT_ID]) + assert ( + await retrieve_media(hass, hass_client, calls[0].data[ATTR_MEDIA_CONTENT_ID]) + == HTTPStatus.NOT_FOUND + ) assert len(aioclient_mock.mock_calls) == 1 assert aioclient_mock.mock_calls[0][2] == FORM_DATA async def test_service_say_timeout( - hass: HomeAssistant, aioclient_mock: AiohttpClientMocker + hass: HomeAssistant, + hass_client: ClientSessionGenerator, + aioclient_mock: AiohttpClientMocker, ) -> None: """Test service call say with http timeout.""" calls = async_mock_service(hass, DOMAIN_MP, SERVICE_PLAY_MEDIA) @@ -219,14 +231,18 @@ async def test_service_say_timeout( ) await hass.async_block_till_done() - with pytest.raises(HomeAssistantError): - await get_media_source_url(hass, calls[0].data[ATTR_MEDIA_CONTENT_ID]) + assert ( + await retrieve_media(hass, hass_client, calls[0].data[ATTR_MEDIA_CONTENT_ID]) + == HTTPStatus.NOT_FOUND + ) assert len(aioclient_mock.mock_calls) == 1 assert aioclient_mock.mock_calls[0][2] == FORM_DATA async def test_service_say_error_msg( - hass: HomeAssistant, aioclient_mock: AiohttpClientMocker + hass: HomeAssistant, + hass_client: ClientSessionGenerator, + aioclient_mock: AiohttpClientMocker, ) -> None: """Test service call say with http error api message.""" calls = async_mock_service(hass, DOMAIN_MP, SERVICE_PLAY_MEDIA) @@ -254,7 +270,9 @@ async def test_service_say_error_msg( ) await hass.async_block_till_done() - with pytest.raises(media_source.Unresolvable): - await get_media_source_url(hass, calls[0].data[ATTR_MEDIA_CONTENT_ID]) + assert ( + await retrieve_media(hass, hass_client, calls[0].data[ATTR_MEDIA_CONTENT_ID]) + == HTTPStatus.NOT_FOUND + ) assert len(aioclient_mock.mock_calls) == 1 assert aioclient_mock.mock_calls[0][2] == FORM_DATA diff --git a/tests/components/wyoming/snapshots/test_tts.ambr b/tests/components/wyoming/snapshots/test_tts.ambr index 1cb5a6cb874..299bddb07e5 100644 --- a/tests/components/wyoming/snapshots/test_tts.ambr +++ b/tests/components/wyoming/snapshots/test_tts.ambr @@ -10,6 +10,39 @@ }), ]) # --- +# name: test_get_tts_audio_different_formats + list([ + dict({ + 'data': dict({ + 'text': 'Hello world', + }), + 'payload': None, + 'type': 'synthesize', + }), + ]) +# --- +# name: test_get_tts_audio_different_formats.1 + list([ + dict({ + 'data': dict({ + 'text': 'Hello world', + }), + 'payload': None, + 'type': 'synthesize', + }), + ]) +# --- +# name: test_get_tts_audio_mp3 + list([ + dict({ + 'data': dict({ + 'text': 'Hello world', + }), + 'payload': None, + 'type': 'synthesize', + }), + ]) +# --- # name: test_get_tts_audio_raw list([ dict({ diff --git a/tests/components/wyoming/test_tts.py b/tests/components/wyoming/test_tts.py index 51a684bc4fd..68b7b2b62bc 100644 --- a/tests/components/wyoming/test_tts.py +++ b/tests/components/wyoming/test_tts.py @@ -51,31 +51,7 @@ async def test_get_tts_audio(hass: HomeAssistant, init_wyoming_tts, snapshot) -> AudioStop().event(), ] - with patch( - "homeassistant.components.wyoming.tts.AsyncTcpClient", - MockAsyncTcpClient(audio_events), - ) as mock_client: - extension, data = await tts.async_get_media_source_audio( - hass, - tts.generate_media_source_id(hass, "Hello world", "tts.test_tts", "en-US"), - ) - - assert extension == "wav" - assert data is not None - with io.BytesIO(data) as wav_io, wave.open(wav_io, "rb") as wav_file: - assert wav_file.getframerate() == 16000 - assert wav_file.getsampwidth() == 2 - assert wav_file.getnchannels() == 1 - assert wav_file.readframes(wav_file.getnframes()) == audio - - assert mock_client.written == snapshot - - -async def test_get_tts_audio_raw( - hass: HomeAssistant, init_wyoming_tts, snapshot -) -> None: - """Test get raw audio.""" - audio = bytes(100) + # Verify audio audio_events = [ AudioChunk(audio=audio, rate=16000, width=2, channels=1).event(), AudioStop().event(), @@ -92,12 +68,83 @@ async def test_get_tts_audio_raw( "Hello world", "tts.test_tts", "en-US", - options={tts.ATTR_AUDIO_OUTPUT: "raw"}, + options={tts.ATTR_PREFERRED_FORMAT: "wav"}, ), ) - assert extension == "raw" - assert data == audio + assert extension == "wav" + assert data is not None + with io.BytesIO(data) as wav_io, wave.open(wav_io, "rb") as wav_file: + assert wav_file.getframerate() == 16000 + assert wav_file.getsampwidth() == 2 + assert wav_file.getnchannels() == 1 + assert wav_file.readframes(wav_file.getnframes()) == audio + + assert mock_client.written == snapshot + + +async def test_get_tts_audio_different_formats( + hass: HomeAssistant, init_wyoming_tts, snapshot +) -> None: + """Test changing preferred audio format.""" + audio = bytes(16000 * 2 * 1) # one second + audio_events = [ + AudioChunk(audio=audio, rate=16000, width=2, channels=1).event(), + AudioStop().event(), + ] + + # Request a different sample rate, etc. + with patch( + "homeassistant.components.wyoming.tts.AsyncTcpClient", + MockAsyncTcpClient(audio_events), + ) as mock_client: + extension, data = await tts.async_get_media_source_audio( + hass, + tts.generate_media_source_id( + hass, + "Hello world", + "tts.test_tts", + "en-US", + options={ + tts.ATTR_PREFERRED_FORMAT: "wav", + tts.ATTR_PREFERRED_SAMPLE_RATE: 48000, + tts.ATTR_PREFERRED_SAMPLE_CHANNELS: 2, + }, + ), + ) + + assert extension == "wav" + assert data is not None + with io.BytesIO(data) as wav_io, wave.open(wav_io, "rb") as wav_file: + assert wav_file.getframerate() == 48000 + assert wav_file.getsampwidth() == 2 + assert wav_file.getnchannels() == 2 + assert wav_file.getnframes() == wav_file.getframerate() # one second + + assert mock_client.written == snapshot + + # MP3 is the default + audio_events = [ + AudioChunk(audio=audio, rate=16000, width=2, channels=1).event(), + AudioStop().event(), + ] + + with patch( + "homeassistant.components.wyoming.tts.AsyncTcpClient", + MockAsyncTcpClient(audio_events), + ) as mock_client: + extension, data = await tts.async_get_media_source_audio( + hass, + tts.generate_media_source_id( + hass, + "Hello world", + "tts.test_tts", + "en-US", + ), + ) + + assert extension == "mp3" + assert b"ID3" in data assert mock_client.written == snapshot diff --git a/tests/components/yandextts/test_tts.py b/tests/components/yandextts/test_tts.py index d04aef6b16b..a8052e45047 100644 --- a/tests/components/yandextts/test_tts.py +++ b/tests/components/yandextts/test_tts.py @@ -4,7 +4,7 @@ from http import HTTPStatus import pytest -from homeassistant.components import media_source, tts +from homeassistant.components import tts from homeassistant.components.media_player import ( ATTR_MEDIA_CONTENT_ID, DOMAIN as DOMAIN_MP, @@ -14,7 +14,9 @@ from homeassistant.core import HomeAssistant from homeassistant.setup import async_setup_component from tests.common import assert_setup_component, async_mock_service +from tests.components.tts.common import retrieve_media from tests.test_util.aiohttp import AiohttpClientMocker +from tests.typing import ClientSessionGenerator URL = "https://tts.voicetech.yandex.net/generate?" @@ -30,15 +32,6 @@ def mock_tts_cache_dir_autouse(mock_tts_cache_dir): return mock_tts_cache_dir -async def get_media_source_url(hass, media_content_id): - """Get the media source url.""" - if media_source.DOMAIN not in hass.config.components: - assert await async_setup_component(hass, media_source.DOMAIN, {}) - - resolved = await media_source.async_resolve_media(hass, media_content_id, None) - return resolved.url - - async def test_setup_component(hass: HomeAssistant) -> None: """Test setup component.""" config = {tts.DOMAIN: {"platform": "yandextts", "api_key": "1234567xx"}} @@ -58,7 +51,9 @@ async def test_setup_component_without_api_key(hass: HomeAssistant) -> None: async def test_service_say( - hass: HomeAssistant, aioclient_mock: AiohttpClientMocker + hass: HomeAssistant, + aioclient_mock: AiohttpClientMocker, + hass_client: ClientSessionGenerator, ) -> None: """Test service call say.""" calls = async_mock_service(hass, DOMAIN_MP, SERVICE_PLAY_MEDIA) @@ -87,12 +82,18 @@ async def test_service_say( blocking=True, ) assert len(calls) == 1 - await get_media_source_url(hass, calls[0].data[ATTR_MEDIA_CONTENT_ID]) + assert ( + await retrieve_media(hass, hass_client, calls[0].data[ATTR_MEDIA_CONTENT_ID]) + == HTTPStatus.OK + ) + assert len(aioclient_mock.mock_calls) == 1 async def test_service_say_russian_config( - hass: HomeAssistant, aioclient_mock: AiohttpClientMocker + hass: HomeAssistant, + aioclient_mock: AiohttpClientMocker, + hass_client: ClientSessionGenerator, ) -> None: """Test service call say.""" calls = async_mock_service(hass, DOMAIN_MP, SERVICE_PLAY_MEDIA) @@ -128,12 +129,18 @@ async def test_service_say_russian_config( ) assert len(calls) == 1 - await get_media_source_url(hass, calls[0].data[ATTR_MEDIA_CONTENT_ID]) + assert ( + await retrieve_media(hass, hass_client, calls[0].data[ATTR_MEDIA_CONTENT_ID]) + == HTTPStatus.OK + ) + assert len(aioclient_mock.mock_calls) == 1 async def test_service_say_russian_service( - hass: HomeAssistant, aioclient_mock: AiohttpClientMocker + hass: HomeAssistant, + aioclient_mock: AiohttpClientMocker, + hass_client: ClientSessionGenerator, ) -> None: """Test service call say.""" calls = async_mock_service(hass, DOMAIN_MP, SERVICE_PLAY_MEDIA) @@ -166,12 +173,18 @@ async def test_service_say_russian_service( blocking=True, ) assert len(calls) == 1 - await get_media_source_url(hass, calls[0].data[ATTR_MEDIA_CONTENT_ID]) + assert ( + await retrieve_media(hass, hass_client, calls[0].data[ATTR_MEDIA_CONTENT_ID]) + == HTTPStatus.OK + ) + assert len(aioclient_mock.mock_calls) == 1 async def test_service_say_timeout( - hass: HomeAssistant, aioclient_mock: AiohttpClientMocker + hass: HomeAssistant, + aioclient_mock: AiohttpClientMocker, + hass_client: ClientSessionGenerator, ) -> None: """Test service call say.""" calls = async_mock_service(hass, DOMAIN_MP, SERVICE_PLAY_MEDIA) @@ -207,13 +220,18 @@ async def test_service_say_timeout( await hass.async_block_till_done() assert len(calls) == 1 - with pytest.raises(media_source.Unresolvable): - await get_media_source_url(hass, calls[0].data[ATTR_MEDIA_CONTENT_ID]) + assert ( + await retrieve_media(hass, hass_client, calls[0].data[ATTR_MEDIA_CONTENT_ID]) + == HTTPStatus.NOT_FOUND + ) + assert len(aioclient_mock.mock_calls) == 1 async def test_service_say_http_error( - hass: HomeAssistant, aioclient_mock: AiohttpClientMocker + hass: HomeAssistant, + aioclient_mock: AiohttpClientMocker, + hass_client: ClientSessionGenerator, ) -> None: """Test service call say.""" calls = async_mock_service(hass, DOMAIN_MP, SERVICE_PLAY_MEDIA) @@ -248,12 +266,16 @@ async def test_service_say_http_error( ) assert len(calls) == 1 - with pytest.raises(media_source.Unresolvable): - await get_media_source_url(hass, calls[0].data[ATTR_MEDIA_CONTENT_ID]) + assert ( + await retrieve_media(hass, hass_client, calls[0].data[ATTR_MEDIA_CONTENT_ID]) + == HTTPStatus.NOT_FOUND + ) async def test_service_say_specified_speaker( - hass: HomeAssistant, aioclient_mock: AiohttpClientMocker + hass: HomeAssistant, + aioclient_mock: AiohttpClientMocker, + hass_client: ClientSessionGenerator, ) -> None: """Test service call say.""" calls = async_mock_service(hass, DOMAIN_MP, SERVICE_PLAY_MEDIA) @@ -288,12 +310,18 @@ async def test_service_say_specified_speaker( blocking=True, ) assert len(calls) == 1 - await get_media_source_url(hass, calls[0].data[ATTR_MEDIA_CONTENT_ID]) + assert ( + await retrieve_media(hass, hass_client, calls[0].data[ATTR_MEDIA_CONTENT_ID]) + == HTTPStatus.OK + ) + assert len(aioclient_mock.mock_calls) == 1 async def test_service_say_specified_emotion( - hass: HomeAssistant, aioclient_mock: AiohttpClientMocker + hass: HomeAssistant, + aioclient_mock: AiohttpClientMocker, + hass_client: ClientSessionGenerator, ) -> None: """Test service call say.""" calls = async_mock_service(hass, DOMAIN_MP, SERVICE_PLAY_MEDIA) @@ -328,13 +356,18 @@ async def test_service_say_specified_emotion( blocking=True, ) assert len(calls) == 1 - await get_media_source_url(hass, calls[0].data[ATTR_MEDIA_CONTENT_ID]) + assert ( + await retrieve_media(hass, hass_client, calls[0].data[ATTR_MEDIA_CONTENT_ID]) + == HTTPStatus.OK + ) assert len(aioclient_mock.mock_calls) == 1 async def test_service_say_specified_low_speed( - hass: HomeAssistant, aioclient_mock: AiohttpClientMocker + hass: HomeAssistant, + aioclient_mock: AiohttpClientMocker, + hass_client: ClientSessionGenerator, ) -> None: """Test service call say.""" calls = async_mock_service(hass, DOMAIN_MP, SERVICE_PLAY_MEDIA) @@ -365,13 +398,18 @@ async def test_service_say_specified_low_speed( blocking=True, ) assert len(calls) == 1 - await get_media_source_url(hass, calls[0].data[ATTR_MEDIA_CONTENT_ID]) + assert ( + await retrieve_media(hass, hass_client, calls[0].data[ATTR_MEDIA_CONTENT_ID]) + == HTTPStatus.OK + ) assert len(aioclient_mock.mock_calls) == 1 async def test_service_say_specified_speed( - hass: HomeAssistant, aioclient_mock: AiohttpClientMocker + hass: HomeAssistant, + aioclient_mock: AiohttpClientMocker, + hass_client: ClientSessionGenerator, ) -> None: """Test service call say.""" calls = async_mock_service(hass, DOMAIN_MP, SERVICE_PLAY_MEDIA) @@ -400,13 +438,18 @@ async def test_service_say_specified_speed( blocking=True, ) assert len(calls) == 1 - await get_media_source_url(hass, calls[0].data[ATTR_MEDIA_CONTENT_ID]) + assert ( + await retrieve_media(hass, hass_client, calls[0].data[ATTR_MEDIA_CONTENT_ID]) + == HTTPStatus.OK + ) assert len(aioclient_mock.mock_calls) == 1 async def test_service_say_specified_options( - hass: HomeAssistant, aioclient_mock: AiohttpClientMocker + hass: HomeAssistant, + aioclient_mock: AiohttpClientMocker, + hass_client: ClientSessionGenerator, ) -> None: """Test service call say with options.""" calls = async_mock_service(hass, DOMAIN_MP, SERVICE_PLAY_MEDIA) @@ -438,6 +481,9 @@ async def test_service_say_specified_options( blocking=True, ) assert len(calls) == 1 - await get_media_source_url(hass, calls[0].data[ATTR_MEDIA_CONTENT_ID]) + assert ( + await retrieve_media(hass, hass_client, calls[0].data[ATTR_MEDIA_CONTENT_ID]) + == HTTPStatus.OK + ) assert len(aioclient_mock.mock_calls) == 1