feat(platform): Simplify Credentials UX (#8524)
- Change `provider` of default credentials to actual provider names (e.g. `anthropic`), remove `llm` provider - Add `discriminator` and `discriminator_mapping` to `CredentialsField` that allows to filter credentials input to only allow providers for matching models in `useCredentials` hook (thanks @ntindle for the idea!); e.g. user chooses `GPT4_TURBO` so then only OpenAI credentials are allowed - Choose credentials automatically and hide credentials input on the node completely if there's only one possible option - Move `getValue` and `parseKeys` to utils - Add `ANTHROPIC`, `GROQ` and `OLLAMA` to providers in frontend `types.ts` - Add `hidden` field to credentials that is used for default system keys to hide them in user profile - Now `provider` field in `CredentialsField` can accept multiple providers as a list ----------------- Co-authored-by: Nicholas Tindle <nicholas.tindle@agpt.co> Co-authored-by: Reinier van der Leer <pwuts@agpt.co>pull/8639/head
parent
ef7e50403e
commit
e907ffda6e
|
@ -46,21 +46,21 @@ replicate_credentials = APIKeyCredentials(
|
|||
)
|
||||
openai_credentials = APIKeyCredentials(
|
||||
id="53c25cb8-e3ee-465c-a4d1-e75a4c899c2a",
|
||||
provider="llm",
|
||||
provider="openai",
|
||||
api_key=SecretStr(settings.secrets.openai_api_key),
|
||||
title="Use Credits for OpenAI",
|
||||
expires_at=None,
|
||||
)
|
||||
anthropic_credentials = APIKeyCredentials(
|
||||
id="24e5d942-d9e3-4798-8151-90143ee55629",
|
||||
provider="llm",
|
||||
provider="anthropic",
|
||||
api_key=SecretStr(settings.secrets.anthropic_api_key),
|
||||
title="Use Credits for Anthropic",
|
||||
expires_at=None,
|
||||
)
|
||||
groq_credentials = APIKeyCredentials(
|
||||
id="4ec22295-8f97-4dd1-b42b-2c6957a02545",
|
||||
provider="llm",
|
||||
provider="groq",
|
||||
api_key=SecretStr(settings.secrets.groq_api_key),
|
||||
title="Use Credits for Groq",
|
||||
expires_at=None,
|
||||
|
|
|
@ -30,11 +30,12 @@ logger = logging.getLogger(__name__)
|
|||
# "ollama": BlockSecret(value=""),
|
||||
# }
|
||||
|
||||
AICredentials = CredentialsMetaInput[Literal["llm"], Literal["api_key"]]
|
||||
LLMProviderName = Literal["anthropic", "groq", "openai", "ollama"]
|
||||
AICredentials = CredentialsMetaInput[LLMProviderName, Literal["api_key"]]
|
||||
|
||||
TEST_CREDENTIALS = APIKeyCredentials(
|
||||
id="ed55ac19-356e-4243-a6cb-bc599e9b716f",
|
||||
provider="llm",
|
||||
provider="openai",
|
||||
api_key=SecretStr("mock-openai-api-key"),
|
||||
title="Mock OpenAI API key",
|
||||
expires_at=None,
|
||||
|
@ -50,8 +51,12 @@ TEST_CREDENTIALS_INPUT = {
|
|||
def AICredentialsField() -> AICredentials:
|
||||
return CredentialsField(
|
||||
description="API key for the LLM provider.",
|
||||
provider="llm",
|
||||
provider=["anthropic", "groq", "openai", "ollama"],
|
||||
supported_credential_types={"api_key"},
|
||||
discriminator="model",
|
||||
discriminator_mapping={
|
||||
model.value: model.metadata.provider for model in LlmModel
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
|
|
|
@ -152,10 +152,12 @@ class CredentialsMetaInput(BaseModel, Generic[CP, CT]):
|
|||
|
||||
|
||||
def CredentialsField(
|
||||
provider: CP,
|
||||
provider: CP | list[CP],
|
||||
supported_credential_types: set[CT],
|
||||
required_scopes: set[str] = set(),
|
||||
*,
|
||||
discriminator: Optional[str] = None,
|
||||
discriminator_mapping: Optional[dict[str, Any]] = None,
|
||||
title: Optional[str] = None,
|
||||
description: Optional[str] = None,
|
||||
**kwargs,
|
||||
|
@ -167,9 +169,13 @@ def CredentialsField(
|
|||
json_extra = {
|
||||
k: v
|
||||
for k, v in {
|
||||
"credentials_provider": provider,
|
||||
"credentials_provider": (
|
||||
[provider] if isinstance(provider, str) else provider
|
||||
),
|
||||
"credentials_scopes": list(required_scopes) or None, # omit if empty
|
||||
"credentials_types": list(supported_credential_types),
|
||||
"discriminator": discriminator,
|
||||
"discriminator_mapping": discriminator_mapping,
|
||||
}.items()
|
||||
if v is not None
|
||||
}
|
||||
|
|
|
@ -4,17 +4,14 @@ import { useSupabase } from "@/components/SupabaseProvider";
|
|||
import { Button } from "@/components/ui/button";
|
||||
import useUser from "@/hooks/useUser";
|
||||
import { useRouter } from "next/navigation";
|
||||
import { useCallback, useContext } from "react";
|
||||
import { useCallback, useContext, useMemo } from "react";
|
||||
import { FaSpinner } from "react-icons/fa";
|
||||
import { Separator } from "@/components/ui/separator";
|
||||
import { useToast } from "@/components/ui/use-toast";
|
||||
import { IconKey, IconUser } from "@/components/ui/icons";
|
||||
import { LogOutIcon, Trash2Icon } from "lucide-react";
|
||||
import { providerIcons } from "@/components/integrations/credentials-input";
|
||||
import {
|
||||
CredentialsProviderName,
|
||||
CredentialsProvidersContext,
|
||||
} from "@/components/integrations/credentials-provider";
|
||||
import { CredentialsProvidersContext } from "@/components/integrations/credentials-provider";
|
||||
import {
|
||||
Table,
|
||||
TableBody,
|
||||
|
@ -23,6 +20,7 @@ import {
|
|||
TableHeader,
|
||||
TableRow,
|
||||
} from "@/components/ui/table";
|
||||
import { CredentialsProviderName } from "@/lib/autogpt-server-api";
|
||||
|
||||
export default function PrivatePage() {
|
||||
const { user, isLoading, error } = useUser();
|
||||
|
@ -62,7 +60,22 @@ export default function PrivatePage() {
|
|||
[providers, toast],
|
||||
);
|
||||
|
||||
if (isLoading || !providers || !providers) {
|
||||
//TODO: remove when the way system credentials are handled is updated
|
||||
// This contains ids for built-in "Use Credits for X" credentials
|
||||
const hiddenCredentials = useMemo(
|
||||
() => [
|
||||
"fdb7f412-f519-48d1-9b5f-d2f73d0e01fe", // Revid
|
||||
"760f84fc-b270-42de-91f6-08efe1b512d0", // Ideogram
|
||||
"6b9fc200-4726-4973-86c9-cd526f5ce5db", // Replicate
|
||||
"53c25cb8-e3ee-465c-a4d1-e75a4c899c2a", // OpenAI
|
||||
"24e5d942-d9e3-4798-8151-90143ee55629", // Anthropic
|
||||
"4ec22295-8f97-4dd1-b42b-2c6957a02545", // Groq
|
||||
"7f7b0654-c36b-4565-8fa7-9a52575dfae2", // D-ID
|
||||
],
|
||||
[],
|
||||
);
|
||||
|
||||
if (isLoading || !providers) {
|
||||
return (
|
||||
<div className="flex h-[80vh] items-center justify-center">
|
||||
<FaSpinner className="mr-2 h-16 w-16 animate-spin" />
|
||||
|
@ -76,15 +89,15 @@ export default function PrivatePage() {
|
|||
}
|
||||
|
||||
const allCredentials = Object.values(providers).flatMap((provider) =>
|
||||
[...provider.savedOAuthCredentials, ...provider.savedApiKeys].map(
|
||||
(credentials) => ({
|
||||
[...provider.savedOAuthCredentials, ...provider.savedApiKeys]
|
||||
.filter((cred) => !hiddenCredentials.includes(cred.id))
|
||||
.map((credentials) => ({
|
||||
...credentials,
|
||||
provider: provider.provider,
|
||||
providerName: provider.providerName,
|
||||
ProviderIcon: providerIcons[provider.provider],
|
||||
TypeIcon: { oauth2: IconUser, api_key: IconKey }[credentials.type],
|
||||
}),
|
||||
),
|
||||
})),
|
||||
);
|
||||
|
||||
return (
|
||||
|
|
|
@ -18,7 +18,13 @@ import {
|
|||
BlockUIType,
|
||||
BlockCost,
|
||||
} from "@/lib/autogpt-server-api/types";
|
||||
import { beautifyString, cn, setNestedProperty } from "@/lib/utils";
|
||||
import {
|
||||
beautifyString,
|
||||
cn,
|
||||
getValue,
|
||||
parseKeys,
|
||||
setNestedProperty,
|
||||
} from "@/lib/utils";
|
||||
import { Button } from "@/components/ui/button";
|
||||
import { Switch } from "@/components/ui/switch";
|
||||
import { history } from "./history";
|
||||
|
@ -36,8 +42,6 @@ import * as Separator from "@radix-ui/react-separator";
|
|||
import * as ContextMenu from "@radix-ui/react-context-menu";
|
||||
import { DotsVerticalIcon, TrashIcon, CopyIcon } from "@radix-ui/react-icons";
|
||||
|
||||
type ParsedKey = { key: string; index?: number };
|
||||
|
||||
export type ConnectionData = Array<{
|
||||
edge_id: string;
|
||||
source: string;
|
||||
|
@ -178,7 +182,7 @@ export function CustomNode({
|
|||
className=""
|
||||
selfKey={noteKey}
|
||||
schema={noteSchema as BlockIOStringSubSchema}
|
||||
value={getValue(noteKey)}
|
||||
value={getValue(noteKey, data.hardcodedValues)}
|
||||
handleInputChange={handleInputChange}
|
||||
handleInputClick={handleInputClick}
|
||||
error={data.errors?.[noteKey] ?? ""}
|
||||
|
@ -228,7 +232,7 @@ export function CustomNode({
|
|||
nodeId={id}
|
||||
propKey={getInputPropKey(propKey)}
|
||||
propSchema={propSchema}
|
||||
currentValue={getValue(getInputPropKey(propKey))}
|
||||
currentValue={getValue(propKey, data.hardcodedValues)}
|
||||
connections={data.connections}
|
||||
handleInputChange={handleInputChange}
|
||||
handleInputClick={handleInputClick}
|
||||
|
@ -283,48 +287,6 @@ export function CustomNode({
|
|||
setErrors({ ...errors });
|
||||
};
|
||||
|
||||
// Helper function to parse keys with array indices
|
||||
//TODO move to utils
|
||||
const parseKeys = (key: string): ParsedKey[] => {
|
||||
const splits = key.split(/_@_|_#_|_\$_|\./);
|
||||
const keys: ParsedKey[] = [];
|
||||
let currentKey: string | null = null;
|
||||
|
||||
splits.forEach((split) => {
|
||||
const isInteger = /^\d+$/.test(split);
|
||||
if (!isInteger) {
|
||||
if (currentKey !== null) {
|
||||
keys.push({ key: currentKey });
|
||||
}
|
||||
currentKey = split;
|
||||
} else {
|
||||
if (currentKey !== null) {
|
||||
keys.push({ key: currentKey, index: parseInt(split, 10) });
|
||||
currentKey = null;
|
||||
} else {
|
||||
throw new Error("Invalid key format: array index without a key");
|
||||
}
|
||||
}
|
||||
});
|
||||
|
||||
if (currentKey !== null) {
|
||||
keys.push({ key: currentKey });
|
||||
}
|
||||
|
||||
return keys;
|
||||
};
|
||||
|
||||
const getValue = (key: string) => {
|
||||
const keys = parseKeys(key);
|
||||
return keys.reduce((acc, k) => {
|
||||
if (acc === undefined) return undefined;
|
||||
if (k.index !== undefined) {
|
||||
return Array.isArray(acc[k.key]) ? acc[k.key][k.index] : undefined;
|
||||
}
|
||||
return acc[k.key];
|
||||
}, data.hardcodedValues as any);
|
||||
};
|
||||
|
||||
const isHandleConnected = (key: string) => {
|
||||
return (
|
||||
data.connections &&
|
||||
|
@ -347,7 +309,7 @@ export function CustomNode({
|
|||
const handleInputClick = (key: string) => {
|
||||
console.debug(`Opening modal for key: ${key}`);
|
||||
setActiveKey(key);
|
||||
const value = getValue(key);
|
||||
const value = getValue(key, data.hardcodedValues);
|
||||
setInputModalValue(
|
||||
typeof value === "object" ? JSON.stringify(value, null, 2) : value,
|
||||
);
|
||||
|
|
|
@ -46,16 +46,18 @@ export const providerIcons: Record<
|
|||
CredentialsProviderName,
|
||||
React.FC<{ className?: string }>
|
||||
> = {
|
||||
anthropic: fallbackIcon,
|
||||
github: FaGithub,
|
||||
google: FaGoogle,
|
||||
groq: fallbackIcon,
|
||||
notion: NotionLogoIcon,
|
||||
discord: FaDiscord,
|
||||
d_id: fallbackIcon,
|
||||
google_maps: FaGoogle,
|
||||
jina: fallbackIcon,
|
||||
ideogram: fallbackIcon,
|
||||
llm: fallbackIcon,
|
||||
medium: FaMedium,
|
||||
ollama: fallbackIcon,
|
||||
openai: fallbackIcon,
|
||||
openweathermap: fallbackIcon,
|
||||
pinecone: fallbackIcon,
|
||||
|
@ -80,7 +82,7 @@ export type OAuthPopupResultMessage = { message_type: "oauth_popup_result" } & (
|
|||
export const CredentialsInput: FC<{
|
||||
className?: string;
|
||||
selectedCredentials?: CredentialsMetaInput;
|
||||
onSelectCredentials: (newValue: CredentialsMetaInput) => void;
|
||||
onSelectCredentials: (newValue?: CredentialsMetaInput) => void;
|
||||
}> = ({ className, selectedCredentials, onSelectCredentials }) => {
|
||||
const api = useMemo(() => new AutoGPTServerAPI(), []);
|
||||
const credentials = useCredentials();
|
||||
|
@ -91,14 +93,10 @@ export const CredentialsInput: FC<{
|
|||
useState<AbortController | null>(null);
|
||||
const [oAuthError, setOAuthError] = useState<string | null>(null);
|
||||
|
||||
if (!credentials) {
|
||||
if (!credentials || credentials.isLoading) {
|
||||
return null;
|
||||
}
|
||||
|
||||
if (credentials.isLoading) {
|
||||
return <div>Loading...</div>;
|
||||
}
|
||||
|
||||
const {
|
||||
schema,
|
||||
provider,
|
||||
|
@ -222,10 +220,21 @@ export const CredentialsInput: FC<{
|
|||
</>
|
||||
);
|
||||
|
||||
// Deselect credentials if they do not exist (e.g. provider was changed)
|
||||
if (
|
||||
selectedCredentials &&
|
||||
!savedApiKeys
|
||||
.concat(savedOAuthCredentials)
|
||||
.some((c) => c.id === selectedCredentials.id)
|
||||
) {
|
||||
onSelectCredentials(undefined);
|
||||
}
|
||||
|
||||
// No saved credentials yet
|
||||
if (savedApiKeys.length === 0 && savedOAuthCredentials.length === 0) {
|
||||
return (
|
||||
<>
|
||||
<span className="text-m green mb-0 text-gray-900">Credentials</span>
|
||||
<div className={cn("flex flex-row space-x-2", className)}>
|
||||
{supportsOAuth2 && (
|
||||
<Button onClick={handleOAuthLogin}>
|
||||
|
@ -248,6 +257,25 @@ export const CredentialsInput: FC<{
|
|||
);
|
||||
}
|
||||
|
||||
const singleCredential =
|
||||
savedApiKeys.length === 1 && savedOAuthCredentials.length === 0
|
||||
? savedApiKeys[0]
|
||||
: savedOAuthCredentials.length === 1 && savedApiKeys.length === 0
|
||||
? savedOAuthCredentials[0]
|
||||
: null;
|
||||
|
||||
if (singleCredential) {
|
||||
if (!selectedCredentials) {
|
||||
onSelectCredentials({
|
||||
id: singleCredential.id,
|
||||
type: singleCredential.type,
|
||||
provider,
|
||||
title: singleCredential.title,
|
||||
});
|
||||
}
|
||||
return null;
|
||||
}
|
||||
|
||||
function handleValueChange(newValue: string) {
|
||||
if (newValue === "sign-in") {
|
||||
// Trigger OAuth2 sign in flow
|
||||
|
@ -263,7 +291,7 @@ export const CredentialsInput: FC<{
|
|||
onSelectCredentials({
|
||||
id: selectedCreds.id,
|
||||
type: selectedCreds.type,
|
||||
provider: schema.credentials_provider,
|
||||
provider: provider,
|
||||
// title: customTitle, // TODO: add input for title
|
||||
});
|
||||
}
|
||||
|
@ -272,6 +300,7 @@ export const CredentialsInput: FC<{
|
|||
// Saved credentials exist
|
||||
return (
|
||||
<>
|
||||
<span className="text-m green mb-0 text-gray-900">Credentials</span>
|
||||
<Select value={selectedCredentials?.id} onValueChange={handleValueChange}>
|
||||
<SelectTrigger>
|
||||
<SelectValue placeholder={schema.placeholder} />
|
||||
|
|
|
@ -20,16 +20,18 @@ const CREDENTIALS_PROVIDER_NAMES = Object.values(
|
|||
|
||||
// --8<-- [start:CredentialsProviderNames]
|
||||
const providerDisplayNames: Record<CredentialsProviderName, string> = {
|
||||
anthropic: "Anthropic",
|
||||
discord: "Discord",
|
||||
d_id: "D-ID",
|
||||
github: "GitHub",
|
||||
google: "Google",
|
||||
google_maps: "Google Maps",
|
||||
groq: "Groq",
|
||||
ideogram: "Ideogram",
|
||||
jina: "Jina",
|
||||
medium: "Medium",
|
||||
llm: "LLM",
|
||||
notion: "Notion",
|
||||
ollama: "Ollama",
|
||||
openai: "OpenAI",
|
||||
openweathermap: "OpenWeatherMap",
|
||||
pinecone: "Pinecone",
|
||||
|
|
|
@ -608,7 +608,15 @@ const NodeStringInput: FC<{
|
|||
className,
|
||||
displayName,
|
||||
}) => {
|
||||
value ||= schema.default || "";
|
||||
if (!value) {
|
||||
value = schema.default || "";
|
||||
// Force update hardcodedData so discriminators can update
|
||||
// e.g. credentials update when provider changes
|
||||
// this won't happen if the value is only set here to schema.default
|
||||
if (schema.default) {
|
||||
handleInputChange(selfKey, value);
|
||||
}
|
||||
}
|
||||
return (
|
||||
<div className={className}>
|
||||
{schema.enum ? (
|
||||
|
|
|
@ -1,11 +1,15 @@
|
|||
import { useContext } from "react";
|
||||
import { CustomNodeData } from "@/components/CustomNode";
|
||||
import { BlockIOCredentialsSubSchema } from "@/lib/autogpt-server-api";
|
||||
import {
|
||||
BlockIOCredentialsSubSchema,
|
||||
CredentialsProviderName,
|
||||
} from "@/lib/autogpt-server-api";
|
||||
import { Node, useNodeId, useNodesData } from "@xyflow/react";
|
||||
import {
|
||||
CredentialsProviderData,
|
||||
CredentialsProvidersContext,
|
||||
} from "@/components/integrations/credentials-provider";
|
||||
import { getValue } from "@/lib/utils";
|
||||
|
||||
export type CredentialsData =
|
||||
| {
|
||||
|
@ -34,15 +38,22 @@ export default function useCredentials(): CredentialsData | null {
|
|||
const credentialsSchema = data.inputSchema.properties
|
||||
.credentials as BlockIOCredentialsSubSchema;
|
||||
|
||||
const discriminatorValue: CredentialsProviderName | null =
|
||||
(credentialsSchema.discriminator &&
|
||||
credentialsSchema.discriminator_mapping![
|
||||
getValue(credentialsSchema.discriminator, data.hardcodedValues)
|
||||
]) ||
|
||||
null;
|
||||
|
||||
const providerName =
|
||||
discriminatorValue || credentialsSchema.credentials_provider;
|
||||
const provider = allProviders ? allProviders[providerName] : null;
|
||||
|
||||
// If block input schema doesn't have credentials, return null
|
||||
if (!credentialsSchema) {
|
||||
return null;
|
||||
}
|
||||
|
||||
const provider = allProviders
|
||||
? allProviders[credentialsSchema?.credentials_provider]
|
||||
: null;
|
||||
|
||||
const supportsApiKey =
|
||||
credentialsSchema.credentials_types.includes("api_key");
|
||||
const supportsOAuth2 = credentialsSchema.credentials_types.includes("oauth2");
|
||||
|
@ -68,6 +79,7 @@ export default function useCredentials(): CredentialsData | null {
|
|||
|
||||
return {
|
||||
...provider,
|
||||
provider: providerName,
|
||||
schema: credentialsSchema,
|
||||
supportsApiKey,
|
||||
supportsOAuth2,
|
||||
|
|
|
@ -98,16 +98,18 @@ export type CredentialsType = "api_key" | "oauth2";
|
|||
|
||||
// --8<-- [start:BlockIOCredentialsSubSchema]
|
||||
export const PROVIDER_NAMES = {
|
||||
ANTHROPIC: "anthropic",
|
||||
D_ID: "d_id",
|
||||
DISCORD: "discord",
|
||||
GITHUB: "github",
|
||||
GOOGLE: "google",
|
||||
GOOGLE_MAPS: "google_maps",
|
||||
GROQ: "groq",
|
||||
IDEOGRAM: "ideogram",
|
||||
JINA: "jina",
|
||||
LLM: "llm",
|
||||
MEDIUM: "medium",
|
||||
NOTION: "notion",
|
||||
OLLAMA: "ollama",
|
||||
OPENAI: "openai",
|
||||
OPENWEATHERMAP: "openweathermap",
|
||||
PINECONE: "pinecone",
|
||||
|
@ -124,6 +126,8 @@ export type BlockIOCredentialsSubSchema = BlockIOSubSchemaMeta & {
|
|||
credentials_provider: CredentialsProviderName;
|
||||
credentials_scopes?: string[];
|
||||
credentials_types: Array<CredentialsType>;
|
||||
discriminator?: string;
|
||||
discriminator_mapping?: { [key: string]: CredentialsProviderName };
|
||||
};
|
||||
|
||||
export type BlockIONullSubSchema = BlockIOSubSchemaMeta & {
|
||||
|
|
|
@ -313,3 +313,48 @@ export function findNewlyAddedBlockCoordinates(
|
|||
y: 0,
|
||||
};
|
||||
}
|
||||
|
||||
type ParsedKey = { key: string; index?: number };
|
||||
|
||||
export function parseKeys(key: string): ParsedKey[] {
|
||||
const splits = key.split(/_@_|_#_|_\$_|\./);
|
||||
const keys: ParsedKey[] = [];
|
||||
let currentKey: string | null = null;
|
||||
|
||||
splits.forEach((split) => {
|
||||
const isInteger = /^\d+$/.test(split);
|
||||
if (!isInteger) {
|
||||
if (currentKey !== null) {
|
||||
keys.push({ key: currentKey });
|
||||
}
|
||||
currentKey = split;
|
||||
} else {
|
||||
if (currentKey !== null) {
|
||||
keys.push({ key: currentKey, index: parseInt(split, 10) });
|
||||
currentKey = null;
|
||||
} else {
|
||||
throw new Error("Invalid key format: array index without a key");
|
||||
}
|
||||
}
|
||||
});
|
||||
|
||||
if (currentKey !== null) {
|
||||
keys.push({ key: currentKey });
|
||||
}
|
||||
|
||||
return keys;
|
||||
}
|
||||
|
||||
/**
|
||||
* Get the value of a nested key in an object, handles arrays and objects.
|
||||
*/
|
||||
export function getValue(key: string, value: any) {
|
||||
const keys = parseKeys(key);
|
||||
return keys.reduce((acc, k) => {
|
||||
if (acc === undefined) return undefined;
|
||||
if (k.index !== undefined) {
|
||||
return Array.isArray(acc[k.key]) ? acc[k.key][k.index] : undefined;
|
||||
}
|
||||
return acc[k.key];
|
||||
}, value);
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue