367 lines
14 KiB
Python
367 lines
14 KiB
Python
"""Tests for the Ollama integration."""
|
|
|
|
from unittest.mock import AsyncMock, patch
|
|
|
|
from httpx import ConnectError
|
|
from ollama import Message, ResponseError
|
|
import pytest
|
|
|
|
from homeassistant.components import conversation, ollama
|
|
from homeassistant.components.homeassistant.exposed_entities import async_expose_entity
|
|
from homeassistant.const import ATTR_FRIENDLY_NAME, MATCH_ALL
|
|
from homeassistant.core import Context, HomeAssistant
|
|
from homeassistant.helpers import (
|
|
area_registry as ar,
|
|
device_registry as dr,
|
|
entity_registry as er,
|
|
intent,
|
|
)
|
|
from homeassistant.setup import async_setup_component
|
|
|
|
from tests.common import MockConfigEntry
|
|
|
|
|
|
async def test_chat(
|
|
hass: HomeAssistant,
|
|
mock_config_entry: MockConfigEntry,
|
|
mock_init_component,
|
|
area_registry: ar.AreaRegistry,
|
|
device_registry: dr.DeviceRegistry,
|
|
entity_registry: er.EntityRegistry,
|
|
) -> None:
|
|
"""Test that the chat function is called with the appropriate arguments."""
|
|
|
|
# Create some areas, devices, and entities
|
|
area_kitchen = area_registry.async_get_or_create("kitchen_id")
|
|
area_kitchen = area_registry.async_update(area_kitchen.id, name="kitchen")
|
|
area_bedroom = area_registry.async_get_or_create("bedroom_id")
|
|
area_bedroom = area_registry.async_update(area_bedroom.id, name="bedroom")
|
|
area_office = area_registry.async_get_or_create("office_id")
|
|
area_office = area_registry.async_update(area_office.id, name="office")
|
|
|
|
entry = MockConfigEntry()
|
|
entry.add_to_hass(hass)
|
|
kitchen_device = device_registry.async_get_or_create(
|
|
config_entry_id=entry.entry_id,
|
|
connections=set(),
|
|
identifiers={("demo", "id-1234")},
|
|
)
|
|
device_registry.async_update_device(kitchen_device.id, area_id=area_kitchen.id)
|
|
|
|
kitchen_light = entity_registry.async_get_or_create("light", "demo", "1234")
|
|
kitchen_light = entity_registry.async_update_entity(
|
|
kitchen_light.entity_id, device_id=kitchen_device.id
|
|
)
|
|
hass.states.async_set(
|
|
kitchen_light.entity_id, "on", attributes={ATTR_FRIENDLY_NAME: "kitchen light"}
|
|
)
|
|
|
|
bedroom_light = entity_registry.async_get_or_create("light", "demo", "5678")
|
|
bedroom_light = entity_registry.async_update_entity(
|
|
bedroom_light.entity_id, area_id=area_bedroom.id
|
|
)
|
|
hass.states.async_set(
|
|
bedroom_light.entity_id, "on", attributes={ATTR_FRIENDLY_NAME: "bedroom light"}
|
|
)
|
|
|
|
# Hide the office light
|
|
office_light = entity_registry.async_get_or_create("light", "demo", "ABCD")
|
|
office_light = entity_registry.async_update_entity(
|
|
office_light.entity_id, area_id=area_office.id
|
|
)
|
|
hass.states.async_set(
|
|
office_light.entity_id, "on", attributes={ATTR_FRIENDLY_NAME: "office light"}
|
|
)
|
|
async_expose_entity(hass, conversation.DOMAIN, office_light.entity_id, False)
|
|
|
|
with patch(
|
|
"ollama.AsyncClient.chat",
|
|
return_value={"message": {"role": "assistant", "content": "test response"}},
|
|
) as mock_chat:
|
|
result = await conversation.async_converse(
|
|
hass,
|
|
"test message",
|
|
None,
|
|
Context(),
|
|
agent_id=mock_config_entry.entry_id,
|
|
)
|
|
|
|
assert mock_chat.call_count == 1
|
|
args = mock_chat.call_args.kwargs
|
|
prompt = args["messages"][0]["content"]
|
|
|
|
assert args["model"] == "test model"
|
|
assert args["messages"] == [
|
|
Message({"role": "system", "content": prompt}),
|
|
Message({"role": "user", "content": "test message"}),
|
|
]
|
|
|
|
# Verify only exposed devices/areas are in prompt
|
|
assert "kitchen light" in prompt
|
|
assert "bedroom light" in prompt
|
|
assert "office light" not in prompt
|
|
assert "office" not in prompt
|
|
|
|
assert (
|
|
result.response.response_type == intent.IntentResponseType.ACTION_DONE
|
|
), result
|
|
assert result.response.speech["plain"]["speech"] == "test response"
|
|
|
|
|
|
async def test_message_history_trimming(
|
|
hass: HomeAssistant, mock_config_entry: MockConfigEntry, mock_init_component
|
|
) -> None:
|
|
"""Test that a single message history is trimmed according to the config."""
|
|
response_idx = 0
|
|
|
|
def response(*args, **kwargs) -> dict:
|
|
nonlocal response_idx
|
|
response_idx += 1
|
|
return {"message": {"role": "assistant", "content": f"response {response_idx}"}}
|
|
|
|
with patch(
|
|
"ollama.AsyncClient.chat",
|
|
side_effect=response,
|
|
) as mock_chat:
|
|
# mock_init_component sets "max_history" to 2
|
|
for i in range(5):
|
|
result = await conversation.async_converse(
|
|
hass,
|
|
f"message {i+1}",
|
|
conversation_id="1234",
|
|
context=Context(),
|
|
agent_id=mock_config_entry.entry_id,
|
|
)
|
|
assert (
|
|
result.response.response_type == intent.IntentResponseType.ACTION_DONE
|
|
), result
|
|
|
|
assert mock_chat.call_count == 5
|
|
args = mock_chat.call_args_list
|
|
prompt = args[0].kwargs["messages"][0]["content"]
|
|
|
|
# system + user-1
|
|
assert len(args[0].kwargs["messages"]) == 2
|
|
assert args[0].kwargs["messages"][1]["content"] == "message 1"
|
|
|
|
# Full history
|
|
# system + user-1 + assistant-1 + user-2
|
|
assert len(args[1].kwargs["messages"]) == 4
|
|
assert args[1].kwargs["messages"][0]["role"] == "system"
|
|
assert args[1].kwargs["messages"][0]["content"] == prompt
|
|
assert args[1].kwargs["messages"][1]["role"] == "user"
|
|
assert args[1].kwargs["messages"][1]["content"] == "message 1"
|
|
assert args[1].kwargs["messages"][2]["role"] == "assistant"
|
|
assert args[1].kwargs["messages"][2]["content"] == "response 1"
|
|
assert args[1].kwargs["messages"][3]["role"] == "user"
|
|
assert args[1].kwargs["messages"][3]["content"] == "message 2"
|
|
|
|
# Full history
|
|
# system + user-1 + assistant-1 + user-2 + assistant-2 + user-3
|
|
assert len(args[2].kwargs["messages"]) == 6
|
|
assert args[2].kwargs["messages"][0]["role"] == "system"
|
|
assert args[2].kwargs["messages"][0]["content"] == prompt
|
|
assert args[2].kwargs["messages"][1]["role"] == "user"
|
|
assert args[2].kwargs["messages"][1]["content"] == "message 1"
|
|
assert args[2].kwargs["messages"][2]["role"] == "assistant"
|
|
assert args[2].kwargs["messages"][2]["content"] == "response 1"
|
|
assert args[2].kwargs["messages"][3]["role"] == "user"
|
|
assert args[2].kwargs["messages"][3]["content"] == "message 2"
|
|
assert args[2].kwargs["messages"][4]["role"] == "assistant"
|
|
assert args[2].kwargs["messages"][4]["content"] == "response 2"
|
|
assert args[2].kwargs["messages"][5]["role"] == "user"
|
|
assert args[2].kwargs["messages"][5]["content"] == "message 3"
|
|
|
|
# Trimmed down to two user messages.
|
|
# system + user-2 + assistant-2 + user-3 + assistant-3 + user-4
|
|
assert len(args[3].kwargs["messages"]) == 6
|
|
assert args[3].kwargs["messages"][0]["role"] == "system"
|
|
assert args[3].kwargs["messages"][0]["content"] == prompt
|
|
assert args[3].kwargs["messages"][1]["role"] == "user"
|
|
assert args[3].kwargs["messages"][1]["content"] == "message 2"
|
|
assert args[3].kwargs["messages"][2]["role"] == "assistant"
|
|
assert args[3].kwargs["messages"][2]["content"] == "response 2"
|
|
assert args[3].kwargs["messages"][3]["role"] == "user"
|
|
assert args[3].kwargs["messages"][3]["content"] == "message 3"
|
|
assert args[3].kwargs["messages"][4]["role"] == "assistant"
|
|
assert args[3].kwargs["messages"][4]["content"] == "response 3"
|
|
assert args[3].kwargs["messages"][5]["role"] == "user"
|
|
assert args[3].kwargs["messages"][5]["content"] == "message 4"
|
|
|
|
# Trimmed down to two user messages.
|
|
# system + user-3 + assistant-3 + user-4 + assistant-4 + user-5
|
|
assert len(args[3].kwargs["messages"]) == 6
|
|
assert args[4].kwargs["messages"][0]["role"] == "system"
|
|
assert args[4].kwargs["messages"][0]["content"] == prompt
|
|
assert args[4].kwargs["messages"][1]["role"] == "user"
|
|
assert args[4].kwargs["messages"][1]["content"] == "message 3"
|
|
assert args[4].kwargs["messages"][2]["role"] == "assistant"
|
|
assert args[4].kwargs["messages"][2]["content"] == "response 3"
|
|
assert args[4].kwargs["messages"][3]["role"] == "user"
|
|
assert args[4].kwargs["messages"][3]["content"] == "message 4"
|
|
assert args[4].kwargs["messages"][4]["role"] == "assistant"
|
|
assert args[4].kwargs["messages"][4]["content"] == "response 4"
|
|
assert args[4].kwargs["messages"][5]["role"] == "user"
|
|
assert args[4].kwargs["messages"][5]["content"] == "message 5"
|
|
|
|
|
|
async def test_message_history_pruning(
|
|
hass: HomeAssistant, mock_config_entry: MockConfigEntry, mock_init_component
|
|
) -> None:
|
|
"""Test that old message histories are pruned."""
|
|
with patch(
|
|
"ollama.AsyncClient.chat",
|
|
return_value={"message": {"role": "assistant", "content": "test response"}},
|
|
):
|
|
# Create 3 different message histories
|
|
conversation_ids: list[str] = []
|
|
for i in range(3):
|
|
result = await conversation.async_converse(
|
|
hass,
|
|
f"message {i+1}",
|
|
conversation_id=None,
|
|
context=Context(),
|
|
agent_id=mock_config_entry.entry_id,
|
|
)
|
|
assert (
|
|
result.response.response_type == intent.IntentResponseType.ACTION_DONE
|
|
), result
|
|
assert isinstance(result.conversation_id, str)
|
|
conversation_ids.append(result.conversation_id)
|
|
|
|
agent = await conversation._get_agent_manager(hass).async_get_agent(
|
|
mock_config_entry.entry_id
|
|
)
|
|
assert isinstance(agent, ollama.OllamaAgent)
|
|
assert len(agent._history) == 3
|
|
assert agent._history.keys() == set(conversation_ids)
|
|
|
|
# Modify the timestamps of the first 2 histories so they will be pruned
|
|
# on the next cycle.
|
|
for conversation_id in conversation_ids[:2]:
|
|
# Move back 2 hours
|
|
agent._history[conversation_id].timestamp -= 2 * 60 * 60
|
|
|
|
# Next cycle
|
|
result = await conversation.async_converse(
|
|
hass,
|
|
"test message",
|
|
conversation_id=None,
|
|
context=Context(),
|
|
agent_id=mock_config_entry.entry_id,
|
|
)
|
|
assert (
|
|
result.response.response_type == intent.IntentResponseType.ACTION_DONE
|
|
), result
|
|
|
|
# Only the most recent histories should remain
|
|
assert len(agent._history) == 2
|
|
assert conversation_ids[-1] in agent._history
|
|
assert result.conversation_id in agent._history
|
|
|
|
|
|
async def test_message_history_unlimited(
|
|
hass: HomeAssistant, mock_config_entry: MockConfigEntry, mock_init_component
|
|
) -> None:
|
|
"""Test that message history is not trimmed when max_history = 0."""
|
|
conversation_id = "1234"
|
|
with (
|
|
patch(
|
|
"ollama.AsyncClient.chat",
|
|
return_value={"message": {"role": "assistant", "content": "test response"}},
|
|
),
|
|
patch.object(mock_config_entry, "options", {ollama.CONF_MAX_HISTORY: 0}),
|
|
):
|
|
for i in range(100):
|
|
result = await conversation.async_converse(
|
|
hass,
|
|
f"message {i+1}",
|
|
conversation_id=conversation_id,
|
|
context=Context(),
|
|
agent_id=mock_config_entry.entry_id,
|
|
)
|
|
assert (
|
|
result.response.response_type == intent.IntentResponseType.ACTION_DONE
|
|
), result
|
|
|
|
agent = await conversation._get_agent_manager(hass).async_get_agent(
|
|
mock_config_entry.entry_id
|
|
)
|
|
assert isinstance(agent, ollama.OllamaAgent)
|
|
|
|
assert len(agent._history) == 1
|
|
assert conversation_id in agent._history
|
|
assert agent._history[conversation_id].num_user_messages == 100
|
|
|
|
|
|
async def test_error_handling(
|
|
hass: HomeAssistant, mock_config_entry: MockConfigEntry, mock_init_component
|
|
) -> None:
|
|
"""Test error handling during converse."""
|
|
with patch(
|
|
"ollama.AsyncClient.chat",
|
|
new_callable=AsyncMock,
|
|
side_effect=ResponseError("test error"),
|
|
):
|
|
result = await conversation.async_converse(
|
|
hass, "hello", None, Context(), agent_id=mock_config_entry.entry_id
|
|
)
|
|
|
|
assert result.response.response_type == intent.IntentResponseType.ERROR, result
|
|
assert result.response.error_code == "unknown", result
|
|
|
|
|
|
async def test_template_error(
|
|
hass: HomeAssistant, mock_config_entry: MockConfigEntry
|
|
) -> None:
|
|
"""Test that template error handling works."""
|
|
hass.config_entries.async_update_entry(
|
|
mock_config_entry,
|
|
options={
|
|
"prompt": "talk like a {% if True %}smarthome{% else %}pirate please.",
|
|
},
|
|
)
|
|
with patch(
|
|
"ollama.AsyncClient.list",
|
|
):
|
|
await hass.config_entries.async_setup(mock_config_entry.entry_id)
|
|
await hass.async_block_till_done()
|
|
result = await conversation.async_converse(
|
|
hass, "hello", None, Context(), agent_id=mock_config_entry.entry_id
|
|
)
|
|
|
|
assert result.response.response_type == intent.IntentResponseType.ERROR, result
|
|
assert result.response.error_code == "unknown", result
|
|
|
|
|
|
async def test_conversation_agent(
|
|
hass: HomeAssistant,
|
|
mock_config_entry: MockConfigEntry,
|
|
mock_init_component,
|
|
) -> None:
|
|
"""Test OllamaAgent."""
|
|
agent = await conversation._get_agent_manager(hass).async_get_agent(
|
|
mock_config_entry.entry_id
|
|
)
|
|
assert agent.supported_languages == MATCH_ALL
|
|
|
|
|
|
@pytest.mark.parametrize(
|
|
("side_effect", "error"),
|
|
[
|
|
(ConnectError(message="Connect error"), "Connect error"),
|
|
(RuntimeError("Runtime error"), "Runtime error"),
|
|
],
|
|
)
|
|
async def test_init_error(
|
|
hass: HomeAssistant, mock_config_entry: MockConfigEntry, caplog, side_effect, error
|
|
) -> None:
|
|
"""Test initialization errors."""
|
|
with patch(
|
|
"ollama.AsyncClient.list",
|
|
side_effect=side_effect,
|
|
):
|
|
assert await async_setup_component(hass, ollama.DOMAIN, {})
|
|
await hass.async_block_till_done()
|
|
assert error in caplog.text
|