Properly handle escaped unicode characters passed to tools in Google Generative AI (#119117)

pull/118435/head
tronikos 2024-06-08 00:02:00 -07:00 committed by GitHub
parent f07e7ec543
commit f605c10f42
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 25 additions and 9 deletions

View File

@ -2,6 +2,7 @@
from __future__ import annotations
import codecs
from typing import Any, Literal
from google.api_core.exceptions import GoogleAPICallError
@ -106,14 +107,14 @@ def _format_tool(tool: llm.Tool) -> dict[str, Any]:
)
def _adjust_value(value: Any) -> Any:
"""Reverse unnecessary single quotes escaping."""
def _escape_decode(value: Any) -> Any:
"""Recursively call codecs.escape_decode on all values."""
if isinstance(value, str):
return value.replace("\\'", "'")
return codecs.escape_decode(bytes(value, "utf-8"))[0].decode("utf-8") # type: ignore[attr-defined]
if isinstance(value, list):
return [_adjust_value(item) for item in value]
return [_escape_decode(item) for item in value]
if isinstance(value, dict):
return {k: _adjust_value(v) for k, v in value.items()}
return {k: _escape_decode(v) for k, v in value.items()}
return value
@ -334,10 +335,7 @@ class GoogleGenerativeAIConversationEntity(
for function_call in function_calls:
tool_call = MessageToDict(function_call._pb) # noqa: SLF001
tool_name = tool_call["name"]
tool_args = {
key: _adjust_value(value)
for key, value in tool_call["args"].items()
}
tool_args = _escape_decode(tool_call["args"])
LOGGER.debug("Tool call: %s(%s)", tool_name, tool_args)
tool_input = llm.ToolInput(tool_name=tool_name, tool_args=tool_args)
try:

View File

@ -12,6 +12,9 @@ import voluptuous as vol
from homeassistant.components import conversation
from homeassistant.components.conversation import trace
from homeassistant.components.google_generative_ai_conversation.conversation import (
_escape_decode,
)
from homeassistant.const import CONF_LLM_HASS_API
from homeassistant.core import Context, HomeAssistant
from homeassistant.exceptions import HomeAssistantError
@ -504,3 +507,18 @@ async def test_conversation_agent(
mock_config_entry.entry_id
)
assert agent.supported_languages == "*"
async def test_escape_decode() -> None:
"""Test _escape_decode."""
assert _escape_decode(
{
"param1": ["test_value", "param1\\'s value"],
"param2": "param2\\'s value",
"param3": {"param31": "Cheminée", "param32": "Chemin\\303\\251e"},
}
) == {
"param1": ["test_value", "param1's value"],
"param2": "param2's value",
"param3": {"param31": "Cheminée", "param32": "Cheminée"},
}