chore(agent/llm): Update OpenAI model info

- Add `text-embedding-3-small` and `text-embedding-3-large` as `EMBEDDING_v3_S` and `EMBEDDING_v3_L` respectively
- Add `gpt-3.5-turbo-0125` as `GPT3_v4`
- Add `gpt-4-1106-vision-preview` as `GPT4_v3_VISION`
- Add GPT-4V models to info map
- Change chat model info mapping to derive info for aliases (e.g. `gpt-3.5-turbo`) from specific versions instead of the other way around
pull/6839/head
Reinier van der Leer 2024-02-12 12:59:58 +01:00
parent 14c9773890
commit 7bf9ba5502
No known key found for this signature in database
GPG Key ID: CDC1180FDAE06193
1 changed files with 57 additions and 21 deletions

View File

@ -46,12 +46,15 @@ OpenAIChatParser = Callable[[str], dict]
class OpenAIModelName(str, enum.Enum): class OpenAIModelName(str, enum.Enum):
ADA = "text-embedding-ada-002" EMBEDDING_v2 = "text-embedding-ada-002"
EMBEDDING_v3_S = "text-embedding-3-small"
EMBEDDING_v3_L = "text-embedding-3-large"
GPT3_v1 = "gpt-3.5-turbo-0301" GPT3_v1 = "gpt-3.5-turbo-0301"
GPT3_v2 = "gpt-3.5-turbo-0613" GPT3_v2 = "gpt-3.5-turbo-0613"
GPT3_v2_16k = "gpt-3.5-turbo-16k-0613" GPT3_v2_16k = "gpt-3.5-turbo-16k-0613"
GPT3_v3 = "gpt-3.5-turbo-1106" GPT3_v3 = "gpt-3.5-turbo-1106"
GPT3_v4 = "gpt-3.5-turbo-0125"
GPT3_ROLLING = "gpt-3.5-turbo" GPT3_ROLLING = "gpt-3.5-turbo"
GPT3_ROLLING_16k = "gpt-3.5-turbo-16k" GPT3_ROLLING_16k = "gpt-3.5-turbo-16k"
GPT3 = GPT3_ROLLING GPT3 = GPT3_ROLLING
@ -62,6 +65,7 @@ class OpenAIModelName(str, enum.Enum):
GPT4_v2 = "gpt-4-0613" GPT4_v2 = "gpt-4-0613"
GPT4_v2_32k = "gpt-4-32k-0613" GPT4_v2_32k = "gpt-4-32k-0613"
GPT4_v3 = "gpt-4-1106-preview" GPT4_v3 = "gpt-4-1106-preview"
GPT4_v3_VISION = "gpt-4-1106-vision-preview"
GPT4_v4 = "gpt-4-0125-preview" GPT4_v4 = "gpt-4-0125-preview"
GPT4_ROLLING = "gpt-4" GPT4_ROLLING = "gpt-4"
GPT4_ROLLING_32k = "gpt-4-32k" GPT4_ROLLING_32k = "gpt-4-32k"
@ -72,14 +76,33 @@ class OpenAIModelName(str, enum.Enum):
OPEN_AI_EMBEDDING_MODELS = { OPEN_AI_EMBEDDING_MODELS = {
OpenAIModelName.ADA: EmbeddingModelInfo( info.name: info
name=OpenAIModelName.ADA, for info in [
service=ModelProviderService.EMBEDDING, EmbeddingModelInfo(
provider_name=ModelProviderName.OPENAI, name=OpenAIModelName.EMBEDDING_v2,
prompt_token_cost=0.0001 / 1000, service=ModelProviderService.EMBEDDING,
max_tokens=8191, provider_name=ModelProviderName.OPENAI,
embedding_dimensions=1536, prompt_token_cost=0.0001 / 1000,
), max_tokens=8191,
embedding_dimensions=1536,
),
EmbeddingModelInfo(
name=OpenAIModelName.EMBEDDING_v3_S,
service=ModelProviderService.EMBEDDING,
provider_name=ModelProviderName.OPENAI,
prompt_token_cost=0.00002 / 1000,
max_tokens=8191,
embedding_dimensions=1536,
),
EmbeddingModelInfo(
name=OpenAIModelName.EMBEDDING_v3_L,
service=ModelProviderService.EMBEDDING,
provider_name=ModelProviderName.OPENAI,
prompt_token_cost=0.00013 / 1000,
max_tokens=8191,
embedding_dimensions=3072,
),
]
} }
@ -87,7 +110,7 @@ OPEN_AI_CHAT_MODELS = {
info.name: info info.name: info
for info in [ for info in [
ChatModelInfo( ChatModelInfo(
name=OpenAIModelName.GPT3, name=OpenAIModelName.GPT3_v1,
service=ModelProviderService.CHAT, service=ModelProviderService.CHAT,
provider_name=ModelProviderName.OPENAI, provider_name=ModelProviderName.OPENAI,
prompt_token_cost=0.0015 / 1000, prompt_token_cost=0.0015 / 1000,
@ -96,7 +119,7 @@ OPEN_AI_CHAT_MODELS = {
has_function_call_api=True, has_function_call_api=True,
), ),
ChatModelInfo( ChatModelInfo(
name=OpenAIModelName.GPT3_16k, name=OpenAIModelName.GPT3_v2_16k,
service=ModelProviderService.CHAT, service=ModelProviderService.CHAT,
provider_name=ModelProviderName.OPENAI, provider_name=ModelProviderName.OPENAI,
prompt_token_cost=0.003 / 1000, prompt_token_cost=0.003 / 1000,
@ -114,7 +137,16 @@ OPEN_AI_CHAT_MODELS = {
has_function_call_api=True, has_function_call_api=True,
), ),
ChatModelInfo( ChatModelInfo(
name=OpenAIModelName.GPT4, name=OpenAIModelName.GPT3_v4,
service=ModelProviderService.CHAT,
provider_name=ModelProviderName.OPENAI,
prompt_token_cost=0.0005 / 1000,
completion_token_cost=0.0015 / 1000,
max_tokens=16384,
has_function_call_api=True,
),
ChatModelInfo(
name=OpenAIModelName.GPT4_v1,
service=ModelProviderService.CHAT, service=ModelProviderService.CHAT,
provider_name=ModelProviderName.OPENAI, provider_name=ModelProviderName.OPENAI,
prompt_token_cost=0.03 / 1000, prompt_token_cost=0.03 / 1000,
@ -123,7 +155,7 @@ OPEN_AI_CHAT_MODELS = {
has_function_call_api=True, has_function_call_api=True,
), ),
ChatModelInfo( ChatModelInfo(
name=OpenAIModelName.GPT4_32k, name=OpenAIModelName.GPT4_v1_32k,
service=ModelProviderService.CHAT, service=ModelProviderService.CHAT,
provider_name=ModelProviderName.OPENAI, provider_name=ModelProviderName.OPENAI,
prompt_token_cost=0.06 / 1000, prompt_token_cost=0.06 / 1000,
@ -144,19 +176,23 @@ OPEN_AI_CHAT_MODELS = {
} }
# Copy entries for models with equivalent specs # Copy entries for models with equivalent specs
chat_model_mapping = { chat_model_mapping = {
OpenAIModelName.GPT3: [OpenAIModelName.GPT3_v1, OpenAIModelName.GPT3_v2], OpenAIModelName.GPT3_v1: [OpenAIModelName.GPT3_v2, OpenAIModelName.GPT3_ROLLING],
OpenAIModelName.GPT3_16k: [OpenAIModelName.GPT3_v2_16k], OpenAIModelName.GPT3_v2_16k: [OpenAIModelName.GPT3_16k],
OpenAIModelName.GPT4: [OpenAIModelName.GPT4_v1, OpenAIModelName.GPT4_v2], OpenAIModelName.GPT4_v1: [OpenAIModelName.GPT4_v2, OpenAIModelName.GPT4_ROLLING],
OpenAIModelName.GPT4_32k: [ OpenAIModelName.GPT4_v1_32k: [
OpenAIModelName.GPT4_v1_32k,
OpenAIModelName.GPT4_v2_32k, OpenAIModelName.GPT4_v2_32k,
OpenAIModelName.GPT4_32k,
],
OpenAIModelName.GPT4_TURBO: [
OpenAIModelName.GPT4_v3,
OpenAIModelName.GPT4_v3_VISION,
OpenAIModelName.GPT4_v4,
OpenAIModelName.GPT4_VISION,
], ],
OpenAIModelName.GPT4_TURBO: [OpenAIModelName.GPT4_v3, OpenAIModelName.GPT4_v4],
} }
for base, copies in chat_model_mapping.items(): for base, copies in chat_model_mapping.items():
for copy in copies: for copy in copies:
copy_info = ChatModelInfo(**OPEN_AI_CHAT_MODELS[base].__dict__) copy_info = OPEN_AI_CHAT_MODELS[base].copy(update={"name": copy})
copy_info.name = copy
OPEN_AI_CHAT_MODELS[copy] = copy_info OPEN_AI_CHAT_MODELS[copy] = copy_info
if copy.endswith(("-0301", "-0314")): if copy.endswith(("-0301", "-0314")):
copy_info.has_function_call_api = False copy_info.has_function_call_api = False