support nc
parent
e59b6d67db
commit
0f285117e1
|
@ -28,7 +28,9 @@ Official implementation of **[StoryDiffusion: Consistent Self-Attention for Long
|
|||
https://github.com/HVision-NKU/StoryDiffusion/assets/49511209/d5b80f8f-09b0-48cd-8b10-daff46d422af
|
||||
|
||||
|
||||
### Update History
|
||||
|
||||
You can visit [here](update.md) to visit update history.
|
||||
|
||||
### 🌠 **Key Features:**
|
||||
StoryDiffusion can create a magic story by generating consistent images and videos. Our work mainly has two parts:
|
||||
|
|
|
@ -546,6 +546,12 @@ def process_generation(_sd_type,_model_type,_upload_images, _num_steps,style_nam
|
|||
sa32, sa64 = sa32_, sa64_
|
||||
id_length = id_length_
|
||||
clipped_prompts = prompts[:]
|
||||
nc_indexs = []
|
||||
for ind,prompt in enumerate(clipped_prompts):
|
||||
if "[NC]" in prompt:
|
||||
nc_indexs.append(ind)
|
||||
if ind < id_length:
|
||||
raise gr.Error(f"The first {id_length} row is id prompts, cannot use [NC]!")
|
||||
prompts = [general_prompt + "," + prompt if "[NC]" not in prompt else prompt.replace("[NC]","") for prompt in clipped_prompts]
|
||||
prompts = [prompt.rpartition('#')[0] if "#" in prompt else prompt for prompt in prompts]
|
||||
print(prompts)
|
||||
|
@ -569,14 +575,14 @@ def process_generation(_sd_type,_model_type,_upload_images, _num_steps,style_nam
|
|||
yield total_results
|
||||
real_images = []
|
||||
write = False
|
||||
for real_prompt in real_prompts:
|
||||
for ind,real_prompt in enumerate(real_prompts):
|
||||
setup_seed(seed_)
|
||||
cur_step = 0
|
||||
real_prompt = apply_style_positive(style_name, real_prompt)
|
||||
if _model_type == "original":
|
||||
real_images.append(pipe(real_prompt, num_inference_steps=_num_steps, guidance_scale=guidance_scale, height = height, width = width,negative_prompt = negative_prompt,generator = generator).images[0])
|
||||
elif _model_type == "Photomaker":
|
||||
real_images.append(pipe(real_prompt, input_id_images=input_id_images, num_inference_steps=_num_steps, guidance_scale=guidance_scale, start_merge_step = start_merge_step, height = height, width = width,negative_prompt = negative_prompt,generator = generator).images[0])
|
||||
real_images.append(pipe(real_prompt, input_id_images=input_id_images, num_inference_steps=_num_steps, guidance_scale=guidance_scale, start_merge_step = start_merge_step, height = height, width = width,negative_prompt = negative_prompt,generator = generator,nc_flag = True if ind+id_length in nc_indexs else False ).images[0])
|
||||
else:
|
||||
raise NotImplementedError("You should choice between original and Photomaker!",f"But you choice {_model_type}")
|
||||
total_results = [real_images[-1]] + total_results
|
||||
|
@ -715,6 +721,19 @@ with gr.Blocks(css=css) as demo:
|
|||
]),
|
||||
"Comic book","Only Using Textual Description",get_image_path_list('./examples/taylor'),768,768
|
||||
],
|
||||
[0,0.5,0.5,2,"a man, wearing black suit",
|
||||
"bad anatomy, bad hands, missing fingers, extra fingers, three hands, three legs, bad arms, missing legs, missing arms, poorly drawn face, bad face, fused face, cloned face, three crus, fused feet, fused thigh, extra crus, ugly fingers, horn, cartoon, cg, 3d, unreal, animate, amputation, disconnected limbs",
|
||||
array2string(["at home, read new paper #at home, The newspaper says there is a treasure house in the forest.",
|
||||
"on the road, near the forest",
|
||||
"[NC] The car on the road, near the forest #He drives to the forest in search of treasure.",
|
||||
"[NC]A tiger appeared in the forest, at night ",
|
||||
"very frightened, open mouth, in the forest, at night",
|
||||
"running very fast, in the forest, at night",
|
||||
"[NC] A house in the forest, at night #Suddenly, he discovers the treasure house!",
|
||||
"in the house filled with treasure, laughing, at night #He is overjoyed inside the house."
|
||||
]),
|
||||
"Comic book","Only Using Textual Description",get_image_path_list('./examples/Robert'),1024,1024
|
||||
],
|
||||
[1,0.5,0.5,3,"a woman img, wearing a white T-shirt, blue loose hair",
|
||||
"bad anatomy, bad hands, missing fingers, extra fingers, three hands, three legs, bad arms, missing legs, missing arms, poorly drawn face, bad face, fused face, cloned face, three crus, fused feet, fused thigh, extra crus, ugly fingers, horn, cartoon, cg, 3d, unreal, animate, amputation, disconnected limbs",
|
||||
array2string(["wake up in the bed",
|
||||
|
|
|
@ -583,6 +583,12 @@ def process_generation(_sd_type,_model_type,_upload_images, _num_steps,style_nam
|
|||
sa32, sa64 = sa32_, sa64_
|
||||
id_length = id_length_
|
||||
clipped_prompts = prompts[:]
|
||||
nc_indexs = []
|
||||
for ind,prompt in enumerate(clipped_prompts):
|
||||
if "[NC]" in prompt:
|
||||
nc_indexs.append(ind)
|
||||
if ind < id_length:
|
||||
raise gr.Error(f"The first {id_length} row is id prompts, cannot use [NC]!")
|
||||
prompts = [general_prompt + "," + prompt if "[NC]" not in prompt else prompt.replace("[NC]","") for prompt in clipped_prompts]
|
||||
prompts = [prompt.rpartition('#')[0] if "#" in prompt else prompt for prompt in prompts]
|
||||
print(prompts)
|
||||
|
@ -606,14 +612,14 @@ def process_generation(_sd_type,_model_type,_upload_images, _num_steps,style_nam
|
|||
yield total_results
|
||||
real_images = []
|
||||
write = False
|
||||
for real_prompt in real_prompts:
|
||||
for ind,real_prompt in enumerate(real_prompts):
|
||||
setup_seed(seed_)
|
||||
cur_step = 0
|
||||
real_prompt = apply_style_positive(style_name, real_prompt)
|
||||
if _model_type == "original":
|
||||
real_images.append(pipe(real_prompt, num_inference_steps=_num_steps, guidance_scale=guidance_scale, height = height, width = width,negative_prompt = negative_prompt,generator = generator).images[0])
|
||||
elif _model_type == "Photomaker":
|
||||
real_images.append(pipe(real_prompt, input_id_images=input_id_images, num_inference_steps=_num_steps, guidance_scale=guidance_scale, start_merge_step = start_merge_step, height = height, width = width,negative_prompt = negative_prompt,generator = generator).images[0])
|
||||
real_images.append(pipe(real_prompt, input_id_images=input_id_images, num_inference_steps=_num_steps, guidance_scale=guidance_scale, start_merge_step = start_merge_step, height = height, width = width,negative_prompt = negative_prompt,generator = generator,nc_flag = True if ind+id_length in nc_indexs else False).images[0])
|
||||
else:
|
||||
raise NotImplementedError("You should choice between original and Photomaker!",f"But you choice {_model_type}")
|
||||
total_results = [real_images[-1]] + total_results
|
||||
|
@ -752,6 +758,19 @@ with gr.Blocks(css=css) as demo:
|
|||
]),
|
||||
"Comic book","Only Using Textual Description",get_image_path_list('./examples/taylor'),768,768
|
||||
],
|
||||
[0,0.5,0.5,2,"a man, wearing black suit",
|
||||
"bad anatomy, bad hands, missing fingers, extra fingers, three hands, three legs, bad arms, missing legs, missing arms, poorly drawn face, bad face, fused face, cloned face, three crus, fused feet, fused thigh, extra crus, ugly fingers, horn, cartoon, cg, 3d, unreal, animate, amputation, disconnected limbs",
|
||||
array2string(["at home, read new paper #at home, The newspaper says there is a treasure house in the forest.",
|
||||
"on the road, near the forest",
|
||||
"[NC] The car on the road, near the forest #He drives to the forest in search of treasure.",
|
||||
"[NC]A tiger appeared in the forest, at night ",
|
||||
"very frightened, open mouth, in the forest, at night",
|
||||
"running very fast, in the forest, at night",
|
||||
"[NC] A house in the forest, at night #Suddenly, he discovers the treasure house!",
|
||||
"in the house filled with treasure, laughing, at night #He is overjoyed inside the house."
|
||||
]),
|
||||
"Comic book","Only Using Textual Description",get_image_path_list('./examples/Robert'),1024,1024
|
||||
],
|
||||
[1,0.5,0.5,3,"a woman img, wearing a white T-shirt, blue loose hair",
|
||||
"bad anatomy, bad hands, missing fingers, extra fingers, three hands, three legs, bad arms, missing legs, missing arms, poorly drawn face, bad face, fused face, cloned face, three crus, fused feet, fused thigh, extra crus, ugly fingers, horn, cartoon, cg, 3d, unreal, animate, amputation, disconnected limbs",
|
||||
array2string(["wake up in the bed",
|
||||
|
|
Binary file not shown.
After Width: | Height: | Size: 7.9 MiB |
|
@ -0,0 +1,22 @@
|
|||
## Update History
|
||||
|
||||
### Update 2024-05-08
|
||||
|
||||
- Support [NC] in Ref Image Model (Photomaker work best in 1024x1024 but may cost a lot of GPU memory, I recommend you to use the res. as larger as possible)
|
||||
|
||||
<img src="results_examples/image1.png" height=100>
|
||||
|
||||
- Merge Push by @cryptowooser to support lastest pillow. But you may be updated pillow if you using the old version.
|
||||
|
||||
|
||||
|
||||
### Todo
|
||||
|
||||
- Support add captions on all images for the classical commic Typesetting Style
|
||||
|
||||
|
||||
|
||||
|
||||
### Welcome to contribute
|
||||
|
||||
- Various layout styles.
|
|
@ -136,6 +136,7 @@ class PhotoMakerStableDiffusionXLPipeline(StableDiffusionXLPipeline):
|
|||
prompt_embeds: Optional[torch.FloatTensor] = None,
|
||||
pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
|
||||
class_tokens_mask: Optional[torch.LongTensor] = None,
|
||||
nc_flag: bool = False,
|
||||
):
|
||||
device = device or self._execution_device
|
||||
|
||||
|
@ -171,8 +172,14 @@ class PhotoMakerStableDiffusionXLPipeline(StableDiffusionXLPipeline):
|
|||
else:
|
||||
clean_input_ids.append(token_id)
|
||||
clean_index += 1
|
||||
|
||||
if len(class_token_index) != 1:
|
||||
if nc_flag:
|
||||
return None, None, None
|
||||
if len(class_token_index) > 1:
|
||||
raise ValueError(
|
||||
f"PhotoMaker currently does not support multiple trigger words in a single prompt.\
|
||||
Trigger word: {self.trigger_word}, Prompt: {prompt}."
|
||||
)
|
||||
elif len(class_token_index) == 0 and not nc_flag:
|
||||
raise ValueError(
|
||||
f"PhotoMaker currently does not support multiple trigger words in a single prompt.\
|
||||
Trigger word: {self.trigger_word}, Prompt: {prompt}."
|
||||
|
@ -257,6 +264,7 @@ class PhotoMakerStableDiffusionXLPipeline(StableDiffusionXLPipeline):
|
|||
class_tokens_mask: Optional[torch.LongTensor] = None,
|
||||
prompt_embeds_text_only: Optional[torch.FloatTensor] = None,
|
||||
pooled_prompt_embeds_text_only: Optional[torch.FloatTensor] = None,
|
||||
nc_flag = False,
|
||||
):
|
||||
r"""
|
||||
Function invoked when calling the pipeline for generation.
|
||||
|
@ -361,13 +369,15 @@ class PhotoMakerStableDiffusionXLPipeline(StableDiffusionXLPipeline):
|
|||
prompt_embeds=prompt_embeds,
|
||||
pooled_prompt_embeds=pooled_prompt_embeds,
|
||||
class_tokens_mask=class_tokens_mask,
|
||||
nc_flag = nc_flag,
|
||||
)
|
||||
|
||||
# 4. Encode input prompt without the trigger word for delayed conditioning
|
||||
# encode, remove trigger word token, then decode
|
||||
tokens_text_only = self.tokenizer.encode(prompt, add_special_tokens=False)
|
||||
trigger_word_token = self.tokenizer.convert_tokens_to_ids(self.trigger_word)
|
||||
tokens_text_only.remove(trigger_word_token)
|
||||
if not nc_flag:
|
||||
tokens_text_only.remove(trigger_word_token)
|
||||
prompt_text_only = self.tokenizer.decode(tokens_text_only, add_special_tokens=False)
|
||||
print(prompt_text_only)
|
||||
(
|
||||
|
@ -396,17 +406,19 @@ class PhotoMakerStableDiffusionXLPipeline(StableDiffusionXLPipeline):
|
|||
|
||||
id_pixel_values = id_pixel_values.unsqueeze(0).to(device=device, dtype=dtype) # TODO: multiple prompts
|
||||
|
||||
# 6. Get the update text embedding with the stacked ID embedding
|
||||
prompt_embeds = self.id_encoder(id_pixel_values, prompt_embeds, class_tokens_mask)
|
||||
|
||||
bs_embed, seq_len, _ = prompt_embeds.shape
|
||||
# duplicate text embeddings for each generation per prompt, using mps friendly method
|
||||
prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
|
||||
prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1)
|
||||
pooled_prompt_embeds = pooled_prompt_embeds.repeat(1, num_images_per_prompt).view(
|
||||
bs_embed * num_images_per_prompt, -1
|
||||
)
|
||||
if not nc_flag:
|
||||
# 6. Get the update text embedding with the stacked ID embedding
|
||||
prompt_embeds = self.id_encoder(id_pixel_values, prompt_embeds, class_tokens_mask)
|
||||
|
||||
bs_embed, seq_len, _ = prompt_embeds.shape
|
||||
prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
|
||||
prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1)
|
||||
pooled_prompt_embeds = pooled_prompt_embeds.repeat(1, num_images_per_prompt).view(
|
||||
bs_embed * num_images_per_prompt, -1
|
||||
)
|
||||
pooled_prompt_embeds_arr.append(pooled_prompt_embeds)
|
||||
pooled_prompt_embeds = None
|
||||
|
||||
negative_prompt_embeds_arr.append(negative_prompt_embeds)
|
||||
negative_prompt_embeds = None
|
||||
|
@ -416,8 +428,6 @@ class PhotoMakerStableDiffusionXLPipeline(StableDiffusionXLPipeline):
|
|||
prompt_embeds_text_only = None
|
||||
prompt_embeds_arr.append(prompt_embeds)
|
||||
prompt_embeds = None
|
||||
pooled_prompt_embeds_arr.append(pooled_prompt_embeds)
|
||||
pooled_prompt_embeds = None
|
||||
pooled_prompt_embeds_text_only_arr.append(pooled_prompt_embeds_text_only)
|
||||
pooled_prompt_embeds_text_only = None
|
||||
# 7. Prepare timesteps
|
||||
|
@ -426,8 +436,11 @@ class PhotoMakerStableDiffusionXLPipeline(StableDiffusionXLPipeline):
|
|||
|
||||
negative_prompt_embeds = torch.cat(negative_prompt_embeds_arr ,dim =0)
|
||||
print(negative_prompt_embeds.shape)
|
||||
prompt_embeds = torch.cat(prompt_embeds_arr ,dim = 0)
|
||||
print(prompt_embeds.shape)
|
||||
if not nc_flag:
|
||||
prompt_embeds = torch.cat(prompt_embeds_arr ,dim = 0)
|
||||
print(prompt_embeds.shape)
|
||||
pooled_prompt_embeds = torch.cat(pooled_prompt_embeds_arr,dim = 0)
|
||||
print(pooled_prompt_embeds.shape)
|
||||
|
||||
prompt_embeds_text_only = torch.cat(prompt_embeds_text_only_arr ,dim = 0)
|
||||
print(prompt_embeds_text_only.shape)
|
||||
|
@ -436,8 +449,6 @@ class PhotoMakerStableDiffusionXLPipeline(StableDiffusionXLPipeline):
|
|||
|
||||
negative_pooled_prompt_embeds = torch.cat(negative_pooled_prompt_embeds_arr ,dim = 0)
|
||||
print(negative_pooled_prompt_embeds.shape)
|
||||
pooled_prompt_embeds = torch.cat(pooled_prompt_embeds_arr,dim = 0)
|
||||
print(pooled_prompt_embeds.shape)
|
||||
# 8. Prepare latent variables
|
||||
num_channels_latents = self.unet.config.in_channels
|
||||
latents = self.prepare_latents(
|
||||
|
@ -445,7 +456,7 @@ class PhotoMakerStableDiffusionXLPipeline(StableDiffusionXLPipeline):
|
|||
num_channels_latents,
|
||||
height,
|
||||
width,
|
||||
prompt_embeds.dtype,
|
||||
prompt_embeds.dtype if not nc_flag else prompt_embeds_text_only.dtype,
|
||||
device,
|
||||
generator,
|
||||
latents,
|
||||
|
@ -464,7 +475,7 @@ class PhotoMakerStableDiffusionXLPipeline(StableDiffusionXLPipeline):
|
|||
original_size,
|
||||
crops_coords_top_left,
|
||||
target_size,
|
||||
dtype=prompt_embeds.dtype,
|
||||
dtype=prompt_embeds.dtype if not nc_flag else prompt_embeds_text_only.dtype,
|
||||
text_encoder_projection_dim=text_encoder_projection_dim,
|
||||
)
|
||||
add_time_ids = torch.cat([add_time_ids, add_time_ids], dim=0)
|
||||
|
@ -486,7 +497,7 @@ class PhotoMakerStableDiffusionXLPipeline(StableDiffusionXLPipeline):
|
|||
)
|
||||
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
|
||||
|
||||
if i <= start_merge_step:
|
||||
if i <= start_merge_step or nc_flag:
|
||||
current_prompt_embeds = torch.cat(
|
||||
[negative_prompt_embeds, prompt_embeds_text_only], dim=0
|
||||
)
|
||||
|
|
Loading…
Reference in New Issue