fix(backend): Re-work the connection input consumption logic for Agent Executor Block (#8710)

pull/8398/head^2
Zamil Majdy 2024-11-21 18:05:41 +07:00 committed by GitHub
parent 6954f4eb0e
commit 8b4bb27077
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 53 additions and 22 deletions

View File

@ -94,15 +94,7 @@ class BlockSchema(BaseModel):
@classmethod
def validate_data(cls, data: BlockInput) -> str | None:
"""
Validate the data against the schema.
Returns the validation error message if the data does not match the schema.
"""
try:
jsonschema.validate(data, cls.jsonschema())
return None
except jsonschema.ValidationError as e:
return str(e)
return json.validate_with_jsonschema(schema=cls.jsonschema(), data=data)
@classmethod
def validate_field(cls, field_name: str, data: BlockInput) -> str | None:

View File

@ -18,6 +18,7 @@ if TYPE_CHECKING:
from autogpt_libs.utils.cache import thread_cached
from backend.blocks.agent import AgentExecutorBlock
from backend.data import redis
from backend.data.block import Block, BlockData, BlockInput, BlockType, get_block
from backend.data.execution import (
@ -135,7 +136,6 @@ def execute_node(
logger.error(f"Block {node.block_id} not found.")
return
# Sanity check: validate the execution input.
log_metadata = LogMetadata(
user_id=user_id,
graph_eid=graph_exec_id,
@ -144,11 +144,20 @@ def execute_node(
node_id=node_id,
block_name=node_block.name,
)
# Sanity check: validate the execution input.
input_data, error = validate_exec(node, data.data, resolve_input=False)
if input_data is None:
log_metadata.error(f"Skip execution, input validation error: {error}")
db_client.upsert_execution_output(node_exec_id, "error", error)
update_execution(ExecutionStatus.FAILED)
return
# Re-shape the input data for agent block.
# AgentExecutorBlock specially separate the node input_data & its input_default.
if isinstance(node_block, AgentExecutorBlock):
input_data = {**node.input_default, "data": input_data}
# Execute the node
input_data_str = json.dumps(input_data)
input_size = len(input_data_str)
@ -376,31 +385,46 @@ def validate_exec(
if not node_block:
return None, f"Block for {node.block_id} not found."
error_prefix = f"Input data missing for {node_block.name}:"
if isinstance(node_block, AgentExecutorBlock):
# Validate the execution metadata for the agent executor block.
try:
exec_data = AgentExecutorBlock.Input(**node.input_default)
except Exception as e:
return None, f"Input data doesn't match {node_block.name}: {str(e)}"
# Validation input
input_schema = exec_data.input_schema
required_fields = set(input_schema["required"])
input_default = exec_data.data
else:
# Convert non-matching data types to the expected input schema.
for name, data_type in node_block.input_schema.__annotations__.items():
if (value := data.get(name)) and (type(value) is not data_type):
data[name] = convert(value, data_type)
# Validation input
input_schema = node_block.input_schema.jsonschema()
required_fields = node_block.input_schema.get_required_fields()
input_default = node.input_default
# Input data (without default values) should contain all required fields.
error_prefix = f"Input data missing or mismatch for `{node_block.name}`:"
input_fields_from_nodes = {link.sink_name for link in node.input_links}
if not input_fields_from_nodes.issubset(data):
return None, f"{error_prefix} {input_fields_from_nodes - set(data)}"
# Merge input data with default values and resolve dynamic dict/list/object pins.
data = {**node.input_default, **data}
data = {**input_default, **data}
if resolve_input:
data = merge_execution_input(data)
# Input data post-merge should contain all required fields from the schema.
input_fields_from_schema = node_block.input_schema.get_required_fields()
if not input_fields_from_schema.issubset(data):
return None, f"{error_prefix} {input_fields_from_schema - set(data)}"
# Convert non-matching data types to the expected input schema.
for name, data_type in node_block.input_schema.__annotations__.items():
if (value := data.get(name)) and (type(value) is not data_type):
data[name] = convert(value, data_type)
if not required_fields.issubset(data):
return None, f"{error_prefix} {required_fields - set(data)}"
# Last validation: Validate the input values against the schema.
if error := node_block.input_schema.validate_data(data):
error_message = f"Input data doesn't match {node_block.name}: {error}"
if error := json.validate_with_jsonschema(schema=input_schema, data=data):
error_message = f"{error_prefix} {error}"
logger.error(error_message)
return None, error_message

View File

@ -1,6 +1,7 @@
import json
from typing import Any, Type, TypeVar, overload
import jsonschema
from fastapi.encoders import jsonable_encoder
from .type import type_match
@ -30,3 +31,17 @@ def loads(data: str, *args, target_type: Type[T] | None = None, **kwargs) -> Any
if target_type:
return type_match(parsed, target_type)
return parsed
def validate_with_jsonschema(
schema: dict[str, Any], data: dict[str, Any]
) -> str | None:
"""
Validate the data against the schema.
Returns the validation error message if the data does not match the schema.
"""
try:
jsonschema.validate(data, schema)
return None
except jsonschema.ValidationError as e:
return str(e)