From 918538147cd582ae98e3fad89945f8b435e5716e Mon Sep 17 00:00:00 2001 From: Reinier van der Leer Date: Fri, 15 Nov 2024 20:18:02 +0100 Subject: [PATCH] fix(backend): Add migrations to fix credentials inputs with invalid provider "llm" (#8674) In #8524, the "llm" credentials provider was replaced. There are still entries with `"provider": "llm"` in the system though, and those break if not migrated. - SQL migration to fix the obvious ones where we know the provider from `credentials.id` - Non-SQL migration to fix the rest --- .../backend/backend/data/graph.py | 82 +++++++++++++++++++ .../backend/backend/server/rest_api.py | 2 + .../migration.sql | 13 +++ 3 files changed, 97 insertions(+) create mode 100644 autogpt_platform/backend/migrations/20241115170707_fix_llm_provider_credentials/migration.sql diff --git a/autogpt_platform/backend/backend/data/graph.py b/autogpt_platform/backend/backend/data/graph.py index d97d246ea..55add87f0 100644 --- a/autogpt_platform/backend/backend/data/graph.py +++ b/autogpt_platform/backend/backend/data/graph.py @@ -5,6 +5,7 @@ from collections import defaultdict from datetime import datetime, timezone from typing import Any, Literal, Type +import prisma from prisma.models import AgentGraph, AgentGraphExecution, AgentNode, AgentNodeLink from prisma.types import AgentGraphWhereInput from pydantic.fields import computed_field @@ -523,3 +524,84 @@ async def __create_graph(tx, graph: Graph, user_id: str): for link in graph.links ] ) + + +# ------------------------ UTILITIES ------------------------ # + + +async def fix_llm_provider_credentials(): + """Fix node credentials with provider `llm`""" + from autogpt_libs.supabase_integration_credentials_store import ( + SupabaseIntegrationCredentialsStore, + ) + + from .redis import get_redis + from .user import get_user_integrations + + store = SupabaseIntegrationCredentialsStore(get_redis()) + + broken_nodes = await prisma.get_client().query_raw( + """ + SELECT "User".id user_id, + node.id node_id, + node."constantInput" node_preset_input + FROM platform."AgentNode" node + LEFT JOIN platform."AgentGraph" graph + ON node."agentGraphId" = graph.id + LEFT JOIN platform."User" "User" + ON graph."userId" = "User".id + WHERE node."constantInput"::jsonb->'credentials'->>'provider' = 'llm' + ORDER BY user_id; + """ + ) + logger.info(f"Fixing LLM credential inputs on {len(broken_nodes)} nodes") + + user_id: str = "" + user_integrations = None + for node in broken_nodes: + if node["user_id"] != user_id: + # Save queries by only fetching once per user + user_id = node["user_id"] + user_integrations = await get_user_integrations(user_id) + elif not user_integrations: + raise RuntimeError(f"Impossible state while processing node {node}") + + node_id: str = node["node_id"] + node_preset_input: dict = json.loads(node["node_preset_input"]) + credentials_meta: dict = node_preset_input["credentials"] + + credentials = next( + ( + c + for c in user_integrations.credentials + if c.id == credentials_meta["id"] + ), + None, + ) + if not credentials: + continue + if credentials.type != "api_key": + logger.warning( + f"User {user_id} credentials {credentials.id} with provider 'llm' " + f"has invalid type '{credentials.type}'" + ) + continue + + api_key = credentials.api_key.get_secret_value() + if api_key.startswith("sk-ant-api03-"): + credentials.provider = credentials_meta["provider"] = "anthropic" + elif api_key.startswith("sk-"): + credentials.provider = credentials_meta["provider"] = "openai" + elif api_key.startswith("gsk_"): + credentials.provider = credentials_meta["provider"] = "groq" + else: + logger.warning( + f"Could not identify provider from key prefix {api_key[:13]}*****" + ) + continue + + store.update_creds(user_id, credentials) + await AgentNode.prisma().update( + where={"id": node_id}, + data={"constantInput": json.dumps(node_preset_input)}, + ) diff --git a/autogpt_platform/backend/backend/server/rest_api.py b/autogpt_platform/backend/backend/server/rest_api.py index 27c6679cc..2436e4115 100644 --- a/autogpt_platform/backend/backend/server/rest_api.py +++ b/autogpt_platform/backend/backend/server/rest_api.py @@ -9,6 +9,7 @@ import uvicorn import backend.data.block import backend.data.db +import backend.data.graph import backend.data.user import backend.server.routers.v1 import backend.util.service @@ -23,6 +24,7 @@ async def lifespan_context(app: fastapi.FastAPI): await backend.data.db.connect() await backend.data.block.initialize_blocks() await backend.data.user.migrate_and_encrypt_user_integrations() + await backend.data.graph.fix_llm_provider_credentials() yield await backend.data.db.disconnect() diff --git a/autogpt_platform/backend/migrations/20241115170707_fix_llm_provider_credentials/migration.sql b/autogpt_platform/backend/migrations/20241115170707_fix_llm_provider_credentials/migration.sql new file mode 100644 index 000000000..59b1d0b05 --- /dev/null +++ b/autogpt_platform/backend/migrations/20241115170707_fix_llm_provider_credentials/migration.sql @@ -0,0 +1,13 @@ +-- Correct credentials.provider field on all nodes with 'llm' provider credentials +UPDATE "AgentNode" +SET "constantInput" = JSONB_SET( + "constantInput"::jsonb, + '{credentials,provider}', + CASE + WHEN "constantInput"::jsonb->'credentials'->>'id' = '53c25cb8-e3ee-465c-a4d1-e75a4c899c2a' THEN '"openai"'::jsonb + WHEN "constantInput"::jsonb->'credentials'->>'id' = '24e5d942-d9e3-4798-8151-90143ee55629' THEN '"anthropic"'::jsonb + WHEN "constantInput"::jsonb->'credentials'->>'id' = '4ec22295-8f97-4dd1-b42b-2c6957a02545' THEN '"groq"'::jsonb + ELSE "constantInput"::jsonb->'credentials'->'provider' + END + )::text +WHERE "constantInput"::jsonb->'credentials'->>'provider' = 'llm';