Filter entity names before intent matching (#131563)
parent
6947800d93
commit
752df5a8cb
|
@ -14,8 +14,14 @@ import re
|
|||
import time
|
||||
from typing import IO, Any, cast
|
||||
|
||||
from hassil.expression import Expression, ListReference, Sequence
|
||||
from hassil.intents import Intents, SlotList, TextSlotList, WildcardSlotList
|
||||
from hassil.expression import Expression, ListReference, Sequence, TextChunk
|
||||
from hassil.intents import (
|
||||
Intents,
|
||||
SlotList,
|
||||
TextSlotList,
|
||||
TextSlotValue,
|
||||
WildcardSlotList,
|
||||
)
|
||||
from hassil.recognize import (
|
||||
MISSING_ENTITY,
|
||||
RecognizeResult,
|
||||
|
@ -23,6 +29,7 @@ from hassil.recognize import (
|
|||
recognize_best,
|
||||
)
|
||||
from hassil.string_matcher import UnmatchedRangeEntity, UnmatchedTextEntity
|
||||
from hassil.trie import Trie
|
||||
from hassil.util import merge_dict
|
||||
from home_assistant_intents import ErrorKey, get_intents, get_languages
|
||||
import yaml
|
||||
|
@ -110,8 +117,8 @@ class IntentMatchingStage(Enum):
|
|||
EXPOSED_ENTITIES_ONLY = auto()
|
||||
"""Match against exposed entities only."""
|
||||
|
||||
ALL_ENTITIES = auto()
|
||||
"""Match against all entities in Home Assistant."""
|
||||
UNEXPOSED_ENTITIES = auto()
|
||||
"""Match against unexposed entities in Home Assistant."""
|
||||
|
||||
FUZZY = auto()
|
||||
"""Capture names that are not known to Home Assistant."""
|
||||
|
@ -233,7 +240,10 @@ class DefaultAgent(ConversationEntity):
|
|||
# intent -> [sentences]
|
||||
self._config_intents: dict[str, Any] = config_intents
|
||||
self._slot_lists: dict[str, SlotList] | None = None
|
||||
self._all_entity_names: TextSlotList | None = None
|
||||
|
||||
# Used to filter slot lists before intent matching
|
||||
self._exposed_names_trie: Trie | None = None
|
||||
self._unexposed_names_trie: Trie | None = None
|
||||
|
||||
# Sentences that will trigger a callback (skipping intent recognition)
|
||||
self._trigger_sentences: list[TriggerData] = []
|
||||
|
@ -305,6 +315,16 @@ class DefaultAgent(ConversationEntity):
|
|||
slot_lists = self._make_slot_lists()
|
||||
intent_context = self._make_intent_context(user_input)
|
||||
|
||||
if self._exposed_names_trie is not None:
|
||||
# Filter by input string
|
||||
text_lower = user_input.text.strip().lower()
|
||||
slot_lists["name"] = TextSlotList(
|
||||
name="name",
|
||||
values=[
|
||||
result[2] for result in self._exposed_names_trie.find(text_lower)
|
||||
],
|
||||
)
|
||||
|
||||
start = time.monotonic()
|
||||
|
||||
result = await self.hass.async_add_executor_job(
|
||||
|
@ -540,29 +560,29 @@ class DefaultAgent(ConversationEntity):
|
|||
return None
|
||||
|
||||
# Try again with all entities (including unexposed)
|
||||
skip_all_entities_match = False
|
||||
skip_unexposed_entities_match = False
|
||||
if cache_value is not None:
|
||||
if (cache_value.result is not None) and (
|
||||
cache_value.stage == IntentMatchingStage.ALL_ENTITIES
|
||||
cache_value.stage == IntentMatchingStage.UNEXPOSED_ENTITIES
|
||||
):
|
||||
_LOGGER.debug("Got cached result for all entities")
|
||||
return cache_value.result
|
||||
|
||||
# Continue with matching, but we know we won't succeed for all
|
||||
# entities.
|
||||
skip_all_entities_match = True
|
||||
skip_unexposed_entities_match = True
|
||||
|
||||
if not skip_all_entities_match:
|
||||
all_entities_slot_lists = {
|
||||
if not skip_unexposed_entities_match:
|
||||
unexposed_entities_slot_lists = {
|
||||
**slot_lists,
|
||||
"name": self._get_all_entity_names(),
|
||||
"name": self._get_unexposed_entity_names(user_input.text),
|
||||
}
|
||||
|
||||
start_time = time.monotonic()
|
||||
strict_result = self._recognize_strict(
|
||||
user_input,
|
||||
lang_intents,
|
||||
all_entities_slot_lists,
|
||||
unexposed_entities_slot_lists,
|
||||
intent_context,
|
||||
language,
|
||||
)
|
||||
|
@ -575,7 +595,7 @@ class DefaultAgent(ConversationEntity):
|
|||
self._intent_cache.put(
|
||||
cache_key,
|
||||
IntentCacheValue(
|
||||
result=strict_result, stage=IntentMatchingStage.ALL_ENTITIES
|
||||
result=strict_result, stage=IntentMatchingStage.UNEXPOSED_ENTITIES
|
||||
),
|
||||
)
|
||||
|
||||
|
@ -683,15 +703,43 @@ class DefaultAgent(ConversationEntity):
|
|||
|
||||
return maybe_result
|
||||
|
||||
def _get_all_entity_names(self) -> TextSlotList:
|
||||
"""Get slot list with all entity names in Home Assistant."""
|
||||
if self._all_entity_names is not None:
|
||||
return self._all_entity_names
|
||||
def _get_unexposed_entity_names(self, text: str) -> TextSlotList:
|
||||
"""Get filtered slot list with unexposed entity names in Home Assistant."""
|
||||
if self._unexposed_names_trie is None:
|
||||
# Build trie
|
||||
self._unexposed_names_trie = Trie()
|
||||
for name_tuple in self._get_entity_name_tuples(exposed=False):
|
||||
self._unexposed_names_trie.insert(
|
||||
name_tuple[0].lower(),
|
||||
TextSlotValue.from_tuple(name_tuple),
|
||||
)
|
||||
|
||||
# Build filtered slot list
|
||||
text_lower = text.strip().lower()
|
||||
return TextSlotList(
|
||||
name="name",
|
||||
values=[
|
||||
result[2] for result in self._unexposed_names_trie.find(text_lower)
|
||||
],
|
||||
)
|
||||
|
||||
def _get_entity_name_tuples(
|
||||
self, exposed: bool
|
||||
) -> Iterable[tuple[str, str, dict[str, Any]]]:
|
||||
"""Yield (input name, output name, context) tuples for entities."""
|
||||
entity_registry = er.async_get(self.hass)
|
||||
all_entity_names: list[tuple[str, str, dict[str, Any]]] = []
|
||||
|
||||
for state in self.hass.states.async_all():
|
||||
entity_exposed = async_should_expose(self.hass, DOMAIN, state.entity_id)
|
||||
if exposed and (not entity_exposed):
|
||||
# Required exposed, entity is not
|
||||
continue
|
||||
|
||||
if (not exposed) and entity_exposed:
|
||||
# Required not exposed, entity is
|
||||
continue
|
||||
|
||||
# Checked against "requires_context" and "excludes_context" in hassil
|
||||
context = {"domain": state.domain}
|
||||
if state.attributes:
|
||||
# Include some attributes
|
||||
|
@ -700,28 +748,18 @@ class DefaultAgent(ConversationEntity):
|
|||
continue
|
||||
context[attr] = state.attributes[attr]
|
||||
|
||||
if entity := entity_registry.async_get(state.entity_id):
|
||||
# Skip config/hidden entities
|
||||
if (entity.entity_category is not None) or (
|
||||
entity.hidden_by is not None
|
||||
):
|
||||
continue
|
||||
if (
|
||||
entity := entity_registry.async_get(state.entity_id)
|
||||
) and entity.aliases:
|
||||
for alias in entity.aliases:
|
||||
alias = alias.strip()
|
||||
if not alias:
|
||||
continue
|
||||
|
||||
if entity.aliases:
|
||||
# Also add aliases
|
||||
for alias in entity.aliases:
|
||||
if not alias.strip():
|
||||
continue
|
||||
|
||||
all_entity_names.append((alias, alias, context))
|
||||
yield (alias, alias, context)
|
||||
|
||||
# Default name
|
||||
all_entity_names.append((state.name, state.name, context))
|
||||
|
||||
self._all_entity_names = TextSlotList.from_tuples(
|
||||
all_entity_names, allow_template=False
|
||||
)
|
||||
return self._all_entity_names
|
||||
yield (state.name, state.name, context)
|
||||
|
||||
def _recognize_strict(
|
||||
self,
|
||||
|
@ -1013,7 +1051,8 @@ class DefaultAgent(ConversationEntity):
|
|||
if self._unsub_clear_slot_list is None:
|
||||
return
|
||||
self._slot_lists = None
|
||||
self._all_entity_names = None
|
||||
self._exposed_names_trie = None
|
||||
self._unexposed_names_trie = None
|
||||
for unsub in self._unsub_clear_slot_list:
|
||||
unsub()
|
||||
self._unsub_clear_slot_list = None
|
||||
|
@ -1029,8 +1068,6 @@ class DefaultAgent(ConversationEntity):
|
|||
|
||||
start = time.monotonic()
|
||||
|
||||
entity_registry = er.async_get(self.hass)
|
||||
|
||||
# Gather entity names, keeping track of exposed names.
|
||||
# We try intent recognition with only exposed names first, then all names.
|
||||
#
|
||||
|
@ -1038,35 +1075,7 @@ class DefaultAgent(ConversationEntity):
|
|||
# have the same name. The intent matcher doesn't gather all matching
|
||||
# values for a list, just the first. So we will need to match by name no
|
||||
# matter what.
|
||||
exposed_entity_names = []
|
||||
for state in self.hass.states.async_all():
|
||||
is_exposed = async_should_expose(self.hass, DOMAIN, state.entity_id)
|
||||
|
||||
# Checked against "requires_context" and "excludes_context" in hassil
|
||||
context = {"domain": state.domain}
|
||||
if state.attributes:
|
||||
# Include some attributes
|
||||
for attr in DEFAULT_EXPOSED_ATTRIBUTES:
|
||||
if attr not in state.attributes:
|
||||
continue
|
||||
context[attr] = state.attributes[attr]
|
||||
|
||||
if (
|
||||
entity := entity_registry.async_get(state.entity_id)
|
||||
) and entity.aliases:
|
||||
for alias in entity.aliases:
|
||||
if not alias.strip():
|
||||
continue
|
||||
|
||||
name_tuple = (alias, alias, context)
|
||||
if is_exposed:
|
||||
exposed_entity_names.append(name_tuple)
|
||||
|
||||
# Default name
|
||||
name_tuple = (state.name, state.name, context)
|
||||
if is_exposed:
|
||||
exposed_entity_names.append(name_tuple)
|
||||
|
||||
exposed_entity_names = list(self._get_entity_name_tuples(exposed=True))
|
||||
_LOGGER.debug("Exposed entities: %s", exposed_entity_names)
|
||||
|
||||
# Expose all areas.
|
||||
|
@ -1099,11 +1108,17 @@ class DefaultAgent(ConversationEntity):
|
|||
|
||||
floor_names.append((alias, floor.name))
|
||||
|
||||
# Build trie
|
||||
self._exposed_names_trie = Trie()
|
||||
name_list = TextSlotList.from_tuples(exposed_entity_names, allow_template=False)
|
||||
for name_value in name_list.values:
|
||||
assert isinstance(name_value.text_in, TextChunk)
|
||||
name_text = name_value.text_in.text.strip().lower()
|
||||
self._exposed_names_trie.insert(name_text, name_value)
|
||||
|
||||
self._slot_lists = {
|
||||
"area": TextSlotList.from_tuples(area_names, allow_template=False),
|
||||
"name": TextSlotList.from_tuples(
|
||||
exposed_entity_names, allow_template=False
|
||||
),
|
||||
"name": name_list,
|
||||
"floor": TextSlotList.from_tuples(floor_names, allow_template=False),
|
||||
}
|
||||
|
||||
|
|
|
@ -6,5 +6,5 @@
|
|||
"documentation": "https://www.home-assistant.io/integrations/conversation",
|
||||
"integration_type": "system",
|
||||
"quality_scale": "internal",
|
||||
"requirements": ["hassil==2.0.2", "home-assistant-intents==2024.11.13"]
|
||||
"requirements": ["hassil==2.0.4", "home-assistant-intents==2024.11.13"]
|
||||
}
|
||||
|
|
|
@ -32,7 +32,7 @@ go2rtc-client==0.1.1
|
|||
ha-ffmpeg==3.2.2
|
||||
habluetooth==3.6.0
|
||||
hass-nabucasa==0.85.0
|
||||
hassil==2.0.2
|
||||
hassil==2.0.4
|
||||
home-assistant-bluetooth==1.13.0
|
||||
home-assistant-frontend==20241106.2
|
||||
home-assistant-intents==2024.11.13
|
||||
|
|
|
@ -1096,7 +1096,7 @@ hass-nabucasa==0.85.0
|
|||
hass-splunk==0.1.1
|
||||
|
||||
# homeassistant.components.conversation
|
||||
hassil==2.0.2
|
||||
hassil==2.0.4
|
||||
|
||||
# homeassistant.components.jewish_calendar
|
||||
hdate==0.11.1
|
||||
|
|
|
@ -931,7 +931,7 @@ habluetooth==3.6.0
|
|||
hass-nabucasa==0.85.0
|
||||
|
||||
# homeassistant.components.conversation
|
||||
hassil==2.0.2
|
||||
hassil==2.0.4
|
||||
|
||||
# homeassistant.components.jewish_calendar
|
||||
hdate==0.11.1
|
||||
|
|
|
@ -23,7 +23,7 @@ RUN --mount=from=ghcr.io/astral-sh/uv:0.5.4,source=/uv,target=/bin/uv \
|
|||
-c /usr/src/homeassistant/homeassistant/package_constraints.txt \
|
||||
-r /usr/src/homeassistant/requirements.txt \
|
||||
stdlib-list==0.10.0 pipdeptree==2.23.4 tqdm==4.66.5 ruff==0.8.0 \
|
||||
PyTurboJPEG==1.7.5 go2rtc-client==0.1.1 ha-ffmpeg==3.2.2 hassil==2.0.2 home-assistant-intents==2024.11.13 mutagen==1.47.0 pymicro-vad==1.0.1 pyspeex-noise==1.0.2
|
||||
PyTurboJPEG==1.7.5 go2rtc-client==0.1.1 ha-ffmpeg==3.2.2 hassil==2.0.4 home-assistant-intents==2024.11.13 mutagen==1.47.0 pymicro-vad==1.0.1 pyspeex-noise==1.0.2
|
||||
|
||||
LABEL "name"="hassfest"
|
||||
LABEL "maintainer"="Home Assistant <hello@home-assistant.io>"
|
||||
|
|
|
@ -1735,7 +1735,7 @@ async def test_empty_aliases(
|
|||
return_value=None,
|
||||
) as mock_recognize_all:
|
||||
await conversation.async_converse(
|
||||
hass, "turn on lights in the kitchen", None, Context(), None
|
||||
hass, "turn on kitchen light", None, Context(), None
|
||||
)
|
||||
|
||||
assert mock_recognize_all.call_count > 0
|
||||
|
@ -2940,3 +2940,76 @@ async def test_intent_cache_fuzzy(hass: HomeAssistant) -> None:
|
|||
result = await agent.async_recognize_intent(user_input)
|
||||
assert result is not None
|
||||
assert getattr(result, mark, None) is True
|
||||
|
||||
|
||||
@pytest.mark.usefixtures("init_components")
|
||||
async def test_entities_filtered_by_input(hass: HomeAssistant) -> None:
|
||||
"""Test that entities are filtered by the input text before intent matching."""
|
||||
agent = hass.data[DATA_DEFAULT_ENTITY]
|
||||
assert isinstance(agent, default_agent.DefaultAgent)
|
||||
|
||||
# Only the switch is exposed
|
||||
hass.states.async_set("light.test_light", "off")
|
||||
hass.states.async_set(
|
||||
"light.test_light_2", "off", attributes={ATTR_FRIENDLY_NAME: "test light"}
|
||||
)
|
||||
hass.states.async_set("cover.garage_door", "closed")
|
||||
hass.states.async_set("switch.test_switch", "off")
|
||||
expose_entity(hass, "light.test_light", False)
|
||||
expose_entity(hass, "light.test_light_2", False)
|
||||
expose_entity(hass, "cover.garage_door", False)
|
||||
expose_entity(hass, "switch.test_switch", True)
|
||||
await hass.async_block_till_done()
|
||||
|
||||
# test switch is exposed
|
||||
user_input = ConversationInput(
|
||||
text="turn on test switch",
|
||||
context=Context(),
|
||||
conversation_id=None,
|
||||
device_id=None,
|
||||
language=hass.config.language,
|
||||
agent_id=None,
|
||||
)
|
||||
|
||||
with patch(
|
||||
"homeassistant.components.conversation.default_agent.recognize_best",
|
||||
return_value=None,
|
||||
) as recognize_best:
|
||||
await agent.async_recognize_intent(user_input)
|
||||
|
||||
# (1) exposed, (2) all entities
|
||||
assert len(recognize_best.call_args_list) == 2
|
||||
|
||||
# Only the test light should have been considered because its name shows
|
||||
# up in the input text.
|
||||
slot_lists = recognize_best.call_args_list[0].kwargs["slot_lists"]
|
||||
name_list = slot_lists["name"]
|
||||
assert len(name_list.values) == 1
|
||||
assert name_list.values[0].text_in.text == "test switch"
|
||||
|
||||
# test light is not exposed
|
||||
user_input = ConversationInput(
|
||||
text="turn on Test Light", # different casing for name
|
||||
context=Context(),
|
||||
conversation_id=None,
|
||||
device_id=None,
|
||||
language=hass.config.language,
|
||||
agent_id=None,
|
||||
)
|
||||
|
||||
with patch(
|
||||
"homeassistant.components.conversation.default_agent.recognize_best",
|
||||
return_value=None,
|
||||
) as recognize_best:
|
||||
await agent.async_recognize_intent(user_input)
|
||||
|
||||
# (1) exposed, (2) all entities
|
||||
assert len(recognize_best.call_args_list) == 2
|
||||
|
||||
# Both test lights should have been considered because their name shows
|
||||
# up in the input text.
|
||||
slot_lists = recognize_best.call_args_list[1].kwargs["slot_lists"]
|
||||
name_list = slot_lists["name"]
|
||||
assert len(name_list.values) == 2
|
||||
assert name_list.values[0].text_in.text == "test light"
|
||||
assert name_list.values[1].text_in.text == "test light"
|
||||
|
|
Loading…
Reference in New Issue