pull/18/head
chenxwh 2024-05-04 23:32:42 +00:00
parent 8a2b11ab66
commit 2fcbb79cd9
2 changed files with 7 additions and 5 deletions

View File

@ -5,7 +5,7 @@
<div align="center">
## StoryDiffusion: Consistent Self-Attention for Long-Range Image and Video Generation [![Paper page](https://huggingface.co/datasets/huggingface/badges/resolve/main/paper-page-md-dark.svg)]()
[[Paper](https://arxiv.org/abs/2405.01434)] &emsp; [[Project Page](https://storydiffusion.github.io/)] &emsp; [[🤗 Comic Generation Demo ](https://huggingface.co/spaces/YupengZhou/StoryDiffusion)] <br>
[[Paper](https://arxiv.org/abs/2405.01434)] &emsp; [[Project Page](https://storydiffusion.github.io/)] &emsp; [[🤗 Comic Generation Demo ](https://huggingface.co/spaces/YupengZhou/StoryDiffusion)] [![Replicate](https://replicate.com/cjwbw/StoryDiffusion/badge)](https://replicate.com/cjwbw/StoryDiffusion) <br>
</div>

View File

@ -457,6 +457,7 @@ class Predictor(BasePredictor):
self.pipe_realvision.enable_freeu(s1=0.6, s2=0.4, b1=1.1, b2=1.2)
self.pipe_realvision.fuse_lora()
@torch.inference_mode()
def predict(
self,
sd_model: str = Input(
@ -559,7 +560,7 @@ class Predictor(BasePredictor):
default=768,
),
num_steps: int = Input(
description="Number of sample steps", ge=25, le=50, default=50
description="Number of sample steps", ge=20, le=50, default=25
),
guidance_scale: float = Input(
description="Scale for classifier-free guidance", ge=0.1, le=10, default=5
@ -569,20 +570,19 @@ class Predictor(BasePredictor):
),
sa32_setting: float = Input(
description="The degree of Paired Attention at 32 x 32 self-attention layers",
default=0.7,
default=0.5,
ge=0,
le=1.0,
),
sa64_setting: float = Input(
description="The degree of Paired Attention at 64 x 64 self-attention layers",
default=0.7,
default=0.5,
ge=0,
le=1.0,
),
num_ids: int = Input(
description="Number of id images in total images. This should not exceed total number of line-separated prompts",
default=3,
choices=[2, 3, 4],
),
output_format: str = Input(
description="Format of the output images",
@ -776,4 +776,6 @@ class Predictor(BasePredictor):
sample.save(output_filename, **save_params)
output_paths.append(Path(output_filename))
del pipe
return ModelOutput(comic=Path(comic_out), individual_images=output_paths)