Conversation config (#86326)

* Restore conversation config

* Fall back to en for en_US, etc.

* Simplify config passing around

Co-authored-by: Paulus Schoutsen <balloob@gmail.com>
pull/86336/head
Michael Hansen 2023-01-20 20:39:49 -06:00 committed by GitHub
parent e1483ff746
commit 255611238b
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 142 additions and 13 deletions

View File

@ -27,6 +27,7 @@ DOMAIN = "conversation"
REGEX_TYPE = type(re.compile(""))
DATA_AGENT = "conversation_agent"
DATA_CONFIG = "conversation_config"
SERVICE_PROCESS = "process"
SERVICE_RELOAD = "reload"
@ -45,6 +46,19 @@ SERVICE_RELOAD_SCHEMA = vol.Schema(
}
)
CONFIG_SCHEMA = vol.Schema(
{
vol.Optional(DOMAIN): vol.Schema(
{
vol.Optional("intents"): vol.Schema(
{cv.string: vol.All(cv.ensure_list, [cv.string])}
)
}
)
},
extra=vol.ALLOW_EXTRA,
)
@core.callback
@bind_hass
@ -55,6 +69,8 @@ def async_set_agent(hass: core.HomeAssistant, agent: AbstractConversationAgent |
async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool:
"""Register the process service."""
if config_intents := config.get(DOMAIN, {}).get("intents"):
hass.data[DATA_CONFIG] = config_intents
async def handle_process(service: core.ServiceCall) -> None:
"""Parse text into commands."""
@ -210,7 +226,7 @@ async def _get_agent(hass: core.HomeAssistant) -> AbstractConversationAgent:
"""Get the active conversation agent."""
if (agent := hass.data.get(DATA_AGENT)) is None:
agent = hass.data[DATA_AGENT] = DefaultAgent(hass)
await agent.async_initialize()
await agent.async_initialize(hass.data.get(DATA_CONFIG))
return agent

View File

@ -3,6 +3,7 @@ from __future__ import annotations
import asyncio
from collections import defaultdict
from collections.abc import Iterable
from dataclasses import dataclass
import logging
from pathlib import Path
@ -35,6 +36,21 @@ class LanguageIntents:
loaded_components: set[str]
def _get_language_variations(language: str) -> Iterable[str]:
"""Generate language codes with and without region."""
yield language
parts = re.split(r"([-_])", language)
if len(parts) == 3:
lang, sep, region = parts
if sep == "_":
# en_US -> en-US
yield f"{lang}-{region}"
# en-US -> en
yield lang
class DefaultAgent(AbstractConversationAgent):
"""Default agent for conversation agent."""
@ -44,12 +60,17 @@ class DefaultAgent(AbstractConversationAgent):
self._lang_intents: dict[str, LanguageIntents] = {}
self._lang_lock: dict[str, asyncio.Lock] = defaultdict(asyncio.Lock)
async def async_initialize(self):
# intent -> [sentences]
self._config_intents: dict[str, Any] = {}
async def async_initialize(self, config_intents):
"""Initialize the default agent."""
if "intent" not in self.hass.config.components:
await setup.async_setup_component(self.hass, "intent", {})
self.hass.data.setdefault(DOMAIN, {})
# Intents from config may only contains sentences for HA config's language
if config_intents:
self._config_intents = config_intents
async def async_process(
self,
@ -144,17 +165,20 @@ class DefaultAgent(AbstractConversationAgent):
# Don't check component again
loaded_components.add(component)
# Check for intents for this component with the target language
component_intents = get_intents(component, language)
if component_intents:
# Merge sentences into existing dictionary
merge_dict(intents_dict, component_intents)
# Check for intents for this component with the target language.
# Try en-US, en, etc.
for language_variation in _get_language_variations(language):
component_intents = get_intents(component, language_variation)
if component_intents:
# Merge sentences into existing dictionary
merge_dict(intents_dict, component_intents)
# Will need to recreate graph
intents_changed = True
_LOGGER.debug(
"Loaded intents component=%s, language=%s", component, language
)
# Will need to recreate graph
intents_changed = True
_LOGGER.debug(
"Loaded intents component=%s, language=%s", component, language
)
break
# Check for custom sentences in <config>/custom_sentences/<language>/
if lang_intents is None:
@ -179,6 +203,22 @@ class DefaultAgent(AbstractConversationAgent):
custom_sentences_path,
)
# Load sentences from HA config for default language only
if self._config_intents and (language == self.hass.config.language):
merge_dict(
intents_dict,
{
"intents": {
intent_name: {"data": [{"sentences": sentences}]}
for intent_name, sentences in self._config_intents.items()
}
},
)
intents_changed = True
_LOGGER.debug(
"Loaded intents from configuration.yaml",
)
if not intents_dict:
return None

View File

@ -380,6 +380,55 @@ async def test_custom_sentences(hass, hass_client, hass_admin_user):
}
async def test_custom_sentences_config(hass, hass_client, hass_admin_user):
"""Test custom sentences with a custom intent in config."""
assert await async_setup_component(hass, "homeassistant", {})
assert await async_setup_component(
hass,
"conversation",
{"conversation": {"intents": {"StealthMode": ["engage stealth mode"]}}},
)
assert await async_setup_component(hass, "intent", {})
assert await async_setup_component(
hass,
"intent_script",
{
"intent_script": {
"StealthMode": {"speech": {"text": "Stealth mode engaged"}}
}
},
)
# Invoke intent via HTTP API
client = await hass_client()
resp = await client.post(
"/api/conversation/process",
json={"text": "engage stealth mode"},
)
assert resp.status == HTTPStatus.OK
data = await resp.json()
assert data == {
"response": {
"card": {},
"speech": {
"plain": {
"extra_data": None,
"speech": "Stealth mode engaged",
}
},
"language": hass.config.language,
"response_type": "action_done",
"data": {
"targets": [],
"success": [],
"failed": [],
},
},
"conversation_id": None,
}
# pylint: disable=protected-access
async def test_prepare_reload(hass):
"""Test calling the reload service."""
@ -414,3 +463,27 @@ async def test_prepare_fail(hass):
# Confirm no intents were loaded
assert not agent._lang_intents.get("not-a-language")
async def test_language_region(hass, init_components):
"""Test calling the turn on intent."""
hass.states.async_set("light.kitchen", "off")
calls = async_mock_service(hass, HASS_DOMAIN, "turn_on")
# Add fake region
language = f"{hass.config.language}-YZ"
await hass.services.async_call(
"conversation",
"process",
{
conversation.ATTR_TEXT: "turn on the kitchen",
conversation.ATTR_LANGUAGE: language,
},
)
await hass.async_block_till_done()
assert len(calls) == 1
call = calls[0]
assert call.domain == HASS_DOMAIN
assert call.service == "turn_on"
assert call.data == {"entity_id": "light.kitchen"}