From 277f3e4e4d05f44a7fa004bde104a3e20275c39b Mon Sep 17 00:00:00 2001 From: merwanehamadi Date: Wed, 16 Aug 2023 09:00:05 -0700 Subject: [PATCH] Add endpoints to power dev tool (#310) --- agbenchmark/app.py | 83 ++++++++++++++++++++++++++++++++++ agbenchmark/start_benchmark.py | 17 ++++++- 2 files changed, 99 insertions(+), 1 deletion(-) create mode 100644 agbenchmark/app.py diff --git a/agbenchmark/app.py b/agbenchmark/app.py new file mode 100644 index 000000000..e4892b867 --- /dev/null +++ b/agbenchmark/app.py @@ -0,0 +1,83 @@ +from pathlib import Path + +from fastapi import FastAPI +from fastapi import ( + HTTPException as FastAPIHTTPException, # Import HTTPException from FastAPI +) +from fastapi.responses import FileResponse + +app = FastAPI() + + +@app.get("/skill_tree") +def get_skill_tree() -> dict: + return { + "graph": { + "nodes": { + "TestWriteFile": { + "name": "TestWriteFile", + "input": "Write the word 'Washington' to a .txt file", + "task_id": "fde559f8-3ab8-11ee-be56-0242ac120002", + "category": ["interface"], + "dependencies": [], + "cutoff": 60, + "ground": { + "answer": "The word 'Washington', printed to a .txt file named anything", + "should_contain": ["Washington"], + "should_not_contain": [], + "files": [".txt"], + "eval": {"type": "file"}, + }, + "info": { + "difficulty": "interface", + "description": "Tests the agents ability to write to a file", + "side_effects": [""], + }, + }, + "TestReadFile": { + "name": "TestReadFile", + "category": ["interface"], + "task_id": "fde559f8-3ab8-11ee-be56-0242ac120002", + "input": "Read the file called file_to_read.txt and write its content to a file called output.txt", + "dependencies": ["TestWriteFile"], + "cutoff": 60, + "ground": { + "answer": "The content of output.txt should be 'Hello World!'", + "should_contain": ["Hello World!"], + "files": ["output.txt"], + "eval": {"type": "file"}, + }, + "info": { + "description": "Tests the ability for an agent to read a file.", + "difficulty": "interface", + "side_effects": [""], + }, + "artifacts": [ + { + "artifact_id": "a1b259f8-3ab8-11ee-be56-0242ac121234", + "file_name": "file_to_read.txt", + "file_path": "interface/write_file/artifacts_out", + } + ], + }, + }, + "edges": [{"source": "TestWriteFile", "target": "TestReadFile"}], + } + } + + +@app.get("/agent/tasks/{challenge_id}/artifacts/{artifact_id}") +def get_artifact( + challenge_id: str, artifact_id: str +) -> FileResponse: # Added return type annotation + try: + # Look up the file path using the challenge ID and artifact ID + + file_path = "challenges/interface/read_file/artifacts_in/file_to_read.txt" + current_directory = Path(__file__).resolve().parent + + # Return the file as a response + return FileResponse(current_directory / file_path) + + except KeyError: + raise FastAPIHTTPException(status_code=404, detail="Artifact not found") diff --git a/agbenchmark/start_benchmark.py b/agbenchmark/start_benchmark.py index 1c5ea42fa..b9526e95f 100644 --- a/agbenchmark/start_benchmark.py +++ b/agbenchmark/start_benchmark.py @@ -1,6 +1,7 @@ import glob import json import os +import subprocess import sys from datetime import datetime from pathlib import Path @@ -98,6 +99,7 @@ def cli() -> None: ) @click.option("--nc", is_flag=True, help="Run without cutoff") @click.option("--cutoff", default=None, help="Set or override tests cutoff (seconds)") +@click.option("--server", is_flag=True, help="Starts the server") def start( category: str, skip_category: list[str], @@ -110,6 +112,7 @@ def start( no_dep: bool, nc: bool, cutoff: Optional[int] = None, + server: bool = False, ) -> int: """Start the benchmark tests. If a category flag is provided, run the categories with that mark.""" # Check if configuration file exists and is not empty @@ -228,7 +231,19 @@ def start( # when used as a library, the pytest directory to execute is in the CURRENT_DIRECTORY pytest_args.append(str(CURRENT_DIRECTORY)) - + if server: + subprocess.run( + [ + "uvicorn", + "agbenchmark.app:app", + "--reload", + "--host", + "0.0.0.0", + "--port", + "8000", + ] + ) + return 0 return sys.exit(pytest.main(pytest_args))