Compare commits
26 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| 1410a1bcdc | |||
| a9109dbb2b | |||
| 6874d2b57f | |||
| d8012a4825 | |||
| 0e9416d6a3 | |||
| 03dfb7f0b4 | |||
| fe0a0ebe88 | |||
| eeeb28a9ad | |||
| c05356497a | |||
| 1d4ad34af0 | |||
| 20ce68f945 | |||
| 110ffe2589 | |||
| 0b7225e918 | |||
| db7b7bd983 | |||
| 6a0a312370 | |||
| c28d3c82ce | |||
| bcb6cc16df | |||
| 4d1e4e24e5 | |||
| a808a85390 | |||
| 4c54519e1a | |||
| 25f11424f6 | |||
| 89300131d2 | |||
| 6c56f05097 | |||
| 77fc197f70 | |||
| edf22c052e | |||
| 5755d16868 |
@@ -76,25 +76,34 @@ If you want to use all possible use cases in a single `DiffusionPipeline` you ca
|
||||
- __call__
|
||||
- enable_attention_slicing
|
||||
- disable_attention_slicing
|
||||
- enable_vae_slicing
|
||||
- disable_vae_slicing
|
||||
- enable_xformers_memory_efficient_attention
|
||||
- disable_xformers_memory_efficient_attention
|
||||
|
||||
## StableDiffusionImg2ImgPipeline
|
||||
[[autodoc]] StableDiffusionImg2ImgPipeline
|
||||
- __call__
|
||||
- enable_attention_slicing
|
||||
- disable_attention_slicing
|
||||
- enable_xformers_memory_efficient_attention
|
||||
- disable_xformers_memory_efficient_attention
|
||||
|
||||
## StableDiffusionInpaintPipeline
|
||||
[[autodoc]] StableDiffusionInpaintPipeline
|
||||
- __call__
|
||||
- enable_attention_slicing
|
||||
- disable_attention_slicing
|
||||
|
||||
- enable_xformers_memory_efficient_attention
|
||||
- disable_xformers_memory_efficient_attention
|
||||
|
||||
## StableDiffusionImageVariationPipeline
|
||||
[[autodoc]] StableDiffusionImageVariationPipeline
|
||||
- __call__
|
||||
- enable_attention_slicing
|
||||
- disable_attention_slicing
|
||||
- enable_xformers_memory_efficient_attention
|
||||
- disable_xformers_memory_efficient_attention
|
||||
|
||||
|
||||
## StableDiffusionUpscalePipeline
|
||||
@@ -102,3 +111,5 @@ If you want to use all possible use cases in a single `DiffusionPipeline` you ca
|
||||
- __call__
|
||||
- enable_attention_slicing
|
||||
- disable_attention_slicing
|
||||
- enable_xformers_memory_efficient_attention
|
||||
- disable_xformers_memory_efficient_attention
|
||||
|
||||
@@ -76,6 +76,33 @@ Original paper can be found [here](https://arxiv.org/abs/2206.00927) and the [im
|
||||
|
||||
[[autodoc]] DPMSolverMultistepScheduler
|
||||
|
||||
#### Heun scheduler inspired by Karras et. al paper
|
||||
|
||||
Algorithm 1 of [Karras et. al](https://arxiv.org/abs/2206.00364).
|
||||
Scheduler ported from @crowsonkb's https://github.com/crowsonkb/k-diffusion library:
|
||||
|
||||
All credit for making this scheduler work goes to [Katherine Crowson](https://github.com/crowsonkb/)
|
||||
|
||||
[[autodoc]] HeunDiscreteScheduler
|
||||
|
||||
#### DPM Discrete Scheduler inspired by Karras et. al paper
|
||||
|
||||
Inspired by [Karras et. al](https://arxiv.org/abs/2206.00364).
|
||||
Scheduler ported from @crowsonkb's https://github.com/crowsonkb/k-diffusion library:
|
||||
|
||||
All credit for making this scheduler work goes to [Katherine Crowson](https://github.com/crowsonkb/)
|
||||
|
||||
[[autodoc]] KDPM2DiscreteScheduler
|
||||
|
||||
#### DPM Discrete Scheduler with ancestral sampling inspired by Karras et. al paper
|
||||
|
||||
Inspired by [Karras et. al](https://arxiv.org/abs/2206.00364).
|
||||
Scheduler ported from @crowsonkb's https://github.com/crowsonkb/k-diffusion library:
|
||||
|
||||
All credit for making this scheduler work goes to [Katherine Crowson](https://github.com/crowsonkb/)
|
||||
|
||||
[[autodoc]] KDPM2AncestralDiscreteScheduler
|
||||
|
||||
#### Variance exploding, stochastic sampling from Karras et. al
|
||||
|
||||
Original paper can be found [here](https://arxiv.org/abs/2006.11239).
|
||||
@@ -86,7 +113,6 @@ Original paper can be found [here](https://arxiv.org/abs/2006.11239).
|
||||
|
||||
Original implementation can be found [here](https://arxiv.org/abs/2206.00364).
|
||||
|
||||
|
||||
[[autodoc]] LMSDiscreteScheduler
|
||||
|
||||
#### Pseudo numerical methods for diffusion models (PNDM)
|
||||
|
||||
@@ -117,6 +117,34 @@ image = pipe(prompt).images[0]
|
||||
|
||||
There's a small performance penalty of about 10% slower inference times, but this method allows you to use Stable Diffusion in as little as 3.2 GB of VRAM!
|
||||
|
||||
|
||||
## Sliced VAE decode for larger batches
|
||||
|
||||
To decode large batches of images with limited VRAM, or to enable batches with 32 images or more, you can use sliced VAE decode that decodes the batch latents one image at a time.
|
||||
|
||||
You likely want to couple this with [`~StableDiffusionPipeline.enable_attention_slicing`] or [`~StableDiffusionPipeline.enable_xformers_memory_efficient_attention`] to further minimize memory use.
|
||||
|
||||
To perform the VAE decode one image at a time, invoke [`~StableDiffusionPipeline.enable_vae_slicing`] in your pipeline before inference. For example:
|
||||
|
||||
```Python
|
||||
import torch
|
||||
from diffusers import StableDiffusionPipeline
|
||||
|
||||
pipe = StableDiffusionPipeline.from_pretrained(
|
||||
"runwayml/stable-diffusion-v1-5",
|
||||
revision="fp16",
|
||||
torch_dtype=torch.float16,
|
||||
)
|
||||
pipe = pipe.to("cuda")
|
||||
|
||||
prompt = "a photo of an astronaut riding a horse on mars"
|
||||
pipe.enable_vae_slicing()
|
||||
images = pipe([prompt] * 32).images
|
||||
```
|
||||
|
||||
You may see a small performance boost in VAE decode on multi-image batches. There should be no performance impact on single-image batches.
|
||||
|
||||
|
||||
## Offloading to CPU with accelerate for memory savings
|
||||
|
||||
For additional memory savings, you can offload the weights to CPU and load them to GPU when performing the forward pass.
|
||||
|
||||
@@ -378,21 +378,3 @@ dpm = DPMSolverMultistepScheduler.from_pretrained(repo_id, subfolder="scheduler"
|
||||
# replace `dpm` with any of `ddpm`, `ddim`, `pndm`, `lms`, `euler`, `euler_anc`
|
||||
pipeline = StableDiffusionPipeline.from_pretrained(repo_id, scheduler=dpm)
|
||||
```
|
||||
|
||||
## API
|
||||
|
||||
[[autodoc]] modeling_utils.ModelMixin
|
||||
- from_pretrained
|
||||
- save_pretrained
|
||||
|
||||
[[autodoc]] pipeline_utils.DiffusionPipeline
|
||||
- from_pretrained
|
||||
- save_pretrained
|
||||
|
||||
[[autodoc]] modeling_flax_utils.FlaxModelMixin
|
||||
- from_pretrained
|
||||
- save_pretrained
|
||||
|
||||
[[autodoc]] pipeline_flax_utils.FlaxDiffusionPipeline
|
||||
- from_pretrained
|
||||
- save_pretrained
|
||||
|
||||
@@ -602,7 +602,7 @@ For example, this could be used to place a logo on a shirt and make it blend sea
|
||||
import PIL
|
||||
import torch
|
||||
|
||||
from diffusers import StableDiffusionInpaintPipeline
|
||||
from diffusers import DiffusionPipeline
|
||||
|
||||
image_path = "./path-to-image.png"
|
||||
inner_image_path = "./path-to-inner-image.png"
|
||||
@@ -612,10 +612,11 @@ init_image = PIL.Image.open(image_path).convert("RGB").resize((512, 512))
|
||||
inner_image = PIL.Image.open(inner_image_path).convert("RGBA").resize((512, 512))
|
||||
mask_image = PIL.Image.open(mask_path).convert("RGB").resize((512, 512))
|
||||
|
||||
pipe = StableDiffusionInpaintPipeline.from_pretrained(
|
||||
pipe = DiffusionPipeline.from_pretrained(
|
||||
"runwayml/stable-diffusion-inpainting",
|
||||
custom_pipeline="img2img_inpainting",
|
||||
revision="fp16",
|
||||
torch_dtype=torch.float16,
|
||||
torch_dtype=torch.float16
|
||||
)
|
||||
pipe = pipe.to("cuda")
|
||||
|
||||
@@ -623,6 +624,8 @@ prompt = "Your prompt here!"
|
||||
image = pipe(prompt=prompt, image=init_image, inner_image=inner_image, mask_image=mask_image).images[0]
|
||||
```
|
||||
|
||||

|
||||
|
||||
### Text Based Inpainting Stable Diffusion
|
||||
|
||||
Use a text prompt to generate the mask for the area to be inpainted.
|
||||
|
||||
@@ -39,6 +39,8 @@ Now let's get our dataset. Download images from [here](https://drive.google.com/
|
||||
|
||||
And launch the training using
|
||||
|
||||
**___Note: Change the `resolution` to 768 if you are using the [stable-diffusion-2](https://huggingface.co/stabilityai/stable-diffusion-2) 768x768 model.___**
|
||||
|
||||
```bash
|
||||
export MODEL_NAME="CompVis/stable-diffusion-v1-4"
|
||||
export INSTANCE_DIR="path-to-instance-images"
|
||||
@@ -193,6 +195,17 @@ accelerate launch train_dreambooth.py \
|
||||
--max_train_steps=800
|
||||
```
|
||||
|
||||
### Using DreamBooth for other pipelines than Stable Diffusion
|
||||
|
||||
Altdiffusion also support dreambooth now, the runing comman is basically the same as abouve, all you need to do is replace the `MODEL_NAME` like this:
|
||||
One can now simply change the `pretrained_model_name_or_path` to another architecture such as [`AltDiffusion`](https://huggingface.co/docs/diffusers/api/pipelines/alt_diffusion).
|
||||
|
||||
```
|
||||
export MODEL_NAME="CompVis/stable-diffusion-v1-4" --> export MODEL_NAME="BAAI/AltDiffusion-m9"
|
||||
or
|
||||
export MODEL_NAME="CompVis/stable-diffusion-v1-4" --> export MODEL_NAME="BAAI/AltDiffusion"
|
||||
```
|
||||
|
||||
### Inference
|
||||
|
||||
Once you have trained a model using above command, the inference can be done simply using the `StableDiffusionPipeline`. Make sure to include the `identifier`(e.g. sks in above example) in your prompt.
|
||||
|
||||
@@ -14,18 +14,38 @@ from torch.utils.data import Dataset
|
||||
from accelerate import Accelerator
|
||||
from accelerate.logging import get_logger
|
||||
from accelerate.utils import set_seed
|
||||
from diffusers import AutoencoderKL, DDPMScheduler, StableDiffusionPipeline, UNet2DConditionModel
|
||||
from diffusers import AutoencoderKL, DDPMScheduler, DiffusionPipeline, UNet2DConditionModel
|
||||
from diffusers.optimization import get_scheduler
|
||||
from huggingface_hub import HfFolder, Repository, whoami
|
||||
from PIL import Image
|
||||
from torchvision import transforms
|
||||
from tqdm.auto import tqdm
|
||||
from transformers import CLIPTextModel, CLIPTokenizer
|
||||
from transformers import AutoTokenizer, PretrainedConfig
|
||||
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
def import_model_class_from_model_name_or_path(pretrained_model_name_or_path: str):
|
||||
text_encoder_config = PretrainedConfig.from_pretrained(
|
||||
pretrained_model_name_or_path,
|
||||
subfolder="text_encoder",
|
||||
revision=args.revision,
|
||||
)
|
||||
model_class = text_encoder_config.architectures[0]
|
||||
|
||||
if model_class == "CLIPTextModel":
|
||||
from transformers import CLIPTextModel
|
||||
|
||||
return CLIPTextModel
|
||||
elif model_class == "RobertaSeriesModelWithTransformation":
|
||||
from diffusers.pipelines.alt_diffusion.modeling_roberta_series import RobertaSeriesModelWithTransformation
|
||||
|
||||
return RobertaSeriesModelWithTransformation
|
||||
else:
|
||||
raise ValueError(f"{model_class} is not supported.")
|
||||
|
||||
|
||||
def parse_args(input_args=None):
|
||||
parser = argparse.ArgumentParser(description="Simple example of a training script.")
|
||||
parser.add_argument(
|
||||
@@ -124,6 +144,7 @@ def parse_args(input_args=None):
|
||||
default=None,
|
||||
help="Total number of training steps to perform. If provided, overrides num_train_epochs.",
|
||||
)
|
||||
parser.add_argument("--save_steps", type=int, default=500, help="Save checkpoint every X updates steps.")
|
||||
parser.add_argument(
|
||||
"--gradient_accumulation_steps",
|
||||
type=int,
|
||||
@@ -356,7 +377,7 @@ def main(args):
|
||||
|
||||
if cur_class_images < args.num_class_images:
|
||||
torch_dtype = torch.float16 if accelerator.device.type == "cuda" else torch.float32
|
||||
pipeline = StableDiffusionPipeline.from_pretrained(
|
||||
pipeline = DiffusionPipeline.from_pretrained(
|
||||
args.pretrained_model_name_or_path,
|
||||
torch_dtype=torch_dtype,
|
||||
safety_checker=None,
|
||||
@@ -406,19 +427,24 @@ def main(args):
|
||||
|
||||
# Load the tokenizer
|
||||
if args.tokenizer_name:
|
||||
tokenizer = CLIPTokenizer.from_pretrained(
|
||||
tokenizer = AutoTokenizer.from_pretrained(
|
||||
args.tokenizer_name,
|
||||
revision=args.revision,
|
||||
use_fast=False,
|
||||
)
|
||||
elif args.pretrained_model_name_or_path:
|
||||
tokenizer = CLIPTokenizer.from_pretrained(
|
||||
tokenizer = AutoTokenizer.from_pretrained(
|
||||
args.pretrained_model_name_or_path,
|
||||
subfolder="tokenizer",
|
||||
revision=args.revision,
|
||||
use_fast=False,
|
||||
)
|
||||
|
||||
# import correct text encoder class
|
||||
text_encoder_cls = import_model_class_from_model_name_or_path(args.pretrained_model_name_or_path)
|
||||
|
||||
# Load models and create wrapper for stable diffusion
|
||||
text_encoder = CLIPTextModel.from_pretrained(
|
||||
text_encoder = text_encoder_cls.from_pretrained(
|
||||
args.pretrained_model_name_or_path,
|
||||
subfolder="text_encoder",
|
||||
revision=args.revision,
|
||||
@@ -603,23 +629,31 @@ def main(args):
|
||||
encoder_hidden_states = text_encoder(batch["input_ids"])[0]
|
||||
|
||||
# Predict the noise residual
|
||||
noise_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample
|
||||
model_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample
|
||||
|
||||
# Get the target for loss depending on the prediction type
|
||||
if noise_scheduler.config.prediction_type == "epsilon":
|
||||
target = noise
|
||||
elif noise_scheduler.config.prediction_type == "v_prediction":
|
||||
target = noise_scheduler.get_velocity(latents, noise, timesteps)
|
||||
else:
|
||||
raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}")
|
||||
|
||||
if args.with_prior_preservation:
|
||||
# Chunk the noise and noise_pred into two parts and compute the loss on each part separately.
|
||||
noise_pred, noise_pred_prior = torch.chunk(noise_pred, 2, dim=0)
|
||||
noise, noise_prior = torch.chunk(noise, 2, dim=0)
|
||||
# Chunk the noise and model_pred into two parts and compute the loss on each part separately.
|
||||
model_pred, model_pred_prior = torch.chunk(model_pred, 2, dim=0)
|
||||
target, target_prior = torch.chunk(target, 2, dim=0)
|
||||
|
||||
# Compute instance loss
|
||||
loss = F.mse_loss(noise_pred.float(), noise.float(), reduction="none").mean([1, 2, 3]).mean()
|
||||
loss = F.mse_loss(model_pred.float(), target.float(), reduction="none").mean([1, 2, 3]).mean()
|
||||
|
||||
# Compute prior loss
|
||||
prior_loss = F.mse_loss(noise_pred_prior.float(), noise_prior.float(), reduction="mean")
|
||||
prior_loss = F.mse_loss(model_pred_prior.float(), target_prior.float(), reduction="mean")
|
||||
|
||||
# Add the prior loss to the instance loss.
|
||||
loss = loss + args.prior_loss_weight * prior_loss
|
||||
else:
|
||||
loss = F.mse_loss(noise_pred.float(), noise.float(), reduction="mean")
|
||||
loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean")
|
||||
|
||||
accelerator.backward(loss)
|
||||
if accelerator.sync_gradients:
|
||||
@@ -638,6 +672,17 @@ def main(args):
|
||||
progress_bar.update(1)
|
||||
global_step += 1
|
||||
|
||||
if global_step % args.save_steps == 0:
|
||||
if accelerator.is_main_process:
|
||||
pipeline = DiffusionPipeline.from_pretrained(
|
||||
args.pretrained_model_name_or_path,
|
||||
unet=accelerator.unwrap_model(unet),
|
||||
text_encoder=accelerator.unwrap_model(text_encoder),
|
||||
revision=args.revision,
|
||||
)
|
||||
save_path = os.path.join(args.output_dir, f"checkpoint-{global_step}")
|
||||
pipeline.save_pretrained(save_path)
|
||||
|
||||
logs = {"loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0]}
|
||||
progress_bar.set_postfix(**logs)
|
||||
accelerator.log(logs, step=global_step)
|
||||
@@ -649,7 +694,7 @@ def main(args):
|
||||
|
||||
# Create the pipeline using using the trained modules and save it.
|
||||
if accelerator.is_main_process:
|
||||
pipeline = StableDiffusionPipeline.from_pretrained(
|
||||
pipeline = DiffusionPipeline.from_pretrained(
|
||||
args.pretrained_model_name_or_path,
|
||||
unet=accelerator.unwrap_model(unet),
|
||||
text_encoder=accelerator.unwrap_model(text_encoder),
|
||||
|
||||
@@ -42,6 +42,8 @@ If you have already cloned the repo, then you won't need to go through these ste
|
||||
#### Hardware
|
||||
With `gradient_checkpointing` and `mixed_precision` it should be possible to fine tune the model on a single 24GB GPU. For higher `batch_size` and faster training it's better to use GPUs with >30GB memory.
|
||||
|
||||
**___Note: Change the `resolution` to 768 if you are using the [stable-diffusion-2](https://huggingface.co/stabilityai/stable-diffusion-2) 768x768 model.___**
|
||||
|
||||
```bash
|
||||
export MODEL_NAME="CompVis/stable-diffusion-v1-4"
|
||||
export dataset_name="lambdalabs/pokemon-blip-captions"
|
||||
|
||||
@@ -15,13 +15,12 @@ from accelerate import Accelerator
|
||||
from accelerate.logging import get_logger
|
||||
from accelerate.utils import set_seed
|
||||
from datasets import load_dataset
|
||||
from diffusers import AutoencoderKL, DDPMScheduler, PNDMScheduler, StableDiffusionPipeline, UNet2DConditionModel
|
||||
from diffusers import AutoencoderKL, DDPMScheduler, StableDiffusionPipeline, UNet2DConditionModel
|
||||
from diffusers.optimization import get_scheduler
|
||||
from diffusers.pipelines.stable_diffusion import StableDiffusionSafetyChecker
|
||||
from huggingface_hub import HfFolder, Repository, whoami
|
||||
from torchvision import transforms
|
||||
from tqdm.auto import tqdm
|
||||
from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer
|
||||
from transformers import CLIPTextModel, CLIPTokenizer
|
||||
|
||||
|
||||
logger = get_logger(__name__)
|
||||
@@ -36,6 +35,13 @@ def parse_args():
|
||||
required=True,
|
||||
help="Path to pretrained model or model identifier from huggingface.co/models.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--revision",
|
||||
type=str,
|
||||
default=None,
|
||||
required=False,
|
||||
help="Revision of pretrained model identifier from huggingface.co/models.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--dataset_name",
|
||||
type=str,
|
||||
@@ -335,10 +341,24 @@ def main():
|
||||
os.makedirs(args.output_dir, exist_ok=True)
|
||||
|
||||
# Load models and create wrapper for stable diffusion
|
||||
tokenizer = CLIPTokenizer.from_pretrained(args.pretrained_model_name_or_path, subfolder="tokenizer")
|
||||
text_encoder = CLIPTextModel.from_pretrained(args.pretrained_model_name_or_path, subfolder="text_encoder")
|
||||
vae = AutoencoderKL.from_pretrained(args.pretrained_model_name_or_path, subfolder="vae")
|
||||
unet = UNet2DConditionModel.from_pretrained(args.pretrained_model_name_or_path, subfolder="unet")
|
||||
tokenizer = CLIPTokenizer.from_pretrained(
|
||||
args.pretrained_model_name_or_path, subfolder="tokenizer", revision=args.revision
|
||||
)
|
||||
text_encoder = CLIPTextModel.from_pretrained(
|
||||
args.pretrained_model_name_or_path,
|
||||
subfolder="text_encoder",
|
||||
revision=args.revision,
|
||||
)
|
||||
vae = AutoencoderKL.from_pretrained(
|
||||
args.pretrained_model_name_or_path,
|
||||
subfolder="vae",
|
||||
revision=args.revision,
|
||||
)
|
||||
unet = UNet2DConditionModel.from_pretrained(
|
||||
args.pretrained_model_name_or_path,
|
||||
subfolder="unet",
|
||||
revision=args.revision,
|
||||
)
|
||||
|
||||
# Freeze vae and text_encoder
|
||||
vae.requires_grad_(False)
|
||||
@@ -562,9 +582,17 @@ def main():
|
||||
# Get the text embedding for conditioning
|
||||
encoder_hidden_states = text_encoder(batch["input_ids"])[0]
|
||||
|
||||
# Get the target for loss depending on the prediction type
|
||||
if noise_scheduler.config.prediction_type == "epsilon":
|
||||
target = noise
|
||||
elif noise_scheduler.config.prediction_type == "v_prediction":
|
||||
target = noise_scheduler.get_velocity(latents, noise, timesteps)
|
||||
else:
|
||||
raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}")
|
||||
|
||||
# Predict the noise residual and compute loss
|
||||
noise_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample
|
||||
loss = F.mse_loss(noise_pred.float(), noise.float(), reduction="mean")
|
||||
model_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample
|
||||
loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean")
|
||||
|
||||
# Gather the losses across all processes for logging (if we use distributed training).
|
||||
avg_loss = accelerator.gather(loss.repeat(args.train_batch_size)).mean()
|
||||
@@ -600,14 +628,12 @@ def main():
|
||||
if args.use_ema:
|
||||
ema_unet.copy_to(unet.parameters())
|
||||
|
||||
pipeline = StableDiffusionPipeline(
|
||||
pipeline = StableDiffusionPipeline.from_pretrained(
|
||||
args.pretrained_model_name_or_path,
|
||||
text_encoder=text_encoder,
|
||||
vae=vae,
|
||||
unet=unet,
|
||||
tokenizer=tokenizer,
|
||||
scheduler=PNDMScheduler.from_config(args.pretrained_model_name_or_path, subfolder="scheduler"),
|
||||
safety_checker=StableDiffusionSafetyChecker.from_pretrained("CompVis/stable-diffusion-safety-checker"),
|
||||
feature_extractor=CLIPFeatureExtractor.from_pretrained("openai/clip-vit-base-patch32"),
|
||||
revision=args.revision,
|
||||
)
|
||||
pipeline.save_pretrained(args.output_dir)
|
||||
|
||||
|
||||
@@ -47,6 +47,8 @@ Now let's get our dataset.Download 3-4 images from [here](https://drive.google.c
|
||||
|
||||
And launch the training using
|
||||
|
||||
**___Note: Change the `resolution` to 768 if you are using the [stable-diffusion-2](https://huggingface.co/stabilityai/stable-diffusion-2) 768x768 model.___**
|
||||
|
||||
```bash
|
||||
export MODEL_NAME="runwayml/stable-diffusion-v1-5"
|
||||
export DATA_DIR="path-to-dir-containing-images"
|
||||
|
||||
@@ -16,9 +16,8 @@ import PIL
|
||||
from accelerate import Accelerator
|
||||
from accelerate.logging import get_logger
|
||||
from accelerate.utils import set_seed
|
||||
from diffusers import AutoencoderKL, DDPMScheduler, PNDMScheduler, StableDiffusionPipeline, UNet2DConditionModel
|
||||
from diffusers import AutoencoderKL, DDPMScheduler, StableDiffusionPipeline, UNet2DConditionModel
|
||||
from diffusers.optimization import get_scheduler
|
||||
from diffusers.pipelines.stable_diffusion import StableDiffusionSafetyChecker
|
||||
from huggingface_hub import HfFolder, Repository, whoami
|
||||
|
||||
# TODO: remove and import from diffusers.utils when the new version of diffusers is released
|
||||
@@ -26,7 +25,7 @@ from packaging import version
|
||||
from PIL import Image
|
||||
from torchvision import transforms
|
||||
from tqdm.auto import tqdm
|
||||
from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer
|
||||
from transformers import CLIPTextModel, CLIPTokenizer
|
||||
|
||||
|
||||
if version.parse(version.parse(PIL.__version__).base_version) >= version.parse("9.1.0"):
|
||||
@@ -51,11 +50,11 @@ else:
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
def save_progress(text_encoder, placeholder_token_id, accelerator, args):
|
||||
def save_progress(text_encoder, placeholder_token_id, accelerator, args, save_path):
|
||||
logger.info("Saving embeddings")
|
||||
learned_embeds = accelerator.unwrap_model(text_encoder).get_input_embeddings().weight[placeholder_token_id]
|
||||
learned_embeds_dict = {args.placeholder_token: learned_embeds.detach().cpu()}
|
||||
torch.save(learned_embeds_dict, os.path.join(args.output_dir, "learned_embeds.bin"))
|
||||
torch.save(learned_embeds_dict, save_path)
|
||||
|
||||
|
||||
def parse_args():
|
||||
@@ -73,6 +72,13 @@ def parse_args():
|
||||
required=True,
|
||||
help="Path to pretrained model or model identifier from huggingface.co/models.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--revision",
|
||||
type=str,
|
||||
default=None,
|
||||
required=False,
|
||||
help="Revision of pretrained model identifier from huggingface.co/models.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--tokenizer_name",
|
||||
type=str,
|
||||
@@ -405,9 +411,21 @@ def main():
|
||||
placeholder_token_id = tokenizer.convert_tokens_to_ids(args.placeholder_token)
|
||||
|
||||
# Load models and create wrapper for stable diffusion
|
||||
text_encoder = CLIPTextModel.from_pretrained(args.pretrained_model_name_or_path, subfolder="text_encoder")
|
||||
vae = AutoencoderKL.from_pretrained(args.pretrained_model_name_or_path, subfolder="vae")
|
||||
unet = UNet2DConditionModel.from_pretrained(args.pretrained_model_name_or_path, subfolder="unet")
|
||||
text_encoder = CLIPTextModel.from_pretrained(
|
||||
args.pretrained_model_name_or_path,
|
||||
subfolder="text_encoder",
|
||||
revision=args.revision,
|
||||
)
|
||||
vae = AutoencoderKL.from_pretrained(
|
||||
args.pretrained_model_name_or_path,
|
||||
subfolder="vae",
|
||||
revision=args.revision,
|
||||
)
|
||||
unet = UNet2DConditionModel.from_pretrained(
|
||||
args.pretrained_model_name_or_path,
|
||||
subfolder="unet",
|
||||
revision=args.revision,
|
||||
)
|
||||
|
||||
# Resize the token embeddings as we are adding new special tokens to the tokenizer
|
||||
text_encoder.resize_token_embeddings(len(tokenizer))
|
||||
@@ -532,9 +550,17 @@ def main():
|
||||
encoder_hidden_states = text_encoder(batch["input_ids"])[0]
|
||||
|
||||
# Predict the noise residual
|
||||
noise_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample
|
||||
model_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample
|
||||
|
||||
loss = F.mse_loss(noise_pred, noise, reduction="none").mean([1, 2, 3]).mean()
|
||||
# Get the target for loss depending on the prediction type
|
||||
if noise_scheduler.config.prediction_type == "epsilon":
|
||||
target = noise
|
||||
elif noise_scheduler.config.prediction_type == "v_prediction":
|
||||
target = noise_scheduler.get_velocity(latents, noise, timesteps)
|
||||
else:
|
||||
raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}")
|
||||
|
||||
loss = F.mse_loss(model_pred, target, reduction="none").mean([1, 2, 3]).mean()
|
||||
accelerator.backward(loss)
|
||||
|
||||
# Zero out the gradients for all token embeddings except the newly added
|
||||
@@ -556,7 +582,8 @@ def main():
|
||||
progress_bar.update(1)
|
||||
global_step += 1
|
||||
if global_step % args.save_steps == 0:
|
||||
save_progress(text_encoder, placeholder_token_id, accelerator, args)
|
||||
save_path = os.path.join(args.output_dir, f"learned_embeds-steps-{global_step}.bin")
|
||||
save_progress(text_encoder, placeholder_token_id, accelerator, args, save_path)
|
||||
|
||||
logs = {"loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0]}
|
||||
progress_bar.set_postfix(**logs)
|
||||
@@ -569,18 +596,18 @@ def main():
|
||||
|
||||
# Create the pipeline using using the trained modules and save it.
|
||||
if accelerator.is_main_process:
|
||||
pipeline = StableDiffusionPipeline(
|
||||
pipeline = StableDiffusionPipeline.from_pretrained(
|
||||
args.pretrained_model_name_or_path,
|
||||
text_encoder=accelerator.unwrap_model(text_encoder),
|
||||
tokenizer=tokenizer,
|
||||
vae=vae,
|
||||
unet=unet,
|
||||
tokenizer=tokenizer,
|
||||
scheduler=PNDMScheduler.from_config(args.pretrained_model_name_or_path, subfolder="scheduler"),
|
||||
safety_checker=StableDiffusionSafetyChecker.from_pretrained("CompVis/stable-diffusion-safety-checker"),
|
||||
feature_extractor=CLIPFeatureExtractor.from_pretrained("openai/clip-vit-base-patch32"),
|
||||
revision=args.revision,
|
||||
)
|
||||
pipeline.save_pretrained(args.output_dir)
|
||||
# Also save the newly trained embeddings
|
||||
save_progress(text_encoder, placeholder_token_id, accelerator, args)
|
||||
save_path = os.path.join(args.output_dir, "learned_embeds.bin")
|
||||
save_progress(text_encoder, placeholder_token_id, accelerator, args, save_path)
|
||||
|
||||
if args.push_to_hub:
|
||||
repo.push_to_hub(commit_message="End of training", blocking=False, auto_lfs_prune=True)
|
||||
|
||||
@@ -320,7 +320,12 @@ def main(args):
|
||||
|
||||
num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
|
||||
|
||||
ema_model = EMAModel(model, inv_gamma=args.ema_inv_gamma, power=args.ema_power, max_value=args.ema_max_decay)
|
||||
ema_model = EMAModel(
|
||||
accelerator.unwrap_model(model),
|
||||
inv_gamma=args.ema_inv_gamma,
|
||||
power=args.ema_power,
|
||||
max_value=args.ema_max_decay,
|
||||
)
|
||||
|
||||
# Handle the repository creation
|
||||
if accelerator.is_main_process:
|
||||
|
||||
@@ -666,17 +666,29 @@ if __name__ == "__main__":
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
if args.original_config_file is None:
|
||||
os.system(
|
||||
"wget https://raw.githubusercontent.com/CompVis/stable-diffusion/main/configs/stable-diffusion/v1-inference.yaml"
|
||||
)
|
||||
args.original_config_file = "./v1-inference.yaml"
|
||||
|
||||
original_config = OmegaConf.load(args.original_config_file)
|
||||
|
||||
checkpoint = torch.load(args.checkpoint_path)
|
||||
checkpoint = checkpoint["state_dict"]
|
||||
|
||||
prediction_type = "epsilon"
|
||||
if args.original_config_file is None:
|
||||
key_name = "model.diffusion_model.input_blocks.2.1.transformer_blocks.0.attn2.to_k.weight"
|
||||
|
||||
if key_name in checkpoint and checkpoint[key_name].shape[-1] == 1024:
|
||||
# model_type = "v2"
|
||||
os.system(
|
||||
"wget https://raw.githubusercontent.com/Stability-AI/stablediffusion/main/configs/stable-diffusion/v2-inference-v.yaml"
|
||||
)
|
||||
args.original_config_file = "./v2-inference-v.yaml"
|
||||
prediction_type
|
||||
else:
|
||||
# model_type = "v2"
|
||||
os.system(
|
||||
"wget https://raw.githubusercontent.com/CompVis/stable-diffusion/main/configs/stable-diffusion/v1-inference.yaml"
|
||||
)
|
||||
args.original_config_file = "./v1-inference.yaml"
|
||||
|
||||
original_config = OmegaConf.load(args.original_config_file)
|
||||
|
||||
num_train_timesteps = original_config.model.params.timesteps
|
||||
beta_start = original_config.model.params.linear_start
|
||||
beta_end = original_config.model.params.linear_end
|
||||
|
||||
@@ -0,0 +1,70 @@
|
||||
model:
|
||||
base_learning_rate: 1.0e-04
|
||||
target: ldm.models.diffusion.ddpm.LatentDiffusion
|
||||
params:
|
||||
linear_start: 0.00085
|
||||
linear_end: 0.0120
|
||||
num_timesteps_cond: 1
|
||||
log_every_t: 200
|
||||
timesteps: 1000
|
||||
first_stage_key: "jpg"
|
||||
cond_stage_key: "txt"
|
||||
image_size: 64
|
||||
channels: 4
|
||||
cond_stage_trainable: false # Note: different from the one we trained before
|
||||
conditioning_key: crossattn
|
||||
monitor: val/loss_simple_ema
|
||||
scale_factor: 0.18215
|
||||
use_ema: False
|
||||
|
||||
scheduler_config: # 10000 warmup steps
|
||||
target: ldm.lr_scheduler.LambdaLinearScheduler
|
||||
params:
|
||||
warm_up_steps: [ 10000 ]
|
||||
cycle_lengths: [ 10000000000000 ] # incredibly large number to prevent corner cases
|
||||
f_start: [ 1.e-6 ]
|
||||
f_max: [ 1. ]
|
||||
f_min: [ 1. ]
|
||||
|
||||
unet_config:
|
||||
target: ldm.modules.diffusionmodules.openaimodel.UNetModel
|
||||
params:
|
||||
image_size: 32 # unused
|
||||
in_channels: 4
|
||||
out_channels: 4
|
||||
model_channels: 320
|
||||
attention_resolutions: [ 4, 2, 1 ]
|
||||
num_res_blocks: 2
|
||||
channel_mult: [ 1, 2, 4, 4 ]
|
||||
num_heads: 8
|
||||
use_spatial_transformer: True
|
||||
transformer_depth: 1
|
||||
context_dim: 768
|
||||
use_checkpoint: True
|
||||
legacy: False
|
||||
|
||||
first_stage_config:
|
||||
target: ldm.models.autoencoder.AutoencoderKL
|
||||
params:
|
||||
embed_dim: 4
|
||||
monitor: val/rec_loss
|
||||
ddconfig:
|
||||
double_z: true
|
||||
z_channels: 4
|
||||
resolution: 256
|
||||
in_channels: 3
|
||||
out_ch: 3
|
||||
ch: 128
|
||||
ch_mult:
|
||||
- 1
|
||||
- 2
|
||||
- 4
|
||||
- 4
|
||||
num_res_blocks: 2
|
||||
attn_resolutions: []
|
||||
dropout: 0.0
|
||||
lossconfig:
|
||||
target: torch.nn.Identity
|
||||
|
||||
cond_stage_config:
|
||||
target: ldm.modules.encoders.modules.FrozenCLIPEmbedder
|
||||
@@ -97,6 +97,7 @@ _deps = [
|
||||
"pytest",
|
||||
"pytest-timeout",
|
||||
"pytest-xdist",
|
||||
"safetensors",
|
||||
"sentencepiece>=0.1.91,!=0.1.92",
|
||||
"scipy",
|
||||
"regex!=2019.12.17",
|
||||
@@ -184,10 +185,11 @@ extras["test"] = deps_list(
|
||||
"pytest",
|
||||
"pytest-timeout",
|
||||
"pytest-xdist",
|
||||
"safetensors",
|
||||
"sentencepiece",
|
||||
"scipy",
|
||||
"torchvision",
|
||||
"transformers"
|
||||
"transformers",
|
||||
)
|
||||
extras["torch"] = deps_list("torch", "accelerate")
|
||||
|
||||
|
||||
@@ -46,8 +46,11 @@ if is_torch_available():
|
||||
DPMSolverMultistepScheduler,
|
||||
EulerAncestralDiscreteScheduler,
|
||||
EulerDiscreteScheduler,
|
||||
HeunDiscreteScheduler,
|
||||
IPNDMScheduler,
|
||||
KarrasVeScheduler,
|
||||
KDPM2AncestralDiscreteScheduler,
|
||||
KDPM2DiscreteScheduler,
|
||||
PNDMScheduler,
|
||||
RePaintScheduler,
|
||||
SchedulerMixin,
|
||||
|
||||
@@ -24,6 +24,8 @@ import re
|
||||
from collections import OrderedDict
|
||||
from typing import Any, Dict, Tuple, Union
|
||||
|
||||
import numpy as np
|
||||
|
||||
from huggingface_hub import hf_hub_download
|
||||
from huggingface_hub.utils import EntryNotFoundError, RepositoryNotFoundError, RevisionNotFoundError
|
||||
from requests import HTTPError
|
||||
@@ -502,6 +504,12 @@ class ConfigMixin:
|
||||
config_dict["_class_name"] = self.__class__.__name__
|
||||
config_dict["_diffusers_version"] = __version__
|
||||
|
||||
def to_json_saveable(value):
|
||||
if isinstance(value, np.ndarray):
|
||||
value = value.tolist()
|
||||
return value
|
||||
|
||||
config_dict = {k: to_json_saveable(v) for k, v in config_dict.items()}
|
||||
return json.dumps(config_dict, indent=2, sort_keys=True) + "\n"
|
||||
|
||||
def to_json_file(self, json_file_path: Union[str, os.PathLike]):
|
||||
|
||||
@@ -21,6 +21,7 @@ deps = {
|
||||
"pytest": "pytest",
|
||||
"pytest-timeout": "pytest-timeout",
|
||||
"pytest-xdist": "pytest-xdist",
|
||||
"safetensors": "safetensors",
|
||||
"sentencepiece": "sentencepiece>=0.1.91,!=0.1.92",
|
||||
"scipy": "scipy",
|
||||
"regex": "regex!=2019.12.17",
|
||||
|
||||
@@ -332,7 +332,7 @@ class FlaxModelMixin:
|
||||
elif os.path.isfile(os.path.join(pretrained_path_with_subfolder, WEIGHTS_NAME)):
|
||||
raise EnvironmentError(
|
||||
f"{WEIGHTS_NAME} file found in directory {pretrained_path_with_subfolder}. Please load the model"
|
||||
" using `from_pt=True`."
|
||||
" using `from_pt=True`."
|
||||
)
|
||||
else:
|
||||
raise EnvironmentError(
|
||||
|
||||
+135
-63
@@ -30,8 +30,10 @@ from .utils import (
|
||||
CONFIG_NAME,
|
||||
DIFFUSERS_CACHE,
|
||||
HUGGINGFACE_CO_RESOLVE_ENDPOINT,
|
||||
SAFETENSORS_WEIGHTS_NAME,
|
||||
WEIGHTS_NAME,
|
||||
is_accelerate_available,
|
||||
is_safetensors_available,
|
||||
is_torch_version,
|
||||
logging,
|
||||
)
|
||||
@@ -51,6 +53,9 @@ if is_accelerate_available():
|
||||
from accelerate.utils import set_module_tensor_to_device
|
||||
from accelerate.utils.versions import is_torch_version
|
||||
|
||||
if is_safetensors_available():
|
||||
import safetensors
|
||||
|
||||
|
||||
def get_parameter_device(parameter: torch.nn.Module):
|
||||
try:
|
||||
@@ -84,10 +89,13 @@ def get_parameter_dtype(parameter: torch.nn.Module):
|
||||
|
||||
def load_state_dict(checkpoint_file: Union[str, os.PathLike]):
|
||||
"""
|
||||
Reads a PyTorch checkpoint file, returning properly formatted errors if they arise.
|
||||
Reads a checkpoint file, returning properly formatted errors if they arise.
|
||||
"""
|
||||
try:
|
||||
return torch.load(checkpoint_file, map_location="cpu")
|
||||
if os.path.basename(checkpoint_file) == WEIGHTS_NAME:
|
||||
return torch.load(checkpoint_file, map_location="cpu")
|
||||
else:
|
||||
return safetensors.torch.load_file(checkpoint_file, device="cpu")
|
||||
except Exception as e:
|
||||
try:
|
||||
with open(checkpoint_file) as f:
|
||||
@@ -104,7 +112,7 @@ def load_state_dict(checkpoint_file: Union[str, os.PathLike]):
|
||||
) from e
|
||||
except (UnicodeDecodeError, ValueError):
|
||||
raise OSError(
|
||||
f"Unable to load weights from pytorch checkpoint file for '{checkpoint_file}' "
|
||||
f"Unable to load weights from checkpoint file for '{checkpoint_file}' "
|
||||
f"at '{checkpoint_file}'. "
|
||||
"If you tried to load a PyTorch model from a TF 2.0 checkpoint, please set from_tf=True."
|
||||
)
|
||||
@@ -375,75 +383,39 @@ class ModelMixin(torch.nn.Module):
|
||||
|
||||
# This variable will flag if we're loading a sharded checkpoint. In this case the archive file is just the
|
||||
# Load model
|
||||
pretrained_model_name_or_path = str(pretrained_model_name_or_path)
|
||||
if os.path.isdir(pretrained_model_name_or_path):
|
||||
if os.path.isfile(os.path.join(pretrained_model_name_or_path, WEIGHTS_NAME)):
|
||||
# Load from a PyTorch checkpoint
|
||||
model_file = os.path.join(pretrained_model_name_or_path, WEIGHTS_NAME)
|
||||
elif subfolder is not None and os.path.isfile(
|
||||
os.path.join(pretrained_model_name_or_path, subfolder, WEIGHTS_NAME)
|
||||
):
|
||||
model_file = os.path.join(pretrained_model_name_or_path, subfolder, WEIGHTS_NAME)
|
||||
else:
|
||||
raise EnvironmentError(
|
||||
f"Error no file named {WEIGHTS_NAME} found in directory {pretrained_model_name_or_path}."
|
||||
)
|
||||
else:
|
||||
|
||||
model_file = None
|
||||
if is_safetensors_available():
|
||||
try:
|
||||
# Load from URL or cache if already cached
|
||||
model_file = hf_hub_download(
|
||||
model_file = _get_model_file(
|
||||
pretrained_model_name_or_path,
|
||||
filename=WEIGHTS_NAME,
|
||||
weights_name=SAFETENSORS_WEIGHTS_NAME,
|
||||
cache_dir=cache_dir,
|
||||
force_download=force_download,
|
||||
proxies=proxies,
|
||||
resume_download=resume_download,
|
||||
proxies=proxies,
|
||||
local_files_only=local_files_only,
|
||||
use_auth_token=use_auth_token,
|
||||
user_agent=user_agent,
|
||||
subfolder=subfolder,
|
||||
revision=revision,
|
||||
subfolder=subfolder,
|
||||
user_agent=user_agent,
|
||||
)
|
||||
|
||||
except RepositoryNotFoundError:
|
||||
raise EnvironmentError(
|
||||
f"{pretrained_model_name_or_path} is not a local folder and is not a valid model identifier "
|
||||
"listed on 'https://huggingface.co/models'\nIf this is a private repository, make sure to pass a "
|
||||
"token having permission to this repo with `use_auth_token` or log in with `huggingface-cli "
|
||||
"login`."
|
||||
)
|
||||
except RevisionNotFoundError:
|
||||
raise EnvironmentError(
|
||||
f"{revision} is not a valid git identifier (branch name, tag name or commit id) that exists for "
|
||||
"this model name. Check the model page at "
|
||||
f"'https://huggingface.co/{pretrained_model_name_or_path}' for available revisions."
|
||||
)
|
||||
except EntryNotFoundError:
|
||||
raise EnvironmentError(
|
||||
f"{pretrained_model_name_or_path} does not appear to have a file named {WEIGHTS_NAME}."
|
||||
)
|
||||
except HTTPError as err:
|
||||
raise EnvironmentError(
|
||||
"There was a specific connection error when trying to load"
|
||||
f" {pretrained_model_name_or_path}:\n{err}"
|
||||
)
|
||||
except ValueError:
|
||||
raise EnvironmentError(
|
||||
f"We couldn't connect to '{HUGGINGFACE_CO_RESOLVE_ENDPOINT}' to load this model, couldn't find it"
|
||||
f" in the cached files and it looks like {pretrained_model_name_or_path} is not the path to a"
|
||||
f" directory containing a file named {WEIGHTS_NAME} or"
|
||||
" \nCheckout your internet connection or see how to run the library in"
|
||||
" offline mode at 'https://huggingface.co/docs/diffusers/installation#offline-mode'."
|
||||
)
|
||||
except EnvironmentError:
|
||||
raise EnvironmentError(
|
||||
f"Can't load the model for '{pretrained_model_name_or_path}'. If you were trying to load it from "
|
||||
"'https://huggingface.co/models', make sure you don't have a local directory with the same name. "
|
||||
f"Otherwise, make sure '{pretrained_model_name_or_path}' is the correct path to a directory "
|
||||
f"containing a file named {WEIGHTS_NAME}"
|
||||
)
|
||||
|
||||
# restore default dtype
|
||||
except:
|
||||
pass
|
||||
if model_file is None:
|
||||
model_file = _get_model_file(
|
||||
pretrained_model_name_or_path,
|
||||
weights_name=WEIGHTS_NAME,
|
||||
cache_dir=cache_dir,
|
||||
force_download=force_download,
|
||||
resume_download=resume_download,
|
||||
proxies=proxies,
|
||||
local_files_only=local_files_only,
|
||||
use_auth_token=use_auth_token,
|
||||
revision=revision,
|
||||
subfolder=subfolder,
|
||||
user_agent=user_agent,
|
||||
)
|
||||
|
||||
if low_cpu_mem_usage:
|
||||
# Instantiate model with empty weights
|
||||
@@ -500,6 +472,21 @@ class ModelMixin(torch.nn.Module):
|
||||
model = cls.from_config(config, **unused_kwargs)
|
||||
|
||||
state_dict = load_state_dict(model_file)
|
||||
dtype = set(v.dtype for v in state_dict.values())
|
||||
|
||||
if len(dtype) > 1 and torch.float32 not in dtype:
|
||||
raise ValueError(
|
||||
f"The weights of the model file {model_file} have a mixture of incompatible dtypes {dtype}. Please"
|
||||
f" make sure that {model_file} weights have only one dtype."
|
||||
)
|
||||
elif len(dtype) > 1 and torch.float32 in dtype:
|
||||
dtype = torch.float32
|
||||
else:
|
||||
dtype = dtype.pop()
|
||||
|
||||
# move model to correct dtype
|
||||
model = model.to(dtype)
|
||||
|
||||
model, missing_keys, unexpected_keys, mismatched_keys, error_msgs = cls._load_pretrained_model(
|
||||
model,
|
||||
state_dict,
|
||||
@@ -691,3 +678,88 @@ def unwrap_model(model: torch.nn.Module) -> torch.nn.Module:
|
||||
return unwrap_model(model.module)
|
||||
else:
|
||||
return model
|
||||
|
||||
|
||||
def _get_model_file(
|
||||
pretrained_model_name_or_path,
|
||||
*,
|
||||
weights_name,
|
||||
subfolder,
|
||||
cache_dir,
|
||||
force_download,
|
||||
proxies,
|
||||
resume_download,
|
||||
local_files_only,
|
||||
use_auth_token,
|
||||
user_agent,
|
||||
revision,
|
||||
):
|
||||
pretrained_model_name_or_path = str(pretrained_model_name_or_path)
|
||||
if os.path.isdir(pretrained_model_name_or_path):
|
||||
if os.path.isfile(os.path.join(pretrained_model_name_or_path, weights_name)):
|
||||
# Load from a PyTorch checkpoint
|
||||
model_file = os.path.join(pretrained_model_name_or_path, weights_name)
|
||||
return model_file
|
||||
elif subfolder is not None and os.path.isfile(
|
||||
os.path.join(pretrained_model_name_or_path, subfolder, weights_name)
|
||||
):
|
||||
model_file = os.path.join(pretrained_model_name_or_path, subfolder, weights_name)
|
||||
return model_file
|
||||
else:
|
||||
raise EnvironmentError(
|
||||
f"Error no file named {weights_name} found in directory {pretrained_model_name_or_path}."
|
||||
)
|
||||
else:
|
||||
try:
|
||||
# Load from URL or cache if already cached
|
||||
model_file = hf_hub_download(
|
||||
pretrained_model_name_or_path,
|
||||
filename=weights_name,
|
||||
cache_dir=cache_dir,
|
||||
force_download=force_download,
|
||||
proxies=proxies,
|
||||
resume_download=resume_download,
|
||||
local_files_only=local_files_only,
|
||||
use_auth_token=use_auth_token,
|
||||
user_agent=user_agent,
|
||||
subfolder=subfolder,
|
||||
revision=revision,
|
||||
)
|
||||
return model_file
|
||||
|
||||
except RepositoryNotFoundError:
|
||||
raise EnvironmentError(
|
||||
f"{pretrained_model_name_or_path} is not a local folder and is not a valid model identifier "
|
||||
"listed on 'https://huggingface.co/models'\nIf this is a private repository, make sure to pass a "
|
||||
"token having permission to this repo with `use_auth_token` or log in with `huggingface-cli "
|
||||
"login`."
|
||||
)
|
||||
except RevisionNotFoundError:
|
||||
raise EnvironmentError(
|
||||
f"{revision} is not a valid git identifier (branch name, tag name or commit id) that exists for "
|
||||
"this model name. Check the model page at "
|
||||
f"'https://huggingface.co/{pretrained_model_name_or_path}' for available revisions."
|
||||
)
|
||||
except EntryNotFoundError:
|
||||
raise EnvironmentError(
|
||||
f"{pretrained_model_name_or_path} does not appear to have a file named {weights_name}."
|
||||
)
|
||||
except HTTPError as err:
|
||||
raise EnvironmentError(
|
||||
f"There was a specific connection error when trying to load {pretrained_model_name_or_path}:\n{err}"
|
||||
)
|
||||
except ValueError:
|
||||
raise EnvironmentError(
|
||||
f"We couldn't connect to '{HUGGINGFACE_CO_RESOLVE_ENDPOINT}' to load this model, couldn't find it"
|
||||
f" in the cached files and it looks like {pretrained_model_name_or_path} is not the path to a"
|
||||
f" directory containing a file named {weights_name} or"
|
||||
" \nCheckout your internet connection or see how to run the library in"
|
||||
" offline mode at 'https://huggingface.co/docs/diffusers/installation#offline-mode'."
|
||||
)
|
||||
except EnvironmentError:
|
||||
raise EnvironmentError(
|
||||
f"Can't load the model for '{pretrained_model_name_or_path}'. If you were trying to load it from "
|
||||
"'https://huggingface.co/models', make sure you don't have a local directory with the same name. "
|
||||
f"Otherwise, make sure '{pretrained_model_name_or_path}' is the correct path to a directory "
|
||||
f"containing a file named {weights_name}"
|
||||
)
|
||||
|
||||
@@ -104,6 +104,8 @@ class FlaxBasicTransformerBlock(nn.Module):
|
||||
Hidden states dimension inside each head
|
||||
dropout (:obj:`float`, *optional*, defaults to 0.0):
|
||||
Dropout rate
|
||||
only_cross_attention (`bool`, defaults to `False`):
|
||||
Whether to only apply cross attention.
|
||||
dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32):
|
||||
Parameters `dtype`
|
||||
"""
|
||||
@@ -111,10 +113,11 @@ class FlaxBasicTransformerBlock(nn.Module):
|
||||
n_heads: int
|
||||
d_head: int
|
||||
dropout: float = 0.0
|
||||
only_cross_attention: bool = False
|
||||
dtype: jnp.dtype = jnp.float32
|
||||
|
||||
def setup(self):
|
||||
# self attention
|
||||
# self attention (or cross_attention if only_cross_attention is True)
|
||||
self.attn1 = FlaxAttentionBlock(self.dim, self.n_heads, self.d_head, self.dropout, dtype=self.dtype)
|
||||
# cross attention
|
||||
self.attn2 = FlaxAttentionBlock(self.dim, self.n_heads, self.d_head, self.dropout, dtype=self.dtype)
|
||||
@@ -126,7 +129,10 @@ class FlaxBasicTransformerBlock(nn.Module):
|
||||
def __call__(self, hidden_states, context, deterministic=True):
|
||||
# self attention
|
||||
residual = hidden_states
|
||||
hidden_states = self.attn1(self.norm1(hidden_states), deterministic=deterministic)
|
||||
if self.only_cross_attention:
|
||||
hidden_states = self.attn1(self.norm1(hidden_states), context, deterministic=deterministic)
|
||||
else:
|
||||
hidden_states = self.attn1(self.norm1(hidden_states), deterministic=deterministic)
|
||||
hidden_states = hidden_states + residual
|
||||
|
||||
# cross attention
|
||||
@@ -159,6 +165,8 @@ class FlaxTransformer2DModel(nn.Module):
|
||||
Number of transformers block
|
||||
dropout (:obj:`float`, *optional*, defaults to 0.0):
|
||||
Dropout rate
|
||||
use_linear_projection (`bool`, defaults to `False`): tbd
|
||||
only_cross_attention (`bool`, defaults to `False`): tbd
|
||||
dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32):
|
||||
Parameters `dtype`
|
||||
"""
|
||||
@@ -167,49 +175,70 @@ class FlaxTransformer2DModel(nn.Module):
|
||||
d_head: int
|
||||
depth: int = 1
|
||||
dropout: float = 0.0
|
||||
use_linear_projection: bool = False
|
||||
only_cross_attention: bool = False
|
||||
dtype: jnp.dtype = jnp.float32
|
||||
|
||||
def setup(self):
|
||||
self.norm = nn.GroupNorm(num_groups=32, epsilon=1e-5)
|
||||
|
||||
inner_dim = self.n_heads * self.d_head
|
||||
self.proj_in = nn.Conv(
|
||||
inner_dim,
|
||||
kernel_size=(1, 1),
|
||||
strides=(1, 1),
|
||||
padding="VALID",
|
||||
dtype=self.dtype,
|
||||
)
|
||||
if self.use_linear_projection:
|
||||
self.proj_in = nn.Dense(inner_dim, dtype=self.dtype)
|
||||
else:
|
||||
self.proj_in = nn.Conv(
|
||||
inner_dim,
|
||||
kernel_size=(1, 1),
|
||||
strides=(1, 1),
|
||||
padding="VALID",
|
||||
dtype=self.dtype,
|
||||
)
|
||||
|
||||
self.transformer_blocks = [
|
||||
FlaxBasicTransformerBlock(inner_dim, self.n_heads, self.d_head, dropout=self.dropout, dtype=self.dtype)
|
||||
FlaxBasicTransformerBlock(
|
||||
inner_dim,
|
||||
self.n_heads,
|
||||
self.d_head,
|
||||
dropout=self.dropout,
|
||||
only_cross_attention=self.only_cross_attention,
|
||||
dtype=self.dtype,
|
||||
)
|
||||
for _ in range(self.depth)
|
||||
]
|
||||
|
||||
self.proj_out = nn.Conv(
|
||||
inner_dim,
|
||||
kernel_size=(1, 1),
|
||||
strides=(1, 1),
|
||||
padding="VALID",
|
||||
dtype=self.dtype,
|
||||
)
|
||||
if self.use_linear_projection:
|
||||
self.proj_out = nn.Dense(inner_dim, dtype=self.dtype)
|
||||
else:
|
||||
self.proj_out = nn.Conv(
|
||||
inner_dim,
|
||||
kernel_size=(1, 1),
|
||||
strides=(1, 1),
|
||||
padding="VALID",
|
||||
dtype=self.dtype,
|
||||
)
|
||||
|
||||
def __call__(self, hidden_states, context, deterministic=True):
|
||||
batch, height, width, channels = hidden_states.shape
|
||||
residual = hidden_states
|
||||
hidden_states = self.norm(hidden_states)
|
||||
hidden_states = self.proj_in(hidden_states)
|
||||
|
||||
hidden_states = hidden_states.reshape(batch, height * width, channels)
|
||||
if self.use_linear_projection:
|
||||
hidden_states = hidden_states.reshape(batch, height * width, channels)
|
||||
hidden_states = self.proj_in(hidden_states)
|
||||
else:
|
||||
hidden_states = self.proj_in(hidden_states)
|
||||
hidden_states = hidden_states.reshape(batch, height * width, channels)
|
||||
|
||||
for transformer_block in self.transformer_blocks:
|
||||
hidden_states = transformer_block(hidden_states, context, deterministic=deterministic)
|
||||
|
||||
hidden_states = hidden_states.reshape(batch, height, width, channels)
|
||||
if self.use_linear_projection:
|
||||
hidden_states = self.proj_out(hidden_states)
|
||||
hidden_states = hidden_states.reshape(batch, height, width, channels)
|
||||
else:
|
||||
hidden_states = hidden_states.reshape(batch, height, width, channels)
|
||||
hidden_states = self.proj_out(hidden_states)
|
||||
|
||||
hidden_states = self.proj_out(hidden_states)
|
||||
hidden_states = hidden_states + residual
|
||||
|
||||
return hidden_states
|
||||
|
||||
|
||||
|
||||
@@ -46,6 +46,8 @@ class FlaxCrossAttnDownBlock2D(nn.Module):
|
||||
num_layers: int = 1
|
||||
attn_num_head_channels: int = 1
|
||||
add_downsample: bool = True
|
||||
use_linear_projection: bool = False
|
||||
only_cross_attention: bool = False
|
||||
dtype: jnp.dtype = jnp.float32
|
||||
|
||||
def setup(self):
|
||||
@@ -68,6 +70,8 @@ class FlaxCrossAttnDownBlock2D(nn.Module):
|
||||
n_heads=self.attn_num_head_channels,
|
||||
d_head=self.out_channels // self.attn_num_head_channels,
|
||||
depth=1,
|
||||
use_linear_projection=self.use_linear_projection,
|
||||
only_cross_attention=self.only_cross_attention,
|
||||
dtype=self.dtype,
|
||||
)
|
||||
attentions.append(attn_block)
|
||||
@@ -178,6 +182,8 @@ class FlaxCrossAttnUpBlock2D(nn.Module):
|
||||
num_layers: int = 1
|
||||
attn_num_head_channels: int = 1
|
||||
add_upsample: bool = True
|
||||
use_linear_projection: bool = False
|
||||
only_cross_attention: bool = False
|
||||
dtype: jnp.dtype = jnp.float32
|
||||
|
||||
def setup(self):
|
||||
@@ -201,6 +207,8 @@ class FlaxCrossAttnUpBlock2D(nn.Module):
|
||||
n_heads=self.attn_num_head_channels,
|
||||
d_head=self.out_channels // self.attn_num_head_channels,
|
||||
depth=1,
|
||||
use_linear_projection=self.use_linear_projection,
|
||||
only_cross_attention=self.only_cross_attention,
|
||||
dtype=self.dtype,
|
||||
)
|
||||
attentions.append(attn_block)
|
||||
@@ -310,6 +318,7 @@ class FlaxUNetMidBlock2DCrossAttn(nn.Module):
|
||||
dropout: float = 0.0
|
||||
num_layers: int = 1
|
||||
attn_num_head_channels: int = 1
|
||||
use_linear_projection: bool = False
|
||||
dtype: jnp.dtype = jnp.float32
|
||||
|
||||
def setup(self):
|
||||
@@ -331,6 +340,7 @@ class FlaxUNetMidBlock2DCrossAttn(nn.Module):
|
||||
n_heads=self.attn_num_head_channels,
|
||||
d_head=self.in_channels // self.attn_num_head_channels,
|
||||
depth=1,
|
||||
use_linear_projection=self.use_linear_projection,
|
||||
dtype=self.dtype,
|
||||
)
|
||||
attentions.append(attn_block)
|
||||
|
||||
@@ -79,7 +79,7 @@ class FlaxUNet2DConditionModel(nn.Module, FlaxModelMixin, ConfigMixin):
|
||||
The tuple of output channels for each block.
|
||||
layers_per_block (`int`, *optional*, defaults to 2):
|
||||
The number of layers per block.
|
||||
attention_head_dim (`int`, *optional*, defaults to 8):
|
||||
attention_head_dim (`int` or `Tuple[int]`, *optional*, defaults to 8):
|
||||
The dimension of the attention heads.
|
||||
cross_attention_dim (`int`, *optional*, defaults to 768):
|
||||
The dimension of the cross attention features.
|
||||
@@ -97,11 +97,13 @@ class FlaxUNet2DConditionModel(nn.Module, FlaxModelMixin, ConfigMixin):
|
||||
"DownBlock2D",
|
||||
)
|
||||
up_block_types: Tuple[str] = ("UpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D")
|
||||
only_cross_attention: Union[bool, Tuple[bool]] = False
|
||||
block_out_channels: Tuple[int] = (320, 640, 1280, 1280)
|
||||
layers_per_block: int = 2
|
||||
attention_head_dim: int = 8
|
||||
attention_head_dim: Union[int, Tuple[int]] = 8
|
||||
cross_attention_dim: int = 1280
|
||||
dropout: float = 0.0
|
||||
use_linear_projection: bool = False
|
||||
dtype: jnp.dtype = jnp.float32
|
||||
freq_shift: int = 0
|
||||
|
||||
@@ -134,6 +136,14 @@ class FlaxUNet2DConditionModel(nn.Module, FlaxModelMixin, ConfigMixin):
|
||||
self.time_proj = FlaxTimesteps(block_out_channels[0], freq_shift=self.config.freq_shift)
|
||||
self.time_embedding = FlaxTimestepEmbedding(time_embed_dim, dtype=self.dtype)
|
||||
|
||||
only_cross_attention = self.only_cross_attention
|
||||
if isinstance(only_cross_attention, bool):
|
||||
only_cross_attention = (only_cross_attention,) * len(self.down_block_types)
|
||||
|
||||
attention_head_dim = self.attention_head_dim
|
||||
if isinstance(attention_head_dim, int):
|
||||
attention_head_dim = (attention_head_dim,) * len(self.down_block_types)
|
||||
|
||||
# down
|
||||
down_blocks = []
|
||||
output_channel = block_out_channels[0]
|
||||
@@ -148,8 +158,10 @@ class FlaxUNet2DConditionModel(nn.Module, FlaxModelMixin, ConfigMixin):
|
||||
out_channels=output_channel,
|
||||
dropout=self.dropout,
|
||||
num_layers=self.layers_per_block,
|
||||
attn_num_head_channels=self.attention_head_dim,
|
||||
attn_num_head_channels=attention_head_dim[i],
|
||||
add_downsample=not is_final_block,
|
||||
use_linear_projection=self.use_linear_projection,
|
||||
only_cross_attention=only_cross_attention[i],
|
||||
dtype=self.dtype,
|
||||
)
|
||||
else:
|
||||
@@ -169,13 +181,16 @@ class FlaxUNet2DConditionModel(nn.Module, FlaxModelMixin, ConfigMixin):
|
||||
self.mid_block = FlaxUNetMidBlock2DCrossAttn(
|
||||
in_channels=block_out_channels[-1],
|
||||
dropout=self.dropout,
|
||||
attn_num_head_channels=self.attention_head_dim,
|
||||
attn_num_head_channels=attention_head_dim[-1],
|
||||
use_linear_projection=self.use_linear_projection,
|
||||
dtype=self.dtype,
|
||||
)
|
||||
|
||||
# up
|
||||
up_blocks = []
|
||||
reversed_block_out_channels = list(reversed(block_out_channels))
|
||||
reversed_attention_head_dim = list(reversed(attention_head_dim))
|
||||
only_cross_attention = list(reversed(only_cross_attention))
|
||||
output_channel = reversed_block_out_channels[0]
|
||||
for i, up_block_type in enumerate(self.up_block_types):
|
||||
prev_output_channel = output_channel
|
||||
@@ -190,9 +205,11 @@ class FlaxUNet2DConditionModel(nn.Module, FlaxModelMixin, ConfigMixin):
|
||||
out_channels=output_channel,
|
||||
prev_output_channel=prev_output_channel,
|
||||
num_layers=self.layers_per_block + 1,
|
||||
attn_num_head_channels=self.attention_head_dim,
|
||||
attn_num_head_channels=reversed_attention_head_dim[i],
|
||||
add_upsample=not is_final_block,
|
||||
dropout=self.dropout,
|
||||
use_linear_projection=self.use_linear_projection,
|
||||
only_cross_attention=only_cross_attention[i],
|
||||
dtype=self.dtype,
|
||||
)
|
||||
else:
|
||||
|
||||
@@ -565,6 +565,7 @@ class AutoencoderKL(ModelMixin, ConfigMixin):
|
||||
|
||||
self.quant_conv = torch.nn.Conv2d(2 * latent_channels, 2 * latent_channels, 1)
|
||||
self.post_quant_conv = torch.nn.Conv2d(latent_channels, latent_channels, 1)
|
||||
self.use_slicing = False
|
||||
|
||||
def encode(self, x: torch.FloatTensor, return_dict: bool = True) -> AutoencoderKLOutput:
|
||||
h = self.encoder(x)
|
||||
@@ -576,7 +577,7 @@ class AutoencoderKL(ModelMixin, ConfigMixin):
|
||||
|
||||
return AutoencoderKLOutput(latent_dist=posterior)
|
||||
|
||||
def decode(self, z: torch.FloatTensor, return_dict: bool = True) -> Union[DecoderOutput, torch.FloatTensor]:
|
||||
def _decode(self, z: torch.FloatTensor, return_dict: bool = True) -> Union[DecoderOutput, torch.FloatTensor]:
|
||||
z = self.post_quant_conv(z)
|
||||
dec = self.decoder(z)
|
||||
|
||||
@@ -585,6 +586,34 @@ class AutoencoderKL(ModelMixin, ConfigMixin):
|
||||
|
||||
return DecoderOutput(sample=dec)
|
||||
|
||||
def enable_slicing(self):
|
||||
r"""
|
||||
Enable sliced VAE decoding.
|
||||
|
||||
When this option is enabled, the VAE will split the input tensor in slices to compute decoding in several
|
||||
steps. This is useful to save some memory and allow larger batch sizes.
|
||||
"""
|
||||
self.use_slicing = True
|
||||
|
||||
def disable_slicing(self):
|
||||
r"""
|
||||
Disable sliced VAE decoding. If `enable_slicing` was previously invoked, this method will go back to computing
|
||||
decoding in one step.
|
||||
"""
|
||||
self.use_slicing = False
|
||||
|
||||
def decode(self, z: torch.FloatTensor, return_dict: bool = True) -> Union[DecoderOutput, torch.FloatTensor]:
|
||||
if self.use_slicing and z.shape[0] > 1:
|
||||
decoded_slices = [self._decode(z_slice).sample for z_slice in z.split(1)]
|
||||
decoded = torch.cat(decoded_slices)
|
||||
else:
|
||||
decoded = self._decode(z).sample
|
||||
|
||||
if not return_dict:
|
||||
return (decoded,)
|
||||
|
||||
return DecoderOutput(sample=decoded)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
sample: torch.FloatTensor,
|
||||
|
||||
@@ -317,8 +317,8 @@ class FlaxDiffusionPipeline(ConfigMixin):
|
||||
allow_patterns = [os.path.join(k, "*") for k in folder_names]
|
||||
allow_patterns += [FLAX_WEIGHTS_NAME, SCHEDULER_CONFIG_NAME, CONFIG_NAME, cls.config_name]
|
||||
|
||||
# make sure we don't download PyTorch weights
|
||||
ignore_patterns = "*.bin"
|
||||
# make sure we don't download PyTorch weights, unless when using from_pt
|
||||
ignore_patterns = "*.bin" if not from_pt else []
|
||||
|
||||
if cls != FlaxDiffusionPipeline:
|
||||
requested_pipeline_class = cls.__name__
|
||||
|
||||
@@ -26,7 +26,7 @@ import torch
|
||||
|
||||
import diffusers
|
||||
import PIL
|
||||
from huggingface_hub import snapshot_download
|
||||
from huggingface_hub import model_info, snapshot_download
|
||||
from packaging import version
|
||||
from PIL import Image
|
||||
from tqdm.auto import tqdm
|
||||
@@ -44,6 +44,7 @@ from .utils import (
|
||||
BaseOutput,
|
||||
deprecate,
|
||||
is_accelerate_available,
|
||||
is_safetensors_available,
|
||||
is_torch_version,
|
||||
is_transformers_available,
|
||||
logging,
|
||||
@@ -117,6 +118,23 @@ class AudioPipelineOutput(BaseOutput):
|
||||
audios: np.ndarray
|
||||
|
||||
|
||||
def is_safetensors_compatible(info) -> bool:
|
||||
filenames = set(sibling.rfilename for sibling in info.siblings)
|
||||
pt_filenames = set(filename for filename in filenames if filename.endswith(".bin"))
|
||||
is_safetensors_compatible = any(file.endswith(".safetensors") for file in filenames)
|
||||
for pt_filename in pt_filenames:
|
||||
prefix, raw = os.path.split(pt_filename)
|
||||
if raw == "pytorch_model.bin":
|
||||
# transformers specific
|
||||
sf_filename = os.path.join(prefix, "model.safetensors")
|
||||
else:
|
||||
sf_filename = pt_filename[: -len(".bin")] + ".safetensors"
|
||||
if is_safetensors_compatible and sf_filename not in filenames:
|
||||
logger.warning(f"{sf_filename} not found")
|
||||
is_safetensors_compatible = False
|
||||
return is_safetensors_compatible
|
||||
|
||||
|
||||
class DiffusionPipeline(ConfigMixin):
|
||||
r"""
|
||||
Base class for all models.
|
||||
@@ -459,7 +477,7 @@ class DiffusionPipeline(ConfigMixin):
|
||||
allow_patterns += [WEIGHTS_NAME, SCHEDULER_CONFIG_NAME, CONFIG_NAME, ONNX_WEIGHTS_NAME, cls.config_name]
|
||||
|
||||
# make sure we don't download flax weights
|
||||
ignore_patterns = "*.msgpack"
|
||||
ignore_patterns = ["*.msgpack"]
|
||||
|
||||
if custom_pipeline is not None:
|
||||
allow_patterns += [CUSTOM_PIPELINE_FILE_NAME]
|
||||
@@ -473,6 +491,15 @@ class DiffusionPipeline(ConfigMixin):
|
||||
user_agent["custom_pipeline"] = custom_pipeline
|
||||
user_agent = http_user_agent(user_agent)
|
||||
|
||||
if is_safetensors_available():
|
||||
info = model_info(
|
||||
pretrained_model_name_or_path,
|
||||
use_auth_token=use_auth_token,
|
||||
revision=revision,
|
||||
)
|
||||
if is_safetensors_compatible(info):
|
||||
ignore_patterns.append("*.bin")
|
||||
|
||||
# download all allow_patterns
|
||||
cached_folder = snapshot_download(
|
||||
pretrained_model_name_or_path,
|
||||
@@ -740,7 +767,7 @@ class DiffusionPipeline(ConfigMixin):
|
||||
|
||||
return pil_images
|
||||
|
||||
def progress_bar(self, iterable):
|
||||
def progress_bar(self, iterable=None, total=None):
|
||||
if not hasattr(self, "_progress_bar_config"):
|
||||
self._progress_bar_config = {}
|
||||
elif not isinstance(self._progress_bar_config, dict):
|
||||
@@ -748,7 +775,12 @@ class DiffusionPipeline(ConfigMixin):
|
||||
f"`self._progress_bar_config` should be of type `dict`, but is {type(self._progress_bar_config)}."
|
||||
)
|
||||
|
||||
return tqdm(iterable, **self._progress_bar_config)
|
||||
if iterable is not None:
|
||||
return tqdm(iterable, **self._progress_bar_config)
|
||||
elif total is not None:
|
||||
return tqdm(total=total, **self._progress_bar_config)
|
||||
else:
|
||||
raise ValueError("Either `total` or `iterable` has to be defined.")
|
||||
|
||||
def set_progress_bar_config(self, **kwargs):
|
||||
self._progress_bar_config = kwargs
|
||||
|
||||
@@ -216,6 +216,22 @@ class AltDiffusionPipeline(DiffusionPipeline):
|
||||
# set slice_size = `None` to disable `attention slicing`
|
||||
self.enable_attention_slicing(None)
|
||||
|
||||
def enable_vae_slicing(self):
|
||||
r"""
|
||||
Enable sliced VAE decoding.
|
||||
|
||||
When this option is enabled, the VAE will split the input tensor in slices to compute decoding in several
|
||||
steps. This is useful to save some memory and allow larger batch sizes.
|
||||
"""
|
||||
self.vae.enable_slicing()
|
||||
|
||||
def disable_vae_slicing(self):
|
||||
r"""
|
||||
Disable sliced VAE decoding. If `enable_vae_slicing` was previously invoked, this method will go back to
|
||||
computing decoding in one step.
|
||||
"""
|
||||
self.vae.disable_slicing()
|
||||
|
||||
def enable_sequential_cpu_offload(self, gpu_id=0):
|
||||
r"""
|
||||
Offloads all models to CPU using accelerate, significantly reducing memory usage. When called, unet,
|
||||
@@ -445,7 +461,6 @@ class AltDiffusionPipeline(DiffusionPipeline):
|
||||
return_dict: bool = True,
|
||||
callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
|
||||
callback_steps: Optional[int] = 1,
|
||||
**kwargs,
|
||||
):
|
||||
r"""
|
||||
Function invoked when calling the pipeline for generation.
|
||||
@@ -542,25 +557,29 @@ class AltDiffusionPipeline(DiffusionPipeline):
|
||||
extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
|
||||
|
||||
# 7. Denoising loop
|
||||
for i, t in enumerate(self.progress_bar(timesteps)):
|
||||
# expand the latents if we are doing classifier free guidance
|
||||
latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
|
||||
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
|
||||
num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
|
||||
with self.progress_bar(total=num_inference_steps) as progress_bar:
|
||||
for i, t in enumerate(timesteps):
|
||||
# expand the latents if we are doing classifier free guidance
|
||||
latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
|
||||
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
|
||||
|
||||
# predict the noise residual
|
||||
noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=text_embeddings).sample
|
||||
# predict the noise residual
|
||||
noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=text_embeddings).sample
|
||||
|
||||
# perform guidance
|
||||
if do_classifier_free_guidance:
|
||||
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
|
||||
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
|
||||
# perform guidance
|
||||
if do_classifier_free_guidance:
|
||||
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
|
||||
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
|
||||
|
||||
# compute the previous noisy sample x_t -> x_t-1
|
||||
latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample
|
||||
# compute the previous noisy sample x_t -> x_t-1
|
||||
latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample
|
||||
|
||||
# call the callback, if provided
|
||||
if callback is not None and i % callback_steps == 0:
|
||||
callback(i, t, latents)
|
||||
# call the callback, if provided
|
||||
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
|
||||
progress_bar.update()
|
||||
if callback is not None and i % callback_steps == 0:
|
||||
callback(i, t, latents)
|
||||
|
||||
# 8. Post-processing
|
||||
image = self.decode_latents(latents)
|
||||
|
||||
@@ -433,7 +433,7 @@ class AltDiffusionImg2ImgPipeline(DiffusionPipeline):
|
||||
t_start = max(num_inference_steps - init_timestep + offset, 0)
|
||||
timesteps = self.scheduler.timesteps[t_start:]
|
||||
|
||||
return timesteps
|
||||
return timesteps, num_inference_steps - t_start
|
||||
|
||||
def prepare_latents(self, init_image, timestep, batch_size, num_images_per_prompt, dtype, device, generator=None):
|
||||
init_image = init_image.to(device=device, dtype=dtype)
|
||||
@@ -484,7 +484,6 @@ class AltDiffusionImg2ImgPipeline(DiffusionPipeline):
|
||||
return_dict: bool = True,
|
||||
callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
|
||||
callback_steps: Optional[int] = 1,
|
||||
**kwargs,
|
||||
):
|
||||
r"""
|
||||
Function invoked when calling the pipeline for generation.
|
||||
@@ -563,7 +562,7 @@ class AltDiffusionImg2ImgPipeline(DiffusionPipeline):
|
||||
|
||||
# 5. set timesteps
|
||||
self.scheduler.set_timesteps(num_inference_steps, device=device)
|
||||
timesteps = self.get_timesteps(num_inference_steps, strength, device)
|
||||
timesteps, num_inference_steps = self.get_timesteps(num_inference_steps, strength, device)
|
||||
latent_timestep = timesteps[:1].repeat(batch_size * num_images_per_prompt)
|
||||
|
||||
# 6. Prepare latent variables
|
||||
@@ -575,25 +574,29 @@ class AltDiffusionImg2ImgPipeline(DiffusionPipeline):
|
||||
extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
|
||||
|
||||
# 8. Denoising loop
|
||||
for i, t in enumerate(self.progress_bar(timesteps)):
|
||||
# expand the latents if we are doing classifier free guidance
|
||||
latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
|
||||
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
|
||||
num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
|
||||
with self.progress_bar(total=num_inference_steps) as progress_bar:
|
||||
for i, t in enumerate(timesteps):
|
||||
# expand the latents if we are doing classifier free guidance
|
||||
latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
|
||||
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
|
||||
|
||||
# predict the noise residual
|
||||
noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=text_embeddings).sample
|
||||
# predict the noise residual
|
||||
noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=text_embeddings).sample
|
||||
|
||||
# perform guidance
|
||||
if do_classifier_free_guidance:
|
||||
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
|
||||
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
|
||||
# perform guidance
|
||||
if do_classifier_free_guidance:
|
||||
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
|
||||
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
|
||||
|
||||
# compute the previous noisy sample x_t -> x_t-1
|
||||
latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample
|
||||
# compute the previous noisy sample x_t -> x_t-1
|
||||
latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample
|
||||
|
||||
# call the callback, if provided
|
||||
if callback is not None and i % callback_steps == 0:
|
||||
callback(i, t, latents)
|
||||
# call the callback, if provided
|
||||
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
|
||||
progress_bar.update()
|
||||
if callback is not None and i % callback_steps == 0:
|
||||
callback(i, t, latents)
|
||||
|
||||
# 9. Post-processing
|
||||
image = self.decode_latents(latents)
|
||||
|
||||
@@ -63,15 +63,14 @@ if is_transformers_available() and is_flax_available():
|
||||
Output class for Stable Diffusion pipelines.
|
||||
|
||||
Args:
|
||||
images (`List[PIL.Image.Image]` or `np.ndarray`)
|
||||
List of denoised PIL images of length `batch_size` or numpy array of shape `(batch_size, height, width,
|
||||
num_channels)`. PIL images or numpy array present the denoised images of the diffusion pipeline.
|
||||
images (`np.ndarray`)
|
||||
Array of shape `(batch_size, height, width, num_channels)` with images from the diffusion pipeline.
|
||||
nsfw_content_detected (`List[bool]`)
|
||||
List of flags denoting whether the corresponding generated image likely represents "not-safe-for-work"
|
||||
(nsfw) content.
|
||||
"""
|
||||
|
||||
images: Union[List[PIL.Image.Image], np.ndarray]
|
||||
images: np.ndarray
|
||||
nsfw_content_detected: List[bool]
|
||||
|
||||
from ...schedulers.scheduling_pndm_flax import PNDMSchedulerState
|
||||
|
||||
@@ -475,7 +475,7 @@ class CycleDiffusionPipeline(DiffusionPipeline):
|
||||
t_start = max(num_inference_steps - init_timestep + offset, 0)
|
||||
timesteps = self.scheduler.timesteps[t_start:]
|
||||
|
||||
return timesteps
|
||||
return timesteps, num_inference_steps - t_start
|
||||
|
||||
def prepare_latents(self, init_image, timestep, batch_size, num_images_per_prompt, dtype, device, generator=None):
|
||||
init_image = init_image.to(device=device, dtype=dtype)
|
||||
@@ -528,7 +528,6 @@ class CycleDiffusionPipeline(DiffusionPipeline):
|
||||
return_dict: bool = True,
|
||||
callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
|
||||
callback_steps: Optional[int] = 1,
|
||||
**kwargs,
|
||||
):
|
||||
r"""
|
||||
Function invoked when calling the pipeline for generation.
|
||||
@@ -608,7 +607,7 @@ class CycleDiffusionPipeline(DiffusionPipeline):
|
||||
|
||||
# 5. Prepare timesteps
|
||||
self.scheduler.set_timesteps(num_inference_steps, device=device)
|
||||
timesteps = self.get_timesteps(num_inference_steps, strength, device)
|
||||
timesteps, num_inference_steps = self.get_timesteps(num_inference_steps, strength, device)
|
||||
latent_timestep = timesteps[:1].repeat(batch_size * num_images_per_prompt)
|
||||
|
||||
# 6. Prepare latent variables
|
||||
@@ -622,66 +621,70 @@ class CycleDiffusionPipeline(DiffusionPipeline):
|
||||
generator = extra_step_kwargs.pop("generator", None)
|
||||
|
||||
# 8. Denoising loop
|
||||
for i, t in enumerate(self.progress_bar(timesteps)):
|
||||
# expand the latents if we are doing classifier free guidance
|
||||
latent_model_input = torch.cat([latents] * 2)
|
||||
source_latent_model_input = torch.cat([source_latents] * 2)
|
||||
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
|
||||
source_latent_model_input = self.scheduler.scale_model_input(source_latent_model_input, t)
|
||||
num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
|
||||
with self.progress_bar(total=num_inference_steps) as progress_bar:
|
||||
for i, t in enumerate(timesteps):
|
||||
# expand the latents if we are doing classifier free guidance
|
||||
latent_model_input = torch.cat([latents] * 2)
|
||||
source_latent_model_input = torch.cat([source_latents] * 2)
|
||||
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
|
||||
source_latent_model_input = self.scheduler.scale_model_input(source_latent_model_input, t)
|
||||
|
||||
# predict the noise residual
|
||||
concat_latent_model_input = torch.stack(
|
||||
[
|
||||
source_latent_model_input[0],
|
||||
latent_model_input[0],
|
||||
source_latent_model_input[1],
|
||||
latent_model_input[1],
|
||||
],
|
||||
dim=0,
|
||||
)
|
||||
concat_text_embeddings = torch.stack(
|
||||
[
|
||||
source_text_embeddings[0],
|
||||
text_embeddings[0],
|
||||
source_text_embeddings[1],
|
||||
text_embeddings[1],
|
||||
],
|
||||
dim=0,
|
||||
)
|
||||
concat_noise_pred = self.unet(
|
||||
concat_latent_model_input, t, encoder_hidden_states=concat_text_embeddings
|
||||
).sample
|
||||
# predict the noise residual
|
||||
concat_latent_model_input = torch.stack(
|
||||
[
|
||||
source_latent_model_input[0],
|
||||
latent_model_input[0],
|
||||
source_latent_model_input[1],
|
||||
latent_model_input[1],
|
||||
],
|
||||
dim=0,
|
||||
)
|
||||
concat_text_embeddings = torch.stack(
|
||||
[
|
||||
source_text_embeddings[0],
|
||||
text_embeddings[0],
|
||||
source_text_embeddings[1],
|
||||
text_embeddings[1],
|
||||
],
|
||||
dim=0,
|
||||
)
|
||||
concat_noise_pred = self.unet(
|
||||
concat_latent_model_input, t, encoder_hidden_states=concat_text_embeddings
|
||||
).sample
|
||||
|
||||
# perform guidance
|
||||
(
|
||||
source_noise_pred_uncond,
|
||||
noise_pred_uncond,
|
||||
source_noise_pred_text,
|
||||
noise_pred_text,
|
||||
) = concat_noise_pred.chunk(4, dim=0)
|
||||
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
|
||||
source_noise_pred = source_noise_pred_uncond + source_guidance_scale * (
|
||||
source_noise_pred_text - source_noise_pred_uncond
|
||||
)
|
||||
# perform guidance
|
||||
(
|
||||
source_noise_pred_uncond,
|
||||
noise_pred_uncond,
|
||||
source_noise_pred_text,
|
||||
noise_pred_text,
|
||||
) = concat_noise_pred.chunk(4, dim=0)
|
||||
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
|
||||
source_noise_pred = source_noise_pred_uncond + source_guidance_scale * (
|
||||
source_noise_pred_text - source_noise_pred_uncond
|
||||
)
|
||||
|
||||
# Sample source_latents from the posterior distribution.
|
||||
prev_source_latents = posterior_sample(
|
||||
self.scheduler, source_latents, t, clean_latents, generator=generator, **extra_step_kwargs
|
||||
)
|
||||
# Compute noise.
|
||||
noise = compute_noise(
|
||||
self.scheduler, prev_source_latents, source_latents, t, source_noise_pred, **extra_step_kwargs
|
||||
)
|
||||
source_latents = prev_source_latents
|
||||
# Sample source_latents from the posterior distribution.
|
||||
prev_source_latents = posterior_sample(
|
||||
self.scheduler, source_latents, t, clean_latents, generator=generator, **extra_step_kwargs
|
||||
)
|
||||
# Compute noise.
|
||||
noise = compute_noise(
|
||||
self.scheduler, prev_source_latents, source_latents, t, source_noise_pred, **extra_step_kwargs
|
||||
)
|
||||
source_latents = prev_source_latents
|
||||
|
||||
# compute the previous noisy sample x_t -> x_t-1
|
||||
latents = self.scheduler.step(
|
||||
noise_pred, t, latents, variance_noise=noise, **extra_step_kwargs
|
||||
).prev_sample
|
||||
# compute the previous noisy sample x_t -> x_t-1
|
||||
latents = self.scheduler.step(
|
||||
noise_pred, t, latents, variance_noise=noise, **extra_step_kwargs
|
||||
).prev_sample
|
||||
|
||||
# call the callback, if provided
|
||||
if callback is not None and i % callback_steps == 0:
|
||||
callback(i, t, latents)
|
||||
# call the callback, if provided
|
||||
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
|
||||
progress_bar.update()
|
||||
if callback is not None and i % callback_steps == 0:
|
||||
callback(i, t, latents)
|
||||
|
||||
# 9. Post-processing
|
||||
image = self.decode_latents(latents)
|
||||
|
||||
@@ -289,7 +289,6 @@ class FlaxStableDiffusionPipeline(FlaxDiffusionPipeline):
|
||||
jit: bool = False,
|
||||
debug: bool = False,
|
||||
neg_prompt_ids: jnp.array = None,
|
||||
**kwargs,
|
||||
):
|
||||
r"""
|
||||
Function invoked when calling the pipeline for generation.
|
||||
@@ -317,9 +316,6 @@ class FlaxStableDiffusionPipeline(FlaxDiffusionPipeline):
|
||||
Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
|
||||
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
|
||||
tensor will ge generated by sampling using the supplied random `generator`.
|
||||
output_type (`str`, *optional*, defaults to `"pil"`):
|
||||
The output format of the generate image. Choose between
|
||||
[PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
|
||||
jit (`bool`, defaults to `False`):
|
||||
Whether to run `pmap` versions of the generation and safety scoring functions. NOTE: This argument
|
||||
exists because `__call__` is not yet end-to-end pmap-able. It will be removed in a future release.
|
||||
@@ -383,6 +379,7 @@ class FlaxStableDiffusionPipeline(FlaxDiffusionPipeline):
|
||||
|
||||
images = images.reshape(num_devices, batch_size, height, width, 3)
|
||||
else:
|
||||
images = np.asarray(images)
|
||||
has_nsfw_concept = False
|
||||
|
||||
if not return_dict:
|
||||
|
||||
@@ -205,7 +205,6 @@ class OnnxStableDiffusionPipeline(DiffusionPipeline):
|
||||
return_dict: bool = True,
|
||||
callback: Optional[Callable[[int, int, np.ndarray], None]] = None,
|
||||
callback_steps: Optional[int] = 1,
|
||||
**kwargs,
|
||||
):
|
||||
if isinstance(prompt, str):
|
||||
batch_size = 1
|
||||
|
||||
@@ -241,7 +241,6 @@ class OnnxStableDiffusionImg2ImgPipeline(DiffusionPipeline):
|
||||
return_dict: bool = True,
|
||||
callback: Optional[Callable[[int, int, np.ndarray], None]] = None,
|
||||
callback_steps: Optional[int] = 1,
|
||||
**kwargs,
|
||||
):
|
||||
r"""
|
||||
Function invoked when calling the pipeline for generation.
|
||||
|
||||
@@ -259,7 +259,6 @@ class OnnxStableDiffusionInpaintPipeline(DiffusionPipeline):
|
||||
return_dict: bool = True,
|
||||
callback: Optional[Callable[[int, int, np.ndarray], None]] = None,
|
||||
callback_steps: Optional[int] = 1,
|
||||
**kwargs,
|
||||
):
|
||||
r"""
|
||||
Function invoked when calling the pipeline for generation.
|
||||
|
||||
+3
-25
@@ -5,7 +5,6 @@ import numpy as np
|
||||
import torch
|
||||
|
||||
import PIL
|
||||
from packaging import version
|
||||
from transformers import CLIPFeatureExtractor, CLIPTokenizer
|
||||
|
||||
from ...configuration_utils import FrozenDict
|
||||
@@ -68,6 +67,8 @@ class OnnxStableDiffusionInpaintPipelineLegacy(DiffusionPipeline):
|
||||
feature_extractor ([`CLIPFeatureExtractor`]):
|
||||
Model that extracts features from generated images to be used as inputs for the `safety_checker`.
|
||||
"""
|
||||
_optional_components = ["safety_checker", "feature_extractor"]
|
||||
|
||||
vae_encoder: OnnxRuntimeModel
|
||||
vae_decoder: OnnxRuntimeModel
|
||||
text_encoder: OnnxRuntimeModel
|
||||
@@ -134,27 +135,6 @@ class OnnxStableDiffusionInpaintPipelineLegacy(DiffusionPipeline):
|
||||
" checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead."
|
||||
)
|
||||
|
||||
is_unet_version_less_0_9_0 = hasattr(unet.config, "_diffusers_version") and version.parse(
|
||||
version.parse(unet.config._diffusers_version).base_version
|
||||
) < version.parse("0.9.0.dev0")
|
||||
is_unet_sample_size_less_64 = hasattr(unet.config, "sample_size") and unet.config.sample_size < 64
|
||||
if is_unet_version_less_0_9_0 and is_unet_sample_size_less_64:
|
||||
deprecation_message = (
|
||||
"The configuration file of the unet has set the default `sample_size` to smaller than"
|
||||
" 64 which seems highly unlikely .If you're checkpoint is a fine-tuned version of any of the"
|
||||
" following: \n- CompVis/stable-diffusion-v1-4 \n- CompVis/stable-diffusion-v1-3 \n-"
|
||||
" CompVis/stable-diffusion-v1-2 \n- CompVis/stable-diffusion-v1-1 \n- runwayml/stable-diffusion-v1-5"
|
||||
" \n- runwayml/stable-diffusion-inpainting \n you should change 'sample_size' to 64 in the"
|
||||
" configuration file. Please make sure to update the config accordingly as leaving `sample_size=32`"
|
||||
" in the config might lead to incorrect results in future versions. If you have downloaded this"
|
||||
" checkpoint from the Hugging Face Hub, it would be very nice if you could open a Pull request for"
|
||||
" the `unet/config.json` file"
|
||||
)
|
||||
deprecate("sample_size<64", "1.0.0", deprecation_message, standard_warn=False)
|
||||
new_config = dict(unet.config)
|
||||
new_config["sample_size"] = 64
|
||||
unet._internal_dict = FrozenDict(new_config)
|
||||
|
||||
self.register_modules(
|
||||
vae_encoder=vae_encoder,
|
||||
vae_decoder=vae_decoder,
|
||||
@@ -165,7 +145,6 @@ class OnnxStableDiffusionInpaintPipelineLegacy(DiffusionPipeline):
|
||||
safety_checker=safety_checker,
|
||||
feature_extractor=feature_extractor,
|
||||
)
|
||||
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
|
||||
self.register_to_config(requires_safety_checker=requires_safety_checker)
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_onnx_stable_diffusion.OnnxStableDiffusionPipeline._encode_prompt
|
||||
@@ -262,7 +241,6 @@ class OnnxStableDiffusionInpaintPipelineLegacy(DiffusionPipeline):
|
||||
return_dict: bool = True,
|
||||
callback: Optional[Callable[[int, int, np.ndarray], None]] = None,
|
||||
callback_steps: Optional[int] = 1,
|
||||
**kwargs,
|
||||
):
|
||||
r"""
|
||||
Function invoked when calling the pipeline for generation.
|
||||
@@ -372,7 +350,7 @@ class OnnxStableDiffusionInpaintPipelineLegacy(DiffusionPipeline):
|
||||
|
||||
# preprocess mask
|
||||
if not isinstance(mask_image, np.ndarray):
|
||||
mask_image = preprocess_mask(mask_image, self.vae_scale_factor)
|
||||
mask_image = preprocess_mask(mask_image, 8)
|
||||
mask_image = mask_image.astype(latents_dtype)
|
||||
mask = np.concatenate([mask_image] * num_images_per_prompt, axis=0)
|
||||
|
||||
|
||||
@@ -215,6 +215,22 @@ class StableDiffusionPipeline(DiffusionPipeline):
|
||||
# set slice_size = `None` to disable `attention slicing`
|
||||
self.enable_attention_slicing(None)
|
||||
|
||||
def enable_vae_slicing(self):
|
||||
r"""
|
||||
Enable sliced VAE decoding.
|
||||
|
||||
When this option is enabled, the VAE will split the input tensor in slices to compute decoding in several
|
||||
steps. This is useful to save some memory and allow larger batch sizes.
|
||||
"""
|
||||
self.vae.enable_slicing()
|
||||
|
||||
def disable_vae_slicing(self):
|
||||
r"""
|
||||
Disable sliced VAE decoding. If `enable_vae_slicing` was previously invoked, this method will go back to
|
||||
computing decoding in one step.
|
||||
"""
|
||||
self.vae.disable_slicing()
|
||||
|
||||
def enable_sequential_cpu_offload(self, gpu_id=0):
|
||||
r"""
|
||||
Offloads all models to CPU using accelerate, significantly reducing memory usage. When called, unet,
|
||||
@@ -444,7 +460,6 @@ class StableDiffusionPipeline(DiffusionPipeline):
|
||||
return_dict: bool = True,
|
||||
callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
|
||||
callback_steps: Optional[int] = 1,
|
||||
**kwargs,
|
||||
):
|
||||
r"""
|
||||
Function invoked when calling the pipeline for generation.
|
||||
@@ -541,25 +556,29 @@ class StableDiffusionPipeline(DiffusionPipeline):
|
||||
extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
|
||||
|
||||
# 7. Denoising loop
|
||||
for i, t in enumerate(self.progress_bar(timesteps)):
|
||||
# expand the latents if we are doing classifier free guidance
|
||||
latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
|
||||
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
|
||||
num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
|
||||
with self.progress_bar(total=num_inference_steps) as progress_bar:
|
||||
for i, t in enumerate(timesteps):
|
||||
# expand the latents if we are doing classifier free guidance
|
||||
latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
|
||||
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
|
||||
|
||||
# predict the noise residual
|
||||
noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=text_embeddings).sample
|
||||
# predict the noise residual
|
||||
noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=text_embeddings).sample
|
||||
|
||||
# perform guidance
|
||||
if do_classifier_free_guidance:
|
||||
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
|
||||
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
|
||||
# perform guidance
|
||||
if do_classifier_free_guidance:
|
||||
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
|
||||
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
|
||||
|
||||
# compute the previous noisy sample x_t -> x_t-1
|
||||
latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample
|
||||
# compute the previous noisy sample x_t -> x_t-1
|
||||
latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample
|
||||
|
||||
# call the callback, if provided
|
||||
if callback is not None and i % callback_steps == 0:
|
||||
callback(i, t, latents)
|
||||
# call the callback, if provided
|
||||
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
|
||||
progress_bar.update()
|
||||
if callback is not None and i % callback_steps == 0:
|
||||
callback(i, t, latents)
|
||||
|
||||
# 8. Post-processing
|
||||
image = self.decode_latents(latents)
|
||||
|
||||
+19
-16
@@ -342,7 +342,6 @@ class StableDiffusionImageVariationPipeline(DiffusionPipeline):
|
||||
return_dict: bool = True,
|
||||
callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
|
||||
callback_steps: Optional[int] = 1,
|
||||
**kwargs,
|
||||
):
|
||||
r"""
|
||||
Function invoked when calling the pipeline for generation.
|
||||
@@ -442,25 +441,29 @@ class StableDiffusionImageVariationPipeline(DiffusionPipeline):
|
||||
extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
|
||||
|
||||
# 7. Denoising loop
|
||||
for i, t in enumerate(self.progress_bar(timesteps)):
|
||||
# expand the latents if we are doing classifier free guidance
|
||||
latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
|
||||
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
|
||||
num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
|
||||
with self.progress_bar(total=num_inference_steps) as progress_bar:
|
||||
for i, t in enumerate(timesteps):
|
||||
# expand the latents if we are doing classifier free guidance
|
||||
latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
|
||||
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
|
||||
|
||||
# predict the noise residual
|
||||
noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=image_embeddings).sample
|
||||
# predict the noise residual
|
||||
noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=image_embeddings).sample
|
||||
|
||||
# perform guidance
|
||||
if do_classifier_free_guidance:
|
||||
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
|
||||
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
|
||||
# perform guidance
|
||||
if do_classifier_free_guidance:
|
||||
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
|
||||
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
|
||||
|
||||
# compute the previous noisy sample x_t -> x_t-1
|
||||
latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample
|
||||
# compute the previous noisy sample x_t -> x_t-1
|
||||
latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample
|
||||
|
||||
# call the callback, if provided
|
||||
if callback is not None and i % callback_steps == 0:
|
||||
callback(i, t, latents)
|
||||
# call the callback, if provided
|
||||
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
|
||||
progress_bar.update()
|
||||
if callback is not None and i % callback_steps == 0:
|
||||
callback(i, t, latents)
|
||||
|
||||
# 8. Post-processing
|
||||
image = self.decode_latents(latents)
|
||||
|
||||
@@ -442,7 +442,7 @@ class StableDiffusionImg2ImgPipeline(DiffusionPipeline):
|
||||
t_start = max(num_inference_steps - init_timestep + offset, 0)
|
||||
timesteps = self.scheduler.timesteps[t_start:]
|
||||
|
||||
return timesteps
|
||||
return timesteps, num_inference_steps - t_start
|
||||
|
||||
def prepare_latents(self, init_image, timestep, batch_size, num_images_per_prompt, dtype, device, generator=None):
|
||||
init_image = init_image.to(device=device, dtype=dtype)
|
||||
@@ -493,7 +493,6 @@ class StableDiffusionImg2ImgPipeline(DiffusionPipeline):
|
||||
return_dict: bool = True,
|
||||
callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
|
||||
callback_steps: Optional[int] = 1,
|
||||
**kwargs,
|
||||
):
|
||||
r"""
|
||||
Function invoked when calling the pipeline for generation.
|
||||
@@ -572,7 +571,7 @@ class StableDiffusionImg2ImgPipeline(DiffusionPipeline):
|
||||
|
||||
# 5. set timesteps
|
||||
self.scheduler.set_timesteps(num_inference_steps, device=device)
|
||||
timesteps = self.get_timesteps(num_inference_steps, strength, device)
|
||||
timesteps, num_inference_steps = self.get_timesteps(num_inference_steps, strength, device)
|
||||
latent_timestep = timesteps[:1].repeat(batch_size * num_images_per_prompt)
|
||||
|
||||
# 6. Prepare latent variables
|
||||
@@ -584,25 +583,29 @@ class StableDiffusionImg2ImgPipeline(DiffusionPipeline):
|
||||
extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
|
||||
|
||||
# 8. Denoising loop
|
||||
for i, t in enumerate(self.progress_bar(timesteps)):
|
||||
# expand the latents if we are doing classifier free guidance
|
||||
latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
|
||||
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
|
||||
num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
|
||||
with self.progress_bar(total=num_inference_steps) as progress_bar:
|
||||
for i, t in enumerate(timesteps):
|
||||
# expand the latents if we are doing classifier free guidance
|
||||
latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
|
||||
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
|
||||
|
||||
# predict the noise residual
|
||||
noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=text_embeddings).sample
|
||||
# predict the noise residual
|
||||
noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=text_embeddings).sample
|
||||
|
||||
# perform guidance
|
||||
if do_classifier_free_guidance:
|
||||
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
|
||||
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
|
||||
# perform guidance
|
||||
if do_classifier_free_guidance:
|
||||
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
|
||||
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
|
||||
|
||||
# compute the previous noisy sample x_t -> x_t-1
|
||||
latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample
|
||||
# compute the previous noisy sample x_t -> x_t-1
|
||||
latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample
|
||||
|
||||
# call the callback, if provided
|
||||
if callback is not None and i % callback_steps == 0:
|
||||
callback(i, t, latents)
|
||||
# call the callback, if provided
|
||||
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
|
||||
progress_bar.update()
|
||||
if callback is not None and i % callback_steps == 0:
|
||||
callback(i, t, latents)
|
||||
|
||||
# 9. Post-processing
|
||||
image = self.decode_latents(latents)
|
||||
|
||||
@@ -566,7 +566,6 @@ class StableDiffusionInpaintPipeline(DiffusionPipeline):
|
||||
return_dict: bool = True,
|
||||
callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
|
||||
callback_steps: Optional[int] = 1,
|
||||
**kwargs,
|
||||
):
|
||||
r"""
|
||||
Function invoked when calling the pipeline for generation.
|
||||
@@ -656,7 +655,7 @@ class StableDiffusionInpaintPipeline(DiffusionPipeline):
|
||||
|
||||
# 5. set timesteps
|
||||
self.scheduler.set_timesteps(num_inference_steps, device=device)
|
||||
timesteps_tensor = self.scheduler.timesteps
|
||||
timesteps = self.scheduler.timesteps
|
||||
|
||||
# 6. Prepare latent variables
|
||||
num_channels_latents = self.vae.config.latent_channels
|
||||
@@ -700,28 +699,32 @@ class StableDiffusionInpaintPipeline(DiffusionPipeline):
|
||||
extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
|
||||
|
||||
# 10. Denoising loop
|
||||
for i, t in enumerate(self.progress_bar(timesteps_tensor)):
|
||||
# expand the latents if we are doing classifier free guidance
|
||||
latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
|
||||
num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
|
||||
with self.progress_bar(total=num_inference_steps) as progress_bar:
|
||||
for i, t in enumerate(timesteps):
|
||||
# expand the latents if we are doing classifier free guidance
|
||||
latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
|
||||
|
||||
# concat latents, mask, masked_image_latents in the channel dimension
|
||||
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
|
||||
latent_model_input = torch.cat([latent_model_input, mask, masked_image_latents], dim=1)
|
||||
# concat latents, mask, masked_image_latents in the channel dimension
|
||||
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
|
||||
latent_model_input = torch.cat([latent_model_input, mask, masked_image_latents], dim=1)
|
||||
|
||||
# predict the noise residual
|
||||
noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=text_embeddings).sample
|
||||
# predict the noise residual
|
||||
noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=text_embeddings).sample
|
||||
|
||||
# perform guidance
|
||||
if do_classifier_free_guidance:
|
||||
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
|
||||
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
|
||||
# perform guidance
|
||||
if do_classifier_free_guidance:
|
||||
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
|
||||
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
|
||||
|
||||
# compute the previous noisy sample x_t -> x_t-1
|
||||
latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample
|
||||
# compute the previous noisy sample x_t -> x_t-1
|
||||
latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample
|
||||
|
||||
# call the callback, if provided
|
||||
if callback is not None and i % callback_steps == 0:
|
||||
callback(i, t, latents)
|
||||
# call the callback, if provided
|
||||
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
|
||||
progress_bar.update()
|
||||
if callback is not None and i % callback_steps == 0:
|
||||
callback(i, t, latents)
|
||||
|
||||
# 11. Post-processing
|
||||
image = self.decode_latents(latents)
|
||||
|
||||
+24
-21
@@ -457,7 +457,7 @@ class StableDiffusionInpaintPipelineLegacy(DiffusionPipeline):
|
||||
t_start = max(num_inference_steps - init_timestep + offset, 0)
|
||||
timesteps = self.scheduler.timesteps[t_start:]
|
||||
|
||||
return timesteps
|
||||
return timesteps, num_inference_steps - t_start
|
||||
|
||||
def prepare_latents(self, init_image, timestep, batch_size, num_images_per_prompt, dtype, device, generator):
|
||||
init_image = init_image.to(device=self.device, dtype=dtype)
|
||||
@@ -492,7 +492,6 @@ class StableDiffusionInpaintPipelineLegacy(DiffusionPipeline):
|
||||
return_dict: bool = True,
|
||||
callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
|
||||
callback_steps: Optional[int] = 1,
|
||||
**kwargs,
|
||||
):
|
||||
r"""
|
||||
Function invoked when calling the pipeline for generation.
|
||||
@@ -578,7 +577,7 @@ class StableDiffusionInpaintPipelineLegacy(DiffusionPipeline):
|
||||
|
||||
# 5. set timesteps
|
||||
self.scheduler.set_timesteps(num_inference_steps, device=device)
|
||||
timesteps = self.get_timesteps(num_inference_steps, strength, device)
|
||||
timesteps, num_inference_steps = self.get_timesteps(num_inference_steps, strength, device)
|
||||
latent_timestep = timesteps[:1].repeat(batch_size * num_images_per_prompt)
|
||||
|
||||
# 6. Prepare latent variables
|
||||
@@ -595,29 +594,33 @@ class StableDiffusionInpaintPipelineLegacy(DiffusionPipeline):
|
||||
extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
|
||||
|
||||
# 9. Denoising loop
|
||||
for i, t in enumerate(self.progress_bar(timesteps)):
|
||||
# expand the latents if we are doing classifier free guidance
|
||||
latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
|
||||
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
|
||||
num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
|
||||
with self.progress_bar(total=num_inference_steps) as progress_bar:
|
||||
for i, t in enumerate(timesteps):
|
||||
# expand the latents if we are doing classifier free guidance
|
||||
latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
|
||||
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
|
||||
|
||||
# predict the noise residual
|
||||
noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=text_embeddings).sample
|
||||
# predict the noise residual
|
||||
noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=text_embeddings).sample
|
||||
|
||||
# perform guidance
|
||||
if do_classifier_free_guidance:
|
||||
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
|
||||
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
|
||||
# perform guidance
|
||||
if do_classifier_free_guidance:
|
||||
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
|
||||
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
|
||||
|
||||
# compute the previous noisy sample x_t -> x_t-1
|
||||
latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample
|
||||
# masking
|
||||
init_latents_proper = self.scheduler.add_noise(init_latents_orig, noise, torch.tensor([t]))
|
||||
# compute the previous noisy sample x_t -> x_t-1
|
||||
latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample
|
||||
# masking
|
||||
init_latents_proper = self.scheduler.add_noise(init_latents_orig, noise, torch.tensor([t]))
|
||||
|
||||
latents = (init_latents_proper * mask) + (latents * (1 - mask))
|
||||
latents = (init_latents_proper * mask) + (latents * (1 - mask))
|
||||
|
||||
# call the callback, if provided
|
||||
if callback is not None and i % callback_steps == 0:
|
||||
callback(i, t, latents)
|
||||
# call the callback, if provided
|
||||
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
|
||||
progress_bar.update()
|
||||
if callback is not None and i % callback_steps == 0:
|
||||
callback(i, t, latents)
|
||||
|
||||
# 10. Post-processing
|
||||
image = self.decode_latents(latents)
|
||||
|
||||
@@ -469,7 +469,7 @@ class StableDiffusionUpscalePipeline(DiffusionPipeline):
|
||||
|
||||
# 5. set timesteps
|
||||
self.scheduler.set_timesteps(num_inference_steps, device=device)
|
||||
timesteps_tensor = self.scheduler.timesteps
|
||||
timesteps = self.scheduler.timesteps
|
||||
|
||||
# 5. Add noise to image
|
||||
noise_level = torch.tensor([noise_level], dtype=torch.long, device=device)
|
||||
@@ -511,30 +511,34 @@ class StableDiffusionUpscalePipeline(DiffusionPipeline):
|
||||
extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
|
||||
|
||||
# 9. Denoising loop
|
||||
for i, t in enumerate(self.progress_bar(timesteps_tensor)):
|
||||
# expand the latents if we are doing classifier free guidance
|
||||
latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
|
||||
num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
|
||||
with self.progress_bar(total=num_inference_steps) as progress_bar:
|
||||
for i, t in enumerate(timesteps):
|
||||
# expand the latents if we are doing classifier free guidance
|
||||
latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
|
||||
|
||||
# concat latents, mask, masked_image_latents in the channel dimension
|
||||
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
|
||||
latent_model_input = torch.cat([latent_model_input, image], dim=1)
|
||||
# concat latents, mask, masked_image_latents in the channel dimension
|
||||
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
|
||||
latent_model_input = torch.cat([latent_model_input, image], dim=1)
|
||||
|
||||
# predict the noise residual
|
||||
noise_pred = self.unet(
|
||||
latent_model_input, t, encoder_hidden_states=text_embeddings, class_labels=noise_level
|
||||
).sample
|
||||
# predict the noise residual
|
||||
noise_pred = self.unet(
|
||||
latent_model_input, t, encoder_hidden_states=text_embeddings, class_labels=noise_level
|
||||
).sample
|
||||
|
||||
# perform guidance
|
||||
if do_classifier_free_guidance:
|
||||
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
|
||||
# perform guidance
|
||||
if do_classifier_free_guidance:
|
||||
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
|
||||
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
|
||||
|
||||
# compute the previous noisy sample x_t -> x_t-1
|
||||
latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample
|
||||
# compute the previous noisy sample x_t -> x_t-1
|
||||
latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample
|
||||
|
||||
# call the callback, if provided
|
||||
if callback is not None and i % callback_steps == 0:
|
||||
callback(i, t, latents)
|
||||
# call the callback, if provided
|
||||
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
|
||||
progress_bar.update()
|
||||
if callback is not None and i % callback_steps == 0:
|
||||
callback(i, t, latents)
|
||||
|
||||
# 10. Post-processing
|
||||
# make sure the VAE is in float32 mode, as it overflows in float16
|
||||
|
||||
@@ -546,7 +546,6 @@ class StableDiffusionPipelineSafe(DiffusionPipeline):
|
||||
sld_threshold: Optional[float] = 0.01,
|
||||
sld_momentum_scale: Optional[float] = 0.3,
|
||||
sld_mom_beta: Optional[float] = 0.4,
|
||||
**kwargs,
|
||||
):
|
||||
r"""
|
||||
Function invoked when calling the pipeline for generation.
|
||||
@@ -669,63 +668,71 @@ class StableDiffusionPipelineSafe(DiffusionPipeline):
|
||||
|
||||
safety_momentum = None
|
||||
|
||||
for i, t in enumerate(self.progress_bar(timesteps)):
|
||||
# expand the latents if we are doing classifier free guidance
|
||||
latent_model_input = (
|
||||
torch.cat([latents] * (3 if enable_safety_guidance else 2)) if do_classifier_free_guidance else latents
|
||||
)
|
||||
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
|
||||
num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
|
||||
with self.progress_bar(total=num_inference_steps) as progress_bar:
|
||||
for i, t in enumerate(timesteps):
|
||||
# expand the latents if we are doing classifier free guidance
|
||||
latent_model_input = (
|
||||
torch.cat([latents] * (3 if enable_safety_guidance else 2))
|
||||
if do_classifier_free_guidance
|
||||
else latents
|
||||
)
|
||||
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
|
||||
|
||||
# predict the noise residual
|
||||
noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=text_embeddings).sample
|
||||
# predict the noise residual
|
||||
noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=text_embeddings).sample
|
||||
|
||||
# perform guidance
|
||||
if do_classifier_free_guidance:
|
||||
noise_pred_out = noise_pred.chunk((3 if enable_safety_guidance else 2))
|
||||
noise_pred_uncond, noise_pred_text = noise_pred_out[0], noise_pred_out[1]
|
||||
# perform guidance
|
||||
if do_classifier_free_guidance:
|
||||
noise_pred_out = noise_pred.chunk((3 if enable_safety_guidance else 2))
|
||||
noise_pred_uncond, noise_pred_text = noise_pred_out[0], noise_pred_out[1]
|
||||
|
||||
# default classifier free guidance
|
||||
noise_guidance = noise_pred_text - noise_pred_uncond
|
||||
# default classifier free guidance
|
||||
noise_guidance = noise_pred_text - noise_pred_uncond
|
||||
|
||||
# Perform SLD guidance
|
||||
if enable_safety_guidance:
|
||||
if safety_momentum is None:
|
||||
safety_momentum = torch.zeros_like(noise_guidance)
|
||||
noise_pred_safety_concept = noise_pred_out[2]
|
||||
# Perform SLD guidance
|
||||
if enable_safety_guidance:
|
||||
if safety_momentum is None:
|
||||
safety_momentum = torch.zeros_like(noise_guidance)
|
||||
noise_pred_safety_concept = noise_pred_out[2]
|
||||
|
||||
# Equation 6
|
||||
scale = torch.clamp(
|
||||
torch.abs((noise_pred_text - noise_pred_safety_concept)) * sld_guidance_scale, max=1.0
|
||||
)
|
||||
# Equation 6
|
||||
scale = torch.clamp(
|
||||
torch.abs((noise_pred_text - noise_pred_safety_concept)) * sld_guidance_scale, max=1.0
|
||||
)
|
||||
|
||||
# Equation 6
|
||||
safety_concept_scale = torch.where(
|
||||
(noise_pred_text - noise_pred_safety_concept) >= sld_threshold, torch.zeros_like(scale), scale
|
||||
)
|
||||
# Equation 6
|
||||
safety_concept_scale = torch.where(
|
||||
(noise_pred_text - noise_pred_safety_concept) >= sld_threshold,
|
||||
torch.zeros_like(scale),
|
||||
scale,
|
||||
)
|
||||
|
||||
# Equation 4
|
||||
noise_guidance_safety = torch.mul(
|
||||
(noise_pred_safety_concept - noise_pred_uncond), safety_concept_scale
|
||||
)
|
||||
# Equation 4
|
||||
noise_guidance_safety = torch.mul(
|
||||
(noise_pred_safety_concept - noise_pred_uncond), safety_concept_scale
|
||||
)
|
||||
|
||||
# Equation 7
|
||||
noise_guidance_safety = noise_guidance_safety + sld_momentum_scale * safety_momentum
|
||||
# Equation 7
|
||||
noise_guidance_safety = noise_guidance_safety + sld_momentum_scale * safety_momentum
|
||||
|
||||
# Equation 8
|
||||
safety_momentum = sld_mom_beta * safety_momentum + (1 - sld_mom_beta) * noise_guidance_safety
|
||||
# Equation 8
|
||||
safety_momentum = sld_mom_beta * safety_momentum + (1 - sld_mom_beta) * noise_guidance_safety
|
||||
|
||||
if i >= sld_warmup_steps: # Warmup
|
||||
# Equation 3
|
||||
noise_guidance = noise_guidance - noise_guidance_safety
|
||||
if i >= sld_warmup_steps: # Warmup
|
||||
# Equation 3
|
||||
noise_guidance = noise_guidance - noise_guidance_safety
|
||||
|
||||
noise_pred = noise_pred_uncond + guidance_scale * noise_guidance
|
||||
noise_pred = noise_pred_uncond + guidance_scale * noise_guidance
|
||||
|
||||
# compute the previous noisy sample x_t -> x_t-1
|
||||
latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample
|
||||
# compute the previous noisy sample x_t -> x_t-1
|
||||
latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample
|
||||
|
||||
# call the callback, if provided
|
||||
if callback is not None and i % callback_steps == 0:
|
||||
callback(i, t, latents)
|
||||
# call the callback, if provided
|
||||
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
|
||||
progress_bar.update()
|
||||
if callback is not None and i % callback_steps == 0:
|
||||
callback(i, t, latents)
|
||||
|
||||
# 8. Post-processing
|
||||
image = self.decode_latents(latents)
|
||||
|
||||
@@ -22,7 +22,10 @@ if is_torch_available():
|
||||
from .scheduling_dpmsolver_multistep import DPMSolverMultistepScheduler
|
||||
from .scheduling_euler_ancestral_discrete import EulerAncestralDiscreteScheduler
|
||||
from .scheduling_euler_discrete import EulerDiscreteScheduler
|
||||
from .scheduling_heun_discrete import HeunDiscreteScheduler
|
||||
from .scheduling_ipndm import IPNDMScheduler
|
||||
from .scheduling_k_dpm_2_ancestral_discrete import KDPM2AncestralDiscreteScheduler
|
||||
from .scheduling_k_dpm_2_discrete import KDPM2DiscreteScheduler
|
||||
from .scheduling_karras_ve import KarrasVeScheduler
|
||||
from .scheduling_pndm import PNDMScheduler
|
||||
from .scheduling_repaint import RePaintScheduler
|
||||
|
||||
@@ -17,7 +17,7 @@
|
||||
|
||||
import math
|
||||
from dataclasses import dataclass
|
||||
from typing import Optional, Tuple, Union
|
||||
from typing import List, Optional, Tuple, Union
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
@@ -114,6 +114,7 @@ class DDIMScheduler(SchedulerMixin, ConfigMixin):
|
||||
|
||||
_compatibles = _COMPATIBLE_STABLE_DIFFUSION_SCHEDULERS.copy()
|
||||
_deprecated_kwargs = ["predict_epsilon"]
|
||||
order = 1
|
||||
|
||||
@register_to_config
|
||||
def __init__(
|
||||
@@ -122,7 +123,7 @@ class DDIMScheduler(SchedulerMixin, ConfigMixin):
|
||||
beta_start: float = 0.0001,
|
||||
beta_end: float = 0.02,
|
||||
beta_schedule: str = "linear",
|
||||
trained_betas: Optional[np.ndarray] = None,
|
||||
trained_betas: Optional[Union[np.ndarray, List[float]]] = None,
|
||||
clip_sample: bool = True,
|
||||
set_alpha_to_one: bool = True,
|
||||
steps_offset: int = 0,
|
||||
@@ -138,7 +139,7 @@ class DDIMScheduler(SchedulerMixin, ConfigMixin):
|
||||
self.register_to_config(prediction_type="epsilon" if predict_epsilon else "sample")
|
||||
|
||||
if trained_betas is not None:
|
||||
self.betas = torch.from_numpy(trained_betas)
|
||||
self.betas = torch.tensor(trained_betas, dtype=torch.float32)
|
||||
elif beta_schedule == "linear":
|
||||
self.betas = torch.linspace(beta_start, beta_end, num_train_timesteps, dtype=torch.float32)
|
||||
elif beta_schedule == "scaled_linear":
|
||||
@@ -355,5 +356,25 @@ class DDIMScheduler(SchedulerMixin, ConfigMixin):
|
||||
noisy_samples = sqrt_alpha_prod * original_samples + sqrt_one_minus_alpha_prod * noise
|
||||
return noisy_samples
|
||||
|
||||
def get_velocity(
|
||||
self, sample: torch.FloatTensor, noise: torch.FloatTensor, timesteps: torch.IntTensor
|
||||
) -> torch.FloatTensor:
|
||||
# Make sure alphas_cumprod and timestep have same device and dtype as sample
|
||||
self.alphas_cumprod = self.alphas_cumprod.to(device=sample.device, dtype=sample.dtype)
|
||||
timesteps = timesteps.to(sample.device)
|
||||
|
||||
sqrt_alpha_prod = self.alphas_cumprod[timesteps] ** 0.5
|
||||
sqrt_alpha_prod = sqrt_alpha_prod.flatten()
|
||||
while len(sqrt_alpha_prod.shape) < len(sample.shape):
|
||||
sqrt_alpha_prod = sqrt_alpha_prod.unsqueeze(-1)
|
||||
|
||||
sqrt_one_minus_alpha_prod = (1 - self.alphas_cumprod[timesteps]) ** 0.5
|
||||
sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.flatten()
|
||||
while len(sqrt_one_minus_alpha_prod.shape) < len(sample.shape):
|
||||
sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.unsqueeze(-1)
|
||||
|
||||
velocity = sqrt_alpha_prod * noise - sqrt_one_minus_alpha_prod * sample
|
||||
return velocity
|
||||
|
||||
def __len__(self):
|
||||
return self.config.num_train_timesteps
|
||||
|
||||
@@ -16,7 +16,7 @@
|
||||
|
||||
import math
|
||||
from dataclasses import dataclass
|
||||
from typing import Optional, Tuple, Union
|
||||
from typing import List, Optional, Tuple, Union
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
@@ -106,6 +106,7 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin):
|
||||
|
||||
_compatibles = _COMPATIBLE_STABLE_DIFFUSION_SCHEDULERS.copy()
|
||||
_deprecated_kwargs = ["predict_epsilon"]
|
||||
order = 1
|
||||
|
||||
@register_to_config
|
||||
def __init__(
|
||||
@@ -114,7 +115,7 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin):
|
||||
beta_start: float = 0.0001,
|
||||
beta_end: float = 0.02,
|
||||
beta_schedule: str = "linear",
|
||||
trained_betas: Optional[np.ndarray] = None,
|
||||
trained_betas: Optional[Union[np.ndarray, List[float]]] = None,
|
||||
variance_type: str = "fixed_small",
|
||||
clip_sample: bool = True,
|
||||
prediction_type: str = "epsilon",
|
||||
@@ -129,7 +130,7 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin):
|
||||
self.register_to_config(prediction_type="epsilon" if predict_epsilon else "sample")
|
||||
|
||||
if trained_betas is not None:
|
||||
self.betas = torch.from_numpy(trained_betas)
|
||||
self.betas = torch.tensor(trained_betas, dtype=torch.float32)
|
||||
elif beta_schedule == "linear":
|
||||
self.betas = torch.linspace(beta_start, beta_end, num_train_timesteps, dtype=torch.float32)
|
||||
elif beta_schedule == "scaled_linear":
|
||||
@@ -345,5 +346,25 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin):
|
||||
noisy_samples = sqrt_alpha_prod * original_samples + sqrt_one_minus_alpha_prod * noise
|
||||
return noisy_samples
|
||||
|
||||
def get_velocity(
|
||||
self, sample: torch.FloatTensor, noise: torch.FloatTensor, timesteps: torch.IntTensor
|
||||
) -> torch.FloatTensor:
|
||||
# Make sure alphas_cumprod and timestep have same device and dtype as sample
|
||||
self.alphas_cumprod = self.alphas_cumprod.to(device=sample.device, dtype=sample.dtype)
|
||||
timesteps = timesteps.to(sample.device)
|
||||
|
||||
sqrt_alpha_prod = self.alphas_cumprod[timesteps] ** 0.5
|
||||
sqrt_alpha_prod = sqrt_alpha_prod.flatten()
|
||||
while len(sqrt_alpha_prod.shape) < len(sample.shape):
|
||||
sqrt_alpha_prod = sqrt_alpha_prod.unsqueeze(-1)
|
||||
|
||||
sqrt_one_minus_alpha_prod = (1 - self.alphas_cumprod[timesteps]) ** 0.5
|
||||
sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.flatten()
|
||||
while len(sqrt_one_minus_alpha_prod.shape) < len(sample.shape):
|
||||
sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.unsqueeze(-1)
|
||||
|
||||
velocity = sqrt_alpha_prod * noise - sqrt_one_minus_alpha_prod * sample
|
||||
return velocity
|
||||
|
||||
def __len__(self):
|
||||
return self.config.num_train_timesteps
|
||||
|
||||
@@ -118,6 +118,7 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
|
||||
|
||||
_compatibles = _COMPATIBLE_STABLE_DIFFUSION_SCHEDULERS.copy()
|
||||
_deprecated_kwargs = ["predict_epsilon"]
|
||||
order = 1
|
||||
|
||||
@register_to_config
|
||||
def __init__(
|
||||
@@ -126,7 +127,7 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
|
||||
beta_start: float = 0.0001,
|
||||
beta_end: float = 0.02,
|
||||
beta_schedule: str = "linear",
|
||||
trained_betas: Optional[np.ndarray] = None,
|
||||
trained_betas: Optional[Union[np.ndarray, List[float]]] = None,
|
||||
solver_order: int = 2,
|
||||
prediction_type: str = "epsilon",
|
||||
thresholding: bool = False,
|
||||
@@ -146,7 +147,7 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
|
||||
self.register_to_config(prediction_type="epsilon" if predict_epsilon else "sample")
|
||||
|
||||
if trained_betas is not None:
|
||||
self.betas = torch.from_numpy(trained_betas)
|
||||
self.betas = torch.tensor(trained_betas, dtype=torch.float32)
|
||||
elif beta_schedule == "linear":
|
||||
self.betas = torch.linspace(beta_start, beta_end, num_train_timesteps, dtype=torch.float32)
|
||||
elif beta_schedule == "scaled_linear":
|
||||
@@ -246,6 +247,9 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
|
||||
|
||||
if self.config.thresholding:
|
||||
# Dynamic thresholding in https://arxiv.org/abs/2205.11487
|
||||
orig_dtype = x0_pred.dtype
|
||||
if orig_dtype not in [torch.float, torch.double]:
|
||||
x0_pred = x0_pred.float()
|
||||
dynamic_max_val = torch.quantile(
|
||||
torch.abs(x0_pred).reshape((x0_pred.shape[0], -1)), self.config.dynamic_thresholding_ratio, dim=1
|
||||
)
|
||||
@@ -254,6 +258,7 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
|
||||
self.config.sample_max_value * torch.ones_like(dynamic_max_val).to(dynamic_max_val.device),
|
||||
)[(...,) + (None,) * (x0_pred.ndim - 1)]
|
||||
x0_pred = torch.clamp(x0_pred, -dynamic_max_val, dynamic_max_val) / dynamic_max_val
|
||||
x0_pred = x0_pred.type(orig_dtype)
|
||||
return x0_pred
|
||||
# DPM-Solver needs to solve an integral of the noise prediction model.
|
||||
elif self.config.algorithm_type == "dpmsolver":
|
||||
|
||||
@@ -13,7 +13,7 @@
|
||||
# limitations under the License.
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import Optional, Tuple, Union
|
||||
from typing import List, Optional, Tuple, Union
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
@@ -68,6 +68,7 @@ class EulerAncestralDiscreteScheduler(SchedulerMixin, ConfigMixin):
|
||||
"""
|
||||
|
||||
_compatibles = _COMPATIBLE_STABLE_DIFFUSION_SCHEDULERS.copy()
|
||||
order = 1
|
||||
|
||||
@register_to_config
|
||||
def __init__(
|
||||
@@ -76,10 +77,10 @@ class EulerAncestralDiscreteScheduler(SchedulerMixin, ConfigMixin):
|
||||
beta_start: float = 0.0001,
|
||||
beta_end: float = 0.02,
|
||||
beta_schedule: str = "linear",
|
||||
trained_betas: Optional[np.ndarray] = None,
|
||||
trained_betas: Optional[Union[np.ndarray, List[float]]] = None,
|
||||
):
|
||||
if trained_betas is not None:
|
||||
self.betas = torch.from_numpy(trained_betas)
|
||||
self.betas = torch.tensor(trained_betas, dtype=torch.float32)
|
||||
elif beta_schedule == "linear":
|
||||
self.betas = torch.linspace(beta_start, beta_end, num_train_timesteps, dtype=torch.float32)
|
||||
elif beta_schedule == "scaled_linear":
|
||||
|
||||
@@ -13,7 +13,7 @@
|
||||
# limitations under the License.
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import Optional, Tuple, Union
|
||||
from typing import List, Optional, Tuple, Union
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
@@ -69,6 +69,7 @@ class EulerDiscreteScheduler(SchedulerMixin, ConfigMixin):
|
||||
"""
|
||||
|
||||
_compatibles = _COMPATIBLE_STABLE_DIFFUSION_SCHEDULERS.copy()
|
||||
order = 1
|
||||
|
||||
@register_to_config
|
||||
def __init__(
|
||||
@@ -77,11 +78,11 @@ class EulerDiscreteScheduler(SchedulerMixin, ConfigMixin):
|
||||
beta_start: float = 0.0001,
|
||||
beta_end: float = 0.02,
|
||||
beta_schedule: str = "linear",
|
||||
trained_betas: Optional[np.ndarray] = None,
|
||||
trained_betas: Optional[Union[np.ndarray, List[float]]] = None,
|
||||
prediction_type: str = "epsilon",
|
||||
):
|
||||
if trained_betas is not None:
|
||||
self.betas = torch.from_numpy(trained_betas)
|
||||
self.betas = torch.tensor(trained_betas, dtype=torch.float32)
|
||||
elif beta_schedule == "linear":
|
||||
self.betas = torch.linspace(beta_start, beta_end, num_train_timesteps, dtype=torch.float32)
|
||||
elif beta_schedule == "scaled_linear":
|
||||
|
||||
@@ -0,0 +1,249 @@
|
||||
# Copyright 2022 Katherine Crowson, The HuggingFace Team and hlky. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from typing import List, Optional, Tuple, Union
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
from ..configuration_utils import ConfigMixin, register_to_config
|
||||
from ..utils import _COMPATIBLE_STABLE_DIFFUSION_SCHEDULERS
|
||||
from .scheduling_utils import SchedulerMixin, SchedulerOutput
|
||||
|
||||
|
||||
class HeunDiscreteScheduler(SchedulerMixin, ConfigMixin):
|
||||
"""
|
||||
Implements Algorithm 2 (Heun steps) from Karras et al. (2022). for discrete beta schedules. Based on the original
|
||||
k-diffusion implementation by Katherine Crowson:
|
||||
https://github.com/crowsonkb/k-diffusion/blob/481677d114f6ea445aa009cf5bd7a9cdee909e47/k_diffusion/sampling.py#L90
|
||||
|
||||
[`~ConfigMixin`] takes care of storing all config attributes that are passed in the scheduler's `__init__`
|
||||
function, such as `num_train_timesteps`. They can be accessed via `scheduler.config.num_train_timesteps`.
|
||||
[`SchedulerMixin`] provides general loading and saving functionality via the [`SchedulerMixin.save_pretrained`] and
|
||||
[`~SchedulerMixin.from_pretrained`] functions.
|
||||
|
||||
Args:
|
||||
num_train_timesteps (`int`): number of diffusion steps used to train the model. beta_start (`float`): the
|
||||
starting `beta` value of inference. beta_end (`float`): the final `beta` value. beta_schedule (`str`):
|
||||
the beta schedule, a mapping from a beta range to a sequence of betas for stepping the model. Choose from
|
||||
`linear` or `scaled_linear`.
|
||||
trained_betas (`np.ndarray`, optional):
|
||||
option to pass an array of betas directly to the constructor to bypass `beta_start`, `beta_end` etc.
|
||||
options to clip the variance used when adding noise to the denoised sample. Choose from `fixed_small`,
|
||||
`fixed_small_log`, `fixed_large`, `fixed_large_log`, `learned` or `learned_range`.
|
||||
tensor_format (`str`): whether the scheduler expects pytorch or numpy arrays.
|
||||
"""
|
||||
|
||||
_compatibles = _COMPATIBLE_STABLE_DIFFUSION_SCHEDULERS.copy()
|
||||
order = 2
|
||||
|
||||
@register_to_config
|
||||
def __init__(
|
||||
self,
|
||||
num_train_timesteps: int = 1000,
|
||||
beta_start: float = 0.00085, # sensible defaults
|
||||
beta_end: float = 0.012,
|
||||
beta_schedule: str = "linear",
|
||||
trained_betas: Optional[Union[np.ndarray, List[float]]] = None,
|
||||
):
|
||||
if trained_betas is not None:
|
||||
self.betas = torch.tensor(trained_betas, dtype=torch.float32)
|
||||
elif beta_schedule == "linear":
|
||||
self.betas = torch.linspace(beta_start, beta_end, num_train_timesteps, dtype=torch.float32)
|
||||
elif beta_schedule == "scaled_linear":
|
||||
# this schedule is very specific to the latent diffusion model.
|
||||
self.betas = (
|
||||
torch.linspace(beta_start**0.5, beta_end**0.5, num_train_timesteps, dtype=torch.float32) ** 2
|
||||
)
|
||||
else:
|
||||
raise NotImplementedError(f"{beta_schedule} does is not implemented for {self.__class__}")
|
||||
|
||||
self.alphas = 1.0 - self.betas
|
||||
self.alphas_cumprod = torch.cumprod(self.alphas, dim=0)
|
||||
|
||||
# set all values
|
||||
self.set_timesteps(num_train_timesteps, None, num_train_timesteps)
|
||||
|
||||
def index_for_timestep(self, timestep):
|
||||
indices = (self.timesteps == timestep).nonzero()
|
||||
if self.state_in_first_order:
|
||||
pos = -1
|
||||
else:
|
||||
pos = 0
|
||||
return indices[pos].item()
|
||||
|
||||
def scale_model_input(
|
||||
self,
|
||||
sample: torch.FloatTensor,
|
||||
timestep: Union[float, torch.FloatTensor],
|
||||
) -> torch.FloatTensor:
|
||||
"""
|
||||
Args:
|
||||
Ensures interchangeability with schedulers that need to scale the denoising model input depending on the
|
||||
current timestep.
|
||||
sample (`torch.FloatTensor`): input sample timestep (`int`, optional): current timestep
|
||||
Returns:
|
||||
`torch.FloatTensor`: scaled input sample
|
||||
"""
|
||||
step_index = self.index_for_timestep(timestep)
|
||||
|
||||
sigma = self.sigmas[step_index]
|
||||
sample = sample / ((sigma**2 + 1) ** 0.5)
|
||||
return sample
|
||||
|
||||
def set_timesteps(
|
||||
self,
|
||||
num_inference_steps: int,
|
||||
device: Union[str, torch.device] = None,
|
||||
num_train_timesteps: Optional[int] = None,
|
||||
):
|
||||
"""
|
||||
Sets the timesteps used for the diffusion chain. Supporting function to be run before inference.
|
||||
|
||||
Args:
|
||||
num_inference_steps (`int`):
|
||||
the number of diffusion steps used when generating samples with a pre-trained model.
|
||||
device (`str` or `torch.device`, optional):
|
||||
the device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
|
||||
"""
|
||||
self.num_inference_steps = num_inference_steps
|
||||
|
||||
num_train_timesteps = num_train_timesteps or self.config.num_train_timesteps
|
||||
|
||||
timesteps = np.linspace(0, num_train_timesteps - 1, num_inference_steps, dtype=float)[::-1].copy()
|
||||
|
||||
sigmas = np.array(((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5)
|
||||
sigmas = np.interp(timesteps, np.arange(0, len(sigmas)), sigmas)
|
||||
sigmas = np.concatenate([sigmas, [0.0]]).astype(np.float32)
|
||||
sigmas = torch.from_numpy(sigmas).to(device=device)
|
||||
self.sigmas = torch.cat([sigmas[:1], sigmas[1:-1].repeat_interleave(2), sigmas[-1:]])
|
||||
|
||||
# standard deviation of the initial noise distribution
|
||||
self.init_noise_sigma = self.sigmas.max()
|
||||
|
||||
timesteps = torch.from_numpy(timesteps)
|
||||
timesteps = torch.cat([timesteps[:1], timesteps[1:].repeat_interleave(2)])
|
||||
|
||||
if str(device).startswith("mps"):
|
||||
# mps does not support float64
|
||||
self.timesteps = timesteps.to(device, dtype=torch.float32)
|
||||
else:
|
||||
self.timesteps = timesteps.to(device=device)
|
||||
|
||||
# empty dt and derivative
|
||||
self.prev_derivative = None
|
||||
self.dt = None
|
||||
|
||||
@property
|
||||
def state_in_first_order(self):
|
||||
return self.dt is None
|
||||
|
||||
def step(
|
||||
self,
|
||||
model_output: Union[torch.FloatTensor, np.ndarray],
|
||||
timestep: Union[float, torch.FloatTensor],
|
||||
sample: Union[torch.FloatTensor, np.ndarray],
|
||||
return_dict: bool = True,
|
||||
) -> Union[SchedulerOutput, Tuple]:
|
||||
"""
|
||||
Args:
|
||||
Predict the sample at the previous timestep by reversing the SDE. Core function to propagate the diffusion
|
||||
process from the learned model outputs (most often the predicted noise).
|
||||
model_output (`torch.FloatTensor` or `np.ndarray`): direct output from learned diffusion model. timestep
|
||||
(`int`): current discrete timestep in the diffusion chain. sample (`torch.FloatTensor` or `np.ndarray`):
|
||||
current instance of sample being created by diffusion process.
|
||||
return_dict (`bool`): option for returning tuple rather than SchedulerOutput class
|
||||
Returns:
|
||||
[`~schedulers.scheduling_utils.SchedulerOutput`] or `tuple`:
|
||||
[`~schedulers.scheduling_utils.SchedulerOutput`] if `return_dict` is True, otherwise a `tuple`. When
|
||||
returning a tuple, the first element is the sample tensor.
|
||||
"""
|
||||
step_index = self.index_for_timestep(timestep)
|
||||
|
||||
if self.state_in_first_order:
|
||||
sigma = self.sigmas[step_index]
|
||||
sigma_next = self.sigmas[step_index + 1]
|
||||
else:
|
||||
# 2nd order / Heun's method
|
||||
sigma = self.sigmas[step_index - 1]
|
||||
sigma_next = self.sigmas[step_index]
|
||||
|
||||
# currently only gamma=0 is supported. This usually works best anyways.
|
||||
# We can support gamma in the future but then need to scale the timestep before
|
||||
# passing it to the model which requires a change in API
|
||||
gamma = 0
|
||||
sigma_hat = sigma * (gamma + 1) # Note: sigma_hat == sigma for now
|
||||
|
||||
# 1. compute predicted original sample (x_0) from sigma-scaled predicted noise
|
||||
pred_original_sample = sample - sigma_hat * model_output
|
||||
|
||||
if self.state_in_first_order:
|
||||
# 2. Convert to an ODE derivative
|
||||
derivative = (sample - pred_original_sample) / sigma_hat
|
||||
# 3. 1st order derivative
|
||||
dt = sigma_next - sigma_hat
|
||||
|
||||
# store for 2nd order step
|
||||
self.prev_derivative = derivative
|
||||
self.dt = dt
|
||||
self.sample = sample
|
||||
else:
|
||||
# 2. 2nd order / Heun's method
|
||||
derivative = (sample - pred_original_sample) / sigma_hat
|
||||
derivative = (self.prev_derivative + derivative) / 2
|
||||
|
||||
# 3. Retrieve 1st order derivative
|
||||
dt = self.dt
|
||||
sample = self.sample
|
||||
|
||||
# free dt and derivative
|
||||
# Note, this puts the scheduler in "first order mode"
|
||||
self.prev_derivative = None
|
||||
self.dt = None
|
||||
self.sample = None
|
||||
|
||||
prev_sample = sample + derivative * dt
|
||||
|
||||
if not return_dict:
|
||||
return (prev_sample,)
|
||||
|
||||
return SchedulerOutput(prev_sample=prev_sample)
|
||||
|
||||
def add_noise(
|
||||
self,
|
||||
original_samples: torch.FloatTensor,
|
||||
noise: torch.FloatTensor,
|
||||
timesteps: torch.FloatTensor,
|
||||
) -> torch.FloatTensor:
|
||||
# Make sure sigmas and timesteps have the same device and dtype as original_samples
|
||||
self.sigmas = self.sigmas.to(device=original_samples.device, dtype=original_samples.dtype)
|
||||
if original_samples.device.type == "mps" and torch.is_floating_point(timesteps):
|
||||
# mps does not support float64
|
||||
self.timesteps = self.timesteps.to(original_samples.device, dtype=torch.float32)
|
||||
timesteps = timesteps.to(original_samples.device, dtype=torch.float32)
|
||||
else:
|
||||
self.timesteps = self.timesteps.to(original_samples.device)
|
||||
timesteps = timesteps.to(original_samples.device)
|
||||
|
||||
step_indices = [self.index_for_timestep(t) for t in timesteps]
|
||||
|
||||
sigma = self.sigmas[step_indices].flatten()
|
||||
while len(sigma.shape) < len(original_samples.shape):
|
||||
sigma = sigma.unsqueeze(-1)
|
||||
|
||||
noisy_samples = original_samples + noise * sigma
|
||||
return noisy_samples
|
||||
|
||||
def __len__(self):
|
||||
return self.config.num_train_timesteps
|
||||
@@ -13,8 +13,9 @@
|
||||
# limitations under the License.
|
||||
|
||||
import math
|
||||
from typing import Tuple, Union
|
||||
from typing import List, Optional, Tuple, Union
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
from ..configuration_utils import ConfigMixin, register_to_config
|
||||
@@ -37,8 +38,12 @@ class IPNDMScheduler(SchedulerMixin, ConfigMixin):
|
||||
num_train_timesteps (`int`): number of diffusion steps used to train the model.
|
||||
"""
|
||||
|
||||
order = 1
|
||||
|
||||
@register_to_config
|
||||
def __init__(self, num_train_timesteps: int = 1000):
|
||||
def __init__(
|
||||
self, num_train_timesteps: int = 1000, trained_betas: Optional[Union[np.ndarray, List[float]]] = None
|
||||
):
|
||||
# set `betas`, `alphas`, `timesteps`
|
||||
self.set_timesteps(num_train_timesteps)
|
||||
|
||||
@@ -65,7 +70,11 @@ class IPNDMScheduler(SchedulerMixin, ConfigMixin):
|
||||
steps = torch.linspace(1, 0, num_inference_steps + 1)[:-1]
|
||||
steps = torch.cat([steps, torch.tensor([0.0])])
|
||||
|
||||
self.betas = torch.sin(steps * math.pi / 2) ** 2
|
||||
if self.config.trained_betas is not None:
|
||||
self.betas = torch.tensor(self.config.trained_betas, dtype=torch.float32)
|
||||
else:
|
||||
self.betas = torch.sin(steps * math.pi / 2) ** 2
|
||||
|
||||
self.alphas = (1.0 - self.betas**2) ** 0.5
|
||||
|
||||
timesteps = (torch.atan2(self.betas, self.alphas) / math.pi * 2)[:-1]
|
||||
|
||||
@@ -0,0 +1,268 @@
|
||||
# Copyright 2022 Katherine Crowson, The HuggingFace Team and hlky. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from typing import List, Optional, Tuple, Union
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
from ..configuration_utils import ConfigMixin, register_to_config
|
||||
from ..utils import _COMPATIBLE_STABLE_DIFFUSION_SCHEDULERS
|
||||
from .scheduling_utils import SchedulerMixin, SchedulerOutput
|
||||
|
||||
|
||||
class KDPM2AncestralDiscreteScheduler(SchedulerMixin, ConfigMixin):
|
||||
"""
|
||||
Scheduler created by @crowsonkb in [k_diffusion](https://github.com/crowsonkb/k-diffusion), see:
|
||||
https://github.com/crowsonkb/k-diffusion/blob/5b3af030dd83e0297272d861c19477735d0317ec/k_diffusion/sampling.py#L188
|
||||
|
||||
Scheduler inspired by DPM-Solver-2 and Algorthim 2 from Karras et al. (2022).
|
||||
|
||||
[`~ConfigMixin`] takes care of storing all config attributes that are passed in the scheduler's `__init__`
|
||||
function, such as `num_train_timesteps`. They can be accessed via `scheduler.config.num_train_timesteps`.
|
||||
[`SchedulerMixin`] provides general loading and saving functionality via the [`SchedulerMixin.save_pretrained`] and
|
||||
[`~SchedulerMixin.from_pretrained`] functions.
|
||||
|
||||
Args:
|
||||
num_train_timesteps (`int`): number of diffusion steps used to train the model. beta_start (`float`): the
|
||||
starting `beta` value of inference. beta_end (`float`): the final `beta` value. beta_schedule (`str`):
|
||||
the beta schedule, a mapping from a beta range to a sequence of betas for stepping the model. Choose from
|
||||
`linear` or `scaled_linear`.
|
||||
trained_betas (`np.ndarray`, optional):
|
||||
option to pass an array of betas directly to the constructor to bypass `beta_start`, `beta_end` etc.
|
||||
options to clip the variance used when adding noise to the denoised sample. Choose from `fixed_small`,
|
||||
`fixed_small_log`, `fixed_large`, `fixed_large_log`, `learned` or `learned_range`.
|
||||
tensor_format (`str`): whether the scheduler expects pytorch or numpy arrays.
|
||||
"""
|
||||
|
||||
_compatibles = _COMPATIBLE_STABLE_DIFFUSION_SCHEDULERS.copy()
|
||||
order = 2
|
||||
|
||||
@register_to_config
|
||||
def __init__(
|
||||
self,
|
||||
num_train_timesteps: int = 1000,
|
||||
beta_start: float = 0.00085, # sensible defaults
|
||||
beta_end: float = 0.012,
|
||||
beta_schedule: str = "linear",
|
||||
trained_betas: Optional[Union[np.ndarray, List[float]]] = None,
|
||||
):
|
||||
if trained_betas is not None:
|
||||
self.betas = torch.tensor(trained_betas, dtype=torch.float32)
|
||||
elif beta_schedule == "linear":
|
||||
self.betas = torch.linspace(beta_start, beta_end, num_train_timesteps, dtype=torch.float32)
|
||||
elif beta_schedule == "scaled_linear":
|
||||
# this schedule is very specific to the latent diffusion model.
|
||||
self.betas = (
|
||||
torch.linspace(beta_start**0.5, beta_end**0.5, num_train_timesteps, dtype=torch.float32) ** 2
|
||||
)
|
||||
else:
|
||||
raise NotImplementedError(f"{beta_schedule} does is not implemented for {self.__class__}")
|
||||
|
||||
self.alphas = 1.0 - self.betas
|
||||
self.alphas_cumprod = torch.cumprod(self.alphas, dim=0)
|
||||
|
||||
# set all values
|
||||
self.set_timesteps(num_train_timesteps, None, num_train_timesteps)
|
||||
|
||||
def index_for_timestep(self, timestep):
|
||||
indices = (self.timesteps == timestep).nonzero()
|
||||
if self.state_in_first_order:
|
||||
pos = -1
|
||||
else:
|
||||
pos = 0
|
||||
return indices[pos].item()
|
||||
|
||||
def scale_model_input(
|
||||
self,
|
||||
sample: torch.FloatTensor,
|
||||
timestep: Union[float, torch.FloatTensor],
|
||||
) -> torch.FloatTensor:
|
||||
"""
|
||||
Args:
|
||||
Ensures interchangeability with schedulers that need to scale the denoising model input depending on the
|
||||
current timestep.
|
||||
sample (`torch.FloatTensor`): input sample timestep (`int`, optional): current timestep
|
||||
Returns:
|
||||
`torch.FloatTensor`: scaled input sample
|
||||
"""
|
||||
step_index = self.index_for_timestep(timestep)
|
||||
|
||||
sigma = self.sigmas[step_index]
|
||||
sample = sample / ((sigma**2 + 1) ** 0.5)
|
||||
return sample
|
||||
|
||||
def set_timesteps(
|
||||
self,
|
||||
num_inference_steps: int,
|
||||
device: Union[str, torch.device] = None,
|
||||
num_train_timesteps: Optional[int] = None,
|
||||
):
|
||||
"""
|
||||
Sets the timesteps used for the diffusion chain. Supporting function to be run before inference.
|
||||
|
||||
Args:
|
||||
num_inference_steps (`int`):
|
||||
the number of diffusion steps used when generating samples with a pre-trained model.
|
||||
device (`str` or `torch.device`, optional):
|
||||
the device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
|
||||
"""
|
||||
self.num_inference_steps = num_inference_steps
|
||||
|
||||
num_train_timesteps = num_train_timesteps or self.config.num_train_timesteps
|
||||
|
||||
timesteps = np.linspace(0, num_train_timesteps - 1, num_inference_steps, dtype=float)[::-1].copy()
|
||||
|
||||
sigmas = np.array(((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5)
|
||||
self.log_sigmas = torch.from_numpy(np.log(sigmas)).to(device)
|
||||
|
||||
sigmas = np.interp(timesteps, np.arange(0, len(sigmas)), sigmas)
|
||||
sigmas = np.concatenate([sigmas, [0.0]]).astype(np.float32)
|
||||
sigmas = torch.from_numpy(sigmas).to(device=device)
|
||||
|
||||
# compute up and down sigmas
|
||||
sigmas_next = sigmas.roll(-1)
|
||||
sigmas_next[-1] = 0.0
|
||||
sigmas_up = (sigmas_next**2 * (sigmas**2 - sigmas_next**2) / sigmas**2) ** 0.5
|
||||
sigmas_down = (sigmas_next**2 - sigmas_up**2) ** 0.5
|
||||
sigmas_down[-1] = 0.0
|
||||
|
||||
self.sigmas = torch.cat([sigmas[:1], sigmas[1:].repeat_interleave(2), sigmas[-1:]])
|
||||
self.sigmas_up = torch.cat([sigmas_up[:1], sigmas_up[1:].repeat_interleave(2), sigmas_up[-1:]])
|
||||
self.sigmas_down = torch.cat([sigmas_down[:1], sigmas_down[1:].repeat_interleave(2), sigmas_down[-1:]])
|
||||
|
||||
# standard deviation of the initial noise distribution
|
||||
self.init_noise_sigma = self.sigmas.max()
|
||||
|
||||
timesteps = torch.from_numpy(timesteps)
|
||||
timesteps = torch.cat([timesteps[:1], timesteps[1:].repeat_interleave(2)])
|
||||
|
||||
if str(device).startswith("mps"):
|
||||
# mps does not support float64
|
||||
self.timesteps = timesteps.to(device, dtype=torch.float32)
|
||||
else:
|
||||
self.timesteps = timesteps
|
||||
|
||||
self.sample = None
|
||||
|
||||
@property
|
||||
def state_in_first_order(self):
|
||||
return self.sample is None
|
||||
|
||||
def step(
|
||||
self,
|
||||
model_output: Union[torch.FloatTensor, np.ndarray],
|
||||
timestep: Union[float, torch.FloatTensor],
|
||||
sample: Union[torch.FloatTensor, np.ndarray],
|
||||
generator: Optional[torch.Generator] = None,
|
||||
return_dict: bool = True,
|
||||
) -> Union[SchedulerOutput, Tuple]:
|
||||
"""
|
||||
Args:
|
||||
Predict the sample at the previous timestep by reversing the SDE. Core function to propagate the diffusion
|
||||
process from the learned model outputs (most often the predicted noise).
|
||||
model_output (`torch.FloatTensor` or `np.ndarray`): direct output from learned diffusion model. timestep
|
||||
(`int`): current discrete timestep in the diffusion chain. sample (`torch.FloatTensor` or `np.ndarray`):
|
||||
current instance of sample being created by diffusion process.
|
||||
return_dict (`bool`): option for returning tuple rather than SchedulerOutput class
|
||||
Returns:
|
||||
[`~schedulers.scheduling_utils.SchedulerOutput`] or `tuple`:
|
||||
[`~schedulers.scheduling_utils.SchedulerOutput`] if `return_dict` is True, otherwise a `tuple`. When
|
||||
returning a tuple, the first element is the sample tensor.
|
||||
"""
|
||||
step_index = self.index_for_timestep(timestep)
|
||||
|
||||
if self.state_in_first_order:
|
||||
sigma = self.sigmas[step_index]
|
||||
sigma_next = self.sigmas[step_index + 1]
|
||||
else:
|
||||
# 2nd order / KPDM2's method
|
||||
sigma = self.sigmas[step_index - 1]
|
||||
sigma_next = self.sigmas[step_index]
|
||||
sigma_up = self.sigmas_up[step_index - 1]
|
||||
sigma_down = self.sigmas_down[step_index - 1]
|
||||
|
||||
# currently only gamma=0 is supported. This usually works best anyways.
|
||||
# We can support gamma in the future but then need to scale the timestep before
|
||||
# passing it to the model which requires a change in API
|
||||
gamma = 0
|
||||
sigma_hat = sigma * (gamma + 1) # Note: sigma_hat == sigma for now
|
||||
|
||||
device = model_output.device
|
||||
if device.type == "mps":
|
||||
# randn does not work reproducibly on mps
|
||||
noise = torch.randn(model_output.shape, dtype=model_output.dtype, device="cpu", generator=generator).to(
|
||||
device
|
||||
)
|
||||
else:
|
||||
noise = torch.randn(model_output.shape, dtype=model_output.dtype, device=device, generator=generator).to(
|
||||
device
|
||||
)
|
||||
|
||||
# 1. compute predicted original sample (x_0) from sigma-scaled predicted noise
|
||||
pred_original_sample = sample - sigma_hat * model_output
|
||||
|
||||
if self.state_in_first_order:
|
||||
# 2. Convert to an ODE derivative
|
||||
derivative = (sample - pred_original_sample) / sigma_hat
|
||||
# 3. 1st order derivative
|
||||
dt = sigma_next - sigma_hat
|
||||
|
||||
# store for 2nd order step
|
||||
self.sample = sample
|
||||
self.dt = dt
|
||||
prev_sample = sample + derivative * dt
|
||||
else:
|
||||
# DPM-Solver-2
|
||||
derivative = (sample - pred_original_sample) / sigma_hat
|
||||
dt = sigma_down - sigma_hat
|
||||
|
||||
sample = self.sample
|
||||
self.sample = None
|
||||
|
||||
prev_sample = sample + derivative * dt
|
||||
prev_sample = prev_sample + noise * sigma_up
|
||||
|
||||
if not return_dict:
|
||||
return (prev_sample,)
|
||||
|
||||
return SchedulerOutput(prev_sample=prev_sample)
|
||||
|
||||
def add_noise(
|
||||
self,
|
||||
original_samples: torch.FloatTensor,
|
||||
noise: torch.FloatTensor,
|
||||
timesteps: torch.FloatTensor,
|
||||
) -> torch.FloatTensor:
|
||||
# Make sure sigmas and timesteps have the same device and dtype as original_samples
|
||||
self.sigmas = self.sigmas.to(device=original_samples.device, dtype=original_samples.dtype)
|
||||
if original_samples.device.type == "mps" and torch.is_floating_point(timesteps):
|
||||
# mps does not support float64
|
||||
self.timesteps = self.timesteps.to(original_samples.device, dtype=torch.float32)
|
||||
timesteps = timesteps.to(original_samples.device, dtype=torch.float32)
|
||||
else:
|
||||
self.timesteps = self.timesteps.to(original_samples.device)
|
||||
timesteps = timesteps.to(original_samples.device)
|
||||
|
||||
step_indices = [self.index_for_timestep(t) for t in timesteps]
|
||||
|
||||
sigma = self.sigmas[step_indices].flatten()
|
||||
while len(sigma.shape) < len(original_samples.shape):
|
||||
sigma = sigma.unsqueeze(-1)
|
||||
|
||||
noisy_samples = original_samples + noise * sigma
|
||||
return noisy_samples
|
||||
|
||||
def __len__(self):
|
||||
return self.config.num_train_timesteps
|
||||
@@ -0,0 +1,283 @@
|
||||
# Copyright 2022 Katherine Crowson, The HuggingFace Team and hlky. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from typing import List, Optional, Tuple, Union
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
from ..configuration_utils import ConfigMixin, register_to_config
|
||||
from ..utils import _COMPATIBLE_STABLE_DIFFUSION_SCHEDULERS
|
||||
from .scheduling_utils import SchedulerMixin, SchedulerOutput
|
||||
|
||||
|
||||
class KDPM2DiscreteScheduler(SchedulerMixin, ConfigMixin):
|
||||
"""
|
||||
Scheduler created by @crowsonkb in [k_diffusion](https://github.com/crowsonkb/k-diffusion), see:
|
||||
https://github.com/crowsonkb/k-diffusion/blob/5b3af030dd83e0297272d861c19477735d0317ec/k_diffusion/sampling.py#L188
|
||||
|
||||
Scheduler inspired by DPM-Solver-2 and Algorthim 2 from Karras et al. (2022).
|
||||
|
||||
[`~ConfigMixin`] takes care of storing all config attributes that are passed in the scheduler's `__init__`
|
||||
function, such as `num_train_timesteps`. They can be accessed via `scheduler.config.num_train_timesteps`.
|
||||
[`SchedulerMixin`] provides general loading and saving functionality via the [`SchedulerMixin.save_pretrained`] and
|
||||
[`~SchedulerMixin.from_pretrained`] functions.
|
||||
|
||||
Args:
|
||||
num_train_timesteps (`int`): number of diffusion steps used to train the model. beta_start (`float`): the
|
||||
starting `beta` value of inference. beta_end (`float`): the final `beta` value. beta_schedule (`str`):
|
||||
the beta schedule, a mapping from a beta range to a sequence of betas for stepping the model. Choose from
|
||||
`linear` or `scaled_linear`.
|
||||
trained_betas (`np.ndarray`, optional):
|
||||
option to pass an array of betas directly to the constructor to bypass `beta_start`, `beta_end` etc.
|
||||
options to clip the variance used when adding noise to the denoised sample. Choose from `fixed_small`,
|
||||
`fixed_small_log`, `fixed_large`, `fixed_large_log`, `learned` or `learned_range`.
|
||||
tensor_format (`str`): whether the scheduler expects pytorch or numpy arrays.
|
||||
"""
|
||||
|
||||
_compatibles = _COMPATIBLE_STABLE_DIFFUSION_SCHEDULERS.copy()
|
||||
order = 2
|
||||
|
||||
@register_to_config
|
||||
def __init__(
|
||||
self,
|
||||
num_train_timesteps: int = 1000,
|
||||
beta_start: float = 0.00085, # sensible defaults
|
||||
beta_end: float = 0.012,
|
||||
beta_schedule: str = "linear",
|
||||
trained_betas: Optional[Union[np.ndarray, List[float]]] = None,
|
||||
):
|
||||
if trained_betas is not None:
|
||||
self.betas = torch.tensor(trained_betas, dtype=torch.float32)
|
||||
elif beta_schedule == "linear":
|
||||
self.betas = torch.linspace(beta_start, beta_end, num_train_timesteps, dtype=torch.float32)
|
||||
elif beta_schedule == "scaled_linear":
|
||||
# this schedule is very specific to the latent diffusion model.
|
||||
self.betas = (
|
||||
torch.linspace(beta_start**0.5, beta_end**0.5, num_train_timesteps, dtype=torch.float32) ** 2
|
||||
)
|
||||
else:
|
||||
raise NotImplementedError(f"{beta_schedule} does is not implemented for {self.__class__}")
|
||||
|
||||
self.alphas = 1.0 - self.betas
|
||||
self.alphas_cumprod = torch.cumprod(self.alphas, dim=0)
|
||||
|
||||
# set all values
|
||||
self.set_timesteps(num_train_timesteps, None, num_train_timesteps)
|
||||
|
||||
def index_for_timestep(self, timestep):
|
||||
indices = (self.timesteps == timestep).nonzero()
|
||||
if self.state_in_first_order:
|
||||
pos = -1
|
||||
else:
|
||||
pos = 0
|
||||
return indices[pos].item()
|
||||
|
||||
def scale_model_input(
|
||||
self,
|
||||
sample: torch.FloatTensor,
|
||||
timestep: Union[float, torch.FloatTensor],
|
||||
) -> torch.FloatTensor:
|
||||
"""
|
||||
Args:
|
||||
Ensures interchangeability with schedulers that need to scale the denoising model input depending on the
|
||||
current timestep.
|
||||
sample (`torch.FloatTensor`): input sample timestep (`int`, optional): current timestep
|
||||
Returns:
|
||||
`torch.FloatTensor`: scaled input sample
|
||||
"""
|
||||
step_index = self.index_for_timestep(timestep)
|
||||
|
||||
if self.state_in_first_order:
|
||||
sigma = self.sigmas[step_index]
|
||||
else:
|
||||
sigma = self.sigmas_interpol[step_index]
|
||||
|
||||
sample = sample / ((sigma**2 + 1) ** 0.5)
|
||||
return sample
|
||||
|
||||
def set_timesteps(
|
||||
self,
|
||||
num_inference_steps: int,
|
||||
device: Union[str, torch.device] = None,
|
||||
num_train_timesteps: Optional[int] = None,
|
||||
):
|
||||
"""
|
||||
Sets the timesteps used for the diffusion chain. Supporting function to be run before inference.
|
||||
|
||||
Args:
|
||||
num_inference_steps (`int`):
|
||||
the number of diffusion steps used when generating samples with a pre-trained model.
|
||||
device (`str` or `torch.device`, optional):
|
||||
the device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
|
||||
"""
|
||||
self.num_inference_steps = num_inference_steps
|
||||
|
||||
num_train_timesteps = num_train_timesteps or self.config.num_train_timesteps
|
||||
|
||||
timesteps = np.linspace(0, num_train_timesteps - 1, num_inference_steps, dtype=float)[::-1].copy()
|
||||
|
||||
sigmas = np.array(((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5)
|
||||
self.log_sigmas = torch.from_numpy(np.log(sigmas)).to(device)
|
||||
|
||||
sigmas = np.interp(timesteps, np.arange(0, len(sigmas)), sigmas)
|
||||
sigmas = np.concatenate([sigmas, [0.0]]).astype(np.float32)
|
||||
sigmas = torch.from_numpy(sigmas).to(device=device)
|
||||
|
||||
# interpolate sigmas
|
||||
sigmas_interpol = sigmas.log().lerp(sigmas.roll(1).log(), 0.5).exp()
|
||||
|
||||
self.sigmas = torch.cat([sigmas[:1], sigmas[1:].repeat_interleave(2), sigmas[-1:]])
|
||||
self.sigmas_interpol = torch.cat(
|
||||
[sigmas_interpol[:1], sigmas_interpol[1:].repeat_interleave(2), sigmas_interpol[-1:]]
|
||||
)
|
||||
|
||||
# standard deviation of the initial noise distribution
|
||||
self.init_noise_sigma = self.sigmas.max()
|
||||
|
||||
timesteps = torch.from_numpy(timesteps).to(device)
|
||||
|
||||
# interpolate timesteps
|
||||
timesteps_interpol = self.sigma_to_t(sigmas_interpol).to(device)
|
||||
interleaved_timesteps = torch.stack((timesteps_interpol[1:-1, None], timesteps[1:, None]), dim=-1).flatten()
|
||||
timesteps = torch.cat([timesteps[:1], interleaved_timesteps])
|
||||
|
||||
if str(device).startswith("mps"):
|
||||
# mps does not support float64
|
||||
self.timesteps = timesteps.to(torch.float32)
|
||||
else:
|
||||
self.timesteps = timesteps
|
||||
|
||||
self.sample = None
|
||||
|
||||
def sigma_to_t(self, sigma):
|
||||
# get log sigma
|
||||
log_sigma = sigma.log()
|
||||
|
||||
# get distribution
|
||||
dists = log_sigma - self.log_sigmas[:, None]
|
||||
|
||||
# get sigmas range
|
||||
low_idx = dists.ge(0).cumsum(dim=0).argmax(dim=0).clamp(max=self.log_sigmas.shape[0] - 2)
|
||||
high_idx = low_idx + 1
|
||||
|
||||
low = self.log_sigmas[low_idx]
|
||||
high = self.log_sigmas[high_idx]
|
||||
|
||||
# interpolate sigmas
|
||||
w = (low - log_sigma) / (low - high)
|
||||
w = w.clamp(0, 1)
|
||||
|
||||
# transform interpolation to time range
|
||||
t = (1 - w) * low_idx + w * high_idx
|
||||
t = t.view(sigma.shape)
|
||||
return t
|
||||
|
||||
@property
|
||||
def state_in_first_order(self):
|
||||
return self.sample is None
|
||||
|
||||
def step(
|
||||
self,
|
||||
model_output: Union[torch.FloatTensor, np.ndarray],
|
||||
timestep: Union[float, torch.FloatTensor],
|
||||
sample: Union[torch.FloatTensor, np.ndarray],
|
||||
return_dict: bool = True,
|
||||
) -> Union[SchedulerOutput, Tuple]:
|
||||
"""
|
||||
Args:
|
||||
Predict the sample at the previous timestep by reversing the SDE. Core function to propagate the diffusion
|
||||
process from the learned model outputs (most often the predicted noise).
|
||||
model_output (`torch.FloatTensor` or `np.ndarray`): direct output from learned diffusion model. timestep
|
||||
(`int`): current discrete timestep in the diffusion chain. sample (`torch.FloatTensor` or `np.ndarray`):
|
||||
current instance of sample being created by diffusion process.
|
||||
return_dict (`bool`): option for returning tuple rather than SchedulerOutput class
|
||||
Returns:
|
||||
[`~schedulers.scheduling_utils.SchedulerOutput`] or `tuple`:
|
||||
[`~schedulers.scheduling_utils.SchedulerOutput`] if `return_dict` is True, otherwise a `tuple`. When
|
||||
returning a tuple, the first element is the sample tensor.
|
||||
"""
|
||||
step_index = self.index_for_timestep(timestep)
|
||||
|
||||
if self.state_in_first_order:
|
||||
sigma = self.sigmas[step_index]
|
||||
sigma_interpol = self.sigmas_interpol[step_index + 1]
|
||||
sigma_next = self.sigmas[step_index + 1]
|
||||
else:
|
||||
# 2nd order / KDPM2's method
|
||||
sigma = self.sigmas[step_index - 1]
|
||||
sigma_interpol = self.sigmas_interpol[step_index]
|
||||
sigma_next = self.sigmas[step_index]
|
||||
|
||||
# currently only gamma=0 is supported. This usually works best anyways.
|
||||
# We can support gamma in the future but then need to scale the timestep before
|
||||
# passing it to the model which requires a change in API
|
||||
gamma = 0
|
||||
sigma_hat = sigma * (gamma + 1) # Note: sigma_hat == sigma for now
|
||||
|
||||
# 1. compute predicted original sample (x_0) from sigma-scaled predicted noise
|
||||
|
||||
if self.state_in_first_order:
|
||||
pred_original_sample = sample - sigma_hat * model_output
|
||||
# 2. Convert to an ODE derivative
|
||||
derivative = (sample - pred_original_sample) / sigma_hat
|
||||
# 3. 1st order derivative
|
||||
dt = sigma_interpol - sigma_hat
|
||||
|
||||
# store for 2nd order step
|
||||
self.sample = sample
|
||||
else:
|
||||
# DPM-Solver-2
|
||||
pred_original_sample = sample - sigma_interpol * model_output
|
||||
derivative = (sample - pred_original_sample) / sigma_interpol
|
||||
|
||||
dt = sigma_next - sigma_hat
|
||||
|
||||
sample = self.sample
|
||||
self.sample = None
|
||||
|
||||
prev_sample = sample + derivative * dt
|
||||
|
||||
if not return_dict:
|
||||
return (prev_sample,)
|
||||
|
||||
return SchedulerOutput(prev_sample=prev_sample)
|
||||
|
||||
def add_noise(
|
||||
self,
|
||||
original_samples: torch.FloatTensor,
|
||||
noise: torch.FloatTensor,
|
||||
timesteps: torch.FloatTensor,
|
||||
) -> torch.FloatTensor:
|
||||
# Make sure sigmas and timesteps have the same device and dtype as original_samples
|
||||
self.sigmas = self.sigmas.to(device=original_samples.device, dtype=original_samples.dtype)
|
||||
if original_samples.device.type == "mps" and torch.is_floating_point(timesteps):
|
||||
# mps does not support float64
|
||||
self.timesteps = self.timesteps.to(original_samples.device, dtype=torch.float32)
|
||||
timesteps = timesteps.to(original_samples.device, dtype=torch.float32)
|
||||
else:
|
||||
self.timesteps = self.timesteps.to(original_samples.device)
|
||||
timesteps = timesteps.to(original_samples.device)
|
||||
|
||||
step_indices = [self.index_for_timestep(t) for t in timesteps]
|
||||
|
||||
sigma = self.sigmas[step_indices].flatten()
|
||||
while len(sigma.shape) < len(original_samples.shape):
|
||||
sigma = sigma.unsqueeze(-1)
|
||||
|
||||
noisy_samples = original_samples + noise * sigma
|
||||
return noisy_samples
|
||||
|
||||
def __len__(self):
|
||||
return self.config.num_train_timesteps
|
||||
@@ -77,6 +77,8 @@ class KarrasVeScheduler(SchedulerMixin, ConfigMixin):
|
||||
|
||||
"""
|
||||
|
||||
order = 2
|
||||
|
||||
@register_to_config
|
||||
def __init__(
|
||||
self,
|
||||
|
||||
@@ -13,7 +13,7 @@
|
||||
# limitations under the License.
|
||||
import warnings
|
||||
from dataclasses import dataclass
|
||||
from typing import Optional, Tuple, Union
|
||||
from typing import List, Optional, Tuple, Union
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
@@ -68,6 +68,7 @@ class LMSDiscreteScheduler(SchedulerMixin, ConfigMixin):
|
||||
"""
|
||||
|
||||
_compatibles = _COMPATIBLE_STABLE_DIFFUSION_SCHEDULERS.copy()
|
||||
order = 1
|
||||
|
||||
@register_to_config
|
||||
def __init__(
|
||||
@@ -76,10 +77,10 @@ class LMSDiscreteScheduler(SchedulerMixin, ConfigMixin):
|
||||
beta_start: float = 0.0001,
|
||||
beta_end: float = 0.02,
|
||||
beta_schedule: str = "linear",
|
||||
trained_betas: Optional[np.ndarray] = None,
|
||||
trained_betas: Optional[Union[np.ndarray, List[float]]] = None,
|
||||
):
|
||||
if trained_betas is not None:
|
||||
self.betas = torch.from_numpy(trained_betas)
|
||||
self.betas = torch.tensor(trained_betas, dtype=torch.float32)
|
||||
elif beta_schedule == "linear":
|
||||
self.betas = torch.linspace(beta_start, beta_end, num_train_timesteps, dtype=torch.float32)
|
||||
elif beta_schedule == "scaled_linear":
|
||||
|
||||
@@ -15,7 +15,7 @@
|
||||
# DISCLAIMER: This file is strongly influenced by https://github.com/ermongroup/ddim
|
||||
|
||||
import math
|
||||
from typing import Optional, Tuple, Union
|
||||
from typing import List, Optional, Tuple, Union
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
@@ -90,6 +90,7 @@ class PNDMScheduler(SchedulerMixin, ConfigMixin):
|
||||
"""
|
||||
|
||||
_compatibles = _COMPATIBLE_STABLE_DIFFUSION_SCHEDULERS.copy()
|
||||
order = 1
|
||||
|
||||
@register_to_config
|
||||
def __init__(
|
||||
@@ -98,13 +99,13 @@ class PNDMScheduler(SchedulerMixin, ConfigMixin):
|
||||
beta_start: float = 0.0001,
|
||||
beta_end: float = 0.02,
|
||||
beta_schedule: str = "linear",
|
||||
trained_betas: Optional[np.ndarray] = None,
|
||||
trained_betas: Optional[Union[np.ndarray, List[float]]] = None,
|
||||
skip_prk_steps: bool = False,
|
||||
set_alpha_to_one: bool = False,
|
||||
steps_offset: int = 0,
|
||||
):
|
||||
if trained_betas is not None:
|
||||
self.betas = torch.from_numpy(trained_betas)
|
||||
self.betas = torch.tensor(trained_betas, dtype=torch.float32)
|
||||
elif beta_schedule == "linear":
|
||||
self.betas = torch.linspace(beta_start, beta_end, num_train_timesteps, dtype=torch.float32)
|
||||
elif beta_schedule == "scaled_linear":
|
||||
|
||||
@@ -102,6 +102,8 @@ class RePaintScheduler(SchedulerMixin, ConfigMixin):
|
||||
|
||||
"""
|
||||
|
||||
order = 1
|
||||
|
||||
@register_to_config
|
||||
def __init__(
|
||||
self,
|
||||
|
||||
@@ -66,6 +66,8 @@ class ScoreSdeVeScheduler(SchedulerMixin, ConfigMixin):
|
||||
correct_steps (`int`): number of correction steps performed on a produced sample.
|
||||
"""
|
||||
|
||||
order = 1
|
||||
|
||||
@register_to_config
|
||||
def __init__(
|
||||
self,
|
||||
|
||||
@@ -38,6 +38,8 @@ class ScoreSdeVpScheduler(SchedulerMixin, ConfigMixin):
|
||||
|
||||
"""
|
||||
|
||||
order = 1
|
||||
|
||||
@register_to_config
|
||||
def __init__(self, num_train_timesteps=2000, beta_min=0.1, beta_max=20, sampling_eps=1e-3):
|
||||
self.sigmas = None
|
||||
|
||||
@@ -138,6 +138,8 @@ class VQDiffusionScheduler(SchedulerMixin, ConfigMixin):
|
||||
The ending cumulative gamma value.
|
||||
"""
|
||||
|
||||
order = 1
|
||||
|
||||
@register_to_config
|
||||
def __init__(
|
||||
self,
|
||||
|
||||
@@ -28,6 +28,7 @@ from .import_utils import (
|
||||
is_inflect_available,
|
||||
is_modelcards_available,
|
||||
is_onnx_available,
|
||||
is_safetensors_available,
|
||||
is_scipy_available,
|
||||
is_tf_available,
|
||||
is_torch_available,
|
||||
@@ -69,6 +70,7 @@ CONFIG_NAME = "config.json"
|
||||
WEIGHTS_NAME = "diffusion_pytorch_model.bin"
|
||||
FLAX_WEIGHTS_NAME = "diffusion_flax_model.msgpack"
|
||||
ONNX_WEIGHTS_NAME = "model.onnx"
|
||||
SAFETENSORS_WEIGHTS_NAME = "diffusion_pytorch_model.safetensors"
|
||||
ONNX_EXTERNAL_WEIGHTS_NAME = "weights.pb"
|
||||
HUGGINGFACE_CO_RESOLVE_ENDPOINT = "https://huggingface.co"
|
||||
DIFFUSERS_CACHE = default_cache_path
|
||||
@@ -81,6 +83,7 @@ _COMPATIBLE_STABLE_DIFFUSION_SCHEDULERS = [
|
||||
"PNDMScheduler",
|
||||
"LMSDiscreteScheduler",
|
||||
"EulerDiscreteScheduler",
|
||||
"HeunDiscreteScheduler",
|
||||
"EulerAncestralDiscreteScheduler",
|
||||
"DPMSolverMultistepScheduler",
|
||||
]
|
||||
|
||||
@@ -362,6 +362,21 @@ class EulerDiscreteScheduler(metaclass=DummyObject):
|
||||
requires_backends(cls, ["torch"])
|
||||
|
||||
|
||||
class HeunDiscreteScheduler(metaclass=DummyObject):
|
||||
_backends = ["torch"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["torch"])
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch"])
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch"])
|
||||
|
||||
|
||||
class IPNDMScheduler(metaclass=DummyObject):
|
||||
_backends = ["torch"]
|
||||
|
||||
@@ -392,6 +407,36 @@ class KarrasVeScheduler(metaclass=DummyObject):
|
||||
requires_backends(cls, ["torch"])
|
||||
|
||||
|
||||
class KDPM2AncestralDiscreteScheduler(metaclass=DummyObject):
|
||||
_backends = ["torch"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["torch"])
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch"])
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch"])
|
||||
|
||||
|
||||
class KDPM2DiscreteScheduler(metaclass=DummyObject):
|
||||
_backends = ["torch"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["torch"])
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch"])
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch"])
|
||||
|
||||
|
||||
class PNDMScheduler(metaclass=DummyObject):
|
||||
_backends = ["torch"]
|
||||
|
||||
|
||||
@@ -42,6 +42,7 @@ ENV_VARS_TRUE_AND_AUTO_VALUES = ENV_VARS_TRUE_VALUES.union({"AUTO"})
|
||||
USE_TF = os.environ.get("USE_TF", "AUTO").upper()
|
||||
USE_TORCH = os.environ.get("USE_TORCH", "AUTO").upper()
|
||||
USE_JAX = os.environ.get("USE_FLAX", "AUTO").upper()
|
||||
USE_SAFETENSORS = os.environ.get("USE_SAFETENSORS", "AUTO").upper()
|
||||
|
||||
STR_OPERATION_TO_FUNC = {">": op.gt, ">=": op.ge, "==": op.eq, "!=": op.ne, "<=": op.le, "<": op.lt}
|
||||
|
||||
@@ -55,7 +56,7 @@ if USE_TORCH in ENV_VARS_TRUE_AND_AUTO_VALUES and USE_TF not in ENV_VARS_TRUE_VA
|
||||
except importlib_metadata.PackageNotFoundError:
|
||||
_torch_available = False
|
||||
else:
|
||||
logger.info("Disabling PyTorch because USE_TF is set")
|
||||
logger.info("Disabling PyTorch because USE_TORCH is set")
|
||||
_torch_available = False
|
||||
|
||||
|
||||
@@ -109,6 +110,17 @@ if USE_JAX in ENV_VARS_TRUE_AND_AUTO_VALUES:
|
||||
else:
|
||||
_flax_available = False
|
||||
|
||||
if USE_SAFETENSORS in ENV_VARS_TRUE_AND_AUTO_VALUES:
|
||||
_safetensors_available = importlib.util.find_spec("safetensors") is not None
|
||||
if _safetensors_available:
|
||||
try:
|
||||
_safetensors_version = importlib_metadata.version("safetensors")
|
||||
logger.info(f"Safetensors version {_safetensors_version} available.")
|
||||
except importlib_metadata.PackageNotFoundError:
|
||||
_safetensors_available = False
|
||||
else:
|
||||
logger.info("Disabling Safetensors because USE_TF is set")
|
||||
_safetensors_available = False
|
||||
|
||||
_transformers_available = importlib.util.find_spec("transformers") is not None
|
||||
try:
|
||||
@@ -145,7 +157,13 @@ except importlib_metadata.PackageNotFoundError:
|
||||
_onnxruntime_version = "N/A"
|
||||
_onnx_available = importlib.util.find_spec("onnxruntime") is not None
|
||||
if _onnx_available:
|
||||
candidates = ("onnxruntime", "onnxruntime-gpu", "onnxruntime-directml", "onnxruntime-openvino")
|
||||
candidates = (
|
||||
"onnxruntime",
|
||||
"onnxruntime-gpu",
|
||||
"onnxruntime-directml",
|
||||
"onnxruntime-openvino",
|
||||
"ort_nightly_directml",
|
||||
)
|
||||
_onnxruntime_version = None
|
||||
# For the metadata, we have to look for both onnxruntime and onnxruntime-gpu
|
||||
for pkg in candidates:
|
||||
@@ -190,6 +208,10 @@ def is_torch_available():
|
||||
return _torch_available
|
||||
|
||||
|
||||
def is_safetensors_available():
|
||||
return _safetensors_available
|
||||
|
||||
|
||||
def is_tf_available():
|
||||
return _tf_available
|
||||
|
||||
|
||||
@@ -63,8 +63,8 @@ class UNet1DModelTests(ModelTesterMixin, unittest.TestCase):
|
||||
super().test_outputs_equivalence()
|
||||
|
||||
@unittest.skipIf(torch_device == "mps", "mish op not supported in MPS")
|
||||
def test_from_pretrained_save_pretrained(self):
|
||||
super().test_from_pretrained_save_pretrained()
|
||||
def test_from_save_pretrained(self):
|
||||
super().test_from_save_pretrained()
|
||||
|
||||
@unittest.skipIf(torch_device == "mps", "mish op not supported in MPS")
|
||||
def test_model_from_pretrained(self):
|
||||
@@ -183,8 +183,8 @@ class UNetRLModelTests(ModelTesterMixin, unittest.TestCase):
|
||||
super().test_outputs_equivalence()
|
||||
|
||||
@unittest.skipIf(torch_device == "mps", "mish op not supported in MPS")
|
||||
def test_from_pretrained_save_pretrained(self):
|
||||
super().test_from_pretrained_save_pretrained()
|
||||
def test_from_save_pretrained(self):
|
||||
super().test_from_save_pretrained()
|
||||
|
||||
@unittest.skipIf(torch_device == "mps", "mish op not supported in MPS")
|
||||
def test_model_from_pretrained(self):
|
||||
|
||||
@@ -639,3 +639,29 @@ class UNet2DConditionModelIntegrationTests(unittest.TestCase):
|
||||
expected_output_slice = torch.tensor(expected_slice)
|
||||
|
||||
assert torch_all_close(output_slice, expected_output_slice, atol=5e-3)
|
||||
|
||||
@parameterized.expand(
|
||||
[
|
||||
# fmt: off
|
||||
[83, 4, [0.1514, 0.0807, 0.1624, 0.1016, -0.1896, 0.0263, 0.0677, 0.2310]],
|
||||
[17, 0.55, [0.1164, -0.0216, 0.0170, 0.1589, -0.3120, 0.1005, -0.0581, -0.1458]],
|
||||
[8, 0.89, [-0.1758, -0.0169, 0.1004, -0.1411, 0.1312, 0.1103, -0.1996, 0.2139]],
|
||||
[3, 1000, [0.1214, 0.0352, -0.0731, -0.1562, -0.0994, -0.0906, -0.2340, -0.0539]],
|
||||
# fmt: on
|
||||
]
|
||||
)
|
||||
@require_torch_gpu
|
||||
def test_stabilityai_sd_v2_fp16(self, seed, timestep, expected_slice):
|
||||
model = self.get_unet_model(model_id="stabilityai/stable-diffusion-2", fp16=True)
|
||||
latents = self.get_latents(seed, shape=(4, 4, 96, 96), fp16=True)
|
||||
encoder_hidden_states = self.get_encoder_hidden_states(seed, shape=(4, 77, 1024), fp16=True)
|
||||
|
||||
with torch.no_grad():
|
||||
sample = model(latents, timestep=timestep, encoder_hidden_states=encoder_hidden_states).sample
|
||||
|
||||
assert sample.shape == latents.shape
|
||||
|
||||
output_slice = sample[-1, -2:, -2:, :2].flatten().float().cpu()
|
||||
expected_output_slice = torch.tensor(expected_slice)
|
||||
|
||||
assert torch_all_close(output_slice, expected_output_slice, atol=5e-3)
|
||||
|
||||
@@ -0,0 +1,103 @@
|
||||
import gc
|
||||
import unittest
|
||||
|
||||
from diffusers import FlaxUNet2DConditionModel
|
||||
from diffusers.utils import is_flax_available
|
||||
from diffusers.utils.testing_utils import load_hf_numpy, require_flax, slow
|
||||
from parameterized import parameterized
|
||||
|
||||
|
||||
if is_flax_available():
|
||||
import jax
|
||||
import jax.numpy as jnp
|
||||
|
||||
|
||||
@slow
|
||||
@require_flax
|
||||
class FlaxUNet2DConditionModelIntegrationTests(unittest.TestCase):
|
||||
def get_file_format(self, seed, shape):
|
||||
return f"gaussian_noise_s={seed}_shape={'_'.join([str(s) for s in shape])}.npy"
|
||||
|
||||
def tearDown(self):
|
||||
# clean up the VRAM after each test
|
||||
super().tearDown()
|
||||
gc.collect()
|
||||
|
||||
def get_latents(self, seed=0, shape=(4, 4, 64, 64), fp16=False):
|
||||
dtype = jnp.bfloat16 if fp16 else jnp.float32
|
||||
image = jnp.array(load_hf_numpy(self.get_file_format(seed, shape)), dtype=dtype)
|
||||
return image
|
||||
|
||||
def get_unet_model(self, fp16=False, model_id="CompVis/stable-diffusion-v1-4"):
|
||||
dtype = jnp.bfloat16 if fp16 else jnp.float32
|
||||
revision = "bf16" if fp16 else None
|
||||
|
||||
model, params = FlaxUNet2DConditionModel.from_pretrained(
|
||||
model_id, subfolder="unet", dtype=dtype, revision=revision
|
||||
)
|
||||
return model, params
|
||||
|
||||
def get_encoder_hidden_states(self, seed=0, shape=(4, 77, 768), fp16=False):
|
||||
dtype = jnp.bfloat16 if fp16 else jnp.float32
|
||||
hidden_states = jnp.array(load_hf_numpy(self.get_file_format(seed, shape)), dtype=dtype)
|
||||
return hidden_states
|
||||
|
||||
@parameterized.expand(
|
||||
[
|
||||
# fmt: off
|
||||
[83, 4, [-0.2323, -0.1304, 0.0813, -0.3093, -0.0919, -0.1571, -0.1125, -0.5806]],
|
||||
[17, 0.55, [-0.0831, -0.2443, 0.0901, -0.0919, 0.3396, 0.0103, -0.3743, 0.0701]],
|
||||
[8, 0.89, [-0.4863, 0.0859, 0.0875, -0.1658, 0.9199, -0.0114, 0.4839, 0.4639]],
|
||||
[3, 1000, [-0.5649, 0.2402, -0.5518, 0.1248, 1.1328, -0.2443, -0.0325, -1.0078]],
|
||||
# fmt: on
|
||||
]
|
||||
)
|
||||
def test_compvis_sd_v1_4_flax_vs_torch_fp16(self, seed, timestep, expected_slice):
|
||||
model, params = self.get_unet_model(model_id="CompVis/stable-diffusion-v1-4", fp16=True)
|
||||
latents = self.get_latents(seed, fp16=True)
|
||||
encoder_hidden_states = self.get_encoder_hidden_states(seed, fp16=True)
|
||||
|
||||
sample = model.apply(
|
||||
{"params": params},
|
||||
latents,
|
||||
jnp.array(timestep, dtype=jnp.int32),
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
).sample
|
||||
|
||||
assert sample.shape == latents.shape
|
||||
|
||||
output_slice = jnp.asarray(jax.device_get((sample[-1, -2:, -2:, :2].flatten())), dtype=jnp.float32)
|
||||
expected_output_slice = jnp.array(expected_slice, dtype=jnp.float32)
|
||||
|
||||
# Found torch (float16) and flax (bfloat16) outputs to be within this tolerance, in the same hardware
|
||||
assert jnp.allclose(output_slice, expected_output_slice, atol=1e-2)
|
||||
|
||||
@parameterized.expand(
|
||||
[
|
||||
# fmt: off
|
||||
[83, 4, [0.1514, 0.0807, 0.1624, 0.1016, -0.1896, 0.0263, 0.0677, 0.2310]],
|
||||
[17, 0.55, [0.1164, -0.0216, 0.0170, 0.1589, -0.3120, 0.1005, -0.0581, -0.1458]],
|
||||
[8, 0.89, [-0.1758, -0.0169, 0.1004, -0.1411, 0.1312, 0.1103, -0.1996, 0.2139]],
|
||||
[3, 1000, [0.1214, 0.0352, -0.0731, -0.1562, -0.0994, -0.0906, -0.2340, -0.0539]],
|
||||
# fmt: on
|
||||
]
|
||||
)
|
||||
def test_stabilityai_sd_v2_flax_vs_torch_fp16(self, seed, timestep, expected_slice):
|
||||
model, params = self.get_unet_model(model_id="stabilityai/stable-diffusion-2", fp16=True)
|
||||
latents = self.get_latents(seed, shape=(4, 4, 96, 96), fp16=True)
|
||||
encoder_hidden_states = self.get_encoder_hidden_states(seed, shape=(4, 77, 1024), fp16=True)
|
||||
|
||||
sample = model.apply(
|
||||
{"params": params},
|
||||
latents,
|
||||
jnp.array(timestep, dtype=jnp.int32),
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
).sample
|
||||
|
||||
assert sample.shape == latents.shape
|
||||
|
||||
output_slice = jnp.asarray(jax.device_get((sample[-1, -2:, -2:, :2].flatten())), dtype=jnp.float32)
|
||||
expected_output_slice = jnp.array(expected_slice, dtype=jnp.float32)
|
||||
|
||||
# Found torch (float16) and flax (bfloat16) outputs to be within this tolerance, on the same hardware
|
||||
assert jnp.allclose(output_slice, expected_output_slice, atol=1e-2)
|
||||
@@ -557,6 +557,46 @@ class StableDiffusionPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
|
||||
|
||||
assert np.abs(output_2.images.flatten() - output_1.images.flatten()).max() < 1e-4
|
||||
|
||||
def test_stable_diffusion_vae_slicing(self):
|
||||
device = "cpu" # ensure determinism for the device-dependent torch.Generator
|
||||
unet = self.dummy_cond_unet
|
||||
scheduler = LMSDiscreteScheduler(beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear")
|
||||
vae = self.dummy_vae
|
||||
bert = self.dummy_text_encoder
|
||||
tokenizer = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip")
|
||||
|
||||
# make sure here that pndm scheduler skips prk
|
||||
sd_pipe = StableDiffusionPipeline(
|
||||
unet=unet,
|
||||
scheduler=scheduler,
|
||||
vae=vae,
|
||||
text_encoder=bert,
|
||||
tokenizer=tokenizer,
|
||||
safety_checker=None,
|
||||
feature_extractor=self.dummy_extractor,
|
||||
)
|
||||
sd_pipe = sd_pipe.to(device)
|
||||
sd_pipe.set_progress_bar_config(disable=None)
|
||||
|
||||
prompt = "A painting of a squirrel eating a burger"
|
||||
|
||||
image_count = 4
|
||||
|
||||
generator = torch.Generator(device=device).manual_seed(0)
|
||||
output_1 = sd_pipe(
|
||||
[prompt] * image_count, generator=generator, guidance_scale=6.0, num_inference_steps=2, output_type="np"
|
||||
)
|
||||
|
||||
# make sure sliced vae decode yields the same result
|
||||
sd_pipe.enable_vae_slicing()
|
||||
generator = torch.Generator(device=device).manual_seed(0)
|
||||
output_2 = sd_pipe(
|
||||
[prompt] * image_count, generator=generator, guidance_scale=6.0, num_inference_steps=2, output_type="np"
|
||||
)
|
||||
|
||||
# there is a small discrepancy at image borders vs. full batch decode
|
||||
assert np.abs(output_2.images.flatten() - output_1.images.flatten()).max() < 3e-3
|
||||
|
||||
def test_stable_diffusion_negative_prompt(self):
|
||||
device = "cpu" # ensure determinism for the device-dependent torch.Generator
|
||||
unet = self.dummy_cond_unet
|
||||
@@ -765,18 +805,18 @@ class StableDiffusionPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
|
||||
|
||||
prompt = "hey"
|
||||
|
||||
output = sd_pipe(prompt, number_of_steps=1, output_type="np")
|
||||
output = sd_pipe(prompt, num_inference_steps=1, output_type="np")
|
||||
image_shape = output.images[0].shape[:2]
|
||||
assert image_shape == (64, 64)
|
||||
|
||||
output = sd_pipe(prompt, number_of_steps=1, height=96, width=96, output_type="np")
|
||||
output = sd_pipe(prompt, num_inference_steps=1, height=96, width=96, output_type="np")
|
||||
image_shape = output.images[0].shape[:2]
|
||||
assert image_shape == (96, 96)
|
||||
|
||||
config = dict(sd_pipe.unet.config)
|
||||
config["sample_size"] = 96
|
||||
sd_pipe.unet = UNet2DConditionModel.from_config(config).to(torch_device)
|
||||
output = sd_pipe(prompt, number_of_steps=1, output_type="np")
|
||||
output = sd_pipe(prompt, num_inference_steps=1, output_type="np")
|
||||
image_shape = output.images[0].shape[:2]
|
||||
assert image_shape == (192, 192)
|
||||
|
||||
@@ -886,6 +926,45 @@ class StableDiffusionPipelineIntegrationTests(unittest.TestCase):
|
||||
assert mem_bytes > 3.75 * 10**9
|
||||
assert np.abs(image_chunked.flatten() - image.flatten()).max() < 1e-3
|
||||
|
||||
def test_stable_diffusion_vae_slicing(self):
|
||||
torch.cuda.reset_peak_memory_stats()
|
||||
model_id = "CompVis/stable-diffusion-v1-4"
|
||||
pipe = StableDiffusionPipeline.from_pretrained(model_id, revision="fp16", torch_dtype=torch.float16)
|
||||
pipe.to(torch_device)
|
||||
pipe.set_progress_bar_config(disable=None)
|
||||
pipe.enable_attention_slicing()
|
||||
|
||||
prompt = "a photograph of an astronaut riding a horse"
|
||||
|
||||
# enable vae slicing
|
||||
pipe.enable_vae_slicing()
|
||||
generator = torch.Generator(device=torch_device).manual_seed(0)
|
||||
with torch.autocast(torch_device):
|
||||
output_chunked = pipe(
|
||||
[prompt] * 4, generator=generator, guidance_scale=7.5, num_inference_steps=10, output_type="numpy"
|
||||
)
|
||||
image_chunked = output_chunked.images
|
||||
|
||||
mem_bytes = torch.cuda.max_memory_allocated()
|
||||
torch.cuda.reset_peak_memory_stats()
|
||||
# make sure that less than 4 GB is allocated
|
||||
assert mem_bytes < 4e9
|
||||
|
||||
# disable vae slicing
|
||||
pipe.disable_vae_slicing()
|
||||
generator = torch.Generator(device=torch_device).manual_seed(0)
|
||||
with torch.autocast(torch_device):
|
||||
output = pipe(
|
||||
[prompt] * 4, generator=generator, guidance_scale=7.5, num_inference_steps=10, output_type="numpy"
|
||||
)
|
||||
image = output.images
|
||||
|
||||
# make sure that more than 4 GB is allocated
|
||||
mem_bytes = torch.cuda.max_memory_allocated()
|
||||
assert mem_bytes > 4e9
|
||||
# There is a small discrepancy at the image borders vs. a fully batched version.
|
||||
assert np.abs(image_chunked.flatten() - image.flatten()).max() < 3e-3
|
||||
|
||||
def test_stable_diffusion_text2img_pipeline_fp16(self):
|
||||
torch.cuda.reset_peak_memory_stats()
|
||||
model_id = "CompVis/stable-diffusion-v1-4"
|
||||
@@ -928,7 +1007,7 @@ class StableDiffusionPipelineIntegrationTests(unittest.TestCase):
|
||||
prompt = "astronaut riding a horse"
|
||||
|
||||
generator = torch.Generator(device=torch_device).manual_seed(0)
|
||||
output = pipe(prompt=prompt, strength=0.75, guidance_scale=7.5, generator=generator, output_type="np")
|
||||
output = pipe(prompt=prompt, guidance_scale=7.5, generator=generator, output_type="np")
|
||||
image = output.images[0]
|
||||
|
||||
assert image.shape == (512, 512, 3)
|
||||
@@ -980,7 +1059,7 @@ class StableDiffusionPipelineIntegrationTests(unittest.TestCase):
|
||||
callback_steps=1,
|
||||
)
|
||||
assert test_callback_fn.has_been_called
|
||||
assert number_of_steps == 51
|
||||
assert number_of_steps == 50
|
||||
|
||||
def test_stable_diffusion_low_cpu_mem_usage(self):
|
||||
pipeline_id = "CompVis/stable-diffusion-v1-4"
|
||||
|
||||
@@ -351,13 +351,13 @@ class StableDiffusionImageVariationPipelineIntegrationTests(unittest.TestCase):
|
||||
assert latents.shape == (1, 4, 64, 64)
|
||||
latents_slice = latents[0, -3:, -3:, -1]
|
||||
expected_slice = np.array([1.83, 1.293, -0.09705, 1.256, -2.293, 1.091, -0.0809, -0.65, -2.953])
|
||||
assert np.abs(latents_slice.flatten() - expected_slice).max() < 1e-3
|
||||
assert np.abs(latents_slice.flatten() - expected_slice).max() < 5e-3
|
||||
elif step == 37:
|
||||
latents = latents.detach().cpu().numpy()
|
||||
assert latents.shape == (1, 4, 64, 64)
|
||||
latents_slice = latents[0, -3:, -3:, -1]
|
||||
expected_slice = np.array([2.285, 2.703, 1.969, 0.696, -1.323, 0.9253, -0.5464, -1.521, -2.537])
|
||||
assert np.abs(latents_slice.flatten() - expected_slice).max() < 1e-2
|
||||
assert np.abs(latents_slice.flatten() - expected_slice).max() < 5e-2
|
||||
|
||||
test_callback_fn.has_been_called = False
|
||||
|
||||
@@ -386,7 +386,7 @@ class StableDiffusionImageVariationPipelineIntegrationTests(unittest.TestCase):
|
||||
callback_steps=1,
|
||||
)
|
||||
assert test_callback_fn.has_been_called
|
||||
assert number_of_steps == 51
|
||||
assert number_of_steps == 50
|
||||
|
||||
def test_stable_diffusion_pipeline_with_sequential_cpu_offloading(self):
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
@@ -635,7 +635,7 @@ class StableDiffusionImg2ImgPipelineIntegrationTests(unittest.TestCase):
|
||||
callback_steps=1,
|
||||
)
|
||||
assert test_callback_fn.has_been_called
|
||||
assert number_of_steps == 38
|
||||
assert number_of_steps == 37
|
||||
|
||||
def test_stable_diffusion_pipeline_with_sequential_cpu_offloading(self):
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
@@ -484,4 +484,4 @@ class StableDiffusionInpaintLegacyPipelineIntegrationTests(unittest.TestCase):
|
||||
callback_steps=1,
|
||||
)
|
||||
assert test_callback_fn.has_been_called
|
||||
assert number_of_steps == 38
|
||||
assert number_of_steps == 37
|
||||
|
||||
@@ -668,7 +668,7 @@ class StableDiffusion2PipelineIntegrationTests(unittest.TestCase):
|
||||
assert latents.shape == (1, 4, 64, 64)
|
||||
latents_slice = latents[0, -3:, -3:, -1]
|
||||
expected_slice = np.array([1.0757, 1.1860, 1.1410, 0.4645, -0.2476, 0.6100, -0.7755, -0.8841, -0.9497])
|
||||
assert np.abs(latents_slice.flatten() - expected_slice).max() < 1e-2
|
||||
assert np.abs(latents_slice.flatten() - expected_slice).max() < 5e-2
|
||||
|
||||
test_callback_fn.has_been_called = False
|
||||
|
||||
@@ -692,7 +692,7 @@ class StableDiffusion2PipelineIntegrationTests(unittest.TestCase):
|
||||
callback_steps=1,
|
||||
)
|
||||
assert test_callback_fn.has_been_called
|
||||
assert number_of_steps == 21
|
||||
assert number_of_steps == 20
|
||||
|
||||
def test_stable_diffusion_low_cpu_mem_usage(self):
|
||||
pipeline_id = "stabilityai/stable-diffusion-2-base"
|
||||
|
||||
@@ -0,0 +1,99 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2022 HuggingFace Inc.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import gc
|
||||
import unittest
|
||||
|
||||
from diffusers import FlaxDPMSolverMultistepScheduler, FlaxStableDiffusionPipeline
|
||||
from diffusers.utils import is_flax_available, slow
|
||||
from diffusers.utils.testing_utils import require_flax
|
||||
|
||||
|
||||
if is_flax_available():
|
||||
import jax
|
||||
import jax.numpy as jnp
|
||||
from flax.jax_utils import replicate
|
||||
from flax.training.common_utils import shard
|
||||
|
||||
|
||||
@slow
|
||||
@require_flax
|
||||
class FlaxStableDiffusion2PipelineIntegrationTests(unittest.TestCase):
|
||||
def tearDown(self):
|
||||
# clean up the VRAM after each test
|
||||
super().tearDown()
|
||||
gc.collect()
|
||||
|
||||
def test_stable_diffusion_flax(self):
|
||||
sd_pipe, params = FlaxStableDiffusionPipeline.from_pretrained(
|
||||
"stabilityai/stable-diffusion-2",
|
||||
revision="bf16",
|
||||
dtype=jnp.bfloat16,
|
||||
)
|
||||
|
||||
prompt = "A painting of a squirrel eating a burger"
|
||||
num_samples = jax.device_count()
|
||||
prompt = num_samples * [prompt]
|
||||
prompt_ids = sd_pipe.prepare_inputs(prompt)
|
||||
|
||||
params = replicate(params)
|
||||
prompt_ids = shard(prompt_ids)
|
||||
|
||||
prng_seed = jax.random.PRNGKey(0)
|
||||
prng_seed = jax.random.split(prng_seed, jax.device_count())
|
||||
|
||||
images = sd_pipe(prompt_ids, params, prng_seed, num_inference_steps=25, jit=True)[0]
|
||||
assert images.shape == (jax.device_count(), 1, 768, 768, 3)
|
||||
|
||||
images = images.reshape((images.shape[0] * images.shape[1],) + images.shape[-3:])
|
||||
image_slice = images[0, 253:256, 253:256, -1]
|
||||
|
||||
output_slice = jnp.asarray(jax.device_get(image_slice.flatten()))
|
||||
expected_slice = jnp.array([0.4238, 0.4414, 0.4395, 0.4453, 0.4629, 0.4590, 0.4531, 0.45508, 0.4512])
|
||||
print(f"output_slice: {output_slice}")
|
||||
assert jnp.abs(output_slice - expected_slice).max() < 1e-2
|
||||
|
||||
def test_stable_diffusion_dpm_flax(self):
|
||||
model_id = "stabilityai/stable-diffusion-2"
|
||||
scheduler, scheduler_params = FlaxDPMSolverMultistepScheduler.from_pretrained(model_id, subfolder="scheduler")
|
||||
sd_pipe, params = FlaxStableDiffusionPipeline.from_pretrained(
|
||||
model_id,
|
||||
scheduler=scheduler,
|
||||
revision="bf16",
|
||||
dtype=jnp.bfloat16,
|
||||
)
|
||||
params["scheduler"] = scheduler_params
|
||||
|
||||
prompt = "A painting of a squirrel eating a burger"
|
||||
num_samples = jax.device_count()
|
||||
prompt = num_samples * [prompt]
|
||||
prompt_ids = sd_pipe.prepare_inputs(prompt)
|
||||
|
||||
params = replicate(params)
|
||||
prompt_ids = shard(prompt_ids)
|
||||
|
||||
prng_seed = jax.random.PRNGKey(0)
|
||||
prng_seed = jax.random.split(prng_seed, jax.device_count())
|
||||
|
||||
images = sd_pipe(prompt_ids, params, prng_seed, num_inference_steps=25, jit=True)[0]
|
||||
assert images.shape == (jax.device_count(), 1, 768, 768, 3)
|
||||
|
||||
images = images.reshape((images.shape[0] * images.shape[1],) + images.shape[-3:])
|
||||
image_slice = images[0, 253:256, 253:256, -1]
|
||||
|
||||
output_slice = jnp.asarray(jax.device_get(image_slice.flatten()))
|
||||
expected_slice = jnp.array([0.4336, 0.42969, 0.4453, 0.4199, 0.4297, 0.4531, 0.4434, 0.4434, 0.4297])
|
||||
print(f"output_slice: {output_slice}")
|
||||
assert jnp.abs(output_slice - expected_slice).max() < 1e-2
|
||||
@@ -306,7 +306,7 @@ class StableDiffusion2VPredictionPipelineIntegrationTests(unittest.TestCase):
|
||||
|
||||
image_slice = image[0, 253:256, 253:256, -1]
|
||||
assert image.shape == (1, 768, 768, 3)
|
||||
expected_slice = np.array([0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0])
|
||||
expected_slice = np.array([0.2049, 0.2115, 0.2323, 0.2416, 0.256, 0.2484, 0.2517, 0.2358, 0.236])
|
||||
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
|
||||
|
||||
def test_stable_diffusion_attention_slicing_v_pred(self):
|
||||
@@ -385,7 +385,7 @@ class StableDiffusion2VPredictionPipelineIntegrationTests(unittest.TestCase):
|
||||
image = output.images[0]
|
||||
|
||||
assert image.shape == (768, 768, 3)
|
||||
assert np.abs(expected_image - image).max() < 5e-3
|
||||
assert np.abs(expected_image - image).max() < 5e-1
|
||||
|
||||
def test_stable_diffusion_text2img_intermediate_state_v_pred(self):
|
||||
number_of_steps = 0
|
||||
|
||||
@@ -42,7 +42,7 @@ class VersatileDiffusionMegaPipelineIntegrationTests(unittest.TestCase):
|
||||
gc.collect()
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
def test_from_pretrained_save_pretrained(self):
|
||||
def test_from_save_pretrained(self):
|
||||
pipe = VersatileDiffusionPipeline.from_pretrained("shi-labs/versatile-diffusion", torch_dtype=torch.float16)
|
||||
pipe.to(torch_device)
|
||||
pipe.set_progress_bar_config(disable=None)
|
||||
@@ -105,7 +105,7 @@ class VersatileDiffusionMegaPipelineIntegrationTests(unittest.TestCase):
|
||||
|
||||
assert image.shape == (1, 512, 512, 3)
|
||||
expected_slice = np.array([0.0081, 0.0032, 0.0002, 0.0056, 0.0027, 0.0000, 0.0051, 0.0020, 0.0007])
|
||||
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
|
||||
assert np.abs(image_slice.flatten() - expected_slice).max() < 5e-2
|
||||
|
||||
prompt = "A painting of a squirrel eating a burger "
|
||||
generator = torch.Generator(device=torch_device).manual_seed(0)
|
||||
@@ -117,7 +117,7 @@ class VersatileDiffusionMegaPipelineIntegrationTests(unittest.TestCase):
|
||||
|
||||
assert image.shape == (1, 512, 512, 3)
|
||||
expected_slice = np.array([0.0408, 0.0181, 0.0, 0.0388, 0.0046, 0.0461, 0.0411, 0.0, 0.0222])
|
||||
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
|
||||
assert np.abs(image_slice.flatten() - expected_slice).max() < 5e-2
|
||||
|
||||
image = pipe.image_variation(init_image, generator=generator, output_type="numpy").images
|
||||
|
||||
@@ -125,4 +125,4 @@ class VersatileDiffusionMegaPipelineIntegrationTests(unittest.TestCase):
|
||||
|
||||
assert image.shape == (1, 512, 512, 3)
|
||||
expected_slice = np.array([0.3403, 0.1809, 0.0938, 0.3855, 0.2393, 0.1243, 0.4028, 0.3110, 0.1799])
|
||||
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
|
||||
assert np.abs(image_slice.flatten() - expected_slice).max() < 5e-2
|
||||
|
||||
@@ -27,7 +27,7 @@ from diffusers.utils import torch_device
|
||||
|
||||
|
||||
class ModelTesterMixin:
|
||||
def test_from_pretrained_save_pretrained(self):
|
||||
def test_from_save_pretrained(self):
|
||||
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
|
||||
|
||||
model = self.model_class(**init_dict)
|
||||
@@ -57,6 +57,24 @@ class ModelTesterMixin:
|
||||
max_diff = (image - new_image).abs().sum().item()
|
||||
self.assertLessEqual(max_diff, 5e-5, "Models give different forward passes")
|
||||
|
||||
def test_from_save_pretrained_dtype(self):
|
||||
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
|
||||
|
||||
model = self.model_class(**init_dict)
|
||||
model.to(torch_device)
|
||||
model.eval()
|
||||
|
||||
for dtype in [torch.float32, torch.float16, torch.bfloat16]:
|
||||
if torch_device == "mps" and dtype == torch.bfloat16:
|
||||
continue
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
model.to(dtype)
|
||||
model.save_pretrained(tmpdirname)
|
||||
new_model = self.model_class.from_pretrained(tmpdirname, low_cpu_mem_usage=True)
|
||||
assert new_model.dtype == dtype
|
||||
new_model = self.model_class.from_pretrained(tmpdirname, low_cpu_mem_usage=False)
|
||||
assert new_model.dtype == dtype
|
||||
|
||||
def test_determinism(self):
|
||||
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
|
||||
model = self.model_class(**init_dict)
|
||||
|
||||
+23
-2
@@ -92,6 +92,24 @@ class DownloadTests(unittest.TestCase):
|
||||
# None of the downloaded files should be a flax file even if we have some here:
|
||||
# https://huggingface.co/hf-internal-testing/tiny-stable-diffusion-pipe/blob/main/unet/diffusion_flax_model.msgpack
|
||||
assert not any(f.endswith(".msgpack") for f in files)
|
||||
# We need to never convert this tiny model to safetensors for this test to pass
|
||||
assert not any(f.endswith(".safetensors") for f in files)
|
||||
|
||||
def test_download_safetensors(self):
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
# pipeline has Flax weights
|
||||
_ = DiffusionPipeline.from_pretrained(
|
||||
"hf-internal-testing/tiny-stable-diffusion-pipe-safetensors",
|
||||
safety_checker=None,
|
||||
cache_dir=tmpdirname,
|
||||
)
|
||||
|
||||
all_root_files = [t[-1] for t in os.walk(os.path.join(tmpdirname, os.listdir(tmpdirname)[0], "snapshots"))]
|
||||
files = [item for sublist in all_root_files for item in sublist]
|
||||
|
||||
# None of the downloaded files should be a pytorch file even if we have some here:
|
||||
# https://huggingface.co/hf-internal-testing/tiny-stable-diffusion-pipe/blob/main/unet/diffusion_flax_model.msgpack
|
||||
assert not any(f.endswith(".bin") for f in files)
|
||||
|
||||
def test_download_no_safety_checker(self):
|
||||
prompt = "hello"
|
||||
@@ -636,9 +654,12 @@ class PipelineSlowTests(unittest.TestCase):
|
||||
force_download=True,
|
||||
)
|
||||
|
||||
assert cap_logger.out == "Keyword arguments {'not_used': True} not recognized.\n"
|
||||
assert (
|
||||
cap_logger.out
|
||||
== "Keyword arguments {'not_used': True} are not expected by DDPMPipeline and will be ignored.\n"
|
||||
)
|
||||
|
||||
def test_from_pretrained_save_pretrained(self):
|
||||
def test_from_save_pretrained(self):
|
||||
# 1. Load models
|
||||
model = UNet2DModel(
|
||||
block_out_channels=(32, 64),
|
||||
|
||||
+319
-8
@@ -30,7 +30,10 @@ from diffusers import (
|
||||
DPMSolverMultistepScheduler,
|
||||
EulerAncestralDiscreteScheduler,
|
||||
EulerDiscreteScheduler,
|
||||
HeunDiscreteScheduler,
|
||||
IPNDMScheduler,
|
||||
KDPM2AncestralDiscreteScheduler,
|
||||
KDPM2DiscreteScheduler,
|
||||
LMSDiscreteScheduler,
|
||||
PNDMScheduler,
|
||||
ScoreSdeVeScheduler,
|
||||
@@ -333,7 +336,7 @@ class SchedulerCommonTest(unittest.TestCase):
|
||||
|
||||
assert torch.sum(torch.abs(output - new_output)) < 1e-5, "Scheduler outputs are not identical"
|
||||
|
||||
def test_from_pretrained_save_pretrained(self):
|
||||
def test_from_save_pretrained(self):
|
||||
kwargs = dict(self.forward_default_kwargs)
|
||||
|
||||
num_inference_steps = kwargs.pop("num_inference_steps", None)
|
||||
@@ -547,7 +550,6 @@ class SchedulerCommonTest(unittest.TestCase):
|
||||
def test_add_noise_device(self):
|
||||
for scheduler_class in self.scheduler_classes:
|
||||
if scheduler_class == IPNDMScheduler:
|
||||
# Skip until #990 is addressed
|
||||
continue
|
||||
scheduler_config = self.get_scheduler_config()
|
||||
scheduler = scheduler_class(**scheduler_config)
|
||||
@@ -583,6 +585,20 @@ class SchedulerCommonTest(unittest.TestCase):
|
||||
" deprecated argument from `_deprecated_kwargs = [<deprecated_argument>]`"
|
||||
)
|
||||
|
||||
def test_trained_betas(self):
|
||||
for scheduler_class in self.scheduler_classes:
|
||||
if scheduler_class == VQDiffusionScheduler:
|
||||
continue
|
||||
|
||||
scheduler_config = self.get_scheduler_config()
|
||||
scheduler = scheduler_class(**scheduler_config, trained_betas=np.array([0.0, 0.1]))
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
scheduler.save_pretrained(tmpdirname)
|
||||
new_scheduler = scheduler_class.from_pretrained(tmpdirname)
|
||||
|
||||
assert scheduler.betas.tolist() == new_scheduler.betas.tolist()
|
||||
|
||||
|
||||
class DDPMSchedulerTest(SchedulerCommonTest):
|
||||
scheduler_classes = (DDPMScheduler,)
|
||||
@@ -860,7 +876,7 @@ class DPMSolverMultistepSchedulerTest(SchedulerCommonTest):
|
||||
|
||||
assert torch.sum(torch.abs(output - new_output)) < 1e-5, "Scheduler outputs are not identical"
|
||||
|
||||
def test_from_pretrained_save_pretrained(self):
|
||||
def test_from_save_pretrained(self):
|
||||
pass
|
||||
|
||||
def check_over_forward(self, time_step=0, **forward_kwargs):
|
||||
@@ -990,6 +1006,22 @@ class DPMSolverMultistepSchedulerTest(SchedulerCommonTest):
|
||||
|
||||
assert abs(result_mean.item() - 0.3301) < 1e-3
|
||||
|
||||
def test_fp16_support(self):
|
||||
scheduler_class = self.scheduler_classes[0]
|
||||
scheduler_config = self.get_scheduler_config(thresholding=True, dynamic_thresholding_ratio=0)
|
||||
scheduler = scheduler_class(**scheduler_config)
|
||||
|
||||
num_inference_steps = 10
|
||||
model = self.dummy_model()
|
||||
sample = self.dummy_sample_deter.half()
|
||||
scheduler.set_timesteps(num_inference_steps)
|
||||
|
||||
for i, t in enumerate(scheduler.timesteps):
|
||||
residual = model(sample, t)
|
||||
sample = scheduler.step(residual, t, sample).prev_sample
|
||||
|
||||
assert sample.dtype == torch.float16
|
||||
|
||||
|
||||
class PNDMSchedulerTest(SchedulerCommonTest):
|
||||
scheduler_classes = (PNDMScheduler,)
|
||||
@@ -1037,7 +1069,7 @@ class PNDMSchedulerTest(SchedulerCommonTest):
|
||||
|
||||
assert torch.sum(torch.abs(output - new_output)) < 1e-5, "Scheduler outputs are not identical"
|
||||
|
||||
def test_from_pretrained_save_pretrained(self):
|
||||
def test_from_save_pretrained(self):
|
||||
pass
|
||||
|
||||
def check_over_forward(self, time_step=0, **forward_kwargs):
|
||||
@@ -1406,7 +1438,6 @@ class LMSDiscreteSchedulerTest(SchedulerCommonTest):
|
||||
"beta_start": 0.0001,
|
||||
"beta_end": 0.02,
|
||||
"beta_schedule": "linear",
|
||||
"trained_betas": None,
|
||||
}
|
||||
|
||||
config.update(**kwargs)
|
||||
@@ -1488,7 +1519,6 @@ class EulerDiscreteSchedulerTest(SchedulerCommonTest):
|
||||
"beta_start": 0.0001,
|
||||
"beta_end": 0.02,
|
||||
"beta_schedule": "linear",
|
||||
"trained_betas": None,
|
||||
}
|
||||
|
||||
config.update(**kwargs)
|
||||
@@ -1579,7 +1609,6 @@ class EulerAncestralDiscreteSchedulerTest(SchedulerCommonTest):
|
||||
"beta_start": 0.0001,
|
||||
"beta_end": 0.02,
|
||||
"beta_schedule": "linear",
|
||||
"trained_betas": None,
|
||||
}
|
||||
|
||||
config.update(**kwargs)
|
||||
@@ -1717,7 +1746,7 @@ class IPNDMSchedulerTest(SchedulerCommonTest):
|
||||
|
||||
assert torch.sum(torch.abs(output - new_output)) < 1e-5, "Scheduler outputs are not identical"
|
||||
|
||||
def test_from_pretrained_save_pretrained(self):
|
||||
def test_from_save_pretrained(self):
|
||||
pass
|
||||
|
||||
def check_over_forward(self, time_step=0, **forward_kwargs):
|
||||
@@ -1876,3 +1905,285 @@ class VQDiffusionSchedulerTest(SchedulerCommonTest):
|
||||
|
||||
def test_add_noise_device(self):
|
||||
pass
|
||||
|
||||
|
||||
class HeunDiscreteSchedulerTest(SchedulerCommonTest):
|
||||
scheduler_classes = (HeunDiscreteScheduler,)
|
||||
num_inference_steps = 10
|
||||
|
||||
def get_scheduler_config(self, **kwargs):
|
||||
config = {
|
||||
"num_train_timesteps": 1100,
|
||||
"beta_start": 0.0001,
|
||||
"beta_end": 0.02,
|
||||
"beta_schedule": "linear",
|
||||
}
|
||||
|
||||
config.update(**kwargs)
|
||||
return config
|
||||
|
||||
def test_timesteps(self):
|
||||
for timesteps in [10, 50, 100, 1000]:
|
||||
self.check_over_configs(num_train_timesteps=timesteps)
|
||||
|
||||
def test_betas(self):
|
||||
for beta_start, beta_end in zip([0.00001, 0.0001, 0.001], [0.0002, 0.002, 0.02]):
|
||||
self.check_over_configs(beta_start=beta_start, beta_end=beta_end)
|
||||
|
||||
def test_schedules(self):
|
||||
for schedule in ["linear", "scaled_linear"]:
|
||||
self.check_over_configs(beta_schedule=schedule)
|
||||
|
||||
def test_full_loop_no_noise(self):
|
||||
scheduler_class = self.scheduler_classes[0]
|
||||
scheduler_config = self.get_scheduler_config()
|
||||
scheduler = scheduler_class(**scheduler_config)
|
||||
|
||||
scheduler.set_timesteps(self.num_inference_steps)
|
||||
|
||||
model = self.dummy_model()
|
||||
sample = self.dummy_sample_deter * scheduler.init_noise_sigma
|
||||
sample = sample.to(torch_device)
|
||||
|
||||
for i, t in enumerate(scheduler.timesteps):
|
||||
sample = scheduler.scale_model_input(sample, t)
|
||||
|
||||
model_output = model(sample, t)
|
||||
|
||||
output = scheduler.step(model_output, t, sample)
|
||||
sample = output.prev_sample
|
||||
|
||||
result_sum = torch.sum(torch.abs(sample))
|
||||
result_mean = torch.mean(torch.abs(sample))
|
||||
|
||||
if torch_device in ["cpu", "mps"]:
|
||||
assert abs(result_sum.item() - 0.1233) < 1e-2
|
||||
assert abs(result_mean.item() - 0.0002) < 1e-3
|
||||
else:
|
||||
# CUDA
|
||||
assert abs(result_sum.item() - 0.1233) < 1e-2
|
||||
assert abs(result_mean.item() - 0.0002) < 1e-3
|
||||
|
||||
def test_full_loop_device(self):
|
||||
scheduler_class = self.scheduler_classes[0]
|
||||
scheduler_config = self.get_scheduler_config()
|
||||
scheduler = scheduler_class(**scheduler_config)
|
||||
|
||||
scheduler.set_timesteps(self.num_inference_steps, device=torch_device)
|
||||
|
||||
model = self.dummy_model()
|
||||
sample = self.dummy_sample_deter.to(torch_device) * scheduler.init_noise_sigma
|
||||
|
||||
for t in scheduler.timesteps:
|
||||
sample = scheduler.scale_model_input(sample, t)
|
||||
|
||||
model_output = model(sample, t)
|
||||
|
||||
output = scheduler.step(model_output, t, sample)
|
||||
sample = output.prev_sample
|
||||
|
||||
result_sum = torch.sum(torch.abs(sample))
|
||||
result_mean = torch.mean(torch.abs(sample))
|
||||
|
||||
if str(torch_device).startswith("cpu"):
|
||||
# The following sum varies between 148 and 156 on mps. Why?
|
||||
assert abs(result_sum.item() - 0.1233) < 1e-2
|
||||
assert abs(result_mean.item() - 0.0002) < 1e-3
|
||||
elif str(torch_device).startswith("mps"):
|
||||
# Larger tolerance on mps
|
||||
assert abs(result_mean.item() - 0.0002) < 1e-2
|
||||
else:
|
||||
# CUDA
|
||||
assert abs(result_sum.item() - 0.1233) < 1e-2
|
||||
assert abs(result_mean.item() - 0.0002) < 1e-3
|
||||
|
||||
|
||||
class KDPM2DiscreteSchedulerTest(SchedulerCommonTest):
|
||||
scheduler_classes = (KDPM2DiscreteScheduler,)
|
||||
num_inference_steps = 10
|
||||
|
||||
def get_scheduler_config(self, **kwargs):
|
||||
config = {
|
||||
"num_train_timesteps": 1100,
|
||||
"beta_start": 0.0001,
|
||||
"beta_end": 0.02,
|
||||
"beta_schedule": "linear",
|
||||
}
|
||||
|
||||
config.update(**kwargs)
|
||||
return config
|
||||
|
||||
def test_timesteps(self):
|
||||
for timesteps in [10, 50, 100, 1000]:
|
||||
self.check_over_configs(num_train_timesteps=timesteps)
|
||||
|
||||
def test_betas(self):
|
||||
for beta_start, beta_end in zip([0.00001, 0.0001, 0.001], [0.0002, 0.002, 0.02]):
|
||||
self.check_over_configs(beta_start=beta_start, beta_end=beta_end)
|
||||
|
||||
def test_schedules(self):
|
||||
for schedule in ["linear", "scaled_linear"]:
|
||||
self.check_over_configs(beta_schedule=schedule)
|
||||
|
||||
def test_full_loop_no_noise(self):
|
||||
if torch_device == "mps":
|
||||
return
|
||||
scheduler_class = self.scheduler_classes[0]
|
||||
scheduler_config = self.get_scheduler_config()
|
||||
scheduler = scheduler_class(**scheduler_config)
|
||||
|
||||
scheduler.set_timesteps(self.num_inference_steps)
|
||||
|
||||
model = self.dummy_model()
|
||||
sample = self.dummy_sample_deter * scheduler.init_noise_sigma
|
||||
sample = sample.to(torch_device)
|
||||
|
||||
for i, t in enumerate(scheduler.timesteps):
|
||||
sample = scheduler.scale_model_input(sample, t)
|
||||
|
||||
model_output = model(sample, t)
|
||||
|
||||
output = scheduler.step(model_output, t, sample)
|
||||
sample = output.prev_sample
|
||||
|
||||
result_sum = torch.sum(torch.abs(sample))
|
||||
result_mean = torch.mean(torch.abs(sample))
|
||||
|
||||
if torch_device in ["cpu", "mps"]:
|
||||
assert abs(result_sum.item() - 20.4125) < 1e-2
|
||||
assert abs(result_mean.item() - 0.0266) < 1e-3
|
||||
else:
|
||||
# CUDA
|
||||
assert abs(result_sum.item() - 20.4125) < 1e-2
|
||||
assert abs(result_mean.item() - 0.0266) < 1e-3
|
||||
|
||||
def test_full_loop_device(self):
|
||||
if torch_device == "mps":
|
||||
return
|
||||
scheduler_class = self.scheduler_classes[0]
|
||||
scheduler_config = self.get_scheduler_config()
|
||||
scheduler = scheduler_class(**scheduler_config)
|
||||
|
||||
scheduler.set_timesteps(self.num_inference_steps, device=torch_device)
|
||||
|
||||
model = self.dummy_model()
|
||||
sample = self.dummy_sample_deter.to(torch_device) * scheduler.init_noise_sigma
|
||||
|
||||
for t in scheduler.timesteps:
|
||||
sample = scheduler.scale_model_input(sample, t)
|
||||
|
||||
model_output = model(sample, t)
|
||||
|
||||
output = scheduler.step(model_output, t, sample)
|
||||
sample = output.prev_sample
|
||||
|
||||
result_sum = torch.sum(torch.abs(sample))
|
||||
result_mean = torch.mean(torch.abs(sample))
|
||||
|
||||
if str(torch_device).startswith("cpu"):
|
||||
# The following sum varies between 148 and 156 on mps. Why?
|
||||
assert abs(result_sum.item() - 20.4125) < 1e-2
|
||||
assert abs(result_mean.item() - 0.0266) < 1e-3
|
||||
else:
|
||||
# CUDA
|
||||
assert abs(result_sum.item() - 20.4125) < 1e-2
|
||||
assert abs(result_mean.item() - 0.0266) < 1e-3
|
||||
|
||||
|
||||
class KDPM2AncestralDiscreteSchedulerTest(SchedulerCommonTest):
|
||||
scheduler_classes = (KDPM2AncestralDiscreteScheduler,)
|
||||
num_inference_steps = 10
|
||||
|
||||
def get_scheduler_config(self, **kwargs):
|
||||
config = {
|
||||
"num_train_timesteps": 1100,
|
||||
"beta_start": 0.0001,
|
||||
"beta_end": 0.02,
|
||||
"beta_schedule": "linear",
|
||||
}
|
||||
|
||||
config.update(**kwargs)
|
||||
return config
|
||||
|
||||
def test_timesteps(self):
|
||||
for timesteps in [10, 50, 100, 1000]:
|
||||
self.check_over_configs(num_train_timesteps=timesteps)
|
||||
|
||||
def test_betas(self):
|
||||
for beta_start, beta_end in zip([0.00001, 0.0001, 0.001], [0.0002, 0.002, 0.02]):
|
||||
self.check_over_configs(beta_start=beta_start, beta_end=beta_end)
|
||||
|
||||
def test_schedules(self):
|
||||
for schedule in ["linear", "scaled_linear"]:
|
||||
self.check_over_configs(beta_schedule=schedule)
|
||||
|
||||
def test_full_loop_no_noise(self):
|
||||
if torch_device == "mps":
|
||||
return
|
||||
scheduler_class = self.scheduler_classes[0]
|
||||
scheduler_config = self.get_scheduler_config()
|
||||
scheduler = scheduler_class(**scheduler_config)
|
||||
|
||||
scheduler.set_timesteps(self.num_inference_steps)
|
||||
|
||||
generator = torch.Generator(device=torch_device).manual_seed(0)
|
||||
|
||||
model = self.dummy_model()
|
||||
sample = self.dummy_sample_deter * scheduler.init_noise_sigma
|
||||
sample = sample.to(torch_device)
|
||||
|
||||
for i, t in enumerate(scheduler.timesteps):
|
||||
sample = scheduler.scale_model_input(sample, t)
|
||||
|
||||
model_output = model(sample, t)
|
||||
|
||||
output = scheduler.step(model_output, t, sample, generator=generator)
|
||||
sample = output.prev_sample
|
||||
|
||||
result_sum = torch.sum(torch.abs(sample))
|
||||
result_mean = torch.mean(torch.abs(sample))
|
||||
|
||||
if torch_device in ["cpu", "mps"]:
|
||||
assert abs(result_sum.item() - 13849.3945) < 1e-2
|
||||
assert abs(result_mean.item() - 18.0331) < 5e-3
|
||||
else:
|
||||
# CUDA
|
||||
assert abs(result_sum.item() - 13913.0449) < 1e-2
|
||||
assert abs(result_mean.item() - 18.1159) < 5e-3
|
||||
|
||||
def test_full_loop_device(self):
|
||||
if torch_device == "mps":
|
||||
return
|
||||
scheduler_class = self.scheduler_classes[0]
|
||||
scheduler_config = self.get_scheduler_config()
|
||||
scheduler = scheduler_class(**scheduler_config)
|
||||
|
||||
scheduler.set_timesteps(self.num_inference_steps, device=torch_device)
|
||||
|
||||
if torch_device == "mps":
|
||||
# device type MPS is not supported for torch.Generator() api.
|
||||
generator = torch.manual_seed(0)
|
||||
else:
|
||||
generator = torch.Generator(device=torch_device).manual_seed(0)
|
||||
|
||||
model = self.dummy_model()
|
||||
sample = self.dummy_sample_deter.to(torch_device) * scheduler.init_noise_sigma
|
||||
|
||||
for t in scheduler.timesteps:
|
||||
sample = scheduler.scale_model_input(sample, t)
|
||||
|
||||
model_output = model(sample, t)
|
||||
|
||||
output = scheduler.step(model_output, t, sample, generator=generator)
|
||||
sample = output.prev_sample
|
||||
|
||||
result_sum = torch.sum(torch.abs(sample))
|
||||
result_mean = torch.mean(torch.abs(sample))
|
||||
|
||||
if str(torch_device).startswith("cpu"):
|
||||
assert abs(result_sum.item() - 13849.3945) < 1e-2
|
||||
assert abs(result_mean.item() - 18.0331) < 5e-3
|
||||
else:
|
||||
# CUDA
|
||||
assert abs(result_sum.item() - 13913.0459) < 1e-2
|
||||
assert abs(result_mean.item() - 18.1159) < 1e-3
|
||||
|
||||
@@ -126,7 +126,7 @@ class FlaxSchedulerCommonTest(unittest.TestCase):
|
||||
|
||||
assert jnp.sum(jnp.abs(output - new_output)) < 1e-5, "Scheduler outputs are not identical"
|
||||
|
||||
def test_from_pretrained_save_pretrained(self):
|
||||
def test_from_save_pretrained(self):
|
||||
kwargs = dict(self.forward_default_kwargs)
|
||||
|
||||
num_inference_steps = kwargs.pop("num_inference_steps", None)
|
||||
@@ -408,7 +408,7 @@ class FlaxDDIMSchedulerTest(FlaxSchedulerCommonTest):
|
||||
|
||||
assert jnp.sum(jnp.abs(output - new_output)) < 1e-5, "Scheduler outputs are not identical"
|
||||
|
||||
def test_from_pretrained_save_pretrained(self):
|
||||
def test_from_save_pretrained(self):
|
||||
kwargs = dict(self.forward_default_kwargs)
|
||||
|
||||
num_inference_steps = kwargs.pop("num_inference_steps", None)
|
||||
@@ -690,7 +690,7 @@ class FlaxPNDMSchedulerTest(FlaxSchedulerCommonTest):
|
||||
|
||||
assert jnp.sum(jnp.abs(output - new_output)) < 1e-5, "Scheduler outputs are not identical"
|
||||
|
||||
def test_from_pretrained_save_pretrained(self):
|
||||
def test_from_save_pretrained(self):
|
||||
pass
|
||||
|
||||
def test_scheduler_outputs_equivalence(self):
|
||||
|
||||
Reference in New Issue
Block a user