Allow unregistering LLM APIs (#135162)
parent
ec37e1ff8d
commit
6e111d18ec
|
@ -85,7 +85,7 @@ def _async_get_apis(hass: HomeAssistant) -> dict[str, API]:
|
|||
|
||||
|
||||
@callback
|
||||
def async_register_api(hass: HomeAssistant, api: API) -> None:
|
||||
def async_register_api(hass: HomeAssistant, api: API) -> Callable[[], None]:
|
||||
"""Register an API to be exposed to LLMs."""
|
||||
apis = _async_get_apis(hass)
|
||||
|
||||
|
@ -94,6 +94,13 @@ def async_register_api(hass: HomeAssistant, api: API) -> None:
|
|||
|
||||
apis[api.id] = api
|
||||
|
||||
@callback
|
||||
def unregister() -> None:
|
||||
"""Unregister the API."""
|
||||
apis.pop(api.id)
|
||||
|
||||
return unregister
|
||||
|
||||
|
||||
async def async_get_api(
|
||||
hass: HomeAssistant, api_id: str, llm_context: LLMContext
|
||||
|
|
|
@ -39,6 +39,14 @@ def llm_context() -> llm.LLMContext:
|
|||
)
|
||||
|
||||
|
||||
class MyAPI(llm.API):
|
||||
"""Test API."""
|
||||
|
||||
async def async_get_api_instance(self, _: llm.ToolInput) -> llm.APIInstance:
|
||||
"""Return a list of tools."""
|
||||
return llm.APIInstance(self, "", [], llm_context)
|
||||
|
||||
|
||||
async def test_get_api_no_existing(
|
||||
hass: HomeAssistant, llm_context: llm.LLMContext
|
||||
) -> None:
|
||||
|
@ -50,11 +58,6 @@ async def test_get_api_no_existing(
|
|||
async def test_register_api(hass: HomeAssistant, llm_context: llm.LLMContext) -> None:
|
||||
"""Test registering an llm api."""
|
||||
|
||||
class MyAPI(llm.API):
|
||||
async def async_get_api_instance(self, _: llm.ToolInput) -> llm.APIInstance:
|
||||
"""Return a list of tools."""
|
||||
return llm.APIInstance(self, "", [], llm_context)
|
||||
|
||||
api = MyAPI(hass=hass, id="test", name="Test")
|
||||
llm.async_register_api(hass, api)
|
||||
|
||||
|
@ -66,6 +69,59 @@ async def test_register_api(hass: HomeAssistant, llm_context: llm.LLMContext) ->
|
|||
llm.async_register_api(hass, api)
|
||||
|
||||
|
||||
async def test_unregister_api(hass: HomeAssistant, llm_context: llm.LLMContext) -> None:
|
||||
"""Test unregistering an llm api."""
|
||||
|
||||
unreg = llm.async_register_api(hass, MyAPI(hass=hass, id="test", name="Test"))
|
||||
assert await llm.async_get_api(hass, "test", llm_context)
|
||||
unreg()
|
||||
with pytest.raises(HomeAssistantError):
|
||||
assert await llm.async_get_api(hass, "test", llm_context)
|
||||
|
||||
|
||||
async def test_reregister_api(hass: HomeAssistant, llm_context: llm.LLMContext) -> None:
|
||||
"""Test unregistering an llm api then re-registering with the same id."""
|
||||
|
||||
unreg = llm.async_register_api(hass, MyAPI(hass=hass, id="test", name="Test"))
|
||||
assert await llm.async_get_api(hass, "test", llm_context)
|
||||
unreg()
|
||||
llm.async_register_api(hass, MyAPI(hass=hass, id="test", name="Test"))
|
||||
assert await llm.async_get_api(hass, "test", llm_context)
|
||||
|
||||
|
||||
async def test_unregister_twice(
|
||||
hass: HomeAssistant, llm_context: llm.LLMContext
|
||||
) -> None:
|
||||
"""Test unregistering an llm api twice."""
|
||||
|
||||
unreg = llm.async_register_api(hass, MyAPI(hass=hass, id="test", name="Test"))
|
||||
assert await llm.async_get_api(hass, "test", llm_context)
|
||||
unreg()
|
||||
|
||||
# Unregistering twice is a bug that should not happen
|
||||
with pytest.raises(KeyError):
|
||||
unreg()
|
||||
|
||||
|
||||
async def test_multiple_apis(hass: HomeAssistant, llm_context: llm.LLMContext) -> None:
|
||||
"""Test registering multiple APIs."""
|
||||
|
||||
unreg1 = llm.async_register_api(hass, MyAPI(hass=hass, id="test-1", name="Test 1"))
|
||||
llm.async_register_api(hass, MyAPI(hass=hass, id="test-2", name="Test 2"))
|
||||
|
||||
# Verify both Apis are registered
|
||||
assert await llm.async_get_api(hass, "test-1", llm_context)
|
||||
assert await llm.async_get_api(hass, "test-2", llm_context)
|
||||
|
||||
# Unregister and verify only one is left
|
||||
unreg1()
|
||||
|
||||
with pytest.raises(HomeAssistantError):
|
||||
assert await llm.async_get_api(hass, "test-1", llm_context)
|
||||
|
||||
assert await llm.async_get_api(hass, "test-2", llm_context)
|
||||
|
||||
|
||||
async def test_call_tool_no_existing(
|
||||
hass: HomeAssistant, llm_context: llm.LLMContext
|
||||
) -> None:
|
||||
|
|
Loading…
Reference in New Issue