fix ordering & propagation of node executions and I/O in `get_graph_execution`

pull/9051/head
Reinier van der Leer 2025-01-17 17:14:57 +01:00
parent f9462bcf2a
commit f547513207
No known key found for this signature in database
GPG Key ID: BEB9E26CB6F21336
4 changed files with 27 additions and 16 deletions

View File

@ -586,7 +586,17 @@ async def get_execution_meta(
async def get_execution(user_id: str, execution_id: str) -> GraphExecution | None:
execution = await AgentGraphExecution.prisma().find_first(
where={"id": execution_id, "userId": user_id},
include={"AgentNodeExecutions": {"include": {"AgentNode": True}}},
include={
"AgentNodeExecutions": {
"include": {"AgentNode": True, "Input": True, "Output": True},
"order_by": [
{"queuedTime": "asc"},
{ # Fallback: Incomplete execs has no queuedTime.
"addedTime": "asc"
},
],
},
},
)
return GraphExecution.from_db(execution) if execution else None

View File

@ -160,10 +160,10 @@ class AgentServer(backend.util.service.AppProcess):
return execution.status
@staticmethod
async def test_get_graph_run_node_execution_results(
async def test_get_graph_run_results(
graph_id: str, graph_exec_id: str, user_id: str
):
return await backend.server.routers.v1.get_graph_run_node_execution_results(
return await backend.server.routers.v1.get_graph_execution(
graph_id, graph_exec_id, user_id
)

View File

@ -73,9 +73,10 @@ async def wait_execution(
# Wait for the executions to complete
for i in range(timeout):
if await is_execution_completed():
return await AgentServer().test_get_graph_run_node_execution_results(
graph_exec = await AgentServer().test_get_graph_run_results(
graph_id, graph_exec_id, user_id
)
return graph_exec.node_executions
time.sleep(1)
assert False, "Execution did not complete in time."

View File

@ -55,7 +55,7 @@ async def assert_sample_graph_executions(
graph_exec_id: str,
):
logger.info(f"Checking execution results for graph {test_graph.id}")
executions = await agent_server.test_get_graph_run_node_execution_results(
graph_run = await agent_server.test_get_graph_run_results(
test_graph.id,
graph_exec_id,
test_user.id,
@ -74,7 +74,7 @@ async def assert_sample_graph_executions(
]
# Executing StoreValueBlock
exec = executions[0]
exec = graph_run.node_executions[0]
logger.info(f"Checking first StoreValueBlock execution: {exec}")
assert exec.status == execution.ExecutionStatus.COMPLETED
assert exec.graph_exec_id == graph_exec_id
@ -87,7 +87,7 @@ async def assert_sample_graph_executions(
assert exec.node_id in [test_graph.nodes[0].id, test_graph.nodes[1].id]
# Executing StoreValueBlock
exec = executions[1]
exec = graph_run.node_executions[1]
logger.info(f"Checking second StoreValueBlock execution: {exec}")
assert exec.status == execution.ExecutionStatus.COMPLETED
assert exec.graph_exec_id == graph_exec_id
@ -100,7 +100,7 @@ async def assert_sample_graph_executions(
assert exec.node_id in [test_graph.nodes[0].id, test_graph.nodes[1].id]
# Executing FillTextTemplateBlock
exec = executions[2]
exec = graph_run.node_executions[2]
logger.info(f"Checking FillTextTemplateBlock execution: {exec}")
assert exec.status == execution.ExecutionStatus.COMPLETED
assert exec.graph_exec_id == graph_exec_id
@ -115,7 +115,7 @@ async def assert_sample_graph_executions(
assert exec.node_id == test_graph.nodes[2].id
# Executing PrintToConsoleBlock
exec = executions[3]
exec = graph_run.node_executions[3]
logger.info(f"Checking PrintToConsoleBlock execution: {exec}")
assert exec.status == execution.ExecutionStatus.COMPLETED
assert exec.graph_exec_id == graph_exec_id
@ -198,14 +198,14 @@ async def test_input_pin_always_waited(server: SpinTestServer):
)
logger.info("Checking execution results")
executions = await server.agent_server.test_get_graph_run_node_execution_results(
graph_exec = await server.agent_server.test_get_graph_run_results(
test_graph.id, graph_exec_id, test_user.id
)
assert len(executions) == 3
assert len(graph_exec.node_executions) == 3
# FindInDictionaryBlock should wait for the input pin to be provided,
# Hence executing extraction of "key" from {"key1": "value1", "key2": "value2"}
assert executions[2].status == execution.ExecutionStatus.COMPLETED
assert executions[2].output_data == {"output": ["value2"]}
assert graph_exec.node_executions[2].status == execution.ExecutionStatus.COMPLETED
assert graph_exec.node_executions[2].output_data == {"output": ["value2"]}
logger.info("Completed test_input_pin_always_waited")
@ -281,12 +281,12 @@ async def test_static_input_link_on_graph(server: SpinTestServer):
server.agent_server, test_graph, test_user, {}, 8
)
logger.info("Checking execution results")
executions = await server.agent_server.test_get_graph_run_node_execution_results(
graph_exec = await server.agent_server.test_get_graph_run_results(
test_graph.id, graph_exec_id, test_user.id
)
assert len(executions) == 8
assert len(graph_exec.node_executions) == 8
# The last 3 executions will be a+b=4+5=9
for i, exec_data in enumerate(executions[-3:]):
for i, exec_data in enumerate(graph_exec.node_executions[-3:]):
logger.info(f"Checking execution {i+1} of last 3: {exec_data}")
assert exec_data.status == execution.ExecutionStatus.COMPLETED
assert exec_data.output_data == {"result": [9]}