Migrate openai_conversation to `entry.runtime_data` (#118535)
* switch to entry.runtime_data * check for missing config entry * Update homeassistant/components/openai_conversation/__init__.py --------- Co-authored-by: Paulus Schoutsen <paulus@home-assistant.io>pull/118845/head
parent
a59c890779
commit
4998fe5e6d
|
@ -2,6 +2,8 @@
|
|||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Literal, cast
|
||||
|
||||
import openai
|
||||
import voluptuous as vol
|
||||
|
||||
|
@ -13,7 +15,11 @@ from homeassistant.core import (
|
|||
ServiceResponse,
|
||||
SupportsResponse,
|
||||
)
|
||||
from homeassistant.exceptions import ConfigEntryNotReady, HomeAssistantError
|
||||
from homeassistant.exceptions import (
|
||||
ConfigEntryNotReady,
|
||||
HomeAssistantError,
|
||||
ServiceValidationError,
|
||||
)
|
||||
from homeassistant.helpers import (
|
||||
config_validation as cv,
|
||||
issue_registry as ir,
|
||||
|
@ -27,13 +33,25 @@ SERVICE_GENERATE_IMAGE = "generate_image"
|
|||
PLATFORMS = (Platform.CONVERSATION,)
|
||||
CONFIG_SCHEMA = cv.config_entry_only_config_schema(DOMAIN)
|
||||
|
||||
type OpenAIConfigEntry = ConfigEntry[openai.AsyncClient]
|
||||
|
||||
|
||||
async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool:
|
||||
"""Set up OpenAI Conversation."""
|
||||
|
||||
async def render_image(call: ServiceCall) -> ServiceResponse:
|
||||
"""Render an image with dall-e."""
|
||||
client = hass.data[DOMAIN][call.data["config_entry"]]
|
||||
entry_id = call.data["config_entry"]
|
||||
entry = hass.config_entries.async_get_entry(entry_id)
|
||||
|
||||
if entry is None or entry.domain != DOMAIN:
|
||||
raise ServiceValidationError(
|
||||
translation_domain=DOMAIN,
|
||||
translation_key="invalid_config_entry",
|
||||
translation_placeholders={"config_entry": entry_id},
|
||||
)
|
||||
|
||||
client: openai.AsyncClient = entry.runtime_data
|
||||
|
||||
if call.data["size"] in ("256", "512", "1024"):
|
||||
ir.async_create_issue(
|
||||
|
@ -51,6 +69,11 @@ async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool:
|
|||
else:
|
||||
size = call.data["size"]
|
||||
|
||||
size = cast(
|
||||
Literal["256x256", "512x512", "1024x1024", "1792x1024", "1024x1792"],
|
||||
size,
|
||||
) # size is selector, so no need to check further
|
||||
|
||||
try:
|
||||
response = await client.images.generate(
|
||||
model="dall-e-3",
|
||||
|
@ -90,7 +113,7 @@ async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool:
|
|||
return True
|
||||
|
||||
|
||||
async def async_setup_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool:
|
||||
async def async_setup_entry(hass: HomeAssistant, entry: OpenAIConfigEntry) -> bool:
|
||||
"""Set up OpenAI Conversation from a config entry."""
|
||||
client = openai.AsyncOpenAI(api_key=entry.data[CONF_API_KEY])
|
||||
try:
|
||||
|
@ -101,7 +124,7 @@ async def async_setup_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool:
|
|||
except openai.OpenAIError as err:
|
||||
raise ConfigEntryNotReady(err) from err
|
||||
|
||||
hass.data.setdefault(DOMAIN, {})[entry.entry_id] = client
|
||||
entry.runtime_data = client
|
||||
|
||||
await hass.config_entries.async_forward_entry_setups(entry, PLATFORMS)
|
||||
|
||||
|
@ -110,8 +133,4 @@ async def async_setup_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool:
|
|||
|
||||
async def async_unload_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool:
|
||||
"""Unload OpenAI."""
|
||||
if not await hass.config_entries.async_unload_platforms(entry, PLATFORMS):
|
||||
return False
|
||||
|
||||
hass.data[DOMAIN].pop(entry.entry_id)
|
||||
return True
|
||||
return await hass.config_entries.async_unload_platforms(entry, PLATFORMS)
|
||||
|
|
|
@ -22,7 +22,6 @@ from voluptuous_openapi import convert
|
|||
|
||||
from homeassistant.components import assist_pipeline, conversation
|
||||
from homeassistant.components.conversation import trace
|
||||
from homeassistant.config_entries import ConfigEntry
|
||||
from homeassistant.const import CONF_LLM_HASS_API, MATCH_ALL
|
||||
from homeassistant.core import HomeAssistant
|
||||
from homeassistant.exceptions import HomeAssistantError, TemplateError
|
||||
|
@ -30,6 +29,7 @@ from homeassistant.helpers import device_registry as dr, intent, llm, template
|
|||
from homeassistant.helpers.entity_platform import AddEntitiesCallback
|
||||
from homeassistant.util import ulid
|
||||
|
||||
from . import OpenAIConfigEntry
|
||||
from .const import (
|
||||
CONF_CHAT_MODEL,
|
||||
CONF_MAX_TOKENS,
|
||||
|
@ -50,7 +50,7 @@ MAX_TOOL_ITERATIONS = 10
|
|||
|
||||
async def async_setup_entry(
|
||||
hass: HomeAssistant,
|
||||
config_entry: ConfigEntry,
|
||||
config_entry: OpenAIConfigEntry,
|
||||
async_add_entities: AddEntitiesCallback,
|
||||
) -> None:
|
||||
"""Set up conversation entities."""
|
||||
|
@ -74,7 +74,7 @@ class OpenAIConversationEntity(
|
|||
_attr_has_entity_name = True
|
||||
_attr_name = None
|
||||
|
||||
def __init__(self, entry: ConfigEntry) -> None:
|
||||
def __init__(self, entry: OpenAIConfigEntry) -> None:
|
||||
"""Initialize the agent."""
|
||||
self.entry = entry
|
||||
self.history: dict[str, list[ChatCompletionMessageParam]] = {}
|
||||
|
@ -203,7 +203,7 @@ class OpenAIConversationEntity(
|
|||
trace.ConversationTraceEventType.AGENT_DETAIL, {"messages": messages}
|
||||
)
|
||||
|
||||
client: openai.AsyncClient = self.hass.data[DOMAIN][self.entry.entry_id]
|
||||
client = self.entry.runtime_data
|
||||
|
||||
# To prevent infinite loops, we limit the number of iterations
|
||||
for _iteration in range(MAX_TOOL_ITERATIONS):
|
||||
|
|
|
@ -60,6 +60,11 @@
|
|||
}
|
||||
}
|
||||
},
|
||||
"exceptions": {
|
||||
"invalid_config_entry": {
|
||||
"message": "Invalid config entry provided. Got {config_entry}"
|
||||
}
|
||||
},
|
||||
"issues": {
|
||||
"image_size_deprecated_format": {
|
||||
"title": "Deprecated size format for image generation service",
|
||||
|
|
|
@ -14,7 +14,7 @@ from openai.types.images_response import ImagesResponse
|
|||
import pytest
|
||||
|
||||
from homeassistant.core import HomeAssistant
|
||||
from homeassistant.exceptions import HomeAssistantError
|
||||
from homeassistant.exceptions import HomeAssistantError, ServiceValidationError
|
||||
from homeassistant.setup import async_setup_component
|
||||
|
||||
from tests.common import MockConfigEntry
|
||||
|
@ -160,6 +160,28 @@ async def test_generate_image_service_error(
|
|||
)
|
||||
|
||||
|
||||
async def test_invalid_config_entry(
|
||||
hass: HomeAssistant,
|
||||
mock_config_entry: MockConfigEntry,
|
||||
mock_init_component,
|
||||
) -> None:
|
||||
"""Assert exception when invalid config entry is provided."""
|
||||
service_data = {
|
||||
"prompt": "Picture of a dog",
|
||||
"config_entry": "invalid_entry",
|
||||
}
|
||||
with pytest.raises(
|
||||
ServiceValidationError, match="Invalid config entry provided. Got invalid_entry"
|
||||
):
|
||||
await hass.services.async_call(
|
||||
"openai_conversation",
|
||||
"generate_image",
|
||||
service_data,
|
||||
blocking=True,
|
||||
return_response=True,
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
("side_effect", "error"),
|
||||
[
|
||||
|
|
Loading…
Reference in New Issue