Conversation config (#86326)
* Restore conversation config * Fall back to en for en_US, etc. * Simplify config passing around Co-authored-by: Paulus Schoutsen <balloob@gmail.com>pull/86336/head
parent
e1483ff746
commit
255611238b
|
@ -27,6 +27,7 @@ DOMAIN = "conversation"
|
||||||
|
|
||||||
REGEX_TYPE = type(re.compile(""))
|
REGEX_TYPE = type(re.compile(""))
|
||||||
DATA_AGENT = "conversation_agent"
|
DATA_AGENT = "conversation_agent"
|
||||||
|
DATA_CONFIG = "conversation_config"
|
||||||
|
|
||||||
SERVICE_PROCESS = "process"
|
SERVICE_PROCESS = "process"
|
||||||
SERVICE_RELOAD = "reload"
|
SERVICE_RELOAD = "reload"
|
||||||
|
@ -45,6 +46,19 @@ SERVICE_RELOAD_SCHEMA = vol.Schema(
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
|
CONFIG_SCHEMA = vol.Schema(
|
||||||
|
{
|
||||||
|
vol.Optional(DOMAIN): vol.Schema(
|
||||||
|
{
|
||||||
|
vol.Optional("intents"): vol.Schema(
|
||||||
|
{cv.string: vol.All(cv.ensure_list, [cv.string])}
|
||||||
|
)
|
||||||
|
}
|
||||||
|
)
|
||||||
|
},
|
||||||
|
extra=vol.ALLOW_EXTRA,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@core.callback
|
@core.callback
|
||||||
@bind_hass
|
@bind_hass
|
||||||
|
@ -55,6 +69,8 @@ def async_set_agent(hass: core.HomeAssistant, agent: AbstractConversationAgent |
|
||||||
|
|
||||||
async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool:
|
async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool:
|
||||||
"""Register the process service."""
|
"""Register the process service."""
|
||||||
|
if config_intents := config.get(DOMAIN, {}).get("intents"):
|
||||||
|
hass.data[DATA_CONFIG] = config_intents
|
||||||
|
|
||||||
async def handle_process(service: core.ServiceCall) -> None:
|
async def handle_process(service: core.ServiceCall) -> None:
|
||||||
"""Parse text into commands."""
|
"""Parse text into commands."""
|
||||||
|
@ -210,7 +226,7 @@ async def _get_agent(hass: core.HomeAssistant) -> AbstractConversationAgent:
|
||||||
"""Get the active conversation agent."""
|
"""Get the active conversation agent."""
|
||||||
if (agent := hass.data.get(DATA_AGENT)) is None:
|
if (agent := hass.data.get(DATA_AGENT)) is None:
|
||||||
agent = hass.data[DATA_AGENT] = DefaultAgent(hass)
|
agent = hass.data[DATA_AGENT] = DefaultAgent(hass)
|
||||||
await agent.async_initialize()
|
await agent.async_initialize(hass.data.get(DATA_CONFIG))
|
||||||
return agent
|
return agent
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -3,6 +3,7 @@ from __future__ import annotations
|
||||||
|
|
||||||
import asyncio
|
import asyncio
|
||||||
from collections import defaultdict
|
from collections import defaultdict
|
||||||
|
from collections.abc import Iterable
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
import logging
|
import logging
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
@ -35,6 +36,21 @@ class LanguageIntents:
|
||||||
loaded_components: set[str]
|
loaded_components: set[str]
|
||||||
|
|
||||||
|
|
||||||
|
def _get_language_variations(language: str) -> Iterable[str]:
|
||||||
|
"""Generate language codes with and without region."""
|
||||||
|
yield language
|
||||||
|
|
||||||
|
parts = re.split(r"([-_])", language)
|
||||||
|
if len(parts) == 3:
|
||||||
|
lang, sep, region = parts
|
||||||
|
if sep == "_":
|
||||||
|
# en_US -> en-US
|
||||||
|
yield f"{lang}-{region}"
|
||||||
|
|
||||||
|
# en-US -> en
|
||||||
|
yield lang
|
||||||
|
|
||||||
|
|
||||||
class DefaultAgent(AbstractConversationAgent):
|
class DefaultAgent(AbstractConversationAgent):
|
||||||
"""Default agent for conversation agent."""
|
"""Default agent for conversation agent."""
|
||||||
|
|
||||||
|
@ -44,12 +60,17 @@ class DefaultAgent(AbstractConversationAgent):
|
||||||
self._lang_intents: dict[str, LanguageIntents] = {}
|
self._lang_intents: dict[str, LanguageIntents] = {}
|
||||||
self._lang_lock: dict[str, asyncio.Lock] = defaultdict(asyncio.Lock)
|
self._lang_lock: dict[str, asyncio.Lock] = defaultdict(asyncio.Lock)
|
||||||
|
|
||||||
async def async_initialize(self):
|
# intent -> [sentences]
|
||||||
|
self._config_intents: dict[str, Any] = {}
|
||||||
|
|
||||||
|
async def async_initialize(self, config_intents):
|
||||||
"""Initialize the default agent."""
|
"""Initialize the default agent."""
|
||||||
if "intent" not in self.hass.config.components:
|
if "intent" not in self.hass.config.components:
|
||||||
await setup.async_setup_component(self.hass, "intent", {})
|
await setup.async_setup_component(self.hass, "intent", {})
|
||||||
|
|
||||||
self.hass.data.setdefault(DOMAIN, {})
|
# Intents from config may only contains sentences for HA config's language
|
||||||
|
if config_intents:
|
||||||
|
self._config_intents = config_intents
|
||||||
|
|
||||||
async def async_process(
|
async def async_process(
|
||||||
self,
|
self,
|
||||||
|
@ -144,17 +165,20 @@ class DefaultAgent(AbstractConversationAgent):
|
||||||
# Don't check component again
|
# Don't check component again
|
||||||
loaded_components.add(component)
|
loaded_components.add(component)
|
||||||
|
|
||||||
# Check for intents for this component with the target language
|
# Check for intents for this component with the target language.
|
||||||
component_intents = get_intents(component, language)
|
# Try en-US, en, etc.
|
||||||
if component_intents:
|
for language_variation in _get_language_variations(language):
|
||||||
# Merge sentences into existing dictionary
|
component_intents = get_intents(component, language_variation)
|
||||||
merge_dict(intents_dict, component_intents)
|
if component_intents:
|
||||||
|
# Merge sentences into existing dictionary
|
||||||
|
merge_dict(intents_dict, component_intents)
|
||||||
|
|
||||||
# Will need to recreate graph
|
# Will need to recreate graph
|
||||||
intents_changed = True
|
intents_changed = True
|
||||||
_LOGGER.debug(
|
_LOGGER.debug(
|
||||||
"Loaded intents component=%s, language=%s", component, language
|
"Loaded intents component=%s, language=%s", component, language
|
||||||
)
|
)
|
||||||
|
break
|
||||||
|
|
||||||
# Check for custom sentences in <config>/custom_sentences/<language>/
|
# Check for custom sentences in <config>/custom_sentences/<language>/
|
||||||
if lang_intents is None:
|
if lang_intents is None:
|
||||||
|
@ -179,6 +203,22 @@ class DefaultAgent(AbstractConversationAgent):
|
||||||
custom_sentences_path,
|
custom_sentences_path,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Load sentences from HA config for default language only
|
||||||
|
if self._config_intents and (language == self.hass.config.language):
|
||||||
|
merge_dict(
|
||||||
|
intents_dict,
|
||||||
|
{
|
||||||
|
"intents": {
|
||||||
|
intent_name: {"data": [{"sentences": sentences}]}
|
||||||
|
for intent_name, sentences in self._config_intents.items()
|
||||||
|
}
|
||||||
|
},
|
||||||
|
)
|
||||||
|
intents_changed = True
|
||||||
|
_LOGGER.debug(
|
||||||
|
"Loaded intents from configuration.yaml",
|
||||||
|
)
|
||||||
|
|
||||||
if not intents_dict:
|
if not intents_dict:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
|
|
@ -380,6 +380,55 @@ async def test_custom_sentences(hass, hass_client, hass_admin_user):
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
async def test_custom_sentences_config(hass, hass_client, hass_admin_user):
|
||||||
|
"""Test custom sentences with a custom intent in config."""
|
||||||
|
assert await async_setup_component(hass, "homeassistant", {})
|
||||||
|
assert await async_setup_component(
|
||||||
|
hass,
|
||||||
|
"conversation",
|
||||||
|
{"conversation": {"intents": {"StealthMode": ["engage stealth mode"]}}},
|
||||||
|
)
|
||||||
|
assert await async_setup_component(hass, "intent", {})
|
||||||
|
assert await async_setup_component(
|
||||||
|
hass,
|
||||||
|
"intent_script",
|
||||||
|
{
|
||||||
|
"intent_script": {
|
||||||
|
"StealthMode": {"speech": {"text": "Stealth mode engaged"}}
|
||||||
|
}
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
# Invoke intent via HTTP API
|
||||||
|
client = await hass_client()
|
||||||
|
resp = await client.post(
|
||||||
|
"/api/conversation/process",
|
||||||
|
json={"text": "engage stealth mode"},
|
||||||
|
)
|
||||||
|
assert resp.status == HTTPStatus.OK
|
||||||
|
data = await resp.json()
|
||||||
|
|
||||||
|
assert data == {
|
||||||
|
"response": {
|
||||||
|
"card": {},
|
||||||
|
"speech": {
|
||||||
|
"plain": {
|
||||||
|
"extra_data": None,
|
||||||
|
"speech": "Stealth mode engaged",
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"language": hass.config.language,
|
||||||
|
"response_type": "action_done",
|
||||||
|
"data": {
|
||||||
|
"targets": [],
|
||||||
|
"success": [],
|
||||||
|
"failed": [],
|
||||||
|
},
|
||||||
|
},
|
||||||
|
"conversation_id": None,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
# pylint: disable=protected-access
|
# pylint: disable=protected-access
|
||||||
async def test_prepare_reload(hass):
|
async def test_prepare_reload(hass):
|
||||||
"""Test calling the reload service."""
|
"""Test calling the reload service."""
|
||||||
|
@ -414,3 +463,27 @@ async def test_prepare_fail(hass):
|
||||||
|
|
||||||
# Confirm no intents were loaded
|
# Confirm no intents were loaded
|
||||||
assert not agent._lang_intents.get("not-a-language")
|
assert not agent._lang_intents.get("not-a-language")
|
||||||
|
|
||||||
|
|
||||||
|
async def test_language_region(hass, init_components):
|
||||||
|
"""Test calling the turn on intent."""
|
||||||
|
hass.states.async_set("light.kitchen", "off")
|
||||||
|
calls = async_mock_service(hass, HASS_DOMAIN, "turn_on")
|
||||||
|
|
||||||
|
# Add fake region
|
||||||
|
language = f"{hass.config.language}-YZ"
|
||||||
|
await hass.services.async_call(
|
||||||
|
"conversation",
|
||||||
|
"process",
|
||||||
|
{
|
||||||
|
conversation.ATTR_TEXT: "turn on the kitchen",
|
||||||
|
conversation.ATTR_LANGUAGE: language,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
await hass.async_block_till_done()
|
||||||
|
|
||||||
|
assert len(calls) == 1
|
||||||
|
call = calls[0]
|
||||||
|
assert call.domain == HASS_DOMAIN
|
||||||
|
assert call.service == "turn_on"
|
||||||
|
assert call.data == {"entity_id": "light.kitchen"}
|
||||||
|
|
Loading…
Reference in New Issue