core/homeassistant/components/openai_conversation/config_flow.py

182 lines
5.2 KiB
Python
Raw Normal View History

"""Config flow for OpenAI Conversation integration."""
from __future__ import annotations
import logging
from types import MappingProxyType
from typing import Any
import openai
import voluptuous as vol
from homeassistant.config_entries import (
ConfigEntry,
ConfigFlow,
ConfigFlowResult,
OptionsFlow,
)
from homeassistant.const import CONF_API_KEY, CONF_LLM_HASS_API
from homeassistant.core import HomeAssistant
from homeassistant.helpers import llm
from homeassistant.helpers.selector import (
NumberSelector,
NumberSelectorConfig,
SelectOptionDict,
SelectSelector,
SelectSelectorConfig,
TemplateSelector,
)
from .const import (
CONF_CHAT_MODEL,
CONF_MAX_TOKENS,
CONF_PROMPT,
CONF_TEMPERATURE,
CONF_TOP_P,
DEFAULT_CHAT_MODEL,
DEFAULT_MAX_TOKENS,
DEFAULT_PROMPT,
DEFAULT_TEMPERATURE,
DEFAULT_TOP_P,
DOMAIN,
)
_LOGGER = logging.getLogger(__name__)
STEP_USER_DATA_SCHEMA = vol.Schema(
{
vol.Required(CONF_API_KEY): str,
}
)
async def validate_input(hass: HomeAssistant, data: dict[str, Any]) -> None:
"""Validate the user input allows us to connect.
Data has the keys from STEP_USER_DATA_SCHEMA with values provided by the user.
"""
client = openai.AsyncOpenAI(api_key=data[CONF_API_KEY])
await hass.async_add_executor_job(client.with_options(timeout=10.0).models.list)
class OpenAIConfigFlow(ConfigFlow, domain=DOMAIN):
"""Handle a config flow for OpenAI Conversation."""
VERSION = 1
async def async_step_user(
self, user_input: dict[str, Any] | None = None
) -> ConfigFlowResult:
"""Handle the initial step."""
if user_input is None:
return self.async_show_form(
step_id="user", data_schema=STEP_USER_DATA_SCHEMA
)
errors = {}
try:
await validate_input(self.hass, user_input)
except openai.APIConnectionError:
errors["base"] = "cannot_connect"
except openai.AuthenticationError:
errors["base"] = "invalid_auth"
except Exception:
_LOGGER.exception("Unexpected exception")
errors["base"] = "unknown"
else:
return self.async_create_entry(
2024-05-25 03:55:01 +00:00
title="ChatGPT",
data=user_input,
options={CONF_LLM_HASS_API: llm.LLM_API_ASSIST},
)
return self.async_show_form(
step_id="user", data_schema=STEP_USER_DATA_SCHEMA, errors=errors
)
@staticmethod
def async_get_options_flow(
config_entry: ConfigEntry,
) -> OptionsFlow:
"""Create the options flow."""
return OpenAIOptionsFlow(config_entry)
class OpenAIOptionsFlow(OptionsFlow):
"""OpenAI config flow options handler."""
def __init__(self, config_entry: ConfigEntry) -> None:
"""Initialize options flow."""
self.config_entry = config_entry
async def async_step_init(
self, user_input: dict[str, Any] | None = None
) -> ConfigFlowResult:
"""Manage the options."""
if user_input is not None:
if user_input[CONF_LLM_HASS_API] == "none":
user_input.pop(CONF_LLM_HASS_API)
return self.async_create_entry(title="", data=user_input)
schema = openai_config_option_schema(self.hass, self.config_entry.options)
return self.async_show_form(
step_id="init",
data_schema=vol.Schema(schema),
)
def openai_config_option_schema(
hass: HomeAssistant,
options: MappingProxyType[str, Any],
) -> dict:
"""Return a schema for OpenAI completion options."""
apis: list[SelectOptionDict] = [
SelectOptionDict(
label="No control",
value="none",
)
]
apis.extend(
SelectOptionDict(
label=api.name,
value=api.id,
)
for api in llm.async_get_apis(hass)
)
return {
vol.Optional(
CONF_PROMPT,
description={"suggested_value": options.get(CONF_PROMPT)},
default=DEFAULT_PROMPT,
): TemplateSelector(),
vol.Optional(
CONF_LLM_HASS_API,
description={"suggested_value": options.get(CONF_LLM_HASS_API)},
default="none",
): SelectSelector(SelectSelectorConfig(options=apis)),
vol.Optional(
CONF_CHAT_MODEL,
description={
# New key in HA 2023.4
"suggested_value": options.get(CONF_CHAT_MODEL)
},
default=DEFAULT_CHAT_MODEL,
): str,
vol.Optional(
CONF_MAX_TOKENS,
description={"suggested_value": options.get(CONF_MAX_TOKENS)},
default=DEFAULT_MAX_TOKENS,
): int,
vol.Optional(
CONF_TOP_P,
description={"suggested_value": options.get(CONF_TOP_P)},
default=DEFAULT_TOP_P,
): NumberSelector(NumberSelectorConfig(min=0, max=1, step=0.05)),
vol.Optional(
CONF_TEMPERATURE,
description={"suggested_value": options.get(CONF_TEMPERATURE)},
default=DEFAULT_TEMPERATURE,
): NumberSelector(NumberSelectorConfig(min=0, max=2, step=0.05)),
}