diff --git a/homeassistant/helpers/llm.py b/homeassistant/helpers/llm.py index ea376923f9d..cc397c5d428 100644 --- a/homeassistant/helpers/llm.py +++ b/homeassistant/helpers/llm.py @@ -49,9 +49,9 @@ from . import ( ) from .singleton import singleton -SCRIPT_PARAMETERS_CACHE: HassKey[dict[str, tuple[str | None, vol.Schema]]] = HassKey( - "llm_script_parameters_cache" -) +ACTION_PARAMETERS_CACHE: HassKey[ + dict[str, dict[str, tuple[str | None, vol.Schema]]] +] = HassKey("llm_action_parameters_cache") LLM_API_ASSIST = "assist" @@ -624,104 +624,105 @@ def _selector_serializer(schema: Any) -> Any: # noqa: C901 return {"type": "string"} -def _get_cached_script_parameters( - hass: HomeAssistant, entity_id: str +def _get_cached_action_parameters( + hass: HomeAssistant, domain: str, action: str ) -> tuple[str | None, vol.Schema]: - """Get script description and schema.""" - entity_registry = er.async_get(hass) - + """Get action description and schema.""" description = None parameters = vol.Schema({}) - entity_entry = entity_registry.async_get(entity_id) - if entity_entry and entity_entry.unique_id: - parameters_cache = hass.data.get(SCRIPT_PARAMETERS_CACHE) - if parameters_cache is None: - parameters_cache = hass.data[SCRIPT_PARAMETERS_CACHE] = {} + parameters_cache = hass.data.get(ACTION_PARAMETERS_CACHE) - @callback - def clear_cache(event: Event) -> None: - """Clear script parameter cache on script reload or delete.""" - if ( - event.data[ATTR_DOMAIN] == SCRIPT_DOMAIN - and event.data[ATTR_SERVICE] in parameters_cache - ): - parameters_cache.pop(event.data[ATTR_SERVICE]) + if parameters_cache is None: + parameters_cache = hass.data[ACTION_PARAMETERS_CACHE] = {} - cancel = hass.bus.async_listen(EVENT_SERVICE_REMOVED, clear_cache) + @callback + def clear_cache(event: Event) -> None: + """Clear action parameter cache on action removal.""" + if ( + event.data[ATTR_DOMAIN] in parameters_cache + and event.data[ATTR_SERVICE] + in parameters_cache[event.data[ATTR_DOMAIN]] + ): + parameters_cache[event.data[ATTR_DOMAIN]].pop(event.data[ATTR_SERVICE]) - @callback - def on_homeassistant_close(event: Event) -> None: - """Cleanup.""" - cancel() + cancel = hass.bus.async_listen(EVENT_SERVICE_REMOVED, clear_cache) - hass.bus.async_listen_once( - EVENT_HOMEASSISTANT_CLOSE, on_homeassistant_close - ) + @callback + def on_homeassistant_close(event: Event) -> None: + """Cleanup.""" + cancel() - if entity_entry.unique_id in parameters_cache: - return parameters_cache[entity_entry.unique_id] + hass.bus.async_listen_once(EVENT_HOMEASSISTANT_CLOSE, on_homeassistant_close) - if service_desc := service.async_get_cached_service_description( - hass, SCRIPT_DOMAIN, entity_entry.unique_id - ): - description = service_desc.get("description") - schema: dict[vol.Marker, Any] = {} - fields = service_desc.get("fields", {}) + if domain in parameters_cache and action in parameters_cache[domain]: + return parameters_cache[domain][action] - for field, config in fields.items(): - field_description = config.get("description") - if not field_description: - field_description = config.get("name") - key: vol.Marker - if config.get("required"): - key = vol.Required(field, description=field_description) - else: - key = vol.Optional(field, description=field_description) - if "selector" in config: - schema[key] = selector.selector(config["selector"]) - else: - schema[key] = cv.string + if action_desc := service.async_get_cached_service_description( + hass, domain, action + ): + description = action_desc.get("description") + schema: dict[vol.Marker, Any] = {} + fields = action_desc.get("fields", {}) - parameters = vol.Schema(schema) + for field, config in fields.items(): + field_description = config.get("description") + if not field_description: + field_description = config.get("name") + key: vol.Marker + if config.get("required"): + key = vol.Required(field, description=field_description) + else: + key = vol.Optional(field, description=field_description) + if "selector" in config: + schema[key] = selector.selector(config["selector"]) + else: + schema[key] = cv.string - aliases: list[str] = [] - if entity_entry.name: - aliases.append(entity_entry.name) - if entity_entry.aliases: - aliases.extend(entity_entry.aliases) - if aliases: - if description: - description = description + ". Aliases: " + str(list(aliases)) - else: - description = "Aliases: " + str(list(aliases)) + parameters = vol.Schema(schema) - parameters_cache[entity_entry.unique_id] = (description, parameters) + if domain == SCRIPT_DOMAIN: + entity_registry = er.async_get(hass) + if ( + entity_id := entity_registry.async_get_entity_id(domain, domain, action) + ) and (entity_entry := entity_registry.async_get(entity_id)): + aliases: list[str] = [] + if entity_entry.name: + aliases.append(entity_entry.name) + if entity_entry.aliases: + aliases.extend(entity_entry.aliases) + if aliases: + if description: + description = description + ". Aliases: " + str(list(aliases)) + else: + description = "Aliases: " + str(list(aliases)) + + parameters_cache.setdefault(domain, {})[action] = (description, parameters) return description, parameters -class ScriptTool(Tool): - """LLM Tool representing a Script.""" +class ActionTool(Tool): + """LLM Tool representing an action.""" def __init__( self, hass: HomeAssistant, - script_entity_id: str, + domain: str, + action: str, ) -> None: """Init the class.""" - self._object_id = self.name = split_entity_id(script_entity_id)[1] - if self.name[0].isdigit(): - self.name = "_" + self.name - - self.description, self.parameters = _get_cached_script_parameters( - hass, script_entity_id + self._domain = domain + self._action = action + self.name = f"{domain}.{action}" + self.description, self.parameters = _get_cached_action_parameters( + hass, domain, action ) async def async_call( self, hass: HomeAssistant, tool_input: ToolInput, llm_context: LLMContext ) -> JsonObjectType: - """Run the script.""" + """Call the action.""" for field, validator in self.parameters.schema.items(): if field not in tool_input.tool_args: @@ -753,8 +754,8 @@ class ScriptTool(Tool): tool_input.tool_args[field] = floor result = await hass.services.async_call( - SCRIPT_DOMAIN, - self._object_id, + self._domain, + self._action, tool_input.tool_args, context=llm_context.context, blocking=True, @@ -764,6 +765,30 @@ class ScriptTool(Tool): return {"success": True, "result": result} +class ScriptTool(ActionTool): + """LLM Tool representing a Script.""" + + def __init__( + self, + hass: HomeAssistant, + script_entity_id: str, + ) -> None: + """Init the class.""" + script_name = split_entity_id(script_entity_id)[1] + + action = script_name + entity_registry = er.async_get(hass) + entity_entry = entity_registry.async_get(script_entity_id) + if entity_entry and entity_entry.unique_id: + action = entity_entry.unique_id + + super().__init__(hass, SCRIPT_DOMAIN, action) + + self.name = script_name + if self.name[0].isdigit(): + self.name = "_" + self.name + + class CalendarGetEventsTool(Tool): """LLM Tool allowing querying a calendar.""" diff --git a/tests/helpers/test_llm.py b/tests/helpers/test_llm.py index 57e151ba8eb..e288026b67b 100644 --- a/tests/helpers/test_llm.py +++ b/tests/helpers/test_llm.py @@ -745,7 +745,7 @@ async def test_script_tool( area = area_registry.async_create("Living room") floor = floor_registry.async_create("2") - assert llm.SCRIPT_PARAMETERS_CACHE not in hass.data + assert llm.ACTION_PARAMETERS_CACHE not in hass.data api = await llm.async_get_api(hass, "assist", llm_context) @@ -769,7 +769,7 @@ async def test_script_tool( } assert tool.parameters.schema == schema - assert hass.data[llm.SCRIPT_PARAMETERS_CACHE] == { + assert hass.data[llm.ACTION_PARAMETERS_CACHE]["script"] == { "test_script": ( "This is a test script. Aliases: ['script name', 'script alias']", vol.Schema(schema), @@ -866,7 +866,7 @@ async def test_script_tool( ): await hass.services.async_call("script", "reload", blocking=True) - assert hass.data[llm.SCRIPT_PARAMETERS_CACHE] == {} + assert hass.data[llm.ACTION_PARAMETERS_CACHE]["script"] == {} api = await llm.async_get_api(hass, "assist", llm_context) @@ -882,7 +882,7 @@ async def test_script_tool( schema = {vol.Required("beer", description="Number of beers"): cv.string} assert tool.parameters.schema == schema - assert hass.data[llm.SCRIPT_PARAMETERS_CACHE] == { + assert hass.data[llm.ACTION_PARAMETERS_CACHE]["script"] == { "test_script": ( "This is a new test script. Aliases: ['script name', 'script alias']", vol.Schema(schema),