Add prompts to MCP server (#134619)
* Add prompts to MCP server * Improve test coverage for get prompt error casespull/134697/head
parent
c9a607aa45
commit
bb97a16756
|
@ -50,6 +50,37 @@ async def create_server(
|
|||
|
||||
server = Server("home-assistant")
|
||||
|
||||
@server.list_prompts() # type: ignore[no-untyped-call, misc]
|
||||
async def handle_list_prompts() -> list[types.Prompt]:
|
||||
llm_api = await llm.async_get_api(hass, llm_api_id, llm_context)
|
||||
return [
|
||||
types.Prompt(
|
||||
name=llm_api.api.name,
|
||||
description=f"Default prompt for the Home Assistant LLM API {llm_api.api.name}",
|
||||
)
|
||||
]
|
||||
|
||||
@server.get_prompt() # type: ignore[no-untyped-call, misc]
|
||||
async def handle_get_prompt(
|
||||
name: str, arguments: dict[str, str] | None
|
||||
) -> types.GetPromptResult:
|
||||
llm_api = await llm.async_get_api(hass, llm_api_id, llm_context)
|
||||
if name != llm_api.api.name:
|
||||
raise ValueError(f"Unknown prompt: {name}")
|
||||
|
||||
return types.GetPromptResult(
|
||||
description=f"Default prompt for the Home Assistant LLM API {llm_api.api.name}",
|
||||
messages=[
|
||||
types.PromptMessage(
|
||||
role="assistant",
|
||||
content=types.TextContent(
|
||||
type="text",
|
||||
text=llm_api.api_prompt,
|
||||
),
|
||||
)
|
||||
],
|
||||
)
|
||||
|
||||
@server.list_tools() # type: ignore[no-untyped-call, misc]
|
||||
async def list_tools() -> list[types.Tool]:
|
||||
"""List available time tools."""
|
||||
|
|
|
@ -10,6 +10,7 @@ import aiohttp
|
|||
import mcp
|
||||
import mcp.client.session
|
||||
import mcp.client.sse
|
||||
from mcp.shared.exceptions import McpError
|
||||
import pytest
|
||||
|
||||
from homeassistant.components.conversation import DOMAIN as CONVERSATION_DOMAIN
|
||||
|
@ -354,3 +355,51 @@ async def test_mcp_tool_call_failed(
|
|||
assert len(result.content) == 1
|
||||
assert result.content[0].type == "text"
|
||||
assert "Error calling tool" in result.content[0].text
|
||||
|
||||
|
||||
async def test_prompt_list(
|
||||
hass: HomeAssistant,
|
||||
setup_integration: None,
|
||||
mcp_sse_url: str,
|
||||
hass_supervisor_access_token: str,
|
||||
) -> None:
|
||||
"""Test the list prompt endpoint."""
|
||||
|
||||
async with mcp_session(mcp_sse_url, hass_supervisor_access_token) as session:
|
||||
result = await session.list_prompts()
|
||||
|
||||
assert len(result.prompts) == 1
|
||||
prompt = result.prompts[0]
|
||||
assert prompt.name == "Assist"
|
||||
assert prompt.description == "Default prompt for the Home Assistant LLM API Assist"
|
||||
|
||||
|
||||
async def test_prompt_get(
|
||||
hass: HomeAssistant,
|
||||
setup_integration: None,
|
||||
mcp_sse_url: str,
|
||||
hass_supervisor_access_token: str,
|
||||
) -> None:
|
||||
"""Test the get prompt endpoint."""
|
||||
|
||||
async with mcp_session(mcp_sse_url, hass_supervisor_access_token) as session:
|
||||
result = await session.get_prompt(name="Assist")
|
||||
|
||||
assert result.description == "Default prompt for the Home Assistant LLM API Assist"
|
||||
assert len(result.messages) == 1
|
||||
assert result.messages[0].role == "assistant"
|
||||
assert result.messages[0].content.type == "text"
|
||||
assert "When controlling Home Assistant" in result.messages[0].content.text
|
||||
|
||||
|
||||
async def test_get_unknwon_prompt(
|
||||
hass: HomeAssistant,
|
||||
setup_integration: None,
|
||||
mcp_sse_url: str,
|
||||
hass_supervisor_access_token: str,
|
||||
) -> None:
|
||||
"""Test the get prompt endpoint."""
|
||||
|
||||
async with mcp_session(mcp_sse_url, hass_supervisor_access_token) as session:
|
||||
with pytest.raises(McpError):
|
||||
await session.get_prompt(name="Unknown")
|
||||
|
|
Loading…
Reference in New Issue