"""Conversation support for OpenAI.""" from typing import Literal import openai from homeassistant.components import assist_pipeline, conversation from homeassistant.config_entries import ConfigEntry from homeassistant.const import MATCH_ALL from homeassistant.core import HomeAssistant from homeassistant.exceptions import TemplateError from homeassistant.helpers import intent, template from homeassistant.helpers.entity_platform import AddEntitiesCallback from homeassistant.util import ulid from .const import ( CONF_CHAT_MODEL, CONF_MAX_TOKENS, CONF_PROMPT, CONF_TEMPERATURE, CONF_TOP_P, DEFAULT_CHAT_MODEL, DEFAULT_MAX_TOKENS, DEFAULT_PROMPT, DEFAULT_TEMPERATURE, DEFAULT_TOP_P, DOMAIN, LOGGER, ) async def async_setup_entry( hass: HomeAssistant, config_entry: ConfigEntry, async_add_entities: AddEntitiesCallback, ) -> None: """Set up conversation entities.""" agent = OpenAIConversationEntity(config_entry) async_add_entities([agent]) class OpenAIConversationEntity( conversation.ConversationEntity, conversation.AbstractConversationAgent ): """OpenAI conversation agent.""" _attr_has_entity_name = True def __init__(self, entry: ConfigEntry) -> None: """Initialize the agent.""" self.entry = entry self.history: dict[str, list[dict]] = {} self._attr_name = entry.title self._attr_unique_id = entry.entry_id @property def supported_languages(self) -> list[str] | Literal["*"]: """Return a list of supported languages.""" return MATCH_ALL async def async_added_to_hass(self) -> None: """When entity is added to Home Assistant.""" await super().async_added_to_hass() assist_pipeline.async_migrate_engine( self.hass, "conversation", self.entry.entry_id, self.entity_id ) conversation.async_set_agent(self.hass, self.entry, self) async def async_will_remove_from_hass(self) -> None: """When entity will be removed from Home Assistant.""" conversation.async_unset_agent(self.hass, self.entry) await super().async_will_remove_from_hass() async def async_process( self, user_input: conversation.ConversationInput ) -> conversation.ConversationResult: """Process a sentence.""" raw_prompt = self.entry.options.get(CONF_PROMPT, DEFAULT_PROMPT) model = self.entry.options.get(CONF_CHAT_MODEL, DEFAULT_CHAT_MODEL) max_tokens = self.entry.options.get(CONF_MAX_TOKENS, DEFAULT_MAX_TOKENS) top_p = self.entry.options.get(CONF_TOP_P, DEFAULT_TOP_P) temperature = self.entry.options.get(CONF_TEMPERATURE, DEFAULT_TEMPERATURE) if user_input.conversation_id in self.history: conversation_id = user_input.conversation_id messages = self.history[conversation_id] else: conversation_id = ulid.ulid_now() try: prompt = self._async_generate_prompt(raw_prompt) except TemplateError as err: LOGGER.error("Error rendering prompt: %s", err) intent_response = intent.IntentResponse(language=user_input.language) intent_response.async_set_error( intent.IntentResponseErrorCode.UNKNOWN, f"Sorry, I had a problem with my template: {err}", ) return conversation.ConversationResult( response=intent_response, conversation_id=conversation_id ) messages = [{"role": "system", "content": prompt}] messages.append({"role": "user", "content": user_input.text}) LOGGER.debug("Prompt for %s: %s", model, messages) client = self.hass.data[DOMAIN][self.entry.entry_id] try: result = await client.chat.completions.create( model=model, messages=messages, max_tokens=max_tokens, top_p=top_p, temperature=temperature, user=conversation_id, ) except openai.OpenAIError as err: intent_response = intent.IntentResponse(language=user_input.language) intent_response.async_set_error( intent.IntentResponseErrorCode.UNKNOWN, f"Sorry, I had a problem talking to OpenAI: {err}", ) return conversation.ConversationResult( response=intent_response, conversation_id=conversation_id ) LOGGER.debug("Response %s", result) response = result.choices[0].message.model_dump(include={"role", "content"}) messages.append(response) self.history[conversation_id] = messages intent_response = intent.IntentResponse(language=user_input.language) intent_response.async_set_speech(response["content"]) return conversation.ConversationResult( response=intent_response, conversation_id=conversation_id ) def _async_generate_prompt(self, raw_prompt: str) -> str: """Generate a prompt for the user.""" return template.Template(raw_prompt, self.hass).async_render( { "ha_name": self.hass.config.location_name, }, parse_result=False, )