core/homeassistant/helpers/llm.py

463 lines
14 KiB
Python

"""Module to coordinate llm tools."""
from __future__ import annotations
from abc import ABC, abstractmethod
from dataclasses import dataclass
from enum import Enum
from typing import Any
import voluptuous as vol
from homeassistant.components.climate.intent import INTENT_GET_TEMPERATURE
from homeassistant.components.conversation.trace import (
ConversationTraceEventType,
async_conversation_trace_append,
)
from homeassistant.components.cover.intent import INTENT_CLOSE_COVER, INTENT_OPEN_COVER
from homeassistant.components.homeassistant.exposed_entities import async_should_expose
from homeassistant.components.intent import async_device_supports_timers
from homeassistant.components.weather.intent import INTENT_GET_WEATHER
from homeassistant.core import Context, HomeAssistant, callback
from homeassistant.exceptions import HomeAssistantError
from homeassistant.util import yaml
from homeassistant.util.json import JsonObjectType
from . import (
area_registry as ar,
device_registry as dr,
entity_registry as er,
floor_registry as fr,
intent,
service,
)
from .singleton import singleton
LLM_API_ASSIST = "assist"
BASE_PROMPT = (
'Current time is {{ now().strftime("%H:%M:%S") }}. '
'Today\'s date is {{ now().strftime("%Y-%m-%d") }}.\n'
)
DEFAULT_INSTRUCTIONS_PROMPT = """You are a voice assistant for Home Assistant.
Answer in plain text. Keep it simple and to the point.
"""
@callback
def async_render_no_api_prompt(hass: HomeAssistant) -> str:
"""Return the prompt to be used when no API is configured."""
return (
"Only if the user wants to control a device, tell them to edit the AI configuration "
"and allow access to Home Assistant."
)
@singleton("llm")
@callback
def _async_get_apis(hass: HomeAssistant) -> dict[str, API]:
"""Get all the LLM APIs."""
return {
LLM_API_ASSIST: AssistAPI(hass=hass),
}
@callback
def async_register_api(hass: HomeAssistant, api: API) -> None:
"""Register an API to be exposed to LLMs."""
apis = _async_get_apis(hass)
if api.id in apis:
raise HomeAssistantError(f"API {api.id} is already registered")
apis[api.id] = api
async def async_get_api(
hass: HomeAssistant, api_id: str, llm_context: LLMContext
) -> APIInstance:
"""Get an API."""
apis = _async_get_apis(hass)
if api_id not in apis:
raise HomeAssistantError(f"API {api_id} not found")
return await apis[api_id].async_get_api_instance(llm_context)
@callback
def async_get_apis(hass: HomeAssistant) -> list[API]:
"""Get all the LLM APIs."""
return list(_async_get_apis(hass).values())
@dataclass(slots=True)
class LLMContext:
"""Tool input to be processed."""
platform: str
context: Context | None
user_prompt: str | None
language: str | None
assistant: str | None
device_id: str | None
@dataclass(slots=True)
class ToolInput:
"""Tool input to be processed."""
tool_name: str
tool_args: dict[str, Any]
class Tool:
"""LLM Tool base class."""
name: str
description: str | None = None
parameters: vol.Schema = vol.Schema({})
@abstractmethod
async def async_call(
self, hass: HomeAssistant, tool_input: ToolInput, llm_context: LLMContext
) -> JsonObjectType:
"""Call the tool."""
raise NotImplementedError
def __repr__(self) -> str:
"""Represent a string of a Tool."""
return f"<{self.__class__.__name__} - {self.name}>"
@dataclass
class APIInstance:
"""Instance of an API to be used by an LLM."""
api: API
api_prompt: str
llm_context: LLMContext
tools: list[Tool]
async def async_call_tool(self, tool_input: ToolInput) -> JsonObjectType:
"""Call a LLM tool, validate args and return the response."""
async_conversation_trace_append(
ConversationTraceEventType.LLM_TOOL_CALL,
{"tool_name": tool_input.tool_name, "tool_args": tool_input.tool_args},
)
for tool in self.tools:
if tool.name == tool_input.tool_name:
break
else:
raise HomeAssistantError(f'Tool "{tool_input.tool_name}" not found')
return await tool.async_call(self.api.hass, tool_input, self.llm_context)
@dataclass(slots=True, kw_only=True)
class API(ABC):
"""An API to expose to LLMs."""
hass: HomeAssistant
id: str
name: str
@abstractmethod
async def async_get_api_instance(self, llm_context: LLMContext) -> APIInstance:
"""Return the instance of the API."""
raise NotImplementedError
class IntentTool(Tool):
"""LLM Tool representing an Intent."""
def __init__(
self,
intent_handler: intent.IntentHandler,
) -> None:
"""Init the class."""
self.name = intent_handler.intent_type
self.description = (
intent_handler.description or f"Execute Home Assistant {self.name} intent"
)
self.extra_slots = None
if not (slot_schema := intent_handler.slot_schema):
return
slot_schema = {**slot_schema}
extra_slots = set()
for field in ("preferred_area_id", "preferred_floor_id"):
if field in slot_schema:
extra_slots.add(field)
del slot_schema[field]
self.parameters = vol.Schema(slot_schema)
if extra_slots:
self.extra_slots = extra_slots
async def async_call(
self, hass: HomeAssistant, tool_input: ToolInput, llm_context: LLMContext
) -> JsonObjectType:
"""Handle the intent."""
slots = {key: {"value": val} for key, val in tool_input.tool_args.items()}
if self.extra_slots and llm_context.device_id:
device_reg = dr.async_get(hass)
device = device_reg.async_get(llm_context.device_id)
area: ar.AreaEntry | None = None
floor: fr.FloorEntry | None = None
if device:
area_reg = ar.async_get(hass)
if device.area_id and (area := area_reg.async_get_area(device.area_id)):
if area.floor_id:
floor_reg = fr.async_get(hass)
floor = floor_reg.async_get_floor(area.floor_id)
for slot_name, slot_value in (
("preferred_area_id", area.id if area else None),
("preferred_floor_id", floor.floor_id if floor else None),
):
if slot_value and slot_name in self.extra_slots:
slots[slot_name] = {"value": slot_value}
intent_response = await intent.async_handle(
hass=hass,
platform=llm_context.platform,
intent_type=self.name,
slots=slots,
text_input=llm_context.user_prompt,
context=llm_context.context,
language=llm_context.language,
assistant=llm_context.assistant,
device_id=llm_context.device_id,
)
response = intent_response.as_dict()
del response["language"]
del response["card"]
return response
class AssistAPI(API):
"""API exposing Assist API to LLMs."""
IGNORE_INTENTS = {
INTENT_GET_TEMPERATURE,
INTENT_GET_WEATHER,
INTENT_OPEN_COVER, # deprecated
INTENT_CLOSE_COVER, # deprecated
intent.INTENT_GET_STATE,
intent.INTENT_NEVERMIND,
intent.INTENT_TOGGLE,
}
def __init__(self, hass: HomeAssistant) -> None:
"""Init the class."""
super().__init__(
hass=hass,
id=LLM_API_ASSIST,
name="Assist",
)
async def async_get_api_instance(self, llm_context: LLMContext) -> APIInstance:
"""Return the instance of the API."""
if llm_context.assistant:
exposed_entities: dict | None = _get_exposed_entities(
self.hass, llm_context.assistant
)
else:
exposed_entities = None
return APIInstance(
api=self,
api_prompt=self._async_get_api_prompt(llm_context, exposed_entities),
llm_context=llm_context,
tools=self._async_get_tools(llm_context, exposed_entities),
)
@callback
def _async_get_api_prompt(
self, llm_context: LLMContext, exposed_entities: dict | None
) -> str:
"""Return the prompt for the API."""
if not exposed_entities:
return (
"Only if the user wants to control a device, tell them to expose entities "
"to their voice assistant in Home Assistant."
)
prompt = [
(
"When controlling Home Assistant always call the intent tools. "
"Use HassTurnOn to lock and HassTurnOff to unlock a lock. "
"When controlling a device, prefer passing just its name and its domain "
"(what comes before the dot in its entity id). "
"When controlling an area, prefer passing just area name and domain."
)
]
area: ar.AreaEntry | None = None
floor: fr.FloorEntry | None = None
if llm_context.device_id:
device_reg = dr.async_get(self.hass)
device = device_reg.async_get(llm_context.device_id)
if device:
area_reg = ar.async_get(self.hass)
if device.area_id and (area := area_reg.async_get_area(device.area_id)):
floor_reg = fr.async_get(self.hass)
if area.floor_id:
floor = floor_reg.async_get_floor(area.floor_id)
extra = "and all generic commands like 'turn on the lights' should target this area."
if floor and area:
prompt.append(f"You are in area {area.name} (floor {floor.name}) {extra}")
elif area:
prompt.append(f"You are in area {area.name} {extra}")
else:
prompt.append(
"When a user asks to turn on all devices of a specific type, "
"ask user to specify an area, unless there is only one device of that type."
)
if not llm_context.device_id or not async_device_supports_timers(
self.hass, llm_context.device_id
):
prompt.append("This device does not support timers.")
if exposed_entities:
prompt.append(
"An overview of the areas and the devices in this smart home:"
)
prompt.append(yaml.dump(exposed_entities))
return "\n".join(prompt)
@callback
def _async_get_tools(
self, llm_context: LLMContext, exposed_entities: dict | None
) -> list[Tool]:
"""Return a list of LLM tools."""
ignore_intents = self.IGNORE_INTENTS
if not llm_context.device_id or not async_device_supports_timers(
self.hass, llm_context.device_id
):
ignore_intents = ignore_intents | {
intent.INTENT_START_TIMER,
intent.INTENT_CANCEL_TIMER,
intent.INTENT_INCREASE_TIMER,
intent.INTENT_DECREASE_TIMER,
intent.INTENT_PAUSE_TIMER,
intent.INTENT_UNPAUSE_TIMER,
intent.INTENT_TIMER_STATUS,
}
intent_handlers = [
intent_handler
for intent_handler in intent.async_get(self.hass)
if intent_handler.intent_type not in ignore_intents
]
exposed_domains: set[str] | None = None
if exposed_entities is not None:
exposed_domains = {
entity_id.split(".")[0] for entity_id in exposed_entities
}
intent_handlers = [
intent_handler
for intent_handler in intent_handlers
if intent_handler.platforms is None
or intent_handler.platforms & exposed_domains
]
return [IntentTool(intent_handler) for intent_handler in intent_handlers]
def _get_exposed_entities(
hass: HomeAssistant, assistant: str
) -> dict[str, dict[str, Any]]:
"""Get exposed entities."""
area_registry = ar.async_get(hass)
entity_registry = er.async_get(hass)
device_registry = dr.async_get(hass)
interesting_attributes = {
"temperature",
"current_temperature",
"temperature_unit",
"brightness",
"humidity",
"unit_of_measurement",
"device_class",
"current_position",
"percentage",
"volume_level",
"media_title",
"media_artist",
"media_album_name",
}
entities = {}
for state in hass.states.async_all():
if not async_should_expose(hass, assistant, state.entity_id):
continue
entity_entry = entity_registry.async_get(state.entity_id)
names = [state.name]
area_names = []
description: str | None = None
if entity_entry is not None:
names.extend(entity_entry.aliases)
if entity_entry.area_id and (
area := area_registry.async_get_area(entity_entry.area_id)
):
# Entity is in area
area_names.append(area.name)
area_names.extend(area.aliases)
elif entity_entry.device_id and (
device := device_registry.async_get(entity_entry.device_id)
):
# Check device area
if device.area_id and (
area := area_registry.async_get_area(device.area_id)
):
area_names.append(area.name)
area_names.extend(area.aliases)
if (
state.domain == "script"
and entity_entry.unique_id
and (
service_desc := service.async_get_cached_service_description(
hass, "script", entity_entry.unique_id
)
)
):
description = service_desc.get("description")
info: dict[str, Any] = {
"names": ", ".join(names),
"state": state.state,
}
if description:
info["description"] = description
if area_names:
info["areas"] = ", ".join(area_names)
if attributes := {
attr_name: str(attr_value) if isinstance(attr_value, Enum) else attr_value
for attr_name, attr_value in state.attributes.items()
if attr_name in interesting_attributes
}:
info["attributes"] = attributes
entities[state.entity_id] = info
return entities