Forge: add output fields to DB StepModel (#5759)

* Forge: add output fields to `StepModel` and `AgentDB.update_step`

* Forge: fix `AgentDB.get_step` parameter types
pull/5789/head
Reinier van der Leer 2023-10-17 10:17:16 -07:00 committed by GitHub
parent cba90e20e9
commit 7f3ca0b76a
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 18 additions and 6 deletions

View File

@ -54,6 +54,7 @@ class StepModel(Base):
name = Column(String)
input = Column(String)
status = Column(String)
output = Column(String)
is_last = Column(Boolean, default=False)
created_at = Column(DateTime, default=datetime.datetime.utcnow)
modified_at = Column(
@ -61,6 +62,7 @@ class StepModel(Base):
)
additional_input = Column(JSON)
additional_output = Column(JSON)
artifacts = relationship("ArtifactModel", back_populates="step")
@ -111,9 +113,11 @@ def convert_to_step(step_model: StepModel, debug_enabled: bool = False) -> Step:
name=step_model.name,
input=step_model.input,
status=status,
output=step_model.output,
artifacts=step_artifacts,
is_last=step_model.is_last == 1,
additional_input=step_model.additional_input,
additional_output=step_model.additional_output,
)
@ -280,7 +284,7 @@ class AgentDB:
LOG.error(f"Unexpected error while getting task: {e}")
raise
async def get_step(self, task_id: int, step_id: int) -> Step:
async def get_step(self, task_id: str, step_id: str) -> Step:
if self.debug_enabled:
LOG.debug(f"Getting step with task_id: {task_id} and step_id: {step_id}")
try:
@ -311,8 +315,10 @@ class AgentDB:
self,
task_id: str,
step_id: str,
status: str,
additional_input: Optional[Dict[str, Any]] = {},
status: Optional[str] = None,
output: Optional[str] = None,
additional_input: Optional[Dict[str, Any]] = None,
additional_output: Optional[Dict[str, Any]] = None,
) -> Step:
if self.debug_enabled:
LOG.debug(f"Updating step with task_id: {task_id} and step_id: {step_id}")
@ -323,8 +329,14 @@ class AgentDB:
.filter_by(task_id=task_id, step_id=step_id)
.first()
):
step.status = status
step.additional_input = additional_input
if status is not None:
step.status = status
if additional_input is not None:
step.additional_input = additional_input
if output is not None:
step.output = output
if additional_output is not None:
step.additional_output = additional_output
session.commit()
return await self.get_step(task_id, step_id)
else:

View File

@ -156,7 +156,7 @@ class Step(StepRequestBody):
description="Output of the task step.",
example="I am going to use the write_to_file command and write Washington to a file called output.txt <write_to_file('output.txt', 'Washington')",
)
additional_output: Optional[StepOutput] = {}
additional_output: Optional[StepOutput] = Field(default_factory=dict)
artifacts: Optional[List[Artifact]] = Field(
[], description="A list of artifacts that the step has produced."
)