Update anthropic to use the streaming API (#138256)

pull/137668/head
Allen Porter 2025-02-11 16:05:23 -08:00 committed by GitHub
parent 117a71cb67
commit da1e3c29ed
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 262 additions and 127 deletions

View File

@ -1,16 +1,23 @@
"""Conversation support for Anthropic."""
from collections.abc import Callable
from collections.abc import AsyncGenerator, Callable
import json
from typing import Any, Literal, cast
from typing import Any, Literal
import anthropic
from anthropic import AsyncStream
from anthropic._types import NOT_GIVEN
from anthropic.types import (
InputJSONDelta,
Message,
MessageParam,
MessageStreamEvent,
RawContentBlockDeltaEvent,
RawContentBlockStartEvent,
RawContentBlockStopEvent,
TextBlock,
TextBlockParam,
TextDelta,
ToolParam,
ToolResultBlockParam,
ToolUseBlock,
@ -109,7 +116,7 @@ def _convert_content(chat_content: conversation.Content) -> MessageParam:
type="tool_use",
id=tool_call.id,
name=tool_call.tool_name,
input=json.dumps(tool_call.tool_args),
input=tool_call.tool_args,
)
for tool_call in chat_content.tool_calls or ()
],
@ -124,6 +131,66 @@ def _convert_content(chat_content: conversation.Content) -> MessageParam:
raise ValueError(f"Unexpected content type: {type(chat_content)}")
async def _transform_stream(
result: AsyncStream[MessageStreamEvent],
) -> AsyncGenerator[conversation.AssistantContentDeltaDict]:
"""Transform the response stream into HA format.
A typical stream of responses might look something like the following:
- RawMessageStartEvent with no content
- RawContentBlockStartEvent with an empty TextBlock
- RawContentBlockDeltaEvent with a TextDelta
- RawContentBlockDeltaEvent with a TextDelta
- RawContentBlockDeltaEvent with a TextDelta
- ...
- RawContentBlockStopEvent
- RawContentBlockStartEvent with ToolUseBlock specifying the function name
- RawContentBlockDeltaEvent with a InputJSONDelta
- RawContentBlockDeltaEvent with a InputJSONDelta
- ...
- RawContentBlockStopEvent
- RawMessageDeltaEvent with a stop_reason='tool_use'
- RawMessageStopEvent(type='message_stop')
"""
if result is None:
raise TypeError("Expected a stream of messages")
current_tool_call: dict | None = None
async for response in result:
LOGGER.debug("Received response: %s", response)
if isinstance(response, RawContentBlockStartEvent):
if isinstance(response.content_block, ToolUseBlock):
current_tool_call = {
"id": response.content_block.id,
"name": response.content_block.name,
"input": "",
}
elif isinstance(response.content_block, TextBlock):
yield {"role": "assistant"}
elif isinstance(response, RawContentBlockDeltaEvent):
if isinstance(response.delta, InputJSONDelta):
if current_tool_call is None:
raise ValueError("Unexpected delta without a tool call")
current_tool_call["input"] += response.delta.partial_json
elif isinstance(response.delta, TextDelta):
LOGGER.debug("yielding delta: %s", response.delta.text)
yield {"content": response.delta.text}
elif isinstance(response, RawContentBlockStopEvent):
if current_tool_call:
yield {
"tool_calls": [
llm.ToolInput(
id=current_tool_call["id"],
tool_name=current_tool_call["name"],
tool_args=json.loads(current_tool_call["input"]),
)
]
}
current_tool_call = None
class AnthropicConversationEntity(
conversation.ConversationEntity, conversation.AbstractConversationAgent
):
@ -206,58 +273,30 @@ class AnthropicConversationEntity(
# To prevent infinite loops, we limit the number of iterations
for _iteration in range(MAX_TOOL_ITERATIONS):
try:
response = await client.messages.create(
stream = await client.messages.create(
model=options.get(CONF_CHAT_MODEL, RECOMMENDED_CHAT_MODEL),
messages=messages,
tools=tools or NOT_GIVEN,
max_tokens=options.get(CONF_MAX_TOKENS, RECOMMENDED_MAX_TOKENS),
system=system.content,
temperature=options.get(CONF_TEMPERATURE, RECOMMENDED_TEMPERATURE),
stream=True,
)
except anthropic.AnthropicError as err:
raise HomeAssistantError(
f"Sorry, I had a problem talking to Anthropic: {err}"
) from err
LOGGER.debug("Response %s", response)
messages.append(_message_convert(response))
text = "".join(
messages.extend(
[
content.text
for content in response.content
if isinstance(content, TextBlock)
_convert_content(content)
async for content in chat_log.async_add_delta_content_stream(
user_input.agent_id, _transform_stream(stream)
)
]
)
tool_inputs = [
llm.ToolInput(
id=tool_call.id,
tool_name=tool_call.name,
tool_args=cast(dict[str, Any], tool_call.input),
)
for tool_call in response.content
if isinstance(tool_call, ToolUseBlock)
]
tool_results = [
ToolResultBlockParam(
type="tool_result",
tool_use_id=tool_response.tool_call_id,
content=json.dumps(tool_response.tool_result),
)
async for tool_response in chat_log.async_add_assistant_content(
conversation.AssistantContent(
agent_id=user_input.agent_id,
content=text,
tool_calls=tool_inputs or None,
)
)
]
if tool_results:
messages.append(MessageParam(role="user", content=tool_results))
if not tool_inputs:
if not chat_log.unresponded_tool_results:
break
response_content = chat_log.content[-1]

View File

@ -1,9 +1,24 @@
"""Tests for the Anthropic integration."""
from collections.abc import AsyncGenerator
from typing import Any
from unittest.mock import AsyncMock, Mock, patch
from anthropic import RateLimitError
from anthropic.types import Message, TextBlock, ToolUseBlock, Usage
from anthropic.types import (
InputJSONDelta,
Message,
RawContentBlockDeltaEvent,
RawContentBlockStartEvent,
RawContentBlockStopEvent,
RawMessageStartEvent,
RawMessageStopEvent,
RawMessageStreamEvent,
TextBlock,
TextDelta,
ToolUseBlock,
Usage,
)
from freezegun import freeze_time
from httpx import URL, Request, Response
from syrupy.assertion import SnapshotAssertion
@ -20,6 +35,81 @@ from homeassistant.util import ulid as ulid_util
from tests.common import MockConfigEntry
async def stream_generator(
responses: list[RawMessageStreamEvent],
) -> AsyncGenerator[RawMessageStreamEvent]:
"""Generate a response from the assistant."""
for msg in responses:
yield msg
def create_messages(
content_blocks: list[RawMessageStreamEvent],
) -> list[RawMessageStreamEvent]:
"""Create a stream of messages with the specified content blocks."""
return [
RawMessageStartEvent(
message=Message(
type="message",
id="msg_1234567890ABCDEFGHIJKLMN",
content=[],
role="assistant",
model="claude-3-5-sonnet-20240620",
usage=Usage(input_tokens=0, output_tokens=0),
),
type="message_start",
),
*content_blocks,
RawMessageStopEvent(type="message_stop"),
]
def create_content_block(
index: int, text_parts: list[str]
) -> list[RawMessageStreamEvent]:
"""Create a text content block with the specified deltas."""
return [
RawContentBlockStartEvent(
type="content_block_start",
content_block=TextBlock(text="", type="text"),
index=index,
),
*[
RawContentBlockDeltaEvent(
delta=TextDelta(text=text_part, type="text_delta"),
index=index,
type="content_block_delta",
)
for text_part in text_parts
],
RawContentBlockStopEvent(index=index, type="content_block_stop"),
]
def create_tool_use_block(
index: int, tool_id: str, tool_name: str, json_parts: list[str]
) -> list[RawMessageStreamEvent]:
"""Create a tool use content block with the specified deltas."""
return [
RawContentBlockStartEvent(
type="content_block_start",
content_block=ToolUseBlock(
id=tool_id, name=tool_name, input={}, type="tool_use"
),
index=index,
),
*[
RawContentBlockDeltaEvent(
delta=InputJSONDelta(partial_json=json_part, type="input_json_delta"),
index=index,
type="content_block_delta",
)
for json_part in json_parts
],
RawContentBlockStopEvent(index=index, type="content_block_stop"),
]
async def test_entity(
hass: HomeAssistant,
mock_config_entry: MockConfigEntry,
@ -120,6 +210,13 @@ async def test_template_variables(
) as mock_create,
patch("homeassistant.auth.AuthManager.async_get_user", return_value=mock_user),
):
mock_create.return_value = stream_generator(
create_messages(
create_content_block(
0, ["Okay, let", " me take care of that for you", "."]
)
)
)
await hass.config_entries.async_setup(mock_config_entry.entry_id)
await hass.async_block_till_done()
result = await conversation.async_converse(
@ -129,6 +226,10 @@ async def test_template_variables(
assert result.response.response_type == intent.IntentResponseType.ACTION_DONE, (
result
)
assert (
result.response.speech["plain"]["speech"]
== "Okay, let me take care of that for you."
)
assert "The user name is Test User." in mock_create.mock_calls[1][2]["system"]
assert "The user id is 12345." in mock_create.mock_calls[1][2]["system"]
@ -168,39 +269,26 @@ async def test_function_call(
for message in messages:
for content in message["content"]:
if not isinstance(content, str) and content["type"] == "tool_use":
return Message(
type="message",
id="msg_1234567890ABCDEFGHIJKLMN",
content=[
TextBlock(
type="text",
text="I have successfully called the function",
)
],
model="claude-3-5-sonnet-20240620",
role="assistant",
stop_reason="end_turn",
stop_sequence=None,
usage=Usage(input_tokens=8, output_tokens=12),
return stream_generator(
create_messages(
create_content_block(
0, ["I have ", "successfully called ", "the function"]
),
)
)
return Message(
type="message",
id="msg_1234567890ABCDEFGHIJKLMN",
content=[
TextBlock(type="text", text="Certainly, calling it now!"),
ToolUseBlock(
type="tool_use",
id="toolu_0123456789AbCdEfGhIjKlM",
name="test_tool",
input={"param1": "test_value"},
),
],
model="claude-3-5-sonnet-20240620",
role="assistant",
stop_reason="tool_use",
stop_sequence=None,
usage=Usage(input_tokens=8, output_tokens=12),
return stream_generator(
create_messages(
[
*create_content_block(0, ["Certainly, calling it now!"]),
*create_tool_use_block(
1,
"toolu_0123456789AbCdEfGhIjKlM",
"test_tool",
['{"para', 'm1": "test_valu', 'e"}'],
),
]
)
)
with (
@ -222,6 +310,10 @@ async def test_function_call(
assert "Today's date is 2024-06-03." in mock_create.mock_calls[1][2]["system"]
assert result.response.response_type == intent.IntentResponseType.ACTION_DONE
assert (
result.response.speech["plain"]["speech"]
== "I have successfully called the function"
)
assert mock_create.mock_calls[1][2]["messages"][2] == {
"role": "user",
"content": [
@ -275,39 +367,27 @@ async def test_function_exception(
for message in messages:
for content in message["content"]:
if not isinstance(content, str) and content["type"] == "tool_use":
return Message(
type="message",
id="msg_1234567890ABCDEFGHIJKLMN",
content=[
TextBlock(
type="text",
text="There was an error calling the function",
return stream_generator(
create_messages(
create_content_block(
0,
["There was an error calling the function"],
)
],
model="claude-3-5-sonnet-20240620",
role="assistant",
stop_reason="end_turn",
stop_sequence=None,
usage=Usage(input_tokens=8, output_tokens=12),
)
)
return Message(
type="message",
id="msg_1234567890ABCDEFGHIJKLMN",
content=[
TextBlock(type="text", text="Certainly, calling it now!"),
ToolUseBlock(
type="tool_use",
id="toolu_0123456789AbCdEfGhIjKlM",
name="test_tool",
input={"param1": "test_value"},
),
],
model="claude-3-5-sonnet-20240620",
role="assistant",
stop_reason="tool_use",
stop_sequence=None,
usage=Usage(input_tokens=8, output_tokens=12),
return stream_generator(
create_messages(
[
*create_content_block(0, "Certainly, calling it now!"),
*create_tool_use_block(
1,
"toolu_0123456789AbCdEfGhIjKlM",
"test_tool",
['{"param1": "test_value"}'],
),
]
)
)
with patch(
@ -324,6 +404,10 @@ async def test_function_exception(
)
assert result.response.response_type == intent.IntentResponseType.ACTION_DONE
assert (
result.response.speech["plain"]["speech"]
== "There was an error calling the function"
)
assert mock_create.mock_calls[1][2]["messages"][2] == {
"role": "user",
"content": [
@ -376,15 +460,10 @@ async def test_assist_api_tools_conversion(
with patch(
"anthropic.resources.messages.AsyncMessages.create",
new_callable=AsyncMock,
return_value=Message(
type="message",
id="msg_1234567890ABCDEFGHIJKLMN",
content=[TextBlock(type="text", text="Hello, how can I help you?")],
model="claude-3-5-sonnet-20240620",
role="assistant",
stop_reason="end_turn",
stop_sequence=None,
usage=Usage(input_tokens=8, output_tokens=12),
return_value=stream_generator(
create_messages(
create_content_block(0, "Hello, how can I help you?"),
),
),
) as mock_create:
await conversation.async_converse(
@ -425,28 +504,45 @@ async def test_conversation_id(
mock_init_component,
) -> None:
"""Test conversation ID is honored."""
result = await conversation.async_converse(
hass, "hello", None, None, agent_id="conversation.claude"
)
conversation_id = result.conversation_id
def create_stream_generator(*args, **kwargs) -> Any:
return stream_generator(
create_messages(
create_content_block(0, "Hello, how can I help you?"),
),
)
result = await conversation.async_converse(
hass, "hello", conversation_id, None, agent_id="conversation.claude"
)
with patch(
"anthropic.resources.messages.AsyncMessages.create",
new_callable=AsyncMock,
side_effect=create_stream_generator,
):
result = await conversation.async_converse(
hass, "hello", "1234", Context(), agent_id="conversation.claude"
)
assert result.conversation_id == conversation_id
result = await conversation.async_converse(
hass, "hello", None, None, agent_id="conversation.claude"
)
unknown_id = ulid_util.ulid()
conversation_id = result.conversation_id
result = await conversation.async_converse(
hass, "hello", unknown_id, None, agent_id="conversation.claude"
)
result = await conversation.async_converse(
hass, "hello", conversation_id, None, agent_id="conversation.claude"
)
assert result.conversation_id != unknown_id
assert result.conversation_id == conversation_id
result = await conversation.async_converse(
hass, "hello", "koala", None, agent_id="conversation.claude"
)
unknown_id = ulid_util.ulid()
assert result.conversation_id == "koala"
result = await conversation.async_converse(
hass, "hello", unknown_id, None, agent_id="conversation.claude"
)
assert result.conversation_id != unknown_id
result = await conversation.async_converse(
hass, "hello", "koala", None, agent_id="conversation.claude"
)
assert result.conversation_id == "koala"