diff --git a/homeassistant/components/mcp_server/__init__.py b/homeassistant/components/mcp_server/__init__.py index e523f46228f..941eccbe528 100644 --- a/homeassistant/components/mcp_server/__init__.py +++ b/homeassistant/components/mcp_server/__init__.py @@ -6,7 +6,7 @@ from homeassistant.core import HomeAssistant from homeassistant.helpers import config_validation as cv from homeassistant.helpers.typing import ConfigType -from . import http +from . import http, llm_api from .const import DOMAIN from .session import SessionManager from .types import MCPServerConfigEntry @@ -25,6 +25,7 @@ CONFIG_SCHEMA = cv.config_entry_only_config_schema(DOMAIN) async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool: """Set up the Model Context Protocol component.""" http.async_register(hass) + llm_api.async_register_api(hass) return True diff --git a/homeassistant/components/mcp_server/config_flow.py b/homeassistant/components/mcp_server/config_flow.py index 8d68c6a868a..8d8d311b874 100644 --- a/homeassistant/components/mcp_server/config_flow.py +++ b/homeassistant/components/mcp_server/config_flow.py @@ -16,7 +16,7 @@ from homeassistant.helpers.selector import ( SelectSelectorConfig, ) -from .const import DOMAIN +from .const import DOMAIN, LLM_API, LLM_API_NAME _LOGGER = logging.getLogger(__name__) @@ -33,6 +33,12 @@ class ModelContextServerProtocolConfigFlow(ConfigFlow, domain=DOMAIN): ) -> ConfigFlowResult: """Handle the initial step.""" llm_apis = {api.id: api.name for api in llm.async_get_apis(self.hass)} + if LLM_API not in llm_apis: + # MCP server component is not loaded yet, so make the LLM API a choice. + llm_apis = { + LLM_API: LLM_API_NAME, + **llm_apis, + } if user_input is not None: return self.async_create_entry( diff --git a/homeassistant/components/mcp_server/const.py b/homeassistant/components/mcp_server/const.py index 1aa81f445a1..8958ac36616 100644 --- a/homeassistant/components/mcp_server/const.py +++ b/homeassistant/components/mcp_server/const.py @@ -2,3 +2,5 @@ DOMAIN = "mcp_server" TITLE = "Model Context Protocol Server" +LLM_API = "stateless_assist" +LLM_API_NAME = "Stateless Assist" diff --git a/homeassistant/components/mcp_server/llm_api.py b/homeassistant/components/mcp_server/llm_api.py new file mode 100644 index 00000000000..f4292744815 --- /dev/null +++ b/homeassistant/components/mcp_server/llm_api.py @@ -0,0 +1,48 @@ +"""LLM API for MCP Server.""" + +from homeassistant.core import HomeAssistant, callback +from homeassistant.helpers import llm +from homeassistant.util import yaml as yaml_util + +from .const import LLM_API, LLM_API_NAME + +EXPOSED_ENTITY_FIELDS = {"name", "domain", "description", "areas", "names"} + + +def async_register_api(hass: HomeAssistant) -> None: + """Register the LLM API.""" + llm.async_register_api(hass, StatelessAssistAPI(hass)) + + +class StatelessAssistAPI(llm.AssistAPI): + """LLM API for MCP Server that provides the Assist API without state information in the prompt. + + Syncing the state information is possible, but may put unnecessary load on + the system so we are instead providing the prompt without entity state. Since + actions don't care about the current state, there is little quality loss. + """ + + def __init__(self, hass: HomeAssistant) -> None: + """Initialize the StatelessAssistAPI.""" + super().__init__(hass) + self.id = LLM_API + self.name = LLM_API_NAME + + @callback + def _async_get_exposed_entities_prompt( + self, llm_context: llm.LLMContext, exposed_entities: dict | None + ) -> list[str]: + """Return the prompt for the exposed entities.""" + prompt = [] + + if exposed_entities: + prompt.append( + "An overview of the areas and the devices in this smart home:" + ) + entities = [ + {k: v for k, v in entity_info.items() if k in EXPOSED_ENTITY_FIELDS} + for entity_info in exposed_entities.values() + ] + prompt.append(yaml_util.dump(list(entities))) + + return prompt diff --git a/homeassistant/helpers/llm.py b/homeassistant/helpers/llm.py index cc397c5d428..2bca4c8528b 100644 --- a/homeassistant/helpers/llm.py +++ b/homeassistant/helpers/llm.py @@ -326,12 +326,21 @@ class AssistAPI(API): def _async_get_api_prompt( self, llm_context: LLMContext, exposed_entities: dict | None ) -> str: - """Return the prompt for the API.""" if not exposed_entities: return ( "Only if the user wants to control a device, tell them to expose entities " "to their voice assistant in Home Assistant." ) + return "\n".join( + [ + *self._async_get_preable(llm_context), + *self._async_get_exposed_entities_prompt(llm_context, exposed_entities), + ] + ) + + @callback + def _async_get_preable(self, llm_context: LLMContext) -> list[str]: + """Return the prompt for the API.""" prompt = [ ( @@ -371,13 +380,22 @@ class AssistAPI(API): ): prompt.append("This device is not able to start timers.") + return prompt + + @callback + def _async_get_exposed_entities_prompt( + self, llm_context: LLMContext, exposed_entities: dict | None + ) -> list[str]: + """Return the prompt for the API for exposed entities.""" + prompt = [] + if exposed_entities: prompt.append( "An overview of the areas and the devices in this smart home:" ) prompt.append(yaml_util.dump(list(exposed_entities.values()))) - return "\n".join(prompt) + return prompt @callback def _async_get_tools( diff --git a/tests/components/mcp_server/conftest.py b/tests/components/mcp_server/conftest.py index 149073f3645..5ec67fb6ce3 100644 --- a/tests/components/mcp_server/conftest.py +++ b/tests/components/mcp_server/conftest.py @@ -5,10 +5,9 @@ from unittest.mock import AsyncMock, patch import pytest -from homeassistant.components.mcp_server.const import DOMAIN +from homeassistant.components.mcp_server.const import DOMAIN, LLM_API from homeassistant.const import CONF_LLM_HASS_API from homeassistant.core import HomeAssistant -from homeassistant.helpers import llm from tests.common import MockConfigEntry @@ -28,7 +27,7 @@ def mock_config_entry(hass: HomeAssistant) -> MockConfigEntry: config_entry = MockConfigEntry( domain=DOMAIN, data={ - CONF_LLM_HASS_API: llm.LLM_API_ASSIST, + CONF_LLM_HASS_API: LLM_API, }, ) config_entry.add_to_hass(hass) diff --git a/tests/components/mcp_server/test_http.py b/tests/components/mcp_server/test_http.py index a71bf42acc8..905bfaa11d7 100644 --- a/tests/components/mcp_server/test_http.py +++ b/tests/components/mcp_server/test_http.py @@ -20,7 +20,11 @@ from homeassistant.components.mcp_server.http import MESSAGES_API, SSE_API from homeassistant.config_entries import ConfigEntryState from homeassistant.const import CONF_LLM_HASS_API, STATE_OFF, STATE_ON from homeassistant.core import HomeAssistant -from homeassistant.helpers import device_registry as dr, entity_registry as er +from homeassistant.helpers import ( + area_registry as ar, + device_registry as dr, + entity_registry as er, +) from homeassistant.setup import async_setup_component from tests.common import MockConfigEntry, setup_test_component_platform @@ -45,6 +49,11 @@ INITIALIZE_MESSAGE = { } EVENT_PREFIX = "event: " DATA_PREFIX = "data: " +EXPECTED_PROMPT_SUFFIX = """ +- names: Kitchen Light + domain: light + areas: Kitchen +""" @pytest.fixture @@ -59,11 +68,13 @@ async def mock_entities( hass: HomeAssistant, device_registry: dr.DeviceRegistry, entity_registry: er.EntityRegistry, + area_registry: ar.AreaRegistry, setup_integration: None, ) -> None: """Fixture to expose entities to the conversation agent.""" - entity = MockLight("kitchen", STATE_OFF) + entity = MockLight("Kitchen Light", STATE_OFF) entity.entity_id = TEST_ENTITY + entity.unique_id = "test-light-unique-id" setup_test_component_platform(hass, LIGHT_DOMAIN, [entity]) assert await async_setup_component( @@ -71,6 +82,9 @@ async def mock_entities( LIGHT_DOMAIN, {LIGHT_DOMAIN: [{"platform": "test"}]}, ) + await hass.async_block_till_done() + kitchen = area_registry.async_get_or_create("Kitchen") + entity_registry.async_update_entity(TEST_ENTITY, area_id=kitchen.id) async_expose_entity(hass, CONVERSATION_DOMAIN, TEST_ENTITY, True) @@ -320,7 +334,7 @@ async def test_mcp_tool_call( async with mcp_session(mcp_sse_url, hass_supervisor_access_token) as session: result = await session.call_tool( name="HassTurnOn", - arguments={"name": "kitchen"}, + arguments={"name": "kitchen light"}, ) assert not result.isError @@ -370,8 +384,11 @@ async def test_prompt_list( assert len(result.prompts) == 1 prompt = result.prompts[0] - assert prompt.name == "Assist" - assert prompt.description == "Default prompt for the Home Assistant LLM API Assist" + assert prompt.name == "Stateless Assist" + assert ( + prompt.description + == "Default prompt for the Home Assistant LLM API Stateless Assist" + ) async def test_prompt_get( @@ -383,13 +400,17 @@ async def test_prompt_get( """Test the get prompt endpoint.""" async with mcp_session(mcp_sse_url, hass_supervisor_access_token) as session: - result = await session.get_prompt(name="Assist") + result = await session.get_prompt(name="Stateless Assist") - assert result.description == "Default prompt for the Home Assistant LLM API Assist" + assert ( + result.description + == "Default prompt for the Home Assistant LLM API Stateless Assist" + ) assert len(result.messages) == 1 assert result.messages[0].role == "assistant" assert result.messages[0].content.type == "text" assert "When controlling Home Assistant" in result.messages[0].content.text + assert result.messages[0].content.text.endswith(EXPECTED_PROMPT_SUFFIX) async def test_get_unknwon_prompt(