Automatically convert TTS audio to MP3 on demand (#102814)
* Add ATTR_PREFERRED_FORMAT to TTS for auto-converting audio * Move conversion into SpeechManager * Handle None case for expected_extension * Only use ATTR_AUDIO_OUTPUT * Prefer MP3 in pipelines * Automatically convert to mp3 on demand * Add preferred audio format * Break out preferred format * Add ATTR_BLOCKING to allow async fetching * Make a copy of supported options * Fix MaryTTS tests * Update ESPHome to use "wav" instead of "raw" * Clean up tests, remove blocking * Clean up rest of TTS tests * Fix ESPHome tests * More test coveragepull/103516/head
parent
054089291f
commit
ae516ffbb5
|
@ -971,12 +971,16 @@ class PipelineRun:
|
|||
# pipeline.tts_engine can't be None or this function is not called
|
||||
engine = cast(str, self.pipeline.tts_engine)
|
||||
|
||||
tts_options = {}
|
||||
tts_options: dict[str, Any] = {}
|
||||
if self.pipeline.tts_voice is not None:
|
||||
tts_options[tts.ATTR_VOICE] = self.pipeline.tts_voice
|
||||
|
||||
if self.tts_audio_output is not None:
|
||||
tts_options[tts.ATTR_AUDIO_OUTPUT] = self.tts_audio_output
|
||||
tts_options[tts.ATTR_PREFERRED_FORMAT] = self.tts_audio_output
|
||||
if self.tts_audio_output == "wav":
|
||||
# 16 Khz, 16-bit mono
|
||||
tts_options[tts.ATTR_PREFERRED_SAMPLE_RATE] = 16000
|
||||
tts_options[tts.ATTR_PREFERRED_SAMPLE_CHANNELS] = 1
|
||||
|
||||
try:
|
||||
options_supported = await tts.async_support_options(
|
||||
|
|
|
@ -150,4 +150,4 @@ class CloudProvider(Provider):
|
|||
_LOGGER.error("Voice error: %s", err)
|
||||
return (None, None)
|
||||
|
||||
return (str(options[ATTR_AUDIO_OUTPUT]), data)
|
||||
return (str(options[ATTR_AUDIO_OUTPUT].value), data)
|
||||
|
|
|
@ -3,9 +3,11 @@ from __future__ import annotations
|
|||
|
||||
import asyncio
|
||||
from collections.abc import AsyncIterable, Callable
|
||||
import io
|
||||
import logging
|
||||
import socket
|
||||
from typing import cast
|
||||
import wave
|
||||
|
||||
from aioesphomeapi import (
|
||||
VoiceAssistantAudioSettings,
|
||||
|
@ -88,6 +90,7 @@ class VoiceAssistantUDPServer(asyncio.DatagramProtocol):
|
|||
self.handle_event = handle_event
|
||||
self.handle_finished = handle_finished
|
||||
self._tts_done = asyncio.Event()
|
||||
self._tts_task: asyncio.Task | None = None
|
||||
|
||||
async def start_server(self) -> int:
|
||||
"""Start accepting connections."""
|
||||
|
@ -189,7 +192,7 @@ class VoiceAssistantUDPServer(asyncio.DatagramProtocol):
|
|||
|
||||
if self.device_info.voice_assistant_version >= 2:
|
||||
media_id = event.data["tts_output"]["media_id"]
|
||||
self.hass.async_create_background_task(
|
||||
self._tts_task = self.hass.async_create_background_task(
|
||||
self._send_tts(media_id), "esphome_voice_assistant_tts"
|
||||
)
|
||||
else:
|
||||
|
@ -228,7 +231,7 @@ class VoiceAssistantUDPServer(asyncio.DatagramProtocol):
|
|||
audio_settings = VoiceAssistantAudioSettings()
|
||||
|
||||
tts_audio_output = (
|
||||
"raw" if self.device_info.voice_assistant_version >= 2 else "mp3"
|
||||
"wav" if self.device_info.voice_assistant_version >= 2 else "mp3"
|
||||
)
|
||||
|
||||
_LOGGER.debug("Starting pipeline")
|
||||
|
@ -302,11 +305,32 @@ class VoiceAssistantUDPServer(asyncio.DatagramProtocol):
|
|||
VoiceAssistantEventType.VOICE_ASSISTANT_TTS_STREAM_START, {}
|
||||
)
|
||||
|
||||
_extension, audio_bytes = await tts.async_get_media_source_audio(
|
||||
extension, data = await tts.async_get_media_source_audio(
|
||||
self.hass,
|
||||
media_id,
|
||||
)
|
||||
|
||||
if extension != "wav":
|
||||
raise ValueError(f"Only WAV audio can be streamed, got {extension}")
|
||||
|
||||
with io.BytesIO(data) as wav_io:
|
||||
with wave.open(wav_io, "rb") as wav_file:
|
||||
sample_rate = wav_file.getframerate()
|
||||
sample_width = wav_file.getsampwidth()
|
||||
sample_channels = wav_file.getnchannels()
|
||||
|
||||
if (
|
||||
(sample_rate != 16000)
|
||||
or (sample_width != 2)
|
||||
or (sample_channels != 1)
|
||||
):
|
||||
raise ValueError(
|
||||
"Expected rate/width/channels as 16000/2/1,"
|
||||
" got {sample_rate}/{sample_width}/{sample_channels}}"
|
||||
)
|
||||
|
||||
audio_bytes = wav_file.readframes(wav_file.getnframes())
|
||||
|
||||
_LOGGER.debug("Sending %d bytes of audio", len(audio_bytes))
|
||||
|
||||
bytes_per_sample = stt.AudioBitRates.BITRATE_16 // 8
|
||||
|
@ -330,4 +354,5 @@ class VoiceAssistantUDPServer(asyncio.DatagramProtocol):
|
|||
self.handle_event(
|
||||
VoiceAssistantEventType.VOICE_ASSISTANT_TTS_STREAM_END, {}
|
||||
)
|
||||
self._tts_task = None
|
||||
self._tts_done.set()
|
||||
|
|
|
@ -13,6 +13,8 @@ import logging
|
|||
import mimetypes
|
||||
import os
|
||||
import re
|
||||
import subprocess
|
||||
import tempfile
|
||||
from typing import Any, TypedDict, final
|
||||
|
||||
from aiohttp import web
|
||||
|
@ -20,7 +22,7 @@ import mutagen
|
|||
from mutagen.id3 import ID3, TextFrame as ID3Text
|
||||
import voluptuous as vol
|
||||
|
||||
from homeassistant.components import websocket_api
|
||||
from homeassistant.components import ffmpeg, websocket_api
|
||||
from homeassistant.components.http import HomeAssistantView
|
||||
from homeassistant.components.media_player import (
|
||||
ATTR_MEDIA_ANNOUNCE,
|
||||
|
@ -72,11 +74,15 @@ __all__ = [
|
|||
"async_get_media_source_audio",
|
||||
"async_support_options",
|
||||
"ATTR_AUDIO_OUTPUT",
|
||||
"ATTR_PREFERRED_FORMAT",
|
||||
"ATTR_PREFERRED_SAMPLE_RATE",
|
||||
"ATTR_PREFERRED_SAMPLE_CHANNELS",
|
||||
"CONF_LANG",
|
||||
"DEFAULT_CACHE_DIR",
|
||||
"generate_media_source_id",
|
||||
"PLATFORM_SCHEMA_BASE",
|
||||
"PLATFORM_SCHEMA",
|
||||
"SampleFormat",
|
||||
"Provider",
|
||||
"TtsAudioType",
|
||||
"Voice",
|
||||
|
@ -86,6 +92,9 @@ _LOGGER = logging.getLogger(__name__)
|
|||
|
||||
ATTR_PLATFORM = "platform"
|
||||
ATTR_AUDIO_OUTPUT = "audio_output"
|
||||
ATTR_PREFERRED_FORMAT = "preferred_format"
|
||||
ATTR_PREFERRED_SAMPLE_RATE = "preferred_sample_rate"
|
||||
ATTR_PREFERRED_SAMPLE_CHANNELS = "preferred_sample_channels"
|
||||
ATTR_MEDIA_PLAYER_ENTITY_ID = "media_player_entity_id"
|
||||
ATTR_VOICE = "voice"
|
||||
|
||||
|
@ -199,6 +208,83 @@ def async_get_text_to_speech_languages(hass: HomeAssistant) -> set[str]:
|
|||
return languages
|
||||
|
||||
|
||||
async def async_convert_audio(
|
||||
hass: HomeAssistant,
|
||||
from_extension: str,
|
||||
audio_bytes: bytes,
|
||||
to_extension: str,
|
||||
to_sample_rate: int | None = None,
|
||||
to_sample_channels: int | None = None,
|
||||
) -> bytes:
|
||||
"""Convert audio to a preferred format using ffmpeg."""
|
||||
ffmpeg_manager = ffmpeg.get_ffmpeg_manager(hass)
|
||||
return await hass.async_add_executor_job(
|
||||
lambda: _convert_audio(
|
||||
ffmpeg_manager.binary,
|
||||
from_extension,
|
||||
audio_bytes,
|
||||
to_extension,
|
||||
to_sample_rate=to_sample_rate,
|
||||
to_sample_channels=to_sample_channels,
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
def _convert_audio(
|
||||
ffmpeg_binary: str,
|
||||
from_extension: str,
|
||||
audio_bytes: bytes,
|
||||
to_extension: str,
|
||||
to_sample_rate: int | None = None,
|
||||
to_sample_channels: int | None = None,
|
||||
) -> bytes:
|
||||
"""Convert audio to a preferred format using ffmpeg."""
|
||||
|
||||
# We have to use a temporary file here because some formats like WAV store
|
||||
# the length of the file in the header, and therefore cannot be written in a
|
||||
# streaming fashion.
|
||||
with tempfile.NamedTemporaryFile(
|
||||
mode="wb+", suffix=f".{to_extension}"
|
||||
) as output_file:
|
||||
# input
|
||||
command = [
|
||||
ffmpeg_binary,
|
||||
"-y", # overwrite temp file
|
||||
"-f",
|
||||
from_extension,
|
||||
"-i",
|
||||
"pipe:", # input from stdin
|
||||
]
|
||||
|
||||
# output
|
||||
command.extend(["-f", to_extension])
|
||||
|
||||
if to_sample_rate is not None:
|
||||
command.extend(["-ar", str(to_sample_rate)])
|
||||
|
||||
if to_sample_channels is not None:
|
||||
command.extend(["-ac", str(to_sample_channels)])
|
||||
|
||||
if to_extension == "mp3":
|
||||
# Max quality for MP3
|
||||
command.extend(["-q:a", "0"])
|
||||
|
||||
command.append(output_file.name)
|
||||
|
||||
with subprocess.Popen(
|
||||
command, stdin=subprocess.PIPE, stderr=subprocess.PIPE
|
||||
) as proc:
|
||||
_stdout, stderr = proc.communicate(input=audio_bytes)
|
||||
if proc.returncode != 0:
|
||||
_LOGGER.error(stderr.decode())
|
||||
raise RuntimeError(
|
||||
f"Unexpected error while running ffmpeg with arguments: {command}. See log for details."
|
||||
)
|
||||
|
||||
output_file.seek(0)
|
||||
return output_file.read()
|
||||
|
||||
|
||||
async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool:
|
||||
"""Set up TTS."""
|
||||
websocket_api.async_register_command(hass, websocket_list_engines)
|
||||
|
@ -482,7 +568,18 @@ class SpeechManager:
|
|||
merged_options = dict(engine_instance.default_options or {})
|
||||
merged_options.update(options or {})
|
||||
|
||||
supported_options = engine_instance.supported_options or []
|
||||
supported_options = list(engine_instance.supported_options or [])
|
||||
|
||||
# ATTR_PREFERRED_* options are always "supported" since they're used to
|
||||
# convert audio after the TTS has run (if necessary).
|
||||
supported_options.extend(
|
||||
(
|
||||
ATTR_PREFERRED_FORMAT,
|
||||
ATTR_PREFERRED_SAMPLE_RATE,
|
||||
ATTR_PREFERRED_SAMPLE_CHANNELS,
|
||||
)
|
||||
)
|
||||
|
||||
invalid_opts = [
|
||||
opt_name for opt_name in merged_options if opt_name not in supported_options
|
||||
]
|
||||
|
@ -520,12 +617,7 @@ class SpeechManager:
|
|||
# Load speech from engine into memory
|
||||
else:
|
||||
filename = await self._async_get_tts_audio(
|
||||
engine_instance,
|
||||
cache_key,
|
||||
message,
|
||||
use_cache,
|
||||
language,
|
||||
options,
|
||||
engine_instance, cache_key, message, use_cache, language, options
|
||||
)
|
||||
|
||||
return f"/api/tts_proxy/{filename}"
|
||||
|
@ -590,10 +682,10 @@ class SpeechManager:
|
|||
|
||||
This method is a coroutine.
|
||||
"""
|
||||
if options is not None and ATTR_AUDIO_OUTPUT in options:
|
||||
expected_extension = options[ATTR_AUDIO_OUTPUT]
|
||||
else:
|
||||
expected_extension = None
|
||||
options = options or {}
|
||||
|
||||
# Default to MP3 unless a different format is preferred
|
||||
final_extension = options.get(ATTR_PREFERRED_FORMAT, "mp3")
|
||||
|
||||
async def get_tts_data() -> str:
|
||||
"""Handle data available."""
|
||||
|
@ -614,8 +706,27 @@ class SpeechManager:
|
|||
f"No TTS from {engine_instance.name} for '{message}'"
|
||||
)
|
||||
|
||||
# Only convert if we have a preferred format different than the
|
||||
# expected format from the TTS system, or if a specific sample
|
||||
# rate/format/channel count is requested.
|
||||
needs_conversion = (
|
||||
(final_extension != extension)
|
||||
or (ATTR_PREFERRED_SAMPLE_RATE in options)
|
||||
or (ATTR_PREFERRED_SAMPLE_CHANNELS in options)
|
||||
)
|
||||
|
||||
if needs_conversion:
|
||||
data = await async_convert_audio(
|
||||
self.hass,
|
||||
extension,
|
||||
data,
|
||||
to_extension=final_extension,
|
||||
to_sample_rate=options.get(ATTR_PREFERRED_SAMPLE_RATE),
|
||||
to_sample_channels=options.get(ATTR_PREFERRED_SAMPLE_CHANNELS),
|
||||
)
|
||||
|
||||
# Create file infos
|
||||
filename = f"{cache_key}.{extension}".lower()
|
||||
filename = f"{cache_key}.{final_extension}".lower()
|
||||
|
||||
# Validate filename
|
||||
if not _RE_VOICE_FILE.match(filename) and not _RE_LEGACY_VOICE_FILE.match(
|
||||
|
@ -626,10 +737,11 @@ class SpeechManager:
|
|||
)
|
||||
|
||||
# Save to memory
|
||||
if extension == "mp3":
|
||||
if final_extension == "mp3":
|
||||
data = self.write_tags(
|
||||
filename, data, engine_instance.name, message, language, options
|
||||
)
|
||||
|
||||
self._async_store_to_memcache(cache_key, filename, data)
|
||||
|
||||
if cache:
|
||||
|
@ -641,9 +753,6 @@ class SpeechManager:
|
|||
|
||||
audio_task = self.hass.async_create_task(get_tts_data())
|
||||
|
||||
if expected_extension is None:
|
||||
return await audio_task
|
||||
|
||||
def handle_error(_future: asyncio.Future) -> None:
|
||||
"""Handle error."""
|
||||
if audio_task.exception():
|
||||
|
@ -651,7 +760,7 @@ class SpeechManager:
|
|||
|
||||
audio_task.add_done_callback(handle_error)
|
||||
|
||||
filename = f"{cache_key}.{expected_extension}".lower()
|
||||
filename = f"{cache_key}.{final_extension}".lower()
|
||||
self.mem_cache[cache_key] = {
|
||||
"filename": filename,
|
||||
"voice": b"",
|
||||
|
@ -747,11 +856,12 @@ class SpeechManager:
|
|||
raise HomeAssistantError(f"{cache_key} not in cache!")
|
||||
await self._async_file_to_mem(cache_key)
|
||||
|
||||
content, _ = mimetypes.guess_type(filename)
|
||||
cached = self.mem_cache[cache_key]
|
||||
if pending := cached.get("pending"):
|
||||
await pending
|
||||
cached = self.mem_cache[cache_key]
|
||||
|
||||
content, _ = mimetypes.guess_type(filename)
|
||||
return content, cached["voice"]
|
||||
|
||||
@staticmethod
|
||||
|
|
|
@ -3,7 +3,7 @@
|
|||
"name": "Text-to-speech (TTS)",
|
||||
"after_dependencies": ["media_player"],
|
||||
"codeowners": ["@home-assistant/core", "@pvizeli"],
|
||||
"dependencies": ["http"],
|
||||
"dependencies": ["http", "ffmpeg"],
|
||||
"documentation": "https://www.home-assistant.io/integrations/tts",
|
||||
"integration_type": "entity",
|
||||
"loggers": ["mutagen"],
|
||||
|
|
|
@ -4,7 +4,7 @@ import io
|
|||
import logging
|
||||
import wave
|
||||
|
||||
from wyoming.audio import AudioChunk, AudioChunkConverter, AudioStop
|
||||
from wyoming.audio import AudioChunk, AudioStop
|
||||
from wyoming.client import AsyncTcpClient
|
||||
from wyoming.tts import Synthesize, SynthesizeVoice
|
||||
|
||||
|
@ -88,12 +88,16 @@ class WyomingTtsProvider(tts.TextToSpeechEntity):
|
|||
@property
|
||||
def supported_options(self):
|
||||
"""Return list of supported options like voice, emotion."""
|
||||
return [tts.ATTR_AUDIO_OUTPUT, tts.ATTR_VOICE, ATTR_SPEAKER]
|
||||
return [
|
||||
tts.ATTR_AUDIO_OUTPUT,
|
||||
tts.ATTR_VOICE,
|
||||
ATTR_SPEAKER,
|
||||
]
|
||||
|
||||
@property
|
||||
def default_options(self):
|
||||
"""Return a dict include default options."""
|
||||
return {tts.ATTR_AUDIO_OUTPUT: "wav"}
|
||||
return {}
|
||||
|
||||
@callback
|
||||
def async_get_supported_voices(self, language: str) -> list[tts.Voice] | None:
|
||||
|
@ -143,27 +147,4 @@ class WyomingTtsProvider(tts.TextToSpeechEntity):
|
|||
except (OSError, WyomingError):
|
||||
return (None, None)
|
||||
|
||||
if options[tts.ATTR_AUDIO_OUTPUT] == "wav":
|
||||
return ("wav", data)
|
||||
|
||||
# Raw output (convert to 16Khz, 16-bit mono)
|
||||
with io.BytesIO(data) as wav_io:
|
||||
wav_reader: wave.Wave_read = wave.open(wav_io, "rb")
|
||||
raw_data = (
|
||||
AudioChunkConverter(
|
||||
rate=16000,
|
||||
width=2,
|
||||
channels=1,
|
||||
)
|
||||
.convert(
|
||||
AudioChunk(
|
||||
audio=wav_reader.readframes(wav_reader.getnframes()),
|
||||
rate=wav_reader.getframerate(),
|
||||
width=wav_reader.getsampwidth(),
|
||||
channels=wav_reader.getnchannels(),
|
||||
)
|
||||
)
|
||||
.audio
|
||||
)
|
||||
|
||||
return ("raw", raw_data)
|
||||
return ("wav", data)
|
||||
|
|
|
@ -20,6 +20,7 @@ cryptography==41.0.4
|
|||
dbus-fast==2.12.0
|
||||
fnv-hash-fast==0.5.0
|
||||
ha-av==10.1.1
|
||||
ha-ffmpeg==3.1.0
|
||||
hass-nabucasa==0.74.0
|
||||
hassil==1.2.5
|
||||
home-assistant-bluetooth==1.10.4
|
||||
|
|
|
@ -9,7 +9,7 @@ import wave
|
|||
import pytest
|
||||
from syrupy.assertion import SnapshotAssertion
|
||||
|
||||
from homeassistant.components import assist_pipeline, stt
|
||||
from homeassistant.components import assist_pipeline, stt, tts
|
||||
from homeassistant.components.assist_pipeline.const import (
|
||||
CONF_DEBUG_RECORDING_DIR,
|
||||
DOMAIN,
|
||||
|
@ -660,3 +660,42 @@ def test_pipeline_run_equality(hass: HomeAssistant, init_components) -> None:
|
|||
assert run_1 == run_1
|
||||
assert run_1 != run_2
|
||||
assert run_1 != 1234
|
||||
|
||||
|
||||
async def test_tts_audio_output(
|
||||
hass: HomeAssistant,
|
||||
mock_stt_provider: MockSttProvider,
|
||||
init_components,
|
||||
pipeline_data: assist_pipeline.pipeline.PipelineData,
|
||||
snapshot: SnapshotAssertion,
|
||||
) -> None:
|
||||
"""Test using tts_audio_output with wav sets options correctly."""
|
||||
|
||||
def event_callback(event):
|
||||
pass
|
||||
|
||||
pipeline_store = pipeline_data.pipeline_store
|
||||
pipeline_id = pipeline_store.async_get_preferred_item()
|
||||
pipeline = assist_pipeline.pipeline.async_get_pipeline(hass, pipeline_id)
|
||||
|
||||
pipeline_input = assist_pipeline.pipeline.PipelineInput(
|
||||
tts_input="This is a test.",
|
||||
conversation_id=None,
|
||||
device_id=None,
|
||||
run=assist_pipeline.pipeline.PipelineRun(
|
||||
hass,
|
||||
context=Context(),
|
||||
pipeline=pipeline,
|
||||
start_stage=assist_pipeline.PipelineStage.TTS,
|
||||
end_stage=assist_pipeline.PipelineStage.TTS,
|
||||
event_callback=event_callback,
|
||||
tts_audio_output="wav",
|
||||
),
|
||||
)
|
||||
await pipeline_input.validate()
|
||||
|
||||
# Verify TTS audio settings
|
||||
assert pipeline_input.run.tts_options is not None
|
||||
assert pipeline_input.run.tts_options.get(tts.ATTR_PREFERRED_FORMAT) == "wav"
|
||||
assert pipeline_input.run.tts_options.get(tts.ATTR_PREFERRED_SAMPLE_RATE) == 16000
|
||||
assert pipeline_input.run.tts_options.get(tts.ATTR_PREFERRED_SAMPLE_CHANNELS) == 1
|
||||
|
|
|
@ -1,8 +1,10 @@
|
|||
"""Test ESPHome voice assistant server."""
|
||||
|
||||
import asyncio
|
||||
import io
|
||||
import socket
|
||||
from unittest.mock import Mock, patch
|
||||
import wave
|
||||
|
||||
from aioesphomeapi import VoiceAssistantEventType
|
||||
import pytest
|
||||
|
@ -340,9 +342,18 @@ async def test_send_tts(
|
|||
voice_assistant_udp_server_v2: VoiceAssistantUDPServer,
|
||||
) -> None:
|
||||
"""Test the UDP server calls sendto to transmit audio data to device."""
|
||||
with io.BytesIO() as wav_io:
|
||||
with wave.open(wav_io, "wb") as wav_file:
|
||||
wav_file.setframerate(16000)
|
||||
wav_file.setsampwidth(2)
|
||||
wav_file.setnchannels(1)
|
||||
wav_file.writeframes(bytes(_ONE_SECOND))
|
||||
|
||||
wav_bytes = wav_io.getvalue()
|
||||
|
||||
with patch(
|
||||
"homeassistant.components.esphome.voice_assistant.tts.async_get_media_source_audio",
|
||||
return_value=("raw", bytes(1024)),
|
||||
return_value=("wav", wav_bytes),
|
||||
):
|
||||
voice_assistant_udp_server_v2.transport = Mock(spec=asyncio.DatagramTransport)
|
||||
|
||||
|
@ -360,6 +371,63 @@ async def test_send_tts(
|
|||
voice_assistant_udp_server_v2.transport.sendto.assert_called()
|
||||
|
||||
|
||||
async def test_send_tts_wrong_sample_rate(
|
||||
hass: HomeAssistant,
|
||||
voice_assistant_udp_server_v2: VoiceAssistantUDPServer,
|
||||
) -> None:
|
||||
"""Test the UDP server calls sendto to transmit audio data to device."""
|
||||
with io.BytesIO() as wav_io:
|
||||
with wave.open(wav_io, "wb") as wav_file:
|
||||
wav_file.setframerate(22050) # should be 16000
|
||||
wav_file.setsampwidth(2)
|
||||
wav_file.setnchannels(1)
|
||||
wav_file.writeframes(bytes(_ONE_SECOND))
|
||||
|
||||
wav_bytes = wav_io.getvalue()
|
||||
|
||||
with patch(
|
||||
"homeassistant.components.esphome.voice_assistant.tts.async_get_media_source_audio",
|
||||
return_value=("wav", wav_bytes),
|
||||
), pytest.raises(ValueError):
|
||||
voice_assistant_udp_server_v2.transport = Mock(spec=asyncio.DatagramTransport)
|
||||
|
||||
voice_assistant_udp_server_v2._event_callback(
|
||||
PipelineEvent(
|
||||
type=PipelineEventType.TTS_END,
|
||||
data={
|
||||
"tts_output": {"media_id": _TEST_MEDIA_ID, "url": _TEST_OUTPUT_URL}
|
||||
},
|
||||
)
|
||||
)
|
||||
|
||||
assert voice_assistant_udp_server_v2._tts_task is not None
|
||||
await voice_assistant_udp_server_v2._tts_task # raises ValueError
|
||||
|
||||
|
||||
async def test_send_tts_wrong_format(
|
||||
hass: HomeAssistant,
|
||||
voice_assistant_udp_server_v2: VoiceAssistantUDPServer,
|
||||
) -> None:
|
||||
"""Test that only WAV audio will be streamed."""
|
||||
with patch(
|
||||
"homeassistant.components.esphome.voice_assistant.tts.async_get_media_source_audio",
|
||||
return_value=("raw", bytes(1024)),
|
||||
), pytest.raises(ValueError):
|
||||
voice_assistant_udp_server_v2.transport = Mock(spec=asyncio.DatagramTransport)
|
||||
|
||||
voice_assistant_udp_server_v2._event_callback(
|
||||
PipelineEvent(
|
||||
type=PipelineEventType.TTS_END,
|
||||
data={
|
||||
"tts_output": {"media_id": _TEST_MEDIA_ID, "url": _TEST_OUTPUT_URL}
|
||||
},
|
||||
)
|
||||
)
|
||||
|
||||
assert voice_assistant_udp_server_v2._tts_task is not None
|
||||
await voice_assistant_udp_server_v2._tts_task # raises ValueError
|
||||
|
||||
|
||||
async def test_wake_word(
|
||||
hass: HomeAssistant,
|
||||
voice_assistant_udp_server_v2: VoiceAssistantUDPServer,
|
||||
|
|
|
@ -2,13 +2,14 @@
|
|||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Generator
|
||||
from http import HTTPStatus
|
||||
from typing import Any
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
from gtts import gTTSError
|
||||
import pytest
|
||||
|
||||
from homeassistant.components import media_source, tts
|
||||
from homeassistant.components import tts
|
||||
from homeassistant.components.google_translate.const import CONF_TLD, DOMAIN
|
||||
from homeassistant.components.media_player import (
|
||||
ATTR_MEDIA_CONTENT_ID,
|
||||
|
@ -18,10 +19,11 @@ from homeassistant.components.media_player import (
|
|||
from homeassistant.config import async_process_ha_core_config
|
||||
from homeassistant.const import ATTR_ENTITY_ID, CONF_PLATFORM
|
||||
from homeassistant.core import HomeAssistant, ServiceCall
|
||||
from homeassistant.exceptions import HomeAssistantError
|
||||
from homeassistant.setup import async_setup_component
|
||||
|
||||
from tests.common import MockConfigEntry, async_mock_service
|
||||
from tests.components.tts.common import retrieve_media
|
||||
from tests.typing import ClientSessionGenerator
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
|
@ -35,15 +37,6 @@ def mock_tts_cache_dir_autouse(mock_tts_cache_dir):
|
|||
return mock_tts_cache_dir
|
||||
|
||||
|
||||
async def get_media_source_url(hass: HomeAssistant, media_content_id: str) -> str:
|
||||
"""Get the media source url."""
|
||||
if media_source.DOMAIN not in hass.config.components:
|
||||
assert await async_setup_component(hass, media_source.DOMAIN, {})
|
||||
|
||||
resolved = await media_source.async_resolve_media(hass, media_content_id, None)
|
||||
return resolved.url
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
async def calls(hass: HomeAssistant) -> list[ServiceCall]:
|
||||
"""Mock media player calls."""
|
||||
|
@ -128,6 +121,7 @@ async def mock_config_entry_setup(hass: HomeAssistant, config: dict[str, Any]) -
|
|||
async def test_tts_service(
|
||||
hass: HomeAssistant,
|
||||
mock_gtts: MagicMock,
|
||||
hass_client: ClientSessionGenerator,
|
||||
calls: list[ServiceCall],
|
||||
setup: str,
|
||||
tts_service: str,
|
||||
|
@ -142,9 +136,11 @@ async def test_tts_service(
|
|||
)
|
||||
|
||||
assert len(calls) == 1
|
||||
url = await get_media_source_url(hass, calls[0].data[ATTR_MEDIA_CONTENT_ID])
|
||||
assert (
|
||||
await retrieve_media(hass, hass_client, calls[0].data[ATTR_MEDIA_CONTENT_ID])
|
||||
== HTTPStatus.OK
|
||||
)
|
||||
assert len(mock_gtts.mock_calls) == 2
|
||||
assert url.endswith(".mp3")
|
||||
|
||||
assert mock_gtts.mock_calls[0][2] == {
|
||||
"text": "There is a person at the front door.",
|
||||
|
@ -180,6 +176,7 @@ async def test_tts_service(
|
|||
async def test_service_say_german_config(
|
||||
hass: HomeAssistant,
|
||||
mock_gtts: MagicMock,
|
||||
hass_client: ClientSessionGenerator,
|
||||
calls: list[ServiceCall],
|
||||
setup: str,
|
||||
tts_service: str,
|
||||
|
@ -194,7 +191,10 @@ async def test_service_say_german_config(
|
|||
)
|
||||
|
||||
assert len(calls) == 1
|
||||
await get_media_source_url(hass, calls[0].data[ATTR_MEDIA_CONTENT_ID])
|
||||
assert (
|
||||
await retrieve_media(hass, hass_client, calls[0].data[ATTR_MEDIA_CONTENT_ID])
|
||||
== HTTPStatus.OK
|
||||
)
|
||||
assert len(mock_gtts.mock_calls) == 2
|
||||
assert mock_gtts.mock_calls[0][2] == {
|
||||
"text": "There is a person at the front door.",
|
||||
|
@ -231,6 +231,7 @@ async def test_service_say_german_config(
|
|||
async def test_service_say_german_service(
|
||||
hass: HomeAssistant,
|
||||
mock_gtts: MagicMock,
|
||||
hass_client: ClientSessionGenerator,
|
||||
calls: list[ServiceCall],
|
||||
setup: str,
|
||||
tts_service: str,
|
||||
|
@ -245,7 +246,10 @@ async def test_service_say_german_service(
|
|||
)
|
||||
|
||||
assert len(calls) == 1
|
||||
await get_media_source_url(hass, calls[0].data[ATTR_MEDIA_CONTENT_ID])
|
||||
assert (
|
||||
await retrieve_media(hass, hass_client, calls[0].data[ATTR_MEDIA_CONTENT_ID])
|
||||
== HTTPStatus.OK
|
||||
)
|
||||
assert len(mock_gtts.mock_calls) == 2
|
||||
assert mock_gtts.mock_calls[0][2] == {
|
||||
"text": "There is a person at the front door.",
|
||||
|
@ -281,6 +285,7 @@ async def test_service_say_german_service(
|
|||
async def test_service_say_en_uk_config(
|
||||
hass: HomeAssistant,
|
||||
mock_gtts: MagicMock,
|
||||
hass_client: ClientSessionGenerator,
|
||||
calls: list[ServiceCall],
|
||||
setup: str,
|
||||
tts_service: str,
|
||||
|
@ -295,7 +300,10 @@ async def test_service_say_en_uk_config(
|
|||
)
|
||||
|
||||
assert len(calls) == 1
|
||||
await get_media_source_url(hass, calls[0].data[ATTR_MEDIA_CONTENT_ID])
|
||||
assert (
|
||||
await retrieve_media(hass, hass_client, calls[0].data[ATTR_MEDIA_CONTENT_ID])
|
||||
== HTTPStatus.OK
|
||||
)
|
||||
assert len(mock_gtts.mock_calls) == 2
|
||||
assert mock_gtts.mock_calls[0][2] == {
|
||||
"text": "There is a person at the front door.",
|
||||
|
@ -332,6 +340,7 @@ async def test_service_say_en_uk_config(
|
|||
async def test_service_say_en_uk_service(
|
||||
hass: HomeAssistant,
|
||||
mock_gtts: MagicMock,
|
||||
hass_client: ClientSessionGenerator,
|
||||
calls: list[ServiceCall],
|
||||
setup: str,
|
||||
tts_service: str,
|
||||
|
@ -346,7 +355,10 @@ async def test_service_say_en_uk_service(
|
|||
)
|
||||
|
||||
assert len(calls) == 1
|
||||
await get_media_source_url(hass, calls[0].data[ATTR_MEDIA_CONTENT_ID])
|
||||
assert (
|
||||
await retrieve_media(hass, hass_client, calls[0].data[ATTR_MEDIA_CONTENT_ID])
|
||||
== HTTPStatus.OK
|
||||
)
|
||||
assert len(mock_gtts.mock_calls) == 2
|
||||
assert mock_gtts.mock_calls[0][2] == {
|
||||
"text": "There is a person at the front door.",
|
||||
|
@ -383,6 +395,7 @@ async def test_service_say_en_uk_service(
|
|||
async def test_service_say_en_couk(
|
||||
hass: HomeAssistant,
|
||||
mock_gtts: MagicMock,
|
||||
hass_client: ClientSessionGenerator,
|
||||
calls: list[ServiceCall],
|
||||
setup: str,
|
||||
tts_service: str,
|
||||
|
@ -397,9 +410,11 @@ async def test_service_say_en_couk(
|
|||
)
|
||||
|
||||
assert len(calls) == 1
|
||||
url = await get_media_source_url(hass, calls[0].data[ATTR_MEDIA_CONTENT_ID])
|
||||
assert (
|
||||
await retrieve_media(hass, hass_client, calls[0].data[ATTR_MEDIA_CONTENT_ID])
|
||||
== HTTPStatus.OK
|
||||
)
|
||||
assert len(mock_gtts.mock_calls) == 2
|
||||
assert url.endswith(".mp3")
|
||||
|
||||
assert mock_gtts.mock_calls[0][2] == {
|
||||
"text": "There is a person at the front door.",
|
||||
|
@ -434,6 +449,7 @@ async def test_service_say_en_couk(
|
|||
async def test_service_say_error(
|
||||
hass: HomeAssistant,
|
||||
mock_gtts: MagicMock,
|
||||
hass_client: ClientSessionGenerator,
|
||||
calls: list[ServiceCall],
|
||||
setup: str,
|
||||
tts_service: str,
|
||||
|
@ -450,6 +466,8 @@ async def test_service_say_error(
|
|||
)
|
||||
|
||||
assert len(calls) == 1
|
||||
with pytest.raises(HomeAssistantError):
|
||||
await get_media_source_url(hass, calls[0].data[ATTR_MEDIA_CONTENT_ID])
|
||||
assert (
|
||||
await retrieve_media(hass, hass_client, calls[0].data[ATTR_MEDIA_CONTENT_ID])
|
||||
== HTTPStatus.NOT_FOUND
|
||||
)
|
||||
assert len(mock_gtts.mock_calls) == 2
|
||||
|
|
|
@ -1,9 +1,12 @@
|
|||
"""The tests for the MaryTTS speech platform."""
|
||||
from http import HTTPStatus
|
||||
import io
|
||||
from unittest.mock import patch
|
||||
import wave
|
||||
|
||||
import pytest
|
||||
|
||||
from homeassistant.components import media_source, tts
|
||||
from homeassistant.components import tts
|
||||
from homeassistant.components.media_player import (
|
||||
ATTR_MEDIA_CONTENT_ID,
|
||||
DOMAIN as DOMAIN_MP,
|
||||
|
@ -13,15 +16,19 @@ from homeassistant.core import HomeAssistant
|
|||
from homeassistant.setup import async_setup_component
|
||||
|
||||
from tests.common import assert_setup_component, async_mock_service
|
||||
from tests.components.tts.common import retrieve_media
|
||||
from tests.typing import ClientSessionGenerator
|
||||
|
||||
|
||||
async def get_media_source_url(hass, media_content_id):
|
||||
"""Get the media source url."""
|
||||
if media_source.DOMAIN not in hass.config.components:
|
||||
assert await async_setup_component(hass, media_source.DOMAIN, {})
|
||||
def get_empty_wav() -> bytes:
|
||||
"""Get bytes for empty WAV file."""
|
||||
with io.BytesIO() as wav_io:
|
||||
with wave.open(wav_io, "wb") as wav_file:
|
||||
wav_file.setframerate(22050)
|
||||
wav_file.setsampwidth(2)
|
||||
wav_file.setnchannels(1)
|
||||
|
||||
resolved = await media_source.async_resolve_media(hass, media_content_id, None)
|
||||
return resolved.url
|
||||
return wav_io.getvalue()
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
|
@ -39,7 +46,9 @@ async def test_setup_component(hass: HomeAssistant) -> None:
|
|||
await hass.async_block_till_done()
|
||||
|
||||
|
||||
async def test_service_say(hass: HomeAssistant) -> None:
|
||||
async def test_service_say(
|
||||
hass: HomeAssistant, hass_client: ClientSessionGenerator
|
||||
) -> None:
|
||||
"""Test service call say."""
|
||||
calls = async_mock_service(hass, DOMAIN_MP, SERVICE_PLAY_MEDIA)
|
||||
|
||||
|
@ -51,7 +60,7 @@ async def test_service_say(hass: HomeAssistant) -> None:
|
|||
|
||||
with patch(
|
||||
"homeassistant.components.marytts.tts.MaryTTS.speak",
|
||||
return_value=b"audio",
|
||||
return_value=get_empty_wav(),
|
||||
) as mock_speak:
|
||||
await hass.services.async_call(
|
||||
tts.DOMAIN,
|
||||
|
@ -63,16 +72,22 @@ async def test_service_say(hass: HomeAssistant) -> None:
|
|||
blocking=True,
|
||||
)
|
||||
|
||||
url = await get_media_source_url(hass, calls[0].data[ATTR_MEDIA_CONTENT_ID])
|
||||
assert (
|
||||
await retrieve_media(
|
||||
hass, hass_client, calls[0].data[ATTR_MEDIA_CONTENT_ID]
|
||||
)
|
||||
== HTTPStatus.OK
|
||||
)
|
||||
|
||||
mock_speak.assert_called_once()
|
||||
mock_speak.assert_called_with("HomeAssistant", {})
|
||||
|
||||
assert len(calls) == 1
|
||||
assert url.endswith(".wav")
|
||||
|
||||
|
||||
async def test_service_say_with_effect(hass: HomeAssistant) -> None:
|
||||
async def test_service_say_with_effect(
|
||||
hass: HomeAssistant, hass_client: ClientSessionGenerator
|
||||
) -> None:
|
||||
"""Test service call say with effects."""
|
||||
calls = async_mock_service(hass, DOMAIN_MP, SERVICE_PLAY_MEDIA)
|
||||
|
||||
|
@ -84,7 +99,7 @@ async def test_service_say_with_effect(hass: HomeAssistant) -> None:
|
|||
|
||||
with patch(
|
||||
"homeassistant.components.marytts.tts.MaryTTS.speak",
|
||||
return_value=b"audio",
|
||||
return_value=get_empty_wav(),
|
||||
) as mock_speak:
|
||||
await hass.services.async_call(
|
||||
tts.DOMAIN,
|
||||
|
@ -96,16 +111,22 @@ async def test_service_say_with_effect(hass: HomeAssistant) -> None:
|
|||
blocking=True,
|
||||
)
|
||||
|
||||
url = await get_media_source_url(hass, calls[0].data[ATTR_MEDIA_CONTENT_ID])
|
||||
assert (
|
||||
await retrieve_media(
|
||||
hass, hass_client, calls[0].data[ATTR_MEDIA_CONTENT_ID]
|
||||
)
|
||||
== HTTPStatus.OK
|
||||
)
|
||||
|
||||
mock_speak.assert_called_once()
|
||||
mock_speak.assert_called_with("HomeAssistant", {"Volume": "amount:2.0;"})
|
||||
|
||||
assert len(calls) == 1
|
||||
assert url.endswith(".wav")
|
||||
|
||||
|
||||
async def test_service_say_http_error(hass: HomeAssistant) -> None:
|
||||
async def test_service_say_http_error(
|
||||
hass: HomeAssistant, hass_client: ClientSessionGenerator
|
||||
) -> None:
|
||||
"""Test service call say."""
|
||||
calls = async_mock_service(hass, DOMAIN_MP, SERVICE_PLAY_MEDIA)
|
||||
|
||||
|
@ -129,7 +150,11 @@ async def test_service_say_http_error(hass: HomeAssistant) -> None:
|
|||
)
|
||||
await hass.async_block_till_done()
|
||||
|
||||
with pytest.raises(Exception):
|
||||
await get_media_source_url(hass, calls[0].data[ATTR_MEDIA_CONTENT_ID])
|
||||
assert (
|
||||
await retrieve_media(
|
||||
hass, hass_client, calls[0].data[ATTR_MEDIA_CONTENT_ID]
|
||||
)
|
||||
== HTTPStatus.NOT_FOUND
|
||||
)
|
||||
|
||||
mock_speak.assert_called_once()
|
||||
|
|
|
@ -1,10 +1,11 @@
|
|||
"""Tests for Microsoft text-to-speech."""
|
||||
from http import HTTPStatus
|
||||
from unittest.mock import patch
|
||||
|
||||
from pycsspeechtts import pycsspeechtts
|
||||
import pytest
|
||||
|
||||
from homeassistant.components import media_source, tts
|
||||
from homeassistant.components import tts
|
||||
from homeassistant.components.media_player import (
|
||||
ATTR_MEDIA_CONTENT_ID,
|
||||
DOMAIN as DOMAIN_MP,
|
||||
|
@ -13,19 +14,12 @@ from homeassistant.components.media_player import (
|
|||
from homeassistant.components.microsoft.tts import SUPPORTED_LANGUAGES
|
||||
from homeassistant.config import async_process_ha_core_config
|
||||
from homeassistant.core import HomeAssistant
|
||||
from homeassistant.exceptions import HomeAssistantError, ServiceNotFound
|
||||
from homeassistant.exceptions import ServiceNotFound
|
||||
from homeassistant.setup import async_setup_component
|
||||
|
||||
from tests.common import async_mock_service
|
||||
|
||||
|
||||
async def get_media_source_url(hass: HomeAssistant, media_content_id):
|
||||
"""Get the media source url."""
|
||||
if media_source.DOMAIN not in hass.config.components:
|
||||
assert await async_setup_component(hass, media_source.DOMAIN, {})
|
||||
|
||||
resolved = await media_source.async_resolve_media(hass, media_content_id, None)
|
||||
return resolved.url
|
||||
from tests.components.tts.common import retrieve_media
|
||||
from tests.typing import ClientSessionGenerator
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
|
@ -58,7 +52,9 @@ def mock_tts():
|
|||
yield mock_tts
|
||||
|
||||
|
||||
async def test_service_say(hass: HomeAssistant, mock_tts, calls) -> None:
|
||||
async def test_service_say(
|
||||
hass: HomeAssistant, hass_client: ClientSessionGenerator, mock_tts, calls
|
||||
) -> None:
|
||||
"""Test service call say."""
|
||||
|
||||
await async_setup_component(
|
||||
|
@ -76,9 +72,12 @@ async def test_service_say(hass: HomeAssistant, mock_tts, calls) -> None:
|
|||
)
|
||||
|
||||
assert len(calls) == 1
|
||||
url = await get_media_source_url(hass, calls[0].data[ATTR_MEDIA_CONTENT_ID])
|
||||
assert (
|
||||
await retrieve_media(hass, hass_client, calls[0].data[ATTR_MEDIA_CONTENT_ID])
|
||||
== HTTPStatus.OK
|
||||
)
|
||||
|
||||
assert len(mock_tts.mock_calls) == 2
|
||||
assert url.endswith(".mp3")
|
||||
|
||||
assert mock_tts.mock_calls[1][2] == {
|
||||
"language": "en-us",
|
||||
|
@ -93,7 +92,9 @@ async def test_service_say(hass: HomeAssistant, mock_tts, calls) -> None:
|
|||
}
|
||||
|
||||
|
||||
async def test_service_say_en_gb_config(hass: HomeAssistant, mock_tts, calls) -> None:
|
||||
async def test_service_say_en_gb_config(
|
||||
hass: HomeAssistant, hass_client: ClientSessionGenerator, mock_tts, calls
|
||||
) -> None:
|
||||
"""Test service call say with en-gb code in the config."""
|
||||
|
||||
await async_setup_component(
|
||||
|
@ -120,7 +121,11 @@ async def test_service_say_en_gb_config(hass: HomeAssistant, mock_tts, calls) ->
|
|||
)
|
||||
|
||||
assert len(calls) == 1
|
||||
await get_media_source_url(hass, calls[0].data[ATTR_MEDIA_CONTENT_ID])
|
||||
assert (
|
||||
await retrieve_media(hass, hass_client, calls[0].data[ATTR_MEDIA_CONTENT_ID])
|
||||
== HTTPStatus.OK
|
||||
)
|
||||
|
||||
assert len(mock_tts.mock_calls) == 2
|
||||
assert mock_tts.mock_calls[1][2] == {
|
||||
"language": "en-gb",
|
||||
|
@ -135,7 +140,9 @@ async def test_service_say_en_gb_config(hass: HomeAssistant, mock_tts, calls) ->
|
|||
}
|
||||
|
||||
|
||||
async def test_service_say_en_gb_service(hass: HomeAssistant, mock_tts, calls) -> None:
|
||||
async def test_service_say_en_gb_service(
|
||||
hass: HomeAssistant, hass_client: ClientSessionGenerator, mock_tts, calls
|
||||
) -> None:
|
||||
"""Test service call say with en-gb code in the service."""
|
||||
|
||||
await async_setup_component(
|
||||
|
@ -157,7 +164,11 @@ async def test_service_say_en_gb_service(hass: HomeAssistant, mock_tts, calls) -
|
|||
)
|
||||
|
||||
assert len(calls) == 1
|
||||
await get_media_source_url(hass, calls[0].data[ATTR_MEDIA_CONTENT_ID])
|
||||
assert (
|
||||
await retrieve_media(hass, hass_client, calls[0].data[ATTR_MEDIA_CONTENT_ID])
|
||||
== HTTPStatus.OK
|
||||
)
|
||||
|
||||
assert len(mock_tts.mock_calls) == 2
|
||||
assert mock_tts.mock_calls[1][2] == {
|
||||
"language": "en-gb",
|
||||
|
@ -172,7 +183,9 @@ async def test_service_say_en_gb_service(hass: HomeAssistant, mock_tts, calls) -
|
|||
}
|
||||
|
||||
|
||||
async def test_service_say_fa_ir_config(hass: HomeAssistant, mock_tts, calls) -> None:
|
||||
async def test_service_say_fa_ir_config(
|
||||
hass: HomeAssistant, hass_client: ClientSessionGenerator, mock_tts, calls
|
||||
) -> None:
|
||||
"""Test service call say with fa-ir code in the config."""
|
||||
|
||||
await async_setup_component(
|
||||
|
@ -199,7 +212,11 @@ async def test_service_say_fa_ir_config(hass: HomeAssistant, mock_tts, calls) ->
|
|||
)
|
||||
|
||||
assert len(calls) == 1
|
||||
await get_media_source_url(hass, calls[0].data[ATTR_MEDIA_CONTENT_ID])
|
||||
assert (
|
||||
await retrieve_media(hass, hass_client, calls[0].data[ATTR_MEDIA_CONTENT_ID])
|
||||
== HTTPStatus.OK
|
||||
)
|
||||
|
||||
assert len(mock_tts.mock_calls) == 2
|
||||
assert mock_tts.mock_calls[1][2] == {
|
||||
"language": "fa-ir",
|
||||
|
@ -214,7 +231,9 @@ async def test_service_say_fa_ir_config(hass: HomeAssistant, mock_tts, calls) ->
|
|||
}
|
||||
|
||||
|
||||
async def test_service_say_fa_ir_service(hass: HomeAssistant, mock_tts, calls) -> None:
|
||||
async def test_service_say_fa_ir_service(
|
||||
hass: HomeAssistant, hass_client: ClientSessionGenerator, mock_tts, calls
|
||||
) -> None:
|
||||
"""Test service call say with fa-ir code in the service."""
|
||||
|
||||
config = {
|
||||
|
@ -240,7 +259,11 @@ async def test_service_say_fa_ir_service(hass: HomeAssistant, mock_tts, calls) -
|
|||
)
|
||||
|
||||
assert len(calls) == 1
|
||||
await get_media_source_url(hass, calls[0].data[ATTR_MEDIA_CONTENT_ID])
|
||||
assert (
|
||||
await retrieve_media(hass, hass_client, calls[0].data[ATTR_MEDIA_CONTENT_ID])
|
||||
== HTTPStatus.OK
|
||||
)
|
||||
|
||||
assert len(mock_tts.mock_calls) == 2
|
||||
assert mock_tts.mock_calls[1][2] == {
|
||||
"language": "fa-ir",
|
||||
|
@ -295,7 +318,9 @@ async def test_invalid_language(hass: HomeAssistant, mock_tts, calls) -> None:
|
|||
assert len(mock_tts.mock_calls) == 0
|
||||
|
||||
|
||||
async def test_service_say_error(hass: HomeAssistant, mock_tts, calls) -> None:
|
||||
async def test_service_say_error(
|
||||
hass: HomeAssistant, hass_client: ClientSessionGenerator, mock_tts, calls
|
||||
) -> None:
|
||||
"""Test service call say with http error."""
|
||||
mock_tts.return_value.speak.side_effect = pycsspeechtts.requests.HTTPError
|
||||
await async_setup_component(
|
||||
|
@ -313,6 +338,9 @@ async def test_service_say_error(hass: HomeAssistant, mock_tts, calls) -> None:
|
|||
)
|
||||
|
||||
assert len(calls) == 1
|
||||
with pytest.raises(HomeAssistantError):
|
||||
await get_media_source_url(hass, calls[0].data[ATTR_MEDIA_CONTENT_ID])
|
||||
assert (
|
||||
await retrieve_media(hass, hass_client, calls[0].data[ATTR_MEDIA_CONTENT_ID])
|
||||
== HTTPStatus.NOT_FOUND
|
||||
)
|
||||
|
||||
assert len(mock_tts.mock_calls) == 2
|
||||
|
|
|
@ -2,6 +2,7 @@
|
|||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Generator
|
||||
from http import HTTPStatus
|
||||
from typing import Any
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
|
@ -32,6 +33,7 @@ from tests.common import (
|
|||
mock_integration,
|
||||
mock_platform,
|
||||
)
|
||||
from tests.typing import ClientSessionGenerator
|
||||
|
||||
DEFAULT_LANG = "en_US"
|
||||
SUPPORT_LANGUAGES = ["de_CH", "de_DE", "en_GB", "en_US"]
|
||||
|
@ -103,6 +105,20 @@ async def get_media_source_url(hass: HomeAssistant, media_content_id: str) -> st
|
|||
return resolved.url
|
||||
|
||||
|
||||
async def retrieve_media(
|
||||
hass: HomeAssistant, hass_client: ClientSessionGenerator, media_content_id: str
|
||||
) -> HTTPStatus:
|
||||
"""Get the media source url."""
|
||||
url = await get_media_source_url(hass, media_content_id)
|
||||
|
||||
# Ensure media has been generated by requesting it
|
||||
await hass.async_block_till_done()
|
||||
client = await hass_client()
|
||||
req = await client.get(url)
|
||||
|
||||
return req.status
|
||||
|
||||
|
||||
class BaseProvider:
|
||||
"""Test speech API provider."""
|
||||
|
||||
|
|
|
@ -6,7 +6,7 @@ from unittest.mock import MagicMock, patch
|
|||
|
||||
import pytest
|
||||
|
||||
from homeassistant.components import tts
|
||||
from homeassistant.components import ffmpeg, tts
|
||||
from homeassistant.components.media_player import (
|
||||
ATTR_MEDIA_ANNOUNCE,
|
||||
ATTR_MEDIA_CONTENT_ID,
|
||||
|
@ -15,7 +15,6 @@ from homeassistant.components.media_player import (
|
|||
SERVICE_PLAY_MEDIA,
|
||||
MediaType,
|
||||
)
|
||||
from homeassistant.components.media_source import Unresolvable
|
||||
from homeassistant.config_entries import ConfigEntryState
|
||||
from homeassistant.const import ATTR_ENTITY_ID, STATE_UNKNOWN
|
||||
from homeassistant.core import HomeAssistant, State
|
||||
|
@ -33,6 +32,7 @@ from .common import (
|
|||
get_media_source_url,
|
||||
mock_config_entry_setup,
|
||||
mock_setup,
|
||||
retrieve_media,
|
||||
)
|
||||
|
||||
from tests.common import async_mock_service, mock_restore_cache
|
||||
|
@ -75,7 +75,9 @@ async def test_default_entity_attributes() -> None:
|
|||
|
||||
|
||||
async def test_config_entry_unload(
|
||||
hass: HomeAssistant, mock_tts_entity: MockTTSEntity
|
||||
hass: HomeAssistant,
|
||||
hass_client: ClientSessionGenerator,
|
||||
mock_tts_entity: MockTTSEntity,
|
||||
) -> None:
|
||||
"""Test we can unload config entry."""
|
||||
entity_id = f"{tts.DOMAIN}.{TEST_DOMAIN}"
|
||||
|
@ -104,7 +106,12 @@ async def test_config_entry_unload(
|
|||
)
|
||||
assert len(calls) == 1
|
||||
|
||||
await get_media_source_url(hass, calls[0].data[ATTR_MEDIA_CONTENT_ID])
|
||||
assert (
|
||||
await retrieve_media(
|
||||
hass, hass_client, calls[0].data[ATTR_MEDIA_CONTENT_ID]
|
||||
)
|
||||
== HTTPStatus.OK
|
||||
)
|
||||
await hass.async_block_till_done()
|
||||
|
||||
state = hass.states.get(entity_id)
|
||||
|
@ -1159,6 +1166,7 @@ class MockEntityEmpty(MockTTSEntity):
|
|||
)
|
||||
async def test_service_get_tts_error(
|
||||
hass: HomeAssistant,
|
||||
hass_client: ClientSessionGenerator,
|
||||
setup: str,
|
||||
tts_service: str,
|
||||
service_data: dict[str, Any],
|
||||
|
@ -1173,8 +1181,10 @@ async def test_service_get_tts_error(
|
|||
blocking=True,
|
||||
)
|
||||
assert len(calls) == 1
|
||||
with pytest.raises(Unresolvable):
|
||||
await get_media_source_url(hass, calls[0].data[ATTR_MEDIA_CONTENT_ID])
|
||||
assert (
|
||||
await retrieve_media(hass, hass_client, calls[0].data[ATTR_MEDIA_CONTENT_ID])
|
||||
== HTTPStatus.NOT_FOUND
|
||||
)
|
||||
|
||||
|
||||
async def test_load_cache_legacy_retrieve_without_mem_cache(
|
||||
|
@ -1454,7 +1464,11 @@ async def test_legacy_fetching_in_async(
|
|||
|
||||
# Test async_get_media_source_audio
|
||||
media_source_id = tts.generate_media_source_id(
|
||||
hass, "test message", "test", "en_US", None, None
|
||||
hass,
|
||||
"test message",
|
||||
"test",
|
||||
"en_US",
|
||||
cache=None,
|
||||
)
|
||||
|
||||
task = hass.async_create_task(
|
||||
|
@ -1508,16 +1522,6 @@ async def test_fetching_in_async(
|
|||
class EntityWithAsyncFetching(MockTTSEntity):
|
||||
"""Entity that supports audio output option."""
|
||||
|
||||
@property
|
||||
def supported_options(self) -> list[str]:
|
||||
"""Return list of supported options like voice, emotions."""
|
||||
return [tts.ATTR_AUDIO_OUTPUT]
|
||||
|
||||
@property
|
||||
def default_options(self) -> dict[str, str]:
|
||||
"""Return a dict including the default options."""
|
||||
return {tts.ATTR_AUDIO_OUTPUT: "mp3"}
|
||||
|
||||
async def async_get_tts_audio(
|
||||
self, message: str, language: str, options: dict[str, Any]
|
||||
) -> tts.TtsAudioType:
|
||||
|
@ -1527,7 +1531,11 @@ async def test_fetching_in_async(
|
|||
|
||||
# Test async_get_media_source_audio
|
||||
media_source_id = tts.generate_media_source_id(
|
||||
hass, "test message", "tts.test", "en_US", None, None
|
||||
hass,
|
||||
"test message",
|
||||
"tts.test",
|
||||
"en_US",
|
||||
cache=None,
|
||||
)
|
||||
|
||||
task = hass.async_create_task(
|
||||
|
@ -1751,3 +1759,12 @@ async def test_ws_list_voices(
|
|||
{"voice_id": "fran_drescher", "name": "Fran Drescher"},
|
||||
]
|
||||
}
|
||||
|
||||
|
||||
async def test_async_convert_audio_error(hass: HomeAssistant) -> None:
|
||||
"""Test that ffmpeg failing during audio conversion will raise an error."""
|
||||
assert await async_setup_component(hass, ffmpeg.DOMAIN, {})
|
||||
|
||||
with pytest.raises(RuntimeError):
|
||||
# Simulate a bad WAV file
|
||||
await tts.async_convert_audio(hass, "wav", bytes(0), "mp3")
|
||||
|
|
|
@ -1,4 +1,5 @@
|
|||
"""Tests for TTS media source."""
|
||||
from http import HTTPStatus
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import pytest
|
||||
|
@ -14,8 +15,11 @@ from .common import (
|
|||
MockTTSEntity,
|
||||
mock_config_entry_setup,
|
||||
mock_setup,
|
||||
retrieve_media,
|
||||
)
|
||||
|
||||
from tests.typing import ClientSessionGenerator
|
||||
|
||||
|
||||
class MSEntity(MockTTSEntity):
|
||||
"""Test speech API entity."""
|
||||
|
@ -88,16 +92,18 @@ async def test_browsing(hass: HomeAssistant, setup: str) -> None:
|
|||
|
||||
|
||||
@pytest.mark.parametrize("mock_provider", [MSProvider(DEFAULT_LANG)])
|
||||
async def test_legacy_resolving(hass: HomeAssistant, mock_provider: MSProvider) -> None:
|
||||
async def test_legacy_resolving(
|
||||
hass: HomeAssistant, hass_client: ClientSessionGenerator, mock_provider: MSProvider
|
||||
) -> None:
|
||||
"""Test resolving legacy provider."""
|
||||
await mock_setup(hass, mock_provider)
|
||||
mock_get_tts_audio = mock_provider.get_tts_audio
|
||||
|
||||
media = await media_source.async_resolve_media(
|
||||
hass, "media-source://tts/test?message=Hello%20World", None
|
||||
)
|
||||
media_id = "media-source://tts/test?message=Hello%20World"
|
||||
media = await media_source.async_resolve_media(hass, media_id, None)
|
||||
assert media.url.startswith("/api/tts_proxy/")
|
||||
assert media.mime_type == "audio/mpeg"
|
||||
assert await retrieve_media(hass, hass_client, media_id) == HTTPStatus.OK
|
||||
|
||||
assert len(mock_get_tts_audio.mock_calls) == 1
|
||||
message, language = mock_get_tts_audio.mock_calls[0][1]
|
||||
|
@ -107,13 +113,11 @@ async def test_legacy_resolving(hass: HomeAssistant, mock_provider: MSProvider)
|
|||
|
||||
# Pass language and options
|
||||
mock_get_tts_audio.reset_mock()
|
||||
media = await media_source.async_resolve_media(
|
||||
hass,
|
||||
"media-source://tts/test?message=Bye%20World&language=de_DE&voice=Paulus",
|
||||
None,
|
||||
)
|
||||
media_id = "media-source://tts/test?message=Bye%20World&language=de_DE&voice=Paulus"
|
||||
media = await media_source.async_resolve_media(hass, media_id, None)
|
||||
assert media.url.startswith("/api/tts_proxy/")
|
||||
assert media.mime_type == "audio/mpeg"
|
||||
assert await retrieve_media(hass, hass_client, media_id) == HTTPStatus.OK
|
||||
|
||||
assert len(mock_get_tts_audio.mock_calls) == 1
|
||||
message, language = mock_get_tts_audio.mock_calls[0][1]
|
||||
|
@ -123,16 +127,18 @@ async def test_legacy_resolving(hass: HomeAssistant, mock_provider: MSProvider)
|
|||
|
||||
|
||||
@pytest.mark.parametrize("mock_tts_entity", [MSEntity(DEFAULT_LANG)])
|
||||
async def test_resolving(hass: HomeAssistant, mock_tts_entity: MSEntity) -> None:
|
||||
async def test_resolving(
|
||||
hass: HomeAssistant, hass_client: ClientSessionGenerator, mock_tts_entity: MSEntity
|
||||
) -> None:
|
||||
"""Test resolving entity."""
|
||||
await mock_config_entry_setup(hass, mock_tts_entity)
|
||||
mock_get_tts_audio = mock_tts_entity.get_tts_audio
|
||||
|
||||
media = await media_source.async_resolve_media(
|
||||
hass, "media-source://tts/tts.test?message=Hello%20World", None
|
||||
)
|
||||
media_id = "media-source://tts/tts.test?message=Hello%20World"
|
||||
media = await media_source.async_resolve_media(hass, media_id, None)
|
||||
assert media.url.startswith("/api/tts_proxy/")
|
||||
assert media.mime_type == "audio/mpeg"
|
||||
assert await retrieve_media(hass, hass_client, media_id) == HTTPStatus.OK
|
||||
|
||||
assert len(mock_get_tts_audio.mock_calls) == 1
|
||||
message, language = mock_get_tts_audio.mock_calls[0][1]
|
||||
|
@ -142,13 +148,13 @@ async def test_resolving(hass: HomeAssistant, mock_tts_entity: MSEntity) -> None
|
|||
|
||||
# Pass language and options
|
||||
mock_get_tts_audio.reset_mock()
|
||||
media = await media_source.async_resolve_media(
|
||||
hass,
|
||||
"media-source://tts/tts.test?message=Bye%20World&language=de_DE&voice=Paulus",
|
||||
None,
|
||||
media_id = (
|
||||
"media-source://tts/tts.test?message=Bye%20World&language=de_DE&voice=Paulus"
|
||||
)
|
||||
media = await media_source.async_resolve_media(hass, media_id, None)
|
||||
assert media.url.startswith("/api/tts_proxy/")
|
||||
assert media.mime_type == "audio/mpeg"
|
||||
assert await retrieve_media(hass, hass_client, media_id) == HTTPStatus.OK
|
||||
|
||||
assert len(mock_get_tts_audio.mock_calls) == 1
|
||||
message, language = mock_get_tts_audio.mock_calls[0][1]
|
||||
|
|
|
@ -4,18 +4,19 @@ from http import HTTPStatus
|
|||
|
||||
import pytest
|
||||
|
||||
from homeassistant.components import media_source, tts
|
||||
from homeassistant.components import tts
|
||||
from homeassistant.components.media_player import (
|
||||
ATTR_MEDIA_CONTENT_ID,
|
||||
DOMAIN as DOMAIN_MP,
|
||||
SERVICE_PLAY_MEDIA,
|
||||
)
|
||||
from homeassistant.core import HomeAssistant
|
||||
from homeassistant.exceptions import HomeAssistantError
|
||||
from homeassistant.setup import async_setup_component
|
||||
|
||||
from tests.common import assert_setup_component, async_mock_service
|
||||
from tests.components.tts.common import retrieve_media
|
||||
from tests.test_util.aiohttp import AiohttpClientMocker
|
||||
from tests.typing import ClientSessionGenerator
|
||||
|
||||
URL = "https://api.voicerss.org/"
|
||||
FORM_DATA = {
|
||||
|
@ -38,15 +39,6 @@ def mock_tts_cache_dir_autouse(mock_tts_cache_dir):
|
|||
return mock_tts_cache_dir
|
||||
|
||||
|
||||
async def get_media_source_url(hass, media_content_id):
|
||||
"""Get the media source url."""
|
||||
if media_source.DOMAIN not in hass.config.components:
|
||||
assert await async_setup_component(hass, media_source.DOMAIN, {})
|
||||
|
||||
resolved = await media_source.async_resolve_media(hass, media_content_id, None)
|
||||
return resolved.url
|
||||
|
||||
|
||||
async def test_setup_component(hass: HomeAssistant) -> None:
|
||||
"""Test setup component."""
|
||||
config = {tts.DOMAIN: {"platform": "voicerss", "api_key": "1234567xx"}}
|
||||
|
@ -66,7 +58,9 @@ async def test_setup_component_without_api_key(hass: HomeAssistant) -> None:
|
|||
|
||||
|
||||
async def test_service_say(
|
||||
hass: HomeAssistant, aioclient_mock: AiohttpClientMocker
|
||||
hass: HomeAssistant,
|
||||
hass_client: ClientSessionGenerator,
|
||||
aioclient_mock: AiohttpClientMocker,
|
||||
) -> None:
|
||||
"""Test service call say."""
|
||||
calls = async_mock_service(hass, DOMAIN_MP, SERVICE_PLAY_MEDIA)
|
||||
|
@ -90,14 +84,18 @@ async def test_service_say(
|
|||
await hass.async_block_till_done()
|
||||
|
||||
assert len(calls) == 1
|
||||
url = await get_media_source_url(hass, calls[0].data[ATTR_MEDIA_CONTENT_ID])
|
||||
assert url.endswith(".mp3")
|
||||
assert (
|
||||
await retrieve_media(hass, hass_client, calls[0].data[ATTR_MEDIA_CONTENT_ID])
|
||||
== HTTPStatus.OK
|
||||
)
|
||||
assert len(aioclient_mock.mock_calls) == 1
|
||||
assert aioclient_mock.mock_calls[0][2] == FORM_DATA
|
||||
|
||||
|
||||
async def test_service_say_german_config(
|
||||
hass: HomeAssistant, aioclient_mock: AiohttpClientMocker
|
||||
hass: HomeAssistant,
|
||||
hass_client: ClientSessionGenerator,
|
||||
aioclient_mock: AiohttpClientMocker,
|
||||
) -> None:
|
||||
"""Test service call say with german code in the config."""
|
||||
calls = async_mock_service(hass, DOMAIN_MP, SERVICE_PLAY_MEDIA)
|
||||
|
@ -128,13 +126,18 @@ async def test_service_say_german_config(
|
|||
await hass.async_block_till_done()
|
||||
|
||||
assert len(calls) == 1
|
||||
await get_media_source_url(hass, calls[0].data[ATTR_MEDIA_CONTENT_ID])
|
||||
assert (
|
||||
await retrieve_media(hass, hass_client, calls[0].data[ATTR_MEDIA_CONTENT_ID])
|
||||
== HTTPStatus.OK
|
||||
)
|
||||
assert len(aioclient_mock.mock_calls) == 1
|
||||
assert aioclient_mock.mock_calls[0][2] == form_data
|
||||
|
||||
|
||||
async def test_service_say_german_service(
|
||||
hass: HomeAssistant, aioclient_mock: AiohttpClientMocker
|
||||
hass: HomeAssistant,
|
||||
hass_client: ClientSessionGenerator,
|
||||
aioclient_mock: AiohttpClientMocker,
|
||||
) -> None:
|
||||
"""Test service call say with german code in the service."""
|
||||
calls = async_mock_service(hass, DOMAIN_MP, SERVICE_PLAY_MEDIA)
|
||||
|
@ -160,13 +163,18 @@ async def test_service_say_german_service(
|
|||
await hass.async_block_till_done()
|
||||
|
||||
assert len(calls) == 1
|
||||
await get_media_source_url(hass, calls[0].data[ATTR_MEDIA_CONTENT_ID])
|
||||
assert (
|
||||
await retrieve_media(hass, hass_client, calls[0].data[ATTR_MEDIA_CONTENT_ID])
|
||||
== HTTPStatus.OK
|
||||
)
|
||||
assert len(aioclient_mock.mock_calls) == 1
|
||||
assert aioclient_mock.mock_calls[0][2] == form_data
|
||||
|
||||
|
||||
async def test_service_say_error(
|
||||
hass: HomeAssistant, aioclient_mock: AiohttpClientMocker
|
||||
hass: HomeAssistant,
|
||||
hass_client: ClientSessionGenerator,
|
||||
aioclient_mock: AiohttpClientMocker,
|
||||
) -> None:
|
||||
"""Test service call say with http response 400."""
|
||||
calls = async_mock_service(hass, DOMAIN_MP, SERVICE_PLAY_MEDIA)
|
||||
|
@ -189,14 +197,18 @@ async def test_service_say_error(
|
|||
)
|
||||
await hass.async_block_till_done()
|
||||
|
||||
with pytest.raises(HomeAssistantError):
|
||||
await get_media_source_url(hass, calls[0].data[ATTR_MEDIA_CONTENT_ID])
|
||||
assert (
|
||||
await retrieve_media(hass, hass_client, calls[0].data[ATTR_MEDIA_CONTENT_ID])
|
||||
== HTTPStatus.NOT_FOUND
|
||||
)
|
||||
assert len(aioclient_mock.mock_calls) == 1
|
||||
assert aioclient_mock.mock_calls[0][2] == FORM_DATA
|
||||
|
||||
|
||||
async def test_service_say_timeout(
|
||||
hass: HomeAssistant, aioclient_mock: AiohttpClientMocker
|
||||
hass: HomeAssistant,
|
||||
hass_client: ClientSessionGenerator,
|
||||
aioclient_mock: AiohttpClientMocker,
|
||||
) -> None:
|
||||
"""Test service call say with http timeout."""
|
||||
calls = async_mock_service(hass, DOMAIN_MP, SERVICE_PLAY_MEDIA)
|
||||
|
@ -219,14 +231,18 @@ async def test_service_say_timeout(
|
|||
)
|
||||
await hass.async_block_till_done()
|
||||
|
||||
with pytest.raises(HomeAssistantError):
|
||||
await get_media_source_url(hass, calls[0].data[ATTR_MEDIA_CONTENT_ID])
|
||||
assert (
|
||||
await retrieve_media(hass, hass_client, calls[0].data[ATTR_MEDIA_CONTENT_ID])
|
||||
== HTTPStatus.NOT_FOUND
|
||||
)
|
||||
assert len(aioclient_mock.mock_calls) == 1
|
||||
assert aioclient_mock.mock_calls[0][2] == FORM_DATA
|
||||
|
||||
|
||||
async def test_service_say_error_msg(
|
||||
hass: HomeAssistant, aioclient_mock: AiohttpClientMocker
|
||||
hass: HomeAssistant,
|
||||
hass_client: ClientSessionGenerator,
|
||||
aioclient_mock: AiohttpClientMocker,
|
||||
) -> None:
|
||||
"""Test service call say with http error api message."""
|
||||
calls = async_mock_service(hass, DOMAIN_MP, SERVICE_PLAY_MEDIA)
|
||||
|
@ -254,7 +270,9 @@ async def test_service_say_error_msg(
|
|||
)
|
||||
await hass.async_block_till_done()
|
||||
|
||||
with pytest.raises(media_source.Unresolvable):
|
||||
await get_media_source_url(hass, calls[0].data[ATTR_MEDIA_CONTENT_ID])
|
||||
assert (
|
||||
await retrieve_media(hass, hass_client, calls[0].data[ATTR_MEDIA_CONTENT_ID])
|
||||
== HTTPStatus.NOT_FOUND
|
||||
)
|
||||
assert len(aioclient_mock.mock_calls) == 1
|
||||
assert aioclient_mock.mock_calls[0][2] == FORM_DATA
|
||||
|
|
|
@ -10,6 +10,39 @@
|
|||
}),
|
||||
])
|
||||
# ---
|
||||
# name: test_get_tts_audio_different_formats
|
||||
list([
|
||||
dict({
|
||||
'data': dict({
|
||||
'text': 'Hello world',
|
||||
}),
|
||||
'payload': None,
|
||||
'type': 'synthesize',
|
||||
}),
|
||||
])
|
||||
# ---
|
||||
# name: test_get_tts_audio_different_formats.1
|
||||
list([
|
||||
dict({
|
||||
'data': dict({
|
||||
'text': 'Hello world',
|
||||
}),
|
||||
'payload': None,
|
||||
'type': 'synthesize',
|
||||
}),
|
||||
])
|
||||
# ---
|
||||
# name: test_get_tts_audio_mp3
|
||||
list([
|
||||
dict({
|
||||
'data': dict({
|
||||
'text': 'Hello world',
|
||||
}),
|
||||
'payload': None,
|
||||
'type': 'synthesize',
|
||||
}),
|
||||
])
|
||||
# ---
|
||||
# name: test_get_tts_audio_raw
|
||||
list([
|
||||
dict({
|
||||
|
|
|
@ -51,31 +51,7 @@ async def test_get_tts_audio(hass: HomeAssistant, init_wyoming_tts, snapshot) ->
|
|||
AudioStop().event(),
|
||||
]
|
||||
|
||||
with patch(
|
||||
"homeassistant.components.wyoming.tts.AsyncTcpClient",
|
||||
MockAsyncTcpClient(audio_events),
|
||||
) as mock_client:
|
||||
extension, data = await tts.async_get_media_source_audio(
|
||||
hass,
|
||||
tts.generate_media_source_id(hass, "Hello world", "tts.test_tts", "en-US"),
|
||||
)
|
||||
|
||||
assert extension == "wav"
|
||||
assert data is not None
|
||||
with io.BytesIO(data) as wav_io, wave.open(wav_io, "rb") as wav_file:
|
||||
assert wav_file.getframerate() == 16000
|
||||
assert wav_file.getsampwidth() == 2
|
||||
assert wav_file.getnchannels() == 1
|
||||
assert wav_file.readframes(wav_file.getnframes()) == audio
|
||||
|
||||
assert mock_client.written == snapshot
|
||||
|
||||
|
||||
async def test_get_tts_audio_raw(
|
||||
hass: HomeAssistant, init_wyoming_tts, snapshot
|
||||
) -> None:
|
||||
"""Test get raw audio."""
|
||||
audio = bytes(100)
|
||||
# Verify audio
|
||||
audio_events = [
|
||||
AudioChunk(audio=audio, rate=16000, width=2, channels=1).event(),
|
||||
AudioStop().event(),
|
||||
|
@ -92,12 +68,83 @@ async def test_get_tts_audio_raw(
|
|||
"Hello world",
|
||||
"tts.test_tts",
|
||||
"en-US",
|
||||
options={tts.ATTR_AUDIO_OUTPUT: "raw"},
|
||||
options={tts.ATTR_PREFERRED_FORMAT: "wav"},
|
||||
),
|
||||
)
|
||||
|
||||
assert extension == "raw"
|
||||
assert data == audio
|
||||
assert extension == "wav"
|
||||
assert data is not None
|
||||
with io.BytesIO(data) as wav_io, wave.open(wav_io, "rb") as wav_file:
|
||||
assert wav_file.getframerate() == 16000
|
||||
assert wav_file.getsampwidth() == 2
|
||||
assert wav_file.getnchannels() == 1
|
||||
assert wav_file.readframes(wav_file.getnframes()) == audio
|
||||
|
||||
assert mock_client.written == snapshot
|
||||
|
||||
|
||||
async def test_get_tts_audio_different_formats(
|
||||
hass: HomeAssistant, init_wyoming_tts, snapshot
|
||||
) -> None:
|
||||
"""Test changing preferred audio format."""
|
||||
audio = bytes(16000 * 2 * 1) # one second
|
||||
audio_events = [
|
||||
AudioChunk(audio=audio, rate=16000, width=2, channels=1).event(),
|
||||
AudioStop().event(),
|
||||
]
|
||||
|
||||
# Request a different sample rate, etc.
|
||||
with patch(
|
||||
"homeassistant.components.wyoming.tts.AsyncTcpClient",
|
||||
MockAsyncTcpClient(audio_events),
|
||||
) as mock_client:
|
||||
extension, data = await tts.async_get_media_source_audio(
|
||||
hass,
|
||||
tts.generate_media_source_id(
|
||||
hass,
|
||||
"Hello world",
|
||||
"tts.test_tts",
|
||||
"en-US",
|
||||
options={
|
||||
tts.ATTR_PREFERRED_FORMAT: "wav",
|
||||
tts.ATTR_PREFERRED_SAMPLE_RATE: 48000,
|
||||
tts.ATTR_PREFERRED_SAMPLE_CHANNELS: 2,
|
||||
},
|
||||
),
|
||||
)
|
||||
|
||||
assert extension == "wav"
|
||||
assert data is not None
|
||||
with io.BytesIO(data) as wav_io, wave.open(wav_io, "rb") as wav_file:
|
||||
assert wav_file.getframerate() == 48000
|
||||
assert wav_file.getsampwidth() == 2
|
||||
assert wav_file.getnchannels() == 2
|
||||
assert wav_file.getnframes() == wav_file.getframerate() # one second
|
||||
|
||||
assert mock_client.written == snapshot
|
||||
|
||||
# MP3 is the default
|
||||
audio_events = [
|
||||
AudioChunk(audio=audio, rate=16000, width=2, channels=1).event(),
|
||||
AudioStop().event(),
|
||||
]
|
||||
|
||||
with patch(
|
||||
"homeassistant.components.wyoming.tts.AsyncTcpClient",
|
||||
MockAsyncTcpClient(audio_events),
|
||||
) as mock_client:
|
||||
extension, data = await tts.async_get_media_source_audio(
|
||||
hass,
|
||||
tts.generate_media_source_id(
|
||||
hass,
|
||||
"Hello world",
|
||||
"tts.test_tts",
|
||||
"en-US",
|
||||
),
|
||||
)
|
||||
|
||||
assert extension == "mp3"
|
||||
assert b"ID3" in data
|
||||
assert mock_client.written == snapshot
|
||||
|
||||
|
||||
|
|
|
@ -4,7 +4,7 @@ from http import HTTPStatus
|
|||
|
||||
import pytest
|
||||
|
||||
from homeassistant.components import media_source, tts
|
||||
from homeassistant.components import tts
|
||||
from homeassistant.components.media_player import (
|
||||
ATTR_MEDIA_CONTENT_ID,
|
||||
DOMAIN as DOMAIN_MP,
|
||||
|
@ -14,7 +14,9 @@ from homeassistant.core import HomeAssistant
|
|||
from homeassistant.setup import async_setup_component
|
||||
|
||||
from tests.common import assert_setup_component, async_mock_service
|
||||
from tests.components.tts.common import retrieve_media
|
||||
from tests.test_util.aiohttp import AiohttpClientMocker
|
||||
from tests.typing import ClientSessionGenerator
|
||||
|
||||
URL = "https://tts.voicetech.yandex.net/generate?"
|
||||
|
||||
|
@ -30,15 +32,6 @@ def mock_tts_cache_dir_autouse(mock_tts_cache_dir):
|
|||
return mock_tts_cache_dir
|
||||
|
||||
|
||||
async def get_media_source_url(hass, media_content_id):
|
||||
"""Get the media source url."""
|
||||
if media_source.DOMAIN not in hass.config.components:
|
||||
assert await async_setup_component(hass, media_source.DOMAIN, {})
|
||||
|
||||
resolved = await media_source.async_resolve_media(hass, media_content_id, None)
|
||||
return resolved.url
|
||||
|
||||
|
||||
async def test_setup_component(hass: HomeAssistant) -> None:
|
||||
"""Test setup component."""
|
||||
config = {tts.DOMAIN: {"platform": "yandextts", "api_key": "1234567xx"}}
|
||||
|
@ -58,7 +51,9 @@ async def test_setup_component_without_api_key(hass: HomeAssistant) -> None:
|
|||
|
||||
|
||||
async def test_service_say(
|
||||
hass: HomeAssistant, aioclient_mock: AiohttpClientMocker
|
||||
hass: HomeAssistant,
|
||||
aioclient_mock: AiohttpClientMocker,
|
||||
hass_client: ClientSessionGenerator,
|
||||
) -> None:
|
||||
"""Test service call say."""
|
||||
calls = async_mock_service(hass, DOMAIN_MP, SERVICE_PLAY_MEDIA)
|
||||
|
@ -87,12 +82,18 @@ async def test_service_say(
|
|||
blocking=True,
|
||||
)
|
||||
assert len(calls) == 1
|
||||
await get_media_source_url(hass, calls[0].data[ATTR_MEDIA_CONTENT_ID])
|
||||
assert (
|
||||
await retrieve_media(hass, hass_client, calls[0].data[ATTR_MEDIA_CONTENT_ID])
|
||||
== HTTPStatus.OK
|
||||
)
|
||||
|
||||
assert len(aioclient_mock.mock_calls) == 1
|
||||
|
||||
|
||||
async def test_service_say_russian_config(
|
||||
hass: HomeAssistant, aioclient_mock: AiohttpClientMocker
|
||||
hass: HomeAssistant,
|
||||
aioclient_mock: AiohttpClientMocker,
|
||||
hass_client: ClientSessionGenerator,
|
||||
) -> None:
|
||||
"""Test service call say."""
|
||||
calls = async_mock_service(hass, DOMAIN_MP, SERVICE_PLAY_MEDIA)
|
||||
|
@ -128,12 +129,18 @@ async def test_service_say_russian_config(
|
|||
)
|
||||
|
||||
assert len(calls) == 1
|
||||
await get_media_source_url(hass, calls[0].data[ATTR_MEDIA_CONTENT_ID])
|
||||
assert (
|
||||
await retrieve_media(hass, hass_client, calls[0].data[ATTR_MEDIA_CONTENT_ID])
|
||||
== HTTPStatus.OK
|
||||
)
|
||||
|
||||
assert len(aioclient_mock.mock_calls) == 1
|
||||
|
||||
|
||||
async def test_service_say_russian_service(
|
||||
hass: HomeAssistant, aioclient_mock: AiohttpClientMocker
|
||||
hass: HomeAssistant,
|
||||
aioclient_mock: AiohttpClientMocker,
|
||||
hass_client: ClientSessionGenerator,
|
||||
) -> None:
|
||||
"""Test service call say."""
|
||||
calls = async_mock_service(hass, DOMAIN_MP, SERVICE_PLAY_MEDIA)
|
||||
|
@ -166,12 +173,18 @@ async def test_service_say_russian_service(
|
|||
blocking=True,
|
||||
)
|
||||
assert len(calls) == 1
|
||||
await get_media_source_url(hass, calls[0].data[ATTR_MEDIA_CONTENT_ID])
|
||||
assert (
|
||||
await retrieve_media(hass, hass_client, calls[0].data[ATTR_MEDIA_CONTENT_ID])
|
||||
== HTTPStatus.OK
|
||||
)
|
||||
|
||||
assert len(aioclient_mock.mock_calls) == 1
|
||||
|
||||
|
||||
async def test_service_say_timeout(
|
||||
hass: HomeAssistant, aioclient_mock: AiohttpClientMocker
|
||||
hass: HomeAssistant,
|
||||
aioclient_mock: AiohttpClientMocker,
|
||||
hass_client: ClientSessionGenerator,
|
||||
) -> None:
|
||||
"""Test service call say."""
|
||||
calls = async_mock_service(hass, DOMAIN_MP, SERVICE_PLAY_MEDIA)
|
||||
|
@ -207,13 +220,18 @@ async def test_service_say_timeout(
|
|||
await hass.async_block_till_done()
|
||||
|
||||
assert len(calls) == 1
|
||||
with pytest.raises(media_source.Unresolvable):
|
||||
await get_media_source_url(hass, calls[0].data[ATTR_MEDIA_CONTENT_ID])
|
||||
assert (
|
||||
await retrieve_media(hass, hass_client, calls[0].data[ATTR_MEDIA_CONTENT_ID])
|
||||
== HTTPStatus.NOT_FOUND
|
||||
)
|
||||
|
||||
assert len(aioclient_mock.mock_calls) == 1
|
||||
|
||||
|
||||
async def test_service_say_http_error(
|
||||
hass: HomeAssistant, aioclient_mock: AiohttpClientMocker
|
||||
hass: HomeAssistant,
|
||||
aioclient_mock: AiohttpClientMocker,
|
||||
hass_client: ClientSessionGenerator,
|
||||
) -> None:
|
||||
"""Test service call say."""
|
||||
calls = async_mock_service(hass, DOMAIN_MP, SERVICE_PLAY_MEDIA)
|
||||
|
@ -248,12 +266,16 @@ async def test_service_say_http_error(
|
|||
)
|
||||
|
||||
assert len(calls) == 1
|
||||
with pytest.raises(media_source.Unresolvable):
|
||||
await get_media_source_url(hass, calls[0].data[ATTR_MEDIA_CONTENT_ID])
|
||||
assert (
|
||||
await retrieve_media(hass, hass_client, calls[0].data[ATTR_MEDIA_CONTENT_ID])
|
||||
== HTTPStatus.NOT_FOUND
|
||||
)
|
||||
|
||||
|
||||
async def test_service_say_specified_speaker(
|
||||
hass: HomeAssistant, aioclient_mock: AiohttpClientMocker
|
||||
hass: HomeAssistant,
|
||||
aioclient_mock: AiohttpClientMocker,
|
||||
hass_client: ClientSessionGenerator,
|
||||
) -> None:
|
||||
"""Test service call say."""
|
||||
calls = async_mock_service(hass, DOMAIN_MP, SERVICE_PLAY_MEDIA)
|
||||
|
@ -288,12 +310,18 @@ async def test_service_say_specified_speaker(
|
|||
blocking=True,
|
||||
)
|
||||
assert len(calls) == 1
|
||||
await get_media_source_url(hass, calls[0].data[ATTR_MEDIA_CONTENT_ID])
|
||||
assert (
|
||||
await retrieve_media(hass, hass_client, calls[0].data[ATTR_MEDIA_CONTENT_ID])
|
||||
== HTTPStatus.OK
|
||||
)
|
||||
|
||||
assert len(aioclient_mock.mock_calls) == 1
|
||||
|
||||
|
||||
async def test_service_say_specified_emotion(
|
||||
hass: HomeAssistant, aioclient_mock: AiohttpClientMocker
|
||||
hass: HomeAssistant,
|
||||
aioclient_mock: AiohttpClientMocker,
|
||||
hass_client: ClientSessionGenerator,
|
||||
) -> None:
|
||||
"""Test service call say."""
|
||||
calls = async_mock_service(hass, DOMAIN_MP, SERVICE_PLAY_MEDIA)
|
||||
|
@ -328,13 +356,18 @@ async def test_service_say_specified_emotion(
|
|||
blocking=True,
|
||||
)
|
||||
assert len(calls) == 1
|
||||
await get_media_source_url(hass, calls[0].data[ATTR_MEDIA_CONTENT_ID])
|
||||
assert (
|
||||
await retrieve_media(hass, hass_client, calls[0].data[ATTR_MEDIA_CONTENT_ID])
|
||||
== HTTPStatus.OK
|
||||
)
|
||||
|
||||
assert len(aioclient_mock.mock_calls) == 1
|
||||
|
||||
|
||||
async def test_service_say_specified_low_speed(
|
||||
hass: HomeAssistant, aioclient_mock: AiohttpClientMocker
|
||||
hass: HomeAssistant,
|
||||
aioclient_mock: AiohttpClientMocker,
|
||||
hass_client: ClientSessionGenerator,
|
||||
) -> None:
|
||||
"""Test service call say."""
|
||||
calls = async_mock_service(hass, DOMAIN_MP, SERVICE_PLAY_MEDIA)
|
||||
|
@ -365,13 +398,18 @@ async def test_service_say_specified_low_speed(
|
|||
blocking=True,
|
||||
)
|
||||
assert len(calls) == 1
|
||||
await get_media_source_url(hass, calls[0].data[ATTR_MEDIA_CONTENT_ID])
|
||||
assert (
|
||||
await retrieve_media(hass, hass_client, calls[0].data[ATTR_MEDIA_CONTENT_ID])
|
||||
== HTTPStatus.OK
|
||||
)
|
||||
|
||||
assert len(aioclient_mock.mock_calls) == 1
|
||||
|
||||
|
||||
async def test_service_say_specified_speed(
|
||||
hass: HomeAssistant, aioclient_mock: AiohttpClientMocker
|
||||
hass: HomeAssistant,
|
||||
aioclient_mock: AiohttpClientMocker,
|
||||
hass_client: ClientSessionGenerator,
|
||||
) -> None:
|
||||
"""Test service call say."""
|
||||
calls = async_mock_service(hass, DOMAIN_MP, SERVICE_PLAY_MEDIA)
|
||||
|
@ -400,13 +438,18 @@ async def test_service_say_specified_speed(
|
|||
blocking=True,
|
||||
)
|
||||
assert len(calls) == 1
|
||||
await get_media_source_url(hass, calls[0].data[ATTR_MEDIA_CONTENT_ID])
|
||||
assert (
|
||||
await retrieve_media(hass, hass_client, calls[0].data[ATTR_MEDIA_CONTENT_ID])
|
||||
== HTTPStatus.OK
|
||||
)
|
||||
|
||||
assert len(aioclient_mock.mock_calls) == 1
|
||||
|
||||
|
||||
async def test_service_say_specified_options(
|
||||
hass: HomeAssistant, aioclient_mock: AiohttpClientMocker
|
||||
hass: HomeAssistant,
|
||||
aioclient_mock: AiohttpClientMocker,
|
||||
hass_client: ClientSessionGenerator,
|
||||
) -> None:
|
||||
"""Test service call say with options."""
|
||||
calls = async_mock_service(hass, DOMAIN_MP, SERVICE_PLAY_MEDIA)
|
||||
|
@ -438,6 +481,9 @@ async def test_service_say_specified_options(
|
|||
blocking=True,
|
||||
)
|
||||
assert len(calls) == 1
|
||||
await get_media_source_url(hass, calls[0].data[ATTR_MEDIA_CONTENT_ID])
|
||||
assert (
|
||||
await retrieve_media(hass, hass_client, calls[0].data[ATTR_MEDIA_CONTENT_ID])
|
||||
== HTTPStatus.OK
|
||||
)
|
||||
|
||||
assert len(aioclient_mock.mock_calls) == 1
|
||||
|
|
Loading…
Reference in New Issue