Filter out certain intents from being matched in local fallback (#137763)
* Filter out certain intents from being matched in local fallback * Only filter if LLM agent can control HApull/138887/head
parent
b2e2ef3119
commit
0b6f49fec2
|
@ -13,7 +13,7 @@ from pathlib import Path
|
|||
from queue import Empty, Queue
|
||||
from threading import Thread
|
||||
import time
|
||||
from typing import Any, Literal, cast
|
||||
from typing import TYPE_CHECKING, Any, Literal, cast
|
||||
import wave
|
||||
|
||||
import hass_nabucasa
|
||||
|
@ -30,7 +30,7 @@ from homeassistant.components import (
|
|||
from homeassistant.components.tts import (
|
||||
generate_media_source_id as tts_generate_media_source_id,
|
||||
)
|
||||
from homeassistant.const import MATCH_ALL
|
||||
from homeassistant.const import ATTR_SUPPORTED_FEATURES, MATCH_ALL
|
||||
from homeassistant.core import Context, HomeAssistant, callback
|
||||
from homeassistant.exceptions import HomeAssistantError
|
||||
from homeassistant.helpers import chat_session, intent
|
||||
|
@ -81,6 +81,9 @@ from .error import (
|
|||
)
|
||||
from .vad import AudioBuffer, VoiceActivityTimeout, VoiceCommandSegmenter, chunk_samples
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from hassil.recognize import RecognizeResult
|
||||
|
||||
_LOGGER = logging.getLogger(__name__)
|
||||
|
||||
STORAGE_KEY = f"{DOMAIN}.pipelines"
|
||||
|
@ -123,6 +126,12 @@ STORED_PIPELINE_RUNS = 10
|
|||
SAVE_DELAY = 10
|
||||
|
||||
|
||||
@callback
|
||||
def _async_local_fallback_intent_filter(result: RecognizeResult) -> bool:
|
||||
"""Filter out intents that are not local fallback."""
|
||||
return result.intent.name in (intent.INTENT_GET_STATE, intent.INTENT_NEVERMIND)
|
||||
|
||||
|
||||
@callback
|
||||
def _async_resolve_default_pipeline_settings(
|
||||
hass: HomeAssistant,
|
||||
|
@ -1084,10 +1093,22 @@ class PipelineRun:
|
|||
)
|
||||
intent_response.async_set_speech(trigger_response_text)
|
||||
|
||||
intent_filter: Callable[[RecognizeResult], bool] | None = None
|
||||
# If the LLM has API access, we filter out some sentences that are
|
||||
# interfering with LLM operation.
|
||||
if (
|
||||
intent_agent_state := self.hass.states.get(self.intent_agent)
|
||||
) and intent_agent_state.attributes.get(
|
||||
ATTR_SUPPORTED_FEATURES, 0
|
||||
) & conversation.ConversationEntityFeature.CONTROL:
|
||||
intent_filter = _async_local_fallback_intent_filter
|
||||
|
||||
# Try local intents first, if preferred.
|
||||
elif self.pipeline.prefer_local_intents and (
|
||||
intent_response := await conversation.async_handle_intents(
|
||||
self.hass, user_input
|
||||
self.hass,
|
||||
user_input,
|
||||
intent_filter=intent_filter,
|
||||
)
|
||||
):
|
||||
# Local intent matched
|
||||
|
|
|
@ -2,10 +2,12 @@
|
|||
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Callable
|
||||
import logging
|
||||
import re
|
||||
from typing import Literal
|
||||
|
||||
from hassil.recognize import RecognizeResult
|
||||
import voluptuous as vol
|
||||
|
||||
from homeassistant.config_entries import ConfigEntry
|
||||
|
@ -241,7 +243,10 @@ async def async_handle_sentence_triggers(
|
|||
|
||||
|
||||
async def async_handle_intents(
|
||||
hass: HomeAssistant, user_input: ConversationInput
|
||||
hass: HomeAssistant,
|
||||
user_input: ConversationInput,
|
||||
*,
|
||||
intent_filter: Callable[[RecognizeResult], bool] | None = None,
|
||||
) -> intent.IntentResponse | None:
|
||||
"""Try to match input against registered intents and return response.
|
||||
|
||||
|
@ -250,7 +255,9 @@ async def async_handle_intents(
|
|||
default_agent = async_get_agent(hass)
|
||||
assert isinstance(default_agent, DefaultAgent)
|
||||
|
||||
return await default_agent.async_handle_intents(user_input)
|
||||
return await default_agent.async_handle_intents(
|
||||
user_input, intent_filter=intent_filter
|
||||
)
|
||||
|
||||
|
||||
async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool:
|
||||
|
|
|
@ -1324,6 +1324,8 @@ class DefaultAgent(ConversationEntity):
|
|||
async def async_handle_intents(
|
||||
self,
|
||||
user_input: ConversationInput,
|
||||
*,
|
||||
intent_filter: Callable[[RecognizeResult], bool] | None = None,
|
||||
) -> intent.IntentResponse | None:
|
||||
"""Try to match sentence against registered intents and return response.
|
||||
|
||||
|
@ -1331,7 +1333,9 @@ class DefaultAgent(ConversationEntity):
|
|||
Returns None if no match or a matching error occurred.
|
||||
"""
|
||||
result = await self.async_recognize_intent(user_input, strict_intents_only=True)
|
||||
if not isinstance(result, RecognizeResult):
|
||||
if not isinstance(result, RecognizeResult) or (
|
||||
intent_filter is not None and intent_filter(result)
|
||||
):
|
||||
# No error message on failed match
|
||||
return None
|
||||
|
||||
|
|
|
@ -4,6 +4,7 @@ from collections.abc import AsyncGenerator
|
|||
from typing import Any
|
||||
from unittest.mock import ANY, patch
|
||||
|
||||
from hassil.recognize import Intent, IntentData, RecognizeResult
|
||||
import pytest
|
||||
|
||||
from homeassistant.components import conversation
|
||||
|
@ -16,6 +17,7 @@ from homeassistant.components.assist_pipeline.pipeline import (
|
|||
PipelineData,
|
||||
PipelineStorageCollection,
|
||||
PipelineStore,
|
||||
_async_local_fallback_intent_filter,
|
||||
async_create_default_pipeline,
|
||||
async_get_pipeline,
|
||||
async_get_pipelines,
|
||||
|
@ -23,6 +25,7 @@ from homeassistant.components.assist_pipeline.pipeline import (
|
|||
async_update_pipeline,
|
||||
)
|
||||
from homeassistant.core import HomeAssistant
|
||||
from homeassistant.helpers import intent
|
||||
from homeassistant.setup import async_setup_component
|
||||
|
||||
from . import MANY_LANGUAGES
|
||||
|
@ -657,3 +660,40 @@ async def test_migrate_after_load(hass: HomeAssistant) -> None:
|
|||
|
||||
assert pipeline_updated.stt_engine == "stt.test"
|
||||
assert pipeline_updated.tts_engine == "tts.test"
|
||||
|
||||
|
||||
def test_fallback_intent_filter() -> None:
|
||||
"""Test that we filter the right things."""
|
||||
assert (
|
||||
_async_local_fallback_intent_filter(
|
||||
RecognizeResult(
|
||||
intent=Intent(intent.INTENT_GET_STATE),
|
||||
intent_data=IntentData([]),
|
||||
entities={},
|
||||
entities_list=[],
|
||||
)
|
||||
)
|
||||
is True
|
||||
)
|
||||
assert (
|
||||
_async_local_fallback_intent_filter(
|
||||
RecognizeResult(
|
||||
intent=Intent(intent.INTENT_NEVERMIND),
|
||||
intent_data=IntentData([]),
|
||||
entities={},
|
||||
entities_list=[],
|
||||
)
|
||||
)
|
||||
is True
|
||||
)
|
||||
assert (
|
||||
_async_local_fallback_intent_filter(
|
||||
RecognizeResult(
|
||||
intent=Intent(intent.INTENT_TURN_ON),
|
||||
intent_data=IntentData([]),
|
||||
entities={},
|
||||
entities_list=[],
|
||||
)
|
||||
)
|
||||
is False
|
||||
)
|
||||
|
|
|
@ -3154,6 +3154,79 @@ async def test_handle_intents_with_response_errors(
|
|||
assert response is None
|
||||
|
||||
|
||||
@pytest.mark.usefixtures("init_components")
|
||||
async def test_handle_intents_filters_results(
|
||||
hass: HomeAssistant,
|
||||
init_components: None,
|
||||
area_registry: ar.AreaRegistry,
|
||||
) -> None:
|
||||
"""Test that handle_intents can filter responses."""
|
||||
assert await async_setup_component(hass, "climate", {})
|
||||
area_registry.async_create("living room")
|
||||
|
||||
agent: default_agent.DefaultAgent = hass.data[DATA_DEFAULT_ENTITY]
|
||||
|
||||
user_input = ConversationInput(
|
||||
text="What is the temperature in the living room?",
|
||||
context=Context(),
|
||||
conversation_id=None,
|
||||
device_id=None,
|
||||
language=hass.config.language,
|
||||
agent_id=None,
|
||||
)
|
||||
|
||||
mock_result = RecognizeResult(
|
||||
intent=Intent("HassTurnOn"),
|
||||
intent_data=IntentData([]),
|
||||
entities={},
|
||||
entities_list=[],
|
||||
)
|
||||
results = []
|
||||
|
||||
def _filter_intents(result):
|
||||
results.append(result)
|
||||
# We filter first, not 2nd.
|
||||
return len(results) == 1
|
||||
|
||||
with (
|
||||
patch(
|
||||
"homeassistant.components.conversation.default_agent.DefaultAgent.async_recognize_intent",
|
||||
return_value=mock_result,
|
||||
) as mock_recognize,
|
||||
patch(
|
||||
"homeassistant.components.conversation.default_agent.DefaultAgent._async_process_intent_result",
|
||||
) as mock_process,
|
||||
):
|
||||
response = await agent.async_handle_intents(
|
||||
user_input, intent_filter=_filter_intents
|
||||
)
|
||||
|
||||
assert len(mock_recognize.mock_calls) == 1
|
||||
assert len(mock_process.mock_calls) == 0
|
||||
|
||||
# It was ignored
|
||||
assert response is None
|
||||
|
||||
# Check we filtered things
|
||||
assert len(results) == 1
|
||||
assert results[0] is mock_result
|
||||
|
||||
# Second time it is not filtered
|
||||
response = await agent.async_handle_intents(
|
||||
user_input, intent_filter=_filter_intents
|
||||
)
|
||||
|
||||
assert len(mock_recognize.mock_calls) == 2
|
||||
assert len(mock_process.mock_calls) == 2
|
||||
|
||||
# Check we filtered things
|
||||
assert len(results) == 2
|
||||
assert results[1] is mock_result
|
||||
|
||||
# It was ignored
|
||||
assert response is not None
|
||||
|
||||
|
||||
@pytest.mark.usefixtures("init_components")
|
||||
async def test_state_names_are_not_translated(
|
||||
hass: HomeAssistant,
|
||||
|
|
Loading…
Reference in New Issue