From 7a484ee0ae0c8f21bea12db74af925bc59dcf875 Mon Sep 17 00:00:00 2001 From: Paulus Schoutsen Date: Mon, 6 Jan 2025 12:58:42 -0500 Subject: [PATCH] Add extra prompt to assist pipeline and conversation (#124743) * Add extra prompt to assist pipeline and conversation * extra_prompt -> extra_system_prompt * Fix rebase * Fix tests --- .../components/assist_pipeline/__init__.py | 2 ++ .../components/assist_pipeline/pipeline.py | 13 ++++++++++++- .../components/conversation/agent_manager.py | 2 ++ homeassistant/components/conversation/models.py | 4 ++++ tests/components/conversation/test_agent_manager.py | 2 ++ tests/components/conversation/test_trigger.py | 4 ++++ 6 files changed, 26 insertions(+), 1 deletion(-) diff --git a/homeassistant/components/assist_pipeline/__init__.py b/homeassistant/components/assist_pipeline/__init__.py index ec6d8a646b6..851c873bb12 100644 --- a/homeassistant/components/assist_pipeline/__init__.py +++ b/homeassistant/components/assist_pipeline/__init__.py @@ -108,6 +108,7 @@ async def async_pipeline_from_audio_stream( device_id: str | None = None, start_stage: PipelineStage = PipelineStage.STT, end_stage: PipelineStage = PipelineStage.TTS, + conversation_extra_system_prompt: str | None = None, ) -> None: """Create an audio pipeline from an audio stream. @@ -119,6 +120,7 @@ async def async_pipeline_from_audio_stream( stt_metadata=stt_metadata, stt_stream=stt_stream, wake_word_phrase=wake_word_phrase, + conversation_extra_system_prompt=conversation_extra_system_prompt, run=PipelineRun( hass, context=context, diff --git a/homeassistant/components/assist_pipeline/pipeline.py b/homeassistant/components/assist_pipeline/pipeline.py index 7dda24c4023..c3a5b93ca6a 100644 --- a/homeassistant/components/assist_pipeline/pipeline.py +++ b/homeassistant/components/assist_pipeline/pipeline.py @@ -1010,7 +1010,11 @@ class PipelineRun: self.intent_agent = agent_info.id async def recognize_intent( - self, intent_input: str, conversation_id: str | None, device_id: str | None + self, + intent_input: str, + conversation_id: str | None, + device_id: str | None, + conversation_extra_system_prompt: str | None, ) -> str: """Run intent recognition portion of pipeline. Returns text to speak.""" if self.intent_agent is None: @@ -1045,6 +1049,7 @@ class PipelineRun: device_id=device_id, language=input_language, agent_id=self.intent_agent, + extra_system_prompt=conversation_extra_system_prompt, ) processed_locally = self.intent_agent == conversation.HOME_ASSISTANT_AGENT @@ -1392,8 +1397,13 @@ class PipelineInput: """Input for text-to-speech. Required when start_stage = tts.""" conversation_id: str | None = None + """Identifier for the conversation.""" + + conversation_extra_system_prompt: str | None = None + """Extra prompt information for the conversation agent.""" device_id: str | None = None + """Identifier of the device that is processing the input/output of the pipeline.""" async def execute(self) -> None: """Run pipeline.""" @@ -1483,6 +1493,7 @@ class PipelineInput: intent_input, self.conversation_id, self.device_id, + self.conversation_extra_system_prompt, ) if tts_input.strip(): current_stage = PipelineStage.TTS diff --git a/homeassistant/components/conversation/agent_manager.py b/homeassistant/components/conversation/agent_manager.py index 7516d9d22ef..97dc5e1292e 100644 --- a/homeassistant/components/conversation/agent_manager.py +++ b/homeassistant/components/conversation/agent_manager.py @@ -75,6 +75,7 @@ async def async_converse( language: str | None = None, agent_id: str | None = None, device_id: str | None = None, + extra_system_prompt: str | None = None, ) -> ConversationResult: """Process text and get intent.""" agent = async_get_agent(hass, agent_id) @@ -99,6 +100,7 @@ async def async_converse( device_id=device_id, language=language, agent_id=agent_id, + extra_system_prompt=extra_system_prompt, ) with async_conversation_trace() as trace: trace.add_event( diff --git a/homeassistant/components/conversation/models.py b/homeassistant/components/conversation/models.py index 10218e76751..9462c597f23 100644 --- a/homeassistant/components/conversation/models.py +++ b/homeassistant/components/conversation/models.py @@ -40,6 +40,9 @@ class ConversationInput: agent_id: str | None = None """Agent to use for processing.""" + extra_system_prompt: str | None = None + """Extra prompt to provide extra info to LLMs how to understand the command.""" + def as_dict(self) -> dict[str, Any]: """Return input as a dict.""" return { @@ -49,6 +52,7 @@ class ConversationInput: "device_id": self.device_id, "language": self.language, "agent_id": self.agent_id, + "extra_system_prompt": self.extra_system_prompt, } diff --git a/tests/components/conversation/test_agent_manager.py b/tests/components/conversation/test_agent_manager.py index 47b58a522a8..3f98c9bcd69 100644 --- a/tests/components/conversation/test_agent_manager.py +++ b/tests/components/conversation/test_agent_manager.py @@ -22,6 +22,7 @@ async def test_async_converse(hass: HomeAssistant, init_components) -> None: language="test lang", agent_id="conversation.home_assistant", device_id="test device id", + extra_system_prompt="test extra prompt", ) assert mock_process.called @@ -32,3 +33,4 @@ async def test_async_converse(hass: HomeAssistant, init_components) -> None: assert conversation_input.language == "test lang" assert conversation_input.agent_id == "conversation.home_assistant" assert conversation_input.device_id == "test device id" + assert conversation_input.extra_system_prompt == "test extra prompt" diff --git a/tests/components/conversation/test_trigger.py b/tests/components/conversation/test_trigger.py index 50fac51c87a..9b57bb43b58 100644 --- a/tests/components/conversation/test_trigger.py +++ b/tests/components/conversation/test_trigger.py @@ -88,6 +88,7 @@ async def test_if_fires_on_event( "device_id": None, "language": "en", "text": "Ha ha ha", + "extra_system_prompt": None, }, } @@ -235,6 +236,7 @@ async def test_response_same_sentence( "device_id": None, "language": "en", "text": "test sentence", + "extra_system_prompt": None, }, } @@ -412,6 +414,7 @@ async def test_same_trigger_multiple_sentences( "device_id": None, "language": "en", "text": "hello", + "extra_system_prompt": None, }, } @@ -639,6 +642,7 @@ async def test_wildcards(hass: HomeAssistant, service_calls: list[ServiceCall]) "device_id": None, "language": "en", "text": "play the white album by the beatles", + "extra_system_prompt": None, }, }