diff --git a/rnd/autogpt_server/autogpt_server/data/block.py b/rnd/autogpt_server/autogpt_server/data/block.py index d0f69fdc0..61354d0a0 100644 --- a/rnd/autogpt_server/autogpt_server/data/block.py +++ b/rnd/autogpt_server/autogpt_server/data/block.py @@ -94,6 +94,9 @@ class BlockSchema(BaseModel): def get_fields(self) -> set[str]: return set(self.jsonschema["properties"].keys()) + def get_required_fields(self) -> set[str]: + return set(self.jsonschema["required"]) + BlockOutput = Generator[tuple[str, Any], None, None] @@ -190,12 +193,15 @@ class ParrotBlock(Block): yield "output", input_data["input"] -class TextCombinerBlock(Block): +class TextFormatterBlock(Block): id: ClassVar[str] = "db7d8f02-2f44-4c55-ab7a-eae0941f0c30" # type: ignore input_schema: ClassVar[BlockSchema] = BlockSchema( # type: ignore { - "text1": "string", - "text2": "string", + "texts": { + "type": "array", + "items": {"type": "string"}, + "minItems": 1, + }, "format": "string", } ) @@ -206,10 +212,7 @@ class TextCombinerBlock(Block): ) def run(self, input_data: BlockData) -> BlockOutput: - yield "combined_text", input_data["format"].format( - text1=input_data["text1"], - text2=input_data["text2"], - ) + yield "combined_text", input_data["format"].format(texts=input_data["texts"]) class PrintingBlock(Block): diff --git a/rnd/autogpt_server/autogpt_server/data/execution.py b/rnd/autogpt_server/autogpt_server/data/execution.py index 2fb6a7a03..025267312 100644 --- a/rnd/autogpt_server/autogpt_server/data/execution.py +++ b/rnd/autogpt_server/autogpt_server/data/execution.py @@ -236,4 +236,27 @@ async def get_node_execution_input(node_exec_id: str) -> dict[str, Any]: exec_input = json.loads(execution.AgentNode.constantInput) for input_data in execution.Input or []: exec_input[input_data.name] = json.loads(input_data.data) - return exec_input + + return merge_execution_input(exec_input) + + +SPLIT = "_$_" + + +def merge_execution_input(data: dict[str, Any]) -> dict[str, Any]: + list_input = [] + for key, value in data.items(): + if SPLIT not in key: + continue + + name, index = key.split(SPLIT) + if not index.isdigit(): + list_input.append((name, value, 0)) + else: + list_input.append((name, value, int(index))) + + for name, value, _ in sorted(list_input, key=lambda x: x[2]): + data[name] = data.get(name, []) + data[name].append(value) + + return data diff --git a/rnd/autogpt_server/autogpt_server/executor/manager.py b/rnd/autogpt_server/autogpt_server/executor/manager.py index 7ebb37569..2d323ee3d 100644 --- a/rnd/autogpt_server/autogpt_server/executor/manager.py +++ b/rnd/autogpt_server/autogpt_server/executor/manager.py @@ -6,8 +6,9 @@ from typing import Any, Coroutine, Generator, TypeVar from autogpt_server.data import db from autogpt_server.data.block import Block, get_block from autogpt_server.data.execution import ( - get_node_execution_input, create_graph_execution, + get_node_execution_input, + merge_execution_input, update_execution_status as execution_update, upsert_execution_output, upsert_execution_input, @@ -77,9 +78,10 @@ def execute_node(loop: asyncio.AbstractEventLoop, data: Execution) -> ExecutionS ): yield execution except Exception as e: - logger.exception(f"{prefix} failed with error: %s", e) + error_msg = f"{e.__class__.__name__}: {e}" + logger.exception(f"{prefix} failed with error. `%s`", error_msg) wait(execution_update(node_exec_id, ExecutionStatus.FAILED)) - wait(upsert_execution_output(node_exec_id, "error", str(e))) + wait(upsert_execution_output(node_exec_id, "error", error_msg)) raise e @@ -151,13 +153,17 @@ async def validate_exec(node: Node, data: dict[str, Any]) -> tuple[bool, str]: A tuple of a boolean indicating if the data is valid, and a message if not. Return the executed block name if the data is valid. """ - node_block: Block | None = await(get_block(node.block_id)) + node_block: Block | None = await get_block(node.block_id) if not node_block: return False, f"Block for {node.block_id} not found." - input_fields = node_block.input_schema.get_fields() - if not input_fields.issubset(data): - return False, f"Input data missing: {input_fields - set(data)}" + input_fields_from_schema = node_block.input_schema.get_required_fields() + if not input_fields_from_schema.issubset(data): + return False, f"Input data missing: {input_fields_from_schema - set(data)}" + + input_fields_from_nodes = {name for name, _ in node.input_nodes} + if not input_fields_from_nodes.issubset(data): + return False, f"Input data missing: {input_fields_from_nodes - set(data)}" if error := node_block.input_schema.validate_data(data): logger.error("Input value doesn't match schema: %s", error) @@ -214,7 +220,7 @@ class ExecutionManager(AppService): # Currently, there is no constraint on the number of root nodes in the graph. for node in graph.starting_nodes: - input_data = {**node.input_default, **data} + input_data = merge_execution_input({**node.input_default, **data}) valid, error = self.run_and_wait(validate_exec(node, input_data)) if not valid: raise Exception(error) diff --git a/rnd/autogpt_server/test/executor/test_manager.py b/rnd/autogpt_server/test/executor/test_manager.py index c521c5d98..54d3c3d16 100644 --- a/rnd/autogpt_server/test/executor/test_manager.py +++ b/rnd/autogpt_server/test/executor/test_manager.py @@ -12,7 +12,7 @@ async def create_test_graph() -> graph.Graph: """ ParrotBlock \ - ---- TextCombinerBlock ---- PrintingBlock + ---- TextFormatterBlock ---- PrintingBlock / ParrotBlock """ @@ -20,13 +20,16 @@ async def create_test_graph() -> graph.Graph: graph.Node(block_id=block.ParrotBlock.id), graph.Node(block_id=block.ParrotBlock.id), graph.Node( - block_id=block.TextCombinerBlock.id, - input_default={"format": "{text1},{text2}"}, + block_id=block.TextFormatterBlock.id, + input_default={ + "format": "{texts[0]},{texts[1]},{texts[2]}", + "texts_$_3": "!!!", + }, ), graph.Node(block_id=block.PrintingBlock.id), ] - nodes[0].connect(nodes[2], "output", "text1") - nodes[1].connect(nodes[2], "output", "text2") + nodes[0].connect(nodes[2], "output", "texts_$_1") + nodes[1].connect(nodes[2], "output", "texts_$_2") nodes[2].connect(nodes[3], "combined_text", "text") test_graph = graph.Graph( @@ -85,14 +88,14 @@ async def execute_graph(test_manager: ExecutionManager, test_graph: graph.Graph) assert exec.input_data == {"input": text} assert exec.node_id == test_graph.nodes[1].id - # Executing TextCombinerBlock + # Executing TextFormatterBlock exec = executions[2] assert exec.status == execution.ExecutionStatus.COMPLETED assert exec.graph_exec_id == graph_exec_id - assert exec.output_data == {"combined_text": ["Hello, World!,Hello, World!"]} + assert exec.output_data == {"combined_text": ["Hello, World!,Hello, World!,!!!"]} assert exec.input_data == { - "text1": "Hello, World!", - "text2": "Hello, World!", + "texts_$_1": "Hello, World!", + "texts_$_2": "Hello, World!", } assert exec.node_id == test_graph.nodes[2].id @@ -101,7 +104,7 @@ async def execute_graph(test_manager: ExecutionManager, test_graph: graph.Graph) assert exec.status == execution.ExecutionStatus.COMPLETED assert exec.graph_exec_id == graph_exec_id assert exec.output_data == {"status": ["printed"]} - assert exec.input_data == {"text": "Hello, World!,Hello, World!"} + assert exec.input_data == {"text": "Hello, World!,Hello, World!,!!!"} assert exec.node_id == test_graph.nodes[3].id