feat(autogpt_server): Expose rest api via websocket (#7350)
* Add in websocket event types * adding in api endpoints * Updated ws messagespull/7363/head^2
parent
f94e81f48b
commit
3789b00479
|
@ -34,7 +34,7 @@ class ConnectionManager:
|
|||
graph_id = result.graph_id
|
||||
if graph_id in self.subscriptions:
|
||||
message = WsMessage(
|
||||
method=Methods.UPDATE,
|
||||
method=Methods.EXECUTION_EVENT,
|
||||
channel=graph_id,
|
||||
data=result.model_dump()
|
||||
).model_dump_json()
|
||||
|
|
|
@ -6,13 +6,24 @@ import pydantic
|
|||
class Methods(enum.Enum):
|
||||
SUBSCRIBE = "subscribe"
|
||||
UNSUBSCRIBE = "unsubscribe"
|
||||
UPDATE = "update"
|
||||
EXECUTION_EVENT = "execution_event"
|
||||
GET_BLOCKS = "get_blocks"
|
||||
EXECUTE_BLOCK = "execute_block"
|
||||
GET_GRAPHS = "get_graphs"
|
||||
GET_GRAPH = "get_graph"
|
||||
CREATE_GRAPH = "create_graph"
|
||||
RUN_GRAPH = "run_graph"
|
||||
GET_GRAPH_RUNS = "get_graph_runs"
|
||||
CREATE_SCHEDULED_RUN = "create_scheduled_run"
|
||||
GET_SCHEDULED_RUNS = "get_scheduled_runs"
|
||||
UPDATE_SCHEDULED_RUN = "update_scheduled_run"
|
||||
UPDATE_CONFIG = "update_config"
|
||||
ERROR = "error"
|
||||
|
||||
|
||||
class WsMessage(pydantic.BaseModel):
|
||||
method: Methods
|
||||
data: typing.Dict[str, typing.Any] | None = None
|
||||
data: typing.Dict[str, typing.Any] | list[typing.Any] | None = None
|
||||
success: bool | None = None
|
||||
channel: str | None = None
|
||||
error: str | None = None
|
||||
|
|
|
@ -3,12 +3,18 @@ import uuid
|
|||
from typing import Annotated, Any, Dict
|
||||
|
||||
import uvicorn
|
||||
from fastapi import WebSocket
|
||||
from fastapi.responses import JSONResponse
|
||||
from fastapi.staticfiles import StaticFiles
|
||||
|
||||
from contextlib import asynccontextmanager
|
||||
from fastapi import APIRouter, Body, FastAPI, HTTPException
|
||||
from fastapi import (
|
||||
APIRouter,
|
||||
Body,
|
||||
FastAPI,
|
||||
HTTPException,
|
||||
WebSocket,
|
||||
WebSocketDisconnect,
|
||||
)
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
|
||||
from autogpt_server.data import db, execution, block
|
||||
|
@ -20,11 +26,12 @@ from autogpt_server.data.graph import (
|
|||
)
|
||||
from autogpt_server.executor import ExecutionManager, ExecutionScheduler
|
||||
from autogpt_server.server.conn_manager import ConnectionManager
|
||||
from autogpt_server.server.ws_api import websocket_router as ws_impl
|
||||
import autogpt_server.server.ws_api
|
||||
from autogpt_server.util.data import get_frontend_path
|
||||
from autogpt_server.util.service import expose # type: ignore
|
||||
from autogpt_server.util.service import AppService, get_service_client
|
||||
from autogpt_server.util.settings import Settings
|
||||
from autogpt_server.server.model import WsMessage, Methods
|
||||
|
||||
|
||||
class AgentServer(AppService):
|
||||
|
@ -73,7 +80,7 @@ class AgentServer(AppService):
|
|||
)
|
||||
router.add_api_route(
|
||||
path="/blocks/{block_id}/execute",
|
||||
endpoint=self.execute_graph_block,
|
||||
endpoint=self.execute_graph_block, # type: ignore
|
||||
methods=["POST"],
|
||||
)
|
||||
router.add_api_route(
|
||||
|
@ -128,7 +135,7 @@ class AgentServer(AppService):
|
|||
methods=["POST"],
|
||||
)
|
||||
|
||||
app.add_exception_handler(500, self.handle_internal_error)
|
||||
app.add_exception_handler(500, self.handle_internal_error) # type: ignore
|
||||
|
||||
app.mount(
|
||||
path="/frontend",
|
||||
|
@ -140,7 +147,7 @@ class AgentServer(AppService):
|
|||
|
||||
@app.websocket("/ws")
|
||||
async def websocket_endpoint(websocket: WebSocket): # type: ignore
|
||||
await ws_impl(websocket, self.manager)
|
||||
await self.websocket_router(websocket)
|
||||
|
||||
uvicorn.run(app, host="0.0.0.0", port=8000)
|
||||
|
||||
|
@ -153,22 +160,191 @@ class AgentServer(AppService):
|
|||
return get_service_client(ExecutionScheduler)
|
||||
|
||||
@classmethod
|
||||
def handle_internal_error(cls, request, exc):
|
||||
def handle_internal_error(cls, request, exc): # type: ignore
|
||||
return JSONResponse(
|
||||
content={
|
||||
"message": f"{request.url.path} call failure",
|
||||
"error": str(exc),
|
||||
"message": f"{request.url.path} call failure", # type: ignore
|
||||
"error": str(exc), # type: ignore
|
||||
},
|
||||
status_code=500,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def get_graph_blocks(cls) -> list[dict[Any, Any]]:
|
||||
return [v.to_dict() for v in block.get_blocks().values()]
|
||||
async def websocket_router(self, websocket: WebSocket):
|
||||
await self.manager.connect(websocket)
|
||||
try:
|
||||
while True:
|
||||
data = await websocket.receive_text()
|
||||
message = WsMessage.model_validate_json(data)
|
||||
if message.method == Methods.SUBSCRIBE:
|
||||
await autogpt_server.server.ws_api.handle_subscribe(
|
||||
websocket, self.manager, message
|
||||
)
|
||||
|
||||
elif message.method == Methods.UNSUBSCRIBE:
|
||||
await autogpt_server.server.ws_api.handle_unsubscribe(
|
||||
websocket, self.manager, message
|
||||
)
|
||||
elif message.method == Methods.EXECUTION_EVENT:
|
||||
print("Execution event received")
|
||||
elif message.method == Methods.GET_BLOCKS:
|
||||
data = self.get_graph_blocks()
|
||||
await websocket.send_text(
|
||||
WsMessage(
|
||||
method=Methods.GET_BLOCKS,
|
||||
success=True,
|
||||
data=data,
|
||||
).model_dump_json()
|
||||
)
|
||||
elif message.method == Methods.EXECUTE_BLOCK:
|
||||
assert isinstance(message.data, dict), "Data must be a dictionary"
|
||||
data = self.execute_graph_block(
|
||||
message.data["block_id"], message.data["data"]
|
||||
)
|
||||
await websocket.send_text(
|
||||
WsMessage(
|
||||
method=Methods.EXECUTE_BLOCK,
|
||||
success=True,
|
||||
data=data,
|
||||
).model_dump_json()
|
||||
)
|
||||
elif message.method == Methods.GET_GRAPHS:
|
||||
data = await self.get_graphs()
|
||||
await websocket.send_text(
|
||||
WsMessage(
|
||||
method=Methods.GET_GRAPHS,
|
||||
success=True,
|
||||
data=data,
|
||||
).model_dump_json()
|
||||
)
|
||||
print("Get graphs request received")
|
||||
elif message.method == Methods.GET_GRAPH:
|
||||
assert isinstance(message.data, dict), "Data must be a dictionary"
|
||||
data = await self.get_graph(message.data["graph_id"])
|
||||
await websocket.send_text(
|
||||
WsMessage(
|
||||
method=Methods.GET_GRAPH,
|
||||
success=True,
|
||||
data=data.model_dump(),
|
||||
).model_dump_json()
|
||||
)
|
||||
print("Get graph request received")
|
||||
elif message.method == Methods.CREATE_GRAPH:
|
||||
assert isinstance(message.data, dict), "Data must be a dictionary"
|
||||
graph = Graph.model_validate(message.data)
|
||||
data = await self.create_new_graph(graph)
|
||||
await websocket.send_text(
|
||||
WsMessage(
|
||||
method=Methods.CREATE_GRAPH,
|
||||
success=True,
|
||||
data=data.model_dump(),
|
||||
).model_dump_json()
|
||||
)
|
||||
|
||||
print("Create graph request received")
|
||||
elif message.method == Methods.RUN_GRAPH:
|
||||
assert isinstance(message.data, dict), "Data must be a dictionary"
|
||||
data = await self.execute_graph(
|
||||
message.data["graph_id"], message.data["data"]
|
||||
)
|
||||
await websocket.send_text(
|
||||
WsMessage(
|
||||
method=Methods.RUN_GRAPH,
|
||||
success=True,
|
||||
data=data,
|
||||
).model_dump_json()
|
||||
)
|
||||
|
||||
print("Run graph request received")
|
||||
elif message.method == Methods.GET_GRAPH_RUNS:
|
||||
assert isinstance(message.data, dict), "Data must be a dictionary"
|
||||
data = await self.list_graph_runs(message.data["graph_id"])
|
||||
await websocket.send_text(
|
||||
WsMessage(
|
||||
method=Methods.GET_GRAPH_RUNS,
|
||||
success=True,
|
||||
data=data,
|
||||
).model_dump_json()
|
||||
)
|
||||
|
||||
print("Get graph runs request received")
|
||||
elif message.method == Methods.CREATE_SCHEDULED_RUN:
|
||||
assert isinstance(message.data, dict), "Data must be a dictionary"
|
||||
data = await self.create_schedule(
|
||||
message.data["graph_id"],
|
||||
message.data["cron"],
|
||||
message.data["data"],
|
||||
)
|
||||
await websocket.send_text(
|
||||
WsMessage(
|
||||
method=Methods.CREATE_SCHEDULED_RUN,
|
||||
success=True,
|
||||
data=data,
|
||||
).model_dump_json()
|
||||
)
|
||||
|
||||
print("Create scheduled run request received")
|
||||
elif message.method == Methods.GET_SCHEDULED_RUNS:
|
||||
assert isinstance(message.data, dict), "Data must be a dictionary"
|
||||
data = self.get_execution_schedules(message.data["graph_id"])
|
||||
await websocket.send_text(
|
||||
WsMessage(
|
||||
method=Methods.GET_SCHEDULED_RUNS,
|
||||
success=True,
|
||||
data=data,
|
||||
).model_dump_json()
|
||||
)
|
||||
print("Get scheduled runs request received")
|
||||
elif message.method == Methods.UPDATE_SCHEDULED_RUN:
|
||||
assert isinstance(message.data, dict), "Data must be a dictionary"
|
||||
data = self.update_schedule(
|
||||
message.data["schedule_id"], message.data
|
||||
)
|
||||
await websocket.send_text(
|
||||
WsMessage(
|
||||
method=Methods.UPDATE_SCHEDULED_RUN,
|
||||
success=True,
|
||||
data=data,
|
||||
).model_dump_json()
|
||||
)
|
||||
|
||||
print("Update scheduled run request received")
|
||||
elif message.method == Methods.UPDATE_CONFIG:
|
||||
assert isinstance(message.data, dict), "Data must be a dictionary"
|
||||
data = self.update_configuration(message.data)
|
||||
await websocket.send_text(
|
||||
WsMessage(
|
||||
method=Methods.UPDATE_CONFIG,
|
||||
success=True,
|
||||
data=data,
|
||||
).model_dump_json()
|
||||
)
|
||||
|
||||
print("Update config request received")
|
||||
elif message.method == Methods.ERROR:
|
||||
print("Error message received")
|
||||
else:
|
||||
print("Message type is not processed by the server")
|
||||
await websocket.send_text(
|
||||
WsMessage(
|
||||
method=Methods.ERROR,
|
||||
success=False,
|
||||
error="Message type is not processed by the server",
|
||||
).model_dump_json()
|
||||
)
|
||||
|
||||
except WebSocketDisconnect:
|
||||
self.manager.disconnect(websocket)
|
||||
print("Client Disconnected")
|
||||
|
||||
@classmethod
|
||||
def execute_graph_block(cls, block_id: str, data: dict[str, Any]) -> list:
|
||||
obj = block.get_block(block_id)
|
||||
def get_graph_blocks(cls) -> list[dict[Any, Any]]:
|
||||
return [v.to_dict() for v in block.get_blocks().values()] # type: ignore
|
||||
|
||||
@classmethod
|
||||
def execute_graph_block(
|
||||
cls, block_id: str, data: dict[str, Any]
|
||||
) -> list[dict[str, Any]]:
|
||||
obj = block.get_block(block_id) # type: ignore
|
||||
if not obj:
|
||||
raise HTTPException(status_code=404, detail=f"Block #{block_id} not found.")
|
||||
return [{name: data} for name, data in obj.execute(data)]
|
||||
|
@ -252,9 +428,7 @@ class AgentServer(AppService):
|
|||
@expose
|
||||
def send_execution_update(self, execution_result_dict: dict[Any, Any]):
|
||||
execution_result = execution.ExecutionResult(**execution_result_dict)
|
||||
self.run_and_wait(
|
||||
self.event_queue.put(execution_result)
|
||||
)
|
||||
self.run_and_wait(self.event_queue.put(execution_result))
|
||||
|
||||
@classmethod
|
||||
def update_configuration(
|
||||
|
@ -275,12 +449,9 @@ class AgentServer(AppService):
|
|||
setattr(settings.secrets, key, value) # type: ignore
|
||||
updated_fields["secrets"].append(key)
|
||||
settings.save()
|
||||
return JSONResponse(
|
||||
content={
|
||||
"message": "Settings updated successfully",
|
||||
"updated_fields": updated_fields,
|
||||
},
|
||||
status_code=200,
|
||||
)
|
||||
return {
|
||||
"message": "Settings updated successfully",
|
||||
"updated_fields": updated_fields,
|
||||
}
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=400, detail=str(e))
|
||||
|
|
|
@ -84,7 +84,7 @@ async def test_send_execution_result(
|
|||
|
||||
mock_websocket.send_text.assert_called_once_with(
|
||||
WsMessage(
|
||||
method=Methods.UPDATE,
|
||||
method=Methods.EXECUTION_EVENT,
|
||||
channel="test_graph",
|
||||
data=result.model_dump(),
|
||||
).model_dump_json()
|
||||
|
|
|
@ -75,7 +75,8 @@ async def test_websocket_router_invalid_method(
|
|||
mock_websocket: AsyncMock, mock_manager: AsyncMock
|
||||
) -> None:
|
||||
mock_websocket.receive_text.side_effect = [
|
||||
WsMessage(method=Methods.UPDATE).model_dump_json(),
|
||||
|
||||
WsMessage(method=Methods.EXECUTION_EVENT).model_dump_json(),
|
||||
WebSocketDisconnect(),
|
||||
]
|
||||
|
||||
|
|
Loading…
Reference in New Issue