feat(rnd): Add support for dynamic input as list for AgentServer Block (#7268)
On AgentServer, To create a Block like StringFormatterBlock or LllmCallBlock, we need some way to dynamically link input pins and aggregate them into a single list input. This will give a better experience for the user to construct an input and link it from the output of the other nodes. The scope of this change is adding support for that in the least intrusive way. Proposal To differentiate the input list name and its singular entry we are using the $_<index> prefix. For example: For the input items: list[int], you can set a pin items with values like [1,2,3,4]. But you can also add input pins like items_$_0 or items_$_4 with values 1 or 2, which will be appended to the items input in alphabetical order. The execution engine will guarantee to wait for the execution until all the input pin value is produced, so input pin with list input will produce fix-sized list.pull/7286/head
parent
cdc658695f
commit
2bc22c5450
|
@ -94,6 +94,9 @@ class BlockSchema(BaseModel):
|
|||
def get_fields(self) -> set[str]:
|
||||
return set(self.jsonschema["properties"].keys())
|
||||
|
||||
def get_required_fields(self) -> set[str]:
|
||||
return set(self.jsonschema["required"])
|
||||
|
||||
|
||||
BlockOutput = Generator[tuple[str, Any], None, None]
|
||||
|
||||
|
@ -190,12 +193,15 @@ class ParrotBlock(Block):
|
|||
yield "output", input_data["input"]
|
||||
|
||||
|
||||
class TextCombinerBlock(Block):
|
||||
class TextFormatterBlock(Block):
|
||||
id: ClassVar[str] = "db7d8f02-2f44-4c55-ab7a-eae0941f0c30" # type: ignore
|
||||
input_schema: ClassVar[BlockSchema] = BlockSchema( # type: ignore
|
||||
{
|
||||
"text1": "string",
|
||||
"text2": "string",
|
||||
"texts": {
|
||||
"type": "array",
|
||||
"items": {"type": "string"},
|
||||
"minItems": 1,
|
||||
},
|
||||
"format": "string",
|
||||
}
|
||||
)
|
||||
|
@ -206,10 +212,7 @@ class TextCombinerBlock(Block):
|
|||
)
|
||||
|
||||
def run(self, input_data: BlockData) -> BlockOutput:
|
||||
yield "combined_text", input_data["format"].format(
|
||||
text1=input_data["text1"],
|
||||
text2=input_data["text2"],
|
||||
)
|
||||
yield "combined_text", input_data["format"].format(texts=input_data["texts"])
|
||||
|
||||
|
||||
class PrintingBlock(Block):
|
||||
|
|
|
@ -236,4 +236,27 @@ async def get_node_execution_input(node_exec_id: str) -> dict[str, Any]:
|
|||
exec_input = json.loads(execution.AgentNode.constantInput)
|
||||
for input_data in execution.Input or []:
|
||||
exec_input[input_data.name] = json.loads(input_data.data)
|
||||
return exec_input
|
||||
|
||||
return merge_execution_input(exec_input)
|
||||
|
||||
|
||||
SPLIT = "_$_"
|
||||
|
||||
|
||||
def merge_execution_input(data: dict[str, Any]) -> dict[str, Any]:
|
||||
list_input = []
|
||||
for key, value in data.items():
|
||||
if SPLIT not in key:
|
||||
continue
|
||||
|
||||
name, index = key.split(SPLIT)
|
||||
if not index.isdigit():
|
||||
list_input.append((name, value, 0))
|
||||
else:
|
||||
list_input.append((name, value, int(index)))
|
||||
|
||||
for name, value, _ in sorted(list_input, key=lambda x: x[2]):
|
||||
data[name] = data.get(name, [])
|
||||
data[name].append(value)
|
||||
|
||||
return data
|
||||
|
|
|
@ -6,8 +6,9 @@ from typing import Any, Coroutine, Generator, TypeVar
|
|||
from autogpt_server.data import db
|
||||
from autogpt_server.data.block import Block, get_block
|
||||
from autogpt_server.data.execution import (
|
||||
get_node_execution_input,
|
||||
create_graph_execution,
|
||||
get_node_execution_input,
|
||||
merge_execution_input,
|
||||
update_execution_status as execution_update,
|
||||
upsert_execution_output,
|
||||
upsert_execution_input,
|
||||
|
@ -77,9 +78,10 @@ def execute_node(loop: asyncio.AbstractEventLoop, data: Execution) -> ExecutionS
|
|||
):
|
||||
yield execution
|
||||
except Exception as e:
|
||||
logger.exception(f"{prefix} failed with error: %s", e)
|
||||
error_msg = f"{e.__class__.__name__}: {e}"
|
||||
logger.exception(f"{prefix} failed with error. `%s`", error_msg)
|
||||
wait(execution_update(node_exec_id, ExecutionStatus.FAILED))
|
||||
wait(upsert_execution_output(node_exec_id, "error", str(e)))
|
||||
wait(upsert_execution_output(node_exec_id, "error", error_msg))
|
||||
raise e
|
||||
|
||||
|
||||
|
@ -151,13 +153,17 @@ async def validate_exec(node: Node, data: dict[str, Any]) -> tuple[bool, str]:
|
|||
A tuple of a boolean indicating if the data is valid, and a message if not.
|
||||
Return the executed block name if the data is valid.
|
||||
"""
|
||||
node_block: Block | None = await(get_block(node.block_id))
|
||||
node_block: Block | None = await get_block(node.block_id)
|
||||
if not node_block:
|
||||
return False, f"Block for {node.block_id} not found."
|
||||
|
||||
input_fields = node_block.input_schema.get_fields()
|
||||
if not input_fields.issubset(data):
|
||||
return False, f"Input data missing: {input_fields - set(data)}"
|
||||
input_fields_from_schema = node_block.input_schema.get_required_fields()
|
||||
if not input_fields_from_schema.issubset(data):
|
||||
return False, f"Input data missing: {input_fields_from_schema - set(data)}"
|
||||
|
||||
input_fields_from_nodes = {name for name, _ in node.input_nodes}
|
||||
if not input_fields_from_nodes.issubset(data):
|
||||
return False, f"Input data missing: {input_fields_from_nodes - set(data)}"
|
||||
|
||||
if error := node_block.input_schema.validate_data(data):
|
||||
logger.error("Input value doesn't match schema: %s", error)
|
||||
|
@ -214,7 +220,7 @@ class ExecutionManager(AppService):
|
|||
|
||||
# Currently, there is no constraint on the number of root nodes in the graph.
|
||||
for node in graph.starting_nodes:
|
||||
input_data = {**node.input_default, **data}
|
||||
input_data = merge_execution_input({**node.input_default, **data})
|
||||
valid, error = self.run_and_wait(validate_exec(node, input_data))
|
||||
if not valid:
|
||||
raise Exception(error)
|
||||
|
|
|
@ -12,7 +12,7 @@ async def create_test_graph() -> graph.Graph:
|
|||
"""
|
||||
ParrotBlock
|
||||
\
|
||||
---- TextCombinerBlock ---- PrintingBlock
|
||||
---- TextFormatterBlock ---- PrintingBlock
|
||||
/
|
||||
ParrotBlock
|
||||
"""
|
||||
|
@ -20,13 +20,16 @@ async def create_test_graph() -> graph.Graph:
|
|||
graph.Node(block_id=block.ParrotBlock.id),
|
||||
graph.Node(block_id=block.ParrotBlock.id),
|
||||
graph.Node(
|
||||
block_id=block.TextCombinerBlock.id,
|
||||
input_default={"format": "{text1},{text2}"},
|
||||
block_id=block.TextFormatterBlock.id,
|
||||
input_default={
|
||||
"format": "{texts[0]},{texts[1]},{texts[2]}",
|
||||
"texts_$_3": "!!!",
|
||||
},
|
||||
),
|
||||
graph.Node(block_id=block.PrintingBlock.id),
|
||||
]
|
||||
nodes[0].connect(nodes[2], "output", "text1")
|
||||
nodes[1].connect(nodes[2], "output", "text2")
|
||||
nodes[0].connect(nodes[2], "output", "texts_$_1")
|
||||
nodes[1].connect(nodes[2], "output", "texts_$_2")
|
||||
nodes[2].connect(nodes[3], "combined_text", "text")
|
||||
|
||||
test_graph = graph.Graph(
|
||||
|
@ -85,14 +88,14 @@ async def execute_graph(test_manager: ExecutionManager, test_graph: graph.Graph)
|
|||
assert exec.input_data == {"input": text}
|
||||
assert exec.node_id == test_graph.nodes[1].id
|
||||
|
||||
# Executing TextCombinerBlock
|
||||
# Executing TextFormatterBlock
|
||||
exec = executions[2]
|
||||
assert exec.status == execution.ExecutionStatus.COMPLETED
|
||||
assert exec.graph_exec_id == graph_exec_id
|
||||
assert exec.output_data == {"combined_text": ["Hello, World!,Hello, World!"]}
|
||||
assert exec.output_data == {"combined_text": ["Hello, World!,Hello, World!,!!!"]}
|
||||
assert exec.input_data == {
|
||||
"text1": "Hello, World!",
|
||||
"text2": "Hello, World!",
|
||||
"texts_$_1": "Hello, World!",
|
||||
"texts_$_2": "Hello, World!",
|
||||
}
|
||||
assert exec.node_id == test_graph.nodes[2].id
|
||||
|
||||
|
@ -101,7 +104,7 @@ async def execute_graph(test_manager: ExecutionManager, test_graph: graph.Graph)
|
|||
assert exec.status == execution.ExecutionStatus.COMPLETED
|
||||
assert exec.graph_exec_id == graph_exec_id
|
||||
assert exec.output_data == {"status": ["printed"]}
|
||||
assert exec.input_data == {"text": "Hello, World!,Hello, World!"}
|
||||
assert exec.input_data == {"text": "Hello, World!,Hello, World!,!!!"}
|
||||
assert exec.node_id == test_graph.nodes[3].id
|
||||
|
||||
|
||||
|
|
Loading…
Reference in New Issue