core/tests/components/wyoming/test_tts.py

169 lines
5.1 KiB
Python
Raw Normal View History

"""Test tts."""
from __future__ import annotations
import io
from unittest.mock import patch
import wave
import pytest
from wyoming.audio import AudioChunk, AudioStop
from homeassistant.components import tts, wyoming
from homeassistant.core import HomeAssistant
from homeassistant.exceptions import HomeAssistantError
from homeassistant.helpers.entity_component import DATA_INSTANCES
from . import MockAsyncTcpClient
2023-05-24 19:00:35 +00:00
@pytest.fixture(autouse=True)
def mock_tts_cache_dir_autouse(mock_tts_cache_dir):
"""Mock the TTS cache dir with empty dir."""
return mock_tts_cache_dir
async def test_support(hass: HomeAssistant, init_wyoming_tts) -> None:
"""Test supported properties."""
state = hass.states.get("tts.test_tts")
assert state is not None
entity = hass.data[DATA_INSTANCES]["tts"].get_entity("tts.test_tts")
assert entity is not None
assert entity.supported_languages == ["en-US"]
assert entity.supported_options == [
tts.ATTR_AUDIO_OUTPUT,
tts.ATTR_VOICE,
wyoming.ATTR_SPEAKER,
]
voices = entity.async_get_supported_voices("en-US")
assert len(voices) == 1
assert voices[0].name == "Test Voice"
assert voices[0].voice_id == "Test Voice"
assert not entity.async_get_supported_voices("de-DE")
async def test_get_tts_audio(hass: HomeAssistant, init_wyoming_tts, snapshot) -> None:
"""Test get audio."""
audio = bytes(100)
audio_events = [
AudioChunk(audio=audio, rate=16000, width=2, channels=1).event(),
AudioStop().event(),
]
with patch(
"homeassistant.components.wyoming.tts.AsyncTcpClient",
MockAsyncTcpClient(audio_events),
) as mock_client:
extension, data = await tts.async_get_media_source_audio(
hass,
tts.generate_media_source_id(hass, "Hello world", "tts.test_tts", "en-US"),
)
assert extension == "wav"
assert data is not None
with io.BytesIO(data) as wav_io, wave.open(wav_io, "rb") as wav_file:
assert wav_file.getframerate() == 16000
assert wav_file.getsampwidth() == 2
assert wav_file.getnchannels() == 1
assert wav_file.readframes(wav_file.getnframes()) == audio
assert mock_client.written == snapshot
async def test_get_tts_audio_raw(
hass: HomeAssistant, init_wyoming_tts, snapshot
) -> None:
"""Test get raw audio."""
audio = bytes(100)
audio_events = [
AudioChunk(audio=audio, rate=16000, width=2, channels=1).event(),
AudioStop().event(),
]
with patch(
"homeassistant.components.wyoming.tts.AsyncTcpClient",
MockAsyncTcpClient(audio_events),
) as mock_client:
extension, data = await tts.async_get_media_source_audio(
hass,
tts.generate_media_source_id(
hass,
"Hello world",
"tts.test_tts",
"en-US",
options={tts.ATTR_AUDIO_OUTPUT: "raw"},
),
)
assert extension == "raw"
assert data == audio
assert mock_client.written == snapshot
async def test_get_tts_audio_connection_lost(
hass: HomeAssistant, init_wyoming_tts
) -> None:
"""Test streaming audio and losing connection."""
with patch(
"homeassistant.components.wyoming.tts.AsyncTcpClient",
MockAsyncTcpClient([None]),
), pytest.raises(HomeAssistantError):
await tts.async_get_media_source_audio(
hass,
tts.generate_media_source_id(hass, "Hello world", "tts.test_tts", "en-US"),
)
async def test_get_tts_audio_audio_oserror(
hass: HomeAssistant, init_wyoming_tts
) -> None:
"""Test get audio and error raising."""
audio = bytes(100)
audio_events = [
AudioChunk(audio=audio, rate=16000, width=2, channels=1).event(),
AudioStop().event(),
]
mock_client = MockAsyncTcpClient(audio_events)
with patch(
"homeassistant.components.wyoming.tts.AsyncTcpClient",
mock_client,
), patch.object(
mock_client, "read_event", side_effect=OSError("Boom!")
), pytest.raises(
HomeAssistantError
):
await tts.async_get_media_source_audio(
hass,
tts.generate_media_source_id(
hass, "Hello world", "tts.test_tts", hass.config.language
),
)
async def test_voice_speaker(hass: HomeAssistant, init_wyoming_tts, snapshot) -> None:
"""Test using a different voice and speaker."""
audio = bytes(100)
audio_events = [
AudioChunk(audio=audio, rate=16000, width=2, channels=1).event(),
AudioStop().event(),
]
with patch(
"homeassistant.components.wyoming.tts.AsyncTcpClient",
MockAsyncTcpClient(audio_events),
) as mock_client:
await tts.async_get_media_source_audio(
hass,
tts.generate_media_source_id(
hass,
"Hello world",
"tts.test_tts",
"en-US",
options={tts.ATTR_VOICE: "voice1", wyoming.ATTR_SPEAKER: "speaker1"},
),
)
assert mock_client.written == snapshot