Add AI Task prefs

pull/146734/head
Paulus Schoutsen 2025-06-13 09:15:20 -04:00
parent 7aa002fdc0
commit ad3852a4f3
7 changed files with 374 additions and 9 deletions

View File

@ -3,12 +3,12 @@
import logging
from homeassistant.config_entries import ConfigEntry
from homeassistant.core import HomeAssistant
from homeassistant.helpers import config_validation as cv
from homeassistant.core import HomeAssistant, callback
from homeassistant.helpers import config_validation as cv, storage
from homeassistant.helpers.entity_component import EntityComponent
from homeassistant.helpers.typing import ConfigType
from homeassistant.helpers.typing import UNDEFINED, ConfigType, UndefinedType
from .const import DATA_COMPONENT, DOMAIN
from .const import DATA_COMPONENT, DATA_PREFERENCES, DOMAIN
from .entity import AITaskEntity
from .http import async_setup as async_setup_conversation_http
from .task import GenTextTask, GenTextTaskResult, GenTextTaskType, async_generate_text
@ -34,6 +34,8 @@ async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool:
"""Register the process service."""
entity_component = EntityComponent[AITaskEntity](_LOGGER, DOMAIN, hass)
hass.data[DATA_COMPONENT] = entity_component
hass.data[DATA_PREFERENCES] = AITaskPreferences(hass)
await hass.data[DATA_PREFERENCES].async_load()
async_setup_conversation_http(hass)
return True
@ -46,3 +48,61 @@ async def async_setup_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool:
async def async_unload_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool:
"""Unload a config entry."""
return await hass.data[DATA_COMPONENT].async_unload_entry(entry)
class AITaskPreferences:
"""AI Task preferences."""
gen_text_summary_entity_id: str | None = None
gen_text_generate_entity_id: str | None = None
def __init__(self, hass: HomeAssistant) -> None:
"""Initialize the preferences."""
self._store: storage.Store[dict[str, str | None]] = storage.Store(
hass, 1, DOMAIN
)
async def async_load(self) -> None:
"""Load the data from the store."""
data = await self._store.async_load()
if data is None:
return
self.gen_text_summary_entity_id = data.get("gen_text_summary_entity_id")
self.gen_text_generate_entity_id = data.get("gen_text_generate_entity_id")
@callback
def async_set_preferences(
self,
*,
gen_text_summary_entity_id: str | None | UndefinedType = UNDEFINED,
gen_text_generate_entity_id: str | None | UndefinedType = UNDEFINED,
) -> None:
"""Set the preferences."""
changed = False
for key, value in (
("gen_text_summary_entity_id", gen_text_summary_entity_id),
("gen_text_generate_entity_id", gen_text_generate_entity_id),
):
if value is not UNDEFINED:
if getattr(self, key) != value:
setattr(self, key, value)
changed = True
if not changed:
return
self._store.async_delay_save(
lambda: {
"gen_text_summary_entity_id": self.gen_text_summary_entity_id,
"gen_text_generate_entity_id": self.gen_text_generate_entity_id,
},
10,
)
@callback
def as_dict(self) -> dict[str, str | None]:
"""Get the current preferences."""
return {
"gen_text_summary_entity_id": self.gen_text_summary_entity_id,
"gen_text_generate_entity_id": self.gen_text_generate_entity_id,
}

View File

@ -10,10 +10,12 @@ from homeassistant.util.hass_dict import HassKey
if TYPE_CHECKING:
from homeassistant.helpers.entity_component import EntityComponent
from . import AITaskPreferences
from .entity import AITaskEntity
DOMAIN = "ai_task"
DATA_COMPONENT: HassKey[EntityComponent[AITaskEntity]] = HassKey(DOMAIN)
DATA_PREFERENCES: HassKey[AITaskPreferences] = HassKey(f"{DOMAIN}_preferences")
DEFAULT_SYSTEM_PROMPT = (
"You are a Home Assistant expert and help users with their tasks."

View File

@ -7,6 +7,7 @@ import voluptuous as vol
from homeassistant.components import websocket_api
from homeassistant.core import HomeAssistant, callback
from .const import DATA_PREFERENCES
from .task import GenTextTaskType, async_generate_text
@ -14,13 +15,15 @@ from .task import GenTextTaskType, async_generate_text
def async_setup(hass: HomeAssistant) -> None:
"""Set up the HTTP API for the conversation integration."""
websocket_api.async_register_command(hass, websocket_generate_text)
websocket_api.async_register_command(hass, websocket_get_preferences)
websocket_api.async_register_command(hass, websocket_set_preferences)
@websocket_api.websocket_command(
{
vol.Required("type"): "ai_task/generate_text",
vol.Required("task_name"): str,
vol.Required("entity_id"): str,
vol.Optional("entity_id"): str,
vol.Required("task_type"): (lambda v: GenTextTaskType(v)), # pylint: disable=unnecessary-lambda
vol.Required("instructions"): str,
}
@ -37,3 +40,41 @@ async def websocket_generate_text(
msg_id = msg.pop("id")
result = await async_generate_text(hass=hass, **msg)
connection.send_result(msg_id, result.as_dict())
@websocket_api.websocket_command(
{
vol.Required("type"): "ai_task/preferences/get",
}
)
@callback
def websocket_get_preferences(
hass: HomeAssistant,
connection: websocket_api.ActiveConnection,
msg: dict[str, Any],
) -> None:
"""Get AI task preferences."""
preferences = hass.data[DATA_PREFERENCES]
connection.send_result(msg["id"], preferences.as_dict())
@websocket_api.websocket_command(
{
vol.Required("type"): "ai_task/preferences/set",
vol.Optional("gen_text_summary_entity_id"): vol.Any(str, None),
vol.Optional("gen_text_generate_entity_id"): vol.Any(str, None),
}
)
@websocket_api.require_admin
@callback
def websocket_set_preferences(
hass: HomeAssistant,
connection: websocket_api.ActiveConnection,
msg: dict[str, Any],
) -> None:
"""Set AI task preferences."""
preferences = hass.data[DATA_PREFERENCES]
msg.pop("type")
msg_id = msg.pop("id")
preferences.async_set_preferences(**msg)
connection.send_result(msg_id, preferences.as_dict())

View File

@ -6,18 +6,30 @@ from dataclasses import dataclass
from homeassistant.core import HomeAssistant
from .const import DATA_COMPONENT, GenTextTaskType
from .const import DATA_COMPONENT, DATA_PREFERENCES, GenTextTaskType
async def async_generate_text(
hass: HomeAssistant,
*,
task_name: str,
entity_id: str,
entity_id: str | None = None,
task_type: GenTextTaskType,
instructions: str,
) -> GenTextTaskResult:
"""Run a task in the AI Task integration."""
if entity_id is None:
preferences = hass.data[DATA_PREFERENCES]
if task_type == GenTextTaskType.SUMMARY:
entity_id = preferences.gen_text_summary_entity_id
elif task_type == GenTextTaskType.GENERATE:
entity_id = preferences.gen_text_generate_entity_id
if entity_id is None:
raise ValueError(
"No entity_id provided and no preferred entity set for this task type"
)
entity = hass.data[DATA_COMPONENT].get_entity(entity_id)
if entity is None:
raise ValueError(f"AI Task entity {entity_id} not found")

View File

@ -1,5 +1,8 @@
"""Test the HTTP API for AI Task integration."""
import pytest
from homeassistant.components.ai_task import DATA_PREFERENCES, GenTextTaskType
from homeassistant.const import STATE_UNKNOWN
from homeassistant.core import HomeAssistant
@ -8,10 +11,14 @@ from .conftest import TEST_ENTITY_ID
from tests.typing import WebSocketGenerator
@pytest.mark.parametrize(
"task_type", [GenTextTaskType.SUMMARY, GenTextTaskType.GENERATE]
)
async def test_ws_generate_text(
hass: HomeAssistant,
hass_ws_client: WebSocketGenerator,
init_components: None,
task_type: GenTextTaskType,
) -> None:
"""Test running a task via the WebSocket API."""
entity = hass.states.get(TEST_ENTITY_ID)
@ -25,7 +32,7 @@ async def test_ws_generate_text(
"type": "ai_task/generate_text",
"task_name": "Test Task",
"entity_id": TEST_ENTITY_ID,
"task_type": "summary",
"task_type": task_type.value,
"instructions": "Test prompt",
}
)
@ -37,3 +44,129 @@ async def test_ws_generate_text(
entity = hass.states.get(TEST_ENTITY_ID)
assert entity.state != STATE_UNKNOWN
@pytest.mark.parametrize(
"task_type", [GenTextTaskType.SUMMARY, GenTextTaskType.GENERATE]
)
async def test_ws_run_task_preferred_entity(
hass: HomeAssistant,
hass_ws_client: WebSocketGenerator,
init_components: None,
task_type: GenTextTaskType,
) -> None:
"""Test running a task via the WebSocket API."""
preferences = hass.data[DATA_PREFERENCES]
preferences_key = f"gen_text_{task_type.value}_entity_id"
preferences.async_set_preferences(**{preferences_key: TEST_ENTITY_ID})
entity = hass.states.get(TEST_ENTITY_ID)
assert entity is not None
assert entity.state == STATE_UNKNOWN
client = await hass_ws_client(hass)
await client.send_json_auto_id(
{
"type": "ai_task/generate_text",
"task_name": "Test Task",
"task_type": task_type.value,
"instructions": "Test prompt",
}
)
msg = await client.receive_json()
assert msg["success"]
assert msg["result"]["result"] == "Mock result"
entity = hass.states.get(TEST_ENTITY_ID)
assert entity.state != STATE_UNKNOWN
async def test_ws_preferences(
hass: HomeAssistant,
hass_ws_client: WebSocketGenerator,
init_components: None,
) -> None:
"""Test preferences via the WebSocket API."""
client = await hass_ws_client(hass)
# Get initial preferences
await client.send_json_auto_id({"type": "ai_task/preferences/get"})
msg = await client.receive_json()
assert msg["success"]
assert msg["result"] == {
"gen_text_summary_entity_id": None,
"gen_text_generate_entity_id": None,
}
# Set preferences
await client.send_json_auto_id(
{
"type": "ai_task/preferences/set",
"gen_text_summary_entity_id": "ai_task.summary_1",
"gen_text_generate_entity_id": "ai_task.generate_1",
}
)
msg = await client.receive_json()
assert msg["success"]
assert msg["result"] == {
"gen_text_summary_entity_id": "ai_task.summary_1",
"gen_text_generate_entity_id": "ai_task.generate_1",
}
# Get updated preferences
await client.send_json_auto_id({"type": "ai_task/preferences/get"})
msg = await client.receive_json()
assert msg["success"]
assert msg["result"] == {
"gen_text_summary_entity_id": "ai_task.summary_1",
"gen_text_generate_entity_id": "ai_task.generate_1",
}
# Set only one preference
await client.send_json_auto_id(
{
"type": "ai_task/preferences/set",
"gen_text_summary_entity_id": "ai_task.summary_2",
}
)
msg = await client.receive_json()
assert msg["success"]
assert msg["result"] == {
"gen_text_summary_entity_id": "ai_task.summary_2",
"gen_text_generate_entity_id": "ai_task.generate_1",
}
# Get updated preferences
await client.send_json_auto_id({"type": "ai_task/preferences/get"})
msg = await client.receive_json()
assert msg["success"]
assert msg["result"] == {
"gen_text_summary_entity_id": "ai_task.summary_2",
"gen_text_generate_entity_id": "ai_task.generate_1",
}
# Clear a preference
await client.send_json_auto_id(
{
"type": "ai_task/preferences/set",
"gen_text_generate_entity_id": None,
}
)
msg = await client.receive_json()
assert msg["success"]
assert msg["result"] == {
"gen_text_summary_entity_id": "ai_task.summary_2",
"gen_text_generate_entity_id": None,
}
# Get updated preferences
await client.send_json_auto_id({"type": "ai_task/preferences/get"})
msg = await client.receive_json()
assert msg["success"]
assert msg["result"] == {
"gen_text_summary_entity_id": "ai_task.summary_2",
"gen_text_generate_entity_id": None,
}

View File

@ -0,0 +1,62 @@
"""Test initialization of the AI Task component."""
from freezegun.api import FrozenDateTimeFactory
from homeassistant.components.ai_task import AITaskPreferences
from homeassistant.components.ai_task.const import DATA_PREFERENCES
from homeassistant.core import HomeAssistant
from tests.common import flush_store
async def test_preferences_storage_load(
hass: HomeAssistant,
init_components: None,
freezer: FrozenDateTimeFactory,
) -> None:
"""Test that AITaskPreferences are stored and loaded correctly."""
preferences = hass.data[DATA_PREFERENCES]
# Initial state should be None for entity IDs
assert preferences.gen_text_summary_entity_id is None
assert preferences.gen_text_generate_entity_id is None
summary_id_1 = "sensor.summary_one"
generate_id_1 = "sensor.generate_one"
preferences.async_set_preferences(
gen_text_summary_entity_id=summary_id_1,
gen_text_generate_entity_id=generate_id_1,
)
# Verify that current preferences object is updated
assert preferences.gen_text_summary_entity_id == summary_id_1
assert preferences.gen_text_generate_entity_id == generate_id_1
await flush_store(preferences._store)
# Create a new preferences instance to test loading from store
new_preferences_instance = AITaskPreferences(hass)
await new_preferences_instance.async_load()
assert new_preferences_instance.gen_text_summary_entity_id == summary_id_1
assert new_preferences_instance.gen_text_generate_entity_id == generate_id_1
# Test updating one preference and setting another to None
summary_id_2 = "sensor.summary_two"
preferences.async_set_preferences(
gen_text_summary_entity_id=summary_id_2, gen_text_generate_entity_id=None
)
# Verify that current preferences object is updated
assert preferences.gen_text_summary_entity_id == summary_id_2
assert preferences.gen_text_generate_entity_id is None
await flush_store(preferences._store)
# Create another new preferences instance to confirm persistence of the update
another_new_preferences_instance = AITaskPreferences(hass)
await another_new_preferences_instance.async_load()
assert another_new_preferences_instance.gen_text_summary_entity_id == summary_id_2
assert another_new_preferences_instance.gen_text_generate_entity_id is None

View File

@ -4,14 +4,69 @@ from freezegun import freeze_time
import pytest
from syrupy.assertion import SnapshotAssertion
from homeassistant.components.ai_task import GenTextTaskType, async_generate_text
from homeassistant.components.ai_task import (
DATA_PREFERENCES,
GenTextTaskType,
async_generate_text,
)
from homeassistant.components.conversation import async_get_chat_log
from homeassistant.const import STATE_UNKNOWN
from homeassistant.core import HomeAssistant
from homeassistant.helpers import chat_session
from .conftest import TEST_ENTITY_ID
@pytest.mark.parametrize(
"task_type", [GenTextTaskType.SUMMARY, GenTextTaskType.GENERATE]
)
async def test_run_task_preferred_entity(
hass: HomeAssistant,
init_components: None,
task_type: GenTextTaskType,
) -> None:
"""Test running a task with an unknown entity."""
preferences = hass.data[DATA_PREFERENCES]
preferences_key = f"gen_text_{task_type.value}_entity_id"
with pytest.raises(
ValueError,
match="No entity_id provided and no preferred entity set for this task type",
):
await async_generate_text(
hass,
task_name="Test Task",
task_type=task_type,
instructions="Test prompt",
)
preferences.async_set_preferences(**{preferences_key: "ai_task.unknown"})
with pytest.raises(ValueError, match="AI Task entity ai_task.unknown not found"):
await async_generate_text(
hass,
task_name="Test Task",
task_type=task_type,
instructions="Test prompt",
)
preferences.async_set_preferences(**{preferences_key: TEST_ENTITY_ID})
state = hass.states.get(TEST_ENTITY_ID)
assert state is not None
assert state.state == STATE_UNKNOWN
result = await async_generate_text(
hass,
task_name="Test Task",
task_type=task_type,
instructions="Test prompt",
)
assert result.result == "Mock result"
state = hass.states.get(TEST_ENTITY_ID)
assert state is not None
assert state.state != STATE_UNKNOWN
async def test_run_task_unknown_entity(
hass: HomeAssistant,
init_components: None,