Conversation config (#86326)
* Restore conversation config * Fall back to en for en_US, etc. * Simplify config passing around Co-authored-by: Paulus Schoutsen <balloob@gmail.com>pull/86336/head
parent
e1483ff746
commit
255611238b
|
@ -27,6 +27,7 @@ DOMAIN = "conversation"
|
|||
|
||||
REGEX_TYPE = type(re.compile(""))
|
||||
DATA_AGENT = "conversation_agent"
|
||||
DATA_CONFIG = "conversation_config"
|
||||
|
||||
SERVICE_PROCESS = "process"
|
||||
SERVICE_RELOAD = "reload"
|
||||
|
@ -45,6 +46,19 @@ SERVICE_RELOAD_SCHEMA = vol.Schema(
|
|||
}
|
||||
)
|
||||
|
||||
CONFIG_SCHEMA = vol.Schema(
|
||||
{
|
||||
vol.Optional(DOMAIN): vol.Schema(
|
||||
{
|
||||
vol.Optional("intents"): vol.Schema(
|
||||
{cv.string: vol.All(cv.ensure_list, [cv.string])}
|
||||
)
|
||||
}
|
||||
)
|
||||
},
|
||||
extra=vol.ALLOW_EXTRA,
|
||||
)
|
||||
|
||||
|
||||
@core.callback
|
||||
@bind_hass
|
||||
|
@ -55,6 +69,8 @@ def async_set_agent(hass: core.HomeAssistant, agent: AbstractConversationAgent |
|
|||
|
||||
async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool:
|
||||
"""Register the process service."""
|
||||
if config_intents := config.get(DOMAIN, {}).get("intents"):
|
||||
hass.data[DATA_CONFIG] = config_intents
|
||||
|
||||
async def handle_process(service: core.ServiceCall) -> None:
|
||||
"""Parse text into commands."""
|
||||
|
@ -210,7 +226,7 @@ async def _get_agent(hass: core.HomeAssistant) -> AbstractConversationAgent:
|
|||
"""Get the active conversation agent."""
|
||||
if (agent := hass.data.get(DATA_AGENT)) is None:
|
||||
agent = hass.data[DATA_AGENT] = DefaultAgent(hass)
|
||||
await agent.async_initialize()
|
||||
await agent.async_initialize(hass.data.get(DATA_CONFIG))
|
||||
return agent
|
||||
|
||||
|
||||
|
|
|
@ -3,6 +3,7 @@ from __future__ import annotations
|
|||
|
||||
import asyncio
|
||||
from collections import defaultdict
|
||||
from collections.abc import Iterable
|
||||
from dataclasses import dataclass
|
||||
import logging
|
||||
from pathlib import Path
|
||||
|
@ -35,6 +36,21 @@ class LanguageIntents:
|
|||
loaded_components: set[str]
|
||||
|
||||
|
||||
def _get_language_variations(language: str) -> Iterable[str]:
|
||||
"""Generate language codes with and without region."""
|
||||
yield language
|
||||
|
||||
parts = re.split(r"([-_])", language)
|
||||
if len(parts) == 3:
|
||||
lang, sep, region = parts
|
||||
if sep == "_":
|
||||
# en_US -> en-US
|
||||
yield f"{lang}-{region}"
|
||||
|
||||
# en-US -> en
|
||||
yield lang
|
||||
|
||||
|
||||
class DefaultAgent(AbstractConversationAgent):
|
||||
"""Default agent for conversation agent."""
|
||||
|
||||
|
@ -44,12 +60,17 @@ class DefaultAgent(AbstractConversationAgent):
|
|||
self._lang_intents: dict[str, LanguageIntents] = {}
|
||||
self._lang_lock: dict[str, asyncio.Lock] = defaultdict(asyncio.Lock)
|
||||
|
||||
async def async_initialize(self):
|
||||
# intent -> [sentences]
|
||||
self._config_intents: dict[str, Any] = {}
|
||||
|
||||
async def async_initialize(self, config_intents):
|
||||
"""Initialize the default agent."""
|
||||
if "intent" not in self.hass.config.components:
|
||||
await setup.async_setup_component(self.hass, "intent", {})
|
||||
|
||||
self.hass.data.setdefault(DOMAIN, {})
|
||||
# Intents from config may only contains sentences for HA config's language
|
||||
if config_intents:
|
||||
self._config_intents = config_intents
|
||||
|
||||
async def async_process(
|
||||
self,
|
||||
|
@ -144,17 +165,20 @@ class DefaultAgent(AbstractConversationAgent):
|
|||
# Don't check component again
|
||||
loaded_components.add(component)
|
||||
|
||||
# Check for intents for this component with the target language
|
||||
component_intents = get_intents(component, language)
|
||||
if component_intents:
|
||||
# Merge sentences into existing dictionary
|
||||
merge_dict(intents_dict, component_intents)
|
||||
# Check for intents for this component with the target language.
|
||||
# Try en-US, en, etc.
|
||||
for language_variation in _get_language_variations(language):
|
||||
component_intents = get_intents(component, language_variation)
|
||||
if component_intents:
|
||||
# Merge sentences into existing dictionary
|
||||
merge_dict(intents_dict, component_intents)
|
||||
|
||||
# Will need to recreate graph
|
||||
intents_changed = True
|
||||
_LOGGER.debug(
|
||||
"Loaded intents component=%s, language=%s", component, language
|
||||
)
|
||||
# Will need to recreate graph
|
||||
intents_changed = True
|
||||
_LOGGER.debug(
|
||||
"Loaded intents component=%s, language=%s", component, language
|
||||
)
|
||||
break
|
||||
|
||||
# Check for custom sentences in <config>/custom_sentences/<language>/
|
||||
if lang_intents is None:
|
||||
|
@ -179,6 +203,22 @@ class DefaultAgent(AbstractConversationAgent):
|
|||
custom_sentences_path,
|
||||
)
|
||||
|
||||
# Load sentences from HA config for default language only
|
||||
if self._config_intents and (language == self.hass.config.language):
|
||||
merge_dict(
|
||||
intents_dict,
|
||||
{
|
||||
"intents": {
|
||||
intent_name: {"data": [{"sentences": sentences}]}
|
||||
for intent_name, sentences in self._config_intents.items()
|
||||
}
|
||||
},
|
||||
)
|
||||
intents_changed = True
|
||||
_LOGGER.debug(
|
||||
"Loaded intents from configuration.yaml",
|
||||
)
|
||||
|
||||
if not intents_dict:
|
||||
return None
|
||||
|
||||
|
|
|
@ -380,6 +380,55 @@ async def test_custom_sentences(hass, hass_client, hass_admin_user):
|
|||
}
|
||||
|
||||
|
||||
async def test_custom_sentences_config(hass, hass_client, hass_admin_user):
|
||||
"""Test custom sentences with a custom intent in config."""
|
||||
assert await async_setup_component(hass, "homeassistant", {})
|
||||
assert await async_setup_component(
|
||||
hass,
|
||||
"conversation",
|
||||
{"conversation": {"intents": {"StealthMode": ["engage stealth mode"]}}},
|
||||
)
|
||||
assert await async_setup_component(hass, "intent", {})
|
||||
assert await async_setup_component(
|
||||
hass,
|
||||
"intent_script",
|
||||
{
|
||||
"intent_script": {
|
||||
"StealthMode": {"speech": {"text": "Stealth mode engaged"}}
|
||||
}
|
||||
},
|
||||
)
|
||||
|
||||
# Invoke intent via HTTP API
|
||||
client = await hass_client()
|
||||
resp = await client.post(
|
||||
"/api/conversation/process",
|
||||
json={"text": "engage stealth mode"},
|
||||
)
|
||||
assert resp.status == HTTPStatus.OK
|
||||
data = await resp.json()
|
||||
|
||||
assert data == {
|
||||
"response": {
|
||||
"card": {},
|
||||
"speech": {
|
||||
"plain": {
|
||||
"extra_data": None,
|
||||
"speech": "Stealth mode engaged",
|
||||
}
|
||||
},
|
||||
"language": hass.config.language,
|
||||
"response_type": "action_done",
|
||||
"data": {
|
||||
"targets": [],
|
||||
"success": [],
|
||||
"failed": [],
|
||||
},
|
||||
},
|
||||
"conversation_id": None,
|
||||
}
|
||||
|
||||
|
||||
# pylint: disable=protected-access
|
||||
async def test_prepare_reload(hass):
|
||||
"""Test calling the reload service."""
|
||||
|
@ -414,3 +463,27 @@ async def test_prepare_fail(hass):
|
|||
|
||||
# Confirm no intents were loaded
|
||||
assert not agent._lang_intents.get("not-a-language")
|
||||
|
||||
|
||||
async def test_language_region(hass, init_components):
|
||||
"""Test calling the turn on intent."""
|
||||
hass.states.async_set("light.kitchen", "off")
|
||||
calls = async_mock_service(hass, HASS_DOMAIN, "turn_on")
|
||||
|
||||
# Add fake region
|
||||
language = f"{hass.config.language}-YZ"
|
||||
await hass.services.async_call(
|
||||
"conversation",
|
||||
"process",
|
||||
{
|
||||
conversation.ATTR_TEXT: "turn on the kitchen",
|
||||
conversation.ATTR_LANGUAGE: language,
|
||||
},
|
||||
)
|
||||
await hass.async_block_till_done()
|
||||
|
||||
assert len(calls) == 1
|
||||
call = calls[0]
|
||||
assert call.domain == HASS_DOMAIN
|
||||
assert call.service == "turn_on"
|
||||
assert call.data == {"entity_id": "light.kitchen"}
|
||||
|
|
Loading…
Reference in New Issue