forked from Significant-Gravitas/AutoGPT
Compare commits
46 Commits
master
...
ntindle/op
Author | SHA1 | Date |
---|---|---|
|
445dff9633 | |
|
56612f16cf | |
|
0d2bb46786 | |
|
c61317e448 | |
|
3c30783b14 | |
|
56b33327ab | |
|
c36c239dd5 | |
|
e53f1eaf80 | |
|
ade7d1a194 | |
|
23672e3ed2 | |
|
0029e255c1 | |
|
04915f2db0 | |
|
9d79bfadea | |
|
c1c6bb29df | |
|
5f50c4863d | |
|
7bf470f4df | |
|
2fd8c8d261 | |
|
6b31356264 | |
|
a88c865437 | |
|
287aa819bb | |
|
db21c6d4bc | |
|
59dd75d016 | |
|
38761f6706 | |
|
513e4eae4b | |
|
fec9d348a0 | |
|
75634e6155 | |
|
6cf77c264a | |
|
2b5c94d508 | |
|
1c6b33d9fb | |
|
d4692f33e2 | |
|
d7a9563d49 | |
|
2ea61f8b65 | |
|
2a5f3d167d | |
|
c0a5a01311 | |
|
0aee309f72 | |
|
4c07f6c633 | |
|
c39f27bcd4 | |
|
35dcc6a2a1 | |
|
bef5637f29 | |
|
e933502cbd | |
|
5720225a75 | |
|
cb3808cb78 | |
|
b6b97f10b8 | |
|
0a905c6d66 | |
|
6b3f5b413f | |
|
8d79a62f61 |
|
@ -15,6 +15,9 @@ REDIS_PORT=6379
|
||||||
REDIS_PASSWORD=password
|
REDIS_PASSWORD=password
|
||||||
|
|
||||||
ENABLE_CREDIT=false
|
ENABLE_CREDIT=false
|
||||||
|
STRIPE_API_KEY=
|
||||||
|
STRIPE_WEBHOOK_SECRET=
|
||||||
|
|
||||||
# What environment things should be logged under: local dev or prod
|
# What environment things should be logged under: local dev or prod
|
||||||
APP_ENV=local
|
APP_ENV=local
|
||||||
# What environment to behave as: "local" or "cloud"
|
# What environment to behave as: "local" or "cloud"
|
||||||
|
@ -36,7 +39,7 @@ SUPABASE_JWT_SECRET=your-super-secret-jwt-token-with-at-least-32-characters-long
|
||||||
## to use the platform's webhook-related functionality.
|
## to use the platform's webhook-related functionality.
|
||||||
## If you are developing locally, you can use something like ngrok to get a publc URL
|
## If you are developing locally, you can use something like ngrok to get a publc URL
|
||||||
## and tunnel it to your locally running backend.
|
## and tunnel it to your locally running backend.
|
||||||
PLATFORM_BASE_URL=https://your-public-url-here
|
PLATFORM_BASE_URL=http://localhost:3000
|
||||||
|
|
||||||
## == INTEGRATION CREDENTIALS == ##
|
## == INTEGRATION CREDENTIALS == ##
|
||||||
# Each set of server side credentials is required for the corresponding 3rd party
|
# Each set of server side credentials is required for the corresponding 3rd party
|
||||||
|
@ -72,6 +75,12 @@ GOOGLE_CLIENT_SECRET=
|
||||||
TWITTER_CLIENT_ID=
|
TWITTER_CLIENT_ID=
|
||||||
TWITTER_CLIENT_SECRET=
|
TWITTER_CLIENT_SECRET=
|
||||||
|
|
||||||
|
# Linear App
|
||||||
|
# Make a new workspace for your OAuth APP -- trust me
|
||||||
|
# https://linear.app/settings/api/applications/new
|
||||||
|
# Callback URL: http://localhost:3000/auth/integrations/oauth_callback
|
||||||
|
LINEAR_CLIENT_ID=
|
||||||
|
LINEAR_CLIENT_SECRET=
|
||||||
|
|
||||||
## ===== OPTIONAL API KEYS ===== ##
|
## ===== OPTIONAL API KEYS ===== ##
|
||||||
|
|
||||||
|
@ -82,10 +91,12 @@ GROQ_API_KEY=
|
||||||
OPEN_ROUTER_API_KEY=
|
OPEN_ROUTER_API_KEY=
|
||||||
|
|
||||||
# Reddit
|
# Reddit
|
||||||
|
# Go to https://www.reddit.com/prefs/apps and create a new app
|
||||||
|
# Choose "script" for the type
|
||||||
|
# Fill in the redirect uri as <your_frontend_url>/auth/integrations/oauth_callback, e.g. http://localhost:3000/auth/integrations/oauth_callback
|
||||||
REDDIT_CLIENT_ID=
|
REDDIT_CLIENT_ID=
|
||||||
REDDIT_CLIENT_SECRET=
|
REDDIT_CLIENT_SECRET=
|
||||||
REDDIT_USERNAME=
|
REDDIT_USER_AGENT="AutoGPT:1.0 (by /u/autogpt)"
|
||||||
REDDIT_PASSWORD=
|
|
||||||
|
|
||||||
# Discord
|
# Discord
|
||||||
DISCORD_BOT_TOKEN=
|
DISCORD_BOT_TOKEN=
|
||||||
|
|
|
@ -1,9 +1,11 @@
|
||||||
|
import enum
|
||||||
from typing import Any, List
|
from typing import Any, List
|
||||||
|
|
||||||
from backend.data.block import Block, BlockCategory, BlockOutput, BlockSchema, BlockType
|
from backend.data.block import Block, BlockCategory, BlockOutput, BlockSchema, BlockType
|
||||||
from backend.data.model import SchemaField
|
from backend.data.model import SchemaField
|
||||||
from backend.util.mock import MockObject
|
from backend.util.mock import MockObject
|
||||||
from backend.util.text import TextFormatter
|
from backend.util.text import TextFormatter
|
||||||
|
from backend.util.type import convert
|
||||||
|
|
||||||
formatter = TextFormatter()
|
formatter = TextFormatter()
|
||||||
|
|
||||||
|
@ -590,3 +592,47 @@ class CreateListBlock(Block):
|
||||||
yield "list", input_data.values
|
yield "list", input_data.values
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
yield "error", f"Failed to create list: {str(e)}"
|
yield "error", f"Failed to create list: {str(e)}"
|
||||||
|
|
||||||
|
|
||||||
|
class TypeOptions(enum.Enum):
|
||||||
|
STRING = "string"
|
||||||
|
NUMBER = "number"
|
||||||
|
BOOLEAN = "boolean"
|
||||||
|
LIST = "list"
|
||||||
|
DICTIONARY = "dictionary"
|
||||||
|
|
||||||
|
|
||||||
|
class UniversalTypeConverterBlock(Block):
|
||||||
|
class Input(BlockSchema):
|
||||||
|
value: Any = SchemaField(
|
||||||
|
description="The value to convert to a universal type."
|
||||||
|
)
|
||||||
|
type: TypeOptions = SchemaField(description="The type to convert the value to.")
|
||||||
|
|
||||||
|
class Output(BlockSchema):
|
||||||
|
value: Any = SchemaField(description="The converted value.")
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__(
|
||||||
|
id="95d1b990-ce13-4d88-9737-ba5c2070c97b",
|
||||||
|
description="This block is used to convert a value to a universal type.",
|
||||||
|
categories={BlockCategory.BASIC},
|
||||||
|
input_schema=UniversalTypeConverterBlock.Input,
|
||||||
|
output_schema=UniversalTypeConverterBlock.Output,
|
||||||
|
)
|
||||||
|
|
||||||
|
def run(self, input_data: Input, **kwargs) -> BlockOutput:
|
||||||
|
try:
|
||||||
|
converted_value = convert(
|
||||||
|
input_data.value,
|
||||||
|
{
|
||||||
|
TypeOptions.STRING: str,
|
||||||
|
TypeOptions.NUMBER: float,
|
||||||
|
TypeOptions.BOOLEAN: bool,
|
||||||
|
TypeOptions.LIST: list,
|
||||||
|
TypeOptions.DICTIONARY: dict,
|
||||||
|
}[input_data.type],
|
||||||
|
)
|
||||||
|
yield "value", converted_value
|
||||||
|
except Exception as e:
|
||||||
|
yield "error", f"Failed to convert value: {str(e)}"
|
||||||
|
|
|
@ -1,22 +1,53 @@
|
||||||
import smtplib
|
import smtplib
|
||||||
from email.mime.multipart import MIMEMultipart
|
from email.mime.multipart import MIMEMultipart
|
||||||
from email.mime.text import MIMEText
|
from email.mime.text import MIMEText
|
||||||
|
from typing import Literal
|
||||||
|
|
||||||
from pydantic import BaseModel, ConfigDict
|
from pydantic import BaseModel, ConfigDict, SecretStr
|
||||||
|
|
||||||
from backend.data.block import Block, BlockCategory, BlockOutput, BlockSchema
|
from backend.data.block import Block, BlockCategory, BlockOutput, BlockSchema
|
||||||
from backend.data.model import BlockSecret, SchemaField, SecretField
|
from backend.data.model import (
|
||||||
|
CredentialsField,
|
||||||
|
CredentialsMetaInput,
|
||||||
|
SchemaField,
|
||||||
|
UserPasswordCredentials,
|
||||||
|
)
|
||||||
|
from backend.integrations.providers import ProviderName
|
||||||
|
|
||||||
|
TEST_CREDENTIALS = UserPasswordCredentials(
|
||||||
|
id="01234567-89ab-cdef-0123-456789abcdef",
|
||||||
|
provider="smtp",
|
||||||
|
username=SecretStr("mock-smtp-username"),
|
||||||
|
password=SecretStr("mock-smtp-password"),
|
||||||
|
title="Mock SMTP credentials",
|
||||||
|
)
|
||||||
|
|
||||||
|
TEST_CREDENTIALS_INPUT = {
|
||||||
|
"provider": TEST_CREDENTIALS.provider,
|
||||||
|
"id": TEST_CREDENTIALS.id,
|
||||||
|
"type": TEST_CREDENTIALS.type,
|
||||||
|
"title": TEST_CREDENTIALS.title,
|
||||||
|
}
|
||||||
|
SMTPCredentials = UserPasswordCredentials
|
||||||
|
SMTPCredentialsInput = CredentialsMetaInput[
|
||||||
|
Literal[ProviderName.SMTP],
|
||||||
|
Literal["user_password"],
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
class EmailCredentials(BaseModel):
|
def SMTPCredentialsField() -> SMTPCredentialsInput:
|
||||||
|
return CredentialsField(
|
||||||
|
description="The SMTP integration requires a username and password.",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class SMTPConfig(BaseModel):
|
||||||
smtp_server: str = SchemaField(
|
smtp_server: str = SchemaField(
|
||||||
default="smtp.gmail.com", description="SMTP server address"
|
default="smtp.example.com", description="SMTP server address"
|
||||||
)
|
)
|
||||||
smtp_port: int = SchemaField(default=25, description="SMTP port number")
|
smtp_port: int = SchemaField(default=25, description="SMTP port number")
|
||||||
smtp_username: BlockSecret = SecretField(key="smtp_username")
|
|
||||||
smtp_password: BlockSecret = SecretField(key="smtp_password")
|
|
||||||
|
|
||||||
model_config = ConfigDict(title="Email Credentials")
|
model_config = ConfigDict(title="SMTP Config")
|
||||||
|
|
||||||
|
|
||||||
class SendEmailBlock(Block):
|
class SendEmailBlock(Block):
|
||||||
|
@ -30,10 +61,11 @@ class SendEmailBlock(Block):
|
||||||
body: str = SchemaField(
|
body: str = SchemaField(
|
||||||
description="Body of the email", placeholder="Enter the email body"
|
description="Body of the email", placeholder="Enter the email body"
|
||||||
)
|
)
|
||||||
creds: EmailCredentials = SchemaField(
|
config: SMTPConfig = SchemaField(
|
||||||
description="SMTP credentials",
|
description="SMTP Config",
|
||||||
default=EmailCredentials(),
|
default=SMTPConfig(),
|
||||||
)
|
)
|
||||||
|
credentials: SMTPCredentialsInput = SMTPCredentialsField()
|
||||||
|
|
||||||
class Output(BlockSchema):
|
class Output(BlockSchema):
|
||||||
status: str = SchemaField(description="Status of the email sending operation")
|
status: str = SchemaField(description="Status of the email sending operation")
|
||||||
|
@ -43,7 +75,6 @@ class SendEmailBlock(Block):
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
super().__init__(
|
super().__init__(
|
||||||
disabled=True,
|
|
||||||
id="4335878a-394e-4e67-adf2-919877ff49ae",
|
id="4335878a-394e-4e67-adf2-919877ff49ae",
|
||||||
description="This block sends an email using the provided SMTP credentials.",
|
description="This block sends an email using the provided SMTP credentials.",
|
||||||
categories={BlockCategory.OUTPUT},
|
categories={BlockCategory.OUTPUT},
|
||||||
|
@ -53,25 +84,29 @@ class SendEmailBlock(Block):
|
||||||
"to_email": "recipient@example.com",
|
"to_email": "recipient@example.com",
|
||||||
"subject": "Test Email",
|
"subject": "Test Email",
|
||||||
"body": "This is a test email.",
|
"body": "This is a test email.",
|
||||||
"creds": {
|
"config": {
|
||||||
"smtp_server": "smtp.gmail.com",
|
"smtp_server": "smtp.gmail.com",
|
||||||
"smtp_port": 25,
|
"smtp_port": 25,
|
||||||
"smtp_username": "your-email@gmail.com",
|
|
||||||
"smtp_password": "your-gmail-password",
|
|
||||||
},
|
},
|
||||||
|
"credentials": TEST_CREDENTIALS_INPUT,
|
||||||
},
|
},
|
||||||
|
test_credentials=TEST_CREDENTIALS,
|
||||||
test_output=[("status", "Email sent successfully")],
|
test_output=[("status", "Email sent successfully")],
|
||||||
test_mock={"send_email": lambda *args, **kwargs: "Email sent successfully"},
|
test_mock={"send_email": lambda *args, **kwargs: "Email sent successfully"},
|
||||||
)
|
)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def send_email(
|
def send_email(
|
||||||
creds: EmailCredentials, to_email: str, subject: str, body: str
|
config: SMTPConfig,
|
||||||
|
to_email: str,
|
||||||
|
subject: str,
|
||||||
|
body: str,
|
||||||
|
credentials: SMTPCredentials,
|
||||||
) -> str:
|
) -> str:
|
||||||
smtp_server = creds.smtp_server
|
smtp_server = config.smtp_server
|
||||||
smtp_port = creds.smtp_port
|
smtp_port = config.smtp_port
|
||||||
smtp_username = creds.smtp_username.get_secret_value()
|
smtp_username = credentials.username.get_secret_value()
|
||||||
smtp_password = creds.smtp_password.get_secret_value()
|
smtp_password = credentials.password.get_secret_value()
|
||||||
|
|
||||||
msg = MIMEMultipart()
|
msg = MIMEMultipart()
|
||||||
msg["From"] = smtp_username
|
msg["From"] = smtp_username
|
||||||
|
@ -86,10 +121,13 @@ class SendEmailBlock(Block):
|
||||||
|
|
||||||
return "Email sent successfully"
|
return "Email sent successfully"
|
||||||
|
|
||||||
def run(self, input_data: Input, **kwargs) -> BlockOutput:
|
def run(
|
||||||
|
self, input_data: Input, *, credentials: SMTPCredentials, **kwargs
|
||||||
|
) -> BlockOutput:
|
||||||
yield "status", self.send_email(
|
yield "status", self.send_email(
|
||||||
input_data.creds,
|
config=input_data.config,
|
||||||
input_data.to_email,
|
to_email=input_data.to_email,
|
||||||
input_data.subject,
|
subject=input_data.subject,
|
||||||
input_data.body,
|
body=input_data.body,
|
||||||
|
credentials=credentials,
|
||||||
)
|
)
|
||||||
|
|
|
@ -1,4 +1,4 @@
|
||||||
from typing import List, Optional
|
from typing import List
|
||||||
|
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
|
|
||||||
|
|
|
@ -30,7 +30,7 @@ def _convert_to_api_url(url: str) -> str:
|
||||||
|
|
||||||
def _get_headers(credentials: GithubCredentials) -> dict[str, str]:
|
def _get_headers(credentials: GithubCredentials) -> dict[str, str]:
|
||||||
return {
|
return {
|
||||||
"Authorization": credentials.bearer(),
|
"Authorization": credentials.auth_header(),
|
||||||
"Accept": "application/vnd.github.v3+json",
|
"Accept": "application/vnd.github.v3+json",
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -0,0 +1,272 @@
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import json
|
||||||
|
from typing import Any, Dict, Optional
|
||||||
|
|
||||||
|
from backend.blocks.linear._auth import LinearCredentials
|
||||||
|
from backend.blocks.linear.models import (
|
||||||
|
CreateCommentResponse,
|
||||||
|
CreateIssueResponse,
|
||||||
|
Issue,
|
||||||
|
Project,
|
||||||
|
)
|
||||||
|
from backend.util.request import Requests
|
||||||
|
|
||||||
|
|
||||||
|
class LinearAPIException(Exception):
|
||||||
|
def __init__(self, message: str, status_code: int):
|
||||||
|
super().__init__(message)
|
||||||
|
self.status_code = status_code
|
||||||
|
|
||||||
|
|
||||||
|
class LinearClient:
|
||||||
|
"""Client for the Linear API
|
||||||
|
|
||||||
|
If you're looking for the schema: https://studio.apollographql.com/public/Linear-API/variant/current/schema
|
||||||
|
"""
|
||||||
|
|
||||||
|
API_URL = "https://api.linear.app/graphql"
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
credentials: LinearCredentials | None = None,
|
||||||
|
custom_requests: Optional[Requests] = None,
|
||||||
|
):
|
||||||
|
if custom_requests:
|
||||||
|
self._requests = custom_requests
|
||||||
|
else:
|
||||||
|
|
||||||
|
headers: Dict[str, str] = {
|
||||||
|
"Content-Type": "application/json",
|
||||||
|
}
|
||||||
|
if credentials:
|
||||||
|
headers["Authorization"] = credentials.bearer()
|
||||||
|
|
||||||
|
self._requests = Requests(
|
||||||
|
extra_headers=headers,
|
||||||
|
trusted_origins=["https://api.linear.app"],
|
||||||
|
raise_for_status=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
def _execute_graphql_request(
|
||||||
|
self, query: str, variables: dict | None = None
|
||||||
|
) -> Any:
|
||||||
|
"""
|
||||||
|
Executes a GraphQL request against the Linear API and returns the response data.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
query: The GraphQL query string.
|
||||||
|
variables (optional): Any GraphQL query variables
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The parsed JSON response data, or raises a LinearAPIException on error.
|
||||||
|
"""
|
||||||
|
payload: Dict[str, Any] = {"query": query}
|
||||||
|
if variables:
|
||||||
|
payload["variables"] = variables
|
||||||
|
|
||||||
|
response = self._requests.post(self.API_URL, json=payload)
|
||||||
|
|
||||||
|
if not response.ok:
|
||||||
|
|
||||||
|
try:
|
||||||
|
error_data = response.json()
|
||||||
|
error_message = error_data.get("errors", [{}])[0].get("message", "")
|
||||||
|
except json.JSONDecodeError:
|
||||||
|
error_message = response.text
|
||||||
|
|
||||||
|
raise LinearAPIException(
|
||||||
|
f"Linear API request failed ({response.status_code}): {error_message}",
|
||||||
|
response.status_code,
|
||||||
|
)
|
||||||
|
|
||||||
|
response_data = response.json()
|
||||||
|
if "errors" in response_data:
|
||||||
|
|
||||||
|
error_messages = [
|
||||||
|
error.get("message", "") for error in response_data["errors"]
|
||||||
|
]
|
||||||
|
raise LinearAPIException(
|
||||||
|
f"Linear API returned errors: {', '.join(error_messages)}",
|
||||||
|
response.status_code,
|
||||||
|
)
|
||||||
|
|
||||||
|
return response_data["data"]
|
||||||
|
|
||||||
|
def query(self, query: str, variables: Optional[dict] = None) -> dict:
|
||||||
|
"""Executes a GraphQL query.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
query: The GraphQL query string.
|
||||||
|
variables: Query variables, if any.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The response data.
|
||||||
|
"""
|
||||||
|
return self._execute_graphql_request(query, variables)
|
||||||
|
|
||||||
|
def mutate(self, mutation: str, variables: Optional[dict] = None) -> dict:
|
||||||
|
"""Executes a GraphQL mutation.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
mutation: The GraphQL mutation string.
|
||||||
|
variables: Query variables, if any.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The response data.
|
||||||
|
"""
|
||||||
|
return self._execute_graphql_request(mutation, variables)
|
||||||
|
|
||||||
|
def try_create_comment(self, issue_id: str, comment: str) -> CreateCommentResponse:
|
||||||
|
try:
|
||||||
|
mutation = """
|
||||||
|
mutation CommentCreate($input: CommentCreateInput!) {
|
||||||
|
commentCreate(input: $input) {
|
||||||
|
success
|
||||||
|
comment {
|
||||||
|
id
|
||||||
|
body
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
"""
|
||||||
|
|
||||||
|
variables = {
|
||||||
|
"input": {
|
||||||
|
"body": comment,
|
||||||
|
"issueId": issue_id,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
added_comment = self.mutate(mutation, variables)
|
||||||
|
# Select the commentCreate field from the mutation response
|
||||||
|
return CreateCommentResponse(**added_comment["commentCreate"])
|
||||||
|
except LinearAPIException as e:
|
||||||
|
raise e
|
||||||
|
|
||||||
|
def try_get_team_by_name(self, team_name: str) -> str:
|
||||||
|
try:
|
||||||
|
query = """
|
||||||
|
query GetTeamId($searchTerm: String!) {
|
||||||
|
teams(filter: {
|
||||||
|
or: [
|
||||||
|
{ name: { eqIgnoreCase: $searchTerm } },
|
||||||
|
{ key: { eqIgnoreCase: $searchTerm } }
|
||||||
|
]
|
||||||
|
}) {
|
||||||
|
nodes {
|
||||||
|
id
|
||||||
|
name
|
||||||
|
key
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
"""
|
||||||
|
|
||||||
|
variables: dict[str, Any] = {
|
||||||
|
"searchTerm": team_name,
|
||||||
|
}
|
||||||
|
|
||||||
|
team_id = self.query(query, variables)
|
||||||
|
return team_id["teams"]["nodes"][0]["id"]
|
||||||
|
except LinearAPIException as e:
|
||||||
|
raise e
|
||||||
|
|
||||||
|
def try_create_issue(
|
||||||
|
self,
|
||||||
|
team_id: str,
|
||||||
|
title: str,
|
||||||
|
description: str | None = None,
|
||||||
|
priority: int | None = None,
|
||||||
|
project_id: str | None = None,
|
||||||
|
) -> CreateIssueResponse:
|
||||||
|
try:
|
||||||
|
mutation = """
|
||||||
|
mutation IssueCreate($input: IssueCreateInput!) {
|
||||||
|
issueCreate(input: $input) {
|
||||||
|
issue {
|
||||||
|
title
|
||||||
|
description
|
||||||
|
id
|
||||||
|
identifier
|
||||||
|
priority
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
"""
|
||||||
|
|
||||||
|
variables: dict[str, Any] = {
|
||||||
|
"input": {
|
||||||
|
"teamId": team_id,
|
||||||
|
"title": title,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if project_id:
|
||||||
|
variables["input"]["projectId"] = project_id
|
||||||
|
|
||||||
|
if description:
|
||||||
|
variables["input"]["description"] = description
|
||||||
|
|
||||||
|
if priority:
|
||||||
|
variables["input"]["priority"] = priority
|
||||||
|
|
||||||
|
added_issue = self.mutate(mutation, variables)
|
||||||
|
return CreateIssueResponse(**added_issue["issueCreate"])
|
||||||
|
except LinearAPIException as e:
|
||||||
|
raise e
|
||||||
|
|
||||||
|
def try_search_projects(self, term: str) -> list[Project]:
|
||||||
|
try:
|
||||||
|
query = """
|
||||||
|
query SearchProjects($term: String!, $includeComments: Boolean!) {
|
||||||
|
searchProjects(term: $term, includeComments: $includeComments) {
|
||||||
|
nodes {
|
||||||
|
id
|
||||||
|
name
|
||||||
|
description
|
||||||
|
priority
|
||||||
|
progress
|
||||||
|
content
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
"""
|
||||||
|
|
||||||
|
variables: dict[str, Any] = {
|
||||||
|
"term": term,
|
||||||
|
"includeComments": True,
|
||||||
|
}
|
||||||
|
|
||||||
|
projects = self.query(query, variables)
|
||||||
|
return [
|
||||||
|
Project(**project) for project in projects["searchProjects"]["nodes"]
|
||||||
|
]
|
||||||
|
except LinearAPIException as e:
|
||||||
|
raise e
|
||||||
|
|
||||||
|
def try_search_issues(self, term: str) -> list[Issue]:
|
||||||
|
try:
|
||||||
|
query = """
|
||||||
|
query SearchIssues($term: String!, $includeComments: Boolean!) {
|
||||||
|
searchIssues(term: $term, includeComments: $includeComments) {
|
||||||
|
nodes {
|
||||||
|
id
|
||||||
|
identifier
|
||||||
|
title
|
||||||
|
description
|
||||||
|
priority
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
"""
|
||||||
|
|
||||||
|
variables: dict[str, Any] = {
|
||||||
|
"term": term,
|
||||||
|
"includeComments": True,
|
||||||
|
}
|
||||||
|
|
||||||
|
issues = self.query(query, variables)
|
||||||
|
return [Issue(**issue) for issue in issues["searchIssues"]["nodes"]]
|
||||||
|
except LinearAPIException as e:
|
||||||
|
raise e
|
|
@ -0,0 +1,101 @@
|
||||||
|
from enum import Enum
|
||||||
|
from typing import Literal
|
||||||
|
|
||||||
|
from pydantic import SecretStr
|
||||||
|
|
||||||
|
from backend.data.model import (
|
||||||
|
APIKeyCredentials,
|
||||||
|
CredentialsField,
|
||||||
|
CredentialsMetaInput,
|
||||||
|
OAuth2Credentials,
|
||||||
|
)
|
||||||
|
from backend.integrations.providers import ProviderName
|
||||||
|
from backend.util.settings import Secrets
|
||||||
|
|
||||||
|
secrets = Secrets()
|
||||||
|
LINEAR_OAUTH_IS_CONFIGURED = bool(
|
||||||
|
secrets.linear_client_id and secrets.linear_client_secret
|
||||||
|
)
|
||||||
|
|
||||||
|
LinearCredentials = OAuth2Credentials | APIKeyCredentials
|
||||||
|
# LinearCredentialsInput = CredentialsMetaInput[
|
||||||
|
# Literal[ProviderName.LINEAR],
|
||||||
|
# Literal["oauth2", "api_key"] if LINEAR_OAUTH_IS_CONFIGURED else Literal["oauth2"],
|
||||||
|
# ]
|
||||||
|
LinearCredentialsInput = CredentialsMetaInput[
|
||||||
|
Literal[ProviderName.LINEAR], Literal["oauth2"]
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
# (required) Comma separated list of scopes:
|
||||||
|
|
||||||
|
# read - (Default) Read access for the user's account. This scope will always be present.
|
||||||
|
|
||||||
|
# write - Write access for the user's account. If your application only needs to create comments, use a more targeted scope
|
||||||
|
|
||||||
|
# issues:create - Allows creating new issues and their attachments
|
||||||
|
|
||||||
|
# comments:create - Allows creating new issue comments
|
||||||
|
|
||||||
|
# timeSchedule:write - Allows creating and modifying time schedules
|
||||||
|
|
||||||
|
|
||||||
|
# admin - Full access to admin level endpoints. You should never ask for this permission unless it's absolutely needed
|
||||||
|
class LinearScope(str, Enum):
|
||||||
|
READ = "read"
|
||||||
|
WRITE = "write"
|
||||||
|
ISSUES_CREATE = "issues:create"
|
||||||
|
COMMENTS_CREATE = "comments:create"
|
||||||
|
TIME_SCHEDULE_WRITE = "timeSchedule:write"
|
||||||
|
ADMIN = "admin"
|
||||||
|
|
||||||
|
|
||||||
|
def LinearCredentialsField(scopes: list[LinearScope]) -> LinearCredentialsInput:
|
||||||
|
"""
|
||||||
|
Creates a Linear credentials input on a block.
|
||||||
|
|
||||||
|
Params:
|
||||||
|
scope: The authorization scope needed for the block to work. ([list of available scopes](https://docs.github.com/en/apps/oauth-apps/building-oauth-apps/scopes-for-oauth-apps#available-scopes))
|
||||||
|
""" # noqa
|
||||||
|
return CredentialsField(
|
||||||
|
required_scopes=set([LinearScope.READ.value]).union(
|
||||||
|
set([scope.value for scope in scopes])
|
||||||
|
),
|
||||||
|
description="The Linear integration can be used with OAuth, "
|
||||||
|
"or any API key with sufficient permissions for the blocks it is used on.",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
TEST_CREDENTIALS_OAUTH = OAuth2Credentials(
|
||||||
|
id="01234567-89ab-cdef-0123-456789abcdef",
|
||||||
|
provider="linear",
|
||||||
|
title="Mock Linear API key",
|
||||||
|
username="mock-linear-username",
|
||||||
|
access_token=SecretStr("mock-linear-access-token"),
|
||||||
|
access_token_expires_at=None,
|
||||||
|
refresh_token=SecretStr("mock-linear-refresh-token"),
|
||||||
|
refresh_token_expires_at=None,
|
||||||
|
scopes=["mock-linear-scopes"],
|
||||||
|
)
|
||||||
|
|
||||||
|
TEST_CREDENTIALS_API_KEY = APIKeyCredentials(
|
||||||
|
id="01234567-89ab-cdef-0123-456789abcdef",
|
||||||
|
provider="linear",
|
||||||
|
title="Mock Linear API key",
|
||||||
|
api_key=SecretStr("mock-linear-api-key"),
|
||||||
|
expires_at=None,
|
||||||
|
)
|
||||||
|
|
||||||
|
TEST_CREDENTIALS_INPUT_OAUTH = {
|
||||||
|
"provider": TEST_CREDENTIALS_OAUTH.provider,
|
||||||
|
"id": TEST_CREDENTIALS_OAUTH.id,
|
||||||
|
"type": TEST_CREDENTIALS_OAUTH.type,
|
||||||
|
"title": TEST_CREDENTIALS_OAUTH.type,
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_CREDENTIALS_INPUT_API_KEY = {
|
||||||
|
"provider": TEST_CREDENTIALS_API_KEY.provider,
|
||||||
|
"id": TEST_CREDENTIALS_API_KEY.id,
|
||||||
|
"type": TEST_CREDENTIALS_API_KEY.type,
|
||||||
|
"title": TEST_CREDENTIALS_API_KEY.type,
|
||||||
|
}
|
|
@ -0,0 +1,81 @@
|
||||||
|
from backend.blocks.linear._api import LinearAPIException, LinearClient
|
||||||
|
from backend.blocks.linear._auth import (
|
||||||
|
TEST_CREDENTIALS_INPUT_OAUTH,
|
||||||
|
TEST_CREDENTIALS_OAUTH,
|
||||||
|
LinearCredentials,
|
||||||
|
LinearCredentialsField,
|
||||||
|
LinearCredentialsInput,
|
||||||
|
LinearScope,
|
||||||
|
)
|
||||||
|
from backend.blocks.linear.models import CreateCommentResponse
|
||||||
|
from backend.data.block import Block, BlockCategory, BlockOutput, BlockSchema
|
||||||
|
from backend.data.model import SchemaField
|
||||||
|
|
||||||
|
|
||||||
|
class LinearCreateCommentBlock(Block):
|
||||||
|
"""Block for creating comments on Linear issues"""
|
||||||
|
|
||||||
|
class Input(BlockSchema):
|
||||||
|
credentials: LinearCredentialsInput = LinearCredentialsField(
|
||||||
|
scopes=[LinearScope.COMMENTS_CREATE],
|
||||||
|
)
|
||||||
|
issue_id: str = SchemaField(description="ID of the issue to comment on")
|
||||||
|
comment: str = SchemaField(description="Comment text to add to the issue")
|
||||||
|
|
||||||
|
class Output(BlockSchema):
|
||||||
|
comment_id: str = SchemaField(description="ID of the created comment")
|
||||||
|
comment_body: str = SchemaField(
|
||||||
|
description="Text content of the created comment"
|
||||||
|
)
|
||||||
|
error: str = SchemaField(description="Error message if comment creation failed")
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__(
|
||||||
|
id="8f7d3a2e-9b5c-4c6a-8f1d-7c8b3e4a5d6c",
|
||||||
|
description="Creates a new comment on a Linear issue",
|
||||||
|
input_schema=self.Input,
|
||||||
|
output_schema=self.Output,
|
||||||
|
categories={BlockCategory.PRODUCTIVITY, BlockCategory.ISSUE_TRACKING},
|
||||||
|
test_input={
|
||||||
|
"issue_id": "TEST-123",
|
||||||
|
"comment": "Test comment",
|
||||||
|
"credentials": TEST_CREDENTIALS_INPUT_OAUTH,
|
||||||
|
},
|
||||||
|
test_credentials=TEST_CREDENTIALS_OAUTH,
|
||||||
|
test_output=[("comment_id", "abc123"), ("comment_body", "Test comment")],
|
||||||
|
test_mock={
|
||||||
|
"create_comment": lambda *args, **kwargs: (
|
||||||
|
"abc123",
|
||||||
|
"Test comment",
|
||||||
|
)
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def create_comment(
|
||||||
|
credentials: LinearCredentials, issue_id: str, comment: str
|
||||||
|
) -> tuple[str, str]:
|
||||||
|
client = LinearClient(credentials=credentials)
|
||||||
|
response: CreateCommentResponse = client.try_create_comment(
|
||||||
|
issue_id=issue_id, comment=comment
|
||||||
|
)
|
||||||
|
return response.comment.id, response.comment.body
|
||||||
|
|
||||||
|
def run(
|
||||||
|
self, input_data: Input, *, credentials: LinearCredentials, **kwargs
|
||||||
|
) -> BlockOutput:
|
||||||
|
"""Execute the comment creation"""
|
||||||
|
try:
|
||||||
|
comment_id, comment_body = self.create_comment(
|
||||||
|
credentials=credentials,
|
||||||
|
issue_id=input_data.issue_id,
|
||||||
|
comment=input_data.comment,
|
||||||
|
)
|
||||||
|
|
||||||
|
yield "comment_id", comment_id
|
||||||
|
yield "comment_body", comment_body
|
||||||
|
|
||||||
|
except LinearAPIException as e:
|
||||||
|
yield "error", str(e)
|
||||||
|
except Exception as e:
|
||||||
|
yield "error", f"Unexpected error: {str(e)}"
|
|
@ -0,0 +1,186 @@
|
||||||
|
from backend.blocks.linear._api import LinearAPIException, LinearClient
|
||||||
|
from backend.blocks.linear._auth import (
|
||||||
|
TEST_CREDENTIALS_INPUT_OAUTH,
|
||||||
|
TEST_CREDENTIALS_OAUTH,
|
||||||
|
LinearCredentials,
|
||||||
|
LinearCredentialsField,
|
||||||
|
LinearCredentialsInput,
|
||||||
|
LinearScope,
|
||||||
|
)
|
||||||
|
from backend.blocks.linear.models import CreateIssueResponse, Issue
|
||||||
|
from backend.data.block import Block, BlockCategory, BlockOutput, BlockSchema
|
||||||
|
from backend.data.model import SchemaField
|
||||||
|
|
||||||
|
|
||||||
|
class LinearCreateIssueBlock(Block):
|
||||||
|
"""Block for creating issues on Linear"""
|
||||||
|
|
||||||
|
class Input(BlockSchema):
|
||||||
|
credentials: LinearCredentialsInput = LinearCredentialsField(
|
||||||
|
scopes=[LinearScope.ISSUES_CREATE],
|
||||||
|
)
|
||||||
|
title: str = SchemaField(description="Title of the issue")
|
||||||
|
description: str | None = SchemaField(description="Description of the issue")
|
||||||
|
team_name: str = SchemaField(
|
||||||
|
description="Name of the team to create the issue on"
|
||||||
|
)
|
||||||
|
priority: int | None = SchemaField(
|
||||||
|
description="Priority of the issue",
|
||||||
|
default=None,
|
||||||
|
minimum=0,
|
||||||
|
maximum=4,
|
||||||
|
)
|
||||||
|
project_name: str | None = SchemaField(
|
||||||
|
description="Name of the project to create the issue on",
|
||||||
|
default=None,
|
||||||
|
)
|
||||||
|
|
||||||
|
class Output(BlockSchema):
|
||||||
|
issue_id: str = SchemaField(description="ID of the created issue")
|
||||||
|
issue_title: str = SchemaField(description="Title of the created issue")
|
||||||
|
error: str = SchemaField(description="Error message if issue creation failed")
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__(
|
||||||
|
id="f9c68f55-dcca-40a8-8771-abf9601680aa",
|
||||||
|
description="Creates a new issue on Linear",
|
||||||
|
input_schema=self.Input,
|
||||||
|
output_schema=self.Output,
|
||||||
|
categories={BlockCategory.PRODUCTIVITY, BlockCategory.ISSUE_TRACKING},
|
||||||
|
test_input={
|
||||||
|
"title": "Test issue",
|
||||||
|
"description": "Test description",
|
||||||
|
"team_name": "Test team",
|
||||||
|
"project_name": "Test project",
|
||||||
|
"credentials": TEST_CREDENTIALS_INPUT_OAUTH,
|
||||||
|
},
|
||||||
|
test_credentials=TEST_CREDENTIALS_OAUTH,
|
||||||
|
test_output=[("issue_id", "abc123"), ("issue_title", "Test issue")],
|
||||||
|
test_mock={
|
||||||
|
"create_issue": lambda *args, **kwargs: (
|
||||||
|
"abc123",
|
||||||
|
"Test issue",
|
||||||
|
)
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def create_issue(
|
||||||
|
credentials: LinearCredentials,
|
||||||
|
team_name: str,
|
||||||
|
title: str,
|
||||||
|
description: str | None = None,
|
||||||
|
priority: int | None = None,
|
||||||
|
project_name: str | None = None,
|
||||||
|
) -> tuple[str, str]:
|
||||||
|
client = LinearClient(credentials=credentials)
|
||||||
|
team_id = client.try_get_team_by_name(team_name=team_name)
|
||||||
|
project_id: str | None = None
|
||||||
|
if project_name:
|
||||||
|
projects = client.try_search_projects(term=project_name)
|
||||||
|
if projects:
|
||||||
|
project_id = projects[0].id
|
||||||
|
else:
|
||||||
|
raise LinearAPIException("Project not found", status_code=404)
|
||||||
|
response: CreateIssueResponse = client.try_create_issue(
|
||||||
|
team_id=team_id,
|
||||||
|
title=title,
|
||||||
|
description=description,
|
||||||
|
priority=priority,
|
||||||
|
project_id=project_id,
|
||||||
|
)
|
||||||
|
return response.issue.identifier, response.issue.title
|
||||||
|
|
||||||
|
def run(
|
||||||
|
self, input_data: Input, *, credentials: LinearCredentials, **kwargs
|
||||||
|
) -> BlockOutput:
|
||||||
|
"""Execute the issue creation"""
|
||||||
|
try:
|
||||||
|
issue_id, issue_title = self.create_issue(
|
||||||
|
credentials=credentials,
|
||||||
|
team_name=input_data.team_name,
|
||||||
|
title=input_data.title,
|
||||||
|
description=input_data.description,
|
||||||
|
priority=input_data.priority,
|
||||||
|
project_name=input_data.project_name,
|
||||||
|
)
|
||||||
|
|
||||||
|
yield "issue_id", issue_id
|
||||||
|
yield "issue_title", issue_title
|
||||||
|
|
||||||
|
except LinearAPIException as e:
|
||||||
|
yield "error", str(e)
|
||||||
|
except Exception as e:
|
||||||
|
yield "error", f"Unexpected error: {str(e)}"
|
||||||
|
|
||||||
|
|
||||||
|
class LinearSearchIssuesBlock(Block):
|
||||||
|
"""Block for searching issues on Linear"""
|
||||||
|
|
||||||
|
class Input(BlockSchema):
|
||||||
|
term: str = SchemaField(description="Term to search for issues")
|
||||||
|
credentials: LinearCredentialsInput = LinearCredentialsField(
|
||||||
|
scopes=[LinearScope.READ],
|
||||||
|
)
|
||||||
|
|
||||||
|
class Output(BlockSchema):
|
||||||
|
issues: list[Issue] = SchemaField(description="List of issues")
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__(
|
||||||
|
id="b5a2a0e6-26b4-4c5b-8a42-bc79e9cb65c2",
|
||||||
|
description="Searches for issues on Linear",
|
||||||
|
input_schema=self.Input,
|
||||||
|
output_schema=self.Output,
|
||||||
|
test_input={
|
||||||
|
"term": "Test issue",
|
||||||
|
"credentials": TEST_CREDENTIALS_INPUT_OAUTH,
|
||||||
|
},
|
||||||
|
test_credentials=TEST_CREDENTIALS_OAUTH,
|
||||||
|
test_output=[
|
||||||
|
(
|
||||||
|
"issues",
|
||||||
|
[
|
||||||
|
Issue(
|
||||||
|
id="abc123",
|
||||||
|
identifier="abc123",
|
||||||
|
title="Test issue",
|
||||||
|
description="Test description",
|
||||||
|
priority=1,
|
||||||
|
)
|
||||||
|
],
|
||||||
|
)
|
||||||
|
],
|
||||||
|
test_mock={
|
||||||
|
"search_issues": lambda *args, **kwargs: [
|
||||||
|
Issue(
|
||||||
|
id="abc123",
|
||||||
|
identifier="abc123",
|
||||||
|
title="Test issue",
|
||||||
|
description="Test description",
|
||||||
|
priority=1,
|
||||||
|
)
|
||||||
|
]
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def search_issues(
|
||||||
|
credentials: LinearCredentials,
|
||||||
|
term: str,
|
||||||
|
) -> list[Issue]:
|
||||||
|
client = LinearClient(credentials=credentials)
|
||||||
|
response: list[Issue] = client.try_search_issues(term=term)
|
||||||
|
return response
|
||||||
|
|
||||||
|
def run(
|
||||||
|
self, input_data: Input, *, credentials: LinearCredentials, **kwargs
|
||||||
|
) -> BlockOutput:
|
||||||
|
"""Execute the issue search"""
|
||||||
|
try:
|
||||||
|
issues = self.search_issues(credentials=credentials, term=input_data.term)
|
||||||
|
yield "issues", issues
|
||||||
|
except LinearAPIException as e:
|
||||||
|
yield "error", str(e)
|
||||||
|
except Exception as e:
|
||||||
|
yield "error", f"Unexpected error: {str(e)}"
|
|
@ -0,0 +1,41 @@
|
||||||
|
from pydantic import BaseModel
|
||||||
|
|
||||||
|
|
||||||
|
class Comment(BaseModel):
|
||||||
|
id: str
|
||||||
|
body: str
|
||||||
|
|
||||||
|
|
||||||
|
class CreateCommentInput(BaseModel):
|
||||||
|
body: str
|
||||||
|
issueId: str
|
||||||
|
|
||||||
|
|
||||||
|
class CreateCommentResponse(BaseModel):
|
||||||
|
success: bool
|
||||||
|
comment: Comment
|
||||||
|
|
||||||
|
|
||||||
|
class CreateCommentResponseWrapper(BaseModel):
|
||||||
|
commentCreate: CreateCommentResponse
|
||||||
|
|
||||||
|
|
||||||
|
class Issue(BaseModel):
|
||||||
|
id: str
|
||||||
|
identifier: str
|
||||||
|
title: str
|
||||||
|
description: str | None
|
||||||
|
priority: int
|
||||||
|
|
||||||
|
|
||||||
|
class CreateIssueResponse(BaseModel):
|
||||||
|
issue: Issue
|
||||||
|
|
||||||
|
|
||||||
|
class Project(BaseModel):
|
||||||
|
id: str
|
||||||
|
name: str
|
||||||
|
description: str
|
||||||
|
priority: int
|
||||||
|
progress: int
|
||||||
|
content: str
|
|
@ -0,0 +1,93 @@
|
||||||
|
from backend.blocks.linear._api import LinearAPIException, LinearClient
|
||||||
|
from backend.blocks.linear._auth import (
|
||||||
|
TEST_CREDENTIALS_INPUT_OAUTH,
|
||||||
|
TEST_CREDENTIALS_OAUTH,
|
||||||
|
LinearCredentials,
|
||||||
|
LinearCredentialsField,
|
||||||
|
LinearCredentialsInput,
|
||||||
|
LinearScope,
|
||||||
|
)
|
||||||
|
from backend.blocks.linear.models import Project
|
||||||
|
from backend.data.block import Block, BlockCategory, BlockOutput, BlockSchema
|
||||||
|
from backend.data.model import SchemaField
|
||||||
|
|
||||||
|
|
||||||
|
class LinearSearchProjectsBlock(Block):
|
||||||
|
"""Block for searching projects on Linear"""
|
||||||
|
|
||||||
|
class Input(BlockSchema):
|
||||||
|
credentials: LinearCredentialsInput = LinearCredentialsField(
|
||||||
|
scopes=[LinearScope.READ],
|
||||||
|
)
|
||||||
|
term: str = SchemaField(description="Term to search for projects")
|
||||||
|
|
||||||
|
class Output(BlockSchema):
|
||||||
|
projects: list[Project] = SchemaField(description="List of projects")
|
||||||
|
error: str = SchemaField(description="Error message if issue creation failed")
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__(
|
||||||
|
id="446a1d35-9d8f-4ac5-83ea-7684ec50e6af",
|
||||||
|
description="Searches for projects on Linear",
|
||||||
|
input_schema=self.Input,
|
||||||
|
output_schema=self.Output,
|
||||||
|
categories={BlockCategory.PRODUCTIVITY, BlockCategory.ISSUE_TRACKING},
|
||||||
|
test_input={
|
||||||
|
"term": "Test project",
|
||||||
|
"credentials": TEST_CREDENTIALS_INPUT_OAUTH,
|
||||||
|
},
|
||||||
|
test_credentials=TEST_CREDENTIALS_OAUTH,
|
||||||
|
test_output=[
|
||||||
|
(
|
||||||
|
"projects",
|
||||||
|
[
|
||||||
|
Project(
|
||||||
|
id="abc123",
|
||||||
|
name="Test project",
|
||||||
|
description="Test description",
|
||||||
|
priority=1,
|
||||||
|
progress=1,
|
||||||
|
content="Test content",
|
||||||
|
)
|
||||||
|
],
|
||||||
|
)
|
||||||
|
],
|
||||||
|
test_mock={
|
||||||
|
"search_projects": lambda *args, **kwargs: [
|
||||||
|
Project(
|
||||||
|
id="abc123",
|
||||||
|
name="Test project",
|
||||||
|
description="Test description",
|
||||||
|
priority=1,
|
||||||
|
progress=1,
|
||||||
|
content="Test content",
|
||||||
|
)
|
||||||
|
]
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def search_projects(
|
||||||
|
credentials: LinearCredentials,
|
||||||
|
term: str,
|
||||||
|
) -> list[Project]:
|
||||||
|
client = LinearClient(credentials=credentials)
|
||||||
|
response: list[Project] = client.try_search_projects(term=term)
|
||||||
|
return response
|
||||||
|
|
||||||
|
def run(
|
||||||
|
self, input_data: Input, *, credentials: LinearCredentials, **kwargs
|
||||||
|
) -> BlockOutput:
|
||||||
|
"""Execute the project search"""
|
||||||
|
try:
|
||||||
|
projects = self.search_projects(
|
||||||
|
credentials=credentials,
|
||||||
|
term=input_data.term,
|
||||||
|
)
|
||||||
|
|
||||||
|
yield "projects", projects
|
||||||
|
|
||||||
|
except LinearAPIException as e:
|
||||||
|
yield "error", str(e)
|
||||||
|
except Exception as e:
|
||||||
|
yield "error", f"Unexpected error: {str(e)}"
|
|
@ -1,22 +1,48 @@
|
||||||
from datetime import datetime, timezone
|
from datetime import datetime, timezone
|
||||||
from typing import Iterator
|
from typing import Iterator, Literal
|
||||||
|
|
||||||
import praw
|
import praw
|
||||||
from pydantic import BaseModel, ConfigDict
|
from pydantic import BaseModel, SecretStr
|
||||||
|
|
||||||
from backend.data.block import Block, BlockCategory, BlockOutput, BlockSchema
|
from backend.data.block import Block, BlockCategory, BlockOutput, BlockSchema
|
||||||
from backend.data.model import BlockSecret, SchemaField, SecretField
|
from backend.data.model import (
|
||||||
|
CredentialsField,
|
||||||
|
CredentialsMetaInput,
|
||||||
|
SchemaField,
|
||||||
|
UserPasswordCredentials,
|
||||||
|
)
|
||||||
|
from backend.integrations.providers import ProviderName
|
||||||
from backend.util.mock import MockObject
|
from backend.util.mock import MockObject
|
||||||
|
from backend.util.settings import Settings
|
||||||
|
|
||||||
|
RedditCredentials = UserPasswordCredentials
|
||||||
|
RedditCredentialsInput = CredentialsMetaInput[
|
||||||
|
Literal[ProviderName.REDDIT],
|
||||||
|
Literal["user_password"],
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
class RedditCredentials(BaseModel):
|
def RedditCredentialsField() -> RedditCredentialsInput:
|
||||||
client_id: BlockSecret = SecretField(key="reddit_client_id")
|
"""Creates a Reddit credentials input on a block."""
|
||||||
client_secret: BlockSecret = SecretField(key="reddit_client_secret")
|
return CredentialsField(
|
||||||
username: BlockSecret = SecretField(key="reddit_username")
|
description="The Reddit integration requires a username and password.",
|
||||||
password: BlockSecret = SecretField(key="reddit_password")
|
)
|
||||||
user_agent: str = "AutoGPT:1.0 (by /u/autogpt)"
|
|
||||||
|
|
||||||
model_config = ConfigDict(title="Reddit Credentials")
|
|
||||||
|
TEST_CREDENTIALS = UserPasswordCredentials(
|
||||||
|
id="01234567-89ab-cdef-0123-456789abcdef",
|
||||||
|
provider="reddit",
|
||||||
|
username=SecretStr("mock-reddit-username"),
|
||||||
|
password=SecretStr("mock-reddit-password"),
|
||||||
|
title="Mock Reddit credentials",
|
||||||
|
)
|
||||||
|
|
||||||
|
TEST_CREDENTIALS_INPUT = {
|
||||||
|
"provider": TEST_CREDENTIALS.provider,
|
||||||
|
"id": TEST_CREDENTIALS.id,
|
||||||
|
"type": TEST_CREDENTIALS.type,
|
||||||
|
"title": TEST_CREDENTIALS.title,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
class RedditPost(BaseModel):
|
class RedditPost(BaseModel):
|
||||||
|
@ -31,13 +57,16 @@ class RedditComment(BaseModel):
|
||||||
comment: str
|
comment: str
|
||||||
|
|
||||||
|
|
||||||
|
settings = Settings()
|
||||||
|
|
||||||
|
|
||||||
def get_praw(creds: RedditCredentials) -> praw.Reddit:
|
def get_praw(creds: RedditCredentials) -> praw.Reddit:
|
||||||
client = praw.Reddit(
|
client = praw.Reddit(
|
||||||
client_id=creds.client_id.get_secret_value(),
|
client_id=settings.secrets.reddit_client_id,
|
||||||
client_secret=creds.client_secret.get_secret_value(),
|
client_secret=settings.secrets.reddit_client_secret,
|
||||||
username=creds.username.get_secret_value(),
|
username=creds.username.get_secret_value(),
|
||||||
password=creds.password.get_secret_value(),
|
password=creds.password.get_secret_value(),
|
||||||
user_agent=creds.user_agent,
|
user_agent=settings.config.reddit_user_agent,
|
||||||
)
|
)
|
||||||
me = client.user.me()
|
me = client.user.me()
|
||||||
if not me:
|
if not me:
|
||||||
|
@ -48,11 +77,11 @@ def get_praw(creds: RedditCredentials) -> praw.Reddit:
|
||||||
|
|
||||||
class GetRedditPostsBlock(Block):
|
class GetRedditPostsBlock(Block):
|
||||||
class Input(BlockSchema):
|
class Input(BlockSchema):
|
||||||
subreddit: str = SchemaField(description="Subreddit name")
|
subreddit: str = SchemaField(
|
||||||
creds: RedditCredentials = SchemaField(
|
description="Subreddit name, excluding the /r/ prefix",
|
||||||
description="Reddit credentials",
|
default="writingprompts",
|
||||||
default=RedditCredentials(),
|
|
||||||
)
|
)
|
||||||
|
credentials: RedditCredentialsInput = RedditCredentialsField()
|
||||||
last_minutes: int | None = SchemaField(
|
last_minutes: int | None = SchemaField(
|
||||||
description="Post time to stop minutes ago while fetching posts",
|
description="Post time to stop minutes ago while fetching posts",
|
||||||
default=None,
|
default=None,
|
||||||
|
@ -70,20 +99,18 @@ class GetRedditPostsBlock(Block):
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
super().__init__(
|
super().__init__(
|
||||||
disabled=True,
|
|
||||||
id="c6731acb-4285-4ee1-bc9b-03d0766c370f",
|
id="c6731acb-4285-4ee1-bc9b-03d0766c370f",
|
||||||
description="This block fetches Reddit posts from a defined subreddit name.",
|
description="This block fetches Reddit posts from a defined subreddit name.",
|
||||||
categories={BlockCategory.SOCIAL},
|
categories={BlockCategory.SOCIAL},
|
||||||
|
disabled=(
|
||||||
|
not settings.secrets.reddit_client_id
|
||||||
|
or not settings.secrets.reddit_client_secret
|
||||||
|
),
|
||||||
input_schema=GetRedditPostsBlock.Input,
|
input_schema=GetRedditPostsBlock.Input,
|
||||||
output_schema=GetRedditPostsBlock.Output,
|
output_schema=GetRedditPostsBlock.Output,
|
||||||
|
test_credentials=TEST_CREDENTIALS,
|
||||||
test_input={
|
test_input={
|
||||||
"creds": {
|
"credentials": TEST_CREDENTIALS_INPUT,
|
||||||
"client_id": "client_id",
|
|
||||||
"client_secret": "client_secret",
|
|
||||||
"username": "username",
|
|
||||||
"password": "password",
|
|
||||||
"user_agent": "user_agent",
|
|
||||||
},
|
|
||||||
"subreddit": "subreddit",
|
"subreddit": "subreddit",
|
||||||
"last_post": "id3",
|
"last_post": "id3",
|
||||||
"post_limit": 2,
|
"post_limit": 2,
|
||||||
|
@ -103,7 +130,7 @@ class GetRedditPostsBlock(Block):
|
||||||
),
|
),
|
||||||
],
|
],
|
||||||
test_mock={
|
test_mock={
|
||||||
"get_posts": lambda _: [
|
"get_posts": lambda input_data, credentials: [
|
||||||
MockObject(id="id1", title="title1", selftext="body1"),
|
MockObject(id="id1", title="title1", selftext="body1"),
|
||||||
MockObject(id="id2", title="title2", selftext="body2"),
|
MockObject(id="id2", title="title2", selftext="body2"),
|
||||||
MockObject(id="id3", title="title2", selftext="body2"),
|
MockObject(id="id3", title="title2", selftext="body2"),
|
||||||
|
@ -112,14 +139,18 @@ class GetRedditPostsBlock(Block):
|
||||||
)
|
)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def get_posts(input_data: Input) -> Iterator[praw.reddit.Submission]:
|
def get_posts(
|
||||||
client = get_praw(input_data.creds)
|
input_data: Input, *, credentials: RedditCredentials
|
||||||
|
) -> Iterator[praw.reddit.Submission]:
|
||||||
|
client = get_praw(credentials)
|
||||||
subreddit = client.subreddit(input_data.subreddit)
|
subreddit = client.subreddit(input_data.subreddit)
|
||||||
return subreddit.new(limit=input_data.post_limit or 10)
|
return subreddit.new(limit=input_data.post_limit or 10)
|
||||||
|
|
||||||
def run(self, input_data: Input, **kwargs) -> BlockOutput:
|
def run(
|
||||||
|
self, input_data: Input, *, credentials: RedditCredentials, **kwargs
|
||||||
|
) -> BlockOutput:
|
||||||
current_time = datetime.now(tz=timezone.utc)
|
current_time = datetime.now(tz=timezone.utc)
|
||||||
for post in self.get_posts(input_data):
|
for post in self.get_posts(input_data=input_data, credentials=credentials):
|
||||||
if input_data.last_minutes:
|
if input_data.last_minutes:
|
||||||
post_datetime = datetime.fromtimestamp(
|
post_datetime = datetime.fromtimestamp(
|
||||||
post.created_utc, tz=timezone.utc
|
post.created_utc, tz=timezone.utc
|
||||||
|
@ -141,9 +172,7 @@ class GetRedditPostsBlock(Block):
|
||||||
|
|
||||||
class PostRedditCommentBlock(Block):
|
class PostRedditCommentBlock(Block):
|
||||||
class Input(BlockSchema):
|
class Input(BlockSchema):
|
||||||
creds: RedditCredentials = SchemaField(
|
credentials: RedditCredentialsInput = RedditCredentialsField()
|
||||||
description="Reddit credentials", default=RedditCredentials()
|
|
||||||
)
|
|
||||||
data: RedditComment = SchemaField(description="Reddit comment")
|
data: RedditComment = SchemaField(description="Reddit comment")
|
||||||
|
|
||||||
class Output(BlockSchema):
|
class Output(BlockSchema):
|
||||||
|
@ -156,7 +185,15 @@ class PostRedditCommentBlock(Block):
|
||||||
categories={BlockCategory.SOCIAL},
|
categories={BlockCategory.SOCIAL},
|
||||||
input_schema=PostRedditCommentBlock.Input,
|
input_schema=PostRedditCommentBlock.Input,
|
||||||
output_schema=PostRedditCommentBlock.Output,
|
output_schema=PostRedditCommentBlock.Output,
|
||||||
test_input={"data": {"post_id": "id", "comment": "comment"}},
|
disabled=(
|
||||||
|
not settings.secrets.reddit_client_id
|
||||||
|
or not settings.secrets.reddit_client_secret
|
||||||
|
),
|
||||||
|
test_credentials=TEST_CREDENTIALS,
|
||||||
|
test_input={
|
||||||
|
"credentials": TEST_CREDENTIALS_INPUT,
|
||||||
|
"data": {"post_id": "id", "comment": "comment"},
|
||||||
|
},
|
||||||
test_output=[("comment_id", "dummy_comment_id")],
|
test_output=[("comment_id", "dummy_comment_id")],
|
||||||
test_mock={"reply_post": lambda creds, comment: "dummy_comment_id"},
|
test_mock={"reply_post": lambda creds, comment: "dummy_comment_id"},
|
||||||
)
|
)
|
||||||
|
@ -170,5 +207,7 @@ class PostRedditCommentBlock(Block):
|
||||||
raise ValueError("Failed to post comment.")
|
raise ValueError("Failed to post comment.")
|
||||||
return new_comment.id
|
return new_comment.id
|
||||||
|
|
||||||
def run(self, input_data: Input, **kwargs) -> BlockOutput:
|
def run(
|
||||||
yield "comment_id", self.reply_post(input_data.creds, input_data.data)
|
self, input_data: Input, *, credentials: RedditCredentials, **kwargs
|
||||||
|
) -> BlockOutput:
|
||||||
|
yield "comment_id", self.reply_post(credentials, input_data.data)
|
||||||
|
|
|
@ -64,6 +64,8 @@ class BlockCategory(Enum):
|
||||||
SAFETY = (
|
SAFETY = (
|
||||||
"Block that provides AI safety mechanisms such as detecting harmful content"
|
"Block that provides AI safety mechanisms such as detecting harmful content"
|
||||||
)
|
)
|
||||||
|
PRODUCTIVITY = "Block that helps with productivity"
|
||||||
|
ISSUE_TRACKING = "Block that helps with issue tracking"
|
||||||
|
|
||||||
def dict(self) -> dict[str, str]:
|
def dict(self) -> dict[str, str]:
|
||||||
return {"category": self.name, "description": self.value}
|
return {"category": self.name, "description": self.value}
|
||||||
|
|
|
@ -1,40 +1,40 @@
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from datetime import datetime, timezone
|
from datetime import datetime, timezone
|
||||||
|
|
||||||
|
import stripe
|
||||||
from prisma import Json
|
from prisma import Json
|
||||||
from prisma.enums import CreditTransactionType
|
from prisma.enums import CreditTransactionType
|
||||||
from prisma.errors import UniqueViolationError
|
from prisma.errors import UniqueViolationError
|
||||||
from prisma.models import CreditTransaction
|
from prisma.models import CreditTransaction, User
|
||||||
|
from prisma.types import CreditTransactionCreateInput, CreditTransactionWhereInput
|
||||||
|
|
||||||
|
from backend.data import db
|
||||||
from backend.data.block import Block, BlockInput, get_block
|
from backend.data.block import Block, BlockInput, get_block
|
||||||
from backend.data.block_cost_config import BLOCK_COSTS
|
from backend.data.block_cost_config import BLOCK_COSTS
|
||||||
from backend.data.cost import BlockCost, BlockCostType
|
from backend.data.cost import BlockCost, BlockCostType
|
||||||
from backend.util.settings import Config
|
from backend.data.execution import NodeExecutionEntry
|
||||||
|
from backend.data.user import get_user_by_id
|
||||||
|
from backend.util.settings import Settings
|
||||||
|
|
||||||
config = Config()
|
settings = Settings()
|
||||||
|
stripe.api_key = settings.secrets.stripe_api_key
|
||||||
|
|
||||||
|
|
||||||
class UserCreditBase(ABC):
|
class UserCreditBase(ABC):
|
||||||
def __init__(self, num_user_credits_refill: int):
|
|
||||||
self.num_user_credits_refill = num_user_credits_refill
|
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
async def get_or_refill_credit(self, user_id: str) -> int:
|
async def get_credits(self, user_id: str) -> int:
|
||||||
"""
|
"""
|
||||||
Get the current credit for the user and refill if no transaction has been made in the current cycle.
|
Get the current credits for the user.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
int: The current credit for the user.
|
int: The current credits for the user.
|
||||||
"""
|
"""
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
async def spend_credits(
|
async def spend_credits(
|
||||||
self,
|
self,
|
||||||
user_id: str,
|
entry: NodeExecutionEntry,
|
||||||
user_credit: int,
|
|
||||||
block_id: str,
|
|
||||||
input_data: BlockInput,
|
|
||||||
data_size: float,
|
data_size: float,
|
||||||
run_time: float,
|
run_time: float,
|
||||||
) -> int:
|
) -> int:
|
||||||
|
@ -42,10 +42,7 @@ class UserCreditBase(ABC):
|
||||||
Spend the credits for the user based on the block usage.
|
Spend the credits for the user based on the block usage.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
user_id (str): The user ID.
|
entry (NodeExecutionEntry): The node execution identifiers & data.
|
||||||
user_credit (int): The current credit for the user.
|
|
||||||
block_id (str): The block ID.
|
|
||||||
input_data (BlockInput): The input data for the block.
|
|
||||||
data_size (float): The size of the data being processed.
|
data_size (float): The size of the data being processed.
|
||||||
run_time (float): The time taken to run the block.
|
run_time (float): The time taken to run the block.
|
||||||
|
|
||||||
|
@ -57,7 +54,7 @@ class UserCreditBase(ABC):
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
async def top_up_credits(self, user_id: str, amount: int):
|
async def top_up_credits(self, user_id: str, amount: int):
|
||||||
"""
|
"""
|
||||||
Top up the credits for the user.
|
Top up the credits for the user immediately.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
user_id (str): The user ID.
|
user_id (str): The user ID.
|
||||||
|
@ -65,51 +62,137 @@ class UserCreditBase(ABC):
|
||||||
"""
|
"""
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
async def top_up_intent(self, user_id: str, amount: int) -> str:
|
||||||
|
"""
|
||||||
|
Create a payment intent to top up the credits for the user.
|
||||||
|
|
||||||
class UserCredit(UserCreditBase):
|
Args:
|
||||||
async def get_or_refill_credit(self, user_id: str) -> int:
|
user_id (str): The user ID.
|
||||||
cur_time = self.time_now()
|
amount (int): The amount of credits to top up.
|
||||||
cur_month = cur_time.replace(day=1, hour=0, minute=0, second=0, microsecond=0)
|
|
||||||
nxt_month = (
|
Returns:
|
||||||
cur_month.replace(month=cur_month.month + 1)
|
str: The redirect url to the payment page.
|
||||||
if cur_month.month < 12
|
"""
|
||||||
else cur_month.replace(year=cur_month.year + 1, month=1)
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
async def fulfill_checkout(
|
||||||
|
self, *, session_id: str | None = None, user_id: str | None = None
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Fulfill the Stripe checkout session.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
session_id (str | None): The checkout session ID. Will try to fulfill most recent if None.
|
||||||
|
user_id (str | None): The user ID must be provided if session_id is None.
|
||||||
|
"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def time_now() -> datetime:
|
||||||
|
return datetime.now(timezone.utc)
|
||||||
|
|
||||||
|
# ====== Transaction Helper Methods ====== #
|
||||||
|
# Any modifications to the transaction table should only be done through these methods #
|
||||||
|
|
||||||
|
async def _get_credits(self, user_id: str) -> tuple[int, datetime]:
|
||||||
|
"""
|
||||||
|
Returns the current balance of the user & the latest balance snapshot time.
|
||||||
|
"""
|
||||||
|
top_time = self.time_now()
|
||||||
|
snapshot = await CreditTransaction.prisma().find_first(
|
||||||
|
where={
|
||||||
|
"userId": user_id,
|
||||||
|
"createdAt": {"lte": top_time},
|
||||||
|
"isActive": True,
|
||||||
|
"runningBalance": {"not": None}, # type: ignore
|
||||||
|
},
|
||||||
|
order={"createdAt": "desc"},
|
||||||
)
|
)
|
||||||
|
if snapshot:
|
||||||
|
return snapshot.runningBalance or 0, snapshot.createdAt
|
||||||
|
|
||||||
user_credit = await CreditTransaction.prisma().group_by(
|
# No snapshot: Manually calculate balance using current month's transactions.
|
||||||
|
low_time = top_time.replace(day=1, hour=0, minute=0, second=0, microsecond=0)
|
||||||
|
transactions = await CreditTransaction.prisma().group_by(
|
||||||
by=["userId"],
|
by=["userId"],
|
||||||
sum={"amount": True},
|
sum={"amount": True},
|
||||||
where={
|
where={
|
||||||
"userId": user_id,
|
"userId": user_id,
|
||||||
"createdAt": {"gte": cur_month, "lt": nxt_month},
|
"createdAt": {"gte": low_time, "lte": top_time},
|
||||||
"isActive": True,
|
"isActive": True,
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
transaction_balance = (
|
||||||
|
transactions[0].get("_sum", {}).get("amount", 0) if transactions else 0
|
||||||
|
)
|
||||||
|
return transaction_balance, datetime.min
|
||||||
|
|
||||||
if user_credit:
|
async def _enable_transaction(
|
||||||
credit_sum = user_credit[0].get("_sum") or {}
|
self, transaction_key: str, user_id: str, metadata: Json
|
||||||
return credit_sum.get("amount", 0)
|
):
|
||||||
|
|
||||||
key = f"MONTHLY-CREDIT-TOP-UP-{cur_month}"
|
transaction = await CreditTransaction.prisma().find_first_or_raise(
|
||||||
|
where={"transactionKey": transaction_key, "userId": user_id}
|
||||||
|
)
|
||||||
|
|
||||||
try:
|
if transaction.isActive:
|
||||||
await CreditTransaction.prisma().create(
|
return
|
||||||
|
|
||||||
|
async with db.locked_transaction(f"usr_trx_{user_id}"):
|
||||||
|
user_balance, _ = await self._get_credits(user_id)
|
||||||
|
|
||||||
|
await CreditTransaction.prisma().update(
|
||||||
|
where={
|
||||||
|
"creditTransactionIdentifier": {
|
||||||
|
"transactionKey": transaction_key,
|
||||||
|
"userId": user_id,
|
||||||
|
}
|
||||||
|
},
|
||||||
data={
|
data={
|
||||||
"amount": self.num_user_credits_refill,
|
"isActive": True,
|
||||||
"type": CreditTransactionType.TOP_UP,
|
"runningBalance": user_balance + transaction.amount,
|
||||||
"userId": user_id,
|
|
||||||
"transactionKey": key,
|
|
||||||
"createdAt": self.time_now(),
|
"createdAt": self.time_now(),
|
||||||
}
|
"metadata": metadata,
|
||||||
|
},
|
||||||
)
|
)
|
||||||
except UniqueViolationError:
|
|
||||||
pass # Already refilled this month
|
|
||||||
|
|
||||||
return self.num_user_credits_refill
|
async def _add_transaction(
|
||||||
|
self,
|
||||||
|
user_id: str,
|
||||||
|
amount: int,
|
||||||
|
transaction_type: CreditTransactionType,
|
||||||
|
is_active: bool = True,
|
||||||
|
transaction_key: str | None = None,
|
||||||
|
metadata: Json = Json({}),
|
||||||
|
):
|
||||||
|
async with db.locked_transaction(f"usr_trx_{user_id}"):
|
||||||
|
# Get latest balance snapshot
|
||||||
|
user_balance, _ = await self._get_credits(user_id)
|
||||||
|
if amount < 0 and user_balance < abs(amount):
|
||||||
|
raise ValueError(
|
||||||
|
f"Insufficient balance for user {user_id}, balance: {user_balance}, amount: {amount}"
|
||||||
|
)
|
||||||
|
|
||||||
@staticmethod
|
# Create the transaction
|
||||||
def time_now():
|
transaction_data: CreditTransactionCreateInput = {
|
||||||
return datetime.now(timezone.utc)
|
"userId": user_id,
|
||||||
|
"amount": amount,
|
||||||
|
"runningBalance": user_balance + amount,
|
||||||
|
"type": transaction_type,
|
||||||
|
"metadata": metadata,
|
||||||
|
"isActive": is_active,
|
||||||
|
"createdAt": self.time_now(),
|
||||||
|
}
|
||||||
|
if transaction_key:
|
||||||
|
transaction_data["transactionKey"] = transaction_key
|
||||||
|
await CreditTransaction.prisma().create(data=transaction_data)
|
||||||
|
|
||||||
|
return user_balance + amount
|
||||||
|
|
||||||
|
|
||||||
|
class UserCredit(UserCreditBase):
|
||||||
|
|
||||||
def _block_usage_cost(
|
def _block_usage_cost(
|
||||||
self,
|
self,
|
||||||
|
@ -148,8 +231,8 @@ class UserCredit(UserCreditBase):
|
||||||
) -> bool:
|
) -> bool:
|
||||||
"""
|
"""
|
||||||
Filter rules:
|
Filter rules:
|
||||||
- If costFilter is an object, then check if costFilter is the subset of inputValues
|
- If cost_filter is an object, then check if cost_filter is the subset of input_data
|
||||||
- Otherwise, check if costFilter is equal to inputValues.
|
- Otherwise, check if cost_filter is equal to input_data.
|
||||||
- Undefined, null, and empty string are considered as equal.
|
- Undefined, null, and empty string are considered as equal.
|
||||||
"""
|
"""
|
||||||
if not isinstance(cost_filter, dict) or not isinstance(input_data, dict):
|
if not isinstance(cost_filter, dict) or not isinstance(input_data, dict):
|
||||||
|
@ -163,57 +246,169 @@ class UserCredit(UserCreditBase):
|
||||||
|
|
||||||
async def spend_credits(
|
async def spend_credits(
|
||||||
self,
|
self,
|
||||||
user_id: str,
|
entry: NodeExecutionEntry,
|
||||||
user_credit: int,
|
|
||||||
block_id: str,
|
|
||||||
input_data: BlockInput,
|
|
||||||
data_size: float,
|
data_size: float,
|
||||||
run_time: float,
|
run_time: float,
|
||||||
validate_balance: bool = True,
|
|
||||||
) -> int:
|
) -> int:
|
||||||
block = get_block(block_id)
|
block = get_block(entry.block_id)
|
||||||
if not block:
|
if not block:
|
||||||
raise ValueError(f"Block not found: {block_id}")
|
raise ValueError(f"Block not found: {entry.block_id}")
|
||||||
|
|
||||||
cost, matching_filter = self._block_usage_cost(
|
cost, matching_filter = self._block_usage_cost(
|
||||||
block=block, input_data=input_data, data_size=data_size, run_time=run_time
|
block=block, input_data=entry.data, data_size=data_size, run_time=run_time
|
||||||
)
|
)
|
||||||
if cost <= 0:
|
if cost == 0:
|
||||||
return 0
|
return 0
|
||||||
|
|
||||||
if validate_balance and user_credit < cost:
|
await self._add_transaction(
|
||||||
raise ValueError(f"Insufficient credit: {user_credit} < {cost}")
|
user_id=entry.user_id,
|
||||||
|
amount=-cost,
|
||||||
await CreditTransaction.prisma().create(
|
transaction_type=CreditTransactionType.USAGE,
|
||||||
data={
|
metadata=Json(
|
||||||
"userId": user_id,
|
{
|
||||||
"amount": -cost,
|
"graph_exec_id": entry.graph_exec_id,
|
||||||
"type": CreditTransactionType.USAGE,
|
"graph_id": entry.graph_id,
|
||||||
"blockId": block.id,
|
"node_id": entry.node_id,
|
||||||
"metadata": Json(
|
"node_exec_id": entry.node_exec_id,
|
||||||
{
|
"block_id": entry.block_id,
|
||||||
"block": block.name,
|
"block": block.name,
|
||||||
"input": matching_filter,
|
"input": matching_filter,
|
||||||
}
|
}
|
||||||
),
|
),
|
||||||
"createdAt": self.time_now(),
|
|
||||||
}
|
|
||||||
)
|
)
|
||||||
|
|
||||||
return cost
|
return cost
|
||||||
|
|
||||||
async def top_up_credits(self, user_id: str, amount: int):
|
async def top_up_credits(self, user_id: str, amount: int):
|
||||||
await CreditTransaction.prisma().create(
|
if amount < 0:
|
||||||
data={
|
raise ValueError(f"Top up amount must not be negative: {amount}")
|
||||||
"userId": user_id,
|
|
||||||
"amount": amount,
|
await self._add_transaction(
|
||||||
"type": CreditTransactionType.TOP_UP,
|
user_id=user_id,
|
||||||
"createdAt": self.time_now(),
|
amount=amount,
|
||||||
}
|
transaction_type=CreditTransactionType.TOP_UP,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
async def top_up_intent(self, user_id: str, amount: int) -> str:
|
||||||
|
# Create checkout session
|
||||||
|
# https://docs.stripe.com/checkout/quickstart?client=react
|
||||||
|
# unit_amount param is always in the smallest currency unit (so cents for usd)
|
||||||
|
# which is equal to amount of credits
|
||||||
|
checkout_session = stripe.checkout.Session.create(
|
||||||
|
customer=await get_stripe_customer_id(user_id),
|
||||||
|
line_items=[
|
||||||
|
{
|
||||||
|
"price_data": {
|
||||||
|
"currency": "usd",
|
||||||
|
"product_data": {
|
||||||
|
"name": "AutoGPT Platform Credits",
|
||||||
|
},
|
||||||
|
"unit_amount": amount,
|
||||||
|
},
|
||||||
|
"quantity": 1,
|
||||||
|
}
|
||||||
|
],
|
||||||
|
mode="payment",
|
||||||
|
success_url=settings.config.platform_base_url
|
||||||
|
+ "/store/credits?topup=success",
|
||||||
|
cancel_url=settings.config.platform_base_url
|
||||||
|
+ "/store/credits?topup=cancel",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Create pending transaction
|
||||||
|
await self._add_transaction(
|
||||||
|
user_id=user_id,
|
||||||
|
amount=amount,
|
||||||
|
transaction_type=CreditTransactionType.TOP_UP,
|
||||||
|
transaction_key=checkout_session.id,
|
||||||
|
is_active=False,
|
||||||
|
metadata=Json({"checkout_session": checkout_session}),
|
||||||
|
)
|
||||||
|
|
||||||
|
return checkout_session.url or ""
|
||||||
|
|
||||||
|
# https://docs.stripe.com/checkout/fulfillment
|
||||||
|
async def fulfill_checkout(
|
||||||
|
self, *, session_id: str | None = None, user_id: str | None = None
|
||||||
|
):
|
||||||
|
if (not session_id and not user_id) or (session_id and user_id):
|
||||||
|
raise ValueError("Either session_id or user_id must be provided")
|
||||||
|
|
||||||
|
# Retrieve CreditTransaction
|
||||||
|
find_filter: CreditTransactionWhereInput = {
|
||||||
|
"type": CreditTransactionType.TOP_UP,
|
||||||
|
"isActive": False,
|
||||||
|
}
|
||||||
|
if session_id:
|
||||||
|
find_filter["transactionKey"] = session_id
|
||||||
|
if user_id:
|
||||||
|
find_filter["userId"] = user_id
|
||||||
|
|
||||||
|
# Find the most recent inactive top-up transaction
|
||||||
|
credit_transaction = await CreditTransaction.prisma().find_first_or_raise(
|
||||||
|
where=find_filter,
|
||||||
|
order={"createdAt": "desc"},
|
||||||
|
)
|
||||||
|
|
||||||
|
# This can be called multiple times for one id, so ignore if already fulfilled
|
||||||
|
if not credit_transaction:
|
||||||
|
return
|
||||||
|
|
||||||
|
# Retrieve the Checkout Session from the API
|
||||||
|
checkout_session = stripe.checkout.Session.retrieve(
|
||||||
|
credit_transaction.transactionKey
|
||||||
|
)
|
||||||
|
|
||||||
|
# Check the Checkout Session's payment_status property
|
||||||
|
# to determine if fulfillment should be performed
|
||||||
|
if checkout_session.payment_status in ["paid", "no_payment_required"]:
|
||||||
|
await self._enable_transaction(
|
||||||
|
transaction_key=credit_transaction.transactionKey,
|
||||||
|
user_id=credit_transaction.userId,
|
||||||
|
metadata=Json({"checkout_session": checkout_session}),
|
||||||
|
)
|
||||||
|
|
||||||
|
async def get_credits(self, user_id: str) -> int:
|
||||||
|
balance, _ = await self._get_credits(user_id)
|
||||||
|
return balance
|
||||||
|
|
||||||
|
|
||||||
|
class BetaUserCredit(UserCredit):
|
||||||
|
"""
|
||||||
|
This is a temporary class to handle the test user utilizing monthly credit refill.
|
||||||
|
TODO: Remove this class & its feature toggle.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, num_user_credits_refill: int):
|
||||||
|
self.num_user_credits_refill = num_user_credits_refill
|
||||||
|
|
||||||
|
async def get_credits(self, user_id: str) -> int:
|
||||||
|
cur_time = self.time_now().date()
|
||||||
|
balance, snapshot_time = await self._get_credits(user_id)
|
||||||
|
if (snapshot_time.year, snapshot_time.month) == (cur_time.year, cur_time.month):
|
||||||
|
return balance
|
||||||
|
|
||||||
|
try:
|
||||||
|
await CreditTransaction.prisma().create(
|
||||||
|
data={
|
||||||
|
"transactionKey": f"MONTHLY-CREDIT-TOP-UP-{cur_time}",
|
||||||
|
"userId": user_id,
|
||||||
|
"amount": self.num_user_credits_refill,
|
||||||
|
"runningBalance": self.num_user_credits_refill,
|
||||||
|
"type": CreditTransactionType.TOP_UP,
|
||||||
|
"metadata": Json({}),
|
||||||
|
"isActive": True,
|
||||||
|
"createdAt": self.time_now(),
|
||||||
|
}
|
||||||
|
)
|
||||||
|
except UniqueViolationError:
|
||||||
|
pass # Already refilled this month
|
||||||
|
|
||||||
|
return self.num_user_credits_refill
|
||||||
|
|
||||||
|
|
||||||
class DisabledUserCredit(UserCreditBase):
|
class DisabledUserCredit(UserCreditBase):
|
||||||
async def get_or_refill_credit(self, *args, **kwargs) -> int:
|
async def get_credits(self, *args, **kwargs) -> int:
|
||||||
return 0
|
return 0
|
||||||
|
|
||||||
async def spend_credits(self, *args, **kwargs) -> int:
|
async def spend_credits(self, *args, **kwargs) -> int:
|
||||||
|
@ -222,13 +417,37 @@ class DisabledUserCredit(UserCreditBase):
|
||||||
async def top_up_credits(self, *args, **kwargs):
|
async def top_up_credits(self, *args, **kwargs):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
async def top_up_intent(self, *args, **kwargs) -> str:
|
||||||
|
return ""
|
||||||
|
|
||||||
|
async def fulfill_checkout(self, *args, **kwargs):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
def get_user_credit_model() -> UserCreditBase:
|
def get_user_credit_model() -> UserCreditBase:
|
||||||
if config.enable_credit.lower() == "true":
|
if not settings.config.enable_credit:
|
||||||
return UserCredit(config.num_user_credits_refill)
|
return DisabledUserCredit()
|
||||||
else:
|
|
||||||
return DisabledUserCredit(0)
|
if settings.config.enable_beta_monthly_credit:
|
||||||
|
return BetaUserCredit(settings.config.num_user_credits_refill)
|
||||||
|
|
||||||
|
return UserCredit()
|
||||||
|
|
||||||
|
|
||||||
def get_block_costs() -> dict[str, list[BlockCost]]:
|
def get_block_costs() -> dict[str, list[BlockCost]]:
|
||||||
return {block().id: costs for block, costs in BLOCK_COSTS.items()}
|
return {block().id: costs for block, costs in BLOCK_COSTS.items()}
|
||||||
|
|
||||||
|
|
||||||
|
async def get_stripe_customer_id(user_id: str) -> str:
|
||||||
|
user = await get_user_by_id(user_id)
|
||||||
|
if not user:
|
||||||
|
raise ValueError(f"User not found: {user_id}")
|
||||||
|
|
||||||
|
if user.stripeCustomerId:
|
||||||
|
return user.stripeCustomerId
|
||||||
|
|
||||||
|
customer = stripe.Customer.create(name=user.name or "", email=user.email)
|
||||||
|
await User.prisma().update(
|
||||||
|
where={"id": user_id}, data={"stripeCustomerId": customer.id}
|
||||||
|
)
|
||||||
|
return customer.id
|
||||||
|
|
|
@ -1,5 +1,6 @@
|
||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
|
import zlib
|
||||||
from contextlib import asynccontextmanager
|
from contextlib import asynccontextmanager
|
||||||
from uuid import uuid4
|
from uuid import uuid4
|
||||||
|
|
||||||
|
@ -54,6 +55,14 @@ async def transaction():
|
||||||
yield tx
|
yield tx
|
||||||
|
|
||||||
|
|
||||||
|
@asynccontextmanager
|
||||||
|
async def locked_transaction(key: str):
|
||||||
|
lock_key = zlib.crc32(key.encode("utf-8"))
|
||||||
|
async with transaction() as tx:
|
||||||
|
await tx.execute_raw(f"SELECT pg_advisory_xact_lock({lock_key})")
|
||||||
|
yield tx
|
||||||
|
|
||||||
|
|
||||||
class BaseDbModel(BaseModel):
|
class BaseDbModel(BaseModel):
|
||||||
id: str = Field(default_factory=lambda: str(uuid4()))
|
id: str = Field(default_factory=lambda: str(uuid4()))
|
||||||
|
|
||||||
|
|
|
@ -1,9 +1,10 @@
|
||||||
from collections import defaultdict
|
from collections import defaultdict
|
||||||
from datetime import datetime, timezone
|
from datetime import datetime, timezone
|
||||||
from multiprocessing import Manager
|
from multiprocessing import Manager
|
||||||
from typing import Any, AsyncGenerator, Generator, Generic, TypeVar
|
from typing import Any, AsyncGenerator, Generator, Generic, Optional, TypeVar
|
||||||
|
|
||||||
from prisma.enums import AgentExecutionStatus
|
from prisma.enums import AgentExecutionStatus
|
||||||
|
from prisma.errors import PrismaError
|
||||||
from prisma.models import (
|
from prisma.models import (
|
||||||
AgentGraphExecution,
|
AgentGraphExecution,
|
||||||
AgentNodeExecution,
|
AgentNodeExecution,
|
||||||
|
@ -31,6 +32,7 @@ class NodeExecutionEntry(BaseModel):
|
||||||
graph_id: str
|
graph_id: str
|
||||||
node_exec_id: str
|
node_exec_id: str
|
||||||
node_id: str
|
node_id: str
|
||||||
|
block_id: str
|
||||||
data: BlockInput
|
data: BlockInput
|
||||||
|
|
||||||
|
|
||||||
|
@ -324,6 +326,30 @@ async def update_execution_status(
|
||||||
return ExecutionResult.from_db(res)
|
return ExecutionResult.from_db(res)
|
||||||
|
|
||||||
|
|
||||||
|
async def get_execution(
|
||||||
|
execution_id: str, user_id: str
|
||||||
|
) -> Optional[AgentNodeExecution]:
|
||||||
|
"""
|
||||||
|
Get an execution by ID. Returns None if not found.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
execution_id: The ID of the execution to retrieve
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The execution if found, None otherwise
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
execution = await AgentNodeExecution.prisma().find_unique(
|
||||||
|
where={
|
||||||
|
"id": execution_id,
|
||||||
|
"userId": user_id,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
return execution
|
||||||
|
except PrismaError:
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
async def get_execution_results(graph_exec_id: str) -> list[ExecutionResult]:
|
async def get_execution_results(graph_exec_id: str) -> list[ExecutionResult]:
|
||||||
executions = await AgentNodeExecution.prisma().find_many(
|
executions = await AgentNodeExecution.prisma().find_many(
|
||||||
where={"agentGraphExecutionId": graph_exec_id},
|
where={"agentGraphExecutionId": graph_exec_id},
|
||||||
|
|
|
@ -1,5 +1,6 @@
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import base64
|
||||||
import logging
|
import logging
|
||||||
from typing import (
|
from typing import (
|
||||||
TYPE_CHECKING,
|
TYPE_CHECKING,
|
||||||
|
@ -199,27 +200,42 @@ class OAuth2Credentials(_BaseCredentials):
|
||||||
scopes: list[str]
|
scopes: list[str]
|
||||||
metadata: dict[str, Any] = Field(default_factory=dict)
|
metadata: dict[str, Any] = Field(default_factory=dict)
|
||||||
|
|
||||||
def bearer(self) -> str:
|
def auth_header(self) -> str:
|
||||||
return f"Bearer {self.access_token.get_secret_value()}"
|
return f"Bearer {self.access_token.get_secret_value()}"
|
||||||
|
|
||||||
|
|
||||||
class APIKeyCredentials(_BaseCredentials):
|
class APIKeyCredentials(_BaseCredentials):
|
||||||
type: Literal["api_key"] = "api_key"
|
type: Literal["api_key"] = "api_key"
|
||||||
api_key: SecretStr
|
api_key: SecretStr
|
||||||
expires_at: Optional[int]
|
expires_at: Optional[int] = Field(
|
||||||
|
default=None,
|
||||||
|
description="Unix timestamp (seconds) indicating when the API key expires (if at all)",
|
||||||
|
)
|
||||||
"""Unix timestamp (seconds) indicating when the API key expires (if at all)"""
|
"""Unix timestamp (seconds) indicating when the API key expires (if at all)"""
|
||||||
|
|
||||||
def bearer(self) -> str:
|
def auth_header(self) -> str:
|
||||||
return f"Bearer {self.api_key.get_secret_value()}"
|
return f"Bearer {self.api_key.get_secret_value()}"
|
||||||
|
|
||||||
|
|
||||||
|
class UserPasswordCredentials(_BaseCredentials):
|
||||||
|
type: Literal["user_password"] = "user_password"
|
||||||
|
username: SecretStr
|
||||||
|
password: SecretStr
|
||||||
|
|
||||||
|
def auth_header(self) -> str:
|
||||||
|
# Converting the string to bytes using encode()
|
||||||
|
# Base64 encoding it with base64.b64encode()
|
||||||
|
# Converting the resulting bytes back to a string with decode()
|
||||||
|
return f"Basic {base64.b64encode(f'{self.username.get_secret_value()}:{self.password.get_secret_value()}'.encode()).decode()}"
|
||||||
|
|
||||||
|
|
||||||
Credentials = Annotated[
|
Credentials = Annotated[
|
||||||
OAuth2Credentials | APIKeyCredentials,
|
OAuth2Credentials | APIKeyCredentials | UserPasswordCredentials,
|
||||||
Field(discriminator="type"),
|
Field(discriminator="type"),
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
CredentialsType = Literal["api_key", "oauth2"]
|
CredentialsType = Literal["api_key", "oauth2", "user_password"]
|
||||||
|
|
||||||
|
|
||||||
class OAuthState(BaseModel):
|
class OAuthState(BaseModel):
|
||||||
|
|
|
@ -4,6 +4,7 @@ from typing import Any, Callable, Concatenate, Coroutine, ParamSpec, TypeVar, ca
|
||||||
from backend.data.credit import get_user_credit_model
|
from backend.data.credit import get_user_credit_model
|
||||||
from backend.data.execution import (
|
from backend.data.execution import (
|
||||||
ExecutionResult,
|
ExecutionResult,
|
||||||
|
NodeExecutionEntry,
|
||||||
RedisExecutionEventBus,
|
RedisExecutionEventBus,
|
||||||
create_graph_execution,
|
create_graph_execution,
|
||||||
get_execution_results,
|
get_execution_results,
|
||||||
|
@ -78,12 +79,8 @@ class DatabaseManager(AppService):
|
||||||
|
|
||||||
# Credits
|
# Credits
|
||||||
user_credit_model = get_user_credit_model()
|
user_credit_model = get_user_credit_model()
|
||||||
get_or_refill_credit = cast(
|
|
||||||
Callable[[Any, str], int],
|
|
||||||
exposed_run_and_wait(user_credit_model.get_or_refill_credit),
|
|
||||||
)
|
|
||||||
spend_credits = cast(
|
spend_credits = cast(
|
||||||
Callable[[Any, str, int, str, dict[str, str], float, float], int],
|
Callable[[Any, NodeExecutionEntry, float, float], int],
|
||||||
exposed_run_and_wait(user_credit_model.spend_credits),
|
exposed_run_and_wait(user_credit_model.spend_credits),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
|
@ -183,9 +183,6 @@ def execute_node(
|
||||||
|
|
||||||
output_size = 0
|
output_size = 0
|
||||||
end_status = ExecutionStatus.COMPLETED
|
end_status = ExecutionStatus.COMPLETED
|
||||||
credit = db_client.get_or_refill_credit(user_id)
|
|
||||||
if credit < 0:
|
|
||||||
raise ValueError(f"Insufficient credit: {credit}")
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
for output_name, output_data in node_block.execute(
|
for output_name, output_data in node_block.execute(
|
||||||
|
@ -241,7 +238,8 @@ def execute_node(
|
||||||
if res.end_time and res.start_time
|
if res.end_time and res.start_time
|
||||||
else 0
|
else 0
|
||||||
)
|
)
|
||||||
db_client.spend_credits(user_id, credit, node_block.id, input_data, s, t)
|
data.data = input_data
|
||||||
|
db_client.spend_credits(data, s, t)
|
||||||
|
|
||||||
# Update execution stats
|
# Update execution stats
|
||||||
if execution_stats is not None:
|
if execution_stats is not None:
|
||||||
|
@ -260,7 +258,7 @@ def _enqueue_next_nodes(
|
||||||
log_metadata: LogMetadata,
|
log_metadata: LogMetadata,
|
||||||
) -> list[NodeExecutionEntry]:
|
) -> list[NodeExecutionEntry]:
|
||||||
def add_enqueued_execution(
|
def add_enqueued_execution(
|
||||||
node_exec_id: str, node_id: str, data: BlockInput
|
node_exec_id: str, node_id: str, block_id: str, data: BlockInput
|
||||||
) -> NodeExecutionEntry:
|
) -> NodeExecutionEntry:
|
||||||
exec_update = db_client.update_execution_status(
|
exec_update = db_client.update_execution_status(
|
||||||
node_exec_id, ExecutionStatus.QUEUED, data
|
node_exec_id, ExecutionStatus.QUEUED, data
|
||||||
|
@ -272,6 +270,7 @@ def _enqueue_next_nodes(
|
||||||
graph_id=graph_id,
|
graph_id=graph_id,
|
||||||
node_exec_id=node_exec_id,
|
node_exec_id=node_exec_id,
|
||||||
node_id=node_id,
|
node_id=node_id,
|
||||||
|
block_id=block_id,
|
||||||
data=data,
|
data=data,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -325,7 +324,12 @@ def _enqueue_next_nodes(
|
||||||
# Input is complete, enqueue the execution.
|
# Input is complete, enqueue the execution.
|
||||||
log_metadata.info(f"Enqueued {suffix}")
|
log_metadata.info(f"Enqueued {suffix}")
|
||||||
enqueued_executions.append(
|
enqueued_executions.append(
|
||||||
add_enqueued_execution(next_node_exec_id, next_node_id, next_node_input)
|
add_enqueued_execution(
|
||||||
|
node_exec_id=next_node_exec_id,
|
||||||
|
node_id=next_node_id,
|
||||||
|
block_id=next_node.block_id,
|
||||||
|
data=next_node_input,
|
||||||
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
# Next execution stops here if the link is not static.
|
# Next execution stops here if the link is not static.
|
||||||
|
@ -355,7 +359,12 @@ def _enqueue_next_nodes(
|
||||||
continue
|
continue
|
||||||
log_metadata.info(f"Enqueueing static-link execution {suffix}")
|
log_metadata.info(f"Enqueueing static-link execution {suffix}")
|
||||||
enqueued_executions.append(
|
enqueued_executions.append(
|
||||||
add_enqueued_execution(iexec.node_exec_id, next_node_id, idata)
|
add_enqueued_execution(
|
||||||
|
node_exec_id=iexec.node_exec_id,
|
||||||
|
node_id=next_node_id,
|
||||||
|
block_id=next_node.block_id,
|
||||||
|
data=idata,
|
||||||
|
)
|
||||||
)
|
)
|
||||||
return enqueued_executions
|
return enqueued_executions
|
||||||
|
|
||||||
|
@ -803,8 +812,8 @@ class ExecutionManager(AppService):
|
||||||
# Extract request input data, and assign it to the input pin.
|
# Extract request input data, and assign it to the input pin.
|
||||||
if block.block_type == BlockType.INPUT:
|
if block.block_type == BlockType.INPUT:
|
||||||
name = node.input_default.get("name")
|
name = node.input_default.get("name")
|
||||||
if name and name in data:
|
if name in data.get("node_input", {}):
|
||||||
input_data = {"value": data[name]}
|
input_data = {"value": data["node_input"][name]}
|
||||||
|
|
||||||
# Extract webhook payload, and assign it to the input pin
|
# Extract webhook payload, and assign it to the input pin
|
||||||
webhook_payload_key = f"webhook_{node.webhook_id}_payload"
|
webhook_payload_key = f"webhook_{node.webhook_id}_payload"
|
||||||
|
@ -840,6 +849,7 @@ class ExecutionManager(AppService):
|
||||||
graph_id=node_exec.graph_id,
|
graph_id=node_exec.graph_id,
|
||||||
node_exec_id=node_exec.node_exec_id,
|
node_exec_id=node_exec.node_exec_id,
|
||||||
node_id=node_exec.node_id,
|
node_id=node_exec.node_id,
|
||||||
|
block_id=node_exec.block_id,
|
||||||
data=node_exec.input_data,
|
data=node_exec.input_data,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
|
@ -23,6 +23,15 @@ from backend.util.settings import Settings
|
||||||
|
|
||||||
settings = Settings()
|
settings = Settings()
|
||||||
|
|
||||||
|
# This is an overrride since ollama doesn't actually require an API key, but the creddential system enforces one be attached
|
||||||
|
ollama_credentials = APIKeyCredentials(
|
||||||
|
id="744fdc56-071a-4761-b5a5-0af0ce10a2b5",
|
||||||
|
provider="ollama",
|
||||||
|
api_key=SecretStr("FAKE_API_KEY"),
|
||||||
|
title="Use Credits for Ollama",
|
||||||
|
expires_at=None,
|
||||||
|
)
|
||||||
|
|
||||||
revid_credentials = APIKeyCredentials(
|
revid_credentials = APIKeyCredentials(
|
||||||
id="fdb7f412-f519-48d1-9b5f-d2f73d0e01fe",
|
id="fdb7f412-f519-48d1-9b5f-d2f73d0e01fe",
|
||||||
provider="revid",
|
provider="revid",
|
||||||
|
@ -124,6 +133,7 @@ nvidia_credentials = APIKeyCredentials(
|
||||||
|
|
||||||
|
|
||||||
DEFAULT_CREDENTIALS = [
|
DEFAULT_CREDENTIALS = [
|
||||||
|
ollama_credentials,
|
||||||
revid_credentials,
|
revid_credentials,
|
||||||
ideogram_credentials,
|
ideogram_credentials,
|
||||||
replicate_credentials,
|
replicate_credentials,
|
||||||
|
@ -169,6 +179,10 @@ class IntegrationCredentialsStore:
|
||||||
def get_all_creds(self, user_id: str) -> list[Credentials]:
|
def get_all_creds(self, user_id: str) -> list[Credentials]:
|
||||||
users_credentials = self._get_user_integrations(user_id).credentials
|
users_credentials = self._get_user_integrations(user_id).credentials
|
||||||
all_credentials = users_credentials
|
all_credentials = users_credentials
|
||||||
|
# These will always be added
|
||||||
|
all_credentials.append(ollama_credentials)
|
||||||
|
|
||||||
|
# These will only be added if the API key is set
|
||||||
if settings.secrets.revid_api_key:
|
if settings.secrets.revid_api_key:
|
||||||
all_credentials.append(revid_credentials)
|
all_credentials.append(revid_credentials)
|
||||||
if settings.secrets.ideogram_api_key:
|
if settings.secrets.ideogram_api_key:
|
||||||
|
|
|
@ -2,6 +2,7 @@ from typing import TYPE_CHECKING
|
||||||
|
|
||||||
from .github import GitHubOAuthHandler
|
from .github import GitHubOAuthHandler
|
||||||
from .google import GoogleOAuthHandler
|
from .google import GoogleOAuthHandler
|
||||||
|
from .linear import LinearOAuthHandler
|
||||||
from .notion import NotionOAuthHandler
|
from .notion import NotionOAuthHandler
|
||||||
from .twitter import TwitterOAuthHandler
|
from .twitter import TwitterOAuthHandler
|
||||||
|
|
||||||
|
@ -17,6 +18,7 @@ HANDLERS_BY_NAME: dict["ProviderName", type["BaseOAuthHandler"]] = {
|
||||||
GoogleOAuthHandler,
|
GoogleOAuthHandler,
|
||||||
NotionOAuthHandler,
|
NotionOAuthHandler,
|
||||||
TwitterOAuthHandler,
|
TwitterOAuthHandler,
|
||||||
|
LinearOAuthHandler,
|
||||||
]
|
]
|
||||||
}
|
}
|
||||||
# --8<-- [end:HANDLERS_BY_NAMEExample]
|
# --8<-- [end:HANDLERS_BY_NAMEExample]
|
||||||
|
|
|
@ -0,0 +1,165 @@
|
||||||
|
import json
|
||||||
|
from typing import Optional
|
||||||
|
from urllib.parse import urlencode
|
||||||
|
|
||||||
|
from pydantic import SecretStr
|
||||||
|
|
||||||
|
from backend.blocks.linear._api import LinearAPIException
|
||||||
|
from backend.data.model import APIKeyCredentials, OAuth2Credentials
|
||||||
|
from backend.integrations.providers import ProviderName
|
||||||
|
from backend.util.request import requests
|
||||||
|
|
||||||
|
from .base import BaseOAuthHandler
|
||||||
|
|
||||||
|
|
||||||
|
class LinearOAuthHandler(BaseOAuthHandler):
|
||||||
|
"""
|
||||||
|
OAuth2 handler for Linear.
|
||||||
|
"""
|
||||||
|
|
||||||
|
PROVIDER_NAME = ProviderName.LINEAR
|
||||||
|
|
||||||
|
def __init__(self, client_id: str, client_secret: str, redirect_uri: str):
|
||||||
|
self.client_id = client_id
|
||||||
|
self.client_secret = client_secret
|
||||||
|
self.redirect_uri = redirect_uri
|
||||||
|
self.auth_base_url = "https://linear.app/oauth/authorize"
|
||||||
|
self.token_url = "https://api.linear.app/oauth/token" # Correct token URL
|
||||||
|
self.revoke_url = "https://api.linear.app/oauth/revoke"
|
||||||
|
|
||||||
|
def get_login_url(
|
||||||
|
self, scopes: list[str], state: str, code_challenge: Optional[str]
|
||||||
|
) -> str:
|
||||||
|
|
||||||
|
params = {
|
||||||
|
"client_id": self.client_id,
|
||||||
|
"redirect_uri": self.redirect_uri,
|
||||||
|
"response_type": "code", # Important: include "response_type"
|
||||||
|
"scope": ",".join(scopes), # Comma-separated, not space-separated
|
||||||
|
"state": state,
|
||||||
|
}
|
||||||
|
return f"{self.auth_base_url}?{urlencode(params)}"
|
||||||
|
|
||||||
|
def exchange_code_for_tokens(
|
||||||
|
self, code: str, scopes: list[str], code_verifier: Optional[str]
|
||||||
|
) -> OAuth2Credentials:
|
||||||
|
return self._request_tokens({"code": code, "redirect_uri": self.redirect_uri})
|
||||||
|
|
||||||
|
def revoke_tokens(self, credentials: OAuth2Credentials) -> bool:
|
||||||
|
if not credentials.access_token:
|
||||||
|
raise ValueError("No access token to revoke")
|
||||||
|
|
||||||
|
headers = {
|
||||||
|
"Authorization": f"Bearer {credentials.access_token.get_secret_value()}"
|
||||||
|
}
|
||||||
|
|
||||||
|
response = requests.post(self.revoke_url, headers=headers)
|
||||||
|
if not response.ok:
|
||||||
|
try:
|
||||||
|
error_data = response.json()
|
||||||
|
error_message = error_data.get("error", "Unknown error")
|
||||||
|
except json.JSONDecodeError:
|
||||||
|
error_message = response.text
|
||||||
|
raise LinearAPIException(
|
||||||
|
f"Failed to revoke Linear tokens ({response.status_code}): {error_message}",
|
||||||
|
response.status_code,
|
||||||
|
)
|
||||||
|
|
||||||
|
return True # Linear doesn't return JSON on successful revoke
|
||||||
|
|
||||||
|
def _refresh_tokens(self, credentials: OAuth2Credentials) -> OAuth2Credentials:
|
||||||
|
if not credentials.refresh_token:
|
||||||
|
raise ValueError(
|
||||||
|
"No refresh token available."
|
||||||
|
) # Linear uses non-expiring tokens
|
||||||
|
|
||||||
|
return self._request_tokens(
|
||||||
|
{
|
||||||
|
"refresh_token": credentials.refresh_token.get_secret_value(),
|
||||||
|
"grant_type": "refresh_token",
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
def _request_tokens(
|
||||||
|
self,
|
||||||
|
params: dict[str, str],
|
||||||
|
current_credentials: Optional[OAuth2Credentials] = None,
|
||||||
|
) -> OAuth2Credentials:
|
||||||
|
request_body = {
|
||||||
|
"client_id": self.client_id,
|
||||||
|
"client_secret": self.client_secret,
|
||||||
|
"grant_type": "authorization_code", # Ensure grant_type is correct
|
||||||
|
**params,
|
||||||
|
}
|
||||||
|
|
||||||
|
headers = {
|
||||||
|
"Content-Type": "application/x-www-form-urlencoded"
|
||||||
|
} # Correct header for token request
|
||||||
|
response = requests.post(self.token_url, data=request_body, headers=headers)
|
||||||
|
|
||||||
|
if not response.ok:
|
||||||
|
try:
|
||||||
|
error_data = response.json()
|
||||||
|
error_message = error_data.get("error", "Unknown error")
|
||||||
|
|
||||||
|
except json.JSONDecodeError:
|
||||||
|
error_message = response.text
|
||||||
|
raise LinearAPIException(
|
||||||
|
f"Failed to fetch Linear tokens ({response.status_code}): {error_message}",
|
||||||
|
response.status_code,
|
||||||
|
)
|
||||||
|
|
||||||
|
token_data = response.json()
|
||||||
|
|
||||||
|
# Note: Linear access tokens do not expire, so we set expires_at to None
|
||||||
|
new_credentials = OAuth2Credentials(
|
||||||
|
provider=self.PROVIDER_NAME,
|
||||||
|
title=current_credentials.title if current_credentials else None,
|
||||||
|
username=token_data.get("user", {}).get(
|
||||||
|
"name", "Unknown User"
|
||||||
|
), # extract name or set appropriate
|
||||||
|
access_token=token_data["access_token"],
|
||||||
|
scopes=token_data["scope"].split(
|
||||||
|
","
|
||||||
|
), # Linear returns comma-separated scopes
|
||||||
|
refresh_token=token_data.get(
|
||||||
|
"refresh_token"
|
||||||
|
), # Linear uses non-expiring tokens so this might be null
|
||||||
|
access_token_expires_at=None,
|
||||||
|
refresh_token_expires_at=None,
|
||||||
|
)
|
||||||
|
if current_credentials:
|
||||||
|
new_credentials.id = current_credentials.id
|
||||||
|
return new_credentials
|
||||||
|
|
||||||
|
def _request_username(self, access_token: str) -> Optional[str]:
|
||||||
|
|
||||||
|
# Use the LinearClient to fetch user details using GraphQL
|
||||||
|
from backend.blocks.linear._api import LinearClient
|
||||||
|
|
||||||
|
try:
|
||||||
|
|
||||||
|
linear_client = LinearClient(
|
||||||
|
APIKeyCredentials(
|
||||||
|
api_key=SecretStr(access_token),
|
||||||
|
title="temp",
|
||||||
|
provider=self.PROVIDER_NAME,
|
||||||
|
expires_at=None,
|
||||||
|
)
|
||||||
|
) # Temporary credentials for this request
|
||||||
|
|
||||||
|
query = """
|
||||||
|
query Viewer {
|
||||||
|
viewer {
|
||||||
|
name
|
||||||
|
}
|
||||||
|
}
|
||||||
|
"""
|
||||||
|
|
||||||
|
response = linear_client.query(query)
|
||||||
|
return response["viewer"]["name"]
|
||||||
|
|
||||||
|
except Exception as e: # Handle any errors
|
||||||
|
|
||||||
|
print(f"Error fetching username: {e}")
|
||||||
|
return None
|
|
@ -17,6 +17,7 @@ class ProviderName(str, Enum):
|
||||||
HUBSPOT = "hubspot"
|
HUBSPOT = "hubspot"
|
||||||
IDEOGRAM = "ideogram"
|
IDEOGRAM = "ideogram"
|
||||||
JINA = "jina"
|
JINA = "jina"
|
||||||
|
LINEAR = "linear"
|
||||||
MEDIUM = "medium"
|
MEDIUM = "medium"
|
||||||
NOTION = "notion"
|
NOTION = "notion"
|
||||||
NVIDIA = "nvidia"
|
NVIDIA = "nvidia"
|
||||||
|
@ -25,9 +26,11 @@ class ProviderName(str, Enum):
|
||||||
OPENWEATHERMAP = "openweathermap"
|
OPENWEATHERMAP = "openweathermap"
|
||||||
OPEN_ROUTER = "open_router"
|
OPEN_ROUTER = "open_router"
|
||||||
PINECONE = "pinecone"
|
PINECONE = "pinecone"
|
||||||
|
REDDIT = "reddit"
|
||||||
REPLICATE = "replicate"
|
REPLICATE = "replicate"
|
||||||
REVID = "revid"
|
REVID = "revid"
|
||||||
SLANT3D = "slant3d"
|
SLANT3D = "slant3d"
|
||||||
|
SMTP = "smtp"
|
||||||
TWITTER = "twitter"
|
TWITTER = "twitter"
|
||||||
UNREAL_SPEECH = "unreal_speech"
|
UNREAL_SPEECH = "unreal_speech"
|
||||||
# --8<-- [end:ProviderName]
|
# --8<-- [end:ProviderName]
|
||||||
|
|
|
@ -168,7 +168,7 @@ class BaseWebhooksManager(ABC, Generic[WT]):
|
||||||
|
|
||||||
id = str(uuid4())
|
id = str(uuid4())
|
||||||
secret = secrets.token_hex(32)
|
secret = secrets.token_hex(32)
|
||||||
provider_name = self.PROVIDER_NAME
|
provider_name: ProviderName = self.PROVIDER_NAME
|
||||||
ingress_url = webhook_ingress_url(provider_name=provider_name, webhook_id=id)
|
ingress_url = webhook_ingress_url(provider_name=provider_name, webhook_id=id)
|
||||||
if register:
|
if register:
|
||||||
if not credentials:
|
if not credentials:
|
||||||
|
|
|
@ -1,7 +1,7 @@
|
||||||
import logging
|
import logging
|
||||||
|
|
||||||
from backend.data import integrations
|
from backend.data import integrations
|
||||||
from backend.data.model import APIKeyCredentials, Credentials, OAuth2Credentials
|
from backend.data.model import Credentials
|
||||||
|
|
||||||
from ._base import WT, BaseWebhooksManager
|
from ._base import WT, BaseWebhooksManager
|
||||||
|
|
||||||
|
@ -25,6 +25,6 @@ class ManualWebhookManagerBase(BaseWebhooksManager[WT]):
|
||||||
async def _deregister_webhook(
|
async def _deregister_webhook(
|
||||||
self,
|
self,
|
||||||
webhook: integrations.Webhook,
|
webhook: integrations.Webhook,
|
||||||
credentials: OAuth2Credentials | APIKeyCredentials,
|
credentials: Credentials,
|
||||||
) -> None:
|
) -> None:
|
||||||
pass
|
pass
|
||||||
|
|
|
@ -67,7 +67,7 @@ class GithubWebhooksManager(BaseWebhooksManager):
|
||||||
|
|
||||||
headers = {
|
headers = {
|
||||||
**self.GITHUB_API_DEFAULT_HEADERS,
|
**self.GITHUB_API_DEFAULT_HEADERS,
|
||||||
"Authorization": credentials.bearer(),
|
"Authorization": credentials.auth_header(),
|
||||||
}
|
}
|
||||||
|
|
||||||
repo, github_hook_id = webhook.resource, webhook.provider_webhook_id
|
repo, github_hook_id = webhook.resource, webhook.provider_webhook_id
|
||||||
|
@ -96,7 +96,7 @@ class GithubWebhooksManager(BaseWebhooksManager):
|
||||||
|
|
||||||
headers = {
|
headers = {
|
||||||
**self.GITHUB_API_DEFAULT_HEADERS,
|
**self.GITHUB_API_DEFAULT_HEADERS,
|
||||||
"Authorization": credentials.bearer(),
|
"Authorization": credentials.auth_header(),
|
||||||
}
|
}
|
||||||
webhook_data = {
|
webhook_data = {
|
||||||
"name": "web",
|
"name": "web",
|
||||||
|
@ -142,7 +142,7 @@ class GithubWebhooksManager(BaseWebhooksManager):
|
||||||
|
|
||||||
headers = {
|
headers = {
|
||||||
**self.GITHUB_API_DEFAULT_HEADERS,
|
**self.GITHUB_API_DEFAULT_HEADERS,
|
||||||
"Authorization": credentials.bearer(),
|
"Authorization": credentials.auth_header(),
|
||||||
}
|
}
|
||||||
|
|
||||||
if webhook_type == self.WebhookType.REPO:
|
if webhook_type == self.WebhookType.REPO:
|
||||||
|
|
|
@ -0,0 +1,11 @@
|
||||||
|
from fastapi import FastAPI
|
||||||
|
|
||||||
|
from .routes.v1 import v1_router
|
||||||
|
|
||||||
|
external_app = FastAPI(
|
||||||
|
title="AutoGPT External API",
|
||||||
|
description="External API for AutoGPT integrations",
|
||||||
|
docs_url="/docs",
|
||||||
|
version="1.0",
|
||||||
|
)
|
||||||
|
external_app.include_router(v1_router, prefix="/v1")
|
|
@ -0,0 +1,37 @@
|
||||||
|
from fastapi import Depends, HTTPException, Request
|
||||||
|
from fastapi.security import APIKeyHeader
|
||||||
|
from prisma.enums import APIKeyPermission
|
||||||
|
|
||||||
|
from backend.data.api_key import has_permission, validate_api_key
|
||||||
|
|
||||||
|
api_key_header = APIKeyHeader(name="X-API-Key")
|
||||||
|
|
||||||
|
|
||||||
|
async def require_api_key(request: Request):
|
||||||
|
"""Base middleware for API key authentication"""
|
||||||
|
api_key = await api_key_header(request)
|
||||||
|
|
||||||
|
if api_key is None:
|
||||||
|
raise HTTPException(status_code=401, detail="Missing API key")
|
||||||
|
|
||||||
|
api_key_obj = await validate_api_key(api_key)
|
||||||
|
|
||||||
|
if not api_key_obj:
|
||||||
|
raise HTTPException(status_code=401, detail="Invalid API key")
|
||||||
|
|
||||||
|
request.state.api_key = api_key_obj
|
||||||
|
return api_key_obj
|
||||||
|
|
||||||
|
|
||||||
|
def require_permission(permission: APIKeyPermission):
|
||||||
|
"""Dependency function for checking specific permissions"""
|
||||||
|
|
||||||
|
async def check_permission(api_key=Depends(require_api_key)):
|
||||||
|
if not has_permission(api_key, permission):
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=403,
|
||||||
|
detail=f"API key missing required permission: {permission}",
|
||||||
|
)
|
||||||
|
return api_key
|
||||||
|
|
||||||
|
return check_permission
|
|
@ -0,0 +1,111 @@
|
||||||
|
import logging
|
||||||
|
from collections import defaultdict
|
||||||
|
from typing import Any, Sequence
|
||||||
|
|
||||||
|
from autogpt_libs.utils.cache import thread_cached
|
||||||
|
from fastapi import APIRouter, Depends, HTTPException
|
||||||
|
from prisma.enums import APIKeyPermission
|
||||||
|
|
||||||
|
import backend.data.block
|
||||||
|
from backend.data import execution as execution_db
|
||||||
|
from backend.data import graph as graph_db
|
||||||
|
from backend.data.api_key import APIKey
|
||||||
|
from backend.data.block import BlockInput, CompletedBlockOutput
|
||||||
|
from backend.executor import ExecutionManager
|
||||||
|
from backend.server.external.middleware import require_permission
|
||||||
|
from backend.util.service import get_service_client
|
||||||
|
from backend.util.settings import Settings
|
||||||
|
|
||||||
|
|
||||||
|
@thread_cached
|
||||||
|
def execution_manager_client() -> ExecutionManager:
|
||||||
|
return get_service_client(ExecutionManager)
|
||||||
|
|
||||||
|
|
||||||
|
settings = Settings()
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
v1_router = APIRouter()
|
||||||
|
|
||||||
|
|
||||||
|
@v1_router.get(
|
||||||
|
path="/blocks",
|
||||||
|
tags=["blocks"],
|
||||||
|
dependencies=[Depends(require_permission(APIKeyPermission.READ_BLOCK))],
|
||||||
|
)
|
||||||
|
def get_graph_blocks() -> Sequence[dict[Any, Any]]:
|
||||||
|
blocks = [block() for block in backend.data.block.get_blocks().values()]
|
||||||
|
return [b.to_dict() for b in blocks]
|
||||||
|
|
||||||
|
|
||||||
|
@v1_router.post(
|
||||||
|
path="/blocks/{block_id}/execute",
|
||||||
|
tags=["blocks"],
|
||||||
|
dependencies=[Depends(require_permission(APIKeyPermission.EXECUTE_BLOCK))],
|
||||||
|
)
|
||||||
|
def execute_graph_block(
|
||||||
|
block_id: str,
|
||||||
|
data: BlockInput,
|
||||||
|
api_key: APIKey = Depends(require_permission(APIKeyPermission.EXECUTE_BLOCK)),
|
||||||
|
) -> CompletedBlockOutput:
|
||||||
|
obj = backend.data.block.get_block(block_id)
|
||||||
|
if not obj:
|
||||||
|
raise HTTPException(status_code=404, detail=f"Block #{block_id} not found.")
|
||||||
|
|
||||||
|
output = defaultdict(list)
|
||||||
|
for name, data in obj.execute(data):
|
||||||
|
output[name].append(data)
|
||||||
|
return output
|
||||||
|
|
||||||
|
|
||||||
|
@v1_router.post(
|
||||||
|
path="/graphs/{graph_id}/execute",
|
||||||
|
tags=["graphs"],
|
||||||
|
)
|
||||||
|
def execute_graph(
|
||||||
|
graph_id: str,
|
||||||
|
node_input: dict[Any, Any],
|
||||||
|
api_key: APIKey = Depends(require_permission(APIKeyPermission.EXECUTE_GRAPH)),
|
||||||
|
) -> dict[str, Any]:
|
||||||
|
try:
|
||||||
|
graph_exec = execution_manager_client().add_execution(
|
||||||
|
graph_id, node_input, user_id=api_key.user_id
|
||||||
|
)
|
||||||
|
return {"id": graph_exec.graph_exec_id}
|
||||||
|
except Exception as e:
|
||||||
|
msg = e.__str__().encode().decode("unicode_escape")
|
||||||
|
raise HTTPException(status_code=400, detail=msg)
|
||||||
|
|
||||||
|
|
||||||
|
@v1_router.get(
|
||||||
|
path="/graphs/{graph_id}/executions/{graph_exec_id}/results",
|
||||||
|
tags=["graphs"],
|
||||||
|
)
|
||||||
|
async def get_graph_execution_results(
|
||||||
|
graph_id: str,
|
||||||
|
graph_exec_id: str,
|
||||||
|
api_key: APIKey = Depends(require_permission(APIKeyPermission.READ_GRAPH)),
|
||||||
|
) -> dict:
|
||||||
|
graph = await graph_db.get_graph(graph_id, user_id=api_key.user_id)
|
||||||
|
if not graph:
|
||||||
|
raise HTTPException(status_code=404, detail=f"Graph #{graph_id} not found.")
|
||||||
|
|
||||||
|
results = await execution_db.get_execution_results(graph_exec_id)
|
||||||
|
|
||||||
|
return {
|
||||||
|
"execution_id": graph_exec_id,
|
||||||
|
"nodes": [
|
||||||
|
{
|
||||||
|
"node_id": result.node_id,
|
||||||
|
"input": (
|
||||||
|
result.input_data.get("value")
|
||||||
|
if "value" in result.input_data
|
||||||
|
else result.input_data
|
||||||
|
),
|
||||||
|
"output": result.output_data.get(
|
||||||
|
"response", result.output_data.get("result", [])
|
||||||
|
),
|
||||||
|
}
|
||||||
|
for result in results
|
||||||
|
],
|
||||||
|
}
|
|
@ -2,7 +2,7 @@ import logging
|
||||||
from typing import TYPE_CHECKING, Annotated, Literal
|
from typing import TYPE_CHECKING, Annotated, Literal
|
||||||
|
|
||||||
from fastapi import APIRouter, Body, Depends, HTTPException, Path, Query, Request
|
from fastapi import APIRouter, Body, Depends, HTTPException, Path, Query, Request
|
||||||
from pydantic import BaseModel, Field, SecretStr
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
from backend.data.graph import set_node_webhook
|
from backend.data.graph import set_node_webhook
|
||||||
from backend.data.integrations import (
|
from backend.data.integrations import (
|
||||||
|
@ -12,12 +12,7 @@ from backend.data.integrations import (
|
||||||
publish_webhook_event,
|
publish_webhook_event,
|
||||||
wait_for_webhook_event,
|
wait_for_webhook_event,
|
||||||
)
|
)
|
||||||
from backend.data.model import (
|
from backend.data.model import Credentials, CredentialsType, OAuth2Credentials
|
||||||
APIKeyCredentials,
|
|
||||||
Credentials,
|
|
||||||
CredentialsType,
|
|
||||||
OAuth2Credentials,
|
|
||||||
)
|
|
||||||
from backend.executor.manager import ExecutionManager
|
from backend.executor.manager import ExecutionManager
|
||||||
from backend.integrations.creds_manager import IntegrationCredentialsManager
|
from backend.integrations.creds_manager import IntegrationCredentialsManager
|
||||||
from backend.integrations.oauth import HANDLERS_BY_NAME
|
from backend.integrations.oauth import HANDLERS_BY_NAME
|
||||||
|
@ -110,6 +105,11 @@ def callback(
|
||||||
|
|
||||||
logger.debug(f"Received credentials with final scopes: {credentials.scopes}")
|
logger.debug(f"Received credentials with final scopes: {credentials.scopes}")
|
||||||
|
|
||||||
|
# Linear returns scopes as a single string with spaces, so we need to split them
|
||||||
|
# TODO: make a bypass of this part of the OAuth handler
|
||||||
|
if len(credentials.scopes) == 1 and " " in credentials.scopes[0]:
|
||||||
|
credentials.scopes = credentials.scopes[0].split(" ")
|
||||||
|
|
||||||
# Check if the granted scopes are sufficient for the requested scopes
|
# Check if the granted scopes are sufficient for the requested scopes
|
||||||
if not set(scopes).issubset(set(credentials.scopes)):
|
if not set(scopes).issubset(set(credentials.scopes)):
|
||||||
# For now, we'll just log the warning and continue
|
# For now, we'll just log the warning and continue
|
||||||
|
@ -199,31 +199,21 @@ def get_credential(
|
||||||
|
|
||||||
|
|
||||||
@router.post("/{provider}/credentials", status_code=201)
|
@router.post("/{provider}/credentials", status_code=201)
|
||||||
def create_api_key_credentials(
|
def create_credentials(
|
||||||
user_id: Annotated[str, Depends(get_user_id)],
|
user_id: Annotated[str, Depends(get_user_id)],
|
||||||
provider: Annotated[
|
provider: Annotated[
|
||||||
ProviderName, Path(title="The provider to create credentials for")
|
ProviderName, Path(title="The provider to create credentials for")
|
||||||
],
|
],
|
||||||
api_key: Annotated[str, Body(title="The API key to store")],
|
credentials: Credentials,
|
||||||
title: Annotated[str, Body(title="Optional title for the credentials")],
|
) -> Credentials:
|
||||||
expires_at: Annotated[
|
credentials.provider = provider
|
||||||
int | None, Body(title="Unix timestamp when the key expires")
|
|
||||||
] = None,
|
|
||||||
) -> APIKeyCredentials:
|
|
||||||
new_credentials = APIKeyCredentials(
|
|
||||||
provider=provider,
|
|
||||||
api_key=SecretStr(api_key),
|
|
||||||
title=title,
|
|
||||||
expires_at=expires_at,
|
|
||||||
)
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
creds_manager.create(user_id, new_credentials)
|
creds_manager.create(user_id, credentials)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
status_code=500, detail=f"Failed to store credentials: {str(e)}"
|
status_code=500, detail=f"Failed to store credentials: {str(e)}"
|
||||||
)
|
)
|
||||||
return new_credentials
|
return credentials
|
||||||
|
|
||||||
|
|
||||||
class CredentialsDeletionResponse(BaseModel):
|
class CredentialsDeletionResponse(BaseModel):
|
||||||
|
|
|
@ -56,3 +56,8 @@ class SetGraphActiveVersion(pydantic.BaseModel):
|
||||||
|
|
||||||
class UpdatePermissionsRequest(pydantic.BaseModel):
|
class UpdatePermissionsRequest(pydantic.BaseModel):
|
||||||
permissions: List[APIKeyPermission]
|
permissions: List[APIKeyPermission]
|
||||||
|
|
||||||
|
|
||||||
|
class RequestTopUp(pydantic.BaseModel):
|
||||||
|
amount: int
|
||||||
|
"""Amount of credits to top up."""
|
||||||
|
|
|
@ -20,6 +20,7 @@ import backend.server.v2.library.routes
|
||||||
import backend.server.v2.store.routes
|
import backend.server.v2.store.routes
|
||||||
import backend.util.service
|
import backend.util.service
|
||||||
import backend.util.settings
|
import backend.util.settings
|
||||||
|
from backend.server.external.api import external_app
|
||||||
|
|
||||||
settings = backend.util.settings.Settings()
|
settings = backend.util.settings.Settings()
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
@ -94,6 +95,8 @@ app.include_router(
|
||||||
backend.server.v2.library.routes.router, tags=["v2"], prefix="/api/library"
|
backend.server.v2.library.routes.router, tags=["v2"], prefix="/api/library"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
app.mount("/external-api", external_app)
|
||||||
|
|
||||||
|
|
||||||
@app.get(path="/health", tags=["health"], dependencies=[])
|
@app.get(path="/health", tags=["health"], dependencies=[])
|
||||||
async def health():
|
async def health():
|
||||||
|
|
|
@ -4,10 +4,11 @@ from collections import defaultdict
|
||||||
from typing import TYPE_CHECKING, Annotated, Any, Sequence
|
from typing import TYPE_CHECKING, Annotated, Any, Sequence
|
||||||
|
|
||||||
import pydantic
|
import pydantic
|
||||||
|
import stripe
|
||||||
from autogpt_libs.auth.middleware import auth_middleware
|
from autogpt_libs.auth.middleware import auth_middleware
|
||||||
from autogpt_libs.feature_flag.client import feature_flag
|
from autogpt_libs.feature_flag.client import feature_flag
|
||||||
from autogpt_libs.utils.cache import thread_cached
|
from autogpt_libs.utils.cache import thread_cached
|
||||||
from fastapi import APIRouter, Depends, HTTPException
|
from fastapi import APIRouter, Depends, HTTPException, Request, Response
|
||||||
from typing_extensions import Optional, TypedDict
|
from typing_extensions import Optional, TypedDict
|
||||||
|
|
||||||
import backend.data.block
|
import backend.data.block
|
||||||
|
@ -28,7 +29,11 @@ from backend.data.api_key import (
|
||||||
update_api_key_permissions,
|
update_api_key_permissions,
|
||||||
)
|
)
|
||||||
from backend.data.block import BlockInput, CompletedBlockOutput
|
from backend.data.block import BlockInput, CompletedBlockOutput
|
||||||
from backend.data.credit import get_block_costs, get_user_credit_model
|
from backend.data.credit import (
|
||||||
|
get_block_costs,
|
||||||
|
get_stripe_customer_id,
|
||||||
|
get_user_credit_model,
|
||||||
|
)
|
||||||
from backend.data.user import get_or_create_user
|
from backend.data.user import get_or_create_user
|
||||||
from backend.executor import ExecutionManager, ExecutionScheduler, scheduler
|
from backend.executor import ExecutionManager, ExecutionScheduler, scheduler
|
||||||
from backend.integrations.creds_manager import IntegrationCredentialsManager
|
from backend.integrations.creds_manager import IntegrationCredentialsManager
|
||||||
|
@ -40,6 +45,7 @@ from backend.server.model import (
|
||||||
CreateAPIKeyRequest,
|
CreateAPIKeyRequest,
|
||||||
CreateAPIKeyResponse,
|
CreateAPIKeyResponse,
|
||||||
CreateGraph,
|
CreateGraph,
|
||||||
|
RequestTopUp,
|
||||||
SetGraphActiveVersion,
|
SetGraphActiveVersion,
|
||||||
UpdatePermissionsRequest,
|
UpdatePermissionsRequest,
|
||||||
)
|
)
|
||||||
|
@ -134,7 +140,69 @@ async def get_user_credits(
|
||||||
user_id: Annotated[str, Depends(get_user_id)],
|
user_id: Annotated[str, Depends(get_user_id)],
|
||||||
) -> dict[str, int]:
|
) -> dict[str, int]:
|
||||||
# Credits can go negative, so ensure it's at least 0 for user to see.
|
# Credits can go negative, so ensure it's at least 0 for user to see.
|
||||||
return {"credits": max(await _user_credit_model.get_or_refill_credit(user_id), 0)}
|
return {"credits": max(await _user_credit_model.get_credits(user_id), 0)}
|
||||||
|
|
||||||
|
|
||||||
|
@v1_router.post(
|
||||||
|
path="/credits", tags=["credits"], dependencies=[Depends(auth_middleware)]
|
||||||
|
)
|
||||||
|
async def request_top_up(
|
||||||
|
request: RequestTopUp, user_id: Annotated[str, Depends(get_user_id)]
|
||||||
|
):
|
||||||
|
checkout_url = await _user_credit_model.top_up_intent(user_id, request.amount)
|
||||||
|
return {"checkout_url": checkout_url}
|
||||||
|
|
||||||
|
|
||||||
|
@v1_router.patch(
|
||||||
|
path="/credits", tags=["credits"], dependencies=[Depends(auth_middleware)]
|
||||||
|
)
|
||||||
|
async def fulfill_checkout(user_id: Annotated[str, Depends(get_user_id)]):
|
||||||
|
await _user_credit_model.fulfill_checkout(user_id=user_id)
|
||||||
|
return Response(status_code=200)
|
||||||
|
|
||||||
|
|
||||||
|
@v1_router.post(path="/credits/stripe_webhook", tags=["credits"])
|
||||||
|
async def stripe_webhook(request: Request):
|
||||||
|
# Get the raw request body
|
||||||
|
payload = await request.body()
|
||||||
|
# Get the signature header
|
||||||
|
sig_header = request.headers.get("stripe-signature")
|
||||||
|
|
||||||
|
try:
|
||||||
|
event = stripe.Webhook.construct_event(
|
||||||
|
payload, sig_header, settings.secrets.stripe_webhook_secret
|
||||||
|
)
|
||||||
|
except ValueError:
|
||||||
|
# Invalid payload
|
||||||
|
raise HTTPException(status_code=400)
|
||||||
|
except stripe.SignatureVerificationError:
|
||||||
|
# Invalid signature
|
||||||
|
raise HTTPException(status_code=400)
|
||||||
|
|
||||||
|
if (
|
||||||
|
event["type"] == "checkout.session.completed"
|
||||||
|
or event["type"] == "checkout.session.async_payment_succeeded"
|
||||||
|
):
|
||||||
|
await _user_credit_model.fulfill_checkout(
|
||||||
|
session_id=event["data"]["object"]["id"]
|
||||||
|
)
|
||||||
|
|
||||||
|
return Response(status_code=200)
|
||||||
|
|
||||||
|
|
||||||
|
@v1_router.get(path="/credits/manage", dependencies=[Depends(auth_middleware)])
|
||||||
|
async def manage_payment_method(
|
||||||
|
user_id: Annotated[str, Depends(get_user_id)],
|
||||||
|
) -> dict[str, str]:
|
||||||
|
session = stripe.billing_portal.Session.create(
|
||||||
|
customer=await get_stripe_customer_id(user_id),
|
||||||
|
return_url=settings.config.platform_base_url + "/store/credits",
|
||||||
|
)
|
||||||
|
if not session:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=400, detail="Failed to create billing portal session"
|
||||||
|
)
|
||||||
|
return {"url": session.url}
|
||||||
|
|
||||||
|
|
||||||
########################################################
|
########################################################
|
||||||
|
@ -545,7 +613,6 @@ def get_execution_schedules(
|
||||||
tags=["api-keys"],
|
tags=["api-keys"],
|
||||||
dependencies=[Depends(auth_middleware)],
|
dependencies=[Depends(auth_middleware)],
|
||||||
)
|
)
|
||||||
@feature_flag("api-keys-enabled")
|
|
||||||
async def create_api_key(
|
async def create_api_key(
|
||||||
request: CreateAPIKeyRequest, user_id: Annotated[str, Depends(get_user_id)]
|
request: CreateAPIKeyRequest, user_id: Annotated[str, Depends(get_user_id)]
|
||||||
) -> CreateAPIKeyResponse:
|
) -> CreateAPIKeyResponse:
|
||||||
|
@ -569,7 +636,6 @@ async def create_api_key(
|
||||||
tags=["api-keys"],
|
tags=["api-keys"],
|
||||||
dependencies=[Depends(auth_middleware)],
|
dependencies=[Depends(auth_middleware)],
|
||||||
)
|
)
|
||||||
@feature_flag("api-keys-enabled")
|
|
||||||
async def get_api_keys(
|
async def get_api_keys(
|
||||||
user_id: Annotated[str, Depends(get_user_id)]
|
user_id: Annotated[str, Depends(get_user_id)]
|
||||||
) -> list[APIKeyWithoutHash]:
|
) -> list[APIKeyWithoutHash]:
|
||||||
|
@ -587,7 +653,6 @@ async def get_api_keys(
|
||||||
tags=["api-keys"],
|
tags=["api-keys"],
|
||||||
dependencies=[Depends(auth_middleware)],
|
dependencies=[Depends(auth_middleware)],
|
||||||
)
|
)
|
||||||
@feature_flag("api-keys-enabled")
|
|
||||||
async def get_api_key(
|
async def get_api_key(
|
||||||
key_id: str, user_id: Annotated[str, Depends(get_user_id)]
|
key_id: str, user_id: Annotated[str, Depends(get_user_id)]
|
||||||
) -> APIKeyWithoutHash:
|
) -> APIKeyWithoutHash:
|
||||||
|
|
|
@ -81,10 +81,14 @@ class Config(UpdateTrackingModel["Config"], BaseSettings):
|
||||||
default=True,
|
default=True,
|
||||||
description="If authentication is enabled or not",
|
description="If authentication is enabled or not",
|
||||||
)
|
)
|
||||||
enable_credit: str = Field(
|
enable_credit: bool = Field(
|
||||||
default="false",
|
default=False,
|
||||||
description="If user credit system is enabled or not",
|
description="If user credit system is enabled or not",
|
||||||
)
|
)
|
||||||
|
enable_beta_monthly_credit: bool = Field(
|
||||||
|
default=True,
|
||||||
|
description="If beta monthly credits accounting is enabled or not",
|
||||||
|
)
|
||||||
num_user_credits_refill: int = Field(
|
num_user_credits_refill: int = Field(
|
||||||
default=1500,
|
default=1500,
|
||||||
description="Number of credits to refill for each user",
|
description="Number of credits to refill for each user",
|
||||||
|
@ -153,6 +157,11 @@ class Config(UpdateTrackingModel["Config"], BaseSettings):
|
||||||
description="The name of the Google Cloud Storage bucket for media files",
|
description="The name of the Google Cloud Storage bucket for media files",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
reddit_user_agent: str = Field(
|
||||||
|
default="AutoGPT:1.0 (by /u/autogpt)",
|
||||||
|
description="The user agent for the Reddit API",
|
||||||
|
)
|
||||||
|
|
||||||
scheduler_db_pool_size: int = Field(
|
scheduler_db_pool_size: int = Field(
|
||||||
default=3,
|
default=3,
|
||||||
description="The pool size for the scheduler database connection pool",
|
description="The pool size for the scheduler database connection pool",
|
||||||
|
@ -276,8 +285,6 @@ class Secrets(UpdateTrackingModel["Secrets"], BaseSettings):
|
||||||
|
|
||||||
reddit_client_id: str = Field(default="", description="Reddit client ID")
|
reddit_client_id: str = Field(default="", description="Reddit client ID")
|
||||||
reddit_client_secret: str = Field(default="", description="Reddit client secret")
|
reddit_client_secret: str = Field(default="", description="Reddit client secret")
|
||||||
reddit_username: str = Field(default="", description="Reddit username")
|
|
||||||
reddit_password: str = Field(default="", description="Reddit password")
|
|
||||||
|
|
||||||
openweathermap_api_key: str = Field(
|
openweathermap_api_key: str = Field(
|
||||||
default="", description="OpenWeatherMap API key"
|
default="", description="OpenWeatherMap API key"
|
||||||
|
@ -309,6 +316,12 @@ class Secrets(UpdateTrackingModel["Secrets"], BaseSettings):
|
||||||
e2b_api_key: str = Field(default="", description="E2B API key")
|
e2b_api_key: str = Field(default="", description="E2B API key")
|
||||||
nvidia_api_key: str = Field(default="", description="Nvidia API key")
|
nvidia_api_key: str = Field(default="", description="Nvidia API key")
|
||||||
|
|
||||||
|
linear_client_id: str = Field(default="", description="Linear client ID")
|
||||||
|
linear_client_secret: str = Field(default="", description="Linear client secret")
|
||||||
|
|
||||||
|
stripe_api_key: str = Field(default="", description="Stripe API Key")
|
||||||
|
stripe_webhook_secret: str = Field(default="", description="Stripe Webhook Secret")
|
||||||
|
|
||||||
# Add more secret fields as needed
|
# Add more secret fields as needed
|
||||||
|
|
||||||
model_config = SettingsConfigDict(
|
model_config = SettingsConfigDict(
|
||||||
|
|
|
@ -0,0 +1,2 @@
|
||||||
|
-- AlterTable
|
||||||
|
ALTER TABLE "CreditTransaction" ADD COLUMN "runningBalance" INTEGER;
|
|
@ -0,0 +1,35 @@
|
||||||
|
/*
|
||||||
|
Warnings:
|
||||||
|
|
||||||
|
- You are about to drop the column `blockId` on the `CreditTransaction` table. All the data in the column will be moved to metadata->block_id.
|
||||||
|
|
||||||
|
*/
|
||||||
|
BEGIN;
|
||||||
|
|
||||||
|
-- DropForeignKey blockId
|
||||||
|
ALTER TABLE "CreditTransaction" DROP CONSTRAINT "CreditTransaction_blockId_fkey";
|
||||||
|
|
||||||
|
-- Update migrate blockId into metadata->"block_id"
|
||||||
|
UPDATE "CreditTransaction"
|
||||||
|
SET "metadata" = jsonb_set(
|
||||||
|
COALESCE("metadata"::jsonb, '{}'),
|
||||||
|
'{block_id}',
|
||||||
|
to_jsonb("blockId")
|
||||||
|
)
|
||||||
|
WHERE "blockId" IS NOT NULL;
|
||||||
|
|
||||||
|
-- AlterTable drop blockId
|
||||||
|
ALTER TABLE "CreditTransaction" DROP COLUMN "blockId";
|
||||||
|
|
||||||
|
COMMIT;
|
||||||
|
|
||||||
|
/*
|
||||||
|
These indices dropped below were part of the cleanup during the schema change applied above.
|
||||||
|
These indexes were not useful and will not impact anything upon their removal.
|
||||||
|
*/
|
||||||
|
|
||||||
|
-- DropIndex
|
||||||
|
DROP INDEX "StoreListingReview_storeListingVersionId_idx";
|
||||||
|
|
||||||
|
-- DropIndex
|
||||||
|
DROP INDEX "StoreListingSubmission_Status_idx";
|
|
@ -3688,6 +3688,22 @@ docs = ["myst-parser[linkify]", "sphinx", "sphinx-rtd-theme"]
|
||||||
release = ["twine"]
|
release = ["twine"]
|
||||||
test = ["pylint", "pytest", "pytest-black", "pytest-cov", "pytest-pylint"]
|
test = ["pylint", "pytest", "pytest-black", "pytest-cov", "pytest-pylint"]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "stripe"
|
||||||
|
version = "11.4.1"
|
||||||
|
description = "Python bindings for the Stripe API"
|
||||||
|
optional = false
|
||||||
|
python-versions = ">=3.6"
|
||||||
|
groups = ["main"]
|
||||||
|
files = [
|
||||||
|
{file = "stripe-11.4.1-py2.py3-none-any.whl", hash = "sha256:8aa47a241de0355c383c916c4ef7273ab666f096a44ee7081e357db4a36f0cce"},
|
||||||
|
{file = "stripe-11.4.1.tar.gz", hash = "sha256:7ddd251b622d490fe57d78487855dc9f4d95b1bb113607e81fd377037a133d5a"},
|
||||||
|
]
|
||||||
|
|
||||||
|
[package.dependencies]
|
||||||
|
requests = {version = ">=2.20", markers = "python_version >= \"3.0\""}
|
||||||
|
typing-extensions = {version = ">=4.5.0", markers = "python_version >= \"3.7\""}
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "supabase"
|
name = "supabase"
|
||||||
version = "2.11.0"
|
version = "2.11.0"
|
||||||
|
@ -4432,4 +4448,4 @@ type = ["pytest-mypy"]
|
||||||
[metadata]
|
[metadata]
|
||||||
lock-version = "2.1"
|
lock-version = "2.1"
|
||||||
python-versions = ">=3.10,<3.13"
|
python-versions = ">=3.10,<3.13"
|
||||||
content-hash = "711669de9e6d5b81f19286bd41d52f57bc0177ba8ff5f2b477313a5b2d012ae5"
|
content-hash = "341712d286b6a6fae89055bd21a55d8fa918973e446f6c0f0329a8493022cbae"
|
||||||
|
|
|
@ -39,6 +39,7 @@ python-dotenv = "^1.0.1"
|
||||||
redis = "^5.2.0"
|
redis = "^5.2.0"
|
||||||
sentry-sdk = "2.19.2"
|
sentry-sdk = "2.19.2"
|
||||||
strenum = "^0.4.9"
|
strenum = "^0.4.9"
|
||||||
|
stripe = "^11.3.0"
|
||||||
supabase = "2.11.0"
|
supabase = "2.11.0"
|
||||||
tenacity = "^9.0.0"
|
tenacity = "^9.0.0"
|
||||||
tweepy = "^4.14.0"
|
tweepy = "^4.14.0"
|
||||||
|
|
|
@ -32,12 +32,12 @@ model User {
|
||||||
AgentPreset AgentPreset[]
|
AgentPreset AgentPreset[]
|
||||||
UserAgent UserAgent[]
|
UserAgent UserAgent[]
|
||||||
|
|
||||||
Profile Profile[]
|
Profile Profile[]
|
||||||
StoreListing StoreListing[]
|
StoreListing StoreListing[]
|
||||||
StoreListingReview StoreListingReview[]
|
StoreListingReview StoreListingReview[]
|
||||||
StoreListingSubmission StoreListingSubmission[]
|
StoreListingSubmission StoreListingSubmission[]
|
||||||
APIKeys APIKey[]
|
APIKeys APIKey[]
|
||||||
IntegrationWebhooks IntegrationWebhook[]
|
IntegrationWebhooks IntegrationWebhook[]
|
||||||
|
|
||||||
@@index([id])
|
@@index([id])
|
||||||
@@index([email])
|
@@index([email])
|
||||||
|
@ -64,23 +64,23 @@ model AgentGraph {
|
||||||
AgentNodes AgentNode[]
|
AgentNodes AgentNode[]
|
||||||
AgentGraphExecution AgentGraphExecution[]
|
AgentGraphExecution AgentGraphExecution[]
|
||||||
|
|
||||||
AgentPreset AgentPreset[]
|
AgentPreset AgentPreset[]
|
||||||
UserAgent UserAgent[]
|
UserAgent UserAgent[]
|
||||||
StoreListing StoreListing[]
|
StoreListing StoreListing[]
|
||||||
StoreListingVersion StoreListingVersion?
|
StoreListingVersion StoreListingVersion?
|
||||||
|
|
||||||
@@id(name: "graphVersionId", [id, version])
|
@@id(name: "graphVersionId", [id, version])
|
||||||
@@index([userId, isActive])
|
@@index([userId, isActive])
|
||||||
}
|
}
|
||||||
|
|
||||||
////////////////////////////////////////////////////////////
|
////////////////////////////////////////////////////////////
|
||||||
////////////////////////////////////////////////////////////
|
////////////////////////////////////////////////////////////
|
||||||
//////////////// USER SPECIFIC DATA ////////////////////
|
//////////////// USER SPECIFIC DATA ////////////////////
|
||||||
////////////////////////////////////////////////////////////
|
////////////////////////////////////////////////////////////
|
||||||
////////////////////////////////////////////////////////////
|
////////////////////////////////////////////////////////////
|
||||||
|
|
||||||
// An AgentPrest is an Agent + User Configuration of that agent.
|
// An AgentPrest is an Agent + User Configuration of that agent.
|
||||||
// For example, if someone has created a weather agent and they want to set it up to
|
// For example, if someone has created a weather agent and they want to set it up to
|
||||||
// Inform them of extreme weather warnings in Texas, the agent with the configuration to set it to
|
// Inform them of extreme weather warnings in Texas, the agent with the configuration to set it to
|
||||||
// monitor texas, along with the cron setup or webhook tiggers, is an AgentPreset
|
// monitor texas, along with the cron setup or webhook tiggers, is an AgentPreset
|
||||||
model AgentPreset {
|
model AgentPreset {
|
||||||
|
@ -102,9 +102,9 @@ model AgentPreset {
|
||||||
agentVersion Int
|
agentVersion Int
|
||||||
Agent AgentGraph @relation(fields: [agentId, agentVersion], references: [id, version], onDelete: Cascade)
|
Agent AgentGraph @relation(fields: [agentId, agentVersion], references: [id, version], onDelete: Cascade)
|
||||||
|
|
||||||
InputPresets AgentNodeExecutionInputOutput[] @relation("AgentPresetsInputData")
|
InputPresets AgentNodeExecutionInputOutput[] @relation("AgentPresetsInputData")
|
||||||
UserAgents UserAgent[]
|
UserAgents UserAgent[]
|
||||||
AgentExecution AgentGraphExecution[]
|
AgentExecution AgentGraphExecution[]
|
||||||
|
|
||||||
@@index([userId])
|
@@index([userId])
|
||||||
}
|
}
|
||||||
|
@ -134,11 +134,11 @@ model UserAgent {
|
||||||
@@index([userId])
|
@@index([userId])
|
||||||
}
|
}
|
||||||
|
|
||||||
////////////////////////////////////////////////////////////
|
////////////////////////////////////////////////////////////
|
||||||
////////////////////////////////////////////////////////////
|
////////////////////////////////////////////////////////////
|
||||||
//////// AGENT DEFINITION AND EXECUTION TABLES ////////
|
//////// AGENT DEFINITION AND EXECUTION TABLES ////////
|
||||||
////////////////////////////////////////////////////////////
|
////////////////////////////////////////////////////////////
|
||||||
////////////////////////////////////////////////////////////
|
////////////////////////////////////////////////////////////
|
||||||
|
|
||||||
// This model describes a single node in the Agent Graph/Flow (Multi Agent System).
|
// This model describes a single node in the Agent Graph/Flow (Multi Agent System).
|
||||||
model AgentNode {
|
model AgentNode {
|
||||||
|
@ -207,7 +207,6 @@ model AgentBlock {
|
||||||
|
|
||||||
// Prisma requires explicit back-references.
|
// Prisma requires explicit back-references.
|
||||||
ReferencedByAgentNode AgentNode[]
|
ReferencedByAgentNode AgentNode[]
|
||||||
CreditTransaction CreditTransaction[]
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// This model describes the status of an AgentGraphExecution or AgentNodeExecution.
|
// This model describes the status of an AgentGraphExecution or AgentNodeExecution.
|
||||||
|
@ -345,11 +344,11 @@ model AnalyticsDetails {
|
||||||
@@index([type])
|
@@index([type])
|
||||||
}
|
}
|
||||||
|
|
||||||
////////////////////////////////////////////////////////////
|
////////////////////////////////////////////////////////////
|
||||||
////////////////////////////////////////////////////////////
|
////////////////////////////////////////////////////////////
|
||||||
////////////// METRICS TRACKING TABLES ////////////////
|
////////////// METRICS TRACKING TABLES ////////////////
|
||||||
////////////////////////////////////////////////////////////
|
////////////////////////////////////////////////////////////
|
||||||
////////////////////////////////////////////////////////////
|
////////////////////////////////////////////////////////////
|
||||||
model AnalyticsMetrics {
|
model AnalyticsMetrics {
|
||||||
id String @id @default(uuid())
|
id String @id @default(uuid())
|
||||||
createdAt DateTime @default(now())
|
createdAt DateTime @default(now())
|
||||||
|
@ -375,11 +374,11 @@ enum CreditTransactionType {
|
||||||
USAGE
|
USAGE
|
||||||
}
|
}
|
||||||
|
|
||||||
////////////////////////////////////////////////////////////
|
////////////////////////////////////////////////////////////
|
||||||
////////////////////////////////////////////////////////////
|
////////////////////////////////////////////////////////////
|
||||||
//////// ACCOUNTING AND CREDIT SYSTEM TABLES //////////
|
//////// ACCOUNTING AND CREDIT SYSTEM TABLES //////////
|
||||||
////////////////////////////////////////////////////////////
|
////////////////////////////////////////////////////////////
|
||||||
////////////////////////////////////////////////////////////
|
////////////////////////////////////////////////////////////
|
||||||
model CreditTransaction {
|
model CreditTransaction {
|
||||||
transactionKey String @default(uuid())
|
transactionKey String @default(uuid())
|
||||||
createdAt DateTime @default(now())
|
createdAt DateTime @default(now())
|
||||||
|
@ -387,12 +386,11 @@ model CreditTransaction {
|
||||||
userId String
|
userId String
|
||||||
user User @relation(fields: [userId], references: [id], onDelete: Cascade)
|
user User @relation(fields: [userId], references: [id], onDelete: Cascade)
|
||||||
|
|
||||||
blockId String?
|
|
||||||
block AgentBlock? @relation(fields: [blockId], references: [id])
|
|
||||||
|
|
||||||
amount Int
|
amount Int
|
||||||
type CreditTransactionType
|
type CreditTransactionType
|
||||||
|
|
||||||
|
runningBalance Int?
|
||||||
|
|
||||||
isActive Boolean @default(true)
|
isActive Boolean @default(true)
|
||||||
metadata Json?
|
metadata Json?
|
||||||
|
|
||||||
|
@ -400,11 +398,11 @@ model CreditTransaction {
|
||||||
@@index([userId, createdAt])
|
@@index([userId, createdAt])
|
||||||
}
|
}
|
||||||
|
|
||||||
////////////////////////////////////////////////////////////
|
////////////////////////////////////////////////////////////
|
||||||
////////////////////////////////////////////////////////////
|
////////////////////////////////////////////////////////////
|
||||||
////////////// Store TABLES ///////////////////////////
|
////////////// Store TABLES ///////////////////////////
|
||||||
////////////////////////////////////////////////////////////
|
////////////////////////////////////////////////////////////
|
||||||
////////////////////////////////////////////////////////////
|
////////////////////////////////////////////////////////////
|
||||||
|
|
||||||
model Profile {
|
model Profile {
|
||||||
id String @id @default(uuid())
|
id String @id @default(uuid())
|
||||||
|
@ -412,7 +410,7 @@ model Profile {
|
||||||
updatedAt DateTime @default(now()) @updatedAt
|
updatedAt DateTime @default(now()) @updatedAt
|
||||||
|
|
||||||
// Only 1 of user or group can be set.
|
// Only 1 of user or group can be set.
|
||||||
// The user this profile belongs to, if any.
|
// The user this profile belongs to, if any.
|
||||||
userId String?
|
userId String?
|
||||||
User User? @relation(fields: [userId], references: [id], onDelete: Cascade)
|
User User? @relation(fields: [userId], references: [id], onDelete: Cascade)
|
||||||
|
|
||||||
|
@ -526,7 +524,7 @@ model StoreListingVersion {
|
||||||
agentVersion Int
|
agentVersion Int
|
||||||
Agent AgentGraph @relation(fields: [agentId, agentVersion], references: [id, version])
|
Agent AgentGraph @relation(fields: [agentId, agentVersion], references: [id, version])
|
||||||
|
|
||||||
// The detials for this version of the agent, this allows the author to update the details of the agent,
|
// The details for this version of the agent, this allows the author to update the details of the agent,
|
||||||
// But still allow using old versions of the agent with there original details.
|
// But still allow using old versions of the agent with there original details.
|
||||||
// TODO: Create a database view that shows only the latest version of each store listing.
|
// TODO: Create a database view that shows only the latest version of each store listing.
|
||||||
slug String
|
slug String
|
||||||
|
@ -571,7 +569,6 @@ model StoreListingReview {
|
||||||
comments String?
|
comments String?
|
||||||
|
|
||||||
@@unique([storeListingVersionId, reviewByUserId])
|
@@unique([storeListingVersionId, reviewByUserId])
|
||||||
@@index([storeListingVersionId])
|
|
||||||
}
|
}
|
||||||
|
|
||||||
enum SubmissionStatus {
|
enum SubmissionStatus {
|
||||||
|
@ -599,7 +596,6 @@ model StoreListingSubmission {
|
||||||
reviewComments String?
|
reviewComments String?
|
||||||
|
|
||||||
@@index([storeListingId])
|
@@index([storeListingId])
|
||||||
@@index([Status])
|
|
||||||
}
|
}
|
||||||
|
|
||||||
enum APIKeyPermission {
|
enum APIKeyPermission {
|
||||||
|
|
|
@ -4,95 +4,101 @@ import pytest
|
||||||
from prisma.models import CreditTransaction
|
from prisma.models import CreditTransaction
|
||||||
|
|
||||||
from backend.blocks.llm import AITextGeneratorBlock
|
from backend.blocks.llm import AITextGeneratorBlock
|
||||||
from backend.data.credit import UserCredit
|
from backend.data.credit import BetaUserCredit
|
||||||
|
from backend.data.execution import NodeExecutionEntry
|
||||||
from backend.data.user import DEFAULT_USER_ID
|
from backend.data.user import DEFAULT_USER_ID
|
||||||
from backend.integrations.credentials_store import openai_credentials
|
from backend.integrations.credentials_store import openai_credentials
|
||||||
from backend.util.test import SpinTestServer
|
from backend.util.test import SpinTestServer
|
||||||
|
|
||||||
REFILL_VALUE = 1000
|
REFILL_VALUE = 1000
|
||||||
user_credit = UserCredit(REFILL_VALUE)
|
user_credit = BetaUserCredit(REFILL_VALUE)
|
||||||
|
|
||||||
|
|
||||||
|
async def disable_test_user_transactions():
|
||||||
|
await CreditTransaction.prisma().delete_many(where={"userId": DEFAULT_USER_ID})
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio(scope="session")
|
@pytest.mark.asyncio(scope="session")
|
||||||
async def test_block_credit_usage(server: SpinTestServer):
|
async def test_block_credit_usage(server: SpinTestServer):
|
||||||
current_credit = await user_credit.get_or_refill_credit(DEFAULT_USER_ID)
|
await disable_test_user_transactions()
|
||||||
|
await user_credit.top_up_credits(DEFAULT_USER_ID, 100)
|
||||||
|
current_credit = await user_credit.get_credits(DEFAULT_USER_ID)
|
||||||
|
|
||||||
spending_amount_1 = await user_credit.spend_credits(
|
spending_amount_1 = await user_credit.spend_credits(
|
||||||
DEFAULT_USER_ID,
|
NodeExecutionEntry(
|
||||||
current_credit,
|
user_id=DEFAULT_USER_ID,
|
||||||
AITextGeneratorBlock().id,
|
graph_id="test_graph",
|
||||||
{
|
node_id="test_node",
|
||||||
"model": "gpt-4-turbo",
|
graph_exec_id="test_graph_exec",
|
||||||
"credentials": {
|
node_exec_id="test_node_exec",
|
||||||
"id": openai_credentials.id,
|
block_id=AITextGeneratorBlock().id,
|
||||||
"provider": openai_credentials.provider,
|
data={
|
||||||
"type": openai_credentials.type,
|
"model": "gpt-4-turbo",
|
||||||
|
"credentials": {
|
||||||
|
"id": openai_credentials.id,
|
||||||
|
"provider": openai_credentials.provider,
|
||||||
|
"type": openai_credentials.type,
|
||||||
|
},
|
||||||
},
|
},
|
||||||
},
|
),
|
||||||
0.0,
|
0.0,
|
||||||
0.0,
|
0.0,
|
||||||
validate_balance=False,
|
|
||||||
)
|
)
|
||||||
assert spending_amount_1 > 0
|
assert spending_amount_1 > 0
|
||||||
|
|
||||||
spending_amount_2 = await user_credit.spend_credits(
|
spending_amount_2 = await user_credit.spend_credits(
|
||||||
DEFAULT_USER_ID,
|
NodeExecutionEntry(
|
||||||
current_credit,
|
user_id=DEFAULT_USER_ID,
|
||||||
AITextGeneratorBlock().id,
|
graph_id="test_graph",
|
||||||
{"model": "gpt-4-turbo", "api_key": "owned_api_key"},
|
node_id="test_node",
|
||||||
|
graph_exec_id="test_graph_exec",
|
||||||
|
node_exec_id="test_node_exec",
|
||||||
|
block_id=AITextGeneratorBlock().id,
|
||||||
|
data={"model": "gpt-4-turbo", "api_key": "owned_api_key"},
|
||||||
|
),
|
||||||
0.0,
|
0.0,
|
||||||
0.0,
|
0.0,
|
||||||
validate_balance=False,
|
|
||||||
)
|
)
|
||||||
assert spending_amount_2 == 0
|
assert spending_amount_2 == 0
|
||||||
|
|
||||||
new_credit = await user_credit.get_or_refill_credit(DEFAULT_USER_ID)
|
new_credit = await user_credit.get_credits(DEFAULT_USER_ID)
|
||||||
assert new_credit == current_credit - spending_amount_1 - spending_amount_2
|
assert new_credit == current_credit - spending_amount_1 - spending_amount_2
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio(scope="session")
|
@pytest.mark.asyncio(scope="session")
|
||||||
async def test_block_credit_top_up(server: SpinTestServer):
|
async def test_block_credit_top_up(server: SpinTestServer):
|
||||||
current_credit = await user_credit.get_or_refill_credit(DEFAULT_USER_ID)
|
await disable_test_user_transactions()
|
||||||
|
current_credit = await user_credit.get_credits(DEFAULT_USER_ID)
|
||||||
|
|
||||||
await user_credit.top_up_credits(DEFAULT_USER_ID, 100)
|
await user_credit.top_up_credits(DEFAULT_USER_ID, 100)
|
||||||
|
|
||||||
new_credit = await user_credit.get_or_refill_credit(DEFAULT_USER_ID)
|
new_credit = await user_credit.get_credits(DEFAULT_USER_ID)
|
||||||
assert new_credit == current_credit + 100
|
assert new_credit == current_credit + 100
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio(scope="session")
|
@pytest.mark.asyncio(scope="session")
|
||||||
async def test_block_credit_reset(server: SpinTestServer):
|
async def test_block_credit_reset(server: SpinTestServer):
|
||||||
month1 = datetime(2022, 1, 15)
|
await disable_test_user_transactions()
|
||||||
month2 = datetime(2022, 2, 15)
|
month1 = 1
|
||||||
|
month2 = 2
|
||||||
|
|
||||||
user_credit.time_now = lambda: month2
|
# set the calendar to month 2 but use current time from now
|
||||||
month2credit = await user_credit.get_or_refill_credit(DEFAULT_USER_ID)
|
user_credit.time_now = lambda: datetime.now().replace(month=month2)
|
||||||
|
month2credit = await user_credit.get_credits(DEFAULT_USER_ID)
|
||||||
|
|
||||||
# Month 1 result should only affect month 1
|
# Month 1 result should only affect month 1
|
||||||
user_credit.time_now = lambda: month1
|
user_credit.time_now = lambda: datetime.now().replace(month=month1)
|
||||||
month1credit = await user_credit.get_or_refill_credit(DEFAULT_USER_ID)
|
month1credit = await user_credit.get_credits(DEFAULT_USER_ID)
|
||||||
await user_credit.top_up_credits(DEFAULT_USER_ID, 100)
|
await user_credit.top_up_credits(DEFAULT_USER_ID, 100)
|
||||||
assert await user_credit.get_or_refill_credit(DEFAULT_USER_ID) == month1credit + 100
|
assert await user_credit.get_credits(DEFAULT_USER_ID) == month1credit + 100
|
||||||
|
|
||||||
# Month 2 balance is unaffected
|
# Month 2 balance is unaffected
|
||||||
user_credit.time_now = lambda: month2
|
user_credit.time_now = lambda: datetime.now().replace(month=month2)
|
||||||
assert await user_credit.get_or_refill_credit(DEFAULT_USER_ID) == month2credit
|
assert await user_credit.get_credits(DEFAULT_USER_ID) == month2credit
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio(scope="session")
|
@pytest.mark.asyncio(scope="session")
|
||||||
async def test_credit_refill(server: SpinTestServer):
|
async def test_credit_refill(server: SpinTestServer):
|
||||||
# Clear all transactions within the month
|
await disable_test_user_transactions()
|
||||||
await CreditTransaction.prisma().update_many(
|
balance = await user_credit.get_credits(DEFAULT_USER_ID)
|
||||||
where={
|
|
||||||
"userId": DEFAULT_USER_ID,
|
|
||||||
"createdAt": {
|
|
||||||
"gte": datetime(2022, 2, 1),
|
|
||||||
"lt": datetime(2022, 3, 1),
|
|
||||||
},
|
|
||||||
},
|
|
||||||
data={"isActive": False},
|
|
||||||
)
|
|
||||||
user_credit.time_now = lambda: datetime(2022, 2, 15)
|
|
||||||
|
|
||||||
balance = await user_credit.get_or_refill_credit(DEFAULT_USER_ID)
|
|
||||||
assert balance == REFILL_VALUE
|
assert balance == REFILL_VALUE
|
||||||
|
|
|
@ -125,7 +125,7 @@ async def test_agent_execution(server: SpinTestServer):
|
||||||
logger.info("Starting test_agent_execution")
|
logger.info("Starting test_agent_execution")
|
||||||
test_user = await create_test_user()
|
test_user = await create_test_user()
|
||||||
test_graph = await create_graph(server, create_test_graph(), test_user)
|
test_graph = await create_graph(server, create_test_graph(), test_user)
|
||||||
data = {"input_1": "Hello", "input_2": "World"}
|
data = {"node_input": {"input_1": "Hello", "input_2": "World"}}
|
||||||
graph_exec_id = await execute_graph(
|
graph_exec_id = await execute_graph(
|
||||||
server.agent_server,
|
server.agent_server,
|
||||||
test_graph,
|
test_graph,
|
||||||
|
|
|
@ -298,7 +298,6 @@ async def main():
|
||||||
data={
|
data={
|
||||||
"transactionKey": str(faker.uuid4()),
|
"transactionKey": str(faker.uuid4()),
|
||||||
"userId": user.id,
|
"userId": user.id,
|
||||||
"blockId": block.id,
|
|
||||||
"amount": random.randint(1, 100),
|
"amount": random.randint(1, 100),
|
||||||
"type": (
|
"type": (
|
||||||
prisma.enums.CreditTransactionType.TOP_UP
|
prisma.enums.CreditTransactionType.TOP_UP
|
||||||
|
|
|
@ -5,6 +5,7 @@ NEXT_PUBLIC_AGPT_MARKETPLACE_URL=http://localhost:8015/api/v1/market
|
||||||
NEXT_PUBLIC_LAUNCHDARKLY_ENABLED=false
|
NEXT_PUBLIC_LAUNCHDARKLY_ENABLED=false
|
||||||
NEXT_PUBLIC_LAUNCHDARKLY_CLIENT_ID=
|
NEXT_PUBLIC_LAUNCHDARKLY_CLIENT_ID=
|
||||||
NEXT_PUBLIC_APP_ENV=dev
|
NEXT_PUBLIC_APP_ENV=dev
|
||||||
|
NEXT_PUBLIC_STRIPE_PUBLISHABLE_KEY=
|
||||||
|
|
||||||
## Locale settings
|
## Locale settings
|
||||||
|
|
||||||
|
|
|
@ -45,6 +45,7 @@
|
||||||
"@radix-ui/react-toast": "^1.2.4",
|
"@radix-ui/react-toast": "^1.2.4",
|
||||||
"@radix-ui/react-tooltip": "^1.1.6",
|
"@radix-ui/react-tooltip": "^1.1.6",
|
||||||
"@sentry/nextjs": "^8",
|
"@sentry/nextjs": "^8",
|
||||||
|
"@stripe/stripe-js": "^5.3.0",
|
||||||
"@supabase/ssr": "^0.5.2",
|
"@supabase/ssr": "^0.5.2",
|
||||||
"@supabase/supabase-js": "^2.47.8",
|
"@supabase/supabase-js": "^2.47.8",
|
||||||
"@tanstack/react-table": "^8.20.6",
|
"@tanstack/react-table": "^8.20.6",
|
||||||
|
@ -64,7 +65,7 @@
|
||||||
"launchdarkly-react-client-sdk": "^3.6.0",
|
"launchdarkly-react-client-sdk": "^3.6.0",
|
||||||
"lucide-react": "^0.469.0",
|
"lucide-react": "^0.469.0",
|
||||||
"moment": "^2.30.1",
|
"moment": "^2.30.1",
|
||||||
"next": "^14.2.13",
|
"next": "^14.2.21",
|
||||||
"next-themes": "^0.4.4",
|
"next-themes": "^0.4.4",
|
||||||
"react": "^18",
|
"react": "^18",
|
||||||
"react-day-picker": "^9.5.0",
|
"react-day-picker": "^9.5.0",
|
||||||
|
|
|
@ -98,6 +98,7 @@ export default function PrivatePage() {
|
||||||
// This contains ids for built-in "Use Credits for X" credentials
|
// This contains ids for built-in "Use Credits for X" credentials
|
||||||
const hiddenCredentials = useMemo(
|
const hiddenCredentials = useMemo(
|
||||||
() => [
|
() => [
|
||||||
|
"744fdc56-071a-4761-b5a5-0af0ce10a2b5", // Ollama
|
||||||
"fdb7f412-f519-48d1-9b5f-d2f73d0e01fe", // Revid
|
"fdb7f412-f519-48d1-9b5f-d2f73d0e01fe", // Revid
|
||||||
"760f84fc-b270-42de-91f6-08efe1b512d0", // Ideogram
|
"760f84fc-b270-42de-91f6-08efe1b512d0", // Ideogram
|
||||||
"6b9fc200-4726-4973-86c9-cd526f5ce5db", // Replicate
|
"6b9fc200-4726-4973-86c9-cd526f5ce5db", // Replicate
|
||||||
|
@ -123,14 +124,22 @@ export default function PrivatePage() {
|
||||||
|
|
||||||
const allCredentials = providers
|
const allCredentials = providers
|
||||||
? Object.values(providers).flatMap((provider) =>
|
? Object.values(providers).flatMap((provider) =>
|
||||||
[...provider.savedOAuthCredentials, ...provider.savedApiKeys]
|
[
|
||||||
|
...provider.savedOAuthCredentials,
|
||||||
|
...provider.savedApiKeys,
|
||||||
|
...provider.savedUserPasswordCredentials,
|
||||||
|
]
|
||||||
.filter((cred) => !hiddenCredentials.includes(cred.id))
|
.filter((cred) => !hiddenCredentials.includes(cred.id))
|
||||||
.map((credentials) => ({
|
.map((credentials) => ({
|
||||||
...credentials,
|
...credentials,
|
||||||
provider: provider.provider,
|
provider: provider.provider,
|
||||||
providerName: provider.providerName,
|
providerName: provider.providerName,
|
||||||
ProviderIcon: providerIcons[provider.provider],
|
ProviderIcon: providerIcons[provider.provider],
|
||||||
TypeIcon: { oauth2: IconUser, api_key: IconKey }[credentials.type],
|
TypeIcon: {
|
||||||
|
oauth2: IconUser,
|
||||||
|
api_key: IconKey,
|
||||||
|
user_password: IconKey,
|
||||||
|
}[credentials.type],
|
||||||
})),
|
})),
|
||||||
)
|
)
|
||||||
: [];
|
: [];
|
||||||
|
@ -175,6 +184,7 @@ export default function PrivatePage() {
|
||||||
{
|
{
|
||||||
oauth2: "OAuth2 credentials",
|
oauth2: "OAuth2 credentials",
|
||||||
api_key: "API key",
|
api_key: "API key",
|
||||||
|
user_password: "User password",
|
||||||
}[cred.type]
|
}[cred.type]
|
||||||
}{" "}
|
}{" "}
|
||||||
- <code>{cred.id}</code>
|
- <code>{cred.id}</code>
|
||||||
|
|
|
@ -0,0 +1,104 @@
|
||||||
|
"use client";
|
||||||
|
import { Button } from "@/components/agptui/Button";
|
||||||
|
import useCredits from "@/hooks/useCredits";
|
||||||
|
import { useBackendAPI } from "@/lib/autogpt-server-api/context";
|
||||||
|
import { useSearchParams, useRouter } from "next/navigation";
|
||||||
|
import { useEffect, useState } from "react";
|
||||||
|
|
||||||
|
export default function CreditsPage() {
|
||||||
|
const { requestTopUp } = useCredits();
|
||||||
|
const [amount, setAmount] = useState(5);
|
||||||
|
const [patched, setPatched] = useState(false);
|
||||||
|
const searchParams = useSearchParams();
|
||||||
|
const router = useRouter();
|
||||||
|
const topupStatus = searchParams.get("topup");
|
||||||
|
const api = useBackendAPI();
|
||||||
|
|
||||||
|
useEffect(() => {
|
||||||
|
if (!patched && topupStatus === "success") {
|
||||||
|
api.fulfillCheckout();
|
||||||
|
setPatched(true);
|
||||||
|
}
|
||||||
|
}, [api, patched, topupStatus]);
|
||||||
|
|
||||||
|
const openBillingPortal = async () => {
|
||||||
|
const portal = await api.getUserPaymentPortalLink();
|
||||||
|
router.push(portal.url);
|
||||||
|
};
|
||||||
|
|
||||||
|
return (
|
||||||
|
<div className="w-full min-w-[800px] px-4 sm:px-8">
|
||||||
|
<h1 className="font-circular mb-6 text-[28px] font-normal text-neutral-900 dark:text-neutral-100 sm:mb-8 sm:text-[35px]">
|
||||||
|
Credits
|
||||||
|
</h1>
|
||||||
|
|
||||||
|
<div className="grid grid-cols-1 gap-8 lg:grid-cols-2">
|
||||||
|
{/* Left Column */}
|
||||||
|
<div>
|
||||||
|
<h2 className="text-lg">Top-up Credits</h2>
|
||||||
|
|
||||||
|
<p className="mb-6 text-neutral-600 dark:text-neutral-400">
|
||||||
|
{topupStatus === "success" && (
|
||||||
|
<span className="text-green-500">
|
||||||
|
Your payment was successful. Your credits will be updated
|
||||||
|
shortly.
|
||||||
|
</span>
|
||||||
|
)}
|
||||||
|
{topupStatus === "cancel" && (
|
||||||
|
<span className="text-red-500">
|
||||||
|
Payment failed. Your payment method has not been charged.
|
||||||
|
</span>
|
||||||
|
)}
|
||||||
|
</p>
|
||||||
|
|
||||||
|
<div className="mb-4 w-full">
|
||||||
|
<label className="text-neutral-700">
|
||||||
|
1 USD = 100 credits, 5 USD is a minimum top-up
|
||||||
|
</label>
|
||||||
|
<div className="rounded-[55px] border border-slate-200 px-4 py-2.5 dark:border-slate-700 dark:bg-slate-800">
|
||||||
|
<input
|
||||||
|
type="number"
|
||||||
|
name="displayName"
|
||||||
|
value={amount}
|
||||||
|
placeholder="Top-up amount in USD"
|
||||||
|
min="5"
|
||||||
|
step="1"
|
||||||
|
className="w-full"
|
||||||
|
onChange={(e) => setAmount(parseInt(e.target.value))}
|
||||||
|
/>
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
|
||||||
|
<Button
|
||||||
|
type="submit"
|
||||||
|
variant="default"
|
||||||
|
className="font-circular ml-auto"
|
||||||
|
onClick={() => requestTopUp(amount)}
|
||||||
|
>
|
||||||
|
Top-up
|
||||||
|
</Button>
|
||||||
|
</div>
|
||||||
|
|
||||||
|
{/* Right Column */}
|
||||||
|
<div>
|
||||||
|
<h2 className="text-lg">Manage Your Payment Methods</h2>
|
||||||
|
<br />
|
||||||
|
<p className="text-neutral-600">
|
||||||
|
You can manage your cards and see your payment history in the
|
||||||
|
billing portal.
|
||||||
|
</p>
|
||||||
|
<br />
|
||||||
|
|
||||||
|
<Button
|
||||||
|
type="submit"
|
||||||
|
variant="default"
|
||||||
|
className="font-circular ml-auto"
|
||||||
|
onClick={() => openBillingPortal()}
|
||||||
|
>
|
||||||
|
Open Portal
|
||||||
|
</Button>
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
);
|
||||||
|
}
|
|
@ -33,14 +33,14 @@ export default function Page({}: {}) {
|
||||||
} catch (error) {
|
} catch (error) {
|
||||||
console.error("Error fetching submissions:", error);
|
console.error("Error fetching submissions:", error);
|
||||||
}
|
}
|
||||||
}, [api, supabase]);
|
}, [api]);
|
||||||
|
|
||||||
useEffect(() => {
|
useEffect(() => {
|
||||||
if (!supabase) {
|
if (!supabase) {
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
fetchData();
|
fetchData();
|
||||||
}, [supabase]);
|
}, [supabase, fetchData]);
|
||||||
|
|
||||||
const onEditSubmission = useCallback((submission: StoreSubmissionRequest) => {
|
const onEditSubmission = useCallback((submission: StoreSubmissionRequest) => {
|
||||||
setSubmissionData(submission);
|
setSubmissionData(submission);
|
||||||
|
@ -56,7 +56,7 @@ export default function Page({}: {}) {
|
||||||
api.deleteStoreSubmission(submission_id);
|
api.deleteStoreSubmission(submission_id);
|
||||||
fetchData();
|
fetchData();
|
||||||
},
|
},
|
||||||
[supabase],
|
[api, supabase, fetchData],
|
||||||
);
|
);
|
||||||
|
|
||||||
const onOpenPopout = useCallback(() => {
|
const onOpenPopout = useCallback(() => {
|
||||||
|
|
|
@ -98,6 +98,7 @@ export default function PrivatePage() {
|
||||||
// This contains ids for built-in "Use Credits for X" credentials
|
// This contains ids for built-in "Use Credits for X" credentials
|
||||||
const hiddenCredentials = useMemo(
|
const hiddenCredentials = useMemo(
|
||||||
() => [
|
() => [
|
||||||
|
"744fdc56-071a-4761-b5a5-0af0ce10a2b5", // Ollama
|
||||||
"fdb7f412-f519-48d1-9b5f-d2f73d0e01fe", // Revid
|
"fdb7f412-f519-48d1-9b5f-d2f73d0e01fe", // Revid
|
||||||
"760f84fc-b270-42de-91f6-08efe1b512d0", // Ideogram
|
"760f84fc-b270-42de-91f6-08efe1b512d0", // Ideogram
|
||||||
"6b9fc200-4726-4973-86c9-cd526f5ce5db", // Replicate
|
"6b9fc200-4726-4973-86c9-cd526f5ce5db", // Replicate
|
||||||
|
@ -123,14 +124,22 @@ export default function PrivatePage() {
|
||||||
|
|
||||||
const allCredentials = providers
|
const allCredentials = providers
|
||||||
? Object.values(providers).flatMap((provider) =>
|
? Object.values(providers).flatMap((provider) =>
|
||||||
[...provider.savedOAuthCredentials, ...provider.savedApiKeys]
|
[
|
||||||
|
...provider.savedOAuthCredentials,
|
||||||
|
...provider.savedApiKeys,
|
||||||
|
...provider.savedUserPasswordCredentials,
|
||||||
|
]
|
||||||
.filter((cred) => !hiddenCredentials.includes(cred.id))
|
.filter((cred) => !hiddenCredentials.includes(cred.id))
|
||||||
.map((credentials) => ({
|
.map((credentials) => ({
|
||||||
...credentials,
|
...credentials,
|
||||||
provider: provider.provider,
|
provider: provider.provider,
|
||||||
providerName: provider.providerName,
|
providerName: provider.providerName,
|
||||||
ProviderIcon: providerIcons[provider.provider],
|
ProviderIcon: providerIcons[provider.provider],
|
||||||
TypeIcon: { oauth2: IconUser, api_key: IconKey }[credentials.type],
|
TypeIcon: {
|
||||||
|
oauth2: IconUser,
|
||||||
|
api_key: IconKey,
|
||||||
|
user_password: IconKey,
|
||||||
|
}[credentials.type],
|
||||||
})),
|
})),
|
||||||
)
|
)
|
||||||
: [];
|
: [];
|
||||||
|
|
|
@ -7,6 +7,7 @@ export default function Layout({ children }: { children: React.ReactNode }) {
|
||||||
links: [
|
links: [
|
||||||
{ text: "Creator Dashboard", href: "/store/dashboard" },
|
{ text: "Creator Dashboard", href: "/store/dashboard" },
|
||||||
{ text: "Agent dashboard", href: "/store/agent-dashboard" },
|
{ text: "Agent dashboard", href: "/store/agent-dashboard" },
|
||||||
|
{ text: "Credits", href: "/store/credits" },
|
||||||
{ text: "Integrations", href: "/store/integrations" },
|
{ text: "Integrations", href: "/store/integrations" },
|
||||||
{ text: "API Keys", href: "/store/api_keys" },
|
{ text: "API Keys", href: "/store/api_keys" },
|
||||||
{ text: "Profile", href: "/store/profile" },
|
{ text: "Profile", href: "/store/profile" },
|
||||||
|
|
|
@ -61,7 +61,7 @@ function SearchResults({
|
||||||
};
|
};
|
||||||
|
|
||||||
fetchData();
|
fetchData();
|
||||||
}, [searchTerm, sort]);
|
}, [api, searchTerm, sort]);
|
||||||
|
|
||||||
const agentsCount = agents.length;
|
const agentsCount = agents.length;
|
||||||
const creatorsCount = creators.length;
|
const creatorsCount = creators.length;
|
||||||
|
|
|
@ -8,6 +8,7 @@ import {
|
||||||
IconIntegrations,
|
IconIntegrations,
|
||||||
IconProfile,
|
IconProfile,
|
||||||
IconSliders,
|
IconSliders,
|
||||||
|
IconCoin,
|
||||||
} from "../ui/icons";
|
} from "../ui/icons";
|
||||||
|
|
||||||
interface SidebarLinkGroup {
|
interface SidebarLinkGroup {
|
||||||
|
@ -22,6 +23,10 @@ interface SidebarProps {
|
||||||
}
|
}
|
||||||
|
|
||||||
export const Sidebar: React.FC<SidebarProps> = ({ linkGroups }) => {
|
export const Sidebar: React.FC<SidebarProps> = ({ linkGroups }) => {
|
||||||
|
const stripeAvailable = Boolean(
|
||||||
|
process.env.NEXT_PUBLIC_STRIPE_PUBLISHABLE_KEY,
|
||||||
|
);
|
||||||
|
|
||||||
return (
|
return (
|
||||||
<>
|
<>
|
||||||
<Sheet>
|
<Sheet>
|
||||||
|
@ -49,6 +54,17 @@ export const Sidebar: React.FC<SidebarProps> = ({ linkGroups }) => {
|
||||||
Creator dashboard
|
Creator dashboard
|
||||||
</div>
|
</div>
|
||||||
</Link>
|
</Link>
|
||||||
|
{stripeAvailable && (
|
||||||
|
<Link
|
||||||
|
href="/store/credits"
|
||||||
|
className="inline-flex w-full items-center gap-2.5 rounded-xl px-3 py-3 text-neutral-800 hover:bg-neutral-800 hover:text-white dark:text-neutral-200 dark:hover:bg-neutral-700 dark:hover:text-white"
|
||||||
|
>
|
||||||
|
<IconCoin className="h-6 w-6" />
|
||||||
|
<div className="p-ui-medium text-base font-medium leading-normal">
|
||||||
|
Credits
|
||||||
|
</div>
|
||||||
|
</Link>
|
||||||
|
)}
|
||||||
<Link
|
<Link
|
||||||
href="/store/integrations"
|
href="/store/integrations"
|
||||||
className="inline-flex w-full items-center gap-2.5 rounded-xl px-3 py-3 text-neutral-800 hover:bg-neutral-800 hover:text-white dark:text-neutral-200 dark:hover:bg-neutral-700 dark:hover:text-white"
|
className="inline-flex w-full items-center gap-2.5 rounded-xl px-3 py-3 text-neutral-800 hover:bg-neutral-800 hover:text-white dark:text-neutral-200 dark:hover:bg-neutral-700 dark:hover:text-white"
|
||||||
|
@ -102,6 +118,17 @@ export const Sidebar: React.FC<SidebarProps> = ({ linkGroups }) => {
|
||||||
Agent dashboard
|
Agent dashboard
|
||||||
</div>
|
</div>
|
||||||
</Link>
|
</Link>
|
||||||
|
{stripeAvailable && (
|
||||||
|
<Link
|
||||||
|
href="/store/credits"
|
||||||
|
className="inline-flex w-full items-center gap-2.5 rounded-xl px-3 py-3 text-neutral-800 hover:bg-neutral-800 hover:text-white dark:text-neutral-200 dark:hover:bg-neutral-700 dark:hover:text-white"
|
||||||
|
>
|
||||||
|
<IconCoin className="h-6 w-6" />
|
||||||
|
<div className="p-ui-medium text-base font-medium leading-normal">
|
||||||
|
Credits
|
||||||
|
</div>
|
||||||
|
</Link>
|
||||||
|
)}
|
||||||
<Link
|
<Link
|
||||||
href="/store/integrations"
|
href="/store/integrations"
|
||||||
className="inline-flex w-full items-center gap-2.5 rounded-xl px-3 py-3 text-neutral-800 hover:bg-neutral-800 hover:text-white dark:text-neutral-200 dark:hover:bg-neutral-700 dark:hover:text-white"
|
className="inline-flex w-full items-center gap-2.5 rounded-xl px-3 py-3 text-neutral-800 hover:bg-neutral-800 hover:text-white dark:text-neutral-200 dark:hover:bg-neutral-700 dark:hover:text-white"
|
||||||
|
|
|
@ -14,6 +14,7 @@ import {
|
||||||
FaGoogle,
|
FaGoogle,
|
||||||
FaMedium,
|
FaMedium,
|
||||||
FaKey,
|
FaKey,
|
||||||
|
FaHubspot,
|
||||||
} from "react-icons/fa";
|
} from "react-icons/fa";
|
||||||
import { FC, useMemo, useState } from "react";
|
import { FC, useMemo, useState } from "react";
|
||||||
import {
|
import {
|
||||||
|
@ -66,6 +67,7 @@ export const providerIcons: Record<
|
||||||
google_maps: FaGoogle,
|
google_maps: FaGoogle,
|
||||||
jina: fallbackIcon,
|
jina: fallbackIcon,
|
||||||
ideogram: fallbackIcon,
|
ideogram: fallbackIcon,
|
||||||
|
linear: fallbackIcon,
|
||||||
medium: FaMedium,
|
medium: FaMedium,
|
||||||
ollama: fallbackIcon,
|
ollama: fallbackIcon,
|
||||||
openai: fallbackIcon,
|
openai: fallbackIcon,
|
||||||
|
@ -73,13 +75,15 @@ export const providerIcons: Record<
|
||||||
open_router: fallbackIcon,
|
open_router: fallbackIcon,
|
||||||
pinecone: fallbackIcon,
|
pinecone: fallbackIcon,
|
||||||
slant3d: fallbackIcon,
|
slant3d: fallbackIcon,
|
||||||
|
smtp: fallbackIcon,
|
||||||
replicate: fallbackIcon,
|
replicate: fallbackIcon,
|
||||||
|
reddit: fallbackIcon,
|
||||||
fal: fallbackIcon,
|
fal: fallbackIcon,
|
||||||
revid: fallbackIcon,
|
revid: fallbackIcon,
|
||||||
twitter: FaTwitter,
|
twitter: FaTwitter,
|
||||||
unreal_speech: fallbackIcon,
|
unreal_speech: fallbackIcon,
|
||||||
exa: fallbackIcon,
|
exa: fallbackIcon,
|
||||||
hubspot: fallbackIcon,
|
hubspot: FaHubspot,
|
||||||
};
|
};
|
||||||
// --8<-- [end:ProviderIconsEmbed]
|
// --8<-- [end:ProviderIconsEmbed]
|
||||||
|
|
||||||
|
@ -105,6 +109,10 @@ export const CredentialsInput: FC<{
|
||||||
const credentials = useCredentials(selfKey);
|
const credentials = useCredentials(selfKey);
|
||||||
const [isAPICredentialsModalOpen, setAPICredentialsModalOpen] =
|
const [isAPICredentialsModalOpen, setAPICredentialsModalOpen] =
|
||||||
useState(false);
|
useState(false);
|
||||||
|
const [
|
||||||
|
isUserPasswordCredentialsModalOpen,
|
||||||
|
setUserPasswordCredentialsModalOpen,
|
||||||
|
] = useState(false);
|
||||||
const [isOAuth2FlowInProgress, setOAuth2FlowInProgress] = useState(false);
|
const [isOAuth2FlowInProgress, setOAuth2FlowInProgress] = useState(false);
|
||||||
const [oAuthPopupController, setOAuthPopupController] =
|
const [oAuthPopupController, setOAuthPopupController] =
|
||||||
useState<AbortController | null>(null);
|
useState<AbortController | null>(null);
|
||||||
|
@ -120,8 +128,10 @@ export const CredentialsInput: FC<{
|
||||||
providerName,
|
providerName,
|
||||||
supportsApiKey,
|
supportsApiKey,
|
||||||
supportsOAuth2,
|
supportsOAuth2,
|
||||||
|
supportsUserPassword,
|
||||||
savedApiKeys,
|
savedApiKeys,
|
||||||
savedOAuthCredentials,
|
savedOAuthCredentials,
|
||||||
|
savedUserPasswordCredentials,
|
||||||
oAuthCallback,
|
oAuthCallback,
|
||||||
} = credentials;
|
} = credentials;
|
||||||
|
|
||||||
|
@ -235,6 +245,17 @@ export const CredentialsInput: FC<{
|
||||||
providerName={providerName}
|
providerName={providerName}
|
||||||
/>
|
/>
|
||||||
)}
|
)}
|
||||||
|
{supportsUserPassword && (
|
||||||
|
<UserPasswordCredentialsModal
|
||||||
|
credentialsFieldName={selfKey}
|
||||||
|
open={isUserPasswordCredentialsModalOpen}
|
||||||
|
onClose={() => setUserPasswordCredentialsModalOpen(false)}
|
||||||
|
onCredentialsCreate={(creds) => {
|
||||||
|
onSelectCredentials(creds);
|
||||||
|
setUserPasswordCredentialsModalOpen(false);
|
||||||
|
}}
|
||||||
|
/>
|
||||||
|
)}
|
||||||
</>
|
</>
|
||||||
);
|
);
|
||||||
|
|
||||||
|
@ -243,13 +264,18 @@ export const CredentialsInput: FC<{
|
||||||
selectedCredentials &&
|
selectedCredentials &&
|
||||||
!savedApiKeys
|
!savedApiKeys
|
||||||
.concat(savedOAuthCredentials)
|
.concat(savedOAuthCredentials)
|
||||||
|
.concat(savedUserPasswordCredentials)
|
||||||
.some((c) => c.id === selectedCredentials.id)
|
.some((c) => c.id === selectedCredentials.id)
|
||||||
) {
|
) {
|
||||||
onSelectCredentials(undefined);
|
onSelectCredentials(undefined);
|
||||||
}
|
}
|
||||||
|
|
||||||
// No saved credentials yet
|
// No saved credentials yet
|
||||||
if (savedApiKeys.length === 0 && savedOAuthCredentials.length === 0) {
|
if (
|
||||||
|
savedApiKeys.length === 0 &&
|
||||||
|
savedOAuthCredentials.length === 0 &&
|
||||||
|
savedUserPasswordCredentials.length === 0
|
||||||
|
) {
|
||||||
return (
|
return (
|
||||||
<>
|
<>
|
||||||
<div className="mb-2 flex gap-1">
|
<div className="mb-2 flex gap-1">
|
||||||
|
@ -271,6 +297,12 @@ export const CredentialsInput: FC<{
|
||||||
Enter API key
|
Enter API key
|
||||||
</Button>
|
</Button>
|
||||||
)}
|
)}
|
||||||
|
{supportsUserPassword && (
|
||||||
|
<Button onClick={() => setUserPasswordCredentialsModalOpen(true)}>
|
||||||
|
<ProviderIcon className="mr-2 h-4 w-4" />
|
||||||
|
Enter username and password
|
||||||
|
</Button>
|
||||||
|
)}
|
||||||
</div>
|
</div>
|
||||||
{modals}
|
{modals}
|
||||||
{oAuthError && (
|
{oAuthError && (
|
||||||
|
@ -280,12 +312,29 @@ export const CredentialsInput: FC<{
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
const singleCredential =
|
const getCredentialCounts = () => ({
|
||||||
savedApiKeys.length === 1 && savedOAuthCredentials.length === 0
|
apiKeys: savedApiKeys.length,
|
||||||
? savedApiKeys[0]
|
oauth: savedOAuthCredentials.length,
|
||||||
: savedOAuthCredentials.length === 1 && savedApiKeys.length === 0
|
userPass: savedUserPasswordCredentials.length,
|
||||||
? savedOAuthCredentials[0]
|
});
|
||||||
: null;
|
|
||||||
|
const getSingleCredential = () => {
|
||||||
|
const counts = getCredentialCounts();
|
||||||
|
const totalCredentials = Object.values(counts).reduce(
|
||||||
|
(sum, count) => sum + count,
|
||||||
|
0,
|
||||||
|
);
|
||||||
|
|
||||||
|
if (totalCredentials !== 1) return null;
|
||||||
|
|
||||||
|
if (counts.apiKeys === 1) return savedApiKeys[0];
|
||||||
|
if (counts.oauth === 1) return savedOAuthCredentials[0];
|
||||||
|
if (counts.userPass === 1) return savedUserPasswordCredentials[0];
|
||||||
|
|
||||||
|
return null;
|
||||||
|
};
|
||||||
|
|
||||||
|
const singleCredential = getSingleCredential();
|
||||||
|
|
||||||
if (singleCredential) {
|
if (singleCredential) {
|
||||||
if (!selectedCredentials) {
|
if (!selectedCredentials) {
|
||||||
|
@ -309,6 +358,7 @@ export const CredentialsInput: FC<{
|
||||||
} else {
|
} else {
|
||||||
const selectedCreds = savedApiKeys
|
const selectedCreds = savedApiKeys
|
||||||
.concat(savedOAuthCredentials)
|
.concat(savedOAuthCredentials)
|
||||||
|
.concat(savedUserPasswordCredentials)
|
||||||
.find((c) => c.id == newValue)!;
|
.find((c) => c.id == newValue)!;
|
||||||
|
|
||||||
onSelectCredentials({
|
onSelectCredentials({
|
||||||
|
@ -347,6 +397,13 @@ export const CredentialsInput: FC<{
|
||||||
{credentials.title}
|
{credentials.title}
|
||||||
</SelectItem>
|
</SelectItem>
|
||||||
))}
|
))}
|
||||||
|
{savedUserPasswordCredentials.map((credentials, index) => (
|
||||||
|
<SelectItem key={index} value={credentials.id}>
|
||||||
|
<ProviderIcon className="mr-2 inline h-4 w-4" />
|
||||||
|
<IconUserPlus className="mr-1.5 inline" />
|
||||||
|
{credentials.title}
|
||||||
|
</SelectItem>
|
||||||
|
))}
|
||||||
<SelectSeparator />
|
<SelectSeparator />
|
||||||
{supportsOAuth2 && (
|
{supportsOAuth2 && (
|
||||||
<SelectItem value="sign-in">
|
<SelectItem value="sign-in">
|
||||||
|
@ -360,6 +417,12 @@ export const CredentialsInput: FC<{
|
||||||
Add new API key
|
Add new API key
|
||||||
</SelectItem>
|
</SelectItem>
|
||||||
)}
|
)}
|
||||||
|
{supportsUserPassword && (
|
||||||
|
<SelectItem value="add-user-password">
|
||||||
|
<IconUserPlus className="mr-1.5 inline" />
|
||||||
|
Add new user password
|
||||||
|
</SelectItem>
|
||||||
|
)}
|
||||||
</SelectContent>
|
</SelectContent>
|
||||||
</Select>
|
</Select>
|
||||||
{modals}
|
{modals}
|
||||||
|
@ -506,6 +569,130 @@ export const APIKeyCredentialsModal: FC<{
|
||||||
);
|
);
|
||||||
};
|
};
|
||||||
|
|
||||||
|
export const UserPasswordCredentialsModal: FC<{
|
||||||
|
credentialsFieldName: string;
|
||||||
|
open: boolean;
|
||||||
|
onClose: () => void;
|
||||||
|
onCredentialsCreate: (creds: CredentialsMetaInput) => void;
|
||||||
|
}> = ({ credentialsFieldName, open, onClose, onCredentialsCreate }) => {
|
||||||
|
const credentials = useCredentials(credentialsFieldName);
|
||||||
|
|
||||||
|
const formSchema = z.object({
|
||||||
|
username: z.string().min(1, "Username is required"),
|
||||||
|
password: z.string().min(1, "Password is required"),
|
||||||
|
title: z.string().min(1, "Name is required"),
|
||||||
|
});
|
||||||
|
|
||||||
|
const form = useForm<z.infer<typeof formSchema>>({
|
||||||
|
resolver: zodResolver(formSchema),
|
||||||
|
defaultValues: {
|
||||||
|
username: "",
|
||||||
|
password: "",
|
||||||
|
title: "",
|
||||||
|
},
|
||||||
|
});
|
||||||
|
|
||||||
|
if (
|
||||||
|
!credentials ||
|
||||||
|
credentials.isLoading ||
|
||||||
|
!credentials.supportsUserPassword
|
||||||
|
) {
|
||||||
|
return null;
|
||||||
|
}
|
||||||
|
|
||||||
|
const { schema, provider, providerName, createUserPasswordCredentials } =
|
||||||
|
credentials;
|
||||||
|
|
||||||
|
async function onSubmit(values: z.infer<typeof formSchema>) {
|
||||||
|
const newCredentials = await createUserPasswordCredentials({
|
||||||
|
username: values.username,
|
||||||
|
password: values.password,
|
||||||
|
title: values.title,
|
||||||
|
});
|
||||||
|
onCredentialsCreate({
|
||||||
|
provider,
|
||||||
|
id: newCredentials.id,
|
||||||
|
type: "user_password",
|
||||||
|
title: newCredentials.title,
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
return (
|
||||||
|
<Dialog
|
||||||
|
open={open}
|
||||||
|
onOpenChange={(open) => {
|
||||||
|
if (!open) onClose();
|
||||||
|
}}
|
||||||
|
>
|
||||||
|
<DialogContent>
|
||||||
|
<DialogHeader>
|
||||||
|
<DialogTitle>
|
||||||
|
Add new username & password for {providerName}
|
||||||
|
</DialogTitle>
|
||||||
|
</DialogHeader>
|
||||||
|
<Form {...form}>
|
||||||
|
<form onSubmit={form.handleSubmit(onSubmit)} className="space-y-4">
|
||||||
|
<FormField
|
||||||
|
control={form.control}
|
||||||
|
name="username"
|
||||||
|
render={({ field }) => (
|
||||||
|
<FormItem>
|
||||||
|
<FormLabel>Username</FormLabel>
|
||||||
|
<FormControl>
|
||||||
|
<Input
|
||||||
|
type="text"
|
||||||
|
placeholder="Enter username..."
|
||||||
|
{...field}
|
||||||
|
/>
|
||||||
|
</FormControl>
|
||||||
|
<FormMessage />
|
||||||
|
</FormItem>
|
||||||
|
)}
|
||||||
|
/>
|
||||||
|
<FormField
|
||||||
|
control={form.control}
|
||||||
|
name="password"
|
||||||
|
render={({ field }) => (
|
||||||
|
<FormItem>
|
||||||
|
<FormLabel>Password</FormLabel>
|
||||||
|
<FormControl>
|
||||||
|
<Input
|
||||||
|
type="password"
|
||||||
|
placeholder="Enter password..."
|
||||||
|
{...field}
|
||||||
|
/>
|
||||||
|
</FormControl>
|
||||||
|
<FormMessage />
|
||||||
|
</FormItem>
|
||||||
|
)}
|
||||||
|
/>
|
||||||
|
<FormField
|
||||||
|
control={form.control}
|
||||||
|
name="title"
|
||||||
|
render={({ field }) => (
|
||||||
|
<FormItem>
|
||||||
|
<FormLabel>Name</FormLabel>
|
||||||
|
<FormControl>
|
||||||
|
<Input
|
||||||
|
type="text"
|
||||||
|
placeholder="Enter a name for this user login..."
|
||||||
|
{...field}
|
||||||
|
/>
|
||||||
|
</FormControl>
|
||||||
|
<FormMessage />
|
||||||
|
</FormItem>
|
||||||
|
)}
|
||||||
|
/>
|
||||||
|
<Button type="submit" className="w-full">
|
||||||
|
Save & use this user login
|
||||||
|
</Button>
|
||||||
|
</form>
|
||||||
|
</Form>
|
||||||
|
</DialogContent>
|
||||||
|
</Dialog>
|
||||||
|
);
|
||||||
|
};
|
||||||
|
|
||||||
export const OAuth2FlowWaitingModal: FC<{
|
export const OAuth2FlowWaitingModal: FC<{
|
||||||
open: boolean;
|
open: boolean;
|
||||||
onClose: () => void;
|
onClose: () => void;
|
||||||
|
|
|
@ -5,6 +5,7 @@ import {
|
||||||
CredentialsMetaResponse,
|
CredentialsMetaResponse,
|
||||||
CredentialsProviderName,
|
CredentialsProviderName,
|
||||||
PROVIDER_NAMES,
|
PROVIDER_NAMES,
|
||||||
|
UserPasswordCredentials,
|
||||||
} from "@/lib/autogpt-server-api";
|
} from "@/lib/autogpt-server-api";
|
||||||
import { useBackendAPI } from "@/lib/autogpt-server-api/context";
|
import { useBackendAPI } from "@/lib/autogpt-server-api/context";
|
||||||
import { createContext, useCallback, useEffect, useState } from "react";
|
import { createContext, useCallback, useEffect, useState } from "react";
|
||||||
|
@ -20,12 +21,16 @@ const providerDisplayNames: Record<CredentialsProviderName, string> = {
|
||||||
discord: "Discord",
|
discord: "Discord",
|
||||||
d_id: "D-ID",
|
d_id: "D-ID",
|
||||||
e2b: "E2B",
|
e2b: "E2B",
|
||||||
|
exa: "Exa",
|
||||||
|
fal: "FAL",
|
||||||
github: "GitHub",
|
github: "GitHub",
|
||||||
google: "Google",
|
google: "Google",
|
||||||
google_maps: "Google Maps",
|
google_maps: "Google Maps",
|
||||||
groq: "Groq",
|
groq: "Groq",
|
||||||
|
hubspot: "Hubspot",
|
||||||
ideogram: "Ideogram",
|
ideogram: "Ideogram",
|
||||||
jina: "Jina",
|
jina: "Jina",
|
||||||
|
linear: "Linear",
|
||||||
medium: "Medium",
|
medium: "Medium",
|
||||||
notion: "Notion",
|
notion: "Notion",
|
||||||
nvidia: "Nvidia",
|
nvidia: "Nvidia",
|
||||||
|
@ -35,13 +40,12 @@ const providerDisplayNames: Record<CredentialsProviderName, string> = {
|
||||||
open_router: "Open Router",
|
open_router: "Open Router",
|
||||||
pinecone: "Pinecone",
|
pinecone: "Pinecone",
|
||||||
slant3d: "Slant3D",
|
slant3d: "Slant3D",
|
||||||
|
smtp: "SMTP",
|
||||||
|
reddit: "Reddit",
|
||||||
replicate: "Replicate",
|
replicate: "Replicate",
|
||||||
fal: "FAL",
|
|
||||||
revid: "Rev.ID",
|
revid: "Rev.ID",
|
||||||
twitter: "Twitter",
|
twitter: "Twitter",
|
||||||
unreal_speech: "Unreal Speech",
|
unreal_speech: "Unreal Speech",
|
||||||
exa: "Exa",
|
|
||||||
hubspot: "Hubspot",
|
|
||||||
} as const;
|
} as const;
|
||||||
// --8<-- [end:CredentialsProviderNames]
|
// --8<-- [end:CredentialsProviderNames]
|
||||||
|
|
||||||
|
@ -50,11 +54,17 @@ type APIKeyCredentialsCreatable = Omit<
|
||||||
"id" | "provider" | "type"
|
"id" | "provider" | "type"
|
||||||
>;
|
>;
|
||||||
|
|
||||||
|
type UserPasswordCredentialsCreatable = Omit<
|
||||||
|
UserPasswordCredentials,
|
||||||
|
"id" | "provider" | "type"
|
||||||
|
>;
|
||||||
|
|
||||||
export type CredentialsProviderData = {
|
export type CredentialsProviderData = {
|
||||||
provider: CredentialsProviderName;
|
provider: CredentialsProviderName;
|
||||||
providerName: string;
|
providerName: string;
|
||||||
savedApiKeys: CredentialsMetaResponse[];
|
savedApiKeys: CredentialsMetaResponse[];
|
||||||
savedOAuthCredentials: CredentialsMetaResponse[];
|
savedOAuthCredentials: CredentialsMetaResponse[];
|
||||||
|
savedUserPasswordCredentials: CredentialsMetaResponse[];
|
||||||
oAuthCallback: (
|
oAuthCallback: (
|
||||||
code: string,
|
code: string,
|
||||||
state_token: string,
|
state_token: string,
|
||||||
|
@ -62,6 +72,9 @@ export type CredentialsProviderData = {
|
||||||
createAPIKeyCredentials: (
|
createAPIKeyCredentials: (
|
||||||
credentials: APIKeyCredentialsCreatable,
|
credentials: APIKeyCredentialsCreatable,
|
||||||
) => Promise<CredentialsMetaResponse>;
|
) => Promise<CredentialsMetaResponse>;
|
||||||
|
createUserPasswordCredentials: (
|
||||||
|
credentials: UserPasswordCredentialsCreatable,
|
||||||
|
) => Promise<CredentialsMetaResponse>;
|
||||||
deleteCredentials: (
|
deleteCredentials: (
|
||||||
id: string,
|
id: string,
|
||||||
force?: boolean,
|
force?: boolean,
|
||||||
|
@ -106,6 +119,11 @@ export default function CredentialsProvider({
|
||||||
...updatedProvider.savedOAuthCredentials,
|
...updatedProvider.savedOAuthCredentials,
|
||||||
credentials,
|
credentials,
|
||||||
];
|
];
|
||||||
|
} else if (credentials.type === "user_password") {
|
||||||
|
updatedProvider.savedUserPasswordCredentials = [
|
||||||
|
...updatedProvider.savedUserPasswordCredentials,
|
||||||
|
credentials,
|
||||||
|
];
|
||||||
}
|
}
|
||||||
|
|
||||||
return {
|
return {
|
||||||
|
@ -147,6 +165,22 @@ export default function CredentialsProvider({
|
||||||
[api, addCredentials],
|
[api, addCredentials],
|
||||||
);
|
);
|
||||||
|
|
||||||
|
/** Wraps `BackendAPI.createUserPasswordCredentials`, and adds the result to the internal credentials store. */
|
||||||
|
const createUserPasswordCredentials = useCallback(
|
||||||
|
async (
|
||||||
|
provider: CredentialsProviderName,
|
||||||
|
credentials: UserPasswordCredentialsCreatable,
|
||||||
|
): Promise<CredentialsMetaResponse> => {
|
||||||
|
const credsMeta = await api.createUserPasswordCredentials({
|
||||||
|
provider,
|
||||||
|
...credentials,
|
||||||
|
});
|
||||||
|
addCredentials(provider, credsMeta);
|
||||||
|
return credsMeta;
|
||||||
|
},
|
||||||
|
[api, addCredentials],
|
||||||
|
);
|
||||||
|
|
||||||
/** Wraps `BackendAPI.deleteCredentials`, and removes the credentials from the internal store. */
|
/** Wraps `BackendAPI.deleteCredentials`, and removes the credentials from the internal store. */
|
||||||
const deleteCredentials = useCallback(
|
const deleteCredentials = useCallback(
|
||||||
async (
|
async (
|
||||||
|
@ -171,7 +205,10 @@ export default function CredentialsProvider({
|
||||||
updatedProvider.savedOAuthCredentials.filter(
|
updatedProvider.savedOAuthCredentials.filter(
|
||||||
(cred) => cred.id !== id,
|
(cred) => cred.id !== id,
|
||||||
);
|
);
|
||||||
|
updatedProvider.savedUserPasswordCredentials =
|
||||||
|
updatedProvider.savedUserPasswordCredentials.filter(
|
||||||
|
(cred) => cred.id !== id,
|
||||||
|
);
|
||||||
return {
|
return {
|
||||||
...prev,
|
...prev,
|
||||||
[provider]: updatedProvider,
|
[provider]: updatedProvider,
|
||||||
|
@ -190,12 +227,18 @@ export default function CredentialsProvider({
|
||||||
const credentialsByProvider = response.reduce(
|
const credentialsByProvider = response.reduce(
|
||||||
(acc, cred) => {
|
(acc, cred) => {
|
||||||
if (!acc[cred.provider]) {
|
if (!acc[cred.provider]) {
|
||||||
acc[cred.provider] = { oauthCreds: [], apiKeys: [] };
|
acc[cred.provider] = {
|
||||||
|
oauthCreds: [],
|
||||||
|
apiKeys: [],
|
||||||
|
userPasswordCreds: [],
|
||||||
|
};
|
||||||
}
|
}
|
||||||
if (cred.type === "oauth2") {
|
if (cred.type === "oauth2") {
|
||||||
acc[cred.provider].oauthCreds.push(cred);
|
acc[cred.provider].oauthCreds.push(cred);
|
||||||
} else if (cred.type === "api_key") {
|
} else if (cred.type === "api_key") {
|
||||||
acc[cred.provider].apiKeys.push(cred);
|
acc[cred.provider].apiKeys.push(cred);
|
||||||
|
} else if (cred.type === "user_password") {
|
||||||
|
acc[cred.provider].userPasswordCreds.push(cred);
|
||||||
}
|
}
|
||||||
return acc;
|
return acc;
|
||||||
},
|
},
|
||||||
|
@ -204,6 +247,7 @@ export default function CredentialsProvider({
|
||||||
{
|
{
|
||||||
oauthCreds: CredentialsMetaResponse[];
|
oauthCreds: CredentialsMetaResponse[];
|
||||||
apiKeys: CredentialsMetaResponse[];
|
apiKeys: CredentialsMetaResponse[];
|
||||||
|
userPasswordCreds: CredentialsMetaResponse[];
|
||||||
}
|
}
|
||||||
>,
|
>,
|
||||||
);
|
);
|
||||||
|
@ -220,6 +264,8 @@ export default function CredentialsProvider({
|
||||||
savedApiKeys: credentialsByProvider[provider]?.apiKeys ?? [],
|
savedApiKeys: credentialsByProvider[provider]?.apiKeys ?? [],
|
||||||
savedOAuthCredentials:
|
savedOAuthCredentials:
|
||||||
credentialsByProvider[provider]?.oauthCreds ?? [],
|
credentialsByProvider[provider]?.oauthCreds ?? [],
|
||||||
|
savedUserPasswordCredentials:
|
||||||
|
credentialsByProvider[provider]?.userPasswordCreds ?? [],
|
||||||
oAuthCallback: (code: string, state_token: string) =>
|
oAuthCallback: (code: string, state_token: string) =>
|
||||||
oAuthCallback(
|
oAuthCallback(
|
||||||
provider as CredentialsProviderName,
|
provider as CredentialsProviderName,
|
||||||
|
@ -233,6 +279,13 @@ export default function CredentialsProvider({
|
||||||
provider as CredentialsProviderName,
|
provider as CredentialsProviderName,
|
||||||
credentials,
|
credentials,
|
||||||
),
|
),
|
||||||
|
createUserPasswordCredentials: (
|
||||||
|
credentials: UserPasswordCredentialsCreatable,
|
||||||
|
) =>
|
||||||
|
createUserPasswordCredentials(
|
||||||
|
provider as CredentialsProviderName,
|
||||||
|
credentials,
|
||||||
|
),
|
||||||
deleteCredentials: (id: string, force: boolean = false) =>
|
deleteCredentials: (id: string, force: boolean = false) =>
|
||||||
deleteCredentials(
|
deleteCredentials(
|
||||||
provider as CredentialsProviderName,
|
provider as CredentialsProviderName,
|
||||||
|
@ -245,7 +298,13 @@ export default function CredentialsProvider({
|
||||||
}));
|
}));
|
||||||
});
|
});
|
||||||
});
|
});
|
||||||
}, [api, createAPIKeyCredentials, deleteCredentials, oAuthCallback]);
|
}, [
|
||||||
|
api,
|
||||||
|
createAPIKeyCredentials,
|
||||||
|
createUserPasswordCredentials,
|
||||||
|
deleteCredentials,
|
||||||
|
oAuthCallback,
|
||||||
|
]);
|
||||||
|
|
||||||
return (
|
return (
|
||||||
<CredentialsProvidersContext.Provider value={providers}>
|
<CredentialsProvidersContext.Provider value={providers}>
|
||||||
|
|
|
@ -1,37 +1,21 @@
|
||||||
"use client";
|
"use client";
|
||||||
|
|
||||||
import { useState, useEffect, useCallback } from "react";
|
|
||||||
import { Button } from "@/components/ui/button";
|
import { Button } from "@/components/ui/button";
|
||||||
import { IconRefresh } from "@/components/ui/icons";
|
import { IconRefresh } from "@/components/ui/icons";
|
||||||
import { useBackendAPI } from "@/lib/autogpt-server-api/context";
|
import useCredits from "@/hooks/useCredits";
|
||||||
|
|
||||||
export default function CreditButton() {
|
export default function CreditButton() {
|
||||||
const [credit, setCredit] = useState<number | null>(null);
|
const { credits, fetchCredits } = useCredits();
|
||||||
const api = useBackendAPI();
|
|
||||||
|
|
||||||
const fetchCredit = useCallback(async () => {
|
|
||||||
try {
|
|
||||||
const response = await api.getUserCredit();
|
|
||||||
setCredit(response.credits);
|
|
||||||
} catch (error) {
|
|
||||||
console.error("Error fetching credit:", error);
|
|
||||||
setCredit(null);
|
|
||||||
}
|
|
||||||
}, []);
|
|
||||||
|
|
||||||
useEffect(() => {
|
|
||||||
fetchCredit();
|
|
||||||
}, [fetchCredit]);
|
|
||||||
|
|
||||||
return (
|
return (
|
||||||
credit !== null && (
|
credits !== null && (
|
||||||
<Button
|
<Button
|
||||||
onClick={fetchCredit}
|
onClick={fetchCredits}
|
||||||
variant="outline"
|
variant="outline"
|
||||||
className="flex items-center space-x-2 rounded-xl bg-gray-200"
|
className="flex items-center space-x-2 rounded-xl bg-gray-200"
|
||||||
>
|
>
|
||||||
<span className="mr-2 flex items-center text-foreground">
|
<span className="mr-2 flex items-center text-foreground">
|
||||||
{credit} <span className="ml-2 text-muted-foreground"> credits</span>
|
{credits} <span className="ml-2 text-muted-foreground"> credits</span>
|
||||||
</span>
|
</span>
|
||||||
<IconRefresh />
|
<IconRefresh />
|
||||||
</Button>
|
</Button>
|
||||||
|
|
|
@ -313,8 +313,6 @@ export const NodeGenericInputField: FC<{
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
console.log("propSchema", propSchema);
|
|
||||||
|
|
||||||
if ("properties" in propSchema) {
|
if ("properties" in propSchema) {
|
||||||
// Render a multi-select for all-boolean sub-schemas with more than 3 properties
|
// Render a multi-select for all-boolean sub-schemas with more than 3 properties
|
||||||
if (
|
if (
|
||||||
|
|
|
@ -323,7 +323,7 @@ export const IconCoin = createIcon((props) => (
|
||||||
viewBox="0 0 24 24"
|
viewBox="0 0 24 24"
|
||||||
fill="none"
|
fill="none"
|
||||||
stroke="currentColor"
|
stroke="currentColor"
|
||||||
strokeWidth="2"
|
strokeWidth="1.25"
|
||||||
strokeLinecap="round"
|
strokeLinecap="round"
|
||||||
strokeLinejoin="round"
|
strokeLinejoin="round"
|
||||||
aria-label="Coin Icon"
|
aria-label="Coin Icon"
|
||||||
|
|
|
@ -862,6 +862,7 @@ export default function useAgentGraph(
|
||||||
title: "Error saving agent",
|
title: "Error saving agent",
|
||||||
description: errorMessage,
|
description: errorMessage,
|
||||||
});
|
});
|
||||||
|
setSaveRunRequest({ request: "save", state: "error" });
|
||||||
}
|
}
|
||||||
}, [_saveAgent, toast]);
|
}, [_saveAgent, toast]);
|
||||||
|
|
||||||
|
@ -874,7 +875,7 @@ export default function useAgentGraph(
|
||||||
request: "save",
|
request: "save",
|
||||||
state: "saving",
|
state: "saving",
|
||||||
});
|
});
|
||||||
}, [saveAgent]);
|
}, [saveAgent, saveRunRequest.state]);
|
||||||
|
|
||||||
const requestSaveAndRun = useCallback(() => {
|
const requestSaveAndRun = useCallback(() => {
|
||||||
saveAgent();
|
saveAgent();
|
||||||
|
|
|
@ -17,12 +17,14 @@ export type CredentialsData =
|
||||||
schema: BlockIOCredentialsSubSchema;
|
schema: BlockIOCredentialsSubSchema;
|
||||||
supportsApiKey: boolean;
|
supportsApiKey: boolean;
|
||||||
supportsOAuth2: boolean;
|
supportsOAuth2: boolean;
|
||||||
|
supportsUserPassword: boolean;
|
||||||
isLoading: true;
|
isLoading: true;
|
||||||
}
|
}
|
||||||
| (CredentialsProviderData & {
|
| (CredentialsProviderData & {
|
||||||
schema: BlockIOCredentialsSubSchema;
|
schema: BlockIOCredentialsSubSchema;
|
||||||
supportsApiKey: boolean;
|
supportsApiKey: boolean;
|
||||||
supportsOAuth2: boolean;
|
supportsOAuth2: boolean;
|
||||||
|
supportsUserPassword: boolean;
|
||||||
isLoading: false;
|
isLoading: false;
|
||||||
});
|
});
|
||||||
|
|
||||||
|
@ -72,6 +74,8 @@ export default function useCredentials(
|
||||||
const supportsApiKey =
|
const supportsApiKey =
|
||||||
credentialsSchema.credentials_types.includes("api_key");
|
credentialsSchema.credentials_types.includes("api_key");
|
||||||
const supportsOAuth2 = credentialsSchema.credentials_types.includes("oauth2");
|
const supportsOAuth2 = credentialsSchema.credentials_types.includes("oauth2");
|
||||||
|
const supportsUserPassword =
|
||||||
|
credentialsSchema.credentials_types.includes("user_password");
|
||||||
|
|
||||||
// No provider means maybe it's still loading
|
// No provider means maybe it's still loading
|
||||||
if (!provider) {
|
if (!provider) {
|
||||||
|
@ -93,13 +97,17 @@ export default function useCredentials(
|
||||||
)
|
)
|
||||||
: provider.savedOAuthCredentials;
|
: provider.savedOAuthCredentials;
|
||||||
|
|
||||||
|
const savedUserPasswordCredentials = provider.savedUserPasswordCredentials;
|
||||||
|
|
||||||
return {
|
return {
|
||||||
...provider,
|
...provider,
|
||||||
provider: providerName,
|
provider: providerName,
|
||||||
schema: credentialsSchema,
|
schema: credentialsSchema,
|
||||||
supportsApiKey,
|
supportsApiKey,
|
||||||
supportsOAuth2,
|
supportsOAuth2,
|
||||||
|
supportsUserPassword,
|
||||||
savedOAuthCredentials,
|
savedOAuthCredentials,
|
||||||
|
savedUserPasswordCredentials,
|
||||||
isLoading: false,
|
isLoading: false,
|
||||||
};
|
};
|
||||||
}
|
}
|
||||||
|
|
|
@ -0,0 +1,48 @@
|
||||||
|
import AutoGPTServerAPI from "@/lib/autogpt-server-api";
|
||||||
|
import { useCallback, useEffect, useMemo, useState } from "react";
|
||||||
|
import { loadStripe } from "@stripe/stripe-js";
|
||||||
|
import { useRouter } from "next/navigation";
|
||||||
|
|
||||||
|
const stripePromise = loadStripe(
|
||||||
|
process.env.NEXT_PUBLIC_STRIPE_PUBLISHABLE_KEY!,
|
||||||
|
);
|
||||||
|
|
||||||
|
export default function useCredits(): {
|
||||||
|
credits: number | null;
|
||||||
|
fetchCredits: () => void;
|
||||||
|
requestTopUp: (usd_amount: number) => Promise<void>;
|
||||||
|
} {
|
||||||
|
const [credits, setCredits] = useState<number | null>(null);
|
||||||
|
const api = useMemo(() => new AutoGPTServerAPI(), []);
|
||||||
|
const router = useRouter();
|
||||||
|
|
||||||
|
const fetchCredits = useCallback(async () => {
|
||||||
|
const response = await api.getUserCredit();
|
||||||
|
setCredits(response.credits);
|
||||||
|
}, [api]);
|
||||||
|
|
||||||
|
useEffect(() => {
|
||||||
|
fetchCredits();
|
||||||
|
}, [fetchCredits]);
|
||||||
|
|
||||||
|
const requestTopUp = useCallback(
|
||||||
|
async (usd_amount: number) => {
|
||||||
|
const stripe = await stripePromise;
|
||||||
|
|
||||||
|
if (!stripe) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Convert dollar amount to credit count
|
||||||
|
const response = await api.requestTopUp(usd_amount * 100);
|
||||||
|
router.push(response.checkout_url);
|
||||||
|
},
|
||||||
|
[api, router],
|
||||||
|
);
|
||||||
|
|
||||||
|
return {
|
||||||
|
credits,
|
||||||
|
fetchCredits,
|
||||||
|
requestTopUp,
|
||||||
|
};
|
||||||
|
}
|
|
@ -15,7 +15,6 @@ import {
|
||||||
GraphUpdateable,
|
GraphUpdateable,
|
||||||
NodeExecutionResult,
|
NodeExecutionResult,
|
||||||
MyAgentsResponse,
|
MyAgentsResponse,
|
||||||
OAuth2Credentials,
|
|
||||||
ProfileDetails,
|
ProfileDetails,
|
||||||
User,
|
User,
|
||||||
StoreAgentsResponse,
|
StoreAgentsResponse,
|
||||||
|
@ -29,6 +28,8 @@ import {
|
||||||
StoreReview,
|
StoreReview,
|
||||||
ScheduleCreatable,
|
ScheduleCreatable,
|
||||||
Schedule,
|
Schedule,
|
||||||
|
UserPasswordCredentials,
|
||||||
|
Credentials,
|
||||||
APIKeyPermission,
|
APIKeyPermission,
|
||||||
CreateAPIKeyResponse,
|
CreateAPIKeyResponse,
|
||||||
APIKey,
|
APIKey,
|
||||||
|
@ -88,6 +89,18 @@ export default class BackendAPI {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
requestTopUp(amount: number): Promise<{ checkout_url: string }> {
|
||||||
|
return this._request("POST", "/credits", { amount });
|
||||||
|
}
|
||||||
|
|
||||||
|
getUserPaymentPortalLink(): Promise<{ url: string }> {
|
||||||
|
return this._get("/credits/manage");
|
||||||
|
}
|
||||||
|
|
||||||
|
fulfillCheckout(): Promise<void> {
|
||||||
|
return this._request("PATCH", "/credits");
|
||||||
|
}
|
||||||
|
|
||||||
getBlocks(): Promise<Block[]> {
|
getBlocks(): Promise<Block[]> {
|
||||||
return this._get("/blocks");
|
return this._get("/blocks");
|
||||||
}
|
}
|
||||||
|
@ -191,7 +204,17 @@ export default class BackendAPI {
|
||||||
return this._request(
|
return this._request(
|
||||||
"POST",
|
"POST",
|
||||||
`/integrations/${credentials.provider}/credentials`,
|
`/integrations/${credentials.provider}/credentials`,
|
||||||
credentials,
|
{ ...credentials, type: "api_key" },
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
createUserPasswordCredentials(
|
||||||
|
credentials: Omit<UserPasswordCredentials, "id" | "type">,
|
||||||
|
): Promise<UserPasswordCredentials> {
|
||||||
|
return this._request(
|
||||||
|
"POST",
|
||||||
|
`/integrations/${credentials.provider}/credentials`,
|
||||||
|
{ ...credentials, type: "user_password" },
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -203,10 +226,7 @@ export default class BackendAPI {
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
getCredentials(
|
getCredentials(provider: string, id: string): Promise<Credentials> {
|
||||||
provider: string,
|
|
||||||
id: string,
|
|
||||||
): Promise<APIKeyCredentials | OAuth2Credentials> {
|
|
||||||
return this._get(`/integrations/${provider}/credentials/${id}`);
|
return this._get(`/integrations/${provider}/credentials/${id}`);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -97,7 +97,12 @@ export type BlockIOBooleanSubSchema = BlockIOSubSchemaMeta & {
|
||||||
default?: boolean;
|
default?: boolean;
|
||||||
};
|
};
|
||||||
|
|
||||||
export type CredentialsType = "api_key" | "oauth2";
|
export type CredentialsType = "api_key" | "oauth2" | "user_password";
|
||||||
|
|
||||||
|
export type Credentials =
|
||||||
|
| APIKeyCredentials
|
||||||
|
| OAuth2Credentials
|
||||||
|
| UserPasswordCredentials;
|
||||||
|
|
||||||
// --8<-- [start:BlockIOCredentialsSubSchema]
|
// --8<-- [start:BlockIOCredentialsSubSchema]
|
||||||
export const PROVIDER_NAMES = {
|
export const PROVIDER_NAMES = {
|
||||||
|
@ -105,12 +110,16 @@ export const PROVIDER_NAMES = {
|
||||||
D_ID: "d_id",
|
D_ID: "d_id",
|
||||||
DISCORD: "discord",
|
DISCORD: "discord",
|
||||||
E2B: "e2b",
|
E2B: "e2b",
|
||||||
|
EXA: "exa",
|
||||||
|
FAL: "fal",
|
||||||
GITHUB: "github",
|
GITHUB: "github",
|
||||||
GOOGLE: "google",
|
GOOGLE: "google",
|
||||||
GOOGLE_MAPS: "google_maps",
|
GOOGLE_MAPS: "google_maps",
|
||||||
GROQ: "groq",
|
GROQ: "groq",
|
||||||
|
HUBSPOT: "hubspot",
|
||||||
IDEOGRAM: "ideogram",
|
IDEOGRAM: "ideogram",
|
||||||
JINA: "jina",
|
JINA: "jina",
|
||||||
|
LINEAR: "linear",
|
||||||
MEDIUM: "medium",
|
MEDIUM: "medium",
|
||||||
NOTION: "notion",
|
NOTION: "notion",
|
||||||
NVIDIA: "nvidia",
|
NVIDIA: "nvidia",
|
||||||
|
@ -120,13 +129,12 @@ export const PROVIDER_NAMES = {
|
||||||
OPEN_ROUTER: "open_router",
|
OPEN_ROUTER: "open_router",
|
||||||
PINECONE: "pinecone",
|
PINECONE: "pinecone",
|
||||||
SLANT3D: "slant3d",
|
SLANT3D: "slant3d",
|
||||||
|
SMTP: "smtp",
|
||||||
|
TWITTER: "twitter",
|
||||||
REPLICATE: "replicate",
|
REPLICATE: "replicate",
|
||||||
FAL: "fal",
|
REDDIT: "reddit",
|
||||||
REVID: "revid",
|
REVID: "revid",
|
||||||
UNREAL_SPEECH: "unreal_speech",
|
UNREAL_SPEECH: "unreal_speech",
|
||||||
EXA: "exa",
|
|
||||||
HUBSPOT: "hubspot",
|
|
||||||
TWITTER: "twitter",
|
|
||||||
} as const;
|
} as const;
|
||||||
// --8<-- [end:BlockIOCredentialsSubSchema]
|
// --8<-- [end:BlockIOCredentialsSubSchema]
|
||||||
|
|
||||||
|
@ -322,8 +330,15 @@ export type APIKeyCredentials = BaseCredentials & {
|
||||||
expires_at?: number;
|
expires_at?: number;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
export type UserPasswordCredentials = BaseCredentials & {
|
||||||
|
type: "user_password";
|
||||||
|
title: string;
|
||||||
|
username: string;
|
||||||
|
password: string;
|
||||||
|
};
|
||||||
|
|
||||||
/* Mirror of backend/data/integrations.py:Webhook */
|
/* Mirror of backend/data/integrations.py:Webhook */
|
||||||
type Webhook = {
|
export type Webhook = {
|
||||||
id: string;
|
id: string;
|
||||||
url: string;
|
url: string;
|
||||||
provider: CredentialsProviderName;
|
provider: CredentialsProviderName;
|
||||||
|
|
|
@ -42,39 +42,75 @@ test.describe("Build", () => { //(1)!
|
||||||
});
|
});
|
||||||
// --8<-- [end:BuildPageExample]
|
// --8<-- [end:BuildPageExample]
|
||||||
|
|
||||||
test("user can add all blocks", async ({ page }, testInfo) => {
|
test("user can add all blocks a-l", async ({ page }, testInfo) => {
|
||||||
// this test is slow af so we 10x the timeout (sorry future me)
|
// this test is slow af so we 10x the timeout (sorry future me)
|
||||||
await test.setTimeout(testInfo.timeout * 10);
|
await test.setTimeout(testInfo.timeout * 100);
|
||||||
await test.expect(buildPage.isLoaded()).resolves.toBeTruthy();
|
await test.expect(buildPage.isLoaded()).resolves.toBeTruthy();
|
||||||
await test.expect(page).toHaveURL(new RegExp("/.*build"));
|
await test.expect(page).toHaveURL(new RegExp("/.*build"));
|
||||||
await buildPage.closeTutorial();
|
await buildPage.closeTutorial();
|
||||||
await buildPage.openBlocksPanel();
|
await buildPage.openBlocksPanel();
|
||||||
const blocks = await buildPage.getBlocks();
|
const blocks = await buildPage.getBlocks();
|
||||||
|
|
||||||
// add all the blocks in order
|
const blocksToSkip = await buildPage.getBlocksToSkip();
|
||||||
|
|
||||||
|
// add all the blocks in order except for the agent executor block
|
||||||
for (const block of blocks) {
|
for (const block of blocks) {
|
||||||
if (block.id !== "e189baac-8c20-45a1-94a7-55177ea42565") {
|
if (block.name[0].toLowerCase() >= "m") {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
if (!blocksToSkip.some((b) => b === block.id)) {
|
||||||
await buildPage.addBlock(block);
|
await buildPage.addBlock(block);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
await buildPage.closeBlocksPanel();
|
await buildPage.closeBlocksPanel();
|
||||||
// check that all the blocks are visible
|
// check that all the blocks are visible
|
||||||
for (const block of blocks) {
|
for (const block of blocks) {
|
||||||
if (block.id !== "e189baac-8c20-45a1-94a7-55177ea42565") {
|
if (block.name[0].toLowerCase() >= "m") {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
if (!blocksToSkip.some((b) => b === block.id)) {
|
||||||
|
console.log("Checking block:", block.name);
|
||||||
await test.expect(buildPage.hasBlock(block)).resolves.toBeTruthy();
|
await test.expect(buildPage.hasBlock(block)).resolves.toBeTruthy();
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
// fill in the input for the agent input block
|
|
||||||
await buildPage.fillBlockInputByPlaceholder(
|
// check that we can save the agent with all the blocks
|
||||||
blocks.find((b) => b.name === "Agent Input")?.id ?? "",
|
await buildPage.saveAgent("all blocks test", "all blocks test");
|
||||||
"Enter Name",
|
// page should have a url like http://localhost:3000/build?flowID=f4f3a1da-cfb3-430f-a074-a455b047e340
|
||||||
"Agent Input Field",
|
await test.expect(page).toHaveURL(new RegExp("/.*build\\?flowID=.+"));
|
||||||
);
|
});
|
||||||
await buildPage.fillBlockInputByPlaceholder(
|
|
||||||
blocks.find((b) => b.name === "Agent Output")?.id ?? "",
|
test("user can add all blocks m-z", async ({ page }, testInfo) => {
|
||||||
"Enter Name",
|
// this test is slow af so we 10x the timeout (sorry future me)
|
||||||
"Agent Output Field",
|
await test.setTimeout(testInfo.timeout * 100);
|
||||||
);
|
await test.expect(buildPage.isLoaded()).resolves.toBeTruthy();
|
||||||
|
await test.expect(page).toHaveURL(new RegExp("/.*build"));
|
||||||
|
await buildPage.closeTutorial();
|
||||||
|
await buildPage.openBlocksPanel();
|
||||||
|
const blocks = await buildPage.getBlocks();
|
||||||
|
|
||||||
|
const blocksToSkip = await buildPage.getBlocksToSkip();
|
||||||
|
|
||||||
|
// add all the blocks in order except for the agent executor block
|
||||||
|
for (const block of blocks) {
|
||||||
|
if (block.name[0].toLowerCase() < "m") {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
if (!blocksToSkip.some((b) => b === block.id)) {
|
||||||
|
await buildPage.addBlock(block);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
await buildPage.closeBlocksPanel();
|
||||||
|
// check that all the blocks are visible
|
||||||
|
for (const block of blocks) {
|
||||||
|
if (block.name[0].toLowerCase() < "m") {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
if (!blocksToSkip.some((b) => b === block.id)) {
|
||||||
|
await test.expect(buildPage.hasBlock(block)).resolves.toBeTruthy();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// check that we can save the agent with all the blocks
|
// check that we can save the agent with all the blocks
|
||||||
await buildPage.saveAgent("all blocks test", "all blocks test");
|
await buildPage.saveAgent("all blocks test", "all blocks test");
|
||||||
// page should have a url like http://localhost:3000/build?flowID=f4f3a1da-cfb3-430f-a074-a455b047e340
|
// page should have a url like http://localhost:3000/build?flowID=f4f3a1da-cfb3-430f-a074-a455b047e340
|
||||||
|
|
|
@ -6,8 +6,7 @@ import { v4 as uuidv4 } from "uuid";
|
||||||
import * as fs from "fs/promises";
|
import * as fs from "fs/promises";
|
||||||
import path from "path";
|
import path from "path";
|
||||||
// --8<-- [start:AttachAgentId]
|
// --8<-- [start:AttachAgentId]
|
||||||
|
test.describe("Monitor", () => {
|
||||||
test.describe.skip("Monitor", () => {
|
|
||||||
let buildPage: BuildPage;
|
let buildPage: BuildPage;
|
||||||
let monitorPage: MonitorPage;
|
let monitorPage: MonitorPage;
|
||||||
|
|
||||||
|
@ -54,21 +53,25 @@ test.describe.skip("Monitor", () => {
|
||||||
await test.expect(agents.length).toBeGreaterThan(0);
|
await test.expect(agents.length).toBeGreaterThan(0);
|
||||||
});
|
});
|
||||||
|
|
||||||
test("user can export and import agents", async ({
|
test.skip("user can export and import agents", async ({
|
||||||
page,
|
page,
|
||||||
}, testInfo: TestInfo) => {
|
}, testInfo: TestInfo) => {
|
||||||
// --8<-- [start:ReadAgentId]
|
// --8<-- [start:ReadAgentId]
|
||||||
if (testInfo.attachments.length === 0 || !testInfo.attachments[0].body) {
|
if (testInfo.attachments.length === 0 || !testInfo.attachments[0].body) {
|
||||||
throw new Error("No agent id attached to the test");
|
throw new Error("No agent id attached to the test");
|
||||||
}
|
}
|
||||||
const id = testInfo.attachments[0].body.toString();
|
const testAttachName = testInfo.attachments[0].body.toString();
|
||||||
// --8<-- [end:ReadAgentId]
|
// --8<-- [end:ReadAgentId]
|
||||||
const agents = await monitorPage.listAgents();
|
const agents = await monitorPage.listAgents();
|
||||||
|
|
||||||
const downloadPromise = page.waitForEvent("download");
|
const downloadPromise = page.waitForEvent("download");
|
||||||
await monitorPage.exportToFile(
|
const agent = agents.find(
|
||||||
agents.find((a: any) => a.id === id) || agents[0],
|
(a: any) => a.name === `test-agent-${testAttachName}`,
|
||||||
);
|
);
|
||||||
|
if (!agent) {
|
||||||
|
throw new Error(`Agent ${testAttachName} not found`);
|
||||||
|
}
|
||||||
|
await monitorPage.exportToFile(agent);
|
||||||
const download = await downloadPromise;
|
const download = await downloadPromise;
|
||||||
|
|
||||||
// Wait for the download process to complete and save the downloaded file somewhere.
|
// Wait for the download process to complete and save the downloaded file somewhere.
|
||||||
|
@ -78,9 +81,6 @@ test.describe.skip("Monitor", () => {
|
||||||
console.log(`downloaded file to ${download.suggestedFilename()}`);
|
console.log(`downloaded file to ${download.suggestedFilename()}`);
|
||||||
await test.expect(download.suggestedFilename()).toBeDefined();
|
await test.expect(download.suggestedFilename()).toBeDefined();
|
||||||
// test-agent-uuid-v1.json
|
// test-agent-uuid-v1.json
|
||||||
if (id) {
|
|
||||||
await test.expect(download.suggestedFilename()).toContain(id);
|
|
||||||
}
|
|
||||||
await test.expect(download.suggestedFilename()).toContain("test-agent-");
|
await test.expect(download.suggestedFilename()).toContain("test-agent-");
|
||||||
await test.expect(download.suggestedFilename()).toContain("v1.json");
|
await test.expect(download.suggestedFilename()).toContain("v1.json");
|
||||||
|
|
||||||
|
@ -89,9 +89,9 @@ test.describe.skip("Monitor", () => {
|
||||||
const filesInFolder = await fs.readdir(
|
const filesInFolder = await fs.readdir(
|
||||||
`${monitorPage.downloadsFolder}/monitor`,
|
`${monitorPage.downloadsFolder}/monitor`,
|
||||||
);
|
);
|
||||||
const importFile = filesInFolder.find((f) => f.includes(id));
|
const importFile = filesInFolder.find((f) => f.includes(testAttachName));
|
||||||
if (!importFile) {
|
if (!importFile) {
|
||||||
throw new Error(`No import file found for agent ${id}`);
|
throw new Error(`No import file found for agent ${testAttachName}`);
|
||||||
}
|
}
|
||||||
const baseName = importFile.split(".")[0];
|
const baseName = importFile.split(".")[0];
|
||||||
await monitorPage.importFromFile(
|
await monitorPage.importFromFile(
|
||||||
|
|
|
@ -1,7 +1,7 @@
|
||||||
import { ElementHandle, Locator, Page } from "@playwright/test";
|
import { ElementHandle, Locator, Page } from "@playwright/test";
|
||||||
import { BasePage } from "./base.page";
|
import { BasePage } from "./base.page";
|
||||||
|
|
||||||
interface Block {
|
export interface Block {
|
||||||
id: string;
|
id: string;
|
||||||
name: string;
|
name: string;
|
||||||
description: string;
|
description: string;
|
||||||
|
@ -378,6 +378,39 @@ export class BuildPage extends BasePage {
|
||||||
};
|
};
|
||||||
}
|
}
|
||||||
|
|
||||||
|
async getAgentExecutorBlockDetails(): Promise<Block> {
|
||||||
|
return {
|
||||||
|
id: "e189baac-8c20-45a1-94a7-55177ea42565",
|
||||||
|
name: "Agent Executor",
|
||||||
|
description: "Executes an existing agent inside your agent",
|
||||||
|
};
|
||||||
|
}
|
||||||
|
|
||||||
|
async getAgentOutputBlockDetails(): Promise<Block> {
|
||||||
|
return {
|
||||||
|
id: "363ae599-353e-4804-937e-b2ee3cef3da4",
|
||||||
|
name: "Agent Output",
|
||||||
|
description: "This block is used to output the result of an agent.",
|
||||||
|
};
|
||||||
|
}
|
||||||
|
|
||||||
|
async getAgentInputBlockDetails(): Promise<Block> {
|
||||||
|
return {
|
||||||
|
id: "c0a8e994-ebf1-4a9c-a4d8-89d09c86741b",
|
||||||
|
name: "Agent Input",
|
||||||
|
description: "This block is used to provide input to the graph.",
|
||||||
|
};
|
||||||
|
}
|
||||||
|
|
||||||
|
async getGithubTriggerBlockDetails(): Promise<Block> {
|
||||||
|
return {
|
||||||
|
id: "6c60ec01-8128-419e-988f-96a063ee2fea",
|
||||||
|
name: "Github Trigger",
|
||||||
|
description:
|
||||||
|
"This block triggers on pull request events and outputs the event type and payload.",
|
||||||
|
};
|
||||||
|
}
|
||||||
|
|
||||||
async nextTutorialStep(): Promise<void> {
|
async nextTutorialStep(): Promise<void> {
|
||||||
console.log(`clicking next tutorial step`);
|
console.log(`clicking next tutorial step`);
|
||||||
await this.page.getByRole("button", { name: "Next" }).click();
|
await this.page.getByRole("button", { name: "Next" }).click();
|
||||||
|
@ -448,6 +481,15 @@ export class BuildPage extends BasePage {
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
async getBlocksToSkip(): Promise<string[]> {
|
||||||
|
return [
|
||||||
|
(await this.getAgentExecutorBlockDetails()).id,
|
||||||
|
(await this.getAgentInputBlockDetails()).id,
|
||||||
|
(await this.getAgentOutputBlockDetails()).id,
|
||||||
|
(await this.getGithubTriggerBlockDetails()).id,
|
||||||
|
];
|
||||||
|
}
|
||||||
|
|
||||||
async waitForRunTutorialButton(): Promise<void> {
|
async waitForRunTutorialButton(): Promise<void> {
|
||||||
console.log(`waiting for run tutorial button`);
|
console.log(`waiting for run tutorial button`);
|
||||||
await this.page.waitForSelector('[id="press-run-label"]');
|
await this.page.waitForSelector('[id="press-run-label"]');
|
||||||
|
|
|
@ -43,9 +43,6 @@ export class MonitorPage extends BasePage {
|
||||||
async isLoaded(): Promise<boolean> {
|
async isLoaded(): Promise<boolean> {
|
||||||
console.log(`checking if monitor page is loaded`);
|
console.log(`checking if monitor page is loaded`);
|
||||||
try {
|
try {
|
||||||
// Wait for network to settle first
|
|
||||||
await this.page.waitForLoadState("networkidle", { timeout: 10_000 });
|
|
||||||
|
|
||||||
// Wait for the monitor page
|
// Wait for the monitor page
|
||||||
await this.page.getByTestId("monitor-page").waitFor({
|
await this.page.getByTestId("monitor-page").waitFor({
|
||||||
state: "visible",
|
state: "visible",
|
||||||
|
@ -55,7 +52,7 @@ export class MonitorPage extends BasePage {
|
||||||
// Wait for table headers to be visible (indicates table structure is ready)
|
// Wait for table headers to be visible (indicates table structure is ready)
|
||||||
await this.page.locator("thead th").first().waitFor({
|
await this.page.locator("thead th").first().waitFor({
|
||||||
state: "visible",
|
state: "visible",
|
||||||
timeout: 5_000,
|
timeout: 15_000,
|
||||||
});
|
});
|
||||||
|
|
||||||
// Wait for either a table row or an empty tbody to be present
|
// Wait for either a table row or an empty tbody to be present
|
||||||
|
@ -63,14 +60,14 @@ export class MonitorPage extends BasePage {
|
||||||
// Wait for at least one row
|
// Wait for at least one row
|
||||||
this.page.locator("tbody tr[data-testid]").first().waitFor({
|
this.page.locator("tbody tr[data-testid]").first().waitFor({
|
||||||
state: "visible",
|
state: "visible",
|
||||||
timeout: 5_000,
|
timeout: 15_000,
|
||||||
}),
|
}),
|
||||||
// OR wait for an empty tbody (indicating no agents but table is loaded)
|
// OR wait for an empty tbody (indicating no agents but table is loaded)
|
||||||
this.page
|
this.page
|
||||||
.locator("tbody[data-testid='agent-flow-list-body']:empty")
|
.locator("tbody[data-testid='agent-flow-list-body']:empty")
|
||||||
.waitFor({
|
.waitFor({
|
||||||
state: "visible",
|
state: "visible",
|
||||||
timeout: 5_000,
|
timeout: 15_000,
|
||||||
}),
|
}),
|
||||||
]);
|
]);
|
||||||
|
|
||||||
|
@ -114,6 +111,13 @@ export class MonitorPage extends BasePage {
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
|
agents.reduce((acc, agent) => {
|
||||||
|
if (!agent.id.includes("flow-run")) {
|
||||||
|
acc.push(agent);
|
||||||
|
}
|
||||||
|
return acc;
|
||||||
|
}, [] as Agent[]);
|
||||||
|
|
||||||
return agents;
|
return agents;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -219,7 +223,7 @@ export class MonitorPage extends BasePage {
|
||||||
async exportToFile(agent: Agent) {
|
async exportToFile(agent: Agent) {
|
||||||
await this.clickAgent(agent.id);
|
await this.clickAgent(agent.id);
|
||||||
|
|
||||||
console.log(`exporting agent ${agent.id} ${agent.name} to file`);
|
console.log(`exporting agent id: ${agent.id} name: ${agent.name} to file`);
|
||||||
await this.page.getByTestId("export-button").click();
|
await this.page.getByTestId("export-button").click();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -1693,10 +1693,10 @@
|
||||||
outvariant "^1.4.3"
|
outvariant "^1.4.3"
|
||||||
strict-event-emitter "^0.5.1"
|
strict-event-emitter "^0.5.1"
|
||||||
|
|
||||||
"@next/env@14.2.20":
|
"@next/env@14.2.23":
|
||||||
version "14.2.20"
|
version "14.2.23"
|
||||||
resolved "https://registry.yarnpkg.com/@next/env/-/env-14.2.20.tgz#0be2cc955f4eb837516e7d7382284cd5bc1d5a02"
|
resolved "https://registry.yarnpkg.com/@next/env/-/env-14.2.23.tgz#3003b53693cbc476710b856f83e623c8231a6be9"
|
||||||
integrity sha512-JfDpuOCB0UBKlEgEy/H6qcBSzHimn/YWjUHzKl1jMeUO+QVRdzmTTl8gFJaNO87c8DXmVKhFCtwxQ9acqB3+Pw==
|
integrity sha512-CysUC9IO+2Bh0omJ3qrb47S8DtsTKbFidGm6ow4gXIG6reZybqxbkH2nhdEm1tC8SmgzDdpq3BIML0PWsmyUYA==
|
||||||
|
|
||||||
"@next/eslint-plugin-next@15.1.3":
|
"@next/eslint-plugin-next@15.1.3":
|
||||||
version "15.1.3"
|
version "15.1.3"
|
||||||
|
@ -1705,50 +1705,50 @@
|
||||||
dependencies:
|
dependencies:
|
||||||
fast-glob "3.3.1"
|
fast-glob "3.3.1"
|
||||||
|
|
||||||
"@next/swc-darwin-arm64@14.2.20":
|
"@next/swc-darwin-arm64@14.2.23":
|
||||||
version "14.2.20"
|
version "14.2.23"
|
||||||
resolved "https://registry.yarnpkg.com/@next/swc-darwin-arm64/-/swc-darwin-arm64-14.2.20.tgz#3c99d318c08362aedde5d2778eec3a50b8085d99"
|
resolved "https://registry.yarnpkg.com/@next/swc-darwin-arm64/-/swc-darwin-arm64-14.2.23.tgz#6d83f03e35e163e8bbeaf5aeaa6bf55eed23d7a1"
|
||||||
integrity sha512-WDfq7bmROa5cIlk6ZNonNdVhKmbCv38XteVFYsxea1vDJt3SnYGgxLGMTXQNfs5OkFvAhmfKKrwe7Y0Hs+rWOg==
|
integrity sha512-WhtEntt6NcbABA8ypEoFd3uzq5iAnrl9AnZt9dXdO+PZLACE32z3a3qA5OoV20JrbJfSJ6Sd6EqGZTrlRnGxQQ==
|
||||||
|
|
||||||
"@next/swc-darwin-x64@14.2.20":
|
"@next/swc-darwin-x64@14.2.23":
|
||||||
version "14.2.20"
|
version "14.2.23"
|
||||||
resolved "https://registry.yarnpkg.com/@next/swc-darwin-x64/-/swc-darwin-x64-14.2.20.tgz#fd547fad1446a677f29c1160006fdd482bba4052"
|
resolved "https://registry.yarnpkg.com/@next/swc-darwin-x64/-/swc-darwin-x64-14.2.23.tgz#e02abc35d5e36ce1550f674f8676999f293ba54f"
|
||||||
integrity sha512-XIQlC+NAmJPfa2hruLvr1H1QJJeqOTDV+v7tl/jIdoFvqhoihvSNykLU/G6NMgoeo+e/H7p/VeWSOvMUHKtTIg==
|
integrity sha512-vwLw0HN2gVclT/ikO6EcE+LcIN+0mddJ53yG4eZd0rXkuEr/RnOaMH8wg/sYl5iz5AYYRo/l6XX7FIo6kwbw1Q==
|
||||||
|
|
||||||
"@next/swc-linux-arm64-gnu@14.2.20":
|
"@next/swc-linux-arm64-gnu@14.2.23":
|
||||||
version "14.2.20"
|
version "14.2.23"
|
||||||
resolved "https://registry.yarnpkg.com/@next/swc-linux-arm64-gnu/-/swc-linux-arm64-gnu-14.2.20.tgz#1d6ba1929d3a11b74c0185cdeca1e38b824222ca"
|
resolved "https://registry.yarnpkg.com/@next/swc-linux-arm64-gnu/-/swc-linux-arm64-gnu-14.2.23.tgz#f13516ad2d665950951b59e7c239574bb8504d63"
|
||||||
integrity sha512-pnzBrHTPXIMm5QX3QC8XeMkpVuoAYOmyfsO4VlPn+0NrHraNuWjdhe+3xLq01xR++iCvX+uoeZmJDKcOxI201Q==
|
integrity sha512-uuAYwD3At2fu5CH1wD7FpP87mnjAv4+DNvLaR9kiIi8DLStWSW304kF09p1EQfhcbUI1Py2vZlBO2VaVqMRtpg==
|
||||||
|
|
||||||
"@next/swc-linux-arm64-musl@14.2.20":
|
"@next/swc-linux-arm64-musl@14.2.23":
|
||||||
version "14.2.20"
|
version "14.2.23"
|
||||||
resolved "https://registry.yarnpkg.com/@next/swc-linux-arm64-musl/-/swc-linux-arm64-musl-14.2.20.tgz#0fe0c67b5d916f99ca76b39416557af609768f17"
|
resolved "https://registry.yarnpkg.com/@next/swc-linux-arm64-musl/-/swc-linux-arm64-musl-14.2.23.tgz#10d05a1c161dc8426d54ccf6d9bbed6953a3252a"
|
||||||
integrity sha512-WhJJAFpi6yqmUx1momewSdcm/iRXFQS0HU2qlUGlGE/+98eu7JWLD5AAaP/tkK1mudS/rH2f9E3WCEF2iYDydQ==
|
integrity sha512-Mm5KHd7nGgeJ4EETvVgFuqKOyDh+UMXHXxye6wRRFDr4FdVRI6YTxajoV2aHE8jqC14xeAMVZvLqYqS7isHL+g==
|
||||||
|
|
||||||
"@next/swc-linux-x64-gnu@14.2.20":
|
"@next/swc-linux-x64-gnu@14.2.23":
|
||||||
version "14.2.20"
|
version "14.2.23"
|
||||||
resolved "https://registry.yarnpkg.com/@next/swc-linux-x64-gnu/-/swc-linux-x64-gnu-14.2.20.tgz#6d29fa8cdb6a9f8250c2048aaa24538f0cd0b02d"
|
resolved "https://registry.yarnpkg.com/@next/swc-linux-x64-gnu/-/swc-linux-x64-gnu-14.2.23.tgz#7f5856df080f58ba058268b30429a2ab52500536"
|
||||||
integrity sha512-ao5HCbw9+iG1Kxm8XsGa3X174Ahn17mSYBQlY6VGsdsYDAbz/ZP13wSLfvlYoIDn1Ger6uYA+yt/3Y9KTIupRg==
|
integrity sha512-Ybfqlyzm4sMSEQO6lDksggAIxnvWSG2cDWnG2jgd+MLbHYn2pvFA8DQ4pT2Vjk3Cwrv+HIg7vXJ8lCiLz79qoQ==
|
||||||
|
|
||||||
"@next/swc-linux-x64-musl@14.2.20":
|
"@next/swc-linux-x64-musl@14.2.23":
|
||||||
version "14.2.20"
|
version "14.2.23"
|
||||||
resolved "https://registry.yarnpkg.com/@next/swc-linux-x64-musl/-/swc-linux-x64-musl-14.2.20.tgz#bfc57482bc033fda8455e8aab1c3cbc44f0c4690"
|
resolved "https://registry.yarnpkg.com/@next/swc-linux-x64-musl/-/swc-linux-x64-musl-14.2.23.tgz#d494ebdf26421c91be65f9b1d095df0191c956d8"
|
||||||
integrity sha512-CXm/kpnltKTT7945np6Td3w7shj/92TMRPyI/VvveFe8+YE+/YOJ5hyAWK5rpx711XO1jBCgXl211TWaxOtkaA==
|
integrity sha512-OSQX94sxd1gOUz3jhhdocnKsy4/peG8zV1HVaW6DLEbEmRRtUCUQZcKxUD9atLYa3RZA+YJx+WZdOnTkDuNDNA==
|
||||||
|
|
||||||
"@next/swc-win32-arm64-msvc@14.2.20":
|
"@next/swc-win32-arm64-msvc@14.2.23":
|
||||||
version "14.2.20"
|
version "14.2.23"
|
||||||
resolved "https://registry.yarnpkg.com/@next/swc-win32-arm64-msvc/-/swc-win32-arm64-msvc-14.2.20.tgz#6f7783e643310510240a981776532ffe0e02af95"
|
resolved "https://registry.yarnpkg.com/@next/swc-win32-arm64-msvc/-/swc-win32-arm64-msvc-14.2.23.tgz#62786e7ba4822a20b6666e3e03e5a389b0e7eb3b"
|
||||||
integrity sha512-upJn2HGQgKNDbXVfIgmqT2BN8f3z/mX8ddoyi1I565FHbfowVK5pnMEwauvLvaJf4iijvuKq3kw/b6E9oIVRWA==
|
integrity sha512-ezmbgZy++XpIMTcTNd0L4k7+cNI4ET5vMv/oqNfTuSXkZtSA9BURElPFyarjjGtRgZ9/zuKDHoMdZwDZIY3ehQ==
|
||||||
|
|
||||||
"@next/swc-win32-ia32-msvc@14.2.20":
|
"@next/swc-win32-ia32-msvc@14.2.23":
|
||||||
version "14.2.20"
|
version "14.2.23"
|
||||||
resolved "https://registry.yarnpkg.com/@next/swc-win32-ia32-msvc/-/swc-win32-ia32-msvc-14.2.20.tgz#58c7720687e80a13795e22c29d5860fa142e44fc"
|
resolved "https://registry.yarnpkg.com/@next/swc-win32-ia32-msvc/-/swc-win32-ia32-msvc-14.2.23.tgz#ef028af91e1c40a4ebba0d2c47b23c1eeb299594"
|
||||||
integrity sha512-igQW/JWciTGJwj3G1ipalD2V20Xfx3ywQy17IV0ciOUBbFhNfyU1DILWsTi32c8KmqgIDviUEulW/yPb2FF90w==
|
integrity sha512-zfHZOGguFCqAJ7zldTKg4tJHPJyJCOFhpoJcVxKL9BSUHScVDnMdDuOU1zPPGdOzr/GWxbhYTjyiEgLEpAoFPA==
|
||||||
|
|
||||||
"@next/swc-win32-x64-msvc@14.2.20":
|
"@next/swc-win32-x64-msvc@14.2.23":
|
||||||
version "14.2.20"
|
version "14.2.23"
|
||||||
resolved "https://registry.yarnpkg.com/@next/swc-win32-x64-msvc/-/swc-win32-x64-msvc-14.2.20.tgz#689bc7beb8005b73c95d926e7edfb7f73efc78f2"
|
resolved "https://registry.yarnpkg.com/@next/swc-win32-x64-msvc/-/swc-win32-x64-msvc-14.2.23.tgz#c81838f02f2f16a321b7533890fb63c1edec68e1"
|
||||||
integrity sha512-AFmqeLW6LtxeFTuoB+MXFeM5fm5052i3MU6xD0WzJDOwku6SkZaxb1bxjBaRC8uNqTRTSPl0yMFtjNowIVI67w==
|
integrity sha512-xCtq5BD553SzOgSZ7UH5LH+OATQihydObTrCTvVzOro8QiWYKdBVwcB2Mn2MLMo6DGW9yH1LSPw7jS7HhgJgjw==
|
||||||
|
|
||||||
"@next/third-parties@^15.1.3":
|
"@next/third-parties@^15.1.3":
|
||||||
version "15.1.3"
|
version "15.1.3"
|
||||||
|
@ -3257,6 +3257,11 @@
|
||||||
resolved "https://registry.yarnpkg.com/@storybook/theming/-/theming-8.4.7.tgz#c308f6a883999bd35e87826738ab8a76515932b5"
|
resolved "https://registry.yarnpkg.com/@storybook/theming/-/theming-8.4.7.tgz#c308f6a883999bd35e87826738ab8a76515932b5"
|
||||||
integrity sha512-99rgLEjf7iwfSEmdqlHkSG3AyLcK0sfExcr0jnc6rLiAkBhzuIsvcHjjUwkR210SOCgXqBPW0ZA6uhnuyppHLw==
|
integrity sha512-99rgLEjf7iwfSEmdqlHkSG3AyLcK0sfExcr0jnc6rLiAkBhzuIsvcHjjUwkR210SOCgXqBPW0ZA6uhnuyppHLw==
|
||||||
|
|
||||||
|
"@stripe/stripe-js@^5.3.0":
|
||||||
|
version "5.4.0"
|
||||||
|
resolved "https://registry.yarnpkg.com/@stripe/stripe-js/-/stripe-js-5.4.0.tgz#847e870ddfe9283432526867857a4c1fba9b11ed"
|
||||||
|
integrity sha512-3tfMbSvLGB+OsJ2MsjWjWo+7sp29dwx+3+9kG/TEnZQJt+EwbF/Nomm43cSK+6oXZA9uhspgyrB+BbrPRumx4g==
|
||||||
|
|
||||||
"@supabase/auth-js@2.67.3":
|
"@supabase/auth-js@2.67.3":
|
||||||
version "2.67.3"
|
version "2.67.3"
|
||||||
resolved "https://registry.yarnpkg.com/@supabase/auth-js/-/auth-js-2.67.3.tgz#a1f5eb22440b0cdbf87fe2ecae662a8dd8bb2028"
|
resolved "https://registry.yarnpkg.com/@supabase/auth-js/-/auth-js-2.67.3.tgz#a1f5eb22440b0cdbf87fe2ecae662a8dd8bb2028"
|
||||||
|
@ -8976,12 +8981,12 @@ next-themes@^0.4.4:
|
||||||
resolved "https://registry.yarnpkg.com/next-themes/-/next-themes-0.4.4.tgz#ce6f68a4af543821bbc4755b59c0d3ced55c2d13"
|
resolved "https://registry.yarnpkg.com/next-themes/-/next-themes-0.4.4.tgz#ce6f68a4af543821bbc4755b59c0d3ced55c2d13"
|
||||||
integrity sha512-LDQ2qIOJF0VnuVrrMSMLrWGjRMkq+0mpgl6e0juCLqdJ+oo8Q84JRWT6Wh11VDQKkMMe+dVzDKLWs5n87T+PkQ==
|
integrity sha512-LDQ2qIOJF0VnuVrrMSMLrWGjRMkq+0mpgl6e0juCLqdJ+oo8Q84JRWT6Wh11VDQKkMMe+dVzDKLWs5n87T+PkQ==
|
||||||
|
|
||||||
next@^14.2.13:
|
next@^14.2.21:
|
||||||
version "14.2.20"
|
version "14.2.23"
|
||||||
resolved "https://registry.yarnpkg.com/next/-/next-14.2.20.tgz#99b551d87ca6505ce63074904cb31a35e21dac9b"
|
resolved "https://registry.yarnpkg.com/next/-/next-14.2.23.tgz#37edc9a4d42c135fd97a4092f829e291e2e7c943"
|
||||||
integrity sha512-yPvIiWsiyVYqJlSQxwmzMIReXn5HxFNq4+tlVQ812N1FbvhmE+fDpIAD7bcS2mGYQwPJ5vAsQouyme2eKsxaug==
|
integrity sha512-mjN3fE6u/tynneLiEg56XnthzuYw+kD7mCujgVqioxyPqbmiotUCGJpIZGS/VaPg3ZDT1tvWxiVyRzeqJFm/kw==
|
||||||
dependencies:
|
dependencies:
|
||||||
"@next/env" "14.2.20"
|
"@next/env" "14.2.23"
|
||||||
"@swc/helpers" "0.5.5"
|
"@swc/helpers" "0.5.5"
|
||||||
busboy "1.6.0"
|
busboy "1.6.0"
|
||||||
caniuse-lite "^1.0.30001579"
|
caniuse-lite "^1.0.30001579"
|
||||||
|
@ -8989,15 +8994,15 @@ next@^14.2.13:
|
||||||
postcss "8.4.31"
|
postcss "8.4.31"
|
||||||
styled-jsx "5.1.1"
|
styled-jsx "5.1.1"
|
||||||
optionalDependencies:
|
optionalDependencies:
|
||||||
"@next/swc-darwin-arm64" "14.2.20"
|
"@next/swc-darwin-arm64" "14.2.23"
|
||||||
"@next/swc-darwin-x64" "14.2.20"
|
"@next/swc-darwin-x64" "14.2.23"
|
||||||
"@next/swc-linux-arm64-gnu" "14.2.20"
|
"@next/swc-linux-arm64-gnu" "14.2.23"
|
||||||
"@next/swc-linux-arm64-musl" "14.2.20"
|
"@next/swc-linux-arm64-musl" "14.2.23"
|
||||||
"@next/swc-linux-x64-gnu" "14.2.20"
|
"@next/swc-linux-x64-gnu" "14.2.23"
|
||||||
"@next/swc-linux-x64-musl" "14.2.20"
|
"@next/swc-linux-x64-musl" "14.2.23"
|
||||||
"@next/swc-win32-arm64-msvc" "14.2.20"
|
"@next/swc-win32-arm64-msvc" "14.2.23"
|
||||||
"@next/swc-win32-ia32-msvc" "14.2.20"
|
"@next/swc-win32-ia32-msvc" "14.2.23"
|
||||||
"@next/swc-win32-x64-msvc" "14.2.20"
|
"@next/swc-win32-x64-msvc" "14.2.23"
|
||||||
|
|
||||||
no-case@^3.0.4:
|
no-case@^3.0.4:
|
||||||
version "3.0.4"
|
version "3.0.4"
|
||||||
|
|
Binary file not shown.
Before Width: | Height: | Size: 105 KiB After Width: | Height: | Size: 81 KiB |
|
@ -257,13 +257,13 @@ response = requests.post(
|
||||||
)
|
)
|
||||||
```
|
```
|
||||||
|
|
||||||
or use the shortcut `credentials.bearer()`:
|
or use the shortcut `credentials.auth_header()`:
|
||||||
|
|
||||||
```python
|
```python
|
||||||
# credentials: APIKeyCredentials | OAuth2Credentials
|
# credentials: APIKeyCredentials | OAuth2Credentials
|
||||||
response = requests.post(
|
response = requests.post(
|
||||||
url,
|
url,
|
||||||
headers={"Authorization": credentials.bearer()},
|
headers={"Authorization": credentials.auth_header()},
|
||||||
)
|
)
|
||||||
```
|
```
|
||||||
|
|
||||||
|
|
|
@ -45,13 +45,7 @@ Now that both Ollama and the AutoGPT platform are running we can move onto using
|
||||||
2. In the "LLM Model" dropdown, select "llama3.2" (This is the model we downloaded earlier)
|
2. In the "LLM Model" dropdown, select "llama3.2" (This is the model we downloaded earlier)
|
||||||

|

|
||||||
|
|
||||||
3. You will see it ask for "Ollama Credentials", simply press "Enter API key"
|
3. Now we need to add some prompts then save and then run the graph:
|
||||||

|
|
||||||
|
|
||||||
And you will see "Add new API key for Ollama", In the API key field you can enter anything you want as Ollama does not require an API key, I usually just enter a space, for the Name call it "Ollama" then press "Save & use this API key"
|
|
||||||

|
|
||||||
|
|
||||||
4. After that you will now see the block again, add your prompts then save and then run the graph:
|
|
||||||

|

|
||||||
|
|
||||||
That's it! You've successfully setup the AutoGPT platform and made a LLM call to Ollama.
|
That's it! You've successfully setup the AutoGPT platform and made a LLM call to Ollama.
|
||||||
|
|
Loading…
Reference in New Issue