refactor(agent/openai): Upgrade OpenAI library to v1
- Update `openai` dependency from ^v0.27.10 to ^v1.7.2 - Update poetry.lock - Update code for changed endpoints and new output types of OpenAI library - Replace uses of `AssistantChatMessageDict` by `AssistantChatMessage` - Update `PromptStrategy`, `BaseAgent`, and all of their subclasses accordingly - Update `OpenAIProvider`, `OpenAICredentials`, azure.yaml.template, .env.template and test_config.py to work with new separate `AzureOpenAI` client - Remove `_OpenAIRetryHandler` and implement retry mechanism with `tenacity` - Rewrite pytest fixture `cached_openai_client` (renamed from `patched_api_requestor`) for OpenAI v1 librarypull/6783/head
parent
39fd1d6be1
commit
f2595af362
|
@ -76,9 +76,11 @@ OPENAI_API_KEY=your-openai-api-key
|
|||
## USE_AZURE - Use Azure OpenAI or not (Default: False)
|
||||
# USE_AZURE=False
|
||||
|
||||
## AZURE_CONFIG_FILE - The path to the azure.yaml file, relative to the AutoGPT root directory. (Default: azure.yaml)
|
||||
## AZURE_CONFIG_FILE - The path to the azure.yaml file, relative to the folder containing this file. (Default: azure.yaml)
|
||||
# AZURE_CONFIG_FILE=azure.yaml
|
||||
|
||||
# AZURE_OPENAI_AD_TOKEN=
|
||||
# AZURE_OPENAI_ENDPOINT=
|
||||
|
||||
################################################################################
|
||||
### LLM MODELS
|
||||
|
|
|
@ -10,7 +10,7 @@ from autogpt.core.prompting import (
|
|||
)
|
||||
from autogpt.core.prompting.utils import json_loads
|
||||
from autogpt.core.resource.model_providers.schema import (
|
||||
AssistantChatMessageDict,
|
||||
AssistantChatMessage,
|
||||
ChatMessage,
|
||||
ChatModelProvider,
|
||||
CompletionModelFunction,
|
||||
|
@ -186,7 +186,7 @@ class AgentProfileGenerator(PromptStrategy):
|
|||
|
||||
def parse_response_content(
|
||||
self,
|
||||
response_content: AssistantChatMessageDict,
|
||||
response_content: AssistantChatMessage,
|
||||
) -> tuple[AIProfile, AIDirectives]:
|
||||
"""Parse the actual text response from the objective model.
|
||||
|
||||
|
@ -198,16 +198,21 @@ class AgentProfileGenerator(PromptStrategy):
|
|||
|
||||
"""
|
||||
try:
|
||||
arguments = json_loads(
|
||||
response_content["tool_calls"][0]["function"]["arguments"]
|
||||
if not response_content.tool_calls:
|
||||
raise ValueError(
|
||||
f"LLM did not call {self._create_agent_function.name} function; "
|
||||
"agent profile creation failed"
|
||||
)
|
||||
arguments: object = json_loads(
|
||||
response_content.tool_calls[0].function.arguments
|
||||
)
|
||||
ai_profile = AIProfile(
|
||||
ai_name=arguments.get("name"),
|
||||
ai_role=arguments.get("description"),
|
||||
)
|
||||
ai_directives = AIDirectives(
|
||||
best_practices=arguments["directives"].get("best_practices"),
|
||||
constraints=arguments["directives"].get("constraints"),
|
||||
best_practices=arguments.get("directives", {}).get("best_practices"),
|
||||
constraints=arguments.get("directives", {}).get("constraints"),
|
||||
resources=[],
|
||||
)
|
||||
except KeyError:
|
||||
|
|
|
@ -15,7 +15,7 @@ from pydantic import Field
|
|||
from autogpt.core.configuration import Configurable
|
||||
from autogpt.core.prompting import ChatPrompt
|
||||
from autogpt.core.resource.model_providers import (
|
||||
AssistantChatMessageDict,
|
||||
AssistantChatMessage,
|
||||
ChatMessage,
|
||||
ChatModelProvider,
|
||||
)
|
||||
|
@ -172,14 +172,12 @@ class Agent(
|
|||
return prompt
|
||||
|
||||
def parse_and_process_response(
|
||||
self, llm_response: AssistantChatMessageDict, *args, **kwargs
|
||||
self, llm_response: AssistantChatMessage, *args, **kwargs
|
||||
) -> Agent.ThoughtProcessOutput:
|
||||
for plugin in self.config.plugins:
|
||||
if not plugin.can_handle_post_planning():
|
||||
continue
|
||||
llm_response["content"] = plugin.post_planning(
|
||||
llm_response.get("content", "")
|
||||
)
|
||||
llm_response.content = plugin.post_planning(llm_response.content or "")
|
||||
|
||||
(
|
||||
command_name,
|
||||
|
|
|
@ -12,7 +12,7 @@ if TYPE_CHECKING:
|
|||
from autogpt.config import Config
|
||||
from autogpt.core.prompting.base import PromptStrategy
|
||||
from autogpt.core.resource.model_providers.schema import (
|
||||
AssistantChatMessageDict,
|
||||
AssistantChatMessage,
|
||||
ChatModelInfo,
|
||||
ChatModelProvider,
|
||||
ChatModelResponse,
|
||||
|
@ -410,7 +410,7 @@ class BaseAgent(Configurable[BaseAgentSettings], ABC):
|
|||
@abstractmethod
|
||||
def parse_and_process_response(
|
||||
self,
|
||||
llm_response: AssistantChatMessageDict,
|
||||
llm_response: AssistantChatMessage,
|
||||
prompt: ChatPrompt,
|
||||
scratchpad: PromptScratchpad,
|
||||
) -> ThoughtProcessOutput:
|
||||
|
|
|
@ -21,7 +21,7 @@ from autogpt.core.prompting import (
|
|||
PromptStrategy,
|
||||
)
|
||||
from autogpt.core.resource.model_providers.schema import (
|
||||
AssistantChatMessageDict,
|
||||
AssistantChatMessage,
|
||||
ChatMessage,
|
||||
CompletionModelFunction,
|
||||
)
|
||||
|
@ -386,12 +386,12 @@ class OneShotAgentPromptStrategy(PromptStrategy):
|
|||
|
||||
def parse_response_content(
|
||||
self,
|
||||
response: AssistantChatMessageDict,
|
||||
response: AssistantChatMessage,
|
||||
) -> Agent.ThoughtProcessOutput:
|
||||
if "content" not in response:
|
||||
if not response.content:
|
||||
raise InvalidAgentResponseError("Assistant response has no text content")
|
||||
|
||||
assistant_reply_dict = extract_dict_from_response(response["content"])
|
||||
assistant_reply_dict = extract_dict_from_response(response.content)
|
||||
|
||||
_, errors = self.response_schema.validate_object(
|
||||
object=assistant_reply_dict,
|
||||
|
@ -417,14 +417,14 @@ class OneShotAgentPromptStrategy(PromptStrategy):
|
|||
|
||||
def extract_command(
|
||||
assistant_reply_json: dict,
|
||||
assistant_reply: AssistantChatMessageDict,
|
||||
assistant_reply: AssistantChatMessage,
|
||||
use_openai_functions_api: bool,
|
||||
) -> tuple[str, dict[str, str]]:
|
||||
"""Parse the response and return the command name and arguments
|
||||
|
||||
Args:
|
||||
assistant_reply_json (dict): The response object from the AI
|
||||
assistant_reply (ChatModelResponse): The model response from the AI
|
||||
assistant_reply (AssistantChatMessage): The model response from the AI
|
||||
config (Config): The config object
|
||||
|
||||
Returns:
|
||||
|
@ -436,13 +436,11 @@ def extract_command(
|
|||
Exception: If any other error occurs
|
||||
"""
|
||||
if use_openai_functions_api:
|
||||
if not assistant_reply.get("tool_calls"):
|
||||
if not assistant_reply.tool_calls:
|
||||
raise InvalidAgentResponseError("No 'tool_calls' in assistant reply")
|
||||
assistant_reply_json["command"] = {
|
||||
"name": assistant_reply["tool_calls"][0]["function"]["name"],
|
||||
"args": json.loads(
|
||||
assistant_reply["tool_calls"][0]["function"]["arguments"]
|
||||
),
|
||||
"name": assistant_reply.tool_calls[0].function.name,
|
||||
"args": json.loads(assistant_reply.tool_calls[0].function.arguments),
|
||||
}
|
||||
try:
|
||||
if not isinstance(assistant_reply_json, dict):
|
||||
|
|
|
@ -1,4 +1,3 @@
|
|||
import copy
|
||||
import logging
|
||||
import os
|
||||
import pathlib
|
||||
|
@ -34,6 +33,7 @@ from autogpt.commands.system import finish
|
|||
from autogpt.commands.user_interaction import ask_user
|
||||
from autogpt.config import Config
|
||||
from autogpt.core.resource.model_providers import ChatModelProvider
|
||||
from autogpt.core.resource.model_providers.openai import OpenAIProvider
|
||||
from autogpt.file_workspace import (
|
||||
FileWorkspace,
|
||||
FileWorkspaceBackendName,
|
||||
|
@ -414,8 +414,8 @@ class AgentProtocolServer:
|
|||
"""
|
||||
Configures the LLM provider with headers to link outgoing requests to the task.
|
||||
"""
|
||||
task_llm_provider = copy.deepcopy(self.llm_provider)
|
||||
_extra_request_headers = task_llm_provider._configuration.extra_request_headers
|
||||
task_llm_provider_config = self.llm_provider._configuration.copy(deep=True)
|
||||
_extra_request_headers = task_llm_provider_config.extra_request_headers
|
||||
|
||||
_extra_request_headers["AP-TaskID"] = task.task_id
|
||||
if step_id:
|
||||
|
@ -423,7 +423,15 @@ class AgentProtocolServer:
|
|||
if task.additional_input and (user_id := task.additional_input.get("user_id")):
|
||||
_extra_request_headers["AutoGPT-UserID"] = user_id
|
||||
|
||||
return task_llm_provider
|
||||
if isinstance(self.llm_provider, OpenAIProvider):
|
||||
settings = self.llm_provider._settings.copy()
|
||||
settings.configuration = task_llm_provider_config
|
||||
return OpenAIProvider(
|
||||
settings=settings,
|
||||
logger=logger.getChild(f"Task-{task.task_id}_OpenAIProvider"),
|
||||
)
|
||||
|
||||
return self.llm_provider
|
||||
|
||||
|
||||
def task_agent_id(task_id: str | int) -> str:
|
||||
|
|
|
@ -190,9 +190,9 @@ def check_model(
|
|||
) -> str:
|
||||
"""Check if model is available for use. If not, return gpt-3.5-turbo."""
|
||||
api_manager = ApiManager()
|
||||
models = api_manager.get_models(**api_credentials.get_api_access_kwargs(model_name))
|
||||
models = api_manager.get_models(api_credentials)
|
||||
|
||||
if any(model_name in m["id"] for m in models):
|
||||
if any(model_name == m.id for m in models):
|
||||
return model_name
|
||||
|
||||
logger.warning(
|
||||
|
|
|
@ -8,8 +8,8 @@ import uuid
|
|||
from base64 import b64decode
|
||||
from pathlib import Path
|
||||
|
||||
import openai
|
||||
import requests
|
||||
from openai import OpenAI
|
||||
from PIL import Image
|
||||
|
||||
from autogpt.agents.agent import Agent
|
||||
|
@ -142,17 +142,18 @@ def generate_image_with_dalle(
|
|||
)
|
||||
size = closest
|
||||
|
||||
response = openai.Image.create(
|
||||
response = OpenAI(
|
||||
api_key=agent.legacy_config.openai_credentials.api_key.get_secret_value()
|
||||
).images.generate(
|
||||
prompt=prompt,
|
||||
n=1,
|
||||
size=f"{size}x{size}",
|
||||
response_format="b64_json",
|
||||
api_key=agent.legacy_config.openai_credentials.api_key.get_secret_value(),
|
||||
)
|
||||
|
||||
logger.info(f"Image Generated for prompt:{prompt}")
|
||||
|
||||
image_data = b64decode(response["data"][0]["b64_json"])
|
||||
image_data = b64decode(response.data[0].b64_json)
|
||||
|
||||
with open(output_file, mode="wb") as png:
|
||||
png.write(image_data)
|
||||
|
|
|
@ -6,7 +6,7 @@ from autogpt.core.prompting import PromptStrategy
|
|||
from autogpt.core.prompting.schema import ChatPrompt, LanguageModelClassification
|
||||
from autogpt.core.prompting.utils import json_loads, to_numbered_list
|
||||
from autogpt.core.resource.model_providers import (
|
||||
AssistantChatMessageDict,
|
||||
AssistantChatMessage,
|
||||
ChatMessage,
|
||||
CompletionModelFunction,
|
||||
)
|
||||
|
@ -178,7 +178,7 @@ class InitialPlan(PromptStrategy):
|
|||
|
||||
def parse_response_content(
|
||||
self,
|
||||
response_content: AssistantChatMessageDict,
|
||||
response_content: AssistantChatMessage,
|
||||
) -> dict:
|
||||
"""Parse the actual text response from the objective model.
|
||||
|
||||
|
@ -189,8 +189,13 @@ class InitialPlan(PromptStrategy):
|
|||
The parsed response.
|
||||
"""
|
||||
try:
|
||||
parsed_response = json_loads(
|
||||
response_content["tool_calls"][0]["function"]["arguments"]
|
||||
if not response_content.tool_calls:
|
||||
raise ValueError(
|
||||
f"LLM did not call {self._create_plan_function.name} function; "
|
||||
"plan creation failed"
|
||||
)
|
||||
parsed_response: object = json_loads(
|
||||
response_content.tool_calls[0].function.arguments
|
||||
)
|
||||
parsed_response["task_list"] = [
|
||||
Task.parse_obj(task) for task in parsed_response["task_list"]
|
||||
|
|
|
@ -5,7 +5,7 @@ from autogpt.core.prompting import PromptStrategy
|
|||
from autogpt.core.prompting.schema import ChatPrompt, LanguageModelClassification
|
||||
from autogpt.core.prompting.utils import json_loads
|
||||
from autogpt.core.resource.model_providers import (
|
||||
AssistantChatMessageDict,
|
||||
AssistantChatMessage,
|
||||
ChatMessage,
|
||||
CompletionModelFunction,
|
||||
)
|
||||
|
@ -124,7 +124,7 @@ class NameAndGoals(PromptStrategy):
|
|||
|
||||
def parse_response_content(
|
||||
self,
|
||||
response_content: AssistantChatMessageDict,
|
||||
response_content: AssistantChatMessage,
|
||||
) -> dict:
|
||||
"""Parse the actual text response from the objective model.
|
||||
|
||||
|
@ -136,8 +136,13 @@ class NameAndGoals(PromptStrategy):
|
|||
|
||||
"""
|
||||
try:
|
||||
if not response_content.tool_calls:
|
||||
raise ValueError(
|
||||
f"LLM did not call {self._create_agent_function} function; "
|
||||
"agent profile creation failed"
|
||||
)
|
||||
parsed_response = json_loads(
|
||||
response_content["tool_calls"][0]["function"]["arguments"]
|
||||
response_content.tool_calls[0].function.arguments
|
||||
)
|
||||
except KeyError:
|
||||
logger.debug(f"Failed to parse this response content: {response_content}")
|
||||
|
|
|
@ -6,7 +6,7 @@ from autogpt.core.prompting import PromptStrategy
|
|||
from autogpt.core.prompting.schema import ChatPrompt, LanguageModelClassification
|
||||
from autogpt.core.prompting.utils import json_loads, to_numbered_list
|
||||
from autogpt.core.resource.model_providers import (
|
||||
AssistantChatMessageDict,
|
||||
AssistantChatMessage,
|
||||
ChatMessage,
|
||||
CompletionModelFunction,
|
||||
)
|
||||
|
@ -171,7 +171,7 @@ class NextAbility(PromptStrategy):
|
|||
|
||||
def parse_response_content(
|
||||
self,
|
||||
response_content: AssistantChatMessageDict,
|
||||
response_content: AssistantChatMessage,
|
||||
) -> dict:
|
||||
"""Parse the actual text response from the objective model.
|
||||
|
||||
|
@ -183,9 +183,12 @@ class NextAbility(PromptStrategy):
|
|||
|
||||
"""
|
||||
try:
|
||||
function_name = response_content["tool_calls"][0]["function"]["name"]
|
||||
if not response_content.tool_calls:
|
||||
raise ValueError("LLM did not call any function")
|
||||
|
||||
function_name = response_content.tool_calls[0].function.name
|
||||
function_arguments = json_loads(
|
||||
response_content["tool_calls"][0]["function"]["arguments"]
|
||||
response_content.tool_calls[0].function.arguments
|
||||
)
|
||||
parsed_response = {
|
||||
"motivation": function_arguments.pop("motivation"),
|
||||
|
|
|
@ -1,7 +1,7 @@
|
|||
import abc
|
||||
|
||||
from autogpt.core.configuration import SystemConfiguration
|
||||
from autogpt.core.resource.model_providers import AssistantChatMessageDict
|
||||
from autogpt.core.resource.model_providers import AssistantChatMessage
|
||||
|
||||
from .schema import ChatPrompt, LanguageModelClassification
|
||||
|
||||
|
@ -19,5 +19,5 @@ class PromptStrategy(abc.ABC):
|
|||
...
|
||||
|
||||
@abc.abstractmethod
|
||||
def parse_response_content(self, response_content: AssistantChatMessageDict):
|
||||
def parse_response_content(self, response_content: AssistantChatMessage):
|
||||
...
|
||||
|
|
|
@ -1,21 +1,22 @@
|
|||
import enum
|
||||
import functools
|
||||
import logging
|
||||
import math
|
||||
import os
|
||||
import time
|
||||
from pathlib import Path
|
||||
from typing import Callable, Optional, ParamSpec, TypeVar
|
||||
from typing import Callable, Coroutine, Iterator, Optional, ParamSpec, TypeVar
|
||||
|
||||
import openai
|
||||
import tenacity
|
||||
import tiktoken
|
||||
import yaml
|
||||
from openai.error import APIError, RateLimitError
|
||||
from openai._exceptions import APIStatusError, RateLimitError
|
||||
from openai.types import CreateEmbeddingResponse
|
||||
from openai.types.chat import ChatCompletion
|
||||
from pydantic import SecretStr
|
||||
|
||||
from autogpt.core.configuration import Configurable, UserConfigurable
|
||||
from autogpt.core.resource.model_providers.schema import (
|
||||
AssistantChatMessageDict,
|
||||
AssistantChatMessage,
|
||||
AssistantToolCall,
|
||||
AssistantToolCallDict,
|
||||
ChatMessage,
|
||||
ChatModelInfo,
|
||||
|
@ -166,7 +167,6 @@ OPEN_AI_MODELS = {
|
|||
|
||||
class OpenAIConfiguration(ModelProviderConfiguration):
|
||||
fix_failed_parse_tries: int = UserConfigurable(3)
|
||||
pass
|
||||
|
||||
|
||||
class OpenAICredentials(ModelProviderCredentials):
|
||||
|
@ -187,32 +187,45 @@ class OpenAICredentials(ModelProviderCredentials):
|
|||
),
|
||||
)
|
||||
api_version: str = UserConfigurable("", from_env="OPENAI_API_VERSION")
|
||||
azure_endpoint: Optional[SecretStr] = None
|
||||
azure_model_to_deploy_id_map: Optional[dict[str, str]] = None
|
||||
|
||||
def get_api_access_kwargs(self, model: str = "") -> dict[str, str]:
|
||||
credentials = {k: v for k, v in self.unmasked().items() if type(v) is str}
|
||||
def get_api_access_kwargs(self) -> dict[str, str]:
|
||||
kwargs = {
|
||||
k: (v.get_secret_value() if type(v) is SecretStr else v)
|
||||
for k, v in {
|
||||
"api_key": self.api_key,
|
||||
"base_url": self.api_base,
|
||||
"organization": self.organization,
|
||||
}.items()
|
||||
if v is not None
|
||||
}
|
||||
if self.api_type == "azure":
|
||||
kwargs["api_version"] = self.api_version
|
||||
kwargs["azure_endpoint"] = self.azure_endpoint
|
||||
return kwargs
|
||||
|
||||
def get_model_access_kwargs(self, model: str) -> dict[str, str]:
|
||||
kwargs = {"model": model}
|
||||
if self.api_type == "azure" and model:
|
||||
azure_credentials = self._get_azure_access_kwargs(model)
|
||||
credentials.update(azure_credentials)
|
||||
return credentials
|
||||
azure_kwargs = self._get_azure_access_kwargs(model)
|
||||
kwargs.update(azure_kwargs)
|
||||
return kwargs
|
||||
|
||||
def load_azure_config(self, config_file: Path) -> None:
|
||||
with open(config_file) as file:
|
||||
config_params = yaml.load(file, Loader=yaml.FullLoader) or {}
|
||||
|
||||
try:
|
||||
assert (
|
||||
azure_api_base := config_params.get("azure_api_base", "")
|
||||
) != "", "Azure API base URL not set"
|
||||
assert config_params.get(
|
||||
"azure_model_map", {}
|
||||
), "Azure model->deployment_id map is empty"
|
||||
except AssertionError as e:
|
||||
raise ValueError(*e.args)
|
||||
|
||||
self.api_base = SecretStr(azure_api_base)
|
||||
self.api_type = config_params.get("azure_api_type", "azure")
|
||||
self.api_version = config_params.get("azure_api_version", "")
|
||||
self.azure_endpoint = config_params.get("azure_endpoint")
|
||||
self.azure_model_to_deploy_id_map = config_params.get("azure_model_map")
|
||||
|
||||
def _get_azure_access_kwargs(self, model: str) -> dict[str, str]:
|
||||
|
@ -225,10 +238,7 @@ class OpenAICredentials(ModelProviderCredentials):
|
|||
raise ValueError(f"No Azure deployment ID configured for model '{model}'")
|
||||
deployment_id = self.azure_model_to_deploy_id_map[model]
|
||||
|
||||
if model in OPEN_AI_EMBEDDING_MODELS:
|
||||
return {"engine": deployment_id}
|
||||
else:
|
||||
return {"deployment_id": deployment_id}
|
||||
return {"model": deployment_id}
|
||||
|
||||
|
||||
class OpenAIModelProviderBudget(ModelProviderBudget):
|
||||
|
@ -273,21 +283,26 @@ class OpenAIProvider(
|
|||
settings: OpenAISettings,
|
||||
logger: logging.Logger,
|
||||
):
|
||||
self._settings = settings
|
||||
|
||||
assert settings.credentials, "Cannot create OpenAIProvider without credentials"
|
||||
self._configuration = settings.configuration
|
||||
self._credentials = settings.credentials
|
||||
self._budget = settings.budget
|
||||
|
||||
if self._credentials.api_type == "azure":
|
||||
from openai import AsyncAzureOpenAI
|
||||
|
||||
# API key and org (if configured) are passed, the rest of the required
|
||||
# credentials is loaded from the environment by the AzureOpenAI client.
|
||||
self._client = AsyncAzureOpenAI(**self._credentials.get_api_access_kwargs())
|
||||
else:
|
||||
from openai import AsyncOpenAI
|
||||
|
||||
self._client = AsyncOpenAI(**self._credentials.get_api_access_kwargs())
|
||||
|
||||
self._logger = logger
|
||||
|
||||
retry_handler = _OpenAIRetryHandler(
|
||||
logger=self._logger,
|
||||
num_retries=self._configuration.retries_per_request,
|
||||
)
|
||||
|
||||
self._create_chat_completion = retry_handler(_create_chat_completion)
|
||||
self._create_embedding = retry_handler(_create_embedding)
|
||||
|
||||
def get_token_limit(self, model_name: str) -> int:
|
||||
"""Get the token limit for a given model."""
|
||||
return OPEN_AI_MODELS[model_name].max_tokens
|
||||
|
@ -333,7 +348,7 @@ class OpenAIProvider(
|
|||
try:
|
||||
encoding = tiktoken.encoding_for_model(encoding_model)
|
||||
except KeyError:
|
||||
cls._logger.warning(
|
||||
logging.getLogger(__class__.__name__).warning(
|
||||
f"Model {model_name} not found. Defaulting to cl100k_base encoding."
|
||||
)
|
||||
encoding = tiktoken.get_encoding("cl100k_base")
|
||||
|
@ -352,7 +367,7 @@ class OpenAIProvider(
|
|||
self,
|
||||
model_prompt: list[ChatMessage],
|
||||
model_name: OpenAIModelName,
|
||||
completion_parser: Callable[[AssistantChatMessageDict], _T] = lambda _: None,
|
||||
completion_parser: Callable[[AssistantChatMessage], _T] = lambda _: None,
|
||||
functions: Optional[list[CompletionModelFunction]] = None,
|
||||
**kwargs,
|
||||
) -> ChatModelResponse[_T]:
|
||||
|
@ -370,23 +385,33 @@ class OpenAIProvider(
|
|||
messages=model_prompt,
|
||||
**completion_kwargs,
|
||||
)
|
||||
response_args = {
|
||||
"model_info": OPEN_AI_CHAT_MODELS[model_name],
|
||||
"prompt_tokens_used": response.usage.prompt_tokens,
|
||||
"completion_tokens_used": response.usage.completion_tokens,
|
||||
}
|
||||
|
||||
response_message = response.choices[0].message.to_dict_recursive()
|
||||
if tool_calls_compat_mode:
|
||||
response_message["tool_calls"] = _tool_calls_compat_extract_calls(
|
||||
response_message["content"]
|
||||
response_message = response.choices[0].message
|
||||
if (
|
||||
tool_calls_compat_mode
|
||||
and response_message.content
|
||||
and not response_message.tool_calls
|
||||
):
|
||||
tool_calls = list(
|
||||
_tool_calls_compat_extract_calls(response_message.content)
|
||||
)
|
||||
elif response_message.tool_calls:
|
||||
tool_calls = [
|
||||
AssistantToolCall(**tc.dict()) for tc in response_message.tool_calls
|
||||
]
|
||||
else:
|
||||
tool_calls = None
|
||||
|
||||
assistant_message = AssistantChatMessage(
|
||||
content=response_message.content,
|
||||
tool_calls=tool_calls,
|
||||
)
|
||||
|
||||
# If parsing the response fails, append the error to the prompt, and let the
|
||||
# LLM fix its mistake(s).
|
||||
try:
|
||||
attempts += 1
|
||||
parsed_response = completion_parser(response_message)
|
||||
parsed_response = completion_parser(assistant_message)
|
||||
break
|
||||
except Exception as e:
|
||||
self._logger.warning(f"Parsing attempt #{attempts} failed: {e}")
|
||||
|
@ -401,9 +426,13 @@ class OpenAIProvider(
|
|||
raise
|
||||
|
||||
response = ChatModelResponse(
|
||||
response=response_message,
|
||||
response=assistant_message,
|
||||
parsed_result=parsed_response,
|
||||
**response_args,
|
||||
model_info=OPEN_AI_CHAT_MODELS[model_name],
|
||||
prompt_tokens_used=response.usage.prompt_tokens if response.usage else 0,
|
||||
completion_tokens_used=(
|
||||
response.usage.completion_tokens if response.usage else 0
|
||||
),
|
||||
)
|
||||
self._budget.update_usage_and_cost(response)
|
||||
return response
|
||||
|
@ -419,14 +448,11 @@ class OpenAIProvider(
|
|||
embedding_kwargs = self._get_embedding_kwargs(model_name, **kwargs)
|
||||
response = await self._create_embedding(text=text, **embedding_kwargs)
|
||||
|
||||
response_args = {
|
||||
"model_info": OPEN_AI_EMBEDDING_MODELS[model_name],
|
||||
"prompt_tokens_used": response.usage.prompt_tokens,
|
||||
"completion_tokens_used": response.usage.completion_tokens,
|
||||
}
|
||||
response = EmbeddingModelResponse(
|
||||
**response_args,
|
||||
embedding=embedding_parser(response.embeddings[0]),
|
||||
embedding=embedding_parser(response.data[0].embedding),
|
||||
model_info=OPEN_AI_EMBEDDING_MODELS[model_name],
|
||||
prompt_tokens_used=response.usage.prompt_tokens,
|
||||
completion_tokens_used=0,
|
||||
)
|
||||
self._budget.update_usage_and_cost(response)
|
||||
return response
|
||||
|
@ -447,34 +473,29 @@ class OpenAIProvider(
|
|||
The kwargs for the chat API call.
|
||||
|
||||
"""
|
||||
completion_kwargs = {
|
||||
"model": model_name,
|
||||
**kwargs,
|
||||
**self._credentials.get_api_access_kwargs(model_name),
|
||||
}
|
||||
kwargs.update(self._credentials.get_model_access_kwargs(model_name))
|
||||
|
||||
if functions:
|
||||
if OPEN_AI_CHAT_MODELS[model_name].has_function_call_api:
|
||||
completion_kwargs["tools"] = [
|
||||
kwargs["tools"] = [
|
||||
{"type": "function", "function": f.schema} for f in functions
|
||||
]
|
||||
if len(functions) == 1:
|
||||
# force the model to call the only specified function
|
||||
completion_kwargs["tool_choice"] = {
|
||||
kwargs["tool_choice"] = {
|
||||
"type": "function",
|
||||
"function": {"name": functions[0].name},
|
||||
}
|
||||
else:
|
||||
# Provide compatibility with older models
|
||||
_functions_compat_fix_kwargs(functions, completion_kwargs)
|
||||
_functions_compat_fix_kwargs(functions, kwargs)
|
||||
|
||||
if extra_headers := self._configuration.extra_request_headers:
|
||||
if completion_kwargs.get("headers"):
|
||||
completion_kwargs["headers"].update(extra_headers)
|
||||
else:
|
||||
completion_kwargs["headers"] = extra_headers.copy()
|
||||
kwargs["extra_headers"] = kwargs.get("extra_headers", {}).update(
|
||||
extra_headers.copy()
|
||||
)
|
||||
|
||||
return completion_kwargs
|
||||
return kwargs
|
||||
|
||||
def _get_embedding_kwargs(
|
||||
self,
|
||||
|
@ -491,122 +512,84 @@ class OpenAIProvider(
|
|||
The kwargs for the embedding API call.
|
||||
|
||||
"""
|
||||
embedding_kwargs = {
|
||||
"model": model_name,
|
||||
**kwargs,
|
||||
**self._credentials.unmasked(),
|
||||
}
|
||||
kwargs.update(self._credentials.get_model_access_kwargs(model_name))
|
||||
|
||||
if extra_headers := self._configuration.extra_request_headers:
|
||||
if embedding_kwargs.get("headers"):
|
||||
embedding_kwargs["headers"].update(extra_headers)
|
||||
else:
|
||||
embedding_kwargs["headers"] = extra_headers.copy()
|
||||
kwargs["extra_headers"] = kwargs.get("extra_headers", {}).update(
|
||||
extra_headers.copy()
|
||||
)
|
||||
|
||||
return embedding_kwargs
|
||||
return kwargs
|
||||
|
||||
def _create_chat_completion(
|
||||
self, messages: list[ChatMessage], *_, **kwargs
|
||||
) -> Coroutine[None, None, ChatCompletion]:
|
||||
"""Create a chat completion using the OpenAI API with retry handling."""
|
||||
|
||||
@self._retry_api_request
|
||||
async def _create_chat_completion_with_retry(
|
||||
messages: list[ChatMessage], *_, **kwargs
|
||||
) -> ChatCompletion:
|
||||
raw_messages = [
|
||||
message.dict(include={"role", "content", "tool_calls", "name"})
|
||||
for message in messages
|
||||
]
|
||||
return await self._client.chat.completions.create(
|
||||
messages=raw_messages, # type: ignore
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
return _create_chat_completion_with_retry(messages, *_, **kwargs)
|
||||
|
||||
def _create_embedding(
|
||||
self, text: str, *_, **kwargs
|
||||
) -> Coroutine[None, None, CreateEmbeddingResponse]:
|
||||
"""Create an embedding using the OpenAI API with retry handling."""
|
||||
|
||||
@self._retry_api_request
|
||||
async def _create_embedding_with_retry(
|
||||
text: str, *_, **kwargs
|
||||
) -> CreateEmbeddingResponse:
|
||||
return await self._client.embeddings.create(
|
||||
input=[text],
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
return _create_embedding_with_retry(text, *_, **kwargs)
|
||||
|
||||
def _retry_api_request(self, func: Callable[_P, _T]) -> Callable[_P, _T]:
|
||||
_log_retry_debug_message = tenacity.after_log(self._logger, logging.DEBUG)
|
||||
|
||||
def _log_on_fail(retry_state: tenacity.RetryCallState) -> None:
|
||||
_log_retry_debug_message(retry_state)
|
||||
|
||||
if (
|
||||
retry_state.attempt_number == 0
|
||||
and retry_state.outcome
|
||||
and isinstance(retry_state.outcome.exception(), RateLimitError)
|
||||
):
|
||||
self._logger.warning(
|
||||
"Please double check that you have setup a PAID OpenAI API Account."
|
||||
" You can read more here: "
|
||||
"https://docs.agpt.co/setup/#getting-an-openai-api-key"
|
||||
)
|
||||
|
||||
return tenacity.retry(
|
||||
retry=(
|
||||
tenacity.retry_if_exception_type(RateLimitError)
|
||||
| tenacity.retry_if_exception(
|
||||
lambda e: isinstance(e, APIStatusError) and e.status_code == 502
|
||||
)
|
||||
),
|
||||
wait=tenacity.wait_exponential(),
|
||||
stop=tenacity.stop_after_attempt(self._configuration.retries_per_request),
|
||||
after=_log_on_fail,
|
||||
)(func)
|
||||
|
||||
def __repr__(self):
|
||||
return "OpenAIProvider()"
|
||||
|
||||
|
||||
async def _create_embedding(text: str, *_, **kwargs) -> openai.Embedding:
|
||||
"""Embed text using the OpenAI API.
|
||||
|
||||
Args:
|
||||
text str: The text to embed.
|
||||
model str: The name of the model to use.
|
||||
|
||||
Returns:
|
||||
str: The embedding.
|
||||
"""
|
||||
return await openai.Embedding.acreate(
|
||||
input=[text],
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
|
||||
async def _create_chat_completion(
|
||||
messages: list[ChatMessage], *_, **kwargs
|
||||
) -> openai.Completion:
|
||||
"""Create a chat completion using the OpenAI API.
|
||||
|
||||
Args:
|
||||
messages: The prompt to use.
|
||||
|
||||
Returns:
|
||||
The completion.
|
||||
"""
|
||||
raw_messages = [
|
||||
message.dict(include={"role", "content", "tool_calls", "name"})
|
||||
for message in messages
|
||||
]
|
||||
return await openai.ChatCompletion.acreate(
|
||||
messages=raw_messages,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
|
||||
class _OpenAIRetryHandler:
|
||||
"""Retry Handler for OpenAI API call.
|
||||
|
||||
Args:
|
||||
num_retries int: Number of retries. Defaults to 10.
|
||||
backoff_base float: Base for exponential backoff. Defaults to 2.
|
||||
warn_user bool: Whether to warn the user. Defaults to True.
|
||||
"""
|
||||
|
||||
_retry_limit_msg = "Error: Reached rate limit, passing..."
|
||||
_api_key_error_msg = (
|
||||
"Please double check that you have setup a PAID OpenAI API Account. You can "
|
||||
"read more here: https://docs.agpt.co/setup/#getting-an-openai-api-key"
|
||||
)
|
||||
_backoff_msg = "Error: API Bad gateway. Waiting {backoff} seconds..."
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
logger: logging.Logger,
|
||||
num_retries: int = 10,
|
||||
backoff_base: float = 2.0,
|
||||
warn_user: bool = True,
|
||||
):
|
||||
self._logger = logger
|
||||
self._num_retries = num_retries
|
||||
self._backoff_base = backoff_base
|
||||
self._warn_user = warn_user
|
||||
|
||||
def _log_rate_limit_error(self) -> None:
|
||||
self._logger.debug(self._retry_limit_msg)
|
||||
if self._warn_user:
|
||||
self._logger.warning(self._api_key_error_msg)
|
||||
self._warn_user = False
|
||||
|
||||
def _backoff(self, attempt: int) -> None:
|
||||
backoff = self._backoff_base ** (attempt + 2)
|
||||
self._logger.debug(self._backoff_msg.format(backoff=backoff))
|
||||
time.sleep(backoff)
|
||||
|
||||
def __call__(self, func: Callable[_P, _T]) -> Callable[_P, _T]:
|
||||
@functools.wraps(func)
|
||||
async def _wrapped(*args: _P.args, **kwargs: _P.kwargs) -> _T:
|
||||
num_attempts = self._num_retries + 1 # +1 for the first attempt
|
||||
for attempt in range(1, num_attempts + 1):
|
||||
try:
|
||||
return await func(*args, **kwargs)
|
||||
|
||||
except RateLimitError:
|
||||
if attempt == num_attempts:
|
||||
raise
|
||||
self._log_rate_limit_error()
|
||||
|
||||
except APIError as e:
|
||||
if (e.http_status != 502) or (attempt == num_attempts):
|
||||
raise
|
||||
|
||||
self._backoff(attempt)
|
||||
|
||||
return _wrapped
|
||||
|
||||
|
||||
def format_function_specs_as_typescript_ns(
|
||||
functions: list[CompletionModelFunction],
|
||||
) -> str:
|
||||
|
@ -730,7 +713,7 @@ def _functions_compat_fix_kwargs(
|
|||
]
|
||||
|
||||
|
||||
def _tool_calls_compat_extract_calls(response: str) -> list[AssistantToolCallDict]:
|
||||
def _tool_calls_compat_extract_calls(response: str) -> Iterator[AssistantToolCall]:
|
||||
import json
|
||||
import re
|
||||
|
||||
|
@ -747,4 +730,4 @@ def _tool_calls_compat_extract_calls(response: str) -> list[AssistantToolCallDic
|
|||
for t in tool_calls:
|
||||
t["function"]["arguments"] = str(t["function"]["arguments"]) # HACK
|
||||
|
||||
return tool_calls
|
||||
yield AssistantToolCall.parse_obj(t)
|
||||
|
|
|
@ -90,7 +90,7 @@ class AssistantToolCallDict(TypedDict):
|
|||
|
||||
|
||||
class AssistantChatMessage(ChatMessage):
|
||||
role: Literal["assistant"]
|
||||
role: Literal["assistant"] = "assistant"
|
||||
content: Optional[str]
|
||||
tool_calls: Optional[list[AssistantToolCall]]
|
||||
|
||||
|
@ -320,7 +320,7 @@ _T = TypeVar("_T")
|
|||
class ChatModelResponse(ModelResponse, Generic[_T]):
|
||||
"""Standard response struct for a response from a language model."""
|
||||
|
||||
response: AssistantChatMessageDict
|
||||
response: AssistantChatMessage
|
||||
parsed_result: _T = None
|
||||
|
||||
|
||||
|
@ -338,7 +338,7 @@ class ChatModelProvider(ModelProvider):
|
|||
self,
|
||||
model_prompt: list[ChatMessage],
|
||||
model_name: str,
|
||||
completion_parser: Callable[[AssistantChatMessageDict], _T] = lambda _: None,
|
||||
completion_parser: Callable[[AssistantChatMessage], _T] = lambda _: None,
|
||||
functions: Optional[list[CompletionModelFunction]] = None,
|
||||
**kwargs,
|
||||
) -> ChatModelResponse[_T]:
|
||||
|
|
|
@ -2,7 +2,7 @@ import logging
|
|||
import sys
|
||||
|
||||
from colorama import Fore, Style
|
||||
from openai.util import logger as openai_logger
|
||||
from openai._base_client import log as openai_logger
|
||||
|
||||
SIMPLE_LOG_FORMAT = "%(asctime)s %(levelname)s %(message)s"
|
||||
DEBUG_LOG_FORMAT = (
|
||||
|
|
|
@ -3,10 +3,13 @@ from __future__ import annotations
|
|||
import logging
|
||||
from typing import List, Optional
|
||||
|
||||
import openai
|
||||
from openai import Model
|
||||
from openai import OpenAI
|
||||
from openai.types import Model
|
||||
|
||||
from autogpt.core.resource.model_providers.openai import OPEN_AI_MODELS
|
||||
from autogpt.core.resource.model_providers.openai import (
|
||||
OPEN_AI_MODELS,
|
||||
OpenAICredentials,
|
||||
)
|
||||
from autogpt.core.resource.model_providers.schema import ChatModelInfo
|
||||
from autogpt.singleton import Singleton
|
||||
|
||||
|
@ -96,16 +99,17 @@ class ApiManager(metaclass=Singleton):
|
|||
"""
|
||||
return self.total_budget
|
||||
|
||||
def get_models(self, **openai_credentials) -> List[Model]:
|
||||
def get_models(self, openai_credentials: OpenAICredentials) -> List[Model]:
|
||||
"""
|
||||
Get list of available GPT models.
|
||||
|
||||
Returns:
|
||||
list: List of available GPT models.
|
||||
|
||||
list[Model]: List of available GPT models.
|
||||
"""
|
||||
if self.models is None:
|
||||
all_models = openai.Model.list(**openai_credentials)["data"]
|
||||
self.models = [model for model in all_models if "gpt" in model["id"]]
|
||||
all_models = (
|
||||
OpenAI(**openai_credentials.get_api_access_kwargs()).models.list().data
|
||||
)
|
||||
self.models = [model for model in all_models if "gpt" in model.id]
|
||||
|
||||
return self.models
|
||||
|
|
|
@ -9,7 +9,7 @@ from pathlib import Path
|
|||
from typing import TYPE_CHECKING, Optional
|
||||
|
||||
from auto_gpt_plugin_template import AutoGPTPluginTemplate
|
||||
from openai.util import logger as openai_logger
|
||||
from openai._base_client import log as openai_logger
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from autogpt.config import Config
|
||||
|
@ -184,7 +184,7 @@ def configure_logging(
|
|||
json_logger.propagate = False
|
||||
|
||||
# Disable debug logging from OpenAI library
|
||||
openai_logger.setLevel(logging.INFO)
|
||||
openai_logger.setLevel(logging.WARNING)
|
||||
|
||||
|
||||
def configure_chat_plugins(config: Config) -> None:
|
||||
|
|
|
@ -1,6 +1,6 @@
|
|||
azure_api_type: azure
|
||||
azure_api_base: your-base-url-for-azure
|
||||
azure_api_version: api-version-for-azure
|
||||
azure_endpoint: your-azure-openai-endpoint
|
||||
azure_model_map:
|
||||
gpt-3.5-turbo: gpt35-deployment-id-for-azure
|
||||
gpt-4: gpt4-deployment-id-for-azure
|
||||
|
|
File diff suppressed because one or more lines are too long
|
@ -43,7 +43,7 @@ hypercorn = "^0.14.4"
|
|||
inflection = "*"
|
||||
jsonschema = "*"
|
||||
numpy = "*"
|
||||
openai = "^0.27.10"
|
||||
openai = "^1.7.2"
|
||||
orjson = "^3.8.10"
|
||||
Pillow = "*"
|
||||
pinecone-client = "^2.2.1"
|
||||
|
@ -60,6 +60,7 @@ redis = "*"
|
|||
requests = "*"
|
||||
selenium = "^4.11.2"
|
||||
spacy = "^3.0.0"
|
||||
tenacity = "^8.2.2"
|
||||
tiktoken = "^0.5.0"
|
||||
webdriver-manager = "*"
|
||||
|
||||
|
|
|
@ -70,7 +70,7 @@ def temp_plugins_config_file():
|
|||
yield config_file
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
@pytest.fixture(scope="function")
|
||||
def config(
|
||||
temp_plugins_config_file: Path,
|
||||
tmp_project_root: Path,
|
||||
|
|
|
@ -97,7 +97,7 @@ def test_json_memory_load_index(config: Config, memory_item: MemoryItem):
|
|||
|
||||
@pytest.mark.vcr
|
||||
@pytest.mark.requires_openai_api_key
|
||||
def test_json_memory_get_relevant(config: Config, patched_api_requestor: None) -> None:
|
||||
def test_json_memory_get_relevant(config: Config, cached_openai_client: None) -> None:
|
||||
index = JSONFileMemory(config)
|
||||
mem1 = MemoryItem.from_text_file("Sample text", "sample.txt", config)
|
||||
mem2 = MemoryItem.from_text_file(
|
||||
|
|
|
@ -18,7 +18,7 @@ def image_size(request):
|
|||
|
||||
@pytest.mark.requires_openai_api_key
|
||||
@pytest.mark.vcr
|
||||
def test_dalle(agent: Agent, workspace, image_size, patched_api_requestor):
|
||||
def test_dalle(agent: Agent, workspace, image_size, cached_openai_client):
|
||||
"""Test DALL-E image generation."""
|
||||
generate_and_validate(
|
||||
agent,
|
||||
|
|
|
@ -7,9 +7,7 @@ from autogpt.commands.web_selenium import BrowsingError, read_webpage
|
|||
@pytest.mark.vcr
|
||||
@pytest.mark.requires_openai_api_key
|
||||
@pytest.mark.asyncio
|
||||
async def test_browse_website_nonexistent_url(
|
||||
agent: Agent, patched_api_requestor: None
|
||||
):
|
||||
async def test_browse_website_nonexistent_url(agent: Agent, cached_openai_client: None):
|
||||
url = "https://auto-gpt-thinks-this-website-does-not-exist.com"
|
||||
question = "How to execute a barrel roll"
|
||||
|
||||
|
|
|
@ -1,5 +1,3 @@
|
|||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
from pytest_mock import MockerFixture
|
||||
|
||||
|
@ -77,13 +75,3 @@ class TestApiManager:
|
|||
assert api_manager.get_total_prompt_tokens() == prompt_tokens
|
||||
assert api_manager.get_total_completion_tokens() == 0
|
||||
assert api_manager.get_total_cost() == (prompt_tokens * 0.0004) / 1000
|
||||
|
||||
@staticmethod
|
||||
def test_get_models():
|
||||
"""Test if getting models works correctly."""
|
||||
with patch("openai.Model.list") as mock_list_models:
|
||||
mock_list_models.return_value = {"data": [{"id": "gpt-3.5-turbo"}]}
|
||||
result = api_manager.get_models()
|
||||
|
||||
assert result[0]["id"] == "gpt-3.5-turbo"
|
||||
assert api_manager.models[0]["id"] == "gpt-3.5-turbo"
|
||||
|
|
|
@ -8,6 +8,8 @@ from unittest import mock
|
|||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
from openai.pagination import SyncPage
|
||||
from openai.types import Model
|
||||
from pydantic import SecretStr
|
||||
|
||||
from autogpt.app.configurator import GPT_3_MODEL, GPT_4_MODEL, apply_overrides_to_config
|
||||
|
@ -80,8 +82,10 @@ def test_set_smart_llm(config: Config) -> None:
|
|||
config.smart_llm = smart_llm
|
||||
|
||||
|
||||
@patch("openai.Model.list")
|
||||
def test_smart_and_fast_llms_set_to_gpt4(mock_list_models: Any, config: Config) -> None:
|
||||
@patch("openai.resources.models.Models.list")
|
||||
def test_fallback_to_gpt3_if_gpt4_not_available(
|
||||
mock_list_models: Any, config: Config
|
||||
) -> None:
|
||||
"""
|
||||
Test if models update to gpt-3.5-turbo if gpt-4 is not available.
|
||||
"""
|
||||
|
@ -91,7 +95,10 @@ def test_smart_and_fast_llms_set_to_gpt4(mock_list_models: Any, config: Config)
|
|||
config.fast_llm = "gpt-4"
|
||||
config.smart_llm = "gpt-4"
|
||||
|
||||
mock_list_models.return_value = {"data": [{"id": "gpt-3.5-turbo"}]}
|
||||
mock_list_models.return_value = SyncPage(
|
||||
data=[Model(id=GPT_3_MODEL, created=0, object="model", owned_by="AutoGPT")],
|
||||
object="Models", # no idea what this should be, but irrelevant
|
||||
)
|
||||
|
||||
apply_overrides_to_config(
|
||||
config=config,
|
||||
|
@ -123,74 +130,80 @@ def test_missing_azure_config(config: Config) -> None:
|
|||
assert config.openai_credentials.azure_model_to_deploy_id_map is None
|
||||
|
||||
|
||||
def test_azure_config(config: Config) -> None:
|
||||
@pytest.fixture
|
||||
def config_with_azure(config: Config):
|
||||
config_file = config.app_data_dir / "azure_config.yaml"
|
||||
config_file.write_text(
|
||||
f"""
|
||||
azure_api_type: azure
|
||||
azure_api_base: https://dummy.openai.azure.com
|
||||
azure_api_version: 2023-06-01-preview
|
||||
azure_endpoint: https://dummy.openai.azure.com
|
||||
azure_model_map:
|
||||
{config.fast_llm}: FAST-LLM_ID
|
||||
{config.smart_llm}: SMART-LLM_ID
|
||||
{config.embedding_model}: embedding-deployment-id-for-azure
|
||||
"""
|
||||
)
|
||||
|
||||
os.environ["USE_AZURE"] = "True"
|
||||
os.environ["AZURE_CONFIG_FILE"] = str(config_file)
|
||||
config = ConfigBuilder.build_config_from_env(project_root=config.project_root)
|
||||
|
||||
assert (credentials := config.openai_credentials) is not None
|
||||
assert credentials.api_type == "azure"
|
||||
assert credentials.api_base == SecretStr("https://dummy.openai.azure.com")
|
||||
assert credentials.api_version == "2023-06-01-preview"
|
||||
assert credentials.azure_model_to_deploy_id_map == {
|
||||
config.fast_llm: "FAST-LLM_ID",
|
||||
config.smart_llm: "SMART-LLM_ID",
|
||||
config.embedding_model: "embedding-deployment-id-for-azure",
|
||||
}
|
||||
|
||||
fast_llm = config.fast_llm
|
||||
smart_llm = config.smart_llm
|
||||
assert (
|
||||
credentials.get_api_access_kwargs(config.fast_llm)["deployment_id"]
|
||||
== "FAST-LLM_ID"
|
||||
config_with_azure = ConfigBuilder.build_config_from_env(
|
||||
project_root=config.project_root
|
||||
)
|
||||
assert (
|
||||
credentials.get_api_access_kwargs(config.smart_llm)["deployment_id"]
|
||||
== "SMART-LLM_ID"
|
||||
)
|
||||
|
||||
# Emulate --gpt4only
|
||||
config.fast_llm = smart_llm
|
||||
assert (
|
||||
credentials.get_api_access_kwargs(config.fast_llm)["deployment_id"]
|
||||
== "SMART-LLM_ID"
|
||||
)
|
||||
assert (
|
||||
credentials.get_api_access_kwargs(config.smart_llm)["deployment_id"]
|
||||
== "SMART-LLM_ID"
|
||||
)
|
||||
|
||||
# Emulate --gpt3only
|
||||
config.fast_llm = config.smart_llm = fast_llm
|
||||
assert (
|
||||
credentials.get_api_access_kwargs(config.fast_llm)["deployment_id"]
|
||||
== "FAST-LLM_ID"
|
||||
)
|
||||
assert (
|
||||
credentials.get_api_access_kwargs(config.smart_llm)["deployment_id"]
|
||||
== "FAST-LLM_ID"
|
||||
)
|
||||
|
||||
yield config_with_azure
|
||||
del os.environ["USE_AZURE"]
|
||||
del os.environ["AZURE_CONFIG_FILE"]
|
||||
|
||||
|
||||
def test_azure_config(config_with_azure: Config) -> None:
|
||||
assert (credentials := config_with_azure.openai_credentials) is not None
|
||||
assert credentials.api_type == "azure"
|
||||
assert credentials.api_version == "2023-06-01-preview"
|
||||
assert credentials.azure_endpoint == SecretStr("https://dummy.openai.azure.com")
|
||||
assert credentials.azure_model_to_deploy_id_map == {
|
||||
config_with_azure.fast_llm: "FAST-LLM_ID",
|
||||
config_with_azure.smart_llm: "SMART-LLM_ID",
|
||||
config_with_azure.embedding_model: "embedding-deployment-id-for-azure",
|
||||
}
|
||||
|
||||
fast_llm = config_with_azure.fast_llm
|
||||
smart_llm = config_with_azure.smart_llm
|
||||
assert (
|
||||
credentials.get_model_access_kwargs(config_with_azure.fast_llm)["model"]
|
||||
== "FAST-LLM_ID"
|
||||
)
|
||||
assert (
|
||||
credentials.get_model_access_kwargs(config_with_azure.smart_llm)["model"]
|
||||
== "SMART-LLM_ID"
|
||||
)
|
||||
|
||||
# Emulate --gpt4only
|
||||
config_with_azure.fast_llm = smart_llm
|
||||
assert (
|
||||
credentials.get_model_access_kwargs(config_with_azure.fast_llm)["model"]
|
||||
== "SMART-LLM_ID"
|
||||
)
|
||||
assert (
|
||||
credentials.get_model_access_kwargs(config_with_azure.smart_llm)["model"]
|
||||
== "SMART-LLM_ID"
|
||||
)
|
||||
|
||||
# Emulate --gpt3only
|
||||
config_with_azure.fast_llm = config_with_azure.smart_llm = fast_llm
|
||||
assert (
|
||||
credentials.get_model_access_kwargs(config_with_azure.fast_llm)["model"]
|
||||
== "FAST-LLM_ID"
|
||||
)
|
||||
assert (
|
||||
credentials.get_model_access_kwargs(config_with_azure.smart_llm)["model"]
|
||||
== "FAST-LLM_ID"
|
||||
)
|
||||
|
||||
|
||||
def test_create_config_gpt4only(config: Config) -> None:
|
||||
with mock.patch("autogpt.llm.api_manager.ApiManager.get_models") as mock_get_models:
|
||||
mock_get_models.return_value = [{"id": GPT_4_MODEL}]
|
||||
mock_get_models.return_value = [
|
||||
Model(id=GPT_4_MODEL, created=0, object="model", owned_by="AutoGPT")
|
||||
]
|
||||
apply_overrides_to_config(
|
||||
config=config,
|
||||
gpt4only=True,
|
||||
|
|
|
@ -2,8 +2,11 @@ import logging
|
|||
import os
|
||||
from hashlib import sha256
|
||||
|
||||
import openai.api_requestor
|
||||
import pytest
|
||||
from openai import OpenAI
|
||||
from openai._models import FinalRequestOptions
|
||||
from openai._types import Omit
|
||||
from openai._utils import is_given
|
||||
from pytest_mock import MockerFixture
|
||||
|
||||
from .vcr_filter import (
|
||||
|
@ -52,30 +55,26 @@ def vcr_cassette_dir(request):
|
|||
return os.path.join("tests/vcr_cassettes", test_name)
|
||||
|
||||
|
||||
def patch_api_base(requestor: openai.api_requestor.APIRequestor):
|
||||
new_api_base = f"{PROXY}/v1"
|
||||
requestor.api_base = new_api_base
|
||||
return requestor
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def patched_api_requestor(mocker: MockerFixture):
|
||||
init_requestor = openai.api_requestor.APIRequestor.__init__
|
||||
prepare_request = openai.api_requestor.APIRequestor._prepare_request_raw
|
||||
def cached_openai_client(mocker: MockerFixture) -> OpenAI:
|
||||
client = OpenAI()
|
||||
_prepare_options = client._prepare_options
|
||||
|
||||
def patched_init_requestor(requestor, *args, **kwargs):
|
||||
init_requestor(requestor, *args, **kwargs)
|
||||
patch_api_base(requestor)
|
||||
def _patched_prepare_options(self, options: FinalRequestOptions):
|
||||
_prepare_options(options)
|
||||
|
||||
def patched_prepare_request(self, *args, **kwargs):
|
||||
url, headers, data = prepare_request(self, *args, **kwargs)
|
||||
headers: dict[str, str | Omit] = (
|
||||
{**options.headers} if is_given(options.headers) else {}
|
||||
)
|
||||
options.headers = headers
|
||||
data: dict = options.json_data
|
||||
|
||||
if PROXY:
|
||||
headers["AGENT-MODE"] = os.environ.get("AGENT_MODE")
|
||||
headers["AGENT-TYPE"] = os.environ.get("AGENT_TYPE")
|
||||
headers["AGENT-MODE"] = os.environ.get("AGENT_MODE", Omit())
|
||||
headers["AGENT-TYPE"] = os.environ.get("AGENT_TYPE", Omit())
|
||||
|
||||
logging.getLogger("patched_api_requestor").debug(
|
||||
f"Outgoing API request: {headers}\n{data.decode() if data else None}"
|
||||
logging.getLogger("cached_openai_client").debug(
|
||||
f"Outgoing API request: {headers}\n{data if data else None}"
|
||||
)
|
||||
|
||||
# Add hash header for cheap & fast matching on cassette playback
|
||||
|
@ -83,16 +82,12 @@ def patched_api_requestor(mocker: MockerFixture):
|
|||
freeze_request_body(data), usedforsecurity=False
|
||||
).hexdigest()
|
||||
|
||||
return url, headers, data
|
||||
|
||||
if PROXY:
|
||||
mocker.patch.object(
|
||||
openai.api_requestor.APIRequestor,
|
||||
"__init__",
|
||||
new=patched_init_requestor,
|
||||
)
|
||||
client.base_url = f"{PROXY}/v1"
|
||||
mocker.patch.object(
|
||||
openai.api_requestor.APIRequestor,
|
||||
"_prepare_request_raw",
|
||||
new=patched_prepare_request,
|
||||
client,
|
||||
"_prepare_options",
|
||||
new=_patched_prepare_options,
|
||||
)
|
||||
|
||||
return client
|
||||
|
|
|
@ -44,14 +44,9 @@ def replace_message_content(content: str, replacements: List[Dict[str, str]]) ->
|
|||
return content
|
||||
|
||||
|
||||
def freeze_request_body(json_body: str | bytes) -> bytes:
|
||||
def freeze_request_body(body: dict) -> bytes:
|
||||
"""Remove any dynamic items from the request body"""
|
||||
|
||||
try:
|
||||
body = json.loads(json_body)
|
||||
except ValueError:
|
||||
return json_body if type(json_body) is bytes else json_body.encode()
|
||||
|
||||
if "messages" not in body:
|
||||
return json.dumps(body, sort_keys=True).encode()
|
||||
|
||||
|
@ -74,9 +69,11 @@ def freeze_request(request: Request) -> Request:
|
|||
|
||||
with contextlib.suppress(ValueError):
|
||||
request.body = freeze_request_body(
|
||||
request.body.getvalue()
|
||||
if isinstance(request.body, BytesIO)
|
||||
else request.body
|
||||
json.loads(
|
||||
request.body.getvalue()
|
||||
if isinstance(request.body, BytesIO)
|
||||
else request.body
|
||||
)
|
||||
)
|
||||
|
||||
return request
|
||||
|
|
Loading…
Reference in New Issue