Add AI Task prefs
parent
7aa002fdc0
commit
ad3852a4f3
|
@ -3,12 +3,12 @@
|
|||
import logging
|
||||
|
||||
from homeassistant.config_entries import ConfigEntry
|
||||
from homeassistant.core import HomeAssistant
|
||||
from homeassistant.helpers import config_validation as cv
|
||||
from homeassistant.core import HomeAssistant, callback
|
||||
from homeassistant.helpers import config_validation as cv, storage
|
||||
from homeassistant.helpers.entity_component import EntityComponent
|
||||
from homeassistant.helpers.typing import ConfigType
|
||||
from homeassistant.helpers.typing import UNDEFINED, ConfigType, UndefinedType
|
||||
|
||||
from .const import DATA_COMPONENT, DOMAIN
|
||||
from .const import DATA_COMPONENT, DATA_PREFERENCES, DOMAIN
|
||||
from .entity import AITaskEntity
|
||||
from .http import async_setup as async_setup_conversation_http
|
||||
from .task import GenTextTask, GenTextTaskResult, GenTextTaskType, async_generate_text
|
||||
|
@ -34,6 +34,8 @@ async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool:
|
|||
"""Register the process service."""
|
||||
entity_component = EntityComponent[AITaskEntity](_LOGGER, DOMAIN, hass)
|
||||
hass.data[DATA_COMPONENT] = entity_component
|
||||
hass.data[DATA_PREFERENCES] = AITaskPreferences(hass)
|
||||
await hass.data[DATA_PREFERENCES].async_load()
|
||||
async_setup_conversation_http(hass)
|
||||
return True
|
||||
|
||||
|
@ -46,3 +48,61 @@ async def async_setup_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool:
|
|||
async def async_unload_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool:
|
||||
"""Unload a config entry."""
|
||||
return await hass.data[DATA_COMPONENT].async_unload_entry(entry)
|
||||
|
||||
|
||||
class AITaskPreferences:
|
||||
"""AI Task preferences."""
|
||||
|
||||
gen_text_summary_entity_id: str | None = None
|
||||
gen_text_generate_entity_id: str | None = None
|
||||
|
||||
def __init__(self, hass: HomeAssistant) -> None:
|
||||
"""Initialize the preferences."""
|
||||
self._store: storage.Store[dict[str, str | None]] = storage.Store(
|
||||
hass, 1, DOMAIN
|
||||
)
|
||||
|
||||
async def async_load(self) -> None:
|
||||
"""Load the data from the store."""
|
||||
data = await self._store.async_load()
|
||||
if data is None:
|
||||
return
|
||||
self.gen_text_summary_entity_id = data.get("gen_text_summary_entity_id")
|
||||
self.gen_text_generate_entity_id = data.get("gen_text_generate_entity_id")
|
||||
|
||||
@callback
|
||||
def async_set_preferences(
|
||||
self,
|
||||
*,
|
||||
gen_text_summary_entity_id: str | None | UndefinedType = UNDEFINED,
|
||||
gen_text_generate_entity_id: str | None | UndefinedType = UNDEFINED,
|
||||
) -> None:
|
||||
"""Set the preferences."""
|
||||
changed = False
|
||||
for key, value in (
|
||||
("gen_text_summary_entity_id", gen_text_summary_entity_id),
|
||||
("gen_text_generate_entity_id", gen_text_generate_entity_id),
|
||||
):
|
||||
if value is not UNDEFINED:
|
||||
if getattr(self, key) != value:
|
||||
setattr(self, key, value)
|
||||
changed = True
|
||||
|
||||
if not changed:
|
||||
return
|
||||
|
||||
self._store.async_delay_save(
|
||||
lambda: {
|
||||
"gen_text_summary_entity_id": self.gen_text_summary_entity_id,
|
||||
"gen_text_generate_entity_id": self.gen_text_generate_entity_id,
|
||||
},
|
||||
10,
|
||||
)
|
||||
|
||||
@callback
|
||||
def as_dict(self) -> dict[str, str | None]:
|
||||
"""Get the current preferences."""
|
||||
return {
|
||||
"gen_text_summary_entity_id": self.gen_text_summary_entity_id,
|
||||
"gen_text_generate_entity_id": self.gen_text_generate_entity_id,
|
||||
}
|
||||
|
|
|
@ -10,10 +10,12 @@ from homeassistant.util.hass_dict import HassKey
|
|||
if TYPE_CHECKING:
|
||||
from homeassistant.helpers.entity_component import EntityComponent
|
||||
|
||||
from . import AITaskPreferences
|
||||
from .entity import AITaskEntity
|
||||
|
||||
DOMAIN = "ai_task"
|
||||
DATA_COMPONENT: HassKey[EntityComponent[AITaskEntity]] = HassKey(DOMAIN)
|
||||
DATA_PREFERENCES: HassKey[AITaskPreferences] = HassKey(f"{DOMAIN}_preferences")
|
||||
|
||||
DEFAULT_SYSTEM_PROMPT = (
|
||||
"You are a Home Assistant expert and help users with their tasks."
|
||||
|
|
|
@ -7,6 +7,7 @@ import voluptuous as vol
|
|||
from homeassistant.components import websocket_api
|
||||
from homeassistant.core import HomeAssistant, callback
|
||||
|
||||
from .const import DATA_PREFERENCES
|
||||
from .task import GenTextTaskType, async_generate_text
|
||||
|
||||
|
||||
|
@ -14,13 +15,15 @@ from .task import GenTextTaskType, async_generate_text
|
|||
def async_setup(hass: HomeAssistant) -> None:
|
||||
"""Set up the HTTP API for the conversation integration."""
|
||||
websocket_api.async_register_command(hass, websocket_generate_text)
|
||||
websocket_api.async_register_command(hass, websocket_get_preferences)
|
||||
websocket_api.async_register_command(hass, websocket_set_preferences)
|
||||
|
||||
|
||||
@websocket_api.websocket_command(
|
||||
{
|
||||
vol.Required("type"): "ai_task/generate_text",
|
||||
vol.Required("task_name"): str,
|
||||
vol.Required("entity_id"): str,
|
||||
vol.Optional("entity_id"): str,
|
||||
vol.Required("task_type"): (lambda v: GenTextTaskType(v)), # pylint: disable=unnecessary-lambda
|
||||
vol.Required("instructions"): str,
|
||||
}
|
||||
|
@ -37,3 +40,41 @@ async def websocket_generate_text(
|
|||
msg_id = msg.pop("id")
|
||||
result = await async_generate_text(hass=hass, **msg)
|
||||
connection.send_result(msg_id, result.as_dict())
|
||||
|
||||
|
||||
@websocket_api.websocket_command(
|
||||
{
|
||||
vol.Required("type"): "ai_task/preferences/get",
|
||||
}
|
||||
)
|
||||
@callback
|
||||
def websocket_get_preferences(
|
||||
hass: HomeAssistant,
|
||||
connection: websocket_api.ActiveConnection,
|
||||
msg: dict[str, Any],
|
||||
) -> None:
|
||||
"""Get AI task preferences."""
|
||||
preferences = hass.data[DATA_PREFERENCES]
|
||||
connection.send_result(msg["id"], preferences.as_dict())
|
||||
|
||||
|
||||
@websocket_api.websocket_command(
|
||||
{
|
||||
vol.Required("type"): "ai_task/preferences/set",
|
||||
vol.Optional("gen_text_summary_entity_id"): vol.Any(str, None),
|
||||
vol.Optional("gen_text_generate_entity_id"): vol.Any(str, None),
|
||||
}
|
||||
)
|
||||
@websocket_api.require_admin
|
||||
@callback
|
||||
def websocket_set_preferences(
|
||||
hass: HomeAssistant,
|
||||
connection: websocket_api.ActiveConnection,
|
||||
msg: dict[str, Any],
|
||||
) -> None:
|
||||
"""Set AI task preferences."""
|
||||
preferences = hass.data[DATA_PREFERENCES]
|
||||
msg.pop("type")
|
||||
msg_id = msg.pop("id")
|
||||
preferences.async_set_preferences(**msg)
|
||||
connection.send_result(msg_id, preferences.as_dict())
|
||||
|
|
|
@ -6,18 +6,30 @@ from dataclasses import dataclass
|
|||
|
||||
from homeassistant.core import HomeAssistant
|
||||
|
||||
from .const import DATA_COMPONENT, GenTextTaskType
|
||||
from .const import DATA_COMPONENT, DATA_PREFERENCES, GenTextTaskType
|
||||
|
||||
|
||||
async def async_generate_text(
|
||||
hass: HomeAssistant,
|
||||
*,
|
||||
task_name: str,
|
||||
entity_id: str,
|
||||
entity_id: str | None = None,
|
||||
task_type: GenTextTaskType,
|
||||
instructions: str,
|
||||
) -> GenTextTaskResult:
|
||||
"""Run a task in the AI Task integration."""
|
||||
if entity_id is None:
|
||||
preferences = hass.data[DATA_PREFERENCES]
|
||||
if task_type == GenTextTaskType.SUMMARY:
|
||||
entity_id = preferences.gen_text_summary_entity_id
|
||||
elif task_type == GenTextTaskType.GENERATE:
|
||||
entity_id = preferences.gen_text_generate_entity_id
|
||||
|
||||
if entity_id is None:
|
||||
raise ValueError(
|
||||
"No entity_id provided and no preferred entity set for this task type"
|
||||
)
|
||||
|
||||
entity = hass.data[DATA_COMPONENT].get_entity(entity_id)
|
||||
if entity is None:
|
||||
raise ValueError(f"AI Task entity {entity_id} not found")
|
||||
|
|
|
@ -1,5 +1,8 @@
|
|||
"""Test the HTTP API for AI Task integration."""
|
||||
|
||||
import pytest
|
||||
|
||||
from homeassistant.components.ai_task import DATA_PREFERENCES, GenTextTaskType
|
||||
from homeassistant.const import STATE_UNKNOWN
|
||||
from homeassistant.core import HomeAssistant
|
||||
|
||||
|
@ -8,10 +11,14 @@ from .conftest import TEST_ENTITY_ID
|
|||
from tests.typing import WebSocketGenerator
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"task_type", [GenTextTaskType.SUMMARY, GenTextTaskType.GENERATE]
|
||||
)
|
||||
async def test_ws_generate_text(
|
||||
hass: HomeAssistant,
|
||||
hass_ws_client: WebSocketGenerator,
|
||||
init_components: None,
|
||||
task_type: GenTextTaskType,
|
||||
) -> None:
|
||||
"""Test running a task via the WebSocket API."""
|
||||
entity = hass.states.get(TEST_ENTITY_ID)
|
||||
|
@ -25,7 +32,7 @@ async def test_ws_generate_text(
|
|||
"type": "ai_task/generate_text",
|
||||
"task_name": "Test Task",
|
||||
"entity_id": TEST_ENTITY_ID,
|
||||
"task_type": "summary",
|
||||
"task_type": task_type.value,
|
||||
"instructions": "Test prompt",
|
||||
}
|
||||
)
|
||||
|
@ -37,3 +44,129 @@ async def test_ws_generate_text(
|
|||
|
||||
entity = hass.states.get(TEST_ENTITY_ID)
|
||||
assert entity.state != STATE_UNKNOWN
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"task_type", [GenTextTaskType.SUMMARY, GenTextTaskType.GENERATE]
|
||||
)
|
||||
async def test_ws_run_task_preferred_entity(
|
||||
hass: HomeAssistant,
|
||||
hass_ws_client: WebSocketGenerator,
|
||||
init_components: None,
|
||||
task_type: GenTextTaskType,
|
||||
) -> None:
|
||||
"""Test running a task via the WebSocket API."""
|
||||
preferences = hass.data[DATA_PREFERENCES]
|
||||
preferences_key = f"gen_text_{task_type.value}_entity_id"
|
||||
preferences.async_set_preferences(**{preferences_key: TEST_ENTITY_ID})
|
||||
|
||||
entity = hass.states.get(TEST_ENTITY_ID)
|
||||
assert entity is not None
|
||||
assert entity.state == STATE_UNKNOWN
|
||||
|
||||
client = await hass_ws_client(hass)
|
||||
|
||||
await client.send_json_auto_id(
|
||||
{
|
||||
"type": "ai_task/generate_text",
|
||||
"task_name": "Test Task",
|
||||
"task_type": task_type.value,
|
||||
"instructions": "Test prompt",
|
||||
}
|
||||
)
|
||||
|
||||
msg = await client.receive_json()
|
||||
|
||||
assert msg["success"]
|
||||
assert msg["result"]["result"] == "Mock result"
|
||||
|
||||
entity = hass.states.get(TEST_ENTITY_ID)
|
||||
assert entity.state != STATE_UNKNOWN
|
||||
|
||||
|
||||
async def test_ws_preferences(
|
||||
hass: HomeAssistant,
|
||||
hass_ws_client: WebSocketGenerator,
|
||||
init_components: None,
|
||||
) -> None:
|
||||
"""Test preferences via the WebSocket API."""
|
||||
client = await hass_ws_client(hass)
|
||||
|
||||
# Get initial preferences
|
||||
await client.send_json_auto_id({"type": "ai_task/preferences/get"})
|
||||
msg = await client.receive_json()
|
||||
assert msg["success"]
|
||||
assert msg["result"] == {
|
||||
"gen_text_summary_entity_id": None,
|
||||
"gen_text_generate_entity_id": None,
|
||||
}
|
||||
|
||||
# Set preferences
|
||||
await client.send_json_auto_id(
|
||||
{
|
||||
"type": "ai_task/preferences/set",
|
||||
"gen_text_summary_entity_id": "ai_task.summary_1",
|
||||
"gen_text_generate_entity_id": "ai_task.generate_1",
|
||||
}
|
||||
)
|
||||
msg = await client.receive_json()
|
||||
assert msg["success"]
|
||||
assert msg["result"] == {
|
||||
"gen_text_summary_entity_id": "ai_task.summary_1",
|
||||
"gen_text_generate_entity_id": "ai_task.generate_1",
|
||||
}
|
||||
|
||||
# Get updated preferences
|
||||
await client.send_json_auto_id({"type": "ai_task/preferences/get"})
|
||||
msg = await client.receive_json()
|
||||
assert msg["success"]
|
||||
assert msg["result"] == {
|
||||
"gen_text_summary_entity_id": "ai_task.summary_1",
|
||||
"gen_text_generate_entity_id": "ai_task.generate_1",
|
||||
}
|
||||
|
||||
# Set only one preference
|
||||
await client.send_json_auto_id(
|
||||
{
|
||||
"type": "ai_task/preferences/set",
|
||||
"gen_text_summary_entity_id": "ai_task.summary_2",
|
||||
}
|
||||
)
|
||||
msg = await client.receive_json()
|
||||
assert msg["success"]
|
||||
assert msg["result"] == {
|
||||
"gen_text_summary_entity_id": "ai_task.summary_2",
|
||||
"gen_text_generate_entity_id": "ai_task.generate_1",
|
||||
}
|
||||
|
||||
# Get updated preferences
|
||||
await client.send_json_auto_id({"type": "ai_task/preferences/get"})
|
||||
msg = await client.receive_json()
|
||||
assert msg["success"]
|
||||
assert msg["result"] == {
|
||||
"gen_text_summary_entity_id": "ai_task.summary_2",
|
||||
"gen_text_generate_entity_id": "ai_task.generate_1",
|
||||
}
|
||||
|
||||
# Clear a preference
|
||||
await client.send_json_auto_id(
|
||||
{
|
||||
"type": "ai_task/preferences/set",
|
||||
"gen_text_generate_entity_id": None,
|
||||
}
|
||||
)
|
||||
msg = await client.receive_json()
|
||||
assert msg["success"]
|
||||
assert msg["result"] == {
|
||||
"gen_text_summary_entity_id": "ai_task.summary_2",
|
||||
"gen_text_generate_entity_id": None,
|
||||
}
|
||||
|
||||
# Get updated preferences
|
||||
await client.send_json_auto_id({"type": "ai_task/preferences/get"})
|
||||
msg = await client.receive_json()
|
||||
assert msg["success"]
|
||||
assert msg["result"] == {
|
||||
"gen_text_summary_entity_id": "ai_task.summary_2",
|
||||
"gen_text_generate_entity_id": None,
|
||||
}
|
||||
|
|
|
@ -0,0 +1,62 @@
|
|||
"""Test initialization of the AI Task component."""
|
||||
|
||||
from freezegun.api import FrozenDateTimeFactory
|
||||
|
||||
from homeassistant.components.ai_task import AITaskPreferences
|
||||
from homeassistant.components.ai_task.const import DATA_PREFERENCES
|
||||
from homeassistant.core import HomeAssistant
|
||||
|
||||
from tests.common import flush_store
|
||||
|
||||
|
||||
async def test_preferences_storage_load(
|
||||
hass: HomeAssistant,
|
||||
init_components: None,
|
||||
freezer: FrozenDateTimeFactory,
|
||||
) -> None:
|
||||
"""Test that AITaskPreferences are stored and loaded correctly."""
|
||||
preferences = hass.data[DATA_PREFERENCES]
|
||||
|
||||
# Initial state should be None for entity IDs
|
||||
assert preferences.gen_text_summary_entity_id is None
|
||||
assert preferences.gen_text_generate_entity_id is None
|
||||
|
||||
summary_id_1 = "sensor.summary_one"
|
||||
generate_id_1 = "sensor.generate_one"
|
||||
|
||||
preferences.async_set_preferences(
|
||||
gen_text_summary_entity_id=summary_id_1,
|
||||
gen_text_generate_entity_id=generate_id_1,
|
||||
)
|
||||
|
||||
# Verify that current preferences object is updated
|
||||
assert preferences.gen_text_summary_entity_id == summary_id_1
|
||||
assert preferences.gen_text_generate_entity_id == generate_id_1
|
||||
|
||||
await flush_store(preferences._store)
|
||||
|
||||
# Create a new preferences instance to test loading from store
|
||||
new_preferences_instance = AITaskPreferences(hass)
|
||||
await new_preferences_instance.async_load()
|
||||
|
||||
assert new_preferences_instance.gen_text_summary_entity_id == summary_id_1
|
||||
assert new_preferences_instance.gen_text_generate_entity_id == generate_id_1
|
||||
|
||||
# Test updating one preference and setting another to None
|
||||
summary_id_2 = "sensor.summary_two"
|
||||
preferences.async_set_preferences(
|
||||
gen_text_summary_entity_id=summary_id_2, gen_text_generate_entity_id=None
|
||||
)
|
||||
|
||||
# Verify that current preferences object is updated
|
||||
assert preferences.gen_text_summary_entity_id == summary_id_2
|
||||
assert preferences.gen_text_generate_entity_id is None
|
||||
|
||||
await flush_store(preferences._store)
|
||||
|
||||
# Create another new preferences instance to confirm persistence of the update
|
||||
another_new_preferences_instance = AITaskPreferences(hass)
|
||||
await another_new_preferences_instance.async_load()
|
||||
|
||||
assert another_new_preferences_instance.gen_text_summary_entity_id == summary_id_2
|
||||
assert another_new_preferences_instance.gen_text_generate_entity_id is None
|
|
@ -4,14 +4,69 @@ from freezegun import freeze_time
|
|||
import pytest
|
||||
from syrupy.assertion import SnapshotAssertion
|
||||
|
||||
from homeassistant.components.ai_task import GenTextTaskType, async_generate_text
|
||||
from homeassistant.components.ai_task import (
|
||||
DATA_PREFERENCES,
|
||||
GenTextTaskType,
|
||||
async_generate_text,
|
||||
)
|
||||
from homeassistant.components.conversation import async_get_chat_log
|
||||
from homeassistant.const import STATE_UNKNOWN
|
||||
from homeassistant.core import HomeAssistant
|
||||
from homeassistant.helpers import chat_session
|
||||
|
||||
from .conftest import TEST_ENTITY_ID
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"task_type", [GenTextTaskType.SUMMARY, GenTextTaskType.GENERATE]
|
||||
)
|
||||
async def test_run_task_preferred_entity(
|
||||
hass: HomeAssistant,
|
||||
init_components: None,
|
||||
task_type: GenTextTaskType,
|
||||
) -> None:
|
||||
"""Test running a task with an unknown entity."""
|
||||
preferences = hass.data[DATA_PREFERENCES]
|
||||
preferences_key = f"gen_text_{task_type.value}_entity_id"
|
||||
|
||||
with pytest.raises(
|
||||
ValueError,
|
||||
match="No entity_id provided and no preferred entity set for this task type",
|
||||
):
|
||||
await async_generate_text(
|
||||
hass,
|
||||
task_name="Test Task",
|
||||
task_type=task_type,
|
||||
instructions="Test prompt",
|
||||
)
|
||||
|
||||
preferences.async_set_preferences(**{preferences_key: "ai_task.unknown"})
|
||||
|
||||
with pytest.raises(ValueError, match="AI Task entity ai_task.unknown not found"):
|
||||
await async_generate_text(
|
||||
hass,
|
||||
task_name="Test Task",
|
||||
task_type=task_type,
|
||||
instructions="Test prompt",
|
||||
)
|
||||
|
||||
preferences.async_set_preferences(**{preferences_key: TEST_ENTITY_ID})
|
||||
state = hass.states.get(TEST_ENTITY_ID)
|
||||
assert state is not None
|
||||
assert state.state == STATE_UNKNOWN
|
||||
|
||||
result = await async_generate_text(
|
||||
hass,
|
||||
task_name="Test Task",
|
||||
task_type=task_type,
|
||||
instructions="Test prompt",
|
||||
)
|
||||
assert result.result == "Mock result"
|
||||
state = hass.states.get(TEST_ENTITY_ID)
|
||||
assert state is not None
|
||||
assert state.state != STATE_UNKNOWN
|
||||
|
||||
|
||||
async def test_run_task_unknown_entity(
|
||||
hass: HomeAssistant,
|
||||
init_components: None,
|
||||
|
|
Loading…
Reference in New Issue