Re-organize conversation integration (#114502)

* Re-organize conversation integration

* Clean up 2 more imports

* Re-export models

* Fix imports

* Uno mas

* Rename agents to models

* Fix cast test that i broke?

* Just blocking till I'm done

* Wrong place
pull/114480/head^2
Paulus Schoutsen 2024-03-31 00:05:25 -04:00 committed by GitHub
parent fb572b8413
commit f01235ef74
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
17 changed files with 579 additions and 518 deletions

View File

@ -2,43 +2,36 @@
from __future__ import annotations
import asyncio
from collections.abc import Iterable
from dataclasses import dataclass
import logging
import re
from typing import Any, Literal
from typing import Literal
from aiohttp import web
from hassil.recognize import (
MISSING_ENTITY,
RecognizeResult,
UnmatchedRangeEntity,
UnmatchedTextEntity,
)
import voluptuous as vol
from homeassistant import core
from homeassistant.components import http, websocket_api
from homeassistant.components.http.data_validator import RequestDataValidator
from homeassistant.config_entries import ConfigEntry
from homeassistant.const import MATCH_ALL
from homeassistant.core import HomeAssistant
from homeassistant.core import (
HomeAssistant,
ServiceCall,
ServiceResponse,
SupportsResponse,
callback,
)
from homeassistant.exceptions import HomeAssistantError
from homeassistant.helpers import config_validation as cv, intent, singleton
from homeassistant.helpers import config_validation as cv, intent
from homeassistant.helpers.typing import ConfigType
from homeassistant.loader import bind_hass
from homeassistant.util import language as language_util
from .agent import AbstractConversationAgent, ConversationInput, ConversationResult
from .const import HOME_ASSISTANT_AGENT
from .default_agent import (
METADATA_CUSTOM_FILE,
METADATA_CUSTOM_SENTENCE,
DefaultAgent,
SentenceTriggerResult,
async_setup as async_setup_default_agent,
from .agent_manager import (
AgentInfo,
agent_id_validator,
async_converse,
get_agent_manager,
)
from .const import DATA_CONFIG, HOME_ASSISTANT_AGENT
from .http import async_setup as async_setup_conversation_http
from .models import AbstractConversationAgent, ConversationInput, ConversationResult
__all__ = [
"DOMAIN",
@ -48,6 +41,8 @@ __all__ = [
"async_set_agent",
"async_unset_agent",
"async_setup",
"ConversationInput",
"ConversationResult",
]
_LOGGER = logging.getLogger(__name__)
@ -60,21 +55,11 @@ ATTR_CONVERSATION_ID = "conversation_id"
DOMAIN = "conversation"
REGEX_TYPE = type(re.compile(""))
DATA_CONFIG = "conversation_config"
SERVICE_PROCESS = "process"
SERVICE_RELOAD = "reload"
def agent_id_validator(value: Any) -> str:
"""Validate agent ID."""
hass = core.async_get_hass()
manager = _get_agent_manager(hass)
if not manager.async_is_valid_agent_id(cv.string(value)):
raise vol.Invalid("invalid agent ID")
return value
SERVICE_PROCESS_SCHEMA = vol.Schema(
{
vol.Required(ATTR_TEXT): cv.string,
@ -106,34 +91,25 @@ CONFIG_SCHEMA = vol.Schema(
)
@singleton.singleton("conversation_agent")
@core.callback
def _get_agent_manager(hass: HomeAssistant) -> AgentManager:
"""Get the active agent."""
manager = AgentManager(hass)
manager.async_setup()
return manager
@core.callback
@callback
@bind_hass
def async_set_agent(
hass: core.HomeAssistant,
hass: HomeAssistant,
config_entry: ConfigEntry,
agent: AbstractConversationAgent,
) -> None:
"""Set the agent to handle the conversations."""
_get_agent_manager(hass).async_set_agent(config_entry.entry_id, agent)
get_agent_manager(hass).async_set_agent(config_entry.entry_id, agent)
@core.callback
@callback
@bind_hass
def async_unset_agent(
hass: core.HomeAssistant,
hass: HomeAssistant,
config_entry: ConfigEntry,
) -> None:
"""Set the agent to handle the conversations."""
_get_agent_manager(hass).async_unset_agent(config_entry.entry_id)
get_agent_manager(hass).async_unset_agent(config_entry.entry_id)
async def async_get_conversation_languages(
@ -145,7 +121,7 @@ async def async_get_conversation_languages(
If no agent is specified, return a set with the union of languages supported by
all conversation agents.
"""
agent_manager = _get_agent_manager(hass)
agent_manager = get_agent_manager(hass)
languages: set[str] = set()
agent_ids: Iterable[str]
@ -164,14 +140,32 @@ async def async_get_conversation_languages(
return languages
@callback
def async_get_agent_info(
hass: HomeAssistant,
agent_id: str | None = None,
) -> AgentInfo | None:
"""Get information on the agent or None if not found."""
manager = get_agent_manager(hass)
if agent_id is None:
agent_id = manager.default_agent
for agent_info in manager.async_get_agent_info():
if agent_info.id == agent_id:
return agent_info
return None
async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool:
"""Register the process service."""
agent_manager = _get_agent_manager(hass)
agent_manager = get_agent_manager(hass)
if config_intents := config.get(DOMAIN, {}).get("intents"):
hass.data[DATA_CONFIG] = config_intents
async def handle_process(service: core.ServiceCall) -> core.ServiceResponse:
async def handle_process(service: ServiceCall) -> ServiceResponse:
"""Parse text into commands."""
text = service.data[ATTR_TEXT]
_LOGGER.debug("Processing: <%s>", text)
@ -192,7 +186,7 @@ async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool:
return None
async def handle_reload(service: core.ServiceCall) -> None:
async def handle_reload(service: ServiceCall) -> None:
"""Reload intents."""
agent = await agent_manager.async_get_agent()
await agent.async_reload(language=service.data.get(ATTR_LANGUAGE))
@ -202,440 +196,11 @@ async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool:
SERVICE_PROCESS,
handle_process,
schema=SERVICE_PROCESS_SCHEMA,
supports_response=core.SupportsResponse.OPTIONAL,
supports_response=SupportsResponse.OPTIONAL,
)
hass.services.async_register(
DOMAIN, SERVICE_RELOAD, handle_reload, schema=SERVICE_RELOAD_SCHEMA
)
hass.http.register_view(ConversationProcessView())
websocket_api.async_register_command(hass, websocket_process)
websocket_api.async_register_command(hass, websocket_prepare)
websocket_api.async_register_command(hass, websocket_list_agents)
websocket_api.async_register_command(hass, websocket_hass_agent_debug)
async_setup_conversation_http(hass)
return True
@websocket_api.websocket_command(
{
vol.Required("type"): "conversation/process",
vol.Required("text"): str,
vol.Optional("conversation_id"): vol.Any(str, None),
vol.Optional("language"): str,
vol.Optional("agent_id"): agent_id_validator,
}
)
@websocket_api.async_response
async def websocket_process(
hass: HomeAssistant,
connection: websocket_api.ActiveConnection,
msg: dict[str, Any],
) -> None:
"""Process text."""
result = await async_converse(
hass=hass,
text=msg["text"],
conversation_id=msg.get("conversation_id"),
context=connection.context(msg),
language=msg.get("language"),
agent_id=msg.get("agent_id"),
)
connection.send_result(msg["id"], result.as_dict())
@websocket_api.websocket_command(
{
"type": "conversation/prepare",
vol.Optional("language"): str,
vol.Optional("agent_id"): agent_id_validator,
}
)
@websocket_api.async_response
async def websocket_prepare(
hass: HomeAssistant,
connection: websocket_api.ActiveConnection,
msg: dict[str, Any],
) -> None:
"""Reload intents."""
manager = _get_agent_manager(hass)
agent = await manager.async_get_agent(msg.get("agent_id"))
await agent.async_prepare(msg.get("language"))
connection.send_result(msg["id"])
@websocket_api.websocket_command(
{
vol.Required("type"): "conversation/agent/list",
vol.Optional("language"): str,
vol.Optional("country"): str,
}
)
@websocket_api.async_response
async def websocket_list_agents(
hass: HomeAssistant, connection: websocket_api.ActiveConnection, msg: dict
) -> None:
"""List conversation agents and, optionally, if they support a given language."""
manager = _get_agent_manager(hass)
country = msg.get("country")
language = msg.get("language")
agents = []
for agent_info in manager.async_get_agent_info():
agent = await manager.async_get_agent(agent_info.id)
supported_languages = agent.supported_languages
if language and supported_languages != MATCH_ALL:
supported_languages = language_util.matches(
language, supported_languages, country
)
agent_dict: dict[str, Any] = {
"id": agent_info.id,
"name": agent_info.name,
"supported_languages": supported_languages,
}
agents.append(agent_dict)
connection.send_message(websocket_api.result_message(msg["id"], {"agents": agents}))
@websocket_api.websocket_command(
{
vol.Required("type"): "conversation/agent/homeassistant/debug",
vol.Required("sentences"): [str],
vol.Optional("language"): str,
vol.Optional("device_id"): vol.Any(str, None),
}
)
@websocket_api.async_response
async def websocket_hass_agent_debug(
hass: HomeAssistant, connection: websocket_api.ActiveConnection, msg: dict
) -> None:
"""Return intents that would be matched by the default agent for a list of sentences."""
agent = await _get_agent_manager(hass).async_get_agent(HOME_ASSISTANT_AGENT)
assert isinstance(agent, DefaultAgent)
results = [
await agent.async_recognize(
ConversationInput(
text=sentence,
context=connection.context(msg),
conversation_id=None,
device_id=msg.get("device_id"),
language=msg.get("language", hass.config.language),
)
)
for sentence in msg["sentences"]
]
# Return results for each sentence in the same order as the input.
result_dicts: list[dict[str, Any] | None] = []
for result in results:
result_dict: dict[str, Any] | None = None
if isinstance(result, SentenceTriggerResult):
result_dict = {
# Matched a user-defined sentence trigger.
# We can't provide the response here without executing the
# trigger.
"match": True,
"source": "trigger",
"sentence_template": result.sentence_template or "",
}
elif isinstance(result, RecognizeResult):
successful_match = not result.unmatched_entities
result_dict = {
# Name of the matching intent (or the closest)
"intent": {
"name": result.intent.name,
},
# Slot values that would be received by the intent
"slots": { # direct access to values
entity_key: entity.text or entity.value
for entity_key, entity in result.entities.items()
},
# Extra slot details, such as the originally matched text
"details": {
entity_key: {
"name": entity.name,
"value": entity.value,
"text": entity.text,
}
for entity_key, entity in result.entities.items()
},
# Entities/areas/etc. that would be targeted
"targets": {},
# True if match was successful
"match": successful_match,
# Text of the sentence template that matched (or was closest)
"sentence_template": "",
# When match is incomplete, this will contain the best slot guesses
"unmatched_slots": _get_unmatched_slots(result),
}
if successful_match:
result_dict["targets"] = {
state.entity_id: {"matched": is_matched}
for state, is_matched in _get_debug_targets(hass, result)
}
if result.intent_sentence is not None:
result_dict["sentence_template"] = result.intent_sentence.text
# Inspect metadata to determine if this matched a custom sentence
if result.intent_metadata and result.intent_metadata.get(
METADATA_CUSTOM_SENTENCE
):
result_dict["source"] = "custom"
result_dict["file"] = result.intent_metadata.get(METADATA_CUSTOM_FILE)
else:
result_dict["source"] = "builtin"
result_dicts.append(result_dict)
connection.send_result(msg["id"], {"results": result_dicts})
def _get_debug_targets(
hass: HomeAssistant,
result: RecognizeResult,
) -> Iterable[tuple[core.State, bool]]:
"""Yield state/is_matched pairs for a hassil recognition."""
entities = result.entities
name: str | None = None
area_name: str | None = None
domains: set[str] | None = None
device_classes: set[str] | None = None
state_names: set[str] | None = None
if "name" in entities:
name = str(entities["name"].value)
if "area" in entities:
area_name = str(entities["area"].value)
if "domain" in entities:
domains = set(cv.ensure_list(entities["domain"].value))
if "device_class" in entities:
device_classes = set(cv.ensure_list(entities["device_class"].value))
if "state" in entities:
# HassGetState only
state_names = set(cv.ensure_list(entities["state"].value))
if (
(name is None)
and (area_name is None)
and (not domains)
and (not device_classes)
and (not state_names)
):
# Avoid "matching" all entities when there is no filter
return
states = intent.async_match_states(
hass,
name=name,
area_name=area_name,
domains=domains,
device_classes=device_classes,
)
for state in states:
# For queries, a target is "matched" based on its state
is_matched = (state_names is None) or (state.state in state_names)
yield state, is_matched
def _get_unmatched_slots(
result: RecognizeResult,
) -> dict[str, str | int]:
"""Return a dict of unmatched text/range slot entities."""
unmatched_slots: dict[str, str | int] = {}
for entity in result.unmatched_entities_list:
if isinstance(entity, UnmatchedTextEntity):
if entity.text == MISSING_ENTITY:
# Don't report <missing> since these are just missing context
# slots.
continue
unmatched_slots[entity.name] = entity.text
elif isinstance(entity, UnmatchedRangeEntity):
unmatched_slots[entity.name] = entity.value
return unmatched_slots
class ConversationProcessView(http.HomeAssistantView):
"""View to process text."""
url = "/api/conversation/process"
name = "api:conversation:process"
@RequestDataValidator(
vol.Schema(
{
vol.Required("text"): str,
vol.Optional("conversation_id"): str,
vol.Optional("language"): str,
vol.Optional("agent_id"): agent_id_validator,
}
)
)
async def post(self, request: web.Request, data: dict[str, str]) -> web.Response:
"""Send a request for processing."""
hass = request.app[http.KEY_HASS]
result = await async_converse(
hass,
text=data["text"],
conversation_id=data.get("conversation_id"),
context=self.context(request),
language=data.get("language"),
agent_id=data.get("agent_id"),
)
return self.json(result.as_dict())
@dataclass(frozen=True)
class AgentInfo:
"""Container for conversation agent info."""
id: str
name: str
@core.callback
def async_get_agent_info(
hass: core.HomeAssistant,
agent_id: str | None = None,
) -> AgentInfo | None:
"""Get information on the agent or None if not found."""
manager = _get_agent_manager(hass)
if agent_id is None:
agent_id = manager.default_agent
for agent_info in manager.async_get_agent_info():
if agent_info.id == agent_id:
return agent_info
return None
async def async_converse(
hass: core.HomeAssistant,
text: str,
conversation_id: str | None,
context: core.Context,
language: str | None = None,
agent_id: str | None = None,
device_id: str | None = None,
) -> ConversationResult:
"""Process text and get intent."""
agent = await _get_agent_manager(hass).async_get_agent(agent_id)
if language is None:
language = hass.config.language
_LOGGER.debug("Processing in %s: %s", language, text)
result = await agent.async_process(
ConversationInput(
text=text,
context=context,
conversation_id=conversation_id,
device_id=device_id,
language=language,
)
)
return result
class AgentManager:
"""Class to manage conversation agents."""
default_agent: str = HOME_ASSISTANT_AGENT
_builtin_agent: AbstractConversationAgent | None = None
def __init__(self, hass: HomeAssistant) -> None:
"""Initialize the conversation agents."""
self.hass = hass
self._agents: dict[str, AbstractConversationAgent] = {}
self._builtin_agent_init_lock = asyncio.Lock()
def async_setup(self) -> None:
"""Set up the conversation agents."""
async_setup_default_agent(self.hass)
async def async_get_agent(
self, agent_id: str | None = None
) -> AbstractConversationAgent:
"""Get the agent."""
if agent_id is None:
agent_id = self.default_agent
if agent_id == HOME_ASSISTANT_AGENT:
if self._builtin_agent is not None:
return self._builtin_agent
async with self._builtin_agent_init_lock:
if self._builtin_agent is not None:
return self._builtin_agent
self._builtin_agent = DefaultAgent(self.hass)
await self._builtin_agent.async_initialize(
self.hass.data.get(DATA_CONFIG)
)
return self._builtin_agent
if agent_id not in self._agents:
raise ValueError(f"Agent {agent_id} not found")
return self._agents[agent_id]
@core.callback
def async_get_agent_info(self) -> list[AgentInfo]:
"""List all agents."""
agents: list[AgentInfo] = [
AgentInfo(
id=HOME_ASSISTANT_AGENT,
name="Home Assistant",
)
]
for agent_id, agent in self._agents.items():
config_entry = self.hass.config_entries.async_get_entry(agent_id)
# Guard against potential bugs in conversation agents where the agent is not
# removed from the manager when the config entry is removed
if config_entry is None:
_LOGGER.warning(
"Conversation agent %s is still loaded after config entry removal",
agent,
)
continue
agents.append(
AgentInfo(
id=agent_id,
name=config_entry.title or config_entry.domain,
)
)
return agents
@core.callback
def async_is_valid_agent_id(self, agent_id: str) -> bool:
"""Check if the agent id is valid."""
return agent_id in self._agents or agent_id == HOME_ASSISTANT_AGENT
@core.callback
def async_set_agent(self, agent_id: str, agent: AbstractConversationAgent) -> None:
"""Set the agent."""
self._agents[agent_id] = agent
@core.callback
def async_unset_agent(self, agent_id: str) -> None:
"""Unset the agent."""
self._agents.pop(agent_id, None)

View File

@ -0,0 +1,161 @@
"""Agent foundation for conversation integration."""
from __future__ import annotations
import asyncio
from dataclasses import dataclass
import logging
from typing import Any
import voluptuous as vol
from homeassistant.core import Context, HomeAssistant, async_get_hass, callback
from homeassistant.helpers import config_validation as cv, singleton
from .const import DATA_CONFIG, HOME_ASSISTANT_AGENT
from .default_agent import DefaultAgent, async_setup as async_setup_default_agent
from .models import AbstractConversationAgent, ConversationInput, ConversationResult
_LOGGER = logging.getLogger(__name__)
@singleton.singleton("conversation_agent")
@callback
def get_agent_manager(hass: HomeAssistant) -> AgentManager:
"""Get the active agent."""
manager = AgentManager(hass)
manager.async_setup()
return manager
def agent_id_validator(value: Any) -> str:
"""Validate agent ID."""
hass = async_get_hass()
manager = get_agent_manager(hass)
if not manager.async_is_valid_agent_id(cv.string(value)):
raise vol.Invalid("invalid agent ID")
return value
async def async_converse(
hass: HomeAssistant,
text: str,
conversation_id: str | None,
context: Context,
language: str | None = None,
agent_id: str | None = None,
device_id: str | None = None,
) -> ConversationResult:
"""Process text and get intent."""
agent = await get_agent_manager(hass).async_get_agent(agent_id)
if language is None:
language = hass.config.language
_LOGGER.debug("Processing in %s: %s", language, text)
result = await agent.async_process(
ConversationInput(
text=text,
context=context,
conversation_id=conversation_id,
device_id=device_id,
language=language,
)
)
return result
@dataclass(frozen=True)
class AgentInfo:
"""Container for conversation agent info."""
id: str
name: str
class AgentManager:
"""Class to manage conversation agents."""
default_agent: str = HOME_ASSISTANT_AGENT
_builtin_agent: AbstractConversationAgent | None = None
def __init__(self, hass: HomeAssistant) -> None:
"""Initialize the conversation agents."""
self.hass = hass
self._agents: dict[str, AbstractConversationAgent] = {}
self._builtin_agent_init_lock = asyncio.Lock()
def async_setup(self) -> None:
"""Set up the conversation agents."""
async_setup_default_agent(self.hass)
async def async_get_agent(
self, agent_id: str | None = None
) -> AbstractConversationAgent:
"""Get the agent."""
if agent_id is None:
agent_id = self.default_agent
if agent_id == HOME_ASSISTANT_AGENT:
if self._builtin_agent is not None:
return self._builtin_agent
async with self._builtin_agent_init_lock:
if self._builtin_agent is not None:
return self._builtin_agent
self._builtin_agent = DefaultAgent(self.hass)
await self._builtin_agent.async_initialize(
self.hass.data.get(DATA_CONFIG)
)
return self._builtin_agent
if agent_id not in self._agents:
raise ValueError(f"Agent {agent_id} not found")
return self._agents[agent_id]
@callback
def async_get_agent_info(self) -> list[AgentInfo]:
"""List all agents."""
agents: list[AgentInfo] = [
AgentInfo(
id=HOME_ASSISTANT_AGENT,
name="Home Assistant",
)
]
for agent_id, agent in self._agents.items():
config_entry = self.hass.config_entries.async_get_entry(agent_id)
# Guard against potential bugs in conversation agents where the agent is not
# removed from the manager when the config entry is removed
if config_entry is None:
_LOGGER.warning(
"Conversation agent %s is still loaded after config entry removal",
agent,
)
continue
agents.append(
AgentInfo(
id=agent_id,
name=config_entry.title or config_entry.domain,
)
)
return agents
@callback
def async_is_valid_agent_id(self, agent_id: str) -> bool:
"""Check if the agent id is valid."""
return agent_id in self._agents or agent_id == HOME_ASSISTANT_AGENT
@callback
def async_set_agent(self, agent_id: str, agent: AbstractConversationAgent) -> None:
"""Set the agent."""
self._agents[agent_id] = agent
@callback
def async_unset_agent(self, agent_id: str) -> None:
"""Unset the agent."""
self._agents.pop(agent_id, None)

View File

@ -3,3 +3,4 @@
DOMAIN = "conversation"
DEFAULT_EXPOSED_ATTRIBUTES = {"device_class"}
HOME_ASSISTANT_AGENT = "homeassistant"
DATA_CONFIG = "conversation_config"

View File

@ -46,8 +46,8 @@ from homeassistant.helpers.event import (
)
from homeassistant.util.json import JsonObjectType, json_loads_object
from .agent import AbstractConversationAgent, ConversationInput, ConversationResult
from .const import DEFAULT_EXPOSED_ATTRIBUTES, DOMAIN
from .models import AbstractConversationAgent, ConversationInput, ConversationResult
_LOGGER = logging.getLogger(__name__)
_DEFAULT_ERROR_TEXT = "Sorry, I couldn't understand that"

View File

@ -0,0 +1,325 @@
"""HTTP endpoints for conversation integration."""
from __future__ import annotations
from collections.abc import Iterable
from typing import Any
from aiohttp import web
from hassil.recognize import (
MISSING_ENTITY,
RecognizeResult,
UnmatchedRangeEntity,
UnmatchedTextEntity,
)
import voluptuous as vol
from homeassistant.components import http, websocket_api
from homeassistant.components.http.data_validator import RequestDataValidator
from homeassistant.const import MATCH_ALL
from homeassistant.core import HomeAssistant, State, callback
from homeassistant.helpers import config_validation as cv, intent
from homeassistant.util import language as language_util
from .agent_manager import agent_id_validator, async_converse, get_agent_manager
from .const import HOME_ASSISTANT_AGENT
from .default_agent import (
METADATA_CUSTOM_FILE,
METADATA_CUSTOM_SENTENCE,
DefaultAgent,
SentenceTriggerResult,
)
from .models import ConversationInput
@callback
def async_setup(hass: HomeAssistant) -> None:
"""Set up the HTTP API for the conversation integration."""
hass.http.register_view(ConversationProcessView())
websocket_api.async_register_command(hass, websocket_process)
websocket_api.async_register_command(hass, websocket_prepare)
websocket_api.async_register_command(hass, websocket_list_agents)
websocket_api.async_register_command(hass, websocket_hass_agent_debug)
@websocket_api.websocket_command(
{
vol.Required("type"): "conversation/process",
vol.Required("text"): str,
vol.Optional("conversation_id"): vol.Any(str, None),
vol.Optional("language"): str,
vol.Optional("agent_id"): agent_id_validator,
}
)
@websocket_api.async_response
async def websocket_process(
hass: HomeAssistant,
connection: websocket_api.ActiveConnection,
msg: dict[str, Any],
) -> None:
"""Process text."""
result = await async_converse(
hass=hass,
text=msg["text"],
conversation_id=msg.get("conversation_id"),
context=connection.context(msg),
language=msg.get("language"),
agent_id=msg.get("agent_id"),
)
connection.send_result(msg["id"], result.as_dict())
@websocket_api.websocket_command(
{
"type": "conversation/prepare",
vol.Optional("language"): str,
vol.Optional("agent_id"): agent_id_validator,
}
)
@websocket_api.async_response
async def websocket_prepare(
hass: HomeAssistant,
connection: websocket_api.ActiveConnection,
msg: dict[str, Any],
) -> None:
"""Reload intents."""
manager = get_agent_manager(hass)
agent = await manager.async_get_agent(msg.get("agent_id"))
await agent.async_prepare(msg.get("language"))
connection.send_result(msg["id"])
@websocket_api.websocket_command(
{
vol.Required("type"): "conversation/agent/list",
vol.Optional("language"): str,
vol.Optional("country"): str,
}
)
@websocket_api.async_response
async def websocket_list_agents(
hass: HomeAssistant, connection: websocket_api.ActiveConnection, msg: dict
) -> None:
"""List conversation agents and, optionally, if they support a given language."""
manager = get_agent_manager(hass)
country = msg.get("country")
language = msg.get("language")
agents = []
for agent_info in manager.async_get_agent_info():
agent = await manager.async_get_agent(agent_info.id)
supported_languages = agent.supported_languages
if language and supported_languages != MATCH_ALL:
supported_languages = language_util.matches(
language, supported_languages, country
)
agent_dict: dict[str, Any] = {
"id": agent_info.id,
"name": agent_info.name,
"supported_languages": supported_languages,
}
agents.append(agent_dict)
connection.send_message(websocket_api.result_message(msg["id"], {"agents": agents}))
@websocket_api.websocket_command(
{
vol.Required("type"): "conversation/agent/homeassistant/debug",
vol.Required("sentences"): [str],
vol.Optional("language"): str,
vol.Optional("device_id"): vol.Any(str, None),
}
)
@websocket_api.async_response
async def websocket_hass_agent_debug(
hass: HomeAssistant, connection: websocket_api.ActiveConnection, msg: dict
) -> None:
"""Return intents that would be matched by the default agent for a list of sentences."""
agent = await get_agent_manager(hass).async_get_agent(HOME_ASSISTANT_AGENT)
assert isinstance(agent, DefaultAgent)
results = [
await agent.async_recognize(
ConversationInput(
text=sentence,
context=connection.context(msg),
conversation_id=None,
device_id=msg.get("device_id"),
language=msg.get("language", hass.config.language),
)
)
for sentence in msg["sentences"]
]
# Return results for each sentence in the same order as the input.
result_dicts: list[dict[str, Any] | None] = []
for result in results:
result_dict: dict[str, Any] | None = None
if isinstance(result, SentenceTriggerResult):
result_dict = {
# Matched a user-defined sentence trigger.
# We can't provide the response here without executing the
# trigger.
"match": True,
"source": "trigger",
"sentence_template": result.sentence_template or "",
}
elif isinstance(result, RecognizeResult):
successful_match = not result.unmatched_entities
result_dict = {
# Name of the matching intent (or the closest)
"intent": {
"name": result.intent.name,
},
# Slot values that would be received by the intent
"slots": { # direct access to values
entity_key: entity.text or entity.value
for entity_key, entity in result.entities.items()
},
# Extra slot details, such as the originally matched text
"details": {
entity_key: {
"name": entity.name,
"value": entity.value,
"text": entity.text,
}
for entity_key, entity in result.entities.items()
},
# Entities/areas/etc. that would be targeted
"targets": {},
# True if match was successful
"match": successful_match,
# Text of the sentence template that matched (or was closest)
"sentence_template": "",
# When match is incomplete, this will contain the best slot guesses
"unmatched_slots": _get_unmatched_slots(result),
}
if successful_match:
result_dict["targets"] = {
state.entity_id: {"matched": is_matched}
for state, is_matched in _get_debug_targets(hass, result)
}
if result.intent_sentence is not None:
result_dict["sentence_template"] = result.intent_sentence.text
# Inspect metadata to determine if this matched a custom sentence
if result.intent_metadata and result.intent_metadata.get(
METADATA_CUSTOM_SENTENCE
):
result_dict["source"] = "custom"
result_dict["file"] = result.intent_metadata.get(METADATA_CUSTOM_FILE)
else:
result_dict["source"] = "builtin"
result_dicts.append(result_dict)
connection.send_result(msg["id"], {"results": result_dicts})
def _get_debug_targets(
hass: HomeAssistant,
result: RecognizeResult,
) -> Iterable[tuple[State, bool]]:
"""Yield state/is_matched pairs for a hassil recognition."""
entities = result.entities
name: str | None = None
area_name: str | None = None
domains: set[str] | None = None
device_classes: set[str] | None = None
state_names: set[str] | None = None
if "name" in entities:
name = str(entities["name"].value)
if "area" in entities:
area_name = str(entities["area"].value)
if "domain" in entities:
domains = set(cv.ensure_list(entities["domain"].value))
if "device_class" in entities:
device_classes = set(cv.ensure_list(entities["device_class"].value))
if "state" in entities:
# HassGetState only
state_names = set(cv.ensure_list(entities["state"].value))
if (
(name is None)
and (area_name is None)
and (not domains)
and (not device_classes)
and (not state_names)
):
# Avoid "matching" all entities when there is no filter
return
states = intent.async_match_states(
hass,
name=name,
area_name=area_name,
domains=domains,
device_classes=device_classes,
)
for state in states:
# For queries, a target is "matched" based on its state
is_matched = (state_names is None) or (state.state in state_names)
yield state, is_matched
def _get_unmatched_slots(
result: RecognizeResult,
) -> dict[str, str | int]:
"""Return a dict of unmatched text/range slot entities."""
unmatched_slots: dict[str, str | int] = {}
for entity in result.unmatched_entities_list:
if isinstance(entity, UnmatchedTextEntity):
if entity.text == MISSING_ENTITY:
# Don't report <missing> since these are just missing context
# slots.
continue
unmatched_slots[entity.name] = entity.text
elif isinstance(entity, UnmatchedRangeEntity):
unmatched_slots[entity.name] = entity.value
return unmatched_slots
class ConversationProcessView(http.HomeAssistantView):
"""View to process text."""
url = "/api/conversation/process"
name = "api:conversation:process"
@RequestDataValidator(
vol.Schema(
{
vol.Required("text"): str,
vol.Optional("conversation_id"): str,
vol.Optional("language"): str,
vol.Optional("agent_id"): agent_id_validator,
}
)
)
async def post(self, request: web.Request, data: dict[str, str]) -> web.Response:
"""Send a request for processing."""
hass = request.app[http.KEY_HASS]
result = await async_converse(
hass,
text=data["text"],
conversation_id=data.get("conversation_id"),
context=self.context(request),
language=data.get("language"),
agent_id=data.get("agent_id"),
)
return self.json(result.as_dict())

View File

@ -14,8 +14,8 @@ from homeassistant.helpers.script import ScriptRunResult
from homeassistant.helpers.trigger import TriggerActionType, TriggerInfo
from homeassistant.helpers.typing import UNDEFINED, ConfigType
from . import HOME_ASSISTANT_AGENT, _get_agent_manager
from .const import DOMAIN
from .agent_manager import get_agent_manager
from .const import DOMAIN, HOME_ASSISTANT_AGENT
from .default_agent import DefaultAgent
@ -111,7 +111,7 @@ async def async_attach_trigger(
# two trigger copies for who will provide a response.
return None
default_agent = await _get_agent_manager(hass).async_get_agent(HOME_ASSISTANT_AGENT)
default_agent = await get_agent_manager(hass).async_get_agent(HOME_ASSISTANT_AGENT)
assert isinstance(default_agent, DefaultAgent)
return default_agent.register_trigger(sentences, call_action)

View File

@ -453,11 +453,13 @@ async def test_stop_discovery_called_on_stop(
"""Test pychromecast.stop_discovery called on shutdown."""
# start_discovery should be called with empty config
await async_setup_cast(hass, {})
await hass.async_block_till_done()
assert castbrowser_mock.return_value.start_discovery.call_count == 1
# stop discovery should be called on shutdown
hass.bus.async_fire(EVENT_HOMEASSISTANT_STOP)
await hass.async_block_till_done()
await hass.async_block_till_done()
assert castbrowser_mock.return_value.stop_discovery.call_count == 1

View File

@ -5,11 +5,16 @@ from __future__ import annotations
from typing import Literal
from homeassistant.components import conversation
from homeassistant.components.conversation.models import (
ConversationInput,
ConversationResult,
)
from homeassistant.components.homeassistant.exposed_entities import (
DATA_EXPOSED_ENTITIES,
ExposedEntities,
async_expose_entity,
)
from homeassistant.core import HomeAssistant
from homeassistant.helpers import intent
@ -30,24 +35,22 @@ class MockAgent(conversation.AbstractConversationAgent):
"""Return a list of supported languages."""
return self._supported_languages
async def async_process(
self, user_input: conversation.ConversationInput
) -> conversation.ConversationResult:
async def async_process(self, user_input: ConversationInput) -> ConversationResult:
"""Process some text."""
self.calls.append(user_input)
response = intent.IntentResponse(language=user_input.language)
response.async_set_speech(self.response)
return conversation.ConversationResult(
return ConversationResult(
response=response, conversation_id=user_input.conversation_id
)
def expose_new(hass, expose_new):
def expose_new(hass: HomeAssistant, expose_new: bool):
"""Enable exposing new entities to the default agent."""
exposed_entities: ExposedEntities = hass.data[DATA_EXPOSED_ENTITIES]
exposed_entities.async_set_expose_new_entities(conversation.DOMAIN, expose_new)
def expose_entity(hass, entity_id, should_expose):
def expose_entity(hass: HomeAssistant, entity_id: str, should_expose: bool):
"""Expose an entity to the default agent."""
async_expose_entity(hass, conversation.DOMAIN, entity_id, should_expose)

View File

@ -7,6 +7,7 @@ from hassil.recognize import Intent, IntentData, MatchEntity, RecognizeResult
import pytest
from homeassistant.components import conversation
from homeassistant.components.conversation import agent_manager, default_agent
from homeassistant.components.homeassistant.exposed_entities import (
async_get_assistant_settings,
)
@ -151,7 +152,7 @@ async def test_conversation_agent(
init_components,
) -> None:
"""Test DefaultAgent."""
agent = await conversation._get_agent_manager(hass).async_get_agent(
agent = await agent_manager.get_agent_manager(hass).async_get_agent(
conversation.HOME_ASSISTANT_AGENT
)
with patch(
@ -253,10 +254,10 @@ async def test_trigger_sentences(hass: HomeAssistant, init_components) -> None:
trigger_sentences = ["It's party time", "It is time to party"]
trigger_response = "Cowabunga!"
agent = await conversation._get_agent_manager(hass).async_get_agent(
agent = await agent_manager.get_agent_manager(hass).async_get_agent(
conversation.HOME_ASSISTANT_AGENT
)
assert isinstance(agent, conversation.DefaultAgent)
assert isinstance(agent, default_agent.DefaultAgent)
callback = AsyncMock(return_value=trigger_response)
unregister = agent.register_trigger(trigger_sentences, callback)
@ -850,7 +851,7 @@ async def test_empty_aliases(
)
with patch(
"homeassistant.components.conversation.DefaultAgent._recognize",
"homeassistant.components.conversation.default_agent.DefaultAgent._recognize",
return_value=None,
) as mock_recognize_all:
await conversation.async_converse(

View File

@ -9,6 +9,8 @@ from syrupy.assertion import SnapshotAssertion
import voluptuous as vol
from homeassistant.components import conversation
from homeassistant.components.conversation import agent_manager, default_agent
from homeassistant.components.conversation.models import ConversationInput
from homeassistant.components.cover import SERVICE_OPEN_COVER
from homeassistant.components.light import DOMAIN as LIGHT_DOMAIN
from homeassistant.const import ATTR_FRIENDLY_NAME
@ -750,8 +752,8 @@ async def test_ws_prepare(
"""Test the Websocket prepare conversation API."""
assert await async_setup_component(hass, "homeassistant", {})
assert await async_setup_component(hass, "conversation", {})
agent = await conversation._get_agent_manager(hass).async_get_agent()
assert isinstance(agent, conversation.DefaultAgent)
agent = await agent_manager.get_agent_manager(hass).async_get_agent()
assert isinstance(agent, default_agent.DefaultAgent)
# No intents should be loaded yet
assert not agent._lang_intents.get(hass.config.language)
@ -852,8 +854,8 @@ async def test_prepare_reload(hass: HomeAssistant) -> None:
assert await async_setup_component(hass, "conversation", {})
# Load intents
agent = await conversation._get_agent_manager(hass).async_get_agent()
assert isinstance(agent, conversation.DefaultAgent)
agent = await agent_manager.get_agent_manager(hass).async_get_agent()
assert isinstance(agent, default_agent.DefaultAgent)
await agent.async_prepare(language)
# Confirm intents are loaded
@ -880,8 +882,8 @@ async def test_prepare_fail(hass: HomeAssistant) -> None:
assert await async_setup_component(hass, "conversation", {})
# Load intents
agent = await conversation._get_agent_manager(hass).async_get_agent()
assert isinstance(agent, conversation.DefaultAgent)
agent = await agent_manager.get_agent_manager(hass).async_get_agent()
assert isinstance(agent, default_agent.DefaultAgent)
await agent.async_prepare("not-a-language")
# Confirm no intents were loaded
@ -917,11 +919,11 @@ async def test_non_default_response(hass: HomeAssistant, init_components) -> Non
hass.states.async_set("cover.front_door", "closed")
calls = async_mock_service(hass, "cover", SERVICE_OPEN_COVER)
agent = await conversation._get_agent_manager(hass).async_get_agent()
assert isinstance(agent, conversation.DefaultAgent)
agent = await agent_manager.get_agent_manager(hass).async_get_agent()
assert isinstance(agent, default_agent.DefaultAgent)
result = await agent.async_process(
conversation.ConversationInput(
ConversationInput(
text="open the front door",
context=Context(),
conversation_id=None,

View File

@ -5,7 +5,8 @@ import logging
import pytest
import voluptuous as vol
from homeassistant.components import conversation
from homeassistant.components.conversation import agent_manager, default_agent
from homeassistant.components.conversation.models import ConversationInput
from homeassistant.core import Context, HomeAssistant
from homeassistant.helpers import trigger
from homeassistant.setup import async_setup_component
@ -514,11 +515,11 @@ async def test_trigger_with_device_id(hass: HomeAssistant) -> None:
},
)
agent = await conversation._get_agent_manager(hass).async_get_agent()
assert isinstance(agent, conversation.DefaultAgent)
agent = await agent_manager.get_agent_manager(hass).async_get_agent()
assert isinstance(agent, default_agent.DefaultAgent)
result = await agent.async_process(
conversation.ConversationInput(
ConversationInput(
text="test sentence",
context=Context(),
conversation_id=None,

View File

@ -334,7 +334,7 @@ async def test_conversation_agent(
entry = entries[0]
assert entry.state is ConfigEntryState.LOADED
agent = await conversation._get_agent_manager(hass).async_get_agent(entry.entry_id)
agent = await conversation.get_agent_manager(hass).async_get_agent(entry.entry_id)
assert agent.supported_languages == SUPPORTED_LANGUAGE_CODES
text1 = "tell me a joke"

View File

@ -152,7 +152,7 @@ async def test_conversation_agent(
mock_init_component,
) -> None:
"""Test GoogleGenerativeAIAgent."""
agent = await conversation._get_agent_manager(hass).async_get_agent(
agent = await conversation.get_agent_manager(hass).async_get_agent(
mock_config_entry.entry_id
)
assert agent.supported_languages == "*"

View File

@ -1033,7 +1033,7 @@ async def test_webhook_handle_conversation_process(
webhook_client.server.app.router._frozen = False
with patch(
"homeassistant.components.conversation.AgentManager.async_get_agent",
"homeassistant.components.conversation.agent_manager.AgentManager.async_get_agent",
return_value=mock_conversation_agent,
):
resp = await webhook_client.post(

View File

@ -229,7 +229,7 @@ async def test_message_history_pruning(
assert isinstance(result.conversation_id, str)
conversation_ids.append(result.conversation_id)
agent = await conversation._get_agent_manager(hass).async_get_agent(
agent = await conversation.get_agent_manager(hass).async_get_agent(
mock_config_entry.entry_id
)
assert isinstance(agent, ollama.OllamaAgent)
@ -284,7 +284,7 @@ async def test_message_history_unlimited(
result.response.response_type == intent.IntentResponseType.ACTION_DONE
), result
agent = await conversation._get_agent_manager(hass).async_get_agent(
agent = await conversation.get_agent_manager(hass).async_get_agent(
mock_config_entry.entry_id
)
assert isinstance(agent, ollama.OllamaAgent)
@ -340,7 +340,7 @@ async def test_conversation_agent(
mock_init_component,
) -> None:
"""Test OllamaAgent."""
agent = await conversation._get_agent_manager(hass).async_get_agent(
agent = await conversation.get_agent_manager(hass).async_get_agent(
mock_config_entry.entry_id
)
assert agent.supported_languages == MATCH_ALL

View File

@ -194,7 +194,7 @@ async def test_conversation_agent(
mock_init_component,
) -> None:
"""Test OpenAIAgent."""
agent = await conversation._get_agent_manager(hass).async_get_agent(
agent = await conversation.get_agent_manager(hass).async_get_agent(
mock_config_entry.entry_id
)
assert agent.supported_languages == "*"