Add support for extra_system_prompt to OpenAI (#134931)
parent
9532e98166
commit
d13c14eedb
|
@ -1,6 +1,7 @@
|
|||
"""Conversation support for OpenAI."""
|
||||
|
||||
from collections.abc import Callable
|
||||
from dataclasses import dataclass, field
|
||||
import json
|
||||
from typing import Any, Literal
|
||||
|
||||
|
@ -73,6 +74,14 @@ def _format_tool(
|
|||
return ChatCompletionToolParam(type="function", function=tool_spec)
|
||||
|
||||
|
||||
@dataclass
|
||||
class ChatHistory:
|
||||
"""Class holding the chat history."""
|
||||
|
||||
extra_system_prompt: str | None = None
|
||||
messages: list[ChatCompletionMessageParam] = field(default_factory=list)
|
||||
|
||||
|
||||
class OpenAIConversationEntity(
|
||||
conversation.ConversationEntity, conversation.AbstractConversationAgent
|
||||
):
|
||||
|
@ -84,7 +93,7 @@ class OpenAIConversationEntity(
|
|||
def __init__(self, entry: OpenAIConfigEntry) -> None:
|
||||
"""Initialize the agent."""
|
||||
self.entry = entry
|
||||
self.history: dict[str, list[ChatCompletionMessageParam]] = {}
|
||||
self.history: dict[str, ChatHistory] = {}
|
||||
self._attr_unique_id = entry.entry_id
|
||||
self._attr_device_info = dr.DeviceInfo(
|
||||
identifiers={(DOMAIN, entry.entry_id)},
|
||||
|
@ -157,13 +166,14 @@ class OpenAIConversationEntity(
|
|||
_format_tool(tool, llm_api.custom_serializer) for tool in llm_api.tools
|
||||
]
|
||||
|
||||
history: ChatHistory | None = None
|
||||
|
||||
if user_input.conversation_id is None:
|
||||
conversation_id = ulid.ulid_now()
|
||||
messages = []
|
||||
|
||||
elif user_input.conversation_id in self.history:
|
||||
conversation_id = user_input.conversation_id
|
||||
messages = self.history[conversation_id]
|
||||
history = self.history.get(conversation_id)
|
||||
|
||||
else:
|
||||
# Conversation IDs are ULIDs. We generate a new one if not provided.
|
||||
|
@ -176,7 +186,8 @@ class OpenAIConversationEntity(
|
|||
except ValueError:
|
||||
conversation_id = user_input.conversation_id
|
||||
|
||||
messages = []
|
||||
if history is None:
|
||||
history = ChatHistory(user_input.extra_system_prompt)
|
||||
|
||||
if (
|
||||
user_input.context
|
||||
|
@ -217,20 +228,31 @@ class OpenAIConversationEntity(
|
|||
if llm_api:
|
||||
prompt_parts.append(llm_api.api_prompt)
|
||||
|
||||
extra_system_prompt = (
|
||||
# Take new system prompt if one was given
|
||||
user_input.extra_system_prompt or history.extra_system_prompt
|
||||
)
|
||||
|
||||
if extra_system_prompt:
|
||||
prompt_parts.append(extra_system_prompt)
|
||||
|
||||
prompt = "\n".join(prompt_parts)
|
||||
|
||||
# Create a copy of the variable because we attach it to the trace
|
||||
messages = [
|
||||
ChatCompletionSystemMessageParam(role="system", content=prompt),
|
||||
*messages[1:],
|
||||
ChatCompletionUserMessageParam(role="user", content=user_input.text),
|
||||
]
|
||||
history = ChatHistory(
|
||||
extra_system_prompt,
|
||||
[
|
||||
ChatCompletionSystemMessageParam(role="system", content=prompt),
|
||||
*history.messages[1:],
|
||||
ChatCompletionUserMessageParam(role="user", content=user_input.text),
|
||||
],
|
||||
)
|
||||
|
||||
LOGGER.debug("Prompt: %s", messages)
|
||||
LOGGER.debug("Prompt: %s", history.messages)
|
||||
LOGGER.debug("Tools: %s", tools)
|
||||
trace.async_conversation_trace_append(
|
||||
trace.ConversationTraceEventType.AGENT_DETAIL,
|
||||
{"messages": messages, "tools": llm_api.tools if llm_api else None},
|
||||
{"messages": history.messages, "tools": llm_api.tools if llm_api else None},
|
||||
)
|
||||
|
||||
client = self.entry.runtime_data
|
||||
|
@ -240,7 +262,7 @@ class OpenAIConversationEntity(
|
|||
try:
|
||||
result = await client.chat.completions.create(
|
||||
model=options.get(CONF_CHAT_MODEL, RECOMMENDED_CHAT_MODEL),
|
||||
messages=messages,
|
||||
messages=history.messages,
|
||||
tools=tools or NOT_GIVEN,
|
||||
max_tokens=options.get(CONF_MAX_TOKENS, RECOMMENDED_MAX_TOKENS),
|
||||
top_p=options.get(CONF_TOP_P, RECOMMENDED_TOP_P),
|
||||
|
@ -286,7 +308,7 @@ class OpenAIConversationEntity(
|
|||
param["tool_calls"] = tool_calls
|
||||
return param
|
||||
|
||||
messages.append(message_convert(response))
|
||||
history.messages.append(message_convert(response))
|
||||
tool_calls = response.tool_calls
|
||||
|
||||
if not tool_calls or not llm_api:
|
||||
|
@ -309,7 +331,7 @@ class OpenAIConversationEntity(
|
|||
tool_response["error_text"] = str(e)
|
||||
|
||||
LOGGER.debug("Tool response: %s", tool_response)
|
||||
messages.append(
|
||||
history.messages.append(
|
||||
ChatCompletionToolMessageParam(
|
||||
role="tool",
|
||||
tool_call_id=tool_call.id,
|
||||
|
@ -317,7 +339,7 @@ class OpenAIConversationEntity(
|
|||
)
|
||||
)
|
||||
|
||||
self.history[conversation_id] = messages
|
||||
self.history[conversation_id] = history
|
||||
|
||||
intent_response = intent.IntentResponse(language=user_input.language)
|
||||
intent_response.async_set_speech(response.content or "")
|
||||
|
|
|
@ -149,6 +149,107 @@ async def test_template_variables(
|
|||
)
|
||||
|
||||
|
||||
async def test_extra_systen_prompt(
|
||||
hass: HomeAssistant, mock_config_entry: MockConfigEntry
|
||||
) -> None:
|
||||
"""Test that template variables work."""
|
||||
extra_system_prompt = "Garage door cover.garage_door has been left open for 30 minutes. We asked the user if they want to close it."
|
||||
extra_system_prompt2 = (
|
||||
"User person.paulus came home. Asked him what he wants to do."
|
||||
)
|
||||
|
||||
with (
|
||||
patch(
|
||||
"openai.resources.models.AsyncModels.list",
|
||||
),
|
||||
patch(
|
||||
"openai.resources.chat.completions.AsyncCompletions.create",
|
||||
new_callable=AsyncMock,
|
||||
) as mock_create,
|
||||
):
|
||||
await hass.config_entries.async_setup(mock_config_entry.entry_id)
|
||||
await hass.async_block_till_done()
|
||||
result = await conversation.async_converse(
|
||||
hass,
|
||||
"hello",
|
||||
None,
|
||||
Context(),
|
||||
agent_id=mock_config_entry.entry_id,
|
||||
extra_system_prompt=extra_system_prompt,
|
||||
)
|
||||
|
||||
assert (
|
||||
result.response.response_type == intent.IntentResponseType.ACTION_DONE
|
||||
), result
|
||||
assert mock_create.mock_calls[0][2]["messages"][0]["content"].endswith(
|
||||
extra_system_prompt
|
||||
)
|
||||
|
||||
conversation_id = result.conversation_id
|
||||
|
||||
# Verify that follow-up conversations with no system prompt take previous one
|
||||
with patch(
|
||||
"openai.resources.chat.completions.AsyncCompletions.create",
|
||||
new_callable=AsyncMock,
|
||||
) as mock_create:
|
||||
result = await conversation.async_converse(
|
||||
hass,
|
||||
"hello",
|
||||
conversation_id,
|
||||
Context(),
|
||||
agent_id=mock_config_entry.entry_id,
|
||||
extra_system_prompt=None,
|
||||
)
|
||||
|
||||
assert (
|
||||
result.response.response_type == intent.IntentResponseType.ACTION_DONE
|
||||
), result
|
||||
assert mock_create.mock_calls[0][2]["messages"][0]["content"].endswith(
|
||||
extra_system_prompt
|
||||
)
|
||||
|
||||
# Verify that we take new system prompts
|
||||
with patch(
|
||||
"openai.resources.chat.completions.AsyncCompletions.create",
|
||||
new_callable=AsyncMock,
|
||||
) as mock_create:
|
||||
result = await conversation.async_converse(
|
||||
hass,
|
||||
"hello",
|
||||
conversation_id,
|
||||
Context(),
|
||||
agent_id=mock_config_entry.entry_id,
|
||||
extra_system_prompt=extra_system_prompt2,
|
||||
)
|
||||
|
||||
assert (
|
||||
result.response.response_type == intent.IntentResponseType.ACTION_DONE
|
||||
), result
|
||||
assert mock_create.mock_calls[0][2]["messages"][0]["content"].endswith(
|
||||
extra_system_prompt2
|
||||
)
|
||||
|
||||
# Verify that follow-up conversations with no system prompt take previous one
|
||||
with patch(
|
||||
"openai.resources.chat.completions.AsyncCompletions.create",
|
||||
new_callable=AsyncMock,
|
||||
) as mock_create:
|
||||
result = await conversation.async_converse(
|
||||
hass,
|
||||
"hello",
|
||||
conversation_id,
|
||||
Context(),
|
||||
agent_id=mock_config_entry.entry_id,
|
||||
)
|
||||
|
||||
assert (
|
||||
result.response.response_type == intent.IntentResponseType.ACTION_DONE
|
||||
), result
|
||||
assert mock_create.mock_calls[0][2]["messages"][0]["content"].endswith(
|
||||
extra_system_prompt2
|
||||
)
|
||||
|
||||
|
||||
async def test_conversation_agent(
|
||||
hass: HomeAssistant,
|
||||
mock_config_entry: MockConfigEntry,
|
||||
|
|
Loading…
Reference in New Issue