fix(agent/llm): Fix support for AzureOpenAI (#6927)

* Fix unmasking of `azure_endpoint` in `OpenAICredentials.get_api_access_kwargs()`
* Amend `ApiManager.get_models` to use `AzureOpenAI` client when `api_type` is set to `azure`

---------

Co-authored-by: Reinier van der Leer <pwuts@agpt.co>
pull/6946/head
edwardsp 2024-02-29 17:35:06 +00:00 committed by GitHub
parent ce45c9b267
commit 50e5ea4e54
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 15 additions and 5 deletions

View File

@ -243,7 +243,8 @@ class OpenAICredentials(ModelProviderCredentials):
}
if self.api_type == "azure":
kwargs["api_version"] = self.api_version
kwargs["azure_endpoint"] = self.azure_endpoint
assert self.azure_endpoint, "Azure endpoint not configured"
kwargs["azure_endpoint"] = self.azure_endpoint.get_secret_value()
return kwargs
def get_model_access_kwargs(self, model: str) -> dict[str, str]:

View File

@ -3,7 +3,7 @@ from __future__ import annotations
import logging
from typing import List, Optional
from openai import OpenAI
from openai import OpenAI, AzureOpenAI
from openai.types import Model
from autogpt.core.resource.model_providers.openai import (
@ -107,9 +107,18 @@ class ApiManager(metaclass=Singleton):
list[Model]: List of available GPT models.
"""
if self.models is None:
all_models = (
OpenAI(**openai_credentials.get_api_access_kwargs()).models.list().data
)
if openai_credentials.api_type == "azure":
all_models = (
AzureOpenAI(**openai_credentials.get_api_access_kwargs())
.models.list()
.data
)
else:
all_models = (
OpenAI(**openai_credentials.get_api_access_kwargs())
.models.list()
.data
)
self.models = [model for model in all_models if "gpt" in model.id]
return self.models