1788 lines
59 KiB
Python
1788 lines
59 KiB
Python
"""Websocket tests for Voice Assistant integration."""
|
|
|
|
from collections.abc import AsyncGenerator, Generator
|
|
from typing import Any
|
|
from unittest.mock import ANY, AsyncMock, Mock, patch
|
|
|
|
from hassil.recognize import Intent, IntentData, RecognizeResult
|
|
import pytest
|
|
from syrupy.assertion import SnapshotAssertion
|
|
import voluptuous as vol
|
|
|
|
from homeassistant.components import (
|
|
assist_pipeline,
|
|
conversation,
|
|
media_source,
|
|
stt,
|
|
tts,
|
|
)
|
|
from homeassistant.components.assist_pipeline.const import DOMAIN
|
|
from homeassistant.components.assist_pipeline.pipeline import (
|
|
STORAGE_KEY,
|
|
STORAGE_VERSION,
|
|
STORAGE_VERSION_MINOR,
|
|
Pipeline,
|
|
PipelineData,
|
|
PipelineStorageCollection,
|
|
PipelineStore,
|
|
_async_local_fallback_intent_filter,
|
|
async_create_default_pipeline,
|
|
async_get_pipeline,
|
|
async_get_pipelines,
|
|
async_update_pipeline,
|
|
)
|
|
from homeassistant.const import MATCH_ALL
|
|
from homeassistant.core import Context, HomeAssistant
|
|
from homeassistant.helpers import chat_session, intent, llm
|
|
from homeassistant.setup import async_setup_component
|
|
|
|
from . import MANY_LANGUAGES, process_events
|
|
from .conftest import (
|
|
MockSTTProvider,
|
|
MockSTTProviderEntity,
|
|
MockTTSEntity,
|
|
MockTTSProvider,
|
|
MockWakeWordEntity,
|
|
make_10ms_chunk,
|
|
)
|
|
|
|
from tests.common import flush_store
|
|
from tests.typing import ClientSessionGenerator, WebSocketGenerator
|
|
|
|
|
|
@pytest.fixture(autouse=True)
|
|
async def delay_save_fixture() -> AsyncGenerator[None]:
|
|
"""Load the homeassistant integration."""
|
|
with patch("homeassistant.helpers.collection.SAVE_DELAY", new=0):
|
|
yield
|
|
|
|
|
|
@pytest.fixture(autouse=True)
|
|
async def load_homeassistant(hass: HomeAssistant) -> None:
|
|
"""Load the homeassistant integration."""
|
|
assert await async_setup_component(hass, "homeassistant", {})
|
|
|
|
|
|
@pytest.fixture
|
|
async def disable_tts_entity(mock_tts_entity: tts.TextToSpeechEntity) -> None:
|
|
"""Disable the TTS entity."""
|
|
mock_tts_entity._attr_entity_registry_enabled_default = False
|
|
|
|
|
|
@pytest.mark.usefixtures("init_components")
|
|
async def test_load_pipelines(hass: HomeAssistant) -> None:
|
|
"""Make sure that we can load/save data correctly."""
|
|
|
|
pipelines = [
|
|
{
|
|
"conversation_engine": "conversation_engine_1",
|
|
"conversation_language": "language_1",
|
|
"language": "language_1",
|
|
"name": "name_1",
|
|
"stt_engine": "stt_engine_1",
|
|
"stt_language": "language_1",
|
|
"tts_engine": "tts_engine_1",
|
|
"tts_language": "language_1",
|
|
"tts_voice": "Arnold Schwarzenegger",
|
|
"wake_word_entity": "wakeword_entity_1",
|
|
"wake_word_id": "wakeword_id_1",
|
|
},
|
|
{
|
|
"conversation_engine": "conversation_engine_2",
|
|
"conversation_language": "language_2",
|
|
"language": "language_2",
|
|
"name": "name_2",
|
|
"stt_engine": "stt_engine_2",
|
|
"stt_language": "language_1",
|
|
"tts_engine": "tts_engine_2",
|
|
"tts_language": "language_2",
|
|
"tts_voice": "The Voice",
|
|
"wake_word_entity": "wakeword_entity_2",
|
|
"wake_word_id": "wakeword_id_2",
|
|
},
|
|
{
|
|
"conversation_engine": "conversation_engine_3",
|
|
"conversation_language": "language_3",
|
|
"language": "language_3",
|
|
"name": "name_3",
|
|
"stt_engine": None,
|
|
"stt_language": None,
|
|
"tts_engine": None,
|
|
"tts_language": None,
|
|
"tts_voice": None,
|
|
"wake_word_entity": "wakeword_entity_3",
|
|
"wake_word_id": "wakeword_id_3",
|
|
},
|
|
]
|
|
|
|
pipeline_data: PipelineData = hass.data[DOMAIN]
|
|
store1 = pipeline_data.pipeline_store
|
|
pipeline_ids = [
|
|
(await store1.async_create_item(pipeline)).id for pipeline in pipelines
|
|
]
|
|
assert len(store1.data) == 4 # 3 manually created plus a default pipeline
|
|
assert store1.async_get_preferred_item() == list(store1.data)[0]
|
|
|
|
await store1.async_delete_item(pipeline_ids[1])
|
|
assert len(store1.data) == 3
|
|
|
|
store2 = PipelineStorageCollection(
|
|
PipelineStore(
|
|
hass, STORAGE_VERSION, STORAGE_KEY, minor_version=STORAGE_VERSION_MINOR
|
|
)
|
|
)
|
|
await flush_store(store1.store)
|
|
await store2.async_load()
|
|
|
|
assert len(store2.data) == 3
|
|
|
|
assert store1.data is not store2.data
|
|
assert store1.data == store2.data
|
|
assert store1.async_get_preferred_item() == store2.async_get_preferred_item()
|
|
|
|
|
|
@pytest.fixture(autouse=True)
|
|
def mock_chat_session_id() -> Generator[Mock]:
|
|
"""Mock the conversation ID of chat sessions."""
|
|
with patch(
|
|
"homeassistant.helpers.chat_session.ulid_now", return_value="mock-ulid"
|
|
) as mock_ulid_now:
|
|
yield mock_ulid_now
|
|
|
|
|
|
@pytest.fixture(autouse=True)
|
|
def mock_tts_token() -> Generator[None]:
|
|
"""Mock the TTS token for URLs."""
|
|
with patch("secrets.token_urlsafe", return_value="mocked-token"):
|
|
yield
|
|
|
|
|
|
async def test_loading_pipelines_from_storage(
|
|
hass: HomeAssistant, hass_storage: dict[str, Any]
|
|
) -> None:
|
|
"""Test loading stored pipelines on start."""
|
|
id_1 = "01GX8ZWBAQYWNB1XV3EXEZ75DY"
|
|
hass_storage[STORAGE_KEY] = {
|
|
"version": STORAGE_VERSION,
|
|
"minor_version": STORAGE_VERSION_MINOR,
|
|
"key": "assist_pipeline.pipelines",
|
|
"data": {
|
|
"items": [
|
|
{
|
|
"conversation_engine": conversation.HOME_ASSISTANT_AGENT,
|
|
"conversation_language": "language_1",
|
|
"id": id_1,
|
|
"language": "language_1",
|
|
"name": "name_1",
|
|
"stt_engine": "stt_engine_1",
|
|
"stt_language": "language_1",
|
|
"tts_engine": "tts_engine_1",
|
|
"tts_language": "language_1",
|
|
"tts_voice": "Arnold Schwarzenegger",
|
|
"wake_word_entity": "wakeword_entity_1",
|
|
"wake_word_id": "wakeword_id_1",
|
|
},
|
|
{
|
|
"conversation_engine": "conversation_engine_2",
|
|
"conversation_language": "language_2",
|
|
"id": "01GX8ZWBAQTKFQNK4W7Q4CTRCX",
|
|
"language": "language_2",
|
|
"name": "name_2",
|
|
"stt_engine": "stt_engine_2",
|
|
"stt_language": "language_2",
|
|
"tts_engine": "tts_engine_2",
|
|
"tts_language": "language_2",
|
|
"tts_voice": "The Voice",
|
|
"wake_word_entity": "wakeword_entity_2",
|
|
"wake_word_id": "wakeword_id_2",
|
|
},
|
|
{
|
|
"conversation_engine": "conversation_engine_3",
|
|
"conversation_language": "language_3",
|
|
"id": "01GX8ZWBAQSV1HP3WGJPFWEJ8J",
|
|
"language": "language_3",
|
|
"name": "name_3",
|
|
"stt_engine": None,
|
|
"stt_language": None,
|
|
"tts_engine": None,
|
|
"tts_language": None,
|
|
"tts_voice": None,
|
|
"wake_word_entity": "wakeword_entity_3",
|
|
"wake_word_id": "wakeword_id_3",
|
|
},
|
|
],
|
|
"preferred_item": id_1,
|
|
},
|
|
}
|
|
|
|
assert await async_setup_component(hass, "assist_pipeline", {})
|
|
|
|
pipeline_data: PipelineData = hass.data[DOMAIN]
|
|
store = pipeline_data.pipeline_store
|
|
assert len(store.data) == 3
|
|
assert store.async_get_preferred_item() == id_1
|
|
assert store.data[id_1].conversation_engine == conversation.HOME_ASSISTANT_AGENT
|
|
|
|
|
|
async def test_migrate_pipeline_store(
|
|
hass: HomeAssistant, hass_storage: dict[str, Any]
|
|
) -> None:
|
|
"""Test loading stored pipelines from an older version."""
|
|
hass_storage[STORAGE_KEY] = {
|
|
"version": 1,
|
|
"minor_version": 1,
|
|
"key": "assist_pipeline.pipelines",
|
|
"data": {
|
|
"items": [
|
|
{
|
|
"conversation_engine": "conversation_engine_1",
|
|
"conversation_language": "language_1",
|
|
"id": "01GX8ZWBAQYWNB1XV3EXEZ75DY",
|
|
"language": "language_1",
|
|
"name": "name_1",
|
|
"stt_engine": "stt_engine_1",
|
|
"stt_language": "language_1",
|
|
"tts_engine": "tts_engine_1",
|
|
"tts_language": "language_1",
|
|
"tts_voice": "Arnold Schwarzenegger",
|
|
},
|
|
{
|
|
"conversation_engine": "conversation_engine_2",
|
|
"conversation_language": "language_2",
|
|
"id": "01GX8ZWBAQTKFQNK4W7Q4CTRCX",
|
|
"language": "language_2",
|
|
"name": "name_2",
|
|
"stt_engine": "stt_engine_2",
|
|
"stt_language": "language_2",
|
|
"tts_engine": "tts_engine_2",
|
|
"tts_language": "language_2",
|
|
"tts_voice": "The Voice",
|
|
},
|
|
{
|
|
"conversation_engine": "conversation_engine_3",
|
|
"conversation_language": "language_3",
|
|
"id": "01GX8ZWBAQSV1HP3WGJPFWEJ8J",
|
|
"language": "language_3",
|
|
"name": "name_3",
|
|
"stt_engine": None,
|
|
"stt_language": None,
|
|
"tts_engine": None,
|
|
"tts_language": None,
|
|
"tts_voice": None,
|
|
},
|
|
],
|
|
"preferred_item": "01GX8ZWBAQYWNB1XV3EXEZ75DY",
|
|
},
|
|
}
|
|
|
|
assert await async_setup_component(hass, "assist_pipeline", {})
|
|
|
|
pipeline_data: PipelineData = hass.data[DOMAIN]
|
|
store = pipeline_data.pipeline_store
|
|
assert len(store.data) == 3
|
|
assert store.async_get_preferred_item() == "01GX8ZWBAQYWNB1XV3EXEZ75DY"
|
|
|
|
|
|
@pytest.mark.usefixtures("init_supporting_components")
|
|
@pytest.mark.usefixtures("disable_tts_entity")
|
|
async def test_create_default_pipeline(hass: HomeAssistant) -> None:
|
|
"""Test async_create_default_pipeline."""
|
|
assert await async_setup_component(hass, "assist_pipeline", {})
|
|
|
|
pipeline_data: PipelineData = hass.data[DOMAIN]
|
|
store = pipeline_data.pipeline_store
|
|
assert len(store.data) == 1
|
|
|
|
assert (
|
|
await async_create_default_pipeline(
|
|
hass,
|
|
stt_engine_id="bla",
|
|
tts_engine_id="bla",
|
|
pipeline_name="Bla pipeline",
|
|
)
|
|
is None
|
|
)
|
|
assert await async_create_default_pipeline(
|
|
hass,
|
|
stt_engine_id="test",
|
|
tts_engine_id="test",
|
|
pipeline_name="Test pipeline",
|
|
) == Pipeline(
|
|
conversation_engine="conversation.home_assistant",
|
|
conversation_language="en",
|
|
id=ANY,
|
|
language="en",
|
|
name="Test pipeline",
|
|
stt_engine="test",
|
|
stt_language="en-US",
|
|
tts_engine="test",
|
|
tts_language="en-US",
|
|
tts_voice="james_earl_jones",
|
|
wake_word_entity=None,
|
|
wake_word_id=None,
|
|
)
|
|
|
|
|
|
async def test_get_pipeline(hass: HomeAssistant) -> None:
|
|
"""Test async_get_pipeline."""
|
|
assert await async_setup_component(hass, "assist_pipeline", {})
|
|
|
|
pipeline_data: PipelineData = hass.data[DOMAIN]
|
|
store = pipeline_data.pipeline_store
|
|
assert len(store.data) == 1
|
|
|
|
# Test we get the preferred pipeline if none is specified
|
|
pipeline = async_get_pipeline(hass, None)
|
|
assert pipeline.id == store.async_get_preferred_item()
|
|
|
|
# Test getting a specific pipeline
|
|
assert pipeline is async_get_pipeline(hass, pipeline.id)
|
|
|
|
|
|
async def test_get_pipelines(hass: HomeAssistant) -> None:
|
|
"""Test async_get_pipelines."""
|
|
assert await async_setup_component(hass, "assist_pipeline", {})
|
|
|
|
pipeline_data: PipelineData = hass.data[DOMAIN]
|
|
store = pipeline_data.pipeline_store
|
|
assert len(store.data) == 1
|
|
|
|
pipelines = async_get_pipelines(hass)
|
|
assert list(pipelines) == [
|
|
Pipeline(
|
|
conversation_engine="conversation.home_assistant",
|
|
conversation_language="en",
|
|
id=ANY,
|
|
language="en",
|
|
name="Home Assistant",
|
|
stt_engine=None,
|
|
stt_language=None,
|
|
tts_engine=None,
|
|
tts_language=None,
|
|
tts_voice=None,
|
|
wake_word_entity=None,
|
|
wake_word_id=None,
|
|
)
|
|
]
|
|
|
|
|
|
@pytest.mark.parametrize(
|
|
("ha_language", "ha_country", "conv_language", "pipeline_language"),
|
|
[
|
|
("en", None, "en", "en"),
|
|
("de", "de", "de", "de"),
|
|
("de", "ch", "de-CH", "de"),
|
|
("en", "us", "en", "en"),
|
|
("en", "uk", "en", "en"),
|
|
("pt", "pt", "pt", "pt"),
|
|
("pt", "br", "pt-BR", "pt"),
|
|
],
|
|
)
|
|
async def test_default_pipeline_no_stt_tts(
|
|
hass: HomeAssistant,
|
|
ha_language: str,
|
|
ha_country: str | None,
|
|
conv_language: str,
|
|
pipeline_language: str,
|
|
) -> None:
|
|
"""Test async_get_pipeline."""
|
|
hass.config.country = ha_country
|
|
hass.config.language = ha_language
|
|
assert await async_setup_component(hass, "assist_pipeline", {})
|
|
|
|
pipeline_data: PipelineData = hass.data[DOMAIN]
|
|
store = pipeline_data.pipeline_store
|
|
assert len(store.data) == 1
|
|
|
|
# Check the default pipeline
|
|
pipeline = async_get_pipeline(hass, None)
|
|
assert pipeline == Pipeline(
|
|
conversation_engine="conversation.home_assistant",
|
|
conversation_language=conv_language,
|
|
id=pipeline.id,
|
|
language=pipeline_language,
|
|
name="Home Assistant",
|
|
stt_engine=None,
|
|
stt_language=None,
|
|
tts_engine=None,
|
|
tts_language=None,
|
|
tts_voice=None,
|
|
wake_word_entity=None,
|
|
wake_word_id=None,
|
|
)
|
|
|
|
|
|
@pytest.mark.parametrize(
|
|
(
|
|
"ha_language",
|
|
"ha_country",
|
|
"conv_language",
|
|
"pipeline_language",
|
|
"stt_language",
|
|
"tts_language",
|
|
),
|
|
[
|
|
("en", None, "en", "en", "en", "en"),
|
|
("de", "de", "de", "de", "de", "de"),
|
|
("de", "ch", "de-CH", "de", "de-CH", "de-CH"),
|
|
("en", "us", "en", "en", "en", "en"),
|
|
("en", "uk", "en", "en", "en", "en"),
|
|
("pt", "pt", "pt", "pt", "pt", "pt"),
|
|
("pt", "br", "pt-BR", "pt", "pt-br", "pt-br"),
|
|
],
|
|
)
|
|
@pytest.mark.usefixtures("init_supporting_components")
|
|
@pytest.mark.usefixtures("disable_tts_entity")
|
|
async def test_default_pipeline(
|
|
hass: HomeAssistant,
|
|
mock_stt_provider_entity: MockSTTProviderEntity,
|
|
mock_tts_provider: MockTTSProvider,
|
|
ha_language: str,
|
|
ha_country: str | None,
|
|
conv_language: str,
|
|
pipeline_language: str,
|
|
stt_language: str,
|
|
tts_language: str,
|
|
) -> None:
|
|
"""Test async_get_pipeline."""
|
|
hass.config.country = ha_country
|
|
hass.config.language = ha_language
|
|
|
|
with (
|
|
patch.object(mock_stt_provider_entity, "_supported_languages", MANY_LANGUAGES),
|
|
patch.object(mock_tts_provider, "_supported_languages", MANY_LANGUAGES),
|
|
):
|
|
assert await async_setup_component(hass, "assist_pipeline", {})
|
|
|
|
pipeline_data: PipelineData = hass.data[DOMAIN]
|
|
store = pipeline_data.pipeline_store
|
|
assert len(store.data) == 1
|
|
|
|
# Check the default pipeline
|
|
pipeline = async_get_pipeline(hass, None)
|
|
assert pipeline == Pipeline(
|
|
conversation_engine="conversation.home_assistant",
|
|
conversation_language=conv_language,
|
|
id=pipeline.id,
|
|
language=pipeline_language,
|
|
name="Home Assistant",
|
|
stt_engine="stt.mock_stt",
|
|
stt_language=stt_language,
|
|
tts_engine="test",
|
|
tts_language=tts_language,
|
|
tts_voice=None,
|
|
wake_word_entity=None,
|
|
wake_word_id=None,
|
|
)
|
|
|
|
|
|
@pytest.mark.usefixtures("init_supporting_components")
|
|
@pytest.mark.usefixtures("disable_tts_entity")
|
|
async def test_default_pipeline_unsupported_stt_language(
|
|
hass: HomeAssistant, mock_stt_provider_entity: MockSTTProviderEntity
|
|
) -> None:
|
|
"""Test async_get_pipeline."""
|
|
with patch.object(mock_stt_provider_entity, "_supported_languages", ["smurfish"]):
|
|
assert await async_setup_component(hass, "assist_pipeline", {})
|
|
|
|
pipeline_data: PipelineData = hass.data[DOMAIN]
|
|
store = pipeline_data.pipeline_store
|
|
assert len(store.data) == 1
|
|
|
|
# Check the default pipeline
|
|
pipeline = async_get_pipeline(hass, None)
|
|
assert pipeline == Pipeline(
|
|
conversation_engine="conversation.home_assistant",
|
|
conversation_language="en",
|
|
id=pipeline.id,
|
|
language="en",
|
|
name="Home Assistant",
|
|
stt_engine=None,
|
|
stt_language=None,
|
|
tts_engine="test",
|
|
tts_language="en-US",
|
|
tts_voice="james_earl_jones",
|
|
wake_word_entity=None,
|
|
wake_word_id=None,
|
|
)
|
|
|
|
|
|
@pytest.mark.usefixtures("init_supporting_components")
|
|
@pytest.mark.usefixtures("disable_tts_entity")
|
|
async def test_default_pipeline_unsupported_tts_language(
|
|
hass: HomeAssistant, mock_tts_provider: MockTTSProvider
|
|
) -> None:
|
|
"""Test async_get_pipeline."""
|
|
with patch.object(mock_tts_provider, "_supported_languages", ["smurfish"]):
|
|
assert await async_setup_component(hass, "assist_pipeline", {})
|
|
|
|
pipeline_data: PipelineData = hass.data[DOMAIN]
|
|
store = pipeline_data.pipeline_store
|
|
assert len(store.data) == 1
|
|
|
|
# Check the default pipeline
|
|
pipeline = async_get_pipeline(hass, None)
|
|
assert pipeline == Pipeline(
|
|
conversation_engine="conversation.home_assistant",
|
|
conversation_language="en",
|
|
id=pipeline.id,
|
|
language="en",
|
|
name="Home Assistant",
|
|
stt_engine="stt.mock_stt",
|
|
stt_language="en-US",
|
|
tts_engine=None,
|
|
tts_language=None,
|
|
tts_voice=None,
|
|
wake_word_entity=None,
|
|
wake_word_id=None,
|
|
)
|
|
|
|
|
|
async def test_update_pipeline(
|
|
hass: HomeAssistant, hass_storage: dict[str, Any]
|
|
) -> None:
|
|
"""Test async_update_pipeline."""
|
|
assert await async_setup_component(hass, "assist_pipeline", {})
|
|
|
|
pipelines = async_get_pipelines(hass)
|
|
pipelines = list(pipelines)
|
|
assert pipelines == [
|
|
Pipeline(
|
|
conversation_engine="conversation.home_assistant",
|
|
conversation_language="en",
|
|
id=ANY,
|
|
language="en",
|
|
name="Home Assistant",
|
|
stt_engine=None,
|
|
stt_language=None,
|
|
tts_engine=None,
|
|
tts_language=None,
|
|
tts_voice=None,
|
|
wake_word_entity=None,
|
|
wake_word_id=None,
|
|
)
|
|
]
|
|
|
|
pipeline = pipelines[0]
|
|
await async_update_pipeline(
|
|
hass,
|
|
pipeline,
|
|
conversation_engine="homeassistant_1",
|
|
conversation_language="de",
|
|
language="de",
|
|
name="Home Assistant 1",
|
|
stt_engine="stt.test_1",
|
|
stt_language="de",
|
|
tts_engine="test_1",
|
|
tts_language="de",
|
|
tts_voice="test_voice",
|
|
wake_word_entity="wake_work.test_1",
|
|
wake_word_id="wake_word_id_1",
|
|
)
|
|
|
|
pipelines = async_get_pipelines(hass)
|
|
pipelines = list(pipelines)
|
|
pipeline = pipelines[0]
|
|
assert pipelines == [
|
|
Pipeline(
|
|
conversation_engine="homeassistant_1",
|
|
conversation_language="de",
|
|
id=pipeline.id,
|
|
language="de",
|
|
name="Home Assistant 1",
|
|
stt_engine="stt.test_1",
|
|
stt_language="de",
|
|
tts_engine="test_1",
|
|
tts_language="de",
|
|
tts_voice="test_voice",
|
|
wake_word_entity="wake_work.test_1",
|
|
wake_word_id="wake_word_id_1",
|
|
)
|
|
]
|
|
assert len(hass_storage[STORAGE_KEY]["data"]["items"]) == 1
|
|
assert hass_storage[STORAGE_KEY]["data"]["items"][0] == {
|
|
"conversation_engine": "homeassistant_1",
|
|
"conversation_language": "de",
|
|
"id": pipeline.id,
|
|
"language": "de",
|
|
"name": "Home Assistant 1",
|
|
"stt_engine": "stt.test_1",
|
|
"stt_language": "de",
|
|
"tts_engine": "test_1",
|
|
"tts_language": "de",
|
|
"tts_voice": "test_voice",
|
|
"wake_word_entity": "wake_work.test_1",
|
|
"wake_word_id": "wake_word_id_1",
|
|
"prefer_local_intents": False,
|
|
}
|
|
|
|
await async_update_pipeline(
|
|
hass,
|
|
pipeline,
|
|
stt_engine="stt.test_2",
|
|
stt_language="en",
|
|
tts_engine="test_2",
|
|
tts_language="en",
|
|
)
|
|
|
|
pipelines = async_get_pipelines(hass)
|
|
pipelines = list(pipelines)
|
|
assert pipelines == [
|
|
Pipeline(
|
|
conversation_engine="homeassistant_1",
|
|
conversation_language="de",
|
|
id=pipeline.id,
|
|
language="de",
|
|
name="Home Assistant 1",
|
|
stt_engine="stt.test_2",
|
|
stt_language="en",
|
|
tts_engine="test_2",
|
|
tts_language="en",
|
|
tts_voice="test_voice",
|
|
wake_word_entity="wake_work.test_1",
|
|
wake_word_id="wake_word_id_1",
|
|
)
|
|
]
|
|
assert len(hass_storage[STORAGE_KEY]["data"]["items"]) == 1
|
|
assert hass_storage[STORAGE_KEY]["data"]["items"][0] == {
|
|
"conversation_engine": "homeassistant_1",
|
|
"conversation_language": "de",
|
|
"id": pipeline.id,
|
|
"language": "de",
|
|
"name": "Home Assistant 1",
|
|
"stt_engine": "stt.test_2",
|
|
"stt_language": "en",
|
|
"tts_engine": "test_2",
|
|
"tts_language": "en",
|
|
"tts_voice": "test_voice",
|
|
"wake_word_entity": "wake_work.test_1",
|
|
"wake_word_id": "wake_word_id_1",
|
|
"prefer_local_intents": False,
|
|
}
|
|
|
|
|
|
def test_fallback_intent_filter() -> None:
|
|
"""Test that we filter the right things."""
|
|
assert (
|
|
_async_local_fallback_intent_filter(
|
|
RecognizeResult(
|
|
intent=Intent(intent.INTENT_GET_STATE),
|
|
intent_data=IntentData([]),
|
|
entities={},
|
|
entities_list=[],
|
|
)
|
|
)
|
|
is True
|
|
)
|
|
assert (
|
|
_async_local_fallback_intent_filter(
|
|
RecognizeResult(
|
|
intent=Intent(intent.INTENT_NEVERMIND),
|
|
intent_data=IntentData([]),
|
|
entities={},
|
|
entities_list=[],
|
|
)
|
|
)
|
|
is False
|
|
)
|
|
assert (
|
|
_async_local_fallback_intent_filter(
|
|
RecognizeResult(
|
|
intent=Intent(intent.INTENT_TURN_ON),
|
|
intent_data=IntentData([]),
|
|
entities={},
|
|
entities_list=[],
|
|
)
|
|
)
|
|
is False
|
|
)
|
|
|
|
|
|
async def test_wake_word_detection_aborted(
|
|
hass: HomeAssistant,
|
|
mock_stt_provider: MockSTTProvider,
|
|
mock_wake_word_provider_entity: MockWakeWordEntity,
|
|
init_components,
|
|
pipeline_data: assist_pipeline.pipeline.PipelineData,
|
|
mock_chat_session: chat_session.ChatSession,
|
|
snapshot: SnapshotAssertion,
|
|
) -> None:
|
|
"""Test wake word stream is first detected, then aborted."""
|
|
|
|
events: list[assist_pipeline.PipelineEvent] = []
|
|
|
|
async def audio_data():
|
|
yield make_10ms_chunk(b"silence!")
|
|
yield make_10ms_chunk(b"wake word!")
|
|
yield make_10ms_chunk(b"part1")
|
|
yield make_10ms_chunk(b"part2")
|
|
yield b""
|
|
|
|
pipeline_store = pipeline_data.pipeline_store
|
|
pipeline_id = pipeline_store.async_get_preferred_item()
|
|
pipeline = assist_pipeline.pipeline.async_get_pipeline(hass, pipeline_id)
|
|
|
|
pipeline_input = assist_pipeline.pipeline.PipelineInput(
|
|
session=mock_chat_session,
|
|
device_id=None,
|
|
stt_metadata=stt.SpeechMetadata(
|
|
language="",
|
|
format=stt.AudioFormats.WAV,
|
|
codec=stt.AudioCodecs.PCM,
|
|
bit_rate=stt.AudioBitRates.BITRATE_16,
|
|
sample_rate=stt.AudioSampleRates.SAMPLERATE_16000,
|
|
channel=stt.AudioChannels.CHANNEL_MONO,
|
|
),
|
|
stt_stream=audio_data(),
|
|
run=assist_pipeline.pipeline.PipelineRun(
|
|
hass,
|
|
context=Context(),
|
|
pipeline=pipeline,
|
|
start_stage=assist_pipeline.PipelineStage.WAKE_WORD,
|
|
end_stage=assist_pipeline.PipelineStage.TTS,
|
|
event_callback=events.append,
|
|
tts_audio_output=None,
|
|
wake_word_settings=assist_pipeline.WakeWordSettings(
|
|
audio_seconds_to_buffer=1.5
|
|
),
|
|
audio_settings=assist_pipeline.AudioSettings(is_vad_enabled=False),
|
|
),
|
|
)
|
|
await pipeline_input.validate()
|
|
|
|
updates = pipeline.to_json()
|
|
updates.pop("id")
|
|
await pipeline_store.async_update_item(
|
|
pipeline_id,
|
|
updates,
|
|
)
|
|
await pipeline_input.execute()
|
|
|
|
assert process_events(events) == snapshot
|
|
|
|
|
|
def test_pipeline_run_equality(hass: HomeAssistant, init_components) -> None:
|
|
"""Test that pipeline run equality uses unique id."""
|
|
|
|
def event_callback(event):
|
|
pass
|
|
|
|
pipeline = assist_pipeline.pipeline.async_get_pipeline(hass)
|
|
run_1 = assist_pipeline.pipeline.PipelineRun(
|
|
hass,
|
|
context=Context(),
|
|
pipeline=pipeline,
|
|
start_stage=assist_pipeline.PipelineStage.STT,
|
|
end_stage=assist_pipeline.PipelineStage.TTS,
|
|
event_callback=event_callback,
|
|
)
|
|
run_2 = assist_pipeline.pipeline.PipelineRun(
|
|
hass,
|
|
context=Context(),
|
|
pipeline=pipeline,
|
|
start_stage=assist_pipeline.PipelineStage.STT,
|
|
end_stage=assist_pipeline.PipelineStage.TTS,
|
|
event_callback=event_callback,
|
|
)
|
|
|
|
assert run_1 == run_1 # noqa: PLR0124
|
|
assert run_1 != run_2
|
|
assert run_1 != 1234
|
|
|
|
|
|
async def test_tts_audio_output(
|
|
hass: HomeAssistant,
|
|
hass_client: ClientSessionGenerator,
|
|
mock_tts_entity: MockTTSProvider,
|
|
init_components,
|
|
pipeline_data: assist_pipeline.pipeline.PipelineData,
|
|
mock_chat_session: chat_session.ChatSession,
|
|
snapshot: SnapshotAssertion,
|
|
) -> None:
|
|
"""Test using tts_audio_output with wav sets options correctly."""
|
|
client = await hass_client()
|
|
assert await async_setup_component(hass, media_source.DOMAIN, {})
|
|
|
|
events: list[assist_pipeline.PipelineEvent] = []
|
|
|
|
pipeline_store = pipeline_data.pipeline_store
|
|
pipeline_id = pipeline_store.async_get_preferred_item()
|
|
pipeline = assist_pipeline.pipeline.async_get_pipeline(hass, pipeline_id)
|
|
|
|
pipeline_input = assist_pipeline.pipeline.PipelineInput(
|
|
tts_input="This is a test.",
|
|
session=mock_chat_session,
|
|
device_id=None,
|
|
run=assist_pipeline.pipeline.PipelineRun(
|
|
hass,
|
|
context=Context(),
|
|
pipeline=pipeline,
|
|
start_stage=assist_pipeline.PipelineStage.TTS,
|
|
end_stage=assist_pipeline.PipelineStage.TTS,
|
|
event_callback=events.append,
|
|
tts_audio_output="wav",
|
|
),
|
|
)
|
|
await pipeline_input.validate()
|
|
|
|
# Verify TTS audio settings
|
|
assert pipeline_input.run.tts_stream.options is not None
|
|
assert pipeline_input.run.tts_stream.options.get(tts.ATTR_PREFERRED_FORMAT) == "wav"
|
|
assert (
|
|
pipeline_input.run.tts_stream.options.get(tts.ATTR_PREFERRED_SAMPLE_RATE)
|
|
== 16000
|
|
)
|
|
assert (
|
|
pipeline_input.run.tts_stream.options.get(tts.ATTR_PREFERRED_SAMPLE_CHANNELS)
|
|
== 1
|
|
)
|
|
|
|
with patch.object(mock_tts_entity, "get_tts_audio") as mock_get_tts_audio:
|
|
await pipeline_input.execute()
|
|
|
|
for event in events:
|
|
if event.type == assist_pipeline.PipelineEventType.TTS_END:
|
|
# We must fetch the media URL to trigger the TTS
|
|
assert event.data
|
|
await client.get(event.data["tts_output"]["url"])
|
|
|
|
# Ensure that no unsupported options were passed in
|
|
assert mock_get_tts_audio.called
|
|
options = mock_get_tts_audio.call_args_list[0].kwargs["options"]
|
|
extra_options = set(options).difference(mock_tts_entity.supported_options)
|
|
assert len(extra_options) == 0, extra_options
|
|
|
|
|
|
async def test_tts_wav_preferred_format(
|
|
hass: HomeAssistant,
|
|
hass_client: ClientSessionGenerator,
|
|
mock_tts_entity: MockTTSEntity,
|
|
init_components,
|
|
mock_chat_session: chat_session.ChatSession,
|
|
pipeline_data: assist_pipeline.pipeline.PipelineData,
|
|
) -> None:
|
|
"""Test that preferred format options are given to the TTS system if supported."""
|
|
client = await hass_client()
|
|
assert await async_setup_component(hass, media_source.DOMAIN, {})
|
|
|
|
events: list[assist_pipeline.PipelineEvent] = []
|
|
|
|
pipeline_store = pipeline_data.pipeline_store
|
|
pipeline_id = pipeline_store.async_get_preferred_item()
|
|
pipeline = assist_pipeline.pipeline.async_get_pipeline(hass, pipeline_id)
|
|
|
|
pipeline_input = assist_pipeline.pipeline.PipelineInput(
|
|
tts_input="This is a test.",
|
|
session=mock_chat_session,
|
|
device_id=None,
|
|
run=assist_pipeline.pipeline.PipelineRun(
|
|
hass,
|
|
context=Context(),
|
|
pipeline=pipeline,
|
|
start_stage=assist_pipeline.PipelineStage.TTS,
|
|
end_stage=assist_pipeline.PipelineStage.TTS,
|
|
event_callback=events.append,
|
|
tts_audio_output="wav",
|
|
),
|
|
)
|
|
await pipeline_input.validate()
|
|
|
|
# Make the TTS provider support preferred format options
|
|
supported_options = list(mock_tts_entity.supported_options or [])
|
|
supported_options.extend(
|
|
[
|
|
tts.ATTR_PREFERRED_FORMAT,
|
|
tts.ATTR_PREFERRED_SAMPLE_RATE,
|
|
tts.ATTR_PREFERRED_SAMPLE_CHANNELS,
|
|
tts.ATTR_PREFERRED_SAMPLE_BYTES,
|
|
]
|
|
)
|
|
|
|
with (
|
|
patch.object(mock_tts_entity, "_supported_options", supported_options),
|
|
patch.object(mock_tts_entity, "get_tts_audio") as mock_get_tts_audio,
|
|
):
|
|
await pipeline_input.execute()
|
|
|
|
for event in events:
|
|
if event.type == assist_pipeline.PipelineEventType.TTS_END:
|
|
# We must fetch the media URL to trigger the TTS
|
|
assert event.data
|
|
await client.get(event.data["tts_output"]["url"])
|
|
|
|
assert mock_get_tts_audio.called
|
|
options = mock_get_tts_audio.call_args_list[0].kwargs["options"]
|
|
|
|
# We should have received preferred format options in get_tts_audio
|
|
assert options.get(tts.ATTR_PREFERRED_FORMAT) == "wav"
|
|
assert int(options.get(tts.ATTR_PREFERRED_SAMPLE_RATE)) == 16000
|
|
assert int(options.get(tts.ATTR_PREFERRED_SAMPLE_CHANNELS)) == 1
|
|
assert int(options.get(tts.ATTR_PREFERRED_SAMPLE_BYTES)) == 2
|
|
|
|
|
|
async def test_tts_dict_preferred_format(
|
|
hass: HomeAssistant,
|
|
hass_client: ClientSessionGenerator,
|
|
mock_tts_entity: MockTTSEntity,
|
|
init_components,
|
|
mock_chat_session: chat_session.ChatSession,
|
|
pipeline_data: assist_pipeline.pipeline.PipelineData,
|
|
) -> None:
|
|
"""Test that preferred format options are given to the TTS system if supported."""
|
|
client = await hass_client()
|
|
assert await async_setup_component(hass, media_source.DOMAIN, {})
|
|
|
|
events: list[assist_pipeline.PipelineEvent] = []
|
|
|
|
pipeline_store = pipeline_data.pipeline_store
|
|
pipeline_id = pipeline_store.async_get_preferred_item()
|
|
pipeline = assist_pipeline.pipeline.async_get_pipeline(hass, pipeline_id)
|
|
|
|
pipeline_input = assist_pipeline.pipeline.PipelineInput(
|
|
tts_input="This is a test.",
|
|
session=mock_chat_session,
|
|
device_id=None,
|
|
run=assist_pipeline.pipeline.PipelineRun(
|
|
hass,
|
|
context=Context(),
|
|
pipeline=pipeline,
|
|
start_stage=assist_pipeline.PipelineStage.TTS,
|
|
end_stage=assist_pipeline.PipelineStage.TTS,
|
|
event_callback=events.append,
|
|
tts_audio_output={
|
|
tts.ATTR_PREFERRED_FORMAT: "flac",
|
|
tts.ATTR_PREFERRED_SAMPLE_RATE: 48000,
|
|
tts.ATTR_PREFERRED_SAMPLE_CHANNELS: 2,
|
|
tts.ATTR_PREFERRED_SAMPLE_BYTES: 2,
|
|
},
|
|
),
|
|
)
|
|
await pipeline_input.validate()
|
|
|
|
# Make the TTS provider support preferred format options
|
|
supported_options = list(mock_tts_entity.supported_options or [])
|
|
supported_options.extend(
|
|
[
|
|
tts.ATTR_PREFERRED_FORMAT,
|
|
tts.ATTR_PREFERRED_SAMPLE_RATE,
|
|
tts.ATTR_PREFERRED_SAMPLE_CHANNELS,
|
|
tts.ATTR_PREFERRED_SAMPLE_BYTES,
|
|
]
|
|
)
|
|
|
|
with (
|
|
patch.object(mock_tts_entity, "_supported_options", supported_options),
|
|
patch.object(mock_tts_entity, "get_tts_audio") as mock_get_tts_audio,
|
|
):
|
|
await pipeline_input.execute()
|
|
|
|
for event in events:
|
|
if event.type == assist_pipeline.PipelineEventType.TTS_END:
|
|
# We must fetch the media URL to trigger the TTS
|
|
assert event.data
|
|
await client.get(event.data["tts_output"]["url"])
|
|
|
|
assert mock_get_tts_audio.called
|
|
options = mock_get_tts_audio.call_args_list[0].kwargs["options"]
|
|
|
|
# We should have received preferred format options in get_tts_audio
|
|
assert options.get(tts.ATTR_PREFERRED_FORMAT) == "flac"
|
|
assert int(options.get(tts.ATTR_PREFERRED_SAMPLE_RATE)) == 48000
|
|
assert int(options.get(tts.ATTR_PREFERRED_SAMPLE_CHANNELS)) == 2
|
|
assert int(options.get(tts.ATTR_PREFERRED_SAMPLE_BYTES)) == 2
|
|
|
|
|
|
async def test_sentence_trigger_overrides_conversation_agent(
|
|
hass: HomeAssistant,
|
|
init_components,
|
|
mock_chat_session: chat_session.ChatSession,
|
|
pipeline_data: assist_pipeline.pipeline.PipelineData,
|
|
) -> None:
|
|
"""Test that sentence triggers are checked before a non-default conversation agent."""
|
|
assert await async_setup_component(
|
|
hass,
|
|
"automation",
|
|
{
|
|
"automation": {
|
|
"trigger": {
|
|
"platform": "conversation",
|
|
"command": [
|
|
"test trigger sentence",
|
|
],
|
|
},
|
|
"action": {
|
|
"set_conversation_response": "test trigger response",
|
|
},
|
|
}
|
|
},
|
|
)
|
|
|
|
events: list[assist_pipeline.PipelineEvent] = []
|
|
|
|
pipeline_store = pipeline_data.pipeline_store
|
|
pipeline_id = pipeline_store.async_get_preferred_item()
|
|
pipeline = assist_pipeline.pipeline.async_get_pipeline(hass, pipeline_id)
|
|
|
|
pipeline_input = assist_pipeline.pipeline.PipelineInput(
|
|
intent_input="test trigger sentence",
|
|
session=mock_chat_session,
|
|
run=assist_pipeline.pipeline.PipelineRun(
|
|
hass,
|
|
context=Context(),
|
|
pipeline=pipeline,
|
|
start_stage=assist_pipeline.PipelineStage.INTENT,
|
|
end_stage=assist_pipeline.PipelineStage.INTENT,
|
|
event_callback=events.append,
|
|
intent_agent="test-agent", # not the default agent
|
|
),
|
|
)
|
|
|
|
# Ensure prepare succeeds
|
|
with patch(
|
|
"homeassistant.components.assist_pipeline.pipeline.conversation.async_get_agent_info",
|
|
return_value=conversation.AgentInfo(
|
|
id="test-agent",
|
|
name="Test Agent",
|
|
supports_streaming=False,
|
|
),
|
|
):
|
|
await pipeline_input.validate()
|
|
|
|
with patch(
|
|
"homeassistant.components.assist_pipeline.pipeline.conversation.async_converse"
|
|
) as mock_async_converse:
|
|
await pipeline_input.execute()
|
|
|
|
# Sentence trigger should have been handled
|
|
mock_async_converse.assert_not_called()
|
|
|
|
# Verify sentence trigger response
|
|
intent_end_event = next(
|
|
(
|
|
e
|
|
for e in events
|
|
if e.type == assist_pipeline.PipelineEventType.INTENT_END
|
|
),
|
|
None,
|
|
)
|
|
assert (intent_end_event is not None) and intent_end_event.data
|
|
assert intent_end_event.data["processed_locally"] is True
|
|
assert (
|
|
intent_end_event.data["intent_output"]["response"]["speech"]["plain"][
|
|
"speech"
|
|
]
|
|
== "test trigger response"
|
|
)
|
|
|
|
|
|
async def test_prefer_local_intents(
|
|
hass: HomeAssistant,
|
|
init_components,
|
|
mock_chat_session: chat_session.ChatSession,
|
|
pipeline_data: assist_pipeline.pipeline.PipelineData,
|
|
) -> None:
|
|
"""Test that the default agent is checked first when local intents are preferred."""
|
|
events: list[assist_pipeline.PipelineEvent] = []
|
|
|
|
# Reuse custom sentences in test config
|
|
class OrderBeerIntentHandler(intent.IntentHandler):
|
|
intent_type = "OrderBeer"
|
|
|
|
async def async_handle(
|
|
self, intent_obj: intent.Intent
|
|
) -> intent.IntentResponse:
|
|
response = intent_obj.create_response()
|
|
response.async_set_speech("Order confirmed")
|
|
return response
|
|
|
|
handler = OrderBeerIntentHandler()
|
|
intent.async_register(hass, handler)
|
|
|
|
# Fake a test agent and prefer local intents
|
|
pipeline_store = pipeline_data.pipeline_store
|
|
pipeline_id = pipeline_store.async_get_preferred_item()
|
|
pipeline = assist_pipeline.pipeline.async_get_pipeline(hass, pipeline_id)
|
|
await assist_pipeline.pipeline.async_update_pipeline(
|
|
hass, pipeline, conversation_engine="test-agent", prefer_local_intents=True
|
|
)
|
|
pipeline = assist_pipeline.pipeline.async_get_pipeline(hass, pipeline_id)
|
|
|
|
pipeline_input = assist_pipeline.pipeline.PipelineInput(
|
|
intent_input="I'd like to order a stout please",
|
|
session=mock_chat_session,
|
|
run=assist_pipeline.pipeline.PipelineRun(
|
|
hass,
|
|
context=Context(),
|
|
pipeline=pipeline,
|
|
start_stage=assist_pipeline.PipelineStage.INTENT,
|
|
end_stage=assist_pipeline.PipelineStage.INTENT,
|
|
event_callback=events.append,
|
|
),
|
|
)
|
|
|
|
# Ensure prepare succeeds
|
|
with patch(
|
|
"homeassistant.components.assist_pipeline.pipeline.conversation.async_get_agent_info",
|
|
return_value=conversation.AgentInfo(
|
|
id="test-agent",
|
|
name="Test Agent",
|
|
supports_streaming=False,
|
|
),
|
|
):
|
|
await pipeline_input.validate()
|
|
|
|
with patch(
|
|
"homeassistant.components.assist_pipeline.pipeline.conversation.async_converse"
|
|
) as mock_async_converse:
|
|
await pipeline_input.execute()
|
|
|
|
# Test agent should not have been called
|
|
mock_async_converse.assert_not_called()
|
|
|
|
# Verify local intent response
|
|
intent_end_event = next(
|
|
(
|
|
e
|
|
for e in events
|
|
if e.type == assist_pipeline.PipelineEventType.INTENT_END
|
|
),
|
|
None,
|
|
)
|
|
assert (intent_end_event is not None) and intent_end_event.data
|
|
assert intent_end_event.data["processed_locally"] is True
|
|
assert (
|
|
intent_end_event.data["intent_output"]["response"]["speech"]["plain"][
|
|
"speech"
|
|
]
|
|
== "Order confirmed"
|
|
)
|
|
|
|
|
|
async def test_intent_continue_conversation(
|
|
hass: HomeAssistant,
|
|
init_components,
|
|
mock_chat_session: chat_session.ChatSession,
|
|
pipeline_data: assist_pipeline.pipeline.PipelineData,
|
|
) -> None:
|
|
"""Test that a conversation agent flagging continue conversation gets response."""
|
|
events: list[assist_pipeline.PipelineEvent] = []
|
|
|
|
# Fake a test agent and prefer local intents
|
|
pipeline_store = pipeline_data.pipeline_store
|
|
pipeline_id = pipeline_store.async_get_preferred_item()
|
|
pipeline = assist_pipeline.pipeline.async_get_pipeline(hass, pipeline_id)
|
|
await assist_pipeline.pipeline.async_update_pipeline(
|
|
hass, pipeline, conversation_engine="test-agent"
|
|
)
|
|
pipeline = assist_pipeline.pipeline.async_get_pipeline(hass, pipeline_id)
|
|
|
|
pipeline_input = assist_pipeline.pipeline.PipelineInput(
|
|
intent_input="Set a timer",
|
|
session=mock_chat_session,
|
|
run=assist_pipeline.pipeline.PipelineRun(
|
|
hass,
|
|
context=Context(),
|
|
pipeline=pipeline,
|
|
start_stage=assist_pipeline.PipelineStage.INTENT,
|
|
end_stage=assist_pipeline.PipelineStage.INTENT,
|
|
event_callback=events.append,
|
|
),
|
|
)
|
|
|
|
# Ensure prepare succeeds
|
|
with patch(
|
|
"homeassistant.components.assist_pipeline.pipeline.conversation.async_get_agent_info",
|
|
return_value=conversation.AgentInfo(
|
|
id="test-agent",
|
|
name="Test Agent",
|
|
supports_streaming=False,
|
|
),
|
|
):
|
|
await pipeline_input.validate()
|
|
|
|
response = intent.IntentResponse("en")
|
|
response.async_set_speech("For how long?")
|
|
|
|
with patch(
|
|
"homeassistant.components.assist_pipeline.pipeline.conversation.async_converse",
|
|
return_value=conversation.ConversationResult(
|
|
response=response,
|
|
conversation_id=mock_chat_session.conversation_id,
|
|
continue_conversation=True,
|
|
),
|
|
) as mock_async_converse:
|
|
await pipeline_input.execute()
|
|
|
|
mock_async_converse.assert_called()
|
|
|
|
results = [
|
|
event.data
|
|
for event in events
|
|
if event.type
|
|
in (
|
|
assist_pipeline.PipelineEventType.INTENT_START,
|
|
assist_pipeline.PipelineEventType.INTENT_END,
|
|
)
|
|
]
|
|
assert results[1]["intent_output"]["continue_conversation"] is True
|
|
|
|
# Change conversation agent to default one and register sentence trigger that should not be called
|
|
await assist_pipeline.pipeline.async_update_pipeline(
|
|
hass, pipeline, conversation_engine=None
|
|
)
|
|
pipeline = assist_pipeline.pipeline.async_get_pipeline(hass, pipeline_id)
|
|
assert await async_setup_component(
|
|
hass,
|
|
"automation",
|
|
{
|
|
"automation": {
|
|
"trigger": {
|
|
"platform": "conversation",
|
|
"command": ["Hello"],
|
|
},
|
|
"action": {
|
|
"set_conversation_response": "test trigger response",
|
|
},
|
|
}
|
|
},
|
|
)
|
|
|
|
# Because we did continue conversation, it should respond to the test agent again.
|
|
events.clear()
|
|
|
|
pipeline_input = assist_pipeline.pipeline.PipelineInput(
|
|
intent_input="Hello",
|
|
session=mock_chat_session,
|
|
run=assist_pipeline.pipeline.PipelineRun(
|
|
hass,
|
|
context=Context(),
|
|
pipeline=pipeline,
|
|
start_stage=assist_pipeline.PipelineStage.INTENT,
|
|
end_stage=assist_pipeline.PipelineStage.INTENT,
|
|
event_callback=events.append,
|
|
),
|
|
)
|
|
|
|
# Ensure prepare succeeds
|
|
with patch(
|
|
"homeassistant.components.assist_pipeline.pipeline.conversation.async_get_agent_info",
|
|
return_value=conversation.AgentInfo(
|
|
id="test-agent",
|
|
name="Test Agent",
|
|
supports_streaming=False,
|
|
),
|
|
) as mock_prepare:
|
|
await pipeline_input.validate()
|
|
|
|
# It requested test agent even if that was not default agent.
|
|
assert mock_prepare.mock_calls[0][1][1] == "test-agent"
|
|
|
|
response = intent.IntentResponse("en")
|
|
response.async_set_speech("Timer set for 20 minutes")
|
|
|
|
with patch(
|
|
"homeassistant.components.assist_pipeline.pipeline.conversation.async_converse",
|
|
return_value=conversation.ConversationResult(
|
|
response=response,
|
|
conversation_id=mock_chat_session.conversation_id,
|
|
),
|
|
) as mock_async_converse:
|
|
await pipeline_input.execute()
|
|
|
|
mock_async_converse.assert_called()
|
|
|
|
# Snapshot will show it was still handled by the test agent and not default agent
|
|
results = [
|
|
event.data
|
|
for event in events
|
|
if event.type
|
|
in (
|
|
assist_pipeline.PipelineEventType.INTENT_START,
|
|
assist_pipeline.PipelineEventType.INTENT_END,
|
|
)
|
|
]
|
|
assert results[0]["engine"] == "test-agent"
|
|
assert results[1]["intent_output"]["continue_conversation"] is False
|
|
|
|
|
|
async def test_stt_language_used_instead_of_conversation_language(
|
|
hass: HomeAssistant,
|
|
hass_ws_client: WebSocketGenerator,
|
|
init_components,
|
|
mock_chat_session: chat_session.ChatSession,
|
|
snapshot: SnapshotAssertion,
|
|
) -> None:
|
|
"""Test that the STT language is used first when the conversation language is '*' (all languages)."""
|
|
client = await hass_ws_client(hass)
|
|
|
|
events: list[assist_pipeline.PipelineEvent] = []
|
|
|
|
await client.send_json_auto_id(
|
|
{
|
|
"type": "assist_pipeline/pipeline/create",
|
|
"conversation_engine": conversation.HOME_ASSISTANT_AGENT,
|
|
"conversation_language": MATCH_ALL,
|
|
"language": "en",
|
|
"name": "test_name",
|
|
"stt_engine": "test",
|
|
"stt_language": "en-US",
|
|
"tts_engine": "test",
|
|
"tts_language": "en-US",
|
|
"tts_voice": "Arnold Schwarzenegger",
|
|
"wake_word_entity": None,
|
|
"wake_word_id": None,
|
|
}
|
|
)
|
|
msg = await client.receive_json()
|
|
assert msg["success"]
|
|
pipeline_id = msg["result"]["id"]
|
|
pipeline = assist_pipeline.async_get_pipeline(hass, pipeline_id)
|
|
|
|
pipeline_input = assist_pipeline.pipeline.PipelineInput(
|
|
intent_input="test input",
|
|
session=mock_chat_session,
|
|
run=assist_pipeline.pipeline.PipelineRun(
|
|
hass,
|
|
context=Context(),
|
|
pipeline=pipeline,
|
|
start_stage=assist_pipeline.PipelineStage.INTENT,
|
|
end_stage=assist_pipeline.PipelineStage.INTENT,
|
|
event_callback=events.append,
|
|
),
|
|
)
|
|
await pipeline_input.validate()
|
|
|
|
with patch(
|
|
"homeassistant.components.assist_pipeline.pipeline.conversation.async_converse",
|
|
return_value=conversation.ConversationResult(
|
|
intent.IntentResponse(pipeline.language)
|
|
),
|
|
) as mock_async_converse:
|
|
await pipeline_input.execute()
|
|
|
|
# Check intent start event
|
|
assert process_events(events) == snapshot
|
|
intent_start: assist_pipeline.PipelineEvent | None = None
|
|
for event in events:
|
|
if event.type == assist_pipeline.PipelineEventType.INTENT_START:
|
|
intent_start = event
|
|
break
|
|
|
|
assert intent_start is not None
|
|
|
|
# STT language (en-US) should be used instead of '*'
|
|
assert intent_start.data.get("language") == pipeline.stt_language
|
|
|
|
# Check input to async_converse
|
|
mock_async_converse.assert_called_once()
|
|
assert (
|
|
mock_async_converse.call_args_list[0].kwargs.get("language")
|
|
== pipeline.stt_language
|
|
)
|
|
|
|
|
|
async def test_tts_language_used_instead_of_conversation_language(
|
|
hass: HomeAssistant,
|
|
hass_ws_client: WebSocketGenerator,
|
|
init_components,
|
|
mock_chat_session: chat_session.ChatSession,
|
|
snapshot: SnapshotAssertion,
|
|
) -> None:
|
|
"""Test that the TTS language is used after STT when the conversation language is '*' (all languages)."""
|
|
client = await hass_ws_client(hass)
|
|
|
|
events: list[assist_pipeline.PipelineEvent] = []
|
|
|
|
await client.send_json_auto_id(
|
|
{
|
|
"type": "assist_pipeline/pipeline/create",
|
|
"conversation_engine": conversation.HOME_ASSISTANT_AGENT,
|
|
"conversation_language": MATCH_ALL,
|
|
"language": "en",
|
|
"name": "test_name",
|
|
"stt_engine": None,
|
|
"stt_language": None,
|
|
"tts_engine": None,
|
|
"tts_language": "en-us",
|
|
"tts_voice": "Arnold Schwarzenegger",
|
|
"wake_word_entity": None,
|
|
"wake_word_id": None,
|
|
}
|
|
)
|
|
msg = await client.receive_json()
|
|
assert msg["success"]
|
|
pipeline_id = msg["result"]["id"]
|
|
pipeline = assist_pipeline.async_get_pipeline(hass, pipeline_id)
|
|
|
|
pipeline_input = assist_pipeline.pipeline.PipelineInput(
|
|
intent_input="test input",
|
|
session=mock_chat_session,
|
|
run=assist_pipeline.pipeline.PipelineRun(
|
|
hass,
|
|
context=Context(),
|
|
pipeline=pipeline,
|
|
start_stage=assist_pipeline.PipelineStage.INTENT,
|
|
end_stage=assist_pipeline.PipelineStage.INTENT,
|
|
event_callback=events.append,
|
|
),
|
|
)
|
|
await pipeline_input.validate()
|
|
|
|
with patch(
|
|
"homeassistant.components.assist_pipeline.pipeline.conversation.async_converse",
|
|
return_value=conversation.ConversationResult(
|
|
intent.IntentResponse(pipeline.language)
|
|
),
|
|
) as mock_async_converse:
|
|
await pipeline_input.execute()
|
|
|
|
# Check intent start event
|
|
assert process_events(events) == snapshot
|
|
intent_start: assist_pipeline.PipelineEvent | None = None
|
|
for event in events:
|
|
if event.type == assist_pipeline.PipelineEventType.INTENT_START:
|
|
intent_start = event
|
|
break
|
|
|
|
assert intent_start is not None
|
|
|
|
# STT language (en-US) should be used instead of '*'
|
|
assert intent_start.data.get("language") == pipeline.tts_language
|
|
|
|
# Check input to async_converse
|
|
mock_async_converse.assert_called_once()
|
|
assert (
|
|
mock_async_converse.call_args_list[0].kwargs.get("language")
|
|
== pipeline.tts_language
|
|
)
|
|
|
|
|
|
async def test_pipeline_language_used_instead_of_conversation_language(
|
|
hass: HomeAssistant,
|
|
hass_ws_client: WebSocketGenerator,
|
|
init_components,
|
|
mock_chat_session: chat_session.ChatSession,
|
|
snapshot: SnapshotAssertion,
|
|
) -> None:
|
|
"""Test that the pipeline language is used last when the conversation language is '*' (all languages)."""
|
|
client = await hass_ws_client(hass)
|
|
|
|
events: list[assist_pipeline.PipelineEvent] = []
|
|
|
|
await client.send_json_auto_id(
|
|
{
|
|
"type": "assist_pipeline/pipeline/create",
|
|
"conversation_engine": conversation.HOME_ASSISTANT_AGENT,
|
|
"conversation_language": MATCH_ALL,
|
|
"language": "en",
|
|
"name": "test_name",
|
|
"stt_engine": None,
|
|
"stt_language": None,
|
|
"tts_engine": None,
|
|
"tts_language": None,
|
|
"tts_voice": None,
|
|
"wake_word_entity": None,
|
|
"wake_word_id": None,
|
|
}
|
|
)
|
|
msg = await client.receive_json()
|
|
assert msg["success"]
|
|
pipeline_id = msg["result"]["id"]
|
|
pipeline = assist_pipeline.async_get_pipeline(hass, pipeline_id)
|
|
|
|
pipeline_input = assist_pipeline.pipeline.PipelineInput(
|
|
intent_input="test input",
|
|
session=mock_chat_session,
|
|
run=assist_pipeline.pipeline.PipelineRun(
|
|
hass,
|
|
context=Context(),
|
|
pipeline=pipeline,
|
|
start_stage=assist_pipeline.PipelineStage.INTENT,
|
|
end_stage=assist_pipeline.PipelineStage.INTENT,
|
|
event_callback=events.append,
|
|
),
|
|
)
|
|
await pipeline_input.validate()
|
|
|
|
with patch(
|
|
"homeassistant.components.assist_pipeline.pipeline.conversation.async_converse",
|
|
return_value=conversation.ConversationResult(
|
|
intent.IntentResponse(pipeline.language)
|
|
),
|
|
) as mock_async_converse:
|
|
await pipeline_input.execute()
|
|
|
|
# Check intent start event
|
|
assert process_events(events) == snapshot
|
|
intent_start: assist_pipeline.PipelineEvent | None = None
|
|
for event in events:
|
|
if event.type == assist_pipeline.PipelineEventType.INTENT_START:
|
|
intent_start = event
|
|
break
|
|
|
|
assert intent_start is not None
|
|
|
|
# STT language (en-US) should be used instead of '*'
|
|
assert intent_start.data.get("language") == pipeline.language
|
|
|
|
# Check input to async_converse
|
|
mock_async_converse.assert_called_once()
|
|
assert (
|
|
mock_async_converse.call_args_list[0].kwargs.get("language")
|
|
== pipeline.language
|
|
)
|
|
|
|
|
|
@pytest.mark.parametrize(
|
|
("to_stream_deltas", "expected_chunks", "chunk_text"),
|
|
[
|
|
# Size below STREAM_RESPONSE_CHUNKS
|
|
(
|
|
(
|
|
[
|
|
"hello,",
|
|
" ",
|
|
"how",
|
|
" ",
|
|
"are",
|
|
" ",
|
|
"you",
|
|
"?",
|
|
],
|
|
),
|
|
# We always stream when possible, so 1 chunk via streaming method
|
|
1,
|
|
"hello, how are you?",
|
|
),
|
|
# Size above STREAM_RESPONSE_CHUNKS
|
|
(
|
|
(
|
|
[
|
|
"hello, ",
|
|
"how ",
|
|
"are ",
|
|
"you",
|
|
"? ",
|
|
"I'm ",
|
|
"doing ",
|
|
"well",
|
|
", ",
|
|
"thank ",
|
|
"you",
|
|
". ",
|
|
"What ",
|
|
"about ",
|
|
"you",
|
|
"?",
|
|
"!",
|
|
],
|
|
),
|
|
# We are streamed. First 15 chunks are grouped into 1 chunk
|
|
# and the rest are streamed
|
|
3,
|
|
"hello, how are you? I'm doing well, thank you. What about you?!",
|
|
),
|
|
# Stream a bit, then a tool call, then stream some more
|
|
(
|
|
(
|
|
[
|
|
"hello, ",
|
|
"how ",
|
|
"are ",
|
|
"you",
|
|
"? ",
|
|
],
|
|
{
|
|
"tool_calls": [
|
|
llm.ToolInput(
|
|
tool_name="test_tool",
|
|
tool_args={},
|
|
id="test_tool_id",
|
|
)
|
|
],
|
|
},
|
|
[
|
|
"I'm ",
|
|
"doing ",
|
|
"well",
|
|
", ",
|
|
"thank ",
|
|
"you",
|
|
".",
|
|
],
|
|
),
|
|
# 1 chunk before tool call, then 7 after
|
|
8,
|
|
"hello, how are you? I'm doing well, thank you.",
|
|
),
|
|
],
|
|
)
|
|
async def test_chat_log_tts_streaming(
|
|
hass: HomeAssistant,
|
|
hass_ws_client: WebSocketGenerator,
|
|
init_components,
|
|
mock_chat_session: chat_session.ChatSession,
|
|
snapshot: SnapshotAssertion,
|
|
mock_tts_entity: MockTTSEntity,
|
|
pipeline_data: assist_pipeline.pipeline.PipelineData,
|
|
to_stream_deltas: tuple[dict | list[str]],
|
|
expected_chunks: int,
|
|
chunk_text: str,
|
|
) -> None:
|
|
"""Test that chat log events are streamed to the TTS entity."""
|
|
text_deltas = [
|
|
delta
|
|
for deltas in to_stream_deltas
|
|
if isinstance(deltas, list)
|
|
for delta in deltas
|
|
]
|
|
|
|
events: list[assist_pipeline.PipelineEvent] = []
|
|
|
|
pipeline_store = pipeline_data.pipeline_store
|
|
pipeline_id = pipeline_store.async_get_preferred_item()
|
|
pipeline = assist_pipeline.pipeline.async_get_pipeline(hass, pipeline_id)
|
|
await assist_pipeline.pipeline.async_update_pipeline(
|
|
hass, pipeline, conversation_engine="test-agent"
|
|
)
|
|
pipeline = assist_pipeline.pipeline.async_get_pipeline(hass, pipeline_id)
|
|
|
|
pipeline_input = assist_pipeline.pipeline.PipelineInput(
|
|
intent_input="Set a timer",
|
|
session=mock_chat_session,
|
|
run=assist_pipeline.pipeline.PipelineRun(
|
|
hass,
|
|
context=Context(),
|
|
pipeline=pipeline,
|
|
start_stage=assist_pipeline.PipelineStage.INTENT,
|
|
end_stage=assist_pipeline.PipelineStage.TTS,
|
|
event_callback=events.append,
|
|
),
|
|
)
|
|
|
|
received_tts = []
|
|
|
|
async def async_stream_tts_audio(
|
|
request: tts.TTSAudioRequest,
|
|
) -> tts.TTSAudioResponse:
|
|
"""Mock stream TTS audio."""
|
|
|
|
async def gen_data():
|
|
async for msg in request.message_gen:
|
|
received_tts.append(msg)
|
|
yield msg.encode()
|
|
|
|
return tts.TTSAudioResponse(
|
|
extension="mp3",
|
|
data_gen=gen_data(),
|
|
)
|
|
|
|
async def async_get_tts_audio(
|
|
message: str,
|
|
language: str,
|
|
options: dict[str, Any] | None = None,
|
|
) -> tts.TtsAudioType:
|
|
"""Mock get TTS audio."""
|
|
return ("mp3", b"".join([chunk.encode() for chunk in text_deltas]))
|
|
|
|
mock_tts_entity.async_get_tts_audio = async_get_tts_audio
|
|
mock_tts_entity.async_stream_tts_audio = async_stream_tts_audio
|
|
mock_tts_entity.async_supports_streaming_input = Mock(return_value=True)
|
|
|
|
with patch(
|
|
"homeassistant.components.assist_pipeline.pipeline.conversation.async_get_agent_info",
|
|
return_value=conversation.AgentInfo(
|
|
id="test-agent",
|
|
name="Test Agent",
|
|
supports_streaming=True,
|
|
),
|
|
):
|
|
await pipeline_input.validate()
|
|
|
|
async def mock_converse(
|
|
hass: HomeAssistant,
|
|
text: str,
|
|
conversation_id: str | None,
|
|
context: Context,
|
|
language: str | None = None,
|
|
agent_id: str | None = None,
|
|
device_id: str | None = None,
|
|
extra_system_prompt: str | None = None,
|
|
):
|
|
"""Mock converse."""
|
|
conversation_input = conversation.ConversationInput(
|
|
text=text,
|
|
context=context,
|
|
conversation_id=conversation_id,
|
|
device_id=device_id,
|
|
language=language,
|
|
agent_id=agent_id,
|
|
extra_system_prompt=extra_system_prompt,
|
|
)
|
|
|
|
async def stream_llm_response():
|
|
for deltas in to_stream_deltas:
|
|
if isinstance(deltas, dict):
|
|
yield deltas
|
|
else:
|
|
yield {"role": "assistant"}
|
|
for chunk in deltas:
|
|
yield {"content": chunk}
|
|
|
|
with (
|
|
chat_session.async_get_chat_session(hass, conversation_id) as session,
|
|
conversation.async_get_chat_log(
|
|
hass,
|
|
session,
|
|
conversation_input,
|
|
) as chat_log,
|
|
):
|
|
await chat_log.async_provide_llm_data(
|
|
conversation_input.as_llm_context("test"),
|
|
user_llm_hass_api="assist",
|
|
user_llm_prompt=None,
|
|
user_extra_system_prompt=conversation_input.extra_system_prompt,
|
|
)
|
|
async for _content in chat_log.async_add_delta_content_stream(
|
|
agent_id, stream_llm_response()
|
|
):
|
|
pass
|
|
intent_response = intent.IntentResponse(language)
|
|
intent_response.async_set_speech("".join(to_stream_deltas[-1]))
|
|
return conversation.ConversationResult(
|
|
response=intent_response,
|
|
conversation_id=chat_log.conversation_id,
|
|
continue_conversation=chat_log.continue_conversation,
|
|
)
|
|
|
|
mock_tool = AsyncMock()
|
|
mock_tool.name = "test_tool"
|
|
mock_tool.description = "Test function"
|
|
mock_tool.parameters = vol.Schema({})
|
|
mock_tool.async_call.return_value = "Test response"
|
|
|
|
with (
|
|
patch(
|
|
"homeassistant.helpers.llm.AssistAPI._async_get_tools",
|
|
return_value=[mock_tool],
|
|
),
|
|
patch(
|
|
"homeassistant.components.assist_pipeline.pipeline.conversation.async_converse",
|
|
mock_converse,
|
|
),
|
|
):
|
|
await pipeline_input.execute()
|
|
|
|
stream = tts.async_get_stream(hass, events[0].data["tts_output"]["token"])
|
|
assert stream is not None
|
|
tts_result = "".join(
|
|
[chunk.decode() async for chunk in stream.async_stream_result()]
|
|
)
|
|
|
|
streamed_text = "".join(text_deltas)
|
|
assert tts_result == streamed_text
|
|
assert len(received_tts) == expected_chunks
|
|
assert "".join(received_tts) == chunk_text
|
|
|
|
assert process_events(events) == snapshot
|