core/homeassistant/components/openai_conversation/conversation.py

147 lines
5.2 KiB
Python
Raw Normal View History

"""Conversation support for OpenAI."""
from typing import Literal
import openai
from homeassistant.components import assist_pipeline, conversation
from homeassistant.config_entries import ConfigEntry
from homeassistant.const import MATCH_ALL
from homeassistant.core import HomeAssistant
from homeassistant.exceptions import TemplateError
from homeassistant.helpers import intent, template
from homeassistant.helpers.entity_platform import AddEntitiesCallback
from homeassistant.util import ulid
from .const import (
CONF_CHAT_MODEL,
CONF_MAX_TOKENS,
CONF_PROMPT,
CONF_TEMPERATURE,
CONF_TOP_P,
DEFAULT_CHAT_MODEL,
DEFAULT_MAX_TOKENS,
DEFAULT_PROMPT,
DEFAULT_TEMPERATURE,
DEFAULT_TOP_P,
DOMAIN,
LOGGER,
)
async def async_setup_entry(
hass: HomeAssistant,
config_entry: ConfigEntry,
async_add_entities: AddEntitiesCallback,
) -> None:
"""Set up conversation entities."""
agent = OpenAIConversationEntity(config_entry)
async_add_entities([agent])
class OpenAIConversationEntity(
conversation.ConversationEntity, conversation.AbstractConversationAgent
):
"""OpenAI conversation agent."""
_attr_has_entity_name = True
def __init__(self, entry: ConfigEntry) -> None:
"""Initialize the agent."""
self.entry = entry
self.history: dict[str, list[dict]] = {}
self._attr_name = entry.title
self._attr_unique_id = entry.entry_id
@property
def supported_languages(self) -> list[str] | Literal["*"]:
"""Return a list of supported languages."""
return MATCH_ALL
async def async_added_to_hass(self) -> None:
"""When entity is added to Home Assistant."""
await super().async_added_to_hass()
assist_pipeline.async_migrate_engine(
self.hass, "conversation", self.entry.entry_id, self.entity_id
)
conversation.async_set_agent(self.hass, self.entry, self)
async def async_will_remove_from_hass(self) -> None:
"""When entity will be removed from Home Assistant."""
conversation.async_unset_agent(self.hass, self.entry)
await super().async_will_remove_from_hass()
async def async_process(
self, user_input: conversation.ConversationInput
) -> conversation.ConversationResult:
"""Process a sentence."""
raw_prompt = self.entry.options.get(CONF_PROMPT, DEFAULT_PROMPT)
model = self.entry.options.get(CONF_CHAT_MODEL, DEFAULT_CHAT_MODEL)
max_tokens = self.entry.options.get(CONF_MAX_TOKENS, DEFAULT_MAX_TOKENS)
top_p = self.entry.options.get(CONF_TOP_P, DEFAULT_TOP_P)
temperature = self.entry.options.get(CONF_TEMPERATURE, DEFAULT_TEMPERATURE)
if user_input.conversation_id in self.history:
conversation_id = user_input.conversation_id
messages = self.history[conversation_id]
else:
conversation_id = ulid.ulid_now()
try:
prompt = self._async_generate_prompt(raw_prompt)
except TemplateError as err:
LOGGER.error("Error rendering prompt: %s", err)
intent_response = intent.IntentResponse(language=user_input.language)
intent_response.async_set_error(
intent.IntentResponseErrorCode.UNKNOWN,
f"Sorry, I had a problem with my template: {err}",
)
return conversation.ConversationResult(
response=intent_response, conversation_id=conversation_id
)
messages = [{"role": "system", "content": prompt}]
messages.append({"role": "user", "content": user_input.text})
LOGGER.debug("Prompt for %s: %s", model, messages)
client = self.hass.data[DOMAIN][self.entry.entry_id]
try:
result = await client.chat.completions.create(
model=model,
messages=messages,
max_tokens=max_tokens,
top_p=top_p,
temperature=temperature,
user=conversation_id,
)
except openai.OpenAIError as err:
intent_response = intent.IntentResponse(language=user_input.language)
intent_response.async_set_error(
intent.IntentResponseErrorCode.UNKNOWN,
f"Sorry, I had a problem talking to OpenAI: {err}",
)
return conversation.ConversationResult(
response=intent_response, conversation_id=conversation_id
)
LOGGER.debug("Response %s", result)
response = result.choices[0].message.model_dump(include={"role", "content"})
messages.append(response)
self.history[conversation_id] = messages
intent_response = intent.IntentResponse(language=user_input.language)
intent_response.async_set_speech(response["content"])
return conversation.ConversationResult(
response=intent_response, conversation_id=conversation_id
)
def _async_generate_prompt(self, raw_prompt: str) -> str:
"""Generate a prompt for the user."""
return template.Template(raw_prompt, self.hass).async_render(
{
"ha_name": self.hass.config.location_name,
},
parse_result=False,
)