fix(backend): Add migrations to fix credentials inputs with invalid provider "llm" (#8674)
In #8524, the "llm" credentials provider was replaced. There are still entries with `"provider": "llm"` in the system though, and those break if not migrated. - SQL migration to fix the obvious ones where we know the provider from `credentials.id` - Non-SQL migration to fix the restrelease/platform-v0.3.0
parent
9a4ff9023d
commit
918538147c
|
@ -5,6 +5,7 @@ from collections import defaultdict
|
|||
from datetime import datetime, timezone
|
||||
from typing import Any, Literal, Type
|
||||
|
||||
import prisma
|
||||
from prisma.models import AgentGraph, AgentGraphExecution, AgentNode, AgentNodeLink
|
||||
from prisma.types import AgentGraphWhereInput
|
||||
from pydantic.fields import computed_field
|
||||
|
@ -523,3 +524,84 @@ async def __create_graph(tx, graph: Graph, user_id: str):
|
|||
for link in graph.links
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
# ------------------------ UTILITIES ------------------------ #
|
||||
|
||||
|
||||
async def fix_llm_provider_credentials():
|
||||
"""Fix node credentials with provider `llm`"""
|
||||
from autogpt_libs.supabase_integration_credentials_store import (
|
||||
SupabaseIntegrationCredentialsStore,
|
||||
)
|
||||
|
||||
from .redis import get_redis
|
||||
from .user import get_user_integrations
|
||||
|
||||
store = SupabaseIntegrationCredentialsStore(get_redis())
|
||||
|
||||
broken_nodes = await prisma.get_client().query_raw(
|
||||
"""
|
||||
SELECT "User".id user_id,
|
||||
node.id node_id,
|
||||
node."constantInput" node_preset_input
|
||||
FROM platform."AgentNode" node
|
||||
LEFT JOIN platform."AgentGraph" graph
|
||||
ON node."agentGraphId" = graph.id
|
||||
LEFT JOIN platform."User" "User"
|
||||
ON graph."userId" = "User".id
|
||||
WHERE node."constantInput"::jsonb->'credentials'->>'provider' = 'llm'
|
||||
ORDER BY user_id;
|
||||
"""
|
||||
)
|
||||
logger.info(f"Fixing LLM credential inputs on {len(broken_nodes)} nodes")
|
||||
|
||||
user_id: str = ""
|
||||
user_integrations = None
|
||||
for node in broken_nodes:
|
||||
if node["user_id"] != user_id:
|
||||
# Save queries by only fetching once per user
|
||||
user_id = node["user_id"]
|
||||
user_integrations = await get_user_integrations(user_id)
|
||||
elif not user_integrations:
|
||||
raise RuntimeError(f"Impossible state while processing node {node}")
|
||||
|
||||
node_id: str = node["node_id"]
|
||||
node_preset_input: dict = json.loads(node["node_preset_input"])
|
||||
credentials_meta: dict = node_preset_input["credentials"]
|
||||
|
||||
credentials = next(
|
||||
(
|
||||
c
|
||||
for c in user_integrations.credentials
|
||||
if c.id == credentials_meta["id"]
|
||||
),
|
||||
None,
|
||||
)
|
||||
if not credentials:
|
||||
continue
|
||||
if credentials.type != "api_key":
|
||||
logger.warning(
|
||||
f"User {user_id} credentials {credentials.id} with provider 'llm' "
|
||||
f"has invalid type '{credentials.type}'"
|
||||
)
|
||||
continue
|
||||
|
||||
api_key = credentials.api_key.get_secret_value()
|
||||
if api_key.startswith("sk-ant-api03-"):
|
||||
credentials.provider = credentials_meta["provider"] = "anthropic"
|
||||
elif api_key.startswith("sk-"):
|
||||
credentials.provider = credentials_meta["provider"] = "openai"
|
||||
elif api_key.startswith("gsk_"):
|
||||
credentials.provider = credentials_meta["provider"] = "groq"
|
||||
else:
|
||||
logger.warning(
|
||||
f"Could not identify provider from key prefix {api_key[:13]}*****"
|
||||
)
|
||||
continue
|
||||
|
||||
store.update_creds(user_id, credentials)
|
||||
await AgentNode.prisma().update(
|
||||
where={"id": node_id},
|
||||
data={"constantInput": json.dumps(node_preset_input)},
|
||||
)
|
||||
|
|
|
@ -9,6 +9,7 @@ import uvicorn
|
|||
|
||||
import backend.data.block
|
||||
import backend.data.db
|
||||
import backend.data.graph
|
||||
import backend.data.user
|
||||
import backend.server.routers.v1
|
||||
import backend.util.service
|
||||
|
@ -23,6 +24,7 @@ async def lifespan_context(app: fastapi.FastAPI):
|
|||
await backend.data.db.connect()
|
||||
await backend.data.block.initialize_blocks()
|
||||
await backend.data.user.migrate_and_encrypt_user_integrations()
|
||||
await backend.data.graph.fix_llm_provider_credentials()
|
||||
yield
|
||||
await backend.data.db.disconnect()
|
||||
|
||||
|
|
|
@ -0,0 +1,13 @@
|
|||
-- Correct credentials.provider field on all nodes with 'llm' provider credentials
|
||||
UPDATE "AgentNode"
|
||||
SET "constantInput" = JSONB_SET(
|
||||
"constantInput"::jsonb,
|
||||
'{credentials,provider}',
|
||||
CASE
|
||||
WHEN "constantInput"::jsonb->'credentials'->>'id' = '53c25cb8-e3ee-465c-a4d1-e75a4c899c2a' THEN '"openai"'::jsonb
|
||||
WHEN "constantInput"::jsonb->'credentials'->>'id' = '24e5d942-d9e3-4798-8151-90143ee55629' THEN '"anthropic"'::jsonb
|
||||
WHEN "constantInput"::jsonb->'credentials'->>'id' = '4ec22295-8f97-4dd1-b42b-2c6957a02545' THEN '"groq"'::jsonb
|
||||
ELSE "constantInput"::jsonb->'credentials'->'provider'
|
||||
END
|
||||
)::text
|
||||
WHERE "constantInput"::jsonb->'credentials'->>'provider' = 'llm';
|
Loading…
Reference in New Issue