Use model list to check anthropic API key (#139307)

Anthropic model list
pull/139311/head
Denis Shulyaka 2025-03-02 00:28:48 +03:00 committed by GitHub
parent 3588784f1e
commit 1786bb9903
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
6 changed files with 20 additions and 31 deletions

View File

@ -12,7 +12,7 @@ from homeassistant.core import HomeAssistant
from homeassistant.exceptions import ConfigEntryNotReady
from homeassistant.helpers import config_validation as cv
from .const import DOMAIN, LOGGER
from .const import CONF_CHAT_MODEL, DOMAIN, LOGGER, RECOMMENDED_CHAT_MODEL
PLATFORMS = (Platform.CONVERSATION,)
CONFIG_SCHEMA = cv.config_entry_only_config_schema(DOMAIN)
@ -26,12 +26,9 @@ async def async_setup_entry(hass: HomeAssistant, entry: AnthropicConfigEntry) ->
partial(anthropic.AsyncAnthropic, api_key=entry.data[CONF_API_KEY])
)
try:
await client.messages.create(
model="claude-3-haiku-20240307",
max_tokens=1,
messages=[{"role": "user", "content": "Hi"}],
timeout=10.0,
)
model_id = entry.options.get(CONF_CHAT_MODEL, RECOMMENDED_CHAT_MODEL)
model = await client.models.retrieve(model_id=model_id, timeout=10.0)
LOGGER.debug("Anthropic model: %s", model.display_name)
except anthropic.AuthenticationError as err:
LOGGER.error("Invalid API key: %s", err)
return False

View File

@ -63,12 +63,7 @@ async def validate_input(hass: HomeAssistant, data: dict[str, Any]) -> None:
client = await hass.async_add_executor_job(
partial(anthropic.AsyncAnthropic, api_key=data[CONF_API_KEY])
)
await client.messages.create(
model="claude-3-haiku-20240307",
max_tokens=1,
messages=[{"role": "user", "content": "Hi"}],
timeout=10.0,
)
await client.models.list(timeout=10.0)
class AnthropicConfigFlow(ConfigFlow, domain=DOMAIN):

View File

@ -1,7 +1,7 @@
"""Tests helpers."""
from collections.abc import AsyncGenerator
from unittest.mock import AsyncMock, patch
from unittest.mock import patch
import pytest
@ -43,9 +43,7 @@ async def mock_init_component(
hass: HomeAssistant, mock_config_entry: MockConfigEntry
) -> AsyncGenerator[None]:
"""Initialize integration."""
with patch(
"anthropic.resources.messages.AsyncMessages.create", new_callable=AsyncMock
):
with patch("anthropic.resources.models.AsyncModels.retrieve"):
assert await async_setup_component(hass, "anthropic", {})
await hass.async_block_till_done()
yield

View File

@ -49,7 +49,7 @@ async def test_form(hass: HomeAssistant) -> None:
with (
patch(
"homeassistant.components.anthropic.config_flow.anthropic.resources.messages.AsyncMessages.create",
"homeassistant.components.anthropic.config_flow.anthropic.resources.models.AsyncModels.list",
new_callable=AsyncMock,
),
patch(
@ -151,7 +151,7 @@ async def test_form_invalid_auth(hass: HomeAssistant, side_effect, error) -> Non
)
with patch(
"homeassistant.components.anthropic.config_flow.anthropic.resources.messages.AsyncMessages.create",
"homeassistant.components.anthropic.config_flow.anthropic.resources.models.AsyncModels.list",
new_callable=AsyncMock,
side_effect=side_effect,
):

View File

@ -127,9 +127,7 @@ async def test_entity(
CONF_LLM_HASS_API: "assist",
},
)
with patch(
"anthropic.resources.messages.AsyncMessages.create", new_callable=AsyncMock
):
with patch("anthropic.resources.models.AsyncModels.retrieve"):
await hass.config_entries.async_reload(mock_config_entry.entry_id)
state = hass.states.get("conversation.claude")
@ -173,8 +171,11 @@ async def test_template_error(
"prompt": "talk like a {% if True %}smarthome{% else %}pirate please.",
},
)
with patch(
"anthropic.resources.messages.AsyncMessages.create", new_callable=AsyncMock
with (
patch("anthropic.resources.models.AsyncModels.retrieve"),
patch(
"anthropic.resources.messages.AsyncMessages.create", new_callable=AsyncMock
),
):
await hass.config_entries.async_setup(mock_config_entry.entry_id)
await hass.async_block_till_done()
@ -205,6 +206,7 @@ async def test_template_variables(
},
)
with (
patch("anthropic.resources.models.AsyncModels.retrieve"),
patch(
"anthropic.resources.messages.AsyncMessages.create", new_callable=AsyncMock
) as mock_create,
@ -230,8 +232,8 @@ async def test_template_variables(
result.response.speech["plain"]["speech"]
== "Okay, let me take care of that for you."
)
assert "The user name is Test User." in mock_create.mock_calls[1][2]["system"]
assert "The user id is 12345." in mock_create.mock_calls[1][2]["system"]
assert "The user name is Test User." in mock_create.call_args.kwargs["system"]
assert "The user id is 12345." in mock_create.call_args.kwargs["system"]
async def test_conversation_agent(
@ -497,9 +499,7 @@ async def test_unknown_hass_api(
assert result == snapshot
@patch("anthropic.resources.messages.AsyncMessages.create", new_callable=AsyncMock)
async def test_conversation_id(
mock_create,
hass: HomeAssistant,
mock_config_entry: MockConfigEntry,
mock_init_component,

View File

@ -1,6 +1,6 @@
"""Tests for the Anthropic integration."""
from unittest.mock import AsyncMock, patch
from unittest.mock import patch
from anthropic import (
APIConnectionError,
@ -55,8 +55,7 @@ async def test_init_error(
) -> None:
"""Test initialization errors."""
with patch(
"anthropic.resources.messages.AsyncMessages.create",
new_callable=AsyncMock,
"anthropic.resources.models.AsyncModels.retrieve",
side_effect=side_effect,
):
assert await async_setup_component(hass, "anthropic", {})