Ollama implement CONTROL supported feature (#123049)

pull/123066/head
Paulus Schoutsen 2024-08-02 12:31:31 +02:00 committed by GitHub
parent ad26db7dc8
commit 4a06e20318
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 42 additions and 1 deletions

View File

@ -106,6 +106,10 @@ class OllamaConversationEntity(
self._history: dict[str, MessageHistory] = {}
self._attr_name = entry.title
self._attr_unique_id = entry.entry_id
if self.entry.options.get(CONF_LLM_HASS_API):
self._attr_supported_features = (
conversation.ConversationEntityFeature.CONTROL
)
async def async_added_to_hass(self) -> None:
"""When entity is added to Home Assistant."""
@ -114,6 +118,9 @@ class OllamaConversationEntity(
self.hass, "conversation", self.entry.entry_id, self.entity_id
)
conversation.async_set_agent(self.hass, self.entry, self)
self.entry.async_on_unload(
self.entry.add_update_listener(self._async_entry_update_listener)
)
async def async_will_remove_from_hass(self) -> None:
"""When entity will be removed from Home Assistant."""
@ -334,3 +341,14 @@ class OllamaConversationEntity(
message_history.messages = [
message_history.messages[0]
] + message_history.messages[drop_index:]
async def _async_entry_update_listener(
self, hass: HomeAssistant, entry: ConfigEntry
) -> None:
"""Handle options update."""
if entry.options.get(CONF_LLM_HASS_API):
self._attr_supported_features = (
conversation.ConversationEntityFeature.CONTROL
)
else:
self._attr_supported_features = conversation.ConversationEntityFeature(0)

View File

@ -10,7 +10,7 @@ import voluptuous as vol
from homeassistant.components import conversation, ollama
from homeassistant.components.conversation import trace
from homeassistant.const import CONF_LLM_HASS_API, MATCH_ALL
from homeassistant.const import ATTR_SUPPORTED_FEATURES, CONF_LLM_HASS_API, MATCH_ALL
from homeassistant.core import Context, HomeAssistant
from homeassistant.exceptions import HomeAssistantError
from homeassistant.helpers import intent, llm
@ -554,3 +554,26 @@ async def test_conversation_agent(
mock_config_entry.entry_id
)
assert agent.supported_languages == MATCH_ALL
state = hass.states.get("conversation.mock_title")
assert state
assert state.attributes[ATTR_SUPPORTED_FEATURES] == 0
async def test_conversation_agent_with_assist(
hass: HomeAssistant,
mock_config_entry_with_assist: MockConfigEntry,
mock_init_component,
) -> None:
"""Test OllamaConversationEntity."""
agent = conversation.get_agent_manager(hass).async_get_agent(
mock_config_entry_with_assist.entry_id
)
assert agent.supported_languages == MATCH_ALL
state = hass.states.get("conversation.mock_title")
assert state
assert (
state.attributes[ATTR_SUPPORTED_FEATURES]
== conversation.ConversationEntityFeature.CONTROL
)