Avoid exceptions when Gemini responses are blocked (#116847)

* Bump google-generativeai to v0.5.2

* Avoid exceptions when Gemini responses are blocked

* pytest --snapshot-update

* set error response

* add test

* ruff
pull/116912/head
tronikos 2024-05-06 01:22:22 -07:00 committed by GitHub
parent 4fce99edb5
commit 5c4afe55fd
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 42 additions and 7 deletions

View File

@ -182,11 +182,11 @@ class GoogleGenerativeAIAgent(conversation.AbstractConversationAgent):
conversation_id = ulid.ulid_now()
messages = [{}, {}]
intent_response = intent.IntentResponse(language=user_input.language)
try:
prompt = self._async_generate_prompt(raw_prompt)
except TemplateError as err:
_LOGGER.error("Error rendering prompt: %s", err)
intent_response = intent.IntentResponse(language=user_input.language)
intent_response.async_set_error(
intent.IntentResponseErrorCode.UNKNOWN,
f"Sorry, I had a problem with my template: {err}",
@ -210,7 +210,6 @@ class GoogleGenerativeAIAgent(conversation.AbstractConversationAgent):
genai_types.StopCandidateException,
) as err:
_LOGGER.error("Error sending message: %s", err)
intent_response = intent.IntentResponse(language=user_input.language)
intent_response.async_set_error(
intent.IntentResponseErrorCode.UNKNOWN,
f"Sorry, I had a problem talking to Google Generative AI: {err}",
@ -220,9 +219,15 @@ class GoogleGenerativeAIAgent(conversation.AbstractConversationAgent):
)
_LOGGER.debug("Response: %s", chat_response.parts)
if not chat_response.parts:
intent_response.async_set_error(
intent.IntentResponseErrorCode.UNKNOWN,
"Sorry, I had a problem talking to Google Generative AI. Likely blocked",
)
return conversation.ConversationResult(
response=intent_response, conversation_id=conversation_id
)
self.history[conversation_id] = chat.history
intent_response = intent.IntentResponse(language=user_input.language)
intent_response.async_set_speech(chat_response.text)
return conversation.ConversationResult(
response=intent_response, conversation_id=conversation_id

View File

@ -95,29 +95,59 @@ async def test_default_prompt(
suggested_area="Test Area 2",
)
with patch("google.generativeai.GenerativeModel") as mock_model:
mock_model.return_value.start_chat.return_value = AsyncMock()
mock_chat = AsyncMock()
mock_model.return_value.start_chat.return_value = mock_chat
chat_response = MagicMock()
mock_chat.send_message_async.return_value = chat_response
chat_response.parts = ["Hi there!"]
chat_response.text = "Hi there!"
result = await conversation.async_converse(
hass, "hello", None, Context(), agent_id=mock_config_entry.entry_id
)
assert result.response.response_type == intent.IntentResponseType.ACTION_DONE
assert result.response.as_dict()["speech"]["plain"]["speech"] == "Hi there!"
assert [tuple(mock_call) for mock_call in mock_model.mock_calls] == snapshot
async def test_error_handling(
hass: HomeAssistant, mock_config_entry: MockConfigEntry, mock_init_component
) -> None:
"""Test that the default prompt works."""
"""Test that client errors are caught."""
with patch("google.generativeai.GenerativeModel") as mock_model:
mock_chat = AsyncMock()
mock_model.return_value.start_chat.return_value = mock_chat
mock_chat.send_message_async.side_effect = ClientError("")
mock_chat.send_message_async.side_effect = ClientError("some error")
result = await conversation.async_converse(
hass, "hello", None, Context(), agent_id=mock_config_entry.entry_id
)
assert result.response.response_type == intent.IntentResponseType.ERROR, result
assert result.response.error_code == "unknown", result
assert result.response.as_dict()["speech"]["plain"]["speech"] == (
"Sorry, I had a problem talking to Google Generative AI: None some error"
)
async def test_blocked_response(
hass: HomeAssistant, mock_config_entry: MockConfigEntry, mock_init_component
) -> None:
"""Test response was blocked."""
with patch("google.generativeai.GenerativeModel") as mock_model:
mock_chat = AsyncMock()
mock_model.return_value.start_chat.return_value = mock_chat
chat_response = MagicMock()
mock_chat.send_message_async.return_value = chat_response
chat_response.parts = []
result = await conversation.async_converse(
hass, "hello", None, Context(), agent_id=mock_config_entry.entry_id
)
assert result.response.response_type == intent.IntentResponseType.ERROR, result
assert result.response.error_code == "unknown", result
assert result.response.as_dict()["speech"]["plain"]["speech"] == (
"Sorry, I had a problem talking to Google Generative AI. Likely blocked"
)
async def test_template_error(