2023-04-18 23:38:31 +00:00
|
|
|
import hashlib
|
2023-04-19 00:24:13 +00:00
|
|
|
import os
|
2023-04-23 19:36:04 +00:00
|
|
|
import shutil
|
2023-04-19 00:26:18 +00:00
|
|
|
import unittest
|
2023-04-23 19:36:04 +00:00
|
|
|
from pathlib import Path
|
2023-04-19 00:26:18 +00:00
|
|
|
|
|
|
|
from PIL import Image
|
|
|
|
|
2023-04-18 23:38:31 +00:00
|
|
|
from autogpt.commands.image_gen import generate_image, generate_image_with_sd_webui
|
2023-04-19 00:26:18 +00:00
|
|
|
from autogpt.config import Config
|
2023-04-23 19:36:04 +00:00
|
|
|
from autogpt.workspace import Workspace
|
2023-04-22 19:48:47 +00:00
|
|
|
from tests.utils import requires_api_key
|
2023-04-18 23:38:31 +00:00
|
|
|
|
|
|
|
|
|
|
|
def lst(txt):
|
2023-04-24 12:55:49 +00:00
|
|
|
"""Extract the file path from the output of `generate_image()`"""
|
2023-04-23 19:36:04 +00:00
|
|
|
return Path(txt.split(":")[1].strip())
|
2023-04-18 23:38:31 +00:00
|
|
|
|
|
|
|
|
2023-04-23 19:36:04 +00:00
|
|
|
@unittest.skip("Skipping image generation tests")
|
2023-04-18 23:38:31 +00:00
|
|
|
class TestImageGen(unittest.TestCase):
|
|
|
|
def setUp(self):
|
|
|
|
self.config = Config()
|
2023-04-23 19:36:04 +00:00
|
|
|
workspace_path = os.path.join(os.path.dirname(__file__), "workspace")
|
|
|
|
self.workspace_path = Workspace.make_workspace(workspace_path)
|
|
|
|
self.config.workspace_path = workspace_path
|
|
|
|
self.workspace = Workspace(workspace_path, restrict_to_workspace=True)
|
|
|
|
|
|
|
|
def tearDown(self) -> None:
|
|
|
|
shutil.rmtree(self.workspace_path)
|
2023-04-18 23:38:31 +00:00
|
|
|
|
2023-04-22 19:48:47 +00:00
|
|
|
@requires_api_key("OPENAI_API_KEY")
|
2023-04-18 23:38:31 +00:00
|
|
|
def test_dalle(self):
|
2023-04-24 12:55:49 +00:00
|
|
|
"""Test DALL-E image generation."""
|
2023-04-18 23:38:31 +00:00
|
|
|
self.config.image_provider = "dalle"
|
|
|
|
|
|
|
|
# Test using size 256
|
2023-04-23 19:36:04 +00:00
|
|
|
image_path = lst(generate_image("astronaut riding a horse", 256))
|
2023-04-18 23:38:31 +00:00
|
|
|
self.assertTrue(image_path.exists())
|
|
|
|
with Image.open(image_path) as img:
|
|
|
|
self.assertEqual(img.size, (256, 256))
|
|
|
|
image_path.unlink()
|
|
|
|
|
|
|
|
# Test using size 512
|
2023-04-23 19:36:04 +00:00
|
|
|
image_path = lst(generate_image("astronaut riding a horse", 512))
|
2023-04-18 23:38:31 +00:00
|
|
|
with Image.open(image_path) as img:
|
|
|
|
self.assertEqual(img.size, (512, 512))
|
|
|
|
image_path.unlink()
|
|
|
|
|
2023-04-22 19:48:47 +00:00
|
|
|
@requires_api_key("HUGGINGFACE_API_TOKEN")
|
2023-04-18 23:38:31 +00:00
|
|
|
def test_huggingface(self):
|
2023-04-24 12:55:49 +00:00
|
|
|
"""Test HuggingFace image generation."""
|
2023-04-18 23:38:31 +00:00
|
|
|
self.config.image_provider = "huggingface"
|
|
|
|
|
|
|
|
# Test usin SD 1.4 model and size 512
|
|
|
|
self.config.huggingface_image_model = "CompVis/stable-diffusion-v1-4"
|
2023-04-23 19:36:04 +00:00
|
|
|
image_path = lst(generate_image("astronaut riding a horse", 512))
|
2023-04-18 23:38:31 +00:00
|
|
|
self.assertTrue(image_path.exists())
|
|
|
|
with Image.open(image_path) as img:
|
|
|
|
self.assertEqual(img.size, (512, 512))
|
|
|
|
image_path.unlink()
|
|
|
|
|
|
|
|
# Test using SD 2.1 768 model and size 768
|
|
|
|
self.config.huggingface_image_model = "stabilityai/stable-diffusion-2-1"
|
2023-04-23 19:36:04 +00:00
|
|
|
image_path = lst(generate_image("astronaut riding a horse", 768))
|
2023-04-18 23:38:31 +00:00
|
|
|
with Image.open(image_path) as img:
|
|
|
|
self.assertEqual(img.size, (768, 768))
|
|
|
|
image_path.unlink()
|
|
|
|
|
|
|
|
def test_sd_webui(self):
|
2023-04-24 12:55:49 +00:00
|
|
|
"""Test SD WebUI image generation."""
|
2023-04-18 23:38:31 +00:00
|
|
|
self.config.image_provider = "sd_webui"
|
2023-04-18 23:46:24 +00:00
|
|
|
return
|
2023-04-18 23:38:31 +00:00
|
|
|
|
|
|
|
# Test using size 128
|
2023-04-23 19:36:04 +00:00
|
|
|
image_path = lst(generate_image_with_sd_webui("astronaut riding a horse", 128))
|
2023-04-18 23:38:31 +00:00
|
|
|
self.assertTrue(image_path.exists())
|
|
|
|
with Image.open(image_path) as img:
|
|
|
|
self.assertEqual(img.size, (128, 128))
|
|
|
|
image_path.unlink()
|
|
|
|
|
|
|
|
# Test using size 64 and negative prompt
|
2023-04-19 00:26:18 +00:00
|
|
|
result = lst(
|
|
|
|
generate_image_with_sd_webui(
|
|
|
|
"astronaut riding a horse",
|
|
|
|
negative_prompt="horse",
|
|
|
|
size=64,
|
|
|
|
extra={"seed": 123},
|
|
|
|
)
|
|
|
|
)
|
2023-04-18 23:38:31 +00:00
|
|
|
image_path = path_in_workspace(result)
|
|
|
|
with Image.open(image_path) as img:
|
|
|
|
self.assertEqual(img.size, (64, 64))
|
|
|
|
neg_image_hash = hashlib.md5(img.tobytes()).hexdigest()
|
|
|
|
image_path.unlink()
|
|
|
|
|
|
|
|
# Same test as above but without the negative prompt
|
2023-04-19 00:26:18 +00:00
|
|
|
result = lst(
|
|
|
|
generate_image_with_sd_webui(
|
|
|
|
"astronaut riding a horse", image_size=64, size=1, extra={"seed": 123}
|
|
|
|
)
|
|
|
|
)
|
2023-04-18 23:38:31 +00:00
|
|
|
image_path = path_in_workspace(result)
|
|
|
|
with Image.open(image_path) as img:
|
|
|
|
self.assertEqual(img.size, (64, 64))
|
|
|
|
image_hash = hashlib.md5(img.tobytes()).hexdigest()
|
|
|
|
image_path.unlink()
|
|
|
|
|
|
|
|
self.assertNotEqual(image_hash, neg_image_hash)
|
|
|
|
|
|
|
|
|
|
|
|
if __name__ == "__main__":
|
|
|
|
unittest.main()
|