Tweak Assist LLM API prompt (#118343)

pull/118339/head^2
Paulus Schoutsen 2024-05-28 22:43:22 -04:00 committed by GitHub
parent d223e1f2ac
commit c097a05ed4
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 8 additions and 26 deletions

View File

@ -227,12 +227,13 @@ class AssistAPI(API):
return APIInstance(
api=self,
api_prompt=await self._async_get_api_prompt(tool_context, exposed_entities),
api_prompt=self._async_get_api_prompt(tool_context, exposed_entities),
tool_context=tool_context,
tools=self._async_get_tools(tool_context, exposed_entities),
)
async def _async_get_api_prompt(
@callback
def _async_get_api_prompt(
self, tool_context: ToolContext, exposed_entities: dict | None
) -> str:
"""Return the prompt for the API."""
@ -269,15 +270,10 @@ class AssistAPI(API):
prompt.append(f"You are in area {area.name} {extra}")
else:
prompt.append(
"Reject all generic commands like 'turn on the lights' because we "
"don't know in what area this conversation is happening."
"When a user asks to turn on all devices of a specific type, "
"ask user to specify an area."
)
if tool_context.context and tool_context.context.user_id:
user = await self.hass.auth.async_get_user(tool_context.context.user_id)
if user:
prompt.append(f"The user name is {user.name}.")
if not tool_context.device_id or not async_device_supports_timers(
self.hass, tool_context.device_id
):

View File

@ -1,6 +1,6 @@
"""Tests for the llm helpers."""
from unittest.mock import Mock, patch
from unittest.mock import patch
import pytest
import voluptuous as vol
@ -430,8 +430,8 @@ async def test_assist_api_prompt(
no_timer_prompt = "This device does not support timers."
area_prompt = (
"Reject all generic commands like 'turn on the lights' because we don't know in what area "
"this conversation is happening."
"When a user asks to turn on all devices of a specific type, "
"ask user to specify an area."
)
api = await llm.async_get_api(hass, "assist", tool_context)
assert api.api_prompt == (
@ -478,19 +478,5 @@ async def test_assist_api_prompt(
assert api.api_prompt == (
f"""{first_part_prompt}
{area_prompt}
{exposed_entities_prompt}"""
)
# Add user
context.user_id = "12345"
mock_user = Mock()
mock_user.id = "12345"
mock_user.name = "Test User"
with patch("homeassistant.auth.AuthManager.async_get_user", return_value=mock_user):
api = await llm.async_get_api(hass, "assist", tool_context)
assert api.api_prompt == (
f"""{first_part_prompt}
{area_prompt}
The user name is Test User.
{exposed_entities_prompt}"""
)