core/tests/components/assist_pipeline/conftest.py

266 lines
7.7 KiB
Python

"""Test fixtures for voice assistant."""
from __future__ import annotations
from collections.abc import AsyncIterable, Generator
from typing import Any
from unittest.mock import AsyncMock
import pytest
from homeassistant.components import stt, tts
from homeassistant.components.assist_pipeline import DOMAIN
from homeassistant.components.assist_pipeline.pipeline import PipelineStorageCollection
from homeassistant.config_entries import ConfigEntry, ConfigFlow
from homeassistant.core import HomeAssistant, callback
from homeassistant.helpers.entity_platform import AddEntitiesCallback
from homeassistant.setup import async_setup_component
from tests.common import (
MockConfigEntry,
MockModule,
MockPlatform,
mock_config_flow,
mock_integration,
mock_platform,
)
_TRANSCRIPT = "test transcript"
@pytest.fixture(autouse=True)
def mock_tts_cache_dir_autouse(mock_tts_cache_dir):
"""Mock the TTS cache dir with empty dir."""
return mock_tts_cache_dir
class BaseProvider:
"""Mock STT provider."""
_supported_languages = ["en-US"]
def __init__(self, text: str) -> None:
"""Init test provider."""
self.text = text
self.received: list[bytes] = []
@property
def supported_languages(self) -> list[str]:
"""Return a list of supported languages."""
return self._supported_languages
@property
def supported_formats(self) -> list[stt.AudioFormats]:
"""Return a list of supported formats."""
return [stt.AudioFormats.WAV]
@property
def supported_codecs(self) -> list[stt.AudioCodecs]:
"""Return a list of supported codecs."""
return [stt.AudioCodecs.PCM]
@property
def supported_bit_rates(self) -> list[stt.AudioBitRates]:
"""Return a list of supported bitrates."""
return [stt.AudioBitRates.BITRATE_16]
@property
def supported_sample_rates(self) -> list[stt.AudioSampleRates]:
"""Return a list of supported samplerates."""
return [stt.AudioSampleRates.SAMPLERATE_16000]
@property
def supported_channels(self) -> list[stt.AudioChannels]:
"""Return a list of supported channels."""
return [stt.AudioChannels.CHANNEL_MONO]
async def async_process_audio_stream(
self, metadata: stt.SpeechMetadata, stream: AsyncIterable[bytes]
) -> stt.SpeechResult:
"""Process an audio stream."""
async for data in stream:
if not data:
break
self.received.append(data)
return stt.SpeechResult(self.text, stt.SpeechResultState.SUCCESS)
class MockSttProvider(BaseProvider, stt.Provider):
"""Mock provider."""
class MockSttProviderEntity(BaseProvider, stt.SpeechToTextEntity):
"""Mock provider entity."""
_attr_name = "Mock STT"
class MockTTSProvider(tts.Provider):
"""Mock TTS provider."""
name = "Test"
_supported_languages = ["en-US"]
_supported_voices = {
"en-US": [
tts.Voice("james_earl_jones", "James Earl Jones"),
tts.Voice("fran_drescher", "Fran Drescher"),
]
}
@property
def default_language(self) -> str:
"""Return the default language."""
return "en"
@property
def supported_languages(self) -> list[str]:
"""Return list of supported languages."""
return self._supported_languages
@callback
def async_get_supported_voices(self, language: str) -> list[tts.Voice] | None:
"""Return a list of supported voices for a language."""
return self._supported_voices.get(language)
@property
def supported_options(self) -> list[str]:
"""Return list of supported options like voice, emotions."""
return ["voice", "age", tts.ATTR_AUDIO_OUTPUT]
def get_tts_audio(
self, message: str, language: str, options: dict[str, Any]
) -> tts.TtsAudioType:
"""Load TTS data."""
return ("mp3", b"")
class MockTTSPlatform(MockPlatform):
"""A mock TTS platform."""
PLATFORM_SCHEMA = tts.PLATFORM_SCHEMA
def __init__(self, *, async_get_engine, **kwargs):
"""Initialize the tts platform."""
super().__init__(**kwargs)
self.async_get_engine = async_get_engine
@pytest.fixture
async def mock_tts_provider(hass) -> MockTTSProvider:
"""Mock TTS provider."""
return MockTTSProvider()
@pytest.fixture
async def mock_stt_provider() -> MockSttProvider:
"""Mock STT provider."""
return MockSttProvider(_TRANSCRIPT)
@pytest.fixture
def mock_stt_provider_entity() -> MockSttProviderEntity:
"""Test provider entity fixture."""
return MockSttProviderEntity(_TRANSCRIPT)
class MockSttPlatform(MockPlatform):
"""Provide a fake STT platform."""
def __init__(self, *, async_get_engine, **kwargs):
"""Initialize the stt platform."""
super().__init__(**kwargs)
self.async_get_engine = async_get_engine
class MockFlow(ConfigFlow):
"""Test flow."""
@pytest.fixture
def config_flow_fixture(hass: HomeAssistant) -> Generator[None, None, None]:
"""Mock config flow."""
mock_platform(hass, "test.config_flow")
with mock_config_flow("test", MockFlow):
yield
@pytest.fixture
async def init_supporting_components(
hass: HomeAssistant,
mock_stt_provider: MockSttProvider,
mock_stt_provider_entity: MockSttProviderEntity,
mock_tts_provider: MockTTSProvider,
config_flow_fixture,
):
"""Initialize relevant components with empty configs."""
async def async_setup_entry_init(
hass: HomeAssistant, config_entry: ConfigEntry
) -> bool:
"""Set up test config entry."""
await hass.config_entries.async_forward_entry_setup(config_entry, stt.DOMAIN)
return True
async def async_unload_entry_init(
hass: HomeAssistant, config_entry: ConfigEntry
) -> bool:
"""Unload up test config entry."""
await hass.config_entries.async_forward_entry_unload(config_entry, stt.DOMAIN)
return True
async def async_setup_entry_stt_platform(
hass: HomeAssistant,
config_entry: ConfigEntry,
async_add_entities: AddEntitiesCallback,
) -> None:
"""Set up test stt platform via config entry."""
async_add_entities([mock_stt_provider_entity])
mock_integration(
hass,
MockModule(
"test",
async_setup_entry=async_setup_entry_init,
async_unload_entry=async_unload_entry_init,
),
)
mock_platform(
hass,
"test.tts",
MockTTSPlatform(
async_get_engine=AsyncMock(return_value=mock_tts_provider),
),
)
mock_platform(
hass,
"test.stt",
MockSttPlatform(
async_get_engine=AsyncMock(return_value=mock_stt_provider),
async_setup_entry=async_setup_entry_stt_platform,
),
)
mock_platform(hass, "test.config_flow")
assert await async_setup_component(hass, "homeassistant", {})
assert await async_setup_component(hass, tts.DOMAIN, {"tts": {"platform": "test"}})
assert await async_setup_component(hass, stt.DOMAIN, {"stt": {"platform": "test"}})
assert await async_setup_component(hass, "media_source", {})
config_entry = MockConfigEntry(domain="test")
config_entry.add_to_hass(hass)
assert await hass.config_entries.async_setup(config_entry.entry_id)
await hass.async_block_till_done()
@pytest.fixture
async def init_components(hass: HomeAssistant, init_supporting_components):
"""Initialize relevant components with empty configs."""
assert await async_setup_component(hass, "assist_pipeline", {})
@pytest.fixture
def pipeline_storage(hass: HomeAssistant, init_components) -> PipelineStorageCollection:
"""Return pipeline storage collection."""
return hass.data[DOMAIN].pipeline_store