Fix gemini api format conversion (#122403)
* Fix gemini api format conversion * add tests * fix tests * fix tests * fix coveragepull/122770/head
parent
7135a919e3
commit
56f51d3e35
|
@ -73,6 +73,14 @@ SUPPORTED_SCHEMA_KEYS = {
|
|||
|
||||
def _format_schema(schema: dict[str, Any]) -> dict[str, Any]:
|
||||
"""Format the schema to protobuf."""
|
||||
if (subschemas := schema.get("anyOf")) or (subschemas := schema.get("allOf")):
|
||||
for subschema in subschemas: # Gemini API does not support anyOf and allOf keys
|
||||
if "type" in subschema: # Fallback to first subschema with 'type' field
|
||||
return _format_schema(subschema)
|
||||
return _format_schema(
|
||||
subschemas[0]
|
||||
) # Or, if not found, to any of the subschemas
|
||||
|
||||
result = {}
|
||||
for key, val in schema.items():
|
||||
if key not in SUPPORTED_SCHEMA_KEYS:
|
||||
|
@ -81,7 +89,9 @@ def _format_schema(schema: dict[str, Any]) -> dict[str, Any]:
|
|||
key = "type_"
|
||||
val = val.upper()
|
||||
elif key == "format":
|
||||
if schema.get("type") == "string" and val != "enum":
|
||||
if (schema.get("type") == "string" and val != "enum") or (
|
||||
schema.get("type") not in ("number", "integer", "string")
|
||||
):
|
||||
continue
|
||||
key = "format_"
|
||||
elif key == "items":
|
||||
|
@ -89,6 +99,12 @@ def _format_schema(schema: dict[str, Any]) -> dict[str, Any]:
|
|||
elif key == "properties":
|
||||
val = {k: _format_schema(v) for k, v in val.items()}
|
||||
result[key] = val
|
||||
|
||||
if result.get("type_") == "OBJECT" and not result.get("properties"):
|
||||
# An object with undefined properties is not supported by Gemini API.
|
||||
# Fallback to JSON string. This will probably fail for most tools that want it,
|
||||
# but we don't have a better fallback strategy so far.
|
||||
result["properties"] = {"json": {"type_": "STRING"}}
|
||||
return result
|
||||
|
||||
|
||||
|
|
|
@ -442,6 +442,24 @@
|
|||
description: "Test function"
|
||||
parameters {
|
||||
type_: OBJECT
|
||||
properties {
|
||||
key: "param3"
|
||||
value {
|
||||
type_: OBJECT
|
||||
properties {
|
||||
key: "json"
|
||||
value {
|
||||
type_: STRING
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
properties {
|
||||
key: "param2"
|
||||
value {
|
||||
type_: NUMBER
|
||||
}
|
||||
}
|
||||
properties {
|
||||
key: "param1"
|
||||
value {
|
||||
|
|
|
@ -185,7 +185,9 @@ async def test_function_call(
|
|||
{
|
||||
vol.Optional("param1", description="Test parameters"): [
|
||||
vol.All(str, vol.Lower)
|
||||
]
|
||||
],
|
||||
vol.Optional("param2"): vol.Any(float, int),
|
||||
vol.Optional("param3"): dict,
|
||||
}
|
||||
)
|
||||
|
||||
|
|
Loading…
Reference in New Issue