Add OpenAI conversation entity (#114942)

* Add OpenAI conversation entity

* Add migration
pull/115322/head
Paulus Schoutsen 2024-04-09 11:10:03 -04:00 committed by GitHub
parent 51d5d51248
commit 2df6f1849f
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
9 changed files with 425 additions and 334 deletions

View File

@ -2,53 +2,30 @@
from __future__ import annotations from __future__ import annotations
import logging
from typing import Literal
import openai import openai
import voluptuous as vol import voluptuous as vol
from homeassistant.components import conversation from homeassistant.components import conversation
from homeassistant.config_entries import ConfigEntry from homeassistant.config_entries import ConfigEntry
from homeassistant.const import CONF_API_KEY, MATCH_ALL from homeassistant.const import CONF_API_KEY, Platform
from homeassistant.core import ( from homeassistant.core import (
HomeAssistant, HomeAssistant,
ServiceCall, ServiceCall,
ServiceResponse, ServiceResponse,
SupportsResponse, SupportsResponse,
) )
from homeassistant.exceptions import ( from homeassistant.exceptions import ConfigEntryNotReady, HomeAssistantError
ConfigEntryNotReady,
HomeAssistantError,
TemplateError,
)
from homeassistant.helpers import ( from homeassistant.helpers import (
config_validation as cv, config_validation as cv,
intent,
issue_registry as ir, issue_registry as ir,
selector, selector,
template,
) )
from homeassistant.helpers.typing import ConfigType from homeassistant.helpers.typing import ConfigType
from homeassistant.util import ulid
from .const import ( from .const import DOMAIN, LOGGER
CONF_CHAT_MODEL,
CONF_MAX_TOKENS,
CONF_PROMPT,
CONF_TEMPERATURE,
CONF_TOP_P,
DEFAULT_CHAT_MODEL,
DEFAULT_MAX_TOKENS,
DEFAULT_PROMPT,
DEFAULT_TEMPERATURE,
DEFAULT_TOP_P,
DOMAIN,
)
_LOGGER = logging.getLogger(__name__)
SERVICE_GENERATE_IMAGE = "generate_image" SERVICE_GENERATE_IMAGE = "generate_image"
PLATFORMS = (Platform.CONVERSATION,)
CONFIG_SCHEMA = cv.config_entry_only_config_schema(DOMAIN) CONFIG_SCHEMA = cv.config_entry_only_config_schema(DOMAIN)
@ -120,108 +97,23 @@ async def async_setup_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool:
try: try:
await hass.async_add_executor_job(client.with_options(timeout=10.0).models.list) await hass.async_add_executor_job(client.with_options(timeout=10.0).models.list)
except openai.AuthenticationError as err: except openai.AuthenticationError as err:
_LOGGER.error("Invalid API key: %s", err) LOGGER.error("Invalid API key: %s", err)
return False return False
except openai.OpenAIError as err: except openai.OpenAIError as err:
raise ConfigEntryNotReady(err) from err raise ConfigEntryNotReady(err) from err
hass.data.setdefault(DOMAIN, {})[entry.entry_id] = client hass.data.setdefault(DOMAIN, {})[entry.entry_id] = client
conversation.async_set_agent(hass, entry, OpenAIAgent(hass, entry)) await hass.config_entries.async_forward_entry_setups(entry, PLATFORMS)
return True return True
async def async_unload_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool: async def async_unload_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool:
"""Unload OpenAI.""" """Unload OpenAI."""
if not await hass.config_entries.async_unload_platforms(entry, PLATFORMS):
return False
hass.data[DOMAIN].pop(entry.entry_id) hass.data[DOMAIN].pop(entry.entry_id)
conversation.async_unset_agent(hass, entry) conversation.async_unset_agent(hass, entry)
return True return True
class OpenAIAgent(conversation.AbstractConversationAgent):
"""OpenAI conversation agent."""
def __init__(self, hass: HomeAssistant, entry: ConfigEntry) -> None:
"""Initialize the agent."""
self.hass = hass
self.entry = entry
self.history: dict[str, list[dict]] = {}
@property
def supported_languages(self) -> list[str] | Literal["*"]:
"""Return a list of supported languages."""
return MATCH_ALL
async def async_process(
self, user_input: conversation.ConversationInput
) -> conversation.ConversationResult:
"""Process a sentence."""
raw_prompt = self.entry.options.get(CONF_PROMPT, DEFAULT_PROMPT)
model = self.entry.options.get(CONF_CHAT_MODEL, DEFAULT_CHAT_MODEL)
max_tokens = self.entry.options.get(CONF_MAX_TOKENS, DEFAULT_MAX_TOKENS)
top_p = self.entry.options.get(CONF_TOP_P, DEFAULT_TOP_P)
temperature = self.entry.options.get(CONF_TEMPERATURE, DEFAULT_TEMPERATURE)
if user_input.conversation_id in self.history:
conversation_id = user_input.conversation_id
messages = self.history[conversation_id]
else:
conversation_id = ulid.ulid_now()
try:
prompt = self._async_generate_prompt(raw_prompt)
except TemplateError as err:
_LOGGER.error("Error rendering prompt: %s", err)
intent_response = intent.IntentResponse(language=user_input.language)
intent_response.async_set_error(
intent.IntentResponseErrorCode.UNKNOWN,
f"Sorry, I had a problem with my template: {err}",
)
return conversation.ConversationResult(
response=intent_response, conversation_id=conversation_id
)
messages = [{"role": "system", "content": prompt}]
messages.append({"role": "user", "content": user_input.text})
_LOGGER.debug("Prompt for %s: %s", model, messages)
client = self.hass.data[DOMAIN][self.entry.entry_id]
try:
result = await client.chat.completions.create(
model=model,
messages=messages,
max_tokens=max_tokens,
top_p=top_p,
temperature=temperature,
user=conversation_id,
)
except openai.OpenAIError as err:
intent_response = intent.IntentResponse(language=user_input.language)
intent_response.async_set_error(
intent.IntentResponseErrorCode.UNKNOWN,
f"Sorry, I had a problem talking to OpenAI: {err}",
)
return conversation.ConversationResult(
response=intent_response, conversation_id=conversation_id
)
_LOGGER.debug("Response %s", result)
response = result.choices[0].message.model_dump(include={"role", "content"})
messages.append(response)
self.history[conversation_id] = messages
intent_response = intent.IntentResponse(language=user_input.language)
intent_response.async_set_speech(response["content"])
return conversation.ConversationResult(
response=intent_response, conversation_id=conversation_id
)
def _async_generate_prompt(self, raw_prompt: str) -> str:
"""Generate a prompt for the user."""
return template.Template(raw_prompt, self.hass).async_render(
{
"ha_name": self.hass.config.location_name,
},
parse_result=False,
)

View File

@ -1,6 +1,9 @@
"""Constants for the OpenAI Conversation integration.""" """Constants for the OpenAI Conversation integration."""
import logging
DOMAIN = "openai_conversation" DOMAIN = "openai_conversation"
LOGGER = logging.getLogger(__name__)
CONF_PROMPT = "prompt" CONF_PROMPT = "prompt"
DEFAULT_PROMPT = """This smart home is controlled by Home Assistant. DEFAULT_PROMPT = """This smart home is controlled by Home Assistant.

View File

@ -0,0 +1,145 @@
"""Conversation support for OpenAI."""
from typing import Literal
import openai
from homeassistant.components import assist_pipeline, conversation
from homeassistant.config_entries import ConfigEntry
from homeassistant.const import MATCH_ALL
from homeassistant.core import HomeAssistant
from homeassistant.exceptions import TemplateError
from homeassistant.helpers import intent, template
from homeassistant.helpers.entity_platform import AddEntitiesCallback
from homeassistant.util import ulid
from .const import (
CONF_CHAT_MODEL,
CONF_MAX_TOKENS,
CONF_PROMPT,
CONF_TEMPERATURE,
CONF_TOP_P,
DEFAULT_CHAT_MODEL,
DEFAULT_MAX_TOKENS,
DEFAULT_PROMPT,
DEFAULT_TEMPERATURE,
DEFAULT_TOP_P,
DOMAIN,
LOGGER,
)
async def async_setup_entry(
hass: HomeAssistant,
config_entry: ConfigEntry,
async_add_entities: AddEntitiesCallback,
) -> None:
"""Set up conversation entities."""
agent = OpenAIConversationEntity(hass, config_entry)
async_add_entities([agent])
class OpenAIConversationEntity(
conversation.ConversationEntity, conversation.AbstractConversationAgent
):
"""OpenAI conversation agent."""
def __init__(self, hass: HomeAssistant, entry: ConfigEntry) -> None:
"""Initialize the agent."""
self.hass = hass
self.entry = entry
self.history: dict[str, list[dict]] = {}
self._attr_name = entry.title
self._attr_unique_id = entry.entry_id
@property
def supported_languages(self) -> list[str] | Literal["*"]:
"""Return a list of supported languages."""
return MATCH_ALL
async def async_added_to_hass(self) -> None:
"""When entity is added to Home Assistant."""
await super().async_added_to_hass()
assist_pipeline.async_migrate_engine(
self.hass, "conversation", self.entry.entry_id, self.entity_id
)
conversation.async_set_agent(self.hass, self.entry, self)
async def async_will_remove_from_hass(self) -> None:
"""When entity will be removed from Home Assistant."""
conversation.async_unset_agent(self.hass, self.entry)
await super().async_will_remove_from_hass()
async def async_process(
self, user_input: conversation.ConversationInput
) -> conversation.ConversationResult:
"""Process a sentence."""
raw_prompt = self.entry.options.get(CONF_PROMPT, DEFAULT_PROMPT)
model = self.entry.options.get(CONF_CHAT_MODEL, DEFAULT_CHAT_MODEL)
max_tokens = self.entry.options.get(CONF_MAX_TOKENS, DEFAULT_MAX_TOKENS)
top_p = self.entry.options.get(CONF_TOP_P, DEFAULT_TOP_P)
temperature = self.entry.options.get(CONF_TEMPERATURE, DEFAULT_TEMPERATURE)
if user_input.conversation_id in self.history:
conversation_id = user_input.conversation_id
messages = self.history[conversation_id]
else:
conversation_id = ulid.ulid_now()
try:
prompt = self._async_generate_prompt(raw_prompt)
except TemplateError as err:
LOGGER.error("Error rendering prompt: %s", err)
intent_response = intent.IntentResponse(language=user_input.language)
intent_response.async_set_error(
intent.IntentResponseErrorCode.UNKNOWN,
f"Sorry, I had a problem with my template: {err}",
)
return conversation.ConversationResult(
response=intent_response, conversation_id=conversation_id
)
messages = [{"role": "system", "content": prompt}]
messages.append({"role": "user", "content": user_input.text})
LOGGER.debug("Prompt for %s: %s", model, messages)
client = self.hass.data[DOMAIN][self.entry.entry_id]
try:
result = await client.chat.completions.create(
model=model,
messages=messages,
max_tokens=max_tokens,
top_p=top_p,
temperature=temperature,
user=conversation_id,
)
except openai.OpenAIError as err:
intent_response = intent.IntentResponse(language=user_input.language)
intent_response.async_set_error(
intent.IntentResponseErrorCode.UNKNOWN,
f"Sorry, I had a problem talking to OpenAI: {err}",
)
return conversation.ConversationResult(
response=intent_response, conversation_id=conversation_id
)
LOGGER.debug("Response %s", result)
response = result.choices[0].message.model_dump(include={"role", "content"})
messages.append(response)
self.history[conversation_id] = messages
intent_response = intent.IntentResponse(language=user_input.language)
intent_response.async_set_speech(response["content"])
return conversation.ConversationResult(
response=intent_response, conversation_id=conversation_id
)
def _async_generate_prompt(self, raw_prompt: str) -> str:
"""Generate a prompt for the user."""
return template.Template(raw_prompt, self.hass).async_render(
{
"ha_name": self.hass.config.location_name,
},
parse_result=False,
)

View File

@ -1,6 +1,7 @@
{ {
"domain": "openai_conversation", "domain": "openai_conversation",
"name": "OpenAI Conversation", "name": "OpenAI Conversation",
"after_dependencies": ["assist_pipeline"],
"codeowners": ["@balloob"], "codeowners": ["@balloob"],
"config_flow": true, "config_flow": true,
"dependencies": ["conversation"], "dependencies": ["conversation"],

View File

@ -14,6 +14,7 @@ from tests.common import MockConfigEntry
def mock_config_entry(hass): def mock_config_entry(hass):
"""Mock a config entry.""" """Mock a config entry."""
entry = MockConfigEntry( entry = MockConfigEntry(
title="OpenAI",
domain="openai_conversation", domain="openai_conversation",
data={ data={
"api_key": "bla", "api_key": "bla",

View File

@ -0,0 +1,67 @@
# serializer version: 1
# name: test_default_prompt[None]
list([
dict({
'content': '''
This smart home is controlled by Home Assistant.
An overview of the areas and the devices in this smart home:
Test Area:
- Test Device (Test Model)
Test Area 2:
- Test Device 2
- Test Device 3 (Test Model 3A)
- Test Device 4
- 1 (3)
Answer the user's questions about the world truthfully.
If the user wants to control a device, reject the request and suggest using the Home Assistant app.
''',
'role': 'system',
}),
dict({
'content': 'hello',
'role': 'user',
}),
dict({
'content': 'Hello, how can I help you?',
'role': 'assistant',
}),
])
# ---
# name: test_default_prompt[conversation.openai]
list([
dict({
'content': '''
This smart home is controlled by Home Assistant.
An overview of the areas and the devices in this smart home:
Test Area:
- Test Device (Test Model)
Test Area 2:
- Test Device 2
- Test Device 3 (Test Model 3A)
- Test Device 4
- 1 (3)
Answer the user's questions about the world truthfully.
If the user wants to control a device, reject the request and suggest using the Home Assistant app.
''',
'role': 'system',
}),
dict({
'content': 'hello',
'role': 'user',
}),
dict({
'content': 'Hello, how can I help you?',
'role': 'assistant',
}),
])
# ---

View File

@ -1,34 +0,0 @@
# serializer version: 1
# name: test_default_prompt
list([
dict({
'content': '''
This smart home is controlled by Home Assistant.
An overview of the areas and the devices in this smart home:
Test Area:
- Test Device (Test Model)
Test Area 2:
- Test Device 2
- Test Device 3 (Test Model 3A)
- Test Device 4
- 1 (3)
Answer the user's questions about the world truthfully.
If the user wants to control a device, reject the request and suggest using the Home Assistant app.
''',
'role': 'system',
}),
dict({
'content': 'hello',
'role': 'user',
}),
dict({
'content': 'Hello, how can I help you?',
'role': 'assistant',
}),
])
# ---

View File

@ -0,0 +1,196 @@
"""Tests for the OpenAI integration."""
from unittest.mock import AsyncMock, patch
from httpx import Response
from openai import RateLimitError
from openai.types.chat.chat_completion import ChatCompletion, Choice
from openai.types.chat.chat_completion_message import ChatCompletionMessage
from openai.types.completion_usage import CompletionUsage
import pytest
from syrupy.assertion import SnapshotAssertion
from homeassistant.components import conversation
from homeassistant.core import Context, HomeAssistant
from homeassistant.helpers import area_registry as ar, device_registry as dr, intent
from tests.common import MockConfigEntry
@pytest.mark.parametrize("agent_id", [None, "conversation.openai"])
async def test_default_prompt(
hass: HomeAssistant,
mock_config_entry: MockConfigEntry,
mock_init_component,
area_registry: ar.AreaRegistry,
device_registry: dr.DeviceRegistry,
snapshot: SnapshotAssertion,
agent_id: str,
) -> None:
"""Test that the default prompt works."""
entry = MockConfigEntry(title=None)
entry.add_to_hass(hass)
for i in range(3):
area_registry.async_create(f"{i}Empty Area")
if agent_id is None:
agent_id = mock_config_entry.entry_id
device_registry.async_get_or_create(
config_entry_id=entry.entry_id,
connections={("test", "1234")},
name="Test Device",
manufacturer="Test Manufacturer",
model="Test Model",
suggested_area="Test Area",
)
for i in range(3):
device_registry.async_get_or_create(
config_entry_id=entry.entry_id,
connections={("test", f"{i}abcd")},
name="Test Service",
manufacturer="Test Manufacturer",
model="Test Model",
suggested_area="Test Area",
entry_type=dr.DeviceEntryType.SERVICE,
)
device_registry.async_get_or_create(
config_entry_id=entry.entry_id,
connections={("test", "5678")},
name="Test Device 2",
manufacturer="Test Manufacturer 2",
model="Device 2",
suggested_area="Test Area 2",
)
device_registry.async_get_or_create(
config_entry_id=entry.entry_id,
connections={("test", "9876")},
name="Test Device 3",
manufacturer="Test Manufacturer 3",
model="Test Model 3A",
suggested_area="Test Area 2",
)
device_registry.async_get_or_create(
config_entry_id=entry.entry_id,
connections={("test", "qwer")},
name="Test Device 4",
suggested_area="Test Area 2",
)
device = device_registry.async_get_or_create(
config_entry_id=entry.entry_id,
connections={("test", "9876-disabled")},
name="Test Device 3",
manufacturer="Test Manufacturer 3",
model="Test Model 3A",
suggested_area="Test Area 2",
)
device_registry.async_update_device(
device.id, disabled_by=dr.DeviceEntryDisabler.USER
)
device_registry.async_get_or_create(
config_entry_id=entry.entry_id,
connections={("test", "9876-no-name")},
manufacturer="Test Manufacturer NoName",
model="Test Model NoName",
suggested_area="Test Area 2",
)
device_registry.async_get_or_create(
config_entry_id=entry.entry_id,
connections={("test", "9876-integer-values")},
name=1,
manufacturer=2,
model=3,
suggested_area="Test Area 2",
)
with patch(
"openai.resources.chat.completions.AsyncCompletions.create",
new_callable=AsyncMock,
return_value=ChatCompletion(
id="chatcmpl-1234567890ABCDEFGHIJKLMNOPQRS",
choices=[
Choice(
finish_reason="stop",
index=0,
message=ChatCompletionMessage(
content="Hello, how can I help you?",
role="assistant",
function_call=None,
tool_calls=None,
),
)
],
created=1700000000,
model="gpt-3.5-turbo-0613",
object="chat.completion",
system_fingerprint=None,
usage=CompletionUsage(
completion_tokens=9, prompt_tokens=8, total_tokens=17
),
),
) as mock_create:
result = await conversation.async_converse(
hass, "hello", None, Context(), agent_id=agent_id
)
assert result.response.response_type == intent.IntentResponseType.ACTION_DONE
assert mock_create.mock_calls[0][2]["messages"] == snapshot
async def test_error_handling(
hass: HomeAssistant, mock_config_entry: MockConfigEntry, mock_init_component
) -> None:
"""Test that the default prompt works."""
with patch(
"openai.resources.chat.completions.AsyncCompletions.create",
new_callable=AsyncMock,
side_effect=RateLimitError(
response=Response(status_code=None, request=""), body=None, message=None
),
):
result = await conversation.async_converse(
hass, "hello", None, Context(), agent_id=mock_config_entry.entry_id
)
assert result.response.response_type == intent.IntentResponseType.ERROR, result
assert result.response.error_code == "unknown", result
async def test_template_error(
hass: HomeAssistant, mock_config_entry: MockConfigEntry
) -> None:
"""Test that template error handling works."""
hass.config_entries.async_update_entry(
mock_config_entry,
options={
"prompt": "talk like a {% if True %}smarthome{% else %}pirate please.",
},
)
with (
patch(
"openai.resources.models.AsyncModels.list",
),
patch(
"openai.resources.chat.completions.AsyncCompletions.create",
new_callable=AsyncMock,
),
):
await hass.config_entries.async_setup(mock_config_entry.entry_id)
await hass.async_block_till_done()
result = await conversation.async_converse(
hass, "hello", None, Context(), agent_id=mock_config_entry.entry_id
)
assert result.response.response_type == intent.IntentResponseType.ERROR, result
assert result.response.error_code == "unknown", result
async def test_conversation_agent(
hass: HomeAssistant,
mock_config_entry: MockConfigEntry,
mock_init_component,
) -> None:
"""Test OpenAIAgent."""
agent = conversation.get_agent_manager(hass).async_get_agent(
mock_config_entry.entry_id
)
assert agent.supported_languages == "*"

View File

@ -1,6 +1,6 @@
"""Tests for the OpenAI integration.""" """Tests for the OpenAI integration."""
from unittest.mock import AsyncMock, patch from unittest.mock import patch
from httpx import Response from httpx import Response
from openai import ( from openai import (
@ -9,197 +9,17 @@ from openai import (
BadRequestError, BadRequestError,
RateLimitError, RateLimitError,
) )
from openai.types.chat.chat_completion import ChatCompletion, Choice
from openai.types.chat.chat_completion_message import ChatCompletionMessage
from openai.types.completion_usage import CompletionUsage
from openai.types.image import Image from openai.types.image import Image
from openai.types.images_response import ImagesResponse from openai.types.images_response import ImagesResponse
import pytest import pytest
from syrupy.assertion import SnapshotAssertion
from homeassistant.components import conversation from homeassistant.core import HomeAssistant
from homeassistant.core import Context, HomeAssistant
from homeassistant.exceptions import HomeAssistantError from homeassistant.exceptions import HomeAssistantError
from homeassistant.helpers import area_registry as ar, device_registry as dr, intent
from homeassistant.setup import async_setup_component from homeassistant.setup import async_setup_component
from tests.common import MockConfigEntry from tests.common import MockConfigEntry
async def test_default_prompt(
hass: HomeAssistant,
mock_config_entry: MockConfigEntry,
mock_init_component,
area_registry: ar.AreaRegistry,
device_registry: dr.DeviceRegistry,
snapshot: SnapshotAssertion,
) -> None:
"""Test that the default prompt works."""
entry = MockConfigEntry(title=None)
entry.add_to_hass(hass)
for i in range(3):
area_registry.async_create(f"{i}Empty Area")
device_registry.async_get_or_create(
config_entry_id=entry.entry_id,
connections={("test", "1234")},
name="Test Device",
manufacturer="Test Manufacturer",
model="Test Model",
suggested_area="Test Area",
)
for i in range(3):
device_registry.async_get_or_create(
config_entry_id=entry.entry_id,
connections={("test", f"{i}abcd")},
name="Test Service",
manufacturer="Test Manufacturer",
model="Test Model",
suggested_area="Test Area",
entry_type=dr.DeviceEntryType.SERVICE,
)
device_registry.async_get_or_create(
config_entry_id=entry.entry_id,
connections={("test", "5678")},
name="Test Device 2",
manufacturer="Test Manufacturer 2",
model="Device 2",
suggested_area="Test Area 2",
)
device_registry.async_get_or_create(
config_entry_id=entry.entry_id,
connections={("test", "9876")},
name="Test Device 3",
manufacturer="Test Manufacturer 3",
model="Test Model 3A",
suggested_area="Test Area 2",
)
device_registry.async_get_or_create(
config_entry_id=entry.entry_id,
connections={("test", "qwer")},
name="Test Device 4",
suggested_area="Test Area 2",
)
device = device_registry.async_get_or_create(
config_entry_id=entry.entry_id,
connections={("test", "9876-disabled")},
name="Test Device 3",
manufacturer="Test Manufacturer 3",
model="Test Model 3A",
suggested_area="Test Area 2",
)
device_registry.async_update_device(
device.id, disabled_by=dr.DeviceEntryDisabler.USER
)
device_registry.async_get_or_create(
config_entry_id=entry.entry_id,
connections={("test", "9876-no-name")},
manufacturer="Test Manufacturer NoName",
model="Test Model NoName",
suggested_area="Test Area 2",
)
device_registry.async_get_or_create(
config_entry_id=entry.entry_id,
connections={("test", "9876-integer-values")},
name=1,
manufacturer=2,
model=3,
suggested_area="Test Area 2",
)
with patch(
"openai.resources.chat.completions.AsyncCompletions.create",
new_callable=AsyncMock,
return_value=ChatCompletion(
id="chatcmpl-1234567890ABCDEFGHIJKLMNOPQRS",
choices=[
Choice(
finish_reason="stop",
index=0,
message=ChatCompletionMessage(
content="Hello, how can I help you?",
role="assistant",
function_call=None,
tool_calls=None,
),
)
],
created=1700000000,
model="gpt-3.5-turbo-0613",
object="chat.completion",
system_fingerprint=None,
usage=CompletionUsage(
completion_tokens=9, prompt_tokens=8, total_tokens=17
),
),
) as mock_create:
result = await conversation.async_converse(
hass, "hello", None, Context(), agent_id=mock_config_entry.entry_id
)
assert result.response.response_type == intent.IntentResponseType.ACTION_DONE
assert mock_create.mock_calls[0][2]["messages"] == snapshot
async def test_error_handling(
hass: HomeAssistant, mock_config_entry: MockConfigEntry, mock_init_component
) -> None:
"""Test that the default prompt works."""
with patch(
"openai.resources.chat.completions.AsyncCompletions.create",
new_callable=AsyncMock,
side_effect=RateLimitError(
response=Response(status_code=None, request=""), body=None, message=None
),
):
result = await conversation.async_converse(
hass, "hello", None, Context(), agent_id=mock_config_entry.entry_id
)
assert result.response.response_type == intent.IntentResponseType.ERROR, result
assert result.response.error_code == "unknown", result
async def test_template_error(
hass: HomeAssistant, mock_config_entry: MockConfigEntry
) -> None:
"""Test that template error handling works."""
hass.config_entries.async_update_entry(
mock_config_entry,
options={
"prompt": "talk like a {% if True %}smarthome{% else %}pirate please.",
},
)
with (
patch(
"openai.resources.models.AsyncModels.list",
),
patch(
"openai.resources.chat.completions.AsyncCompletions.create",
new_callable=AsyncMock,
),
):
await hass.config_entries.async_setup(mock_config_entry.entry_id)
await hass.async_block_till_done()
result = await conversation.async_converse(
hass, "hello", None, Context(), agent_id=mock_config_entry.entry_id
)
assert result.response.response_type == intent.IntentResponseType.ERROR, result
assert result.response.error_code == "unknown", result
async def test_conversation_agent(
hass: HomeAssistant,
mock_config_entry: MockConfigEntry,
mock_init_component,
) -> None:
"""Test OpenAIAgent."""
agent = conversation.get_agent_manager(hass).async_get_agent(
mock_config_entry.entry_id
)
assert agent.supported_languages == "*"
@pytest.mark.parametrize( @pytest.mark.parametrize(
("service_data", "expected_args"), ("service_data", "expected_args"),
[ [