support mul character

pull/85/head
yupeng.zhou 2024-05-12 21:59:49 +08:00
parent 5f62ae939a
commit ae1398a451
4 changed files with 232 additions and 70 deletions

BIN
examples/twoperson/1.jpeg Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 6.8 KiB

BIN
examples/twoperson/2.png Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 48 KiB

View File

@ -15,6 +15,7 @@ from PIL import Image
from tqdm.auto import tqdm
from datetime import datetime
from utils.gradio_utils import is_torch2_available
from utils.gradio_utils import get_id_prompt_index, character_to_dict,get_cur_id_list, process_original_prompt, get_ref_character
if is_torch2_available():
from utils.gradio_utils import \
AttnProcessor2_0 as AttnProcessor
@ -101,9 +102,11 @@ class SpatialAttnProcessor2_0(torch.nn.Module):
global sa32, sa64
global write
global height,width
global character_dict,character_index_dict,invert_character_index_dict,cur_character,ref_indexs_dict,ref_totals,cur_character
if attn_count == 0 and cur_step == 0:
indices1024,indices4096 = cal_attn_indice_xl_effcient_memory(self.total_length,self.id_length,sa32,sa64,height,width, device=self.device, dtype= self.dtype)
if write:
assert len(cur_character) == 1
if hidden_states.shape[1] == (height//32) * (width//32):
indices = indices1024
else:
@ -112,12 +115,18 @@ class SpatialAttnProcessor2_0(torch.nn.Module):
total_batch_size,nums_token,channel = hidden_states.shape
img_nums = total_batch_size // 2
hidden_states = hidden_states.reshape(-1,img_nums,nums_token,channel)
self.id_bank[cur_step] = [hidden_states[:,img_ind,indices[img_ind],:].reshape(2,-1,channel).clone() for img_ind in range(img_nums)]
# print(img_nums,len(indices),hidden_states.shape,self.total_length)
if cur_character[0] not in self.id_bank:
self.id_bank[cur_character[0]] = {}
self.id_bank[cur_character[0]][cur_step] = [hidden_states[:,img_ind,indices[img_ind],:].reshape(2,-1,channel).clone() for img_ind in range(img_nums)]
hidden_states = hidden_states.reshape(-1,nums_token,channel)
#self.id_bank[cur_step] = [hidden_states[:self.id_length].clone(), hidden_states[self.id_length:].clone()]
else:
#encoder_hidden_states = torch.cat((self.id_bank[cur_step][0].to(self.device),self.id_bank[cur_step][1].to(self.device)))
encoder_arr = [tensor.to(self.device) for tensor in self.id_bank[cur_step]]
# TODO: ADD Multipersion Control
encoder_arr = []
for character in cur_character:
encoder_arr = encoder_arr + [tensor.to(self.device) for tensor in self.id_bank[character][cur_step]]
# 判断随机数是否大于0.5
if cur_step <1:
hidden_states = self.__call2__(attn, hidden_states,None,attention_mask,temb)
@ -140,7 +149,18 @@ class SpatialAttnProcessor2_0(torch.nn.Module):
hidden_states = hidden_states.reshape(-1,img_nums,nums_token,channel)
encoder_arr = [hidden_states[:,img_ind,indices[img_ind],:].reshape(2,-1,channel) for img_ind in range(img_nums)]
for img_ind in range(img_nums):
encoder_hidden_states_tmp = torch.cat(encoder_arr[0:img_ind] + encoder_arr[img_ind+1:] + [hidden_states[:,img_ind,:,:]],dim=1)
# print(img_nums)
# assert img_nums != 1
img_ind_list = [i for i in range(img_nums)]
# print(img_ind_list,img_ind)
img_ind_list.remove(img_ind)
# print(img_ind,invert_character_index_dict[img_ind])
# print(character_index_dict[invert_character_index_dict[img_ind]])
# print(img_ind_list)
# print(img_ind,img_ind_list)
encoder_hidden_states_tmp = torch.cat([encoder_arr[img_ind] for img_ind in img_ind_list] + [hidden_states[:,img_ind,:,:]],dim=1)
hidden_states[:,img_ind,:,:] = self.__call2__(attn, hidden_states[:,img_ind,:,:],encoder_hidden_states_tmp,None,temb)
else:
_,nums_token,channel = hidden_states.shape
@ -150,6 +170,7 @@ class SpatialAttnProcessor2_0(torch.nn.Module):
# print(len(indices))
# encoder_arr = [encoder_hidden_states[:,img_ind,indices[img_ind],:].reshape(2,-1,channel) for img_ind in range(img_nums)]
encoder_hidden_states_tmp = torch.cat(encoder_arr+[hidden_states[:,0,:,:]],dim=1)
# print(len(encoder_arr),encoder_hidden_states_tmp.shape)
hidden_states[:,0,:,:] = self.__call2__(attn, hidden_states[:,0,:,:],encoder_hidden_states_tmp,None,temb)
hidden_states = hidden_states.reshape(-1,nums_token,channel)
else:
@ -552,12 +573,18 @@ def process_generation(_sd_type,_model_type,_upload_images, _num_steps,style_nam
else:
unet = pipe.unet
# unet.set_attn_processor(copy.deepcopy(attn_procs))
if _model_type != "original":
input_id_images = []
for img in _upload_images:
print(img)
input_id_images.append(load_image(img))
prompts = prompt_array.splitlines()
global character_dict,character_index_dict,invert_character_index_dict,ref_indexs_dict,ref_totals
character_dict,character_list = character_to_dict(general_prompt)
start_merge_step = int(float(_style_strength_ratio) / 100 * _num_steps)
if start_merge_step > 30:
start_merge_step = 30
@ -572,41 +599,81 @@ def process_generation(_sd_type,_model_type,_upload_images, _num_steps,style_nam
nc_indexs.append(ind)
if ind < id_length:
raise gr.Error(f"The first {id_length} row is id prompts, cannot use [NC]!")
prompts = [general_prompt + "," + prompt if "[NC]" not in prompt else prompt.replace("[NC]","") for prompt in clipped_prompts]
prompts = [prompt if "[NC]" not in prompt else prompt.replace("[NC]","") for prompt in clipped_prompts]
prompts = [prompt.rpartition('#')[0] if "#" in prompt else prompt for prompt in prompts]
print(prompts)
id_prompts = prompts[:id_length]
real_prompts = prompts[id_length:]
#id_prompts = prompts[:id_length]
character_index_dict,invert_character_index_dict,replace_prompts,ref_indexs_dict,ref_totals = process_original_prompt(character_dict,prompts.copy(),id_length)
if _model_type != "original":
input_id_images_dict = {}
if len(_upload_images) != len(character_dict.keys()):
raise gr.Error(f"You upload images({len(_upload_images)}) is not equal to the number of characters({len(character_dict.keys())})!")
for ind,img in enumerate(_upload_images):
input_id_images_dict[character_list[ind]] = [load_image(img)]
print(character_dict)
print(character_index_dict)
print(invert_character_index_dict)
# real_prompts = prompts[id_length:]
torch.cuda.empty_cache()
write = True
cur_step = 0
attn_count = 0
id_prompts, negative_prompt = apply_style(style_name, id_prompts, negative_prompt)
# id_prompts, negative_prompt = apply_style(style_name, id_prompts, negative_prompt)
# print(id_prompts)
setup_seed(seed_)
total_results = []
if _model_type == "original":
id_images = pipe(id_prompts, num_inference_steps=_num_steps, guidance_scale=guidance_scale, height = height, width = width,negative_prompt = negative_prompt,generator = generator).images
elif _model_type == "Photomaker":
id_images = pipe(id_prompts,input_id_images=input_id_images, num_inference_steps=_num_steps, guidance_scale=guidance_scale, start_merge_step = start_merge_step, height = height, width = width,negative_prompt = negative_prompt,generator = generator).images
else:
raise NotImplementedError("You should choice between original and Photomaker!",f"But you choice {_model_type}")
total_results = id_images + total_results
yield total_results
real_images = []
write = False
for ind,real_prompt in enumerate(real_prompts):
id_images = []
results_dict = {}
global cur_character
for character_key in character_dict.keys():
cur_character = [character_key]
ref_indexs = ref_indexs_dict[character_key]
print(character_key,ref_indexs)
current_prompts = [replace_prompts[ref_ind] for ref_ind in ref_indexs]
print(current_prompts)
setup_seed(seed_)
generator = torch.Generator(device="cuda").manual_seed(seed_)
cur_step = 0
cur_positive_prompts, negative_prompt = apply_style(style_name, current_prompts, negative_prompt)
if _model_type == "original":
id_images = pipe(cur_positive_prompts, num_inference_steps=_num_steps, guidance_scale=guidance_scale, height = height, width = width,negative_prompt = negative_prompt,generator = generator).images
elif _model_type == "Photomaker":
id_images = pipe(cur_positive_prompts,input_id_images=input_id_images_dict[character_key], num_inference_steps=_num_steps, guidance_scale=guidance_scale, start_merge_step = start_merge_step, height = height, width = width,negative_prompt = negative_prompt,generator = generator).images
else:
raise NotImplementedError("You should choice between original and Photomaker!",f"But you choice {_model_type}")
# total_results = id_images + total_results
# yield total_results
print(id_images)
for ind,img in enumerate(id_images):
print(ref_indexs[ind])
results_dict[ref_indexs[ind]] = img
# real_images = []
write = False
real_prompts_inds = [ind for ind in range(len(prompts)) if ind not in ref_totals]
print(real_prompts_inds)
for real_prompts_ind in real_prompts_inds:
real_prompt = replace_prompts[real_prompts_ind]
cur_character = get_ref_character(prompts[real_prompts_ind],character_dict)
print(cur_character,real_prompt)
setup_seed(seed_)
if len(cur_character) > 1 and _model_type == "Photomaker":
raise gr.Error("Temporarily Not Support Multiple character in Ref Image Mode!")
generator = torch.Generator(device="cuda").manual_seed(seed_)
cur_step = 0
real_prompt = apply_style_positive(style_name, real_prompt)
if _model_type == "original":
real_images.append(pipe(real_prompt, num_inference_steps=_num_steps, guidance_scale=guidance_scale, height = height, width = width,negative_prompt = negative_prompt,generator = generator).images[0])
results_dict[real_prompts_ind] = (pipe(real_prompt, num_inference_steps=_num_steps, guidance_scale=guidance_scale, height = height, width = width,negative_prompt = negative_prompt,generator = generator).images[0])
elif _model_type == "Photomaker":
real_images.append(pipe(real_prompt, input_id_images=input_id_images, num_inference_steps=_num_steps, guidance_scale=guidance_scale, start_merge_step = start_merge_step, height = height, width = width,negative_prompt = negative_prompt,generator = generator,nc_flag = True if ind+id_length in nc_indexs else False).images[0])
results_dict[real_prompts_ind] = (pipe(real_prompt, input_id_images=input_id_images_dict[cur_character[0]] if real_prompts_ind not in nc_indexs else input_id_images_dict[character_list[0]], num_inference_steps=_num_steps, guidance_scale=guidance_scale, start_merge_step = start_merge_step, height = height, width = width,negative_prompt = negative_prompt,generator = generator,nc_flag = True if real_prompts_ind in nc_indexs else False).images[0])
else:
raise NotImplementedError("You should choice between original and Photomaker!",f"But you choice {_model_type}")
total_results = [real_images[-1]] + total_results
yield total_results
total_results = [results_dict[ind] for ind in range(len(prompts))]
if _comic_type != "No typesetting (default)":
captions= prompt_array.splitlines()
captions = [caption.replace("[NC]","") for caption in captions]
@ -616,7 +683,7 @@ def process_generation(_sd_type,_model_type,_upload_images, _num_steps,style_nam
font_path = os.path.join("fonts", font_choice)
print(f"Attempting to load font from path: {font_path}")
font = ImageFont.truetype(font_path, int(45))
total_results = get_comic(id_images + real_images, _comic_type, captions=captions, font=font) + total_results
total_results = get_comic(total_results, _comic_type, captions=captions, font=font) + total_results
yield total_results
@ -660,7 +727,7 @@ with gr.Blocks(css=css) as demo:
uploaded_files = gr.Gallery(label="Your images", visible=False, columns=5, rows=1, height=200)
with gr.Column(visible=False) as clear_button:
remove_and_reupload = gr.ClearButton(value="Remove and upload new ones", components=files, size="sm")
general_prompt = gr.Textbox(value='', label="(1) Textual Description for Character", interactive=True)
general_prompt = gr.Textbox(value='',lines = 2, label="(1) Textual Description for Character", interactive=True)
negative_prompt = gr.Textbox(value='', label="(2) Negative_prompt", interactive=True)
style = gr.Dropdown(label="Style template", choices=STYLE_NAMES, value=DEFAULT_STYLE_NAME)
prompt_array = gr.Textbox(lines = 3,value='', label="(3) Comic Description (each line corresponds to a frame).", interactive=True)
@ -733,65 +800,69 @@ with gr.Blocks(css=css) as demo:
gr.Examples(
examples=[
[0,0.5,0.5,2,"a man, wearing black suit",
[0,0.5,0.5,2,"[Bob] A man, wearing a black suit\n[Alice]a woman, wearing a white shirt",
"bad anatomy, bad hands, missing fingers, extra fingers, three hands, three legs, bad arms, missing legs, missing arms, poorly drawn face, bad face, fused face, cloned face, three crus, fused feet, fused thigh, extra crus, ugly fingers, horn, cartoon, cg, 3d, unreal, animate, amputation, disconnected limbs",
array2string(["at home, read new paper #at home, The newspaper says there is a treasure house in the forest.",
"on the road, near the forest",
"[NC] The car on the road, near the forest #He drives to the forest in search of treasure.",
array2string(["[Bob] at home, read new paper #at home, The newspaper says there is a treasure house in the forest.",
"[Bob] on the road, near the forest",
"[Alice] is make a call at home # [Bob] invited [Alice] to join him on an adventure.",
"[NC]A tiger appeared in the forest, at night ",
"very frightened, open mouth, in the forest, at night",
"running very fast, in the forest, at night",
"[NC] A house in the forest, at night #Suddenly, he discovers the treasure house!",
"in the house filled with treasure, laughing, at night #He is overjoyed inside the house."
"[NC] The car on the road, near the forest #They drives to the forest in search of treasure.",
"[Bob] very frightened, open mouth, in the forest, at night",
"[Alice] very frightened, open mouth, in the forest, at night",
"[Bob] and [Alice] running very fast, in the forest, at night",
"[NC] A house in the forest, at night #Suddenly, They discovers the treasure house!",
"[Bob] and [Alice] in the house filled with treasure, laughing, at night #He is overjoyed inside the house."
]),
"Comic book","Only Using Textual Description",get_image_path_list('./examples/taylor'),768,768
],
[0,0.5,0.5,2,"a man, wearing black suit",
[0,0.5,0.5,2,"[Bob] A man img, wearing a black suit\n[Alice]a woman img, wearing a white shirt",
"bad anatomy, bad hands, missing fingers, extra fingers, three hands, three legs, bad arms, missing legs, missing arms, poorly drawn face, bad face, fused face, cloned face, three crus, fused feet, fused thigh, extra crus, ugly fingers, horn, cartoon, cg, 3d, unreal, animate, amputation, disconnected limbs",
array2string(["at home, read new paper #at home, The newspaper says there is a treasure house in the forest.",
"on the road, near the forest",
"[NC] The car on the road, near the forest #He drives to the forest in search of treasure.",
array2string(["[Bob] at home, read new paper #at home, The newspaper says there is a treasure house in the forest.",
"[Bob] on the road, near the forest",
"[Alice] is make a call at home # [Bob] invited [Alice] to join him on an adventure.",
"[NC] The car on the road, near the forest #They drives to the forest in search of treasure.",
"[NC]A tiger appeared in the forest, at night ",
"very frightened, open mouth, in the forest, at night",
"running very fast, in the forest, at night",
"[NC] A house in the forest, at night #Suddenly, he discovers the treasure house!",
"in the house filled with treasure, laughing, at night #He is overjoyed inside the house."
"[Bob] very frightened, open mouth, in the forest, at night",
"[Alice] very frightened, open mouth, in the forest, at night",
"[Bob] running very fast, in the forest, at night",
"[NC] A house in the forest, at night #Suddenly, They discovers the treasure house!",
"[Bob] in the house filled with treasure, laughing, at night #They are overjoyed inside the house."
]),
"Comic book","Using Ref Images",get_image_path_list('./examples/Robert'),1024,1024
"Comic book","Using Ref Images",get_image_path_list('./examples/twoperson'),1024,1024
],
[1,0.5,0.5,3,"a woman img, wearing a white T-shirt, blue loose hair",
[1,0.5,0.5,3,"[Taylor]a woman img, wearing a white T-shirt, blue loose hair",
"bad anatomy, bad hands, missing fingers, extra fingers, three hands, three legs, bad arms, missing legs, missing arms, poorly drawn face, bad face, fused face, cloned face, three crus, fused feet, fused thigh, extra crus, ugly fingers, horn, cartoon, cg, 3d, unreal, animate, amputation, disconnected limbs",
array2string(["wake up in the bed",
"have breakfast",
"is on the road, go to company",
"work in the company",
"Take a walk next to the company at noon",
"lying in bed at night"]),
array2string(["[Taylor]wake up in the bed",
"[Taylor]have breakfast",
"[Taylor]is on the road, go to company",
"[Taylor]work in the company",
"[Taylor]Take a walk next to the company at noon",
"[Taylor]lying in bed at night"]),
"Japanese Anime", "Using Ref Images",get_image_path_list('./examples/taylor'),768,768
],
[0,0.5,0.5,3,"a man, wearing black jacket",
[0,0.5,0.5,3,"[Bob]a man, wearing black jacket",
"bad anatomy, bad hands, missing fingers, extra fingers, three hands, three legs, bad arms, missing legs, missing arms, poorly drawn face, bad face, fused face, cloned face, three crus, fused feet, fused thigh, extra crus, ugly fingers, horn, cartoon, cg, 3d, unreal, animate, amputation, disconnected limbs",
array2string(["wake up in the bed",
"have breakfast",
"is on the road, go to the company, close look",
"work in the company",
"laughing happily",
"lying in bed at night"
array2string(["[Bob]wake up in the bed",
"[Bob]have breakfast",
"[Bob]is on the road, go to the company, close look",
"[Bob]work in the company",
"[Bob]laughing happily",
"[Bob]lying in bed at night"
]),
"Japanese Anime","Only Using Textual Description",get_image_path_list('./examples/taylor'),768,768
],
[0,0.3,0.5,3,"a girl, wearing white shirt, black skirt, black tie, yellow hair",
[0,0.3,0.5,3,"[Kitty]a girl, wearing white shirt, black skirt, black tie, yellow hair",
"bad anatomy, bad hands, missing fingers, extra fingers, three hands, three legs, bad arms, missing legs, missing arms, poorly drawn face, bad face, fused face, cloned face, three crus, fused feet, fused thigh, extra crus, ugly fingers, horn, cartoon, cg, 3d, unreal, animate, amputation, disconnected limbs",
array2string([
"at home #at home, began to go to drawing",
"sitting alone on a park bench.",
"reading a book on a park bench.",
"[Kitty]at home #at home, began to go to drawing",
"[Kitty]sitting alone on a park bench.",
"[Kitty]reading a book on a park bench.",
"[NC]A squirrel approaches, peeking over the bench. ",
"look around in the park. # She looks around and enjoys the beauty of nature.",
"[Kitty]look around in the park. # She looks around and enjoys the beauty of nature.",
"[NC]leaf falls from the tree, landing on the sketchbook.",
"picks up the leaf, examining its details closely.",
"[Kitty]picks up the leaf, examining its details closely.",
"[NC]The brown squirrel appear.",
"is very happy # She is very happy to see the squirrel again",
"[Kitty]is very happy # She is very happy to see the squirrel again",
"[NC]The brown squirrel takes the cracker and scampers up a tree. # She gives the squirrel cracker"]),
"Japanese Anime","Only Using Textual Description",get_image_path_list('./examples/taylor'),768,768
]
@ -804,4 +875,4 @@ with gr.Blocks(css=css) as demo:
gr.Markdown(article)
demo.launch(server_name="0.0.0.0", share = False)
demo.launch(server_name="0.0.0.0", share = True)

View File

@ -1,8 +1,11 @@
from calendar import c
from operator import invert
from webbrowser import get
import torch
import random
import torch.nn as nn
import torch.nn.functional as F
import gradio as gr
class SpatialAttnProcessor2_0(torch.nn.Module):
r"""
@ -425,4 +428,92 @@ class AttnProcessor2_0(torch.nn.Module):
def is_torch2_available():
return hasattr(F, "scaled_dot_product_attention")
return hasattr(F, "scaled_dot_product_attention")
# 将列表转换为字典的函数
def character_to_dict(general_prompt):
character_dict = {}
generate_prompt_arr = general_prompt.splitlines()
character_index_dict = {}
invert_character_index_dict = {}
character_list = []
for ind,string in enumerate(generate_prompt_arr):
# 分割字符串寻找key和value
start = string.find('[')
end = string.find(']')
if start != -1 and end != -1:
key = string[start:end+1]
value = string[end+1:]
if "#" in value:
value = value.rpartition('#')[0]
if key in character_dict:
raise gr.Error("duplicate character descirption: " + key)
character_dict[key] = value
character_list.append(key)
return character_dict,character_list
def get_id_prompt_index(character_dict,id_prompts):
replace_id_prompts = []
character_index_dict = {}
invert_character_index_dict = {}
for ind,id_prompt in enumerate(id_prompts):
for key in character_dict.keys():
if key in id_prompt:
if key not in character_index_dict:
character_index_dict[key] = []
character_index_dict[key].append(ind)
invert_character_index_dict[ind] = key
replace_id_prompts.append(id_prompt.replace(key,character_dict[key]))
return character_index_dict,invert_character_index_dict,replace_id_prompts
def get_cur_id_list(real_prompt,character_dict,character_index_dict):
list_arr = []
for keys in character_index_dict.keys():
if keys in real_prompt:
list_arr = list_arr + character_index_dict[keys]
real_prompt = real_prompt.replace(keys,character_dict[keys])
return list_arr,real_prompt
def process_original_prompt(character_dict,prompts,id_length):
replace_prompts = []
character_index_dict = {}
invert_character_index_dict = {}
for ind,prompt in enumerate(prompts):
for key in character_dict.keys():
if key in prompt:
if key not in character_index_dict:
character_index_dict[key] = []
character_index_dict[key].append(ind)
if ind not in invert_character_index_dict:
invert_character_index_dict[ind] = []
invert_character_index_dict[ind].append(key)
cur_prompt = prompt
if ind in invert_character_index_dict:
for key in invert_character_index_dict[ind]:
cur_prompt = cur_prompt.replace(key,character_dict[key])
replace_prompts.append(cur_prompt)
ref_index_dict = {}
ref_totals = []
print(character_index_dict)
for character_key in character_index_dict.keys():
if character_key not in character_index_dict:
raise gr.Error("{} not have prompt description, please remove it".format(character_key))
index_list = character_index_dict[character_key]
index_list = [index for index in index_list if len(invert_character_index_dict[index]) == 1]
if len(index_list) < id_length:
raise gr.Error(f"{character_key} not have enough prompt description, need no less than {id_length}, but you give {len(index_list)}")
ref_index_dict[character_key] = index_list[:id_length]
ref_totals = ref_totals + index_list[:id_length]
return character_index_dict,invert_character_index_dict,replace_prompts,ref_index_dict,ref_totals
def get_ref_character(real_prompt,character_dict):
list_arr = []
for keys in character_dict.keys():
if keys in real_prompt:
list_arr = list_arr + [keys]
return list_arr