refactor(platform/rest): Update REST API to use standard FastAPI structure (#8519)
parent
e26513f5e4
commit
594aa996d7
|
@ -1,695 +1,117 @@
|
|||
import asyncio
|
||||
import inspect
|
||||
import logging
|
||||
from collections import defaultdict
|
||||
from contextlib import asynccontextmanager
|
||||
from functools import wraps
|
||||
from typing import Annotated, Any, Dict
|
||||
import contextlib
|
||||
import typing
|
||||
|
||||
import fastapi
|
||||
import fastapi.middleware.cors
|
||||
import fastapi.responses
|
||||
import uvicorn
|
||||
from autogpt_libs.auth.middleware import auth_middleware
|
||||
from autogpt_libs.utils.cache import thread_cached
|
||||
from fastapi import APIRouter, Body, Depends, FastAPI, HTTPException, Request
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
from fastapi.responses import JSONResponse
|
||||
from typing_extensions import TypedDict
|
||||
|
||||
from backend.data import block, db
|
||||
from backend.data import execution as execution_db
|
||||
from backend.data import graph as graph_db
|
||||
from backend.data.block import BlockInput, CompletedBlockOutput
|
||||
from backend.data.credit import get_block_costs, get_user_credit_model
|
||||
from backend.data.user import get_or_create_user, migrate_and_encrypt_user_integrations
|
||||
from backend.executor import ExecutionManager, ExecutionScheduler
|
||||
from backend.server.model import CreateGraph, SetGraphActiveVersion
|
||||
from backend.util.service import AppService, get_service_client
|
||||
from backend.util.settings import AppEnvironment, Config, Settings
|
||||
import backend.data.block
|
||||
import backend.data.db
|
||||
import backend.data.user
|
||||
import backend.server.routers.v1
|
||||
import backend.util.service
|
||||
import backend.util.settings
|
||||
|
||||
from .utils import get_user_id
|
||||
|
||||
settings = Settings()
|
||||
logger = logging.getLogger(__name__)
|
||||
settings = backend.util.settings.Settings()
|
||||
|
||||
|
||||
class AgentServer(AppService):
|
||||
_test_dependency_overrides = {}
|
||||
_user_credit_model = get_user_credit_model()
|
||||
@contextlib.asynccontextmanager
|
||||
async def lifespan_context(app: fastapi.FastAPI):
|
||||
await backend.data.db.connect()
|
||||
await backend.data.block.initialize_blocks()
|
||||
await backend.data.user.migrate_and_encrypt_user_integrations()
|
||||
yield
|
||||
await backend.data.db.disconnect()
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.use_redis = True
|
||||
|
||||
@classmethod
|
||||
def get_port(cls) -> int:
|
||||
return Config().agent_server_port
|
||||
docs_url = (
|
||||
"/docs"
|
||||
if settings.config.app_env == backend.util.settings.AppEnvironment.LOCAL
|
||||
else None
|
||||
)
|
||||
|
||||
@asynccontextmanager
|
||||
async def lifespan(self, _: FastAPI):
|
||||
await db.connect()
|
||||
await block.initialize_blocks()
|
||||
await migrate_and_encrypt_user_integrations()
|
||||
yield
|
||||
await db.disconnect()
|
||||
app = fastapi.FastAPI(
|
||||
title="AutoGPT Agent Server",
|
||||
description=(
|
||||
"This server is used to execute agents that are created by the "
|
||||
"AutoGPT system."
|
||||
),
|
||||
summary="AutoGPT Agent Server",
|
||||
version="0.1",
|
||||
lifespan=lifespan_context,
|
||||
docs_url=docs_url,
|
||||
)
|
||||
|
||||
def run_service(self):
|
||||
docs_url = "/docs" if settings.config.app_env == AppEnvironment.LOCAL else None
|
||||
app = FastAPI(
|
||||
title="AutoGPT Agent Server",
|
||||
description=(
|
||||
"This server is used to execute agents that are created by the "
|
||||
"AutoGPT system."
|
||||
),
|
||||
summary="AutoGPT Agent Server",
|
||||
version="0.1",
|
||||
lifespan=self.lifespan,
|
||||
docs_url=docs_url,
|
||||
)
|
||||
app.include_router(backend.server.routers.v1.v1_router, tags=["v1"])
|
||||
app.add_middleware(
|
||||
fastapi.middleware.cors.CORSMiddleware,
|
||||
allow_origins=settings.config.backend_cors_allow_origins,
|
||||
allow_credentials=True,
|
||||
allow_methods=["*"], # Allows all methods
|
||||
allow_headers=["*"], # Allows all headers
|
||||
)
|
||||
|
||||
if self._test_dependency_overrides:
|
||||
app.dependency_overrides.update(self._test_dependency_overrides)
|
||||
|
||||
logger.debug(
|
||||
f"FastAPI CORS allow origins: {Config().backend_cors_allow_origins}"
|
||||
)
|
||||
@app.get(path="/health", tags=["health"], dependencies=[])
|
||||
async def health():
|
||||
return {"status": "healthy"}
|
||||
|
||||
app.add_middleware(
|
||||
CORSMiddleware,
|
||||
allow_origins=Config().backend_cors_allow_origins,
|
||||
allow_credentials=True,
|
||||
allow_methods=["*"], # Allows all methods
|
||||
allow_headers=["*"], # Allows all headers
|
||||
)
|
||||
|
||||
health_router = APIRouter()
|
||||
health_router.add_api_route(
|
||||
path="/health",
|
||||
endpoint=self.health,
|
||||
methods=["GET"],
|
||||
tags=["health"],
|
||||
)
|
||||
@app.exception_handler(Exception)
|
||||
def handle_internal_http_error(request: fastapi.Request, exc: Exception):
|
||||
return fastapi.responses.JSONResponse(
|
||||
content={
|
||||
"message": f"{request.method} {request.url.path} failed",
|
||||
"error": str(exc),
|
||||
},
|
||||
status_code=500,
|
||||
)
|
||||
|
||||
# Define the API routes
|
||||
api_router = APIRouter(prefix="/api")
|
||||
api_router.dependencies.append(Depends(auth_middleware))
|
||||
|
||||
# Import & Attach sub-routers
|
||||
import backend.server.integrations.router
|
||||
import backend.server.routers.analytics
|
||||
|
||||
api_router.include_router(
|
||||
backend.server.integrations.router.router,
|
||||
prefix="/integrations",
|
||||
tags=["integrations"],
|
||||
dependencies=[Depends(auth_middleware)],
|
||||
)
|
||||
|
||||
api_router.include_router(
|
||||
backend.server.routers.analytics.router,
|
||||
prefix="/analytics",
|
||||
tags=["analytics"],
|
||||
dependencies=[Depends(auth_middleware)],
|
||||
)
|
||||
|
||||
api_router.add_api_route(
|
||||
path="/auth/user",
|
||||
endpoint=self.get_or_create_user_route,
|
||||
methods=["POST"],
|
||||
tags=["auth"],
|
||||
)
|
||||
|
||||
api_router.add_api_route(
|
||||
path="/blocks",
|
||||
endpoint=self.get_graph_blocks,
|
||||
methods=["GET"],
|
||||
tags=["blocks"],
|
||||
)
|
||||
api_router.add_api_route(
|
||||
path="/blocks/{block_id}/execute",
|
||||
endpoint=self.execute_graph_block,
|
||||
methods=["POST"],
|
||||
tags=["blocks"],
|
||||
)
|
||||
api_router.add_api_route(
|
||||
path="/graphs",
|
||||
endpoint=self.get_graphs,
|
||||
methods=["GET"],
|
||||
tags=["graphs"],
|
||||
)
|
||||
api_router.add_api_route(
|
||||
path="/templates",
|
||||
endpoint=self.get_templates,
|
||||
methods=["GET"],
|
||||
tags=["templates", "graphs"],
|
||||
)
|
||||
api_router.add_api_route(
|
||||
path="/graphs",
|
||||
endpoint=self.create_new_graph,
|
||||
methods=["POST"],
|
||||
tags=["graphs"],
|
||||
)
|
||||
api_router.add_api_route(
|
||||
path="/templates",
|
||||
endpoint=self.create_new_template,
|
||||
methods=["POST"],
|
||||
tags=["templates", "graphs"],
|
||||
)
|
||||
api_router.add_api_route(
|
||||
path="/graphs/{graph_id}",
|
||||
endpoint=self.get_graph,
|
||||
methods=["GET"],
|
||||
tags=["graphs"],
|
||||
)
|
||||
api_router.add_api_route(
|
||||
path="/templates/{graph_id}",
|
||||
endpoint=self.get_template,
|
||||
methods=["GET"],
|
||||
tags=["templates", "graphs"],
|
||||
)
|
||||
api_router.add_api_route(
|
||||
path="/graphs/{graph_id}",
|
||||
endpoint=self.update_graph,
|
||||
methods=["PUT"],
|
||||
tags=["graphs"],
|
||||
)
|
||||
api_router.add_api_route(
|
||||
path="/templates/{graph_id}",
|
||||
endpoint=self.update_graph,
|
||||
methods=["PUT"],
|
||||
tags=["templates", "graphs"],
|
||||
)
|
||||
api_router.add_api_route(
|
||||
path="/graphs/{graph_id}",
|
||||
endpoint=self.delete_graph,
|
||||
methods=["DELETE"],
|
||||
tags=["graphs"],
|
||||
)
|
||||
api_router.add_api_route(
|
||||
path="/graphs/{graph_id}/versions",
|
||||
endpoint=self.get_graph_all_versions,
|
||||
methods=["GET"],
|
||||
tags=["graphs"],
|
||||
)
|
||||
api_router.add_api_route(
|
||||
path="/templates/{graph_id}/versions",
|
||||
endpoint=self.get_graph_all_versions,
|
||||
methods=["GET"],
|
||||
tags=["templates", "graphs"],
|
||||
)
|
||||
api_router.add_api_route(
|
||||
path="/graphs/{graph_id}/versions/{version}",
|
||||
endpoint=self.get_graph,
|
||||
methods=["GET"],
|
||||
tags=["graphs"],
|
||||
)
|
||||
api_router.add_api_route(
|
||||
path="/graphs/{graph_id}/versions/active",
|
||||
endpoint=self.set_graph_active_version,
|
||||
methods=["PUT"],
|
||||
tags=["graphs"],
|
||||
)
|
||||
api_router.add_api_route(
|
||||
path="/graphs/{graph_id}/input_schema",
|
||||
endpoint=self.get_graph_input_schema,
|
||||
methods=["GET"],
|
||||
tags=["graphs"],
|
||||
)
|
||||
api_router.add_api_route(
|
||||
path="/graphs/{graph_id}/execute",
|
||||
endpoint=self.execute_graph,
|
||||
methods=["POST"],
|
||||
tags=["graphs"],
|
||||
)
|
||||
api_router.add_api_route(
|
||||
path="/graphs/{graph_id}/executions",
|
||||
endpoint=self.list_graph_runs,
|
||||
methods=["GET"],
|
||||
tags=["graphs"],
|
||||
)
|
||||
api_router.add_api_route(
|
||||
path="/graphs/{graph_id}/executions/{graph_exec_id}",
|
||||
endpoint=self.get_graph_run_node_execution_results,
|
||||
methods=["GET"],
|
||||
tags=["graphs"],
|
||||
)
|
||||
api_router.add_api_route(
|
||||
path="/graphs/{graph_id}/executions/{graph_exec_id}/stop",
|
||||
endpoint=self.stop_graph_run,
|
||||
methods=["POST"],
|
||||
tags=["graphs"],
|
||||
)
|
||||
api_router.add_api_route(
|
||||
path="/graphs/{graph_id}/schedules",
|
||||
endpoint=self.create_schedule,
|
||||
methods=["POST"],
|
||||
tags=["graphs"],
|
||||
)
|
||||
api_router.add_api_route(
|
||||
path="/graphs/{graph_id}/schedules",
|
||||
endpoint=self.get_execution_schedules,
|
||||
methods=["GET"],
|
||||
tags=["graphs"],
|
||||
)
|
||||
api_router.add_api_route(
|
||||
path="/graphs/schedules/{schedule_id}",
|
||||
endpoint=self.update_schedule,
|
||||
methods=["PUT"],
|
||||
tags=["graphs"],
|
||||
)
|
||||
api_router.add_api_route(
|
||||
path="/credits",
|
||||
endpoint=self.get_user_credits,
|
||||
methods=["GET"],
|
||||
)
|
||||
|
||||
api_router.add_api_route(
|
||||
path="/settings",
|
||||
endpoint=self.update_configuration,
|
||||
methods=["POST"],
|
||||
tags=["settings"],
|
||||
)
|
||||
|
||||
app.add_exception_handler(500, self.handle_internal_http_error)
|
||||
|
||||
app.include_router(api_router)
|
||||
app.include_router(health_router)
|
||||
|
||||
class AgentServer(backend.util.service.AppProcess):
|
||||
def run(self):
|
||||
uvicorn.run(
|
||||
app,
|
||||
host=Config().agent_api_host,
|
||||
port=Config().agent_api_port,
|
||||
log_config=None,
|
||||
host=backend.util.settings.Config().agent_api_host,
|
||||
port=backend.util.settings.Config().agent_api_port,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
async def test_execute_graph(
|
||||
graph_id: str, node_input: dict[typing.Any, typing.Any], user_id: str
|
||||
):
|
||||
return await backend.server.routers.v1.execute_graph(
|
||||
graph_id, node_input, user_id
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
async def test_create_graph(
|
||||
create_graph: backend.server.routers.v1.CreateGraph,
|
||||
user_id: str,
|
||||
is_template=False,
|
||||
):
|
||||
return await backend.server.routers.v1.create_new_graph(create_graph, user_id)
|
||||
|
||||
@staticmethod
|
||||
async def test_get_graph_run_status(
|
||||
graph_id: str, graph_exec_id: str, user_id: str
|
||||
):
|
||||
return await backend.server.routers.v1.get_graph_run_status(
|
||||
graph_id, graph_exec_id, user_id
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
async def test_get_graph_run_node_execution_results(
|
||||
graph_id: str, graph_exec_id: str, user_id: str
|
||||
):
|
||||
return await backend.server.routers.v1.get_graph_run_node_execution_results(
|
||||
graph_id, graph_exec_id, user_id
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
async def test_delete_graph(graph_id: str, user_id: str):
|
||||
return await backend.server.routers.v1.delete_graph(graph_id, user_id)
|
||||
|
||||
def set_test_dependency_overrides(self, overrides: dict):
|
||||
self._test_dependency_overrides = overrides
|
||||
|
||||
def _apply_overrides_to_methods(self):
|
||||
for attr_name in dir(self):
|
||||
attr = getattr(self, attr_name)
|
||||
if callable(attr) and hasattr(attr, "__annotations__"):
|
||||
setattr(self, attr_name, self._override_method(attr))
|
||||
|
||||
# TODO: fix this with some proper refactoring of the server
|
||||
def _override_method(self, method):
|
||||
@wraps(method)
|
||||
async def wrapper(*args, **kwargs):
|
||||
sig = inspect.signature(method)
|
||||
for param_name, param in sig.parameters.items():
|
||||
if param.annotation is inspect.Parameter.empty:
|
||||
continue
|
||||
if isinstance(param.annotation, Depends) or ( # type: ignore
|
||||
isinstance(param.annotation, type) and issubclass(param.annotation, Depends) # type: ignore
|
||||
):
|
||||
dependency = param.annotation.dependency if isinstance(param.annotation, Depends) else param.annotation # type: ignore
|
||||
if dependency in self._test_dependency_overrides:
|
||||
kwargs[param_name] = self._test_dependency_overrides[
|
||||
dependency
|
||||
]()
|
||||
return await method(*args, **kwargs)
|
||||
|
||||
return wrapper
|
||||
|
||||
@property
|
||||
@thread_cached
|
||||
def execution_manager_client(self) -> ExecutionManager:
|
||||
return get_service_client(ExecutionManager)
|
||||
|
||||
@property
|
||||
@thread_cached
|
||||
def execution_scheduler_client(self) -> ExecutionScheduler:
|
||||
return get_service_client(ExecutionScheduler)
|
||||
|
||||
@classmethod
|
||||
def handle_internal_http_error(cls, request: Request, exc: Exception):
|
||||
return JSONResponse(
|
||||
content={
|
||||
"message": f"{request.method} {request.url.path} failed",
|
||||
"error": str(exc),
|
||||
},
|
||||
status_code=500,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
async def get_or_create_user_route(cls, user_data: dict = Depends(auth_middleware)):
|
||||
user = await get_or_create_user(user_data)
|
||||
return user.model_dump()
|
||||
|
||||
@classmethod
|
||||
def get_graph_blocks(cls) -> list[dict[Any, Any]]:
|
||||
blocks = [cls() for cls in block.get_blocks().values()]
|
||||
costs = get_block_costs()
|
||||
return [{**b.to_dict(), "costs": costs.get(b.id, [])} for b in blocks]
|
||||
|
||||
@classmethod
|
||||
def execute_graph_block(
|
||||
cls, block_id: str, data: BlockInput
|
||||
) -> CompletedBlockOutput:
|
||||
obj = block.get_block(block_id)
|
||||
if not obj:
|
||||
raise HTTPException(status_code=404, detail=f"Block #{block_id} not found.")
|
||||
|
||||
output = defaultdict(list)
|
||||
for name, data in obj.execute(data):
|
||||
output[name].append(data)
|
||||
return output
|
||||
|
||||
@classmethod
|
||||
async def get_graphs(
|
||||
cls,
|
||||
user_id: Annotated[str, Depends(get_user_id)],
|
||||
with_runs: bool = False,
|
||||
) -> list[graph_db.GraphMeta]:
|
||||
return await graph_db.get_graphs_meta(
|
||||
include_executions=with_runs, filter_by="active", user_id=user_id
|
||||
)
|
||||
|
||||
@classmethod
|
||||
async def get_templates(
|
||||
cls, user_id: Annotated[str, Depends(get_user_id)]
|
||||
) -> list[graph_db.GraphMeta]:
|
||||
return await graph_db.get_graphs_meta(filter_by="template", user_id=user_id)
|
||||
|
||||
@classmethod
|
||||
async def get_graph(
|
||||
cls,
|
||||
graph_id: str,
|
||||
user_id: Annotated[str, Depends(get_user_id)],
|
||||
version: int | None = None,
|
||||
hide_credentials: bool = False,
|
||||
) -> graph_db.Graph:
|
||||
graph = await graph_db.get_graph(
|
||||
graph_id, version, user_id=user_id, hide_credentials=hide_credentials
|
||||
)
|
||||
if not graph:
|
||||
raise HTTPException(status_code=404, detail=f"Graph #{graph_id} not found.")
|
||||
return graph
|
||||
|
||||
@classmethod
|
||||
async def get_template(
|
||||
cls, graph_id: str, version: int | None = None
|
||||
) -> graph_db.Graph:
|
||||
graph = await graph_db.get_graph(graph_id, version, template=True)
|
||||
if not graph:
|
||||
raise HTTPException(
|
||||
status_code=404, detail=f"Template #{graph_id} not found."
|
||||
)
|
||||
return graph
|
||||
|
||||
@classmethod
|
||||
async def get_graph_all_versions(
|
||||
cls, graph_id: str, user_id: Annotated[str, Depends(get_user_id)]
|
||||
) -> list[graph_db.Graph]:
|
||||
graphs = await graph_db.get_graph_all_versions(graph_id, user_id=user_id)
|
||||
if not graphs:
|
||||
raise HTTPException(status_code=404, detail=f"Graph #{graph_id} not found.")
|
||||
return graphs
|
||||
|
||||
@classmethod
|
||||
async def create_new_graph(
|
||||
cls, create_graph: CreateGraph, user_id: Annotated[str, Depends(get_user_id)]
|
||||
) -> graph_db.Graph:
|
||||
return await cls.create_graph(create_graph, is_template=False, user_id=user_id)
|
||||
|
||||
@classmethod
|
||||
async def create_new_template(
|
||||
cls, create_graph: CreateGraph, user_id: Annotated[str, Depends(get_user_id)]
|
||||
) -> graph_db.Graph:
|
||||
return await cls.create_graph(create_graph, is_template=True, user_id=user_id)
|
||||
|
||||
class DeleteGraphResponse(TypedDict):
|
||||
version_counts: int
|
||||
|
||||
@classmethod
|
||||
async def delete_graph(
|
||||
cls, graph_id: str, user_id: Annotated[str, Depends(get_user_id)]
|
||||
) -> DeleteGraphResponse:
|
||||
return {
|
||||
"version_counts": await graph_db.delete_graph(graph_id, user_id=user_id)
|
||||
}
|
||||
|
||||
@classmethod
|
||||
async def create_graph(
|
||||
cls,
|
||||
create_graph: CreateGraph,
|
||||
is_template: bool,
|
||||
# user_id doesn't have to be annotated like on other endpoints,
|
||||
# because create_graph isn't used directly as an endpoint
|
||||
user_id: str,
|
||||
) -> graph_db.Graph:
|
||||
if create_graph.graph:
|
||||
graph = create_graph.graph
|
||||
elif create_graph.template_id:
|
||||
# Create a new graph from a template
|
||||
graph = await graph_db.get_graph(
|
||||
create_graph.template_id,
|
||||
create_graph.template_version,
|
||||
template=True,
|
||||
user_id=user_id,
|
||||
)
|
||||
if not graph:
|
||||
raise HTTPException(
|
||||
400, detail=f"Template #{create_graph.template_id} not found"
|
||||
)
|
||||
graph.version = 1
|
||||
else:
|
||||
raise HTTPException(
|
||||
status_code=400, detail="Either graph or template_id must be provided."
|
||||
)
|
||||
|
||||
graph.is_template = is_template
|
||||
graph.is_active = not is_template
|
||||
graph.reassign_ids(reassign_graph_id=True)
|
||||
|
||||
return await graph_db.create_graph(graph, user_id=user_id)
|
||||
|
||||
@classmethod
|
||||
async def update_graph(
|
||||
cls,
|
||||
graph_id: str,
|
||||
graph: graph_db.Graph,
|
||||
user_id: Annotated[str, Depends(get_user_id)],
|
||||
) -> graph_db.Graph:
|
||||
# Sanity check
|
||||
if graph.id and graph.id != graph_id:
|
||||
raise HTTPException(400, detail="Graph ID does not match ID in URI")
|
||||
|
||||
# Determine new version
|
||||
existing_versions = await graph_db.get_graph_all_versions(
|
||||
graph_id, user_id=user_id
|
||||
)
|
||||
if not existing_versions:
|
||||
raise HTTPException(404, detail=f"Graph #{graph_id} not found")
|
||||
latest_version_number = max(g.version for g in existing_versions)
|
||||
graph.version = latest_version_number + 1
|
||||
|
||||
latest_version_graph = next(
|
||||
v for v in existing_versions if v.version == latest_version_number
|
||||
)
|
||||
if latest_version_graph.is_template != graph.is_template:
|
||||
raise HTTPException(
|
||||
400, detail="Changing is_template on an existing graph is forbidden"
|
||||
)
|
||||
graph.is_active = not graph.is_template
|
||||
graph.reassign_ids()
|
||||
|
||||
new_graph_version = await graph_db.create_graph(graph, user_id=user_id)
|
||||
|
||||
if new_graph_version.is_active:
|
||||
# Ensure new version is the only active version
|
||||
await graph_db.set_graph_active_version(
|
||||
graph_id=graph_id, version=new_graph_version.version, user_id=user_id
|
||||
)
|
||||
|
||||
return new_graph_version
|
||||
|
||||
@classmethod
|
||||
async def set_graph_active_version(
|
||||
cls,
|
||||
graph_id: str,
|
||||
request_body: SetGraphActiveVersion,
|
||||
user_id: Annotated[str, Depends(get_user_id)],
|
||||
):
|
||||
new_active_version = request_body.active_graph_version
|
||||
if not await graph_db.get_graph(graph_id, new_active_version, user_id=user_id):
|
||||
raise HTTPException(
|
||||
404, f"Graph #{graph_id} v{new_active_version} not found"
|
||||
)
|
||||
await graph_db.set_graph_active_version(
|
||||
graph_id=graph_id,
|
||||
version=request_body.active_graph_version,
|
||||
user_id=user_id,
|
||||
)
|
||||
|
||||
def execute_graph(
|
||||
self,
|
||||
graph_id: str,
|
||||
node_input: dict[Any, Any],
|
||||
user_id: Annotated[str, Depends(get_user_id)],
|
||||
) -> dict[str, Any]: # FIXME: add proper return type
|
||||
try:
|
||||
graph_exec = self.execution_manager_client.add_execution(
|
||||
graph_id, node_input, user_id=user_id
|
||||
)
|
||||
return {"id": graph_exec["graph_exec_id"]}
|
||||
except Exception as e:
|
||||
msg = e.__str__().encode().decode("unicode_escape")
|
||||
raise HTTPException(status_code=400, detail=msg)
|
||||
|
||||
async def stop_graph_run(
|
||||
self, graph_exec_id: str, user_id: Annotated[str, Depends(get_user_id)]
|
||||
) -> list[execution_db.ExecutionResult]:
|
||||
if not await execution_db.get_graph_execution(graph_exec_id, user_id):
|
||||
raise HTTPException(
|
||||
404, detail=f"Agent execution #{graph_exec_id} not found"
|
||||
)
|
||||
|
||||
await asyncio.to_thread(
|
||||
lambda: self.execution_manager_client.cancel_execution(graph_exec_id)
|
||||
)
|
||||
|
||||
# Retrieve & return canceled graph execution in its final state
|
||||
return await execution_db.get_execution_results(graph_exec_id)
|
||||
|
||||
@classmethod
|
||||
async def get_graph_input_schema(
|
||||
cls,
|
||||
graph_id: str,
|
||||
user_id: Annotated[str, Depends(get_user_id)],
|
||||
) -> list[graph_db.InputSchemaItem]:
|
||||
try:
|
||||
graph = await graph_db.get_graph(graph_id, user_id=user_id)
|
||||
return graph.get_input_schema() if graph else []
|
||||
except Exception:
|
||||
raise HTTPException(status_code=404, detail=f"Graph #{graph_id} not found.")
|
||||
|
||||
@classmethod
|
||||
async def list_graph_runs(
|
||||
cls,
|
||||
graph_id: str,
|
||||
user_id: Annotated[str, Depends(get_user_id)],
|
||||
graph_version: int | None = None,
|
||||
) -> list[str]:
|
||||
graph = await graph_db.get_graph(graph_id, graph_version, user_id=user_id)
|
||||
if not graph:
|
||||
rev = "" if graph_version is None else f" v{graph_version}"
|
||||
raise HTTPException(
|
||||
status_code=404, detail=f"Agent #{graph_id}{rev} not found."
|
||||
)
|
||||
|
||||
return await execution_db.list_executions(graph_id, graph_version)
|
||||
|
||||
@classmethod
|
||||
async def get_graph_run_status(
|
||||
cls,
|
||||
graph_id: str,
|
||||
graph_exec_id: str,
|
||||
user_id: Annotated[str, Depends(get_user_id)],
|
||||
) -> execution_db.ExecutionStatus:
|
||||
graph = await graph_db.get_graph(graph_id, user_id=user_id)
|
||||
if not graph:
|
||||
raise HTTPException(status_code=404, detail=f"Graph #{graph_id} not found.")
|
||||
|
||||
execution = await execution_db.get_graph_execution(graph_exec_id, user_id)
|
||||
if not execution:
|
||||
raise HTTPException(
|
||||
status_code=404, detail=f"Execution #{graph_exec_id} not found."
|
||||
)
|
||||
|
||||
return execution.executionStatus
|
||||
|
||||
@classmethod
|
||||
async def get_graph_run_node_execution_results(
|
||||
cls,
|
||||
graph_id: str,
|
||||
graph_exec_id: str,
|
||||
user_id: Annotated[str, Depends(get_user_id)],
|
||||
) -> list[execution_db.ExecutionResult]:
|
||||
graph = await graph_db.get_graph(graph_id, user_id=user_id)
|
||||
if not graph:
|
||||
raise HTTPException(status_code=404, detail=f"Graph #{graph_id} not found.")
|
||||
|
||||
return await execution_db.get_execution_results(graph_exec_id)
|
||||
|
||||
async def create_schedule(
|
||||
self,
|
||||
graph_id: str,
|
||||
cron: str,
|
||||
input_data: dict[Any, Any],
|
||||
user_id: Annotated[str, Depends(get_user_id)],
|
||||
) -> dict[Any, Any]:
|
||||
graph = await graph_db.get_graph(graph_id, user_id=user_id)
|
||||
if not graph:
|
||||
raise HTTPException(status_code=404, detail=f"Graph #{graph_id} not found.")
|
||||
|
||||
return {
|
||||
"id": await asyncio.to_thread(
|
||||
lambda: self.execution_scheduler_client.add_execution_schedule(
|
||||
graph_id=graph_id,
|
||||
graph_version=graph.version,
|
||||
cron=cron,
|
||||
input_data=input_data,
|
||||
user_id=user_id,
|
||||
)
|
||||
)
|
||||
}
|
||||
|
||||
def update_schedule(
|
||||
self,
|
||||
schedule_id: str,
|
||||
input_data: dict[Any, Any],
|
||||
user_id: Annotated[str, Depends(get_user_id)],
|
||||
) -> dict[Any, Any]:
|
||||
execution_scheduler = self.execution_scheduler_client
|
||||
is_enabled = input_data.get("is_enabled", False)
|
||||
execution_scheduler.update_schedule(schedule_id, is_enabled, user_id=user_id)
|
||||
return {"id": schedule_id}
|
||||
|
||||
async def get_user_credits(
|
||||
self, user_id: Annotated[str, Depends(get_user_id)]
|
||||
) -> dict[str, int]:
|
||||
return {"credits": await self._user_credit_model.get_or_refill_credit(user_id)}
|
||||
|
||||
def get_execution_schedules(
|
||||
self, graph_id: str, user_id: Annotated[str, Depends(get_user_id)]
|
||||
) -> dict[str, str]:
|
||||
execution_scheduler = self.execution_scheduler_client
|
||||
return execution_scheduler.get_execution_schedules(graph_id, user_id)
|
||||
|
||||
async def health(self):
|
||||
return {"status": "healthy"}
|
||||
|
||||
@classmethod
|
||||
def update_configuration(
|
||||
cls,
|
||||
updated_settings: Annotated[
|
||||
Dict[str, Any],
|
||||
Body(
|
||||
examples=[
|
||||
{
|
||||
"config": {
|
||||
"num_graph_workers": 10,
|
||||
"num_node_workers": 10,
|
||||
}
|
||||
}
|
||||
]
|
||||
),
|
||||
],
|
||||
):
|
||||
settings = Settings()
|
||||
try:
|
||||
updated_fields: dict[Any, Any] = {"config": [], "secrets": []}
|
||||
for key, value in updated_settings.get("config", {}).items():
|
||||
if hasattr(settings.config, key):
|
||||
setattr(settings.config, key, value)
|
||||
updated_fields["config"].append(key)
|
||||
for key, value in updated_settings.get("secrets", {}).items():
|
||||
if hasattr(settings.secrets, key):
|
||||
setattr(settings.secrets, key, value)
|
||||
updated_fields["secrets"].append(key)
|
||||
settings.save()
|
||||
return {
|
||||
"message": "Settings updated successfully",
|
||||
"updated_fields": updated_fields,
|
||||
}
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=400, detail=str(e))
|
||||
app.dependency_overrides.update(overrides)
|
||||
|
|
|
@ -0,0 +1,539 @@
|
|||
import asyncio
|
||||
import logging
|
||||
from collections import defaultdict
|
||||
from typing import Annotated, Any, Dict
|
||||
|
||||
from autogpt_libs.auth.middleware import auth_middleware
|
||||
from autogpt_libs.utils.cache import thread_cached
|
||||
from fastapi import APIRouter, Body, Depends, HTTPException
|
||||
from typing_extensions import TypedDict
|
||||
|
||||
import backend.data.block
|
||||
import backend.server.integrations.router
|
||||
import backend.server.routers.analytics
|
||||
from backend.data import execution as execution_db
|
||||
from backend.data import graph as graph_db
|
||||
from backend.data.block import BlockInput, CompletedBlockOutput
|
||||
from backend.data.credit import get_block_costs, get_user_credit_model
|
||||
from backend.data.user import get_or_create_user
|
||||
from backend.executor import ExecutionManager, ExecutionScheduler
|
||||
from backend.server.model import CreateGraph, SetGraphActiveVersion
|
||||
from backend.server.utils import get_user_id
|
||||
from backend.util.service import get_service_client
|
||||
from backend.util.settings import Settings
|
||||
|
||||
|
||||
@thread_cached
|
||||
def execution_manager_client() -> ExecutionManager:
|
||||
return get_service_client(ExecutionManager)
|
||||
|
||||
|
||||
@thread_cached
|
||||
def execution_scheduler_client() -> ExecutionScheduler:
|
||||
return get_service_client(ExecutionScheduler)
|
||||
|
||||
|
||||
settings = Settings()
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
_user_credit_model = get_user_credit_model()
|
||||
|
||||
# Define the API routes
|
||||
v1_router = APIRouter(prefix="/api")
|
||||
|
||||
|
||||
v1_router.dependencies.append(Depends(auth_middleware))
|
||||
|
||||
|
||||
v1_router.include_router(
|
||||
backend.server.integrations.router.router,
|
||||
prefix="/integrations",
|
||||
tags=["integrations"],
|
||||
dependencies=[Depends(auth_middleware)],
|
||||
)
|
||||
|
||||
v1_router.include_router(
|
||||
backend.server.routers.analytics.router,
|
||||
prefix="/analytics",
|
||||
tags=["analytics"],
|
||||
dependencies=[Depends(auth_middleware)],
|
||||
)
|
||||
|
||||
|
||||
########################################################
|
||||
##################### Auth #############################
|
||||
########################################################
|
||||
|
||||
|
||||
@v1_router.post("/auth/user", tags=["auth"], dependencies=[Depends(auth_middleware)])
|
||||
async def get_or_create_user_route(user_data: dict = Depends(auth_middleware)):
|
||||
user = await get_or_create_user(user_data)
|
||||
return user.model_dump()
|
||||
|
||||
|
||||
########################################################
|
||||
##################### Blocks ###########################
|
||||
########################################################
|
||||
|
||||
|
||||
@v1_router.get(path="/blocks", tags=["blocks"], dependencies=[Depends(auth_middleware)])
|
||||
def get_graph_blocks() -> list[dict[Any, Any]]:
|
||||
blocks = [block() for block in backend.data.block.get_blocks().values()]
|
||||
costs = get_block_costs()
|
||||
return [{**b.to_dict(), "costs": costs.get(b.id, [])} for b in blocks]
|
||||
|
||||
|
||||
@v1_router.post(path="/blocks/{block_id}/execute", tags=["blocks"])
|
||||
def execute_graph_block(block_id: str, data: BlockInput) -> CompletedBlockOutput:
|
||||
obj = backend.data.block.get_block(block_id)
|
||||
if not obj:
|
||||
raise HTTPException(status_code=404, detail=f"Block #{block_id} not found.")
|
||||
|
||||
output = defaultdict(list)
|
||||
for name, data in obj.execute(data):
|
||||
output[name].append(data)
|
||||
return output
|
||||
|
||||
|
||||
########################################################
|
||||
##################### Credits ##########################
|
||||
########################################################
|
||||
|
||||
|
||||
@v1_router.get(path="/credits", dependencies=[Depends(auth_middleware)])
|
||||
async def get_user_credits(
|
||||
user_id: Annotated[str, Depends(get_user_id)]
|
||||
) -> dict[str, int]:
|
||||
return {"credits": await _user_credit_model.get_or_refill_credit(user_id)}
|
||||
|
||||
|
||||
########################################################
|
||||
##################### Graphs ###########################
|
||||
########################################################
|
||||
|
||||
|
||||
class DeleteGraphResponse(TypedDict):
|
||||
version_counts: int
|
||||
|
||||
|
||||
@v1_router.get(path="/graphs", tags=["graphs"], dependencies=[Depends(auth_middleware)])
|
||||
async def get_graphs(
|
||||
user_id: Annotated[str, Depends(get_user_id)],
|
||||
with_runs: bool = False,
|
||||
) -> list[graph_db.GraphMeta]:
|
||||
return await graph_db.get_graphs_meta(
|
||||
include_executions=with_runs, filter_by="active", user_id=user_id
|
||||
)
|
||||
|
||||
|
||||
@v1_router.get(
|
||||
path="/graphs/{graph_id}", tags=["graphs"], dependencies=[Depends(auth_middleware)]
|
||||
)
|
||||
@v1_router.get(
|
||||
path="/graphs/{graph_id}/versions/{version}",
|
||||
tags=["graphs"],
|
||||
dependencies=[Depends(auth_middleware)],
|
||||
)
|
||||
async def get_graph(
|
||||
graph_id: str,
|
||||
user_id: Annotated[str, Depends(get_user_id)],
|
||||
version: int | None = None,
|
||||
hide_credentials: bool = False,
|
||||
) -> graph_db.Graph:
|
||||
graph = await graph_db.get_graph(
|
||||
graph_id, version, user_id=user_id, hide_credentials=hide_credentials
|
||||
)
|
||||
if not graph:
|
||||
raise HTTPException(status_code=404, detail=f"Graph #{graph_id} not found.")
|
||||
return graph
|
||||
|
||||
|
||||
@v1_router.get(
|
||||
path="/graphs/{graph_id}/versions",
|
||||
tags=["graphs"],
|
||||
dependencies=[Depends(auth_middleware)],
|
||||
)
|
||||
@v1_router.get(
|
||||
path="/templates/{graph_id}/versions",
|
||||
tags=["templates", "graphs"],
|
||||
dependencies=[Depends(auth_middleware)],
|
||||
)
|
||||
async def get_graph_all_versions(
|
||||
graph_id: str, user_id: Annotated[str, Depends(get_user_id)]
|
||||
) -> list[graph_db.Graph]:
|
||||
graphs = await graph_db.get_graph_all_versions(graph_id, user_id=user_id)
|
||||
if not graphs:
|
||||
raise HTTPException(status_code=404, detail=f"Graph #{graph_id} not found.")
|
||||
return graphs
|
||||
|
||||
|
||||
@v1_router.delete(
|
||||
path="/graphs/{graph_id}", tags=["graphs"], dependencies=[Depends(auth_middleware)]
|
||||
)
|
||||
async def delete_graph(
|
||||
graph_id: str, user_id: Annotated[str, Depends(get_user_id)]
|
||||
) -> DeleteGraphResponse:
|
||||
return {"version_counts": await graph_db.delete_graph(graph_id, user_id=user_id)}
|
||||
|
||||
|
||||
@v1_router.put(
|
||||
path="/graphs/{graph_id}", tags=["graphs"], dependencies=[Depends(auth_middleware)]
|
||||
)
|
||||
@v1_router.put(
|
||||
path="/templates/{graph_id}",
|
||||
tags=["templates", "graphs"],
|
||||
dependencies=[Depends(auth_middleware)],
|
||||
)
|
||||
async def update_graph(
|
||||
graph_id: str,
|
||||
graph: graph_db.Graph,
|
||||
user_id: Annotated[str, Depends(get_user_id)],
|
||||
) -> graph_db.Graph:
|
||||
# Sanity check
|
||||
if graph.id and graph.id != graph_id:
|
||||
raise HTTPException(400, detail="Graph ID does not match ID in URI")
|
||||
|
||||
# Determine new version
|
||||
existing_versions = await graph_db.get_graph_all_versions(graph_id, user_id=user_id)
|
||||
if not existing_versions:
|
||||
raise HTTPException(404, detail=f"Graph #{graph_id} not found")
|
||||
latest_version_number = max(g.version for g in existing_versions)
|
||||
graph.version = latest_version_number + 1
|
||||
|
||||
latest_version_graph = next(
|
||||
v for v in existing_versions if v.version == latest_version_number
|
||||
)
|
||||
if latest_version_graph.is_template != graph.is_template:
|
||||
raise HTTPException(
|
||||
400, detail="Changing is_template on an existing graph is forbidden"
|
||||
)
|
||||
graph.is_active = not graph.is_template
|
||||
graph.reassign_ids()
|
||||
|
||||
new_graph_version = await graph_db.create_graph(graph, user_id=user_id)
|
||||
|
||||
if new_graph_version.is_active:
|
||||
# Ensure new version is the only active version
|
||||
await graph_db.set_graph_active_version(
|
||||
graph_id=graph_id, version=new_graph_version.version, user_id=user_id
|
||||
)
|
||||
|
||||
return new_graph_version
|
||||
|
||||
|
||||
@v1_router.post(
|
||||
path="/graphs", tags=["graphs"], dependencies=[Depends(auth_middleware)]
|
||||
)
|
||||
async def create_new_graph(
|
||||
create_graph: CreateGraph, user_id: Annotated[str, Depends(get_user_id)]
|
||||
) -> graph_db.Graph:
|
||||
return await do_create_graph(create_graph, is_template=False, user_id=user_id)
|
||||
|
||||
|
||||
@v1_router.put(
|
||||
path="/graphs/{graph_id}/versions/active",
|
||||
tags=["graphs"],
|
||||
dependencies=[Depends(auth_middleware)],
|
||||
)
|
||||
async def set_graph_active_version(
|
||||
graph_id: str,
|
||||
request_body: SetGraphActiveVersion,
|
||||
user_id: Annotated[str, Depends(get_user_id)],
|
||||
):
|
||||
new_active_version = request_body.active_graph_version
|
||||
if not await graph_db.get_graph(graph_id, new_active_version, user_id=user_id):
|
||||
raise HTTPException(404, f"Graph #{graph_id} v{new_active_version} not found")
|
||||
await graph_db.set_graph_active_version(
|
||||
graph_id=graph_id,
|
||||
version=request_body.active_graph_version,
|
||||
user_id=user_id,
|
||||
)
|
||||
|
||||
|
||||
@v1_router.post(
|
||||
path="/graphs/{graph_id}/execute",
|
||||
tags=["graphs"],
|
||||
dependencies=[Depends(auth_middleware)],
|
||||
)
|
||||
async def execute_graph(
|
||||
graph_id: str,
|
||||
node_input: dict[Any, Any],
|
||||
user_id: Annotated[str, Depends(get_user_id)],
|
||||
) -> dict[str, Any]: # FIXME: add proper return type
|
||||
try:
|
||||
graph_exec = execution_manager_client().add_execution(
|
||||
graph_id, node_input, user_id=user_id
|
||||
)
|
||||
return {"id": graph_exec["graph_exec_id"]}
|
||||
except Exception as e:
|
||||
msg = e.__str__().encode().decode("unicode_escape")
|
||||
raise HTTPException(status_code=400, detail=msg)
|
||||
|
||||
|
||||
@v1_router.post(
|
||||
path="/graphs/{graph_id}/executions/{graph_exec_id}/stop",
|
||||
tags=["graphs"],
|
||||
dependencies=[Depends(auth_middleware)],
|
||||
)
|
||||
async def stop_graph_run(
|
||||
graph_exec_id: str, user_id: Annotated[str, Depends(get_user_id)]
|
||||
) -> list[execution_db.ExecutionResult]:
|
||||
if not await execution_db.get_graph_execution(graph_exec_id, user_id):
|
||||
raise HTTPException(404, detail=f"Agent execution #{graph_exec_id} not found")
|
||||
|
||||
await asyncio.to_thread(
|
||||
lambda: execution_manager_client().cancel_execution(graph_exec_id)
|
||||
)
|
||||
|
||||
# Retrieve & return canceled graph execution in its final state
|
||||
return await execution_db.get_execution_results(graph_exec_id)
|
||||
|
||||
|
||||
@v1_router.get(
|
||||
path="/graphs/{graph_id}/input_schema",
|
||||
tags=["graphs"],
|
||||
dependencies=[Depends(auth_middleware)],
|
||||
)
|
||||
async def get_graph_input_schema(
|
||||
graph_id: str,
|
||||
user_id: Annotated[str, Depends(get_user_id)],
|
||||
) -> list[graph_db.InputSchemaItem]:
|
||||
try:
|
||||
graph = await graph_db.get_graph(graph_id, user_id=user_id)
|
||||
return graph.get_input_schema() if graph else []
|
||||
except Exception:
|
||||
raise HTTPException(status_code=404, detail=f"Graph #{graph_id} not found.")
|
||||
|
||||
|
||||
@v1_router.get(
|
||||
path="/graphs/{graph_id}/executions",
|
||||
tags=["graphs"],
|
||||
dependencies=[Depends(auth_middleware)],
|
||||
)
|
||||
async def list_graph_runs(
|
||||
graph_id: str,
|
||||
user_id: Annotated[str, Depends(get_user_id)],
|
||||
graph_version: int | None = None,
|
||||
) -> list[str]:
|
||||
graph = await graph_db.get_graph(graph_id, graph_version, user_id=user_id)
|
||||
if not graph:
|
||||
rev = "" if graph_version is None else f" v{graph_version}"
|
||||
raise HTTPException(
|
||||
status_code=404, detail=f"Agent #{graph_id}{rev} not found."
|
||||
)
|
||||
|
||||
return await execution_db.list_executions(graph_id, graph_version)
|
||||
|
||||
|
||||
@v1_router.get(
|
||||
path="/graphs/{graph_id}/executions/{graph_exec_id}",
|
||||
tags=["graphs"],
|
||||
dependencies=[Depends(auth_middleware)],
|
||||
)
|
||||
async def get_graph_run_node_execution_results(
|
||||
graph_id: str,
|
||||
graph_exec_id: str,
|
||||
user_id: Annotated[str, Depends(get_user_id)],
|
||||
) -> list[execution_db.ExecutionResult]:
|
||||
graph = await graph_db.get_graph(graph_id, user_id=user_id)
|
||||
if not graph:
|
||||
raise HTTPException(status_code=404, detail=f"Graph #{graph_id} not found.")
|
||||
|
||||
return await execution_db.get_execution_results(graph_exec_id)
|
||||
|
||||
|
||||
# NOTE: This is used for testing
|
||||
async def get_graph_run_status(
|
||||
graph_id: str,
|
||||
graph_exec_id: str,
|
||||
user_id: Annotated[str, Depends(get_user_id)],
|
||||
) -> execution_db.ExecutionStatus:
|
||||
graph = await graph_db.get_graph(graph_id, user_id=user_id)
|
||||
if not graph:
|
||||
raise HTTPException(status_code=404, detail=f"Graph #{graph_id} not found.")
|
||||
|
||||
execution = await execution_db.get_graph_execution(graph_exec_id, user_id)
|
||||
if not execution:
|
||||
raise HTTPException(
|
||||
status_code=404, detail=f"Execution #{graph_exec_id} not found."
|
||||
)
|
||||
|
||||
return execution.executionStatus
|
||||
|
||||
|
||||
########################################################
|
||||
##################### Templates ########################
|
||||
########################################################
|
||||
|
||||
|
||||
@v1_router.get(
|
||||
path="/templates",
|
||||
tags=["graphs", "templates"],
|
||||
dependencies=[Depends(auth_middleware)],
|
||||
)
|
||||
async def get_templates(
|
||||
user_id: Annotated[str, Depends(get_user_id)]
|
||||
) -> list[graph_db.GraphMeta]:
|
||||
return await graph_db.get_graphs_meta(filter_by="template", user_id=user_id)
|
||||
|
||||
|
||||
@v1_router.get(
|
||||
path="/templates/{graph_id}",
|
||||
tags=["templates", "graphs"],
|
||||
dependencies=[Depends(auth_middleware)],
|
||||
)
|
||||
async def get_template(graph_id: str, version: int | None = None) -> graph_db.Graph:
|
||||
graph = await graph_db.get_graph(graph_id, version, template=True)
|
||||
if not graph:
|
||||
raise HTTPException(status_code=404, detail=f"Template #{graph_id} not found.")
|
||||
return graph
|
||||
|
||||
|
||||
async def do_create_graph(
|
||||
create_graph: CreateGraph,
|
||||
is_template: bool,
|
||||
# user_id doesn't have to be annotated like on other endpoints,
|
||||
# because create_graph isn't used directly as an endpoint
|
||||
user_id: str,
|
||||
) -> graph_db.Graph:
|
||||
if create_graph.graph:
|
||||
graph = create_graph.graph
|
||||
elif create_graph.template_id:
|
||||
# Create a new graph from a template
|
||||
graph = await graph_db.get_graph(
|
||||
create_graph.template_id,
|
||||
create_graph.template_version,
|
||||
template=True,
|
||||
user_id=user_id,
|
||||
)
|
||||
if not graph:
|
||||
raise HTTPException(
|
||||
400, detail=f"Template #{create_graph.template_id} not found"
|
||||
)
|
||||
graph.version = 1
|
||||
else:
|
||||
raise HTTPException(
|
||||
status_code=400, detail="Either graph or template_id must be provided."
|
||||
)
|
||||
|
||||
graph.is_template = is_template
|
||||
graph.is_active = not is_template
|
||||
graph.reassign_ids(reassign_graph_id=True)
|
||||
|
||||
return await graph_db.create_graph(graph, user_id=user_id)
|
||||
|
||||
|
||||
@v1_router.post(
|
||||
path="/templates",
|
||||
tags=["templates", "graphs"],
|
||||
dependencies=[Depends(auth_middleware)],
|
||||
)
|
||||
async def create_new_template(
|
||||
create_graph: CreateGraph, user_id: Annotated[str, Depends(get_user_id)]
|
||||
) -> graph_db.Graph:
|
||||
return await do_create_graph(create_graph, is_template=True, user_id=user_id)
|
||||
|
||||
|
||||
########################################################
|
||||
##################### Schedules ########################
|
||||
########################################################
|
||||
|
||||
|
||||
@v1_router.post(
|
||||
path="/graphs/{graph_id}/schedules",
|
||||
tags=["graphs"],
|
||||
dependencies=[Depends(auth_middleware)],
|
||||
)
|
||||
async def create_schedule(
|
||||
graph_id: str,
|
||||
cron: str,
|
||||
input_data: dict[Any, Any],
|
||||
user_id: Annotated[str, Depends(get_user_id)],
|
||||
) -> dict[Any, Any]:
|
||||
graph = await graph_db.get_graph(graph_id, user_id=user_id)
|
||||
if not graph:
|
||||
raise HTTPException(status_code=404, detail=f"Graph #{graph_id} not found.")
|
||||
|
||||
return {
|
||||
"id": await asyncio.to_thread(
|
||||
lambda: execution_scheduler_client().add_execution_schedule(
|
||||
graph_id=graph_id,
|
||||
graph_version=graph.version,
|
||||
cron=cron,
|
||||
input_data=input_data,
|
||||
user_id=user_id,
|
||||
)
|
||||
)
|
||||
}
|
||||
|
||||
|
||||
@v1_router.put(
|
||||
path="/graphs/schedules/{schedule_id}",
|
||||
tags=["graphs"],
|
||||
dependencies=[Depends(auth_middleware)],
|
||||
)
|
||||
async def update_schedule(
|
||||
schedule_id: str,
|
||||
input_data: dict[Any, Any],
|
||||
user_id: Annotated[str, Depends(get_user_id)],
|
||||
) -> dict[Any, Any]:
|
||||
is_enabled = input_data.get("is_enabled", False)
|
||||
execution_scheduler_client().update_schedule(
|
||||
schedule_id, is_enabled, user_id=user_id
|
||||
)
|
||||
return {"id": schedule_id}
|
||||
|
||||
|
||||
@v1_router.get(
|
||||
path="/graphs/{graph_id}/schedules",
|
||||
tags=["graphs"],
|
||||
dependencies=[Depends(auth_middleware)],
|
||||
)
|
||||
async def get_execution_schedules(
|
||||
graph_id: str, user_id: Annotated[str, Depends(get_user_id)]
|
||||
) -> dict[str, str]:
|
||||
return execution_scheduler_client().get_execution_schedules(graph_id, user_id)
|
||||
|
||||
|
||||
########################################################
|
||||
##################### Settings ########################
|
||||
########################################################
|
||||
|
||||
|
||||
@v1_router.post(
|
||||
path="/settings", tags=["settings"], dependencies=[Depends(auth_middleware)]
|
||||
)
|
||||
async def update_configuration(
|
||||
updated_settings: Annotated[
|
||||
Dict[str, Any],
|
||||
Body(
|
||||
examples=[
|
||||
{
|
||||
"config": {
|
||||
"num_graph_workers": 10,
|
||||
"num_node_workers": 10,
|
||||
}
|
||||
}
|
||||
]
|
||||
),
|
||||
],
|
||||
):
|
||||
settings = Settings()
|
||||
try:
|
||||
updated_fields: dict[Any, Any] = {"config": [], "secrets": []}
|
||||
for key, value in updated_settings.get("config", {}).items():
|
||||
if hasattr(settings.config, key):
|
||||
setattr(settings.config, key, value)
|
||||
updated_fields["config"].append(key)
|
||||
for key, value in updated_settings.get("secrets", {}).items():
|
||||
if hasattr(settings.secrets, key):
|
||||
setattr(settings.secrets, key, value)
|
||||
updated_fields["secrets"].append(key)
|
||||
settings.save()
|
||||
return {
|
||||
"message": "Settings updated successfully",
|
||||
"updated_fields": updated_fields,
|
||||
}
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=400, detail=str(e))
|
|
@ -252,7 +252,7 @@ async def block_autogen_agent():
|
|||
test_user = await create_test_user()
|
||||
test_graph = await create_graph(create_test_graph(), user_id=test_user.id)
|
||||
input_data = {"input": "Write me a block that writes a string into a file."}
|
||||
response = server.agent_server.execute_graph(
|
||||
response = await server.agent_server.test_execute_graph(
|
||||
test_graph.id, input_data, test_user.id
|
||||
)
|
||||
print(response)
|
||||
|
|
|
@ -156,7 +156,7 @@ async def reddit_marketing_agent():
|
|||
test_user = await create_test_user()
|
||||
test_graph = await create_graph(create_test_graph(), user_id=test_user.id)
|
||||
input_data = {"subreddit": "AutoGPT"}
|
||||
response = server.agent_server.execute_graph(
|
||||
response = await server.agent_server.test_execute_graph(
|
||||
test_graph.id, input_data, test_user.id
|
||||
)
|
||||
print(response)
|
||||
|
|
|
@ -78,7 +78,7 @@ async def sample_agent():
|
|||
test_user = await create_test_user()
|
||||
test_graph = await create_graph(create_test_graph(), test_user.id)
|
||||
input_data = {"input_1": "Hello", "input_2": "World"}
|
||||
response = server.agent_server.execute_graph(
|
||||
response = await server.agent_server.test_execute_graph(
|
||||
test_graph.id, input_data, test_user.id
|
||||
)
|
||||
print(response)
|
||||
|
|
|
@ -1,3 +1,4 @@
|
|||
import logging
|
||||
import time
|
||||
|
||||
from backend.data import db
|
||||
|
@ -6,9 +7,10 @@ from backend.data.execution import ExecutionStatus
|
|||
from backend.data.model import CREDENTIALS_FIELD_NAME
|
||||
from backend.data.user import create_default_user
|
||||
from backend.executor import DatabaseManager, ExecutionManager, ExecutionScheduler
|
||||
from backend.server.rest_api import AgentServer, get_user_id
|
||||
from backend.server.rest_api import AgentServer
|
||||
from backend.server.utils import get_user_id
|
||||
|
||||
log = print
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class SpinTestServer:
|
||||
|
@ -57,17 +59,19 @@ async def wait_execution(
|
|||
timeout: int = 20,
|
||||
) -> list:
|
||||
async def is_execution_completed():
|
||||
status = await AgentServer().get_graph_run_status(
|
||||
status = await AgentServer().test_get_graph_run_status(
|
||||
graph_id, graph_exec_id, user_id
|
||||
)
|
||||
log.info(f"Execution status: {status}")
|
||||
if status == ExecutionStatus.FAILED:
|
||||
log.info("Execution failed")
|
||||
raise Exception("Execution failed")
|
||||
return status == ExecutionStatus.COMPLETED
|
||||
|
||||
# Wait for the executions to complete
|
||||
for i in range(timeout):
|
||||
if await is_execution_completed():
|
||||
return await AgentServer().get_graph_run_node_execution_results(
|
||||
return await AgentServer().test_get_graph_run_node_execution_results(
|
||||
graph_id, graph_exec_id, user_id
|
||||
)
|
||||
time.sleep(1)
|
||||
|
@ -79,7 +83,7 @@ def execute_block_test(block: Block):
|
|||
prefix = f"[Test-{block.name}]"
|
||||
|
||||
if not block.test_input or not block.test_output:
|
||||
log(f"{prefix} No test data provided")
|
||||
log.info(f"{prefix} No test data provided")
|
||||
return
|
||||
if not isinstance(block.test_input, list):
|
||||
block.test_input = [block.test_input]
|
||||
|
@ -87,15 +91,15 @@ def execute_block_test(block: Block):
|
|||
block.test_output = [block.test_output]
|
||||
|
||||
output_index = 0
|
||||
log(f"{prefix} Executing {len(block.test_input)} tests...")
|
||||
log.info(f"{prefix} Executing {len(block.test_input)} tests...")
|
||||
prefix = " " * 4 + prefix
|
||||
|
||||
for mock_name, mock_obj in (block.test_mock or {}).items():
|
||||
log(f"{prefix} mocking {mock_name}...")
|
||||
log.info(f"{prefix} mocking {mock_name}...")
|
||||
if hasattr(block, mock_name):
|
||||
setattr(block, mock_name, mock_obj)
|
||||
else:
|
||||
log(f"{prefix} mock {mock_name} not found in block")
|
||||
log.info(f"{prefix} mock {mock_name} not found in block")
|
||||
|
||||
extra_exec_kwargs = {}
|
||||
|
||||
|
@ -107,7 +111,7 @@ def execute_block_test(block: Block):
|
|||
extra_exec_kwargs[CREDENTIALS_FIELD_NAME] = block.test_credentials
|
||||
|
||||
for input_data in block.test_input:
|
||||
log(f"{prefix} in: {input_data}")
|
||||
log.info(f"{prefix} in: {input_data}")
|
||||
|
||||
for output_name, output_data in block.execute(input_data, **extra_exec_kwargs):
|
||||
if output_index >= len(block.test_output):
|
||||
|
@ -125,7 +129,7 @@ def execute_block_test(block: Block):
|
|||
is_matching = False
|
||||
|
||||
mark = "✅" if is_matching else "❌"
|
||||
log(f"{prefix} {mark} comparing `{data}` vs `{expected_data}`")
|
||||
log.info(f"{prefix} {mark} comparing `{data}` vs `{expected_data}`")
|
||||
if not is_matching:
|
||||
raise ValueError(
|
||||
f"{prefix}: wrong output {data} vs {expected_data}"
|
||||
|
|
|
@ -1,7 +1,20 @@
|
|||
import logging
|
||||
|
||||
import pytest
|
||||
|
||||
from backend.util.test import SpinTestServer
|
||||
|
||||
# NOTE: You can run tests like with the --log-cli-level=INFO to see the logs
|
||||
# Set up logging
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Create console handler with formatting
|
||||
ch = logging.StreamHandler()
|
||||
ch.setLevel(logging.INFO)
|
||||
formatter = logging.Formatter("%(asctime)s - %(name)s - %(levelname)s - %(message)s")
|
||||
ch.setFormatter(formatter)
|
||||
logger.addHandler(ch)
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
async def server():
|
||||
|
@ -12,7 +25,7 @@ async def server():
|
|||
@pytest.fixture(scope="session", autouse=True)
|
||||
async def graph_cleanup(server):
|
||||
created_graph_ids = []
|
||||
original_create_graph = server.agent_server.create_graph
|
||||
original_create_graph = server.agent_server.test_create_graph
|
||||
|
||||
async def create_graph_wrapper(*args, **kwargs):
|
||||
created_graph = await original_create_graph(*args, **kwargs)
|
||||
|
@ -22,13 +35,14 @@ async def graph_cleanup(server):
|
|||
return created_graph
|
||||
|
||||
try:
|
||||
server.agent_server.create_graph = create_graph_wrapper
|
||||
server.agent_server.test_create_graph = create_graph_wrapper
|
||||
yield # This runs the test function
|
||||
finally:
|
||||
server.agent_server.create_graph = original_create_graph
|
||||
server.agent_server.test_create_graph = original_create_graph
|
||||
|
||||
# Delete the created graphs and assert they were deleted
|
||||
for graph_id, user_id in created_graph_ids:
|
||||
resp = await server.agent_server.delete_graph(graph_id, user_id)
|
||||
num_deleted = resp["version_counts"]
|
||||
assert num_deleted > 0, f"Graph {graph_id} was not deleted."
|
||||
if user_id:
|
||||
resp = await server.agent_server.test_delete_graph(graph_id, user_id)
|
||||
num_deleted = resp["version_counts"]
|
||||
assert num_deleted > 0, f"Graph {graph_id} was not deleted."
|
||||
|
|
|
@ -47,15 +47,15 @@ async def test_graph_creation(server: SpinTestServer):
|
|||
create_graph = CreateGraph(graph=graph)
|
||||
|
||||
try:
|
||||
await server.agent_server.create_graph(create_graph, False, DEFAULT_USER_ID)
|
||||
await server.agent_server.test_create_graph(create_graph, DEFAULT_USER_ID)
|
||||
assert False, "Should not be able to connect nodes from different subgraphs"
|
||||
except ValueError as e:
|
||||
assert "different subgraph" in str(e)
|
||||
|
||||
# Change node_1 <-> node_3 link to node_1 <-> node_2 (input for subgraph_1)
|
||||
graph.links[0].sink_id = "node_2"
|
||||
created_graph = await server.agent_server.create_graph(
|
||||
create_graph, False, DEFAULT_USER_ID
|
||||
created_graph = await server.agent_server.test_create_graph(
|
||||
create_graph, DEFAULT_USER_ID
|
||||
)
|
||||
|
||||
assert UUID(created_graph.id)
|
||||
|
@ -102,8 +102,8 @@ async def test_get_input_schema(server: SpinTestServer):
|
|||
)
|
||||
|
||||
create_graph = CreateGraph(graph=graph)
|
||||
created_graph = await server.agent_server.create_graph(
|
||||
create_graph, False, DEFAULT_USER_ID
|
||||
created_graph = await server.agent_server.test_create_graph(
|
||||
create_graph, DEFAULT_USER_ID
|
||||
)
|
||||
|
||||
input_schema = created_graph.get_input_schema()
|
||||
|
@ -138,8 +138,8 @@ async def test_get_input_schema_none_required(server: SpinTestServer):
|
|||
)
|
||||
|
||||
create_graph = CreateGraph(graph=graph)
|
||||
created_graph = await server.agent_server.create_graph(
|
||||
create_graph, False, DEFAULT_USER_ID
|
||||
created_graph = await server.agent_server.test_create_graph(
|
||||
create_graph, DEFAULT_USER_ID
|
||||
)
|
||||
|
||||
input_schema = created_graph.get_input_schema()
|
||||
|
@ -180,8 +180,8 @@ async def test_get_input_schema_with_linked_blocks(server: SpinTestServer):
|
|||
)
|
||||
|
||||
create_graph = CreateGraph(graph=graph)
|
||||
created_graph = await server.agent_server.create_graph(
|
||||
create_graph, False, DEFAULT_USER_ID
|
||||
created_graph = await server.agent_server.test_create_graph(
|
||||
create_graph, DEFAULT_USER_ID
|
||||
)
|
||||
|
||||
input_schema = created_graph.get_input_schema()
|
||||
|
|
|
@ -1,3 +1,5 @@
|
|||
import logging
|
||||
|
||||
import pytest
|
||||
from prisma.models import User
|
||||
|
||||
|
@ -9,9 +11,12 @@ from backend.server.rest_api import AgentServer
|
|||
from backend.usecases.sample import create_test_graph, create_test_user
|
||||
from backend.util.test import SpinTestServer, wait_execution
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
async def create_graph(s: SpinTestServer, g: graph.Graph, u: User) -> graph.Graph:
|
||||
return await s.agent_server.create_graph(CreateGraph(graph=g), False, u.id)
|
||||
logger.info(f"Creating graph for user {u.id}")
|
||||
return await s.agent_server.test_create_graph(CreateGraph(graph=g), u.id)
|
||||
|
||||
|
||||
async def execute_graph(
|
||||
|
@ -21,12 +26,20 @@ async def execute_graph(
|
|||
input_data: dict,
|
||||
num_execs: int = 4,
|
||||
) -> str:
|
||||
logger.info(f"Executing graph {test_graph.id} for user {test_user.id}")
|
||||
logger.info(f"Input data: {input_data}")
|
||||
|
||||
# --- Test adding new executions --- #
|
||||
response = agent_server.execute_graph(test_graph.id, input_data, test_user.id)
|
||||
response = await agent_server.test_execute_graph(
|
||||
test_graph.id, input_data, test_user.id
|
||||
)
|
||||
graph_exec_id = response["id"]
|
||||
logger.info(f"Created execution with ID: {graph_exec_id}")
|
||||
|
||||
# Execution queue should be empty
|
||||
logger.info("Waiting for execution to complete...")
|
||||
result = await wait_execution(test_user.id, test_graph.id, graph_exec_id)
|
||||
logger.info(f"Execution completed with {len(result)} results")
|
||||
assert result and len(result) == num_execs
|
||||
return graph_exec_id
|
||||
|
||||
|
@ -37,7 +50,8 @@ async def assert_sample_graph_executions(
|
|||
test_user: User,
|
||||
graph_exec_id: str,
|
||||
):
|
||||
executions = await agent_server.get_graph_run_node_execution_results(
|
||||
logger.info(f"Checking execution results for graph {test_graph.id}")
|
||||
executions = await agent_server.test_get_graph_run_node_execution_results(
|
||||
test_graph.id,
|
||||
graph_exec_id,
|
||||
test_user.id,
|
||||
|
@ -57,6 +71,7 @@ async def assert_sample_graph_executions(
|
|||
|
||||
# Executing StoreValueBlock
|
||||
exec = executions[0]
|
||||
logger.info(f"Checking first StoreValueBlock execution: {exec}")
|
||||
assert exec.status == execution.ExecutionStatus.COMPLETED
|
||||
assert exec.graph_exec_id == graph_exec_id
|
||||
assert (
|
||||
|
@ -69,6 +84,7 @@ async def assert_sample_graph_executions(
|
|||
|
||||
# Executing StoreValueBlock
|
||||
exec = executions[1]
|
||||
logger.info(f"Checking second StoreValueBlock execution: {exec}")
|
||||
assert exec.status == execution.ExecutionStatus.COMPLETED
|
||||
assert exec.graph_exec_id == graph_exec_id
|
||||
assert (
|
||||
|
@ -81,6 +97,7 @@ async def assert_sample_graph_executions(
|
|||
|
||||
# Executing FillTextTemplateBlock
|
||||
exec = executions[2]
|
||||
logger.info(f"Checking FillTextTemplateBlock execution: {exec}")
|
||||
assert exec.status == execution.ExecutionStatus.COMPLETED
|
||||
assert exec.graph_exec_id == graph_exec_id
|
||||
assert exec.output_data == {"output": ["Hello, World!!!"]}
|
||||
|
@ -95,6 +112,7 @@ async def assert_sample_graph_executions(
|
|||
|
||||
# Executing PrintToConsoleBlock
|
||||
exec = executions[3]
|
||||
logger.info(f"Checking PrintToConsoleBlock execution: {exec}")
|
||||
assert exec.status == execution.ExecutionStatus.COMPLETED
|
||||
assert exec.graph_exec_id == graph_exec_id
|
||||
assert exec.output_data == {"status": ["printed"]}
|
||||
|
@ -104,6 +122,7 @@ async def assert_sample_graph_executions(
|
|||
|
||||
@pytest.mark.asyncio(scope="session")
|
||||
async def test_agent_execution(server: SpinTestServer):
|
||||
logger.info("Starting test_agent_execution")
|
||||
test_user = await create_test_user()
|
||||
test_graph = await create_graph(server, create_test_graph(), test_user)
|
||||
data = {"input_1": "Hello", "input_2": "World"}
|
||||
|
@ -117,6 +136,7 @@ async def test_agent_execution(server: SpinTestServer):
|
|||
await assert_sample_graph_executions(
|
||||
server.agent_server, test_graph, test_user, graph_exec_id
|
||||
)
|
||||
logger.info("Completed test_agent_execution")
|
||||
|
||||
|
||||
@pytest.mark.asyncio(scope="session")
|
||||
|
@ -132,6 +152,7 @@ async def test_input_pin_always_waited(server: SpinTestServer):
|
|||
// key
|
||||
StoreValueBlock2
|
||||
"""
|
||||
logger.info("Starting test_input_pin_always_waited")
|
||||
nodes = [
|
||||
graph.Node(
|
||||
block_id=StoreValueBlock().id,
|
||||
|
@ -172,7 +193,8 @@ async def test_input_pin_always_waited(server: SpinTestServer):
|
|||
server.agent_server, test_graph, test_user, {}, 3
|
||||
)
|
||||
|
||||
executions = await server.agent_server.get_graph_run_node_execution_results(
|
||||
logger.info("Checking execution results")
|
||||
executions = await server.agent_server.test_get_graph_run_node_execution_results(
|
||||
test_graph.id, graph_exec_id, test_user.id
|
||||
)
|
||||
assert len(executions) == 3
|
||||
|
@ -180,6 +202,7 @@ async def test_input_pin_always_waited(server: SpinTestServer):
|
|||
# Hence executing extraction of "key" from {"key1": "value1", "key2": "value2"}
|
||||
assert executions[2].status == execution.ExecutionStatus.COMPLETED
|
||||
assert executions[2].output_data == {"output": ["value2"]}
|
||||
logger.info("Completed test_input_pin_always_waited")
|
||||
|
||||
|
||||
@pytest.mark.asyncio(scope="session")
|
||||
|
@ -197,6 +220,7 @@ async def test_static_input_link_on_graph(server: SpinTestServer):
|
|||
And later, another output is produced on input pin `b`, which is a static link,
|
||||
this input will complete the input of those three incomplete executions.
|
||||
"""
|
||||
logger.info("Starting test_static_input_link_on_graph")
|
||||
nodes = [
|
||||
graph.Node(block_id=StoreValueBlock().id, input_default={"input": 4}), # a
|
||||
graph.Node(block_id=StoreValueBlock().id, input_default={"input": 4}), # a
|
||||
|
@ -252,11 +276,14 @@ async def test_static_input_link_on_graph(server: SpinTestServer):
|
|||
graph_exec_id = await execute_graph(
|
||||
server.agent_server, test_graph, test_user, {}, 8
|
||||
)
|
||||
executions = await server.agent_server.get_graph_run_node_execution_results(
|
||||
logger.info("Checking execution results")
|
||||
executions = await server.agent_server.test_get_graph_run_node_execution_results(
|
||||
test_graph.id, graph_exec_id, test_user.id
|
||||
)
|
||||
assert len(executions) == 8
|
||||
# The last 3 executions will be a+b=4+5=9
|
||||
for exec_data in executions[-3:]:
|
||||
for i, exec_data in enumerate(executions[-3:]):
|
||||
logger.info(f"Checking execution {i+1} of last 3: {exec_data}")
|
||||
assert exec_data.status == execution.ExecutionStatus.COMPLETED
|
||||
assert exec_data.output_data == {"result": [9]}
|
||||
logger.info("Completed test_static_input_link_on_graph")
|
||||
|
|
|
@ -12,7 +12,7 @@ from backend.util.test import SpinTestServer
|
|||
async def test_agent_schedule(server: SpinTestServer):
|
||||
await db.connect()
|
||||
test_user = await create_test_user()
|
||||
test_graph = await server.agent_server.create_graph(
|
||||
test_graph = await server.agent_server.test_create_graph(
|
||||
create_graph=CreateGraph(graph=create_test_graph()),
|
||||
is_template=False,
|
||||
user_id=test_user.id,
|
||||
|
|
|
@ -28,7 +28,7 @@ export default class BaseAutoGPTServerAPI {
|
|||
|
||||
constructor(
|
||||
baseUrl: string = process.env.NEXT_PUBLIC_AGPT_SERVER_URL ||
|
||||
"http://localhost:8006/api",
|
||||
"http://localhost:8006/api/v1",
|
||||
wsUrl: string = process.env.NEXT_PUBLIC_AGPT_WS_SERVER_URL ||
|
||||
"ws://localhost:8001/ws",
|
||||
supabaseClient: SupabaseClient | null = null,
|
||||
|
|
Loading…
Reference in New Issue