fix(backend): Fix DatabaseManager usage by calling it on-demand (#8404)
parent
7f318685af
commit
17e79ad88d
|
@ -6,6 +6,7 @@ if TYPE_CHECKING:
|
|||
from redis import Redis
|
||||
from backend.executor.database import DatabaseManager
|
||||
|
||||
from autogpt_libs.utils.cache import thread_cached_property
|
||||
from autogpt_libs.utils.synchronize import RedisKeyedMutex
|
||||
|
||||
from .types import (
|
||||
|
@ -18,9 +19,14 @@ from .types import (
|
|||
|
||||
|
||||
class SupabaseIntegrationCredentialsStore:
|
||||
def __init__(self, redis: "Redis", db: "DatabaseManager"):
|
||||
self.db_manager: DatabaseManager = db
|
||||
def __init__(self, redis: "Redis"):
|
||||
self.locks = RedisKeyedMutex(redis)
|
||||
|
||||
@thread_cached_property
|
||||
def db_manager(self) -> "DatabaseManager":
|
||||
from backend.executor.database import DatabaseManager
|
||||
from backend.util.service import get_service_client
|
||||
return get_service_client(DatabaseManager)
|
||||
|
||||
def add_creds(self, user_id: str, credentials: Credentials) -> None:
|
||||
with self.locked_user_metadata(user_id):
|
||||
|
|
|
@ -27,11 +27,15 @@ R = TypeVar("R")
|
|||
class DatabaseManager(AppService):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(port=Config().database_api_port)
|
||||
super().__init__()
|
||||
self.use_db = True
|
||||
self.use_redis = True
|
||||
self.event_queue = RedisEventQueue()
|
||||
|
||||
@classmethod
|
||||
def get_port(cls) -> int:
|
||||
return Config().database_api_port
|
||||
|
||||
@expose
|
||||
def send_execution_update(self, execution_result_dict: dict[Any, Any]):
|
||||
self.event_queue.put(ExecutionResult(**execution_result_dict))
|
||||
|
|
|
@ -16,6 +16,8 @@ from redis.lock import Lock as RedisLock
|
|||
if TYPE_CHECKING:
|
||||
from backend.executor import DatabaseManager
|
||||
|
||||
from autogpt_libs.utils.cache import thread_cached
|
||||
|
||||
from backend.data import redis
|
||||
from backend.data.block import Block, BlockData, BlockInput, BlockType, get_block
|
||||
from backend.data.execution import (
|
||||
|
@ -31,7 +33,6 @@ from backend.data.graph import Graph, Link, Node
|
|||
from backend.data.model import CREDENTIALS_FIELD_NAME, CredentialsMetaInput
|
||||
from backend.integrations.creds_manager import IntegrationCredentialsManager
|
||||
from backend.util import json
|
||||
from backend.util.cache import thread_cached
|
||||
from backend.util.decorator import error_logged, time_measured
|
||||
from backend.util.logging import configure_logging
|
||||
from backend.util.process import set_service_name
|
||||
|
@ -419,7 +420,7 @@ class Executor:
|
|||
redis.connect()
|
||||
cls.pid = os.getpid()
|
||||
cls.db_client = get_db_client()
|
||||
cls.creds_manager = IntegrationCredentialsManager(db_manager=cls.db_client)
|
||||
cls.creds_manager = IntegrationCredentialsManager()
|
||||
|
||||
# Set up shutdown handlers
|
||||
cls.shutdown_lock = threading.Lock()
|
||||
|
@ -659,20 +660,24 @@ class Executor:
|
|||
class ExecutionManager(AppService):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(port=settings.config.execution_manager_port)
|
||||
super().__init__()
|
||||
self.use_redis = True
|
||||
self.use_supabase = True
|
||||
self.pool_size = settings.config.num_graph_workers
|
||||
self.queue = ExecutionQueue[GraphExecution]()
|
||||
self.active_graph_runs: dict[str, tuple[Future, threading.Event]] = {}
|
||||
|
||||
@classmethod
|
||||
def get_port(cls) -> int:
|
||||
return settings.config.execution_manager_port
|
||||
|
||||
def run_service(self):
|
||||
from autogpt_libs.supabase_integration_credentials_store import (
|
||||
SupabaseIntegrationCredentialsStore,
|
||||
)
|
||||
|
||||
self.credentials_store = SupabaseIntegrationCredentialsStore(
|
||||
redis=redis.get_redis(), db=self.db_client
|
||||
redis=redis.get_redis()
|
||||
)
|
||||
self.executor = ProcessPoolExecutor(
|
||||
max_workers=self.pool_size,
|
||||
|
@ -863,7 +868,7 @@ class ExecutionManager(AppService):
|
|||
def get_db_client() -> "DatabaseManager":
|
||||
from backend.executor import DatabaseManager
|
||||
|
||||
return get_service_client(DatabaseManager, settings.config.database_api_port)
|
||||
return get_service_client(DatabaseManager)
|
||||
|
||||
|
||||
@contextmanager
|
||||
|
|
|
@ -4,6 +4,7 @@ from datetime import datetime
|
|||
|
||||
from apscheduler.schedulers.background import BackgroundScheduler
|
||||
from apscheduler.triggers.cron import CronTrigger
|
||||
from autogpt_libs.utils.cache import thread_cached_property
|
||||
|
||||
from backend.data.block import BlockInput
|
||||
from backend.data.schedule import (
|
||||
|
@ -14,7 +15,6 @@ from backend.data.schedule import (
|
|||
update_schedule,
|
||||
)
|
||||
from backend.executor.manager import ExecutionManager
|
||||
from backend.util.cache import thread_cached_property
|
||||
from backend.util.service import AppService, expose, get_service_client
|
||||
from backend.util.settings import Config
|
||||
|
||||
|
@ -28,14 +28,18 @@ def log(msg, **kwargs):
|
|||
class ExecutionScheduler(AppService):
|
||||
|
||||
def __init__(self, refresh_interval=10):
|
||||
super().__init__(port=Config().execution_scheduler_port)
|
||||
super().__init__()
|
||||
self.use_db = True
|
||||
self.last_check = datetime.min
|
||||
self.refresh_interval = refresh_interval
|
||||
|
||||
@classmethod
|
||||
def get_port(cls) -> int:
|
||||
return Config().execution_scheduler_port
|
||||
|
||||
@thread_cached_property
|
||||
def execution_client(self) -> ExecutionManager:
|
||||
return get_service_client(ExecutionManager, Config().execution_manager_port)
|
||||
return get_service_client(ExecutionManager)
|
||||
|
||||
def run_service(self):
|
||||
scheduler = BackgroundScheduler()
|
||||
|
|
|
@ -10,7 +10,6 @@ from autogpt_libs.utils.synchronize import RedisKeyedMutex
|
|||
from redis.lock import Lock as RedisLock
|
||||
|
||||
from backend.data import redis
|
||||
from backend.executor.database import DatabaseManager
|
||||
from backend.integrations.oauth import HANDLERS_BY_NAME, BaseOAuthHandler
|
||||
from backend.util.settings import Settings
|
||||
|
||||
|
@ -50,12 +49,10 @@ class IntegrationCredentialsManager:
|
|||
cause so much latency that it's worth implementing.
|
||||
"""
|
||||
|
||||
def __init__(self, db_manager: DatabaseManager):
|
||||
def __init__(self):
|
||||
redis_conn = redis.get_redis()
|
||||
self._locks = RedisKeyedMutex(redis_conn)
|
||||
self.store = SupabaseIntegrationCredentialsStore(
|
||||
redis=redis_conn, db=db_manager
|
||||
)
|
||||
self.store = SupabaseIntegrationCredentialsStore(redis=redis_conn)
|
||||
|
||||
def create(self, user_id: str, credentials: Credentials) -> None:
|
||||
return self.store.add_creds(user_id, credentials)
|
||||
|
|
|
@ -10,7 +10,6 @@ from autogpt_libs.supabase_integration_credentials_store.types import (
|
|||
from fastapi import APIRouter, Body, Depends, HTTPException, Path, Query, Request
|
||||
from pydantic import BaseModel, Field, SecretStr
|
||||
|
||||
from backend.executor.manager import get_db_client
|
||||
from backend.integrations.creds_manager import IntegrationCredentialsManager
|
||||
from backend.integrations.oauth import HANDLERS_BY_NAME, BaseOAuthHandler
|
||||
from backend.util.settings import Settings
|
||||
|
@ -21,7 +20,7 @@ logger = logging.getLogger(__name__)
|
|||
settings = Settings()
|
||||
router = APIRouter()
|
||||
|
||||
creds_manager = IntegrationCredentialsManager(db_manager=get_db_client())
|
||||
creds_manager = IntegrationCredentialsManager()
|
||||
|
||||
|
||||
class LoginResponse(BaseModel):
|
||||
|
|
|
@ -7,6 +7,7 @@ from typing import Annotated, Any, Dict
|
|||
|
||||
import uvicorn
|
||||
from autogpt_libs.auth.middleware import auth_middleware
|
||||
from autogpt_libs.utils.cache import thread_cached_property
|
||||
from fastapi import APIRouter, Body, Depends, FastAPI, HTTPException, Request
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
from fastapi.responses import JSONResponse
|
||||
|
@ -19,10 +20,7 @@ from backend.data.block import BlockInput, CompletedBlockOutput
|
|||
from backend.data.credit import get_block_costs, get_user_credit_model
|
||||
from backend.data.user import get_or_create_user
|
||||
from backend.executor import ExecutionManager, ExecutionScheduler
|
||||
from backend.executor.manager import get_db_client
|
||||
from backend.integrations.creds_manager import IntegrationCredentialsManager
|
||||
from backend.server.model import CreateGraph, SetGraphActiveVersion
|
||||
from backend.util.cache import thread_cached_property
|
||||
from backend.util.service import AppService, get_service_client
|
||||
from backend.util.settings import AppEnvironment, Config, Settings
|
||||
|
||||
|
@ -37,9 +35,13 @@ class AgentServer(AppService):
|
|||
_user_credit_model = get_user_credit_model()
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(port=Config().agent_server_port)
|
||||
super().__init__()
|
||||
self.use_redis = True
|
||||
|
||||
@classmethod
|
||||
def get_port(cls) -> int:
|
||||
return Config().agent_server_port
|
||||
|
||||
@asynccontextmanager
|
||||
async def lifespan(self, _: FastAPI):
|
||||
await db.connect()
|
||||
|
@ -98,7 +100,6 @@ class AgentServer(AppService):
|
|||
tags=["integrations"],
|
||||
dependencies=[Depends(auth_middleware)],
|
||||
)
|
||||
self.integration_creds_manager = IntegrationCredentialsManager(get_db_client())
|
||||
|
||||
api_router.include_router(
|
||||
backend.server.routers.analytics.router,
|
||||
|
@ -308,11 +309,11 @@ class AgentServer(AppService):
|
|||
|
||||
@thread_cached_property
|
||||
def execution_manager_client(self) -> ExecutionManager:
|
||||
return get_service_client(ExecutionManager, Config().execution_manager_port)
|
||||
return get_service_client(ExecutionManager)
|
||||
|
||||
@thread_cached_property
|
||||
def execution_scheduler_client(self) -> ExecutionScheduler:
|
||||
return get_service_client(ExecutionScheduler, Config().execution_scheduler_port)
|
||||
return get_service_client(ExecutionScheduler)
|
||||
|
||||
@classmethod
|
||||
def handle_internal_http_error(cls, request: Request, exc: Exception):
|
||||
|
|
|
@ -5,6 +5,7 @@ import os
|
|||
import threading
|
||||
import time
|
||||
import typing
|
||||
from abc import ABC, abstractmethod
|
||||
from enum import Enum
|
||||
from types import NoneType, UnionType
|
||||
from typing import (
|
||||
|
@ -99,16 +100,24 @@ def _make_custom_deserializer(model: Type[BaseModel]):
|
|||
return custom_dict_to_class
|
||||
|
||||
|
||||
class AppService(AppProcess):
|
||||
class AppService(AppProcess, ABC):
|
||||
shared_event_loop: asyncio.AbstractEventLoop
|
||||
use_db: bool = False
|
||||
use_redis: bool = False
|
||||
use_supabase: bool = False
|
||||
|
||||
def __init__(self, port):
|
||||
self.port = port
|
||||
def __init__(self):
|
||||
self.uri = None
|
||||
|
||||
@classmethod
|
||||
@abstractmethod
|
||||
def get_port(cls) -> int:
|
||||
pass
|
||||
|
||||
@classmethod
|
||||
def get_host(cls) -> str:
|
||||
return os.environ.get(f"{cls.service_name.upper()}_HOST", Config().pyro_host)
|
||||
|
||||
def run_service(self) -> None:
|
||||
while True:
|
||||
time.sleep(10)
|
||||
|
@ -157,8 +166,7 @@ class AppService(AppProcess):
|
|||
|
||||
@conn_retry("Pyro", "Starting Pyro Service")
|
||||
def __start_pyro(self):
|
||||
host = Config().pyro_host
|
||||
daemon = Pyro5.api.Daemon(host=host, port=self.port)
|
||||
daemon = Pyro5.api.Daemon(host=self.get_host(), port=self.get_port())
|
||||
self.uri = daemon.register(self, objectId=self.service_name)
|
||||
logger.info(f"[{self.service_name}] Connected to Pyro; URI = {self.uri}")
|
||||
daemon.requestLoop()
|
||||
|
@ -167,16 +175,20 @@ class AppService(AppProcess):
|
|||
self.shared_event_loop.run_forever()
|
||||
|
||||
|
||||
# --------- UTILITIES --------- #
|
||||
|
||||
|
||||
AS = TypeVar("AS", bound=AppService)
|
||||
|
||||
|
||||
def get_service_client(service_type: Type[AS], port: int) -> AS:
|
||||
def get_service_client(service_type: Type[AS]) -> AS:
|
||||
service_name = service_type.service_name
|
||||
|
||||
class DynamicClient:
|
||||
@conn_retry("Pyro", f"Connecting to [{service_name}]")
|
||||
def __init__(self):
|
||||
host = os.environ.get(f"{service_name.upper()}_HOST", "localhost")
|
||||
host = service_type.get_host()
|
||||
port = service_type.get_port()
|
||||
uri = f"PYRO:{service_type.service_name}@{host}:{port}"
|
||||
logger.debug(f"Connecting to service [{service_name}]. URI = {uri}")
|
||||
self.proxy = Pyro5.api.Proxy(uri)
|
||||
|
@ -191,8 +203,6 @@ def get_service_client(service_type: Type[AS], port: int) -> AS:
|
|||
return cast(AS, DynamicClient())
|
||||
|
||||
|
||||
# --------- UTILITIES --------- #
|
||||
|
||||
builtin_types = [*vars(builtins).values(), NoneType, Enum]
|
||||
|
||||
|
||||
|
|
|
@ -5,7 +5,6 @@ from backend.executor import ExecutionScheduler
|
|||
from backend.server.model import CreateGraph
|
||||
from backend.usecases.sample import create_test_graph, create_test_user
|
||||
from backend.util.service import get_service_client
|
||||
from backend.util.settings import Config
|
||||
from backend.util.test import SpinTestServer
|
||||
|
||||
|
||||
|
@ -19,10 +18,7 @@ async def test_agent_schedule(server: SpinTestServer):
|
|||
user_id=test_user.id,
|
||||
)
|
||||
|
||||
scheduler = get_service_client(
|
||||
ExecutionScheduler, Config().execution_scheduler_port
|
||||
)
|
||||
|
||||
scheduler = get_service_client(ExecutionScheduler)
|
||||
schedules = scheduler.get_execution_schedules(test_graph.id, test_user.id)
|
||||
assert len(schedules) == 0
|
||||
|
||||
|
|
|
@ -7,7 +7,11 @@ TEST_SERVICE_PORT = 8765
|
|||
|
||||
class ServiceTest(AppService):
|
||||
def __init__(self):
|
||||
super().__init__(port=TEST_SERVICE_PORT)
|
||||
super().__init__()
|
||||
|
||||
@classmethod
|
||||
def get_port(cls) -> int:
|
||||
return TEST_SERVICE_PORT
|
||||
|
||||
@expose
|
||||
def add(self, a: int, b: int) -> int:
|
||||
|
@ -28,7 +32,7 @@ class ServiceTest(AppService):
|
|||
@pytest.mark.asyncio(scope="session")
|
||||
async def test_service_creation(server):
|
||||
with ServiceTest():
|
||||
client = get_service_client(ServiceTest, TEST_SERVICE_PORT)
|
||||
client = get_service_client(ServiceTest)
|
||||
assert client.add(5, 3) == 8
|
||||
assert client.subtract(10, 4) == 6
|
||||
assert client.fun_with_async(5, 3) == 8
|
||||
|
|
Loading…
Reference in New Issue