feat(block): Add AI video generator block with Fal txt 2 vid (#8528)
### Background Implements an AI Video Generator Block for text to image models hosted on Fal ![image](https://github.com/user-attachments/assets/9cb70015-4174-4419-8c1a-4144f324442f) --------- Co-authored-by: Aarushi <50577581+aarushik93@users.noreply.github.com> Co-authored-by: Aarushi <aarushik93@gmail.com>pull/8787/head^2
parent
75f9b072a6
commit
4aa5f53710
|
@ -0,0 +1,36 @@
|
|||
from typing import Literal
|
||||
|
||||
from pydantic import SecretStr
|
||||
|
||||
from backend.data.model import APIKeyCredentials, CredentialsField, CredentialsMetaInput
|
||||
|
||||
FalCredentials = APIKeyCredentials
|
||||
FalCredentialsInput = CredentialsMetaInput[
|
||||
Literal["fal"],
|
||||
Literal["api_key"],
|
||||
]
|
||||
|
||||
TEST_CREDENTIALS = APIKeyCredentials(
|
||||
id="01234567-89ab-cdef-0123-456789abcdef",
|
||||
provider="fal",
|
||||
api_key=SecretStr("mock-fal-api-key"),
|
||||
title="Mock FAL API key",
|
||||
expires_at=None,
|
||||
)
|
||||
TEST_CREDENTIALS_INPUT = {
|
||||
"provider": TEST_CREDENTIALS.provider,
|
||||
"id": TEST_CREDENTIALS.id,
|
||||
"type": TEST_CREDENTIALS.type,
|
||||
"title": TEST_CREDENTIALS.title,
|
||||
}
|
||||
|
||||
|
||||
def FalCredentialsField() -> FalCredentialsInput:
|
||||
"""
|
||||
Creates a FAL credentials input on a block.
|
||||
"""
|
||||
return CredentialsField(
|
||||
provider="fal",
|
||||
supported_credential_types={"api_key"},
|
||||
description="The FAL integration can be used with an API Key.",
|
||||
)
|
|
@ -0,0 +1,199 @@
|
|||
import logging
|
||||
import time
|
||||
from enum import Enum
|
||||
from typing import Any, Dict
|
||||
|
||||
import httpx
|
||||
|
||||
from backend.blocks.fal._auth import (
|
||||
TEST_CREDENTIALS,
|
||||
TEST_CREDENTIALS_INPUT,
|
||||
FalCredentials,
|
||||
FalCredentialsField,
|
||||
FalCredentialsInput,
|
||||
)
|
||||
from backend.data.block import Block, BlockCategory, BlockOutput, BlockSchema
|
||||
from backend.data.model import SchemaField
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class FalModel(str, Enum):
|
||||
MOCHI = "fal-ai/mochi-v1"
|
||||
LUMA = "fal-ai/luma-dream-machine"
|
||||
|
||||
|
||||
class AIVideoGeneratorBlock(Block):
|
||||
class Input(BlockSchema):
|
||||
prompt: str = SchemaField(
|
||||
description="Description of the video to generate.",
|
||||
placeholder="A dog running in a field.",
|
||||
)
|
||||
model: FalModel = SchemaField(
|
||||
title="FAL Model",
|
||||
default=FalModel.MOCHI,
|
||||
description="The FAL model to use for video generation.",
|
||||
)
|
||||
credentials: FalCredentialsInput = FalCredentialsField()
|
||||
|
||||
class Output(BlockSchema):
|
||||
video_url: str = SchemaField(description="The URL of the generated video.")
|
||||
error: str = SchemaField(
|
||||
description="Error message if video generation failed."
|
||||
)
|
||||
logs: list[str] = SchemaField(
|
||||
description="Generation progress logs.", optional=True
|
||||
)
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="530cf046-2ce0-4854-ae2c-659db17c7a46",
|
||||
description="Generate videos using FAL AI models.",
|
||||
categories={BlockCategory.AI},
|
||||
input_schema=self.Input,
|
||||
output_schema=self.Output,
|
||||
test_input={
|
||||
"prompt": "A dog running in a field.",
|
||||
"model": FalModel.MOCHI,
|
||||
"credentials": TEST_CREDENTIALS_INPUT,
|
||||
},
|
||||
test_credentials=TEST_CREDENTIALS,
|
||||
test_output=[("video_url", "https://fal.media/files/example/video.mp4")],
|
||||
test_mock={
|
||||
"generate_video": lambda *args, **kwargs: "https://fal.media/files/example/video.mp4"
|
||||
},
|
||||
)
|
||||
|
||||
def _get_headers(self, api_key: str) -> Dict[str, str]:
|
||||
"""Get headers for FAL API requests."""
|
||||
return {
|
||||
"Authorization": f"Key {api_key}",
|
||||
"Content-Type": "application/json",
|
||||
}
|
||||
|
||||
def _submit_request(
|
||||
self, url: str, headers: Dict[str, str], data: Dict[str, Any]
|
||||
) -> Dict[str, Any]:
|
||||
"""Submit a request to the FAL API."""
|
||||
try:
|
||||
response = httpx.post(url, headers=headers, json=data)
|
||||
response.raise_for_status()
|
||||
return response.json()
|
||||
except httpx.HTTPError as e:
|
||||
logger.error(f"FAL API request failed: {str(e)}")
|
||||
raise RuntimeError(f"Failed to submit request: {str(e)}")
|
||||
|
||||
def _poll_status(self, status_url: str, headers: Dict[str, str]) -> Dict[str, Any]:
|
||||
"""Poll the status endpoint until completion or failure."""
|
||||
try:
|
||||
response = httpx.get(status_url, headers=headers)
|
||||
response.raise_for_status()
|
||||
return response.json()
|
||||
except httpx.HTTPError as e:
|
||||
logger.error(f"Failed to get status: {str(e)}")
|
||||
raise RuntimeError(f"Failed to get status: {str(e)}")
|
||||
|
||||
def generate_video(self, input_data: Input, credentials: FalCredentials) -> str:
|
||||
"""Generate video using the specified FAL model."""
|
||||
base_url = "https://queue.fal.run"
|
||||
api_key = credentials.api_key.get_secret_value()
|
||||
headers = self._get_headers(api_key)
|
||||
|
||||
# Submit generation request
|
||||
submit_url = f"{base_url}/{input_data.model.value}"
|
||||
submit_data = {"prompt": input_data.prompt}
|
||||
|
||||
seen_logs = set()
|
||||
|
||||
try:
|
||||
# Submit request to queue
|
||||
submit_response = httpx.post(submit_url, headers=headers, json=submit_data)
|
||||
submit_response.raise_for_status()
|
||||
request_data = submit_response.json()
|
||||
|
||||
# Get request_id and urls from initial response
|
||||
request_id = request_data.get("request_id")
|
||||
status_url = request_data.get("status_url")
|
||||
result_url = request_data.get("response_url")
|
||||
|
||||
if not all([request_id, status_url, result_url]):
|
||||
raise ValueError("Missing required data in submission response")
|
||||
|
||||
# Poll for status with exponential backoff
|
||||
max_attempts = 30
|
||||
attempt = 0
|
||||
base_wait_time = 5
|
||||
|
||||
while attempt < max_attempts:
|
||||
status_response = httpx.get(f"{status_url}?logs=1", headers=headers)
|
||||
status_response.raise_for_status()
|
||||
status_data = status_response.json()
|
||||
|
||||
# Process new logs only
|
||||
logs = status_data.get("logs", [])
|
||||
if logs and isinstance(logs, list):
|
||||
for log in logs:
|
||||
if isinstance(log, dict):
|
||||
# Create a unique key for this log entry
|
||||
log_key = (
|
||||
f"{log.get('timestamp', '')}-{log.get('message', '')}"
|
||||
)
|
||||
if log_key not in seen_logs:
|
||||
seen_logs.add(log_key)
|
||||
message = log.get("message", "")
|
||||
if message:
|
||||
logger.debug(
|
||||
f"[FAL Generation] [{log.get('level', 'INFO')}] [{log.get('source', '')}] [{log.get('timestamp', '')}] {message}"
|
||||
)
|
||||
|
||||
status = status_data.get("status")
|
||||
if status == "COMPLETED":
|
||||
# Get the final result
|
||||
result_response = httpx.get(result_url, headers=headers)
|
||||
result_response.raise_for_status()
|
||||
result_data = result_response.json()
|
||||
|
||||
if "video" not in result_data or not isinstance(
|
||||
result_data["video"], dict
|
||||
):
|
||||
raise ValueError("Invalid response format - missing video data")
|
||||
|
||||
video_url = result_data["video"].get("url")
|
||||
if not video_url:
|
||||
raise ValueError("No video URL in response")
|
||||
|
||||
return video_url
|
||||
|
||||
elif status == "FAILED":
|
||||
error_msg = status_data.get("error", "No error details provided")
|
||||
raise RuntimeError(f"Video generation failed: {error_msg}")
|
||||
elif status == "IN_QUEUE":
|
||||
position = status_data.get("queue_position", "unknown")
|
||||
logger.debug(
|
||||
f"[FAL Generation] Status: In queue, position: {position}"
|
||||
)
|
||||
elif status == "IN_PROGRESS":
|
||||
logger.debug(
|
||||
"[FAL Generation] Status: Request is being processed..."
|
||||
)
|
||||
else:
|
||||
logger.info(f"[FAL Generation] Status: Unknown status: {status}")
|
||||
|
||||
wait_time = min(base_wait_time * (2**attempt), 60) # Cap at 60 seconds
|
||||
time.sleep(wait_time)
|
||||
attempt += 1
|
||||
|
||||
raise RuntimeError("Maximum polling attempts reached")
|
||||
|
||||
except httpx.HTTPError as e:
|
||||
raise RuntimeError(f"API request failed: {str(e)}")
|
||||
|
||||
def run(
|
||||
self, input_data: Input, *, credentials: FalCredentials, **kwargs
|
||||
) -> BlockOutput:
|
||||
try:
|
||||
video_url = self.generate_video(input_data, credentials)
|
||||
yield "video_url", video_url
|
||||
except Exception as e:
|
||||
error_message = str(e)
|
||||
yield "error", error_message
|
|
@ -290,6 +290,8 @@ class Secrets(UpdateTrackingModel["Secrets"], BaseSettings):
|
|||
jina_api_key: str = Field(default="", description="Jina API Key")
|
||||
unreal_speech_api_key: str = Field(default="", description="Unreal Speech API Key")
|
||||
|
||||
fal_key: str = Field(default="", description="FAL API key")
|
||||
|
||||
# Add more secret fields as needed
|
||||
|
||||
model_config = SettingsConfigDict(
|
||||
|
|
|
@ -64,6 +64,7 @@ export const providerIcons: Record<
|
|||
open_router: fallbackIcon,
|
||||
pinecone: fallbackIcon,
|
||||
replicate: fallbackIcon,
|
||||
fal: fallbackIcon,
|
||||
revid: fallbackIcon,
|
||||
unreal_speech: fallbackIcon,
|
||||
hubspot: fallbackIcon,
|
||||
|
|
|
@ -38,6 +38,7 @@ const providerDisplayNames: Record<CredentialsProviderName, string> = {
|
|||
open_router: "Open Router",
|
||||
pinecone: "Pinecone",
|
||||
replicate: "Replicate",
|
||||
fal: "FAL",
|
||||
revid: "Rev.ID",
|
||||
unreal_speech: "Unreal Speech",
|
||||
hubspot: "Hubspot",
|
||||
|
|
|
@ -116,6 +116,7 @@ export const PROVIDER_NAMES = {
|
|||
OPEN_ROUTER: "open_router",
|
||||
PINECONE: "pinecone",
|
||||
REPLICATE: "replicate",
|
||||
FAL: "fal",
|
||||
REVID: "revid",
|
||||
UNREAL_SPEECH: "unreal_speech",
|
||||
HUBSPOT: "hubspot",
|
||||
|
|
Loading…
Reference in New Issue