Wuerstchen fixes (#4942)
* fix arguments and make example code work * change arguments in combined test * Add default timesteps * style * fixed test * fix broken test * formatting * fix docstrings * fix num_images_per_prompt * fix doc styles * please dont change this * fix tests * rename to DEFAULT_STAGE_C_TIMESTEPS --------- Co-authored-by: Dominic Rampas <d6582533@gmail.com>
This commit is contained in:
parent
6c6a246461
commit
16a056a7b5
@ -17,6 +17,7 @@ After the initial paper release, we have improved numerous things in the archite
|
||||
- Multi Aspect Resolution Sampling
|
||||
- Better quality
|
||||
|
||||
|
||||
We are releasing 3 checkpoints for the text-conditional image generation model (Stage C). Those are:
|
||||
|
||||
- v2-base
|
||||
@ -24,7 +25,7 @@ We are releasing 3 checkpoints for the text-conditional image generation model (
|
||||
- v2-interpolated (50% interpolation between v2-base and v2-aesthetic)
|
||||
|
||||
We recommend to use v2-interpolated, as it has a nice touch of both photorealism and aesthetic. Use v2-base for finetunings as it does not have a style bias and use v2-aesthetic for very artistic generations.
|
||||
A comparison can be seen here:
|
||||
A comparison can be seen here:
|
||||
|
||||
<img src="https://github.com/dome272/Wuerstchen/assets/61938694/2914830f-cbd3-461c-be64-d50734f4b49d" width=500>
|
||||
|
||||
@ -35,27 +36,18 @@ For the sake of usability Würstchen can be used with a single pipeline. This pi
|
||||
```python
|
||||
import torch
|
||||
from diffusers import AutoPipelineForText2Image
|
||||
from diffusers.pipelines.wuerstchen import DEFAULT_STAGE_C_TIMESTEPS
|
||||
|
||||
device = "cuda"
|
||||
dtype = torch.float16
|
||||
num_images_per_prompt = 2
|
||||
|
||||
pipeline = AutoPipelineForText2Image.from_pretrained(
|
||||
"warp-diffusion/wuerstchen", torch_dtype=dtype
|
||||
).to(device)
|
||||
pipe = AutoPipelineForText2Image.from_pretrained("warp-ai/wuerstchen", torch_dtype=torch.float16).to("cuda")
|
||||
|
||||
caption = "Anthropomorphic cat dressed as a fire fighter"
|
||||
negative_prompt = ""
|
||||
|
||||
output = pipeline(
|
||||
prompt=caption,
|
||||
height=1024,
|
||||
images = pipe(
|
||||
caption,
|
||||
width=1024,
|
||||
negative_prompt=negative_prompt,
|
||||
height=1536,
|
||||
prior_timesteps=DEFAULT_STAGE_C_TIMESTEPS,
|
||||
prior_guidance_scale=4.0,
|
||||
decoder_guidance_scale=0.0,
|
||||
num_images_per_prompt=num_images_per_prompt,
|
||||
output_type="pil",
|
||||
num_images_per_prompt=2,
|
||||
).images
|
||||
```
|
||||
|
||||
@ -64,27 +56,29 @@ For explanation purposes, we can also initialize the two main pipelines of Würs
|
||||
```python
|
||||
import torch
|
||||
from diffusers import WuerstchenDecoderPipeline, WuerstchenPriorPipeline
|
||||
from diffusers.pipelines.wuerstchen import DEFAULT_STAGE_C_TIMESTEPS
|
||||
|
||||
device = "cuda"
|
||||
dtype = torch.float16
|
||||
num_images_per_prompt = 2
|
||||
|
||||
prior_pipeline = WuerstchenPriorPipeline.from_pretrained(
|
||||
"warp-diffusion/wuerstchen-prior", torch_dtype=dtype
|
||||
"warp-ai/wuerstchen-prior", torch_dtype=dtype
|
||||
).to(device)
|
||||
decoder_pipeline = WuerstchenDecoderPipeline.from_pretrained(
|
||||
"warp-diffusion/wuerstchen", torch_dtype=dtype
|
||||
"warp-ai/wuerstchen", torch_dtype=dtype
|
||||
).to(device)
|
||||
|
||||
caption = "A captivating artwork of a mysterious stone golem"
|
||||
caption = "Anthropomorphic cat dressed as a fire fighter"
|
||||
negative_prompt = ""
|
||||
|
||||
prior_output = prior_pipeline(
|
||||
prompt=caption,
|
||||
height=1024,
|
||||
width=1024,
|
||||
width=1536,
|
||||
timesteps=DEFAULT_STAGE_C_TIMESTEPS,
|
||||
negative_prompt=negative_prompt,
|
||||
guidance_scale=4.0,
|
||||
guidance_scale=4.0,
|
||||
num_images_per_prompt=num_images_per_prompt,
|
||||
)
|
||||
decoder_output = decoder_pipeline(
|
||||
@ -109,13 +103,12 @@ pipeline.decoder = torch.compile(pipeline.decoder, mode="reduce-overhead", fullg
|
||||
|
||||
- Due to the high compression employed by Würstchen, generations can lack a good amount
|
||||
of detail. To our human eye, this is especially noticeable in faces, hands etc.
|
||||
- **Images can only be generated in 128-pixel steps**, e.g. the next higher resolution
|
||||
- **Images can only be generated in 128-pixel steps**, e.g. the next higher resolution
|
||||
after 1024x1024 is 1152x1152
|
||||
- The model lacks the ability to render correct text in images
|
||||
- The model often does not achieve photorealism
|
||||
- Difficult compositional prompts are hard for the model
|
||||
|
||||
|
||||
The original codebase, as well as experimental ideas, can be found at [dome272/Wuerstchen](https://github.com/dome272/Wuerstchen).
|
||||
|
||||
## WuerschenPipeline
|
||||
|
||||
@ -91,12 +91,12 @@ prior_pipeline = WuerstchenPriorPipeline(
|
||||
prior=prior_model, text_encoder=text_encoder, tokenizer=tokenizer, scheduler=scheduler
|
||||
)
|
||||
|
||||
prior_pipeline.save_pretrained("warp-diffusion/wuerstchen-prior")
|
||||
prior_pipeline.save_pretrained("warp-ai/wuerstchen-prior")
|
||||
|
||||
decoder_pipeline = WuerstchenDecoderPipeline(
|
||||
text_encoder=gen_text_encoder, tokenizer=gen_tokenizer, vqgan=vqmodel, decoder=deocder, scheduler=scheduler
|
||||
)
|
||||
decoder_pipeline.save_pretrained("warp-diffusion/wuerstchen")
|
||||
decoder_pipeline.save_pretrained("warp-ai/wuerstchen")
|
||||
|
||||
# Wuerstchen pipeline
|
||||
wuerstchen_pipeline = WuerstchenCombinedPipeline(
|
||||
@ -112,4 +112,4 @@ wuerstchen_pipeline = WuerstchenCombinedPipeline(
|
||||
prior=prior_model,
|
||||
prior_scheduler=scheduler,
|
||||
)
|
||||
wuerstchen_pipeline.save_pretrained("warp-diffusion/WuerstchenCombinedPipeline")
|
||||
wuerstchen_pipeline.save_pretrained("warp-ai/WuerstchenCombinedPipeline")
|
||||
|
||||
@ -24,7 +24,7 @@ else:
|
||||
_import_structure["modeling_wuerstchen_prior"] = ["WuerstchenPrior"]
|
||||
_import_structure["pipeline_wuerstchen"] = ["WuerstchenDecoderPipeline"]
|
||||
_import_structure["pipeline_wuerstchen_combined"] = ["WuerstchenCombinedPipeline"]
|
||||
_import_structure["pipeline_wuerstchen_prior"] = ["WuerstchenPriorPipeline"]
|
||||
_import_structure["pipeline_wuerstchen_prior"] = ["DEFAULT_STAGE_C_TIMESTEPS", "WuerstchenPriorPipeline"]
|
||||
|
||||
|
||||
import sys
|
||||
|
||||
@ -35,11 +35,11 @@ EXAMPLE_DOC_STRING = """
|
||||
>>> from diffusers import WuerstchenPriorPipeline, WuerstchenDecoderPipeline
|
||||
|
||||
>>> prior_pipe = WuerstchenPriorPipeline.from_pretrained(
|
||||
... "warp-diffusion/wuerstchen-prior", torch_dtype=torch.float16
|
||||
... ).to("cuda")
|
||||
>>> gen_pipe = WuerstchenDecoderPipeline.from_pretrain(
|
||||
... "warp-diffusion/wuerstchen", torch_dtype=torch.float16
|
||||
... "warp-ai/wuerstchen-prior", torch_dtype=torch.float16
|
||||
... ).to("cuda")
|
||||
>>> gen_pipe = WuerstchenDecoderPipeline.from_pretrain("warp-ai/wuerstchen", torch_dtype=torch.float16).to(
|
||||
... "cuda"
|
||||
... )
|
||||
|
||||
>>> prompt = "an image of a shiba inu, donning a spacesuit and helmet"
|
||||
>>> prior_output = pipe(prompt)
|
||||
|
||||
@ -31,9 +31,9 @@ TEXT2IMAGE_EXAMPLE_DOC_STRING = """
|
||||
```py
|
||||
>>> from diffusions import WuerstchenCombinedPipeline
|
||||
|
||||
>>> pipe = WuerstchenCombinedPipeline.from_pretrained(
|
||||
... "warp-diffusion/Wuerstchen", torch_dtype=torch.float16
|
||||
... ).to("cuda")
|
||||
>>> pipe = WuerstchenCombinedPipeline.from_pretrained("warp-ai/Wuerstchen", torch_dtype=torch.float16).to(
|
||||
... "cuda"
|
||||
... )
|
||||
>>> prompt = "an image of a shiba inu, donning a spacesuit and helmet"
|
||||
>>> images = pipe(prompt=prompt)
|
||||
```
|
||||
@ -145,16 +145,16 @@ class WuerstchenCombinedPipeline(DiffusionPipeline):
|
||||
def __call__(
|
||||
self,
|
||||
prompt: Union[str, List[str]],
|
||||
negative_prompt: Optional[Union[str, List[str]]] = None,
|
||||
guidance_scale: float = 4.0,
|
||||
num_images_per_prompt: int = 1,
|
||||
height: int = 512,
|
||||
width: int = 512,
|
||||
prior_guidance_scale: float = 4.0,
|
||||
prior_num_inference_steps: int = 60,
|
||||
num_inference_steps: int = 12,
|
||||
prior_timesteps: Optional[List[float]] = None,
|
||||
timesteps: Optional[List[float]] = None,
|
||||
prior_guidance_scale: float = 4.0,
|
||||
num_inference_steps: int = 12,
|
||||
decoder_timesteps: Optional[List[float]] = None,
|
||||
decoder_guidance_scale: float = 0.0,
|
||||
negative_prompt: Optional[Union[str, List[str]]] = None,
|
||||
num_images_per_prompt: int = 1,
|
||||
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
|
||||
latents: Optional[torch.FloatTensor] = None,
|
||||
output_type: Optional[str] = "pil",
|
||||
@ -182,19 +182,20 @@ class WuerstchenCombinedPipeline(DiffusionPipeline):
|
||||
`prior_guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely linked
|
||||
to the text `prompt`, usually at the expense of lower image quality.
|
||||
prior_num_inference_steps (`Union[int, Dict[float, int]]`, *optional*, defaults to 30):
|
||||
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
|
||||
The number of prior denoising steps. More denoising steps usually lead to a higher quality image at the
|
||||
expense of slower inference. For more specific timestep spacing, you can pass customized
|
||||
`prior_timesteps`
|
||||
num_inference_steps (`int`, *optional*, defaults to 12):
|
||||
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
|
||||
expense of slower inference. For more specific timestep spacing, you can pass customized `timesteps`
|
||||
The number of decoder denoising steps. More denoising steps usually lead to a higher quality image at
|
||||
the expense of slower inference. For more specific timestep spacing, you can pass customized
|
||||
`timesteps`
|
||||
prior_timesteps (`List[float]`, *optional*):
|
||||
Custom timesteps to use for the denoising process for the prior. If not defined, equal spaced
|
||||
`prior_num_inference_steps` timesteps are used. Must be in descending order.
|
||||
timesteps (`List[float]`, *optional*):
|
||||
decoder_timesteps (`List[float]`, *optional*):
|
||||
Custom timesteps to use for the denoising process for the decoder. If not defined, equal spaced
|
||||
`decoder_num_inference_steps` timesteps are used. Must be in descending order.
|
||||
guidance_scale (`float`, *optional*, defaults to 4.0):
|
||||
`num_inference_steps` timesteps are used. Must be in descending order.
|
||||
decoder_guidance_scale (`float`, *optional*, defaults to 0.0):
|
||||
Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
|
||||
`guidance_scale` is defined as `w` of equation 2. of [Imagen
|
||||
Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
|
||||
@ -221,27 +222,28 @@ class WuerstchenCombinedPipeline(DiffusionPipeline):
|
||||
"""
|
||||
prior_outputs = self.prior_pipe(
|
||||
prompt=prompt,
|
||||
negative_prompt=negative_prompt,
|
||||
width=width,
|
||||
height=height,
|
||||
num_images_per_prompt=num_images_per_prompt,
|
||||
width=width,
|
||||
num_inference_steps=prior_num_inference_steps,
|
||||
timesteps=prior_timesteps,
|
||||
guidance_scale=prior_guidance_scale,
|
||||
negative_prompt=negative_prompt,
|
||||
num_images_per_prompt=num_images_per_prompt,
|
||||
generator=generator,
|
||||
latents=latents,
|
||||
guidance_scale=prior_guidance_scale,
|
||||
output_type="pt",
|
||||
return_dict=False,
|
||||
)
|
||||
image_embeddings = prior_outputs[0]
|
||||
|
||||
outputs = self.decoder_pipe(
|
||||
prompt=prompt,
|
||||
image_embeddings=image_embeddings,
|
||||
prompt=prompt,
|
||||
num_inference_steps=num_inference_steps,
|
||||
timesteps=timesteps,
|
||||
timesteps=decoder_timesteps,
|
||||
guidance_scale=decoder_guidance_scale,
|
||||
negative_prompt=negative_prompt,
|
||||
generator=generator,
|
||||
guidance_scale=guidance_scale,
|
||||
output_type=output_type,
|
||||
return_dict=return_dict,
|
||||
)
|
||||
|
||||
@ -14,7 +14,7 @@
|
||||
|
||||
from dataclasses import dataclass
|
||||
from math import ceil
|
||||
from typing import List, Optional, Union
|
||||
from typing import Callable, List, Optional, Union
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
@ -35,6 +35,8 @@ from .modeling_wuerstchen_prior import WuerstchenPrior
|
||||
|
||||
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||
|
||||
DEFAULT_STAGE_C_TIMESTEPS = list(np.linspace(1.0, 2 / 3, 20)) + list(np.linspace(2 / 3, 0.0, 11))[1:]
|
||||
|
||||
EXAMPLE_DOC_STRING = """
|
||||
Examples:
|
||||
```py
|
||||
@ -42,7 +44,7 @@ EXAMPLE_DOC_STRING = """
|
||||
>>> from diffusers import WuerstchenPriorPipeline
|
||||
|
||||
>>> prior_pipe = WuerstchenPriorPipeline.from_pretrained(
|
||||
... "warp-diffusion/wuerstchen-prior", torch_dtype=torch.float16
|
||||
... "warp-ai/wuerstchen-prior", torch_dtype=torch.float16
|
||||
... ).to("cuda")
|
||||
|
||||
>>> prompt = "an image of a shiba inu, donning a spacesuit and helmet"
|
||||
@ -265,7 +267,7 @@ class WuerstchenPriorPipeline(DiffusionPipeline):
|
||||
prompt: Union[str, List[str]] = None,
|
||||
height: int = 1024,
|
||||
width: int = 1024,
|
||||
num_inference_steps: int = 30,
|
||||
num_inference_steps: int = 60,
|
||||
timesteps: List[float] = None,
|
||||
guidance_scale: float = 8.0,
|
||||
negative_prompt: Optional[Union[str, List[str]]] = None,
|
||||
@ -274,6 +276,8 @@ class WuerstchenPriorPipeline(DiffusionPipeline):
|
||||
latents: Optional[torch.FloatTensor] = None,
|
||||
output_type: Optional[str] = "pt",
|
||||
return_dict: bool = True,
|
||||
callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
|
||||
callback_steps: int = 1,
|
||||
):
|
||||
"""
|
||||
Function invoked when calling the pipeline for generation.
|
||||
@ -314,6 +318,12 @@ class WuerstchenPriorPipeline(DiffusionPipeline):
|
||||
(`np.array`) or `"pt"` (`torch.Tensor`).
|
||||
return_dict (`bool`, *optional*, defaults to `True`):
|
||||
Whether or not to return a [`~pipelines.ImagePipelineOutput`] instead of a plain tuple.
|
||||
callback (`Callable`, *optional*):
|
||||
A function that will be called every `callback_steps` steps during inference. The function will be
|
||||
called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`.
|
||||
callback_steps (`int`, *optional*, defaults to 1):
|
||||
The frequency at which the `callback` function will be called. If not specified, the callback will be
|
||||
called at every step.
|
||||
|
||||
Examples:
|
||||
|
||||
@ -365,7 +375,7 @@ class WuerstchenPriorPipeline(DiffusionPipeline):
|
||||
latents = self.prepare_latents(effnet_features_shape, dtype, device, generator, latents, self.scheduler)
|
||||
|
||||
# 6. Run denoising loop
|
||||
for t in self.progress_bar(timesteps[:-1]):
|
||||
for i, t in enumerate(self.progress_bar(timesteps[:-1])):
|
||||
ratio = t.expand(latents.size(0)).to(dtype)
|
||||
|
||||
# 7. Denoise image embeddings
|
||||
@ -390,6 +400,9 @@ class WuerstchenPriorPipeline(DiffusionPipeline):
|
||||
generator=generator,
|
||||
).prev_sample
|
||||
|
||||
if callback is not None and i % callback_steps == 0:
|
||||
callback(i, t, latents)
|
||||
|
||||
# 10. Denormalize the latents
|
||||
latents = latents * self.config.latent_mean - self.config.latent_std
|
||||
|
||||
|
||||
@ -38,7 +38,8 @@ class WuerstchenCombinedPipelineFastTests(PipelineTesterMixin, unittest.TestCase
|
||||
"height",
|
||||
"width",
|
||||
"latents",
|
||||
"guidance_scale",
|
||||
"prior_guidance_scale",
|
||||
"decoder_guidance_scale",
|
||||
"negative_prompt",
|
||||
"num_inference_steps",
|
||||
"return_dict",
|
||||
@ -160,7 +161,7 @@ class WuerstchenCombinedPipelineFastTests(PipelineTesterMixin, unittest.TestCase
|
||||
"prompt": "horse",
|
||||
"generator": generator,
|
||||
"prior_guidance_scale": 4.0,
|
||||
"guidance_scale": 4.0,
|
||||
"decoder_guidance_scale": 4.0,
|
||||
"num_inference_steps": 2,
|
||||
"prior_num_inference_steps": 2,
|
||||
"output_type": "np",
|
||||
|
||||
Loading…
Reference in New Issue
Block a user