diff --git a/homeassistant/components/openai_conversation/__init__.py b/homeassistant/components/openai_conversation/__init__.py index 07e872a0f5d..ffbfc1799c5 100644 --- a/homeassistant/components/openai_conversation/__init__.py +++ b/homeassistant/components/openai_conversation/__init__.py @@ -2,53 +2,30 @@ from __future__ import annotations -import logging -from typing import Literal - import openai import voluptuous as vol from homeassistant.components import conversation from homeassistant.config_entries import ConfigEntry -from homeassistant.const import CONF_API_KEY, MATCH_ALL +from homeassistant.const import CONF_API_KEY, Platform from homeassistant.core import ( HomeAssistant, ServiceCall, ServiceResponse, SupportsResponse, ) -from homeassistant.exceptions import ( - ConfigEntryNotReady, - HomeAssistantError, - TemplateError, -) +from homeassistant.exceptions import ConfigEntryNotReady, HomeAssistantError from homeassistant.helpers import ( config_validation as cv, - intent, issue_registry as ir, selector, - template, ) from homeassistant.helpers.typing import ConfigType -from homeassistant.util import ulid -from .const import ( - CONF_CHAT_MODEL, - CONF_MAX_TOKENS, - CONF_PROMPT, - CONF_TEMPERATURE, - CONF_TOP_P, - DEFAULT_CHAT_MODEL, - DEFAULT_MAX_TOKENS, - DEFAULT_PROMPT, - DEFAULT_TEMPERATURE, - DEFAULT_TOP_P, - DOMAIN, -) +from .const import DOMAIN, LOGGER -_LOGGER = logging.getLogger(__name__) SERVICE_GENERATE_IMAGE = "generate_image" - +PLATFORMS = (Platform.CONVERSATION,) CONFIG_SCHEMA = cv.config_entry_only_config_schema(DOMAIN) @@ -120,108 +97,23 @@ async def async_setup_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool: try: await hass.async_add_executor_job(client.with_options(timeout=10.0).models.list) except openai.AuthenticationError as err: - _LOGGER.error("Invalid API key: %s", err) + LOGGER.error("Invalid API key: %s", err) return False except openai.OpenAIError as err: raise ConfigEntryNotReady(err) from err hass.data.setdefault(DOMAIN, {})[entry.entry_id] = client - conversation.async_set_agent(hass, entry, OpenAIAgent(hass, entry)) + await hass.config_entries.async_forward_entry_setups(entry, PLATFORMS) + return True async def async_unload_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool: """Unload OpenAI.""" + if not await hass.config_entries.async_unload_platforms(entry, PLATFORMS): + return False + hass.data[DOMAIN].pop(entry.entry_id) conversation.async_unset_agent(hass, entry) return True - - -class OpenAIAgent(conversation.AbstractConversationAgent): - """OpenAI conversation agent.""" - - def __init__(self, hass: HomeAssistant, entry: ConfigEntry) -> None: - """Initialize the agent.""" - self.hass = hass - self.entry = entry - self.history: dict[str, list[dict]] = {} - - @property - def supported_languages(self) -> list[str] | Literal["*"]: - """Return a list of supported languages.""" - return MATCH_ALL - - async def async_process( - self, user_input: conversation.ConversationInput - ) -> conversation.ConversationResult: - """Process a sentence.""" - raw_prompt = self.entry.options.get(CONF_PROMPT, DEFAULT_PROMPT) - model = self.entry.options.get(CONF_CHAT_MODEL, DEFAULT_CHAT_MODEL) - max_tokens = self.entry.options.get(CONF_MAX_TOKENS, DEFAULT_MAX_TOKENS) - top_p = self.entry.options.get(CONF_TOP_P, DEFAULT_TOP_P) - temperature = self.entry.options.get(CONF_TEMPERATURE, DEFAULT_TEMPERATURE) - - if user_input.conversation_id in self.history: - conversation_id = user_input.conversation_id - messages = self.history[conversation_id] - else: - conversation_id = ulid.ulid_now() - try: - prompt = self._async_generate_prompt(raw_prompt) - except TemplateError as err: - _LOGGER.error("Error rendering prompt: %s", err) - intent_response = intent.IntentResponse(language=user_input.language) - intent_response.async_set_error( - intent.IntentResponseErrorCode.UNKNOWN, - f"Sorry, I had a problem with my template: {err}", - ) - return conversation.ConversationResult( - response=intent_response, conversation_id=conversation_id - ) - messages = [{"role": "system", "content": prompt}] - - messages.append({"role": "user", "content": user_input.text}) - - _LOGGER.debug("Prompt for %s: %s", model, messages) - - client = self.hass.data[DOMAIN][self.entry.entry_id] - - try: - result = await client.chat.completions.create( - model=model, - messages=messages, - max_tokens=max_tokens, - top_p=top_p, - temperature=temperature, - user=conversation_id, - ) - except openai.OpenAIError as err: - intent_response = intent.IntentResponse(language=user_input.language) - intent_response.async_set_error( - intent.IntentResponseErrorCode.UNKNOWN, - f"Sorry, I had a problem talking to OpenAI: {err}", - ) - return conversation.ConversationResult( - response=intent_response, conversation_id=conversation_id - ) - - _LOGGER.debug("Response %s", result) - response = result.choices[0].message.model_dump(include={"role", "content"}) - messages.append(response) - self.history[conversation_id] = messages - - intent_response = intent.IntentResponse(language=user_input.language) - intent_response.async_set_speech(response["content"]) - return conversation.ConversationResult( - response=intent_response, conversation_id=conversation_id - ) - - def _async_generate_prompt(self, raw_prompt: str) -> str: - """Generate a prompt for the user.""" - return template.Template(raw_prompt, self.hass).async_render( - { - "ha_name": self.hass.config.location_name, - }, - parse_result=False, - ) diff --git a/homeassistant/components/openai_conversation/const.py b/homeassistant/components/openai_conversation/const.py index 46f8603c5f1..ee4a107c241 100644 --- a/homeassistant/components/openai_conversation/const.py +++ b/homeassistant/components/openai_conversation/const.py @@ -1,6 +1,9 @@ """Constants for the OpenAI Conversation integration.""" +import logging + DOMAIN = "openai_conversation" +LOGGER = logging.getLogger(__name__) CONF_PROMPT = "prompt" DEFAULT_PROMPT = """This smart home is controlled by Home Assistant. diff --git a/homeassistant/components/openai_conversation/conversation.py b/homeassistant/components/openai_conversation/conversation.py new file mode 100644 index 00000000000..158b155c75d --- /dev/null +++ b/homeassistant/components/openai_conversation/conversation.py @@ -0,0 +1,145 @@ +"""Conversation support for OpenAI.""" + +from typing import Literal + +import openai + +from homeassistant.components import assist_pipeline, conversation +from homeassistant.config_entries import ConfigEntry +from homeassistant.const import MATCH_ALL +from homeassistant.core import HomeAssistant +from homeassistant.exceptions import TemplateError +from homeassistant.helpers import intent, template +from homeassistant.helpers.entity_platform import AddEntitiesCallback +from homeassistant.util import ulid + +from .const import ( + CONF_CHAT_MODEL, + CONF_MAX_TOKENS, + CONF_PROMPT, + CONF_TEMPERATURE, + CONF_TOP_P, + DEFAULT_CHAT_MODEL, + DEFAULT_MAX_TOKENS, + DEFAULT_PROMPT, + DEFAULT_TEMPERATURE, + DEFAULT_TOP_P, + DOMAIN, + LOGGER, +) + + +async def async_setup_entry( + hass: HomeAssistant, + config_entry: ConfigEntry, + async_add_entities: AddEntitiesCallback, +) -> None: + """Set up conversation entities.""" + agent = OpenAIConversationEntity(hass, config_entry) + async_add_entities([agent]) + + +class OpenAIConversationEntity( + conversation.ConversationEntity, conversation.AbstractConversationAgent +): + """OpenAI conversation agent.""" + + def __init__(self, hass: HomeAssistant, entry: ConfigEntry) -> None: + """Initialize the agent.""" + self.hass = hass + self.entry = entry + self.history: dict[str, list[dict]] = {} + self._attr_name = entry.title + self._attr_unique_id = entry.entry_id + + @property + def supported_languages(self) -> list[str] | Literal["*"]: + """Return a list of supported languages.""" + return MATCH_ALL + + async def async_added_to_hass(self) -> None: + """When entity is added to Home Assistant.""" + await super().async_added_to_hass() + assist_pipeline.async_migrate_engine( + self.hass, "conversation", self.entry.entry_id, self.entity_id + ) + conversation.async_set_agent(self.hass, self.entry, self) + + async def async_will_remove_from_hass(self) -> None: + """When entity will be removed from Home Assistant.""" + conversation.async_unset_agent(self.hass, self.entry) + await super().async_will_remove_from_hass() + + async def async_process( + self, user_input: conversation.ConversationInput + ) -> conversation.ConversationResult: + """Process a sentence.""" + raw_prompt = self.entry.options.get(CONF_PROMPT, DEFAULT_PROMPT) + model = self.entry.options.get(CONF_CHAT_MODEL, DEFAULT_CHAT_MODEL) + max_tokens = self.entry.options.get(CONF_MAX_TOKENS, DEFAULT_MAX_TOKENS) + top_p = self.entry.options.get(CONF_TOP_P, DEFAULT_TOP_P) + temperature = self.entry.options.get(CONF_TEMPERATURE, DEFAULT_TEMPERATURE) + + if user_input.conversation_id in self.history: + conversation_id = user_input.conversation_id + messages = self.history[conversation_id] + else: + conversation_id = ulid.ulid_now() + try: + prompt = self._async_generate_prompt(raw_prompt) + except TemplateError as err: + LOGGER.error("Error rendering prompt: %s", err) + intent_response = intent.IntentResponse(language=user_input.language) + intent_response.async_set_error( + intent.IntentResponseErrorCode.UNKNOWN, + f"Sorry, I had a problem with my template: {err}", + ) + return conversation.ConversationResult( + response=intent_response, conversation_id=conversation_id + ) + messages = [{"role": "system", "content": prompt}] + + messages.append({"role": "user", "content": user_input.text}) + + LOGGER.debug("Prompt for %s: %s", model, messages) + + client = self.hass.data[DOMAIN][self.entry.entry_id] + + try: + result = await client.chat.completions.create( + model=model, + messages=messages, + max_tokens=max_tokens, + top_p=top_p, + temperature=temperature, + user=conversation_id, + ) + except openai.OpenAIError as err: + intent_response = intent.IntentResponse(language=user_input.language) + intent_response.async_set_error( + intent.IntentResponseErrorCode.UNKNOWN, + f"Sorry, I had a problem talking to OpenAI: {err}", + ) + return conversation.ConversationResult( + response=intent_response, conversation_id=conversation_id + ) + + LOGGER.debug("Response %s", result) + response = result.choices[0].message.model_dump(include={"role", "content"}) + messages.append(response) + self.history[conversation_id] = messages + + intent_response = intent.IntentResponse(language=user_input.language) + intent_response.async_set_speech(response["content"]) + return conversation.ConversationResult( + response=intent_response, conversation_id=conversation_id + ) + + def _async_generate_prompt(self, raw_prompt: str) -> str: + """Generate a prompt for the user.""" + return template.Template(raw_prompt, self.hass).async_render( + { + "ha_name": self.hass.config.location_name, + }, + parse_result=False, + ) diff --git a/homeassistant/components/openai_conversation/manifest.json b/homeassistant/components/openai_conversation/manifest.json index 5138be96b55..b71c84e2081 100644 --- a/homeassistant/components/openai_conversation/manifest.json +++ b/homeassistant/components/openai_conversation/manifest.json @@ -1,6 +1,7 @@ { "domain": "openai_conversation", "name": "OpenAI Conversation", + "after_dependencies": ["assist_pipeline"], "codeowners": ["@balloob"], "config_flow": true, "dependencies": ["conversation"], diff --git a/tests/components/openai_conversation/conftest.py b/tests/components/openai_conversation/conftest.py index 1597fa79d0a..272c23a9510 100644 --- a/tests/components/openai_conversation/conftest.py +++ b/tests/components/openai_conversation/conftest.py @@ -14,6 +14,7 @@ from tests.common import MockConfigEntry def mock_config_entry(hass): """Mock a config entry.""" entry = MockConfigEntry( + title="OpenAI", domain="openai_conversation", data={ "api_key": "bla", diff --git a/tests/components/openai_conversation/snapshots/test_conversation.ambr b/tests/components/openai_conversation/snapshots/test_conversation.ambr new file mode 100644 index 00000000000..1a488bb948c --- /dev/null +++ b/tests/components/openai_conversation/snapshots/test_conversation.ambr @@ -0,0 +1,67 @@ +# serializer version: 1 +# name: test_default_prompt[None] + list([ + dict({ + 'content': ''' + This smart home is controlled by Home Assistant. + + An overview of the areas and the devices in this smart home: + + Test Area: + - Test Device (Test Model) + + Test Area 2: + - Test Device 2 + - Test Device 3 (Test Model 3A) + - Test Device 4 + - 1 (3) + + Answer the user's questions about the world truthfully. + + If the user wants to control a device, reject the request and suggest using the Home Assistant app. + ''', + 'role': 'system', + }), + dict({ + 'content': 'hello', + 'role': 'user', + }), + dict({ + 'content': 'Hello, how can I help you?', + 'role': 'assistant', + }), + ]) +# --- +# name: test_default_prompt[conversation.openai] + list([ + dict({ + 'content': ''' + This smart home is controlled by Home Assistant. + + An overview of the areas and the devices in this smart home: + + Test Area: + - Test Device (Test Model) + + Test Area 2: + - Test Device 2 + - Test Device 3 (Test Model 3A) + - Test Device 4 + - 1 (3) + + Answer the user's questions about the world truthfully. + + If the user wants to control a device, reject the request and suggest using the Home Assistant app. + ''', + 'role': 'system', + }), + dict({ + 'content': 'hello', + 'role': 'user', + }), + dict({ + 'content': 'Hello, how can I help you?', + 'role': 'assistant', + }), + ]) +# --- diff --git a/tests/components/openai_conversation/snapshots/test_init.ambr b/tests/components/openai_conversation/snapshots/test_init.ambr deleted file mode 100644 index bc06f51f416..00000000000 --- a/tests/components/openai_conversation/snapshots/test_init.ambr +++ /dev/null @@ -1,34 +0,0 @@ -# serializer version: 1 -# name: test_default_prompt - list([ - dict({ - 'content': ''' - This smart home is controlled by Home Assistant. - - An overview of the areas and the devices in this smart home: - - Test Area: - - Test Device (Test Model) - - Test Area 2: - - Test Device 2 - - Test Device 3 (Test Model 3A) - - Test Device 4 - - 1 (3) - - Answer the user's questions about the world truthfully. - - If the user wants to control a device, reject the request and suggest using the Home Assistant app. - ''', - 'role': 'system', - }), - dict({ - 'content': 'hello', - 'role': 'user', - }), - dict({ - 'content': 'Hello, how can I help you?', - 'role': 'assistant', - }), - ]) -# --- diff --git a/tests/components/openai_conversation/test_conversation.py b/tests/components/openai_conversation/test_conversation.py new file mode 100644 index 00000000000..9e50204cdde --- /dev/null +++ b/tests/components/openai_conversation/test_conversation.py @@ -0,0 +1,196 @@ +"""Tests for the OpenAI integration.""" + +from unittest.mock import AsyncMock, patch + +from httpx import Response +from openai import RateLimitError +from openai.types.chat.chat_completion import ChatCompletion, Choice +from openai.types.chat.chat_completion_message import ChatCompletionMessage +from openai.types.completion_usage import CompletionUsage +import pytest +from syrupy.assertion import SnapshotAssertion + +from homeassistant.components import conversation +from homeassistant.core import Context, HomeAssistant +from homeassistant.helpers import area_registry as ar, device_registry as dr, intent + +from tests.common import MockConfigEntry + + +@pytest.mark.parametrize("agent_id", [None, "conversation.openai"]) +async def test_default_prompt( + hass: HomeAssistant, + mock_config_entry: MockConfigEntry, + mock_init_component, + area_registry: ar.AreaRegistry, + device_registry: dr.DeviceRegistry, + snapshot: SnapshotAssertion, + agent_id: str, +) -> None: + """Test that the default prompt works.""" + entry = MockConfigEntry(title=None) + entry.add_to_hass(hass) + for i in range(3): + area_registry.async_create(f"{i}Empty Area") + + if agent_id is None: + agent_id = mock_config_entry.entry_id + + device_registry.async_get_or_create( + config_entry_id=entry.entry_id, + connections={("test", "1234")}, + name="Test Device", + manufacturer="Test Manufacturer", + model="Test Model", + suggested_area="Test Area", + ) + for i in range(3): + device_registry.async_get_or_create( + config_entry_id=entry.entry_id, + connections={("test", f"{i}abcd")}, + name="Test Service", + manufacturer="Test Manufacturer", + model="Test Model", + suggested_area="Test Area", + entry_type=dr.DeviceEntryType.SERVICE, + ) + device_registry.async_get_or_create( + config_entry_id=entry.entry_id, + connections={("test", "5678")}, + name="Test Device 2", + manufacturer="Test Manufacturer 2", + model="Device 2", + suggested_area="Test Area 2", + ) + device_registry.async_get_or_create( + config_entry_id=entry.entry_id, + connections={("test", "9876")}, + name="Test Device 3", + manufacturer="Test Manufacturer 3", + model="Test Model 3A", + suggested_area="Test Area 2", + ) + device_registry.async_get_or_create( + config_entry_id=entry.entry_id, + connections={("test", "qwer")}, + name="Test Device 4", + suggested_area="Test Area 2", + ) + device = device_registry.async_get_or_create( + config_entry_id=entry.entry_id, + connections={("test", "9876-disabled")}, + name="Test Device 3", + manufacturer="Test Manufacturer 3", + model="Test Model 3A", + suggested_area="Test Area 2", + ) + device_registry.async_update_device( + device.id, disabled_by=dr.DeviceEntryDisabler.USER + ) + device_registry.async_get_or_create( + config_entry_id=entry.entry_id, + connections={("test", "9876-no-name")}, + manufacturer="Test Manufacturer NoName", + model="Test Model NoName", + suggested_area="Test Area 2", + ) + device_registry.async_get_or_create( + config_entry_id=entry.entry_id, + connections={("test", "9876-integer-values")}, + name=1, + manufacturer=2, + model=3, + suggested_area="Test Area 2", + ) + with patch( + "openai.resources.chat.completions.AsyncCompletions.create", + new_callable=AsyncMock, + return_value=ChatCompletion( + id="chatcmpl-1234567890ABCDEFGHIJKLMNOPQRS", + choices=[ + Choice( + finish_reason="stop", + index=0, + message=ChatCompletionMessage( + content="Hello, how can I help you?", + role="assistant", + function_call=None, + tool_calls=None, + ), + ) + ], + created=1700000000, + model="gpt-3.5-turbo-0613", + object="chat.completion", + system_fingerprint=None, + usage=CompletionUsage( + completion_tokens=9, prompt_tokens=8, total_tokens=17 + ), + ), + ) as mock_create: + result = await conversation.async_converse( + hass, "hello", None, Context(), agent_id=agent_id + ) + + assert result.response.response_type == intent.IntentResponseType.ACTION_DONE + assert mock_create.mock_calls[0][2]["messages"] == snapshot + + +async def test_error_handling( + hass: HomeAssistant, mock_config_entry: MockConfigEntry, mock_init_component +) -> None: + """Test that the default prompt works.""" + with patch( + "openai.resources.chat.completions.AsyncCompletions.create", + new_callable=AsyncMock, + side_effect=RateLimitError( + response=Response(status_code=None, request=""), body=None, message=None + ), + ): + result = await conversation.async_converse( + hass, "hello", None, Context(), agent_id=mock_config_entry.entry_id + ) + + assert result.response.response_type == intent.IntentResponseType.ERROR, result + assert result.response.error_code == "unknown", result + + +async def test_template_error( + hass: HomeAssistant, mock_config_entry: MockConfigEntry +) -> None: + """Test that template error handling works.""" + hass.config_entries.async_update_entry( + mock_config_entry, + options={ + "prompt": "talk like a {% if True %}smarthome{% else %}pirate please.", + }, + ) + with ( + patch( + "openai.resources.models.AsyncModels.list", + ), + patch( + "openai.resources.chat.completions.AsyncCompletions.create", + new_callable=AsyncMock, + ), + ): + await hass.config_entries.async_setup(mock_config_entry.entry_id) + await hass.async_block_till_done() + result = await conversation.async_converse( + hass, "hello", None, Context(), agent_id=mock_config_entry.entry_id + ) + + assert result.response.response_type == intent.IntentResponseType.ERROR, result + assert result.response.error_code == "unknown", result + + +async def test_conversation_agent( + hass: HomeAssistant, + mock_config_entry: MockConfigEntry, + mock_init_component, +) -> None: + """Test OpenAIAgent.""" + agent = conversation.get_agent_manager(hass).async_get_agent( + mock_config_entry.entry_id + ) + assert agent.supported_languages == "*" diff --git a/tests/components/openai_conversation/test_init.py b/tests/components/openai_conversation/test_init.py index 2702b749a64..773ba3bca06 100644 --- a/tests/components/openai_conversation/test_init.py +++ b/tests/components/openai_conversation/test_init.py @@ -1,6 +1,6 @@ """Tests for the OpenAI integration.""" -from unittest.mock import AsyncMock, patch +from unittest.mock import patch from httpx import Response from openai import ( @@ -9,197 +9,17 @@ from openai import ( BadRequestError, RateLimitError, ) -from openai.types.chat.chat_completion import ChatCompletion, Choice -from openai.types.chat.chat_completion_message import ChatCompletionMessage -from openai.types.completion_usage import CompletionUsage from openai.types.image import Image from openai.types.images_response import ImagesResponse import pytest -from syrupy.assertion import SnapshotAssertion -from homeassistant.components import conversation -from homeassistant.core import Context, HomeAssistant +from homeassistant.core import HomeAssistant from homeassistant.exceptions import HomeAssistantError -from homeassistant.helpers import area_registry as ar, device_registry as dr, intent from homeassistant.setup import async_setup_component from tests.common import MockConfigEntry -async def test_default_prompt( - hass: HomeAssistant, - mock_config_entry: MockConfigEntry, - mock_init_component, - area_registry: ar.AreaRegistry, - device_registry: dr.DeviceRegistry, - snapshot: SnapshotAssertion, -) -> None: - """Test that the default prompt works.""" - entry = MockConfigEntry(title=None) - entry.add_to_hass(hass) - for i in range(3): - area_registry.async_create(f"{i}Empty Area") - - device_registry.async_get_or_create( - config_entry_id=entry.entry_id, - connections={("test", "1234")}, - name="Test Device", - manufacturer="Test Manufacturer", - model="Test Model", - suggested_area="Test Area", - ) - for i in range(3): - device_registry.async_get_or_create( - config_entry_id=entry.entry_id, - connections={("test", f"{i}abcd")}, - name="Test Service", - manufacturer="Test Manufacturer", - model="Test Model", - suggested_area="Test Area", - entry_type=dr.DeviceEntryType.SERVICE, - ) - device_registry.async_get_or_create( - config_entry_id=entry.entry_id, - connections={("test", "5678")}, - name="Test Device 2", - manufacturer="Test Manufacturer 2", - model="Device 2", - suggested_area="Test Area 2", - ) - device_registry.async_get_or_create( - config_entry_id=entry.entry_id, - connections={("test", "9876")}, - name="Test Device 3", - manufacturer="Test Manufacturer 3", - model="Test Model 3A", - suggested_area="Test Area 2", - ) - device_registry.async_get_or_create( - config_entry_id=entry.entry_id, - connections={("test", "qwer")}, - name="Test Device 4", - suggested_area="Test Area 2", - ) - device = device_registry.async_get_or_create( - config_entry_id=entry.entry_id, - connections={("test", "9876-disabled")}, - name="Test Device 3", - manufacturer="Test Manufacturer 3", - model="Test Model 3A", - suggested_area="Test Area 2", - ) - device_registry.async_update_device( - device.id, disabled_by=dr.DeviceEntryDisabler.USER - ) - device_registry.async_get_or_create( - config_entry_id=entry.entry_id, - connections={("test", "9876-no-name")}, - manufacturer="Test Manufacturer NoName", - model="Test Model NoName", - suggested_area="Test Area 2", - ) - device_registry.async_get_or_create( - config_entry_id=entry.entry_id, - connections={("test", "9876-integer-values")}, - name=1, - manufacturer=2, - model=3, - suggested_area="Test Area 2", - ) - with patch( - "openai.resources.chat.completions.AsyncCompletions.create", - new_callable=AsyncMock, - return_value=ChatCompletion( - id="chatcmpl-1234567890ABCDEFGHIJKLMNOPQRS", - choices=[ - Choice( - finish_reason="stop", - index=0, - message=ChatCompletionMessage( - content="Hello, how can I help you?", - role="assistant", - function_call=None, - tool_calls=None, - ), - ) - ], - created=1700000000, - model="gpt-3.5-turbo-0613", - object="chat.completion", - system_fingerprint=None, - usage=CompletionUsage( - completion_tokens=9, prompt_tokens=8, total_tokens=17 - ), - ), - ) as mock_create: - result = await conversation.async_converse( - hass, "hello", None, Context(), agent_id=mock_config_entry.entry_id - ) - - assert result.response.response_type == intent.IntentResponseType.ACTION_DONE - assert mock_create.mock_calls[0][2]["messages"] == snapshot - - -async def test_error_handling( - hass: HomeAssistant, mock_config_entry: MockConfigEntry, mock_init_component -) -> None: - """Test that the default prompt works.""" - with patch( - "openai.resources.chat.completions.AsyncCompletions.create", - new_callable=AsyncMock, - side_effect=RateLimitError( - response=Response(status_code=None, request=""), body=None, message=None - ), - ): - result = await conversation.async_converse( - hass, "hello", None, Context(), agent_id=mock_config_entry.entry_id - ) - - assert result.response.response_type == intent.IntentResponseType.ERROR, result - assert result.response.error_code == "unknown", result - - -async def test_template_error( - hass: HomeAssistant, mock_config_entry: MockConfigEntry -) -> None: - """Test that template error handling works.""" - hass.config_entries.async_update_entry( - mock_config_entry, - options={ - "prompt": "talk like a {% if True %}smarthome{% else %}pirate please.", - }, - ) - with ( - patch( - "openai.resources.models.AsyncModels.list", - ), - patch( - "openai.resources.chat.completions.AsyncCompletions.create", - new_callable=AsyncMock, - ), - ): - await hass.config_entries.async_setup(mock_config_entry.entry_id) - await hass.async_block_till_done() - result = await conversation.async_converse( - hass, "hello", None, Context(), agent_id=mock_config_entry.entry_id - ) - - assert result.response.response_type == intent.IntentResponseType.ERROR, result - assert result.response.error_code == "unknown", result - - -async def test_conversation_agent( - hass: HomeAssistant, - mock_config_entry: MockConfigEntry, - mock_init_component, -) -> None: - """Test OpenAIAgent.""" - agent = conversation.get_agent_manager(hass).async_get_agent( - mock_config_entry.entry_id - ) - assert agent.supported_languages == "*" - - @pytest.mark.parametrize( ("service_data", "expected_args"), [