Clean up conversation agent attribution (#96883)
* Clean up conversation agent attribution * Clean up google_generative_ai_conversation as wellpull/96890/head
parent
22d0f4ff0a
commit
f2bd122fde
|
@ -195,7 +195,6 @@ async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool:
|
||||||
hass.http.register_view(ConversationProcessView())
|
hass.http.register_view(ConversationProcessView())
|
||||||
websocket_api.async_register_command(hass, websocket_process)
|
websocket_api.async_register_command(hass, websocket_process)
|
||||||
websocket_api.async_register_command(hass, websocket_prepare)
|
websocket_api.async_register_command(hass, websocket_prepare)
|
||||||
websocket_api.async_register_command(hass, websocket_get_agent_info)
|
|
||||||
websocket_api.async_register_command(hass, websocket_list_agents)
|
websocket_api.async_register_command(hass, websocket_list_agents)
|
||||||
websocket_api.async_register_command(hass, websocket_hass_agent_debug)
|
websocket_api.async_register_command(hass, websocket_hass_agent_debug)
|
||||||
|
|
||||||
|
@ -249,29 +248,6 @@ async def websocket_prepare(
|
||||||
connection.send_result(msg["id"])
|
connection.send_result(msg["id"])
|
||||||
|
|
||||||
|
|
||||||
@websocket_api.websocket_command(
|
|
||||||
{
|
|
||||||
vol.Required("type"): "conversation/agent/info",
|
|
||||||
vol.Optional("agent_id"): agent_id_validator,
|
|
||||||
}
|
|
||||||
)
|
|
||||||
@websocket_api.async_response
|
|
||||||
async def websocket_get_agent_info(
|
|
||||||
hass: HomeAssistant,
|
|
||||||
connection: websocket_api.ActiveConnection,
|
|
||||||
msg: dict[str, Any],
|
|
||||||
) -> None:
|
|
||||||
"""Info about the agent in use."""
|
|
||||||
agent = await _get_agent_manager(hass).async_get_agent(msg.get("agent_id"))
|
|
||||||
|
|
||||||
connection.send_result(
|
|
||||||
msg["id"],
|
|
||||||
{
|
|
||||||
"attribution": agent.attribution,
|
|
||||||
},
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
@websocket_api.websocket_command(
|
@websocket_api.websocket_command(
|
||||||
{
|
{
|
||||||
vol.Required("type"): "conversation/agent/list",
|
vol.Required("type"): "conversation/agent/list",
|
||||||
|
|
|
@ -3,7 +3,7 @@ from __future__ import annotations
|
||||||
|
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import Any, Literal, TypedDict
|
from typing import Any, Literal
|
||||||
|
|
||||||
from homeassistant.core import Context
|
from homeassistant.core import Context
|
||||||
from homeassistant.helpers import intent
|
from homeassistant.helpers import intent
|
||||||
|
@ -35,21 +35,9 @@ class ConversationResult:
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
class Attribution(TypedDict):
|
|
||||||
"""Attribution for a conversation agent."""
|
|
||||||
|
|
||||||
name: str
|
|
||||||
url: str
|
|
||||||
|
|
||||||
|
|
||||||
class AbstractConversationAgent(ABC):
|
class AbstractConversationAgent(ABC):
|
||||||
"""Abstract conversation agent."""
|
"""Abstract conversation agent."""
|
||||||
|
|
||||||
@property
|
|
||||||
def attribution(self) -> Attribution | None:
|
|
||||||
"""Return the attribution."""
|
|
||||||
return None
|
|
||||||
|
|
||||||
@property
|
@property
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def supported_languages(self) -> list[str] | Literal["*"]:
|
def supported_languages(self) -> list[str] | Literal["*"]:
|
||||||
|
|
|
@ -128,14 +128,6 @@ class GoogleAssistantConversationAgent(conversation.AbstractConversationAgent):
|
||||||
self.session: OAuth2Session | None = None
|
self.session: OAuth2Session | None = None
|
||||||
self.language: str | None = None
|
self.language: str | None = None
|
||||||
|
|
||||||
@property
|
|
||||||
def attribution(self):
|
|
||||||
"""Return the attribution."""
|
|
||||||
return {
|
|
||||||
"name": "Powered by Google Assistant SDK",
|
|
||||||
"url": "https://www.home-assistant.io/integrations/google_assistant_sdk/",
|
|
||||||
}
|
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def supported_languages(self) -> list[str]:
|
def supported_languages(self) -> list[str]:
|
||||||
"""Return a list of supported languages."""
|
"""Return a list of supported languages."""
|
||||||
|
|
|
@ -69,14 +69,6 @@ class GoogleGenerativeAIAgent(conversation.AbstractConversationAgent):
|
||||||
self.entry = entry
|
self.entry = entry
|
||||||
self.history: dict[str, list[dict]] = {}
|
self.history: dict[str, list[dict]] = {}
|
||||||
|
|
||||||
@property
|
|
||||||
def attribution(self):
|
|
||||||
"""Return the attribution."""
|
|
||||||
return {
|
|
||||||
"name": "Powered by Google Generative AI",
|
|
||||||
"url": "https://developers.generativeai.google/",
|
|
||||||
}
|
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def supported_languages(self) -> list[str] | Literal["*"]:
|
def supported_languages(self) -> list[str] | Literal["*"]:
|
||||||
"""Return a list of supported languages."""
|
"""Return a list of supported languages."""
|
||||||
|
|
|
@ -66,11 +66,6 @@ class OpenAIAgent(conversation.AbstractConversationAgent):
|
||||||
self.entry = entry
|
self.entry = entry
|
||||||
self.history: dict[str, list[dict]] = {}
|
self.history: dict[str, list[dict]] = {}
|
||||||
|
|
||||||
@property
|
|
||||||
def attribution(self):
|
|
||||||
"""Return the attribution."""
|
|
||||||
return {"name": "Powered by OpenAI", "url": "https://www.openai.com"}
|
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def supported_languages(self) -> list[str] | Literal["*"]:
|
def supported_languages(self) -> list[str] | Literal["*"]:
|
||||||
"""Return a list of supported languages."""
|
"""Return a list of supported languages."""
|
||||||
|
|
|
@ -24,11 +24,6 @@ class MockAgent(conversation.AbstractConversationAgent):
|
||||||
self.response = "Test response"
|
self.response = "Test response"
|
||||||
self._supported_languages = supported_languages
|
self._supported_languages = supported_languages
|
||||||
|
|
||||||
@property
|
|
||||||
def attribution(self) -> conversation.Attribution | None:
|
|
||||||
"""Return the attribution."""
|
|
||||||
return {"name": "Mock assistant", "url": "https://assist.me"}
|
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def supported_languages(self) -> list[str]:
|
def supported_languages(self) -> list[str]:
|
||||||
"""Return a list of supported languages."""
|
"""Return a list of supported languages."""
|
||||||
|
|
|
@ -1611,43 +1611,6 @@ async def test_get_agent_info(
|
||||||
assert agent_info == snapshot
|
assert agent_info == snapshot
|
||||||
|
|
||||||
|
|
||||||
async def test_ws_get_agent_info(
|
|
||||||
hass: HomeAssistant,
|
|
||||||
init_components,
|
|
||||||
mock_agent,
|
|
||||||
hass_ws_client: WebSocketGenerator,
|
|
||||||
snapshot: SnapshotAssertion,
|
|
||||||
) -> None:
|
|
||||||
"""Test get agent info."""
|
|
||||||
client = await hass_ws_client(hass)
|
|
||||||
|
|
||||||
await client.send_json_auto_id({"type": "conversation/agent/info"})
|
|
||||||
msg = await client.receive_json()
|
|
||||||
assert msg["success"]
|
|
||||||
assert msg["result"] == snapshot
|
|
||||||
|
|
||||||
await client.send_json_auto_id(
|
|
||||||
{"type": "conversation/agent/info", "agent_id": "homeassistant"}
|
|
||||||
)
|
|
||||||
msg = await client.receive_json()
|
|
||||||
assert msg["success"]
|
|
||||||
assert msg["result"] == snapshot
|
|
||||||
|
|
||||||
await client.send_json_auto_id(
|
|
||||||
{"type": "conversation/agent/info", "agent_id": mock_agent.agent_id}
|
|
||||||
)
|
|
||||||
msg = await client.receive_json()
|
|
||||||
assert msg["success"]
|
|
||||||
assert msg["result"] == snapshot
|
|
||||||
|
|
||||||
await client.send_json_auto_id(
|
|
||||||
{"type": "conversation/agent/info", "agent_id": "not_exist"}
|
|
||||||
)
|
|
||||||
msg = await client.receive_json()
|
|
||||||
assert not msg["success"]
|
|
||||||
assert msg["error"] == snapshot
|
|
||||||
|
|
||||||
|
|
||||||
async def test_ws_hass_agent_debug(
|
async def test_ws_hass_agent_debug(
|
||||||
hass: HomeAssistant,
|
hass: HomeAssistant,
|
||||||
init_components,
|
init_components,
|
||||||
|
|
|
@ -326,7 +326,6 @@ async def test_conversation_agent(
|
||||||
assert entry.state is ConfigEntryState.LOADED
|
assert entry.state is ConfigEntryState.LOADED
|
||||||
|
|
||||||
agent = await conversation._get_agent_manager(hass).async_get_agent(entry.entry_id)
|
agent = await conversation._get_agent_manager(hass).async_get_agent(entry.entry_id)
|
||||||
assert agent.attribution.keys() == {"name", "url"}
|
|
||||||
assert agent.supported_languages == SUPPORTED_LANGUAGE_CODES
|
assert agent.supported_languages == SUPPORTED_LANGUAGE_CODES
|
||||||
|
|
||||||
text1 = "tell me a joke"
|
text1 = "tell me a joke"
|
||||||
|
|
Loading…
Reference in New Issue