support mul character
parent
5f62ae939a
commit
ae1398a451
Binary file not shown.
After Width: | Height: | Size: 6.8 KiB |
Binary file not shown.
After Width: | Height: | Size: 48 KiB |
|
@ -15,6 +15,7 @@ from PIL import Image
|
|||
from tqdm.auto import tqdm
|
||||
from datetime import datetime
|
||||
from utils.gradio_utils import is_torch2_available
|
||||
from utils.gradio_utils import get_id_prompt_index, character_to_dict,get_cur_id_list, process_original_prompt, get_ref_character
|
||||
if is_torch2_available():
|
||||
from utils.gradio_utils import \
|
||||
AttnProcessor2_0 as AttnProcessor
|
||||
|
@ -101,9 +102,11 @@ class SpatialAttnProcessor2_0(torch.nn.Module):
|
|||
global sa32, sa64
|
||||
global write
|
||||
global height,width
|
||||
global character_dict,character_index_dict,invert_character_index_dict,cur_character,ref_indexs_dict,ref_totals,cur_character
|
||||
if attn_count == 0 and cur_step == 0:
|
||||
indices1024,indices4096 = cal_attn_indice_xl_effcient_memory(self.total_length,self.id_length,sa32,sa64,height,width, device=self.device, dtype= self.dtype)
|
||||
if write:
|
||||
assert len(cur_character) == 1
|
||||
if hidden_states.shape[1] == (height//32) * (width//32):
|
||||
indices = indices1024
|
||||
else:
|
||||
|
@ -112,12 +115,18 @@ class SpatialAttnProcessor2_0(torch.nn.Module):
|
|||
total_batch_size,nums_token,channel = hidden_states.shape
|
||||
img_nums = total_batch_size // 2
|
||||
hidden_states = hidden_states.reshape(-1,img_nums,nums_token,channel)
|
||||
self.id_bank[cur_step] = [hidden_states[:,img_ind,indices[img_ind],:].reshape(2,-1,channel).clone() for img_ind in range(img_nums)]
|
||||
# print(img_nums,len(indices),hidden_states.shape,self.total_length)
|
||||
if cur_character[0] not in self.id_bank:
|
||||
self.id_bank[cur_character[0]] = {}
|
||||
self.id_bank[cur_character[0]][cur_step] = [hidden_states[:,img_ind,indices[img_ind],:].reshape(2,-1,channel).clone() for img_ind in range(img_nums)]
|
||||
hidden_states = hidden_states.reshape(-1,nums_token,channel)
|
||||
#self.id_bank[cur_step] = [hidden_states[:self.id_length].clone(), hidden_states[self.id_length:].clone()]
|
||||
else:
|
||||
#encoder_hidden_states = torch.cat((self.id_bank[cur_step][0].to(self.device),self.id_bank[cur_step][1].to(self.device)))
|
||||
encoder_arr = [tensor.to(self.device) for tensor in self.id_bank[cur_step]]
|
||||
# TODO: ADD Multipersion Control
|
||||
encoder_arr = []
|
||||
for character in cur_character:
|
||||
encoder_arr = encoder_arr + [tensor.to(self.device) for tensor in self.id_bank[character][cur_step]]
|
||||
# 判断随机数是否大于0.5
|
||||
if cur_step <1:
|
||||
hidden_states = self.__call2__(attn, hidden_states,None,attention_mask,temb)
|
||||
|
@ -140,7 +149,18 @@ class SpatialAttnProcessor2_0(torch.nn.Module):
|
|||
hidden_states = hidden_states.reshape(-1,img_nums,nums_token,channel)
|
||||
encoder_arr = [hidden_states[:,img_ind,indices[img_ind],:].reshape(2,-1,channel) for img_ind in range(img_nums)]
|
||||
for img_ind in range(img_nums):
|
||||
encoder_hidden_states_tmp = torch.cat(encoder_arr[0:img_ind] + encoder_arr[img_ind+1:] + [hidden_states[:,img_ind,:,:]],dim=1)
|
||||
# print(img_nums)
|
||||
# assert img_nums != 1
|
||||
img_ind_list = [i for i in range(img_nums)]
|
||||
# print(img_ind_list,img_ind)
|
||||
img_ind_list.remove(img_ind)
|
||||
# print(img_ind,invert_character_index_dict[img_ind])
|
||||
# print(character_index_dict[invert_character_index_dict[img_ind]])
|
||||
# print(img_ind_list)
|
||||
# print(img_ind,img_ind_list)
|
||||
encoder_hidden_states_tmp = torch.cat([encoder_arr[img_ind] for img_ind in img_ind_list] + [hidden_states[:,img_ind,:,:]],dim=1)
|
||||
|
||||
|
||||
hidden_states[:,img_ind,:,:] = self.__call2__(attn, hidden_states[:,img_ind,:,:],encoder_hidden_states_tmp,None,temb)
|
||||
else:
|
||||
_,nums_token,channel = hidden_states.shape
|
||||
|
@ -150,6 +170,7 @@ class SpatialAttnProcessor2_0(torch.nn.Module):
|
|||
# print(len(indices))
|
||||
# encoder_arr = [encoder_hidden_states[:,img_ind,indices[img_ind],:].reshape(2,-1,channel) for img_ind in range(img_nums)]
|
||||
encoder_hidden_states_tmp = torch.cat(encoder_arr+[hidden_states[:,0,:,:]],dim=1)
|
||||
# print(len(encoder_arr),encoder_hidden_states_tmp.shape)
|
||||
hidden_states[:,0,:,:] = self.__call2__(attn, hidden_states[:,0,:,:],encoder_hidden_states_tmp,None,temb)
|
||||
hidden_states = hidden_states.reshape(-1,nums_token,channel)
|
||||
else:
|
||||
|
@ -552,12 +573,18 @@ def process_generation(_sd_type,_model_type,_upload_images, _num_steps,style_nam
|
|||
else:
|
||||
unet = pipe.unet
|
||||
# unet.set_attn_processor(copy.deepcopy(attn_procs))
|
||||
if _model_type != "original":
|
||||
input_id_images = []
|
||||
for img in _upload_images:
|
||||
print(img)
|
||||
input_id_images.append(load_image(img))
|
||||
|
||||
|
||||
|
||||
prompts = prompt_array.splitlines()
|
||||
global character_dict,character_index_dict,invert_character_index_dict,ref_indexs_dict,ref_totals
|
||||
character_dict,character_list = character_to_dict(general_prompt)
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
start_merge_step = int(float(_style_strength_ratio) / 100 * _num_steps)
|
||||
if start_merge_step > 30:
|
||||
start_merge_step = 30
|
||||
|
@ -572,41 +599,81 @@ def process_generation(_sd_type,_model_type,_upload_images, _num_steps,style_nam
|
|||
nc_indexs.append(ind)
|
||||
if ind < id_length:
|
||||
raise gr.Error(f"The first {id_length} row is id prompts, cannot use [NC]!")
|
||||
prompts = [general_prompt + "," + prompt if "[NC]" not in prompt else prompt.replace("[NC]","") for prompt in clipped_prompts]
|
||||
prompts = [prompt if "[NC]" not in prompt else prompt.replace("[NC]","") for prompt in clipped_prompts]
|
||||
|
||||
prompts = [prompt.rpartition('#')[0] if "#" in prompt else prompt for prompt in prompts]
|
||||
print(prompts)
|
||||
id_prompts = prompts[:id_length]
|
||||
real_prompts = prompts[id_length:]
|
||||
#id_prompts = prompts[:id_length]
|
||||
character_index_dict,invert_character_index_dict,replace_prompts,ref_indexs_dict,ref_totals = process_original_prompt(character_dict,prompts.copy(),id_length)
|
||||
if _model_type != "original":
|
||||
input_id_images_dict = {}
|
||||
if len(_upload_images) != len(character_dict.keys()):
|
||||
raise gr.Error(f"You upload images({len(_upload_images)}) is not equal to the number of characters({len(character_dict.keys())})!")
|
||||
for ind,img in enumerate(_upload_images):
|
||||
input_id_images_dict[character_list[ind]] = [load_image(img)]
|
||||
print(character_dict)
|
||||
print(character_index_dict)
|
||||
print(invert_character_index_dict)
|
||||
# real_prompts = prompts[id_length:]
|
||||
torch.cuda.empty_cache()
|
||||
write = True
|
||||
cur_step = 0
|
||||
|
||||
attn_count = 0
|
||||
id_prompts, negative_prompt = apply_style(style_name, id_prompts, negative_prompt)
|
||||
# id_prompts, negative_prompt = apply_style(style_name, id_prompts, negative_prompt)
|
||||
# print(id_prompts)
|
||||
setup_seed(seed_)
|
||||
total_results = []
|
||||
if _model_type == "original":
|
||||
id_images = pipe(id_prompts, num_inference_steps=_num_steps, guidance_scale=guidance_scale, height = height, width = width,negative_prompt = negative_prompt,generator = generator).images
|
||||
elif _model_type == "Photomaker":
|
||||
id_images = pipe(id_prompts,input_id_images=input_id_images, num_inference_steps=_num_steps, guidance_scale=guidance_scale, start_merge_step = start_merge_step, height = height, width = width,negative_prompt = negative_prompt,generator = generator).images
|
||||
else:
|
||||
raise NotImplementedError("You should choice between original and Photomaker!",f"But you choice {_model_type}")
|
||||
total_results = id_images + total_results
|
||||
yield total_results
|
||||
real_images = []
|
||||
write = False
|
||||
for ind,real_prompt in enumerate(real_prompts):
|
||||
id_images = []
|
||||
results_dict = {}
|
||||
global cur_character
|
||||
for character_key in character_dict.keys():
|
||||
cur_character = [character_key]
|
||||
ref_indexs = ref_indexs_dict[character_key]
|
||||
print(character_key,ref_indexs)
|
||||
current_prompts = [replace_prompts[ref_ind] for ref_ind in ref_indexs]
|
||||
print(current_prompts)
|
||||
setup_seed(seed_)
|
||||
generator = torch.Generator(device="cuda").manual_seed(seed_)
|
||||
cur_step = 0
|
||||
cur_positive_prompts, negative_prompt = apply_style(style_name, current_prompts, negative_prompt)
|
||||
if _model_type == "original":
|
||||
id_images = pipe(cur_positive_prompts, num_inference_steps=_num_steps, guidance_scale=guidance_scale, height = height, width = width,negative_prompt = negative_prompt,generator = generator).images
|
||||
elif _model_type == "Photomaker":
|
||||
id_images = pipe(cur_positive_prompts,input_id_images=input_id_images_dict[character_key], num_inference_steps=_num_steps, guidance_scale=guidance_scale, start_merge_step = start_merge_step, height = height, width = width,negative_prompt = negative_prompt,generator = generator).images
|
||||
else:
|
||||
raise NotImplementedError("You should choice between original and Photomaker!",f"But you choice {_model_type}")
|
||||
|
||||
# total_results = id_images + total_results
|
||||
# yield total_results
|
||||
print(id_images)
|
||||
for ind,img in enumerate(id_images):
|
||||
print(ref_indexs[ind])
|
||||
results_dict[ref_indexs[ind]] = img
|
||||
# real_images = []
|
||||
write = False
|
||||
|
||||
real_prompts_inds = [ind for ind in range(len(prompts)) if ind not in ref_totals]
|
||||
print(real_prompts_inds)
|
||||
|
||||
for real_prompts_ind in real_prompts_inds:
|
||||
real_prompt = replace_prompts[real_prompts_ind]
|
||||
cur_character = get_ref_character(prompts[real_prompts_ind],character_dict)
|
||||
print(cur_character,real_prompt)
|
||||
setup_seed(seed_)
|
||||
if len(cur_character) > 1 and _model_type == "Photomaker":
|
||||
raise gr.Error("Temporarily Not Support Multiple character in Ref Image Mode!")
|
||||
generator = torch.Generator(device="cuda").manual_seed(seed_)
|
||||
cur_step = 0
|
||||
real_prompt = apply_style_positive(style_name, real_prompt)
|
||||
if _model_type == "original":
|
||||
real_images.append(pipe(real_prompt, num_inference_steps=_num_steps, guidance_scale=guidance_scale, height = height, width = width,negative_prompt = negative_prompt,generator = generator).images[0])
|
||||
results_dict[real_prompts_ind] = (pipe(real_prompt, num_inference_steps=_num_steps, guidance_scale=guidance_scale, height = height, width = width,negative_prompt = negative_prompt,generator = generator).images[0])
|
||||
elif _model_type == "Photomaker":
|
||||
real_images.append(pipe(real_prompt, input_id_images=input_id_images, num_inference_steps=_num_steps, guidance_scale=guidance_scale, start_merge_step = start_merge_step, height = height, width = width,negative_prompt = negative_prompt,generator = generator,nc_flag = True if ind+id_length in nc_indexs else False).images[0])
|
||||
results_dict[real_prompts_ind] = (pipe(real_prompt, input_id_images=input_id_images_dict[cur_character[0]] if real_prompts_ind not in nc_indexs else input_id_images_dict[character_list[0]], num_inference_steps=_num_steps, guidance_scale=guidance_scale, start_merge_step = start_merge_step, height = height, width = width,negative_prompt = negative_prompt,generator = generator,nc_flag = True if real_prompts_ind in nc_indexs else False).images[0])
|
||||
else:
|
||||
raise NotImplementedError("You should choice between original and Photomaker!",f"But you choice {_model_type}")
|
||||
total_results = [real_images[-1]] + total_results
|
||||
yield total_results
|
||||
|
||||
total_results = [results_dict[ind] for ind in range(len(prompts))]
|
||||
if _comic_type != "No typesetting (default)":
|
||||
captions= prompt_array.splitlines()
|
||||
captions = [caption.replace("[NC]","") for caption in captions]
|
||||
|
@ -616,7 +683,7 @@ def process_generation(_sd_type,_model_type,_upload_images, _num_steps,style_nam
|
|||
font_path = os.path.join("fonts", font_choice)
|
||||
print(f"Attempting to load font from path: {font_path}")
|
||||
font = ImageFont.truetype(font_path, int(45))
|
||||
total_results = get_comic(id_images + real_images, _comic_type, captions=captions, font=font) + total_results
|
||||
total_results = get_comic(total_results, _comic_type, captions=captions, font=font) + total_results
|
||||
yield total_results
|
||||
|
||||
|
||||
|
@ -660,7 +727,7 @@ with gr.Blocks(css=css) as demo:
|
|||
uploaded_files = gr.Gallery(label="Your images", visible=False, columns=5, rows=1, height=200)
|
||||
with gr.Column(visible=False) as clear_button:
|
||||
remove_and_reupload = gr.ClearButton(value="Remove and upload new ones", components=files, size="sm")
|
||||
general_prompt = gr.Textbox(value='', label="(1) Textual Description for Character", interactive=True)
|
||||
general_prompt = gr.Textbox(value='',lines = 2, label="(1) Textual Description for Character", interactive=True)
|
||||
negative_prompt = gr.Textbox(value='', label="(2) Negative_prompt", interactive=True)
|
||||
style = gr.Dropdown(label="Style template", choices=STYLE_NAMES, value=DEFAULT_STYLE_NAME)
|
||||
prompt_array = gr.Textbox(lines = 3,value='', label="(3) Comic Description (each line corresponds to a frame).", interactive=True)
|
||||
|
@ -733,65 +800,69 @@ with gr.Blocks(css=css) as demo:
|
|||
|
||||
gr.Examples(
|
||||
examples=[
|
||||
[0,0.5,0.5,2,"a man, wearing black suit",
|
||||
[0,0.5,0.5,2,"[Bob] A man, wearing a black suit\n[Alice]a woman, wearing a white shirt",
|
||||
"bad anatomy, bad hands, missing fingers, extra fingers, three hands, three legs, bad arms, missing legs, missing arms, poorly drawn face, bad face, fused face, cloned face, three crus, fused feet, fused thigh, extra crus, ugly fingers, horn, cartoon, cg, 3d, unreal, animate, amputation, disconnected limbs",
|
||||
array2string(["at home, read new paper #at home, The newspaper says there is a treasure house in the forest.",
|
||||
"on the road, near the forest",
|
||||
"[NC] The car on the road, near the forest #He drives to the forest in search of treasure.",
|
||||
array2string(["[Bob] at home, read new paper #at home, The newspaper says there is a treasure house in the forest.",
|
||||
"[Bob] on the road, near the forest",
|
||||
"[Alice] is make a call at home # [Bob] invited [Alice] to join him on an adventure.",
|
||||
"[NC]A tiger appeared in the forest, at night ",
|
||||
"very frightened, open mouth, in the forest, at night",
|
||||
"running very fast, in the forest, at night",
|
||||
"[NC] A house in the forest, at night #Suddenly, he discovers the treasure house!",
|
||||
"in the house filled with treasure, laughing, at night #He is overjoyed inside the house."
|
||||
"[NC] The car on the road, near the forest #They drives to the forest in search of treasure.",
|
||||
"[Bob] very frightened, open mouth, in the forest, at night",
|
||||
"[Alice] very frightened, open mouth, in the forest, at night",
|
||||
"[Bob] and [Alice] running very fast, in the forest, at night",
|
||||
"[NC] A house in the forest, at night #Suddenly, They discovers the treasure house!",
|
||||
"[Bob] and [Alice] in the house filled with treasure, laughing, at night #He is overjoyed inside the house."
|
||||
]),
|
||||
"Comic book","Only Using Textual Description",get_image_path_list('./examples/taylor'),768,768
|
||||
],
|
||||
[0,0.5,0.5,2,"a man, wearing black suit",
|
||||
[0,0.5,0.5,2,"[Bob] A man img, wearing a black suit\n[Alice]a woman img, wearing a white shirt",
|
||||
"bad anatomy, bad hands, missing fingers, extra fingers, three hands, three legs, bad arms, missing legs, missing arms, poorly drawn face, bad face, fused face, cloned face, three crus, fused feet, fused thigh, extra crus, ugly fingers, horn, cartoon, cg, 3d, unreal, animate, amputation, disconnected limbs",
|
||||
array2string(["at home, read new paper #at home, The newspaper says there is a treasure house in the forest.",
|
||||
"on the road, near the forest",
|
||||
"[NC] The car on the road, near the forest #He drives to the forest in search of treasure.",
|
||||
array2string(["[Bob] at home, read new paper #at home, The newspaper says there is a treasure house in the forest.",
|
||||
"[Bob] on the road, near the forest",
|
||||
"[Alice] is make a call at home # [Bob] invited [Alice] to join him on an adventure.",
|
||||
"[NC] The car on the road, near the forest #They drives to the forest in search of treasure.",
|
||||
"[NC]A tiger appeared in the forest, at night ",
|
||||
"very frightened, open mouth, in the forest, at night",
|
||||
"running very fast, in the forest, at night",
|
||||
"[NC] A house in the forest, at night #Suddenly, he discovers the treasure house!",
|
||||
"in the house filled with treasure, laughing, at night #He is overjoyed inside the house."
|
||||
"[Bob] very frightened, open mouth, in the forest, at night",
|
||||
"[Alice] very frightened, open mouth, in the forest, at night",
|
||||
"[Bob] running very fast, in the forest, at night",
|
||||
"[NC] A house in the forest, at night #Suddenly, They discovers the treasure house!",
|
||||
"[Bob] in the house filled with treasure, laughing, at night #They are overjoyed inside the house."
|
||||
]),
|
||||
"Comic book","Using Ref Images",get_image_path_list('./examples/Robert'),1024,1024
|
||||
"Comic book","Using Ref Images",get_image_path_list('./examples/twoperson'),1024,1024
|
||||
],
|
||||
[1,0.5,0.5,3,"a woman img, wearing a white T-shirt, blue loose hair",
|
||||
[1,0.5,0.5,3,"[Taylor]a woman img, wearing a white T-shirt, blue loose hair",
|
||||
"bad anatomy, bad hands, missing fingers, extra fingers, three hands, three legs, bad arms, missing legs, missing arms, poorly drawn face, bad face, fused face, cloned face, three crus, fused feet, fused thigh, extra crus, ugly fingers, horn, cartoon, cg, 3d, unreal, animate, amputation, disconnected limbs",
|
||||
array2string(["wake up in the bed",
|
||||
"have breakfast",
|
||||
"is on the road, go to company",
|
||||
"work in the company",
|
||||
"Take a walk next to the company at noon",
|
||||
"lying in bed at night"]),
|
||||
array2string(["[Taylor]wake up in the bed",
|
||||
"[Taylor]have breakfast",
|
||||
"[Taylor]is on the road, go to company",
|
||||
"[Taylor]work in the company",
|
||||
"[Taylor]Take a walk next to the company at noon",
|
||||
"[Taylor]lying in bed at night"]),
|
||||
"Japanese Anime", "Using Ref Images",get_image_path_list('./examples/taylor'),768,768
|
||||
],
|
||||
[0,0.5,0.5,3,"a man, wearing black jacket",
|
||||
[0,0.5,0.5,3,"[Bob]a man, wearing black jacket",
|
||||
"bad anatomy, bad hands, missing fingers, extra fingers, three hands, three legs, bad arms, missing legs, missing arms, poorly drawn face, bad face, fused face, cloned face, three crus, fused feet, fused thigh, extra crus, ugly fingers, horn, cartoon, cg, 3d, unreal, animate, amputation, disconnected limbs",
|
||||
array2string(["wake up in the bed",
|
||||
"have breakfast",
|
||||
"is on the road, go to the company, close look",
|
||||
"work in the company",
|
||||
"laughing happily",
|
||||
"lying in bed at night"
|
||||
array2string(["[Bob]wake up in the bed",
|
||||
"[Bob]have breakfast",
|
||||
"[Bob]is on the road, go to the company, close look",
|
||||
"[Bob]work in the company",
|
||||
"[Bob]laughing happily",
|
||||
"[Bob]lying in bed at night"
|
||||
]),
|
||||
"Japanese Anime","Only Using Textual Description",get_image_path_list('./examples/taylor'),768,768
|
||||
],
|
||||
[0,0.3,0.5,3,"a girl, wearing white shirt, black skirt, black tie, yellow hair",
|
||||
[0,0.3,0.5,3,"[Kitty]a girl, wearing white shirt, black skirt, black tie, yellow hair",
|
||||
"bad anatomy, bad hands, missing fingers, extra fingers, three hands, three legs, bad arms, missing legs, missing arms, poorly drawn face, bad face, fused face, cloned face, three crus, fused feet, fused thigh, extra crus, ugly fingers, horn, cartoon, cg, 3d, unreal, animate, amputation, disconnected limbs",
|
||||
array2string([
|
||||
"at home #at home, began to go to drawing",
|
||||
"sitting alone on a park bench.",
|
||||
"reading a book on a park bench.",
|
||||
"[Kitty]at home #at home, began to go to drawing",
|
||||
"[Kitty]sitting alone on a park bench.",
|
||||
"[Kitty]reading a book on a park bench.",
|
||||
"[NC]A squirrel approaches, peeking over the bench. ",
|
||||
"look around in the park. # She looks around and enjoys the beauty of nature.",
|
||||
"[Kitty]look around in the park. # She looks around and enjoys the beauty of nature.",
|
||||
"[NC]leaf falls from the tree, landing on the sketchbook.",
|
||||
"picks up the leaf, examining its details closely.",
|
||||
"[Kitty]picks up the leaf, examining its details closely.",
|
||||
"[NC]The brown squirrel appear.",
|
||||
"is very happy # She is very happy to see the squirrel again",
|
||||
"[Kitty]is very happy # She is very happy to see the squirrel again",
|
||||
"[NC]The brown squirrel takes the cracker and scampers up a tree. # She gives the squirrel cracker"]),
|
||||
"Japanese Anime","Only Using Textual Description",get_image_path_list('./examples/taylor'),768,768
|
||||
]
|
||||
|
@ -804,4 +875,4 @@ with gr.Blocks(css=css) as demo:
|
|||
gr.Markdown(article)
|
||||
|
||||
|
||||
demo.launch(server_name="0.0.0.0", share = False)
|
||||
demo.launch(server_name="0.0.0.0", share = True)
|
|
@ -1,8 +1,11 @@
|
|||
from calendar import c
|
||||
from operator import invert
|
||||
from webbrowser import get
|
||||
import torch
|
||||
import random
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
import gradio as gr
|
||||
|
||||
class SpatialAttnProcessor2_0(torch.nn.Module):
|
||||
r"""
|
||||
|
@ -425,4 +428,92 @@ class AttnProcessor2_0(torch.nn.Module):
|
|||
|
||||
|
||||
def is_torch2_available():
|
||||
return hasattr(F, "scaled_dot_product_attention")
|
||||
return hasattr(F, "scaled_dot_product_attention")
|
||||
|
||||
|
||||
# 将列表转换为字典的函数
|
||||
def character_to_dict(general_prompt):
|
||||
character_dict = {}
|
||||
generate_prompt_arr = general_prompt.splitlines()
|
||||
character_index_dict = {}
|
||||
invert_character_index_dict = {}
|
||||
character_list = []
|
||||
for ind,string in enumerate(generate_prompt_arr):
|
||||
# 分割字符串寻找key和value
|
||||
start = string.find('[')
|
||||
end = string.find(']')
|
||||
if start != -1 and end != -1:
|
||||
key = string[start:end+1]
|
||||
value = string[end+1:]
|
||||
if "#" in value:
|
||||
value = value.rpartition('#')[0]
|
||||
if key in character_dict:
|
||||
raise gr.Error("duplicate character descirption: " + key)
|
||||
character_dict[key] = value
|
||||
character_list.append(key)
|
||||
|
||||
|
||||
return character_dict,character_list
|
||||
|
||||
def get_id_prompt_index(character_dict,id_prompts):
|
||||
replace_id_prompts = []
|
||||
character_index_dict = {}
|
||||
invert_character_index_dict = {}
|
||||
for ind,id_prompt in enumerate(id_prompts):
|
||||
for key in character_dict.keys():
|
||||
if key in id_prompt:
|
||||
if key not in character_index_dict:
|
||||
character_index_dict[key] = []
|
||||
character_index_dict[key].append(ind)
|
||||
invert_character_index_dict[ind] = key
|
||||
replace_id_prompts.append(id_prompt.replace(key,character_dict[key]))
|
||||
|
||||
return character_index_dict,invert_character_index_dict,replace_id_prompts
|
||||
|
||||
def get_cur_id_list(real_prompt,character_dict,character_index_dict):
|
||||
list_arr = []
|
||||
for keys in character_index_dict.keys():
|
||||
if keys in real_prompt:
|
||||
list_arr = list_arr + character_index_dict[keys]
|
||||
real_prompt = real_prompt.replace(keys,character_dict[keys])
|
||||
return list_arr,real_prompt
|
||||
|
||||
def process_original_prompt(character_dict,prompts,id_length):
|
||||
replace_prompts = []
|
||||
character_index_dict = {}
|
||||
invert_character_index_dict = {}
|
||||
for ind,prompt in enumerate(prompts):
|
||||
for key in character_dict.keys():
|
||||
if key in prompt:
|
||||
if key not in character_index_dict:
|
||||
character_index_dict[key] = []
|
||||
character_index_dict[key].append(ind)
|
||||
if ind not in invert_character_index_dict:
|
||||
invert_character_index_dict[ind] = []
|
||||
invert_character_index_dict[ind].append(key)
|
||||
cur_prompt = prompt
|
||||
if ind in invert_character_index_dict:
|
||||
for key in invert_character_index_dict[ind]:
|
||||
cur_prompt = cur_prompt.replace(key,character_dict[key])
|
||||
replace_prompts.append(cur_prompt)
|
||||
ref_index_dict = {}
|
||||
ref_totals = []
|
||||
print(character_index_dict)
|
||||
for character_key in character_index_dict.keys():
|
||||
if character_key not in character_index_dict:
|
||||
raise gr.Error("{} not have prompt description, please remove it".format(character_key))
|
||||
index_list = character_index_dict[character_key]
|
||||
index_list = [index for index in index_list if len(invert_character_index_dict[index]) == 1]
|
||||
if len(index_list) < id_length:
|
||||
raise gr.Error(f"{character_key} not have enough prompt description, need no less than {id_length}, but you give {len(index_list)}")
|
||||
ref_index_dict[character_key] = index_list[:id_length]
|
||||
ref_totals = ref_totals + index_list[:id_length]
|
||||
return character_index_dict,invert_character_index_dict,replace_prompts,ref_index_dict,ref_totals
|
||||
|
||||
|
||||
def get_ref_character(real_prompt,character_dict):
|
||||
list_arr = []
|
||||
for keys in character_dict.keys():
|
||||
if keys in real_prompt:
|
||||
list_arr = list_arr + [keys]
|
||||
return list_arr
|
Loading…
Reference in New Issue