core/homeassistant/components/openai_conversation/conversation.py

248 lines
9.3 KiB
Python

"""Conversation support for OpenAI."""
import json
from typing import Any, Literal
import openai
import voluptuous as vol
from voluptuous_openapi import convert
from homeassistant.components import assist_pipeline, conversation
from homeassistant.components.conversation import trace
from homeassistant.config_entries import ConfigEntry
from homeassistant.const import CONF_LLM_HASS_API, MATCH_ALL
from homeassistant.core import HomeAssistant
from homeassistant.exceptions import HomeAssistantError, TemplateError
from homeassistant.helpers import device_registry as dr, intent, llm, template
from homeassistant.helpers.entity_platform import AddEntitiesCallback
from homeassistant.util import ulid
from .const import (
CONF_CHAT_MODEL,
CONF_MAX_TOKENS,
CONF_PROMPT,
CONF_TEMPERATURE,
CONF_TOP_P,
DEFAULT_PROMPT,
DOMAIN,
LOGGER,
RECOMMENDED_CHAT_MODEL,
RECOMMENDED_MAX_TOKENS,
RECOMMENDED_TEMPERATURE,
RECOMMENDED_TOP_P,
)
# Max number of back and forth with the LLM to generate a response
MAX_TOOL_ITERATIONS = 10
async def async_setup_entry(
hass: HomeAssistant,
config_entry: ConfigEntry,
async_add_entities: AddEntitiesCallback,
) -> None:
"""Set up conversation entities."""
agent = OpenAIConversationEntity(config_entry)
async_add_entities([agent])
def _format_tool(tool: llm.Tool) -> dict[str, Any]:
"""Format tool specification."""
tool_spec = {"name": tool.name}
if tool.description:
tool_spec["description"] = tool.description
tool_spec["parameters"] = convert(tool.parameters)
return {"type": "function", "function": tool_spec}
class OpenAIConversationEntity(
conversation.ConversationEntity, conversation.AbstractConversationAgent
):
"""OpenAI conversation agent."""
_attr_has_entity_name = True
_attr_name = None
def __init__(self, entry: ConfigEntry) -> None:
"""Initialize the agent."""
self.entry = entry
self.history: dict[str, list[dict]] = {}
self._attr_unique_id = entry.entry_id
self._attr_device_info = dr.DeviceInfo(
identifiers={(DOMAIN, entry.entry_id)},
name=entry.title,
manufacturer="OpenAI",
model="ChatGPT",
entry_type=dr.DeviceEntryType.SERVICE,
)
@property
def supported_languages(self) -> list[str] | Literal["*"]:
"""Return a list of supported languages."""
return MATCH_ALL
async def async_added_to_hass(self) -> None:
"""When entity is added to Home Assistant."""
await super().async_added_to_hass()
assist_pipeline.async_migrate_engine(
self.hass, "conversation", self.entry.entry_id, self.entity_id
)
conversation.async_set_agent(self.hass, self.entry, self)
async def async_will_remove_from_hass(self) -> None:
"""When entity will be removed from Home Assistant."""
conversation.async_unset_agent(self.hass, self.entry)
await super().async_will_remove_from_hass()
async def async_process(
self, user_input: conversation.ConversationInput
) -> conversation.ConversationResult:
"""Process a sentence."""
options = self.entry.options
intent_response = intent.IntentResponse(language=user_input.language)
llm_api: llm.API | None = None
tools: list[dict[str, Any]] | None = None
if options.get(CONF_LLM_HASS_API):
try:
llm_api = llm.async_get_api(self.hass, options[CONF_LLM_HASS_API])
except HomeAssistantError as err:
LOGGER.error("Error getting LLM API: %s", err)
intent_response.async_set_error(
intent.IntentResponseErrorCode.UNKNOWN,
f"Error preparing LLM API: {err}",
)
return conversation.ConversationResult(
response=intent_response, conversation_id=user_input.conversation_id
)
tools = [_format_tool(tool) for tool in llm_api.async_get_tools()]
if user_input.conversation_id in self.history:
conversation_id = user_input.conversation_id
messages = self.history[conversation_id]
else:
conversation_id = ulid.ulid_now()
try:
if llm_api:
empty_tool_input = llm.ToolInput(
tool_name="",
tool_args={},
platform=DOMAIN,
context=user_input.context,
user_prompt=user_input.text,
language=user_input.language,
assistant=conversation.DOMAIN,
device_id=user_input.device_id,
)
api_prompt = await llm_api.async_get_api_prompt(empty_tool_input)
else:
api_prompt = llm.async_render_no_api_prompt(self.hass)
prompt = "\n".join(
(
template.Template(
options.get(CONF_PROMPT, DEFAULT_PROMPT), self.hass
).async_render(
{
"ha_name": self.hass.config.location_name,
},
parse_result=False,
),
api_prompt,
)
)
except TemplateError as err:
LOGGER.error("Error rendering prompt: %s", err)
intent_response = intent.IntentResponse(language=user_input.language)
intent_response.async_set_error(
intent.IntentResponseErrorCode.UNKNOWN,
f"Sorry, I had a problem with my template: {err}",
)
return conversation.ConversationResult(
response=intent_response, conversation_id=conversation_id
)
messages = [{"role": "system", "content": prompt}]
messages.append({"role": "user", "content": user_input.text})
LOGGER.debug("Prompt: %s", messages)
trace.async_conversation_trace_append(
trace.ConversationTraceEventType.AGENT_DETAIL, {"messages": messages}
)
client = self.hass.data[DOMAIN][self.entry.entry_id]
# To prevent infinite loops, we limit the number of iterations
for _iteration in range(MAX_TOOL_ITERATIONS):
try:
result = await client.chat.completions.create(
model=options.get(CONF_CHAT_MODEL, RECOMMENDED_CHAT_MODEL),
messages=messages,
tools=tools,
max_tokens=options.get(CONF_MAX_TOKENS, RECOMMENDED_MAX_TOKENS),
top_p=options.get(CONF_TOP_P, RECOMMENDED_TOP_P),
temperature=options.get(CONF_TEMPERATURE, RECOMMENDED_TEMPERATURE),
user=conversation_id,
)
except openai.OpenAIError as err:
intent_response = intent.IntentResponse(language=user_input.language)
intent_response.async_set_error(
intent.IntentResponseErrorCode.UNKNOWN,
f"Sorry, I had a problem talking to OpenAI: {err}",
)
return conversation.ConversationResult(
response=intent_response, conversation_id=conversation_id
)
LOGGER.debug("Response %s", result)
response = result.choices[0].message
messages.append(response)
tool_calls = response.tool_calls
if not tool_calls or not llm_api:
break
for tool_call in tool_calls:
tool_input = llm.ToolInput(
tool_name=tool_call.function.name,
tool_args=json.loads(tool_call.function.arguments),
platform=DOMAIN,
context=user_input.context,
user_prompt=user_input.text,
language=user_input.language,
assistant=conversation.DOMAIN,
device_id=user_input.device_id,
)
LOGGER.debug(
"Tool call: %s(%s)", tool_input.tool_name, tool_input.tool_args
)
try:
tool_response = await llm_api.async_call_tool(tool_input)
except (HomeAssistantError, vol.Invalid) as e:
tool_response = {"error": type(e).__name__}
if str(e):
tool_response["error_text"] = str(e)
LOGGER.debug("Tool response: %s", tool_response)
messages.append(
{
"role": "tool",
"tool_call_id": tool_call.id,
"name": tool_call.function.name,
"content": json.dumps(tool_response),
}
)
self.history[conversation_id] = messages
intent_response = intent.IntentResponse(language=user_input.language)
intent_response.async_set_speech(response.content)
return conversation.ConversationResult(
response=intent_response, conversation_id=conversation_id
)