Sentence trigger (#94613)
* Add async_register_trigger_sentences for default agent * Add trigger response and trigger handler * Check callback in test * Clean up and move response to callback * Add trigger test * Drop TriggerAction * Test we pass sentence to callback * Match triggers once, allow multiple sentences * Don't use trigger id * Use async callback * No response for now * Use asyncio.gather for callback responses * Fix after rebase * Use a list for trigger sentences --------- Co-authored-by: Paulus Schoutsen <balloob@gmail.com>pull/94214/head
parent
29ef925d73
commit
d811fa0e74
|
@ -3,8 +3,9 @@ from __future__ import annotations
|
||||||
|
|
||||||
import asyncio
|
import asyncio
|
||||||
from collections import defaultdict
|
from collections import defaultdict
|
||||||
from collections.abc import Iterable
|
from collections.abc import Awaitable, Callable, Iterable
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
|
import functools
|
||||||
import logging
|
import logging
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
import re
|
import re
|
||||||
|
@ -42,6 +43,9 @@ _DEFAULT_ERROR_TEXT = "Sorry, I couldn't understand that"
|
||||||
_ENTITY_REGISTRY_UPDATE_FIELDS = ["aliases", "name", "original_name"]
|
_ENTITY_REGISTRY_UPDATE_FIELDS = ["aliases", "name", "original_name"]
|
||||||
|
|
||||||
REGEX_TYPE = type(re.compile(""))
|
REGEX_TYPE = type(re.compile(""))
|
||||||
|
TRIGGER_CALLBACK_TYPE = Callable[ # pylint: disable=invalid-name
|
||||||
|
[str], Awaitable[str | None]
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
def json_load(fp: IO[str]) -> JsonObjectType:
|
def json_load(fp: IO[str]) -> JsonObjectType:
|
||||||
|
@ -60,6 +64,14 @@ class LanguageIntents:
|
||||||
loaded_components: set[str]
|
loaded_components: set[str]
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass(slots=True)
|
||||||
|
class TriggerData:
|
||||||
|
"""List of sentences and the callback for a trigger."""
|
||||||
|
|
||||||
|
sentences: list[str]
|
||||||
|
callback: TRIGGER_CALLBACK_TYPE
|
||||||
|
|
||||||
|
|
||||||
def _get_language_variations(language: str) -> Iterable[str]:
|
def _get_language_variations(language: str) -> Iterable[str]:
|
||||||
"""Generate language codes with and without region."""
|
"""Generate language codes with and without region."""
|
||||||
yield language
|
yield language
|
||||||
|
@ -110,6 +122,10 @@ class DefaultAgent(AbstractConversationAgent):
|
||||||
self._config_intents: dict[str, Any] = {}
|
self._config_intents: dict[str, Any] = {}
|
||||||
self._slot_lists: dict[str, SlotList] | None = None
|
self._slot_lists: dict[str, SlotList] | None = None
|
||||||
|
|
||||||
|
# Sentences that will trigger a callback (skipping intent recognition)
|
||||||
|
self._trigger_sentences: list[TriggerData] = []
|
||||||
|
self._trigger_intents: Intents | None = None
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def supported_languages(self) -> list[str]:
|
def supported_languages(self) -> list[str]:
|
||||||
"""Return a list of supported languages."""
|
"""Return a list of supported languages."""
|
||||||
|
@ -174,6 +190,9 @@ class DefaultAgent(AbstractConversationAgent):
|
||||||
|
|
||||||
async def async_process(self, user_input: ConversationInput) -> ConversationResult:
|
async def async_process(self, user_input: ConversationInput) -> ConversationResult:
|
||||||
"""Process a sentence."""
|
"""Process a sentence."""
|
||||||
|
if trigger_result := await self._match_triggers(user_input.text):
|
||||||
|
return trigger_result
|
||||||
|
|
||||||
language = user_input.language or self.hass.config.language
|
language = user_input.language or self.hass.config.language
|
||||||
conversation_id = None # Not supported
|
conversation_id = None # Not supported
|
||||||
|
|
||||||
|
@ -605,6 +624,99 @@ class DefaultAgent(AbstractConversationAgent):
|
||||||
response_str = lang_intents.error_responses.get(response_key)
|
response_str = lang_intents.error_responses.get(response_key)
|
||||||
return response_str or _DEFAULT_ERROR_TEXT
|
return response_str or _DEFAULT_ERROR_TEXT
|
||||||
|
|
||||||
|
def register_trigger(
|
||||||
|
self,
|
||||||
|
sentences: list[str],
|
||||||
|
callback: TRIGGER_CALLBACK_TYPE,
|
||||||
|
) -> core.CALLBACK_TYPE:
|
||||||
|
"""Register a list of sentences that will trigger a callback when recognized."""
|
||||||
|
trigger_data = TriggerData(sentences=sentences, callback=callback)
|
||||||
|
self._trigger_sentences.append(trigger_data)
|
||||||
|
|
||||||
|
# Force rebuild on next use
|
||||||
|
self._trigger_intents = None
|
||||||
|
|
||||||
|
unregister = functools.partial(self._unregister_trigger, trigger_data)
|
||||||
|
return unregister
|
||||||
|
|
||||||
|
def _rebuild_trigger_intents(self) -> None:
|
||||||
|
"""Rebuild the HassIL intents object from the current trigger sentences."""
|
||||||
|
intents_dict = {
|
||||||
|
"language": self.hass.config.language,
|
||||||
|
"intents": {
|
||||||
|
# Use trigger data index as a virtual intent name for HassIL.
|
||||||
|
# This works because the intents are rebuilt on every
|
||||||
|
# register/unregister.
|
||||||
|
str(trigger_id): {"data": [{"sentences": trigger_data.sentences}]}
|
||||||
|
for trigger_id, trigger_data in enumerate(self._trigger_sentences)
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
self._trigger_intents = Intents.from_dict(intents_dict)
|
||||||
|
_LOGGER.debug("Rebuilt trigger intents: %s", intents_dict)
|
||||||
|
|
||||||
|
def _unregister_trigger(self, trigger_data: TriggerData) -> None:
|
||||||
|
"""Unregister a set of trigger sentences."""
|
||||||
|
self._trigger_sentences.remove(trigger_data)
|
||||||
|
|
||||||
|
# Force rebuild on next use
|
||||||
|
self._trigger_intents = None
|
||||||
|
|
||||||
|
async def _match_triggers(self, sentence: str) -> ConversationResult | None:
|
||||||
|
"""Try to match sentence against registered trigger sentences.
|
||||||
|
|
||||||
|
Calls the registered callbacks if there's a match and returns a positive
|
||||||
|
conversation result.
|
||||||
|
"""
|
||||||
|
if not self._trigger_sentences:
|
||||||
|
# No triggers registered
|
||||||
|
return None
|
||||||
|
|
||||||
|
if self._trigger_intents is None:
|
||||||
|
# Need to rebuild intents before matching
|
||||||
|
self._rebuild_trigger_intents()
|
||||||
|
|
||||||
|
assert self._trigger_intents is not None
|
||||||
|
|
||||||
|
matched_triggers: set[int] = set()
|
||||||
|
for result in recognize_all(sentence, self._trigger_intents):
|
||||||
|
trigger_id = int(result.intent.name)
|
||||||
|
if trigger_id in matched_triggers:
|
||||||
|
# Already matched a sentence from this trigger
|
||||||
|
break
|
||||||
|
|
||||||
|
matched_triggers.add(trigger_id)
|
||||||
|
|
||||||
|
if not matched_triggers:
|
||||||
|
# Sentence did not match any trigger sentences
|
||||||
|
return None
|
||||||
|
|
||||||
|
_LOGGER.debug(
|
||||||
|
"'%s' matched %s trigger(s): %s",
|
||||||
|
sentence,
|
||||||
|
len(matched_triggers),
|
||||||
|
matched_triggers,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Gather callback responses in parallel
|
||||||
|
trigger_responses = await asyncio.gather(
|
||||||
|
*(
|
||||||
|
self._trigger_sentences[trigger_id].callback(sentence)
|
||||||
|
for trigger_id in matched_triggers
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
# Use last non-empty result as speech response
|
||||||
|
speech: str | None = None
|
||||||
|
for trigger_response in trigger_responses:
|
||||||
|
speech = speech or trigger_response
|
||||||
|
|
||||||
|
response = intent.IntentResponse(language=self.hass.config.language)
|
||||||
|
response.response_type = intent.IntentResponseType.ACTION_DONE
|
||||||
|
response.async_set_speech(speech or "")
|
||||||
|
|
||||||
|
return ConversationResult(response=response)
|
||||||
|
|
||||||
|
|
||||||
def _make_error_result(
|
def _make_error_result(
|
||||||
language: str,
|
language: str,
|
||||||
|
|
|
@ -0,0 +1,59 @@
|
||||||
|
"""Offer sentence based automation rules."""
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
import voluptuous as vol
|
||||||
|
|
||||||
|
from homeassistant.const import CONF_COMMAND, CONF_PLATFORM
|
||||||
|
from homeassistant.core import CALLBACK_TYPE, HassJob, HomeAssistant, callback
|
||||||
|
import homeassistant.helpers.config_validation as cv
|
||||||
|
from homeassistant.helpers.trigger import TriggerActionType, TriggerInfo
|
||||||
|
from homeassistant.helpers.typing import ConfigType
|
||||||
|
|
||||||
|
from . import HOME_ASSISTANT_AGENT, _get_agent_manager
|
||||||
|
from .const import DOMAIN
|
||||||
|
from .default_agent import DefaultAgent
|
||||||
|
|
||||||
|
TRIGGER_SCHEMA = cv.TRIGGER_BASE_SCHEMA.extend(
|
||||||
|
{
|
||||||
|
vol.Required(CONF_PLATFORM): DOMAIN,
|
||||||
|
vol.Required(CONF_COMMAND): vol.All(cv.ensure_list, [cv.string]),
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
async def async_attach_trigger(
|
||||||
|
hass: HomeAssistant,
|
||||||
|
config: ConfigType,
|
||||||
|
action: TriggerActionType,
|
||||||
|
trigger_info: TriggerInfo,
|
||||||
|
) -> CALLBACK_TYPE:
|
||||||
|
"""Listen for events based on configuration."""
|
||||||
|
trigger_data = trigger_info["trigger_data"]
|
||||||
|
sentences = config.get(CONF_COMMAND, [])
|
||||||
|
|
||||||
|
job = HassJob(action)
|
||||||
|
|
||||||
|
@callback
|
||||||
|
async def call_action(sentence: str) -> str | None:
|
||||||
|
"""Call action with right context."""
|
||||||
|
trigger_input: dict[str, Any] = { # Satisfy type checker
|
||||||
|
**trigger_data,
|
||||||
|
"platform": DOMAIN,
|
||||||
|
"sentence": sentence,
|
||||||
|
}
|
||||||
|
|
||||||
|
# Wait for the automation to complete
|
||||||
|
if future := hass.async_run_hass_job(
|
||||||
|
job,
|
||||||
|
{"trigger": trigger_input},
|
||||||
|
):
|
||||||
|
await future
|
||||||
|
|
||||||
|
return None
|
||||||
|
|
||||||
|
default_agent = await _get_agent_manager(hass).async_get_agent(HOME_ASSISTANT_AGENT)
|
||||||
|
assert isinstance(default_agent, DefaultAgent)
|
||||||
|
|
||||||
|
return default_agent.register_trigger(sentences, call_action)
|
|
@ -1,5 +1,5 @@
|
||||||
"""Test for the default agent."""
|
"""Test for the default agent."""
|
||||||
from unittest.mock import patch
|
from unittest.mock import AsyncMock, patch
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
|
@ -223,3 +223,45 @@ async def test_unexposed_entities_skipped(
|
||||||
assert result.response.response_type == intent.IntentResponseType.QUERY_ANSWER
|
assert result.response.response_type == intent.IntentResponseType.QUERY_ANSWER
|
||||||
assert len(result.response.matched_states) == 1
|
assert len(result.response.matched_states) == 1
|
||||||
assert result.response.matched_states[0].entity_id == exposed_light.entity_id
|
assert result.response.matched_states[0].entity_id == exposed_light.entity_id
|
||||||
|
|
||||||
|
|
||||||
|
async def test_trigger_sentences(hass: HomeAssistant, init_components) -> None:
|
||||||
|
"""Test registering/unregistering/matching a few trigger sentences."""
|
||||||
|
trigger_sentences = ["It's party time", "It is time to party"]
|
||||||
|
trigger_response = "Cowabunga!"
|
||||||
|
|
||||||
|
agent = await conversation._get_agent_manager(hass).async_get_agent(
|
||||||
|
conversation.HOME_ASSISTANT_AGENT
|
||||||
|
)
|
||||||
|
assert isinstance(agent, conversation.DefaultAgent)
|
||||||
|
|
||||||
|
callback = AsyncMock(return_value=trigger_response)
|
||||||
|
unregister = agent.register_trigger(trigger_sentences, callback)
|
||||||
|
|
||||||
|
result = await conversation.async_converse(hass, "Not the trigger", None, Context())
|
||||||
|
assert result.response.response_type == intent.IntentResponseType.ERROR
|
||||||
|
|
||||||
|
# Using different case and including punctuation
|
||||||
|
test_sentences = ["it's party time!", "IT IS TIME TO PARTY."]
|
||||||
|
for sentence in test_sentences:
|
||||||
|
callback.reset_mock()
|
||||||
|
result = await conversation.async_converse(hass, sentence, None, Context())
|
||||||
|
callback.assert_called_once_with(sentence)
|
||||||
|
assert (
|
||||||
|
result.response.response_type == intent.IntentResponseType.ACTION_DONE
|
||||||
|
), sentence
|
||||||
|
assert result.response.speech == {
|
||||||
|
"plain": {"speech": trigger_response, "extra_data": None}
|
||||||
|
}
|
||||||
|
|
||||||
|
unregister()
|
||||||
|
|
||||||
|
# Should produce errors now
|
||||||
|
callback.reset_mock()
|
||||||
|
for sentence in test_sentences:
|
||||||
|
result = await conversation.async_converse(hass, sentence, None, Context())
|
||||||
|
assert (
|
||||||
|
result.response.response_type == intent.IntentResponseType.ERROR
|
||||||
|
), sentence
|
||||||
|
|
||||||
|
assert len(callback.mock_calls) == 0
|
||||||
|
|
|
@ -0,0 +1,167 @@
|
||||||
|
"""Test conversation triggers."""
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from homeassistant.core import HomeAssistant
|
||||||
|
from homeassistant.setup import async_setup_component
|
||||||
|
|
||||||
|
from tests.common import async_mock_service
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def calls(hass):
|
||||||
|
"""Track calls to a mock service."""
|
||||||
|
return async_mock_service(hass, "test", "automation")
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(autouse=True)
|
||||||
|
async def setup_comp(hass):
|
||||||
|
"""Initialize components."""
|
||||||
|
assert await async_setup_component(hass, "homeassistant", {})
|
||||||
|
assert await async_setup_component(hass, "conversation", {})
|
||||||
|
|
||||||
|
|
||||||
|
async def test_if_fires_on_event(hass: HomeAssistant, calls, setup_comp) -> None:
|
||||||
|
"""Test the firing of events."""
|
||||||
|
assert await async_setup_component(
|
||||||
|
hass,
|
||||||
|
"automation",
|
||||||
|
{
|
||||||
|
"automation": {
|
||||||
|
"trigger": {
|
||||||
|
"platform": "conversation",
|
||||||
|
"command": [
|
||||||
|
"Hey yo",
|
||||||
|
"Ha ha ha",
|
||||||
|
],
|
||||||
|
},
|
||||||
|
"action": {
|
||||||
|
"service": "test.automation",
|
||||||
|
"data_template": {"data": "{{ trigger }}"},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
await hass.services.async_call(
|
||||||
|
"conversation",
|
||||||
|
"process",
|
||||||
|
{
|
||||||
|
"text": "Ha ha ha",
|
||||||
|
},
|
||||||
|
blocking=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
await hass.async_block_till_done()
|
||||||
|
assert len(calls) == 1
|
||||||
|
assert calls[0].data["data"] == {
|
||||||
|
"alias": None,
|
||||||
|
"id": "0",
|
||||||
|
"idx": "0",
|
||||||
|
"platform": "conversation",
|
||||||
|
"sentence": "Ha ha ha",
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
async def test_same_trigger_multiple_sentences(
|
||||||
|
hass: HomeAssistant, calls, setup_comp
|
||||||
|
) -> None:
|
||||||
|
"""Test matching of multiple sentences from the same trigger."""
|
||||||
|
assert await async_setup_component(
|
||||||
|
hass,
|
||||||
|
"automation",
|
||||||
|
{
|
||||||
|
"automation": {
|
||||||
|
"trigger": {
|
||||||
|
"platform": "conversation",
|
||||||
|
"command": ["hello", "hello[ world]"],
|
||||||
|
},
|
||||||
|
"action": {
|
||||||
|
"service": "test.automation",
|
||||||
|
"data_template": {"data": "{{ trigger }}"},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
await hass.services.async_call(
|
||||||
|
"conversation",
|
||||||
|
"process",
|
||||||
|
{
|
||||||
|
"text": "hello",
|
||||||
|
},
|
||||||
|
blocking=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Only triggers once
|
||||||
|
await hass.async_block_till_done()
|
||||||
|
assert len(calls) == 1
|
||||||
|
assert calls[0].data["data"] == {
|
||||||
|
"alias": None,
|
||||||
|
"id": "0",
|
||||||
|
"idx": "0",
|
||||||
|
"platform": "conversation",
|
||||||
|
"sentence": "hello",
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
async def test_same_sentence_multiple_triggers(
|
||||||
|
hass: HomeAssistant, calls, setup_comp
|
||||||
|
) -> None:
|
||||||
|
"""Test use of the same sentence in multiple triggers."""
|
||||||
|
assert await async_setup_component(
|
||||||
|
hass,
|
||||||
|
"automation",
|
||||||
|
{
|
||||||
|
"automation": [
|
||||||
|
{
|
||||||
|
"trigger": {
|
||||||
|
"id": "trigger1",
|
||||||
|
"platform": "conversation",
|
||||||
|
"command": [
|
||||||
|
"hello",
|
||||||
|
],
|
||||||
|
},
|
||||||
|
"action": {
|
||||||
|
"service": "test.automation",
|
||||||
|
"data_template": {"data": "{{ trigger }}"},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"trigger": {
|
||||||
|
"id": "trigger2",
|
||||||
|
"platform": "conversation",
|
||||||
|
"command": [
|
||||||
|
"hello[ world]",
|
||||||
|
],
|
||||||
|
},
|
||||||
|
"action": {
|
||||||
|
"service": "test.automation",
|
||||||
|
"data_template": {"data": "{{ trigger }}"},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
],
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
await hass.services.async_call(
|
||||||
|
"conversation",
|
||||||
|
"process",
|
||||||
|
{
|
||||||
|
"text": "hello",
|
||||||
|
},
|
||||||
|
blocking=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
await hass.async_block_till_done()
|
||||||
|
assert len(calls) == 2
|
||||||
|
|
||||||
|
# The calls may come in any order
|
||||||
|
call_datas: set[tuple[str, str, str]] = set()
|
||||||
|
for call in calls:
|
||||||
|
call_data = call.data["data"]
|
||||||
|
call_datas.add((call_data["id"], call_data["platform"], call_data["sentence"]))
|
||||||
|
|
||||||
|
assert call_datas == {
|
||||||
|
("trigger1", "conversation", "hello"),
|
||||||
|
("trigger2", "conversation", "hello"),
|
||||||
|
}
|
Loading…
Reference in New Issue