fix(agent): Make `Agent.save_state` behave like save as (#7025)

* Make `Agent.save_state` behave like "save as"
   - Leave previously saved state untouched
   - Save agent state in new folder corresponding to new `agent_id`
   - Copy over workspace contents to new folder

* Add `copy` method to `FileStorage`

---------

Co-authored-by: Reinier van der Leer <pwuts@agpt.co>
pull/7082/head
Krzysztof Czerwinski 2024-04-12 12:41:02 +02:00 committed by GitHub
parent 90f3c5e2d9
commit e866a4ba04
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
9 changed files with 151 additions and 16 deletions

View File

@ -1,6 +1,7 @@
from __future__ import annotations
import logging
from typing import Optional
from autogpt.file_storage.base import FileStorage
@ -13,11 +14,11 @@ class AgentFileManagerMixin:
"""Mixin that adds file manager (e.g. Agent state)
and workspace manager (e.g. Agent output files) support."""
files: FileStorage = None
files: FileStorage
"""Agent-related files, e.g. state, logs.
Use `workspace` to access the agent's workspace files."""
workspace: FileStorage = None
workspace: FileStorage
"""Workspace that the agent has access to, e.g. for reading/writing files.
Use `files` to access agent-related files, e.g. state, logs."""
@ -68,9 +69,24 @@ class AgentFileManagerMixin:
"""Get the agent's file operation logs as list of strings."""
return self._file_logs_cache
async def save_state(self) -> None:
async def save_state(self, save_as: Optional[str] = None) -> None:
"""Save the agent's state to the state file."""
state: BaseAgentSettings = getattr(self, "state")
if save_as:
temp_id = state.agent_id
state.agent_id = save_as
self._file_storage.make_dir(f"agents/{save_as}")
# Save state
await self._file_storage.write_file(
f"agents/{save_as}/{self.STATE_FILE}", state.json()
)
# Copy workspace
self._file_storage.copy(
f"agents/{temp_id}/workspace",
f"agents/{save_as}/workspace",
)
state.agent_id = temp_id
else:
await self.files.write_file(self.files.root / self.STATE_FILE, state.json())
def change_agent_id(self, new_id: str):

View File

@ -345,19 +345,13 @@ async def run_auto_gpt(
logger.info(f"Saving state of {agent_id}...")
# Allow user to Save As other ID
save_as_id = (
clean_input(
save_as_id = clean_input(
config,
f"Press enter to save as '{agent_id}',"
" or enter a different ID to save to:",
)
or agent_id
)
if save_as_id and save_as_id != agent_id:
agent.change_agent_id(save_as_id)
# TODO: allow many-to-one relations of agents and workspaces
await agent.save_state()
await agent.save_state(save_as_id if not save_as_id.isspace() else None)
@coroutine

View File

@ -127,6 +127,10 @@ class FileStorage(ABC):
def rename(self, old_path: str | Path, new_path: str | Path) -> None:
"""Rename a file or folder in the storage."""
@abstractmethod
def copy(self, source: str | Path, destination: str | Path) -> None:
"""Copy a file or folder with all contents in the storage."""
@abstractmethod
def make_dir(self, path: str | Path) -> None:
"""Create a directory in the storage if doesn't exist."""

View File

@ -182,6 +182,21 @@ class GCSFileStorage(FileStorage):
new_name = str(blob.name).replace(str(old_path), str(new_path), 1)
self._bucket.rename_blob(blob, new_name=new_name)
def copy(self, source: str | Path, destination: str | Path) -> None:
"""Copy a file or folder with all contents in the storage."""
source = self.get_path(source)
destination = self.get_path(destination)
# If the source is a file, copy it
if self._bucket.blob(str(source)).exists():
self._bucket.copy_blob(
self._bucket.blob(str(source)), self._bucket, str(destination)
)
return
# Otherwise, copy all blobs with the prefix (folder)
for blob in self._bucket.list_blobs(prefix=f"{source}/"):
new_name = str(blob.name).replace(str(source), str(destination), 1)
self._bucket.copy_blob(blob, self._bucket, new_name)
def clone_with_subroot(self, subroot: str | Path) -> GCSFileStorage:
"""Create a new GCSFileStorage with a subroot of the current storage."""
file_storage = GCSFileStorage(

View File

@ -115,6 +115,20 @@ class LocalFileStorage(FileStorage):
new_path = self.get_path(new_path)
old_path.rename(new_path)
def copy(self, source: str | Path, destination: str | Path) -> None:
"""Copy a file or folder with all contents in the storage."""
source = self.get_path(source)
destination = self.get_path(destination)
if source.is_file():
destination.write_bytes(source.read_bytes())
else:
destination.mkdir(exist_ok=True, parents=True)
for file in source.rglob("*"):
if file.is_file():
target = destination / file.relative_to(source)
target.parent.mkdir(exist_ok=True, parents=True)
target.write_bytes(file.read_bytes())
def clone_with_subroot(self, subroot: str | Path) -> FileStorage:
"""Create a new LocalFileStorage with a subroot of the current storage."""
return LocalFileStorage(

View File

@ -222,6 +222,35 @@ class S3FileStorage(FileStorage):
else:
raise # Re-raise for any other client errors
def copy(self, source: str | Path, destination: str | Path) -> None:
"""Copy a file or folder with all contents in the storage."""
source = str(self.get_path(source))
destination = str(self.get_path(destination))
try:
# If source is a file, copy it
self._s3.meta.client.head_object(Bucket=self._bucket_name, Key=source)
self._s3.meta.client.copy_object(
CopySource={"Bucket": self._bucket_name, "Key": source},
Bucket=self._bucket_name,
Key=destination,
)
except botocore.exceptions.ClientError as e:
if int(e.response["ResponseMetadata"]["HTTPStatusCode"]) == 404:
# If the object does not exist,
# it may be a folder
prefix = f"{source.rstrip('/')}/"
objs = list(self._bucket.objects.filter(Prefix=prefix))
for obj in objs:
new_key = destination + obj.key[len(source) :]
self._s3.meta.client.copy_object(
CopySource={"Bucket": self._bucket_name, "Key": obj.key},
Bucket=self._bucket_name,
Key=new_key,
)
else:
raise
def clone_with_subroot(self, subroot: str | Path) -> S3FileStorage:
"""Create a new S3FileStorage with a subroot of the current storage."""
file_storage = S3FileStorage(

View File

@ -177,3 +177,24 @@ def test_clone(gcs_storage_with_files: GCSFileStorage, gcs_root: Path):
assert cloned._bucket.name == gcs_storage_with_files._bucket.name
assert cloned.exists("dir")
assert cloned.exists("dir/test_file_4")
@pytest.mark.asyncio
async def test_copy_file(storage: GCSFileStorage):
await storage.write_file("test_file.txt", "test content")
storage.copy("test_file.txt", "test_file_copy.txt")
storage.make_dir("dir")
storage.copy("test_file.txt", "dir/test_file_copy.txt")
assert storage.read_file("test_file_copy.txt") == "test content"
assert storage.read_file("dir/test_file_copy.txt") == "test content"
@pytest.mark.asyncio
async def test_copy_dir(storage: GCSFileStorage):
storage.make_dir("dir")
storage.make_dir("dir/sub_dir")
await storage.write_file("dir/test_file.txt", "test content")
await storage.write_file("dir/sub_dir/test_file.txt", "test content")
storage.copy("dir", "dir_copy")
assert storage.read_file("dir_copy/test_file.txt") == "test content"
assert storage.read_file("dir_copy/sub_dir/test_file.txt") == "test content"

View File

@ -188,3 +188,24 @@ def test_get_path_accessible(accessible_path: Path, storage: LocalFileStorage):
def test_get_path_inaccessible(inaccessible_path: Path, storage: LocalFileStorage):
with pytest.raises(ValueError):
storage.get_path(inaccessible_path)
@pytest.mark.asyncio
async def test_copy_file(storage: LocalFileStorage):
await storage.write_file("test_file.txt", "test content")
storage.copy("test_file.txt", "test_file_copy.txt")
storage.make_dir("dir")
storage.copy("test_file.txt", "dir/test_file_copy.txt")
assert storage.read_file("test_file_copy.txt") == "test content"
assert storage.read_file("dir/test_file_copy.txt") == "test content"
@pytest.mark.asyncio
async def test_copy_dir(storage: LocalFileStorage):
storage.make_dir("dir")
storage.make_dir("dir/sub_dir")
await storage.write_file("dir/test_file.txt", "test content")
await storage.write_file("dir/sub_dir/test_file.txt", "test content")
storage.copy("dir", "dir_copy")
assert storage.read_file("dir_copy/test_file.txt") == "test content"
assert storage.read_file("dir_copy/sub_dir/test_file.txt") == "test content"

View File

@ -172,3 +172,24 @@ def test_clone(s3_storage_with_files: S3FileStorage, s3_root: Path):
assert cloned._bucket.name == s3_storage_with_files._bucket.name
assert cloned.exists("dir")
assert cloned.exists("dir/test_file_4")
@pytest.mark.asyncio
async def test_copy_file(storage: S3FileStorage):
await storage.write_file("test_file.txt", "test content")
storage.copy("test_file.txt", "test_file_copy.txt")
storage.make_dir("dir")
storage.copy("test_file.txt", "dir/test_file_copy.txt")
assert storage.read_file("test_file_copy.txt") == "test content"
assert storage.read_file("dir/test_file_copy.txt") == "test content"
@pytest.mark.asyncio
async def test_copy_dir(storage: S3FileStorage):
storage.make_dir("dir")
storage.make_dir("dir/sub_dir")
await storage.write_file("dir/test_file.txt", "test content")
await storage.write_file("dir/sub_dir/test_file.txt", "test content")
storage.copy("dir", "dir_copy")
assert storage.read_file("dir_copy/test_file.txt") == "test content"
assert storage.read_file("dir_copy/sub_dir/test_file.txt") == "test content"