From d001e7daeac61bb262b482a9ea0eae78820563e5 Mon Sep 17 00:00:00 2001
From: Paulus Schoutsen <balloob@gmail.com>
Date: Sat, 18 May 2024 21:14:05 -0400
Subject: [PATCH] Add API class to LLM helper (#117707)

* Add API class to LLM helper

* Add more tests

* Rename intent to assist to broaden scope
---
 homeassistant/helpers/llm.py | 128 +++++++++++++++++++++++++----------
 tests/helpers/test_llm.py    |  39 +++++++++--
 2 files changed, 125 insertions(+), 42 deletions(-)

diff --git a/homeassistant/helpers/llm.py b/homeassistant/helpers/llm.py
index 1d91c9e545d..db1b46f656a 100644
--- a/homeassistant/helpers/llm.py
+++ b/homeassistant/helpers/llm.py
@@ -2,10 +2,8 @@
 
 from __future__ import annotations
 
-from abc import abstractmethod
-from collections.abc import Iterable
+from abc import ABC, abstractmethod
 from dataclasses import dataclass
-import logging
 from typing import Any
 
 import voluptuous as vol
@@ -17,19 +15,53 @@ from homeassistant.exceptions import HomeAssistantError
 from homeassistant.util.json import JsonObjectType
 
 from . import intent
+from .singleton import singleton
 
-_LOGGER = logging.getLogger(__name__)
 
-IGNORE_INTENTS = [
-    intent.INTENT_NEVERMIND,
-    intent.INTENT_GET_STATE,
-    INTENT_GET_WEATHER,
-    INTENT_GET_TEMPERATURE,
-]
+@singleton("llm")
+@callback
+def _async_get_apis(hass: HomeAssistant) -> dict[str, API]:
+    """Get all the LLM APIs."""
+    return {
+        "assist": AssistAPI(
+            hass=hass,
+            id="assist",
+            name="Assist",
+            prompt_template="Call the intent tools to control the system. Just pass the name to the intent.",
+        ),
+    }
+
+
+@callback
+def async_register_api(hass: HomeAssistant, api: API) -> None:
+    """Register an API to be exposed to LLMs."""
+    apis = _async_get_apis(hass)
+
+    if api.id in apis:
+        raise HomeAssistantError(f"API {api.id} is already registered")
+
+    apis[api.id] = api
+
+
+@callback
+def async_get_api(hass: HomeAssistant, api_id: str) -> API:
+    """Get an API."""
+    apis = _async_get_apis(hass)
+
+    if api_id not in apis:
+        raise HomeAssistantError(f"API {api_id} not found")
+
+    return apis[api_id]
+
+
+@callback
+def async_get_apis(hass: HomeAssistant) -> list[API]:
+    """Get all the LLM APIs."""
+    return list(_async_get_apis(hass).values())
 
 
 @dataclass(slots=True)
-class ToolInput:
+class ToolInput(ABC):
     """Tool input to be processed."""
 
     tool_name: str
@@ -60,34 +92,40 @@ class Tool:
         return f"<{self.__class__.__name__} - {self.name}>"
 
 
-@callback
-def async_get_tools(hass: HomeAssistant) -> Iterable[Tool]:
-    """Return a list of LLM tools."""
-    for intent_handler in intent.async_get(hass):
-        if intent_handler.intent_type not in IGNORE_INTENTS:
-            yield IntentTool(intent_handler)
+@dataclass(slots=True, kw_only=True)
+class API(ABC):
+    """An API to expose to LLMs."""
 
+    hass: HomeAssistant
+    id: str
+    name: str
+    prompt_template: str
 
-@callback
-async def async_call_tool(hass: HomeAssistant, tool_input: ToolInput) -> JsonObjectType:
-    """Call a LLM tool, validate args and return the response."""
-    for tool in async_get_tools(hass):
-        if tool.name == tool_input.tool_name:
-            break
-    else:
-        raise HomeAssistantError(f'Tool "{tool_input.tool_name}" not found')
+    @abstractmethod
+    @callback
+    def async_get_tools(self) -> list[Tool]:
+        """Return a list of tools."""
+        raise NotImplementedError
 
-    _tool_input = ToolInput(
-        tool_name=tool.name,
-        tool_args=tool.parameters(tool_input.tool_args),
-        platform=tool_input.platform,
-        context=tool_input.context or Context(),
-        user_prompt=tool_input.user_prompt,
-        language=tool_input.language,
-        assistant=tool_input.assistant,
-    )
+    async def async_call_tool(self, tool_input: ToolInput) -> JsonObjectType:
+        """Call a LLM tool, validate args and return the response."""
+        for tool in self.async_get_tools():
+            if tool.name == tool_input.tool_name:
+                break
+        else:
+            raise HomeAssistantError(f'Tool "{tool_input.tool_name}" not found')
 
-    return await tool.async_call(hass, _tool_input)
+        _tool_input = ToolInput(
+            tool_name=tool.name,
+            tool_args=tool.parameters(tool_input.tool_args),
+            platform=tool_input.platform,
+            context=tool_input.context or Context(),
+            user_prompt=tool_input.user_prompt,
+            language=tool_input.language,
+            assistant=tool_input.assistant,
+        )
+
+        return await tool.async_call(self.hass, _tool_input)
 
 
 class IntentTool(Tool):
@@ -120,3 +158,23 @@ class IntentTool(Tool):
             tool_input.assistant,
         )
         return intent_response.as_dict()
+
+
+class AssistAPI(API):
+    """API exposing Assist API to LLMs."""
+
+    IGNORE_INTENTS = {
+        intent.INTENT_NEVERMIND,
+        intent.INTENT_GET_STATE,
+        INTENT_GET_WEATHER,
+        INTENT_GET_TEMPERATURE,
+    }
+
+    @callback
+    def async_get_tools(self) -> list[Tool]:
+        """Return a list of LLM tools."""
+        return [
+            IntentTool(intent_handler)
+            for intent_handler in intent.async_get(self.hass)
+            if intent_handler.intent_type not in self.IGNORE_INTENTS
+        ]
diff --git a/tests/helpers/test_llm.py b/tests/helpers/test_llm.py
index 3cb2078967d..861a63ec3ef 100644
--- a/tests/helpers/test_llm.py
+++ b/tests/helpers/test_llm.py
@@ -10,11 +10,33 @@ from homeassistant.exceptions import HomeAssistantError
 from homeassistant.helpers import config_validation as cv, intent, llm
 
 
+async def test_get_api_no_existing(hass: HomeAssistant) -> None:
+    """Test getting an llm api where no config exists."""
+    with pytest.raises(HomeAssistantError):
+        llm.async_get_api(hass, "non-existing")
+
+
+async def test_register_api(hass: HomeAssistant) -> None:
+    """Test registering an llm api."""
+    api = llm.AssistAPI(
+        hass=hass,
+        id="test",
+        name="Test",
+        prompt_template="Test",
+    )
+    llm.async_register_api(hass, api)
+
+    assert llm.async_get_api(hass, "test") is api
+    assert api in llm.async_get_apis(hass)
+
+    with pytest.raises(HomeAssistantError):
+        llm.async_register_api(hass, api)
+
+
 async def test_call_tool_no_existing(hass: HomeAssistant) -> None:
     """Test calling an llm tool where no config exists."""
     with pytest.raises(HomeAssistantError):
-        await llm.async_call_tool(
-            hass,
+        await llm.async_get_api(hass, "intent").async_call_tool(
             llm.ToolInput(
                 "test_tool",
                 {},
@@ -27,8 +49,8 @@ async def test_call_tool_no_existing(hass: HomeAssistant) -> None:
         )
 
 
-async def test_intent_tool(hass: HomeAssistant) -> None:
-    """Test IntentTool class."""
+async def test_assist_api(hass: HomeAssistant) -> None:
+    """Test Assist API."""
     schema = {
         vol.Optional("area"): cv.string,
         vol.Optional("floor"): cv.string,
@@ -42,8 +64,11 @@ async def test_intent_tool(hass: HomeAssistant) -> None:
 
     intent.async_register(hass, intent_handler)
 
-    assert len(list(llm.async_get_tools(hass))) == 1
-    tool = list(llm.async_get_tools(hass))[0]
+    assert len(llm.async_get_apis(hass)) == 1
+    api = llm.async_get_api(hass, "assist")
+    tools = api.async_get_tools()
+    assert len(tools) == 1
+    tool = tools[0]
     assert tool.name == "test_intent"
     assert tool.description == "Execute Home Assistant test_intent intent"
     assert tool.parameters == vol.Schema(intent_handler.slot_schema)
@@ -66,7 +91,7 @@ async def test_intent_tool(hass: HomeAssistant) -> None:
     with patch(
         "homeassistant.helpers.intent.async_handle", return_value=intent_response
     ) as mock_intent_handle:
-        response = await llm.async_call_tool(hass, tool_input)
+        response = await api.async_call_tool(tool_input)
 
     mock_intent_handle.assert_awaited_once_with(
         hass,