446 lines
14 KiB
Python
446 lines
14 KiB
Python
##########################################################################
|
|
#
|
|
# pgAdmin 4 - PostgreSQL Tools
|
|
#
|
|
# Copyright (C) 2013 - 2026, The pgAdmin Development Team
|
|
# This software is released under the PostgreSQL Licence
|
|
#
|
|
##########################################################################
|
|
|
|
"""Conversation history compaction for managing LLM token budgets.
|
|
|
|
This module implements a compaction strategy to keep conversation history
|
|
within token limits. It classifies messages by importance and drops
|
|
lower-value messages first, while preserving tool call/result pairs
|
|
and recent conversation context.
|
|
|
|
Inspired by the approach described at:
|
|
https://www.pgedge.com/blog/lessons-learned-writing-an-mcp-server-for-postgresql
|
|
"""
|
|
|
|
import re
|
|
from typing import Optional
|
|
|
|
from pgadmin.llm.models import Message, Role, ToolCall
|
|
|
|
|
|
# Token budget defaults
|
|
DEFAULT_MAX_TOKENS = 100000
|
|
DEFAULT_RECENT_WINDOW = 10
|
|
|
|
# Provider-specific characters-per-token ratios
|
|
CHARS_PER_TOKEN = {
|
|
'anthropic': 3.8,
|
|
'openai': 4.0,
|
|
'ollama': 4.5,
|
|
'docker': 4.0,
|
|
}
|
|
|
|
# SQL content is tokenized less efficiently
|
|
SQL_TOKEN_MULTIPLIER = 1.2
|
|
|
|
# Overhead per message (role markers, formatting, etc.)
|
|
MESSAGE_OVERHEAD_TOKENS = 10
|
|
|
|
# Importance tiers
|
|
CLASS_ANCHOR = 1.0 # Schema info, corrections - always keep
|
|
CLASS_IMPORTANT = 0.8 # Query analysis, errors, insights
|
|
CLASS_CONTEXTUAL = 0.6 # Detailed responses, tool results
|
|
CLASS_ROUTINE = 0.4 # Short responses, standard messages
|
|
CLASS_TRANSIENT = 0.1 # Acknowledgments, short phrases
|
|
|
|
# Patterns for classification
|
|
_SCHEMA_PATTERNS = re.compile(
|
|
r'\b(CREATE|ALTER|DROP)\s+(TABLE|INDEX|VIEW|SCHEMA)\b'
|
|
r'|PRIMARY\s+KEY|FOREIGN\s+KEY|CONSTRAINT\b',
|
|
re.IGNORECASE
|
|
)
|
|
|
|
_QUERY_PATTERNS = re.compile(
|
|
r'\bEXPLAIN\s+ANALYZE\b|execution\s+time\b'
|
|
r'|seq\s+scan\b|index\s+scan\b|query\s+plan\b',
|
|
re.IGNORECASE
|
|
)
|
|
|
|
_ERROR_PATTERNS = re.compile(
|
|
r'\berror\b|\bfailed\b|\bsyntax\s+error\b'
|
|
r'|\bpermission\s+denied\b|\bdoes\s+not\s+exist\b',
|
|
re.IGNORECASE
|
|
)
|
|
|
|
|
|
def estimate_tokens(text: str, provider: str = 'openai') -> int:
|
|
"""Estimate the number of tokens in a text string.
|
|
|
|
Uses provider-specific character-per-token ratios and applies
|
|
a multiplier for SQL-heavy content.
|
|
|
|
Args:
|
|
text: The text to estimate tokens for.
|
|
provider: The LLM provider name for ratio selection.
|
|
|
|
Returns:
|
|
Estimated token count.
|
|
"""
|
|
if not text:
|
|
return 0
|
|
|
|
chars_per_token = CHARS_PER_TOKEN.get(provider, 4.0)
|
|
base_tokens = len(text) / chars_per_token
|
|
|
|
# Apply SQL multiplier if content looks like SQL
|
|
if re.search(r'\b(SELECT|INSERT|UPDATE|DELETE|CREATE|ALTER)\b',
|
|
text, re.IGNORECASE):
|
|
base_tokens *= SQL_TOKEN_MULTIPLIER
|
|
|
|
return int(base_tokens) + MESSAGE_OVERHEAD_TOKENS
|
|
|
|
|
|
def estimate_message_tokens(message: Message, provider: str = 'openai') -> int:
|
|
"""Estimate token count for a single Message object.
|
|
|
|
Args:
|
|
message: The Message to estimate.
|
|
provider: The LLM provider name.
|
|
|
|
Returns:
|
|
Estimated token count.
|
|
"""
|
|
total = estimate_tokens(message.content, provider)
|
|
|
|
# Account for tool call arguments
|
|
for tc in message.tool_calls:
|
|
import json
|
|
total += estimate_tokens(json.dumps(tc.arguments), provider)
|
|
total += estimate_tokens(tc.name, provider)
|
|
|
|
# Account for tool results
|
|
for tr in message.tool_results:
|
|
total += estimate_tokens(tr.content, provider)
|
|
|
|
return total
|
|
|
|
|
|
def estimate_history_tokens(
|
|
messages: list[Message], provider: str = 'openai'
|
|
) -> int:
|
|
"""Estimate total token count for a conversation history.
|
|
|
|
Args:
|
|
messages: List of Message objects.
|
|
provider: The LLM provider name.
|
|
|
|
Returns:
|
|
Estimated total token count.
|
|
"""
|
|
return sum(estimate_message_tokens(m, provider) for m in messages)
|
|
|
|
|
|
def _classify_message(message: Message) -> float:
|
|
"""Classify a message by importance for compaction decisions.
|
|
|
|
Args:
|
|
message: The message to classify.
|
|
|
|
Returns:
|
|
Importance score from 0.0 to 1.0.
|
|
"""
|
|
content = message.content or ''
|
|
|
|
# Tool results containing schema info are anchors
|
|
if message.role == Role.TOOL:
|
|
for tr in message.tool_results:
|
|
if _SCHEMA_PATTERNS.search(tr.content):
|
|
return CLASS_ANCHOR
|
|
if _ERROR_PATTERNS.search(tr.content):
|
|
return CLASS_IMPORTANT
|
|
# Large tool results are contextual
|
|
if len(tr.content) > 500:
|
|
return CLASS_CONTEXTUAL
|
|
return CLASS_ROUTINE
|
|
|
|
# Assistant messages with tool calls are important (they reference tools)
|
|
if message.role == Role.ASSISTANT and message.tool_calls:
|
|
return CLASS_IMPORTANT
|
|
|
|
# Check content patterns
|
|
if _SCHEMA_PATTERNS.search(content):
|
|
return CLASS_ANCHOR
|
|
if _ERROR_PATTERNS.search(content):
|
|
return CLASS_IMPORTANT
|
|
if _QUERY_PATTERNS.search(content):
|
|
return CLASS_IMPORTANT
|
|
|
|
# Short messages are transient
|
|
if len(content) < 30:
|
|
return CLASS_TRANSIENT
|
|
|
|
# Medium messages are routine
|
|
if len(content) < 100:
|
|
return CLASS_ROUTINE
|
|
|
|
return CLASS_CONTEXTUAL
|
|
|
|
|
|
def _find_tool_pair_indices(
|
|
messages: list[Message]
|
|
) -> dict[int, frozenset[int]]:
|
|
"""Find indices of tool_call/tool_result groups that must stay together.
|
|
|
|
An assistant message may contain multiple tool_calls, each with a
|
|
corresponding tool result message. All messages in such a group
|
|
must be dropped or kept together.
|
|
|
|
Returns a mapping where every index in a group maps to the full
|
|
set of indices in that group.
|
|
|
|
Args:
|
|
messages: The message list.
|
|
|
|
Returns:
|
|
Dict mapping index -> frozenset of all indices in the group.
|
|
"""
|
|
groups: dict[int, frozenset[int]] = {}
|
|
|
|
for i, msg in enumerate(messages):
|
|
if msg.role == Role.ASSISTANT and msg.tool_calls:
|
|
tool_call_ids = {tc.id for tc in msg.tool_calls}
|
|
group_indices = {i}
|
|
for j in range(i + 1, len(messages)):
|
|
if messages[j].role == Role.TOOL:
|
|
for tr in messages[j].tool_results:
|
|
if tr.tool_call_id in tool_call_ids:
|
|
group_indices.add(j)
|
|
break
|
|
group = frozenset(group_indices)
|
|
for idx in group:
|
|
groups[idx] = group
|
|
|
|
return groups
|
|
|
|
|
|
def compact_history(
|
|
messages: list[Message],
|
|
max_tokens: int = DEFAULT_MAX_TOKENS,
|
|
recent_window: int = DEFAULT_RECENT_WINDOW,
|
|
provider: str = 'openai'
|
|
) -> list[Message]:
|
|
"""Compact conversation history to fit within a token budget.
|
|
|
|
Strategy:
|
|
1. Always keep the first message (provides original context)
|
|
2. Always keep the last `recent_window` messages
|
|
3. Among remaining messages, classify by importance and drop
|
|
lowest-value messages first
|
|
4. Keep tool_call/tool_result pairs together
|
|
|
|
Args:
|
|
messages: Full conversation history.
|
|
max_tokens: Maximum token budget for the history.
|
|
recent_window: Number of recent messages to always preserve.
|
|
provider: LLM provider name for token estimation.
|
|
|
|
Returns:
|
|
Compacted list of messages that fits within the token budget.
|
|
"""
|
|
if not messages:
|
|
return messages
|
|
|
|
# Check if we're already within budget
|
|
current_tokens = estimate_history_tokens(messages, provider)
|
|
if current_tokens <= max_tokens:
|
|
return messages
|
|
|
|
total = len(messages)
|
|
|
|
# Determine protected indices
|
|
protected = set()
|
|
|
|
# Always protect the first message
|
|
protected.add(0)
|
|
|
|
# Always protect the recent window
|
|
recent_start = max(1, total - recent_window)
|
|
for i in range(recent_start, total):
|
|
protected.add(i)
|
|
|
|
# If protected messages alone exceed the budget, shrink the
|
|
# recent window until we have room for compaction candidates.
|
|
while recent_window > 0:
|
|
protected_tokens = sum(
|
|
estimate_message_tokens(messages[i], provider)
|
|
for i in protected
|
|
)
|
|
if protected_tokens <= max_tokens:
|
|
break
|
|
recent_window -= 1
|
|
recent_start = max(1, total - recent_window)
|
|
protected = {0} | set(range(recent_start, total))
|
|
|
|
# Find tool groups
|
|
tool_groups = _find_tool_pair_indices(messages)
|
|
|
|
# Expand protected set to include entire tool groups so we never
|
|
# split a tool-use turn (leaving orphaned call or result messages).
|
|
for i in list(protected):
|
|
if i in tool_groups:
|
|
protected |= set(tool_groups[i])
|
|
|
|
# Classify and score all non-protected messages
|
|
candidates = []
|
|
for i in range(len(messages)):
|
|
if i not in protected:
|
|
score = _classify_message(messages[i])
|
|
candidates.append((i, score))
|
|
|
|
# Sort by importance (lowest first - these get dropped first)
|
|
candidates.sort(key=lambda x: x[1])
|
|
|
|
# Drop messages starting from lowest importance until within budget
|
|
dropped = set()
|
|
for idx, score in candidates:
|
|
if current_tokens <= max_tokens:
|
|
break
|
|
|
|
# Skip if already dropped (as part of a group)
|
|
if idx in dropped:
|
|
continue
|
|
|
|
# Don't drop anchor messages unless we absolutely must
|
|
if score >= CLASS_ANCHOR:
|
|
break
|
|
|
|
# Calculate tokens saved by dropping this message
|
|
saved = estimate_message_tokens(messages[idx], provider)
|
|
dropped.add(idx)
|
|
|
|
# If this is part of a tool group, drop all partners too
|
|
if idx in tool_groups:
|
|
for partner in tool_groups[idx]:
|
|
if partner != idx and partner not in protected:
|
|
saved += estimate_message_tokens(
|
|
messages[partner], provider
|
|
)
|
|
dropped.add(partner)
|
|
|
|
current_tokens -= saved
|
|
|
|
# If still over budget, drop anchor messages too
|
|
if current_tokens > max_tokens:
|
|
for idx, score in candidates:
|
|
if current_tokens <= max_tokens:
|
|
break
|
|
if idx in dropped:
|
|
continue
|
|
|
|
saved = estimate_message_tokens(messages[idx], provider)
|
|
dropped.add(idx)
|
|
|
|
if idx in tool_groups:
|
|
for partner in tool_groups[idx]:
|
|
if partner != idx and partner not in protected:
|
|
saved += estimate_message_tokens(
|
|
messages[partner], provider
|
|
)
|
|
dropped.add(partner)
|
|
|
|
current_tokens -= saved
|
|
|
|
# Build the compacted message list preserving order
|
|
result = [msg for i, msg in enumerate(messages) if i not in dropped]
|
|
|
|
return result
|
|
|
|
|
|
def deserialize_history(
|
|
history_data: list[dict]
|
|
) -> list[Message]:
|
|
"""Deserialize conversation history from JSON request data.
|
|
|
|
Converts a list of message dictionaries (from the frontend) into
|
|
Message objects suitable for passing to chat_with_database().
|
|
|
|
Args:
|
|
history_data: List of dicts with 'role' and 'content' keys,
|
|
and optionally 'tool_calls' and 'tool_results'.
|
|
|
|
Returns:
|
|
List of Message objects.
|
|
"""
|
|
if not isinstance(history_data, list):
|
|
return []
|
|
|
|
messages = []
|
|
for item in history_data:
|
|
if not isinstance(item, dict):
|
|
continue
|
|
|
|
role_str = item.get('role', '')
|
|
content = item.get('content', '')
|
|
|
|
try:
|
|
role = Role(role_str)
|
|
except ValueError:
|
|
continue # Skip unknown roles
|
|
|
|
# Reconstruct tool calls if present
|
|
tool_calls = []
|
|
for tc_data in item.get('tool_calls') or []:
|
|
if not isinstance(tc_data, dict):
|
|
continue
|
|
tool_calls.append(ToolCall(
|
|
id=tc_data.get('id', ''),
|
|
name=tc_data.get('name', ''),
|
|
arguments=tc_data.get('arguments', {})
|
|
))
|
|
|
|
# Reconstruct tool results if present
|
|
from pgadmin.llm.models import ToolResult
|
|
tool_results = []
|
|
for tr_data in item.get('tool_results') or []:
|
|
if not isinstance(tr_data, dict):
|
|
continue
|
|
tool_results.append(ToolResult(
|
|
tool_call_id=tr_data.get('tool_call_id', ''),
|
|
content=tr_data.get('content', ''),
|
|
is_error=tr_data.get('is_error', False)
|
|
))
|
|
|
|
messages.append(Message(
|
|
role=role,
|
|
content=content,
|
|
tool_calls=tool_calls,
|
|
tool_results=tool_results
|
|
))
|
|
|
|
return messages
|
|
|
|
|
|
def filter_conversational(messages: list[Message]) -> list[Message]:
|
|
"""Filter history to only conversational messages for storage.
|
|
|
|
Keeps user messages and final assistant responses (those without
|
|
tool calls). Drops intermediate assistant messages that contain
|
|
tool_use requests and all tool result messages, since these are
|
|
internal to each turn and don't need to persist between turns.
|
|
|
|
This dramatically reduces history size since tool results often
|
|
contain large schema dumps and query results.
|
|
|
|
Args:
|
|
messages: Full message history including tool call internals.
|
|
|
|
Returns:
|
|
Filtered list with only user messages and final assistant
|
|
responses.
|
|
"""
|
|
result = []
|
|
for msg in messages:
|
|
if msg.role == Role.USER:
|
|
result.append(msg)
|
|
elif msg.role == Role.ASSISTANT and not msg.tool_calls:
|
|
# Final assistant response (no pending tool calls)
|
|
result.append(msg)
|
|
# Skip Role.TOOL and assistant messages with tool_calls
|
|
return result
|