Fix gemini api format conversion (#122403)

* Fix gemini api format conversion

* add tests

* fix tests

* fix tests

* fix coverage
pull/122770/head
Denis Shulyaka 2024-07-23 03:56:13 +03:00 committed by Franck Nijhof
parent 7135a919e3
commit 56f51d3e35
No known key found for this signature in database
GPG Key ID: D62583BA8AB11CA3
3 changed files with 38 additions and 2 deletions

View File

@ -73,6 +73,14 @@ SUPPORTED_SCHEMA_KEYS = {
def _format_schema(schema: dict[str, Any]) -> dict[str, Any]:
"""Format the schema to protobuf."""
if (subschemas := schema.get("anyOf")) or (subschemas := schema.get("allOf")):
for subschema in subschemas: # Gemini API does not support anyOf and allOf keys
if "type" in subschema: # Fallback to first subschema with 'type' field
return _format_schema(subschema)
return _format_schema(
subschemas[0]
) # Or, if not found, to any of the subschemas
result = {}
for key, val in schema.items():
if key not in SUPPORTED_SCHEMA_KEYS:
@ -81,7 +89,9 @@ def _format_schema(schema: dict[str, Any]) -> dict[str, Any]:
key = "type_"
val = val.upper()
elif key == "format":
if schema.get("type") == "string" and val != "enum":
if (schema.get("type") == "string" and val != "enum") or (
schema.get("type") not in ("number", "integer", "string")
):
continue
key = "format_"
elif key == "items":
@ -89,6 +99,12 @@ def _format_schema(schema: dict[str, Any]) -> dict[str, Any]:
elif key == "properties":
val = {k: _format_schema(v) for k, v in val.items()}
result[key] = val
if result.get("type_") == "OBJECT" and not result.get("properties"):
# An object with undefined properties is not supported by Gemini API.
# Fallback to JSON string. This will probably fail for most tools that want it,
# but we don't have a better fallback strategy so far.
result["properties"] = {"json": {"type_": "STRING"}}
return result

View File

@ -442,6 +442,24 @@
description: "Test function"
parameters {
type_: OBJECT
properties {
key: "param3"
value {
type_: OBJECT
properties {
key: "json"
value {
type_: STRING
}
}
}
}
properties {
key: "param2"
value {
type_: NUMBER
}
}
properties {
key: "param1"
value {

View File

@ -185,7 +185,9 @@ async def test_function_call(
{
vol.Optional("param1", description="Test parameters"): [
vol.All(str, vol.Lower)
]
],
vol.Optional("param2"): vol.Any(float, int),
vol.Optional("param3"): dict,
}
)