fix(backend): Fix DatabaseManager usage by calling it on-demand (#8404)

pull/8405/head
Zamil Majdy 2024-10-23 06:09:23 +03:00 committed by GitHub
parent 7f318685af
commit 17e79ad88d
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
11 changed files with 67 additions and 41 deletions

View File

@ -6,6 +6,7 @@ if TYPE_CHECKING:
from redis import Redis
from backend.executor.database import DatabaseManager
from autogpt_libs.utils.cache import thread_cached_property
from autogpt_libs.utils.synchronize import RedisKeyedMutex
from .types import (
@ -18,9 +19,14 @@ from .types import (
class SupabaseIntegrationCredentialsStore:
def __init__(self, redis: "Redis", db: "DatabaseManager"):
self.db_manager: DatabaseManager = db
def __init__(self, redis: "Redis"):
self.locks = RedisKeyedMutex(redis)
@thread_cached_property
def db_manager(self) -> "DatabaseManager":
from backend.executor.database import DatabaseManager
from backend.util.service import get_service_client
return get_service_client(DatabaseManager)
def add_creds(self, user_id: str, credentials: Credentials) -> None:
with self.locked_user_metadata(user_id):

View File

@ -27,11 +27,15 @@ R = TypeVar("R")
class DatabaseManager(AppService):
def __init__(self):
super().__init__(port=Config().database_api_port)
super().__init__()
self.use_db = True
self.use_redis = True
self.event_queue = RedisEventQueue()
@classmethod
def get_port(cls) -> int:
return Config().database_api_port
@expose
def send_execution_update(self, execution_result_dict: dict[Any, Any]):
self.event_queue.put(ExecutionResult(**execution_result_dict))

View File

@ -16,6 +16,8 @@ from redis.lock import Lock as RedisLock
if TYPE_CHECKING:
from backend.executor import DatabaseManager
from autogpt_libs.utils.cache import thread_cached
from backend.data import redis
from backend.data.block import Block, BlockData, BlockInput, BlockType, get_block
from backend.data.execution import (
@ -31,7 +33,6 @@ from backend.data.graph import Graph, Link, Node
from backend.data.model import CREDENTIALS_FIELD_NAME, CredentialsMetaInput
from backend.integrations.creds_manager import IntegrationCredentialsManager
from backend.util import json
from backend.util.cache import thread_cached
from backend.util.decorator import error_logged, time_measured
from backend.util.logging import configure_logging
from backend.util.process import set_service_name
@ -419,7 +420,7 @@ class Executor:
redis.connect()
cls.pid = os.getpid()
cls.db_client = get_db_client()
cls.creds_manager = IntegrationCredentialsManager(db_manager=cls.db_client)
cls.creds_manager = IntegrationCredentialsManager()
# Set up shutdown handlers
cls.shutdown_lock = threading.Lock()
@ -659,20 +660,24 @@ class Executor:
class ExecutionManager(AppService):
def __init__(self):
super().__init__(port=settings.config.execution_manager_port)
super().__init__()
self.use_redis = True
self.use_supabase = True
self.pool_size = settings.config.num_graph_workers
self.queue = ExecutionQueue[GraphExecution]()
self.active_graph_runs: dict[str, tuple[Future, threading.Event]] = {}
@classmethod
def get_port(cls) -> int:
return settings.config.execution_manager_port
def run_service(self):
from autogpt_libs.supabase_integration_credentials_store import (
SupabaseIntegrationCredentialsStore,
)
self.credentials_store = SupabaseIntegrationCredentialsStore(
redis=redis.get_redis(), db=self.db_client
redis=redis.get_redis()
)
self.executor = ProcessPoolExecutor(
max_workers=self.pool_size,
@ -863,7 +868,7 @@ class ExecutionManager(AppService):
def get_db_client() -> "DatabaseManager":
from backend.executor import DatabaseManager
return get_service_client(DatabaseManager, settings.config.database_api_port)
return get_service_client(DatabaseManager)
@contextmanager

View File

@ -4,6 +4,7 @@ from datetime import datetime
from apscheduler.schedulers.background import BackgroundScheduler
from apscheduler.triggers.cron import CronTrigger
from autogpt_libs.utils.cache import thread_cached_property
from backend.data.block import BlockInput
from backend.data.schedule import (
@ -14,7 +15,6 @@ from backend.data.schedule import (
update_schedule,
)
from backend.executor.manager import ExecutionManager
from backend.util.cache import thread_cached_property
from backend.util.service import AppService, expose, get_service_client
from backend.util.settings import Config
@ -28,14 +28,18 @@ def log(msg, **kwargs):
class ExecutionScheduler(AppService):
def __init__(self, refresh_interval=10):
super().__init__(port=Config().execution_scheduler_port)
super().__init__()
self.use_db = True
self.last_check = datetime.min
self.refresh_interval = refresh_interval
@classmethod
def get_port(cls) -> int:
return Config().execution_scheduler_port
@thread_cached_property
def execution_client(self) -> ExecutionManager:
return get_service_client(ExecutionManager, Config().execution_manager_port)
return get_service_client(ExecutionManager)
def run_service(self):
scheduler = BackgroundScheduler()

View File

@ -10,7 +10,6 @@ from autogpt_libs.utils.synchronize import RedisKeyedMutex
from redis.lock import Lock as RedisLock
from backend.data import redis
from backend.executor.database import DatabaseManager
from backend.integrations.oauth import HANDLERS_BY_NAME, BaseOAuthHandler
from backend.util.settings import Settings
@ -50,12 +49,10 @@ class IntegrationCredentialsManager:
cause so much latency that it's worth implementing.
"""
def __init__(self, db_manager: DatabaseManager):
def __init__(self):
redis_conn = redis.get_redis()
self._locks = RedisKeyedMutex(redis_conn)
self.store = SupabaseIntegrationCredentialsStore(
redis=redis_conn, db=db_manager
)
self.store = SupabaseIntegrationCredentialsStore(redis=redis_conn)
def create(self, user_id: str, credentials: Credentials) -> None:
return self.store.add_creds(user_id, credentials)

View File

@ -10,7 +10,6 @@ from autogpt_libs.supabase_integration_credentials_store.types import (
from fastapi import APIRouter, Body, Depends, HTTPException, Path, Query, Request
from pydantic import BaseModel, Field, SecretStr
from backend.executor.manager import get_db_client
from backend.integrations.creds_manager import IntegrationCredentialsManager
from backend.integrations.oauth import HANDLERS_BY_NAME, BaseOAuthHandler
from backend.util.settings import Settings
@ -21,7 +20,7 @@ logger = logging.getLogger(__name__)
settings = Settings()
router = APIRouter()
creds_manager = IntegrationCredentialsManager(db_manager=get_db_client())
creds_manager = IntegrationCredentialsManager()
class LoginResponse(BaseModel):

View File

@ -7,6 +7,7 @@ from typing import Annotated, Any, Dict
import uvicorn
from autogpt_libs.auth.middleware import auth_middleware
from autogpt_libs.utils.cache import thread_cached_property
from fastapi import APIRouter, Body, Depends, FastAPI, HTTPException, Request
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import JSONResponse
@ -19,10 +20,7 @@ from backend.data.block import BlockInput, CompletedBlockOutput
from backend.data.credit import get_block_costs, get_user_credit_model
from backend.data.user import get_or_create_user
from backend.executor import ExecutionManager, ExecutionScheduler
from backend.executor.manager import get_db_client
from backend.integrations.creds_manager import IntegrationCredentialsManager
from backend.server.model import CreateGraph, SetGraphActiveVersion
from backend.util.cache import thread_cached_property
from backend.util.service import AppService, get_service_client
from backend.util.settings import AppEnvironment, Config, Settings
@ -37,9 +35,13 @@ class AgentServer(AppService):
_user_credit_model = get_user_credit_model()
def __init__(self):
super().__init__(port=Config().agent_server_port)
super().__init__()
self.use_redis = True
@classmethod
def get_port(cls) -> int:
return Config().agent_server_port
@asynccontextmanager
async def lifespan(self, _: FastAPI):
await db.connect()
@ -98,7 +100,6 @@ class AgentServer(AppService):
tags=["integrations"],
dependencies=[Depends(auth_middleware)],
)
self.integration_creds_manager = IntegrationCredentialsManager(get_db_client())
api_router.include_router(
backend.server.routers.analytics.router,
@ -308,11 +309,11 @@ class AgentServer(AppService):
@thread_cached_property
def execution_manager_client(self) -> ExecutionManager:
return get_service_client(ExecutionManager, Config().execution_manager_port)
return get_service_client(ExecutionManager)
@thread_cached_property
def execution_scheduler_client(self) -> ExecutionScheduler:
return get_service_client(ExecutionScheduler, Config().execution_scheduler_port)
return get_service_client(ExecutionScheduler)
@classmethod
def handle_internal_http_error(cls, request: Request, exc: Exception):

View File

@ -5,6 +5,7 @@ import os
import threading
import time
import typing
from abc import ABC, abstractmethod
from enum import Enum
from types import NoneType, UnionType
from typing import (
@ -99,16 +100,24 @@ def _make_custom_deserializer(model: Type[BaseModel]):
return custom_dict_to_class
class AppService(AppProcess):
class AppService(AppProcess, ABC):
shared_event_loop: asyncio.AbstractEventLoop
use_db: bool = False
use_redis: bool = False
use_supabase: bool = False
def __init__(self, port):
self.port = port
def __init__(self):
self.uri = None
@classmethod
@abstractmethod
def get_port(cls) -> int:
pass
@classmethod
def get_host(cls) -> str:
return os.environ.get(f"{cls.service_name.upper()}_HOST", Config().pyro_host)
def run_service(self) -> None:
while True:
time.sleep(10)
@ -157,8 +166,7 @@ class AppService(AppProcess):
@conn_retry("Pyro", "Starting Pyro Service")
def __start_pyro(self):
host = Config().pyro_host
daemon = Pyro5.api.Daemon(host=host, port=self.port)
daemon = Pyro5.api.Daemon(host=self.get_host(), port=self.get_port())
self.uri = daemon.register(self, objectId=self.service_name)
logger.info(f"[{self.service_name}] Connected to Pyro; URI = {self.uri}")
daemon.requestLoop()
@ -167,16 +175,20 @@ class AppService(AppProcess):
self.shared_event_loop.run_forever()
# --------- UTILITIES --------- #
AS = TypeVar("AS", bound=AppService)
def get_service_client(service_type: Type[AS], port: int) -> AS:
def get_service_client(service_type: Type[AS]) -> AS:
service_name = service_type.service_name
class DynamicClient:
@conn_retry("Pyro", f"Connecting to [{service_name}]")
def __init__(self):
host = os.environ.get(f"{service_name.upper()}_HOST", "localhost")
host = service_type.get_host()
port = service_type.get_port()
uri = f"PYRO:{service_type.service_name}@{host}:{port}"
logger.debug(f"Connecting to service [{service_name}]. URI = {uri}")
self.proxy = Pyro5.api.Proxy(uri)
@ -191,8 +203,6 @@ def get_service_client(service_type: Type[AS], port: int) -> AS:
return cast(AS, DynamicClient())
# --------- UTILITIES --------- #
builtin_types = [*vars(builtins).values(), NoneType, Enum]

View File

@ -5,7 +5,6 @@ from backend.executor import ExecutionScheduler
from backend.server.model import CreateGraph
from backend.usecases.sample import create_test_graph, create_test_user
from backend.util.service import get_service_client
from backend.util.settings import Config
from backend.util.test import SpinTestServer
@ -19,10 +18,7 @@ async def test_agent_schedule(server: SpinTestServer):
user_id=test_user.id,
)
scheduler = get_service_client(
ExecutionScheduler, Config().execution_scheduler_port
)
scheduler = get_service_client(ExecutionScheduler)
schedules = scheduler.get_execution_schedules(test_graph.id, test_user.id)
assert len(schedules) == 0

View File

@ -7,7 +7,11 @@ TEST_SERVICE_PORT = 8765
class ServiceTest(AppService):
def __init__(self):
super().__init__(port=TEST_SERVICE_PORT)
super().__init__()
@classmethod
def get_port(cls) -> int:
return TEST_SERVICE_PORT
@expose
def add(self, a: int, b: int) -> int:
@ -28,7 +32,7 @@ class ServiceTest(AppService):
@pytest.mark.asyncio(scope="session")
async def test_service_creation(server):
with ServiceTest():
client = get_service_client(ServiceTest, TEST_SERVICE_PORT)
client = get_service_client(ServiceTest)
assert client.add(5, 3) == 8
assert client.subtract(10, 4) == 6
assert client.fun_with_async(5, 3) == 8