From 227cf41612f577f1d1406cb97022c4ce3144027a Mon Sep 17 00:00:00 2001 From: Reinier van der Leer Date: Fri, 14 Jun 2024 20:56:03 +0200 Subject: [PATCH] fix(agent, forge): Make LLM API key check provider-agnostic (#7220) * Rename `assert_config_has_openai_api_key` to `assert_config_has_required_llm_api_keys` * Make OpenAI credential check conditional (only if an OpenAI model is selected in the config) * Implement checks for Groq and Anthropic credentials * Use API calls for Groq and OpenAI credential checks to make sure the keys are valid --- autogpt/autogpt/app/main.py | 12 +-- forge/forge/config/__init__.py | 4 +- forge/forge/config/config.py | 132 +++++++++++++++++++++------------ 3 files changed, 93 insertions(+), 55 deletions(-) diff --git a/autogpt/autogpt/app/main.py b/autogpt/autogpt/app/main.py index 0642afb48..1b413a2c5 100644 --- a/autogpt/autogpt/app/main.py +++ b/autogpt/autogpt/app/main.py @@ -21,7 +21,11 @@ from forge.components.code_executor.code_executor import ( ) from forge.config.ai_directives import AIDirectives from forge.config.ai_profile import AIProfile -from forge.config.config import Config, ConfigBuilder, assert_config_has_openai_api_key +from forge.config.config import ( + Config, + ConfigBuilder, + assert_config_has_required_llm_api_keys, +) from forge.file_storage import FileStorageBackendName, get_storage from forge.llm.providers import MultiProvider from forge.logging.config import configure_logging @@ -98,8 +102,7 @@ async def run_auto_gpt( tts_config=config.tts_config, ) - # TODO: fill in llm values here - assert_config_has_openai_api_key(config) + await assert_config_has_required_llm_api_keys(config) await apply_overrides_to_config( config=config, @@ -380,8 +383,7 @@ async def run_auto_gpt_server( tts_config=config.tts_config, ) - # TODO: fill in llm values here - assert_config_has_openai_api_key(config) + await assert_config_has_required_llm_api_keys(config) await apply_overrides_to_config( config=config, diff --git a/forge/forge/config/__init__.py b/forge/forge/config/__init__.py index cfa66121f..82dae4fd8 100644 --- a/forge/forge/config/__init__.py +++ b/forge/forge/config/__init__.py @@ -3,10 +3,10 @@ This module contains configuration models and helpers for AutoGPT Forge. """ from .ai_directives import AIDirectives from .ai_profile import AIProfile -from .config import Config, ConfigBuilder, assert_config_has_openai_api_key +from .config import Config, ConfigBuilder, assert_config_has_required_llm_api_keys __all__ = [ - "assert_config_has_openai_api_key", + "assert_config_has_required_llm_api_keys", "AIProfile", "AIDirectives", "Config", diff --git a/forge/forge/config/config.py b/forge/forge/config/config.py index 27b57cb5a..17c43b389 100644 --- a/forge/forge/config/config.py +++ b/forge/forge/config/config.py @@ -7,8 +7,6 @@ import re from pathlib import Path from typing import Any, Optional, Union -import click -from colorama import Fore from pydantic import SecretStr, validator import forge @@ -208,55 +206,93 @@ class ConfigBuilder(Configurable[Config]): return config -def assert_config_has_openai_api_key(config: Config) -> None: - """Check if the OpenAI API key is set in config.py or as an environment variable.""" - key_pattern = r"^sk-(proj-)?\w{48}" - openai_api_key = ( - config.openai_credentials.api_key.get_secret_value() - if config.openai_credentials - else "" - ) +async def assert_config_has_required_llm_api_keys(config: Config) -> None: + """ + Check if API keys (if required) are set for the configured SMART_LLM and FAST_LLM. + """ + from pydantic import ValidationError - # If there's no credentials or empty API key, prompt the user to set it - if not openai_api_key: - logger.error( - "Please set your OpenAI API key in .env or as an environment variable." - ) - logger.info( - "You can get your key from https://platform.openai.com/account/api-keys" - ) - openai_api_key = click.prompt( - "Please enter your OpenAI API key if you have it", - default="", - show_default=False, - ) - openai_api_key = openai_api_key.strip() - if re.search(key_pattern, openai_api_key): - os.environ["OPENAI_API_KEY"] = openai_api_key - if config.openai_credentials: - config.openai_credentials.api_key = SecretStr(openai_api_key) - else: - config.openai_credentials = OpenAICredentials( - api_key=SecretStr(openai_api_key) + from forge.llm.providers.anthropic import AnthropicModelName + from forge.llm.providers.groq import GroqModelName + + if set((config.smart_llm, config.fast_llm)).intersection(AnthropicModelName): + from forge.llm.providers.anthropic import AnthropicCredentials + + try: + credentials = AnthropicCredentials.from_env() + except ValidationError as e: + if "api_key" in str(e): + logger.error( + "Set your Anthropic API key in .env or as an environment variable" ) - print("OpenAI API key successfully set!") - print( - f"{Fore.YELLOW}NOTE: The API key you've set is only temporary. " - f"For longer sessions, please set it in the .env file{Fore.RESET}" + logger.info( + "For further instructions: " + "https://docs.agpt.co/autogpt/setup/#anthropic" + ) + + raise ValueError("Anthropic is unavailable: can't load credentials") from e + + key_pattern = r"^sk-ant-api03-[\w\-]{95}" + + # If key is set, but it looks invalid + if not re.search(key_pattern, credentials.api_key.get_secret_value()): + logger.warning( + "Possibly invalid Anthropic API key! " + f"Configured Anthropic API key does not match pattern '{key_pattern}'. " + "If this is a valid key, please report this warning to the maintainers." ) - else: - print(f"{Fore.RED}Invalid OpenAI API key{Fore.RESET}") - exit(1) - # If key is set, but it looks invalid - elif not re.search(key_pattern, openai_api_key): - logger.error( - "Invalid OpenAI API key! " - "Please set your OpenAI API key in .env or as an environment variable." - ) - logger.info( - "You can get your key from https://platform.openai.com/account/api-keys" - ) - exit(1) + + if set((config.smart_llm, config.fast_llm)).intersection(GroqModelName): + from groq import AuthenticationError + + from forge.llm.providers.groq import GroqProvider + + try: + groq = GroqProvider() + await groq.get_available_models() + except ValidationError as e: + if "api_key" not in str(e): + raise + + logger.error("Set your Groq API key in .env or as an environment variable") + logger.info( + "For further instructions: https://docs.agpt.co/autogpt/setup/#groq" + ) + raise ValueError("Groq is unavailable: can't load credentials") + except AuthenticationError as e: + logger.error("The Groq API key is invalid!") + logger.info( + "For instructions to get and set a new API key: " + "https://docs.agpt.co/autogpt/setup/#groq" + ) + raise ValueError("Groq is unavailable: invalid API key") from e + + if set((config.smart_llm, config.fast_llm)).intersection(OpenAIModelName): + from openai import AuthenticationError + + from forge.llm.providers.openai import OpenAIProvider + + try: + openai = OpenAIProvider() + await openai.get_available_models() + except ValidationError as e: + if "api_key" not in str(e): + raise + + logger.error( + "Set your OpenAI API key in .env or as an environment variable" + ) + logger.info( + "For further instructions: https://docs.agpt.co/autogpt/setup/#openai" + ) + raise ValueError("OpenAI is unavailable: can't load credentials") + except AuthenticationError as e: + logger.error("The OpenAI API key is invalid!") + logger.info( + "For instructions to get and set a new API key: " + "https://docs.agpt.co/autogpt/setup/#openai" + ) + raise ValueError("OpenAI is unavailable: invalid API key") from e def _safe_split(s: Union[str, None], sep: str = ",") -> list[str]: