Merge pull request #1477 from Tymec/feature/more-image-gen
Image generation improvementspull/1243/head
commit
fdaa55a452
|
@ -107,14 +107,22 @@ MILVUS_COLLECTION=autogpt
|
|||
### OPEN AI
|
||||
# IMAGE_PROVIDER - Image provider (Example: dalle)
|
||||
IMAGE_PROVIDER=dalle
|
||||
# IMAGE_SIZE - Image size (Example: 256)
|
||||
# DALLE: 256, 512, 1024
|
||||
IMAGE_SIZE=256
|
||||
|
||||
### HUGGINGFACE
|
||||
# STABLE DIFFUSION
|
||||
# (Default URL: https://api-inference.huggingface.co/models/CompVis/stable-diffusion-v1-4)
|
||||
# Set in image_gen.py)
|
||||
# HUGGINGFACE_IMAGE_MODEL - Text-to-image model from Huggingface (Default: CompVis/stable-diffusion-v1-4)
|
||||
HUGGINGFACE_IMAGE_MODEL=CompVis/stable-diffusion-v1-4
|
||||
# HUGGINGFACE_API_TOKEN - HuggingFace API token (Example: my-huggingface-api-token)
|
||||
HUGGINGFACE_API_TOKEN=your-huggingface-api-token
|
||||
|
||||
### STABLE DIFFUSION WEBUI
|
||||
# SD_WEBUI_URL - Stable diffusion webui API URL (Example: http://127.0.0.1:7860)
|
||||
SD_WEBUI_URL=http://127.0.0.1:7860
|
||||
# SD_WEBUI_AUTH - Stable diffusion webui username:password pair (Example: username:password)
|
||||
SD_WEBUI_AUTH=
|
||||
|
||||
################################################################################
|
||||
### AUDIO TO TEXT PROVIDER
|
||||
################################################################################
|
||||
|
|
|
@ -14,11 +14,12 @@ from autogpt.workspace import path_in_workspace
|
|||
CFG = Config()
|
||||
|
||||
|
||||
def generate_image(prompt: str) -> str:
|
||||
def generate_image(prompt: str, size: int = 256) -> str:
|
||||
"""Generate an image from a prompt.
|
||||
|
||||
Args:
|
||||
prompt (str): The prompt to use
|
||||
size (int, optional): The size of the image. Defaults to 256. (Not supported by HuggingFace)
|
||||
|
||||
Returns:
|
||||
str: The filename of the image
|
||||
|
@ -27,11 +28,14 @@ def generate_image(prompt: str) -> str:
|
|||
|
||||
# DALL-E
|
||||
if CFG.image_provider == "dalle":
|
||||
return generate_image_with_dalle(prompt, filename)
|
||||
elif CFG.image_provider == "sd":
|
||||
return generate_image_with_dalle(prompt, filename, size)
|
||||
# HuggingFace
|
||||
elif CFG.image_provider == "huggingface":
|
||||
return generate_image_with_hf(prompt, filename)
|
||||
else:
|
||||
return "No Image Provider Set"
|
||||
# SD WebUI
|
||||
elif CFG.image_provider == "sdwebui":
|
||||
return generate_image_with_sd_webui(prompt, filename, size)
|
||||
return "No Image Provider Set"
|
||||
|
||||
|
||||
def generate_image_with_hf(prompt: str, filename: str) -> str:
|
||||
|
@ -45,13 +49,16 @@ def generate_image_with_hf(prompt: str, filename: str) -> str:
|
|||
str: The filename of the image
|
||||
"""
|
||||
API_URL = (
|
||||
"https://api-inference.huggingface.co/models/CompVis/stable-diffusion-v1-4"
|
||||
f"https://api-inference.huggingface.co/models/{CFG.huggingface_image_model}"
|
||||
)
|
||||
if CFG.huggingface_api_token is None:
|
||||
raise ValueError(
|
||||
"You need to set your Hugging Face API token in the config file."
|
||||
)
|
||||
headers = {"Authorization": f"Bearer {CFG.huggingface_api_token}"}
|
||||
headers = {
|
||||
"Authorization": f"Bearer {CFG.huggingface_api_token}",
|
||||
"X-Use-Cache": "false",
|
||||
}
|
||||
|
||||
response = requests.post(
|
||||
API_URL,
|
||||
|
@ -81,10 +88,18 @@ def generate_image_with_dalle(prompt: str, filename: str) -> str:
|
|||
"""
|
||||
openai.api_key = CFG.openai_api_key
|
||||
|
||||
# Check for supported image sizes
|
||||
if size not in [256, 512, 1024]:
|
||||
closest = min([256, 512, 1024], key=lambda x: abs(x - size))
|
||||
print(
|
||||
f"DALL-E only supports image sizes of 256x256, 512x512, or 1024x1024. Setting to {closest}, was {size}."
|
||||
)
|
||||
size = closest
|
||||
|
||||
response = openai.Image.create(
|
||||
prompt=prompt,
|
||||
n=1,
|
||||
size="256x256",
|
||||
size=f"{size}x{size}",
|
||||
response_format="b64_json",
|
||||
)
|
||||
|
||||
|
@ -96,3 +111,53 @@ def generate_image_with_dalle(prompt: str, filename: str) -> str:
|
|||
png.write(image_data)
|
||||
|
||||
return f"Saved to disk:{filename}"
|
||||
|
||||
|
||||
def generate_image_with_sd_webui(
|
||||
prompt: str,
|
||||
filename: str,
|
||||
size: int = 512,
|
||||
negative_prompt: str = "",
|
||||
extra: dict = {},
|
||||
) -> str:
|
||||
"""Generate an image with Stable Diffusion webui.
|
||||
Args:
|
||||
prompt (str): The prompt to use
|
||||
filename (str): The filename to save the image to
|
||||
size (int, optional): The size of the image. Defaults to 256.
|
||||
negative_prompt (str, optional): The negative prompt to use. Defaults to "".
|
||||
extra (dict, optional): Extra parameters to pass to the API. Defaults to {}.
|
||||
Returns:
|
||||
str: The filename of the image
|
||||
"""
|
||||
# Create a session and set the basic auth if needed
|
||||
s = requests.Session()
|
||||
if CFG.sd_webui_auth:
|
||||
username, password = CFG.sd_webui_auth.split(":")
|
||||
s.auth = (username, password or "")
|
||||
|
||||
# Generate the images
|
||||
response = requests.post(
|
||||
f"{CFG.sd_webui_url}/sdapi/v1/txt2img",
|
||||
json={
|
||||
"prompt": prompt,
|
||||
"negative_prompt": negative_prompt,
|
||||
"sampler_index": "DDIM",
|
||||
"steps": 20,
|
||||
"cfg_scale": 7.0,
|
||||
"width": size,
|
||||
"height": size,
|
||||
"n_iter": 1,
|
||||
**extra,
|
||||
},
|
||||
)
|
||||
|
||||
print(f"Image Generated for prompt:{prompt}")
|
||||
|
||||
# Save the image to disk
|
||||
response = response.json()
|
||||
b64 = b64decode(response["images"][0].split(",", 1)[0])
|
||||
image = Image.open(io.BytesIO(b64))
|
||||
image.save(path_in_workspace(filename))
|
||||
|
||||
return f"Saved to disk:{filename}"
|
||||
|
|
|
@ -85,10 +85,16 @@ class Config(metaclass=Singleton):
|
|||
self.milvus_collection = os.getenv("MILVUS_COLLECTION", "autogpt")
|
||||
|
||||
self.image_provider = os.getenv("IMAGE_PROVIDER")
|
||||
self.image_size = int(os.getenv("IMAGE_SIZE", 256))
|
||||
self.huggingface_api_token = os.getenv("HUGGINGFACE_API_TOKEN")
|
||||
self.huggingface_image_model = os.getenv(
|
||||
"HUGGINGFACE_IMAGE_MODEL", "CompVis/stable-diffusion-v1-4"
|
||||
)
|
||||
self.huggingface_audio_to_text_model = os.getenv(
|
||||
"HUGGINGFACE_AUDIO_TO_TEXT_MODEL"
|
||||
)
|
||||
self.sd_webui_url = os.getenv("SD_WEBUI_URL", "http://localhost:7860")
|
||||
self.sd_webui_auth = os.getenv("SD_WEBUI_AUTH")
|
||||
|
||||
# Selenium browser settings
|
||||
self.selenium_web_browser = os.getenv("USE_WEB_BROWSER", "chrome")
|
||||
|
|
|
@ -0,0 +1,102 @@
|
|||
import hashlib
|
||||
import os
|
||||
import unittest
|
||||
|
||||
from PIL import Image
|
||||
|
||||
from autogpt.commands.image_gen import generate_image, generate_image_with_sd_webui
|
||||
from autogpt.config import Config
|
||||
from autogpt.workspace import path_in_workspace
|
||||
|
||||
|
||||
def lst(txt):
|
||||
return txt.split(":")[1].strip()
|
||||
|
||||
|
||||
@unittest.skipIf(os.getenv("CI"), "Skipping image generation tests")
|
||||
class TestImageGen(unittest.TestCase):
|
||||
def setUp(self):
|
||||
self.config = Config()
|
||||
|
||||
def test_dalle(self):
|
||||
self.config.image_provider = "dalle"
|
||||
|
||||
# Test using size 256
|
||||
result = lst(generate_image("astronaut riding a horse", 256))
|
||||
image_path = path_in_workspace(result)
|
||||
self.assertTrue(image_path.exists())
|
||||
with Image.open(image_path) as img:
|
||||
self.assertEqual(img.size, (256, 256))
|
||||
image_path.unlink()
|
||||
|
||||
# Test using size 512
|
||||
result = lst(generate_image("astronaut riding a horse", 512))
|
||||
image_path = path_in_workspace(result)
|
||||
with Image.open(image_path) as img:
|
||||
self.assertEqual(img.size, (512, 512))
|
||||
image_path.unlink()
|
||||
|
||||
def test_huggingface(self):
|
||||
self.config.image_provider = "huggingface"
|
||||
|
||||
# Test usin SD 1.4 model and size 512
|
||||
self.config.huggingface_image_model = "CompVis/stable-diffusion-v1-4"
|
||||
result = lst(generate_image("astronaut riding a horse", 512))
|
||||
image_path = path_in_workspace(result)
|
||||
self.assertTrue(image_path.exists())
|
||||
with Image.open(image_path) as img:
|
||||
self.assertEqual(img.size, (512, 512))
|
||||
image_path.unlink()
|
||||
|
||||
# Test using SD 2.1 768 model and size 768
|
||||
self.config.huggingface_image_model = "stabilityai/stable-diffusion-2-1"
|
||||
result = lst(generate_image("astronaut riding a horse", 768))
|
||||
image_path = path_in_workspace(result)
|
||||
with Image.open(image_path) as img:
|
||||
self.assertEqual(img.size, (768, 768))
|
||||
image_path.unlink()
|
||||
|
||||
def test_sd_webui(self):
|
||||
self.config.image_provider = "sd_webui"
|
||||
return
|
||||
|
||||
# Test using size 128
|
||||
result = lst(generate_image_with_sd_webui("astronaut riding a horse", 128))
|
||||
image_path = path_in_workspace(result)
|
||||
self.assertTrue(image_path.exists())
|
||||
with Image.open(image_path) as img:
|
||||
self.assertEqual(img.size, (128, 128))
|
||||
image_path.unlink()
|
||||
|
||||
# Test using size 64 and negative prompt
|
||||
result = lst(
|
||||
generate_image_with_sd_webui(
|
||||
"astronaut riding a horse",
|
||||
negative_prompt="horse",
|
||||
size=64,
|
||||
extra={"seed": 123},
|
||||
)
|
||||
)
|
||||
image_path = path_in_workspace(result)
|
||||
with Image.open(image_path) as img:
|
||||
self.assertEqual(img.size, (64, 64))
|
||||
neg_image_hash = hashlib.md5(img.tobytes()).hexdigest()
|
||||
image_path.unlink()
|
||||
|
||||
# Same test as above but without the negative prompt
|
||||
result = lst(
|
||||
generate_image_with_sd_webui(
|
||||
"astronaut riding a horse", image_size=64, size=1, extra={"seed": 123}
|
||||
)
|
||||
)
|
||||
image_path = path_in_workspace(result)
|
||||
with Image.open(image_path) as img:
|
||||
self.assertEqual(img.size, (64, 64))
|
||||
image_hash = hashlib.md5(img.tobytes()).hexdigest()
|
||||
image_path.unlink()
|
||||
|
||||
self.assertNotEqual(image_hash, neg_image_hash)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
Loading…
Reference in New Issue