Improve conversation agent tracing to help with eval and data collection (#122542)

pull/122936/head
Allen Porter 2024-07-31 05:38:44 -07:00 committed by GitHub
parent 4f5eab4646
commit 8d0e998e54
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
8 changed files with 32 additions and 7 deletions

View File

@ -47,6 +47,7 @@ from homeassistant.util.json import JsonObjectType, json_loads_object
from .const import DEFAULT_EXPOSED_ATTRIBUTES, DOMAIN, ConversationEntityFeature
from .entity import ConversationEntity
from .models import ConversationInput, ConversationResult
from .trace import ConversationTraceEventType, async_conversation_trace_append
_LOGGER = logging.getLogger(__name__)
_DEFAULT_ERROR_TEXT = "Sorry, I couldn't understand that"
@ -348,6 +349,16 @@ class DefaultAgent(ConversationEntity):
}
for entity in result.entities_list
}
async_conversation_trace_append(
ConversationTraceEventType.TOOL_CALL,
{
"intent_name": result.intent.name,
"slots": {
entity.name: entity.value or entity.text
for entity in result.entities_list
},
},
)
try:
intent_response = await intent.async_handle(

View File

@ -22,8 +22,8 @@ class ConversationTraceEventType(enum.StrEnum):
AGENT_DETAIL = "agent_detail"
"""Event detail added by a conversation agent."""
LLM_TOOL_CALL = "llm_tool_call"
"""An LLM Tool call"""
TOOL_CALL = "tool_call"
"""A conversation agent Tool call or default agent intent call."""
@dataclass(frozen=True)

View File

@ -286,6 +286,7 @@ class GoogleGenerativeAIConversationEntity(
if supports_system_instruction
else messages[2:],
"prompt": prompt,
"tools": [*llm_api.tools] if llm_api else None,
},
)

View File

@ -225,7 +225,8 @@ class OpenAIConversationEntity(
LOGGER.debug("Prompt: %s", messages)
LOGGER.debug("Tools: %s", tools)
trace.async_conversation_trace_append(
trace.ConversationTraceEventType.AGENT_DETAIL, {"messages": messages}
trace.ConversationTraceEventType.AGENT_DETAIL,
{"messages": messages, "tools": llm_api.tools if llm_api else None},
)
client = self.entry.runtime_data

View File

@ -167,7 +167,7 @@ class APIInstance:
async def async_call_tool(self, tool_input: ToolInput) -> JsonObjectType:
"""Call a LLM tool, validate args and return the response."""
async_conversation_trace_append(
ConversationTraceEventType.LLM_TOOL_CALL,
ConversationTraceEventType.TOOL_CALL,
{"tool_name": tool_input.tool_name, "tool_args": tool_input.tool_args},
)

View File

@ -33,7 +33,7 @@ async def test_converation_trace(
assert traces
last_trace = traces[-1].as_dict()
assert last_trace.get("events")
assert len(last_trace.get("events")) == 1
assert len(last_trace.get("events")) == 2
trace_event = last_trace["events"][0]
assert (
trace_event.get("event_type") == trace.ConversationTraceEventType.ASYNC_PROCESS
@ -50,6 +50,16 @@ async def test_converation_trace(
== "Added apples"
)
trace_event = last_trace["events"][1]
assert trace_event.get("event_type") == trace.ConversationTraceEventType.TOOL_CALL
assert trace_event.get("data") == {
"intent_name": "HassListAddItem",
"slots": {
"name": "Shopping List",
"item": "apples ",
},
}
async def test_converation_trace_error(
hass: HomeAssistant,

View File

@ -269,11 +269,12 @@ async def test_function_call(
assert [event["event_type"] for event in trace_events] == [
trace.ConversationTraceEventType.ASYNC_PROCESS,
trace.ConversationTraceEventType.AGENT_DETAIL,
trace.ConversationTraceEventType.LLM_TOOL_CALL,
trace.ConversationTraceEventType.TOOL_CALL,
]
# AGENT_DETAIL event contains the raw prompt passed to the model
detail_event = trace_events[1]
assert "Answer in plain text" in detail_event["data"]["prompt"]
assert [t.name for t in detail_event["data"]["tools"]] == ["test_tool"]
@patch(

View File

@ -294,7 +294,7 @@ async def test_function_call(
assert [event["event_type"] for event in trace_events] == [
trace.ConversationTraceEventType.ASYNC_PROCESS,
trace.ConversationTraceEventType.AGENT_DETAIL,
trace.ConversationTraceEventType.LLM_TOOL_CALL,
trace.ConversationTraceEventType.TOOL_CALL,
]
# AGENT_DETAIL event contains the raw prompt passed to the model
detail_event = trace_events[1]
@ -303,6 +303,7 @@ async def test_function_call(
"Today's date is 2024-06-03."
in trace_events[1]["data"]["messages"][0]["content"]
)
assert [t.name for t in detail_event["data"]["tools"]] == ["test_tool"]
# Call it again, make sure we have updated prompt
with (