Remove entity state from mcp-server prompt (#137126)
* Create a stateless assist API for MCP server * Update stateless API * Fix areas in exposed entity fields * Add tests that verify areas are returned * Revert the getstate intent * Revert whitespace change * Revert whitespace change * Revert method name changes to avoid breaking openai and google testspull/136549/head
parent
2c99e3778e
commit
bf6f790d09
|
@ -6,7 +6,7 @@ from homeassistant.core import HomeAssistant
|
|||
from homeassistant.helpers import config_validation as cv
|
||||
from homeassistant.helpers.typing import ConfigType
|
||||
|
||||
from . import http
|
||||
from . import http, llm_api
|
||||
from .const import DOMAIN
|
||||
from .session import SessionManager
|
||||
from .types import MCPServerConfigEntry
|
||||
|
@ -25,6 +25,7 @@ CONFIG_SCHEMA = cv.config_entry_only_config_schema(DOMAIN)
|
|||
async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool:
|
||||
"""Set up the Model Context Protocol component."""
|
||||
http.async_register(hass)
|
||||
llm_api.async_register_api(hass)
|
||||
return True
|
||||
|
||||
|
||||
|
|
|
@ -16,7 +16,7 @@ from homeassistant.helpers.selector import (
|
|||
SelectSelectorConfig,
|
||||
)
|
||||
|
||||
from .const import DOMAIN
|
||||
from .const import DOMAIN, LLM_API, LLM_API_NAME
|
||||
|
||||
_LOGGER = logging.getLogger(__name__)
|
||||
|
||||
|
@ -33,6 +33,12 @@ class ModelContextServerProtocolConfigFlow(ConfigFlow, domain=DOMAIN):
|
|||
) -> ConfigFlowResult:
|
||||
"""Handle the initial step."""
|
||||
llm_apis = {api.id: api.name for api in llm.async_get_apis(self.hass)}
|
||||
if LLM_API not in llm_apis:
|
||||
# MCP server component is not loaded yet, so make the LLM API a choice.
|
||||
llm_apis = {
|
||||
LLM_API: LLM_API_NAME,
|
||||
**llm_apis,
|
||||
}
|
||||
|
||||
if user_input is not None:
|
||||
return self.async_create_entry(
|
||||
|
|
|
@ -2,3 +2,5 @@
|
|||
|
||||
DOMAIN = "mcp_server"
|
||||
TITLE = "Model Context Protocol Server"
|
||||
LLM_API = "stateless_assist"
|
||||
LLM_API_NAME = "Stateless Assist"
|
||||
|
|
|
@ -0,0 +1,48 @@
|
|||
"""LLM API for MCP Server."""
|
||||
|
||||
from homeassistant.core import HomeAssistant, callback
|
||||
from homeassistant.helpers import llm
|
||||
from homeassistant.util import yaml as yaml_util
|
||||
|
||||
from .const import LLM_API, LLM_API_NAME
|
||||
|
||||
EXPOSED_ENTITY_FIELDS = {"name", "domain", "description", "areas", "names"}
|
||||
|
||||
|
||||
def async_register_api(hass: HomeAssistant) -> None:
|
||||
"""Register the LLM API."""
|
||||
llm.async_register_api(hass, StatelessAssistAPI(hass))
|
||||
|
||||
|
||||
class StatelessAssistAPI(llm.AssistAPI):
|
||||
"""LLM API for MCP Server that provides the Assist API without state information in the prompt.
|
||||
|
||||
Syncing the state information is possible, but may put unnecessary load on
|
||||
the system so we are instead providing the prompt without entity state. Since
|
||||
actions don't care about the current state, there is little quality loss.
|
||||
"""
|
||||
|
||||
def __init__(self, hass: HomeAssistant) -> None:
|
||||
"""Initialize the StatelessAssistAPI."""
|
||||
super().__init__(hass)
|
||||
self.id = LLM_API
|
||||
self.name = LLM_API_NAME
|
||||
|
||||
@callback
|
||||
def _async_get_exposed_entities_prompt(
|
||||
self, llm_context: llm.LLMContext, exposed_entities: dict | None
|
||||
) -> list[str]:
|
||||
"""Return the prompt for the exposed entities."""
|
||||
prompt = []
|
||||
|
||||
if exposed_entities:
|
||||
prompt.append(
|
||||
"An overview of the areas and the devices in this smart home:"
|
||||
)
|
||||
entities = [
|
||||
{k: v for k, v in entity_info.items() if k in EXPOSED_ENTITY_FIELDS}
|
||||
for entity_info in exposed_entities.values()
|
||||
]
|
||||
prompt.append(yaml_util.dump(list(entities)))
|
||||
|
||||
return prompt
|
|
@ -326,12 +326,21 @@ class AssistAPI(API):
|
|||
def _async_get_api_prompt(
|
||||
self, llm_context: LLMContext, exposed_entities: dict | None
|
||||
) -> str:
|
||||
"""Return the prompt for the API."""
|
||||
if not exposed_entities:
|
||||
return (
|
||||
"Only if the user wants to control a device, tell them to expose entities "
|
||||
"to their voice assistant in Home Assistant."
|
||||
)
|
||||
return "\n".join(
|
||||
[
|
||||
*self._async_get_preable(llm_context),
|
||||
*self._async_get_exposed_entities_prompt(llm_context, exposed_entities),
|
||||
]
|
||||
)
|
||||
|
||||
@callback
|
||||
def _async_get_preable(self, llm_context: LLMContext) -> list[str]:
|
||||
"""Return the prompt for the API."""
|
||||
|
||||
prompt = [
|
||||
(
|
||||
|
@ -371,13 +380,22 @@ class AssistAPI(API):
|
|||
):
|
||||
prompt.append("This device is not able to start timers.")
|
||||
|
||||
return prompt
|
||||
|
||||
@callback
|
||||
def _async_get_exposed_entities_prompt(
|
||||
self, llm_context: LLMContext, exposed_entities: dict | None
|
||||
) -> list[str]:
|
||||
"""Return the prompt for the API for exposed entities."""
|
||||
prompt = []
|
||||
|
||||
if exposed_entities:
|
||||
prompt.append(
|
||||
"An overview of the areas and the devices in this smart home:"
|
||||
)
|
||||
prompt.append(yaml_util.dump(list(exposed_entities.values())))
|
||||
|
||||
return "\n".join(prompt)
|
||||
return prompt
|
||||
|
||||
@callback
|
||||
def _async_get_tools(
|
||||
|
|
|
@ -5,10 +5,9 @@ from unittest.mock import AsyncMock, patch
|
|||
|
||||
import pytest
|
||||
|
||||
from homeassistant.components.mcp_server.const import DOMAIN
|
||||
from homeassistant.components.mcp_server.const import DOMAIN, LLM_API
|
||||
from homeassistant.const import CONF_LLM_HASS_API
|
||||
from homeassistant.core import HomeAssistant
|
||||
from homeassistant.helpers import llm
|
||||
|
||||
from tests.common import MockConfigEntry
|
||||
|
||||
|
@ -28,7 +27,7 @@ def mock_config_entry(hass: HomeAssistant) -> MockConfigEntry:
|
|||
config_entry = MockConfigEntry(
|
||||
domain=DOMAIN,
|
||||
data={
|
||||
CONF_LLM_HASS_API: llm.LLM_API_ASSIST,
|
||||
CONF_LLM_HASS_API: LLM_API,
|
||||
},
|
||||
)
|
||||
config_entry.add_to_hass(hass)
|
||||
|
|
|
@ -20,7 +20,11 @@ from homeassistant.components.mcp_server.http import MESSAGES_API, SSE_API
|
|||
from homeassistant.config_entries import ConfigEntryState
|
||||
from homeassistant.const import CONF_LLM_HASS_API, STATE_OFF, STATE_ON
|
||||
from homeassistant.core import HomeAssistant
|
||||
from homeassistant.helpers import device_registry as dr, entity_registry as er
|
||||
from homeassistant.helpers import (
|
||||
area_registry as ar,
|
||||
device_registry as dr,
|
||||
entity_registry as er,
|
||||
)
|
||||
from homeassistant.setup import async_setup_component
|
||||
|
||||
from tests.common import MockConfigEntry, setup_test_component_platform
|
||||
|
@ -45,6 +49,11 @@ INITIALIZE_MESSAGE = {
|
|||
}
|
||||
EVENT_PREFIX = "event: "
|
||||
DATA_PREFIX = "data: "
|
||||
EXPECTED_PROMPT_SUFFIX = """
|
||||
- names: Kitchen Light
|
||||
domain: light
|
||||
areas: Kitchen
|
||||
"""
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
|
@ -59,11 +68,13 @@ async def mock_entities(
|
|||
hass: HomeAssistant,
|
||||
device_registry: dr.DeviceRegistry,
|
||||
entity_registry: er.EntityRegistry,
|
||||
area_registry: ar.AreaRegistry,
|
||||
setup_integration: None,
|
||||
) -> None:
|
||||
"""Fixture to expose entities to the conversation agent."""
|
||||
entity = MockLight("kitchen", STATE_OFF)
|
||||
entity = MockLight("Kitchen Light", STATE_OFF)
|
||||
entity.entity_id = TEST_ENTITY
|
||||
entity.unique_id = "test-light-unique-id"
|
||||
setup_test_component_platform(hass, LIGHT_DOMAIN, [entity])
|
||||
|
||||
assert await async_setup_component(
|
||||
|
@ -71,6 +82,9 @@ async def mock_entities(
|
|||
LIGHT_DOMAIN,
|
||||
{LIGHT_DOMAIN: [{"platform": "test"}]},
|
||||
)
|
||||
await hass.async_block_till_done()
|
||||
kitchen = area_registry.async_get_or_create("Kitchen")
|
||||
entity_registry.async_update_entity(TEST_ENTITY, area_id=kitchen.id)
|
||||
|
||||
async_expose_entity(hass, CONVERSATION_DOMAIN, TEST_ENTITY, True)
|
||||
|
||||
|
@ -320,7 +334,7 @@ async def test_mcp_tool_call(
|
|||
async with mcp_session(mcp_sse_url, hass_supervisor_access_token) as session:
|
||||
result = await session.call_tool(
|
||||
name="HassTurnOn",
|
||||
arguments={"name": "kitchen"},
|
||||
arguments={"name": "kitchen light"},
|
||||
)
|
||||
|
||||
assert not result.isError
|
||||
|
@ -370,8 +384,11 @@ async def test_prompt_list(
|
|||
|
||||
assert len(result.prompts) == 1
|
||||
prompt = result.prompts[0]
|
||||
assert prompt.name == "Assist"
|
||||
assert prompt.description == "Default prompt for the Home Assistant LLM API Assist"
|
||||
assert prompt.name == "Stateless Assist"
|
||||
assert (
|
||||
prompt.description
|
||||
== "Default prompt for the Home Assistant LLM API Stateless Assist"
|
||||
)
|
||||
|
||||
|
||||
async def test_prompt_get(
|
||||
|
@ -383,13 +400,17 @@ async def test_prompt_get(
|
|||
"""Test the get prompt endpoint."""
|
||||
|
||||
async with mcp_session(mcp_sse_url, hass_supervisor_access_token) as session:
|
||||
result = await session.get_prompt(name="Assist")
|
||||
result = await session.get_prompt(name="Stateless Assist")
|
||||
|
||||
assert result.description == "Default prompt for the Home Assistant LLM API Assist"
|
||||
assert (
|
||||
result.description
|
||||
== "Default prompt for the Home Assistant LLM API Stateless Assist"
|
||||
)
|
||||
assert len(result.messages) == 1
|
||||
assert result.messages[0].role == "assistant"
|
||||
assert result.messages[0].content.type == "text"
|
||||
assert "When controlling Home Assistant" in result.messages[0].content.text
|
||||
assert result.messages[0].content.text.endswith(EXPECTED_PROMPT_SUFFIX)
|
||||
|
||||
|
||||
async def test_get_unknwon_prompt(
|
||||
|
|
Loading…
Reference in New Issue