diff --git a/homeassistant/components/assist_pipeline/__init__.py b/homeassistant/components/assist_pipeline/__init__.py index 9a32821e3a0..59bd987d90e 100644 --- a/homeassistant/components/assist_pipeline/__init__.py +++ b/homeassistant/components/assist_pipeline/__init__.py @@ -117,7 +117,7 @@ async def async_pipeline_from_audio_stream( """ with chat_session.async_get_chat_session(hass, conversation_id) as session: pipeline_input = PipelineInput( - conversation_id=session.conversation_id, + session=session, device_id=device_id, stt_metadata=stt_metadata, stt_stream=stt_stream, diff --git a/homeassistant/components/assist_pipeline/pipeline.py b/homeassistant/components/assist_pipeline/pipeline.py index 75811a0ec36..038874d1966 100644 --- a/homeassistant/components/assist_pipeline/pipeline.py +++ b/homeassistant/components/assist_pipeline/pipeline.py @@ -96,6 +96,9 @@ ENGINE_LANGUAGE_PAIRS = ( ) KEY_ASSIST_PIPELINE: HassKey[PipelineData] = HassKey(DOMAIN) +KEY_PIPELINE_CONVERSATION_DATA: HassKey[dict[str, PipelineConversationData]] = HassKey( + "pipeline_conversation_data" +) def validate_language(data: dict[str, Any]) -> Any: @@ -590,6 +593,12 @@ class PipelineRun: _device_id: str | None = None """Optional device id set during run start.""" + _conversation_data: PipelineConversationData | None = None + """Data tied to the conversation ID.""" + + _intent_agent_only = False + """If request should only be handled by agent, ignoring sentence triggers and local processing.""" + def __post_init__(self) -> None: """Set language for pipeline.""" self.language = self.pipeline.language or self.hass.config.language @@ -1007,19 +1016,36 @@ class PipelineRun: yield chunk.audio - async def prepare_recognize_intent(self) -> None: + async def prepare_recognize_intent(self, session: chat_session.ChatSession) -> None: """Prepare recognizing an intent.""" - agent_info = conversation.async_get_agent_info( - self.hass, - self.pipeline.conversation_engine or conversation.HOME_ASSISTANT_AGENT, + self._conversation_data = async_get_pipeline_conversation_data( + self.hass, session ) - if agent_info is None: - engine = self.pipeline.conversation_engine or "default" - raise IntentRecognitionError( - code="intent-not-supported", - message=f"Intent recognition engine {engine} is not found", + if self._conversation_data.continue_conversation_agent is not None: + agent_info = conversation.async_get_agent_info( + self.hass, self._conversation_data.continue_conversation_agent ) + self._conversation_data.continue_conversation_agent = None + if agent_info is None: + raise IntentRecognitionError( + code="intent-agent-not-found", + message=f"Intent recognition engine {self._conversation_data.continue_conversation_agent} asked for follow-up but is no longer found", + ) + self._intent_agent_only = True + + else: + agent_info = conversation.async_get_agent_info( + self.hass, + self.pipeline.conversation_engine or conversation.HOME_ASSISTANT_AGENT, + ) + + if agent_info is None: + engine = self.pipeline.conversation_engine or "default" + raise IntentRecognitionError( + code="intent-not-supported", + message=f"Intent recognition engine {engine} is not found", + ) self.intent_agent = agent_info.id @@ -1031,7 +1057,7 @@ class PipelineRun: conversation_extra_system_prompt: str | None, ) -> str: """Run intent recognition portion of pipeline. Returns text to speak.""" - if self.intent_agent is None: + if self.intent_agent is None or self._conversation_data is None: raise RuntimeError("Recognize intent was not prepared") if self.pipeline.conversation_language == MATCH_ALL: @@ -1078,7 +1104,7 @@ class PipelineRun: agent_id = self.intent_agent processed_locally = agent_id == conversation.HOME_ASSISTANT_AGENT intent_response: intent.IntentResponse | None = None - if not processed_locally: + if not processed_locally and not self._intent_agent_only: # Sentence triggers override conversation agent if ( trigger_response_text @@ -1195,6 +1221,9 @@ class PipelineRun: ) ) + if conversation_result.continue_conversation: + self._conversation_data.continue_conversation_agent = agent_id + return speech async def prepare_text_to_speech(self) -> None: @@ -1458,8 +1487,8 @@ class PipelineInput: run: PipelineRun - conversation_id: str - """Identifier for the conversation.""" + session: chat_session.ChatSession + """Session for the conversation.""" stt_metadata: stt.SpeechMetadata | None = None """Metadata of stt input audio. Required when start_stage = stt.""" @@ -1484,7 +1513,9 @@ class PipelineInput: async def execute(self) -> None: """Run pipeline.""" - self.run.start(conversation_id=self.conversation_id, device_id=self.device_id) + self.run.start( + conversation_id=self.session.conversation_id, device_id=self.device_id + ) current_stage: PipelineStage | None = self.run.start_stage stt_audio_buffer: list[EnhancedAudioChunk] = [] stt_processed_stream: AsyncIterable[EnhancedAudioChunk] | None = None @@ -1568,7 +1599,7 @@ class PipelineInput: assert intent_input is not None tts_input = await self.run.recognize_intent( intent_input, - self.conversation_id, + self.session.conversation_id, self.device_id, self.conversation_extra_system_prompt, ) @@ -1652,7 +1683,7 @@ class PipelineInput: <= PIPELINE_STAGE_ORDER.index(PipelineStage.INTENT) <= end_stage_index ): - prepare_tasks.append(self.run.prepare_recognize_intent()) + prepare_tasks.append(self.run.prepare_recognize_intent(self.session)) if ( start_stage_index @@ -1931,7 +1962,7 @@ class PipelineRunDebug: class PipelineStore(Store[SerializedPipelineStorageCollection]): - """Store entity registry data.""" + """Store pipeline data.""" async def _async_migrate_func( self, @@ -2013,3 +2044,37 @@ async def async_run_migrations(hass: HomeAssistant) -> None: for pipeline, attr_updates in updates: await async_update_pipeline(hass, pipeline, **attr_updates) + + +@dataclass +class PipelineConversationData: + """Hold data for the duration of a conversation.""" + + continue_conversation_agent: str | None = None + """The agent that requested the conversation to be continued.""" + + +@callback +def async_get_pipeline_conversation_data( + hass: HomeAssistant, session: chat_session.ChatSession +) -> PipelineConversationData: + """Get the pipeline data for a specific conversation.""" + all_conversation_data = hass.data.get(KEY_PIPELINE_CONVERSATION_DATA) + if all_conversation_data is None: + all_conversation_data = {} + hass.data[KEY_PIPELINE_CONVERSATION_DATA] = all_conversation_data + + data = all_conversation_data.get(session.conversation_id) + + if data is not None: + return data + + @callback + def do_cleanup() -> None: + """Handle cleanup.""" + all_conversation_data.pop(session.conversation_id) + + session.async_on_cleanup(do_cleanup) + + data = all_conversation_data[session.conversation_id] = PipelineConversationData() + return data diff --git a/homeassistant/components/assist_pipeline/websocket_api.py b/homeassistant/components/assist_pipeline/websocket_api.py index d2d54a1b7c3..937b3a0ea45 100644 --- a/homeassistant/components/assist_pipeline/websocket_api.py +++ b/homeassistant/components/assist_pipeline/websocket_api.py @@ -239,7 +239,7 @@ async def websocket_run( with chat_session.async_get_chat_session( hass, msg.get("conversation_id") ) as session: - input_args["conversation_id"] = session.conversation_id + input_args["session"] = session pipeline_input = PipelineInput(**input_args) try: diff --git a/homeassistant/components/conversation/models.py b/homeassistant/components/conversation/models.py index 08a68fa0164..7bdd13afc01 100644 --- a/homeassistant/components/conversation/models.py +++ b/homeassistant/components/conversation/models.py @@ -62,12 +62,14 @@ class ConversationResult: response: intent.IntentResponse conversation_id: str | None = None + continue_conversation: bool = False def as_dict(self) -> dict[str, Any]: """Return result as a dict.""" return { "response": self.response.as_dict(), "conversation_id": self.conversation_id, + "continue_conversation": self.continue_conversation, } diff --git a/homeassistant/components/esphome/assist_satellite.py b/homeassistant/components/esphome/assist_satellite.py index 016b1c3494d..0af74621153 100644 --- a/homeassistant/components/esphome/assist_satellite.py +++ b/homeassistant/components/esphome/assist_satellite.py @@ -284,7 +284,10 @@ class EsphomeAssistSatellite( elif event_type == VoiceAssistantEventType.VOICE_ASSISTANT_INTENT_END: assert event.data is not None data_to_send = { - "conversation_id": event.data["intent_output"]["conversation_id"] or "", + "conversation_id": event.data["intent_output"]["conversation_id"], + "continue_conversation": event.data["intent_output"][ + "continue_conversation" + ], } elif event_type == VoiceAssistantEventType.VOICE_ASSISTANT_TTS_START: assert event.data is not None diff --git a/tests/components/anthropic/snapshots/test_conversation.ambr b/tests/components/anthropic/snapshots/test_conversation.ambr index 93f3b03d9af..de414019317 100644 --- a/tests/components/anthropic/snapshots/test_conversation.ambr +++ b/tests/components/anthropic/snapshots/test_conversation.ambr @@ -1,6 +1,7 @@ # serializer version: 1 # name: test_unknown_hass_api dict({ + 'continue_conversation': False, 'conversation_id': '1234', 'response': IntentResponse( card=dict({ diff --git a/tests/components/assist_pipeline/conftest.py b/tests/components/assist_pipeline/conftest.py index 02ec7c04607..a0549f27f05 100644 --- a/tests/components/assist_pipeline/conftest.py +++ b/tests/components/assist_pipeline/conftest.py @@ -5,7 +5,7 @@ from __future__ import annotations from collections.abc import AsyncIterable, Generator from pathlib import Path from typing import Any -from unittest.mock import AsyncMock +from unittest.mock import AsyncMock, patch import pytest @@ -24,7 +24,7 @@ from homeassistant.components.assist_pipeline.pipeline import ( from homeassistant.config_entries import ConfigEntry, ConfigFlow from homeassistant.const import Platform from homeassistant.core import HomeAssistant -from homeassistant.helpers import device_registry as dr +from homeassistant.helpers import chat_session, device_registry as dr from homeassistant.helpers.entity_platform import AddConfigEntryEntitiesCallback from homeassistant.setup import async_setup_component @@ -379,3 +379,14 @@ def pipeline_storage(pipeline_data) -> PipelineStorageCollection: def make_10ms_chunk(header: bytes) -> bytes: """Return 10ms of zeros with the given header.""" return header + bytes(BYTES_PER_CHUNK - len(header)) + + +@pytest.fixture +def mock_chat_session(hass: HomeAssistant) -> Generator[chat_session.ChatSession]: + """Mock the ulid of chat sessions.""" + # pylint: disable-next=contextmanager-generator-missing-cleanup + with ( + patch("homeassistant.helpers.chat_session.ulid_now", return_value="mock-ulid"), + chat_session.async_get_chat_session(hass) as session, + ): + yield session diff --git a/tests/components/assist_pipeline/snapshots/test_init.ambr b/tests/components/assist_pipeline/snapshots/test_init.ambr index 11e6bc2339a..f5e5f813db6 100644 --- a/tests/components/assist_pipeline/snapshots/test_init.ambr +++ b/tests/components/assist_pipeline/snapshots/test_init.ambr @@ -45,6 +45,7 @@ dict({ 'data': dict({ 'intent_output': dict({ + 'continue_conversation': False, 'conversation_id': , 'response': dict({ 'card': dict({ @@ -137,6 +138,7 @@ dict({ 'data': dict({ 'intent_output': dict({ + 'continue_conversation': False, 'conversation_id': , 'response': dict({ 'card': dict({ @@ -229,6 +231,7 @@ dict({ 'data': dict({ 'intent_output': dict({ + 'continue_conversation': False, 'conversation_id': , 'response': dict({ 'card': dict({ @@ -345,6 +348,7 @@ dict({ 'data': dict({ 'intent_output': dict({ + 'continue_conversation': False, 'conversation_id': , 'response': dict({ 'card': dict({ @@ -432,7 +436,7 @@ list([ dict({ 'data': dict({ - 'conversation_id': 'mock-conversation-id', + 'conversation_id': 'mock-ulid', 'language': 'en', 'pipeline': , }), @@ -440,7 +444,7 @@ }), dict({ 'data': dict({ - 'conversation_id': 'mock-conversation-id', + 'conversation_id': 'mock-ulid', 'device_id': None, 'engine': 'conversation.home_assistant', 'intent_input': 'test input', @@ -452,6 +456,7 @@ dict({ 'data': dict({ 'intent_output': dict({ + 'continue_conversation': False, 'conversation_id': , 'response': dict({ 'card': dict({ @@ -484,7 +489,7 @@ list([ dict({ 'data': dict({ - 'conversation_id': 'mock-conversation-id', + 'conversation_id': 'mock-ulid', 'language': 'en', 'pipeline': , }), @@ -492,7 +497,7 @@ }), dict({ 'data': dict({ - 'conversation_id': 'mock-conversation-id', + 'conversation_id': 'mock-ulid', 'device_id': None, 'engine': 'conversation.home_assistant', 'intent_input': 'test input', @@ -504,6 +509,7 @@ dict({ 'data': dict({ 'intent_output': dict({ + 'continue_conversation': False, 'conversation_id': , 'response': dict({ 'card': dict({ @@ -536,7 +542,7 @@ list([ dict({ 'data': dict({ - 'conversation_id': 'mock-conversation-id', + 'conversation_id': 'mock-ulid', 'language': 'en', 'pipeline': , }), @@ -544,7 +550,7 @@ }), dict({ 'data': dict({ - 'conversation_id': 'mock-conversation-id', + 'conversation_id': 'mock-ulid', 'device_id': None, 'engine': 'conversation.home_assistant', 'intent_input': 'test input', @@ -556,6 +562,7 @@ dict({ 'data': dict({ 'intent_output': dict({ + 'continue_conversation': False, 'conversation_id': , 'response': dict({ 'card': dict({ @@ -588,7 +595,7 @@ list([ dict({ 'data': dict({ - 'conversation_id': 'mock-conversation-id', + 'conversation_id': 'mock-ulid', 'language': 'en', 'pipeline': , }), diff --git a/tests/components/assist_pipeline/snapshots/test_websocket.ambr b/tests/components/assist_pipeline/snapshots/test_websocket.ambr index f677fa6d8cf..509f2072509 100644 --- a/tests/components/assist_pipeline/snapshots/test_websocket.ambr +++ b/tests/components/assist_pipeline/snapshots/test_websocket.ambr @@ -43,6 +43,7 @@ # name: test_audio_pipeline.4 dict({ 'intent_output': dict({ + 'continue_conversation': False, 'conversation_id': , 'response': dict({ 'card': dict({ @@ -127,6 +128,7 @@ # name: test_audio_pipeline_debug.4 dict({ 'intent_output': dict({ + 'continue_conversation': False, 'conversation_id': , 'response': dict({ 'card': dict({ @@ -223,6 +225,7 @@ # name: test_audio_pipeline_with_enhancements.4 dict({ 'intent_output': dict({ + 'continue_conversation': False, 'conversation_id': , 'response': dict({ 'card': dict({ @@ -329,6 +332,7 @@ # name: test_audio_pipeline_with_wake_word_no_timeout.6 dict({ 'intent_output': dict({ + 'continue_conversation': False, 'conversation_id': , 'response': dict({ 'card': dict({ @@ -596,6 +600,7 @@ # name: test_pipeline_empty_tts_output.2 dict({ 'intent_output': dict({ + 'continue_conversation': False, 'conversation_id': , 'response': dict({ 'card': dict({ @@ -715,6 +720,7 @@ # name: test_text_only_pipeline[extra_msg0].2 dict({ 'intent_output': dict({ + 'continue_conversation': False, 'conversation_id': , 'response': dict({ 'card': dict({ @@ -762,6 +768,7 @@ # name: test_text_only_pipeline[extra_msg1].2 dict({ 'intent_output': dict({ + 'continue_conversation': False, 'conversation_id': , 'response': dict({ 'card': dict({ diff --git a/tests/components/assist_pipeline/test_init.py b/tests/components/assist_pipeline/test_init.py index 1651950c173..e983e4a96e3 100644 --- a/tests/components/assist_pipeline/test_init.py +++ b/tests/components/assist_pipeline/test_init.py @@ -27,7 +27,7 @@ from homeassistant.components.assist_pipeline.const import ( ) from homeassistant.const import MATCH_ALL from homeassistant.core import Context, HomeAssistant -from homeassistant.helpers import intent +from homeassistant.helpers import chat_session, intent from homeassistant.setup import async_setup_component from .conftest import ( @@ -675,6 +675,7 @@ async def test_wake_word_detection_aborted( mock_wake_word_provider_entity: MockWakeWordEntity, init_components, pipeline_data: assist_pipeline.pipeline.PipelineData, + mock_chat_session: chat_session.ChatSession, snapshot: SnapshotAssertion, ) -> None: """Test creating a pipeline from an audio stream with wake word.""" @@ -693,7 +694,7 @@ async def test_wake_word_detection_aborted( pipeline = assist_pipeline.pipeline.async_get_pipeline(hass, pipeline_id) pipeline_input = assist_pipeline.pipeline.PipelineInput( - conversation_id="mock-conversation-id", + session=mock_chat_session, device_id=None, stt_metadata=stt.SpeechMetadata( language="", @@ -766,6 +767,7 @@ async def test_tts_audio_output( mock_tts_provider: MockTTSProvider, init_components, pipeline_data: assist_pipeline.pipeline.PipelineData, + mock_chat_session: chat_session.ChatSession, snapshot: SnapshotAssertion, ) -> None: """Test using tts_audio_output with wav sets options correctly.""" @@ -780,7 +782,7 @@ async def test_tts_audio_output( pipeline_input = assist_pipeline.pipeline.PipelineInput( tts_input="This is a test.", - conversation_id="mock-conversation-id", + session=mock_chat_session, device_id=None, run=assist_pipeline.pipeline.PipelineRun( hass, @@ -823,6 +825,7 @@ async def test_tts_wav_preferred_format( hass_client: ClientSessionGenerator, mock_tts_provider: MockTTSProvider, init_components, + mock_chat_session: chat_session.ChatSession, pipeline_data: assist_pipeline.pipeline.PipelineData, ) -> None: """Test that preferred format options are given to the TTS system if supported.""" @@ -837,7 +840,7 @@ async def test_tts_wav_preferred_format( pipeline_input = assist_pipeline.pipeline.PipelineInput( tts_input="This is a test.", - conversation_id="mock-conversation-id", + session=mock_chat_session, device_id=None, run=assist_pipeline.pipeline.PipelineRun( hass, @@ -891,6 +894,7 @@ async def test_tts_dict_preferred_format( hass_client: ClientSessionGenerator, mock_tts_provider: MockTTSProvider, init_components, + mock_chat_session: chat_session.ChatSession, pipeline_data: assist_pipeline.pipeline.PipelineData, ) -> None: """Test that preferred format options are given to the TTS system if supported.""" @@ -905,7 +909,7 @@ async def test_tts_dict_preferred_format( pipeline_input = assist_pipeline.pipeline.PipelineInput( tts_input="This is a test.", - conversation_id="mock-conversation-id", + session=mock_chat_session, device_id=None, run=assist_pipeline.pipeline.PipelineRun( hass, @@ -962,6 +966,7 @@ async def test_tts_dict_preferred_format( async def test_sentence_trigger_overrides_conversation_agent( hass: HomeAssistant, init_components, + mock_chat_session: chat_session.ChatSession, pipeline_data: assist_pipeline.pipeline.PipelineData, ) -> None: """Test that sentence triggers are checked before a non-default conversation agent.""" @@ -991,7 +996,7 @@ async def test_sentence_trigger_overrides_conversation_agent( pipeline_input = assist_pipeline.pipeline.PipelineInput( intent_input="test trigger sentence", - conversation_id="mock-conversation-id", + session=mock_chat_session, run=assist_pipeline.pipeline.PipelineRun( hass, context=Context(), @@ -1039,6 +1044,7 @@ async def test_sentence_trigger_overrides_conversation_agent( async def test_prefer_local_intents( hass: HomeAssistant, init_components, + mock_chat_session: chat_session.ChatSession, pipeline_data: assist_pipeline.pipeline.PipelineData, ) -> None: """Test that the default agent is checked first when local intents are preferred.""" @@ -1069,7 +1075,7 @@ async def test_prefer_local_intents( pipeline_input = assist_pipeline.pipeline.PipelineInput( intent_input="I'd like to order a stout please", - conversation_id="mock-conversation-id", + session=mock_chat_session, run=assist_pipeline.pipeline.PipelineRun( hass, context=Context(), @@ -1113,10 +1119,150 @@ async def test_prefer_local_intents( ) +async def test_intent_continue_conversation( + hass: HomeAssistant, + init_components, + mock_chat_session: chat_session.ChatSession, + pipeline_data: assist_pipeline.pipeline.PipelineData, +) -> None: + """Test that a conversation agent flagging continue conversation gets response.""" + events: list[assist_pipeline.PipelineEvent] = [] + + # Fake a test agent and prefer local intents + pipeline_store = pipeline_data.pipeline_store + pipeline_id = pipeline_store.async_get_preferred_item() + pipeline = assist_pipeline.pipeline.async_get_pipeline(hass, pipeline_id) + await assist_pipeline.pipeline.async_update_pipeline( + hass, pipeline, conversation_engine="test-agent" + ) + pipeline = assist_pipeline.pipeline.async_get_pipeline(hass, pipeline_id) + + pipeline_input = assist_pipeline.pipeline.PipelineInput( + intent_input="Set a timer", + session=mock_chat_session, + run=assist_pipeline.pipeline.PipelineRun( + hass, + context=Context(), + pipeline=pipeline, + start_stage=assist_pipeline.PipelineStage.INTENT, + end_stage=assist_pipeline.PipelineStage.INTENT, + event_callback=events.append, + ), + ) + + # Ensure prepare succeeds + with patch( + "homeassistant.components.assist_pipeline.pipeline.conversation.async_get_agent_info", + return_value=conversation.AgentInfo(id="test-agent", name="Test Agent"), + ): + await pipeline_input.validate() + + response = intent.IntentResponse("en") + response.async_set_speech("For how long?") + + with patch( + "homeassistant.components.assist_pipeline.pipeline.conversation.async_converse", + return_value=conversation.ConversationResult( + response=response, + conversation_id=mock_chat_session.conversation_id, + continue_conversation=True, + ), + ) as mock_async_converse: + await pipeline_input.execute() + + mock_async_converse.assert_called() + + results = [ + event.data + for event in events + if event.type + in ( + assist_pipeline.PipelineEventType.INTENT_START, + assist_pipeline.PipelineEventType.INTENT_END, + ) + ] + assert results[1]["intent_output"]["continue_conversation"] is True + + # Change conversation agent to default one and register sentence trigger that should not be called + await assist_pipeline.pipeline.async_update_pipeline( + hass, pipeline, conversation_engine=None + ) + pipeline = assist_pipeline.pipeline.async_get_pipeline(hass, pipeline_id) + assert await async_setup_component( + hass, + "automation", + { + "automation": { + "trigger": { + "platform": "conversation", + "command": ["Hello"], + }, + "action": { + "set_conversation_response": "test trigger response", + }, + } + }, + ) + + # Because we did continue conversation, it should respond to the test agent again. + events.clear() + + pipeline_input = assist_pipeline.pipeline.PipelineInput( + intent_input="Hello", + session=mock_chat_session, + run=assist_pipeline.pipeline.PipelineRun( + hass, + context=Context(), + pipeline=pipeline, + start_stage=assist_pipeline.PipelineStage.INTENT, + end_stage=assist_pipeline.PipelineStage.INTENT, + event_callback=events.append, + ), + ) + + # Ensure prepare succeeds + with patch( + "homeassistant.components.assist_pipeline.pipeline.conversation.async_get_agent_info", + return_value=conversation.AgentInfo(id="test-agent", name="Test Agent"), + ) as mock_prepare: + await pipeline_input.validate() + + # It requested test agent even if that was not default agent. + assert mock_prepare.mock_calls[0][1][1] == "test-agent" + + response = intent.IntentResponse("en") + response.async_set_speech("Timer set for 20 minutes") + + with patch( + "homeassistant.components.assist_pipeline.pipeline.conversation.async_converse", + return_value=conversation.ConversationResult( + response=response, + conversation_id=mock_chat_session.conversation_id, + ), + ) as mock_async_converse: + await pipeline_input.execute() + + mock_async_converse.assert_called() + + # Snapshot will show it was still handled by the test agent and not default agent + results = [ + event.data + for event in events + if event.type + in ( + assist_pipeline.PipelineEventType.INTENT_START, + assist_pipeline.PipelineEventType.INTENT_END, + ) + ] + assert results[0]["engine"] == "test-agent" + assert results[1]["intent_output"]["continue_conversation"] is False + + async def test_stt_language_used_instead_of_conversation_language( hass: HomeAssistant, hass_ws_client: WebSocketGenerator, init_components, + mock_chat_session: chat_session.ChatSession, snapshot: SnapshotAssertion, ) -> None: """Test that the STT language is used first when the conversation language is '*' (all languages).""" @@ -1147,7 +1293,7 @@ async def test_stt_language_used_instead_of_conversation_language( pipeline_input = assist_pipeline.pipeline.PipelineInput( intent_input="test input", - conversation_id="mock-conversation-id", + session=mock_chat_session, run=assist_pipeline.pipeline.PipelineRun( hass, context=Context(), @@ -1192,6 +1338,7 @@ async def test_tts_language_used_instead_of_conversation_language( hass: HomeAssistant, hass_ws_client: WebSocketGenerator, init_components, + mock_chat_session: chat_session.ChatSession, snapshot: SnapshotAssertion, ) -> None: """Test that the TTS language is used after STT when the conversation language is '*' (all languages).""" @@ -1222,7 +1369,7 @@ async def test_tts_language_used_instead_of_conversation_language( pipeline_input = assist_pipeline.pipeline.PipelineInput( intent_input="test input", - conversation_id="mock-conversation-id", + session=mock_chat_session, run=assist_pipeline.pipeline.PipelineRun( hass, context=Context(), @@ -1267,6 +1414,7 @@ async def test_pipeline_language_used_instead_of_conversation_language( hass: HomeAssistant, hass_ws_client: WebSocketGenerator, init_components, + mock_chat_session: chat_session.ChatSession, snapshot: SnapshotAssertion, ) -> None: """Test that the pipeline language is used last when the conversation language is '*' (all languages).""" @@ -1297,7 +1445,7 @@ async def test_pipeline_language_used_instead_of_conversation_language( pipeline_input = assist_pipeline.pipeline.PipelineInput( intent_input="test input", - conversation_id="mock-conversation-id", + session=mock_chat_session, run=assist_pipeline.pipeline.PipelineRun( hass, context=Context(), diff --git a/tests/components/conversation/__init__.py b/tests/components/conversation/__init__.py index 314188dbd82..eeab8b6b9af 100644 --- a/tests/components/conversation/__init__.py +++ b/tests/components/conversation/__init__.py @@ -2,6 +2,7 @@ from __future__ import annotations +from collections.abc import AsyncGenerator from dataclasses import dataclass, field from typing import Literal from unittest.mock import patch @@ -49,7 +50,7 @@ class MockAgent(conversation.AbstractConversationAgent): @pytest.fixture -async def mock_chat_log(hass: HomeAssistant) -> MockChatLog: +async def mock_chat_log(hass: HomeAssistant) -> AsyncGenerator[MockChatLog]: """Return mock chat logs.""" # pylint: disable-next=contextmanager-generator-missing-cleanup with ( diff --git a/tests/components/conversation/snapshots/test_chat_log.ambr b/tests/components/conversation/snapshots/test_chat_log.ambr index 1ddbf68bb84..ff8ebf724cd 100644 --- a/tests/components/conversation/snapshots/test_chat_log.ambr +++ b/tests/components/conversation/snapshots/test_chat_log.ambr @@ -151,6 +151,7 @@ # --- # name: test_template_error dict({ + 'continue_conversation': False, 'conversation_id': , 'response': dict({ 'card': dict({ @@ -171,6 +172,7 @@ # --- # name: test_unknown_llm_api dict({ + 'continue_conversation': False, 'conversation_id': , 'response': dict({ 'card': dict({ diff --git a/tests/components/conversation/snapshots/test_default_agent.ambr b/tests/components/conversation/snapshots/test_default_agent.ambr index c2b16ea2912..02e4ef1befe 100644 --- a/tests/components/conversation/snapshots/test_default_agent.ambr +++ b/tests/components/conversation/snapshots/test_default_agent.ambr @@ -1,6 +1,7 @@ # serializer version: 1 # name: test_custom_sentences dict({ + 'continue_conversation': False, 'conversation_id': , 'response': dict({ 'card': dict({ @@ -26,6 +27,7 @@ # --- # name: test_custom_sentences.1 dict({ + 'continue_conversation': False, 'conversation_id': , 'response': dict({ 'card': dict({ @@ -51,6 +53,7 @@ # --- # name: test_custom_sentences_config dict({ + 'continue_conversation': False, 'conversation_id': , 'response': dict({ 'card': dict({ @@ -76,6 +79,7 @@ # --- # name: test_intent_alias_added_removed dict({ + 'continue_conversation': False, 'conversation_id': , 'response': dict({ 'card': dict({ @@ -106,6 +110,7 @@ # --- # name: test_intent_alias_added_removed.1 dict({ + 'continue_conversation': False, 'conversation_id': , 'response': dict({ 'card': dict({ @@ -136,6 +141,7 @@ # --- # name: test_intent_alias_added_removed.2 dict({ + 'continue_conversation': False, 'conversation_id': , 'response': dict({ 'card': dict({ @@ -156,6 +162,7 @@ # --- # name: test_intent_conversion_not_expose_new dict({ + 'continue_conversation': False, 'conversation_id': , 'response': dict({ 'card': dict({ @@ -176,6 +183,7 @@ # --- # name: test_intent_conversion_not_expose_new.1 dict({ + 'continue_conversation': False, 'conversation_id': , 'response': dict({ 'card': dict({ @@ -206,6 +214,7 @@ # --- # name: test_intent_entity_added_removed dict({ + 'continue_conversation': False, 'conversation_id': , 'response': dict({ 'card': dict({ @@ -236,6 +245,7 @@ # --- # name: test_intent_entity_added_removed.1 dict({ + 'continue_conversation': False, 'conversation_id': , 'response': dict({ 'card': dict({ @@ -266,6 +276,7 @@ # --- # name: test_intent_entity_added_removed.2 dict({ + 'continue_conversation': False, 'conversation_id': , 'response': dict({ 'card': dict({ @@ -296,6 +307,7 @@ # --- # name: test_intent_entity_added_removed.3 dict({ + 'continue_conversation': False, 'conversation_id': , 'response': dict({ 'card': dict({ @@ -316,6 +328,7 @@ # --- # name: test_intent_entity_exposed dict({ + 'continue_conversation': False, 'conversation_id': , 'response': dict({ 'card': dict({ @@ -346,6 +359,7 @@ # --- # name: test_intent_entity_fail_if_unexposed dict({ + 'continue_conversation': False, 'conversation_id': , 'response': dict({ 'card': dict({ @@ -366,6 +380,7 @@ # --- # name: test_intent_entity_remove_custom_name dict({ + 'continue_conversation': False, 'conversation_id': , 'response': dict({ 'card': dict({ @@ -386,6 +401,7 @@ # --- # name: test_intent_entity_remove_custom_name.1 dict({ + 'continue_conversation': False, 'conversation_id': , 'response': dict({ 'card': dict({ @@ -416,6 +432,7 @@ # --- # name: test_intent_entity_remove_custom_name.2 dict({ + 'continue_conversation': False, 'conversation_id': , 'response': dict({ 'card': dict({ @@ -436,6 +453,7 @@ # --- # name: test_intent_entity_renamed dict({ + 'continue_conversation': False, 'conversation_id': , 'response': dict({ 'card': dict({ @@ -466,6 +484,7 @@ # --- # name: test_intent_entity_renamed.1 dict({ + 'continue_conversation': False, 'conversation_id': , 'response': dict({ 'card': dict({ diff --git a/tests/components/conversation/snapshots/test_http.ambr b/tests/components/conversation/snapshots/test_http.ambr index c6ac6c2df9c..849a5b17102 100644 --- a/tests/components/conversation/snapshots/test_http.ambr +++ b/tests/components/conversation/snapshots/test_http.ambr @@ -202,6 +202,7 @@ # --- # name: test_http_api_handle_failure dict({ + 'continue_conversation': False, 'conversation_id': , 'response': dict({ 'card': dict({ @@ -222,6 +223,7 @@ # --- # name: test_http_api_no_match dict({ + 'continue_conversation': False, 'conversation_id': , 'response': dict({ 'card': dict({ @@ -242,6 +244,7 @@ # --- # name: test_http_api_unexpected_failure dict({ + 'continue_conversation': False, 'conversation_id': , 'response': dict({ 'card': dict({ @@ -262,6 +265,7 @@ # --- # name: test_http_processing_intent[None] dict({ + 'continue_conversation': False, 'conversation_id': , 'response': dict({ 'card': dict({ @@ -292,6 +296,7 @@ # --- # name: test_http_processing_intent[conversation.home_assistant] dict({ + 'continue_conversation': False, 'conversation_id': , 'response': dict({ 'card': dict({ @@ -322,6 +327,7 @@ # --- # name: test_http_processing_intent[homeassistant] dict({ + 'continue_conversation': False, 'conversation_id': , 'response': dict({ 'card': dict({ @@ -352,6 +358,7 @@ # --- # name: test_ws_api[payload0] dict({ + 'continue_conversation': False, 'conversation_id': , 'response': dict({ 'card': dict({ @@ -372,6 +379,7 @@ # --- # name: test_ws_api[payload1] dict({ + 'continue_conversation': False, 'conversation_id': , 'response': dict({ 'card': dict({ @@ -392,6 +400,7 @@ # --- # name: test_ws_api[payload2] dict({ + 'continue_conversation': False, 'conversation_id': , 'response': dict({ 'card': dict({ @@ -412,6 +421,7 @@ # --- # name: test_ws_api[payload3] dict({ + 'continue_conversation': False, 'conversation_id': , 'response': dict({ 'card': dict({ @@ -432,6 +442,7 @@ # --- # name: test_ws_api[payload4] dict({ + 'continue_conversation': False, 'conversation_id': , 'response': dict({ 'card': dict({ @@ -452,6 +463,7 @@ # --- # name: test_ws_api[payload5] dict({ + 'continue_conversation': False, 'conversation_id': , 'response': dict({ 'card': dict({ diff --git a/tests/components/conversation/snapshots/test_init.ambr b/tests/components/conversation/snapshots/test_init.ambr index 911c7043a6d..3d843d4e32a 100644 --- a/tests/components/conversation/snapshots/test_init.ambr +++ b/tests/components/conversation/snapshots/test_init.ambr @@ -1,6 +1,7 @@ # serializer version: 1 # name: test_custom_agent dict({ + 'continue_conversation': False, 'conversation_id': , 'response': dict({ 'card': dict({ @@ -44,6 +45,7 @@ # --- # name: test_turn_on_intent[None-turn kitchen on-None] dict({ + 'continue_conversation': False, 'conversation_id': , 'response': dict({ 'card': dict({ @@ -74,6 +76,7 @@ # --- # name: test_turn_on_intent[None-turn kitchen on-conversation.home_assistant] dict({ + 'continue_conversation': False, 'conversation_id': , 'response': dict({ 'card': dict({ @@ -104,6 +107,7 @@ # --- # name: test_turn_on_intent[None-turn kitchen on-homeassistant] dict({ + 'continue_conversation': False, 'conversation_id': , 'response': dict({ 'card': dict({ @@ -134,6 +138,7 @@ # --- # name: test_turn_on_intent[None-turn on kitchen-None] dict({ + 'continue_conversation': False, 'conversation_id': , 'response': dict({ 'card': dict({ @@ -164,6 +169,7 @@ # --- # name: test_turn_on_intent[None-turn on kitchen-conversation.home_assistant] dict({ + 'continue_conversation': False, 'conversation_id': , 'response': dict({ 'card': dict({ @@ -194,6 +200,7 @@ # --- # name: test_turn_on_intent[None-turn on kitchen-homeassistant] dict({ + 'continue_conversation': False, 'conversation_id': , 'response': dict({ 'card': dict({ @@ -224,6 +231,7 @@ # --- # name: test_turn_on_intent[my_new_conversation-turn kitchen on-None] dict({ + 'continue_conversation': False, 'conversation_id': , 'response': dict({ 'card': dict({ @@ -254,6 +262,7 @@ # --- # name: test_turn_on_intent[my_new_conversation-turn kitchen on-conversation.home_assistant] dict({ + 'continue_conversation': False, 'conversation_id': , 'response': dict({ 'card': dict({ @@ -284,6 +293,7 @@ # --- # name: test_turn_on_intent[my_new_conversation-turn kitchen on-homeassistant] dict({ + 'continue_conversation': False, 'conversation_id': , 'response': dict({ 'card': dict({ @@ -314,6 +324,7 @@ # --- # name: test_turn_on_intent[my_new_conversation-turn on kitchen-None] dict({ + 'continue_conversation': False, 'conversation_id': , 'response': dict({ 'card': dict({ @@ -344,6 +355,7 @@ # --- # name: test_turn_on_intent[my_new_conversation-turn on kitchen-conversation.home_assistant] dict({ + 'continue_conversation': False, 'conversation_id': , 'response': dict({ 'card': dict({ @@ -374,6 +386,7 @@ # --- # name: test_turn_on_intent[my_new_conversation-turn on kitchen-homeassistant] dict({ + 'continue_conversation': False, 'conversation_id': , 'response': dict({ 'card': dict({ diff --git a/tests/components/esphome/test_assist_satellite.py b/tests/components/esphome/test_assist_satellite.py index 30535236970..56914a0b829 100644 --- a/tests/components/esphome/test_assist_satellite.py +++ b/tests/components/esphome/test_assist_satellite.py @@ -25,7 +25,7 @@ from aioesphomeapi import ( ) import pytest -from homeassistant.components import assist_satellite, tts +from homeassistant.components import assist_satellite, conversation, tts from homeassistant.components.assist_pipeline import PipelineEvent, PipelineEventType from homeassistant.components.assist_satellite import ( AssistSatelliteConfiguration, @@ -285,12 +285,21 @@ async def test_pipeline_api_audio( event_callback( PipelineEvent( type=PipelineEventType.INTENT_END, - data={"intent_output": {"conversation_id": conversation_id}}, + data={ + "intent_output": conversation.ConversationResult( + response=intent_helper.IntentResponse("en"), + conversation_id=conversation_id, + continue_conversation=True, + ).as_dict() + }, ) ) assert mock_client.send_voice_assistant_event.call_args_list[-1].args == ( VoiceAssistantEventType.VOICE_ASSISTANT_INTENT_END, - {"conversation_id": conversation_id}, + { + "conversation_id": conversation_id, + "continue_conversation": True, + }, ) # TTS @@ -484,7 +493,12 @@ async def test_pipeline_udp_audio( event_callback( PipelineEvent( type=PipelineEventType.INTENT_END, - data={"intent_output": {"conversation_id": conversation_id}}, + data={ + "intent_output": conversation.ConversationResult( + response=intent_helper.IntentResponse("en"), + conversation_id=conversation_id, + ).as_dict() + }, ) ) @@ -690,7 +704,12 @@ async def test_pipeline_media_player( event_callback( PipelineEvent( type=PipelineEventType.INTENT_END, - data={"intent_output": {"conversation_id": conversation_id}}, + data={ + "intent_output": conversation.ConversationResult( + response=intent_helper.IntentResponse("en"), + conversation_id=conversation_id, + ).as_dict() + }, ) ) diff --git a/tests/components/mobile_app/test_webhook.py b/tests/components/mobile_app/test_webhook.py index dda5f369ad5..b071caebd16 100644 --- a/tests/components/mobile_app/test_webhook.py +++ b/tests/components/mobile_app/test_webhook.py @@ -1081,6 +1081,7 @@ async def test_webhook_handle_conversation_process( }, }, "conversation_id": None, + "continue_conversation": False, } diff --git a/tests/components/ollama/snapshots/test_conversation.ambr b/tests/components/ollama/snapshots/test_conversation.ambr index 93f3b03d9af..de414019317 100644 --- a/tests/components/ollama/snapshots/test_conversation.ambr +++ b/tests/components/ollama/snapshots/test_conversation.ambr @@ -1,6 +1,7 @@ # serializer version: 1 # name: test_unknown_hass_api dict({ + 'continue_conversation': False, 'conversation_id': '1234', 'response': IntentResponse( card=dict({ diff --git a/tests/syrupy.py b/tests/syrupy.py index 3c8e398f0f8..e028d5839cb 100644 --- a/tests/syrupy.py +++ b/tests/syrupy.py @@ -109,7 +109,11 @@ class HomeAssistantSnapshotSerializer(AmberDataSerializer): serializable_data = cls._serializable_issue_registry_entry(data) elif isinstance(data, dict) and "flow_id" in data and "handler" in data: serializable_data = cls._serializable_flow_result(data) - elif isinstance(data, dict) and set(data) == {"conversation_id", "response"}: + elif isinstance(data, dict) and set(data) == { + "conversation_id", + "response", + "continue_conversation", + }: serializable_data = cls._serializable_conversation_result(data) elif isinstance(data, vol.Schema): serializable_data = voluptuous_serialize.convert(data)