Add prompts to MCP server (#134619)

* Add prompts to MCP server

* Improve test coverage for get prompt error cases
pull/134697/head
Allen Porter 2025-01-04 09:35:05 -08:00 committed by GitHub
parent c9a607aa45
commit bb97a16756
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 80 additions and 0 deletions

View File

@ -50,6 +50,37 @@ async def create_server(
server = Server("home-assistant")
@server.list_prompts() # type: ignore[no-untyped-call, misc]
async def handle_list_prompts() -> list[types.Prompt]:
llm_api = await llm.async_get_api(hass, llm_api_id, llm_context)
return [
types.Prompt(
name=llm_api.api.name,
description=f"Default prompt for the Home Assistant LLM API {llm_api.api.name}",
)
]
@server.get_prompt() # type: ignore[no-untyped-call, misc]
async def handle_get_prompt(
name: str, arguments: dict[str, str] | None
) -> types.GetPromptResult:
llm_api = await llm.async_get_api(hass, llm_api_id, llm_context)
if name != llm_api.api.name:
raise ValueError(f"Unknown prompt: {name}")
return types.GetPromptResult(
description=f"Default prompt for the Home Assistant LLM API {llm_api.api.name}",
messages=[
types.PromptMessage(
role="assistant",
content=types.TextContent(
type="text",
text=llm_api.api_prompt,
),
)
],
)
@server.list_tools() # type: ignore[no-untyped-call, misc]
async def list_tools() -> list[types.Tool]:
"""List available time tools."""

View File

@ -10,6 +10,7 @@ import aiohttp
import mcp
import mcp.client.session
import mcp.client.sse
from mcp.shared.exceptions import McpError
import pytest
from homeassistant.components.conversation import DOMAIN as CONVERSATION_DOMAIN
@ -354,3 +355,51 @@ async def test_mcp_tool_call_failed(
assert len(result.content) == 1
assert result.content[0].type == "text"
assert "Error calling tool" in result.content[0].text
async def test_prompt_list(
hass: HomeAssistant,
setup_integration: None,
mcp_sse_url: str,
hass_supervisor_access_token: str,
) -> None:
"""Test the list prompt endpoint."""
async with mcp_session(mcp_sse_url, hass_supervisor_access_token) as session:
result = await session.list_prompts()
assert len(result.prompts) == 1
prompt = result.prompts[0]
assert prompt.name == "Assist"
assert prompt.description == "Default prompt for the Home Assistant LLM API Assist"
async def test_prompt_get(
hass: HomeAssistant,
setup_integration: None,
mcp_sse_url: str,
hass_supervisor_access_token: str,
) -> None:
"""Test the get prompt endpoint."""
async with mcp_session(mcp_sse_url, hass_supervisor_access_token) as session:
result = await session.get_prompt(name="Assist")
assert result.description == "Default prompt for the Home Assistant LLM API Assist"
assert len(result.messages) == 1
assert result.messages[0].role == "assistant"
assert result.messages[0].content.type == "text"
assert "When controlling Home Assistant" in result.messages[0].content.text
async def test_get_unknwon_prompt(
hass: HomeAssistant,
setup_integration: None,
mcp_sse_url: str,
hass_supervisor_access_token: str,
) -> None:
"""Test the get prompt endpoint."""
async with mcp_session(mcp_sse_url, hass_supervisor_access_token) as session:
with pytest.raises(McpError):
await session.get_prompt(name="Unknown")