Ensure script llm tool name does not start with a digit (#122349)
* Ensure script tool name does not start with a digit * Fix test namepull/122388/head
parent
0c6dc9e43b
commit
064d7261b4
|
@ -617,6 +617,9 @@ class ScriptTool(Tool):
|
|||
entity_registry = er.async_get(hass)
|
||||
|
||||
self.name = split_entity_id(script_entity_id)[1]
|
||||
if self.name[0].isdigit():
|
||||
self.name = "_" + self.name
|
||||
self._entity_id = script_entity_id
|
||||
self.parameters = vol.Schema({})
|
||||
entity_entry = entity_registry.async_get(script_entity_id)
|
||||
if entity_entry and entity_entry.unique_id:
|
||||
|
@ -717,7 +720,7 @@ class ScriptTool(Tool):
|
|||
SCRIPT_DOMAIN,
|
||||
SERVICE_TURN_ON,
|
||||
{
|
||||
ATTR_ENTITY_ID: SCRIPT_DOMAIN + "." + self.name,
|
||||
ATTR_ENTITY_ID: self._entity_id,
|
||||
ATTR_VARIABLES: tool_input.tool_args,
|
||||
},
|
||||
context=llm_context.context,
|
||||
|
|
|
@ -780,6 +780,46 @@ async def test_script_tool(
|
|||
}
|
||||
|
||||
|
||||
async def test_script_tool_name(hass: HomeAssistant) -> None:
|
||||
"""Test that script tool name is not started with a digit."""
|
||||
assert await async_setup_component(hass, "homeassistant", {})
|
||||
context = Context()
|
||||
llm_context = llm.LLMContext(
|
||||
platform="test_platform",
|
||||
context=context,
|
||||
user_prompt="test_text",
|
||||
language="*",
|
||||
assistant="conversation",
|
||||
device_id=None,
|
||||
)
|
||||
|
||||
# Create a script with a unique ID
|
||||
assert await async_setup_component(
|
||||
hass,
|
||||
"script",
|
||||
{
|
||||
"script": {
|
||||
"123456": {
|
||||
"description": "This is a test script",
|
||||
"sequence": [],
|
||||
"fields": {
|
||||
"beer": {"description": "Number of beers", "required": True},
|
||||
},
|
||||
},
|
||||
}
|
||||
},
|
||||
)
|
||||
async_expose_entity(hass, "conversation", "script.123456", True)
|
||||
|
||||
api = await llm.async_get_api(hass, "assist", llm_context)
|
||||
|
||||
tools = [tool for tool in api.tools if isinstance(tool, llm.ScriptTool)]
|
||||
assert len(tools) == 1
|
||||
|
||||
tool = tools[0]
|
||||
assert tool.name == "_123456"
|
||||
|
||||
|
||||
async def test_selector_serializer(
|
||||
hass: HomeAssistant, llm_context: llm.LLMContext
|
||||
) -> None:
|
||||
|
|
Loading…
Reference in New Issue