diff --git a/homeassistant/components/conversation/__init__.py b/homeassistant/components/conversation/__init__.py index 69e738205c5..11de75801ba 100644 --- a/homeassistant/components/conversation/__init__.py +++ b/homeassistant/components/conversation/__init__.py @@ -32,6 +32,7 @@ from .agent_manager import ( ) from .chat_log import ( AssistantContent, + AssistantContentDeltaDict, ChatLog, Content, ConverseError, @@ -65,6 +66,7 @@ __all__ = [ "HOME_ASSISTANT_AGENT", "OLD_HOME_ASSISTANT_AGENT", "AssistantContent", + "AssistantContentDeltaDict", "ChatLog", "Content", "ConversationEntity", diff --git a/homeassistant/components/conversation/chat_log.py b/homeassistant/components/conversation/chat_log.py index 086e1374c1a..5dbd19ba275 100644 --- a/homeassistant/components/conversation/chat_log.py +++ b/homeassistant/components/conversation/chat_log.py @@ -3,11 +3,12 @@ from __future__ import annotations import asyncio -from collections.abc import AsyncGenerator, Generator +from collections.abc import AsyncGenerator, AsyncIterable, Generator from contextlib import contextmanager from contextvars import ContextVar from dataclasses import dataclass, field, replace import logging +from typing import Literal, TypedDict import voluptuous as vol @@ -145,6 +146,14 @@ class ToolResultContent: type Content = SystemContent | UserContent | AssistantContent | ToolResultContent +class AssistantContentDeltaDict(TypedDict, total=False): + """Partial content to define an AssistantContent.""" + + role: Literal["assistant"] + content: str | None + tool_calls: list[llm.ToolInput] | None + + @dataclass class ChatLog: """Class holding the chat history of a specific conversation.""" @@ -155,6 +164,11 @@ class ChatLog: extra_system_prompt: str | None = None llm_api: llm.APIInstance | None = None + @property + def unresponded_tool_results(self) -> bool: + """Return if there are unresponded tool results.""" + return self.content[-1].role == "tool_result" + @callback def async_add_user_content(self, content: UserContent) -> None: """Add user content to the log.""" @@ -223,6 +237,77 @@ class ChatLog: self.content.append(response_content) yield response_content + async def async_add_delta_content_stream( + self, agent_id: str, stream: AsyncIterable[AssistantContentDeltaDict] + ) -> AsyncGenerator[AssistantContent | ToolResultContent]: + """Stream content into the chat log. + + Returns a generator with all content that was added to the chat log. + + stream iterates over dictionaries with optional keys role, content and tool_calls. + + When a delta contains a role key, the current message is considered complete and + a new message is started. + + The keys content and tool_calls will be concatenated if they appear multiple times. + """ + current_content = "" + current_tool_calls: list[llm.ToolInput] = [] + tool_call_tasks: dict[str, asyncio.Task] = {} + + async for delta in stream: + LOGGER.debug("Received delta: %s", delta) + + # Indicates update to current message + if "role" not in delta: + if delta_content := delta.get("content"): + current_content += delta_content + if delta_tool_calls := delta.get("tool_calls"): + if self.llm_api is None: + raise ValueError("No LLM API configured") + current_tool_calls += delta_tool_calls + + # Start processing the tool calls as soon as we know about them + for tool_call in delta_tool_calls: + tool_call_tasks[tool_call.id] = self.hass.async_create_task( + self.llm_api.async_call_tool(tool_call), + name=f"llm_tool_{tool_call.id}", + ) + continue + + # Starting a new message + + if delta["role"] != "assistant": + raise ValueError(f"Only assistant role expected. Got {delta['role']}") + + # Yield the previous message if it has content + if current_content or current_tool_calls: + content = AssistantContent( + agent_id=agent_id, + content=current_content or None, + tool_calls=current_tool_calls or None, + ) + yield content + async for tool_result in self.async_add_assistant_content( + content, tool_call_tasks=tool_call_tasks + ): + yield tool_result + + current_content = delta.get("content") or "" + current_tool_calls = delta.get("tool_calls") or [] + + if current_content or current_tool_calls: + content = AssistantContent( + agent_id=agent_id, + content=current_content or None, + tool_calls=current_tool_calls or None, + ) + yield content + async for tool_result in self.async_add_assistant_content( + content, tool_call_tasks=tool_call_tasks + ): + yield tool_result + async def async_update_llm_data( self, conversing_domain: str, diff --git a/homeassistant/components/openai_conversation/conversation.py b/homeassistant/components/openai_conversation/conversation.py index eaa62bd1adc..4dee1d4b167 100644 --- a/homeassistant/components/openai_conversation/conversation.py +++ b/homeassistant/components/openai_conversation/conversation.py @@ -1,14 +1,15 @@ """Conversation support for OpenAI.""" -from collections.abc import Callable +from collections.abc import AsyncGenerator, Callable import json from typing import Any, Literal, cast import openai +from openai._streaming import AsyncStream from openai._types import NOT_GIVEN from openai.types.chat import ( ChatCompletionAssistantMessageParam, - ChatCompletionMessage, + ChatCompletionChunk, ChatCompletionMessageParam, ChatCompletionMessageToolCallParam, ChatCompletionToolMessageParam, @@ -70,32 +71,6 @@ def _format_tool( return ChatCompletionToolParam(type="function", function=tool_spec) -def _convert_message_to_param( - message: ChatCompletionMessage, -) -> ChatCompletionMessageParam: - """Convert from class to TypedDict.""" - tool_calls: list[ChatCompletionMessageToolCallParam] = [] - if message.tool_calls: - tool_calls = [ - ChatCompletionMessageToolCallParam( - id=tool_call.id, - function=Function( - arguments=tool_call.function.arguments, - name=tool_call.function.name, - ), - type=tool_call.type, - ) - for tool_call in message.tool_calls - ] - param = ChatCompletionAssistantMessageParam( - role=message.role, - content=message.content, - ) - if tool_calls: - param["tool_calls"] = tool_calls - return param - - def _convert_content_to_param( content: conversation.Content, ) -> ChatCompletionMessageParam: @@ -135,6 +110,74 @@ def _convert_content_to_param( ) +async def _transform_stream( + result: AsyncStream[ChatCompletionChunk], +) -> AsyncGenerator[conversation.AssistantContentDeltaDict]: + """Transform an OpenAI delta stream into HA format.""" + current_tool_call: dict | None = None + + async for chunk in result: + LOGGER.debug("Received chunk: %s", chunk) + choice = chunk.choices[0] + + if choice.finish_reason: + if current_tool_call: + yield { + "tool_calls": [ + llm.ToolInput( + id=current_tool_call["id"], + tool_name=current_tool_call["tool_name"], + tool_args=json.loads(current_tool_call["tool_args"]), + ) + ] + } + + break + + delta = chunk.choices[0].delta + + # We can yield delta messages not continuing or starting tool calls + if current_tool_call is None and not delta.tool_calls: + yield { # type: ignore[misc] + key: value + for key in ("role", "content") + if (value := getattr(delta, key)) is not None + } + continue + + # When doing tool calls, we should always have a tool call + # object or we have gotten stopped above with a finish_reason set. + if ( + not delta.tool_calls + or not (delta_tool_call := delta.tool_calls[0]) + or not delta_tool_call.function + ): + raise ValueError("Expected delta with tool call") + + if current_tool_call and delta_tool_call.index == current_tool_call["index"]: + current_tool_call["tool_args"] += delta_tool_call.function.arguments or "" + continue + + # We got tool call with new index, so we need to yield the previous + if current_tool_call: + yield { + "tool_calls": [ + llm.ToolInput( + id=current_tool_call["id"], + tool_name=current_tool_call["tool_name"], + tool_args=json.loads(current_tool_call["tool_args"]), + ) + ] + } + + current_tool_call = { + "index": delta_tool_call.index, + "id": delta_tool_call.id, + "tool_name": delta_tool_call.function.name, + "tool_args": delta_tool_call.function.arguments or "", + } + + class OpenAIConversationEntity( conversation.ConversationEntity, conversation.AbstractConversationAgent ): @@ -234,6 +277,7 @@ class OpenAIConversationEntity( "top_p": options.get(CONF_TOP_P, RECOMMENDED_TOP_P), "temperature": options.get(CONF_TEMPERATURE, RECOMMENDED_TEMPERATURE), "user": chat_log.conversation_id, + "stream": True, } if model.startswith("o"): @@ -247,39 +291,21 @@ class OpenAIConversationEntity( LOGGER.error("Error talking to OpenAI: %s", err) raise HomeAssistantError("Error talking to OpenAI") from err - LOGGER.debug("Response %s", result) - response = result.choices[0].message - messages.append(_convert_message_to_param(response)) - - tool_calls: list[llm.ToolInput] | None = None - if response.tool_calls: - tool_calls = [ - llm.ToolInput( - id=tool_call.id, - tool_name=tool_call.function.name, - tool_args=json.loads(tool_call.function.arguments), - ) - for tool_call in response.tool_calls - ] - messages.extend( [ - _convert_content_to_param(tool_response) - async for tool_response in chat_log.async_add_assistant_content( - conversation.AssistantContent( - agent_id=user_input.agent_id, - content=response.content or "", - tool_calls=tool_calls, - ) + _convert_content_to_param(content) + async for content in chat_log.async_add_delta_content_stream( + user_input.agent_id, _transform_stream(result) ) ] ) - if not tool_calls: + if not chat_log.unresponded_tool_results: break intent_response = intent.IntentResponse(language=user_input.language) - intent_response.async_set_speech(response.content or "") + assert type(chat_log.content[-1]) is conversation.AssistantContent + intent_response.async_set_speech(chat_log.content[-1].content or "") return conversation.ConversationResult( response=intent_response, conversation_id=chat_log.conversation_id ) diff --git a/homeassistant/helpers/chat_session.py b/homeassistant/helpers/chat_session.py index 686272dd834..e7a4ecd2ca9 100644 --- a/homeassistant/helpers/chat_session.py +++ b/homeassistant/helpers/chat_session.py @@ -7,6 +7,7 @@ from contextlib import contextmanager from contextvars import ContextVar from dataclasses import dataclass, field from datetime import datetime, timedelta +import logging from homeassistant.const import EVENT_HOMEASSISTANT_STOP from homeassistant.core import ( @@ -27,6 +28,7 @@ DATA_CHAT_SESSION: HassKey[dict[str, ChatSession]] = HassKey("chat_session") DATA_CHAT_SESSION_CLEANUP: HassKey[SessionCleanup] = HassKey("chat_session_cleanup") CONVERSATION_TIMEOUT = timedelta(minutes=5) +LOGGER = logging.getLogger(__name__) current_session: ContextVar[ChatSession | None] = ContextVar( "current_session", default=None @@ -100,6 +102,7 @@ class SessionCleanup: # yielding session based on it. for conversation_id, session in list(all_sessions.items()): if session.last_updated + CONVERSATION_TIMEOUT < now: + LOGGER.debug("Cleaning up session %s", conversation_id) del all_sessions[conversation_id] session.async_cleanup() @@ -150,6 +153,7 @@ def async_get_chat_session( pass if session is None: + LOGGER.debug("Creating new session %s", conversation_id) session = ChatSession(conversation_id) current_session.set(session) diff --git a/tests/components/conversation/snapshots/test_chat_log.ambr b/tests/components/conversation/snapshots/test_chat_log.ambr index 4e94157c601..1ddbf68bb84 100644 --- a/tests/components/conversation/snapshots/test_chat_log.ambr +++ b/tests/components/conversation/snapshots/test_chat_log.ambr @@ -1,4 +1,154 @@ # serializer version: 1 +# name: test_add_delta_content_stream[deltas0] + list([ + ]) +# --- +# name: test_add_delta_content_stream[deltas1] + list([ + dict({ + 'agent_id': 'mock-agent-id', + 'content': 'Test', + 'role': 'assistant', + 'tool_calls': None, + }), + ]) +# --- +# name: test_add_delta_content_stream[deltas2] + list([ + dict({ + 'agent_id': 'mock-agent-id', + 'content': 'Test', + 'role': 'assistant', + 'tool_calls': None, + }), + dict({ + 'agent_id': 'mock-agent-id', + 'content': 'Test 2', + 'role': 'assistant', + 'tool_calls': None, + }), + ]) +# --- +# name: test_add_delta_content_stream[deltas3] + list([ + dict({ + 'agent_id': 'mock-agent-id', + 'content': None, + 'role': 'assistant', + 'tool_calls': list([ + dict({ + 'id': 'mock-tool-call-id', + 'tool_args': dict({ + 'param1': 'Test Param 1', + }), + 'tool_name': 'test_tool', + }), + ]), + }), + dict({ + 'agent_id': 'mock-agent-id', + 'role': 'tool_result', + 'tool_call_id': 'mock-tool-call-id', + 'tool_name': 'test_tool', + 'tool_result': 'Test Param 1', + }), + ]) +# --- +# name: test_add_delta_content_stream[deltas4] + list([ + dict({ + 'agent_id': 'mock-agent-id', + 'content': 'Test', + 'role': 'assistant', + 'tool_calls': list([ + dict({ + 'id': 'mock-tool-call-id', + 'tool_args': dict({ + 'param1': 'Test Param 1', + }), + 'tool_name': 'test_tool', + }), + ]), + }), + dict({ + 'agent_id': 'mock-agent-id', + 'role': 'tool_result', + 'tool_call_id': 'mock-tool-call-id', + 'tool_name': 'test_tool', + 'tool_result': 'Test Param 1', + }), + ]) +# --- +# name: test_add_delta_content_stream[deltas5] + list([ + dict({ + 'agent_id': 'mock-agent-id', + 'content': 'Test', + 'role': 'assistant', + 'tool_calls': list([ + dict({ + 'id': 'mock-tool-call-id', + 'tool_args': dict({ + 'param1': 'Test Param 1', + }), + 'tool_name': 'test_tool', + }), + ]), + }), + dict({ + 'agent_id': 'mock-agent-id', + 'role': 'tool_result', + 'tool_call_id': 'mock-tool-call-id', + 'tool_name': 'test_tool', + 'tool_result': 'Test Param 1', + }), + dict({ + 'agent_id': 'mock-agent-id', + 'content': 'Test 2', + 'role': 'assistant', + 'tool_calls': None, + }), + ]) +# --- +# name: test_add_delta_content_stream[deltas6] + list([ + dict({ + 'agent_id': 'mock-agent-id', + 'content': None, + 'role': 'assistant', + 'tool_calls': list([ + dict({ + 'id': 'mock-tool-call-id', + 'tool_args': dict({ + 'param1': 'Test Param 1', + }), + 'tool_name': 'test_tool', + }), + dict({ + 'id': 'mock-tool-call-id-2', + 'tool_args': dict({ + 'param1': 'Test Param 2', + }), + 'tool_name': 'test_tool', + }), + ]), + }), + dict({ + 'agent_id': 'mock-agent-id', + 'role': 'tool_result', + 'tool_call_id': 'mock-tool-call-id', + 'tool_name': 'test_tool', + 'tool_result': 'Test Param 1', + }), + dict({ + 'agent_id': 'mock-agent-id', + 'role': 'tool_result', + 'tool_call_id': 'mock-tool-call-id-2', + 'tool_name': 'test_tool', + 'tool_result': 'Test Param 2', + }), + ]) +# --- # name: test_template_error dict({ 'conversation_id': , diff --git a/tests/components/conversation/test_chat_log.py b/tests/components/conversation/test_chat_log.py index 1f659b8005e..090904c7063 100644 --- a/tests/components/conversation/test_chat_log.py +++ b/tests/components/conversation/test_chat_log.py @@ -282,7 +282,7 @@ async def test_extra_systen_prompt( @pytest.mark.parametrize( "prerun_tool_tasks", [ - None, + (), ("mock-tool-call-id",), ("mock-tool-call-id", "mock-tool-call-id-2"), ], @@ -290,7 +290,7 @@ async def test_extra_systen_prompt( async def test_tool_call( hass: HomeAssistant, mock_conversation_input: ConversationInput, - prerun_tool_tasks: tuple[str] | None, + prerun_tool_tasks: tuple[str], ) -> None: """Test using the session tool calling API.""" @@ -334,15 +334,13 @@ async def test_tool_call( ], ) - tool_call_tasks = None - if prerun_tool_tasks: - tool_call_tasks = { - tool_call_id: hass.async_create_task( - chat_log.llm_api.async_call_tool(content.tool_calls[0]), - tool_call_id, - ) - for tool_call_id in prerun_tool_tasks - } + tool_call_tasks = { + tool_call_id: hass.async_create_task( + chat_log.llm_api.async_call_tool(content.tool_calls[0]), + tool_call_id, + ) + for tool_call_id in prerun_tool_tasks + } with pytest.raises(ValueError): chat_log.async_add_assistant_content_without_tools(content) @@ -350,7 +348,7 @@ async def test_tool_call( results = [ tool_result_content async for tool_result_content in chat_log.async_add_assistant_content( - content, tool_call_tasks=tool_call_tasks + content, tool_call_tasks=tool_call_tasks or None ) ] @@ -382,37 +380,36 @@ async def test_tool_call_exception( ) mock_tool.async_call.side_effect = HomeAssistantError("Test error") - with patch( - "homeassistant.helpers.llm.AssistAPI._async_get_tools", return_value=[] - ) as mock_get_tools: + with ( + patch( + "homeassistant.helpers.llm.AssistAPI._async_get_tools", return_value=[] + ) as mock_get_tools, + chat_session.async_get_chat_session(hass) as session, + async_get_chat_log(hass, session, mock_conversation_input) as chat_log, + ): mock_get_tools.return_value = [mock_tool] - - with ( - chat_session.async_get_chat_session(hass) as session, - async_get_chat_log(hass, session, mock_conversation_input) as chat_log, - ): - await chat_log.async_update_llm_data( - conversing_domain="test", - user_input=mock_conversation_input, - user_llm_hass_api="assist", - user_llm_prompt=None, + await chat_log.async_update_llm_data( + conversing_domain="test", + user_input=mock_conversation_input, + user_llm_hass_api="assist", + user_llm_prompt=None, + ) + result = None + async for tool_result_content in chat_log.async_add_assistant_content( + AssistantContent( + agent_id=mock_conversation_input.agent_id, + content="", + tool_calls=[ + llm.ToolInput( + id="mock-tool-call-id", + tool_name="test_tool", + tool_args={"param1": "Test Param"}, + ) + ], ) - result = None - async for tool_result_content in chat_log.async_add_assistant_content( - AssistantContent( - agent_id=mock_conversation_input.agent_id, - content="", - tool_calls=[ - llm.ToolInput( - id="mock-tool-call-id", - tool_name="test_tool", - tool_args={"param1": "Test Param"}, - ) - ], - ) - ): - assert result is None - result = tool_result_content + ): + assert result is None + result = tool_result_content assert result == ToolResultContent( agent_id=mock_conversation_input.agent_id, @@ -420,3 +417,188 @@ async def test_tool_call_exception( tool_result={"error": "HomeAssistantError", "error_text": "Test error"}, tool_name="test_tool", ) + + +@pytest.mark.parametrize( + "deltas", + [ + [], + # With content + [ + {"role": "assistant"}, + {"content": "Test"}, + ], + # With 2 content + [ + {"role": "assistant"}, + {"content": "Test"}, + {"role": "assistant"}, + {"content": "Test 2"}, + ], + # With 1 tool call + [ + {"role": "assistant"}, + { + "tool_calls": [ + llm.ToolInput( + id="mock-tool-call-id", + tool_name="test_tool", + tool_args={"param1": "Test Param 1"}, + ) + ] + }, + ], + # With content and 1 tool call + [ + {"role": "assistant"}, + {"content": "Test"}, + { + "tool_calls": [ + llm.ToolInput( + id="mock-tool-call-id", + tool_name="test_tool", + tool_args={"param1": "Test Param 1"}, + ) + ] + }, + ], + # With 2 contents and 1 tool call + [ + {"role": "assistant"}, + {"content": "Test"}, + { + "tool_calls": [ + llm.ToolInput( + id="mock-tool-call-id", + tool_name="test_tool", + tool_args={"param1": "Test Param 1"}, + ) + ] + }, + {"role": "assistant"}, + {"content": "Test 2"}, + ], + # With 2 tool calls + [ + {"role": "assistant"}, + { + "tool_calls": [ + llm.ToolInput( + id="mock-tool-call-id", + tool_name="test_tool", + tool_args={"param1": "Test Param 1"}, + ) + ] + }, + { + "tool_calls": [ + llm.ToolInput( + id="mock-tool-call-id-2", + tool_name="test_tool", + tool_args={"param1": "Test Param 2"}, + ) + ] + }, + ], + ], +) +async def test_add_delta_content_stream( + hass: HomeAssistant, + mock_conversation_input: ConversationInput, + snapshot: SnapshotAssertion, + deltas: list[dict], +) -> None: + """Test streaming deltas.""" + + mock_tool = AsyncMock() + mock_tool.name = "test_tool" + mock_tool.description = "Test function" + mock_tool.parameters = vol.Schema( + {vol.Optional("param1", description="Test parameters"): str} + ) + + async def tool_call( + hass: HomeAssistant, tool_input: llm.ToolInput, llm_context: llm.LLMContext + ) -> str: + """Call the tool.""" + return tool_input.tool_args["param1"] + + mock_tool.async_call.side_effect = tool_call + + async def stream(): + """Yield deltas.""" + for d in deltas: + yield d + + with ( + patch( + "homeassistant.helpers.llm.AssistAPI._async_get_tools", return_value=[] + ) as mock_get_tools, + chat_session.async_get_chat_session(hass) as session, + async_get_chat_log(hass, session, mock_conversation_input) as chat_log, + ): + mock_get_tools.return_value = [mock_tool] + await chat_log.async_update_llm_data( + conversing_domain="test", + user_input=mock_conversation_input, + user_llm_hass_api="assist", + user_llm_prompt=None, + ) + + results = [ + tool_result_content + async for tool_result_content in chat_log.async_add_delta_content_stream( + "mock-agent-id", stream() + ) + ] + + assert results == snapshot + assert chat_log.content[2:] == results + + +async def test_add_delta_content_stream_errors( + hass: HomeAssistant, + mock_conversation_input: ConversationInput, +) -> None: + """Test streaming deltas error handling.""" + + async def stream(deltas): + """Yield deltas.""" + for d in deltas: + yield d + + with ( + chat_session.async_get_chat_session(hass) as session, + async_get_chat_log(hass, session, mock_conversation_input) as chat_log, + ): + # Stream content without LLM API set + with pytest.raises(ValueError): # noqa: PT012 + async for _tool_result_content in chat_log.async_add_delta_content_stream( + "mock-agent-id", + stream( + [ + {"role": "assistant"}, + { + "tool_calls": [ + llm.ToolInput( + id="mock-tool-call-id", + tool_name="test_tool", + tool_args={}, + ) + ] + }, + ] + ), + ): + pass + + # Non assistant role + for role in "system", "user": + with pytest.raises(ValueError): # noqa: PT012 + async for ( + _tool_result_content + ) in chat_log.async_add_delta_content_stream( + "mock-agent-id", + stream([{"role": role}]), + ): + pass diff --git a/tests/components/openai_conversation/snapshots/test_conversation.ambr b/tests/components/openai_conversation/snapshots/test_conversation.ambr index 4ef8b8655ee..2db5be706ef 100644 --- a/tests/components/openai_conversation/snapshots/test_conversation.ambr +++ b/tests/components/openai_conversation/snapshots/test_conversation.ambr @@ -1,34 +1,64 @@ # serializer version: 1 -# name: test_unknown_hass_api - dict({ - 'conversation_id': 'my-conversation-id', - 'response': IntentResponse( - card=dict({ - }), - error_code=, - failed_results=list([ - ]), - intent=None, - intent_targets=list([ - ]), - language='en', - matched_states=list([ - ]), - reprompt=dict({ - }), - response_type=, - speech=dict({ - 'plain': dict({ - 'extra_data': None, - 'speech': 'Error preparing LLM API', +# name: test_function_call + list([ + dict({ + 'content': ''' + Current time is 16:00:00. Today's date is 2024-06-03. + You are a voice assistant for Home Assistant. + Answer questions about the world truthfully. + Answer in plain text. Keep it simple and to the point. + Only if the user wants to control a device, tell them to expose entities to their voice assistant in Home Assistant. + ''', + 'role': 'system', + }), + dict({ + 'content': 'hello', + 'role': 'user', + }), + dict({ + 'content': 'Please call the test function', + 'role': 'user', + }), + dict({ + 'agent_id': 'conversation.openai', + 'content': None, + 'role': 'assistant', + 'tool_calls': list([ + dict({ + 'id': 'call_call_1', + 'tool_args': dict({ + 'param1': 'call1', + }), + 'tool_name': 'test_tool', + }), + dict({ + 'id': 'call_call_2', + 'tool_args': dict({ + 'param1': 'call2', + }), + 'tool_name': 'test_tool', }), - }), - speech_slots=dict({ - }), - success_results=list([ ]), - unmatched_states=list([ - ]), - ), - }) + }), + dict({ + 'agent_id': 'conversation.openai', + 'role': 'tool_result', + 'tool_call_id': 'call_call_1', + 'tool_name': 'test_tool', + 'tool_result': 'value1', + }), + dict({ + 'agent_id': 'conversation.openai', + 'role': 'tool_result', + 'tool_call_id': 'call_call_2', + 'tool_name': 'test_tool', + 'tool_result': 'value2', + }), + dict({ + 'agent_id': 'conversation.openai', + 'content': 'Cool', + 'role': 'assistant', + 'tool_calls': None, + }), + ]) # --- diff --git a/tests/components/openai_conversation/test_conversation.py b/tests/components/openai_conversation/test_conversation.py index 39ca1b53e28..9afdfc6a5a2 100644 --- a/tests/components/openai_conversation/test_conversation.py +++ b/tests/components/openai_conversation/test_conversation.py @@ -1,29 +1,130 @@ """Tests for the OpenAI integration.""" +from collections.abc import Generator +from dataclasses import dataclass, field from unittest.mock import AsyncMock, patch from freezegun import freeze_time from httpx import Response from openai import RateLimitError -from openai.types.chat.chat_completion import ChatCompletion, Choice -from openai.types.chat.chat_completion_message import ChatCompletionMessage -from openai.types.chat.chat_completion_message_tool_call import ( - ChatCompletionMessageToolCall, - Function, +from openai.types.chat.chat_completion_chunk import ( + ChatCompletionChunk, + Choice, + ChoiceDelta, + ChoiceDeltaToolCall, + ChoiceDeltaToolCallFunction, ) -from openai.types.completion_usage import CompletionUsage -import voluptuous as vol +import pytest +from syrupy.assertion import SnapshotAssertion from homeassistant.components import conversation -from homeassistant.components.conversation import trace +from homeassistant.components.conversation import chat_log +from homeassistant.components.homeassistant.exposed_entities import async_expose_entity from homeassistant.const import CONF_LLM_HASS_API from homeassistant.core import Context, HomeAssistant -from homeassistant.exceptions import HomeAssistantError -from homeassistant.helpers import intent, llm +from homeassistant.helpers import chat_session, intent from homeassistant.setup import async_setup_component from tests.common import MockConfigEntry +ASSIST_RESPONSE_FINISH = ( + # Assistant message + ChatCompletionChunk( + id="chatcmpl-B", + created=1700000000, + model="gpt-4-1106-preview", + object="chat.completion.chunk", + choices=[Choice(index=0, delta=ChoiceDelta(content="Cool"))], + ), + # Finish stream + ChatCompletionChunk( + id="chatcmpl-B", + created=1700000000, + model="gpt-4-1106-preview", + object="chat.completion.chunk", + choices=[Choice(index=0, finish_reason="stop", delta=ChoiceDelta())], + ), +) + + +@pytest.fixture +def mock_create_stream() -> Generator[AsyncMock]: + """Mock stream response.""" + + async def mock_generator(stream): + for value in stream: + yield value + + with patch( + "openai.resources.chat.completions.AsyncCompletions.create", + AsyncMock(), + ) as mock_create: + mock_create.side_effect = lambda **kwargs: mock_generator( + mock_create.return_value.pop(0) + ) + + yield mock_create + + +@dataclass +class MockChatLog(chat_log.ChatLog): + """Mock chat log.""" + + _mock_tool_results: dict = field(default_factory=dict) + + def mock_tool_results(self, results: dict) -> None: + """Set tool results.""" + self._mock_tool_results = results + + @property + def llm_api(self): + """Return LLM API.""" + return self._llm_api + + @llm_api.setter + def llm_api(self, value): + """Set LLM API.""" + self._llm_api = value + + if not value: + return + + async def async_call_tool(tool_input): + """Call tool.""" + if tool_input.id not in self._mock_tool_results: + raise ValueError(f"Tool {tool_input.id} not found") + return self._mock_tool_results[tool_input.id] + + self._llm_api.async_call_tool = async_call_tool + + def latest_content(self) -> list[conversation.Content]: + """Return content from latest version chat log. + + The chat log makes copies until it's committed. Helper to get latest content. + """ + with ( + chat_session.async_get_chat_session( + self.hass, self.conversation_id + ) as session, + conversation.async_get_chat_log(self.hass, session) as chat_log, + ): + return chat_log.content + + +@pytest.fixture +async def mock_chat_log(hass: HomeAssistant) -> MockChatLog: + """Return mock chat logs.""" + with ( + patch( + "homeassistant.components.conversation.chat_log.ChatLog", + MockChatLog, + ), + chat_session.async_get_chat_session(hass, "mock-conversation-id") as session, + conversation.async_get_chat_log(hass, session) as chat_log, + ): + chat_log.async_add_user_content(conversation.UserContent("hello")) + return chat_log + async def test_entity( hass: HomeAssistant, @@ -83,348 +184,299 @@ async def test_conversation_agent( assert agent.supported_languages == "*" -@patch( - "homeassistant.components.openai_conversation.conversation.llm.AssistAPI._async_get_tools" -) async def test_function_call( - mock_get_tools, hass: HomeAssistant, mock_config_entry_with_assist: MockConfigEntry, mock_init_component, + mock_create_stream: AsyncMock, + mock_chat_log: MockChatLog, + snapshot: SnapshotAssertion, ) -> None: """Test function call from the assistant.""" - agent_id = mock_config_entry_with_assist.entry_id - context = Context() - - mock_tool = AsyncMock() - mock_tool.name = "test_tool" - mock_tool.description = "Test function" - mock_tool.parameters = vol.Schema( - {vol.Optional("param1", description="Test parameters"): str} + mock_create_stream.return_value = [ + # Initial conversation + ( + # First tool call + ChatCompletionChunk( + id="chatcmpl-A", + created=1700000000, + model="gpt-4-1106-preview", + object="chat.completion.chunk", + choices=[ + Choice( + index=0, + delta=ChoiceDelta( + tool_calls=[ + ChoiceDeltaToolCall( + id="call_call_1", + index=0, + function=ChoiceDeltaToolCallFunction( + name="test_tool", + arguments=None, + ), + ) + ] + ), + ) + ], + ), + ChatCompletionChunk( + id="chatcmpl-A", + created=1700000000, + model="gpt-4-1106-preview", + object="chat.completion.chunk", + choices=[ + Choice( + index=0, + delta=ChoiceDelta( + tool_calls=[ + ChoiceDeltaToolCall( + index=0, + function=ChoiceDeltaToolCallFunction( + name=None, + arguments='{"para', + ), + ) + ] + ), + ) + ], + ), + ChatCompletionChunk( + id="chatcmpl-A", + created=1700000000, + model="gpt-4-1106-preview", + object="chat.completion.chunk", + choices=[ + Choice( + index=0, + delta=ChoiceDelta( + tool_calls=[ + ChoiceDeltaToolCall( + index=0, + function=ChoiceDeltaToolCallFunction( + name=None, + arguments='m1":"call1"}', + ), + ) + ] + ), + ) + ], + ), + # Second tool call + ChatCompletionChunk( + id="chatcmpl-A", + created=1700000000, + model="gpt-4-1106-preview", + object="chat.completion.chunk", + choices=[ + Choice( + index=0, + delta=ChoiceDelta( + tool_calls=[ + ChoiceDeltaToolCall( + id="call_call_2", + index=1, + function=ChoiceDeltaToolCallFunction( + name="test_tool", + arguments='{"param1":"call2"}', + ), + ) + ] + ), + ) + ], + ), + # Finish stream + ChatCompletionChunk( + id="chatcmpl-A", + created=1700000000, + model="gpt-4-1106-preview", + object="chat.completion.chunk", + choices=[ + Choice(index=0, finish_reason="tool_calls", delta=ChoiceDelta()) + ], + ), + ), + # Response after tool responses + ASSIST_RESPONSE_FINISH, + ] + mock_chat_log.mock_tool_results( + { + "call_call_1": "value1", + "call_call_2": "value2", + } ) - mock_tool.async_call.return_value = "Test response" - mock_get_tools.return_value = [mock_tool] + with freeze_time("2024-06-03 23:00:00"): + result = await conversation.async_converse( + hass, + "Please call the test function", + "mock-conversation-id", + Context(), + agent_id="conversation.openai", + ) - def completion_result(*args, messages, **kwargs): - for message in messages: - role = message["role"] if isinstance(message, dict) else message.role - if role == "tool": - return ChatCompletion( - id="chatcmpl-1234567890ZYXWVUTSRQPONMLKJIH", + assert result.response.response_type == intent.IntentResponseType.ACTION_DONE + assert mock_chat_log.latest_content() == snapshot + + +@pytest.mark.parametrize( + ("description", "messages"), + [ + ( + "Test function call started with missing arguments", + ( + ChatCompletionChunk( + id="chatcmpl-A", + created=1700000000, + model="gpt-4-1106-preview", + object="chat.completion.chunk", choices=[ Choice( - finish_reason="stop", index=0, - message=ChatCompletionMessage( - content="I have successfully called the function", - role="assistant", - function_call=None, - tool_calls=None, + delta=ChoiceDelta( + tool_calls=[ + ChoiceDeltaToolCall( + id="call_call_1", + index=0, + function=ChoiceDeltaToolCallFunction( + name="test_tool", + arguments=None, + ), + ) + ] ), ) ], + ), + ChatCompletionChunk( + id="chatcmpl-B", created=1700000000, model="gpt-4-1106-preview", - object="chat.completion", - system_fingerprint=None, - usage=CompletionUsage( - completion_tokens=9, prompt_tokens=8, total_tokens=17 - ), - ) - - return ChatCompletion( - id="chatcmpl-1234567890ABCDEFGHIJKLMNOPQRS", - choices=[ - Choice( - finish_reason="tool_calls", - index=0, - message=ChatCompletionMessage( - content=None, - role="assistant", - function_call=None, - tool_calls=[ - ChatCompletionMessageToolCall( - id="call_AbCdEfGhIjKlMnOpQrStUvWx", - function=Function( - arguments='{"param1":"test_value"}', - name="test_tool", - ), - type="function", - ) - ], - ), - ) - ], - created=1700000000, - model="gpt-4-1106-preview", - object="chat.completion", - system_fingerprint=None, - usage=CompletionUsage( - completion_tokens=9, prompt_tokens=8, total_tokens=17 + object="chat.completion.chunk", + choices=[Choice(index=0, delta=ChoiceDelta(content="Cool"))], + ), ), - ) - - with ( - patch( - "openai.resources.chat.completions.AsyncCompletions.create", - new_callable=AsyncMock, - side_effect=completion_result, - ) as mock_create, - freeze_time("2024-06-03 23:00:00"), - ): - result = await conversation.async_converse( - hass, - "Please call the test function", - None, - context, - agent_id=agent_id, - ) - - assert ( - "Today's date is 2024-06-03." - in mock_create.mock_calls[1][2]["messages"][0]["content"] - ) - - assert result.response.response_type == intent.IntentResponseType.ACTION_DONE - assert mock_create.mock_calls[1][2]["messages"][3] == { - "role": "tool", - "tool_call_id": "call_AbCdEfGhIjKlMnOpQrStUvWx", - "content": '"Test response"', - } - mock_tool.async_call.assert_awaited_once_with( - hass, - llm.ToolInput( - id="call_AbCdEfGhIjKlMnOpQrStUvWx", - tool_name="test_tool", - tool_args={"param1": "test_value"}, ), - llm.LLMContext( - platform="openai_conversation", - context=context, - user_prompt="Please call the test function", - language="en", - assistant="conversation", - device_id=None, + ( + "Test invalid JSON", + ( + ChatCompletionChunk( + id="chatcmpl-A", + created=1700000000, + model="gpt-4-1106-preview", + object="chat.completion.chunk", + choices=[ + Choice( + index=0, + delta=ChoiceDelta( + tool_calls=[ + ChoiceDeltaToolCall( + id="call_call_1", + index=0, + function=ChoiceDeltaToolCallFunction( + name="test_tool", + arguments=None, + ), + ) + ] + ), + ) + ], + ), + ChatCompletionChunk( + id="chatcmpl-A", + created=1700000000, + model="gpt-4-1106-preview", + object="chat.completion.chunk", + choices=[ + Choice( + index=0, + delta=ChoiceDelta( + tool_calls=[ + ChoiceDeltaToolCall( + index=0, + function=ChoiceDeltaToolCallFunction( + name=None, + arguments='{"para', + ), + ) + ] + ), + ) + ], + ), + ChatCompletionChunk( + id="chatcmpl-B", + created=1700000000, + model="gpt-4-1106-preview", + object="chat.completion.chunk", + choices=[ + Choice( + index=0, + delta=ChoiceDelta(content="Cool"), + finish_reason="tool_calls", + ) + ], + ), + ), ), - ) - - # Test Conversation tracing - traces = trace.async_get_traces() - assert traces - last_trace = traces[-1].as_dict() - trace_events = last_trace.get("events", []) - assert [event["event_type"] for event in trace_events] == [ - trace.ConversationTraceEventType.ASYNC_PROCESS, - trace.ConversationTraceEventType.AGENT_DETAIL, - trace.ConversationTraceEventType.TOOL_CALL, - ] - # AGENT_DETAIL event contains the raw prompt passed to the model - detail_event = trace_events[1] - assert "Answer in plain text" in detail_event["data"]["messages"][0]["content"] - assert ( - "Today's date is 2024-06-03." - in trace_events[1]["data"]["messages"][0]["content"] - ) - assert [t.name for t in detail_event["data"]["tools"]] == ["test_tool"] - - # Call it again, make sure we have updated prompt - with ( - patch( - "openai.resources.chat.completions.AsyncCompletions.create", - new_callable=AsyncMock, - side_effect=completion_result, - ) as mock_create, - freeze_time("2024-06-04 23:00:00"), - ): - result = await conversation.async_converse( - hass, - "Please call the test function", - None, - context, - agent_id=agent_id, - ) - - assert ( - "Today's date is 2024-06-04." - in mock_create.mock_calls[1][2]["messages"][0]["content"] - ) - # Test old assert message not updated - assert ( - "Today's date is 2024-06-03." - in trace_events[1]["data"]["messages"][0]["content"] - ) - - -@patch( - "homeassistant.components.openai_conversation.conversation.llm.AssistAPI._async_get_tools" + ], ) -async def test_function_exception( - mock_get_tools, +async def test_function_call_invalid( hass: HomeAssistant, mock_config_entry_with_assist: MockConfigEntry, mock_init_component, + mock_create_stream: AsyncMock, + mock_chat_log: MockChatLog, + description: str, + messages: tuple[ChatCompletionChunk], ) -> None: - """Test function call with exception.""" - agent_id = mock_config_entry_with_assist.entry_id - context = Context() + """Test function call containing invalid data.""" + mock_create_stream.return_value = [messages] - mock_tool = AsyncMock() - mock_tool.name = "test_tool" - mock_tool.description = "Test function" - mock_tool.parameters = vol.Schema( - {vol.Optional("param1", description="Test parameters"): str} - ) - mock_tool.async_call.side_effect = HomeAssistantError("Test tool exception") - - mock_get_tools.return_value = [mock_tool] - - def completion_result(*args, messages, **kwargs): - for message in messages: - role = message["role"] if isinstance(message, dict) else message.role - if role == "tool": - return ChatCompletion( - id="chatcmpl-1234567890ZYXWVUTSRQPONMLKJIH", - choices=[ - Choice( - finish_reason="stop", - index=0, - message=ChatCompletionMessage( - content="There was an error calling the function", - role="assistant", - function_call=None, - tool_calls=None, - ), - ) - ], - created=1700000000, - model="gpt-4-1106-preview", - object="chat.completion", - system_fingerprint=None, - usage=CompletionUsage( - completion_tokens=9, prompt_tokens=8, total_tokens=17 - ), - ) - - return ChatCompletion( - id="chatcmpl-1234567890ABCDEFGHIJKLMNOPQRS", - choices=[ - Choice( - finish_reason="tool_calls", - index=0, - message=ChatCompletionMessage( - content=None, - role="assistant", - function_call=None, - tool_calls=[ - ChatCompletionMessageToolCall( - id="call_AbCdEfGhIjKlMnOpQrStUvWx", - function=Function( - arguments='{"param1":"test_value"}', - name="test_tool", - ), - type="function", - ) - ], - ), - ) - ], - created=1700000000, - model="gpt-4-1106-preview", - object="chat.completion", - system_fingerprint=None, - usage=CompletionUsage( - completion_tokens=9, prompt_tokens=8, total_tokens=17 - ), - ) - - with patch( - "openai.resources.chat.completions.AsyncCompletions.create", - new_callable=AsyncMock, - side_effect=completion_result, - ) as mock_create: - result = await conversation.async_converse( + with pytest.raises(ValueError): + await conversation.async_converse( hass, "Please call the test function", - None, - context, - agent_id=agent_id, + "mock-conversation-id", + Context(), + agent_id="conversation.openai", ) - assert result.response.response_type == intent.IntentResponseType.ACTION_DONE - assert mock_create.mock_calls[1][2]["messages"][3] == { - "role": "tool", - "tool_call_id": "call_AbCdEfGhIjKlMnOpQrStUvWx", - "content": '{"error": "HomeAssistantError", "error_text": "Test tool exception"}', - } - mock_tool.async_call.assert_awaited_once_with( - hass, - llm.ToolInput( - id="call_AbCdEfGhIjKlMnOpQrStUvWx", - tool_name="test_tool", - tool_args={"param1": "test_value"}, - ), - llm.LLMContext( - platform="openai_conversation", - context=context, - user_prompt="Please call the test function", - language="en", - assistant="conversation", - device_id=None, - ), - ) - async def test_assist_api_tools_conversion( hass: HomeAssistant, mock_config_entry_with_assist: MockConfigEntry, mock_init_component, + mock_create_stream, ) -> None: """Test that we are able to convert actual tools from Assist API.""" for component in ( - "intent", - "todo", - "light", - "shopping_list", - "humidifier", + "calendar", "climate", - "media_player", - "vacuum", "cover", + "humidifier", + "intent", + "light", + "media_player", + "script", + "shopping_list", + "todo", + "vacuum", "weather", ): assert await async_setup_component(hass, component, {}) + hass.states.async_set(f"{component}.test", "on") + async_expose_entity(hass, "conversation", f"{component}.test", True) - agent_id = mock_config_entry_with_assist.entry_id - with patch( - "openai.resources.chat.completions.AsyncCompletions.create", - new_callable=AsyncMock, - return_value=ChatCompletion( - id="chatcmpl-1234567890ABCDEFGHIJKLMNOPQRS", - choices=[ - Choice( - finish_reason="stop", - index=0, - message=ChatCompletionMessage( - content="Hello, how can I help you?", - role="assistant", - function_call=None, - tool_calls=None, - ), - ) - ], - created=1700000000, - model="gpt-3.5-turbo-0613", - object="chat.completion", - system_fingerprint=None, - usage=CompletionUsage( - completion_tokens=9, prompt_tokens=8, total_tokens=17 - ), - ), - ) as mock_create: - await conversation.async_converse( - hass, "hello", None, Context(), agent_id=agent_id - ) + mock_create_stream.return_value = [ASSIST_RESPONSE_FINISH] - tools = mock_create.mock_calls[0][2]["tools"] + await conversation.async_converse( + hass, "hello", None, Context(), agent_id="conversation.openai" + ) + + tools = mock_create_stream.mock_calls[0][2]["tools"] assert tools