feat(server): Add OAuth flow endpoints for integrations (#7872)
- feat(server): Initial draft of OAuth init and exchange endpoints - Add `supabase` dependency - Add Supabase credentials to `Secrets` - Add `get_supabase` utility to `.server.utils` - Add `.server.integrations` API segment with initial implementations for OAuth init and exchange endpoints - Move integration OAuth handlers to `autogpt_server.integrations.oauth` - Change constructor of `SupabaseIntegrationCredentialsStore` to take a Supabase client - Fix type issues in `GoogleOAuthHandler`pull/8027/head
parent
e17ea22a0a
commit
a60ed21404
|
@ -1,13 +1,21 @@
|
|||
import secrets
|
||||
from datetime import datetime, timedelta, timezone
|
||||
from typing import cast
|
||||
|
||||
from supabase import Client, create_client
|
||||
from supabase import Client
|
||||
|
||||
from .types import Credentials, OAuth2Credentials, UserMetadata, UserMetadataRaw
|
||||
from .types import (
|
||||
Credentials,
|
||||
OAuth2Credentials,
|
||||
OAuthState,
|
||||
UserMetadata,
|
||||
UserMetadataRaw,
|
||||
)
|
||||
|
||||
|
||||
class SupabaseIntegrationCredentialsStore:
|
||||
def __init__(self, url: str, key: str):
|
||||
self.supabase: Client = create_client(url, key)
|
||||
def __init__(self, supabase: Client):
|
||||
self.supabase = supabase
|
||||
|
||||
def add_creds(self, user_id: str, credentials: Credentials) -> None:
|
||||
if self.get_creds_by_id(user_id, credentials.id):
|
||||
|
@ -73,6 +81,52 @@ class SupabaseIntegrationCredentialsStore:
|
|||
]
|
||||
self._set_user_integration_creds(user_id, filtered_credentials)
|
||||
|
||||
async def store_state_token(self, user_id: str, provider: str) -> str:
|
||||
token = secrets.token_urlsafe(32)
|
||||
expires_at = datetime.now(timezone.utc) + timedelta(minutes=10)
|
||||
|
||||
state = OAuthState(
|
||||
token=token, provider=provider, expires_at=int(expires_at.timestamp())
|
||||
)
|
||||
|
||||
user_metadata = self._get_user_metadata(user_id)
|
||||
oauth_states = user_metadata.get("integration_oauth_states", [])
|
||||
oauth_states.append(state.model_dump())
|
||||
user_metadata["integration_oauth_states"] = oauth_states
|
||||
|
||||
self.supabase.auth.admin.update_user_by_id(
|
||||
user_id, {"user_metadata": user_metadata}
|
||||
)
|
||||
|
||||
return token
|
||||
|
||||
async def verify_state_token(self, user_id: str, token: str, provider: str) -> bool:
|
||||
user_metadata = self._get_user_metadata(user_id)
|
||||
oauth_states = user_metadata.get("integration_oauth_states", [])
|
||||
|
||||
now = datetime.now(timezone.utc)
|
||||
valid_state = next(
|
||||
(
|
||||
state
|
||||
for state in oauth_states
|
||||
if state["token"] == token
|
||||
and state["provider"] == provider
|
||||
and state["expires_at"] > now.timestamp()
|
||||
),
|
||||
None,
|
||||
)
|
||||
|
||||
if valid_state:
|
||||
# Remove the used state
|
||||
oauth_states.remove(valid_state)
|
||||
user_metadata["integration_oauth_states"] = oauth_states
|
||||
self.supabase.auth.admin.update_user_by_id(
|
||||
user_id, {"user_metadata": user_metadata}
|
||||
)
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
def _set_user_integration_creds(
|
||||
self, user_id: str, credentials: list[Credentials]
|
||||
) -> None:
|
||||
|
|
|
@ -19,9 +19,11 @@ class _BaseCredentials(BaseModel):
|
|||
class OAuth2Credentials(_BaseCredentials):
|
||||
type: Literal["oauth2"] = "oauth2"
|
||||
access_token: SecretStr
|
||||
access_token_expires_at: Optional[int] # seconds
|
||||
access_token_expires_at: Optional[int]
|
||||
"""Unix timestamp (seconds) indicating when the access token expires (if at all)"""
|
||||
refresh_token: Optional[SecretStr]
|
||||
refresh_token_expires_at: Optional[int] # seconds
|
||||
refresh_token_expires_at: Optional[int]
|
||||
"""Unix timestamp (seconds) indicating when the refresh token expires (if at all)"""
|
||||
scopes: list[str]
|
||||
metadata: dict[str, Any] = Field(default_factory=dict)
|
||||
|
||||
|
@ -29,7 +31,8 @@ class OAuth2Credentials(_BaseCredentials):
|
|||
class APIKeyCredentials(_BaseCredentials):
|
||||
type: Literal["api_key"] = "api_key"
|
||||
api_key: SecretStr
|
||||
expires_at: Optional[int] # seconds
|
||||
expires_at: Optional[int]
|
||||
"""Unix timestamp (seconds) indicating when the API key expires (if at all)"""
|
||||
|
||||
|
||||
Credentials = Annotated[
|
||||
|
@ -38,9 +41,18 @@ Credentials = Annotated[
|
|||
]
|
||||
|
||||
|
||||
class OAuthState(BaseModel):
|
||||
token: str
|
||||
provider: str
|
||||
expires_at: int
|
||||
"""Unix timestamp (seconds) indicating when this OAuth state expires"""
|
||||
|
||||
|
||||
class UserMetadata(BaseModel):
|
||||
integration_credentials: list[Credentials] = Field(default_factory=list)
|
||||
integration_oauth_states: list[OAuthState] = Field(default_factory=list)
|
||||
|
||||
|
||||
class UserMetadataRaw(TypedDict, total=False):
|
||||
integration_credentials: list[dict]
|
||||
integration_oauth_states: list[dict]
|
||||
|
|
|
@ -0,0 +1,15 @@
|
|||
from .base import BaseOAuthHandler
|
||||
from .github import GitHubOAuthHandler
|
||||
from .google import GoogleOAuthHandler
|
||||
from .notion import NotionOAuthHandler
|
||||
|
||||
HANDLERS_BY_NAME: dict[str, type[BaseOAuthHandler]] = {
|
||||
handler.PROVIDER_NAME: handler
|
||||
for handler in [
|
||||
GitHubOAuthHandler,
|
||||
GoogleOAuthHandler,
|
||||
NotionOAuthHandler,
|
||||
]
|
||||
}
|
||||
|
||||
__all__ = ["HANDLERS_BY_NAME"]
|
|
@ -0,0 +1,48 @@
|
|||
import time
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import ClassVar
|
||||
|
||||
from autogpt_libs.supabase_integration_credentials_store import OAuth2Credentials
|
||||
|
||||
|
||||
class BaseOAuthHandler(ABC):
|
||||
PROVIDER_NAME: ClassVar[str]
|
||||
|
||||
@abstractmethod
|
||||
def __init__(self, client_id: str, client_secret: str, redirect_uri: str): ...
|
||||
|
||||
@abstractmethod
|
||||
def get_login_url(self, scopes: list[str], state: str) -> str:
|
||||
"""Constructs a login URL that the user can be redirected to"""
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
def exchange_code_for_tokens(self, code: str) -> OAuth2Credentials:
|
||||
"""Exchanges the acquired authorization code from login for a set of tokens"""
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
def _refresh_tokens(self, credentials: OAuth2Credentials) -> OAuth2Credentials:
|
||||
"""Implements the token refresh mechanism"""
|
||||
...
|
||||
|
||||
def refresh_tokens(self, credentials: OAuth2Credentials) -> OAuth2Credentials:
|
||||
if credentials.provider != self.PROVIDER_NAME:
|
||||
raise ValueError(
|
||||
f"{self.__class__.__name__} can not refresh tokens "
|
||||
f"for other provider '{credentials.provider}'"
|
||||
)
|
||||
return self._refresh_tokens(credentials)
|
||||
|
||||
def get_access_token(self, credentials: OAuth2Credentials) -> str:
|
||||
"""Returns a valid access token, refreshing it first if needed"""
|
||||
if self.needs_refresh(credentials):
|
||||
credentials = self.refresh_tokens(credentials)
|
||||
return credentials.access_token.get_secret_value()
|
||||
|
||||
def needs_refresh(self, credentials: OAuth2Credentials) -> bool:
|
||||
"""Indicates whether the given tokens need to be refreshed"""
|
||||
return (
|
||||
credentials.access_token_expires_at is not None
|
||||
and credentials.access_token_expires_at < int(time.time()) + 300
|
||||
)
|
|
@ -0,0 +1,99 @@
|
|||
import time
|
||||
from typing import Optional
|
||||
from urllib.parse import urlencode
|
||||
|
||||
import requests
|
||||
from autogpt_libs.supabase_integration_credentials_store import OAuth2Credentials
|
||||
|
||||
from .base import BaseOAuthHandler
|
||||
|
||||
|
||||
class GitHubOAuthHandler(BaseOAuthHandler):
|
||||
"""
|
||||
Based on the documentation at:
|
||||
- [Authorizing OAuth apps - GitHub Docs](https://docs.github.com/en/apps/oauth-apps/building-oauth-apps/authorizing-oauth-apps)
|
||||
- [Refreshing user access tokens - GitHub Docs](https://docs.github.com/en/apps/creating-github-apps/authenticating-with-a-github-app/refreshing-user-access-tokens)
|
||||
|
||||
Notes:
|
||||
- By default, token expiration is disabled on GitHub Apps. This means the access
|
||||
token doesn't expire and no refresh token is returned by the authorization flow.
|
||||
- When token expiration gets enabled, any existing tokens will remain non-expiring.
|
||||
- When token expiration gets disabled, token refreshes will return a non-expiring
|
||||
access token *with no refresh token*.
|
||||
""" # noqa
|
||||
|
||||
PROVIDER_NAME = "github"
|
||||
|
||||
def __init__(self, client_id: str, client_secret: str, redirect_uri: str):
|
||||
self.client_id = client_id
|
||||
self.client_secret = client_secret
|
||||
self.redirect_uri = redirect_uri
|
||||
self.auth_base_url = "https://github.com/login/oauth/authorize"
|
||||
self.token_url = "https://github.com/login/oauth/access_token"
|
||||
|
||||
def get_login_url(self, scopes: list[str], state: str) -> str:
|
||||
params = {
|
||||
"client_id": self.client_id,
|
||||
"redirect_uri": self.redirect_uri,
|
||||
"scope": " ".join(scopes),
|
||||
"state": state,
|
||||
}
|
||||
return f"{self.auth_base_url}?{urlencode(params)}"
|
||||
|
||||
def exchange_code_for_tokens(self, code: str) -> OAuth2Credentials:
|
||||
return self._request_tokens({"code": code, "redirect_uri": self.redirect_uri})
|
||||
|
||||
def _refresh_tokens(self, credentials: OAuth2Credentials) -> OAuth2Credentials:
|
||||
if not credentials.refresh_token:
|
||||
return credentials
|
||||
|
||||
return self._request_tokens(
|
||||
{
|
||||
"refresh_token": credentials.refresh_token.get_secret_value(),
|
||||
"grant_type": "refresh_token",
|
||||
}
|
||||
)
|
||||
|
||||
def _request_tokens(
|
||||
self,
|
||||
params: dict[str, str],
|
||||
current_credentials: Optional[OAuth2Credentials] = None,
|
||||
) -> OAuth2Credentials:
|
||||
request_body = {
|
||||
"client_id": self.client_id,
|
||||
"client_secret": self.client_secret,
|
||||
**params,
|
||||
}
|
||||
headers = {"Accept": "application/json"}
|
||||
response = requests.post(self.token_url, data=request_body, headers=headers)
|
||||
response.raise_for_status()
|
||||
token_data: dict = response.json()
|
||||
|
||||
now = int(time.time())
|
||||
new_credentials = OAuth2Credentials(
|
||||
provider=self.PROVIDER_NAME,
|
||||
title=current_credentials.title if current_credentials else "GitHub",
|
||||
access_token=token_data["access_token"],
|
||||
# Token refresh responses have an empty `scope` property (see docs),
|
||||
# so we have to get the scope from the existing credentials object.
|
||||
scopes=(
|
||||
token_data.get("scope", "").split(",")
|
||||
or (current_credentials.scopes if current_credentials else [])
|
||||
),
|
||||
# Refresh token and expiration intervals are only given if token expiration
|
||||
# is enabled in the GitHub App's settings.
|
||||
refresh_token=token_data.get("refresh_token"),
|
||||
access_token_expires_at=(
|
||||
now + expires_in
|
||||
if (expires_in := token_data.get("expires_in", None))
|
||||
else None
|
||||
),
|
||||
refresh_token_expires_at=(
|
||||
now + expires_in
|
||||
if (expires_in := token_data.get("refresh_token_expires_in", None))
|
||||
else None
|
||||
),
|
||||
)
|
||||
if current_credentials:
|
||||
new_credentials.id = current_credentials.id
|
||||
return new_credentials
|
|
@ -0,0 +1,96 @@
|
|||
from autogpt_libs.supabase_integration_credentials_store import OAuth2Credentials
|
||||
from google.auth.transport.requests import Request
|
||||
from google.oauth2.credentials import Credentials
|
||||
from google_auth_oauthlib.flow import Flow
|
||||
from pydantic import SecretStr
|
||||
|
||||
from .base import BaseOAuthHandler
|
||||
|
||||
|
||||
class GoogleOAuthHandler(BaseOAuthHandler):
|
||||
"""
|
||||
Based on the documentation at https://developers.google.com/identity/protocols/oauth2/web-server
|
||||
""" # noqa
|
||||
|
||||
PROVIDER_NAME = "google"
|
||||
|
||||
def __init__(self, client_id: str, client_secret: str, redirect_uri: str):
|
||||
self.client_id = client_id
|
||||
self.client_secret = client_secret
|
||||
self.redirect_uri = redirect_uri
|
||||
self.token_uri = "https://oauth2.googleapis.com/token"
|
||||
|
||||
def get_login_url(self, scopes: list[str], state: str) -> str:
|
||||
flow = self._setup_oauth_flow(scopes)
|
||||
flow.redirect_uri = self.redirect_uri
|
||||
authorization_url, _ = flow.authorization_url(
|
||||
access_type="offline",
|
||||
include_granted_scopes="true",
|
||||
state=state,
|
||||
prompt="consent",
|
||||
)
|
||||
return authorization_url
|
||||
|
||||
def exchange_code_for_tokens(self, code: str) -> OAuth2Credentials:
|
||||
flow = self._setup_oauth_flow(None)
|
||||
flow.redirect_uri = self.redirect_uri
|
||||
flow.fetch_token(code=code)
|
||||
|
||||
google_creds = flow.credentials
|
||||
# Google's OAuth library is poorly typed so we need some of these:
|
||||
assert google_creds.token
|
||||
assert google_creds.refresh_token
|
||||
assert google_creds.expiry
|
||||
assert google_creds.scopes
|
||||
return OAuth2Credentials(
|
||||
provider=self.PROVIDER_NAME,
|
||||
title="Google",
|
||||
access_token=SecretStr(google_creds.token),
|
||||
refresh_token=SecretStr(google_creds.refresh_token),
|
||||
access_token_expires_at=int(google_creds.expiry.timestamp()),
|
||||
refresh_token_expires_at=None,
|
||||
scopes=google_creds.scopes,
|
||||
)
|
||||
|
||||
def _refresh_tokens(self, credentials: OAuth2Credentials) -> OAuth2Credentials:
|
||||
# Google credentials should ALWAYS have a refresh token
|
||||
assert credentials.refresh_token
|
||||
|
||||
google_creds = Credentials(
|
||||
token=credentials.access_token.get_secret_value(),
|
||||
refresh_token=credentials.refresh_token.get_secret_value(),
|
||||
token_uri=self.token_uri,
|
||||
client_id=self.client_id,
|
||||
client_secret=self.client_secret,
|
||||
scopes=credentials.scopes,
|
||||
)
|
||||
# Google's OAuth library is poorly typed so we need some of these:
|
||||
assert google_creds.refresh_token
|
||||
assert google_creds.scopes
|
||||
|
||||
google_creds.refresh(Request())
|
||||
assert google_creds.expiry
|
||||
|
||||
return OAuth2Credentials(
|
||||
id=credentials.id,
|
||||
provider=self.PROVIDER_NAME,
|
||||
title=credentials.title,
|
||||
access_token=SecretStr(google_creds.token),
|
||||
refresh_token=SecretStr(google_creds.refresh_token),
|
||||
access_token_expires_at=int(google_creds.expiry.timestamp()),
|
||||
refresh_token_expires_at=None,
|
||||
scopes=google_creds.scopes,
|
||||
)
|
||||
|
||||
def _setup_oauth_flow(self, scopes: list[str] | None) -> Flow:
|
||||
return Flow.from_client_config(
|
||||
{
|
||||
"web": {
|
||||
"client_id": self.client_id,
|
||||
"client_secret": self.client_secret,
|
||||
"auth_uri": "https://accounts.google.com/o/oauth2/auth",
|
||||
"token_uri": self.token_uri,
|
||||
}
|
||||
},
|
||||
scopes=scopes,
|
||||
)
|
|
@ -0,0 +1,76 @@
|
|||
from base64 import b64encode
|
||||
from urllib.parse import urlencode
|
||||
|
||||
import requests
|
||||
from autogpt_libs.supabase_integration_credentials_store import OAuth2Credentials
|
||||
|
||||
from .base import BaseOAuthHandler
|
||||
|
||||
|
||||
class NotionOAuthHandler(BaseOAuthHandler):
|
||||
"""
|
||||
Based on the documentation at https://developers.notion.com/docs/authorization
|
||||
|
||||
Notes:
|
||||
- Notion uses non-expiring access tokens and therefore doesn't have a refresh flow
|
||||
- Notion doesn't use scopes
|
||||
"""
|
||||
|
||||
PROVIDER_NAME = "notion"
|
||||
|
||||
def __init__(self, client_id: str, client_secret: str, redirect_uri: str):
|
||||
self.client_id = client_id
|
||||
self.client_secret = client_secret
|
||||
self.redirect_uri = redirect_uri
|
||||
self.auth_base_url = "https://api.notion.com/v1/oauth/authorize"
|
||||
self.token_url = "https://api.notion.com/v1/oauth/token"
|
||||
|
||||
def get_login_url(self, scopes: list[str], state: str) -> str:
|
||||
params = {
|
||||
"client_id": self.client_id,
|
||||
"redirect_uri": self.redirect_uri,
|
||||
"response_type": "code",
|
||||
"owner": "user",
|
||||
"state": state,
|
||||
}
|
||||
return f"{self.auth_base_url}?{urlencode(params)}"
|
||||
|
||||
def exchange_code_for_tokens(self, code: str) -> OAuth2Credentials:
|
||||
request_body = {
|
||||
"grant_type": "authorization_code",
|
||||
"code": code,
|
||||
"redirect_uri": self.redirect_uri,
|
||||
}
|
||||
auth_str = b64encode(f"{self.client_id}:{self.client_secret}".encode()).decode()
|
||||
headers = {
|
||||
"Authorization": f"Basic {auth_str}",
|
||||
"Accept": "application/json",
|
||||
}
|
||||
response = requests.post(self.token_url, json=request_body, headers=headers)
|
||||
response.raise_for_status()
|
||||
token_data = response.json()
|
||||
|
||||
return OAuth2Credentials(
|
||||
provider=self.PROVIDER_NAME,
|
||||
title=token_data.get("workspace_name", "Notion"),
|
||||
access_token=token_data["access_token"],
|
||||
refresh_token=None,
|
||||
access_token_expires_at=None, # Notion tokens don't expire
|
||||
refresh_token_expires_at=None,
|
||||
scopes=[],
|
||||
metadata={
|
||||
"owner": token_data["owner"],
|
||||
"bot_id": token_data["bot_id"],
|
||||
"workspace_id": token_data["workspace_id"],
|
||||
"workspace_name": token_data.get("workspace_name"),
|
||||
"workspace_icon": token_data.get("workspace_icon"),
|
||||
},
|
||||
)
|
||||
|
||||
def _refresh_tokens(self, credentials: OAuth2Credentials) -> OAuth2Credentials:
|
||||
# Notion doesn't support token refresh
|
||||
return credentials
|
||||
|
||||
def needs_refresh(self, credentials: OAuth2Credentials) -> bool:
|
||||
# Notion access tokens don't expire
|
||||
return False
|
|
@ -0,0 +1,105 @@
|
|||
import logging
|
||||
from typing import Annotated, Literal
|
||||
|
||||
from autogpt_libs.supabase_integration_credentials_store import (
|
||||
SupabaseIntegrationCredentialsStore,
|
||||
)
|
||||
from fastapi import APIRouter, Body, Depends, HTTPException, Path, Query, Request
|
||||
from pydantic import BaseModel
|
||||
from supabase import Client
|
||||
|
||||
from autogpt_server.integrations.oauth import HANDLERS_BY_NAME, BaseOAuthHandler
|
||||
from autogpt_server.util.settings import Settings
|
||||
|
||||
from .utils import get_supabase, get_user_id
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
settings = Settings()
|
||||
integrations_api_router = APIRouter()
|
||||
|
||||
|
||||
def get_store(supabase: Client = Depends(get_supabase)):
|
||||
return SupabaseIntegrationCredentialsStore(supabase)
|
||||
|
||||
|
||||
class LoginResponse(BaseModel):
|
||||
login_url: str
|
||||
|
||||
|
||||
@integrations_api_router.get("/{provider}/login")
|
||||
async def login(
|
||||
provider: Annotated[str, Path(title="The provider to initiate an OAuth flow for")],
|
||||
user_id: Annotated[str, Depends(get_user_id)],
|
||||
request: Request,
|
||||
store: Annotated[SupabaseIntegrationCredentialsStore, Depends(get_store)],
|
||||
scopes: Annotated[
|
||||
str, Query(title="Comma-separated list of authorization scopes")
|
||||
] = "",
|
||||
) -> LoginResponse:
|
||||
handler = _get_provider_oauth_handler(request, provider)
|
||||
|
||||
# Generate and store a secure random state token
|
||||
state = await store.store_state_token(user_id, provider)
|
||||
|
||||
requested_scopes = scopes.split(",") if scopes else []
|
||||
login_url = handler.get_login_url(requested_scopes, state)
|
||||
|
||||
return LoginResponse(login_url=login_url)
|
||||
|
||||
|
||||
class CredentialsMetaResponse(BaseModel):
|
||||
credentials_id: str
|
||||
credentials_type: Literal["oauth2", "api_key"]
|
||||
|
||||
|
||||
@integrations_api_router.post("/{provider}/callback")
|
||||
async def callback(
|
||||
provider: Annotated[str, Path(title="The target provider for this OAuth exchange")],
|
||||
code: Annotated[str, Body(title="Authorization code acquired by user login")],
|
||||
state_token: Annotated[str, Body(title="Anti-CSRF nonce")],
|
||||
store: Annotated[SupabaseIntegrationCredentialsStore, Depends(get_store)],
|
||||
user_id: Annotated[str, Depends(get_user_id)],
|
||||
request: Request,
|
||||
) -> CredentialsMetaResponse:
|
||||
handler = _get_provider_oauth_handler(request, provider)
|
||||
|
||||
# Verify the state token
|
||||
if not await store.verify_state_token(user_id, state_token, provider):
|
||||
raise HTTPException(status_code=400, detail="Invalid or expired state token")
|
||||
|
||||
try:
|
||||
credentials = handler.exchange_code_for_tokens(code)
|
||||
except Exception as e:
|
||||
logger.warning(f"Code->Token exchange failed for provider {provider}: {e}")
|
||||
raise HTTPException(status_code=400, detail=str(e))
|
||||
|
||||
store.add_creds(user_id, credentials)
|
||||
return CredentialsMetaResponse(
|
||||
credentials_id=credentials.id,
|
||||
credentials_type=credentials.type,
|
||||
)
|
||||
|
||||
|
||||
# -------- UTILITIES --------- #
|
||||
|
||||
|
||||
def _get_provider_oauth_handler(req: Request, provider_name: str) -> BaseOAuthHandler:
|
||||
if provider_name not in HANDLERS_BY_NAME:
|
||||
raise HTTPException(
|
||||
status_code=404, detail=f"Unknown provider '{provider_name}'"
|
||||
)
|
||||
|
||||
client_id = getattr(settings.secrets, f"{provider_name}_client_id")
|
||||
client_secret = getattr(settings.secrets, f"{provider_name}_client_secret")
|
||||
if not (client_id and client_secret):
|
||||
raise HTTPException(
|
||||
status_code=501,
|
||||
detail=f"Integration with provider '{provider_name}' is not configured",
|
||||
)
|
||||
|
||||
handler_class = HANDLERS_BY_NAME[provider_name]
|
||||
return handler_class(
|
||||
client_id=client_id,
|
||||
client_secret=client_secret,
|
||||
redirect_uri=str(req.url_for("callback", provider=provider_name)),
|
||||
)
|
|
@ -19,11 +19,12 @@ from autogpt_server.data.queue import AsyncEventQueue, AsyncRedisEventQueue
|
|||
from autogpt_server.data.user import get_or_create_user
|
||||
from autogpt_server.executor import ExecutionManager, ExecutionScheduler
|
||||
from autogpt_server.server.model import CreateGraph, SetGraphActiveVersion
|
||||
from autogpt_server.util.auth import get_user_id
|
||||
from autogpt_server.util.lock import KeyedMutex
|
||||
from autogpt_server.util.service import AppService, expose, get_service_client
|
||||
from autogpt_server.util.settings import Settings
|
||||
|
||||
from .utils import get_user_id
|
||||
|
||||
settings = Settings()
|
||||
|
||||
|
||||
|
@ -70,127 +71,132 @@ class AgentServer(AppService):
|
|||
)
|
||||
|
||||
# Define the API routes
|
||||
router = APIRouter(prefix="/api")
|
||||
router.dependencies.append(Depends(auth_middleware))
|
||||
api_router = APIRouter(prefix="/api")
|
||||
api_router.dependencies.append(Depends(auth_middleware))
|
||||
|
||||
router.add_api_route(
|
||||
# Import & Attach sub-routers
|
||||
from .integrations import integrations_api_router
|
||||
|
||||
api_router.include_router(integrations_api_router, prefix="/integrations")
|
||||
|
||||
api_router.add_api_route(
|
||||
path="/auth/user",
|
||||
endpoint=self.get_or_create_user_route,
|
||||
methods=["POST"],
|
||||
)
|
||||
|
||||
router.add_api_route(
|
||||
api_router.add_api_route(
|
||||
path="/blocks",
|
||||
endpoint=self.get_graph_blocks,
|
||||
methods=["GET"],
|
||||
)
|
||||
router.add_api_route(
|
||||
api_router.add_api_route(
|
||||
path="/blocks/{block_id}/execute",
|
||||
endpoint=self.execute_graph_block,
|
||||
methods=["POST"],
|
||||
)
|
||||
router.add_api_route(
|
||||
api_router.add_api_route(
|
||||
path="/graphs",
|
||||
endpoint=self.get_graphs,
|
||||
methods=["GET"],
|
||||
)
|
||||
router.add_api_route(
|
||||
api_router.add_api_route(
|
||||
path="/templates",
|
||||
endpoint=self.get_templates,
|
||||
methods=["GET"],
|
||||
)
|
||||
router.add_api_route(
|
||||
api_router.add_api_route(
|
||||
path="/graphs",
|
||||
endpoint=self.create_new_graph,
|
||||
methods=["POST"],
|
||||
)
|
||||
router.add_api_route(
|
||||
api_router.add_api_route(
|
||||
path="/templates",
|
||||
endpoint=self.create_new_template,
|
||||
methods=["POST"],
|
||||
)
|
||||
router.add_api_route(
|
||||
api_router.add_api_route(
|
||||
path="/graphs/{graph_id}",
|
||||
endpoint=self.get_graph,
|
||||
methods=["GET"],
|
||||
)
|
||||
router.add_api_route(
|
||||
api_router.add_api_route(
|
||||
path="/templates/{graph_id}",
|
||||
endpoint=self.get_template,
|
||||
methods=["GET"],
|
||||
)
|
||||
router.add_api_route(
|
||||
api_router.add_api_route(
|
||||
path="/graphs/{graph_id}",
|
||||
endpoint=self.update_graph,
|
||||
methods=["PUT"],
|
||||
)
|
||||
router.add_api_route(
|
||||
api_router.add_api_route(
|
||||
path="/templates/{graph_id}",
|
||||
endpoint=self.update_graph,
|
||||
methods=["PUT"],
|
||||
)
|
||||
router.add_api_route(
|
||||
api_router.add_api_route(
|
||||
path="/graphs/{graph_id}/versions",
|
||||
endpoint=self.get_graph_all_versions,
|
||||
methods=["GET"],
|
||||
)
|
||||
router.add_api_route(
|
||||
api_router.add_api_route(
|
||||
path="/templates/{graph_id}/versions",
|
||||
endpoint=self.get_graph_all_versions,
|
||||
methods=["GET"],
|
||||
)
|
||||
router.add_api_route(
|
||||
api_router.add_api_route(
|
||||
path="/graphs/{graph_id}/versions/{version}",
|
||||
endpoint=self.get_graph,
|
||||
methods=["GET"],
|
||||
)
|
||||
router.add_api_route(
|
||||
api_router.add_api_route(
|
||||
path="/graphs/{graph_id}/versions/active",
|
||||
endpoint=self.set_graph_active_version,
|
||||
methods=["PUT"],
|
||||
)
|
||||
router.add_api_route(
|
||||
api_router.add_api_route(
|
||||
path="/graphs/{graph_id}/input_schema",
|
||||
endpoint=self.get_graph_input_schema,
|
||||
methods=["GET"],
|
||||
)
|
||||
router.add_api_route(
|
||||
api_router.add_api_route(
|
||||
path="/graphs/{graph_id}/execute",
|
||||
endpoint=self.execute_graph,
|
||||
methods=["POST"],
|
||||
)
|
||||
router.add_api_route(
|
||||
api_router.add_api_route(
|
||||
path="/graphs/{graph_id}/executions",
|
||||
endpoint=self.list_graph_runs,
|
||||
methods=["GET"],
|
||||
)
|
||||
router.add_api_route(
|
||||
api_router.add_api_route(
|
||||
path="/graphs/{graph_id}/executions/{graph_exec_id}",
|
||||
endpoint=self.get_graph_run_node_execution_results,
|
||||
methods=["GET"],
|
||||
)
|
||||
router.add_api_route(
|
||||
api_router.add_api_route(
|
||||
path="/graphs/{graph_id}/executions/{graph_exec_id}/stop",
|
||||
endpoint=self.stop_graph_run,
|
||||
methods=["POST"],
|
||||
)
|
||||
router.add_api_route(
|
||||
api_router.add_api_route(
|
||||
path="/graphs/{graph_id}/schedules",
|
||||
endpoint=self.create_schedule,
|
||||
methods=["POST"],
|
||||
)
|
||||
router.add_api_route(
|
||||
api_router.add_api_route(
|
||||
path="/graphs/{graph_id}/schedules",
|
||||
endpoint=self.get_execution_schedules,
|
||||
methods=["GET"],
|
||||
)
|
||||
router.add_api_route(
|
||||
api_router.add_api_route(
|
||||
path="/graphs/schedules/{schedule_id}",
|
||||
endpoint=self.update_schedule,
|
||||
methods=["PUT"],
|
||||
)
|
||||
|
||||
router.add_api_route(
|
||||
api_router.add_api_route(
|
||||
path="/settings",
|
||||
endpoint=self.update_configuration,
|
||||
methods=["POST"],
|
||||
|
@ -198,7 +204,7 @@ class AgentServer(AppService):
|
|||
|
||||
app.add_exception_handler(500, self.handle_internal_http_error)
|
||||
|
||||
app.include_router(router)
|
||||
app.include_router(api_router)
|
||||
|
||||
uvicorn.run(app, host="0.0.0.0", port=8000, log_config=None)
|
||||
|
||||
|
|
|
@ -1,7 +1,11 @@
|
|||
from autogpt_libs.auth import auth_middleware
|
||||
from autogpt_libs.auth.middleware import auth_middleware
|
||||
from fastapi import Depends, HTTPException
|
||||
from supabase import Client, create_client
|
||||
|
||||
from autogpt_server.data.user import DEFAULT_USER_ID
|
||||
from autogpt_server.util.settings import Settings
|
||||
|
||||
settings = Settings()
|
||||
|
||||
|
||||
def get_user_id(payload: dict = Depends(auth_middleware)) -> str:
|
||||
|
@ -13,3 +17,7 @@ def get_user_id(payload: dict = Depends(auth_middleware)) -> str:
|
|||
if not user_id:
|
||||
raise HTTPException(status_code=401, detail="User ID not found in token")
|
||||
return user_id
|
||||
|
||||
|
||||
def get_supabase() -> Client:
|
||||
return create_client(settings.secrets.supabase_url, settings.secrets.supabase_key)
|
|
@ -93,6 +93,23 @@ class Config(UpdateTrackingModel["Config"], BaseSettings):
|
|||
class Secrets(UpdateTrackingModel["Secrets"], BaseSettings):
|
||||
"""Secrets for the server."""
|
||||
|
||||
supabase_url: str = Field(default="", description="Supabase URL")
|
||||
supabase_key: str = Field(default="", description="Supabase key")
|
||||
|
||||
# OAuth server credentials for integrations
|
||||
github_client_id: str = Field(default="", description="GitHub OAuth client ID")
|
||||
github_client_secret: str = Field(
|
||||
default="", description="GitHub OAuth client secret"
|
||||
)
|
||||
google_client_id: str = Field(default="", description="Google OAuth client ID")
|
||||
google_client_secret: str = Field(
|
||||
default="", description="Google OAuth client secret"
|
||||
)
|
||||
notion_client_id: str = Field(default="", description="Notion OAuth client ID")
|
||||
notion_client_secret: str = Field(
|
||||
default="", description="Notion OAuth client secret"
|
||||
)
|
||||
|
||||
openai_api_key: str = Field(default="", description="OpenAI API key")
|
||||
anthropic_api_key: str = Field(default="", description="Anthropic API key")
|
||||
groq_api_key: str = Field(default="", description="Groq API key")
|
||||
|
|
|
@ -1,4 +1,4 @@
|
|||
# This file is automatically @generated by Poetry 1.7.1 and should not be changed by hand.
|
||||
# This file is automatically @generated by Poetry 1.8.3 and should not be changed by hand.
|
||||
|
||||
[[package]]
|
||||
name = "agpt"
|
||||
|
@ -25,7 +25,7 @@ requests = "*"
|
|||
sentry-sdk = "^1.40.4"
|
||||
|
||||
[package.extras]
|
||||
benchmark = ["agbenchmark @ file:///Users/majdyz/Code/AutoGPT/benchmark"]
|
||||
benchmark = ["agbenchmark @ file:///home/reinier/code/agpt/AutoGPT/benchmark"]
|
||||
|
||||
[package.source]
|
||||
type = "directory"
|
||||
|
@ -386,7 +386,7 @@ watchdog = "4.0.0"
|
|||
webdriver-manager = "^4.0.1"
|
||||
|
||||
[package.extras]
|
||||
benchmark = ["agbenchmark @ file:///Users/majdyz/Code/AutoGPT/benchmark"]
|
||||
benchmark = ["agbenchmark @ file:///home/reinier/code/agpt/AutoGPT/benchmark"]
|
||||
|
||||
[package.source]
|
||||
type = "directory"
|
||||
|
@ -3429,6 +3429,8 @@ files = [
|
|||
{file = "orjson-3.10.6-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:960db0e31c4e52fa0fc3ecbaea5b2d3b58f379e32a95ae6b0ebeaa25b93dfd34"},
|
||||
{file = "orjson-3.10.6-cp312-none-win32.whl", hash = "sha256:a6ea7afb5b30b2317e0bee03c8d34c8181bc5a36f2afd4d0952f378972c4efd5"},
|
||||
{file = "orjson-3.10.6-cp312-none-win_amd64.whl", hash = "sha256:874ce88264b7e655dde4aeaacdc8fd772a7962faadfb41abe63e2a4861abc3dc"},
|
||||
{file = "orjson-3.10.6-cp313-none-win32.whl", hash = "sha256:efdf2c5cde290ae6b83095f03119bdc00303d7a03b42b16c54517baa3c4ca3d0"},
|
||||
{file = "orjson-3.10.6-cp313-none-win_amd64.whl", hash = "sha256:8e190fe7888e2e4392f52cafb9626113ba135ef53aacc65cd13109eb9746c43e"},
|
||||
{file = "orjson-3.10.6-cp38-cp38-macosx_10_15_x86_64.macosx_11_0_arm64.macosx_10_15_universal2.whl", hash = "sha256:66680eae4c4e7fc193d91cfc1353ad6d01b4801ae9b5314f17e11ba55e934183"},
|
||||
{file = "orjson-3.10.6-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:caff75b425db5ef8e8f23af93c80f072f97b4fb3afd4af44482905c9f588da28"},
|
||||
{file = "orjson-3.10.6-cp38-cp38-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:3722fddb821b6036fd2a3c814f6bd9b57a89dc6337b9924ecd614ebce3271394"},
|
||||
|
@ -6452,4 +6454,4 @@ test = ["big-O", "importlib-resources", "jaraco.functools", "jaraco.itertools",
|
|||
[metadata]
|
||||
lock-version = "2.0"
|
||||
python-versions = "^3.10"
|
||||
content-hash = "126731188e8fdc7df0bd2dc92cd069fcd2b90edd5e1065cebd8f4adcedf982b5"
|
||||
content-hash = "0ecd19c5cdf414368aa81b83ae76ba6db34b1bfc8f32a482d1222d6b839792da"
|
||||
|
|
|
@ -11,6 +11,7 @@ readme = "README.md"
|
|||
[tool.poetry.dependencies]
|
||||
python = "^3.10"
|
||||
agpt = { path = "../../autogpt", develop = true }
|
||||
aio-pika = "^9.4.3"
|
||||
anthropic = "^0.25.1"
|
||||
apscheduler = "^3.10.4"
|
||||
autogpt-forge = { path = "../../forge", develop = true }
|
||||
|
@ -39,15 +40,14 @@ pyro5 = "^5.15"
|
|||
pytest = "^8.2.1"
|
||||
pytest-asyncio = "^0.23.7"
|
||||
python-dotenv = "^1.0.1"
|
||||
redis = "^5.0.8"
|
||||
sentry-sdk = "1.45.0"
|
||||
supabase = "^2.7.2"
|
||||
tenacity = "^8.3.0"
|
||||
uvicorn = { extras = ["standard"], version = "^0.30.1" }
|
||||
websockets = "^12.0"
|
||||
youtube-transcript-api = "^0.6.2"
|
||||
|
||||
aio-pika = "^9.4.3"
|
||||
redis = "^5.0.8"
|
||||
sentry-sdk = "1.45.0"
|
||||
|
||||
[tool.poetry.group.dev.dependencies]
|
||||
poethepoet = "^0.26.1"
|
||||
httpx = "^0.27.0"
|
||||
|
|
Loading…
Reference in New Issue