Compare commits
19 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| 5f6a41ed02 | |||
| 99f608218c | |||
| 7f323f0f31 | |||
| 61d37640ad | |||
| 33fafe3d14 | |||
| c4a8979f30 | |||
| f9fd511466 | |||
| 91fbe16b63 | |||
| 8e7d6c03a3 | |||
| b28675c605 | |||
| bd4df2856a | |||
| 11542431a5 | |||
| 81cf3b2f15 | |||
| 534848c370 | |||
| 2daedc0ad3 | |||
| 665c6b47a2 | |||
| 066ea374c8 | |||
| 9cd37557d5 | |||
| 1c6ede9371 |
@@ -55,6 +55,9 @@ Since RegEx is supported as a way for matching layer identifiers, it is crucial
|
||||
|
||||
## StableDiffusionControlNetPAGPipeline
|
||||
[[autodoc]] StableDiffusionControlNetPAGPipeline
|
||||
|
||||
## StableDiffusionControlNetPAGInpaintPipeline
|
||||
[[autodoc]] StableDiffusionControlNetPAGInpaintPipeline
|
||||
- all
|
||||
- __call__
|
||||
|
||||
|
||||
@@ -52,6 +52,7 @@ Many schedulers are implemented from the [k-diffusion](https://github.com/crowso
|
||||
| sgm_uniform | init with `timestep_spacing="trailing"` |
|
||||
| simple | init with `timestep_spacing="trailing"` |
|
||||
| exponential | init with `timestep_spacing="linspace"`, `use_exponential_sigmas=True` |
|
||||
| beta | init with `timestep_spacing="linspace"`, `use_beta_sigmas=True` |
|
||||
|
||||
All schedulers are built from the base [`SchedulerMixin`] class which implements low level utilities shared by all schedulers.
|
||||
|
||||
|
||||
@@ -38,10 +38,7 @@ from diffusers import AutoencoderKLCogVideoX, CogVideoXDPMScheduler, CogVideoXPi
|
||||
from diffusers.models.embeddings import get_3d_rotary_pos_embed
|
||||
from diffusers.optimization import get_scheduler
|
||||
from diffusers.pipelines.cogvideo.pipeline_cogvideox import get_resize_crop_region_for_grid
|
||||
from diffusers.training_utils import (
|
||||
cast_training_params,
|
||||
clear_objs_and_retain_memory,
|
||||
)
|
||||
from diffusers.training_utils import cast_training_params, free_memory
|
||||
from diffusers.utils import check_min_version, convert_unet_state_dict_to_peft, export_to_video, is_wandb_available
|
||||
from diffusers.utils.hub_utils import load_or_create_model_card, populate_model_card
|
||||
from diffusers.utils.torch_utils import is_compiled_module
|
||||
@@ -726,7 +723,8 @@ def log_validation(
|
||||
}
|
||||
)
|
||||
|
||||
clear_objs_and_retain_memory([pipe])
|
||||
del pipe
|
||||
free_memory()
|
||||
|
||||
return videos
|
||||
|
||||
|
||||
@@ -0,0 +1,430 @@
|
||||
# ControlNet training example for FLUX
|
||||
|
||||
The `train_controlnet_flux.py` script shows how to implement the ControlNet training procedure and adapt it for [FLUX](https://github.com/black-forest-labs/flux).
|
||||
|
||||
Training script provided by LibAI, which is an institution dedicated to the progress and achievement of artificial general intelligence. LibAI is the developer of [cutout.pro](https://www.cutout.pro/) and [promeai.pro](https://www.promeai.pro/).
|
||||
> [!NOTE]
|
||||
> **Memory consumption**
|
||||
>
|
||||
> Flux can be quite expensive to run on consumer hardware devices and as a result, ControlNet training of it comes with higher memory requirements than usual.
|
||||
|
||||
> **Gated access**
|
||||
>
|
||||
> As the model is gated, before using it with diffusers you first need to go to the [FLUX.1 [dev] Hugging Face page](https://huggingface.co/black-forest-labs/FLUX.1-dev), fill in the form and accept the gate. Once you are in, you need to log in so that your system knows you’ve accepted the gate. Use the command below to log in: `huggingface-cli login`
|
||||
|
||||
|
||||
## Running locally with PyTorch
|
||||
|
||||
### Installing the dependencies
|
||||
|
||||
Before running the scripts, make sure to install the library's training dependencies:
|
||||
|
||||
**Important**
|
||||
|
||||
To make sure you can successfully run the latest versions of the example scripts, we highly recommend **installing from source** and keeping the install up to date as we update the example scripts frequently and install some example-specific requirements. To do this, execute the following steps in a new virtual environment:
|
||||
|
||||
```bash
|
||||
git clone https://github.com/huggingface/diffusers
|
||||
cd diffusers
|
||||
pip install -e .
|
||||
```
|
||||
|
||||
Then cd in the `examples/controlnet` folder and run
|
||||
```bash
|
||||
pip install -r requirements_flux.txt
|
||||
```
|
||||
|
||||
And initialize an [🤗Accelerate](https://github.com/huggingface/accelerate/) environment with:
|
||||
|
||||
```bash
|
||||
accelerate config
|
||||
```
|
||||
|
||||
Or for a default accelerate configuration without answering questions about your environment
|
||||
|
||||
```bash
|
||||
accelerate config default
|
||||
```
|
||||
|
||||
Or if your environment doesn't support an interactive shell (e.g., a notebook)
|
||||
|
||||
```python
|
||||
from accelerate.utils import write_basic_config
|
||||
write_basic_config()
|
||||
```
|
||||
|
||||
When running `accelerate config`, if we specify torch compile mode to True there can be dramatic speedups.
|
||||
|
||||
## Custom Datasets
|
||||
|
||||
We support dataset formats:
|
||||
The original dataset is hosted in the [ControlNet repo](https://huggingface.co/lllyasviel/ControlNet/blob/main/training/fill50k.zip). We re-uploaded it to be compatible with `datasets` [here](https://huggingface.co/datasets/fusing/fill50k). Note that `datasets` handles dataloading within the training script. To use our example, add `--dataset_name=fusing/fill50k \` to the script and remove line `--jsonl_for_train` mentioned below.
|
||||
|
||||
|
||||
We also support importing data from jsonl(xxx.jsonl),using `--jsonl_for_train` to enable it, here is a brief example of jsonl files:
|
||||
```sh
|
||||
{"image": "xxx", "text": "xxx", "conditioning_image": "xxx"}
|
||||
{"image": "xxx", "text": "xxx", "conditioning_image": "xxx"}
|
||||
```
|
||||
|
||||
## Training
|
||||
|
||||
Our training examples use two test conditioning images. They can be downloaded by running
|
||||
|
||||
```sh
|
||||
wget https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/controlnet_training/conditioning_image_1.png
|
||||
wget https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/controlnet_training/conditioning_image_2.png
|
||||
```
|
||||
|
||||
Then run `huggingface-cli login` to log into your Hugging Face account. This is needed to be able to push the trained ControlNet parameters to Hugging Face Hub.
|
||||
|
||||
we can define the num_layers, num_single_layers, which determines the size of the control(default values are num_layers=4, num_single_layers=10)
|
||||
|
||||
|
||||
```bash
|
||||
accelerate launch train_controlnet_flux.py \
|
||||
--pretrained_model_name_or_path="black-forest-labs/FLUX.1-dev" \
|
||||
--dataset_name=fusing/fill50k \
|
||||
--conditioning_image_column=conditioning_image \
|
||||
--image_column=image \
|
||||
--caption_column=text \
|
||||
--output_dir="path to save model" \
|
||||
--mixed_precision="bf16" \
|
||||
--resolution=512 \
|
||||
--learning_rate=1e-5 \
|
||||
--max_train_steps=15000 \
|
||||
--validation_steps=100 \
|
||||
--checkpointing_steps=200 \
|
||||
--validation_image "./conditioning_image_1.png" "./conditioning_image_2.png" \
|
||||
--validation_prompt "red circle with blue background" "cyan circle with brown floral background" \
|
||||
--train_batch_size=1 \
|
||||
--gradient_accumulation_steps=4 \
|
||||
--report_to="wandb" \
|
||||
--num_double_layers=4 \
|
||||
--num_single_layers=0 \
|
||||
--seed=42 \
|
||||
--push_to_hub \
|
||||
```
|
||||
|
||||
To better track our training experiments, we're using the following flags in the command above:
|
||||
|
||||
* `report_to="wandb` will ensure the training runs are tracked on Weights and Biases.
|
||||
* `validation_image`, `validation_prompt`, and `validation_steps` to allow the script to do a few validation inference runs. This allows us to qualitatively check if the training is progressing as expected.
|
||||
|
||||
Our experiments were conducted on a single 80GB A100 GPU.
|
||||
|
||||
### Inference
|
||||
|
||||
Once training is done, we can perform inference like so:
|
||||
|
||||
```python
|
||||
import torch
|
||||
from diffusers.utils import load_image
|
||||
from diffusers.pipelines.flux.pipeline_flux_controlnet import FluxControlNetPipeline
|
||||
from diffusers.models.controlnet_flux import FluxControlNetModel
|
||||
|
||||
base_model = 'black-forest-labs/FLUX.1-dev'
|
||||
controlnet_model = 'promeai/FLUX.1-controlnet-lineart-promeai'
|
||||
controlnet = FluxControlNetModel.from_pretrained(controlnet_model, torch_dtype=torch.bfloat16)
|
||||
pipe = FluxControlNetPipeline.from_pretrained(
|
||||
base_model,
|
||||
controlnet=controlnet,
|
||||
torch_dtype=torch.bfloat16
|
||||
)
|
||||
# enable memory optimizations
|
||||
pipe.enable_model_cpu_offload()
|
||||
|
||||
control_image = load_image("https://huggingface.co/promeai/FLUX.1-controlnet-lineart-promeai/resolve/main/images/example-control.jpg")resize((1024, 1024))
|
||||
prompt = "cute anime girl with massive fluffy fennec ears and a big fluffy tail blonde messy long hair blue eyes wearing a maid outfit with a long black gold leaf pattern dress and a white apron mouth open holding a fancy black forest cake with candles on top in the kitchen of an old dark Victorian mansion lit by candlelight with a bright window to the foggy forest and very expensive stuff everywhere"
|
||||
|
||||
image = pipe(
|
||||
prompt,
|
||||
control_image=control_image,
|
||||
controlnet_conditioning_scale=0.6,
|
||||
num_inference_steps=28,
|
||||
guidance_scale=3.5,
|
||||
).images[0]
|
||||
image.save("./output.png")
|
||||
```
|
||||
|
||||
## Apply Deepspeed Zero3
|
||||
|
||||
This is an experimental process, I am not sure if it is suitable for everyone, we used this process to successfully train 512 resolution on A100(40g) * 8.
|
||||
Please modify some of the code in the script.
|
||||
### 1.Customize zero3 settings
|
||||
|
||||
Copy the **accelerate_config_zero3.yaml**,modify `num_processes` according to the number of gpus you want to use:
|
||||
|
||||
```bash
|
||||
compute_environment: LOCAL_MACHINE
|
||||
debug: false
|
||||
deepspeed_config:
|
||||
gradient_accumulation_steps: 8
|
||||
offload_optimizer_device: cpu
|
||||
offload_param_device: cpu
|
||||
zero3_init_flag: true
|
||||
zero3_save_16bit_model: true
|
||||
zero_stage: 3
|
||||
distributed_type: DEEPSPEED
|
||||
downcast_bf16: 'no'
|
||||
enable_cpu_affinity: false
|
||||
machine_rank: 0
|
||||
main_training_function: main
|
||||
mixed_precision: bf16
|
||||
num_machines: 1
|
||||
num_processes: 8
|
||||
rdzv_backend: static
|
||||
same_network: true
|
||||
tpu_env: []
|
||||
tpu_use_cluster: false
|
||||
tpu_use_sudo: false
|
||||
use_cpu: false
|
||||
```
|
||||
|
||||
### 2.Precompute all inputs (latent, embeddings)
|
||||
|
||||
In the train_controlnet_flux.py, We need to pre-calculate all parameters and put them into batches.So we first need to rewrite the `compute_embeddings` function.
|
||||
|
||||
```python
|
||||
def compute_embeddings(batch, proportion_empty_prompts, vae, flux_controlnet_pipeline, weight_dtype, is_train=True):
|
||||
|
||||
### compute text embeddings
|
||||
prompt_batch = batch[args.caption_column]
|
||||
captions = []
|
||||
for caption in prompt_batch:
|
||||
if random.random() < proportion_empty_prompts:
|
||||
captions.append("")
|
||||
elif isinstance(caption, str):
|
||||
captions.append(caption)
|
||||
elif isinstance(caption, (list, np.ndarray)):
|
||||
# take a random caption if there are multiple
|
||||
captions.append(random.choice(caption) if is_train else caption[0])
|
||||
prompt_batch = captions
|
||||
prompt_embeds, pooled_prompt_embeds, text_ids = flux_controlnet_pipeline.encode_prompt(
|
||||
prompt_batch, prompt_2=prompt_batch
|
||||
)
|
||||
prompt_embeds = prompt_embeds.to(dtype=weight_dtype)
|
||||
pooled_prompt_embeds = pooled_prompt_embeds.to(dtype=weight_dtype)
|
||||
text_ids = text_ids.to(dtype=weight_dtype)
|
||||
|
||||
# text_ids [512,3] to [bs,512,3]
|
||||
text_ids = text_ids.unsqueeze(0).expand(prompt_embeds.shape[0], -1, -1)
|
||||
|
||||
### compute latents
|
||||
def _pack_latents(latents, batch_size, num_channels_latents, height, width):
|
||||
latents = latents.view(batch_size, num_channels_latents, height // 2, 2, width // 2, 2)
|
||||
latents = latents.permute(0, 2, 4, 1, 3, 5)
|
||||
latents = latents.reshape(batch_size, (height // 2) * (width // 2), num_channels_latents * 4)
|
||||
return latents
|
||||
|
||||
# vae encode
|
||||
pixel_values = batch["pixel_values"]
|
||||
pixel_values = torch.stack([image for image in pixel_values]).to(dtype=weight_dtype).to(vae.device)
|
||||
pixel_latents_tmp = vae.encode(pixel_values).latent_dist.sample()
|
||||
pixel_latents_tmp = (pixel_latents_tmp - vae.config.shift_factor) * vae.config.scaling_factor
|
||||
pixel_latents = _pack_latents(
|
||||
pixel_latents_tmp,
|
||||
pixel_values.shape[0],
|
||||
pixel_latents_tmp.shape[1],
|
||||
pixel_latents_tmp.shape[2],
|
||||
pixel_latents_tmp.shape[3],
|
||||
)
|
||||
|
||||
control_values = batch["conditioning_pixel_values"]
|
||||
control_values = torch.stack([image for image in control_values]).to(dtype=weight_dtype).to(vae.device)
|
||||
control_latents = vae.encode(control_values).latent_dist.sample()
|
||||
control_latents = (control_latents - vae.config.shift_factor) * vae.config.scaling_factor
|
||||
control_latents = _pack_latents(
|
||||
control_latents,
|
||||
control_values.shape[0],
|
||||
control_latents.shape[1],
|
||||
control_latents.shape[2],
|
||||
control_latents.shape[3],
|
||||
)
|
||||
|
||||
# copied from pipeline_flux_controlnet
|
||||
def _prepare_latent_image_ids(batch_size, height, width, device, dtype):
|
||||
latent_image_ids = torch.zeros(height // 2, width // 2, 3)
|
||||
latent_image_ids[..., 1] = latent_image_ids[..., 1] + torch.arange(height // 2)[:, None]
|
||||
latent_image_ids[..., 2] = latent_image_ids[..., 2] + torch.arange(width // 2)[None, :]
|
||||
|
||||
latent_image_id_height, latent_image_id_width, latent_image_id_channels = latent_image_ids.shape
|
||||
|
||||
latent_image_ids = latent_image_ids[None, :].repeat(batch_size, 1, 1, 1)
|
||||
latent_image_ids = latent_image_ids.reshape(
|
||||
batch_size, latent_image_id_height * latent_image_id_width, latent_image_id_channels
|
||||
)
|
||||
|
||||
return latent_image_ids.to(device=device, dtype=dtype)
|
||||
latent_image_ids = _prepare_latent_image_ids(
|
||||
batch_size=pixel_latents_tmp.shape[0],
|
||||
height=pixel_latents_tmp.shape[2],
|
||||
width=pixel_latents_tmp.shape[3],
|
||||
device=pixel_values.device,
|
||||
dtype=pixel_values.dtype,
|
||||
)
|
||||
|
||||
# unet_added_cond_kwargs = {"pooled_prompt_embeds": pooled_prompt_embeds, "text_ids": text_ids}
|
||||
return {"prompt_embeds": prompt_embeds, "pooled_prompt_embeds": pooled_prompt_embeds, "text_ids": text_ids, "pixel_latents": pixel_latents, "control_latents": control_latents, "latent_image_ids": latent_image_ids}
|
||||
```
|
||||
|
||||
Because we need images to pass through vae, we need to preprocess the images in the dataset first. At the same time, vae requires more gpu memory, so you may need to modify the `batch_size` below
|
||||
```diff
|
||||
+train_dataset = prepare_train_dataset(train_dataset, accelerator)
|
||||
with accelerator.main_process_first():
|
||||
from datasets.fingerprint import Hasher
|
||||
|
||||
# fingerprint used by the cache for the other processes to load the result
|
||||
# details: https://github.com/huggingface/diffusers/pull/4038#discussion_r1266078401
|
||||
new_fingerprint = Hasher.hash(args)
|
||||
train_dataset = train_dataset.map(
|
||||
- compute_embeddings_fn, batched=True, new_fingerprint=new_fingerprint, batch_size=100
|
||||
+ compute_embeddings_fn, batched=True, new_fingerprint=new_fingerprint, batch_size=10
|
||||
)
|
||||
|
||||
del text_encoders, tokenizers
|
||||
gc.collect()
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
# Then get the training dataset ready to be passed to the dataloader.
|
||||
-train_dataset = prepare_train_dataset(train_dataset, accelerator)
|
||||
```
|
||||
### 3.Redefine the behavior of getting batchsize
|
||||
|
||||
Now that we have all the preprocessing done, we need to modify the `collate_fn` function.
|
||||
|
||||
```python
|
||||
def collate_fn(examples):
|
||||
pixel_values = torch.stack([example["pixel_values"] for example in examples])
|
||||
pixel_values = pixel_values.to(memory_format=torch.contiguous_format).float()
|
||||
|
||||
conditioning_pixel_values = torch.stack([example["conditioning_pixel_values"] for example in examples])
|
||||
conditioning_pixel_values = conditioning_pixel_values.to(memory_format=torch.contiguous_format).float()
|
||||
|
||||
pixel_latents = torch.stack([torch.tensor(example["pixel_latents"]) for example in examples])
|
||||
pixel_latents = pixel_latents.to(memory_format=torch.contiguous_format).float()
|
||||
|
||||
control_latents = torch.stack([torch.tensor(example["control_latents"]) for example in examples])
|
||||
control_latents = control_latents.to(memory_format=torch.contiguous_format).float()
|
||||
|
||||
latent_image_ids= torch.stack([torch.tensor(example["latent_image_ids"]) for example in examples])
|
||||
|
||||
prompt_ids = torch.stack([torch.tensor(example["prompt_embeds"]) for example in examples])
|
||||
|
||||
pooled_prompt_embeds = torch.stack([torch.tensor(example["pooled_prompt_embeds"]) for example in examples])
|
||||
text_ids = torch.stack([torch.tensor(example["text_ids"]) for example in examples])
|
||||
|
||||
return {
|
||||
"pixel_values": pixel_values,
|
||||
"conditioning_pixel_values": conditioning_pixel_values,
|
||||
"pixel_latents": pixel_latents,
|
||||
"control_latents": control_latents,
|
||||
"latent_image_ids": latent_image_ids,
|
||||
"prompt_ids": prompt_ids,
|
||||
"unet_added_conditions": {"pooled_prompt_embeds": pooled_prompt_embeds, "time_ids": text_ids},
|
||||
}
|
||||
```
|
||||
Finally, we just need to modify the way of obtaining various parameters during training.
|
||||
```python
|
||||
for epoch in range(first_epoch, args.num_train_epochs):
|
||||
for step, batch in enumerate(train_dataloader):
|
||||
with accelerator.accumulate(flux_controlnet):
|
||||
# Convert images to latent space
|
||||
pixel_latents = batch["pixel_latents"].to(dtype=weight_dtype)
|
||||
control_image = batch["control_latents"].to(dtype=weight_dtype)
|
||||
latent_image_ids = batch["latent_image_ids"].to(dtype=weight_dtype)
|
||||
|
||||
# Sample noise that we'll add to the latents
|
||||
noise = torch.randn_like(pixel_latents).to(accelerator.device).to(dtype=weight_dtype)
|
||||
bsz = pixel_latents.shape[0]
|
||||
|
||||
# Sample a random timestep for each image
|
||||
t = torch.sigmoid(torch.randn((bsz,), device=accelerator.device, dtype=weight_dtype))
|
||||
|
||||
# apply flow matching
|
||||
noisy_latents = (
|
||||
1 - t.unsqueeze(1).unsqueeze(2).repeat(1, pixel_latents.shape[1], pixel_latents.shape[2])
|
||||
) * pixel_latents + t.unsqueeze(1).unsqueeze(2).repeat(
|
||||
1, pixel_latents.shape[1], pixel_latents.shape[2]
|
||||
) * noise
|
||||
|
||||
guidance_vec = torch.full(
|
||||
(noisy_latents.shape[0],), 3.5, device=noisy_latents.device, dtype=weight_dtype
|
||||
)
|
||||
|
||||
controlnet_block_samples, controlnet_single_block_samples = flux_controlnet(
|
||||
hidden_states=noisy_latents,
|
||||
controlnet_cond=control_image,
|
||||
timestep=t,
|
||||
guidance=guidance_vec,
|
||||
pooled_projections=batch["unet_added_conditions"]["pooled_prompt_embeds"].to(dtype=weight_dtype),
|
||||
encoder_hidden_states=batch["prompt_ids"].to(dtype=weight_dtype),
|
||||
txt_ids=batch["unet_added_conditions"]["time_ids"][0].to(dtype=weight_dtype),
|
||||
img_ids=latent_image_ids[0],
|
||||
return_dict=False,
|
||||
)
|
||||
|
||||
noise_pred = flux_transformer(
|
||||
hidden_states=noisy_latents,
|
||||
timestep=t,
|
||||
guidance=guidance_vec,
|
||||
pooled_projections=batch["unet_added_conditions"]["pooled_prompt_embeds"].to(dtype=weight_dtype),
|
||||
encoder_hidden_states=batch["prompt_ids"].to(dtype=weight_dtype),
|
||||
controlnet_block_samples=[sample.to(dtype=weight_dtype) for sample in controlnet_block_samples]
|
||||
if controlnet_block_samples is not None
|
||||
else None,
|
||||
controlnet_single_block_samples=[
|
||||
sample.to(dtype=weight_dtype) for sample in controlnet_single_block_samples
|
||||
]
|
||||
if controlnet_single_block_samples is not None
|
||||
else None,
|
||||
txt_ids=batch["unet_added_conditions"]["time_ids"][0].to(dtype=weight_dtype),
|
||||
img_ids=latent_image_ids[0],
|
||||
return_dict=False,
|
||||
)[0]
|
||||
```
|
||||
Congratulations! You have completed all the required code modifications required for deepspeedzero3.
|
||||
|
||||
### 4.Training with deepspeedzero3
|
||||
|
||||
Start!!!
|
||||
|
||||
```bash
|
||||
export pretrained_model_name_or_path='flux-dev-model-path'
|
||||
export MODEL_TYPE='train_model_type'
|
||||
export TRAIN_JSON_FILE="your_json_file"
|
||||
export CONTROL_TYPE='control_preprocessor_type'
|
||||
export CAPTION_COLUMN='caption_column'
|
||||
|
||||
export CACHE_DIR="/data/train_csr/.cache/huggingface/"
|
||||
export OUTPUT_DIR='/data/train_csr/FLUX/MODEL_OUT/'$MODEL_TYPE
|
||||
# The first step is to use Python to precompute all caches.Replace the first line below with this line. (I am not sure why using acclerate would cause problems.)
|
||||
|
||||
CUDA_VISIBLE_DEVICES=0 python3 train_controlnet_flux.py \
|
||||
|
||||
# The second step is to use the above accelerate config to train
|
||||
accelerate launch --config_file "./accelerate_config_zero3.yaml" train_controlnet_flux.py \
|
||||
--pretrained_model_name_or_path=$pretrained_model_name_or_path \
|
||||
--jsonl_for_train=$TRAIN_JSON_FILE \
|
||||
--conditioning_image_column=$CONTROL_TYPE \
|
||||
--image_column=image \
|
||||
--caption_column=$CAPTION_COLUMN\
|
||||
--cache_dir=$CACHE_DIR \
|
||||
--tracker_project_name=$MODEL_TYPE \
|
||||
--output_dir=$OUTPUT_DIR \
|
||||
--max_train_steps=500000 \
|
||||
--mixed_precision bf16 \
|
||||
--checkpointing_steps=1000 \
|
||||
--gradient_accumulation_steps=8 \
|
||||
--resolution=512 \
|
||||
--train_batch_size=1 \
|
||||
--learning_rate=1e-5 \
|
||||
--num_double_layers=4 \
|
||||
--num_single_layers=0 \
|
||||
--gradient_checkpointing \
|
||||
--resume_from_checkpoint="latest" \
|
||||
# --use_adafactor \ dont use
|
||||
# --validation_steps=3 \ not support
|
||||
# --validation_image $VALIDATION_IMAGE \ not support
|
||||
# --validation_prompt "xxx" \ not support
|
||||
```
|
||||
@@ -0,0 +1,9 @@
|
||||
accelerate>=0.16.0
|
||||
torchvision
|
||||
transformers>=4.25.1
|
||||
ftfy
|
||||
tensorboard
|
||||
Jinja2
|
||||
datasets
|
||||
wandb
|
||||
SentencePiece
|
||||
@@ -136,3 +136,28 @@ class ControlNetSD3(ExamplesTestsAccelerate):
|
||||
run_command(self._launch_args + test_args)
|
||||
|
||||
self.assertTrue(os.path.isfile(os.path.join(tmpdir, "diffusion_pytorch_model.safetensors")))
|
||||
|
||||
|
||||
class ControlNetflux(ExamplesTestsAccelerate):
|
||||
def test_controlnet_flux(self):
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
test_args = f"""
|
||||
examples/controlnet/train_controlnet_flux.py
|
||||
--pretrained_model_name_or_path=hf-internal-testing/tiny-flux-pipe
|
||||
--output_dir={tmpdir}
|
||||
--dataset_name=hf-internal-testing/fill10
|
||||
--conditioning_image_column=conditioning_image
|
||||
--image_column=image
|
||||
--caption_column=text
|
||||
--resolution=64
|
||||
--train_batch_size=1
|
||||
--gradient_accumulation_steps=1
|
||||
--max_train_steps=4
|
||||
--checkpointing_steps=2
|
||||
--num_double_layers=1
|
||||
--num_single_layers=1
|
||||
""".split()
|
||||
|
||||
run_command(self._launch_args + test_args)
|
||||
|
||||
self.assertTrue(os.path.isfile(os.path.join(tmpdir, "diffusion_pytorch_model.safetensors")))
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -49,11 +49,7 @@ from diffusers import (
|
||||
StableDiffusion3ControlNetPipeline,
|
||||
)
|
||||
from diffusers.optimization import get_scheduler
|
||||
from diffusers.training_utils import (
|
||||
clear_objs_and_retain_memory,
|
||||
compute_density_for_timestep_sampling,
|
||||
compute_loss_weighting_for_sd3,
|
||||
)
|
||||
from diffusers.training_utils import compute_density_for_timestep_sampling, compute_loss_weighting_for_sd3, free_memory
|
||||
from diffusers.utils import check_min_version, is_wandb_available
|
||||
from diffusers.utils.hub_utils import load_or_create_model_card, populate_model_card
|
||||
from diffusers.utils.torch_utils import is_compiled_module
|
||||
@@ -174,7 +170,8 @@ def log_validation(controlnet, args, accelerator, weight_dtype, step, is_final_v
|
||||
else:
|
||||
logger.warning(f"image logging not implemented for {tracker.name}")
|
||||
|
||||
clear_objs_and_retain_memory(pipeline)
|
||||
del pipeline
|
||||
free_memory()
|
||||
|
||||
if not is_final_validation:
|
||||
controlnet.to(accelerator.device)
|
||||
@@ -1131,7 +1128,9 @@ def main(args):
|
||||
new_fingerprint = Hasher.hash(args)
|
||||
train_dataset = train_dataset.map(compute_embeddings_fn, batched=True, new_fingerprint=new_fingerprint)
|
||||
|
||||
clear_objs_and_retain_memory(text_encoders + tokenizers)
|
||||
del text_encoder_one, text_encoder_two, text_encoder_three
|
||||
del tokenizer_one, tokenizer_two, tokenizer_three
|
||||
free_memory()
|
||||
|
||||
train_dataloader = torch.utils.data.DataLoader(
|
||||
train_dataset,
|
||||
|
||||
@@ -55,9 +55,9 @@ from diffusers.optimization import get_scheduler
|
||||
from diffusers.training_utils import (
|
||||
_set_state_dict_into_text_encoder,
|
||||
cast_training_params,
|
||||
clear_objs_and_retain_memory,
|
||||
compute_density_for_timestep_sampling,
|
||||
compute_loss_weighting_for_sd3,
|
||||
free_memory,
|
||||
)
|
||||
from diffusers.utils import (
|
||||
check_min_version,
|
||||
@@ -1437,7 +1437,8 @@ def main(args):
|
||||
|
||||
# Clear the memory here
|
||||
if not args.train_text_encoder and not train_dataset.custom_instance_prompts:
|
||||
clear_objs_and_retain_memory([tokenizers, text_encoders, text_encoder_one, text_encoder_two])
|
||||
del text_encoder_one, text_encoder_two, tokenizer_one, tokenizer_two
|
||||
free_memory()
|
||||
|
||||
# If custom instance prompts are NOT provided (i.e. the instance prompt is used for all images),
|
||||
# pack the statically computed variables appropriately here. This is so that we don't
|
||||
@@ -1480,7 +1481,8 @@ def main(args):
|
||||
latents_cache.append(vae.encode(batch["pixel_values"]).latent_dist)
|
||||
|
||||
if args.validation_prompt is None:
|
||||
clear_objs_and_retain_memory([vae])
|
||||
del vae
|
||||
free_memory()
|
||||
|
||||
# Scheduler and math around the number of training steps.
|
||||
overrode_max_train_steps = False
|
||||
@@ -1817,7 +1819,8 @@ def main(args):
|
||||
torch_dtype=weight_dtype,
|
||||
)
|
||||
if not args.train_text_encoder:
|
||||
clear_objs_and_retain_memory([text_encoder_one, text_encoder_two])
|
||||
del text_encoder_one, text_encoder_two
|
||||
free_memory()
|
||||
|
||||
# Save the lora layers
|
||||
accelerator.wait_for_everyone()
|
||||
|
||||
@@ -55,9 +55,9 @@ from diffusers.optimization import get_scheduler
|
||||
from diffusers.training_utils import (
|
||||
_set_state_dict_into_text_encoder,
|
||||
cast_training_params,
|
||||
clear_objs_and_retain_memory,
|
||||
compute_density_for_timestep_sampling,
|
||||
compute_loss_weighting_for_sd3,
|
||||
free_memory,
|
||||
)
|
||||
from diffusers.utils import (
|
||||
check_min_version,
|
||||
@@ -211,7 +211,8 @@ def log_validation(
|
||||
}
|
||||
)
|
||||
|
||||
clear_objs_and_retain_memory(objs=[pipeline])
|
||||
del pipeline
|
||||
free_memory()
|
||||
|
||||
return images
|
||||
|
||||
@@ -1106,7 +1107,8 @@ def main(args):
|
||||
image_filename = class_images_dir / f"{example['index'][i] + cur_class_images}-{hash_image}.jpg"
|
||||
image.save(image_filename)
|
||||
|
||||
clear_objs_and_retain_memory(objs=[pipeline])
|
||||
del pipeline
|
||||
free_memory()
|
||||
|
||||
# Handle the repository creation
|
||||
if accelerator.is_main_process:
|
||||
@@ -1453,9 +1455,9 @@ def main(args):
|
||||
# Clear the memory here
|
||||
if not args.train_text_encoder and not train_dataset.custom_instance_prompts:
|
||||
# Explicitly delete the objects as well, otherwise only the lists are deleted and the original references remain, preventing garbage collection
|
||||
clear_objs_and_retain_memory(
|
||||
objs=[tokenizers, text_encoders, text_encoder_one, text_encoder_two, text_encoder_three]
|
||||
)
|
||||
del tokenizers, text_encoders
|
||||
del text_encoder_one, text_encoder_two, text_encoder_three
|
||||
free_memory()
|
||||
|
||||
# If custom instance prompts are NOT provided (i.e. the instance prompt is used for all images),
|
||||
# pack the statically computed variables appropriately here. This is so that we don't
|
||||
@@ -1791,11 +1793,9 @@ def main(args):
|
||||
epoch=epoch,
|
||||
torch_dtype=weight_dtype,
|
||||
)
|
||||
objs = []
|
||||
if not args.train_text_encoder:
|
||||
objs.extend([text_encoder_one, text_encoder_two, text_encoder_three])
|
||||
|
||||
clear_objs_and_retain_memory(objs=objs)
|
||||
del text_encoder_one, text_encoder_two, text_encoder_three
|
||||
free_memory()
|
||||
|
||||
# Save the lora layers
|
||||
accelerator.wait_for_everyone()
|
||||
|
||||
@@ -747,17 +747,22 @@ def main():
|
||||
)
|
||||
|
||||
# Scheduler and math around the number of training steps.
|
||||
overrode_max_train_steps = False
|
||||
num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
|
||||
# Check the PR https://github.com/huggingface/diffusers/pull/8312 for detailed explanation.
|
||||
num_warmup_steps_for_scheduler = args.lr_warmup_steps * accelerator.num_processes
|
||||
if args.max_train_steps is None:
|
||||
args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
|
||||
overrode_max_train_steps = True
|
||||
len_train_dataloader_after_sharding = math.ceil(len(train_dataloader) / accelerator.num_processes)
|
||||
num_update_steps_per_epoch = math.ceil(len_train_dataloader_after_sharding / args.gradient_accumulation_steps)
|
||||
num_training_steps_for_scheduler = (
|
||||
args.num_train_epochs * num_update_steps_per_epoch * accelerator.num_processes
|
||||
)
|
||||
else:
|
||||
num_training_steps_for_scheduler = args.max_train_steps * accelerator.num_processes
|
||||
|
||||
lr_scheduler = get_scheduler(
|
||||
args.lr_scheduler,
|
||||
optimizer=optimizer,
|
||||
num_warmup_steps=args.lr_warmup_steps * accelerator.num_processes,
|
||||
num_training_steps=args.max_train_steps * accelerator.num_processes,
|
||||
num_warmup_steps=num_warmup_steps_for_scheduler,
|
||||
num_training_steps=num_training_steps_for_scheduler,
|
||||
)
|
||||
|
||||
# Prepare everything with our `accelerator`.
|
||||
@@ -782,8 +787,14 @@ def main():
|
||||
|
||||
# We need to recalculate our total training steps as the size of the training dataloader may have changed.
|
||||
num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
|
||||
if overrode_max_train_steps:
|
||||
if args.max_train_steps is None:
|
||||
args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
|
||||
if num_training_steps_for_scheduler != args.max_train_steps * accelerator.num_processes:
|
||||
logger.warning(
|
||||
f"The length of the 'train_dataloader' after 'accelerator.prepare' ({len(train_dataloader)}) does not match "
|
||||
f"the expected length ({len_train_dataloader_after_sharding}) when the learning rate scheduler was created. "
|
||||
f"This inconsistency may result in the learning rate scheduler not functioning properly."
|
||||
)
|
||||
# Afterwards we recalculate our number of training epochs
|
||||
args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)
|
||||
|
||||
|
||||
@@ -328,6 +328,7 @@ else:
|
||||
"StableDiffusionAttendAndExcitePipeline",
|
||||
"StableDiffusionControlNetImg2ImgPipeline",
|
||||
"StableDiffusionControlNetInpaintPipeline",
|
||||
"StableDiffusionControlNetPAGInpaintPipeline",
|
||||
"StableDiffusionControlNetPAGPipeline",
|
||||
"StableDiffusionControlNetPipeline",
|
||||
"StableDiffusionControlNetXSPipeline",
|
||||
@@ -778,6 +779,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
StableDiffusionAttendAndExcitePipeline,
|
||||
StableDiffusionControlNetImg2ImgPipeline,
|
||||
StableDiffusionControlNetInpaintPipeline,
|
||||
StableDiffusionControlNetPAGInpaintPipeline,
|
||||
StableDiffusionControlNetPAGPipeline,
|
||||
StableDiffusionControlNetPipeline,
|
||||
StableDiffusionControlNetXSPipeline,
|
||||
|
||||
@@ -532,13 +532,19 @@ class LoraBaseMixin:
|
||||
)
|
||||
|
||||
list_adapters = self.get_list_adapters() # eg {"unet": ["adapter1", "adapter2"], "text_encoder": ["adapter2"]}
|
||||
all_adapters = {
|
||||
adapter for adapters in list_adapters.values() for adapter in adapters
|
||||
} # eg ["adapter1", "adapter2"]
|
||||
# eg ["adapter1", "adapter2"]
|
||||
all_adapters = {adapter for adapters in list_adapters.values() for adapter in adapters}
|
||||
missing_adapters = set(adapter_names) - all_adapters
|
||||
if len(missing_adapters) > 0:
|
||||
raise ValueError(
|
||||
f"Adapter name(s) {missing_adapters} not in the list of present adapters: {all_adapters}."
|
||||
)
|
||||
|
||||
# eg {"adapter1": ["unet"], "adapter2": ["unet", "text_encoder"]}
|
||||
invert_list_adapters = {
|
||||
adapter: [part for part, adapters in list_adapters.items() if adapter in adapters]
|
||||
for adapter in all_adapters
|
||||
} # eg {"adapter1": ["unet"], "adapter2": ["unet", "text_encoder"]}
|
||||
}
|
||||
|
||||
# Decompose weights into weights for denoiser and text encoders.
|
||||
_component_adapter_weights = {}
|
||||
|
||||
@@ -516,10 +516,47 @@ def _convert_kohya_flux_lora_to_diffusers(state_dict):
|
||||
f"transformer.single_transformer_blocks.{i}.norm.linear",
|
||||
)
|
||||
|
||||
if len(sds_sd) > 0:
|
||||
logger.warning(f"Unsuppored keys for ai-toolkit: {sds_sd.keys()}")
|
||||
remaining_keys = list(sds_sd.keys())
|
||||
te_state_dict = {}
|
||||
if remaining_keys:
|
||||
if not all(k.startswith("lora_te1") for k in remaining_keys):
|
||||
raise ValueError(f"Incompatible keys detected: \n\n {', '.join(remaining_keys)}")
|
||||
for key in remaining_keys:
|
||||
if not key.endswith("lora_down.weight"):
|
||||
continue
|
||||
|
||||
return ait_sd
|
||||
lora_name = key.split(".")[0]
|
||||
lora_name_up = f"{lora_name}.lora_up.weight"
|
||||
lora_name_alpha = f"{lora_name}.alpha"
|
||||
diffusers_name = _convert_text_encoder_lora_key(key, lora_name)
|
||||
|
||||
if lora_name.startswith(("lora_te_", "lora_te1_")):
|
||||
down_weight = sds_sd.pop(key)
|
||||
sd_lora_rank = down_weight.shape[0]
|
||||
te_state_dict[diffusers_name] = down_weight
|
||||
te_state_dict[diffusers_name.replace(".down.", ".up.")] = sds_sd.pop(lora_name_up)
|
||||
|
||||
if lora_name_alpha in sds_sd:
|
||||
alpha = sds_sd.pop(lora_name_alpha).item()
|
||||
scale = alpha / sd_lora_rank
|
||||
|
||||
scale_down = scale
|
||||
scale_up = 1.0
|
||||
while scale_down * 2 < scale_up:
|
||||
scale_down *= 2
|
||||
scale_up /= 2
|
||||
|
||||
te_state_dict[diffusers_name] *= scale_down
|
||||
te_state_dict[diffusers_name.replace(".down.", ".up.")] *= scale_up
|
||||
|
||||
if len(sds_sd) > 0:
|
||||
logger.warning(f"Unsupported keys for ai-toolkit: {sds_sd.keys()}")
|
||||
|
||||
if te_state_dict:
|
||||
te_state_dict = {f"text_encoder.{module_name}": params for module_name, params in te_state_dict.items()}
|
||||
|
||||
new_state_dict = {**ait_sd, **te_state_dict}
|
||||
return new_state_dict
|
||||
|
||||
return _convert_sd_scripts_to_ai_toolkit(state_dict)
|
||||
|
||||
|
||||
@@ -18,6 +18,7 @@ import torch.nn as nn
|
||||
|
||||
from ...configuration_utils import ConfigMixin, register_to_config
|
||||
from ...loaders.single_file_model import FromOriginalModelMixin
|
||||
from ...utils import deprecate
|
||||
from ...utils.accelerate_utils import apply_forward_hook
|
||||
from ..attention_processor import (
|
||||
ADDED_KV_ATTENTION_PROCESSORS,
|
||||
@@ -245,6 +246,18 @@ class AutoencoderKL(ModelMixin, ConfigMixin, FromOriginalModelMixin):
|
||||
|
||||
self.set_attn_processor(processor)
|
||||
|
||||
def _encode(self, x: torch.Tensor) -> torch.Tensor:
|
||||
batch_size, num_channels, height, width = x.shape
|
||||
|
||||
if self.use_tiling and (width > self.tile_sample_min_size or height > self.tile_sample_min_size):
|
||||
return self._tiled_encode(x)
|
||||
|
||||
enc = self.encoder(x)
|
||||
if self.quant_conv is not None:
|
||||
enc = self.quant_conv(enc)
|
||||
|
||||
return enc
|
||||
|
||||
@apply_forward_hook
|
||||
def encode(
|
||||
self, x: torch.Tensor, return_dict: bool = True
|
||||
@@ -261,21 +274,13 @@ class AutoencoderKL(ModelMixin, ConfigMixin, FromOriginalModelMixin):
|
||||
The latent representations of the encoded images. If `return_dict` is True, a
|
||||
[`~models.autoencoder_kl.AutoencoderKLOutput`] is returned, otherwise a plain `tuple` is returned.
|
||||
"""
|
||||
if self.use_tiling and (x.shape[-1] > self.tile_sample_min_size or x.shape[-2] > self.tile_sample_min_size):
|
||||
return self.tiled_encode(x, return_dict=return_dict)
|
||||
|
||||
if self.use_slicing and x.shape[0] > 1:
|
||||
encoded_slices = [self.encoder(x_slice) for x_slice in x.split(1)]
|
||||
encoded_slices = [self._encode(x_slice) for x_slice in x.split(1)]
|
||||
h = torch.cat(encoded_slices)
|
||||
else:
|
||||
h = self.encoder(x)
|
||||
h = self._encode(x)
|
||||
|
||||
if self.quant_conv is not None:
|
||||
moments = self.quant_conv(h)
|
||||
else:
|
||||
moments = h
|
||||
|
||||
posterior = DiagonalGaussianDistribution(moments)
|
||||
posterior = DiagonalGaussianDistribution(h)
|
||||
|
||||
if not return_dict:
|
||||
return (posterior,)
|
||||
@@ -337,6 +342,54 @@ class AutoencoderKL(ModelMixin, ConfigMixin, FromOriginalModelMixin):
|
||||
b[:, :, :, x] = a[:, :, :, -blend_extent + x] * (1 - x / blend_extent) + b[:, :, :, x] * (x / blend_extent)
|
||||
return b
|
||||
|
||||
def _tiled_encode(self, x: torch.Tensor) -> torch.Tensor:
|
||||
r"""Encode a batch of images using a tiled encoder.
|
||||
|
||||
When this option is enabled, the VAE will split the input tensor into tiles to compute encoding in several
|
||||
steps. This is useful to keep memory use constant regardless of image size. The end result of tiled encoding is
|
||||
different from non-tiled encoding because each tile uses a different encoder. To avoid tiling artifacts, the
|
||||
tiles overlap and are blended together to form a smooth output. You may still see tile-sized changes in the
|
||||
output, but they should be much less noticeable.
|
||||
|
||||
Args:
|
||||
x (`torch.Tensor`): Input batch of images.
|
||||
|
||||
Returns:
|
||||
`torch.Tensor`:
|
||||
The latent representation of the encoded videos.
|
||||
"""
|
||||
|
||||
overlap_size = int(self.tile_sample_min_size * (1 - self.tile_overlap_factor))
|
||||
blend_extent = int(self.tile_latent_min_size * self.tile_overlap_factor)
|
||||
row_limit = self.tile_latent_min_size - blend_extent
|
||||
|
||||
# Split the image into 512x512 tiles and encode them separately.
|
||||
rows = []
|
||||
for i in range(0, x.shape[2], overlap_size):
|
||||
row = []
|
||||
for j in range(0, x.shape[3], overlap_size):
|
||||
tile = x[:, :, i : i + self.tile_sample_min_size, j : j + self.tile_sample_min_size]
|
||||
tile = self.encoder(tile)
|
||||
if self.config.use_quant_conv:
|
||||
tile = self.quant_conv(tile)
|
||||
row.append(tile)
|
||||
rows.append(row)
|
||||
result_rows = []
|
||||
for i, row in enumerate(rows):
|
||||
result_row = []
|
||||
for j, tile in enumerate(row):
|
||||
# blend the above tile and the left tile
|
||||
# to the current tile and add the current tile to the result row
|
||||
if i > 0:
|
||||
tile = self.blend_v(rows[i - 1][j], tile, blend_extent)
|
||||
if j > 0:
|
||||
tile = self.blend_h(row[j - 1], tile, blend_extent)
|
||||
result_row.append(tile[:, :, :row_limit, :row_limit])
|
||||
result_rows.append(torch.cat(result_row, dim=3))
|
||||
|
||||
enc = torch.cat(result_rows, dim=2)
|
||||
return enc
|
||||
|
||||
def tiled_encode(self, x: torch.Tensor, return_dict: bool = True) -> AutoencoderKLOutput:
|
||||
r"""Encode a batch of images using a tiled encoder.
|
||||
|
||||
@@ -356,6 +409,13 @@ class AutoencoderKL(ModelMixin, ConfigMixin, FromOriginalModelMixin):
|
||||
If return_dict is True, a [`~models.autoencoder_kl.AutoencoderKLOutput`] is returned, otherwise a plain
|
||||
`tuple` is returned.
|
||||
"""
|
||||
deprecation_message = (
|
||||
"The tiled_encode implementation supporting the `return_dict` parameter is deprecated. In the future, the "
|
||||
"implementation of this method will be replaced with that of `_tiled_encode` and you will no longer be able "
|
||||
"to pass `return_dict`. You will also have to create a `DiagonalGaussianDistribution()` from the returned value."
|
||||
)
|
||||
deprecate("tiled_encode", "1.0.0", deprecation_message, standard_warn=False)
|
||||
|
||||
overlap_size = int(self.tile_sample_min_size * (1 - self.tile_overlap_factor))
|
||||
blend_extent = int(self.tile_latent_min_size * self.tile_overlap_factor)
|
||||
row_limit = self.tile_latent_min_size - blend_extent
|
||||
|
||||
@@ -13,7 +13,7 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from typing import Optional, Tuple, Union
|
||||
from typing import Dict, Optional, Tuple, Union
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
@@ -41,7 +41,9 @@ class CogVideoXSafeConv3d(nn.Conv3d):
|
||||
"""
|
||||
|
||||
def forward(self, input: torch.Tensor) -> torch.Tensor:
|
||||
memory_count = torch.prod(torch.tensor(input.shape)).item() * 2 / 1024**3
|
||||
memory_count = (
|
||||
(input.shape[0] * input.shape[1] * input.shape[2] * input.shape[3] * input.shape[4]) * 2 / 1024**3
|
||||
)
|
||||
|
||||
# Set to 2GB, suitable for CuDNN
|
||||
if memory_count > 2:
|
||||
@@ -115,34 +117,29 @@ class CogVideoXCausalConv3d(nn.Module):
|
||||
dilation=dilation,
|
||||
)
|
||||
|
||||
self.conv_cache = None
|
||||
self.return_conv_cache = True
|
||||
|
||||
def fake_context_parallel_forward(self, inputs: torch.Tensor) -> torch.Tensor:
|
||||
def fake_context_parallel_forward(
|
||||
self, inputs: torch.Tensor, conv_cache: Optional[torch.Tensor] = None
|
||||
) -> torch.Tensor:
|
||||
kernel_size = self.time_kernel_size
|
||||
if kernel_size > 1:
|
||||
cached_inputs = (
|
||||
[self.conv_cache] if self.conv_cache is not None else [inputs[:, :, :1]] * (kernel_size - 1)
|
||||
)
|
||||
cached_inputs = [conv_cache] if conv_cache is not None else [inputs[:, :, :1]] * (kernel_size - 1)
|
||||
inputs = torch.cat(cached_inputs + [inputs], dim=2)
|
||||
return inputs
|
||||
|
||||
def _clear_fake_context_parallel_cache(self):
|
||||
del self.conv_cache
|
||||
self.conv_cache = None
|
||||
|
||||
def forward(self, inputs: torch.Tensor) -> torch.Tensor:
|
||||
inputs = self.fake_context_parallel_forward(inputs)
|
||||
|
||||
self._clear_fake_context_parallel_cache()
|
||||
# Note: we could move these to the cpu for a lower maximum memory usage but its only a few
|
||||
# hundred megabytes and so let's not do it for now
|
||||
self.conv_cache = inputs[:, :, -self.time_kernel_size + 1 :].clone()
|
||||
def forward(self, inputs: torch.Tensor, conv_cache: Optional[torch.Tensor] = None) -> torch.Tensor:
|
||||
inputs = self.fake_context_parallel_forward(inputs, conv_cache)
|
||||
if self.return_conv_cache:
|
||||
conv_cache = inputs[:, :, -self.time_kernel_size + 1 :].clone()
|
||||
else:
|
||||
conv_cache = None
|
||||
|
||||
padding_2d = (self.width_pad, self.width_pad, self.height_pad, self.height_pad)
|
||||
inputs = F.pad(inputs, padding_2d, mode="constant", value=0)
|
||||
|
||||
output = self.conv(inputs)
|
||||
return output
|
||||
return output, conv_cache
|
||||
|
||||
|
||||
class CogVideoXSpatialNorm3D(nn.Module):
|
||||
@@ -172,7 +169,12 @@ class CogVideoXSpatialNorm3D(nn.Module):
|
||||
self.conv_y = CogVideoXCausalConv3d(zq_channels, f_channels, kernel_size=1, stride=1)
|
||||
self.conv_b = CogVideoXCausalConv3d(zq_channels, f_channels, kernel_size=1, stride=1)
|
||||
|
||||
def forward(self, f: torch.Tensor, zq: torch.Tensor) -> torch.Tensor:
|
||||
def forward(
|
||||
self, f: torch.Tensor, zq: torch.Tensor, conv_cache: Optional[Dict[str, torch.Tensor]] = None
|
||||
) -> torch.Tensor:
|
||||
new_conv_cache = {}
|
||||
conv_cache = conv_cache or {}
|
||||
|
||||
if f.shape[2] > 1 and f.shape[2] % 2 == 1:
|
||||
f_first, f_rest = f[:, :, :1], f[:, :, 1:]
|
||||
f_first_size, f_rest_size = f_first.shape[-3:], f_rest.shape[-3:]
|
||||
@@ -183,9 +185,12 @@ class CogVideoXSpatialNorm3D(nn.Module):
|
||||
else:
|
||||
zq = F.interpolate(zq, size=f.shape[-3:])
|
||||
|
||||
conv_y, new_conv_cache["conv_y"] = self.conv_y(zq, conv_cache=conv_cache.get("conv_y"))
|
||||
conv_b, new_conv_cache["conv_b"] = self.conv_b(zq, conv_cache=conv_cache.get("conv_b"))
|
||||
|
||||
norm_f = self.norm_layer(f)
|
||||
new_f = norm_f * self.conv_y(zq) + self.conv_b(zq)
|
||||
return new_f
|
||||
new_f = norm_f * conv_y + conv_b
|
||||
return new_f, new_conv_cache
|
||||
|
||||
|
||||
class CogVideoXResnetBlock3D(nn.Module):
|
||||
@@ -236,6 +241,7 @@ class CogVideoXResnetBlock3D(nn.Module):
|
||||
self.out_channels = out_channels
|
||||
self.nonlinearity = get_activation(non_linearity)
|
||||
self.use_conv_shortcut = conv_shortcut
|
||||
self.spatial_norm_dim = spatial_norm_dim
|
||||
|
||||
if spatial_norm_dim is None:
|
||||
self.norm1 = nn.GroupNorm(num_channels=in_channels, num_groups=groups, eps=eps)
|
||||
@@ -279,34 +285,43 @@ class CogVideoXResnetBlock3D(nn.Module):
|
||||
inputs: torch.Tensor,
|
||||
temb: Optional[torch.Tensor] = None,
|
||||
zq: Optional[torch.Tensor] = None,
|
||||
conv_cache: Optional[Dict[str, torch.Tensor]] = None,
|
||||
) -> torch.Tensor:
|
||||
new_conv_cache = {}
|
||||
conv_cache = conv_cache or {}
|
||||
|
||||
hidden_states = inputs
|
||||
|
||||
if zq is not None:
|
||||
hidden_states = self.norm1(hidden_states, zq)
|
||||
hidden_states, new_conv_cache["norm1"] = self.norm1(hidden_states, zq, conv_cache=conv_cache.get("norm1"))
|
||||
else:
|
||||
hidden_states = self.norm1(hidden_states)
|
||||
|
||||
hidden_states = self.nonlinearity(hidden_states)
|
||||
hidden_states = self.conv1(hidden_states)
|
||||
hidden_states, new_conv_cache["conv1"] = self.conv1(hidden_states, conv_cache=conv_cache.get("conv1"))
|
||||
|
||||
if temb is not None:
|
||||
hidden_states = hidden_states + self.temb_proj(self.nonlinearity(temb))[:, :, None, None, None]
|
||||
|
||||
if zq is not None:
|
||||
hidden_states = self.norm2(hidden_states, zq)
|
||||
hidden_states, new_conv_cache["norm2"] = self.norm2(hidden_states, zq, conv_cache=conv_cache.get("norm2"))
|
||||
else:
|
||||
hidden_states = self.norm2(hidden_states)
|
||||
|
||||
hidden_states = self.nonlinearity(hidden_states)
|
||||
hidden_states = self.dropout(hidden_states)
|
||||
hidden_states = self.conv2(hidden_states)
|
||||
hidden_states, new_conv_cache["conv2"] = self.conv2(hidden_states, conv_cache=conv_cache.get("conv2"))
|
||||
|
||||
if self.in_channels != self.out_channels:
|
||||
inputs = self.conv_shortcut(inputs)
|
||||
if self.use_conv_shortcut:
|
||||
inputs, new_conv_cache["conv_shortcut"] = self.conv_shortcut(
|
||||
inputs, conv_cache=conv_cache.get("conv_shortcut")
|
||||
)
|
||||
else:
|
||||
inputs = self.conv_shortcut(inputs)
|
||||
|
||||
hidden_states = hidden_states + inputs
|
||||
return hidden_states
|
||||
return hidden_states, new_conv_cache
|
||||
|
||||
|
||||
class CogVideoXDownBlock3D(nn.Module):
|
||||
@@ -392,8 +407,16 @@ class CogVideoXDownBlock3D(nn.Module):
|
||||
hidden_states: torch.Tensor,
|
||||
temb: Optional[torch.Tensor] = None,
|
||||
zq: Optional[torch.Tensor] = None,
|
||||
conv_cache: Optional[Dict[str, torch.Tensor]] = None,
|
||||
) -> torch.Tensor:
|
||||
for resnet in self.resnets:
|
||||
r"""Forward method of the `CogVideoXDownBlock3D` class."""
|
||||
|
||||
new_conv_cache = {}
|
||||
conv_cache = conv_cache or {}
|
||||
|
||||
for i, resnet in enumerate(self.resnets):
|
||||
conv_cache_key = f"resnet_{i}"
|
||||
|
||||
if self.training and self.gradient_checkpointing:
|
||||
|
||||
def create_custom_forward(module):
|
||||
@@ -402,17 +425,23 @@ class CogVideoXDownBlock3D(nn.Module):
|
||||
|
||||
return create_forward
|
||||
|
||||
hidden_states = torch.utils.checkpoint.checkpoint(
|
||||
create_custom_forward(resnet), hidden_states, temb, zq
|
||||
hidden_states, new_conv_cache[conv_cache_key] = torch.utils.checkpoint.checkpoint(
|
||||
create_custom_forward(resnet),
|
||||
hidden_states,
|
||||
temb,
|
||||
zq,
|
||||
conv_cache=conv_cache.get(conv_cache_key),
|
||||
)
|
||||
else:
|
||||
hidden_states = resnet(hidden_states, temb, zq)
|
||||
hidden_states, new_conv_cache[conv_cache_key] = resnet(
|
||||
hidden_states, temb, zq, conv_cache=conv_cache.get(conv_cache_key)
|
||||
)
|
||||
|
||||
if self.downsamplers is not None:
|
||||
for downsampler in self.downsamplers:
|
||||
hidden_states = downsampler(hidden_states)
|
||||
|
||||
return hidden_states
|
||||
return hidden_states, new_conv_cache
|
||||
|
||||
|
||||
class CogVideoXMidBlock3D(nn.Module):
|
||||
@@ -480,8 +509,16 @@ class CogVideoXMidBlock3D(nn.Module):
|
||||
hidden_states: torch.Tensor,
|
||||
temb: Optional[torch.Tensor] = None,
|
||||
zq: Optional[torch.Tensor] = None,
|
||||
conv_cache: Optional[Dict[str, torch.Tensor]] = None,
|
||||
) -> torch.Tensor:
|
||||
for resnet in self.resnets:
|
||||
r"""Forward method of the `CogVideoXMidBlock3D` class."""
|
||||
|
||||
new_conv_cache = {}
|
||||
conv_cache = conv_cache or {}
|
||||
|
||||
for i, resnet in enumerate(self.resnets):
|
||||
conv_cache_key = f"resnet_{i}"
|
||||
|
||||
if self.training and self.gradient_checkpointing:
|
||||
|
||||
def create_custom_forward(module):
|
||||
@@ -490,13 +527,15 @@ class CogVideoXMidBlock3D(nn.Module):
|
||||
|
||||
return create_forward
|
||||
|
||||
hidden_states = torch.utils.checkpoint.checkpoint(
|
||||
create_custom_forward(resnet), hidden_states, temb, zq
|
||||
hidden_states, new_conv_cache[conv_cache_key] = torch.utils.checkpoint.checkpoint(
|
||||
create_custom_forward(resnet), hidden_states, temb, zq, conv_cache=conv_cache.get(conv_cache_key)
|
||||
)
|
||||
else:
|
||||
hidden_states = resnet(hidden_states, temb, zq)
|
||||
hidden_states, new_conv_cache[conv_cache_key] = resnet(
|
||||
hidden_states, temb, zq, conv_cache=conv_cache.get(conv_cache_key)
|
||||
)
|
||||
|
||||
return hidden_states
|
||||
return hidden_states, new_conv_cache
|
||||
|
||||
|
||||
class CogVideoXUpBlock3D(nn.Module):
|
||||
@@ -584,9 +623,16 @@ class CogVideoXUpBlock3D(nn.Module):
|
||||
hidden_states: torch.Tensor,
|
||||
temb: Optional[torch.Tensor] = None,
|
||||
zq: Optional[torch.Tensor] = None,
|
||||
conv_cache: Optional[Dict[str, torch.Tensor]] = None,
|
||||
) -> torch.Tensor:
|
||||
r"""Forward method of the `CogVideoXUpBlock3D` class."""
|
||||
for resnet in self.resnets:
|
||||
|
||||
new_conv_cache = {}
|
||||
conv_cache = conv_cache or {}
|
||||
|
||||
for i, resnet in enumerate(self.resnets):
|
||||
conv_cache_key = f"resnet_{i}"
|
||||
|
||||
if self.training and self.gradient_checkpointing:
|
||||
|
||||
def create_custom_forward(module):
|
||||
@@ -595,17 +641,23 @@ class CogVideoXUpBlock3D(nn.Module):
|
||||
|
||||
return create_forward
|
||||
|
||||
hidden_states = torch.utils.checkpoint.checkpoint(
|
||||
create_custom_forward(resnet), hidden_states, temb, zq
|
||||
hidden_states, new_conv_cache[conv_cache_key] = torch.utils.checkpoint.checkpoint(
|
||||
create_custom_forward(resnet),
|
||||
hidden_states,
|
||||
temb,
|
||||
zq,
|
||||
conv_cache=conv_cache.get(conv_cache_key),
|
||||
)
|
||||
else:
|
||||
hidden_states = resnet(hidden_states, temb, zq)
|
||||
hidden_states, new_conv_cache[conv_cache_key] = resnet(
|
||||
hidden_states, temb, zq, conv_cache=conv_cache.get(conv_cache_key)
|
||||
)
|
||||
|
||||
if self.upsamplers is not None:
|
||||
for upsampler in self.upsamplers:
|
||||
hidden_states = upsampler(hidden_states)
|
||||
|
||||
return hidden_states
|
||||
return hidden_states, new_conv_cache
|
||||
|
||||
|
||||
class CogVideoXEncoder3D(nn.Module):
|
||||
@@ -705,9 +757,18 @@ class CogVideoXEncoder3D(nn.Module):
|
||||
|
||||
self.gradient_checkpointing = False
|
||||
|
||||
def forward(self, sample: torch.Tensor, temb: Optional[torch.Tensor] = None) -> torch.Tensor:
|
||||
def forward(
|
||||
self,
|
||||
sample: torch.Tensor,
|
||||
temb: Optional[torch.Tensor] = None,
|
||||
conv_cache: Optional[Dict[str, torch.Tensor]] = None,
|
||||
) -> torch.Tensor:
|
||||
r"""The forward method of the `CogVideoXEncoder3D` class."""
|
||||
hidden_states = self.conv_in(sample)
|
||||
|
||||
new_conv_cache = {}
|
||||
conv_cache = conv_cache or {}
|
||||
|
||||
hidden_states, new_conv_cache["conv_in"] = self.conv_in(sample, conv_cache=conv_cache.get("conv_in"))
|
||||
|
||||
if self.training and self.gradient_checkpointing:
|
||||
|
||||
@@ -718,28 +779,44 @@ class CogVideoXEncoder3D(nn.Module):
|
||||
return custom_forward
|
||||
|
||||
# 1. Down
|
||||
for down_block in self.down_blocks:
|
||||
hidden_states = torch.utils.checkpoint.checkpoint(
|
||||
create_custom_forward(down_block), hidden_states, temb, None
|
||||
for i, down_block in enumerate(self.down_blocks):
|
||||
conv_cache_key = f"down_block_{i}"
|
||||
hidden_states, new_conv_cache[conv_cache_key] = torch.utils.checkpoint.checkpoint(
|
||||
create_custom_forward(down_block),
|
||||
hidden_states,
|
||||
temb,
|
||||
None,
|
||||
conv_cache=conv_cache.get(conv_cache_key),
|
||||
)
|
||||
|
||||
# 2. Mid
|
||||
hidden_states = torch.utils.checkpoint.checkpoint(
|
||||
create_custom_forward(self.mid_block), hidden_states, temb, None
|
||||
hidden_states, new_conv_cache["mid_block"] = torch.utils.checkpoint.checkpoint(
|
||||
create_custom_forward(self.mid_block),
|
||||
hidden_states,
|
||||
temb,
|
||||
None,
|
||||
conv_cache=conv_cache.get("mid_block"),
|
||||
)
|
||||
else:
|
||||
# 1. Down
|
||||
for down_block in self.down_blocks:
|
||||
hidden_states = down_block(hidden_states, temb, None)
|
||||
for i, down_block in enumerate(self.down_blocks):
|
||||
conv_cache_key = f"down_block_{i}"
|
||||
hidden_states, new_conv_cache[conv_cache_key] = down_block(
|
||||
hidden_states, temb, None, conv_cache=conv_cache.get(conv_cache_key)
|
||||
)
|
||||
|
||||
# 2. Mid
|
||||
hidden_states = self.mid_block(hidden_states, temb, None)
|
||||
hidden_states, new_conv_cache["mid_block"] = self.mid_block(
|
||||
hidden_states, temb, None, conv_cache=conv_cache.get("mid_block")
|
||||
)
|
||||
|
||||
# 3. Post-process
|
||||
hidden_states = self.norm_out(hidden_states)
|
||||
hidden_states = self.conv_act(hidden_states)
|
||||
hidden_states = self.conv_out(hidden_states)
|
||||
return hidden_states
|
||||
|
||||
hidden_states, new_conv_cache["conv_out"] = self.conv_out(hidden_states, conv_cache=conv_cache.get("conv_out"))
|
||||
|
||||
return hidden_states, new_conv_cache
|
||||
|
||||
|
||||
class CogVideoXDecoder3D(nn.Module):
|
||||
@@ -846,9 +923,18 @@ class CogVideoXDecoder3D(nn.Module):
|
||||
|
||||
self.gradient_checkpointing = False
|
||||
|
||||
def forward(self, sample: torch.Tensor, temb: Optional[torch.Tensor] = None) -> torch.Tensor:
|
||||
def forward(
|
||||
self,
|
||||
sample: torch.Tensor,
|
||||
temb: Optional[torch.Tensor] = None,
|
||||
conv_cache: Optional[Dict[str, torch.Tensor]] = None,
|
||||
) -> torch.Tensor:
|
||||
r"""The forward method of the `CogVideoXDecoder3D` class."""
|
||||
hidden_states = self.conv_in(sample)
|
||||
|
||||
new_conv_cache = {}
|
||||
conv_cache = conv_cache or {}
|
||||
|
||||
hidden_states, new_conv_cache["conv_in"] = self.conv_in(sample, conv_cache=conv_cache.get("conv_in"))
|
||||
|
||||
if self.training and self.gradient_checkpointing:
|
||||
|
||||
@@ -859,28 +945,45 @@ class CogVideoXDecoder3D(nn.Module):
|
||||
return custom_forward
|
||||
|
||||
# 1. Mid
|
||||
hidden_states = torch.utils.checkpoint.checkpoint(
|
||||
create_custom_forward(self.mid_block), hidden_states, temb, sample
|
||||
hidden_states, new_conv_cache["mid_block"] = torch.utils.checkpoint.checkpoint(
|
||||
create_custom_forward(self.mid_block),
|
||||
hidden_states,
|
||||
temb,
|
||||
sample,
|
||||
conv_cache=conv_cache.get("mid_block"),
|
||||
)
|
||||
|
||||
# 2. Up
|
||||
for up_block in self.up_blocks:
|
||||
hidden_states = torch.utils.checkpoint.checkpoint(
|
||||
create_custom_forward(up_block), hidden_states, temb, sample
|
||||
for i, up_block in enumerate(self.up_blocks):
|
||||
conv_cache_key = f"up_block_{i}"
|
||||
hidden_states, new_conv_cache[conv_cache_key] = torch.utils.checkpoint.checkpoint(
|
||||
create_custom_forward(up_block),
|
||||
hidden_states,
|
||||
temb,
|
||||
sample,
|
||||
conv_cache=conv_cache.get(conv_cache_key),
|
||||
)
|
||||
else:
|
||||
# 1. Mid
|
||||
hidden_states = self.mid_block(hidden_states, temb, sample)
|
||||
hidden_states, new_conv_cache["mid_block"] = self.mid_block(
|
||||
hidden_states, temb, sample, conv_cache=conv_cache.get("mid_block")
|
||||
)
|
||||
|
||||
# 2. Up
|
||||
for up_block in self.up_blocks:
|
||||
hidden_states = up_block(hidden_states, temb, sample)
|
||||
for i, up_block in enumerate(self.up_blocks):
|
||||
conv_cache_key = f"up_block_{i}"
|
||||
hidden_states, new_conv_cache[conv_cache_key] = up_block(
|
||||
hidden_states, temb, sample, conv_cache=conv_cache.get(conv_cache_key)
|
||||
)
|
||||
|
||||
# 3. Post-process
|
||||
hidden_states = self.norm_out(hidden_states, sample)
|
||||
hidden_states, new_conv_cache["norm_out"] = self.norm_out(
|
||||
hidden_states, sample, conv_cache=conv_cache.get("norm_out")
|
||||
)
|
||||
hidden_states = self.conv_act(hidden_states)
|
||||
hidden_states = self.conv_out(hidden_states)
|
||||
return hidden_states
|
||||
hidden_states, new_conv_cache["conv_out"] = self.conv_out(hidden_states, conv_cache=conv_cache.get("conv_out"))
|
||||
|
||||
return hidden_states, new_conv_cache
|
||||
|
||||
|
||||
class AutoencoderKLCogVideoX(ModelMixin, ConfigMixin, FromOriginalModelMixin):
|
||||
@@ -981,6 +1084,7 @@ class AutoencoderKLCogVideoX(ModelMixin, ConfigMixin, FromOriginalModelMixin):
|
||||
|
||||
self.use_slicing = False
|
||||
self.use_tiling = False
|
||||
self.use_framewise_batching = True
|
||||
|
||||
# Can be increased to decode more latent frames at once, but comes at a reasonable memory cost and it is not
|
||||
# recommended because the temporal parts of the VAE, here, are tricky to understand.
|
||||
@@ -1019,12 +1123,6 @@ class AutoencoderKLCogVideoX(ModelMixin, ConfigMixin, FromOriginalModelMixin):
|
||||
if isinstance(module, (CogVideoXEncoder3D, CogVideoXDecoder3D)):
|
||||
module.gradient_checkpointing = value
|
||||
|
||||
def _clear_fake_context_parallel_cache(self):
|
||||
for name, module in self.named_modules():
|
||||
if isinstance(module, CogVideoXCausalConv3d):
|
||||
logger.debug(f"Clearing fake Context Parallel cache for layer: {name}")
|
||||
module._clear_fake_context_parallel_cache()
|
||||
|
||||
def enable_tiling(
|
||||
self,
|
||||
tile_sample_min_height: Optional[int] = None,
|
||||
@@ -1082,6 +1180,20 @@ class AutoencoderKLCogVideoX(ModelMixin, ConfigMixin, FromOriginalModelMixin):
|
||||
"""
|
||||
self.use_slicing = False
|
||||
|
||||
def enable_framewise_batching(self) -> None:
|
||||
self.use_framewise_batching = True
|
||||
|
||||
for name, module in self.named_modules():
|
||||
if isinstance(module, CogVideoXCausalConv3d):
|
||||
module.return_conv_cache = True
|
||||
|
||||
def disable_framewise_batching(self) -> None:
|
||||
self.use_framewise_batching = False
|
||||
|
||||
for name, module in self.named_modules():
|
||||
if isinstance(module, CogVideoXCausalConv3d):
|
||||
module.return_conv_cache = False
|
||||
|
||||
def _encode(self, x: torch.Tensor) -> torch.Tensor:
|
||||
batch_size, num_channels, num_frames, height, width = x.shape
|
||||
|
||||
@@ -1091,19 +1203,26 @@ class AutoencoderKLCogVideoX(ModelMixin, ConfigMixin, FromOriginalModelMixin):
|
||||
frame_batch_size = self.num_sample_frames_batch_size
|
||||
# Note: We expect the number of frames to be either `1` or `frame_batch_size * k` or `frame_batch_size * k + 1` for some k.
|
||||
num_batches = num_frames // frame_batch_size if num_frames > 1 else 1
|
||||
enc = []
|
||||
for i in range(num_batches):
|
||||
remaining_frames = num_frames % frame_batch_size
|
||||
start_frame = frame_batch_size * i + (0 if i == 0 else remaining_frames)
|
||||
end_frame = frame_batch_size * (i + 1) + remaining_frames
|
||||
x_intermediate = x[:, :, start_frame:end_frame]
|
||||
x_intermediate = self.encoder(x_intermediate)
|
||||
if self.quant_conv is not None:
|
||||
x_intermediate = self.quant_conv(x_intermediate)
|
||||
enc.append(x_intermediate)
|
||||
conv_cache = None
|
||||
|
||||
self._clear_fake_context_parallel_cache()
|
||||
enc = torch.cat(enc, dim=2)
|
||||
if self.use_framewise_batching:
|
||||
enc = []
|
||||
|
||||
for i in range(num_batches):
|
||||
remaining_frames = num_frames % frame_batch_size
|
||||
start_frame = frame_batch_size * i + (0 if i == 0 else remaining_frames)
|
||||
end_frame = frame_batch_size * (i + 1) + remaining_frames
|
||||
x_intermediate = x[:, :, start_frame:end_frame]
|
||||
x_intermediate, conv_cache = self.encoder(x_intermediate, conv_cache=conv_cache)
|
||||
if self.quant_conv is not None:
|
||||
x_intermediate = self.quant_conv(x_intermediate)
|
||||
enc.append(x_intermediate)
|
||||
|
||||
enc = torch.cat(enc, dim=2)
|
||||
else:
|
||||
enc, _ = self.encoder(x, conv_cache=conv_cache)
|
||||
if self.quant_conv is not None:
|
||||
enc = self.quant_conv(enc)
|
||||
|
||||
return enc
|
||||
|
||||
@@ -1142,20 +1261,27 @@ class AutoencoderKLCogVideoX(ModelMixin, ConfigMixin, FromOriginalModelMixin):
|
||||
return self.tiled_decode(z, return_dict=return_dict)
|
||||
|
||||
frame_batch_size = self.num_latent_frames_batch_size
|
||||
num_batches = num_frames // frame_batch_size
|
||||
dec = []
|
||||
for i in range(num_batches):
|
||||
remaining_frames = num_frames % frame_batch_size
|
||||
start_frame = frame_batch_size * i + (0 if i == 0 else remaining_frames)
|
||||
end_frame = frame_batch_size * (i + 1) + remaining_frames
|
||||
z_intermediate = z[:, :, start_frame:end_frame]
|
||||
if self.post_quant_conv is not None:
|
||||
z_intermediate = self.post_quant_conv(z_intermediate)
|
||||
z_intermediate = self.decoder(z_intermediate)
|
||||
dec.append(z_intermediate)
|
||||
num_batches = max(num_frames // frame_batch_size, 1)
|
||||
conv_cache = None
|
||||
|
||||
self._clear_fake_context_parallel_cache()
|
||||
dec = torch.cat(dec, dim=2)
|
||||
if self.use_framewise_batching:
|
||||
dec = []
|
||||
|
||||
for i in range(num_batches):
|
||||
remaining_frames = num_frames % frame_batch_size
|
||||
start_frame = frame_batch_size * i + (0 if i == 0 else remaining_frames)
|
||||
end_frame = frame_batch_size * (i + 1) + remaining_frames
|
||||
z_intermediate = z[:, :, start_frame:end_frame]
|
||||
if self.post_quant_conv is not None:
|
||||
z_intermediate = self.post_quant_conv(z_intermediate)
|
||||
z_intermediate, conv_cache = self.decoder(z_intermediate, conv_cache=conv_cache)
|
||||
dec.append(z_intermediate)
|
||||
|
||||
dec = torch.cat(dec, dim=2)
|
||||
else:
|
||||
if self.post_quant_conv is not None:
|
||||
dec = self.post_quant_conv(z)
|
||||
dec, _ = self.decoder(z, conv_cache=conv_cache)
|
||||
|
||||
if not return_dict:
|
||||
return (dec,)
|
||||
@@ -1238,7 +1364,9 @@ class AutoencoderKLCogVideoX(ModelMixin, ConfigMixin, FromOriginalModelMixin):
|
||||
for j in range(0, width, overlap_width):
|
||||
# Note: We expect the number of frames to be either `1` or `frame_batch_size * k` or `frame_batch_size * k + 1` for some k.
|
||||
num_batches = num_frames // frame_batch_size if num_frames > 1 else 1
|
||||
conv_cache = None
|
||||
time = []
|
||||
|
||||
for k in range(num_batches):
|
||||
remaining_frames = num_frames % frame_batch_size
|
||||
start_frame = frame_batch_size * k + (0 if k == 0 else remaining_frames)
|
||||
@@ -1250,11 +1378,11 @@ class AutoencoderKLCogVideoX(ModelMixin, ConfigMixin, FromOriginalModelMixin):
|
||||
i : i + self.tile_sample_min_height,
|
||||
j : j + self.tile_sample_min_width,
|
||||
]
|
||||
tile = self.encoder(tile)
|
||||
tile, conv_cache = self.encoder(tile, conv_cache=conv_cache)
|
||||
if self.quant_conv is not None:
|
||||
tile = self.quant_conv(tile)
|
||||
time.append(tile)
|
||||
self._clear_fake_context_parallel_cache()
|
||||
|
||||
row.append(torch.cat(time, dim=2))
|
||||
rows.append(row)
|
||||
|
||||
@@ -1315,7 +1443,9 @@ class AutoencoderKLCogVideoX(ModelMixin, ConfigMixin, FromOriginalModelMixin):
|
||||
row = []
|
||||
for j in range(0, width, overlap_width):
|
||||
num_batches = num_frames // frame_batch_size
|
||||
conv_cache = None
|
||||
time = []
|
||||
|
||||
for k in range(num_batches):
|
||||
remaining_frames = num_frames % frame_batch_size
|
||||
start_frame = frame_batch_size * k + (0 if k == 0 else remaining_frames)
|
||||
@@ -1329,9 +1459,9 @@ class AutoencoderKLCogVideoX(ModelMixin, ConfigMixin, FromOriginalModelMixin):
|
||||
]
|
||||
if self.post_quant_conv is not None:
|
||||
tile = self.post_quant_conv(tile)
|
||||
tile = self.decoder(tile)
|
||||
tile, conv_cache = self.decoder(tile, conv_cache=conv_cache)
|
||||
time.append(tile)
|
||||
self._clear_fake_context_parallel_cache()
|
||||
|
||||
row.append(torch.cat(time, dim=2))
|
||||
rows.append(row)
|
||||
|
||||
|
||||
@@ -502,16 +502,17 @@ class FluxMultiControlNetModel(ModelMixin):
|
||||
control_block_samples = block_samples
|
||||
control_single_block_samples = single_block_samples
|
||||
else:
|
||||
control_block_samples = [
|
||||
control_block_sample + block_sample
|
||||
for control_block_sample, block_sample in zip(control_block_samples, block_samples)
|
||||
]
|
||||
|
||||
control_single_block_samples = [
|
||||
control_single_block_sample + block_sample
|
||||
for control_single_block_sample, block_sample in zip(
|
||||
control_single_block_samples, single_block_samples
|
||||
)
|
||||
]
|
||||
if block_samples is not None and control_block_samples is not None:
|
||||
control_block_samples = [
|
||||
control_block_sample + block_sample
|
||||
for control_block_sample, block_sample in zip(control_block_samples, block_samples)
|
||||
]
|
||||
if single_block_samples is not None and control_single_block_samples is not None:
|
||||
control_single_block_samples = [
|
||||
control_single_block_sample + block_sample
|
||||
for control_single_block_sample, block_sample in zip(
|
||||
control_single_block_samples, single_block_samples
|
||||
)
|
||||
]
|
||||
|
||||
return control_block_samples, control_single_block_samples
|
||||
|
||||
@@ -31,6 +31,7 @@ from ..utils import (
|
||||
WEIGHTS_INDEX_NAME,
|
||||
_add_variant,
|
||||
_get_model_file,
|
||||
deprecate,
|
||||
is_accelerate_available,
|
||||
is_torch_version,
|
||||
logging,
|
||||
@@ -228,3 +229,67 @@ def _fetch_index_file(
|
||||
index_file = None
|
||||
|
||||
return index_file
|
||||
|
||||
|
||||
def _fetch_index_file_legacy(
|
||||
is_local,
|
||||
pretrained_model_name_or_path,
|
||||
subfolder,
|
||||
use_safetensors,
|
||||
cache_dir,
|
||||
variant,
|
||||
force_download,
|
||||
proxies,
|
||||
local_files_only,
|
||||
token,
|
||||
revision,
|
||||
user_agent,
|
||||
commit_hash,
|
||||
):
|
||||
if is_local:
|
||||
index_file = Path(
|
||||
pretrained_model_name_or_path,
|
||||
subfolder or "",
|
||||
SAFE_WEIGHTS_INDEX_NAME if use_safetensors else WEIGHTS_INDEX_NAME,
|
||||
).as_posix()
|
||||
splits = index_file.split(".")
|
||||
split_index = -3 if ".cache" in index_file else -2
|
||||
splits = splits[:-split_index] + [variant] + splits[-split_index:]
|
||||
index_file = ".".join(splits)
|
||||
if os.path.exists(index_file):
|
||||
deprecation_message = f"This serialization format is now deprecated to standardize the serialization format between `transformers` and `diffusers`. We recommend you to remove the existing files associated with the current variant ({variant}) and re-obtain them by running a `save_pretrained()`."
|
||||
deprecate("legacy_sharded_ckpts_with_variant", "1.0.0", deprecation_message, standard_warn=False)
|
||||
index_file = Path(index_file)
|
||||
else:
|
||||
index_file = None
|
||||
else:
|
||||
if variant is not None:
|
||||
index_file_in_repo = Path(
|
||||
subfolder or "",
|
||||
SAFE_WEIGHTS_INDEX_NAME if use_safetensors else WEIGHTS_INDEX_NAME,
|
||||
).as_posix()
|
||||
splits = index_file_in_repo.split(".")
|
||||
split_index = -2
|
||||
splits = splits[:-split_index] + [variant] + splits[-split_index:]
|
||||
index_file_in_repo = ".".join(splits)
|
||||
try:
|
||||
index_file = _get_model_file(
|
||||
pretrained_model_name_or_path,
|
||||
weights_name=index_file_in_repo,
|
||||
cache_dir=cache_dir,
|
||||
force_download=force_download,
|
||||
proxies=proxies,
|
||||
local_files_only=local_files_only,
|
||||
token=token,
|
||||
revision=revision,
|
||||
subfolder=None,
|
||||
user_agent=user_agent,
|
||||
commit_hash=commit_hash,
|
||||
)
|
||||
index_file = Path(index_file)
|
||||
deprecation_message = f"This serialization format is now deprecated to standardize the serialization format between `transformers` and `diffusers`. We recommend you to remove the existing files associated with the current variant ({variant}) and re-obtain them by running a `save_pretrained()`."
|
||||
deprecate("legacy_sharded_ckpts_with_variant", "1.0.0", deprecation_message, standard_warn=False)
|
||||
except (EntryNotFoundError, EnvironmentError):
|
||||
index_file = None
|
||||
|
||||
return index_file
|
||||
|
||||
@@ -54,6 +54,7 @@ from ..utils.hub_utils import (
|
||||
from .model_loading_utils import (
|
||||
_determine_device_map,
|
||||
_fetch_index_file,
|
||||
_fetch_index_file_legacy,
|
||||
_load_state_dict_into_model,
|
||||
load_model_dict_into_meta,
|
||||
load_state_dict,
|
||||
@@ -309,11 +310,9 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
|
||||
|
||||
weights_name = SAFETENSORS_WEIGHTS_NAME if safe_serialization else WEIGHTS_NAME
|
||||
weights_name = _add_variant(weights_name, variant)
|
||||
weight_name_split = weights_name.split(".")
|
||||
if len(weight_name_split) in [2, 3]:
|
||||
weights_name_pattern = weight_name_split[0] + "{suffix}." + ".".join(weight_name_split[1:])
|
||||
else:
|
||||
raise ValueError(f"Invalid {weights_name} provided.")
|
||||
weights_name_pattern = weights_name.replace(".bin", "{suffix}.bin").replace(
|
||||
".safetensors", "{suffix}.safetensors"
|
||||
)
|
||||
|
||||
os.makedirs(save_directory, exist_ok=True)
|
||||
|
||||
@@ -624,21 +623,26 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
|
||||
is_sharded = False
|
||||
index_file = None
|
||||
is_local = os.path.isdir(pretrained_model_name_or_path)
|
||||
index_file = _fetch_index_file(
|
||||
is_local=is_local,
|
||||
pretrained_model_name_or_path=pretrained_model_name_or_path,
|
||||
subfolder=subfolder or "",
|
||||
use_safetensors=use_safetensors,
|
||||
cache_dir=cache_dir,
|
||||
variant=variant,
|
||||
force_download=force_download,
|
||||
proxies=proxies,
|
||||
local_files_only=local_files_only,
|
||||
token=token,
|
||||
revision=revision,
|
||||
user_agent=user_agent,
|
||||
commit_hash=commit_hash,
|
||||
)
|
||||
index_file_kwargs = {
|
||||
"is_local": is_local,
|
||||
"pretrained_model_name_or_path": pretrained_model_name_or_path,
|
||||
"subfolder": subfolder or "",
|
||||
"use_safetensors": use_safetensors,
|
||||
"cache_dir": cache_dir,
|
||||
"variant": variant,
|
||||
"force_download": force_download,
|
||||
"proxies": proxies,
|
||||
"local_files_only": local_files_only,
|
||||
"token": token,
|
||||
"revision": revision,
|
||||
"user_agent": user_agent,
|
||||
"commit_hash": commit_hash,
|
||||
}
|
||||
index_file = _fetch_index_file(**index_file_kwargs)
|
||||
# In case the index file was not found we still have to consider the legacy format.
|
||||
# this becomes applicable when the variant is not None.
|
||||
if variant is not None and (index_file is None or not os.path.exists(index_file)):
|
||||
index_file = _fetch_index_file_legacy(**index_file_kwargs)
|
||||
if index_file is not None and index_file.is_file():
|
||||
is_sharded = True
|
||||
|
||||
|
||||
@@ -19,6 +19,7 @@ import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
from ..utils import deprecate
|
||||
from ..utils.import_utils import is_torch_version
|
||||
from .normalization import RMSNorm
|
||||
|
||||
|
||||
@@ -151,11 +152,10 @@ class Upsample2D(nn.Module):
|
||||
if self.use_conv_transpose:
|
||||
return self.conv(hidden_states)
|
||||
|
||||
# Cast to float32 to as 'upsample_nearest2d_out_frame' op does not support bfloat16
|
||||
# TODO(Suraj): Remove this cast once the issue is fixed in PyTorch
|
||||
# https://github.com/pytorch/pytorch/issues/86679
|
||||
# Cast to float32 to as 'upsample_nearest2d_out_frame' op does not support bfloat16 until PyTorch 2.1
|
||||
# https://github.com/pytorch/pytorch/issues/86679#issuecomment-1783978767
|
||||
dtype = hidden_states.dtype
|
||||
if dtype == torch.bfloat16:
|
||||
if dtype == torch.bfloat16 and is_torch_version("<", "2.1"):
|
||||
hidden_states = hidden_states.to(torch.float32)
|
||||
|
||||
# upsample_nearest_nhwc fails with large batch sizes. see https://github.com/huggingface/diffusers/issues/984
|
||||
@@ -170,8 +170,8 @@ class Upsample2D(nn.Module):
|
||||
else:
|
||||
hidden_states = F.interpolate(hidden_states, size=output_size, mode="nearest")
|
||||
|
||||
# If the input is bfloat16, we cast back to bfloat16
|
||||
if dtype == torch.bfloat16:
|
||||
# Cast back to original dtype
|
||||
if dtype == torch.bfloat16 and is_torch_version("<", "2.1"):
|
||||
hidden_states = hidden_states.to(dtype)
|
||||
|
||||
# TODO(Suraj, Patrick) - clean up after weight dicts are correctly renamed
|
||||
|
||||
@@ -158,6 +158,7 @@ else:
|
||||
)
|
||||
_import_structure["pag"].extend(
|
||||
[
|
||||
"StableDiffusionControlNetPAGInpaintPipeline",
|
||||
"AnimateDiffPAGPipeline",
|
||||
"KolorsPAGPipeline",
|
||||
"HunyuanDiTPAGPipeline",
|
||||
@@ -566,6 +567,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
KolorsPAGPipeline,
|
||||
PixArtSigmaPAGPipeline,
|
||||
StableDiffusion3PAGPipeline,
|
||||
StableDiffusionControlNetPAGInpaintPipeline,
|
||||
StableDiffusionControlNetPAGPipeline,
|
||||
StableDiffusionPAGPipeline,
|
||||
StableDiffusionXLControlNetPAGImg2ImgPipeline,
|
||||
|
||||
@@ -61,6 +61,7 @@ from .pag import (
|
||||
HunyuanDiTPAGPipeline,
|
||||
PixArtSigmaPAGPipeline,
|
||||
StableDiffusion3PAGPipeline,
|
||||
StableDiffusionControlNetPAGInpaintPipeline,
|
||||
StableDiffusionControlNetPAGPipeline,
|
||||
StableDiffusionPAGPipeline,
|
||||
StableDiffusionXLControlNetPAGImg2ImgPipeline,
|
||||
@@ -148,6 +149,7 @@ AUTO_INPAINT_PIPELINES_MAPPING = OrderedDict(
|
||||
("kandinsky", KandinskyInpaintCombinedPipeline),
|
||||
("kandinsky22", KandinskyV22InpaintCombinedPipeline),
|
||||
("stable-diffusion-controlnet", StableDiffusionControlNetInpaintPipeline),
|
||||
("stable-diffusion-controlnet-pag", StableDiffusionControlNetPAGInpaintPipeline),
|
||||
("stable-diffusion-xl-controlnet", StableDiffusionXLControlNetInpaintPipeline),
|
||||
("stable-diffusion-xl-pag", StableDiffusionXLPAGInpaintPipeline),
|
||||
("flux", FluxInpaintPipeline),
|
||||
|
||||
+11
-2
@@ -251,6 +251,9 @@ class StableDiffusion3ControlNetInpaintingPipeline(DiffusionPipeline, SD3LoraLoa
|
||||
if hasattr(self, "transformer") and self.transformer is not None
|
||||
else 128
|
||||
)
|
||||
self.patch_size = (
|
||||
self.transformer.config.patch_size if hasattr(self, "transformer") and self.transformer is not None else 2
|
||||
)
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion_3.pipeline_stable_diffusion_3.StableDiffusion3Pipeline._get_t5_prompt_embeds
|
||||
def _get_t5_prompt_embeds(
|
||||
@@ -577,8 +580,14 @@ class StableDiffusion3ControlNetInpaintingPipeline(DiffusionPipeline, SD3LoraLoa
|
||||
callback_on_step_end_tensor_inputs=None,
|
||||
max_sequence_length=None,
|
||||
):
|
||||
if height % 8 != 0 or width % 8 != 0:
|
||||
raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
|
||||
if (
|
||||
height % (self.vae_scale_factor * self.patch_size) != 0
|
||||
or width % (self.vae_scale_factor * self.patch_size) != 0
|
||||
):
|
||||
raise ValueError(
|
||||
f"`height` and `width` have to be divisible by {self.vae_scale_factor * self.patch_size} but are {height} and {width}."
|
||||
f"You can use height {height - height % (self.vae_scale_factor * self.patch_size)} and width {width - width % (self.vae_scale_factor * self.patch_size)}."
|
||||
)
|
||||
|
||||
if callback_on_step_end_tensor_inputs is not None and not all(
|
||||
k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
|
||||
|
||||
@@ -747,10 +747,12 @@ class FluxControlNetPipeline(DiffusionPipeline, FluxLoraLoaderMixin, FromSingleF
|
||||
width_control_image,
|
||||
)
|
||||
|
||||
# set control mode
|
||||
# Here we ensure that `control_mode` has the same length as the control_image.
|
||||
if control_mode is not None:
|
||||
if not isinstance(control_mode, int):
|
||||
raise ValueError(" For `FluxControlNet`, `control_mode` should be an `int` or `None`")
|
||||
control_mode = torch.tensor(control_mode).to(device, dtype=torch.long)
|
||||
control_mode = control_mode.reshape([-1, 1])
|
||||
control_mode = control_mode.view(-1, 1).expand(control_image.shape[0], 1)
|
||||
|
||||
elif isinstance(self.controlnet, FluxMultiControlNetModel):
|
||||
control_images = []
|
||||
@@ -785,16 +787,22 @@ class FluxControlNetPipeline(DiffusionPipeline, FluxLoraLoaderMixin, FromSingleF
|
||||
|
||||
control_image = control_images
|
||||
|
||||
# Here we ensure that `control_mode` has the same length as the control_image.
|
||||
if isinstance(control_mode, list) and len(control_mode) != len(control_image):
|
||||
raise ValueError(
|
||||
"For Multi-ControlNet, `control_mode` must be a list of the same "
|
||||
+ " length as the number of controlnets (control images) specified"
|
||||
)
|
||||
if not isinstance(control_mode, list):
|
||||
control_mode = [control_mode] * len(control_image)
|
||||
# set control mode
|
||||
control_mode_ = []
|
||||
if isinstance(control_mode, list):
|
||||
for cmode in control_mode:
|
||||
if cmode is None:
|
||||
control_mode_.append(-1)
|
||||
else:
|
||||
control_mode_.append(cmode)
|
||||
control_mode = torch.tensor(control_mode_).to(device, dtype=torch.long)
|
||||
control_mode = control_mode.reshape([-1, 1])
|
||||
control_modes = []
|
||||
for cmode in control_mode:
|
||||
if cmode is None:
|
||||
cmode = -1
|
||||
control_mode = torch.tensor(cmode).expand(control_images[0].shape[0]).to(device, dtype=torch.long)
|
||||
control_modes.append(control_mode)
|
||||
control_mode = control_modes
|
||||
|
||||
# 4. Prepare latent variables
|
||||
num_channels_latents = self.transformer.config.in_channels // 4
|
||||
@@ -840,9 +848,12 @@ class FluxControlNetPipeline(DiffusionPipeline, FluxLoraLoaderMixin, FromSingleF
|
||||
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
|
||||
timestep = t.expand(latents.shape[0]).to(latents.dtype)
|
||||
|
||||
guidance = (
|
||||
torch.tensor([guidance_scale], device=device) if self.controlnet.config.guidance_embeds else None
|
||||
)
|
||||
if isinstance(self.controlnet, FluxMultiControlNetModel):
|
||||
use_guidance = self.controlnet.nets[0].config.guidance_embeds
|
||||
else:
|
||||
use_guidance = self.controlnet.config.guidance_embeds
|
||||
|
||||
guidance = torch.tensor([guidance_scale], device=device) if use_guidance else None
|
||||
guidance = guidance.expand(latents.shape[0]) if guidance is not None else None
|
||||
|
||||
# controlnet
|
||||
|
||||
@@ -277,6 +277,7 @@ class ChatGLMTokenizer(PreTrainedTokenizer):
|
||||
padding_strategy: PaddingStrategy = PaddingStrategy.DO_NOT_PAD,
|
||||
pad_to_multiple_of: Optional[int] = None,
|
||||
return_attention_mask: Optional[bool] = None,
|
||||
padding_side: Optional[bool] = None,
|
||||
) -> dict:
|
||||
"""
|
||||
Pad encoded inputs (on left/right and up to predefined length or max length in the batch)
|
||||
@@ -298,6 +299,9 @@ class ChatGLMTokenizer(PreTrainedTokenizer):
|
||||
pad_to_multiple_of: (optional) Integer if set will pad the sequence to a multiple of the provided value.
|
||||
This is especially useful to enable the use of Tensor Core on NVIDIA hardware with compute capability
|
||||
`>= 7.5` (Volta).
|
||||
padding_side (`str`, *optional*):
|
||||
The side on which the model should have padding applied. Should be selected between ['right', 'left'].
|
||||
Default value is picked from the class attribute of the same name.
|
||||
return_attention_mask:
|
||||
(optional) Set to False to avoid returning attention mask (default: set to model specifics)
|
||||
"""
|
||||
|
||||
@@ -23,6 +23,7 @@ except OptionalDependencyNotAvailable:
|
||||
_dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects))
|
||||
else:
|
||||
_import_structure["pipeline_pag_controlnet_sd"] = ["StableDiffusionControlNetPAGPipeline"]
|
||||
_import_structure["pipeline_pag_controlnet_sd_inpaint"] = ["StableDiffusionControlNetPAGInpaintPipeline"]
|
||||
_import_structure["pipeline_pag_controlnet_sd_xl"] = ["StableDiffusionXLControlNetPAGPipeline"]
|
||||
_import_structure["pipeline_pag_controlnet_sd_xl_img2img"] = ["StableDiffusionXLControlNetPAGImg2ImgPipeline"]
|
||||
_import_structure["pipeline_pag_hunyuandit"] = ["HunyuanDiTPAGPipeline"]
|
||||
@@ -44,6 +45,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
from ...utils.dummy_torch_and_transformers_objects import *
|
||||
else:
|
||||
from .pipeline_pag_controlnet_sd import StableDiffusionControlNetPAGPipeline
|
||||
from .pipeline_pag_controlnet_sd_inpaint import StableDiffusionControlNetPAGInpaintPipeline
|
||||
from .pipeline_pag_controlnet_sd_xl import StableDiffusionXLControlNetPAGPipeline
|
||||
from .pipeline_pag_controlnet_sd_xl_img2img import StableDiffusionXLControlNetPAGImg2ImgPipeline
|
||||
from .pipeline_pag_hunyuandit import HunyuanDiTPAGPipeline
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -1347,7 +1347,7 @@ class StableDiffusionXLControlNetPAGPipeline(
|
||||
latents,
|
||||
)
|
||||
|
||||
# 6.5 Optionally get Guidance Scale Embedding
|
||||
# 6.1 Optionally get Guidance Scale Embedding
|
||||
timestep_cond = None
|
||||
if self.unet.config.time_cond_proj_dim is not None:
|
||||
guidance_scale_tensor = torch.tensor(self.guidance_scale - 1).repeat(batch_size * num_images_per_prompt)
|
||||
|
||||
@@ -212,6 +212,9 @@ class StableDiffusion3PAGPipeline(DiffusionPipeline, SD3LoraLoaderMixin, FromSin
|
||||
if hasattr(self, "transformer") and self.transformer is not None
|
||||
else 128
|
||||
)
|
||||
self.patch_size = (
|
||||
self.transformer.config.patch_size if hasattr(self, "transformer") and self.transformer is not None else 2
|
||||
)
|
||||
|
||||
self.set_pag_applied_layers(
|
||||
pag_applied_layers, pag_attn_processors=(PAGCFGJointAttnProcessor2_0(), PAGJointAttnProcessor2_0())
|
||||
@@ -542,8 +545,14 @@ class StableDiffusion3PAGPipeline(DiffusionPipeline, SD3LoraLoaderMixin, FromSin
|
||||
callback_on_step_end_tensor_inputs=None,
|
||||
max_sequence_length=None,
|
||||
):
|
||||
if height % 8 != 0 or width % 8 != 0:
|
||||
raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
|
||||
if (
|
||||
height % (self.vae_scale_factor * self.patch_size) != 0
|
||||
or width % (self.vae_scale_factor * self.patch_size) != 0
|
||||
):
|
||||
raise ValueError(
|
||||
f"`height` and `width` have to be divisible by {self.vae_scale_factor * self.patch_size} but are {height} and {width}."
|
||||
f"You can use height {height - height % (self.vae_scale_factor * self.patch_size)} and width {width - width % (self.vae_scale_factor * self.patch_size)}."
|
||||
)
|
||||
|
||||
if callback_on_step_end_tensor_inputs is not None and not all(
|
||||
k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
|
||||
|
||||
@@ -50,7 +50,6 @@ from ..utils import (
|
||||
DEPRECATED_REVISION_ARGS,
|
||||
BaseOutput,
|
||||
PushToHubMixin,
|
||||
deprecate,
|
||||
is_accelerate_available,
|
||||
is_accelerate_version,
|
||||
is_torch_npu_available,
|
||||
@@ -58,7 +57,7 @@ from ..utils import (
|
||||
logging,
|
||||
numpy_to_pil,
|
||||
)
|
||||
from ..utils.hub_utils import load_or_create_model_card, populate_model_card
|
||||
from ..utils.hub_utils import _check_legacy_sharding_variant_format, load_or_create_model_card, populate_model_card
|
||||
from ..utils.torch_utils import is_compiled_module
|
||||
|
||||
|
||||
@@ -735,6 +734,21 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
|
||||
else:
|
||||
cached_folder = pretrained_model_name_or_path
|
||||
|
||||
# The variant filenames can have the legacy sharding checkpoint format that we check and throw
|
||||
# a warning if detected.
|
||||
if variant is not None and _check_legacy_sharding_variant_format(folder=cached_folder, variant=variant):
|
||||
warn_msg = (
|
||||
f"Warning: The repository contains sharded checkpoints for variant '{variant}' maybe in a deprecated format. "
|
||||
"Please check your files carefully:\n\n"
|
||||
"- Correct format example: diffusion_pytorch_model.fp16-00003-of-00003.safetensors\n"
|
||||
"- Deprecated format example: diffusion_pytorch_model-00001-of-00002.fp16.safetensors\n\n"
|
||||
"If you find any files in the deprecated format:\n"
|
||||
"1. Remove all existing checkpoint files for this variant.\n"
|
||||
"2. Re-obtain the correct files by running `save_pretrained()`.\n\n"
|
||||
"This will ensure you're using the most up-to-date and compatible checkpoint format."
|
||||
)
|
||||
logger.warning(warn_msg)
|
||||
|
||||
config_dict = cls.load_config(cached_folder)
|
||||
|
||||
# pop out "_ignore_files" as it is only needed for download
|
||||
@@ -745,6 +759,9 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
|
||||
# Example: `diffusion_pytorch_model.safetensors` -> `diffusion_pytorch_model.fp16.safetensors`
|
||||
# with variant being `"fp16"`.
|
||||
model_variants = _identify_model_variants(folder=cached_folder, variant=variant, config=config_dict)
|
||||
if len(model_variants) == 0 and variant is not None:
|
||||
error_message = f"You are trying to load the model files of the `variant={variant}`, but no such modeling files are available."
|
||||
raise ValueError(error_message)
|
||||
|
||||
# 3. Load the pipeline class, if using custom module then load it from the hub
|
||||
# if we load from explicit class, let's use it
|
||||
@@ -1251,6 +1268,22 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
|
||||
model_info_call_error = e # save error to reraise it if model is not cached locally
|
||||
|
||||
if not local_files_only:
|
||||
filenames = {sibling.rfilename for sibling in info.siblings}
|
||||
if variant is not None and _check_legacy_sharding_variant_format(filenames=filenames, variant=variant):
|
||||
warn_msg = (
|
||||
f"Warning: The repository contains sharded checkpoints for variant '{variant}' maybe in a deprecated format. "
|
||||
"Please check your files carefully:\n\n"
|
||||
"- Correct format example: diffusion_pytorch_model.fp16-00003-of-00003.safetensors\n"
|
||||
"- Deprecated format example: diffusion_pytorch_model-00001-of-00002.fp16.safetensors\n\n"
|
||||
"If you find any files in the deprecated format:\n"
|
||||
"1. Remove all existing checkpoint files for this variant.\n"
|
||||
"2. Re-obtain the correct files by running `save_pretrained()`.\n\n"
|
||||
"This will ensure you're using the most up-to-date and compatible checkpoint format."
|
||||
)
|
||||
logger.warning(warn_msg)
|
||||
|
||||
model_filenames, variant_filenames = variant_compatible_siblings(filenames, variant=variant)
|
||||
|
||||
config_file = hf_hub_download(
|
||||
pretrained_model_name,
|
||||
cls.config_name,
|
||||
@@ -1267,9 +1300,6 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
|
||||
# retrieve all folder_names that contain relevant files
|
||||
folder_names = [k for k, v in config_dict.items() if isinstance(v, list) and k != "_class_name"]
|
||||
|
||||
filenames = {sibling.rfilename for sibling in info.siblings}
|
||||
model_filenames, variant_filenames = variant_compatible_siblings(filenames, variant=variant)
|
||||
|
||||
diffusers_module = importlib.import_module(__name__.split(".")[0])
|
||||
pipelines = getattr(diffusers_module, "pipelines")
|
||||
|
||||
@@ -1292,13 +1322,8 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
|
||||
)
|
||||
|
||||
if len(variant_filenames) == 0 and variant is not None:
|
||||
deprecation_message = (
|
||||
f"You are trying to load the model files of the `variant={variant}`, but no such modeling files are available."
|
||||
f"The default model files: {model_filenames} will be loaded instead. Make sure to not load from `variant={variant}`"
|
||||
"if such variant modeling files are not available. Doing so will lead to an error in v0.24.0 as defaulting to non-variant"
|
||||
"modeling files is deprecated."
|
||||
)
|
||||
deprecate("no variant default", "0.24.0", deprecation_message, standard_warn=False)
|
||||
error_message = f"You are trying to load the model files of the `variant={variant}`, but no such modeling files are available."
|
||||
raise ValueError(error_message)
|
||||
|
||||
# remove ignored filenames
|
||||
model_filenames = set(model_filenames) - set(ignore_filenames)
|
||||
|
||||
@@ -203,6 +203,9 @@ class StableDiffusion3Pipeline(DiffusionPipeline, SD3LoraLoaderMixin, FromSingle
|
||||
if hasattr(self, "transformer") and self.transformer is not None
|
||||
else 128
|
||||
)
|
||||
self.patch_size = (
|
||||
self.transformer.config.patch_size if hasattr(self, "transformer") and self.transformer is not None else 2
|
||||
)
|
||||
|
||||
def _get_t5_prompt_embeds(
|
||||
self,
|
||||
@@ -525,8 +528,14 @@ class StableDiffusion3Pipeline(DiffusionPipeline, SD3LoraLoaderMixin, FromSingle
|
||||
callback_on_step_end_tensor_inputs=None,
|
||||
max_sequence_length=None,
|
||||
):
|
||||
if height % 8 != 0 or width % 8 != 0:
|
||||
raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
|
||||
if (
|
||||
height % (self.vae_scale_factor * self.patch_size) != 0
|
||||
or width % (self.vae_scale_factor * self.patch_size) != 0
|
||||
):
|
||||
raise ValueError(
|
||||
f"`height` and `width` have to be divisible by {self.vae_scale_factor * self.patch_size} but are {height} and {width}."
|
||||
f"You can use height {height - height % (self.vae_scale_factor * self.patch_size)} and width {width - width % (self.vae_scale_factor * self.patch_size)}."
|
||||
)
|
||||
|
||||
if callback_on_step_end_tensor_inputs is not None and not all(
|
||||
k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
|
||||
|
||||
@@ -22,10 +22,14 @@ import numpy as np
|
||||
import torch
|
||||
|
||||
from ..configuration_utils import ConfigMixin, register_to_config
|
||||
from ..utils import deprecate
|
||||
from ..utils import deprecate, is_scipy_available
|
||||
from .scheduling_utils import KarrasDiffusionSchedulers, SchedulerMixin, SchedulerOutput
|
||||
|
||||
|
||||
if is_scipy_available():
|
||||
import scipy.stats
|
||||
|
||||
|
||||
# Copied from diffusers.schedulers.scheduling_ddpm.betas_for_alpha_bar
|
||||
def betas_for_alpha_bar(
|
||||
num_diffusion_timesteps,
|
||||
@@ -113,6 +117,9 @@ class DEISMultistepScheduler(SchedulerMixin, ConfigMixin):
|
||||
the sigmas are determined according to a sequence of noise levels {σi}.
|
||||
use_exponential_sigmas (`bool`, *optional*, defaults to `False`):
|
||||
Whether to use exponential sigmas for step sizes in the noise schedule during the sampling process.
|
||||
use_beta_sigmas (`bool`, *optional*, defaults to `False`):
|
||||
Whether to use beta sigmas for step sizes in the noise schedule during the sampling process. Refer to [Beta
|
||||
Sampling is All You Need](https://huggingface.co/papers/2407.12173) for more information.
|
||||
timestep_spacing (`str`, defaults to `"linspace"`):
|
||||
The way the timesteps should be scaled. Refer to Table 2 of the [Common Diffusion Noise Schedules and
|
||||
Sample Steps are Flawed](https://huggingface.co/papers/2305.08891) for more information.
|
||||
@@ -141,11 +148,16 @@ class DEISMultistepScheduler(SchedulerMixin, ConfigMixin):
|
||||
lower_order_final: bool = True,
|
||||
use_karras_sigmas: Optional[bool] = False,
|
||||
use_exponential_sigmas: Optional[bool] = False,
|
||||
use_beta_sigmas: Optional[bool] = False,
|
||||
timestep_spacing: str = "linspace",
|
||||
steps_offset: int = 0,
|
||||
):
|
||||
if sum([self.config.use_exponential_sigmas, self.config.use_karras_sigmas]) > 1:
|
||||
raise ValueError("Only one of `config.use_exponential_sigmas`, `config.use_karras_sigmas` can be used.")
|
||||
if self.config.use_beta_sigmas and not is_scipy_available():
|
||||
raise ImportError("Make sure to install scipy if you want to use beta sigmas.")
|
||||
if sum([self.config.use_beta_sigmas, self.config.use_exponential_sigmas, self.config.use_karras_sigmas]) > 1:
|
||||
raise ValueError(
|
||||
"Only one of `config.use_beta_sigmas`, `config.use_exponential_sigmas`, `config.use_karras_sigmas` can be used."
|
||||
)
|
||||
if trained_betas is not None:
|
||||
self.betas = torch.tensor(trained_betas, dtype=torch.float32)
|
||||
elif beta_schedule == "linear":
|
||||
@@ -263,6 +275,9 @@ class DEISMultistepScheduler(SchedulerMixin, ConfigMixin):
|
||||
elif self.config.use_exponential_sigmas:
|
||||
sigmas = self._convert_to_exponential(in_sigmas=sigmas, num_inference_steps=self.num_inference_steps)
|
||||
timesteps = np.array([self._sigma_to_t(sigma, log_sigmas) for sigma in sigmas])
|
||||
elif self.config.use_beta_sigmas:
|
||||
sigmas = self._convert_to_beta(in_sigmas=sigmas, num_inference_steps=self.num_inference_steps)
|
||||
timesteps = np.array([self._sigma_to_t(sigma, log_sigmas) for sigma in sigmas])
|
||||
else:
|
||||
sigmas = np.interp(timesteps, np.arange(0, len(sigmas)), sigmas)
|
||||
sigma_last = ((1 - self.alphas_cumprod[0]) / self.alphas_cumprod[0]) ** 0.5
|
||||
@@ -396,6 +411,38 @@ class DEISMultistepScheduler(SchedulerMixin, ConfigMixin):
|
||||
sigmas = torch.linspace(math.log(sigma_max), math.log(sigma_min), num_inference_steps).exp()
|
||||
return sigmas
|
||||
|
||||
# Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._convert_to_beta
|
||||
def _convert_to_beta(
|
||||
self, in_sigmas: torch.Tensor, num_inference_steps: int, alpha: float = 0.6, beta: float = 0.6
|
||||
) -> torch.Tensor:
|
||||
"""From "Beta Sampling is All You Need" [arXiv:2407.12173] (Lee et. al, 2024)"""
|
||||
|
||||
# Hack to make sure that other schedulers which copy this function don't break
|
||||
# TODO: Add this logic to the other schedulers
|
||||
if hasattr(self.config, "sigma_min"):
|
||||
sigma_min = self.config.sigma_min
|
||||
else:
|
||||
sigma_min = None
|
||||
|
||||
if hasattr(self.config, "sigma_max"):
|
||||
sigma_max = self.config.sigma_max
|
||||
else:
|
||||
sigma_max = None
|
||||
|
||||
sigma_min = sigma_min if sigma_min is not None else in_sigmas[-1].item()
|
||||
sigma_max = sigma_max if sigma_max is not None else in_sigmas[0].item()
|
||||
|
||||
sigmas = torch.Tensor(
|
||||
[
|
||||
sigma_min + (ppf * (sigma_max - sigma_min))
|
||||
for ppf in [
|
||||
scipy.stats.beta.ppf(timestep, alpha, beta)
|
||||
for timestep in 1 - np.linspace(0, 1, num_inference_steps)
|
||||
]
|
||||
]
|
||||
)
|
||||
return sigmas
|
||||
|
||||
def convert_model_output(
|
||||
self,
|
||||
model_output: torch.Tensor,
|
||||
|
||||
@@ -21,11 +21,15 @@ import numpy as np
|
||||
import torch
|
||||
|
||||
from ..configuration_utils import ConfigMixin, register_to_config
|
||||
from ..utils import deprecate
|
||||
from ..utils import deprecate, is_scipy_available
|
||||
from ..utils.torch_utils import randn_tensor
|
||||
from .scheduling_utils import KarrasDiffusionSchedulers, SchedulerMixin, SchedulerOutput
|
||||
|
||||
|
||||
if is_scipy_available():
|
||||
import scipy.stats
|
||||
|
||||
|
||||
# Copied from diffusers.schedulers.scheduling_ddpm.betas_for_alpha_bar
|
||||
def betas_for_alpha_bar(
|
||||
num_diffusion_timesteps,
|
||||
@@ -163,6 +167,9 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
|
||||
the sigmas are determined according to a sequence of noise levels {σi}.
|
||||
use_exponential_sigmas (`bool`, *optional*, defaults to `False`):
|
||||
Whether to use exponential sigmas for step sizes in the noise schedule during the sampling process.
|
||||
use_beta_sigmas (`bool`, *optional*, defaults to `False`):
|
||||
Whether to use beta sigmas for step sizes in the noise schedule during the sampling process. Refer to [Beta
|
||||
Sampling is All You Need](https://huggingface.co/papers/2407.12173) for more information.
|
||||
use_lu_lambdas (`bool`, *optional*, defaults to `False`):
|
||||
Whether to use the uniform-logSNR for step sizes proposed by Lu's DPM-Solver in the noise schedule during
|
||||
the sampling process. If `True`, the sigmas and time steps are determined according to a sequence of
|
||||
@@ -209,6 +216,7 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
|
||||
euler_at_final: bool = False,
|
||||
use_karras_sigmas: Optional[bool] = False,
|
||||
use_exponential_sigmas: Optional[bool] = False,
|
||||
use_beta_sigmas: Optional[bool] = False,
|
||||
use_lu_lambdas: Optional[bool] = False,
|
||||
final_sigmas_type: Optional[str] = "zero", # "zero", "sigma_min"
|
||||
lambda_min_clipped: float = -float("inf"),
|
||||
@@ -217,8 +225,12 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
|
||||
steps_offset: int = 0,
|
||||
rescale_betas_zero_snr: bool = False,
|
||||
):
|
||||
if sum([self.config.use_exponential_sigmas, self.config.use_karras_sigmas]) > 1:
|
||||
raise ValueError("Only one of `config.use_exponential_sigmas`, `config.use_karras_sigmas` can be used.")
|
||||
if self.config.use_beta_sigmas and not is_scipy_available():
|
||||
raise ImportError("Make sure to install scipy if you want to use beta sigmas.")
|
||||
if sum([self.config.use_beta_sigmas, self.config.use_exponential_sigmas, self.config.use_karras_sigmas]) > 1:
|
||||
raise ValueError(
|
||||
"Only one of `config.use_beta_sigmas`, `config.use_exponential_sigmas`, `config.use_karras_sigmas` can be used."
|
||||
)
|
||||
if algorithm_type in ["dpmsolver", "sde-dpmsolver"]:
|
||||
deprecation_message = f"algorithm_type {algorithm_type} is deprecated and will be removed in a future version. Choose from `dpmsolver++` or `sde-dpmsolver++` instead"
|
||||
deprecate("algorithm_types dpmsolver and sde-dpmsolver", "1.0.0", deprecation_message)
|
||||
@@ -337,6 +349,8 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
|
||||
raise ValueError("Cannot use `timesteps` with `config.use_lu_lambdas = True`")
|
||||
if timesteps is not None and self.config.use_exponential_sigmas:
|
||||
raise ValueError("Cannot set `timesteps` with `config.use_exponential_sigmas = True`.")
|
||||
if timesteps is not None and self.config.use_beta_sigmas:
|
||||
raise ValueError("Cannot set `timesteps` with `config.use_beta_sigmas = True`.")
|
||||
|
||||
if timesteps is not None:
|
||||
timesteps = np.array(timesteps).astype(np.int64)
|
||||
@@ -388,6 +402,9 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
|
||||
elif self.config.use_exponential_sigmas:
|
||||
sigmas = self._convert_to_exponential(in_sigmas=sigmas, num_inference_steps=self.num_inference_steps)
|
||||
timesteps = np.array([self._sigma_to_t(sigma, log_sigmas) for sigma in sigmas])
|
||||
elif self.config.use_beta_sigmas:
|
||||
sigmas = self._convert_to_beta(in_sigmas=sigmas, num_inference_steps=self.num_inference_steps)
|
||||
timesteps = np.array([self._sigma_to_t(sigma, log_sigmas) for sigma in sigmas])
|
||||
else:
|
||||
sigmas = np.interp(timesteps, np.arange(0, len(sigmas)), sigmas)
|
||||
|
||||
@@ -542,6 +559,38 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
|
||||
sigmas = torch.linspace(math.log(sigma_max), math.log(sigma_min), num_inference_steps).exp()
|
||||
return sigmas
|
||||
|
||||
# Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._convert_to_beta
|
||||
def _convert_to_beta(
|
||||
self, in_sigmas: torch.Tensor, num_inference_steps: int, alpha: float = 0.6, beta: float = 0.6
|
||||
) -> torch.Tensor:
|
||||
"""From "Beta Sampling is All You Need" [arXiv:2407.12173] (Lee et. al, 2024)"""
|
||||
|
||||
# Hack to make sure that other schedulers which copy this function don't break
|
||||
# TODO: Add this logic to the other schedulers
|
||||
if hasattr(self.config, "sigma_min"):
|
||||
sigma_min = self.config.sigma_min
|
||||
else:
|
||||
sigma_min = None
|
||||
|
||||
if hasattr(self.config, "sigma_max"):
|
||||
sigma_max = self.config.sigma_max
|
||||
else:
|
||||
sigma_max = None
|
||||
|
||||
sigma_min = sigma_min if sigma_min is not None else in_sigmas[-1].item()
|
||||
sigma_max = sigma_max if sigma_max is not None else in_sigmas[0].item()
|
||||
|
||||
sigmas = torch.Tensor(
|
||||
[
|
||||
sigma_min + (ppf * (sigma_max - sigma_min))
|
||||
for ppf in [
|
||||
scipy.stats.beta.ppf(timestep, alpha, beta)
|
||||
for timestep in 1 - np.linspace(0, 1, num_inference_steps)
|
||||
]
|
||||
]
|
||||
)
|
||||
return sigmas
|
||||
|
||||
def convert_model_output(
|
||||
self,
|
||||
model_output: torch.Tensor,
|
||||
|
||||
@@ -21,11 +21,15 @@ import numpy as np
|
||||
import torch
|
||||
|
||||
from ..configuration_utils import ConfigMixin, register_to_config
|
||||
from ..utils import deprecate
|
||||
from ..utils import deprecate, is_scipy_available
|
||||
from ..utils.torch_utils import randn_tensor
|
||||
from .scheduling_utils import KarrasDiffusionSchedulers, SchedulerMixin, SchedulerOutput
|
||||
|
||||
|
||||
if is_scipy_available():
|
||||
import scipy.stats
|
||||
|
||||
|
||||
# Copied from diffusers.schedulers.scheduling_ddpm.betas_for_alpha_bar
|
||||
def betas_for_alpha_bar(
|
||||
num_diffusion_timesteps,
|
||||
@@ -126,6 +130,9 @@ class DPMSolverMultistepInverseScheduler(SchedulerMixin, ConfigMixin):
|
||||
the sigmas are determined according to a sequence of noise levels {σi}.
|
||||
use_exponential_sigmas (`bool`, *optional*, defaults to `False`):
|
||||
Whether to use exponential sigmas for step sizes in the noise schedule during the sampling process.
|
||||
use_beta_sigmas (`bool`, *optional*, defaults to `False`):
|
||||
Whether to use beta sigmas for step sizes in the noise schedule during the sampling process. Refer to [Beta
|
||||
Sampling is All You Need](https://huggingface.co/papers/2407.12173) for more information.
|
||||
lambda_min_clipped (`float`, defaults to `-inf`):
|
||||
Clipping threshold for the minimum value of `lambda(t)` for numerical stability. This is critical for the
|
||||
cosine (`squaredcos_cap_v2`) noise schedule.
|
||||
@@ -161,13 +168,18 @@ class DPMSolverMultistepInverseScheduler(SchedulerMixin, ConfigMixin):
|
||||
euler_at_final: bool = False,
|
||||
use_karras_sigmas: Optional[bool] = False,
|
||||
use_exponential_sigmas: Optional[bool] = False,
|
||||
use_beta_sigmas: Optional[bool] = False,
|
||||
lambda_min_clipped: float = -float("inf"),
|
||||
variance_type: Optional[str] = None,
|
||||
timestep_spacing: str = "linspace",
|
||||
steps_offset: int = 0,
|
||||
):
|
||||
if sum([self.config.use_exponential_sigmas, self.config.use_karras_sigmas]) > 1:
|
||||
raise ValueError("Only one of `config.use_exponential_sigmas`, `config.use_karras_sigmas` can be used.")
|
||||
if self.config.use_beta_sigmas and not is_scipy_available():
|
||||
raise ImportError("Make sure to install scipy if you want to use beta sigmas.")
|
||||
if sum([self.config.use_beta_sigmas, self.config.use_exponential_sigmas, self.config.use_karras_sigmas]) > 1:
|
||||
raise ValueError(
|
||||
"Only one of `config.use_beta_sigmas`, `config.use_exponential_sigmas`, `config.use_karras_sigmas` can be used."
|
||||
)
|
||||
if algorithm_type in ["dpmsolver", "sde-dpmsolver"]:
|
||||
deprecation_message = f"algorithm_type {algorithm_type} is deprecated and will be removed in a future version. Choose from `dpmsolver++` or `sde-dpmsolver++` instead"
|
||||
deprecate("algorithm_types dpmsolver and sde-dpmsolver", "1.0.0", deprecation_message)
|
||||
@@ -219,6 +231,7 @@ class DPMSolverMultistepInverseScheduler(SchedulerMixin, ConfigMixin):
|
||||
self.sigmas = self.sigmas.to("cpu") # to avoid too much CPU/GPU communication
|
||||
self.use_karras_sigmas = use_karras_sigmas
|
||||
self.use_exponential_sigmas = use_exponential_sigmas
|
||||
self.use_beta_sigmas = use_beta_sigmas
|
||||
|
||||
@property
|
||||
def step_index(self):
|
||||
@@ -276,6 +289,9 @@ class DPMSolverMultistepInverseScheduler(SchedulerMixin, ConfigMixin):
|
||||
elif self.config.use_exponential_sigmas:
|
||||
sigmas = self._convert_to_exponential(in_sigmas=sigmas, num_inference_steps=self.num_inference_steps)
|
||||
timesteps = np.array([self._sigma_to_t(sigma, log_sigmas) for sigma in sigmas])
|
||||
elif self.config.use_beta_sigmas:
|
||||
sigmas = self._convert_to_beta(in_sigmas=sigmas, num_inference_steps=self.num_inference_steps)
|
||||
timesteps = np.array([self._sigma_to_t(sigma, log_sigmas) for sigma in sigmas])
|
||||
else:
|
||||
sigmas = np.interp(timesteps, np.arange(0, len(sigmas)), sigmas)
|
||||
sigma_max = (
|
||||
@@ -416,6 +432,38 @@ class DPMSolverMultistepInverseScheduler(SchedulerMixin, ConfigMixin):
|
||||
sigmas = torch.linspace(math.log(sigma_max), math.log(sigma_min), num_inference_steps).exp()
|
||||
return sigmas
|
||||
|
||||
# Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._convert_to_beta
|
||||
def _convert_to_beta(
|
||||
self, in_sigmas: torch.Tensor, num_inference_steps: int, alpha: float = 0.6, beta: float = 0.6
|
||||
) -> torch.Tensor:
|
||||
"""From "Beta Sampling is All You Need" [arXiv:2407.12173] (Lee et. al, 2024)"""
|
||||
|
||||
# Hack to make sure that other schedulers which copy this function don't break
|
||||
# TODO: Add this logic to the other schedulers
|
||||
if hasattr(self.config, "sigma_min"):
|
||||
sigma_min = self.config.sigma_min
|
||||
else:
|
||||
sigma_min = None
|
||||
|
||||
if hasattr(self.config, "sigma_max"):
|
||||
sigma_max = self.config.sigma_max
|
||||
else:
|
||||
sigma_max = None
|
||||
|
||||
sigma_min = sigma_min if sigma_min is not None else in_sigmas[-1].item()
|
||||
sigma_max = sigma_max if sigma_max is not None else in_sigmas[0].item()
|
||||
|
||||
sigmas = torch.Tensor(
|
||||
[
|
||||
sigma_min + (ppf * (sigma_max - sigma_min))
|
||||
for ppf in [
|
||||
scipy.stats.beta.ppf(timestep, alpha, beta)
|
||||
for timestep in 1 - np.linspace(0, 1, num_inference_steps)
|
||||
]
|
||||
]
|
||||
)
|
||||
return sigmas
|
||||
|
||||
# Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.convert_model_output
|
||||
def convert_model_output(
|
||||
self,
|
||||
|
||||
@@ -20,9 +20,14 @@ import torch
|
||||
import torchsde
|
||||
|
||||
from ..configuration_utils import ConfigMixin, register_to_config
|
||||
from ..utils import is_scipy_available
|
||||
from .scheduling_utils import KarrasDiffusionSchedulers, SchedulerMixin, SchedulerOutput
|
||||
|
||||
|
||||
if is_scipy_available():
|
||||
import scipy.stats
|
||||
|
||||
|
||||
class BatchedBrownianTree:
|
||||
"""A wrapper around torchsde.BrownianTree that enables batches of entropy."""
|
||||
|
||||
@@ -162,6 +167,9 @@ class DPMSolverSDEScheduler(SchedulerMixin, ConfigMixin):
|
||||
the sigmas are determined according to a sequence of noise levels {σi}.
|
||||
use_exponential_sigmas (`bool`, *optional*, defaults to `False`):
|
||||
Whether to use exponential sigmas for step sizes in the noise schedule during the sampling process.
|
||||
use_beta_sigmas (`bool`, *optional*, defaults to `False`):
|
||||
Whether to use beta sigmas for step sizes in the noise schedule during the sampling process. Refer to [Beta
|
||||
Sampling is All You Need](https://huggingface.co/papers/2407.12173) for more information.
|
||||
noise_sampler_seed (`int`, *optional*, defaults to `None`):
|
||||
The random seed to use for the noise sampler. If `None`, a random seed is generated.
|
||||
timestep_spacing (`str`, defaults to `"linspace"`):
|
||||
@@ -185,12 +193,17 @@ class DPMSolverSDEScheduler(SchedulerMixin, ConfigMixin):
|
||||
prediction_type: str = "epsilon",
|
||||
use_karras_sigmas: Optional[bool] = False,
|
||||
use_exponential_sigmas: Optional[bool] = False,
|
||||
use_beta_sigmas: Optional[bool] = False,
|
||||
noise_sampler_seed: Optional[int] = None,
|
||||
timestep_spacing: str = "linspace",
|
||||
steps_offset: int = 0,
|
||||
):
|
||||
if sum([self.config.use_exponential_sigmas, self.config.use_karras_sigmas]) > 1:
|
||||
raise ValueError("Only one of `config.use_exponential_sigmas`, `config.use_karras_sigmas` can be used.")
|
||||
if self.config.use_beta_sigmas and not is_scipy_available():
|
||||
raise ImportError("Make sure to install scipy if you want to use beta sigmas.")
|
||||
if sum([self.config.use_beta_sigmas, self.config.use_exponential_sigmas, self.config.use_karras_sigmas]) > 1:
|
||||
raise ValueError(
|
||||
"Only one of `config.use_beta_sigmas`, `config.use_exponential_sigmas`, `config.use_karras_sigmas` can be used."
|
||||
)
|
||||
if trained_betas is not None:
|
||||
self.betas = torch.tensor(trained_betas, dtype=torch.float32)
|
||||
elif beta_schedule == "linear":
|
||||
@@ -349,6 +362,9 @@ class DPMSolverSDEScheduler(SchedulerMixin, ConfigMixin):
|
||||
elif self.config.use_exponential_sigmas:
|
||||
sigmas = self._convert_to_exponential(in_sigmas=sigmas, num_inference_steps=self.num_inference_steps)
|
||||
timesteps = np.array([self._sigma_to_t(sigma, log_sigmas) for sigma in sigmas])
|
||||
elif self.config.use_beta_sigmas:
|
||||
sigmas = self._convert_to_beta(in_sigmas=sigmas, num_inference_steps=self.num_inference_steps)
|
||||
timesteps = np.array([self._sigma_to_t(sigma, log_sigmas) for sigma in sigmas])
|
||||
|
||||
second_order_timesteps = self._second_order_timesteps(sigmas, log_sigmas)
|
||||
|
||||
@@ -451,6 +467,38 @@ class DPMSolverSDEScheduler(SchedulerMixin, ConfigMixin):
|
||||
sigmas = torch.linspace(math.log(sigma_max), math.log(sigma_min), num_inference_steps).exp()
|
||||
return sigmas
|
||||
|
||||
# Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._convert_to_beta
|
||||
def _convert_to_beta(
|
||||
self, in_sigmas: torch.Tensor, num_inference_steps: int, alpha: float = 0.6, beta: float = 0.6
|
||||
) -> torch.Tensor:
|
||||
"""From "Beta Sampling is All You Need" [arXiv:2407.12173] (Lee et. al, 2024)"""
|
||||
|
||||
# Hack to make sure that other schedulers which copy this function don't break
|
||||
# TODO: Add this logic to the other schedulers
|
||||
if hasattr(self.config, "sigma_min"):
|
||||
sigma_min = self.config.sigma_min
|
||||
else:
|
||||
sigma_min = None
|
||||
|
||||
if hasattr(self.config, "sigma_max"):
|
||||
sigma_max = self.config.sigma_max
|
||||
else:
|
||||
sigma_max = None
|
||||
|
||||
sigma_min = sigma_min if sigma_min is not None else in_sigmas[-1].item()
|
||||
sigma_max = sigma_max if sigma_max is not None else in_sigmas[0].item()
|
||||
|
||||
sigmas = torch.Tensor(
|
||||
[
|
||||
sigma_min + (ppf * (sigma_max - sigma_min))
|
||||
for ppf in [
|
||||
scipy.stats.beta.ppf(timestep, alpha, beta)
|
||||
for timestep in 1 - np.linspace(0, 1, num_inference_steps)
|
||||
]
|
||||
]
|
||||
)
|
||||
return sigmas
|
||||
|
||||
@property
|
||||
def state_in_first_order(self):
|
||||
return self.sample is None
|
||||
|
||||
@@ -21,11 +21,14 @@ import numpy as np
|
||||
import torch
|
||||
|
||||
from ..configuration_utils import ConfigMixin, register_to_config
|
||||
from ..utils import deprecate, logging
|
||||
from ..utils import deprecate, is_scipy_available, logging
|
||||
from ..utils.torch_utils import randn_tensor
|
||||
from .scheduling_utils import KarrasDiffusionSchedulers, SchedulerMixin, SchedulerOutput
|
||||
|
||||
|
||||
if is_scipy_available():
|
||||
import scipy.stats
|
||||
|
||||
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||
|
||||
|
||||
@@ -125,6 +128,9 @@ class DPMSolverSinglestepScheduler(SchedulerMixin, ConfigMixin):
|
||||
the sigmas are determined according to a sequence of noise levels {σi}.
|
||||
use_exponential_sigmas (`bool`, *optional*, defaults to `False`):
|
||||
Whether to use exponential sigmas for step sizes in the noise schedule during the sampling process.
|
||||
use_beta_sigmas (`bool`, *optional*, defaults to `False`):
|
||||
Whether to use beta sigmas for step sizes in the noise schedule during the sampling process. Refer to [Beta
|
||||
Sampling is All You Need](https://huggingface.co/papers/2407.12173) for more information.
|
||||
final_sigmas_type (`str`, *optional*, defaults to `"zero"`):
|
||||
The final `sigma` value for the noise schedule during the sampling process. If `"sigma_min"`, the final
|
||||
sigma is the same as the last sigma in the training schedule. If `zero`, the final sigma is set to 0.
|
||||
@@ -157,12 +163,17 @@ class DPMSolverSinglestepScheduler(SchedulerMixin, ConfigMixin):
|
||||
lower_order_final: bool = False,
|
||||
use_karras_sigmas: Optional[bool] = False,
|
||||
use_exponential_sigmas: Optional[bool] = False,
|
||||
use_beta_sigmas: Optional[bool] = False,
|
||||
final_sigmas_type: Optional[str] = "zero", # "zero", "sigma_min"
|
||||
lambda_min_clipped: float = -float("inf"),
|
||||
variance_type: Optional[str] = None,
|
||||
):
|
||||
if sum([self.config.use_exponential_sigmas, self.config.use_karras_sigmas]) > 1:
|
||||
raise ValueError("Only one of `config.use_exponential_sigmas`, `config.use_karras_sigmas` can be used.")
|
||||
if self.config.use_beta_sigmas and not is_scipy_available():
|
||||
raise ImportError("Make sure to install scipy if you want to use beta sigmas.")
|
||||
if sum([self.config.use_beta_sigmas, self.config.use_exponential_sigmas, self.config.use_karras_sigmas]) > 1:
|
||||
raise ValueError(
|
||||
"Only one of `config.use_beta_sigmas`, `config.use_exponential_sigmas`, `config.use_karras_sigmas` can be used."
|
||||
)
|
||||
if algorithm_type == "dpmsolver":
|
||||
deprecation_message = "algorithm_type `dpmsolver` is deprecated and will be removed in a future version. Choose from `dpmsolver++` or `sde-dpmsolver++` instead"
|
||||
deprecate("algorithm_types=dpmsolver", "1.0.0", deprecation_message)
|
||||
@@ -307,6 +318,8 @@ class DPMSolverSinglestepScheduler(SchedulerMixin, ConfigMixin):
|
||||
raise ValueError("Cannot use `timesteps` when `config.use_karras_sigmas=True`.")
|
||||
if timesteps is not None and self.config.use_exponential_sigmas:
|
||||
raise ValueError("Cannot set `timesteps` with `config.use_exponential_sigmas = True`.")
|
||||
if timesteps is not None and self.config.use_beta_sigmas:
|
||||
raise ValueError("Cannot set `timesteps` with `config.use_beta_sigmas = True`.")
|
||||
|
||||
num_inference_steps = num_inference_steps or len(timesteps)
|
||||
self.num_inference_steps = num_inference_steps
|
||||
@@ -333,6 +346,9 @@ class DPMSolverSinglestepScheduler(SchedulerMixin, ConfigMixin):
|
||||
elif self.config.use_exponential_sigmas:
|
||||
sigmas = self._convert_to_exponential(in_sigmas=sigmas, num_inference_steps=self.num_inference_steps)
|
||||
timesteps = np.array([self._sigma_to_t(sigma, log_sigmas) for sigma in sigmas])
|
||||
elif self.config.use_beta_sigmas:
|
||||
sigmas = self._convert_to_beta(in_sigmas=sigmas, num_inference_steps=self.num_inference_steps)
|
||||
timesteps = np.array([self._sigma_to_t(sigma, log_sigmas) for sigma in sigmas])
|
||||
else:
|
||||
sigmas = np.interp(timesteps, np.arange(0, len(sigmas)), sigmas)
|
||||
|
||||
@@ -484,6 +500,38 @@ class DPMSolverSinglestepScheduler(SchedulerMixin, ConfigMixin):
|
||||
sigmas = torch.linspace(math.log(sigma_max), math.log(sigma_min), num_inference_steps).exp()
|
||||
return sigmas
|
||||
|
||||
# Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._convert_to_beta
|
||||
def _convert_to_beta(
|
||||
self, in_sigmas: torch.Tensor, num_inference_steps: int, alpha: float = 0.6, beta: float = 0.6
|
||||
) -> torch.Tensor:
|
||||
"""From "Beta Sampling is All You Need" [arXiv:2407.12173] (Lee et. al, 2024)"""
|
||||
|
||||
# Hack to make sure that other schedulers which copy this function don't break
|
||||
# TODO: Add this logic to the other schedulers
|
||||
if hasattr(self.config, "sigma_min"):
|
||||
sigma_min = self.config.sigma_min
|
||||
else:
|
||||
sigma_min = None
|
||||
|
||||
if hasattr(self.config, "sigma_max"):
|
||||
sigma_max = self.config.sigma_max
|
||||
else:
|
||||
sigma_max = None
|
||||
|
||||
sigma_min = sigma_min if sigma_min is not None else in_sigmas[-1].item()
|
||||
sigma_max = sigma_max if sigma_max is not None else in_sigmas[0].item()
|
||||
|
||||
sigmas = torch.Tensor(
|
||||
[
|
||||
sigma_min + (ppf * (sigma_max - sigma_min))
|
||||
for ppf in [
|
||||
scipy.stats.beta.ppf(timestep, alpha, beta)
|
||||
for timestep in 1 - np.linspace(0, 1, num_inference_steps)
|
||||
]
|
||||
]
|
||||
)
|
||||
return sigmas
|
||||
|
||||
def convert_model_output(
|
||||
self,
|
||||
model_output: torch.Tensor,
|
||||
|
||||
@@ -20,11 +20,14 @@ import numpy as np
|
||||
import torch
|
||||
|
||||
from ..configuration_utils import ConfigMixin, register_to_config
|
||||
from ..utils import BaseOutput, logging
|
||||
from ..utils import BaseOutput, is_scipy_available, logging
|
||||
from ..utils.torch_utils import randn_tensor
|
||||
from .scheduling_utils import KarrasDiffusionSchedulers, SchedulerMixin
|
||||
|
||||
|
||||
if is_scipy_available():
|
||||
import scipy.stats
|
||||
|
||||
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||
|
||||
|
||||
@@ -160,6 +163,9 @@ class EulerDiscreteScheduler(SchedulerMixin, ConfigMixin):
|
||||
the sigmas are determined according to a sequence of noise levels {σi}.
|
||||
use_exponential_sigmas (`bool`, *optional*, defaults to `False`):
|
||||
Whether to use exponential sigmas for step sizes in the noise schedule during the sampling process.
|
||||
use_beta_sigmas (`bool`, *optional*, defaults to `False`):
|
||||
Whether to use beta sigmas for step sizes in the noise schedule during the sampling process. Refer to [Beta
|
||||
Sampling is All You Need](https://huggingface.co/papers/2407.12173) for more information.
|
||||
timestep_spacing (`str`, defaults to `"linspace"`):
|
||||
The way the timesteps should be scaled. Refer to Table 2 of the [Common Diffusion Noise Schedules and
|
||||
Sample Steps are Flawed](https://huggingface.co/papers/2305.08891) for more information.
|
||||
@@ -189,6 +195,7 @@ class EulerDiscreteScheduler(SchedulerMixin, ConfigMixin):
|
||||
interpolation_type: str = "linear",
|
||||
use_karras_sigmas: Optional[bool] = False,
|
||||
use_exponential_sigmas: Optional[bool] = False,
|
||||
use_beta_sigmas: Optional[bool] = False,
|
||||
sigma_min: Optional[float] = None,
|
||||
sigma_max: Optional[float] = None,
|
||||
timestep_spacing: str = "linspace",
|
||||
@@ -197,8 +204,12 @@ class EulerDiscreteScheduler(SchedulerMixin, ConfigMixin):
|
||||
rescale_betas_zero_snr: bool = False,
|
||||
final_sigmas_type: str = "zero", # can be "zero" or "sigma_min"
|
||||
):
|
||||
if sum([self.config.use_exponential_sigmas, self.config.use_karras_sigmas]) > 1:
|
||||
raise ValueError("Only one of `config.use_exponential_sigmas`, `config.use_karras_sigmas` can be used.")
|
||||
if self.config.use_beta_sigmas and not is_scipy_available():
|
||||
raise ImportError("Make sure to install scipy if you want to use beta sigmas.")
|
||||
if sum([self.config.use_beta_sigmas, self.config.use_exponential_sigmas, self.config.use_karras_sigmas]) > 1:
|
||||
raise ValueError(
|
||||
"Only one of `config.use_beta_sigmas`, `config.use_exponential_sigmas`, `config.use_karras_sigmas` can be used."
|
||||
)
|
||||
if trained_betas is not None:
|
||||
self.betas = torch.tensor(trained_betas, dtype=torch.float32)
|
||||
elif beta_schedule == "linear":
|
||||
@@ -241,6 +252,7 @@ class EulerDiscreteScheduler(SchedulerMixin, ConfigMixin):
|
||||
self.is_scale_input_called = False
|
||||
self.use_karras_sigmas = use_karras_sigmas
|
||||
self.use_exponential_sigmas = use_exponential_sigmas
|
||||
self.use_beta_sigmas = use_beta_sigmas
|
||||
|
||||
self._step_index = None
|
||||
self._begin_index = None
|
||||
@@ -340,6 +352,8 @@ class EulerDiscreteScheduler(SchedulerMixin, ConfigMixin):
|
||||
raise ValueError("Cannot set `timesteps` with `config.use_karras_sigmas = True`.")
|
||||
if timesteps is not None and self.config.use_exponential_sigmas:
|
||||
raise ValueError("Cannot set `timesteps` with `config.use_exponential_sigmas = True`.")
|
||||
if timesteps is not None and self.config.use_beta_sigmas:
|
||||
raise ValueError("Cannot set `timesteps` with `config.use_beta_sigmas = True`.")
|
||||
if (
|
||||
timesteps is not None
|
||||
and self.config.timestep_type == "continuous"
|
||||
@@ -408,6 +422,10 @@ class EulerDiscreteScheduler(SchedulerMixin, ConfigMixin):
|
||||
sigmas = self._convert_to_exponential(in_sigmas=sigmas, num_inference_steps=self.num_inference_steps)
|
||||
timesteps = np.array([self._sigma_to_t(sigma, log_sigmas) for sigma in sigmas])
|
||||
|
||||
elif self.config.use_beta_sigmas:
|
||||
sigmas = self._convert_to_beta(in_sigmas=sigmas, num_inference_steps=self.num_inference_steps)
|
||||
timesteps = np.array([self._sigma_to_t(sigma, log_sigmas) for sigma in sigmas])
|
||||
|
||||
if self.config.final_sigmas_type == "sigma_min":
|
||||
sigma_last = ((1 - self.alphas_cumprod[0]) / self.alphas_cumprod[0]) ** 0.5
|
||||
elif self.config.final_sigmas_type == "zero":
|
||||
@@ -502,6 +520,37 @@ class EulerDiscreteScheduler(SchedulerMixin, ConfigMixin):
|
||||
sigmas = torch.linspace(math.log(sigma_max), math.log(sigma_min), num_inference_steps).exp()
|
||||
return sigmas
|
||||
|
||||
def _convert_to_beta(
|
||||
self, in_sigmas: torch.Tensor, num_inference_steps: int, alpha: float = 0.6, beta: float = 0.6
|
||||
) -> torch.Tensor:
|
||||
"""From "Beta Sampling is All You Need" [arXiv:2407.12173] (Lee et. al, 2024)"""
|
||||
|
||||
# Hack to make sure that other schedulers which copy this function don't break
|
||||
# TODO: Add this logic to the other schedulers
|
||||
if hasattr(self.config, "sigma_min"):
|
||||
sigma_min = self.config.sigma_min
|
||||
else:
|
||||
sigma_min = None
|
||||
|
||||
if hasattr(self.config, "sigma_max"):
|
||||
sigma_max = self.config.sigma_max
|
||||
else:
|
||||
sigma_max = None
|
||||
|
||||
sigma_min = sigma_min if sigma_min is not None else in_sigmas[-1].item()
|
||||
sigma_max = sigma_max if sigma_max is not None else in_sigmas[0].item()
|
||||
|
||||
sigmas = torch.Tensor(
|
||||
[
|
||||
sigma_min + (ppf * (sigma_max - sigma_min))
|
||||
for ppf in [
|
||||
scipy.stats.beta.ppf(timestep, alpha, beta)
|
||||
for timestep in 1 - np.linspace(0, 1, num_inference_steps)
|
||||
]
|
||||
]
|
||||
)
|
||||
return sigmas
|
||||
|
||||
def index_for_timestep(self, timestep, schedule_timesteps=None):
|
||||
if schedule_timesteps is None:
|
||||
schedule_timesteps = self.timesteps
|
||||
|
||||
@@ -19,9 +19,14 @@ import numpy as np
|
||||
import torch
|
||||
|
||||
from ..configuration_utils import ConfigMixin, register_to_config
|
||||
from ..utils import is_scipy_available
|
||||
from .scheduling_utils import KarrasDiffusionSchedulers, SchedulerMixin, SchedulerOutput
|
||||
|
||||
|
||||
if is_scipy_available():
|
||||
import scipy.stats
|
||||
|
||||
|
||||
# Copied from diffusers.schedulers.scheduling_ddpm.betas_for_alpha_bar
|
||||
def betas_for_alpha_bar(
|
||||
num_diffusion_timesteps,
|
||||
@@ -99,6 +104,9 @@ class HeunDiscreteScheduler(SchedulerMixin, ConfigMixin):
|
||||
the sigmas are determined according to a sequence of noise levels {σi}.
|
||||
use_exponential_sigmas (`bool`, *optional*, defaults to `False`):
|
||||
Whether to use exponential sigmas for step sizes in the noise schedule during the sampling process.
|
||||
use_beta_sigmas (`bool`, *optional*, defaults to `False`):
|
||||
Whether to use beta sigmas for step sizes in the noise schedule during the sampling process. Refer to [Beta
|
||||
Sampling is All You Need](https://huggingface.co/papers/2407.12173) for more information.
|
||||
timestep_spacing (`str`, defaults to `"linspace"`):
|
||||
The way the timesteps should be scaled. Refer to Table 2 of the [Common Diffusion Noise Schedules and
|
||||
Sample Steps are Flawed](https://huggingface.co/papers/2305.08891) for more information.
|
||||
@@ -120,13 +128,18 @@ class HeunDiscreteScheduler(SchedulerMixin, ConfigMixin):
|
||||
prediction_type: str = "epsilon",
|
||||
use_karras_sigmas: Optional[bool] = False,
|
||||
use_exponential_sigmas: Optional[bool] = False,
|
||||
use_beta_sigmas: Optional[bool] = False,
|
||||
clip_sample: Optional[bool] = False,
|
||||
clip_sample_range: float = 1.0,
|
||||
timestep_spacing: str = "linspace",
|
||||
steps_offset: int = 0,
|
||||
):
|
||||
if sum([self.config.use_exponential_sigmas, self.config.use_karras_sigmas]) > 1:
|
||||
raise ValueError("Only one of `config.use_exponential_sigmas`, `config.use_karras_sigmas` can be used.")
|
||||
if self.config.use_beta_sigmas and not is_scipy_available():
|
||||
raise ImportError("Make sure to install scipy if you want to use beta sigmas.")
|
||||
if sum([self.config.use_beta_sigmas, self.config.use_exponential_sigmas, self.config.use_karras_sigmas]) > 1:
|
||||
raise ValueError(
|
||||
"Only one of `config.use_beta_sigmas`, `config.use_exponential_sigmas`, `config.use_karras_sigmas` can be used."
|
||||
)
|
||||
if trained_betas is not None:
|
||||
self.betas = torch.tensor(trained_betas, dtype=torch.float32)
|
||||
elif beta_schedule == "linear":
|
||||
@@ -258,6 +271,8 @@ class HeunDiscreteScheduler(SchedulerMixin, ConfigMixin):
|
||||
raise ValueError("Cannot use `timesteps` with `config.use_karras_sigmas = True`")
|
||||
if timesteps is not None and self.config.use_exponential_sigmas:
|
||||
raise ValueError("Cannot set `timesteps` with `config.use_exponential_sigmas = True`.")
|
||||
if timesteps is not None and self.config.use_beta_sigmas:
|
||||
raise ValueError("Cannot set `timesteps` with `config.use_beta_sigmas = True`.")
|
||||
|
||||
num_inference_steps = num_inference_steps or len(timesteps)
|
||||
self.num_inference_steps = num_inference_steps
|
||||
@@ -296,6 +311,9 @@ class HeunDiscreteScheduler(SchedulerMixin, ConfigMixin):
|
||||
elif self.config.use_exponential_sigmas:
|
||||
sigmas = self._convert_to_exponential(in_sigmas=sigmas, num_inference_steps=self.num_inference_steps)
|
||||
timesteps = np.array([self._sigma_to_t(sigma, log_sigmas) for sigma in sigmas])
|
||||
elif self.config.use_beta_sigmas:
|
||||
sigmas = self._convert_to_beta(in_sigmas=sigmas, num_inference_steps=self.num_inference_steps)
|
||||
timesteps = np.array([self._sigma_to_t(sigma, log_sigmas) for sigma in sigmas])
|
||||
|
||||
sigmas = np.concatenate([sigmas, [0.0]]).astype(np.float32)
|
||||
sigmas = torch.from_numpy(sigmas).to(device=device)
|
||||
@@ -386,6 +404,38 @@ class HeunDiscreteScheduler(SchedulerMixin, ConfigMixin):
|
||||
sigmas = torch.linspace(math.log(sigma_max), math.log(sigma_min), num_inference_steps).exp()
|
||||
return sigmas
|
||||
|
||||
# Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._convert_to_beta
|
||||
def _convert_to_beta(
|
||||
self, in_sigmas: torch.Tensor, num_inference_steps: int, alpha: float = 0.6, beta: float = 0.6
|
||||
) -> torch.Tensor:
|
||||
"""From "Beta Sampling is All You Need" [arXiv:2407.12173] (Lee et. al, 2024)"""
|
||||
|
||||
# Hack to make sure that other schedulers which copy this function don't break
|
||||
# TODO: Add this logic to the other schedulers
|
||||
if hasattr(self.config, "sigma_min"):
|
||||
sigma_min = self.config.sigma_min
|
||||
else:
|
||||
sigma_min = None
|
||||
|
||||
if hasattr(self.config, "sigma_max"):
|
||||
sigma_max = self.config.sigma_max
|
||||
else:
|
||||
sigma_max = None
|
||||
|
||||
sigma_min = sigma_min if sigma_min is not None else in_sigmas[-1].item()
|
||||
sigma_max = sigma_max if sigma_max is not None else in_sigmas[0].item()
|
||||
|
||||
sigmas = torch.Tensor(
|
||||
[
|
||||
sigma_min + (ppf * (sigma_max - sigma_min))
|
||||
for ppf in [
|
||||
scipy.stats.beta.ppf(timestep, alpha, beta)
|
||||
for timestep in 1 - np.linspace(0, 1, num_inference_steps)
|
||||
]
|
||||
]
|
||||
)
|
||||
return sigmas
|
||||
|
||||
@property
|
||||
def state_in_first_order(self):
|
||||
return self.dt is None
|
||||
|
||||
@@ -19,10 +19,15 @@ import numpy as np
|
||||
import torch
|
||||
|
||||
from ..configuration_utils import ConfigMixin, register_to_config
|
||||
from ..utils import is_scipy_available
|
||||
from ..utils.torch_utils import randn_tensor
|
||||
from .scheduling_utils import KarrasDiffusionSchedulers, SchedulerMixin, SchedulerOutput
|
||||
|
||||
|
||||
if is_scipy_available():
|
||||
import scipy.stats
|
||||
|
||||
|
||||
# Copied from diffusers.schedulers.scheduling_ddpm.betas_for_alpha_bar
|
||||
def betas_for_alpha_bar(
|
||||
num_diffusion_timesteps,
|
||||
@@ -93,6 +98,9 @@ class KDPM2AncestralDiscreteScheduler(SchedulerMixin, ConfigMixin):
|
||||
the sigmas are determined according to a sequence of noise levels {σi}.
|
||||
use_exponential_sigmas (`bool`, *optional*, defaults to `False`):
|
||||
Whether to use exponential sigmas for step sizes in the noise schedule during the sampling process.
|
||||
use_beta_sigmas (`bool`, *optional*, defaults to `False`):
|
||||
Whether to use beta sigmas for step sizes in the noise schedule during the sampling process. Refer to [Beta
|
||||
Sampling is All You Need](https://huggingface.co/papers/2407.12173) for more information.
|
||||
prediction_type (`str`, defaults to `epsilon`, *optional*):
|
||||
Prediction type of the scheduler function; can be `epsilon` (predicts the noise of the diffusion process),
|
||||
`sample` (directly predicts the noisy sample`) or `v_prediction` (see section 2.4 of [Imagen
|
||||
@@ -117,12 +125,17 @@ class KDPM2AncestralDiscreteScheduler(SchedulerMixin, ConfigMixin):
|
||||
trained_betas: Optional[Union[np.ndarray, List[float]]] = None,
|
||||
use_karras_sigmas: Optional[bool] = False,
|
||||
use_exponential_sigmas: Optional[bool] = False,
|
||||
use_beta_sigmas: Optional[bool] = False,
|
||||
prediction_type: str = "epsilon",
|
||||
timestep_spacing: str = "linspace",
|
||||
steps_offset: int = 0,
|
||||
):
|
||||
if sum([self.config.use_exponential_sigmas, self.config.use_karras_sigmas]) > 1:
|
||||
raise ValueError("Only one of `config.use_exponential_sigmas`, `config.use_karras_sigmas` can be used.")
|
||||
if self.config.use_beta_sigmas and not is_scipy_available():
|
||||
raise ImportError("Make sure to install scipy if you want to use beta sigmas.")
|
||||
if sum([self.config.use_beta_sigmas, self.config.use_exponential_sigmas, self.config.use_karras_sigmas]) > 1:
|
||||
raise ValueError(
|
||||
"Only one of `config.use_beta_sigmas`, `config.use_exponential_sigmas`, `config.use_karras_sigmas` can be used."
|
||||
)
|
||||
if trained_betas is not None:
|
||||
self.betas = torch.tensor(trained_betas, dtype=torch.float32)
|
||||
elif beta_schedule == "linear":
|
||||
@@ -258,6 +271,9 @@ class KDPM2AncestralDiscreteScheduler(SchedulerMixin, ConfigMixin):
|
||||
elif self.config.use_exponential_sigmas:
|
||||
sigmas = self._convert_to_exponential(in_sigmas=sigmas, num_inference_steps=self.num_inference_steps)
|
||||
timesteps = np.array([self._sigma_to_t(sigma, log_sigmas) for sigma in sigmas])
|
||||
elif self.config.use_beta_sigmas:
|
||||
sigmas = self._convert_to_beta(in_sigmas=sigmas, num_inference_steps=self.num_inference_steps)
|
||||
timesteps = np.array([self._sigma_to_t(sigma, log_sigmas) for sigma in sigmas])
|
||||
|
||||
self.log_sigmas = torch.from_numpy(log_sigmas).to(device)
|
||||
sigmas = np.concatenate([sigmas, [0.0]]).astype(np.float32)
|
||||
@@ -376,6 +392,38 @@ class KDPM2AncestralDiscreteScheduler(SchedulerMixin, ConfigMixin):
|
||||
sigmas = torch.linspace(math.log(sigma_max), math.log(sigma_min), num_inference_steps).exp()
|
||||
return sigmas
|
||||
|
||||
# Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._convert_to_beta
|
||||
def _convert_to_beta(
|
||||
self, in_sigmas: torch.Tensor, num_inference_steps: int, alpha: float = 0.6, beta: float = 0.6
|
||||
) -> torch.Tensor:
|
||||
"""From "Beta Sampling is All You Need" [arXiv:2407.12173] (Lee et. al, 2024)"""
|
||||
|
||||
# Hack to make sure that other schedulers which copy this function don't break
|
||||
# TODO: Add this logic to the other schedulers
|
||||
if hasattr(self.config, "sigma_min"):
|
||||
sigma_min = self.config.sigma_min
|
||||
else:
|
||||
sigma_min = None
|
||||
|
||||
if hasattr(self.config, "sigma_max"):
|
||||
sigma_max = self.config.sigma_max
|
||||
else:
|
||||
sigma_max = None
|
||||
|
||||
sigma_min = sigma_min if sigma_min is not None else in_sigmas[-1].item()
|
||||
sigma_max = sigma_max if sigma_max is not None else in_sigmas[0].item()
|
||||
|
||||
sigmas = torch.Tensor(
|
||||
[
|
||||
sigma_min + (ppf * (sigma_max - sigma_min))
|
||||
for ppf in [
|
||||
scipy.stats.beta.ppf(timestep, alpha, beta)
|
||||
for timestep in 1 - np.linspace(0, 1, num_inference_steps)
|
||||
]
|
||||
]
|
||||
)
|
||||
return sigmas
|
||||
|
||||
@property
|
||||
def state_in_first_order(self):
|
||||
return self.sample is None
|
||||
|
||||
@@ -19,9 +19,14 @@ import numpy as np
|
||||
import torch
|
||||
|
||||
from ..configuration_utils import ConfigMixin, register_to_config
|
||||
from ..utils import is_scipy_available
|
||||
from .scheduling_utils import KarrasDiffusionSchedulers, SchedulerMixin, SchedulerOutput
|
||||
|
||||
|
||||
if is_scipy_available():
|
||||
import scipy.stats
|
||||
|
||||
|
||||
# Copied from diffusers.schedulers.scheduling_ddpm.betas_for_alpha_bar
|
||||
def betas_for_alpha_bar(
|
||||
num_diffusion_timesteps,
|
||||
@@ -92,6 +97,9 @@ class KDPM2DiscreteScheduler(SchedulerMixin, ConfigMixin):
|
||||
the sigmas are determined according to a sequence of noise levels {σi}.
|
||||
use_exponential_sigmas (`bool`, *optional*, defaults to `False`):
|
||||
Whether to use exponential sigmas for step sizes in the noise schedule during the sampling process.
|
||||
use_beta_sigmas (`bool`, *optional*, defaults to `False`):
|
||||
Whether to use beta sigmas for step sizes in the noise schedule during the sampling process. Refer to [Beta
|
||||
Sampling is All You Need](https://huggingface.co/papers/2407.12173) for more information.
|
||||
prediction_type (`str`, defaults to `epsilon`, *optional*):
|
||||
Prediction type of the scheduler function; can be `epsilon` (predicts the noise of the diffusion process),
|
||||
`sample` (directly predicts the noisy sample`) or `v_prediction` (see section 2.4 of [Imagen
|
||||
@@ -116,12 +124,17 @@ class KDPM2DiscreteScheduler(SchedulerMixin, ConfigMixin):
|
||||
trained_betas: Optional[Union[np.ndarray, List[float]]] = None,
|
||||
use_karras_sigmas: Optional[bool] = False,
|
||||
use_exponential_sigmas: Optional[bool] = False,
|
||||
use_beta_sigmas: Optional[bool] = False,
|
||||
prediction_type: str = "epsilon",
|
||||
timestep_spacing: str = "linspace",
|
||||
steps_offset: int = 0,
|
||||
):
|
||||
if sum([self.config.use_exponential_sigmas, self.config.use_karras_sigmas]) > 1:
|
||||
raise ValueError("Only one of `config.use_exponential_sigmas`, `config.use_karras_sigmas` can be used.")
|
||||
if self.config.use_beta_sigmas and not is_scipy_available():
|
||||
raise ImportError("Make sure to install scipy if you want to use beta sigmas.")
|
||||
if sum([self.config.use_beta_sigmas, self.config.use_exponential_sigmas, self.config.use_karras_sigmas]) > 1:
|
||||
raise ValueError(
|
||||
"Only one of `config.use_beta_sigmas`, `config.use_exponential_sigmas`, `config.use_karras_sigmas` can be used."
|
||||
)
|
||||
if trained_betas is not None:
|
||||
self.betas = torch.tensor(trained_betas, dtype=torch.float32)
|
||||
elif beta_schedule == "linear":
|
||||
@@ -257,6 +270,9 @@ class KDPM2DiscreteScheduler(SchedulerMixin, ConfigMixin):
|
||||
elif self.config.use_exponential_sigmas:
|
||||
sigmas = self._convert_to_exponential(in_sigmas=sigmas, num_inference_steps=self.num_inference_steps)
|
||||
timesteps = np.array([self._sigma_to_t(sigma, log_sigmas) for sigma in sigmas])
|
||||
elif self.config.use_beta_sigmas:
|
||||
sigmas = self._convert_to_beta(in_sigmas=sigmas, num_inference_steps=self.num_inference_steps)
|
||||
timesteps = np.array([self._sigma_to_t(sigma, log_sigmas) for sigma in sigmas])
|
||||
|
||||
self.log_sigmas = torch.from_numpy(log_sigmas).to(device=device)
|
||||
sigmas = np.concatenate([sigmas, [0.0]]).astype(np.float32)
|
||||
@@ -389,6 +405,38 @@ class KDPM2DiscreteScheduler(SchedulerMixin, ConfigMixin):
|
||||
sigmas = torch.linspace(math.log(sigma_max), math.log(sigma_min), num_inference_steps).exp()
|
||||
return sigmas
|
||||
|
||||
# Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._convert_to_beta
|
||||
def _convert_to_beta(
|
||||
self, in_sigmas: torch.Tensor, num_inference_steps: int, alpha: float = 0.6, beta: float = 0.6
|
||||
) -> torch.Tensor:
|
||||
"""From "Beta Sampling is All You Need" [arXiv:2407.12173] (Lee et. al, 2024)"""
|
||||
|
||||
# Hack to make sure that other schedulers which copy this function don't break
|
||||
# TODO: Add this logic to the other schedulers
|
||||
if hasattr(self.config, "sigma_min"):
|
||||
sigma_min = self.config.sigma_min
|
||||
else:
|
||||
sigma_min = None
|
||||
|
||||
if hasattr(self.config, "sigma_max"):
|
||||
sigma_max = self.config.sigma_max
|
||||
else:
|
||||
sigma_max = None
|
||||
|
||||
sigma_min = sigma_min if sigma_min is not None else in_sigmas[-1].item()
|
||||
sigma_max = sigma_max if sigma_max is not None else in_sigmas[0].item()
|
||||
|
||||
sigmas = torch.Tensor(
|
||||
[
|
||||
sigma_min + (ppf * (sigma_max - sigma_min))
|
||||
for ppf in [
|
||||
scipy.stats.beta.ppf(timestep, alpha, beta)
|
||||
for timestep in 1 - np.linspace(0, 1, num_inference_steps)
|
||||
]
|
||||
]
|
||||
)
|
||||
return sigmas
|
||||
|
||||
def step(
|
||||
self,
|
||||
model_output: Union[torch.Tensor, np.ndarray],
|
||||
|
||||
@@ -17,6 +17,7 @@ from dataclasses import dataclass
|
||||
from typing import List, Optional, Tuple, Union
|
||||
|
||||
import numpy as np
|
||||
import scipy.stats
|
||||
import torch
|
||||
from scipy import integrate
|
||||
|
||||
@@ -113,6 +114,9 @@ class LMSDiscreteScheduler(SchedulerMixin, ConfigMixin):
|
||||
the sigmas are determined according to a sequence of noise levels {σi}.
|
||||
use_exponential_sigmas (`bool`, *optional*, defaults to `False`):
|
||||
Whether to use exponential sigmas for step sizes in the noise schedule during the sampling process.
|
||||
use_beta_sigmas (`bool`, *optional*, defaults to `False`):
|
||||
Whether to use beta sigmas for step sizes in the noise schedule during the sampling process. Refer to [Beta
|
||||
Sampling is All You Need](https://huggingface.co/papers/2407.12173) for more information.
|
||||
prediction_type (`str`, defaults to `epsilon`, *optional*):
|
||||
Prediction type of the scheduler function; can be `epsilon` (predicts the noise of the diffusion process),
|
||||
`sample` (directly predicts the noisy sample`) or `v_prediction` (see section 2.4 of [Imagen
|
||||
@@ -137,12 +141,15 @@ class LMSDiscreteScheduler(SchedulerMixin, ConfigMixin):
|
||||
trained_betas: Optional[Union[np.ndarray, List[float]]] = None,
|
||||
use_karras_sigmas: Optional[bool] = False,
|
||||
use_exponential_sigmas: Optional[bool] = False,
|
||||
use_beta_sigmas: Optional[bool] = False,
|
||||
prediction_type: str = "epsilon",
|
||||
timestep_spacing: str = "linspace",
|
||||
steps_offset: int = 0,
|
||||
):
|
||||
if sum([self.config.use_exponential_sigmas, self.config.use_karras_sigmas]) > 1:
|
||||
raise ValueError("Only one of `config.use_exponential_sigmas`, `config.use_karras_sigmas` can be used.")
|
||||
if sum([self.config.use_beta_sigmas, self.config.use_exponential_sigmas, self.config.use_karras_sigmas]) > 1:
|
||||
raise ValueError(
|
||||
"Only one of `config.use_beta_sigmas`, `config.use_exponential_sigmas`, `config.use_karras_sigmas` can be used."
|
||||
)
|
||||
if trained_betas is not None:
|
||||
self.betas = torch.tensor(trained_betas, dtype=torch.float32)
|
||||
elif beta_schedule == "linear":
|
||||
@@ -297,6 +304,9 @@ class LMSDiscreteScheduler(SchedulerMixin, ConfigMixin):
|
||||
elif self.config.use_exponential_sigmas:
|
||||
sigmas = self._convert_to_exponential(in_sigmas=sigmas, num_inference_steps=self.num_inference_steps)
|
||||
timesteps = np.array([self._sigma_to_t(sigma, log_sigmas) for sigma in sigmas])
|
||||
elif self.config.use_beta_sigmas:
|
||||
sigmas = self._convert_to_beta(in_sigmas=sigmas, num_inference_steps=self.num_inference_steps)
|
||||
timesteps = np.array([self._sigma_to_t(sigma, log_sigmas) for sigma in sigmas])
|
||||
|
||||
sigmas = np.concatenate([sigmas, [0.0]]).astype(np.float32)
|
||||
|
||||
@@ -392,6 +402,38 @@ class LMSDiscreteScheduler(SchedulerMixin, ConfigMixin):
|
||||
sigmas = torch.linspace(math.log(sigma_max), math.log(sigma_min), num_inference_steps).exp()
|
||||
return sigmas
|
||||
|
||||
# Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._convert_to_beta
|
||||
def _convert_to_beta(
|
||||
self, in_sigmas: torch.Tensor, num_inference_steps: int, alpha: float = 0.6, beta: float = 0.6
|
||||
) -> torch.Tensor:
|
||||
"""From "Beta Sampling is All You Need" [arXiv:2407.12173] (Lee et. al, 2024)"""
|
||||
|
||||
# Hack to make sure that other schedulers which copy this function don't break
|
||||
# TODO: Add this logic to the other schedulers
|
||||
if hasattr(self.config, "sigma_min"):
|
||||
sigma_min = self.config.sigma_min
|
||||
else:
|
||||
sigma_min = None
|
||||
|
||||
if hasattr(self.config, "sigma_max"):
|
||||
sigma_max = self.config.sigma_max
|
||||
else:
|
||||
sigma_max = None
|
||||
|
||||
sigma_min = sigma_min if sigma_min is not None else in_sigmas[-1].item()
|
||||
sigma_max = sigma_max if sigma_max is not None else in_sigmas[0].item()
|
||||
|
||||
sigmas = torch.Tensor(
|
||||
[
|
||||
sigma_min + (ppf * (sigma_max - sigma_min))
|
||||
for ppf in [
|
||||
scipy.stats.beta.ppf(timestep, alpha, beta)
|
||||
for timestep in 1 - np.linspace(0, 1, num_inference_steps)
|
||||
]
|
||||
]
|
||||
)
|
||||
return sigmas
|
||||
|
||||
def step(
|
||||
self,
|
||||
model_output: torch.Tensor,
|
||||
|
||||
@@ -22,11 +22,15 @@ import numpy as np
|
||||
import torch
|
||||
|
||||
from ..configuration_utils import ConfigMixin, register_to_config
|
||||
from ..utils import deprecate
|
||||
from ..utils import deprecate, is_scipy_available
|
||||
from ..utils.torch_utils import randn_tensor
|
||||
from .scheduling_utils import KarrasDiffusionSchedulers, SchedulerMixin, SchedulerOutput
|
||||
|
||||
|
||||
if is_scipy_available():
|
||||
import scipy.stats
|
||||
|
||||
|
||||
# Copied from diffusers.schedulers.scheduling_ddpm.betas_for_alpha_bar
|
||||
def betas_for_alpha_bar(
|
||||
num_diffusion_timesteps,
|
||||
@@ -124,6 +128,9 @@ class SASolverScheduler(SchedulerMixin, ConfigMixin):
|
||||
the sigmas are determined according to a sequence of noise levels {σi}.
|
||||
use_exponential_sigmas (`bool`, *optional*, defaults to `False`):
|
||||
Whether to use exponential sigmas for step sizes in the noise schedule during the sampling process.
|
||||
use_beta_sigmas (`bool`, *optional*, defaults to `False`):
|
||||
Whether to use beta sigmas for step sizes in the noise schedule during the sampling process. Refer to [Beta
|
||||
Sampling is All You Need](https://huggingface.co/papers/2407.12173) for more information.
|
||||
lambda_min_clipped (`float`, defaults to `-inf`):
|
||||
Clipping threshold for the minimum value of `lambda(t)` for numerical stability. This is critical for the
|
||||
cosine (`squaredcos_cap_v2`) noise schedule.
|
||||
@@ -159,13 +166,18 @@ class SASolverScheduler(SchedulerMixin, ConfigMixin):
|
||||
lower_order_final: bool = True,
|
||||
use_karras_sigmas: Optional[bool] = False,
|
||||
use_exponential_sigmas: Optional[bool] = False,
|
||||
use_beta_sigmas: Optional[bool] = False,
|
||||
lambda_min_clipped: float = -float("inf"),
|
||||
variance_type: Optional[str] = None,
|
||||
timestep_spacing: str = "linspace",
|
||||
steps_offset: int = 0,
|
||||
):
|
||||
if sum([self.config.use_exponential_sigmas, self.config.use_karras_sigmas]) > 1:
|
||||
raise ValueError("Only one of `config.use_exponential_sigmas`, `config.use_karras_sigmas` can be used.")
|
||||
if self.config.use_beta_sigmas and not is_scipy_available():
|
||||
raise ImportError("Make sure to install scipy if you want to use beta sigmas.")
|
||||
if sum([self.config.use_beta_sigmas, self.config.use_exponential_sigmas, self.config.use_karras_sigmas]) > 1:
|
||||
raise ValueError(
|
||||
"Only one of `config.use_beta_sigmas`, `config.use_exponential_sigmas`, `config.use_karras_sigmas` can be used."
|
||||
)
|
||||
if trained_betas is not None:
|
||||
self.betas = torch.tensor(trained_betas, dtype=torch.float32)
|
||||
elif beta_schedule == "linear":
|
||||
@@ -292,6 +304,9 @@ class SASolverScheduler(SchedulerMixin, ConfigMixin):
|
||||
elif self.config.use_exponential_sigmas:
|
||||
sigmas = self._convert_to_exponential(in_sigmas=sigmas, num_inference_steps=self.num_inference_steps)
|
||||
timesteps = np.array([self._sigma_to_t(sigma, log_sigmas) for sigma in sigmas])
|
||||
elif self.config.use_beta_sigmas:
|
||||
sigmas = self._convert_to_beta(in_sigmas=sigmas, num_inference_steps=self.num_inference_steps)
|
||||
timesteps = np.array([self._sigma_to_t(sigma, log_sigmas) for sigma in sigmas])
|
||||
else:
|
||||
sigmas = np.interp(timesteps, np.arange(0, len(sigmas)), sigmas)
|
||||
sigma_last = ((1 - self.alphas_cumprod[0]) / self.alphas_cumprod[0]) ** 0.5
|
||||
@@ -425,6 +440,38 @@ class SASolverScheduler(SchedulerMixin, ConfigMixin):
|
||||
sigmas = torch.linspace(math.log(sigma_max), math.log(sigma_min), num_inference_steps).exp()
|
||||
return sigmas
|
||||
|
||||
# Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._convert_to_beta
|
||||
def _convert_to_beta(
|
||||
self, in_sigmas: torch.Tensor, num_inference_steps: int, alpha: float = 0.6, beta: float = 0.6
|
||||
) -> torch.Tensor:
|
||||
"""From "Beta Sampling is All You Need" [arXiv:2407.12173] (Lee et. al, 2024)"""
|
||||
|
||||
# Hack to make sure that other schedulers which copy this function don't break
|
||||
# TODO: Add this logic to the other schedulers
|
||||
if hasattr(self.config, "sigma_min"):
|
||||
sigma_min = self.config.sigma_min
|
||||
else:
|
||||
sigma_min = None
|
||||
|
||||
if hasattr(self.config, "sigma_max"):
|
||||
sigma_max = self.config.sigma_max
|
||||
else:
|
||||
sigma_max = None
|
||||
|
||||
sigma_min = sigma_min if sigma_min is not None else in_sigmas[-1].item()
|
||||
sigma_max = sigma_max if sigma_max is not None else in_sigmas[0].item()
|
||||
|
||||
sigmas = torch.Tensor(
|
||||
[
|
||||
sigma_min + (ppf * (sigma_max - sigma_min))
|
||||
for ppf in [
|
||||
scipy.stats.beta.ppf(timestep, alpha, beta)
|
||||
for timestep in 1 - np.linspace(0, 1, num_inference_steps)
|
||||
]
|
||||
]
|
||||
)
|
||||
return sigmas
|
||||
|
||||
def convert_model_output(
|
||||
self,
|
||||
model_output: torch.Tensor,
|
||||
|
||||
@@ -22,10 +22,14 @@ import numpy as np
|
||||
import torch
|
||||
|
||||
from ..configuration_utils import ConfigMixin, register_to_config
|
||||
from ..utils import deprecate
|
||||
from ..utils import deprecate, is_scipy_available
|
||||
from .scheduling_utils import KarrasDiffusionSchedulers, SchedulerMixin, SchedulerOutput
|
||||
|
||||
|
||||
if is_scipy_available():
|
||||
import scipy.stats
|
||||
|
||||
|
||||
# Copied from diffusers.schedulers.scheduling_ddpm.betas_for_alpha_bar
|
||||
def betas_for_alpha_bar(
|
||||
num_diffusion_timesteps,
|
||||
@@ -161,6 +165,9 @@ class UniPCMultistepScheduler(SchedulerMixin, ConfigMixin):
|
||||
the sigmas are determined according to a sequence of noise levels {σi}.
|
||||
use_exponential_sigmas (`bool`, *optional*, defaults to `False`):
|
||||
Whether to use exponential sigmas for step sizes in the noise schedule during the sampling process.
|
||||
use_beta_sigmas (`bool`, *optional*, defaults to `False`):
|
||||
Whether to use beta sigmas for step sizes in the noise schedule during the sampling process. Refer to [Beta
|
||||
Sampling is All You Need](https://huggingface.co/papers/2407.12173) for more information.
|
||||
timestep_spacing (`str`, defaults to `"linspace"`):
|
||||
The way the timesteps should be scaled. Refer to Table 2 of the [Common Diffusion Noise Schedules and
|
||||
Sample Steps are Flawed](https://huggingface.co/papers/2305.08891) for more information.
|
||||
@@ -198,13 +205,18 @@ class UniPCMultistepScheduler(SchedulerMixin, ConfigMixin):
|
||||
solver_p: SchedulerMixin = None,
|
||||
use_karras_sigmas: Optional[bool] = False,
|
||||
use_exponential_sigmas: Optional[bool] = False,
|
||||
use_beta_sigmas: Optional[bool] = False,
|
||||
timestep_spacing: str = "linspace",
|
||||
steps_offset: int = 0,
|
||||
final_sigmas_type: Optional[str] = "zero", # "zero", "sigma_min"
|
||||
rescale_betas_zero_snr: bool = False,
|
||||
):
|
||||
if sum([self.config.use_exponential_sigmas, self.config.use_karras_sigmas]) > 1:
|
||||
raise ValueError("Only one of `config.use_exponential_sigmas`, `config.use_karras_sigmas` can be used.")
|
||||
if self.config.use_beta_sigmas and not is_scipy_available():
|
||||
raise ImportError("Make sure to install scipy if you want to use beta sigmas.")
|
||||
if sum([self.config.use_beta_sigmas, self.config.use_exponential_sigmas, self.config.use_karras_sigmas]) > 1:
|
||||
raise ValueError(
|
||||
"Only one of `config.use_beta_sigmas`, `config.use_exponential_sigmas`, `config.use_karras_sigmas` can be used."
|
||||
)
|
||||
if trained_betas is not None:
|
||||
self.betas = torch.tensor(trained_betas, dtype=torch.float32)
|
||||
elif beta_schedule == "linear":
|
||||
@@ -337,6 +349,9 @@ class UniPCMultistepScheduler(SchedulerMixin, ConfigMixin):
|
||||
elif self.config.use_exponential_sigmas:
|
||||
sigmas = self._convert_to_exponential(in_sigmas=sigmas, num_inference_steps=self.num_inference_steps)
|
||||
timesteps = np.array([self._sigma_to_t(sigma, log_sigmas) for sigma in sigmas])
|
||||
elif self.config.use_beta_sigmas:
|
||||
sigmas = self._convert_to_beta(in_sigmas=sigmas, num_inference_steps=self.num_inference_steps)
|
||||
timesteps = np.array([self._sigma_to_t(sigma, log_sigmas) for sigma in sigmas])
|
||||
else:
|
||||
sigmas = np.interp(timesteps, np.arange(0, len(sigmas)), sigmas)
|
||||
if self.config.final_sigmas_type == "sigma_min":
|
||||
@@ -480,6 +495,38 @@ class UniPCMultistepScheduler(SchedulerMixin, ConfigMixin):
|
||||
sigmas = torch.linspace(math.log(sigma_max), math.log(sigma_min), num_inference_steps).exp()
|
||||
return sigmas
|
||||
|
||||
# Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._convert_to_beta
|
||||
def _convert_to_beta(
|
||||
self, in_sigmas: torch.Tensor, num_inference_steps: int, alpha: float = 0.6, beta: float = 0.6
|
||||
) -> torch.Tensor:
|
||||
"""From "Beta Sampling is All You Need" [arXiv:2407.12173] (Lee et. al, 2024)"""
|
||||
|
||||
# Hack to make sure that other schedulers which copy this function don't break
|
||||
# TODO: Add this logic to the other schedulers
|
||||
if hasattr(self.config, "sigma_min"):
|
||||
sigma_min = self.config.sigma_min
|
||||
else:
|
||||
sigma_min = None
|
||||
|
||||
if hasattr(self.config, "sigma_max"):
|
||||
sigma_max = self.config.sigma_max
|
||||
else:
|
||||
sigma_max = None
|
||||
|
||||
sigma_min = sigma_min if sigma_min is not None else in_sigmas[-1].item()
|
||||
sigma_max = sigma_max if sigma_max is not None else in_sigmas[0].item()
|
||||
|
||||
sigmas = torch.Tensor(
|
||||
[
|
||||
sigma_min + (ppf * (sigma_max - sigma_min))
|
||||
for ppf in [
|
||||
scipy.stats.beta.ppf(timestep, alpha, beta)
|
||||
for timestep in 1 - np.linspace(0, 1, num_inference_steps)
|
||||
]
|
||||
]
|
||||
)
|
||||
return sigmas
|
||||
|
||||
def convert_model_output(
|
||||
self,
|
||||
model_output: torch.Tensor,
|
||||
|
||||
@@ -260,12 +260,8 @@ def compute_loss_weighting_for_sd3(weighting_scheme: str, sigmas=None):
|
||||
return weighting
|
||||
|
||||
|
||||
def clear_objs_and_retain_memory(objs: List[Any]):
|
||||
"""Deletes `objs` and runs garbage collection. Then clears the cache of the available accelerator."""
|
||||
if len(objs) >= 1:
|
||||
for obj in objs:
|
||||
del obj
|
||||
|
||||
def free_memory():
|
||||
"""Runs garbage collection. Then clears the cache of the available accelerator."""
|
||||
gc.collect()
|
||||
|
||||
if torch.cuda.is_available():
|
||||
|
||||
@@ -1352,6 +1352,21 @@ class StableDiffusionControlNetInpaintPipeline(metaclass=DummyObject):
|
||||
requires_backends(cls, ["torch", "transformers"])
|
||||
|
||||
|
||||
class StableDiffusionControlNetPAGInpaintPipeline(metaclass=DummyObject):
|
||||
_backends = ["torch", "transformers"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["torch", "transformers"])
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch", "transformers"])
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch", "transformers"])
|
||||
|
||||
|
||||
class StableDiffusionControlNetPAGPipeline(metaclass=DummyObject):
|
||||
_backends = ["torch", "transformers"]
|
||||
|
||||
|
||||
@@ -271,8 +271,7 @@ if cache_version < 1:
|
||||
def _add_variant(weights_name: str, variant: Optional[str] = None) -> str:
|
||||
if variant is not None:
|
||||
splits = weights_name.split(".")
|
||||
split_index = -2 if weights_name.endswith(".index.json") else -1
|
||||
splits = splits[:-split_index] + [variant] + splits[-split_index:]
|
||||
splits = splits[:-1] + [variant] + splits[-1:]
|
||||
weights_name = ".".join(splits)
|
||||
|
||||
return weights_name
|
||||
@@ -502,6 +501,19 @@ def _get_checkpoint_shard_files(
|
||||
return cached_folder, sharded_metadata
|
||||
|
||||
|
||||
def _check_legacy_sharding_variant_format(folder: str = None, filenames: List[str] = None, variant: str = None):
|
||||
if filenames and folder:
|
||||
raise ValueError("Both `filenames` and `folder` cannot be provided.")
|
||||
if not filenames:
|
||||
filenames = []
|
||||
for _, _, files in os.walk(folder):
|
||||
for file in files:
|
||||
filenames.append(os.path.basename(file))
|
||||
transformers_index_format = r"\d{5}-of-\d{5}"
|
||||
variant_file_re = re.compile(rf".*-{transformers_index_format}\.{variant}\.[a-z]+$")
|
||||
return any(variant_file_re.match(f) is not None for f in filenames)
|
||||
|
||||
|
||||
class PushToHubMixin:
|
||||
"""
|
||||
A Mixin to push a model, scheduler, or pipeline to the Hugging Face Hub.
|
||||
|
||||
@@ -252,6 +252,18 @@ def require_torch_2(test_case):
|
||||
)
|
||||
|
||||
|
||||
def require_torch_version_greater_equal(torch_version):
|
||||
"""Decorator marking a test that requires torch with a specific version or greater."""
|
||||
|
||||
def decorator(test_case):
|
||||
correct_torch_version = is_torch_available() and is_torch_version(">=", torch_version)
|
||||
return unittest.skipUnless(
|
||||
correct_torch_version, f"test requires torch with the version greater than or equal to {torch_version}"
|
||||
)(test_case)
|
||||
|
||||
return decorator
|
||||
|
||||
|
||||
def require_torch_gpu(test_case):
|
||||
"""Decorator marking a test that requires CUDA and PyTorch."""
|
||||
return unittest.skipUnless(is_torch_available() and torch_device == "cuda", "test requires PyTorch+CUDA")(
|
||||
|
||||
@@ -228,6 +228,26 @@ class FluxLoRAIntegrationTests(unittest.TestCase):
|
||||
|
||||
assert np.allclose(out_slice, expected_slice, atol=1e-4, rtol=1e-4)
|
||||
|
||||
def test_flux_kohya_with_text_encoder(self):
|
||||
self.pipeline.load_lora_weights("cocktailpeanut/optimus", weight_name="optimus.safetensors")
|
||||
self.pipeline.fuse_lora()
|
||||
self.pipeline.unload_lora_weights()
|
||||
self.pipeline.enable_model_cpu_offload()
|
||||
|
||||
prompt = "optimus is cleaning the house with broomstick"
|
||||
out = self.pipeline(
|
||||
prompt,
|
||||
num_inference_steps=self.num_inference_steps,
|
||||
guidance_scale=4.5,
|
||||
output_type="np",
|
||||
generator=torch.manual_seed(self.seed),
|
||||
).images
|
||||
|
||||
out_slice = out[0, -3:, -3:, -1].flatten()
|
||||
expected_slice = np.array([0.4023, 0.4043, 0.4023, 0.3965, 0.3984, 0.3984, 0.3906, 0.3906, 0.4219])
|
||||
|
||||
assert np.allclose(out_slice, expected_slice, atol=1e-4, rtol=1e-4)
|
||||
|
||||
def test_flux_xlabs(self):
|
||||
self.pipeline.load_lora_weights("XLabs-AI/flux-lora-collection", weight_name="disney_lora.safetensors")
|
||||
self.pipeline.fuse_lora()
|
||||
|
||||
+85
-73
@@ -201,6 +201,32 @@ class PeftLoraLoaderMixinTests:
|
||||
prepared_inputs["input_ids"] = inputs
|
||||
return prepared_inputs
|
||||
|
||||
def _get_lora_state_dicts(self, modules_to_save):
|
||||
state_dicts = {}
|
||||
for module_name, module in modules_to_save.items():
|
||||
if module is not None:
|
||||
state_dicts[f"{module_name}_lora_layers"] = get_peft_model_state_dict(module)
|
||||
return state_dicts
|
||||
|
||||
def _get_modules_to_save(self, pipe, has_denoiser=False):
|
||||
modules_to_save = {}
|
||||
lora_loadable_modules = self.pipeline_class._lora_loadable_modules
|
||||
|
||||
if "text_encoder" in lora_loadable_modules and hasattr(pipe, "text_encoder"):
|
||||
modules_to_save["text_encoder"] = pipe.text_encoder
|
||||
|
||||
if "text_encoder_2" in lora_loadable_modules and hasattr(pipe, "text_encoder_2"):
|
||||
modules_to_save["text_encoder_2"] = pipe.text_encoder_2
|
||||
|
||||
if has_denoiser:
|
||||
if "unet" in lora_loadable_modules and hasattr(pipe, "unet"):
|
||||
modules_to_save["unet"] = pipe.unet
|
||||
|
||||
if "transformer" in lora_loadable_modules and hasattr(pipe, "transformer"):
|
||||
modules_to_save["transformer"] = pipe.transformer
|
||||
|
||||
return modules_to_save
|
||||
|
||||
def test_simple_inference(self):
|
||||
"""
|
||||
Tests a simple inference and makes sure it works as expected
|
||||
@@ -420,45 +446,21 @@ class PeftLoraLoaderMixinTests:
|
||||
images_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
text_encoder_state_dict = get_peft_model_state_dict(pipe.text_encoder)
|
||||
if self.has_two_text_encoders or self.has_three_text_encoders:
|
||||
if "text_encoder_2" in self.pipeline_class._lora_loadable_modules:
|
||||
text_encoder_2_state_dict = get_peft_model_state_dict(pipe.text_encoder_2)
|
||||
modules_to_save = self._get_modules_to_save(pipe)
|
||||
lora_state_dicts = self._get_lora_state_dicts(modules_to_save)
|
||||
|
||||
self.pipeline_class.save_lora_weights(
|
||||
save_directory=tmpdirname,
|
||||
text_encoder_lora_layers=text_encoder_state_dict,
|
||||
text_encoder_2_lora_layers=text_encoder_2_state_dict,
|
||||
safe_serialization=False,
|
||||
)
|
||||
else:
|
||||
self.pipeline_class.save_lora_weights(
|
||||
save_directory=tmpdirname,
|
||||
text_encoder_lora_layers=text_encoder_state_dict,
|
||||
safe_serialization=False,
|
||||
)
|
||||
|
||||
if self.has_two_text_encoders:
|
||||
if "text_encoder_2" not in self.pipeline_class._lora_loadable_modules:
|
||||
self.pipeline_class.save_lora_weights(
|
||||
save_directory=tmpdirname,
|
||||
text_encoder_lora_layers=text_encoder_state_dict,
|
||||
safe_serialization=False,
|
||||
)
|
||||
self.pipeline_class.save_lora_weights(
|
||||
save_directory=tmpdirname, safe_serialization=False, **lora_state_dicts
|
||||
)
|
||||
|
||||
self.assertTrue(os.path.isfile(os.path.join(tmpdirname, "pytorch_lora_weights.bin")))
|
||||
pipe.unload_lora_weights()
|
||||
|
||||
pipe.load_lora_weights(os.path.join(tmpdirname, "pytorch_lora_weights.bin"))
|
||||
|
||||
images_lora_from_pretrained = pipe(**inputs, generator=torch.manual_seed(0))[0]
|
||||
self.assertTrue(check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder")
|
||||
for module_name, module in modules_to_save.items():
|
||||
self.assertTrue(check_if_lora_correctly_set(module), f"Lora not correctly set in {module_name}")
|
||||
|
||||
if self.has_two_text_encoders or self.has_three_text_encoders:
|
||||
if "text_encoder_2" in self.pipeline_class._lora_loadable_modules:
|
||||
self.assertTrue(
|
||||
check_if_lora_correctly_set(pipe.text_encoder_2), "Lora not correctly set in text encoder 2"
|
||||
)
|
||||
images_lora_from_pretrained = pipe(**inputs, generator=torch.manual_seed(0))[0]
|
||||
|
||||
self.assertTrue(
|
||||
np.allclose(images_lora, images_lora_from_pretrained, atol=1e-3, rtol=1e-3),
|
||||
@@ -614,54 +616,20 @@ class PeftLoraLoaderMixinTests:
|
||||
images_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
text_encoder_state_dict = (
|
||||
get_peft_model_state_dict(pipe.text_encoder)
|
||||
if "text_encoder" in self.pipeline_class._lora_loadable_modules
|
||||
else None
|
||||
modules_to_save = self._get_modules_to_save(pipe, has_denoiser=True)
|
||||
lora_state_dicts = self._get_lora_state_dicts(modules_to_save)
|
||||
self.pipeline_class.save_lora_weights(
|
||||
save_directory=tmpdirname, safe_serialization=False, **lora_state_dicts
|
||||
)
|
||||
|
||||
denoiser_state_dict = get_peft_model_state_dict(denoiser)
|
||||
|
||||
saving_kwargs = {
|
||||
"save_directory": tmpdirname,
|
||||
"safe_serialization": False,
|
||||
}
|
||||
|
||||
if "text_encoder" in self.pipeline_class._lora_loadable_modules:
|
||||
saving_kwargs.update({"text_encoder_lora_layers": text_encoder_state_dict})
|
||||
|
||||
if self.unet_kwargs is not None:
|
||||
saving_kwargs.update({"unet_lora_layers": denoiser_state_dict})
|
||||
else:
|
||||
saving_kwargs.update({"transformer_lora_layers": denoiser_state_dict})
|
||||
|
||||
if self.has_two_text_encoders or self.has_three_text_encoders:
|
||||
if "text_encoder_2" in self.pipeline_class._lora_loadable_modules:
|
||||
text_encoder_2_state_dict = get_peft_model_state_dict(pipe.text_encoder_2)
|
||||
saving_kwargs.update({"text_encoder_2_lora_layers": text_encoder_2_state_dict})
|
||||
|
||||
self.pipeline_class.save_lora_weights(**saving_kwargs)
|
||||
|
||||
self.assertTrue(os.path.isfile(os.path.join(tmpdirname, "pytorch_lora_weights.bin")))
|
||||
pipe.unload_lora_weights()
|
||||
|
||||
pipe.load_lora_weights(os.path.join(tmpdirname, "pytorch_lora_weights.bin"))
|
||||
|
||||
for module_name, module in modules_to_save.items():
|
||||
self.assertTrue(check_if_lora_correctly_set(module), f"Lora not correctly set in {module_name}")
|
||||
|
||||
images_lora_from_pretrained = pipe(**inputs, generator=torch.manual_seed(0))[0]
|
||||
|
||||
if "text_encoder" in self.pipeline_class._lora_loadable_modules:
|
||||
self.assertTrue(
|
||||
check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder"
|
||||
)
|
||||
|
||||
self.assertTrue(check_if_lora_correctly_set(denoiser), "Lora not correctly set in denoiser")
|
||||
|
||||
if self.has_two_text_encoders or self.has_three_text_encoders:
|
||||
if "text_encoder_2" in self.pipeline_class._lora_loadable_modules:
|
||||
self.assertTrue(
|
||||
check_if_lora_correctly_set(pipe.text_encoder_2), "Lora not correctly set in text encoder 2"
|
||||
)
|
||||
|
||||
self.assertTrue(
|
||||
np.allclose(images_lora, images_lora_from_pretrained, atol=1e-3, rtol=1e-3),
|
||||
"Loading from saved checkpoints should give same results.",
|
||||
@@ -929,12 +897,24 @@ class PeftLoraLoaderMixinTests:
|
||||
|
||||
pipe.set_adapters("adapter-1")
|
||||
output_adapter_1 = pipe(**inputs, generator=torch.manual_seed(0))[0]
|
||||
self.assertFalse(
|
||||
np.allclose(output_no_lora, output_adapter_1, atol=1e-3, rtol=1e-3),
|
||||
"Adapter outputs should be different.",
|
||||
)
|
||||
|
||||
pipe.set_adapters("adapter-2")
|
||||
output_adapter_2 = pipe(**inputs, generator=torch.manual_seed(0))[0]
|
||||
self.assertFalse(
|
||||
np.allclose(output_no_lora, output_adapter_2, atol=1e-3, rtol=1e-3),
|
||||
"Adapter outputs should be different.",
|
||||
)
|
||||
|
||||
pipe.set_adapters(["adapter-1", "adapter-2"])
|
||||
output_adapter_mixed = pipe(**inputs, generator=torch.manual_seed(0))[0]
|
||||
self.assertFalse(
|
||||
np.allclose(output_no_lora, output_adapter_mixed, atol=1e-3, rtol=1e-3),
|
||||
"Adapter outputs should be different.",
|
||||
)
|
||||
|
||||
# Fuse and unfuse should lead to the same results
|
||||
self.assertFalse(
|
||||
@@ -960,6 +940,38 @@ class PeftLoraLoaderMixinTests:
|
||||
"output with no lora and output with lora disabled should give same results",
|
||||
)
|
||||
|
||||
def test_wrong_adapter_name_raises_error(self):
|
||||
scheduler_cls = self.scheduler_classes[0]
|
||||
components, text_lora_config, denoiser_lora_config = self.get_dummy_components(scheduler_cls)
|
||||
pipe = self.pipeline_class(**components)
|
||||
pipe = pipe.to(torch_device)
|
||||
pipe.set_progress_bar_config(disable=None)
|
||||
_, _, inputs = self.get_dummy_inputs(with_generator=False)
|
||||
|
||||
if "text_encoder" in self.pipeline_class._lora_loadable_modules:
|
||||
pipe.text_encoder.add_adapter(text_lora_config, "adapter-1")
|
||||
self.assertTrue(check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder")
|
||||
|
||||
denoiser = pipe.transformer if self.unet_kwargs is None else pipe.unet
|
||||
denoiser.add_adapter(denoiser_lora_config, "adapter-1")
|
||||
self.assertTrue(check_if_lora_correctly_set(denoiser), "Lora not correctly set in denoiser.")
|
||||
|
||||
if self.has_two_text_encoders or self.has_three_text_encoders:
|
||||
if "text_encoder_2" in self.pipeline_class._lora_loadable_modules:
|
||||
pipe.text_encoder_2.add_adapter(text_lora_config, "adapter-1")
|
||||
self.assertTrue(
|
||||
check_if_lora_correctly_set(pipe.text_encoder_2), "Lora not correctly set in text encoder 2"
|
||||
)
|
||||
|
||||
with self.assertRaises(ValueError) as err_context:
|
||||
pipe.set_adapters("test")
|
||||
|
||||
self.assertTrue("not in the list of present adapters" in str(err_context.exception))
|
||||
|
||||
# test this works.
|
||||
pipe.set_adapters("adapter-1")
|
||||
_ = pipe(**inputs, generator=torch.manual_seed(0))[0]
|
||||
|
||||
def test_simple_inference_with_text_denoiser_block_scale(self):
|
||||
"""
|
||||
Tests a simple inference with lora attached to text encoder and unet, attaches
|
||||
|
||||
@@ -27,6 +27,7 @@ from diffusers.models.transformers.transformer_2d import Transformer2DModel
|
||||
from diffusers.utils.testing_utils import (
|
||||
backend_manual_seed,
|
||||
require_torch_accelerator_with_fp64,
|
||||
require_torch_version_greater_equal,
|
||||
torch_device,
|
||||
)
|
||||
|
||||
@@ -120,6 +121,21 @@ class Upsample2DBlockTests(unittest.TestCase):
|
||||
expected_slice = torch.tensor([-0.2173, -1.2079, -1.2079, 0.2952, 1.1254, 1.1254, 0.2952, 1.1254, 1.1254])
|
||||
assert torch.allclose(output_slice.flatten(), expected_slice, atol=1e-3)
|
||||
|
||||
@require_torch_version_greater_equal("2.1")
|
||||
def test_upsample_bfloat16(self):
|
||||
torch.manual_seed(0)
|
||||
sample = torch.randn(1, 32, 32, 32).to(torch.bfloat16)
|
||||
upsample = Upsample2D(channels=32, use_conv=False)
|
||||
with torch.no_grad():
|
||||
upsampled = upsample(sample)
|
||||
|
||||
assert upsampled.shape == (1, 32, 64, 64)
|
||||
output_slice = upsampled[0, -1, -3:, -3:]
|
||||
expected_slice = torch.tensor(
|
||||
[-0.2173, -1.2079, -1.2079, 0.2952, 1.1254, 1.1254, 0.2952, 1.1254, 1.1254], dtype=torch.bfloat16
|
||||
)
|
||||
assert torch.allclose(output_slice.flatten(), expected_slice, atol=1e-3)
|
||||
|
||||
def test_upsample_with_conv(self):
|
||||
torch.manual_seed(0)
|
||||
sample = torch.randn(1, 32, 32, 32)
|
||||
|
||||
@@ -27,8 +27,9 @@ import numpy as np
|
||||
import requests_mock
|
||||
import torch
|
||||
from accelerate.utils import compute_module_sizes
|
||||
from huggingface_hub import ModelCard, delete_repo
|
||||
from huggingface_hub import ModelCard, delete_repo, snapshot_download
|
||||
from huggingface_hub.utils import is_jinja_available
|
||||
from parameterized import parameterized
|
||||
from requests.exceptions import HTTPError
|
||||
|
||||
from diffusers.models import UNet2DConditionModel
|
||||
@@ -39,7 +40,13 @@ from diffusers.models.attention_processor import (
|
||||
XFormersAttnProcessor,
|
||||
)
|
||||
from diffusers.training_utils import EMAModel
|
||||
from diffusers.utils import SAFE_WEIGHTS_INDEX_NAME, is_torch_npu_available, is_xformers_available, logging
|
||||
from diffusers.utils import (
|
||||
SAFE_WEIGHTS_INDEX_NAME,
|
||||
WEIGHTS_INDEX_NAME,
|
||||
is_torch_npu_available,
|
||||
is_xformers_available,
|
||||
logging,
|
||||
)
|
||||
from diffusers.utils.hub_utils import _add_variant
|
||||
from diffusers.utils.testing_utils import (
|
||||
CaptureLogger,
|
||||
@@ -100,6 +107,52 @@ class ModelUtilsTest(unittest.TestCase):
|
||||
# make sure that error message states what keys are missing
|
||||
assert "conv_out.bias" in str(error_context.exception)
|
||||
|
||||
@parameterized.expand(
|
||||
[
|
||||
("hf-internal-testing/tiny-stable-diffusion-pipe-variants-all-kinds", "unet", False),
|
||||
("hf-internal-testing/tiny-stable-diffusion-pipe-variants-all-kinds", "unet", True),
|
||||
("hf-internal-testing/tiny-sd-unet-with-sharded-ckpt", None, False),
|
||||
("hf-internal-testing/tiny-sd-unet-with-sharded-ckpt", None, True),
|
||||
]
|
||||
)
|
||||
def test_variant_sharded_ckpt_legacy_format_raises_warning(self, repo_id, subfolder, use_local):
|
||||
def load_model(path):
|
||||
kwargs = {"variant": "fp16"}
|
||||
if subfolder:
|
||||
kwargs["subfolder"] = subfolder
|
||||
return UNet2DConditionModel.from_pretrained(path, **kwargs)
|
||||
|
||||
with self.assertWarns(FutureWarning) as warning:
|
||||
if use_local:
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
tmpdirname = snapshot_download(repo_id=repo_id)
|
||||
_ = load_model(tmpdirname)
|
||||
else:
|
||||
_ = load_model(repo_id)
|
||||
|
||||
warning_message = str(warning.warnings[0].message)
|
||||
self.assertIn("This serialization format is now deprecated to standardize the serialization", warning_message)
|
||||
|
||||
# Local tests are already covered down below.
|
||||
@parameterized.expand(
|
||||
[
|
||||
("hf-internal-testing/tiny-sd-unet-sharded-latest-format", None, "fp16"),
|
||||
("hf-internal-testing/tiny-sd-unet-sharded-latest-format-subfolder", "unet", "fp16"),
|
||||
("hf-internal-testing/tiny-sd-unet-sharded-no-variants", None, None),
|
||||
("hf-internal-testing/tiny-sd-unet-sharded-no-variants-subfolder", "unet", None),
|
||||
]
|
||||
)
|
||||
def test_variant_sharded_ckpt_loads_from_hub(self, repo_id, subfolder, variant=None):
|
||||
def load_model():
|
||||
kwargs = {}
|
||||
if variant:
|
||||
kwargs["variant"] = variant
|
||||
if subfolder:
|
||||
kwargs["subfolder"] = subfolder
|
||||
return UNet2DConditionModel.from_pretrained(repo_id, **kwargs)
|
||||
|
||||
assert load_model()
|
||||
|
||||
def test_cached_files_are_used_when_no_internet(self):
|
||||
# A mock response for an HTTP head request to emulate server down
|
||||
response_mock = mock.Mock()
|
||||
@@ -924,6 +977,7 @@ class ModelTesterMixin:
|
||||
# testing if loading works with the variant when the checkpoint is sharded should be
|
||||
# enough.
|
||||
model.cpu().save_pretrained(tmp_dir, max_shard_size=f"{max_shard_size}KB", variant=variant)
|
||||
|
||||
index_filename = _add_variant(SAFE_WEIGHTS_INDEX_NAME, variant)
|
||||
self.assertTrue(os.path.exists(os.path.join(tmp_dir, index_filename)))
|
||||
|
||||
@@ -976,6 +1030,44 @@ class ModelTesterMixin:
|
||||
new_output = new_model(**inputs_dict)
|
||||
self.assertTrue(torch.allclose(base_output[0], new_output[0], atol=1e-5))
|
||||
|
||||
# This test is okay without a GPU because we're not running any execution. We're just serializing
|
||||
# and check if the resultant files are following an expected format.
|
||||
def test_variant_sharded_ckpt_right_format(self):
|
||||
for use_safe in [True, False]:
|
||||
extension = ".safetensors" if use_safe else ".bin"
|
||||
config, _ = self.prepare_init_args_and_inputs_for_common()
|
||||
model = self.model_class(**config).eval()
|
||||
|
||||
model_size = compute_module_sizes(model)[""]
|
||||
max_shard_size = int((model_size * 0.75) / (2**10)) # Convert to KB as these test models are small.
|
||||
variant = "fp16"
|
||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||
model.cpu().save_pretrained(
|
||||
tmp_dir, variant=variant, max_shard_size=f"{max_shard_size}KB", safe_serialization=use_safe
|
||||
)
|
||||
index_variant = _add_variant(SAFE_WEIGHTS_INDEX_NAME if use_safe else WEIGHTS_INDEX_NAME, variant)
|
||||
self.assertTrue(os.path.exists(os.path.join(tmp_dir, index_variant)))
|
||||
|
||||
# Now check if the right number of shards exists. First, let's get the number of shards.
|
||||
# Since this number can be dependent on the model being tested, it's important that we calculate it
|
||||
# instead of hardcoding it.
|
||||
expected_num_shards = caculate_expected_num_shards(os.path.join(tmp_dir, index_variant))
|
||||
actual_num_shards = len([file for file in os.listdir(tmp_dir) if file.endswith(extension)])
|
||||
self.assertTrue(actual_num_shards == expected_num_shards)
|
||||
|
||||
# Check if the variant is present as a substring in the checkpoints.
|
||||
shard_files = [
|
||||
file
|
||||
for file in os.listdir(tmp_dir)
|
||||
if file.endswith(extension) or ("index" in file and "json" in file)
|
||||
]
|
||||
assert all(variant in f for f in shard_files)
|
||||
|
||||
# Check if the sharded checkpoints were serialized in the right format.
|
||||
shard_files = [file for file in os.listdir(tmp_dir) if file.endswith(extension)]
|
||||
# Example: diffusion_pytorch_model.fp16-00001-of-00002.safetensors
|
||||
assert all(f.split(".")[1].split("-")[0] == variant for f in shard_files)
|
||||
|
||||
|
||||
@is_staging_test
|
||||
class ModelPushToHubTester(unittest.TestCase):
|
||||
|
||||
@@ -1036,9 +1036,15 @@ class UNet2DConditionModelTests(ModelTesterMixin, UNetTesterMixin, unittest.Test
|
||||
assert sample2.allclose(sample6, atol=1e-4, rtol=1e-4)
|
||||
|
||||
@require_torch_gpu
|
||||
def test_load_sharded_checkpoint_from_hub(self):
|
||||
@parameterized.expand(
|
||||
[
|
||||
("hf-internal-testing/unet2d-sharded-dummy", None),
|
||||
("hf-internal-testing/tiny-sd-unet-sharded-latest-format", "fp16"),
|
||||
]
|
||||
)
|
||||
def test_load_sharded_checkpoint_from_hub(self, repo_id, variant):
|
||||
_, inputs_dict = self.prepare_init_args_and_inputs_for_common()
|
||||
loaded_model = self.model_class.from_pretrained("hf-internal-testing/unet2d-sharded-dummy")
|
||||
loaded_model = self.model_class.from_pretrained(repo_id, variant=variant)
|
||||
loaded_model = loaded_model.to(torch_device)
|
||||
new_output = loaded_model(**inputs_dict)
|
||||
|
||||
@@ -1046,11 +1052,15 @@ class UNet2DConditionModelTests(ModelTesterMixin, UNetTesterMixin, unittest.Test
|
||||
assert new_output.sample.shape == (4, 4, 16, 16)
|
||||
|
||||
@require_torch_gpu
|
||||
def test_load_sharded_checkpoint_from_hub_subfolder(self):
|
||||
@parameterized.expand(
|
||||
[
|
||||
("hf-internal-testing/unet2d-sharded-dummy-subfolder", None),
|
||||
("hf-internal-testing/tiny-sd-unet-sharded-latest-format-subfolder", "fp16"),
|
||||
]
|
||||
)
|
||||
def test_load_sharded_checkpoint_from_hub_subfolder(self, repo_id, variant):
|
||||
_, inputs_dict = self.prepare_init_args_and_inputs_for_common()
|
||||
loaded_model = self.model_class.from_pretrained(
|
||||
"hf-internal-testing/unet2d-sharded-dummy-subfolder", subfolder="unet"
|
||||
)
|
||||
loaded_model = self.model_class.from_pretrained(repo_id, subfolder="unet", variant=variant)
|
||||
loaded_model = loaded_model.to(torch_device)
|
||||
new_output = loaded_model(**inputs_dict)
|
||||
|
||||
@@ -1080,20 +1090,30 @@ class UNet2DConditionModelTests(ModelTesterMixin, UNetTesterMixin, unittest.Test
|
||||
assert new_output.sample.shape == (4, 4, 16, 16)
|
||||
|
||||
@require_torch_gpu
|
||||
def test_load_sharded_checkpoint_device_map_from_hub(self):
|
||||
@parameterized.expand(
|
||||
[
|
||||
("hf-internal-testing/unet2d-sharded-dummy", None),
|
||||
("hf-internal-testing/tiny-sd-unet-sharded-latest-format", "fp16"),
|
||||
]
|
||||
)
|
||||
def test_load_sharded_checkpoint_device_map_from_hub(self, repo_id, variant):
|
||||
_, inputs_dict = self.prepare_init_args_and_inputs_for_common()
|
||||
loaded_model = self.model_class.from_pretrained("hf-internal-testing/unet2d-sharded-dummy", device_map="auto")
|
||||
loaded_model = self.model_class.from_pretrained(repo_id, variant=variant, device_map="auto")
|
||||
new_output = loaded_model(**inputs_dict)
|
||||
|
||||
assert loaded_model
|
||||
assert new_output.sample.shape == (4, 4, 16, 16)
|
||||
|
||||
@require_torch_gpu
|
||||
def test_load_sharded_checkpoint_device_map_from_hub_subfolder(self):
|
||||
@parameterized.expand(
|
||||
[
|
||||
("hf-internal-testing/unet2d-sharded-dummy-subfolder", None),
|
||||
("hf-internal-testing/tiny-sd-unet-sharded-latest-format-subfolder", "fp16"),
|
||||
]
|
||||
)
|
||||
def test_load_sharded_checkpoint_device_map_from_hub_subfolder(self, repo_id, variant):
|
||||
_, inputs_dict = self.prepare_init_args_and_inputs_for_common()
|
||||
loaded_model = self.model_class.from_pretrained(
|
||||
"hf-internal-testing/unet2d-sharded-dummy-subfolder", subfolder="unet", device_map="auto"
|
||||
)
|
||||
loaded_model = self.model_class.from_pretrained(repo_id, variant=variant, subfolder="unet", device_map="auto")
|
||||
new_output = loaded_model(**inputs_dict)
|
||||
|
||||
assert loaded_model
|
||||
@@ -1121,18 +1141,6 @@ class UNet2DConditionModelTests(ModelTesterMixin, UNetTesterMixin, unittest.Test
|
||||
assert loaded_model
|
||||
assert new_output.sample.shape == (4, 4, 16, 16)
|
||||
|
||||
@require_torch_gpu
|
||||
def test_load_sharded_checkpoint_with_variant_from_hub(self):
|
||||
_, inputs_dict = self.prepare_init_args_and_inputs_for_common()
|
||||
loaded_model = self.model_class.from_pretrained(
|
||||
"hf-internal-testing/unet2d-sharded-with-variant-dummy", variant="fp16"
|
||||
)
|
||||
loaded_model = loaded_model.to(torch_device)
|
||||
new_output = loaded_model(**inputs_dict)
|
||||
|
||||
assert loaded_model
|
||||
assert new_output.sample.shape == (4, 4, 16, 16)
|
||||
|
||||
@require_peft_backend
|
||||
def test_lora(self):
|
||||
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
|
||||
|
||||
@@ -0,0 +1,245 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2024 HuggingFace Inc.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
# This model implementation is heavily based on:
|
||||
|
||||
import inspect
|
||||
import random
|
||||
import unittest
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from PIL import Image
|
||||
from transformers import CLIPTextConfig, CLIPTextModel, CLIPTokenizer
|
||||
|
||||
from diffusers import (
|
||||
AutoencoderKL,
|
||||
ControlNetModel,
|
||||
DDIMScheduler,
|
||||
StableDiffusionControlNetInpaintPipeline,
|
||||
StableDiffusionControlNetPAGInpaintPipeline,
|
||||
UNet2DConditionModel,
|
||||
)
|
||||
from diffusers.utils.testing_utils import (
|
||||
enable_full_determinism,
|
||||
floats_tensor,
|
||||
)
|
||||
from diffusers.utils.torch_utils import randn_tensor
|
||||
|
||||
from ..pipeline_params import (
|
||||
TEXT_GUIDED_IMAGE_INPAINTING_BATCH_PARAMS,
|
||||
TEXT_GUIDED_IMAGE_INPAINTING_PARAMS,
|
||||
TEXT_TO_IMAGE_IMAGE_PARAMS,
|
||||
)
|
||||
from ..test_pipelines_common import PipelineKarrasSchedulerTesterMixin, PipelineLatentTesterMixin, PipelineTesterMixin
|
||||
|
||||
|
||||
enable_full_determinism()
|
||||
|
||||
|
||||
class StableDiffusionControlNetPAGInpaintPipelineFastTests(
|
||||
PipelineLatentTesterMixin, PipelineKarrasSchedulerTesterMixin, PipelineTesterMixin, unittest.TestCase
|
||||
):
|
||||
pipeline_class = StableDiffusionControlNetPAGInpaintPipeline
|
||||
params = TEXT_GUIDED_IMAGE_INPAINTING_PARAMS
|
||||
batch_params = TEXT_GUIDED_IMAGE_INPAINTING_BATCH_PARAMS
|
||||
image_params = frozenset({"control_image"}) # skip `image` and `mask` for now, only test for control_image
|
||||
image_latents_params = TEXT_TO_IMAGE_IMAGE_PARAMS
|
||||
|
||||
def get_dummy_components(self):
|
||||
# Copied from tests.pipelines.controlnet.test_controlnet_inpaint.ControlNetInpaintPipelineFastTests.get_dummy_components
|
||||
torch.manual_seed(0)
|
||||
unet = UNet2DConditionModel(
|
||||
block_out_channels=(32, 64),
|
||||
layers_per_block=2,
|
||||
sample_size=32,
|
||||
in_channels=9,
|
||||
out_channels=4,
|
||||
down_block_types=("DownBlock2D", "CrossAttnDownBlock2D"),
|
||||
up_block_types=("CrossAttnUpBlock2D", "UpBlock2D"),
|
||||
cross_attention_dim=32,
|
||||
)
|
||||
torch.manual_seed(0)
|
||||
controlnet = ControlNetModel(
|
||||
block_out_channels=(32, 64),
|
||||
layers_per_block=2,
|
||||
in_channels=4,
|
||||
down_block_types=("DownBlock2D", "CrossAttnDownBlock2D"),
|
||||
cross_attention_dim=32,
|
||||
conditioning_embedding_out_channels=(16, 32),
|
||||
)
|
||||
torch.manual_seed(0)
|
||||
scheduler = DDIMScheduler(
|
||||
beta_start=0.00085,
|
||||
beta_end=0.012,
|
||||
beta_schedule="scaled_linear",
|
||||
clip_sample=False,
|
||||
set_alpha_to_one=False,
|
||||
)
|
||||
torch.manual_seed(0)
|
||||
vae = AutoencoderKL(
|
||||
block_out_channels=[32, 64],
|
||||
in_channels=3,
|
||||
out_channels=3,
|
||||
down_block_types=["DownEncoderBlock2D", "DownEncoderBlock2D"],
|
||||
up_block_types=["UpDecoderBlock2D", "UpDecoderBlock2D"],
|
||||
latent_channels=4,
|
||||
)
|
||||
torch.manual_seed(0)
|
||||
text_encoder_config = CLIPTextConfig(
|
||||
bos_token_id=0,
|
||||
eos_token_id=2,
|
||||
hidden_size=32,
|
||||
intermediate_size=37,
|
||||
layer_norm_eps=1e-05,
|
||||
num_attention_heads=4,
|
||||
num_hidden_layers=5,
|
||||
pad_token_id=1,
|
||||
vocab_size=1000,
|
||||
)
|
||||
text_encoder = CLIPTextModel(text_encoder_config)
|
||||
tokenizer = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip")
|
||||
|
||||
components = {
|
||||
"unet": unet,
|
||||
"controlnet": controlnet,
|
||||
"scheduler": scheduler,
|
||||
"vae": vae,
|
||||
"text_encoder": text_encoder,
|
||||
"tokenizer": tokenizer,
|
||||
"safety_checker": None,
|
||||
"feature_extractor": None,
|
||||
"image_encoder": None,
|
||||
}
|
||||
return components
|
||||
|
||||
def get_dummy_inputs(self, device, seed=0):
|
||||
if str(device).startswith("mps"):
|
||||
generator = torch.manual_seed(seed)
|
||||
else:
|
||||
generator = torch.Generator(device=device).manual_seed(seed)
|
||||
|
||||
controlnet_embedder_scale_factor = 2
|
||||
control_image = randn_tensor(
|
||||
(1, 3, 32 * controlnet_embedder_scale_factor, 32 * controlnet_embedder_scale_factor),
|
||||
generator=generator,
|
||||
device=torch.device(device),
|
||||
)
|
||||
init_image = floats_tensor((1, 3, 32, 32), rng=random.Random(seed)).to(device)
|
||||
init_image = init_image.cpu().permute(0, 2, 3, 1)[0]
|
||||
|
||||
image = Image.fromarray(np.uint8(init_image)).convert("RGB").resize((64, 64))
|
||||
mask_image = Image.fromarray(np.uint8(init_image + 4)).convert("RGB").resize((64, 64))
|
||||
|
||||
inputs = {
|
||||
"prompt": "A painting of a squirrel eating a burger",
|
||||
"generator": generator,
|
||||
"num_inference_steps": 2,
|
||||
"guidance_scale": 6.0,
|
||||
"pag_scale": 3.0,
|
||||
"output_type": "np",
|
||||
"image": image,
|
||||
"mask_image": mask_image,
|
||||
"control_image": control_image,
|
||||
}
|
||||
|
||||
return inputs
|
||||
|
||||
def test_pag_disable_enable(self):
|
||||
device = "cpu" # ensure determinism for the device-dependent torch.Generator
|
||||
components = self.get_dummy_components()
|
||||
|
||||
# base pipeline (expect same output when pag is disabled)
|
||||
pipe_sd = StableDiffusionControlNetInpaintPipeline(**components)
|
||||
pipe_sd = pipe_sd.to(device)
|
||||
pipe_sd.set_progress_bar_config(disable=None)
|
||||
|
||||
inputs = self.get_dummy_inputs(device)
|
||||
del inputs["pag_scale"]
|
||||
assert (
|
||||
"pag_scale" not in inspect.signature(pipe_sd.__call__).parameters
|
||||
), f"`pag_scale` should not be a call parameter of the base pipeline {pipe_sd.__calss__.__name__}."
|
||||
out = pipe_sd(**inputs).images[0, -3:, -3:, -1]
|
||||
|
||||
# pag disabled with pag_scale=0.0
|
||||
pipe_pag = self.pipeline_class(**components)
|
||||
pipe_pag = pipe_pag.to(device)
|
||||
pipe_pag.set_progress_bar_config(disable=None)
|
||||
|
||||
inputs = self.get_dummy_inputs(device)
|
||||
inputs["pag_scale"] = 0.0
|
||||
out_pag_disabled = pipe_pag(**inputs).images[0, -3:, -3:, -1]
|
||||
|
||||
# pag enabled
|
||||
pipe_pag = self.pipeline_class(**components, pag_applied_layers=["mid", "up", "down"])
|
||||
pipe_pag = pipe_pag.to(device)
|
||||
pipe_pag.set_progress_bar_config(disable=None)
|
||||
|
||||
inputs = self.get_dummy_inputs(device)
|
||||
out_pag_enabled = pipe_pag(**inputs).images[0, -3:, -3:, -1]
|
||||
|
||||
assert np.abs(out.flatten() - out_pag_disabled.flatten()).max() < 1e-3
|
||||
assert np.abs(out.flatten() - out_pag_enabled.flatten()).max() > 1e-3
|
||||
|
||||
def test_pag_cfg(self):
|
||||
device = "cpu" # ensure determinism for the device-dependent torch.Generator
|
||||
components = self.get_dummy_components()
|
||||
|
||||
pipe_pag = self.pipeline_class(**components, pag_applied_layers=["mid", "up", "down"])
|
||||
pipe_pag = pipe_pag.to(device)
|
||||
pipe_pag.set_progress_bar_config(disable=None)
|
||||
|
||||
inputs = self.get_dummy_inputs(device)
|
||||
image = pipe_pag(**inputs).images
|
||||
image_slice = image[0, -3:, -3:, -1]
|
||||
|
||||
assert image.shape == (
|
||||
1,
|
||||
64,
|
||||
64,
|
||||
3,
|
||||
), f"the shape of the output image should be (1, 64, 64, 3) but got {image.shape}"
|
||||
expected_slice = np.array(
|
||||
[0.7488756, 0.61194265, 0.53382546, 0.5993959, 0.6193306, 0.56880975, 0.41277143, 0.5050145, 0.49376273]
|
||||
)
|
||||
|
||||
max_diff = np.abs(image_slice.flatten() - expected_slice).max()
|
||||
assert max_diff < 1e-3, f"output is different from expected, {image_slice.flatten()}"
|
||||
|
||||
def test_pag_uncond(self):
|
||||
device = "cpu" # ensure determinism for the device-dependent torch.Generator
|
||||
components = self.get_dummy_components()
|
||||
|
||||
pipe_pag = self.pipeline_class(**components, pag_applied_layers=["mid", "up", "down"])
|
||||
pipe_pag = pipe_pag.to(device)
|
||||
pipe_pag.set_progress_bar_config(disable=None)
|
||||
|
||||
inputs = self.get_dummy_inputs(device)
|
||||
inputs["guidance_scale"] = 0.0
|
||||
image = pipe_pag(**inputs).images
|
||||
image_slice = image[0, -3:, -3:, -1]
|
||||
|
||||
assert image.shape == (
|
||||
1,
|
||||
64,
|
||||
64,
|
||||
3,
|
||||
), f"the shape of the output image should be (1, 64, 64, 3) but got {image.shape}"
|
||||
expected_slice = np.array(
|
||||
[0.7410303, 0.5989337, 0.530866, 0.60571927, 0.6162597, 0.5719856, 0.4187478, 0.5101238, 0.4978468]
|
||||
)
|
||||
|
||||
max_diff = np.abs(image_slice.flatten() - expected_slice).max()
|
||||
assert max_diff < 1e-3, f"output is different from expected, {image_slice.flatten()}"
|
||||
@@ -30,6 +30,7 @@ import requests_mock
|
||||
import safetensors.torch
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from huggingface_hub import snapshot_download
|
||||
from parameterized import parameterized
|
||||
from PIL import Image
|
||||
from requests.exceptions import HTTPError
|
||||
@@ -551,6 +552,50 @@ class DownloadTests(unittest.TestCase):
|
||||
assert sum(f.endswith(this_format) and not f.endswith(f"{variant}{this_format}") for f in files) == 3
|
||||
assert not any(f.endswith(other_format) for f in files)
|
||||
|
||||
def test_download_variants_with_sharded_checkpoints(self):
|
||||
# Here we test for downloading of "variant" files belonging to the `unet` and
|
||||
# the `text_encoder`. Their checkpoints can be sharded.
|
||||
for use_safetensors in [True, False]:
|
||||
for variant in ["fp16", None]:
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
tmpdirname = DiffusionPipeline.download(
|
||||
"hf-internal-testing/tiny-stable-diffusion-pipe-variants-right-format",
|
||||
safety_checker=None,
|
||||
cache_dir=tmpdirname,
|
||||
variant=variant,
|
||||
use_safetensors=use_safetensors,
|
||||
)
|
||||
|
||||
all_root_files = [t[-1] for t in os.walk(os.path.join(tmpdirname))]
|
||||
files = [item for sublist in all_root_files for item in sublist]
|
||||
|
||||
# Check for `model_ext` and `variant`.
|
||||
model_ext = ".safetensors" if use_safetensors else ".bin"
|
||||
unexpected_ext = ".bin" if use_safetensors else ".safetensors"
|
||||
model_files = [f for f in files if f.endswith(model_ext)]
|
||||
assert not any(f.endswith(unexpected_ext) for f in files)
|
||||
assert all(variant in f for f in model_files if f.endswith(model_ext) and variant is not None)
|
||||
|
||||
def test_download_legacy_variants_with_sharded_ckpts_raises_warning(self):
|
||||
repo_id = "hf-internal-testing/tiny-stable-diffusion-pipe-variants-all-kinds"
|
||||
logger = logging.get_logger("diffusers.pipelines.pipeline_utils")
|
||||
deprecated_warning_msg = "Warning: The repository contains sharded checkpoints for variant"
|
||||
|
||||
for is_local in [True, False]:
|
||||
with CaptureLogger(logger) as cap_logger:
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
local_repo_id = repo_id
|
||||
if is_local:
|
||||
local_repo_id = snapshot_download(repo_id, cache_dir=tmpdirname)
|
||||
|
||||
_ = DiffusionPipeline.from_pretrained(
|
||||
local_repo_id,
|
||||
safety_checker=None,
|
||||
variant="fp16",
|
||||
use_safetensors=True,
|
||||
)
|
||||
assert deprecated_warning_msg in str(cap_logger), "Deprecation warning not found in logs"
|
||||
|
||||
def test_download_safetensors_only_variant_exists_for_model(self):
|
||||
variant = None
|
||||
use_safetensors = True
|
||||
@@ -655,7 +700,7 @@ class DownloadTests(unittest.TestCase):
|
||||
out = pipe(prompt, num_inference_steps=2, generator=generator, output_type="np").images
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
pipe.save_pretrained(tmpdirname)
|
||||
pipe.save_pretrained(tmpdirname, variant=variant, safe_serialization=use_safe)
|
||||
pipe_2 = StableDiffusionPipeline.from_pretrained(
|
||||
tmpdirname, safe_serialization=use_safe, variant=variant
|
||||
)
|
||||
@@ -1646,7 +1691,7 @@ class PipelineFastTests(unittest.TestCase):
|
||||
def test_error_no_variant_available(self):
|
||||
variant = "fp16"
|
||||
with self.assertRaises(ValueError) as error_context:
|
||||
_ = StableDiffusionPipeline.download(
|
||||
_ = StableDiffusionPipeline.from_pretrained(
|
||||
"hf-internal-testing/diffusers-stable-diffusion-tiny-all", variant=variant
|
||||
)
|
||||
|
||||
|
||||
@@ -1824,6 +1824,74 @@ class PipelineTesterMixin:
|
||||
# accounts for models that modify the number of inference steps based on strength
|
||||
assert pipe.guidance_scale == (inputs["guidance_scale"] + pipe.num_timesteps)
|
||||
|
||||
def test_serialization_with_variants(self):
|
||||
components = self.get_dummy_components()
|
||||
pipe = self.pipeline_class(**components)
|
||||
model_components = [
|
||||
component_name for component_name, component in pipe.components.items() if isinstance(component, nn.Module)
|
||||
]
|
||||
variant = "fp16"
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
pipe.save_pretrained(tmpdir, variant=variant, safe_serialization=False)
|
||||
|
||||
with open(f"{tmpdir}/model_index.json", "r") as f:
|
||||
config = json.load(f)
|
||||
|
||||
for subfolder in os.listdir(tmpdir):
|
||||
if not os.path.isfile(subfolder) and subfolder in model_components:
|
||||
folder_path = os.path.join(tmpdir, subfolder)
|
||||
is_folder = os.path.isdir(folder_path) and subfolder in config
|
||||
assert is_folder and any(p.split(".")[1].startswith(variant) for p in os.listdir(folder_path))
|
||||
|
||||
def test_loading_with_variants(self):
|
||||
components = self.get_dummy_components()
|
||||
pipe = self.pipeline_class(**components)
|
||||
variant = "fp16"
|
||||
|
||||
def is_nan(tensor):
|
||||
if tensor.ndimension() == 0:
|
||||
has_nan = torch.isnan(tensor).item()
|
||||
else:
|
||||
has_nan = torch.isnan(tensor).any()
|
||||
return has_nan
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
pipe.save_pretrained(tmpdir, variant=variant, safe_serialization=False)
|
||||
pipe_loaded = self.pipeline_class.from_pretrained(tmpdir, variant=variant)
|
||||
|
||||
model_components_pipe = {
|
||||
component_name: component
|
||||
for component_name, component in pipe.components.items()
|
||||
if isinstance(component, nn.Module)
|
||||
}
|
||||
model_components_pipe_loaded = {
|
||||
component_name: component
|
||||
for component_name, component in pipe_loaded.components.items()
|
||||
if isinstance(component, nn.Module)
|
||||
}
|
||||
for component_name in model_components_pipe:
|
||||
pipe_component = model_components_pipe[component_name]
|
||||
pipe_loaded_component = model_components_pipe_loaded[component_name]
|
||||
for p1, p2 in zip(pipe_component.parameters(), pipe_loaded_component.parameters()):
|
||||
# nan check for luminanext (mps).
|
||||
if not (is_nan(p1) and is_nan(p2)):
|
||||
self.assertTrue(torch.equal(p1, p2))
|
||||
|
||||
def test_loading_with_incorrect_variants_raises_error(self):
|
||||
components = self.get_dummy_components()
|
||||
pipe = self.pipeline_class(**components)
|
||||
variant = "fp16"
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
# Don't save with variants.
|
||||
pipe.save_pretrained(tmpdir, safe_serialization=False)
|
||||
|
||||
with self.assertRaises(ValueError) as error:
|
||||
_ = self.pipeline_class.from_pretrained(tmpdir, variant=variant)
|
||||
|
||||
assert f"You are trying to load the model files of the `variant={variant}`" in str(error.exception)
|
||||
|
||||
def test_StableDiffusionMixin_component(self):
|
||||
"""Any pipeline that have LDMFuncMixin should have vae and unet components."""
|
||||
if not issubclass(self.pipeline_class, StableDiffusionMixin):
|
||||
|
||||
Reference in New Issue
Block a user