fix(backend): Re-work the connection input consumption logic for Agent Executor Block (#8710)
parent
6954f4eb0e
commit
8b4bb27077
|
@ -94,15 +94,7 @@ class BlockSchema(BaseModel):
|
|||
|
||||
@classmethod
|
||||
def validate_data(cls, data: BlockInput) -> str | None:
|
||||
"""
|
||||
Validate the data against the schema.
|
||||
Returns the validation error message if the data does not match the schema.
|
||||
"""
|
||||
try:
|
||||
jsonschema.validate(data, cls.jsonschema())
|
||||
return None
|
||||
except jsonschema.ValidationError as e:
|
||||
return str(e)
|
||||
return json.validate_with_jsonschema(schema=cls.jsonschema(), data=data)
|
||||
|
||||
@classmethod
|
||||
def validate_field(cls, field_name: str, data: BlockInput) -> str | None:
|
||||
|
|
|
@ -18,6 +18,7 @@ if TYPE_CHECKING:
|
|||
|
||||
from autogpt_libs.utils.cache import thread_cached
|
||||
|
||||
from backend.blocks.agent import AgentExecutorBlock
|
||||
from backend.data import redis
|
||||
from backend.data.block import Block, BlockData, BlockInput, BlockType, get_block
|
||||
from backend.data.execution import (
|
||||
|
@ -135,7 +136,6 @@ def execute_node(
|
|||
logger.error(f"Block {node.block_id} not found.")
|
||||
return
|
||||
|
||||
# Sanity check: validate the execution input.
|
||||
log_metadata = LogMetadata(
|
||||
user_id=user_id,
|
||||
graph_eid=graph_exec_id,
|
||||
|
@ -144,11 +144,20 @@ def execute_node(
|
|||
node_id=node_id,
|
||||
block_name=node_block.name,
|
||||
)
|
||||
|
||||
# Sanity check: validate the execution input.
|
||||
input_data, error = validate_exec(node, data.data, resolve_input=False)
|
||||
if input_data is None:
|
||||
log_metadata.error(f"Skip execution, input validation error: {error}")
|
||||
db_client.upsert_execution_output(node_exec_id, "error", error)
|
||||
update_execution(ExecutionStatus.FAILED)
|
||||
return
|
||||
|
||||
# Re-shape the input data for agent block.
|
||||
# AgentExecutorBlock specially separate the node input_data & its input_default.
|
||||
if isinstance(node_block, AgentExecutorBlock):
|
||||
input_data = {**node.input_default, "data": input_data}
|
||||
|
||||
# Execute the node
|
||||
input_data_str = json.dumps(input_data)
|
||||
input_size = len(input_data_str)
|
||||
|
@ -376,31 +385,46 @@ def validate_exec(
|
|||
if not node_block:
|
||||
return None, f"Block for {node.block_id} not found."
|
||||
|
||||
error_prefix = f"Input data missing for {node_block.name}:"
|
||||
if isinstance(node_block, AgentExecutorBlock):
|
||||
# Validate the execution metadata for the agent executor block.
|
||||
try:
|
||||
exec_data = AgentExecutorBlock.Input(**node.input_default)
|
||||
except Exception as e:
|
||||
return None, f"Input data doesn't match {node_block.name}: {str(e)}"
|
||||
|
||||
# Validation input
|
||||
input_schema = exec_data.input_schema
|
||||
required_fields = set(input_schema["required"])
|
||||
input_default = exec_data.data
|
||||
else:
|
||||
# Convert non-matching data types to the expected input schema.
|
||||
for name, data_type in node_block.input_schema.__annotations__.items():
|
||||
if (value := data.get(name)) and (type(value) is not data_type):
|
||||
data[name] = convert(value, data_type)
|
||||
|
||||
# Validation input
|
||||
input_schema = node_block.input_schema.jsonschema()
|
||||
required_fields = node_block.input_schema.get_required_fields()
|
||||
input_default = node.input_default
|
||||
|
||||
# Input data (without default values) should contain all required fields.
|
||||
error_prefix = f"Input data missing or mismatch for `{node_block.name}`:"
|
||||
input_fields_from_nodes = {link.sink_name for link in node.input_links}
|
||||
if not input_fields_from_nodes.issubset(data):
|
||||
return None, f"{error_prefix} {input_fields_from_nodes - set(data)}"
|
||||
|
||||
# Merge input data with default values and resolve dynamic dict/list/object pins.
|
||||
data = {**node.input_default, **data}
|
||||
data = {**input_default, **data}
|
||||
if resolve_input:
|
||||
data = merge_execution_input(data)
|
||||
|
||||
# Input data post-merge should contain all required fields from the schema.
|
||||
input_fields_from_schema = node_block.input_schema.get_required_fields()
|
||||
if not input_fields_from_schema.issubset(data):
|
||||
return None, f"{error_prefix} {input_fields_from_schema - set(data)}"
|
||||
|
||||
# Convert non-matching data types to the expected input schema.
|
||||
for name, data_type in node_block.input_schema.__annotations__.items():
|
||||
if (value := data.get(name)) and (type(value) is not data_type):
|
||||
data[name] = convert(value, data_type)
|
||||
if not required_fields.issubset(data):
|
||||
return None, f"{error_prefix} {required_fields - set(data)}"
|
||||
|
||||
# Last validation: Validate the input values against the schema.
|
||||
if error := node_block.input_schema.validate_data(data):
|
||||
error_message = f"Input data doesn't match {node_block.name}: {error}"
|
||||
if error := json.validate_with_jsonschema(schema=input_schema, data=data):
|
||||
error_message = f"{error_prefix} {error}"
|
||||
logger.error(error_message)
|
||||
return None, error_message
|
||||
|
||||
|
|
|
@ -1,6 +1,7 @@
|
|||
import json
|
||||
from typing import Any, Type, TypeVar, overload
|
||||
|
||||
import jsonschema
|
||||
from fastapi.encoders import jsonable_encoder
|
||||
|
||||
from .type import type_match
|
||||
|
@ -30,3 +31,17 @@ def loads(data: str, *args, target_type: Type[T] | None = None, **kwargs) -> Any
|
|||
if target_type:
|
||||
return type_match(parsed, target_type)
|
||||
return parsed
|
||||
|
||||
|
||||
def validate_with_jsonschema(
|
||||
schema: dict[str, Any], data: dict[str, Any]
|
||||
) -> str | None:
|
||||
"""
|
||||
Validate the data against the schema.
|
||||
Returns the validation error message if the data does not match the schema.
|
||||
"""
|
||||
try:
|
||||
jsonschema.validate(data, schema)
|
||||
return None
|
||||
except jsonschema.ValidationError as e:
|
||||
return str(e)
|
||||
|
|
Loading…
Reference in New Issue