Conversation config (#86326)

* Restore conversation config

* Fall back to en for en_US, etc.

* Simplify config passing around

Co-authored-by: Paulus Schoutsen <balloob@gmail.com>
pull/86336/head
Michael Hansen 2023-01-20 20:39:49 -06:00 committed by GitHub
parent e1483ff746
commit 255611238b
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 142 additions and 13 deletions

View File

@ -27,6 +27,7 @@ DOMAIN = "conversation"
REGEX_TYPE = type(re.compile("")) REGEX_TYPE = type(re.compile(""))
DATA_AGENT = "conversation_agent" DATA_AGENT = "conversation_agent"
DATA_CONFIG = "conversation_config"
SERVICE_PROCESS = "process" SERVICE_PROCESS = "process"
SERVICE_RELOAD = "reload" SERVICE_RELOAD = "reload"
@ -45,6 +46,19 @@ SERVICE_RELOAD_SCHEMA = vol.Schema(
} }
) )
CONFIG_SCHEMA = vol.Schema(
{
vol.Optional(DOMAIN): vol.Schema(
{
vol.Optional("intents"): vol.Schema(
{cv.string: vol.All(cv.ensure_list, [cv.string])}
)
}
)
},
extra=vol.ALLOW_EXTRA,
)
@core.callback @core.callback
@bind_hass @bind_hass
@ -55,6 +69,8 @@ def async_set_agent(hass: core.HomeAssistant, agent: AbstractConversationAgent |
async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool: async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool:
"""Register the process service.""" """Register the process service."""
if config_intents := config.get(DOMAIN, {}).get("intents"):
hass.data[DATA_CONFIG] = config_intents
async def handle_process(service: core.ServiceCall) -> None: async def handle_process(service: core.ServiceCall) -> None:
"""Parse text into commands.""" """Parse text into commands."""
@ -210,7 +226,7 @@ async def _get_agent(hass: core.HomeAssistant) -> AbstractConversationAgent:
"""Get the active conversation agent.""" """Get the active conversation agent."""
if (agent := hass.data.get(DATA_AGENT)) is None: if (agent := hass.data.get(DATA_AGENT)) is None:
agent = hass.data[DATA_AGENT] = DefaultAgent(hass) agent = hass.data[DATA_AGENT] = DefaultAgent(hass)
await agent.async_initialize() await agent.async_initialize(hass.data.get(DATA_CONFIG))
return agent return agent

View File

@ -3,6 +3,7 @@ from __future__ import annotations
import asyncio import asyncio
from collections import defaultdict from collections import defaultdict
from collections.abc import Iterable
from dataclasses import dataclass from dataclasses import dataclass
import logging import logging
from pathlib import Path from pathlib import Path
@ -35,6 +36,21 @@ class LanguageIntents:
loaded_components: set[str] loaded_components: set[str]
def _get_language_variations(language: str) -> Iterable[str]:
"""Generate language codes with and without region."""
yield language
parts = re.split(r"([-_])", language)
if len(parts) == 3:
lang, sep, region = parts
if sep == "_":
# en_US -> en-US
yield f"{lang}-{region}"
# en-US -> en
yield lang
class DefaultAgent(AbstractConversationAgent): class DefaultAgent(AbstractConversationAgent):
"""Default agent for conversation agent.""" """Default agent for conversation agent."""
@ -44,12 +60,17 @@ class DefaultAgent(AbstractConversationAgent):
self._lang_intents: dict[str, LanguageIntents] = {} self._lang_intents: dict[str, LanguageIntents] = {}
self._lang_lock: dict[str, asyncio.Lock] = defaultdict(asyncio.Lock) self._lang_lock: dict[str, asyncio.Lock] = defaultdict(asyncio.Lock)
async def async_initialize(self): # intent -> [sentences]
self._config_intents: dict[str, Any] = {}
async def async_initialize(self, config_intents):
"""Initialize the default agent.""" """Initialize the default agent."""
if "intent" not in self.hass.config.components: if "intent" not in self.hass.config.components:
await setup.async_setup_component(self.hass, "intent", {}) await setup.async_setup_component(self.hass, "intent", {})
self.hass.data.setdefault(DOMAIN, {}) # Intents from config may only contains sentences for HA config's language
if config_intents:
self._config_intents = config_intents
async def async_process( async def async_process(
self, self,
@ -144,17 +165,20 @@ class DefaultAgent(AbstractConversationAgent):
# Don't check component again # Don't check component again
loaded_components.add(component) loaded_components.add(component)
# Check for intents for this component with the target language # Check for intents for this component with the target language.
component_intents = get_intents(component, language) # Try en-US, en, etc.
if component_intents: for language_variation in _get_language_variations(language):
# Merge sentences into existing dictionary component_intents = get_intents(component, language_variation)
merge_dict(intents_dict, component_intents) if component_intents:
# Merge sentences into existing dictionary
merge_dict(intents_dict, component_intents)
# Will need to recreate graph # Will need to recreate graph
intents_changed = True intents_changed = True
_LOGGER.debug( _LOGGER.debug(
"Loaded intents component=%s, language=%s", component, language "Loaded intents component=%s, language=%s", component, language
) )
break
# Check for custom sentences in <config>/custom_sentences/<language>/ # Check for custom sentences in <config>/custom_sentences/<language>/
if lang_intents is None: if lang_intents is None:
@ -179,6 +203,22 @@ class DefaultAgent(AbstractConversationAgent):
custom_sentences_path, custom_sentences_path,
) )
# Load sentences from HA config for default language only
if self._config_intents and (language == self.hass.config.language):
merge_dict(
intents_dict,
{
"intents": {
intent_name: {"data": [{"sentences": sentences}]}
for intent_name, sentences in self._config_intents.items()
}
},
)
intents_changed = True
_LOGGER.debug(
"Loaded intents from configuration.yaml",
)
if not intents_dict: if not intents_dict:
return None return None

View File

@ -380,6 +380,55 @@ async def test_custom_sentences(hass, hass_client, hass_admin_user):
} }
async def test_custom_sentences_config(hass, hass_client, hass_admin_user):
"""Test custom sentences with a custom intent in config."""
assert await async_setup_component(hass, "homeassistant", {})
assert await async_setup_component(
hass,
"conversation",
{"conversation": {"intents": {"StealthMode": ["engage stealth mode"]}}},
)
assert await async_setup_component(hass, "intent", {})
assert await async_setup_component(
hass,
"intent_script",
{
"intent_script": {
"StealthMode": {"speech": {"text": "Stealth mode engaged"}}
}
},
)
# Invoke intent via HTTP API
client = await hass_client()
resp = await client.post(
"/api/conversation/process",
json={"text": "engage stealth mode"},
)
assert resp.status == HTTPStatus.OK
data = await resp.json()
assert data == {
"response": {
"card": {},
"speech": {
"plain": {
"extra_data": None,
"speech": "Stealth mode engaged",
}
},
"language": hass.config.language,
"response_type": "action_done",
"data": {
"targets": [],
"success": [],
"failed": [],
},
},
"conversation_id": None,
}
# pylint: disable=protected-access # pylint: disable=protected-access
async def test_prepare_reload(hass): async def test_prepare_reload(hass):
"""Test calling the reload service.""" """Test calling the reload service."""
@ -414,3 +463,27 @@ async def test_prepare_fail(hass):
# Confirm no intents were loaded # Confirm no intents were loaded
assert not agent._lang_intents.get("not-a-language") assert not agent._lang_intents.get("not-a-language")
async def test_language_region(hass, init_components):
"""Test calling the turn on intent."""
hass.states.async_set("light.kitchen", "off")
calls = async_mock_service(hass, HASS_DOMAIN, "turn_on")
# Add fake region
language = f"{hass.config.language}-YZ"
await hass.services.async_call(
"conversation",
"process",
{
conversation.ATTR_TEXT: "turn on the kitchen",
conversation.ATTR_LANGUAGE: language,
},
)
await hass.async_block_till_done()
assert len(calls) == 1
call = calls[0]
assert call.domain == HASS_DOMAIN
assert call.service == "turn_on"
assert call.data == {"entity_id": "light.kitchen"}