Merge pull request #1477 from Tymec/feature/more-image-gen

Image generation improvements
pull/1243/head
BillSchumacher 2023-04-18 19:29:24 -05:00 committed by GitHub
commit fdaa55a452
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 192 additions and 11 deletions

View File

@ -107,14 +107,22 @@ MILVUS_COLLECTION=autogpt
### OPEN AI
# IMAGE_PROVIDER - Image provider (Example: dalle)
IMAGE_PROVIDER=dalle
# IMAGE_SIZE - Image size (Example: 256)
# DALLE: 256, 512, 1024
IMAGE_SIZE=256
### HUGGINGFACE
# STABLE DIFFUSION
# (Default URL: https://api-inference.huggingface.co/models/CompVis/stable-diffusion-v1-4)
# Set in image_gen.py)
# HUGGINGFACE_IMAGE_MODEL - Text-to-image model from Huggingface (Default: CompVis/stable-diffusion-v1-4)
HUGGINGFACE_IMAGE_MODEL=CompVis/stable-diffusion-v1-4
# HUGGINGFACE_API_TOKEN - HuggingFace API token (Example: my-huggingface-api-token)
HUGGINGFACE_API_TOKEN=your-huggingface-api-token
### STABLE DIFFUSION WEBUI
# SD_WEBUI_URL - Stable diffusion webui API URL (Example: http://127.0.0.1:7860)
SD_WEBUI_URL=http://127.0.0.1:7860
# SD_WEBUI_AUTH - Stable diffusion webui username:password pair (Example: username:password)
SD_WEBUI_AUTH=
################################################################################
### AUDIO TO TEXT PROVIDER
################################################################################

View File

@ -14,11 +14,12 @@ from autogpt.workspace import path_in_workspace
CFG = Config()
def generate_image(prompt: str) -> str:
def generate_image(prompt: str, size: int = 256) -> str:
"""Generate an image from a prompt.
Args:
prompt (str): The prompt to use
size (int, optional): The size of the image. Defaults to 256. (Not supported by HuggingFace)
Returns:
str: The filename of the image
@ -27,11 +28,14 @@ def generate_image(prompt: str) -> str:
# DALL-E
if CFG.image_provider == "dalle":
return generate_image_with_dalle(prompt, filename)
elif CFG.image_provider == "sd":
return generate_image_with_dalle(prompt, filename, size)
# HuggingFace
elif CFG.image_provider == "huggingface":
return generate_image_with_hf(prompt, filename)
else:
return "No Image Provider Set"
# SD WebUI
elif CFG.image_provider == "sdwebui":
return generate_image_with_sd_webui(prompt, filename, size)
return "No Image Provider Set"
def generate_image_with_hf(prompt: str, filename: str) -> str:
@ -45,13 +49,16 @@ def generate_image_with_hf(prompt: str, filename: str) -> str:
str: The filename of the image
"""
API_URL = (
"https://api-inference.huggingface.co/models/CompVis/stable-diffusion-v1-4"
f"https://api-inference.huggingface.co/models/{CFG.huggingface_image_model}"
)
if CFG.huggingface_api_token is None:
raise ValueError(
"You need to set your Hugging Face API token in the config file."
)
headers = {"Authorization": f"Bearer {CFG.huggingface_api_token}"}
headers = {
"Authorization": f"Bearer {CFG.huggingface_api_token}",
"X-Use-Cache": "false",
}
response = requests.post(
API_URL,
@ -81,10 +88,18 @@ def generate_image_with_dalle(prompt: str, filename: str) -> str:
"""
openai.api_key = CFG.openai_api_key
# Check for supported image sizes
if size not in [256, 512, 1024]:
closest = min([256, 512, 1024], key=lambda x: abs(x - size))
print(
f"DALL-E only supports image sizes of 256x256, 512x512, or 1024x1024. Setting to {closest}, was {size}."
)
size = closest
response = openai.Image.create(
prompt=prompt,
n=1,
size="256x256",
size=f"{size}x{size}",
response_format="b64_json",
)
@ -96,3 +111,53 @@ def generate_image_with_dalle(prompt: str, filename: str) -> str:
png.write(image_data)
return f"Saved to disk:{filename}"
def generate_image_with_sd_webui(
prompt: str,
filename: str,
size: int = 512,
negative_prompt: str = "",
extra: dict = {},
) -> str:
"""Generate an image with Stable Diffusion webui.
Args:
prompt (str): The prompt to use
filename (str): The filename to save the image to
size (int, optional): The size of the image. Defaults to 256.
negative_prompt (str, optional): The negative prompt to use. Defaults to "".
extra (dict, optional): Extra parameters to pass to the API. Defaults to {}.
Returns:
str: The filename of the image
"""
# Create a session and set the basic auth if needed
s = requests.Session()
if CFG.sd_webui_auth:
username, password = CFG.sd_webui_auth.split(":")
s.auth = (username, password or "")
# Generate the images
response = requests.post(
f"{CFG.sd_webui_url}/sdapi/v1/txt2img",
json={
"prompt": prompt,
"negative_prompt": negative_prompt,
"sampler_index": "DDIM",
"steps": 20,
"cfg_scale": 7.0,
"width": size,
"height": size,
"n_iter": 1,
**extra,
},
)
print(f"Image Generated for prompt:{prompt}")
# Save the image to disk
response = response.json()
b64 = b64decode(response["images"][0].split(",", 1)[0])
image = Image.open(io.BytesIO(b64))
image.save(path_in_workspace(filename))
return f"Saved to disk:{filename}"

View File

@ -85,10 +85,16 @@ class Config(metaclass=Singleton):
self.milvus_collection = os.getenv("MILVUS_COLLECTION", "autogpt")
self.image_provider = os.getenv("IMAGE_PROVIDER")
self.image_size = int(os.getenv("IMAGE_SIZE", 256))
self.huggingface_api_token = os.getenv("HUGGINGFACE_API_TOKEN")
self.huggingface_image_model = os.getenv(
"HUGGINGFACE_IMAGE_MODEL", "CompVis/stable-diffusion-v1-4"
)
self.huggingface_audio_to_text_model = os.getenv(
"HUGGINGFACE_AUDIO_TO_TEXT_MODEL"
)
self.sd_webui_url = os.getenv("SD_WEBUI_URL", "http://localhost:7860")
self.sd_webui_auth = os.getenv("SD_WEBUI_AUTH")
# Selenium browser settings
self.selenium_web_browser = os.getenv("USE_WEB_BROWSER", "chrome")

102
tests/test_image_gen.py Normal file
View File

@ -0,0 +1,102 @@
import hashlib
import os
import unittest
from PIL import Image
from autogpt.commands.image_gen import generate_image, generate_image_with_sd_webui
from autogpt.config import Config
from autogpt.workspace import path_in_workspace
def lst(txt):
return txt.split(":")[1].strip()
@unittest.skipIf(os.getenv("CI"), "Skipping image generation tests")
class TestImageGen(unittest.TestCase):
def setUp(self):
self.config = Config()
def test_dalle(self):
self.config.image_provider = "dalle"
# Test using size 256
result = lst(generate_image("astronaut riding a horse", 256))
image_path = path_in_workspace(result)
self.assertTrue(image_path.exists())
with Image.open(image_path) as img:
self.assertEqual(img.size, (256, 256))
image_path.unlink()
# Test using size 512
result = lst(generate_image("astronaut riding a horse", 512))
image_path = path_in_workspace(result)
with Image.open(image_path) as img:
self.assertEqual(img.size, (512, 512))
image_path.unlink()
def test_huggingface(self):
self.config.image_provider = "huggingface"
# Test usin SD 1.4 model and size 512
self.config.huggingface_image_model = "CompVis/stable-diffusion-v1-4"
result = lst(generate_image("astronaut riding a horse", 512))
image_path = path_in_workspace(result)
self.assertTrue(image_path.exists())
with Image.open(image_path) as img:
self.assertEqual(img.size, (512, 512))
image_path.unlink()
# Test using SD 2.1 768 model and size 768
self.config.huggingface_image_model = "stabilityai/stable-diffusion-2-1"
result = lst(generate_image("astronaut riding a horse", 768))
image_path = path_in_workspace(result)
with Image.open(image_path) as img:
self.assertEqual(img.size, (768, 768))
image_path.unlink()
def test_sd_webui(self):
self.config.image_provider = "sd_webui"
return
# Test using size 128
result = lst(generate_image_with_sd_webui("astronaut riding a horse", 128))
image_path = path_in_workspace(result)
self.assertTrue(image_path.exists())
with Image.open(image_path) as img:
self.assertEqual(img.size, (128, 128))
image_path.unlink()
# Test using size 64 and negative prompt
result = lst(
generate_image_with_sd_webui(
"astronaut riding a horse",
negative_prompt="horse",
size=64,
extra={"seed": 123},
)
)
image_path = path_in_workspace(result)
with Image.open(image_path) as img:
self.assertEqual(img.size, (64, 64))
neg_image_hash = hashlib.md5(img.tobytes()).hexdigest()
image_path.unlink()
# Same test as above but without the negative prompt
result = lst(
generate_image_with_sd_webui(
"astronaut riding a horse", image_size=64, size=1, extra={"seed": 123}
)
)
image_path = path_in_workspace(result)
with Image.open(image_path) as img:
self.assertEqual(img.size, (64, 64))
image_hash = hashlib.md5(img.tobytes()).hexdigest()
image_path.unlink()
self.assertNotEqual(image_hash, neg_image_hash)
if __name__ == "__main__":
unittest.main()