Automatically fill in slots based on LLM context (#118619)
* Automatically fill in slots from LLM context * Add tests * Apply suggestions from code review Co-authored-by: Allen Porter <allen@thebends.org> --------- Co-authored-by: Allen Porter <allen@thebends.org>pull/118845/head
parent
b436fe94ae
commit
84f9bb1d63
|
@ -181,14 +181,48 @@ class IntentTool(Tool):
|
||||||
self.description = (
|
self.description = (
|
||||||
intent_handler.description or f"Execute Home Assistant {self.name} intent"
|
intent_handler.description or f"Execute Home Assistant {self.name} intent"
|
||||||
)
|
)
|
||||||
if slot_schema := intent_handler.slot_schema:
|
self.extra_slots = None
|
||||||
self.parameters = vol.Schema(slot_schema)
|
if not (slot_schema := intent_handler.slot_schema):
|
||||||
|
return
|
||||||
|
|
||||||
|
slot_schema = {**slot_schema}
|
||||||
|
extra_slots = set()
|
||||||
|
|
||||||
|
for field in ("preferred_area_id", "preferred_floor_id"):
|
||||||
|
if field in slot_schema:
|
||||||
|
extra_slots.add(field)
|
||||||
|
del slot_schema[field]
|
||||||
|
|
||||||
|
self.parameters = vol.Schema(slot_schema)
|
||||||
|
if extra_slots:
|
||||||
|
self.extra_slots = extra_slots
|
||||||
|
|
||||||
async def async_call(
|
async def async_call(
|
||||||
self, hass: HomeAssistant, tool_input: ToolInput, llm_context: LLMContext
|
self, hass: HomeAssistant, tool_input: ToolInput, llm_context: LLMContext
|
||||||
) -> JsonObjectType:
|
) -> JsonObjectType:
|
||||||
"""Handle the intent."""
|
"""Handle the intent."""
|
||||||
slots = {key: {"value": val} for key, val in tool_input.tool_args.items()}
|
slots = {key: {"value": val} for key, val in tool_input.tool_args.items()}
|
||||||
|
|
||||||
|
if self.extra_slots and llm_context.device_id:
|
||||||
|
device_reg = dr.async_get(hass)
|
||||||
|
device = device_reg.async_get(llm_context.device_id)
|
||||||
|
|
||||||
|
area: ar.AreaEntry | None = None
|
||||||
|
floor: fr.FloorEntry | None = None
|
||||||
|
if device:
|
||||||
|
area_reg = ar.async_get(hass)
|
||||||
|
if device.area_id and (area := area_reg.async_get_area(device.area_id)):
|
||||||
|
if area.floor_id:
|
||||||
|
floor_reg = fr.async_get(hass)
|
||||||
|
floor = floor_reg.async_get_floor(area.floor_id)
|
||||||
|
|
||||||
|
for slot_name, slot_value in (
|
||||||
|
("preferred_area_id", area.id if area else None),
|
||||||
|
("preferred_floor_id", floor.floor_id if floor else None),
|
||||||
|
):
|
||||||
|
if slot_value and slot_name in self.extra_slots:
|
||||||
|
slots[slot_name] = {"value": slot_value}
|
||||||
|
|
||||||
intent_response = await intent.async_handle(
|
intent_response = await intent.async_handle(
|
||||||
hass=hass,
|
hass=hass,
|
||||||
platform=llm_context.platform,
|
platform=llm_context.platform,
|
||||||
|
|
|
@ -77,7 +77,11 @@ async def test_call_tool_no_existing(
|
||||||
|
|
||||||
|
|
||||||
async def test_assist_api(
|
async def test_assist_api(
|
||||||
hass: HomeAssistant, entity_registry: er.EntityRegistry
|
hass: HomeAssistant,
|
||||||
|
entity_registry: er.EntityRegistry,
|
||||||
|
device_registry: dr.DeviceRegistry,
|
||||||
|
area_registry: ar.AreaRegistry,
|
||||||
|
floor_registry: fr.FloorRegistry,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Test Assist API."""
|
"""Test Assist API."""
|
||||||
assert await async_setup_component(hass, "homeassistant", {})
|
assert await async_setup_component(hass, "homeassistant", {})
|
||||||
|
@ -97,11 +101,13 @@ async def test_assist_api(
|
||||||
user_prompt="test_text",
|
user_prompt="test_text",
|
||||||
language="*",
|
language="*",
|
||||||
assistant="conversation",
|
assistant="conversation",
|
||||||
device_id="test_device",
|
device_id=None,
|
||||||
)
|
)
|
||||||
schema = {
|
schema = {
|
||||||
vol.Optional("area"): cv.string,
|
vol.Optional("area"): cv.string,
|
||||||
vol.Optional("floor"): cv.string,
|
vol.Optional("floor"): cv.string,
|
||||||
|
vol.Optional("preferred_area_id"): cv.string,
|
||||||
|
vol.Optional("preferred_floor_id"): cv.string,
|
||||||
}
|
}
|
||||||
|
|
||||||
class MyIntentHandler(intent.IntentHandler):
|
class MyIntentHandler(intent.IntentHandler):
|
||||||
|
@ -131,7 +137,13 @@ async def test_assist_api(
|
||||||
tool = api.tools[0]
|
tool = api.tools[0]
|
||||||
assert tool.name == "test_intent"
|
assert tool.name == "test_intent"
|
||||||
assert tool.description == "Execute Home Assistant test_intent intent"
|
assert tool.description == "Execute Home Assistant test_intent intent"
|
||||||
assert tool.parameters == vol.Schema(intent_handler.slot_schema)
|
assert tool.parameters == vol.Schema(
|
||||||
|
{
|
||||||
|
vol.Optional("area"): cv.string,
|
||||||
|
vol.Optional("floor"): cv.string,
|
||||||
|
# No preferred_area_id, preferred_floor_id
|
||||||
|
}
|
||||||
|
)
|
||||||
assert str(tool) == "<IntentTool - test_intent>"
|
assert str(tool) == "<IntentTool - test_intent>"
|
||||||
|
|
||||||
assert test_context.json_fragment # To reproduce an error case in tracing
|
assert test_context.json_fragment # To reproduce an error case in tracing
|
||||||
|
@ -160,7 +172,52 @@ async def test_assist_api(
|
||||||
context=test_context,
|
context=test_context,
|
||||||
language="*",
|
language="*",
|
||||||
assistant="conversation",
|
assistant="conversation",
|
||||||
device_id="test_device",
|
device_id=None,
|
||||||
|
)
|
||||||
|
assert response == {
|
||||||
|
"data": {
|
||||||
|
"failed": [],
|
||||||
|
"success": [],
|
||||||
|
"targets": [],
|
||||||
|
},
|
||||||
|
"response_type": "action_done",
|
||||||
|
"speech": {},
|
||||||
|
}
|
||||||
|
|
||||||
|
# Call with a device/area/floor
|
||||||
|
entry = MockConfigEntry(title=None)
|
||||||
|
entry.add_to_hass(hass)
|
||||||
|
|
||||||
|
device = device_registry.async_get_or_create(
|
||||||
|
config_entry_id=entry.entry_id,
|
||||||
|
connections={("test", "1234")},
|
||||||
|
suggested_area="Test Area",
|
||||||
|
)
|
||||||
|
area = area_registry.async_get_area_by_name("Test Area")
|
||||||
|
floor = floor_registry.async_create("2")
|
||||||
|
area_registry.async_update(area.id, floor_id=floor.floor_id)
|
||||||
|
llm_context.device_id = device.id
|
||||||
|
|
||||||
|
with patch(
|
||||||
|
"homeassistant.helpers.intent.async_handle", return_value=intent_response
|
||||||
|
) as mock_intent_handle:
|
||||||
|
response = await api.async_call_tool(tool_input)
|
||||||
|
|
||||||
|
mock_intent_handle.assert_awaited_once_with(
|
||||||
|
hass=hass,
|
||||||
|
platform="test_platform",
|
||||||
|
intent_type="test_intent",
|
||||||
|
slots={
|
||||||
|
"area": {"value": "kitchen"},
|
||||||
|
"floor": {"value": "ground_floor"},
|
||||||
|
"preferred_area_id": {"value": area.id},
|
||||||
|
"preferred_floor_id": {"value": floor.floor_id},
|
||||||
|
},
|
||||||
|
text_input="test_text",
|
||||||
|
context=test_context,
|
||||||
|
language="*",
|
||||||
|
assistant="conversation",
|
||||||
|
device_id=device.id,
|
||||||
)
|
)
|
||||||
assert response == {
|
assert response == {
|
||||||
"data": {
|
"data": {
|
||||||
|
|
Loading…
Reference in New Issue