Remove entity state from mcp-server prompt (#137126)

* Create a stateless assist API for MCP server

* Update stateless API

* Fix areas in exposed entity fields

* Add tests that verify areas are returned

* Revert the getstate intent

* Revert whitespace change

* Revert whitespace change

* Revert method name changes to avoid breaking openai and google tests
pull/136549/head
Allen Porter 2025-02-01 14:26:52 -08:00 committed by GitHub
parent 2c99e3778e
commit bf6f790d09
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
7 changed files with 109 additions and 14 deletions

View File

@ -6,7 +6,7 @@ from homeassistant.core import HomeAssistant
from homeassistant.helpers import config_validation as cv
from homeassistant.helpers.typing import ConfigType
from . import http
from . import http, llm_api
from .const import DOMAIN
from .session import SessionManager
from .types import MCPServerConfigEntry
@ -25,6 +25,7 @@ CONFIG_SCHEMA = cv.config_entry_only_config_schema(DOMAIN)
async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool:
"""Set up the Model Context Protocol component."""
http.async_register(hass)
llm_api.async_register_api(hass)
return True

View File

@ -16,7 +16,7 @@ from homeassistant.helpers.selector import (
SelectSelectorConfig,
)
from .const import DOMAIN
from .const import DOMAIN, LLM_API, LLM_API_NAME
_LOGGER = logging.getLogger(__name__)
@ -33,6 +33,12 @@ class ModelContextServerProtocolConfigFlow(ConfigFlow, domain=DOMAIN):
) -> ConfigFlowResult:
"""Handle the initial step."""
llm_apis = {api.id: api.name for api in llm.async_get_apis(self.hass)}
if LLM_API not in llm_apis:
# MCP server component is not loaded yet, so make the LLM API a choice.
llm_apis = {
LLM_API: LLM_API_NAME,
**llm_apis,
}
if user_input is not None:
return self.async_create_entry(

View File

@ -2,3 +2,5 @@
DOMAIN = "mcp_server"
TITLE = "Model Context Protocol Server"
LLM_API = "stateless_assist"
LLM_API_NAME = "Stateless Assist"

View File

@ -0,0 +1,48 @@
"""LLM API for MCP Server."""
from homeassistant.core import HomeAssistant, callback
from homeassistant.helpers import llm
from homeassistant.util import yaml as yaml_util
from .const import LLM_API, LLM_API_NAME
EXPOSED_ENTITY_FIELDS = {"name", "domain", "description", "areas", "names"}
def async_register_api(hass: HomeAssistant) -> None:
"""Register the LLM API."""
llm.async_register_api(hass, StatelessAssistAPI(hass))
class StatelessAssistAPI(llm.AssistAPI):
"""LLM API for MCP Server that provides the Assist API without state information in the prompt.
Syncing the state information is possible, but may put unnecessary load on
the system so we are instead providing the prompt without entity state. Since
actions don't care about the current state, there is little quality loss.
"""
def __init__(self, hass: HomeAssistant) -> None:
"""Initialize the StatelessAssistAPI."""
super().__init__(hass)
self.id = LLM_API
self.name = LLM_API_NAME
@callback
def _async_get_exposed_entities_prompt(
self, llm_context: llm.LLMContext, exposed_entities: dict | None
) -> list[str]:
"""Return the prompt for the exposed entities."""
prompt = []
if exposed_entities:
prompt.append(
"An overview of the areas and the devices in this smart home:"
)
entities = [
{k: v for k, v in entity_info.items() if k in EXPOSED_ENTITY_FIELDS}
for entity_info in exposed_entities.values()
]
prompt.append(yaml_util.dump(list(entities)))
return prompt

View File

@ -326,12 +326,21 @@ class AssistAPI(API):
def _async_get_api_prompt(
self, llm_context: LLMContext, exposed_entities: dict | None
) -> str:
"""Return the prompt for the API."""
if not exposed_entities:
return (
"Only if the user wants to control a device, tell them to expose entities "
"to their voice assistant in Home Assistant."
)
return "\n".join(
[
*self._async_get_preable(llm_context),
*self._async_get_exposed_entities_prompt(llm_context, exposed_entities),
]
)
@callback
def _async_get_preable(self, llm_context: LLMContext) -> list[str]:
"""Return the prompt for the API."""
prompt = [
(
@ -371,13 +380,22 @@ class AssistAPI(API):
):
prompt.append("This device is not able to start timers.")
return prompt
@callback
def _async_get_exposed_entities_prompt(
self, llm_context: LLMContext, exposed_entities: dict | None
) -> list[str]:
"""Return the prompt for the API for exposed entities."""
prompt = []
if exposed_entities:
prompt.append(
"An overview of the areas and the devices in this smart home:"
)
prompt.append(yaml_util.dump(list(exposed_entities.values())))
return "\n".join(prompt)
return prompt
@callback
def _async_get_tools(

View File

@ -5,10 +5,9 @@ from unittest.mock import AsyncMock, patch
import pytest
from homeassistant.components.mcp_server.const import DOMAIN
from homeassistant.components.mcp_server.const import DOMAIN, LLM_API
from homeassistant.const import CONF_LLM_HASS_API
from homeassistant.core import HomeAssistant
from homeassistant.helpers import llm
from tests.common import MockConfigEntry
@ -28,7 +27,7 @@ def mock_config_entry(hass: HomeAssistant) -> MockConfigEntry:
config_entry = MockConfigEntry(
domain=DOMAIN,
data={
CONF_LLM_HASS_API: llm.LLM_API_ASSIST,
CONF_LLM_HASS_API: LLM_API,
},
)
config_entry.add_to_hass(hass)

View File

@ -20,7 +20,11 @@ from homeassistant.components.mcp_server.http import MESSAGES_API, SSE_API
from homeassistant.config_entries import ConfigEntryState
from homeassistant.const import CONF_LLM_HASS_API, STATE_OFF, STATE_ON
from homeassistant.core import HomeAssistant
from homeassistant.helpers import device_registry as dr, entity_registry as er
from homeassistant.helpers import (
area_registry as ar,
device_registry as dr,
entity_registry as er,
)
from homeassistant.setup import async_setup_component
from tests.common import MockConfigEntry, setup_test_component_platform
@ -45,6 +49,11 @@ INITIALIZE_MESSAGE = {
}
EVENT_PREFIX = "event: "
DATA_PREFIX = "data: "
EXPECTED_PROMPT_SUFFIX = """
- names: Kitchen Light
domain: light
areas: Kitchen
"""
@pytest.fixture
@ -59,11 +68,13 @@ async def mock_entities(
hass: HomeAssistant,
device_registry: dr.DeviceRegistry,
entity_registry: er.EntityRegistry,
area_registry: ar.AreaRegistry,
setup_integration: None,
) -> None:
"""Fixture to expose entities to the conversation agent."""
entity = MockLight("kitchen", STATE_OFF)
entity = MockLight("Kitchen Light", STATE_OFF)
entity.entity_id = TEST_ENTITY
entity.unique_id = "test-light-unique-id"
setup_test_component_platform(hass, LIGHT_DOMAIN, [entity])
assert await async_setup_component(
@ -71,6 +82,9 @@ async def mock_entities(
LIGHT_DOMAIN,
{LIGHT_DOMAIN: [{"platform": "test"}]},
)
await hass.async_block_till_done()
kitchen = area_registry.async_get_or_create("Kitchen")
entity_registry.async_update_entity(TEST_ENTITY, area_id=kitchen.id)
async_expose_entity(hass, CONVERSATION_DOMAIN, TEST_ENTITY, True)
@ -320,7 +334,7 @@ async def test_mcp_tool_call(
async with mcp_session(mcp_sse_url, hass_supervisor_access_token) as session:
result = await session.call_tool(
name="HassTurnOn",
arguments={"name": "kitchen"},
arguments={"name": "kitchen light"},
)
assert not result.isError
@ -370,8 +384,11 @@ async def test_prompt_list(
assert len(result.prompts) == 1
prompt = result.prompts[0]
assert prompt.name == "Assist"
assert prompt.description == "Default prompt for the Home Assistant LLM API Assist"
assert prompt.name == "Stateless Assist"
assert (
prompt.description
== "Default prompt for the Home Assistant LLM API Stateless Assist"
)
async def test_prompt_get(
@ -383,13 +400,17 @@ async def test_prompt_get(
"""Test the get prompt endpoint."""
async with mcp_session(mcp_sse_url, hass_supervisor_access_token) as session:
result = await session.get_prompt(name="Assist")
result = await session.get_prompt(name="Stateless Assist")
assert result.description == "Default prompt for the Home Assistant LLM API Assist"
assert (
result.description
== "Default prompt for the Home Assistant LLM API Stateless Assist"
)
assert len(result.messages) == 1
assert result.messages[0].role == "assistant"
assert result.messages[0].content.type == "text"
assert "When controlling Home Assistant" in result.messages[0].content.text
assert result.messages[0].content.text.endswith(EXPECTED_PROMPT_SUFFIX)
async def test_get_unknwon_prompt(