core/homeassistant/components/conversation/__init__.py

565 lines
17 KiB
Python

"""Support for functionality to have conversations with Home Assistant."""
from __future__ import annotations
import asyncio
from collections.abc import Iterable
from dataclasses import dataclass
import logging
import re
from typing import Any, Literal
from hassil.recognize import RecognizeResult
import voluptuous as vol
from homeassistant import core
from homeassistant.components import http, websocket_api
from homeassistant.components.http.data_validator import RequestDataValidator
from homeassistant.config_entries import ConfigEntry
from homeassistant.const import MATCH_ALL
from homeassistant.core import HomeAssistant
from homeassistant.exceptions import HomeAssistantError
from homeassistant.helpers import config_validation as cv, intent, singleton
from homeassistant.helpers.typing import ConfigType
from homeassistant.loader import bind_hass
from homeassistant.util import language as language_util
from .agent import AbstractConversationAgent, ConversationInput, ConversationResult
from .const import HOME_ASSISTANT_AGENT
from .default_agent import DefaultAgent, async_setup as async_setup_default_agent
__all__ = [
"DOMAIN",
"HOME_ASSISTANT_AGENT",
"async_converse",
"async_get_agent_info",
"async_set_agent",
"async_unset_agent",
"async_setup",
]
_LOGGER = logging.getLogger(__name__)
ATTR_TEXT = "text"
ATTR_LANGUAGE = "language"
ATTR_AGENT_ID = "agent_id"
DOMAIN = "conversation"
REGEX_TYPE = type(re.compile(""))
DATA_CONFIG = "conversation_config"
SERVICE_PROCESS = "process"
SERVICE_RELOAD = "reload"
def agent_id_validator(value: Any) -> str:
"""Validate agent ID."""
hass = core.async_get_hass()
manager = _get_agent_manager(hass)
if not manager.async_is_valid_agent_id(cv.string(value)):
raise vol.Invalid("invalid agent ID")
return value
SERVICE_PROCESS_SCHEMA = vol.Schema(
{
vol.Required(ATTR_TEXT): cv.string,
vol.Optional(ATTR_LANGUAGE): cv.string,
vol.Optional(ATTR_AGENT_ID): agent_id_validator,
}
)
SERVICE_RELOAD_SCHEMA = vol.Schema(
{
vol.Optional(ATTR_LANGUAGE): cv.string,
vol.Optional(ATTR_AGENT_ID): agent_id_validator,
}
)
CONFIG_SCHEMA = vol.Schema(
{
vol.Optional(DOMAIN): vol.Schema(
{
vol.Optional("intents"): vol.Schema(
{cv.string: vol.All(cv.ensure_list, [cv.string])}
)
}
)
},
extra=vol.ALLOW_EXTRA,
)
@singleton.singleton("conversation_agent")
@core.callback
def _get_agent_manager(hass: HomeAssistant) -> AgentManager:
"""Get the active agent."""
manager = AgentManager(hass)
manager.async_setup()
return manager
@core.callback
@bind_hass
def async_set_agent(
hass: core.HomeAssistant,
config_entry: ConfigEntry,
agent: AbstractConversationAgent,
):
"""Set the agent to handle the conversations."""
_get_agent_manager(hass).async_set_agent(config_entry.entry_id, agent)
@core.callback
@bind_hass
def async_unset_agent(
hass: core.HomeAssistant,
config_entry: ConfigEntry,
):
"""Set the agent to handle the conversations."""
_get_agent_manager(hass).async_unset_agent(config_entry.entry_id)
async def async_get_conversation_languages(
hass: HomeAssistant, agent_id: str | None = None
) -> set[str] | Literal["*"]:
"""Return languages supported by conversation agents.
If an agent is specified, returns a set of languages supported by that agent.
If no agent is specified, return a set with the union of languages supported by
all conversation agents.
"""
agent_manager = _get_agent_manager(hass)
languages = set()
agent_ids: Iterable[str]
if agent_id is None:
agent_ids = iter(info.id for info in agent_manager.async_get_agent_info())
else:
agent_ids = (agent_id,)
for _agent_id in agent_ids:
agent = await agent_manager.async_get_agent(_agent_id)
if agent.supported_languages == MATCH_ALL:
return MATCH_ALL
for language_tag in agent.supported_languages:
languages.add(language_tag)
return languages
async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool:
"""Register the process service."""
agent_manager = _get_agent_manager(hass)
if config_intents := config.get(DOMAIN, {}).get("intents"):
hass.data[DATA_CONFIG] = config_intents
async def handle_process(service: core.ServiceCall) -> core.ServiceResponse:
"""Parse text into commands."""
text = service.data[ATTR_TEXT]
_LOGGER.debug("Processing: <%s>", text)
try:
result = await async_converse(
hass=hass,
text=text,
conversation_id=None,
context=service.context,
language=service.data.get(ATTR_LANGUAGE),
agent_id=service.data.get(ATTR_AGENT_ID),
)
except intent.IntentHandleError as err:
raise HomeAssistantError(f"Error processing {text}: {err}") from err
if service.return_response:
return result.as_dict()
return None
async def handle_reload(service: core.ServiceCall) -> None:
"""Reload intents."""
agent = await agent_manager.async_get_agent()
await agent.async_reload(language=service.data.get(ATTR_LANGUAGE))
hass.services.async_register(
DOMAIN,
SERVICE_PROCESS,
handle_process,
schema=SERVICE_PROCESS_SCHEMA,
supports_response=core.SupportsResponse.OPTIONAL,
)
hass.services.async_register(
DOMAIN, SERVICE_RELOAD, handle_reload, schema=SERVICE_RELOAD_SCHEMA
)
hass.http.register_view(ConversationProcessView())
websocket_api.async_register_command(hass, websocket_process)
websocket_api.async_register_command(hass, websocket_prepare)
websocket_api.async_register_command(hass, websocket_list_agents)
websocket_api.async_register_command(hass, websocket_hass_agent_debug)
return True
@websocket_api.websocket_command(
{
vol.Required("type"): "conversation/process",
vol.Required("text"): str,
vol.Optional("conversation_id"): vol.Any(str, None),
vol.Optional("language"): str,
vol.Optional("agent_id"): agent_id_validator,
}
)
@websocket_api.async_response
async def websocket_process(
hass: HomeAssistant,
connection: websocket_api.ActiveConnection,
msg: dict[str, Any],
) -> None:
"""Process text."""
result = await async_converse(
hass=hass,
text=msg["text"],
conversation_id=msg.get("conversation_id"),
context=connection.context(msg),
language=msg.get("language"),
agent_id=msg.get("agent_id"),
)
connection.send_result(msg["id"], result.as_dict())
@websocket_api.websocket_command(
{
"type": "conversation/prepare",
vol.Optional("language"): str,
vol.Optional("agent_id"): agent_id_validator,
}
)
@websocket_api.async_response
async def websocket_prepare(
hass: HomeAssistant,
connection: websocket_api.ActiveConnection,
msg: dict[str, Any],
) -> None:
"""Reload intents."""
manager = _get_agent_manager(hass)
agent = await manager.async_get_agent(msg.get("agent_id"))
await agent.async_prepare(msg.get("language"))
connection.send_result(msg["id"])
@websocket_api.websocket_command(
{
vol.Required("type"): "conversation/agent/list",
vol.Optional("language"): str,
vol.Optional("country"): str,
}
)
@websocket_api.async_response
async def websocket_list_agents(
hass: HomeAssistant, connection: websocket_api.ActiveConnection, msg: dict
) -> None:
"""List conversation agents and, optionally, if they support a given language."""
manager = _get_agent_manager(hass)
country = msg.get("country")
language = msg.get("language")
agents = []
for agent_info in manager.async_get_agent_info():
agent = await manager.async_get_agent(agent_info.id)
supported_languages = agent.supported_languages
if language and supported_languages != MATCH_ALL:
supported_languages = language_util.matches(
language, supported_languages, country
)
agent_dict: dict[str, Any] = {
"id": agent_info.id,
"name": agent_info.name,
"supported_languages": supported_languages,
}
agents.append(agent_dict)
connection.send_message(websocket_api.result_message(msg["id"], {"agents": agents}))
@websocket_api.websocket_command(
{
vol.Required("type"): "conversation/agent/homeassistant/debug",
vol.Required("sentences"): [str],
vol.Optional("language"): str,
vol.Optional("device_id"): vol.Any(str, None),
}
)
@websocket_api.async_response
async def websocket_hass_agent_debug(
hass: HomeAssistant, connection: websocket_api.ActiveConnection, msg: dict
) -> None:
"""Return intents that would be matched by the default agent for a list of sentences."""
agent = await _get_agent_manager(hass).async_get_agent(HOME_ASSISTANT_AGENT)
assert isinstance(agent, DefaultAgent)
results = [
await agent.async_recognize(
ConversationInput(
text=sentence,
context=connection.context(msg),
conversation_id=None,
device_id=msg.get("device_id"),
language=msg.get("language", hass.config.language),
)
)
for sentence in msg["sentences"]
]
# Return results for each sentence in the same order as the input.
connection.send_result(
msg["id"],
{
"results": [
{
"intent": {
"name": result.intent.name,
},
"slots": { # direct access to values
entity_key: entity.value
for entity_key, entity in result.entities.items()
},
"details": {
entity_key: {
"name": entity.name,
"value": entity.value,
"text": entity.text,
}
for entity_key, entity in result.entities.items()
},
"targets": {
state.entity_id: {"matched": is_matched}
for state, is_matched in _get_debug_targets(hass, result)
},
}
if result is not None
else None
for result in results
]
},
)
def _get_debug_targets(
hass: HomeAssistant,
result: RecognizeResult,
) -> Iterable[tuple[core.State, bool]]:
"""Yield state/is_matched pairs for a hassil recognition."""
entities = result.entities
name: str | None = None
area_name: str | None = None
domains: set[str] | None = None
device_classes: set[str] | None = None
state_names: set[str] | None = None
if "name" in entities:
name = str(entities["name"].value)
if "area" in entities:
area_name = str(entities["area"].value)
if "domain" in entities:
domains = set(cv.ensure_list(entities["domain"].value))
if "device_class" in entities:
device_classes = set(cv.ensure_list(entities["device_class"].value))
if "state" in entities:
# HassGetState only
state_names = set(cv.ensure_list(entities["state"].value))
states = intent.async_match_states(
hass,
name=name,
area_name=area_name,
domains=domains,
device_classes=device_classes,
)
for state in states:
# For queries, a target is "matched" based on its state
is_matched = (state_names is None) or (state.state in state_names)
yield state, is_matched
class ConversationProcessView(http.HomeAssistantView):
"""View to process text."""
url = "/api/conversation/process"
name = "api:conversation:process"
@RequestDataValidator(
vol.Schema(
{
vol.Required("text"): str,
vol.Optional("conversation_id"): str,
vol.Optional("language"): str,
vol.Optional("agent_id"): agent_id_validator,
}
)
)
async def post(self, request, data):
"""Send a request for processing."""
hass = request.app["hass"]
result = await async_converse(
hass,
text=data["text"],
conversation_id=data.get("conversation_id"),
context=self.context(request),
language=data.get("language"),
agent_id=data.get("agent_id"),
)
return self.json(result.as_dict())
@dataclass(frozen=True)
class AgentInfo:
"""Container for conversation agent info."""
id: str
name: str
@core.callback
def async_get_agent_info(
hass: core.HomeAssistant,
agent_id: str | None = None,
) -> AgentInfo | None:
"""Get information on the agent or None if not found."""
manager = _get_agent_manager(hass)
if agent_id is None:
agent_id = manager.default_agent
for agent_info in manager.async_get_agent_info():
if agent_info.id == agent_id:
return agent_info
return None
async def async_converse(
hass: core.HomeAssistant,
text: str,
conversation_id: str | None,
context: core.Context,
language: str | None = None,
agent_id: str | None = None,
device_id: str | None = None,
) -> ConversationResult:
"""Process text and get intent."""
agent = await _get_agent_manager(hass).async_get_agent(agent_id)
if language is None:
language = hass.config.language
_LOGGER.debug("Processing in %s: %s", language, text)
result = await agent.async_process(
ConversationInput(
text=text,
context=context,
conversation_id=conversation_id,
device_id=device_id,
language=language,
)
)
return result
class AgentManager:
"""Class to manage conversation agents."""
default_agent: str = HOME_ASSISTANT_AGENT
_builtin_agent: AbstractConversationAgent | None = None
def __init__(self, hass: HomeAssistant) -> None:
"""Initialize the conversation agents."""
self.hass = hass
self._agents: dict[str, AbstractConversationAgent] = {}
self._builtin_agent_init_lock = asyncio.Lock()
def async_setup(self) -> None:
"""Set up the conversation agents."""
async_setup_default_agent(self.hass)
async def async_get_agent(
self, agent_id: str | None = None
) -> AbstractConversationAgent:
"""Get the agent."""
if agent_id is None:
agent_id = self.default_agent
if agent_id == HOME_ASSISTANT_AGENT:
if self._builtin_agent is not None:
return self._builtin_agent
async with self._builtin_agent_init_lock:
if self._builtin_agent is not None:
return self._builtin_agent
self._builtin_agent = DefaultAgent(self.hass)
await self._builtin_agent.async_initialize(
self.hass.data.get(DATA_CONFIG)
)
return self._builtin_agent
if agent_id not in self._agents:
raise ValueError(f"Agent {agent_id} not found")
return self._agents[agent_id]
@core.callback
def async_get_agent_info(self) -> list[AgentInfo]:
"""List all agents."""
agents: list[AgentInfo] = [
AgentInfo(
id=HOME_ASSISTANT_AGENT,
name="Home Assistant",
)
]
for agent_id, agent in self._agents.items():
config_entry = self.hass.config_entries.async_get_entry(agent_id)
# Guard against potential bugs in conversation agents where the agent is not
# removed from the manager when the config entry is removed
if config_entry is None:
_LOGGER.warning(
"Conversation agent %s is still loaded after config entry removal",
agent,
)
continue
agents.append(
AgentInfo(
id=agent_id,
name=config_entry.title or config_entry.domain,
)
)
return agents
@core.callback
def async_is_valid_agent_id(self, agent_id: str) -> bool:
"""Check if the agent id is valid."""
return agent_id in self._agents or agent_id == HOME_ASSISTANT_AGENT
@core.callback
def async_set_agent(self, agent_id: str, agent: AbstractConversationAgent) -> None:
"""Set the agent."""
self._agents[agent_id] = agent
@core.callback
def async_unset_agent(self, agent_id: str) -> None:
"""Unset the agent."""
self._agents.pop(agent_id, None)