feat(autogpt_server): Expose rest api via websocket (#7350)

* Add in websocket event types

* adding in api endpoints

* Updated ws messages
pull/7363/head^2
Swifty 2024-07-10 11:54:18 +02:00 committed by GitHub
parent f94e81f48b
commit 3789b00479
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 212 additions and 29 deletions

View File

@ -34,7 +34,7 @@ class ConnectionManager:
graph_id = result.graph_id
if graph_id in self.subscriptions:
message = WsMessage(
method=Methods.UPDATE,
method=Methods.EXECUTION_EVENT,
channel=graph_id,
data=result.model_dump()
).model_dump_json()

View File

@ -6,13 +6,24 @@ import pydantic
class Methods(enum.Enum):
SUBSCRIBE = "subscribe"
UNSUBSCRIBE = "unsubscribe"
UPDATE = "update"
EXECUTION_EVENT = "execution_event"
GET_BLOCKS = "get_blocks"
EXECUTE_BLOCK = "execute_block"
GET_GRAPHS = "get_graphs"
GET_GRAPH = "get_graph"
CREATE_GRAPH = "create_graph"
RUN_GRAPH = "run_graph"
GET_GRAPH_RUNS = "get_graph_runs"
CREATE_SCHEDULED_RUN = "create_scheduled_run"
GET_SCHEDULED_RUNS = "get_scheduled_runs"
UPDATE_SCHEDULED_RUN = "update_scheduled_run"
UPDATE_CONFIG = "update_config"
ERROR = "error"
class WsMessage(pydantic.BaseModel):
method: Methods
data: typing.Dict[str, typing.Any] | None = None
data: typing.Dict[str, typing.Any] | list[typing.Any] | None = None
success: bool | None = None
channel: str | None = None
error: str | None = None

View File

@ -3,12 +3,18 @@ import uuid
from typing import Annotated, Any, Dict
import uvicorn
from fastapi import WebSocket
from fastapi.responses import JSONResponse
from fastapi.staticfiles import StaticFiles
from contextlib import asynccontextmanager
from fastapi import APIRouter, Body, FastAPI, HTTPException
from fastapi import (
APIRouter,
Body,
FastAPI,
HTTPException,
WebSocket,
WebSocketDisconnect,
)
from fastapi.middleware.cors import CORSMiddleware
from autogpt_server.data import db, execution, block
@ -20,11 +26,12 @@ from autogpt_server.data.graph import (
)
from autogpt_server.executor import ExecutionManager, ExecutionScheduler
from autogpt_server.server.conn_manager import ConnectionManager
from autogpt_server.server.ws_api import websocket_router as ws_impl
import autogpt_server.server.ws_api
from autogpt_server.util.data import get_frontend_path
from autogpt_server.util.service import expose # type: ignore
from autogpt_server.util.service import AppService, get_service_client
from autogpt_server.util.settings import Settings
from autogpt_server.server.model import WsMessage, Methods
class AgentServer(AppService):
@ -73,7 +80,7 @@ class AgentServer(AppService):
)
router.add_api_route(
path="/blocks/{block_id}/execute",
endpoint=self.execute_graph_block,
endpoint=self.execute_graph_block, # type: ignore
methods=["POST"],
)
router.add_api_route(
@ -128,7 +135,7 @@ class AgentServer(AppService):
methods=["POST"],
)
app.add_exception_handler(500, self.handle_internal_error)
app.add_exception_handler(500, self.handle_internal_error) # type: ignore
app.mount(
path="/frontend",
@ -140,7 +147,7 @@ class AgentServer(AppService):
@app.websocket("/ws")
async def websocket_endpoint(websocket: WebSocket): # type: ignore
await ws_impl(websocket, self.manager)
await self.websocket_router(websocket)
uvicorn.run(app, host="0.0.0.0", port=8000)
@ -153,22 +160,191 @@ class AgentServer(AppService):
return get_service_client(ExecutionScheduler)
@classmethod
def handle_internal_error(cls, request, exc):
def handle_internal_error(cls, request, exc): # type: ignore
return JSONResponse(
content={
"message": f"{request.url.path} call failure",
"error": str(exc),
"message": f"{request.url.path} call failure", # type: ignore
"error": str(exc), # type: ignore
},
status_code=500,
)
@classmethod
def get_graph_blocks(cls) -> list[dict[Any, Any]]:
return [v.to_dict() for v in block.get_blocks().values()]
async def websocket_router(self, websocket: WebSocket):
await self.manager.connect(websocket)
try:
while True:
data = await websocket.receive_text()
message = WsMessage.model_validate_json(data)
if message.method == Methods.SUBSCRIBE:
await autogpt_server.server.ws_api.handle_subscribe(
websocket, self.manager, message
)
elif message.method == Methods.UNSUBSCRIBE:
await autogpt_server.server.ws_api.handle_unsubscribe(
websocket, self.manager, message
)
elif message.method == Methods.EXECUTION_EVENT:
print("Execution event received")
elif message.method == Methods.GET_BLOCKS:
data = self.get_graph_blocks()
await websocket.send_text(
WsMessage(
method=Methods.GET_BLOCKS,
success=True,
data=data,
).model_dump_json()
)
elif message.method == Methods.EXECUTE_BLOCK:
assert isinstance(message.data, dict), "Data must be a dictionary"
data = self.execute_graph_block(
message.data["block_id"], message.data["data"]
)
await websocket.send_text(
WsMessage(
method=Methods.EXECUTE_BLOCK,
success=True,
data=data,
).model_dump_json()
)
elif message.method == Methods.GET_GRAPHS:
data = await self.get_graphs()
await websocket.send_text(
WsMessage(
method=Methods.GET_GRAPHS,
success=True,
data=data,
).model_dump_json()
)
print("Get graphs request received")
elif message.method == Methods.GET_GRAPH:
assert isinstance(message.data, dict), "Data must be a dictionary"
data = await self.get_graph(message.data["graph_id"])
await websocket.send_text(
WsMessage(
method=Methods.GET_GRAPH,
success=True,
data=data.model_dump(),
).model_dump_json()
)
print("Get graph request received")
elif message.method == Methods.CREATE_GRAPH:
assert isinstance(message.data, dict), "Data must be a dictionary"
graph = Graph.model_validate(message.data)
data = await self.create_new_graph(graph)
await websocket.send_text(
WsMessage(
method=Methods.CREATE_GRAPH,
success=True,
data=data.model_dump(),
).model_dump_json()
)
print("Create graph request received")
elif message.method == Methods.RUN_GRAPH:
assert isinstance(message.data, dict), "Data must be a dictionary"
data = await self.execute_graph(
message.data["graph_id"], message.data["data"]
)
await websocket.send_text(
WsMessage(
method=Methods.RUN_GRAPH,
success=True,
data=data,
).model_dump_json()
)
print("Run graph request received")
elif message.method == Methods.GET_GRAPH_RUNS:
assert isinstance(message.data, dict), "Data must be a dictionary"
data = await self.list_graph_runs(message.data["graph_id"])
await websocket.send_text(
WsMessage(
method=Methods.GET_GRAPH_RUNS,
success=True,
data=data,
).model_dump_json()
)
print("Get graph runs request received")
elif message.method == Methods.CREATE_SCHEDULED_RUN:
assert isinstance(message.data, dict), "Data must be a dictionary"
data = await self.create_schedule(
message.data["graph_id"],
message.data["cron"],
message.data["data"],
)
await websocket.send_text(
WsMessage(
method=Methods.CREATE_SCHEDULED_RUN,
success=True,
data=data,
).model_dump_json()
)
print("Create scheduled run request received")
elif message.method == Methods.GET_SCHEDULED_RUNS:
assert isinstance(message.data, dict), "Data must be a dictionary"
data = self.get_execution_schedules(message.data["graph_id"])
await websocket.send_text(
WsMessage(
method=Methods.GET_SCHEDULED_RUNS,
success=True,
data=data,
).model_dump_json()
)
print("Get scheduled runs request received")
elif message.method == Methods.UPDATE_SCHEDULED_RUN:
assert isinstance(message.data, dict), "Data must be a dictionary"
data = self.update_schedule(
message.data["schedule_id"], message.data
)
await websocket.send_text(
WsMessage(
method=Methods.UPDATE_SCHEDULED_RUN,
success=True,
data=data,
).model_dump_json()
)
print("Update scheduled run request received")
elif message.method == Methods.UPDATE_CONFIG:
assert isinstance(message.data, dict), "Data must be a dictionary"
data = self.update_configuration(message.data)
await websocket.send_text(
WsMessage(
method=Methods.UPDATE_CONFIG,
success=True,
data=data,
).model_dump_json()
)
print("Update config request received")
elif message.method == Methods.ERROR:
print("Error message received")
else:
print("Message type is not processed by the server")
await websocket.send_text(
WsMessage(
method=Methods.ERROR,
success=False,
error="Message type is not processed by the server",
).model_dump_json()
)
except WebSocketDisconnect:
self.manager.disconnect(websocket)
print("Client Disconnected")
@classmethod
def execute_graph_block(cls, block_id: str, data: dict[str, Any]) -> list:
obj = block.get_block(block_id)
def get_graph_blocks(cls) -> list[dict[Any, Any]]:
return [v.to_dict() for v in block.get_blocks().values()] # type: ignore
@classmethod
def execute_graph_block(
cls, block_id: str, data: dict[str, Any]
) -> list[dict[str, Any]]:
obj = block.get_block(block_id) # type: ignore
if not obj:
raise HTTPException(status_code=404, detail=f"Block #{block_id} not found.")
return [{name: data} for name, data in obj.execute(data)]
@ -252,9 +428,7 @@ class AgentServer(AppService):
@expose
def send_execution_update(self, execution_result_dict: dict[Any, Any]):
execution_result = execution.ExecutionResult(**execution_result_dict)
self.run_and_wait(
self.event_queue.put(execution_result)
)
self.run_and_wait(self.event_queue.put(execution_result))
@classmethod
def update_configuration(
@ -275,12 +449,9 @@ class AgentServer(AppService):
setattr(settings.secrets, key, value) # type: ignore
updated_fields["secrets"].append(key)
settings.save()
return JSONResponse(
content={
"message": "Settings updated successfully",
"updated_fields": updated_fields,
},
status_code=200,
)
return {
"message": "Settings updated successfully",
"updated_fields": updated_fields,
}
except Exception as e:
raise HTTPException(status_code=400, detail=str(e))

View File

@ -84,7 +84,7 @@ async def test_send_execution_result(
mock_websocket.send_text.assert_called_once_with(
WsMessage(
method=Methods.UPDATE,
method=Methods.EXECUTION_EVENT,
channel="test_graph",
data=result.model_dump(),
).model_dump_json()

View File

@ -75,7 +75,8 @@ async def test_websocket_router_invalid_method(
mock_websocket: AsyncMock, mock_manager: AsyncMock
) -> None:
mock_websocket.receive_text.side_effect = [
WsMessage(method=Methods.UPDATE).model_dump_json(),
WsMessage(method=Methods.EXECUTION_EVENT).model_dump_json(),
WebSocketDisconnect(),
]