feat(server): Add OAuth flow endpoints for integrations (#7872)

- feat(server): Initial draft of OAuth init and exchange endpoints
  - Add `supabase` dependency
  - Add Supabase credentials to `Secrets`
  - Add `get_supabase` utility to `.server.utils`
  - Add `.server.integrations` API segment with initial implementations for OAuth init and exchange endpoints
- Move integration OAuth handlers to `autogpt_server.integrations.oauth`
- Change constructor of `SupabaseIntegrationCredentialsStore` to take a Supabase client
- Fix type issues in `GoogleOAuthHandler`
pull/8027/head
Reinier van der Leer 2024-09-09 17:21:56 +02:00 committed by GitHub
parent e17ea22a0a
commit a60ed21404
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
13 changed files with 582 additions and 44 deletions

View File

@ -1,13 +1,21 @@
import secrets
from datetime import datetime, timedelta, timezone
from typing import cast from typing import cast
from supabase import Client, create_client from supabase import Client
from .types import Credentials, OAuth2Credentials, UserMetadata, UserMetadataRaw from .types import (
Credentials,
OAuth2Credentials,
OAuthState,
UserMetadata,
UserMetadataRaw,
)
class SupabaseIntegrationCredentialsStore: class SupabaseIntegrationCredentialsStore:
def __init__(self, url: str, key: str): def __init__(self, supabase: Client):
self.supabase: Client = create_client(url, key) self.supabase = supabase
def add_creds(self, user_id: str, credentials: Credentials) -> None: def add_creds(self, user_id: str, credentials: Credentials) -> None:
if self.get_creds_by_id(user_id, credentials.id): if self.get_creds_by_id(user_id, credentials.id):
@ -73,6 +81,52 @@ class SupabaseIntegrationCredentialsStore:
] ]
self._set_user_integration_creds(user_id, filtered_credentials) self._set_user_integration_creds(user_id, filtered_credentials)
async def store_state_token(self, user_id: str, provider: str) -> str:
token = secrets.token_urlsafe(32)
expires_at = datetime.now(timezone.utc) + timedelta(minutes=10)
state = OAuthState(
token=token, provider=provider, expires_at=int(expires_at.timestamp())
)
user_metadata = self._get_user_metadata(user_id)
oauth_states = user_metadata.get("integration_oauth_states", [])
oauth_states.append(state.model_dump())
user_metadata["integration_oauth_states"] = oauth_states
self.supabase.auth.admin.update_user_by_id(
user_id, {"user_metadata": user_metadata}
)
return token
async def verify_state_token(self, user_id: str, token: str, provider: str) -> bool:
user_metadata = self._get_user_metadata(user_id)
oauth_states = user_metadata.get("integration_oauth_states", [])
now = datetime.now(timezone.utc)
valid_state = next(
(
state
for state in oauth_states
if state["token"] == token
and state["provider"] == provider
and state["expires_at"] > now.timestamp()
),
None,
)
if valid_state:
# Remove the used state
oauth_states.remove(valid_state)
user_metadata["integration_oauth_states"] = oauth_states
self.supabase.auth.admin.update_user_by_id(
user_id, {"user_metadata": user_metadata}
)
return True
return False
def _set_user_integration_creds( def _set_user_integration_creds(
self, user_id: str, credentials: list[Credentials] self, user_id: str, credentials: list[Credentials]
) -> None: ) -> None:

View File

@ -19,9 +19,11 @@ class _BaseCredentials(BaseModel):
class OAuth2Credentials(_BaseCredentials): class OAuth2Credentials(_BaseCredentials):
type: Literal["oauth2"] = "oauth2" type: Literal["oauth2"] = "oauth2"
access_token: SecretStr access_token: SecretStr
access_token_expires_at: Optional[int] # seconds access_token_expires_at: Optional[int]
"""Unix timestamp (seconds) indicating when the access token expires (if at all)"""
refresh_token: Optional[SecretStr] refresh_token: Optional[SecretStr]
refresh_token_expires_at: Optional[int] # seconds refresh_token_expires_at: Optional[int]
"""Unix timestamp (seconds) indicating when the refresh token expires (if at all)"""
scopes: list[str] scopes: list[str]
metadata: dict[str, Any] = Field(default_factory=dict) metadata: dict[str, Any] = Field(default_factory=dict)
@ -29,7 +31,8 @@ class OAuth2Credentials(_BaseCredentials):
class APIKeyCredentials(_BaseCredentials): class APIKeyCredentials(_BaseCredentials):
type: Literal["api_key"] = "api_key" type: Literal["api_key"] = "api_key"
api_key: SecretStr api_key: SecretStr
expires_at: Optional[int] # seconds expires_at: Optional[int]
"""Unix timestamp (seconds) indicating when the API key expires (if at all)"""
Credentials = Annotated[ Credentials = Annotated[
@ -38,9 +41,18 @@ Credentials = Annotated[
] ]
class OAuthState(BaseModel):
token: str
provider: str
expires_at: int
"""Unix timestamp (seconds) indicating when this OAuth state expires"""
class UserMetadata(BaseModel): class UserMetadata(BaseModel):
integration_credentials: list[Credentials] = Field(default_factory=list) integration_credentials: list[Credentials] = Field(default_factory=list)
integration_oauth_states: list[OAuthState] = Field(default_factory=list)
class UserMetadataRaw(TypedDict, total=False): class UserMetadataRaw(TypedDict, total=False):
integration_credentials: list[dict] integration_credentials: list[dict]
integration_oauth_states: list[dict]

View File

@ -0,0 +1,15 @@
from .base import BaseOAuthHandler
from .github import GitHubOAuthHandler
from .google import GoogleOAuthHandler
from .notion import NotionOAuthHandler
HANDLERS_BY_NAME: dict[str, type[BaseOAuthHandler]] = {
handler.PROVIDER_NAME: handler
for handler in [
GitHubOAuthHandler,
GoogleOAuthHandler,
NotionOAuthHandler,
]
}
__all__ = ["HANDLERS_BY_NAME"]

View File

@ -0,0 +1,48 @@
import time
from abc import ABC, abstractmethod
from typing import ClassVar
from autogpt_libs.supabase_integration_credentials_store import OAuth2Credentials
class BaseOAuthHandler(ABC):
PROVIDER_NAME: ClassVar[str]
@abstractmethod
def __init__(self, client_id: str, client_secret: str, redirect_uri: str): ...
@abstractmethod
def get_login_url(self, scopes: list[str], state: str) -> str:
"""Constructs a login URL that the user can be redirected to"""
...
@abstractmethod
def exchange_code_for_tokens(self, code: str) -> OAuth2Credentials:
"""Exchanges the acquired authorization code from login for a set of tokens"""
...
@abstractmethod
def _refresh_tokens(self, credentials: OAuth2Credentials) -> OAuth2Credentials:
"""Implements the token refresh mechanism"""
...
def refresh_tokens(self, credentials: OAuth2Credentials) -> OAuth2Credentials:
if credentials.provider != self.PROVIDER_NAME:
raise ValueError(
f"{self.__class__.__name__} can not refresh tokens "
f"for other provider '{credentials.provider}'"
)
return self._refresh_tokens(credentials)
def get_access_token(self, credentials: OAuth2Credentials) -> str:
"""Returns a valid access token, refreshing it first if needed"""
if self.needs_refresh(credentials):
credentials = self.refresh_tokens(credentials)
return credentials.access_token.get_secret_value()
def needs_refresh(self, credentials: OAuth2Credentials) -> bool:
"""Indicates whether the given tokens need to be refreshed"""
return (
credentials.access_token_expires_at is not None
and credentials.access_token_expires_at < int(time.time()) + 300
)

View File

@ -0,0 +1,99 @@
import time
from typing import Optional
from urllib.parse import urlencode
import requests
from autogpt_libs.supabase_integration_credentials_store import OAuth2Credentials
from .base import BaseOAuthHandler
class GitHubOAuthHandler(BaseOAuthHandler):
"""
Based on the documentation at:
- [Authorizing OAuth apps - GitHub Docs](https://docs.github.com/en/apps/oauth-apps/building-oauth-apps/authorizing-oauth-apps)
- [Refreshing user access tokens - GitHub Docs](https://docs.github.com/en/apps/creating-github-apps/authenticating-with-a-github-app/refreshing-user-access-tokens)
Notes:
- By default, token expiration is disabled on GitHub Apps. This means the access
token doesn't expire and no refresh token is returned by the authorization flow.
- When token expiration gets enabled, any existing tokens will remain non-expiring.
- When token expiration gets disabled, token refreshes will return a non-expiring
access token *with no refresh token*.
""" # noqa
PROVIDER_NAME = "github"
def __init__(self, client_id: str, client_secret: str, redirect_uri: str):
self.client_id = client_id
self.client_secret = client_secret
self.redirect_uri = redirect_uri
self.auth_base_url = "https://github.com/login/oauth/authorize"
self.token_url = "https://github.com/login/oauth/access_token"
def get_login_url(self, scopes: list[str], state: str) -> str:
params = {
"client_id": self.client_id,
"redirect_uri": self.redirect_uri,
"scope": " ".join(scopes),
"state": state,
}
return f"{self.auth_base_url}?{urlencode(params)}"
def exchange_code_for_tokens(self, code: str) -> OAuth2Credentials:
return self._request_tokens({"code": code, "redirect_uri": self.redirect_uri})
def _refresh_tokens(self, credentials: OAuth2Credentials) -> OAuth2Credentials:
if not credentials.refresh_token:
return credentials
return self._request_tokens(
{
"refresh_token": credentials.refresh_token.get_secret_value(),
"grant_type": "refresh_token",
}
)
def _request_tokens(
self,
params: dict[str, str],
current_credentials: Optional[OAuth2Credentials] = None,
) -> OAuth2Credentials:
request_body = {
"client_id": self.client_id,
"client_secret": self.client_secret,
**params,
}
headers = {"Accept": "application/json"}
response = requests.post(self.token_url, data=request_body, headers=headers)
response.raise_for_status()
token_data: dict = response.json()
now = int(time.time())
new_credentials = OAuth2Credentials(
provider=self.PROVIDER_NAME,
title=current_credentials.title if current_credentials else "GitHub",
access_token=token_data["access_token"],
# Token refresh responses have an empty `scope` property (see docs),
# so we have to get the scope from the existing credentials object.
scopes=(
token_data.get("scope", "").split(",")
or (current_credentials.scopes if current_credentials else [])
),
# Refresh token and expiration intervals are only given if token expiration
# is enabled in the GitHub App's settings.
refresh_token=token_data.get("refresh_token"),
access_token_expires_at=(
now + expires_in
if (expires_in := token_data.get("expires_in", None))
else None
),
refresh_token_expires_at=(
now + expires_in
if (expires_in := token_data.get("refresh_token_expires_in", None))
else None
),
)
if current_credentials:
new_credentials.id = current_credentials.id
return new_credentials

View File

@ -0,0 +1,96 @@
from autogpt_libs.supabase_integration_credentials_store import OAuth2Credentials
from google.auth.transport.requests import Request
from google.oauth2.credentials import Credentials
from google_auth_oauthlib.flow import Flow
from pydantic import SecretStr
from .base import BaseOAuthHandler
class GoogleOAuthHandler(BaseOAuthHandler):
"""
Based on the documentation at https://developers.google.com/identity/protocols/oauth2/web-server
""" # noqa
PROVIDER_NAME = "google"
def __init__(self, client_id: str, client_secret: str, redirect_uri: str):
self.client_id = client_id
self.client_secret = client_secret
self.redirect_uri = redirect_uri
self.token_uri = "https://oauth2.googleapis.com/token"
def get_login_url(self, scopes: list[str], state: str) -> str:
flow = self._setup_oauth_flow(scopes)
flow.redirect_uri = self.redirect_uri
authorization_url, _ = flow.authorization_url(
access_type="offline",
include_granted_scopes="true",
state=state,
prompt="consent",
)
return authorization_url
def exchange_code_for_tokens(self, code: str) -> OAuth2Credentials:
flow = self._setup_oauth_flow(None)
flow.redirect_uri = self.redirect_uri
flow.fetch_token(code=code)
google_creds = flow.credentials
# Google's OAuth library is poorly typed so we need some of these:
assert google_creds.token
assert google_creds.refresh_token
assert google_creds.expiry
assert google_creds.scopes
return OAuth2Credentials(
provider=self.PROVIDER_NAME,
title="Google",
access_token=SecretStr(google_creds.token),
refresh_token=SecretStr(google_creds.refresh_token),
access_token_expires_at=int(google_creds.expiry.timestamp()),
refresh_token_expires_at=None,
scopes=google_creds.scopes,
)
def _refresh_tokens(self, credentials: OAuth2Credentials) -> OAuth2Credentials:
# Google credentials should ALWAYS have a refresh token
assert credentials.refresh_token
google_creds = Credentials(
token=credentials.access_token.get_secret_value(),
refresh_token=credentials.refresh_token.get_secret_value(),
token_uri=self.token_uri,
client_id=self.client_id,
client_secret=self.client_secret,
scopes=credentials.scopes,
)
# Google's OAuth library is poorly typed so we need some of these:
assert google_creds.refresh_token
assert google_creds.scopes
google_creds.refresh(Request())
assert google_creds.expiry
return OAuth2Credentials(
id=credentials.id,
provider=self.PROVIDER_NAME,
title=credentials.title,
access_token=SecretStr(google_creds.token),
refresh_token=SecretStr(google_creds.refresh_token),
access_token_expires_at=int(google_creds.expiry.timestamp()),
refresh_token_expires_at=None,
scopes=google_creds.scopes,
)
def _setup_oauth_flow(self, scopes: list[str] | None) -> Flow:
return Flow.from_client_config(
{
"web": {
"client_id": self.client_id,
"client_secret": self.client_secret,
"auth_uri": "https://accounts.google.com/o/oauth2/auth",
"token_uri": self.token_uri,
}
},
scopes=scopes,
)

View File

@ -0,0 +1,76 @@
from base64 import b64encode
from urllib.parse import urlencode
import requests
from autogpt_libs.supabase_integration_credentials_store import OAuth2Credentials
from .base import BaseOAuthHandler
class NotionOAuthHandler(BaseOAuthHandler):
"""
Based on the documentation at https://developers.notion.com/docs/authorization
Notes:
- Notion uses non-expiring access tokens and therefore doesn't have a refresh flow
- Notion doesn't use scopes
"""
PROVIDER_NAME = "notion"
def __init__(self, client_id: str, client_secret: str, redirect_uri: str):
self.client_id = client_id
self.client_secret = client_secret
self.redirect_uri = redirect_uri
self.auth_base_url = "https://api.notion.com/v1/oauth/authorize"
self.token_url = "https://api.notion.com/v1/oauth/token"
def get_login_url(self, scopes: list[str], state: str) -> str:
params = {
"client_id": self.client_id,
"redirect_uri": self.redirect_uri,
"response_type": "code",
"owner": "user",
"state": state,
}
return f"{self.auth_base_url}?{urlencode(params)}"
def exchange_code_for_tokens(self, code: str) -> OAuth2Credentials:
request_body = {
"grant_type": "authorization_code",
"code": code,
"redirect_uri": self.redirect_uri,
}
auth_str = b64encode(f"{self.client_id}:{self.client_secret}".encode()).decode()
headers = {
"Authorization": f"Basic {auth_str}",
"Accept": "application/json",
}
response = requests.post(self.token_url, json=request_body, headers=headers)
response.raise_for_status()
token_data = response.json()
return OAuth2Credentials(
provider=self.PROVIDER_NAME,
title=token_data.get("workspace_name", "Notion"),
access_token=token_data["access_token"],
refresh_token=None,
access_token_expires_at=None, # Notion tokens don't expire
refresh_token_expires_at=None,
scopes=[],
metadata={
"owner": token_data["owner"],
"bot_id": token_data["bot_id"],
"workspace_id": token_data["workspace_id"],
"workspace_name": token_data.get("workspace_name"),
"workspace_icon": token_data.get("workspace_icon"),
},
)
def _refresh_tokens(self, credentials: OAuth2Credentials) -> OAuth2Credentials:
# Notion doesn't support token refresh
return credentials
def needs_refresh(self, credentials: OAuth2Credentials) -> bool:
# Notion access tokens don't expire
return False

View File

@ -0,0 +1,105 @@
import logging
from typing import Annotated, Literal
from autogpt_libs.supabase_integration_credentials_store import (
SupabaseIntegrationCredentialsStore,
)
from fastapi import APIRouter, Body, Depends, HTTPException, Path, Query, Request
from pydantic import BaseModel
from supabase import Client
from autogpt_server.integrations.oauth import HANDLERS_BY_NAME, BaseOAuthHandler
from autogpt_server.util.settings import Settings
from .utils import get_supabase, get_user_id
logger = logging.getLogger(__name__)
settings = Settings()
integrations_api_router = APIRouter()
def get_store(supabase: Client = Depends(get_supabase)):
return SupabaseIntegrationCredentialsStore(supabase)
class LoginResponse(BaseModel):
login_url: str
@integrations_api_router.get("/{provider}/login")
async def login(
provider: Annotated[str, Path(title="The provider to initiate an OAuth flow for")],
user_id: Annotated[str, Depends(get_user_id)],
request: Request,
store: Annotated[SupabaseIntegrationCredentialsStore, Depends(get_store)],
scopes: Annotated[
str, Query(title="Comma-separated list of authorization scopes")
] = "",
) -> LoginResponse:
handler = _get_provider_oauth_handler(request, provider)
# Generate and store a secure random state token
state = await store.store_state_token(user_id, provider)
requested_scopes = scopes.split(",") if scopes else []
login_url = handler.get_login_url(requested_scopes, state)
return LoginResponse(login_url=login_url)
class CredentialsMetaResponse(BaseModel):
credentials_id: str
credentials_type: Literal["oauth2", "api_key"]
@integrations_api_router.post("/{provider}/callback")
async def callback(
provider: Annotated[str, Path(title="The target provider for this OAuth exchange")],
code: Annotated[str, Body(title="Authorization code acquired by user login")],
state_token: Annotated[str, Body(title="Anti-CSRF nonce")],
store: Annotated[SupabaseIntegrationCredentialsStore, Depends(get_store)],
user_id: Annotated[str, Depends(get_user_id)],
request: Request,
) -> CredentialsMetaResponse:
handler = _get_provider_oauth_handler(request, provider)
# Verify the state token
if not await store.verify_state_token(user_id, state_token, provider):
raise HTTPException(status_code=400, detail="Invalid or expired state token")
try:
credentials = handler.exchange_code_for_tokens(code)
except Exception as e:
logger.warning(f"Code->Token exchange failed for provider {provider}: {e}")
raise HTTPException(status_code=400, detail=str(e))
store.add_creds(user_id, credentials)
return CredentialsMetaResponse(
credentials_id=credentials.id,
credentials_type=credentials.type,
)
# -------- UTILITIES --------- #
def _get_provider_oauth_handler(req: Request, provider_name: str) -> BaseOAuthHandler:
if provider_name not in HANDLERS_BY_NAME:
raise HTTPException(
status_code=404, detail=f"Unknown provider '{provider_name}'"
)
client_id = getattr(settings.secrets, f"{provider_name}_client_id")
client_secret = getattr(settings.secrets, f"{provider_name}_client_secret")
if not (client_id and client_secret):
raise HTTPException(
status_code=501,
detail=f"Integration with provider '{provider_name}' is not configured",
)
handler_class = HANDLERS_BY_NAME[provider_name]
return handler_class(
client_id=client_id,
client_secret=client_secret,
redirect_uri=str(req.url_for("callback", provider=provider_name)),
)

View File

@ -19,11 +19,12 @@ from autogpt_server.data.queue import AsyncEventQueue, AsyncRedisEventQueue
from autogpt_server.data.user import get_or_create_user from autogpt_server.data.user import get_or_create_user
from autogpt_server.executor import ExecutionManager, ExecutionScheduler from autogpt_server.executor import ExecutionManager, ExecutionScheduler
from autogpt_server.server.model import CreateGraph, SetGraphActiveVersion from autogpt_server.server.model import CreateGraph, SetGraphActiveVersion
from autogpt_server.util.auth import get_user_id
from autogpt_server.util.lock import KeyedMutex from autogpt_server.util.lock import KeyedMutex
from autogpt_server.util.service import AppService, expose, get_service_client from autogpt_server.util.service import AppService, expose, get_service_client
from autogpt_server.util.settings import Settings from autogpt_server.util.settings import Settings
from .utils import get_user_id
settings = Settings() settings = Settings()
@ -70,127 +71,132 @@ class AgentServer(AppService):
) )
# Define the API routes # Define the API routes
router = APIRouter(prefix="/api") api_router = APIRouter(prefix="/api")
router.dependencies.append(Depends(auth_middleware)) api_router.dependencies.append(Depends(auth_middleware))
router.add_api_route( # Import & Attach sub-routers
from .integrations import integrations_api_router
api_router.include_router(integrations_api_router, prefix="/integrations")
api_router.add_api_route(
path="/auth/user", path="/auth/user",
endpoint=self.get_or_create_user_route, endpoint=self.get_or_create_user_route,
methods=["POST"], methods=["POST"],
) )
router.add_api_route( api_router.add_api_route(
path="/blocks", path="/blocks",
endpoint=self.get_graph_blocks, endpoint=self.get_graph_blocks,
methods=["GET"], methods=["GET"],
) )
router.add_api_route( api_router.add_api_route(
path="/blocks/{block_id}/execute", path="/blocks/{block_id}/execute",
endpoint=self.execute_graph_block, endpoint=self.execute_graph_block,
methods=["POST"], methods=["POST"],
) )
router.add_api_route( api_router.add_api_route(
path="/graphs", path="/graphs",
endpoint=self.get_graphs, endpoint=self.get_graphs,
methods=["GET"], methods=["GET"],
) )
router.add_api_route( api_router.add_api_route(
path="/templates", path="/templates",
endpoint=self.get_templates, endpoint=self.get_templates,
methods=["GET"], methods=["GET"],
) )
router.add_api_route( api_router.add_api_route(
path="/graphs", path="/graphs",
endpoint=self.create_new_graph, endpoint=self.create_new_graph,
methods=["POST"], methods=["POST"],
) )
router.add_api_route( api_router.add_api_route(
path="/templates", path="/templates",
endpoint=self.create_new_template, endpoint=self.create_new_template,
methods=["POST"], methods=["POST"],
) )
router.add_api_route( api_router.add_api_route(
path="/graphs/{graph_id}", path="/graphs/{graph_id}",
endpoint=self.get_graph, endpoint=self.get_graph,
methods=["GET"], methods=["GET"],
) )
router.add_api_route( api_router.add_api_route(
path="/templates/{graph_id}", path="/templates/{graph_id}",
endpoint=self.get_template, endpoint=self.get_template,
methods=["GET"], methods=["GET"],
) )
router.add_api_route( api_router.add_api_route(
path="/graphs/{graph_id}", path="/graphs/{graph_id}",
endpoint=self.update_graph, endpoint=self.update_graph,
methods=["PUT"], methods=["PUT"],
) )
router.add_api_route( api_router.add_api_route(
path="/templates/{graph_id}", path="/templates/{graph_id}",
endpoint=self.update_graph, endpoint=self.update_graph,
methods=["PUT"], methods=["PUT"],
) )
router.add_api_route( api_router.add_api_route(
path="/graphs/{graph_id}/versions", path="/graphs/{graph_id}/versions",
endpoint=self.get_graph_all_versions, endpoint=self.get_graph_all_versions,
methods=["GET"], methods=["GET"],
) )
router.add_api_route( api_router.add_api_route(
path="/templates/{graph_id}/versions", path="/templates/{graph_id}/versions",
endpoint=self.get_graph_all_versions, endpoint=self.get_graph_all_versions,
methods=["GET"], methods=["GET"],
) )
router.add_api_route( api_router.add_api_route(
path="/graphs/{graph_id}/versions/{version}", path="/graphs/{graph_id}/versions/{version}",
endpoint=self.get_graph, endpoint=self.get_graph,
methods=["GET"], methods=["GET"],
) )
router.add_api_route( api_router.add_api_route(
path="/graphs/{graph_id}/versions/active", path="/graphs/{graph_id}/versions/active",
endpoint=self.set_graph_active_version, endpoint=self.set_graph_active_version,
methods=["PUT"], methods=["PUT"],
) )
router.add_api_route( api_router.add_api_route(
path="/graphs/{graph_id}/input_schema", path="/graphs/{graph_id}/input_schema",
endpoint=self.get_graph_input_schema, endpoint=self.get_graph_input_schema,
methods=["GET"], methods=["GET"],
) )
router.add_api_route( api_router.add_api_route(
path="/graphs/{graph_id}/execute", path="/graphs/{graph_id}/execute",
endpoint=self.execute_graph, endpoint=self.execute_graph,
methods=["POST"], methods=["POST"],
) )
router.add_api_route( api_router.add_api_route(
path="/graphs/{graph_id}/executions", path="/graphs/{graph_id}/executions",
endpoint=self.list_graph_runs, endpoint=self.list_graph_runs,
methods=["GET"], methods=["GET"],
) )
router.add_api_route( api_router.add_api_route(
path="/graphs/{graph_id}/executions/{graph_exec_id}", path="/graphs/{graph_id}/executions/{graph_exec_id}",
endpoint=self.get_graph_run_node_execution_results, endpoint=self.get_graph_run_node_execution_results,
methods=["GET"], methods=["GET"],
) )
router.add_api_route( api_router.add_api_route(
path="/graphs/{graph_id}/executions/{graph_exec_id}/stop", path="/graphs/{graph_id}/executions/{graph_exec_id}/stop",
endpoint=self.stop_graph_run, endpoint=self.stop_graph_run,
methods=["POST"], methods=["POST"],
) )
router.add_api_route( api_router.add_api_route(
path="/graphs/{graph_id}/schedules", path="/graphs/{graph_id}/schedules",
endpoint=self.create_schedule, endpoint=self.create_schedule,
methods=["POST"], methods=["POST"],
) )
router.add_api_route( api_router.add_api_route(
path="/graphs/{graph_id}/schedules", path="/graphs/{graph_id}/schedules",
endpoint=self.get_execution_schedules, endpoint=self.get_execution_schedules,
methods=["GET"], methods=["GET"],
) )
router.add_api_route( api_router.add_api_route(
path="/graphs/schedules/{schedule_id}", path="/graphs/schedules/{schedule_id}",
endpoint=self.update_schedule, endpoint=self.update_schedule,
methods=["PUT"], methods=["PUT"],
) )
router.add_api_route( api_router.add_api_route(
path="/settings", path="/settings",
endpoint=self.update_configuration, endpoint=self.update_configuration,
methods=["POST"], methods=["POST"],
@ -198,7 +204,7 @@ class AgentServer(AppService):
app.add_exception_handler(500, self.handle_internal_http_error) app.add_exception_handler(500, self.handle_internal_http_error)
app.include_router(router) app.include_router(api_router)
uvicorn.run(app, host="0.0.0.0", port=8000, log_config=None) uvicorn.run(app, host="0.0.0.0", port=8000, log_config=None)

View File

@ -1,7 +1,11 @@
from autogpt_libs.auth import auth_middleware from autogpt_libs.auth.middleware import auth_middleware
from fastapi import Depends, HTTPException from fastapi import Depends, HTTPException
from supabase import Client, create_client
from autogpt_server.data.user import DEFAULT_USER_ID from autogpt_server.data.user import DEFAULT_USER_ID
from autogpt_server.util.settings import Settings
settings = Settings()
def get_user_id(payload: dict = Depends(auth_middleware)) -> str: def get_user_id(payload: dict = Depends(auth_middleware)) -> str:
@ -13,3 +17,7 @@ def get_user_id(payload: dict = Depends(auth_middleware)) -> str:
if not user_id: if not user_id:
raise HTTPException(status_code=401, detail="User ID not found in token") raise HTTPException(status_code=401, detail="User ID not found in token")
return user_id return user_id
def get_supabase() -> Client:
return create_client(settings.secrets.supabase_url, settings.secrets.supabase_key)

View File

@ -93,6 +93,23 @@ class Config(UpdateTrackingModel["Config"], BaseSettings):
class Secrets(UpdateTrackingModel["Secrets"], BaseSettings): class Secrets(UpdateTrackingModel["Secrets"], BaseSettings):
"""Secrets for the server.""" """Secrets for the server."""
supabase_url: str = Field(default="", description="Supabase URL")
supabase_key: str = Field(default="", description="Supabase key")
# OAuth server credentials for integrations
github_client_id: str = Field(default="", description="GitHub OAuth client ID")
github_client_secret: str = Field(
default="", description="GitHub OAuth client secret"
)
google_client_id: str = Field(default="", description="Google OAuth client ID")
google_client_secret: str = Field(
default="", description="Google OAuth client secret"
)
notion_client_id: str = Field(default="", description="Notion OAuth client ID")
notion_client_secret: str = Field(
default="", description="Notion OAuth client secret"
)
openai_api_key: str = Field(default="", description="OpenAI API key") openai_api_key: str = Field(default="", description="OpenAI API key")
anthropic_api_key: str = Field(default="", description="Anthropic API key") anthropic_api_key: str = Field(default="", description="Anthropic API key")
groq_api_key: str = Field(default="", description="Groq API key") groq_api_key: str = Field(default="", description="Groq API key")

View File

@ -1,4 +1,4 @@
# This file is automatically @generated by Poetry 1.7.1 and should not be changed by hand. # This file is automatically @generated by Poetry 1.8.3 and should not be changed by hand.
[[package]] [[package]]
name = "agpt" name = "agpt"
@ -25,7 +25,7 @@ requests = "*"
sentry-sdk = "^1.40.4" sentry-sdk = "^1.40.4"
[package.extras] [package.extras]
benchmark = ["agbenchmark @ file:///Users/majdyz/Code/AutoGPT/benchmark"] benchmark = ["agbenchmark @ file:///home/reinier/code/agpt/AutoGPT/benchmark"]
[package.source] [package.source]
type = "directory" type = "directory"
@ -386,7 +386,7 @@ watchdog = "4.0.0"
webdriver-manager = "^4.0.1" webdriver-manager = "^4.0.1"
[package.extras] [package.extras]
benchmark = ["agbenchmark @ file:///Users/majdyz/Code/AutoGPT/benchmark"] benchmark = ["agbenchmark @ file:///home/reinier/code/agpt/AutoGPT/benchmark"]
[package.source] [package.source]
type = "directory" type = "directory"
@ -3429,6 +3429,8 @@ files = [
{file = "orjson-3.10.6-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:960db0e31c4e52fa0fc3ecbaea5b2d3b58f379e32a95ae6b0ebeaa25b93dfd34"}, {file = "orjson-3.10.6-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:960db0e31c4e52fa0fc3ecbaea5b2d3b58f379e32a95ae6b0ebeaa25b93dfd34"},
{file = "orjson-3.10.6-cp312-none-win32.whl", hash = "sha256:a6ea7afb5b30b2317e0bee03c8d34c8181bc5a36f2afd4d0952f378972c4efd5"}, {file = "orjson-3.10.6-cp312-none-win32.whl", hash = "sha256:a6ea7afb5b30b2317e0bee03c8d34c8181bc5a36f2afd4d0952f378972c4efd5"},
{file = "orjson-3.10.6-cp312-none-win_amd64.whl", hash = "sha256:874ce88264b7e655dde4aeaacdc8fd772a7962faadfb41abe63e2a4861abc3dc"}, {file = "orjson-3.10.6-cp312-none-win_amd64.whl", hash = "sha256:874ce88264b7e655dde4aeaacdc8fd772a7962faadfb41abe63e2a4861abc3dc"},
{file = "orjson-3.10.6-cp313-none-win32.whl", hash = "sha256:efdf2c5cde290ae6b83095f03119bdc00303d7a03b42b16c54517baa3c4ca3d0"},
{file = "orjson-3.10.6-cp313-none-win_amd64.whl", hash = "sha256:8e190fe7888e2e4392f52cafb9626113ba135ef53aacc65cd13109eb9746c43e"},
{file = "orjson-3.10.6-cp38-cp38-macosx_10_15_x86_64.macosx_11_0_arm64.macosx_10_15_universal2.whl", hash = "sha256:66680eae4c4e7fc193d91cfc1353ad6d01b4801ae9b5314f17e11ba55e934183"}, {file = "orjson-3.10.6-cp38-cp38-macosx_10_15_x86_64.macosx_11_0_arm64.macosx_10_15_universal2.whl", hash = "sha256:66680eae4c4e7fc193d91cfc1353ad6d01b4801ae9b5314f17e11ba55e934183"},
{file = "orjson-3.10.6-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:caff75b425db5ef8e8f23af93c80f072f97b4fb3afd4af44482905c9f588da28"}, {file = "orjson-3.10.6-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:caff75b425db5ef8e8f23af93c80f072f97b4fb3afd4af44482905c9f588da28"},
{file = "orjson-3.10.6-cp38-cp38-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:3722fddb821b6036fd2a3c814f6bd9b57a89dc6337b9924ecd614ebce3271394"}, {file = "orjson-3.10.6-cp38-cp38-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:3722fddb821b6036fd2a3c814f6bd9b57a89dc6337b9924ecd614ebce3271394"},
@ -6452,4 +6454,4 @@ test = ["big-O", "importlib-resources", "jaraco.functools", "jaraco.itertools",
[metadata] [metadata]
lock-version = "2.0" lock-version = "2.0"
python-versions = "^3.10" python-versions = "^3.10"
content-hash = "126731188e8fdc7df0bd2dc92cd069fcd2b90edd5e1065cebd8f4adcedf982b5" content-hash = "0ecd19c5cdf414368aa81b83ae76ba6db34b1bfc8f32a482d1222d6b839792da"

View File

@ -11,6 +11,7 @@ readme = "README.md"
[tool.poetry.dependencies] [tool.poetry.dependencies]
python = "^3.10" python = "^3.10"
agpt = { path = "../../autogpt", develop = true } agpt = { path = "../../autogpt", develop = true }
aio-pika = "^9.4.3"
anthropic = "^0.25.1" anthropic = "^0.25.1"
apscheduler = "^3.10.4" apscheduler = "^3.10.4"
autogpt-forge = { path = "../../forge", develop = true } autogpt-forge = { path = "../../forge", develop = true }
@ -39,15 +40,14 @@ pyro5 = "^5.15"
pytest = "^8.2.1" pytest = "^8.2.1"
pytest-asyncio = "^0.23.7" pytest-asyncio = "^0.23.7"
python-dotenv = "^1.0.1" python-dotenv = "^1.0.1"
redis = "^5.0.8"
sentry-sdk = "1.45.0"
supabase = "^2.7.2"
tenacity = "^8.3.0" tenacity = "^8.3.0"
uvicorn = { extras = ["standard"], version = "^0.30.1" } uvicorn = { extras = ["standard"], version = "^0.30.1" }
websockets = "^12.0" websockets = "^12.0"
youtube-transcript-api = "^0.6.2" youtube-transcript-api = "^0.6.2"
aio-pika = "^9.4.3"
redis = "^5.0.8"
sentry-sdk = "1.45.0"
[tool.poetry.group.dev.dependencies] [tool.poetry.group.dev.dependencies]
poethepoet = "^0.26.1" poethepoet = "^0.26.1"
httpx = "^0.27.0" httpx = "^0.27.0"