Filter entity names before intent matching (#131563)

pull/131607/head
Michael Hansen 2024-11-26 02:42:31 -06:00 committed by GitHub
parent 6947800d93
commit 752df5a8cb
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
7 changed files with 166 additions and 78 deletions

View File

@ -14,8 +14,14 @@ import re
import time
from typing import IO, Any, cast
from hassil.expression import Expression, ListReference, Sequence
from hassil.intents import Intents, SlotList, TextSlotList, WildcardSlotList
from hassil.expression import Expression, ListReference, Sequence, TextChunk
from hassil.intents import (
Intents,
SlotList,
TextSlotList,
TextSlotValue,
WildcardSlotList,
)
from hassil.recognize import (
MISSING_ENTITY,
RecognizeResult,
@ -23,6 +29,7 @@ from hassil.recognize import (
recognize_best,
)
from hassil.string_matcher import UnmatchedRangeEntity, UnmatchedTextEntity
from hassil.trie import Trie
from hassil.util import merge_dict
from home_assistant_intents import ErrorKey, get_intents, get_languages
import yaml
@ -110,8 +117,8 @@ class IntentMatchingStage(Enum):
EXPOSED_ENTITIES_ONLY = auto()
"""Match against exposed entities only."""
ALL_ENTITIES = auto()
"""Match against all entities in Home Assistant."""
UNEXPOSED_ENTITIES = auto()
"""Match against unexposed entities in Home Assistant."""
FUZZY = auto()
"""Capture names that are not known to Home Assistant."""
@ -233,7 +240,10 @@ class DefaultAgent(ConversationEntity):
# intent -> [sentences]
self._config_intents: dict[str, Any] = config_intents
self._slot_lists: dict[str, SlotList] | None = None
self._all_entity_names: TextSlotList | None = None
# Used to filter slot lists before intent matching
self._exposed_names_trie: Trie | None = None
self._unexposed_names_trie: Trie | None = None
# Sentences that will trigger a callback (skipping intent recognition)
self._trigger_sentences: list[TriggerData] = []
@ -305,6 +315,16 @@ class DefaultAgent(ConversationEntity):
slot_lists = self._make_slot_lists()
intent_context = self._make_intent_context(user_input)
if self._exposed_names_trie is not None:
# Filter by input string
text_lower = user_input.text.strip().lower()
slot_lists["name"] = TextSlotList(
name="name",
values=[
result[2] for result in self._exposed_names_trie.find(text_lower)
],
)
start = time.monotonic()
result = await self.hass.async_add_executor_job(
@ -540,29 +560,29 @@ class DefaultAgent(ConversationEntity):
return None
# Try again with all entities (including unexposed)
skip_all_entities_match = False
skip_unexposed_entities_match = False
if cache_value is not None:
if (cache_value.result is not None) and (
cache_value.stage == IntentMatchingStage.ALL_ENTITIES
cache_value.stage == IntentMatchingStage.UNEXPOSED_ENTITIES
):
_LOGGER.debug("Got cached result for all entities")
return cache_value.result
# Continue with matching, but we know we won't succeed for all
# entities.
skip_all_entities_match = True
skip_unexposed_entities_match = True
if not skip_all_entities_match:
all_entities_slot_lists = {
if not skip_unexposed_entities_match:
unexposed_entities_slot_lists = {
**slot_lists,
"name": self._get_all_entity_names(),
"name": self._get_unexposed_entity_names(user_input.text),
}
start_time = time.monotonic()
strict_result = self._recognize_strict(
user_input,
lang_intents,
all_entities_slot_lists,
unexposed_entities_slot_lists,
intent_context,
language,
)
@ -575,7 +595,7 @@ class DefaultAgent(ConversationEntity):
self._intent_cache.put(
cache_key,
IntentCacheValue(
result=strict_result, stage=IntentMatchingStage.ALL_ENTITIES
result=strict_result, stage=IntentMatchingStage.UNEXPOSED_ENTITIES
),
)
@ -683,15 +703,43 @@ class DefaultAgent(ConversationEntity):
return maybe_result
def _get_all_entity_names(self) -> TextSlotList:
"""Get slot list with all entity names in Home Assistant."""
if self._all_entity_names is not None:
return self._all_entity_names
def _get_unexposed_entity_names(self, text: str) -> TextSlotList:
"""Get filtered slot list with unexposed entity names in Home Assistant."""
if self._unexposed_names_trie is None:
# Build trie
self._unexposed_names_trie = Trie()
for name_tuple in self._get_entity_name_tuples(exposed=False):
self._unexposed_names_trie.insert(
name_tuple[0].lower(),
TextSlotValue.from_tuple(name_tuple),
)
# Build filtered slot list
text_lower = text.strip().lower()
return TextSlotList(
name="name",
values=[
result[2] for result in self._unexposed_names_trie.find(text_lower)
],
)
def _get_entity_name_tuples(
self, exposed: bool
) -> Iterable[tuple[str, str, dict[str, Any]]]:
"""Yield (input name, output name, context) tuples for entities."""
entity_registry = er.async_get(self.hass)
all_entity_names: list[tuple[str, str, dict[str, Any]]] = []
for state in self.hass.states.async_all():
entity_exposed = async_should_expose(self.hass, DOMAIN, state.entity_id)
if exposed and (not entity_exposed):
# Required exposed, entity is not
continue
if (not exposed) and entity_exposed:
# Required not exposed, entity is
continue
# Checked against "requires_context" and "excludes_context" in hassil
context = {"domain": state.domain}
if state.attributes:
# Include some attributes
@ -700,28 +748,18 @@ class DefaultAgent(ConversationEntity):
continue
context[attr] = state.attributes[attr]
if entity := entity_registry.async_get(state.entity_id):
# Skip config/hidden entities
if (entity.entity_category is not None) or (
entity.hidden_by is not None
):
continue
if (
entity := entity_registry.async_get(state.entity_id)
) and entity.aliases:
for alias in entity.aliases:
alias = alias.strip()
if not alias:
continue
if entity.aliases:
# Also add aliases
for alias in entity.aliases:
if not alias.strip():
continue
all_entity_names.append((alias, alias, context))
yield (alias, alias, context)
# Default name
all_entity_names.append((state.name, state.name, context))
self._all_entity_names = TextSlotList.from_tuples(
all_entity_names, allow_template=False
)
return self._all_entity_names
yield (state.name, state.name, context)
def _recognize_strict(
self,
@ -1013,7 +1051,8 @@ class DefaultAgent(ConversationEntity):
if self._unsub_clear_slot_list is None:
return
self._slot_lists = None
self._all_entity_names = None
self._exposed_names_trie = None
self._unexposed_names_trie = None
for unsub in self._unsub_clear_slot_list:
unsub()
self._unsub_clear_slot_list = None
@ -1029,8 +1068,6 @@ class DefaultAgent(ConversationEntity):
start = time.monotonic()
entity_registry = er.async_get(self.hass)
# Gather entity names, keeping track of exposed names.
# We try intent recognition with only exposed names first, then all names.
#
@ -1038,35 +1075,7 @@ class DefaultAgent(ConversationEntity):
# have the same name. The intent matcher doesn't gather all matching
# values for a list, just the first. So we will need to match by name no
# matter what.
exposed_entity_names = []
for state in self.hass.states.async_all():
is_exposed = async_should_expose(self.hass, DOMAIN, state.entity_id)
# Checked against "requires_context" and "excludes_context" in hassil
context = {"domain": state.domain}
if state.attributes:
# Include some attributes
for attr in DEFAULT_EXPOSED_ATTRIBUTES:
if attr not in state.attributes:
continue
context[attr] = state.attributes[attr]
if (
entity := entity_registry.async_get(state.entity_id)
) and entity.aliases:
for alias in entity.aliases:
if not alias.strip():
continue
name_tuple = (alias, alias, context)
if is_exposed:
exposed_entity_names.append(name_tuple)
# Default name
name_tuple = (state.name, state.name, context)
if is_exposed:
exposed_entity_names.append(name_tuple)
exposed_entity_names = list(self._get_entity_name_tuples(exposed=True))
_LOGGER.debug("Exposed entities: %s", exposed_entity_names)
# Expose all areas.
@ -1099,11 +1108,17 @@ class DefaultAgent(ConversationEntity):
floor_names.append((alias, floor.name))
# Build trie
self._exposed_names_trie = Trie()
name_list = TextSlotList.from_tuples(exposed_entity_names, allow_template=False)
for name_value in name_list.values:
assert isinstance(name_value.text_in, TextChunk)
name_text = name_value.text_in.text.strip().lower()
self._exposed_names_trie.insert(name_text, name_value)
self._slot_lists = {
"area": TextSlotList.from_tuples(area_names, allow_template=False),
"name": TextSlotList.from_tuples(
exposed_entity_names, allow_template=False
),
"name": name_list,
"floor": TextSlotList.from_tuples(floor_names, allow_template=False),
}

View File

@ -6,5 +6,5 @@
"documentation": "https://www.home-assistant.io/integrations/conversation",
"integration_type": "system",
"quality_scale": "internal",
"requirements": ["hassil==2.0.2", "home-assistant-intents==2024.11.13"]
"requirements": ["hassil==2.0.4", "home-assistant-intents==2024.11.13"]
}

View File

@ -32,7 +32,7 @@ go2rtc-client==0.1.1
ha-ffmpeg==3.2.2
habluetooth==3.6.0
hass-nabucasa==0.85.0
hassil==2.0.2
hassil==2.0.4
home-assistant-bluetooth==1.13.0
home-assistant-frontend==20241106.2
home-assistant-intents==2024.11.13

View File

@ -1096,7 +1096,7 @@ hass-nabucasa==0.85.0
hass-splunk==0.1.1
# homeassistant.components.conversation
hassil==2.0.2
hassil==2.0.4
# homeassistant.components.jewish_calendar
hdate==0.11.1

View File

@ -931,7 +931,7 @@ habluetooth==3.6.0
hass-nabucasa==0.85.0
# homeassistant.components.conversation
hassil==2.0.2
hassil==2.0.4
# homeassistant.components.jewish_calendar
hdate==0.11.1

View File

@ -23,7 +23,7 @@ RUN --mount=from=ghcr.io/astral-sh/uv:0.5.4,source=/uv,target=/bin/uv \
-c /usr/src/homeassistant/homeassistant/package_constraints.txt \
-r /usr/src/homeassistant/requirements.txt \
stdlib-list==0.10.0 pipdeptree==2.23.4 tqdm==4.66.5 ruff==0.8.0 \
PyTurboJPEG==1.7.5 go2rtc-client==0.1.1 ha-ffmpeg==3.2.2 hassil==2.0.2 home-assistant-intents==2024.11.13 mutagen==1.47.0 pymicro-vad==1.0.1 pyspeex-noise==1.0.2
PyTurboJPEG==1.7.5 go2rtc-client==0.1.1 ha-ffmpeg==3.2.2 hassil==2.0.4 home-assistant-intents==2024.11.13 mutagen==1.47.0 pymicro-vad==1.0.1 pyspeex-noise==1.0.2
LABEL "name"="hassfest"
LABEL "maintainer"="Home Assistant <hello@home-assistant.io>"

View File

@ -1735,7 +1735,7 @@ async def test_empty_aliases(
return_value=None,
) as mock_recognize_all:
await conversation.async_converse(
hass, "turn on lights in the kitchen", None, Context(), None
hass, "turn on kitchen light", None, Context(), None
)
assert mock_recognize_all.call_count > 0
@ -2940,3 +2940,76 @@ async def test_intent_cache_fuzzy(hass: HomeAssistant) -> None:
result = await agent.async_recognize_intent(user_input)
assert result is not None
assert getattr(result, mark, None) is True
@pytest.mark.usefixtures("init_components")
async def test_entities_filtered_by_input(hass: HomeAssistant) -> None:
"""Test that entities are filtered by the input text before intent matching."""
agent = hass.data[DATA_DEFAULT_ENTITY]
assert isinstance(agent, default_agent.DefaultAgent)
# Only the switch is exposed
hass.states.async_set("light.test_light", "off")
hass.states.async_set(
"light.test_light_2", "off", attributes={ATTR_FRIENDLY_NAME: "test light"}
)
hass.states.async_set("cover.garage_door", "closed")
hass.states.async_set("switch.test_switch", "off")
expose_entity(hass, "light.test_light", False)
expose_entity(hass, "light.test_light_2", False)
expose_entity(hass, "cover.garage_door", False)
expose_entity(hass, "switch.test_switch", True)
await hass.async_block_till_done()
# test switch is exposed
user_input = ConversationInput(
text="turn on test switch",
context=Context(),
conversation_id=None,
device_id=None,
language=hass.config.language,
agent_id=None,
)
with patch(
"homeassistant.components.conversation.default_agent.recognize_best",
return_value=None,
) as recognize_best:
await agent.async_recognize_intent(user_input)
# (1) exposed, (2) all entities
assert len(recognize_best.call_args_list) == 2
# Only the test light should have been considered because its name shows
# up in the input text.
slot_lists = recognize_best.call_args_list[0].kwargs["slot_lists"]
name_list = slot_lists["name"]
assert len(name_list.values) == 1
assert name_list.values[0].text_in.text == "test switch"
# test light is not exposed
user_input = ConversationInput(
text="turn on Test Light", # different casing for name
context=Context(),
conversation_id=None,
device_id=None,
language=hass.config.language,
agent_id=None,
)
with patch(
"homeassistant.components.conversation.default_agent.recognize_best",
return_value=None,
) as recognize_best:
await agent.async_recognize_intent(user_input)
# (1) exposed, (2) all entities
assert len(recognize_best.call_args_list) == 2
# Both test lights should have been considered because their name shows
# up in the input text.
slot_lists = recognize_best.call_args_list[1].kwargs["slot_lists"]
name_list = slot_lists["name"]
assert len(name_list.values) == 2
assert name_list.values[0].text_in.text == "test light"
assert name_list.values[1].text_in.text == "test light"