fix(agent): Make `Agent.save_state` behave like save as (#7025)
* Make `Agent.save_state` behave like "save as" - Leave previously saved state untouched - Save agent state in new folder corresponding to new `agent_id` - Copy over workspace contents to new folder * Add `copy` method to `FileStorage` --------- Co-authored-by: Reinier van der Leer <pwuts@agpt.co>pull/7082/head
parent
90f3c5e2d9
commit
e866a4ba04
|
@ -1,6 +1,7 @@
|
|||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from typing import Optional
|
||||
|
||||
from autogpt.file_storage.base import FileStorage
|
||||
|
||||
|
@ -13,11 +14,11 @@ class AgentFileManagerMixin:
|
|||
"""Mixin that adds file manager (e.g. Agent state)
|
||||
and workspace manager (e.g. Agent output files) support."""
|
||||
|
||||
files: FileStorage = None
|
||||
files: FileStorage
|
||||
"""Agent-related files, e.g. state, logs.
|
||||
Use `workspace` to access the agent's workspace files."""
|
||||
|
||||
workspace: FileStorage = None
|
||||
workspace: FileStorage
|
||||
"""Workspace that the agent has access to, e.g. for reading/writing files.
|
||||
Use `files` to access agent-related files, e.g. state, logs."""
|
||||
|
||||
|
@ -68,10 +69,25 @@ class AgentFileManagerMixin:
|
|||
"""Get the agent's file operation logs as list of strings."""
|
||||
return self._file_logs_cache
|
||||
|
||||
async def save_state(self) -> None:
|
||||
async def save_state(self, save_as: Optional[str] = None) -> None:
|
||||
"""Save the agent's state to the state file."""
|
||||
state: BaseAgentSettings = getattr(self, "state")
|
||||
await self.files.write_file(self.files.root / self.STATE_FILE, state.json())
|
||||
if save_as:
|
||||
temp_id = state.agent_id
|
||||
state.agent_id = save_as
|
||||
self._file_storage.make_dir(f"agents/{save_as}")
|
||||
# Save state
|
||||
await self._file_storage.write_file(
|
||||
f"agents/{save_as}/{self.STATE_FILE}", state.json()
|
||||
)
|
||||
# Copy workspace
|
||||
self._file_storage.copy(
|
||||
f"agents/{temp_id}/workspace",
|
||||
f"agents/{save_as}/workspace",
|
||||
)
|
||||
state.agent_id = temp_id
|
||||
else:
|
||||
await self.files.write_file(self.files.root / self.STATE_FILE, state.json())
|
||||
|
||||
def change_agent_id(self, new_id: str):
|
||||
"""Change the agent's ID and update the file storage accordingly."""
|
||||
|
|
|
@ -345,19 +345,13 @@ async def run_auto_gpt(
|
|||
logger.info(f"Saving state of {agent_id}...")
|
||||
|
||||
# Allow user to Save As other ID
|
||||
save_as_id = (
|
||||
clean_input(
|
||||
config,
|
||||
f"Press enter to save as '{agent_id}',"
|
||||
" or enter a different ID to save to:",
|
||||
)
|
||||
or agent_id
|
||||
save_as_id = clean_input(
|
||||
config,
|
||||
f"Press enter to save as '{agent_id}',"
|
||||
" or enter a different ID to save to:",
|
||||
)
|
||||
if save_as_id and save_as_id != agent_id:
|
||||
agent.change_agent_id(save_as_id)
|
||||
# TODO: allow many-to-one relations of agents and workspaces
|
||||
|
||||
await agent.save_state()
|
||||
# TODO: allow many-to-one relations of agents and workspaces
|
||||
await agent.save_state(save_as_id if not save_as_id.isspace() else None)
|
||||
|
||||
|
||||
@coroutine
|
||||
|
|
|
@ -127,6 +127,10 @@ class FileStorage(ABC):
|
|||
def rename(self, old_path: str | Path, new_path: str | Path) -> None:
|
||||
"""Rename a file or folder in the storage."""
|
||||
|
||||
@abstractmethod
|
||||
def copy(self, source: str | Path, destination: str | Path) -> None:
|
||||
"""Copy a file or folder with all contents in the storage."""
|
||||
|
||||
@abstractmethod
|
||||
def make_dir(self, path: str | Path) -> None:
|
||||
"""Create a directory in the storage if doesn't exist."""
|
||||
|
|
|
@ -182,6 +182,21 @@ class GCSFileStorage(FileStorage):
|
|||
new_name = str(blob.name).replace(str(old_path), str(new_path), 1)
|
||||
self._bucket.rename_blob(blob, new_name=new_name)
|
||||
|
||||
def copy(self, source: str | Path, destination: str | Path) -> None:
|
||||
"""Copy a file or folder with all contents in the storage."""
|
||||
source = self.get_path(source)
|
||||
destination = self.get_path(destination)
|
||||
# If the source is a file, copy it
|
||||
if self._bucket.blob(str(source)).exists():
|
||||
self._bucket.copy_blob(
|
||||
self._bucket.blob(str(source)), self._bucket, str(destination)
|
||||
)
|
||||
return
|
||||
# Otherwise, copy all blobs with the prefix (folder)
|
||||
for blob in self._bucket.list_blobs(prefix=f"{source}/"):
|
||||
new_name = str(blob.name).replace(str(source), str(destination), 1)
|
||||
self._bucket.copy_blob(blob, self._bucket, new_name)
|
||||
|
||||
def clone_with_subroot(self, subroot: str | Path) -> GCSFileStorage:
|
||||
"""Create a new GCSFileStorage with a subroot of the current storage."""
|
||||
file_storage = GCSFileStorage(
|
||||
|
|
|
@ -115,6 +115,20 @@ class LocalFileStorage(FileStorage):
|
|||
new_path = self.get_path(new_path)
|
||||
old_path.rename(new_path)
|
||||
|
||||
def copy(self, source: str | Path, destination: str | Path) -> None:
|
||||
"""Copy a file or folder with all contents in the storage."""
|
||||
source = self.get_path(source)
|
||||
destination = self.get_path(destination)
|
||||
if source.is_file():
|
||||
destination.write_bytes(source.read_bytes())
|
||||
else:
|
||||
destination.mkdir(exist_ok=True, parents=True)
|
||||
for file in source.rglob("*"):
|
||||
if file.is_file():
|
||||
target = destination / file.relative_to(source)
|
||||
target.parent.mkdir(exist_ok=True, parents=True)
|
||||
target.write_bytes(file.read_bytes())
|
||||
|
||||
def clone_with_subroot(self, subroot: str | Path) -> FileStorage:
|
||||
"""Create a new LocalFileStorage with a subroot of the current storage."""
|
||||
return LocalFileStorage(
|
||||
|
|
|
@ -222,6 +222,35 @@ class S3FileStorage(FileStorage):
|
|||
else:
|
||||
raise # Re-raise for any other client errors
|
||||
|
||||
def copy(self, source: str | Path, destination: str | Path) -> None:
|
||||
"""Copy a file or folder with all contents in the storage."""
|
||||
source = str(self.get_path(source))
|
||||
destination = str(self.get_path(destination))
|
||||
|
||||
try:
|
||||
# If source is a file, copy it
|
||||
self._s3.meta.client.head_object(Bucket=self._bucket_name, Key=source)
|
||||
self._s3.meta.client.copy_object(
|
||||
CopySource={"Bucket": self._bucket_name, "Key": source},
|
||||
Bucket=self._bucket_name,
|
||||
Key=destination,
|
||||
)
|
||||
except botocore.exceptions.ClientError as e:
|
||||
if int(e.response["ResponseMetadata"]["HTTPStatusCode"]) == 404:
|
||||
# If the object does not exist,
|
||||
# it may be a folder
|
||||
prefix = f"{source.rstrip('/')}/"
|
||||
objs = list(self._bucket.objects.filter(Prefix=prefix))
|
||||
for obj in objs:
|
||||
new_key = destination + obj.key[len(source) :]
|
||||
self._s3.meta.client.copy_object(
|
||||
CopySource={"Bucket": self._bucket_name, "Key": obj.key},
|
||||
Bucket=self._bucket_name,
|
||||
Key=new_key,
|
||||
)
|
||||
else:
|
||||
raise
|
||||
|
||||
def clone_with_subroot(self, subroot: str | Path) -> S3FileStorage:
|
||||
"""Create a new S3FileStorage with a subroot of the current storage."""
|
||||
file_storage = S3FileStorage(
|
||||
|
|
|
@ -177,3 +177,24 @@ def test_clone(gcs_storage_with_files: GCSFileStorage, gcs_root: Path):
|
|||
assert cloned._bucket.name == gcs_storage_with_files._bucket.name
|
||||
assert cloned.exists("dir")
|
||||
assert cloned.exists("dir/test_file_4")
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_copy_file(storage: GCSFileStorage):
|
||||
await storage.write_file("test_file.txt", "test content")
|
||||
storage.copy("test_file.txt", "test_file_copy.txt")
|
||||
storage.make_dir("dir")
|
||||
storage.copy("test_file.txt", "dir/test_file_copy.txt")
|
||||
assert storage.read_file("test_file_copy.txt") == "test content"
|
||||
assert storage.read_file("dir/test_file_copy.txt") == "test content"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_copy_dir(storage: GCSFileStorage):
|
||||
storage.make_dir("dir")
|
||||
storage.make_dir("dir/sub_dir")
|
||||
await storage.write_file("dir/test_file.txt", "test content")
|
||||
await storage.write_file("dir/sub_dir/test_file.txt", "test content")
|
||||
storage.copy("dir", "dir_copy")
|
||||
assert storage.read_file("dir_copy/test_file.txt") == "test content"
|
||||
assert storage.read_file("dir_copy/sub_dir/test_file.txt") == "test content"
|
||||
|
|
|
@ -188,3 +188,24 @@ def test_get_path_accessible(accessible_path: Path, storage: LocalFileStorage):
|
|||
def test_get_path_inaccessible(inaccessible_path: Path, storage: LocalFileStorage):
|
||||
with pytest.raises(ValueError):
|
||||
storage.get_path(inaccessible_path)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_copy_file(storage: LocalFileStorage):
|
||||
await storage.write_file("test_file.txt", "test content")
|
||||
storage.copy("test_file.txt", "test_file_copy.txt")
|
||||
storage.make_dir("dir")
|
||||
storage.copy("test_file.txt", "dir/test_file_copy.txt")
|
||||
assert storage.read_file("test_file_copy.txt") == "test content"
|
||||
assert storage.read_file("dir/test_file_copy.txt") == "test content"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_copy_dir(storage: LocalFileStorage):
|
||||
storage.make_dir("dir")
|
||||
storage.make_dir("dir/sub_dir")
|
||||
await storage.write_file("dir/test_file.txt", "test content")
|
||||
await storage.write_file("dir/sub_dir/test_file.txt", "test content")
|
||||
storage.copy("dir", "dir_copy")
|
||||
assert storage.read_file("dir_copy/test_file.txt") == "test content"
|
||||
assert storage.read_file("dir_copy/sub_dir/test_file.txt") == "test content"
|
||||
|
|
|
@ -172,3 +172,24 @@ def test_clone(s3_storage_with_files: S3FileStorage, s3_root: Path):
|
|||
assert cloned._bucket.name == s3_storage_with_files._bucket.name
|
||||
assert cloned.exists("dir")
|
||||
assert cloned.exists("dir/test_file_4")
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_copy_file(storage: S3FileStorage):
|
||||
await storage.write_file("test_file.txt", "test content")
|
||||
storage.copy("test_file.txt", "test_file_copy.txt")
|
||||
storage.make_dir("dir")
|
||||
storage.copy("test_file.txt", "dir/test_file_copy.txt")
|
||||
assert storage.read_file("test_file_copy.txt") == "test content"
|
||||
assert storage.read_file("dir/test_file_copy.txt") == "test content"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_copy_dir(storage: S3FileStorage):
|
||||
storage.make_dir("dir")
|
||||
storage.make_dir("dir/sub_dir")
|
||||
await storage.write_file("dir/test_file.txt", "test content")
|
||||
await storage.write_file("dir/sub_dir/test_file.txt", "test content")
|
||||
storage.copy("dir", "dir_copy")
|
||||
assert storage.read_file("dir_copy/test_file.txt") == "test content"
|
||||
assert storage.read_file("dir_copy/sub_dir/test_file.txt") == "test content"
|
||||
|
|
Loading…
Reference in New Issue