From b897e6a85f388ffe521eadae69a7f0de5026a071 Mon Sep 17 00:00:00 2001 From: Michael Hansen Date: Mon, 13 Jan 2025 14:17:12 -0600 Subject: [PATCH] Use STT/TTS languages for LLM fallback (#135533) --- .../components/assist_pipeline/pipeline.py | 15 +- .../assist_pipeline/snapshots/test_init.ambr | 102 ++++++++++++ tests/components/assist_pipeline/test_init.py | 154 +++++++++++++++++- 3 files changed, 265 insertions(+), 6 deletions(-) diff --git a/homeassistant/components/assist_pipeline/pipeline.py b/homeassistant/components/assist_pipeline/pipeline.py index c3a5b93ca6a..1b76121fcd2 100644 --- a/homeassistant/components/assist_pipeline/pipeline.py +++ b/homeassistant/components/assist_pipeline/pipeline.py @@ -1021,9 +1021,18 @@ class PipelineRun: raise RuntimeError("Recognize intent was not prepared") if self.pipeline.conversation_language == MATCH_ALL: - # LLMs support all languages ('*') so use pipeline language for - # intent fallback. - input_language = self.pipeline.language + # LLMs support all languages ('*') so use languages from the + # pipeline for intent fallback. + # + # We prioritize the STT and TTS languages because they may be more + # specific, such as "zh-CN" instead of just "zh". This is necessary + # for languages whose intents are split out by region when + # preferring local intent matching. + input_language = ( + self.pipeline.stt_language + or self.pipeline.tts_language + or self.pipeline.language + ) else: input_language = self.pipeline.conversation_language diff --git a/tests/components/assist_pipeline/snapshots/test_init.ambr b/tests/components/assist_pipeline/snapshots/test_init.ambr index f63a28efbb7..171014fdc4a 100644 --- a/tests/components/assist_pipeline/snapshots/test_init.ambr +++ b/tests/components/assist_pipeline/snapshots/test_init.ambr @@ -474,6 +474,108 @@ }), ]) # --- +# name: test_stt_language_used_instead_of_conversation_language + list([ + dict({ + 'data': dict({ + 'language': 'en', + 'pipeline': , + }), + 'type': , + }), + dict({ + 'data': dict({ + 'conversation_id': None, + 'device_id': None, + 'engine': 'conversation.home_assistant', + 'intent_input': 'test input', + 'language': 'en-US', + 'prefer_local_intents': False, + }), + 'type': , + }), + dict({ + 'data': dict({ + 'intent_output': dict({ + 'conversation_id': None, + 'response': dict({ + 'card': dict({ + }), + 'data': dict({ + 'failed': list([ + ]), + 'success': list([ + ]), + 'targets': list([ + ]), + }), + 'language': 'en', + 'response_type': 'action_done', + 'speech': dict({ + }), + }), + }), + 'processed_locally': True, + }), + 'type': , + }), + dict({ + 'data': None, + 'type': , + }), + ]) +# --- +# name: test_tts_language_used_instead_of_conversation_language + list([ + dict({ + 'data': dict({ + 'language': 'en', + 'pipeline': , + }), + 'type': , + }), + dict({ + 'data': dict({ + 'conversation_id': None, + 'device_id': None, + 'engine': 'conversation.home_assistant', + 'intent_input': 'test input', + 'language': 'en-us', + 'prefer_local_intents': False, + }), + 'type': , + }), + dict({ + 'data': dict({ + 'intent_output': dict({ + 'conversation_id': None, + 'response': dict({ + 'card': dict({ + }), + 'data': dict({ + 'failed': list([ + ]), + 'success': list([ + ]), + 'targets': list([ + ]), + }), + 'language': 'en', + 'response_type': 'action_done', + 'speech': dict({ + }), + }), + }), + 'processed_locally': True, + }), + 'type': , + }), + dict({ + 'data': None, + 'type': , + }), + ]) +# --- # name: test_wake_word_detection_aborted list([ dict({ diff --git a/tests/components/assist_pipeline/test_init.py b/tests/components/assist_pipeline/test_init.py index d4cce4e2e98..a2cb9ef382a 100644 --- a/tests/components/assist_pipeline/test_init.py +++ b/tests/components/assist_pipeline/test_init.py @@ -1102,13 +1102,13 @@ async def test_prefer_local_intents( ) -async def test_pipeline_language_used_instead_of_conversation_language( +async def test_stt_language_used_instead_of_conversation_language( hass: HomeAssistant, hass_ws_client: WebSocketGenerator, init_components, snapshot: SnapshotAssertion, ) -> None: - """Test that the pipeline language is used when the conversation language is '*' (all languages).""" + """Test that the STT language is used first when the conversation language is '*' (all languages).""" client = await hass_ws_client(hass) events: list[assist_pipeline.PipelineEvent] = [] @@ -1165,7 +1165,155 @@ async def test_pipeline_language_used_instead_of_conversation_language( assert intent_start is not None - # Pipeline language (en) should be used instead of '*' + # STT language (en-US) should be used instead of '*' + assert intent_start.data.get("language") == pipeline.stt_language + + # Check input to async_converse + mock_async_converse.assert_called_once() + assert ( + mock_async_converse.call_args_list[0].kwargs.get("language") + == pipeline.stt_language + ) + + +async def test_tts_language_used_instead_of_conversation_language( + hass: HomeAssistant, + hass_ws_client: WebSocketGenerator, + init_components, + snapshot: SnapshotAssertion, +) -> None: + """Test that the TTS language is used after STT when the conversation language is '*' (all languages).""" + client = await hass_ws_client(hass) + + events: list[assist_pipeline.PipelineEvent] = [] + + await client.send_json_auto_id( + { + "type": "assist_pipeline/pipeline/create", + "conversation_engine": "homeassistant", + "conversation_language": MATCH_ALL, + "language": "en", + "name": "test_name", + "stt_engine": None, + "stt_language": None, + "tts_engine": None, + "tts_language": "en-us", + "tts_voice": "Arnold Schwarzenegger", + "wake_word_entity": None, + "wake_word_id": None, + } + ) + msg = await client.receive_json() + assert msg["success"] + pipeline_id = msg["result"]["id"] + pipeline = assist_pipeline.async_get_pipeline(hass, pipeline_id) + + pipeline_input = assist_pipeline.pipeline.PipelineInput( + intent_input="test input", + run=assist_pipeline.pipeline.PipelineRun( + hass, + context=Context(), + pipeline=pipeline, + start_stage=assist_pipeline.PipelineStage.INTENT, + end_stage=assist_pipeline.PipelineStage.INTENT, + event_callback=events.append, + ), + ) + await pipeline_input.validate() + + with patch( + "homeassistant.components.assist_pipeline.pipeline.conversation.async_converse", + return_value=conversation.ConversationResult( + intent.IntentResponse(pipeline.language) + ), + ) as mock_async_converse: + await pipeline_input.execute() + + # Check intent start event + assert process_events(events) == snapshot + intent_start: assist_pipeline.PipelineEvent | None = None + for event in events: + if event.type == assist_pipeline.PipelineEventType.INTENT_START: + intent_start = event + break + + assert intent_start is not None + + # STT language (en-US) should be used instead of '*' + assert intent_start.data.get("language") == pipeline.tts_language + + # Check input to async_converse + mock_async_converse.assert_called_once() + assert ( + mock_async_converse.call_args_list[0].kwargs.get("language") + == pipeline.tts_language + ) + + +async def test_pipeline_language_used_instead_of_conversation_language( + hass: HomeAssistant, + hass_ws_client: WebSocketGenerator, + init_components, + snapshot: SnapshotAssertion, +) -> None: + """Test that the pipeline language is used last when the conversation language is '*' (all languages).""" + client = await hass_ws_client(hass) + + events: list[assist_pipeline.PipelineEvent] = [] + + await client.send_json_auto_id( + { + "type": "assist_pipeline/pipeline/create", + "conversation_engine": "homeassistant", + "conversation_language": MATCH_ALL, + "language": "en", + "name": "test_name", + "stt_engine": None, + "stt_language": None, + "tts_engine": None, + "tts_language": None, + "tts_voice": None, + "wake_word_entity": None, + "wake_word_id": None, + } + ) + msg = await client.receive_json() + assert msg["success"] + pipeline_id = msg["result"]["id"] + pipeline = assist_pipeline.async_get_pipeline(hass, pipeline_id) + + pipeline_input = assist_pipeline.pipeline.PipelineInput( + intent_input="test input", + run=assist_pipeline.pipeline.PipelineRun( + hass, + context=Context(), + pipeline=pipeline, + start_stage=assist_pipeline.PipelineStage.INTENT, + end_stage=assist_pipeline.PipelineStage.INTENT, + event_callback=events.append, + ), + ) + await pipeline_input.validate() + + with patch( + "homeassistant.components.assist_pipeline.pipeline.conversation.async_converse", + return_value=conversation.ConversationResult( + intent.IntentResponse(pipeline.language) + ), + ) as mock_async_converse: + await pipeline_input.execute() + + # Check intent start event + assert process_events(events) == snapshot + intent_start: assist_pipeline.PipelineEvent | None = None + for event in events: + if event.type == assist_pipeline.PipelineEventType.INTENT_START: + intent_start = event + break + + assert intent_start is not None + + # STT language (en-US) should be used instead of '*' assert intent_start.data.get("language") == pipeline.language # Check input to async_converse