Stream OpenAI messages into the chat log (#137400)

pull/138010/head^2
Paulus Schoutsen 2025-02-09 00:01:24 -05:00 committed by GitHub
parent a526baa831
commit df307aeb6d
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
8 changed files with 964 additions and 433 deletions

View File

@ -32,6 +32,7 @@ from .agent_manager import (
)
from .chat_log import (
AssistantContent,
AssistantContentDeltaDict,
ChatLog,
Content,
ConverseError,
@ -65,6 +66,7 @@ __all__ = [
"HOME_ASSISTANT_AGENT",
"OLD_HOME_ASSISTANT_AGENT",
"AssistantContent",
"AssistantContentDeltaDict",
"ChatLog",
"Content",
"ConversationEntity",

View File

@ -3,11 +3,12 @@
from __future__ import annotations
import asyncio
from collections.abc import AsyncGenerator, Generator
from collections.abc import AsyncGenerator, AsyncIterable, Generator
from contextlib import contextmanager
from contextvars import ContextVar
from dataclasses import dataclass, field, replace
import logging
from typing import Literal, TypedDict
import voluptuous as vol
@ -145,6 +146,14 @@ class ToolResultContent:
type Content = SystemContent | UserContent | AssistantContent | ToolResultContent
class AssistantContentDeltaDict(TypedDict, total=False):
"""Partial content to define an AssistantContent."""
role: Literal["assistant"]
content: str | None
tool_calls: list[llm.ToolInput] | None
@dataclass
class ChatLog:
"""Class holding the chat history of a specific conversation."""
@ -155,6 +164,11 @@ class ChatLog:
extra_system_prompt: str | None = None
llm_api: llm.APIInstance | None = None
@property
def unresponded_tool_results(self) -> bool:
"""Return if there are unresponded tool results."""
return self.content[-1].role == "tool_result"
@callback
def async_add_user_content(self, content: UserContent) -> None:
"""Add user content to the log."""
@ -223,6 +237,77 @@ class ChatLog:
self.content.append(response_content)
yield response_content
async def async_add_delta_content_stream(
self, agent_id: str, stream: AsyncIterable[AssistantContentDeltaDict]
) -> AsyncGenerator[AssistantContent | ToolResultContent]:
"""Stream content into the chat log.
Returns a generator with all content that was added to the chat log.
stream iterates over dictionaries with optional keys role, content and tool_calls.
When a delta contains a role key, the current message is considered complete and
a new message is started.
The keys content and tool_calls will be concatenated if they appear multiple times.
"""
current_content = ""
current_tool_calls: list[llm.ToolInput] = []
tool_call_tasks: dict[str, asyncio.Task] = {}
async for delta in stream:
LOGGER.debug("Received delta: %s", delta)
# Indicates update to current message
if "role" not in delta:
if delta_content := delta.get("content"):
current_content += delta_content
if delta_tool_calls := delta.get("tool_calls"):
if self.llm_api is None:
raise ValueError("No LLM API configured")
current_tool_calls += delta_tool_calls
# Start processing the tool calls as soon as we know about them
for tool_call in delta_tool_calls:
tool_call_tasks[tool_call.id] = self.hass.async_create_task(
self.llm_api.async_call_tool(tool_call),
name=f"llm_tool_{tool_call.id}",
)
continue
# Starting a new message
if delta["role"] != "assistant":
raise ValueError(f"Only assistant role expected. Got {delta['role']}")
# Yield the previous message if it has content
if current_content or current_tool_calls:
content = AssistantContent(
agent_id=agent_id,
content=current_content or None,
tool_calls=current_tool_calls or None,
)
yield content
async for tool_result in self.async_add_assistant_content(
content, tool_call_tasks=tool_call_tasks
):
yield tool_result
current_content = delta.get("content") or ""
current_tool_calls = delta.get("tool_calls") or []
if current_content or current_tool_calls:
content = AssistantContent(
agent_id=agent_id,
content=current_content or None,
tool_calls=current_tool_calls or None,
)
yield content
async for tool_result in self.async_add_assistant_content(
content, tool_call_tasks=tool_call_tasks
):
yield tool_result
async def async_update_llm_data(
self,
conversing_domain: str,

View File

@ -1,14 +1,15 @@
"""Conversation support for OpenAI."""
from collections.abc import Callable
from collections.abc import AsyncGenerator, Callable
import json
from typing import Any, Literal, cast
import openai
from openai._streaming import AsyncStream
from openai._types import NOT_GIVEN
from openai.types.chat import (
ChatCompletionAssistantMessageParam,
ChatCompletionMessage,
ChatCompletionChunk,
ChatCompletionMessageParam,
ChatCompletionMessageToolCallParam,
ChatCompletionToolMessageParam,
@ -70,32 +71,6 @@ def _format_tool(
return ChatCompletionToolParam(type="function", function=tool_spec)
def _convert_message_to_param(
message: ChatCompletionMessage,
) -> ChatCompletionMessageParam:
"""Convert from class to TypedDict."""
tool_calls: list[ChatCompletionMessageToolCallParam] = []
if message.tool_calls:
tool_calls = [
ChatCompletionMessageToolCallParam(
id=tool_call.id,
function=Function(
arguments=tool_call.function.arguments,
name=tool_call.function.name,
),
type=tool_call.type,
)
for tool_call in message.tool_calls
]
param = ChatCompletionAssistantMessageParam(
role=message.role,
content=message.content,
)
if tool_calls:
param["tool_calls"] = tool_calls
return param
def _convert_content_to_param(
content: conversation.Content,
) -> ChatCompletionMessageParam:
@ -135,6 +110,74 @@ def _convert_content_to_param(
)
async def _transform_stream(
result: AsyncStream[ChatCompletionChunk],
) -> AsyncGenerator[conversation.AssistantContentDeltaDict]:
"""Transform an OpenAI delta stream into HA format."""
current_tool_call: dict | None = None
async for chunk in result:
LOGGER.debug("Received chunk: %s", chunk)
choice = chunk.choices[0]
if choice.finish_reason:
if current_tool_call:
yield {
"tool_calls": [
llm.ToolInput(
id=current_tool_call["id"],
tool_name=current_tool_call["tool_name"],
tool_args=json.loads(current_tool_call["tool_args"]),
)
]
}
break
delta = chunk.choices[0].delta
# We can yield delta messages not continuing or starting tool calls
if current_tool_call is None and not delta.tool_calls:
yield { # type: ignore[misc]
key: value
for key in ("role", "content")
if (value := getattr(delta, key)) is not None
}
continue
# When doing tool calls, we should always have a tool call
# object or we have gotten stopped above with a finish_reason set.
if (
not delta.tool_calls
or not (delta_tool_call := delta.tool_calls[0])
or not delta_tool_call.function
):
raise ValueError("Expected delta with tool call")
if current_tool_call and delta_tool_call.index == current_tool_call["index"]:
current_tool_call["tool_args"] += delta_tool_call.function.arguments or ""
continue
# We got tool call with new index, so we need to yield the previous
if current_tool_call:
yield {
"tool_calls": [
llm.ToolInput(
id=current_tool_call["id"],
tool_name=current_tool_call["tool_name"],
tool_args=json.loads(current_tool_call["tool_args"]),
)
]
}
current_tool_call = {
"index": delta_tool_call.index,
"id": delta_tool_call.id,
"tool_name": delta_tool_call.function.name,
"tool_args": delta_tool_call.function.arguments or "",
}
class OpenAIConversationEntity(
conversation.ConversationEntity, conversation.AbstractConversationAgent
):
@ -234,6 +277,7 @@ class OpenAIConversationEntity(
"top_p": options.get(CONF_TOP_P, RECOMMENDED_TOP_P),
"temperature": options.get(CONF_TEMPERATURE, RECOMMENDED_TEMPERATURE),
"user": chat_log.conversation_id,
"stream": True,
}
if model.startswith("o"):
@ -247,39 +291,21 @@ class OpenAIConversationEntity(
LOGGER.error("Error talking to OpenAI: %s", err)
raise HomeAssistantError("Error talking to OpenAI") from err
LOGGER.debug("Response %s", result)
response = result.choices[0].message
messages.append(_convert_message_to_param(response))
tool_calls: list[llm.ToolInput] | None = None
if response.tool_calls:
tool_calls = [
llm.ToolInput(
id=tool_call.id,
tool_name=tool_call.function.name,
tool_args=json.loads(tool_call.function.arguments),
)
for tool_call in response.tool_calls
]
messages.extend(
[
_convert_content_to_param(tool_response)
async for tool_response in chat_log.async_add_assistant_content(
conversation.AssistantContent(
agent_id=user_input.agent_id,
content=response.content or "",
tool_calls=tool_calls,
)
_convert_content_to_param(content)
async for content in chat_log.async_add_delta_content_stream(
user_input.agent_id, _transform_stream(result)
)
]
)
if not tool_calls:
if not chat_log.unresponded_tool_results:
break
intent_response = intent.IntentResponse(language=user_input.language)
intent_response.async_set_speech(response.content or "")
assert type(chat_log.content[-1]) is conversation.AssistantContent
intent_response.async_set_speech(chat_log.content[-1].content or "")
return conversation.ConversationResult(
response=intent_response, conversation_id=chat_log.conversation_id
)

View File

@ -7,6 +7,7 @@ from contextlib import contextmanager
from contextvars import ContextVar
from dataclasses import dataclass, field
from datetime import datetime, timedelta
import logging
from homeassistant.const import EVENT_HOMEASSISTANT_STOP
from homeassistant.core import (
@ -27,6 +28,7 @@ DATA_CHAT_SESSION: HassKey[dict[str, ChatSession]] = HassKey("chat_session")
DATA_CHAT_SESSION_CLEANUP: HassKey[SessionCleanup] = HassKey("chat_session_cleanup")
CONVERSATION_TIMEOUT = timedelta(minutes=5)
LOGGER = logging.getLogger(__name__)
current_session: ContextVar[ChatSession | None] = ContextVar(
"current_session", default=None
@ -100,6 +102,7 @@ class SessionCleanup:
# yielding session based on it.
for conversation_id, session in list(all_sessions.items()):
if session.last_updated + CONVERSATION_TIMEOUT < now:
LOGGER.debug("Cleaning up session %s", conversation_id)
del all_sessions[conversation_id]
session.async_cleanup()
@ -150,6 +153,7 @@ def async_get_chat_session(
pass
if session is None:
LOGGER.debug("Creating new session %s", conversation_id)
session = ChatSession(conversation_id)
current_session.set(session)

View File

@ -1,4 +1,154 @@
# serializer version: 1
# name: test_add_delta_content_stream[deltas0]
list([
])
# ---
# name: test_add_delta_content_stream[deltas1]
list([
dict({
'agent_id': 'mock-agent-id',
'content': 'Test',
'role': 'assistant',
'tool_calls': None,
}),
])
# ---
# name: test_add_delta_content_stream[deltas2]
list([
dict({
'agent_id': 'mock-agent-id',
'content': 'Test',
'role': 'assistant',
'tool_calls': None,
}),
dict({
'agent_id': 'mock-agent-id',
'content': 'Test 2',
'role': 'assistant',
'tool_calls': None,
}),
])
# ---
# name: test_add_delta_content_stream[deltas3]
list([
dict({
'agent_id': 'mock-agent-id',
'content': None,
'role': 'assistant',
'tool_calls': list([
dict({
'id': 'mock-tool-call-id',
'tool_args': dict({
'param1': 'Test Param 1',
}),
'tool_name': 'test_tool',
}),
]),
}),
dict({
'agent_id': 'mock-agent-id',
'role': 'tool_result',
'tool_call_id': 'mock-tool-call-id',
'tool_name': 'test_tool',
'tool_result': 'Test Param 1',
}),
])
# ---
# name: test_add_delta_content_stream[deltas4]
list([
dict({
'agent_id': 'mock-agent-id',
'content': 'Test',
'role': 'assistant',
'tool_calls': list([
dict({
'id': 'mock-tool-call-id',
'tool_args': dict({
'param1': 'Test Param 1',
}),
'tool_name': 'test_tool',
}),
]),
}),
dict({
'agent_id': 'mock-agent-id',
'role': 'tool_result',
'tool_call_id': 'mock-tool-call-id',
'tool_name': 'test_tool',
'tool_result': 'Test Param 1',
}),
])
# ---
# name: test_add_delta_content_stream[deltas5]
list([
dict({
'agent_id': 'mock-agent-id',
'content': 'Test',
'role': 'assistant',
'tool_calls': list([
dict({
'id': 'mock-tool-call-id',
'tool_args': dict({
'param1': 'Test Param 1',
}),
'tool_name': 'test_tool',
}),
]),
}),
dict({
'agent_id': 'mock-agent-id',
'role': 'tool_result',
'tool_call_id': 'mock-tool-call-id',
'tool_name': 'test_tool',
'tool_result': 'Test Param 1',
}),
dict({
'agent_id': 'mock-agent-id',
'content': 'Test 2',
'role': 'assistant',
'tool_calls': None,
}),
])
# ---
# name: test_add_delta_content_stream[deltas6]
list([
dict({
'agent_id': 'mock-agent-id',
'content': None,
'role': 'assistant',
'tool_calls': list([
dict({
'id': 'mock-tool-call-id',
'tool_args': dict({
'param1': 'Test Param 1',
}),
'tool_name': 'test_tool',
}),
dict({
'id': 'mock-tool-call-id-2',
'tool_args': dict({
'param1': 'Test Param 2',
}),
'tool_name': 'test_tool',
}),
]),
}),
dict({
'agent_id': 'mock-agent-id',
'role': 'tool_result',
'tool_call_id': 'mock-tool-call-id',
'tool_name': 'test_tool',
'tool_result': 'Test Param 1',
}),
dict({
'agent_id': 'mock-agent-id',
'role': 'tool_result',
'tool_call_id': 'mock-tool-call-id-2',
'tool_name': 'test_tool',
'tool_result': 'Test Param 2',
}),
])
# ---
# name: test_template_error
dict({
'conversation_id': <ANY>,

View File

@ -282,7 +282,7 @@ async def test_extra_systen_prompt(
@pytest.mark.parametrize(
"prerun_tool_tasks",
[
None,
(),
("mock-tool-call-id",),
("mock-tool-call-id", "mock-tool-call-id-2"),
],
@ -290,7 +290,7 @@ async def test_extra_systen_prompt(
async def test_tool_call(
hass: HomeAssistant,
mock_conversation_input: ConversationInput,
prerun_tool_tasks: tuple[str] | None,
prerun_tool_tasks: tuple[str],
) -> None:
"""Test using the session tool calling API."""
@ -334,15 +334,13 @@ async def test_tool_call(
],
)
tool_call_tasks = None
if prerun_tool_tasks:
tool_call_tasks = {
tool_call_id: hass.async_create_task(
chat_log.llm_api.async_call_tool(content.tool_calls[0]),
tool_call_id,
)
for tool_call_id in prerun_tool_tasks
}
tool_call_tasks = {
tool_call_id: hass.async_create_task(
chat_log.llm_api.async_call_tool(content.tool_calls[0]),
tool_call_id,
)
for tool_call_id in prerun_tool_tasks
}
with pytest.raises(ValueError):
chat_log.async_add_assistant_content_without_tools(content)
@ -350,7 +348,7 @@ async def test_tool_call(
results = [
tool_result_content
async for tool_result_content in chat_log.async_add_assistant_content(
content, tool_call_tasks=tool_call_tasks
content, tool_call_tasks=tool_call_tasks or None
)
]
@ -382,37 +380,36 @@ async def test_tool_call_exception(
)
mock_tool.async_call.side_effect = HomeAssistantError("Test error")
with patch(
"homeassistant.helpers.llm.AssistAPI._async_get_tools", return_value=[]
) as mock_get_tools:
with (
patch(
"homeassistant.helpers.llm.AssistAPI._async_get_tools", return_value=[]
) as mock_get_tools,
chat_session.async_get_chat_session(hass) as session,
async_get_chat_log(hass, session, mock_conversation_input) as chat_log,
):
mock_get_tools.return_value = [mock_tool]
with (
chat_session.async_get_chat_session(hass) as session,
async_get_chat_log(hass, session, mock_conversation_input) as chat_log,
):
await chat_log.async_update_llm_data(
conversing_domain="test",
user_input=mock_conversation_input,
user_llm_hass_api="assist",
user_llm_prompt=None,
await chat_log.async_update_llm_data(
conversing_domain="test",
user_input=mock_conversation_input,
user_llm_hass_api="assist",
user_llm_prompt=None,
)
result = None
async for tool_result_content in chat_log.async_add_assistant_content(
AssistantContent(
agent_id=mock_conversation_input.agent_id,
content="",
tool_calls=[
llm.ToolInput(
id="mock-tool-call-id",
tool_name="test_tool",
tool_args={"param1": "Test Param"},
)
],
)
result = None
async for tool_result_content in chat_log.async_add_assistant_content(
AssistantContent(
agent_id=mock_conversation_input.agent_id,
content="",
tool_calls=[
llm.ToolInput(
id="mock-tool-call-id",
tool_name="test_tool",
tool_args={"param1": "Test Param"},
)
],
)
):
assert result is None
result = tool_result_content
):
assert result is None
result = tool_result_content
assert result == ToolResultContent(
agent_id=mock_conversation_input.agent_id,
@ -420,3 +417,188 @@ async def test_tool_call_exception(
tool_result={"error": "HomeAssistantError", "error_text": "Test error"},
tool_name="test_tool",
)
@pytest.mark.parametrize(
"deltas",
[
[],
# With content
[
{"role": "assistant"},
{"content": "Test"},
],
# With 2 content
[
{"role": "assistant"},
{"content": "Test"},
{"role": "assistant"},
{"content": "Test 2"},
],
# With 1 tool call
[
{"role": "assistant"},
{
"tool_calls": [
llm.ToolInput(
id="mock-tool-call-id",
tool_name="test_tool",
tool_args={"param1": "Test Param 1"},
)
]
},
],
# With content and 1 tool call
[
{"role": "assistant"},
{"content": "Test"},
{
"tool_calls": [
llm.ToolInput(
id="mock-tool-call-id",
tool_name="test_tool",
tool_args={"param1": "Test Param 1"},
)
]
},
],
# With 2 contents and 1 tool call
[
{"role": "assistant"},
{"content": "Test"},
{
"tool_calls": [
llm.ToolInput(
id="mock-tool-call-id",
tool_name="test_tool",
tool_args={"param1": "Test Param 1"},
)
]
},
{"role": "assistant"},
{"content": "Test 2"},
],
# With 2 tool calls
[
{"role": "assistant"},
{
"tool_calls": [
llm.ToolInput(
id="mock-tool-call-id",
tool_name="test_tool",
tool_args={"param1": "Test Param 1"},
)
]
},
{
"tool_calls": [
llm.ToolInput(
id="mock-tool-call-id-2",
tool_name="test_tool",
tool_args={"param1": "Test Param 2"},
)
]
},
],
],
)
async def test_add_delta_content_stream(
hass: HomeAssistant,
mock_conversation_input: ConversationInput,
snapshot: SnapshotAssertion,
deltas: list[dict],
) -> None:
"""Test streaming deltas."""
mock_tool = AsyncMock()
mock_tool.name = "test_tool"
mock_tool.description = "Test function"
mock_tool.parameters = vol.Schema(
{vol.Optional("param1", description="Test parameters"): str}
)
async def tool_call(
hass: HomeAssistant, tool_input: llm.ToolInput, llm_context: llm.LLMContext
) -> str:
"""Call the tool."""
return tool_input.tool_args["param1"]
mock_tool.async_call.side_effect = tool_call
async def stream():
"""Yield deltas."""
for d in deltas:
yield d
with (
patch(
"homeassistant.helpers.llm.AssistAPI._async_get_tools", return_value=[]
) as mock_get_tools,
chat_session.async_get_chat_session(hass) as session,
async_get_chat_log(hass, session, mock_conversation_input) as chat_log,
):
mock_get_tools.return_value = [mock_tool]
await chat_log.async_update_llm_data(
conversing_domain="test",
user_input=mock_conversation_input,
user_llm_hass_api="assist",
user_llm_prompt=None,
)
results = [
tool_result_content
async for tool_result_content in chat_log.async_add_delta_content_stream(
"mock-agent-id", stream()
)
]
assert results == snapshot
assert chat_log.content[2:] == results
async def test_add_delta_content_stream_errors(
hass: HomeAssistant,
mock_conversation_input: ConversationInput,
) -> None:
"""Test streaming deltas error handling."""
async def stream(deltas):
"""Yield deltas."""
for d in deltas:
yield d
with (
chat_session.async_get_chat_session(hass) as session,
async_get_chat_log(hass, session, mock_conversation_input) as chat_log,
):
# Stream content without LLM API set
with pytest.raises(ValueError): # noqa: PT012
async for _tool_result_content in chat_log.async_add_delta_content_stream(
"mock-agent-id",
stream(
[
{"role": "assistant"},
{
"tool_calls": [
llm.ToolInput(
id="mock-tool-call-id",
tool_name="test_tool",
tool_args={},
)
]
},
]
),
):
pass
# Non assistant role
for role in "system", "user":
with pytest.raises(ValueError): # noqa: PT012
async for (
_tool_result_content
) in chat_log.async_add_delta_content_stream(
"mock-agent-id",
stream([{"role": role}]),
):
pass

View File

@ -1,34 +1,64 @@
# serializer version: 1
# name: test_unknown_hass_api
dict({
'conversation_id': 'my-conversation-id',
'response': IntentResponse(
card=dict({
}),
error_code=<IntentResponseErrorCode.UNKNOWN: 'unknown'>,
failed_results=list([
]),
intent=None,
intent_targets=list([
]),
language='en',
matched_states=list([
]),
reprompt=dict({
}),
response_type=<IntentResponseType.ERROR: 'error'>,
speech=dict({
'plain': dict({
'extra_data': None,
'speech': 'Error preparing LLM API',
# name: test_function_call
list([
dict({
'content': '''
Current time is 16:00:00. Today's date is 2024-06-03.
You are a voice assistant for Home Assistant.
Answer questions about the world truthfully.
Answer in plain text. Keep it simple and to the point.
Only if the user wants to control a device, tell them to expose entities to their voice assistant in Home Assistant.
''',
'role': 'system',
}),
dict({
'content': 'hello',
'role': 'user',
}),
dict({
'content': 'Please call the test function',
'role': 'user',
}),
dict({
'agent_id': 'conversation.openai',
'content': None,
'role': 'assistant',
'tool_calls': list([
dict({
'id': 'call_call_1',
'tool_args': dict({
'param1': 'call1',
}),
'tool_name': 'test_tool',
}),
dict({
'id': 'call_call_2',
'tool_args': dict({
'param1': 'call2',
}),
'tool_name': 'test_tool',
}),
}),
speech_slots=dict({
}),
success_results=list([
]),
unmatched_states=list([
]),
),
})
}),
dict({
'agent_id': 'conversation.openai',
'role': 'tool_result',
'tool_call_id': 'call_call_1',
'tool_name': 'test_tool',
'tool_result': 'value1',
}),
dict({
'agent_id': 'conversation.openai',
'role': 'tool_result',
'tool_call_id': 'call_call_2',
'tool_name': 'test_tool',
'tool_result': 'value2',
}),
dict({
'agent_id': 'conversation.openai',
'content': 'Cool',
'role': 'assistant',
'tool_calls': None,
}),
])
# ---

View File

@ -1,29 +1,130 @@
"""Tests for the OpenAI integration."""
from collections.abc import Generator
from dataclasses import dataclass, field
from unittest.mock import AsyncMock, patch
from freezegun import freeze_time
from httpx import Response
from openai import RateLimitError
from openai.types.chat.chat_completion import ChatCompletion, Choice
from openai.types.chat.chat_completion_message import ChatCompletionMessage
from openai.types.chat.chat_completion_message_tool_call import (
ChatCompletionMessageToolCall,
Function,
from openai.types.chat.chat_completion_chunk import (
ChatCompletionChunk,
Choice,
ChoiceDelta,
ChoiceDeltaToolCall,
ChoiceDeltaToolCallFunction,
)
from openai.types.completion_usage import CompletionUsage
import voluptuous as vol
import pytest
from syrupy.assertion import SnapshotAssertion
from homeassistant.components import conversation
from homeassistant.components.conversation import trace
from homeassistant.components.conversation import chat_log
from homeassistant.components.homeassistant.exposed_entities import async_expose_entity
from homeassistant.const import CONF_LLM_HASS_API
from homeassistant.core import Context, HomeAssistant
from homeassistant.exceptions import HomeAssistantError
from homeassistant.helpers import intent, llm
from homeassistant.helpers import chat_session, intent
from homeassistant.setup import async_setup_component
from tests.common import MockConfigEntry
ASSIST_RESPONSE_FINISH = (
# Assistant message
ChatCompletionChunk(
id="chatcmpl-B",
created=1700000000,
model="gpt-4-1106-preview",
object="chat.completion.chunk",
choices=[Choice(index=0, delta=ChoiceDelta(content="Cool"))],
),
# Finish stream
ChatCompletionChunk(
id="chatcmpl-B",
created=1700000000,
model="gpt-4-1106-preview",
object="chat.completion.chunk",
choices=[Choice(index=0, finish_reason="stop", delta=ChoiceDelta())],
),
)
@pytest.fixture
def mock_create_stream() -> Generator[AsyncMock]:
"""Mock stream response."""
async def mock_generator(stream):
for value in stream:
yield value
with patch(
"openai.resources.chat.completions.AsyncCompletions.create",
AsyncMock(),
) as mock_create:
mock_create.side_effect = lambda **kwargs: mock_generator(
mock_create.return_value.pop(0)
)
yield mock_create
@dataclass
class MockChatLog(chat_log.ChatLog):
"""Mock chat log."""
_mock_tool_results: dict = field(default_factory=dict)
def mock_tool_results(self, results: dict) -> None:
"""Set tool results."""
self._mock_tool_results = results
@property
def llm_api(self):
"""Return LLM API."""
return self._llm_api
@llm_api.setter
def llm_api(self, value):
"""Set LLM API."""
self._llm_api = value
if not value:
return
async def async_call_tool(tool_input):
"""Call tool."""
if tool_input.id not in self._mock_tool_results:
raise ValueError(f"Tool {tool_input.id} not found")
return self._mock_tool_results[tool_input.id]
self._llm_api.async_call_tool = async_call_tool
def latest_content(self) -> list[conversation.Content]:
"""Return content from latest version chat log.
The chat log makes copies until it's committed. Helper to get latest content.
"""
with (
chat_session.async_get_chat_session(
self.hass, self.conversation_id
) as session,
conversation.async_get_chat_log(self.hass, session) as chat_log,
):
return chat_log.content
@pytest.fixture
async def mock_chat_log(hass: HomeAssistant) -> MockChatLog:
"""Return mock chat logs."""
with (
patch(
"homeassistant.components.conversation.chat_log.ChatLog",
MockChatLog,
),
chat_session.async_get_chat_session(hass, "mock-conversation-id") as session,
conversation.async_get_chat_log(hass, session) as chat_log,
):
chat_log.async_add_user_content(conversation.UserContent("hello"))
return chat_log
async def test_entity(
hass: HomeAssistant,
@ -83,348 +184,299 @@ async def test_conversation_agent(
assert agent.supported_languages == "*"
@patch(
"homeassistant.components.openai_conversation.conversation.llm.AssistAPI._async_get_tools"
)
async def test_function_call(
mock_get_tools,
hass: HomeAssistant,
mock_config_entry_with_assist: MockConfigEntry,
mock_init_component,
mock_create_stream: AsyncMock,
mock_chat_log: MockChatLog,
snapshot: SnapshotAssertion,
) -> None:
"""Test function call from the assistant."""
agent_id = mock_config_entry_with_assist.entry_id
context = Context()
mock_tool = AsyncMock()
mock_tool.name = "test_tool"
mock_tool.description = "Test function"
mock_tool.parameters = vol.Schema(
{vol.Optional("param1", description="Test parameters"): str}
mock_create_stream.return_value = [
# Initial conversation
(
# First tool call
ChatCompletionChunk(
id="chatcmpl-A",
created=1700000000,
model="gpt-4-1106-preview",
object="chat.completion.chunk",
choices=[
Choice(
index=0,
delta=ChoiceDelta(
tool_calls=[
ChoiceDeltaToolCall(
id="call_call_1",
index=0,
function=ChoiceDeltaToolCallFunction(
name="test_tool",
arguments=None,
),
)
]
),
)
],
),
ChatCompletionChunk(
id="chatcmpl-A",
created=1700000000,
model="gpt-4-1106-preview",
object="chat.completion.chunk",
choices=[
Choice(
index=0,
delta=ChoiceDelta(
tool_calls=[
ChoiceDeltaToolCall(
index=0,
function=ChoiceDeltaToolCallFunction(
name=None,
arguments='{"para',
),
)
]
),
)
],
),
ChatCompletionChunk(
id="chatcmpl-A",
created=1700000000,
model="gpt-4-1106-preview",
object="chat.completion.chunk",
choices=[
Choice(
index=0,
delta=ChoiceDelta(
tool_calls=[
ChoiceDeltaToolCall(
index=0,
function=ChoiceDeltaToolCallFunction(
name=None,
arguments='m1":"call1"}',
),
)
]
),
)
],
),
# Second tool call
ChatCompletionChunk(
id="chatcmpl-A",
created=1700000000,
model="gpt-4-1106-preview",
object="chat.completion.chunk",
choices=[
Choice(
index=0,
delta=ChoiceDelta(
tool_calls=[
ChoiceDeltaToolCall(
id="call_call_2",
index=1,
function=ChoiceDeltaToolCallFunction(
name="test_tool",
arguments='{"param1":"call2"}',
),
)
]
),
)
],
),
# Finish stream
ChatCompletionChunk(
id="chatcmpl-A",
created=1700000000,
model="gpt-4-1106-preview",
object="chat.completion.chunk",
choices=[
Choice(index=0, finish_reason="tool_calls", delta=ChoiceDelta())
],
),
),
# Response after tool responses
ASSIST_RESPONSE_FINISH,
]
mock_chat_log.mock_tool_results(
{
"call_call_1": "value1",
"call_call_2": "value2",
}
)
mock_tool.async_call.return_value = "Test response"
mock_get_tools.return_value = [mock_tool]
with freeze_time("2024-06-03 23:00:00"):
result = await conversation.async_converse(
hass,
"Please call the test function",
"mock-conversation-id",
Context(),
agent_id="conversation.openai",
)
def completion_result(*args, messages, **kwargs):
for message in messages:
role = message["role"] if isinstance(message, dict) else message.role
if role == "tool":
return ChatCompletion(
id="chatcmpl-1234567890ZYXWVUTSRQPONMLKJIH",
assert result.response.response_type == intent.IntentResponseType.ACTION_DONE
assert mock_chat_log.latest_content() == snapshot
@pytest.mark.parametrize(
("description", "messages"),
[
(
"Test function call started with missing arguments",
(
ChatCompletionChunk(
id="chatcmpl-A",
created=1700000000,
model="gpt-4-1106-preview",
object="chat.completion.chunk",
choices=[
Choice(
finish_reason="stop",
index=0,
message=ChatCompletionMessage(
content="I have successfully called the function",
role="assistant",
function_call=None,
tool_calls=None,
delta=ChoiceDelta(
tool_calls=[
ChoiceDeltaToolCall(
id="call_call_1",
index=0,
function=ChoiceDeltaToolCallFunction(
name="test_tool",
arguments=None,
),
)
]
),
)
],
),
ChatCompletionChunk(
id="chatcmpl-B",
created=1700000000,
model="gpt-4-1106-preview",
object="chat.completion",
system_fingerprint=None,
usage=CompletionUsage(
completion_tokens=9, prompt_tokens=8, total_tokens=17
),
)
return ChatCompletion(
id="chatcmpl-1234567890ABCDEFGHIJKLMNOPQRS",
choices=[
Choice(
finish_reason="tool_calls",
index=0,
message=ChatCompletionMessage(
content=None,
role="assistant",
function_call=None,
tool_calls=[
ChatCompletionMessageToolCall(
id="call_AbCdEfGhIjKlMnOpQrStUvWx",
function=Function(
arguments='{"param1":"test_value"}',
name="test_tool",
),
type="function",
)
],
),
)
],
created=1700000000,
model="gpt-4-1106-preview",
object="chat.completion",
system_fingerprint=None,
usage=CompletionUsage(
completion_tokens=9, prompt_tokens=8, total_tokens=17
object="chat.completion.chunk",
choices=[Choice(index=0, delta=ChoiceDelta(content="Cool"))],
),
),
)
with (
patch(
"openai.resources.chat.completions.AsyncCompletions.create",
new_callable=AsyncMock,
side_effect=completion_result,
) as mock_create,
freeze_time("2024-06-03 23:00:00"),
):
result = await conversation.async_converse(
hass,
"Please call the test function",
None,
context,
agent_id=agent_id,
)
assert (
"Today's date is 2024-06-03."
in mock_create.mock_calls[1][2]["messages"][0]["content"]
)
assert result.response.response_type == intent.IntentResponseType.ACTION_DONE
assert mock_create.mock_calls[1][2]["messages"][3] == {
"role": "tool",
"tool_call_id": "call_AbCdEfGhIjKlMnOpQrStUvWx",
"content": '"Test response"',
}
mock_tool.async_call.assert_awaited_once_with(
hass,
llm.ToolInput(
id="call_AbCdEfGhIjKlMnOpQrStUvWx",
tool_name="test_tool",
tool_args={"param1": "test_value"},
),
llm.LLMContext(
platform="openai_conversation",
context=context,
user_prompt="Please call the test function",
language="en",
assistant="conversation",
device_id=None,
(
"Test invalid JSON",
(
ChatCompletionChunk(
id="chatcmpl-A",
created=1700000000,
model="gpt-4-1106-preview",
object="chat.completion.chunk",
choices=[
Choice(
index=0,
delta=ChoiceDelta(
tool_calls=[
ChoiceDeltaToolCall(
id="call_call_1",
index=0,
function=ChoiceDeltaToolCallFunction(
name="test_tool",
arguments=None,
),
)
]
),
)
],
),
ChatCompletionChunk(
id="chatcmpl-A",
created=1700000000,
model="gpt-4-1106-preview",
object="chat.completion.chunk",
choices=[
Choice(
index=0,
delta=ChoiceDelta(
tool_calls=[
ChoiceDeltaToolCall(
index=0,
function=ChoiceDeltaToolCallFunction(
name=None,
arguments='{"para',
),
)
]
),
)
],
),
ChatCompletionChunk(
id="chatcmpl-B",
created=1700000000,
model="gpt-4-1106-preview",
object="chat.completion.chunk",
choices=[
Choice(
index=0,
delta=ChoiceDelta(content="Cool"),
finish_reason="tool_calls",
)
],
),
),
),
)
# Test Conversation tracing
traces = trace.async_get_traces()
assert traces
last_trace = traces[-1].as_dict()
trace_events = last_trace.get("events", [])
assert [event["event_type"] for event in trace_events] == [
trace.ConversationTraceEventType.ASYNC_PROCESS,
trace.ConversationTraceEventType.AGENT_DETAIL,
trace.ConversationTraceEventType.TOOL_CALL,
]
# AGENT_DETAIL event contains the raw prompt passed to the model
detail_event = trace_events[1]
assert "Answer in plain text" in detail_event["data"]["messages"][0]["content"]
assert (
"Today's date is 2024-06-03."
in trace_events[1]["data"]["messages"][0]["content"]
)
assert [t.name for t in detail_event["data"]["tools"]] == ["test_tool"]
# Call it again, make sure we have updated prompt
with (
patch(
"openai.resources.chat.completions.AsyncCompletions.create",
new_callable=AsyncMock,
side_effect=completion_result,
) as mock_create,
freeze_time("2024-06-04 23:00:00"),
):
result = await conversation.async_converse(
hass,
"Please call the test function",
None,
context,
agent_id=agent_id,
)
assert (
"Today's date is 2024-06-04."
in mock_create.mock_calls[1][2]["messages"][0]["content"]
)
# Test old assert message not updated
assert (
"Today's date is 2024-06-03."
in trace_events[1]["data"]["messages"][0]["content"]
)
@patch(
"homeassistant.components.openai_conversation.conversation.llm.AssistAPI._async_get_tools"
],
)
async def test_function_exception(
mock_get_tools,
async def test_function_call_invalid(
hass: HomeAssistant,
mock_config_entry_with_assist: MockConfigEntry,
mock_init_component,
mock_create_stream: AsyncMock,
mock_chat_log: MockChatLog,
description: str,
messages: tuple[ChatCompletionChunk],
) -> None:
"""Test function call with exception."""
agent_id = mock_config_entry_with_assist.entry_id
context = Context()
"""Test function call containing invalid data."""
mock_create_stream.return_value = [messages]
mock_tool = AsyncMock()
mock_tool.name = "test_tool"
mock_tool.description = "Test function"
mock_tool.parameters = vol.Schema(
{vol.Optional("param1", description="Test parameters"): str}
)
mock_tool.async_call.side_effect = HomeAssistantError("Test tool exception")
mock_get_tools.return_value = [mock_tool]
def completion_result(*args, messages, **kwargs):
for message in messages:
role = message["role"] if isinstance(message, dict) else message.role
if role == "tool":
return ChatCompletion(
id="chatcmpl-1234567890ZYXWVUTSRQPONMLKJIH",
choices=[
Choice(
finish_reason="stop",
index=0,
message=ChatCompletionMessage(
content="There was an error calling the function",
role="assistant",
function_call=None,
tool_calls=None,
),
)
],
created=1700000000,
model="gpt-4-1106-preview",
object="chat.completion",
system_fingerprint=None,
usage=CompletionUsage(
completion_tokens=9, prompt_tokens=8, total_tokens=17
),
)
return ChatCompletion(
id="chatcmpl-1234567890ABCDEFGHIJKLMNOPQRS",
choices=[
Choice(
finish_reason="tool_calls",
index=0,
message=ChatCompletionMessage(
content=None,
role="assistant",
function_call=None,
tool_calls=[
ChatCompletionMessageToolCall(
id="call_AbCdEfGhIjKlMnOpQrStUvWx",
function=Function(
arguments='{"param1":"test_value"}',
name="test_tool",
),
type="function",
)
],
),
)
],
created=1700000000,
model="gpt-4-1106-preview",
object="chat.completion",
system_fingerprint=None,
usage=CompletionUsage(
completion_tokens=9, prompt_tokens=8, total_tokens=17
),
)
with patch(
"openai.resources.chat.completions.AsyncCompletions.create",
new_callable=AsyncMock,
side_effect=completion_result,
) as mock_create:
result = await conversation.async_converse(
with pytest.raises(ValueError):
await conversation.async_converse(
hass,
"Please call the test function",
None,
context,
agent_id=agent_id,
"mock-conversation-id",
Context(),
agent_id="conversation.openai",
)
assert result.response.response_type == intent.IntentResponseType.ACTION_DONE
assert mock_create.mock_calls[1][2]["messages"][3] == {
"role": "tool",
"tool_call_id": "call_AbCdEfGhIjKlMnOpQrStUvWx",
"content": '{"error": "HomeAssistantError", "error_text": "Test tool exception"}',
}
mock_tool.async_call.assert_awaited_once_with(
hass,
llm.ToolInput(
id="call_AbCdEfGhIjKlMnOpQrStUvWx",
tool_name="test_tool",
tool_args={"param1": "test_value"},
),
llm.LLMContext(
platform="openai_conversation",
context=context,
user_prompt="Please call the test function",
language="en",
assistant="conversation",
device_id=None,
),
)
async def test_assist_api_tools_conversion(
hass: HomeAssistant,
mock_config_entry_with_assist: MockConfigEntry,
mock_init_component,
mock_create_stream,
) -> None:
"""Test that we are able to convert actual tools from Assist API."""
for component in (
"intent",
"todo",
"light",
"shopping_list",
"humidifier",
"calendar",
"climate",
"media_player",
"vacuum",
"cover",
"humidifier",
"intent",
"light",
"media_player",
"script",
"shopping_list",
"todo",
"vacuum",
"weather",
):
assert await async_setup_component(hass, component, {})
hass.states.async_set(f"{component}.test", "on")
async_expose_entity(hass, "conversation", f"{component}.test", True)
agent_id = mock_config_entry_with_assist.entry_id
with patch(
"openai.resources.chat.completions.AsyncCompletions.create",
new_callable=AsyncMock,
return_value=ChatCompletion(
id="chatcmpl-1234567890ABCDEFGHIJKLMNOPQRS",
choices=[
Choice(
finish_reason="stop",
index=0,
message=ChatCompletionMessage(
content="Hello, how can I help you?",
role="assistant",
function_call=None,
tool_calls=None,
),
)
],
created=1700000000,
model="gpt-3.5-turbo-0613",
object="chat.completion",
system_fingerprint=None,
usage=CompletionUsage(
completion_tokens=9, prompt_tokens=8, total_tokens=17
),
),
) as mock_create:
await conversation.async_converse(
hass, "hello", None, Context(), agent_id=agent_id
)
mock_create_stream.return_value = [ASSIST_RESPONSE_FINISH]
tools = mock_create.mock_calls[0][2]["tools"]
await conversation.async_converse(
hass, "hello", None, Context(), agent_id="conversation.openai"
)
tools = mock_create_stream.mock_calls[0][2]["tools"]
assert tools