Sentence trigger (#94613)
* Add async_register_trigger_sentences for default agent * Add trigger response and trigger handler * Check callback in test * Clean up and move response to callback * Add trigger test * Drop TriggerAction * Test we pass sentence to callback * Match triggers once, allow multiple sentences * Don't use trigger id * Use async callback * No response for now * Use asyncio.gather for callback responses * Fix after rebase * Use a list for trigger sentences --------- Co-authored-by: Paulus Schoutsen <balloob@gmail.com>pull/94214/head
parent
29ef925d73
commit
d811fa0e74
|
@ -3,8 +3,9 @@ from __future__ import annotations
|
|||
|
||||
import asyncio
|
||||
from collections import defaultdict
|
||||
from collections.abc import Iterable
|
||||
from collections.abc import Awaitable, Callable, Iterable
|
||||
from dataclasses import dataclass
|
||||
import functools
|
||||
import logging
|
||||
from pathlib import Path
|
||||
import re
|
||||
|
@ -42,6 +43,9 @@ _DEFAULT_ERROR_TEXT = "Sorry, I couldn't understand that"
|
|||
_ENTITY_REGISTRY_UPDATE_FIELDS = ["aliases", "name", "original_name"]
|
||||
|
||||
REGEX_TYPE = type(re.compile(""))
|
||||
TRIGGER_CALLBACK_TYPE = Callable[ # pylint: disable=invalid-name
|
||||
[str], Awaitable[str | None]
|
||||
]
|
||||
|
||||
|
||||
def json_load(fp: IO[str]) -> JsonObjectType:
|
||||
|
@ -60,6 +64,14 @@ class LanguageIntents:
|
|||
loaded_components: set[str]
|
||||
|
||||
|
||||
@dataclass(slots=True)
|
||||
class TriggerData:
|
||||
"""List of sentences and the callback for a trigger."""
|
||||
|
||||
sentences: list[str]
|
||||
callback: TRIGGER_CALLBACK_TYPE
|
||||
|
||||
|
||||
def _get_language_variations(language: str) -> Iterable[str]:
|
||||
"""Generate language codes with and without region."""
|
||||
yield language
|
||||
|
@ -110,6 +122,10 @@ class DefaultAgent(AbstractConversationAgent):
|
|||
self._config_intents: dict[str, Any] = {}
|
||||
self._slot_lists: dict[str, SlotList] | None = None
|
||||
|
||||
# Sentences that will trigger a callback (skipping intent recognition)
|
||||
self._trigger_sentences: list[TriggerData] = []
|
||||
self._trigger_intents: Intents | None = None
|
||||
|
||||
@property
|
||||
def supported_languages(self) -> list[str]:
|
||||
"""Return a list of supported languages."""
|
||||
|
@ -174,6 +190,9 @@ class DefaultAgent(AbstractConversationAgent):
|
|||
|
||||
async def async_process(self, user_input: ConversationInput) -> ConversationResult:
|
||||
"""Process a sentence."""
|
||||
if trigger_result := await self._match_triggers(user_input.text):
|
||||
return trigger_result
|
||||
|
||||
language = user_input.language or self.hass.config.language
|
||||
conversation_id = None # Not supported
|
||||
|
||||
|
@ -605,6 +624,99 @@ class DefaultAgent(AbstractConversationAgent):
|
|||
response_str = lang_intents.error_responses.get(response_key)
|
||||
return response_str or _DEFAULT_ERROR_TEXT
|
||||
|
||||
def register_trigger(
|
||||
self,
|
||||
sentences: list[str],
|
||||
callback: TRIGGER_CALLBACK_TYPE,
|
||||
) -> core.CALLBACK_TYPE:
|
||||
"""Register a list of sentences that will trigger a callback when recognized."""
|
||||
trigger_data = TriggerData(sentences=sentences, callback=callback)
|
||||
self._trigger_sentences.append(trigger_data)
|
||||
|
||||
# Force rebuild on next use
|
||||
self._trigger_intents = None
|
||||
|
||||
unregister = functools.partial(self._unregister_trigger, trigger_data)
|
||||
return unregister
|
||||
|
||||
def _rebuild_trigger_intents(self) -> None:
|
||||
"""Rebuild the HassIL intents object from the current trigger sentences."""
|
||||
intents_dict = {
|
||||
"language": self.hass.config.language,
|
||||
"intents": {
|
||||
# Use trigger data index as a virtual intent name for HassIL.
|
||||
# This works because the intents are rebuilt on every
|
||||
# register/unregister.
|
||||
str(trigger_id): {"data": [{"sentences": trigger_data.sentences}]}
|
||||
for trigger_id, trigger_data in enumerate(self._trigger_sentences)
|
||||
},
|
||||
}
|
||||
|
||||
self._trigger_intents = Intents.from_dict(intents_dict)
|
||||
_LOGGER.debug("Rebuilt trigger intents: %s", intents_dict)
|
||||
|
||||
def _unregister_trigger(self, trigger_data: TriggerData) -> None:
|
||||
"""Unregister a set of trigger sentences."""
|
||||
self._trigger_sentences.remove(trigger_data)
|
||||
|
||||
# Force rebuild on next use
|
||||
self._trigger_intents = None
|
||||
|
||||
async def _match_triggers(self, sentence: str) -> ConversationResult | None:
|
||||
"""Try to match sentence against registered trigger sentences.
|
||||
|
||||
Calls the registered callbacks if there's a match and returns a positive
|
||||
conversation result.
|
||||
"""
|
||||
if not self._trigger_sentences:
|
||||
# No triggers registered
|
||||
return None
|
||||
|
||||
if self._trigger_intents is None:
|
||||
# Need to rebuild intents before matching
|
||||
self._rebuild_trigger_intents()
|
||||
|
||||
assert self._trigger_intents is not None
|
||||
|
||||
matched_triggers: set[int] = set()
|
||||
for result in recognize_all(sentence, self._trigger_intents):
|
||||
trigger_id = int(result.intent.name)
|
||||
if trigger_id in matched_triggers:
|
||||
# Already matched a sentence from this trigger
|
||||
break
|
||||
|
||||
matched_triggers.add(trigger_id)
|
||||
|
||||
if not matched_triggers:
|
||||
# Sentence did not match any trigger sentences
|
||||
return None
|
||||
|
||||
_LOGGER.debug(
|
||||
"'%s' matched %s trigger(s): %s",
|
||||
sentence,
|
||||
len(matched_triggers),
|
||||
matched_triggers,
|
||||
)
|
||||
|
||||
# Gather callback responses in parallel
|
||||
trigger_responses = await asyncio.gather(
|
||||
*(
|
||||
self._trigger_sentences[trigger_id].callback(sentence)
|
||||
for trigger_id in matched_triggers
|
||||
)
|
||||
)
|
||||
|
||||
# Use last non-empty result as speech response
|
||||
speech: str | None = None
|
||||
for trigger_response in trigger_responses:
|
||||
speech = speech or trigger_response
|
||||
|
||||
response = intent.IntentResponse(language=self.hass.config.language)
|
||||
response.response_type = intent.IntentResponseType.ACTION_DONE
|
||||
response.async_set_speech(speech or "")
|
||||
|
||||
return ConversationResult(response=response)
|
||||
|
||||
|
||||
def _make_error_result(
|
||||
language: str,
|
||||
|
|
|
@ -0,0 +1,59 @@
|
|||
"""Offer sentence based automation rules."""
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any
|
||||
|
||||
import voluptuous as vol
|
||||
|
||||
from homeassistant.const import CONF_COMMAND, CONF_PLATFORM
|
||||
from homeassistant.core import CALLBACK_TYPE, HassJob, HomeAssistant, callback
|
||||
import homeassistant.helpers.config_validation as cv
|
||||
from homeassistant.helpers.trigger import TriggerActionType, TriggerInfo
|
||||
from homeassistant.helpers.typing import ConfigType
|
||||
|
||||
from . import HOME_ASSISTANT_AGENT, _get_agent_manager
|
||||
from .const import DOMAIN
|
||||
from .default_agent import DefaultAgent
|
||||
|
||||
TRIGGER_SCHEMA = cv.TRIGGER_BASE_SCHEMA.extend(
|
||||
{
|
||||
vol.Required(CONF_PLATFORM): DOMAIN,
|
||||
vol.Required(CONF_COMMAND): vol.All(cv.ensure_list, [cv.string]),
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
async def async_attach_trigger(
|
||||
hass: HomeAssistant,
|
||||
config: ConfigType,
|
||||
action: TriggerActionType,
|
||||
trigger_info: TriggerInfo,
|
||||
) -> CALLBACK_TYPE:
|
||||
"""Listen for events based on configuration."""
|
||||
trigger_data = trigger_info["trigger_data"]
|
||||
sentences = config.get(CONF_COMMAND, [])
|
||||
|
||||
job = HassJob(action)
|
||||
|
||||
@callback
|
||||
async def call_action(sentence: str) -> str | None:
|
||||
"""Call action with right context."""
|
||||
trigger_input: dict[str, Any] = { # Satisfy type checker
|
||||
**trigger_data,
|
||||
"platform": DOMAIN,
|
||||
"sentence": sentence,
|
||||
}
|
||||
|
||||
# Wait for the automation to complete
|
||||
if future := hass.async_run_hass_job(
|
||||
job,
|
||||
{"trigger": trigger_input},
|
||||
):
|
||||
await future
|
||||
|
||||
return None
|
||||
|
||||
default_agent = await _get_agent_manager(hass).async_get_agent(HOME_ASSISTANT_AGENT)
|
||||
assert isinstance(default_agent, DefaultAgent)
|
||||
|
||||
return default_agent.register_trigger(sentences, call_action)
|
|
@ -1,5 +1,5 @@
|
|||
"""Test for the default agent."""
|
||||
from unittest.mock import patch
|
||||
from unittest.mock import AsyncMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
|
@ -223,3 +223,45 @@ async def test_unexposed_entities_skipped(
|
|||
assert result.response.response_type == intent.IntentResponseType.QUERY_ANSWER
|
||||
assert len(result.response.matched_states) == 1
|
||||
assert result.response.matched_states[0].entity_id == exposed_light.entity_id
|
||||
|
||||
|
||||
async def test_trigger_sentences(hass: HomeAssistant, init_components) -> None:
|
||||
"""Test registering/unregistering/matching a few trigger sentences."""
|
||||
trigger_sentences = ["It's party time", "It is time to party"]
|
||||
trigger_response = "Cowabunga!"
|
||||
|
||||
agent = await conversation._get_agent_manager(hass).async_get_agent(
|
||||
conversation.HOME_ASSISTANT_AGENT
|
||||
)
|
||||
assert isinstance(agent, conversation.DefaultAgent)
|
||||
|
||||
callback = AsyncMock(return_value=trigger_response)
|
||||
unregister = agent.register_trigger(trigger_sentences, callback)
|
||||
|
||||
result = await conversation.async_converse(hass, "Not the trigger", None, Context())
|
||||
assert result.response.response_type == intent.IntentResponseType.ERROR
|
||||
|
||||
# Using different case and including punctuation
|
||||
test_sentences = ["it's party time!", "IT IS TIME TO PARTY."]
|
||||
for sentence in test_sentences:
|
||||
callback.reset_mock()
|
||||
result = await conversation.async_converse(hass, sentence, None, Context())
|
||||
callback.assert_called_once_with(sentence)
|
||||
assert (
|
||||
result.response.response_type == intent.IntentResponseType.ACTION_DONE
|
||||
), sentence
|
||||
assert result.response.speech == {
|
||||
"plain": {"speech": trigger_response, "extra_data": None}
|
||||
}
|
||||
|
||||
unregister()
|
||||
|
||||
# Should produce errors now
|
||||
callback.reset_mock()
|
||||
for sentence in test_sentences:
|
||||
result = await conversation.async_converse(hass, sentence, None, Context())
|
||||
assert (
|
||||
result.response.response_type == intent.IntentResponseType.ERROR
|
||||
), sentence
|
||||
|
||||
assert len(callback.mock_calls) == 0
|
||||
|
|
|
@ -0,0 +1,167 @@
|
|||
"""Test conversation triggers."""
|
||||
import pytest
|
||||
|
||||
from homeassistant.core import HomeAssistant
|
||||
from homeassistant.setup import async_setup_component
|
||||
|
||||
from tests.common import async_mock_service
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def calls(hass):
|
||||
"""Track calls to a mock service."""
|
||||
return async_mock_service(hass, "test", "automation")
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
async def setup_comp(hass):
|
||||
"""Initialize components."""
|
||||
assert await async_setup_component(hass, "homeassistant", {})
|
||||
assert await async_setup_component(hass, "conversation", {})
|
||||
|
||||
|
||||
async def test_if_fires_on_event(hass: HomeAssistant, calls, setup_comp) -> None:
|
||||
"""Test the firing of events."""
|
||||
assert await async_setup_component(
|
||||
hass,
|
||||
"automation",
|
||||
{
|
||||
"automation": {
|
||||
"trigger": {
|
||||
"platform": "conversation",
|
||||
"command": [
|
||||
"Hey yo",
|
||||
"Ha ha ha",
|
||||
],
|
||||
},
|
||||
"action": {
|
||||
"service": "test.automation",
|
||||
"data_template": {"data": "{{ trigger }}"},
|
||||
},
|
||||
}
|
||||
},
|
||||
)
|
||||
|
||||
await hass.services.async_call(
|
||||
"conversation",
|
||||
"process",
|
||||
{
|
||||
"text": "Ha ha ha",
|
||||
},
|
||||
blocking=True,
|
||||
)
|
||||
|
||||
await hass.async_block_till_done()
|
||||
assert len(calls) == 1
|
||||
assert calls[0].data["data"] == {
|
||||
"alias": None,
|
||||
"id": "0",
|
||||
"idx": "0",
|
||||
"platform": "conversation",
|
||||
"sentence": "Ha ha ha",
|
||||
}
|
||||
|
||||
|
||||
async def test_same_trigger_multiple_sentences(
|
||||
hass: HomeAssistant, calls, setup_comp
|
||||
) -> None:
|
||||
"""Test matching of multiple sentences from the same trigger."""
|
||||
assert await async_setup_component(
|
||||
hass,
|
||||
"automation",
|
||||
{
|
||||
"automation": {
|
||||
"trigger": {
|
||||
"platform": "conversation",
|
||||
"command": ["hello", "hello[ world]"],
|
||||
},
|
||||
"action": {
|
||||
"service": "test.automation",
|
||||
"data_template": {"data": "{{ trigger }}"},
|
||||
},
|
||||
}
|
||||
},
|
||||
)
|
||||
|
||||
await hass.services.async_call(
|
||||
"conversation",
|
||||
"process",
|
||||
{
|
||||
"text": "hello",
|
||||
},
|
||||
blocking=True,
|
||||
)
|
||||
|
||||
# Only triggers once
|
||||
await hass.async_block_till_done()
|
||||
assert len(calls) == 1
|
||||
assert calls[0].data["data"] == {
|
||||
"alias": None,
|
||||
"id": "0",
|
||||
"idx": "0",
|
||||
"platform": "conversation",
|
||||
"sentence": "hello",
|
||||
}
|
||||
|
||||
|
||||
async def test_same_sentence_multiple_triggers(
|
||||
hass: HomeAssistant, calls, setup_comp
|
||||
) -> None:
|
||||
"""Test use of the same sentence in multiple triggers."""
|
||||
assert await async_setup_component(
|
||||
hass,
|
||||
"automation",
|
||||
{
|
||||
"automation": [
|
||||
{
|
||||
"trigger": {
|
||||
"id": "trigger1",
|
||||
"platform": "conversation",
|
||||
"command": [
|
||||
"hello",
|
||||
],
|
||||
},
|
||||
"action": {
|
||||
"service": "test.automation",
|
||||
"data_template": {"data": "{{ trigger }}"},
|
||||
},
|
||||
},
|
||||
{
|
||||
"trigger": {
|
||||
"id": "trigger2",
|
||||
"platform": "conversation",
|
||||
"command": [
|
||||
"hello[ world]",
|
||||
],
|
||||
},
|
||||
"action": {
|
||||
"service": "test.automation",
|
||||
"data_template": {"data": "{{ trigger }}"},
|
||||
},
|
||||
},
|
||||
],
|
||||
},
|
||||
)
|
||||
|
||||
await hass.services.async_call(
|
||||
"conversation",
|
||||
"process",
|
||||
{
|
||||
"text": "hello",
|
||||
},
|
||||
blocking=True,
|
||||
)
|
||||
|
||||
await hass.async_block_till_done()
|
||||
assert len(calls) == 2
|
||||
|
||||
# The calls may come in any order
|
||||
call_datas: set[tuple[str, str, str]] = set()
|
||||
for call in calls:
|
||||
call_data = call.data["data"]
|
||||
call_datas.add((call_data["id"], call_data["platform"], call_data["sentence"]))
|
||||
|
||||
assert call_datas == {
|
||||
("trigger1", "conversation", "hello"),
|
||||
("trigger2", "conversation", "hello"),
|
||||
}
|
Loading…
Reference in New Issue