diff --git a/homeassistant/components/ai_task/__init__.py b/homeassistant/components/ai_task/__init__.py index 77b319c071f..65753718eec 100644 --- a/homeassistant/components/ai_task/__init__.py +++ b/homeassistant/components/ai_task/__init__.py @@ -3,12 +3,12 @@ import logging from homeassistant.config_entries import ConfigEntry -from homeassistant.core import HomeAssistant -from homeassistant.helpers import config_validation as cv +from homeassistant.core import HomeAssistant, callback +from homeassistant.helpers import config_validation as cv, storage from homeassistant.helpers.entity_component import EntityComponent -from homeassistant.helpers.typing import ConfigType +from homeassistant.helpers.typing import UNDEFINED, ConfigType, UndefinedType -from .const import DATA_COMPONENT, DOMAIN +from .const import DATA_COMPONENT, DATA_PREFERENCES, DOMAIN from .entity import AITaskEntity from .http import async_setup as async_setup_conversation_http from .task import GenTextTask, GenTextTaskResult, GenTextTaskType, async_generate_text @@ -34,6 +34,8 @@ async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool: """Register the process service.""" entity_component = EntityComponent[AITaskEntity](_LOGGER, DOMAIN, hass) hass.data[DATA_COMPONENT] = entity_component + hass.data[DATA_PREFERENCES] = AITaskPreferences(hass) + await hass.data[DATA_PREFERENCES].async_load() async_setup_conversation_http(hass) return True @@ -46,3 +48,61 @@ async def async_setup_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool: async def async_unload_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool: """Unload a config entry.""" return await hass.data[DATA_COMPONENT].async_unload_entry(entry) + + +class AITaskPreferences: + """AI Task preferences.""" + + gen_text_summary_entity_id: str | None = None + gen_text_generate_entity_id: str | None = None + + def __init__(self, hass: HomeAssistant) -> None: + """Initialize the preferences.""" + self._store: storage.Store[dict[str, str | None]] = storage.Store( + hass, 1, DOMAIN + ) + + async def async_load(self) -> None: + """Load the data from the store.""" + data = await self._store.async_load() + if data is None: + return + self.gen_text_summary_entity_id = data.get("gen_text_summary_entity_id") + self.gen_text_generate_entity_id = data.get("gen_text_generate_entity_id") + + @callback + def async_set_preferences( + self, + *, + gen_text_summary_entity_id: str | None | UndefinedType = UNDEFINED, + gen_text_generate_entity_id: str | None | UndefinedType = UNDEFINED, + ) -> None: + """Set the preferences.""" + changed = False + for key, value in ( + ("gen_text_summary_entity_id", gen_text_summary_entity_id), + ("gen_text_generate_entity_id", gen_text_generate_entity_id), + ): + if value is not UNDEFINED: + if getattr(self, key) != value: + setattr(self, key, value) + changed = True + + if not changed: + return + + self._store.async_delay_save( + lambda: { + "gen_text_summary_entity_id": self.gen_text_summary_entity_id, + "gen_text_generate_entity_id": self.gen_text_generate_entity_id, + }, + 10, + ) + + @callback + def as_dict(self) -> dict[str, str | None]: + """Get the current preferences.""" + return { + "gen_text_summary_entity_id": self.gen_text_summary_entity_id, + "gen_text_generate_entity_id": self.gen_text_generate_entity_id, + } diff --git a/homeassistant/components/ai_task/const.py b/homeassistant/components/ai_task/const.py index 9d580ab39f5..ffd22755aac 100644 --- a/homeassistant/components/ai_task/const.py +++ b/homeassistant/components/ai_task/const.py @@ -10,10 +10,12 @@ from homeassistant.util.hass_dict import HassKey if TYPE_CHECKING: from homeassistant.helpers.entity_component import EntityComponent + from . import AITaskPreferences from .entity import AITaskEntity DOMAIN = "ai_task" DATA_COMPONENT: HassKey[EntityComponent[AITaskEntity]] = HassKey(DOMAIN) +DATA_PREFERENCES: HassKey[AITaskPreferences] = HassKey(f"{DOMAIN}_preferences") DEFAULT_SYSTEM_PROMPT = ( "You are a Home Assistant expert and help users with their tasks." diff --git a/homeassistant/components/ai_task/http.py b/homeassistant/components/ai_task/http.py index c79694194aa..9c024a5ac91 100644 --- a/homeassistant/components/ai_task/http.py +++ b/homeassistant/components/ai_task/http.py @@ -7,6 +7,7 @@ import voluptuous as vol from homeassistant.components import websocket_api from homeassistant.core import HomeAssistant, callback +from .const import DATA_PREFERENCES from .task import GenTextTaskType, async_generate_text @@ -14,13 +15,15 @@ from .task import GenTextTaskType, async_generate_text def async_setup(hass: HomeAssistant) -> None: """Set up the HTTP API for the conversation integration.""" websocket_api.async_register_command(hass, websocket_generate_text) + websocket_api.async_register_command(hass, websocket_get_preferences) + websocket_api.async_register_command(hass, websocket_set_preferences) @websocket_api.websocket_command( { vol.Required("type"): "ai_task/generate_text", vol.Required("task_name"): str, - vol.Required("entity_id"): str, + vol.Optional("entity_id"): str, vol.Required("task_type"): (lambda v: GenTextTaskType(v)), # pylint: disable=unnecessary-lambda vol.Required("instructions"): str, } @@ -37,3 +40,41 @@ async def websocket_generate_text( msg_id = msg.pop("id") result = await async_generate_text(hass=hass, **msg) connection.send_result(msg_id, result.as_dict()) + + +@websocket_api.websocket_command( + { + vol.Required("type"): "ai_task/preferences/get", + } +) +@callback +def websocket_get_preferences( + hass: HomeAssistant, + connection: websocket_api.ActiveConnection, + msg: dict[str, Any], +) -> None: + """Get AI task preferences.""" + preferences = hass.data[DATA_PREFERENCES] + connection.send_result(msg["id"], preferences.as_dict()) + + +@websocket_api.websocket_command( + { + vol.Required("type"): "ai_task/preferences/set", + vol.Optional("gen_text_summary_entity_id"): vol.Any(str, None), + vol.Optional("gen_text_generate_entity_id"): vol.Any(str, None), + } +) +@websocket_api.require_admin +@callback +def websocket_set_preferences( + hass: HomeAssistant, + connection: websocket_api.ActiveConnection, + msg: dict[str, Any], +) -> None: + """Set AI task preferences.""" + preferences = hass.data[DATA_PREFERENCES] + msg.pop("type") + msg_id = msg.pop("id") + preferences.async_set_preferences(**msg) + connection.send_result(msg_id, preferences.as_dict()) diff --git a/homeassistant/components/ai_task/task.py b/homeassistant/components/ai_task/task.py index 2f6c901dce8..ec486aa483b 100644 --- a/homeassistant/components/ai_task/task.py +++ b/homeassistant/components/ai_task/task.py @@ -6,18 +6,30 @@ from dataclasses import dataclass from homeassistant.core import HomeAssistant -from .const import DATA_COMPONENT, GenTextTaskType +from .const import DATA_COMPONENT, DATA_PREFERENCES, GenTextTaskType async def async_generate_text( hass: HomeAssistant, *, task_name: str, - entity_id: str, + entity_id: str | None = None, task_type: GenTextTaskType, instructions: str, ) -> GenTextTaskResult: """Run a task in the AI Task integration.""" + if entity_id is None: + preferences = hass.data[DATA_PREFERENCES] + if task_type == GenTextTaskType.SUMMARY: + entity_id = preferences.gen_text_summary_entity_id + elif task_type == GenTextTaskType.GENERATE: + entity_id = preferences.gen_text_generate_entity_id + + if entity_id is None: + raise ValueError( + "No entity_id provided and no preferred entity set for this task type" + ) + entity = hass.data[DATA_COMPONENT].get_entity(entity_id) if entity is None: raise ValueError(f"AI Task entity {entity_id} not found") diff --git a/tests/components/ai_task/test_http.py b/tests/components/ai_task/test_http.py index c2cb430a443..cb67787754a 100644 --- a/tests/components/ai_task/test_http.py +++ b/tests/components/ai_task/test_http.py @@ -1,5 +1,8 @@ """Test the HTTP API for AI Task integration.""" +import pytest + +from homeassistant.components.ai_task import DATA_PREFERENCES, GenTextTaskType from homeassistant.const import STATE_UNKNOWN from homeassistant.core import HomeAssistant @@ -8,10 +11,14 @@ from .conftest import TEST_ENTITY_ID from tests.typing import WebSocketGenerator +@pytest.mark.parametrize( + "task_type", [GenTextTaskType.SUMMARY, GenTextTaskType.GENERATE] +) async def test_ws_generate_text( hass: HomeAssistant, hass_ws_client: WebSocketGenerator, init_components: None, + task_type: GenTextTaskType, ) -> None: """Test running a task via the WebSocket API.""" entity = hass.states.get(TEST_ENTITY_ID) @@ -25,7 +32,7 @@ async def test_ws_generate_text( "type": "ai_task/generate_text", "task_name": "Test Task", "entity_id": TEST_ENTITY_ID, - "task_type": "summary", + "task_type": task_type.value, "instructions": "Test prompt", } ) @@ -37,3 +44,129 @@ async def test_ws_generate_text( entity = hass.states.get(TEST_ENTITY_ID) assert entity.state != STATE_UNKNOWN + + +@pytest.mark.parametrize( + "task_type", [GenTextTaskType.SUMMARY, GenTextTaskType.GENERATE] +) +async def test_ws_run_task_preferred_entity( + hass: HomeAssistant, + hass_ws_client: WebSocketGenerator, + init_components: None, + task_type: GenTextTaskType, +) -> None: + """Test running a task via the WebSocket API.""" + preferences = hass.data[DATA_PREFERENCES] + preferences_key = f"gen_text_{task_type.value}_entity_id" + preferences.async_set_preferences(**{preferences_key: TEST_ENTITY_ID}) + + entity = hass.states.get(TEST_ENTITY_ID) + assert entity is not None + assert entity.state == STATE_UNKNOWN + + client = await hass_ws_client(hass) + + await client.send_json_auto_id( + { + "type": "ai_task/generate_text", + "task_name": "Test Task", + "task_type": task_type.value, + "instructions": "Test prompt", + } + ) + + msg = await client.receive_json() + + assert msg["success"] + assert msg["result"]["result"] == "Mock result" + + entity = hass.states.get(TEST_ENTITY_ID) + assert entity.state != STATE_UNKNOWN + + +async def test_ws_preferences( + hass: HomeAssistant, + hass_ws_client: WebSocketGenerator, + init_components: None, +) -> None: + """Test preferences via the WebSocket API.""" + client = await hass_ws_client(hass) + + # Get initial preferences + await client.send_json_auto_id({"type": "ai_task/preferences/get"}) + msg = await client.receive_json() + assert msg["success"] + assert msg["result"] == { + "gen_text_summary_entity_id": None, + "gen_text_generate_entity_id": None, + } + + # Set preferences + await client.send_json_auto_id( + { + "type": "ai_task/preferences/set", + "gen_text_summary_entity_id": "ai_task.summary_1", + "gen_text_generate_entity_id": "ai_task.generate_1", + } + ) + msg = await client.receive_json() + assert msg["success"] + assert msg["result"] == { + "gen_text_summary_entity_id": "ai_task.summary_1", + "gen_text_generate_entity_id": "ai_task.generate_1", + } + + # Get updated preferences + await client.send_json_auto_id({"type": "ai_task/preferences/get"}) + msg = await client.receive_json() + assert msg["success"] + assert msg["result"] == { + "gen_text_summary_entity_id": "ai_task.summary_1", + "gen_text_generate_entity_id": "ai_task.generate_1", + } + + # Set only one preference + await client.send_json_auto_id( + { + "type": "ai_task/preferences/set", + "gen_text_summary_entity_id": "ai_task.summary_2", + } + ) + msg = await client.receive_json() + assert msg["success"] + assert msg["result"] == { + "gen_text_summary_entity_id": "ai_task.summary_2", + "gen_text_generate_entity_id": "ai_task.generate_1", + } + + # Get updated preferences + await client.send_json_auto_id({"type": "ai_task/preferences/get"}) + msg = await client.receive_json() + assert msg["success"] + assert msg["result"] == { + "gen_text_summary_entity_id": "ai_task.summary_2", + "gen_text_generate_entity_id": "ai_task.generate_1", + } + + # Clear a preference + await client.send_json_auto_id( + { + "type": "ai_task/preferences/set", + "gen_text_generate_entity_id": None, + } + ) + msg = await client.receive_json() + assert msg["success"] + assert msg["result"] == { + "gen_text_summary_entity_id": "ai_task.summary_2", + "gen_text_generate_entity_id": None, + } + + # Get updated preferences + await client.send_json_auto_id({"type": "ai_task/preferences/get"}) + msg = await client.receive_json() + assert msg["success"] + assert msg["result"] == { + "gen_text_summary_entity_id": "ai_task.summary_2", + "gen_text_generate_entity_id": None, + } diff --git a/tests/components/ai_task/test_init.py b/tests/components/ai_task/test_init.py new file mode 100644 index 00000000000..56a00d6e1e8 --- /dev/null +++ b/tests/components/ai_task/test_init.py @@ -0,0 +1,62 @@ +"""Test initialization of the AI Task component.""" + +from freezegun.api import FrozenDateTimeFactory + +from homeassistant.components.ai_task import AITaskPreferences +from homeassistant.components.ai_task.const import DATA_PREFERENCES +from homeassistant.core import HomeAssistant + +from tests.common import flush_store + + +async def test_preferences_storage_load( + hass: HomeAssistant, + init_components: None, + freezer: FrozenDateTimeFactory, +) -> None: + """Test that AITaskPreferences are stored and loaded correctly.""" + preferences = hass.data[DATA_PREFERENCES] + + # Initial state should be None for entity IDs + assert preferences.gen_text_summary_entity_id is None + assert preferences.gen_text_generate_entity_id is None + + summary_id_1 = "sensor.summary_one" + generate_id_1 = "sensor.generate_one" + + preferences.async_set_preferences( + gen_text_summary_entity_id=summary_id_1, + gen_text_generate_entity_id=generate_id_1, + ) + + # Verify that current preferences object is updated + assert preferences.gen_text_summary_entity_id == summary_id_1 + assert preferences.gen_text_generate_entity_id == generate_id_1 + + await flush_store(preferences._store) + + # Create a new preferences instance to test loading from store + new_preferences_instance = AITaskPreferences(hass) + await new_preferences_instance.async_load() + + assert new_preferences_instance.gen_text_summary_entity_id == summary_id_1 + assert new_preferences_instance.gen_text_generate_entity_id == generate_id_1 + + # Test updating one preference and setting another to None + summary_id_2 = "sensor.summary_two" + preferences.async_set_preferences( + gen_text_summary_entity_id=summary_id_2, gen_text_generate_entity_id=None + ) + + # Verify that current preferences object is updated + assert preferences.gen_text_summary_entity_id == summary_id_2 + assert preferences.gen_text_generate_entity_id is None + + await flush_store(preferences._store) + + # Create another new preferences instance to confirm persistence of the update + another_new_preferences_instance = AITaskPreferences(hass) + await another_new_preferences_instance.async_load() + + assert another_new_preferences_instance.gen_text_summary_entity_id == summary_id_2 + assert another_new_preferences_instance.gen_text_generate_entity_id is None diff --git a/tests/components/ai_task/test_task.py b/tests/components/ai_task/test_task.py index b1138e7459c..171fca41ff1 100644 --- a/tests/components/ai_task/test_task.py +++ b/tests/components/ai_task/test_task.py @@ -4,14 +4,69 @@ from freezegun import freeze_time import pytest from syrupy.assertion import SnapshotAssertion -from homeassistant.components.ai_task import GenTextTaskType, async_generate_text +from homeassistant.components.ai_task import ( + DATA_PREFERENCES, + GenTextTaskType, + async_generate_text, +) from homeassistant.components.conversation import async_get_chat_log +from homeassistant.const import STATE_UNKNOWN from homeassistant.core import HomeAssistant from homeassistant.helpers import chat_session from .conftest import TEST_ENTITY_ID +@pytest.mark.parametrize( + "task_type", [GenTextTaskType.SUMMARY, GenTextTaskType.GENERATE] +) +async def test_run_task_preferred_entity( + hass: HomeAssistant, + init_components: None, + task_type: GenTextTaskType, +) -> None: + """Test running a task with an unknown entity.""" + preferences = hass.data[DATA_PREFERENCES] + preferences_key = f"gen_text_{task_type.value}_entity_id" + + with pytest.raises( + ValueError, + match="No entity_id provided and no preferred entity set for this task type", + ): + await async_generate_text( + hass, + task_name="Test Task", + task_type=task_type, + instructions="Test prompt", + ) + + preferences.async_set_preferences(**{preferences_key: "ai_task.unknown"}) + + with pytest.raises(ValueError, match="AI Task entity ai_task.unknown not found"): + await async_generate_text( + hass, + task_name="Test Task", + task_type=task_type, + instructions="Test prompt", + ) + + preferences.async_set_preferences(**{preferences_key: TEST_ENTITY_ID}) + state = hass.states.get(TEST_ENTITY_ID) + assert state is not None + assert state.state == STATE_UNKNOWN + + result = await async_generate_text( + hass, + task_name="Test Task", + task_type=task_type, + instructions="Test prompt", + ) + assert result.result == "Mock result" + state = hass.states.get(TEST_ENTITY_ID) + assert state is not None + assert state.state != STATE_UNKNOWN + + async def test_run_task_unknown_entity( hass: HomeAssistant, init_components: None,