Avoid exceptions when Gemini responses are blocked (#116847)
* Bump google-generativeai to v0.5.2 * Avoid exceptions when Gemini responses are blocked * pytest --snapshot-update * set error response * add test * ruffpull/116912/head
parent
4fce99edb5
commit
5c4afe55fd
|
@ -182,11 +182,11 @@ class GoogleGenerativeAIAgent(conversation.AbstractConversationAgent):
|
|||
conversation_id = ulid.ulid_now()
|
||||
messages = [{}, {}]
|
||||
|
||||
intent_response = intent.IntentResponse(language=user_input.language)
|
||||
try:
|
||||
prompt = self._async_generate_prompt(raw_prompt)
|
||||
except TemplateError as err:
|
||||
_LOGGER.error("Error rendering prompt: %s", err)
|
||||
intent_response = intent.IntentResponse(language=user_input.language)
|
||||
intent_response.async_set_error(
|
||||
intent.IntentResponseErrorCode.UNKNOWN,
|
||||
f"Sorry, I had a problem with my template: {err}",
|
||||
|
@ -210,7 +210,6 @@ class GoogleGenerativeAIAgent(conversation.AbstractConversationAgent):
|
|||
genai_types.StopCandidateException,
|
||||
) as err:
|
||||
_LOGGER.error("Error sending message: %s", err)
|
||||
intent_response = intent.IntentResponse(language=user_input.language)
|
||||
intent_response.async_set_error(
|
||||
intent.IntentResponseErrorCode.UNKNOWN,
|
||||
f"Sorry, I had a problem talking to Google Generative AI: {err}",
|
||||
|
@ -220,9 +219,15 @@ class GoogleGenerativeAIAgent(conversation.AbstractConversationAgent):
|
|||
)
|
||||
|
||||
_LOGGER.debug("Response: %s", chat_response.parts)
|
||||
if not chat_response.parts:
|
||||
intent_response.async_set_error(
|
||||
intent.IntentResponseErrorCode.UNKNOWN,
|
||||
"Sorry, I had a problem talking to Google Generative AI. Likely blocked",
|
||||
)
|
||||
return conversation.ConversationResult(
|
||||
response=intent_response, conversation_id=conversation_id
|
||||
)
|
||||
self.history[conversation_id] = chat.history
|
||||
|
||||
intent_response = intent.IntentResponse(language=user_input.language)
|
||||
intent_response.async_set_speech(chat_response.text)
|
||||
return conversation.ConversationResult(
|
||||
response=intent_response, conversation_id=conversation_id
|
||||
|
|
|
@ -95,29 +95,59 @@ async def test_default_prompt(
|
|||
suggested_area="Test Area 2",
|
||||
)
|
||||
with patch("google.generativeai.GenerativeModel") as mock_model:
|
||||
mock_model.return_value.start_chat.return_value = AsyncMock()
|
||||
mock_chat = AsyncMock()
|
||||
mock_model.return_value.start_chat.return_value = mock_chat
|
||||
chat_response = MagicMock()
|
||||
mock_chat.send_message_async.return_value = chat_response
|
||||
chat_response.parts = ["Hi there!"]
|
||||
chat_response.text = "Hi there!"
|
||||
result = await conversation.async_converse(
|
||||
hass, "hello", None, Context(), agent_id=mock_config_entry.entry_id
|
||||
)
|
||||
|
||||
assert result.response.response_type == intent.IntentResponseType.ACTION_DONE
|
||||
assert result.response.as_dict()["speech"]["plain"]["speech"] == "Hi there!"
|
||||
assert [tuple(mock_call) for mock_call in mock_model.mock_calls] == snapshot
|
||||
|
||||
|
||||
async def test_error_handling(
|
||||
hass: HomeAssistant, mock_config_entry: MockConfigEntry, mock_init_component
|
||||
) -> None:
|
||||
"""Test that the default prompt works."""
|
||||
"""Test that client errors are caught."""
|
||||
with patch("google.generativeai.GenerativeModel") as mock_model:
|
||||
mock_chat = AsyncMock()
|
||||
mock_model.return_value.start_chat.return_value = mock_chat
|
||||
mock_chat.send_message_async.side_effect = ClientError("")
|
||||
mock_chat.send_message_async.side_effect = ClientError("some error")
|
||||
result = await conversation.async_converse(
|
||||
hass, "hello", None, Context(), agent_id=mock_config_entry.entry_id
|
||||
)
|
||||
|
||||
assert result.response.response_type == intent.IntentResponseType.ERROR, result
|
||||
assert result.response.error_code == "unknown", result
|
||||
assert result.response.as_dict()["speech"]["plain"]["speech"] == (
|
||||
"Sorry, I had a problem talking to Google Generative AI: None some error"
|
||||
)
|
||||
|
||||
|
||||
async def test_blocked_response(
|
||||
hass: HomeAssistant, mock_config_entry: MockConfigEntry, mock_init_component
|
||||
) -> None:
|
||||
"""Test response was blocked."""
|
||||
with patch("google.generativeai.GenerativeModel") as mock_model:
|
||||
mock_chat = AsyncMock()
|
||||
mock_model.return_value.start_chat.return_value = mock_chat
|
||||
chat_response = MagicMock()
|
||||
mock_chat.send_message_async.return_value = chat_response
|
||||
chat_response.parts = []
|
||||
result = await conversation.async_converse(
|
||||
hass, "hello", None, Context(), agent_id=mock_config_entry.entry_id
|
||||
)
|
||||
|
||||
assert result.response.response_type == intent.IntentResponseType.ERROR, result
|
||||
assert result.response.error_code == "unknown", result
|
||||
assert result.response.as_dict()["speech"]["plain"]["speech"] == (
|
||||
"Sorry, I had a problem talking to Google Generative AI. Likely blocked"
|
||||
)
|
||||
|
||||
|
||||
async def test_template_error(
|
||||
|
|
Loading…
Reference in New Issue