Add conversation agent debug tracing (#118124)

* Add debug tracing for conversation agents

* Minor cleanup
pull/118136/head
Allen Porter 2024-05-25 11:16:51 -07:00 committed by GitHub
parent 2f16c3aa80
commit 89e2c57da6
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
11 changed files with 294 additions and 9 deletions

View File

@ -2,6 +2,7 @@
from __future__ import annotations
import dataclasses
import logging
from typing import Any
@ -20,6 +21,11 @@ from .models import (
ConversationInput,
ConversationResult,
)
from .trace import (
ConversationTraceEvent,
ConversationTraceEventType,
async_conversation_trace,
)
_LOGGER = logging.getLogger(__name__)
@ -84,15 +90,23 @@ async def async_converse(
language = hass.config.language
_LOGGER.debug("Processing in %s: %s", language, text)
return await method(
ConversationInput(
conversation_input = ConversationInput(
text=text,
context=context,
conversation_id=conversation_id,
device_id=device_id,
language=language,
)
with async_conversation_trace() as trace:
trace.add_event(
ConversationTraceEvent(
ConversationTraceEventType.ASYNC_PROCESS,
dataclasses.asdict(conversation_input),
)
)
result = await method(conversation_input)
trace.set_result(**result.as_dict())
return result
class AgentManager:

View File

@ -0,0 +1,118 @@
"""Debug traces for conversation."""
from collections.abc import Generator
from contextlib import contextmanager
from contextvars import ContextVar
from dataclasses import asdict, dataclass, field
import enum
from typing import Any
from homeassistant.util import dt as dt_util, ulid as ulid_util
from homeassistant.util.limited_size_dict import LimitedSizeDict
STORED_TRACES = 3
class ConversationTraceEventType(enum.StrEnum):
"""Type of an event emitted during a conversation."""
ASYNC_PROCESS = "async_process"
"""The conversation is started from user input."""
AGENT_DETAIL = "agent_detail"
"""Event detail added by a conversation agent."""
LLM_TOOL_CALL = "llm_tool_call"
"""An LLM Tool call"""
@dataclass(frozen=True)
class ConversationTraceEvent:
"""Event emitted during a conversation."""
event_type: ConversationTraceEventType
data: dict[str, Any] | None = None
timestamp: str = field(default_factory=lambda: dt_util.utcnow().isoformat())
class ConversationTrace:
"""Stores debug data related to a conversation."""
def __init__(self) -> None:
"""Initialize ConversationTrace."""
self._trace_id = ulid_util.ulid_now()
self._events: list[ConversationTraceEvent] = []
self._error: Exception | None = None
self._result: dict[str, Any] = {}
@property
def trace_id(self) -> str:
"""Identifier for this trace."""
return self._trace_id
def add_event(self, event: ConversationTraceEvent) -> None:
"""Add an event to the trace."""
self._events.append(event)
def set_error(self, ex: Exception) -> None:
"""Set error."""
self._error = ex
def set_result(self, **kwargs: Any) -> None:
"""Set result."""
self._result = {**kwargs}
def as_dict(self) -> dict[str, Any]:
"""Return dictionary version of this ConversationTrace."""
result: dict[str, Any] = {
"id": self._trace_id,
"events": [asdict(event) for event in self._events],
}
if self._error is not None:
result["error"] = str(self._error) or self._error.__class__.__name__
if self._result is not None:
result["result"] = self._result
return result
_current_trace: ContextVar[ConversationTrace | None] = ContextVar(
"current_trace", default=None
)
_recent_traces: LimitedSizeDict[str, ConversationTrace] = LimitedSizeDict(
size_limit=STORED_TRACES
)
def async_conversation_trace_append(
event_type: ConversationTraceEventType, event_data: dict[str, Any]
) -> None:
"""Append a ConversationTraceEvent to the current active trace."""
trace = _current_trace.get()
if not trace:
return
trace.add_event(ConversationTraceEvent(event_type, event_data))
@contextmanager
def async_conversation_trace() -> Generator[ConversationTrace, None]:
"""Create a new active ConversationTrace."""
trace = ConversationTrace()
token = _current_trace.set(trace)
_recent_traces[trace.trace_id] = trace
try:
yield trace
except Exception as ex:
trace.set_error(ex)
raise
finally:
_current_trace.reset(token)
def async_get_traces() -> list[ConversationTrace]:
"""Get the most recent traces."""
return list(_recent_traces.values())
def async_clear_traces() -> None:
"""Clear all traces."""
_recent_traces.clear()

View File

@ -12,6 +12,7 @@ import voluptuous as vol
from voluptuous_openapi import convert
from homeassistant.components import assist_pipeline, conversation
from homeassistant.components.conversation import trace
from homeassistant.config_entries import ConfigEntry
from homeassistant.const import CONF_LLM_HASS_API, MATCH_ALL
from homeassistant.core import HomeAssistant
@ -250,6 +251,9 @@ class GoogleGenerativeAIConversationEntity(
messages[1] = {"role": "model", "parts": "Ok"}
LOGGER.debug("Input: '%s' with history: %s", user_input.text, messages)
trace.async_conversation_trace_append(
trace.ConversationTraceEventType.AGENT_DETAIL, {"messages": messages}
)
chat = model.start_chat(history=messages)
chat_request = user_input.text

View File

@ -9,6 +9,7 @@ from typing import Literal
import ollama
from homeassistant.components import assist_pipeline, conversation
from homeassistant.components.conversation import trace
from homeassistant.components.homeassistant.exposed_entities import async_should_expose
from homeassistant.config_entries import ConfigEntry
from homeassistant.const import MATCH_ALL
@ -138,6 +139,11 @@ class OllamaConversationEntity(
ollama.Message(role=MessageRole.USER.value, content=user_input.text)
)
trace.async_conversation_trace_append(
trace.ConversationTraceEventType.AGENT_DETAIL,
{"messages": message_history.messages},
)
# Get response
try:
response = await client.chat(

View File

@ -8,6 +8,7 @@ import voluptuous as vol
from voluptuous_openapi import convert
from homeassistant.components import assist_pipeline, conversation
from homeassistant.components.conversation import trace
from homeassistant.config_entries import ConfigEntry
from homeassistant.const import CONF_LLM_HASS_API, MATCH_ALL
from homeassistant.core import HomeAssistant
@ -169,6 +170,9 @@ class OpenAIConversationEntity(
messages.append({"role": "user", "content": user_input.text})
LOGGER.debug("Prompt: %s", messages)
trace.async_conversation_trace_append(
trace.ConversationTraceEventType.AGENT_DETAIL, {"messages": messages}
)
client = self.hass.data[DOMAIN][self.entry.entry_id]

View File

@ -3,12 +3,16 @@
from __future__ import annotations
from abc import ABC, abstractmethod
from dataclasses import dataclass
from dataclasses import asdict, dataclass
from typing import Any
import voluptuous as vol
from homeassistant.components.climate.intent import INTENT_GET_TEMPERATURE
from homeassistant.components.conversation.trace import (
ConversationTraceEventType,
async_conversation_trace_append,
)
from homeassistant.components.weather.intent import INTENT_GET_WEATHER
from homeassistant.core import Context, HomeAssistant, callback
from homeassistant.exceptions import HomeAssistantError
@ -116,6 +120,10 @@ class API(ABC):
async def async_call_tool(self, tool_input: ToolInput) -> JsonObjectType:
"""Call a LLM tool, validate args and return the response."""
async_conversation_trace_append(
ConversationTraceEventType.LLM_TOOL_CALL, asdict(tool_input)
)
for tool in self.async_get_tools():
if tool.name == tool_input.tool_name:
break

View File

@ -2,7 +2,9 @@
from unittest.mock import patch
from homeassistant.components import conversation
from homeassistant.core import Context, HomeAssistant, State
from homeassistant.helpers import intent
from homeassistant.setup import async_setup_component
import homeassistant.util.dt as dt_util
@ -31,6 +33,11 @@ async def test_state_set_and_restore(hass: HomeAssistant) -> None:
) as mock_process,
patch("homeassistant.util.dt.utcnow", return_value=now),
):
intent_response = intent.IntentResponse(language="en")
intent_response.async_set_speech("response text")
mock_process.return_value = conversation.ConversationResult(
response=intent_response,
)
await hass.services.async_call(
"conversation",
"process",

View File

@ -0,0 +1,80 @@
"""Test for the conversation traces."""
from unittest.mock import patch
import pytest
from homeassistant.components import conversation
from homeassistant.components.conversation import trace
from homeassistant.core import Context, HomeAssistant
from homeassistant.exceptions import HomeAssistantError
from homeassistant.setup import async_setup_component
@pytest.fixture
async def init_components(hass: HomeAssistant):
"""Initialize relevant components with empty configs."""
assert await async_setup_component(hass, "homeassistant", {})
assert await async_setup_component(hass, "conversation", {})
assert await async_setup_component(hass, "intent", {})
async def test_converation_trace(
hass: HomeAssistant,
init_components: None,
sl_setup: None,
) -> None:
"""Test tracing a conversation."""
await conversation.async_converse(
hass, "add apples to my shopping list", None, Context()
)
traces = trace.async_get_traces()
assert traces
last_trace = traces[-1].as_dict()
assert last_trace.get("events")
assert len(last_trace.get("events")) == 1
trace_event = last_trace["events"][0]
assert (
trace_event.get("event_type") == trace.ConversationTraceEventType.ASYNC_PROCESS
)
assert trace_event.get("data")
assert trace_event["data"].get("text") == "add apples to my shopping list"
assert last_trace.get("result")
assert (
last_trace["result"]
.get("response", {})
.get("speech", {})
.get("plain", {})
.get("speech")
== "Added apples"
)
async def test_converation_trace_error(
hass: HomeAssistant,
init_components: None,
sl_setup: None,
) -> None:
"""Test tracing a conversation."""
with (
patch(
"homeassistant.components.conversation.default_agent.DefaultAgent.async_process",
side_effect=HomeAssistantError("Failed to talk to agent"),
),
pytest.raises(HomeAssistantError),
):
await conversation.async_converse(
hass, "add apples to my shopping list", None, Context()
)
traces = trace.async_get_traces()
assert traces
last_trace = traces[-1].as_dict()
assert last_trace.get("events")
assert len(last_trace.get("events")) == 1
trace_event = last_trace["events"][0]
assert (
trace_event.get("event_type") == trace.ConversationTraceEventType.ASYNC_PROCESS
)
assert last_trace.get("error") == "Failed to talk to agent"

View File

@ -9,6 +9,7 @@ from syrupy.assertion import SnapshotAssertion
import voluptuous as vol
from homeassistant.components import conversation
from homeassistant.components.conversation import trace
from homeassistant.const import CONF_LLM_HASS_API
from homeassistant.core import Context, HomeAssistant
from homeassistant.exceptions import HomeAssistantError
@ -285,6 +286,20 @@ async def test_function_call(
),
)
# Test conversating tracing
traces = trace.async_get_traces()
assert traces
last_trace = traces[-1].as_dict()
trace_events = last_trace.get("events", [])
assert [event["event_type"] for event in trace_events] == [
trace.ConversationTraceEventType.ASYNC_PROCESS,
trace.ConversationTraceEventType.AGENT_DETAIL,
trace.ConversationTraceEventType.LLM_TOOL_CALL,
]
# AGENT_DETAIL event contains the raw prompt passed to the model
detail_event = trace_events[1]
assert "Answer in plain text" in detail_event["data"]["messages"][0]["parts"]
@patch(
"homeassistant.components.google_generative_ai_conversation.conversation.llm.AssistAPI.async_get_tools"

View File

@ -6,6 +6,7 @@ from ollama import Message, ResponseError
import pytest
from homeassistant.components import conversation, ollama
from homeassistant.components.conversation import trace
from homeassistant.components.homeassistant.exposed_entities import async_expose_entity
from homeassistant.const import ATTR_FRIENDLY_NAME, MATCH_ALL
from homeassistant.core import Context, HomeAssistant
@ -110,6 +111,19 @@ async def test_chat(
), result
assert result.response.speech["plain"]["speech"] == "test response"
# Test Conversation tracing
traces = trace.async_get_traces()
assert traces
last_trace = traces[-1].as_dict()
trace_events = last_trace.get("events", [])
assert [event["event_type"] for event in trace_events] == [
trace.ConversationTraceEventType.ASYNC_PROCESS,
trace.ConversationTraceEventType.AGENT_DETAIL,
]
# AGENT_DETAIL event contains the raw prompt passed to the model
detail_event = trace_events[1]
assert "The current time is" in detail_event["data"]["messages"][0]["content"]
async def test_message_history_trimming(
hass: HomeAssistant, mock_config_entry: MockConfigEntry, mock_init_component

View File

@ -15,6 +15,7 @@ from syrupy.assertion import SnapshotAssertion
import voluptuous as vol
from homeassistant.components import conversation
from homeassistant.components.conversation import trace
from homeassistant.const import CONF_LLM_HASS_API
from homeassistant.core import Context, HomeAssistant
from homeassistant.exceptions import HomeAssistantError
@ -200,6 +201,20 @@ async def test_function_call(
),
)
# Test Conversation tracing
traces = trace.async_get_traces()
assert traces
last_trace = traces[-1].as_dict()
trace_events = last_trace.get("events", [])
assert [event["event_type"] for event in trace_events] == [
trace.ConversationTraceEventType.ASYNC_PROCESS,
trace.ConversationTraceEventType.AGENT_DETAIL,
trace.ConversationTraceEventType.LLM_TOOL_CALL,
]
# AGENT_DETAIL event contains the raw prompt passed to the model
detail_event = trace_events[1]
assert "Answer in plain text" in detail_event["data"]["messages"][0]["content"]
@patch(
"homeassistant.components.openai_conversation.conversation.llm.AssistAPI.async_get_tools"