Improve conversation agent tracing to help with eval and data collection (#122542)
parent
4f5eab4646
commit
8d0e998e54
|
@ -47,6 +47,7 @@ from homeassistant.util.json import JsonObjectType, json_loads_object
|
|||
from .const import DEFAULT_EXPOSED_ATTRIBUTES, DOMAIN, ConversationEntityFeature
|
||||
from .entity import ConversationEntity
|
||||
from .models import ConversationInput, ConversationResult
|
||||
from .trace import ConversationTraceEventType, async_conversation_trace_append
|
||||
|
||||
_LOGGER = logging.getLogger(__name__)
|
||||
_DEFAULT_ERROR_TEXT = "Sorry, I couldn't understand that"
|
||||
|
@ -348,6 +349,16 @@ class DefaultAgent(ConversationEntity):
|
|||
}
|
||||
for entity in result.entities_list
|
||||
}
|
||||
async_conversation_trace_append(
|
||||
ConversationTraceEventType.TOOL_CALL,
|
||||
{
|
||||
"intent_name": result.intent.name,
|
||||
"slots": {
|
||||
entity.name: entity.value or entity.text
|
||||
for entity in result.entities_list
|
||||
},
|
||||
},
|
||||
)
|
||||
|
||||
try:
|
||||
intent_response = await intent.async_handle(
|
||||
|
|
|
@ -22,8 +22,8 @@ class ConversationTraceEventType(enum.StrEnum):
|
|||
AGENT_DETAIL = "agent_detail"
|
||||
"""Event detail added by a conversation agent."""
|
||||
|
||||
LLM_TOOL_CALL = "llm_tool_call"
|
||||
"""An LLM Tool call"""
|
||||
TOOL_CALL = "tool_call"
|
||||
"""A conversation agent Tool call or default agent intent call."""
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
|
|
|
@ -286,6 +286,7 @@ class GoogleGenerativeAIConversationEntity(
|
|||
if supports_system_instruction
|
||||
else messages[2:],
|
||||
"prompt": prompt,
|
||||
"tools": [*llm_api.tools] if llm_api else None,
|
||||
},
|
||||
)
|
||||
|
||||
|
|
|
@ -225,7 +225,8 @@ class OpenAIConversationEntity(
|
|||
LOGGER.debug("Prompt: %s", messages)
|
||||
LOGGER.debug("Tools: %s", tools)
|
||||
trace.async_conversation_trace_append(
|
||||
trace.ConversationTraceEventType.AGENT_DETAIL, {"messages": messages}
|
||||
trace.ConversationTraceEventType.AGENT_DETAIL,
|
||||
{"messages": messages, "tools": llm_api.tools if llm_api else None},
|
||||
)
|
||||
|
||||
client = self.entry.runtime_data
|
||||
|
|
|
@ -167,7 +167,7 @@ class APIInstance:
|
|||
async def async_call_tool(self, tool_input: ToolInput) -> JsonObjectType:
|
||||
"""Call a LLM tool, validate args and return the response."""
|
||||
async_conversation_trace_append(
|
||||
ConversationTraceEventType.LLM_TOOL_CALL,
|
||||
ConversationTraceEventType.TOOL_CALL,
|
||||
{"tool_name": tool_input.tool_name, "tool_args": tool_input.tool_args},
|
||||
)
|
||||
|
||||
|
|
|
@ -33,7 +33,7 @@ async def test_converation_trace(
|
|||
assert traces
|
||||
last_trace = traces[-1].as_dict()
|
||||
assert last_trace.get("events")
|
||||
assert len(last_trace.get("events")) == 1
|
||||
assert len(last_trace.get("events")) == 2
|
||||
trace_event = last_trace["events"][0]
|
||||
assert (
|
||||
trace_event.get("event_type") == trace.ConversationTraceEventType.ASYNC_PROCESS
|
||||
|
@ -50,6 +50,16 @@ async def test_converation_trace(
|
|||
== "Added apples"
|
||||
)
|
||||
|
||||
trace_event = last_trace["events"][1]
|
||||
assert trace_event.get("event_type") == trace.ConversationTraceEventType.TOOL_CALL
|
||||
assert trace_event.get("data") == {
|
||||
"intent_name": "HassListAddItem",
|
||||
"slots": {
|
||||
"name": "Shopping List",
|
||||
"item": "apples ",
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
async def test_converation_trace_error(
|
||||
hass: HomeAssistant,
|
||||
|
|
|
@ -269,11 +269,12 @@ async def test_function_call(
|
|||
assert [event["event_type"] for event in trace_events] == [
|
||||
trace.ConversationTraceEventType.ASYNC_PROCESS,
|
||||
trace.ConversationTraceEventType.AGENT_DETAIL,
|
||||
trace.ConversationTraceEventType.LLM_TOOL_CALL,
|
||||
trace.ConversationTraceEventType.TOOL_CALL,
|
||||
]
|
||||
# AGENT_DETAIL event contains the raw prompt passed to the model
|
||||
detail_event = trace_events[1]
|
||||
assert "Answer in plain text" in detail_event["data"]["prompt"]
|
||||
assert [t.name for t in detail_event["data"]["tools"]] == ["test_tool"]
|
||||
|
||||
|
||||
@patch(
|
||||
|
|
|
@ -294,7 +294,7 @@ async def test_function_call(
|
|||
assert [event["event_type"] for event in trace_events] == [
|
||||
trace.ConversationTraceEventType.ASYNC_PROCESS,
|
||||
trace.ConversationTraceEventType.AGENT_DETAIL,
|
||||
trace.ConversationTraceEventType.LLM_TOOL_CALL,
|
||||
trace.ConversationTraceEventType.TOOL_CALL,
|
||||
]
|
||||
# AGENT_DETAIL event contains the raw prompt passed to the model
|
||||
detail_event = trace_events[1]
|
||||
|
@ -303,6 +303,7 @@ async def test_function_call(
|
|||
"Today's date is 2024-06-03."
|
||||
in trace_events[1]["data"]["messages"][0]["content"]
|
||||
)
|
||||
assert [t.name for t in detail_event["data"]["tools"]] == ["test_tool"]
|
||||
|
||||
# Call it again, make sure we have updated prompt
|
||||
with (
|
||||
|
|
Loading…
Reference in New Issue