feat(rnd): Add strong pydantic type & composite data extraction for Block input/output schema + add reddit agent-blocks (#7288)
* feat(rnd): Add type hint and strong pydantic type validation for block input/output + add reddit agent-blocks. * feat(rnd): Add type hint and strong pydantic type validation for block input/output + add reddit agent-blocks. * Fix reddit block * Fix serialization * Eliminate deprecated class property * Remove RedditCredentialsBlock * Cache jsonschema computation, add dictionary construction * Add dict_split and list_split to output, add more blocks * Add objc_split for completeness, int both input and output * Update reddit block * Add reddit test (untested) * Resolved json issue on pydantic * Add creds check on client * Add dict <--> pydantic object flexibility * Fix error retry * Skip reddit test * Code cleanup * Chang prompt * Make this work * Fix linting * Hide input_links and output_links from Node * Add docs --------- Co-authored-by: Aarushi <50577581+aarushik93@users.noreply.github.com>pull/7319/head
parent
db0e726954
commit
833944e228
|
@ -0,0 +1,9 @@
|
|||
from autogpt_server.blocks import sample, reddit, text, object, ai
|
||||
from autogpt_server.data.block import Block
|
||||
|
||||
AVAILABLE_BLOCKS = {
|
||||
block.id: block
|
||||
for block in [v() for v in Block.__subclasses__()]
|
||||
}
|
||||
|
||||
__all__ = ["ai", "object", "sample", "reddit", "text", "AVAILABLE_BLOCKS"]
|
|
@ -0,0 +1,107 @@
|
|||
import logging
|
||||
from enum import Enum
|
||||
|
||||
import openai
|
||||
from pydantic import BaseModel
|
||||
|
||||
from autogpt_server.data.block import Block, BlockOutput, BlockSchema
|
||||
from autogpt_server.util import json
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class LlmModel(str, Enum):
|
||||
openai_gpt4 = "gpt-4-turbo"
|
||||
|
||||
|
||||
class LlmConfig(BaseModel):
|
||||
model: LlmModel
|
||||
api_key: str
|
||||
|
||||
|
||||
class LlmCallBlock(Block):
|
||||
class Input(BlockSchema):
|
||||
config: LlmConfig
|
||||
expected_format: dict[str, str]
|
||||
sys_prompt: str = ""
|
||||
usr_prompt: str = ""
|
||||
retry: int = 3
|
||||
|
||||
class Output(BlockSchema):
|
||||
response: dict[str, str]
|
||||
error: str
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="ed55ac19-356e-4243-a6cb-bc599e9b716f",
|
||||
input_schema=LlmCallBlock.Input,
|
||||
output_schema=LlmCallBlock.Output,
|
||||
)
|
||||
|
||||
def run(self, input_data: Input) -> BlockOutput:
|
||||
openai.api_key = input_data.config.api_key
|
||||
expected_format = [f'"{k}": "{v}"' for k, v in
|
||||
input_data.expected_format.items()]
|
||||
|
||||
format_prompt = ",\n ".join(expected_format)
|
||||
sys_prompt = f"""
|
||||
|{input_data.sys_prompt}
|
||||
|
|
||||
|Reply in json format:
|
||||
|{{
|
||||
| {format_prompt}
|
||||
|}}
|
||||
"""
|
||||
usr_prompt = f"""
|
||||
|{input_data.usr_prompt}
|
||||
"""
|
||||
|
||||
def trim_prompt(s: str) -> str:
|
||||
lines = s.strip().split("\n")
|
||||
return "\n".join([line.strip().lstrip("|") for line in lines])
|
||||
|
||||
def parse_response(resp: str) -> tuple[dict[str, str], str | None]:
|
||||
try:
|
||||
parsed = json.loads(resp)
|
||||
miss_keys = set(input_data.expected_format.keys()) - set(parsed.keys())
|
||||
if miss_keys:
|
||||
return parsed, f"Missing keys: {miss_keys}"
|
||||
return parsed, None
|
||||
except Exception as e:
|
||||
return {}, f"JSON decode error: {e}"
|
||||
|
||||
prompt = [
|
||||
{"role": "system", "content": trim_prompt(sys_prompt)},
|
||||
{"role": "user", "content": trim_prompt(usr_prompt)},
|
||||
]
|
||||
|
||||
logger.warning(f"LLM request: {prompt}")
|
||||
retry_prompt = ""
|
||||
for retry_count in range(input_data.retry):
|
||||
response = openai.chat.completions.create(
|
||||
model=input_data.config.model,
|
||||
messages=prompt, # type: ignore
|
||||
response_format={"type": "json_object"},
|
||||
)
|
||||
response_text = response.choices[0].message.content or ""
|
||||
logger.warning(f"LLM attempt-{retry_count} response: {response_text}")
|
||||
|
||||
parsed_dict, parsed_error = parse_response(response_text)
|
||||
if not parsed_error:
|
||||
yield "response", {k: str(v) for k, v in parsed_dict.items()}
|
||||
return
|
||||
|
||||
retry_prompt = f"""
|
||||
|This is your previous error response:
|
||||
|--
|
||||
|{response_text}
|
||||
|--
|
||||
|
|
||||
|And this is the error:
|
||||
|--
|
||||
|{parsed_error}
|
||||
|--
|
||||
"""
|
||||
prompt.append({"role": "user", "content": trim_prompt(retry_prompt)})
|
||||
|
||||
yield "error", retry_prompt
|
|
@ -0,0 +1,33 @@
|
|||
from typing import Any
|
||||
|
||||
from autogpt_server.data.block import Block, BlockOutput, BlockSchema
|
||||
|
||||
|
||||
class ObjectParser(Block):
|
||||
class Input(BlockSchema):
|
||||
object: Any
|
||||
field_path: str
|
||||
|
||||
class Output(BlockSchema):
|
||||
field_value: Any
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="be45299a-193b-4852-bda4-510883d21814",
|
||||
input_schema=ObjectParser.Input,
|
||||
output_schema=ObjectParser.Output,
|
||||
)
|
||||
|
||||
def run(self, input_data: Input) -> BlockOutput:
|
||||
field_path = input_data.field_path.split(".")
|
||||
field_value = input_data.object
|
||||
for field in field_path:
|
||||
if isinstance(field_value, dict) and field in field_value:
|
||||
field_value = field_value.get(field)
|
||||
elif isinstance(field_value, object) and hasattr(field_value, field):
|
||||
field_value = getattr(field_value, field)
|
||||
else:
|
||||
yield "error", input_data.object
|
||||
return
|
||||
|
||||
yield "field_value", field_value
|
|
@ -0,0 +1,107 @@
|
|||
# type: ignore
|
||||
|
||||
from datetime import datetime, timedelta
|
||||
|
||||
import praw
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from autogpt_server.data.block import Block, BlockOutput, BlockSchema
|
||||
|
||||
|
||||
class RedditCredentials(BaseModel):
|
||||
client_id: str
|
||||
client_secret: str
|
||||
username: str
|
||||
password: str
|
||||
user_agent: str | None = None
|
||||
|
||||
|
||||
class RedditPost(BaseModel):
|
||||
id: str
|
||||
subreddit: str
|
||||
title: str
|
||||
body: str
|
||||
|
||||
|
||||
def get_praw(creds: RedditCredentials) -> praw.Reddit:
|
||||
client = praw.Reddit(
|
||||
client_id=creds.client_id,
|
||||
client_secret=creds.client_secret,
|
||||
user_agent=creds.user_agent,
|
||||
username=creds.username,
|
||||
password=creds.password,
|
||||
)
|
||||
me = client.user.me()
|
||||
if not me:
|
||||
raise ValueError("Invalid Reddit credentials.")
|
||||
print(f"Logged in as Reddit user: {me.name}")
|
||||
return client
|
||||
|
||||
|
||||
class RedditGetPostsBlock(Block):
|
||||
class Input(BlockSchema):
|
||||
creds: RedditCredentials = Field(description="Reddit credentials")
|
||||
subreddit: str = Field(description="Subreddit name")
|
||||
last_minutes: int | None = Field(
|
||||
description="Post time to stop minutes ago while fetching posts",
|
||||
default=None
|
||||
)
|
||||
last_post: str | None = Field(
|
||||
description="Post ID to stop when reached while fetching posts",
|
||||
default=None
|
||||
)
|
||||
post_limit: int | None = Field(
|
||||
description="Number of posts to fetch",
|
||||
default=10
|
||||
)
|
||||
|
||||
class Output(BlockSchema):
|
||||
post: RedditPost = Field(description="Reddit post")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="c6731acb-4285-4ee1-bc9b-03d0766c370f",
|
||||
input_schema=RedditGetPostsBlock.Input,
|
||||
output_schema=RedditGetPostsBlock.Output,
|
||||
)
|
||||
|
||||
def run(self, input_data: Input) -> BlockOutput:
|
||||
client = get_praw(input_data.creds)
|
||||
subreddit = client.subreddit(input_data.subreddit)
|
||||
for post in subreddit.new(limit=input_data.post_limit):
|
||||
if input_data.last_post and post.created_utc < datetime.now() - \
|
||||
timedelta(minutes=input_data.last_minutes):
|
||||
break
|
||||
|
||||
if input_data.last_post and post.id == input_data.last_post:
|
||||
break
|
||||
|
||||
yield "post", RedditPost(
|
||||
id=post.id,
|
||||
subreddit=subreddit.display_name,
|
||||
title=post.title,
|
||||
body=post.selftext
|
||||
)
|
||||
|
||||
|
||||
class RedditPostCommentBlock(Block):
|
||||
class Input(BlockSchema):
|
||||
creds: RedditCredentials = Field(description="Reddit credentials")
|
||||
post_id: str = Field(description="Reddit post ID")
|
||||
comment: str = Field(description="Comment text")
|
||||
|
||||
class Output(BlockSchema):
|
||||
comment_id: str
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="4a92261b-701e-4ffb-8970-675fd28e261f",
|
||||
input_schema=RedditPostCommentBlock.Input,
|
||||
output_schema=RedditPostCommentBlock.Output,
|
||||
)
|
||||
|
||||
def run(self, input_data: Input) -> BlockOutput:
|
||||
client = get_praw(input_data.creds)
|
||||
submission = client.submission(id=input_data.post_id)
|
||||
comment = submission.reply(input_data.comment)
|
||||
yield "comment_id", comment.id
|
|
@ -0,0 +1,40 @@
|
|||
# type: ignore
|
||||
|
||||
from autogpt_server.data.block import Block, BlockSchema, BlockOutput
|
||||
|
||||
|
||||
class ParrotBlock(Block):
|
||||
class Input(BlockSchema):
|
||||
input: str
|
||||
|
||||
class Output(BlockSchema):
|
||||
output: str
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="1ff065e9-88e8-4358-9d82-8dc91f622ba9",
|
||||
input_schema=ParrotBlock.Input,
|
||||
output_schema=ParrotBlock.Output,
|
||||
)
|
||||
|
||||
def run(self, input_data: Input) -> BlockOutput:
|
||||
yield "output", input_data.input
|
||||
|
||||
|
||||
class PrintingBlock(Block):
|
||||
class Input(BlockSchema):
|
||||
text: str
|
||||
|
||||
class Output(BlockSchema):
|
||||
status: str
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="f3b1c1b2-4c4f-4f0d-8d2f-4c4f0d8d2f4c",
|
||||
input_schema=PrintingBlock.Input,
|
||||
output_schema=PrintingBlock.Output,
|
||||
)
|
||||
|
||||
def run(self, input_data: Input) -> BlockOutput:
|
||||
print(">>>>> Print: ", input_data.text)
|
||||
yield "status", "printed"
|
|
@ -0,0 +1,63 @@
|
|||
import re
|
||||
|
||||
from typing import Any
|
||||
from pydantic import Field
|
||||
from autogpt_server.data.block import Block, BlockOutput, BlockSchema
|
||||
|
||||
|
||||
class TextMatcherBlock(Block):
|
||||
class Input(BlockSchema):
|
||||
text: str = Field(description="Text to match")
|
||||
match: str = Field(description="Pattern (Regex) to match")
|
||||
data: Any = Field(description="Data to be forwarded to output")
|
||||
case_sensitive: bool = Field(description="Case sensitive match", default=True)
|
||||
|
||||
class Output(BlockSchema):
|
||||
positive: Any = Field(description="Output data if match is found")
|
||||
negative: Any = Field(description="Output data if match is not found")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="3060088f-6ed9-4928-9ba7-9c92823a7ccd",
|
||||
input_schema=TextMatcherBlock.Input,
|
||||
output_schema=TextMatcherBlock.Output,
|
||||
)
|
||||
|
||||
def run(self, input_data: Input) -> BlockOutput:
|
||||
output = input_data.data or input_data.text
|
||||
case = 0 if input_data.case_sensitive else re.IGNORECASE
|
||||
if re.search(input_data.match, input_data.text, case):
|
||||
yield "positive", output
|
||||
else:
|
||||
yield "negative", output
|
||||
|
||||
|
||||
class TextFormatterBlock(Block):
|
||||
class Input(BlockSchema):
|
||||
texts: list[str] = Field(
|
||||
description="Texts (list) to format",
|
||||
default=[]
|
||||
)
|
||||
named_texts: dict[str, str] = Field(
|
||||
description="Texts (dict) to format",
|
||||
default={}
|
||||
)
|
||||
format: str = Field(
|
||||
description="Template to format the text using `texts` and `named_texts`",
|
||||
)
|
||||
|
||||
class Output(BlockSchema):
|
||||
output: str
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="db7d8f02-2f44-4c55-ab7a-eae0941f0c30",
|
||||
input_schema=TextFormatterBlock.Input,
|
||||
output_schema=TextFormatterBlock.Output,
|
||||
)
|
||||
|
||||
def run(self, input_data: Input) -> BlockOutput:
|
||||
yield "output", input_data.format.format(
|
||||
texts=input_data.texts,
|
||||
**input_data.named_texts,
|
||||
)
|
|
@ -1,140 +1,112 @@
|
|||
import json
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Any, Generator, ClassVar
|
||||
from typing import Any, cast, ClassVar, Generator, Generic, TypeVar, Type
|
||||
|
||||
import jsonref
|
||||
import jsonschema
|
||||
from prisma.models import AgentBlock
|
||||
from pydantic import BaseModel
|
||||
|
||||
from autogpt_server.util import json
|
||||
|
||||
BlockData = dict[str, Any]
|
||||
|
||||
|
||||
class BlockSchema(BaseModel):
|
||||
"""
|
||||
A schema for the block input and output data.
|
||||
The dictionary structure is an object-typed `jsonschema`.
|
||||
The top-level properties are the block input/output names.
|
||||
|
||||
cached_jsonschema: ClassVar[dict[str, Any]] = {}
|
||||
|
||||
You can initialize this class by providing a dictionary of properties.
|
||||
The key is the string of the property name, and the value is either
|
||||
a string of the type or a dictionary of the jsonschema.
|
||||
@classmethod
|
||||
def jsonschema(cls) -> dict[str, Any]:
|
||||
if cls.cached_jsonschema:
|
||||
return cls.cached_jsonschema
|
||||
|
||||
You can also provide additional keyword arguments for additional properties.
|
||||
Like `name`, `required` (by default all properties are required), etc.
|
||||
model = jsonref.replace_refs(cls.model_json_schema())
|
||||
|
||||
Example:
|
||||
input_schema = BlockSchema({
|
||||
"system_prompt": "string",
|
||||
"user_prompt": "string",
|
||||
"max_tokens": "integer",
|
||||
"user_info": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"name": {"type": "string"},
|
||||
"age": {"type": "integer"},
|
||||
},
|
||||
"required": ["name"],
|
||||
},
|
||||
}, required=["system_prompt", "user_prompt"])
|
||||
def ref_to_dict(obj):
|
||||
if isinstance(obj, dict):
|
||||
return {
|
||||
key: ref_to_dict(value)
|
||||
for key, value in obj.items() if not key.startswith("$")
|
||||
}
|
||||
elif isinstance(obj, list):
|
||||
return [ref_to_dict(item) for item in obj]
|
||||
return obj
|
||||
|
||||
output_schema = BlockSchema({
|
||||
"on_complete": "string",
|
||||
"on_failures": "string",
|
||||
})
|
||||
"""
|
||||
cls.cached_jsonschema = cast(dict[str, Any], ref_to_dict(model))
|
||||
return cls.cached_jsonschema
|
||||
|
||||
jsonschema: dict[str, Any]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
properties: dict[str, str | dict],
|
||||
required: list[str] | None = None,
|
||||
**kwargs: Any,
|
||||
):
|
||||
schema = {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
key: {"type": value} if isinstance(value, str) else value
|
||||
for key, value in properties.items()
|
||||
},
|
||||
"required": required or list(properties.keys()),
|
||||
**kwargs,
|
||||
}
|
||||
super().__init__(jsonschema=schema)
|
||||
|
||||
def __str__(self) -> str:
|
||||
return json.dumps(self.jsonschema)
|
||||
|
||||
def validate_data(self, data: BlockData) -> str | None:
|
||||
@classmethod
|
||||
def validate_data(cls, data: BlockData) -> str | None:
|
||||
"""
|
||||
Validate the data against the schema.
|
||||
Returns the validation error message if the data does not match the schema.
|
||||
"""
|
||||
try:
|
||||
jsonschema.validate(data, self.jsonschema)
|
||||
jsonschema.validate(data, cls.jsonschema())
|
||||
return None
|
||||
except jsonschema.ValidationError as e:
|
||||
return str(e)
|
||||
|
||||
def validate_field(self, field_name: str, data: BlockData) -> str | None:
|
||||
@classmethod
|
||||
def validate_field(cls, field_name: str, data: BlockData) -> str | None:
|
||||
"""
|
||||
Validate the data against a specific property (one of the input/output name).
|
||||
Returns the validation error message if the data does not match the schema.
|
||||
"""
|
||||
property_schema = self.jsonschema["properties"].get(field_name)
|
||||
model_schema = cls.jsonschema().get("properties", {})
|
||||
if not model_schema:
|
||||
return f"Invalid model schema {cls}"
|
||||
|
||||
property_schema = model_schema.get(field_name)
|
||||
if not property_schema:
|
||||
return f"Invalid property name {field_name}"
|
||||
|
||||
try:
|
||||
jsonschema.validate(data, property_schema)
|
||||
jsonschema.validate(json.to_dict(data), property_schema)
|
||||
return None
|
||||
except jsonschema.ValidationError as e:
|
||||
return str(e)
|
||||
|
||||
def get_fields(self) -> set[str]:
|
||||
return set(self.jsonschema["properties"].keys())
|
||||
@classmethod
|
||||
def get_fields(cls) -> set[str]:
|
||||
return set(cls.model_fields.keys())
|
||||
|
||||
def get_required_fields(self) -> set[str]:
|
||||
return set(self.jsonschema["required"])
|
||||
@classmethod
|
||||
def get_required_fields(cls) -> set[str]:
|
||||
return {
|
||||
field
|
||||
for field, field_info in cls.model_fields.items()
|
||||
if field_info.is_required()
|
||||
}
|
||||
|
||||
|
||||
BlockOutput = Generator[tuple[str, Any], None, None]
|
||||
BlockSchemaInputType = TypeVar('BlockSchemaInputType', bound=BlockSchema)
|
||||
BlockSchemaOutputType = TypeVar('BlockSchemaOutputType', bound=BlockSchema)
|
||||
|
||||
|
||||
class Block(ABC, BaseModel):
|
||||
@classmethod
|
||||
@property
|
||||
@abstractmethod
|
||||
def id(cls) -> str:
|
||||
class EmptySchema(BlockSchema):
|
||||
pass
|
||||
|
||||
|
||||
class Block(ABC, Generic[BlockSchemaInputType, BlockSchemaOutputType]):
|
||||
def __init__(
|
||||
self,
|
||||
id: str = "",
|
||||
input_schema: Type[BlockSchemaInputType] = EmptySchema,
|
||||
output_schema: Type[BlockSchemaOutputType] = EmptySchema,
|
||||
):
|
||||
"""
|
||||
The unique identifier for the block, this value will be persisted in the DB.
|
||||
So it should be a unique and constant across the application run.
|
||||
Use the UUID format for the ID.
|
||||
"""
|
||||
pass
|
||||
|
||||
@classmethod
|
||||
@property
|
||||
@abstractmethod
|
||||
def input_schema(cls) -> BlockSchema:
|
||||
"""
|
||||
The schema for the block input data.
|
||||
The top-level properties are the possible input name expected by the block.
|
||||
"""
|
||||
pass
|
||||
|
||||
@classmethod
|
||||
@property
|
||||
@abstractmethod
|
||||
def output_schema(cls) -> BlockSchema:
|
||||
"""
|
||||
The schema for the block output.
|
||||
The top-level properties are the possible output name produced by the block.
|
||||
"""
|
||||
pass
|
||||
self.id = id
|
||||
self.input_schema = input_schema
|
||||
self.output_schema = output_schema
|
||||
|
||||
@abstractmethod
|
||||
def run(self, input_data: BlockData) -> BlockOutput:
|
||||
def run(self, input_data: BlockSchemaInputType) -> BlockOutput:
|
||||
"""
|
||||
Run the block with the given input data.
|
||||
Args:
|
||||
|
@ -146,17 +118,16 @@ class Block(ABC, BaseModel):
|
|||
"""
|
||||
pass
|
||||
|
||||
@classmethod
|
||||
@property
|
||||
def name(cls):
|
||||
return cls.__name__
|
||||
def name(self):
|
||||
return self.__class__.__name__
|
||||
|
||||
def to_dict(self):
|
||||
return {
|
||||
"id": self.id,
|
||||
"name": self.name,
|
||||
"inputSchema": self.input_schema.jsonschema,
|
||||
"outputSchema": self.output_schema.jsonschema,
|
||||
"inputSchema": self.input_schema.jsonschema(),
|
||||
"outputSchema": self.output_schema.jsonschema(),
|
||||
}
|
||||
|
||||
def execute(self, input_data: BlockData) -> BlockOutput:
|
||||
|
@ -165,83 +136,20 @@ class Block(ABC, BaseModel):
|
|||
f"Unable to execute block with invalid input data: {error}"
|
||||
)
|
||||
|
||||
for output_name, output_data in self.run(input_data):
|
||||
for output_name, output_data in self.run(self.input_schema(**input_data)):
|
||||
if error := self.output_schema.validate_field(output_name, output_data):
|
||||
raise ValueError(
|
||||
f"Unable to execute block with invalid output data: {error}"
|
||||
f"Block produced an invalid output data: {error}"
|
||||
)
|
||||
yield output_name, output_data
|
||||
|
||||
|
||||
# ===================== Inline-Block Implementations ===================== #
|
||||
|
||||
|
||||
class ParrotBlock(Block):
|
||||
id: ClassVar[str] = "1ff065e9-88e8-4358-9d82-8dc91f622ba9" # type: ignore
|
||||
input_schema: ClassVar[BlockSchema] = BlockSchema( # type: ignore
|
||||
{
|
||||
"input": "string",
|
||||
}
|
||||
)
|
||||
output_schema: ClassVar[BlockSchema] = BlockSchema( # type: ignore
|
||||
{
|
||||
"output": "string",
|
||||
}
|
||||
)
|
||||
|
||||
def run(self, input_data: BlockData) -> BlockOutput:
|
||||
yield "output", input_data["input"]
|
||||
|
||||
|
||||
class TextFormatterBlock(Block):
|
||||
id: ClassVar[str] = "db7d8f02-2f44-4c55-ab7a-eae0941f0c30" # type: ignore
|
||||
input_schema: ClassVar[BlockSchema] = BlockSchema( # type: ignore
|
||||
{
|
||||
"texts": {
|
||||
"type": "array",
|
||||
"items": {"type": "string"},
|
||||
"minItems": 1,
|
||||
},
|
||||
"format": "string",
|
||||
}
|
||||
)
|
||||
output_schema: ClassVar[BlockSchema] = BlockSchema( # type: ignore
|
||||
{
|
||||
"combined_text": "string",
|
||||
}
|
||||
)
|
||||
|
||||
def run(self, input_data: BlockData) -> BlockOutput:
|
||||
yield "combined_text", input_data["format"].format(texts=input_data["texts"])
|
||||
|
||||
|
||||
class PrintingBlock(Block):
|
||||
id: ClassVar[str] = "f3b1c1b2-4c4f-4f0d-8d2f-4c4f0d8d2f4c" # type: ignore
|
||||
input_schema: ClassVar[BlockSchema] = BlockSchema( # type: ignore
|
||||
{
|
||||
"text": "string",
|
||||
}
|
||||
)
|
||||
output_schema: ClassVar[BlockSchema] = BlockSchema( # type: ignore
|
||||
{
|
||||
"status": "string",
|
||||
}
|
||||
)
|
||||
|
||||
def run(self, input_data: BlockData) -> BlockOutput:
|
||||
yield "status", "printed"
|
||||
|
||||
|
||||
# ======================= Block Helper Functions ======================= #
|
||||
|
||||
AVAILABLE_BLOCKS: dict[str, Block] = {}
|
||||
from autogpt_server.blocks import AVAILABLE_BLOCKS # noqa: E402
|
||||
|
||||
|
||||
async def initialize_blocks() -> None:
|
||||
global AVAILABLE_BLOCKS
|
||||
|
||||
AVAILABLE_BLOCKS = {block.id: block() for block in Block.__subclasses__()}
|
||||
|
||||
for block in AVAILABLE_BLOCKS.values():
|
||||
if await AgentBlock.prisma().find_unique(where={"id": block.id}):
|
||||
continue
|
||||
|
@ -250,19 +158,15 @@ async def initialize_blocks() -> None:
|
|||
data={
|
||||
"id": block.id,
|
||||
"name": block.name,
|
||||
"inputSchema": str(block.input_schema),
|
||||
"outputSchema": str(block.output_schema),
|
||||
"inputSchema": json.dumps(block.input_schema.jsonschema()),
|
||||
"outputSchema": json.dumps(block.output_schema.jsonschema()),
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
async def get_blocks() -> list[Block]:
|
||||
if not AVAILABLE_BLOCKS:
|
||||
await initialize_blocks()
|
||||
def get_blocks() -> list[Block]:
|
||||
return list(AVAILABLE_BLOCKS.values())
|
||||
|
||||
|
||||
async def get_block(block_id: str) -> Block | None:
|
||||
if not AVAILABLE_BLOCKS:
|
||||
await initialize_blocks()
|
||||
def get_block(block_id: str) -> Block | None:
|
||||
return AVAILABLE_BLOCKS.get(block_id)
|
||||
|
|
|
@ -1,4 +1,3 @@
|
|||
import json
|
||||
from collections import defaultdict
|
||||
from datetime import datetime
|
||||
from enum import Enum
|
||||
|
@ -12,6 +11,8 @@ from prisma.models import (
|
|||
)
|
||||
from pydantic import BaseModel
|
||||
|
||||
from autogpt_server.util import json
|
||||
|
||||
|
||||
class NodeExecution(BaseModel):
|
||||
graph_exec_id: str
|
||||
|
@ -148,7 +149,6 @@ async def upsert_execution_input(
|
|||
json_data = json.dumps(data)
|
||||
|
||||
if existing_execution:
|
||||
print(f"Adding input {input_name}={data} to execution #{existing_execution.id}")
|
||||
await AgentNodeExecutionInputOutput.prisma().create(
|
||||
data={
|
||||
"name": input_name,
|
||||
|
@ -159,7 +159,6 @@ async def upsert_execution_input(
|
|||
return existing_execution.id
|
||||
|
||||
else:
|
||||
print(f"Creating new execution for input {input_name}={data}")
|
||||
result = await AgentNodeExecution.prisma().create(
|
||||
data={
|
||||
"agentNodeId": node_id,
|
||||
|
@ -240,16 +239,46 @@ async def get_node_execution_input(node_exec_id: str) -> dict[str, Any]:
|
|||
return merge_execution_input(exec_input)
|
||||
|
||||
|
||||
SPLIT = "_$_"
|
||||
LIST_SPLIT = "_$_"
|
||||
DICT_SPLIT = "_#_"
|
||||
OBJC_SPLIT = "_@_"
|
||||
|
||||
|
||||
def parse_execution_output(output: tuple[str, Any], name: str) -> Any | None:
|
||||
# Allow extracting partial output data by name.
|
||||
output_name, output_data = output
|
||||
|
||||
if name == output_name:
|
||||
return output_data
|
||||
|
||||
if name.startswith(f"{output_name}{LIST_SPLIT}"):
|
||||
index = int(name.split(LIST_SPLIT)[1])
|
||||
if not isinstance(output_data, list) or len(output_data) <= index:
|
||||
return None
|
||||
return output_data[int(name.split(LIST_SPLIT)[1])]
|
||||
|
||||
if name.startswith(f"{output_name}{DICT_SPLIT}"):
|
||||
index = name.split(DICT_SPLIT)[1]
|
||||
if not isinstance(output_data, dict) or index not in output_data:
|
||||
return None
|
||||
return output_data[index]
|
||||
|
||||
if name.startswith(f"{output_name}{OBJC_SPLIT}"):
|
||||
index = name.split(OBJC_SPLIT)[1]
|
||||
if isinstance(output_data, object) and hasattr(output_data, index):
|
||||
return getattr(output_data, index)
|
||||
return None
|
||||
|
||||
return None
|
||||
|
||||
|
||||
def merge_execution_input(data: dict[str, Any]) -> dict[str, Any]:
|
||||
# Merge all input with <input_name>_$_<index> into a single list.
|
||||
list_input = []
|
||||
for key, value in data.items():
|
||||
if SPLIT not in key:
|
||||
if LIST_SPLIT not in key:
|
||||
continue
|
||||
|
||||
name, index = key.split(SPLIT)
|
||||
name, index = key.split(LIST_SPLIT)
|
||||
if not index.isdigit():
|
||||
list_input.append((name, value, 0))
|
||||
else:
|
||||
|
@ -259,4 +288,21 @@ def merge_execution_input(data: dict[str, Any]) -> dict[str, Any]:
|
|||
data[name] = data.get(name, [])
|
||||
data[name].append(value)
|
||||
|
||||
# Merge all input with <input_name>_#_<index> into a single dict.
|
||||
for key, value in data.items():
|
||||
if DICT_SPLIT not in key:
|
||||
continue
|
||||
name, index = key.split(DICT_SPLIT)
|
||||
data[name] = data.get(name, {})
|
||||
data[name][index] = value
|
||||
|
||||
# Merge all input with <input_name>_@_<index> into a single object.
|
||||
for key, value in data.items():
|
||||
if OBJC_SPLIT not in key:
|
||||
continue
|
||||
name, index = key.split(OBJC_SPLIT)
|
||||
if not isinstance(data[name], object):
|
||||
data[name] = type("Object", (object,), data[name])()
|
||||
setattr(data[name], index, value)
|
||||
|
||||
return data
|
||||
|
|
|
@ -1,65 +1,82 @@
|
|||
import asyncio
|
||||
import json
|
||||
import uuid
|
||||
from typing import Any
|
||||
|
||||
from typing import Any
|
||||
from prisma.models import AgentGraph, AgentNode, AgentNodeLink
|
||||
from pydantic import BaseModel
|
||||
from pydantic import BaseModel, PrivateAttr
|
||||
|
||||
from autogpt_server.data.db import BaseDbModel
|
||||
from autogpt_server.util import json
|
||||
|
||||
|
||||
class Link(BaseModel):
|
||||
name: str
|
||||
node_id: str
|
||||
|
||||
def __init__(self, name: str, node_id: str):
|
||||
super().__init__(name=name, node_id=node_id)
|
||||
|
||||
def __iter__(self):
|
||||
return iter((self.name, self.node_id))
|
||||
source_id: str
|
||||
sink_id: str
|
||||
source_name: str
|
||||
sink_name: str
|
||||
|
||||
def __init__(self, source_id: str, sink_id: str, source_name: str, sink_name: str):
|
||||
super().__init__(
|
||||
source_id=source_id,
|
||||
sink_id=sink_id,
|
||||
source_name=source_name,
|
||||
sink_name=sink_name,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def from_db(link: AgentNodeLink):
|
||||
return Link(
|
||||
source_name=link.sourceName,
|
||||
source_id=link.agentNodeSourceId,
|
||||
sink_name=link.sinkName,
|
||||
sink_id=link.agentNodeSinkId,
|
||||
)
|
||||
|
||||
def __hash__(self):
|
||||
return hash((self.source_id, self.sink_id, self.source_name, self.sink_name))
|
||||
|
||||
|
||||
class Node(BaseDbModel):
|
||||
block_id: str
|
||||
input_default: dict[str, Any] = {} # dict[input_name, default_value]
|
||||
input_nodes: list[Link] = [] # dict[input_name, node_id]
|
||||
output_nodes: list[Link] = [] # dict[output_name, node_id]
|
||||
metadata: dict[str, Any] = {}
|
||||
|
||||
_input_links: list[Link] = PrivateAttr(default=[])
|
||||
_output_links: list[Link] = PrivateAttr(default=[])
|
||||
|
||||
@property
|
||||
def input_links(self) -> list[Link]:
|
||||
return self._input_links
|
||||
|
||||
@property
|
||||
def output_links(self) -> list[Link]:
|
||||
return self._output_links
|
||||
|
||||
@staticmethod
|
||||
def from_db(node: AgentNode):
|
||||
if not node.AgentBlock:
|
||||
raise ValueError(f"Invalid node {node.id}, invalid AgentBlock.")
|
||||
|
||||
return Node(
|
||||
obj = Node(
|
||||
id=node.id,
|
||||
block_id=node.AgentBlock.id,
|
||||
input_default=json.loads(node.constantInput),
|
||||
input_nodes=[
|
||||
Link(v.sinkName, v.agentNodeSourceId)
|
||||
for v in node.Input or []
|
||||
],
|
||||
output_nodes=[
|
||||
Link(v.sourceName, v.agentNodeSinkId)
|
||||
for v in node.Output or []
|
||||
],
|
||||
metadata=json.loads(node.metadata),
|
||||
metadata=json.loads(node.metadata)
|
||||
)
|
||||
|
||||
def connect(self, node: "Node", source_name: str, sink_name: str):
|
||||
self.output_nodes.append(Link(source_name, node.id))
|
||||
node.input_nodes.append(Link(sink_name, self.id))
|
||||
obj._input_links = [Link.from_db(link) for link in node.Input or []]
|
||||
obj._output_links = [Link.from_db(link) for link in node.Output or []]
|
||||
return obj
|
||||
|
||||
|
||||
class Graph(BaseDbModel):
|
||||
name: str
|
||||
description: str
|
||||
nodes: list[Node]
|
||||
links: list[Link]
|
||||
|
||||
@property
|
||||
def starting_nodes(self) -> list[Node]:
|
||||
return [node for node in self.nodes if not node.input_nodes]
|
||||
outbound_nodes = {link.sink_id for link in self.links}
|
||||
return [node for node in self.nodes if node.id not in outbound_nodes]
|
||||
|
||||
@staticmethod
|
||||
def from_db(graph: AgentGraph):
|
||||
|
@ -68,6 +85,11 @@ class Graph(BaseDbModel):
|
|||
name=graph.name or "",
|
||||
description=graph.description or "",
|
||||
nodes=[Node.from_db(node) for node in graph.AgentNodes or []],
|
||||
links=list({
|
||||
Link.from_db(link)
|
||||
for node in graph.AgentNodes or []
|
||||
for link in (node.Input or []) + (node.Output or [])
|
||||
})
|
||||
)
|
||||
|
||||
|
||||
|
@ -121,27 +143,15 @@ async def create_graph(graph: Graph) -> Graph:
|
|||
}) for node in graph.nodes
|
||||
])
|
||||
|
||||
edge_source_names = {
|
||||
(source_node.id, sink_node_id): output_name
|
||||
for source_node in graph.nodes
|
||||
for output_name, sink_node_id in source_node.output_nodes
|
||||
}
|
||||
edge_sink_names = {
|
||||
(source_node_id, sink_node.id): input_name
|
||||
for sink_node in graph.nodes
|
||||
for input_name, source_node_id in sink_node.input_nodes
|
||||
}
|
||||
|
||||
# TODO: replace bulk creation using create_many
|
||||
await asyncio.gather(*[
|
||||
AgentNodeLink.prisma().create({
|
||||
"id": str(uuid.uuid4()),
|
||||
"sourceName": edge_source_names.get((input_node, output_node), ""),
|
||||
"sinkName": edge_sink_names.get((input_node, output_node), ""),
|
||||
"agentNodeSourceId": input_node,
|
||||
"agentNodeSinkId": output_node,
|
||||
"sourceName": link.source_name,
|
||||
"sinkName": link.sink_name,
|
||||
"agentNodeSourceId": link.source_id,
|
||||
"agentNodeSinkId": link.sink_id,
|
||||
})
|
||||
for input_node, output_node in edge_source_names.keys() | edge_sink_names.keys()
|
||||
for link in graph.links
|
||||
])
|
||||
|
||||
if created_graph := await get_graph(graph.id):
|
||||
|
|
|
@ -1,10 +1,10 @@
|
|||
import json
|
||||
from datetime import datetime
|
||||
from typing import Optional, Any
|
||||
|
||||
from prisma.models import AgentGraphExecutionSchedule
|
||||
|
||||
from autogpt_server.data.db import BaseDbModel
|
||||
from autogpt_server.util import json
|
||||
|
||||
|
||||
class ExecutionSchedule(BaseDbModel):
|
||||
|
|
|
@ -9,6 +9,7 @@ from autogpt_server.data.execution import (
|
|||
create_graph_execution,
|
||||
get_node_execution_input,
|
||||
merge_execution_input,
|
||||
parse_execution_output,
|
||||
update_execution_status as execution_update,
|
||||
upsert_execution_output,
|
||||
upsert_execution_input,
|
||||
|
@ -16,7 +17,7 @@ from autogpt_server.data.execution import (
|
|||
ExecutionStatus,
|
||||
ExecutionQueue,
|
||||
)
|
||||
from autogpt_server.data.graph import Node, get_node, get_graph
|
||||
from autogpt_server.data.graph import Link, Node, get_node, get_graph
|
||||
from autogpt_server.util.service import AppService, expose
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
@ -57,7 +58,7 @@ def execute_node(loop: asyncio.AbstractEventLoop, data: Execution) -> ExecutionS
|
|||
logger.error(f"Node {node_id} not found.")
|
||||
return
|
||||
|
||||
node_block = wait(get_block(node.block_id))
|
||||
node_block = get_block(node.block_id)
|
||||
if not node_block:
|
||||
logger.error(f"Block {node.block_id} not found.")
|
||||
return
|
||||
|
@ -74,7 +75,11 @@ def execute_node(loop: asyncio.AbstractEventLoop, data: Execution) -> ExecutionS
|
|||
wait(upsert_execution_output(node_exec_id, output_name, output_data))
|
||||
|
||||
for execution in enqueue_next_nodes(
|
||||
loop, node, output_name, output_data, graph_exec_id
|
||||
loop=loop,
|
||||
node=node,
|
||||
output=(output_name, output_data),
|
||||
graph_exec_id=graph_exec_id,
|
||||
prefix=prefix,
|
||||
):
|
||||
yield execution
|
||||
except Exception as e:
|
||||
|
@ -88,45 +93,43 @@ def execute_node(loop: asyncio.AbstractEventLoop, data: Execution) -> ExecutionS
|
|||
def enqueue_next_nodes(
|
||||
loop: asyncio.AbstractEventLoop,
|
||||
node: Node,
|
||||
output_name: str,
|
||||
output_data: Any,
|
||||
output: tuple[str, Any],
|
||||
graph_exec_id: str,
|
||||
prefix: str,
|
||||
) -> list[Execution]:
|
||||
def wait(f: Coroutine[T, Any, T]) -> T:
|
||||
return loop.run_until_complete(f)
|
||||
|
||||
prefix = get_log_prefix(graph_exec_id, node.id)
|
||||
node_id = node.id
|
||||
def get_next_node_execution(node_link: Link) -> Execution | None:
|
||||
next_output_name = node_link.source_name
|
||||
next_input_name = node_link.sink_name
|
||||
next_node_id = node_link.sink_id
|
||||
|
||||
# Try to enqueue next eligible nodes
|
||||
next_node_ids = [nid for name, nid in node.output_nodes if name == output_name]
|
||||
if not next_node_ids:
|
||||
logger.error(f"{prefix} Output [{output_name}] has no subsequent node.")
|
||||
return []
|
||||
next_data = parse_execution_output(output, next_output_name)
|
||||
if next_data is None:
|
||||
return
|
||||
|
||||
def validate_node_execution(next_node_id: str):
|
||||
next_node = wait(get_node(next_node_id))
|
||||
if not next_node:
|
||||
logger.error(f"{prefix} Error, next node {next_node_id} not found.")
|
||||
return
|
||||
|
||||
next_node_input_name = next(
|
||||
name for name, nid in next_node.input_nodes if nid == node_id
|
||||
)
|
||||
next_node_exec_id = wait(upsert_execution_input(
|
||||
node_id=next_node_id,
|
||||
graph_exec_id=graph_exec_id,
|
||||
input_name=next_node_input_name,
|
||||
data=output_data
|
||||
input_name=next_input_name,
|
||||
data=next_data
|
||||
))
|
||||
|
||||
next_node_input = wait(get_node_execution_input(next_node_exec_id))
|
||||
is_valid, validation_resp = wait(validate_exec(next_node, next_node_input))
|
||||
|
||||
is_valid, validation_msg = validate_exec(next_node, next_node_input)
|
||||
suffix = f"{next_output_name}~{next_input_name}#{next_node_id}:{validation_msg}"
|
||||
|
||||
if not is_valid:
|
||||
logger.warning(f"{prefix} Skipped {next_node_id}: {validation_resp}")
|
||||
logger.warning(f"{prefix} Skipped queueing {suffix}")
|
||||
return
|
||||
|
||||
logger.warning(f"{prefix} Enqueue next node {next_node_id}-{validation_resp}")
|
||||
logger.warning(f"{prefix} Enqueued {suffix}")
|
||||
return Execution(
|
||||
graph_exec_id=graph_exec_id,
|
||||
node_exec_id=next_node_exec_id,
|
||||
|
@ -134,14 +137,11 @@ def enqueue_next_nodes(
|
|||
data=next_node_input
|
||||
)
|
||||
|
||||
executions = []
|
||||
for nid in next_node_ids:
|
||||
if execution := validate_node_execution(nid):
|
||||
executions.append(execution)
|
||||
return executions
|
||||
executions = [get_next_node_execution(link) for link in node.output_links]
|
||||
return [v for v in executions if v]
|
||||
|
||||
|
||||
async def validate_exec(node: Node, data: dict[str, Any]) -> tuple[bool, str]:
|
||||
def validate_exec(node: Node, data: dict[str, Any]) -> tuple[bool, str]:
|
||||
"""
|
||||
Validate the input data for a node execution.
|
||||
|
||||
|
@ -153,21 +153,24 @@ async def validate_exec(node: Node, data: dict[str, Any]) -> tuple[bool, str]:
|
|||
A tuple of a boolean indicating if the data is valid, and a message if not.
|
||||
Return the executed block name if the data is valid.
|
||||
"""
|
||||
node_block: Block | None = await get_block(node.block_id)
|
||||
node_block: Block | None = get_block(node.block_id)
|
||||
if not node_block:
|
||||
return False, f"Block for {node.block_id} not found."
|
||||
|
||||
error_message = f"Input data missing for {node_block.name}:"
|
||||
|
||||
input_fields_from_schema = node_block.input_schema.get_required_fields()
|
||||
if not input_fields_from_schema.issubset(data):
|
||||
return False, f"Input data missing: {input_fields_from_schema - set(data)}"
|
||||
return False, f"{error_message} {input_fields_from_schema - set(data)}"
|
||||
|
||||
input_fields_from_nodes = {name for name, _ in node.input_nodes}
|
||||
input_fields_from_nodes = {link.sink_name for link in node.input_links}
|
||||
if not input_fields_from_nodes.issubset(data):
|
||||
return False, f"Input data missing: {input_fields_from_nodes - set(data)}"
|
||||
return False, f"{error_message} {input_fields_from_nodes - set(data)}"
|
||||
|
||||
if error := node_block.input_schema.validate_data(data):
|
||||
logger.error("Input value doesn't match schema: %s", error)
|
||||
return False, f"Input data doesn't match {node_block.name}: {error}"
|
||||
error_message = f"Input data doesn't match {node_block.name}: {error}"
|
||||
logger.error(error_message)
|
||||
return False, error_message
|
||||
|
||||
return True, node_block.name
|
||||
|
||||
|
@ -221,7 +224,7 @@ class ExecutionManager(AppService):
|
|||
# Currently, there is no constraint on the number of root nodes in the graph.
|
||||
for node in graph.starting_nodes:
|
||||
input_data = merge_execution_input({**node.input_default, **data})
|
||||
valid, error = self.run_and_wait(validate_exec(node, input_data))
|
||||
valid, error = validate_exec(node, input_data)
|
||||
if not valid:
|
||||
raise Exception(error)
|
||||
|
||||
|
|
|
@ -27,6 +27,7 @@ class AgentServer(AppProcess):
|
|||
@asynccontextmanager
|
||||
async def lifespan(self, _: FastAPI):
|
||||
await db.connect()
|
||||
await block.initialize_blocks()
|
||||
yield
|
||||
await db.disconnect()
|
||||
|
||||
|
@ -112,8 +113,8 @@ class AgentServer(AppProcess):
|
|||
def execution_scheduler_client(self) -> ExecutionScheduler:
|
||||
return get_service_client(ExecutionScheduler)
|
||||
|
||||
async def get_graph_blocks(self) -> list[dict]:
|
||||
return [v.to_dict() for v in await block.get_blocks()]
|
||||
def get_graph_blocks(self) -> list[dict]:
|
||||
return [v.to_dict() for v in block.get_blocks()]
|
||||
|
||||
async def get_graphs(self) -> list[str]:
|
||||
return await get_graph_ids()
|
||||
|
@ -128,10 +129,13 @@ class AgentServer(AppProcess):
|
|||
# TODO: replace uuid generation here to DB generated uuids.
|
||||
graph.id = str(uuid.uuid4())
|
||||
id_map = {node.id: str(uuid.uuid4()) for node in graph.nodes}
|
||||
|
||||
for node in graph.nodes:
|
||||
node.id = id_map[node.id]
|
||||
node.input_nodes = [Link(k, id_map[v]) for k, v in node.input_nodes]
|
||||
node.output_nodes = [Link(k, id_map[v]) for k, v in node.output_nodes]
|
||||
|
||||
for link in graph.links:
|
||||
link.source_id = id_map[link.source_id]
|
||||
link.sink_id = id_map[link.sink_id]
|
||||
|
||||
return await create_graph(graph)
|
||||
|
||||
|
|
|
@ -0,0 +1,14 @@
|
|||
import json
|
||||
from fastapi.encoders import jsonable_encoder
|
||||
|
||||
|
||||
def to_dict(data) -> dict:
|
||||
return jsonable_encoder(data)
|
||||
|
||||
|
||||
def dumps(data) -> str:
|
||||
return json.dumps(jsonable_encoder(data))
|
||||
|
||||
|
||||
def loads(data) -> dict:
|
||||
return json.loads(data)
|
|
@ -22,7 +22,6 @@ class UpdateTrackingModel(BaseModel, Generic[T]):
|
|||
self._updated_fields.add(name)
|
||||
super().__setattr__(name, value)
|
||||
|
||||
|
||||
def mark_updated(self, field_name: str) -> None:
|
||||
if field_name in self.model_fields:
|
||||
self._updated_fields.add(field_name)
|
||||
|
@ -33,6 +32,10 @@ class UpdateTrackingModel(BaseModel, Generic[T]):
|
|||
def get_updates(self) -> Dict[str, Any]:
|
||||
return {field: getattr(self, field) for field in self._updated_fields}
|
||||
|
||||
@property
|
||||
def updated_fields(self):
|
||||
return self._updated_fields
|
||||
|
||||
|
||||
class Config(UpdateTrackingModel["Config"], BaseSettings):
|
||||
"""Config for the server."""
|
||||
|
@ -54,12 +57,12 @@ class Config(UpdateTrackingModel["Config"], BaseSettings):
|
|||
|
||||
@classmethod
|
||||
def settings_customise_sources(
|
||||
cls,
|
||||
settings_cls: Type[BaseSettings],
|
||||
init_settings: PydanticBaseSettingsSource,
|
||||
env_settings: PydanticBaseSettingsSource,
|
||||
dotenv_settings: PydanticBaseSettingsSource,
|
||||
file_secret_settings: PydanticBaseSettingsSource,
|
||||
cls,
|
||||
settings_cls: Type[BaseSettings],
|
||||
init_settings: PydanticBaseSettingsSource,
|
||||
env_settings: PydanticBaseSettingsSource,
|
||||
dotenv_settings: PydanticBaseSettingsSource,
|
||||
file_secret_settings: PydanticBaseSettingsSource,
|
||||
) -> Tuple[PydanticBaseSettingsSource, ...]:
|
||||
return (JsonConfigSettingsSource(settings_cls),)
|
||||
|
||||
|
@ -101,7 +104,7 @@ class Settings(BaseModel):
|
|||
|
||||
# Save updated secrets to individual files
|
||||
secrets_dir = get_secrets_path()
|
||||
for key in self.secrets._updated_fields:
|
||||
for key in self.secrets.updated_fields:
|
||||
secret_file = os.path.join(secrets_dir, key)
|
||||
with open(secret_file, "w") as f:
|
||||
f.write(str(getattr(self.secrets, key)))
|
||||
|
|
|
@ -1,4 +1,4 @@
|
|||
# This file is automatically @generated by Poetry 1.8.3 and should not be changed by hand.
|
||||
# This file is automatically @generated by Poetry 1.7.1 and should not be changed by hand.
|
||||
|
||||
[[package]]
|
||||
name = "annotated-types"
|
||||
|
@ -91,6 +91,105 @@ files = [
|
|||
{file = "certifi-2024.6.2.tar.gz", hash = "sha256:3cd43f1c6fa7dedc5899d69d3ad0398fd018ad1a17fba83ddaf78aa46c747516"},
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "charset-normalizer"
|
||||
version = "3.3.2"
|
||||
description = "The Real First Universal Charset Detector. Open, modern and actively maintained alternative to Chardet."
|
||||
optional = false
|
||||
python-versions = ">=3.7.0"
|
||||
files = [
|
||||
{file = "charset-normalizer-3.3.2.tar.gz", hash = "sha256:f30c3cb33b24454a82faecaf01b19c18562b1e89558fb6c56de4d9118a032fd5"},
|
||||
{file = "charset_normalizer-3.3.2-cp310-cp310-macosx_10_9_universal2.whl", hash = "sha256:25baf083bf6f6b341f4121c2f3c548875ee6f5339300e08be3f2b2ba1721cdd3"},
|
||||
{file = "charset_normalizer-3.3.2-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:06435b539f889b1f6f4ac1758871aae42dc3a8c0e24ac9e60c2384973ad73027"},
|
||||
{file = "charset_normalizer-3.3.2-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:9063e24fdb1e498ab71cb7419e24622516c4a04476b17a2dab57e8baa30d6e03"},
|
||||
{file = "charset_normalizer-3.3.2-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:6897af51655e3691ff853668779c7bad41579facacf5fd7253b0133308cf000d"},
|
||||
{file = "charset_normalizer-3.3.2-cp310-cp310-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:1d3193f4a680c64b4b6a9115943538edb896edc190f0b222e73761716519268e"},
|
||||
{file = "charset_normalizer-3.3.2-cp310-cp310-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:cd70574b12bb8a4d2aaa0094515df2463cb429d8536cfb6c7ce983246983e5a6"},
|
||||
{file = "charset_normalizer-3.3.2-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:8465322196c8b4d7ab6d1e049e4c5cb460d0394da4a27d23cc242fbf0034b6b5"},
|
||||
{file = "charset_normalizer-3.3.2-cp310-cp310-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:a9a8e9031d613fd2009c182b69c7b2c1ef8239a0efb1df3f7c8da66d5dd3d537"},
|
||||
{file = "charset_normalizer-3.3.2-cp310-cp310-musllinux_1_1_aarch64.whl", hash = "sha256:beb58fe5cdb101e3a055192ac291b7a21e3b7ef4f67fa1d74e331a7f2124341c"},
|
||||
{file = "charset_normalizer-3.3.2-cp310-cp310-musllinux_1_1_i686.whl", hash = "sha256:e06ed3eb3218bc64786f7db41917d4e686cc4856944f53d5bdf83a6884432e12"},
|
||||
{file = "charset_normalizer-3.3.2-cp310-cp310-musllinux_1_1_ppc64le.whl", hash = "sha256:2e81c7b9c8979ce92ed306c249d46894776a909505d8f5a4ba55b14206e3222f"},
|
||||
{file = "charset_normalizer-3.3.2-cp310-cp310-musllinux_1_1_s390x.whl", hash = "sha256:572c3763a264ba47b3cf708a44ce965d98555f618ca42c926a9c1616d8f34269"},
|
||||
{file = "charset_normalizer-3.3.2-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:fd1abc0d89e30cc4e02e4064dc67fcc51bd941eb395c502aac3ec19fab46b519"},
|
||||
{file = "charset_normalizer-3.3.2-cp310-cp310-win32.whl", hash = "sha256:3d47fa203a7bd9c5b6cee4736ee84ca03b8ef23193c0d1ca99b5089f72645c73"},
|
||||
{file = "charset_normalizer-3.3.2-cp310-cp310-win_amd64.whl", hash = "sha256:10955842570876604d404661fbccbc9c7e684caf432c09c715ec38fbae45ae09"},
|
||||
{file = "charset_normalizer-3.3.2-cp311-cp311-macosx_10_9_universal2.whl", hash = "sha256:802fe99cca7457642125a8a88a084cef28ff0cf9407060f7b93dca5aa25480db"},
|
||||
{file = "charset_normalizer-3.3.2-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:573f6eac48f4769d667c4442081b1794f52919e7edada77495aaed9236d13a96"},
|
||||
{file = "charset_normalizer-3.3.2-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:549a3a73da901d5bc3ce8d24e0600d1fa85524c10287f6004fbab87672bf3e1e"},
|
||||
{file = "charset_normalizer-3.3.2-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:f27273b60488abe721a075bcca6d7f3964f9f6f067c8c4c605743023d7d3944f"},
|
||||
{file = "charset_normalizer-3.3.2-cp311-cp311-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:1ceae2f17a9c33cb48e3263960dc5fc8005351ee19db217e9b1bb15d28c02574"},
|
||||
{file = "charset_normalizer-3.3.2-cp311-cp311-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:65f6f63034100ead094b8744b3b97965785388f308a64cf8d7c34f2f2e5be0c4"},
|
||||
{file = "charset_normalizer-3.3.2-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:753f10e867343b4511128c6ed8c82f7bec3bd026875576dfd88483c5c73b2fd8"},
|
||||
{file = "charset_normalizer-3.3.2-cp311-cp311-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:4a78b2b446bd7c934f5dcedc588903fb2f5eec172f3d29e52a9096a43722adfc"},
|
||||
{file = "charset_normalizer-3.3.2-cp311-cp311-musllinux_1_1_aarch64.whl", hash = "sha256:e537484df0d8f426ce2afb2d0f8e1c3d0b114b83f8850e5f2fbea0e797bd82ae"},
|
||||
{file = "charset_normalizer-3.3.2-cp311-cp311-musllinux_1_1_i686.whl", hash = "sha256:eb6904c354526e758fda7167b33005998fb68c46fbc10e013ca97f21ca5c8887"},
|
||||
{file = "charset_normalizer-3.3.2-cp311-cp311-musllinux_1_1_ppc64le.whl", hash = "sha256:deb6be0ac38ece9ba87dea880e438f25ca3eddfac8b002a2ec3d9183a454e8ae"},
|
||||
{file = "charset_normalizer-3.3.2-cp311-cp311-musllinux_1_1_s390x.whl", hash = "sha256:4ab2fe47fae9e0f9dee8c04187ce5d09f48eabe611be8259444906793ab7cbce"},
|
||||
{file = "charset_normalizer-3.3.2-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:80402cd6ee291dcb72644d6eac93785fe2c8b9cb30893c1af5b8fdd753b9d40f"},
|
||||
{file = "charset_normalizer-3.3.2-cp311-cp311-win32.whl", hash = "sha256:7cd13a2e3ddeed6913a65e66e94b51d80a041145a026c27e6bb76c31a853c6ab"},
|
||||
{file = "charset_normalizer-3.3.2-cp311-cp311-win_amd64.whl", hash = "sha256:663946639d296df6a2bb2aa51b60a2454ca1cb29835324c640dafb5ff2131a77"},
|
||||
{file = "charset_normalizer-3.3.2-cp312-cp312-macosx_10_9_universal2.whl", hash = "sha256:0b2b64d2bb6d3fb9112bafa732def486049e63de9618b5843bcdd081d8144cd8"},
|
||||
{file = "charset_normalizer-3.3.2-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:ddbb2551d7e0102e7252db79ba445cdab71b26640817ab1e3e3648dad515003b"},
|
||||
{file = "charset_normalizer-3.3.2-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:55086ee1064215781fff39a1af09518bc9255b50d6333f2e4c74ca09fac6a8f6"},
|
||||
{file = "charset_normalizer-3.3.2-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:8f4a014bc36d3c57402e2977dada34f9c12300af536839dc38c0beab8878f38a"},
|
||||
{file = "charset_normalizer-3.3.2-cp312-cp312-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:a10af20b82360ab00827f916a6058451b723b4e65030c5a18577c8b2de5b3389"},
|
||||
{file = "charset_normalizer-3.3.2-cp312-cp312-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:8d756e44e94489e49571086ef83b2bb8ce311e730092d2c34ca8f7d925cb20aa"},
|
||||
{file = "charset_normalizer-3.3.2-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:90d558489962fd4918143277a773316e56c72da56ec7aa3dc3dbbe20fdfed15b"},
|
||||
{file = "charset_normalizer-3.3.2-cp312-cp312-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:6ac7ffc7ad6d040517be39eb591cac5ff87416c2537df6ba3cba3bae290c0fed"},
|
||||
{file = "charset_normalizer-3.3.2-cp312-cp312-musllinux_1_1_aarch64.whl", hash = "sha256:7ed9e526742851e8d5cc9e6cf41427dfc6068d4f5a3bb03659444b4cabf6bc26"},
|
||||
{file = "charset_normalizer-3.3.2-cp312-cp312-musllinux_1_1_i686.whl", hash = "sha256:8bdb58ff7ba23002a4c5808d608e4e6c687175724f54a5dade5fa8c67b604e4d"},
|
||||
{file = "charset_normalizer-3.3.2-cp312-cp312-musllinux_1_1_ppc64le.whl", hash = "sha256:6b3251890fff30ee142c44144871185dbe13b11bab478a88887a639655be1068"},
|
||||
{file = "charset_normalizer-3.3.2-cp312-cp312-musllinux_1_1_s390x.whl", hash = "sha256:b4a23f61ce87adf89be746c8a8974fe1c823c891d8f86eb218bb957c924bb143"},
|
||||
{file = "charset_normalizer-3.3.2-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:efcb3f6676480691518c177e3b465bcddf57cea040302f9f4e6e191af91174d4"},
|
||||
{file = "charset_normalizer-3.3.2-cp312-cp312-win32.whl", hash = "sha256:d965bba47ddeec8cd560687584e88cf699fd28f192ceb452d1d7ee807c5597b7"},
|
||||
{file = "charset_normalizer-3.3.2-cp312-cp312-win_amd64.whl", hash = "sha256:96b02a3dc4381e5494fad39be677abcb5e6634bf7b4fa83a6dd3112607547001"},
|
||||
{file = "charset_normalizer-3.3.2-cp37-cp37m-macosx_10_9_x86_64.whl", hash = "sha256:95f2a5796329323b8f0512e09dbb7a1860c46a39da62ecb2324f116fa8fdc85c"},
|
||||
{file = "charset_normalizer-3.3.2-cp37-cp37m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:c002b4ffc0be611f0d9da932eb0f704fe2602a9a949d1f738e4c34c75b0863d5"},
|
||||
{file = "charset_normalizer-3.3.2-cp37-cp37m-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:a981a536974bbc7a512cf44ed14938cf01030a99e9b3a06dd59578882f06f985"},
|
||||
{file = "charset_normalizer-3.3.2-cp37-cp37m-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:3287761bc4ee9e33561a7e058c72ac0938c4f57fe49a09eae428fd88aafe7bb6"},
|
||||
{file = "charset_normalizer-3.3.2-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:42cb296636fcc8b0644486d15c12376cb9fa75443e00fb25de0b8602e64c1714"},
|
||||
{file = "charset_normalizer-3.3.2-cp37-cp37m-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:0a55554a2fa0d408816b3b5cedf0045f4b8e1a6065aec45849de2d6f3f8e9786"},
|
||||
{file = "charset_normalizer-3.3.2-cp37-cp37m-musllinux_1_1_aarch64.whl", hash = "sha256:c083af607d2515612056a31f0a8d9e0fcb5876b7bfc0abad3ecd275bc4ebc2d5"},
|
||||
{file = "charset_normalizer-3.3.2-cp37-cp37m-musllinux_1_1_i686.whl", hash = "sha256:87d1351268731db79e0f8e745d92493ee2841c974128ef629dc518b937d9194c"},
|
||||
{file = "charset_normalizer-3.3.2-cp37-cp37m-musllinux_1_1_ppc64le.whl", hash = "sha256:bd8f7df7d12c2db9fab40bdd87a7c09b1530128315d047a086fa3ae3435cb3a8"},
|
||||
{file = "charset_normalizer-3.3.2-cp37-cp37m-musllinux_1_1_s390x.whl", hash = "sha256:c180f51afb394e165eafe4ac2936a14bee3eb10debc9d9e4db8958fe36afe711"},
|
||||
{file = "charset_normalizer-3.3.2-cp37-cp37m-musllinux_1_1_x86_64.whl", hash = "sha256:8c622a5fe39a48f78944a87d4fb8a53ee07344641b0562c540d840748571b811"},
|
||||
{file = "charset_normalizer-3.3.2-cp37-cp37m-win32.whl", hash = "sha256:db364eca23f876da6f9e16c9da0df51aa4f104a972735574842618b8c6d999d4"},
|
||||
{file = "charset_normalizer-3.3.2-cp37-cp37m-win_amd64.whl", hash = "sha256:86216b5cee4b06df986d214f664305142d9c76df9b6512be2738aa72a2048f99"},
|
||||
{file = "charset_normalizer-3.3.2-cp38-cp38-macosx_10_9_universal2.whl", hash = "sha256:6463effa3186ea09411d50efc7d85360b38d5f09b870c48e4600f63af490e56a"},
|
||||
{file = "charset_normalizer-3.3.2-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:6c4caeef8fa63d06bd437cd4bdcf3ffefe6738fb1b25951440d80dc7df8c03ac"},
|
||||
{file = "charset_normalizer-3.3.2-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:37e55c8e51c236f95b033f6fb391d7d7970ba5fe7ff453dad675e88cf303377a"},
|
||||
{file = "charset_normalizer-3.3.2-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:fb69256e180cb6c8a894fee62b3afebae785babc1ee98b81cdf68bbca1987f33"},
|
||||
{file = "charset_normalizer-3.3.2-cp38-cp38-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:ae5f4161f18c61806f411a13b0310bea87f987c7d2ecdbdaad0e94eb2e404238"},
|
||||
{file = "charset_normalizer-3.3.2-cp38-cp38-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:b2b0a0c0517616b6869869f8c581d4eb2dd83a4d79e0ebcb7d373ef9956aeb0a"},
|
||||
{file = "charset_normalizer-3.3.2-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:45485e01ff4d3630ec0d9617310448a8702f70e9c01906b0d0118bdf9d124cf2"},
|
||||
{file = "charset_normalizer-3.3.2-cp38-cp38-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:eb00ed941194665c332bf8e078baf037d6c35d7c4f3102ea2d4f16ca94a26dc8"},
|
||||
{file = "charset_normalizer-3.3.2-cp38-cp38-musllinux_1_1_aarch64.whl", hash = "sha256:2127566c664442652f024c837091890cb1942c30937add288223dc895793f898"},
|
||||
{file = "charset_normalizer-3.3.2-cp38-cp38-musllinux_1_1_i686.whl", hash = "sha256:a50aebfa173e157099939b17f18600f72f84eed3049e743b68ad15bd69b6bf99"},
|
||||
{file = "charset_normalizer-3.3.2-cp38-cp38-musllinux_1_1_ppc64le.whl", hash = "sha256:4d0d1650369165a14e14e1e47b372cfcb31d6ab44e6e33cb2d4e57265290044d"},
|
||||
{file = "charset_normalizer-3.3.2-cp38-cp38-musllinux_1_1_s390x.whl", hash = "sha256:923c0c831b7cfcb071580d3f46c4baf50f174be571576556269530f4bbd79d04"},
|
||||
{file = "charset_normalizer-3.3.2-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:06a81e93cd441c56a9b65d8e1d043daeb97a3d0856d177d5c90ba85acb3db087"},
|
||||
{file = "charset_normalizer-3.3.2-cp38-cp38-win32.whl", hash = "sha256:6ef1d82a3af9d3eecdba2321dc1b3c238245d890843e040e41e470ffa64c3e25"},
|
||||
{file = "charset_normalizer-3.3.2-cp38-cp38-win_amd64.whl", hash = "sha256:eb8821e09e916165e160797a6c17edda0679379a4be5c716c260e836e122f54b"},
|
||||
{file = "charset_normalizer-3.3.2-cp39-cp39-macosx_10_9_universal2.whl", hash = "sha256:c235ebd9baae02f1b77bcea61bce332cb4331dc3617d254df3323aa01ab47bd4"},
|
||||
{file = "charset_normalizer-3.3.2-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:5b4c145409bef602a690e7cfad0a15a55c13320ff7a3ad7ca59c13bb8ba4d45d"},
|
||||
{file = "charset_normalizer-3.3.2-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:68d1f8a9e9e37c1223b656399be5d6b448dea850bed7d0f87a8311f1ff3dabb0"},
|
||||
{file = "charset_normalizer-3.3.2-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:22afcb9f253dac0696b5a4be4a1c0f8762f8239e21b99680099abd9b2b1b2269"},
|
||||
{file = "charset_normalizer-3.3.2-cp39-cp39-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:e27ad930a842b4c5eb8ac0016b0a54f5aebbe679340c26101df33424142c143c"},
|
||||
{file = "charset_normalizer-3.3.2-cp39-cp39-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:1f79682fbe303db92bc2b1136016a38a42e835d932bab5b3b1bfcfbf0640e519"},
|
||||
{file = "charset_normalizer-3.3.2-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:b261ccdec7821281dade748d088bb6e9b69e6d15b30652b74cbbac25e280b796"},
|
||||
{file = "charset_normalizer-3.3.2-cp39-cp39-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:122c7fa62b130ed55f8f285bfd56d5f4b4a5b503609d181f9ad85e55c89f4185"},
|
||||
{file = "charset_normalizer-3.3.2-cp39-cp39-musllinux_1_1_aarch64.whl", hash = "sha256:d0eccceffcb53201b5bfebb52600a5fb483a20b61da9dbc885f8b103cbe7598c"},
|
||||
{file = "charset_normalizer-3.3.2-cp39-cp39-musllinux_1_1_i686.whl", hash = "sha256:9f96df6923e21816da7e0ad3fd47dd8f94b2a5ce594e00677c0013018b813458"},
|
||||
{file = "charset_normalizer-3.3.2-cp39-cp39-musllinux_1_1_ppc64le.whl", hash = "sha256:7f04c839ed0b6b98b1a7501a002144b76c18fb1c1850c8b98d458ac269e26ed2"},
|
||||
{file = "charset_normalizer-3.3.2-cp39-cp39-musllinux_1_1_s390x.whl", hash = "sha256:34d1c8da1e78d2e001f363791c98a272bb734000fcef47a491c1e3b0505657a8"},
|
||||
{file = "charset_normalizer-3.3.2-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:ff8fa367d09b717b2a17a052544193ad76cd49979c805768879cb63d9ca50561"},
|
||||
{file = "charset_normalizer-3.3.2-cp39-cp39-win32.whl", hash = "sha256:aed38f6e4fb3f5d6bf81bfa990a07806be9d83cf7bacef998ab1a9bd660a581f"},
|
||||
{file = "charset_normalizer-3.3.2-cp39-cp39-win_amd64.whl", hash = "sha256:b01b88d45a6fcb69667cd6d2f7a9aeb4bf53760d7fc536bf679ec94fe9f3ff3d"},
|
||||
{file = "charset_normalizer-3.3.2-py3-none-any.whl", hash = "sha256:3e4d1f6587322d2788836a99c69062fbb091331ec940e02d12d179c1d53e25fc"},
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "click"
|
||||
version = "8.1.7"
|
||||
|
@ -190,6 +289,17 @@ files = [
|
|||
{file = "cx_Logging-3.2.0.tar.gz", hash = "sha256:bdbad6d2e6a0cc5bef962a34d7aa1232e88ea9f3541d6e2881675b5e7eab5502"},
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "distro"
|
||||
version = "1.9.0"
|
||||
description = "Distro - an OS platform information API"
|
||||
optional = false
|
||||
python-versions = ">=3.6"
|
||||
files = [
|
||||
{file = "distro-1.9.0-py3-none-any.whl", hash = "sha256:7bffd925d65168f85027d8da9af6bddab658135b840670a223589bc0c8ef02b2"},
|
||||
{file = "distro-1.9.0.tar.gz", hash = "sha256:2fa77c6fd8940f116ee1d6b94a2f90b13b5ea8d019b98bc8bafdcabcdd9bdbed"},
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "dmgbuild"
|
||||
version = "1.6.1"
|
||||
|
@ -438,6 +548,17 @@ MarkupSafe = ">=2.0"
|
|||
[package.extras]
|
||||
i18n = ["Babel (>=2.7)"]
|
||||
|
||||
[[package]]
|
||||
name = "jsonref"
|
||||
version = "1.1.0"
|
||||
description = "jsonref is a library for automatic dereferencing of JSON Reference objects for Python."
|
||||
optional = false
|
||||
python-versions = ">=3.7"
|
||||
files = [
|
||||
{file = "jsonref-1.1.0-py3-none-any.whl", hash = "sha256:590dc7773df6c21cbf948b5dac07a72a251db28b0238ceecce0a2abfa8ec30a9"},
|
||||
{file = "jsonref-1.1.0.tar.gz", hash = "sha256:32fe8e1d85af0fdefbebce950af85590b22b60f9e95443176adbde4e1ecea552"},
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "jsonschema"
|
||||
version = "4.22.0"
|
||||
|
@ -619,6 +740,29 @@ files = [
|
|||
{file = "nodeenv-1.9.1.tar.gz", hash = "sha256:6ec12890a2dab7946721edbfbcd91f3319c6ccc9aec47be7c7e6b7011ee6645f"},
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "openai"
|
||||
version = "1.35.7"
|
||||
description = "The official Python library for the openai API"
|
||||
optional = false
|
||||
python-versions = ">=3.7.1"
|
||||
files = [
|
||||
{file = "openai-1.35.7-py3-none-any.whl", hash = "sha256:3d1e0b0aac9b0db69a972d36dc7efa7563f8e8d65550b27a48f2a0c2ec207e80"},
|
||||
{file = "openai-1.35.7.tar.gz", hash = "sha256:009bfa1504c9c7ef64d87be55936d142325656bbc6d98c68b669d6472e4beb09"},
|
||||
]
|
||||
|
||||
[package.dependencies]
|
||||
anyio = ">=3.5.0,<5"
|
||||
distro = ">=1.7.0,<2"
|
||||
httpx = ">=0.23.0,<1"
|
||||
pydantic = ">=1.9.0,<3"
|
||||
sniffio = "*"
|
||||
tqdm = ">4"
|
||||
typing-extensions = ">=4.7,<5"
|
||||
|
||||
[package.extras]
|
||||
datalib = ["numpy (>=1)", "pandas (>=1.2.3)", "pandas-stubs (>=1.1.0.11)"]
|
||||
|
||||
[[package]]
|
||||
name = "packaging"
|
||||
version = "24.1"
|
||||
|
@ -693,6 +837,49 @@ tomli = ">=1.2.2"
|
|||
[package.extras]
|
||||
poetry-plugin = ["poetry (>=1.0,<2.0)"]
|
||||
|
||||
[[package]]
|
||||
name = "praw"
|
||||
version = "7.7.1"
|
||||
description = "PRAW, an acronym for \"Python Reddit API Wrapper\", is a Python package that allows for simple access to Reddit's API."
|
||||
optional = false
|
||||
python-versions = "~=3.7"
|
||||
files = [
|
||||
{file = "praw-7.7.1-py3-none-any.whl", hash = "sha256:9ec5dc943db00c175bc6a53f4e089ce625f3fdfb27305564b616747b767d38ef"},
|
||||
{file = "praw-7.7.1.tar.gz", hash = "sha256:f1d7eef414cafe28080dda12ed09253a095a69933d5c8132eca11d4dc8a070bf"},
|
||||
]
|
||||
|
||||
[package.dependencies]
|
||||
prawcore = ">=2.1,<3"
|
||||
update-checker = ">=0.18"
|
||||
websocket-client = ">=0.54.0"
|
||||
|
||||
[package.extras]
|
||||
ci = ["coveralls"]
|
||||
dev = ["betamax (>=0.8,<0.9)", "betamax-matchers (>=0.3.0,<0.5)", "furo", "packaging", "pre-commit", "pytest (>=2.7.3)", "requests (>=2.20.1,<3)", "sphinx", "urllib3 (==1.26.*)"]
|
||||
lint = ["furo", "pre-commit", "sphinx"]
|
||||
readthedocs = ["furo", "sphinx"]
|
||||
test = ["betamax (>=0.8,<0.9)", "betamax-matchers (>=0.3.0,<0.5)", "pytest (>=2.7.3)", "requests (>=2.20.1,<3)", "urllib3 (==1.26.*)"]
|
||||
|
||||
[[package]]
|
||||
name = "prawcore"
|
||||
version = "2.4.0"
|
||||
description = "\"Low-level communication layer for PRAW 4+."
|
||||
optional = false
|
||||
python-versions = "~=3.8"
|
||||
files = [
|
||||
{file = "prawcore-2.4.0-py3-none-any.whl", hash = "sha256:29af5da58d85704b439ad3c820873ad541f4535e00bb98c66f0fbcc8c603065a"},
|
||||
{file = "prawcore-2.4.0.tar.gz", hash = "sha256:b7b2b5a1d04406e086ab4e79988dc794df16059862f329f4c6a43ed09986c335"},
|
||||
]
|
||||
|
||||
[package.dependencies]
|
||||
requests = ">=2.6.0,<3.0"
|
||||
|
||||
[package.extras]
|
||||
ci = ["coveralls"]
|
||||
dev = ["packaging", "prawcore[lint]", "prawcore[test]"]
|
||||
lint = ["pre-commit", "ruff (>=0.0.291)"]
|
||||
test = ["betamax (>=0.8,<0.9)", "pytest (>=2.7.3)", "urllib3 (==1.26.*)"]
|
||||
|
||||
[[package]]
|
||||
name = "prisma"
|
||||
version = "0.13.1"
|
||||
|
@ -1081,6 +1268,27 @@ files = [
|
|||
attrs = ">=22.2.0"
|
||||
rpds-py = ">=0.7.0"
|
||||
|
||||
[[package]]
|
||||
name = "requests"
|
||||
version = "2.32.3"
|
||||
description = "Python HTTP for Humans."
|
||||
optional = false
|
||||
python-versions = ">=3.8"
|
||||
files = [
|
||||
{file = "requests-2.32.3-py3-none-any.whl", hash = "sha256:70761cfe03c773ceb22aa2f671b4757976145175cdfca038c02654d061d6dcc6"},
|
||||
{file = "requests-2.32.3.tar.gz", hash = "sha256:55365417734eb18255590a9ff9eb97e9e1da868d4ccd6402399eaf68af20a760"},
|
||||
]
|
||||
|
||||
[package.dependencies]
|
||||
certifi = ">=2017.4.17"
|
||||
charset-normalizer = ">=2,<4"
|
||||
idna = ">=2.5,<4"
|
||||
urllib3 = ">=1.21.1,<3"
|
||||
|
||||
[package.extras]
|
||||
socks = ["PySocks (>=1.5.6,!=1.5.7)"]
|
||||
use-chardet-on-py3 = ["chardet (>=3.0.2,<6)"]
|
||||
|
||||
[[package]]
|
||||
name = "rpds-py"
|
||||
version = "0.18.1"
|
||||
|
@ -1334,6 +1542,26 @@ files = [
|
|||
{file = "tomlkit-0.12.5.tar.gz", hash = "sha256:eef34fba39834d4d6b73c9ba7f3e4d1c417a4e56f89a7e96e090dd0d24b8fb3c"},
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "tqdm"
|
||||
version = "4.66.4"
|
||||
description = "Fast, Extensible Progress Meter"
|
||||
optional = false
|
||||
python-versions = ">=3.7"
|
||||
files = [
|
||||
{file = "tqdm-4.66.4-py3-none-any.whl", hash = "sha256:b75ca56b413b030bc3f00af51fd2c1a1a5eac6a0c1cca83cbb37a5c52abce644"},
|
||||
{file = "tqdm-4.66.4.tar.gz", hash = "sha256:e4d936c9de8727928f3be6079590e97d9abfe8d39a590be678eb5919ffc186bb"},
|
||||
]
|
||||
|
||||
[package.dependencies]
|
||||
colorama = {version = "*", markers = "platform_system == \"Windows\""}
|
||||
|
||||
[package.extras]
|
||||
dev = ["pytest (>=6)", "pytest-cov", "pytest-timeout", "pytest-xdist"]
|
||||
notebook = ["ipywidgets (>=6)"]
|
||||
slack = ["slack-sdk"]
|
||||
telegram = ["requests"]
|
||||
|
||||
[[package]]
|
||||
name = "typing-extensions"
|
||||
version = "4.12.2"
|
||||
|
@ -1373,6 +1601,42 @@ tzdata = {version = "*", markers = "platform_system == \"Windows\""}
|
|||
[package.extras]
|
||||
devenv = ["check-manifest", "pytest (>=4.3)", "pytest-cov", "pytest-mock (>=3.3)", "zest.releaser"]
|
||||
|
||||
[[package]]
|
||||
name = "update-checker"
|
||||
version = "0.18.0"
|
||||
description = "A python module that will check for package updates."
|
||||
optional = false
|
||||
python-versions = "*"
|
||||
files = [
|
||||
{file = "update_checker-0.18.0-py3-none-any.whl", hash = "sha256:cbba64760a36fe2640d80d85306e8fe82b6816659190993b7bdabadee4d4bbfd"},
|
||||
{file = "update_checker-0.18.0.tar.gz", hash = "sha256:6a2d45bb4ac585884a6b03f9eade9161cedd9e8111545141e9aa9058932acb13"},
|
||||
]
|
||||
|
||||
[package.dependencies]
|
||||
requests = ">=2.3.0"
|
||||
|
||||
[package.extras]
|
||||
dev = ["black", "flake8", "pytest (>=2.7.3)"]
|
||||
lint = ["black", "flake8"]
|
||||
test = ["pytest (>=2.7.3)"]
|
||||
|
||||
[[package]]
|
||||
name = "urllib3"
|
||||
version = "2.2.2"
|
||||
description = "HTTP library with thread-safe connection pooling, file post, and more."
|
||||
optional = false
|
||||
python-versions = ">=3.8"
|
||||
files = [
|
||||
{file = "urllib3-2.2.2-py3-none-any.whl", hash = "sha256:a448b2f64d686155468037e1ace9f2d2199776e17f0a46610480d311f73e3472"},
|
||||
{file = "urllib3-2.2.2.tar.gz", hash = "sha256:dd505485549a7a552833da5e6063639d0d177c04f23bc3864e41e5dc5f612168"},
|
||||
]
|
||||
|
||||
[package.extras]
|
||||
brotli = ["brotli (>=1.0.9)", "brotlicffi (>=0.8.0)"]
|
||||
h2 = ["h2 (>=4,<5)"]
|
||||
socks = ["pysocks (>=1.5.6,!=1.5.7,<2.0)"]
|
||||
zstd = ["zstandard (>=0.18.0)"]
|
||||
|
||||
[[package]]
|
||||
name = "uvicorn"
|
||||
version = "0.30.1"
|
||||
|
@ -1574,6 +1838,22 @@ files = [
|
|||
[package.dependencies]
|
||||
anyio = ">=3.0.0"
|
||||
|
||||
[[package]]
|
||||
name = "websocket-client"
|
||||
version = "1.8.0"
|
||||
description = "WebSocket client for Python with low level API options"
|
||||
optional = false
|
||||
python-versions = ">=3.8"
|
||||
files = [
|
||||
{file = "websocket_client-1.8.0-py3-none-any.whl", hash = "sha256:17b44cc997f5c498e809b22cdf2d9c7a9e71c02c8cc2b6c56e7c2d1239bfa526"},
|
||||
{file = "websocket_client-1.8.0.tar.gz", hash = "sha256:3239df9f44da632f96012472805d40a23281a991027ce11d2f45a6f24ac4c3da"},
|
||||
]
|
||||
|
||||
[package.extras]
|
||||
docs = ["Sphinx (>=6.0)", "myst-parser (>=2.0.0)", "sphinx-rtd-theme (>=1.1.0)"]
|
||||
optional = ["python-socks", "wsaccel"]
|
||||
test = ["websockets"]
|
||||
|
||||
[[package]]
|
||||
name = "websockets"
|
||||
version = "12.0"
|
||||
|
@ -1672,4 +1952,4 @@ test = ["pytest (>=6.0.0)", "setuptools (>=65)"]
|
|||
[metadata]
|
||||
lock-version = "2.0"
|
||||
python-versions = "^3.10"
|
||||
content-hash = "a243c28b48b60e14513fc18629096a7f9d1c60ae7b05a6c50125c1d4c045033e"
|
||||
content-hash = "ef4326d7688ca5c6c7d3c5189d1058183bd6916f2ba536ad42ad306c9d2ab595"
|
||||
|
|
|
@ -26,6 +26,9 @@ apscheduler = "^3.10.4"
|
|||
croniter = "^2.0.5"
|
||||
pytest-asyncio = "^0.23.7"
|
||||
pydantic-settings = "^2.3.4"
|
||||
praw = "^7.7.1"
|
||||
openai = "^1.35.7"
|
||||
jsonref = "^1.1.0"
|
||||
|
||||
|
||||
[tool.poetry.group.dev.dependencies]
|
||||
|
|
|
@ -1,11 +1,12 @@
|
|||
import time
|
||||
|
||||
import pytest
|
||||
|
||||
from autogpt_server.data import block, db, execution, graph
|
||||
from autogpt_server.executor import ExecutionManager
|
||||
from autogpt_server.server import AgentServer
|
||||
from autogpt_server.util.service import PyroNameServer
|
||||
from autogpt_server.blocks.sample import ParrotBlock, PrintingBlock
|
||||
from autogpt_server.blocks.text import TextFormatterBlock
|
||||
|
||||
|
||||
async def create_test_graph() -> graph.Graph:
|
||||
|
@ -17,27 +18,28 @@ async def create_test_graph() -> graph.Graph:
|
|||
ParrotBlock
|
||||
"""
|
||||
nodes = [
|
||||
graph.Node(block_id=block.ParrotBlock.id),
|
||||
graph.Node(block_id=block.ParrotBlock.id),
|
||||
graph.Node(block_id=ParrotBlock().id),
|
||||
graph.Node(block_id=ParrotBlock().id),
|
||||
graph.Node(
|
||||
block_id=block.TextFormatterBlock.id,
|
||||
block_id=TextFormatterBlock().id,
|
||||
input_default={
|
||||
"format": "{texts[0]},{texts[1]},{texts[2]}",
|
||||
"texts_$_3": "!!!",
|
||||
},
|
||||
),
|
||||
graph.Node(block_id=block.PrintingBlock.id),
|
||||
graph.Node(block_id=PrintingBlock().id),
|
||||
]
|
||||
links = [
|
||||
graph.Link(nodes[0].id, nodes[2].id, "output", "texts_$_1"),
|
||||
graph.Link(nodes[1].id, nodes[2].id, "output", "texts_$_2"),
|
||||
graph.Link(nodes[2].id, nodes[3].id, "output", "text"),
|
||||
]
|
||||
nodes[0].connect(nodes[2], "output", "texts_$_1")
|
||||
nodes[1].connect(nodes[2], "output", "texts_$_2")
|
||||
nodes[2].connect(nodes[3], "combined_text", "text")
|
||||
|
||||
test_graph = graph.Graph(
|
||||
name="TestGraph",
|
||||
description="Test graph",
|
||||
nodes=nodes,
|
||||
links=links,
|
||||
)
|
||||
await block.initialize_blocks()
|
||||
result = await graph.create_graph(test_graph)
|
||||
|
||||
# Assertions
|
||||
|
@ -48,7 +50,7 @@ async def create_test_graph() -> graph.Graph:
|
|||
return test_graph
|
||||
|
||||
|
||||
async def execute_graph(test_manager: ExecutionManager, test_graph: graph.Graph):
|
||||
async def execute_graph(test_manager: ExecutionManager, test_graph: graph.Graph) -> str:
|
||||
# --- Test adding new executions --- #
|
||||
text = "Hello, World!"
|
||||
input_data = {"input": text}
|
||||
|
@ -70,6 +72,12 @@ async def execute_graph(test_manager: ExecutionManager, test_graph: graph.Graph)
|
|||
|
||||
# Execution queue should be empty
|
||||
assert await is_execution_completed()
|
||||
return graph_exec_id
|
||||
|
||||
|
||||
async def assert_executions(test_graph: graph.Graph, graph_exec_id: str):
|
||||
text = "Hello, World!"
|
||||
agent_server = AgentServer()
|
||||
executions = await agent_server.get_executions(test_graph.id, graph_exec_id)
|
||||
|
||||
# Executing ParrotBlock1
|
||||
|
@ -92,7 +100,7 @@ async def execute_graph(test_manager: ExecutionManager, test_graph: graph.Graph)
|
|||
exec = executions[2]
|
||||
assert exec.status == execution.ExecutionStatus.COMPLETED
|
||||
assert exec.graph_exec_id == graph_exec_id
|
||||
assert exec.output_data == {"combined_text": ["Hello, World!,Hello, World!,!!!"]}
|
||||
assert exec.output_data == {"output": ["Hello, World!,Hello, World!,!!!"]}
|
||||
assert exec.input_data == {
|
||||
"texts_$_1": "Hello, World!",
|
||||
"texts_$_2": "Hello, World!",
|
||||
|
@ -113,5 +121,7 @@ async def test_agent_execution():
|
|||
with PyroNameServer():
|
||||
with ExecutionManager(1) as test_manager:
|
||||
await db.connect()
|
||||
await block.initialize_blocks()
|
||||
test_graph = await create_test_graph()
|
||||
await execute_graph(test_manager, test_graph)
|
||||
graph_exec_id = await execute_graph(test_manager, test_graph)
|
||||
await assert_executions(test_graph, graph_exec_id)
|
||||
|
|
|
@ -0,0 +1,161 @@
|
|||
import time
|
||||
from autogpt_server.data import block, db
|
||||
from autogpt_server.data.graph import Graph, Link, Node, create_graph
|
||||
from autogpt_server.data.execution import ExecutionStatus
|
||||
from autogpt_server.blocks.ai import LlmConfig, LlmCallBlock, LlmModel
|
||||
from autogpt_server.blocks.reddit import (
|
||||
RedditCredentials,
|
||||
RedditGetPostsBlock,
|
||||
RedditPostCommentBlock,
|
||||
)
|
||||
from autogpt_server.blocks.text import TextFormatterBlock, TextMatcherBlock
|
||||
from autogpt_server.executor import ExecutionManager
|
||||
from autogpt_server.server import AgentServer
|
||||
from autogpt_server.util.service import PyroNameServer
|
||||
|
||||
|
||||
async def create_test_graph() -> Graph:
|
||||
# /--- post_id -----------\ /--- post_id ---\
|
||||
# subreddit --> RedditGetPostsBlock ---- post_body -------- TextFormatterBlock ----- LlmCallBlock / TextRelevancy --- relevant/not -- TextMatcherBlock -- Yes {postid, text} --- RedditPostCommentBlock
|
||||
# \--- post_title -------/ \--- marketing_text ---/ -- No
|
||||
|
||||
# Creds
|
||||
reddit_creds = RedditCredentials(
|
||||
client_id="TODO_FILL_OUT_THIS",
|
||||
client_secret="TODO_FILL_OUT_THIS",
|
||||
username="TODO_FILL_OUT_THIS",
|
||||
password="TODO_FILL_OUT_THIS",
|
||||
user_agent="TODO_FILL_OUT_THIS",
|
||||
)
|
||||
openai_creds = LlmConfig(
|
||||
model=LlmModel.openai_gpt4,
|
||||
api_key="TODO_FILL_OUT_THIS",
|
||||
)
|
||||
|
||||
# Hardcoded inputs
|
||||
reddit_get_post_input = {
|
||||
"creds": reddit_creds,
|
||||
"last_minutes": 60,
|
||||
"post_limit": 3,
|
||||
}
|
||||
text_formatter_input = {
|
||||
"format": """
|
||||
Based on the following post, write your marketing comment:
|
||||
* Post ID: {id}
|
||||
* Post Subreddit: {subreddit}
|
||||
* Post Title: {title}
|
||||
* Post Body: {body}""".strip(),
|
||||
}
|
||||
llm_call_input = {
|
||||
"sys_prompt": """
|
||||
You are an expert at marketing, and have been tasked with picking Reddit posts that are relevant to your product.
|
||||
The product you are marketing is: Auto-GPT an autonomous AI agent utilizing GPT model.
|
||||
You reply the post that you find it relevant to be replied with marketing text.
|
||||
Make sure to only comment on a relevant post.
|
||||
""",
|
||||
"config": openai_creds,
|
||||
"expected_format": {
|
||||
"post_id": "str, the reddit post id",
|
||||
"is_relevant": "bool, whether the post is relevant for marketing",
|
||||
"marketing_text": "str, marketing text, this is empty on irrelevant posts",
|
||||
},
|
||||
}
|
||||
text_matcher_input = {"match": "true", "case_sensitive": False}
|
||||
reddit_comment_input = {"creds": reddit_creds}
|
||||
|
||||
# Nodes
|
||||
reddit_get_post_node = Node(
|
||||
block_id=RedditGetPostsBlock().id,
|
||||
input_default=reddit_get_post_input,
|
||||
)
|
||||
text_formatter_node = Node(
|
||||
block_id=TextFormatterBlock().id,
|
||||
input_default=text_formatter_input,
|
||||
)
|
||||
llm_call_node = Node(
|
||||
block_id=LlmCallBlock().id,
|
||||
input_default=llm_call_input
|
||||
)
|
||||
text_matcher_node = Node(
|
||||
block_id=TextMatcherBlock().id,
|
||||
input_default=text_matcher_input,
|
||||
)
|
||||
reddit_comment_node = Node(
|
||||
block_id=RedditPostCommentBlock().id,
|
||||
input_default=reddit_comment_input,
|
||||
)
|
||||
|
||||
nodes = [
|
||||
reddit_get_post_node,
|
||||
text_formatter_node,
|
||||
llm_call_node,
|
||||
text_matcher_node,
|
||||
reddit_comment_node,
|
||||
]
|
||||
|
||||
# Links
|
||||
links = [
|
||||
Link(reddit_get_post_node.id, text_formatter_node.id, "post", "named_texts"),
|
||||
Link(text_formatter_node.id, llm_call_node.id, "output", "usr_prompt"),
|
||||
Link(llm_call_node.id, text_matcher_node.id, "response", "data"),
|
||||
Link(llm_call_node.id, text_matcher_node.id, "response_#_is_relevant", "text"),
|
||||
Link(text_matcher_node.id, reddit_comment_node.id, "positive_#_post_id",
|
||||
"post_id"),
|
||||
Link(text_matcher_node.id, reddit_comment_node.id, "positive_#_marketing_text",
|
||||
"comment"),
|
||||
]
|
||||
|
||||
# Create graph
|
||||
test_graph = Graph(
|
||||
name="RedditMarketingAgent",
|
||||
description="Reddit marketing agent",
|
||||
nodes=nodes,
|
||||
links=links,
|
||||
)
|
||||
return await create_graph(test_graph)
|
||||
|
||||
|
||||
async def wait_execution(test_manager, graph_id, graph_exec_id) -> list:
|
||||
async def is_execution_completed():
|
||||
execs = await AgentServer().get_executions(graph_id, graph_exec_id)
|
||||
"""
|
||||
List of execution:
|
||||
reddit_get_post_node 1 (produced 3 posts)
|
||||
text_formatter_node 3
|
||||
llm_call_node 3 (assume 3 of them relevant)
|
||||
text_matcher_node 3
|
||||
reddit_comment_node 3
|
||||
Total: 13
|
||||
"""
|
||||
print("--------> Execution count: ", len(execs), [str(v.status) for v in execs])
|
||||
return test_manager.queue.empty() and len(execs) == 13 and all(
|
||||
v.status in [ExecutionStatus.COMPLETED, ExecutionStatus.FAILED]
|
||||
for v in execs
|
||||
)
|
||||
|
||||
# Wait for the executions to complete
|
||||
for i in range(120):
|
||||
if await is_execution_completed():
|
||||
return await AgentServer().get_executions(graph_id, graph_exec_id)
|
||||
time.sleep(1)
|
||||
|
||||
assert False, "Execution did not complete in time."
|
||||
|
||||
|
||||
async def reddit_marketing_agent():
|
||||
with PyroNameServer():
|
||||
with ExecutionManager(1) as test_manager:
|
||||
await db.connect()
|
||||
await block.initialize_blocks()
|
||||
test_graph = await create_test_graph()
|
||||
input_data = {"subreddit": "AutoGPT"}
|
||||
response = await AgentServer().execute_graph(test_graph.id, input_data)
|
||||
print(response)
|
||||
result = await wait_execution(test_manager, test_graph.id, response["id"])
|
||||
print(result)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import asyncio
|
||||
|
||||
asyncio.run(reddit_marketing_agent())
|
Loading…
Reference in New Issue