Sentence trigger (#94613)

* Add async_register_trigger_sentences for default agent

* Add trigger response and trigger handler

* Check callback in test

* Clean up and move response to callback

* Add trigger test

* Drop TriggerAction

* Test we pass sentence to callback

* Match triggers once, allow multiple sentences

* Don't use trigger id

* Use async callback

* No response for now

* Use asyncio.gather for callback responses

* Fix after rebase

* Use a list for trigger sentences

---------

Co-authored-by: Paulus Schoutsen <balloob@gmail.com>
pull/94214/head
Michael Hansen 2023-06-22 18:29:34 -05:00 committed by GitHub
parent 29ef925d73
commit d811fa0e74
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 382 additions and 2 deletions

View File

@ -3,8 +3,9 @@ from __future__ import annotations
import asyncio import asyncio
from collections import defaultdict from collections import defaultdict
from collections.abc import Iterable from collections.abc import Awaitable, Callable, Iterable
from dataclasses import dataclass from dataclasses import dataclass
import functools
import logging import logging
from pathlib import Path from pathlib import Path
import re import re
@ -42,6 +43,9 @@ _DEFAULT_ERROR_TEXT = "Sorry, I couldn't understand that"
_ENTITY_REGISTRY_UPDATE_FIELDS = ["aliases", "name", "original_name"] _ENTITY_REGISTRY_UPDATE_FIELDS = ["aliases", "name", "original_name"]
REGEX_TYPE = type(re.compile("")) REGEX_TYPE = type(re.compile(""))
TRIGGER_CALLBACK_TYPE = Callable[ # pylint: disable=invalid-name
[str], Awaitable[str | None]
]
def json_load(fp: IO[str]) -> JsonObjectType: def json_load(fp: IO[str]) -> JsonObjectType:
@ -60,6 +64,14 @@ class LanguageIntents:
loaded_components: set[str] loaded_components: set[str]
@dataclass(slots=True)
class TriggerData:
"""List of sentences and the callback for a trigger."""
sentences: list[str]
callback: TRIGGER_CALLBACK_TYPE
def _get_language_variations(language: str) -> Iterable[str]: def _get_language_variations(language: str) -> Iterable[str]:
"""Generate language codes with and without region.""" """Generate language codes with and without region."""
yield language yield language
@ -110,6 +122,10 @@ class DefaultAgent(AbstractConversationAgent):
self._config_intents: dict[str, Any] = {} self._config_intents: dict[str, Any] = {}
self._slot_lists: dict[str, SlotList] | None = None self._slot_lists: dict[str, SlotList] | None = None
# Sentences that will trigger a callback (skipping intent recognition)
self._trigger_sentences: list[TriggerData] = []
self._trigger_intents: Intents | None = None
@property @property
def supported_languages(self) -> list[str]: def supported_languages(self) -> list[str]:
"""Return a list of supported languages.""" """Return a list of supported languages."""
@ -174,6 +190,9 @@ class DefaultAgent(AbstractConversationAgent):
async def async_process(self, user_input: ConversationInput) -> ConversationResult: async def async_process(self, user_input: ConversationInput) -> ConversationResult:
"""Process a sentence.""" """Process a sentence."""
if trigger_result := await self._match_triggers(user_input.text):
return trigger_result
language = user_input.language or self.hass.config.language language = user_input.language or self.hass.config.language
conversation_id = None # Not supported conversation_id = None # Not supported
@ -605,6 +624,99 @@ class DefaultAgent(AbstractConversationAgent):
response_str = lang_intents.error_responses.get(response_key) response_str = lang_intents.error_responses.get(response_key)
return response_str or _DEFAULT_ERROR_TEXT return response_str or _DEFAULT_ERROR_TEXT
def register_trigger(
self,
sentences: list[str],
callback: TRIGGER_CALLBACK_TYPE,
) -> core.CALLBACK_TYPE:
"""Register a list of sentences that will trigger a callback when recognized."""
trigger_data = TriggerData(sentences=sentences, callback=callback)
self._trigger_sentences.append(trigger_data)
# Force rebuild on next use
self._trigger_intents = None
unregister = functools.partial(self._unregister_trigger, trigger_data)
return unregister
def _rebuild_trigger_intents(self) -> None:
"""Rebuild the HassIL intents object from the current trigger sentences."""
intents_dict = {
"language": self.hass.config.language,
"intents": {
# Use trigger data index as a virtual intent name for HassIL.
# This works because the intents are rebuilt on every
# register/unregister.
str(trigger_id): {"data": [{"sentences": trigger_data.sentences}]}
for trigger_id, trigger_data in enumerate(self._trigger_sentences)
},
}
self._trigger_intents = Intents.from_dict(intents_dict)
_LOGGER.debug("Rebuilt trigger intents: %s", intents_dict)
def _unregister_trigger(self, trigger_data: TriggerData) -> None:
"""Unregister a set of trigger sentences."""
self._trigger_sentences.remove(trigger_data)
# Force rebuild on next use
self._trigger_intents = None
async def _match_triggers(self, sentence: str) -> ConversationResult | None:
"""Try to match sentence against registered trigger sentences.
Calls the registered callbacks if there's a match and returns a positive
conversation result.
"""
if not self._trigger_sentences:
# No triggers registered
return None
if self._trigger_intents is None:
# Need to rebuild intents before matching
self._rebuild_trigger_intents()
assert self._trigger_intents is not None
matched_triggers: set[int] = set()
for result in recognize_all(sentence, self._trigger_intents):
trigger_id = int(result.intent.name)
if trigger_id in matched_triggers:
# Already matched a sentence from this trigger
break
matched_triggers.add(trigger_id)
if not matched_triggers:
# Sentence did not match any trigger sentences
return None
_LOGGER.debug(
"'%s' matched %s trigger(s): %s",
sentence,
len(matched_triggers),
matched_triggers,
)
# Gather callback responses in parallel
trigger_responses = await asyncio.gather(
*(
self._trigger_sentences[trigger_id].callback(sentence)
for trigger_id in matched_triggers
)
)
# Use last non-empty result as speech response
speech: str | None = None
for trigger_response in trigger_responses:
speech = speech or trigger_response
response = intent.IntentResponse(language=self.hass.config.language)
response.response_type = intent.IntentResponseType.ACTION_DONE
response.async_set_speech(speech or "")
return ConversationResult(response=response)
def _make_error_result( def _make_error_result(
language: str, language: str,

View File

@ -0,0 +1,59 @@
"""Offer sentence based automation rules."""
from __future__ import annotations
from typing import Any
import voluptuous as vol
from homeassistant.const import CONF_COMMAND, CONF_PLATFORM
from homeassistant.core import CALLBACK_TYPE, HassJob, HomeAssistant, callback
import homeassistant.helpers.config_validation as cv
from homeassistant.helpers.trigger import TriggerActionType, TriggerInfo
from homeassistant.helpers.typing import ConfigType
from . import HOME_ASSISTANT_AGENT, _get_agent_manager
from .const import DOMAIN
from .default_agent import DefaultAgent
TRIGGER_SCHEMA = cv.TRIGGER_BASE_SCHEMA.extend(
{
vol.Required(CONF_PLATFORM): DOMAIN,
vol.Required(CONF_COMMAND): vol.All(cv.ensure_list, [cv.string]),
}
)
async def async_attach_trigger(
hass: HomeAssistant,
config: ConfigType,
action: TriggerActionType,
trigger_info: TriggerInfo,
) -> CALLBACK_TYPE:
"""Listen for events based on configuration."""
trigger_data = trigger_info["trigger_data"]
sentences = config.get(CONF_COMMAND, [])
job = HassJob(action)
@callback
async def call_action(sentence: str) -> str | None:
"""Call action with right context."""
trigger_input: dict[str, Any] = { # Satisfy type checker
**trigger_data,
"platform": DOMAIN,
"sentence": sentence,
}
# Wait for the automation to complete
if future := hass.async_run_hass_job(
job,
{"trigger": trigger_input},
):
await future
return None
default_agent = await _get_agent_manager(hass).async_get_agent(HOME_ASSISTANT_AGENT)
assert isinstance(default_agent, DefaultAgent)
return default_agent.register_trigger(sentences, call_action)

View File

@ -1,5 +1,5 @@
"""Test for the default agent.""" """Test for the default agent."""
from unittest.mock import patch from unittest.mock import AsyncMock, patch
import pytest import pytest
@ -223,3 +223,45 @@ async def test_unexposed_entities_skipped(
assert result.response.response_type == intent.IntentResponseType.QUERY_ANSWER assert result.response.response_type == intent.IntentResponseType.QUERY_ANSWER
assert len(result.response.matched_states) == 1 assert len(result.response.matched_states) == 1
assert result.response.matched_states[0].entity_id == exposed_light.entity_id assert result.response.matched_states[0].entity_id == exposed_light.entity_id
async def test_trigger_sentences(hass: HomeAssistant, init_components) -> None:
"""Test registering/unregistering/matching a few trigger sentences."""
trigger_sentences = ["It's party time", "It is time to party"]
trigger_response = "Cowabunga!"
agent = await conversation._get_agent_manager(hass).async_get_agent(
conversation.HOME_ASSISTANT_AGENT
)
assert isinstance(agent, conversation.DefaultAgent)
callback = AsyncMock(return_value=trigger_response)
unregister = agent.register_trigger(trigger_sentences, callback)
result = await conversation.async_converse(hass, "Not the trigger", None, Context())
assert result.response.response_type == intent.IntentResponseType.ERROR
# Using different case and including punctuation
test_sentences = ["it's party time!", "IT IS TIME TO PARTY."]
for sentence in test_sentences:
callback.reset_mock()
result = await conversation.async_converse(hass, sentence, None, Context())
callback.assert_called_once_with(sentence)
assert (
result.response.response_type == intent.IntentResponseType.ACTION_DONE
), sentence
assert result.response.speech == {
"plain": {"speech": trigger_response, "extra_data": None}
}
unregister()
# Should produce errors now
callback.reset_mock()
for sentence in test_sentences:
result = await conversation.async_converse(hass, sentence, None, Context())
assert (
result.response.response_type == intent.IntentResponseType.ERROR
), sentence
assert len(callback.mock_calls) == 0

View File

@ -0,0 +1,167 @@
"""Test conversation triggers."""
import pytest
from homeassistant.core import HomeAssistant
from homeassistant.setup import async_setup_component
from tests.common import async_mock_service
@pytest.fixture
def calls(hass):
"""Track calls to a mock service."""
return async_mock_service(hass, "test", "automation")
@pytest.fixture(autouse=True)
async def setup_comp(hass):
"""Initialize components."""
assert await async_setup_component(hass, "homeassistant", {})
assert await async_setup_component(hass, "conversation", {})
async def test_if_fires_on_event(hass: HomeAssistant, calls, setup_comp) -> None:
"""Test the firing of events."""
assert await async_setup_component(
hass,
"automation",
{
"automation": {
"trigger": {
"platform": "conversation",
"command": [
"Hey yo",
"Ha ha ha",
],
},
"action": {
"service": "test.automation",
"data_template": {"data": "{{ trigger }}"},
},
}
},
)
await hass.services.async_call(
"conversation",
"process",
{
"text": "Ha ha ha",
},
blocking=True,
)
await hass.async_block_till_done()
assert len(calls) == 1
assert calls[0].data["data"] == {
"alias": None,
"id": "0",
"idx": "0",
"platform": "conversation",
"sentence": "Ha ha ha",
}
async def test_same_trigger_multiple_sentences(
hass: HomeAssistant, calls, setup_comp
) -> None:
"""Test matching of multiple sentences from the same trigger."""
assert await async_setup_component(
hass,
"automation",
{
"automation": {
"trigger": {
"platform": "conversation",
"command": ["hello", "hello[ world]"],
},
"action": {
"service": "test.automation",
"data_template": {"data": "{{ trigger }}"},
},
}
},
)
await hass.services.async_call(
"conversation",
"process",
{
"text": "hello",
},
blocking=True,
)
# Only triggers once
await hass.async_block_till_done()
assert len(calls) == 1
assert calls[0].data["data"] == {
"alias": None,
"id": "0",
"idx": "0",
"platform": "conversation",
"sentence": "hello",
}
async def test_same_sentence_multiple_triggers(
hass: HomeAssistant, calls, setup_comp
) -> None:
"""Test use of the same sentence in multiple triggers."""
assert await async_setup_component(
hass,
"automation",
{
"automation": [
{
"trigger": {
"id": "trigger1",
"platform": "conversation",
"command": [
"hello",
],
},
"action": {
"service": "test.automation",
"data_template": {"data": "{{ trigger }}"},
},
},
{
"trigger": {
"id": "trigger2",
"platform": "conversation",
"command": [
"hello[ world]",
],
},
"action": {
"service": "test.automation",
"data_template": {"data": "{{ trigger }}"},
},
},
],
},
)
await hass.services.async_call(
"conversation",
"process",
{
"text": "hello",
},
blocking=True,
)
await hass.async_block_till_done()
assert len(calls) == 2
# The calls may come in any order
call_datas: set[tuple[str, str, str]] = set()
for call in calls:
call_data = call.data["data"]
call_datas.add((call_data["id"], call_data["platform"], call_data["sentence"]))
assert call_datas == {
("trigger1", "conversation", "hello"),
("trigger2", "conversation", "hello"),
}