From 5690516852a4134a5445d5b2d888d0d1cca284da Mon Sep 17 00:00:00 2001 From: Paulus Schoutsen Date: Tue, 28 Jan 2025 00:12:42 -0500 Subject: [PATCH] ChatSession: Split native content out of message class (#136668) Split native content out of message class --- .../components/assist_pipeline/pipeline.py | 3 +- .../components/conversation/__init__.py | 11 +++- .../components/conversation/default_agent.py | 5 +- .../components/conversation/session.py | 36 +++++++------ .../openai_conversation/conversation.py | 26 +++++----- tests/components/conversation/test_session.py | 51 +++++++------------ 6 files changed, 59 insertions(+), 73 deletions(-) diff --git a/homeassistant/components/assist_pipeline/pipeline.py b/homeassistant/components/assist_pipeline/pipeline.py index 9353bbe0007..9fdcc2bf690 100644 --- a/homeassistant/components/assist_pipeline/pipeline.py +++ b/homeassistant/components/assist_pipeline/pipeline.py @@ -1101,11 +1101,10 @@ class PipelineRun: "speech", "" ) chat_session.async_add_message( - conversation.ChatMessage( + conversation.Content( role="assistant", agent_id=agent_id, content=speech, - native=intent_response, ) ) conversation_result = conversation.ConversationResult( diff --git a/homeassistant/components/conversation/__init__.py b/homeassistant/components/conversation/__init__.py index 9c1db128f15..b110d53540c 100644 --- a/homeassistant/components/conversation/__init__.py +++ b/homeassistant/components/conversation/__init__.py @@ -48,21 +48,28 @@ from .default_agent import DefaultAgent, async_setup_default_agent from .entity import ConversationEntity from .http import async_setup as async_setup_conversation_http from .models import AbstractConversationAgent, ConversationInput, ConversationResult -from .session import ChatMessage, ChatSession, ConverseError, async_get_chat_session +from .session import ( + ChatSession, + Content, + ConverseError, + NativeContent, + async_get_chat_session, +) from .trace import ConversationTraceEventType, async_conversation_trace_append __all__ = [ "DOMAIN", "HOME_ASSISTANT_AGENT", "OLD_HOME_ASSISTANT_AGENT", - "ChatMessage", "ChatSession", + "Content", "ConversationEntity", "ConversationEntityFeature", "ConversationInput", "ConversationResult", "ConversationTraceEventType", "ConverseError", + "NativeContent", "async_conversation_trace_append", "async_converse", "async_get_agent_info", diff --git a/homeassistant/components/conversation/default_agent.py b/homeassistant/components/conversation/default_agent.py index bb815698941..be0387555dc 100644 --- a/homeassistant/components/conversation/default_agent.py +++ b/homeassistant/components/conversation/default_agent.py @@ -62,7 +62,7 @@ from .const import ( ) from .entity import ConversationEntity from .models import ConversationInput, ConversationResult -from .session import ChatMessage, async_get_chat_session +from .session import Content, async_get_chat_session from .trace import ConversationTraceEventType, async_conversation_trace_append _LOGGER = logging.getLogger(__name__) @@ -374,11 +374,10 @@ class DefaultAgent(ConversationEntity): speech: str = response.speech.get("plain", {}).get("speech", "") chat_session.async_add_message( - ChatMessage( + Content( role="assistant", agent_id=user_input.agent_id, content=speech, - native=response, ) ) diff --git a/homeassistant/components/conversation/session.py b/homeassistant/components/conversation/session.py index 2235459954f..43f4cbf427c 100644 --- a/homeassistant/components/conversation/session.py +++ b/homeassistant/components/conversation/session.py @@ -126,7 +126,7 @@ async def async_get_chat_session( else: history = ChatSession(hass, conversation_id, user_input.agent_id) - message: ChatMessage = ChatMessage( + message: Content = Content( role="user", agent_id=user_input.agent_id, content=user_input.text, @@ -169,23 +169,21 @@ class ConverseError(HomeAssistantError): @dataclass -class ChatMessage[_NativeT]: - """Base class for chat messages. +class Content: + """Base class for chat messages.""" - When role is native, the content is to be ignored and message - is only meant for storing the native object. - """ - - role: Literal["system", "assistant", "user", "native"] + role: Literal["system", "assistant", "user"] agent_id: str | None content: str - native: _NativeT | None = field(default=None) - # Validate in post-init that if role is native, there is no content and a native object exists - def __post_init__(self) -> None: - """Validate native message.""" - if self.role == "native" and self.native is None: - raise ValueError("Native message must have a native object") + +@dataclass(frozen=True) +class NativeContent[_NativeT]: + """Native content.""" + + role: str = field(init=False, default="native") + agent_id: str + content: _NativeT @dataclass @@ -196,15 +194,15 @@ class ChatSession[_NativeT]: conversation_id: str agent_id: str | None user_name: str | None = None - messages: list[ChatMessage[_NativeT]] = field( - default_factory=lambda: [ChatMessage(role="system", agent_id=None, content="")] + messages: list[Content | NativeContent[_NativeT]] = field( + default_factory=lambda: [Content(role="system", agent_id=None, content="")] ) extra_system_prompt: str | None = None llm_api: llm.APIInstance | None = None last_updated: datetime = field(default_factory=dt_util.utcnow) @callback - def async_add_message(self, message: ChatMessage[_NativeT]) -> None: + def async_add_message(self, message: Content | NativeContent[_NativeT]) -> None: """Process intent.""" if message.role == "system": raise ValueError("Cannot add system messages to history") @@ -216,7 +214,7 @@ class ChatSession[_NativeT]: @callback def async_get_messages( self, agent_id: str | None = None - ) -> list[ChatMessage[_NativeT]]: + ) -> list[Content | NativeContent[_NativeT]]: """Get messages for a specific agent ID. This will filter out any native message tied to other agent IDs. @@ -328,7 +326,7 @@ class ChatSession[_NativeT]: self.llm_api = llm_api self.user_name = user_name self.extra_system_prompt = extra_system_prompt - self.messages[0] = ChatMessage( + self.messages[0] = Content( role="system", agent_id=user_input.agent_id, content=prompt, diff --git a/homeassistant/components/openai_conversation/conversation.py b/homeassistant/components/openai_conversation/conversation.py index 1464f4224d7..2f35bea97e2 100644 --- a/homeassistant/components/openai_conversation/conversation.py +++ b/homeassistant/components/openai_conversation/conversation.py @@ -93,12 +93,13 @@ def _message_convert(message: ChatCompletionMessage) -> ChatCompletionMessagePar def _chat_message_convert( - message: conversation.ChatMessage[ChatCompletionMessageParam], - agent_id: str | None, + message: conversation.Content + | conversation.NativeContent[ChatCompletionMessageParam], ) -> ChatCompletionMessageParam: """Convert any native chat message for this agent to the native format.""" - if message.native is not None and message.agent_id == agent_id: - return message.native + if message.role == "native": + # mypy doesn't understand that checking role ensures content type + return message.content # type: ignore[return-value] return cast( ChatCompletionMessageParam, {"role": message.role, "content": message.content}, @@ -157,14 +158,15 @@ class OpenAIConversationEntity( async with conversation.async_get_chat_session( self.hass, user_input ) as session: - return await self._async_call_api(user_input, session) + return await self._async_handle_message(user_input, session) - async def _async_call_api( + async def _async_handle_message( self, user_input: conversation.ConversationInput, session: conversation.ChatSession[ChatCompletionMessageParam], ) -> conversation.ConversationResult: """Call the API.""" + assert user_input.agent_id options = self.entry.options try: @@ -185,8 +187,7 @@ class OpenAIConversationEntity( ] messages = [ - _chat_message_convert(message, user_input.agent_id) - for message in session.async_get_messages() + _chat_message_convert(message) for message in session.async_get_messages() ] client = self.entry.runtime_data @@ -212,11 +213,10 @@ class OpenAIConversationEntity( messages.append(_message_convert(response)) session.async_add_message( - conversation.ChatMessage( + conversation.Content( role=response.role, agent_id=user_input.agent_id, content=response.content or "", - native=messages[-1], ), ) @@ -237,11 +237,9 @@ class OpenAIConversationEntity( ) ) session.async_add_message( - conversation.ChatMessage( - role="native", + conversation.NativeContent( agent_id=user_input.agent_id, - content="", - native=messages[-1], + content=messages[-1], ) ) diff --git a/tests/components/conversation/test_session.py b/tests/components/conversation/test_session.py index bca19b3b06a..60c7f2957b8 100644 --- a/tests/components/conversation/test_session.py +++ b/tests/components/conversation/test_session.py @@ -82,7 +82,7 @@ async def test_cleanup( assert chat_session.conversation_id != conversation_id conversation_id = chat_session.conversation_id chat_session.async_add_message( - session.ChatMessage( + session.Content( role="assistant", agent_id="mock-agent-id", content="Hey!", @@ -127,12 +127,6 @@ async def test_cleanup( assert len(chat_session.messages) == 2 -def test_chat_message() -> None: - """Test chat message.""" - with pytest.raises(ValueError): - session.ChatMessage(role="native", agent_id=None, content="", native=None) - - async def test_add_message( hass: HomeAssistant, mock_conversation_input: ConversationInput ) -> None: @@ -144,7 +138,7 @@ async def test_add_message( with pytest.raises(ValueError): chat_session.async_add_message( - session.ChatMessage(role="system", agent_id=None, content="") + session.Content(role="system", agent_id=None, content="") ) # No 2 user messages in a row @@ -152,19 +146,19 @@ async def test_add_message( with pytest.raises(ValueError): chat_session.async_add_message( - session.ChatMessage(role="user", agent_id=None, content="") + session.Content(role="user", agent_id=None, content="") ) # No 2 assistant messages in a row chat_session.async_add_message( - session.ChatMessage(role="assistant", agent_id=None, content="") + session.Content(role="assistant", agent_id=None, content="") ) assert len(chat_session.messages) == 3 assert chat_session.messages[-1].role == "assistant" with pytest.raises(ValueError): chat_session.async_add_message( - session.ChatMessage(role="assistant", agent_id=None, content="") + session.Content(role="assistant", agent_id=None, content="") ) @@ -177,12 +171,12 @@ async def test_message_filtering( ) as chat_session: messages = chat_session.async_get_messages(agent_id=None) assert len(messages) == 2 - assert messages[0] == session.ChatMessage( + assert messages[0] == session.Content( role="system", agent_id=None, content="", ) - assert messages[1] == session.ChatMessage( + assert messages[1] == session.Content( role="user", agent_id="mock-agent-id", content=mock_conversation_input.text, @@ -190,7 +184,7 @@ async def test_message_filtering( # Cannot add a second user message in a row with pytest.raises(ValueError): chat_session.async_add_message( - session.ChatMessage( + session.Content( role="user", agent_id="mock-agent-id", content="Hey!", @@ -198,31 +192,25 @@ async def test_message_filtering( ) chat_session.async_add_message( - session.ChatMessage( + session.Content( role="assistant", agent_id="mock-agent-id", content="Hey!", - native="assistant-reply-native", ) ) # Different agent, native messages will be filtered out. chat_session.async_add_message( - session.ChatMessage( - role="native", agent_id="another-mock-agent-id", content="", native=1 - ) + session.NativeContent(agent_id="another-mock-agent-id", content=1) ) chat_session.async_add_message( - session.ChatMessage( - role="native", agent_id="mock-agent-id", content="", native=1 - ) + session.NativeContent(agent_id="mock-agent-id", content=1) ) # A non-native message from another agent is not filtered out. chat_session.async_add_message( - session.ChatMessage( + session.Content( role="assistant", agent_id="another-mock-agent-id", content="Hi!", - native=1, ) ) @@ -231,17 +219,14 @@ async def test_message_filtering( messages = chat_session.async_get_messages(agent_id="mock-agent-id") assert len(messages) == 5 - assert messages[2] == session.ChatMessage( + assert messages[2] == session.Content( role="assistant", agent_id="mock-agent-id", content="Hey!", - native="assistant-reply-native", ) - assert messages[3] == session.ChatMessage( - role="native", agent_id="mock-agent-id", content="", native=1 - ) - assert messages[4] == session.ChatMessage( - role="assistant", agent_id="another-mock-agent-id", content="Hi!", native=1 + assert messages[3] == session.NativeContent(agent_id="mock-agent-id", content=1) + assert messages[4] == session.Content( + role="assistant", agent_id="another-mock-agent-id", content="Hi!" ) @@ -361,7 +346,7 @@ async def test_extra_systen_prompt( user_llm_prompt=None, ) chat_session.async_add_message( - session.ChatMessage( + session.Content( role="assistant", agent_id="mock-agent-id", content="Hey!", @@ -401,7 +386,7 @@ async def test_extra_systen_prompt( user_llm_prompt=None, ) chat_session.async_add_message( - session.ChatMessage( + session.Content( role="assistant", agent_id="mock-agent-id", content="Hey!",