Ensure intent tools have safe names (#119144)

pull/119160/head
Paulus Schoutsen 2024-06-08 11:53:47 -04:00 committed by GitHub
parent fff2c1115d
commit c49ca5ed56
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 44 additions and 2 deletions

View File

@ -5,8 +5,10 @@ from __future__ import annotations
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from dataclasses import dataclass from dataclasses import dataclass
from enum import Enum from enum import Enum
from functools import cache, partial
from typing import Any from typing import Any
import slugify as unicode_slug
import voluptuous as vol import voluptuous as vol
from homeassistant.components.climate.intent import INTENT_GET_TEMPERATURE from homeassistant.components.climate.intent import INTENT_GET_TEMPERATURE
@ -175,10 +177,11 @@ class IntentTool(Tool):
def __init__( def __init__(
self, self,
name: str,
intent_handler: intent.IntentHandler, intent_handler: intent.IntentHandler,
) -> None: ) -> None:
"""Init the class.""" """Init the class."""
self.name = intent_handler.intent_type self.name = name
self.description = ( self.description = (
intent_handler.description or f"Execute Home Assistant {self.name} intent" intent_handler.description or f"Execute Home Assistant {self.name} intent"
) )
@ -261,6 +264,9 @@ class AssistAPI(API):
id=LLM_API_ASSIST, id=LLM_API_ASSIST,
name="Assist", name="Assist",
) )
self.cached_slugify = cache(
partial(unicode_slug.slugify, separator="_", lowercase=False)
)
async def async_get_api_instance(self, llm_context: LLMContext) -> APIInstance: async def async_get_api_instance(self, llm_context: LLMContext) -> APIInstance:
"""Return the instance of the API.""" """Return the instance of the API."""
@ -373,7 +379,10 @@ class AssistAPI(API):
or intent_handler.platforms & exposed_domains or intent_handler.platforms & exposed_domains
] ]
return [IntentTool(intent_handler) for intent_handler in intent_handlers] return [
IntentTool(self.cached_slugify(intent_handler.intent_type), intent_handler)
for intent_handler in intent_handlers
]
def _get_exposed_entities( def _get_exposed_entities(

View File

@ -249,6 +249,39 @@ async def test_assist_api_get_timer_tools(
assert "HassStartTimer" in [tool.name for tool in api.tools] assert "HassStartTimer" in [tool.name for tool in api.tools]
async def test_assist_api_tools(
hass: HomeAssistant, llm_context: llm.LLMContext
) -> None:
"""Test getting timer tools with Assist API."""
assert await async_setup_component(hass, "homeassistant", {})
assert await async_setup_component(hass, "intent", {})
llm_context.device_id = "test_device"
async_register_timer_handler(hass, "test_device", lambda *args: None)
class MyIntentHandler(intent.IntentHandler):
intent_type = "Super crazy intent with unique nåme"
description = "my intent handler"
intent.async_register(hass, MyIntentHandler())
api = await llm.async_get_api(hass, "assist", llm_context)
assert [tool.name for tool in api.tools] == [
"HassTurnOn",
"HassTurnOff",
"HassSetPosition",
"HassStartTimer",
"HassCancelTimer",
"HassIncreaseTimer",
"HassDecreaseTimer",
"HassPauseTimer",
"HassUnpauseTimer",
"HassTimerStatus",
"Super_crazy_intent_with_unique_name",
]
async def test_assist_api_description( async def test_assist_api_description(
hass: HomeAssistant, llm_context: llm.LLMContext hass: HomeAssistant, llm_context: llm.LLMContext
) -> None: ) -> None: