feat(rnd): Add support for dynamic input as list for AgentServer Block (#7268)

On AgentServer, To create a Block like StringFormatterBlock or LllmCallBlock, we need some way to dynamically link input pins and aggregate them into a single list input. This will give a better experience for the user to construct an input and link it from the output of the other nodes. The scope of this change is adding support for that in the least intrusive way.

Proposal
To differentiate the input list name and its singular entry we are using the $_<index> prefix. For example:
For the input items: list[int], you can set a pin items with values like [1,2,3,4]. But you can also add input pins like items_$_0 or items_$_4 with values 1 or 2, which will be appended to the items input in alphabetical order.
The execution engine will guarantee to wait for the execution until all the input pin value is produced, so input pin with list input will produce fix-sized list.
pull/7286/head
Zamil Majdy 2024-06-27 18:51:34 +04:00 committed by GitHub
parent cdc658695f
commit 2bc22c5450
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 61 additions and 26 deletions

View File

@ -94,6 +94,9 @@ class BlockSchema(BaseModel):
def get_fields(self) -> set[str]:
return set(self.jsonschema["properties"].keys())
def get_required_fields(self) -> set[str]:
return set(self.jsonschema["required"])
BlockOutput = Generator[tuple[str, Any], None, None]
@ -190,12 +193,15 @@ class ParrotBlock(Block):
yield "output", input_data["input"]
class TextCombinerBlock(Block):
class TextFormatterBlock(Block):
id: ClassVar[str] = "db7d8f02-2f44-4c55-ab7a-eae0941f0c30" # type: ignore
input_schema: ClassVar[BlockSchema] = BlockSchema( # type: ignore
{
"text1": "string",
"text2": "string",
"texts": {
"type": "array",
"items": {"type": "string"},
"minItems": 1,
},
"format": "string",
}
)
@ -206,10 +212,7 @@ class TextCombinerBlock(Block):
)
def run(self, input_data: BlockData) -> BlockOutput:
yield "combined_text", input_data["format"].format(
text1=input_data["text1"],
text2=input_data["text2"],
)
yield "combined_text", input_data["format"].format(texts=input_data["texts"])
class PrintingBlock(Block):

View File

@ -236,4 +236,27 @@ async def get_node_execution_input(node_exec_id: str) -> dict[str, Any]:
exec_input = json.loads(execution.AgentNode.constantInput)
for input_data in execution.Input or []:
exec_input[input_data.name] = json.loads(input_data.data)
return exec_input
return merge_execution_input(exec_input)
SPLIT = "_$_"
def merge_execution_input(data: dict[str, Any]) -> dict[str, Any]:
list_input = []
for key, value in data.items():
if SPLIT not in key:
continue
name, index = key.split(SPLIT)
if not index.isdigit():
list_input.append((name, value, 0))
else:
list_input.append((name, value, int(index)))
for name, value, _ in sorted(list_input, key=lambda x: x[2]):
data[name] = data.get(name, [])
data[name].append(value)
return data

View File

@ -6,8 +6,9 @@ from typing import Any, Coroutine, Generator, TypeVar
from autogpt_server.data import db
from autogpt_server.data.block import Block, get_block
from autogpt_server.data.execution import (
get_node_execution_input,
create_graph_execution,
get_node_execution_input,
merge_execution_input,
update_execution_status as execution_update,
upsert_execution_output,
upsert_execution_input,
@ -77,9 +78,10 @@ def execute_node(loop: asyncio.AbstractEventLoop, data: Execution) -> ExecutionS
):
yield execution
except Exception as e:
logger.exception(f"{prefix} failed with error: %s", e)
error_msg = f"{e.__class__.__name__}: {e}"
logger.exception(f"{prefix} failed with error. `%s`", error_msg)
wait(execution_update(node_exec_id, ExecutionStatus.FAILED))
wait(upsert_execution_output(node_exec_id, "error", str(e)))
wait(upsert_execution_output(node_exec_id, "error", error_msg))
raise e
@ -151,13 +153,17 @@ async def validate_exec(node: Node, data: dict[str, Any]) -> tuple[bool, str]:
A tuple of a boolean indicating if the data is valid, and a message if not.
Return the executed block name if the data is valid.
"""
node_block: Block | None = await(get_block(node.block_id))
node_block: Block | None = await get_block(node.block_id)
if not node_block:
return False, f"Block for {node.block_id} not found."
input_fields = node_block.input_schema.get_fields()
if not input_fields.issubset(data):
return False, f"Input data missing: {input_fields - set(data)}"
input_fields_from_schema = node_block.input_schema.get_required_fields()
if not input_fields_from_schema.issubset(data):
return False, f"Input data missing: {input_fields_from_schema - set(data)}"
input_fields_from_nodes = {name for name, _ in node.input_nodes}
if not input_fields_from_nodes.issubset(data):
return False, f"Input data missing: {input_fields_from_nodes - set(data)}"
if error := node_block.input_schema.validate_data(data):
logger.error("Input value doesn't match schema: %s", error)
@ -214,7 +220,7 @@ class ExecutionManager(AppService):
# Currently, there is no constraint on the number of root nodes in the graph.
for node in graph.starting_nodes:
input_data = {**node.input_default, **data}
input_data = merge_execution_input({**node.input_default, **data})
valid, error = self.run_and_wait(validate_exec(node, input_data))
if not valid:
raise Exception(error)

View File

@ -12,7 +12,7 @@ async def create_test_graph() -> graph.Graph:
"""
ParrotBlock
\
---- TextCombinerBlock ---- PrintingBlock
---- TextFormatterBlock ---- PrintingBlock
/
ParrotBlock
"""
@ -20,13 +20,16 @@ async def create_test_graph() -> graph.Graph:
graph.Node(block_id=block.ParrotBlock.id),
graph.Node(block_id=block.ParrotBlock.id),
graph.Node(
block_id=block.TextCombinerBlock.id,
input_default={"format": "{text1},{text2}"},
block_id=block.TextFormatterBlock.id,
input_default={
"format": "{texts[0]},{texts[1]},{texts[2]}",
"texts_$_3": "!!!",
},
),
graph.Node(block_id=block.PrintingBlock.id),
]
nodes[0].connect(nodes[2], "output", "text1")
nodes[1].connect(nodes[2], "output", "text2")
nodes[0].connect(nodes[2], "output", "texts_$_1")
nodes[1].connect(nodes[2], "output", "texts_$_2")
nodes[2].connect(nodes[3], "combined_text", "text")
test_graph = graph.Graph(
@ -85,14 +88,14 @@ async def execute_graph(test_manager: ExecutionManager, test_graph: graph.Graph)
assert exec.input_data == {"input": text}
assert exec.node_id == test_graph.nodes[1].id
# Executing TextCombinerBlock
# Executing TextFormatterBlock
exec = executions[2]
assert exec.status == execution.ExecutionStatus.COMPLETED
assert exec.graph_exec_id == graph_exec_id
assert exec.output_data == {"combined_text": ["Hello, World!,Hello, World!"]}
assert exec.output_data == {"combined_text": ["Hello, World!,Hello, World!,!!!"]}
assert exec.input_data == {
"text1": "Hello, World!",
"text2": "Hello, World!",
"texts_$_1": "Hello, World!",
"texts_$_2": "Hello, World!",
}
assert exec.node_id == test_graph.nodes[2].id
@ -101,7 +104,7 @@ async def execute_graph(test_manager: ExecutionManager, test_graph: graph.Graph)
assert exec.status == execution.ExecutionStatus.COMPLETED
assert exec.graph_exec_id == graph_exec_id
assert exec.output_data == {"status": ["printed"]}
assert exec.input_data == {"text": "Hello, World!,Hello, World!"}
assert exec.input_data == {"text": "Hello, World!,Hello, World!,!!!"}
assert exec.node_id == test_graph.nodes[3].id