fix(backend): Avoid long synchronous call to block FastAPI event-loop (#8429)
parent
1e620fdb13
commit
8938209d0d
|
@ -6,7 +6,7 @@ if TYPE_CHECKING:
|
|||
from redis import Redis
|
||||
from backend.executor.database import DatabaseManager
|
||||
|
||||
from autogpt_libs.utils.cache import thread_cached_property
|
||||
from autogpt_libs.utils.cache import thread_cached
|
||||
from autogpt_libs.utils.synchronize import RedisKeyedMutex
|
||||
|
||||
from .types import (
|
||||
|
@ -21,8 +21,9 @@ from .types import (
|
|||
class SupabaseIntegrationCredentialsStore:
|
||||
def __init__(self, redis: "Redis"):
|
||||
self.locks = RedisKeyedMutex(redis)
|
||||
|
||||
@thread_cached_property
|
||||
|
||||
@property
|
||||
@thread_cached
|
||||
def db_manager(self) -> "DatabaseManager":
|
||||
from backend.executor.database import DatabaseManager
|
||||
from backend.util.service import get_service_client
|
||||
|
|
|
@ -1,8 +1,6 @@
|
|||
from typing import Callable, TypeVar, ParamSpec
|
||||
import threading
|
||||
from functools import wraps
|
||||
from typing import Callable, ParamSpec, TypeVar
|
||||
|
||||
T = TypeVar("T")
|
||||
P = ParamSpec("P")
|
||||
R = TypeVar("R")
|
||||
|
||||
|
@ -10,7 +8,6 @@ R = TypeVar("R")
|
|||
def thread_cached(func: Callable[P, R]) -> Callable[P, R]:
|
||||
thread_local = threading.local()
|
||||
|
||||
@wraps(func)
|
||||
def wrapper(*args: P.args, **kwargs: P.kwargs) -> R:
|
||||
cache = getattr(thread_local, "cache", None)
|
||||
if cache is None:
|
||||
|
@ -21,7 +18,3 @@ def thread_cached(func: Callable[P, R]) -> Callable[P, R]:
|
|||
return cache[key]
|
||||
|
||||
return wrapper
|
||||
|
||||
|
||||
def thread_cached_property(func: Callable[[T], R]) -> property:
|
||||
return property(thread_cached(func))
|
||||
|
|
|
@ -4,7 +4,7 @@ from datetime import datetime
|
|||
|
||||
from apscheduler.schedulers.background import BackgroundScheduler
|
||||
from apscheduler.triggers.cron import CronTrigger
|
||||
from autogpt_libs.utils.cache import thread_cached_property
|
||||
from autogpt_libs.utils.cache import thread_cached
|
||||
|
||||
from backend.data.block import BlockInput
|
||||
from backend.data.schedule import (
|
||||
|
@ -37,7 +37,8 @@ class ExecutionScheduler(AppService):
|
|||
def get_port(cls) -> int:
|
||||
return Config().execution_scheduler_port
|
||||
|
||||
@thread_cached_property
|
||||
@property
|
||||
@thread_cached
|
||||
def execution_client(self) -> ExecutionManager:
|
||||
return get_service_client(ExecutionManager)
|
||||
|
||||
|
|
|
@ -29,7 +29,7 @@ class LoginResponse(BaseModel):
|
|||
|
||||
|
||||
@router.get("/{provider}/login")
|
||||
async def login(
|
||||
def login(
|
||||
provider: Annotated[str, Path(title="The provider to initiate an OAuth flow for")],
|
||||
user_id: Annotated[str, Depends(get_user_id)],
|
||||
request: Request,
|
||||
|
@ -60,7 +60,7 @@ class CredentialsMetaResponse(BaseModel):
|
|||
|
||||
|
||||
@router.post("/{provider}/callback")
|
||||
async def callback(
|
||||
def callback(
|
||||
provider: Annotated[str, Path(title="The target provider for this OAuth exchange")],
|
||||
code: Annotated[str, Body(title="Authorization code acquired by user login")],
|
||||
state_token: Annotated[str, Body(title="Anti-CSRF nonce")],
|
||||
|
@ -115,7 +115,7 @@ async def callback(
|
|||
|
||||
|
||||
@router.get("/{provider}/credentials")
|
||||
async def list_credentials(
|
||||
def list_credentials(
|
||||
provider: Annotated[str, Path(title="The provider to list credentials for")],
|
||||
user_id: Annotated[str, Depends(get_user_id)],
|
||||
) -> list[CredentialsMetaResponse]:
|
||||
|
@ -133,7 +133,7 @@ async def list_credentials(
|
|||
|
||||
|
||||
@router.get("/{provider}/credentials/{cred_id}")
|
||||
async def get_credential(
|
||||
def get_credential(
|
||||
provider: Annotated[str, Path(title="The provider to retrieve credentials for")],
|
||||
cred_id: Annotated[str, Path(title="The ID of the credentials to retrieve")],
|
||||
user_id: Annotated[str, Depends(get_user_id)],
|
||||
|
@ -149,7 +149,7 @@ async def get_credential(
|
|||
|
||||
|
||||
@router.post("/{provider}/credentials", status_code=201)
|
||||
async def create_api_key_credentials(
|
||||
def create_api_key_credentials(
|
||||
user_id: Annotated[str, Depends(get_user_id)],
|
||||
provider: Annotated[str, Path(title="The provider to create credentials for")],
|
||||
api_key: Annotated[str, Body(title="The API key to store")],
|
||||
|
@ -184,7 +184,7 @@ class CredentialsDeletionResponse(BaseModel):
|
|||
|
||||
|
||||
@router.delete("/{provider}/credentials/{cred_id}")
|
||||
async def delete_credentials(
|
||||
def delete_credentials(
|
||||
request: Request,
|
||||
provider: Annotated[str, Path(title="The provider to delete credentials for")],
|
||||
cred_id: Annotated[str, Path(title="The ID of the credentials to delete")],
|
||||
|
|
|
@ -1,3 +1,4 @@
|
|||
import asyncio
|
||||
import inspect
|
||||
import logging
|
||||
from collections import defaultdict
|
||||
|
@ -7,7 +8,7 @@ from typing import Annotated, Any, Dict
|
|||
|
||||
import uvicorn
|
||||
from autogpt_libs.auth.middleware import auth_middleware
|
||||
from autogpt_libs.utils.cache import thread_cached_property
|
||||
from autogpt_libs.utils.cache import thread_cached
|
||||
from fastapi import APIRouter, Body, Depends, FastAPI, HTTPException, Request
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
from fastapi.responses import JSONResponse
|
||||
|
@ -307,11 +308,13 @@ class AgentServer(AppService):
|
|||
|
||||
return wrapper
|
||||
|
||||
@thread_cached_property
|
||||
@property
|
||||
@thread_cached
|
||||
def execution_manager_client(self) -> ExecutionManager:
|
||||
return get_service_client(ExecutionManager)
|
||||
|
||||
@thread_cached_property
|
||||
@property
|
||||
@thread_cached
|
||||
def execution_scheduler_client(self) -> ExecutionScheduler:
|
||||
return get_service_client(ExecutionScheduler)
|
||||
|
||||
|
@ -516,7 +519,7 @@ class AgentServer(AppService):
|
|||
user_id=user_id,
|
||||
)
|
||||
|
||||
async def execute_graph(
|
||||
def execute_graph(
|
||||
self,
|
||||
graph_id: str,
|
||||
node_input: dict[Any, Any],
|
||||
|
@ -539,7 +542,9 @@ class AgentServer(AppService):
|
|||
404, detail=f"Agent execution #{graph_exec_id} not found"
|
||||
)
|
||||
|
||||
self.execution_manager_client.cancel_execution(graph_exec_id)
|
||||
await asyncio.to_thread(
|
||||
lambda: self.execution_manager_client.cancel_execution(graph_exec_id)
|
||||
)
|
||||
|
||||
# Retrieve & return canceled graph execution in its final state
|
||||
return await execution_db.get_execution_results(graph_exec_id)
|
||||
|
@ -614,10 +619,16 @@ class AgentServer(AppService):
|
|||
graph = await graph_db.get_graph(graph_id, user_id=user_id)
|
||||
if not graph:
|
||||
raise HTTPException(status_code=404, detail=f"Graph #{graph_id} not found.")
|
||||
execution_scheduler = self.execution_scheduler_client
|
||||
|
||||
return {
|
||||
"id": execution_scheduler.add_execution_schedule(
|
||||
graph_id, graph.version, cron, input_data, user_id=user_id
|
||||
"id": await asyncio.to_thread(
|
||||
lambda: self.execution_scheduler_client.add_execution_schedule(
|
||||
graph_id=graph_id,
|
||||
graph_version=graph.version,
|
||||
cron=cron,
|
||||
input_data=input_data,
|
||||
user_id=user_id,
|
||||
)
|
||||
)
|
||||
}
|
||||
|
||||
|
|
|
@ -252,7 +252,7 @@ async def block_autogen_agent():
|
|||
test_user = await create_test_user()
|
||||
test_graph = await create_graph(create_test_graph(), user_id=test_user.id)
|
||||
input_data = {"input": "Write me a block that writes a string into a file."}
|
||||
response = await server.agent_server.execute_graph(
|
||||
response = server.agent_server.execute_graph(
|
||||
test_graph.id, input_data, test_user.id
|
||||
)
|
||||
print(response)
|
||||
|
|
|
@ -156,7 +156,7 @@ async def reddit_marketing_agent():
|
|||
test_user = await create_test_user()
|
||||
test_graph = await create_graph(create_test_graph(), user_id=test_user.id)
|
||||
input_data = {"subreddit": "AutoGPT"}
|
||||
response = await server.agent_server.execute_graph(
|
||||
response = server.agent_server.execute_graph(
|
||||
test_graph.id, input_data, test_user.id
|
||||
)
|
||||
print(response)
|
||||
|
|
|
@ -78,7 +78,7 @@ async def sample_agent():
|
|||
test_user = await create_test_user()
|
||||
test_graph = await create_graph(create_test_graph(), test_user.id)
|
||||
input_data = {"input_1": "Hello", "input_2": "World"}
|
||||
response = await server.agent_server.execute_graph(
|
||||
response = server.agent_server.execute_graph(
|
||||
test_graph.id, input_data, test_user.id
|
||||
)
|
||||
print(response)
|
||||
|
|
|
@ -22,7 +22,7 @@ async def execute_graph(
|
|||
num_execs: int = 4,
|
||||
) -> str:
|
||||
# --- Test adding new executions --- #
|
||||
response = await agent_server.execute_graph(test_graph.id, input_data, test_user.id)
|
||||
response = agent_server.execute_graph(test_graph.id, input_data, test_user.id)
|
||||
graph_exec_id = response["id"]
|
||||
|
||||
# Execution queue should be empty
|
||||
|
|
Loading…
Reference in New Issue