"""Support for functionality to have conversations with Home Assistant.""" from __future__ import annotations import asyncio from collections.abc import Iterable from dataclasses import dataclass import logging import re from typing import Any, Literal from hassil.recognize import RecognizeResult import voluptuous as vol from homeassistant import core from homeassistant.components import http, websocket_api from homeassistant.components.http.data_validator import RequestDataValidator from homeassistant.config_entries import ConfigEntry from homeassistant.const import MATCH_ALL from homeassistant.core import HomeAssistant from homeassistant.exceptions import HomeAssistantError from homeassistant.helpers import config_validation as cv, intent, singleton from homeassistant.helpers.typing import ConfigType from homeassistant.loader import bind_hass from homeassistant.util import language as language_util from .agent import AbstractConversationAgent, ConversationInput, ConversationResult from .const import HOME_ASSISTANT_AGENT from .default_agent import DefaultAgent, async_setup as async_setup_default_agent __all__ = [ "DOMAIN", "HOME_ASSISTANT_AGENT", "async_converse", "async_get_agent_info", "async_set_agent", "async_unset_agent", "async_setup", ] _LOGGER = logging.getLogger(__name__) ATTR_TEXT = "text" ATTR_LANGUAGE = "language" ATTR_AGENT_ID = "agent_id" DOMAIN = "conversation" REGEX_TYPE = type(re.compile("")) DATA_CONFIG = "conversation_config" SERVICE_PROCESS = "process" SERVICE_RELOAD = "reload" def agent_id_validator(value: Any) -> str: """Validate agent ID.""" hass = core.async_get_hass() manager = _get_agent_manager(hass) if not manager.async_is_valid_agent_id(cv.string(value)): raise vol.Invalid("invalid agent ID") return value SERVICE_PROCESS_SCHEMA = vol.Schema( { vol.Required(ATTR_TEXT): cv.string, vol.Optional(ATTR_LANGUAGE): cv.string, vol.Optional(ATTR_AGENT_ID): agent_id_validator, } ) SERVICE_RELOAD_SCHEMA = vol.Schema( { vol.Optional(ATTR_LANGUAGE): cv.string, vol.Optional(ATTR_AGENT_ID): agent_id_validator, } ) CONFIG_SCHEMA = vol.Schema( { vol.Optional(DOMAIN): vol.Schema( { vol.Optional("intents"): vol.Schema( {cv.string: vol.All(cv.ensure_list, [cv.string])} ) } ) }, extra=vol.ALLOW_EXTRA, ) @singleton.singleton("conversation_agent") @core.callback def _get_agent_manager(hass: HomeAssistant) -> AgentManager: """Get the active agent.""" manager = AgentManager(hass) manager.async_setup() return manager @core.callback @bind_hass def async_set_agent( hass: core.HomeAssistant, config_entry: ConfigEntry, agent: AbstractConversationAgent, ): """Set the agent to handle the conversations.""" _get_agent_manager(hass).async_set_agent(config_entry.entry_id, agent) @core.callback @bind_hass def async_unset_agent( hass: core.HomeAssistant, config_entry: ConfigEntry, ): """Set the agent to handle the conversations.""" _get_agent_manager(hass).async_unset_agent(config_entry.entry_id) async def async_get_conversation_languages( hass: HomeAssistant, agent_id: str | None = None ) -> set[str] | Literal["*"]: """Return languages supported by conversation agents. If an agent is specified, returns a set of languages supported by that agent. If no agent is specified, return a set with the union of languages supported by all conversation agents. """ agent_manager = _get_agent_manager(hass) languages = set() agent_ids: Iterable[str] if agent_id is None: agent_ids = iter(info.id for info in agent_manager.async_get_agent_info()) else: agent_ids = (agent_id,) for _agent_id in agent_ids: agent = await agent_manager.async_get_agent(_agent_id) if agent.supported_languages == MATCH_ALL: return MATCH_ALL for language_tag in agent.supported_languages: languages.add(language_tag) return languages async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool: """Register the process service.""" agent_manager = _get_agent_manager(hass) if config_intents := config.get(DOMAIN, {}).get("intents"): hass.data[DATA_CONFIG] = config_intents async def handle_process(service: core.ServiceCall) -> core.ServiceResponse: """Parse text into commands.""" text = service.data[ATTR_TEXT] _LOGGER.debug("Processing: <%s>", text) try: result = await async_converse( hass=hass, text=text, conversation_id=None, context=service.context, language=service.data.get(ATTR_LANGUAGE), agent_id=service.data.get(ATTR_AGENT_ID), ) except intent.IntentHandleError as err: raise HomeAssistantError(f"Error processing {text}: {err}") from err if service.return_response: return result.as_dict() return None async def handle_reload(service: core.ServiceCall) -> None: """Reload intents.""" agent = await agent_manager.async_get_agent() await agent.async_reload(language=service.data.get(ATTR_LANGUAGE)) hass.services.async_register( DOMAIN, SERVICE_PROCESS, handle_process, schema=SERVICE_PROCESS_SCHEMA, supports_response=core.SupportsResponse.OPTIONAL, ) hass.services.async_register( DOMAIN, SERVICE_RELOAD, handle_reload, schema=SERVICE_RELOAD_SCHEMA ) hass.http.register_view(ConversationProcessView()) websocket_api.async_register_command(hass, websocket_process) websocket_api.async_register_command(hass, websocket_prepare) websocket_api.async_register_command(hass, websocket_list_agents) websocket_api.async_register_command(hass, websocket_hass_agent_debug) return True @websocket_api.websocket_command( { vol.Required("type"): "conversation/process", vol.Required("text"): str, vol.Optional("conversation_id"): vol.Any(str, None), vol.Optional("language"): str, vol.Optional("agent_id"): agent_id_validator, } ) @websocket_api.async_response async def websocket_process( hass: HomeAssistant, connection: websocket_api.ActiveConnection, msg: dict[str, Any], ) -> None: """Process text.""" result = await async_converse( hass=hass, text=msg["text"], conversation_id=msg.get("conversation_id"), context=connection.context(msg), language=msg.get("language"), agent_id=msg.get("agent_id"), ) connection.send_result(msg["id"], result.as_dict()) @websocket_api.websocket_command( { "type": "conversation/prepare", vol.Optional("language"): str, vol.Optional("agent_id"): agent_id_validator, } ) @websocket_api.async_response async def websocket_prepare( hass: HomeAssistant, connection: websocket_api.ActiveConnection, msg: dict[str, Any], ) -> None: """Reload intents.""" manager = _get_agent_manager(hass) agent = await manager.async_get_agent(msg.get("agent_id")) await agent.async_prepare(msg.get("language")) connection.send_result(msg["id"]) @websocket_api.websocket_command( { vol.Required("type"): "conversation/agent/list", vol.Optional("language"): str, vol.Optional("country"): str, } ) @websocket_api.async_response async def websocket_list_agents( hass: HomeAssistant, connection: websocket_api.ActiveConnection, msg: dict ) -> None: """List conversation agents and, optionally, if they support a given language.""" manager = _get_agent_manager(hass) country = msg.get("country") language = msg.get("language") agents = [] for agent_info in manager.async_get_agent_info(): agent = await manager.async_get_agent(agent_info.id) supported_languages = agent.supported_languages if language and supported_languages != MATCH_ALL: supported_languages = language_util.matches( language, supported_languages, country ) agent_dict: dict[str, Any] = { "id": agent_info.id, "name": agent_info.name, "supported_languages": supported_languages, } agents.append(agent_dict) connection.send_message(websocket_api.result_message(msg["id"], {"agents": agents})) @websocket_api.websocket_command( { vol.Required("type"): "conversation/agent/homeassistant/debug", vol.Required("sentences"): [str], vol.Optional("language"): str, vol.Optional("device_id"): vol.Any(str, None), } ) @websocket_api.async_response async def websocket_hass_agent_debug( hass: HomeAssistant, connection: websocket_api.ActiveConnection, msg: dict ) -> None: """Return intents that would be matched by the default agent for a list of sentences.""" agent = await _get_agent_manager(hass).async_get_agent(HOME_ASSISTANT_AGENT) assert isinstance(agent, DefaultAgent) results = [ await agent.async_recognize( ConversationInput( text=sentence, context=connection.context(msg), conversation_id=None, device_id=msg.get("device_id"), language=msg.get("language", hass.config.language), ) ) for sentence in msg["sentences"] ] # Return results for each sentence in the same order as the input. connection.send_result( msg["id"], { "results": [ { "intent": { "name": result.intent.name, }, "slots": { # direct access to values entity_key: entity.value for entity_key, entity in result.entities.items() }, "details": { entity_key: { "name": entity.name, "value": entity.value, "text": entity.text, } for entity_key, entity in result.entities.items() }, "targets": { state.entity_id: {"matched": is_matched} for state, is_matched in _get_debug_targets(hass, result) }, } if result is not None else None for result in results ] }, ) def _get_debug_targets( hass: HomeAssistant, result: RecognizeResult, ) -> Iterable[tuple[core.State, bool]]: """Yield state/is_matched pairs for a hassil recognition.""" entities = result.entities name: str | None = None area_name: str | None = None domains: set[str] | None = None device_classes: set[str] | None = None state_names: set[str] | None = None if "name" in entities: name = str(entities["name"].value) if "area" in entities: area_name = str(entities["area"].value) if "domain" in entities: domains = set(cv.ensure_list(entities["domain"].value)) if "device_class" in entities: device_classes = set(cv.ensure_list(entities["device_class"].value)) if "state" in entities: # HassGetState only state_names = set(cv.ensure_list(entities["state"].value)) states = intent.async_match_states( hass, name=name, area_name=area_name, domains=domains, device_classes=device_classes, ) for state in states: # For queries, a target is "matched" based on its state is_matched = (state_names is None) or (state.state in state_names) yield state, is_matched class ConversationProcessView(http.HomeAssistantView): """View to process text.""" url = "/api/conversation/process" name = "api:conversation:process" @RequestDataValidator( vol.Schema( { vol.Required("text"): str, vol.Optional("conversation_id"): str, vol.Optional("language"): str, vol.Optional("agent_id"): agent_id_validator, } ) ) async def post(self, request, data): """Send a request for processing.""" hass = request.app["hass"] result = await async_converse( hass, text=data["text"], conversation_id=data.get("conversation_id"), context=self.context(request), language=data.get("language"), agent_id=data.get("agent_id"), ) return self.json(result.as_dict()) @dataclass(frozen=True) class AgentInfo: """Container for conversation agent info.""" id: str name: str @core.callback def async_get_agent_info( hass: core.HomeAssistant, agent_id: str | None = None, ) -> AgentInfo | None: """Get information on the agent or None if not found.""" manager = _get_agent_manager(hass) if agent_id is None: agent_id = manager.default_agent for agent_info in manager.async_get_agent_info(): if agent_info.id == agent_id: return agent_info return None async def async_converse( hass: core.HomeAssistant, text: str, conversation_id: str | None, context: core.Context, language: str | None = None, agent_id: str | None = None, device_id: str | None = None, ) -> ConversationResult: """Process text and get intent.""" agent = await _get_agent_manager(hass).async_get_agent(agent_id) if language is None: language = hass.config.language _LOGGER.debug("Processing in %s: %s", language, text) result = await agent.async_process( ConversationInput( text=text, context=context, conversation_id=conversation_id, device_id=device_id, language=language, ) ) return result class AgentManager: """Class to manage conversation agents.""" default_agent: str = HOME_ASSISTANT_AGENT _builtin_agent: AbstractConversationAgent | None = None def __init__(self, hass: HomeAssistant) -> None: """Initialize the conversation agents.""" self.hass = hass self._agents: dict[str, AbstractConversationAgent] = {} self._builtin_agent_init_lock = asyncio.Lock() def async_setup(self) -> None: """Set up the conversation agents.""" async_setup_default_agent(self.hass) async def async_get_agent( self, agent_id: str | None = None ) -> AbstractConversationAgent: """Get the agent.""" if agent_id is None: agent_id = self.default_agent if agent_id == HOME_ASSISTANT_AGENT: if self._builtin_agent is not None: return self._builtin_agent async with self._builtin_agent_init_lock: if self._builtin_agent is not None: return self._builtin_agent self._builtin_agent = DefaultAgent(self.hass) await self._builtin_agent.async_initialize( self.hass.data.get(DATA_CONFIG) ) return self._builtin_agent if agent_id not in self._agents: raise ValueError(f"Agent {agent_id} not found") return self._agents[agent_id] @core.callback def async_get_agent_info(self) -> list[AgentInfo]: """List all agents.""" agents: list[AgentInfo] = [ AgentInfo( id=HOME_ASSISTANT_AGENT, name="Home Assistant", ) ] for agent_id, agent in self._agents.items(): config_entry = self.hass.config_entries.async_get_entry(agent_id) # Guard against potential bugs in conversation agents where the agent is not # removed from the manager when the config entry is removed if config_entry is None: _LOGGER.warning( "Conversation agent %s is still loaded after config entry removal", agent, ) continue agents.append( AgentInfo( id=agent_id, name=config_entry.title or config_entry.domain, ) ) return agents @core.callback def async_is_valid_agent_id(self, agent_id: str) -> bool: """Check if the agent id is valid.""" return agent_id in self._agents or agent_id == HOME_ASSISTANT_AGENT @core.callback def async_set_agent(self, agent_id: str, agent: AbstractConversationAgent) -> None: """Set the agent.""" self._agents[agent_id] = agent @core.callback def async_unset_agent(self, agent_id: str) -> None: """Unset the agent.""" self._agents.pop(agent_id, None)