106 lines
3.0 KiB
Python
106 lines
3.0 KiB
Python
import functools
|
|
import hashlib
|
|
from pathlib import Path
|
|
|
|
import pytest
|
|
from PIL import Image
|
|
|
|
from autogpt.commands.image_gen import generate_image, generate_image_with_sd_webui
|
|
from tests.utils import requires_api_key
|
|
|
|
|
|
@pytest.fixture(params=[256, 512, 1024])
|
|
def image_size(request):
|
|
"""Parametrize image size."""
|
|
return request.param
|
|
|
|
|
|
@pytest.mark.xfail(
|
|
reason="The image is too big to be put in a cassette for a CI pipeline. We're looking into a solution."
|
|
)
|
|
@requires_api_key("OPENAI_API_KEY")
|
|
def test_dalle(config, workspace, image_size):
|
|
"""Test DALL-E image generation."""
|
|
generate_and_validate(
|
|
config,
|
|
workspace,
|
|
image_provider="dalle",
|
|
image_size=image_size,
|
|
)
|
|
|
|
|
|
@pytest.mark.xfail(
|
|
reason="The image is too big to be put in a cassette for a CI pipeline. We're looking into a solution."
|
|
)
|
|
@requires_api_key("HUGGINGFACE_API_TOKEN")
|
|
@pytest.mark.parametrize(
|
|
"image_model",
|
|
["CompVis/stable-diffusion-v1-4", "stabilityai/stable-diffusion-2-1"],
|
|
)
|
|
def test_huggingface(config, workspace, image_size, image_model):
|
|
"""Test HuggingFace image generation."""
|
|
generate_and_validate(
|
|
config,
|
|
workspace,
|
|
image_provider="huggingface",
|
|
image_size=image_size,
|
|
hugging_face_image_model=image_model,
|
|
)
|
|
|
|
|
|
@pytest.mark.xfail(reason="SD WebUI call does not work.")
|
|
def test_sd_webui(config, workspace, image_size):
|
|
"""Test SD WebUI image generation."""
|
|
generate_and_validate(
|
|
config,
|
|
workspace,
|
|
image_provider="sd_webui",
|
|
image_size=image_size,
|
|
)
|
|
|
|
|
|
@pytest.mark.xfail(reason="SD WebUI call does not work.")
|
|
def test_sd_webui_negative_prompt(config, workspace, image_size):
|
|
gen_image = functools.partial(
|
|
generate_image_with_sd_webui,
|
|
prompt="astronaut riding a horse",
|
|
size=image_size,
|
|
extra={"seed": 123},
|
|
)
|
|
|
|
# Generate an image with a negative prompt
|
|
image_path = lst(gen_image(negative_prompt="horse", filename="negative.jpg"))
|
|
with Image.open(image_path) as img:
|
|
neg_image_hash = hashlib.md5(img.tobytes()).hexdigest()
|
|
|
|
# Generate an image without a negative prompt
|
|
image_path = lst(gen_image(filename="positive.jpg"))
|
|
with Image.open(image_path) as img:
|
|
image_hash = hashlib.md5(img.tobytes()).hexdigest()
|
|
|
|
assert image_hash != neg_image_hash
|
|
|
|
|
|
def lst(txt):
|
|
"""Extract the file path from the output of `generate_image()`"""
|
|
return Path(txt.split(":")[1].strip())
|
|
|
|
|
|
def generate_and_validate(
|
|
config,
|
|
workspace,
|
|
image_size,
|
|
image_provider,
|
|
hugging_face_image_model=None,
|
|
**kwargs,
|
|
):
|
|
"""Generate an image and validate the output."""
|
|
config.image_provider = image_provider
|
|
config.huggingface_image_model = hugging_face_image_model
|
|
prompt = "astronaut riding a horse"
|
|
|
|
image_path = lst(generate_image(prompt, image_size, **kwargs))
|
|
assert image_path.exists()
|
|
with Image.open(image_path) as img:
|
|
assert img.size == (image_size, image_size)
|