From da1e3c29edc4954cb08e62d4dbcb8343da00a34c Mon Sep 17 00:00:00 2001 From: Allen Porter Date: Tue, 11 Feb 2025 16:05:23 -0800 Subject: [PATCH] Update anthropic to use the streaming API (#138256) --- .../components/anthropic/conversation.py | 117 +++++--- .../components/anthropic/test_conversation.py | 272 ++++++++++++------ 2 files changed, 262 insertions(+), 127 deletions(-) diff --git a/homeassistant/components/anthropic/conversation.py b/homeassistant/components/anthropic/conversation.py index 9f513509ce7..5511119d377 100644 --- a/homeassistant/components/anthropic/conversation.py +++ b/homeassistant/components/anthropic/conversation.py @@ -1,16 +1,23 @@ """Conversation support for Anthropic.""" -from collections.abc import Callable +from collections.abc import AsyncGenerator, Callable import json -from typing import Any, Literal, cast +from typing import Any, Literal import anthropic +from anthropic import AsyncStream from anthropic._types import NOT_GIVEN from anthropic.types import ( + InputJSONDelta, Message, MessageParam, + MessageStreamEvent, + RawContentBlockDeltaEvent, + RawContentBlockStartEvent, + RawContentBlockStopEvent, TextBlock, TextBlockParam, + TextDelta, ToolParam, ToolResultBlockParam, ToolUseBlock, @@ -109,7 +116,7 @@ def _convert_content(chat_content: conversation.Content) -> MessageParam: type="tool_use", id=tool_call.id, name=tool_call.tool_name, - input=json.dumps(tool_call.tool_args), + input=tool_call.tool_args, ) for tool_call in chat_content.tool_calls or () ], @@ -124,6 +131,66 @@ def _convert_content(chat_content: conversation.Content) -> MessageParam: raise ValueError(f"Unexpected content type: {type(chat_content)}") +async def _transform_stream( + result: AsyncStream[MessageStreamEvent], +) -> AsyncGenerator[conversation.AssistantContentDeltaDict]: + """Transform the response stream into HA format. + + A typical stream of responses might look something like the following: + - RawMessageStartEvent with no content + - RawContentBlockStartEvent with an empty TextBlock + - RawContentBlockDeltaEvent with a TextDelta + - RawContentBlockDeltaEvent with a TextDelta + - RawContentBlockDeltaEvent with a TextDelta + - ... + - RawContentBlockStopEvent + - RawContentBlockStartEvent with ToolUseBlock specifying the function name + - RawContentBlockDeltaEvent with a InputJSONDelta + - RawContentBlockDeltaEvent with a InputJSONDelta + - ... + - RawContentBlockStopEvent + - RawMessageDeltaEvent with a stop_reason='tool_use' + - RawMessageStopEvent(type='message_stop') + """ + if result is None: + raise TypeError("Expected a stream of messages") + + current_tool_call: dict | None = None + + async for response in result: + LOGGER.debug("Received response: %s", response) + + if isinstance(response, RawContentBlockStartEvent): + if isinstance(response.content_block, ToolUseBlock): + current_tool_call = { + "id": response.content_block.id, + "name": response.content_block.name, + "input": "", + } + elif isinstance(response.content_block, TextBlock): + yield {"role": "assistant"} + elif isinstance(response, RawContentBlockDeltaEvent): + if isinstance(response.delta, InputJSONDelta): + if current_tool_call is None: + raise ValueError("Unexpected delta without a tool call") + current_tool_call["input"] += response.delta.partial_json + elif isinstance(response.delta, TextDelta): + LOGGER.debug("yielding delta: %s", response.delta.text) + yield {"content": response.delta.text} + elif isinstance(response, RawContentBlockStopEvent): + if current_tool_call: + yield { + "tool_calls": [ + llm.ToolInput( + id=current_tool_call["id"], + tool_name=current_tool_call["name"], + tool_args=json.loads(current_tool_call["input"]), + ) + ] + } + current_tool_call = None + + class AnthropicConversationEntity( conversation.ConversationEntity, conversation.AbstractConversationAgent ): @@ -206,58 +273,30 @@ class AnthropicConversationEntity( # To prevent infinite loops, we limit the number of iterations for _iteration in range(MAX_TOOL_ITERATIONS): try: - response = await client.messages.create( + stream = await client.messages.create( model=options.get(CONF_CHAT_MODEL, RECOMMENDED_CHAT_MODEL), messages=messages, tools=tools or NOT_GIVEN, max_tokens=options.get(CONF_MAX_TOKENS, RECOMMENDED_MAX_TOKENS), system=system.content, temperature=options.get(CONF_TEMPERATURE, RECOMMENDED_TEMPERATURE), + stream=True, ) except anthropic.AnthropicError as err: raise HomeAssistantError( f"Sorry, I had a problem talking to Anthropic: {err}" ) from err - LOGGER.debug("Response %s", response) - - messages.append(_message_convert(response)) - - text = "".join( + messages.extend( [ - content.text - for content in response.content - if isinstance(content, TextBlock) + _convert_content(content) + async for content in chat_log.async_add_delta_content_stream( + user_input.agent_id, _transform_stream(stream) + ) ] ) - tool_inputs = [ - llm.ToolInput( - id=tool_call.id, - tool_name=tool_call.name, - tool_args=cast(dict[str, Any], tool_call.input), - ) - for tool_call in response.content - if isinstance(tool_call, ToolUseBlock) - ] - tool_results = [ - ToolResultBlockParam( - type="tool_result", - tool_use_id=tool_response.tool_call_id, - content=json.dumps(tool_response.tool_result), - ) - async for tool_response in chat_log.async_add_assistant_content( - conversation.AssistantContent( - agent_id=user_input.agent_id, - content=text, - tool_calls=tool_inputs or None, - ) - ) - ] - if tool_results: - messages.append(MessageParam(role="user", content=tool_results)) - - if not tool_inputs: + if not chat_log.unresponded_tool_results: break response_content = chat_log.content[-1] diff --git a/tests/components/anthropic/test_conversation.py b/tests/components/anthropic/test_conversation.py index 2f1de3a2db9..bda9ca32b34 100644 --- a/tests/components/anthropic/test_conversation.py +++ b/tests/components/anthropic/test_conversation.py @@ -1,9 +1,24 @@ """Tests for the Anthropic integration.""" +from collections.abc import AsyncGenerator +from typing import Any from unittest.mock import AsyncMock, Mock, patch from anthropic import RateLimitError -from anthropic.types import Message, TextBlock, ToolUseBlock, Usage +from anthropic.types import ( + InputJSONDelta, + Message, + RawContentBlockDeltaEvent, + RawContentBlockStartEvent, + RawContentBlockStopEvent, + RawMessageStartEvent, + RawMessageStopEvent, + RawMessageStreamEvent, + TextBlock, + TextDelta, + ToolUseBlock, + Usage, +) from freezegun import freeze_time from httpx import URL, Request, Response from syrupy.assertion import SnapshotAssertion @@ -20,6 +35,81 @@ from homeassistant.util import ulid as ulid_util from tests.common import MockConfigEntry +async def stream_generator( + responses: list[RawMessageStreamEvent], +) -> AsyncGenerator[RawMessageStreamEvent]: + """Generate a response from the assistant.""" + for msg in responses: + yield msg + + +def create_messages( + content_blocks: list[RawMessageStreamEvent], +) -> list[RawMessageStreamEvent]: + """Create a stream of messages with the specified content blocks.""" + return [ + RawMessageStartEvent( + message=Message( + type="message", + id="msg_1234567890ABCDEFGHIJKLMN", + content=[], + role="assistant", + model="claude-3-5-sonnet-20240620", + usage=Usage(input_tokens=0, output_tokens=0), + ), + type="message_start", + ), + *content_blocks, + RawMessageStopEvent(type="message_stop"), + ] + + +def create_content_block( + index: int, text_parts: list[str] +) -> list[RawMessageStreamEvent]: + """Create a text content block with the specified deltas.""" + return [ + RawContentBlockStartEvent( + type="content_block_start", + content_block=TextBlock(text="", type="text"), + index=index, + ), + *[ + RawContentBlockDeltaEvent( + delta=TextDelta(text=text_part, type="text_delta"), + index=index, + type="content_block_delta", + ) + for text_part in text_parts + ], + RawContentBlockStopEvent(index=index, type="content_block_stop"), + ] + + +def create_tool_use_block( + index: int, tool_id: str, tool_name: str, json_parts: list[str] +) -> list[RawMessageStreamEvent]: + """Create a tool use content block with the specified deltas.""" + return [ + RawContentBlockStartEvent( + type="content_block_start", + content_block=ToolUseBlock( + id=tool_id, name=tool_name, input={}, type="tool_use" + ), + index=index, + ), + *[ + RawContentBlockDeltaEvent( + delta=InputJSONDelta(partial_json=json_part, type="input_json_delta"), + index=index, + type="content_block_delta", + ) + for json_part in json_parts + ], + RawContentBlockStopEvent(index=index, type="content_block_stop"), + ] + + async def test_entity( hass: HomeAssistant, mock_config_entry: MockConfigEntry, @@ -120,6 +210,13 @@ async def test_template_variables( ) as mock_create, patch("homeassistant.auth.AuthManager.async_get_user", return_value=mock_user), ): + mock_create.return_value = stream_generator( + create_messages( + create_content_block( + 0, ["Okay, let", " me take care of that for you", "."] + ) + ) + ) await hass.config_entries.async_setup(mock_config_entry.entry_id) await hass.async_block_till_done() result = await conversation.async_converse( @@ -129,6 +226,10 @@ async def test_template_variables( assert result.response.response_type == intent.IntentResponseType.ACTION_DONE, ( result ) + assert ( + result.response.speech["plain"]["speech"] + == "Okay, let me take care of that for you." + ) assert "The user name is Test User." in mock_create.mock_calls[1][2]["system"] assert "The user id is 12345." in mock_create.mock_calls[1][2]["system"] @@ -168,39 +269,26 @@ async def test_function_call( for message in messages: for content in message["content"]: if not isinstance(content, str) and content["type"] == "tool_use": - return Message( - type="message", - id="msg_1234567890ABCDEFGHIJKLMN", - content=[ - TextBlock( - type="text", - text="I have successfully called the function", - ) - ], - model="claude-3-5-sonnet-20240620", - role="assistant", - stop_reason="end_turn", - stop_sequence=None, - usage=Usage(input_tokens=8, output_tokens=12), + return stream_generator( + create_messages( + create_content_block( + 0, ["I have ", "successfully called ", "the function"] + ), + ) ) - return Message( - type="message", - id="msg_1234567890ABCDEFGHIJKLMN", - content=[ - TextBlock(type="text", text="Certainly, calling it now!"), - ToolUseBlock( - type="tool_use", - id="toolu_0123456789AbCdEfGhIjKlM", - name="test_tool", - input={"param1": "test_value"}, - ), - ], - model="claude-3-5-sonnet-20240620", - role="assistant", - stop_reason="tool_use", - stop_sequence=None, - usage=Usage(input_tokens=8, output_tokens=12), + return stream_generator( + create_messages( + [ + *create_content_block(0, ["Certainly, calling it now!"]), + *create_tool_use_block( + 1, + "toolu_0123456789AbCdEfGhIjKlM", + "test_tool", + ['{"para', 'm1": "test_valu', 'e"}'], + ), + ] + ) ) with ( @@ -222,6 +310,10 @@ async def test_function_call( assert "Today's date is 2024-06-03." in mock_create.mock_calls[1][2]["system"] assert result.response.response_type == intent.IntentResponseType.ACTION_DONE + assert ( + result.response.speech["plain"]["speech"] + == "I have successfully called the function" + ) assert mock_create.mock_calls[1][2]["messages"][2] == { "role": "user", "content": [ @@ -275,39 +367,27 @@ async def test_function_exception( for message in messages: for content in message["content"]: if not isinstance(content, str) and content["type"] == "tool_use": - return Message( - type="message", - id="msg_1234567890ABCDEFGHIJKLMN", - content=[ - TextBlock( - type="text", - text="There was an error calling the function", + return stream_generator( + create_messages( + create_content_block( + 0, + ["There was an error calling the function"], ) - ], - model="claude-3-5-sonnet-20240620", - role="assistant", - stop_reason="end_turn", - stop_sequence=None, - usage=Usage(input_tokens=8, output_tokens=12), + ) ) - return Message( - type="message", - id="msg_1234567890ABCDEFGHIJKLMN", - content=[ - TextBlock(type="text", text="Certainly, calling it now!"), - ToolUseBlock( - type="tool_use", - id="toolu_0123456789AbCdEfGhIjKlM", - name="test_tool", - input={"param1": "test_value"}, - ), - ], - model="claude-3-5-sonnet-20240620", - role="assistant", - stop_reason="tool_use", - stop_sequence=None, - usage=Usage(input_tokens=8, output_tokens=12), + return stream_generator( + create_messages( + [ + *create_content_block(0, "Certainly, calling it now!"), + *create_tool_use_block( + 1, + "toolu_0123456789AbCdEfGhIjKlM", + "test_tool", + ['{"param1": "test_value"}'], + ), + ] + ) ) with patch( @@ -324,6 +404,10 @@ async def test_function_exception( ) assert result.response.response_type == intent.IntentResponseType.ACTION_DONE + assert ( + result.response.speech["plain"]["speech"] + == "There was an error calling the function" + ) assert mock_create.mock_calls[1][2]["messages"][2] == { "role": "user", "content": [ @@ -376,15 +460,10 @@ async def test_assist_api_tools_conversion( with patch( "anthropic.resources.messages.AsyncMessages.create", new_callable=AsyncMock, - return_value=Message( - type="message", - id="msg_1234567890ABCDEFGHIJKLMN", - content=[TextBlock(type="text", text="Hello, how can I help you?")], - model="claude-3-5-sonnet-20240620", - role="assistant", - stop_reason="end_turn", - stop_sequence=None, - usage=Usage(input_tokens=8, output_tokens=12), + return_value=stream_generator( + create_messages( + create_content_block(0, "Hello, how can I help you?"), + ), ), ) as mock_create: await conversation.async_converse( @@ -425,28 +504,45 @@ async def test_conversation_id( mock_init_component, ) -> None: """Test conversation ID is honored.""" - result = await conversation.async_converse( - hass, "hello", None, None, agent_id="conversation.claude" - ) - conversation_id = result.conversation_id + def create_stream_generator(*args, **kwargs) -> Any: + return stream_generator( + create_messages( + create_content_block(0, "Hello, how can I help you?"), + ), + ) - result = await conversation.async_converse( - hass, "hello", conversation_id, None, agent_id="conversation.claude" - ) + with patch( + "anthropic.resources.messages.AsyncMessages.create", + new_callable=AsyncMock, + side_effect=create_stream_generator, + ): + result = await conversation.async_converse( + hass, "hello", "1234", Context(), agent_id="conversation.claude" + ) - assert result.conversation_id == conversation_id + result = await conversation.async_converse( + hass, "hello", None, None, agent_id="conversation.claude" + ) - unknown_id = ulid_util.ulid() + conversation_id = result.conversation_id - result = await conversation.async_converse( - hass, "hello", unknown_id, None, agent_id="conversation.claude" - ) + result = await conversation.async_converse( + hass, "hello", conversation_id, None, agent_id="conversation.claude" + ) - assert result.conversation_id != unknown_id + assert result.conversation_id == conversation_id - result = await conversation.async_converse( - hass, "hello", "koala", None, agent_id="conversation.claude" - ) + unknown_id = ulid_util.ulid() - assert result.conversation_id == "koala" + result = await conversation.async_converse( + hass, "hello", unknown_id, None, agent_id="conversation.claude" + ) + + assert result.conversation_id != unknown_id + + result = await conversation.async_converse( + hass, "hello", "koala", None, agent_id="conversation.claude" + ) + + assert result.conversation_id == "koala"