fix(agent): Make `Agent.save_state` behave like save as (#7025)

* Make `Agent.save_state` behave like "save as"
   - Leave previously saved state untouched
   - Save agent state in new folder corresponding to new `agent_id`
   - Copy over workspace contents to new folder

* Add `copy` method to `FileStorage`

---------

Co-authored-by: Reinier van der Leer <pwuts@agpt.co>
pull/7082/head
Krzysztof Czerwinski 2024-04-12 12:41:02 +02:00 committed by GitHub
parent 90f3c5e2d9
commit e866a4ba04
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
9 changed files with 151 additions and 16 deletions

View File

@ -1,6 +1,7 @@
from __future__ import annotations from __future__ import annotations
import logging import logging
from typing import Optional
from autogpt.file_storage.base import FileStorage from autogpt.file_storage.base import FileStorage
@ -13,11 +14,11 @@ class AgentFileManagerMixin:
"""Mixin that adds file manager (e.g. Agent state) """Mixin that adds file manager (e.g. Agent state)
and workspace manager (e.g. Agent output files) support.""" and workspace manager (e.g. Agent output files) support."""
files: FileStorage = None files: FileStorage
"""Agent-related files, e.g. state, logs. """Agent-related files, e.g. state, logs.
Use `workspace` to access the agent's workspace files.""" Use `workspace` to access the agent's workspace files."""
workspace: FileStorage = None workspace: FileStorage
"""Workspace that the agent has access to, e.g. for reading/writing files. """Workspace that the agent has access to, e.g. for reading/writing files.
Use `files` to access agent-related files, e.g. state, logs.""" Use `files` to access agent-related files, e.g. state, logs."""
@ -68,9 +69,24 @@ class AgentFileManagerMixin:
"""Get the agent's file operation logs as list of strings.""" """Get the agent's file operation logs as list of strings."""
return self._file_logs_cache return self._file_logs_cache
async def save_state(self) -> None: async def save_state(self, save_as: Optional[str] = None) -> None:
"""Save the agent's state to the state file.""" """Save the agent's state to the state file."""
state: BaseAgentSettings = getattr(self, "state") state: BaseAgentSettings = getattr(self, "state")
if save_as:
temp_id = state.agent_id
state.agent_id = save_as
self._file_storage.make_dir(f"agents/{save_as}")
# Save state
await self._file_storage.write_file(
f"agents/{save_as}/{self.STATE_FILE}", state.json()
)
# Copy workspace
self._file_storage.copy(
f"agents/{temp_id}/workspace",
f"agents/{save_as}/workspace",
)
state.agent_id = temp_id
else:
await self.files.write_file(self.files.root / self.STATE_FILE, state.json()) await self.files.write_file(self.files.root / self.STATE_FILE, state.json())
def change_agent_id(self, new_id: str): def change_agent_id(self, new_id: str):

View File

@ -345,19 +345,13 @@ async def run_auto_gpt(
logger.info(f"Saving state of {agent_id}...") logger.info(f"Saving state of {agent_id}...")
# Allow user to Save As other ID # Allow user to Save As other ID
save_as_id = ( save_as_id = clean_input(
clean_input(
config, config,
f"Press enter to save as '{agent_id}'," f"Press enter to save as '{agent_id}',"
" or enter a different ID to save to:", " or enter a different ID to save to:",
) )
or agent_id
)
if save_as_id and save_as_id != agent_id:
agent.change_agent_id(save_as_id)
# TODO: allow many-to-one relations of agents and workspaces # TODO: allow many-to-one relations of agents and workspaces
await agent.save_state(save_as_id if not save_as_id.isspace() else None)
await agent.save_state()
@coroutine @coroutine

View File

@ -127,6 +127,10 @@ class FileStorage(ABC):
def rename(self, old_path: str | Path, new_path: str | Path) -> None: def rename(self, old_path: str | Path, new_path: str | Path) -> None:
"""Rename a file or folder in the storage.""" """Rename a file or folder in the storage."""
@abstractmethod
def copy(self, source: str | Path, destination: str | Path) -> None:
"""Copy a file or folder with all contents in the storage."""
@abstractmethod @abstractmethod
def make_dir(self, path: str | Path) -> None: def make_dir(self, path: str | Path) -> None:
"""Create a directory in the storage if doesn't exist.""" """Create a directory in the storage if doesn't exist."""

View File

@ -182,6 +182,21 @@ class GCSFileStorage(FileStorage):
new_name = str(blob.name).replace(str(old_path), str(new_path), 1) new_name = str(blob.name).replace(str(old_path), str(new_path), 1)
self._bucket.rename_blob(blob, new_name=new_name) self._bucket.rename_blob(blob, new_name=new_name)
def copy(self, source: str | Path, destination: str | Path) -> None:
"""Copy a file or folder with all contents in the storage."""
source = self.get_path(source)
destination = self.get_path(destination)
# If the source is a file, copy it
if self._bucket.blob(str(source)).exists():
self._bucket.copy_blob(
self._bucket.blob(str(source)), self._bucket, str(destination)
)
return
# Otherwise, copy all blobs with the prefix (folder)
for blob in self._bucket.list_blobs(prefix=f"{source}/"):
new_name = str(blob.name).replace(str(source), str(destination), 1)
self._bucket.copy_blob(blob, self._bucket, new_name)
def clone_with_subroot(self, subroot: str | Path) -> GCSFileStorage: def clone_with_subroot(self, subroot: str | Path) -> GCSFileStorage:
"""Create a new GCSFileStorage with a subroot of the current storage.""" """Create a new GCSFileStorage with a subroot of the current storage."""
file_storage = GCSFileStorage( file_storage = GCSFileStorage(

View File

@ -115,6 +115,20 @@ class LocalFileStorage(FileStorage):
new_path = self.get_path(new_path) new_path = self.get_path(new_path)
old_path.rename(new_path) old_path.rename(new_path)
def copy(self, source: str | Path, destination: str | Path) -> None:
"""Copy a file or folder with all contents in the storage."""
source = self.get_path(source)
destination = self.get_path(destination)
if source.is_file():
destination.write_bytes(source.read_bytes())
else:
destination.mkdir(exist_ok=True, parents=True)
for file in source.rglob("*"):
if file.is_file():
target = destination / file.relative_to(source)
target.parent.mkdir(exist_ok=True, parents=True)
target.write_bytes(file.read_bytes())
def clone_with_subroot(self, subroot: str | Path) -> FileStorage: def clone_with_subroot(self, subroot: str | Path) -> FileStorage:
"""Create a new LocalFileStorage with a subroot of the current storage.""" """Create a new LocalFileStorage with a subroot of the current storage."""
return LocalFileStorage( return LocalFileStorage(

View File

@ -222,6 +222,35 @@ class S3FileStorage(FileStorage):
else: else:
raise # Re-raise for any other client errors raise # Re-raise for any other client errors
def copy(self, source: str | Path, destination: str | Path) -> None:
"""Copy a file or folder with all contents in the storage."""
source = str(self.get_path(source))
destination = str(self.get_path(destination))
try:
# If source is a file, copy it
self._s3.meta.client.head_object(Bucket=self._bucket_name, Key=source)
self._s3.meta.client.copy_object(
CopySource={"Bucket": self._bucket_name, "Key": source},
Bucket=self._bucket_name,
Key=destination,
)
except botocore.exceptions.ClientError as e:
if int(e.response["ResponseMetadata"]["HTTPStatusCode"]) == 404:
# If the object does not exist,
# it may be a folder
prefix = f"{source.rstrip('/')}/"
objs = list(self._bucket.objects.filter(Prefix=prefix))
for obj in objs:
new_key = destination + obj.key[len(source) :]
self._s3.meta.client.copy_object(
CopySource={"Bucket": self._bucket_name, "Key": obj.key},
Bucket=self._bucket_name,
Key=new_key,
)
else:
raise
def clone_with_subroot(self, subroot: str | Path) -> S3FileStorage: def clone_with_subroot(self, subroot: str | Path) -> S3FileStorage:
"""Create a new S3FileStorage with a subroot of the current storage.""" """Create a new S3FileStorage with a subroot of the current storage."""
file_storage = S3FileStorage( file_storage = S3FileStorage(

View File

@ -177,3 +177,24 @@ def test_clone(gcs_storage_with_files: GCSFileStorage, gcs_root: Path):
assert cloned._bucket.name == gcs_storage_with_files._bucket.name assert cloned._bucket.name == gcs_storage_with_files._bucket.name
assert cloned.exists("dir") assert cloned.exists("dir")
assert cloned.exists("dir/test_file_4") assert cloned.exists("dir/test_file_4")
@pytest.mark.asyncio
async def test_copy_file(storage: GCSFileStorage):
await storage.write_file("test_file.txt", "test content")
storage.copy("test_file.txt", "test_file_copy.txt")
storage.make_dir("dir")
storage.copy("test_file.txt", "dir/test_file_copy.txt")
assert storage.read_file("test_file_copy.txt") == "test content"
assert storage.read_file("dir/test_file_copy.txt") == "test content"
@pytest.mark.asyncio
async def test_copy_dir(storage: GCSFileStorage):
storage.make_dir("dir")
storage.make_dir("dir/sub_dir")
await storage.write_file("dir/test_file.txt", "test content")
await storage.write_file("dir/sub_dir/test_file.txt", "test content")
storage.copy("dir", "dir_copy")
assert storage.read_file("dir_copy/test_file.txt") == "test content"
assert storage.read_file("dir_copy/sub_dir/test_file.txt") == "test content"

View File

@ -188,3 +188,24 @@ def test_get_path_accessible(accessible_path: Path, storage: LocalFileStorage):
def test_get_path_inaccessible(inaccessible_path: Path, storage: LocalFileStorage): def test_get_path_inaccessible(inaccessible_path: Path, storage: LocalFileStorage):
with pytest.raises(ValueError): with pytest.raises(ValueError):
storage.get_path(inaccessible_path) storage.get_path(inaccessible_path)
@pytest.mark.asyncio
async def test_copy_file(storage: LocalFileStorage):
await storage.write_file("test_file.txt", "test content")
storage.copy("test_file.txt", "test_file_copy.txt")
storage.make_dir("dir")
storage.copy("test_file.txt", "dir/test_file_copy.txt")
assert storage.read_file("test_file_copy.txt") == "test content"
assert storage.read_file("dir/test_file_copy.txt") == "test content"
@pytest.mark.asyncio
async def test_copy_dir(storage: LocalFileStorage):
storage.make_dir("dir")
storage.make_dir("dir/sub_dir")
await storage.write_file("dir/test_file.txt", "test content")
await storage.write_file("dir/sub_dir/test_file.txt", "test content")
storage.copy("dir", "dir_copy")
assert storage.read_file("dir_copy/test_file.txt") == "test content"
assert storage.read_file("dir_copy/sub_dir/test_file.txt") == "test content"

View File

@ -172,3 +172,24 @@ def test_clone(s3_storage_with_files: S3FileStorage, s3_root: Path):
assert cloned._bucket.name == s3_storage_with_files._bucket.name assert cloned._bucket.name == s3_storage_with_files._bucket.name
assert cloned.exists("dir") assert cloned.exists("dir")
assert cloned.exists("dir/test_file_4") assert cloned.exists("dir/test_file_4")
@pytest.mark.asyncio
async def test_copy_file(storage: S3FileStorage):
await storage.write_file("test_file.txt", "test content")
storage.copy("test_file.txt", "test_file_copy.txt")
storage.make_dir("dir")
storage.copy("test_file.txt", "dir/test_file_copy.txt")
assert storage.read_file("test_file_copy.txt") == "test content"
assert storage.read_file("dir/test_file_copy.txt") == "test content"
@pytest.mark.asyncio
async def test_copy_dir(storage: S3FileStorage):
storage.make_dir("dir")
storage.make_dir("dir/sub_dir")
await storage.write_file("dir/test_file.txt", "test content")
await storage.write_file("dir/sub_dir/test_file.txt", "test content")
storage.copy("dir", "dir_copy")
assert storage.read_file("dir_copy/test_file.txt") == "test content"
assert storage.read_file("dir_copy/sub_dir/test_file.txt") == "test content"