2024-05-15 23:16:47 +00:00
|
|
|
"""Module to coordinate llm tools."""
|
|
|
|
|
|
|
|
from __future__ import annotations
|
|
|
|
|
2024-05-19 01:14:05 +00:00
|
|
|
from abc import ABC, abstractmethod
|
2024-05-27 04:27:08 +00:00
|
|
|
from dataclasses import asdict, dataclass, replace
|
|
|
|
from enum import Enum
|
2024-05-15 23:16:47 +00:00
|
|
|
from typing import Any
|
|
|
|
|
|
|
|
import voluptuous as vol
|
|
|
|
|
|
|
|
from homeassistant.components.climate.intent import INTENT_GET_TEMPERATURE
|
2024-05-25 18:16:51 +00:00
|
|
|
from homeassistant.components.conversation.trace import (
|
|
|
|
ConversationTraceEventType,
|
|
|
|
async_conversation_trace_append,
|
|
|
|
)
|
2024-05-27 04:27:08 +00:00
|
|
|
from homeassistant.components.homeassistant.exposed_entities import async_should_expose
|
2024-05-15 23:16:47 +00:00
|
|
|
from homeassistant.components.weather.intent import INTENT_GET_WEATHER
|
|
|
|
from homeassistant.core import Context, HomeAssistant, callback
|
|
|
|
from homeassistant.exceptions import HomeAssistantError
|
2024-05-27 04:27:08 +00:00
|
|
|
from homeassistant.util import yaml
|
2024-05-15 23:16:47 +00:00
|
|
|
from homeassistant.util.json import JsonObjectType
|
|
|
|
|
2024-05-27 04:27:08 +00:00
|
|
|
from . import (
|
|
|
|
area_registry as ar,
|
|
|
|
device_registry as dr,
|
|
|
|
entity_registry as er,
|
|
|
|
floor_registry as fr,
|
|
|
|
intent,
|
|
|
|
)
|
2024-05-19 01:14:05 +00:00
|
|
|
from .singleton import singleton
|
2024-05-15 23:16:47 +00:00
|
|
|
|
2024-05-20 02:11:25 +00:00
|
|
|
LLM_API_ASSIST = "assist"
|
|
|
|
|
2024-05-27 00:24:26 +00:00
|
|
|
DEFAULT_INSTRUCTIONS_PROMPT = """You are a voice assistant for Home Assistant.
|
|
|
|
Answer in plain text. Keep it simple and to the point.
|
|
|
|
The current time is {{ now().strftime("%X") }}.
|
|
|
|
Today's date is {{ now().strftime("%x") }}.
|
|
|
|
"""
|
|
|
|
|
2024-05-26 11:35:15 +00:00
|
|
|
|
|
|
|
@callback
|
|
|
|
def async_render_no_api_prompt(hass: HomeAssistant) -> str:
|
|
|
|
"""Return the prompt to be used when no API is configured."""
|
|
|
|
return (
|
|
|
|
"Only if the user wants to control a device, tell them to edit the AI configuration "
|
|
|
|
"and allow access to Home Assistant."
|
|
|
|
)
|
2024-05-20 02:11:25 +00:00
|
|
|
|
2024-05-15 23:16:47 +00:00
|
|
|
|
2024-05-19 01:14:05 +00:00
|
|
|
@singleton("llm")
|
|
|
|
@callback
|
|
|
|
def _async_get_apis(hass: HomeAssistant) -> dict[str, API]:
|
|
|
|
"""Get all the LLM APIs."""
|
|
|
|
return {
|
2024-05-20 02:11:25 +00:00
|
|
|
LLM_API_ASSIST: AssistAPI(hass=hass),
|
2024-05-19 01:14:05 +00:00
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
@callback
|
|
|
|
def async_register_api(hass: HomeAssistant, api: API) -> None:
|
|
|
|
"""Register an API to be exposed to LLMs."""
|
|
|
|
apis = _async_get_apis(hass)
|
|
|
|
|
|
|
|
if api.id in apis:
|
|
|
|
raise HomeAssistantError(f"API {api.id} is already registered")
|
|
|
|
|
|
|
|
apis[api.id] = api
|
|
|
|
|
|
|
|
|
|
|
|
@callback
|
|
|
|
def async_get_api(hass: HomeAssistant, api_id: str) -> API:
|
|
|
|
"""Get an API."""
|
|
|
|
apis = _async_get_apis(hass)
|
|
|
|
|
|
|
|
if api_id not in apis:
|
|
|
|
raise HomeAssistantError(f"API {api_id} not found")
|
|
|
|
|
|
|
|
return apis[api_id]
|
|
|
|
|
|
|
|
|
|
|
|
@callback
|
|
|
|
def async_get_apis(hass: HomeAssistant) -> list[API]:
|
|
|
|
"""Get all the LLM APIs."""
|
|
|
|
return list(_async_get_apis(hass).values())
|
2024-05-15 23:16:47 +00:00
|
|
|
|
|
|
|
|
|
|
|
@dataclass(slots=True)
|
2024-05-19 01:14:05 +00:00
|
|
|
class ToolInput(ABC):
|
2024-05-15 23:16:47 +00:00
|
|
|
"""Tool input to be processed."""
|
|
|
|
|
|
|
|
tool_name: str
|
|
|
|
tool_args: dict[str, Any]
|
|
|
|
platform: str
|
|
|
|
context: Context | None
|
|
|
|
user_prompt: str | None
|
|
|
|
language: str | None
|
|
|
|
assistant: str | None
|
2024-05-22 01:24:46 +00:00
|
|
|
device_id: str | None
|
2024-05-15 23:16:47 +00:00
|
|
|
|
|
|
|
|
|
|
|
class Tool:
|
|
|
|
"""LLM Tool base class."""
|
|
|
|
|
|
|
|
name: str
|
|
|
|
description: str | None = None
|
|
|
|
parameters: vol.Schema = vol.Schema({})
|
|
|
|
|
|
|
|
@abstractmethod
|
|
|
|
async def async_call(
|
|
|
|
self, hass: HomeAssistant, tool_input: ToolInput
|
|
|
|
) -> JsonObjectType:
|
|
|
|
"""Call the tool."""
|
|
|
|
raise NotImplementedError
|
|
|
|
|
|
|
|
def __repr__(self) -> str:
|
|
|
|
"""Represent a string of a Tool."""
|
|
|
|
return f"<{self.__class__.__name__} - {self.name}>"
|
|
|
|
|
|
|
|
|
2024-05-19 01:14:05 +00:00
|
|
|
@dataclass(slots=True, kw_only=True)
|
|
|
|
class API(ABC):
|
|
|
|
"""An API to expose to LLMs."""
|
2024-05-15 23:16:47 +00:00
|
|
|
|
2024-05-19 01:14:05 +00:00
|
|
|
hass: HomeAssistant
|
|
|
|
id: str
|
|
|
|
name: str
|
2024-05-24 20:04:48 +00:00
|
|
|
|
|
|
|
@abstractmethod
|
|
|
|
async def async_get_api_prompt(self, tool_input: ToolInput) -> str:
|
|
|
|
"""Return the prompt for the API."""
|
|
|
|
raise NotImplementedError
|
2024-05-15 23:16:47 +00:00
|
|
|
|
2024-05-19 01:14:05 +00:00
|
|
|
@abstractmethod
|
|
|
|
@callback
|
|
|
|
def async_get_tools(self) -> list[Tool]:
|
|
|
|
"""Return a list of tools."""
|
|
|
|
raise NotImplementedError
|
|
|
|
|
|
|
|
async def async_call_tool(self, tool_input: ToolInput) -> JsonObjectType:
|
|
|
|
"""Call a LLM tool, validate args and return the response."""
|
2024-05-25 18:16:51 +00:00
|
|
|
async_conversation_trace_append(
|
|
|
|
ConversationTraceEventType.LLM_TOOL_CALL, asdict(tool_input)
|
|
|
|
)
|
|
|
|
|
2024-05-19 01:14:05 +00:00
|
|
|
for tool in self.async_get_tools():
|
|
|
|
if tool.name == tool_input.tool_name:
|
|
|
|
break
|
|
|
|
else:
|
|
|
|
raise HomeAssistantError(f'Tool "{tool_input.tool_name}" not found')
|
|
|
|
|
2024-05-27 04:27:08 +00:00
|
|
|
return await tool.async_call(
|
|
|
|
self.hass,
|
|
|
|
replace(
|
|
|
|
tool_input,
|
|
|
|
tool_name=tool.name,
|
|
|
|
tool_args=tool.parameters(tool_input.tool_args),
|
|
|
|
context=tool_input.context or Context(),
|
|
|
|
),
|
2024-05-19 01:14:05 +00:00
|
|
|
)
|
|
|
|
|
2024-05-15 23:16:47 +00:00
|
|
|
|
|
|
|
class IntentTool(Tool):
|
|
|
|
"""LLM Tool representing an Intent."""
|
|
|
|
|
|
|
|
def __init__(
|
|
|
|
self,
|
|
|
|
intent_handler: intent.IntentHandler,
|
|
|
|
) -> None:
|
|
|
|
"""Init the class."""
|
|
|
|
self.name = intent_handler.intent_type
|
2024-05-21 16:54:34 +00:00
|
|
|
self.description = (
|
|
|
|
intent_handler.description or f"Execute Home Assistant {self.name} intent"
|
|
|
|
)
|
2024-05-15 23:16:47 +00:00
|
|
|
if slot_schema := intent_handler.slot_schema:
|
|
|
|
self.parameters = vol.Schema(slot_schema)
|
|
|
|
|
|
|
|
async def async_call(
|
|
|
|
self, hass: HomeAssistant, tool_input: ToolInput
|
|
|
|
) -> JsonObjectType:
|
|
|
|
"""Handle the intent."""
|
|
|
|
slots = {key: {"value": val} for key, val in tool_input.tool_args.items()}
|
|
|
|
|
|
|
|
intent_response = await intent.async_handle(
|
|
|
|
hass,
|
|
|
|
tool_input.platform,
|
|
|
|
self.name,
|
|
|
|
slots,
|
|
|
|
tool_input.user_prompt,
|
|
|
|
tool_input.context,
|
|
|
|
tool_input.language,
|
|
|
|
tool_input.assistant,
|
2024-05-22 01:24:46 +00:00
|
|
|
tool_input.device_id,
|
2024-05-15 23:16:47 +00:00
|
|
|
)
|
|
|
|
return intent_response.as_dict()
|
2024-05-19 01:14:05 +00:00
|
|
|
|
|
|
|
|
|
|
|
class AssistAPI(API):
|
|
|
|
"""API exposing Assist API to LLMs."""
|
|
|
|
|
|
|
|
IGNORE_INTENTS = {
|
|
|
|
intent.INTENT_NEVERMIND,
|
|
|
|
intent.INTENT_GET_STATE,
|
|
|
|
INTENT_GET_WEATHER,
|
|
|
|
INTENT_GET_TEMPERATURE,
|
|
|
|
}
|
|
|
|
|
2024-05-20 02:11:25 +00:00
|
|
|
def __init__(self, hass: HomeAssistant) -> None:
|
|
|
|
"""Init the class."""
|
|
|
|
super().__init__(
|
|
|
|
hass=hass,
|
|
|
|
id=LLM_API_ASSIST,
|
|
|
|
name="Assist",
|
|
|
|
)
|
|
|
|
|
2024-05-24 20:04:48 +00:00
|
|
|
async def async_get_api_prompt(self, tool_input: ToolInput) -> str:
|
|
|
|
"""Return the prompt for the API."""
|
2024-05-27 04:27:08 +00:00
|
|
|
if tool_input.assistant:
|
|
|
|
exposed_entities: dict | None = _get_exposed_entities(
|
|
|
|
self.hass, tool_input.assistant
|
|
|
|
)
|
|
|
|
else:
|
|
|
|
exposed_entities = None
|
|
|
|
|
|
|
|
if not exposed_entities:
|
|
|
|
return (
|
|
|
|
"Only if the user wants to control a device, tell them to expose entities "
|
|
|
|
"to their voice assistant in Home Assistant."
|
|
|
|
)
|
|
|
|
|
|
|
|
prompt = [
|
|
|
|
(
|
|
|
|
"Call the intent tools to control Home Assistant. "
|
|
|
|
"Just pass the name to the intent. "
|
|
|
|
"When controlling an area, prefer passing area name."
|
|
|
|
)
|
|
|
|
]
|
2024-05-25 02:23:05 +00:00
|
|
|
if tool_input.device_id:
|
2024-05-27 04:27:08 +00:00
|
|
|
device_reg = dr.async_get(self.hass)
|
2024-05-25 02:23:05 +00:00
|
|
|
device = device_reg.async_get(tool_input.device_id)
|
|
|
|
if device:
|
2024-05-27 04:27:08 +00:00
|
|
|
area_reg = ar.async_get(self.hass)
|
2024-05-25 02:23:05 +00:00
|
|
|
if device.area_id and (area := area_reg.async_get_area(device.area_id)):
|
2024-05-27 04:27:08 +00:00
|
|
|
floor_reg = fr.async_get(self.hass)
|
2024-05-25 02:23:05 +00:00
|
|
|
if area.floor_id and (
|
|
|
|
floor := floor_reg.async_get_floor(area.floor_id)
|
|
|
|
):
|
2024-05-27 04:27:08 +00:00
|
|
|
prompt.append(f"You are in {area.name} ({floor.name}).")
|
2024-05-25 02:23:05 +00:00
|
|
|
else:
|
2024-05-27 04:27:08 +00:00
|
|
|
prompt.append(f"You are in {area.name}.")
|
2024-05-25 02:23:05 +00:00
|
|
|
if tool_input.context and tool_input.context.user_id:
|
|
|
|
user = await self.hass.auth.async_get_user(tool_input.context.user_id)
|
|
|
|
if user:
|
2024-05-27 04:27:08 +00:00
|
|
|
prompt.append(f"The user name is {user.name}.")
|
|
|
|
|
|
|
|
if exposed_entities:
|
|
|
|
prompt.append(
|
|
|
|
"An overview of the areas and the devices in this smart home:"
|
|
|
|
)
|
|
|
|
prompt.append(yaml.dump(exposed_entities))
|
|
|
|
|
|
|
|
return "\n".join(prompt)
|
2024-05-24 20:04:48 +00:00
|
|
|
|
2024-05-19 01:14:05 +00:00
|
|
|
@callback
|
|
|
|
def async_get_tools(self) -> list[Tool]:
|
|
|
|
"""Return a list of LLM tools."""
|
|
|
|
return [
|
|
|
|
IntentTool(intent_handler)
|
|
|
|
for intent_handler in intent.async_get(self.hass)
|
|
|
|
if intent_handler.intent_type not in self.IGNORE_INTENTS
|
|
|
|
]
|
2024-05-27 04:27:08 +00:00
|
|
|
|
|
|
|
|
|
|
|
def _get_exposed_entities(
|
|
|
|
hass: HomeAssistant, assistant: str
|
|
|
|
) -> dict[str, dict[str, Any]]:
|
|
|
|
"""Get exposed entities."""
|
|
|
|
area_registry = ar.async_get(hass)
|
|
|
|
entity_registry = er.async_get(hass)
|
|
|
|
device_registry = dr.async_get(hass)
|
|
|
|
interesting_attributes = {
|
|
|
|
"temperature",
|
|
|
|
"current_temperature",
|
|
|
|
"temperature_unit",
|
|
|
|
"brightness",
|
|
|
|
"humidity",
|
|
|
|
"unit_of_measurement",
|
|
|
|
"device_class",
|
|
|
|
"current_position",
|
|
|
|
"percentage",
|
|
|
|
}
|
|
|
|
|
|
|
|
entities = {}
|
|
|
|
|
|
|
|
for state in hass.states.async_all():
|
|
|
|
if not async_should_expose(hass, assistant, state.entity_id):
|
|
|
|
continue
|
|
|
|
|
|
|
|
entity_entry = entity_registry.async_get(state.entity_id)
|
|
|
|
names = [state.name]
|
|
|
|
area_names = []
|
|
|
|
|
|
|
|
if entity_entry is not None:
|
|
|
|
names.extend(entity_entry.aliases)
|
|
|
|
if entity_entry.area_id and (
|
|
|
|
area := area_registry.async_get_area(entity_entry.area_id)
|
|
|
|
):
|
|
|
|
# Entity is in area
|
|
|
|
area_names.append(area.name)
|
|
|
|
area_names.extend(area.aliases)
|
|
|
|
elif entity_entry.device_id and (
|
|
|
|
device := device_registry.async_get(entity_entry.device_id)
|
|
|
|
):
|
|
|
|
# Check device area
|
|
|
|
if device.area_id and (
|
|
|
|
area := area_registry.async_get_area(device.area_id)
|
|
|
|
):
|
|
|
|
area_names.append(area.name)
|
|
|
|
area_names.extend(area.aliases)
|
|
|
|
|
|
|
|
info: dict[str, Any] = {
|
|
|
|
"names": ", ".join(names),
|
|
|
|
"state": state.state,
|
|
|
|
}
|
|
|
|
|
|
|
|
if area_names:
|
|
|
|
info["areas"] = ", ".join(area_names)
|
|
|
|
|
|
|
|
if attributes := {
|
|
|
|
attr_name: str(attr_value) if isinstance(attr_value, Enum) else attr_value
|
|
|
|
for attr_name, attr_value in state.attributes.items()
|
|
|
|
if attr_name in interesting_attributes
|
|
|
|
}:
|
|
|
|
info["attributes"] = attributes
|
|
|
|
|
|
|
|
entities[state.entity_id] = info
|
|
|
|
|
|
|
|
return entities
|