Migrate openai_conversation to `entry.runtime_data` (#118535)

* switch to entry.runtime_data

* check for missing config entry

* Update homeassistant/components/openai_conversation/__init__.py

---------

Co-authored-by: Paulus Schoutsen <paulus@home-assistant.io>
pull/118845/head
Josef Zweck 2024-05-31 17:16:39 +02:00 committed by Paulus Schoutsen
parent a59c890779
commit 4998fe5e6d
4 changed files with 60 additions and 14 deletions

View File

@ -2,6 +2,8 @@
from __future__ import annotations
from typing import Literal, cast
import openai
import voluptuous as vol
@ -13,7 +15,11 @@ from homeassistant.core import (
ServiceResponse,
SupportsResponse,
)
from homeassistant.exceptions import ConfigEntryNotReady, HomeAssistantError
from homeassistant.exceptions import (
ConfigEntryNotReady,
HomeAssistantError,
ServiceValidationError,
)
from homeassistant.helpers import (
config_validation as cv,
issue_registry as ir,
@ -27,13 +33,25 @@ SERVICE_GENERATE_IMAGE = "generate_image"
PLATFORMS = (Platform.CONVERSATION,)
CONFIG_SCHEMA = cv.config_entry_only_config_schema(DOMAIN)
type OpenAIConfigEntry = ConfigEntry[openai.AsyncClient]
async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool:
"""Set up OpenAI Conversation."""
async def render_image(call: ServiceCall) -> ServiceResponse:
"""Render an image with dall-e."""
client = hass.data[DOMAIN][call.data["config_entry"]]
entry_id = call.data["config_entry"]
entry = hass.config_entries.async_get_entry(entry_id)
if entry is None or entry.domain != DOMAIN:
raise ServiceValidationError(
translation_domain=DOMAIN,
translation_key="invalid_config_entry",
translation_placeholders={"config_entry": entry_id},
)
client: openai.AsyncClient = entry.runtime_data
if call.data["size"] in ("256", "512", "1024"):
ir.async_create_issue(
@ -51,6 +69,11 @@ async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool:
else:
size = call.data["size"]
size = cast(
Literal["256x256", "512x512", "1024x1024", "1792x1024", "1024x1792"],
size,
) # size is selector, so no need to check further
try:
response = await client.images.generate(
model="dall-e-3",
@ -90,7 +113,7 @@ async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool:
return True
async def async_setup_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool:
async def async_setup_entry(hass: HomeAssistant, entry: OpenAIConfigEntry) -> bool:
"""Set up OpenAI Conversation from a config entry."""
client = openai.AsyncOpenAI(api_key=entry.data[CONF_API_KEY])
try:
@ -101,7 +124,7 @@ async def async_setup_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool:
except openai.OpenAIError as err:
raise ConfigEntryNotReady(err) from err
hass.data.setdefault(DOMAIN, {})[entry.entry_id] = client
entry.runtime_data = client
await hass.config_entries.async_forward_entry_setups(entry, PLATFORMS)
@ -110,8 +133,4 @@ async def async_setup_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool:
async def async_unload_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool:
"""Unload OpenAI."""
if not await hass.config_entries.async_unload_platforms(entry, PLATFORMS):
return False
hass.data[DOMAIN].pop(entry.entry_id)
return True
return await hass.config_entries.async_unload_platforms(entry, PLATFORMS)

View File

@ -22,7 +22,6 @@ from voluptuous_openapi import convert
from homeassistant.components import assist_pipeline, conversation
from homeassistant.components.conversation import trace
from homeassistant.config_entries import ConfigEntry
from homeassistant.const import CONF_LLM_HASS_API, MATCH_ALL
from homeassistant.core import HomeAssistant
from homeassistant.exceptions import HomeAssistantError, TemplateError
@ -30,6 +29,7 @@ from homeassistant.helpers import device_registry as dr, intent, llm, template
from homeassistant.helpers.entity_platform import AddEntitiesCallback
from homeassistant.util import ulid
from . import OpenAIConfigEntry
from .const import (
CONF_CHAT_MODEL,
CONF_MAX_TOKENS,
@ -50,7 +50,7 @@ MAX_TOOL_ITERATIONS = 10
async def async_setup_entry(
hass: HomeAssistant,
config_entry: ConfigEntry,
config_entry: OpenAIConfigEntry,
async_add_entities: AddEntitiesCallback,
) -> None:
"""Set up conversation entities."""
@ -74,7 +74,7 @@ class OpenAIConversationEntity(
_attr_has_entity_name = True
_attr_name = None
def __init__(self, entry: ConfigEntry) -> None:
def __init__(self, entry: OpenAIConfigEntry) -> None:
"""Initialize the agent."""
self.entry = entry
self.history: dict[str, list[ChatCompletionMessageParam]] = {}
@ -203,7 +203,7 @@ class OpenAIConversationEntity(
trace.ConversationTraceEventType.AGENT_DETAIL, {"messages": messages}
)
client: openai.AsyncClient = self.hass.data[DOMAIN][self.entry.entry_id]
client = self.entry.runtime_data
# To prevent infinite loops, we limit the number of iterations
for _iteration in range(MAX_TOOL_ITERATIONS):

View File

@ -60,6 +60,11 @@
}
}
},
"exceptions": {
"invalid_config_entry": {
"message": "Invalid config entry provided. Got {config_entry}"
}
},
"issues": {
"image_size_deprecated_format": {
"title": "Deprecated size format for image generation service",

View File

@ -14,7 +14,7 @@ from openai.types.images_response import ImagesResponse
import pytest
from homeassistant.core import HomeAssistant
from homeassistant.exceptions import HomeAssistantError
from homeassistant.exceptions import HomeAssistantError, ServiceValidationError
from homeassistant.setup import async_setup_component
from tests.common import MockConfigEntry
@ -160,6 +160,28 @@ async def test_generate_image_service_error(
)
async def test_invalid_config_entry(
hass: HomeAssistant,
mock_config_entry: MockConfigEntry,
mock_init_component,
) -> None:
"""Assert exception when invalid config entry is provided."""
service_data = {
"prompt": "Picture of a dog",
"config_entry": "invalid_entry",
}
with pytest.raises(
ServiceValidationError, match="Invalid config entry provided. Got invalid_entry"
):
await hass.services.async_call(
"openai_conversation",
"generate_image",
service_data,
blocking=True,
return_response=True,
)
@pytest.mark.parametrize(
("side_effect", "error"),
[