156 lines
4.5 KiB
Python
156 lines
4.5 KiB
Python
"""Test STT component setup."""
|
|
from asyncio import StreamReader
|
|
from http import HTTPStatus
|
|
from unittest.mock import AsyncMock, Mock
|
|
|
|
import pytest
|
|
|
|
from homeassistant.components.stt import (
|
|
AudioBitRates,
|
|
AudioChannels,
|
|
AudioCodecs,
|
|
AudioFormats,
|
|
AudioSampleRates,
|
|
Provider,
|
|
SpeechMetadata,
|
|
SpeechResult,
|
|
SpeechResultState,
|
|
async_get_provider,
|
|
)
|
|
from homeassistant.setup import async_setup_component
|
|
|
|
from tests.common import mock_platform
|
|
|
|
|
|
class TestProvider(Provider):
|
|
"""Test provider."""
|
|
|
|
fail_process_audio = False
|
|
|
|
def __init__(self) -> None:
|
|
"""Init test provider."""
|
|
self.calls = []
|
|
|
|
@property
|
|
def supported_languages(self):
|
|
"""Return a list of supported languages."""
|
|
return ["en"]
|
|
|
|
@property
|
|
def supported_formats(self) -> list[AudioFormats]:
|
|
"""Return a list of supported formats."""
|
|
return [AudioFormats.WAV, AudioFormats.OGG]
|
|
|
|
@property
|
|
def supported_codecs(self) -> list[AudioCodecs]:
|
|
"""Return a list of supported codecs."""
|
|
return [AudioCodecs.PCM, AudioCodecs.OPUS]
|
|
|
|
@property
|
|
def supported_bit_rates(self) -> list[AudioBitRates]:
|
|
"""Return a list of supported bitrates."""
|
|
return [AudioBitRates.BITRATE_16]
|
|
|
|
@property
|
|
def supported_sample_rates(self) -> list[AudioSampleRates]:
|
|
"""Return a list of supported samplerates."""
|
|
return [AudioSampleRates.SAMPLERATE_16000]
|
|
|
|
@property
|
|
def supported_channels(self) -> list[AudioChannels]:
|
|
"""Return a list of supported channels."""
|
|
return [AudioChannels.CHANNEL_MONO]
|
|
|
|
async def async_process_audio_stream(
|
|
self, metadata: SpeechMetadata, stream: StreamReader
|
|
) -> SpeechResult:
|
|
"""Process an audio stream."""
|
|
self.calls.append((metadata, stream))
|
|
if self.fail_process_audio:
|
|
return SpeechResult(None, SpeechResultState.ERROR)
|
|
|
|
return SpeechResult("test", SpeechResultState.SUCCESS)
|
|
|
|
|
|
@pytest.fixture
|
|
def test_provider():
|
|
"""Test provider fixture."""
|
|
return TestProvider()
|
|
|
|
|
|
@pytest.fixture(autouse=True)
|
|
async def mock_setup(hass, test_provider):
|
|
"""Set up a test provider."""
|
|
mock_platform(
|
|
hass, "test.stt", Mock(async_get_engine=AsyncMock(return_value=test_provider))
|
|
)
|
|
assert await async_setup_component(hass, "stt", {"stt": {"platform": "test"}})
|
|
|
|
|
|
async def test_get_provider_info(hass, hass_client):
|
|
"""Test engine that doesn't exist."""
|
|
client = await hass_client()
|
|
response = await client.get("/api/stt/test")
|
|
assert response.status == HTTPStatus.OK
|
|
assert await response.json() == {
|
|
"languages": ["en"],
|
|
"formats": ["wav", "ogg"],
|
|
"codecs": ["pcm", "opus"],
|
|
"sample_rates": [16000],
|
|
"bit_rates": [16],
|
|
"channels": [1],
|
|
}
|
|
|
|
|
|
async def test_get_non_existing_provider_info(hass, hass_client):
|
|
"""Test streaming to engine that doesn't exist."""
|
|
client = await hass_client()
|
|
response = await client.get("/api/stt/not_exist")
|
|
assert response.status == HTTPStatus.NOT_FOUND
|
|
|
|
|
|
async def test_stream_audio(hass, hass_client):
|
|
"""Test streaming audio and getting response."""
|
|
client = await hass_client()
|
|
response = await client.post(
|
|
"/api/stt/test",
|
|
headers={
|
|
"X-Speech-Content": "format=wav; codec=pcm; sample_rate=16000; bit_rate=16; channel=1; language=en"
|
|
},
|
|
)
|
|
assert response.status == HTTPStatus.OK
|
|
assert await response.json() == {"text": "test", "result": "success"}
|
|
|
|
|
|
@pytest.mark.parametrize(
|
|
"header,status,error",
|
|
(
|
|
(None, 400, "Missing X-Speech-Content header"),
|
|
(
|
|
"format=wav; codec=pcm; sample_rate=16000; bit_rate=16; channel=100; language=en",
|
|
400,
|
|
"100 is not a valid AudioChannels",
|
|
),
|
|
(
|
|
"format=wav; codec=pcm; sample_rate=16000",
|
|
400,
|
|
"Missing language in X-Speech-Content header",
|
|
),
|
|
),
|
|
)
|
|
async def test_metadata_errors(hass, hass_client, header, status, error):
|
|
"""Test metadata errors."""
|
|
client = await hass_client()
|
|
headers = {}
|
|
if header:
|
|
headers["X-Speech-Content"] = header
|
|
|
|
response = await client.post("/api/stt/test", headers=headers)
|
|
assert response.status == status
|
|
assert await response.text() == error
|
|
|
|
|
|
async def test_get_provider(hass, test_provider):
|
|
"""Test we can get STT providers."""
|
|
assert test_provider == async_get_provider(hass, "test")
|