feat(platform): Add OpenAI reasoning models (#8152)
parent
6da8007ce0
commit
81d1be73cd
|
@ -30,6 +30,8 @@ class ModelMetadata(NamedTuple):
|
|||
|
||||
class LlmModel(str, Enum):
|
||||
# OpenAI models
|
||||
O1_PREVIEW = "o1-preview"
|
||||
O1_MINI = "o1-mini"
|
||||
GPT4O_MINI = "gpt-4o-mini"
|
||||
GPT4O = "gpt-4o"
|
||||
GPT4_TURBO = "gpt-4-turbo"
|
||||
|
@ -57,6 +59,8 @@ class LlmModel(str, Enum):
|
|||
|
||||
|
||||
MODEL_METADATA = {
|
||||
LlmModel.O1_PREVIEW: ModelMetadata("openai", 32000, cost_factor=60),
|
||||
LlmModel.O1_MINI: ModelMetadata("openai", 62000, cost_factor=30),
|
||||
LlmModel.GPT4O_MINI: ModelMetadata("openai", 128000, cost_factor=10),
|
||||
LlmModel.GPT4O: ModelMetadata("openai", 128000, cost_factor=12),
|
||||
LlmModel.GPT4_TURBO: ModelMetadata("openai", 128000, cost_factor=11),
|
||||
|
@ -84,7 +88,10 @@ for model in LlmModel:
|
|||
class AIStructuredResponseGeneratorBlock(Block):
|
||||
class Input(BlockSchema):
|
||||
prompt: str
|
||||
expected_format: dict[str, str]
|
||||
expected_format: dict[str, str] = SchemaField(
|
||||
description="Expected format of the response. If provided, the response will be validated against this format. "
|
||||
"The keys should be the expected fields in the response, and the values should be the description of the field.",
|
||||
)
|
||||
model: LlmModel = LlmModel.GPT4_TURBO
|
||||
api_key: BlockSecret = SecretField(value="")
|
||||
sys_prompt: str = ""
|
||||
|
@ -132,7 +139,18 @@ class AIStructuredResponseGeneratorBlock(Block):
|
|||
|
||||
if provider == "openai":
|
||||
openai.api_key = api_key
|
||||
response_format = {"type": "json_object"} if json_format else None
|
||||
response_format = None
|
||||
|
||||
if model in [LlmModel.O1_MINI, LlmModel.O1_PREVIEW]:
|
||||
sys_messages = [p["content"] for p in prompt if p["role"] == "system"]
|
||||
usr_messages = [p["content"] for p in prompt if p["role"] != "system"]
|
||||
prompt = [
|
||||
{"role": "user", "content": "\n".join(sys_messages)},
|
||||
{"role": "user", "content": "\n".join(usr_messages)},
|
||||
]
|
||||
elif json_format:
|
||||
response_format = {"type": "json_object"}
|
||||
|
||||
response = openai.chat.completions.create(
|
||||
model=model.value,
|
||||
messages=prompt, # type: ignore
|
||||
|
@ -207,11 +225,11 @@ class AIStructuredResponseGeneratorBlock(Block):
|
|||
format_prompt = ",\n ".join(expected_format)
|
||||
sys_prompt = trim_prompt(
|
||||
f"""
|
||||
|Reply in json format:
|
||||
|{{
|
||||
| {format_prompt}
|
||||
|}}
|
||||
"""
|
||||
|Reply strictly only in the following JSON format:
|
||||
|{{
|
||||
| {format_prompt}
|
||||
|}}
|
||||
"""
|
||||
)
|
||||
prompt.append({"role": "system", "content": sys_prompt})
|
||||
|
||||
|
|
Loading…
Reference in New Issue