Automatically convert TTS audio to MP3 on demand (#102814)

* Add ATTR_PREFERRED_FORMAT to TTS for auto-converting audio

* Move conversion into SpeechManager

* Handle None case for expected_extension

* Only use ATTR_AUDIO_OUTPUT

* Prefer MP3 in pipelines

* Automatically convert to mp3 on demand

* Add preferred audio format

* Break out preferred format

* Add ATTR_BLOCKING to allow async fetching

* Make a copy of supported options

* Fix MaryTTS tests

* Update ESPHome to use "wav" instead of "raw"

* Clean up tests, remove blocking

* Clean up rest of TTS tests

* Fix ESPHome tests

* More test coverage
pull/103516/head
Michael Hansen 2023-11-06 14:26:00 -06:00 committed by GitHub
parent 054089291f
commit ae516ffbb5
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
19 changed files with 723 additions and 241 deletions

View File

@ -971,12 +971,16 @@ class PipelineRun:
# pipeline.tts_engine can't be None or this function is not called
engine = cast(str, self.pipeline.tts_engine)
tts_options = {}
tts_options: dict[str, Any] = {}
if self.pipeline.tts_voice is not None:
tts_options[tts.ATTR_VOICE] = self.pipeline.tts_voice
if self.tts_audio_output is not None:
tts_options[tts.ATTR_AUDIO_OUTPUT] = self.tts_audio_output
tts_options[tts.ATTR_PREFERRED_FORMAT] = self.tts_audio_output
if self.tts_audio_output == "wav":
# 16 Khz, 16-bit mono
tts_options[tts.ATTR_PREFERRED_SAMPLE_RATE] = 16000
tts_options[tts.ATTR_PREFERRED_SAMPLE_CHANNELS] = 1
try:
options_supported = await tts.async_support_options(

View File

@ -150,4 +150,4 @@ class CloudProvider(Provider):
_LOGGER.error("Voice error: %s", err)
return (None, None)
return (str(options[ATTR_AUDIO_OUTPUT]), data)
return (str(options[ATTR_AUDIO_OUTPUT].value), data)

View File

@ -3,9 +3,11 @@ from __future__ import annotations
import asyncio
from collections.abc import AsyncIterable, Callable
import io
import logging
import socket
from typing import cast
import wave
from aioesphomeapi import (
VoiceAssistantAudioSettings,
@ -88,6 +90,7 @@ class VoiceAssistantUDPServer(asyncio.DatagramProtocol):
self.handle_event = handle_event
self.handle_finished = handle_finished
self._tts_done = asyncio.Event()
self._tts_task: asyncio.Task | None = None
async def start_server(self) -> int:
"""Start accepting connections."""
@ -189,7 +192,7 @@ class VoiceAssistantUDPServer(asyncio.DatagramProtocol):
if self.device_info.voice_assistant_version >= 2:
media_id = event.data["tts_output"]["media_id"]
self.hass.async_create_background_task(
self._tts_task = self.hass.async_create_background_task(
self._send_tts(media_id), "esphome_voice_assistant_tts"
)
else:
@ -228,7 +231,7 @@ class VoiceAssistantUDPServer(asyncio.DatagramProtocol):
audio_settings = VoiceAssistantAudioSettings()
tts_audio_output = (
"raw" if self.device_info.voice_assistant_version >= 2 else "mp3"
"wav" if self.device_info.voice_assistant_version >= 2 else "mp3"
)
_LOGGER.debug("Starting pipeline")
@ -302,11 +305,32 @@ class VoiceAssistantUDPServer(asyncio.DatagramProtocol):
VoiceAssistantEventType.VOICE_ASSISTANT_TTS_STREAM_START, {}
)
_extension, audio_bytes = await tts.async_get_media_source_audio(
extension, data = await tts.async_get_media_source_audio(
self.hass,
media_id,
)
if extension != "wav":
raise ValueError(f"Only WAV audio can be streamed, got {extension}")
with io.BytesIO(data) as wav_io:
with wave.open(wav_io, "rb") as wav_file:
sample_rate = wav_file.getframerate()
sample_width = wav_file.getsampwidth()
sample_channels = wav_file.getnchannels()
if (
(sample_rate != 16000)
or (sample_width != 2)
or (sample_channels != 1)
):
raise ValueError(
"Expected rate/width/channels as 16000/2/1,"
" got {sample_rate}/{sample_width}/{sample_channels}}"
)
audio_bytes = wav_file.readframes(wav_file.getnframes())
_LOGGER.debug("Sending %d bytes of audio", len(audio_bytes))
bytes_per_sample = stt.AudioBitRates.BITRATE_16 // 8
@ -330,4 +354,5 @@ class VoiceAssistantUDPServer(asyncio.DatagramProtocol):
self.handle_event(
VoiceAssistantEventType.VOICE_ASSISTANT_TTS_STREAM_END, {}
)
self._tts_task = None
self._tts_done.set()

View File

@ -13,6 +13,8 @@ import logging
import mimetypes
import os
import re
import subprocess
import tempfile
from typing import Any, TypedDict, final
from aiohttp import web
@ -20,7 +22,7 @@ import mutagen
from mutagen.id3 import ID3, TextFrame as ID3Text
import voluptuous as vol
from homeassistant.components import websocket_api
from homeassistant.components import ffmpeg, websocket_api
from homeassistant.components.http import HomeAssistantView
from homeassistant.components.media_player import (
ATTR_MEDIA_ANNOUNCE,
@ -72,11 +74,15 @@ __all__ = [
"async_get_media_source_audio",
"async_support_options",
"ATTR_AUDIO_OUTPUT",
"ATTR_PREFERRED_FORMAT",
"ATTR_PREFERRED_SAMPLE_RATE",
"ATTR_PREFERRED_SAMPLE_CHANNELS",
"CONF_LANG",
"DEFAULT_CACHE_DIR",
"generate_media_source_id",
"PLATFORM_SCHEMA_BASE",
"PLATFORM_SCHEMA",
"SampleFormat",
"Provider",
"TtsAudioType",
"Voice",
@ -86,6 +92,9 @@ _LOGGER = logging.getLogger(__name__)
ATTR_PLATFORM = "platform"
ATTR_AUDIO_OUTPUT = "audio_output"
ATTR_PREFERRED_FORMAT = "preferred_format"
ATTR_PREFERRED_SAMPLE_RATE = "preferred_sample_rate"
ATTR_PREFERRED_SAMPLE_CHANNELS = "preferred_sample_channels"
ATTR_MEDIA_PLAYER_ENTITY_ID = "media_player_entity_id"
ATTR_VOICE = "voice"
@ -199,6 +208,83 @@ def async_get_text_to_speech_languages(hass: HomeAssistant) -> set[str]:
return languages
async def async_convert_audio(
hass: HomeAssistant,
from_extension: str,
audio_bytes: bytes,
to_extension: str,
to_sample_rate: int | None = None,
to_sample_channels: int | None = None,
) -> bytes:
"""Convert audio to a preferred format using ffmpeg."""
ffmpeg_manager = ffmpeg.get_ffmpeg_manager(hass)
return await hass.async_add_executor_job(
lambda: _convert_audio(
ffmpeg_manager.binary,
from_extension,
audio_bytes,
to_extension,
to_sample_rate=to_sample_rate,
to_sample_channels=to_sample_channels,
)
)
def _convert_audio(
ffmpeg_binary: str,
from_extension: str,
audio_bytes: bytes,
to_extension: str,
to_sample_rate: int | None = None,
to_sample_channels: int | None = None,
) -> bytes:
"""Convert audio to a preferred format using ffmpeg."""
# We have to use a temporary file here because some formats like WAV store
# the length of the file in the header, and therefore cannot be written in a
# streaming fashion.
with tempfile.NamedTemporaryFile(
mode="wb+", suffix=f".{to_extension}"
) as output_file:
# input
command = [
ffmpeg_binary,
"-y", # overwrite temp file
"-f",
from_extension,
"-i",
"pipe:", # input from stdin
]
# output
command.extend(["-f", to_extension])
if to_sample_rate is not None:
command.extend(["-ar", str(to_sample_rate)])
if to_sample_channels is not None:
command.extend(["-ac", str(to_sample_channels)])
if to_extension == "mp3":
# Max quality for MP3
command.extend(["-q:a", "0"])
command.append(output_file.name)
with subprocess.Popen(
command, stdin=subprocess.PIPE, stderr=subprocess.PIPE
) as proc:
_stdout, stderr = proc.communicate(input=audio_bytes)
if proc.returncode != 0:
_LOGGER.error(stderr.decode())
raise RuntimeError(
f"Unexpected error while running ffmpeg with arguments: {command}. See log for details."
)
output_file.seek(0)
return output_file.read()
async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool:
"""Set up TTS."""
websocket_api.async_register_command(hass, websocket_list_engines)
@ -482,7 +568,18 @@ class SpeechManager:
merged_options = dict(engine_instance.default_options or {})
merged_options.update(options or {})
supported_options = engine_instance.supported_options or []
supported_options = list(engine_instance.supported_options or [])
# ATTR_PREFERRED_* options are always "supported" since they're used to
# convert audio after the TTS has run (if necessary).
supported_options.extend(
(
ATTR_PREFERRED_FORMAT,
ATTR_PREFERRED_SAMPLE_RATE,
ATTR_PREFERRED_SAMPLE_CHANNELS,
)
)
invalid_opts = [
opt_name for opt_name in merged_options if opt_name not in supported_options
]
@ -520,12 +617,7 @@ class SpeechManager:
# Load speech from engine into memory
else:
filename = await self._async_get_tts_audio(
engine_instance,
cache_key,
message,
use_cache,
language,
options,
engine_instance, cache_key, message, use_cache, language, options
)
return f"/api/tts_proxy/{filename}"
@ -590,10 +682,10 @@ class SpeechManager:
This method is a coroutine.
"""
if options is not None and ATTR_AUDIO_OUTPUT in options:
expected_extension = options[ATTR_AUDIO_OUTPUT]
else:
expected_extension = None
options = options or {}
# Default to MP3 unless a different format is preferred
final_extension = options.get(ATTR_PREFERRED_FORMAT, "mp3")
async def get_tts_data() -> str:
"""Handle data available."""
@ -614,8 +706,27 @@ class SpeechManager:
f"No TTS from {engine_instance.name} for '{message}'"
)
# Only convert if we have a preferred format different than the
# expected format from the TTS system, or if a specific sample
# rate/format/channel count is requested.
needs_conversion = (
(final_extension != extension)
or (ATTR_PREFERRED_SAMPLE_RATE in options)
or (ATTR_PREFERRED_SAMPLE_CHANNELS in options)
)
if needs_conversion:
data = await async_convert_audio(
self.hass,
extension,
data,
to_extension=final_extension,
to_sample_rate=options.get(ATTR_PREFERRED_SAMPLE_RATE),
to_sample_channels=options.get(ATTR_PREFERRED_SAMPLE_CHANNELS),
)
# Create file infos
filename = f"{cache_key}.{extension}".lower()
filename = f"{cache_key}.{final_extension}".lower()
# Validate filename
if not _RE_VOICE_FILE.match(filename) and not _RE_LEGACY_VOICE_FILE.match(
@ -626,10 +737,11 @@ class SpeechManager:
)
# Save to memory
if extension == "mp3":
if final_extension == "mp3":
data = self.write_tags(
filename, data, engine_instance.name, message, language, options
)
self._async_store_to_memcache(cache_key, filename, data)
if cache:
@ -641,9 +753,6 @@ class SpeechManager:
audio_task = self.hass.async_create_task(get_tts_data())
if expected_extension is None:
return await audio_task
def handle_error(_future: asyncio.Future) -> None:
"""Handle error."""
if audio_task.exception():
@ -651,7 +760,7 @@ class SpeechManager:
audio_task.add_done_callback(handle_error)
filename = f"{cache_key}.{expected_extension}".lower()
filename = f"{cache_key}.{final_extension}".lower()
self.mem_cache[cache_key] = {
"filename": filename,
"voice": b"",
@ -747,11 +856,12 @@ class SpeechManager:
raise HomeAssistantError(f"{cache_key} not in cache!")
await self._async_file_to_mem(cache_key)
content, _ = mimetypes.guess_type(filename)
cached = self.mem_cache[cache_key]
if pending := cached.get("pending"):
await pending
cached = self.mem_cache[cache_key]
content, _ = mimetypes.guess_type(filename)
return content, cached["voice"]
@staticmethod

View File

@ -3,7 +3,7 @@
"name": "Text-to-speech (TTS)",
"after_dependencies": ["media_player"],
"codeowners": ["@home-assistant/core", "@pvizeli"],
"dependencies": ["http"],
"dependencies": ["http", "ffmpeg"],
"documentation": "https://www.home-assistant.io/integrations/tts",
"integration_type": "entity",
"loggers": ["mutagen"],

View File

@ -4,7 +4,7 @@ import io
import logging
import wave
from wyoming.audio import AudioChunk, AudioChunkConverter, AudioStop
from wyoming.audio import AudioChunk, AudioStop
from wyoming.client import AsyncTcpClient
from wyoming.tts import Synthesize, SynthesizeVoice
@ -88,12 +88,16 @@ class WyomingTtsProvider(tts.TextToSpeechEntity):
@property
def supported_options(self):
"""Return list of supported options like voice, emotion."""
return [tts.ATTR_AUDIO_OUTPUT, tts.ATTR_VOICE, ATTR_SPEAKER]
return [
tts.ATTR_AUDIO_OUTPUT,
tts.ATTR_VOICE,
ATTR_SPEAKER,
]
@property
def default_options(self):
"""Return a dict include default options."""
return {tts.ATTR_AUDIO_OUTPUT: "wav"}
return {}
@callback
def async_get_supported_voices(self, language: str) -> list[tts.Voice] | None:
@ -143,27 +147,4 @@ class WyomingTtsProvider(tts.TextToSpeechEntity):
except (OSError, WyomingError):
return (None, None)
if options[tts.ATTR_AUDIO_OUTPUT] == "wav":
return ("wav", data)
# Raw output (convert to 16Khz, 16-bit mono)
with io.BytesIO(data) as wav_io:
wav_reader: wave.Wave_read = wave.open(wav_io, "rb")
raw_data = (
AudioChunkConverter(
rate=16000,
width=2,
channels=1,
)
.convert(
AudioChunk(
audio=wav_reader.readframes(wav_reader.getnframes()),
rate=wav_reader.getframerate(),
width=wav_reader.getsampwidth(),
channels=wav_reader.getnchannels(),
)
)
.audio
)
return ("raw", raw_data)

View File

@ -20,6 +20,7 @@ cryptography==41.0.4
dbus-fast==2.12.0
fnv-hash-fast==0.5.0
ha-av==10.1.1
ha-ffmpeg==3.1.0
hass-nabucasa==0.74.0
hassil==1.2.5
home-assistant-bluetooth==1.10.4

View File

@ -9,7 +9,7 @@ import wave
import pytest
from syrupy.assertion import SnapshotAssertion
from homeassistant.components import assist_pipeline, stt
from homeassistant.components import assist_pipeline, stt, tts
from homeassistant.components.assist_pipeline.const import (
CONF_DEBUG_RECORDING_DIR,
DOMAIN,
@ -660,3 +660,42 @@ def test_pipeline_run_equality(hass: HomeAssistant, init_components) -> None:
assert run_1 == run_1
assert run_1 != run_2
assert run_1 != 1234
async def test_tts_audio_output(
hass: HomeAssistant,
mock_stt_provider: MockSttProvider,
init_components,
pipeline_data: assist_pipeline.pipeline.PipelineData,
snapshot: SnapshotAssertion,
) -> None:
"""Test using tts_audio_output with wav sets options correctly."""
def event_callback(event):
pass
pipeline_store = pipeline_data.pipeline_store
pipeline_id = pipeline_store.async_get_preferred_item()
pipeline = assist_pipeline.pipeline.async_get_pipeline(hass, pipeline_id)
pipeline_input = assist_pipeline.pipeline.PipelineInput(
tts_input="This is a test.",
conversation_id=None,
device_id=None,
run=assist_pipeline.pipeline.PipelineRun(
hass,
context=Context(),
pipeline=pipeline,
start_stage=assist_pipeline.PipelineStage.TTS,
end_stage=assist_pipeline.PipelineStage.TTS,
event_callback=event_callback,
tts_audio_output="wav",
),
)
await pipeline_input.validate()
# Verify TTS audio settings
assert pipeline_input.run.tts_options is not None
assert pipeline_input.run.tts_options.get(tts.ATTR_PREFERRED_FORMAT) == "wav"
assert pipeline_input.run.tts_options.get(tts.ATTR_PREFERRED_SAMPLE_RATE) == 16000
assert pipeline_input.run.tts_options.get(tts.ATTR_PREFERRED_SAMPLE_CHANNELS) == 1

View File

@ -1,8 +1,10 @@
"""Test ESPHome voice assistant server."""
import asyncio
import io
import socket
from unittest.mock import Mock, patch
import wave
from aioesphomeapi import VoiceAssistantEventType
import pytest
@ -340,9 +342,18 @@ async def test_send_tts(
voice_assistant_udp_server_v2: VoiceAssistantUDPServer,
) -> None:
"""Test the UDP server calls sendto to transmit audio data to device."""
with io.BytesIO() as wav_io:
with wave.open(wav_io, "wb") as wav_file:
wav_file.setframerate(16000)
wav_file.setsampwidth(2)
wav_file.setnchannels(1)
wav_file.writeframes(bytes(_ONE_SECOND))
wav_bytes = wav_io.getvalue()
with patch(
"homeassistant.components.esphome.voice_assistant.tts.async_get_media_source_audio",
return_value=("raw", bytes(1024)),
return_value=("wav", wav_bytes),
):
voice_assistant_udp_server_v2.transport = Mock(spec=asyncio.DatagramTransport)
@ -360,6 +371,63 @@ async def test_send_tts(
voice_assistant_udp_server_v2.transport.sendto.assert_called()
async def test_send_tts_wrong_sample_rate(
hass: HomeAssistant,
voice_assistant_udp_server_v2: VoiceAssistantUDPServer,
) -> None:
"""Test the UDP server calls sendto to transmit audio data to device."""
with io.BytesIO() as wav_io:
with wave.open(wav_io, "wb") as wav_file:
wav_file.setframerate(22050) # should be 16000
wav_file.setsampwidth(2)
wav_file.setnchannels(1)
wav_file.writeframes(bytes(_ONE_SECOND))
wav_bytes = wav_io.getvalue()
with patch(
"homeassistant.components.esphome.voice_assistant.tts.async_get_media_source_audio",
return_value=("wav", wav_bytes),
), pytest.raises(ValueError):
voice_assistant_udp_server_v2.transport = Mock(spec=asyncio.DatagramTransport)
voice_assistant_udp_server_v2._event_callback(
PipelineEvent(
type=PipelineEventType.TTS_END,
data={
"tts_output": {"media_id": _TEST_MEDIA_ID, "url": _TEST_OUTPUT_URL}
},
)
)
assert voice_assistant_udp_server_v2._tts_task is not None
await voice_assistant_udp_server_v2._tts_task # raises ValueError
async def test_send_tts_wrong_format(
hass: HomeAssistant,
voice_assistant_udp_server_v2: VoiceAssistantUDPServer,
) -> None:
"""Test that only WAV audio will be streamed."""
with patch(
"homeassistant.components.esphome.voice_assistant.tts.async_get_media_source_audio",
return_value=("raw", bytes(1024)),
), pytest.raises(ValueError):
voice_assistant_udp_server_v2.transport = Mock(spec=asyncio.DatagramTransport)
voice_assistant_udp_server_v2._event_callback(
PipelineEvent(
type=PipelineEventType.TTS_END,
data={
"tts_output": {"media_id": _TEST_MEDIA_ID, "url": _TEST_OUTPUT_URL}
},
)
)
assert voice_assistant_udp_server_v2._tts_task is not None
await voice_assistant_udp_server_v2._tts_task # raises ValueError
async def test_wake_word(
hass: HomeAssistant,
voice_assistant_udp_server_v2: VoiceAssistantUDPServer,

View File

@ -2,13 +2,14 @@
from __future__ import annotations
from collections.abc import Generator
from http import HTTPStatus
from typing import Any
from unittest.mock import MagicMock, patch
from gtts import gTTSError
import pytest
from homeassistant.components import media_source, tts
from homeassistant.components import tts
from homeassistant.components.google_translate.const import CONF_TLD, DOMAIN
from homeassistant.components.media_player import (
ATTR_MEDIA_CONTENT_ID,
@ -18,10 +19,11 @@ from homeassistant.components.media_player import (
from homeassistant.config import async_process_ha_core_config
from homeassistant.const import ATTR_ENTITY_ID, CONF_PLATFORM
from homeassistant.core import HomeAssistant, ServiceCall
from homeassistant.exceptions import HomeAssistantError
from homeassistant.setup import async_setup_component
from tests.common import MockConfigEntry, async_mock_service
from tests.components.tts.common import retrieve_media
from tests.typing import ClientSessionGenerator
@pytest.fixture(autouse=True)
@ -35,15 +37,6 @@ def mock_tts_cache_dir_autouse(mock_tts_cache_dir):
return mock_tts_cache_dir
async def get_media_source_url(hass: HomeAssistant, media_content_id: str) -> str:
"""Get the media source url."""
if media_source.DOMAIN not in hass.config.components:
assert await async_setup_component(hass, media_source.DOMAIN, {})
resolved = await media_source.async_resolve_media(hass, media_content_id, None)
return resolved.url
@pytest.fixture
async def calls(hass: HomeAssistant) -> list[ServiceCall]:
"""Mock media player calls."""
@ -128,6 +121,7 @@ async def mock_config_entry_setup(hass: HomeAssistant, config: dict[str, Any]) -
async def test_tts_service(
hass: HomeAssistant,
mock_gtts: MagicMock,
hass_client: ClientSessionGenerator,
calls: list[ServiceCall],
setup: str,
tts_service: str,
@ -142,9 +136,11 @@ async def test_tts_service(
)
assert len(calls) == 1
url = await get_media_source_url(hass, calls[0].data[ATTR_MEDIA_CONTENT_ID])
assert (
await retrieve_media(hass, hass_client, calls[0].data[ATTR_MEDIA_CONTENT_ID])
== HTTPStatus.OK
)
assert len(mock_gtts.mock_calls) == 2
assert url.endswith(".mp3")
assert mock_gtts.mock_calls[0][2] == {
"text": "There is a person at the front door.",
@ -180,6 +176,7 @@ async def test_tts_service(
async def test_service_say_german_config(
hass: HomeAssistant,
mock_gtts: MagicMock,
hass_client: ClientSessionGenerator,
calls: list[ServiceCall],
setup: str,
tts_service: str,
@ -194,7 +191,10 @@ async def test_service_say_german_config(
)
assert len(calls) == 1
await get_media_source_url(hass, calls[0].data[ATTR_MEDIA_CONTENT_ID])
assert (
await retrieve_media(hass, hass_client, calls[0].data[ATTR_MEDIA_CONTENT_ID])
== HTTPStatus.OK
)
assert len(mock_gtts.mock_calls) == 2
assert mock_gtts.mock_calls[0][2] == {
"text": "There is a person at the front door.",
@ -231,6 +231,7 @@ async def test_service_say_german_config(
async def test_service_say_german_service(
hass: HomeAssistant,
mock_gtts: MagicMock,
hass_client: ClientSessionGenerator,
calls: list[ServiceCall],
setup: str,
tts_service: str,
@ -245,7 +246,10 @@ async def test_service_say_german_service(
)
assert len(calls) == 1
await get_media_source_url(hass, calls[0].data[ATTR_MEDIA_CONTENT_ID])
assert (
await retrieve_media(hass, hass_client, calls[0].data[ATTR_MEDIA_CONTENT_ID])
== HTTPStatus.OK
)
assert len(mock_gtts.mock_calls) == 2
assert mock_gtts.mock_calls[0][2] == {
"text": "There is a person at the front door.",
@ -281,6 +285,7 @@ async def test_service_say_german_service(
async def test_service_say_en_uk_config(
hass: HomeAssistant,
mock_gtts: MagicMock,
hass_client: ClientSessionGenerator,
calls: list[ServiceCall],
setup: str,
tts_service: str,
@ -295,7 +300,10 @@ async def test_service_say_en_uk_config(
)
assert len(calls) == 1
await get_media_source_url(hass, calls[0].data[ATTR_MEDIA_CONTENT_ID])
assert (
await retrieve_media(hass, hass_client, calls[0].data[ATTR_MEDIA_CONTENT_ID])
== HTTPStatus.OK
)
assert len(mock_gtts.mock_calls) == 2
assert mock_gtts.mock_calls[0][2] == {
"text": "There is a person at the front door.",
@ -332,6 +340,7 @@ async def test_service_say_en_uk_config(
async def test_service_say_en_uk_service(
hass: HomeAssistant,
mock_gtts: MagicMock,
hass_client: ClientSessionGenerator,
calls: list[ServiceCall],
setup: str,
tts_service: str,
@ -346,7 +355,10 @@ async def test_service_say_en_uk_service(
)
assert len(calls) == 1
await get_media_source_url(hass, calls[0].data[ATTR_MEDIA_CONTENT_ID])
assert (
await retrieve_media(hass, hass_client, calls[0].data[ATTR_MEDIA_CONTENT_ID])
== HTTPStatus.OK
)
assert len(mock_gtts.mock_calls) == 2
assert mock_gtts.mock_calls[0][2] == {
"text": "There is a person at the front door.",
@ -383,6 +395,7 @@ async def test_service_say_en_uk_service(
async def test_service_say_en_couk(
hass: HomeAssistant,
mock_gtts: MagicMock,
hass_client: ClientSessionGenerator,
calls: list[ServiceCall],
setup: str,
tts_service: str,
@ -397,9 +410,11 @@ async def test_service_say_en_couk(
)
assert len(calls) == 1
url = await get_media_source_url(hass, calls[0].data[ATTR_MEDIA_CONTENT_ID])
assert (
await retrieve_media(hass, hass_client, calls[0].data[ATTR_MEDIA_CONTENT_ID])
== HTTPStatus.OK
)
assert len(mock_gtts.mock_calls) == 2
assert url.endswith(".mp3")
assert mock_gtts.mock_calls[0][2] == {
"text": "There is a person at the front door.",
@ -434,6 +449,7 @@ async def test_service_say_en_couk(
async def test_service_say_error(
hass: HomeAssistant,
mock_gtts: MagicMock,
hass_client: ClientSessionGenerator,
calls: list[ServiceCall],
setup: str,
tts_service: str,
@ -450,6 +466,8 @@ async def test_service_say_error(
)
assert len(calls) == 1
with pytest.raises(HomeAssistantError):
await get_media_source_url(hass, calls[0].data[ATTR_MEDIA_CONTENT_ID])
assert (
await retrieve_media(hass, hass_client, calls[0].data[ATTR_MEDIA_CONTENT_ID])
== HTTPStatus.NOT_FOUND
)
assert len(mock_gtts.mock_calls) == 2

View File

@ -1,9 +1,12 @@
"""The tests for the MaryTTS speech platform."""
from http import HTTPStatus
import io
from unittest.mock import patch
import wave
import pytest
from homeassistant.components import media_source, tts
from homeassistant.components import tts
from homeassistant.components.media_player import (
ATTR_MEDIA_CONTENT_ID,
DOMAIN as DOMAIN_MP,
@ -13,15 +16,19 @@ from homeassistant.core import HomeAssistant
from homeassistant.setup import async_setup_component
from tests.common import assert_setup_component, async_mock_service
from tests.components.tts.common import retrieve_media
from tests.typing import ClientSessionGenerator
async def get_media_source_url(hass, media_content_id):
"""Get the media source url."""
if media_source.DOMAIN not in hass.config.components:
assert await async_setup_component(hass, media_source.DOMAIN, {})
def get_empty_wav() -> bytes:
"""Get bytes for empty WAV file."""
with io.BytesIO() as wav_io:
with wave.open(wav_io, "wb") as wav_file:
wav_file.setframerate(22050)
wav_file.setsampwidth(2)
wav_file.setnchannels(1)
resolved = await media_source.async_resolve_media(hass, media_content_id, None)
return resolved.url
return wav_io.getvalue()
@pytest.fixture(autouse=True)
@ -39,7 +46,9 @@ async def test_setup_component(hass: HomeAssistant) -> None:
await hass.async_block_till_done()
async def test_service_say(hass: HomeAssistant) -> None:
async def test_service_say(
hass: HomeAssistant, hass_client: ClientSessionGenerator
) -> None:
"""Test service call say."""
calls = async_mock_service(hass, DOMAIN_MP, SERVICE_PLAY_MEDIA)
@ -51,7 +60,7 @@ async def test_service_say(hass: HomeAssistant) -> None:
with patch(
"homeassistant.components.marytts.tts.MaryTTS.speak",
return_value=b"audio",
return_value=get_empty_wav(),
) as mock_speak:
await hass.services.async_call(
tts.DOMAIN,
@ -63,16 +72,22 @@ async def test_service_say(hass: HomeAssistant) -> None:
blocking=True,
)
url = await get_media_source_url(hass, calls[0].data[ATTR_MEDIA_CONTENT_ID])
assert (
await retrieve_media(
hass, hass_client, calls[0].data[ATTR_MEDIA_CONTENT_ID]
)
== HTTPStatus.OK
)
mock_speak.assert_called_once()
mock_speak.assert_called_with("HomeAssistant", {})
assert len(calls) == 1
assert url.endswith(".wav")
async def test_service_say_with_effect(hass: HomeAssistant) -> None:
async def test_service_say_with_effect(
hass: HomeAssistant, hass_client: ClientSessionGenerator
) -> None:
"""Test service call say with effects."""
calls = async_mock_service(hass, DOMAIN_MP, SERVICE_PLAY_MEDIA)
@ -84,7 +99,7 @@ async def test_service_say_with_effect(hass: HomeAssistant) -> None:
with patch(
"homeassistant.components.marytts.tts.MaryTTS.speak",
return_value=b"audio",
return_value=get_empty_wav(),
) as mock_speak:
await hass.services.async_call(
tts.DOMAIN,
@ -96,16 +111,22 @@ async def test_service_say_with_effect(hass: HomeAssistant) -> None:
blocking=True,
)
url = await get_media_source_url(hass, calls[0].data[ATTR_MEDIA_CONTENT_ID])
assert (
await retrieve_media(
hass, hass_client, calls[0].data[ATTR_MEDIA_CONTENT_ID]
)
== HTTPStatus.OK
)
mock_speak.assert_called_once()
mock_speak.assert_called_with("HomeAssistant", {"Volume": "amount:2.0;"})
assert len(calls) == 1
assert url.endswith(".wav")
async def test_service_say_http_error(hass: HomeAssistant) -> None:
async def test_service_say_http_error(
hass: HomeAssistant, hass_client: ClientSessionGenerator
) -> None:
"""Test service call say."""
calls = async_mock_service(hass, DOMAIN_MP, SERVICE_PLAY_MEDIA)
@ -129,7 +150,11 @@ async def test_service_say_http_error(hass: HomeAssistant) -> None:
)
await hass.async_block_till_done()
with pytest.raises(Exception):
await get_media_source_url(hass, calls[0].data[ATTR_MEDIA_CONTENT_ID])
assert (
await retrieve_media(
hass, hass_client, calls[0].data[ATTR_MEDIA_CONTENT_ID]
)
== HTTPStatus.NOT_FOUND
)
mock_speak.assert_called_once()

View File

@ -1,10 +1,11 @@
"""Tests for Microsoft text-to-speech."""
from http import HTTPStatus
from unittest.mock import patch
from pycsspeechtts import pycsspeechtts
import pytest
from homeassistant.components import media_source, tts
from homeassistant.components import tts
from homeassistant.components.media_player import (
ATTR_MEDIA_CONTENT_ID,
DOMAIN as DOMAIN_MP,
@ -13,19 +14,12 @@ from homeassistant.components.media_player import (
from homeassistant.components.microsoft.tts import SUPPORTED_LANGUAGES
from homeassistant.config import async_process_ha_core_config
from homeassistant.core import HomeAssistant
from homeassistant.exceptions import HomeAssistantError, ServiceNotFound
from homeassistant.exceptions import ServiceNotFound
from homeassistant.setup import async_setup_component
from tests.common import async_mock_service
async def get_media_source_url(hass: HomeAssistant, media_content_id):
"""Get the media source url."""
if media_source.DOMAIN not in hass.config.components:
assert await async_setup_component(hass, media_source.DOMAIN, {})
resolved = await media_source.async_resolve_media(hass, media_content_id, None)
return resolved.url
from tests.components.tts.common import retrieve_media
from tests.typing import ClientSessionGenerator
@pytest.fixture(autouse=True)
@ -58,7 +52,9 @@ def mock_tts():
yield mock_tts
async def test_service_say(hass: HomeAssistant, mock_tts, calls) -> None:
async def test_service_say(
hass: HomeAssistant, hass_client: ClientSessionGenerator, mock_tts, calls
) -> None:
"""Test service call say."""
await async_setup_component(
@ -76,9 +72,12 @@ async def test_service_say(hass: HomeAssistant, mock_tts, calls) -> None:
)
assert len(calls) == 1
url = await get_media_source_url(hass, calls[0].data[ATTR_MEDIA_CONTENT_ID])
assert (
await retrieve_media(hass, hass_client, calls[0].data[ATTR_MEDIA_CONTENT_ID])
== HTTPStatus.OK
)
assert len(mock_tts.mock_calls) == 2
assert url.endswith(".mp3")
assert mock_tts.mock_calls[1][2] == {
"language": "en-us",
@ -93,7 +92,9 @@ async def test_service_say(hass: HomeAssistant, mock_tts, calls) -> None:
}
async def test_service_say_en_gb_config(hass: HomeAssistant, mock_tts, calls) -> None:
async def test_service_say_en_gb_config(
hass: HomeAssistant, hass_client: ClientSessionGenerator, mock_tts, calls
) -> None:
"""Test service call say with en-gb code in the config."""
await async_setup_component(
@ -120,7 +121,11 @@ async def test_service_say_en_gb_config(hass: HomeAssistant, mock_tts, calls) ->
)
assert len(calls) == 1
await get_media_source_url(hass, calls[0].data[ATTR_MEDIA_CONTENT_ID])
assert (
await retrieve_media(hass, hass_client, calls[0].data[ATTR_MEDIA_CONTENT_ID])
== HTTPStatus.OK
)
assert len(mock_tts.mock_calls) == 2
assert mock_tts.mock_calls[1][2] == {
"language": "en-gb",
@ -135,7 +140,9 @@ async def test_service_say_en_gb_config(hass: HomeAssistant, mock_tts, calls) ->
}
async def test_service_say_en_gb_service(hass: HomeAssistant, mock_tts, calls) -> None:
async def test_service_say_en_gb_service(
hass: HomeAssistant, hass_client: ClientSessionGenerator, mock_tts, calls
) -> None:
"""Test service call say with en-gb code in the service."""
await async_setup_component(
@ -157,7 +164,11 @@ async def test_service_say_en_gb_service(hass: HomeAssistant, mock_tts, calls) -
)
assert len(calls) == 1
await get_media_source_url(hass, calls[0].data[ATTR_MEDIA_CONTENT_ID])
assert (
await retrieve_media(hass, hass_client, calls[0].data[ATTR_MEDIA_CONTENT_ID])
== HTTPStatus.OK
)
assert len(mock_tts.mock_calls) == 2
assert mock_tts.mock_calls[1][2] == {
"language": "en-gb",
@ -172,7 +183,9 @@ async def test_service_say_en_gb_service(hass: HomeAssistant, mock_tts, calls) -
}
async def test_service_say_fa_ir_config(hass: HomeAssistant, mock_tts, calls) -> None:
async def test_service_say_fa_ir_config(
hass: HomeAssistant, hass_client: ClientSessionGenerator, mock_tts, calls
) -> None:
"""Test service call say with fa-ir code in the config."""
await async_setup_component(
@ -199,7 +212,11 @@ async def test_service_say_fa_ir_config(hass: HomeAssistant, mock_tts, calls) ->
)
assert len(calls) == 1
await get_media_source_url(hass, calls[0].data[ATTR_MEDIA_CONTENT_ID])
assert (
await retrieve_media(hass, hass_client, calls[0].data[ATTR_MEDIA_CONTENT_ID])
== HTTPStatus.OK
)
assert len(mock_tts.mock_calls) == 2
assert mock_tts.mock_calls[1][2] == {
"language": "fa-ir",
@ -214,7 +231,9 @@ async def test_service_say_fa_ir_config(hass: HomeAssistant, mock_tts, calls) ->
}
async def test_service_say_fa_ir_service(hass: HomeAssistant, mock_tts, calls) -> None:
async def test_service_say_fa_ir_service(
hass: HomeAssistant, hass_client: ClientSessionGenerator, mock_tts, calls
) -> None:
"""Test service call say with fa-ir code in the service."""
config = {
@ -240,7 +259,11 @@ async def test_service_say_fa_ir_service(hass: HomeAssistant, mock_tts, calls) -
)
assert len(calls) == 1
await get_media_source_url(hass, calls[0].data[ATTR_MEDIA_CONTENT_ID])
assert (
await retrieve_media(hass, hass_client, calls[0].data[ATTR_MEDIA_CONTENT_ID])
== HTTPStatus.OK
)
assert len(mock_tts.mock_calls) == 2
assert mock_tts.mock_calls[1][2] == {
"language": "fa-ir",
@ -295,7 +318,9 @@ async def test_invalid_language(hass: HomeAssistant, mock_tts, calls) -> None:
assert len(mock_tts.mock_calls) == 0
async def test_service_say_error(hass: HomeAssistant, mock_tts, calls) -> None:
async def test_service_say_error(
hass: HomeAssistant, hass_client: ClientSessionGenerator, mock_tts, calls
) -> None:
"""Test service call say with http error."""
mock_tts.return_value.speak.side_effect = pycsspeechtts.requests.HTTPError
await async_setup_component(
@ -313,6 +338,9 @@ async def test_service_say_error(hass: HomeAssistant, mock_tts, calls) -> None:
)
assert len(calls) == 1
with pytest.raises(HomeAssistantError):
await get_media_source_url(hass, calls[0].data[ATTR_MEDIA_CONTENT_ID])
assert (
await retrieve_media(hass, hass_client, calls[0].data[ATTR_MEDIA_CONTENT_ID])
== HTTPStatus.NOT_FOUND
)
assert len(mock_tts.mock_calls) == 2

View File

@ -2,6 +2,7 @@
from __future__ import annotations
from collections.abc import Generator
from http import HTTPStatus
from typing import Any
from unittest.mock import MagicMock, patch
@ -32,6 +33,7 @@ from tests.common import (
mock_integration,
mock_platform,
)
from tests.typing import ClientSessionGenerator
DEFAULT_LANG = "en_US"
SUPPORT_LANGUAGES = ["de_CH", "de_DE", "en_GB", "en_US"]
@ -103,6 +105,20 @@ async def get_media_source_url(hass: HomeAssistant, media_content_id: str) -> st
return resolved.url
async def retrieve_media(
hass: HomeAssistant, hass_client: ClientSessionGenerator, media_content_id: str
) -> HTTPStatus:
"""Get the media source url."""
url = await get_media_source_url(hass, media_content_id)
# Ensure media has been generated by requesting it
await hass.async_block_till_done()
client = await hass_client()
req = await client.get(url)
return req.status
class BaseProvider:
"""Test speech API provider."""

View File

@ -6,7 +6,7 @@ from unittest.mock import MagicMock, patch
import pytest
from homeassistant.components import tts
from homeassistant.components import ffmpeg, tts
from homeassistant.components.media_player import (
ATTR_MEDIA_ANNOUNCE,
ATTR_MEDIA_CONTENT_ID,
@ -15,7 +15,6 @@ from homeassistant.components.media_player import (
SERVICE_PLAY_MEDIA,
MediaType,
)
from homeassistant.components.media_source import Unresolvable
from homeassistant.config_entries import ConfigEntryState
from homeassistant.const import ATTR_ENTITY_ID, STATE_UNKNOWN
from homeassistant.core import HomeAssistant, State
@ -33,6 +32,7 @@ from .common import (
get_media_source_url,
mock_config_entry_setup,
mock_setup,
retrieve_media,
)
from tests.common import async_mock_service, mock_restore_cache
@ -75,7 +75,9 @@ async def test_default_entity_attributes() -> None:
async def test_config_entry_unload(
hass: HomeAssistant, mock_tts_entity: MockTTSEntity
hass: HomeAssistant,
hass_client: ClientSessionGenerator,
mock_tts_entity: MockTTSEntity,
) -> None:
"""Test we can unload config entry."""
entity_id = f"{tts.DOMAIN}.{TEST_DOMAIN}"
@ -104,7 +106,12 @@ async def test_config_entry_unload(
)
assert len(calls) == 1
await get_media_source_url(hass, calls[0].data[ATTR_MEDIA_CONTENT_ID])
assert (
await retrieve_media(
hass, hass_client, calls[0].data[ATTR_MEDIA_CONTENT_ID]
)
== HTTPStatus.OK
)
await hass.async_block_till_done()
state = hass.states.get(entity_id)
@ -1159,6 +1166,7 @@ class MockEntityEmpty(MockTTSEntity):
)
async def test_service_get_tts_error(
hass: HomeAssistant,
hass_client: ClientSessionGenerator,
setup: str,
tts_service: str,
service_data: dict[str, Any],
@ -1173,8 +1181,10 @@ async def test_service_get_tts_error(
blocking=True,
)
assert len(calls) == 1
with pytest.raises(Unresolvable):
await get_media_source_url(hass, calls[0].data[ATTR_MEDIA_CONTENT_ID])
assert (
await retrieve_media(hass, hass_client, calls[0].data[ATTR_MEDIA_CONTENT_ID])
== HTTPStatus.NOT_FOUND
)
async def test_load_cache_legacy_retrieve_without_mem_cache(
@ -1454,7 +1464,11 @@ async def test_legacy_fetching_in_async(
# Test async_get_media_source_audio
media_source_id = tts.generate_media_source_id(
hass, "test message", "test", "en_US", None, None
hass,
"test message",
"test",
"en_US",
cache=None,
)
task = hass.async_create_task(
@ -1508,16 +1522,6 @@ async def test_fetching_in_async(
class EntityWithAsyncFetching(MockTTSEntity):
"""Entity that supports audio output option."""
@property
def supported_options(self) -> list[str]:
"""Return list of supported options like voice, emotions."""
return [tts.ATTR_AUDIO_OUTPUT]
@property
def default_options(self) -> dict[str, str]:
"""Return a dict including the default options."""
return {tts.ATTR_AUDIO_OUTPUT: "mp3"}
async def async_get_tts_audio(
self, message: str, language: str, options: dict[str, Any]
) -> tts.TtsAudioType:
@ -1527,7 +1531,11 @@ async def test_fetching_in_async(
# Test async_get_media_source_audio
media_source_id = tts.generate_media_source_id(
hass, "test message", "tts.test", "en_US", None, None
hass,
"test message",
"tts.test",
"en_US",
cache=None,
)
task = hass.async_create_task(
@ -1751,3 +1759,12 @@ async def test_ws_list_voices(
{"voice_id": "fran_drescher", "name": "Fran Drescher"},
]
}
async def test_async_convert_audio_error(hass: HomeAssistant) -> None:
"""Test that ffmpeg failing during audio conversion will raise an error."""
assert await async_setup_component(hass, ffmpeg.DOMAIN, {})
with pytest.raises(RuntimeError):
# Simulate a bad WAV file
await tts.async_convert_audio(hass, "wav", bytes(0), "mp3")

View File

@ -1,4 +1,5 @@
"""Tests for TTS media source."""
from http import HTTPStatus
from unittest.mock import MagicMock
import pytest
@ -14,8 +15,11 @@ from .common import (
MockTTSEntity,
mock_config_entry_setup,
mock_setup,
retrieve_media,
)
from tests.typing import ClientSessionGenerator
class MSEntity(MockTTSEntity):
"""Test speech API entity."""
@ -88,16 +92,18 @@ async def test_browsing(hass: HomeAssistant, setup: str) -> None:
@pytest.mark.parametrize("mock_provider", [MSProvider(DEFAULT_LANG)])
async def test_legacy_resolving(hass: HomeAssistant, mock_provider: MSProvider) -> None:
async def test_legacy_resolving(
hass: HomeAssistant, hass_client: ClientSessionGenerator, mock_provider: MSProvider
) -> None:
"""Test resolving legacy provider."""
await mock_setup(hass, mock_provider)
mock_get_tts_audio = mock_provider.get_tts_audio
media = await media_source.async_resolve_media(
hass, "media-source://tts/test?message=Hello%20World", None
)
media_id = "media-source://tts/test?message=Hello%20World"
media = await media_source.async_resolve_media(hass, media_id, None)
assert media.url.startswith("/api/tts_proxy/")
assert media.mime_type == "audio/mpeg"
assert await retrieve_media(hass, hass_client, media_id) == HTTPStatus.OK
assert len(mock_get_tts_audio.mock_calls) == 1
message, language = mock_get_tts_audio.mock_calls[0][1]
@ -107,13 +113,11 @@ async def test_legacy_resolving(hass: HomeAssistant, mock_provider: MSProvider)
# Pass language and options
mock_get_tts_audio.reset_mock()
media = await media_source.async_resolve_media(
hass,
"media-source://tts/test?message=Bye%20World&language=de_DE&voice=Paulus",
None,
)
media_id = "media-source://tts/test?message=Bye%20World&language=de_DE&voice=Paulus"
media = await media_source.async_resolve_media(hass, media_id, None)
assert media.url.startswith("/api/tts_proxy/")
assert media.mime_type == "audio/mpeg"
assert await retrieve_media(hass, hass_client, media_id) == HTTPStatus.OK
assert len(mock_get_tts_audio.mock_calls) == 1
message, language = mock_get_tts_audio.mock_calls[0][1]
@ -123,16 +127,18 @@ async def test_legacy_resolving(hass: HomeAssistant, mock_provider: MSProvider)
@pytest.mark.parametrize("mock_tts_entity", [MSEntity(DEFAULT_LANG)])
async def test_resolving(hass: HomeAssistant, mock_tts_entity: MSEntity) -> None:
async def test_resolving(
hass: HomeAssistant, hass_client: ClientSessionGenerator, mock_tts_entity: MSEntity
) -> None:
"""Test resolving entity."""
await mock_config_entry_setup(hass, mock_tts_entity)
mock_get_tts_audio = mock_tts_entity.get_tts_audio
media = await media_source.async_resolve_media(
hass, "media-source://tts/tts.test?message=Hello%20World", None
)
media_id = "media-source://tts/tts.test?message=Hello%20World"
media = await media_source.async_resolve_media(hass, media_id, None)
assert media.url.startswith("/api/tts_proxy/")
assert media.mime_type == "audio/mpeg"
assert await retrieve_media(hass, hass_client, media_id) == HTTPStatus.OK
assert len(mock_get_tts_audio.mock_calls) == 1
message, language = mock_get_tts_audio.mock_calls[0][1]
@ -142,13 +148,13 @@ async def test_resolving(hass: HomeAssistant, mock_tts_entity: MSEntity) -> None
# Pass language and options
mock_get_tts_audio.reset_mock()
media = await media_source.async_resolve_media(
hass,
"media-source://tts/tts.test?message=Bye%20World&language=de_DE&voice=Paulus",
None,
media_id = (
"media-source://tts/tts.test?message=Bye%20World&language=de_DE&voice=Paulus"
)
media = await media_source.async_resolve_media(hass, media_id, None)
assert media.url.startswith("/api/tts_proxy/")
assert media.mime_type == "audio/mpeg"
assert await retrieve_media(hass, hass_client, media_id) == HTTPStatus.OK
assert len(mock_get_tts_audio.mock_calls) == 1
message, language = mock_get_tts_audio.mock_calls[0][1]

View File

@ -4,18 +4,19 @@ from http import HTTPStatus
import pytest
from homeassistant.components import media_source, tts
from homeassistant.components import tts
from homeassistant.components.media_player import (
ATTR_MEDIA_CONTENT_ID,
DOMAIN as DOMAIN_MP,
SERVICE_PLAY_MEDIA,
)
from homeassistant.core import HomeAssistant
from homeassistant.exceptions import HomeAssistantError
from homeassistant.setup import async_setup_component
from tests.common import assert_setup_component, async_mock_service
from tests.components.tts.common import retrieve_media
from tests.test_util.aiohttp import AiohttpClientMocker
from tests.typing import ClientSessionGenerator
URL = "https://api.voicerss.org/"
FORM_DATA = {
@ -38,15 +39,6 @@ def mock_tts_cache_dir_autouse(mock_tts_cache_dir):
return mock_tts_cache_dir
async def get_media_source_url(hass, media_content_id):
"""Get the media source url."""
if media_source.DOMAIN not in hass.config.components:
assert await async_setup_component(hass, media_source.DOMAIN, {})
resolved = await media_source.async_resolve_media(hass, media_content_id, None)
return resolved.url
async def test_setup_component(hass: HomeAssistant) -> None:
"""Test setup component."""
config = {tts.DOMAIN: {"platform": "voicerss", "api_key": "1234567xx"}}
@ -66,7 +58,9 @@ async def test_setup_component_without_api_key(hass: HomeAssistant) -> None:
async def test_service_say(
hass: HomeAssistant, aioclient_mock: AiohttpClientMocker
hass: HomeAssistant,
hass_client: ClientSessionGenerator,
aioclient_mock: AiohttpClientMocker,
) -> None:
"""Test service call say."""
calls = async_mock_service(hass, DOMAIN_MP, SERVICE_PLAY_MEDIA)
@ -90,14 +84,18 @@ async def test_service_say(
await hass.async_block_till_done()
assert len(calls) == 1
url = await get_media_source_url(hass, calls[0].data[ATTR_MEDIA_CONTENT_ID])
assert url.endswith(".mp3")
assert (
await retrieve_media(hass, hass_client, calls[0].data[ATTR_MEDIA_CONTENT_ID])
== HTTPStatus.OK
)
assert len(aioclient_mock.mock_calls) == 1
assert aioclient_mock.mock_calls[0][2] == FORM_DATA
async def test_service_say_german_config(
hass: HomeAssistant, aioclient_mock: AiohttpClientMocker
hass: HomeAssistant,
hass_client: ClientSessionGenerator,
aioclient_mock: AiohttpClientMocker,
) -> None:
"""Test service call say with german code in the config."""
calls = async_mock_service(hass, DOMAIN_MP, SERVICE_PLAY_MEDIA)
@ -128,13 +126,18 @@ async def test_service_say_german_config(
await hass.async_block_till_done()
assert len(calls) == 1
await get_media_source_url(hass, calls[0].data[ATTR_MEDIA_CONTENT_ID])
assert (
await retrieve_media(hass, hass_client, calls[0].data[ATTR_MEDIA_CONTENT_ID])
== HTTPStatus.OK
)
assert len(aioclient_mock.mock_calls) == 1
assert aioclient_mock.mock_calls[0][2] == form_data
async def test_service_say_german_service(
hass: HomeAssistant, aioclient_mock: AiohttpClientMocker
hass: HomeAssistant,
hass_client: ClientSessionGenerator,
aioclient_mock: AiohttpClientMocker,
) -> None:
"""Test service call say with german code in the service."""
calls = async_mock_service(hass, DOMAIN_MP, SERVICE_PLAY_MEDIA)
@ -160,13 +163,18 @@ async def test_service_say_german_service(
await hass.async_block_till_done()
assert len(calls) == 1
await get_media_source_url(hass, calls[0].data[ATTR_MEDIA_CONTENT_ID])
assert (
await retrieve_media(hass, hass_client, calls[0].data[ATTR_MEDIA_CONTENT_ID])
== HTTPStatus.OK
)
assert len(aioclient_mock.mock_calls) == 1
assert aioclient_mock.mock_calls[0][2] == form_data
async def test_service_say_error(
hass: HomeAssistant, aioclient_mock: AiohttpClientMocker
hass: HomeAssistant,
hass_client: ClientSessionGenerator,
aioclient_mock: AiohttpClientMocker,
) -> None:
"""Test service call say with http response 400."""
calls = async_mock_service(hass, DOMAIN_MP, SERVICE_PLAY_MEDIA)
@ -189,14 +197,18 @@ async def test_service_say_error(
)
await hass.async_block_till_done()
with pytest.raises(HomeAssistantError):
await get_media_source_url(hass, calls[0].data[ATTR_MEDIA_CONTENT_ID])
assert (
await retrieve_media(hass, hass_client, calls[0].data[ATTR_MEDIA_CONTENT_ID])
== HTTPStatus.NOT_FOUND
)
assert len(aioclient_mock.mock_calls) == 1
assert aioclient_mock.mock_calls[0][2] == FORM_DATA
async def test_service_say_timeout(
hass: HomeAssistant, aioclient_mock: AiohttpClientMocker
hass: HomeAssistant,
hass_client: ClientSessionGenerator,
aioclient_mock: AiohttpClientMocker,
) -> None:
"""Test service call say with http timeout."""
calls = async_mock_service(hass, DOMAIN_MP, SERVICE_PLAY_MEDIA)
@ -219,14 +231,18 @@ async def test_service_say_timeout(
)
await hass.async_block_till_done()
with pytest.raises(HomeAssistantError):
await get_media_source_url(hass, calls[0].data[ATTR_MEDIA_CONTENT_ID])
assert (
await retrieve_media(hass, hass_client, calls[0].data[ATTR_MEDIA_CONTENT_ID])
== HTTPStatus.NOT_FOUND
)
assert len(aioclient_mock.mock_calls) == 1
assert aioclient_mock.mock_calls[0][2] == FORM_DATA
async def test_service_say_error_msg(
hass: HomeAssistant, aioclient_mock: AiohttpClientMocker
hass: HomeAssistant,
hass_client: ClientSessionGenerator,
aioclient_mock: AiohttpClientMocker,
) -> None:
"""Test service call say with http error api message."""
calls = async_mock_service(hass, DOMAIN_MP, SERVICE_PLAY_MEDIA)
@ -254,7 +270,9 @@ async def test_service_say_error_msg(
)
await hass.async_block_till_done()
with pytest.raises(media_source.Unresolvable):
await get_media_source_url(hass, calls[0].data[ATTR_MEDIA_CONTENT_ID])
assert (
await retrieve_media(hass, hass_client, calls[0].data[ATTR_MEDIA_CONTENT_ID])
== HTTPStatus.NOT_FOUND
)
assert len(aioclient_mock.mock_calls) == 1
assert aioclient_mock.mock_calls[0][2] == FORM_DATA

View File

@ -10,6 +10,39 @@
}),
])
# ---
# name: test_get_tts_audio_different_formats
list([
dict({
'data': dict({
'text': 'Hello world',
}),
'payload': None,
'type': 'synthesize',
}),
])
# ---
# name: test_get_tts_audio_different_formats.1
list([
dict({
'data': dict({
'text': 'Hello world',
}),
'payload': None,
'type': 'synthesize',
}),
])
# ---
# name: test_get_tts_audio_mp3
list([
dict({
'data': dict({
'text': 'Hello world',
}),
'payload': None,
'type': 'synthesize',
}),
])
# ---
# name: test_get_tts_audio_raw
list([
dict({

View File

@ -51,31 +51,7 @@ async def test_get_tts_audio(hass: HomeAssistant, init_wyoming_tts, snapshot) ->
AudioStop().event(),
]
with patch(
"homeassistant.components.wyoming.tts.AsyncTcpClient",
MockAsyncTcpClient(audio_events),
) as mock_client:
extension, data = await tts.async_get_media_source_audio(
hass,
tts.generate_media_source_id(hass, "Hello world", "tts.test_tts", "en-US"),
)
assert extension == "wav"
assert data is not None
with io.BytesIO(data) as wav_io, wave.open(wav_io, "rb") as wav_file:
assert wav_file.getframerate() == 16000
assert wav_file.getsampwidth() == 2
assert wav_file.getnchannels() == 1
assert wav_file.readframes(wav_file.getnframes()) == audio
assert mock_client.written == snapshot
async def test_get_tts_audio_raw(
hass: HomeAssistant, init_wyoming_tts, snapshot
) -> None:
"""Test get raw audio."""
audio = bytes(100)
# Verify audio
audio_events = [
AudioChunk(audio=audio, rate=16000, width=2, channels=1).event(),
AudioStop().event(),
@ -92,12 +68,83 @@ async def test_get_tts_audio_raw(
"Hello world",
"tts.test_tts",
"en-US",
options={tts.ATTR_AUDIO_OUTPUT: "raw"},
options={tts.ATTR_PREFERRED_FORMAT: "wav"},
),
)
assert extension == "raw"
assert data == audio
assert extension == "wav"
assert data is not None
with io.BytesIO(data) as wav_io, wave.open(wav_io, "rb") as wav_file:
assert wav_file.getframerate() == 16000
assert wav_file.getsampwidth() == 2
assert wav_file.getnchannels() == 1
assert wav_file.readframes(wav_file.getnframes()) == audio
assert mock_client.written == snapshot
async def test_get_tts_audio_different_formats(
hass: HomeAssistant, init_wyoming_tts, snapshot
) -> None:
"""Test changing preferred audio format."""
audio = bytes(16000 * 2 * 1) # one second
audio_events = [
AudioChunk(audio=audio, rate=16000, width=2, channels=1).event(),
AudioStop().event(),
]
# Request a different sample rate, etc.
with patch(
"homeassistant.components.wyoming.tts.AsyncTcpClient",
MockAsyncTcpClient(audio_events),
) as mock_client:
extension, data = await tts.async_get_media_source_audio(
hass,
tts.generate_media_source_id(
hass,
"Hello world",
"tts.test_tts",
"en-US",
options={
tts.ATTR_PREFERRED_FORMAT: "wav",
tts.ATTR_PREFERRED_SAMPLE_RATE: 48000,
tts.ATTR_PREFERRED_SAMPLE_CHANNELS: 2,
},
),
)
assert extension == "wav"
assert data is not None
with io.BytesIO(data) as wav_io, wave.open(wav_io, "rb") as wav_file:
assert wav_file.getframerate() == 48000
assert wav_file.getsampwidth() == 2
assert wav_file.getnchannels() == 2
assert wav_file.getnframes() == wav_file.getframerate() # one second
assert mock_client.written == snapshot
# MP3 is the default
audio_events = [
AudioChunk(audio=audio, rate=16000, width=2, channels=1).event(),
AudioStop().event(),
]
with patch(
"homeassistant.components.wyoming.tts.AsyncTcpClient",
MockAsyncTcpClient(audio_events),
) as mock_client:
extension, data = await tts.async_get_media_source_audio(
hass,
tts.generate_media_source_id(
hass,
"Hello world",
"tts.test_tts",
"en-US",
),
)
assert extension == "mp3"
assert b"ID3" in data
assert mock_client.written == snapshot

View File

@ -4,7 +4,7 @@ from http import HTTPStatus
import pytest
from homeassistant.components import media_source, tts
from homeassistant.components import tts
from homeassistant.components.media_player import (
ATTR_MEDIA_CONTENT_ID,
DOMAIN as DOMAIN_MP,
@ -14,7 +14,9 @@ from homeassistant.core import HomeAssistant
from homeassistant.setup import async_setup_component
from tests.common import assert_setup_component, async_mock_service
from tests.components.tts.common import retrieve_media
from tests.test_util.aiohttp import AiohttpClientMocker
from tests.typing import ClientSessionGenerator
URL = "https://tts.voicetech.yandex.net/generate?"
@ -30,15 +32,6 @@ def mock_tts_cache_dir_autouse(mock_tts_cache_dir):
return mock_tts_cache_dir
async def get_media_source_url(hass, media_content_id):
"""Get the media source url."""
if media_source.DOMAIN not in hass.config.components:
assert await async_setup_component(hass, media_source.DOMAIN, {})
resolved = await media_source.async_resolve_media(hass, media_content_id, None)
return resolved.url
async def test_setup_component(hass: HomeAssistant) -> None:
"""Test setup component."""
config = {tts.DOMAIN: {"platform": "yandextts", "api_key": "1234567xx"}}
@ -58,7 +51,9 @@ async def test_setup_component_without_api_key(hass: HomeAssistant) -> None:
async def test_service_say(
hass: HomeAssistant, aioclient_mock: AiohttpClientMocker
hass: HomeAssistant,
aioclient_mock: AiohttpClientMocker,
hass_client: ClientSessionGenerator,
) -> None:
"""Test service call say."""
calls = async_mock_service(hass, DOMAIN_MP, SERVICE_PLAY_MEDIA)
@ -87,12 +82,18 @@ async def test_service_say(
blocking=True,
)
assert len(calls) == 1
await get_media_source_url(hass, calls[0].data[ATTR_MEDIA_CONTENT_ID])
assert (
await retrieve_media(hass, hass_client, calls[0].data[ATTR_MEDIA_CONTENT_ID])
== HTTPStatus.OK
)
assert len(aioclient_mock.mock_calls) == 1
async def test_service_say_russian_config(
hass: HomeAssistant, aioclient_mock: AiohttpClientMocker
hass: HomeAssistant,
aioclient_mock: AiohttpClientMocker,
hass_client: ClientSessionGenerator,
) -> None:
"""Test service call say."""
calls = async_mock_service(hass, DOMAIN_MP, SERVICE_PLAY_MEDIA)
@ -128,12 +129,18 @@ async def test_service_say_russian_config(
)
assert len(calls) == 1
await get_media_source_url(hass, calls[0].data[ATTR_MEDIA_CONTENT_ID])
assert (
await retrieve_media(hass, hass_client, calls[0].data[ATTR_MEDIA_CONTENT_ID])
== HTTPStatus.OK
)
assert len(aioclient_mock.mock_calls) == 1
async def test_service_say_russian_service(
hass: HomeAssistant, aioclient_mock: AiohttpClientMocker
hass: HomeAssistant,
aioclient_mock: AiohttpClientMocker,
hass_client: ClientSessionGenerator,
) -> None:
"""Test service call say."""
calls = async_mock_service(hass, DOMAIN_MP, SERVICE_PLAY_MEDIA)
@ -166,12 +173,18 @@ async def test_service_say_russian_service(
blocking=True,
)
assert len(calls) == 1
await get_media_source_url(hass, calls[0].data[ATTR_MEDIA_CONTENT_ID])
assert (
await retrieve_media(hass, hass_client, calls[0].data[ATTR_MEDIA_CONTENT_ID])
== HTTPStatus.OK
)
assert len(aioclient_mock.mock_calls) == 1
async def test_service_say_timeout(
hass: HomeAssistant, aioclient_mock: AiohttpClientMocker
hass: HomeAssistant,
aioclient_mock: AiohttpClientMocker,
hass_client: ClientSessionGenerator,
) -> None:
"""Test service call say."""
calls = async_mock_service(hass, DOMAIN_MP, SERVICE_PLAY_MEDIA)
@ -207,13 +220,18 @@ async def test_service_say_timeout(
await hass.async_block_till_done()
assert len(calls) == 1
with pytest.raises(media_source.Unresolvable):
await get_media_source_url(hass, calls[0].data[ATTR_MEDIA_CONTENT_ID])
assert (
await retrieve_media(hass, hass_client, calls[0].data[ATTR_MEDIA_CONTENT_ID])
== HTTPStatus.NOT_FOUND
)
assert len(aioclient_mock.mock_calls) == 1
async def test_service_say_http_error(
hass: HomeAssistant, aioclient_mock: AiohttpClientMocker
hass: HomeAssistant,
aioclient_mock: AiohttpClientMocker,
hass_client: ClientSessionGenerator,
) -> None:
"""Test service call say."""
calls = async_mock_service(hass, DOMAIN_MP, SERVICE_PLAY_MEDIA)
@ -248,12 +266,16 @@ async def test_service_say_http_error(
)
assert len(calls) == 1
with pytest.raises(media_source.Unresolvable):
await get_media_source_url(hass, calls[0].data[ATTR_MEDIA_CONTENT_ID])
assert (
await retrieve_media(hass, hass_client, calls[0].data[ATTR_MEDIA_CONTENT_ID])
== HTTPStatus.NOT_FOUND
)
async def test_service_say_specified_speaker(
hass: HomeAssistant, aioclient_mock: AiohttpClientMocker
hass: HomeAssistant,
aioclient_mock: AiohttpClientMocker,
hass_client: ClientSessionGenerator,
) -> None:
"""Test service call say."""
calls = async_mock_service(hass, DOMAIN_MP, SERVICE_PLAY_MEDIA)
@ -288,12 +310,18 @@ async def test_service_say_specified_speaker(
blocking=True,
)
assert len(calls) == 1
await get_media_source_url(hass, calls[0].data[ATTR_MEDIA_CONTENT_ID])
assert (
await retrieve_media(hass, hass_client, calls[0].data[ATTR_MEDIA_CONTENT_ID])
== HTTPStatus.OK
)
assert len(aioclient_mock.mock_calls) == 1
async def test_service_say_specified_emotion(
hass: HomeAssistant, aioclient_mock: AiohttpClientMocker
hass: HomeAssistant,
aioclient_mock: AiohttpClientMocker,
hass_client: ClientSessionGenerator,
) -> None:
"""Test service call say."""
calls = async_mock_service(hass, DOMAIN_MP, SERVICE_PLAY_MEDIA)
@ -328,13 +356,18 @@ async def test_service_say_specified_emotion(
blocking=True,
)
assert len(calls) == 1
await get_media_source_url(hass, calls[0].data[ATTR_MEDIA_CONTENT_ID])
assert (
await retrieve_media(hass, hass_client, calls[0].data[ATTR_MEDIA_CONTENT_ID])
== HTTPStatus.OK
)
assert len(aioclient_mock.mock_calls) == 1
async def test_service_say_specified_low_speed(
hass: HomeAssistant, aioclient_mock: AiohttpClientMocker
hass: HomeAssistant,
aioclient_mock: AiohttpClientMocker,
hass_client: ClientSessionGenerator,
) -> None:
"""Test service call say."""
calls = async_mock_service(hass, DOMAIN_MP, SERVICE_PLAY_MEDIA)
@ -365,13 +398,18 @@ async def test_service_say_specified_low_speed(
blocking=True,
)
assert len(calls) == 1
await get_media_source_url(hass, calls[0].data[ATTR_MEDIA_CONTENT_ID])
assert (
await retrieve_media(hass, hass_client, calls[0].data[ATTR_MEDIA_CONTENT_ID])
== HTTPStatus.OK
)
assert len(aioclient_mock.mock_calls) == 1
async def test_service_say_specified_speed(
hass: HomeAssistant, aioclient_mock: AiohttpClientMocker
hass: HomeAssistant,
aioclient_mock: AiohttpClientMocker,
hass_client: ClientSessionGenerator,
) -> None:
"""Test service call say."""
calls = async_mock_service(hass, DOMAIN_MP, SERVICE_PLAY_MEDIA)
@ -400,13 +438,18 @@ async def test_service_say_specified_speed(
blocking=True,
)
assert len(calls) == 1
await get_media_source_url(hass, calls[0].data[ATTR_MEDIA_CONTENT_ID])
assert (
await retrieve_media(hass, hass_client, calls[0].data[ATTR_MEDIA_CONTENT_ID])
== HTTPStatus.OK
)
assert len(aioclient_mock.mock_calls) == 1
async def test_service_say_specified_options(
hass: HomeAssistant, aioclient_mock: AiohttpClientMocker
hass: HomeAssistant,
aioclient_mock: AiohttpClientMocker,
hass_client: ClientSessionGenerator,
) -> None:
"""Test service call say with options."""
calls = async_mock_service(hass, DOMAIN_MP, SERVICE_PLAY_MEDIA)
@ -438,6 +481,9 @@ async def test_service_say_specified_options(
blocking=True,
)
assert len(calls) == 1
await get_media_source_url(hass, calls[0].data[ATTR_MEDIA_CONTENT_ID])
assert (
await retrieve_media(hass, hass_client, calls[0].data[ATTR_MEDIA_CONTENT_ID])
== HTTPStatus.OK
)
assert len(aioclient_mock.mock_calls) == 1