267 lines
9.3 KiB
Python
267 lines
9.3 KiB
Python
"""The Ollama integration."""
|
|
|
|
from __future__ import annotations
|
|
|
|
import asyncio
|
|
import logging
|
|
import time
|
|
from typing import Literal
|
|
|
|
import httpx
|
|
import ollama
|
|
|
|
from homeassistant.components import conversation
|
|
from homeassistant.components.homeassistant.exposed_entities import async_should_expose
|
|
from homeassistant.config_entries import ConfigEntry
|
|
from homeassistant.const import CONF_URL, MATCH_ALL
|
|
from homeassistant.core import HomeAssistant
|
|
from homeassistant.exceptions import ConfigEntryNotReady, TemplateError
|
|
from homeassistant.helpers import (
|
|
area_registry as ar,
|
|
config_validation as cv,
|
|
device_registry as dr,
|
|
entity_registry as er,
|
|
intent,
|
|
template,
|
|
)
|
|
from homeassistant.util import ulid
|
|
|
|
from .const import (
|
|
CONF_MAX_HISTORY,
|
|
CONF_MODEL,
|
|
CONF_PROMPT,
|
|
DEFAULT_MAX_HISTORY,
|
|
DEFAULT_PROMPT,
|
|
DEFAULT_TIMEOUT,
|
|
DOMAIN,
|
|
KEEP_ALIVE_FOREVER,
|
|
MAX_HISTORY_SECONDS,
|
|
)
|
|
from .models import ExposedEntity, MessageHistory, MessageRole
|
|
|
|
_LOGGER = logging.getLogger(__name__)
|
|
|
|
__all__ = [
|
|
"CONF_URL",
|
|
"CONF_PROMPT",
|
|
"CONF_MODEL",
|
|
"CONF_MAX_HISTORY",
|
|
"MAX_HISTORY_NO_LIMIT",
|
|
"DOMAIN",
|
|
]
|
|
|
|
CONFIG_SCHEMA = cv.config_entry_only_config_schema(DOMAIN)
|
|
|
|
|
|
async def async_setup_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool:
|
|
"""Set up Ollama from a config entry."""
|
|
settings = {**entry.data, **entry.options}
|
|
client = ollama.AsyncClient(host=settings[CONF_URL])
|
|
try:
|
|
async with asyncio.timeout(DEFAULT_TIMEOUT):
|
|
await client.list()
|
|
except (TimeoutError, httpx.ConnectError) as err:
|
|
raise ConfigEntryNotReady(err) from err
|
|
|
|
hass.data.setdefault(DOMAIN, {})[entry.entry_id] = client
|
|
|
|
conversation.async_set_agent(hass, entry, OllamaAgent(hass, entry))
|
|
return True
|
|
|
|
|
|
async def async_unload_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool:
|
|
"""Unload Ollama."""
|
|
hass.data[DOMAIN].pop(entry.entry_id)
|
|
conversation.async_unset_agent(hass, entry)
|
|
return True
|
|
|
|
|
|
class OllamaAgent(conversation.AbstractConversationAgent):
|
|
"""Ollama conversation agent."""
|
|
|
|
def __init__(self, hass: HomeAssistant, entry: ConfigEntry) -> None:
|
|
"""Initialize the agent."""
|
|
self.hass = hass
|
|
self.entry = entry
|
|
|
|
# conversation id -> message history
|
|
self._history: dict[str, MessageHistory] = {}
|
|
|
|
@property
|
|
def supported_languages(self) -> list[str] | Literal["*"]:
|
|
"""Return a list of supported languages."""
|
|
return MATCH_ALL
|
|
|
|
async def async_process(
|
|
self, user_input: conversation.ConversationInput
|
|
) -> conversation.ConversationResult:
|
|
"""Process a sentence."""
|
|
settings = {**self.entry.data, **self.entry.options}
|
|
|
|
client = self.hass.data[DOMAIN][self.entry.entry_id]
|
|
conversation_id = user_input.conversation_id or ulid.ulid_now()
|
|
model = settings[CONF_MODEL]
|
|
|
|
# Look up message history
|
|
message_history: MessageHistory | None = None
|
|
message_history = self._history.get(conversation_id)
|
|
if message_history is None:
|
|
# New history
|
|
#
|
|
# Render prompt and error out early if there's a problem
|
|
raw_prompt = settings.get(CONF_PROMPT, DEFAULT_PROMPT)
|
|
try:
|
|
prompt = self._generate_prompt(raw_prompt)
|
|
_LOGGER.debug("Prompt: %s", prompt)
|
|
except TemplateError as err:
|
|
_LOGGER.error("Error rendering prompt: %s", err)
|
|
intent_response = intent.IntentResponse(language=user_input.language)
|
|
intent_response.async_set_error(
|
|
intent.IntentResponseErrorCode.UNKNOWN,
|
|
f"Sorry, I had a problem generating my prompt: {err}",
|
|
)
|
|
return conversation.ConversationResult(
|
|
response=intent_response, conversation_id=conversation_id
|
|
)
|
|
|
|
message_history = MessageHistory(
|
|
timestamp=time.monotonic(),
|
|
messages=[
|
|
ollama.Message(role=MessageRole.SYSTEM.value, content=prompt)
|
|
],
|
|
)
|
|
self._history[conversation_id] = message_history
|
|
else:
|
|
# Bump timestamp so this conversation won't get cleaned up
|
|
message_history.timestamp = time.monotonic()
|
|
|
|
# Clean up old histories
|
|
self._prune_old_histories()
|
|
|
|
# Trim this message history to keep a maximum number of *user* messages
|
|
max_messages = int(settings.get(CONF_MAX_HISTORY, DEFAULT_MAX_HISTORY))
|
|
self._trim_history(message_history, max_messages)
|
|
|
|
# Add new user message
|
|
message_history.messages.append(
|
|
ollama.Message(role=MessageRole.USER.value, content=user_input.text)
|
|
)
|
|
|
|
# Get response
|
|
try:
|
|
response = await client.chat(
|
|
model=model,
|
|
# Make a copy of the messages because we mutate the list later
|
|
messages=list(message_history.messages),
|
|
stream=False,
|
|
keep_alive=KEEP_ALIVE_FOREVER,
|
|
)
|
|
except (ollama.RequestError, ollama.ResponseError) as err:
|
|
_LOGGER.error("Unexpected error talking to Ollama server: %s", err)
|
|
intent_response = intent.IntentResponse(language=user_input.language)
|
|
intent_response.async_set_error(
|
|
intent.IntentResponseErrorCode.UNKNOWN,
|
|
f"Sorry, I had a problem talking to the Ollama server: {err}",
|
|
)
|
|
return conversation.ConversationResult(
|
|
response=intent_response, conversation_id=conversation_id
|
|
)
|
|
|
|
response_message = response["message"]
|
|
message_history.messages.append(
|
|
ollama.Message(
|
|
role=response_message["role"], content=response_message["content"]
|
|
)
|
|
)
|
|
|
|
# Create intent response
|
|
intent_response = intent.IntentResponse(language=user_input.language)
|
|
intent_response.async_set_speech(response_message["content"])
|
|
return conversation.ConversationResult(
|
|
response=intent_response, conversation_id=conversation_id
|
|
)
|
|
|
|
def _prune_old_histories(self) -> None:
|
|
"""Remove old message histories."""
|
|
now = time.monotonic()
|
|
self._history = {
|
|
conversation_id: message_history
|
|
for conversation_id, message_history in self._history.items()
|
|
if (now - message_history.timestamp) <= MAX_HISTORY_SECONDS
|
|
}
|
|
|
|
def _trim_history(self, message_history: MessageHistory, max_messages: int) -> None:
|
|
"""Trims excess messages from a single history."""
|
|
if max_messages < 1:
|
|
# Keep all messages
|
|
return
|
|
|
|
if message_history.num_user_messages >= max_messages:
|
|
# Trim history but keep system prompt (first message).
|
|
# Every other message should be an assistant message, so keep 2x
|
|
# message objects.
|
|
num_keep = 2 * max_messages
|
|
drop_index = len(message_history.messages) - num_keep
|
|
message_history.messages = [
|
|
message_history.messages[0]
|
|
] + message_history.messages[drop_index:]
|
|
|
|
def _generate_prompt(self, raw_prompt: str) -> str:
|
|
"""Generate a prompt for the user."""
|
|
return template.Template(raw_prompt, self.hass).async_render(
|
|
{
|
|
"ha_name": self.hass.config.location_name,
|
|
"ha_language": self.hass.config.language,
|
|
"exposed_entities": self._get_exposed_entities(),
|
|
},
|
|
parse_result=False,
|
|
)
|
|
|
|
def _get_exposed_entities(self) -> list[ExposedEntity]:
|
|
"""Get state list of exposed entities."""
|
|
area_registry = ar.async_get(self.hass)
|
|
entity_registry = er.async_get(self.hass)
|
|
device_registry = dr.async_get(self.hass)
|
|
|
|
exposed_entities = []
|
|
exposed_states = [
|
|
state
|
|
for state in self.hass.states.async_all()
|
|
if async_should_expose(self.hass, conversation.DOMAIN, state.entity_id)
|
|
]
|
|
|
|
for state in exposed_states:
|
|
entity = entity_registry.async_get(state.entity_id)
|
|
names = [state.name]
|
|
area_names = []
|
|
|
|
if entity is not None:
|
|
# Add aliases
|
|
names.extend(entity.aliases)
|
|
if entity.area_id and (
|
|
area := area_registry.async_get_area(entity.area_id)
|
|
):
|
|
# Entity is in area
|
|
area_names.append(area.name)
|
|
area_names.extend(area.aliases)
|
|
elif entity.device_id and (
|
|
device := device_registry.async_get(entity.device_id)
|
|
):
|
|
# Check device area
|
|
if device.area_id and (
|
|
area := area_registry.async_get_area(device.area_id)
|
|
):
|
|
area_names.append(area.name)
|
|
area_names.extend(area.aliases)
|
|
|
|
exposed_entities.append(
|
|
ExposedEntity(
|
|
entity_id=state.entity_id,
|
|
state=state,
|
|
names=names,
|
|
area_names=area_names,
|
|
)
|
|
)
|
|
|
|
return exposed_entities
|