diff --git a/autogpt_platform/backend/backend/data/graph.py b/autogpt_platform/backend/backend/data/graph.py index b4f8f8aeb..ad7cc0a93 100644 --- a/autogpt_platform/backend/backend/data/graph.py +++ b/autogpt_platform/backend/backend/data/graph.py @@ -10,7 +10,7 @@ from prisma.types import AgentGraphInclude from pydantic import BaseModel from pydantic_core import PydanticUndefinedType -from backend.blocks.basic import AgentInputBlock, AgentOutputBlock +from backend.blocks.basic import AgentInputBlock, AgentOutputBlock, BlockType from backend.data.block import BlockInput, get_block, get_blocks from backend.data.db import BaseDbModel, transaction from backend.data.execution import ExecutionStatus @@ -209,16 +209,15 @@ class Graph(GraphMeta): if block is None: raise ValueError(f"Invalid block {node.block_id} for node #{node.id}") - if not for_run: - continue # Skip input completion validation, unless when executing. - provided_inputs = set( [sanitize(name) for name in node.input_default] + [sanitize(link.sink_name) for link in node.input_links] ) for name in block.input_schema.get_required_fields(): - if name not in provided_inputs and not isinstance( - block, AgentInputBlock + if name not in provided_inputs and ( + for_run # Skip input completion validation, unless when executing. + or block.block_type == BlockType.INPUT + or block.block_type == BlockType.OUTPUT ): raise ValueError( f"Node {block.name} #{node.id} required input missing: `{name}`" diff --git a/autogpt_platform/backend/backend/server/rest_api.py b/autogpt_platform/backend/backend/server/rest_api.py index 07d55a058..27c6679cc 100644 --- a/autogpt_platform/backend/backend/server/rest_api.py +++ b/autogpt_platform/backend/backend/server/rest_api.py @@ -70,19 +70,17 @@ async def health(): return {"status": "healthy"} -app = starlette.middleware.cors.CORSMiddleware( - app=app, - allow_origins=settings.config.backend_cors_allow_origins, - allow_credentials=True, - allow_methods=["*"], # Allows all methods - allow_headers=["*"], # Allows all headers -) - - class AgentServer(backend.util.service.AppProcess): def run(self): + server_app = starlette.middleware.cors.CORSMiddleware( + app=app, + allow_origins=settings.config.backend_cors_allow_origins, + allow_credentials=True, + allow_methods=["*"], # Allows all methods + allow_headers=["*"], # Allows all headers + ) uvicorn.run( - app, + server_app, host=backend.util.settings.Config().agent_api_host, port=backend.util.settings.Config().agent_api_port, ) diff --git a/autogpt_platform/backend/backend/server/ws_api.py b/autogpt_platform/backend/backend/server/ws_api.py index 2800182c2..d2fdf6d73 100644 --- a/autogpt_platform/backend/backend/server/ws_api.py +++ b/autogpt_platform/backend/backend/server/ws_api.py @@ -5,7 +5,7 @@ from contextlib import asynccontextmanager import uvicorn from autogpt_libs.auth import parse_jwt_token from fastapi import Depends, FastAPI, WebSocket, WebSocketDisconnect -from fastapi.middleware.cors import CORSMiddleware +from starlette.middleware.cors import CORSMiddleware from backend.data import redis from backend.data.queue import AsyncRedisExecutionEventBus @@ -31,15 +31,6 @@ docs_url = "/docs" if settings.config.app_env == AppEnvironment.LOCAL else None app = FastAPI(lifespan=lifespan, docs_url=docs_url) _connection_manager = None -logger.info(f"CORS allow origins: {settings.config.backend_cors_allow_origins}") -app.add_middleware( - CORSMiddleware, - allow_origins=settings.config.backend_cors_allow_origins, - allow_credentials=True, - allow_methods=["*"], - allow_headers=["*"], -) - def get_connection_manager(): global _connection_manager @@ -176,8 +167,16 @@ async def websocket_router( class WebsocketServer(AppProcess): def run(self): + logger.info(f"CORS allow origins: {settings.config.backend_cors_allow_origins}") + server_app = CORSMiddleware( + app=app, + allow_origins=settings.config.backend_cors_allow_origins, + allow_credentials=True, + allow_methods=["*"], + allow_headers=["*"], + ) uvicorn.run( - app, + server_app, host=Config().websocket_server_host, port=Config().websocket_server_port, )