Add error handling for OpenAI (#86671)

* Add error handling for OpenAI

* Simplify area filtering

* better prompt
pull/86676/head
Paulus Schoutsen 2023-01-25 22:17:19 -05:00 committed by GitHub
parent c395698ea2
commit 28a3b4a32c
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 70 additions and 39 deletions

View File

@ -3,7 +3,6 @@ from __future__ import annotations
from functools import partial from functools import partial
import logging import logging
from typing import cast
import openai import openai
from openai import error from openai import error
@ -13,7 +12,7 @@ from homeassistant.config_entries import ConfigEntry
from homeassistant.const import CONF_API_KEY from homeassistant.const import CONF_API_KEY
from homeassistant.core import HomeAssistant from homeassistant.core import HomeAssistant
from homeassistant.exceptions import ConfigEntryNotReady, TemplateError from homeassistant.exceptions import ConfigEntryNotReady, TemplateError
from homeassistant.helpers import area_registry, device_registry, intent, template from homeassistant.helpers import area_registry, intent, template
from homeassistant.util import ulid from homeassistant.util import ulid
from .const import DEFAULT_MODEL, DEFAULT_PROMPT from .const import DEFAULT_MODEL, DEFAULT_PROMPT
@ -97,15 +96,26 @@ class OpenAIAgent(conversation.AbstractConversationAgent):
_LOGGER.debug("Prompt for %s: %s", model, prompt) _LOGGER.debug("Prompt for %s: %s", model, prompt)
result = await self.hass.async_add_executor_job( try:
partial( result = await self.hass.async_add_executor_job(
openai.Completion.create, partial(
engine=model, openai.Completion.create,
prompt=prompt, engine=model,
max_tokens=150, prompt=prompt,
user=conversation_id, max_tokens=150,
user=conversation_id,
)
) )
) except error.OpenAIError as err:
intent_response = intent.IntentResponse(language=user_input.language)
intent_response.async_set_error(
intent.IntentResponseErrorCode.UNKNOWN,
f"Sorry, I had a problem talking to OpenAI: {err}",
)
return conversation.ConversationResult(
response=intent_response, conversation_id=conversation_id
)
_LOGGER.debug("Response %s", result) _LOGGER.debug("Response %s", result)
response = result["choices"][0]["text"].strip() response = result["choices"][0]["text"].strip()
self.history[conversation_id] = prompt + response self.history[conversation_id] = prompt + response
@ -122,20 +132,9 @@ class OpenAIAgent(conversation.AbstractConversationAgent):
def _async_generate_prompt(self) -> str: def _async_generate_prompt(self) -> str:
"""Generate a prompt for the user.""" """Generate a prompt for the user."""
dev_reg = device_registry.async_get(self.hass)
return template.Template(DEFAULT_PROMPT, self.hass).async_render( return template.Template(DEFAULT_PROMPT, self.hass).async_render(
{ {
"ha_name": self.hass.config.location_name, "ha_name": self.hass.config.location_name,
"areas": [ "areas": list(area_registry.async_get(self.hass).areas.values()),
area
for area in area_registry.async_get(self.hass).areas.values()
# Filter out areas without devices
if any(
not dev.disabled_by
for dev in device_registry.async_entries_for_area(
dev_reg, cast(str, area.id)
)
)
],
} }
) )

View File

@ -3,19 +3,26 @@
DOMAIN = "openai_conversation" DOMAIN = "openai_conversation"
CONF_PROMPT = "prompt" CONF_PROMPT = "prompt"
DEFAULT_MODEL = "text-davinci-003" DEFAULT_MODEL = "text-davinci-003"
DEFAULT_PROMPT = """ DEFAULT_PROMPT = """This smart home is controlled by Home Assistant.
You are a conversational AI for a smart home named {{ ha_name }}.
If a user wants to control a device, reject the request and suggest using the Home Assistant UI.
An overview of the areas and the devices in this smart home: An overview of the areas and the devices in this smart home:
{% for area in areas %} {%- for area in areas %}
{%- set area_info = namespace(printed=false) %}
{%- for device in area_devices(area.name) -%}
{%- if not device_attr(device, "disabled_by") and not device_attr(device, "entry_type") %}
{%- if not area_info.printed %}
{{ area.name }}: {{ area.name }}:
{% for device in area_devices(area.name) -%} {%- set area_info.printed = true %}
{%- if not device_attr(device, "disabled_by") %} {%- endif %}
- {{ device_attr(device, "name") }} ({{ device_attr(device, "model") }} by {{ device_attr(device, "manufacturer") }}) - {{ device_attr(device, "name") }}{% if device_attr(device, "model") not in device_attr(device, "name") %} ({{ device_attr(device, "model") }}){% endif %}
{%- endif %} {%- endif %}
{%- endfor %} {%- endfor %}
{% endfor %} {%- endfor %}
Answer the users questions about the world truthfully.
If the user wants to control a device, reject the request and suggest using the Home Assistant UI.
Now finish this conversation: Now finish this conversation:

View File

@ -1,14 +1,20 @@
"""Tests for the OpenAI integration.""" """Tests for the OpenAI integration."""
from unittest.mock import patch from unittest.mock import patch
from openai import error
from homeassistant.components import conversation from homeassistant.components import conversation
from homeassistant.core import Context from homeassistant.core import Context
from homeassistant.helpers import device_registry from homeassistant.helpers import area_registry, device_registry, intent
async def test_default_prompt(hass, mock_init_component): async def test_default_prompt(hass, mock_init_component):
"""Test that the default prompt works.""" """Test that the default prompt works."""
device_reg = device_registry.async_get(hass) device_reg = device_registry.async_get(hass)
area_reg = area_registry.async_get(hass)
for i in range(3):
area_reg.async_create(f"{i}Empty Area")
device_reg.async_get_or_create( device_reg.async_get_or_create(
config_entry_id="1234", config_entry_id="1234",
@ -18,12 +24,22 @@ async def test_default_prompt(hass, mock_init_component):
model="Test Model", model="Test Model",
suggested_area="Test Area", suggested_area="Test Area",
) )
for i in range(3):
device_reg.async_get_or_create(
config_entry_id="1234",
connections={("test", f"{i}abcd")},
name="Test Service",
manufacturer="Test Manufacturer",
model="Test Model",
suggested_area="Test Area",
entry_type=device_registry.DeviceEntryType.SERVICE,
)
device_reg.async_get_or_create( device_reg.async_get_or_create(
config_entry_id="1234", config_entry_id="1234",
connections={("test", "5678")}, connections={("test", "5678")},
name="Test Device 2", name="Test Device 2",
manufacturer="Test Manufacturer 2", manufacturer="Test Manufacturer 2",
model="Test Model 2", model="Device 2",
suggested_area="Test Area 2", suggested_area="Test Area 2",
) )
device_reg.async_get_or_create( device_reg.async_get_or_create(
@ -31,7 +47,7 @@ async def test_default_prompt(hass, mock_init_component):
connections={("test", "9876")}, connections={("test", "9876")},
name="Test Device 3", name="Test Device 3",
manufacturer="Test Manufacturer 3", manufacturer="Test Manufacturer 3",
model="Test Model 3", model="Test Model 3A",
suggested_area="Test Area 2", suggested_area="Test Area 2",
) )
@ -40,20 +56,20 @@ async def test_default_prompt(hass, mock_init_component):
assert ( assert (
mock_create.mock_calls[0][2]["prompt"] mock_create.mock_calls[0][2]["prompt"]
== """You are a conversational AI for a smart home named test home. == """This smart home is controlled by Home Assistant.
If a user wants to control a device, reject the request and suggest using the Home Assistant UI.
An overview of the areas and the devices in this smart home: An overview of the areas and the devices in this smart home:
Test Area: Test Area:
- Test Device (Test Model)
- Test Device (Test Model by Test Manufacturer)
Test Area 2: Test Area 2:
- Test Device 2
- Test Device 3 (Test Model 3A)
- Test Device 2 (Test Model 2 by Test Manufacturer 2) Answer the users questions about the world truthfully.
- Test Device 3 (Test Model 3 by Test Manufacturer 3)
If the user wants to control a device, reject the request and suggest using the Home Assistant UI.
Now finish this conversation: Now finish this conversation:
@ -61,3 +77,12 @@ Smart home: How can I assist?
User: hello User: hello
Smart home: """ Smart home: """
) )
async def test_error_handling(hass, mock_init_component):
"""Test that the default prompt works."""
with patch("openai.Completion.create", side_effect=error.ServiceUnavailableError):
result = await conversation.async_converse(hass, "hello", None, Context())
assert result.response.response_type == intent.IntentResponseType.ERROR, result
assert result.response.error_code == "unknown", result