Properly handle escaped unicode characters passed to tools in Google Generative AI (#119117)
parent
f07e7ec543
commit
f605c10f42
|
@ -2,6 +2,7 @@
|
|||
|
||||
from __future__ import annotations
|
||||
|
||||
import codecs
|
||||
from typing import Any, Literal
|
||||
|
||||
from google.api_core.exceptions import GoogleAPICallError
|
||||
|
@ -106,14 +107,14 @@ def _format_tool(tool: llm.Tool) -> dict[str, Any]:
|
|||
)
|
||||
|
||||
|
||||
def _adjust_value(value: Any) -> Any:
|
||||
"""Reverse unnecessary single quotes escaping."""
|
||||
def _escape_decode(value: Any) -> Any:
|
||||
"""Recursively call codecs.escape_decode on all values."""
|
||||
if isinstance(value, str):
|
||||
return value.replace("\\'", "'")
|
||||
return codecs.escape_decode(bytes(value, "utf-8"))[0].decode("utf-8") # type: ignore[attr-defined]
|
||||
if isinstance(value, list):
|
||||
return [_adjust_value(item) for item in value]
|
||||
return [_escape_decode(item) for item in value]
|
||||
if isinstance(value, dict):
|
||||
return {k: _adjust_value(v) for k, v in value.items()}
|
||||
return {k: _escape_decode(v) for k, v in value.items()}
|
||||
return value
|
||||
|
||||
|
||||
|
@ -334,10 +335,7 @@ class GoogleGenerativeAIConversationEntity(
|
|||
for function_call in function_calls:
|
||||
tool_call = MessageToDict(function_call._pb) # noqa: SLF001
|
||||
tool_name = tool_call["name"]
|
||||
tool_args = {
|
||||
key: _adjust_value(value)
|
||||
for key, value in tool_call["args"].items()
|
||||
}
|
||||
tool_args = _escape_decode(tool_call["args"])
|
||||
LOGGER.debug("Tool call: %s(%s)", tool_name, tool_args)
|
||||
tool_input = llm.ToolInput(tool_name=tool_name, tool_args=tool_args)
|
||||
try:
|
||||
|
|
|
@ -12,6 +12,9 @@ import voluptuous as vol
|
|||
|
||||
from homeassistant.components import conversation
|
||||
from homeassistant.components.conversation import trace
|
||||
from homeassistant.components.google_generative_ai_conversation.conversation import (
|
||||
_escape_decode,
|
||||
)
|
||||
from homeassistant.const import CONF_LLM_HASS_API
|
||||
from homeassistant.core import Context, HomeAssistant
|
||||
from homeassistant.exceptions import HomeAssistantError
|
||||
|
@ -504,3 +507,18 @@ async def test_conversation_agent(
|
|||
mock_config_entry.entry_id
|
||||
)
|
||||
assert agent.supported_languages == "*"
|
||||
|
||||
|
||||
async def test_escape_decode() -> None:
|
||||
"""Test _escape_decode."""
|
||||
assert _escape_decode(
|
||||
{
|
||||
"param1": ["test_value", "param1\\'s value"],
|
||||
"param2": "param2\\'s value",
|
||||
"param3": {"param31": "Cheminée", "param32": "Chemin\\303\\251e"},
|
||||
}
|
||||
) == {
|
||||
"param1": ["test_value", "param1's value"],
|
||||
"param2": "param2's value",
|
||||
"param3": {"param31": "Cheminée", "param32": "Cheminée"},
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue