core/homeassistant/components/open_router/entity.py

251 lines
8.3 KiB
Python

"""Base entity for Open Router."""
from __future__ import annotations
from collections.abc import AsyncGenerator, Callable
import json
from typing import TYPE_CHECKING, Any, Literal
import openai
from openai.types.chat import (
ChatCompletionAssistantMessageParam,
ChatCompletionFunctionToolParam,
ChatCompletionMessage,
ChatCompletionMessageFunctionToolCallParam,
ChatCompletionMessageParam,
ChatCompletionSystemMessageParam,
ChatCompletionToolMessageParam,
ChatCompletionUserMessageParam,
)
from openai.types.chat.chat_completion_message_function_tool_call_param import Function
from openai.types.shared_params import FunctionDefinition, ResponseFormatJSONSchema
from openai.types.shared_params.response_format_json_schema import JSONSchema
import voluptuous as vol
from voluptuous_openapi import convert
from homeassistant.components import conversation
from homeassistant.config_entries import ConfigSubentry
from homeassistant.const import CONF_MODEL
from homeassistant.exceptions import HomeAssistantError
from homeassistant.helpers import device_registry as dr, llm
from homeassistant.helpers.entity import Entity
from . import OpenRouterConfigEntry
from .const import DOMAIN, LOGGER
# Max number of back and forth with the LLM to generate a response
MAX_TOOL_ITERATIONS = 10
def _adjust_schema(schema: dict[str, Any]) -> None:
"""Adjust the schema to be compatible with OpenRouter API."""
if schema["type"] == "object":
if "properties" not in schema:
return
if "required" not in schema:
schema["required"] = []
# Ensure all properties are required
for prop, prop_info in schema["properties"].items():
_adjust_schema(prop_info)
if prop not in schema["required"]:
prop_info["type"] = [prop_info["type"], "null"]
schema["required"].append(prop)
elif schema["type"] == "array":
if "items" not in schema:
return
_adjust_schema(schema["items"])
def _format_structured_output(
name: str, schema: vol.Schema, llm_api: llm.APIInstance | None
) -> JSONSchema:
"""Format the schema to be compatible with OpenRouter API."""
result: JSONSchema = {
"name": name,
"strict": True,
}
result_schema = convert(
schema,
custom_serializer=(
llm_api.custom_serializer if llm_api else llm.selector_serializer
),
)
_adjust_schema(result_schema)
result["schema"] = result_schema
return result
def _format_tool(
tool: llm.Tool,
custom_serializer: Callable[[Any], Any] | None,
) -> ChatCompletionFunctionToolParam:
"""Format tool specification."""
tool_spec = FunctionDefinition(
name=tool.name,
parameters=convert(tool.parameters, custom_serializer=custom_serializer),
)
if tool.description:
tool_spec["description"] = tool.description
return ChatCompletionFunctionToolParam(type="function", function=tool_spec)
def _convert_content_to_chat_message(
content: conversation.Content,
) -> ChatCompletionMessageParam | None:
"""Convert any native chat message for this agent to the native format."""
LOGGER.debug("_convert_content_to_chat_message=%s", content)
if isinstance(content, conversation.ToolResultContent):
return ChatCompletionToolMessageParam(
role="tool",
tool_call_id=content.tool_call_id,
content=json.dumps(content.tool_result),
)
role: Literal["user", "assistant", "system"] = content.role
if role == "system" and content.content:
return ChatCompletionSystemMessageParam(role="system", content=content.content)
if role == "user" and content.content:
return ChatCompletionUserMessageParam(role="user", content=content.content)
if role == "assistant":
param = ChatCompletionAssistantMessageParam(
role="assistant",
content=content.content,
)
if isinstance(content, conversation.AssistantContent) and content.tool_calls:
param["tool_calls"] = [
ChatCompletionMessageFunctionToolCallParam(
type="function",
id=tool_call.id,
function=Function(
arguments=json.dumps(tool_call.tool_args),
name=tool_call.tool_name,
),
)
for tool_call in content.tool_calls
]
return param
LOGGER.warning("Could not convert message to Completions API: %s", content)
return None
def _decode_tool_arguments(arguments: str) -> Any:
"""Decode tool call arguments."""
try:
return json.loads(arguments)
except json.JSONDecodeError as err:
raise HomeAssistantError(f"Unexpected tool argument response: {err}") from err
async def _transform_response(
message: ChatCompletionMessage,
) -> AsyncGenerator[conversation.AssistantContentDeltaDict]:
"""Transform the OpenRouter message to a ChatLog format."""
data: conversation.AssistantContentDeltaDict = {
"role": message.role,
"content": message.content,
}
if message.tool_calls:
data["tool_calls"] = [
llm.ToolInput(
id=tool_call.id,
tool_name=tool_call.function.name,
tool_args=_decode_tool_arguments(tool_call.function.arguments),
)
for tool_call in message.tool_calls
if tool_call.type == "function"
]
yield data
class OpenRouterEntity(Entity):
"""Base entity for Open Router."""
_attr_has_entity_name = True
def __init__(self, entry: OpenRouterConfigEntry, subentry: ConfigSubentry) -> None:
"""Initialize the entity."""
self.entry = entry
self.subentry = subentry
self.model = subentry.data[CONF_MODEL]
self._attr_unique_id = subentry.subentry_id
self._attr_device_info = dr.DeviceInfo(
identifiers={(DOMAIN, subentry.subentry_id)},
name=subentry.title,
entry_type=dr.DeviceEntryType.SERVICE,
)
async def _async_handle_chat_log(
self,
chat_log: conversation.ChatLog,
structure_name: str | None = None,
structure: vol.Schema | None = None,
) -> None:
"""Generate an answer for the chat log."""
model_args = {
"model": self.model,
"user": chat_log.conversation_id,
"extra_headers": {
"X-Title": "Home Assistant",
"HTTP-Referer": "https://www.home-assistant.io/integrations/open_router",
},
"extra_body": {"require_parameters": True},
}
tools: list[ChatCompletionFunctionToolParam] | None = None
if chat_log.llm_api:
tools = [
_format_tool(tool, chat_log.llm_api.custom_serializer)
for tool in chat_log.llm_api.tools
]
if tools:
model_args["tools"] = tools
model_args["messages"] = [
m
for content in chat_log.content
if (m := _convert_content_to_chat_message(content))
]
if structure:
if TYPE_CHECKING:
assert structure_name is not None
model_args["response_format"] = ResponseFormatJSONSchema(
type="json_schema",
json_schema=_format_structured_output(
structure_name, structure, chat_log.llm_api
),
)
client = self.entry.runtime_data
for _iteration in range(MAX_TOOL_ITERATIONS):
try:
result = await client.chat.completions.create(**model_args)
except openai.OpenAIError as err:
LOGGER.error("Error talking to API: %s", err)
raise HomeAssistantError("Error talking to API") from err
result_message = result.choices[0].message
model_args["messages"].extend(
[
msg
async for content in chat_log.async_add_delta_content_stream(
self.entity_id, _transform_response(result_message)
)
if (msg := _convert_content_to_chat_message(content))
]
)
if not chat_log.unresponded_tool_results:
break