support nc

pull/49/head
yupeng.zhou 2024-05-08 03:47:07 +08:00
parent e59b6d67db
commit 0f285117e1
6 changed files with 98 additions and 25 deletions

View File

@ -28,7 +28,9 @@ Official implementation of **[StoryDiffusion: Consistent Self-Attention for Long
https://github.com/HVision-NKU/StoryDiffusion/assets/49511209/d5b80f8f-09b0-48cd-8b10-daff46d422af
### Update History
You can visit [here](update.md) to visit update history.
### 🌠 **Key Features:**
StoryDiffusion can create a magic story by generating consistent images and videos. Our work mainly has two parts:

View File

@ -546,6 +546,12 @@ def process_generation(_sd_type,_model_type,_upload_images, _num_steps,style_nam
sa32, sa64 = sa32_, sa64_
id_length = id_length_
clipped_prompts = prompts[:]
nc_indexs = []
for ind,prompt in enumerate(clipped_prompts):
if "[NC]" in prompt:
nc_indexs.append(ind)
if ind < id_length:
raise gr.Error(f"The first {id_length} row is id prompts, cannot use [NC]!")
prompts = [general_prompt + "," + prompt if "[NC]" not in prompt else prompt.replace("[NC]","") for prompt in clipped_prompts]
prompts = [prompt.rpartition('#')[0] if "#" in prompt else prompt for prompt in prompts]
print(prompts)
@ -569,14 +575,14 @@ def process_generation(_sd_type,_model_type,_upload_images, _num_steps,style_nam
yield total_results
real_images = []
write = False
for real_prompt in real_prompts:
for ind,real_prompt in enumerate(real_prompts):
setup_seed(seed_)
cur_step = 0
real_prompt = apply_style_positive(style_name, real_prompt)
if _model_type == "original":
real_images.append(pipe(real_prompt, num_inference_steps=_num_steps, guidance_scale=guidance_scale, height = height, width = width,negative_prompt = negative_prompt,generator = generator).images[0])
elif _model_type == "Photomaker":
real_images.append(pipe(real_prompt, input_id_images=input_id_images, num_inference_steps=_num_steps, guidance_scale=guidance_scale, start_merge_step = start_merge_step, height = height, width = width,negative_prompt = negative_prompt,generator = generator).images[0])
real_images.append(pipe(real_prompt, input_id_images=input_id_images, num_inference_steps=_num_steps, guidance_scale=guidance_scale, start_merge_step = start_merge_step, height = height, width = width,negative_prompt = negative_prompt,generator = generator,nc_flag = True if ind+id_length in nc_indexs else False ).images[0])
else:
raise NotImplementedError("You should choice between original and Photomaker!",f"But you choice {_model_type}")
total_results = [real_images[-1]] + total_results
@ -715,6 +721,19 @@ with gr.Blocks(css=css) as demo:
]),
"Comic book","Only Using Textual Description",get_image_path_list('./examples/taylor'),768,768
],
[0,0.5,0.5,2,"a man, wearing black suit",
"bad anatomy, bad hands, missing fingers, extra fingers, three hands, three legs, bad arms, missing legs, missing arms, poorly drawn face, bad face, fused face, cloned face, three crus, fused feet, fused thigh, extra crus, ugly fingers, horn, cartoon, cg, 3d, unreal, animate, amputation, disconnected limbs",
array2string(["at home, read new paper #at home, The newspaper says there is a treasure house in the forest.",
"on the road, near the forest",
"[NC] The car on the road, near the forest #He drives to the forest in search of treasure.",
"[NC]A tiger appeared in the forest, at night ",
"very frightened, open mouth, in the forest, at night",
"running very fast, in the forest, at night",
"[NC] A house in the forest, at night #Suddenly, he discovers the treasure house!",
"in the house filled with treasure, laughing, at night #He is overjoyed inside the house."
]),
"Comic book","Only Using Textual Description",get_image_path_list('./examples/Robert'),1024,1024
],
[1,0.5,0.5,3,"a woman img, wearing a white T-shirt, blue loose hair",
"bad anatomy, bad hands, missing fingers, extra fingers, three hands, three legs, bad arms, missing legs, missing arms, poorly drawn face, bad face, fused face, cloned face, three crus, fused feet, fused thigh, extra crus, ugly fingers, horn, cartoon, cg, 3d, unreal, animate, amputation, disconnected limbs",
array2string(["wake up in the bed",

View File

@ -583,6 +583,12 @@ def process_generation(_sd_type,_model_type,_upload_images, _num_steps,style_nam
sa32, sa64 = sa32_, sa64_
id_length = id_length_
clipped_prompts = prompts[:]
nc_indexs = []
for ind,prompt in enumerate(clipped_prompts):
if "[NC]" in prompt:
nc_indexs.append(ind)
if ind < id_length:
raise gr.Error(f"The first {id_length} row is id prompts, cannot use [NC]!")
prompts = [general_prompt + "," + prompt if "[NC]" not in prompt else prompt.replace("[NC]","") for prompt in clipped_prompts]
prompts = [prompt.rpartition('#')[0] if "#" in prompt else prompt for prompt in prompts]
print(prompts)
@ -606,14 +612,14 @@ def process_generation(_sd_type,_model_type,_upload_images, _num_steps,style_nam
yield total_results
real_images = []
write = False
for real_prompt in real_prompts:
for ind,real_prompt in enumerate(real_prompts):
setup_seed(seed_)
cur_step = 0
real_prompt = apply_style_positive(style_name, real_prompt)
if _model_type == "original":
real_images.append(pipe(real_prompt, num_inference_steps=_num_steps, guidance_scale=guidance_scale, height = height, width = width,negative_prompt = negative_prompt,generator = generator).images[0])
elif _model_type == "Photomaker":
real_images.append(pipe(real_prompt, input_id_images=input_id_images, num_inference_steps=_num_steps, guidance_scale=guidance_scale, start_merge_step = start_merge_step, height = height, width = width,negative_prompt = negative_prompt,generator = generator).images[0])
real_images.append(pipe(real_prompt, input_id_images=input_id_images, num_inference_steps=_num_steps, guidance_scale=guidance_scale, start_merge_step = start_merge_step, height = height, width = width,negative_prompt = negative_prompt,generator = generator,nc_flag = True if ind+id_length in nc_indexs else False).images[0])
else:
raise NotImplementedError("You should choice between original and Photomaker!",f"But you choice {_model_type}")
total_results = [real_images[-1]] + total_results
@ -752,6 +758,19 @@ with gr.Blocks(css=css) as demo:
]),
"Comic book","Only Using Textual Description",get_image_path_list('./examples/taylor'),768,768
],
[0,0.5,0.5,2,"a man, wearing black suit",
"bad anatomy, bad hands, missing fingers, extra fingers, three hands, three legs, bad arms, missing legs, missing arms, poorly drawn face, bad face, fused face, cloned face, three crus, fused feet, fused thigh, extra crus, ugly fingers, horn, cartoon, cg, 3d, unreal, animate, amputation, disconnected limbs",
array2string(["at home, read new paper #at home, The newspaper says there is a treasure house in the forest.",
"on the road, near the forest",
"[NC] The car on the road, near the forest #He drives to the forest in search of treasure.",
"[NC]A tiger appeared in the forest, at night ",
"very frightened, open mouth, in the forest, at night",
"running very fast, in the forest, at night",
"[NC] A house in the forest, at night #Suddenly, he discovers the treasure house!",
"in the house filled with treasure, laughing, at night #He is overjoyed inside the house."
]),
"Comic book","Only Using Textual Description",get_image_path_list('./examples/Robert'),1024,1024
],
[1,0.5,0.5,3,"a woman img, wearing a white T-shirt, blue loose hair",
"bad anatomy, bad hands, missing fingers, extra fingers, three hands, three legs, bad arms, missing legs, missing arms, poorly drawn face, bad face, fused face, cloned face, three crus, fused feet, fused thigh, extra crus, ugly fingers, horn, cartoon, cg, 3d, unreal, animate, amputation, disconnected limbs",
array2string(["wake up in the bed",

BIN
results_examples/image1.png Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 7.9 MiB

22
update.md Normal file
View File

@ -0,0 +1,22 @@
## Update History
### Update 2024-05-08
- Support [NC] in Ref Image Model (Photomaker work best in 1024x1024 but may cost a lot of GPU memory, I recommend you to use the res. as larger as possible)
<img src="results_examples/image1.png" height=100>
- Merge Push by @cryptowooser to support lastest pillow. But you may be updated pillow if you using the old version.
### Todo
- Support add captions on all images for the classical commic Typesetting Style
### Welcome to contribute
- Various layout styles.

View File

@ -136,6 +136,7 @@ class PhotoMakerStableDiffusionXLPipeline(StableDiffusionXLPipeline):
prompt_embeds: Optional[torch.FloatTensor] = None,
pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
class_tokens_mask: Optional[torch.LongTensor] = None,
nc_flag: bool = False,
):
device = device or self._execution_device
@ -171,8 +172,14 @@ class PhotoMakerStableDiffusionXLPipeline(StableDiffusionXLPipeline):
else:
clean_input_ids.append(token_id)
clean_index += 1
if len(class_token_index) != 1:
if nc_flag:
return None, None, None
if len(class_token_index) > 1:
raise ValueError(
f"PhotoMaker currently does not support multiple trigger words in a single prompt.\
Trigger word: {self.trigger_word}, Prompt: {prompt}."
)
elif len(class_token_index) == 0 and not nc_flag:
raise ValueError(
f"PhotoMaker currently does not support multiple trigger words in a single prompt.\
Trigger word: {self.trigger_word}, Prompt: {prompt}."
@ -257,6 +264,7 @@ class PhotoMakerStableDiffusionXLPipeline(StableDiffusionXLPipeline):
class_tokens_mask: Optional[torch.LongTensor] = None,
prompt_embeds_text_only: Optional[torch.FloatTensor] = None,
pooled_prompt_embeds_text_only: Optional[torch.FloatTensor] = None,
nc_flag = False,
):
r"""
Function invoked when calling the pipeline for generation.
@ -361,13 +369,15 @@ class PhotoMakerStableDiffusionXLPipeline(StableDiffusionXLPipeline):
prompt_embeds=prompt_embeds,
pooled_prompt_embeds=pooled_prompt_embeds,
class_tokens_mask=class_tokens_mask,
nc_flag = nc_flag,
)
# 4. Encode input prompt without the trigger word for delayed conditioning
# encode, remove trigger word token, then decode
tokens_text_only = self.tokenizer.encode(prompt, add_special_tokens=False)
trigger_word_token = self.tokenizer.convert_tokens_to_ids(self.trigger_word)
tokens_text_only.remove(trigger_word_token)
if not nc_flag:
tokens_text_only.remove(trigger_word_token)
prompt_text_only = self.tokenizer.decode(tokens_text_only, add_special_tokens=False)
print(prompt_text_only)
(
@ -396,17 +406,19 @@ class PhotoMakerStableDiffusionXLPipeline(StableDiffusionXLPipeline):
id_pixel_values = id_pixel_values.unsqueeze(0).to(device=device, dtype=dtype) # TODO: multiple prompts
# 6. Get the update text embedding with the stacked ID embedding
prompt_embeds = self.id_encoder(id_pixel_values, prompt_embeds, class_tokens_mask)
bs_embed, seq_len, _ = prompt_embeds.shape
# duplicate text embeddings for each generation per prompt, using mps friendly method
prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1)
pooled_prompt_embeds = pooled_prompt_embeds.repeat(1, num_images_per_prompt).view(
bs_embed * num_images_per_prompt, -1
)
if not nc_flag:
# 6. Get the update text embedding with the stacked ID embedding
prompt_embeds = self.id_encoder(id_pixel_values, prompt_embeds, class_tokens_mask)
bs_embed, seq_len, _ = prompt_embeds.shape
prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1)
pooled_prompt_embeds = pooled_prompt_embeds.repeat(1, num_images_per_prompt).view(
bs_embed * num_images_per_prompt, -1
)
pooled_prompt_embeds_arr.append(pooled_prompt_embeds)
pooled_prompt_embeds = None
negative_prompt_embeds_arr.append(negative_prompt_embeds)
negative_prompt_embeds = None
@ -416,8 +428,6 @@ class PhotoMakerStableDiffusionXLPipeline(StableDiffusionXLPipeline):
prompt_embeds_text_only = None
prompt_embeds_arr.append(prompt_embeds)
prompt_embeds = None
pooled_prompt_embeds_arr.append(pooled_prompt_embeds)
pooled_prompt_embeds = None
pooled_prompt_embeds_text_only_arr.append(pooled_prompt_embeds_text_only)
pooled_prompt_embeds_text_only = None
# 7. Prepare timesteps
@ -426,8 +436,11 @@ class PhotoMakerStableDiffusionXLPipeline(StableDiffusionXLPipeline):
negative_prompt_embeds = torch.cat(negative_prompt_embeds_arr ,dim =0)
print(negative_prompt_embeds.shape)
prompt_embeds = torch.cat(prompt_embeds_arr ,dim = 0)
print(prompt_embeds.shape)
if not nc_flag:
prompt_embeds = torch.cat(prompt_embeds_arr ,dim = 0)
print(prompt_embeds.shape)
pooled_prompt_embeds = torch.cat(pooled_prompt_embeds_arr,dim = 0)
print(pooled_prompt_embeds.shape)
prompt_embeds_text_only = torch.cat(prompt_embeds_text_only_arr ,dim = 0)
print(prompt_embeds_text_only.shape)
@ -436,8 +449,6 @@ class PhotoMakerStableDiffusionXLPipeline(StableDiffusionXLPipeline):
negative_pooled_prompt_embeds = torch.cat(negative_pooled_prompt_embeds_arr ,dim = 0)
print(negative_pooled_prompt_embeds.shape)
pooled_prompt_embeds = torch.cat(pooled_prompt_embeds_arr,dim = 0)
print(pooled_prompt_embeds.shape)
# 8. Prepare latent variables
num_channels_latents = self.unet.config.in_channels
latents = self.prepare_latents(
@ -445,7 +456,7 @@ class PhotoMakerStableDiffusionXLPipeline(StableDiffusionXLPipeline):
num_channels_latents,
height,
width,
prompt_embeds.dtype,
prompt_embeds.dtype if not nc_flag else prompt_embeds_text_only.dtype,
device,
generator,
latents,
@ -464,7 +475,7 @@ class PhotoMakerStableDiffusionXLPipeline(StableDiffusionXLPipeline):
original_size,
crops_coords_top_left,
target_size,
dtype=prompt_embeds.dtype,
dtype=prompt_embeds.dtype if not nc_flag else prompt_embeds_text_only.dtype,
text_encoder_projection_dim=text_encoder_projection_dim,
)
add_time_ids = torch.cat([add_time_ids, add_time_ids], dim=0)
@ -486,7 +497,7 @@ class PhotoMakerStableDiffusionXLPipeline(StableDiffusionXLPipeline):
)
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
if i <= start_merge_step:
if i <= start_merge_step or nc_flag:
current_prompt_embeds = torch.cat(
[negative_prompt_embeds, prompt_embeds_text_only], dim=0
)