Address late feedback Google LLM (#117873)

pull/117878/head
Paulus Schoutsen 2024-05-21 14:11:18 -04:00 committed by GitHub
parent 2a9b31261c
commit f21226dd0e
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 21 additions and 16 deletions

View File

@ -19,7 +19,10 @@ from .singleton import singleton
LLM_API_ASSIST = "assist"
PROMPT_NO_API_CONFIGURED = "If the user wants to control a device, tell them to edit the AI configuration and allow access to Home Assistant."
PROMPT_NO_API_CONFIGURED = (
"If the user wants to control a device, tell them to edit the AI configuration and "
"allow access to Home Assistant."
)
@singleton("llm")

View File

@ -1,5 +1,5 @@
# serializer version: 1
# name: test_default_prompt[False-None]
# name: test_default_prompt[config_entry_options0-None]
list([
tuple(
'',
@ -58,7 +58,7 @@
),
])
# ---
# name: test_default_prompt[False-conversation.google_generative_ai_conversation]
# name: test_default_prompt[config_entry_options0-conversation.google_generative_ai_conversation]
list([
tuple(
'',
@ -117,7 +117,7 @@
),
])
# ---
# name: test_default_prompt[True-None]
# name: test_default_prompt[config_entry_options1-None]
list([
tuple(
'',
@ -176,7 +176,7 @@
),
])
# ---
# name: test_default_prompt[True-conversation.google_generative_ai_conversation]
# name: test_default_prompt[config_entry_options1-conversation.google_generative_ai_conversation]
list([
tuple(
'',

View File

@ -24,7 +24,13 @@ from tests.common import MockConfigEntry
@pytest.mark.parametrize(
"agent_id", [None, "conversation.google_generative_ai_conversation"]
)
@pytest.mark.parametrize("allow_hass_access", [False, True])
@pytest.mark.parametrize(
"config_entry_options",
[
{},
{CONF_LLM_HASS_API: llm.LLM_API_ASSIST},
],
)
async def test_default_prompt(
hass: HomeAssistant,
mock_config_entry: MockConfigEntry,
@ -33,7 +39,7 @@ async def test_default_prompt(
device_registry: dr.DeviceRegistry,
snapshot: SnapshotAssertion,
agent_id: str | None,
allow_hass_access: bool,
config_entry_options: {},
) -> None:
"""Test that the default prompt works."""
entry = MockConfigEntry(title=None)
@ -44,14 +50,10 @@ async def test_default_prompt(
if agent_id is None:
agent_id = mock_config_entry.entry_id
if allow_hass_access:
hass.config_entries.async_update_entry(
mock_config_entry,
options={
**mock_config_entry.options,
CONF_LLM_HASS_API: llm.LLM_API_ASSIST,
},
)
hass.config_entries.async_update_entry(
mock_config_entry,
options={**mock_config_entry.options, **config_entry_options},
)
device_registry.async_get_or_create(
config_entry_id=entry.entry_id,
@ -145,7 +147,7 @@ async def test_default_prompt(
assert result.response.response_type == intent.IntentResponseType.ACTION_DONE
assert result.response.as_dict()["speech"]["plain"]["speech"] == "Hi there!"
assert [tuple(mock_call) for mock_call in mock_model.mock_calls] == snapshot
assert mock_get_tools.called == allow_hass_access
assert mock_get_tools.called == (CONF_LLM_HASS_API in config_entry_options)
@patch(