ChatSession: Split native content out of message class ()

Split native content out of message class
pull/136685/head
Paulus Schoutsen 2025-01-28 00:12:42 -05:00 committed by GitHub
parent 48a91540e1
commit 5690516852
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
6 changed files with 59 additions and 73 deletions
homeassistant/components
assist_pipeline
openai_conversation
tests/components/conversation

View File

@ -1101,11 +1101,10 @@ class PipelineRun:
"speech", ""
)
chat_session.async_add_message(
conversation.ChatMessage(
conversation.Content(
role="assistant",
agent_id=agent_id,
content=speech,
native=intent_response,
)
)
conversation_result = conversation.ConversationResult(

View File

@ -48,21 +48,28 @@ from .default_agent import DefaultAgent, async_setup_default_agent
from .entity import ConversationEntity
from .http import async_setup as async_setup_conversation_http
from .models import AbstractConversationAgent, ConversationInput, ConversationResult
from .session import ChatMessage, ChatSession, ConverseError, async_get_chat_session
from .session import (
ChatSession,
Content,
ConverseError,
NativeContent,
async_get_chat_session,
)
from .trace import ConversationTraceEventType, async_conversation_trace_append
__all__ = [
"DOMAIN",
"HOME_ASSISTANT_AGENT",
"OLD_HOME_ASSISTANT_AGENT",
"ChatMessage",
"ChatSession",
"Content",
"ConversationEntity",
"ConversationEntityFeature",
"ConversationInput",
"ConversationResult",
"ConversationTraceEventType",
"ConverseError",
"NativeContent",
"async_conversation_trace_append",
"async_converse",
"async_get_agent_info",

View File

@ -62,7 +62,7 @@ from .const import (
)
from .entity import ConversationEntity
from .models import ConversationInput, ConversationResult
from .session import ChatMessage, async_get_chat_session
from .session import Content, async_get_chat_session
from .trace import ConversationTraceEventType, async_conversation_trace_append
_LOGGER = logging.getLogger(__name__)
@ -374,11 +374,10 @@ class DefaultAgent(ConversationEntity):
speech: str = response.speech.get("plain", {}).get("speech", "")
chat_session.async_add_message(
ChatMessage(
Content(
role="assistant",
agent_id=user_input.agent_id,
content=speech,
native=response,
)
)

View File

@ -126,7 +126,7 @@ async def async_get_chat_session(
else:
history = ChatSession(hass, conversation_id, user_input.agent_id)
message: ChatMessage = ChatMessage(
message: Content = Content(
role="user",
agent_id=user_input.agent_id,
content=user_input.text,
@ -169,23 +169,21 @@ class ConverseError(HomeAssistantError):
@dataclass
class ChatMessage[_NativeT]:
"""Base class for chat messages.
class Content:
"""Base class for chat messages."""
When role is native, the content is to be ignored and message
is only meant for storing the native object.
"""
role: Literal["system", "assistant", "user", "native"]
role: Literal["system", "assistant", "user"]
agent_id: str | None
content: str
native: _NativeT | None = field(default=None)
# Validate in post-init that if role is native, there is no content and a native object exists
def __post_init__(self) -> None:
"""Validate native message."""
if self.role == "native" and self.native is None:
raise ValueError("Native message must have a native object")
@dataclass(frozen=True)
class NativeContent[_NativeT]:
"""Native content."""
role: str = field(init=False, default="native")
agent_id: str
content: _NativeT
@dataclass
@ -196,15 +194,15 @@ class ChatSession[_NativeT]:
conversation_id: str
agent_id: str | None
user_name: str | None = None
messages: list[ChatMessage[_NativeT]] = field(
default_factory=lambda: [ChatMessage(role="system", agent_id=None, content="")]
messages: list[Content | NativeContent[_NativeT]] = field(
default_factory=lambda: [Content(role="system", agent_id=None, content="")]
)
extra_system_prompt: str | None = None
llm_api: llm.APIInstance | None = None
last_updated: datetime = field(default_factory=dt_util.utcnow)
@callback
def async_add_message(self, message: ChatMessage[_NativeT]) -> None:
def async_add_message(self, message: Content | NativeContent[_NativeT]) -> None:
"""Process intent."""
if message.role == "system":
raise ValueError("Cannot add system messages to history")
@ -216,7 +214,7 @@ class ChatSession[_NativeT]:
@callback
def async_get_messages(
self, agent_id: str | None = None
) -> list[ChatMessage[_NativeT]]:
) -> list[Content | NativeContent[_NativeT]]:
"""Get messages for a specific agent ID.
This will filter out any native message tied to other agent IDs.
@ -328,7 +326,7 @@ class ChatSession[_NativeT]:
self.llm_api = llm_api
self.user_name = user_name
self.extra_system_prompt = extra_system_prompt
self.messages[0] = ChatMessage(
self.messages[0] = Content(
role="system",
agent_id=user_input.agent_id,
content=prompt,

View File

@ -93,12 +93,13 @@ def _message_convert(message: ChatCompletionMessage) -> ChatCompletionMessagePar
def _chat_message_convert(
message: conversation.ChatMessage[ChatCompletionMessageParam],
agent_id: str | None,
message: conversation.Content
| conversation.NativeContent[ChatCompletionMessageParam],
) -> ChatCompletionMessageParam:
"""Convert any native chat message for this agent to the native format."""
if message.native is not None and message.agent_id == agent_id:
return message.native
if message.role == "native":
# mypy doesn't understand that checking role ensures content type
return message.content # type: ignore[return-value]
return cast(
ChatCompletionMessageParam,
{"role": message.role, "content": message.content},
@ -157,14 +158,15 @@ class OpenAIConversationEntity(
async with conversation.async_get_chat_session(
self.hass, user_input
) as session:
return await self._async_call_api(user_input, session)
return await self._async_handle_message(user_input, session)
async def _async_call_api(
async def _async_handle_message(
self,
user_input: conversation.ConversationInput,
session: conversation.ChatSession[ChatCompletionMessageParam],
) -> conversation.ConversationResult:
"""Call the API."""
assert user_input.agent_id
options = self.entry.options
try:
@ -185,8 +187,7 @@ class OpenAIConversationEntity(
]
messages = [
_chat_message_convert(message, user_input.agent_id)
for message in session.async_get_messages()
_chat_message_convert(message) for message in session.async_get_messages()
]
client = self.entry.runtime_data
@ -212,11 +213,10 @@ class OpenAIConversationEntity(
messages.append(_message_convert(response))
session.async_add_message(
conversation.ChatMessage(
conversation.Content(
role=response.role,
agent_id=user_input.agent_id,
content=response.content or "",
native=messages[-1],
),
)
@ -237,11 +237,9 @@ class OpenAIConversationEntity(
)
)
session.async_add_message(
conversation.ChatMessage(
role="native",
conversation.NativeContent(
agent_id=user_input.agent_id,
content="",
native=messages[-1],
content=messages[-1],
)
)

View File

@ -82,7 +82,7 @@ async def test_cleanup(
assert chat_session.conversation_id != conversation_id
conversation_id = chat_session.conversation_id
chat_session.async_add_message(
session.ChatMessage(
session.Content(
role="assistant",
agent_id="mock-agent-id",
content="Hey!",
@ -127,12 +127,6 @@ async def test_cleanup(
assert len(chat_session.messages) == 2
def test_chat_message() -> None:
"""Test chat message."""
with pytest.raises(ValueError):
session.ChatMessage(role="native", agent_id=None, content="", native=None)
async def test_add_message(
hass: HomeAssistant, mock_conversation_input: ConversationInput
) -> None:
@ -144,7 +138,7 @@ async def test_add_message(
with pytest.raises(ValueError):
chat_session.async_add_message(
session.ChatMessage(role="system", agent_id=None, content="")
session.Content(role="system", agent_id=None, content="")
)
# No 2 user messages in a row
@ -152,19 +146,19 @@ async def test_add_message(
with pytest.raises(ValueError):
chat_session.async_add_message(
session.ChatMessage(role="user", agent_id=None, content="")
session.Content(role="user", agent_id=None, content="")
)
# No 2 assistant messages in a row
chat_session.async_add_message(
session.ChatMessage(role="assistant", agent_id=None, content="")
session.Content(role="assistant", agent_id=None, content="")
)
assert len(chat_session.messages) == 3
assert chat_session.messages[-1].role == "assistant"
with pytest.raises(ValueError):
chat_session.async_add_message(
session.ChatMessage(role="assistant", agent_id=None, content="")
session.Content(role="assistant", agent_id=None, content="")
)
@ -177,12 +171,12 @@ async def test_message_filtering(
) as chat_session:
messages = chat_session.async_get_messages(agent_id=None)
assert len(messages) == 2
assert messages[0] == session.ChatMessage(
assert messages[0] == session.Content(
role="system",
agent_id=None,
content="",
)
assert messages[1] == session.ChatMessage(
assert messages[1] == session.Content(
role="user",
agent_id="mock-agent-id",
content=mock_conversation_input.text,
@ -190,7 +184,7 @@ async def test_message_filtering(
# Cannot add a second user message in a row
with pytest.raises(ValueError):
chat_session.async_add_message(
session.ChatMessage(
session.Content(
role="user",
agent_id="mock-agent-id",
content="Hey!",
@ -198,31 +192,25 @@ async def test_message_filtering(
)
chat_session.async_add_message(
session.ChatMessage(
session.Content(
role="assistant",
agent_id="mock-agent-id",
content="Hey!",
native="assistant-reply-native",
)
)
# Different agent, native messages will be filtered out.
chat_session.async_add_message(
session.ChatMessage(
role="native", agent_id="another-mock-agent-id", content="", native=1
)
session.NativeContent(agent_id="another-mock-agent-id", content=1)
)
chat_session.async_add_message(
session.ChatMessage(
role="native", agent_id="mock-agent-id", content="", native=1
)
session.NativeContent(agent_id="mock-agent-id", content=1)
)
# A non-native message from another agent is not filtered out.
chat_session.async_add_message(
session.ChatMessage(
session.Content(
role="assistant",
agent_id="another-mock-agent-id",
content="Hi!",
native=1,
)
)
@ -231,17 +219,14 @@ async def test_message_filtering(
messages = chat_session.async_get_messages(agent_id="mock-agent-id")
assert len(messages) == 5
assert messages[2] == session.ChatMessage(
assert messages[2] == session.Content(
role="assistant",
agent_id="mock-agent-id",
content="Hey!",
native="assistant-reply-native",
)
assert messages[3] == session.ChatMessage(
role="native", agent_id="mock-agent-id", content="", native=1
)
assert messages[4] == session.ChatMessage(
role="assistant", agent_id="another-mock-agent-id", content="Hi!", native=1
assert messages[3] == session.NativeContent(agent_id="mock-agent-id", content=1)
assert messages[4] == session.Content(
role="assistant", agent_id="another-mock-agent-id", content="Hi!"
)
@ -361,7 +346,7 @@ async def test_extra_systen_prompt(
user_llm_prompt=None,
)
chat_session.async_add_message(
session.ChatMessage(
session.Content(
role="assistant",
agent_id="mock-agent-id",
content="Hey!",
@ -401,7 +386,7 @@ async def test_extra_systen_prompt(
user_llm_prompt=None,
)
chat_session.async_add_message(
session.ChatMessage(
session.Content(
role="assistant",
agent_id="mock-agent-id",
content="Hey!",