diff --git a/homeassistant/components/openai_conversation/__init__.py b/homeassistant/components/openai_conversation/__init__.py index c9d92f554ee..950e15f8e11 100644 --- a/homeassistant/components/openai_conversation/__init__.py +++ b/homeassistant/components/openai_conversation/__init__.py @@ -15,7 +15,18 @@ from homeassistant.exceptions import ConfigEntryNotReady, TemplateError from homeassistant.helpers import area_registry, intent, template from homeassistant.util import ulid -from .const import DEFAULT_MODEL, DEFAULT_PROMPT +from .const import ( + CONF_MAX_TOKENS, + CONF_MODEL, + CONF_PROMPT, + CONF_TEMPERATURE, + CONF_TOP_P, + DEFAULT_MAX_TOKENS, + DEFAULT_MODEL, + DEFAULT_PROMPT, + DEFAULT_TEMPERATURE, + DEFAULT_TOP_P, +) _LOGGER = logging.getLogger(__name__) @@ -63,7 +74,11 @@ class OpenAIAgent(conversation.AbstractConversationAgent): self, user_input: conversation.ConversationInput ) -> conversation.ConversationResult: """Process a sentence.""" - model = DEFAULT_MODEL + raw_prompt = self.entry.options.get(CONF_PROMPT, DEFAULT_PROMPT) + model = self.entry.options.get(CONF_MODEL, DEFAULT_MODEL) + max_tokens = self.entry.options.get(CONF_MAX_TOKENS, DEFAULT_MAX_TOKENS) + top_p = self.entry.options.get(CONF_TOP_P, DEFAULT_TOP_P) + temperature = self.entry.options.get(CONF_TEMPERATURE, DEFAULT_TEMPERATURE) if user_input.conversation_id in self.history: conversation_id = user_input.conversation_id @@ -71,7 +86,7 @@ class OpenAIAgent(conversation.AbstractConversationAgent): else: conversation_id = ulid.ulid() try: - prompt = self._async_generate_prompt() + prompt = self._async_generate_prompt(raw_prompt) except TemplateError as err: _LOGGER.error("Error rendering prompt: %s", err) intent_response = intent.IntentResponse(language=user_input.language) @@ -98,14 +113,13 @@ class OpenAIAgent(conversation.AbstractConversationAgent): _LOGGER.debug("Prompt for %s: %s", model, prompt) try: - result = await self.hass.async_add_executor_job( - partial( - openai.Completion.create, - engine=model, - prompt=prompt, - max_tokens=150, - user=conversation_id, - ) + result = await openai.Completion.acreate( + engine=model, + prompt=prompt, + max_tokens=max_tokens, + top_p=top_p, + temperature=temperature, + user=conversation_id, ) except error.OpenAIError as err: intent_response = intent.IntentResponse(language=user_input.language) @@ -131,9 +145,9 @@ class OpenAIAgent(conversation.AbstractConversationAgent): response=intent_response, conversation_id=conversation_id ) - def _async_generate_prompt(self) -> str: + def _async_generate_prompt(self, raw_prompt: str) -> str: """Generate a prompt for the user.""" - return template.Template(DEFAULT_PROMPT, self.hass).async_render( + return template.Template(raw_prompt, self.hass).async_render( { "ha_name": self.hass.config.location_name, "areas": list(area_registry.async_get(self.hass).areas.values()), diff --git a/homeassistant/components/openai_conversation/config_flow.py b/homeassistant/components/openai_conversation/config_flow.py index 88253d63a44..9aef77e37f7 100644 --- a/homeassistant/components/openai_conversation/config_flow.py +++ b/homeassistant/components/openai_conversation/config_flow.py @@ -3,6 +3,8 @@ from __future__ import annotations from functools import partial import logging +import types +from types import MappingProxyType from typing import Any import openai @@ -13,8 +15,26 @@ from homeassistant import config_entries from homeassistant.const import CONF_API_KEY from homeassistant.core import HomeAssistant from homeassistant.data_entry_flow import FlowResult +from homeassistant.helpers.selector import ( + NumberSelector, + NumberSelectorConfig, + TextSelector, + TextSelectorConfig, +) -from .const import DOMAIN +from .const import ( + CONF_MAX_TOKENS, + CONF_MODEL, + CONF_PROMPT, + CONF_TEMPERATURE, + CONF_TOP_P, + DEFAULT_MAX_TOKENS, + DEFAULT_MODEL, + DEFAULT_PROMPT, + DEFAULT_TEMPERATURE, + DEFAULT_TOP_P, + DOMAIN, +) _LOGGER = logging.getLogger(__name__) @@ -24,6 +44,16 @@ STEP_USER_DATA_SCHEMA = vol.Schema( } ) +DEFAULT_OPTIONS = types.MappingProxyType( + { + CONF_PROMPT: DEFAULT_PROMPT, + CONF_MODEL: DEFAULT_MODEL, + CONF_MAX_TOKENS: DEFAULT_MAX_TOKENS, + CONF_TOP_P: DEFAULT_TOP_P, + CONF_TEMPERATURE: DEFAULT_TEMPERATURE, + } +) + async def validate_input(hass: HomeAssistant, data: dict[str, Any]) -> None: """Validate the user input allows us to connect. @@ -68,3 +98,49 @@ class ConfigFlow(config_entries.ConfigFlow, domain=DOMAIN): return self.async_show_form( step_id="user", data_schema=STEP_USER_DATA_SCHEMA, errors=errors ) + + @staticmethod + def async_get_options_flow( + config_entry: config_entries.ConfigEntry, + ) -> config_entries.OptionsFlow: + """Create the options flow.""" + return OptionsFlow(config_entry) + + +class OptionsFlow(config_entries.OptionsFlow): + """OpenAI config flow options handler.""" + + def __init__(self, config_entry: config_entries.ConfigEntry) -> None: + """Initialize options flow.""" + self.config_entry = config_entry + + async def async_step_init( + self, user_input: dict[str, Any] | None = None + ) -> FlowResult: + """Manage the options.""" + if user_input is not None: + return self.async_create_entry(title="OpenAI Conversation", data=user_input) + schema = openai_config_option_schema(self.config_entry.options) + return self.async_show_form( + step_id="init", + data_schema=vol.Schema(schema), + ) + + +def openai_config_option_schema(options: MappingProxyType[str, Any]) -> dict: + """Return a schema for OpenAI completion options.""" + if not options: + options = DEFAULT_OPTIONS + return { + vol.Required(CONF_PROMPT, default=options.get(CONF_PROMPT)): TextSelector( + TextSelectorConfig(multiline=True) + ), + vol.Required(CONF_MODEL, default=options.get(CONF_MODEL)): str, + vol.Required(CONF_MAX_TOKENS, default=options.get(CONF_MAX_TOKENS)): int, + vol.Required(CONF_TOP_P, default=options.get(CONF_TOP_P)): NumberSelector( + NumberSelectorConfig(min=0, max=1, step=0.05) + ), + vol.Required( + CONF_TEMPERATURE, default=options.get(CONF_TEMPERATURE) + ): NumberSelector(NumberSelectorConfig(min=0, max=1, step=0.05)), + } diff --git a/homeassistant/components/openai_conversation/const.py b/homeassistant/components/openai_conversation/const.py index 378548173b0..ae6d2db6cc2 100644 --- a/homeassistant/components/openai_conversation/const.py +++ b/homeassistant/components/openai_conversation/const.py @@ -2,7 +2,6 @@ DOMAIN = "openai_conversation" CONF_PROMPT = "prompt" -DEFAULT_MODEL = "text-davinci-003" DEFAULT_PROMPT = """This smart home is controlled by Home Assistant. An overview of the areas and the devices in this smart home: @@ -28,3 +27,11 @@ Now finish this conversation: Smart home: How can I assist? """ +CONF_MODEL = "model" +DEFAULT_MODEL = "text-davinci-003" +CONF_MAX_TOKENS = "max_tokens" +DEFAULT_MAX_TOKENS = 150 +CONF_TOP_P = "top_p" +DEFAULT_TOP_P = 1 +CONF_TEMPERATURE = "temperature" +DEFAULT_TEMPERATURE = 0.5 diff --git a/homeassistant/components/openai_conversation/strings.json b/homeassistant/components/openai_conversation/strings.json index 9ebf1c64a21..f7af4618a9d 100644 --- a/homeassistant/components/openai_conversation/strings.json +++ b/homeassistant/components/openai_conversation/strings.json @@ -15,5 +15,18 @@ "abort": { "single_instance_allowed": "[%key:common::config_flow::abort::single_instance_allowed%]" } + }, + "options": { + "step": { + "init": { + "data": { + "prompt": "Prompt Template", + "model": "Completion Model", + "max_tokens": "Maximum tokens to return in response", + "temperature": "Temperature", + "top_p": "Top P" + } + } + } } } diff --git a/homeassistant/components/openai_conversation/translations/en.json b/homeassistant/components/openai_conversation/translations/en.json index 7665a5535ab..cf4122eab2a 100644 --- a/homeassistant/components/openai_conversation/translations/en.json +++ b/homeassistant/components/openai_conversation/translations/en.json @@ -15,5 +15,18 @@ } } } + }, + "options": { + "step": { + "init": { + "data": { + "prompt": "Prompt Template", + "model": "Completion Model", + "max_tokens": "Maximum tokens to return in response", + "temperature": "Temperature", + "top_p": "Top P" + } + } + } } } \ No newline at end of file diff --git a/tests/components/openai_conversation/test_config_flow.py b/tests/components/openai_conversation/test_config_flow.py index 1510b986b59..761b8268942 100644 --- a/tests/components/openai_conversation/test_config_flow.py +++ b/tests/components/openai_conversation/test_config_flow.py @@ -5,7 +5,11 @@ from openai.error import APIConnectionError, AuthenticationError, InvalidRequest import pytest from homeassistant import config_entries -from homeassistant.components.openai_conversation.const import DOMAIN +from homeassistant.components.openai_conversation.const import ( + CONF_MODEL, + DEFAULT_MODEL, + DOMAIN, +) from homeassistant.core import HomeAssistant from homeassistant.data_entry_flow import FlowResultType @@ -50,6 +54,27 @@ async def test_form(hass: HomeAssistant) -> None: assert len(mock_setup_entry.mock_calls) == 1 +async def test_options( + hass: HomeAssistant, mock_config_entry, mock_init_component +) -> None: + """Test the options form.""" + options_flow = await hass.config_entries.options.async_init( + mock_config_entry.entry_id + ) + options = await hass.config_entries.options.async_configure( + options_flow["flow_id"], + { + "prompt": "Speak like a pirate", + "max_tokens": 200, + }, + ) + await hass.async_block_till_done() + assert options["type"] == FlowResultType.CREATE_ENTRY + assert options["data"]["prompt"] == "Speak like a pirate" + assert options["data"]["max_tokens"] == 200 + assert options["data"][CONF_MODEL] == DEFAULT_MODEL + + @pytest.mark.parametrize( "side_effect, error", [ diff --git a/tests/components/openai_conversation/test_init.py b/tests/components/openai_conversation/test_init.py index 551d493df8e..759c5e2e200 100644 --- a/tests/components/openai_conversation/test_init.py +++ b/tests/components/openai_conversation/test_init.py @@ -68,11 +68,10 @@ async def test_default_prompt(hass, mock_init_component): device.id, disabled_by=device_registry.DeviceEntryDisabler.USER ) - with patch("openai.Completion.create") as mock_create: + with patch("openai.Completion.acreate") as mock_create: result = await conversation.async_converse(hass, "hello", None, Context()) assert result.response.response_type == intent.IntentResponseType.ACTION_DONE - assert ( mock_create.mock_calls[0][2]["prompt"] == """This smart home is controlled by Home Assistant. @@ -101,7 +100,26 @@ Smart home: """ async def test_error_handling(hass, mock_init_component): """Test that the default prompt works.""" - with patch("openai.Completion.create", side_effect=error.ServiceUnavailableError): + with patch("openai.Completion.acreate", side_effect=error.ServiceUnavailableError): + result = await conversation.async_converse(hass, "hello", None, Context()) + + assert result.response.response_type == intent.IntentResponseType.ERROR, result + assert result.response.error_code == "unknown", result + + +async def test_template_error(hass, mock_config_entry, mock_init_component): + """Test that template error handling works.""" + options_flow = await hass.config_entries.options.async_init( + mock_config_entry.entry_id + ) + await hass.config_entries.options.async_configure( + options_flow["flow_id"], + { + "prompt": "talk like a {% if True %}smarthome{% else %}pirate please.", + }, + ) + await hass.async_block_till_done() + with patch("openai.Completion.acreate"): result = await conversation.async_converse(hass, "hello", None, Context()) assert result.response.response_type == intent.IntentResponseType.ERROR, result