From 85540cea3fb6584373863a4e73ac7b56fab85d12 Mon Sep 17 00:00:00 2001
From: Denis Shulyaka <Shulyaka@gmail.com>
Date: Mon, 27 Jan 2025 22:21:27 +0300
Subject: [PATCH] Add LLM ActionTool (#136591)

Add ActionTool
---
 homeassistant/helpers/llm.py | 173 ++++++++++++++++++++---------------
 tests/helpers/test_llm.py    |   8 +-
 2 files changed, 103 insertions(+), 78 deletions(-)

diff --git a/homeassistant/helpers/llm.py b/homeassistant/helpers/llm.py
index ea376923f9d..cc397c5d428 100644
--- a/homeassistant/helpers/llm.py
+++ b/homeassistant/helpers/llm.py
@@ -49,9 +49,9 @@ from . import (
 )
 from .singleton import singleton
 
-SCRIPT_PARAMETERS_CACHE: HassKey[dict[str, tuple[str | None, vol.Schema]]] = HassKey(
-    "llm_script_parameters_cache"
-)
+ACTION_PARAMETERS_CACHE: HassKey[
+    dict[str, dict[str, tuple[str | None, vol.Schema]]]
+] = HassKey("llm_action_parameters_cache")
 
 
 LLM_API_ASSIST = "assist"
@@ -624,104 +624,105 @@ def _selector_serializer(schema: Any) -> Any:  # noqa: C901
     return {"type": "string"}
 
 
-def _get_cached_script_parameters(
-    hass: HomeAssistant, entity_id: str
+def _get_cached_action_parameters(
+    hass: HomeAssistant, domain: str, action: str
 ) -> tuple[str | None, vol.Schema]:
-    """Get script description and schema."""
-    entity_registry = er.async_get(hass)
-
+    """Get action description and schema."""
     description = None
     parameters = vol.Schema({})
-    entity_entry = entity_registry.async_get(entity_id)
-    if entity_entry and entity_entry.unique_id:
-        parameters_cache = hass.data.get(SCRIPT_PARAMETERS_CACHE)
 
-        if parameters_cache is None:
-            parameters_cache = hass.data[SCRIPT_PARAMETERS_CACHE] = {}
+    parameters_cache = hass.data.get(ACTION_PARAMETERS_CACHE)
 
-            @callback
-            def clear_cache(event: Event) -> None:
-                """Clear script parameter cache on script reload or delete."""
-                if (
-                    event.data[ATTR_DOMAIN] == SCRIPT_DOMAIN
-                    and event.data[ATTR_SERVICE] in parameters_cache
-                ):
-                    parameters_cache.pop(event.data[ATTR_SERVICE])
+    if parameters_cache is None:
+        parameters_cache = hass.data[ACTION_PARAMETERS_CACHE] = {}
 
-            cancel = hass.bus.async_listen(EVENT_SERVICE_REMOVED, clear_cache)
+        @callback
+        def clear_cache(event: Event) -> None:
+            """Clear action parameter cache on action removal."""
+            if (
+                event.data[ATTR_DOMAIN] in parameters_cache
+                and event.data[ATTR_SERVICE]
+                in parameters_cache[event.data[ATTR_DOMAIN]]
+            ):
+                parameters_cache[event.data[ATTR_DOMAIN]].pop(event.data[ATTR_SERVICE])
 
-            @callback
-            def on_homeassistant_close(event: Event) -> None:
-                """Cleanup."""
-                cancel()
+        cancel = hass.bus.async_listen(EVENT_SERVICE_REMOVED, clear_cache)
 
-            hass.bus.async_listen_once(
-                EVENT_HOMEASSISTANT_CLOSE, on_homeassistant_close
-            )
+        @callback
+        def on_homeassistant_close(event: Event) -> None:
+            """Cleanup."""
+            cancel()
 
-        if entity_entry.unique_id in parameters_cache:
-            return parameters_cache[entity_entry.unique_id]
+        hass.bus.async_listen_once(EVENT_HOMEASSISTANT_CLOSE, on_homeassistant_close)
 
-        if service_desc := service.async_get_cached_service_description(
-            hass, SCRIPT_DOMAIN, entity_entry.unique_id
-        ):
-            description = service_desc.get("description")
-            schema: dict[vol.Marker, Any] = {}
-            fields = service_desc.get("fields", {})
+    if domain in parameters_cache and action in parameters_cache[domain]:
+        return parameters_cache[domain][action]
 
-            for field, config in fields.items():
-                field_description = config.get("description")
-                if not field_description:
-                    field_description = config.get("name")
-                key: vol.Marker
-                if config.get("required"):
-                    key = vol.Required(field, description=field_description)
-                else:
-                    key = vol.Optional(field, description=field_description)
-                if "selector" in config:
-                    schema[key] = selector.selector(config["selector"])
-                else:
-                    schema[key] = cv.string
+    if action_desc := service.async_get_cached_service_description(
+        hass, domain, action
+    ):
+        description = action_desc.get("description")
+        schema: dict[vol.Marker, Any] = {}
+        fields = action_desc.get("fields", {})
 
-            parameters = vol.Schema(schema)
+        for field, config in fields.items():
+            field_description = config.get("description")
+            if not field_description:
+                field_description = config.get("name")
+            key: vol.Marker
+            if config.get("required"):
+                key = vol.Required(field, description=field_description)
+            else:
+                key = vol.Optional(field, description=field_description)
+            if "selector" in config:
+                schema[key] = selector.selector(config["selector"])
+            else:
+                schema[key] = cv.string
 
-            aliases: list[str] = []
-            if entity_entry.name:
-                aliases.append(entity_entry.name)
-            if entity_entry.aliases:
-                aliases.extend(entity_entry.aliases)
-            if aliases:
-                if description:
-                    description = description + ". Aliases: " + str(list(aliases))
-                else:
-                    description = "Aliases: " + str(list(aliases))
+        parameters = vol.Schema(schema)
 
-            parameters_cache[entity_entry.unique_id] = (description, parameters)
+        if domain == SCRIPT_DOMAIN:
+            entity_registry = er.async_get(hass)
+            if (
+                entity_id := entity_registry.async_get_entity_id(domain, domain, action)
+            ) and (entity_entry := entity_registry.async_get(entity_id)):
+                aliases: list[str] = []
+                if entity_entry.name:
+                    aliases.append(entity_entry.name)
+                if entity_entry.aliases:
+                    aliases.extend(entity_entry.aliases)
+                if aliases:
+                    if description:
+                        description = description + ". Aliases: " + str(list(aliases))
+                    else:
+                        description = "Aliases: " + str(list(aliases))
+
+        parameters_cache.setdefault(domain, {})[action] = (description, parameters)
 
     return description, parameters
 
 
-class ScriptTool(Tool):
-    """LLM Tool representing a Script."""
+class ActionTool(Tool):
+    """LLM Tool representing an action."""
 
     def __init__(
         self,
         hass: HomeAssistant,
-        script_entity_id: str,
+        domain: str,
+        action: str,
     ) -> None:
         """Init the class."""
-        self._object_id = self.name = split_entity_id(script_entity_id)[1]
-        if self.name[0].isdigit():
-            self.name = "_" + self.name
-
-        self.description, self.parameters = _get_cached_script_parameters(
-            hass, script_entity_id
+        self._domain = domain
+        self._action = action
+        self.name = f"{domain}.{action}"
+        self.description, self.parameters = _get_cached_action_parameters(
+            hass, domain, action
         )
 
     async def async_call(
         self, hass: HomeAssistant, tool_input: ToolInput, llm_context: LLMContext
     ) -> JsonObjectType:
-        """Run the script."""
+        """Call the action."""
 
         for field, validator in self.parameters.schema.items():
             if field not in tool_input.tool_args:
@@ -753,8 +754,8 @@ class ScriptTool(Tool):
                     tool_input.tool_args[field] = floor
 
         result = await hass.services.async_call(
-            SCRIPT_DOMAIN,
-            self._object_id,
+            self._domain,
+            self._action,
             tool_input.tool_args,
             context=llm_context.context,
             blocking=True,
@@ -764,6 +765,30 @@ class ScriptTool(Tool):
         return {"success": True, "result": result}
 
 
+class ScriptTool(ActionTool):
+    """LLM Tool representing a Script."""
+
+    def __init__(
+        self,
+        hass: HomeAssistant,
+        script_entity_id: str,
+    ) -> None:
+        """Init the class."""
+        script_name = split_entity_id(script_entity_id)[1]
+
+        action = script_name
+        entity_registry = er.async_get(hass)
+        entity_entry = entity_registry.async_get(script_entity_id)
+        if entity_entry and entity_entry.unique_id:
+            action = entity_entry.unique_id
+
+        super().__init__(hass, SCRIPT_DOMAIN, action)
+
+        self.name = script_name
+        if self.name[0].isdigit():
+            self.name = "_" + self.name
+
+
 class CalendarGetEventsTool(Tool):
     """LLM Tool allowing querying a calendar."""
 
diff --git a/tests/helpers/test_llm.py b/tests/helpers/test_llm.py
index 57e151ba8eb..e288026b67b 100644
--- a/tests/helpers/test_llm.py
+++ b/tests/helpers/test_llm.py
@@ -745,7 +745,7 @@ async def test_script_tool(
     area = area_registry.async_create("Living room")
     floor = floor_registry.async_create("2")
 
-    assert llm.SCRIPT_PARAMETERS_CACHE not in hass.data
+    assert llm.ACTION_PARAMETERS_CACHE not in hass.data
 
     api = await llm.async_get_api(hass, "assist", llm_context)
 
@@ -769,7 +769,7 @@ async def test_script_tool(
     }
     assert tool.parameters.schema == schema
 
-    assert hass.data[llm.SCRIPT_PARAMETERS_CACHE] == {
+    assert hass.data[llm.ACTION_PARAMETERS_CACHE]["script"] == {
         "test_script": (
             "This is a test script. Aliases: ['script name', 'script alias']",
             vol.Schema(schema),
@@ -866,7 +866,7 @@ async def test_script_tool(
     ):
         await hass.services.async_call("script", "reload", blocking=True)
 
-    assert hass.data[llm.SCRIPT_PARAMETERS_CACHE] == {}
+    assert hass.data[llm.ACTION_PARAMETERS_CACHE]["script"] == {}
 
     api = await llm.async_get_api(hass, "assist", llm_context)
 
@@ -882,7 +882,7 @@ async def test_script_tool(
     schema = {vol.Required("beer", description="Number of beers"): cv.string}
     assert tool.parameters.schema == schema
 
-    assert hass.data[llm.SCRIPT_PARAMETERS_CACHE] == {
+    assert hass.data[llm.ACTION_PARAMETERS_CACHE]["script"] == {
         "test_script": (
             "This is a new test script. Aliases: ['script name', 'script alias']",
             vol.Schema(schema),