"""Module to coordinate llm tools.""" from __future__ import annotations from abc import ABC, abstractmethod from collections.abc import Callable from dataclasses import dataclass, field as dc_field from datetime import timedelta from decimal import Decimal from enum import Enum from functools import cache, partial from typing import Any, cast import slugify as unicode_slug import voluptuous as vol from voluptuous_openapi import UNSUPPORTED, convert from homeassistant.components.calendar import ( DOMAIN as CALENDAR_DOMAIN, SERVICE_GET_EVENTS, ) from homeassistant.components.cover import INTENT_CLOSE_COVER, INTENT_OPEN_COVER from homeassistant.components.homeassistant import async_should_expose from homeassistant.components.intent import async_device_supports_timers from homeassistant.components.script import DOMAIN as SCRIPT_DOMAIN from homeassistant.components.weather import INTENT_GET_WEATHER from homeassistant.const import ( ATTR_DOMAIN, ATTR_SERVICE, EVENT_HOMEASSISTANT_CLOSE, EVENT_SERVICE_REMOVED, ) from homeassistant.core import Context, Event, HomeAssistant, callback, split_entity_id from homeassistant.exceptions import HomeAssistantError from homeassistant.util import dt as dt_util, yaml as yaml_util from homeassistant.util.hass_dict import HassKey from homeassistant.util.json import JsonObjectType from homeassistant.util.ulid import ulid_now from . import ( area_registry as ar, config_validation as cv, device_registry as dr, entity_registry as er, floor_registry as fr, intent, selector, service, ) from .singleton import singleton ACTION_PARAMETERS_CACHE: HassKey[ dict[str, dict[str, tuple[str | None, vol.Schema]]] ] = HassKey("llm_action_parameters_cache") LLM_API_ASSIST = "assist" BASE_PROMPT = ( 'Current time is {{ now().strftime("%H:%M:%S") }}. ' 'Today\'s date is {{ now().strftime("%Y-%m-%d") }}.\n' ) DEFAULT_INSTRUCTIONS_PROMPT = """You are a voice assistant for Home Assistant. Answer questions about the world truthfully. Answer in plain text. Keep it simple and to the point. """ @callback def async_render_no_api_prompt(hass: HomeAssistant) -> str: """Return the prompt to be used when no API is configured. No longer used since Home Assistant 2024.7. """ return "" @singleton("llm") @callback def _async_get_apis(hass: HomeAssistant) -> dict[str, API]: """Get all the LLM APIs.""" return { LLM_API_ASSIST: AssistAPI(hass=hass), } @callback def async_register_api(hass: HomeAssistant, api: API) -> Callable[[], None]: """Register an API to be exposed to LLMs.""" apis = _async_get_apis(hass) if api.id in apis: raise HomeAssistantError(f"API {api.id} is already registered") apis[api.id] = api @callback def unregister() -> None: """Unregister the API.""" apis.pop(api.id) return unregister async def async_get_api( hass: HomeAssistant, api_id: str, llm_context: LLMContext ) -> APIInstance: """Get an API.""" apis = _async_get_apis(hass) if api_id not in apis: raise HomeAssistantError(f"API {api_id} not found") return await apis[api_id].async_get_api_instance(llm_context) @callback def async_get_apis(hass: HomeAssistant) -> list[API]: """Get all the LLM APIs.""" return list(_async_get_apis(hass).values()) @dataclass(slots=True) class LLMContext: """Tool input to be processed.""" platform: str context: Context | None user_prompt: str | None language: str | None assistant: str | None device_id: str | None @dataclass(slots=True) class ToolInput: """Tool input to be processed.""" tool_name: str tool_args: dict[str, Any] # Using lambda for default to allow patching in tests id: str = dc_field(default_factory=lambda: ulid_now()) # pylint: disable=unnecessary-lambda class Tool: """LLM Tool base class.""" name: str description: str | None = None parameters: vol.Schema = vol.Schema({}) @abstractmethod async def async_call( self, hass: HomeAssistant, tool_input: ToolInput, llm_context: LLMContext ) -> JsonObjectType: """Call the tool.""" raise NotImplementedError def __repr__(self) -> str: """Represent a string of a Tool.""" return f"<{self.__class__.__name__} - {self.name}>" @dataclass class APIInstance: """Instance of an API to be used by an LLM.""" api: API api_prompt: str llm_context: LLMContext tools: list[Tool] custom_serializer: Callable[[Any], Any] | None = None async def async_call_tool(self, tool_input: ToolInput) -> JsonObjectType: """Call a LLM tool, validate args and return the response.""" # pylint: disable=import-outside-toplevel from homeassistant.components.conversation import ( ConversationTraceEventType, async_conversation_trace_append, ) async_conversation_trace_append( ConversationTraceEventType.TOOL_CALL, {"tool_name": tool_input.tool_name, "tool_args": tool_input.tool_args}, ) for tool in self.tools: if tool.name == tool_input.tool_name: break else: raise HomeAssistantError(f'Tool "{tool_input.tool_name}" not found') return await tool.async_call(self.api.hass, tool_input, self.llm_context) @dataclass(slots=True, kw_only=True) class API(ABC): """An API to expose to LLMs.""" hass: HomeAssistant id: str name: str @abstractmethod async def async_get_api_instance(self, llm_context: LLMContext) -> APIInstance: """Return the instance of the API.""" raise NotImplementedError class IntentTool(Tool): """LLM Tool representing an Intent.""" def __init__( self, name: str, intent_handler: intent.IntentHandler, ) -> None: """Init the class.""" self.name = name self.description = ( intent_handler.description or f"Execute Home Assistant {self.name} intent" ) self.extra_slots = None if not (slot_schema := intent_handler.slot_schema): return slot_schema = {**slot_schema} extra_slots = set() for field in ("preferred_area_id", "preferred_floor_id"): if field in slot_schema: extra_slots.add(field) del slot_schema[field] self.parameters = vol.Schema(slot_schema) if extra_slots: self.extra_slots = extra_slots async def async_call( self, hass: HomeAssistant, tool_input: ToolInput, llm_context: LLMContext ) -> JsonObjectType: """Handle the intent.""" slots = {key: {"value": val} for key, val in tool_input.tool_args.items()} if self.extra_slots and llm_context.device_id: device_reg = dr.async_get(hass) device = device_reg.async_get(llm_context.device_id) area: ar.AreaEntry | None = None floor: fr.FloorEntry | None = None if device: area_reg = ar.async_get(hass) if device.area_id and (area := area_reg.async_get_area(device.area_id)): if area.floor_id: floor_reg = fr.async_get(hass) floor = floor_reg.async_get_floor(area.floor_id) for slot_name, slot_value in ( ("preferred_area_id", area.id if area else None), ("preferred_floor_id", floor.floor_id if floor else None), ): if slot_value and slot_name in self.extra_slots: slots[slot_name] = {"value": slot_value} intent_response = await intent.async_handle( hass=hass, platform=llm_context.platform, intent_type=self.name, slots=slots, text_input=llm_context.user_prompt, context=llm_context.context, language=llm_context.language, assistant=llm_context.assistant, device_id=llm_context.device_id, ) response = intent_response.as_dict() del response["language"] del response["card"] return response class AssistAPI(API): """API exposing Assist API to LLMs.""" IGNORE_INTENTS = { intent.INTENT_GET_TEMPERATURE, INTENT_GET_WEATHER, INTENT_OPEN_COVER, # deprecated INTENT_CLOSE_COVER, # deprecated intent.INTENT_GET_STATE, intent.INTENT_NEVERMIND, intent.INTENT_TOGGLE, intent.INTENT_GET_CURRENT_DATE, intent.INTENT_GET_CURRENT_TIME, intent.INTENT_RESPOND, } def __init__(self, hass: HomeAssistant) -> None: """Init the class.""" super().__init__( hass=hass, id=LLM_API_ASSIST, name="Assist", ) self.cached_slugify = cache( partial(unicode_slug.slugify, separator="_", lowercase=False) ) async def async_get_api_instance(self, llm_context: LLMContext) -> APIInstance: """Return the instance of the API.""" if llm_context.assistant: exposed_entities: dict | None = _get_exposed_entities( self.hass, llm_context.assistant ) else: exposed_entities = None return APIInstance( api=self, api_prompt=self._async_get_api_prompt(llm_context, exposed_entities), llm_context=llm_context, tools=self._async_get_tools(llm_context, exposed_entities), custom_serializer=_selector_serializer, ) @callback def _async_get_api_prompt( self, llm_context: LLMContext, exposed_entities: dict | None ) -> str: if not exposed_entities or not exposed_entities["entities"]: return ( "Only if the user wants to control a device, tell them to expose entities " "to their voice assistant in Home Assistant." ) return "\n".join( [ *self._async_get_preable(llm_context), *self._async_get_exposed_entities_prompt(llm_context, exposed_entities), ] ) @callback def _async_get_preable(self, llm_context: LLMContext) -> list[str]: """Return the prompt for the API.""" prompt = [ ( "When controlling Home Assistant always call the intent tools. " "Use HassTurnOn to lock and HassTurnOff to unlock a lock. " "When controlling a device, prefer passing just name and domain. " "When controlling an area, prefer passing just area name and domain." ) ] area: ar.AreaEntry | None = None floor: fr.FloorEntry | None = None if llm_context.device_id: device_reg = dr.async_get(self.hass) device = device_reg.async_get(llm_context.device_id) if device: area_reg = ar.async_get(self.hass) if device.area_id and (area := area_reg.async_get_area(device.area_id)): floor_reg = fr.async_get(self.hass) if area.floor_id: floor = floor_reg.async_get_floor(area.floor_id) extra = "and all generic commands like 'turn on the lights' should target this area." if floor and area: prompt.append(f"You are in area {area.name} (floor {floor.name}) {extra}") elif area: prompt.append(f"You are in area {area.name} {extra}") else: prompt.append( "When a user asks to turn on all devices of a specific type, " "ask user to specify an area, unless there is only one device of that type." ) if not llm_context.device_id or not async_device_supports_timers( self.hass, llm_context.device_id ): prompt.append("This device is not able to start timers.") return prompt @callback def _async_get_exposed_entities_prompt( self, llm_context: LLMContext, exposed_entities: dict | None ) -> list[str]: """Return the prompt for the API for exposed entities.""" prompt = [] if exposed_entities and exposed_entities["entities"]: prompt.append( "An overview of the areas and the devices in this smart home:" ) prompt.append(yaml_util.dump(list(exposed_entities["entities"].values()))) return prompt @callback def _async_get_tools( self, llm_context: LLMContext, exposed_entities: dict | None ) -> list[Tool]: """Return a list of LLM tools.""" ignore_intents = self.IGNORE_INTENTS if not llm_context.device_id or not async_device_supports_timers( self.hass, llm_context.device_id ): ignore_intents = ignore_intents | { intent.INTENT_START_TIMER, intent.INTENT_CANCEL_TIMER, intent.INTENT_INCREASE_TIMER, intent.INTENT_DECREASE_TIMER, intent.INTENT_PAUSE_TIMER, intent.INTENT_UNPAUSE_TIMER, intent.INTENT_TIMER_STATUS, } intent_handlers = [ intent_handler for intent_handler in intent.async_get(self.hass) if intent_handler.intent_type not in ignore_intents ] exposed_domains: set[str] | None = None if exposed_entities is not None: exposed_domains = { info["domain"] for info in exposed_entities["entities"].values() } intent_handlers = [ intent_handler for intent_handler in intent_handlers if intent_handler.platforms is None or intent_handler.platforms & exposed_domains ] tools: list[Tool] = [ IntentTool(self.cached_slugify(intent_handler.intent_type), intent_handler) for intent_handler in intent_handlers ] if exposed_entities: if exposed_entities[CALENDAR_DOMAIN]: names = [] for info in exposed_entities[CALENDAR_DOMAIN].values(): names.extend(info["names"].split(", ")) tools.append(CalendarGetEventsTool(names)) tools.extend( ScriptTool(self.hass, script_entity_id) for script_entity_id in exposed_entities[SCRIPT_DOMAIN] ) return tools def _get_exposed_entities( hass: HomeAssistant, assistant: str ) -> dict[str, dict[str, dict[str, Any]]]: """Get exposed entities. Splits out calendars and scripts. """ area_registry = ar.async_get(hass) entity_registry = er.async_get(hass) device_registry = dr.async_get(hass) interesting_attributes = { "temperature", "current_temperature", "temperature_unit", "brightness", "humidity", "unit_of_measurement", "device_class", "current_position", "percentage", "volume_level", "media_title", "media_artist", "media_album_name", } entities = {} data: dict[str, dict[str, Any]] = { SCRIPT_DOMAIN: {}, CALENDAR_DOMAIN: {}, } for state in hass.states.async_all(): if not async_should_expose(hass, assistant, state.entity_id): continue description: str | None = None entity_entry = entity_registry.async_get(state.entity_id) names = [state.name] area_names = [] if entity_entry is not None: names.extend(entity_entry.aliases) if entity_entry.area_id and ( area := area_registry.async_get_area(entity_entry.area_id) ): # Entity is in area area_names.append(area.name) area_names.extend(area.aliases) elif entity_entry.device_id and ( device := device_registry.async_get(entity_entry.device_id) ): # Check device area if device.area_id and ( area := area_registry.async_get_area(device.area_id) ): area_names.append(area.name) area_names.extend(area.aliases) info: dict[str, Any] = { "names": ", ".join(names), "domain": state.domain, "state": state.state, } if description: info["description"] = description if area_names: info["areas"] = ", ".join(area_names) if attributes := { attr_name: ( str(attr_value) if isinstance(attr_value, (Enum, Decimal, int)) else attr_value ) for attr_name, attr_value in state.attributes.items() if attr_name in interesting_attributes }: info["attributes"] = attributes if state.domain in data: data[state.domain][state.entity_id] = info else: entities[state.entity_id] = info data["entities"] = entities return data def _selector_serializer(schema: Any) -> Any: # noqa: C901 """Convert selectors into OpenAPI schema.""" if not isinstance(schema, selector.Selector): return UNSUPPORTED if isinstance(schema, selector.BackupLocationSelector): return {"type": "string", "pattern": "^(?:\\/backup|\\w+)$"} if isinstance(schema, selector.BooleanSelector): return {"type": "boolean"} if isinstance(schema, selector.ColorRGBSelector): return { "type": "array", "items": {"type": "number"}, "minItems": 3, "maxItems": 3, "format": "RGB", } if isinstance(schema, selector.ConditionSelector): return convert(cv.CONDITIONS_SCHEMA) if isinstance(schema, selector.ConstantSelector): return convert(vol.Schema(schema.config["value"])) result: dict[str, Any] if isinstance(schema, selector.ColorTempSelector): result = {"type": "number"} if "min" in schema.config: result["minimum"] = schema.config["min"] elif "min_mireds" in schema.config: result["minimum"] = schema.config["min_mireds"] if "max" in schema.config: result["maximum"] = schema.config["max"] elif "max_mireds" in schema.config: result["maximum"] = schema.config["max_mireds"] return result if isinstance(schema, selector.CountrySelector): if schema.config.get("countries"): return {"type": "string", "enum": schema.config["countries"]} return {"type": "string", "format": "ISO 3166-1 alpha-2"} if isinstance(schema, selector.DateSelector): return {"type": "string", "format": "date"} if isinstance(schema, selector.DateTimeSelector): return {"type": "string", "format": "date-time"} if isinstance(schema, selector.DurationSelector): return convert(cv.time_period_dict) if isinstance(schema, selector.EntitySelector): if schema.config.get("multiple"): return {"type": "array", "items": {"type": "string", "format": "entity_id"}} return {"type": "string", "format": "entity_id"} if isinstance(schema, selector.LanguageSelector): if schema.config.get("languages"): return {"type": "string", "enum": schema.config["languages"]} return {"type": "string", "format": "RFC 5646"} if isinstance(schema, (selector.LocationSelector, selector.MediaSelector)): return convert(schema.DATA_SCHEMA) if isinstance(schema, selector.NumberSelector): result = {"type": "number"} if "min" in schema.config: result["minimum"] = schema.config["min"] if "max" in schema.config: result["maximum"] = schema.config["max"] return result if isinstance(schema, selector.ObjectSelector): return {"type": "object", "additionalProperties": True} if isinstance(schema, selector.SelectSelector): options = [ x["value"] if isinstance(x, dict) else x for x in schema.config["options"] ] if schema.config.get("multiple"): return { "type": "array", "items": {"type": "string", "enum": options}, "uniqueItems": True, } return {"type": "string", "enum": options} if isinstance(schema, selector.TargetSelector): return convert(cv.TARGET_SERVICE_FIELDS) if isinstance(schema, selector.TemplateSelector): return {"type": "string", "format": "jinja2"} if isinstance(schema, selector.TimeSelector): return {"type": "string", "format": "time"} if isinstance(schema, selector.TriggerSelector): return {"type": "array", "items": {"type": "string"}} if schema.config.get("multiple"): return {"type": "array", "items": {"type": "string"}} return {"type": "string"} def _get_cached_action_parameters( hass: HomeAssistant, domain: str, action: str ) -> tuple[str | None, vol.Schema]: """Get action description and schema.""" description = None parameters = vol.Schema({}) parameters_cache = hass.data.get(ACTION_PARAMETERS_CACHE) if parameters_cache is None: parameters_cache = hass.data[ACTION_PARAMETERS_CACHE] = {} @callback def clear_cache(event: Event) -> None: """Clear action parameter cache on action removal.""" if ( event.data[ATTR_DOMAIN] in parameters_cache and event.data[ATTR_SERVICE] in parameters_cache[event.data[ATTR_DOMAIN]] ): parameters_cache[event.data[ATTR_DOMAIN]].pop(event.data[ATTR_SERVICE]) cancel = hass.bus.async_listen(EVENT_SERVICE_REMOVED, clear_cache) @callback def on_homeassistant_close(event: Event) -> None: """Cleanup.""" cancel() hass.bus.async_listen_once(EVENT_HOMEASSISTANT_CLOSE, on_homeassistant_close) if domain in parameters_cache and action in parameters_cache[domain]: return parameters_cache[domain][action] if action_desc := service.async_get_cached_service_description( hass, domain, action ): description = action_desc.get("description") schema: dict[vol.Marker, Any] = {} fields = action_desc.get("fields", {}) for field, config in fields.items(): field_description = config.get("description") if not field_description: field_description = config.get("name") key: vol.Marker if config.get("required"): key = vol.Required(field, description=field_description) else: key = vol.Optional(field, description=field_description) if "selector" in config: schema[key] = selector.selector(config["selector"]) else: schema[key] = cv.string parameters = vol.Schema(schema) if domain == SCRIPT_DOMAIN: entity_registry = er.async_get(hass) if ( entity_id := entity_registry.async_get_entity_id(domain, domain, action) ) and (entity_entry := entity_registry.async_get(entity_id)): aliases: list[str] = [] if entity_entry.name: aliases.append(entity_entry.name) if entity_entry.aliases: aliases.extend(entity_entry.aliases) if aliases: if description: description = description + ". Aliases: " + str(list(aliases)) else: description = "Aliases: " + str(list(aliases)) parameters_cache.setdefault(domain, {})[action] = (description, parameters) return description, parameters class ActionTool(Tool): """LLM Tool representing an action.""" def __init__( self, hass: HomeAssistant, domain: str, action: str, ) -> None: """Init the class.""" self._domain = domain self._action = action self.name = f"{domain}.{action}" self.description, self.parameters = _get_cached_action_parameters( hass, domain, action ) async def async_call( self, hass: HomeAssistant, tool_input: ToolInput, llm_context: LLMContext ) -> JsonObjectType: """Call the action.""" for field, validator in self.parameters.schema.items(): if field not in tool_input.tool_args: continue if isinstance(validator, selector.AreaSelector): area_reg = ar.async_get(hass) if validator.config.get("multiple"): areas: list[ar.AreaEntry] = [] for area in tool_input.tool_args[field]: areas.extend(intent.find_areas(area, area_reg)) tool_input.tool_args[field] = list({area.id for area in areas}) else: area = tool_input.tool_args[field] area = list(intent.find_areas(area, area_reg))[0].id tool_input.tool_args[field] = area elif isinstance(validator, selector.FloorSelector): floor_reg = fr.async_get(hass) if validator.config.get("multiple"): floors: list[fr.FloorEntry] = [] for floor in tool_input.tool_args[field]: floors.extend(intent.find_floors(floor, floor_reg)) tool_input.tool_args[field] = list( {floor.floor_id for floor in floors} ) else: floor = tool_input.tool_args[field] floor = list(intent.find_floors(floor, floor_reg))[0].floor_id tool_input.tool_args[field] = floor result = await hass.services.async_call( self._domain, self._action, tool_input.tool_args, context=llm_context.context, blocking=True, return_response=True, ) return {"success": True, "result": result} class ScriptTool(ActionTool): """LLM Tool representing a Script.""" def __init__( self, hass: HomeAssistant, script_entity_id: str, ) -> None: """Init the class.""" script_name = split_entity_id(script_entity_id)[1] action = script_name entity_registry = er.async_get(hass) entity_entry = entity_registry.async_get(script_entity_id) if entity_entry and entity_entry.unique_id: action = entity_entry.unique_id super().__init__(hass, SCRIPT_DOMAIN, action) self.name = script_name if self.name[0].isdigit(): self.name = "_" + self.name class CalendarGetEventsTool(Tool): """LLM Tool allowing querying a calendar.""" name = "calendar_get_events" description = ( "Get events from a calendar. " "When asked if something happens, search the whole week. " "Results are RFC 5545 which means 'end' is exclusive." ) def __init__(self, calendars: list[str]) -> None: """Init the get events tool.""" self.parameters = vol.Schema( { vol.Required("calendar"): vol.In(calendars), vol.Required("range"): vol.In(["today", "week"]), } ) async def async_call( self, hass: HomeAssistant, tool_input: ToolInput, llm_context: LLMContext ) -> JsonObjectType: """Query a calendar.""" data = self.parameters(tool_input.tool_args) result = intent.async_match_targets( hass, intent.MatchTargetsConstraints( name=data["calendar"], domains=[CALENDAR_DOMAIN], assistant=llm_context.assistant, ), ) if not result.is_match: return {"success": False, "error": "Calendar not found"} entity_id = result.states[0].entity_id if data["range"] == "today": start = dt_util.now() end = dt_util.start_of_local_day() + timedelta(days=1) elif data["range"] == "week": start = dt_util.now() end = dt_util.start_of_local_day() + timedelta(days=7) service_data = { "entity_id": entity_id, "start_date_time": start.isoformat(), "end_date_time": end.isoformat(), } service_result = await hass.services.async_call( CALENDAR_DOMAIN, SERVICE_GET_EVENTS, service_data, context=llm_context.context, blocking=True, return_response=True, ) events = [ event if "T" in event["start"] else {**event, "all_day": True} for event in cast(dict, service_result)[entity_id]["events"] ] return {"success": True, "result": events}