Add error handling for OpenAI (#86671)

* Add error handling for OpenAI

* Simplify area filtering

* better prompt
pull/86676/head
Paulus Schoutsen 2023-01-25 22:17:19 -05:00 committed by GitHub
parent c395698ea2
commit 28a3b4a32c
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 70 additions and 39 deletions

View File

@ -3,7 +3,6 @@ from __future__ import annotations
from functools import partial
import logging
from typing import cast
import openai
from openai import error
@ -13,7 +12,7 @@ from homeassistant.config_entries import ConfigEntry
from homeassistant.const import CONF_API_KEY
from homeassistant.core import HomeAssistant
from homeassistant.exceptions import ConfigEntryNotReady, TemplateError
from homeassistant.helpers import area_registry, device_registry, intent, template
from homeassistant.helpers import area_registry, intent, template
from homeassistant.util import ulid
from .const import DEFAULT_MODEL, DEFAULT_PROMPT
@ -97,15 +96,26 @@ class OpenAIAgent(conversation.AbstractConversationAgent):
_LOGGER.debug("Prompt for %s: %s", model, prompt)
result = await self.hass.async_add_executor_job(
partial(
openai.Completion.create,
engine=model,
prompt=prompt,
max_tokens=150,
user=conversation_id,
try:
result = await self.hass.async_add_executor_job(
partial(
openai.Completion.create,
engine=model,
prompt=prompt,
max_tokens=150,
user=conversation_id,
)
)
)
except error.OpenAIError as err:
intent_response = intent.IntentResponse(language=user_input.language)
intent_response.async_set_error(
intent.IntentResponseErrorCode.UNKNOWN,
f"Sorry, I had a problem talking to OpenAI: {err}",
)
return conversation.ConversationResult(
response=intent_response, conversation_id=conversation_id
)
_LOGGER.debug("Response %s", result)
response = result["choices"][0]["text"].strip()
self.history[conversation_id] = prompt + response
@ -122,20 +132,9 @@ class OpenAIAgent(conversation.AbstractConversationAgent):
def _async_generate_prompt(self) -> str:
"""Generate a prompt for the user."""
dev_reg = device_registry.async_get(self.hass)
return template.Template(DEFAULT_PROMPT, self.hass).async_render(
{
"ha_name": self.hass.config.location_name,
"areas": [
area
for area in area_registry.async_get(self.hass).areas.values()
# Filter out areas without devices
if any(
not dev.disabled_by
for dev in device_registry.async_entries_for_area(
dev_reg, cast(str, area.id)
)
)
],
"areas": list(area_registry.async_get(self.hass).areas.values()),
}
)

View File

@ -3,19 +3,26 @@
DOMAIN = "openai_conversation"
CONF_PROMPT = "prompt"
DEFAULT_MODEL = "text-davinci-003"
DEFAULT_PROMPT = """
You are a conversational AI for a smart home named {{ ha_name }}.
If a user wants to control a device, reject the request and suggest using the Home Assistant UI.
DEFAULT_PROMPT = """This smart home is controlled by Home Assistant.
An overview of the areas and the devices in this smart home:
{% for area in areas %}
{%- for area in areas %}
{%- set area_info = namespace(printed=false) %}
{%- for device in area_devices(area.name) -%}
{%- if not device_attr(device, "disabled_by") and not device_attr(device, "entry_type") %}
{%- if not area_info.printed %}
{{ area.name }}:
{% for device in area_devices(area.name) -%}
{%- if not device_attr(device, "disabled_by") %}
- {{ device_attr(device, "name") }} ({{ device_attr(device, "model") }} by {{ device_attr(device, "manufacturer") }})
{%- set area_info.printed = true %}
{%- endif %}
- {{ device_attr(device, "name") }}{% if device_attr(device, "model") not in device_attr(device, "name") %} ({{ device_attr(device, "model") }}){% endif %}
{%- endif %}
{%- endfor %}
{% endfor %}
{%- endfor %}
Answer the users questions about the world truthfully.
If the user wants to control a device, reject the request and suggest using the Home Assistant UI.
Now finish this conversation:

View File

@ -1,14 +1,20 @@
"""Tests for the OpenAI integration."""
from unittest.mock import patch
from openai import error
from homeassistant.components import conversation
from homeassistant.core import Context
from homeassistant.helpers import device_registry
from homeassistant.helpers import area_registry, device_registry, intent
async def test_default_prompt(hass, mock_init_component):
"""Test that the default prompt works."""
device_reg = device_registry.async_get(hass)
area_reg = area_registry.async_get(hass)
for i in range(3):
area_reg.async_create(f"{i}Empty Area")
device_reg.async_get_or_create(
config_entry_id="1234",
@ -18,12 +24,22 @@ async def test_default_prompt(hass, mock_init_component):
model="Test Model",
suggested_area="Test Area",
)
for i in range(3):
device_reg.async_get_or_create(
config_entry_id="1234",
connections={("test", f"{i}abcd")},
name="Test Service",
manufacturer="Test Manufacturer",
model="Test Model",
suggested_area="Test Area",
entry_type=device_registry.DeviceEntryType.SERVICE,
)
device_reg.async_get_or_create(
config_entry_id="1234",
connections={("test", "5678")},
name="Test Device 2",
manufacturer="Test Manufacturer 2",
model="Test Model 2",
model="Device 2",
suggested_area="Test Area 2",
)
device_reg.async_get_or_create(
@ -31,7 +47,7 @@ async def test_default_prompt(hass, mock_init_component):
connections={("test", "9876")},
name="Test Device 3",
manufacturer="Test Manufacturer 3",
model="Test Model 3",
model="Test Model 3A",
suggested_area="Test Area 2",
)
@ -40,20 +56,20 @@ async def test_default_prompt(hass, mock_init_component):
assert (
mock_create.mock_calls[0][2]["prompt"]
== """You are a conversational AI for a smart home named test home.
If a user wants to control a device, reject the request and suggest using the Home Assistant UI.
== """This smart home is controlled by Home Assistant.
An overview of the areas and the devices in this smart home:
Test Area:
- Test Device (Test Model by Test Manufacturer)
- Test Device (Test Model)
Test Area 2:
- Test Device 2
- Test Device 3 (Test Model 3A)
- Test Device 2 (Test Model 2 by Test Manufacturer 2)
- Test Device 3 (Test Model 3 by Test Manufacturer 3)
Answer the users questions about the world truthfully.
If the user wants to control a device, reject the request and suggest using the Home Assistant UI.
Now finish this conversation:
@ -61,3 +77,12 @@ Smart home: How can I assist?
User: hello
Smart home: """
)
async def test_error_handling(hass, mock_init_component):
"""Test that the default prompt works."""
with patch("openai.Completion.create", side_effect=error.ServiceUnavailableError):
result = await conversation.async_converse(hass, "hello", None, Context())
assert result.response.response_type == intent.IntentResponseType.ERROR, result
assert result.response.error_code == "unknown", result