core/tests/components/tts/common.py

260 lines
7.5 KiB
Python

"""Provide common tests tools for tts."""
from __future__ import annotations
from collections.abc import Generator
from http import HTTPStatus
from typing import Any
from unittest.mock import MagicMock, patch
import pytest
import voluptuous as vol
from homeassistant.components import media_source
from homeassistant.components.tts import (
CONF_LANG,
DOMAIN as TTS_DOMAIN,
PLATFORM_SCHEMA,
Provider,
TextToSpeechEntity,
TtsAudioType,
Voice,
_get_cache_files,
)
from homeassistant.config_entries import ConfigEntry
from homeassistant.core import HomeAssistant, callback
from homeassistant.helpers.entity_platform import AddEntitiesCallback
from homeassistant.helpers.typing import ConfigType, DiscoveryInfoType
from homeassistant.setup import async_setup_component
from tests.common import (
MockConfigEntry,
MockModule,
MockPlatform,
mock_integration,
mock_platform,
)
from tests.typing import ClientSessionGenerator
DEFAULT_LANG = "en_US"
SUPPORT_LANGUAGES = ["de_CH", "de_DE", "en_GB", "en_US"]
TEST_DOMAIN = "test"
def mock_tts_get_cache_files_fixture_helper():
"""Mock the list TTS cache function."""
with patch(
"homeassistant.components.tts._get_cache_files", return_value={}
) as mock_cache_files:
yield mock_cache_files
def mock_tts_init_cache_dir_fixture_helper(
init_tts_cache_dir_side_effect: Any,
) -> Generator[MagicMock, None, None]:
"""Mock the TTS cache dir in memory."""
with patch(
"homeassistant.components.tts._init_tts_cache_dir",
side_effect=init_tts_cache_dir_side_effect,
) as mock_cache_dir:
yield mock_cache_dir
def init_tts_cache_dir_side_effect_fixture_helper() -> Any:
"""Return the cache dir."""
return None
def mock_tts_cache_dir_fixture_helper(
tmp_path, mock_tts_init_cache_dir, mock_tts_get_cache_files, request
):
"""Mock the TTS cache dir with empty dir."""
mock_tts_init_cache_dir.return_value = str(tmp_path)
# Restore original get cache files behavior, we're working with a real dir.
mock_tts_get_cache_files.side_effect = _get_cache_files
yield tmp_path
if not hasattr(request.node, "rep_call") or request.node.rep_call.passed:
return
# Print contents of dir if failed
print("Content of dir for", request.node.nodeid) # noqa: T201
for fil in tmp_path.iterdir():
print(fil.relative_to(tmp_path)) # noqa: T201
# To show the log.
pytest.fail("Test failed, see log for details")
def tts_mutagen_mock_fixture_helper():
"""Mock writing tags."""
with patch(
"homeassistant.components.tts.SpeechManager.write_tags",
side_effect=lambda *args: args[1],
) as mock_write_tags:
yield mock_write_tags
async def get_media_source_url(hass: HomeAssistant, media_content_id: str) -> str:
"""Get the media source url."""
if media_source.DOMAIN not in hass.config.components:
assert await async_setup_component(hass, media_source.DOMAIN, {})
resolved = await media_source.async_resolve_media(hass, media_content_id, None)
return resolved.url
async def retrieve_media(
hass: HomeAssistant, hass_client: ClientSessionGenerator, media_content_id: str
) -> HTTPStatus:
"""Get the media source url."""
url = await get_media_source_url(hass, media_content_id)
# Ensure media has been generated by requesting it
await hass.async_block_till_done()
client = await hass_client()
req = await client.get(url)
return req.status
class BaseProvider:
"""Test speech API provider."""
def __init__(self, lang: str) -> None:
"""Initialize test provider."""
self._lang = lang
@property
def default_language(self) -> str:
"""Return the default language."""
return self._lang
@property
def supported_languages(self) -> list[str]:
"""Return list of supported languages."""
return SUPPORT_LANGUAGES
@callback
def async_get_supported_voices(self, language: str) -> list[Voice] | None:
"""Return list of supported languages."""
if language == "en-US":
return [
Voice("james_earl_jones", "James Earl Jones"),
Voice("fran_drescher", "Fran Drescher"),
]
return None
@property
def supported_options(self) -> list[str]:
"""Return list of supported options like voice, emotions."""
return ["voice", "age"]
def get_tts_audio(
self, message: str, language: str, options: dict[str, Any]
) -> TtsAudioType:
"""Load TTS dat."""
return ("mp3", b"")
class MockProvider(BaseProvider, Provider):
"""Test speech API provider."""
def __init__(self, lang: str) -> None:
"""Initialize test provider."""
super().__init__(lang)
self.name = "Test"
class MockTTSEntity(BaseProvider, TextToSpeechEntity):
"""Test speech API provider."""
@property
def name(self) -> str:
"""Return the name of the entity."""
return "Test"
class MockTTS(MockPlatform):
"""A mock TTS platform."""
PLATFORM_SCHEMA = PLATFORM_SCHEMA.extend(
{vol.Optional(CONF_LANG, default=DEFAULT_LANG): vol.In(SUPPORT_LANGUAGES)}
)
def __init__(self, provider: MockProvider, **kwargs: Any) -> None:
"""Initialize."""
super().__init__(**kwargs)
self._provider = provider
async def async_get_engine(
self,
hass: HomeAssistant,
config: ConfigType,
discovery_info: DiscoveryInfoType | None = None,
) -> Provider | None:
"""Set up a mock speech component."""
return self._provider
async def mock_setup(
hass: HomeAssistant,
mock_provider: MockProvider,
) -> None:
"""Set up a test provider."""
mock_integration(hass, MockModule(domain=TEST_DOMAIN))
mock_platform(hass, f"{TEST_DOMAIN}.{TTS_DOMAIN}", MockTTS(mock_provider))
await async_setup_component(
hass, TTS_DOMAIN, {TTS_DOMAIN: {"platform": TEST_DOMAIN}}
)
await hass.async_block_till_done()
async def mock_config_entry_setup(
hass: HomeAssistant, tts_entity: MockTTSEntity
) -> MockConfigEntry:
"""Set up a test tts platform via config entry."""
async def async_setup_entry_init(
hass: HomeAssistant, config_entry: ConfigEntry
) -> bool:
"""Set up test config entry."""
await hass.config_entries.async_forward_entry_setup(config_entry, TTS_DOMAIN)
return True
async def async_unload_entry_init(
hass: HomeAssistant, config_entry: ConfigEntry
) -> bool:
"""Unload test config entry."""
await hass.config_entries.async_forward_entry_unload(config_entry, TTS_DOMAIN)
return True
mock_integration(
hass,
MockModule(
TEST_DOMAIN,
async_setup_entry=async_setup_entry_init,
async_unload_entry=async_unload_entry_init,
),
)
async def async_setup_entry_platform(
hass: HomeAssistant,
config_entry: ConfigEntry,
async_add_entities: AddEntitiesCallback,
) -> None:
"""Set up test tts platform via config entry."""
async_add_entities([tts_entity])
loaded_platform = MockPlatform(async_setup_entry=async_setup_entry_platform)
mock_platform(hass, f"{TEST_DOMAIN}.{TTS_DOMAIN}", loaded_platform)
config_entry = MockConfigEntry(domain=TEST_DOMAIN)
config_entry.add_to_hass(hass)
assert await hass.config_entries.async_setup(config_entry.entry_id)
await hass.async_block_till_done()
return config_entry