"""Websocket tests for Voice Assistant integration.""" from collections.abc import AsyncGenerator from typing import Any from unittest.mock import ANY, patch import pytest from homeassistant.components import conversation from homeassistant.components.assist_pipeline.const import DOMAIN from homeassistant.components.assist_pipeline.pipeline import ( STORAGE_KEY, STORAGE_VERSION, STORAGE_VERSION_MINOR, Pipeline, PipelineData, PipelineStorageCollection, PipelineStore, async_create_default_pipeline, async_get_pipeline, async_get_pipelines, async_update_pipeline, ) from homeassistant.core import HomeAssistant from homeassistant.setup import async_setup_component from . import MANY_LANGUAGES from .conftest import MockSttProvider, MockTTSProvider from tests.common import flush_store @pytest.fixture(autouse=True) async def delay_save_fixture() -> AsyncGenerator[None, None]: """Load the homeassistant integration.""" with patch("homeassistant.helpers.collection.SAVE_DELAY", new=0): yield @pytest.fixture(autouse=True) async def load_homeassistant(hass) -> None: """Load the homeassistant integration.""" assert await async_setup_component(hass, "homeassistant", {}) async def test_load_pipelines(hass: HomeAssistant, init_components) -> None: """Make sure that we can load/save data correctly.""" pipelines = [ { "conversation_engine": "conversation_engine_1", "conversation_language": "language_1", "language": "language_1", "name": "name_1", "stt_engine": "stt_engine_1", "stt_language": "language_1", "tts_engine": "tts_engine_1", "tts_language": "language_1", "tts_voice": "Arnold Schwarzenegger", "wake_word_entity": "wakeword_entity_1", "wake_word_id": "wakeword_id_1", }, { "conversation_engine": "conversation_engine_2", "conversation_language": "language_2", "language": "language_2", "name": "name_2", "stt_engine": "stt_engine_2", "stt_language": "language_1", "tts_engine": "tts_engine_2", "tts_language": "language_2", "tts_voice": "The Voice", "wake_word_entity": "wakeword_entity_2", "wake_word_id": "wakeword_id_2", }, { "conversation_engine": "conversation_engine_3", "conversation_language": "language_3", "language": "language_3", "name": "name_3", "stt_engine": None, "stt_language": None, "tts_engine": None, "tts_language": None, "tts_voice": None, "wake_word_entity": "wakeword_entity_3", "wake_word_id": "wakeword_id_3", }, ] pipeline_data: PipelineData = hass.data[DOMAIN] store1 = pipeline_data.pipeline_store pipeline_ids = [ (await store1.async_create_item(pipeline)).id for pipeline in pipelines ] assert len(store1.data) == 4 # 3 manually created plus a default pipeline assert store1.async_get_preferred_item() == list(store1.data)[0] await store1.async_delete_item(pipeline_ids[1]) assert len(store1.data) == 3 store2 = PipelineStorageCollection( PipelineStore( hass, STORAGE_VERSION, STORAGE_KEY, minor_version=STORAGE_VERSION_MINOR ) ) await flush_store(store1.store) await store2.async_load() assert len(store2.data) == 3 assert store1.data is not store2.data assert store1.data == store2.data assert store1.async_get_preferred_item() == store2.async_get_preferred_item() async def test_loading_pipelines_from_storage( hass: HomeAssistant, hass_storage: dict[str, Any] ) -> None: """Test loading stored pipelines on start.""" id_1 = "01GX8ZWBAQYWNB1XV3EXEZ75DY" hass_storage[STORAGE_KEY] = { "version": STORAGE_VERSION, "minor_version": STORAGE_VERSION_MINOR, "key": "assist_pipeline.pipelines", "data": { "items": [ { "conversation_engine": conversation.OLD_HOME_ASSISTANT_AGENT, "conversation_language": "language_1", "id": id_1, "language": "language_1", "name": "name_1", "stt_engine": "stt_engine_1", "stt_language": "language_1", "tts_engine": "tts_engine_1", "tts_language": "language_1", "tts_voice": "Arnold Schwarzenegger", "wake_word_entity": "wakeword_entity_1", "wake_word_id": "wakeword_id_1", }, { "conversation_engine": "conversation_engine_2", "conversation_language": "language_2", "id": "01GX8ZWBAQTKFQNK4W7Q4CTRCX", "language": "language_2", "name": "name_2", "stt_engine": "stt_engine_2", "stt_language": "language_2", "tts_engine": "tts_engine_2", "tts_language": "language_2", "tts_voice": "The Voice", "wake_word_entity": "wakeword_entity_2", "wake_word_id": "wakeword_id_2", }, { "conversation_engine": "conversation_engine_3", "conversation_language": "language_3", "id": "01GX8ZWBAQSV1HP3WGJPFWEJ8J", "language": "language_3", "name": "name_3", "stt_engine": None, "stt_language": None, "tts_engine": None, "tts_language": None, "tts_voice": None, "wake_word_entity": "wakeword_entity_3", "wake_word_id": "wakeword_id_3", }, ], "preferred_item": id_1, }, } assert await async_setup_component(hass, "assist_pipeline", {}) pipeline_data: PipelineData = hass.data[DOMAIN] store = pipeline_data.pipeline_store assert len(store.data) == 3 assert store.async_get_preferred_item() == id_1 assert store.data[id_1].conversation_engine == conversation.HOME_ASSISTANT_AGENT async def test_migrate_pipeline_store( hass: HomeAssistant, hass_storage: dict[str, Any] ) -> None: """Test loading stored pipelines from an older version.""" hass_storage[STORAGE_KEY] = { "version": 1, "minor_version": 1, "key": "assist_pipeline.pipelines", "data": { "items": [ { "conversation_engine": "conversation_engine_1", "conversation_language": "language_1", "id": "01GX8ZWBAQYWNB1XV3EXEZ75DY", "language": "language_1", "name": "name_1", "stt_engine": "stt_engine_1", "stt_language": "language_1", "tts_engine": "tts_engine_1", "tts_language": "language_1", "tts_voice": "Arnold Schwarzenegger", }, { "conversation_engine": "conversation_engine_2", "conversation_language": "language_2", "id": "01GX8ZWBAQTKFQNK4W7Q4CTRCX", "language": "language_2", "name": "name_2", "stt_engine": "stt_engine_2", "stt_language": "language_2", "tts_engine": "tts_engine_2", "tts_language": "language_2", "tts_voice": "The Voice", }, { "conversation_engine": "conversation_engine_3", "conversation_language": "language_3", "id": "01GX8ZWBAQSV1HP3WGJPFWEJ8J", "language": "language_3", "name": "name_3", "stt_engine": None, "stt_language": None, "tts_engine": None, "tts_language": None, "tts_voice": None, }, ], "preferred_item": "01GX8ZWBAQYWNB1XV3EXEZ75DY", }, } assert await async_setup_component(hass, "assist_pipeline", {}) pipeline_data: PipelineData = hass.data[DOMAIN] store = pipeline_data.pipeline_store assert len(store.data) == 3 assert store.async_get_preferred_item() == "01GX8ZWBAQYWNB1XV3EXEZ75DY" async def test_create_default_pipeline( hass: HomeAssistant, init_supporting_components ) -> None: """Test async_create_default_pipeline.""" assert await async_setup_component(hass, "assist_pipeline", {}) pipeline_data: PipelineData = hass.data[DOMAIN] store = pipeline_data.pipeline_store assert len(store.data) == 1 assert ( await async_create_default_pipeline( hass, stt_engine_id="bla", tts_engine_id="bla", pipeline_name="Bla pipeline", ) is None ) assert await async_create_default_pipeline( hass, stt_engine_id="test", tts_engine_id="test", pipeline_name="Test pipeline", ) == Pipeline( conversation_engine="conversation.home_assistant", conversation_language="en", id=ANY, language="en", name="Test pipeline", stt_engine="test", stt_language="en-US", tts_engine="test", tts_language="en-US", tts_voice="james_earl_jones", wake_word_entity=None, wake_word_id=None, ) async def test_get_pipeline(hass: HomeAssistant) -> None: """Test async_get_pipeline.""" assert await async_setup_component(hass, "assist_pipeline", {}) pipeline_data: PipelineData = hass.data[DOMAIN] store = pipeline_data.pipeline_store assert len(store.data) == 1 # Test we get the preferred pipeline if none is specified pipeline = async_get_pipeline(hass, None) assert pipeline.id == store.async_get_preferred_item() # Test getting a specific pipeline assert pipeline is async_get_pipeline(hass, pipeline.id) async def test_get_pipelines(hass: HomeAssistant) -> None: """Test async_get_pipelines.""" assert await async_setup_component(hass, "assist_pipeline", {}) pipeline_data: PipelineData = hass.data[DOMAIN] store = pipeline_data.pipeline_store assert len(store.data) == 1 pipelines = async_get_pipelines(hass) assert list(pipelines) == [ Pipeline( conversation_engine="conversation.home_assistant", conversation_language="en", id=ANY, language="en", name="Home Assistant", stt_engine=None, stt_language=None, tts_engine=None, tts_language=None, tts_voice=None, wake_word_entity=None, wake_word_id=None, ) ] @pytest.mark.parametrize( ("ha_language", "ha_country", "conv_language", "pipeline_language"), [ ("en", None, "en", "en"), ("de", "de", "de", "de"), ("de", "ch", "de-CH", "de"), ("en", "us", "en", "en"), ("en", "uk", "en", "en"), ("pt", "pt", "pt", "pt"), ("pt", "br", "pt-br", "pt"), ], ) async def test_default_pipeline_no_stt_tts( hass: HomeAssistant, ha_language: str, ha_country: str | None, conv_language: str, pipeline_language: str, ) -> None: """Test async_get_pipeline.""" hass.config.country = ha_country hass.config.language = ha_language assert await async_setup_component(hass, "assist_pipeline", {}) pipeline_data: PipelineData = hass.data[DOMAIN] store = pipeline_data.pipeline_store assert len(store.data) == 1 # Check the default pipeline pipeline = async_get_pipeline(hass, None) assert pipeline == Pipeline( conversation_engine="conversation.home_assistant", conversation_language=conv_language, id=pipeline.id, language=pipeline_language, name="Home Assistant", stt_engine=None, stt_language=None, tts_engine=None, tts_language=None, tts_voice=None, wake_word_entity=None, wake_word_id=None, ) @pytest.mark.parametrize( ( "ha_language", "ha_country", "conv_language", "pipeline_language", "stt_language", "tts_language", ), [ ("en", None, "en", "en", "en", "en"), ("de", "de", "de", "de", "de", "de"), ("de", "ch", "de-CH", "de", "de-CH", "de-CH"), ("en", "us", "en", "en", "en", "en"), ("en", "uk", "en", "en", "en", "en"), ("pt", "pt", "pt", "pt", "pt", "pt"), ("pt", "br", "pt-br", "pt", "pt-br", "pt-br"), ], ) async def test_default_pipeline( hass: HomeAssistant, init_supporting_components, mock_stt_provider: MockSttProvider, mock_tts_provider: MockTTSProvider, ha_language: str, ha_country: str | None, conv_language: str, pipeline_language: str, stt_language: str, tts_language: str, ) -> None: """Test async_get_pipeline.""" hass.config.country = ha_country hass.config.language = ha_language with ( patch.object(mock_stt_provider, "_supported_languages", MANY_LANGUAGES), patch.object(mock_tts_provider, "_supported_languages", MANY_LANGUAGES), ): assert await async_setup_component(hass, "assist_pipeline", {}) pipeline_data: PipelineData = hass.data[DOMAIN] store = pipeline_data.pipeline_store assert len(store.data) == 1 # Check the default pipeline pipeline = async_get_pipeline(hass, None) assert pipeline == Pipeline( conversation_engine="conversation.home_assistant", conversation_language=conv_language, id=pipeline.id, language=pipeline_language, name="Home Assistant", stt_engine="test", stt_language=stt_language, tts_engine="test", tts_language=tts_language, tts_voice=None, wake_word_entity=None, wake_word_id=None, ) async def test_default_pipeline_unsupported_stt_language( hass: HomeAssistant, init_supporting_components, mock_stt_provider: MockSttProvider, ) -> None: """Test async_get_pipeline.""" with patch.object(mock_stt_provider, "_supported_languages", ["smurfish"]): assert await async_setup_component(hass, "assist_pipeline", {}) pipeline_data: PipelineData = hass.data[DOMAIN] store = pipeline_data.pipeline_store assert len(store.data) == 1 # Check the default pipeline pipeline = async_get_pipeline(hass, None) assert pipeline == Pipeline( conversation_engine="conversation.home_assistant", conversation_language="en", id=pipeline.id, language="en", name="Home Assistant", stt_engine=None, stt_language=None, tts_engine="test", tts_language="en-US", tts_voice="james_earl_jones", wake_word_entity=None, wake_word_id=None, ) async def test_default_pipeline_unsupported_tts_language( hass: HomeAssistant, init_supporting_components, mock_tts_provider: MockTTSProvider, ) -> None: """Test async_get_pipeline.""" with patch.object(mock_tts_provider, "_supported_languages", ["smurfish"]): assert await async_setup_component(hass, "assist_pipeline", {}) pipeline_data: PipelineData = hass.data[DOMAIN] store = pipeline_data.pipeline_store assert len(store.data) == 1 # Check the default pipeline pipeline = async_get_pipeline(hass, None) assert pipeline == Pipeline( conversation_engine="conversation.home_assistant", conversation_language="en", id=pipeline.id, language="en", name="Home Assistant", stt_engine="test", stt_language="en-US", tts_engine=None, tts_language=None, tts_voice=None, wake_word_entity=None, wake_word_id=None, ) async def test_update_pipeline( hass: HomeAssistant, hass_storage: dict[str, Any], ) -> None: """Test async_update_pipeline.""" assert await async_setup_component(hass, "assist_pipeline", {}) pipelines = async_get_pipelines(hass) pipelines = list(pipelines) assert pipelines == [ Pipeline( conversation_engine="conversation.home_assistant", conversation_language="en", id=ANY, language="en", name="Home Assistant", stt_engine=None, stt_language=None, tts_engine=None, tts_language=None, tts_voice=None, wake_word_entity=None, wake_word_id=None, ) ] pipeline = pipelines[0] await async_update_pipeline( hass, pipeline, conversation_engine="homeassistant_1", conversation_language="de", language="de", name="Home Assistant 1", stt_engine="stt.test_1", stt_language="de", tts_engine="test_1", tts_language="de", tts_voice="test_voice", wake_word_entity="wake_work.test_1", wake_word_id="wake_word_id_1", ) pipelines = async_get_pipelines(hass) pipelines = list(pipelines) pipeline = pipelines[0] assert pipelines == [ Pipeline( conversation_engine="homeassistant_1", conversation_language="de", id=pipeline.id, language="de", name="Home Assistant 1", stt_engine="stt.test_1", stt_language="de", tts_engine="test_1", tts_language="de", tts_voice="test_voice", wake_word_entity="wake_work.test_1", wake_word_id="wake_word_id_1", ) ] assert len(hass_storage[STORAGE_KEY]["data"]["items"]) == 1 assert hass_storage[STORAGE_KEY]["data"]["items"][0] == { "conversation_engine": "homeassistant_1", "conversation_language": "de", "id": pipeline.id, "language": "de", "name": "Home Assistant 1", "stt_engine": "stt.test_1", "stt_language": "de", "tts_engine": "test_1", "tts_language": "de", "tts_voice": "test_voice", "wake_word_entity": "wake_work.test_1", "wake_word_id": "wake_word_id_1", } await async_update_pipeline( hass, pipeline, stt_engine="stt.test_2", stt_language="en", tts_engine="test_2", tts_language="en", ) pipelines = async_get_pipelines(hass) pipelines = list(pipelines) assert pipelines == [ Pipeline( conversation_engine="homeassistant_1", conversation_language="de", id=pipeline.id, language="de", name="Home Assistant 1", stt_engine="stt.test_2", stt_language="en", tts_engine="test_2", tts_language="en", tts_voice="test_voice", wake_word_entity="wake_work.test_1", wake_word_id="wake_word_id_1", ) ] assert len(hass_storage[STORAGE_KEY]["data"]["items"]) == 1 assert hass_storage[STORAGE_KEY]["data"]["items"][0] == { "conversation_engine": "homeassistant_1", "conversation_language": "de", "id": pipeline.id, "language": "de", "name": "Home Assistant 1", "stt_engine": "stt.test_2", "stt_language": "en", "tts_engine": "test_2", "tts_language": "en", "tts_voice": "test_voice", "wake_word_entity": "wake_work.test_1", "wake_word_id": "wake_word_id_1", }