diff --git a/homeassistant/helpers/llm.py b/homeassistant/helpers/llm.py index 1d91c9e545d..db1b46f656a 100644 --- a/homeassistant/helpers/llm.py +++ b/homeassistant/helpers/llm.py @@ -2,10 +2,8 @@ from __future__ import annotations -from abc import abstractmethod -from collections.abc import Iterable +from abc import ABC, abstractmethod from dataclasses import dataclass -import logging from typing import Any import voluptuous as vol @@ -17,19 +15,53 @@ from homeassistant.exceptions import HomeAssistantError from homeassistant.util.json import JsonObjectType from . import intent +from .singleton import singleton -_LOGGER = logging.getLogger(__name__) -IGNORE_INTENTS = [ - intent.INTENT_NEVERMIND, - intent.INTENT_GET_STATE, - INTENT_GET_WEATHER, - INTENT_GET_TEMPERATURE, -] +@singleton("llm") +@callback +def _async_get_apis(hass: HomeAssistant) -> dict[str, API]: + """Get all the LLM APIs.""" + return { + "assist": AssistAPI( + hass=hass, + id="assist", + name="Assist", + prompt_template="Call the intent tools to control the system. Just pass the name to the intent.", + ), + } + + +@callback +def async_register_api(hass: HomeAssistant, api: API) -> None: + """Register an API to be exposed to LLMs.""" + apis = _async_get_apis(hass) + + if api.id in apis: + raise HomeAssistantError(f"API {api.id} is already registered") + + apis[api.id] = api + + +@callback +def async_get_api(hass: HomeAssistant, api_id: str) -> API: + """Get an API.""" + apis = _async_get_apis(hass) + + if api_id not in apis: + raise HomeAssistantError(f"API {api_id} not found") + + return apis[api_id] + + +@callback +def async_get_apis(hass: HomeAssistant) -> list[API]: + """Get all the LLM APIs.""" + return list(_async_get_apis(hass).values()) @dataclass(slots=True) -class ToolInput: +class ToolInput(ABC): """Tool input to be processed.""" tool_name: str @@ -60,34 +92,40 @@ class Tool: return f"<{self.__class__.__name__} - {self.name}>" -@callback -def async_get_tools(hass: HomeAssistant) -> Iterable[Tool]: - """Return a list of LLM tools.""" - for intent_handler in intent.async_get(hass): - if intent_handler.intent_type not in IGNORE_INTENTS: - yield IntentTool(intent_handler) +@dataclass(slots=True, kw_only=True) +class API(ABC): + """An API to expose to LLMs.""" + hass: HomeAssistant + id: str + name: str + prompt_template: str -@callback -async def async_call_tool(hass: HomeAssistant, tool_input: ToolInput) -> JsonObjectType: - """Call a LLM tool, validate args and return the response.""" - for tool in async_get_tools(hass): - if tool.name == tool_input.tool_name: - break - else: - raise HomeAssistantError(f'Tool "{tool_input.tool_name}" not found') + @abstractmethod + @callback + def async_get_tools(self) -> list[Tool]: + """Return a list of tools.""" + raise NotImplementedError - _tool_input = ToolInput( - tool_name=tool.name, - tool_args=tool.parameters(tool_input.tool_args), - platform=tool_input.platform, - context=tool_input.context or Context(), - user_prompt=tool_input.user_prompt, - language=tool_input.language, - assistant=tool_input.assistant, - ) + async def async_call_tool(self, tool_input: ToolInput) -> JsonObjectType: + """Call a LLM tool, validate args and return the response.""" + for tool in self.async_get_tools(): + if tool.name == tool_input.tool_name: + break + else: + raise HomeAssistantError(f'Tool "{tool_input.tool_name}" not found') - return await tool.async_call(hass, _tool_input) + _tool_input = ToolInput( + tool_name=tool.name, + tool_args=tool.parameters(tool_input.tool_args), + platform=tool_input.platform, + context=tool_input.context or Context(), + user_prompt=tool_input.user_prompt, + language=tool_input.language, + assistant=tool_input.assistant, + ) + + return await tool.async_call(self.hass, _tool_input) class IntentTool(Tool): @@ -120,3 +158,23 @@ class IntentTool(Tool): tool_input.assistant, ) return intent_response.as_dict() + + +class AssistAPI(API): + """API exposing Assist API to LLMs.""" + + IGNORE_INTENTS = { + intent.INTENT_NEVERMIND, + intent.INTENT_GET_STATE, + INTENT_GET_WEATHER, + INTENT_GET_TEMPERATURE, + } + + @callback + def async_get_tools(self) -> list[Tool]: + """Return a list of LLM tools.""" + return [ + IntentTool(intent_handler) + for intent_handler in intent.async_get(self.hass) + if intent_handler.intent_type not in self.IGNORE_INTENTS + ] diff --git a/tests/helpers/test_llm.py b/tests/helpers/test_llm.py index 3cb2078967d..861a63ec3ef 100644 --- a/tests/helpers/test_llm.py +++ b/tests/helpers/test_llm.py @@ -10,11 +10,33 @@ from homeassistant.exceptions import HomeAssistantError from homeassistant.helpers import config_validation as cv, intent, llm +async def test_get_api_no_existing(hass: HomeAssistant) -> None: + """Test getting an llm api where no config exists.""" + with pytest.raises(HomeAssistantError): + llm.async_get_api(hass, "non-existing") + + +async def test_register_api(hass: HomeAssistant) -> None: + """Test registering an llm api.""" + api = llm.AssistAPI( + hass=hass, + id="test", + name="Test", + prompt_template="Test", + ) + llm.async_register_api(hass, api) + + assert llm.async_get_api(hass, "test") is api + assert api in llm.async_get_apis(hass) + + with pytest.raises(HomeAssistantError): + llm.async_register_api(hass, api) + + async def test_call_tool_no_existing(hass: HomeAssistant) -> None: """Test calling an llm tool where no config exists.""" with pytest.raises(HomeAssistantError): - await llm.async_call_tool( - hass, + await llm.async_get_api(hass, "intent").async_call_tool( llm.ToolInput( "test_tool", {}, @@ -27,8 +49,8 @@ async def test_call_tool_no_existing(hass: HomeAssistant) -> None: ) -async def test_intent_tool(hass: HomeAssistant) -> None: - """Test IntentTool class.""" +async def test_assist_api(hass: HomeAssistant) -> None: + """Test Assist API.""" schema = { vol.Optional("area"): cv.string, vol.Optional("floor"): cv.string, @@ -42,8 +64,11 @@ async def test_intent_tool(hass: HomeAssistant) -> None: intent.async_register(hass, intent_handler) - assert len(list(llm.async_get_tools(hass))) == 1 - tool = list(llm.async_get_tools(hass))[0] + assert len(llm.async_get_apis(hass)) == 1 + api = llm.async_get_api(hass, "assist") + tools = api.async_get_tools() + assert len(tools) == 1 + tool = tools[0] assert tool.name == "test_intent" assert tool.description == "Execute Home Assistant test_intent intent" assert tool.parameters == vol.Schema(intent_handler.slot_schema) @@ -66,7 +91,7 @@ async def test_intent_tool(hass: HomeAssistant) -> None: with patch( "homeassistant.helpers.intent.async_handle", return_value=intent_response ) as mock_intent_handle: - response = await llm.async_call_tool(hass, tool_input) + response = await api.async_call_tool(tool_input) mock_intent_handle.assert_awaited_once_with( hass,