From 84f9bb1d639963615f0f85fc60cc5684dd6612c1 Mon Sep 17 00:00:00 2001 From: Paulus Schoutsen Date: Mon, 3 Jun 2024 10:36:41 -0400 Subject: [PATCH] Automatically fill in slots based on LLM context (#118619) * Automatically fill in slots from LLM context * Add tests * Apply suggestions from code review Co-authored-by: Allen Porter --------- Co-authored-by: Allen Porter --- homeassistant/helpers/llm.py | 38 +++++++++++++++++++-- tests/helpers/test_llm.py | 65 +++++++++++++++++++++++++++++++++--- 2 files changed, 97 insertions(+), 6 deletions(-) diff --git a/homeassistant/helpers/llm.py b/homeassistant/helpers/llm.py index ec1bfb7dbc4..37233b0d407 100644 --- a/homeassistant/helpers/llm.py +++ b/homeassistant/helpers/llm.py @@ -181,14 +181,48 @@ class IntentTool(Tool): self.description = ( intent_handler.description or f"Execute Home Assistant {self.name} intent" ) - if slot_schema := intent_handler.slot_schema: - self.parameters = vol.Schema(slot_schema) + self.extra_slots = None + if not (slot_schema := intent_handler.slot_schema): + return + + slot_schema = {**slot_schema} + extra_slots = set() + + for field in ("preferred_area_id", "preferred_floor_id"): + if field in slot_schema: + extra_slots.add(field) + del slot_schema[field] + + self.parameters = vol.Schema(slot_schema) + if extra_slots: + self.extra_slots = extra_slots async def async_call( self, hass: HomeAssistant, tool_input: ToolInput, llm_context: LLMContext ) -> JsonObjectType: """Handle the intent.""" slots = {key: {"value": val} for key, val in tool_input.tool_args.items()} + + if self.extra_slots and llm_context.device_id: + device_reg = dr.async_get(hass) + device = device_reg.async_get(llm_context.device_id) + + area: ar.AreaEntry | None = None + floor: fr.FloorEntry | None = None + if device: + area_reg = ar.async_get(hass) + if device.area_id and (area := area_reg.async_get_area(device.area_id)): + if area.floor_id: + floor_reg = fr.async_get(hass) + floor = floor_reg.async_get_floor(area.floor_id) + + for slot_name, slot_value in ( + ("preferred_area_id", area.id if area else None), + ("preferred_floor_id", floor.floor_id if floor else None), + ): + if slot_value and slot_name in self.extra_slots: + slots[slot_name] = {"value": slot_value} + intent_response = await intent.async_handle( hass=hass, platform=llm_context.platform, diff --git a/tests/helpers/test_llm.py b/tests/helpers/test_llm.py index 9ad58441277..6c9451bc843 100644 --- a/tests/helpers/test_llm.py +++ b/tests/helpers/test_llm.py @@ -77,7 +77,11 @@ async def test_call_tool_no_existing( async def test_assist_api( - hass: HomeAssistant, entity_registry: er.EntityRegistry + hass: HomeAssistant, + entity_registry: er.EntityRegistry, + device_registry: dr.DeviceRegistry, + area_registry: ar.AreaRegistry, + floor_registry: fr.FloorRegistry, ) -> None: """Test Assist API.""" assert await async_setup_component(hass, "homeassistant", {}) @@ -97,11 +101,13 @@ async def test_assist_api( user_prompt="test_text", language="*", assistant="conversation", - device_id="test_device", + device_id=None, ) schema = { vol.Optional("area"): cv.string, vol.Optional("floor"): cv.string, + vol.Optional("preferred_area_id"): cv.string, + vol.Optional("preferred_floor_id"): cv.string, } class MyIntentHandler(intent.IntentHandler): @@ -131,7 +137,13 @@ async def test_assist_api( tool = api.tools[0] assert tool.name == "test_intent" assert tool.description == "Execute Home Assistant test_intent intent" - assert tool.parameters == vol.Schema(intent_handler.slot_schema) + assert tool.parameters == vol.Schema( + { + vol.Optional("area"): cv.string, + vol.Optional("floor"): cv.string, + # No preferred_area_id, preferred_floor_id + } + ) assert str(tool) == "" assert test_context.json_fragment # To reproduce an error case in tracing @@ -160,7 +172,52 @@ async def test_assist_api( context=test_context, language="*", assistant="conversation", - device_id="test_device", + device_id=None, + ) + assert response == { + "data": { + "failed": [], + "success": [], + "targets": [], + }, + "response_type": "action_done", + "speech": {}, + } + + # Call with a device/area/floor + entry = MockConfigEntry(title=None) + entry.add_to_hass(hass) + + device = device_registry.async_get_or_create( + config_entry_id=entry.entry_id, + connections={("test", "1234")}, + suggested_area="Test Area", + ) + area = area_registry.async_get_area_by_name("Test Area") + floor = floor_registry.async_create("2") + area_registry.async_update(area.id, floor_id=floor.floor_id) + llm_context.device_id = device.id + + with patch( + "homeassistant.helpers.intent.async_handle", return_value=intent_response + ) as mock_intent_handle: + response = await api.async_call_tool(tool_input) + + mock_intent_handle.assert_awaited_once_with( + hass=hass, + platform="test_platform", + intent_type="test_intent", + slots={ + "area": {"value": "kitchen"}, + "floor": {"value": "ground_floor"}, + "preferred_area_id": {"value": area.id}, + "preferred_floor_id": {"value": floor.floor_id}, + }, + text_input="test_text", + context=test_context, + language="*", + assistant="conversation", + device_id=device.id, ) assert response == { "data": {