feat(autogpt_server): Expose rest api via websocket (#7350)

* Add in websocket event types

* adding in api endpoints

* Updated ws messages
pull/7363/head^2
Swifty 2024-07-10 11:54:18 +02:00 committed by GitHub
parent f94e81f48b
commit 3789b00479
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 212 additions and 29 deletions

View File

@ -34,7 +34,7 @@ class ConnectionManager:
graph_id = result.graph_id graph_id = result.graph_id
if graph_id in self.subscriptions: if graph_id in self.subscriptions:
message = WsMessage( message = WsMessage(
method=Methods.UPDATE, method=Methods.EXECUTION_EVENT,
channel=graph_id, channel=graph_id,
data=result.model_dump() data=result.model_dump()
).model_dump_json() ).model_dump_json()

View File

@ -6,13 +6,24 @@ import pydantic
class Methods(enum.Enum): class Methods(enum.Enum):
SUBSCRIBE = "subscribe" SUBSCRIBE = "subscribe"
UNSUBSCRIBE = "unsubscribe" UNSUBSCRIBE = "unsubscribe"
UPDATE = "update" EXECUTION_EVENT = "execution_event"
GET_BLOCKS = "get_blocks"
EXECUTE_BLOCK = "execute_block"
GET_GRAPHS = "get_graphs"
GET_GRAPH = "get_graph"
CREATE_GRAPH = "create_graph"
RUN_GRAPH = "run_graph"
GET_GRAPH_RUNS = "get_graph_runs"
CREATE_SCHEDULED_RUN = "create_scheduled_run"
GET_SCHEDULED_RUNS = "get_scheduled_runs"
UPDATE_SCHEDULED_RUN = "update_scheduled_run"
UPDATE_CONFIG = "update_config"
ERROR = "error" ERROR = "error"
class WsMessage(pydantic.BaseModel): class WsMessage(pydantic.BaseModel):
method: Methods method: Methods
data: typing.Dict[str, typing.Any] | None = None data: typing.Dict[str, typing.Any] | list[typing.Any] | None = None
success: bool | None = None success: bool | None = None
channel: str | None = None channel: str | None = None
error: str | None = None error: str | None = None

View File

@ -3,12 +3,18 @@ import uuid
from typing import Annotated, Any, Dict from typing import Annotated, Any, Dict
import uvicorn import uvicorn
from fastapi import WebSocket
from fastapi.responses import JSONResponse from fastapi.responses import JSONResponse
from fastapi.staticfiles import StaticFiles from fastapi.staticfiles import StaticFiles
from contextlib import asynccontextmanager from contextlib import asynccontextmanager
from fastapi import APIRouter, Body, FastAPI, HTTPException from fastapi import (
APIRouter,
Body,
FastAPI,
HTTPException,
WebSocket,
WebSocketDisconnect,
)
from fastapi.middleware.cors import CORSMiddleware from fastapi.middleware.cors import CORSMiddleware
from autogpt_server.data import db, execution, block from autogpt_server.data import db, execution, block
@ -20,11 +26,12 @@ from autogpt_server.data.graph import (
) )
from autogpt_server.executor import ExecutionManager, ExecutionScheduler from autogpt_server.executor import ExecutionManager, ExecutionScheduler
from autogpt_server.server.conn_manager import ConnectionManager from autogpt_server.server.conn_manager import ConnectionManager
from autogpt_server.server.ws_api import websocket_router as ws_impl import autogpt_server.server.ws_api
from autogpt_server.util.data import get_frontend_path from autogpt_server.util.data import get_frontend_path
from autogpt_server.util.service import expose # type: ignore from autogpt_server.util.service import expose # type: ignore
from autogpt_server.util.service import AppService, get_service_client from autogpt_server.util.service import AppService, get_service_client
from autogpt_server.util.settings import Settings from autogpt_server.util.settings import Settings
from autogpt_server.server.model import WsMessage, Methods
class AgentServer(AppService): class AgentServer(AppService):
@ -73,7 +80,7 @@ class AgentServer(AppService):
) )
router.add_api_route( router.add_api_route(
path="/blocks/{block_id}/execute", path="/blocks/{block_id}/execute",
endpoint=self.execute_graph_block, endpoint=self.execute_graph_block, # type: ignore
methods=["POST"], methods=["POST"],
) )
router.add_api_route( router.add_api_route(
@ -128,7 +135,7 @@ class AgentServer(AppService):
methods=["POST"], methods=["POST"],
) )
app.add_exception_handler(500, self.handle_internal_error) app.add_exception_handler(500, self.handle_internal_error) # type: ignore
app.mount( app.mount(
path="/frontend", path="/frontend",
@ -140,7 +147,7 @@ class AgentServer(AppService):
@app.websocket("/ws") @app.websocket("/ws")
async def websocket_endpoint(websocket: WebSocket): # type: ignore async def websocket_endpoint(websocket: WebSocket): # type: ignore
await ws_impl(websocket, self.manager) await self.websocket_router(websocket)
uvicorn.run(app, host="0.0.0.0", port=8000) uvicorn.run(app, host="0.0.0.0", port=8000)
@ -153,22 +160,191 @@ class AgentServer(AppService):
return get_service_client(ExecutionScheduler) return get_service_client(ExecutionScheduler)
@classmethod @classmethod
def handle_internal_error(cls, request, exc): def handle_internal_error(cls, request, exc): # type: ignore
return JSONResponse( return JSONResponse(
content={ content={
"message": f"{request.url.path} call failure", "message": f"{request.url.path} call failure", # type: ignore
"error": str(exc), "error": str(exc), # type: ignore
}, },
status_code=500, status_code=500,
) )
@classmethod async def websocket_router(self, websocket: WebSocket):
def get_graph_blocks(cls) -> list[dict[Any, Any]]: await self.manager.connect(websocket)
return [v.to_dict() for v in block.get_blocks().values()] try:
while True:
data = await websocket.receive_text()
message = WsMessage.model_validate_json(data)
if message.method == Methods.SUBSCRIBE:
await autogpt_server.server.ws_api.handle_subscribe(
websocket, self.manager, message
)
elif message.method == Methods.UNSUBSCRIBE:
await autogpt_server.server.ws_api.handle_unsubscribe(
websocket, self.manager, message
)
elif message.method == Methods.EXECUTION_EVENT:
print("Execution event received")
elif message.method == Methods.GET_BLOCKS:
data = self.get_graph_blocks()
await websocket.send_text(
WsMessage(
method=Methods.GET_BLOCKS,
success=True,
data=data,
).model_dump_json()
)
elif message.method == Methods.EXECUTE_BLOCK:
assert isinstance(message.data, dict), "Data must be a dictionary"
data = self.execute_graph_block(
message.data["block_id"], message.data["data"]
)
await websocket.send_text(
WsMessage(
method=Methods.EXECUTE_BLOCK,
success=True,
data=data,
).model_dump_json()
)
elif message.method == Methods.GET_GRAPHS:
data = await self.get_graphs()
await websocket.send_text(
WsMessage(
method=Methods.GET_GRAPHS,
success=True,
data=data,
).model_dump_json()
)
print("Get graphs request received")
elif message.method == Methods.GET_GRAPH:
assert isinstance(message.data, dict), "Data must be a dictionary"
data = await self.get_graph(message.data["graph_id"])
await websocket.send_text(
WsMessage(
method=Methods.GET_GRAPH,
success=True,
data=data.model_dump(),
).model_dump_json()
)
print("Get graph request received")
elif message.method == Methods.CREATE_GRAPH:
assert isinstance(message.data, dict), "Data must be a dictionary"
graph = Graph.model_validate(message.data)
data = await self.create_new_graph(graph)
await websocket.send_text(
WsMessage(
method=Methods.CREATE_GRAPH,
success=True,
data=data.model_dump(),
).model_dump_json()
)
print("Create graph request received")
elif message.method == Methods.RUN_GRAPH:
assert isinstance(message.data, dict), "Data must be a dictionary"
data = await self.execute_graph(
message.data["graph_id"], message.data["data"]
)
await websocket.send_text(
WsMessage(
method=Methods.RUN_GRAPH,
success=True,
data=data,
).model_dump_json()
)
print("Run graph request received")
elif message.method == Methods.GET_GRAPH_RUNS:
assert isinstance(message.data, dict), "Data must be a dictionary"
data = await self.list_graph_runs(message.data["graph_id"])
await websocket.send_text(
WsMessage(
method=Methods.GET_GRAPH_RUNS,
success=True,
data=data,
).model_dump_json()
)
print("Get graph runs request received")
elif message.method == Methods.CREATE_SCHEDULED_RUN:
assert isinstance(message.data, dict), "Data must be a dictionary"
data = await self.create_schedule(
message.data["graph_id"],
message.data["cron"],
message.data["data"],
)
await websocket.send_text(
WsMessage(
method=Methods.CREATE_SCHEDULED_RUN,
success=True,
data=data,
).model_dump_json()
)
print("Create scheduled run request received")
elif message.method == Methods.GET_SCHEDULED_RUNS:
assert isinstance(message.data, dict), "Data must be a dictionary"
data = self.get_execution_schedules(message.data["graph_id"])
await websocket.send_text(
WsMessage(
method=Methods.GET_SCHEDULED_RUNS,
success=True,
data=data,
).model_dump_json()
)
print("Get scheduled runs request received")
elif message.method == Methods.UPDATE_SCHEDULED_RUN:
assert isinstance(message.data, dict), "Data must be a dictionary"
data = self.update_schedule(
message.data["schedule_id"], message.data
)
await websocket.send_text(
WsMessage(
method=Methods.UPDATE_SCHEDULED_RUN,
success=True,
data=data,
).model_dump_json()
)
print("Update scheduled run request received")
elif message.method == Methods.UPDATE_CONFIG:
assert isinstance(message.data, dict), "Data must be a dictionary"
data = self.update_configuration(message.data)
await websocket.send_text(
WsMessage(
method=Methods.UPDATE_CONFIG,
success=True,
data=data,
).model_dump_json()
)
print("Update config request received")
elif message.method == Methods.ERROR:
print("Error message received")
else:
print("Message type is not processed by the server")
await websocket.send_text(
WsMessage(
method=Methods.ERROR,
success=False,
error="Message type is not processed by the server",
).model_dump_json()
)
except WebSocketDisconnect:
self.manager.disconnect(websocket)
print("Client Disconnected")
@classmethod @classmethod
def execute_graph_block(cls, block_id: str, data: dict[str, Any]) -> list: def get_graph_blocks(cls) -> list[dict[Any, Any]]:
obj = block.get_block(block_id) return [v.to_dict() for v in block.get_blocks().values()] # type: ignore
@classmethod
def execute_graph_block(
cls, block_id: str, data: dict[str, Any]
) -> list[dict[str, Any]]:
obj = block.get_block(block_id) # type: ignore
if not obj: if not obj:
raise HTTPException(status_code=404, detail=f"Block #{block_id} not found.") raise HTTPException(status_code=404, detail=f"Block #{block_id} not found.")
return [{name: data} for name, data in obj.execute(data)] return [{name: data} for name, data in obj.execute(data)]
@ -252,9 +428,7 @@ class AgentServer(AppService):
@expose @expose
def send_execution_update(self, execution_result_dict: dict[Any, Any]): def send_execution_update(self, execution_result_dict: dict[Any, Any]):
execution_result = execution.ExecutionResult(**execution_result_dict) execution_result = execution.ExecutionResult(**execution_result_dict)
self.run_and_wait( self.run_and_wait(self.event_queue.put(execution_result))
self.event_queue.put(execution_result)
)
@classmethod @classmethod
def update_configuration( def update_configuration(
@ -275,12 +449,9 @@ class AgentServer(AppService):
setattr(settings.secrets, key, value) # type: ignore setattr(settings.secrets, key, value) # type: ignore
updated_fields["secrets"].append(key) updated_fields["secrets"].append(key)
settings.save() settings.save()
return JSONResponse( return {
content={ "message": "Settings updated successfully",
"message": "Settings updated successfully", "updated_fields": updated_fields,
"updated_fields": updated_fields, }
},
status_code=200,
)
except Exception as e: except Exception as e:
raise HTTPException(status_code=400, detail=str(e)) raise HTTPException(status_code=400, detail=str(e))

View File

@ -84,7 +84,7 @@ async def test_send_execution_result(
mock_websocket.send_text.assert_called_once_with( mock_websocket.send_text.assert_called_once_with(
WsMessage( WsMessage(
method=Methods.UPDATE, method=Methods.EXECUTION_EVENT,
channel="test_graph", channel="test_graph",
data=result.model_dump(), data=result.model_dump(),
).model_dump_json() ).model_dump_json()

View File

@ -75,7 +75,8 @@ async def test_websocket_router_invalid_method(
mock_websocket: AsyncMock, mock_manager: AsyncMock mock_websocket: AsyncMock, mock_manager: AsyncMock
) -> None: ) -> None:
mock_websocket.receive_text.side_effect = [ mock_websocket.receive_text.side_effect = [
WsMessage(method=Methods.UPDATE).model_dump_json(),
WsMessage(method=Methods.EXECUTION_EVENT).model_dump_json(),
WebSocketDisconnect(), WebSocketDisconnect(),
] ]