fix(store) : Download agent from store if user is not logged in (#9121)
- resolves - #9120 ### Changes - Added a new endpoint to download agent files as JSON, allowing users to retrieve agent data by store listing version ID and version number. - Introduced a new `get_agent` function in the database module to fetch agent details and prepare the graph data for download. - Enhanced the frontend `AgentInfo` component to include a download button, which triggers the download of the agent file. - Integrated loading state and user feedback via toast notifications during the download process. - Updated the API client to support the new download functionality. ### Demo video https://github.com/user-attachments/assets/6744a753-297f-4ccc-abde-f56ca24ed2d5 ### Example Json ```json { "id": "14378095-4cc5-41ea-975e-bd0bce010bea", "version": 1, "is_active": true, "is_template": false, "name": "something", "description": "1", "nodes": [ { "id": "6914efa0-e4fa-4ce8-802c-d5577cf061b6", "block_id": "aeb08fc1-2fc1-4141-bc8e-f758f183a822", "input_default": {}, "metadata": { "position": { "x": 756, "y": 452.5 } }, "input_links": [], "output_links": [], "webhook_id": null, "graph_id": "14378095-4cc5-41ea-975e-bd0bce010bea", "graph_version": 1, "webhook": null } ], "links": [], "input_schema": { "type": "object", "properties": {}, "required": [] }, "output_schema": { "type": "object", "properties": {}, "required": [] } } ``` --------- Co-authored-by: SwiftyOS <craigswift13@gmail.com>pull/9208/head
parent
1375a0fdbc
commit
0872da1969
|
@ -424,6 +424,26 @@ class GraphModel(Graph):
|
|||
result[key] = value
|
||||
return result
|
||||
|
||||
def clean_graph(self):
|
||||
blocks = [block() for block in get_blocks().values()]
|
||||
|
||||
input_blocks = [
|
||||
node
|
||||
for node in self.nodes
|
||||
if next(
|
||||
(
|
||||
b
|
||||
for b in blocks
|
||||
if b.id == node.block_id and b.block_type == BlockType.INPUT
|
||||
),
|
||||
None,
|
||||
)
|
||||
]
|
||||
|
||||
for node in self.nodes:
|
||||
if any(input_block.id == node.id for input_block in input_blocks):
|
||||
node.input_default["value"] = ""
|
||||
|
||||
|
||||
# --------------------- CRUD functions --------------------- #
|
||||
|
||||
|
|
|
@ -1,14 +1,18 @@
|
|||
import logging
|
||||
import random
|
||||
from datetime import datetime
|
||||
from typing import Optional
|
||||
|
||||
import fastapi
|
||||
import prisma.enums
|
||||
import prisma.errors
|
||||
import prisma.models
|
||||
import prisma.types
|
||||
|
||||
import backend.data.graph
|
||||
import backend.server.v2.store.exceptions
|
||||
import backend.server.v2.store.model
|
||||
from backend.data.graph import GraphModel
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
@ -786,3 +790,45 @@ async def get_my_agents(
|
|||
raise backend.server.v2.store.exceptions.DatabaseError(
|
||||
"Failed to fetch my agents"
|
||||
) from e
|
||||
|
||||
|
||||
async def get_agent(
|
||||
store_listing_version_id: str, version_id: Optional[int]
|
||||
) -> GraphModel:
|
||||
"""Get agent using the version ID and store listing version ID."""
|
||||
try:
|
||||
store_listing_version = (
|
||||
await prisma.models.StoreListingVersion.prisma().find_unique(
|
||||
where={"id": store_listing_version_id}, include={"Agent": True}
|
||||
)
|
||||
)
|
||||
|
||||
if not store_listing_version or not store_listing_version.Agent:
|
||||
raise fastapi.HTTPException(
|
||||
status_code=404,
|
||||
detail=f"Store listing version {store_listing_version_id} not found",
|
||||
)
|
||||
|
||||
agent = store_listing_version.Agent
|
||||
|
||||
graph = await backend.data.graph.get_graph(
|
||||
agent.id, agent.version, template=True
|
||||
)
|
||||
|
||||
if not graph:
|
||||
raise fastapi.HTTPException(
|
||||
status_code=404, detail=f"Agent {agent.id} not found"
|
||||
)
|
||||
|
||||
graph.version = 1
|
||||
graph.is_template = False
|
||||
graph.is_active = True
|
||||
delattr(graph, "user_id")
|
||||
|
||||
return graph
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting agent: {str(e)}")
|
||||
raise backend.server.v2.store.exceptions.DatabaseError(
|
||||
"Failed to fetch agent"
|
||||
) from e
|
||||
|
|
|
@ -1,4 +1,6 @@
|
|||
import json
|
||||
import logging
|
||||
import tempfile
|
||||
import typing
|
||||
import urllib.parse
|
||||
|
||||
|
@ -6,7 +8,9 @@ import autogpt_libs.auth.depends
|
|||
import autogpt_libs.auth.middleware
|
||||
import fastapi
|
||||
import fastapi.responses
|
||||
from fastapi.encoders import jsonable_encoder
|
||||
|
||||
import backend.data.block
|
||||
import backend.data.graph
|
||||
import backend.server.v2.store.db
|
||||
import backend.server.v2.store.image_gen
|
||||
|
@ -575,3 +579,66 @@ async def generate_image(
|
|||
status_code=500,
|
||||
content={"detail": "An error occurred while generating the image"},
|
||||
)
|
||||
|
||||
|
||||
@router.get(
|
||||
"/download/agents/{store_listing_version_id}",
|
||||
tags=["store", "public"],
|
||||
)
|
||||
async def download_agent_file(
|
||||
store_listing_version_id: str = fastapi.Path(
|
||||
..., description="The ID of the agent to download"
|
||||
),
|
||||
version: typing.Optional[int] = fastapi.Query(
|
||||
None, description="Specific version of the agent"
|
||||
),
|
||||
) -> fastapi.responses.FileResponse:
|
||||
"""
|
||||
Download the agent file by streaming its content.
|
||||
|
||||
Args:
|
||||
agent_id (str): The ID of the agent to download.
|
||||
version (Optional[int]): Specific version of the agent to download.
|
||||
|
||||
Returns:
|
||||
StreamingResponse: A streaming response containing the agent's graph data.
|
||||
|
||||
Raises:
|
||||
HTTPException: If the agent is not found or an unexpected error occurs.
|
||||
"""
|
||||
|
||||
graph_data = await backend.server.v2.store.db.get_agent(
|
||||
store_listing_version_id=store_listing_version_id, version_id=version
|
||||
)
|
||||
|
||||
graph_data.clean_graph()
|
||||
graph_date_dict = jsonable_encoder(graph_data)
|
||||
|
||||
def remove_credentials(obj):
|
||||
if obj and isinstance(obj, dict):
|
||||
if "credentials" in obj:
|
||||
del obj["credentials"]
|
||||
if "creds" in obj:
|
||||
del obj["creds"]
|
||||
|
||||
for value in obj.values():
|
||||
remove_credentials(value)
|
||||
elif isinstance(obj, list):
|
||||
for item in obj:
|
||||
remove_credentials(item)
|
||||
return obj
|
||||
|
||||
graph_date_dict = remove_credentials(graph_date_dict)
|
||||
|
||||
file_name = f"agent_{store_listing_version_id}_v{version or 'latest'}.json"
|
||||
|
||||
# Sending graph as a stream (similar to marketplace v1)
|
||||
with tempfile.NamedTemporaryFile(
|
||||
mode="w", suffix=".json", delete=False
|
||||
) as tmp_file:
|
||||
tmp_file.write(json.dumps(graph_date_dict))
|
||||
tmp_file.flush()
|
||||
|
||||
return fastapi.responses.FileResponse(
|
||||
tmp_file.name, filename=file_name, media_type="application/json"
|
||||
)
|
||||
|
|
|
@ -160,3 +160,45 @@ async def test_get_input_schema(server: SpinTestServer):
|
|||
output_schema = created_graph.output_schema
|
||||
output_schema["title"] = "ExpectedOutputSchema"
|
||||
assert output_schema == ExpectedOutputSchema.jsonschema()
|
||||
|
||||
|
||||
@pytest.mark.asyncio(scope="session")
|
||||
async def test_clean_graph(server: SpinTestServer):
|
||||
"""
|
||||
Test the clean_graph function that:
|
||||
1. Clears input block values
|
||||
2. Removes credentials from nodes
|
||||
"""
|
||||
# Create a graph with input blocks and credentials
|
||||
graph = Graph(
|
||||
id="test_clean_graph",
|
||||
name="Test Clean Graph",
|
||||
description="Test graph cleaning",
|
||||
nodes=[
|
||||
Node(
|
||||
id="input_node",
|
||||
block_id=AgentInputBlock().id,
|
||||
input_default={
|
||||
"name": "test_input",
|
||||
"value": "test value",
|
||||
"description": "Test input description",
|
||||
},
|
||||
),
|
||||
],
|
||||
links=[],
|
||||
)
|
||||
|
||||
# Create graph and get model
|
||||
create_graph = CreateGraph(graph=graph)
|
||||
created_graph = await server.agent_server.test_create_graph(
|
||||
create_graph, DEFAULT_USER_ID
|
||||
)
|
||||
|
||||
# Clean the graph
|
||||
created_graph.clean_graph()
|
||||
|
||||
# # Verify input block value is cleared
|
||||
input_node = next(
|
||||
n for n in created_graph.nodes if n.block_id == AgentInputBlock().id
|
||||
)
|
||||
assert input_node.input_default["value"] == ""
|
||||
|
|
|
@ -6,6 +6,10 @@ import { Separator } from "@/components/ui/separator";
|
|||
import BackendAPI from "@/lib/autogpt-server-api";
|
||||
import { useRouter } from "next/navigation";
|
||||
import Link from "next/link";
|
||||
import { useToast } from "@/components/ui/use-toast";
|
||||
|
||||
import useSupabase from "@/hooks/useSupabase";
|
||||
import { DownloadIcon, LoaderIcon } from "lucide-react";
|
||||
interface AgentInfoProps {
|
||||
name: string;
|
||||
creator: string;
|
||||
|
@ -32,8 +36,11 @@ export const AgentInfo: React.FC<AgentInfoProps> = ({
|
|||
storeListingVersionId,
|
||||
}) => {
|
||||
const router = useRouter();
|
||||
|
||||
const api = React.useMemo(() => new BackendAPI(), []);
|
||||
const { user } = useSupabase();
|
||||
const { toast } = useToast();
|
||||
|
||||
const [downloading, setDownloading] = React.useState(false);
|
||||
|
||||
const handleAddToLibrary = async () => {
|
||||
try {
|
||||
|
@ -45,6 +52,46 @@ export const AgentInfo: React.FC<AgentInfoProps> = ({
|
|||
}
|
||||
};
|
||||
|
||||
const handleDownloadToLibrary = async () => {
|
||||
const downloadAgent = async (): Promise<void> => {
|
||||
setDownloading(true);
|
||||
try {
|
||||
const file = await api.downloadStoreAgent(storeListingVersionId);
|
||||
|
||||
// Similar to Marketplace v1
|
||||
const jsonData = JSON.stringify(file, null, 2);
|
||||
// Create a Blob from the file content
|
||||
const blob = new Blob([jsonData], { type: "application/json" });
|
||||
|
||||
// Create a temporary URL for the Blob
|
||||
const url = window.URL.createObjectURL(blob);
|
||||
|
||||
// Create a temporary anchor element
|
||||
const a = document.createElement("a");
|
||||
a.href = url;
|
||||
a.download = `agent_${storeListingVersionId}.json`; // Set the filename
|
||||
|
||||
// Append the anchor to the body, click it, and remove it
|
||||
document.body.appendChild(a);
|
||||
a.click();
|
||||
document.body.removeChild(a);
|
||||
|
||||
// Revoke the temporary URL
|
||||
window.URL.revokeObjectURL(url);
|
||||
|
||||
toast({
|
||||
title: "Download Complete",
|
||||
description: "Your agent has been successfully downloaded.",
|
||||
});
|
||||
} catch (error) {
|
||||
console.error(`Error downloading agent:`, error);
|
||||
throw error;
|
||||
}
|
||||
};
|
||||
await downloadAgent();
|
||||
setDownloading(false);
|
||||
};
|
||||
|
||||
return (
|
||||
<div className="w-full max-w-[396px] px-4 sm:px-6 lg:w-[396px] lg:px-0">
|
||||
{/* Title */}
|
||||
|
@ -72,15 +119,36 @@ export const AgentInfo: React.FC<AgentInfoProps> = ({
|
|||
|
||||
{/* Run Agent Button */}
|
||||
<div className="mb-4 w-full lg:mb-[60px]">
|
||||
<button
|
||||
onClick={handleAddToLibrary}
|
||||
className="inline-flex w-full items-center justify-center gap-2 rounded-[38px] bg-violet-600 px-4 py-3 transition-colors hover:bg-violet-700 sm:w-auto sm:gap-2.5 sm:px-5 sm:py-3.5 lg:px-6 lg:py-4"
|
||||
>
|
||||
<IconPlay className="h-5 w-5 text-white sm:h-5 sm:w-5 lg:h-6 lg:w-6" />
|
||||
<span className="font-poppins text-base font-medium text-neutral-50 sm:text-lg">
|
||||
Add To Library
|
||||
</span>
|
||||
</button>
|
||||
{user ? (
|
||||
<button
|
||||
onClick={handleAddToLibrary}
|
||||
className="inline-flex w-full items-center justify-center gap-2 rounded-[38px] bg-violet-600 px-4 py-3 transition-colors hover:bg-violet-700 sm:w-auto sm:gap-2.5 sm:px-5 sm:py-3.5 lg:px-6 lg:py-4"
|
||||
>
|
||||
<IconPlay className="h-5 w-5 text-white sm:h-5 sm:w-5 lg:h-6 lg:w-6" />
|
||||
<span className="font-poppins text-base font-medium text-neutral-50 sm:text-lg">
|
||||
Add To Library
|
||||
</span>
|
||||
</button>
|
||||
) : (
|
||||
<button
|
||||
onClick={handleDownloadToLibrary}
|
||||
className={`inline-flex w-full items-center justify-center gap-2 rounded-[38px] px-4 py-3 transition-colors sm:w-auto sm:gap-2.5 sm:px-5 sm:py-3.5 lg:px-6 lg:py-4 ${
|
||||
downloading
|
||||
? "bg-neutral-400"
|
||||
: "bg-violet-600 hover:bg-violet-700"
|
||||
}`}
|
||||
disabled={downloading}
|
||||
>
|
||||
{downloading ? (
|
||||
<LoaderIcon className="h-5 w-5 animate-spin text-white sm:h-5 sm:w-5 lg:h-6 lg:w-6" />
|
||||
) : (
|
||||
<DownloadIcon className="h-5 w-5 text-white sm:h-5 sm:w-5 lg:h-6 lg:w-6" />
|
||||
)}
|
||||
<span className="font-poppins text-base font-medium text-neutral-50 sm:text-lg">
|
||||
{downloading ? "Downloading..." : "Download Agent as File"}
|
||||
</span>
|
||||
</button>
|
||||
)}
|
||||
</div>
|
||||
|
||||
{/* Rating and Runs */}
|
||||
|
|
|
@ -348,6 +348,17 @@ export default class BackendAPI {
|
|||
return this._get("/store/myagents", params);
|
||||
}
|
||||
|
||||
downloadStoreAgent(
|
||||
storeListingVersionId: string,
|
||||
version?: number,
|
||||
): Promise<BlobPart> {
|
||||
const url = version
|
||||
? `/store/download/agents/${storeListingVersionId}?version=${version}`
|
||||
: `/store/download/agents/${storeListingVersionId}`;
|
||||
|
||||
return this._get(url);
|
||||
}
|
||||
|
||||
/////////////////////////////////////////
|
||||
/////////// V2 LIBRARY API //////////////
|
||||
/////////////////////////////////////////
|
||||
|
|
Loading…
Reference in New Issue