Add error handling for OpenAI (#86671)
* Add error handling for OpenAI * Simplify area filtering * better promptpull/86676/head
parent
c395698ea2
commit
28a3b4a32c
|
@ -3,7 +3,6 @@ from __future__ import annotations
|
|||
|
||||
from functools import partial
|
||||
import logging
|
||||
from typing import cast
|
||||
|
||||
import openai
|
||||
from openai import error
|
||||
|
@ -13,7 +12,7 @@ from homeassistant.config_entries import ConfigEntry
|
|||
from homeassistant.const import CONF_API_KEY
|
||||
from homeassistant.core import HomeAssistant
|
||||
from homeassistant.exceptions import ConfigEntryNotReady, TemplateError
|
||||
from homeassistant.helpers import area_registry, device_registry, intent, template
|
||||
from homeassistant.helpers import area_registry, intent, template
|
||||
from homeassistant.util import ulid
|
||||
|
||||
from .const import DEFAULT_MODEL, DEFAULT_PROMPT
|
||||
|
@ -97,15 +96,26 @@ class OpenAIAgent(conversation.AbstractConversationAgent):
|
|||
|
||||
_LOGGER.debug("Prompt for %s: %s", model, prompt)
|
||||
|
||||
result = await self.hass.async_add_executor_job(
|
||||
partial(
|
||||
openai.Completion.create,
|
||||
engine=model,
|
||||
prompt=prompt,
|
||||
max_tokens=150,
|
||||
user=conversation_id,
|
||||
try:
|
||||
result = await self.hass.async_add_executor_job(
|
||||
partial(
|
||||
openai.Completion.create,
|
||||
engine=model,
|
||||
prompt=prompt,
|
||||
max_tokens=150,
|
||||
user=conversation_id,
|
||||
)
|
||||
)
|
||||
)
|
||||
except error.OpenAIError as err:
|
||||
intent_response = intent.IntentResponse(language=user_input.language)
|
||||
intent_response.async_set_error(
|
||||
intent.IntentResponseErrorCode.UNKNOWN,
|
||||
f"Sorry, I had a problem talking to OpenAI: {err}",
|
||||
)
|
||||
return conversation.ConversationResult(
|
||||
response=intent_response, conversation_id=conversation_id
|
||||
)
|
||||
|
||||
_LOGGER.debug("Response %s", result)
|
||||
response = result["choices"][0]["text"].strip()
|
||||
self.history[conversation_id] = prompt + response
|
||||
|
@ -122,20 +132,9 @@ class OpenAIAgent(conversation.AbstractConversationAgent):
|
|||
|
||||
def _async_generate_prompt(self) -> str:
|
||||
"""Generate a prompt for the user."""
|
||||
dev_reg = device_registry.async_get(self.hass)
|
||||
return template.Template(DEFAULT_PROMPT, self.hass).async_render(
|
||||
{
|
||||
"ha_name": self.hass.config.location_name,
|
||||
"areas": [
|
||||
area
|
||||
for area in area_registry.async_get(self.hass).areas.values()
|
||||
# Filter out areas without devices
|
||||
if any(
|
||||
not dev.disabled_by
|
||||
for dev in device_registry.async_entries_for_area(
|
||||
dev_reg, cast(str, area.id)
|
||||
)
|
||||
)
|
||||
],
|
||||
"areas": list(area_registry.async_get(self.hass).areas.values()),
|
||||
}
|
||||
)
|
||||
|
|
|
@ -3,19 +3,26 @@
|
|||
DOMAIN = "openai_conversation"
|
||||
CONF_PROMPT = "prompt"
|
||||
DEFAULT_MODEL = "text-davinci-003"
|
||||
DEFAULT_PROMPT = """
|
||||
You are a conversational AI for a smart home named {{ ha_name }}.
|
||||
If a user wants to control a device, reject the request and suggest using the Home Assistant UI.
|
||||
DEFAULT_PROMPT = """This smart home is controlled by Home Assistant.
|
||||
|
||||
An overview of the areas and the devices in this smart home:
|
||||
{% for area in areas %}
|
||||
{%- for area in areas %}
|
||||
{%- set area_info = namespace(printed=false) %}
|
||||
{%- for device in area_devices(area.name) -%}
|
||||
{%- if not device_attr(device, "disabled_by") and not device_attr(device, "entry_type") %}
|
||||
{%- if not area_info.printed %}
|
||||
|
||||
{{ area.name }}:
|
||||
{% for device in area_devices(area.name) -%}
|
||||
{%- if not device_attr(device, "disabled_by") %}
|
||||
- {{ device_attr(device, "name") }} ({{ device_attr(device, "model") }} by {{ device_attr(device, "manufacturer") }})
|
||||
{%- set area_info.printed = true %}
|
||||
{%- endif %}
|
||||
- {{ device_attr(device, "name") }}{% if device_attr(device, "model") not in device_attr(device, "name") %} ({{ device_attr(device, "model") }}){% endif %}
|
||||
{%- endif %}
|
||||
{%- endfor %}
|
||||
{% endfor %}
|
||||
{%- endfor %}
|
||||
|
||||
Answer the users questions about the world truthfully.
|
||||
|
||||
If the user wants to control a device, reject the request and suggest using the Home Assistant UI.
|
||||
|
||||
Now finish this conversation:
|
||||
|
||||
|
|
|
@ -1,14 +1,20 @@
|
|||
"""Tests for the OpenAI integration."""
|
||||
from unittest.mock import patch
|
||||
|
||||
from openai import error
|
||||
|
||||
from homeassistant.components import conversation
|
||||
from homeassistant.core import Context
|
||||
from homeassistant.helpers import device_registry
|
||||
from homeassistant.helpers import area_registry, device_registry, intent
|
||||
|
||||
|
||||
async def test_default_prompt(hass, mock_init_component):
|
||||
"""Test that the default prompt works."""
|
||||
device_reg = device_registry.async_get(hass)
|
||||
area_reg = area_registry.async_get(hass)
|
||||
|
||||
for i in range(3):
|
||||
area_reg.async_create(f"{i}Empty Area")
|
||||
|
||||
device_reg.async_get_or_create(
|
||||
config_entry_id="1234",
|
||||
|
@ -18,12 +24,22 @@ async def test_default_prompt(hass, mock_init_component):
|
|||
model="Test Model",
|
||||
suggested_area="Test Area",
|
||||
)
|
||||
for i in range(3):
|
||||
device_reg.async_get_or_create(
|
||||
config_entry_id="1234",
|
||||
connections={("test", f"{i}abcd")},
|
||||
name="Test Service",
|
||||
manufacturer="Test Manufacturer",
|
||||
model="Test Model",
|
||||
suggested_area="Test Area",
|
||||
entry_type=device_registry.DeviceEntryType.SERVICE,
|
||||
)
|
||||
device_reg.async_get_or_create(
|
||||
config_entry_id="1234",
|
||||
connections={("test", "5678")},
|
||||
name="Test Device 2",
|
||||
manufacturer="Test Manufacturer 2",
|
||||
model="Test Model 2",
|
||||
model="Device 2",
|
||||
suggested_area="Test Area 2",
|
||||
)
|
||||
device_reg.async_get_or_create(
|
||||
|
@ -31,7 +47,7 @@ async def test_default_prompt(hass, mock_init_component):
|
|||
connections={("test", "9876")},
|
||||
name="Test Device 3",
|
||||
manufacturer="Test Manufacturer 3",
|
||||
model="Test Model 3",
|
||||
model="Test Model 3A",
|
||||
suggested_area="Test Area 2",
|
||||
)
|
||||
|
||||
|
@ -40,20 +56,20 @@ async def test_default_prompt(hass, mock_init_component):
|
|||
|
||||
assert (
|
||||
mock_create.mock_calls[0][2]["prompt"]
|
||||
== """You are a conversational AI for a smart home named test home.
|
||||
If a user wants to control a device, reject the request and suggest using the Home Assistant UI.
|
||||
== """This smart home is controlled by Home Assistant.
|
||||
|
||||
An overview of the areas and the devices in this smart home:
|
||||
|
||||
Test Area:
|
||||
|
||||
- Test Device (Test Model by Test Manufacturer)
|
||||
- Test Device (Test Model)
|
||||
|
||||
Test Area 2:
|
||||
- Test Device 2
|
||||
- Test Device 3 (Test Model 3A)
|
||||
|
||||
- Test Device 2 (Test Model 2 by Test Manufacturer 2)
|
||||
- Test Device 3 (Test Model 3 by Test Manufacturer 3)
|
||||
Answer the users questions about the world truthfully.
|
||||
|
||||
If the user wants to control a device, reject the request and suggest using the Home Assistant UI.
|
||||
|
||||
Now finish this conversation:
|
||||
|
||||
|
@ -61,3 +77,12 @@ Smart home: How can I assist?
|
|||
User: hello
|
||||
Smart home: """
|
||||
)
|
||||
|
||||
|
||||
async def test_error_handling(hass, mock_init_component):
|
||||
"""Test that the default prompt works."""
|
||||
with patch("openai.Completion.create", side_effect=error.ServiceUnavailableError):
|
||||
result = await conversation.async_converse(hass, "hello", None, Context())
|
||||
|
||||
assert result.response.response_type == intent.IntentResponseType.ERROR, result
|
||||
assert result.response.error_code == "unknown", result
|
||||
|
|
Loading…
Reference in New Issue