core/tests/components/mcp/test_init.py

226 lines
7.5 KiB
Python

"""Tests for the Model Context Protocol component."""
import re
from unittest.mock import Mock, patch
import httpx
from mcp.types import CallToolResult, ListToolsResult, TextContent, Tool
import pytest
import voluptuous as vol
from homeassistant.config_entries import ConfigEntryState
from homeassistant.core import Context, HomeAssistant
from homeassistant.exceptions import HomeAssistantError
from homeassistant.helpers import llm
from .conftest import TEST_API_NAME
from tests.common import MockConfigEntry
SEARCH_MEMORY_TOOL = Tool(
name="search_memory",
description="Search memory for relevant context based on a query.",
inputSchema={
"type": "object",
"required": ["query"],
"properties": {
"query": {
"type": "string",
"description": "A free text query to search context for.",
}
},
},
)
SAVE_MEMORY_TOOL = Tool(
name="save_memory",
description="Save a memory context.",
inputSchema={
"type": "object",
"required": ["context"],
"properties": {
"context": {
"type": "object",
"description": "The context to save.",
"properties": {
"fact": {
"type": "string",
"description": "The key for the context.",
},
},
},
},
},
)
def create_llm_context() -> llm.LLMContext:
"""Create a test LLM context."""
return llm.LLMContext(
platform="test_platform",
context=Context(),
user_prompt="test_text",
language="*",
assistant="conversation",
device_id=None,
)
async def test_init(
hass: HomeAssistant, config_entry: MockConfigEntry, mock_mcp_client: Mock
) -> None:
"""Test the integration is initialized and can be unloaded cleanly."""
await hass.config_entries.async_setup(config_entry.entry_id)
assert config_entry.state is ConfigEntryState.LOADED
await hass.config_entries.async_unload(config_entry.entry_id)
assert config_entry.state is ConfigEntryState.NOT_LOADED
async def test_mcp_server_failure(
hass: HomeAssistant, config_entry: MockConfigEntry, mock_mcp_client: Mock
) -> None:
"""Test the integration fails to setup if the server fails initialization."""
mock_mcp_client.side_effect = httpx.HTTPStatusError(
"", request=None, response=httpx.Response(500)
)
with patch("homeassistant.components.mcp.coordinator.TIMEOUT", 1):
await hass.config_entries.async_setup(config_entry.entry_id)
assert config_entry.state is ConfigEntryState.SETUP_RETRY
async def test_list_tools_failure(
hass: HomeAssistant, config_entry: MockConfigEntry, mock_mcp_client: Mock
) -> None:
"""Test the integration fails to load if the first data fetch returns an error."""
mock_mcp_client.return_value.list_tools.side_effect = httpx.HTTPStatusError(
"", request=None, response=httpx.Response(500)
)
await hass.config_entries.async_setup(config_entry.entry_id)
assert config_entry.state is ConfigEntryState.SETUP_RETRY
async def test_llm_get_api_tools(
hass: HomeAssistant, config_entry: MockConfigEntry, mock_mcp_client: Mock
) -> None:
"""Test MCP tools are returned as LLM API tools."""
mock_mcp_client.return_value.list_tools.return_value = ListToolsResult(
tools=[SEARCH_MEMORY_TOOL, SAVE_MEMORY_TOOL],
)
await hass.config_entries.async_setup(config_entry.entry_id)
assert config_entry.state is ConfigEntryState.LOADED
apis = llm.async_get_apis(hass)
api = next(iter([api for api in apis if api.name == TEST_API_NAME]))
assert api
api_instance = await api.async_get_api_instance(create_llm_context())
assert len(api_instance.tools) == 2
tool = api_instance.tools[0]
assert tool.name == "search_memory"
assert tool.description == "Search memory for relevant context based on a query."
with pytest.raises(
vol.Invalid, match=re.escape("required key not provided @ data['query']")
):
tool.parameters({})
assert tool.parameters({"query": "frogs"}) == {"query": "frogs"}
tool = api_instance.tools[1]
assert tool.name == "save_memory"
assert tool.description == "Save a memory context."
with pytest.raises(
vol.Invalid, match=re.escape("required key not provided @ data['context']")
):
tool.parameters({})
assert tool.parameters({"context": {"fact": "User was born in February"}}) == {
"context": {"fact": "User was born in February"}
}
async def test_call_tool(
hass: HomeAssistant, config_entry: MockConfigEntry, mock_mcp_client: Mock
) -> None:
"""Test calling an MCP Tool through the LLM API."""
mock_mcp_client.return_value.list_tools.return_value = ListToolsResult(
tools=[SEARCH_MEMORY_TOOL]
)
await hass.config_entries.async_setup(config_entry.entry_id)
assert config_entry.state is ConfigEntryState.LOADED
apis = llm.async_get_apis(hass)
api = next(iter([api for api in apis if api.name == TEST_API_NAME]))
assert api
api_instance = await api.async_get_api_instance(create_llm_context())
assert len(api_instance.tools) == 1
tool = api_instance.tools[0]
assert tool.name == "search_memory"
mock_mcp_client.return_value.call_tool.return_value = CallToolResult(
content=[TextContent(type="text", text="User was born in February")]
)
result = await tool.async_call(
hass,
llm.ToolInput(
tool_name="search_memory", tool_args={"query": "User's birth month"}
),
create_llm_context(),
)
assert result == {
"content": [{"text": "User was born in February", "type": "text"}]
}
async def test_call_tool_fails(
hass: HomeAssistant, config_entry: MockConfigEntry, mock_mcp_client: Mock
) -> None:
"""Test handling an MCP Tool call failure."""
mock_mcp_client.return_value.list_tools.return_value = ListToolsResult(
tools=[SEARCH_MEMORY_TOOL]
)
await hass.config_entries.async_setup(config_entry.entry_id)
assert config_entry.state is ConfigEntryState.LOADED
apis = llm.async_get_apis(hass)
api = next(iter([api for api in apis if api.name == TEST_API_NAME]))
assert api
api_instance = await api.async_get_api_instance(create_llm_context())
assert len(api_instance.tools) == 1
tool = api_instance.tools[0]
assert tool.name == "search_memory"
mock_mcp_client.return_value.call_tool.side_effect = httpx.HTTPStatusError(
"Server error", request=None, response=httpx.Response(500)
)
with pytest.raises(
HomeAssistantError, match="Error when calling tool: Server error"
):
await tool.async_call(
hass,
llm.ToolInput(
tool_name="search_memory", tool_args={"query": "User's birth month"}
),
create_llm_context(),
)
async def test_convert_tool_schema_fails(
hass: HomeAssistant, config_entry: MockConfigEntry, mock_mcp_client: Mock
) -> None:
"""Test a failure converting an MCP tool schema to a Home Assistant schema."""
mock_mcp_client.return_value.list_tools.return_value = ListToolsResult(
tools=[SEARCH_MEMORY_TOOL]
)
with patch(
"homeassistant.components.mcp.coordinator.convert_to_voluptuous",
side_effect=ValueError,
):
await hass.config_entries.async_setup(config_entry.entry_id)
assert config_entry.state is ConfigEntryState.SETUP_RETRY