core/homeassistant/components/ollama/conversation.py

354 lines
13 KiB
Python

"""The conversation platform for the Ollama integration."""
from __future__ import annotations
from collections.abc import Callable
import json
import logging
import time
from typing import Any, Literal
import ollama
import voluptuous as vol
from voluptuous_openapi import convert
from homeassistant.components import assist_pipeline, conversation
from homeassistant.components.conversation import trace
from homeassistant.config_entries import ConfigEntry
from homeassistant.const import CONF_LLM_HASS_API, MATCH_ALL
from homeassistant.core import HomeAssistant
from homeassistant.exceptions import HomeAssistantError, TemplateError
from homeassistant.helpers import intent, llm, template
from homeassistant.helpers.entity_platform import AddEntitiesCallback
from homeassistant.util import ulid
from .const import (
CONF_KEEP_ALIVE,
CONF_MAX_HISTORY,
CONF_MODEL,
CONF_NUM_CTX,
CONF_PROMPT,
DEFAULT_KEEP_ALIVE,
DEFAULT_MAX_HISTORY,
DEFAULT_NUM_CTX,
DOMAIN,
MAX_HISTORY_SECONDS,
)
from .models import MessageHistory, MessageRole
# Max number of back and forth with the LLM to generate a response
MAX_TOOL_ITERATIONS = 10
_LOGGER = logging.getLogger(__name__)
async def async_setup_entry(
hass: HomeAssistant,
config_entry: ConfigEntry,
async_add_entities: AddEntitiesCallback,
) -> None:
"""Set up conversation entities."""
agent = OllamaConversationEntity(config_entry)
async_add_entities([agent])
def _format_tool(
tool: llm.Tool, custom_serializer: Callable[[Any], Any] | None
) -> dict[str, Any]:
"""Format tool specification."""
tool_spec = {
"name": tool.name,
"parameters": convert(tool.parameters, custom_serializer=custom_serializer),
}
if tool.description:
tool_spec["description"] = tool.description
return {"type": "function", "function": tool_spec}
def _fix_invalid_arguments(value: Any) -> Any:
"""Attempt to repair incorrectly formatted json function arguments.
Small models (for example llama3.1 8B) may produce invalid argument values
which we attempt to repair here.
"""
if not isinstance(value, str):
return value
if (value.startswith("[") and value.endswith("]")) or (
value.startswith("{") and value.endswith("}")
):
try:
return json.loads(value)
except json.decoder.JSONDecodeError:
pass
return value
def _parse_tool_args(arguments: dict[str, Any]) -> dict[str, Any]:
"""Rewrite ollama tool arguments.
This function improves tool use quality by fixing common mistakes made by
small local tool use models. This will repair invalid json arguments and
omit unnecessary arguments with empty values that will fail intent parsing.
"""
return {k: _fix_invalid_arguments(v) for k, v in arguments.items() if v}
class OllamaConversationEntity(
conversation.ConversationEntity, conversation.AbstractConversationAgent
):
"""Ollama conversation agent."""
_attr_has_entity_name = True
def __init__(self, entry: ConfigEntry) -> None:
"""Initialize the agent."""
self.entry = entry
# conversation id -> message history
self._history: dict[str, MessageHistory] = {}
self._attr_name = entry.title
self._attr_unique_id = entry.entry_id
if self.entry.options.get(CONF_LLM_HASS_API):
self._attr_supported_features = (
conversation.ConversationEntityFeature.CONTROL
)
async def async_added_to_hass(self) -> None:
"""When entity is added to Home Assistant."""
await super().async_added_to_hass()
assist_pipeline.async_migrate_engine(
self.hass, "conversation", self.entry.entry_id, self.entity_id
)
conversation.async_set_agent(self.hass, self.entry, self)
self.entry.async_on_unload(
self.entry.add_update_listener(self._async_entry_update_listener)
)
async def async_will_remove_from_hass(self) -> None:
"""When entity will be removed from Home Assistant."""
conversation.async_unset_agent(self.hass, self.entry)
await super().async_will_remove_from_hass()
@property
def supported_languages(self) -> list[str] | Literal["*"]:
"""Return a list of supported languages."""
return MATCH_ALL
async def async_process(
self, user_input: conversation.ConversationInput
) -> conversation.ConversationResult:
"""Process a sentence."""
settings = {**self.entry.data, **self.entry.options}
client = self.hass.data[DOMAIN][self.entry.entry_id]
conversation_id = user_input.conversation_id or ulid.ulid_now()
model = settings[CONF_MODEL]
intent_response = intent.IntentResponse(language=user_input.language)
llm_api: llm.APIInstance | None = None
tools: list[dict[str, Any]] | None = None
user_name: str | None = None
llm_context = llm.LLMContext(
platform=DOMAIN,
context=user_input.context,
user_prompt=user_input.text,
language=user_input.language,
assistant=conversation.DOMAIN,
device_id=user_input.device_id,
)
if settings.get(CONF_LLM_HASS_API):
try:
llm_api = await llm.async_get_api(
self.hass,
settings[CONF_LLM_HASS_API],
llm_context,
)
except HomeAssistantError as err:
_LOGGER.error("Error getting LLM API: %s", err)
intent_response.async_set_error(
intent.IntentResponseErrorCode.UNKNOWN,
f"Error preparing LLM API: {err}",
)
return conversation.ConversationResult(
response=intent_response, conversation_id=user_input.conversation_id
)
tools = [
_format_tool(tool, llm_api.custom_serializer) for tool in llm_api.tools
]
if (
user_input.context
and user_input.context.user_id
and (
user := await self.hass.auth.async_get_user(user_input.context.user_id)
)
):
user_name = user.name
# Look up message history
message_history: MessageHistory | None = None
message_history = self._history.get(conversation_id)
if message_history is None:
# New history
#
# Render prompt and error out early if there's a problem
try:
prompt_parts = [
template.Template(
llm.BASE_PROMPT
+ settings.get(CONF_PROMPT, llm.DEFAULT_INSTRUCTIONS_PROMPT),
self.hass,
).async_render(
{
"ha_name": self.hass.config.location_name,
"user_name": user_name,
"llm_context": llm_context,
},
parse_result=False,
)
]
except TemplateError as err:
_LOGGER.error("Error rendering prompt: %s", err)
intent_response.async_set_error(
intent.IntentResponseErrorCode.UNKNOWN,
f"Sorry, I had a problem generating my prompt: {err}",
)
return conversation.ConversationResult(
response=intent_response, conversation_id=conversation_id
)
if llm_api:
prompt_parts.append(llm_api.api_prompt)
prompt = "\n".join(prompt_parts)
_LOGGER.debug("Prompt: %s", prompt)
_LOGGER.debug("Tools: %s", tools)
message_history = MessageHistory(
timestamp=time.monotonic(),
messages=[
ollama.Message(role=MessageRole.SYSTEM.value, content=prompt)
],
)
self._history[conversation_id] = message_history
else:
# Bump timestamp so this conversation won't get cleaned up
message_history.timestamp = time.monotonic()
# Clean up old histories
self._prune_old_histories()
# Trim this message history to keep a maximum number of *user* messages
max_messages = int(settings.get(CONF_MAX_HISTORY, DEFAULT_MAX_HISTORY))
self._trim_history(message_history, max_messages)
# Add new user message
message_history.messages.append(
ollama.Message(role=MessageRole.USER.value, content=user_input.text)
)
trace.async_conversation_trace_append(
trace.ConversationTraceEventType.AGENT_DETAIL,
{"messages": message_history.messages},
)
# Get response
# To prevent infinite loops, we limit the number of iterations
for _iteration in range(MAX_TOOL_ITERATIONS):
try:
response = await client.chat(
model=model,
# Make a copy of the messages because we mutate the list later
messages=list(message_history.messages),
tools=tools,
stream=False,
# keep_alive requires specifying unit. In this case, seconds
keep_alive=f"{settings.get(CONF_KEEP_ALIVE, DEFAULT_KEEP_ALIVE)}s",
options={CONF_NUM_CTX: settings.get(CONF_NUM_CTX, DEFAULT_NUM_CTX)},
)
except (ollama.RequestError, ollama.ResponseError) as err:
_LOGGER.error("Unexpected error talking to Ollama server: %s", err)
intent_response.async_set_error(
intent.IntentResponseErrorCode.UNKNOWN,
f"Sorry, I had a problem talking to the Ollama server: {err}",
)
return conversation.ConversationResult(
response=intent_response, conversation_id=conversation_id
)
response_message = response["message"]
message_history.messages.append(
ollama.Message(
role=response_message["role"],
content=response_message.get("content"),
tool_calls=response_message.get("tool_calls"),
)
)
tool_calls = response_message.get("tool_calls")
if not tool_calls or not llm_api:
break
for tool_call in tool_calls:
tool_input = llm.ToolInput(
tool_name=tool_call["function"]["name"],
tool_args=_parse_tool_args(tool_call["function"]["arguments"]),
)
_LOGGER.debug(
"Tool call: %s(%s)", tool_input.tool_name, tool_input.tool_args
)
try:
tool_response = await llm_api.async_call_tool(tool_input)
except (HomeAssistantError, vol.Invalid) as e:
tool_response = {"error": type(e).__name__}
if str(e):
tool_response["error_text"] = str(e)
_LOGGER.debug("Tool response: %s", tool_response)
message_history.messages.append(
ollama.Message(
role=MessageRole.TOOL.value,
content=json.dumps(tool_response),
)
)
# Create intent response
intent_response.async_set_speech(response_message["content"])
return conversation.ConversationResult(
response=intent_response, conversation_id=conversation_id
)
def _prune_old_histories(self) -> None:
"""Remove old message histories."""
now = time.monotonic()
self._history = {
conversation_id: message_history
for conversation_id, message_history in self._history.items()
if (now - message_history.timestamp) <= MAX_HISTORY_SECONDS
}
def _trim_history(self, message_history: MessageHistory, max_messages: int) -> None:
"""Trims excess messages from a single history."""
if max_messages < 1:
# Keep all messages
return
if message_history.num_user_messages >= max_messages:
# Trim history but keep system prompt (first message).
# Every other message should be an assistant message, so keep 2x
# message objects.
num_keep = 2 * max_messages
drop_index = len(message_history.messages) - num_keep
message_history.messages = [
message_history.messages[0]
] + message_history.messages[drop_index:]
async def _async_entry_update_listener(
self, hass: HomeAssistant, entry: ConfigEntry
) -> None:
"""Handle options update."""
# Reload as we update device info + entity name + supported features
await hass.config_entries.async_reload(entry.entry_id)