Add API class to LLM helper (#117707)

* Add API class to LLM helper

* Add more tests

* Rename intent to assist to broaden scope
pull/115580/head
Paulus Schoutsen 2024-05-18 21:14:05 -04:00 committed by GitHub
parent bfc52b9fab
commit d001e7daea
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 125 additions and 42 deletions

View File

@ -2,10 +2,8 @@
from __future__ import annotations
from abc import abstractmethod
from collections.abc import Iterable
from abc import ABC, abstractmethod
from dataclasses import dataclass
import logging
from typing import Any
import voluptuous as vol
@ -17,19 +15,53 @@ from homeassistant.exceptions import HomeAssistantError
from homeassistant.util.json import JsonObjectType
from . import intent
from .singleton import singleton
_LOGGER = logging.getLogger(__name__)
IGNORE_INTENTS = [
intent.INTENT_NEVERMIND,
intent.INTENT_GET_STATE,
INTENT_GET_WEATHER,
INTENT_GET_TEMPERATURE,
]
@singleton("llm")
@callback
def _async_get_apis(hass: HomeAssistant) -> dict[str, API]:
"""Get all the LLM APIs."""
return {
"assist": AssistAPI(
hass=hass,
id="assist",
name="Assist",
prompt_template="Call the intent tools to control the system. Just pass the name to the intent.",
),
}
@callback
def async_register_api(hass: HomeAssistant, api: API) -> None:
"""Register an API to be exposed to LLMs."""
apis = _async_get_apis(hass)
if api.id in apis:
raise HomeAssistantError(f"API {api.id} is already registered")
apis[api.id] = api
@callback
def async_get_api(hass: HomeAssistant, api_id: str) -> API:
"""Get an API."""
apis = _async_get_apis(hass)
if api_id not in apis:
raise HomeAssistantError(f"API {api_id} not found")
return apis[api_id]
@callback
def async_get_apis(hass: HomeAssistant) -> list[API]:
"""Get all the LLM APIs."""
return list(_async_get_apis(hass).values())
@dataclass(slots=True)
class ToolInput:
class ToolInput(ABC):
"""Tool input to be processed."""
tool_name: str
@ -60,34 +92,40 @@ class Tool:
return f"<{self.__class__.__name__} - {self.name}>"
@callback
def async_get_tools(hass: HomeAssistant) -> Iterable[Tool]:
"""Return a list of LLM tools."""
for intent_handler in intent.async_get(hass):
if intent_handler.intent_type not in IGNORE_INTENTS:
yield IntentTool(intent_handler)
@dataclass(slots=True, kw_only=True)
class API(ABC):
"""An API to expose to LLMs."""
hass: HomeAssistant
id: str
name: str
prompt_template: str
@callback
async def async_call_tool(hass: HomeAssistant, tool_input: ToolInput) -> JsonObjectType:
"""Call a LLM tool, validate args and return the response."""
for tool in async_get_tools(hass):
if tool.name == tool_input.tool_name:
break
else:
raise HomeAssistantError(f'Tool "{tool_input.tool_name}" not found')
@abstractmethod
@callback
def async_get_tools(self) -> list[Tool]:
"""Return a list of tools."""
raise NotImplementedError
_tool_input = ToolInput(
tool_name=tool.name,
tool_args=tool.parameters(tool_input.tool_args),
platform=tool_input.platform,
context=tool_input.context or Context(),
user_prompt=tool_input.user_prompt,
language=tool_input.language,
assistant=tool_input.assistant,
)
async def async_call_tool(self, tool_input: ToolInput) -> JsonObjectType:
"""Call a LLM tool, validate args and return the response."""
for tool in self.async_get_tools():
if tool.name == tool_input.tool_name:
break
else:
raise HomeAssistantError(f'Tool "{tool_input.tool_name}" not found')
return await tool.async_call(hass, _tool_input)
_tool_input = ToolInput(
tool_name=tool.name,
tool_args=tool.parameters(tool_input.tool_args),
platform=tool_input.platform,
context=tool_input.context or Context(),
user_prompt=tool_input.user_prompt,
language=tool_input.language,
assistant=tool_input.assistant,
)
return await tool.async_call(self.hass, _tool_input)
class IntentTool(Tool):
@ -120,3 +158,23 @@ class IntentTool(Tool):
tool_input.assistant,
)
return intent_response.as_dict()
class AssistAPI(API):
"""API exposing Assist API to LLMs."""
IGNORE_INTENTS = {
intent.INTENT_NEVERMIND,
intent.INTENT_GET_STATE,
INTENT_GET_WEATHER,
INTENT_GET_TEMPERATURE,
}
@callback
def async_get_tools(self) -> list[Tool]:
"""Return a list of LLM tools."""
return [
IntentTool(intent_handler)
for intent_handler in intent.async_get(self.hass)
if intent_handler.intent_type not in self.IGNORE_INTENTS
]

View File

@ -10,11 +10,33 @@ from homeassistant.exceptions import HomeAssistantError
from homeassistant.helpers import config_validation as cv, intent, llm
async def test_get_api_no_existing(hass: HomeAssistant) -> None:
"""Test getting an llm api where no config exists."""
with pytest.raises(HomeAssistantError):
llm.async_get_api(hass, "non-existing")
async def test_register_api(hass: HomeAssistant) -> None:
"""Test registering an llm api."""
api = llm.AssistAPI(
hass=hass,
id="test",
name="Test",
prompt_template="Test",
)
llm.async_register_api(hass, api)
assert llm.async_get_api(hass, "test") is api
assert api in llm.async_get_apis(hass)
with pytest.raises(HomeAssistantError):
llm.async_register_api(hass, api)
async def test_call_tool_no_existing(hass: HomeAssistant) -> None:
"""Test calling an llm tool where no config exists."""
with pytest.raises(HomeAssistantError):
await llm.async_call_tool(
hass,
await llm.async_get_api(hass, "intent").async_call_tool(
llm.ToolInput(
"test_tool",
{},
@ -27,8 +49,8 @@ async def test_call_tool_no_existing(hass: HomeAssistant) -> None:
)
async def test_intent_tool(hass: HomeAssistant) -> None:
"""Test IntentTool class."""
async def test_assist_api(hass: HomeAssistant) -> None:
"""Test Assist API."""
schema = {
vol.Optional("area"): cv.string,
vol.Optional("floor"): cv.string,
@ -42,8 +64,11 @@ async def test_intent_tool(hass: HomeAssistant) -> None:
intent.async_register(hass, intent_handler)
assert len(list(llm.async_get_tools(hass))) == 1
tool = list(llm.async_get_tools(hass))[0]
assert len(llm.async_get_apis(hass)) == 1
api = llm.async_get_api(hass, "assist")
tools = api.async_get_tools()
assert len(tools) == 1
tool = tools[0]
assert tool.name == "test_intent"
assert tool.description == "Execute Home Assistant test_intent intent"
assert tool.parameters == vol.Schema(intent_handler.slot_schema)
@ -66,7 +91,7 @@ async def test_intent_tool(hass: HomeAssistant) -> None:
with patch(
"homeassistant.helpers.intent.async_handle", return_value=intent_response
) as mock_intent_handle:
response = await llm.async_call_tool(hass, tool_input)
response = await api.async_call_tool(tool_input)
mock_intent_handle.assert_awaited_once_with(
hass,