feat(rnd): Add strong pydantic type & composite data extraction for Block input/output schema + add reddit agent-blocks (#7288)

* feat(rnd): Add type hint and strong pydantic type validation for block input/output + add reddit agent-blocks.

* feat(rnd): Add type hint and strong pydantic type validation for block input/output + add reddit agent-blocks.

* Fix reddit block

* Fix serialization

* Eliminate deprecated class property

* Remove RedditCredentialsBlock

* Cache jsonschema computation, add dictionary construction

* Add dict_split and list_split to output, add more blocks

* Add objc_split for completeness, int both input and output

* Update reddit block

* Add reddit test (untested)

* Resolved json issue on pydantic

* Add creds check on client

* Add dict <--> pydantic object flexibility

* Fix error retry

* Skip reddit test

* Code cleanup

* Chang prompt

* Make this work

* Fix linting

* Hide input_links and output_links from Node

* Add docs

---------

Co-authored-by: Aarushi <50577581+aarushik93@users.noreply.github.com>
pull/7319/head
Zamil Majdy 2024-07-04 14:37:28 +04:00 committed by GitHub
parent db0e726954
commit 833944e228
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
21 changed files with 1080 additions and 283 deletions

View File

@ -0,0 +1,9 @@
from autogpt_server.blocks import sample, reddit, text, object, ai
from autogpt_server.data.block import Block
AVAILABLE_BLOCKS = {
block.id: block
for block in [v() for v in Block.__subclasses__()]
}
__all__ = ["ai", "object", "sample", "reddit", "text", "AVAILABLE_BLOCKS"]

View File

@ -0,0 +1,107 @@
import logging
from enum import Enum
import openai
from pydantic import BaseModel
from autogpt_server.data.block import Block, BlockOutput, BlockSchema
from autogpt_server.util import json
logger = logging.getLogger(__name__)
class LlmModel(str, Enum):
openai_gpt4 = "gpt-4-turbo"
class LlmConfig(BaseModel):
model: LlmModel
api_key: str
class LlmCallBlock(Block):
class Input(BlockSchema):
config: LlmConfig
expected_format: dict[str, str]
sys_prompt: str = ""
usr_prompt: str = ""
retry: int = 3
class Output(BlockSchema):
response: dict[str, str]
error: str
def __init__(self):
super().__init__(
id="ed55ac19-356e-4243-a6cb-bc599e9b716f",
input_schema=LlmCallBlock.Input,
output_schema=LlmCallBlock.Output,
)
def run(self, input_data: Input) -> BlockOutput:
openai.api_key = input_data.config.api_key
expected_format = [f'"{k}": "{v}"' for k, v in
input_data.expected_format.items()]
format_prompt = ",\n ".join(expected_format)
sys_prompt = f"""
|{input_data.sys_prompt}
|
|Reply in json format:
|{{
| {format_prompt}
|}}
"""
usr_prompt = f"""
|{input_data.usr_prompt}
"""
def trim_prompt(s: str) -> str:
lines = s.strip().split("\n")
return "\n".join([line.strip().lstrip("|") for line in lines])
def parse_response(resp: str) -> tuple[dict[str, str], str | None]:
try:
parsed = json.loads(resp)
miss_keys = set(input_data.expected_format.keys()) - set(parsed.keys())
if miss_keys:
return parsed, f"Missing keys: {miss_keys}"
return parsed, None
except Exception as e:
return {}, f"JSON decode error: {e}"
prompt = [
{"role": "system", "content": trim_prompt(sys_prompt)},
{"role": "user", "content": trim_prompt(usr_prompt)},
]
logger.warning(f"LLM request: {prompt}")
retry_prompt = ""
for retry_count in range(input_data.retry):
response = openai.chat.completions.create(
model=input_data.config.model,
messages=prompt, # type: ignore
response_format={"type": "json_object"},
)
response_text = response.choices[0].message.content or ""
logger.warning(f"LLM attempt-{retry_count} response: {response_text}")
parsed_dict, parsed_error = parse_response(response_text)
if not parsed_error:
yield "response", {k: str(v) for k, v in parsed_dict.items()}
return
retry_prompt = f"""
|This is your previous error response:
|--
|{response_text}
|--
|
|And this is the error:
|--
|{parsed_error}
|--
"""
prompt.append({"role": "user", "content": trim_prompt(retry_prompt)})
yield "error", retry_prompt

View File

@ -0,0 +1,33 @@
from typing import Any
from autogpt_server.data.block import Block, BlockOutput, BlockSchema
class ObjectParser(Block):
class Input(BlockSchema):
object: Any
field_path: str
class Output(BlockSchema):
field_value: Any
def __init__(self):
super().__init__(
id="be45299a-193b-4852-bda4-510883d21814",
input_schema=ObjectParser.Input,
output_schema=ObjectParser.Output,
)
def run(self, input_data: Input) -> BlockOutput:
field_path = input_data.field_path.split(".")
field_value = input_data.object
for field in field_path:
if isinstance(field_value, dict) and field in field_value:
field_value = field_value.get(field)
elif isinstance(field_value, object) and hasattr(field_value, field):
field_value = getattr(field_value, field)
else:
yield "error", input_data.object
return
yield "field_value", field_value

View File

@ -0,0 +1,107 @@
# type: ignore
from datetime import datetime, timedelta
import praw
from pydantic import BaseModel, Field
from autogpt_server.data.block import Block, BlockOutput, BlockSchema
class RedditCredentials(BaseModel):
client_id: str
client_secret: str
username: str
password: str
user_agent: str | None = None
class RedditPost(BaseModel):
id: str
subreddit: str
title: str
body: str
def get_praw(creds: RedditCredentials) -> praw.Reddit:
client = praw.Reddit(
client_id=creds.client_id,
client_secret=creds.client_secret,
user_agent=creds.user_agent,
username=creds.username,
password=creds.password,
)
me = client.user.me()
if not me:
raise ValueError("Invalid Reddit credentials.")
print(f"Logged in as Reddit user: {me.name}")
return client
class RedditGetPostsBlock(Block):
class Input(BlockSchema):
creds: RedditCredentials = Field(description="Reddit credentials")
subreddit: str = Field(description="Subreddit name")
last_minutes: int | None = Field(
description="Post time to stop minutes ago while fetching posts",
default=None
)
last_post: str | None = Field(
description="Post ID to stop when reached while fetching posts",
default=None
)
post_limit: int | None = Field(
description="Number of posts to fetch",
default=10
)
class Output(BlockSchema):
post: RedditPost = Field(description="Reddit post")
def __init__(self):
super().__init__(
id="c6731acb-4285-4ee1-bc9b-03d0766c370f",
input_schema=RedditGetPostsBlock.Input,
output_schema=RedditGetPostsBlock.Output,
)
def run(self, input_data: Input) -> BlockOutput:
client = get_praw(input_data.creds)
subreddit = client.subreddit(input_data.subreddit)
for post in subreddit.new(limit=input_data.post_limit):
if input_data.last_post and post.created_utc < datetime.now() - \
timedelta(minutes=input_data.last_minutes):
break
if input_data.last_post and post.id == input_data.last_post:
break
yield "post", RedditPost(
id=post.id,
subreddit=subreddit.display_name,
title=post.title,
body=post.selftext
)
class RedditPostCommentBlock(Block):
class Input(BlockSchema):
creds: RedditCredentials = Field(description="Reddit credentials")
post_id: str = Field(description="Reddit post ID")
comment: str = Field(description="Comment text")
class Output(BlockSchema):
comment_id: str
def __init__(self):
super().__init__(
id="4a92261b-701e-4ffb-8970-675fd28e261f",
input_schema=RedditPostCommentBlock.Input,
output_schema=RedditPostCommentBlock.Output,
)
def run(self, input_data: Input) -> BlockOutput:
client = get_praw(input_data.creds)
submission = client.submission(id=input_data.post_id)
comment = submission.reply(input_data.comment)
yield "comment_id", comment.id

View File

@ -0,0 +1,40 @@
# type: ignore
from autogpt_server.data.block import Block, BlockSchema, BlockOutput
class ParrotBlock(Block):
class Input(BlockSchema):
input: str
class Output(BlockSchema):
output: str
def __init__(self):
super().__init__(
id="1ff065e9-88e8-4358-9d82-8dc91f622ba9",
input_schema=ParrotBlock.Input,
output_schema=ParrotBlock.Output,
)
def run(self, input_data: Input) -> BlockOutput:
yield "output", input_data.input
class PrintingBlock(Block):
class Input(BlockSchema):
text: str
class Output(BlockSchema):
status: str
def __init__(self):
super().__init__(
id="f3b1c1b2-4c4f-4f0d-8d2f-4c4f0d8d2f4c",
input_schema=PrintingBlock.Input,
output_schema=PrintingBlock.Output,
)
def run(self, input_data: Input) -> BlockOutput:
print(">>>>> Print: ", input_data.text)
yield "status", "printed"

View File

@ -0,0 +1,63 @@
import re
from typing import Any
from pydantic import Field
from autogpt_server.data.block import Block, BlockOutput, BlockSchema
class TextMatcherBlock(Block):
class Input(BlockSchema):
text: str = Field(description="Text to match")
match: str = Field(description="Pattern (Regex) to match")
data: Any = Field(description="Data to be forwarded to output")
case_sensitive: bool = Field(description="Case sensitive match", default=True)
class Output(BlockSchema):
positive: Any = Field(description="Output data if match is found")
negative: Any = Field(description="Output data if match is not found")
def __init__(self):
super().__init__(
id="3060088f-6ed9-4928-9ba7-9c92823a7ccd",
input_schema=TextMatcherBlock.Input,
output_schema=TextMatcherBlock.Output,
)
def run(self, input_data: Input) -> BlockOutput:
output = input_data.data or input_data.text
case = 0 if input_data.case_sensitive else re.IGNORECASE
if re.search(input_data.match, input_data.text, case):
yield "positive", output
else:
yield "negative", output
class TextFormatterBlock(Block):
class Input(BlockSchema):
texts: list[str] = Field(
description="Texts (list) to format",
default=[]
)
named_texts: dict[str, str] = Field(
description="Texts (dict) to format",
default={}
)
format: str = Field(
description="Template to format the text using `texts` and `named_texts`",
)
class Output(BlockSchema):
output: str
def __init__(self):
super().__init__(
id="db7d8f02-2f44-4c55-ab7a-eae0941f0c30",
input_schema=TextFormatterBlock.Input,
output_schema=TextFormatterBlock.Output,
)
def run(self, input_data: Input) -> BlockOutput:
yield "output", input_data.format.format(
texts=input_data.texts,
**input_data.named_texts,
)

View File

@ -1,140 +1,112 @@
import json
from abc import ABC, abstractmethod
from typing import Any, Generator, ClassVar
from typing import Any, cast, ClassVar, Generator, Generic, TypeVar, Type
import jsonref
import jsonschema
from prisma.models import AgentBlock
from pydantic import BaseModel
from autogpt_server.util import json
BlockData = dict[str, Any]
class BlockSchema(BaseModel):
"""
A schema for the block input and output data.
The dictionary structure is an object-typed `jsonschema`.
The top-level properties are the block input/output names.
cached_jsonschema: ClassVar[dict[str, Any]] = {}
You can initialize this class by providing a dictionary of properties.
The key is the string of the property name, and the value is either
a string of the type or a dictionary of the jsonschema.
@classmethod
def jsonschema(cls) -> dict[str, Any]:
if cls.cached_jsonschema:
return cls.cached_jsonschema
You can also provide additional keyword arguments for additional properties.
Like `name`, `required` (by default all properties are required), etc.
model = jsonref.replace_refs(cls.model_json_schema())
Example:
input_schema = BlockSchema({
"system_prompt": "string",
"user_prompt": "string",
"max_tokens": "integer",
"user_info": {
"type": "object",
"properties": {
"name": {"type": "string"},
"age": {"type": "integer"},
},
"required": ["name"],
},
}, required=["system_prompt", "user_prompt"])
def ref_to_dict(obj):
if isinstance(obj, dict):
return {
key: ref_to_dict(value)
for key, value in obj.items() if not key.startswith("$")
}
elif isinstance(obj, list):
return [ref_to_dict(item) for item in obj]
return obj
output_schema = BlockSchema({
"on_complete": "string",
"on_failures": "string",
})
"""
cls.cached_jsonschema = cast(dict[str, Any], ref_to_dict(model))
return cls.cached_jsonschema
jsonschema: dict[str, Any]
def __init__(
self,
properties: dict[str, str | dict],
required: list[str] | None = None,
**kwargs: Any,
):
schema = {
"type": "object",
"properties": {
key: {"type": value} if isinstance(value, str) else value
for key, value in properties.items()
},
"required": required or list(properties.keys()),
**kwargs,
}
super().__init__(jsonschema=schema)
def __str__(self) -> str:
return json.dumps(self.jsonschema)
def validate_data(self, data: BlockData) -> str | None:
@classmethod
def validate_data(cls, data: BlockData) -> str | None:
"""
Validate the data against the schema.
Returns the validation error message if the data does not match the schema.
"""
try:
jsonschema.validate(data, self.jsonschema)
jsonschema.validate(data, cls.jsonschema())
return None
except jsonschema.ValidationError as e:
return str(e)
def validate_field(self, field_name: str, data: BlockData) -> str | None:
@classmethod
def validate_field(cls, field_name: str, data: BlockData) -> str | None:
"""
Validate the data against a specific property (one of the input/output name).
Returns the validation error message if the data does not match the schema.
"""
property_schema = self.jsonschema["properties"].get(field_name)
model_schema = cls.jsonschema().get("properties", {})
if not model_schema:
return f"Invalid model schema {cls}"
property_schema = model_schema.get(field_name)
if not property_schema:
return f"Invalid property name {field_name}"
try:
jsonschema.validate(data, property_schema)
jsonschema.validate(json.to_dict(data), property_schema)
return None
except jsonschema.ValidationError as e:
return str(e)
def get_fields(self) -> set[str]:
return set(self.jsonschema["properties"].keys())
@classmethod
def get_fields(cls) -> set[str]:
return set(cls.model_fields.keys())
def get_required_fields(self) -> set[str]:
return set(self.jsonschema["required"])
@classmethod
def get_required_fields(cls) -> set[str]:
return {
field
for field, field_info in cls.model_fields.items()
if field_info.is_required()
}
BlockOutput = Generator[tuple[str, Any], None, None]
BlockSchemaInputType = TypeVar('BlockSchemaInputType', bound=BlockSchema)
BlockSchemaOutputType = TypeVar('BlockSchemaOutputType', bound=BlockSchema)
class Block(ABC, BaseModel):
@classmethod
@property
@abstractmethod
def id(cls) -> str:
class EmptySchema(BlockSchema):
pass
class Block(ABC, Generic[BlockSchemaInputType, BlockSchemaOutputType]):
def __init__(
self,
id: str = "",
input_schema: Type[BlockSchemaInputType] = EmptySchema,
output_schema: Type[BlockSchemaOutputType] = EmptySchema,
):
"""
The unique identifier for the block, this value will be persisted in the DB.
So it should be a unique and constant across the application run.
Use the UUID format for the ID.
"""
pass
@classmethod
@property
@abstractmethod
def input_schema(cls) -> BlockSchema:
"""
The schema for the block input data.
The top-level properties are the possible input name expected by the block.
"""
pass
@classmethod
@property
@abstractmethod
def output_schema(cls) -> BlockSchema:
"""
The schema for the block output.
The top-level properties are the possible output name produced by the block.
"""
pass
self.id = id
self.input_schema = input_schema
self.output_schema = output_schema
@abstractmethod
def run(self, input_data: BlockData) -> BlockOutput:
def run(self, input_data: BlockSchemaInputType) -> BlockOutput:
"""
Run the block with the given input data.
Args:
@ -146,17 +118,16 @@ class Block(ABC, BaseModel):
"""
pass
@classmethod
@property
def name(cls):
return cls.__name__
def name(self):
return self.__class__.__name__
def to_dict(self):
return {
"id": self.id,
"name": self.name,
"inputSchema": self.input_schema.jsonschema,
"outputSchema": self.output_schema.jsonschema,
"inputSchema": self.input_schema.jsonschema(),
"outputSchema": self.output_schema.jsonschema(),
}
def execute(self, input_data: BlockData) -> BlockOutput:
@ -165,83 +136,20 @@ class Block(ABC, BaseModel):
f"Unable to execute block with invalid input data: {error}"
)
for output_name, output_data in self.run(input_data):
for output_name, output_data in self.run(self.input_schema(**input_data)):
if error := self.output_schema.validate_field(output_name, output_data):
raise ValueError(
f"Unable to execute block with invalid output data: {error}"
f"Block produced an invalid output data: {error}"
)
yield output_name, output_data
# ===================== Inline-Block Implementations ===================== #
class ParrotBlock(Block):
id: ClassVar[str] = "1ff065e9-88e8-4358-9d82-8dc91f622ba9" # type: ignore
input_schema: ClassVar[BlockSchema] = BlockSchema( # type: ignore
{
"input": "string",
}
)
output_schema: ClassVar[BlockSchema] = BlockSchema( # type: ignore
{
"output": "string",
}
)
def run(self, input_data: BlockData) -> BlockOutput:
yield "output", input_data["input"]
class TextFormatterBlock(Block):
id: ClassVar[str] = "db7d8f02-2f44-4c55-ab7a-eae0941f0c30" # type: ignore
input_schema: ClassVar[BlockSchema] = BlockSchema( # type: ignore
{
"texts": {
"type": "array",
"items": {"type": "string"},
"minItems": 1,
},
"format": "string",
}
)
output_schema: ClassVar[BlockSchema] = BlockSchema( # type: ignore
{
"combined_text": "string",
}
)
def run(self, input_data: BlockData) -> BlockOutput:
yield "combined_text", input_data["format"].format(texts=input_data["texts"])
class PrintingBlock(Block):
id: ClassVar[str] = "f3b1c1b2-4c4f-4f0d-8d2f-4c4f0d8d2f4c" # type: ignore
input_schema: ClassVar[BlockSchema] = BlockSchema( # type: ignore
{
"text": "string",
}
)
output_schema: ClassVar[BlockSchema] = BlockSchema( # type: ignore
{
"status": "string",
}
)
def run(self, input_data: BlockData) -> BlockOutput:
yield "status", "printed"
# ======================= Block Helper Functions ======================= #
AVAILABLE_BLOCKS: dict[str, Block] = {}
from autogpt_server.blocks import AVAILABLE_BLOCKS # noqa: E402
async def initialize_blocks() -> None:
global AVAILABLE_BLOCKS
AVAILABLE_BLOCKS = {block.id: block() for block in Block.__subclasses__()}
for block in AVAILABLE_BLOCKS.values():
if await AgentBlock.prisma().find_unique(where={"id": block.id}):
continue
@ -250,19 +158,15 @@ async def initialize_blocks() -> None:
data={
"id": block.id,
"name": block.name,
"inputSchema": str(block.input_schema),
"outputSchema": str(block.output_schema),
"inputSchema": json.dumps(block.input_schema.jsonschema()),
"outputSchema": json.dumps(block.output_schema.jsonschema()),
}
)
async def get_blocks() -> list[Block]:
if not AVAILABLE_BLOCKS:
await initialize_blocks()
def get_blocks() -> list[Block]:
return list(AVAILABLE_BLOCKS.values())
async def get_block(block_id: str) -> Block | None:
if not AVAILABLE_BLOCKS:
await initialize_blocks()
def get_block(block_id: str) -> Block | None:
return AVAILABLE_BLOCKS.get(block_id)

View File

@ -1,4 +1,3 @@
import json
from collections import defaultdict
from datetime import datetime
from enum import Enum
@ -12,6 +11,8 @@ from prisma.models import (
)
from pydantic import BaseModel
from autogpt_server.util import json
class NodeExecution(BaseModel):
graph_exec_id: str
@ -148,7 +149,6 @@ async def upsert_execution_input(
json_data = json.dumps(data)
if existing_execution:
print(f"Adding input {input_name}={data} to execution #{existing_execution.id}")
await AgentNodeExecutionInputOutput.prisma().create(
data={
"name": input_name,
@ -159,7 +159,6 @@ async def upsert_execution_input(
return existing_execution.id
else:
print(f"Creating new execution for input {input_name}={data}")
result = await AgentNodeExecution.prisma().create(
data={
"agentNodeId": node_id,
@ -240,16 +239,46 @@ async def get_node_execution_input(node_exec_id: str) -> dict[str, Any]:
return merge_execution_input(exec_input)
SPLIT = "_$_"
LIST_SPLIT = "_$_"
DICT_SPLIT = "_#_"
OBJC_SPLIT = "_@_"
def parse_execution_output(output: tuple[str, Any], name: str) -> Any | None:
# Allow extracting partial output data by name.
output_name, output_data = output
if name == output_name:
return output_data
if name.startswith(f"{output_name}{LIST_SPLIT}"):
index = int(name.split(LIST_SPLIT)[1])
if not isinstance(output_data, list) or len(output_data) <= index:
return None
return output_data[int(name.split(LIST_SPLIT)[1])]
if name.startswith(f"{output_name}{DICT_SPLIT}"):
index = name.split(DICT_SPLIT)[1]
if not isinstance(output_data, dict) or index not in output_data:
return None
return output_data[index]
if name.startswith(f"{output_name}{OBJC_SPLIT}"):
index = name.split(OBJC_SPLIT)[1]
if isinstance(output_data, object) and hasattr(output_data, index):
return getattr(output_data, index)
return None
return None
def merge_execution_input(data: dict[str, Any]) -> dict[str, Any]:
# Merge all input with <input_name>_$_<index> into a single list.
list_input = []
for key, value in data.items():
if SPLIT not in key:
if LIST_SPLIT not in key:
continue
name, index = key.split(SPLIT)
name, index = key.split(LIST_SPLIT)
if not index.isdigit():
list_input.append((name, value, 0))
else:
@ -259,4 +288,21 @@ def merge_execution_input(data: dict[str, Any]) -> dict[str, Any]:
data[name] = data.get(name, [])
data[name].append(value)
# Merge all input with <input_name>_#_<index> into a single dict.
for key, value in data.items():
if DICT_SPLIT not in key:
continue
name, index = key.split(DICT_SPLIT)
data[name] = data.get(name, {})
data[name][index] = value
# Merge all input with <input_name>_@_<index> into a single object.
for key, value in data.items():
if OBJC_SPLIT not in key:
continue
name, index = key.split(OBJC_SPLIT)
if not isinstance(data[name], object):
data[name] = type("Object", (object,), data[name])()
setattr(data[name], index, value)
return data

View File

@ -1,65 +1,82 @@
import asyncio
import json
import uuid
from typing import Any
from typing import Any
from prisma.models import AgentGraph, AgentNode, AgentNodeLink
from pydantic import BaseModel
from pydantic import BaseModel, PrivateAttr
from autogpt_server.data.db import BaseDbModel
from autogpt_server.util import json
class Link(BaseModel):
name: str
node_id: str
def __init__(self, name: str, node_id: str):
super().__init__(name=name, node_id=node_id)
def __iter__(self):
return iter((self.name, self.node_id))
source_id: str
sink_id: str
source_name: str
sink_name: str
def __init__(self, source_id: str, sink_id: str, source_name: str, sink_name: str):
super().__init__(
source_id=source_id,
sink_id=sink_id,
source_name=source_name,
sink_name=sink_name,
)
@staticmethod
def from_db(link: AgentNodeLink):
return Link(
source_name=link.sourceName,
source_id=link.agentNodeSourceId,
sink_name=link.sinkName,
sink_id=link.agentNodeSinkId,
)
def __hash__(self):
return hash((self.source_id, self.sink_id, self.source_name, self.sink_name))
class Node(BaseDbModel):
block_id: str
input_default: dict[str, Any] = {} # dict[input_name, default_value]
input_nodes: list[Link] = [] # dict[input_name, node_id]
output_nodes: list[Link] = [] # dict[output_name, node_id]
metadata: dict[str, Any] = {}
_input_links: list[Link] = PrivateAttr(default=[])
_output_links: list[Link] = PrivateAttr(default=[])
@property
def input_links(self) -> list[Link]:
return self._input_links
@property
def output_links(self) -> list[Link]:
return self._output_links
@staticmethod
def from_db(node: AgentNode):
if not node.AgentBlock:
raise ValueError(f"Invalid node {node.id}, invalid AgentBlock.")
return Node(
obj = Node(
id=node.id,
block_id=node.AgentBlock.id,
input_default=json.loads(node.constantInput),
input_nodes=[
Link(v.sinkName, v.agentNodeSourceId)
for v in node.Input or []
],
output_nodes=[
Link(v.sourceName, v.agentNodeSinkId)
for v in node.Output or []
],
metadata=json.loads(node.metadata),
metadata=json.loads(node.metadata)
)
def connect(self, node: "Node", source_name: str, sink_name: str):
self.output_nodes.append(Link(source_name, node.id))
node.input_nodes.append(Link(sink_name, self.id))
obj._input_links = [Link.from_db(link) for link in node.Input or []]
obj._output_links = [Link.from_db(link) for link in node.Output or []]
return obj
class Graph(BaseDbModel):
name: str
description: str
nodes: list[Node]
links: list[Link]
@property
def starting_nodes(self) -> list[Node]:
return [node for node in self.nodes if not node.input_nodes]
outbound_nodes = {link.sink_id for link in self.links}
return [node for node in self.nodes if node.id not in outbound_nodes]
@staticmethod
def from_db(graph: AgentGraph):
@ -68,6 +85,11 @@ class Graph(BaseDbModel):
name=graph.name or "",
description=graph.description or "",
nodes=[Node.from_db(node) for node in graph.AgentNodes or []],
links=list({
Link.from_db(link)
for node in graph.AgentNodes or []
for link in (node.Input or []) + (node.Output or [])
})
)
@ -121,27 +143,15 @@ async def create_graph(graph: Graph) -> Graph:
}) for node in graph.nodes
])
edge_source_names = {
(source_node.id, sink_node_id): output_name
for source_node in graph.nodes
for output_name, sink_node_id in source_node.output_nodes
}
edge_sink_names = {
(source_node_id, sink_node.id): input_name
for sink_node in graph.nodes
for input_name, source_node_id in sink_node.input_nodes
}
# TODO: replace bulk creation using create_many
await asyncio.gather(*[
AgentNodeLink.prisma().create({
"id": str(uuid.uuid4()),
"sourceName": edge_source_names.get((input_node, output_node), ""),
"sinkName": edge_sink_names.get((input_node, output_node), ""),
"agentNodeSourceId": input_node,
"agentNodeSinkId": output_node,
"sourceName": link.source_name,
"sinkName": link.sink_name,
"agentNodeSourceId": link.source_id,
"agentNodeSinkId": link.sink_id,
})
for input_node, output_node in edge_source_names.keys() | edge_sink_names.keys()
for link in graph.links
])
if created_graph := await get_graph(graph.id):

View File

@ -1,10 +1,10 @@
import json
from datetime import datetime
from typing import Optional, Any
from prisma.models import AgentGraphExecutionSchedule
from autogpt_server.data.db import BaseDbModel
from autogpt_server.util import json
class ExecutionSchedule(BaseDbModel):

View File

@ -9,6 +9,7 @@ from autogpt_server.data.execution import (
create_graph_execution,
get_node_execution_input,
merge_execution_input,
parse_execution_output,
update_execution_status as execution_update,
upsert_execution_output,
upsert_execution_input,
@ -16,7 +17,7 @@ from autogpt_server.data.execution import (
ExecutionStatus,
ExecutionQueue,
)
from autogpt_server.data.graph import Node, get_node, get_graph
from autogpt_server.data.graph import Link, Node, get_node, get_graph
from autogpt_server.util.service import AppService, expose
logger = logging.getLogger(__name__)
@ -57,7 +58,7 @@ def execute_node(loop: asyncio.AbstractEventLoop, data: Execution) -> ExecutionS
logger.error(f"Node {node_id} not found.")
return
node_block = wait(get_block(node.block_id))
node_block = get_block(node.block_id)
if not node_block:
logger.error(f"Block {node.block_id} not found.")
return
@ -74,7 +75,11 @@ def execute_node(loop: asyncio.AbstractEventLoop, data: Execution) -> ExecutionS
wait(upsert_execution_output(node_exec_id, output_name, output_data))
for execution in enqueue_next_nodes(
loop, node, output_name, output_data, graph_exec_id
loop=loop,
node=node,
output=(output_name, output_data),
graph_exec_id=graph_exec_id,
prefix=prefix,
):
yield execution
except Exception as e:
@ -88,45 +93,43 @@ def execute_node(loop: asyncio.AbstractEventLoop, data: Execution) -> ExecutionS
def enqueue_next_nodes(
loop: asyncio.AbstractEventLoop,
node: Node,
output_name: str,
output_data: Any,
output: tuple[str, Any],
graph_exec_id: str,
prefix: str,
) -> list[Execution]:
def wait(f: Coroutine[T, Any, T]) -> T:
return loop.run_until_complete(f)
prefix = get_log_prefix(graph_exec_id, node.id)
node_id = node.id
def get_next_node_execution(node_link: Link) -> Execution | None:
next_output_name = node_link.source_name
next_input_name = node_link.sink_name
next_node_id = node_link.sink_id
# Try to enqueue next eligible nodes
next_node_ids = [nid for name, nid in node.output_nodes if name == output_name]
if not next_node_ids:
logger.error(f"{prefix} Output [{output_name}] has no subsequent node.")
return []
next_data = parse_execution_output(output, next_output_name)
if next_data is None:
return
def validate_node_execution(next_node_id: str):
next_node = wait(get_node(next_node_id))
if not next_node:
logger.error(f"{prefix} Error, next node {next_node_id} not found.")
return
next_node_input_name = next(
name for name, nid in next_node.input_nodes if nid == node_id
)
next_node_exec_id = wait(upsert_execution_input(
node_id=next_node_id,
graph_exec_id=graph_exec_id,
input_name=next_node_input_name,
data=output_data
input_name=next_input_name,
data=next_data
))
next_node_input = wait(get_node_execution_input(next_node_exec_id))
is_valid, validation_resp = wait(validate_exec(next_node, next_node_input))
is_valid, validation_msg = validate_exec(next_node, next_node_input)
suffix = f"{next_output_name}~{next_input_name}#{next_node_id}:{validation_msg}"
if not is_valid:
logger.warning(f"{prefix} Skipped {next_node_id}: {validation_resp}")
logger.warning(f"{prefix} Skipped queueing {suffix}")
return
logger.warning(f"{prefix} Enqueue next node {next_node_id}-{validation_resp}")
logger.warning(f"{prefix} Enqueued {suffix}")
return Execution(
graph_exec_id=graph_exec_id,
node_exec_id=next_node_exec_id,
@ -134,14 +137,11 @@ def enqueue_next_nodes(
data=next_node_input
)
executions = []
for nid in next_node_ids:
if execution := validate_node_execution(nid):
executions.append(execution)
return executions
executions = [get_next_node_execution(link) for link in node.output_links]
return [v for v in executions if v]
async def validate_exec(node: Node, data: dict[str, Any]) -> tuple[bool, str]:
def validate_exec(node: Node, data: dict[str, Any]) -> tuple[bool, str]:
"""
Validate the input data for a node execution.
@ -153,21 +153,24 @@ async def validate_exec(node: Node, data: dict[str, Any]) -> tuple[bool, str]:
A tuple of a boolean indicating if the data is valid, and a message if not.
Return the executed block name if the data is valid.
"""
node_block: Block | None = await get_block(node.block_id)
node_block: Block | None = get_block(node.block_id)
if not node_block:
return False, f"Block for {node.block_id} not found."
error_message = f"Input data missing for {node_block.name}:"
input_fields_from_schema = node_block.input_schema.get_required_fields()
if not input_fields_from_schema.issubset(data):
return False, f"Input data missing: {input_fields_from_schema - set(data)}"
return False, f"{error_message} {input_fields_from_schema - set(data)}"
input_fields_from_nodes = {name for name, _ in node.input_nodes}
input_fields_from_nodes = {link.sink_name for link in node.input_links}
if not input_fields_from_nodes.issubset(data):
return False, f"Input data missing: {input_fields_from_nodes - set(data)}"
return False, f"{error_message} {input_fields_from_nodes - set(data)}"
if error := node_block.input_schema.validate_data(data):
logger.error("Input value doesn't match schema: %s", error)
return False, f"Input data doesn't match {node_block.name}: {error}"
error_message = f"Input data doesn't match {node_block.name}: {error}"
logger.error(error_message)
return False, error_message
return True, node_block.name
@ -221,7 +224,7 @@ class ExecutionManager(AppService):
# Currently, there is no constraint on the number of root nodes in the graph.
for node in graph.starting_nodes:
input_data = merge_execution_input({**node.input_default, **data})
valid, error = self.run_and_wait(validate_exec(node, input_data))
valid, error = validate_exec(node, input_data)
if not valid:
raise Exception(error)

View File

@ -27,6 +27,7 @@ class AgentServer(AppProcess):
@asynccontextmanager
async def lifespan(self, _: FastAPI):
await db.connect()
await block.initialize_blocks()
yield
await db.disconnect()
@ -112,8 +113,8 @@ class AgentServer(AppProcess):
def execution_scheduler_client(self) -> ExecutionScheduler:
return get_service_client(ExecutionScheduler)
async def get_graph_blocks(self) -> list[dict]:
return [v.to_dict() for v in await block.get_blocks()]
def get_graph_blocks(self) -> list[dict]:
return [v.to_dict() for v in block.get_blocks()]
async def get_graphs(self) -> list[str]:
return await get_graph_ids()
@ -128,10 +129,13 @@ class AgentServer(AppProcess):
# TODO: replace uuid generation here to DB generated uuids.
graph.id = str(uuid.uuid4())
id_map = {node.id: str(uuid.uuid4()) for node in graph.nodes}
for node in graph.nodes:
node.id = id_map[node.id]
node.input_nodes = [Link(k, id_map[v]) for k, v in node.input_nodes]
node.output_nodes = [Link(k, id_map[v]) for k, v in node.output_nodes]
for link in graph.links:
link.source_id = id_map[link.source_id]
link.sink_id = id_map[link.sink_id]
return await create_graph(graph)

View File

@ -0,0 +1,14 @@
import json
from fastapi.encoders import jsonable_encoder
def to_dict(data) -> dict:
return jsonable_encoder(data)
def dumps(data) -> str:
return json.dumps(jsonable_encoder(data))
def loads(data) -> dict:
return json.loads(data)

View File

@ -22,7 +22,6 @@ class UpdateTrackingModel(BaseModel, Generic[T]):
self._updated_fields.add(name)
super().__setattr__(name, value)
def mark_updated(self, field_name: str) -> None:
if field_name in self.model_fields:
self._updated_fields.add(field_name)
@ -33,6 +32,10 @@ class UpdateTrackingModel(BaseModel, Generic[T]):
def get_updates(self) -> Dict[str, Any]:
return {field: getattr(self, field) for field in self._updated_fields}
@property
def updated_fields(self):
return self._updated_fields
class Config(UpdateTrackingModel["Config"], BaseSettings):
"""Config for the server."""
@ -54,12 +57,12 @@ class Config(UpdateTrackingModel["Config"], BaseSettings):
@classmethod
def settings_customise_sources(
cls,
settings_cls: Type[BaseSettings],
init_settings: PydanticBaseSettingsSource,
env_settings: PydanticBaseSettingsSource,
dotenv_settings: PydanticBaseSettingsSource,
file_secret_settings: PydanticBaseSettingsSource,
cls,
settings_cls: Type[BaseSettings],
init_settings: PydanticBaseSettingsSource,
env_settings: PydanticBaseSettingsSource,
dotenv_settings: PydanticBaseSettingsSource,
file_secret_settings: PydanticBaseSettingsSource,
) -> Tuple[PydanticBaseSettingsSource, ...]:
return (JsonConfigSettingsSource(settings_cls),)
@ -101,7 +104,7 @@ class Settings(BaseModel):
# Save updated secrets to individual files
secrets_dir = get_secrets_path()
for key in self.secrets._updated_fields:
for key in self.secrets.updated_fields:
secret_file = os.path.join(secrets_dir, key)
with open(secret_file, "w") as f:
f.write(str(getattr(self.secrets, key)))

View File

@ -1,4 +1,4 @@
# This file is automatically @generated by Poetry 1.8.3 and should not be changed by hand.
# This file is automatically @generated by Poetry 1.7.1 and should not be changed by hand.
[[package]]
name = "annotated-types"
@ -91,6 +91,105 @@ files = [
{file = "certifi-2024.6.2.tar.gz", hash = "sha256:3cd43f1c6fa7dedc5899d69d3ad0398fd018ad1a17fba83ddaf78aa46c747516"},
]
[[package]]
name = "charset-normalizer"
version = "3.3.2"
description = "The Real First Universal Charset Detector. Open, modern and actively maintained alternative to Chardet."
optional = false
python-versions = ">=3.7.0"
files = [
{file = "charset-normalizer-3.3.2.tar.gz", hash = "sha256:f30c3cb33b24454a82faecaf01b19c18562b1e89558fb6c56de4d9118a032fd5"},
{file = "charset_normalizer-3.3.2-cp310-cp310-macosx_10_9_universal2.whl", hash = "sha256:25baf083bf6f6b341f4121c2f3c548875ee6f5339300e08be3f2b2ba1721cdd3"},
{file = "charset_normalizer-3.3.2-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:06435b539f889b1f6f4ac1758871aae42dc3a8c0e24ac9e60c2384973ad73027"},
{file = "charset_normalizer-3.3.2-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:9063e24fdb1e498ab71cb7419e24622516c4a04476b17a2dab57e8baa30d6e03"},
{file = "charset_normalizer-3.3.2-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:6897af51655e3691ff853668779c7bad41579facacf5fd7253b0133308cf000d"},
{file = "charset_normalizer-3.3.2-cp310-cp310-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:1d3193f4a680c64b4b6a9115943538edb896edc190f0b222e73761716519268e"},
{file = "charset_normalizer-3.3.2-cp310-cp310-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:cd70574b12bb8a4d2aaa0094515df2463cb429d8536cfb6c7ce983246983e5a6"},
{file = "charset_normalizer-3.3.2-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:8465322196c8b4d7ab6d1e049e4c5cb460d0394da4a27d23cc242fbf0034b6b5"},
{file = "charset_normalizer-3.3.2-cp310-cp310-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:a9a8e9031d613fd2009c182b69c7b2c1ef8239a0efb1df3f7c8da66d5dd3d537"},
{file = "charset_normalizer-3.3.2-cp310-cp310-musllinux_1_1_aarch64.whl", hash = "sha256:beb58fe5cdb101e3a055192ac291b7a21e3b7ef4f67fa1d74e331a7f2124341c"},
{file = "charset_normalizer-3.3.2-cp310-cp310-musllinux_1_1_i686.whl", hash = "sha256:e06ed3eb3218bc64786f7db41917d4e686cc4856944f53d5bdf83a6884432e12"},
{file = "charset_normalizer-3.3.2-cp310-cp310-musllinux_1_1_ppc64le.whl", hash = "sha256:2e81c7b9c8979ce92ed306c249d46894776a909505d8f5a4ba55b14206e3222f"},
{file = "charset_normalizer-3.3.2-cp310-cp310-musllinux_1_1_s390x.whl", hash = "sha256:572c3763a264ba47b3cf708a44ce965d98555f618ca42c926a9c1616d8f34269"},
{file = "charset_normalizer-3.3.2-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:fd1abc0d89e30cc4e02e4064dc67fcc51bd941eb395c502aac3ec19fab46b519"},
{file = "charset_normalizer-3.3.2-cp310-cp310-win32.whl", hash = "sha256:3d47fa203a7bd9c5b6cee4736ee84ca03b8ef23193c0d1ca99b5089f72645c73"},
{file = "charset_normalizer-3.3.2-cp310-cp310-win_amd64.whl", hash = "sha256:10955842570876604d404661fbccbc9c7e684caf432c09c715ec38fbae45ae09"},
{file = "charset_normalizer-3.3.2-cp311-cp311-macosx_10_9_universal2.whl", hash = "sha256:802fe99cca7457642125a8a88a084cef28ff0cf9407060f7b93dca5aa25480db"},
{file = "charset_normalizer-3.3.2-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:573f6eac48f4769d667c4442081b1794f52919e7edada77495aaed9236d13a96"},
{file = "charset_normalizer-3.3.2-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:549a3a73da901d5bc3ce8d24e0600d1fa85524c10287f6004fbab87672bf3e1e"},
{file = "charset_normalizer-3.3.2-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:f27273b60488abe721a075bcca6d7f3964f9f6f067c8c4c605743023d7d3944f"},
{file = "charset_normalizer-3.3.2-cp311-cp311-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:1ceae2f17a9c33cb48e3263960dc5fc8005351ee19db217e9b1bb15d28c02574"},
{file = "charset_normalizer-3.3.2-cp311-cp311-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:65f6f63034100ead094b8744b3b97965785388f308a64cf8d7c34f2f2e5be0c4"},
{file = "charset_normalizer-3.3.2-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:753f10e867343b4511128c6ed8c82f7bec3bd026875576dfd88483c5c73b2fd8"},
{file = "charset_normalizer-3.3.2-cp311-cp311-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:4a78b2b446bd7c934f5dcedc588903fb2f5eec172f3d29e52a9096a43722adfc"},
{file = "charset_normalizer-3.3.2-cp311-cp311-musllinux_1_1_aarch64.whl", hash = "sha256:e537484df0d8f426ce2afb2d0f8e1c3d0b114b83f8850e5f2fbea0e797bd82ae"},
{file = "charset_normalizer-3.3.2-cp311-cp311-musllinux_1_1_i686.whl", hash = "sha256:eb6904c354526e758fda7167b33005998fb68c46fbc10e013ca97f21ca5c8887"},
{file = "charset_normalizer-3.3.2-cp311-cp311-musllinux_1_1_ppc64le.whl", hash = "sha256:deb6be0ac38ece9ba87dea880e438f25ca3eddfac8b002a2ec3d9183a454e8ae"},
{file = "charset_normalizer-3.3.2-cp311-cp311-musllinux_1_1_s390x.whl", hash = "sha256:4ab2fe47fae9e0f9dee8c04187ce5d09f48eabe611be8259444906793ab7cbce"},
{file = "charset_normalizer-3.3.2-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:80402cd6ee291dcb72644d6eac93785fe2c8b9cb30893c1af5b8fdd753b9d40f"},
{file = "charset_normalizer-3.3.2-cp311-cp311-win32.whl", hash = "sha256:7cd13a2e3ddeed6913a65e66e94b51d80a041145a026c27e6bb76c31a853c6ab"},
{file = "charset_normalizer-3.3.2-cp311-cp311-win_amd64.whl", hash = "sha256:663946639d296df6a2bb2aa51b60a2454ca1cb29835324c640dafb5ff2131a77"},
{file = "charset_normalizer-3.3.2-cp312-cp312-macosx_10_9_universal2.whl", hash = "sha256:0b2b64d2bb6d3fb9112bafa732def486049e63de9618b5843bcdd081d8144cd8"},
{file = "charset_normalizer-3.3.2-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:ddbb2551d7e0102e7252db79ba445cdab71b26640817ab1e3e3648dad515003b"},
{file = "charset_normalizer-3.3.2-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:55086ee1064215781fff39a1af09518bc9255b50d6333f2e4c74ca09fac6a8f6"},
{file = "charset_normalizer-3.3.2-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:8f4a014bc36d3c57402e2977dada34f9c12300af536839dc38c0beab8878f38a"},
{file = "charset_normalizer-3.3.2-cp312-cp312-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:a10af20b82360ab00827f916a6058451b723b4e65030c5a18577c8b2de5b3389"},
{file = "charset_normalizer-3.3.2-cp312-cp312-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:8d756e44e94489e49571086ef83b2bb8ce311e730092d2c34ca8f7d925cb20aa"},
{file = "charset_normalizer-3.3.2-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:90d558489962fd4918143277a773316e56c72da56ec7aa3dc3dbbe20fdfed15b"},
{file = "charset_normalizer-3.3.2-cp312-cp312-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:6ac7ffc7ad6d040517be39eb591cac5ff87416c2537df6ba3cba3bae290c0fed"},
{file = "charset_normalizer-3.3.2-cp312-cp312-musllinux_1_1_aarch64.whl", hash = "sha256:7ed9e526742851e8d5cc9e6cf41427dfc6068d4f5a3bb03659444b4cabf6bc26"},
{file = "charset_normalizer-3.3.2-cp312-cp312-musllinux_1_1_i686.whl", hash = "sha256:8bdb58ff7ba23002a4c5808d608e4e6c687175724f54a5dade5fa8c67b604e4d"},
{file = "charset_normalizer-3.3.2-cp312-cp312-musllinux_1_1_ppc64le.whl", hash = "sha256:6b3251890fff30ee142c44144871185dbe13b11bab478a88887a639655be1068"},
{file = "charset_normalizer-3.3.2-cp312-cp312-musllinux_1_1_s390x.whl", hash = "sha256:b4a23f61ce87adf89be746c8a8974fe1c823c891d8f86eb218bb957c924bb143"},
{file = "charset_normalizer-3.3.2-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:efcb3f6676480691518c177e3b465bcddf57cea040302f9f4e6e191af91174d4"},
{file = "charset_normalizer-3.3.2-cp312-cp312-win32.whl", hash = "sha256:d965bba47ddeec8cd560687584e88cf699fd28f192ceb452d1d7ee807c5597b7"},
{file = "charset_normalizer-3.3.2-cp312-cp312-win_amd64.whl", hash = "sha256:96b02a3dc4381e5494fad39be677abcb5e6634bf7b4fa83a6dd3112607547001"},
{file = "charset_normalizer-3.3.2-cp37-cp37m-macosx_10_9_x86_64.whl", hash = "sha256:95f2a5796329323b8f0512e09dbb7a1860c46a39da62ecb2324f116fa8fdc85c"},
{file = "charset_normalizer-3.3.2-cp37-cp37m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:c002b4ffc0be611f0d9da932eb0f704fe2602a9a949d1f738e4c34c75b0863d5"},
{file = "charset_normalizer-3.3.2-cp37-cp37m-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:a981a536974bbc7a512cf44ed14938cf01030a99e9b3a06dd59578882f06f985"},
{file = "charset_normalizer-3.3.2-cp37-cp37m-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:3287761bc4ee9e33561a7e058c72ac0938c4f57fe49a09eae428fd88aafe7bb6"},
{file = "charset_normalizer-3.3.2-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:42cb296636fcc8b0644486d15c12376cb9fa75443e00fb25de0b8602e64c1714"},
{file = "charset_normalizer-3.3.2-cp37-cp37m-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:0a55554a2fa0d408816b3b5cedf0045f4b8e1a6065aec45849de2d6f3f8e9786"},
{file = "charset_normalizer-3.3.2-cp37-cp37m-musllinux_1_1_aarch64.whl", hash = "sha256:c083af607d2515612056a31f0a8d9e0fcb5876b7bfc0abad3ecd275bc4ebc2d5"},
{file = "charset_normalizer-3.3.2-cp37-cp37m-musllinux_1_1_i686.whl", hash = "sha256:87d1351268731db79e0f8e745d92493ee2841c974128ef629dc518b937d9194c"},
{file = "charset_normalizer-3.3.2-cp37-cp37m-musllinux_1_1_ppc64le.whl", hash = "sha256:bd8f7df7d12c2db9fab40bdd87a7c09b1530128315d047a086fa3ae3435cb3a8"},
{file = "charset_normalizer-3.3.2-cp37-cp37m-musllinux_1_1_s390x.whl", hash = "sha256:c180f51afb394e165eafe4ac2936a14bee3eb10debc9d9e4db8958fe36afe711"},
{file = "charset_normalizer-3.3.2-cp37-cp37m-musllinux_1_1_x86_64.whl", hash = "sha256:8c622a5fe39a48f78944a87d4fb8a53ee07344641b0562c540d840748571b811"},
{file = "charset_normalizer-3.3.2-cp37-cp37m-win32.whl", hash = "sha256:db364eca23f876da6f9e16c9da0df51aa4f104a972735574842618b8c6d999d4"},
{file = "charset_normalizer-3.3.2-cp37-cp37m-win_amd64.whl", hash = "sha256:86216b5cee4b06df986d214f664305142d9c76df9b6512be2738aa72a2048f99"},
{file = "charset_normalizer-3.3.2-cp38-cp38-macosx_10_9_universal2.whl", hash = "sha256:6463effa3186ea09411d50efc7d85360b38d5f09b870c48e4600f63af490e56a"},
{file = "charset_normalizer-3.3.2-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:6c4caeef8fa63d06bd437cd4bdcf3ffefe6738fb1b25951440d80dc7df8c03ac"},
{file = "charset_normalizer-3.3.2-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:37e55c8e51c236f95b033f6fb391d7d7970ba5fe7ff453dad675e88cf303377a"},
{file = "charset_normalizer-3.3.2-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:fb69256e180cb6c8a894fee62b3afebae785babc1ee98b81cdf68bbca1987f33"},
{file = "charset_normalizer-3.3.2-cp38-cp38-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:ae5f4161f18c61806f411a13b0310bea87f987c7d2ecdbdaad0e94eb2e404238"},
{file = "charset_normalizer-3.3.2-cp38-cp38-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:b2b0a0c0517616b6869869f8c581d4eb2dd83a4d79e0ebcb7d373ef9956aeb0a"},
{file = "charset_normalizer-3.3.2-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:45485e01ff4d3630ec0d9617310448a8702f70e9c01906b0d0118bdf9d124cf2"},
{file = "charset_normalizer-3.3.2-cp38-cp38-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:eb00ed941194665c332bf8e078baf037d6c35d7c4f3102ea2d4f16ca94a26dc8"},
{file = "charset_normalizer-3.3.2-cp38-cp38-musllinux_1_1_aarch64.whl", hash = "sha256:2127566c664442652f024c837091890cb1942c30937add288223dc895793f898"},
{file = "charset_normalizer-3.3.2-cp38-cp38-musllinux_1_1_i686.whl", hash = "sha256:a50aebfa173e157099939b17f18600f72f84eed3049e743b68ad15bd69b6bf99"},
{file = "charset_normalizer-3.3.2-cp38-cp38-musllinux_1_1_ppc64le.whl", hash = "sha256:4d0d1650369165a14e14e1e47b372cfcb31d6ab44e6e33cb2d4e57265290044d"},
{file = "charset_normalizer-3.3.2-cp38-cp38-musllinux_1_1_s390x.whl", hash = "sha256:923c0c831b7cfcb071580d3f46c4baf50f174be571576556269530f4bbd79d04"},
{file = "charset_normalizer-3.3.2-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:06a81e93cd441c56a9b65d8e1d043daeb97a3d0856d177d5c90ba85acb3db087"},
{file = "charset_normalizer-3.3.2-cp38-cp38-win32.whl", hash = "sha256:6ef1d82a3af9d3eecdba2321dc1b3c238245d890843e040e41e470ffa64c3e25"},
{file = "charset_normalizer-3.3.2-cp38-cp38-win_amd64.whl", hash = "sha256:eb8821e09e916165e160797a6c17edda0679379a4be5c716c260e836e122f54b"},
{file = "charset_normalizer-3.3.2-cp39-cp39-macosx_10_9_universal2.whl", hash = "sha256:c235ebd9baae02f1b77bcea61bce332cb4331dc3617d254df3323aa01ab47bd4"},
{file = "charset_normalizer-3.3.2-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:5b4c145409bef602a690e7cfad0a15a55c13320ff7a3ad7ca59c13bb8ba4d45d"},
{file = "charset_normalizer-3.3.2-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:68d1f8a9e9e37c1223b656399be5d6b448dea850bed7d0f87a8311f1ff3dabb0"},
{file = "charset_normalizer-3.3.2-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:22afcb9f253dac0696b5a4be4a1c0f8762f8239e21b99680099abd9b2b1b2269"},
{file = "charset_normalizer-3.3.2-cp39-cp39-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:e27ad930a842b4c5eb8ac0016b0a54f5aebbe679340c26101df33424142c143c"},
{file = "charset_normalizer-3.3.2-cp39-cp39-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:1f79682fbe303db92bc2b1136016a38a42e835d932bab5b3b1bfcfbf0640e519"},
{file = "charset_normalizer-3.3.2-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:b261ccdec7821281dade748d088bb6e9b69e6d15b30652b74cbbac25e280b796"},
{file = "charset_normalizer-3.3.2-cp39-cp39-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:122c7fa62b130ed55f8f285bfd56d5f4b4a5b503609d181f9ad85e55c89f4185"},
{file = "charset_normalizer-3.3.2-cp39-cp39-musllinux_1_1_aarch64.whl", hash = "sha256:d0eccceffcb53201b5bfebb52600a5fb483a20b61da9dbc885f8b103cbe7598c"},
{file = "charset_normalizer-3.3.2-cp39-cp39-musllinux_1_1_i686.whl", hash = "sha256:9f96df6923e21816da7e0ad3fd47dd8f94b2a5ce594e00677c0013018b813458"},
{file = "charset_normalizer-3.3.2-cp39-cp39-musllinux_1_1_ppc64le.whl", hash = "sha256:7f04c839ed0b6b98b1a7501a002144b76c18fb1c1850c8b98d458ac269e26ed2"},
{file = "charset_normalizer-3.3.2-cp39-cp39-musllinux_1_1_s390x.whl", hash = "sha256:34d1c8da1e78d2e001f363791c98a272bb734000fcef47a491c1e3b0505657a8"},
{file = "charset_normalizer-3.3.2-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:ff8fa367d09b717b2a17a052544193ad76cd49979c805768879cb63d9ca50561"},
{file = "charset_normalizer-3.3.2-cp39-cp39-win32.whl", hash = "sha256:aed38f6e4fb3f5d6bf81bfa990a07806be9d83cf7bacef998ab1a9bd660a581f"},
{file = "charset_normalizer-3.3.2-cp39-cp39-win_amd64.whl", hash = "sha256:b01b88d45a6fcb69667cd6d2f7a9aeb4bf53760d7fc536bf679ec94fe9f3ff3d"},
{file = "charset_normalizer-3.3.2-py3-none-any.whl", hash = "sha256:3e4d1f6587322d2788836a99c69062fbb091331ec940e02d12d179c1d53e25fc"},
]
[[package]]
name = "click"
version = "8.1.7"
@ -190,6 +289,17 @@ files = [
{file = "cx_Logging-3.2.0.tar.gz", hash = "sha256:bdbad6d2e6a0cc5bef962a34d7aa1232e88ea9f3541d6e2881675b5e7eab5502"},
]
[[package]]
name = "distro"
version = "1.9.0"
description = "Distro - an OS platform information API"
optional = false
python-versions = ">=3.6"
files = [
{file = "distro-1.9.0-py3-none-any.whl", hash = "sha256:7bffd925d65168f85027d8da9af6bddab658135b840670a223589bc0c8ef02b2"},
{file = "distro-1.9.0.tar.gz", hash = "sha256:2fa77c6fd8940f116ee1d6b94a2f90b13b5ea8d019b98bc8bafdcabcdd9bdbed"},
]
[[package]]
name = "dmgbuild"
version = "1.6.1"
@ -438,6 +548,17 @@ MarkupSafe = ">=2.0"
[package.extras]
i18n = ["Babel (>=2.7)"]
[[package]]
name = "jsonref"
version = "1.1.0"
description = "jsonref is a library for automatic dereferencing of JSON Reference objects for Python."
optional = false
python-versions = ">=3.7"
files = [
{file = "jsonref-1.1.0-py3-none-any.whl", hash = "sha256:590dc7773df6c21cbf948b5dac07a72a251db28b0238ceecce0a2abfa8ec30a9"},
{file = "jsonref-1.1.0.tar.gz", hash = "sha256:32fe8e1d85af0fdefbebce950af85590b22b60f9e95443176adbde4e1ecea552"},
]
[[package]]
name = "jsonschema"
version = "4.22.0"
@ -619,6 +740,29 @@ files = [
{file = "nodeenv-1.9.1.tar.gz", hash = "sha256:6ec12890a2dab7946721edbfbcd91f3319c6ccc9aec47be7c7e6b7011ee6645f"},
]
[[package]]
name = "openai"
version = "1.35.7"
description = "The official Python library for the openai API"
optional = false
python-versions = ">=3.7.1"
files = [
{file = "openai-1.35.7-py3-none-any.whl", hash = "sha256:3d1e0b0aac9b0db69a972d36dc7efa7563f8e8d65550b27a48f2a0c2ec207e80"},
{file = "openai-1.35.7.tar.gz", hash = "sha256:009bfa1504c9c7ef64d87be55936d142325656bbc6d98c68b669d6472e4beb09"},
]
[package.dependencies]
anyio = ">=3.5.0,<5"
distro = ">=1.7.0,<2"
httpx = ">=0.23.0,<1"
pydantic = ">=1.9.0,<3"
sniffio = "*"
tqdm = ">4"
typing-extensions = ">=4.7,<5"
[package.extras]
datalib = ["numpy (>=1)", "pandas (>=1.2.3)", "pandas-stubs (>=1.1.0.11)"]
[[package]]
name = "packaging"
version = "24.1"
@ -693,6 +837,49 @@ tomli = ">=1.2.2"
[package.extras]
poetry-plugin = ["poetry (>=1.0,<2.0)"]
[[package]]
name = "praw"
version = "7.7.1"
description = "PRAW, an acronym for \"Python Reddit API Wrapper\", is a Python package that allows for simple access to Reddit's API."
optional = false
python-versions = "~=3.7"
files = [
{file = "praw-7.7.1-py3-none-any.whl", hash = "sha256:9ec5dc943db00c175bc6a53f4e089ce625f3fdfb27305564b616747b767d38ef"},
{file = "praw-7.7.1.tar.gz", hash = "sha256:f1d7eef414cafe28080dda12ed09253a095a69933d5c8132eca11d4dc8a070bf"},
]
[package.dependencies]
prawcore = ">=2.1,<3"
update-checker = ">=0.18"
websocket-client = ">=0.54.0"
[package.extras]
ci = ["coveralls"]
dev = ["betamax (>=0.8,<0.9)", "betamax-matchers (>=0.3.0,<0.5)", "furo", "packaging", "pre-commit", "pytest (>=2.7.3)", "requests (>=2.20.1,<3)", "sphinx", "urllib3 (==1.26.*)"]
lint = ["furo", "pre-commit", "sphinx"]
readthedocs = ["furo", "sphinx"]
test = ["betamax (>=0.8,<0.9)", "betamax-matchers (>=0.3.0,<0.5)", "pytest (>=2.7.3)", "requests (>=2.20.1,<3)", "urllib3 (==1.26.*)"]
[[package]]
name = "prawcore"
version = "2.4.0"
description = "\"Low-level communication layer for PRAW 4+."
optional = false
python-versions = "~=3.8"
files = [
{file = "prawcore-2.4.0-py3-none-any.whl", hash = "sha256:29af5da58d85704b439ad3c820873ad541f4535e00bb98c66f0fbcc8c603065a"},
{file = "prawcore-2.4.0.tar.gz", hash = "sha256:b7b2b5a1d04406e086ab4e79988dc794df16059862f329f4c6a43ed09986c335"},
]
[package.dependencies]
requests = ">=2.6.0,<3.0"
[package.extras]
ci = ["coveralls"]
dev = ["packaging", "prawcore[lint]", "prawcore[test]"]
lint = ["pre-commit", "ruff (>=0.0.291)"]
test = ["betamax (>=0.8,<0.9)", "pytest (>=2.7.3)", "urllib3 (==1.26.*)"]
[[package]]
name = "prisma"
version = "0.13.1"
@ -1081,6 +1268,27 @@ files = [
attrs = ">=22.2.0"
rpds-py = ">=0.7.0"
[[package]]
name = "requests"
version = "2.32.3"
description = "Python HTTP for Humans."
optional = false
python-versions = ">=3.8"
files = [
{file = "requests-2.32.3-py3-none-any.whl", hash = "sha256:70761cfe03c773ceb22aa2f671b4757976145175cdfca038c02654d061d6dcc6"},
{file = "requests-2.32.3.tar.gz", hash = "sha256:55365417734eb18255590a9ff9eb97e9e1da868d4ccd6402399eaf68af20a760"},
]
[package.dependencies]
certifi = ">=2017.4.17"
charset-normalizer = ">=2,<4"
idna = ">=2.5,<4"
urllib3 = ">=1.21.1,<3"
[package.extras]
socks = ["PySocks (>=1.5.6,!=1.5.7)"]
use-chardet-on-py3 = ["chardet (>=3.0.2,<6)"]
[[package]]
name = "rpds-py"
version = "0.18.1"
@ -1334,6 +1542,26 @@ files = [
{file = "tomlkit-0.12.5.tar.gz", hash = "sha256:eef34fba39834d4d6b73c9ba7f3e4d1c417a4e56f89a7e96e090dd0d24b8fb3c"},
]
[[package]]
name = "tqdm"
version = "4.66.4"
description = "Fast, Extensible Progress Meter"
optional = false
python-versions = ">=3.7"
files = [
{file = "tqdm-4.66.4-py3-none-any.whl", hash = "sha256:b75ca56b413b030bc3f00af51fd2c1a1a5eac6a0c1cca83cbb37a5c52abce644"},
{file = "tqdm-4.66.4.tar.gz", hash = "sha256:e4d936c9de8727928f3be6079590e97d9abfe8d39a590be678eb5919ffc186bb"},
]
[package.dependencies]
colorama = {version = "*", markers = "platform_system == \"Windows\""}
[package.extras]
dev = ["pytest (>=6)", "pytest-cov", "pytest-timeout", "pytest-xdist"]
notebook = ["ipywidgets (>=6)"]
slack = ["slack-sdk"]
telegram = ["requests"]
[[package]]
name = "typing-extensions"
version = "4.12.2"
@ -1373,6 +1601,42 @@ tzdata = {version = "*", markers = "platform_system == \"Windows\""}
[package.extras]
devenv = ["check-manifest", "pytest (>=4.3)", "pytest-cov", "pytest-mock (>=3.3)", "zest.releaser"]
[[package]]
name = "update-checker"
version = "0.18.0"
description = "A python module that will check for package updates."
optional = false
python-versions = "*"
files = [
{file = "update_checker-0.18.0-py3-none-any.whl", hash = "sha256:cbba64760a36fe2640d80d85306e8fe82b6816659190993b7bdabadee4d4bbfd"},
{file = "update_checker-0.18.0.tar.gz", hash = "sha256:6a2d45bb4ac585884a6b03f9eade9161cedd9e8111545141e9aa9058932acb13"},
]
[package.dependencies]
requests = ">=2.3.0"
[package.extras]
dev = ["black", "flake8", "pytest (>=2.7.3)"]
lint = ["black", "flake8"]
test = ["pytest (>=2.7.3)"]
[[package]]
name = "urllib3"
version = "2.2.2"
description = "HTTP library with thread-safe connection pooling, file post, and more."
optional = false
python-versions = ">=3.8"
files = [
{file = "urllib3-2.2.2-py3-none-any.whl", hash = "sha256:a448b2f64d686155468037e1ace9f2d2199776e17f0a46610480d311f73e3472"},
{file = "urllib3-2.2.2.tar.gz", hash = "sha256:dd505485549a7a552833da5e6063639d0d177c04f23bc3864e41e5dc5f612168"},
]
[package.extras]
brotli = ["brotli (>=1.0.9)", "brotlicffi (>=0.8.0)"]
h2 = ["h2 (>=4,<5)"]
socks = ["pysocks (>=1.5.6,!=1.5.7,<2.0)"]
zstd = ["zstandard (>=0.18.0)"]
[[package]]
name = "uvicorn"
version = "0.30.1"
@ -1574,6 +1838,22 @@ files = [
[package.dependencies]
anyio = ">=3.0.0"
[[package]]
name = "websocket-client"
version = "1.8.0"
description = "WebSocket client for Python with low level API options"
optional = false
python-versions = ">=3.8"
files = [
{file = "websocket_client-1.8.0-py3-none-any.whl", hash = "sha256:17b44cc997f5c498e809b22cdf2d9c7a9e71c02c8cc2b6c56e7c2d1239bfa526"},
{file = "websocket_client-1.8.0.tar.gz", hash = "sha256:3239df9f44da632f96012472805d40a23281a991027ce11d2f45a6f24ac4c3da"},
]
[package.extras]
docs = ["Sphinx (>=6.0)", "myst-parser (>=2.0.0)", "sphinx-rtd-theme (>=1.1.0)"]
optional = ["python-socks", "wsaccel"]
test = ["websockets"]
[[package]]
name = "websockets"
version = "12.0"
@ -1672,4 +1952,4 @@ test = ["pytest (>=6.0.0)", "setuptools (>=65)"]
[metadata]
lock-version = "2.0"
python-versions = "^3.10"
content-hash = "a243c28b48b60e14513fc18629096a7f9d1c60ae7b05a6c50125c1d4c045033e"
content-hash = "ef4326d7688ca5c6c7d3c5189d1058183bd6916f2ba536ad42ad306c9d2ab595"

View File

@ -26,6 +26,9 @@ apscheduler = "^3.10.4"
croniter = "^2.0.5"
pytest-asyncio = "^0.23.7"
pydantic-settings = "^2.3.4"
praw = "^7.7.1"
openai = "^1.35.7"
jsonref = "^1.1.0"
[tool.poetry.group.dev.dependencies]

View File

View File

@ -1,11 +1,12 @@
import time
import pytest
from autogpt_server.data import block, db, execution, graph
from autogpt_server.executor import ExecutionManager
from autogpt_server.server import AgentServer
from autogpt_server.util.service import PyroNameServer
from autogpt_server.blocks.sample import ParrotBlock, PrintingBlock
from autogpt_server.blocks.text import TextFormatterBlock
async def create_test_graph() -> graph.Graph:
@ -17,27 +18,28 @@ async def create_test_graph() -> graph.Graph:
ParrotBlock
"""
nodes = [
graph.Node(block_id=block.ParrotBlock.id),
graph.Node(block_id=block.ParrotBlock.id),
graph.Node(block_id=ParrotBlock().id),
graph.Node(block_id=ParrotBlock().id),
graph.Node(
block_id=block.TextFormatterBlock.id,
block_id=TextFormatterBlock().id,
input_default={
"format": "{texts[0]},{texts[1]},{texts[2]}",
"texts_$_3": "!!!",
},
),
graph.Node(block_id=block.PrintingBlock.id),
graph.Node(block_id=PrintingBlock().id),
]
links = [
graph.Link(nodes[0].id, nodes[2].id, "output", "texts_$_1"),
graph.Link(nodes[1].id, nodes[2].id, "output", "texts_$_2"),
graph.Link(nodes[2].id, nodes[3].id, "output", "text"),
]
nodes[0].connect(nodes[2], "output", "texts_$_1")
nodes[1].connect(nodes[2], "output", "texts_$_2")
nodes[2].connect(nodes[3], "combined_text", "text")
test_graph = graph.Graph(
name="TestGraph",
description="Test graph",
nodes=nodes,
links=links,
)
await block.initialize_blocks()
result = await graph.create_graph(test_graph)
# Assertions
@ -48,7 +50,7 @@ async def create_test_graph() -> graph.Graph:
return test_graph
async def execute_graph(test_manager: ExecutionManager, test_graph: graph.Graph):
async def execute_graph(test_manager: ExecutionManager, test_graph: graph.Graph) -> str:
# --- Test adding new executions --- #
text = "Hello, World!"
input_data = {"input": text}
@ -70,6 +72,12 @@ async def execute_graph(test_manager: ExecutionManager, test_graph: graph.Graph)
# Execution queue should be empty
assert await is_execution_completed()
return graph_exec_id
async def assert_executions(test_graph: graph.Graph, graph_exec_id: str):
text = "Hello, World!"
agent_server = AgentServer()
executions = await agent_server.get_executions(test_graph.id, graph_exec_id)
# Executing ParrotBlock1
@ -92,7 +100,7 @@ async def execute_graph(test_manager: ExecutionManager, test_graph: graph.Graph)
exec = executions[2]
assert exec.status == execution.ExecutionStatus.COMPLETED
assert exec.graph_exec_id == graph_exec_id
assert exec.output_data == {"combined_text": ["Hello, World!,Hello, World!,!!!"]}
assert exec.output_data == {"output": ["Hello, World!,Hello, World!,!!!"]}
assert exec.input_data == {
"texts_$_1": "Hello, World!",
"texts_$_2": "Hello, World!",
@ -113,5 +121,7 @@ async def test_agent_execution():
with PyroNameServer():
with ExecutionManager(1) as test_manager:
await db.connect()
await block.initialize_blocks()
test_graph = await create_test_graph()
await execute_graph(test_manager, test_graph)
graph_exec_id = await execute_graph(test_manager, test_graph)
await assert_executions(test_graph, graph_exec_id)

View File

@ -0,0 +1,161 @@
import time
from autogpt_server.data import block, db
from autogpt_server.data.graph import Graph, Link, Node, create_graph
from autogpt_server.data.execution import ExecutionStatus
from autogpt_server.blocks.ai import LlmConfig, LlmCallBlock, LlmModel
from autogpt_server.blocks.reddit import (
RedditCredentials,
RedditGetPostsBlock,
RedditPostCommentBlock,
)
from autogpt_server.blocks.text import TextFormatterBlock, TextMatcherBlock
from autogpt_server.executor import ExecutionManager
from autogpt_server.server import AgentServer
from autogpt_server.util.service import PyroNameServer
async def create_test_graph() -> Graph:
# /--- post_id -----------\ /--- post_id ---\
# subreddit --> RedditGetPostsBlock ---- post_body -------- TextFormatterBlock ----- LlmCallBlock / TextRelevancy --- relevant/not -- TextMatcherBlock -- Yes {postid, text} --- RedditPostCommentBlock
# \--- post_title -------/ \--- marketing_text ---/ -- No
# Creds
reddit_creds = RedditCredentials(
client_id="TODO_FILL_OUT_THIS",
client_secret="TODO_FILL_OUT_THIS",
username="TODO_FILL_OUT_THIS",
password="TODO_FILL_OUT_THIS",
user_agent="TODO_FILL_OUT_THIS",
)
openai_creds = LlmConfig(
model=LlmModel.openai_gpt4,
api_key="TODO_FILL_OUT_THIS",
)
# Hardcoded inputs
reddit_get_post_input = {
"creds": reddit_creds,
"last_minutes": 60,
"post_limit": 3,
}
text_formatter_input = {
"format": """
Based on the following post, write your marketing comment:
* Post ID: {id}
* Post Subreddit: {subreddit}
* Post Title: {title}
* Post Body: {body}""".strip(),
}
llm_call_input = {
"sys_prompt": """
You are an expert at marketing, and have been tasked with picking Reddit posts that are relevant to your product.
The product you are marketing is: Auto-GPT an autonomous AI agent utilizing GPT model.
You reply the post that you find it relevant to be replied with marketing text.
Make sure to only comment on a relevant post.
""",
"config": openai_creds,
"expected_format": {
"post_id": "str, the reddit post id",
"is_relevant": "bool, whether the post is relevant for marketing",
"marketing_text": "str, marketing text, this is empty on irrelevant posts",
},
}
text_matcher_input = {"match": "true", "case_sensitive": False}
reddit_comment_input = {"creds": reddit_creds}
# Nodes
reddit_get_post_node = Node(
block_id=RedditGetPostsBlock().id,
input_default=reddit_get_post_input,
)
text_formatter_node = Node(
block_id=TextFormatterBlock().id,
input_default=text_formatter_input,
)
llm_call_node = Node(
block_id=LlmCallBlock().id,
input_default=llm_call_input
)
text_matcher_node = Node(
block_id=TextMatcherBlock().id,
input_default=text_matcher_input,
)
reddit_comment_node = Node(
block_id=RedditPostCommentBlock().id,
input_default=reddit_comment_input,
)
nodes = [
reddit_get_post_node,
text_formatter_node,
llm_call_node,
text_matcher_node,
reddit_comment_node,
]
# Links
links = [
Link(reddit_get_post_node.id, text_formatter_node.id, "post", "named_texts"),
Link(text_formatter_node.id, llm_call_node.id, "output", "usr_prompt"),
Link(llm_call_node.id, text_matcher_node.id, "response", "data"),
Link(llm_call_node.id, text_matcher_node.id, "response_#_is_relevant", "text"),
Link(text_matcher_node.id, reddit_comment_node.id, "positive_#_post_id",
"post_id"),
Link(text_matcher_node.id, reddit_comment_node.id, "positive_#_marketing_text",
"comment"),
]
# Create graph
test_graph = Graph(
name="RedditMarketingAgent",
description="Reddit marketing agent",
nodes=nodes,
links=links,
)
return await create_graph(test_graph)
async def wait_execution(test_manager, graph_id, graph_exec_id) -> list:
async def is_execution_completed():
execs = await AgentServer().get_executions(graph_id, graph_exec_id)
"""
List of execution:
reddit_get_post_node 1 (produced 3 posts)
text_formatter_node 3
llm_call_node 3 (assume 3 of them relevant)
text_matcher_node 3
reddit_comment_node 3
Total: 13
"""
print("--------> Execution count: ", len(execs), [str(v.status) for v in execs])
return test_manager.queue.empty() and len(execs) == 13 and all(
v.status in [ExecutionStatus.COMPLETED, ExecutionStatus.FAILED]
for v in execs
)
# Wait for the executions to complete
for i in range(120):
if await is_execution_completed():
return await AgentServer().get_executions(graph_id, graph_exec_id)
time.sleep(1)
assert False, "Execution did not complete in time."
async def reddit_marketing_agent():
with PyroNameServer():
with ExecutionManager(1) as test_manager:
await db.connect()
await block.initialize_blocks()
test_graph = await create_test_graph()
input_data = {"subreddit": "AutoGPT"}
response = await AgentServer().execute_graph(test_graph.id, input_data)
print(response)
result = await wait_execution(test_manager, test_graph.id, response["id"])
print(result)
if __name__ == "__main__":
import asyncio
asyncio.run(reddit_marketing_agent())