fix(agent): Make `Agent.save_state` behave like save as (#7025)
* Make `Agent.save_state` behave like "save as" - Leave previously saved state untouched - Save agent state in new folder corresponding to new `agent_id` - Copy over workspace contents to new folder * Add `copy` method to `FileStorage` --------- Co-authored-by: Reinier van der Leer <pwuts@agpt.co>pull/7082/head
parent
90f3c5e2d9
commit
e866a4ba04
|
@ -1,6 +1,7 @@
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
from autogpt.file_storage.base import FileStorage
|
from autogpt.file_storage.base import FileStorage
|
||||||
|
|
||||||
|
@ -13,11 +14,11 @@ class AgentFileManagerMixin:
|
||||||
"""Mixin that adds file manager (e.g. Agent state)
|
"""Mixin that adds file manager (e.g. Agent state)
|
||||||
and workspace manager (e.g. Agent output files) support."""
|
and workspace manager (e.g. Agent output files) support."""
|
||||||
|
|
||||||
files: FileStorage = None
|
files: FileStorage
|
||||||
"""Agent-related files, e.g. state, logs.
|
"""Agent-related files, e.g. state, logs.
|
||||||
Use `workspace` to access the agent's workspace files."""
|
Use `workspace` to access the agent's workspace files."""
|
||||||
|
|
||||||
workspace: FileStorage = None
|
workspace: FileStorage
|
||||||
"""Workspace that the agent has access to, e.g. for reading/writing files.
|
"""Workspace that the agent has access to, e.g. for reading/writing files.
|
||||||
Use `files` to access agent-related files, e.g. state, logs."""
|
Use `files` to access agent-related files, e.g. state, logs."""
|
||||||
|
|
||||||
|
@ -68,9 +69,24 @@ class AgentFileManagerMixin:
|
||||||
"""Get the agent's file operation logs as list of strings."""
|
"""Get the agent's file operation logs as list of strings."""
|
||||||
return self._file_logs_cache
|
return self._file_logs_cache
|
||||||
|
|
||||||
async def save_state(self) -> None:
|
async def save_state(self, save_as: Optional[str] = None) -> None:
|
||||||
"""Save the agent's state to the state file."""
|
"""Save the agent's state to the state file."""
|
||||||
state: BaseAgentSettings = getattr(self, "state")
|
state: BaseAgentSettings = getattr(self, "state")
|
||||||
|
if save_as:
|
||||||
|
temp_id = state.agent_id
|
||||||
|
state.agent_id = save_as
|
||||||
|
self._file_storage.make_dir(f"agents/{save_as}")
|
||||||
|
# Save state
|
||||||
|
await self._file_storage.write_file(
|
||||||
|
f"agents/{save_as}/{self.STATE_FILE}", state.json()
|
||||||
|
)
|
||||||
|
# Copy workspace
|
||||||
|
self._file_storage.copy(
|
||||||
|
f"agents/{temp_id}/workspace",
|
||||||
|
f"agents/{save_as}/workspace",
|
||||||
|
)
|
||||||
|
state.agent_id = temp_id
|
||||||
|
else:
|
||||||
await self.files.write_file(self.files.root / self.STATE_FILE, state.json())
|
await self.files.write_file(self.files.root / self.STATE_FILE, state.json())
|
||||||
|
|
||||||
def change_agent_id(self, new_id: str):
|
def change_agent_id(self, new_id: str):
|
||||||
|
|
|
@ -345,19 +345,13 @@ async def run_auto_gpt(
|
||||||
logger.info(f"Saving state of {agent_id}...")
|
logger.info(f"Saving state of {agent_id}...")
|
||||||
|
|
||||||
# Allow user to Save As other ID
|
# Allow user to Save As other ID
|
||||||
save_as_id = (
|
save_as_id = clean_input(
|
||||||
clean_input(
|
|
||||||
config,
|
config,
|
||||||
f"Press enter to save as '{agent_id}',"
|
f"Press enter to save as '{agent_id}',"
|
||||||
" or enter a different ID to save to:",
|
" or enter a different ID to save to:",
|
||||||
)
|
)
|
||||||
or agent_id
|
|
||||||
)
|
|
||||||
if save_as_id and save_as_id != agent_id:
|
|
||||||
agent.change_agent_id(save_as_id)
|
|
||||||
# TODO: allow many-to-one relations of agents and workspaces
|
# TODO: allow many-to-one relations of agents and workspaces
|
||||||
|
await agent.save_state(save_as_id if not save_as_id.isspace() else None)
|
||||||
await agent.save_state()
|
|
||||||
|
|
||||||
|
|
||||||
@coroutine
|
@coroutine
|
||||||
|
|
|
@ -127,6 +127,10 @@ class FileStorage(ABC):
|
||||||
def rename(self, old_path: str | Path, new_path: str | Path) -> None:
|
def rename(self, old_path: str | Path, new_path: str | Path) -> None:
|
||||||
"""Rename a file or folder in the storage."""
|
"""Rename a file or folder in the storage."""
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def copy(self, source: str | Path, destination: str | Path) -> None:
|
||||||
|
"""Copy a file or folder with all contents in the storage."""
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def make_dir(self, path: str | Path) -> None:
|
def make_dir(self, path: str | Path) -> None:
|
||||||
"""Create a directory in the storage if doesn't exist."""
|
"""Create a directory in the storage if doesn't exist."""
|
||||||
|
|
|
@ -182,6 +182,21 @@ class GCSFileStorage(FileStorage):
|
||||||
new_name = str(blob.name).replace(str(old_path), str(new_path), 1)
|
new_name = str(blob.name).replace(str(old_path), str(new_path), 1)
|
||||||
self._bucket.rename_blob(blob, new_name=new_name)
|
self._bucket.rename_blob(blob, new_name=new_name)
|
||||||
|
|
||||||
|
def copy(self, source: str | Path, destination: str | Path) -> None:
|
||||||
|
"""Copy a file or folder with all contents in the storage."""
|
||||||
|
source = self.get_path(source)
|
||||||
|
destination = self.get_path(destination)
|
||||||
|
# If the source is a file, copy it
|
||||||
|
if self._bucket.blob(str(source)).exists():
|
||||||
|
self._bucket.copy_blob(
|
||||||
|
self._bucket.blob(str(source)), self._bucket, str(destination)
|
||||||
|
)
|
||||||
|
return
|
||||||
|
# Otherwise, copy all blobs with the prefix (folder)
|
||||||
|
for blob in self._bucket.list_blobs(prefix=f"{source}/"):
|
||||||
|
new_name = str(blob.name).replace(str(source), str(destination), 1)
|
||||||
|
self._bucket.copy_blob(blob, self._bucket, new_name)
|
||||||
|
|
||||||
def clone_with_subroot(self, subroot: str | Path) -> GCSFileStorage:
|
def clone_with_subroot(self, subroot: str | Path) -> GCSFileStorage:
|
||||||
"""Create a new GCSFileStorage with a subroot of the current storage."""
|
"""Create a new GCSFileStorage with a subroot of the current storage."""
|
||||||
file_storage = GCSFileStorage(
|
file_storage = GCSFileStorage(
|
||||||
|
|
|
@ -115,6 +115,20 @@ class LocalFileStorage(FileStorage):
|
||||||
new_path = self.get_path(new_path)
|
new_path = self.get_path(new_path)
|
||||||
old_path.rename(new_path)
|
old_path.rename(new_path)
|
||||||
|
|
||||||
|
def copy(self, source: str | Path, destination: str | Path) -> None:
|
||||||
|
"""Copy a file or folder with all contents in the storage."""
|
||||||
|
source = self.get_path(source)
|
||||||
|
destination = self.get_path(destination)
|
||||||
|
if source.is_file():
|
||||||
|
destination.write_bytes(source.read_bytes())
|
||||||
|
else:
|
||||||
|
destination.mkdir(exist_ok=True, parents=True)
|
||||||
|
for file in source.rglob("*"):
|
||||||
|
if file.is_file():
|
||||||
|
target = destination / file.relative_to(source)
|
||||||
|
target.parent.mkdir(exist_ok=True, parents=True)
|
||||||
|
target.write_bytes(file.read_bytes())
|
||||||
|
|
||||||
def clone_with_subroot(self, subroot: str | Path) -> FileStorage:
|
def clone_with_subroot(self, subroot: str | Path) -> FileStorage:
|
||||||
"""Create a new LocalFileStorage with a subroot of the current storage."""
|
"""Create a new LocalFileStorage with a subroot of the current storage."""
|
||||||
return LocalFileStorage(
|
return LocalFileStorage(
|
||||||
|
|
|
@ -222,6 +222,35 @@ class S3FileStorage(FileStorage):
|
||||||
else:
|
else:
|
||||||
raise # Re-raise for any other client errors
|
raise # Re-raise for any other client errors
|
||||||
|
|
||||||
|
def copy(self, source: str | Path, destination: str | Path) -> None:
|
||||||
|
"""Copy a file or folder with all contents in the storage."""
|
||||||
|
source = str(self.get_path(source))
|
||||||
|
destination = str(self.get_path(destination))
|
||||||
|
|
||||||
|
try:
|
||||||
|
# If source is a file, copy it
|
||||||
|
self._s3.meta.client.head_object(Bucket=self._bucket_name, Key=source)
|
||||||
|
self._s3.meta.client.copy_object(
|
||||||
|
CopySource={"Bucket": self._bucket_name, "Key": source},
|
||||||
|
Bucket=self._bucket_name,
|
||||||
|
Key=destination,
|
||||||
|
)
|
||||||
|
except botocore.exceptions.ClientError as e:
|
||||||
|
if int(e.response["ResponseMetadata"]["HTTPStatusCode"]) == 404:
|
||||||
|
# If the object does not exist,
|
||||||
|
# it may be a folder
|
||||||
|
prefix = f"{source.rstrip('/')}/"
|
||||||
|
objs = list(self._bucket.objects.filter(Prefix=prefix))
|
||||||
|
for obj in objs:
|
||||||
|
new_key = destination + obj.key[len(source) :]
|
||||||
|
self._s3.meta.client.copy_object(
|
||||||
|
CopySource={"Bucket": self._bucket_name, "Key": obj.key},
|
||||||
|
Bucket=self._bucket_name,
|
||||||
|
Key=new_key,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
raise
|
||||||
|
|
||||||
def clone_with_subroot(self, subroot: str | Path) -> S3FileStorage:
|
def clone_with_subroot(self, subroot: str | Path) -> S3FileStorage:
|
||||||
"""Create a new S3FileStorage with a subroot of the current storage."""
|
"""Create a new S3FileStorage with a subroot of the current storage."""
|
||||||
file_storage = S3FileStorage(
|
file_storage = S3FileStorage(
|
||||||
|
|
|
@ -177,3 +177,24 @@ def test_clone(gcs_storage_with_files: GCSFileStorage, gcs_root: Path):
|
||||||
assert cloned._bucket.name == gcs_storage_with_files._bucket.name
|
assert cloned._bucket.name == gcs_storage_with_files._bucket.name
|
||||||
assert cloned.exists("dir")
|
assert cloned.exists("dir")
|
||||||
assert cloned.exists("dir/test_file_4")
|
assert cloned.exists("dir/test_file_4")
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_copy_file(storage: GCSFileStorage):
|
||||||
|
await storage.write_file("test_file.txt", "test content")
|
||||||
|
storage.copy("test_file.txt", "test_file_copy.txt")
|
||||||
|
storage.make_dir("dir")
|
||||||
|
storage.copy("test_file.txt", "dir/test_file_copy.txt")
|
||||||
|
assert storage.read_file("test_file_copy.txt") == "test content"
|
||||||
|
assert storage.read_file("dir/test_file_copy.txt") == "test content"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_copy_dir(storage: GCSFileStorage):
|
||||||
|
storage.make_dir("dir")
|
||||||
|
storage.make_dir("dir/sub_dir")
|
||||||
|
await storage.write_file("dir/test_file.txt", "test content")
|
||||||
|
await storage.write_file("dir/sub_dir/test_file.txt", "test content")
|
||||||
|
storage.copy("dir", "dir_copy")
|
||||||
|
assert storage.read_file("dir_copy/test_file.txt") == "test content"
|
||||||
|
assert storage.read_file("dir_copy/sub_dir/test_file.txt") == "test content"
|
||||||
|
|
|
@ -188,3 +188,24 @@ def test_get_path_accessible(accessible_path: Path, storage: LocalFileStorage):
|
||||||
def test_get_path_inaccessible(inaccessible_path: Path, storage: LocalFileStorage):
|
def test_get_path_inaccessible(inaccessible_path: Path, storage: LocalFileStorage):
|
||||||
with pytest.raises(ValueError):
|
with pytest.raises(ValueError):
|
||||||
storage.get_path(inaccessible_path)
|
storage.get_path(inaccessible_path)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_copy_file(storage: LocalFileStorage):
|
||||||
|
await storage.write_file("test_file.txt", "test content")
|
||||||
|
storage.copy("test_file.txt", "test_file_copy.txt")
|
||||||
|
storage.make_dir("dir")
|
||||||
|
storage.copy("test_file.txt", "dir/test_file_copy.txt")
|
||||||
|
assert storage.read_file("test_file_copy.txt") == "test content"
|
||||||
|
assert storage.read_file("dir/test_file_copy.txt") == "test content"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_copy_dir(storage: LocalFileStorage):
|
||||||
|
storage.make_dir("dir")
|
||||||
|
storage.make_dir("dir/sub_dir")
|
||||||
|
await storage.write_file("dir/test_file.txt", "test content")
|
||||||
|
await storage.write_file("dir/sub_dir/test_file.txt", "test content")
|
||||||
|
storage.copy("dir", "dir_copy")
|
||||||
|
assert storage.read_file("dir_copy/test_file.txt") == "test content"
|
||||||
|
assert storage.read_file("dir_copy/sub_dir/test_file.txt") == "test content"
|
||||||
|
|
|
@ -172,3 +172,24 @@ def test_clone(s3_storage_with_files: S3FileStorage, s3_root: Path):
|
||||||
assert cloned._bucket.name == s3_storage_with_files._bucket.name
|
assert cloned._bucket.name == s3_storage_with_files._bucket.name
|
||||||
assert cloned.exists("dir")
|
assert cloned.exists("dir")
|
||||||
assert cloned.exists("dir/test_file_4")
|
assert cloned.exists("dir/test_file_4")
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_copy_file(storage: S3FileStorage):
|
||||||
|
await storage.write_file("test_file.txt", "test content")
|
||||||
|
storage.copy("test_file.txt", "test_file_copy.txt")
|
||||||
|
storage.make_dir("dir")
|
||||||
|
storage.copy("test_file.txt", "dir/test_file_copy.txt")
|
||||||
|
assert storage.read_file("test_file_copy.txt") == "test content"
|
||||||
|
assert storage.read_file("dir/test_file_copy.txt") == "test content"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_copy_dir(storage: S3FileStorage):
|
||||||
|
storage.make_dir("dir")
|
||||||
|
storage.make_dir("dir/sub_dir")
|
||||||
|
await storage.write_file("dir/test_file.txt", "test content")
|
||||||
|
await storage.write_file("dir/sub_dir/test_file.txt", "test content")
|
||||||
|
storage.copy("dir", "dir_copy")
|
||||||
|
assert storage.read_file("dir_copy/test_file.txt") == "test content"
|
||||||
|
assert storage.read_file("dir_copy/sub_dir/test_file.txt") == "test content"
|
||||||
|
|
Loading…
Reference in New Issue