core/tests/components/assist_pipeline/test_pipeline.py

612 lines
20 KiB
Python

"""Websocket tests for Voice Assistant integration."""
from collections.abc import AsyncGenerator
from typing import Any
from unittest.mock import ANY, patch
import pytest
from homeassistant.components.assist_pipeline.const import DOMAIN
from homeassistant.components.assist_pipeline.pipeline import (
STORAGE_KEY,
STORAGE_VERSION,
STORAGE_VERSION_MINOR,
Pipeline,
PipelineData,
PipelineStorageCollection,
PipelineStore,
async_create_default_pipeline,
async_get_pipeline,
async_get_pipelines,
async_update_pipeline,
)
from homeassistant.core import HomeAssistant
from homeassistant.setup import async_setup_component
from . import MANY_LANGUAGES
from .conftest import MockSttProvider, MockTTSProvider
from tests.common import flush_store
@pytest.fixture(autouse=True)
async def delay_save_fixture() -> AsyncGenerator[None, None]:
"""Load the homeassistant integration."""
with patch("homeassistant.helpers.collection.SAVE_DELAY", new=0):
yield
@pytest.fixture(autouse=True)
async def load_homeassistant(hass) -> None:
"""Load the homeassistant integration."""
assert await async_setup_component(hass, "homeassistant", {})
async def test_load_pipelines(hass: HomeAssistant, init_components) -> None:
"""Make sure that we can load/save data correctly."""
pipelines = [
{
"conversation_engine": "conversation_engine_1",
"conversation_language": "language_1",
"language": "language_1",
"name": "name_1",
"stt_engine": "stt_engine_1",
"stt_language": "language_1",
"tts_engine": "tts_engine_1",
"tts_language": "language_1",
"tts_voice": "Arnold Schwarzenegger",
"wake_word_entity": "wakeword_entity_1",
"wake_word_id": "wakeword_id_1",
},
{
"conversation_engine": "conversation_engine_2",
"conversation_language": "language_2",
"language": "language_2",
"name": "name_2",
"stt_engine": "stt_engine_2",
"stt_language": "language_1",
"tts_engine": "tts_engine_2",
"tts_language": "language_2",
"tts_voice": "The Voice",
"wake_word_entity": "wakeword_entity_2",
"wake_word_id": "wakeword_id_2",
},
{
"conversation_engine": "conversation_engine_3",
"conversation_language": "language_3",
"language": "language_3",
"name": "name_3",
"stt_engine": None,
"stt_language": None,
"tts_engine": None,
"tts_language": None,
"tts_voice": None,
"wake_word_entity": "wakeword_entity_3",
"wake_word_id": "wakeword_id_3",
},
]
pipeline_ids = []
pipeline_data: PipelineData = hass.data[DOMAIN]
store1 = pipeline_data.pipeline_store
for pipeline in pipelines:
pipeline_ids.append((await store1.async_create_item(pipeline)).id)
assert len(store1.data) == 4 # 3 manually created plus a default pipeline
assert store1.async_get_preferred_item() == list(store1.data)[0]
await store1.async_delete_item(pipeline_ids[1])
assert len(store1.data) == 3
store2 = PipelineStorageCollection(
PipelineStore(
hass, STORAGE_VERSION, STORAGE_KEY, minor_version=STORAGE_VERSION_MINOR
)
)
await flush_store(store1.store)
await store2.async_load()
assert len(store2.data) == 3
assert store1.data is not store2.data
assert store1.data == store2.data
assert store1.async_get_preferred_item() == store2.async_get_preferred_item()
async def test_loading_pipelines_from_storage(
hass: HomeAssistant, hass_storage: dict[str, Any]
) -> None:
"""Test loading stored pipelines on start."""
hass_storage[STORAGE_KEY] = {
"version": STORAGE_VERSION,
"minor_version": STORAGE_VERSION_MINOR,
"key": "assist_pipeline.pipelines",
"data": {
"items": [
{
"conversation_engine": "conversation_engine_1",
"conversation_language": "language_1",
"id": "01GX8ZWBAQYWNB1XV3EXEZ75DY",
"language": "language_1",
"name": "name_1",
"stt_engine": "stt_engine_1",
"stt_language": "language_1",
"tts_engine": "tts_engine_1",
"tts_language": "language_1",
"tts_voice": "Arnold Schwarzenegger",
"wake_word_entity": "wakeword_entity_1",
"wake_word_id": "wakeword_id_1",
},
{
"conversation_engine": "conversation_engine_2",
"conversation_language": "language_2",
"id": "01GX8ZWBAQTKFQNK4W7Q4CTRCX",
"language": "language_2",
"name": "name_2",
"stt_engine": "stt_engine_2",
"stt_language": "language_2",
"tts_engine": "tts_engine_2",
"tts_language": "language_2",
"tts_voice": "The Voice",
"wake_word_entity": "wakeword_entity_2",
"wake_word_id": "wakeword_id_2",
},
{
"conversation_engine": "conversation_engine_3",
"conversation_language": "language_3",
"id": "01GX8ZWBAQSV1HP3WGJPFWEJ8J",
"language": "language_3",
"name": "name_3",
"stt_engine": None,
"stt_language": None,
"tts_engine": None,
"tts_language": None,
"tts_voice": None,
"wake_word_entity": "wakeword_entity_3",
"wake_word_id": "wakeword_id_3",
},
],
"preferred_item": "01GX8ZWBAQYWNB1XV3EXEZ75DY",
},
}
assert await async_setup_component(hass, "assist_pipeline", {})
pipeline_data: PipelineData = hass.data[DOMAIN]
store = pipeline_data.pipeline_store
assert len(store.data) == 3
assert store.async_get_preferred_item() == "01GX8ZWBAQYWNB1XV3EXEZ75DY"
async def test_migrate_pipeline_store(
hass: HomeAssistant, hass_storage: dict[str, Any]
) -> None:
"""Test loading stored pipelines from an older version."""
hass_storage[STORAGE_KEY] = {
"version": 1,
"minor_version": 1,
"key": "assist_pipeline.pipelines",
"data": {
"items": [
{
"conversation_engine": "conversation_engine_1",
"conversation_language": "language_1",
"id": "01GX8ZWBAQYWNB1XV3EXEZ75DY",
"language": "language_1",
"name": "name_1",
"stt_engine": "stt_engine_1",
"stt_language": "language_1",
"tts_engine": "tts_engine_1",
"tts_language": "language_1",
"tts_voice": "Arnold Schwarzenegger",
},
{
"conversation_engine": "conversation_engine_2",
"conversation_language": "language_2",
"id": "01GX8ZWBAQTKFQNK4W7Q4CTRCX",
"language": "language_2",
"name": "name_2",
"stt_engine": "stt_engine_2",
"stt_language": "language_2",
"tts_engine": "tts_engine_2",
"tts_language": "language_2",
"tts_voice": "The Voice",
},
{
"conversation_engine": "conversation_engine_3",
"conversation_language": "language_3",
"id": "01GX8ZWBAQSV1HP3WGJPFWEJ8J",
"language": "language_3",
"name": "name_3",
"stt_engine": None,
"stt_language": None,
"tts_engine": None,
"tts_language": None,
"tts_voice": None,
},
],
"preferred_item": "01GX8ZWBAQYWNB1XV3EXEZ75DY",
},
}
assert await async_setup_component(hass, "assist_pipeline", {})
pipeline_data: PipelineData = hass.data[DOMAIN]
store = pipeline_data.pipeline_store
assert len(store.data) == 3
assert store.async_get_preferred_item() == "01GX8ZWBAQYWNB1XV3EXEZ75DY"
async def test_create_default_pipeline(
hass: HomeAssistant, init_supporting_components
) -> None:
"""Test async_create_default_pipeline."""
assert await async_setup_component(hass, "assist_pipeline", {})
pipeline_data: PipelineData = hass.data[DOMAIN]
store = pipeline_data.pipeline_store
assert len(store.data) == 1
assert (
await async_create_default_pipeline(
hass,
stt_engine_id="bla",
tts_engine_id="bla",
pipeline_name="Bla pipeline",
)
is None
)
assert await async_create_default_pipeline(
hass,
stt_engine_id="test",
tts_engine_id="test",
pipeline_name="Test pipeline",
) == Pipeline(
conversation_engine="homeassistant",
conversation_language="en",
id=ANY,
language="en",
name="Test pipeline",
stt_engine="test",
stt_language="en-US",
tts_engine="test",
tts_language="en-US",
tts_voice="james_earl_jones",
wake_word_entity=None,
wake_word_id=None,
)
async def test_get_pipeline(hass: HomeAssistant) -> None:
"""Test async_get_pipeline."""
assert await async_setup_component(hass, "assist_pipeline", {})
pipeline_data: PipelineData = hass.data[DOMAIN]
store = pipeline_data.pipeline_store
assert len(store.data) == 1
# Test we get the preferred pipeline if none is specified
pipeline = async_get_pipeline(hass, None)
assert pipeline.id == store.async_get_preferred_item()
# Test getting a specific pipeline
assert pipeline is async_get_pipeline(hass, pipeline.id)
async def test_get_pipelines(hass: HomeAssistant) -> None:
"""Test async_get_pipelines."""
assert await async_setup_component(hass, "assist_pipeline", {})
pipeline_data: PipelineData = hass.data[DOMAIN]
store = pipeline_data.pipeline_store
assert len(store.data) == 1
pipelines = async_get_pipelines(hass)
assert list(pipelines) == [
Pipeline(
conversation_engine="homeassistant",
conversation_language="en",
id=ANY,
language="en",
name="Home Assistant",
stt_engine=None,
stt_language=None,
tts_engine=None,
tts_language=None,
tts_voice=None,
wake_word_entity=None,
wake_word_id=None,
)
]
@pytest.mark.parametrize(
("ha_language", "ha_country", "conv_language", "pipeline_language"),
[
("en", None, "en", "en"),
("de", "de", "de", "de"),
("de", "ch", "de-CH", "de"),
("en", "us", "en", "en"),
("en", "uk", "en", "en"),
("pt", "pt", "pt", "pt"),
("pt", "br", "pt-br", "pt"),
],
)
async def test_default_pipeline_no_stt_tts(
hass: HomeAssistant,
ha_language: str,
ha_country: str | None,
conv_language: str,
pipeline_language: str,
) -> None:
"""Test async_get_pipeline."""
hass.config.country = ha_country
hass.config.language = ha_language
assert await async_setup_component(hass, "assist_pipeline", {})
pipeline_data: PipelineData = hass.data[DOMAIN]
store = pipeline_data.pipeline_store
assert len(store.data) == 1
# Check the default pipeline
pipeline = async_get_pipeline(hass, None)
assert pipeline == Pipeline(
conversation_engine="homeassistant",
conversation_language=conv_language,
id=pipeline.id,
language=pipeline_language,
name="Home Assistant",
stt_engine=None,
stt_language=None,
tts_engine=None,
tts_language=None,
tts_voice=None,
wake_word_entity=None,
wake_word_id=None,
)
@pytest.mark.parametrize(
(
"ha_language",
"ha_country",
"conv_language",
"pipeline_language",
"stt_language",
"tts_language",
),
[
("en", None, "en", "en", "en", "en"),
("de", "de", "de", "de", "de", "de"),
("de", "ch", "de-CH", "de", "de-CH", "de-CH"),
("en", "us", "en", "en", "en", "en"),
("en", "uk", "en", "en", "en", "en"),
("pt", "pt", "pt", "pt", "pt", "pt"),
("pt", "br", "pt-br", "pt", "pt-br", "pt-br"),
],
)
async def test_default_pipeline(
hass: HomeAssistant,
init_supporting_components,
mock_stt_provider: MockSttProvider,
mock_tts_provider: MockTTSProvider,
ha_language: str,
ha_country: str | None,
conv_language: str,
pipeline_language: str,
stt_language: str,
tts_language: str,
) -> None:
"""Test async_get_pipeline."""
hass.config.country = ha_country
hass.config.language = ha_language
with patch.object(
mock_stt_provider, "_supported_languages", MANY_LANGUAGES
), patch.object(mock_tts_provider, "_supported_languages", MANY_LANGUAGES):
assert await async_setup_component(hass, "assist_pipeline", {})
pipeline_data: PipelineData = hass.data[DOMAIN]
store = pipeline_data.pipeline_store
assert len(store.data) == 1
# Check the default pipeline
pipeline = async_get_pipeline(hass, None)
assert pipeline == Pipeline(
conversation_engine="homeassistant",
conversation_language=conv_language,
id=pipeline.id,
language=pipeline_language,
name="Home Assistant",
stt_engine="test",
stt_language=stt_language,
tts_engine="test",
tts_language=tts_language,
tts_voice=None,
wake_word_entity=None,
wake_word_id=None,
)
async def test_default_pipeline_unsupported_stt_language(
hass: HomeAssistant,
init_supporting_components,
mock_stt_provider: MockSttProvider,
) -> None:
"""Test async_get_pipeline."""
with patch.object(mock_stt_provider, "_supported_languages", ["smurfish"]):
assert await async_setup_component(hass, "assist_pipeline", {})
pipeline_data: PipelineData = hass.data[DOMAIN]
store = pipeline_data.pipeline_store
assert len(store.data) == 1
# Check the default pipeline
pipeline = async_get_pipeline(hass, None)
assert pipeline == Pipeline(
conversation_engine="homeassistant",
conversation_language="en",
id=pipeline.id,
language="en",
name="Home Assistant",
stt_engine=None,
stt_language=None,
tts_engine="test",
tts_language="en-US",
tts_voice="james_earl_jones",
wake_word_entity=None,
wake_word_id=None,
)
async def test_default_pipeline_unsupported_tts_language(
hass: HomeAssistant,
init_supporting_components,
mock_tts_provider: MockTTSProvider,
) -> None:
"""Test async_get_pipeline."""
with patch.object(mock_tts_provider, "_supported_languages", ["smurfish"]):
assert await async_setup_component(hass, "assist_pipeline", {})
pipeline_data: PipelineData = hass.data[DOMAIN]
store = pipeline_data.pipeline_store
assert len(store.data) == 1
# Check the default pipeline
pipeline = async_get_pipeline(hass, None)
assert pipeline == Pipeline(
conversation_engine="homeassistant",
conversation_language="en",
id=pipeline.id,
language="en",
name="Home Assistant",
stt_engine="test",
stt_language="en-US",
tts_engine=None,
tts_language=None,
tts_voice=None,
wake_word_entity=None,
wake_word_id=None,
)
async def test_update_pipeline(
hass: HomeAssistant,
hass_storage: dict[str, Any],
) -> None:
"""Test async_update_pipeline."""
assert await async_setup_component(hass, "assist_pipeline", {})
pipelines = async_get_pipelines(hass)
pipelines = list(pipelines)
assert pipelines == [
Pipeline(
conversation_engine="homeassistant",
conversation_language="en",
id=ANY,
language="en",
name="Home Assistant",
stt_engine=None,
stt_language=None,
tts_engine=None,
tts_language=None,
tts_voice=None,
wake_word_entity=None,
wake_word_id=None,
)
]
pipeline = pipelines[0]
await async_update_pipeline(
hass,
pipeline,
conversation_engine="homeassistant_1",
conversation_language="de",
language="de",
name="Home Assistant 1",
stt_engine="stt.test_1",
stt_language="de",
tts_engine="test_1",
tts_language="de",
tts_voice="test_voice",
wake_word_entity="wake_work.test_1",
wake_word_id="wake_word_id_1",
)
pipelines = async_get_pipelines(hass)
pipelines = list(pipelines)
pipeline = pipelines[0]
assert pipelines == [
Pipeline(
conversation_engine="homeassistant_1",
conversation_language="de",
id=pipeline.id,
language="de",
name="Home Assistant 1",
stt_engine="stt.test_1",
stt_language="de",
tts_engine="test_1",
tts_language="de",
tts_voice="test_voice",
wake_word_entity="wake_work.test_1",
wake_word_id="wake_word_id_1",
)
]
assert len(hass_storage[STORAGE_KEY]["data"]["items"]) == 1
assert hass_storage[STORAGE_KEY]["data"]["items"][0] == {
"conversation_engine": "homeassistant_1",
"conversation_language": "de",
"id": pipeline.id,
"language": "de",
"name": "Home Assistant 1",
"stt_engine": "stt.test_1",
"stt_language": "de",
"tts_engine": "test_1",
"tts_language": "de",
"tts_voice": "test_voice",
"wake_word_entity": "wake_work.test_1",
"wake_word_id": "wake_word_id_1",
}
await async_update_pipeline(
hass,
pipeline,
stt_engine="stt.test_2",
stt_language="en",
tts_engine="test_2",
tts_language="en",
)
pipelines = async_get_pipelines(hass)
pipelines = list(pipelines)
assert pipelines == [
Pipeline(
conversation_engine="homeassistant_1",
conversation_language="de",
id=pipeline.id,
language="de",
name="Home Assistant 1",
stt_engine="stt.test_2",
stt_language="en",
tts_engine="test_2",
tts_language="en",
tts_voice="test_voice",
wake_word_entity="wake_work.test_1",
wake_word_id="wake_word_id_1",
)
]
assert len(hass_storage[STORAGE_KEY]["data"]["items"]) == 1
assert hass_storage[STORAGE_KEY]["data"]["items"][0] == {
"conversation_engine": "homeassistant_1",
"conversation_language": "de",
"id": pipeline.id,
"language": "de",
"name": "Home Assistant 1",
"stt_engine": "stt.test_2",
"stt_language": "en",
"tts_engine": "test_2",
"tts_language": "en",
"tts_voice": "test_voice",
"wake_word_entity": "wake_work.test_1",
"wake_word_id": "wake_word_id_1",
}