Filter out certain intents from being matched in local fallback (#137763)

* Filter out certain intents from being matched in local fallback

* Only filter if LLM agent can control HA
pull/138887/head
Paulus Schoutsen 2025-02-19 15:27:42 -05:00 committed by GitHub
parent b2e2ef3119
commit 0b6f49fec2
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 151 additions and 6 deletions

View File

@ -13,7 +13,7 @@ from pathlib import Path
from queue import Empty, Queue
from threading import Thread
import time
from typing import Any, Literal, cast
from typing import TYPE_CHECKING, Any, Literal, cast
import wave
import hass_nabucasa
@ -30,7 +30,7 @@ from homeassistant.components import (
from homeassistant.components.tts import (
generate_media_source_id as tts_generate_media_source_id,
)
from homeassistant.const import MATCH_ALL
from homeassistant.const import ATTR_SUPPORTED_FEATURES, MATCH_ALL
from homeassistant.core import Context, HomeAssistant, callback
from homeassistant.exceptions import HomeAssistantError
from homeassistant.helpers import chat_session, intent
@ -81,6 +81,9 @@ from .error import (
)
from .vad import AudioBuffer, VoiceActivityTimeout, VoiceCommandSegmenter, chunk_samples
if TYPE_CHECKING:
from hassil.recognize import RecognizeResult
_LOGGER = logging.getLogger(__name__)
STORAGE_KEY = f"{DOMAIN}.pipelines"
@ -123,6 +126,12 @@ STORED_PIPELINE_RUNS = 10
SAVE_DELAY = 10
@callback
def _async_local_fallback_intent_filter(result: RecognizeResult) -> bool:
"""Filter out intents that are not local fallback."""
return result.intent.name in (intent.INTENT_GET_STATE, intent.INTENT_NEVERMIND)
@callback
def _async_resolve_default_pipeline_settings(
hass: HomeAssistant,
@ -1084,10 +1093,22 @@ class PipelineRun:
)
intent_response.async_set_speech(trigger_response_text)
intent_filter: Callable[[RecognizeResult], bool] | None = None
# If the LLM has API access, we filter out some sentences that are
# interfering with LLM operation.
if (
intent_agent_state := self.hass.states.get(self.intent_agent)
) and intent_agent_state.attributes.get(
ATTR_SUPPORTED_FEATURES, 0
) & conversation.ConversationEntityFeature.CONTROL:
intent_filter = _async_local_fallback_intent_filter
# Try local intents first, if preferred.
elif self.pipeline.prefer_local_intents and (
intent_response := await conversation.async_handle_intents(
self.hass, user_input
self.hass,
user_input,
intent_filter=intent_filter,
)
):
# Local intent matched

View File

@ -2,10 +2,12 @@
from __future__ import annotations
from collections.abc import Callable
import logging
import re
from typing import Literal
from hassil.recognize import RecognizeResult
import voluptuous as vol
from homeassistant.config_entries import ConfigEntry
@ -241,7 +243,10 @@ async def async_handle_sentence_triggers(
async def async_handle_intents(
hass: HomeAssistant, user_input: ConversationInput
hass: HomeAssistant,
user_input: ConversationInput,
*,
intent_filter: Callable[[RecognizeResult], bool] | None = None,
) -> intent.IntentResponse | None:
"""Try to match input against registered intents and return response.
@ -250,7 +255,9 @@ async def async_handle_intents(
default_agent = async_get_agent(hass)
assert isinstance(default_agent, DefaultAgent)
return await default_agent.async_handle_intents(user_input)
return await default_agent.async_handle_intents(
user_input, intent_filter=intent_filter
)
async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool:

View File

@ -1324,6 +1324,8 @@ class DefaultAgent(ConversationEntity):
async def async_handle_intents(
self,
user_input: ConversationInput,
*,
intent_filter: Callable[[RecognizeResult], bool] | None = None,
) -> intent.IntentResponse | None:
"""Try to match sentence against registered intents and return response.
@ -1331,7 +1333,9 @@ class DefaultAgent(ConversationEntity):
Returns None if no match or a matching error occurred.
"""
result = await self.async_recognize_intent(user_input, strict_intents_only=True)
if not isinstance(result, RecognizeResult):
if not isinstance(result, RecognizeResult) or (
intent_filter is not None and intent_filter(result)
):
# No error message on failed match
return None

View File

@ -4,6 +4,7 @@ from collections.abc import AsyncGenerator
from typing import Any
from unittest.mock import ANY, patch
from hassil.recognize import Intent, IntentData, RecognizeResult
import pytest
from homeassistant.components import conversation
@ -16,6 +17,7 @@ from homeassistant.components.assist_pipeline.pipeline import (
PipelineData,
PipelineStorageCollection,
PipelineStore,
_async_local_fallback_intent_filter,
async_create_default_pipeline,
async_get_pipeline,
async_get_pipelines,
@ -23,6 +25,7 @@ from homeassistant.components.assist_pipeline.pipeline import (
async_update_pipeline,
)
from homeassistant.core import HomeAssistant
from homeassistant.helpers import intent
from homeassistant.setup import async_setup_component
from . import MANY_LANGUAGES
@ -657,3 +660,40 @@ async def test_migrate_after_load(hass: HomeAssistant) -> None:
assert pipeline_updated.stt_engine == "stt.test"
assert pipeline_updated.tts_engine == "tts.test"
def test_fallback_intent_filter() -> None:
"""Test that we filter the right things."""
assert (
_async_local_fallback_intent_filter(
RecognizeResult(
intent=Intent(intent.INTENT_GET_STATE),
intent_data=IntentData([]),
entities={},
entities_list=[],
)
)
is True
)
assert (
_async_local_fallback_intent_filter(
RecognizeResult(
intent=Intent(intent.INTENT_NEVERMIND),
intent_data=IntentData([]),
entities={},
entities_list=[],
)
)
is True
)
assert (
_async_local_fallback_intent_filter(
RecognizeResult(
intent=Intent(intent.INTENT_TURN_ON),
intent_data=IntentData([]),
entities={},
entities_list=[],
)
)
is False
)

View File

@ -3154,6 +3154,79 @@ async def test_handle_intents_with_response_errors(
assert response is None
@pytest.mark.usefixtures("init_components")
async def test_handle_intents_filters_results(
hass: HomeAssistant,
init_components: None,
area_registry: ar.AreaRegistry,
) -> None:
"""Test that handle_intents can filter responses."""
assert await async_setup_component(hass, "climate", {})
area_registry.async_create("living room")
agent: default_agent.DefaultAgent = hass.data[DATA_DEFAULT_ENTITY]
user_input = ConversationInput(
text="What is the temperature in the living room?",
context=Context(),
conversation_id=None,
device_id=None,
language=hass.config.language,
agent_id=None,
)
mock_result = RecognizeResult(
intent=Intent("HassTurnOn"),
intent_data=IntentData([]),
entities={},
entities_list=[],
)
results = []
def _filter_intents(result):
results.append(result)
# We filter first, not 2nd.
return len(results) == 1
with (
patch(
"homeassistant.components.conversation.default_agent.DefaultAgent.async_recognize_intent",
return_value=mock_result,
) as mock_recognize,
patch(
"homeassistant.components.conversation.default_agent.DefaultAgent._async_process_intent_result",
) as mock_process,
):
response = await agent.async_handle_intents(
user_input, intent_filter=_filter_intents
)
assert len(mock_recognize.mock_calls) == 1
assert len(mock_process.mock_calls) == 0
# It was ignored
assert response is None
# Check we filtered things
assert len(results) == 1
assert results[0] is mock_result
# Second time it is not filtered
response = await agent.async_handle_intents(
user_input, intent_filter=_filter_intents
)
assert len(mock_recognize.mock_calls) == 2
assert len(mock_process.mock_calls) == 2
# Check we filtered things
assert len(results) == 2
assert results[1] is mock_result
# It was ignored
assert response is not None
@pytest.mark.usefixtures("init_components")
async def test_state_names_are_not_translated(
hass: HomeAssistant,