fix(backend): Fix `credentials` cost filter not able to filter the block cost (#8837)
We've started enabling cost based on the *partial value* of the `credentials` field. And this logic has never been supported. ### Changes 🏗️ * Add partial object matching on the input data filter for evaluating the block cost. * Add missing credentials for `ExtractWebsiteContentBlock` * Removed fallback cost on LLM blocks. ### Checklist 📋 #### For code changes: - [ ] I have clearly listed my changes in the PR description - [ ] I have made a test plan - [ ] I have tested my changes according to the test plan: <!-- Put your test plan here: --> - [ ] ... <details> <summary>Example test plan</summary> - [ ] Create from scratch and execute an agent with at least 3 blocks - [ ] Import an agent from file upload, and confirm it executes correctly - [ ] Upload agent to marketplace - [ ] Import an agent from marketplace and confirm it executes correctly - [ ] Edit an agent from monitor, and confirm it executes correctly </details> #### For configuration changes: - [ ] `.env.example` is updated or already compatible with my changes - [ ] `docker-compose.yml` is updated or already compatible with my changes - [ ] I have included a list of my configuration changes in the PR description (under **Changes**) <details> <summary>Examples of configuration changes</summary> - Changing ports - Adding new services that need to communicate with each other - Secrets or environment variable changes - New or infrastructure changes such as databases </details>pull/8836/head^2
parent
520b1d7940
commit
eeb5b4aa46
|
@ -55,3 +55,53 @@ class SearchTheWebBlock(Block, GetRequest):
|
|||
|
||||
# Output the search results
|
||||
yield "results", results
|
||||
|
||||
|
||||
class ExtractWebsiteContentBlock(Block, GetRequest):
|
||||
class Input(BlockSchema):
|
||||
credentials: JinaCredentialsInput = JinaCredentialsField()
|
||||
url: str = SchemaField(description="The URL to scrape the content from")
|
||||
raw_content: bool = SchemaField(
|
||||
default=False,
|
||||
title="Raw Content",
|
||||
description="Whether to do a raw scrape of the content or use Jina-ai Reader to scrape the content",
|
||||
advanced=True,
|
||||
)
|
||||
|
||||
class Output(BlockSchema):
|
||||
content: str = SchemaField(description="The scraped content from the given URL")
|
||||
error: str = SchemaField(
|
||||
description="Error message if the content cannot be retrieved"
|
||||
)
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="436c3984-57fd-4b85-8e9a-459b356883bd",
|
||||
description="This block scrapes the content from the given web URL.",
|
||||
categories={BlockCategory.SEARCH},
|
||||
input_schema=ExtractWebsiteContentBlock.Input,
|
||||
output_schema=ExtractWebsiteContentBlock.Output,
|
||||
test_input={
|
||||
"url": "https://en.wikipedia.org/wiki/Artificial_intelligence",
|
||||
"credentials": TEST_CREDENTIALS_INPUT,
|
||||
},
|
||||
test_credentials=TEST_CREDENTIALS,
|
||||
test_output=("content", "scraped content"),
|
||||
test_mock={"get_request": lambda *args, **kwargs: "scraped content"},
|
||||
)
|
||||
|
||||
def run(
|
||||
self, input_data: Input, *, credentials: JinaCredentials, **kwargs
|
||||
) -> BlockOutput:
|
||||
if input_data.raw_content:
|
||||
url = input_data.url
|
||||
headers = {}
|
||||
else:
|
||||
url = f"https://r.jina.ai/{input_data.url}"
|
||||
headers = {
|
||||
"Content-Type": "application/json",
|
||||
"Authorization": f"Bearer {credentials.api_key.get_secret_value()}",
|
||||
}
|
||||
|
||||
content = self.get_request(url, json=False, headers=headers)
|
||||
yield "content", content
|
||||
|
|
|
@ -23,13 +23,6 @@ from backend.util.settings import BehaveAs, Settings
|
|||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# LlmApiKeys = {
|
||||
# "openai": BlockSecret("openai_api_key"),
|
||||
# "anthropic": BlockSecret("anthropic_api_key"),
|
||||
# "groq": BlockSecret("groq_api_key"),
|
||||
# "ollama": BlockSecret(value=""),
|
||||
# }
|
||||
|
||||
LLMProviderName = Literal["anthropic", "groq", "openai", "ollama", "open_router"]
|
||||
AICredentials = CredentialsMetaInput[LLMProviderName, Literal["api_key"]]
|
||||
|
||||
|
|
|
@ -40,44 +40,6 @@ class GetWikipediaSummaryBlock(Block, GetRequest):
|
|||
yield "summary", response["extract"]
|
||||
|
||||
|
||||
class ExtractWebsiteContentBlock(Block, GetRequest):
|
||||
class Input(BlockSchema):
|
||||
url: str = SchemaField(description="The URL to scrape the content from")
|
||||
raw_content: bool = SchemaField(
|
||||
default=False,
|
||||
title="Raw Content",
|
||||
description="Whether to do a raw scrape of the content or use Jina-ai Reader to scrape the content",
|
||||
advanced=True,
|
||||
)
|
||||
|
||||
class Output(BlockSchema):
|
||||
content: str = SchemaField(description="The scraped content from the given URL")
|
||||
error: str = SchemaField(
|
||||
description="Error message if the content cannot be retrieved"
|
||||
)
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="436c3984-57fd-4b85-8e9a-459b356883bd",
|
||||
description="This block scrapes the content from the given web URL.",
|
||||
categories={BlockCategory.SEARCH},
|
||||
input_schema=ExtractWebsiteContentBlock.Input,
|
||||
output_schema=ExtractWebsiteContentBlock.Output,
|
||||
test_input={"url": "https://en.wikipedia.org/wiki/Artificial_intelligence"},
|
||||
test_output=("content", "scraped content"),
|
||||
test_mock={"get_request": lambda url, json: "scraped content"},
|
||||
)
|
||||
|
||||
def run(self, input_data: Input, **kwargs) -> BlockOutput:
|
||||
if input_data.raw_content:
|
||||
url = input_data.url
|
||||
else:
|
||||
url = f"https://r.jina.ai/{input_data.url}"
|
||||
|
||||
content = self.get_request(url, json=False)
|
||||
yield "content", content
|
||||
|
||||
|
||||
TEST_CREDENTIALS = APIKeyCredentials(
|
||||
id="01234567-89ab-cdef-0123-456789abcdef",
|
||||
provider="openweathermap",
|
||||
|
|
|
@ -17,7 +17,7 @@ from backend.blocks.ai_music_generator import AIMusicGeneratorBlock
|
|||
from backend.blocks.ai_shortform_video_block import AIShortformVideoCreatorBlock
|
||||
from backend.blocks.ideogram import IdeogramModelBlock
|
||||
from backend.blocks.jina.embeddings import JinaEmbeddingBlock
|
||||
from backend.blocks.jina.search import SearchTheWebBlock
|
||||
from backend.blocks.jina.search import ExtractWebsiteContentBlock, SearchTheWebBlock
|
||||
from backend.blocks.llm import (
|
||||
MODEL_METADATA,
|
||||
AIConversationBlock,
|
||||
|
@ -28,7 +28,6 @@ from backend.blocks.llm import (
|
|||
LlmModel,
|
||||
)
|
||||
from backend.blocks.replicate_flux_advanced import ReplicateFluxAdvancedModelBlock
|
||||
from backend.blocks.search import ExtractWebsiteContentBlock
|
||||
from backend.blocks.talking_head import CreateTalkingAvatarVideoBlock
|
||||
from backend.blocks.text_to_speech_block import UnrealTextToSpeechBlock
|
||||
from backend.data.block import Block
|
||||
|
@ -72,18 +71,8 @@ for model in LlmModel:
|
|||
|
||||
|
||||
LLM_COST = (
|
||||
# Anthropic Models
|
||||
[
|
||||
BlockCost(
|
||||
cost_type=BlockCostType.RUN,
|
||||
cost_filter={
|
||||
"model": model,
|
||||
"api_key": None, # Running LLM with user own API key is free.
|
||||
},
|
||||
cost_amount=cost,
|
||||
)
|
||||
for model, cost in MODEL_COST.items()
|
||||
]
|
||||
+ [
|
||||
BlockCost(
|
||||
cost_type=BlockCostType.RUN,
|
||||
cost_filter={
|
||||
|
@ -99,6 +88,7 @@ LLM_COST = (
|
|||
for model, cost in MODEL_COST.items()
|
||||
if MODEL_METADATA[model].provider == "anthropic"
|
||||
]
|
||||
# OpenAI Models
|
||||
+ [
|
||||
BlockCost(
|
||||
cost_type=BlockCostType.RUN,
|
||||
|
@ -115,6 +105,7 @@ LLM_COST = (
|
|||
for model, cost in MODEL_COST.items()
|
||||
if MODEL_METADATA[model].provider == "openai"
|
||||
]
|
||||
# Groq Models
|
||||
+ [
|
||||
BlockCost(
|
||||
cost_type=BlockCostType.RUN,
|
||||
|
@ -127,13 +118,6 @@ LLM_COST = (
|
|||
for model, cost in MODEL_COST.items()
|
||||
if MODEL_METADATA[model].provider == "groq"
|
||||
]
|
||||
+ [
|
||||
BlockCost(
|
||||
# Default cost is running LlmModel.GPT4O.
|
||||
cost_amount=MODEL_COST[LlmModel.GPT4O],
|
||||
cost_filter={"api_key": None},
|
||||
),
|
||||
]
|
||||
# Open Router Models
|
||||
+ [
|
||||
BlockCost(
|
||||
|
@ -186,7 +170,17 @@ BLOCK_COSTS: dict[Type[Block], list[BlockCost]] = {
|
|||
)
|
||||
],
|
||||
ExtractWebsiteContentBlock: [
|
||||
BlockCost(cost_amount=1, cost_filter={"raw_content": False})
|
||||
BlockCost(
|
||||
cost_amount=1,
|
||||
cost_filter={
|
||||
"raw_content": False,
|
||||
"credentials": {
|
||||
"id": jina_credentials.id,
|
||||
"provider": jina_credentials.provider,
|
||||
"type": jina_credentials.type,
|
||||
},
|
||||
},
|
||||
)
|
||||
],
|
||||
IdeogramModelBlock: [
|
||||
BlockCost(
|
||||
|
|
|
@ -107,8 +107,8 @@ class UserCredit(UserCreditBase):
|
|||
def time_now():
|
||||
return datetime.now(timezone.utc)
|
||||
|
||||
@staticmethod
|
||||
def _block_usage_cost(
|
||||
self,
|
||||
block: Block,
|
||||
input_data: BlockInput,
|
||||
data_size: float,
|
||||
|
@ -119,28 +119,44 @@ class UserCredit(UserCreditBase):
|
|||
return 0, {}
|
||||
|
||||
for block_cost in block_costs:
|
||||
if all(
|
||||
# None, [], {}, "", are considered the same value.
|
||||
input_data.get(k) == b or (not input_data.get(k) and not b)
|
||||
for k, b in block_cost.cost_filter.items()
|
||||
):
|
||||
if block_cost.cost_type == BlockCostType.RUN:
|
||||
return block_cost.cost_amount, block_cost.cost_filter
|
||||
if not self._is_cost_filter_match(block_cost.cost_filter, input_data):
|
||||
continue
|
||||
|
||||
if block_cost.cost_type == BlockCostType.SECOND:
|
||||
return (
|
||||
int(run_time * block_cost.cost_amount),
|
||||
block_cost.cost_filter,
|
||||
)
|
||||
if block_cost.cost_type == BlockCostType.RUN:
|
||||
return block_cost.cost_amount, block_cost.cost_filter
|
||||
|
||||
if block_cost.cost_type == BlockCostType.BYTE:
|
||||
return (
|
||||
int(data_size * block_cost.cost_amount),
|
||||
block_cost.cost_filter,
|
||||
)
|
||||
if block_cost.cost_type == BlockCostType.SECOND:
|
||||
return (
|
||||
int(run_time * block_cost.cost_amount),
|
||||
block_cost.cost_filter,
|
||||
)
|
||||
|
||||
if block_cost.cost_type == BlockCostType.BYTE:
|
||||
return (
|
||||
int(data_size * block_cost.cost_amount),
|
||||
block_cost.cost_filter,
|
||||
)
|
||||
|
||||
return 0, {}
|
||||
|
||||
def _is_cost_filter_match(
|
||||
self, cost_filter: BlockInput, input_data: BlockInput
|
||||
) -> bool:
|
||||
"""
|
||||
Filter rules:
|
||||
- If costFilter is an object, then check if costFilter is the subset of inputValues
|
||||
- Otherwise, check if costFilter is equal to inputValues.
|
||||
- Undefined, null, and empty string are considered as equal.
|
||||
"""
|
||||
if not isinstance(cost_filter, dict) or not isinstance(input_data, dict):
|
||||
return cost_filter == input_data
|
||||
|
||||
return all(
|
||||
(not input_data.get(k) and not v)
|
||||
or (input_data.get(k) and self._is_cost_filter_match(v, input_data[k]))
|
||||
for k, v in cost_filter.items()
|
||||
)
|
||||
|
||||
async def spend_credits(
|
||||
self,
|
||||
user_id: str,
|
||||
|
|
|
@ -1,6 +1,7 @@
|
|||
from datetime import datetime
|
||||
|
||||
import pytest
|
||||
from autogpt_libs.supabase_integration_credentials_store.store import openai_credentials
|
||||
from prisma.models import UserBlockCredit
|
||||
|
||||
from backend.blocks.llm import AITextGeneratorBlock
|
||||
|
@ -20,7 +21,14 @@ async def test_block_credit_usage(server: SpinTestServer):
|
|||
DEFAULT_USER_ID,
|
||||
current_credit,
|
||||
AITextGeneratorBlock().id,
|
||||
{"model": "gpt-4-turbo"},
|
||||
{
|
||||
"model": "gpt-4-turbo",
|
||||
"credentials": {
|
||||
"id": openai_credentials.id,
|
||||
"provider": openai_credentials.provider,
|
||||
"type": openai_credentials.type,
|
||||
},
|
||||
},
|
||||
0.0,
|
||||
0.0,
|
||||
validate_balance=False,
|
||||
|
|
|
@ -491,14 +491,26 @@ export function CustomNode({
|
|||
});
|
||||
|
||||
const inputValues = data.hardcodedValues;
|
||||
|
||||
const isCostFilterMatch = (costFilter: any, inputValues: any): boolean => {
|
||||
/*
|
||||
Filter rules:
|
||||
- If costFilter is an object, then check if costFilter is the subset of inputValues
|
||||
- Otherwise, check if costFilter is equal to inputValues.
|
||||
- Undefined, null, and empty string are considered as equal.
|
||||
*/
|
||||
return typeof costFilter === "object" && typeof inputValues === "object"
|
||||
? Object.entries(costFilter).every(
|
||||
([k, v]) =>
|
||||
(!v && !inputValues[k]) || isCostFilterMatch(v, inputValues[k]),
|
||||
)
|
||||
: costFilter === inputValues;
|
||||
};
|
||||
|
||||
const blockCost =
|
||||
data.blockCosts &&
|
||||
data.blockCosts.find((cost) =>
|
||||
Object.entries(cost.cost_filter).every(
|
||||
// Undefined, null, or empty values are considered equal
|
||||
([key, value]) =>
|
||||
value === inputValues[key] || (!value && !inputValues[key]),
|
||||
),
|
||||
isCostFilterMatch(cost.cost_filter, inputValues),
|
||||
);
|
||||
|
||||
const LineSeparator = () => (
|
||||
|
|
Loading…
Reference in New Issue