Ensure script llm tool name does not start with a digit (#122349)

* Ensure script tool name does not start with a digit

* Fix test name
pull/122388/head
Denis Shulyaka 2024-07-22 12:11:09 +03:00 committed by GitHub
parent 0c6dc9e43b
commit 064d7261b4
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 44 additions and 1 deletions

View File

@ -617,6 +617,9 @@ class ScriptTool(Tool):
entity_registry = er.async_get(hass) entity_registry = er.async_get(hass)
self.name = split_entity_id(script_entity_id)[1] self.name = split_entity_id(script_entity_id)[1]
if self.name[0].isdigit():
self.name = "_" + self.name
self._entity_id = script_entity_id
self.parameters = vol.Schema({}) self.parameters = vol.Schema({})
entity_entry = entity_registry.async_get(script_entity_id) entity_entry = entity_registry.async_get(script_entity_id)
if entity_entry and entity_entry.unique_id: if entity_entry and entity_entry.unique_id:
@ -717,7 +720,7 @@ class ScriptTool(Tool):
SCRIPT_DOMAIN, SCRIPT_DOMAIN,
SERVICE_TURN_ON, SERVICE_TURN_ON,
{ {
ATTR_ENTITY_ID: SCRIPT_DOMAIN + "." + self.name, ATTR_ENTITY_ID: self._entity_id,
ATTR_VARIABLES: tool_input.tool_args, ATTR_VARIABLES: tool_input.tool_args,
}, },
context=llm_context.context, context=llm_context.context,

View File

@ -780,6 +780,46 @@ async def test_script_tool(
} }
async def test_script_tool_name(hass: HomeAssistant) -> None:
"""Test that script tool name is not started with a digit."""
assert await async_setup_component(hass, "homeassistant", {})
context = Context()
llm_context = llm.LLMContext(
platform="test_platform",
context=context,
user_prompt="test_text",
language="*",
assistant="conversation",
device_id=None,
)
# Create a script with a unique ID
assert await async_setup_component(
hass,
"script",
{
"script": {
"123456": {
"description": "This is a test script",
"sequence": [],
"fields": {
"beer": {"description": "Number of beers", "required": True},
},
},
}
},
)
async_expose_entity(hass, "conversation", "script.123456", True)
api = await llm.async_get_api(hass, "assist", llm_context)
tools = [tool for tool in api.tools if isinstance(tool, llm.ScriptTool)]
assert len(tools) == 1
tool = tools[0]
assert tool.name == "_123456"
async def test_selector_serializer( async def test_selector_serializer(
hass: HomeAssistant, llm_context: llm.LLMContext hass: HomeAssistant, llm_context: llm.LLMContext
) -> None: ) -> None: