feat(platform): Add OpenAI reasoning models (#8152)

pull/8143/head^2
Zamil Majdy 2024-09-24 18:11:15 -05:00 committed by GitHub
parent 6da8007ce0
commit 81d1be73cd
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
1 changed files with 25 additions and 7 deletions

View File

@ -30,6 +30,8 @@ class ModelMetadata(NamedTuple):
class LlmModel(str, Enum):
# OpenAI models
O1_PREVIEW = "o1-preview"
O1_MINI = "o1-mini"
GPT4O_MINI = "gpt-4o-mini"
GPT4O = "gpt-4o"
GPT4_TURBO = "gpt-4-turbo"
@ -57,6 +59,8 @@ class LlmModel(str, Enum):
MODEL_METADATA = {
LlmModel.O1_PREVIEW: ModelMetadata("openai", 32000, cost_factor=60),
LlmModel.O1_MINI: ModelMetadata("openai", 62000, cost_factor=30),
LlmModel.GPT4O_MINI: ModelMetadata("openai", 128000, cost_factor=10),
LlmModel.GPT4O: ModelMetadata("openai", 128000, cost_factor=12),
LlmModel.GPT4_TURBO: ModelMetadata("openai", 128000, cost_factor=11),
@ -84,7 +88,10 @@ for model in LlmModel:
class AIStructuredResponseGeneratorBlock(Block):
class Input(BlockSchema):
prompt: str
expected_format: dict[str, str]
expected_format: dict[str, str] = SchemaField(
description="Expected format of the response. If provided, the response will be validated against this format. "
"The keys should be the expected fields in the response, and the values should be the description of the field.",
)
model: LlmModel = LlmModel.GPT4_TURBO
api_key: BlockSecret = SecretField(value="")
sys_prompt: str = ""
@ -132,7 +139,18 @@ class AIStructuredResponseGeneratorBlock(Block):
if provider == "openai":
openai.api_key = api_key
response_format = {"type": "json_object"} if json_format else None
response_format = None
if model in [LlmModel.O1_MINI, LlmModel.O1_PREVIEW]:
sys_messages = [p["content"] for p in prompt if p["role"] == "system"]
usr_messages = [p["content"] for p in prompt if p["role"] != "system"]
prompt = [
{"role": "user", "content": "\n".join(sys_messages)},
{"role": "user", "content": "\n".join(usr_messages)},
]
elif json_format:
response_format = {"type": "json_object"}
response = openai.chat.completions.create(
model=model.value,
messages=prompt, # type: ignore
@ -207,11 +225,11 @@ class AIStructuredResponseGeneratorBlock(Block):
format_prompt = ",\n ".join(expected_format)
sys_prompt = trim_prompt(
f"""
|Reply in json format:
|{{
| {format_prompt}
|}}
"""
|Reply strictly only in the following JSON format:
|{{
| {format_prompt}
|}}
"""
)
prompt.append({"role": "system", "content": sys_prompt})