Remove import time loading of config from llm_utils (#3245)

pull/3194/head^2
James Collins 2023-04-25 12:10:12 -07:00 committed by GitHub
parent 1806fc683d
commit 6fbac455d4
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 20 additions and 15 deletions

View File

@ -13,8 +13,6 @@ from autogpt.config import Config
from autogpt.logs import logger
from autogpt.types.openai import Message
CFG = Config()
def retry_openai_api(
num_retries: int = 10,
@ -86,8 +84,9 @@ def call_ai_function(
Returns:
str: The response from the function
"""
cfg = Config()
if model is None:
model = CFG.smart_llm_model
model = cfg.smart_llm_model
# For each arg, if any are None, convert to "None":
args = [str(arg) if arg is not None else "None" for arg in args]
# parse args to comma separated string
@ -109,7 +108,7 @@ def call_ai_function(
def create_chat_completion(
messages: List[Message], # type: ignore
model: Optional[str] = None,
temperature: float = CFG.temperature,
temperature: float = None,
max_tokens: Optional[int] = None,
) -> str:
"""Create a chat completion using the OpenAI API
@ -123,13 +122,17 @@ def create_chat_completion(
Returns:
str: The response from the chat completion
"""
cfg = Config()
if temperature is None:
temperature = cfg.temperature
num_retries = 10
warned_user = False
if CFG.debug_mode:
if cfg.debug_mode:
print(
f"{Fore.GREEN}Creating chat completion with model {model}, temperature {temperature}, max_tokens {max_tokens}{Fore.RESET}"
)
for plugin in CFG.plugins:
for plugin in cfg.plugins:
if plugin.can_handle_chat_completion(
messages=messages,
model=model,
@ -148,9 +151,9 @@ def create_chat_completion(
for attempt in range(num_retries):
backoff = 2 ** (attempt + 2)
try:
if CFG.use_azure:
if cfg.use_azure:
response = api_manager.create_chat_completion(
deployment_id=CFG.get_azure_deployment_id_for_model(model),
deployment_id=cfg.get_azure_deployment_id_for_model(model),
model=model,
messages=messages,
temperature=temperature,
@ -165,7 +168,7 @@ def create_chat_completion(
)
break
except RateLimitError:
if CFG.debug_mode:
if cfg.debug_mode:
print(
f"{Fore.RED}Error: ", f"Reached rate limit, passing...{Fore.RESET}"
)
@ -180,7 +183,7 @@ def create_chat_completion(
raise
if attempt == num_retries - 1:
raise
if CFG.debug_mode:
if cfg.debug_mode:
print(
f"{Fore.RED}Error: ",
f"API Bad gateway. Waiting {backoff} seconds...{Fore.RESET}",
@ -194,12 +197,12 @@ def create_chat_completion(
+ f"Try running Auto-GPT again, and if the problem the persists try running it with `{Fore.CYAN}--debug{Fore.RESET}`.",
)
logger.double_check()
if CFG.debug_mode:
if cfg.debug_mode:
raise RuntimeError(f"Failed to get response after {num_retries} retries")
else:
quit(1)
resp = response.choices[0].message["content"]
for plugin in CFG.plugins:
for plugin in cfg.plugins:
if not plugin.can_handle_on_response():
continue
resp = plugin.on_response(resp)
@ -215,11 +218,12 @@ def get_ada_embedding(text: str) -> List[int]:
Returns:
List[int]: The embedding.
"""
cfg = Config()
model = "text-embedding-ada-002"
text = text.replace("\n", " ")
if CFG.use_azure:
kwargs = {"engine": CFG.get_azure_deployment_id_for_model(model)}
if cfg.use_azure:
kwargs = {"engine": cfg.get_azure_deployment_id_for_model(model)}
else:
kwargs = {"model": model}
@ -247,8 +251,9 @@ def create_embedding(
Returns:
openai.Embedding: The embedding object.
"""
cfg = Config()
return openai.Embedding.create(
input=[text],
api_key=CFG.openai_api_key,
api_key=cfg.openai_api_key,
**kwargs,
)