ChatSession: Split native content out of message class (#136668)
Split native content out of message classpull/136685/head
parent
48a91540e1
commit
5690516852
homeassistant/components
assist_pipeline
conversation
openai_conversation
tests/components/conversation
|
@ -1101,11 +1101,10 @@ class PipelineRun:
|
|||
"speech", ""
|
||||
)
|
||||
chat_session.async_add_message(
|
||||
conversation.ChatMessage(
|
||||
conversation.Content(
|
||||
role="assistant",
|
||||
agent_id=agent_id,
|
||||
content=speech,
|
||||
native=intent_response,
|
||||
)
|
||||
)
|
||||
conversation_result = conversation.ConversationResult(
|
||||
|
|
|
@ -48,21 +48,28 @@ from .default_agent import DefaultAgent, async_setup_default_agent
|
|||
from .entity import ConversationEntity
|
||||
from .http import async_setup as async_setup_conversation_http
|
||||
from .models import AbstractConversationAgent, ConversationInput, ConversationResult
|
||||
from .session import ChatMessage, ChatSession, ConverseError, async_get_chat_session
|
||||
from .session import (
|
||||
ChatSession,
|
||||
Content,
|
||||
ConverseError,
|
||||
NativeContent,
|
||||
async_get_chat_session,
|
||||
)
|
||||
from .trace import ConversationTraceEventType, async_conversation_trace_append
|
||||
|
||||
__all__ = [
|
||||
"DOMAIN",
|
||||
"HOME_ASSISTANT_AGENT",
|
||||
"OLD_HOME_ASSISTANT_AGENT",
|
||||
"ChatMessage",
|
||||
"ChatSession",
|
||||
"Content",
|
||||
"ConversationEntity",
|
||||
"ConversationEntityFeature",
|
||||
"ConversationInput",
|
||||
"ConversationResult",
|
||||
"ConversationTraceEventType",
|
||||
"ConverseError",
|
||||
"NativeContent",
|
||||
"async_conversation_trace_append",
|
||||
"async_converse",
|
||||
"async_get_agent_info",
|
||||
|
|
|
@ -62,7 +62,7 @@ from .const import (
|
|||
)
|
||||
from .entity import ConversationEntity
|
||||
from .models import ConversationInput, ConversationResult
|
||||
from .session import ChatMessage, async_get_chat_session
|
||||
from .session import Content, async_get_chat_session
|
||||
from .trace import ConversationTraceEventType, async_conversation_trace_append
|
||||
|
||||
_LOGGER = logging.getLogger(__name__)
|
||||
|
@ -374,11 +374,10 @@ class DefaultAgent(ConversationEntity):
|
|||
|
||||
speech: str = response.speech.get("plain", {}).get("speech", "")
|
||||
chat_session.async_add_message(
|
||||
ChatMessage(
|
||||
Content(
|
||||
role="assistant",
|
||||
agent_id=user_input.agent_id,
|
||||
content=speech,
|
||||
native=response,
|
||||
)
|
||||
)
|
||||
|
||||
|
|
|
@ -126,7 +126,7 @@ async def async_get_chat_session(
|
|||
else:
|
||||
history = ChatSession(hass, conversation_id, user_input.agent_id)
|
||||
|
||||
message: ChatMessage = ChatMessage(
|
||||
message: Content = Content(
|
||||
role="user",
|
||||
agent_id=user_input.agent_id,
|
||||
content=user_input.text,
|
||||
|
@ -169,23 +169,21 @@ class ConverseError(HomeAssistantError):
|
|||
|
||||
|
||||
@dataclass
|
||||
class ChatMessage[_NativeT]:
|
||||
"""Base class for chat messages.
|
||||
class Content:
|
||||
"""Base class for chat messages."""
|
||||
|
||||
When role is native, the content is to be ignored and message
|
||||
is only meant for storing the native object.
|
||||
"""
|
||||
|
||||
role: Literal["system", "assistant", "user", "native"]
|
||||
role: Literal["system", "assistant", "user"]
|
||||
agent_id: str | None
|
||||
content: str
|
||||
native: _NativeT | None = field(default=None)
|
||||
|
||||
# Validate in post-init that if role is native, there is no content and a native object exists
|
||||
def __post_init__(self) -> None:
|
||||
"""Validate native message."""
|
||||
if self.role == "native" and self.native is None:
|
||||
raise ValueError("Native message must have a native object")
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class NativeContent[_NativeT]:
|
||||
"""Native content."""
|
||||
|
||||
role: str = field(init=False, default="native")
|
||||
agent_id: str
|
||||
content: _NativeT
|
||||
|
||||
|
||||
@dataclass
|
||||
|
@ -196,15 +194,15 @@ class ChatSession[_NativeT]:
|
|||
conversation_id: str
|
||||
agent_id: str | None
|
||||
user_name: str | None = None
|
||||
messages: list[ChatMessage[_NativeT]] = field(
|
||||
default_factory=lambda: [ChatMessage(role="system", agent_id=None, content="")]
|
||||
messages: list[Content | NativeContent[_NativeT]] = field(
|
||||
default_factory=lambda: [Content(role="system", agent_id=None, content="")]
|
||||
)
|
||||
extra_system_prompt: str | None = None
|
||||
llm_api: llm.APIInstance | None = None
|
||||
last_updated: datetime = field(default_factory=dt_util.utcnow)
|
||||
|
||||
@callback
|
||||
def async_add_message(self, message: ChatMessage[_NativeT]) -> None:
|
||||
def async_add_message(self, message: Content | NativeContent[_NativeT]) -> None:
|
||||
"""Process intent."""
|
||||
if message.role == "system":
|
||||
raise ValueError("Cannot add system messages to history")
|
||||
|
@ -216,7 +214,7 @@ class ChatSession[_NativeT]:
|
|||
@callback
|
||||
def async_get_messages(
|
||||
self, agent_id: str | None = None
|
||||
) -> list[ChatMessage[_NativeT]]:
|
||||
) -> list[Content | NativeContent[_NativeT]]:
|
||||
"""Get messages for a specific agent ID.
|
||||
|
||||
This will filter out any native message tied to other agent IDs.
|
||||
|
@ -328,7 +326,7 @@ class ChatSession[_NativeT]:
|
|||
self.llm_api = llm_api
|
||||
self.user_name = user_name
|
||||
self.extra_system_prompt = extra_system_prompt
|
||||
self.messages[0] = ChatMessage(
|
||||
self.messages[0] = Content(
|
||||
role="system",
|
||||
agent_id=user_input.agent_id,
|
||||
content=prompt,
|
||||
|
|
|
@ -93,12 +93,13 @@ def _message_convert(message: ChatCompletionMessage) -> ChatCompletionMessagePar
|
|||
|
||||
|
||||
def _chat_message_convert(
|
||||
message: conversation.ChatMessage[ChatCompletionMessageParam],
|
||||
agent_id: str | None,
|
||||
message: conversation.Content
|
||||
| conversation.NativeContent[ChatCompletionMessageParam],
|
||||
) -> ChatCompletionMessageParam:
|
||||
"""Convert any native chat message for this agent to the native format."""
|
||||
if message.native is not None and message.agent_id == agent_id:
|
||||
return message.native
|
||||
if message.role == "native":
|
||||
# mypy doesn't understand that checking role ensures content type
|
||||
return message.content # type: ignore[return-value]
|
||||
return cast(
|
||||
ChatCompletionMessageParam,
|
||||
{"role": message.role, "content": message.content},
|
||||
|
@ -157,14 +158,15 @@ class OpenAIConversationEntity(
|
|||
async with conversation.async_get_chat_session(
|
||||
self.hass, user_input
|
||||
) as session:
|
||||
return await self._async_call_api(user_input, session)
|
||||
return await self._async_handle_message(user_input, session)
|
||||
|
||||
async def _async_call_api(
|
||||
async def _async_handle_message(
|
||||
self,
|
||||
user_input: conversation.ConversationInput,
|
||||
session: conversation.ChatSession[ChatCompletionMessageParam],
|
||||
) -> conversation.ConversationResult:
|
||||
"""Call the API."""
|
||||
assert user_input.agent_id
|
||||
options = self.entry.options
|
||||
|
||||
try:
|
||||
|
@ -185,8 +187,7 @@ class OpenAIConversationEntity(
|
|||
]
|
||||
|
||||
messages = [
|
||||
_chat_message_convert(message, user_input.agent_id)
|
||||
for message in session.async_get_messages()
|
||||
_chat_message_convert(message) for message in session.async_get_messages()
|
||||
]
|
||||
|
||||
client = self.entry.runtime_data
|
||||
|
@ -212,11 +213,10 @@ class OpenAIConversationEntity(
|
|||
messages.append(_message_convert(response))
|
||||
|
||||
session.async_add_message(
|
||||
conversation.ChatMessage(
|
||||
conversation.Content(
|
||||
role=response.role,
|
||||
agent_id=user_input.agent_id,
|
||||
content=response.content or "",
|
||||
native=messages[-1],
|
||||
),
|
||||
)
|
||||
|
||||
|
@ -237,11 +237,9 @@ class OpenAIConversationEntity(
|
|||
)
|
||||
)
|
||||
session.async_add_message(
|
||||
conversation.ChatMessage(
|
||||
role="native",
|
||||
conversation.NativeContent(
|
||||
agent_id=user_input.agent_id,
|
||||
content="",
|
||||
native=messages[-1],
|
||||
content=messages[-1],
|
||||
)
|
||||
)
|
||||
|
||||
|
|
|
@ -82,7 +82,7 @@ async def test_cleanup(
|
|||
assert chat_session.conversation_id != conversation_id
|
||||
conversation_id = chat_session.conversation_id
|
||||
chat_session.async_add_message(
|
||||
session.ChatMessage(
|
||||
session.Content(
|
||||
role="assistant",
|
||||
agent_id="mock-agent-id",
|
||||
content="Hey!",
|
||||
|
@ -127,12 +127,6 @@ async def test_cleanup(
|
|||
assert len(chat_session.messages) == 2
|
||||
|
||||
|
||||
def test_chat_message() -> None:
|
||||
"""Test chat message."""
|
||||
with pytest.raises(ValueError):
|
||||
session.ChatMessage(role="native", agent_id=None, content="", native=None)
|
||||
|
||||
|
||||
async def test_add_message(
|
||||
hass: HomeAssistant, mock_conversation_input: ConversationInput
|
||||
) -> None:
|
||||
|
@ -144,7 +138,7 @@ async def test_add_message(
|
|||
|
||||
with pytest.raises(ValueError):
|
||||
chat_session.async_add_message(
|
||||
session.ChatMessage(role="system", agent_id=None, content="")
|
||||
session.Content(role="system", agent_id=None, content="")
|
||||
)
|
||||
|
||||
# No 2 user messages in a row
|
||||
|
@ -152,19 +146,19 @@ async def test_add_message(
|
|||
|
||||
with pytest.raises(ValueError):
|
||||
chat_session.async_add_message(
|
||||
session.ChatMessage(role="user", agent_id=None, content="")
|
||||
session.Content(role="user", agent_id=None, content="")
|
||||
)
|
||||
|
||||
# No 2 assistant messages in a row
|
||||
chat_session.async_add_message(
|
||||
session.ChatMessage(role="assistant", agent_id=None, content="")
|
||||
session.Content(role="assistant", agent_id=None, content="")
|
||||
)
|
||||
assert len(chat_session.messages) == 3
|
||||
assert chat_session.messages[-1].role == "assistant"
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
chat_session.async_add_message(
|
||||
session.ChatMessage(role="assistant", agent_id=None, content="")
|
||||
session.Content(role="assistant", agent_id=None, content="")
|
||||
)
|
||||
|
||||
|
||||
|
@ -177,12 +171,12 @@ async def test_message_filtering(
|
|||
) as chat_session:
|
||||
messages = chat_session.async_get_messages(agent_id=None)
|
||||
assert len(messages) == 2
|
||||
assert messages[0] == session.ChatMessage(
|
||||
assert messages[0] == session.Content(
|
||||
role="system",
|
||||
agent_id=None,
|
||||
content="",
|
||||
)
|
||||
assert messages[1] == session.ChatMessage(
|
||||
assert messages[1] == session.Content(
|
||||
role="user",
|
||||
agent_id="mock-agent-id",
|
||||
content=mock_conversation_input.text,
|
||||
|
@ -190,7 +184,7 @@ async def test_message_filtering(
|
|||
# Cannot add a second user message in a row
|
||||
with pytest.raises(ValueError):
|
||||
chat_session.async_add_message(
|
||||
session.ChatMessage(
|
||||
session.Content(
|
||||
role="user",
|
||||
agent_id="mock-agent-id",
|
||||
content="Hey!",
|
||||
|
@ -198,31 +192,25 @@ async def test_message_filtering(
|
|||
)
|
||||
|
||||
chat_session.async_add_message(
|
||||
session.ChatMessage(
|
||||
session.Content(
|
||||
role="assistant",
|
||||
agent_id="mock-agent-id",
|
||||
content="Hey!",
|
||||
native="assistant-reply-native",
|
||||
)
|
||||
)
|
||||
# Different agent, native messages will be filtered out.
|
||||
chat_session.async_add_message(
|
||||
session.ChatMessage(
|
||||
role="native", agent_id="another-mock-agent-id", content="", native=1
|
||||
)
|
||||
session.NativeContent(agent_id="another-mock-agent-id", content=1)
|
||||
)
|
||||
chat_session.async_add_message(
|
||||
session.ChatMessage(
|
||||
role="native", agent_id="mock-agent-id", content="", native=1
|
||||
)
|
||||
session.NativeContent(agent_id="mock-agent-id", content=1)
|
||||
)
|
||||
# A non-native message from another agent is not filtered out.
|
||||
chat_session.async_add_message(
|
||||
session.ChatMessage(
|
||||
session.Content(
|
||||
role="assistant",
|
||||
agent_id="another-mock-agent-id",
|
||||
content="Hi!",
|
||||
native=1,
|
||||
)
|
||||
)
|
||||
|
||||
|
@ -231,17 +219,14 @@ async def test_message_filtering(
|
|||
messages = chat_session.async_get_messages(agent_id="mock-agent-id")
|
||||
assert len(messages) == 5
|
||||
|
||||
assert messages[2] == session.ChatMessage(
|
||||
assert messages[2] == session.Content(
|
||||
role="assistant",
|
||||
agent_id="mock-agent-id",
|
||||
content="Hey!",
|
||||
native="assistant-reply-native",
|
||||
)
|
||||
assert messages[3] == session.ChatMessage(
|
||||
role="native", agent_id="mock-agent-id", content="", native=1
|
||||
)
|
||||
assert messages[4] == session.ChatMessage(
|
||||
role="assistant", agent_id="another-mock-agent-id", content="Hi!", native=1
|
||||
assert messages[3] == session.NativeContent(agent_id="mock-agent-id", content=1)
|
||||
assert messages[4] == session.Content(
|
||||
role="assistant", agent_id="another-mock-agent-id", content="Hi!"
|
||||
)
|
||||
|
||||
|
||||
|
@ -361,7 +346,7 @@ async def test_extra_systen_prompt(
|
|||
user_llm_prompt=None,
|
||||
)
|
||||
chat_session.async_add_message(
|
||||
session.ChatMessage(
|
||||
session.Content(
|
||||
role="assistant",
|
||||
agent_id="mock-agent-id",
|
||||
content="Hey!",
|
||||
|
@ -401,7 +386,7 @@ async def test_extra_systen_prompt(
|
|||
user_llm_prompt=None,
|
||||
)
|
||||
chat_session.async_add_message(
|
||||
session.ChatMessage(
|
||||
session.Content(
|
||||
role="assistant",
|
||||
agent_id="mock-agent-id",
|
||||
content="Hey!",
|
||||
|
|
Loading…
Reference in New Issue