From f605c10f42fd6d521d6a64e75d531111b25812cd Mon Sep 17 00:00:00 2001 From: tronikos Date: Sat, 8 Jun 2024 00:02:00 -0700 Subject: [PATCH] Properly handle escaped unicode characters passed to tools in Google Generative AI (#119117) --- .../conversation.py | 16 +++++++--------- .../test_conversation.py | 18 ++++++++++++++++++ 2 files changed, 25 insertions(+), 9 deletions(-) diff --git a/homeassistant/components/google_generative_ai_conversation/conversation.py b/homeassistant/components/google_generative_ai_conversation/conversation.py index 6c2bd64a7b5..65c0dc7fd93 100644 --- a/homeassistant/components/google_generative_ai_conversation/conversation.py +++ b/homeassistant/components/google_generative_ai_conversation/conversation.py @@ -2,6 +2,7 @@ from __future__ import annotations +import codecs from typing import Any, Literal from google.api_core.exceptions import GoogleAPICallError @@ -106,14 +107,14 @@ def _format_tool(tool: llm.Tool) -> dict[str, Any]: ) -def _adjust_value(value: Any) -> Any: - """Reverse unnecessary single quotes escaping.""" +def _escape_decode(value: Any) -> Any: + """Recursively call codecs.escape_decode on all values.""" if isinstance(value, str): - return value.replace("\\'", "'") + return codecs.escape_decode(bytes(value, "utf-8"))[0].decode("utf-8") # type: ignore[attr-defined] if isinstance(value, list): - return [_adjust_value(item) for item in value] + return [_escape_decode(item) for item in value] if isinstance(value, dict): - return {k: _adjust_value(v) for k, v in value.items()} + return {k: _escape_decode(v) for k, v in value.items()} return value @@ -334,10 +335,7 @@ class GoogleGenerativeAIConversationEntity( for function_call in function_calls: tool_call = MessageToDict(function_call._pb) # noqa: SLF001 tool_name = tool_call["name"] - tool_args = { - key: _adjust_value(value) - for key, value in tool_call["args"].items() - } + tool_args = _escape_decode(tool_call["args"]) LOGGER.debug("Tool call: %s(%s)", tool_name, tool_args) tool_input = llm.ToolInput(tool_name=tool_name, tool_args=tool_args) try: diff --git a/tests/components/google_generative_ai_conversation/test_conversation.py b/tests/components/google_generative_ai_conversation/test_conversation.py index 901216d262f..e84efffe7df 100644 --- a/tests/components/google_generative_ai_conversation/test_conversation.py +++ b/tests/components/google_generative_ai_conversation/test_conversation.py @@ -12,6 +12,9 @@ import voluptuous as vol from homeassistant.components import conversation from homeassistant.components.conversation import trace +from homeassistant.components.google_generative_ai_conversation.conversation import ( + _escape_decode, +) from homeassistant.const import CONF_LLM_HASS_API from homeassistant.core import Context, HomeAssistant from homeassistant.exceptions import HomeAssistantError @@ -504,3 +507,18 @@ async def test_conversation_agent( mock_config_entry.entry_id ) assert agent.supported_languages == "*" + + +async def test_escape_decode() -> None: + """Test _escape_decode.""" + assert _escape_decode( + { + "param1": ["test_value", "param1\\'s value"], + "param2": "param2\\'s value", + "param3": {"param31": "Cheminée", "param32": "Chemin\\303\\251e"}, + } + ) == { + "param1": ["test_value", "param1's value"], + "param2": "param2's value", + "param3": {"param31": "Cheminée", "param32": "Cheminée"}, + }