531 lines
18 KiB
Python
531 lines
18 KiB
Python
"""Tests for the OpenAI integration."""
|
|
|
|
from unittest.mock import AsyncMock, patch
|
|
|
|
from httpx import Response
|
|
from openai import RateLimitError
|
|
from openai.types.chat.chat_completion import ChatCompletion, Choice
|
|
from openai.types.chat.chat_completion_message import ChatCompletionMessage
|
|
from openai.types.chat.chat_completion_message_tool_call import (
|
|
ChatCompletionMessageToolCall,
|
|
Function,
|
|
)
|
|
from openai.types.completion_usage import CompletionUsage
|
|
import pytest
|
|
from syrupy.assertion import SnapshotAssertion
|
|
import voluptuous as vol
|
|
|
|
from homeassistant.components import conversation
|
|
from homeassistant.const import CONF_LLM_HASS_API
|
|
from homeassistant.core import Context, HomeAssistant
|
|
from homeassistant.exceptions import HomeAssistantError
|
|
from homeassistant.helpers import (
|
|
area_registry as ar,
|
|
device_registry as dr,
|
|
intent,
|
|
llm,
|
|
)
|
|
from homeassistant.setup import async_setup_component
|
|
|
|
from tests.common import MockConfigEntry
|
|
|
|
|
|
@pytest.mark.parametrize("agent_id", [None, "conversation.openai"])
|
|
@pytest.mark.parametrize(
|
|
"config_entry_options", [{}, {CONF_LLM_HASS_API: llm.LLM_API_ASSIST}]
|
|
)
|
|
async def test_default_prompt(
|
|
hass: HomeAssistant,
|
|
mock_config_entry: MockConfigEntry,
|
|
mock_init_component,
|
|
area_registry: ar.AreaRegistry,
|
|
device_registry: dr.DeviceRegistry,
|
|
snapshot: SnapshotAssertion,
|
|
agent_id: str,
|
|
config_entry_options: dict,
|
|
) -> None:
|
|
"""Test that the default prompt works."""
|
|
entry = MockConfigEntry(title=None)
|
|
entry.add_to_hass(hass)
|
|
for i in range(3):
|
|
area_registry.async_create(f"{i}Empty Area")
|
|
|
|
if agent_id is None:
|
|
agent_id = mock_config_entry.entry_id
|
|
|
|
hass.config_entries.async_update_entry(
|
|
mock_config_entry,
|
|
options={
|
|
**mock_config_entry.options,
|
|
CONF_LLM_HASS_API: llm.LLM_API_ASSIST,
|
|
},
|
|
)
|
|
|
|
device_registry.async_get_or_create(
|
|
config_entry_id=entry.entry_id,
|
|
connections={("test", "1234")},
|
|
name="Test Device",
|
|
manufacturer="Test Manufacturer",
|
|
model="Test Model",
|
|
suggested_area="Test Area",
|
|
)
|
|
for i in range(3):
|
|
device_registry.async_get_or_create(
|
|
config_entry_id=entry.entry_id,
|
|
connections={("test", f"{i}abcd")},
|
|
name="Test Service",
|
|
manufacturer="Test Manufacturer",
|
|
model="Test Model",
|
|
suggested_area="Test Area",
|
|
entry_type=dr.DeviceEntryType.SERVICE,
|
|
)
|
|
device_registry.async_get_or_create(
|
|
config_entry_id=entry.entry_id,
|
|
connections={("test", "5678")},
|
|
name="Test Device 2",
|
|
manufacturer="Test Manufacturer 2",
|
|
model="Device 2",
|
|
suggested_area="Test Area 2",
|
|
)
|
|
device_registry.async_get_or_create(
|
|
config_entry_id=entry.entry_id,
|
|
connections={("test", "9876")},
|
|
name="Test Device 3",
|
|
manufacturer="Test Manufacturer 3",
|
|
model="Test Model 3A",
|
|
suggested_area="Test Area 2",
|
|
)
|
|
device_registry.async_get_or_create(
|
|
config_entry_id=entry.entry_id,
|
|
connections={("test", "qwer")},
|
|
name="Test Device 4",
|
|
suggested_area="Test Area 2",
|
|
)
|
|
device = device_registry.async_get_or_create(
|
|
config_entry_id=entry.entry_id,
|
|
connections={("test", "9876-disabled")},
|
|
name="Test Device 3",
|
|
manufacturer="Test Manufacturer 3",
|
|
model="Test Model 3A",
|
|
suggested_area="Test Area 2",
|
|
)
|
|
device_registry.async_update_device(
|
|
device.id, disabled_by=dr.DeviceEntryDisabler.USER
|
|
)
|
|
device_registry.async_get_or_create(
|
|
config_entry_id=entry.entry_id,
|
|
connections={("test", "9876-no-name")},
|
|
manufacturer="Test Manufacturer NoName",
|
|
model="Test Model NoName",
|
|
suggested_area="Test Area 2",
|
|
)
|
|
device_registry.async_get_or_create(
|
|
config_entry_id=entry.entry_id,
|
|
connections={("test", "9876-integer-values")},
|
|
name=1,
|
|
manufacturer=2,
|
|
model=3,
|
|
suggested_area="Test Area 2",
|
|
)
|
|
with patch(
|
|
"openai.resources.chat.completions.AsyncCompletions.create",
|
|
new_callable=AsyncMock,
|
|
return_value=ChatCompletion(
|
|
id="chatcmpl-1234567890ABCDEFGHIJKLMNOPQRS",
|
|
choices=[
|
|
Choice(
|
|
finish_reason="stop",
|
|
index=0,
|
|
message=ChatCompletionMessage(
|
|
content="Hello, how can I help you?",
|
|
role="assistant",
|
|
function_call=None,
|
|
tool_calls=None,
|
|
),
|
|
)
|
|
],
|
|
created=1700000000,
|
|
model="gpt-3.5-turbo-0613",
|
|
object="chat.completion",
|
|
system_fingerprint=None,
|
|
usage=CompletionUsage(
|
|
completion_tokens=9, prompt_tokens=8, total_tokens=17
|
|
),
|
|
),
|
|
) as mock_create:
|
|
result = await conversation.async_converse(
|
|
hass, "hello", None, Context(), agent_id=agent_id
|
|
)
|
|
|
|
assert result.response.response_type == intent.IntentResponseType.ACTION_DONE
|
|
assert mock_create.mock_calls[0][2]["messages"] == snapshot
|
|
|
|
|
|
async def test_error_handling(
|
|
hass: HomeAssistant, mock_config_entry: MockConfigEntry, mock_init_component
|
|
) -> None:
|
|
"""Test that the default prompt works."""
|
|
with patch(
|
|
"openai.resources.chat.completions.AsyncCompletions.create",
|
|
new_callable=AsyncMock,
|
|
side_effect=RateLimitError(
|
|
response=Response(status_code=None, request=""), body=None, message=None
|
|
),
|
|
):
|
|
result = await conversation.async_converse(
|
|
hass, "hello", None, Context(), agent_id=mock_config_entry.entry_id
|
|
)
|
|
|
|
assert result.response.response_type == intent.IntentResponseType.ERROR, result
|
|
assert result.response.error_code == "unknown", result
|
|
|
|
|
|
async def test_template_error(
|
|
hass: HomeAssistant, mock_config_entry: MockConfigEntry
|
|
) -> None:
|
|
"""Test that template error handling works."""
|
|
hass.config_entries.async_update_entry(
|
|
mock_config_entry,
|
|
options={
|
|
"prompt": "talk like a {% if True %}smarthome{% else %}pirate please.",
|
|
},
|
|
)
|
|
with (
|
|
patch(
|
|
"openai.resources.models.AsyncModels.list",
|
|
),
|
|
patch(
|
|
"openai.resources.chat.completions.AsyncCompletions.create",
|
|
new_callable=AsyncMock,
|
|
),
|
|
):
|
|
await hass.config_entries.async_setup(mock_config_entry.entry_id)
|
|
await hass.async_block_till_done()
|
|
result = await conversation.async_converse(
|
|
hass, "hello", None, Context(), agent_id=mock_config_entry.entry_id
|
|
)
|
|
|
|
assert result.response.response_type == intent.IntentResponseType.ERROR, result
|
|
assert result.response.error_code == "unknown", result
|
|
|
|
|
|
async def test_conversation_agent(
|
|
hass: HomeAssistant,
|
|
mock_config_entry: MockConfigEntry,
|
|
mock_init_component,
|
|
) -> None:
|
|
"""Test OpenAIAgent."""
|
|
agent = conversation.get_agent_manager(hass).async_get_agent(
|
|
mock_config_entry.entry_id
|
|
)
|
|
assert agent.supported_languages == "*"
|
|
|
|
|
|
@patch(
|
|
"homeassistant.components.openai_conversation.conversation.llm.AssistAPI.async_get_tools"
|
|
)
|
|
async def test_function_call(
|
|
mock_get_tools,
|
|
hass: HomeAssistant,
|
|
mock_config_entry_with_assist: MockConfigEntry,
|
|
mock_init_component,
|
|
) -> None:
|
|
"""Test function call from the assistant."""
|
|
agent_id = mock_config_entry_with_assist.entry_id
|
|
context = Context()
|
|
|
|
mock_tool = AsyncMock()
|
|
mock_tool.name = "test_tool"
|
|
mock_tool.description = "Test function"
|
|
mock_tool.parameters = vol.Schema(
|
|
{vol.Optional("param1", description="Test parameters"): str}
|
|
)
|
|
mock_tool.async_call.return_value = "Test response"
|
|
|
|
mock_get_tools.return_value = [mock_tool]
|
|
|
|
def completion_result(*args, messages, **kwargs):
|
|
for message in messages:
|
|
role = message["role"] if isinstance(message, dict) else message.role
|
|
if role == "tool":
|
|
return ChatCompletion(
|
|
id="chatcmpl-1234567890ZYXWVUTSRQPONMLKJIH",
|
|
choices=[
|
|
Choice(
|
|
finish_reason="stop",
|
|
index=0,
|
|
message=ChatCompletionMessage(
|
|
content="I have successfully called the function",
|
|
role="assistant",
|
|
function_call=None,
|
|
tool_calls=None,
|
|
),
|
|
)
|
|
],
|
|
created=1700000000,
|
|
model="gpt-4-1106-preview",
|
|
object="chat.completion",
|
|
system_fingerprint=None,
|
|
usage=CompletionUsage(
|
|
completion_tokens=9, prompt_tokens=8, total_tokens=17
|
|
),
|
|
)
|
|
|
|
return ChatCompletion(
|
|
id="chatcmpl-1234567890ABCDEFGHIJKLMNOPQRS",
|
|
choices=[
|
|
Choice(
|
|
finish_reason="tool_calls",
|
|
index=0,
|
|
message=ChatCompletionMessage(
|
|
content=None,
|
|
role="assistant",
|
|
function_call=None,
|
|
tool_calls=[
|
|
ChatCompletionMessageToolCall(
|
|
id="call_AbCdEfGhIjKlMnOpQrStUvWx",
|
|
function=Function(
|
|
arguments='{"param1":"test_value"}',
|
|
name="test_tool",
|
|
),
|
|
type="function",
|
|
)
|
|
],
|
|
),
|
|
)
|
|
],
|
|
created=1700000000,
|
|
model="gpt-4-1106-preview",
|
|
object="chat.completion",
|
|
system_fingerprint=None,
|
|
usage=CompletionUsage(
|
|
completion_tokens=9, prompt_tokens=8, total_tokens=17
|
|
),
|
|
)
|
|
|
|
with patch(
|
|
"openai.resources.chat.completions.AsyncCompletions.create",
|
|
new_callable=AsyncMock,
|
|
side_effect=completion_result,
|
|
) as mock_create:
|
|
result = await conversation.async_converse(
|
|
hass,
|
|
"Please call the test function",
|
|
None,
|
|
context,
|
|
agent_id=agent_id,
|
|
)
|
|
|
|
assert result.response.response_type == intent.IntentResponseType.ACTION_DONE
|
|
assert mock_create.mock_calls[1][2]["messages"][3] == {
|
|
"role": "tool",
|
|
"tool_call_id": "call_AbCdEfGhIjKlMnOpQrStUvWx",
|
|
"name": "test_tool",
|
|
"content": '"Test response"',
|
|
}
|
|
mock_tool.async_call.assert_awaited_once_with(
|
|
hass,
|
|
llm.ToolInput(
|
|
tool_name="test_tool",
|
|
tool_args={"param1": "test_value"},
|
|
platform="openai_conversation",
|
|
context=context,
|
|
user_prompt="Please call the test function",
|
|
language="en",
|
|
assistant="conversation",
|
|
device_id=None,
|
|
),
|
|
)
|
|
|
|
|
|
@patch(
|
|
"homeassistant.components.openai_conversation.conversation.llm.AssistAPI.async_get_tools"
|
|
)
|
|
async def test_function_exception(
|
|
mock_get_tools,
|
|
hass: HomeAssistant,
|
|
mock_config_entry_with_assist: MockConfigEntry,
|
|
mock_init_component,
|
|
) -> None:
|
|
"""Test function call with exception."""
|
|
agent_id = mock_config_entry_with_assist.entry_id
|
|
context = Context()
|
|
|
|
mock_tool = AsyncMock()
|
|
mock_tool.name = "test_tool"
|
|
mock_tool.description = "Test function"
|
|
mock_tool.parameters = vol.Schema(
|
|
{vol.Optional("param1", description="Test parameters"): str}
|
|
)
|
|
mock_tool.async_call.side_effect = HomeAssistantError("Test tool exception")
|
|
|
|
mock_get_tools.return_value = [mock_tool]
|
|
|
|
def completion_result(*args, messages, **kwargs):
|
|
for message in messages:
|
|
role = message["role"] if isinstance(message, dict) else message.role
|
|
if role == "tool":
|
|
return ChatCompletion(
|
|
id="chatcmpl-1234567890ZYXWVUTSRQPONMLKJIH",
|
|
choices=[
|
|
Choice(
|
|
finish_reason="stop",
|
|
index=0,
|
|
message=ChatCompletionMessage(
|
|
content="There was an error calling the function",
|
|
role="assistant",
|
|
function_call=None,
|
|
tool_calls=None,
|
|
),
|
|
)
|
|
],
|
|
created=1700000000,
|
|
model="gpt-4-1106-preview",
|
|
object="chat.completion",
|
|
system_fingerprint=None,
|
|
usage=CompletionUsage(
|
|
completion_tokens=9, prompt_tokens=8, total_tokens=17
|
|
),
|
|
)
|
|
|
|
return ChatCompletion(
|
|
id="chatcmpl-1234567890ABCDEFGHIJKLMNOPQRS",
|
|
choices=[
|
|
Choice(
|
|
finish_reason="tool_calls",
|
|
index=0,
|
|
message=ChatCompletionMessage(
|
|
content=None,
|
|
role="assistant",
|
|
function_call=None,
|
|
tool_calls=[
|
|
ChatCompletionMessageToolCall(
|
|
id="call_AbCdEfGhIjKlMnOpQrStUvWx",
|
|
function=Function(
|
|
arguments='{"param1":"test_value"}',
|
|
name="test_tool",
|
|
),
|
|
type="function",
|
|
)
|
|
],
|
|
),
|
|
)
|
|
],
|
|
created=1700000000,
|
|
model="gpt-4-1106-preview",
|
|
object="chat.completion",
|
|
system_fingerprint=None,
|
|
usage=CompletionUsage(
|
|
completion_tokens=9, prompt_tokens=8, total_tokens=17
|
|
),
|
|
)
|
|
|
|
with patch(
|
|
"openai.resources.chat.completions.AsyncCompletions.create",
|
|
new_callable=AsyncMock,
|
|
side_effect=completion_result,
|
|
) as mock_create:
|
|
result = await conversation.async_converse(
|
|
hass,
|
|
"Please call the test function",
|
|
None,
|
|
context,
|
|
agent_id=agent_id,
|
|
)
|
|
|
|
assert result.response.response_type == intent.IntentResponseType.ACTION_DONE
|
|
assert mock_create.mock_calls[1][2]["messages"][3] == {
|
|
"role": "tool",
|
|
"tool_call_id": "call_AbCdEfGhIjKlMnOpQrStUvWx",
|
|
"name": "test_tool",
|
|
"content": '{"error": "HomeAssistantError", "error_text": "Test tool exception"}',
|
|
}
|
|
mock_tool.async_call.assert_awaited_once_with(
|
|
hass,
|
|
llm.ToolInput(
|
|
tool_name="test_tool",
|
|
tool_args={"param1": "test_value"},
|
|
platform="openai_conversation",
|
|
context=context,
|
|
user_prompt="Please call the test function",
|
|
language="en",
|
|
assistant="conversation",
|
|
device_id=None,
|
|
),
|
|
)
|
|
|
|
|
|
async def test_assist_api_tools_conversion(
|
|
hass: HomeAssistant,
|
|
mock_config_entry_with_assist: MockConfigEntry,
|
|
mock_init_component,
|
|
) -> None:
|
|
"""Test that we are able to convert actual tools from Assist API."""
|
|
for component in [
|
|
"intent",
|
|
"todo",
|
|
"light",
|
|
"shopping_list",
|
|
"humidifier",
|
|
"climate",
|
|
"media_player",
|
|
"vacuum",
|
|
"cover",
|
|
"weather",
|
|
]:
|
|
assert await async_setup_component(hass, component, {})
|
|
|
|
agent_id = mock_config_entry_with_assist.entry_id
|
|
with patch(
|
|
"openai.resources.chat.completions.AsyncCompletions.create",
|
|
new_callable=AsyncMock,
|
|
return_value=ChatCompletion(
|
|
id="chatcmpl-1234567890ABCDEFGHIJKLMNOPQRS",
|
|
choices=[
|
|
Choice(
|
|
finish_reason="stop",
|
|
index=0,
|
|
message=ChatCompletionMessage(
|
|
content="Hello, how can I help you?",
|
|
role="assistant",
|
|
function_call=None,
|
|
tool_calls=None,
|
|
),
|
|
)
|
|
],
|
|
created=1700000000,
|
|
model="gpt-3.5-turbo-0613",
|
|
object="chat.completion",
|
|
system_fingerprint=None,
|
|
usage=CompletionUsage(
|
|
completion_tokens=9, prompt_tokens=8, total_tokens=17
|
|
),
|
|
),
|
|
) as mock_create:
|
|
await conversation.async_converse(hass, "hello", None, None, agent_id=agent_id)
|
|
|
|
tools = mock_create.mock_calls[0][2]["tools"]
|
|
assert tools
|
|
|
|
|
|
async def test_unknown_hass_api(
|
|
hass: HomeAssistant,
|
|
mock_config_entry: MockConfigEntry,
|
|
snapshot: SnapshotAssertion,
|
|
mock_init_component,
|
|
) -> None:
|
|
"""Test when we reference an API that no longer exists."""
|
|
hass.config_entries.async_update_entry(
|
|
mock_config_entry,
|
|
options={
|
|
**mock_config_entry.options,
|
|
CONF_LLM_HASS_API: "non-existing",
|
|
},
|
|
)
|
|
|
|
result = await conversation.async_converse(
|
|
hass, "hello", None, Context(), agent_id=mock_config_entry.entry_id
|
|
)
|
|
|
|
assert result == snapshot
|