292 lines
9.1 KiB
Python
292 lines
9.1 KiB
Python
##########################################################################
|
|
#
|
|
# pgAdmin 4 - PostgreSQL Tools
|
|
#
|
|
# Copyright (C) 2013 - 2026, The pgAdmin Development Team
|
|
# This software is released under the PostgreSQL Licence
|
|
#
|
|
##########################################################################
|
|
|
|
"""High-level report generation functions using the pipeline."""
|
|
|
|
import json
|
|
from typing import Generator, Optional, Any
|
|
|
|
from flask import Response, stream_with_context
|
|
from flask_babel import gettext
|
|
|
|
from pgadmin.llm.client import get_llm_client, LLMClient
|
|
from pgadmin.llm.reports.pipeline import ReportPipeline
|
|
from pgadmin.llm.reports.sections import get_sections_for_scope
|
|
from pgadmin.llm.reports.queries import QUERIES
|
|
|
|
|
|
def create_query_executor(conn) -> callable:
|
|
"""Create a query executor function for the pipeline.
|
|
|
|
Args:
|
|
conn: Database connection object.
|
|
|
|
Returns:
|
|
A callable that executes queries by ID.
|
|
"""
|
|
def executor(query_id: str, context: dict) -> dict[str, Any]:
|
|
"""Execute a query by ID.
|
|
|
|
Args:
|
|
query_id: The query identifier from QUERIES registry.
|
|
context: Execution context (may contain schema_id for filtering).
|
|
|
|
Returns:
|
|
Dictionary with query results.
|
|
"""
|
|
query_def = QUERIES.get(query_id)
|
|
if not query_def:
|
|
return {'error': f'Unknown query: {query_id}', 'rows': []}
|
|
|
|
sql = query_def['sql']
|
|
|
|
# Check if query requires an extension
|
|
required_ext = query_def.get('requires_extension')
|
|
if required_ext:
|
|
check_sql = """
|
|
SELECT EXISTS (
|
|
SELECT 1 FROM pg_extension WHERE extname = %s
|
|
) as available
|
|
"""
|
|
status, result = conn.execute_dict(check_sql, [required_ext])
|
|
if not (status and result and
|
|
result.get('rows', [{}])[0].get('available', False)):
|
|
return {
|
|
'note': f"Extension '{required_ext}' not installed",
|
|
'rows': []
|
|
}
|
|
|
|
# Handle schema-scoped queries
|
|
schema_id = context.get('schema_id')
|
|
if schema_id and '%s' in sql:
|
|
status, result = conn.execute_dict(sql, [schema_id])
|
|
else:
|
|
status, result = conn.execute_dict(sql)
|
|
|
|
if status and result:
|
|
return {'rows': result.get('rows', [])}
|
|
else:
|
|
return {'error': 'Query failed', 'rows': []}
|
|
|
|
return executor
|
|
|
|
|
|
def generate_report_streaming(
|
|
report_type: str,
|
|
scope: str,
|
|
conn,
|
|
manager,
|
|
context: dict,
|
|
client: Optional[LLMClient] = None
|
|
) -> Generator[str, None, None]:
|
|
"""Generate a report with streaming progress updates.
|
|
|
|
Yields Server-Sent Events (SSE) formatted strings.
|
|
|
|
Args:
|
|
report_type: One of 'security', 'performance', 'design'.
|
|
scope: One of 'server', 'database', 'schema'.
|
|
conn: Database connection.
|
|
manager: Connection manager.
|
|
context: Report context dict with keys like:
|
|
- server_version
|
|
- database_name
|
|
- schema_name
|
|
- schema_id (for schema-scoped reports)
|
|
client: Optional LLM client (will create one if not provided).
|
|
|
|
Yields:
|
|
SSE-formatted event strings.
|
|
"""
|
|
# Get or create LLM client
|
|
if client is None:
|
|
client = get_llm_client()
|
|
if not client:
|
|
yield _sse_event({
|
|
'type': 'error',
|
|
'message': gettext('Failed to initialize LLM client.')
|
|
})
|
|
return
|
|
|
|
# Get sections for this report type and scope
|
|
sections = get_sections_for_scope(report_type, scope)
|
|
if not sections:
|
|
yield _sse_event({
|
|
'type': 'error',
|
|
'message': gettext('No sections available for this report type.')
|
|
})
|
|
return
|
|
|
|
# Add server version to context
|
|
context['server_version'] = manager.ver
|
|
|
|
# Create the pipeline
|
|
query_executor = create_query_executor(conn)
|
|
pipeline = ReportPipeline(
|
|
report_type=report_type,
|
|
sections=sections,
|
|
client=client,
|
|
query_executor=query_executor
|
|
)
|
|
|
|
# Execute pipeline and stream events
|
|
try:
|
|
for event in pipeline.execute_with_progress(context):
|
|
if event.get('type') == 'complete':
|
|
# Add disclaimer to final report
|
|
report = event.get('report', '')
|
|
disclaimer = gettext(
|
|
'> **Note:** This report was generated by '
|
|
'%(provider)s / %(model)s. '
|
|
'AI systems can make mistakes. Please verify all findings '
|
|
'and recommendations before taking action.\n\n'
|
|
) % {
|
|
'provider': client.provider_name,
|
|
'model': client.model_name
|
|
}
|
|
event['report'] = disclaimer + report
|
|
|
|
yield _sse_event(event)
|
|
|
|
except Exception as e:
|
|
yield _sse_event({
|
|
'type': 'error',
|
|
'message': gettext('Failed to generate report: ') + str(e)
|
|
})
|
|
|
|
|
|
def generate_report_sync(
|
|
report_type: str,
|
|
scope: str,
|
|
conn,
|
|
manager,
|
|
context: dict,
|
|
client: Optional[LLMClient] = None
|
|
) -> tuple[bool, str]:
|
|
"""Generate a report synchronously (non-streaming).
|
|
|
|
Args:
|
|
report_type: One of 'security', 'performance', 'design'.
|
|
scope: One of 'server', 'database', 'schema'.
|
|
conn: Database connection.
|
|
manager: Connection manager.
|
|
context: Report context dict.
|
|
client: Optional LLM client.
|
|
|
|
Returns:
|
|
Tuple of (success, report_or_error_message).
|
|
"""
|
|
# Get or create LLM client
|
|
if client is None:
|
|
client = get_llm_client()
|
|
if not client:
|
|
return False, gettext('Failed to initialize LLM client.')
|
|
|
|
# Get sections for this report type and scope
|
|
sections = get_sections_for_scope(report_type, scope)
|
|
if not sections:
|
|
return False, gettext('No sections available for this report type.')
|
|
|
|
# Add server version to context
|
|
context['server_version'] = manager.ver
|
|
|
|
# Create and execute the pipeline
|
|
query_executor = create_query_executor(conn)
|
|
pipeline = ReportPipeline(
|
|
report_type=report_type,
|
|
sections=sections,
|
|
client=client,
|
|
query_executor=query_executor
|
|
)
|
|
|
|
try:
|
|
report = pipeline.execute(context)
|
|
|
|
# Add disclaimer
|
|
disclaimer = gettext(
|
|
'> **Note:** This report was generated by '
|
|
'%(provider)s / %(model)s. '
|
|
'AI systems can make mistakes. Please verify all findings '
|
|
'and recommendations before taking action.\n\n'
|
|
) % {
|
|
'provider': client.provider_name,
|
|
'model': client.model_name
|
|
}
|
|
|
|
return True, disclaimer + report
|
|
|
|
except Exception as e:
|
|
return False, gettext('Failed to generate report: ') + str(e)
|
|
|
|
|
|
def _sse_event(data: dict) -> bytes:
|
|
"""Format data as an SSE event.
|
|
|
|
Args:
|
|
data: Event data dictionary.
|
|
|
|
Returns:
|
|
SSE-formatted bytes with padding to help flush buffers.
|
|
"""
|
|
# Add padding comment to help flush buffers in some WSGI servers
|
|
# Some servers buffer until a certain amount of data is received
|
|
json_data = json.dumps(data)
|
|
# Minimum 2KB total to help flush various buffer sizes
|
|
padding_needed = max(0, 2048 - len(json_data) - 20)
|
|
padding = f": {'.' * padding_needed}\n" if padding_needed > 0 else ""
|
|
return f"{padding}data: {json_data}\n\n".encode('utf-8')
|
|
|
|
|
|
def _wrap_generator_with_keepalive(generator: Generator) -> Generator:
|
|
"""Wrap a generator to add SSE keepalive and initial flush.
|
|
|
|
Args:
|
|
generator: Original event generator.
|
|
|
|
Yields:
|
|
SSE events (as bytes) with initial connection event.
|
|
"""
|
|
# Send initial comment to establish connection and flush headers
|
|
# The retry directive tells browser to reconnect after 3s if disconnected
|
|
yield b": SSE stream connected\nretry: 3000\n\n"
|
|
|
|
# Yield all events from the original generator
|
|
for event in generator:
|
|
yield event
|
|
|
|
|
|
def create_sse_response(generator: Generator) -> Response:
|
|
"""Create a Flask Response for SSE streaming.
|
|
|
|
Args:
|
|
generator: Generator that yields SSE event strings.
|
|
|
|
Returns:
|
|
Flask Response configured for SSE.
|
|
"""
|
|
# Wrap generator with keepalive/flush helper
|
|
wrapped = _wrap_generator_with_keepalive(generator)
|
|
|
|
# stream_with_context maintains Flask's request context throughout
|
|
# the generator's lifecycle, which is required for streaming responses
|
|
response = Response(
|
|
stream_with_context(wrapped),
|
|
mimetype='text/event-stream',
|
|
headers={
|
|
'Cache-Control': 'no-cache, no-store, must-revalidate',
|
|
'Pragma': 'no-cache',
|
|
'Expires': '0',
|
|
'Connection': 'keep-alive',
|
|
'X-Accel-Buffering': 'no', # Disable nginx buffering
|
|
}
|
|
)
|
|
# Disable Werkzeug's response buffering - critical for SSE to work
|
|
response.direct_passthrough = True
|
|
return response
|