fix(backend): Avoid long synchronous call to block FastAPI event-loop (#8429)

pull/8456/head^2
Zamil Majdy 2024-10-27 23:54:38 +07:00 committed by GitHub
parent 1e620fdb13
commit 8938209d0d
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
9 changed files with 37 additions and 31 deletions

View File

@ -6,7 +6,7 @@ if TYPE_CHECKING:
from redis import Redis
from backend.executor.database import DatabaseManager
from autogpt_libs.utils.cache import thread_cached_property
from autogpt_libs.utils.cache import thread_cached
from autogpt_libs.utils.synchronize import RedisKeyedMutex
from .types import (
@ -21,8 +21,9 @@ from .types import (
class SupabaseIntegrationCredentialsStore:
def __init__(self, redis: "Redis"):
self.locks = RedisKeyedMutex(redis)
@thread_cached_property
@property
@thread_cached
def db_manager(self) -> "DatabaseManager":
from backend.executor.database import DatabaseManager
from backend.util.service import get_service_client

View File

@ -1,8 +1,6 @@
from typing import Callable, TypeVar, ParamSpec
import threading
from functools import wraps
from typing import Callable, ParamSpec, TypeVar
T = TypeVar("T")
P = ParamSpec("P")
R = TypeVar("R")
@ -10,7 +8,6 @@ R = TypeVar("R")
def thread_cached(func: Callable[P, R]) -> Callable[P, R]:
thread_local = threading.local()
@wraps(func)
def wrapper(*args: P.args, **kwargs: P.kwargs) -> R:
cache = getattr(thread_local, "cache", None)
if cache is None:
@ -21,7 +18,3 @@ def thread_cached(func: Callable[P, R]) -> Callable[P, R]:
return cache[key]
return wrapper
def thread_cached_property(func: Callable[[T], R]) -> property:
return property(thread_cached(func))

View File

@ -4,7 +4,7 @@ from datetime import datetime
from apscheduler.schedulers.background import BackgroundScheduler
from apscheduler.triggers.cron import CronTrigger
from autogpt_libs.utils.cache import thread_cached_property
from autogpt_libs.utils.cache import thread_cached
from backend.data.block import BlockInput
from backend.data.schedule import (
@ -37,7 +37,8 @@ class ExecutionScheduler(AppService):
def get_port(cls) -> int:
return Config().execution_scheduler_port
@thread_cached_property
@property
@thread_cached
def execution_client(self) -> ExecutionManager:
return get_service_client(ExecutionManager)

View File

@ -29,7 +29,7 @@ class LoginResponse(BaseModel):
@router.get("/{provider}/login")
async def login(
def login(
provider: Annotated[str, Path(title="The provider to initiate an OAuth flow for")],
user_id: Annotated[str, Depends(get_user_id)],
request: Request,
@ -60,7 +60,7 @@ class CredentialsMetaResponse(BaseModel):
@router.post("/{provider}/callback")
async def callback(
def callback(
provider: Annotated[str, Path(title="The target provider for this OAuth exchange")],
code: Annotated[str, Body(title="Authorization code acquired by user login")],
state_token: Annotated[str, Body(title="Anti-CSRF nonce")],
@ -115,7 +115,7 @@ async def callback(
@router.get("/{provider}/credentials")
async def list_credentials(
def list_credentials(
provider: Annotated[str, Path(title="The provider to list credentials for")],
user_id: Annotated[str, Depends(get_user_id)],
) -> list[CredentialsMetaResponse]:
@ -133,7 +133,7 @@ async def list_credentials(
@router.get("/{provider}/credentials/{cred_id}")
async def get_credential(
def get_credential(
provider: Annotated[str, Path(title="The provider to retrieve credentials for")],
cred_id: Annotated[str, Path(title="The ID of the credentials to retrieve")],
user_id: Annotated[str, Depends(get_user_id)],
@ -149,7 +149,7 @@ async def get_credential(
@router.post("/{provider}/credentials", status_code=201)
async def create_api_key_credentials(
def create_api_key_credentials(
user_id: Annotated[str, Depends(get_user_id)],
provider: Annotated[str, Path(title="The provider to create credentials for")],
api_key: Annotated[str, Body(title="The API key to store")],
@ -184,7 +184,7 @@ class CredentialsDeletionResponse(BaseModel):
@router.delete("/{provider}/credentials/{cred_id}")
async def delete_credentials(
def delete_credentials(
request: Request,
provider: Annotated[str, Path(title="The provider to delete credentials for")],
cred_id: Annotated[str, Path(title="The ID of the credentials to delete")],

View File

@ -1,3 +1,4 @@
import asyncio
import inspect
import logging
from collections import defaultdict
@ -7,7 +8,7 @@ from typing import Annotated, Any, Dict
import uvicorn
from autogpt_libs.auth.middleware import auth_middleware
from autogpt_libs.utils.cache import thread_cached_property
from autogpt_libs.utils.cache import thread_cached
from fastapi import APIRouter, Body, Depends, FastAPI, HTTPException, Request
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import JSONResponse
@ -307,11 +308,13 @@ class AgentServer(AppService):
return wrapper
@thread_cached_property
@property
@thread_cached
def execution_manager_client(self) -> ExecutionManager:
return get_service_client(ExecutionManager)
@thread_cached_property
@property
@thread_cached
def execution_scheduler_client(self) -> ExecutionScheduler:
return get_service_client(ExecutionScheduler)
@ -516,7 +519,7 @@ class AgentServer(AppService):
user_id=user_id,
)
async def execute_graph(
def execute_graph(
self,
graph_id: str,
node_input: dict[Any, Any],
@ -539,7 +542,9 @@ class AgentServer(AppService):
404, detail=f"Agent execution #{graph_exec_id} not found"
)
self.execution_manager_client.cancel_execution(graph_exec_id)
await asyncio.to_thread(
lambda: self.execution_manager_client.cancel_execution(graph_exec_id)
)
# Retrieve & return canceled graph execution in its final state
return await execution_db.get_execution_results(graph_exec_id)
@ -614,10 +619,16 @@ class AgentServer(AppService):
graph = await graph_db.get_graph(graph_id, user_id=user_id)
if not graph:
raise HTTPException(status_code=404, detail=f"Graph #{graph_id} not found.")
execution_scheduler = self.execution_scheduler_client
return {
"id": execution_scheduler.add_execution_schedule(
graph_id, graph.version, cron, input_data, user_id=user_id
"id": await asyncio.to_thread(
lambda: self.execution_scheduler_client.add_execution_schedule(
graph_id=graph_id,
graph_version=graph.version,
cron=cron,
input_data=input_data,
user_id=user_id,
)
)
}

View File

@ -252,7 +252,7 @@ async def block_autogen_agent():
test_user = await create_test_user()
test_graph = await create_graph(create_test_graph(), user_id=test_user.id)
input_data = {"input": "Write me a block that writes a string into a file."}
response = await server.agent_server.execute_graph(
response = server.agent_server.execute_graph(
test_graph.id, input_data, test_user.id
)
print(response)

View File

@ -156,7 +156,7 @@ async def reddit_marketing_agent():
test_user = await create_test_user()
test_graph = await create_graph(create_test_graph(), user_id=test_user.id)
input_data = {"subreddit": "AutoGPT"}
response = await server.agent_server.execute_graph(
response = server.agent_server.execute_graph(
test_graph.id, input_data, test_user.id
)
print(response)

View File

@ -78,7 +78,7 @@ async def sample_agent():
test_user = await create_test_user()
test_graph = await create_graph(create_test_graph(), test_user.id)
input_data = {"input_1": "Hello", "input_2": "World"}
response = await server.agent_server.execute_graph(
response = server.agent_server.execute_graph(
test_graph.id, input_data, test_user.id
)
print(response)

View File

@ -22,7 +22,7 @@ async def execute_graph(
num_execs: int = 4,
) -> str:
# --- Test adding new executions --- #
response = await agent_server.execute_graph(test_graph.id, input_data, test_user.id)
response = agent_server.execute_graph(test_graph.id, input_data, test_user.id)
graph_exec_id = response["id"]
# Execution queue should be empty