core/tests/components/assist_pipeline/test_pipeline.py

1788 lines
59 KiB
Python

"""Websocket tests for Voice Assistant integration."""
from collections.abc import AsyncGenerator, Generator
from typing import Any
from unittest.mock import ANY, AsyncMock, Mock, patch
from hassil.recognize import Intent, IntentData, RecognizeResult
import pytest
from syrupy.assertion import SnapshotAssertion
import voluptuous as vol
from homeassistant.components import (
assist_pipeline,
conversation,
media_source,
stt,
tts,
)
from homeassistant.components.assist_pipeline.const import DOMAIN
from homeassistant.components.assist_pipeline.pipeline import (
STORAGE_KEY,
STORAGE_VERSION,
STORAGE_VERSION_MINOR,
Pipeline,
PipelineData,
PipelineStorageCollection,
PipelineStore,
_async_local_fallback_intent_filter,
async_create_default_pipeline,
async_get_pipeline,
async_get_pipelines,
async_update_pipeline,
)
from homeassistant.const import MATCH_ALL
from homeassistant.core import Context, HomeAssistant
from homeassistant.helpers import chat_session, intent, llm
from homeassistant.setup import async_setup_component
from . import MANY_LANGUAGES, process_events
from .conftest import (
MockSTTProvider,
MockSTTProviderEntity,
MockTTSEntity,
MockTTSProvider,
MockWakeWordEntity,
make_10ms_chunk,
)
from tests.common import flush_store
from tests.typing import ClientSessionGenerator, WebSocketGenerator
@pytest.fixture(autouse=True)
async def delay_save_fixture() -> AsyncGenerator[None]:
"""Load the homeassistant integration."""
with patch("homeassistant.helpers.collection.SAVE_DELAY", new=0):
yield
@pytest.fixture(autouse=True)
async def load_homeassistant(hass: HomeAssistant) -> None:
"""Load the homeassistant integration."""
assert await async_setup_component(hass, "homeassistant", {})
@pytest.fixture
async def disable_tts_entity(mock_tts_entity: tts.TextToSpeechEntity) -> None:
"""Disable the TTS entity."""
mock_tts_entity._attr_entity_registry_enabled_default = False
@pytest.mark.usefixtures("init_components")
async def test_load_pipelines(hass: HomeAssistant) -> None:
"""Make sure that we can load/save data correctly."""
pipelines = [
{
"conversation_engine": "conversation_engine_1",
"conversation_language": "language_1",
"language": "language_1",
"name": "name_1",
"stt_engine": "stt_engine_1",
"stt_language": "language_1",
"tts_engine": "tts_engine_1",
"tts_language": "language_1",
"tts_voice": "Arnold Schwarzenegger",
"wake_word_entity": "wakeword_entity_1",
"wake_word_id": "wakeword_id_1",
},
{
"conversation_engine": "conversation_engine_2",
"conversation_language": "language_2",
"language": "language_2",
"name": "name_2",
"stt_engine": "stt_engine_2",
"stt_language": "language_1",
"tts_engine": "tts_engine_2",
"tts_language": "language_2",
"tts_voice": "The Voice",
"wake_word_entity": "wakeword_entity_2",
"wake_word_id": "wakeword_id_2",
},
{
"conversation_engine": "conversation_engine_3",
"conversation_language": "language_3",
"language": "language_3",
"name": "name_3",
"stt_engine": None,
"stt_language": None,
"tts_engine": None,
"tts_language": None,
"tts_voice": None,
"wake_word_entity": "wakeword_entity_3",
"wake_word_id": "wakeword_id_3",
},
]
pipeline_data: PipelineData = hass.data[DOMAIN]
store1 = pipeline_data.pipeline_store
pipeline_ids = [
(await store1.async_create_item(pipeline)).id for pipeline in pipelines
]
assert len(store1.data) == 4 # 3 manually created plus a default pipeline
assert store1.async_get_preferred_item() == list(store1.data)[0]
await store1.async_delete_item(pipeline_ids[1])
assert len(store1.data) == 3
store2 = PipelineStorageCollection(
PipelineStore(
hass, STORAGE_VERSION, STORAGE_KEY, minor_version=STORAGE_VERSION_MINOR
)
)
await flush_store(store1.store)
await store2.async_load()
assert len(store2.data) == 3
assert store1.data is not store2.data
assert store1.data == store2.data
assert store1.async_get_preferred_item() == store2.async_get_preferred_item()
@pytest.fixture(autouse=True)
def mock_chat_session_id() -> Generator[Mock]:
"""Mock the conversation ID of chat sessions."""
with patch(
"homeassistant.helpers.chat_session.ulid_now", return_value="mock-ulid"
) as mock_ulid_now:
yield mock_ulid_now
@pytest.fixture(autouse=True)
def mock_tts_token() -> Generator[None]:
"""Mock the TTS token for URLs."""
with patch("secrets.token_urlsafe", return_value="mocked-token"):
yield
async def test_loading_pipelines_from_storage(
hass: HomeAssistant, hass_storage: dict[str, Any]
) -> None:
"""Test loading stored pipelines on start."""
id_1 = "01GX8ZWBAQYWNB1XV3EXEZ75DY"
hass_storage[STORAGE_KEY] = {
"version": STORAGE_VERSION,
"minor_version": STORAGE_VERSION_MINOR,
"key": "assist_pipeline.pipelines",
"data": {
"items": [
{
"conversation_engine": conversation.HOME_ASSISTANT_AGENT,
"conversation_language": "language_1",
"id": id_1,
"language": "language_1",
"name": "name_1",
"stt_engine": "stt_engine_1",
"stt_language": "language_1",
"tts_engine": "tts_engine_1",
"tts_language": "language_1",
"tts_voice": "Arnold Schwarzenegger",
"wake_word_entity": "wakeword_entity_1",
"wake_word_id": "wakeword_id_1",
},
{
"conversation_engine": "conversation_engine_2",
"conversation_language": "language_2",
"id": "01GX8ZWBAQTKFQNK4W7Q4CTRCX",
"language": "language_2",
"name": "name_2",
"stt_engine": "stt_engine_2",
"stt_language": "language_2",
"tts_engine": "tts_engine_2",
"tts_language": "language_2",
"tts_voice": "The Voice",
"wake_word_entity": "wakeword_entity_2",
"wake_word_id": "wakeword_id_2",
},
{
"conversation_engine": "conversation_engine_3",
"conversation_language": "language_3",
"id": "01GX8ZWBAQSV1HP3WGJPFWEJ8J",
"language": "language_3",
"name": "name_3",
"stt_engine": None,
"stt_language": None,
"tts_engine": None,
"tts_language": None,
"tts_voice": None,
"wake_word_entity": "wakeword_entity_3",
"wake_word_id": "wakeword_id_3",
},
],
"preferred_item": id_1,
},
}
assert await async_setup_component(hass, "assist_pipeline", {})
pipeline_data: PipelineData = hass.data[DOMAIN]
store = pipeline_data.pipeline_store
assert len(store.data) == 3
assert store.async_get_preferred_item() == id_1
assert store.data[id_1].conversation_engine == conversation.HOME_ASSISTANT_AGENT
async def test_migrate_pipeline_store(
hass: HomeAssistant, hass_storage: dict[str, Any]
) -> None:
"""Test loading stored pipelines from an older version."""
hass_storage[STORAGE_KEY] = {
"version": 1,
"minor_version": 1,
"key": "assist_pipeline.pipelines",
"data": {
"items": [
{
"conversation_engine": "conversation_engine_1",
"conversation_language": "language_1",
"id": "01GX8ZWBAQYWNB1XV3EXEZ75DY",
"language": "language_1",
"name": "name_1",
"stt_engine": "stt_engine_1",
"stt_language": "language_1",
"tts_engine": "tts_engine_1",
"tts_language": "language_1",
"tts_voice": "Arnold Schwarzenegger",
},
{
"conversation_engine": "conversation_engine_2",
"conversation_language": "language_2",
"id": "01GX8ZWBAQTKFQNK4W7Q4CTRCX",
"language": "language_2",
"name": "name_2",
"stt_engine": "stt_engine_2",
"stt_language": "language_2",
"tts_engine": "tts_engine_2",
"tts_language": "language_2",
"tts_voice": "The Voice",
},
{
"conversation_engine": "conversation_engine_3",
"conversation_language": "language_3",
"id": "01GX8ZWBAQSV1HP3WGJPFWEJ8J",
"language": "language_3",
"name": "name_3",
"stt_engine": None,
"stt_language": None,
"tts_engine": None,
"tts_language": None,
"tts_voice": None,
},
],
"preferred_item": "01GX8ZWBAQYWNB1XV3EXEZ75DY",
},
}
assert await async_setup_component(hass, "assist_pipeline", {})
pipeline_data: PipelineData = hass.data[DOMAIN]
store = pipeline_data.pipeline_store
assert len(store.data) == 3
assert store.async_get_preferred_item() == "01GX8ZWBAQYWNB1XV3EXEZ75DY"
@pytest.mark.usefixtures("init_supporting_components")
@pytest.mark.usefixtures("disable_tts_entity")
async def test_create_default_pipeline(hass: HomeAssistant) -> None:
"""Test async_create_default_pipeline."""
assert await async_setup_component(hass, "assist_pipeline", {})
pipeline_data: PipelineData = hass.data[DOMAIN]
store = pipeline_data.pipeline_store
assert len(store.data) == 1
assert (
await async_create_default_pipeline(
hass,
stt_engine_id="bla",
tts_engine_id="bla",
pipeline_name="Bla pipeline",
)
is None
)
assert await async_create_default_pipeline(
hass,
stt_engine_id="test",
tts_engine_id="test",
pipeline_name="Test pipeline",
) == Pipeline(
conversation_engine="conversation.home_assistant",
conversation_language="en",
id=ANY,
language="en",
name="Test pipeline",
stt_engine="test",
stt_language="en-US",
tts_engine="test",
tts_language="en-US",
tts_voice="james_earl_jones",
wake_word_entity=None,
wake_word_id=None,
)
async def test_get_pipeline(hass: HomeAssistant) -> None:
"""Test async_get_pipeline."""
assert await async_setup_component(hass, "assist_pipeline", {})
pipeline_data: PipelineData = hass.data[DOMAIN]
store = pipeline_data.pipeline_store
assert len(store.data) == 1
# Test we get the preferred pipeline if none is specified
pipeline = async_get_pipeline(hass, None)
assert pipeline.id == store.async_get_preferred_item()
# Test getting a specific pipeline
assert pipeline is async_get_pipeline(hass, pipeline.id)
async def test_get_pipelines(hass: HomeAssistant) -> None:
"""Test async_get_pipelines."""
assert await async_setup_component(hass, "assist_pipeline", {})
pipeline_data: PipelineData = hass.data[DOMAIN]
store = pipeline_data.pipeline_store
assert len(store.data) == 1
pipelines = async_get_pipelines(hass)
assert list(pipelines) == [
Pipeline(
conversation_engine="conversation.home_assistant",
conversation_language="en",
id=ANY,
language="en",
name="Home Assistant",
stt_engine=None,
stt_language=None,
tts_engine=None,
tts_language=None,
tts_voice=None,
wake_word_entity=None,
wake_word_id=None,
)
]
@pytest.mark.parametrize(
("ha_language", "ha_country", "conv_language", "pipeline_language"),
[
("en", None, "en", "en"),
("de", "de", "de", "de"),
("de", "ch", "de-CH", "de"),
("en", "us", "en", "en"),
("en", "uk", "en", "en"),
("pt", "pt", "pt", "pt"),
("pt", "br", "pt-BR", "pt"),
],
)
async def test_default_pipeline_no_stt_tts(
hass: HomeAssistant,
ha_language: str,
ha_country: str | None,
conv_language: str,
pipeline_language: str,
) -> None:
"""Test async_get_pipeline."""
hass.config.country = ha_country
hass.config.language = ha_language
assert await async_setup_component(hass, "assist_pipeline", {})
pipeline_data: PipelineData = hass.data[DOMAIN]
store = pipeline_data.pipeline_store
assert len(store.data) == 1
# Check the default pipeline
pipeline = async_get_pipeline(hass, None)
assert pipeline == Pipeline(
conversation_engine="conversation.home_assistant",
conversation_language=conv_language,
id=pipeline.id,
language=pipeline_language,
name="Home Assistant",
stt_engine=None,
stt_language=None,
tts_engine=None,
tts_language=None,
tts_voice=None,
wake_word_entity=None,
wake_word_id=None,
)
@pytest.mark.parametrize(
(
"ha_language",
"ha_country",
"conv_language",
"pipeline_language",
"stt_language",
"tts_language",
),
[
("en", None, "en", "en", "en", "en"),
("de", "de", "de", "de", "de", "de"),
("de", "ch", "de-CH", "de", "de-CH", "de-CH"),
("en", "us", "en", "en", "en", "en"),
("en", "uk", "en", "en", "en", "en"),
("pt", "pt", "pt", "pt", "pt", "pt"),
("pt", "br", "pt-BR", "pt", "pt-br", "pt-br"),
],
)
@pytest.mark.usefixtures("init_supporting_components")
@pytest.mark.usefixtures("disable_tts_entity")
async def test_default_pipeline(
hass: HomeAssistant,
mock_stt_provider_entity: MockSTTProviderEntity,
mock_tts_provider: MockTTSProvider,
ha_language: str,
ha_country: str | None,
conv_language: str,
pipeline_language: str,
stt_language: str,
tts_language: str,
) -> None:
"""Test async_get_pipeline."""
hass.config.country = ha_country
hass.config.language = ha_language
with (
patch.object(mock_stt_provider_entity, "_supported_languages", MANY_LANGUAGES),
patch.object(mock_tts_provider, "_supported_languages", MANY_LANGUAGES),
):
assert await async_setup_component(hass, "assist_pipeline", {})
pipeline_data: PipelineData = hass.data[DOMAIN]
store = pipeline_data.pipeline_store
assert len(store.data) == 1
# Check the default pipeline
pipeline = async_get_pipeline(hass, None)
assert pipeline == Pipeline(
conversation_engine="conversation.home_assistant",
conversation_language=conv_language,
id=pipeline.id,
language=pipeline_language,
name="Home Assistant",
stt_engine="stt.mock_stt",
stt_language=stt_language,
tts_engine="test",
tts_language=tts_language,
tts_voice=None,
wake_word_entity=None,
wake_word_id=None,
)
@pytest.mark.usefixtures("init_supporting_components")
@pytest.mark.usefixtures("disable_tts_entity")
async def test_default_pipeline_unsupported_stt_language(
hass: HomeAssistant, mock_stt_provider_entity: MockSTTProviderEntity
) -> None:
"""Test async_get_pipeline."""
with patch.object(mock_stt_provider_entity, "_supported_languages", ["smurfish"]):
assert await async_setup_component(hass, "assist_pipeline", {})
pipeline_data: PipelineData = hass.data[DOMAIN]
store = pipeline_data.pipeline_store
assert len(store.data) == 1
# Check the default pipeline
pipeline = async_get_pipeline(hass, None)
assert pipeline == Pipeline(
conversation_engine="conversation.home_assistant",
conversation_language="en",
id=pipeline.id,
language="en",
name="Home Assistant",
stt_engine=None,
stt_language=None,
tts_engine="test",
tts_language="en-US",
tts_voice="james_earl_jones",
wake_word_entity=None,
wake_word_id=None,
)
@pytest.mark.usefixtures("init_supporting_components")
@pytest.mark.usefixtures("disable_tts_entity")
async def test_default_pipeline_unsupported_tts_language(
hass: HomeAssistant, mock_tts_provider: MockTTSProvider
) -> None:
"""Test async_get_pipeline."""
with patch.object(mock_tts_provider, "_supported_languages", ["smurfish"]):
assert await async_setup_component(hass, "assist_pipeline", {})
pipeline_data: PipelineData = hass.data[DOMAIN]
store = pipeline_data.pipeline_store
assert len(store.data) == 1
# Check the default pipeline
pipeline = async_get_pipeline(hass, None)
assert pipeline == Pipeline(
conversation_engine="conversation.home_assistant",
conversation_language="en",
id=pipeline.id,
language="en",
name="Home Assistant",
stt_engine="stt.mock_stt",
stt_language="en-US",
tts_engine=None,
tts_language=None,
tts_voice=None,
wake_word_entity=None,
wake_word_id=None,
)
async def test_update_pipeline(
hass: HomeAssistant, hass_storage: dict[str, Any]
) -> None:
"""Test async_update_pipeline."""
assert await async_setup_component(hass, "assist_pipeline", {})
pipelines = async_get_pipelines(hass)
pipelines = list(pipelines)
assert pipelines == [
Pipeline(
conversation_engine="conversation.home_assistant",
conversation_language="en",
id=ANY,
language="en",
name="Home Assistant",
stt_engine=None,
stt_language=None,
tts_engine=None,
tts_language=None,
tts_voice=None,
wake_word_entity=None,
wake_word_id=None,
)
]
pipeline = pipelines[0]
await async_update_pipeline(
hass,
pipeline,
conversation_engine="homeassistant_1",
conversation_language="de",
language="de",
name="Home Assistant 1",
stt_engine="stt.test_1",
stt_language="de",
tts_engine="test_1",
tts_language="de",
tts_voice="test_voice",
wake_word_entity="wake_work.test_1",
wake_word_id="wake_word_id_1",
)
pipelines = async_get_pipelines(hass)
pipelines = list(pipelines)
pipeline = pipelines[0]
assert pipelines == [
Pipeline(
conversation_engine="homeassistant_1",
conversation_language="de",
id=pipeline.id,
language="de",
name="Home Assistant 1",
stt_engine="stt.test_1",
stt_language="de",
tts_engine="test_1",
tts_language="de",
tts_voice="test_voice",
wake_word_entity="wake_work.test_1",
wake_word_id="wake_word_id_1",
)
]
assert len(hass_storage[STORAGE_KEY]["data"]["items"]) == 1
assert hass_storage[STORAGE_KEY]["data"]["items"][0] == {
"conversation_engine": "homeassistant_1",
"conversation_language": "de",
"id": pipeline.id,
"language": "de",
"name": "Home Assistant 1",
"stt_engine": "stt.test_1",
"stt_language": "de",
"tts_engine": "test_1",
"tts_language": "de",
"tts_voice": "test_voice",
"wake_word_entity": "wake_work.test_1",
"wake_word_id": "wake_word_id_1",
"prefer_local_intents": False,
}
await async_update_pipeline(
hass,
pipeline,
stt_engine="stt.test_2",
stt_language="en",
tts_engine="test_2",
tts_language="en",
)
pipelines = async_get_pipelines(hass)
pipelines = list(pipelines)
assert pipelines == [
Pipeline(
conversation_engine="homeassistant_1",
conversation_language="de",
id=pipeline.id,
language="de",
name="Home Assistant 1",
stt_engine="stt.test_2",
stt_language="en",
tts_engine="test_2",
tts_language="en",
tts_voice="test_voice",
wake_word_entity="wake_work.test_1",
wake_word_id="wake_word_id_1",
)
]
assert len(hass_storage[STORAGE_KEY]["data"]["items"]) == 1
assert hass_storage[STORAGE_KEY]["data"]["items"][0] == {
"conversation_engine": "homeassistant_1",
"conversation_language": "de",
"id": pipeline.id,
"language": "de",
"name": "Home Assistant 1",
"stt_engine": "stt.test_2",
"stt_language": "en",
"tts_engine": "test_2",
"tts_language": "en",
"tts_voice": "test_voice",
"wake_word_entity": "wake_work.test_1",
"wake_word_id": "wake_word_id_1",
"prefer_local_intents": False,
}
def test_fallback_intent_filter() -> None:
"""Test that we filter the right things."""
assert (
_async_local_fallback_intent_filter(
RecognizeResult(
intent=Intent(intent.INTENT_GET_STATE),
intent_data=IntentData([]),
entities={},
entities_list=[],
)
)
is True
)
assert (
_async_local_fallback_intent_filter(
RecognizeResult(
intent=Intent(intent.INTENT_NEVERMIND),
intent_data=IntentData([]),
entities={},
entities_list=[],
)
)
is False
)
assert (
_async_local_fallback_intent_filter(
RecognizeResult(
intent=Intent(intent.INTENT_TURN_ON),
intent_data=IntentData([]),
entities={},
entities_list=[],
)
)
is False
)
async def test_wake_word_detection_aborted(
hass: HomeAssistant,
mock_stt_provider: MockSTTProvider,
mock_wake_word_provider_entity: MockWakeWordEntity,
init_components,
pipeline_data: assist_pipeline.pipeline.PipelineData,
mock_chat_session: chat_session.ChatSession,
snapshot: SnapshotAssertion,
) -> None:
"""Test wake word stream is first detected, then aborted."""
events: list[assist_pipeline.PipelineEvent] = []
async def audio_data():
yield make_10ms_chunk(b"silence!")
yield make_10ms_chunk(b"wake word!")
yield make_10ms_chunk(b"part1")
yield make_10ms_chunk(b"part2")
yield b""
pipeline_store = pipeline_data.pipeline_store
pipeline_id = pipeline_store.async_get_preferred_item()
pipeline = assist_pipeline.pipeline.async_get_pipeline(hass, pipeline_id)
pipeline_input = assist_pipeline.pipeline.PipelineInput(
session=mock_chat_session,
device_id=None,
stt_metadata=stt.SpeechMetadata(
language="",
format=stt.AudioFormats.WAV,
codec=stt.AudioCodecs.PCM,
bit_rate=stt.AudioBitRates.BITRATE_16,
sample_rate=stt.AudioSampleRates.SAMPLERATE_16000,
channel=stt.AudioChannels.CHANNEL_MONO,
),
stt_stream=audio_data(),
run=assist_pipeline.pipeline.PipelineRun(
hass,
context=Context(),
pipeline=pipeline,
start_stage=assist_pipeline.PipelineStage.WAKE_WORD,
end_stage=assist_pipeline.PipelineStage.TTS,
event_callback=events.append,
tts_audio_output=None,
wake_word_settings=assist_pipeline.WakeWordSettings(
audio_seconds_to_buffer=1.5
),
audio_settings=assist_pipeline.AudioSettings(is_vad_enabled=False),
),
)
await pipeline_input.validate()
updates = pipeline.to_json()
updates.pop("id")
await pipeline_store.async_update_item(
pipeline_id,
updates,
)
await pipeline_input.execute()
assert process_events(events) == snapshot
def test_pipeline_run_equality(hass: HomeAssistant, init_components) -> None:
"""Test that pipeline run equality uses unique id."""
def event_callback(event):
pass
pipeline = assist_pipeline.pipeline.async_get_pipeline(hass)
run_1 = assist_pipeline.pipeline.PipelineRun(
hass,
context=Context(),
pipeline=pipeline,
start_stage=assist_pipeline.PipelineStage.STT,
end_stage=assist_pipeline.PipelineStage.TTS,
event_callback=event_callback,
)
run_2 = assist_pipeline.pipeline.PipelineRun(
hass,
context=Context(),
pipeline=pipeline,
start_stage=assist_pipeline.PipelineStage.STT,
end_stage=assist_pipeline.PipelineStage.TTS,
event_callback=event_callback,
)
assert run_1 == run_1 # noqa: PLR0124
assert run_1 != run_2
assert run_1 != 1234
async def test_tts_audio_output(
hass: HomeAssistant,
hass_client: ClientSessionGenerator,
mock_tts_entity: MockTTSProvider,
init_components,
pipeline_data: assist_pipeline.pipeline.PipelineData,
mock_chat_session: chat_session.ChatSession,
snapshot: SnapshotAssertion,
) -> None:
"""Test using tts_audio_output with wav sets options correctly."""
client = await hass_client()
assert await async_setup_component(hass, media_source.DOMAIN, {})
events: list[assist_pipeline.PipelineEvent] = []
pipeline_store = pipeline_data.pipeline_store
pipeline_id = pipeline_store.async_get_preferred_item()
pipeline = assist_pipeline.pipeline.async_get_pipeline(hass, pipeline_id)
pipeline_input = assist_pipeline.pipeline.PipelineInput(
tts_input="This is a test.",
session=mock_chat_session,
device_id=None,
run=assist_pipeline.pipeline.PipelineRun(
hass,
context=Context(),
pipeline=pipeline,
start_stage=assist_pipeline.PipelineStage.TTS,
end_stage=assist_pipeline.PipelineStage.TTS,
event_callback=events.append,
tts_audio_output="wav",
),
)
await pipeline_input.validate()
# Verify TTS audio settings
assert pipeline_input.run.tts_stream.options is not None
assert pipeline_input.run.tts_stream.options.get(tts.ATTR_PREFERRED_FORMAT) == "wav"
assert (
pipeline_input.run.tts_stream.options.get(tts.ATTR_PREFERRED_SAMPLE_RATE)
== 16000
)
assert (
pipeline_input.run.tts_stream.options.get(tts.ATTR_PREFERRED_SAMPLE_CHANNELS)
== 1
)
with patch.object(mock_tts_entity, "get_tts_audio") as mock_get_tts_audio:
await pipeline_input.execute()
for event in events:
if event.type == assist_pipeline.PipelineEventType.TTS_END:
# We must fetch the media URL to trigger the TTS
assert event.data
await client.get(event.data["tts_output"]["url"])
# Ensure that no unsupported options were passed in
assert mock_get_tts_audio.called
options = mock_get_tts_audio.call_args_list[0].kwargs["options"]
extra_options = set(options).difference(mock_tts_entity.supported_options)
assert len(extra_options) == 0, extra_options
async def test_tts_wav_preferred_format(
hass: HomeAssistant,
hass_client: ClientSessionGenerator,
mock_tts_entity: MockTTSEntity,
init_components,
mock_chat_session: chat_session.ChatSession,
pipeline_data: assist_pipeline.pipeline.PipelineData,
) -> None:
"""Test that preferred format options are given to the TTS system if supported."""
client = await hass_client()
assert await async_setup_component(hass, media_source.DOMAIN, {})
events: list[assist_pipeline.PipelineEvent] = []
pipeline_store = pipeline_data.pipeline_store
pipeline_id = pipeline_store.async_get_preferred_item()
pipeline = assist_pipeline.pipeline.async_get_pipeline(hass, pipeline_id)
pipeline_input = assist_pipeline.pipeline.PipelineInput(
tts_input="This is a test.",
session=mock_chat_session,
device_id=None,
run=assist_pipeline.pipeline.PipelineRun(
hass,
context=Context(),
pipeline=pipeline,
start_stage=assist_pipeline.PipelineStage.TTS,
end_stage=assist_pipeline.PipelineStage.TTS,
event_callback=events.append,
tts_audio_output="wav",
),
)
await pipeline_input.validate()
# Make the TTS provider support preferred format options
supported_options = list(mock_tts_entity.supported_options or [])
supported_options.extend(
[
tts.ATTR_PREFERRED_FORMAT,
tts.ATTR_PREFERRED_SAMPLE_RATE,
tts.ATTR_PREFERRED_SAMPLE_CHANNELS,
tts.ATTR_PREFERRED_SAMPLE_BYTES,
]
)
with (
patch.object(mock_tts_entity, "_supported_options", supported_options),
patch.object(mock_tts_entity, "get_tts_audio") as mock_get_tts_audio,
):
await pipeline_input.execute()
for event in events:
if event.type == assist_pipeline.PipelineEventType.TTS_END:
# We must fetch the media URL to trigger the TTS
assert event.data
await client.get(event.data["tts_output"]["url"])
assert mock_get_tts_audio.called
options = mock_get_tts_audio.call_args_list[0].kwargs["options"]
# We should have received preferred format options in get_tts_audio
assert options.get(tts.ATTR_PREFERRED_FORMAT) == "wav"
assert int(options.get(tts.ATTR_PREFERRED_SAMPLE_RATE)) == 16000
assert int(options.get(tts.ATTR_PREFERRED_SAMPLE_CHANNELS)) == 1
assert int(options.get(tts.ATTR_PREFERRED_SAMPLE_BYTES)) == 2
async def test_tts_dict_preferred_format(
hass: HomeAssistant,
hass_client: ClientSessionGenerator,
mock_tts_entity: MockTTSEntity,
init_components,
mock_chat_session: chat_session.ChatSession,
pipeline_data: assist_pipeline.pipeline.PipelineData,
) -> None:
"""Test that preferred format options are given to the TTS system if supported."""
client = await hass_client()
assert await async_setup_component(hass, media_source.DOMAIN, {})
events: list[assist_pipeline.PipelineEvent] = []
pipeline_store = pipeline_data.pipeline_store
pipeline_id = pipeline_store.async_get_preferred_item()
pipeline = assist_pipeline.pipeline.async_get_pipeline(hass, pipeline_id)
pipeline_input = assist_pipeline.pipeline.PipelineInput(
tts_input="This is a test.",
session=mock_chat_session,
device_id=None,
run=assist_pipeline.pipeline.PipelineRun(
hass,
context=Context(),
pipeline=pipeline,
start_stage=assist_pipeline.PipelineStage.TTS,
end_stage=assist_pipeline.PipelineStage.TTS,
event_callback=events.append,
tts_audio_output={
tts.ATTR_PREFERRED_FORMAT: "flac",
tts.ATTR_PREFERRED_SAMPLE_RATE: 48000,
tts.ATTR_PREFERRED_SAMPLE_CHANNELS: 2,
tts.ATTR_PREFERRED_SAMPLE_BYTES: 2,
},
),
)
await pipeline_input.validate()
# Make the TTS provider support preferred format options
supported_options = list(mock_tts_entity.supported_options or [])
supported_options.extend(
[
tts.ATTR_PREFERRED_FORMAT,
tts.ATTR_PREFERRED_SAMPLE_RATE,
tts.ATTR_PREFERRED_SAMPLE_CHANNELS,
tts.ATTR_PREFERRED_SAMPLE_BYTES,
]
)
with (
patch.object(mock_tts_entity, "_supported_options", supported_options),
patch.object(mock_tts_entity, "get_tts_audio") as mock_get_tts_audio,
):
await pipeline_input.execute()
for event in events:
if event.type == assist_pipeline.PipelineEventType.TTS_END:
# We must fetch the media URL to trigger the TTS
assert event.data
await client.get(event.data["tts_output"]["url"])
assert mock_get_tts_audio.called
options = mock_get_tts_audio.call_args_list[0].kwargs["options"]
# We should have received preferred format options in get_tts_audio
assert options.get(tts.ATTR_PREFERRED_FORMAT) == "flac"
assert int(options.get(tts.ATTR_PREFERRED_SAMPLE_RATE)) == 48000
assert int(options.get(tts.ATTR_PREFERRED_SAMPLE_CHANNELS)) == 2
assert int(options.get(tts.ATTR_PREFERRED_SAMPLE_BYTES)) == 2
async def test_sentence_trigger_overrides_conversation_agent(
hass: HomeAssistant,
init_components,
mock_chat_session: chat_session.ChatSession,
pipeline_data: assist_pipeline.pipeline.PipelineData,
) -> None:
"""Test that sentence triggers are checked before a non-default conversation agent."""
assert await async_setup_component(
hass,
"automation",
{
"automation": {
"trigger": {
"platform": "conversation",
"command": [
"test trigger sentence",
],
},
"action": {
"set_conversation_response": "test trigger response",
},
}
},
)
events: list[assist_pipeline.PipelineEvent] = []
pipeline_store = pipeline_data.pipeline_store
pipeline_id = pipeline_store.async_get_preferred_item()
pipeline = assist_pipeline.pipeline.async_get_pipeline(hass, pipeline_id)
pipeline_input = assist_pipeline.pipeline.PipelineInput(
intent_input="test trigger sentence",
session=mock_chat_session,
run=assist_pipeline.pipeline.PipelineRun(
hass,
context=Context(),
pipeline=pipeline,
start_stage=assist_pipeline.PipelineStage.INTENT,
end_stage=assist_pipeline.PipelineStage.INTENT,
event_callback=events.append,
intent_agent="test-agent", # not the default agent
),
)
# Ensure prepare succeeds
with patch(
"homeassistant.components.assist_pipeline.pipeline.conversation.async_get_agent_info",
return_value=conversation.AgentInfo(
id="test-agent",
name="Test Agent",
supports_streaming=False,
),
):
await pipeline_input.validate()
with patch(
"homeassistant.components.assist_pipeline.pipeline.conversation.async_converse"
) as mock_async_converse:
await pipeline_input.execute()
# Sentence trigger should have been handled
mock_async_converse.assert_not_called()
# Verify sentence trigger response
intent_end_event = next(
(
e
for e in events
if e.type == assist_pipeline.PipelineEventType.INTENT_END
),
None,
)
assert (intent_end_event is not None) and intent_end_event.data
assert intent_end_event.data["processed_locally"] is True
assert (
intent_end_event.data["intent_output"]["response"]["speech"]["plain"][
"speech"
]
== "test trigger response"
)
async def test_prefer_local_intents(
hass: HomeAssistant,
init_components,
mock_chat_session: chat_session.ChatSession,
pipeline_data: assist_pipeline.pipeline.PipelineData,
) -> None:
"""Test that the default agent is checked first when local intents are preferred."""
events: list[assist_pipeline.PipelineEvent] = []
# Reuse custom sentences in test config
class OrderBeerIntentHandler(intent.IntentHandler):
intent_type = "OrderBeer"
async def async_handle(
self, intent_obj: intent.Intent
) -> intent.IntentResponse:
response = intent_obj.create_response()
response.async_set_speech("Order confirmed")
return response
handler = OrderBeerIntentHandler()
intent.async_register(hass, handler)
# Fake a test agent and prefer local intents
pipeline_store = pipeline_data.pipeline_store
pipeline_id = pipeline_store.async_get_preferred_item()
pipeline = assist_pipeline.pipeline.async_get_pipeline(hass, pipeline_id)
await assist_pipeline.pipeline.async_update_pipeline(
hass, pipeline, conversation_engine="test-agent", prefer_local_intents=True
)
pipeline = assist_pipeline.pipeline.async_get_pipeline(hass, pipeline_id)
pipeline_input = assist_pipeline.pipeline.PipelineInput(
intent_input="I'd like to order a stout please",
session=mock_chat_session,
run=assist_pipeline.pipeline.PipelineRun(
hass,
context=Context(),
pipeline=pipeline,
start_stage=assist_pipeline.PipelineStage.INTENT,
end_stage=assist_pipeline.PipelineStage.INTENT,
event_callback=events.append,
),
)
# Ensure prepare succeeds
with patch(
"homeassistant.components.assist_pipeline.pipeline.conversation.async_get_agent_info",
return_value=conversation.AgentInfo(
id="test-agent",
name="Test Agent",
supports_streaming=False,
),
):
await pipeline_input.validate()
with patch(
"homeassistant.components.assist_pipeline.pipeline.conversation.async_converse"
) as mock_async_converse:
await pipeline_input.execute()
# Test agent should not have been called
mock_async_converse.assert_not_called()
# Verify local intent response
intent_end_event = next(
(
e
for e in events
if e.type == assist_pipeline.PipelineEventType.INTENT_END
),
None,
)
assert (intent_end_event is not None) and intent_end_event.data
assert intent_end_event.data["processed_locally"] is True
assert (
intent_end_event.data["intent_output"]["response"]["speech"]["plain"][
"speech"
]
== "Order confirmed"
)
async def test_intent_continue_conversation(
hass: HomeAssistant,
init_components,
mock_chat_session: chat_session.ChatSession,
pipeline_data: assist_pipeline.pipeline.PipelineData,
) -> None:
"""Test that a conversation agent flagging continue conversation gets response."""
events: list[assist_pipeline.PipelineEvent] = []
# Fake a test agent and prefer local intents
pipeline_store = pipeline_data.pipeline_store
pipeline_id = pipeline_store.async_get_preferred_item()
pipeline = assist_pipeline.pipeline.async_get_pipeline(hass, pipeline_id)
await assist_pipeline.pipeline.async_update_pipeline(
hass, pipeline, conversation_engine="test-agent"
)
pipeline = assist_pipeline.pipeline.async_get_pipeline(hass, pipeline_id)
pipeline_input = assist_pipeline.pipeline.PipelineInput(
intent_input="Set a timer",
session=mock_chat_session,
run=assist_pipeline.pipeline.PipelineRun(
hass,
context=Context(),
pipeline=pipeline,
start_stage=assist_pipeline.PipelineStage.INTENT,
end_stage=assist_pipeline.PipelineStage.INTENT,
event_callback=events.append,
),
)
# Ensure prepare succeeds
with patch(
"homeassistant.components.assist_pipeline.pipeline.conversation.async_get_agent_info",
return_value=conversation.AgentInfo(
id="test-agent",
name="Test Agent",
supports_streaming=False,
),
):
await pipeline_input.validate()
response = intent.IntentResponse("en")
response.async_set_speech("For how long?")
with patch(
"homeassistant.components.assist_pipeline.pipeline.conversation.async_converse",
return_value=conversation.ConversationResult(
response=response,
conversation_id=mock_chat_session.conversation_id,
continue_conversation=True,
),
) as mock_async_converse:
await pipeline_input.execute()
mock_async_converse.assert_called()
results = [
event.data
for event in events
if event.type
in (
assist_pipeline.PipelineEventType.INTENT_START,
assist_pipeline.PipelineEventType.INTENT_END,
)
]
assert results[1]["intent_output"]["continue_conversation"] is True
# Change conversation agent to default one and register sentence trigger that should not be called
await assist_pipeline.pipeline.async_update_pipeline(
hass, pipeline, conversation_engine=None
)
pipeline = assist_pipeline.pipeline.async_get_pipeline(hass, pipeline_id)
assert await async_setup_component(
hass,
"automation",
{
"automation": {
"trigger": {
"platform": "conversation",
"command": ["Hello"],
},
"action": {
"set_conversation_response": "test trigger response",
},
}
},
)
# Because we did continue conversation, it should respond to the test agent again.
events.clear()
pipeline_input = assist_pipeline.pipeline.PipelineInput(
intent_input="Hello",
session=mock_chat_session,
run=assist_pipeline.pipeline.PipelineRun(
hass,
context=Context(),
pipeline=pipeline,
start_stage=assist_pipeline.PipelineStage.INTENT,
end_stage=assist_pipeline.PipelineStage.INTENT,
event_callback=events.append,
),
)
# Ensure prepare succeeds
with patch(
"homeassistant.components.assist_pipeline.pipeline.conversation.async_get_agent_info",
return_value=conversation.AgentInfo(
id="test-agent",
name="Test Agent",
supports_streaming=False,
),
) as mock_prepare:
await pipeline_input.validate()
# It requested test agent even if that was not default agent.
assert mock_prepare.mock_calls[0][1][1] == "test-agent"
response = intent.IntentResponse("en")
response.async_set_speech("Timer set for 20 minutes")
with patch(
"homeassistant.components.assist_pipeline.pipeline.conversation.async_converse",
return_value=conversation.ConversationResult(
response=response,
conversation_id=mock_chat_session.conversation_id,
),
) as mock_async_converse:
await pipeline_input.execute()
mock_async_converse.assert_called()
# Snapshot will show it was still handled by the test agent and not default agent
results = [
event.data
for event in events
if event.type
in (
assist_pipeline.PipelineEventType.INTENT_START,
assist_pipeline.PipelineEventType.INTENT_END,
)
]
assert results[0]["engine"] == "test-agent"
assert results[1]["intent_output"]["continue_conversation"] is False
async def test_stt_language_used_instead_of_conversation_language(
hass: HomeAssistant,
hass_ws_client: WebSocketGenerator,
init_components,
mock_chat_session: chat_session.ChatSession,
snapshot: SnapshotAssertion,
) -> None:
"""Test that the STT language is used first when the conversation language is '*' (all languages)."""
client = await hass_ws_client(hass)
events: list[assist_pipeline.PipelineEvent] = []
await client.send_json_auto_id(
{
"type": "assist_pipeline/pipeline/create",
"conversation_engine": conversation.HOME_ASSISTANT_AGENT,
"conversation_language": MATCH_ALL,
"language": "en",
"name": "test_name",
"stt_engine": "test",
"stt_language": "en-US",
"tts_engine": "test",
"tts_language": "en-US",
"tts_voice": "Arnold Schwarzenegger",
"wake_word_entity": None,
"wake_word_id": None,
}
)
msg = await client.receive_json()
assert msg["success"]
pipeline_id = msg["result"]["id"]
pipeline = assist_pipeline.async_get_pipeline(hass, pipeline_id)
pipeline_input = assist_pipeline.pipeline.PipelineInput(
intent_input="test input",
session=mock_chat_session,
run=assist_pipeline.pipeline.PipelineRun(
hass,
context=Context(),
pipeline=pipeline,
start_stage=assist_pipeline.PipelineStage.INTENT,
end_stage=assist_pipeline.PipelineStage.INTENT,
event_callback=events.append,
),
)
await pipeline_input.validate()
with patch(
"homeassistant.components.assist_pipeline.pipeline.conversation.async_converse",
return_value=conversation.ConversationResult(
intent.IntentResponse(pipeline.language)
),
) as mock_async_converse:
await pipeline_input.execute()
# Check intent start event
assert process_events(events) == snapshot
intent_start: assist_pipeline.PipelineEvent | None = None
for event in events:
if event.type == assist_pipeline.PipelineEventType.INTENT_START:
intent_start = event
break
assert intent_start is not None
# STT language (en-US) should be used instead of '*'
assert intent_start.data.get("language") == pipeline.stt_language
# Check input to async_converse
mock_async_converse.assert_called_once()
assert (
mock_async_converse.call_args_list[0].kwargs.get("language")
== pipeline.stt_language
)
async def test_tts_language_used_instead_of_conversation_language(
hass: HomeAssistant,
hass_ws_client: WebSocketGenerator,
init_components,
mock_chat_session: chat_session.ChatSession,
snapshot: SnapshotAssertion,
) -> None:
"""Test that the TTS language is used after STT when the conversation language is '*' (all languages)."""
client = await hass_ws_client(hass)
events: list[assist_pipeline.PipelineEvent] = []
await client.send_json_auto_id(
{
"type": "assist_pipeline/pipeline/create",
"conversation_engine": conversation.HOME_ASSISTANT_AGENT,
"conversation_language": MATCH_ALL,
"language": "en",
"name": "test_name",
"stt_engine": None,
"stt_language": None,
"tts_engine": None,
"tts_language": "en-us",
"tts_voice": "Arnold Schwarzenegger",
"wake_word_entity": None,
"wake_word_id": None,
}
)
msg = await client.receive_json()
assert msg["success"]
pipeline_id = msg["result"]["id"]
pipeline = assist_pipeline.async_get_pipeline(hass, pipeline_id)
pipeline_input = assist_pipeline.pipeline.PipelineInput(
intent_input="test input",
session=mock_chat_session,
run=assist_pipeline.pipeline.PipelineRun(
hass,
context=Context(),
pipeline=pipeline,
start_stage=assist_pipeline.PipelineStage.INTENT,
end_stage=assist_pipeline.PipelineStage.INTENT,
event_callback=events.append,
),
)
await pipeline_input.validate()
with patch(
"homeassistant.components.assist_pipeline.pipeline.conversation.async_converse",
return_value=conversation.ConversationResult(
intent.IntentResponse(pipeline.language)
),
) as mock_async_converse:
await pipeline_input.execute()
# Check intent start event
assert process_events(events) == snapshot
intent_start: assist_pipeline.PipelineEvent | None = None
for event in events:
if event.type == assist_pipeline.PipelineEventType.INTENT_START:
intent_start = event
break
assert intent_start is not None
# STT language (en-US) should be used instead of '*'
assert intent_start.data.get("language") == pipeline.tts_language
# Check input to async_converse
mock_async_converse.assert_called_once()
assert (
mock_async_converse.call_args_list[0].kwargs.get("language")
== pipeline.tts_language
)
async def test_pipeline_language_used_instead_of_conversation_language(
hass: HomeAssistant,
hass_ws_client: WebSocketGenerator,
init_components,
mock_chat_session: chat_session.ChatSession,
snapshot: SnapshotAssertion,
) -> None:
"""Test that the pipeline language is used last when the conversation language is '*' (all languages)."""
client = await hass_ws_client(hass)
events: list[assist_pipeline.PipelineEvent] = []
await client.send_json_auto_id(
{
"type": "assist_pipeline/pipeline/create",
"conversation_engine": conversation.HOME_ASSISTANT_AGENT,
"conversation_language": MATCH_ALL,
"language": "en",
"name": "test_name",
"stt_engine": None,
"stt_language": None,
"tts_engine": None,
"tts_language": None,
"tts_voice": None,
"wake_word_entity": None,
"wake_word_id": None,
}
)
msg = await client.receive_json()
assert msg["success"]
pipeline_id = msg["result"]["id"]
pipeline = assist_pipeline.async_get_pipeline(hass, pipeline_id)
pipeline_input = assist_pipeline.pipeline.PipelineInput(
intent_input="test input",
session=mock_chat_session,
run=assist_pipeline.pipeline.PipelineRun(
hass,
context=Context(),
pipeline=pipeline,
start_stage=assist_pipeline.PipelineStage.INTENT,
end_stage=assist_pipeline.PipelineStage.INTENT,
event_callback=events.append,
),
)
await pipeline_input.validate()
with patch(
"homeassistant.components.assist_pipeline.pipeline.conversation.async_converse",
return_value=conversation.ConversationResult(
intent.IntentResponse(pipeline.language)
),
) as mock_async_converse:
await pipeline_input.execute()
# Check intent start event
assert process_events(events) == snapshot
intent_start: assist_pipeline.PipelineEvent | None = None
for event in events:
if event.type == assist_pipeline.PipelineEventType.INTENT_START:
intent_start = event
break
assert intent_start is not None
# STT language (en-US) should be used instead of '*'
assert intent_start.data.get("language") == pipeline.language
# Check input to async_converse
mock_async_converse.assert_called_once()
assert (
mock_async_converse.call_args_list[0].kwargs.get("language")
== pipeline.language
)
@pytest.mark.parametrize(
("to_stream_deltas", "expected_chunks", "chunk_text"),
[
# Size below STREAM_RESPONSE_CHUNKS
(
(
[
"hello,",
" ",
"how",
" ",
"are",
" ",
"you",
"?",
],
),
# We always stream when possible, so 1 chunk via streaming method
1,
"hello, how are you?",
),
# Size above STREAM_RESPONSE_CHUNKS
(
(
[
"hello, ",
"how ",
"are ",
"you",
"? ",
"I'm ",
"doing ",
"well",
", ",
"thank ",
"you",
". ",
"What ",
"about ",
"you",
"?",
"!",
],
),
# We are streamed. First 15 chunks are grouped into 1 chunk
# and the rest are streamed
3,
"hello, how are you? I'm doing well, thank you. What about you?!",
),
# Stream a bit, then a tool call, then stream some more
(
(
[
"hello, ",
"how ",
"are ",
"you",
"? ",
],
{
"tool_calls": [
llm.ToolInput(
tool_name="test_tool",
tool_args={},
id="test_tool_id",
)
],
},
[
"I'm ",
"doing ",
"well",
", ",
"thank ",
"you",
".",
],
),
# 1 chunk before tool call, then 7 after
8,
"hello, how are you? I'm doing well, thank you.",
),
],
)
async def test_chat_log_tts_streaming(
hass: HomeAssistant,
hass_ws_client: WebSocketGenerator,
init_components,
mock_chat_session: chat_session.ChatSession,
snapshot: SnapshotAssertion,
mock_tts_entity: MockTTSEntity,
pipeline_data: assist_pipeline.pipeline.PipelineData,
to_stream_deltas: tuple[dict | list[str]],
expected_chunks: int,
chunk_text: str,
) -> None:
"""Test that chat log events are streamed to the TTS entity."""
text_deltas = [
delta
for deltas in to_stream_deltas
if isinstance(deltas, list)
for delta in deltas
]
events: list[assist_pipeline.PipelineEvent] = []
pipeline_store = pipeline_data.pipeline_store
pipeline_id = pipeline_store.async_get_preferred_item()
pipeline = assist_pipeline.pipeline.async_get_pipeline(hass, pipeline_id)
await assist_pipeline.pipeline.async_update_pipeline(
hass, pipeline, conversation_engine="test-agent"
)
pipeline = assist_pipeline.pipeline.async_get_pipeline(hass, pipeline_id)
pipeline_input = assist_pipeline.pipeline.PipelineInput(
intent_input="Set a timer",
session=mock_chat_session,
run=assist_pipeline.pipeline.PipelineRun(
hass,
context=Context(),
pipeline=pipeline,
start_stage=assist_pipeline.PipelineStage.INTENT,
end_stage=assist_pipeline.PipelineStage.TTS,
event_callback=events.append,
),
)
received_tts = []
async def async_stream_tts_audio(
request: tts.TTSAudioRequest,
) -> tts.TTSAudioResponse:
"""Mock stream TTS audio."""
async def gen_data():
async for msg in request.message_gen:
received_tts.append(msg)
yield msg.encode()
return tts.TTSAudioResponse(
extension="mp3",
data_gen=gen_data(),
)
async def async_get_tts_audio(
message: str,
language: str,
options: dict[str, Any] | None = None,
) -> tts.TtsAudioType:
"""Mock get TTS audio."""
return ("mp3", b"".join([chunk.encode() for chunk in text_deltas]))
mock_tts_entity.async_get_tts_audio = async_get_tts_audio
mock_tts_entity.async_stream_tts_audio = async_stream_tts_audio
mock_tts_entity.async_supports_streaming_input = Mock(return_value=True)
with patch(
"homeassistant.components.assist_pipeline.pipeline.conversation.async_get_agent_info",
return_value=conversation.AgentInfo(
id="test-agent",
name="Test Agent",
supports_streaming=True,
),
):
await pipeline_input.validate()
async def mock_converse(
hass: HomeAssistant,
text: str,
conversation_id: str | None,
context: Context,
language: str | None = None,
agent_id: str | None = None,
device_id: str | None = None,
extra_system_prompt: str | None = None,
):
"""Mock converse."""
conversation_input = conversation.ConversationInput(
text=text,
context=context,
conversation_id=conversation_id,
device_id=device_id,
language=language,
agent_id=agent_id,
extra_system_prompt=extra_system_prompt,
)
async def stream_llm_response():
for deltas in to_stream_deltas:
if isinstance(deltas, dict):
yield deltas
else:
yield {"role": "assistant"}
for chunk in deltas:
yield {"content": chunk}
with (
chat_session.async_get_chat_session(hass, conversation_id) as session,
conversation.async_get_chat_log(
hass,
session,
conversation_input,
) as chat_log,
):
await chat_log.async_provide_llm_data(
conversation_input.as_llm_context("test"),
user_llm_hass_api="assist",
user_llm_prompt=None,
user_extra_system_prompt=conversation_input.extra_system_prompt,
)
async for _content in chat_log.async_add_delta_content_stream(
agent_id, stream_llm_response()
):
pass
intent_response = intent.IntentResponse(language)
intent_response.async_set_speech("".join(to_stream_deltas[-1]))
return conversation.ConversationResult(
response=intent_response,
conversation_id=chat_log.conversation_id,
continue_conversation=chat_log.continue_conversation,
)
mock_tool = AsyncMock()
mock_tool.name = "test_tool"
mock_tool.description = "Test function"
mock_tool.parameters = vol.Schema({})
mock_tool.async_call.return_value = "Test response"
with (
patch(
"homeassistant.helpers.llm.AssistAPI._async_get_tools",
return_value=[mock_tool],
),
patch(
"homeassistant.components.assist_pipeline.pipeline.conversation.async_converse",
mock_converse,
),
):
await pipeline_input.execute()
stream = tts.async_get_stream(hass, events[0].data["tts_output"]["token"])
assert stream is not None
tts_result = "".join(
[chunk.decode() async for chunk in stream.async_stream_result()]
)
streamed_text = "".join(text_deltas)
assert tts_result == streamed_text
assert len(received_tts) == expected_chunks
assert "".join(received_tts) == chunk_text
assert process_events(events) == snapshot