2019-11-04 12:10:42 +00:00
|
|
|
"""Test STT component setup."""
|
2023-04-13 17:58:35 +00:00
|
|
|
from collections.abc import AsyncIterable, Generator
|
2021-10-23 18:49:04 +00:00
|
|
|
from http import HTTPStatus
|
2023-04-04 12:52:36 +00:00
|
|
|
from pathlib import Path
|
|
|
|
from unittest.mock import AsyncMock
|
2019-11-04 12:10:42 +00:00
|
|
|
|
2022-09-24 07:58:01 +00:00
|
|
|
import pytest
|
|
|
|
|
|
|
|
from homeassistant.components.stt import (
|
2023-04-13 17:58:35 +00:00
|
|
|
DOMAIN,
|
2022-09-24 07:58:01 +00:00
|
|
|
AudioBitRates,
|
|
|
|
AudioChannels,
|
|
|
|
AudioCodecs,
|
|
|
|
AudioFormats,
|
|
|
|
AudioSampleRates,
|
|
|
|
Provider,
|
|
|
|
SpeechMetadata,
|
|
|
|
SpeechResult,
|
|
|
|
SpeechResultState,
|
2023-04-13 17:58:35 +00:00
|
|
|
SpeechToTextEntity,
|
2022-09-24 07:58:01 +00:00
|
|
|
async_get_provider,
|
|
|
|
)
|
2023-04-13 17:58:35 +00:00
|
|
|
from homeassistant.config_entries import ConfigEntry, ConfigEntryState, ConfigFlow
|
|
|
|
from homeassistant.core import HomeAssistant, State
|
|
|
|
from homeassistant.helpers.entity_platform import AddEntitiesCallback
|
2019-12-09 13:38:01 +00:00
|
|
|
from homeassistant.setup import async_setup_component
|
2019-11-04 12:10:42 +00:00
|
|
|
|
2023-04-13 17:58:35 +00:00
|
|
|
from .common import mock_stt_entity_platform, mock_stt_platform
|
2023-04-04 12:52:36 +00:00
|
|
|
|
2023-04-13 17:58:35 +00:00
|
|
|
from tests.common import (
|
|
|
|
MockConfigEntry,
|
|
|
|
MockModule,
|
|
|
|
mock_config_flow,
|
|
|
|
mock_integration,
|
|
|
|
mock_platform,
|
|
|
|
mock_restore_cache,
|
|
|
|
)
|
2023-02-07 15:06:53 +00:00
|
|
|
from tests.typing import ClientSessionGenerator
|
2022-09-24 07:58:01 +00:00
|
|
|
|
2023-04-13 17:58:35 +00:00
|
|
|
TEST_DOMAIN = "test"
|
|
|
|
|
2022-09-24 07:58:01 +00:00
|
|
|
|
2023-04-13 17:58:35 +00:00
|
|
|
class BaseProvider:
|
2023-02-07 15:06:53 +00:00
|
|
|
"""Mock provider."""
|
2022-09-24 07:58:01 +00:00
|
|
|
|
|
|
|
fail_process_audio = False
|
|
|
|
|
|
|
|
def __init__(self) -> None:
|
|
|
|
"""Init test provider."""
|
2023-04-04 12:52:36 +00:00
|
|
|
self.calls: list[tuple[SpeechMetadata, AsyncIterable[bytes]]] = []
|
2022-09-24 07:58:01 +00:00
|
|
|
|
|
|
|
@property
|
2023-02-07 15:06:53 +00:00
|
|
|
def supported_languages(self) -> list[str]:
|
2022-09-24 07:58:01 +00:00
|
|
|
"""Return a list of supported languages."""
|
|
|
|
return ["en"]
|
|
|
|
|
|
|
|
@property
|
|
|
|
def supported_formats(self) -> list[AudioFormats]:
|
|
|
|
"""Return a list of supported formats."""
|
|
|
|
return [AudioFormats.WAV, AudioFormats.OGG]
|
|
|
|
|
|
|
|
@property
|
|
|
|
def supported_codecs(self) -> list[AudioCodecs]:
|
|
|
|
"""Return a list of supported codecs."""
|
|
|
|
return [AudioCodecs.PCM, AudioCodecs.OPUS]
|
|
|
|
|
|
|
|
@property
|
|
|
|
def supported_bit_rates(self) -> list[AudioBitRates]:
|
|
|
|
"""Return a list of supported bitrates."""
|
|
|
|
return [AudioBitRates.BITRATE_16]
|
|
|
|
|
|
|
|
@property
|
|
|
|
def supported_sample_rates(self) -> list[AudioSampleRates]:
|
|
|
|
"""Return a list of supported samplerates."""
|
|
|
|
return [AudioSampleRates.SAMPLERATE_16000]
|
|
|
|
|
|
|
|
@property
|
|
|
|
def supported_channels(self) -> list[AudioChannels]:
|
|
|
|
"""Return a list of supported channels."""
|
|
|
|
return [AudioChannels.CHANNEL_MONO]
|
2019-11-04 12:10:42 +00:00
|
|
|
|
2022-09-24 07:58:01 +00:00
|
|
|
async def async_process_audio_stream(
|
2023-03-23 18:44:19 +00:00
|
|
|
self, metadata: SpeechMetadata, stream: AsyncIterable[bytes]
|
2022-09-24 07:58:01 +00:00
|
|
|
) -> SpeechResult:
|
|
|
|
"""Process an audio stream."""
|
|
|
|
self.calls.append((metadata, stream))
|
|
|
|
if self.fail_process_audio:
|
|
|
|
return SpeechResult(None, SpeechResultState.ERROR)
|
2019-11-04 12:10:42 +00:00
|
|
|
|
2023-04-13 17:58:35 +00:00
|
|
|
return SpeechResult("test_result", SpeechResultState.SUCCESS)
|
|
|
|
|
|
|
|
|
|
|
|
class MockProvider(BaseProvider, Provider):
|
|
|
|
"""Mock provider."""
|
|
|
|
|
|
|
|
|
|
|
|
class MockProviderEntity(BaseProvider, SpeechToTextEntity):
|
|
|
|
"""Mock provider entity."""
|
2019-11-04 12:10:42 +00:00
|
|
|
|
2022-09-24 07:58:01 +00:00
|
|
|
|
|
|
|
@pytest.fixture
|
2023-02-07 15:06:53 +00:00
|
|
|
def mock_provider() -> MockProvider:
|
2022-09-24 07:58:01 +00:00
|
|
|
"""Test provider fixture."""
|
2023-02-07 15:06:53 +00:00
|
|
|
return MockProvider()
|
2022-09-24 07:58:01 +00:00
|
|
|
|
|
|
|
|
2023-04-13 17:58:35 +00:00
|
|
|
@pytest.fixture
|
|
|
|
def mock_provider_entity() -> MockProviderEntity:
|
|
|
|
"""Test provider entity fixture."""
|
|
|
|
return MockProviderEntity()
|
|
|
|
|
|
|
|
|
|
|
|
class STTFlow(ConfigFlow):
|
|
|
|
"""Test flow."""
|
|
|
|
|
|
|
|
|
2022-09-24 07:58:01 +00:00
|
|
|
@pytest.fixture(autouse=True)
|
2023-04-13 17:58:35 +00:00
|
|
|
def config_flow_fixture(hass: HomeAssistant) -> Generator[None, None, None]:
|
|
|
|
"""Mock config flow."""
|
|
|
|
mock_platform(hass, f"{TEST_DOMAIN}.config_flow")
|
|
|
|
|
|
|
|
with mock_config_flow(TEST_DOMAIN, STTFlow):
|
|
|
|
yield
|
|
|
|
|
|
|
|
|
|
|
|
@pytest.fixture(name="setup")
|
|
|
|
async def setup_fixture(
|
|
|
|
hass: HomeAssistant,
|
|
|
|
tmp_path: Path,
|
|
|
|
request: pytest.FixtureRequest,
|
|
|
|
) -> None:
|
|
|
|
"""Set up the test environment."""
|
|
|
|
if request.param == "mock_setup":
|
|
|
|
await mock_setup(hass, tmp_path, MockProvider())
|
|
|
|
elif request.param == "mock_config_entry_setup":
|
|
|
|
await mock_config_entry_setup(hass, tmp_path, MockProviderEntity())
|
|
|
|
else:
|
|
|
|
raise RuntimeError("Invalid setup fixture")
|
|
|
|
|
|
|
|
|
2023-04-04 12:52:36 +00:00
|
|
|
async def mock_setup(
|
2023-04-13 17:58:35 +00:00
|
|
|
hass: HomeAssistant,
|
|
|
|
tmp_path: Path,
|
|
|
|
mock_provider: MockProvider,
|
2023-04-04 12:52:36 +00:00
|
|
|
) -> None:
|
2022-09-24 07:58:01 +00:00
|
|
|
"""Set up a test provider."""
|
2023-04-04 12:52:36 +00:00
|
|
|
mock_stt_platform(
|
|
|
|
hass,
|
|
|
|
tmp_path,
|
2023-04-13 17:58:35 +00:00
|
|
|
TEST_DOMAIN,
|
2023-04-04 12:52:36 +00:00
|
|
|
async_get_engine=AsyncMock(return_value=mock_provider),
|
2022-09-24 07:58:01 +00:00
|
|
|
)
|
2023-04-13 17:58:35 +00:00
|
|
|
assert await async_setup_component(hass, "stt", {"stt": {"platform": TEST_DOMAIN}})
|
|
|
|
await hass.async_block_till_done()
|
|
|
|
|
|
|
|
|
|
|
|
async def mock_config_entry_setup(
|
|
|
|
hass: HomeAssistant, tmp_path: Path, mock_provider_entity: MockProviderEntity
|
|
|
|
) -> MockConfigEntry:
|
|
|
|
"""Set up a test provider via config entry."""
|
|
|
|
|
|
|
|
async def async_setup_entry_init(
|
|
|
|
hass: HomeAssistant, config_entry: ConfigEntry
|
|
|
|
) -> bool:
|
|
|
|
"""Set up test config entry."""
|
|
|
|
await hass.config_entries.async_forward_entry_setup(config_entry, DOMAIN)
|
|
|
|
return True
|
|
|
|
|
|
|
|
async def async_unload_entry_init(
|
|
|
|
hass: HomeAssistant, config_entry: ConfigEntry
|
|
|
|
) -> bool:
|
|
|
|
"""Unload up test config entry."""
|
|
|
|
await hass.config_entries.async_forward_entry_unload(config_entry, DOMAIN)
|
|
|
|
return True
|
|
|
|
|
|
|
|
mock_integration(
|
|
|
|
hass,
|
|
|
|
MockModule(
|
|
|
|
TEST_DOMAIN,
|
|
|
|
async_setup_entry=async_setup_entry_init,
|
|
|
|
async_unload_entry=async_unload_entry_init,
|
|
|
|
),
|
|
|
|
)
|
|
|
|
|
|
|
|
async def async_setup_entry_platform(
|
|
|
|
hass: HomeAssistant,
|
|
|
|
config_entry: ConfigEntry,
|
|
|
|
async_add_entities: AddEntitiesCallback,
|
|
|
|
) -> None:
|
|
|
|
"""Set up test stt platform via config entry."""
|
|
|
|
async_add_entities([mock_provider_entity])
|
|
|
|
|
|
|
|
mock_stt_entity_platform(hass, tmp_path, TEST_DOMAIN, async_setup_entry_platform)
|
|
|
|
|
|
|
|
config_entry = MockConfigEntry(domain=TEST_DOMAIN)
|
|
|
|
config_entry.add_to_hass(hass)
|
|
|
|
assert await hass.config_entries.async_setup(config_entry.entry_id)
|
|
|
|
await hass.async_block_till_done()
|
|
|
|
|
|
|
|
return config_entry
|
2022-09-24 07:58:01 +00:00
|
|
|
|
|
|
|
|
2023-04-13 17:58:35 +00:00
|
|
|
@pytest.mark.parametrize(
|
|
|
|
"setup", ["mock_setup", "mock_config_entry_setup"], indirect=True
|
|
|
|
)
|
2023-02-07 15:06:53 +00:00
|
|
|
async def test_get_provider_info(
|
2023-04-13 17:58:35 +00:00
|
|
|
hass: HomeAssistant,
|
|
|
|
hass_client: ClientSessionGenerator,
|
|
|
|
setup: str,
|
2023-02-07 15:06:53 +00:00
|
|
|
) -> None:
|
2022-09-24 07:58:01 +00:00
|
|
|
"""Test engine that doesn't exist."""
|
2019-11-04 12:10:42 +00:00
|
|
|
client = await hass_client()
|
2023-04-13 17:58:35 +00:00
|
|
|
response = await client.get(f"/api/stt/{TEST_DOMAIN}")
|
2022-09-24 07:58:01 +00:00
|
|
|
assert response.status == HTTPStatus.OK
|
|
|
|
assert await response.json() == {
|
|
|
|
"languages": ["en"],
|
|
|
|
"formats": ["wav", "ogg"],
|
|
|
|
"codecs": ["pcm", "opus"],
|
|
|
|
"sample_rates": [16000],
|
|
|
|
"bit_rates": [16],
|
|
|
|
"channels": [1],
|
|
|
|
}
|
2019-11-04 12:10:42 +00:00
|
|
|
|
|
|
|
|
2023-04-13 17:58:35 +00:00
|
|
|
@pytest.mark.parametrize(
|
|
|
|
"setup", ["mock_setup", "mock_config_entry_setup"], indirect=True
|
|
|
|
)
|
|
|
|
async def test_non_existing_provider(
|
|
|
|
hass: HomeAssistant,
|
|
|
|
hass_client: ClientSessionGenerator,
|
|
|
|
setup: str,
|
2023-02-07 15:06:53 +00:00
|
|
|
) -> None:
|
2022-09-24 07:58:01 +00:00
|
|
|
"""Test streaming to engine that doesn't exist."""
|
|
|
|
client = await hass_client()
|
2023-04-13 17:58:35 +00:00
|
|
|
|
2022-09-24 07:58:01 +00:00
|
|
|
response = await client.get("/api/stt/not_exist")
|
2021-10-23 18:49:04 +00:00
|
|
|
assert response.status == HTTPStatus.NOT_FOUND
|
2019-11-04 12:10:42 +00:00
|
|
|
|
2023-04-13 17:58:35 +00:00
|
|
|
response = await client.post(
|
|
|
|
"/api/stt/not_exist",
|
|
|
|
headers={
|
|
|
|
"X-Speech-Content": (
|
|
|
|
"format=wav; codec=pcm; sample_rate=16000; bit_rate=16; channel=1;"
|
|
|
|
" language=en"
|
|
|
|
)
|
|
|
|
},
|
|
|
|
)
|
|
|
|
assert response.status == HTTPStatus.NOT_FOUND
|
|
|
|
|
2019-11-04 12:10:42 +00:00
|
|
|
|
2023-04-13 17:58:35 +00:00
|
|
|
@pytest.mark.parametrize(
|
|
|
|
"setup", ["mock_setup", "mock_config_entry_setup"], indirect=True
|
|
|
|
)
|
2023-02-07 15:06:53 +00:00
|
|
|
async def test_stream_audio(
|
2023-04-13 17:58:35 +00:00
|
|
|
hass: HomeAssistant,
|
|
|
|
hass_client: ClientSessionGenerator,
|
|
|
|
setup: str,
|
2023-02-07 15:06:53 +00:00
|
|
|
) -> None:
|
2022-09-24 07:58:01 +00:00
|
|
|
"""Test streaming audio and getting response."""
|
2019-11-04 12:10:42 +00:00
|
|
|
client = await hass_client()
|
2022-09-24 07:58:01 +00:00
|
|
|
response = await client.post(
|
2023-04-13 17:58:35 +00:00
|
|
|
f"/api/stt/{TEST_DOMAIN}",
|
2022-09-24 07:58:01 +00:00
|
|
|
headers={
|
2023-01-26 00:23:53 +00:00
|
|
|
"X-Speech-Content": (
|
|
|
|
"format=wav; codec=pcm; sample_rate=16000; bit_rate=16; channel=1;"
|
|
|
|
" language=en"
|
|
|
|
)
|
2022-09-24 07:58:01 +00:00
|
|
|
},
|
|
|
|
)
|
|
|
|
assert response.status == HTTPStatus.OK
|
2023-04-13 17:58:35 +00:00
|
|
|
assert await response.json() == {"text": "test_result", "result": "success"}
|
2019-11-04 12:10:42 +00:00
|
|
|
|
|
|
|
|
2023-04-13 17:58:35 +00:00
|
|
|
@pytest.mark.parametrize(
|
|
|
|
"setup", ["mock_setup", "mock_config_entry_setup"], indirect=True
|
|
|
|
)
|
2022-09-24 07:58:01 +00:00
|
|
|
@pytest.mark.parametrize(
|
2023-02-15 13:09:50 +00:00
|
|
|
("header", "status", "error"),
|
2022-09-24 07:58:01 +00:00
|
|
|
(
|
|
|
|
(None, 400, "Missing X-Speech-Content header"),
|
2023-04-13 17:58:35 +00:00
|
|
|
(
|
|
|
|
(
|
|
|
|
"format=wav; codec=pcm; sample_rate=16000; bit_rate=16; channel=100;"
|
|
|
|
" language=en; unknown=1"
|
|
|
|
),
|
|
|
|
400,
|
|
|
|
"Invalid field: unknown",
|
|
|
|
),
|
2022-09-24 07:58:01 +00:00
|
|
|
(
|
2023-01-26 00:23:53 +00:00
|
|
|
(
|
|
|
|
"format=wav; codec=pcm; sample_rate=16000; bit_rate=16; channel=100;"
|
|
|
|
" language=en"
|
|
|
|
),
|
2022-09-24 07:58:01 +00:00
|
|
|
400,
|
2023-04-13 17:58:35 +00:00
|
|
|
"Wrong format of X-Speech-Content: 100 is not a valid AudioChannels",
|
|
|
|
),
|
|
|
|
(
|
|
|
|
(
|
|
|
|
"format=wav; codec=pcm; sample_rate=16000; bit_rate=16; channel=bad channel;"
|
|
|
|
" language=en"
|
|
|
|
),
|
|
|
|
400,
|
|
|
|
"Wrong format of X-Speech-Content: invalid literal for int() with base 10: 'bad channel'",
|
2022-09-24 07:58:01 +00:00
|
|
|
),
|
|
|
|
(
|
|
|
|
"format=wav; codec=pcm; sample_rate=16000",
|
|
|
|
400,
|
|
|
|
"Missing language in X-Speech-Content header",
|
|
|
|
),
|
|
|
|
),
|
|
|
|
)
|
2023-02-07 15:06:53 +00:00
|
|
|
async def test_metadata_errors(
|
|
|
|
hass: HomeAssistant,
|
|
|
|
hass_client: ClientSessionGenerator,
|
|
|
|
header: str | None,
|
|
|
|
status: int,
|
|
|
|
error: str,
|
2023-04-13 17:58:35 +00:00
|
|
|
setup: str,
|
2023-02-07 15:06:53 +00:00
|
|
|
) -> None:
|
2022-09-24 07:58:01 +00:00
|
|
|
"""Test metadata errors."""
|
|
|
|
client = await hass_client()
|
2023-02-07 15:06:53 +00:00
|
|
|
headers: dict[str, str] = {}
|
2022-09-24 07:58:01 +00:00
|
|
|
if header:
|
|
|
|
headers["X-Speech-Content"] = header
|
|
|
|
|
2023-04-13 17:58:35 +00:00
|
|
|
response = await client.post(f"/api/stt/{TEST_DOMAIN}", headers=headers)
|
2022-09-24 07:58:01 +00:00
|
|
|
assert response.status == status
|
|
|
|
assert await response.text() == error
|
|
|
|
|
|
|
|
|
2023-04-13 17:58:35 +00:00
|
|
|
async def test_get_provider(
|
|
|
|
hass: HomeAssistant,
|
|
|
|
tmp_path: Path,
|
|
|
|
mock_provider: MockProvider,
|
|
|
|
) -> None:
|
2022-09-24 07:58:01 +00:00
|
|
|
"""Test we can get STT providers."""
|
2023-04-13 17:58:35 +00:00
|
|
|
await mock_setup(hass, tmp_path, mock_provider)
|
|
|
|
assert mock_provider == async_get_provider(hass, TEST_DOMAIN)
|
|
|
|
|
|
|
|
|
|
|
|
async def test_config_entry_unload(
|
|
|
|
hass: HomeAssistant, tmp_path: Path, mock_provider_entity: MockProviderEntity
|
|
|
|
) -> None:
|
|
|
|
"""Test we can unload config entry."""
|
|
|
|
config_entry = await mock_config_entry_setup(hass, tmp_path, mock_provider_entity)
|
|
|
|
assert config_entry.state == ConfigEntryState.LOADED
|
|
|
|
await hass.config_entries.async_unload(config_entry.entry_id)
|
|
|
|
assert config_entry.state == ConfigEntryState.NOT_LOADED
|
|
|
|
|
|
|
|
|
|
|
|
def test_entity_name_raises_before_addition(
|
|
|
|
hass: HomeAssistant,
|
|
|
|
tmp_path: Path,
|
|
|
|
mock_provider_entity: MockProviderEntity,
|
|
|
|
) -> None:
|
|
|
|
"""Test entity name raises before addition to Home Assistant."""
|
|
|
|
with pytest.raises(RuntimeError):
|
|
|
|
mock_provider_entity.name # pylint: disable=pointless-statement
|
|
|
|
|
|
|
|
|
|
|
|
async def test_restore_state(
|
|
|
|
hass: HomeAssistant,
|
|
|
|
tmp_path: Path,
|
|
|
|
mock_provider_entity: MockProviderEntity,
|
|
|
|
) -> None:
|
|
|
|
"""Test we restore state in the integration."""
|
|
|
|
entity_id = f"{DOMAIN}.{TEST_DOMAIN}"
|
|
|
|
timestamp = "2023-01-01T23:59:59+00:00"
|
|
|
|
mock_restore_cache(hass, (State(entity_id, timestamp),))
|
|
|
|
|
|
|
|
config_entry = await mock_config_entry_setup(hass, tmp_path, mock_provider_entity)
|
|
|
|
await hass.async_block_till_done()
|
|
|
|
|
|
|
|
assert config_entry.state == ConfigEntryState.LOADED
|
|
|
|
state = hass.states.get(entity_id)
|
|
|
|
assert state
|
|
|
|
assert state.state == timestamp
|