From 18cbe26efdab691f4b2eccc5f8c4c3951208cc79 Mon Sep 17 00:00:00 2001 From: "yupeng.zhou" Date: Tue, 14 May 2024 00:27:21 +0800 Subject: [PATCH] twoperson --- gradio_app_sdxl_specific_id_low_vram.py | 231 +++--- ...gradio_app_sdxl_specific_id_old_version.py | 782 ++++++++++++++++++ storydiffusionpipeline.py | 0 3 files changed, 898 insertions(+), 115 deletions(-) create mode 100644 oldversion/gradio_app_sdxl_specific_id_old_version.py create mode 100644 storydiffusionpipeline.py diff --git a/gradio_app_sdxl_specific_id_low_vram.py b/gradio_app_sdxl_specific_id_low_vram.py index a656027..3239948 100644 --- a/gradio_app_sdxl_specific_id_low_vram.py +++ b/gradio_app_sdxl_specific_id_low_vram.py @@ -21,7 +21,7 @@ if is_torch2_available(): AttnProcessor2_0 as AttnProcessor else: from utils.gradio_utils import AttnProcessor - +import datetime import diffusers from diffusers import StableDiffusionXLPipeline from utils import PhotoMakerStableDiffusionXLPipeline @@ -181,83 +181,6 @@ class SpatialAttnProcessor2_0(torch.nn.Module): cur_step += 1 indices1024,indices4096 = cal_attn_indice_xl_effcient_memory(self.total_length,self.id_length,sa32,sa64,height,width, device=self.device, dtype= self.dtype) - return hidden_states - def __call1__( - self, - attn, - hidden_states, - encoder_hidden_states=None, - attention_mask=None, - temb=None, - attn_indices = None, - ): - # print("hidden state shape",hidden_states.shape,self.id_length) - residual = hidden_states - # if encoder_hidden_states is not None: - # raise Exception("not implement") - if attn.spatial_norm is not None: - hidden_states = attn.spatial_norm(hidden_states, temb) - input_ndim = hidden_states.ndim - - if input_ndim == 4: - total_batch_size, channel, height, width = hidden_states.shape - hidden_states = hidden_states.view(total_batch_size, channel, height * width).transpose(1, 2) - total_batch_size,nums_token,channel = hidden_states.shape - img_nums = total_batch_size//2 - hidden_states = hidden_states.view(-1,img_nums,nums_token,channel).reshape(-1,img_nums * nums_token,channel) - batch_size, sequence_length, _ = hidden_states.shape - - if attn.group_norm is not None: - hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2) - - query = attn.to_q(hidden_states) - - if encoder_hidden_states is None: - encoder_hidden_states = hidden_states # B, N, C - else: - encoder_hidden_states = encoder_hidden_states.view(-1,self.id_length+1,nums_token,channel).reshape(-1,(self.id_length+1) * nums_token,channel) - - key = attn.to_k(encoder_hidden_states) - value = attn.to_v(encoder_hidden_states) - - - inner_dim = key.shape[-1] - head_dim = inner_dim // attn.heads - - query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) - - key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) - value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) - # print(key.shape,value.shape,query.shape,attention_mask.shape) - # the output of sdp = (batch, num_heads, seq_len, head_dim) - # TODO: add support for attn.scale when we move to Torch 2.1 - #print(query.shape,key.shape,value.shape,attention_mask.shape) - hidden_states = F.scaled_dot_product_attention( - query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False - ) - - hidden_states = hidden_states.transpose(1, 2).reshape(total_batch_size, -1, attn.heads * head_dim) - hidden_states = hidden_states.to(query.dtype) - - - - # linear proj - hidden_states = attn.to_out[0](hidden_states) - # dropout - hidden_states = attn.to_out[1](hidden_states) - - # if input_ndim == 4: - # tile_hidden_states = tile_hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width) - - # if attn.residual_connection: - # tile_hidden_states = tile_hidden_states + residual - - if input_ndim == 4: - hidden_states = hidden_states.transpose(-1, -2).reshape(total_batch_size, channel, height, width) - if attn.residual_connection: - hidden_states = hidden_states + residual - hidden_states = hidden_states / attn.rescale_output_factor - # print(hidden_states.shape) return hidden_states def __call2__( self, @@ -393,6 +316,60 @@ css = '''