Add support for extra_system_prompt to OpenAI (#134931)

pull/134942/head
Paulus Schoutsen 2025-01-06 17:01:13 -05:00 committed by GitHub
parent 9532e98166
commit d13c14eedb
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 138 additions and 15 deletions

View File

@ -1,6 +1,7 @@
"""Conversation support for OpenAI."""
from collections.abc import Callable
from dataclasses import dataclass, field
import json
from typing import Any, Literal
@ -73,6 +74,14 @@ def _format_tool(
return ChatCompletionToolParam(type="function", function=tool_spec)
@dataclass
class ChatHistory:
"""Class holding the chat history."""
extra_system_prompt: str | None = None
messages: list[ChatCompletionMessageParam] = field(default_factory=list)
class OpenAIConversationEntity(
conversation.ConversationEntity, conversation.AbstractConversationAgent
):
@ -84,7 +93,7 @@ class OpenAIConversationEntity(
def __init__(self, entry: OpenAIConfigEntry) -> None:
"""Initialize the agent."""
self.entry = entry
self.history: dict[str, list[ChatCompletionMessageParam]] = {}
self.history: dict[str, ChatHistory] = {}
self._attr_unique_id = entry.entry_id
self._attr_device_info = dr.DeviceInfo(
identifiers={(DOMAIN, entry.entry_id)},
@ -157,13 +166,14 @@ class OpenAIConversationEntity(
_format_tool(tool, llm_api.custom_serializer) for tool in llm_api.tools
]
history: ChatHistory | None = None
if user_input.conversation_id is None:
conversation_id = ulid.ulid_now()
messages = []
elif user_input.conversation_id in self.history:
conversation_id = user_input.conversation_id
messages = self.history[conversation_id]
history = self.history.get(conversation_id)
else:
# Conversation IDs are ULIDs. We generate a new one if not provided.
@ -176,7 +186,8 @@ class OpenAIConversationEntity(
except ValueError:
conversation_id = user_input.conversation_id
messages = []
if history is None:
history = ChatHistory(user_input.extra_system_prompt)
if (
user_input.context
@ -217,20 +228,31 @@ class OpenAIConversationEntity(
if llm_api:
prompt_parts.append(llm_api.api_prompt)
extra_system_prompt = (
# Take new system prompt if one was given
user_input.extra_system_prompt or history.extra_system_prompt
)
if extra_system_prompt:
prompt_parts.append(extra_system_prompt)
prompt = "\n".join(prompt_parts)
# Create a copy of the variable because we attach it to the trace
messages = [
ChatCompletionSystemMessageParam(role="system", content=prompt),
*messages[1:],
ChatCompletionUserMessageParam(role="user", content=user_input.text),
]
history = ChatHistory(
extra_system_prompt,
[
ChatCompletionSystemMessageParam(role="system", content=prompt),
*history.messages[1:],
ChatCompletionUserMessageParam(role="user", content=user_input.text),
],
)
LOGGER.debug("Prompt: %s", messages)
LOGGER.debug("Prompt: %s", history.messages)
LOGGER.debug("Tools: %s", tools)
trace.async_conversation_trace_append(
trace.ConversationTraceEventType.AGENT_DETAIL,
{"messages": messages, "tools": llm_api.tools if llm_api else None},
{"messages": history.messages, "tools": llm_api.tools if llm_api else None},
)
client = self.entry.runtime_data
@ -240,7 +262,7 @@ class OpenAIConversationEntity(
try:
result = await client.chat.completions.create(
model=options.get(CONF_CHAT_MODEL, RECOMMENDED_CHAT_MODEL),
messages=messages,
messages=history.messages,
tools=tools or NOT_GIVEN,
max_tokens=options.get(CONF_MAX_TOKENS, RECOMMENDED_MAX_TOKENS),
top_p=options.get(CONF_TOP_P, RECOMMENDED_TOP_P),
@ -286,7 +308,7 @@ class OpenAIConversationEntity(
param["tool_calls"] = tool_calls
return param
messages.append(message_convert(response))
history.messages.append(message_convert(response))
tool_calls = response.tool_calls
if not tool_calls or not llm_api:
@ -309,7 +331,7 @@ class OpenAIConversationEntity(
tool_response["error_text"] = str(e)
LOGGER.debug("Tool response: %s", tool_response)
messages.append(
history.messages.append(
ChatCompletionToolMessageParam(
role="tool",
tool_call_id=tool_call.id,
@ -317,7 +339,7 @@ class OpenAIConversationEntity(
)
)
self.history[conversation_id] = messages
self.history[conversation_id] = history
intent_response = intent.IntentResponse(language=user_input.language)
intent_response.async_set_speech(response.content or "")

View File

@ -149,6 +149,107 @@ async def test_template_variables(
)
async def test_extra_systen_prompt(
hass: HomeAssistant, mock_config_entry: MockConfigEntry
) -> None:
"""Test that template variables work."""
extra_system_prompt = "Garage door cover.garage_door has been left open for 30 minutes. We asked the user if they want to close it."
extra_system_prompt2 = (
"User person.paulus came home. Asked him what he wants to do."
)
with (
patch(
"openai.resources.models.AsyncModels.list",
),
patch(
"openai.resources.chat.completions.AsyncCompletions.create",
new_callable=AsyncMock,
) as mock_create,
):
await hass.config_entries.async_setup(mock_config_entry.entry_id)
await hass.async_block_till_done()
result = await conversation.async_converse(
hass,
"hello",
None,
Context(),
agent_id=mock_config_entry.entry_id,
extra_system_prompt=extra_system_prompt,
)
assert (
result.response.response_type == intent.IntentResponseType.ACTION_DONE
), result
assert mock_create.mock_calls[0][2]["messages"][0]["content"].endswith(
extra_system_prompt
)
conversation_id = result.conversation_id
# Verify that follow-up conversations with no system prompt take previous one
with patch(
"openai.resources.chat.completions.AsyncCompletions.create",
new_callable=AsyncMock,
) as mock_create:
result = await conversation.async_converse(
hass,
"hello",
conversation_id,
Context(),
agent_id=mock_config_entry.entry_id,
extra_system_prompt=None,
)
assert (
result.response.response_type == intent.IntentResponseType.ACTION_DONE
), result
assert mock_create.mock_calls[0][2]["messages"][0]["content"].endswith(
extra_system_prompt
)
# Verify that we take new system prompts
with patch(
"openai.resources.chat.completions.AsyncCompletions.create",
new_callable=AsyncMock,
) as mock_create:
result = await conversation.async_converse(
hass,
"hello",
conversation_id,
Context(),
agent_id=mock_config_entry.entry_id,
extra_system_prompt=extra_system_prompt2,
)
assert (
result.response.response_type == intent.IntentResponseType.ACTION_DONE
), result
assert mock_create.mock_calls[0][2]["messages"][0]["content"].endswith(
extra_system_prompt2
)
# Verify that follow-up conversations with no system prompt take previous one
with patch(
"openai.resources.chat.completions.AsyncCompletions.create",
new_callable=AsyncMock,
) as mock_create:
result = await conversation.async_converse(
hass,
"hello",
conversation_id,
Context(),
agent_id=mock_config_entry.entry_id,
)
assert (
result.response.response_type == intent.IntentResponseType.ACTION_DONE
), result
assert mock_create.mock_calls[0][2]["messages"][0]["content"].endswith(
extra_system_prompt2
)
async def test_conversation_agent(
hass: HomeAssistant,
mock_config_entry: MockConfigEntry,