From 0cf574dc42f44d1ad980b585bad1b5dffa7bf2f1 Mon Sep 17 00:00:00 2001 From: Paulus Schoutsen Date: Tue, 28 May 2024 21:21:28 -0400 Subject: [PATCH] Update the recommended model for Google Gen AI (#118323) --- .../google_generative_ai_conversation/const.py | 2 +- .../snapshots/test_conversation.ambr | 12 ++++++------ .../snapshots/test_diagnostics.ambr | 2 +- .../test_config_flow.py | 8 +++++++- 4 files changed, 15 insertions(+), 9 deletions(-) diff --git a/homeassistant/components/google_generative_ai_conversation/const.py b/homeassistant/components/google_generative_ai_conversation/const.py index bd60e8d94c1..94e974d379d 100644 --- a/homeassistant/components/google_generative_ai_conversation/const.py +++ b/homeassistant/components/google_generative_ai_conversation/const.py @@ -8,7 +8,7 @@ CONF_PROMPT = "prompt" CONF_RECOMMENDED = "recommended" CONF_CHAT_MODEL = "chat_model" -RECOMMENDED_CHAT_MODEL = "models/gemini-1.5-flash-latest" +RECOMMENDED_CHAT_MODEL = "models/gemini-1.5-pro-latest" CONF_TEMPERATURE = "temperature" RECOMMENDED_TEMPERATURE = 1.0 CONF_TOP_P = "top_p" diff --git a/tests/components/google_generative_ai_conversation/snapshots/test_conversation.ambr b/tests/components/google_generative_ai_conversation/snapshots/test_conversation.ambr index 40ff556af1c..9c108371bee 100644 --- a/tests/components/google_generative_ai_conversation/snapshots/test_conversation.ambr +++ b/tests/components/google_generative_ai_conversation/snapshots/test_conversation.ambr @@ -12,7 +12,7 @@ 'top_k': 64, 'top_p': 0.95, }), - 'model_name': 'models/gemini-1.5-flash-latest', + 'model_name': 'models/gemini-1.5-pro-latest', 'safety_settings': dict({ 'DANGEROUS': 'BLOCK_MEDIUM_AND_ABOVE', 'HARASSMENT': 'BLOCK_MEDIUM_AND_ABOVE', @@ -64,7 +64,7 @@ 'top_k': 64, 'top_p': 0.95, }), - 'model_name': 'models/gemini-1.5-flash-latest', + 'model_name': 'models/gemini-1.5-pro-latest', 'safety_settings': dict({ 'DANGEROUS': 'BLOCK_MEDIUM_AND_ABOVE', 'HARASSMENT': 'BLOCK_MEDIUM_AND_ABOVE', @@ -128,7 +128,7 @@ 'top_k': 64, 'top_p': 0.95, }), - 'model_name': 'models/gemini-1.5-flash-latest', + 'model_name': 'models/gemini-1.5-pro-latest', 'safety_settings': dict({ 'DANGEROUS': 'BLOCK_MEDIUM_AND_ABOVE', 'HARASSMENT': 'BLOCK_MEDIUM_AND_ABOVE', @@ -184,7 +184,7 @@ 'top_k': 64, 'top_p': 0.95, }), - 'model_name': 'models/gemini-1.5-flash-latest', + 'model_name': 'models/gemini-1.5-pro-latest', 'safety_settings': dict({ 'DANGEROUS': 'BLOCK_MEDIUM_AND_ABOVE', 'HARASSMENT': 'BLOCK_MEDIUM_AND_ABOVE', @@ -240,7 +240,7 @@ 'top_k': 64, 'top_p': 0.95, }), - 'model_name': 'models/gemini-1.5-flash-latest', + 'model_name': 'models/gemini-1.5-pro-latest', 'safety_settings': dict({ 'DANGEROUS': 'BLOCK_MEDIUM_AND_ABOVE', 'HARASSMENT': 'BLOCK_MEDIUM_AND_ABOVE', @@ -296,7 +296,7 @@ 'top_k': 64, 'top_p': 0.95, }), - 'model_name': 'models/gemini-1.5-flash-latest', + 'model_name': 'models/gemini-1.5-pro-latest', 'safety_settings': dict({ 'DANGEROUS': 'BLOCK_MEDIUM_AND_ABOVE', 'HARASSMENT': 'BLOCK_MEDIUM_AND_ABOVE', diff --git a/tests/components/google_generative_ai_conversation/snapshots/test_diagnostics.ambr b/tests/components/google_generative_ai_conversation/snapshots/test_diagnostics.ambr index 316bf74b72a..ca18b0ad25c 100644 --- a/tests/components/google_generative_ai_conversation/snapshots/test_diagnostics.ambr +++ b/tests/components/google_generative_ai_conversation/snapshots/test_diagnostics.ambr @@ -5,7 +5,7 @@ 'api_key': '**REDACTED**', }), 'options': dict({ - 'chat_model': 'models/gemini-1.5-flash-latest', + 'chat_model': 'models/gemini-1.5-pro-latest', 'dangerous_block_threshold': 'BLOCK_MEDIUM_AND_ABOVE', 'harassment_block_threshold': 'BLOCK_MEDIUM_AND_ABOVE', 'hate_block_threshold': 'BLOCK_MEDIUM_AND_ABOVE', diff --git a/tests/components/google_generative_ai_conversation/test_config_flow.py b/tests/components/google_generative_ai_conversation/test_config_flow.py index 41b1dbeb32e..24ed06a408f 100644 --- a/tests/components/google_generative_ai_conversation/test_config_flow.py +++ b/tests/components/google_generative_ai_conversation/test_config_flow.py @@ -45,6 +45,12 @@ def mock_models(): ) model_15_flash.name = "models/gemini-1.5-flash-latest" + model_15_pro = Mock( + display_name="Gemini 1.5 Pro", + supported_generation_methods=["generateContent"], + ) + model_15_pro.name = "models/gemini-1.5-pro-latest" + model_10_pro = Mock( display_name="Gemini 1.0 Pro", supported_generation_methods=["generateContent"], @@ -52,7 +58,7 @@ def mock_models(): model_10_pro.name = "models/gemini-pro" with patch( "homeassistant.components.google_generative_ai_conversation.config_flow.genai.list_models", - return_value=iter([model_15_flash, model_10_pro]), + return_value=iter([model_15_flash, model_15_pro, model_10_pro]), ): yield