fix ordering & propagation of node executions and I/O in `get_graph_execution`
parent
f9462bcf2a
commit
f547513207
|
@ -586,7 +586,17 @@ async def get_execution_meta(
|
|||
async def get_execution(user_id: str, execution_id: str) -> GraphExecution | None:
|
||||
execution = await AgentGraphExecution.prisma().find_first(
|
||||
where={"id": execution_id, "userId": user_id},
|
||||
include={"AgentNodeExecutions": {"include": {"AgentNode": True}}},
|
||||
include={
|
||||
"AgentNodeExecutions": {
|
||||
"include": {"AgentNode": True, "Input": True, "Output": True},
|
||||
"order_by": [
|
||||
{"queuedTime": "asc"},
|
||||
{ # Fallback: Incomplete execs has no queuedTime.
|
||||
"addedTime": "asc"
|
||||
},
|
||||
],
|
||||
},
|
||||
},
|
||||
)
|
||||
return GraphExecution.from_db(execution) if execution else None
|
||||
|
||||
|
|
|
@ -160,10 +160,10 @@ class AgentServer(backend.util.service.AppProcess):
|
|||
return execution.status
|
||||
|
||||
@staticmethod
|
||||
async def test_get_graph_run_node_execution_results(
|
||||
async def test_get_graph_run_results(
|
||||
graph_id: str, graph_exec_id: str, user_id: str
|
||||
):
|
||||
return await backend.server.routers.v1.get_graph_run_node_execution_results(
|
||||
return await backend.server.routers.v1.get_graph_execution(
|
||||
graph_id, graph_exec_id, user_id
|
||||
)
|
||||
|
||||
|
|
|
@ -73,9 +73,10 @@ async def wait_execution(
|
|||
# Wait for the executions to complete
|
||||
for i in range(timeout):
|
||||
if await is_execution_completed():
|
||||
return await AgentServer().test_get_graph_run_node_execution_results(
|
||||
graph_exec = await AgentServer().test_get_graph_run_results(
|
||||
graph_id, graph_exec_id, user_id
|
||||
)
|
||||
return graph_exec.node_executions
|
||||
time.sleep(1)
|
||||
|
||||
assert False, "Execution did not complete in time."
|
||||
|
|
|
@ -55,7 +55,7 @@ async def assert_sample_graph_executions(
|
|||
graph_exec_id: str,
|
||||
):
|
||||
logger.info(f"Checking execution results for graph {test_graph.id}")
|
||||
executions = await agent_server.test_get_graph_run_node_execution_results(
|
||||
graph_run = await agent_server.test_get_graph_run_results(
|
||||
test_graph.id,
|
||||
graph_exec_id,
|
||||
test_user.id,
|
||||
|
@ -74,7 +74,7 @@ async def assert_sample_graph_executions(
|
|||
]
|
||||
|
||||
# Executing StoreValueBlock
|
||||
exec = executions[0]
|
||||
exec = graph_run.node_executions[0]
|
||||
logger.info(f"Checking first StoreValueBlock execution: {exec}")
|
||||
assert exec.status == execution.ExecutionStatus.COMPLETED
|
||||
assert exec.graph_exec_id == graph_exec_id
|
||||
|
@ -87,7 +87,7 @@ async def assert_sample_graph_executions(
|
|||
assert exec.node_id in [test_graph.nodes[0].id, test_graph.nodes[1].id]
|
||||
|
||||
# Executing StoreValueBlock
|
||||
exec = executions[1]
|
||||
exec = graph_run.node_executions[1]
|
||||
logger.info(f"Checking second StoreValueBlock execution: {exec}")
|
||||
assert exec.status == execution.ExecutionStatus.COMPLETED
|
||||
assert exec.graph_exec_id == graph_exec_id
|
||||
|
@ -100,7 +100,7 @@ async def assert_sample_graph_executions(
|
|||
assert exec.node_id in [test_graph.nodes[0].id, test_graph.nodes[1].id]
|
||||
|
||||
# Executing FillTextTemplateBlock
|
||||
exec = executions[2]
|
||||
exec = graph_run.node_executions[2]
|
||||
logger.info(f"Checking FillTextTemplateBlock execution: {exec}")
|
||||
assert exec.status == execution.ExecutionStatus.COMPLETED
|
||||
assert exec.graph_exec_id == graph_exec_id
|
||||
|
@ -115,7 +115,7 @@ async def assert_sample_graph_executions(
|
|||
assert exec.node_id == test_graph.nodes[2].id
|
||||
|
||||
# Executing PrintToConsoleBlock
|
||||
exec = executions[3]
|
||||
exec = graph_run.node_executions[3]
|
||||
logger.info(f"Checking PrintToConsoleBlock execution: {exec}")
|
||||
assert exec.status == execution.ExecutionStatus.COMPLETED
|
||||
assert exec.graph_exec_id == graph_exec_id
|
||||
|
@ -198,14 +198,14 @@ async def test_input_pin_always_waited(server: SpinTestServer):
|
|||
)
|
||||
|
||||
logger.info("Checking execution results")
|
||||
executions = await server.agent_server.test_get_graph_run_node_execution_results(
|
||||
graph_exec = await server.agent_server.test_get_graph_run_results(
|
||||
test_graph.id, graph_exec_id, test_user.id
|
||||
)
|
||||
assert len(executions) == 3
|
||||
assert len(graph_exec.node_executions) == 3
|
||||
# FindInDictionaryBlock should wait for the input pin to be provided,
|
||||
# Hence executing extraction of "key" from {"key1": "value1", "key2": "value2"}
|
||||
assert executions[2].status == execution.ExecutionStatus.COMPLETED
|
||||
assert executions[2].output_data == {"output": ["value2"]}
|
||||
assert graph_exec.node_executions[2].status == execution.ExecutionStatus.COMPLETED
|
||||
assert graph_exec.node_executions[2].output_data == {"output": ["value2"]}
|
||||
logger.info("Completed test_input_pin_always_waited")
|
||||
|
||||
|
||||
|
@ -281,12 +281,12 @@ async def test_static_input_link_on_graph(server: SpinTestServer):
|
|||
server.agent_server, test_graph, test_user, {}, 8
|
||||
)
|
||||
logger.info("Checking execution results")
|
||||
executions = await server.agent_server.test_get_graph_run_node_execution_results(
|
||||
graph_exec = await server.agent_server.test_get_graph_run_results(
|
||||
test_graph.id, graph_exec_id, test_user.id
|
||||
)
|
||||
assert len(executions) == 8
|
||||
assert len(graph_exec.node_executions) == 8
|
||||
# The last 3 executions will be a+b=4+5=9
|
||||
for i, exec_data in enumerate(executions[-3:]):
|
||||
for i, exec_data in enumerate(graph_exec.node_executions[-3:]):
|
||||
logger.info(f"Checking execution {i+1} of last 3: {exec_data}")
|
||||
assert exec_data.status == execution.ExecutionStatus.COMPLETED
|
||||
assert exec_data.output_data == {"result": [9]}
|
||||
|
|
Loading…
Reference in New Issue