diff --git a/homeassistant/components/conversation/default_agent.py b/homeassistant/components/conversation/default_agent.py index b0bbc8e7fec..336d6287f18 100644 --- a/homeassistant/components/conversation/default_agent.py +++ b/homeassistant/components/conversation/default_agent.py @@ -3,8 +3,9 @@ from __future__ import annotations import asyncio from collections import defaultdict -from collections.abc import Iterable +from collections.abc import Awaitable, Callable, Iterable from dataclasses import dataclass +import functools import logging from pathlib import Path import re @@ -42,6 +43,9 @@ _DEFAULT_ERROR_TEXT = "Sorry, I couldn't understand that" _ENTITY_REGISTRY_UPDATE_FIELDS = ["aliases", "name", "original_name"] REGEX_TYPE = type(re.compile("")) +TRIGGER_CALLBACK_TYPE = Callable[ # pylint: disable=invalid-name + [str], Awaitable[str | None] +] def json_load(fp: IO[str]) -> JsonObjectType: @@ -60,6 +64,14 @@ class LanguageIntents: loaded_components: set[str] +@dataclass(slots=True) +class TriggerData: + """List of sentences and the callback for a trigger.""" + + sentences: list[str] + callback: TRIGGER_CALLBACK_TYPE + + def _get_language_variations(language: str) -> Iterable[str]: """Generate language codes with and without region.""" yield language @@ -110,6 +122,10 @@ class DefaultAgent(AbstractConversationAgent): self._config_intents: dict[str, Any] = {} self._slot_lists: dict[str, SlotList] | None = None + # Sentences that will trigger a callback (skipping intent recognition) + self._trigger_sentences: list[TriggerData] = [] + self._trigger_intents: Intents | None = None + @property def supported_languages(self) -> list[str]: """Return a list of supported languages.""" @@ -174,6 +190,9 @@ class DefaultAgent(AbstractConversationAgent): async def async_process(self, user_input: ConversationInput) -> ConversationResult: """Process a sentence.""" + if trigger_result := await self._match_triggers(user_input.text): + return trigger_result + language = user_input.language or self.hass.config.language conversation_id = None # Not supported @@ -605,6 +624,99 @@ class DefaultAgent(AbstractConversationAgent): response_str = lang_intents.error_responses.get(response_key) return response_str or _DEFAULT_ERROR_TEXT + def register_trigger( + self, + sentences: list[str], + callback: TRIGGER_CALLBACK_TYPE, + ) -> core.CALLBACK_TYPE: + """Register a list of sentences that will trigger a callback when recognized.""" + trigger_data = TriggerData(sentences=sentences, callback=callback) + self._trigger_sentences.append(trigger_data) + + # Force rebuild on next use + self._trigger_intents = None + + unregister = functools.partial(self._unregister_trigger, trigger_data) + return unregister + + def _rebuild_trigger_intents(self) -> None: + """Rebuild the HassIL intents object from the current trigger sentences.""" + intents_dict = { + "language": self.hass.config.language, + "intents": { + # Use trigger data index as a virtual intent name for HassIL. + # This works because the intents are rebuilt on every + # register/unregister. + str(trigger_id): {"data": [{"sentences": trigger_data.sentences}]} + for trigger_id, trigger_data in enumerate(self._trigger_sentences) + }, + } + + self._trigger_intents = Intents.from_dict(intents_dict) + _LOGGER.debug("Rebuilt trigger intents: %s", intents_dict) + + def _unregister_trigger(self, trigger_data: TriggerData) -> None: + """Unregister a set of trigger sentences.""" + self._trigger_sentences.remove(trigger_data) + + # Force rebuild on next use + self._trigger_intents = None + + async def _match_triggers(self, sentence: str) -> ConversationResult | None: + """Try to match sentence against registered trigger sentences. + + Calls the registered callbacks if there's a match and returns a positive + conversation result. + """ + if not self._trigger_sentences: + # No triggers registered + return None + + if self._trigger_intents is None: + # Need to rebuild intents before matching + self._rebuild_trigger_intents() + + assert self._trigger_intents is not None + + matched_triggers: set[int] = set() + for result in recognize_all(sentence, self._trigger_intents): + trigger_id = int(result.intent.name) + if trigger_id in matched_triggers: + # Already matched a sentence from this trigger + break + + matched_triggers.add(trigger_id) + + if not matched_triggers: + # Sentence did not match any trigger sentences + return None + + _LOGGER.debug( + "'%s' matched %s trigger(s): %s", + sentence, + len(matched_triggers), + matched_triggers, + ) + + # Gather callback responses in parallel + trigger_responses = await asyncio.gather( + *( + self._trigger_sentences[trigger_id].callback(sentence) + for trigger_id in matched_triggers + ) + ) + + # Use last non-empty result as speech response + speech: str | None = None + for trigger_response in trigger_responses: + speech = speech or trigger_response + + response = intent.IntentResponse(language=self.hass.config.language) + response.response_type = intent.IntentResponseType.ACTION_DONE + response.async_set_speech(speech or "") + + return ConversationResult(response=response) + def _make_error_result( language: str, diff --git a/homeassistant/components/conversation/trigger.py b/homeassistant/components/conversation/trigger.py new file mode 100644 index 00000000000..c12808efa53 --- /dev/null +++ b/homeassistant/components/conversation/trigger.py @@ -0,0 +1,59 @@ +"""Offer sentence based automation rules.""" +from __future__ import annotations + +from typing import Any + +import voluptuous as vol + +from homeassistant.const import CONF_COMMAND, CONF_PLATFORM +from homeassistant.core import CALLBACK_TYPE, HassJob, HomeAssistant, callback +import homeassistant.helpers.config_validation as cv +from homeassistant.helpers.trigger import TriggerActionType, TriggerInfo +from homeassistant.helpers.typing import ConfigType + +from . import HOME_ASSISTANT_AGENT, _get_agent_manager +from .const import DOMAIN +from .default_agent import DefaultAgent + +TRIGGER_SCHEMA = cv.TRIGGER_BASE_SCHEMA.extend( + { + vol.Required(CONF_PLATFORM): DOMAIN, + vol.Required(CONF_COMMAND): vol.All(cv.ensure_list, [cv.string]), + } +) + + +async def async_attach_trigger( + hass: HomeAssistant, + config: ConfigType, + action: TriggerActionType, + trigger_info: TriggerInfo, +) -> CALLBACK_TYPE: + """Listen for events based on configuration.""" + trigger_data = trigger_info["trigger_data"] + sentences = config.get(CONF_COMMAND, []) + + job = HassJob(action) + + @callback + async def call_action(sentence: str) -> str | None: + """Call action with right context.""" + trigger_input: dict[str, Any] = { # Satisfy type checker + **trigger_data, + "platform": DOMAIN, + "sentence": sentence, + } + + # Wait for the automation to complete + if future := hass.async_run_hass_job( + job, + {"trigger": trigger_input}, + ): + await future + + return None + + default_agent = await _get_agent_manager(hass).async_get_agent(HOME_ASSISTANT_AGENT) + assert isinstance(default_agent, DefaultAgent) + + return default_agent.register_trigger(sentences, call_action) diff --git a/tests/components/conversation/test_default_agent.py b/tests/components/conversation/test_default_agent.py index 58fe9371e11..899fd761d5e 100644 --- a/tests/components/conversation/test_default_agent.py +++ b/tests/components/conversation/test_default_agent.py @@ -1,5 +1,5 @@ """Test for the default agent.""" -from unittest.mock import patch +from unittest.mock import AsyncMock, patch import pytest @@ -223,3 +223,45 @@ async def test_unexposed_entities_skipped( assert result.response.response_type == intent.IntentResponseType.QUERY_ANSWER assert len(result.response.matched_states) == 1 assert result.response.matched_states[0].entity_id == exposed_light.entity_id + + +async def test_trigger_sentences(hass: HomeAssistant, init_components) -> None: + """Test registering/unregistering/matching a few trigger sentences.""" + trigger_sentences = ["It's party time", "It is time to party"] + trigger_response = "Cowabunga!" + + agent = await conversation._get_agent_manager(hass).async_get_agent( + conversation.HOME_ASSISTANT_AGENT + ) + assert isinstance(agent, conversation.DefaultAgent) + + callback = AsyncMock(return_value=trigger_response) + unregister = agent.register_trigger(trigger_sentences, callback) + + result = await conversation.async_converse(hass, "Not the trigger", None, Context()) + assert result.response.response_type == intent.IntentResponseType.ERROR + + # Using different case and including punctuation + test_sentences = ["it's party time!", "IT IS TIME TO PARTY."] + for sentence in test_sentences: + callback.reset_mock() + result = await conversation.async_converse(hass, sentence, None, Context()) + callback.assert_called_once_with(sentence) + assert ( + result.response.response_type == intent.IntentResponseType.ACTION_DONE + ), sentence + assert result.response.speech == { + "plain": {"speech": trigger_response, "extra_data": None} + } + + unregister() + + # Should produce errors now + callback.reset_mock() + for sentence in test_sentences: + result = await conversation.async_converse(hass, sentence, None, Context()) + assert ( + result.response.response_type == intent.IntentResponseType.ERROR + ), sentence + + assert len(callback.mock_calls) == 0 diff --git a/tests/components/conversation/test_trigger.py b/tests/components/conversation/test_trigger.py new file mode 100644 index 00000000000..74a5e4df8e2 --- /dev/null +++ b/tests/components/conversation/test_trigger.py @@ -0,0 +1,167 @@ +"""Test conversation triggers.""" +import pytest + +from homeassistant.core import HomeAssistant +from homeassistant.setup import async_setup_component + +from tests.common import async_mock_service + + +@pytest.fixture +def calls(hass): + """Track calls to a mock service.""" + return async_mock_service(hass, "test", "automation") + + +@pytest.fixture(autouse=True) +async def setup_comp(hass): + """Initialize components.""" + assert await async_setup_component(hass, "homeassistant", {}) + assert await async_setup_component(hass, "conversation", {}) + + +async def test_if_fires_on_event(hass: HomeAssistant, calls, setup_comp) -> None: + """Test the firing of events.""" + assert await async_setup_component( + hass, + "automation", + { + "automation": { + "trigger": { + "platform": "conversation", + "command": [ + "Hey yo", + "Ha ha ha", + ], + }, + "action": { + "service": "test.automation", + "data_template": {"data": "{{ trigger }}"}, + }, + } + }, + ) + + await hass.services.async_call( + "conversation", + "process", + { + "text": "Ha ha ha", + }, + blocking=True, + ) + + await hass.async_block_till_done() + assert len(calls) == 1 + assert calls[0].data["data"] == { + "alias": None, + "id": "0", + "idx": "0", + "platform": "conversation", + "sentence": "Ha ha ha", + } + + +async def test_same_trigger_multiple_sentences( + hass: HomeAssistant, calls, setup_comp +) -> None: + """Test matching of multiple sentences from the same trigger.""" + assert await async_setup_component( + hass, + "automation", + { + "automation": { + "trigger": { + "platform": "conversation", + "command": ["hello", "hello[ world]"], + }, + "action": { + "service": "test.automation", + "data_template": {"data": "{{ trigger }}"}, + }, + } + }, + ) + + await hass.services.async_call( + "conversation", + "process", + { + "text": "hello", + }, + blocking=True, + ) + + # Only triggers once + await hass.async_block_till_done() + assert len(calls) == 1 + assert calls[0].data["data"] == { + "alias": None, + "id": "0", + "idx": "0", + "platform": "conversation", + "sentence": "hello", + } + + +async def test_same_sentence_multiple_triggers( + hass: HomeAssistant, calls, setup_comp +) -> None: + """Test use of the same sentence in multiple triggers.""" + assert await async_setup_component( + hass, + "automation", + { + "automation": [ + { + "trigger": { + "id": "trigger1", + "platform": "conversation", + "command": [ + "hello", + ], + }, + "action": { + "service": "test.automation", + "data_template": {"data": "{{ trigger }}"}, + }, + }, + { + "trigger": { + "id": "trigger2", + "platform": "conversation", + "command": [ + "hello[ world]", + ], + }, + "action": { + "service": "test.automation", + "data_template": {"data": "{{ trigger }}"}, + }, + }, + ], + }, + ) + + await hass.services.async_call( + "conversation", + "process", + { + "text": "hello", + }, + blocking=True, + ) + + await hass.async_block_till_done() + assert len(calls) == 2 + + # The calls may come in any order + call_datas: set[tuple[str, str, str]] = set() + for call in calls: + call_data = call.data["data"] + call_datas.add((call_data["id"], call_data["platform"], call_data["sentence"])) + + assert call_datas == { + ("trigger1", "conversation", "hello"), + ("trigger2", "conversation", "hello"), + }