refactor(platform/rest): Update REST API to use standard FastAPI structure (#8519)

pull/8546/head
Swifty 2024-11-04 12:12:21 +01:00 committed by GitHub
parent e26513f5e4
commit 594aa996d7
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
11 changed files with 716 additions and 710 deletions

View File

@ -1,695 +1,117 @@
import asyncio
import inspect
import logging
from collections import defaultdict
from contextlib import asynccontextmanager
from functools import wraps
from typing import Annotated, Any, Dict
import contextlib
import typing
import fastapi
import fastapi.middleware.cors
import fastapi.responses
import uvicorn
from autogpt_libs.auth.middleware import auth_middleware
from autogpt_libs.utils.cache import thread_cached
from fastapi import APIRouter, Body, Depends, FastAPI, HTTPException, Request
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import JSONResponse
from typing_extensions import TypedDict
from backend.data import block, db
from backend.data import execution as execution_db
from backend.data import graph as graph_db
from backend.data.block import BlockInput, CompletedBlockOutput
from backend.data.credit import get_block_costs, get_user_credit_model
from backend.data.user import get_or_create_user, migrate_and_encrypt_user_integrations
from backend.executor import ExecutionManager, ExecutionScheduler
from backend.server.model import CreateGraph, SetGraphActiveVersion
from backend.util.service import AppService, get_service_client
from backend.util.settings import AppEnvironment, Config, Settings
import backend.data.block
import backend.data.db
import backend.data.user
import backend.server.routers.v1
import backend.util.service
import backend.util.settings
from .utils import get_user_id
settings = Settings()
logger = logging.getLogger(__name__)
settings = backend.util.settings.Settings()
class AgentServer(AppService):
_test_dependency_overrides = {}
_user_credit_model = get_user_credit_model()
@contextlib.asynccontextmanager
async def lifespan_context(app: fastapi.FastAPI):
await backend.data.db.connect()
await backend.data.block.initialize_blocks()
await backend.data.user.migrate_and_encrypt_user_integrations()
yield
await backend.data.db.disconnect()
def __init__(self):
super().__init__()
self.use_redis = True
@classmethod
def get_port(cls) -> int:
return Config().agent_server_port
docs_url = (
"/docs"
if settings.config.app_env == backend.util.settings.AppEnvironment.LOCAL
else None
)
@asynccontextmanager
async def lifespan(self, _: FastAPI):
await db.connect()
await block.initialize_blocks()
await migrate_and_encrypt_user_integrations()
yield
await db.disconnect()
app = fastapi.FastAPI(
title="AutoGPT Agent Server",
description=(
"This server is used to execute agents that are created by the "
"AutoGPT system."
),
summary="AutoGPT Agent Server",
version="0.1",
lifespan=lifespan_context,
docs_url=docs_url,
)
def run_service(self):
docs_url = "/docs" if settings.config.app_env == AppEnvironment.LOCAL else None
app = FastAPI(
title="AutoGPT Agent Server",
description=(
"This server is used to execute agents that are created by the "
"AutoGPT system."
),
summary="AutoGPT Agent Server",
version="0.1",
lifespan=self.lifespan,
docs_url=docs_url,
)
app.include_router(backend.server.routers.v1.v1_router, tags=["v1"])
app.add_middleware(
fastapi.middleware.cors.CORSMiddleware,
allow_origins=settings.config.backend_cors_allow_origins,
allow_credentials=True,
allow_methods=["*"], # Allows all methods
allow_headers=["*"], # Allows all headers
)
if self._test_dependency_overrides:
app.dependency_overrides.update(self._test_dependency_overrides)
logger.debug(
f"FastAPI CORS allow origins: {Config().backend_cors_allow_origins}"
)
@app.get(path="/health", tags=["health"], dependencies=[])
async def health():
return {"status": "healthy"}
app.add_middleware(
CORSMiddleware,
allow_origins=Config().backend_cors_allow_origins,
allow_credentials=True,
allow_methods=["*"], # Allows all methods
allow_headers=["*"], # Allows all headers
)
health_router = APIRouter()
health_router.add_api_route(
path="/health",
endpoint=self.health,
methods=["GET"],
tags=["health"],
)
@app.exception_handler(Exception)
def handle_internal_http_error(request: fastapi.Request, exc: Exception):
return fastapi.responses.JSONResponse(
content={
"message": f"{request.method} {request.url.path} failed",
"error": str(exc),
},
status_code=500,
)
# Define the API routes
api_router = APIRouter(prefix="/api")
api_router.dependencies.append(Depends(auth_middleware))
# Import & Attach sub-routers
import backend.server.integrations.router
import backend.server.routers.analytics
api_router.include_router(
backend.server.integrations.router.router,
prefix="/integrations",
tags=["integrations"],
dependencies=[Depends(auth_middleware)],
)
api_router.include_router(
backend.server.routers.analytics.router,
prefix="/analytics",
tags=["analytics"],
dependencies=[Depends(auth_middleware)],
)
api_router.add_api_route(
path="/auth/user",
endpoint=self.get_or_create_user_route,
methods=["POST"],
tags=["auth"],
)
api_router.add_api_route(
path="/blocks",
endpoint=self.get_graph_blocks,
methods=["GET"],
tags=["blocks"],
)
api_router.add_api_route(
path="/blocks/{block_id}/execute",
endpoint=self.execute_graph_block,
methods=["POST"],
tags=["blocks"],
)
api_router.add_api_route(
path="/graphs",
endpoint=self.get_graphs,
methods=["GET"],
tags=["graphs"],
)
api_router.add_api_route(
path="/templates",
endpoint=self.get_templates,
methods=["GET"],
tags=["templates", "graphs"],
)
api_router.add_api_route(
path="/graphs",
endpoint=self.create_new_graph,
methods=["POST"],
tags=["graphs"],
)
api_router.add_api_route(
path="/templates",
endpoint=self.create_new_template,
methods=["POST"],
tags=["templates", "graphs"],
)
api_router.add_api_route(
path="/graphs/{graph_id}",
endpoint=self.get_graph,
methods=["GET"],
tags=["graphs"],
)
api_router.add_api_route(
path="/templates/{graph_id}",
endpoint=self.get_template,
methods=["GET"],
tags=["templates", "graphs"],
)
api_router.add_api_route(
path="/graphs/{graph_id}",
endpoint=self.update_graph,
methods=["PUT"],
tags=["graphs"],
)
api_router.add_api_route(
path="/templates/{graph_id}",
endpoint=self.update_graph,
methods=["PUT"],
tags=["templates", "graphs"],
)
api_router.add_api_route(
path="/graphs/{graph_id}",
endpoint=self.delete_graph,
methods=["DELETE"],
tags=["graphs"],
)
api_router.add_api_route(
path="/graphs/{graph_id}/versions",
endpoint=self.get_graph_all_versions,
methods=["GET"],
tags=["graphs"],
)
api_router.add_api_route(
path="/templates/{graph_id}/versions",
endpoint=self.get_graph_all_versions,
methods=["GET"],
tags=["templates", "graphs"],
)
api_router.add_api_route(
path="/graphs/{graph_id}/versions/{version}",
endpoint=self.get_graph,
methods=["GET"],
tags=["graphs"],
)
api_router.add_api_route(
path="/graphs/{graph_id}/versions/active",
endpoint=self.set_graph_active_version,
methods=["PUT"],
tags=["graphs"],
)
api_router.add_api_route(
path="/graphs/{graph_id}/input_schema",
endpoint=self.get_graph_input_schema,
methods=["GET"],
tags=["graphs"],
)
api_router.add_api_route(
path="/graphs/{graph_id}/execute",
endpoint=self.execute_graph,
methods=["POST"],
tags=["graphs"],
)
api_router.add_api_route(
path="/graphs/{graph_id}/executions",
endpoint=self.list_graph_runs,
methods=["GET"],
tags=["graphs"],
)
api_router.add_api_route(
path="/graphs/{graph_id}/executions/{graph_exec_id}",
endpoint=self.get_graph_run_node_execution_results,
methods=["GET"],
tags=["graphs"],
)
api_router.add_api_route(
path="/graphs/{graph_id}/executions/{graph_exec_id}/stop",
endpoint=self.stop_graph_run,
methods=["POST"],
tags=["graphs"],
)
api_router.add_api_route(
path="/graphs/{graph_id}/schedules",
endpoint=self.create_schedule,
methods=["POST"],
tags=["graphs"],
)
api_router.add_api_route(
path="/graphs/{graph_id}/schedules",
endpoint=self.get_execution_schedules,
methods=["GET"],
tags=["graphs"],
)
api_router.add_api_route(
path="/graphs/schedules/{schedule_id}",
endpoint=self.update_schedule,
methods=["PUT"],
tags=["graphs"],
)
api_router.add_api_route(
path="/credits",
endpoint=self.get_user_credits,
methods=["GET"],
)
api_router.add_api_route(
path="/settings",
endpoint=self.update_configuration,
methods=["POST"],
tags=["settings"],
)
app.add_exception_handler(500, self.handle_internal_http_error)
app.include_router(api_router)
app.include_router(health_router)
class AgentServer(backend.util.service.AppProcess):
def run(self):
uvicorn.run(
app,
host=Config().agent_api_host,
port=Config().agent_api_port,
log_config=None,
host=backend.util.settings.Config().agent_api_host,
port=backend.util.settings.Config().agent_api_port,
)
@staticmethod
async def test_execute_graph(
graph_id: str, node_input: dict[typing.Any, typing.Any], user_id: str
):
return await backend.server.routers.v1.execute_graph(
graph_id, node_input, user_id
)
@staticmethod
async def test_create_graph(
create_graph: backend.server.routers.v1.CreateGraph,
user_id: str,
is_template=False,
):
return await backend.server.routers.v1.create_new_graph(create_graph, user_id)
@staticmethod
async def test_get_graph_run_status(
graph_id: str, graph_exec_id: str, user_id: str
):
return await backend.server.routers.v1.get_graph_run_status(
graph_id, graph_exec_id, user_id
)
@staticmethod
async def test_get_graph_run_node_execution_results(
graph_id: str, graph_exec_id: str, user_id: str
):
return await backend.server.routers.v1.get_graph_run_node_execution_results(
graph_id, graph_exec_id, user_id
)
@staticmethod
async def test_delete_graph(graph_id: str, user_id: str):
return await backend.server.routers.v1.delete_graph(graph_id, user_id)
def set_test_dependency_overrides(self, overrides: dict):
self._test_dependency_overrides = overrides
def _apply_overrides_to_methods(self):
for attr_name in dir(self):
attr = getattr(self, attr_name)
if callable(attr) and hasattr(attr, "__annotations__"):
setattr(self, attr_name, self._override_method(attr))
# TODO: fix this with some proper refactoring of the server
def _override_method(self, method):
@wraps(method)
async def wrapper(*args, **kwargs):
sig = inspect.signature(method)
for param_name, param in sig.parameters.items():
if param.annotation is inspect.Parameter.empty:
continue
if isinstance(param.annotation, Depends) or ( # type: ignore
isinstance(param.annotation, type) and issubclass(param.annotation, Depends) # type: ignore
):
dependency = param.annotation.dependency if isinstance(param.annotation, Depends) else param.annotation # type: ignore
if dependency in self._test_dependency_overrides:
kwargs[param_name] = self._test_dependency_overrides[
dependency
]()
return await method(*args, **kwargs)
return wrapper
@property
@thread_cached
def execution_manager_client(self) -> ExecutionManager:
return get_service_client(ExecutionManager)
@property
@thread_cached
def execution_scheduler_client(self) -> ExecutionScheduler:
return get_service_client(ExecutionScheduler)
@classmethod
def handle_internal_http_error(cls, request: Request, exc: Exception):
return JSONResponse(
content={
"message": f"{request.method} {request.url.path} failed",
"error": str(exc),
},
status_code=500,
)
@classmethod
async def get_or_create_user_route(cls, user_data: dict = Depends(auth_middleware)):
user = await get_or_create_user(user_data)
return user.model_dump()
@classmethod
def get_graph_blocks(cls) -> list[dict[Any, Any]]:
blocks = [cls() for cls in block.get_blocks().values()]
costs = get_block_costs()
return [{**b.to_dict(), "costs": costs.get(b.id, [])} for b in blocks]
@classmethod
def execute_graph_block(
cls, block_id: str, data: BlockInput
) -> CompletedBlockOutput:
obj = block.get_block(block_id)
if not obj:
raise HTTPException(status_code=404, detail=f"Block #{block_id} not found.")
output = defaultdict(list)
for name, data in obj.execute(data):
output[name].append(data)
return output
@classmethod
async def get_graphs(
cls,
user_id: Annotated[str, Depends(get_user_id)],
with_runs: bool = False,
) -> list[graph_db.GraphMeta]:
return await graph_db.get_graphs_meta(
include_executions=with_runs, filter_by="active", user_id=user_id
)
@classmethod
async def get_templates(
cls, user_id: Annotated[str, Depends(get_user_id)]
) -> list[graph_db.GraphMeta]:
return await graph_db.get_graphs_meta(filter_by="template", user_id=user_id)
@classmethod
async def get_graph(
cls,
graph_id: str,
user_id: Annotated[str, Depends(get_user_id)],
version: int | None = None,
hide_credentials: bool = False,
) -> graph_db.Graph:
graph = await graph_db.get_graph(
graph_id, version, user_id=user_id, hide_credentials=hide_credentials
)
if not graph:
raise HTTPException(status_code=404, detail=f"Graph #{graph_id} not found.")
return graph
@classmethod
async def get_template(
cls, graph_id: str, version: int | None = None
) -> graph_db.Graph:
graph = await graph_db.get_graph(graph_id, version, template=True)
if not graph:
raise HTTPException(
status_code=404, detail=f"Template #{graph_id} not found."
)
return graph
@classmethod
async def get_graph_all_versions(
cls, graph_id: str, user_id: Annotated[str, Depends(get_user_id)]
) -> list[graph_db.Graph]:
graphs = await graph_db.get_graph_all_versions(graph_id, user_id=user_id)
if not graphs:
raise HTTPException(status_code=404, detail=f"Graph #{graph_id} not found.")
return graphs
@classmethod
async def create_new_graph(
cls, create_graph: CreateGraph, user_id: Annotated[str, Depends(get_user_id)]
) -> graph_db.Graph:
return await cls.create_graph(create_graph, is_template=False, user_id=user_id)
@classmethod
async def create_new_template(
cls, create_graph: CreateGraph, user_id: Annotated[str, Depends(get_user_id)]
) -> graph_db.Graph:
return await cls.create_graph(create_graph, is_template=True, user_id=user_id)
class DeleteGraphResponse(TypedDict):
version_counts: int
@classmethod
async def delete_graph(
cls, graph_id: str, user_id: Annotated[str, Depends(get_user_id)]
) -> DeleteGraphResponse:
return {
"version_counts": await graph_db.delete_graph(graph_id, user_id=user_id)
}
@classmethod
async def create_graph(
cls,
create_graph: CreateGraph,
is_template: bool,
# user_id doesn't have to be annotated like on other endpoints,
# because create_graph isn't used directly as an endpoint
user_id: str,
) -> graph_db.Graph:
if create_graph.graph:
graph = create_graph.graph
elif create_graph.template_id:
# Create a new graph from a template
graph = await graph_db.get_graph(
create_graph.template_id,
create_graph.template_version,
template=True,
user_id=user_id,
)
if not graph:
raise HTTPException(
400, detail=f"Template #{create_graph.template_id} not found"
)
graph.version = 1
else:
raise HTTPException(
status_code=400, detail="Either graph or template_id must be provided."
)
graph.is_template = is_template
graph.is_active = not is_template
graph.reassign_ids(reassign_graph_id=True)
return await graph_db.create_graph(graph, user_id=user_id)
@classmethod
async def update_graph(
cls,
graph_id: str,
graph: graph_db.Graph,
user_id: Annotated[str, Depends(get_user_id)],
) -> graph_db.Graph:
# Sanity check
if graph.id and graph.id != graph_id:
raise HTTPException(400, detail="Graph ID does not match ID in URI")
# Determine new version
existing_versions = await graph_db.get_graph_all_versions(
graph_id, user_id=user_id
)
if not existing_versions:
raise HTTPException(404, detail=f"Graph #{graph_id} not found")
latest_version_number = max(g.version for g in existing_versions)
graph.version = latest_version_number + 1
latest_version_graph = next(
v for v in existing_versions if v.version == latest_version_number
)
if latest_version_graph.is_template != graph.is_template:
raise HTTPException(
400, detail="Changing is_template on an existing graph is forbidden"
)
graph.is_active = not graph.is_template
graph.reassign_ids()
new_graph_version = await graph_db.create_graph(graph, user_id=user_id)
if new_graph_version.is_active:
# Ensure new version is the only active version
await graph_db.set_graph_active_version(
graph_id=graph_id, version=new_graph_version.version, user_id=user_id
)
return new_graph_version
@classmethod
async def set_graph_active_version(
cls,
graph_id: str,
request_body: SetGraphActiveVersion,
user_id: Annotated[str, Depends(get_user_id)],
):
new_active_version = request_body.active_graph_version
if not await graph_db.get_graph(graph_id, new_active_version, user_id=user_id):
raise HTTPException(
404, f"Graph #{graph_id} v{new_active_version} not found"
)
await graph_db.set_graph_active_version(
graph_id=graph_id,
version=request_body.active_graph_version,
user_id=user_id,
)
def execute_graph(
self,
graph_id: str,
node_input: dict[Any, Any],
user_id: Annotated[str, Depends(get_user_id)],
) -> dict[str, Any]: # FIXME: add proper return type
try:
graph_exec = self.execution_manager_client.add_execution(
graph_id, node_input, user_id=user_id
)
return {"id": graph_exec["graph_exec_id"]}
except Exception as e:
msg = e.__str__().encode().decode("unicode_escape")
raise HTTPException(status_code=400, detail=msg)
async def stop_graph_run(
self, graph_exec_id: str, user_id: Annotated[str, Depends(get_user_id)]
) -> list[execution_db.ExecutionResult]:
if not await execution_db.get_graph_execution(graph_exec_id, user_id):
raise HTTPException(
404, detail=f"Agent execution #{graph_exec_id} not found"
)
await asyncio.to_thread(
lambda: self.execution_manager_client.cancel_execution(graph_exec_id)
)
# Retrieve & return canceled graph execution in its final state
return await execution_db.get_execution_results(graph_exec_id)
@classmethod
async def get_graph_input_schema(
cls,
graph_id: str,
user_id: Annotated[str, Depends(get_user_id)],
) -> list[graph_db.InputSchemaItem]:
try:
graph = await graph_db.get_graph(graph_id, user_id=user_id)
return graph.get_input_schema() if graph else []
except Exception:
raise HTTPException(status_code=404, detail=f"Graph #{graph_id} not found.")
@classmethod
async def list_graph_runs(
cls,
graph_id: str,
user_id: Annotated[str, Depends(get_user_id)],
graph_version: int | None = None,
) -> list[str]:
graph = await graph_db.get_graph(graph_id, graph_version, user_id=user_id)
if not graph:
rev = "" if graph_version is None else f" v{graph_version}"
raise HTTPException(
status_code=404, detail=f"Agent #{graph_id}{rev} not found."
)
return await execution_db.list_executions(graph_id, graph_version)
@classmethod
async def get_graph_run_status(
cls,
graph_id: str,
graph_exec_id: str,
user_id: Annotated[str, Depends(get_user_id)],
) -> execution_db.ExecutionStatus:
graph = await graph_db.get_graph(graph_id, user_id=user_id)
if not graph:
raise HTTPException(status_code=404, detail=f"Graph #{graph_id} not found.")
execution = await execution_db.get_graph_execution(graph_exec_id, user_id)
if not execution:
raise HTTPException(
status_code=404, detail=f"Execution #{graph_exec_id} not found."
)
return execution.executionStatus
@classmethod
async def get_graph_run_node_execution_results(
cls,
graph_id: str,
graph_exec_id: str,
user_id: Annotated[str, Depends(get_user_id)],
) -> list[execution_db.ExecutionResult]:
graph = await graph_db.get_graph(graph_id, user_id=user_id)
if not graph:
raise HTTPException(status_code=404, detail=f"Graph #{graph_id} not found.")
return await execution_db.get_execution_results(graph_exec_id)
async def create_schedule(
self,
graph_id: str,
cron: str,
input_data: dict[Any, Any],
user_id: Annotated[str, Depends(get_user_id)],
) -> dict[Any, Any]:
graph = await graph_db.get_graph(graph_id, user_id=user_id)
if not graph:
raise HTTPException(status_code=404, detail=f"Graph #{graph_id} not found.")
return {
"id": await asyncio.to_thread(
lambda: self.execution_scheduler_client.add_execution_schedule(
graph_id=graph_id,
graph_version=graph.version,
cron=cron,
input_data=input_data,
user_id=user_id,
)
)
}
def update_schedule(
self,
schedule_id: str,
input_data: dict[Any, Any],
user_id: Annotated[str, Depends(get_user_id)],
) -> dict[Any, Any]:
execution_scheduler = self.execution_scheduler_client
is_enabled = input_data.get("is_enabled", False)
execution_scheduler.update_schedule(schedule_id, is_enabled, user_id=user_id)
return {"id": schedule_id}
async def get_user_credits(
self, user_id: Annotated[str, Depends(get_user_id)]
) -> dict[str, int]:
return {"credits": await self._user_credit_model.get_or_refill_credit(user_id)}
def get_execution_schedules(
self, graph_id: str, user_id: Annotated[str, Depends(get_user_id)]
) -> dict[str, str]:
execution_scheduler = self.execution_scheduler_client
return execution_scheduler.get_execution_schedules(graph_id, user_id)
async def health(self):
return {"status": "healthy"}
@classmethod
def update_configuration(
cls,
updated_settings: Annotated[
Dict[str, Any],
Body(
examples=[
{
"config": {
"num_graph_workers": 10,
"num_node_workers": 10,
}
}
]
),
],
):
settings = Settings()
try:
updated_fields: dict[Any, Any] = {"config": [], "secrets": []}
for key, value in updated_settings.get("config", {}).items():
if hasattr(settings.config, key):
setattr(settings.config, key, value)
updated_fields["config"].append(key)
for key, value in updated_settings.get("secrets", {}).items():
if hasattr(settings.secrets, key):
setattr(settings.secrets, key, value)
updated_fields["secrets"].append(key)
settings.save()
return {
"message": "Settings updated successfully",
"updated_fields": updated_fields,
}
except Exception as e:
raise HTTPException(status_code=400, detail=str(e))
app.dependency_overrides.update(overrides)

View File

@ -0,0 +1,539 @@
import asyncio
import logging
from collections import defaultdict
from typing import Annotated, Any, Dict
from autogpt_libs.auth.middleware import auth_middleware
from autogpt_libs.utils.cache import thread_cached
from fastapi import APIRouter, Body, Depends, HTTPException
from typing_extensions import TypedDict
import backend.data.block
import backend.server.integrations.router
import backend.server.routers.analytics
from backend.data import execution as execution_db
from backend.data import graph as graph_db
from backend.data.block import BlockInput, CompletedBlockOutput
from backend.data.credit import get_block_costs, get_user_credit_model
from backend.data.user import get_or_create_user
from backend.executor import ExecutionManager, ExecutionScheduler
from backend.server.model import CreateGraph, SetGraphActiveVersion
from backend.server.utils import get_user_id
from backend.util.service import get_service_client
from backend.util.settings import Settings
@thread_cached
def execution_manager_client() -> ExecutionManager:
return get_service_client(ExecutionManager)
@thread_cached
def execution_scheduler_client() -> ExecutionScheduler:
return get_service_client(ExecutionScheduler)
settings = Settings()
logger = logging.getLogger(__name__)
_user_credit_model = get_user_credit_model()
# Define the API routes
v1_router = APIRouter(prefix="/api")
v1_router.dependencies.append(Depends(auth_middleware))
v1_router.include_router(
backend.server.integrations.router.router,
prefix="/integrations",
tags=["integrations"],
dependencies=[Depends(auth_middleware)],
)
v1_router.include_router(
backend.server.routers.analytics.router,
prefix="/analytics",
tags=["analytics"],
dependencies=[Depends(auth_middleware)],
)
########################################################
##################### Auth #############################
########################################################
@v1_router.post("/auth/user", tags=["auth"], dependencies=[Depends(auth_middleware)])
async def get_or_create_user_route(user_data: dict = Depends(auth_middleware)):
user = await get_or_create_user(user_data)
return user.model_dump()
########################################################
##################### Blocks ###########################
########################################################
@v1_router.get(path="/blocks", tags=["blocks"], dependencies=[Depends(auth_middleware)])
def get_graph_blocks() -> list[dict[Any, Any]]:
blocks = [block() for block in backend.data.block.get_blocks().values()]
costs = get_block_costs()
return [{**b.to_dict(), "costs": costs.get(b.id, [])} for b in blocks]
@v1_router.post(path="/blocks/{block_id}/execute", tags=["blocks"])
def execute_graph_block(block_id: str, data: BlockInput) -> CompletedBlockOutput:
obj = backend.data.block.get_block(block_id)
if not obj:
raise HTTPException(status_code=404, detail=f"Block #{block_id} not found.")
output = defaultdict(list)
for name, data in obj.execute(data):
output[name].append(data)
return output
########################################################
##################### Credits ##########################
########################################################
@v1_router.get(path="/credits", dependencies=[Depends(auth_middleware)])
async def get_user_credits(
user_id: Annotated[str, Depends(get_user_id)]
) -> dict[str, int]:
return {"credits": await _user_credit_model.get_or_refill_credit(user_id)}
########################################################
##################### Graphs ###########################
########################################################
class DeleteGraphResponse(TypedDict):
version_counts: int
@v1_router.get(path="/graphs", tags=["graphs"], dependencies=[Depends(auth_middleware)])
async def get_graphs(
user_id: Annotated[str, Depends(get_user_id)],
with_runs: bool = False,
) -> list[graph_db.GraphMeta]:
return await graph_db.get_graphs_meta(
include_executions=with_runs, filter_by="active", user_id=user_id
)
@v1_router.get(
path="/graphs/{graph_id}", tags=["graphs"], dependencies=[Depends(auth_middleware)]
)
@v1_router.get(
path="/graphs/{graph_id}/versions/{version}",
tags=["graphs"],
dependencies=[Depends(auth_middleware)],
)
async def get_graph(
graph_id: str,
user_id: Annotated[str, Depends(get_user_id)],
version: int | None = None,
hide_credentials: bool = False,
) -> graph_db.Graph:
graph = await graph_db.get_graph(
graph_id, version, user_id=user_id, hide_credentials=hide_credentials
)
if not graph:
raise HTTPException(status_code=404, detail=f"Graph #{graph_id} not found.")
return graph
@v1_router.get(
path="/graphs/{graph_id}/versions",
tags=["graphs"],
dependencies=[Depends(auth_middleware)],
)
@v1_router.get(
path="/templates/{graph_id}/versions",
tags=["templates", "graphs"],
dependencies=[Depends(auth_middleware)],
)
async def get_graph_all_versions(
graph_id: str, user_id: Annotated[str, Depends(get_user_id)]
) -> list[graph_db.Graph]:
graphs = await graph_db.get_graph_all_versions(graph_id, user_id=user_id)
if not graphs:
raise HTTPException(status_code=404, detail=f"Graph #{graph_id} not found.")
return graphs
@v1_router.delete(
path="/graphs/{graph_id}", tags=["graphs"], dependencies=[Depends(auth_middleware)]
)
async def delete_graph(
graph_id: str, user_id: Annotated[str, Depends(get_user_id)]
) -> DeleteGraphResponse:
return {"version_counts": await graph_db.delete_graph(graph_id, user_id=user_id)}
@v1_router.put(
path="/graphs/{graph_id}", tags=["graphs"], dependencies=[Depends(auth_middleware)]
)
@v1_router.put(
path="/templates/{graph_id}",
tags=["templates", "graphs"],
dependencies=[Depends(auth_middleware)],
)
async def update_graph(
graph_id: str,
graph: graph_db.Graph,
user_id: Annotated[str, Depends(get_user_id)],
) -> graph_db.Graph:
# Sanity check
if graph.id and graph.id != graph_id:
raise HTTPException(400, detail="Graph ID does not match ID in URI")
# Determine new version
existing_versions = await graph_db.get_graph_all_versions(graph_id, user_id=user_id)
if not existing_versions:
raise HTTPException(404, detail=f"Graph #{graph_id} not found")
latest_version_number = max(g.version for g in existing_versions)
graph.version = latest_version_number + 1
latest_version_graph = next(
v for v in existing_versions if v.version == latest_version_number
)
if latest_version_graph.is_template != graph.is_template:
raise HTTPException(
400, detail="Changing is_template on an existing graph is forbidden"
)
graph.is_active = not graph.is_template
graph.reassign_ids()
new_graph_version = await graph_db.create_graph(graph, user_id=user_id)
if new_graph_version.is_active:
# Ensure new version is the only active version
await graph_db.set_graph_active_version(
graph_id=graph_id, version=new_graph_version.version, user_id=user_id
)
return new_graph_version
@v1_router.post(
path="/graphs", tags=["graphs"], dependencies=[Depends(auth_middleware)]
)
async def create_new_graph(
create_graph: CreateGraph, user_id: Annotated[str, Depends(get_user_id)]
) -> graph_db.Graph:
return await do_create_graph(create_graph, is_template=False, user_id=user_id)
@v1_router.put(
path="/graphs/{graph_id}/versions/active",
tags=["graphs"],
dependencies=[Depends(auth_middleware)],
)
async def set_graph_active_version(
graph_id: str,
request_body: SetGraphActiveVersion,
user_id: Annotated[str, Depends(get_user_id)],
):
new_active_version = request_body.active_graph_version
if not await graph_db.get_graph(graph_id, new_active_version, user_id=user_id):
raise HTTPException(404, f"Graph #{graph_id} v{new_active_version} not found")
await graph_db.set_graph_active_version(
graph_id=graph_id,
version=request_body.active_graph_version,
user_id=user_id,
)
@v1_router.post(
path="/graphs/{graph_id}/execute",
tags=["graphs"],
dependencies=[Depends(auth_middleware)],
)
async def execute_graph(
graph_id: str,
node_input: dict[Any, Any],
user_id: Annotated[str, Depends(get_user_id)],
) -> dict[str, Any]: # FIXME: add proper return type
try:
graph_exec = execution_manager_client().add_execution(
graph_id, node_input, user_id=user_id
)
return {"id": graph_exec["graph_exec_id"]}
except Exception as e:
msg = e.__str__().encode().decode("unicode_escape")
raise HTTPException(status_code=400, detail=msg)
@v1_router.post(
path="/graphs/{graph_id}/executions/{graph_exec_id}/stop",
tags=["graphs"],
dependencies=[Depends(auth_middleware)],
)
async def stop_graph_run(
graph_exec_id: str, user_id: Annotated[str, Depends(get_user_id)]
) -> list[execution_db.ExecutionResult]:
if not await execution_db.get_graph_execution(graph_exec_id, user_id):
raise HTTPException(404, detail=f"Agent execution #{graph_exec_id} not found")
await asyncio.to_thread(
lambda: execution_manager_client().cancel_execution(graph_exec_id)
)
# Retrieve & return canceled graph execution in its final state
return await execution_db.get_execution_results(graph_exec_id)
@v1_router.get(
path="/graphs/{graph_id}/input_schema",
tags=["graphs"],
dependencies=[Depends(auth_middleware)],
)
async def get_graph_input_schema(
graph_id: str,
user_id: Annotated[str, Depends(get_user_id)],
) -> list[graph_db.InputSchemaItem]:
try:
graph = await graph_db.get_graph(graph_id, user_id=user_id)
return graph.get_input_schema() if graph else []
except Exception:
raise HTTPException(status_code=404, detail=f"Graph #{graph_id} not found.")
@v1_router.get(
path="/graphs/{graph_id}/executions",
tags=["graphs"],
dependencies=[Depends(auth_middleware)],
)
async def list_graph_runs(
graph_id: str,
user_id: Annotated[str, Depends(get_user_id)],
graph_version: int | None = None,
) -> list[str]:
graph = await graph_db.get_graph(graph_id, graph_version, user_id=user_id)
if not graph:
rev = "" if graph_version is None else f" v{graph_version}"
raise HTTPException(
status_code=404, detail=f"Agent #{graph_id}{rev} not found."
)
return await execution_db.list_executions(graph_id, graph_version)
@v1_router.get(
path="/graphs/{graph_id}/executions/{graph_exec_id}",
tags=["graphs"],
dependencies=[Depends(auth_middleware)],
)
async def get_graph_run_node_execution_results(
graph_id: str,
graph_exec_id: str,
user_id: Annotated[str, Depends(get_user_id)],
) -> list[execution_db.ExecutionResult]:
graph = await graph_db.get_graph(graph_id, user_id=user_id)
if not graph:
raise HTTPException(status_code=404, detail=f"Graph #{graph_id} not found.")
return await execution_db.get_execution_results(graph_exec_id)
# NOTE: This is used for testing
async def get_graph_run_status(
graph_id: str,
graph_exec_id: str,
user_id: Annotated[str, Depends(get_user_id)],
) -> execution_db.ExecutionStatus:
graph = await graph_db.get_graph(graph_id, user_id=user_id)
if not graph:
raise HTTPException(status_code=404, detail=f"Graph #{graph_id} not found.")
execution = await execution_db.get_graph_execution(graph_exec_id, user_id)
if not execution:
raise HTTPException(
status_code=404, detail=f"Execution #{graph_exec_id} not found."
)
return execution.executionStatus
########################################################
##################### Templates ########################
########################################################
@v1_router.get(
path="/templates",
tags=["graphs", "templates"],
dependencies=[Depends(auth_middleware)],
)
async def get_templates(
user_id: Annotated[str, Depends(get_user_id)]
) -> list[graph_db.GraphMeta]:
return await graph_db.get_graphs_meta(filter_by="template", user_id=user_id)
@v1_router.get(
path="/templates/{graph_id}",
tags=["templates", "graphs"],
dependencies=[Depends(auth_middleware)],
)
async def get_template(graph_id: str, version: int | None = None) -> graph_db.Graph:
graph = await graph_db.get_graph(graph_id, version, template=True)
if not graph:
raise HTTPException(status_code=404, detail=f"Template #{graph_id} not found.")
return graph
async def do_create_graph(
create_graph: CreateGraph,
is_template: bool,
# user_id doesn't have to be annotated like on other endpoints,
# because create_graph isn't used directly as an endpoint
user_id: str,
) -> graph_db.Graph:
if create_graph.graph:
graph = create_graph.graph
elif create_graph.template_id:
# Create a new graph from a template
graph = await graph_db.get_graph(
create_graph.template_id,
create_graph.template_version,
template=True,
user_id=user_id,
)
if not graph:
raise HTTPException(
400, detail=f"Template #{create_graph.template_id} not found"
)
graph.version = 1
else:
raise HTTPException(
status_code=400, detail="Either graph or template_id must be provided."
)
graph.is_template = is_template
graph.is_active = not is_template
graph.reassign_ids(reassign_graph_id=True)
return await graph_db.create_graph(graph, user_id=user_id)
@v1_router.post(
path="/templates",
tags=["templates", "graphs"],
dependencies=[Depends(auth_middleware)],
)
async def create_new_template(
create_graph: CreateGraph, user_id: Annotated[str, Depends(get_user_id)]
) -> graph_db.Graph:
return await do_create_graph(create_graph, is_template=True, user_id=user_id)
########################################################
##################### Schedules ########################
########################################################
@v1_router.post(
path="/graphs/{graph_id}/schedules",
tags=["graphs"],
dependencies=[Depends(auth_middleware)],
)
async def create_schedule(
graph_id: str,
cron: str,
input_data: dict[Any, Any],
user_id: Annotated[str, Depends(get_user_id)],
) -> dict[Any, Any]:
graph = await graph_db.get_graph(graph_id, user_id=user_id)
if not graph:
raise HTTPException(status_code=404, detail=f"Graph #{graph_id} not found.")
return {
"id": await asyncio.to_thread(
lambda: execution_scheduler_client().add_execution_schedule(
graph_id=graph_id,
graph_version=graph.version,
cron=cron,
input_data=input_data,
user_id=user_id,
)
)
}
@v1_router.put(
path="/graphs/schedules/{schedule_id}",
tags=["graphs"],
dependencies=[Depends(auth_middleware)],
)
async def update_schedule(
schedule_id: str,
input_data: dict[Any, Any],
user_id: Annotated[str, Depends(get_user_id)],
) -> dict[Any, Any]:
is_enabled = input_data.get("is_enabled", False)
execution_scheduler_client().update_schedule(
schedule_id, is_enabled, user_id=user_id
)
return {"id": schedule_id}
@v1_router.get(
path="/graphs/{graph_id}/schedules",
tags=["graphs"],
dependencies=[Depends(auth_middleware)],
)
async def get_execution_schedules(
graph_id: str, user_id: Annotated[str, Depends(get_user_id)]
) -> dict[str, str]:
return execution_scheduler_client().get_execution_schedules(graph_id, user_id)
########################################################
##################### Settings ########################
########################################################
@v1_router.post(
path="/settings", tags=["settings"], dependencies=[Depends(auth_middleware)]
)
async def update_configuration(
updated_settings: Annotated[
Dict[str, Any],
Body(
examples=[
{
"config": {
"num_graph_workers": 10,
"num_node_workers": 10,
}
}
]
),
],
):
settings = Settings()
try:
updated_fields: dict[Any, Any] = {"config": [], "secrets": []}
for key, value in updated_settings.get("config", {}).items():
if hasattr(settings.config, key):
setattr(settings.config, key, value)
updated_fields["config"].append(key)
for key, value in updated_settings.get("secrets", {}).items():
if hasattr(settings.secrets, key):
setattr(settings.secrets, key, value)
updated_fields["secrets"].append(key)
settings.save()
return {
"message": "Settings updated successfully",
"updated_fields": updated_fields,
}
except Exception as e:
raise HTTPException(status_code=400, detail=str(e))

View File

@ -252,7 +252,7 @@ async def block_autogen_agent():
test_user = await create_test_user()
test_graph = await create_graph(create_test_graph(), user_id=test_user.id)
input_data = {"input": "Write me a block that writes a string into a file."}
response = server.agent_server.execute_graph(
response = await server.agent_server.test_execute_graph(
test_graph.id, input_data, test_user.id
)
print(response)

View File

@ -156,7 +156,7 @@ async def reddit_marketing_agent():
test_user = await create_test_user()
test_graph = await create_graph(create_test_graph(), user_id=test_user.id)
input_data = {"subreddit": "AutoGPT"}
response = server.agent_server.execute_graph(
response = await server.agent_server.test_execute_graph(
test_graph.id, input_data, test_user.id
)
print(response)

View File

@ -78,7 +78,7 @@ async def sample_agent():
test_user = await create_test_user()
test_graph = await create_graph(create_test_graph(), test_user.id)
input_data = {"input_1": "Hello", "input_2": "World"}
response = server.agent_server.execute_graph(
response = await server.agent_server.test_execute_graph(
test_graph.id, input_data, test_user.id
)
print(response)

View File

@ -1,3 +1,4 @@
import logging
import time
from backend.data import db
@ -6,9 +7,10 @@ from backend.data.execution import ExecutionStatus
from backend.data.model import CREDENTIALS_FIELD_NAME
from backend.data.user import create_default_user
from backend.executor import DatabaseManager, ExecutionManager, ExecutionScheduler
from backend.server.rest_api import AgentServer, get_user_id
from backend.server.rest_api import AgentServer
from backend.server.utils import get_user_id
log = print
log = logging.getLogger(__name__)
class SpinTestServer:
@ -57,17 +59,19 @@ async def wait_execution(
timeout: int = 20,
) -> list:
async def is_execution_completed():
status = await AgentServer().get_graph_run_status(
status = await AgentServer().test_get_graph_run_status(
graph_id, graph_exec_id, user_id
)
log.info(f"Execution status: {status}")
if status == ExecutionStatus.FAILED:
log.info("Execution failed")
raise Exception("Execution failed")
return status == ExecutionStatus.COMPLETED
# Wait for the executions to complete
for i in range(timeout):
if await is_execution_completed():
return await AgentServer().get_graph_run_node_execution_results(
return await AgentServer().test_get_graph_run_node_execution_results(
graph_id, graph_exec_id, user_id
)
time.sleep(1)
@ -79,7 +83,7 @@ def execute_block_test(block: Block):
prefix = f"[Test-{block.name}]"
if not block.test_input or not block.test_output:
log(f"{prefix} No test data provided")
log.info(f"{prefix} No test data provided")
return
if not isinstance(block.test_input, list):
block.test_input = [block.test_input]
@ -87,15 +91,15 @@ def execute_block_test(block: Block):
block.test_output = [block.test_output]
output_index = 0
log(f"{prefix} Executing {len(block.test_input)} tests...")
log.info(f"{prefix} Executing {len(block.test_input)} tests...")
prefix = " " * 4 + prefix
for mock_name, mock_obj in (block.test_mock or {}).items():
log(f"{prefix} mocking {mock_name}...")
log.info(f"{prefix} mocking {mock_name}...")
if hasattr(block, mock_name):
setattr(block, mock_name, mock_obj)
else:
log(f"{prefix} mock {mock_name} not found in block")
log.info(f"{prefix} mock {mock_name} not found in block")
extra_exec_kwargs = {}
@ -107,7 +111,7 @@ def execute_block_test(block: Block):
extra_exec_kwargs[CREDENTIALS_FIELD_NAME] = block.test_credentials
for input_data in block.test_input:
log(f"{prefix} in: {input_data}")
log.info(f"{prefix} in: {input_data}")
for output_name, output_data in block.execute(input_data, **extra_exec_kwargs):
if output_index >= len(block.test_output):
@ -125,7 +129,7 @@ def execute_block_test(block: Block):
is_matching = False
mark = "" if is_matching else ""
log(f"{prefix} {mark} comparing `{data}` vs `{expected_data}`")
log.info(f"{prefix} {mark} comparing `{data}` vs `{expected_data}`")
if not is_matching:
raise ValueError(
f"{prefix}: wrong output {data} vs {expected_data}"

View File

@ -1,7 +1,20 @@
import logging
import pytest
from backend.util.test import SpinTestServer
# NOTE: You can run tests like with the --log-cli-level=INFO to see the logs
# Set up logging
logger = logging.getLogger(__name__)
# Create console handler with formatting
ch = logging.StreamHandler()
ch.setLevel(logging.INFO)
formatter = logging.Formatter("%(asctime)s - %(name)s - %(levelname)s - %(message)s")
ch.setFormatter(formatter)
logger.addHandler(ch)
@pytest.fixture(scope="session")
async def server():
@ -12,7 +25,7 @@ async def server():
@pytest.fixture(scope="session", autouse=True)
async def graph_cleanup(server):
created_graph_ids = []
original_create_graph = server.agent_server.create_graph
original_create_graph = server.agent_server.test_create_graph
async def create_graph_wrapper(*args, **kwargs):
created_graph = await original_create_graph(*args, **kwargs)
@ -22,13 +35,14 @@ async def graph_cleanup(server):
return created_graph
try:
server.agent_server.create_graph = create_graph_wrapper
server.agent_server.test_create_graph = create_graph_wrapper
yield # This runs the test function
finally:
server.agent_server.create_graph = original_create_graph
server.agent_server.test_create_graph = original_create_graph
# Delete the created graphs and assert they were deleted
for graph_id, user_id in created_graph_ids:
resp = await server.agent_server.delete_graph(graph_id, user_id)
num_deleted = resp["version_counts"]
assert num_deleted > 0, f"Graph {graph_id} was not deleted."
if user_id:
resp = await server.agent_server.test_delete_graph(graph_id, user_id)
num_deleted = resp["version_counts"]
assert num_deleted > 0, f"Graph {graph_id} was not deleted."

View File

@ -47,15 +47,15 @@ async def test_graph_creation(server: SpinTestServer):
create_graph = CreateGraph(graph=graph)
try:
await server.agent_server.create_graph(create_graph, False, DEFAULT_USER_ID)
await server.agent_server.test_create_graph(create_graph, DEFAULT_USER_ID)
assert False, "Should not be able to connect nodes from different subgraphs"
except ValueError as e:
assert "different subgraph" in str(e)
# Change node_1 <-> node_3 link to node_1 <-> node_2 (input for subgraph_1)
graph.links[0].sink_id = "node_2"
created_graph = await server.agent_server.create_graph(
create_graph, False, DEFAULT_USER_ID
created_graph = await server.agent_server.test_create_graph(
create_graph, DEFAULT_USER_ID
)
assert UUID(created_graph.id)
@ -102,8 +102,8 @@ async def test_get_input_schema(server: SpinTestServer):
)
create_graph = CreateGraph(graph=graph)
created_graph = await server.agent_server.create_graph(
create_graph, False, DEFAULT_USER_ID
created_graph = await server.agent_server.test_create_graph(
create_graph, DEFAULT_USER_ID
)
input_schema = created_graph.get_input_schema()
@ -138,8 +138,8 @@ async def test_get_input_schema_none_required(server: SpinTestServer):
)
create_graph = CreateGraph(graph=graph)
created_graph = await server.agent_server.create_graph(
create_graph, False, DEFAULT_USER_ID
created_graph = await server.agent_server.test_create_graph(
create_graph, DEFAULT_USER_ID
)
input_schema = created_graph.get_input_schema()
@ -180,8 +180,8 @@ async def test_get_input_schema_with_linked_blocks(server: SpinTestServer):
)
create_graph = CreateGraph(graph=graph)
created_graph = await server.agent_server.create_graph(
create_graph, False, DEFAULT_USER_ID
created_graph = await server.agent_server.test_create_graph(
create_graph, DEFAULT_USER_ID
)
input_schema = created_graph.get_input_schema()

View File

@ -1,3 +1,5 @@
import logging
import pytest
from prisma.models import User
@ -9,9 +11,12 @@ from backend.server.rest_api import AgentServer
from backend.usecases.sample import create_test_graph, create_test_user
from backend.util.test import SpinTestServer, wait_execution
logger = logging.getLogger(__name__)
async def create_graph(s: SpinTestServer, g: graph.Graph, u: User) -> graph.Graph:
return await s.agent_server.create_graph(CreateGraph(graph=g), False, u.id)
logger.info(f"Creating graph for user {u.id}")
return await s.agent_server.test_create_graph(CreateGraph(graph=g), u.id)
async def execute_graph(
@ -21,12 +26,20 @@ async def execute_graph(
input_data: dict,
num_execs: int = 4,
) -> str:
logger.info(f"Executing graph {test_graph.id} for user {test_user.id}")
logger.info(f"Input data: {input_data}")
# --- Test adding new executions --- #
response = agent_server.execute_graph(test_graph.id, input_data, test_user.id)
response = await agent_server.test_execute_graph(
test_graph.id, input_data, test_user.id
)
graph_exec_id = response["id"]
logger.info(f"Created execution with ID: {graph_exec_id}")
# Execution queue should be empty
logger.info("Waiting for execution to complete...")
result = await wait_execution(test_user.id, test_graph.id, graph_exec_id)
logger.info(f"Execution completed with {len(result)} results")
assert result and len(result) == num_execs
return graph_exec_id
@ -37,7 +50,8 @@ async def assert_sample_graph_executions(
test_user: User,
graph_exec_id: str,
):
executions = await agent_server.get_graph_run_node_execution_results(
logger.info(f"Checking execution results for graph {test_graph.id}")
executions = await agent_server.test_get_graph_run_node_execution_results(
test_graph.id,
graph_exec_id,
test_user.id,
@ -57,6 +71,7 @@ async def assert_sample_graph_executions(
# Executing StoreValueBlock
exec = executions[0]
logger.info(f"Checking first StoreValueBlock execution: {exec}")
assert exec.status == execution.ExecutionStatus.COMPLETED
assert exec.graph_exec_id == graph_exec_id
assert (
@ -69,6 +84,7 @@ async def assert_sample_graph_executions(
# Executing StoreValueBlock
exec = executions[1]
logger.info(f"Checking second StoreValueBlock execution: {exec}")
assert exec.status == execution.ExecutionStatus.COMPLETED
assert exec.graph_exec_id == graph_exec_id
assert (
@ -81,6 +97,7 @@ async def assert_sample_graph_executions(
# Executing FillTextTemplateBlock
exec = executions[2]
logger.info(f"Checking FillTextTemplateBlock execution: {exec}")
assert exec.status == execution.ExecutionStatus.COMPLETED
assert exec.graph_exec_id == graph_exec_id
assert exec.output_data == {"output": ["Hello, World!!!"]}
@ -95,6 +112,7 @@ async def assert_sample_graph_executions(
# Executing PrintToConsoleBlock
exec = executions[3]
logger.info(f"Checking PrintToConsoleBlock execution: {exec}")
assert exec.status == execution.ExecutionStatus.COMPLETED
assert exec.graph_exec_id == graph_exec_id
assert exec.output_data == {"status": ["printed"]}
@ -104,6 +122,7 @@ async def assert_sample_graph_executions(
@pytest.mark.asyncio(scope="session")
async def test_agent_execution(server: SpinTestServer):
logger.info("Starting test_agent_execution")
test_user = await create_test_user()
test_graph = await create_graph(server, create_test_graph(), test_user)
data = {"input_1": "Hello", "input_2": "World"}
@ -117,6 +136,7 @@ async def test_agent_execution(server: SpinTestServer):
await assert_sample_graph_executions(
server.agent_server, test_graph, test_user, graph_exec_id
)
logger.info("Completed test_agent_execution")
@pytest.mark.asyncio(scope="session")
@ -132,6 +152,7 @@ async def test_input_pin_always_waited(server: SpinTestServer):
// key
StoreValueBlock2
"""
logger.info("Starting test_input_pin_always_waited")
nodes = [
graph.Node(
block_id=StoreValueBlock().id,
@ -172,7 +193,8 @@ async def test_input_pin_always_waited(server: SpinTestServer):
server.agent_server, test_graph, test_user, {}, 3
)
executions = await server.agent_server.get_graph_run_node_execution_results(
logger.info("Checking execution results")
executions = await server.agent_server.test_get_graph_run_node_execution_results(
test_graph.id, graph_exec_id, test_user.id
)
assert len(executions) == 3
@ -180,6 +202,7 @@ async def test_input_pin_always_waited(server: SpinTestServer):
# Hence executing extraction of "key" from {"key1": "value1", "key2": "value2"}
assert executions[2].status == execution.ExecutionStatus.COMPLETED
assert executions[2].output_data == {"output": ["value2"]}
logger.info("Completed test_input_pin_always_waited")
@pytest.mark.asyncio(scope="session")
@ -197,6 +220,7 @@ async def test_static_input_link_on_graph(server: SpinTestServer):
And later, another output is produced on input pin `b`, which is a static link,
this input will complete the input of those three incomplete executions.
"""
logger.info("Starting test_static_input_link_on_graph")
nodes = [
graph.Node(block_id=StoreValueBlock().id, input_default={"input": 4}), # a
graph.Node(block_id=StoreValueBlock().id, input_default={"input": 4}), # a
@ -252,11 +276,14 @@ async def test_static_input_link_on_graph(server: SpinTestServer):
graph_exec_id = await execute_graph(
server.agent_server, test_graph, test_user, {}, 8
)
executions = await server.agent_server.get_graph_run_node_execution_results(
logger.info("Checking execution results")
executions = await server.agent_server.test_get_graph_run_node_execution_results(
test_graph.id, graph_exec_id, test_user.id
)
assert len(executions) == 8
# The last 3 executions will be a+b=4+5=9
for exec_data in executions[-3:]:
for i, exec_data in enumerate(executions[-3:]):
logger.info(f"Checking execution {i+1} of last 3: {exec_data}")
assert exec_data.status == execution.ExecutionStatus.COMPLETED
assert exec_data.output_data == {"result": [9]}
logger.info("Completed test_static_input_link_on_graph")

View File

@ -12,7 +12,7 @@ from backend.util.test import SpinTestServer
async def test_agent_schedule(server: SpinTestServer):
await db.connect()
test_user = await create_test_user()
test_graph = await server.agent_server.create_graph(
test_graph = await server.agent_server.test_create_graph(
create_graph=CreateGraph(graph=create_test_graph()),
is_template=False,
user_id=test_user.id,

View File

@ -28,7 +28,7 @@ export default class BaseAutoGPTServerAPI {
constructor(
baseUrl: string = process.env.NEXT_PUBLIC_AGPT_SERVER_URL ||
"http://localhost:8006/api",
"http://localhost:8006/api/v1",
wsUrl: string = process.env.NEXT_PUBLIC_AGPT_WS_SERVER_URL ||
"ws://localhost:8001/ws",
supabaseClient: SupabaseClient | null = null,