diff --git a/homeassistant/components/anthropic/conversation.py b/homeassistant/components/anthropic/conversation.py index 259d1295809..b479ee4409c 100644 --- a/homeassistant/components/anthropic/conversation.py +++ b/homeassistant/components/anthropic/conversation.py @@ -16,18 +16,15 @@ from anthropic.types import ( ToolUseBlock, ToolUseBlockParam, ) -import voluptuous as vol from voluptuous_openapi import convert from homeassistant.components import conversation -from homeassistant.components.conversation import trace from homeassistant.config_entries import ConfigEntry from homeassistant.const import CONF_LLM_HASS_API, MATCH_ALL from homeassistant.core import HomeAssistant -from homeassistant.exceptions import HomeAssistantError, TemplateError -from homeassistant.helpers import device_registry as dr, intent, llm, template +from homeassistant.exceptions import HomeAssistantError +from homeassistant.helpers import chat_session, device_registry as dr, intent, llm from homeassistant.helpers.entity_platform import AddEntitiesCallback -from homeassistant.util import ulid as ulid_util from . import AnthropicConfigEntry from .const import ( @@ -89,6 +86,44 @@ def _message_convert( return MessageParam(role=message.role, content=param_content) +def _convert_content(chat_content: conversation.Content) -> MessageParam: + """Create tool response content.""" + if isinstance(chat_content, conversation.ToolResultContent): + return MessageParam( + role="user", + content=[ + ToolResultBlockParam( + type="tool_result", + tool_use_id=chat_content.tool_call_id, + content=json.dumps(chat_content.tool_result), + ) + ], + ) + if isinstance(chat_content, conversation.AssistantContent): + return MessageParam( + role="assistant", + content=[ + TextBlockParam(type="text", text=chat_content.content or ""), + *[ + ToolUseBlockParam( + type="tool_use", + id=tool_call.id, + name=tool_call.tool_name, + input=json.dumps(tool_call.tool_args), + ) + for tool_call in chat_content.tool_calls or () + ], + ], + ) + if isinstance(chat_content, conversation.UserContent): + return MessageParam( + role="user", + content=chat_content.content, + ) + # Note: We don't pass SystemContent here as its passed to the API as the prompt + raise ValueError(f"Unexpected content type: {type(chat_content)}") + + class AnthropicConversationEntity( conversation.ConversationEntity, conversation.AbstractConversationAgent ): @@ -100,7 +135,6 @@ class AnthropicConversationEntity( def __init__(self, entry: AnthropicConfigEntry) -> None: """Initialize the agent.""" self.entry = entry - self.history: dict[str, list[MessageParam]] = {} self._attr_unique_id = entry.entry_id self._attr_device_info = dr.DeviceInfo( identifiers={(DOMAIN, entry.entry_id)}, @@ -129,110 +163,43 @@ class AnthropicConversationEntity( self, user_input: conversation.ConversationInput ) -> conversation.ConversationResult: """Process a sentence.""" - options = self.entry.options - intent_response = intent.IntentResponse(language=user_input.language) - llm_api: llm.APIInstance | None = None - tools: list[ToolParam] | None = None - user_name: str | None = None - llm_context = llm.LLMContext( - platform=DOMAIN, - context=user_input.context, - user_prompt=user_input.text, - language=user_input.language, - assistant=conversation.DOMAIN, - device_id=user_input.device_id, - ) - - if options.get(CONF_LLM_HASS_API): - try: - llm_api = await llm.async_get_api( - self.hass, - options[CONF_LLM_HASS_API], - llm_context, - ) - except HomeAssistantError as err: - LOGGER.error("Error getting LLM API: %s", err) - intent_response.async_set_error( - intent.IntentResponseErrorCode.UNKNOWN, - f"Error preparing LLM API: {err}", - ) - return conversation.ConversationResult( - response=intent_response, conversation_id=user_input.conversation_id - ) - tools = [ - _format_tool(tool, llm_api.custom_serializer) for tool in llm_api.tools - ] - - if user_input.conversation_id is None: - conversation_id = ulid_util.ulid_now() - messages = [] - - elif user_input.conversation_id in self.history: - conversation_id = user_input.conversation_id - messages = self.history[conversation_id] - - else: - # Conversation IDs are ULIDs. We generate a new one if not provided. - # If an old OLID is passed in, we will generate a new one to indicate - # a new conversation was started. If the user picks their own, they - # want to track a conversation and we respect it. - try: - ulid_util.ulid_to_bytes(user_input.conversation_id) - conversation_id = ulid_util.ulid_now() - except ValueError: - conversation_id = user_input.conversation_id - - messages = [] - - if ( - user_input.context - and user_input.context.user_id - and ( - user := await self.hass.auth.async_get_user(user_input.context.user_id) - ) + with ( + chat_session.async_get_chat_session( + self.hass, user_input.conversation_id + ) as session, + conversation.async_get_chat_log(self.hass, session, user_input) as chat_log, ): - user_name = user.name + return await self._async_handle_message(user_input, chat_log) + + async def _async_handle_message( + self, + user_input: conversation.ConversationInput, + chat_log: conversation.ChatLog, + ) -> conversation.ConversationResult: + """Call the API.""" + options = self.entry.options try: - prompt_parts = [ - template.Template( - llm.BASE_PROMPT - + options.get(CONF_PROMPT, llm.DEFAULT_INSTRUCTIONS_PROMPT), - self.hass, - ).async_render( - { - "ha_name": self.hass.config.location_name, - "user_name": user_name, - "llm_context": llm_context, - }, - parse_result=False, - ) + await chat_log.async_update_llm_data( + DOMAIN, + user_input, + options.get(CONF_LLM_HASS_API), + options.get(CONF_PROMPT), + ) + except conversation.ConverseError as err: + return err.as_conversation_result() + + tools: list[ToolParam] | None = None + if chat_log.llm_api: + tools = [ + _format_tool(tool, chat_log.llm_api.custom_serializer) + for tool in chat_log.llm_api.tools ] - except TemplateError as err: - LOGGER.error("Error rendering prompt: %s", err) - intent_response.async_set_error( - intent.IntentResponseErrorCode.UNKNOWN, - f"Sorry, I had a problem with my template: {err}", - ) - return conversation.ConversationResult( - response=intent_response, conversation_id=conversation_id - ) - - if llm_api: - prompt_parts.append(llm_api.api_prompt) - - prompt = "\n".join(prompt_parts) - - # Create a copy of the variable because we attach it to the trace - messages = [*messages, MessageParam(role="user", content=user_input.text)] - - LOGGER.debug("Prompt: %s", messages) - LOGGER.debug("Tools: %s", tools) - trace.async_conversation_trace_append( - trace.ConversationTraceEventType.AGENT_DETAIL, - {"system": prompt, "messages": messages}, - ) + system = chat_log.content[0] + if not isinstance(system, conversation.SystemContent): + raise TypeError("First message must be a system message") + messages = [_convert_content(content) for content in chat_log.content[1:]] client = self.entry.runtime_data @@ -244,69 +211,62 @@ class AnthropicConversationEntity( messages=messages, tools=tools or NOT_GIVEN, max_tokens=options.get(CONF_MAX_TOKENS, RECOMMENDED_MAX_TOKENS), - system=prompt, + system=system.content, temperature=options.get(CONF_TEMPERATURE, RECOMMENDED_TEMPERATURE), ) except anthropic.AnthropicError as err: - intent_response.async_set_error( - intent.IntentResponseErrorCode.UNKNOWN, - f"Sorry, I had a problem talking to Anthropic: {err}", - ) - return conversation.ConversationResult( - response=intent_response, conversation_id=conversation_id - ) + raise HomeAssistantError( + f"Sorry, I had a problem talking to Anthropic: {err}" + ) from err LOGGER.debug("Response %s", response) messages.append(_message_convert(response)) - if response.stop_reason != "tool_use" or not llm_api: - break - - tool_results: list[ToolResultBlockParam] = [] - for tool_call in response.content: - if isinstance(tool_call, TextBlock): - LOGGER.info(tool_call.text) - - if not isinstance(tool_call, ToolUseBlock): - continue - - tool_input = llm.ToolInput( + text = "".join( + [ + content.text + for content in response.content + if isinstance(content, TextBlock) + ] + ) + tool_inputs = [ + llm.ToolInput( id=tool_call.id, tool_name=tool_call.name, tool_args=cast(dict[str, Any], tool_call.input), ) - LOGGER.debug( - "Tool call: %s(%s)", tool_input.tool_name, tool_input.tool_args + for tool_call in response.content + if isinstance(tool_call, ToolUseBlock) + ] + + tool_results = [ + ToolResultBlockParam( + type="tool_result", + tool_use_id=tool_response.tool_call_id, + content=json.dumps(tool_response.tool_result), ) - - try: - tool_response = await llm_api.async_call_tool(tool_input) - except (HomeAssistantError, vol.Invalid) as e: - tool_response = {"error": type(e).__name__} - if str(e): - tool_response["error_text"] = str(e) - - LOGGER.debug("Tool response: %s", tool_response) - tool_results.append( - ToolResultBlockParam( - type="tool_result", - tool_use_id=tool_call.id, - content=json.dumps(tool_response), + async for tool_response in chat_log.async_add_assistant_content( + conversation.AssistantContent( + agent_id=user_input.agent_id, + content=text, + tool_calls=tool_inputs or None, ) ) + ] + if tool_results: + messages.append(MessageParam(role="user", content=tool_results)) - messages.append(MessageParam(role="user", content=tool_results)) - - self.history[conversation_id] = messages - - for content in response.content: - if isinstance(content, TextBlock): - intent_response.async_set_speech(content.text) + if not tool_inputs: break + response_content = chat_log.content[-1] + if not isinstance(response_content, conversation.AssistantContent): + raise TypeError("Last message must be an assistant message") + intent_response = intent.IntentResponse(language=user_input.language) + intent_response.async_set_speech(response_content.content or "") return conversation.ConversationResult( - response=intent_response, conversation_id=conversation_id + response=intent_response, conversation_id=chat_log.conversation_id ) async def _async_entry_update_listener( diff --git a/tests/components/anthropic/snapshots/test_conversation.ambr b/tests/components/anthropic/snapshots/test_conversation.ambr index e4dd7cd00bb..93f3b03d9af 100644 --- a/tests/components/anthropic/snapshots/test_conversation.ambr +++ b/tests/components/anthropic/snapshots/test_conversation.ambr @@ -1,7 +1,7 @@ # serializer version: 1 # name: test_unknown_hass_api dict({ - 'conversation_id': None, + 'conversation_id': '1234', 'response': IntentResponse( card=dict({ }), @@ -20,7 +20,7 @@ speech=dict({ 'plain': dict({ 'extra_data': None, - 'speech': 'Error preparing LLM API: API non-existing not found', + 'speech': 'Error preparing LLM API', }), }), speech_slots=dict({ diff --git a/tests/components/anthropic/test_conversation.py b/tests/components/anthropic/test_conversation.py index bb77e2ff926..2f1de3a2db9 100644 --- a/tests/components/anthropic/test_conversation.py +++ b/tests/components/anthropic/test_conversation.py @@ -10,7 +10,6 @@ from syrupy.assertion import SnapshotAssertion import voluptuous as vol from homeassistant.components import conversation -from homeassistant.components.conversation import trace from homeassistant.const import CONF_LLM_HASS_API from homeassistant.core import Context, HomeAssistant from homeassistant.exceptions import HomeAssistantError @@ -250,42 +249,6 @@ async def test_function_call( ), ) - # Test Conversation tracing - traces = trace.async_get_traces() - assert traces - last_trace = traces[-1].as_dict() - trace_events = last_trace.get("events", []) - assert [event["event_type"] for event in trace_events] == [ - trace.ConversationTraceEventType.ASYNC_PROCESS, - trace.ConversationTraceEventType.AGENT_DETAIL, - trace.ConversationTraceEventType.TOOL_CALL, - ] - # AGENT_DETAIL event contains the raw prompt passed to the model - detail_event = trace_events[1] - assert "Answer in plain text" in detail_event["data"]["system"] - assert "Today's date is 2024-06-03." in trace_events[1]["data"]["system"] - - # Call it again, make sure we have updated prompt - with ( - patch( - "anthropic.resources.messages.AsyncMessages.create", - new_callable=AsyncMock, - side_effect=completion_result, - ) as mock_create, - freeze_time("2024-06-04 23:00:00"), - ): - result = await conversation.async_converse( - hass, - "Please call the test function", - None, - context, - agent_id=agent_id, - ) - - assert "Today's date is 2024-06-04." in mock_create.mock_calls[1][2]["system"] - # Test old assert message not updated - assert "Today's date is 2024-06-03." in trace_events[1]["data"]["system"] - @patch("homeassistant.components.anthropic.conversation.llm.AssistAPI._async_get_tools") async def test_function_exception( @@ -448,7 +411,7 @@ async def test_unknown_hass_api( ) result = await conversation.async_converse( - hass, "hello", None, Context(), agent_id="conversation.claude" + hass, "hello", "1234", Context(), agent_id="conversation.claude" ) assert result == snapshot