Update anthropic to use the streaming API (#138256)
parent
117a71cb67
commit
da1e3c29ed
|
@ -1,16 +1,23 @@
|
|||
"""Conversation support for Anthropic."""
|
||||
|
||||
from collections.abc import Callable
|
||||
from collections.abc import AsyncGenerator, Callable
|
||||
import json
|
||||
from typing import Any, Literal, cast
|
||||
from typing import Any, Literal
|
||||
|
||||
import anthropic
|
||||
from anthropic import AsyncStream
|
||||
from anthropic._types import NOT_GIVEN
|
||||
from anthropic.types import (
|
||||
InputJSONDelta,
|
||||
Message,
|
||||
MessageParam,
|
||||
MessageStreamEvent,
|
||||
RawContentBlockDeltaEvent,
|
||||
RawContentBlockStartEvent,
|
||||
RawContentBlockStopEvent,
|
||||
TextBlock,
|
||||
TextBlockParam,
|
||||
TextDelta,
|
||||
ToolParam,
|
||||
ToolResultBlockParam,
|
||||
ToolUseBlock,
|
||||
|
@ -109,7 +116,7 @@ def _convert_content(chat_content: conversation.Content) -> MessageParam:
|
|||
type="tool_use",
|
||||
id=tool_call.id,
|
||||
name=tool_call.tool_name,
|
||||
input=json.dumps(tool_call.tool_args),
|
||||
input=tool_call.tool_args,
|
||||
)
|
||||
for tool_call in chat_content.tool_calls or ()
|
||||
],
|
||||
|
@ -124,6 +131,66 @@ def _convert_content(chat_content: conversation.Content) -> MessageParam:
|
|||
raise ValueError(f"Unexpected content type: {type(chat_content)}")
|
||||
|
||||
|
||||
async def _transform_stream(
|
||||
result: AsyncStream[MessageStreamEvent],
|
||||
) -> AsyncGenerator[conversation.AssistantContentDeltaDict]:
|
||||
"""Transform the response stream into HA format.
|
||||
|
||||
A typical stream of responses might look something like the following:
|
||||
- RawMessageStartEvent with no content
|
||||
- RawContentBlockStartEvent with an empty TextBlock
|
||||
- RawContentBlockDeltaEvent with a TextDelta
|
||||
- RawContentBlockDeltaEvent with a TextDelta
|
||||
- RawContentBlockDeltaEvent with a TextDelta
|
||||
- ...
|
||||
- RawContentBlockStopEvent
|
||||
- RawContentBlockStartEvent with ToolUseBlock specifying the function name
|
||||
- RawContentBlockDeltaEvent with a InputJSONDelta
|
||||
- RawContentBlockDeltaEvent with a InputJSONDelta
|
||||
- ...
|
||||
- RawContentBlockStopEvent
|
||||
- RawMessageDeltaEvent with a stop_reason='tool_use'
|
||||
- RawMessageStopEvent(type='message_stop')
|
||||
"""
|
||||
if result is None:
|
||||
raise TypeError("Expected a stream of messages")
|
||||
|
||||
current_tool_call: dict | None = None
|
||||
|
||||
async for response in result:
|
||||
LOGGER.debug("Received response: %s", response)
|
||||
|
||||
if isinstance(response, RawContentBlockStartEvent):
|
||||
if isinstance(response.content_block, ToolUseBlock):
|
||||
current_tool_call = {
|
||||
"id": response.content_block.id,
|
||||
"name": response.content_block.name,
|
||||
"input": "",
|
||||
}
|
||||
elif isinstance(response.content_block, TextBlock):
|
||||
yield {"role": "assistant"}
|
||||
elif isinstance(response, RawContentBlockDeltaEvent):
|
||||
if isinstance(response.delta, InputJSONDelta):
|
||||
if current_tool_call is None:
|
||||
raise ValueError("Unexpected delta without a tool call")
|
||||
current_tool_call["input"] += response.delta.partial_json
|
||||
elif isinstance(response.delta, TextDelta):
|
||||
LOGGER.debug("yielding delta: %s", response.delta.text)
|
||||
yield {"content": response.delta.text}
|
||||
elif isinstance(response, RawContentBlockStopEvent):
|
||||
if current_tool_call:
|
||||
yield {
|
||||
"tool_calls": [
|
||||
llm.ToolInput(
|
||||
id=current_tool_call["id"],
|
||||
tool_name=current_tool_call["name"],
|
||||
tool_args=json.loads(current_tool_call["input"]),
|
||||
)
|
||||
]
|
||||
}
|
||||
current_tool_call = None
|
||||
|
||||
|
||||
class AnthropicConversationEntity(
|
||||
conversation.ConversationEntity, conversation.AbstractConversationAgent
|
||||
):
|
||||
|
@ -206,58 +273,30 @@ class AnthropicConversationEntity(
|
|||
# To prevent infinite loops, we limit the number of iterations
|
||||
for _iteration in range(MAX_TOOL_ITERATIONS):
|
||||
try:
|
||||
response = await client.messages.create(
|
||||
stream = await client.messages.create(
|
||||
model=options.get(CONF_CHAT_MODEL, RECOMMENDED_CHAT_MODEL),
|
||||
messages=messages,
|
||||
tools=tools or NOT_GIVEN,
|
||||
max_tokens=options.get(CONF_MAX_TOKENS, RECOMMENDED_MAX_TOKENS),
|
||||
system=system.content,
|
||||
temperature=options.get(CONF_TEMPERATURE, RECOMMENDED_TEMPERATURE),
|
||||
stream=True,
|
||||
)
|
||||
except anthropic.AnthropicError as err:
|
||||
raise HomeAssistantError(
|
||||
f"Sorry, I had a problem talking to Anthropic: {err}"
|
||||
) from err
|
||||
|
||||
LOGGER.debug("Response %s", response)
|
||||
|
||||
messages.append(_message_convert(response))
|
||||
|
||||
text = "".join(
|
||||
messages.extend(
|
||||
[
|
||||
content.text
|
||||
for content in response.content
|
||||
if isinstance(content, TextBlock)
|
||||
_convert_content(content)
|
||||
async for content in chat_log.async_add_delta_content_stream(
|
||||
user_input.agent_id, _transform_stream(stream)
|
||||
)
|
||||
]
|
||||
)
|
||||
tool_inputs = [
|
||||
llm.ToolInput(
|
||||
id=tool_call.id,
|
||||
tool_name=tool_call.name,
|
||||
tool_args=cast(dict[str, Any], tool_call.input),
|
||||
)
|
||||
for tool_call in response.content
|
||||
if isinstance(tool_call, ToolUseBlock)
|
||||
]
|
||||
|
||||
tool_results = [
|
||||
ToolResultBlockParam(
|
||||
type="tool_result",
|
||||
tool_use_id=tool_response.tool_call_id,
|
||||
content=json.dumps(tool_response.tool_result),
|
||||
)
|
||||
async for tool_response in chat_log.async_add_assistant_content(
|
||||
conversation.AssistantContent(
|
||||
agent_id=user_input.agent_id,
|
||||
content=text,
|
||||
tool_calls=tool_inputs or None,
|
||||
)
|
||||
)
|
||||
]
|
||||
if tool_results:
|
||||
messages.append(MessageParam(role="user", content=tool_results))
|
||||
|
||||
if not tool_inputs:
|
||||
if not chat_log.unresponded_tool_results:
|
||||
break
|
||||
|
||||
response_content = chat_log.content[-1]
|
||||
|
|
|
@ -1,9 +1,24 @@
|
|||
"""Tests for the Anthropic integration."""
|
||||
|
||||
from collections.abc import AsyncGenerator
|
||||
from typing import Any
|
||||
from unittest.mock import AsyncMock, Mock, patch
|
||||
|
||||
from anthropic import RateLimitError
|
||||
from anthropic.types import Message, TextBlock, ToolUseBlock, Usage
|
||||
from anthropic.types import (
|
||||
InputJSONDelta,
|
||||
Message,
|
||||
RawContentBlockDeltaEvent,
|
||||
RawContentBlockStartEvent,
|
||||
RawContentBlockStopEvent,
|
||||
RawMessageStartEvent,
|
||||
RawMessageStopEvent,
|
||||
RawMessageStreamEvent,
|
||||
TextBlock,
|
||||
TextDelta,
|
||||
ToolUseBlock,
|
||||
Usage,
|
||||
)
|
||||
from freezegun import freeze_time
|
||||
from httpx import URL, Request, Response
|
||||
from syrupy.assertion import SnapshotAssertion
|
||||
|
@ -20,6 +35,81 @@ from homeassistant.util import ulid as ulid_util
|
|||
from tests.common import MockConfigEntry
|
||||
|
||||
|
||||
async def stream_generator(
|
||||
responses: list[RawMessageStreamEvent],
|
||||
) -> AsyncGenerator[RawMessageStreamEvent]:
|
||||
"""Generate a response from the assistant."""
|
||||
for msg in responses:
|
||||
yield msg
|
||||
|
||||
|
||||
def create_messages(
|
||||
content_blocks: list[RawMessageStreamEvent],
|
||||
) -> list[RawMessageStreamEvent]:
|
||||
"""Create a stream of messages with the specified content blocks."""
|
||||
return [
|
||||
RawMessageStartEvent(
|
||||
message=Message(
|
||||
type="message",
|
||||
id="msg_1234567890ABCDEFGHIJKLMN",
|
||||
content=[],
|
||||
role="assistant",
|
||||
model="claude-3-5-sonnet-20240620",
|
||||
usage=Usage(input_tokens=0, output_tokens=0),
|
||||
),
|
||||
type="message_start",
|
||||
),
|
||||
*content_blocks,
|
||||
RawMessageStopEvent(type="message_stop"),
|
||||
]
|
||||
|
||||
|
||||
def create_content_block(
|
||||
index: int, text_parts: list[str]
|
||||
) -> list[RawMessageStreamEvent]:
|
||||
"""Create a text content block with the specified deltas."""
|
||||
return [
|
||||
RawContentBlockStartEvent(
|
||||
type="content_block_start",
|
||||
content_block=TextBlock(text="", type="text"),
|
||||
index=index,
|
||||
),
|
||||
*[
|
||||
RawContentBlockDeltaEvent(
|
||||
delta=TextDelta(text=text_part, type="text_delta"),
|
||||
index=index,
|
||||
type="content_block_delta",
|
||||
)
|
||||
for text_part in text_parts
|
||||
],
|
||||
RawContentBlockStopEvent(index=index, type="content_block_stop"),
|
||||
]
|
||||
|
||||
|
||||
def create_tool_use_block(
|
||||
index: int, tool_id: str, tool_name: str, json_parts: list[str]
|
||||
) -> list[RawMessageStreamEvent]:
|
||||
"""Create a tool use content block with the specified deltas."""
|
||||
return [
|
||||
RawContentBlockStartEvent(
|
||||
type="content_block_start",
|
||||
content_block=ToolUseBlock(
|
||||
id=tool_id, name=tool_name, input={}, type="tool_use"
|
||||
),
|
||||
index=index,
|
||||
),
|
||||
*[
|
||||
RawContentBlockDeltaEvent(
|
||||
delta=InputJSONDelta(partial_json=json_part, type="input_json_delta"),
|
||||
index=index,
|
||||
type="content_block_delta",
|
||||
)
|
||||
for json_part in json_parts
|
||||
],
|
||||
RawContentBlockStopEvent(index=index, type="content_block_stop"),
|
||||
]
|
||||
|
||||
|
||||
async def test_entity(
|
||||
hass: HomeAssistant,
|
||||
mock_config_entry: MockConfigEntry,
|
||||
|
@ -120,6 +210,13 @@ async def test_template_variables(
|
|||
) as mock_create,
|
||||
patch("homeassistant.auth.AuthManager.async_get_user", return_value=mock_user),
|
||||
):
|
||||
mock_create.return_value = stream_generator(
|
||||
create_messages(
|
||||
create_content_block(
|
||||
0, ["Okay, let", " me take care of that for you", "."]
|
||||
)
|
||||
)
|
||||
)
|
||||
await hass.config_entries.async_setup(mock_config_entry.entry_id)
|
||||
await hass.async_block_till_done()
|
||||
result = await conversation.async_converse(
|
||||
|
@ -129,6 +226,10 @@ async def test_template_variables(
|
|||
assert result.response.response_type == intent.IntentResponseType.ACTION_DONE, (
|
||||
result
|
||||
)
|
||||
assert (
|
||||
result.response.speech["plain"]["speech"]
|
||||
== "Okay, let me take care of that for you."
|
||||
)
|
||||
assert "The user name is Test User." in mock_create.mock_calls[1][2]["system"]
|
||||
assert "The user id is 12345." in mock_create.mock_calls[1][2]["system"]
|
||||
|
||||
|
@ -168,39 +269,26 @@ async def test_function_call(
|
|||
for message in messages:
|
||||
for content in message["content"]:
|
||||
if not isinstance(content, str) and content["type"] == "tool_use":
|
||||
return Message(
|
||||
type="message",
|
||||
id="msg_1234567890ABCDEFGHIJKLMN",
|
||||
content=[
|
||||
TextBlock(
|
||||
type="text",
|
||||
text="I have successfully called the function",
|
||||
)
|
||||
],
|
||||
model="claude-3-5-sonnet-20240620",
|
||||
role="assistant",
|
||||
stop_reason="end_turn",
|
||||
stop_sequence=None,
|
||||
usage=Usage(input_tokens=8, output_tokens=12),
|
||||
return stream_generator(
|
||||
create_messages(
|
||||
create_content_block(
|
||||
0, ["I have ", "successfully called ", "the function"]
|
||||
),
|
||||
)
|
||||
)
|
||||
|
||||
return Message(
|
||||
type="message",
|
||||
id="msg_1234567890ABCDEFGHIJKLMN",
|
||||
content=[
|
||||
TextBlock(type="text", text="Certainly, calling it now!"),
|
||||
ToolUseBlock(
|
||||
type="tool_use",
|
||||
id="toolu_0123456789AbCdEfGhIjKlM",
|
||||
name="test_tool",
|
||||
input={"param1": "test_value"},
|
||||
),
|
||||
],
|
||||
model="claude-3-5-sonnet-20240620",
|
||||
role="assistant",
|
||||
stop_reason="tool_use",
|
||||
stop_sequence=None,
|
||||
usage=Usage(input_tokens=8, output_tokens=12),
|
||||
return stream_generator(
|
||||
create_messages(
|
||||
[
|
||||
*create_content_block(0, ["Certainly, calling it now!"]),
|
||||
*create_tool_use_block(
|
||||
1,
|
||||
"toolu_0123456789AbCdEfGhIjKlM",
|
||||
"test_tool",
|
||||
['{"para', 'm1": "test_valu', 'e"}'],
|
||||
),
|
||||
]
|
||||
)
|
||||
)
|
||||
|
||||
with (
|
||||
|
@ -222,6 +310,10 @@ async def test_function_call(
|
|||
assert "Today's date is 2024-06-03." in mock_create.mock_calls[1][2]["system"]
|
||||
|
||||
assert result.response.response_type == intent.IntentResponseType.ACTION_DONE
|
||||
assert (
|
||||
result.response.speech["plain"]["speech"]
|
||||
== "I have successfully called the function"
|
||||
)
|
||||
assert mock_create.mock_calls[1][2]["messages"][2] == {
|
||||
"role": "user",
|
||||
"content": [
|
||||
|
@ -275,39 +367,27 @@ async def test_function_exception(
|
|||
for message in messages:
|
||||
for content in message["content"]:
|
||||
if not isinstance(content, str) and content["type"] == "tool_use":
|
||||
return Message(
|
||||
type="message",
|
||||
id="msg_1234567890ABCDEFGHIJKLMN",
|
||||
content=[
|
||||
TextBlock(
|
||||
type="text",
|
||||
text="There was an error calling the function",
|
||||
return stream_generator(
|
||||
create_messages(
|
||||
create_content_block(
|
||||
0,
|
||||
["There was an error calling the function"],
|
||||
)
|
||||
],
|
||||
model="claude-3-5-sonnet-20240620",
|
||||
role="assistant",
|
||||
stop_reason="end_turn",
|
||||
stop_sequence=None,
|
||||
usage=Usage(input_tokens=8, output_tokens=12),
|
||||
)
|
||||
)
|
||||
|
||||
return Message(
|
||||
type="message",
|
||||
id="msg_1234567890ABCDEFGHIJKLMN",
|
||||
content=[
|
||||
TextBlock(type="text", text="Certainly, calling it now!"),
|
||||
ToolUseBlock(
|
||||
type="tool_use",
|
||||
id="toolu_0123456789AbCdEfGhIjKlM",
|
||||
name="test_tool",
|
||||
input={"param1": "test_value"},
|
||||
),
|
||||
],
|
||||
model="claude-3-5-sonnet-20240620",
|
||||
role="assistant",
|
||||
stop_reason="tool_use",
|
||||
stop_sequence=None,
|
||||
usage=Usage(input_tokens=8, output_tokens=12),
|
||||
return stream_generator(
|
||||
create_messages(
|
||||
[
|
||||
*create_content_block(0, "Certainly, calling it now!"),
|
||||
*create_tool_use_block(
|
||||
1,
|
||||
"toolu_0123456789AbCdEfGhIjKlM",
|
||||
"test_tool",
|
||||
['{"param1": "test_value"}'],
|
||||
),
|
||||
]
|
||||
)
|
||||
)
|
||||
|
||||
with patch(
|
||||
|
@ -324,6 +404,10 @@ async def test_function_exception(
|
|||
)
|
||||
|
||||
assert result.response.response_type == intent.IntentResponseType.ACTION_DONE
|
||||
assert (
|
||||
result.response.speech["plain"]["speech"]
|
||||
== "There was an error calling the function"
|
||||
)
|
||||
assert mock_create.mock_calls[1][2]["messages"][2] == {
|
||||
"role": "user",
|
||||
"content": [
|
||||
|
@ -376,15 +460,10 @@ async def test_assist_api_tools_conversion(
|
|||
with patch(
|
||||
"anthropic.resources.messages.AsyncMessages.create",
|
||||
new_callable=AsyncMock,
|
||||
return_value=Message(
|
||||
type="message",
|
||||
id="msg_1234567890ABCDEFGHIJKLMN",
|
||||
content=[TextBlock(type="text", text="Hello, how can I help you?")],
|
||||
model="claude-3-5-sonnet-20240620",
|
||||
role="assistant",
|
||||
stop_reason="end_turn",
|
||||
stop_sequence=None,
|
||||
usage=Usage(input_tokens=8, output_tokens=12),
|
||||
return_value=stream_generator(
|
||||
create_messages(
|
||||
create_content_block(0, "Hello, how can I help you?"),
|
||||
),
|
||||
),
|
||||
) as mock_create:
|
||||
await conversation.async_converse(
|
||||
|
@ -425,28 +504,45 @@ async def test_conversation_id(
|
|||
mock_init_component,
|
||||
) -> None:
|
||||
"""Test conversation ID is honored."""
|
||||
result = await conversation.async_converse(
|
||||
hass, "hello", None, None, agent_id="conversation.claude"
|
||||
)
|
||||
|
||||
conversation_id = result.conversation_id
|
||||
def create_stream_generator(*args, **kwargs) -> Any:
|
||||
return stream_generator(
|
||||
create_messages(
|
||||
create_content_block(0, "Hello, how can I help you?"),
|
||||
),
|
||||
)
|
||||
|
||||
result = await conversation.async_converse(
|
||||
hass, "hello", conversation_id, None, agent_id="conversation.claude"
|
||||
)
|
||||
with patch(
|
||||
"anthropic.resources.messages.AsyncMessages.create",
|
||||
new_callable=AsyncMock,
|
||||
side_effect=create_stream_generator,
|
||||
):
|
||||
result = await conversation.async_converse(
|
||||
hass, "hello", "1234", Context(), agent_id="conversation.claude"
|
||||
)
|
||||
|
||||
assert result.conversation_id == conversation_id
|
||||
result = await conversation.async_converse(
|
||||
hass, "hello", None, None, agent_id="conversation.claude"
|
||||
)
|
||||
|
||||
unknown_id = ulid_util.ulid()
|
||||
conversation_id = result.conversation_id
|
||||
|
||||
result = await conversation.async_converse(
|
||||
hass, "hello", unknown_id, None, agent_id="conversation.claude"
|
||||
)
|
||||
result = await conversation.async_converse(
|
||||
hass, "hello", conversation_id, None, agent_id="conversation.claude"
|
||||
)
|
||||
|
||||
assert result.conversation_id != unknown_id
|
||||
assert result.conversation_id == conversation_id
|
||||
|
||||
result = await conversation.async_converse(
|
||||
hass, "hello", "koala", None, agent_id="conversation.claude"
|
||||
)
|
||||
unknown_id = ulid_util.ulid()
|
||||
|
||||
assert result.conversation_id == "koala"
|
||||
result = await conversation.async_converse(
|
||||
hass, "hello", unknown_id, None, agent_id="conversation.claude"
|
||||
)
|
||||
|
||||
assert result.conversation_id != unknown_id
|
||||
|
||||
result = await conversation.async_converse(
|
||||
hass, "hello", "koala", None, agent_id="conversation.claude"
|
||||
)
|
||||
|
||||
assert result.conversation_id == "koala"
|
||||
|
|
Loading…
Reference in New Issue