2024-04-09 15:10:03 +00:00
|
|
|
"""Conversation support for OpenAI."""
|
|
|
|
|
|
|
|
from typing import Literal
|
|
|
|
|
|
|
|
import openai
|
|
|
|
|
|
|
|
from homeassistant.components import assist_pipeline, conversation
|
|
|
|
from homeassistant.config_entries import ConfigEntry
|
|
|
|
from homeassistant.const import MATCH_ALL
|
|
|
|
from homeassistant.core import HomeAssistant
|
|
|
|
from homeassistant.exceptions import TemplateError
|
|
|
|
from homeassistant.helpers import intent, template
|
|
|
|
from homeassistant.helpers.entity_platform import AddEntitiesCallback
|
|
|
|
from homeassistant.util import ulid
|
|
|
|
|
|
|
|
from .const import (
|
|
|
|
CONF_CHAT_MODEL,
|
|
|
|
CONF_MAX_TOKENS,
|
|
|
|
CONF_PROMPT,
|
|
|
|
CONF_TEMPERATURE,
|
|
|
|
CONF_TOP_P,
|
|
|
|
DEFAULT_CHAT_MODEL,
|
|
|
|
DEFAULT_MAX_TOKENS,
|
|
|
|
DEFAULT_PROMPT,
|
|
|
|
DEFAULT_TEMPERATURE,
|
|
|
|
DEFAULT_TOP_P,
|
|
|
|
DOMAIN,
|
|
|
|
LOGGER,
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
async def async_setup_entry(
|
|
|
|
hass: HomeAssistant,
|
|
|
|
config_entry: ConfigEntry,
|
|
|
|
async_add_entities: AddEntitiesCallback,
|
|
|
|
) -> None:
|
|
|
|
"""Set up conversation entities."""
|
2024-05-03 14:59:08 +00:00
|
|
|
agent = OpenAIConversationEntity(config_entry)
|
2024-04-09 15:10:03 +00:00
|
|
|
async_add_entities([agent])
|
|
|
|
|
|
|
|
|
|
|
|
class OpenAIConversationEntity(
|
|
|
|
conversation.ConversationEntity, conversation.AbstractConversationAgent
|
|
|
|
):
|
|
|
|
"""OpenAI conversation agent."""
|
|
|
|
|
2024-05-03 00:38:12 +00:00
|
|
|
_attr_has_entity_name = True
|
|
|
|
|
2024-05-03 14:59:08 +00:00
|
|
|
def __init__(self, entry: ConfigEntry) -> None:
|
2024-04-09 15:10:03 +00:00
|
|
|
"""Initialize the agent."""
|
|
|
|
self.entry = entry
|
|
|
|
self.history: dict[str, list[dict]] = {}
|
|
|
|
self._attr_name = entry.title
|
|
|
|
self._attr_unique_id = entry.entry_id
|
|
|
|
|
|
|
|
@property
|
|
|
|
def supported_languages(self) -> list[str] | Literal["*"]:
|
|
|
|
"""Return a list of supported languages."""
|
|
|
|
return MATCH_ALL
|
|
|
|
|
|
|
|
async def async_added_to_hass(self) -> None:
|
|
|
|
"""When entity is added to Home Assistant."""
|
|
|
|
await super().async_added_to_hass()
|
|
|
|
assist_pipeline.async_migrate_engine(
|
|
|
|
self.hass, "conversation", self.entry.entry_id, self.entity_id
|
|
|
|
)
|
|
|
|
conversation.async_set_agent(self.hass, self.entry, self)
|
|
|
|
|
|
|
|
async def async_will_remove_from_hass(self) -> None:
|
|
|
|
"""When entity will be removed from Home Assistant."""
|
|
|
|
conversation.async_unset_agent(self.hass, self.entry)
|
|
|
|
await super().async_will_remove_from_hass()
|
|
|
|
|
|
|
|
async def async_process(
|
|
|
|
self, user_input: conversation.ConversationInput
|
|
|
|
) -> conversation.ConversationResult:
|
|
|
|
"""Process a sentence."""
|
|
|
|
raw_prompt = self.entry.options.get(CONF_PROMPT, DEFAULT_PROMPT)
|
|
|
|
model = self.entry.options.get(CONF_CHAT_MODEL, DEFAULT_CHAT_MODEL)
|
|
|
|
max_tokens = self.entry.options.get(CONF_MAX_TOKENS, DEFAULT_MAX_TOKENS)
|
|
|
|
top_p = self.entry.options.get(CONF_TOP_P, DEFAULT_TOP_P)
|
|
|
|
temperature = self.entry.options.get(CONF_TEMPERATURE, DEFAULT_TEMPERATURE)
|
|
|
|
|
|
|
|
if user_input.conversation_id in self.history:
|
|
|
|
conversation_id = user_input.conversation_id
|
|
|
|
messages = self.history[conversation_id]
|
|
|
|
else:
|
|
|
|
conversation_id = ulid.ulid_now()
|
|
|
|
try:
|
|
|
|
prompt = self._async_generate_prompt(raw_prompt)
|
|
|
|
except TemplateError as err:
|
|
|
|
LOGGER.error("Error rendering prompt: %s", err)
|
|
|
|
intent_response = intent.IntentResponse(language=user_input.language)
|
|
|
|
intent_response.async_set_error(
|
|
|
|
intent.IntentResponseErrorCode.UNKNOWN,
|
|
|
|
f"Sorry, I had a problem with my template: {err}",
|
|
|
|
)
|
|
|
|
return conversation.ConversationResult(
|
|
|
|
response=intent_response, conversation_id=conversation_id
|
|
|
|
)
|
|
|
|
messages = [{"role": "system", "content": prompt}]
|
|
|
|
|
|
|
|
messages.append({"role": "user", "content": user_input.text})
|
|
|
|
|
|
|
|
LOGGER.debug("Prompt for %s: %s", model, messages)
|
|
|
|
|
|
|
|
client = self.hass.data[DOMAIN][self.entry.entry_id]
|
|
|
|
|
|
|
|
try:
|
|
|
|
result = await client.chat.completions.create(
|
|
|
|
model=model,
|
|
|
|
messages=messages,
|
|
|
|
max_tokens=max_tokens,
|
|
|
|
top_p=top_p,
|
|
|
|
temperature=temperature,
|
|
|
|
user=conversation_id,
|
|
|
|
)
|
|
|
|
except openai.OpenAIError as err:
|
|
|
|
intent_response = intent.IntentResponse(language=user_input.language)
|
|
|
|
intent_response.async_set_error(
|
|
|
|
intent.IntentResponseErrorCode.UNKNOWN,
|
|
|
|
f"Sorry, I had a problem talking to OpenAI: {err}",
|
|
|
|
)
|
|
|
|
return conversation.ConversationResult(
|
|
|
|
response=intent_response, conversation_id=conversation_id
|
|
|
|
)
|
|
|
|
|
|
|
|
LOGGER.debug("Response %s", result)
|
|
|
|
response = result.choices[0].message.model_dump(include={"role", "content"})
|
|
|
|
messages.append(response)
|
|
|
|
self.history[conversation_id] = messages
|
|
|
|
|
|
|
|
intent_response = intent.IntentResponse(language=user_input.language)
|
|
|
|
intent_response.async_set_speech(response["content"])
|
|
|
|
return conversation.ConversationResult(
|
|
|
|
response=intent_response, conversation_id=conversation_id
|
|
|
|
)
|
|
|
|
|
|
|
|
def _async_generate_prompt(self, raw_prompt: str) -> str:
|
|
|
|
"""Generate a prompt for the user."""
|
|
|
|
return template.Template(raw_prompt, self.hass).async_render(
|
|
|
|
{
|
|
|
|
"ha_name": self.hass.config.location_name,
|
|
|
|
},
|
|
|
|
parse_result=False,
|
|
|
|
)
|