feat(autogpt_server): Expose rest api via websocket (#7350)
* Add in websocket event types * adding in api endpoints * Updated ws messagespull/7363/head^2
parent
f94e81f48b
commit
3789b00479
|
@ -34,7 +34,7 @@ class ConnectionManager:
|
||||||
graph_id = result.graph_id
|
graph_id = result.graph_id
|
||||||
if graph_id in self.subscriptions:
|
if graph_id in self.subscriptions:
|
||||||
message = WsMessage(
|
message = WsMessage(
|
||||||
method=Methods.UPDATE,
|
method=Methods.EXECUTION_EVENT,
|
||||||
channel=graph_id,
|
channel=graph_id,
|
||||||
data=result.model_dump()
|
data=result.model_dump()
|
||||||
).model_dump_json()
|
).model_dump_json()
|
||||||
|
|
|
@ -6,13 +6,24 @@ import pydantic
|
||||||
class Methods(enum.Enum):
|
class Methods(enum.Enum):
|
||||||
SUBSCRIBE = "subscribe"
|
SUBSCRIBE = "subscribe"
|
||||||
UNSUBSCRIBE = "unsubscribe"
|
UNSUBSCRIBE = "unsubscribe"
|
||||||
UPDATE = "update"
|
EXECUTION_EVENT = "execution_event"
|
||||||
|
GET_BLOCKS = "get_blocks"
|
||||||
|
EXECUTE_BLOCK = "execute_block"
|
||||||
|
GET_GRAPHS = "get_graphs"
|
||||||
|
GET_GRAPH = "get_graph"
|
||||||
|
CREATE_GRAPH = "create_graph"
|
||||||
|
RUN_GRAPH = "run_graph"
|
||||||
|
GET_GRAPH_RUNS = "get_graph_runs"
|
||||||
|
CREATE_SCHEDULED_RUN = "create_scheduled_run"
|
||||||
|
GET_SCHEDULED_RUNS = "get_scheduled_runs"
|
||||||
|
UPDATE_SCHEDULED_RUN = "update_scheduled_run"
|
||||||
|
UPDATE_CONFIG = "update_config"
|
||||||
ERROR = "error"
|
ERROR = "error"
|
||||||
|
|
||||||
|
|
||||||
class WsMessage(pydantic.BaseModel):
|
class WsMessage(pydantic.BaseModel):
|
||||||
method: Methods
|
method: Methods
|
||||||
data: typing.Dict[str, typing.Any] | None = None
|
data: typing.Dict[str, typing.Any] | list[typing.Any] | None = None
|
||||||
success: bool | None = None
|
success: bool | None = None
|
||||||
channel: str | None = None
|
channel: str | None = None
|
||||||
error: str | None = None
|
error: str | None = None
|
||||||
|
|
|
@ -3,12 +3,18 @@ import uuid
|
||||||
from typing import Annotated, Any, Dict
|
from typing import Annotated, Any, Dict
|
||||||
|
|
||||||
import uvicorn
|
import uvicorn
|
||||||
from fastapi import WebSocket
|
|
||||||
from fastapi.responses import JSONResponse
|
from fastapi.responses import JSONResponse
|
||||||
from fastapi.staticfiles import StaticFiles
|
from fastapi.staticfiles import StaticFiles
|
||||||
|
|
||||||
from contextlib import asynccontextmanager
|
from contextlib import asynccontextmanager
|
||||||
from fastapi import APIRouter, Body, FastAPI, HTTPException
|
from fastapi import (
|
||||||
|
APIRouter,
|
||||||
|
Body,
|
||||||
|
FastAPI,
|
||||||
|
HTTPException,
|
||||||
|
WebSocket,
|
||||||
|
WebSocketDisconnect,
|
||||||
|
)
|
||||||
from fastapi.middleware.cors import CORSMiddleware
|
from fastapi.middleware.cors import CORSMiddleware
|
||||||
|
|
||||||
from autogpt_server.data import db, execution, block
|
from autogpt_server.data import db, execution, block
|
||||||
|
@ -20,11 +26,12 @@ from autogpt_server.data.graph import (
|
||||||
)
|
)
|
||||||
from autogpt_server.executor import ExecutionManager, ExecutionScheduler
|
from autogpt_server.executor import ExecutionManager, ExecutionScheduler
|
||||||
from autogpt_server.server.conn_manager import ConnectionManager
|
from autogpt_server.server.conn_manager import ConnectionManager
|
||||||
from autogpt_server.server.ws_api import websocket_router as ws_impl
|
import autogpt_server.server.ws_api
|
||||||
from autogpt_server.util.data import get_frontend_path
|
from autogpt_server.util.data import get_frontend_path
|
||||||
from autogpt_server.util.service import expose # type: ignore
|
from autogpt_server.util.service import expose # type: ignore
|
||||||
from autogpt_server.util.service import AppService, get_service_client
|
from autogpt_server.util.service import AppService, get_service_client
|
||||||
from autogpt_server.util.settings import Settings
|
from autogpt_server.util.settings import Settings
|
||||||
|
from autogpt_server.server.model import WsMessage, Methods
|
||||||
|
|
||||||
|
|
||||||
class AgentServer(AppService):
|
class AgentServer(AppService):
|
||||||
|
@ -73,7 +80,7 @@ class AgentServer(AppService):
|
||||||
)
|
)
|
||||||
router.add_api_route(
|
router.add_api_route(
|
||||||
path="/blocks/{block_id}/execute",
|
path="/blocks/{block_id}/execute",
|
||||||
endpoint=self.execute_graph_block,
|
endpoint=self.execute_graph_block, # type: ignore
|
||||||
methods=["POST"],
|
methods=["POST"],
|
||||||
)
|
)
|
||||||
router.add_api_route(
|
router.add_api_route(
|
||||||
|
@ -128,7 +135,7 @@ class AgentServer(AppService):
|
||||||
methods=["POST"],
|
methods=["POST"],
|
||||||
)
|
)
|
||||||
|
|
||||||
app.add_exception_handler(500, self.handle_internal_error)
|
app.add_exception_handler(500, self.handle_internal_error) # type: ignore
|
||||||
|
|
||||||
app.mount(
|
app.mount(
|
||||||
path="/frontend",
|
path="/frontend",
|
||||||
|
@ -140,7 +147,7 @@ class AgentServer(AppService):
|
||||||
|
|
||||||
@app.websocket("/ws")
|
@app.websocket("/ws")
|
||||||
async def websocket_endpoint(websocket: WebSocket): # type: ignore
|
async def websocket_endpoint(websocket: WebSocket): # type: ignore
|
||||||
await ws_impl(websocket, self.manager)
|
await self.websocket_router(websocket)
|
||||||
|
|
||||||
uvicorn.run(app, host="0.0.0.0", port=8000)
|
uvicorn.run(app, host="0.0.0.0", port=8000)
|
||||||
|
|
||||||
|
@ -153,22 +160,191 @@ class AgentServer(AppService):
|
||||||
return get_service_client(ExecutionScheduler)
|
return get_service_client(ExecutionScheduler)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def handle_internal_error(cls, request, exc):
|
def handle_internal_error(cls, request, exc): # type: ignore
|
||||||
return JSONResponse(
|
return JSONResponse(
|
||||||
content={
|
content={
|
||||||
"message": f"{request.url.path} call failure",
|
"message": f"{request.url.path} call failure", # type: ignore
|
||||||
"error": str(exc),
|
"error": str(exc), # type: ignore
|
||||||
},
|
},
|
||||||
status_code=500,
|
status_code=500,
|
||||||
)
|
)
|
||||||
|
|
||||||
@classmethod
|
async def websocket_router(self, websocket: WebSocket):
|
||||||
def get_graph_blocks(cls) -> list[dict[Any, Any]]:
|
await self.manager.connect(websocket)
|
||||||
return [v.to_dict() for v in block.get_blocks().values()]
|
try:
|
||||||
|
while True:
|
||||||
|
data = await websocket.receive_text()
|
||||||
|
message = WsMessage.model_validate_json(data)
|
||||||
|
if message.method == Methods.SUBSCRIBE:
|
||||||
|
await autogpt_server.server.ws_api.handle_subscribe(
|
||||||
|
websocket, self.manager, message
|
||||||
|
)
|
||||||
|
|
||||||
|
elif message.method == Methods.UNSUBSCRIBE:
|
||||||
|
await autogpt_server.server.ws_api.handle_unsubscribe(
|
||||||
|
websocket, self.manager, message
|
||||||
|
)
|
||||||
|
elif message.method == Methods.EXECUTION_EVENT:
|
||||||
|
print("Execution event received")
|
||||||
|
elif message.method == Methods.GET_BLOCKS:
|
||||||
|
data = self.get_graph_blocks()
|
||||||
|
await websocket.send_text(
|
||||||
|
WsMessage(
|
||||||
|
method=Methods.GET_BLOCKS,
|
||||||
|
success=True,
|
||||||
|
data=data,
|
||||||
|
).model_dump_json()
|
||||||
|
)
|
||||||
|
elif message.method == Methods.EXECUTE_BLOCK:
|
||||||
|
assert isinstance(message.data, dict), "Data must be a dictionary"
|
||||||
|
data = self.execute_graph_block(
|
||||||
|
message.data["block_id"], message.data["data"]
|
||||||
|
)
|
||||||
|
await websocket.send_text(
|
||||||
|
WsMessage(
|
||||||
|
method=Methods.EXECUTE_BLOCK,
|
||||||
|
success=True,
|
||||||
|
data=data,
|
||||||
|
).model_dump_json()
|
||||||
|
)
|
||||||
|
elif message.method == Methods.GET_GRAPHS:
|
||||||
|
data = await self.get_graphs()
|
||||||
|
await websocket.send_text(
|
||||||
|
WsMessage(
|
||||||
|
method=Methods.GET_GRAPHS,
|
||||||
|
success=True,
|
||||||
|
data=data,
|
||||||
|
).model_dump_json()
|
||||||
|
)
|
||||||
|
print("Get graphs request received")
|
||||||
|
elif message.method == Methods.GET_GRAPH:
|
||||||
|
assert isinstance(message.data, dict), "Data must be a dictionary"
|
||||||
|
data = await self.get_graph(message.data["graph_id"])
|
||||||
|
await websocket.send_text(
|
||||||
|
WsMessage(
|
||||||
|
method=Methods.GET_GRAPH,
|
||||||
|
success=True,
|
||||||
|
data=data.model_dump(),
|
||||||
|
).model_dump_json()
|
||||||
|
)
|
||||||
|
print("Get graph request received")
|
||||||
|
elif message.method == Methods.CREATE_GRAPH:
|
||||||
|
assert isinstance(message.data, dict), "Data must be a dictionary"
|
||||||
|
graph = Graph.model_validate(message.data)
|
||||||
|
data = await self.create_new_graph(graph)
|
||||||
|
await websocket.send_text(
|
||||||
|
WsMessage(
|
||||||
|
method=Methods.CREATE_GRAPH,
|
||||||
|
success=True,
|
||||||
|
data=data.model_dump(),
|
||||||
|
).model_dump_json()
|
||||||
|
)
|
||||||
|
|
||||||
|
print("Create graph request received")
|
||||||
|
elif message.method == Methods.RUN_GRAPH:
|
||||||
|
assert isinstance(message.data, dict), "Data must be a dictionary"
|
||||||
|
data = await self.execute_graph(
|
||||||
|
message.data["graph_id"], message.data["data"]
|
||||||
|
)
|
||||||
|
await websocket.send_text(
|
||||||
|
WsMessage(
|
||||||
|
method=Methods.RUN_GRAPH,
|
||||||
|
success=True,
|
||||||
|
data=data,
|
||||||
|
).model_dump_json()
|
||||||
|
)
|
||||||
|
|
||||||
|
print("Run graph request received")
|
||||||
|
elif message.method == Methods.GET_GRAPH_RUNS:
|
||||||
|
assert isinstance(message.data, dict), "Data must be a dictionary"
|
||||||
|
data = await self.list_graph_runs(message.data["graph_id"])
|
||||||
|
await websocket.send_text(
|
||||||
|
WsMessage(
|
||||||
|
method=Methods.GET_GRAPH_RUNS,
|
||||||
|
success=True,
|
||||||
|
data=data,
|
||||||
|
).model_dump_json()
|
||||||
|
)
|
||||||
|
|
||||||
|
print("Get graph runs request received")
|
||||||
|
elif message.method == Methods.CREATE_SCHEDULED_RUN:
|
||||||
|
assert isinstance(message.data, dict), "Data must be a dictionary"
|
||||||
|
data = await self.create_schedule(
|
||||||
|
message.data["graph_id"],
|
||||||
|
message.data["cron"],
|
||||||
|
message.data["data"],
|
||||||
|
)
|
||||||
|
await websocket.send_text(
|
||||||
|
WsMessage(
|
||||||
|
method=Methods.CREATE_SCHEDULED_RUN,
|
||||||
|
success=True,
|
||||||
|
data=data,
|
||||||
|
).model_dump_json()
|
||||||
|
)
|
||||||
|
|
||||||
|
print("Create scheduled run request received")
|
||||||
|
elif message.method == Methods.GET_SCHEDULED_RUNS:
|
||||||
|
assert isinstance(message.data, dict), "Data must be a dictionary"
|
||||||
|
data = self.get_execution_schedules(message.data["graph_id"])
|
||||||
|
await websocket.send_text(
|
||||||
|
WsMessage(
|
||||||
|
method=Methods.GET_SCHEDULED_RUNS,
|
||||||
|
success=True,
|
||||||
|
data=data,
|
||||||
|
).model_dump_json()
|
||||||
|
)
|
||||||
|
print("Get scheduled runs request received")
|
||||||
|
elif message.method == Methods.UPDATE_SCHEDULED_RUN:
|
||||||
|
assert isinstance(message.data, dict), "Data must be a dictionary"
|
||||||
|
data = self.update_schedule(
|
||||||
|
message.data["schedule_id"], message.data
|
||||||
|
)
|
||||||
|
await websocket.send_text(
|
||||||
|
WsMessage(
|
||||||
|
method=Methods.UPDATE_SCHEDULED_RUN,
|
||||||
|
success=True,
|
||||||
|
data=data,
|
||||||
|
).model_dump_json()
|
||||||
|
)
|
||||||
|
|
||||||
|
print("Update scheduled run request received")
|
||||||
|
elif message.method == Methods.UPDATE_CONFIG:
|
||||||
|
assert isinstance(message.data, dict), "Data must be a dictionary"
|
||||||
|
data = self.update_configuration(message.data)
|
||||||
|
await websocket.send_text(
|
||||||
|
WsMessage(
|
||||||
|
method=Methods.UPDATE_CONFIG,
|
||||||
|
success=True,
|
||||||
|
data=data,
|
||||||
|
).model_dump_json()
|
||||||
|
)
|
||||||
|
|
||||||
|
print("Update config request received")
|
||||||
|
elif message.method == Methods.ERROR:
|
||||||
|
print("Error message received")
|
||||||
|
else:
|
||||||
|
print("Message type is not processed by the server")
|
||||||
|
await websocket.send_text(
|
||||||
|
WsMessage(
|
||||||
|
method=Methods.ERROR,
|
||||||
|
success=False,
|
||||||
|
error="Message type is not processed by the server",
|
||||||
|
).model_dump_json()
|
||||||
|
)
|
||||||
|
|
||||||
|
except WebSocketDisconnect:
|
||||||
|
self.manager.disconnect(websocket)
|
||||||
|
print("Client Disconnected")
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def execute_graph_block(cls, block_id: str, data: dict[str, Any]) -> list:
|
def get_graph_blocks(cls) -> list[dict[Any, Any]]:
|
||||||
obj = block.get_block(block_id)
|
return [v.to_dict() for v in block.get_blocks().values()] # type: ignore
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def execute_graph_block(
|
||||||
|
cls, block_id: str, data: dict[str, Any]
|
||||||
|
) -> list[dict[str, Any]]:
|
||||||
|
obj = block.get_block(block_id) # type: ignore
|
||||||
if not obj:
|
if not obj:
|
||||||
raise HTTPException(status_code=404, detail=f"Block #{block_id} not found.")
|
raise HTTPException(status_code=404, detail=f"Block #{block_id} not found.")
|
||||||
return [{name: data} for name, data in obj.execute(data)]
|
return [{name: data} for name, data in obj.execute(data)]
|
||||||
|
@ -252,9 +428,7 @@ class AgentServer(AppService):
|
||||||
@expose
|
@expose
|
||||||
def send_execution_update(self, execution_result_dict: dict[Any, Any]):
|
def send_execution_update(self, execution_result_dict: dict[Any, Any]):
|
||||||
execution_result = execution.ExecutionResult(**execution_result_dict)
|
execution_result = execution.ExecutionResult(**execution_result_dict)
|
||||||
self.run_and_wait(
|
self.run_and_wait(self.event_queue.put(execution_result))
|
||||||
self.event_queue.put(execution_result)
|
|
||||||
)
|
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def update_configuration(
|
def update_configuration(
|
||||||
|
@ -275,12 +449,9 @@ class AgentServer(AppService):
|
||||||
setattr(settings.secrets, key, value) # type: ignore
|
setattr(settings.secrets, key, value) # type: ignore
|
||||||
updated_fields["secrets"].append(key)
|
updated_fields["secrets"].append(key)
|
||||||
settings.save()
|
settings.save()
|
||||||
return JSONResponse(
|
return {
|
||||||
content={
|
"message": "Settings updated successfully",
|
||||||
"message": "Settings updated successfully",
|
"updated_fields": updated_fields,
|
||||||
"updated_fields": updated_fields,
|
}
|
||||||
},
|
|
||||||
status_code=200,
|
|
||||||
)
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
raise HTTPException(status_code=400, detail=str(e))
|
raise HTTPException(status_code=400, detail=str(e))
|
||||||
|
|
|
@ -84,7 +84,7 @@ async def test_send_execution_result(
|
||||||
|
|
||||||
mock_websocket.send_text.assert_called_once_with(
|
mock_websocket.send_text.assert_called_once_with(
|
||||||
WsMessage(
|
WsMessage(
|
||||||
method=Methods.UPDATE,
|
method=Methods.EXECUTION_EVENT,
|
||||||
channel="test_graph",
|
channel="test_graph",
|
||||||
data=result.model_dump(),
|
data=result.model_dump(),
|
||||||
).model_dump_json()
|
).model_dump_json()
|
||||||
|
|
|
@ -75,7 +75,8 @@ async def test_websocket_router_invalid_method(
|
||||||
mock_websocket: AsyncMock, mock_manager: AsyncMock
|
mock_websocket: AsyncMock, mock_manager: AsyncMock
|
||||||
) -> None:
|
) -> None:
|
||||||
mock_websocket.receive_text.side_effect = [
|
mock_websocket.receive_text.side_effect = [
|
||||||
WsMessage(method=Methods.UPDATE).model_dump_json(),
|
|
||||||
|
WsMessage(method=Methods.EXECUTION_EVENT).model_dump_json(),
|
||||||
WebSocketDisconnect(),
|
WebSocketDisconnect(),
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue