Add API class to LLM helper (#117707)
* Add API class to LLM helper * Add more tests * Rename intent to assist to broaden scopepull/115580/head
parent
bfc52b9fab
commit
d001e7daea
|
@ -2,10 +2,8 @@
|
|||
|
||||
from __future__ import annotations
|
||||
|
||||
from abc import abstractmethod
|
||||
from collections.abc import Iterable
|
||||
from abc import ABC, abstractmethod
|
||||
from dataclasses import dataclass
|
||||
import logging
|
||||
from typing import Any
|
||||
|
||||
import voluptuous as vol
|
||||
|
@ -17,19 +15,53 @@ from homeassistant.exceptions import HomeAssistantError
|
|||
from homeassistant.util.json import JsonObjectType
|
||||
|
||||
from . import intent
|
||||
from .singleton import singleton
|
||||
|
||||
_LOGGER = logging.getLogger(__name__)
|
||||
|
||||
IGNORE_INTENTS = [
|
||||
intent.INTENT_NEVERMIND,
|
||||
intent.INTENT_GET_STATE,
|
||||
INTENT_GET_WEATHER,
|
||||
INTENT_GET_TEMPERATURE,
|
||||
]
|
||||
@singleton("llm")
|
||||
@callback
|
||||
def _async_get_apis(hass: HomeAssistant) -> dict[str, API]:
|
||||
"""Get all the LLM APIs."""
|
||||
return {
|
||||
"assist": AssistAPI(
|
||||
hass=hass,
|
||||
id="assist",
|
||||
name="Assist",
|
||||
prompt_template="Call the intent tools to control the system. Just pass the name to the intent.",
|
||||
),
|
||||
}
|
||||
|
||||
|
||||
@callback
|
||||
def async_register_api(hass: HomeAssistant, api: API) -> None:
|
||||
"""Register an API to be exposed to LLMs."""
|
||||
apis = _async_get_apis(hass)
|
||||
|
||||
if api.id in apis:
|
||||
raise HomeAssistantError(f"API {api.id} is already registered")
|
||||
|
||||
apis[api.id] = api
|
||||
|
||||
|
||||
@callback
|
||||
def async_get_api(hass: HomeAssistant, api_id: str) -> API:
|
||||
"""Get an API."""
|
||||
apis = _async_get_apis(hass)
|
||||
|
||||
if api_id not in apis:
|
||||
raise HomeAssistantError(f"API {api_id} not found")
|
||||
|
||||
return apis[api_id]
|
||||
|
||||
|
||||
@callback
|
||||
def async_get_apis(hass: HomeAssistant) -> list[API]:
|
||||
"""Get all the LLM APIs."""
|
||||
return list(_async_get_apis(hass).values())
|
||||
|
||||
|
||||
@dataclass(slots=True)
|
||||
class ToolInput:
|
||||
class ToolInput(ABC):
|
||||
"""Tool input to be processed."""
|
||||
|
||||
tool_name: str
|
||||
|
@ -60,34 +92,40 @@ class Tool:
|
|||
return f"<{self.__class__.__name__} - {self.name}>"
|
||||
|
||||
|
||||
@callback
|
||||
def async_get_tools(hass: HomeAssistant) -> Iterable[Tool]:
|
||||
"""Return a list of LLM tools."""
|
||||
for intent_handler in intent.async_get(hass):
|
||||
if intent_handler.intent_type not in IGNORE_INTENTS:
|
||||
yield IntentTool(intent_handler)
|
||||
@dataclass(slots=True, kw_only=True)
|
||||
class API(ABC):
|
||||
"""An API to expose to LLMs."""
|
||||
|
||||
hass: HomeAssistant
|
||||
id: str
|
||||
name: str
|
||||
prompt_template: str
|
||||
|
||||
@callback
|
||||
async def async_call_tool(hass: HomeAssistant, tool_input: ToolInput) -> JsonObjectType:
|
||||
"""Call a LLM tool, validate args and return the response."""
|
||||
for tool in async_get_tools(hass):
|
||||
if tool.name == tool_input.tool_name:
|
||||
break
|
||||
else:
|
||||
raise HomeAssistantError(f'Tool "{tool_input.tool_name}" not found')
|
||||
@abstractmethod
|
||||
@callback
|
||||
def async_get_tools(self) -> list[Tool]:
|
||||
"""Return a list of tools."""
|
||||
raise NotImplementedError
|
||||
|
||||
_tool_input = ToolInput(
|
||||
tool_name=tool.name,
|
||||
tool_args=tool.parameters(tool_input.tool_args),
|
||||
platform=tool_input.platform,
|
||||
context=tool_input.context or Context(),
|
||||
user_prompt=tool_input.user_prompt,
|
||||
language=tool_input.language,
|
||||
assistant=tool_input.assistant,
|
||||
)
|
||||
async def async_call_tool(self, tool_input: ToolInput) -> JsonObjectType:
|
||||
"""Call a LLM tool, validate args and return the response."""
|
||||
for tool in self.async_get_tools():
|
||||
if tool.name == tool_input.tool_name:
|
||||
break
|
||||
else:
|
||||
raise HomeAssistantError(f'Tool "{tool_input.tool_name}" not found')
|
||||
|
||||
return await tool.async_call(hass, _tool_input)
|
||||
_tool_input = ToolInput(
|
||||
tool_name=tool.name,
|
||||
tool_args=tool.parameters(tool_input.tool_args),
|
||||
platform=tool_input.platform,
|
||||
context=tool_input.context or Context(),
|
||||
user_prompt=tool_input.user_prompt,
|
||||
language=tool_input.language,
|
||||
assistant=tool_input.assistant,
|
||||
)
|
||||
|
||||
return await tool.async_call(self.hass, _tool_input)
|
||||
|
||||
|
||||
class IntentTool(Tool):
|
||||
|
@ -120,3 +158,23 @@ class IntentTool(Tool):
|
|||
tool_input.assistant,
|
||||
)
|
||||
return intent_response.as_dict()
|
||||
|
||||
|
||||
class AssistAPI(API):
|
||||
"""API exposing Assist API to LLMs."""
|
||||
|
||||
IGNORE_INTENTS = {
|
||||
intent.INTENT_NEVERMIND,
|
||||
intent.INTENT_GET_STATE,
|
||||
INTENT_GET_WEATHER,
|
||||
INTENT_GET_TEMPERATURE,
|
||||
}
|
||||
|
||||
@callback
|
||||
def async_get_tools(self) -> list[Tool]:
|
||||
"""Return a list of LLM tools."""
|
||||
return [
|
||||
IntentTool(intent_handler)
|
||||
for intent_handler in intent.async_get(self.hass)
|
||||
if intent_handler.intent_type not in self.IGNORE_INTENTS
|
||||
]
|
||||
|
|
|
@ -10,11 +10,33 @@ from homeassistant.exceptions import HomeAssistantError
|
|||
from homeassistant.helpers import config_validation as cv, intent, llm
|
||||
|
||||
|
||||
async def test_get_api_no_existing(hass: HomeAssistant) -> None:
|
||||
"""Test getting an llm api where no config exists."""
|
||||
with pytest.raises(HomeAssistantError):
|
||||
llm.async_get_api(hass, "non-existing")
|
||||
|
||||
|
||||
async def test_register_api(hass: HomeAssistant) -> None:
|
||||
"""Test registering an llm api."""
|
||||
api = llm.AssistAPI(
|
||||
hass=hass,
|
||||
id="test",
|
||||
name="Test",
|
||||
prompt_template="Test",
|
||||
)
|
||||
llm.async_register_api(hass, api)
|
||||
|
||||
assert llm.async_get_api(hass, "test") is api
|
||||
assert api in llm.async_get_apis(hass)
|
||||
|
||||
with pytest.raises(HomeAssistantError):
|
||||
llm.async_register_api(hass, api)
|
||||
|
||||
|
||||
async def test_call_tool_no_existing(hass: HomeAssistant) -> None:
|
||||
"""Test calling an llm tool where no config exists."""
|
||||
with pytest.raises(HomeAssistantError):
|
||||
await llm.async_call_tool(
|
||||
hass,
|
||||
await llm.async_get_api(hass, "intent").async_call_tool(
|
||||
llm.ToolInput(
|
||||
"test_tool",
|
||||
{},
|
||||
|
@ -27,8 +49,8 @@ async def test_call_tool_no_existing(hass: HomeAssistant) -> None:
|
|||
)
|
||||
|
||||
|
||||
async def test_intent_tool(hass: HomeAssistant) -> None:
|
||||
"""Test IntentTool class."""
|
||||
async def test_assist_api(hass: HomeAssistant) -> None:
|
||||
"""Test Assist API."""
|
||||
schema = {
|
||||
vol.Optional("area"): cv.string,
|
||||
vol.Optional("floor"): cv.string,
|
||||
|
@ -42,8 +64,11 @@ async def test_intent_tool(hass: HomeAssistant) -> None:
|
|||
|
||||
intent.async_register(hass, intent_handler)
|
||||
|
||||
assert len(list(llm.async_get_tools(hass))) == 1
|
||||
tool = list(llm.async_get_tools(hass))[0]
|
||||
assert len(llm.async_get_apis(hass)) == 1
|
||||
api = llm.async_get_api(hass, "assist")
|
||||
tools = api.async_get_tools()
|
||||
assert len(tools) == 1
|
||||
tool = tools[0]
|
||||
assert tool.name == "test_intent"
|
||||
assert tool.description == "Execute Home Assistant test_intent intent"
|
||||
assert tool.parameters == vol.Schema(intent_handler.slot_schema)
|
||||
|
@ -66,7 +91,7 @@ async def test_intent_tool(hass: HomeAssistant) -> None:
|
|||
with patch(
|
||||
"homeassistant.helpers.intent.async_handle", return_value=intent_response
|
||||
) as mock_intent_handle:
|
||||
response = await llm.async_call_tool(hass, tool_input)
|
||||
response = await api.async_call_tool(tool_input)
|
||||
|
||||
mock_intent_handle.assert_awaited_once_with(
|
||||
hass,
|
||||
|
|
Loading…
Reference in New Issue