From 752df5a8cb8786d515210077a06e365de604e065 Mon Sep 17 00:00:00 2001 From: Michael Hansen Date: Tue, 26 Nov 2024 02:42:31 -0600 Subject: [PATCH] Filter entity names before intent matching (#131563) --- .../components/conversation/default_agent.py | 159 ++++++++++-------- .../components/conversation/manifest.json | 2 +- homeassistant/package_constraints.txt | 2 +- requirements_all.txt | 2 +- requirements_test_all.txt | 2 +- script/hassfest/docker/Dockerfile | 2 +- .../conversation/test_default_agent.py | 75 ++++++++- 7 files changed, 166 insertions(+), 78 deletions(-) diff --git a/homeassistant/components/conversation/default_agent.py b/homeassistant/components/conversation/default_agent.py index 20720b90423..c1256a1507b 100644 --- a/homeassistant/components/conversation/default_agent.py +++ b/homeassistant/components/conversation/default_agent.py @@ -14,8 +14,14 @@ import re import time from typing import IO, Any, cast -from hassil.expression import Expression, ListReference, Sequence -from hassil.intents import Intents, SlotList, TextSlotList, WildcardSlotList +from hassil.expression import Expression, ListReference, Sequence, TextChunk +from hassil.intents import ( + Intents, + SlotList, + TextSlotList, + TextSlotValue, + WildcardSlotList, +) from hassil.recognize import ( MISSING_ENTITY, RecognizeResult, @@ -23,6 +29,7 @@ from hassil.recognize import ( recognize_best, ) from hassil.string_matcher import UnmatchedRangeEntity, UnmatchedTextEntity +from hassil.trie import Trie from hassil.util import merge_dict from home_assistant_intents import ErrorKey, get_intents, get_languages import yaml @@ -110,8 +117,8 @@ class IntentMatchingStage(Enum): EXPOSED_ENTITIES_ONLY = auto() """Match against exposed entities only.""" - ALL_ENTITIES = auto() - """Match against all entities in Home Assistant.""" + UNEXPOSED_ENTITIES = auto() + """Match against unexposed entities in Home Assistant.""" FUZZY = auto() """Capture names that are not known to Home Assistant.""" @@ -233,7 +240,10 @@ class DefaultAgent(ConversationEntity): # intent -> [sentences] self._config_intents: dict[str, Any] = config_intents self._slot_lists: dict[str, SlotList] | None = None - self._all_entity_names: TextSlotList | None = None + + # Used to filter slot lists before intent matching + self._exposed_names_trie: Trie | None = None + self._unexposed_names_trie: Trie | None = None # Sentences that will trigger a callback (skipping intent recognition) self._trigger_sentences: list[TriggerData] = [] @@ -305,6 +315,16 @@ class DefaultAgent(ConversationEntity): slot_lists = self._make_slot_lists() intent_context = self._make_intent_context(user_input) + if self._exposed_names_trie is not None: + # Filter by input string + text_lower = user_input.text.strip().lower() + slot_lists["name"] = TextSlotList( + name="name", + values=[ + result[2] for result in self._exposed_names_trie.find(text_lower) + ], + ) + start = time.monotonic() result = await self.hass.async_add_executor_job( @@ -540,29 +560,29 @@ class DefaultAgent(ConversationEntity): return None # Try again with all entities (including unexposed) - skip_all_entities_match = False + skip_unexposed_entities_match = False if cache_value is not None: if (cache_value.result is not None) and ( - cache_value.stage == IntentMatchingStage.ALL_ENTITIES + cache_value.stage == IntentMatchingStage.UNEXPOSED_ENTITIES ): _LOGGER.debug("Got cached result for all entities") return cache_value.result # Continue with matching, but we know we won't succeed for all # entities. - skip_all_entities_match = True + skip_unexposed_entities_match = True - if not skip_all_entities_match: - all_entities_slot_lists = { + if not skip_unexposed_entities_match: + unexposed_entities_slot_lists = { **slot_lists, - "name": self._get_all_entity_names(), + "name": self._get_unexposed_entity_names(user_input.text), } start_time = time.monotonic() strict_result = self._recognize_strict( user_input, lang_intents, - all_entities_slot_lists, + unexposed_entities_slot_lists, intent_context, language, ) @@ -575,7 +595,7 @@ class DefaultAgent(ConversationEntity): self._intent_cache.put( cache_key, IntentCacheValue( - result=strict_result, stage=IntentMatchingStage.ALL_ENTITIES + result=strict_result, stage=IntentMatchingStage.UNEXPOSED_ENTITIES ), ) @@ -683,15 +703,43 @@ class DefaultAgent(ConversationEntity): return maybe_result - def _get_all_entity_names(self) -> TextSlotList: - """Get slot list with all entity names in Home Assistant.""" - if self._all_entity_names is not None: - return self._all_entity_names + def _get_unexposed_entity_names(self, text: str) -> TextSlotList: + """Get filtered slot list with unexposed entity names in Home Assistant.""" + if self._unexposed_names_trie is None: + # Build trie + self._unexposed_names_trie = Trie() + for name_tuple in self._get_entity_name_tuples(exposed=False): + self._unexposed_names_trie.insert( + name_tuple[0].lower(), + TextSlotValue.from_tuple(name_tuple), + ) + # Build filtered slot list + text_lower = text.strip().lower() + return TextSlotList( + name="name", + values=[ + result[2] for result in self._unexposed_names_trie.find(text_lower) + ], + ) + + def _get_entity_name_tuples( + self, exposed: bool + ) -> Iterable[tuple[str, str, dict[str, Any]]]: + """Yield (input name, output name, context) tuples for entities.""" entity_registry = er.async_get(self.hass) - all_entity_names: list[tuple[str, str, dict[str, Any]]] = [] for state in self.hass.states.async_all(): + entity_exposed = async_should_expose(self.hass, DOMAIN, state.entity_id) + if exposed and (not entity_exposed): + # Required exposed, entity is not + continue + + if (not exposed) and entity_exposed: + # Required not exposed, entity is + continue + + # Checked against "requires_context" and "excludes_context" in hassil context = {"domain": state.domain} if state.attributes: # Include some attributes @@ -700,28 +748,18 @@ class DefaultAgent(ConversationEntity): continue context[attr] = state.attributes[attr] - if entity := entity_registry.async_get(state.entity_id): - # Skip config/hidden entities - if (entity.entity_category is not None) or ( - entity.hidden_by is not None - ): - continue + if ( + entity := entity_registry.async_get(state.entity_id) + ) and entity.aliases: + for alias in entity.aliases: + alias = alias.strip() + if not alias: + continue - if entity.aliases: - # Also add aliases - for alias in entity.aliases: - if not alias.strip(): - continue - - all_entity_names.append((alias, alias, context)) + yield (alias, alias, context) # Default name - all_entity_names.append((state.name, state.name, context)) - - self._all_entity_names = TextSlotList.from_tuples( - all_entity_names, allow_template=False - ) - return self._all_entity_names + yield (state.name, state.name, context) def _recognize_strict( self, @@ -1013,7 +1051,8 @@ class DefaultAgent(ConversationEntity): if self._unsub_clear_slot_list is None: return self._slot_lists = None - self._all_entity_names = None + self._exposed_names_trie = None + self._unexposed_names_trie = None for unsub in self._unsub_clear_slot_list: unsub() self._unsub_clear_slot_list = None @@ -1029,8 +1068,6 @@ class DefaultAgent(ConversationEntity): start = time.monotonic() - entity_registry = er.async_get(self.hass) - # Gather entity names, keeping track of exposed names. # We try intent recognition with only exposed names first, then all names. # @@ -1038,35 +1075,7 @@ class DefaultAgent(ConversationEntity): # have the same name. The intent matcher doesn't gather all matching # values for a list, just the first. So we will need to match by name no # matter what. - exposed_entity_names = [] - for state in self.hass.states.async_all(): - is_exposed = async_should_expose(self.hass, DOMAIN, state.entity_id) - - # Checked against "requires_context" and "excludes_context" in hassil - context = {"domain": state.domain} - if state.attributes: - # Include some attributes - for attr in DEFAULT_EXPOSED_ATTRIBUTES: - if attr not in state.attributes: - continue - context[attr] = state.attributes[attr] - - if ( - entity := entity_registry.async_get(state.entity_id) - ) and entity.aliases: - for alias in entity.aliases: - if not alias.strip(): - continue - - name_tuple = (alias, alias, context) - if is_exposed: - exposed_entity_names.append(name_tuple) - - # Default name - name_tuple = (state.name, state.name, context) - if is_exposed: - exposed_entity_names.append(name_tuple) - + exposed_entity_names = list(self._get_entity_name_tuples(exposed=True)) _LOGGER.debug("Exposed entities: %s", exposed_entity_names) # Expose all areas. @@ -1099,11 +1108,17 @@ class DefaultAgent(ConversationEntity): floor_names.append((alias, floor.name)) + # Build trie + self._exposed_names_trie = Trie() + name_list = TextSlotList.from_tuples(exposed_entity_names, allow_template=False) + for name_value in name_list.values: + assert isinstance(name_value.text_in, TextChunk) + name_text = name_value.text_in.text.strip().lower() + self._exposed_names_trie.insert(name_text, name_value) + self._slot_lists = { "area": TextSlotList.from_tuples(area_names, allow_template=False), - "name": TextSlotList.from_tuples( - exposed_entity_names, allow_template=False - ), + "name": name_list, "floor": TextSlotList.from_tuples(floor_names, allow_template=False), } diff --git a/homeassistant/components/conversation/manifest.json b/homeassistant/components/conversation/manifest.json index 6c2d70b6a11..b45a5456825 100644 --- a/homeassistant/components/conversation/manifest.json +++ b/homeassistant/components/conversation/manifest.json @@ -6,5 +6,5 @@ "documentation": "https://www.home-assistant.io/integrations/conversation", "integration_type": "system", "quality_scale": "internal", - "requirements": ["hassil==2.0.2", "home-assistant-intents==2024.11.13"] + "requirements": ["hassil==2.0.4", "home-assistant-intents==2024.11.13"] } diff --git a/homeassistant/package_constraints.txt b/homeassistant/package_constraints.txt index 765fbd89a24..19bfee3c80a 100644 --- a/homeassistant/package_constraints.txt +++ b/homeassistant/package_constraints.txt @@ -32,7 +32,7 @@ go2rtc-client==0.1.1 ha-ffmpeg==3.2.2 habluetooth==3.6.0 hass-nabucasa==0.85.0 -hassil==2.0.2 +hassil==2.0.4 home-assistant-bluetooth==1.13.0 home-assistant-frontend==20241106.2 home-assistant-intents==2024.11.13 diff --git a/requirements_all.txt b/requirements_all.txt index 277b9cc5020..8aeb7509395 100644 --- a/requirements_all.txt +++ b/requirements_all.txt @@ -1096,7 +1096,7 @@ hass-nabucasa==0.85.0 hass-splunk==0.1.1 # homeassistant.components.conversation -hassil==2.0.2 +hassil==2.0.4 # homeassistant.components.jewish_calendar hdate==0.11.1 diff --git a/requirements_test_all.txt b/requirements_test_all.txt index a4d87e43c22..3fb24e4a46e 100644 --- a/requirements_test_all.txt +++ b/requirements_test_all.txt @@ -931,7 +931,7 @@ habluetooth==3.6.0 hass-nabucasa==0.85.0 # homeassistant.components.conversation -hassil==2.0.2 +hassil==2.0.4 # homeassistant.components.jewish_calendar hdate==0.11.1 diff --git a/script/hassfest/docker/Dockerfile b/script/hassfest/docker/Dockerfile index b75477820ff..e6ab27de9b0 100644 --- a/script/hassfest/docker/Dockerfile +++ b/script/hassfest/docker/Dockerfile @@ -23,7 +23,7 @@ RUN --mount=from=ghcr.io/astral-sh/uv:0.5.4,source=/uv,target=/bin/uv \ -c /usr/src/homeassistant/homeassistant/package_constraints.txt \ -r /usr/src/homeassistant/requirements.txt \ stdlib-list==0.10.0 pipdeptree==2.23.4 tqdm==4.66.5 ruff==0.8.0 \ - PyTurboJPEG==1.7.5 go2rtc-client==0.1.1 ha-ffmpeg==3.2.2 hassil==2.0.2 home-assistant-intents==2024.11.13 mutagen==1.47.0 pymicro-vad==1.0.1 pyspeex-noise==1.0.2 + PyTurboJPEG==1.7.5 go2rtc-client==0.1.1 ha-ffmpeg==3.2.2 hassil==2.0.4 home-assistant-intents==2024.11.13 mutagen==1.47.0 pymicro-vad==1.0.1 pyspeex-noise==1.0.2 LABEL "name"="hassfest" LABEL "maintainer"="Home Assistant " diff --git a/tests/components/conversation/test_default_agent.py b/tests/components/conversation/test_default_agent.py index 1e5e284a245..6990ffe7717 100644 --- a/tests/components/conversation/test_default_agent.py +++ b/tests/components/conversation/test_default_agent.py @@ -1735,7 +1735,7 @@ async def test_empty_aliases( return_value=None, ) as mock_recognize_all: await conversation.async_converse( - hass, "turn on lights in the kitchen", None, Context(), None + hass, "turn on kitchen light", None, Context(), None ) assert mock_recognize_all.call_count > 0 @@ -2940,3 +2940,76 @@ async def test_intent_cache_fuzzy(hass: HomeAssistant) -> None: result = await agent.async_recognize_intent(user_input) assert result is not None assert getattr(result, mark, None) is True + + +@pytest.mark.usefixtures("init_components") +async def test_entities_filtered_by_input(hass: HomeAssistant) -> None: + """Test that entities are filtered by the input text before intent matching.""" + agent = hass.data[DATA_DEFAULT_ENTITY] + assert isinstance(agent, default_agent.DefaultAgent) + + # Only the switch is exposed + hass.states.async_set("light.test_light", "off") + hass.states.async_set( + "light.test_light_2", "off", attributes={ATTR_FRIENDLY_NAME: "test light"} + ) + hass.states.async_set("cover.garage_door", "closed") + hass.states.async_set("switch.test_switch", "off") + expose_entity(hass, "light.test_light", False) + expose_entity(hass, "light.test_light_2", False) + expose_entity(hass, "cover.garage_door", False) + expose_entity(hass, "switch.test_switch", True) + await hass.async_block_till_done() + + # test switch is exposed + user_input = ConversationInput( + text="turn on test switch", + context=Context(), + conversation_id=None, + device_id=None, + language=hass.config.language, + agent_id=None, + ) + + with patch( + "homeassistant.components.conversation.default_agent.recognize_best", + return_value=None, + ) as recognize_best: + await agent.async_recognize_intent(user_input) + + # (1) exposed, (2) all entities + assert len(recognize_best.call_args_list) == 2 + + # Only the test light should have been considered because its name shows + # up in the input text. + slot_lists = recognize_best.call_args_list[0].kwargs["slot_lists"] + name_list = slot_lists["name"] + assert len(name_list.values) == 1 + assert name_list.values[0].text_in.text == "test switch" + + # test light is not exposed + user_input = ConversationInput( + text="turn on Test Light", # different casing for name + context=Context(), + conversation_id=None, + device_id=None, + language=hass.config.language, + agent_id=None, + ) + + with patch( + "homeassistant.components.conversation.default_agent.recognize_best", + return_value=None, + ) as recognize_best: + await agent.async_recognize_intent(user_input) + + # (1) exposed, (2) all entities + assert len(recognize_best.call_args_list) == 2 + + # Both test lights should have been considered because their name shows + # up in the input text. + slot_lists = recognize_best.call_args_list[1].kwargs["slot_lists"] + name_list = slot_lists["name"] + assert len(name_list.values) == 2 + assert name_list.values[0].text_in.text == "test light" + assert name_list.values[1].text_in.text == "test light"