feat(platform): Simplify Credentials UX (#8524)

- Change `provider` of default credentials to actual provider names (e.g. `anthropic`), remove `llm` provider
- Add `discriminator` and `discriminator_mapping` to `CredentialsField` that allows to filter credentials input to only allow  providers for matching models in `useCredentials` hook (thanks @ntindle for the idea!); e.g. user chooses `GPT4_TURBO` so then only OpenAI credentials are allowed
- Choose credentials automatically and hide credentials input on the node completely if there's only one possible option
- Move `getValue` and `parseKeys` to utils
- Add `ANTHROPIC`, `GROQ` and `OLLAMA` to providers in frontend `types.ts`
- Add `hidden` field to credentials that is used for default system keys to hide them in user profile
- Now `provider` field in `CredentialsField` can accept multiple providers as a list

-----------------
Co-authored-by: Nicholas Tindle <nicholas.tindle@agpt.co>
Co-authored-by: Reinier van der Leer <pwuts@agpt.co>
pull/8639/head
Krzysztof Czerwinski 2024-11-12 15:55:48 +00:00 committed by GitHub
parent ef7e50403e
commit e907ffda6e
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
11 changed files with 168 additions and 82 deletions

View File

@ -46,21 +46,21 @@ replicate_credentials = APIKeyCredentials(
)
openai_credentials = APIKeyCredentials(
id="53c25cb8-e3ee-465c-a4d1-e75a4c899c2a",
provider="llm",
provider="openai",
api_key=SecretStr(settings.secrets.openai_api_key),
title="Use Credits for OpenAI",
expires_at=None,
)
anthropic_credentials = APIKeyCredentials(
id="24e5d942-d9e3-4798-8151-90143ee55629",
provider="llm",
provider="anthropic",
api_key=SecretStr(settings.secrets.anthropic_api_key),
title="Use Credits for Anthropic",
expires_at=None,
)
groq_credentials = APIKeyCredentials(
id="4ec22295-8f97-4dd1-b42b-2c6957a02545",
provider="llm",
provider="groq",
api_key=SecretStr(settings.secrets.groq_api_key),
title="Use Credits for Groq",
expires_at=None,

View File

@ -30,11 +30,12 @@ logger = logging.getLogger(__name__)
# "ollama": BlockSecret(value=""),
# }
AICredentials = CredentialsMetaInput[Literal["llm"], Literal["api_key"]]
LLMProviderName = Literal["anthropic", "groq", "openai", "ollama"]
AICredentials = CredentialsMetaInput[LLMProviderName, Literal["api_key"]]
TEST_CREDENTIALS = APIKeyCredentials(
id="ed55ac19-356e-4243-a6cb-bc599e9b716f",
provider="llm",
provider="openai",
api_key=SecretStr("mock-openai-api-key"),
title="Mock OpenAI API key",
expires_at=None,
@ -50,8 +51,12 @@ TEST_CREDENTIALS_INPUT = {
def AICredentialsField() -> AICredentials:
return CredentialsField(
description="API key for the LLM provider.",
provider="llm",
provider=["anthropic", "groq", "openai", "ollama"],
supported_credential_types={"api_key"},
discriminator="model",
discriminator_mapping={
model.value: model.metadata.provider for model in LlmModel
},
)

View File

@ -152,10 +152,12 @@ class CredentialsMetaInput(BaseModel, Generic[CP, CT]):
def CredentialsField(
provider: CP,
provider: CP | list[CP],
supported_credential_types: set[CT],
required_scopes: set[str] = set(),
*,
discriminator: Optional[str] = None,
discriminator_mapping: Optional[dict[str, Any]] = None,
title: Optional[str] = None,
description: Optional[str] = None,
**kwargs,
@ -167,9 +169,13 @@ def CredentialsField(
json_extra = {
k: v
for k, v in {
"credentials_provider": provider,
"credentials_provider": (
[provider] if isinstance(provider, str) else provider
),
"credentials_scopes": list(required_scopes) or None, # omit if empty
"credentials_types": list(supported_credential_types),
"discriminator": discriminator,
"discriminator_mapping": discriminator_mapping,
}.items()
if v is not None
}

View File

@ -4,17 +4,14 @@ import { useSupabase } from "@/components/SupabaseProvider";
import { Button } from "@/components/ui/button";
import useUser from "@/hooks/useUser";
import { useRouter } from "next/navigation";
import { useCallback, useContext } from "react";
import { useCallback, useContext, useMemo } from "react";
import { FaSpinner } from "react-icons/fa";
import { Separator } from "@/components/ui/separator";
import { useToast } from "@/components/ui/use-toast";
import { IconKey, IconUser } from "@/components/ui/icons";
import { LogOutIcon, Trash2Icon } from "lucide-react";
import { providerIcons } from "@/components/integrations/credentials-input";
import {
CredentialsProviderName,
CredentialsProvidersContext,
} from "@/components/integrations/credentials-provider";
import { CredentialsProvidersContext } from "@/components/integrations/credentials-provider";
import {
Table,
TableBody,
@ -23,6 +20,7 @@ import {
TableHeader,
TableRow,
} from "@/components/ui/table";
import { CredentialsProviderName } from "@/lib/autogpt-server-api";
export default function PrivatePage() {
const { user, isLoading, error } = useUser();
@ -62,7 +60,22 @@ export default function PrivatePage() {
[providers, toast],
);
if (isLoading || !providers || !providers) {
//TODO: remove when the way system credentials are handled is updated
// This contains ids for built-in "Use Credits for X" credentials
const hiddenCredentials = useMemo(
() => [
"fdb7f412-f519-48d1-9b5f-d2f73d0e01fe", // Revid
"760f84fc-b270-42de-91f6-08efe1b512d0", // Ideogram
"6b9fc200-4726-4973-86c9-cd526f5ce5db", // Replicate
"53c25cb8-e3ee-465c-a4d1-e75a4c899c2a", // OpenAI
"24e5d942-d9e3-4798-8151-90143ee55629", // Anthropic
"4ec22295-8f97-4dd1-b42b-2c6957a02545", // Groq
"7f7b0654-c36b-4565-8fa7-9a52575dfae2", // D-ID
],
[],
);
if (isLoading || !providers) {
return (
<div className="flex h-[80vh] items-center justify-center">
<FaSpinner className="mr-2 h-16 w-16 animate-spin" />
@ -76,15 +89,15 @@ export default function PrivatePage() {
}
const allCredentials = Object.values(providers).flatMap((provider) =>
[...provider.savedOAuthCredentials, ...provider.savedApiKeys].map(
(credentials) => ({
[...provider.savedOAuthCredentials, ...provider.savedApiKeys]
.filter((cred) => !hiddenCredentials.includes(cred.id))
.map((credentials) => ({
...credentials,
provider: provider.provider,
providerName: provider.providerName,
ProviderIcon: providerIcons[provider.provider],
TypeIcon: { oauth2: IconUser, api_key: IconKey }[credentials.type],
}),
),
})),
);
return (

View File

@ -18,7 +18,13 @@ import {
BlockUIType,
BlockCost,
} from "@/lib/autogpt-server-api/types";
import { beautifyString, cn, setNestedProperty } from "@/lib/utils";
import {
beautifyString,
cn,
getValue,
parseKeys,
setNestedProperty,
} from "@/lib/utils";
import { Button } from "@/components/ui/button";
import { Switch } from "@/components/ui/switch";
import { history } from "./history";
@ -36,8 +42,6 @@ import * as Separator from "@radix-ui/react-separator";
import * as ContextMenu from "@radix-ui/react-context-menu";
import { DotsVerticalIcon, TrashIcon, CopyIcon } from "@radix-ui/react-icons";
type ParsedKey = { key: string; index?: number };
export type ConnectionData = Array<{
edge_id: string;
source: string;
@ -178,7 +182,7 @@ export function CustomNode({
className=""
selfKey={noteKey}
schema={noteSchema as BlockIOStringSubSchema}
value={getValue(noteKey)}
value={getValue(noteKey, data.hardcodedValues)}
handleInputChange={handleInputChange}
handleInputClick={handleInputClick}
error={data.errors?.[noteKey] ?? ""}
@ -228,7 +232,7 @@ export function CustomNode({
nodeId={id}
propKey={getInputPropKey(propKey)}
propSchema={propSchema}
currentValue={getValue(getInputPropKey(propKey))}
currentValue={getValue(propKey, data.hardcodedValues)}
connections={data.connections}
handleInputChange={handleInputChange}
handleInputClick={handleInputClick}
@ -283,48 +287,6 @@ export function CustomNode({
setErrors({ ...errors });
};
// Helper function to parse keys with array indices
//TODO move to utils
const parseKeys = (key: string): ParsedKey[] => {
const splits = key.split(/_@_|_#_|_\$_|\./);
const keys: ParsedKey[] = [];
let currentKey: string | null = null;
splits.forEach((split) => {
const isInteger = /^\d+$/.test(split);
if (!isInteger) {
if (currentKey !== null) {
keys.push({ key: currentKey });
}
currentKey = split;
} else {
if (currentKey !== null) {
keys.push({ key: currentKey, index: parseInt(split, 10) });
currentKey = null;
} else {
throw new Error("Invalid key format: array index without a key");
}
}
});
if (currentKey !== null) {
keys.push({ key: currentKey });
}
return keys;
};
const getValue = (key: string) => {
const keys = parseKeys(key);
return keys.reduce((acc, k) => {
if (acc === undefined) return undefined;
if (k.index !== undefined) {
return Array.isArray(acc[k.key]) ? acc[k.key][k.index] : undefined;
}
return acc[k.key];
}, data.hardcodedValues as any);
};
const isHandleConnected = (key: string) => {
return (
data.connections &&
@ -347,7 +309,7 @@ export function CustomNode({
const handleInputClick = (key: string) => {
console.debug(`Opening modal for key: ${key}`);
setActiveKey(key);
const value = getValue(key);
const value = getValue(key, data.hardcodedValues);
setInputModalValue(
typeof value === "object" ? JSON.stringify(value, null, 2) : value,
);

View File

@ -46,16 +46,18 @@ export const providerIcons: Record<
CredentialsProviderName,
React.FC<{ className?: string }>
> = {
anthropic: fallbackIcon,
github: FaGithub,
google: FaGoogle,
groq: fallbackIcon,
notion: NotionLogoIcon,
discord: FaDiscord,
d_id: fallbackIcon,
google_maps: FaGoogle,
jina: fallbackIcon,
ideogram: fallbackIcon,
llm: fallbackIcon,
medium: FaMedium,
ollama: fallbackIcon,
openai: fallbackIcon,
openweathermap: fallbackIcon,
pinecone: fallbackIcon,
@ -80,7 +82,7 @@ export type OAuthPopupResultMessage = { message_type: "oauth_popup_result" } & (
export const CredentialsInput: FC<{
className?: string;
selectedCredentials?: CredentialsMetaInput;
onSelectCredentials: (newValue: CredentialsMetaInput) => void;
onSelectCredentials: (newValue?: CredentialsMetaInput) => void;
}> = ({ className, selectedCredentials, onSelectCredentials }) => {
const api = useMemo(() => new AutoGPTServerAPI(), []);
const credentials = useCredentials();
@ -91,14 +93,10 @@ export const CredentialsInput: FC<{
useState<AbortController | null>(null);
const [oAuthError, setOAuthError] = useState<string | null>(null);
if (!credentials) {
if (!credentials || credentials.isLoading) {
return null;
}
if (credentials.isLoading) {
return <div>Loading...</div>;
}
const {
schema,
provider,
@ -222,10 +220,21 @@ export const CredentialsInput: FC<{
</>
);
// Deselect credentials if they do not exist (e.g. provider was changed)
if (
selectedCredentials &&
!savedApiKeys
.concat(savedOAuthCredentials)
.some((c) => c.id === selectedCredentials.id)
) {
onSelectCredentials(undefined);
}
// No saved credentials yet
if (savedApiKeys.length === 0 && savedOAuthCredentials.length === 0) {
return (
<>
<span className="text-m green mb-0 text-gray-900">Credentials</span>
<div className={cn("flex flex-row space-x-2", className)}>
{supportsOAuth2 && (
<Button onClick={handleOAuthLogin}>
@ -248,6 +257,25 @@ export const CredentialsInput: FC<{
);
}
const singleCredential =
savedApiKeys.length === 1 && savedOAuthCredentials.length === 0
? savedApiKeys[0]
: savedOAuthCredentials.length === 1 && savedApiKeys.length === 0
? savedOAuthCredentials[0]
: null;
if (singleCredential) {
if (!selectedCredentials) {
onSelectCredentials({
id: singleCredential.id,
type: singleCredential.type,
provider,
title: singleCredential.title,
});
}
return null;
}
function handleValueChange(newValue: string) {
if (newValue === "sign-in") {
// Trigger OAuth2 sign in flow
@ -263,7 +291,7 @@ export const CredentialsInput: FC<{
onSelectCredentials({
id: selectedCreds.id,
type: selectedCreds.type,
provider: schema.credentials_provider,
provider: provider,
// title: customTitle, // TODO: add input for title
});
}
@ -272,6 +300,7 @@ export const CredentialsInput: FC<{
// Saved credentials exist
return (
<>
<span className="text-m green mb-0 text-gray-900">Credentials</span>
<Select value={selectedCredentials?.id} onValueChange={handleValueChange}>
<SelectTrigger>
<SelectValue placeholder={schema.placeholder} />

View File

@ -20,16 +20,18 @@ const CREDENTIALS_PROVIDER_NAMES = Object.values(
// --8<-- [start:CredentialsProviderNames]
const providerDisplayNames: Record<CredentialsProviderName, string> = {
anthropic: "Anthropic",
discord: "Discord",
d_id: "D-ID",
github: "GitHub",
google: "Google",
google_maps: "Google Maps",
groq: "Groq",
ideogram: "Ideogram",
jina: "Jina",
medium: "Medium",
llm: "LLM",
notion: "Notion",
ollama: "Ollama",
openai: "OpenAI",
openweathermap: "OpenWeatherMap",
pinecone: "Pinecone",

View File

@ -608,7 +608,15 @@ const NodeStringInput: FC<{
className,
displayName,
}) => {
value ||= schema.default || "";
if (!value) {
value = schema.default || "";
// Force update hardcodedData so discriminators can update
// e.g. credentials update when provider changes
// this won't happen if the value is only set here to schema.default
if (schema.default) {
handleInputChange(selfKey, value);
}
}
return (
<div className={className}>
{schema.enum ? (

View File

@ -1,11 +1,15 @@
import { useContext } from "react";
import { CustomNodeData } from "@/components/CustomNode";
import { BlockIOCredentialsSubSchema } from "@/lib/autogpt-server-api";
import {
BlockIOCredentialsSubSchema,
CredentialsProviderName,
} from "@/lib/autogpt-server-api";
import { Node, useNodeId, useNodesData } from "@xyflow/react";
import {
CredentialsProviderData,
CredentialsProvidersContext,
} from "@/components/integrations/credentials-provider";
import { getValue } from "@/lib/utils";
export type CredentialsData =
| {
@ -34,15 +38,22 @@ export default function useCredentials(): CredentialsData | null {
const credentialsSchema = data.inputSchema.properties
.credentials as BlockIOCredentialsSubSchema;
const discriminatorValue: CredentialsProviderName | null =
(credentialsSchema.discriminator &&
credentialsSchema.discriminator_mapping![
getValue(credentialsSchema.discriminator, data.hardcodedValues)
]) ||
null;
const providerName =
discriminatorValue || credentialsSchema.credentials_provider;
const provider = allProviders ? allProviders[providerName] : null;
// If block input schema doesn't have credentials, return null
if (!credentialsSchema) {
return null;
}
const provider = allProviders
? allProviders[credentialsSchema?.credentials_provider]
: null;
const supportsApiKey =
credentialsSchema.credentials_types.includes("api_key");
const supportsOAuth2 = credentialsSchema.credentials_types.includes("oauth2");
@ -68,6 +79,7 @@ export default function useCredentials(): CredentialsData | null {
return {
...provider,
provider: providerName,
schema: credentialsSchema,
supportsApiKey,
supportsOAuth2,

View File

@ -98,16 +98,18 @@ export type CredentialsType = "api_key" | "oauth2";
// --8<-- [start:BlockIOCredentialsSubSchema]
export const PROVIDER_NAMES = {
ANTHROPIC: "anthropic",
D_ID: "d_id",
DISCORD: "discord",
GITHUB: "github",
GOOGLE: "google",
GOOGLE_MAPS: "google_maps",
GROQ: "groq",
IDEOGRAM: "ideogram",
JINA: "jina",
LLM: "llm",
MEDIUM: "medium",
NOTION: "notion",
OLLAMA: "ollama",
OPENAI: "openai",
OPENWEATHERMAP: "openweathermap",
PINECONE: "pinecone",
@ -124,6 +126,8 @@ export type BlockIOCredentialsSubSchema = BlockIOSubSchemaMeta & {
credentials_provider: CredentialsProviderName;
credentials_scopes?: string[];
credentials_types: Array<CredentialsType>;
discriminator?: string;
discriminator_mapping?: { [key: string]: CredentialsProviderName };
};
export type BlockIONullSubSchema = BlockIOSubSchemaMeta & {

View File

@ -313,3 +313,48 @@ export function findNewlyAddedBlockCoordinates(
y: 0,
};
}
type ParsedKey = { key: string; index?: number };
export function parseKeys(key: string): ParsedKey[] {
const splits = key.split(/_@_|_#_|_\$_|\./);
const keys: ParsedKey[] = [];
let currentKey: string | null = null;
splits.forEach((split) => {
const isInteger = /^\d+$/.test(split);
if (!isInteger) {
if (currentKey !== null) {
keys.push({ key: currentKey });
}
currentKey = split;
} else {
if (currentKey !== null) {
keys.push({ key: currentKey, index: parseInt(split, 10) });
currentKey = null;
} else {
throw new Error("Invalid key format: array index without a key");
}
}
});
if (currentKey !== null) {
keys.push({ key: currentKey });
}
return keys;
}
/**
* Get the value of a nested key in an object, handles arrays and objects.
*/
export function getValue(key: string, value: any) {
const keys = parseKeys(key);
return keys.reduce((acc, k) => {
if (acc === undefined) return undefined;
if (k.index !== undefined) {
return Array.isArray(acc[k.key]) ? acc[k.key][k.index] : undefined;
}
return acc[k.key];
}, value);
}