81 lines
2.4 KiB
Python
81 lines
2.4 KiB
Python
"""AI Task integration for Ollama."""
|
|
|
|
from __future__ import annotations
|
|
|
|
from json import JSONDecodeError
|
|
import logging
|
|
|
|
from homeassistant.components import ai_task, conversation
|
|
from homeassistant.config_entries import ConfigEntry
|
|
from homeassistant.core import HomeAssistant
|
|
from homeassistant.exceptions import HomeAssistantError
|
|
from homeassistant.helpers.entity_platform import AddConfigEntryEntitiesCallback
|
|
from homeassistant.util.json import json_loads
|
|
|
|
from .entity import OllamaBaseLLMEntity
|
|
|
|
_LOGGER = logging.getLogger(__name__)
|
|
|
|
|
|
async def async_setup_entry(
|
|
hass: HomeAssistant,
|
|
config_entry: ConfigEntry,
|
|
async_add_entities: AddConfigEntryEntitiesCallback,
|
|
) -> None:
|
|
"""Set up AI Task entities."""
|
|
for subentry in config_entry.subentries.values():
|
|
if subentry.subentry_type != "ai_task_data":
|
|
continue
|
|
|
|
async_add_entities(
|
|
[OllamaTaskEntity(config_entry, subentry)],
|
|
config_subentry_id=subentry.subentry_id,
|
|
)
|
|
|
|
|
|
class OllamaTaskEntity(
|
|
ai_task.AITaskEntity,
|
|
OllamaBaseLLMEntity,
|
|
):
|
|
"""Ollama AI Task entity."""
|
|
|
|
_attr_supported_features = (
|
|
ai_task.AITaskEntityFeature.GENERATE_DATA
|
|
| ai_task.AITaskEntityFeature.SUPPORT_ATTACHMENTS
|
|
)
|
|
|
|
async def _async_generate_data(
|
|
self,
|
|
task: ai_task.GenDataTask,
|
|
chat_log: conversation.ChatLog,
|
|
) -> ai_task.GenDataTaskResult:
|
|
"""Handle a generate data task."""
|
|
await self._async_handle_chat_log(chat_log, task.structure)
|
|
|
|
if not isinstance(chat_log.content[-1], conversation.AssistantContent):
|
|
raise HomeAssistantError(
|
|
"Last content in chat log is not an AssistantContent"
|
|
)
|
|
|
|
text = chat_log.content[-1].content or ""
|
|
|
|
if not task.structure:
|
|
return ai_task.GenDataTaskResult(
|
|
conversation_id=chat_log.conversation_id,
|
|
data=text,
|
|
)
|
|
try:
|
|
data = json_loads(text)
|
|
except JSONDecodeError as err:
|
|
_LOGGER.error(
|
|
"Failed to parse JSON response: %s. Response: %s",
|
|
err,
|
|
text,
|
|
)
|
|
raise HomeAssistantError("Error with Ollama structured response") from err
|
|
|
|
return ai_task.GenDataTaskResult(
|
|
conversation_id=chat_log.conversation_id,
|
|
data=data,
|
|
)
|