Compare commits

..

1 Commits

Author SHA1 Message Date
yiyixuxu 3e93944bf7 up 2024-09-26 02:32:50 +02:00
158 changed files with 875 additions and 24086 deletions
@@ -25,7 +25,7 @@ jobs:
env:
SLACK_WEBHOOK_URL: ${{ secrets.SLACK_WEBHOOK_URL_COMMUNITY_MIRROR }}
runs-on: ubuntu-22.04
runs-on: ubuntu-latest
steps:
# Checkout to correct ref
# If workflow dispatch
@@ -7,7 +7,7 @@ on:
jobs:
build:
runs-on: ubuntu-22.04
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v3
+1 -1
View File
@@ -16,7 +16,7 @@ concurrency:
jobs:
check_dependencies:
runs-on: ubuntu-22.04
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v3
- name: Set up Python
@@ -16,7 +16,7 @@ concurrency:
jobs:
check_flax_dependencies:
runs-on: ubuntu-22.04
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v3
- name: Set up Python
+2 -2
View File
@@ -20,7 +20,7 @@ env:
jobs:
check_code_quality:
runs-on: ubuntu-22.04
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v3
- name: Set up Python
@@ -40,7 +40,7 @@ jobs:
check_repository_consistency:
needs: check_code_quality
runs-on: ubuntu-22.04
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v3
- name: Set up Python
+2 -2
View File
@@ -29,7 +29,7 @@ env:
jobs:
check_code_quality:
runs-on: ubuntu-22.04
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v3
- name: Set up Python
@@ -49,7 +49,7 @@ jobs:
check_repository_consistency:
needs: check_code_quality
runs-on: ubuntu-22.04
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v3
- name: Set up Python
@@ -16,7 +16,7 @@ concurrency:
jobs:
check_torch_dependencies:
runs-on: ubuntu-22.04
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v3
- name: Set up Python
+2 -2
View File
@@ -10,7 +10,7 @@ on:
jobs:
find-and-checkout-latest-branch:
runs-on: ubuntu-22.04
runs-on: ubuntu-latest
outputs:
latest_branch: ${{ steps.set_latest_branch.outputs.latest_branch }}
steps:
@@ -36,7 +36,7 @@ jobs:
release:
needs: find-and-checkout-latest-branch
runs-on: ubuntu-22.04
runs-on: ubuntu-latest
steps:
- name: Checkout Repo
+1 -1
View File
@@ -8,7 +8,7 @@ jobs:
close_stale_issues:
name: Close Stale Issues
if: github.repository == 'huggingface/diffusers'
runs-on: ubuntu-22.04
runs-on: ubuntu-latest
permissions:
issues: write
pull-requests: write
+1 -1
View File
@@ -5,7 +5,7 @@ name: Secret Leaks
jobs:
trufflehog:
runs-on: ubuntu-22.04
runs-on: ubuntu-latest
steps:
- name: Checkout code
uses: actions/checkout@v4
+1 -1
View File
@@ -5,7 +5,7 @@ on:
jobs:
build:
runs-on: ubuntu-22.04
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v3
+1 -1
View File
@@ -3,7 +3,7 @@ import sys
import pandas as pd
from huggingface_hub import hf_hub_download, upload_file
from huggingface_hub.utils import EntryNotFoundError
from huggingface_hub.utils._errors import EntryNotFoundError
sys.path.append(".")
-8
View File
@@ -75,8 +75,6 @@
title: Outpainting
title: Advanced inference
- sections:
- local: using-diffusers/cogvideox
title: CogVideoX
- local: using-diffusers/sdxl
title: Stable Diffusion XL
- local: using-diffusers/sdxl_turbo
@@ -131,8 +129,6 @@
title: T2I-Adapters
- local: training/instructpix2pix
title: InstructPix2Pix
- local: training/cogvideox
title: CogVideoX
title: Models
- isExpanded: false
sections:
@@ -246,8 +242,6 @@
title: AuraFlowTransformer2DModel
- local: api/models/cogvideox_transformer3d
title: CogVideoXTransformer3DModel
- local: api/models/cogview3plus_transformer2d
title: CogView3PlusTransformer2DModel
- local: api/models/dit_transformer2d
title: DiTTransformer2DModel
- local: api/models/flux_transformer
@@ -326,8 +320,6 @@
title: BLIP-Diffusion
- local: api/pipelines/cogvideox
title: CogVideoX
- local: api/pipelines/cogview3
title: CogView3
- local: api/pipelines/consistency_models
title: Consistency Models
- local: api/pipelines/controlnet
@@ -1,30 +0,0 @@
<!--Copyright 2024 The HuggingFace Team. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
the License. You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
specific language governing permissions and limitations under the License. -->
# CogView3PlusTransformer2DModel
A Diffusion Transformer model for 2D data from [CogView3Plus](https://github.com/THUDM/CogView3) was introduced in [CogView3: Finer and Faster Text-to-Image Generation via Relay Diffusion](https://huggingface.co/papers/2403.05121) by Tsinghua University & ZhipuAI.
The model can be loaded with the following code snippet.
```python
from diffusers import CogView3PlusTransformer2DModel
vae = CogView3PlusTransformer2DModel.from_pretrained("THUDM/CogView3Plus-3b", subfolder="transformer", torch_dtype=torch.bfloat16).to("cuda")
```
## CogView3PlusTransformer2DModel
[[autodoc]] CogView3PlusTransformer2DModel
## Transformer2DModelOutput
[[autodoc]] models.modeling_outputs.Transformer2DModelOutput
-10
View File
@@ -36,10 +36,6 @@ There are two models available that can be used with the text-to-video and video
There is one model available that can be used with the image-to-video CogVideoX pipeline:
- [`THUDM/CogVideoX-5b-I2V`](https://huggingface.co/THUDM/CogVideoX-5b-I2V): The recommended dtype for running this model is `bf16`.
There are two models that support pose controllable generation (by the [Alibaba-PAI](https://huggingface.co/alibaba-pai) team):
- [`alibaba-pai/CogVideoX-Fun-V1.1-2b-Pose`](https://huggingface.co/alibaba-pai/CogVideoX-Fun-V1.1-2b-Pose): The recommended dtype for running this model is `bf16`.
- [`alibaba-pai/CogVideoX-Fun-V1.1-5b-Pose`](https://huggingface.co/alibaba-pai/CogVideoX-Fun-V1.1-5b-Pose): The recommended dtype for running this model is `bf16`.
## Inference
Use [`torch.compile`](https://huggingface.co/docs/diffusers/main/en/tutorials/fast_diffusion#torchcompile) to reduce the inference latency.
@@ -122,12 +118,6 @@ It is also worth noting that torchao quantization is fully compatible with [torc
- all
- __call__
## CogVideoXFunControlPipeline
[[autodoc]] CogVideoXFunControlPipeline
- all
- __call__
## CogVideoXPipelineOutput
[[autodoc]] pipelines.cogvideo.pipeline_output.CogVideoXPipelineOutput
-40
View File
@@ -1,40 +0,0 @@
<!--Copyright 2024 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-->
# CogView3Plus
[CogView3: Finer and Faster Text-to-Image Generation via Relay Diffusion](https://huggingface.co/papers/2403.05121) from Tsinghua University & ZhipuAI, by Wendi Zheng, Jiayan Teng, Zhuoyi Yang, Weihan Wang, Jidong Chen, Xiaotao Gu, Yuxiao Dong, Ming Ding, Jie Tang.
The abstract from the paper is:
*Recent advancements in text-to-image generative systems have been largely driven by diffusion models. However, single-stage text-to-image diffusion models still face challenges, in terms of computational efficiency and the refinement of image details. To tackle the issue, we propose CogView3, an innovative cascaded framework that enhances the performance of text-to-image diffusion. CogView3 is the first model implementing relay diffusion in the realm of text-to-image generation, executing the task by first creating low-resolution images and subsequently applying relay-based super-resolution. This methodology not only results in competitive text-to-image outputs but also greatly reduces both training and inference costs. Our experimental results demonstrate that CogView3 outperforms SDXL, the current state-of-the-art open-source text-to-image diffusion model, by 77.0% in human evaluations, all while requiring only about 1/2 of the inference time. The distilled variant of CogView3 achieves comparable performance while only utilizing 1/10 of the inference time by SDXL.*
<Tip>
Make sure to check out the Schedulers [guide](../../using-diffusers/schedulers.md) to learn how to explore the tradeoff between scheduler speed and quality, and see the [reuse components across pipelines](../../using-diffusers/loading.md#reuse-a-pipeline) section to learn how to efficiently load the same components into multiple pipelines.
</Tip>
This pipeline was contributed by [zRzRzRzRzRzRzR](https://github.com/zRzRzRzRzRzRzR). The original codebase can be found [here](https://huggingface.co/THUDM). The original weights can be found under [hf.co/THUDM](https://huggingface.co/THUDM).
## CogView3PlusPipeline
[[autodoc]] CogView3PlusPipeline
- all
- __call__
## CogView3PipelineOutput
[[autodoc]] pipelines.cogview3.pipeline_output.CogView3PipelineOutput
-8
View File
@@ -53,16 +53,8 @@ Since RegEx is supported as a way for matching layer identifiers, it is crucial
- all
- __call__
## StableDiffusionPAGImg2ImgPipeline
[[autodoc]] StableDiffusionPAGImg2ImgPipeline
- all
- __call__
## StableDiffusionControlNetPAGPipeline
[[autodoc]] StableDiffusionControlNetPAGPipeline
## StableDiffusionControlNetPAGInpaintPipeline
[[autodoc]] StableDiffusionControlNetPAGInpaintPipeline
- all
- __call__
@@ -52,7 +52,6 @@ Many schedulers are implemented from the [k-diffusion](https://github.com/crowso
| sgm_uniform | init with `timestep_spacing="trailing"` |
| simple | init with `timestep_spacing="trailing"` |
| exponential | init with `timestep_spacing="linspace"`, `use_exponential_sigmas=True` |
| beta | init with `timestep_spacing="linspace"`, `use_beta_sigmas=True` |
All schedulers are built from the base [`SchedulerMixin`] class which implements low level utilities shared by all schedulers.
+9 -8
View File
@@ -9,7 +9,7 @@ Optimization orthogonal to parallelization focuses on accelerating single GPU pe
The overview of xDiT is shown as follows.
<div class="flex justify-center">
<img src="https://huggingface.co/datasets/xDiT/documentation-images/resolve/main/methods/xdit_overview.png">
<img src="https://github.com/xdit-project/xDiT/raw/main/assets/methods/xdit_overview.png">
</div>
You can install xDiT using the following command:
@@ -78,36 +78,37 @@ A subset of Diffusers models are supported in xDiT, such as Flux.1, Stable Diffu
## Benchmark
We tested different models on various machines, and here is some of the benchmark data.
### Flux.1-schnell
<div class="flex justify-center">
<img src="https://huggingface.co/datasets/xDiT/documentation-images/resolve/main/performance/flux/Flux-2k-L40.png">
<img src="https://github.com/xdit-project/xDiT/raw/main/assets/performance/flux/Flux-2k-L40.png">
</div>
<div class="flex justify-center">
<img src="https://huggingface.co/datasets/xDiT/documentation-images/resolve/main/performance/flux/Flux-2K-A100.png">
<img src="https://github.com/xdit-project/xDiT/raw/main/assets/performance/flux/Flux-2K-A100.png">
</div>
### Stable Diffusion 3
<div class="flex justify-center">
<img src="https://huggingface.co/datasets/xDiT/documentation-images/resolve/main/performance/sd3/L40-SD3.png">
<img src="https://github.com/xdit-project/xDiT/raw/main/assets/performance/sd3/L40-SD3.png">
</div>
<div class="flex justify-center">
<img src="https://huggingface.co/datasets/xDiT/documentation-images/resolve/main/performance/sd3/A100-SD3.png">
<img src="https://github.com/xdit-project/xDiT/raw/main/assets/performance/sd3/A100-SD3.png">
</div>
### HunyuanDiT
<div class="flex justify-center">
<img src="https://huggingface.co/datasets/xDiT/documentation-images/resolve/main/performance/hunuyuandit/L40-HunyuanDiT.png">
<img src="https://github.com/xdit-project/xDiT/raw/main/assets/performance/hunuyuandit/L40-HunyuanDiT.png">
</div>
<div class="flex justify-center">
<img src="https://huggingface.co/datasets/xDiT/documentation-images/resolve/main/performance/hunuyuandit/V100-HunyuanDiT.png">
<img src="https://github.com/xdit-project/xDiT/raw/main/assets/performance/hunuyuandit/A100-HunyuanDiT.png">
</div>
<div class="flex justify-center">
<img src="https://huggingface.co/datasets/xDiT/documentation-images/resolve/main/performance/hunuyuandit/T4-HunyuanDiT.png">
<img src="https://github.com/xdit-project/xDiT/raw/main/assets/performance/hunuyuandit/T4-HunyuanDiT.png">
</div>
More detailed performance metric can be found on our [github page](https://github.com/xdit-project/xDiT?tab=readme-ov-file#perf).
-291
View File
@@ -1,291 +0,0 @@
<!--Copyright 2024 The HuggingFace Team. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
the License. You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
specific language governing permissions and limitations under the License.
-->
# CogVideoX
CogVideoX is a text-to-video generation model focused on creating more coherent videos aligned with a prompt. It achieves this using several methods.
- a 3D variational autoencoder that compresses videos spatially and temporally, improving compression rate and video accuracy.
- an expert transformer block to help align text and video, and a 3D full attention module for capturing and creating spatially and temporally accurate videos.
The actual test of the video instruction dimension found that CogVideoX has good effects on consistent theme, dynamic information, consistent background, object information, smooth motion, color, scene, appearance style, and temporal style but cannot achieve good results with human action, spatial relationship, and multiple objects.
Finetuning with Diffusers can help make up for these poor results.
## Data Preparation
The training scripts accepts data in two formats.
The first format is suited for small-scale training, and the second format uses a CSV format, which is more appropriate for streaming data for large-scale training. In the future, Diffusers will support the `<Video>` tag.
### Small format
Two files where one file contains line-separated prompts and another file contains line-separated paths to video data (the path to video files must be relative to the path you pass when specifying `--instance_data_root`). Let's take a look at an example to understand this better!
Assume you've specified `--instance_data_root` as `/dataset`, and that this directory contains the files: `prompts.txt` and `videos.txt`.
The `prompts.txt` file should contain line-separated prompts:
```
A black and white animated sequence featuring a rabbit, named Rabbity Ribfried, and an anthropomorphic goat in a musical, playful environment, showcasing their evolving interaction.
A black and white animated sequence on a ship's deck features a bulldog character, named Bully Bulldoger, showcasing exaggerated facial expressions and body language. The character progresses from confident to focused, then to strained and distressed, displaying a range of emotions as it navigates challenges. The ship's interior remains static in the background, with minimalistic details such as a bell and open door. The character's dynamic movements and changing expressions drive the narrative, with no camera movement to distract from its evolving reactions and physical gestures.
...
```
The `videos.txt` file should contain line-separate paths to video files. Note that the path should be _relative_ to the `--instance_data_root` directory.
```
videos/00000.mp4
videos/00001.mp4
...
```
Overall, this is how your dataset would look like if you ran the `tree` command on the dataset root directory:
```
/dataset
├── prompts.txt
├── videos.txt
├── videos
├── videos/00000.mp4
├── videos/00001.mp4
├── ...
```
When using this format, the `--caption_column` must be `prompts.txt` and `--video_column` must be `videos.txt`.
### Stream format
You could use a single CSV file. For the sake of this example, assume you have a `metadata.csv` file. The expected format is:
```
<CAPTION_COLUMN>,<PATH_TO_VIDEO_COLUMN>
"""A black and white animated sequence featuring a rabbit, named Rabbity Ribfried, and an anthropomorphic goat in a musical, playful environment, showcasing their evolving interaction.""","""00000.mp4"""
"""A black and white animated sequence on a ship's deck features a bulldog character, named Bully Bulldoger, showcasing exaggerated facial expressions and body language. The character progresses from confident to focused, then to strained and distressed, displaying a range of emotions as it navigates challenges. The ship's interior remains static in the background, with minimalistic details such as a bell and open door. The character's dynamic movements and changing expressions drive the narrative, with no camera movement to distract from its evolving reactions and physical gestures.""","""00001.mp4"""
...
```
In this case, the `--instance_data_root` should be the location where the videos are stored and `--dataset_name` should be either a path to local folder or a [`~datasets.load_dataset`] compatible dataset hosted on the Hub. Assuming you have videos of Minecraft gameplay at `https://huggingface.co/datasets/my-awesome-username/minecraft-videos`, you would have to specify `my-awesome-username/minecraft-videos`.
When using this format, the `--caption_column` must be `<CAPTION_COLUMN>` and `--video_column` must be `<PATH_TO_VIDEO_COLUMN>`.
You are not strictly restricted to the CSV format. Any format works as long as the `load_dataset` method supports the file format to load a basic `<PATH_TO_VIDEO_COLUMN>` and `<CAPTION_COLUMN>`. The reason for going through these dataset organization gymnastics for loading video data is because `load_dataset` does not fully support all kinds of video formats.
> [!NOTE]
> CogVideoX works best with long and descriptive LLM-augmented prompts for video generation. We recommend pre-processing your videos by first generating a summary using a VLM and then augmenting the prompts with an LLM. To generate the above captions, we use [MiniCPM-V-26](https://huggingface.co/openbmb/MiniCPM-V-2_6) and [Llama-3.1-8B-Instruct](https://huggingface.co/meta-llama/Meta-Llama-3.1-8B-Instruct). A very barebones and no-frills example for this is available [here](https://gist.github.com/a-r-r-o-w/4dee20250e82f4e44690a02351324a4a). The official recommendation for augmenting prompts is [ChatGLM](https://huggingface.co/THUDM?search_models=chatglm) and a length of 50-100 words is considered good.
>![NOTE]
> It is expected that your dataset is already pre-processed. If not, some basic pre-processing can be done by playing with the following parameters:
> `--height`, `--width`, `--fps`, `--max_num_frames`, `--skip_frames_start` and `--skip_frames_end`.
> Presently, all videos in your dataset should contain the same number of video frames when using a training batch size > 1.
<!-- TODO: Implement frame packing in future to address above issue. -->
## Training
You need to setup your development environment by installing the necessary requirements. The following packages are required:
- Torch 2.0 or above based on the training features you are utilizing (might require latest or nightly versions for quantized/deepspeed training)
- `pip install diffusers transformers accelerate peft huggingface_hub` for all things modeling and training related
- `pip install datasets decord` for loading video training data
- `pip install bitsandbytes` for using 8-bit Adam or AdamW optimizers for memory-optimized training
- `pip install wandb` optionally for monitoring training logs
- `pip install deepspeed` optionally for [DeepSpeed](https://github.com/microsoft/DeepSpeed) training
- `pip install prodigyopt` optionally if you would like to use the Prodigy optimizer for training
To make sure you can successfully run the latest versions of the example scripts, we highly recommend **installing from source** and keeping the install up to date as we update the example scripts frequently and install some example-specific requirements. To do this, execute the following steps in a new virtual environment:
Before running the script, make sure you install the library from source:
```bash
git clone https://github.com/huggingface/diffusers
cd diffusers
pip install -e .
```
Then navigate to the example folder containing the training script and install the required dependencies for the script you're using:
- PyTorch
```bash
cd examples/cogvideo
pip install -r requirements.txt
```
And initialize an [🤗 Accelerate](https://github.com/huggingface/accelerate/) environment with:
```bash
accelerate config
```
Or for a default accelerate configuration without answering questions about your environment
```bash
accelerate config default
```
Or if your environment doesn't support an interactive shell (e.g., a notebook)
```python
from accelerate.utils import write_basic_config
write_basic_config()
```
When running `accelerate config`, if you use torch.compile, there can be dramatic speedups. The PEFT library is used as a backend for LoRA training, so make sure to have `peft>=0.6.0` installed in your environment.
If you would like to push your model to the Hub after training is completed with a neat model card, make sure you're logged in:
```bash
huggingface-cli login
# Alternatively, you could upload your model manually using:
# huggingface-cli upload my-cool-account-name/my-cool-lora-name /path/to/awesome/lora
```
Make sure your data is prepared as described in [Data Preparation](#data-preparation). When ready, you can begin training!
Assuming you are training on 50 videos of a similar concept, we have found 1500-2000 steps to work well. The official recommendation, however, is 100 videos with a total of 4000 steps. Assuming you are training on a single GPU with a `--train_batch_size` of `1`:
- 1500 steps on 50 videos would correspond to `30` training epochs
- 4000 steps on 100 videos would correspond to `40` training epochs
```bash
#!/bin/bash
GPU_IDS="0"
accelerate launch --gpu_ids $GPU_IDS examples/cogvideo/train_cogvideox_lora.py \
--pretrained_model_name_or_path THUDM/CogVideoX-2b \
--cache_dir <CACHE_DIR> \
--instance_data_root <PATH_TO_WHERE_VIDEO_FILES_ARE_STORED> \
--dataset_name my-awesome-name/my-awesome-dataset \
--caption_column <CAPTION_COLUMN> \
--video_column <PATH_TO_VIDEO_COLUMN> \
--id_token <ID_TOKEN> \
--validation_prompt "<ID_TOKEN> Spiderman swinging over buildings:::A panda, dressed in a small, red jacket and a tiny hat, sits on a wooden stool in a serene bamboo forest. The panda's fluffy paws strum a miniature acoustic guitar, producing soft, melodic tunes. Nearby, a few other pandas gather, watching curiously and some clapping in rhythm. Sunlight filters through the tall bamboo, casting a gentle glow on the scene. The panda's face is expressive, showing concentration and joy as it plays. The background includes a small, flowing stream and vibrant green foliage, enhancing the peaceful and magical atmosphere of this unique musical performance" \
--validation_prompt_separator ::: \
--num_validation_videos 1 \
--validation_epochs 10 \
--seed 42 \
--rank 64 \
--lora_alpha 64 \
--mixed_precision fp16 \
--output_dir /raid/aryan/cogvideox-lora \
--height 480 --width 720 --fps 8 --max_num_frames 49 --skip_frames_start 0 --skip_frames_end 0 \
--train_batch_size 1 \
--num_train_epochs 30 \
--checkpointing_steps 1000 \
--gradient_accumulation_steps 1 \
--learning_rate 1e-3 \
--lr_scheduler cosine_with_restarts \
--lr_warmup_steps 200 \
--lr_num_cycles 1 \
--enable_slicing \
--enable_tiling \
--optimizer Adam \
--adam_beta1 0.9 \
--adam_beta2 0.95 \
--max_grad_norm 1.0 \
--report_to wandb
```
To better track our training experiments, we're using the following flags in the command above:
* `--report_to wandb` will ensure the training runs are tracked on Weights and Biases. To use it, be sure to install `wandb` with `pip install wandb`.
* `validation_prompt` and `validation_epochs` to allow the script to do a few validation inference runs. This allows us to qualitatively check if the training is progressing as expected.
Setting the `<ID_TOKEN>` is not necessary. From some limited experimentation, we found it works better (as it resembles [Dreambooth](https://huggingface.co/docs/diffusers/en/training/dreambooth) training) than without. When provided, the `<ID_TOKEN>` is appended to the beginning of each prompt. So, if your `<ID_TOKEN>` was `"DISNEY"` and your prompt was `"Spiderman swinging over buildings"`, the effective prompt used in training would be `"DISNEY Spiderman swinging over buildings"`. When not provided, you would either be training without any additional token or could augment your dataset to apply the token where you wish before starting the training.
> [!NOTE]
> You can pass `--use_8bit_adam` to reduce the memory requirements of training.
> [!IMPORTANT]
> The following settings have been tested at the time of adding CogVideoX LoRA training support:
> - Our testing was primarily done on CogVideoX-2b. We will work on CogVideoX-5b and CogVideoX-5b-I2V soon
> - One dataset comprised of 70 training videos of resolutions `200 x 480 x 720` (F x H x W). From this, by using frame skipping in data preprocessing, we created two smaller 49-frame and 16-frame datasets for faster experimentation and because the maximum limit recommended by the CogVideoX team is 49 frames. Out of the 70 videos, we created three groups of 10, 25 and 50 videos. All videos were similar in nature of the concept being trained.
> - 25+ videos worked best for training new concepts and styles.
> - We found that it is better to train with an identifier token that can be specified as `--id_token`. This is similar to Dreambooth-like training but normal finetuning without such a token works too.
> - Trained concept seemed to work decently well when combined with completely unrelated prompts. We expect even better results if CogVideoX-5B is finetuned.
> - The original repository uses a `lora_alpha` of `1`. We found this not suitable in many runs, possibly due to difference in modeling backends and training settings. Our recommendation is to set to the `lora_alpha` to either `rank` or `rank // 2`.
> - If you're training on data whose captions generate bad results with the original model, a `rank` of 64 and above is good and also the recommendation by the team behind CogVideoX. If the generations are already moderately good on your training captions, a `rank` of 16/32 should work. We found that setting the rank too low, say `4`, is not ideal and doesn't produce promising results.
> - The authors of CogVideoX recommend 4000 training steps and 100 training videos overall to achieve the best result. While that might yield the best results, we found from our limited experimentation that 2000 steps and 25 videos could also be sufficient.
> - When using the Prodigy opitimizer for training, one can follow the recommendations from [this](https://huggingface.co/blog/sdxl_lora_advanced_script) blog. Prodigy tends to overfit quickly. From my very limited testing, I found a learning rate of `0.5` to be suitable in addition to `--prodigy_use_bias_correction`, `prodigy_safeguard_warmup` and `--prodigy_decouple`.
> - The recommended learning rate by the CogVideoX authors and from our experimentation with Adam/AdamW is between `1e-3` and `1e-4` for a dataset of 25+ videos.
>
> Note that our testing is not exhaustive due to limited time for exploration. Our recommendation would be to play around with the different knobs and dials to find the best settings for your data.
<!-- TODO: Test finetuning with CogVideoX-5b and CogVideoX-5b-I2V and update scripts accordingly -->
## Inference
Once you have trained a lora model, the inference can be done simply loading the lora weights into the `CogVideoXPipeline`.
```python
import torch
from diffusers import CogVideoXPipeline
from diffusers.utils import export_to_video
pipe = CogVideoXPipeline.from_pretrained("THUDM/CogVideoX-2b", torch_dtype=torch.float16)
# pipe.load_lora_weights("/path/to/lora/weights", adapter_name="cogvideox-lora") # Or,
pipe.load_lora_weights("my-awesome-hf-username/my-awesome-lora-name", adapter_name="cogvideox-lora") # If loading from the HF Hub
pipe.to("cuda")
# Assuming lora_alpha=32 and rank=64 for training. If different, set accordingly
pipe.set_adapters(["cogvideox-lora"], [32 / 64])
prompt = "A vast, shimmering ocean flows gracefully under a twilight sky, its waves undulating in a mesmerizing dance of blues and greens. The surface glints with the last rays of the setting sun, casting golden highlights that ripple across the water. Seagulls soar above, their cries blending with the gentle roar of the waves. The horizon stretches infinitely, where the ocean meets the sky in a seamless blend of hues. Close-ups reveal the intricate patterns of the waves, capturing the fluidity and dynamic beauty of the sea in motion."
frames = pipe(prompt, guidance_scale=6, use_dynamic_cfg=True).frames[0]
export_to_video(frames, "output.mp4", fps=8)
```
## Reduce memory usage
While testing using the diffusers library, all optimizations included in the diffusers library were enabled. This
scheme has not been tested for actual memory usage on devices outside of **NVIDIA A100 / H100** architectures.
Generally, this scheme can be adapted to all **NVIDIA Ampere architecture** and above devices. If optimizations are
disabled, memory consumption will multiply, with peak memory usage being about 3 times the value in the table.
However, speed will increase by about 3-4 times. You can selectively disable some optimizations, including:
```
pipe.enable_sequential_cpu_offload()
pipe.vae.enable_slicing()
pipe.vae.enable_tiling()
```
+ For multi-GPU inference, the `enable_sequential_cpu_offload()` optimization needs to be disabled.
+ Using INT8 models will slow down inference, which is done to accommodate lower-memory GPUs while maintaining minimal
video quality loss, though inference speed will significantly decrease.
+ The CogVideoX-2B model was trained in `FP16` precision, and all CogVideoX-5B models were trained in `BF16` precision.
We recommend using the precision in which the model was trained for inference.
+ [PytorchAO](https://github.com/pytorch/ao) and [Optimum-quanto](https://github.com/huggingface/optimum-quanto/) can be
used to quantize the text encoder, transformer, and VAE modules to reduce the memory requirements of CogVideoX. This
allows the model to run on free T4 Colabs or GPUs with smaller memory! Also, note that TorchAO quantization is fully
compatible with `torch.compile`, which can significantly improve inference speed. FP8 precision must be used on
devices with NVIDIA H100 and above, requiring source installation of `torch`, `torchao`, `diffusers`, and `accelerate`
Python packages. CUDA 12.4 is recommended.
+ The inference speed tests also used the above memory optimization scheme. Without memory optimization, inference speed
increases by about 10%. Only the `diffusers` version of the model supports quantization.
+ The model only supports English input; other languages can be translated into English for use via large model
refinement.
+ The memory usage of model fine-tuning is tested in an `8 * H100` environment, and the program automatically
uses `Zero 2` optimization. If a specific number of GPUs is marked in the table, that number or more GPUs must be used
for fine-tuning.
| **Attribute** | **CogVideoX-2B** | **CogVideoX-5B** |
| ------------------------------------ | ---------------------------------------------------------------------- | ---------------------------------------------------------------------- |
| **Model Name** | CogVideoX-2B | CogVideoX-5B |
| **Inference Precision** | FP16* (Recommended), BF16, FP32, FP8*, INT8, Not supported INT4 | BF16 (Recommended), FP16, FP32, FP8*, INT8, Not supported INT4 |
| **Single GPU Inference VRAM** | FP16: Using diffusers 12.5GB* INT8: Using diffusers with torchao 7.8GB* | BF16: Using diffusers 20.7GB* INT8: Using diffusers with torchao 11.4GB* |
| **Multi GPU Inference VRAM** | FP16: Using diffusers 10GB* | BF16: Using diffusers 15GB* |
| **Inference Speed** | Single A100: ~90 seconds, Single H100: ~45 seconds | Single A100: ~180 seconds, Single H100: ~90 seconds |
| **Fine-tuning Precision** | FP16 | BF16 |
| **Fine-tuning VRAM Consumption** | 47 GB (bs=1, LORA) 61 GB (bs=2, LORA) 62GB (bs=1, SFT) | 63 GB (bs=1, LORA) 80 GB (bs=2, LORA) 75GB (bs=1, SFT) |
@@ -177,7 +177,7 @@ transformer = FluxTransformer2DModel.from_pretrained(
```
> [!TIP]
> At any point, you can try `print(pipeline.hf_device_map)` to see how the various models are distributed across devices. This is useful for tracking the device placement of the models. You can also try `print(transformer.hf_device_map)` to see how the transformer model is sharded across devices.
> At any point, you can try `print(pipeline.hf_device_map)` to see how the various models are distributed across devices. This is useful for tracking the device placement of the models.
Add the transformer model to the pipeline for denoising, but set the other model-level components like the text encoders and VAE to `None` because you don't need them yet.
@@ -75,12 +75,6 @@ image
![pixel-art](https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/peft_integration/diffusers_peft_lora_inference_12_1.png)
<Tip>
By default, if the most up-to-date versions of PEFT and Transformers are detected, `low_cpu_mem_usage` is set to `True` to speed up the loading time of LoRA checkpoints.
</Tip>
## Merge adapters
You can also merge different adapter checkpoints for inference to blend their styles together.
-120
View File
@@ -1,120 +0,0 @@
<!--Copyright 2024 The HuggingFace Team. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
the License. You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
specific language governing permissions and limitations under the License.
-->
# CogVideoX
CogVideoX is a text-to-video generation model focused on creating more coherent videos aligned with a prompt. It achieves this using several methods.
- a 3D variational autoencoder that compresses videos spatially and temporally, improving compression rate and video accuracy.
- an expert transformer block to help align text and video, and a 3D full attention module for capturing and creating spatially and temporally accurate videos.
## Load model checkpoints
Model weights may be stored in separate subfolders on the Hub or locally, in which case, you should use the [`~DiffusionPipeline.from_pretrained`] method.
```py
from diffusers import CogVideoXPipeline, CogVideoXImageToVideoPipeline
pipe = CogVideoXPipeline.from_pretrained(
"THUDM/CogVideoX-2b",
torch_dtype=torch.float16
)
pipe = CogVideoXImageToVideoPipeline.from_pretrained(
"THUDM/CogVideoX-5b-I2V",
torch_dtype=torch.bfloat16
)
```
## Text-to-Video
For text-to-video, pass a text prompt. By default, CogVideoX generates a 720x480 video for the best results.
```py
import torch
from diffusers import CogVideoXPipeline
from diffusers.utils import export_to_video
prompt = "An elderly gentleman, with a serene expression, sits at the water's edge, a steaming cup of tea by his side. He is engrossed in his artwork, brush in hand, as he renders an oil painting on a canvas that's propped up against a small, weathered table. The sea breeze whispers through his silver hair, gently billowing his loose-fitting white shirt, while the salty air adds an intangible element to his masterpiece in progress. The scene is one of tranquility and inspiration, with the artist's canvas capturing the vibrant hues of the setting sun reflecting off the tranquil sea."
pipe = CogVideoXPipeline.from_pretrained(
"THUDM/CogVideoX-5b",
torch_dtype=torch.bfloat16
)
pipe.enable_model_cpu_offload()
pipe.vae.enable_tiling()
video = pipe(
prompt=prompt,
num_videos_per_prompt=1,
num_inference_steps=50,
num_frames=49,
guidance_scale=6,
generator=torch.Generator(device="cuda").manual_seed(42),
).frames[0]
export_to_video(video, "output.mp4", fps=8)
```
<div class="flex justify-center">
<img src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/cogvideox/cogvideox_out.gif" alt="generated image of an astronaut in a jungle"/>
</div>
## Image-to-Video
You'll use the [THUDM/CogVideoX-5b-I2V](https://huggingface.co/THUDM/CogVideoX-5b-I2V) checkpoint for this guide.
```py
import torch
from diffusers import CogVideoXImageToVideoPipeline
from diffusers.utils import export_to_video, load_image
prompt = "A vast, shimmering ocean flows gracefully under a twilight sky, its waves undulating in a mesmerizing dance of blues and greens. The surface glints with the last rays of the setting sun, casting golden highlights that ripple across the water. Seagulls soar above, their cries blending with the gentle roar of the waves. The horizon stretches infinitely, where the ocean meets the sky in a seamless blend of hues. Close-ups reveal the intricate patterns of the waves, capturing the fluidity and dynamic beauty of the sea in motion."
image = load_image(image="cogvideox_rocket.png")
pipe = CogVideoXImageToVideoPipeline.from_pretrained(
"THUDM/CogVideoX-5b-I2V",
torch_dtype=torch.bfloat16
)
pipe.vae.enable_tiling()
pipe.vae.enable_slicing()
video = pipe(
prompt=prompt,
image=image,
num_videos_per_prompt=1,
num_inference_steps=50,
num_frames=49,
guidance_scale=6,
generator=torch.Generator(device="cuda").manual_seed(42),
).frames[0]
export_to_video(video, "output.mp4", fps=8)
```
<div class="flex gap-4">
<div>
<img class="rounded-xl" src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/cogvideox/cogvideox_rocket.png"/>
<figcaption class="mt-2 text-center text-sm text-gray-500">initial image</figcaption>
</div>
<div>
<img class="rounded-xl" src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/cogvideox/cogvideox_outrocket.gif"/>
<figcaption class="mt-2 text-center text-sm text-gray-500">generated video</figcaption>
</div>
</div>
@@ -23,59 +23,6 @@ This guide will show you how to generate videos, how to configure video model pa
[Stable Video Diffusions (SVD)](https://huggingface.co/stabilityai/stable-video-diffusion-img2vid), [I2VGen-XL](https://huggingface.co/ali-vilab/i2vgen-xl/), [AnimateDiff](https://huggingface.co/guoyww/animatediff), and [ModelScopeT2V](https://huggingface.co/ali-vilab/text-to-video-ms-1.7b) are popular models used for video diffusion. Each model is distinct. For example, AnimateDiff inserts a motion modeling module into a frozen text-to-image model to generate personalized animated images, whereas SVD is entirely pretrained from scratch with a three-stage training process to generate short high-quality videos.
[CogVideoX](https://huggingface.co/collections/THUDM/cogvideo-66c08e62f1685a3ade464cce) is another popular video generation model. The model is a multidimensional transformer that integrates text, time, and space. It employs full attention in the attention module and includes an expert block at the layer level to spatially align text and video.
### CogVideoX
[CogVideoX](../api/pipelines/cogvideox) uses a 3D Variational Autoencoder (VAE) to compress videos along the spatial and temporal dimensions.
Begin by loading the [`CogVideoXPipeline`] and passing an initial text or image to generate a video.
<Tip>
CogVideoX is available for image-to-video and text-to-video. [THUDM/CogVideoX-5b-I2V](https://huggingface.co/THUDM/CogVideoX-5b-I2V) uses the [`CogVideoXImageToVideoPipeline`] for image-to-video. [THUDM/CogVideoX-5b](https://huggingface.co/THUDM/CogVideoX-5b) and [THUDM/CogVideoX-2b](https://huggingface.co/THUDM/CogVideoX-2b) are available for text-to-video with the [`CogVideoXPipeline`].
</Tip>
```py
import torch
from diffusers import CogVideoXImageToVideoPipeline
from diffusers.utils import export_to_video, load_image
prompt = "A vast, shimmering ocean flows gracefully under a twilight sky, its waves undulating in a mesmerizing dance of blues and greens. The surface glints with the last rays of the setting sun, casting golden highlights that ripple across the water. Seagulls soar above, their cries blending with the gentle roar of the waves. The horizon stretches infinitely, where the ocean meets the sky in a seamless blend of hues. Close-ups reveal the intricate patterns of the waves, capturing the fluidity and dynamic beauty of the sea in motion."
image = load_image(image="cogvideox_rocket.png")
pipe = CogVideoXImageToVideoPipeline.from_pretrained(
"THUDM/CogVideoX-5b-I2V",
torch_dtype=torch.bfloat16
)
pipe.vae.enable_tiling()
pipe.vae.enable_slicing()
video = pipe(
prompt=prompt,
image=image,
num_videos_per_prompt=1,
num_inference_steps=50,
num_frames=49,
guidance_scale=6,
generator=torch.Generator(device="cuda").manual_seed(42),
).frames[0]
export_to_video(video, "output.mp4", fps=8)
```
<div class="flex gap-4">
<div>
<img class="rounded-xl" src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/cogvideox/cogvideox_rocket.png"/>
<figcaption class="mt-2 text-center text-sm text-gray-500">initial image</figcaption>
</div>
<div>
<img class="rounded-xl" src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/cogvideox/cogvideox_outrocket.gif"/>
<figcaption class="mt-2 text-center text-sm text-gray-500">generated video</figcaption>
</div>
</div>
### Stable Video Diffusion
[SVD](../api/pipelines/svd) is based on the Stable Diffusion 2.1 model and it is trained on images, then low-resolution videos, and finally a smaller dataset of high-resolution videos. This model generates a short 2-4 second video from an initial image. You can learn more details about model, like micro-conditioning, in the [Stable Video Diffusion](../using-diffusers/svd) guide.
@@ -1,353 +0,0 @@
# Advanced diffusion training examples
## Train Dreambooth LoRA with Flux.1 Dev
> [!TIP]
> 💡 This example follows some of the techniques and recommended practices covered in the community derived guide we made for SDXL training: [LoRA training scripts of the world, unite!](https://huggingface.co/blog/sdxl_lora_advanced_script).
> As many of these are architecture agnostic & generally relevant to fine-tuning of diffusion models we suggest to take a look 🤗
[DreamBooth](https://arxiv.org/abs/2208.12242) is a method to personalize text-to-image models like flux, stable diffusion given just a few(3~5) images of a subject.
LoRA - Low-Rank Adaption of Large Language Models, was first introduced by Microsoft in [LoRA: Low-Rank Adaptation of Large Language Models](https://arxiv.org/abs/2106.09685) by *Edward J. Hu, Yelong Shen, Phillip Wallis, Zeyuan Allen-Zhu, Yuanzhi Li, Shean Wang, Lu Wang, Weizhu Chen*
In a nutshell, LoRA allows to adapt pretrained models by adding pairs of rank-decomposition matrices to existing weights and **only** training those newly added weights. This has a couple of advantages:
- Previous pretrained weights are kept frozen so that the model is not prone to [catastrophic forgetting](https://www.pnas.org/doi/10.1073/pnas.1611835114)
- Rank-decomposition matrices have significantly fewer parameters than the original model, which means that trained LoRA weights are easily portable.
- LoRA attention layers allow to control to which extent the model is adapted towards new training images via a `scale` parameter.
[cloneofsimo](https://github.com/cloneofsimo) was the first to try out LoRA training for Stable Diffusion in
the popular [lora](https://github.com/cloneofsimo/lora) GitHub repository.
The `train_dreambooth_lora_flux_advanced.py` script shows how to implement dreambooth-LoRA, combining the training process shown in `train_dreambooth_lora_flux.py`, with
advanced features and techniques, inspired and built upon contributions by [Nataniel Ruiz](https://twitter.com/natanielruizg): [Dreambooth](https://dreambooth.github.io), [Rinon Gal](https://twitter.com/RinonGal): [Textual Inversion](https://textual-inversion.github.io), [Ron Mokady](https://twitter.com/MokadyRon): [Pivotal Tuning](https://arxiv.org/abs/2106.05744), [Simo Ryu](https://twitter.com/cloneofsimo): [cog-sdxl](https://github.com/replicate/cog-sdxl),
[ostris](https://x.com/ostrisai):[ai-toolkit](https://github.com/ostris/ai-toolkit), [bghira](https://github.com/bghira):[SimpleTuner](https://github.com/bghira/SimpleTuner), [Kohya](https://twitter.com/kohya_tech/): [sd-scripts](https://github.com/kohya-ss/sd-scripts), [The Last Ben](https://twitter.com/__TheBen): [fast-stable-diffusion](https://github.com/TheLastBen/fast-stable-diffusion) ❤️
> [!NOTE]
> 💡If this is your first time training a Dreambooth LoRA, congrats!🥳
> You might want to familiarize yourself more with the techniques: [Dreambooth blog](https://huggingface.co/blog/dreambooth), [Using LoRA for Efficient Stable Diffusion Fine-Tuning blog](https://huggingface.co/blog/lora)
## Running locally with PyTorch
### Installing the dependencies
Before running the scripts, make sure to install the library's training dependencies:
**Important**
To make sure you can successfully run the latest versions of the example scripts, we highly recommend **installing from source** and keeping the install up to date as we update the example scripts frequently and install some example-specific requirements. To do this, execute the following steps in a new virtual environment:
```bash
git clone https://github.com/huggingface/diffusers
cd diffusers
pip install -e .
```
Then cd in the `examples/advanced_diffusion_training` folder and run
```bash
pip install -r requirements.txt
```
And initialize an [🤗Accelerate](https://github.com/huggingface/accelerate/) environment with:
```bash
accelerate config
```
Or for a default accelerate configuration without answering questions about your environment
```bash
accelerate config default
```
Or if your environment doesn't support an interactive shell e.g. a notebook
```python
from accelerate.utils import write_basic_config
write_basic_config()
```
When running `accelerate config`, if we specify torch compile mode to True there can be dramatic speedups.
Note also that we use PEFT library as backend for LoRA training, make sure to have `peft>=0.6.0` installed in your environment.
### Target Modules
When LoRA was first adapted from language models to diffusion models, it was applied to the cross-attention layers in the Unet that relate the image representations with the prompts that describe them.
More recently, SOTA text-to-image diffusion models replaced the Unet with a diffusion Transformer(DiT). With this change, we may also want to explore
applying LoRA training onto different types of layers and blocks. To allow more flexibility and control over the targeted modules we added `--lora_layers`- in which you can specify in a comma seperated string
the exact modules for LoRA training. Here are some examples of target modules you can provide:
- for attention only layers: `--lora_layers="attn.to_k,attn.to_q,attn.to_v,attn.to_out.0"`
- to train the same modules as in the fal trainer: `--lora_layers="attn.to_k,attn.to_q,attn.to_v,attn.to_out.0,attn.add_k_proj,attn.add_q_proj,attn.add_v_proj,attn.to_add_out,ff.net.0.proj,ff.net.2,ff_context.net.0.proj,ff_context.net.2"`
- to train the same modules as in ostris ai-toolkit / replicate trainer: `--lora_blocks="attn.to_k,attn.to_q,attn.to_v,attn.to_out.0,attn.add_k_proj,attn.add_q_proj,attn.add_v_proj,attn.to_add_out,ff.net.0.proj,ff.net.2,ff_context.net.0.proj,ff_context.net.2,norm1_context.linear, norm1.linear,norm.linear,proj_mlp,proj_out"`
> [!NOTE]
> `--lora_layers` can also be used to specify which **blocks** to apply LoRA training to. To do so, simply add a block prefix to each layer in the comma seperated string:
> **single DiT blocks**: to target the ith single transformer block, add the prefix `single_transformer_blocks.i`, e.g. - `single_transformer_blocks.i.attn.to_k`
> **MMDiT blocks**: to target the ith MMDiT block, add the prefix `transformer_blocks.i`, e.g. - `transformer_blocks.i.attn.to_k`
> [!NOTE]
> keep in mind that while training more layers can improve quality and expressiveness, it also increases the size of the output LoRA weights.
### Pivotal Tuning (and more)
**Training with text encoder(s)**
Alongside the Transformer, LoRA fine-tuning of the text encoders is also supported. In addition to the text encoder optimization
available with `train_dreambooth_lora_flux_advanced.py`, in the advanced script **pivotal tuning** is also supported.
[pivotal tuning](https://huggingface.co/blog/sdxl_lora_advanced_script#pivotal-tuning) combines Textual Inversion with regular diffusion fine-tuning -
we insert new tokens into the text encoders of the model, instead of reusing existing ones.
We then optimize the newly-inserted token embeddings to represent the new concept.
To do so, just specify `--train_text_encoder_ti` while launching training (for regular text encoder optimizations, use `--train_text_encoder`).
Please keep the following points in mind:
* Flux uses two text encoders - [CLIP](https://huggingface.co/docs/diffusers/main/en/api/pipelines/flux#diffusers.FluxPipeline.text_encoder) & [T5](https://huggingface.co/docs/diffusers/main/en/api/pipelines/flux#diffusers.FluxPipeline.text_encoder_2) , by default `--train_text_encoder_ti` performs pivotal tuning for the **CLIP** encoder only.
To activate pivotal tuning for both encoders, add the flag `--enable_t5_ti`.
* When not fine-tuning the text encoders, we ALWAYS precompute the text embeddings to save memory.
* **pure textual inversion** - to support the full range from pivotal tuning to textual inversion we introduce `--train_transformer_frac` which controls the amount of epochs the transformer LoRA layers are trained. By default, `--train_transformer_frac==1`, to trigger a textual inversion run set `--train_transformer_frac==0`. Values between 0 and 1 are supported as well, and we welcome the community to experiment w/ different settings and share the results!
* **token initializer** - similar to the original textual inversion work, you can specify a token of your choosing as the starting point for training. By default, when enabling `--train_text_encoder_ti`, the new inserted tokens are initialized randomly. You can specify a token in `--initializer_token` such that the starting point for the trained embeddings will be the embeddings associated with your chosen `--initializer_token`.
## Training examples
Now let's get our dataset. For this example we will use some cool images of 3d rendered icons: https://huggingface.co/datasets/linoyts/3d_icon.
Let's first download it locally:
```python
from huggingface_hub import snapshot_download
local_dir = "./3d_icon"
snapshot_download(
"LinoyTsaban/3d_icon",
local_dir=local_dir, repo_type="dataset",
ignore_patterns=".gitattributes",
)
```
Let's review some of the advanced features we're going to be using for this example:
- **custom captions**:
To use custom captioning, first ensure that you have the datasets library installed, otherwise you can install it by
```bash
pip install datasets
```
Now we'll simply specify the name of the dataset and caption column (in this case it's "prompt")
```
--dataset_name=./3d_icon
--caption_column=prompt
```
You can also load a dataset straight from by specifying it's name in `dataset_name`.
Look [here](https://huggingface.co/blog/sdxl_lora_advanced_script#custom-captioning) for more info on creating/loadin your own caption dataset.
- **optimizer**: for this example, we'll use [prodigy](https://huggingface.co/blog/sdxl_lora_advanced_script#adaptive-optimizers) - an adaptive optimizer
- **pivotal tuning**
### Example #1: Pivotal tuning
**Now, we can launch training:**
```bash
export MODEL_NAME="black-forest-labs/FLUX.1-dev"
export DATASET_NAME="./3d_icon"
export OUTPUT_DIR="3d-icon-Flux-LoRA"
accelerate launch train_dreambooth_lora_flux_advanced.py \
--pretrained_model_name_or_path=$MODEL_NAME \
--dataset_name=$DATASET_NAME \
--instance_prompt="3d icon in the style of TOK" \
--validation_prompt="a TOK icon of an astronaut riding a horse, in the style of TOK" \
--output_dir=$OUTPUT_DIR \
--caption_column="prompt" \
--mixed_precision="bf16" \
--resolution=1024 \
--train_batch_size=1 \
--repeats=1 \
--report_to="wandb"\
--gradient_accumulation_steps=1 \
--gradient_checkpointing \
--learning_rate=1.0 \
--text_encoder_lr=1.0 \
--optimizer="prodigy"\
--train_text_encoder_ti\
--train_text_encoder_ti_frac=0.5\
--lr_scheduler="constant" \
--lr_warmup_steps=0 \
--rank=8 \
--max_train_steps=1000 \
--checkpointing_steps=2000 \
--seed="0" \
--push_to_hub
```
To better track our training experiments, we're using the following flags in the command above:
* `report_to="wandb` will ensure the training runs are tracked on Weights and Biases. To use it, be sure to install `wandb` with `pip install wandb`.
* `validation_prompt` and `validation_epochs` to allow the script to do a few validation inference runs. This allows us to qualitatively check if the training is progressing as expected.
Our experiments were conducted on a single 40GB A100 GPU.
### Example #2: Pivotal tuning with T5
Now let's try that with T5 as well, so instead of only optimizing the CLIP embeddings associated with newly inserted tokens, we'll optimize
the T5 embeddings as well. We can do this by simply adding `--enable_t5_ti` to the previous configuration:
```bash
export MODEL_NAME="black-forest-labs/FLUX.1-dev"
export DATASET_NAME="./3d_icon"
export OUTPUT_DIR="3d-icon-Flux-LoRA"
accelerate launch train_dreambooth_lora_flux_advanced.py \
--pretrained_model_name_or_path=$MODEL_NAME \
--dataset_name=$DATASET_NAME \
--instance_prompt="3d icon in the style of TOK" \
--validation_prompt="a TOK icon of an astronaut riding a horse, in the style of TOK" \
--output_dir=$OUTPUT_DIR \
--caption_column="prompt" \
--mixed_precision="bf16" \
--resolution=1024 \
--train_batch_size=1 \
--repeats=1 \
--report_to="wandb"\
--gradient_accumulation_steps=1 \
--gradient_checkpointing \
--learning_rate=1.0 \
--text_encoder_lr=1.0 \
--optimizer="prodigy"\
--train_text_encoder_ti\
--enable_t5_ti\
--train_text_encoder_ti_frac=0.5\
--lr_scheduler="constant" \
--lr_warmup_steps=0 \
--rank=8 \
--max_train_steps=1000 \
--checkpointing_steps=2000 \
--seed="0" \
--push_to_hub
```
### Example #3: Textual Inversion
To explore a pure textual inversion - i.e. only optimizing the text embeddings w/o training transformer LoRA layers, we
can set the value for `--train_transformer_frac` - which is responsible for the percent of epochs in which the transformer is
trained. By setting `--train_transformer_frac == 0` and enabling `--train_text_encoder_ti` we trigger a textual inversion train
run.
```bash
export MODEL_NAME="black-forest-labs/FLUX.1-dev"
export DATASET_NAME="./3d_icon"
export OUTPUT_DIR="3d-icon-Flux-LoRA"
accelerate launch train_dreambooth_lora_flux_advanced.py \
--pretrained_model_name_or_path=$MODEL_NAME \
--dataset_name=$DATASET_NAME \
--instance_prompt="3d icon in the style of TOK" \
--validation_prompt="a TOK icon of an astronaut riding a horse, in the style of TOK" \
--output_dir=$OUTPUT_DIR \
--caption_column="prompt" \
--mixed_precision="bf16" \
--resolution=1024 \
--train_batch_size=1 \
--repeats=1 \
--report_to="wandb"\
--gradient_accumulation_steps=1 \
--gradient_checkpointing \
--learning_rate=1.0 \
--text_encoder_lr=1.0 \
--optimizer="prodigy"\
--train_text_encoder_ti\
--enable_t5_ti\
--train_text_encoder_ti_frac=0.5\
--train_transformer_frac=0\
--lr_scheduler="constant" \
--lr_warmup_steps=0 \
--rank=8 \
--max_train_steps=1000 \
--checkpointing_steps=2000 \
--seed="0" \
--push_to_hub
```
### Inference - pivotal tuning
Once training is done, we can perform inference like so:
1. starting with loading the transformer lora weights
```python
import torch
from huggingface_hub import hf_hub_download, upload_file
from diffusers import AutoPipelineForText2Image
from safetensors.torch import load_file
username = "linoyts"
repo_id = f"{username}/3d-icon-Flux-LoRA"
pipe = AutoPipelineForText2Image.from_pretrained("black-forest-labs/FLUX.1-dev", torch_dtype=torch.bfloat16).to('cuda')
pipe.load_lora_weights(repo_id, weight_name="pytorch_lora_weights.safetensors")
```
2. now we load the pivotal tuning embeddings
💡note that if you didn't enable `--enable_t5_ti`, you only load the embeddings to the CLIP encoder
```python
text_encoders = [pipe.text_encoder, pipe.text_encoder_2]
tokenizers = [pipe.tokenizer, pipe.tokenizer_2]
embedding_path = hf_hub_download(repo_id=repo_id, filename="3d-icon-Flux-LoRA_emb.safetensors", repo_type="model")
state_dict = load_file(embedding_path)
# load embeddings of text_encoder 1 (CLIP ViT-L/14)
pipe.load_textual_inversion(state_dict["clip_l"], token=["<s0>", "<s1>"], text_encoder=pipe.text_encoder, tokenizer=pipe.tokenizer)
# load embeddings of text_encoder 2 (T5 XXL) - ignore this line if you didn't enable `--enable_t5_ti`
pipe.load_textual_inversion(state_dict["t5"], token=["<s0>", "<s1>"], text_encoder=pipe.text_encoder_2, tokenizer=pipe.tokenizer_2)
```
3. let's generate images
```python
instance_token = "<s0><s1>"
prompt = f"a {instance_token} icon of an orange llama eating ramen, in the style of {instance_token}"
image = pipe(prompt=prompt, num_inference_steps=25, cross_attention_kwargs={"scale": 1.0}).images[0]
image.save("llama.png")
```
### Inference - pure textual inversion
In this case, we don't load transformer layers as before, since we only optimize the text embeddings. The output of a textual inversion train run is a
`.safetensors` file containing the trained embeddings for the new tokens either for the CLIP encoder, or for both encoders (CLIP and T5)
1. starting with loading the embeddings.
💡note that here too, if you didn't enable `--enable_t5_ti`, you only load the embeddings to the CLIP encoder
```python
import torch
from huggingface_hub import hf_hub_download, upload_file
from diffusers import AutoPipelineForText2Image
from safetensors.torch import load_file
username = "linoyts"
repo_id = f"{username}/3d-icon-Flux-LoRA"
pipe = AutoPipelineForText2Image.from_pretrained("black-forest-labs/FLUX.1-dev", torch_dtype=torch.bfloat16).to('cuda')
text_encoders = [pipe.text_encoder, pipe.text_encoder_2]
tokenizers = [pipe.tokenizer, pipe.tokenizer_2]
embedding_path = hf_hub_download(repo_id=repo_id, filename="3d-icon-Flux-LoRA_emb.safetensors", repo_type="model")
state_dict = load_file(embedding_path)
# load embeddings of text_encoder 1 (CLIP ViT-L/14)
pipe.load_textual_inversion(state_dict["clip_l"], token=["<s0>", "<s1>"], text_encoder=pipe.text_encoder, tokenizer=pipe.tokenizer)
# load embeddings of text_encoder 2 (T5 XXL) - ignore this line if you didn't enable `--enable_t5_ti`
pipe.load_textual_inversion(state_dict["t5"], token=["<s0>", "<s1>"], text_encoder=pipe.text_encoder_2, tokenizer=pipe.tokenizer_2)
```
2. let's generate images
```python
instance_token = "<s0><s1>"
prompt = f"a {instance_token} icon of an orange llama eating ramen, in the style of {instance_token}"
image = pipe(prompt=prompt, num_inference_steps=25, cross_attention_kwargs={"scale": 1.0}).images[0]
image.save("llama.png")
```
### Comfy UI / AUTOMATIC1111 Inference
The new script fully supports textual inversion loading with Comfy UI and AUTOMATIC1111 formats!
**AUTOMATIC1111 / SD.Next** \
In AUTOMATIC1111/SD.Next we will load a LoRA and a textual embedding at the same time.
- *LoRA*: Besides the diffusers format, the script will also train a WebUI compatible LoRA. It is generated as `{your_lora_name}.safetensors`. You can then include it in your `models/Lora` directory.
- *Embedding*: the embedding is the same for diffusers and WebUI. You can download your `{lora_name}_emb.safetensors` file from a trained model, and include it in your `embeddings` directory.
You can then run inference by prompting `a y2k_emb webpage about the movie Mean Girls <lora:y2k:0.9>`. You can use the `y2k_emb` token normally, including increasing its weight by doing `(y2k_emb:1.2)`.
**ComfyUI** \
In ComfyUI we will load a LoRA and a textual embedding at the same time.
- *LoRA*: Besides the diffusers format, the script will also train a ComfyUI compatible LoRA. It is generated as `{your_lora_name}.safetensors`. You can then include it in your `models/Lora` directory. Then you will load the LoRALoader node and hook that up with your model and CLIP. [Official guide for loading LoRAs](https://comfyanonymous.github.io/ComfyUI_examples/lora/)
- *Embedding*: the embedding is the same for diffusers and WebUI. You can download your `{lora_name}_emb.safetensors` file from a trained model, and include it in your `models/embeddings` directory and use it in your prompts like `embedding:y2k_emb`. [Official guide for loading embeddings](https://comfyanonymous.github.io/ComfyUI_examples/textual_inversion_embeddings/).
@@ -1,8 +0,0 @@
accelerate>=0.31.0
torchvision
transformers>=4.41.2
ftfy
tensorboard
Jinja2
peft>=0.11.1
sentencepiece
@@ -1,283 +0,0 @@
# coding=utf-8
# Copyright 2024 HuggingFace Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
import os
import sys
import tempfile
import safetensors
sys.path.append("..")
from test_examples_utils import ExamplesTestsAccelerate, run_command # noqa: E402
logging.basicConfig(level=logging.DEBUG)
logger = logging.getLogger()
stream_handler = logging.StreamHandler(sys.stdout)
logger.addHandler(stream_handler)
class DreamBoothLoRAFluxAdvanced(ExamplesTestsAccelerate):
instance_data_dir = "docs/source/en/imgs"
instance_prompt = "photo"
pretrained_model_name_or_path = "hf-internal-testing/tiny-flux-pipe"
script_path = "examples/advanced_diffusion_training/train_dreambooth_lora_flux_advanced.py"
def test_dreambooth_lora_flux(self):
with tempfile.TemporaryDirectory() as tmpdir:
test_args = f"""
{self.script_path}
--pretrained_model_name_or_path {self.pretrained_model_name_or_path}
--instance_data_dir {self.instance_data_dir}
--instance_prompt {self.instance_prompt}
--resolution 64
--train_batch_size 1
--gradient_accumulation_steps 1
--max_train_steps 2
--learning_rate 5.0e-04
--scale_lr
--lr_scheduler constant
--lr_warmup_steps 0
--output_dir {tmpdir}
""".split()
run_command(self._launch_args + test_args)
# save_pretrained smoke test
self.assertTrue(os.path.isfile(os.path.join(tmpdir, "pytorch_lora_weights.safetensors")))
# make sure the state_dict has the correct naming in the parameters.
lora_state_dict = safetensors.torch.load_file(os.path.join(tmpdir, "pytorch_lora_weights.safetensors"))
is_lora = all("lora" in k for k in lora_state_dict.keys())
self.assertTrue(is_lora)
# when not training the text encoder, all the parameters in the state dict should start
# with `"transformer"` in their names.
starts_with_transformer = all(key.startswith("transformer") for key in lora_state_dict.keys())
self.assertTrue(starts_with_transformer)
def test_dreambooth_lora_text_encoder_flux(self):
with tempfile.TemporaryDirectory() as tmpdir:
test_args = f"""
{self.script_path}
--pretrained_model_name_or_path {self.pretrained_model_name_or_path}
--instance_data_dir {self.instance_data_dir}
--instance_prompt {self.instance_prompt}
--resolution 64
--train_batch_size 1
--train_text_encoder
--gradient_accumulation_steps 1
--max_train_steps 2
--learning_rate 5.0e-04
--scale_lr
--lr_scheduler constant
--lr_warmup_steps 0
--output_dir {tmpdir}
""".split()
run_command(self._launch_args + test_args)
# save_pretrained smoke test
self.assertTrue(os.path.isfile(os.path.join(tmpdir, "pytorch_lora_weights.safetensors")))
# make sure the state_dict has the correct naming in the parameters.
lora_state_dict = safetensors.torch.load_file(os.path.join(tmpdir, "pytorch_lora_weights.safetensors"))
is_lora = all("lora" in k for k in lora_state_dict.keys())
self.assertTrue(is_lora)
starts_with_expected_prefix = all(
(key.startswith("transformer") or key.startswith("text_encoder")) for key in lora_state_dict.keys()
)
self.assertTrue(starts_with_expected_prefix)
def test_dreambooth_lora_pivotal_tuning_flux_clip(self):
with tempfile.TemporaryDirectory() as tmpdir:
test_args = f"""
{self.script_path}
--pretrained_model_name_or_path {self.pretrained_model_name_or_path}
--instance_data_dir {self.instance_data_dir}
--instance_prompt {self.instance_prompt}
--resolution 64
--train_batch_size 1
--train_text_encoder_ti
--gradient_accumulation_steps 1
--max_train_steps 2
--learning_rate 5.0e-04
--scale_lr
--lr_scheduler constant
--lr_warmup_steps 0
--output_dir {tmpdir}
""".split()
run_command(self._launch_args + test_args)
# save_pretrained smoke test
self.assertTrue(os.path.isfile(os.path.join(tmpdir, "pytorch_lora_weights.safetensors")))
# make sure embeddings were also saved
self.assertTrue(os.path.isfile(os.path.join(tmpdir, f"{os.path.basename(tmpdir)}_emb.safetensors")))
# make sure the state_dict has the correct naming in the parameters.
lora_state_dict = safetensors.torch.load_file(os.path.join(tmpdir, "pytorch_lora_weights.safetensors"))
is_lora = all("lora" in k for k in lora_state_dict.keys())
self.assertTrue(is_lora)
# make sure the state_dict has the correct naming in the parameters.
textual_inversion_state_dict = safetensors.torch.load_file(
os.path.join(tmpdir, f"{os.path.basename(tmpdir)}_emb.safetensors")
)
is_clip = all("clip_l" in k for k in textual_inversion_state_dict.keys())
self.assertTrue(is_clip)
# when performing pivotal tuning, all the parameters in the state dict should start
# with `"transformer"` in their names.
starts_with_transformer = all(key.startswith("transformer") for key in lora_state_dict.keys())
self.assertTrue(starts_with_transformer)
def test_dreambooth_lora_pivotal_tuning_flux_clip_t5(self):
with tempfile.TemporaryDirectory() as tmpdir:
test_args = f"""
{self.script_path}
--pretrained_model_name_or_path {self.pretrained_model_name_or_path}
--instance_data_dir {self.instance_data_dir}
--instance_prompt {self.instance_prompt}
--resolution 64
--train_batch_size 1
--train_text_encoder_ti
--enable_t5_ti
--gradient_accumulation_steps 1
--max_train_steps 2
--learning_rate 5.0e-04
--scale_lr
--lr_scheduler constant
--lr_warmup_steps 0
--output_dir {tmpdir}
""".split()
run_command(self._launch_args + test_args)
# save_pretrained smoke test
self.assertTrue(os.path.isfile(os.path.join(tmpdir, "pytorch_lora_weights.safetensors")))
# make sure embeddings were also saved
self.assertTrue(os.path.isfile(os.path.join(tmpdir, f"{os.path.basename(tmpdir)}_emb.safetensors")))
# make sure the state_dict has the correct naming in the parameters.
lora_state_dict = safetensors.torch.load_file(os.path.join(tmpdir, "pytorch_lora_weights.safetensors"))
is_lora = all("lora" in k for k in lora_state_dict.keys())
self.assertTrue(is_lora)
# make sure the state_dict has the correct naming in the parameters.
textual_inversion_state_dict = safetensors.torch.load_file(
os.path.join(tmpdir, f"{os.path.basename(tmpdir)}_emb.safetensors")
)
is_te = all(("clip_l" in k or "t5" in k) for k in textual_inversion_state_dict.keys())
self.assertTrue(is_te)
# when performing pivotal tuning, all the parameters in the state dict should start
# with `"transformer"` in their names.
starts_with_transformer = all(key.startswith("transformer") for key in lora_state_dict.keys())
self.assertTrue(starts_with_transformer)
def test_dreambooth_lora_latent_caching(self):
with tempfile.TemporaryDirectory() as tmpdir:
test_args = f"""
{self.script_path}
--pretrained_model_name_or_path {self.pretrained_model_name_or_path}
--instance_data_dir {self.instance_data_dir}
--instance_prompt {self.instance_prompt}
--resolution 64
--train_batch_size 1
--gradient_accumulation_steps 1
--max_train_steps 2
--cache_latents
--learning_rate 5.0e-04
--scale_lr
--lr_scheduler constant
--lr_warmup_steps 0
--output_dir {tmpdir}
""".split()
run_command(self._launch_args + test_args)
# save_pretrained smoke test
self.assertTrue(os.path.isfile(os.path.join(tmpdir, "pytorch_lora_weights.safetensors")))
# make sure the state_dict has the correct naming in the parameters.
lora_state_dict = safetensors.torch.load_file(os.path.join(tmpdir, "pytorch_lora_weights.safetensors"))
is_lora = all("lora" in k for k in lora_state_dict.keys())
self.assertTrue(is_lora)
# when not training the text encoder, all the parameters in the state dict should start
# with `"transformer"` in their names.
starts_with_transformer = all(key.startswith("transformer") for key in lora_state_dict.keys())
self.assertTrue(starts_with_transformer)
def test_dreambooth_lora_flux_checkpointing_checkpoints_total_limit(self):
with tempfile.TemporaryDirectory() as tmpdir:
test_args = f"""
{self.script_path}
--pretrained_model_name_or_path={self.pretrained_model_name_or_path}
--instance_data_dir={self.instance_data_dir}
--output_dir={tmpdir}
--instance_prompt={self.instance_prompt}
--resolution=64
--train_batch_size=1
--gradient_accumulation_steps=1
--max_train_steps=6
--checkpoints_total_limit=2
--checkpointing_steps=2
""".split()
run_command(self._launch_args + test_args)
self.assertEqual(
{x for x in os.listdir(tmpdir) if "checkpoint" in x},
{"checkpoint-4", "checkpoint-6"},
)
def test_dreambooth_lora_flux_checkpointing_checkpoints_total_limit_removes_multiple_checkpoints(self):
with tempfile.TemporaryDirectory() as tmpdir:
test_args = f"""
{self.script_path}
--pretrained_model_name_or_path={self.pretrained_model_name_or_path}
--instance_data_dir={self.instance_data_dir}
--output_dir={tmpdir}
--instance_prompt={self.instance_prompt}
--resolution=64
--train_batch_size=1
--gradient_accumulation_steps=1
--max_train_steps=4
--checkpointing_steps=2
""".split()
run_command(self._launch_args + test_args)
self.assertEqual({x for x in os.listdir(tmpdir) if "checkpoint" in x}, {"checkpoint-2", "checkpoint-4"})
resume_run_args = f"""
{self.script_path}
--pretrained_model_name_or_path={self.pretrained_model_name_or_path}
--instance_data_dir={self.instance_data_dir}
--output_dir={tmpdir}
--instance_prompt={self.instance_prompt}
--resolution=64
--train_batch_size=1
--gradient_accumulation_steps=1
--max_train_steps=8
--checkpointing_steps=2
--resume_from_checkpoint=checkpoint-4
--checkpoints_total_limit=2
""".split()
run_command(self._launch_args + resume_run_args)
self.assertEqual({x for x in os.listdir(tmpdir) if "checkpoint" in x}, {"checkpoint-6", "checkpoint-8"})
File diff suppressed because it is too large Load Diff
+2 -12
View File
@@ -10,11 +10,6 @@ In a nutshell, LoRA allows adapting pretrained models by adding pairs of rank-de
At the moment, LoRA finetuning has only been tested for [CogVideoX-2b](https://huggingface.co/THUDM/CogVideoX-2b).
> [!NOTE]
> The scripts for CogVideoX come with limited support and may not be fully compatible with different training techniques. They are not feature-rich either and simply serve as minimal examples of finetuning to take inspiration from and improve.
>
> A repository containing memory-optimized finetuning scripts with support for multiple resolutions, dataset preparation, captioning, etc. is available [here](https://github.com/a-r-r-o-w/cogvideox-factory), which will be maintained jointly by the CogVideoX and Diffusers team.
## Data Preparation
The training scripts accepts data in two formats.
@@ -137,8 +132,6 @@ Assuming you are training on 50 videos of a similar concept, we have found 1500-
- 1500 steps on 50 videos would correspond to `30` training epochs
- 4000 steps on 100 videos would correspond to `40` training epochs
The following bash script launches training for text-to-video lora.
```bash
#!/bin/bash
@@ -179,8 +172,6 @@ accelerate launch --gpu_ids $GPU_IDS examples/cogvideo/train_cogvideox_lora.py \
--report_to wandb
```
For launching image-to-video finetuning instead, run the `train_cogvideox_image_to_video_lora.py` file instead. Additionally, you will have to pass `--validation_images` as paths to initial images corresponding to `--validation_prompts` for I2V validation to work.
To better track our training experiments, we're using the following flags in the command above:
* `--report_to wandb` will ensure the training runs are tracked on Weights and Biases. To use it, be sure to install `wandb` with `pip install wandb`.
* `validation_prompt` and `validation_epochs` to allow the script to do a few validation inference runs. This allows us to qualitatively check if the training is progressing as expected.
@@ -189,7 +180,6 @@ Note that setting the `<ID_TOKEN>` is not necessary. From some limited experimen
> [!TIP]
> You can pass `--use_8bit_adam` to reduce the memory requirements of training.
> You can pass `--video_reshape_mode` video cropping functionality, supporting options: ['center', 'random', 'none']. See [this](https://gist.github.com/glide-the/7658dbfd5f555be0a1a687a4139dba40) notebook for examples.
> [!IMPORTANT]
> The following settings have been tested at the time of adding CogVideoX LoRA training support:
@@ -206,6 +196,8 @@ Note that setting the `<ID_TOKEN>` is not necessary. From some limited experimen
>
> Note that our testing is not exhaustive due to limited time for exploration. Our recommendation would be to play around with the different knobs and dials to find the best settings for your data.
<!-- TODO: Test finetuning with CogVideoX-5b and CogVideoX-5b-I2V and update scripts accordingly -->
## Inference
Once you have trained a lora model, the inference can be done simply loading the lora weights into the `CogVideoXPipeline`.
@@ -234,5 +226,3 @@ prompt = (
frames = pipe(prompt, guidance_scale=6, use_dynamic_cfg=True).frames[0]
export_to_video(frames, "output.mp4", fps=8)
```
If you've trained a LoRA for `CogVideoXImageToVideoPipeline` instead, everything in the above example remains the same except you must also pass an image as initial condition for generation.
File diff suppressed because it is too large Load Diff
+25 -90
View File
@@ -21,28 +21,27 @@ import shutil
from pathlib import Path
from typing import List, Optional, Tuple, Union
import numpy as np
import torch
import torchvision.transforms as TT
import transformers
from accelerate import Accelerator, DistributedType
from accelerate import Accelerator
from accelerate.logging import get_logger
from accelerate.utils import DistributedDataParallelKwargs, ProjectConfiguration, set_seed
from huggingface_hub import create_repo, upload_folder
from peft import LoraConfig, get_peft_model_state_dict, set_peft_model_state_dict
from torch.utils.data import DataLoader, Dataset
from torchvision.transforms import InterpolationMode
from torchvision.transforms.functional import resize
from torchvision import transforms
from tqdm.auto import tqdm
from transformers import AutoTokenizer, T5EncoderModel, T5Tokenizer
import diffusers
from diffusers import AutoencoderKLCogVideoX, CogVideoXDPMScheduler, CogVideoXPipeline, CogVideoXTransformer3DModel
from diffusers.image_processor import VaeImageProcessor
from diffusers.models.embeddings import get_3d_rotary_pos_embed
from diffusers.optimization import get_scheduler
from diffusers.pipelines.cogvideo.pipeline_cogvideox import get_resize_crop_region_for_grid
from diffusers.training_utils import cast_training_params, free_memory
from diffusers.training_utils import (
cast_training_params,
clear_objs_and_retain_memory,
)
from diffusers.utils import check_min_version, convert_unet_state_dict_to_peft, export_to_video, is_wandb_available
from diffusers.utils.hub_utils import load_or_create_model_card, populate_model_card
from diffusers.utils.torch_utils import is_compiled_module
@@ -218,12 +217,6 @@ def get_args():
default=720,
help="All input videos are resized to this width.",
)
parser.add_argument(
"--video_reshape_mode",
type=str,
default="center",
help="All input videos are reshaped to this mode. Choose between ['center', 'random', 'none']",
)
parser.add_argument("--fps", type=int, default=8, help="All input videos will be used at this FPS.")
parser.add_argument(
"--max_num_frames", type=int, default=49, help="All input videos will be truncated to these many frames."
@@ -423,7 +416,6 @@ class VideoDataset(Dataset):
video_column: str = "video",
height: int = 480,
width: int = 720,
video_reshape_mode: str = "center",
fps: int = 8,
max_num_frames: int = 49,
skip_frames_start: int = 0,
@@ -440,7 +432,6 @@ class VideoDataset(Dataset):
self.video_column = video_column
self.height = height
self.width = width
self.video_reshape_mode = video_reshape_mode
self.fps = fps
self.max_num_frames = max_num_frames
self.skip_frames_start = skip_frames_start
@@ -544,38 +535,6 @@ class VideoDataset(Dataset):
return instance_prompts, instance_videos
def _resize_for_rectangle_crop(self, arr):
image_size = self.height, self.width
reshape_mode = self.video_reshape_mode
if arr.shape[3] / arr.shape[2] > image_size[1] / image_size[0]:
arr = resize(
arr,
size=[image_size[0], int(arr.shape[3] * image_size[0] / arr.shape[2])],
interpolation=InterpolationMode.BICUBIC,
)
else:
arr = resize(
arr,
size=[int(arr.shape[2] * image_size[1] / arr.shape[3]), image_size[1]],
interpolation=InterpolationMode.BICUBIC,
)
h, w = arr.shape[2], arr.shape[3]
arr = arr.squeeze(0)
delta_h = h - image_size[0]
delta_w = w - image_size[1]
if reshape_mode == "random" or reshape_mode == "none":
top = np.random.randint(0, delta_h + 1)
left = np.random.randint(0, delta_w + 1)
elif reshape_mode == "center":
top, left = delta_h // 2, delta_w // 2
else:
raise NotImplementedError
arr = TT.functional.crop(arr, top=top, left=left, height=image_size[0], width=image_size[1])
return arr
def _preprocess_data(self):
try:
import decord
@@ -586,14 +545,15 @@ class VideoDataset(Dataset):
decord.bridge.set_bridge("torch")
progress_dataset_bar = tqdm(
range(0, len(self.instance_video_paths)),
desc="Loading progress resize and crop videos",
)
videos = []
train_transforms = transforms.Compose(
[
transforms.Lambda(lambda x: x / 255.0 * 2.0 - 1.0),
]
)
for filename in self.instance_video_paths:
video_reader = decord.VideoReader(uri=filename.as_posix())
video_reader = decord.VideoReader(uri=filename.as_posix(), width=self.width, height=self.height)
video_num_frames = len(video_reader)
start_frame = min(self.skip_frames_start, video_num_frames)
@@ -619,16 +579,10 @@ class VideoDataset(Dataset):
assert (selected_num_frames - 1) % 4 == 0
# Training transforms
frames = (frames - 127.5) / 127.5
frames = frames.permute(0, 3, 1, 2) # [F, C, H, W]
progress_dataset_bar.set_description(
f"Loading progress Resizing video from {frames.shape[2]}x{frames.shape[3]} to {self.height}x{self.width}"
)
frames = self._resize_for_rectangle_crop(frames)
videos.append(frames.contiguous()) # [F, C, H, W]
progress_dataset_bar.update(1)
frames = frames.float()
frames = torch.stack([train_transforms(frame) for frame in frames], dim=0)
videos.append(frames.permute(0, 3, 1, 2).contiguous()) # [F, C, H, W]
progress_dataset_bar.close()
return videos
@@ -743,13 +697,8 @@ def log_validation(
videos = []
for _ in range(args.num_validation_videos):
pt_images = pipe(**pipeline_args, generator=generator, output_type="pt").frames[0]
pt_images = torch.stack([pt_images[i] for i in range(pt_images.shape[0])])
image_np = VaeImageProcessor.pt_to_numpy(pt_images)
image_pil = VaeImageProcessor.numpy_to_pil(image_np)
videos.append(image_pil)
video = pipe(**pipeline_args, generator=generator, output_type="np").frames[0]
videos.append(video)
for tracker in accelerator.trackers:
phase_name = "test" if is_final_validation else "validation"
@@ -777,8 +726,7 @@ def log_validation(
}
)
del pipe
free_memory()
clear_objs_and_retain_memory([pipe])
return videos
@@ -922,7 +870,7 @@ def get_optimizer(args, params_to_optimize, use_deepspeed: bool = False):
)
args.optimizer = "adamw"
if args.use_8bit_adam and args.optimizer.lower() not in ["adam", "adamw"]:
if args.use_8bit_adam and not (args.optimizer.lower() not in ["adam", "adamw"]):
logger.warning(
f"use_8bit_adam is ignored when optimizer is not set to 'Adam' or 'AdamW'. Optimizer was "
f"set to {args.optimizer.lower()}"
@@ -1211,7 +1159,7 @@ def main(args):
)
use_deepspeed_scheduler = (
accelerator.state.deepspeed_plugin is not None
and "scheduler" in accelerator.state.deepspeed_plugin.deepspeed_config
and "scheduler" not in accelerator.state.deepspeed_plugin.deepspeed_config
)
optimizer = get_optimizer(args, params_to_optimize, use_deepspeed=use_deepspeed_optimizer)
@@ -1225,7 +1173,6 @@ def main(args):
video_column=args.video_column,
height=args.height,
width=args.width,
video_reshape_mode=args.video_reshape_mode,
fps=args.fps,
max_num_frames=args.max_num_frames,
skip_frames_start=args.skip_frames_start,
@@ -1234,28 +1181,19 @@ def main(args):
id_token=args.id_token,
)
def encode_video(video, bar):
bar.update(1)
def encode_video(video):
video = video.to(accelerator.device, dtype=vae.dtype).unsqueeze(0)
video = video.permute(0, 2, 1, 3, 4) # [B, C, F, H, W]
latent_dist = vae.encode(video).latent_dist
return latent_dist
progress_encode_bar = tqdm(
range(0, len(train_dataset.instance_videos)),
desc="Loading Encode videos",
)
train_dataset.instance_videos = [
encode_video(video, progress_encode_bar) for video in train_dataset.instance_videos
]
progress_encode_bar.close()
train_dataset.instance_videos = [encode_video(video) for video in train_dataset.instance_videos]
def collate_fn(examples):
videos = [example["instance_video"].sample() * vae.config.scaling_factor for example in examples]
prompts = [example["instance_prompt"] for example in examples]
videos = torch.cat(videos)
videos = videos.permute(0, 2, 1, 3, 4)
videos = videos.to(memory_format=torch.contiguous_format).float()
return {
@@ -1377,7 +1315,7 @@ def main(args):
models_to_accumulate = [transformer]
with accelerator.accumulate(models_to_accumulate):
model_input = batch["videos"].to(dtype=weight_dtype) # [B, F, C, H, W]
model_input = batch["videos"].permute(0, 2, 1, 3, 4).to(dtype=weight_dtype) # [B, F, C, H, W]
prompts = batch["prompts"]
# encode prompts
@@ -1456,7 +1394,7 @@ def main(args):
progress_bar.update(1)
global_step += 1
if accelerator.is_main_process or accelerator.distributed_type == DistributedType.DEEPSPEED:
if accelerator.is_main_process:
if global_step % args.checkpointing_steps == 0:
# _before_ saving state, check if this save would set us over the `checkpoints_total_limit`
if args.checkpoints_total_limit is not None:
@@ -1496,6 +1434,7 @@ def main(args):
args.pretrained_model_name_or_path,
transformer=unwrap_model(transformer),
text_encoder=unwrap_model(text_encoder),
vae=unwrap_model(vae),
scheduler=scheduler,
revision=args.revision,
variant=args.variant,
@@ -1539,10 +1478,6 @@ def main(args):
transformer_lora_layers=transformer_lora_layers,
)
# Cleanup trained models to save memory
del transformer
free_memory()
# Final test inference
pipe = CogVideoXPipeline.from_pretrained(
args.pretrained_model_name_or_path,
+10 -56
View File
@@ -73,8 +73,7 @@ Please also check out our [Community Scripts](https://github.com/huggingface/dif
| Stable Diffusion BoxDiff Pipeline | Training-free controlled generation with bounding boxes using [BoxDiff](https://github.com/showlab/BoxDiff) | [Stable Diffusion BoxDiff Pipeline](#stable-diffusion-boxdiff) | - | [Jingyang Zhang](https://github.com/zjysteven/) |
| FRESCO V2V Pipeline | Implementation of [[CVPR 2024] FRESCO: Spatial-Temporal Correspondence for Zero-Shot Video Translation](https://arxiv.org/abs/2403.12962) | [FRESCO V2V Pipeline](#fresco) | - | [Yifan Zhou](https://github.com/SingleZombie) |
| AnimateDiff IPEX Pipeline | Accelerate AnimateDiff inference pipeline with BF16/FP32 precision on Intel Xeon CPUs with [IPEX](https://github.com/intel/intel-extension-for-pytorch) | [AnimateDiff on IPEX](#animatediff-on-ipex) | - | [Dan Li](https://github.com/ustcuna/) |
| HunyuanDiT Differential Diffusion Pipeline | Applies [Differential Diffusion](https://github.com/exx8/differential-diffusion) to [HunyuanDiT](https://github.com/huggingface/diffusers/pull/8240). | [HunyuanDiT with Differential Diffusion](#hunyuandit-with-differential-diffusion) | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/1v44a5fpzyr4Ffr4v2XBQ7BajzG874N4P?usp=sharing) | [Monjoy Choudhury](https://github.com/MnCSSJ4x) |
| [🪆Matryoshka Diffusion Models](https://huggingface.co/papers/2310.15111) | A diffusion process that denoises inputs at multiple resolutions jointly and uses a NestedUNet architecture where features and parameters for small scale inputs are nested within those of the large scales. See [original codebase](https://github.com/apple/ml-mdm). | [🪆Matryoshka Diffusion Models](#matryoshka-diffusion-models) | [![Hugging Face Space](https://img.shields.io/badge/🤗%20Hugging%20Face-Space-yellow)](https://huggingface.co/spaces/pcuenq/mdm) [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/gist/tolgacangoz/1f54875fc7aeaabcf284ebde64820966/matryoshka_hf.ipynb) | [M. Tolga Cangöz](https://github.com/tolgacangoz) |
| HunyuanDiT Differential Diffusion Pipeline | Applies [Differential Diffsuion](https://github.com/exx8/differential-diffusion) to [HunyuanDiT](https://github.com/huggingface/diffusers/pull/8240). | [HunyuanDiT with Differential Diffusion](#hunyuandit-with-differential-diffusion) | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/1v44a5fpzyr4Ffr4v2XBQ7BajzG874N4P?usp=sharing) | [Monjoy Choudhury](https://github.com/MnCSSJ4x) |
To load a custom pipeline you just need to pass the `custom_pipeline` argument to `DiffusionPipeline`, as one of the files in `diffusers/examples/community`. Feel free to send a PR with your own pipelines, we will merge them quickly.
@@ -86,17 +85,17 @@ pipe = DiffusionPipeline.from_pretrained("stable-diffusion-v1-5/stable-diffusion
### Flux with CFG
Know more about Flux [here](https://blackforestlabs.ai/announcing-black-forest-labs/). Since Flux doesn't use CFG, this implementation provides one, inspired by the [PuLID Flux adaptation](https://github.com/ToTheBeginning/PuLID/blob/main/docs/pulid_for_flux.md).
Know more about Flux [here](https://blackforestlabs.ai/announcing-black-forest-labs/). Since Flux doesn't use CFG, this implementation provides one, inspired by the [PuLID Flux adaptation](https://github.com/ToTheBeginning/PuLID/blob/main/docs/pulid_for_flux.md).
Example usage:
```py
from diffusers import DiffusionPipeline
import torch
import torch
pipeline = DiffusionPipeline.from_pretrained(
"black-forest-labs/FLUX.1-dev",
torch_dtype=torch.bfloat16,
"black-forest-labs/FLUX.1-dev",
torch_dtype=torch.bfloat16,
custom_pipeline="pipeline_flux_with_cfg"
)
pipeline.enable_model_cpu_offload()
@@ -104,10 +103,10 @@ prompt = "a watercolor painting of a unicorn"
negative_prompt = "pink"
img = pipeline(
prompt=prompt,
negative_prompt=negative_prompt,
true_cfg=1.5,
guidance_scale=3.5,
prompt=prompt,
negative_prompt=negative_prompt,
true_cfg=1.5,
guidance_scale=3.5,
num_images_per_prompt=1,
generator=torch.manual_seed(0)
).images[0]
@@ -2657,7 +2656,7 @@ image with mask mech_painted.png
<img src=https://github.com/noskill/diffusers/assets/733626/c334466a-67fe-4377-9ff7-f46021b9c224 width="25%" >
result:
result:
<img src=https://github.com/noskill/diffusers/assets/733626/5043fb57-a785-4606-a5ba-a36704f7cb42 width="25%" >
@@ -4325,51 +4324,6 @@ image = pipe(
A colab notebook demonstrating all results can be found [here](https://colab.research.google.com/drive/1v44a5fpzyr4Ffr4v2XBQ7BajzG874N4P?usp=sharing). Depth Maps have also been added in the same colab.
### 🪆Matryoshka Diffusion Models
![🪆Matryoshka Diffusion Models](https://github.com/user-attachments/assets/bf90b53b-48c3-4769-a805-d9dfe4a7c572)
The Abstract of the paper:
>Diffusion models are the _de-facto_ approach for generating high-quality images and videos but learning high-dimensional models remains a formidable task due to computational and optimization challenges. Existing methods often resort to training cascaded models in pixel space, or using a downsampled latent space of a separately trained auto-encoder. In this paper, we introduce Matryoshka Diffusion (MDM), **a novel framework for high-resolution image and video synthesis**. We propose a diffusion process that denoises inputs at multiple resolutions jointly and uses a **NestedUNet** architecture where features and parameters for small scale inputs are nested within those of the large scales. In addition, MDM enables a progressive training schedule from lower to higher resolutions which leads to significant improvements in optimization for high-resolution generation. We demonstrate the effectiveness of our approach on various benchmarks, including class-conditioned image generation, high-resolution text-to-image, and text-to-video applications. Remarkably, we can train a **_single pixel-space model_ at resolutions of up to 1024 × 1024 pixels**, demonstrating strong zero shot generalization using the **CC12M dataset, which contains only 12 million images**. Code and pre-trained checkpoints are released at https://github.com/apple/ml-mdm.
- `64×64, nesting_level=0`: 1.719 GiB. With `50` DDIM inference steps:
**64x64**
:-------------------------:
| <img src="https://github.com/user-attachments/assets/9e7bb2cd-45a0-4bd1-adb8-23e283baed39" width="222" height="222" alt="bird_64"> |
- `256×256, nesting_level=1`: 1.776 GiB. With `150` DDIM inference steps:
**64x64** | **256x256**
:-------------------------:|:-------------------------:
| <img src="https://github.com/user-attachments/assets/6b724c2e-5e6a-4b63-9b65-c1182cbb67e0" width="222" height="222" alt="64x64"> | <img src="https://github.com/user-attachments/assets/7dbab2ad-bf40-4a73-ab04-f178347cb7d5" width="222" height="222" alt="256x256"> |
- `1024×1024, nesting_level=2`: 1.792 GiB. As one can realize the cost of adding another layer is really negligible. With `250` DDIM inference steps:
**64x64** | **256x256** | **1024x1024**
:-------------------------:|:-------------------------:|:-------------------------:
| <img src="https://github.com/user-attachments/assets/4a9454e4-e20a-4736-a196-270e2ae796c0" width="222" height="222" alt="64x64"> | <img src="https://github.com/user-attachments/assets/4a96555d-0fda-4303-82b1-a4d886f770b9" width="222" height="222" alt="256x256"> | <img src="https://github.com/user-attachments/assets/e0239b7a-ab73-4d45-8f3e-b4e6b4b50abe" width="222" height="222" alt="1024x1024"> |
```py
from diffusers import DiffusionPipeline
from diffusers.utils import make_image_grid
# nesting_level=0 -> 64x64; nesting_level=1 -> 256x256 - 64x64; nesting_level=2 -> 1024x1024 - 256x256 - 64x64
pipe = DiffusionPipeline.from_pretrained("tolgacangoz/matryoshka-diffusion-models",
nesting_level=0,
trust_remote_code=False, # One needs to give permission for this code to run
).to("cuda")
prompt0 = "a blue jay stops on the top of a helmet of Japanese samurai, background with sakura tree"
prompt = f"breathtaking {prompt0}. award-winning, professional, highly detailed"
negative_prompt = "deformed, mutated, ugly, disfigured, blur, blurry, noise, noisy"
image = pipe(prompt=prompt, negative_prompt=negative_prompt, num_inference_steps=50).images
make_image_grid(image, rows=1, cols=len(image))
# pipe.change_nesting_level(<int>) # 0, 1, or 2
# 50+, 100+, and 250+ num_inference_steps are recommended for nesting levels 0, 1, and 2 respectively.
```
# Perturbed-Attention Guidance
[Project](https://ku-cvlab.github.io/Perturbed-Attention-Guidance/) / [arXiv](https://arxiv.org/abs/2403.17377) / [GitHub](https://github.com/KU-CVLAB/Perturbed-Attention-Guidance)
+10 -17
View File
@@ -898,16 +898,13 @@ class GaussianSmoothing(nn.Module):
Apply gaussian smoothing on a
1d, 2d or 3d tensor. Filtering is performed seperately for each channel
in the input using a depthwise convolution.
Args:
channels (`int` or `sequence`):
Number of channels of the input tensors. The output will have this number of channels as well.
kernel_size (`int` or `sequence`):
Size of the Gaussian kernel.
sigma (`float` or `sequence`):
Standard deviation of the Gaussian kernel.
dim (`int`, *optional*, defaults to `2`):
The number of dimensions of the data. Default is 2 (spatial dimensions).
Arguments:
channels (int, sequence): Number of channels of the input tensors. Output will
have this number of channels as well.
kernel_size (int, sequence): Size of the gaussian kernel.
sigma (float, sequence): Standard deviation of the gaussian kernel.
dim (int, optional): The number of dimensions of the data.
Default value is 2 (spatial).
"""
def __init__(self, channels, kernel_size, sigma, dim=2):
@@ -947,14 +944,10 @@ class GaussianSmoothing(nn.Module):
def forward(self, input):
"""
Apply gaussian filter to input.
Args:
input (`torch.Tensor` of shape `(N, C, H, W)`):
Input to apply Gaussian filter on.
Arguments:
input (torch.Tensor): Input to apply gaussian filter on.
Returns:
`torch.Tensor`:
The filtered output tensor with the same shape as the input.
filtered (torch.Tensor): Filtered output.
"""
return self.conv(input, weight=self.weight.to(input.dtype), groups=self.groups, padding="same")
File diff suppressed because it is too large Load Diff
File diff suppressed because it is too large Load Diff
File diff suppressed because it is too large Load Diff
-430
View File
@@ -1,430 +0,0 @@
# ControlNet training example for FLUX
The `train_controlnet_flux.py` script shows how to implement the ControlNet training procedure and adapt it for [FLUX](https://github.com/black-forest-labs/flux).
Training script provided by LibAI, which is an institution dedicated to the progress and achievement of artificial general intelligence. LibAI is the developer of [cutout.pro](https://www.cutout.pro/) and [promeai.pro](https://www.promeai.pro/).
> [!NOTE]
> **Memory consumption**
>
> Flux can be quite expensive to run on consumer hardware devices and as a result, ControlNet training of it comes with higher memory requirements than usual.
> **Gated access**
>
> As the model is gated, before using it with diffusers you first need to go to the [FLUX.1 [dev] Hugging Face page](https://huggingface.co/black-forest-labs/FLUX.1-dev), fill in the form and accept the gate. Once you are in, you need to log in so that your system knows youve accepted the gate. Use the command below to log in: `huggingface-cli login`
## Running locally with PyTorch
### Installing the dependencies
Before running the scripts, make sure to install the library's training dependencies:
**Important**
To make sure you can successfully run the latest versions of the example scripts, we highly recommend **installing from source** and keeping the install up to date as we update the example scripts frequently and install some example-specific requirements. To do this, execute the following steps in a new virtual environment:
```bash
git clone https://github.com/huggingface/diffusers
cd diffusers
pip install -e .
```
Then cd in the `examples/controlnet` folder and run
```bash
pip install -r requirements_flux.txt
```
And initialize an [🤗Accelerate](https://github.com/huggingface/accelerate/) environment with:
```bash
accelerate config
```
Or for a default accelerate configuration without answering questions about your environment
```bash
accelerate config default
```
Or if your environment doesn't support an interactive shell (e.g., a notebook)
```python
from accelerate.utils import write_basic_config
write_basic_config()
```
When running `accelerate config`, if we specify torch compile mode to True there can be dramatic speedups.
## Custom Datasets
We support dataset formats:
The original dataset is hosted in the [ControlNet repo](https://huggingface.co/lllyasviel/ControlNet/blob/main/training/fill50k.zip). We re-uploaded it to be compatible with `datasets` [here](https://huggingface.co/datasets/fusing/fill50k). Note that `datasets` handles dataloading within the training script. To use our example, add `--dataset_name=fusing/fill50k \` to the script and remove line `--jsonl_for_train` mentioned below.
We also support importing data from jsonl(xxx.jsonl),using `--jsonl_for_train` to enable it, here is a brief example of jsonl files:
```sh
{"image": "xxx", "text": "xxx", "conditioning_image": "xxx"}
{"image": "xxx", "text": "xxx", "conditioning_image": "xxx"}
```
## Training
Our training examples use two test conditioning images. They can be downloaded by running
```sh
wget https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/controlnet_training/conditioning_image_1.png
wget https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/controlnet_training/conditioning_image_2.png
```
Then run `huggingface-cli login` to log into your Hugging Face account. This is needed to be able to push the trained ControlNet parameters to Hugging Face Hub.
we can define the num_layers, num_single_layers, which determines the size of the control(default values are num_layers=4, num_single_layers=10)
```bash
accelerate launch train_controlnet_flux.py \
--pretrained_model_name_or_path="black-forest-labs/FLUX.1-dev" \
--dataset_name=fusing/fill50k \
--conditioning_image_column=conditioning_image \
--image_column=image \
--caption_column=text \
--output_dir="path to save model" \
--mixed_precision="bf16" \
--resolution=512 \
--learning_rate=1e-5 \
--max_train_steps=15000 \
--validation_steps=100 \
--checkpointing_steps=200 \
--validation_image "./conditioning_image_1.png" "./conditioning_image_2.png" \
--validation_prompt "red circle with blue background" "cyan circle with brown floral background" \
--train_batch_size=1 \
--gradient_accumulation_steps=4 \
--report_to="wandb" \
--num_double_layers=4 \
--num_single_layers=0 \
--seed=42 \
--push_to_hub \
```
To better track our training experiments, we're using the following flags in the command above:
* `report_to="wandb` will ensure the training runs are tracked on Weights and Biases.
* `validation_image`, `validation_prompt`, and `validation_steps` to allow the script to do a few validation inference runs. This allows us to qualitatively check if the training is progressing as expected.
Our experiments were conducted on a single 80GB A100 GPU.
### Inference
Once training is done, we can perform inference like so:
```python
import torch
from diffusers.utils import load_image
from diffusers.pipelines.flux.pipeline_flux_controlnet import FluxControlNetPipeline
from diffusers.models.controlnet_flux import FluxControlNetModel
base_model = 'black-forest-labs/FLUX.1-dev'
controlnet_model = 'promeai/FLUX.1-controlnet-lineart-promeai'
controlnet = FluxControlNetModel.from_pretrained(controlnet_model, torch_dtype=torch.bfloat16)
pipe = FluxControlNetPipeline.from_pretrained(
base_model,
controlnet=controlnet,
torch_dtype=torch.bfloat16
)
# enable memory optimizations
pipe.enable_model_cpu_offload()
control_image = load_image("https://huggingface.co/promeai/FLUX.1-controlnet-lineart-promeai/resolve/main/images/example-control.jpg")resize((1024, 1024))
prompt = "cute anime girl with massive fluffy fennec ears and a big fluffy tail blonde messy long hair blue eyes wearing a maid outfit with a long black gold leaf pattern dress and a white apron mouth open holding a fancy black forest cake with candles on top in the kitchen of an old dark Victorian mansion lit by candlelight with a bright window to the foggy forest and very expensive stuff everywhere"
image = pipe(
prompt,
control_image=control_image,
controlnet_conditioning_scale=0.6,
num_inference_steps=28,
guidance_scale=3.5,
).images[0]
image.save("./output.png")
```
## Apply Deepspeed Zero3
This is an experimental process, I am not sure if it is suitable for everyone, we used this process to successfully train 512 resolution on A100(40g) * 8.
Please modify some of the code in the script.
### 1.Customize zero3 settings
Copy the **accelerate_config_zero3.yaml**,modify `num_processes` according to the number of gpus you want to use:
```bash
compute_environment: LOCAL_MACHINE
debug: false
deepspeed_config:
gradient_accumulation_steps: 8
offload_optimizer_device: cpu
offload_param_device: cpu
zero3_init_flag: true
zero3_save_16bit_model: true
zero_stage: 3
distributed_type: DEEPSPEED
downcast_bf16: 'no'
enable_cpu_affinity: false
machine_rank: 0
main_training_function: main
mixed_precision: bf16
num_machines: 1
num_processes: 8
rdzv_backend: static
same_network: true
tpu_env: []
tpu_use_cluster: false
tpu_use_sudo: false
use_cpu: false
```
### 2.Precompute all inputs (latent, embeddings)
In the train_controlnet_flux.py, We need to pre-calculate all parameters and put them into batches.So we first need to rewrite the `compute_embeddings` function.
```python
def compute_embeddings(batch, proportion_empty_prompts, vae, flux_controlnet_pipeline, weight_dtype, is_train=True):
### compute text embeddings
prompt_batch = batch[args.caption_column]
captions = []
for caption in prompt_batch:
if random.random() < proportion_empty_prompts:
captions.append("")
elif isinstance(caption, str):
captions.append(caption)
elif isinstance(caption, (list, np.ndarray)):
# take a random caption if there are multiple
captions.append(random.choice(caption) if is_train else caption[0])
prompt_batch = captions
prompt_embeds, pooled_prompt_embeds, text_ids = flux_controlnet_pipeline.encode_prompt(
prompt_batch, prompt_2=prompt_batch
)
prompt_embeds = prompt_embeds.to(dtype=weight_dtype)
pooled_prompt_embeds = pooled_prompt_embeds.to(dtype=weight_dtype)
text_ids = text_ids.to(dtype=weight_dtype)
# text_ids [512,3] to [bs,512,3]
text_ids = text_ids.unsqueeze(0).expand(prompt_embeds.shape[0], -1, -1)
### compute latents
def _pack_latents(latents, batch_size, num_channels_latents, height, width):
latents = latents.view(batch_size, num_channels_latents, height // 2, 2, width // 2, 2)
latents = latents.permute(0, 2, 4, 1, 3, 5)
latents = latents.reshape(batch_size, (height // 2) * (width // 2), num_channels_latents * 4)
return latents
# vae encode
pixel_values = batch["pixel_values"]
pixel_values = torch.stack([image for image in pixel_values]).to(dtype=weight_dtype).to(vae.device)
pixel_latents_tmp = vae.encode(pixel_values).latent_dist.sample()
pixel_latents_tmp = (pixel_latents_tmp - vae.config.shift_factor) * vae.config.scaling_factor
pixel_latents = _pack_latents(
pixel_latents_tmp,
pixel_values.shape[0],
pixel_latents_tmp.shape[1],
pixel_latents_tmp.shape[2],
pixel_latents_tmp.shape[3],
)
control_values = batch["conditioning_pixel_values"]
control_values = torch.stack([image for image in control_values]).to(dtype=weight_dtype).to(vae.device)
control_latents = vae.encode(control_values).latent_dist.sample()
control_latents = (control_latents - vae.config.shift_factor) * vae.config.scaling_factor
control_latents = _pack_latents(
control_latents,
control_values.shape[0],
control_latents.shape[1],
control_latents.shape[2],
control_latents.shape[3],
)
# copied from pipeline_flux_controlnet
def _prepare_latent_image_ids(batch_size, height, width, device, dtype):
latent_image_ids = torch.zeros(height // 2, width // 2, 3)
latent_image_ids[..., 1] = latent_image_ids[..., 1] + torch.arange(height // 2)[:, None]
latent_image_ids[..., 2] = latent_image_ids[..., 2] + torch.arange(width // 2)[None, :]
latent_image_id_height, latent_image_id_width, latent_image_id_channels = latent_image_ids.shape
latent_image_ids = latent_image_ids[None, :].repeat(batch_size, 1, 1, 1)
latent_image_ids = latent_image_ids.reshape(
batch_size, latent_image_id_height * latent_image_id_width, latent_image_id_channels
)
return latent_image_ids.to(device=device, dtype=dtype)
latent_image_ids = _prepare_latent_image_ids(
batch_size=pixel_latents_tmp.shape[0],
height=pixel_latents_tmp.shape[2],
width=pixel_latents_tmp.shape[3],
device=pixel_values.device,
dtype=pixel_values.dtype,
)
# unet_added_cond_kwargs = {"pooled_prompt_embeds": pooled_prompt_embeds, "text_ids": text_ids}
return {"prompt_embeds": prompt_embeds, "pooled_prompt_embeds": pooled_prompt_embeds, "text_ids": text_ids, "pixel_latents": pixel_latents, "control_latents": control_latents, "latent_image_ids": latent_image_ids}
```
Because we need images to pass through vae, we need to preprocess the images in the dataset first. At the same time, vae requires more gpu memory, so you may need to modify the `batch_size` below
```diff
+train_dataset = prepare_train_dataset(train_dataset, accelerator)
with accelerator.main_process_first():
from datasets.fingerprint import Hasher
# fingerprint used by the cache for the other processes to load the result
# details: https://github.com/huggingface/diffusers/pull/4038#discussion_r1266078401
new_fingerprint = Hasher.hash(args)
train_dataset = train_dataset.map(
- compute_embeddings_fn, batched=True, new_fingerprint=new_fingerprint, batch_size=100
+ compute_embeddings_fn, batched=True, new_fingerprint=new_fingerprint, batch_size=10
)
del text_encoders, tokenizers
gc.collect()
torch.cuda.empty_cache()
# Then get the training dataset ready to be passed to the dataloader.
-train_dataset = prepare_train_dataset(train_dataset, accelerator)
```
### 3.Redefine the behavior of getting batchsize
Now that we have all the preprocessing done, we need to modify the `collate_fn` function.
```python
def collate_fn(examples):
pixel_values = torch.stack([example["pixel_values"] for example in examples])
pixel_values = pixel_values.to(memory_format=torch.contiguous_format).float()
conditioning_pixel_values = torch.stack([example["conditioning_pixel_values"] for example in examples])
conditioning_pixel_values = conditioning_pixel_values.to(memory_format=torch.contiguous_format).float()
pixel_latents = torch.stack([torch.tensor(example["pixel_latents"]) for example in examples])
pixel_latents = pixel_latents.to(memory_format=torch.contiguous_format).float()
control_latents = torch.stack([torch.tensor(example["control_latents"]) for example in examples])
control_latents = control_latents.to(memory_format=torch.contiguous_format).float()
latent_image_ids= torch.stack([torch.tensor(example["latent_image_ids"]) for example in examples])
prompt_ids = torch.stack([torch.tensor(example["prompt_embeds"]) for example in examples])
pooled_prompt_embeds = torch.stack([torch.tensor(example["pooled_prompt_embeds"]) for example in examples])
text_ids = torch.stack([torch.tensor(example["text_ids"]) for example in examples])
return {
"pixel_values": pixel_values,
"conditioning_pixel_values": conditioning_pixel_values,
"pixel_latents": pixel_latents,
"control_latents": control_latents,
"latent_image_ids": latent_image_ids,
"prompt_ids": prompt_ids,
"unet_added_conditions": {"pooled_prompt_embeds": pooled_prompt_embeds, "time_ids": text_ids},
}
```
Finally, we just need to modify the way of obtaining various parameters during training.
```python
for epoch in range(first_epoch, args.num_train_epochs):
for step, batch in enumerate(train_dataloader):
with accelerator.accumulate(flux_controlnet):
# Convert images to latent space
pixel_latents = batch["pixel_latents"].to(dtype=weight_dtype)
control_image = batch["control_latents"].to(dtype=weight_dtype)
latent_image_ids = batch["latent_image_ids"].to(dtype=weight_dtype)
# Sample noise that we'll add to the latents
noise = torch.randn_like(pixel_latents).to(accelerator.device).to(dtype=weight_dtype)
bsz = pixel_latents.shape[0]
# Sample a random timestep for each image
t = torch.sigmoid(torch.randn((bsz,), device=accelerator.device, dtype=weight_dtype))
# apply flow matching
noisy_latents = (
1 - t.unsqueeze(1).unsqueeze(2).repeat(1, pixel_latents.shape[1], pixel_latents.shape[2])
) * pixel_latents + t.unsqueeze(1).unsqueeze(2).repeat(
1, pixel_latents.shape[1], pixel_latents.shape[2]
) * noise
guidance_vec = torch.full(
(noisy_latents.shape[0],), 3.5, device=noisy_latents.device, dtype=weight_dtype
)
controlnet_block_samples, controlnet_single_block_samples = flux_controlnet(
hidden_states=noisy_latents,
controlnet_cond=control_image,
timestep=t,
guidance=guidance_vec,
pooled_projections=batch["unet_added_conditions"]["pooled_prompt_embeds"].to(dtype=weight_dtype),
encoder_hidden_states=batch["prompt_ids"].to(dtype=weight_dtype),
txt_ids=batch["unet_added_conditions"]["time_ids"][0].to(dtype=weight_dtype),
img_ids=latent_image_ids[0],
return_dict=False,
)
noise_pred = flux_transformer(
hidden_states=noisy_latents,
timestep=t,
guidance=guidance_vec,
pooled_projections=batch["unet_added_conditions"]["pooled_prompt_embeds"].to(dtype=weight_dtype),
encoder_hidden_states=batch["prompt_ids"].to(dtype=weight_dtype),
controlnet_block_samples=[sample.to(dtype=weight_dtype) for sample in controlnet_block_samples]
if controlnet_block_samples is not None
else None,
controlnet_single_block_samples=[
sample.to(dtype=weight_dtype) for sample in controlnet_single_block_samples
]
if controlnet_single_block_samples is not None
else None,
txt_ids=batch["unet_added_conditions"]["time_ids"][0].to(dtype=weight_dtype),
img_ids=latent_image_ids[0],
return_dict=False,
)[0]
```
Congratulations! You have completed all the required code modifications required for deepspeedzero3.
### 4.Training with deepspeedzero3
Start!!!
```bash
export pretrained_model_name_or_path='flux-dev-model-path'
export MODEL_TYPE='train_model_type'
export TRAIN_JSON_FILE="your_json_file"
export CONTROL_TYPE='control_preprocessor_type'
export CAPTION_COLUMN='caption_column'
export CACHE_DIR="/data/train_csr/.cache/huggingface/"
export OUTPUT_DIR='/data/train_csr/FLUX/MODEL_OUT/'$MODEL_TYPE
# The first step is to use Python to precompute all caches.Replace the first line below with this line. (I am not sure why using acclerate would cause problems.)
CUDA_VISIBLE_DEVICES=0 python3 train_controlnet_flux.py \
# The second step is to use the above accelerate config to train
accelerate launch --config_file "./accelerate_config_zero3.yaml" train_controlnet_flux.py \
--pretrained_model_name_or_path=$pretrained_model_name_or_path \
--jsonl_for_train=$TRAIN_JSON_FILE \
--conditioning_image_column=$CONTROL_TYPE \
--image_column=image \
--caption_column=$CAPTION_COLUMN\
--cache_dir=$CACHE_DIR \
--tracker_project_name=$MODEL_TYPE \
--output_dir=$OUTPUT_DIR \
--max_train_steps=500000 \
--mixed_precision bf16 \
--checkpointing_steps=1000 \
--gradient_accumulation_steps=8 \
--resolution=512 \
--train_batch_size=1 \
--learning_rate=1e-5 \
--num_double_layers=4 \
--num_single_layers=0 \
--gradient_checkpointing \
--resume_from_checkpoint="latest" \
# --use_adafactor \ dont use
# --validation_steps=3 \ not support
# --validation_image $VALIDATION_IMAGE \ not support
# --validation_prompt "xxx" \ not support
```
@@ -1,9 +0,0 @@
accelerate>=0.16.0
torchvision
transformers>=4.25.1
ftfy
tensorboard
Jinja2
datasets
wandb
SentencePiece
-25
View File
@@ -136,28 +136,3 @@ class ControlNetSD3(ExamplesTestsAccelerate):
run_command(self._launch_args + test_args)
self.assertTrue(os.path.isfile(os.path.join(tmpdir, "diffusion_pytorch_model.safetensors")))
class ControlNetflux(ExamplesTestsAccelerate):
def test_controlnet_flux(self):
with tempfile.TemporaryDirectory() as tmpdir:
test_args = f"""
examples/controlnet/train_controlnet_flux.py
--pretrained_model_name_or_path=hf-internal-testing/tiny-flux-pipe
--output_dir={tmpdir}
--dataset_name=hf-internal-testing/fill10
--conditioning_image_column=conditioning_image
--image_column=image
--caption_column=text
--resolution=64
--train_batch_size=1
--gradient_accumulation_steps=1
--max_train_steps=4
--checkpointing_steps=2
--num_double_layers=1
--num_single_layers=1
""".split()
run_command(self._launch_args + test_args)
self.assertTrue(os.path.isfile(os.path.join(tmpdir, "diffusion_pytorch_model.safetensors")))
File diff suppressed because it is too large Load Diff
+8 -15
View File
@@ -49,7 +49,11 @@ from diffusers import (
StableDiffusion3ControlNetPipeline,
)
from diffusers.optimization import get_scheduler
from diffusers.training_utils import compute_density_for_timestep_sampling, compute_loss_weighting_for_sd3, free_memory
from diffusers.training_utils import (
clear_objs_and_retain_memory,
compute_density_for_timestep_sampling,
compute_loss_weighting_for_sd3,
)
from diffusers.utils import check_min_version, is_wandb_available
from diffusers.utils.hub_utils import load_or_create_model_card, populate_model_card
from diffusers.utils.torch_utils import is_compiled_module
@@ -170,8 +174,7 @@ def log_validation(controlnet, args, accelerator, weight_dtype, step, is_final_v
else:
logger.warning(f"image logging not implemented for {tracker.name}")
del pipeline
free_memory()
clear_objs_and_retain_memory(pipeline)
if not is_final_validation:
controlnet.to(accelerator.device)
@@ -357,11 +360,6 @@ def parse_args(input_args=None):
action="store_true",
help="Whether or not to use gradient checkpointing to save memory at the expense of slower backward pass.",
)
parser.add_argument(
"--upcast_vae",
action="store_true",
help="Whether or not to upcast vae to fp32",
)
parser.add_argument(
"--learning_rate",
type=float,
@@ -1099,10 +1097,7 @@ def main(args):
weight_dtype = torch.bfloat16
# Move vae, transformer and text_encoder to device and cast to weight_dtype
if args.upcast_vae:
vae.to(accelerator.device, dtype=torch.float32)
else:
vae.to(accelerator.device, dtype=weight_dtype)
vae.to(accelerator.device, dtype=torch.float32)
transformer.to(accelerator.device, dtype=weight_dtype)
text_encoder_one.to(accelerator.device, dtype=weight_dtype)
text_encoder_two.to(accelerator.device, dtype=weight_dtype)
@@ -1136,9 +1131,7 @@ def main(args):
new_fingerprint = Hasher.hash(args)
train_dataset = train_dataset.map(compute_embeddings_fn, batched=True, new_fingerprint=new_fingerprint)
del text_encoder_one, text_encoder_two, text_encoder_three
del tokenizer_one, tokenizer_two, tokenizer_three
free_memory()
clear_objs_and_retain_memory(text_encoders + tokenizers)
train_dataloader = torch.utils.data.DataLoader(
train_dataset,
+1 -1
View File
@@ -136,7 +136,7 @@ accelerate launch train_dreambooth_lora_sd3.py \
--resolution=512 \
--train_batch_size=1 \
--gradient_accumulation_steps=4 \
--learning_rate=4e-4 \
--learning_rate=1e-5 \
--report_to="wandb" \
--lr_scheduler="constant" \
--lr_warmup_steps=0 \
@@ -103,39 +103,6 @@ class DreamBoothLoRASD3(ExamplesTestsAccelerate):
)
self.assertTrue(starts_with_expected_prefix)
def test_dreambooth_lora_latent_caching(self):
with tempfile.TemporaryDirectory() as tmpdir:
test_args = f"""
{self.script_path}
--pretrained_model_name_or_path {self.pretrained_model_name_or_path}
--instance_data_dir {self.instance_data_dir}
--instance_prompt {self.instance_prompt}
--resolution 64
--train_batch_size 1
--gradient_accumulation_steps 1
--max_train_steps 2
--cache_latents
--learning_rate 5.0e-04
--scale_lr
--lr_scheduler constant
--lr_warmup_steps 0
--output_dir {tmpdir}
""".split()
run_command(self._launch_args + test_args)
# save_pretrained smoke test
self.assertTrue(os.path.isfile(os.path.join(tmpdir, "pytorch_lora_weights.safetensors")))
# make sure the state_dict has the correct naming in the parameters.
lora_state_dict = safetensors.torch.load_file(os.path.join(tmpdir, "pytorch_lora_weights.safetensors"))
is_lora = all("lora" in k for k in lora_state_dict.keys())
self.assertTrue(is_lora)
# when not training the text encoder, all the parameters in the state dict should start
# with `"transformer"` in their names.
starts_with_transformer = all(key.startswith("transformer") for key in lora_state_dict.keys())
self.assertTrue(starts_with_transformer)
def test_dreambooth_lora_sd3_checkpointing_checkpoints_total_limit(self):
with tempfile.TemporaryDirectory() as tmpdir:
test_args = f"""
@@ -55,9 +55,9 @@ from diffusers.optimization import get_scheduler
from diffusers.training_utils import (
_set_state_dict_into_text_encoder,
cast_training_params,
clear_objs_and_retain_memory,
compute_density_for_timestep_sampling,
compute_loss_weighting_for_sd3,
free_memory,
)
from diffusers.utils import (
check_min_version,
@@ -985,6 +985,7 @@ def encode_prompt(
text_input_ids_list=None,
):
prompt = [prompt] if isinstance(prompt, str) else prompt
batch_size = len(prompt)
dtype = text_encoders[0].dtype
pooled_prompt_embeds = _encode_prompt_with_clip(
@@ -1006,7 +1007,8 @@ def encode_prompt(
text_input_ids=text_input_ids_list[1] if text_input_ids_list else None,
)
text_ids = torch.zeros(prompt_embeds.shape[1], 3).to(device=device, dtype=dtype)
text_ids = torch.zeros(batch_size, prompt_embeds.shape[1], 3).to(device=device, dtype=dtype)
text_ids = text_ids.repeat(num_images_per_prompt, 1, 1)
return prompt_embeds, pooled_prompt_embeds, text_ids
@@ -1435,8 +1437,7 @@ def main(args):
# Clear the memory here
if not args.train_text_encoder and not train_dataset.custom_instance_prompts:
del text_encoder_one, text_encoder_two, tokenizer_one, tokenizer_two
free_memory()
clear_objs_and_retain_memory([tokenizers, text_encoders, text_encoder_one, text_encoder_two])
# If custom instance prompts are NOT provided (i.e. the instance prompt is used for all images),
# pack the statically computed variables appropriately here. This is so that we don't
@@ -1479,8 +1480,7 @@ def main(args):
latents_cache.append(vae.encode(batch["pixel_values"]).latent_dist)
if args.validation_prompt is None:
del vae
free_memory()
clear_objs_and_retain_memory([vae])
# Scheduler and math around the number of training steps.
overrode_max_train_steps = False
@@ -1817,8 +1817,7 @@ def main(args):
torch_dtype=weight_dtype,
)
if not args.train_text_encoder:
del text_encoder_one, text_encoder_two
free_memory()
clear_objs_and_retain_memory([text_encoder_one, text_encoder_two])
# Save the lora layers
accelerator.wait_for_everyone()
@@ -55,9 +55,9 @@ from diffusers.optimization import get_scheduler
from diffusers.training_utils import (
_set_state_dict_into_text_encoder,
cast_training_params,
clear_objs_and_retain_memory,
compute_density_for_timestep_sampling,
compute_loss_weighting_for_sd3,
free_memory,
)
from diffusers.utils import (
check_min_version,
@@ -140,7 +140,7 @@ Please adhere to the licensing terms as described [here](https://huggingface.co/
model_card = load_or_create_model_card(
repo_id_or_path=repo_id,
from_training=True,
license="other",
license="openrail++",
base_model=base_model,
prompt=instance_prompt,
model_description=model_description,
@@ -186,7 +186,7 @@ def log_validation(
f"Running validation... \n Generating {args.num_validation_images} images with prompt:"
f" {args.validation_prompt}."
)
pipeline = pipeline.to(accelerator.device)
pipeline = pipeline.to(accelerator.device, dtype=torch_dtype)
pipeline.set_progress_bar_config(disable=True)
# run inference
@@ -211,8 +211,7 @@ def log_validation(
}
)
del pipeline
free_memory()
clear_objs_and_retain_memory(objs=[pipeline])
return images
@@ -608,12 +607,6 @@ def parse_args(input_args=None):
" https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices"
),
)
parser.add_argument(
"--cache_latents",
action="store_true",
default=False,
help="Cache the VAE latents",
)
parser.add_argument(
"--report_to",
type=str,
@@ -634,15 +627,6 @@ def parse_args(input_args=None):
" flag passed with the `accelerate.launch` command. Use this argument to override the accelerate config."
),
)
parser.add_argument(
"--upcast_before_saving",
action="store_true",
default=False,
help=(
"Whether to upcast the trained transformer layers to float32 before saving (at the end of training). "
"Defaults to precision dtype used for training to save memory"
),
)
parser.add_argument(
"--prior_generation_precision",
type=str,
@@ -1122,8 +1106,7 @@ def main(args):
image_filename = class_images_dir / f"{example['index'][i] + cur_class_images}-{hash_image}.jpg"
image.save(image_filename)
del pipeline
free_memory()
clear_objs_and_retain_memory(objs=[pipeline])
# Handle the repository creation
if accelerator.is_main_process:
@@ -1409,16 +1392,6 @@ def main(args):
logger.warning(
"Learning rate is too low. When using prodigy, it's generally better to set learning rate around 1.0"
)
if args.train_text_encoder and args.text_encoder_lr:
logger.warning(
f"Learning rates were provided both for the transformer and the text encoder- e.g. text_encoder_lr:"
f" {args.text_encoder_lr} and learning_rate: {args.learning_rate}. "
f"When using prodigy only learning_rate is used as the initial learning rate."
)
# changes the learning rate of text_encoder_parameters_one and text_encoder_parameters_two to be
# --learning_rate
params_to_optimize[1]["lr"] = args.learning_rate
params_to_optimize[2]["lr"] = args.learning_rate
optimizer = optimizer_class(
params_to_optimize,
@@ -1465,9 +1438,6 @@ def main(args):
pooled_prompt_embeds = pooled_prompt_embeds.to(accelerator.device)
return prompt_embeds, pooled_prompt_embeds
# If no type of tuning is done on the text_encoder and custom instance prompts are NOT
# provided (i.e. the --instance_prompt is used for all images), we encode the instance prompt once to avoid
# the redundant encoding.
if not args.train_text_encoder and not train_dataset.custom_instance_prompts:
instance_prompt_hidden_states, instance_pooled_prompt_embeds = compute_text_embeddings(
args.instance_prompt, text_encoders, tokenizers
@@ -1483,9 +1453,9 @@ def main(args):
# Clear the memory here
if not args.train_text_encoder and not train_dataset.custom_instance_prompts:
# Explicitly delete the objects as well, otherwise only the lists are deleted and the original references remain, preventing garbage collection
del tokenizers, text_encoders
del text_encoder_one, text_encoder_two, text_encoder_three
free_memory()
clear_objs_and_retain_memory(
objs=[tokenizers, text_encoders, text_encoder_one, text_encoder_two, text_encoder_three]
)
# If custom instance prompts are NOT provided (i.e. the instance prompt is used for all images),
# pack the statically computed variables appropriately here. This is so that we don't
@@ -1512,21 +1482,6 @@ def main(args):
tokens_two = torch.cat([tokens_two, class_tokens_two], dim=0)
tokens_three = torch.cat([tokens_three, class_tokens_three], dim=0)
vae_config_shift_factor = vae.config.shift_factor
vae_config_scaling_factor = vae.config.scaling_factor
if args.cache_latents:
latents_cache = []
for batch in tqdm(train_dataloader, desc="Caching latents"):
with torch.no_grad():
batch["pixel_values"] = batch["pixel_values"].to(
accelerator.device, non_blocking=True, dtype=weight_dtype
)
latents_cache.append(vae.encode(batch["pixel_values"]).latent_dist)
if args.validation_prompt is None:
del vae
free_memory()
# Scheduler and math around the number of training steps.
overrode_max_train_steps = False
num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
@@ -1543,6 +1498,7 @@ def main(args):
power=args.lr_power,
)
# Prepare everything with our `accelerator`.
# Prepare everything with our `accelerator`.
if args.train_text_encoder:
(
@@ -1649,9 +1605,8 @@ def main(args):
for step, batch in enumerate(train_dataloader):
models_to_accumulate = [transformer]
if args.train_text_encoder:
models_to_accumulate.extend([text_encoder_one, text_encoder_two])
with accelerator.accumulate(models_to_accumulate):
pixel_values = batch["pixel_values"].to(dtype=vae.dtype)
prompts = batch["prompts"]
# encode batch prompts when custom prompts are provided for each image -
@@ -1682,13 +1637,8 @@ def main(args):
)
# Convert images to latent space
if args.cache_latents:
model_input = latents_cache[step].sample()
else:
pixel_values = batch["pixel_values"].to(dtype=vae.dtype)
model_input = vae.encode(pixel_values).latent_dist.sample()
model_input = (model_input - vae_config_shift_factor) * vae_config_scaling_factor
model_input = vae.encode(pixel_values).latent_dist.sample()
model_input = (model_input - vae.config.shift_factor) * vae.config.scaling_factor
model_input = model_input.to(dtype=weight_dtype)
# Sample noise that we'll add to the latents
@@ -1821,8 +1771,6 @@ def main(args):
text_encoder_one, text_encoder_two, text_encoder_three = load_text_encoders(
text_encoder_cls_one, text_encoder_cls_two, text_encoder_cls_three
)
text_encoder_one.to(weight_dtype)
text_encoder_two.to(weight_dtype)
pipeline = StableDiffusion3Pipeline.from_pretrained(
args.pretrained_model_name_or_path,
vae=vae,
@@ -1843,18 +1791,17 @@ def main(args):
epoch=epoch,
torch_dtype=weight_dtype,
)
objs = []
if not args.train_text_encoder:
del text_encoder_one, text_encoder_two, text_encoder_three
free_memory()
objs.extend([text_encoder_one, text_encoder_two, text_encoder_three])
clear_objs_and_retain_memory(objs=objs)
# Save the lora layers
accelerator.wait_for_everyone()
if accelerator.is_main_process:
transformer = unwrap_model(transformer)
if args.upcast_before_saving:
transformer.to(torch.float32)
else:
transformer = transformer.to(weight_dtype)
transformer = transformer.to(torch.float32)
transformer_lora_layers = get_peft_model_state_dict(transformer)
if args.train_text_encoder:
+13 -10
View File
@@ -15,6 +15,7 @@
import argparse
import copy
import gc
import itertools
import logging
import math
@@ -50,7 +51,7 @@ from diffusers import (
StableDiffusion3Pipeline,
)
from diffusers.optimization import get_scheduler
from diffusers.training_utils import compute_density_for_timestep_sampling, compute_loss_weighting_for_sd3, free_memory
from diffusers.training_utils import compute_density_for_timestep_sampling, compute_loss_weighting_for_sd3
from diffusers.utils import (
check_min_version,
is_wandb_available,
@@ -118,7 +119,7 @@ Please adhere to the licensing terms as described `[here](https://huggingface.co
model_card = load_or_create_model_card(
repo_id_or_path=repo_id,
from_training=True,
license="other",
license="openrail++",
base_model=base_model,
prompt=instance_prompt,
model_description=model_description,
@@ -163,7 +164,7 @@ def log_validation(
f"Running validation... \n Generating {args.num_validation_images} images with prompt:"
f" {args.validation_prompt}."
)
pipeline = pipeline.to(accelerator.device)
pipeline = pipeline.to(accelerator.device, dtype=torch_dtype)
pipeline.set_progress_bar_config(disable=True)
# run inference
@@ -189,7 +190,8 @@ def log_validation(
)
del pipeline
free_memory()
if torch.cuda.is_available():
torch.cuda.empty_cache()
return images
@@ -1063,7 +1065,8 @@ def main(args):
image.save(image_filename)
del pipeline
free_memory()
if torch.cuda.is_available():
torch.cuda.empty_cache()
# Handle the repository creation
if accelerator.is_main_process:
@@ -1383,7 +1386,9 @@ def main(args):
del tokenizers, text_encoders
# Explicitly delete the objects as well, otherwise only the lists are deleted and the original references remain, preventing garbage collection
del text_encoder_one, text_encoder_two, text_encoder_three
free_memory()
gc.collect()
if torch.cuda.is_available():
torch.cuda.empty_cache()
# If custom instance prompts are NOT provided (i.e. the instance prompt is used for all images),
# pack the statically computed variables appropriately here. This is so that we don't
@@ -1703,9 +1708,6 @@ def main(args):
text_encoder_one, text_encoder_two, text_encoder_three = load_text_encoders(
text_encoder_cls_one, text_encoder_cls_two, text_encoder_cls_three
)
text_encoder_one.to(weight_dtype)
text_encoder_two.to(weight_dtype)
text_encoder_three.to(weight_dtype)
pipeline = StableDiffusion3Pipeline.from_pretrained(
args.pretrained_model_name_or_path,
vae=vae,
@@ -1728,7 +1730,8 @@ def main(args):
)
if not args.train_text_encoder:
del text_encoder_one, text_encoder_two, text_encoder_three
free_memory()
torch.cuda.empty_cache()
gc.collect()
# Save the lora layers
accelerator.wait_for_everyone()
@@ -747,22 +747,17 @@ def main():
)
# Scheduler and math around the number of training steps.
# Check the PR https://github.com/huggingface/diffusers/pull/8312 for detailed explanation.
num_warmup_steps_for_scheduler = args.lr_warmup_steps * accelerator.num_processes
overrode_max_train_steps = False
num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
if args.max_train_steps is None:
len_train_dataloader_after_sharding = math.ceil(len(train_dataloader) / accelerator.num_processes)
num_update_steps_per_epoch = math.ceil(len_train_dataloader_after_sharding / args.gradient_accumulation_steps)
num_training_steps_for_scheduler = (
args.num_train_epochs * num_update_steps_per_epoch * accelerator.num_processes
)
else:
num_training_steps_for_scheduler = args.max_train_steps * accelerator.num_processes
args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
overrode_max_train_steps = True
lr_scheduler = get_scheduler(
args.lr_scheduler,
optimizer=optimizer,
num_warmup_steps=num_warmup_steps_for_scheduler,
num_training_steps=num_training_steps_for_scheduler,
num_warmup_steps=args.lr_warmup_steps * accelerator.num_processes,
num_training_steps=args.max_train_steps * accelerator.num_processes,
)
# Prepare everything with our `accelerator`.
@@ -787,14 +782,8 @@ def main():
# We need to recalculate our total training steps as the size of the training dataloader may have changed.
num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
if args.max_train_steps is None:
if overrode_max_train_steps:
args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
if num_training_steps_for_scheduler != args.max_train_steps * accelerator.num_processes:
logger.warning(
f"The length of the 'train_dataloader' after 'accelerator.prepare' ({len(train_dataloader)}) does not match "
f"the expected length ({len_train_dataloader_after_sharding}) when the learning rate scheduler was created. "
f"This inconsistency may result in the learning rate scheduler not functioning properly."
)
# Afterwards we recalculate our number of training epochs
args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)
@@ -196,7 +196,7 @@
},
"outputs": [],
"source": [
"!rm -rf dog/.cache"
"!rm -rf dog/.huggingface"
]
},
{
@@ -59,8 +59,6 @@ check_min_version("0.31.0.dev0")
logger = get_logger(__name__)
if is_torch_npu_available():
import torch_npu
torch.npu.config.allow_internal_format = False
DATASET_NAME_MAPPING = {
@@ -542,9 +540,6 @@ def compute_vae_encodings(batch, vae):
with torch.no_grad():
model_input = vae.encode(pixel_values).latent_dist.sample()
model_input = model_input * vae.config.scaling_factor
# There might have slightly performance improvement
# by changing model_input.cpu() to accelerator.gather(model_input)
return {"model_input": model_input.cpu()}
@@ -940,10 +935,7 @@ def main(args):
del compute_vae_encodings_fn, compute_embeddings_fn, text_encoder_one, text_encoder_two
del text_encoders, tokenizers, vae
gc.collect()
if is_torch_npu_available():
torch_npu.npu.empty_cache()
elif torch.cuda.is_available():
torch.cuda.empty_cache()
torch.cuda.empty_cache()
def collate_fn(examples):
model_input = torch.stack([torch.tensor(example["model_input"]) for example in examples])
@@ -1099,7 +1091,8 @@ def main(args):
# Adapted from pipeline.StableDiffusionXLPipeline._get_add_time_ids
target_size = (args.resolution, args.resolution)
add_time_ids = list(original_size + crops_coords_top_left + target_size)
add_time_ids = torch.tensor([add_time_ids], device=accelerator.device, dtype=weight_dtype)
add_time_ids = torch.tensor([add_time_ids])
add_time_ids = add_time_ids.to(accelerator.device, dtype=weight_dtype)
return add_time_ids
add_time_ids = torch.cat(
@@ -1268,10 +1261,7 @@ def main(args):
)
del pipeline
if is_torch_npu_available():
torch_npu.npu.empty_cache()
elif torch.cuda.is_available():
torch.cuda.empty_cache()
torch.cuda.empty_cache()
if args.use_ema:
# Switch back to the original UNet parameters.
-242
View File
@@ -1,242 +0,0 @@
"""
Convert a CogView3 checkpoint to the Diffusers format.
This script converts a CogView3 checkpoint to the Diffusers format, which can then be used
with the Diffusers library.
Example usage:
python scripts/convert_cogview3_to_diffusers.py \
--transformer_checkpoint_path 'your path/cogview3plus_3b/1/mp_rank_00_model_states.pt' \
--vae_checkpoint_path 'your path/3plus_ae/imagekl_ch16.pt' \
--output_path "/raid/yiyi/cogview3_diffusers" \
--dtype "bf16"
Arguments:
--transformer_checkpoint_path: Path to Transformer state dict.
--vae_checkpoint_path: Path to VAE state dict.
--output_path: The path to save the converted model.
--push_to_hub: Whether to push the converted checkpoint to the HF Hub or not. Defaults to `False`.
--text_encoder_cache_dir: Cache directory where text encoder is located. Defaults to None, which means HF_HOME will be used
--dtype: The dtype to save the model in (default: "bf16", options: "fp16", "bf16", "fp32"). If None, the dtype of the state dict is considered.
Default is "bf16" because CogView3 uses bfloat16 for Training.
Note: You must provide either --original_state_dict_repo_id or --checkpoint_path.
"""
import argparse
from contextlib import nullcontext
import torch
from accelerate import init_empty_weights
from transformers import T5EncoderModel, T5Tokenizer
from diffusers import AutoencoderKL, CogVideoXDDIMScheduler, CogView3PlusPipeline, CogView3PlusTransformer2DModel
from diffusers.loaders.single_file_utils import convert_ldm_vae_checkpoint
from diffusers.utils.import_utils import is_accelerate_available
CTX = init_empty_weights if is_accelerate_available else nullcontext
TOKENIZER_MAX_LENGTH = 224
parser = argparse.ArgumentParser()
parser.add_argument("--transformer_checkpoint_path", default=None, type=str)
parser.add_argument("--vae_checkpoint_path", default=None, type=str)
parser.add_argument("--output_path", required=True, type=str)
parser.add_argument("--push_to_hub", action="store_true", default=False, help="Whether to push to HF Hub after saving")
parser.add_argument("--text_encoder_cache_dir", type=str, default=None, help="Path to text encoder cache directory")
parser.add_argument("--dtype", type=str, default="bf16")
args = parser.parse_args()
# this is specific to `AdaLayerNormContinuous`:
# diffusers implementation split the linear projection into the scale, shift while CogView3 split it tino shift, scale
def swap_scale_shift(weight, dim):
shift, scale = weight.chunk(2, dim=0)
new_weight = torch.cat([scale, shift], dim=0)
return new_weight
def convert_cogview3_transformer_checkpoint_to_diffusers(ckpt_path):
original_state_dict = torch.load(ckpt_path, map_location="cpu")
original_state_dict = original_state_dict["module"]
original_state_dict = {k.replace("model.diffusion_model.", ""): v for k, v in original_state_dict.items()}
new_state_dict = {}
# Convert patch_embed
new_state_dict["patch_embed.proj.weight"] = original_state_dict.pop("mixins.patch_embed.proj.weight")
new_state_dict["patch_embed.proj.bias"] = original_state_dict.pop("mixins.patch_embed.proj.bias")
new_state_dict["patch_embed.text_proj.weight"] = original_state_dict.pop("mixins.patch_embed.text_proj.weight")
new_state_dict["patch_embed.text_proj.bias"] = original_state_dict.pop("mixins.patch_embed.text_proj.bias")
# Convert time_condition_embed
new_state_dict["time_condition_embed.timestep_embedder.linear_1.weight"] = original_state_dict.pop(
"time_embed.0.weight"
)
new_state_dict["time_condition_embed.timestep_embedder.linear_1.bias"] = original_state_dict.pop(
"time_embed.0.bias"
)
new_state_dict["time_condition_embed.timestep_embedder.linear_2.weight"] = original_state_dict.pop(
"time_embed.2.weight"
)
new_state_dict["time_condition_embed.timestep_embedder.linear_2.bias"] = original_state_dict.pop(
"time_embed.2.bias"
)
new_state_dict["time_condition_embed.condition_embedder.linear_1.weight"] = original_state_dict.pop(
"label_emb.0.0.weight"
)
new_state_dict["time_condition_embed.condition_embedder.linear_1.bias"] = original_state_dict.pop(
"label_emb.0.0.bias"
)
new_state_dict["time_condition_embed.condition_embedder.linear_2.weight"] = original_state_dict.pop(
"label_emb.0.2.weight"
)
new_state_dict["time_condition_embed.condition_embedder.linear_2.bias"] = original_state_dict.pop(
"label_emb.0.2.bias"
)
# Convert transformer blocks
for i in range(30):
block_prefix = f"transformer_blocks.{i}."
old_prefix = f"transformer.layers.{i}."
adaln_prefix = f"mixins.adaln.adaln_modules.{i}."
new_state_dict[block_prefix + "norm1.linear.weight"] = original_state_dict.pop(adaln_prefix + "1.weight")
new_state_dict[block_prefix + "norm1.linear.bias"] = original_state_dict.pop(adaln_prefix + "1.bias")
qkv_weight = original_state_dict.pop(old_prefix + "attention.query_key_value.weight")
qkv_bias = original_state_dict.pop(old_prefix + "attention.query_key_value.bias")
q, k, v = qkv_weight.chunk(3, dim=0)
q_bias, k_bias, v_bias = qkv_bias.chunk(3, dim=0)
new_state_dict[block_prefix + "attn1.to_q.weight"] = q
new_state_dict[block_prefix + "attn1.to_q.bias"] = q_bias
new_state_dict[block_prefix + "attn1.to_k.weight"] = k
new_state_dict[block_prefix + "attn1.to_k.bias"] = k_bias
new_state_dict[block_prefix + "attn1.to_v.weight"] = v
new_state_dict[block_prefix + "attn1.to_v.bias"] = v_bias
new_state_dict[block_prefix + "attn1.to_out.0.weight"] = original_state_dict.pop(
old_prefix + "attention.dense.weight"
)
new_state_dict[block_prefix + "attn1.to_out.0.bias"] = original_state_dict.pop(
old_prefix + "attention.dense.bias"
)
new_state_dict[block_prefix + "ff.net.0.proj.weight"] = original_state_dict.pop(
old_prefix + "mlp.dense_h_to_4h.weight"
)
new_state_dict[block_prefix + "ff.net.0.proj.bias"] = original_state_dict.pop(
old_prefix + "mlp.dense_h_to_4h.bias"
)
new_state_dict[block_prefix + "ff.net.2.weight"] = original_state_dict.pop(
old_prefix + "mlp.dense_4h_to_h.weight"
)
new_state_dict[block_prefix + "ff.net.2.bias"] = original_state_dict.pop(old_prefix + "mlp.dense_4h_to_h.bias")
# Convert final norm and projection
new_state_dict["norm_out.linear.weight"] = swap_scale_shift(
original_state_dict.pop("mixins.final_layer.adaln.1.weight"), dim=0
)
new_state_dict["norm_out.linear.bias"] = swap_scale_shift(
original_state_dict.pop("mixins.final_layer.adaln.1.bias"), dim=0
)
new_state_dict["proj_out.weight"] = original_state_dict.pop("mixins.final_layer.linear.weight")
new_state_dict["proj_out.bias"] = original_state_dict.pop("mixins.final_layer.linear.bias")
return new_state_dict
def convert_cogview3_vae_checkpoint_to_diffusers(ckpt_path, vae_config):
original_state_dict = torch.load(ckpt_path, map_location="cpu")["state_dict"]
return convert_ldm_vae_checkpoint(original_state_dict, vae_config)
def main(args):
if args.dtype == "fp16":
dtype = torch.float16
elif args.dtype == "bf16":
dtype = torch.bfloat16
elif args.dtype == "fp32":
dtype = torch.float32
else:
raise ValueError(f"Unsupported dtype: {args.dtype}")
transformer = None
vae = None
if args.transformer_checkpoint_path is not None:
converted_transformer_state_dict = convert_cogview3_transformer_checkpoint_to_diffusers(
args.transformer_checkpoint_path
)
transformer = CogView3PlusTransformer2DModel()
transformer.load_state_dict(converted_transformer_state_dict, strict=True)
if dtype is not None:
# Original checkpoint data type will be preserved
transformer = transformer.to(dtype=dtype)
if args.vae_checkpoint_path is not None:
vae_config = {
"in_channels": 3,
"out_channels": 3,
"down_block_types": ("DownEncoderBlock2D",) * 4,
"up_block_types": ("UpDecoderBlock2D",) * 4,
"block_out_channels": (128, 512, 1024, 1024),
"layers_per_block": 3,
"act_fn": "silu",
"latent_channels": 16,
"norm_num_groups": 32,
"sample_size": 1024,
"scaling_factor": 1.0,
"force_upcast": True,
"use_quant_conv": False,
"use_post_quant_conv": False,
"mid_block_add_attention": False,
}
converted_vae_state_dict = convert_cogview3_vae_checkpoint_to_diffusers(args.vae_checkpoint_path, vae_config)
vae = AutoencoderKL(**vae_config)
vae.load_state_dict(converted_vae_state_dict, strict=True)
if dtype is not None:
vae = vae.to(dtype=dtype)
text_encoder_id = "google/t5-v1_1-xxl"
tokenizer = T5Tokenizer.from_pretrained(text_encoder_id, model_max_length=TOKENIZER_MAX_LENGTH)
text_encoder = T5EncoderModel.from_pretrained(text_encoder_id, cache_dir=args.text_encoder_cache_dir)
# Apparently, the conversion does not work anymore without this :shrug:
for param in text_encoder.parameters():
param.data = param.data.contiguous()
scheduler = CogVideoXDDIMScheduler.from_config(
{
"snr_shift_scale": 4.0,
"beta_end": 0.012,
"beta_schedule": "scaled_linear",
"beta_start": 0.00085,
"clip_sample": False,
"num_train_timesteps": 1000,
"prediction_type": "v_prediction",
"rescale_betas_zero_snr": True,
"set_alpha_to_one": True,
"timestep_spacing": "trailing",
}
)
pipe = CogView3PlusPipeline(
tokenizer=tokenizer,
text_encoder=text_encoder,
vae=vae,
transformer=transformer,
scheduler=scheduler,
)
# This is necessary for users with insufficient memory, such as those using Colab and notebooks, as it can
# save some memory used for model loading.
pipe.save_pretrained(args.output_path, safe_serialization=True, max_shard_size="5GB", push_to_hub=args.push_to_hub)
if __name__ == "__main__":
main(args)
-10
View File
@@ -84,7 +84,6 @@ else:
"AutoencoderOobleck",
"AutoencoderTiny",
"CogVideoXTransformer3DModel",
"CogView3PlusTransformer2DModel",
"ConsistencyDecoderVAE",
"ControlNetModel",
"ControlNetXSAdapter",
@@ -256,11 +255,9 @@ else:
"BlipDiffusionControlNetPipeline",
"BlipDiffusionPipeline",
"CLIPImageProjection",
"CogVideoXFunControlPipeline",
"CogVideoXImageToVideoPipeline",
"CogVideoXPipeline",
"CogVideoXVideoToVideoPipeline",
"CogView3PlusPipeline",
"CycleDiffusionPipeline",
"FluxControlNetImg2ImgPipeline",
"FluxControlNetInpaintPipeline",
@@ -331,7 +328,6 @@ else:
"StableDiffusionAttendAndExcitePipeline",
"StableDiffusionControlNetImg2ImgPipeline",
"StableDiffusionControlNetInpaintPipeline",
"StableDiffusionControlNetPAGInpaintPipeline",
"StableDiffusionControlNetPAGPipeline",
"StableDiffusionControlNetPipeline",
"StableDiffusionControlNetXSPipeline",
@@ -347,7 +343,6 @@ else:
"StableDiffusionLatentUpscalePipeline",
"StableDiffusionLDM3DPipeline",
"StableDiffusionModelEditingPipeline",
"StableDiffusionPAGImg2ImgPipeline",
"StableDiffusionPAGPipeline",
"StableDiffusionPanoramaPipeline",
"StableDiffusionParadigmsPipeline",
@@ -562,7 +557,6 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
AutoencoderOobleck,
AutoencoderTiny,
CogVideoXTransformer3DModel,
CogView3PlusTransformer2DModel,
ConsistencyDecoderVAE,
ControlNetModel,
ControlNetXSAdapter,
@@ -712,11 +706,9 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
AudioLDMPipeline,
AuraFlowPipeline,
CLIPImageProjection,
CogVideoXFunControlPipeline,
CogVideoXImageToVideoPipeline,
CogVideoXPipeline,
CogVideoXVideoToVideoPipeline,
CogView3PlusPipeline,
CycleDiffusionPipeline,
FluxControlNetImg2ImgPipeline,
FluxControlNetInpaintPipeline,
@@ -786,7 +778,6 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
StableDiffusionAttendAndExcitePipeline,
StableDiffusionControlNetImg2ImgPipeline,
StableDiffusionControlNetInpaintPipeline,
StableDiffusionControlNetPAGInpaintPipeline,
StableDiffusionControlNetPAGPipeline,
StableDiffusionControlNetPipeline,
StableDiffusionControlNetXSPipeline,
@@ -802,7 +793,6 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
StableDiffusionLatentUpscalePipeline,
StableDiffusionLDM3DPipeline,
StableDiffusionModelEditingPipeline,
StableDiffusionPAGImg2ImgPipeline,
StableDiffusionPAGPipeline,
StableDiffusionPanoramaPipeline,
StableDiffusionParadigmsPipeline,
+53 -256
View File
@@ -38,44 +38,16 @@ PipelineImageInput = Union[
PipelineDepthInput = PipelineImageInput
def is_valid_image(image) -> bool:
r"""
Checks if the input is a valid image.
A valid image can be:
- A `PIL.Image.Image`.
- A 2D or 3D `np.ndarray` or `torch.Tensor` (grayscale or color image).
Args:
image (`Union[PIL.Image.Image, np.ndarray, torch.Tensor]`):
The image to validate. It can be a PIL image, a NumPy array, or a torch tensor.
Returns:
`bool`:
`True` if the input is a valid image, `False` otherwise.
"""
def is_valid_image(image):
return isinstance(image, PIL.Image.Image) or isinstance(image, (np.ndarray, torch.Tensor)) and image.ndim in (2, 3)
def is_valid_image_imagelist(images):
r"""
Checks if the input is a valid image or list of images.
The input can be one of the following formats:
- A 4D tensor or numpy array (batch of images).
- A valid single image: `PIL.Image.Image`, 2D `np.ndarray` or `torch.Tensor` (grayscale image), 3D `np.ndarray` or
`torch.Tensor`.
- A list of valid images.
Args:
images (`Union[np.ndarray, torch.Tensor, PIL.Image.Image, List]`):
The image(s) to check. Can be a batch of images (4D tensor/array), a single image, or a list of valid
images.
Returns:
`bool`:
`True` if the input is valid, `False` otherwise.
"""
# check if the image input is one of the supported formats for image and image list:
# it can be either one of below 3
# (1) a 4d pytorch tensor or numpy array,
# (2) a valid image: PIL.Image.Image, 2-d np.ndarray or torch.Tensor (grayscale image), 3-d np.ndarray or torch.Tensor
# (3) a list of valid image
if isinstance(images, (np.ndarray, torch.Tensor)) and images.ndim == 4:
return True
elif is_valid_image(images):
@@ -131,16 +103,8 @@ class VaeImageProcessor(ConfigMixin):
@staticmethod
def numpy_to_pil(images: np.ndarray) -> List[PIL.Image.Image]:
r"""
"""
Convert a numpy image or a batch of images to a PIL image.
Args:
images (`np.ndarray`):
The image array to convert to PIL format.
Returns:
`List[PIL.Image.Image]`:
A list of PIL images.
"""
if images.ndim == 3:
images = images[None, ...]
@@ -155,16 +119,8 @@ class VaeImageProcessor(ConfigMixin):
@staticmethod
def pil_to_numpy(images: Union[List[PIL.Image.Image], PIL.Image.Image]) -> np.ndarray:
r"""
"""
Convert a PIL image or a list of PIL images to NumPy arrays.
Args:
images (`PIL.Image.Image` or `List[PIL.Image.Image]`):
The PIL image or list of images to convert to NumPy format.
Returns:
`np.ndarray`:
A NumPy array representation of the images.
"""
if not isinstance(images, list):
images = [images]
@@ -175,16 +131,8 @@ class VaeImageProcessor(ConfigMixin):
@staticmethod
def numpy_to_pt(images: np.ndarray) -> torch.Tensor:
r"""
"""
Convert a NumPy image to a PyTorch tensor.
Args:
images (`np.ndarray`):
The NumPy image array to convert to PyTorch format.
Returns:
`torch.Tensor`:
A PyTorch tensor representation of the images.
"""
if images.ndim == 3:
images = images[..., None]
@@ -194,62 +142,30 @@ class VaeImageProcessor(ConfigMixin):
@staticmethod
def pt_to_numpy(images: torch.Tensor) -> np.ndarray:
r"""
"""
Convert a PyTorch tensor to a NumPy image.
Args:
images (`torch.Tensor`):
The PyTorch tensor to convert to NumPy format.
Returns:
`np.ndarray`:
A NumPy array representation of the images.
"""
images = images.cpu().permute(0, 2, 3, 1).float().numpy()
return images
@staticmethod
def normalize(images: Union[np.ndarray, torch.Tensor]) -> Union[np.ndarray, torch.Tensor]:
r"""
"""
Normalize an image array to [-1,1].
Args:
images (`np.ndarray` or `torch.Tensor`):
The image array to normalize.
Returns:
`np.ndarray` or `torch.Tensor`:
The normalized image array.
"""
return 2.0 * images - 1.0
@staticmethod
def denormalize(images: Union[np.ndarray, torch.Tensor]) -> Union[np.ndarray, torch.Tensor]:
r"""
"""
Denormalize an image array to [0,1].
Args:
images (`np.ndarray` or `torch.Tensor`):
The image array to denormalize.
Returns:
`np.ndarray` or `torch.Tensor`:
The denormalized image array.
"""
return (images / 2 + 0.5).clamp(0, 1)
@staticmethod
def convert_to_rgb(image: PIL.Image.Image) -> PIL.Image.Image:
r"""
"""
Converts a PIL image to RGB format.
Args:
image (`PIL.Image.Image`):
The PIL image to convert to RGB.
Returns:
`PIL.Image.Image`:
The RGB-converted PIL image.
"""
image = image.convert("RGB")
@@ -257,16 +173,8 @@ class VaeImageProcessor(ConfigMixin):
@staticmethod
def convert_to_grayscale(image: PIL.Image.Image) -> PIL.Image.Image:
r"""
Converts a given PIL image to grayscale.
Args:
image (`PIL.Image.Image`):
The input image to convert.
Returns:
`PIL.Image.Image`:
The image converted to grayscale.
"""
Converts a PIL image to grayscale format.
"""
image = image.convert("L")
@@ -274,16 +182,8 @@ class VaeImageProcessor(ConfigMixin):
@staticmethod
def blur(image: PIL.Image.Image, blur_factor: int = 4) -> PIL.Image.Image:
r"""
"""
Applies Gaussian blur to an image.
Args:
image (`PIL.Image.Image`):
The PIL image to convert to grayscale.
Returns:
`PIL.Image.Image`:
The grayscale-converted PIL image.
"""
image = image.filter(ImageFilter.GaussianBlur(blur_factor))
@@ -291,7 +191,7 @@ class VaeImageProcessor(ConfigMixin):
@staticmethod
def get_crop_region(mask_image: PIL.Image.Image, width: int, height: int, pad=0):
r"""
"""
Finds a rectangular region that contains all masked ares in an image, and expands region to match the aspect
ratio of the original image; for example, if user drew mask in a 128x32 region, and the dimensions for
processing are 512x512, the region will be expanded to 128x128.
@@ -385,21 +285,14 @@ class VaeImageProcessor(ConfigMixin):
width: int,
height: int,
) -> PIL.Image.Image:
r"""
"""
Resize the image to fit within the specified width and height, maintaining the aspect ratio, and then center
the image within the dimensions, filling empty with data from image.
Args:
image (`PIL.Image.Image`):
The image to resize and fill.
width (`int`):
The width to resize the image to.
height (`int`):
The height to resize the image to.
Returns:
`PIL.Image.Image`:
The resized and filled image.
image: The image to resize.
width: The width to resize the image to.
height: The height to resize the image to.
"""
ratio = width / height
@@ -437,21 +330,14 @@ class VaeImageProcessor(ConfigMixin):
width: int,
height: int,
) -> PIL.Image.Image:
r"""
"""
Resize the image to fit within the specified width and height, maintaining the aspect ratio, and then center
the image within the dimensions, cropping the excess.
Args:
image (`PIL.Image.Image`):
The image to resize and crop.
width (`int`):
The width to resize the image to.
height (`int`):
The height to resize the image to.
Returns:
`PIL.Image.Image`:
The resized and cropped image.
image: The image to resize.
width: The width to resize the image to.
height: The height to resize the image to.
"""
ratio = width / height
src_ratio = image.width / image.height
@@ -543,23 +429,19 @@ class VaeImageProcessor(ConfigMixin):
height: Optional[int] = None,
width: Optional[int] = None,
) -> Tuple[int, int]:
r"""
Returns the height and width of the image, downscaled to the next integer multiple of `vae_scale_factor`.
"""
This function return the height and width that are downscaled to the next integer multiple of
`vae_scale_factor`.
Args:
image (`Union[PIL.Image.Image, np.ndarray, torch.Tensor]`):
The image input, which can be a PIL image, NumPy array, or PyTorch tensor. If it is a NumPy array, it
should have shape `[batch, height, width]` or `[batch, height, width, channels]`. If it is a PyTorch
tensor, it should have shape `[batch, channels, height, width]`.
height (`Optional[int]`, *optional*, defaults to `None`):
The height of the preprocessed image. If `None`, the height of the `image` input will be used.
width (`Optional[int]`, *optional*, defaults to `None`):
The width of the preprocessed image. If `None`, the width of the `image` input will be used.
Returns:
`Tuple[int, int]`:
A tuple containing the height and width, both resized to the nearest integer multiple of
`vae_scale_factor`.
image(`PIL.Image.Image`, `np.ndarray` or `torch.Tensor`):
The image input, can be a PIL image, numpy array or pytorch tensor. if it is a numpy array, should have
shape `[batch, height, width]` or `[batch, height, width, channel]` if it is a pytorch tensor, should
have shape `[batch, channel, height, width]`.
height (`int`, *optional*, defaults to `None`):
The height in preprocessed image. If `None`, will use the height of `image` input.
width (`int`, *optional*`, defaults to `None`):
The width in preprocessed. If `None`, will use the width of the `image` input.
"""
if height is None:
@@ -596,13 +478,13 @@ class VaeImageProcessor(ConfigMixin):
Preprocess the image input.
Args:
image (`PipelineImageInput`):
image (`pipeline_image_input`):
The image input, accepted formats are PIL images, NumPy arrays, PyTorch tensors; Also accept list of
supported formats.
height (`int`, *optional*):
height (`int`, *optional*, defaults to `None`):
The height in preprocessed image. If `None`, will use the `get_default_height_width()` to get default
height.
width (`int`, *optional*):
width (`int`, *optional*`, defaults to `None`):
The width in preprocessed. If `None`, will use get_default_height_width()` to get the default width.
resize_mode (`str`, *optional*, defaults to `default`):
The resize mode, can be one of `default` or `fill`. If `default`, will resize the image to fit within
@@ -614,10 +496,6 @@ class VaeImageProcessor(ConfigMixin):
supported for PIL image input.
crops_coords (`List[Tuple[int, int, int, int]]`, *optional*, defaults to `None`):
The crop coordinates for each image in the batch. If `None`, will not crop the image.
Returns:
`torch.Tensor`:
The preprocessed image.
"""
supported_formats = (PIL.Image.Image, np.ndarray, torch.Tensor)
@@ -777,22 +655,8 @@ class VaeImageProcessor(ConfigMixin):
image: PIL.Image.Image,
crop_coords: Optional[Tuple[int, int, int, int]] = None,
) -> PIL.Image.Image:
r"""
Applies an overlay of the mask and the inpainted image on the original image.
Args:
mask (`PIL.Image.Image`):
The mask image that highlights regions to overlay.
init_image (`PIL.Image.Image`):
The original image to which the overlay is applied.
image (`PIL.Image.Image`):
The image to overlay onto the original.
crop_coords (`Tuple[int, int, int, int]`, *optional*):
Coordinates to crop the image. If provided, the image will be cropped accordingly.
Returns:
`PIL.Image.Image`:
The final image with the overlay applied.
"""
overlay the inpaint output to the original image
"""
width, height = image.width, image.height
@@ -849,16 +713,8 @@ class VaeImageProcessorLDM3D(VaeImageProcessor):
@staticmethod
def numpy_to_pil(images: np.ndarray) -> List[PIL.Image.Image]:
r"""
Convert a NumPy image or a batch of images to a list of PIL images.
Args:
images (`np.ndarray`):
The input NumPy array of images, which can be a single image or a batch.
Returns:
`List[PIL.Image.Image]`:
A list of PIL images converted from the input NumPy array.
"""
Convert a NumPy image or a batch of images to a PIL image.
"""
if images.ndim == 3:
images = images[None, ...]
@@ -873,16 +729,8 @@ class VaeImageProcessorLDM3D(VaeImageProcessor):
@staticmethod
def depth_pil_to_numpy(images: Union[List[PIL.Image.Image], PIL.Image.Image]) -> np.ndarray:
r"""
"""
Convert a PIL image or a list of PIL images to NumPy arrays.
Args:
images (`Union[List[PIL.Image.Image], PIL.Image.Image]`):
The input image or list of images to be converted.
Returns:
`np.ndarray`:
A NumPy array of the converted images.
"""
if not isinstance(images, list):
images = [images]
@@ -893,30 +741,18 @@ class VaeImageProcessorLDM3D(VaeImageProcessor):
@staticmethod
def rgblike_to_depthmap(image: Union[np.ndarray, torch.Tensor]) -> Union[np.ndarray, torch.Tensor]:
r"""
Convert an RGB-like depth image to a depth map.
"""
Args:
image (`Union[np.ndarray, torch.Tensor]`):
The RGB-like depth image to convert.
image: RGB-like depth image
Returns: depth map
Returns:
`Union[np.ndarray, torch.Tensor]`:
The corresponding depth map.
"""
return image[:, :, 1] * 2**8 + image[:, :, 2]
def numpy_to_depth(self, images: np.ndarray) -> List[PIL.Image.Image]:
r"""
Convert a NumPy depth image or a batch of images to a list of PIL images.
Args:
images (`np.ndarray`):
The input NumPy array of depth images, which can be a single image or a batch.
Returns:
`List[PIL.Image.Image]`:
A list of PIL images converted from the input NumPy depth images.
"""
Convert a NumPy depth image or a batch of images to a PIL image.
"""
if images.ndim == 3:
images = images[None, ...]
@@ -997,24 +833,8 @@ class VaeImageProcessorLDM3D(VaeImageProcessor):
width: Optional[int] = None,
target_res: Optional[int] = None,
) -> torch.Tensor:
r"""
Preprocess the image input. Accepted formats are PIL images, NumPy arrays, or PyTorch tensors.
Args:
rgb (`Union[torch.Tensor, PIL.Image.Image, np.ndarray]`):
The RGB input image, which can be a single image or a batch.
depth (`Union[torch.Tensor, PIL.Image.Image, np.ndarray]`):
The depth input image, which can be a single image or a batch.
height (`Optional[int]`, *optional*, defaults to `None`):
The desired height of the processed image. If `None`, defaults to the height of the input image.
width (`Optional[int]`, *optional*, defaults to `None`):
The desired width of the processed image. If `None`, defaults to the width of the input image.
target_res (`Optional[int]`, *optional*, defaults to `None`):
Target resolution for resizing the images. If specified, overrides height and width.
Returns:
`Tuple[torch.Tensor, torch.Tensor]`:
A tuple containing the processed RGB and depth images as PyTorch tensors.
"""
Preprocess the image input. Accepted formats are PIL images, NumPy arrays or PyTorch tensors.
"""
supported_formats = (PIL.Image.Image, np.ndarray, torch.Tensor)
@@ -1252,17 +1072,7 @@ class PixArtImageProcessor(VaeImageProcessor):
@staticmethod
def classify_height_width_bin(height: int, width: int, ratios: dict) -> Tuple[int, int]:
r"""
Returns the binned height and width based on the aspect ratio.
Args:
height (`int`): The height of the image.
width (`int`): The width of the image.
ratios (`dict`): A dictionary where keys are aspect ratios and values are tuples of (height, width).
Returns:
`Tuple[int, int]`: The closest binned height and width.
"""
"""Returns binned height and width."""
ar = float(height / width)
closest_ratio = min(ratios.keys(), key=lambda ratio: abs(float(ratio) - ar))
default_hw = ratios[closest_ratio]
@@ -1270,19 +1080,6 @@ class PixArtImageProcessor(VaeImageProcessor):
@staticmethod
def resize_and_crop_tensor(samples: torch.Tensor, new_width: int, new_height: int) -> torch.Tensor:
r"""
Resizes and crops a tensor of images to the specified dimensions.
Args:
samples (`torch.Tensor`):
A tensor of shape (N, C, H, W) where N is the batch size, C is the number of channels, H is the height,
and W is the width.
new_width (`int`): The desired width of the output images.
new_height (`int`): The desired height of the output images.
Returns:
`torch.Tensor`: A tensor containing the resized and cropped images.
"""
orig_height, orig_width = samples.shape[2], samples.shape[3]
# Check if resizing is needed
+4 -10
View File
@@ -532,19 +532,13 @@ class LoraBaseMixin:
)
list_adapters = self.get_list_adapters() # eg {"unet": ["adapter1", "adapter2"], "text_encoder": ["adapter2"]}
# eg ["adapter1", "adapter2"]
all_adapters = {adapter for adapters in list_adapters.values() for adapter in adapters}
missing_adapters = set(adapter_names) - all_adapters
if len(missing_adapters) > 0:
raise ValueError(
f"Adapter name(s) {missing_adapters} not in the list of present adapters: {all_adapters}."
)
# eg {"adapter1": ["unet"], "adapter2": ["unet", "text_encoder"]}
all_adapters = {
adapter for adapters in list_adapters.values() for adapter in adapters
} # eg ["adapter1", "adapter2"]
invert_list_adapters = {
adapter: [part for part, adapters in list_adapters.items() if adapter in adapters]
for adapter in all_adapters
}
} # eg {"adapter1": ["unet"], "adapter2": ["unet", "text_encoder"]}
# Decompose weights into weights for denoiser and text encoders.
_component_adapter_weights = {}
+3 -40
View File
@@ -516,47 +516,10 @@ def _convert_kohya_flux_lora_to_diffusers(state_dict):
f"transformer.single_transformer_blocks.{i}.norm.linear",
)
remaining_keys = list(sds_sd.keys())
te_state_dict = {}
if remaining_keys:
if not all(k.startswith("lora_te1") for k in remaining_keys):
raise ValueError(f"Incompatible keys detected: \n\n {', '.join(remaining_keys)}")
for key in remaining_keys:
if not key.endswith("lora_down.weight"):
continue
lora_name = key.split(".")[0]
lora_name_up = f"{lora_name}.lora_up.weight"
lora_name_alpha = f"{lora_name}.alpha"
diffusers_name = _convert_text_encoder_lora_key(key, lora_name)
if lora_name.startswith(("lora_te_", "lora_te1_")):
down_weight = sds_sd.pop(key)
sd_lora_rank = down_weight.shape[0]
te_state_dict[diffusers_name] = down_weight
te_state_dict[diffusers_name.replace(".down.", ".up.")] = sds_sd.pop(lora_name_up)
if lora_name_alpha in sds_sd:
alpha = sds_sd.pop(lora_name_alpha).item()
scale = alpha / sd_lora_rank
scale_down = scale
scale_up = 1.0
while scale_down * 2 < scale_up:
scale_down *= 2
scale_up /= 2
te_state_dict[diffusers_name] *= scale_down
te_state_dict[diffusers_name.replace(".down.", ".up.")] *= scale_up
if len(sds_sd) > 0:
logger.warning(f"Unsupported keys for ai-toolkit: {sds_sd.keys()}")
logger.warning(f"Unsuppored keys for ai-toolkit: {sds_sd.keys()}")
if te_state_dict:
te_state_dict = {f"text_encoder.{module_name}": params for module_name, params in te_state_dict.items()}
new_state_dict = {**ait_sd, **te_state_dict}
return new_state_dict
return ait_sd
return _convert_sd_scripts_to_ai_toolkit(state_dict)
@@ -632,7 +595,7 @@ def _convert_xlabs_flux_lora_to_diffusers(old_state_dict):
new_key += ".lora_B.weight"
# Handle single_blocks
elif old_key.startswith(("diffusion_model.single_blocks", "single_blocks")):
elif old_key.startswith("diffusion_model.single_blocks", "single_blocks"):
block_num = re.search(r"single_blocks\.(\d+)", old_key).group(1)
new_key = f"transformer.single_transformer_blocks.{block_num}"
+46 -348
View File
@@ -25,11 +25,8 @@ from ..utils import (
deprecate,
get_adapter_name,
get_peft_kwargs,
is_peft_available,
is_peft_version,
is_torch_version,
is_transformers_available,
is_transformers_version,
logging,
scale_lora_layers,
)
@@ -42,17 +39,6 @@ from .lora_conversion_utils import (
)
_LOW_CPU_MEM_USAGE_DEFAULT_LORA = False
if is_torch_version(">=", "1.9.0"):
if (
is_peft_available()
and is_peft_version(">=", "0.13.1")
and is_transformers_available()
and is_transformers_version(">", "4.45.2")
):
_LOW_CPU_MEM_USAGE_DEFAULT_LORA = True
if is_transformers_available():
from ..models.lora import text_encoder_attn_modules, text_encoder_mlp_modules
@@ -97,24 +83,15 @@ class StableDiffusionLoraLoaderMixin(LoraBaseMixin):
Parameters:
pretrained_model_name_or_path_or_dict (`str` or `os.PathLike` or `dict`):
See [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`].
kwargs (`dict`, *optional*):
See [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`].
adapter_name (`str`, *optional*):
Adapter name to be used for referencing the loaded adapter model. If not specified, it will use
`default_{i}` where i is the total number of adapters being loaded.
low_cpu_mem_usage (`bool`, *optional*):
Speed up model loading by only loading the pretrained LoRA weights and not initializing the random
weights.
kwargs (`dict`, *optional*):
See [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`].
"""
if not USE_PEFT_BACKEND:
raise ValueError("PEFT backend is required for this method.")
low_cpu_mem_usage = kwargs.pop("low_cpu_mem_usage", _LOW_CPU_MEM_USAGE_DEFAULT_LORA)
if low_cpu_mem_usage and not is_peft_version(">=", "0.13.1"):
raise ValueError(
"`low_cpu_mem_usage=True` is not compatible with this `peft` version. Please update it with `pip install -U peft`."
)
# if a dict is passed, copy it instead of modifying it inplace
if isinstance(pretrained_model_name_or_path_or_dict, dict):
pretrained_model_name_or_path_or_dict = pretrained_model_name_or_path_or_dict.copy()
@@ -122,7 +99,7 @@ class StableDiffusionLoraLoaderMixin(LoraBaseMixin):
# First, ensure that the checkpoint is a compatible one and can be successfully loaded.
state_dict, network_alphas = self.lora_state_dict(pretrained_model_name_or_path_or_dict, **kwargs)
is_correct_format = all("lora" in key for key in state_dict.keys())
is_correct_format = all("lora" in key or "dora_scale" in key for key in state_dict.keys())
if not is_correct_format:
raise ValueError("Invalid LoRA checkpoint.")
@@ -132,7 +109,6 @@ class StableDiffusionLoraLoaderMixin(LoraBaseMixin):
unet=getattr(self, self.unet_name) if not hasattr(self, "unet") else self.unet,
adapter_name=adapter_name,
_pipeline=self,
low_cpu_mem_usage=low_cpu_mem_usage,
)
self.load_lora_into_text_encoder(
state_dict,
@@ -143,7 +119,6 @@ class StableDiffusionLoraLoaderMixin(LoraBaseMixin):
lora_scale=self.lora_scale,
adapter_name=adapter_name,
_pipeline=self,
low_cpu_mem_usage=low_cpu_mem_usage,
)
@classmethod
@@ -236,11 +211,6 @@ class StableDiffusionLoraLoaderMixin(LoraBaseMixin):
user_agent=user_agent,
allow_pickle=allow_pickle,
)
is_dora_scale_present = any("dora_scale" in k for k in state_dict)
if is_dora_scale_present:
warn_msg = "It seems like you are using a DoRA checkpoint that is not compatible in Diffusers at the moment. So, we are going to filter out the keys associated to 'dora_scale` from the state dict. If you think this is a mistake please open an issue https://github.com/huggingface/diffusers/issues/new."
logger.warning(warn_msg)
state_dict = {k: v for k, v in state_dict.items() if "dora_scale" not in k}
network_alphas = None
# TODO: replace it with a method from `state_dict_utils`
@@ -262,9 +232,7 @@ class StableDiffusionLoraLoaderMixin(LoraBaseMixin):
return state_dict, network_alphas
@classmethod
def load_lora_into_unet(
cls, state_dict, network_alphas, unet, adapter_name=None, _pipeline=None, low_cpu_mem_usage=False
):
def load_lora_into_unet(cls, state_dict, network_alphas, unet, adapter_name=None, _pipeline=None):
"""
This will load the LoRA layers specified in `state_dict` into `unet`.
@@ -282,16 +250,10 @@ class StableDiffusionLoraLoaderMixin(LoraBaseMixin):
adapter_name (`str`, *optional*):
Adapter name to be used for referencing the loaded adapter model. If not specified, it will use
`default_{i}` where i is the total number of adapters being loaded.
Speed up model loading only loading the pretrained LoRA weights and not initializing the random weights.
"""
if not USE_PEFT_BACKEND:
raise ValueError("PEFT backend is required for this method.")
if low_cpu_mem_usage and not is_peft_version(">=", "0.13.1"):
raise ValueError(
"`low_cpu_mem_usage=True` is not compatible with this `peft` version. Please update it with `pip install -U peft`."
)
# If the serialization format is new (introduced in https://github.com/huggingface/diffusers/pull/2918),
# then the `state_dict` keys should have `cls.unet_name` and/or `cls.text_encoder_name` as
# their prefixes.
@@ -301,11 +263,7 @@ class StableDiffusionLoraLoaderMixin(LoraBaseMixin):
# Load the layers corresponding to UNet.
logger.info(f"Loading {cls.unet_name}.")
unet.load_attn_procs(
state_dict,
network_alphas=network_alphas,
adapter_name=adapter_name,
_pipeline=_pipeline,
low_cpu_mem_usage=low_cpu_mem_usage,
state_dict, network_alphas=network_alphas, adapter_name=adapter_name, _pipeline=_pipeline
)
@classmethod
@@ -318,7 +276,6 @@ class StableDiffusionLoraLoaderMixin(LoraBaseMixin):
lora_scale=1.0,
adapter_name=None,
_pipeline=None,
low_cpu_mem_usage=False,
):
"""
This will load the LoRA layers specified in `state_dict` into `text_encoder`
@@ -341,25 +298,10 @@ class StableDiffusionLoraLoaderMixin(LoraBaseMixin):
adapter_name (`str`, *optional*):
Adapter name to be used for referencing the loaded adapter model. If not specified, it will use
`default_{i}` where i is the total number of adapters being loaded.
Speed up model loading by only loading the pretrained LoRA weights and not initializing the random weights.:
"""
if not USE_PEFT_BACKEND:
raise ValueError("PEFT backend is required for this method.")
peft_kwargs = {}
if low_cpu_mem_usage:
if not is_peft_version(">=", "0.13.1"):
raise ValueError(
"`low_cpu_mem_usage=True` is not compatible with this `peft` version. Please update it with `pip install -U peft`."
)
if not is_transformers_version(">", "4.45.2"):
# Note from sayakpaul: It's not in `transformers` stable yet.
# https://github.com/huggingface/transformers/pull/33725/
raise ValueError(
"`low_cpu_mem_usage=True` is not compatible with this `transformers` version. Please update it with `pip install -U transformers`."
)
peft_kwargs["low_cpu_mem_usage"] = low_cpu_mem_usage
from peft import LoraConfig
# If the serialization format is new (introduced in https://github.com/huggingface/diffusers/pull/2918),
@@ -430,7 +372,6 @@ class StableDiffusionLoraLoaderMixin(LoraBaseMixin):
adapter_name=adapter_name,
adapter_state_dict=text_encoder_lora_state_dict,
peft_config=lora_config,
**peft_kwargs,
)
# scale LoRA layers with `lora_scale`
@@ -601,19 +542,12 @@ class StableDiffusionXLLoraLoaderMixin(LoraBaseMixin):
adapter_name (`str`, *optional*):
Adapter name to be used for referencing the loaded adapter model. If not specified, it will use
`default_{i}` where i is the total number of adapters being loaded.
Speed up model loading by only loading the pretrained LoRA weights and not initializing the random weights.:
kwargs (`dict`, *optional*):
See [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`].
"""
if not USE_PEFT_BACKEND:
raise ValueError("PEFT backend is required for this method.")
low_cpu_mem_usage = kwargs.pop("low_cpu_mem_usage", _LOW_CPU_MEM_USAGE_DEFAULT_LORA)
if low_cpu_mem_usage and not is_peft_version(">=", "0.13.1"):
raise ValueError(
"`low_cpu_mem_usage=True` is not compatible with this `peft` version. Please update it with `pip install -U peft`."
)
# We could have accessed the unet config from `lora_state_dict()` too. We pass
# it here explicitly to be able to tell that it's coming from an SDXL
# pipeline.
@@ -628,18 +562,12 @@ class StableDiffusionXLLoraLoaderMixin(LoraBaseMixin):
unet_config=self.unet.config,
**kwargs,
)
is_correct_format = all("lora" in key for key in state_dict.keys())
is_correct_format = all("lora" in key or "dora_scale" in key for key in state_dict.keys())
if not is_correct_format:
raise ValueError("Invalid LoRA checkpoint.")
self.load_lora_into_unet(
state_dict,
network_alphas=network_alphas,
unet=self.unet,
adapter_name=adapter_name,
_pipeline=self,
low_cpu_mem_usage=low_cpu_mem_usage,
state_dict, network_alphas=network_alphas, unet=self.unet, adapter_name=adapter_name, _pipeline=self
)
text_encoder_state_dict = {k: v for k, v in state_dict.items() if "text_encoder." in k}
if len(text_encoder_state_dict) > 0:
@@ -651,7 +579,6 @@ class StableDiffusionXLLoraLoaderMixin(LoraBaseMixin):
lora_scale=self.lora_scale,
adapter_name=adapter_name,
_pipeline=self,
low_cpu_mem_usage=low_cpu_mem_usage,
)
text_encoder_2_state_dict = {k: v for k, v in state_dict.items() if "text_encoder_2." in k}
@@ -664,7 +591,6 @@ class StableDiffusionXLLoraLoaderMixin(LoraBaseMixin):
lora_scale=self.lora_scale,
adapter_name=adapter_name,
_pipeline=self,
low_cpu_mem_usage=low_cpu_mem_usage,
)
@classmethod
@@ -758,11 +684,6 @@ class StableDiffusionXLLoraLoaderMixin(LoraBaseMixin):
user_agent=user_agent,
allow_pickle=allow_pickle,
)
is_dora_scale_present = any("dora_scale" in k for k in state_dict)
if is_dora_scale_present:
warn_msg = "It seems like you are using a DoRA checkpoint that is not compatible in Diffusers at the moment. So, we are going to filter out the keys associated to 'dora_scale` from the state dict. If you think this is a mistake please open an issue https://github.com/huggingface/diffusers/issues/new."
logger.warning(warn_msg)
state_dict = {k: v for k, v in state_dict.items() if "dora_scale" not in k}
network_alphas = None
# TODO: replace it with a method from `state_dict_utils`
@@ -785,9 +706,7 @@ class StableDiffusionXLLoraLoaderMixin(LoraBaseMixin):
@classmethod
# Copied from diffusers.loaders.lora_pipeline.StableDiffusionLoraLoaderMixin.load_lora_into_unet
def load_lora_into_unet(
cls, state_dict, network_alphas, unet, adapter_name=None, _pipeline=None, low_cpu_mem_usage=False
):
def load_lora_into_unet(cls, state_dict, network_alphas, unet, adapter_name=None, _pipeline=None):
"""
This will load the LoRA layers specified in `state_dict` into `unet`.
@@ -805,16 +724,10 @@ class StableDiffusionXLLoraLoaderMixin(LoraBaseMixin):
adapter_name (`str`, *optional*):
Adapter name to be used for referencing the loaded adapter model. If not specified, it will use
`default_{i}` where i is the total number of adapters being loaded.
Speed up model loading only loading the pretrained LoRA weights and not initializing the random weights.
"""
if not USE_PEFT_BACKEND:
raise ValueError("PEFT backend is required for this method.")
if low_cpu_mem_usage and not is_peft_version(">=", "0.13.1"):
raise ValueError(
"`low_cpu_mem_usage=True` is not compatible with this `peft` version. Please update it with `pip install -U peft`."
)
# If the serialization format is new (introduced in https://github.com/huggingface/diffusers/pull/2918),
# then the `state_dict` keys should have `cls.unet_name` and/or `cls.text_encoder_name` as
# their prefixes.
@@ -824,11 +737,7 @@ class StableDiffusionXLLoraLoaderMixin(LoraBaseMixin):
# Load the layers corresponding to UNet.
logger.info(f"Loading {cls.unet_name}.")
unet.load_attn_procs(
state_dict,
network_alphas=network_alphas,
adapter_name=adapter_name,
_pipeline=_pipeline,
low_cpu_mem_usage=low_cpu_mem_usage,
state_dict, network_alphas=network_alphas, adapter_name=adapter_name, _pipeline=_pipeline
)
@classmethod
@@ -842,7 +751,6 @@ class StableDiffusionXLLoraLoaderMixin(LoraBaseMixin):
lora_scale=1.0,
adapter_name=None,
_pipeline=None,
low_cpu_mem_usage=False,
):
"""
This will load the LoRA layers specified in `state_dict` into `text_encoder`
@@ -865,25 +773,10 @@ class StableDiffusionXLLoraLoaderMixin(LoraBaseMixin):
adapter_name (`str`, *optional*):
Adapter name to be used for referencing the loaded adapter model. If not specified, it will use
`default_{i}` where i is the total number of adapters being loaded.
Speed up model loading by only loading the pretrained LoRA weights and not initializing the random weights.:
"""
if not USE_PEFT_BACKEND:
raise ValueError("PEFT backend is required for this method.")
peft_kwargs = {}
if low_cpu_mem_usage:
if not is_peft_version(">=", "0.13.1"):
raise ValueError(
"`low_cpu_mem_usage=True` is not compatible with this `peft` version. Please update it with `pip install -U peft`."
)
if not is_transformers_version(">", "4.45.2"):
# Note from sayakpaul: It's not in `transformers` stable yet.
# https://github.com/huggingface/transformers/pull/33725/
raise ValueError(
"`low_cpu_mem_usage=True` is not compatible with this `transformers` version. Please update it with `pip install -U transformers`."
)
peft_kwargs["low_cpu_mem_usage"] = low_cpu_mem_usage
from peft import LoraConfig
# If the serialization format is new (introduced in https://github.com/huggingface/diffusers/pull/2918),
@@ -954,7 +847,6 @@ class StableDiffusionXLLoraLoaderMixin(LoraBaseMixin):
adapter_name=adapter_name,
adapter_state_dict=text_encoder_lora_state_dict,
peft_config=lora_config,
**peft_kwargs,
)
# scale LoRA layers with `lora_scale`
@@ -1197,12 +1089,6 @@ class SD3LoraLoaderMixin(LoraBaseMixin):
allow_pickle=allow_pickle,
)
is_dora_scale_present = any("dora_scale" in k for k in state_dict)
if is_dora_scale_present:
warn_msg = "It seems like you are using a DoRA checkpoint that is not compatible in Diffusers at the moment. So, we are going to filter out the keys associated to 'dora_scale` from the state dict. If you think this is a mistake please open an issue https://github.com/huggingface/diffusers/issues/new."
logger.warning(warn_msg)
state_dict = {k: v for k, v in state_dict.items() if "dora_scale" not in k}
return state_dict
def load_lora_weights(
@@ -1223,22 +1109,15 @@ class SD3LoraLoaderMixin(LoraBaseMixin):
Parameters:
pretrained_model_name_or_path_or_dict (`str` or `os.PathLike` or `dict`):
See [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`].
kwargs (`dict`, *optional*):
See [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`].
adapter_name (`str`, *optional*):
Adapter name to be used for referencing the loaded adapter model. If not specified, it will use
`default_{i}` where i is the total number of adapters being loaded.
Speed up model loading by only loading the pretrained LoRA weights and not initializing the random weights.:
kwargs (`dict`, *optional*):
See [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`].
"""
if not USE_PEFT_BACKEND:
raise ValueError("PEFT backend is required for this method.")
low_cpu_mem_usage = kwargs.pop("low_cpu_mem_usage", _LOW_CPU_MEM_USAGE_DEFAULT_LORA)
if low_cpu_mem_usage and is_peft_version("<", "0.13.0"):
raise ValueError(
"`low_cpu_mem_usage=True` is not compatible with this `peft` version. Please update it with `pip install -U peft`."
)
# if a dict is passed, copy it instead of modifying it inplace
if isinstance(pretrained_model_name_or_path_or_dict, dict):
pretrained_model_name_or_path_or_dict = pretrained_model_name_or_path_or_dict.copy()
@@ -1246,7 +1125,7 @@ class SD3LoraLoaderMixin(LoraBaseMixin):
# First, ensure that the checkpoint is a compatible one and can be successfully loaded.
state_dict = self.lora_state_dict(pretrained_model_name_or_path_or_dict, **kwargs)
is_correct_format = all("lora" in key for key in state_dict.keys())
is_correct_format = all("lora" in key or "dora_scale" in key for key in state_dict.keys())
if not is_correct_format:
raise ValueError("Invalid LoRA checkpoint.")
@@ -1255,7 +1134,6 @@ class SD3LoraLoaderMixin(LoraBaseMixin):
transformer=getattr(self, self.transformer_name) if not hasattr(self, "transformer") else self.transformer,
adapter_name=adapter_name,
_pipeline=self,
low_cpu_mem_usage=low_cpu_mem_usage,
)
text_encoder_state_dict = {k: v for k, v in state_dict.items() if "text_encoder." in k}
@@ -1268,7 +1146,6 @@ class SD3LoraLoaderMixin(LoraBaseMixin):
lora_scale=self.lora_scale,
adapter_name=adapter_name,
_pipeline=self,
low_cpu_mem_usage=low_cpu_mem_usage,
)
text_encoder_2_state_dict = {k: v for k, v in state_dict.items() if "text_encoder_2." in k}
@@ -1281,13 +1158,10 @@ class SD3LoraLoaderMixin(LoraBaseMixin):
lora_scale=self.lora_scale,
adapter_name=adapter_name,
_pipeline=self,
low_cpu_mem_usage=low_cpu_mem_usage,
)
@classmethod
def load_lora_into_transformer(
cls, state_dict, transformer, adapter_name=None, _pipeline=None, low_cpu_mem_usage=False
):
def load_lora_into_transformer(cls, state_dict, transformer, adapter_name=None, _pipeline=None):
"""
This will load the LoRA layers specified in `state_dict` into `transformer`.
@@ -1301,13 +1175,7 @@ class SD3LoraLoaderMixin(LoraBaseMixin):
adapter_name (`str`, *optional*):
Adapter name to be used for referencing the loaded adapter model. If not specified, it will use
`default_{i}` where i is the total number of adapters being loaded.
Speed up model loading by only loading the pretrained LoRA weights and not initializing the random weights.:
"""
if low_cpu_mem_usage and is_peft_version("<", "0.13.0"):
raise ValueError(
"`low_cpu_mem_usage=True` is not compatible with this `peft` version. Please update it with `pip install -U peft`."
)
from peft import LoraConfig, inject_adapter_in_model, set_peft_model_state_dict
keys = list(state_dict.keys())
@@ -1351,37 +1219,17 @@ class SD3LoraLoaderMixin(LoraBaseMixin):
# otherwise loading LoRA weights will lead to an error
is_model_cpu_offload, is_sequential_cpu_offload = cls._optionally_disable_offloading(_pipeline)
peft_kwargs = {}
if is_peft_version(">=", "0.13.1"):
peft_kwargs["low_cpu_mem_usage"] = low_cpu_mem_usage
inject_adapter_in_model(lora_config, transformer, adapter_name=adapter_name)
incompatible_keys = set_peft_model_state_dict(transformer, state_dict, adapter_name)
inject_adapter_in_model(lora_config, transformer, adapter_name=adapter_name, **peft_kwargs)
incompatible_keys = set_peft_model_state_dict(transformer, state_dict, adapter_name, **peft_kwargs)
warn_msg = ""
if incompatible_keys is not None:
# Check only for unexpected keys.
# check only for unexpected keys
unexpected_keys = getattr(incompatible_keys, "unexpected_keys", None)
if unexpected_keys:
lora_unexpected_keys = [k for k in unexpected_keys if "lora_" in k and adapter_name in k]
if lora_unexpected_keys:
warn_msg = (
f"Loading adapter weights from state_dict led to unexpected keys found in the model:"
f" {', '.join(lora_unexpected_keys)}. "
)
# Filter missing keys specific to the current adapter.
missing_keys = getattr(incompatible_keys, "missing_keys", None)
if missing_keys:
lora_missing_keys = [k for k in missing_keys if "lora_" in k and adapter_name in k]
if lora_missing_keys:
warn_msg += (
f"Loading adapter weights from state_dict led to missing keys in the model:"
f" {', '.join(lora_missing_keys)}."
)
if warn_msg:
logger.warning(warn_msg)
logger.warning(
f"Loading adapter weights from state_dict led to unexpected keys not found in the model: "
f" {unexpected_keys}. "
)
# Offload back.
if is_model_cpu_offload:
@@ -1401,7 +1249,6 @@ class SD3LoraLoaderMixin(LoraBaseMixin):
lora_scale=1.0,
adapter_name=None,
_pipeline=None,
low_cpu_mem_usage=False,
):
"""
This will load the LoRA layers specified in `state_dict` into `text_encoder`
@@ -1424,25 +1271,10 @@ class SD3LoraLoaderMixin(LoraBaseMixin):
adapter_name (`str`, *optional*):
Adapter name to be used for referencing the loaded adapter model. If not specified, it will use
`default_{i}` where i is the total number of adapters being loaded.
Speed up model loading by only loading the pretrained LoRA weights and not initializing the random weights.:
"""
if not USE_PEFT_BACKEND:
raise ValueError("PEFT backend is required for this method.")
peft_kwargs = {}
if low_cpu_mem_usage:
if not is_peft_version(">=", "0.13.1"):
raise ValueError(
"`low_cpu_mem_usage=True` is not compatible with this `peft` version. Please update it with `pip install -U peft`."
)
if not is_transformers_version(">", "4.45.2"):
# Note from sayakpaul: It's not in `transformers` stable yet.
# https://github.com/huggingface/transformers/pull/33725/
raise ValueError(
"`low_cpu_mem_usage=True` is not compatible with this `transformers` version. Please update it with `pip install -U transformers`."
)
peft_kwargs["low_cpu_mem_usage"] = low_cpu_mem_usage
from peft import LoraConfig
# If the serialization format is new (introduced in https://github.com/huggingface/diffusers/pull/2918),
@@ -1513,7 +1345,6 @@ class SD3LoraLoaderMixin(LoraBaseMixin):
adapter_name=adapter_name,
adapter_state_dict=text_encoder_lora_state_dict,
peft_config=lora_config,
**peft_kwargs,
)
# scale LoRA layers with `lora_scale`
@@ -1756,13 +1587,9 @@ class FluxLoraLoaderMixin(LoraBaseMixin):
user_agent=user_agent,
allow_pickle=allow_pickle,
)
is_dora_scale_present = any("dora_scale" in k for k in state_dict)
if is_dora_scale_present:
warn_msg = "It seems like you are using a DoRA checkpoint that is not compatible in Diffusers at the moment. So, we are going to filter out the keys associated to 'dora_scale` from the state dict. If you think this is a mistake please open an issue https://github.com/huggingface/diffusers/issues/new."
logger.warning(warn_msg)
state_dict = {k: v for k, v in state_dict.items() if "dora_scale" not in k}
# TODO (sayakpaul): to a follow-up to clean and try to unify the conditions.
is_kohya = any(".lora_down.weight" in k for k in state_dict)
if is_kohya:
state_dict = _convert_kohya_flux_lora_to_diffusers(state_dict)
@@ -1819,17 +1646,10 @@ class FluxLoraLoaderMixin(LoraBaseMixin):
adapter_name (`str`, *optional*):
Adapter name to be used for referencing the loaded adapter model. If not specified, it will use
`default_{i}` where i is the total number of adapters being loaded.
Speed up model loading by only loading the pretrained LoRA weights and not initializing the random weights.:
"""
if not USE_PEFT_BACKEND:
raise ValueError("PEFT backend is required for this method.")
low_cpu_mem_usage = kwargs.pop("low_cpu_mem_usage", _LOW_CPU_MEM_USAGE_DEFAULT_LORA)
if low_cpu_mem_usage and not is_peft_version(">=", "0.13.1"):
raise ValueError(
"`low_cpu_mem_usage=True` is not compatible with this `peft` version. Please update it with `pip install -U peft`."
)
# if a dict is passed, copy it instead of modifying it inplace
if isinstance(pretrained_model_name_or_path_or_dict, dict):
pretrained_model_name_or_path_or_dict = pretrained_model_name_or_path_or_dict.copy()
@@ -1839,7 +1659,7 @@ class FluxLoraLoaderMixin(LoraBaseMixin):
pretrained_model_name_or_path_or_dict, return_alphas=True, **kwargs
)
is_correct_format = all("lora" in key for key in state_dict.keys())
is_correct_format = all("lora" in key or "dora_scale" in key for key in state_dict.keys())
if not is_correct_format:
raise ValueError("Invalid LoRA checkpoint.")
@@ -1849,7 +1669,6 @@ class FluxLoraLoaderMixin(LoraBaseMixin):
transformer=getattr(self, self.transformer_name) if not hasattr(self, "transformer") else self.transformer,
adapter_name=adapter_name,
_pipeline=self,
low_cpu_mem_usage=low_cpu_mem_usage,
)
text_encoder_state_dict = {k: v for k, v in state_dict.items() if "text_encoder." in k}
@@ -1862,13 +1681,10 @@ class FluxLoraLoaderMixin(LoraBaseMixin):
lora_scale=self.lora_scale,
adapter_name=adapter_name,
_pipeline=self,
low_cpu_mem_usage=low_cpu_mem_usage,
)
@classmethod
def load_lora_into_transformer(
cls, state_dict, network_alphas, transformer, adapter_name=None, _pipeline=None, low_cpu_mem_usage=False
):
def load_lora_into_transformer(cls, state_dict, network_alphas, transformer, adapter_name=None, _pipeline=None):
"""
This will load the LoRA layers specified in `state_dict` into `transformer`.
@@ -1886,13 +1702,7 @@ class FluxLoraLoaderMixin(LoraBaseMixin):
adapter_name (`str`, *optional*):
Adapter name to be used for referencing the loaded adapter model. If not specified, it will use
`default_{i}` where i is the total number of adapters being loaded.
Speed up model loading by only loading the pretrained LoRA weights and not initializing the random weights.:
"""
if low_cpu_mem_usage and not is_peft_version(">=", "0.13.1"):
raise ValueError(
"`low_cpu_mem_usage=True` is not compatible with this `peft` version. Please update it with `pip install -U peft`."
)
from peft import LoraConfig, inject_adapter_in_model, set_peft_model_state_dict
keys = list(state_dict.keys())
@@ -1941,37 +1751,17 @@ class FluxLoraLoaderMixin(LoraBaseMixin):
# otherwise loading LoRA weights will lead to an error
is_model_cpu_offload, is_sequential_cpu_offload = cls._optionally_disable_offloading(_pipeline)
peft_kwargs = {}
if is_peft_version(">=", "0.13.1"):
peft_kwargs["low_cpu_mem_usage"] = low_cpu_mem_usage
inject_adapter_in_model(lora_config, transformer, adapter_name=adapter_name)
incompatible_keys = set_peft_model_state_dict(transformer, state_dict, adapter_name)
inject_adapter_in_model(lora_config, transformer, adapter_name=adapter_name, **peft_kwargs)
incompatible_keys = set_peft_model_state_dict(transformer, state_dict, adapter_name, **peft_kwargs)
warn_msg = ""
if incompatible_keys is not None:
# Check only for unexpected keys.
# check only for unexpected keys
unexpected_keys = getattr(incompatible_keys, "unexpected_keys", None)
if unexpected_keys:
lora_unexpected_keys = [k for k in unexpected_keys if "lora_" in k and adapter_name in k]
if lora_unexpected_keys:
warn_msg = (
f"Loading adapter weights from state_dict led to unexpected keys found in the model:"
f" {', '.join(lora_unexpected_keys)}. "
)
# Filter missing keys specific to the current adapter.
missing_keys = getattr(incompatible_keys, "missing_keys", None)
if missing_keys:
lora_missing_keys = [k for k in missing_keys if "lora_" in k and adapter_name in k]
if lora_missing_keys:
warn_msg += (
f"Loading adapter weights from state_dict led to missing keys in the model:"
f" {', '.join(lora_missing_keys)}."
)
if warn_msg:
logger.warning(warn_msg)
logger.warning(
f"Loading adapter weights from state_dict led to unexpected keys not found in the model: "
f" {unexpected_keys}. "
)
# Offload back.
if is_model_cpu_offload:
@@ -1991,7 +1781,6 @@ class FluxLoraLoaderMixin(LoraBaseMixin):
lora_scale=1.0,
adapter_name=None,
_pipeline=None,
low_cpu_mem_usage=False,
):
"""
This will load the LoRA layers specified in `state_dict` into `text_encoder`
@@ -2014,25 +1803,10 @@ class FluxLoraLoaderMixin(LoraBaseMixin):
adapter_name (`str`, *optional*):
Adapter name to be used for referencing the loaded adapter model. If not specified, it will use
`default_{i}` where i is the total number of adapters being loaded.
Speed up model loading by only loading the pretrained LoRA weights and not initializing the random weights.:
"""
if not USE_PEFT_BACKEND:
raise ValueError("PEFT backend is required for this method.")
peft_kwargs = {}
if low_cpu_mem_usage:
if not is_peft_version(">=", "0.13.1"):
raise ValueError(
"`low_cpu_mem_usage=True` is not compatible with this `peft` version. Please update it with `pip install -U peft`."
)
if not is_transformers_version(">", "4.45.2"):
# Note from sayakpaul: It's not in `transformers` stable yet.
# https://github.com/huggingface/transformers/pull/33725/
raise ValueError(
"`low_cpu_mem_usage=True` is not compatible with this `transformers` version. Please update it with `pip install -U transformers`."
)
peft_kwargs["low_cpu_mem_usage"] = low_cpu_mem_usage
from peft import LoraConfig
# If the serialization format is new (introduced in https://github.com/huggingface/diffusers/pull/2918),
@@ -2103,7 +1877,6 @@ class FluxLoraLoaderMixin(LoraBaseMixin):
adapter_name=adapter_name,
adapter_state_dict=text_encoder_lora_state_dict,
peft_config=lora_config,
**peft_kwargs,
)
# scale LoRA layers with `lora_scale`
@@ -2311,30 +2084,14 @@ class AmusedLoraLoaderMixin(StableDiffusionLoraLoaderMixin):
inject_adapter_in_model(lora_config, transformer, adapter_name=adapter_name)
incompatible_keys = set_peft_model_state_dict(transformer, state_dict, adapter_name)
warn_msg = ""
if incompatible_keys is not None:
# Check only for unexpected keys.
# check only for unexpected keys
unexpected_keys = getattr(incompatible_keys, "unexpected_keys", None)
if unexpected_keys:
lora_unexpected_keys = [k for k in unexpected_keys if "lora_" in k and adapter_name in k]
if lora_unexpected_keys:
warn_msg = (
f"Loading adapter weights from state_dict led to unexpected keys found in the model:"
f" {', '.join(lora_unexpected_keys)}. "
)
# Filter missing keys specific to the current adapter.
missing_keys = getattr(incompatible_keys, "missing_keys", None)
if missing_keys:
lora_missing_keys = [k for k in missing_keys if "lora_" in k and adapter_name in k]
if lora_missing_keys:
warn_msg += (
f"Loading adapter weights from state_dict led to missing keys in the model:"
f" {', '.join(lora_missing_keys)}."
)
if warn_msg:
logger.warning(warn_msg)
logger.warning(
f"Loading adapter weights from state_dict led to unexpected keys not found in the model: "
f" {unexpected_keys}. "
)
# Offload back.
if is_model_cpu_offload:
@@ -2354,7 +2111,6 @@ class AmusedLoraLoaderMixin(StableDiffusionLoraLoaderMixin):
lora_scale=1.0,
adapter_name=None,
_pipeline=None,
low_cpu_mem_usage=False,
):
"""
This will load the LoRA layers specified in `state_dict` into `text_encoder`
@@ -2377,25 +2133,10 @@ class AmusedLoraLoaderMixin(StableDiffusionLoraLoaderMixin):
adapter_name (`str`, *optional*):
Adapter name to be used for referencing the loaded adapter model. If not specified, it will use
`default_{i}` where i is the total number of adapters being loaded.
Speed up model loading by only loading the pretrained LoRA weights and not initializing the random weights.:
"""
if not USE_PEFT_BACKEND:
raise ValueError("PEFT backend is required for this method.")
peft_kwargs = {}
if low_cpu_mem_usage:
if not is_peft_version(">=", "0.13.1"):
raise ValueError(
"`low_cpu_mem_usage=True` is not compatible with this `peft` version. Please update it with `pip install -U peft`."
)
if not is_transformers_version(">", "4.45.2"):
# Note from sayakpaul: It's not in `transformers` stable yet.
# https://github.com/huggingface/transformers/pull/33725/
raise ValueError(
"`low_cpu_mem_usage=True` is not compatible with this `transformers` version. Please update it with `pip install -U transformers`."
)
peft_kwargs["low_cpu_mem_usage"] = low_cpu_mem_usage
from peft import LoraConfig
# If the serialization format is new (introduced in https://github.com/huggingface/diffusers/pull/2918),
@@ -2466,7 +2207,6 @@ class AmusedLoraLoaderMixin(StableDiffusionLoraLoaderMixin):
adapter_name=adapter_name,
adapter_state_dict=text_encoder_lora_state_dict,
peft_config=lora_config,
**peft_kwargs,
)
# scale LoRA layers with `lora_scale`
@@ -2634,12 +2374,6 @@ class CogVideoXLoraLoaderMixin(LoraBaseMixin):
allow_pickle=allow_pickle,
)
is_dora_scale_present = any("dora_scale" in k for k in state_dict)
if is_dora_scale_present:
warn_msg = "It seems like you are using a DoRA checkpoint that is not compatible in Diffusers at the moment. So, we are going to filter out the keys associated to 'dora_scale` from the state dict. If you think this is a mistake please open an issue https://github.com/huggingface/diffusers/issues/new."
logger.warning(warn_msg)
state_dict = {k: v for k, v in state_dict.items() if "dora_scale" not in k}
return state_dict
def load_lora_weights(
@@ -2655,22 +2389,15 @@ class CogVideoXLoraLoaderMixin(LoraBaseMixin):
Parameters:
pretrained_model_name_or_path_or_dict (`str` or `os.PathLike` or `dict`):
See [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`].
kwargs (`dict`, *optional*):
See [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`].
adapter_name (`str`, *optional*):
Adapter name to be used for referencing the loaded adapter model. If not specified, it will use
`default_{i}` where i is the total number of adapters being loaded.
Speed up model loading by only loading the pretrained LoRA weights and not initializing the random weights.:
kwargs (`dict`, *optional*):
See [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`].
"""
if not USE_PEFT_BACKEND:
raise ValueError("PEFT backend is required for this method.")
low_cpu_mem_usage = kwargs.pop("low_cpu_mem_usage", _LOW_CPU_MEM_USAGE_DEFAULT_LORA)
if low_cpu_mem_usage and is_peft_version("<", "0.13.0"):
raise ValueError(
"`low_cpu_mem_usage=True` is not compatible with this `peft` version. Please update it with `pip install -U peft`."
)
# if a dict is passed, copy it instead of modifying it inplace
if isinstance(pretrained_model_name_or_path_or_dict, dict):
pretrained_model_name_or_path_or_dict = pretrained_model_name_or_path_or_dict.copy()
@@ -2678,7 +2405,7 @@ class CogVideoXLoraLoaderMixin(LoraBaseMixin):
# First, ensure that the checkpoint is a compatible one and can be successfully loaded.
state_dict = self.lora_state_dict(pretrained_model_name_or_path_or_dict, **kwargs)
is_correct_format = all("lora" in key for key in state_dict.keys())
is_correct_format = all("lora" in key or "dora_scale" in key for key in state_dict.keys())
if not is_correct_format:
raise ValueError("Invalid LoRA checkpoint.")
@@ -2687,14 +2414,11 @@ class CogVideoXLoraLoaderMixin(LoraBaseMixin):
transformer=getattr(self, self.transformer_name) if not hasattr(self, "transformer") else self.transformer,
adapter_name=adapter_name,
_pipeline=self,
low_cpu_mem_usage=low_cpu_mem_usage,
)
@classmethod
# Copied from diffusers.loaders.lora_pipeline.SD3LoraLoaderMixin.load_lora_into_transformer
def load_lora_into_transformer(
cls, state_dict, transformer, adapter_name=None, _pipeline=None, low_cpu_mem_usage=False
):
def load_lora_into_transformer(cls, state_dict, transformer, adapter_name=None, _pipeline=None):
"""
This will load the LoRA layers specified in `state_dict` into `transformer`.
@@ -2708,13 +2432,7 @@ class CogVideoXLoraLoaderMixin(LoraBaseMixin):
adapter_name (`str`, *optional*):
Adapter name to be used for referencing the loaded adapter model. If not specified, it will use
`default_{i}` where i is the total number of adapters being loaded.
Speed up model loading by only loading the pretrained LoRA weights and not initializing the random weights.:
"""
if low_cpu_mem_usage and is_peft_version("<", "0.13.0"):
raise ValueError(
"`low_cpu_mem_usage=True` is not compatible with this `peft` version. Please update it with `pip install -U peft`."
)
from peft import LoraConfig, inject_adapter_in_model, set_peft_model_state_dict
keys = list(state_dict.keys())
@@ -2758,37 +2476,17 @@ class CogVideoXLoraLoaderMixin(LoraBaseMixin):
# otherwise loading LoRA weights will lead to an error
is_model_cpu_offload, is_sequential_cpu_offload = cls._optionally_disable_offloading(_pipeline)
peft_kwargs = {}
if is_peft_version(">=", "0.13.1"):
peft_kwargs["low_cpu_mem_usage"] = low_cpu_mem_usage
inject_adapter_in_model(lora_config, transformer, adapter_name=adapter_name)
incompatible_keys = set_peft_model_state_dict(transformer, state_dict, adapter_name)
inject_adapter_in_model(lora_config, transformer, adapter_name=adapter_name, **peft_kwargs)
incompatible_keys = set_peft_model_state_dict(transformer, state_dict, adapter_name, **peft_kwargs)
warn_msg = ""
if incompatible_keys is not None:
# Check only for unexpected keys.
# check only for unexpected keys
unexpected_keys = getattr(incompatible_keys, "unexpected_keys", None)
if unexpected_keys:
lora_unexpected_keys = [k for k in unexpected_keys if "lora_" in k and adapter_name in k]
if lora_unexpected_keys:
warn_msg = (
f"Loading adapter weights from state_dict led to unexpected keys found in the model:"
f" {', '.join(lora_unexpected_keys)}. "
)
# Filter missing keys specific to the current adapter.
missing_keys = getattr(incompatible_keys, "missing_keys", None)
if missing_keys:
lora_missing_keys = [k for k in missing_keys if "lora_" in k and adapter_name in k]
if lora_missing_keys:
warn_msg += (
f"Loading adapter weights from state_dict led to missing keys in the model:"
f" {', '.join(lora_missing_keys)}."
)
if warn_msg:
logger.warning(warn_msg)
logger.warning(
f"Loading adapter weights from state_dict led to unexpected keys not found in the model: "
f" {unexpected_keys}. "
)
# Offload back.
if is_model_cpu_offload:
+1 -1
View File
@@ -1513,7 +1513,7 @@ def _legacy_load_scheduler(
)
deprecate("prediction_type", "1.0.0", deprecation_message)
scheduler_config = SCHEDULER_DEFAULT_CONFIG
scheduler_config = copy.deepcopy(SCHEDULER_DEFAULT_CONFIG)
model_type = infer_diffusers_model_type(checkpoint=checkpoint)
global_step = checkpoint["global_step"] if "global_step" in checkpoint else None
+8 -39
View File
@@ -115,9 +115,6 @@ class UNet2DConditionLoadersMixin:
`default_{i}` where i is the total number of adapters being loaded.
weight_name (`str`, *optional*, defaults to None):
Name of the serialized state dict file.
low_cpu_mem_usage (`bool`, *optional*):
Speed up model loading by only loading the pretrained LoRA weights and not initializing the random
weights.
Example:
@@ -145,14 +142,8 @@ class UNet2DConditionLoadersMixin:
adapter_name = kwargs.pop("adapter_name", None)
_pipeline = kwargs.pop("_pipeline", None)
network_alphas = kwargs.pop("network_alphas", None)
low_cpu_mem_usage = kwargs.pop("low_cpu_mem_usage", False)
allow_pickle = False
if low_cpu_mem_usage and is_peft_version("<=", "0.13.0"):
raise ValueError(
"`low_cpu_mem_usage=True` is not compatible with this `peft` version. Please update it with `pip install -U peft`."
)
if use_safetensors is None:
use_safetensors = True
allow_pickle = True
@@ -218,7 +209,6 @@ class UNet2DConditionLoadersMixin:
network_alphas=network_alphas,
adapter_name=adapter_name,
_pipeline=_pipeline,
low_cpu_mem_usage=low_cpu_mem_usage,
)
else:
raise ValueError(
@@ -278,9 +268,7 @@ class UNet2DConditionLoadersMixin:
return attn_processors
def _process_lora(
self, state_dict, unet_identifier_key, network_alphas, adapter_name, _pipeline, low_cpu_mem_usage
):
def _process_lora(self, state_dict, unet_identifier_key, network_alphas, adapter_name, _pipeline):
# This method does the following things:
# 1. Filters the `state_dict` with keys matching `unet_identifier_key` when using the non-legacy
# format. For legacy format no filtering is applied.
@@ -347,37 +335,18 @@ class UNet2DConditionLoadersMixin:
# In case the pipeline has been already offloaded to CPU - temporarily remove the hooks
# otherwise loading LoRA weights will lead to an error
is_model_cpu_offload, is_sequential_cpu_offload = self._optionally_disable_offloading(_pipeline)
peft_kwargs = {}
if is_peft_version(">=", "0.13.1"):
peft_kwargs["low_cpu_mem_usage"] = low_cpu_mem_usage
inject_adapter_in_model(lora_config, self, adapter_name=adapter_name, **peft_kwargs)
incompatible_keys = set_peft_model_state_dict(self, state_dict, adapter_name, **peft_kwargs)
inject_adapter_in_model(lora_config, self, adapter_name=adapter_name)
incompatible_keys = set_peft_model_state_dict(self, state_dict, adapter_name)
warn_msg = ""
if incompatible_keys is not None:
# Check only for unexpected keys.
# check only for unexpected keys
unexpected_keys = getattr(incompatible_keys, "unexpected_keys", None)
if unexpected_keys:
lora_unexpected_keys = [k for k in unexpected_keys if "lora_" in k and adapter_name in k]
if lora_unexpected_keys:
warn_msg = (
f"Loading adapter weights from state_dict led to unexpected keys found in the model:"
f" {', '.join(lora_unexpected_keys)}. "
)
# Filter missing keys specific to the current adapter.
missing_keys = getattr(incompatible_keys, "missing_keys", None)
if missing_keys:
lora_missing_keys = [k for k in missing_keys if "lora_" in k and adapter_name in k]
if lora_missing_keys:
warn_msg += (
f"Loading adapter weights from state_dict led to missing keys in the model:"
f" {', '.join(lora_missing_keys)}."
)
if warn_msg:
logger.warning(warn_msg)
logger.warning(
f"Loading adapter weights from state_dict led to unexpected keys not found in the model: "
f" {unexpected_keys}. "
)
return is_model_cpu_offload, is_sequential_cpu_offload
-2
View File
@@ -54,7 +54,6 @@ if is_torch_available():
_import_structure["transformers.stable_audio_transformer"] = ["StableAudioDiTModel"]
_import_structure["transformers.t5_film_transformer"] = ["T5FilmDecoder"]
_import_structure["transformers.transformer_2d"] = ["Transformer2DModel"]
_import_structure["transformers.transformer_cogview3plus"] = ["CogView3PlusTransformer2DModel"]
_import_structure["transformers.transformer_flux"] = ["FluxTransformer2DModel"]
_import_structure["transformers.transformer_sd3"] = ["SD3Transformer2DModel"]
_import_structure["transformers.transformer_temporal"] = ["TransformerTemporalModel"]
@@ -99,7 +98,6 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
from .transformers import (
AuraFlowTransformer2DModel,
CogVideoXTransformer3DModel,
CogView3PlusTransformer2DModel,
DiTTransformer2DModel,
DualTransformer2DModel,
FluxTransformer2DModel,
+53 -53
View File
@@ -30,10 +30,10 @@ class MultiAdapter(ModelMixin):
MultiAdapter is a wrapper model that contains multiple adapter models and merges their outputs according to
user-assigned weighting.
This model inherits from [`ModelMixin`]. Check the superclass documentation for common methods such as downloading
or saving.
This model inherits from [`ModelMixin`]. Check the superclass documentation for the generic methods the library
implements for all the model (such as downloading or saving, etc.)
Args:
Parameters:
adapters (`List[T2IAdapter]`, *optional*, defaults to None):
A list of `T2IAdapter` model instances.
"""
@@ -77,13 +77,11 @@ class MultiAdapter(ModelMixin):
r"""
Args:
xs (`torch.Tensor`):
A tensor of shape (batch, channel, height, width) representing input images for multiple adapter
models, concatenated along dimension 1(channel dimension). The `channel` dimension should be equal to
`num_adapter` * number of channel per image.
(batch, channel, height, width) input images for multiple adapter models concated along dimension 1,
`channel` should equal to `num_adapter` * "number of channel of image".
adapter_weights (`List[float]`, *optional*, defaults to None):
A list of floats representing the weights which will be multiplied by each adapter's output before
summing them together. If `None`, equal weights will be used for all adapters.
List of floats representing the weight which will be multiply to each adapter's output before adding
them together.
"""
if adapter_weights is None:
adapter_weights = torch.tensor([1 / self.num_adapter] * self.num_adapter)
@@ -111,24 +109,24 @@ class MultiAdapter(ModelMixin):
variant: Optional[str] = None,
):
"""
Save a model and its configuration file to a specified directory, allowing it to be re-loaded with the
Save a model and its configuration file to a directory, so that it can be re-loaded using the
`[`~models.adapter.MultiAdapter.from_pretrained`]` class method.
Args:
Arguments:
save_directory (`str` or `os.PathLike`):
The directory where the model will be saved. If the directory does not exist, it will be created.
is_main_process (`bool`, optional, defaults=True):
Indicates whether current process is the main process or not. Useful for distributed training (e.g.,
TPUs) and need to call this function on all processes. In this case, set `is_main_process=True` only
for the main process to avoid race conditions.
Directory to which to save. Will be created if it doesn't exist.
is_main_process (`bool`, *optional*, defaults to `True`):
Whether the process calling this is the main process or not. Useful when in distributed training like
TPUs and need to call this function on all processes. In this case, set `is_main_process=True` only on
the main process to avoid race conditions.
save_function (`Callable`):
Function used to save the state dictionary. Useful for distributed training (e.g., TPUs) to replace
`torch.save` with another method. Can also be configured using`DIFFUSERS_SAVE_MODE` environment
variable.
safe_serialization (`bool`, optional, defaults=True):
If `True`, save the model using `safetensors`. If `False`, save the model with `pickle`.
The function to use to save the state dictionary. Useful on distributed training like TPUs when one
need to replace `torch.save` by another method. Can be configured with the environment variable
`DIFFUSERS_SAVE_MODE`.
safe_serialization (`bool`, *optional*, defaults to `True`):
Whether to save the model using `safetensors` or the traditional PyTorch way (that uses `pickle`).
variant (`str`, *optional*):
If specified, weights are saved in the format `pytorch_model.<variant>.bin`.
If specified, weights are saved in the format pytorch_model.<variant>.bin.
"""
idx = 0
model_path_to_save = save_directory
@@ -147,17 +145,19 @@ class MultiAdapter(ModelMixin):
@classmethod
def from_pretrained(cls, pretrained_model_path: Optional[Union[str, os.PathLike]], **kwargs):
r"""
Instantiate a pretrained `MultiAdapter` model from multiple pre-trained adapter models.
Instantiate a pretrained MultiAdapter model from multiple pre-trained adapter models.
The model is set in evaluation mode by default using `model.eval()` (Dropout modules are deactivated). To train
the model, set it back to training mode using `model.train()`.
the model, you should first set it back in training mode with `model.train()`.
Warnings:
*Weights from XXX not initialized from pretrained model* means that the weights of XXX are not pretrained
with the rest of the model. It is up to you to train those weights with a downstream fine-tuning. *Weights
from XXX not used in YYY* means that the layer XXX is not used by YYY, so those weights are discarded.
The warning *Weights from XXX not initialized from pretrained model* means that the weights of XXX do not come
pretrained with the rest of the model. It is up to you to train those weights with a downstream fine-tuning
task.
Args:
The warning *Weights from XXX not used in YYY* means that the layer XXX is not used by YYY, therefore those
weights are discarded.
Parameters:
pretrained_model_path (`os.PathLike`):
A path to a *directory* containing model weights saved using
[`~diffusers.models.adapter.MultiAdapter.save_pretrained`], e.g., `./my_model_directory/adapter`.
@@ -175,20 +175,20 @@ class MultiAdapter(ModelMixin):
more information about each option see [designing a device
map](https://hf.co/docs/accelerate/main/en/usage_guides/big_modeling#designing-a-device-map).
max_memory (`Dict`, *optional*):
A dictionary mapping device identifiers to their maximum memory. Default to the maximum memory
available for each GPU and the available CPU RAM if unset.
A dictionary device identifier to maximum memory. Will default to the maximum memory available for each
GPU and the available CPU RAM if unset.
low_cpu_mem_usage (`bool`, *optional*, defaults to `True` if torch version >= 1.9.0 else `False`):
Speed up model loading by not initializing the weights and only loading the pre-trained weights. This
also tries to not use more than 1x model size in CPU memory (including peak memory) while loading the
model. This is only supported when torch version >= 1.9.0. If you are using an older version of torch,
setting this argument to `True` will raise an error.
variant (`str`, *optional*):
If specified, load weights from a `variant` file (*e.g.* pytorch_model.<variant>.bin). `variant` will
be ignored when using `from_flax`.
If specified load weights from `variant` filename, *e.g.* pytorch_model.<variant>.bin. `variant` is
ignored when using `from_flax`.
use_safetensors (`bool`, *optional*, defaults to `None`):
If `None`, the `safetensors` weights will be downloaded if available **and** if`safetensors` library is
installed. If `True`, the model will be forcibly loaded from`safetensors` weights. If `False`,
`safetensors` is not used.
If set to `None`, the `safetensors` weights will be downloaded if they're available **and** if the
`safetensors` library is installed. If set to `True`, the model will be forcibly loaded from
`safetensors` weights. If set to `False`, loading will *not* use `safetensors`.
"""
idx = 0
adapters = []
@@ -223,22 +223,22 @@ class T2IAdapter(ModelMixin, ConfigMixin):
and
[AdapterLight](https://github.com/TencentARC/T2I-Adapter/blob/686de4681515662c0ac2ffa07bf5dda83af1038a/ldm/modules/encoders/adapter.py#L235).
This model inherits from [`ModelMixin`]. Check the superclass documentation for the common methods, such as
downloading or saving.
This model inherits from [`ModelMixin`]. Check the superclass documentation for the generic methods the library
implements for all the model (such as downloading or saving, etc.)
Args:
in_channels (`int`, *optional*, defaults to `3`):
The number of channels in the adapter's input (*control image*). Set it to 1 if you're using a gray scale
image.
Parameters:
in_channels (`int`, *optional*, defaults to 3):
Number of channels of Aapter's input(*control image*). Set this parameter to 1 if you're using gray scale
image as *control image*.
channels (`List[int]`, *optional*, defaults to `(320, 640, 1280, 1280)`):
The number of channels in each downsample block's output hidden state. The `len(block_out_channels)`
determines the number of downsample blocks in the adapter.
num_res_blocks (`int`, *optional*, defaults to `2`):
The number of channel of each downsample block's output hidden state. The `len(block_out_channels)` will
also determine the number of downsample blocks in the Adapter.
num_res_blocks (`int`, *optional*, defaults to 2):
Number of ResNet blocks in each downsample block.
downscale_factor (`int`, *optional*, defaults to `8`):
downscale_factor (`int`, *optional*, defaults to 8):
A factor that determines the total downscale factor of the Adapter.
adapter_type (`str`, *optional*, defaults to `full_adapter`):
Adapter type (`full_adapter` or `full_adapter_xl` or `light_adapter`) to use.
The type of Adapter to use. Choose either `full_adapter` or `full_adapter_xl` or `light_adapter`.
"""
@register_to_config
@@ -393,7 +393,7 @@ class AdapterBlock(nn.Module):
An AdapterBlock is a helper model that contains multiple ResNet-like blocks. It is used in the `FullAdapter` and
`FullAdapterXL` models.
Args:
Parameters:
in_channels (`int`):
Number of channels of AdapterBlock's input.
out_channels (`int`):
@@ -401,7 +401,7 @@ class AdapterBlock(nn.Module):
num_res_blocks (`int`):
Number of ResNet blocks in the AdapterBlock.
down (`bool`, *optional*, defaults to `False`):
If `True`, perform downsampling on AdapterBlock's input.
Whether to perform downsampling on AdapterBlock's input.
"""
def __init__(self, in_channels: int, out_channels: int, num_res_blocks: int, down: bool = False):
@@ -440,7 +440,7 @@ class AdapterResnetBlock(nn.Module):
r"""
An `AdapterResnetBlock` is a helper model that implements a ResNet-like block.
Args:
Parameters:
channels (`int`):
Number of channels of AdapterResnetBlock's input and output.
"""
@@ -518,7 +518,7 @@ class LightAdapterBlock(nn.Module):
A `LightAdapterBlock` is a helper model that contains multiple `LightAdapterResnetBlocks`. It is used in the
`LightAdapter` model.
Args:
Parameters:
in_channels (`int`):
Number of channels of LightAdapterBlock's input.
out_channels (`int`):
@@ -526,7 +526,7 @@ class LightAdapterBlock(nn.Module):
num_res_blocks (`int`):
Number of LightAdapterResnetBlocks in the LightAdapterBlock.
down (`bool`, *optional*, defaults to `False`):
If `True`, perform downsampling on LightAdapterBlock's input.
Whether to perform downsampling on LightAdapterBlock's input.
"""
def __init__(self, in_channels: int, out_channels: int, num_res_blocks: int, down: bool = False):
@@ -561,7 +561,7 @@ class LightAdapterResnetBlock(nn.Module):
A `LightAdapterResnetBlock` is a helper model that implements a ResNet-like block with a slightly different
architecture than `AdapterResnetBlock`.
Args:
Parameters:
channels (`int`):
Number of channels of LightAdapterResnetBlock's input and output.
"""
+2 -3
View File
@@ -122,7 +122,6 @@ class Attention(nn.Module):
out_dim: int = None,
context_pre_only=None,
pre_only=False,
elementwise_affine: bool = True,
):
super().__init__()
@@ -180,8 +179,8 @@ class Attention(nn.Module):
self.norm_q = None
self.norm_k = None
elif qk_norm == "layer_norm":
self.norm_q = nn.LayerNorm(dim_head, eps=eps, elementwise_affine=elementwise_affine)
self.norm_k = nn.LayerNorm(dim_head, eps=eps, elementwise_affine=elementwise_affine)
self.norm_q = nn.LayerNorm(dim_head, eps=eps)
self.norm_k = nn.LayerNorm(dim_head, eps=eps)
elif qk_norm == "fp32_layer_norm":
self.norm_q = FP32LayerNorm(dim_head, elementwise_affine=False, bias=False, eps=eps)
self.norm_k = FP32LayerNorm(dim_head, elementwise_affine=False, bias=False, eps=eps)
@@ -18,7 +18,6 @@ import torch.nn as nn
from ...configuration_utils import ConfigMixin, register_to_config
from ...loaders.single_file_model import FromOriginalModelMixin
from ...utils import deprecate
from ...utils.accelerate_utils import apply_forward_hook
from ..attention_processor import (
ADDED_KV_ATTENTION_PROCESSORS,
@@ -246,18 +245,6 @@ class AutoencoderKL(ModelMixin, ConfigMixin, FromOriginalModelMixin):
self.set_attn_processor(processor)
def _encode(self, x: torch.Tensor) -> torch.Tensor:
batch_size, num_channels, height, width = x.shape
if self.use_tiling and (width > self.tile_sample_min_size or height > self.tile_sample_min_size):
return self._tiled_encode(x)
enc = self.encoder(x)
if self.quant_conv is not None:
enc = self.quant_conv(enc)
return enc
@apply_forward_hook
def encode(
self, x: torch.Tensor, return_dict: bool = True
@@ -274,13 +261,21 @@ class AutoencoderKL(ModelMixin, ConfigMixin, FromOriginalModelMixin):
The latent representations of the encoded images. If `return_dict` is True, a
[`~models.autoencoder_kl.AutoencoderKLOutput`] is returned, otherwise a plain `tuple` is returned.
"""
if self.use_tiling and (x.shape[-1] > self.tile_sample_min_size or x.shape[-2] > self.tile_sample_min_size):
return self.tiled_encode(x, return_dict=return_dict)
if self.use_slicing and x.shape[0] > 1:
encoded_slices = [self._encode(x_slice) for x_slice in x.split(1)]
encoded_slices = [self.encoder(x_slice) for x_slice in x.split(1)]
h = torch.cat(encoded_slices)
else:
h = self._encode(x)
h = self.encoder(x)
posterior = DiagonalGaussianDistribution(h)
if self.quant_conv is not None:
moments = self.quant_conv(h)
else:
moments = h
posterior = DiagonalGaussianDistribution(moments)
if not return_dict:
return (posterior,)
@@ -342,54 +337,6 @@ class AutoencoderKL(ModelMixin, ConfigMixin, FromOriginalModelMixin):
b[:, :, :, x] = a[:, :, :, -blend_extent + x] * (1 - x / blend_extent) + b[:, :, :, x] * (x / blend_extent)
return b
def _tiled_encode(self, x: torch.Tensor) -> torch.Tensor:
r"""Encode a batch of images using a tiled encoder.
When this option is enabled, the VAE will split the input tensor into tiles to compute encoding in several
steps. This is useful to keep memory use constant regardless of image size. The end result of tiled encoding is
different from non-tiled encoding because each tile uses a different encoder. To avoid tiling artifacts, the
tiles overlap and are blended together to form a smooth output. You may still see tile-sized changes in the
output, but they should be much less noticeable.
Args:
x (`torch.Tensor`): Input batch of images.
Returns:
`torch.Tensor`:
The latent representation of the encoded videos.
"""
overlap_size = int(self.tile_sample_min_size * (1 - self.tile_overlap_factor))
blend_extent = int(self.tile_latent_min_size * self.tile_overlap_factor)
row_limit = self.tile_latent_min_size - blend_extent
# Split the image into 512x512 tiles and encode them separately.
rows = []
for i in range(0, x.shape[2], overlap_size):
row = []
for j in range(0, x.shape[3], overlap_size):
tile = x[:, :, i : i + self.tile_sample_min_size, j : j + self.tile_sample_min_size]
tile = self.encoder(tile)
if self.config.use_quant_conv:
tile = self.quant_conv(tile)
row.append(tile)
rows.append(row)
result_rows = []
for i, row in enumerate(rows):
result_row = []
for j, tile in enumerate(row):
# blend the above tile and the left tile
# to the current tile and add the current tile to the result row
if i > 0:
tile = self.blend_v(rows[i - 1][j], tile, blend_extent)
if j > 0:
tile = self.blend_h(row[j - 1], tile, blend_extent)
result_row.append(tile[:, :, :row_limit, :row_limit])
result_rows.append(torch.cat(result_row, dim=3))
enc = torch.cat(result_rows, dim=2)
return enc
def tiled_encode(self, x: torch.Tensor, return_dict: bool = True) -> AutoencoderKLOutput:
r"""Encode a batch of images using a tiled encoder.
@@ -409,13 +356,6 @@ class AutoencoderKL(ModelMixin, ConfigMixin, FromOriginalModelMixin):
If return_dict is True, a [`~models.autoencoder_kl.AutoencoderKLOutput`] is returned, otherwise a plain
`tuple` is returned.
"""
deprecation_message = (
"The tiled_encode implementation supporting the `return_dict` parameter is deprecated. In the future, the "
"implementation of this method will be replaced with that of `_tiled_encode` and you will no longer be able "
"to pass `return_dict`. You will also have to create a `DiagonalGaussianDistribution()` from the returned value."
)
deprecate("tiled_encode", "1.0.0", deprecation_message, standard_warn=False)
overlap_size = int(self.tile_sample_min_size * (1 - self.tile_overlap_factor))
blend_extent = int(self.tile_latent_min_size * self.tile_overlap_factor)
row_limit = self.tile_latent_min_size - blend_extent
@@ -13,7 +13,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Dict, Optional, Tuple, Union
from typing import Optional, Tuple, Union
import numpy as np
import torch
@@ -41,9 +41,7 @@ class CogVideoXSafeConv3d(nn.Conv3d):
"""
def forward(self, input: torch.Tensor) -> torch.Tensor:
memory_count = (
(input.shape[0] * input.shape[1] * input.shape[2] * input.shape[3] * input.shape[4]) * 2 / 1024**3
)
memory_count = torch.prod(torch.tensor(input.shape)).item() * 2 / 1024**3
# Set to 2GB, suitable for CuDNN
if memory_count > 2:
@@ -117,24 +115,34 @@ class CogVideoXCausalConv3d(nn.Module):
dilation=dilation,
)
def fake_context_parallel_forward(
self, inputs: torch.Tensor, conv_cache: Optional[torch.Tensor] = None
) -> torch.Tensor:
self.conv_cache = None
def fake_context_parallel_forward(self, inputs: torch.Tensor) -> torch.Tensor:
kernel_size = self.time_kernel_size
if kernel_size > 1:
cached_inputs = [conv_cache] if conv_cache is not None else [inputs[:, :, :1]] * (kernel_size - 1)
cached_inputs = (
[self.conv_cache] if self.conv_cache is not None else [inputs[:, :, :1]] * (kernel_size - 1)
)
inputs = torch.cat(cached_inputs + [inputs], dim=2)
return inputs
def forward(self, inputs: torch.Tensor, conv_cache: Optional[torch.Tensor] = None) -> torch.Tensor:
inputs = self.fake_context_parallel_forward(inputs, conv_cache)
conv_cache = inputs[:, :, -self.time_kernel_size + 1 :].clone()
def _clear_fake_context_parallel_cache(self):
del self.conv_cache
self.conv_cache = None
def forward(self, inputs: torch.Tensor) -> torch.Tensor:
inputs = self.fake_context_parallel_forward(inputs)
self._clear_fake_context_parallel_cache()
# Note: we could move these to the cpu for a lower maximum memory usage but its only a few
# hundred megabytes and so let's not do it for now
self.conv_cache = inputs[:, :, -self.time_kernel_size + 1 :].clone()
padding_2d = (self.width_pad, self.width_pad, self.height_pad, self.height_pad)
inputs = F.pad(inputs, padding_2d, mode="constant", value=0)
output = self.conv(inputs)
return output, conv_cache
return output
class CogVideoXSpatialNorm3D(nn.Module):
@@ -164,12 +172,7 @@ class CogVideoXSpatialNorm3D(nn.Module):
self.conv_y = CogVideoXCausalConv3d(zq_channels, f_channels, kernel_size=1, stride=1)
self.conv_b = CogVideoXCausalConv3d(zq_channels, f_channels, kernel_size=1, stride=1)
def forward(
self, f: torch.Tensor, zq: torch.Tensor, conv_cache: Optional[Dict[str, torch.Tensor]] = None
) -> torch.Tensor:
new_conv_cache = {}
conv_cache = conv_cache or {}
def forward(self, f: torch.Tensor, zq: torch.Tensor) -> torch.Tensor:
if f.shape[2] > 1 and f.shape[2] % 2 == 1:
f_first, f_rest = f[:, :, :1], f[:, :, 1:]
f_first_size, f_rest_size = f_first.shape[-3:], f_rest.shape[-3:]
@@ -180,12 +183,9 @@ class CogVideoXSpatialNorm3D(nn.Module):
else:
zq = F.interpolate(zq, size=f.shape[-3:])
conv_y, new_conv_cache["conv_y"] = self.conv_y(zq, conv_cache=conv_cache.get("conv_y"))
conv_b, new_conv_cache["conv_b"] = self.conv_b(zq, conv_cache=conv_cache.get("conv_b"))
norm_f = self.norm_layer(f)
new_f = norm_f * conv_y + conv_b
return new_f, new_conv_cache
new_f = norm_f * self.conv_y(zq) + self.conv_b(zq)
return new_f
class CogVideoXResnetBlock3D(nn.Module):
@@ -236,7 +236,6 @@ class CogVideoXResnetBlock3D(nn.Module):
self.out_channels = out_channels
self.nonlinearity = get_activation(non_linearity)
self.use_conv_shortcut = conv_shortcut
self.spatial_norm_dim = spatial_norm_dim
if spatial_norm_dim is None:
self.norm1 = nn.GroupNorm(num_channels=in_channels, num_groups=groups, eps=eps)
@@ -280,43 +279,34 @@ class CogVideoXResnetBlock3D(nn.Module):
inputs: torch.Tensor,
temb: Optional[torch.Tensor] = None,
zq: Optional[torch.Tensor] = None,
conv_cache: Optional[Dict[str, torch.Tensor]] = None,
) -> torch.Tensor:
new_conv_cache = {}
conv_cache = conv_cache or {}
hidden_states = inputs
if zq is not None:
hidden_states, new_conv_cache["norm1"] = self.norm1(hidden_states, zq, conv_cache=conv_cache.get("norm1"))
hidden_states = self.norm1(hidden_states, zq)
else:
hidden_states = self.norm1(hidden_states)
hidden_states = self.nonlinearity(hidden_states)
hidden_states, new_conv_cache["conv1"] = self.conv1(hidden_states, conv_cache=conv_cache.get("conv1"))
hidden_states = self.conv1(hidden_states)
if temb is not None:
hidden_states = hidden_states + self.temb_proj(self.nonlinearity(temb))[:, :, None, None, None]
if zq is not None:
hidden_states, new_conv_cache["norm2"] = self.norm2(hidden_states, zq, conv_cache=conv_cache.get("norm2"))
hidden_states = self.norm2(hidden_states, zq)
else:
hidden_states = self.norm2(hidden_states)
hidden_states = self.nonlinearity(hidden_states)
hidden_states = self.dropout(hidden_states)
hidden_states, new_conv_cache["conv2"] = self.conv2(hidden_states, conv_cache=conv_cache.get("conv2"))
hidden_states = self.conv2(hidden_states)
if self.in_channels != self.out_channels:
if self.use_conv_shortcut:
inputs, new_conv_cache["conv_shortcut"] = self.conv_shortcut(
inputs, conv_cache=conv_cache.get("conv_shortcut")
)
else:
inputs = self.conv_shortcut(inputs)
inputs = self.conv_shortcut(inputs)
hidden_states = hidden_states + inputs
return hidden_states, new_conv_cache
return hidden_states
class CogVideoXDownBlock3D(nn.Module):
@@ -402,16 +392,8 @@ class CogVideoXDownBlock3D(nn.Module):
hidden_states: torch.Tensor,
temb: Optional[torch.Tensor] = None,
zq: Optional[torch.Tensor] = None,
conv_cache: Optional[Dict[str, torch.Tensor]] = None,
) -> torch.Tensor:
r"""Forward method of the `CogVideoXDownBlock3D` class."""
new_conv_cache = {}
conv_cache = conv_cache or {}
for i, resnet in enumerate(self.resnets):
conv_cache_key = f"resnet_{i}"
for resnet in self.resnets:
if self.training and self.gradient_checkpointing:
def create_custom_forward(module):
@@ -420,23 +402,17 @@ class CogVideoXDownBlock3D(nn.Module):
return create_forward
hidden_states, new_conv_cache[conv_cache_key] = torch.utils.checkpoint.checkpoint(
create_custom_forward(resnet),
hidden_states,
temb,
zq,
conv_cache=conv_cache.get(conv_cache_key),
hidden_states = torch.utils.checkpoint.checkpoint(
create_custom_forward(resnet), hidden_states, temb, zq
)
else:
hidden_states, new_conv_cache[conv_cache_key] = resnet(
hidden_states, temb, zq, conv_cache=conv_cache.get(conv_cache_key)
)
hidden_states = resnet(hidden_states, temb, zq)
if self.downsamplers is not None:
for downsampler in self.downsamplers:
hidden_states = downsampler(hidden_states)
return hidden_states, new_conv_cache
return hidden_states
class CogVideoXMidBlock3D(nn.Module):
@@ -504,16 +480,8 @@ class CogVideoXMidBlock3D(nn.Module):
hidden_states: torch.Tensor,
temb: Optional[torch.Tensor] = None,
zq: Optional[torch.Tensor] = None,
conv_cache: Optional[Dict[str, torch.Tensor]] = None,
) -> torch.Tensor:
r"""Forward method of the `CogVideoXMidBlock3D` class."""
new_conv_cache = {}
conv_cache = conv_cache or {}
for i, resnet in enumerate(self.resnets):
conv_cache_key = f"resnet_{i}"
for resnet in self.resnets:
if self.training and self.gradient_checkpointing:
def create_custom_forward(module):
@@ -522,15 +490,13 @@ class CogVideoXMidBlock3D(nn.Module):
return create_forward
hidden_states, new_conv_cache[conv_cache_key] = torch.utils.checkpoint.checkpoint(
create_custom_forward(resnet), hidden_states, temb, zq, conv_cache=conv_cache.get(conv_cache_key)
hidden_states = torch.utils.checkpoint.checkpoint(
create_custom_forward(resnet), hidden_states, temb, zq
)
else:
hidden_states, new_conv_cache[conv_cache_key] = resnet(
hidden_states, temb, zq, conv_cache=conv_cache.get(conv_cache_key)
)
hidden_states = resnet(hidden_states, temb, zq)
return hidden_states, new_conv_cache
return hidden_states
class CogVideoXUpBlock3D(nn.Module):
@@ -618,16 +584,9 @@ class CogVideoXUpBlock3D(nn.Module):
hidden_states: torch.Tensor,
temb: Optional[torch.Tensor] = None,
zq: Optional[torch.Tensor] = None,
conv_cache: Optional[Dict[str, torch.Tensor]] = None,
) -> torch.Tensor:
r"""Forward method of the `CogVideoXUpBlock3D` class."""
new_conv_cache = {}
conv_cache = conv_cache or {}
for i, resnet in enumerate(self.resnets):
conv_cache_key = f"resnet_{i}"
for resnet in self.resnets:
if self.training and self.gradient_checkpointing:
def create_custom_forward(module):
@@ -636,23 +595,17 @@ class CogVideoXUpBlock3D(nn.Module):
return create_forward
hidden_states, new_conv_cache[conv_cache_key] = torch.utils.checkpoint.checkpoint(
create_custom_forward(resnet),
hidden_states,
temb,
zq,
conv_cache=conv_cache.get(conv_cache_key),
hidden_states = torch.utils.checkpoint.checkpoint(
create_custom_forward(resnet), hidden_states, temb, zq
)
else:
hidden_states, new_conv_cache[conv_cache_key] = resnet(
hidden_states, temb, zq, conv_cache=conv_cache.get(conv_cache_key)
)
hidden_states = resnet(hidden_states, temb, zq)
if self.upsamplers is not None:
for upsampler in self.upsamplers:
hidden_states = upsampler(hidden_states)
return hidden_states, new_conv_cache
return hidden_states
class CogVideoXEncoder3D(nn.Module):
@@ -752,18 +705,9 @@ class CogVideoXEncoder3D(nn.Module):
self.gradient_checkpointing = False
def forward(
self,
sample: torch.Tensor,
temb: Optional[torch.Tensor] = None,
conv_cache: Optional[Dict[str, torch.Tensor]] = None,
) -> torch.Tensor:
def forward(self, sample: torch.Tensor, temb: Optional[torch.Tensor] = None) -> torch.Tensor:
r"""The forward method of the `CogVideoXEncoder3D` class."""
new_conv_cache = {}
conv_cache = conv_cache or {}
hidden_states, new_conv_cache["conv_in"] = self.conv_in(sample, conv_cache=conv_cache.get("conv_in"))
hidden_states = self.conv_in(sample)
if self.training and self.gradient_checkpointing:
@@ -774,44 +718,28 @@ class CogVideoXEncoder3D(nn.Module):
return custom_forward
# 1. Down
for i, down_block in enumerate(self.down_blocks):
conv_cache_key = f"down_block_{i}"
hidden_states, new_conv_cache[conv_cache_key] = torch.utils.checkpoint.checkpoint(
create_custom_forward(down_block),
hidden_states,
temb,
None,
conv_cache=conv_cache.get(conv_cache_key),
for down_block in self.down_blocks:
hidden_states = torch.utils.checkpoint.checkpoint(
create_custom_forward(down_block), hidden_states, temb, None
)
# 2. Mid
hidden_states, new_conv_cache["mid_block"] = torch.utils.checkpoint.checkpoint(
create_custom_forward(self.mid_block),
hidden_states,
temb,
None,
conv_cache=conv_cache.get("mid_block"),
hidden_states = torch.utils.checkpoint.checkpoint(
create_custom_forward(self.mid_block), hidden_states, temb, None
)
else:
# 1. Down
for i, down_block in enumerate(self.down_blocks):
conv_cache_key = f"down_block_{i}"
hidden_states, new_conv_cache[conv_cache_key] = down_block(
hidden_states, temb, None, conv_cache=conv_cache.get(conv_cache_key)
)
for down_block in self.down_blocks:
hidden_states = down_block(hidden_states, temb, None)
# 2. Mid
hidden_states, new_conv_cache["mid_block"] = self.mid_block(
hidden_states, temb, None, conv_cache=conv_cache.get("mid_block")
)
hidden_states = self.mid_block(hidden_states, temb, None)
# 3. Post-process
hidden_states = self.norm_out(hidden_states)
hidden_states = self.conv_act(hidden_states)
hidden_states, new_conv_cache["conv_out"] = self.conv_out(hidden_states, conv_cache=conv_cache.get("conv_out"))
return hidden_states, new_conv_cache
hidden_states = self.conv_out(hidden_states)
return hidden_states
class CogVideoXDecoder3D(nn.Module):
@@ -918,18 +846,9 @@ class CogVideoXDecoder3D(nn.Module):
self.gradient_checkpointing = False
def forward(
self,
sample: torch.Tensor,
temb: Optional[torch.Tensor] = None,
conv_cache: Optional[Dict[str, torch.Tensor]] = None,
) -> torch.Tensor:
def forward(self, sample: torch.Tensor, temb: Optional[torch.Tensor] = None) -> torch.Tensor:
r"""The forward method of the `CogVideoXDecoder3D` class."""
new_conv_cache = {}
conv_cache = conv_cache or {}
hidden_states, new_conv_cache["conv_in"] = self.conv_in(sample, conv_cache=conv_cache.get("conv_in"))
hidden_states = self.conv_in(sample)
if self.training and self.gradient_checkpointing:
@@ -940,45 +859,28 @@ class CogVideoXDecoder3D(nn.Module):
return custom_forward
# 1. Mid
hidden_states, new_conv_cache["mid_block"] = torch.utils.checkpoint.checkpoint(
create_custom_forward(self.mid_block),
hidden_states,
temb,
sample,
conv_cache=conv_cache.get("mid_block"),
hidden_states = torch.utils.checkpoint.checkpoint(
create_custom_forward(self.mid_block), hidden_states, temb, sample
)
# 2. Up
for i, up_block in enumerate(self.up_blocks):
conv_cache_key = f"up_block_{i}"
hidden_states, new_conv_cache[conv_cache_key] = torch.utils.checkpoint.checkpoint(
create_custom_forward(up_block),
hidden_states,
temb,
sample,
conv_cache=conv_cache.get(conv_cache_key),
for up_block in self.up_blocks:
hidden_states = torch.utils.checkpoint.checkpoint(
create_custom_forward(up_block), hidden_states, temb, sample
)
else:
# 1. Mid
hidden_states, new_conv_cache["mid_block"] = self.mid_block(
hidden_states, temb, sample, conv_cache=conv_cache.get("mid_block")
)
hidden_states = self.mid_block(hidden_states, temb, sample)
# 2. Up
for i, up_block in enumerate(self.up_blocks):
conv_cache_key = f"up_block_{i}"
hidden_states, new_conv_cache[conv_cache_key] = up_block(
hidden_states, temb, sample, conv_cache=conv_cache.get(conv_cache_key)
)
for up_block in self.up_blocks:
hidden_states = up_block(hidden_states, temb, sample)
# 3. Post-process
hidden_states, new_conv_cache["norm_out"] = self.norm_out(
hidden_states, sample, conv_cache=conv_cache.get("norm_out")
)
hidden_states = self.norm_out(hidden_states, sample)
hidden_states = self.conv_act(hidden_states)
hidden_states, new_conv_cache["conv_out"] = self.conv_out(hidden_states, conv_cache=conv_cache.get("conv_out"))
return hidden_states, new_conv_cache
hidden_states = self.conv_out(hidden_states)
return hidden_states
class AutoencoderKLCogVideoX(ModelMixin, ConfigMixin, FromOriginalModelMixin):
@@ -1117,6 +1019,12 @@ class AutoencoderKLCogVideoX(ModelMixin, ConfigMixin, FromOriginalModelMixin):
if isinstance(module, (CogVideoXEncoder3D, CogVideoXDecoder3D)):
module.gradient_checkpointing = value
def _clear_fake_context_parallel_cache(self):
for name, module in self.named_modules():
if isinstance(module, CogVideoXCausalConv3d):
logger.debug(f"Clearing fake Context Parallel cache for layer: {name}")
module._clear_fake_context_parallel_cache()
def enable_tiling(
self,
tile_sample_min_height: Optional[int] = None,
@@ -1182,22 +1090,21 @@ class AutoencoderKLCogVideoX(ModelMixin, ConfigMixin, FromOriginalModelMixin):
frame_batch_size = self.num_sample_frames_batch_size
# Note: We expect the number of frames to be either `1` or `frame_batch_size * k` or `frame_batch_size * k + 1` for some k.
# As the extra single frame is handled inside the loop, it is not required to round up here.
num_batches = max(num_frames // frame_batch_size, 1)
conv_cache = None
num_batches = num_frames // frame_batch_size if num_frames > 1 else 1
enc = []
for i in range(num_batches):
remaining_frames = num_frames % frame_batch_size
start_frame = frame_batch_size * i + (0 if i == 0 else remaining_frames)
end_frame = frame_batch_size * (i + 1) + remaining_frames
x_intermediate = x[:, :, start_frame:end_frame]
x_intermediate, conv_cache = self.encoder(x_intermediate, conv_cache=conv_cache)
x_intermediate = self.encoder(x_intermediate)
if self.quant_conv is not None:
x_intermediate = self.quant_conv(x_intermediate)
enc.append(x_intermediate)
self._clear_fake_context_parallel_cache()
enc = torch.cat(enc, dim=2)
return enc
@apply_forward_hook
@@ -1235,10 +1142,8 @@ class AutoencoderKLCogVideoX(ModelMixin, ConfigMixin, FromOriginalModelMixin):
return self.tiled_decode(z, return_dict=return_dict)
frame_batch_size = self.num_latent_frames_batch_size
num_batches = max(num_frames // frame_batch_size, 1)
conv_cache = None
num_batches = num_frames // frame_batch_size
dec = []
for i in range(num_batches):
remaining_frames = num_frames % frame_batch_size
start_frame = frame_batch_size * i + (0 if i == 0 else remaining_frames)
@@ -1246,9 +1151,10 @@ class AutoencoderKLCogVideoX(ModelMixin, ConfigMixin, FromOriginalModelMixin):
z_intermediate = z[:, :, start_frame:end_frame]
if self.post_quant_conv is not None:
z_intermediate = self.post_quant_conv(z_intermediate)
z_intermediate, conv_cache = self.decoder(z_intermediate, conv_cache=conv_cache)
z_intermediate = self.decoder(z_intermediate)
dec.append(z_intermediate)
self._clear_fake_context_parallel_cache()
dec = torch.cat(dec, dim=2)
if not return_dict:
@@ -1331,11 +1237,8 @@ class AutoencoderKLCogVideoX(ModelMixin, ConfigMixin, FromOriginalModelMixin):
row = []
for j in range(0, width, overlap_width):
# Note: We expect the number of frames to be either `1` or `frame_batch_size * k` or `frame_batch_size * k + 1` for some k.
# As the extra single frame is handled inside the loop, it is not required to round up here.
num_batches = max(num_frames // frame_batch_size, 1)
conv_cache = None
num_batches = num_frames // frame_batch_size if num_frames > 1 else 1
time = []
for k in range(num_batches):
remaining_frames = num_frames % frame_batch_size
start_frame = frame_batch_size * k + (0 if k == 0 else remaining_frames)
@@ -1347,11 +1250,11 @@ class AutoencoderKLCogVideoX(ModelMixin, ConfigMixin, FromOriginalModelMixin):
i : i + self.tile_sample_min_height,
j : j + self.tile_sample_min_width,
]
tile, conv_cache = self.encoder(tile, conv_cache=conv_cache)
tile = self.encoder(tile)
if self.quant_conv is not None:
tile = self.quant_conv(tile)
time.append(tile)
self._clear_fake_context_parallel_cache()
row.append(torch.cat(time, dim=2))
rows.append(row)
@@ -1411,10 +1314,8 @@ class AutoencoderKLCogVideoX(ModelMixin, ConfigMixin, FromOriginalModelMixin):
for i in range(0, height, overlap_height):
row = []
for j in range(0, width, overlap_width):
num_batches = max(num_frames // frame_batch_size, 1)
conv_cache = None
num_batches = num_frames // frame_batch_size
time = []
for k in range(num_batches):
remaining_frames = num_frames % frame_batch_size
start_frame = frame_batch_size * k + (0 if k == 0 else remaining_frames)
@@ -1428,9 +1329,9 @@ class AutoencoderKLCogVideoX(ModelMixin, ConfigMixin, FromOriginalModelMixin):
]
if self.post_quant_conv is not None:
tile = self.post_quant_conv(tile)
tile, conv_cache = self.decoder(tile, conv_cache=conv_cache)
tile = self.decoder(tile)
time.append(tile)
self._clear_fake_context_parallel_cache()
row.append(torch.cat(time, dim=2))
rows.append(row)
+13 -32
View File
@@ -23,7 +23,7 @@ from ..loaders import PeftAdapterMixin
from ..models.attention_processor import AttentionProcessor
from ..models.modeling_utils import ModelMixin
from ..utils import USE_PEFT_BACKEND, is_torch_version, logging, scale_lora_layers, unscale_lora_layers
from .controlnet import BaseOutput, ControlNetConditioningEmbedding, zero_module
from .controlnet import BaseOutput, zero_module
from .embeddings import CombinedTimestepGuidanceTextProjEmbeddings, CombinedTimestepTextProjEmbeddings, FluxPosEmbed
from .modeling_outputs import Transformer2DModelOutput
from .transformers.transformer_flux import FluxSingleTransformerBlock, FluxTransformerBlock
@@ -55,7 +55,6 @@ class FluxControlNetModel(ModelMixin, ConfigMixin, PeftAdapterMixin):
guidance_embeds: bool = False,
axes_dims_rope: List[int] = [16, 56, 56],
num_mode: int = None,
conditioning_embedding_channels: int = None,
):
super().__init__()
self.out_channels = in_channels
@@ -107,14 +106,7 @@ class FluxControlNetModel(ModelMixin, ConfigMixin, PeftAdapterMixin):
if self.union:
self.controlnet_mode_embedder = nn.Embedding(num_mode, self.inner_dim)
if conditioning_embedding_channels is not None:
self.input_hint_block = ControlNetConditioningEmbedding(
conditioning_embedding_channels=conditioning_embedding_channels, block_out_channels=(16, 16, 16, 16)
)
self.controlnet_x_embedder = torch.nn.Linear(in_channels, self.inner_dim)
else:
self.input_hint_block = None
self.controlnet_x_embedder = zero_module(torch.nn.Linear(in_channels, self.inner_dim))
self.controlnet_x_embedder = zero_module(torch.nn.Linear(in_channels, self.inner_dim))
self.gradient_checkpointing = False
@@ -277,16 +269,6 @@ class FluxControlNetModel(ModelMixin, ConfigMixin, PeftAdapterMixin):
)
hidden_states = self.x_embedder(hidden_states)
if self.input_hint_block is not None:
controlnet_cond = self.input_hint_block(controlnet_cond)
batch_size, channels, height_pw, width_pw = controlnet_cond.shape
height = height_pw // self.config.patch_size
width = width_pw // self.config.patch_size
controlnet_cond = controlnet_cond.reshape(
batch_size, channels, height, self.config.patch_size, width, self.config.patch_size
)
controlnet_cond = controlnet_cond.permute(0, 2, 4, 1, 3, 5)
controlnet_cond = controlnet_cond.reshape(batch_size, height * width, -1)
# add
hidden_states = hidden_states + self.controlnet_x_embedder(controlnet_cond)
@@ -520,17 +502,16 @@ class FluxMultiControlNetModel(ModelMixin):
control_block_samples = block_samples
control_single_block_samples = single_block_samples
else:
if block_samples is not None and control_block_samples is not None:
control_block_samples = [
control_block_sample + block_sample
for control_block_sample, block_sample in zip(control_block_samples, block_samples)
]
if single_block_samples is not None and control_single_block_samples is not None:
control_single_block_samples = [
control_single_block_sample + block_sample
for control_single_block_sample, block_sample in zip(
control_single_block_samples, single_block_samples
)
]
control_block_samples = [
control_block_sample + block_sample
for control_block_sample, block_sample in zip(control_block_samples, block_samples)
]
control_single_block_samples = [
control_single_block_sample + block_sample
for control_single_block_sample, block_sample in zip(
control_single_block_samples, single_block_samples
)
]
return control_block_samples, control_single_block_samples
-87
View File
@@ -442,60 +442,6 @@ class CogVideoXPatchEmbed(nn.Module):
return embeds
class CogView3PlusPatchEmbed(nn.Module):
def __init__(
self,
in_channels: int = 16,
hidden_size: int = 2560,
patch_size: int = 2,
text_hidden_size: int = 4096,
pos_embed_max_size: int = 128,
):
super().__init__()
self.in_channels = in_channels
self.hidden_size = hidden_size
self.patch_size = patch_size
self.text_hidden_size = text_hidden_size
self.pos_embed_max_size = pos_embed_max_size
# Linear projection for image patches
self.proj = nn.Linear(in_channels * patch_size**2, hidden_size)
# Linear projection for text embeddings
self.text_proj = nn.Linear(text_hidden_size, hidden_size)
pos_embed = get_2d_sincos_pos_embed(hidden_size, pos_embed_max_size, base_size=pos_embed_max_size)
pos_embed = pos_embed.reshape(pos_embed_max_size, pos_embed_max_size, hidden_size)
self.register_buffer("pos_embed", torch.from_numpy(pos_embed).float(), persistent=False)
def forward(self, hidden_states: torch.Tensor, encoder_hidden_states: torch.Tensor) -> torch.Tensor:
batch_size, channel, height, width = hidden_states.shape
if height % self.patch_size != 0 or width % self.patch_size != 0:
raise ValueError("Height and width must be divisible by patch size")
height = height // self.patch_size
width = width // self.patch_size
hidden_states = hidden_states.view(batch_size, channel, height, self.patch_size, width, self.patch_size)
hidden_states = hidden_states.permute(0, 2, 4, 1, 3, 5).contiguous()
hidden_states = hidden_states.view(batch_size, height * width, channel * self.patch_size * self.patch_size)
# Project the patches
hidden_states = self.proj(hidden_states)
encoder_hidden_states = self.text_proj(encoder_hidden_states)
hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1)
# Calculate text_length
text_length = encoder_hidden_states.shape[1]
image_pos_embed = self.pos_embed[:height, :width].reshape(height * width, -1)
text_pos_embed = torch.zeros(
(text_length, self.hidden_size), dtype=image_pos_embed.dtype, device=image_pos_embed.device
)
pos_embed = torch.cat([text_pos_embed, image_pos_embed], dim=0)[None, ...]
return (hidden_states + pos_embed).to(hidden_states.dtype)
def get_3d_rotary_pos_embed(
embed_dim, crops_coords, grid_size, temporal_size, theta: int = 10000, use_real: bool = True
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
@@ -1134,39 +1080,6 @@ class CombinedTimestepGuidanceTextProjEmbeddings(nn.Module):
return conditioning
class CogView3CombinedTimestepSizeEmbeddings(nn.Module):
def __init__(self, embedding_dim: int, condition_dim: int, pooled_projection_dim: int, timesteps_dim: int = 256):
super().__init__()
self.time_proj = Timesteps(num_channels=timesteps_dim, flip_sin_to_cos=True, downscale_freq_shift=0)
self.condition_proj = Timesteps(num_channels=condition_dim, flip_sin_to_cos=True, downscale_freq_shift=0)
self.timestep_embedder = TimestepEmbedding(in_channels=timesteps_dim, time_embed_dim=embedding_dim)
self.condition_embedder = PixArtAlphaTextProjection(pooled_projection_dim, embedding_dim, act_fn="silu")
def forward(
self,
timestep: torch.Tensor,
original_size: torch.Tensor,
target_size: torch.Tensor,
crop_coords: torch.Tensor,
hidden_dtype: torch.dtype,
) -> torch.Tensor:
timesteps_proj = self.time_proj(timestep)
original_size_proj = self.condition_proj(original_size.flatten()).view(original_size.size(0), -1)
crop_coords_proj = self.condition_proj(crop_coords.flatten()).view(crop_coords.size(0), -1)
target_size_proj = self.condition_proj(target_size.flatten()).view(target_size.size(0), -1)
# (B, 3 * condition_dim)
condition_proj = torch.cat([original_size_proj, crop_coords_proj, target_size_proj], dim=1)
timesteps_emb = self.timestep_embedder(timesteps_proj.to(dtype=hidden_dtype)) # (B, embedding_dim)
condition_emb = self.condition_embedder(condition_proj.to(dtype=hidden_dtype)) # (B, embedding_dim)
conditioning = timesteps_emb + condition_emb
return conditioning
class HunyuanDiTAttentionPool(nn.Module):
# Copied from https://github.com/Tencent/HunyuanDiT/blob/cb709308d92e6c7e8d59d0dff41b74d35088db6a/hydit/modules/poolers.py#L6
+9 -23
View File
@@ -29,21 +29,11 @@ def get_sinusoidal_embeddings(
"""Returns the positional encoding (same as Tensor2Tensor).
Args:
timesteps (`jnp.ndarray` of shape `(N,)`):
A 1-D array of N indices, one per batch element. These may be fractional.
embedding_dim (`int`):
The number of output channels.
freq_shift (`float`, *optional*, defaults to `1`):
Shift applied to the frequency scaling of the embeddings.
min_timescale (`float`, *optional*, defaults to `1`):
The smallest time unit used in the sinusoidal calculation (should probably be 0.0).
max_timescale (`float`, *optional*, defaults to `1.0e4`):
The largest time unit used in the sinusoidal calculation.
flip_sin_to_cos (`bool`, *optional*, defaults to `False`):
Whether to flip the order of sinusoidal components to cosine first.
scale (`float`, *optional*, defaults to `1.0`):
A scaling factor applied to the positional embeddings.
timesteps: a 1-D Tensor of N indices, one per batch element.
These may be fractional.
embedding_dim: The number of output channels.
min_timescale: The smallest time unit (should probably be 0.0).
max_timescale: The largest time unit.
Returns:
a Tensor of timing signals [N, num_channels]
"""
@@ -71,9 +61,9 @@ class FlaxTimestepEmbedding(nn.Module):
Args:
time_embed_dim (`int`, *optional*, defaults to `32`):
Time step embedding dimension.
dtype (`jnp.dtype`, *optional*, defaults to `jnp.float32`):
The data type for the embedding parameters.
Time step embedding dimension
dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32):
Parameters `dtype`
"""
time_embed_dim: int = 32
@@ -93,11 +83,7 @@ class FlaxTimesteps(nn.Module):
Args:
dim (`int`, *optional*, defaults to `32`):
Time step embedding dimension.
flip_sin_to_cos (`bool`, *optional*, defaults to `False`):
Whether to flip the sinusoidal function from sine to cosine.
freq_shift (`float`, *optional*, defaults to `1`):
Frequency shift applied to the sinusoidal embeddings.
Time step embedding dimension
"""
dim: int = 32
@@ -31,7 +31,6 @@ from ..utils import (
WEIGHTS_INDEX_NAME,
_add_variant,
_get_model_file,
deprecate,
is_accelerate_available,
is_torch_version,
logging,
@@ -229,67 +228,3 @@ def _fetch_index_file(
index_file = None
return index_file
def _fetch_index_file_legacy(
is_local,
pretrained_model_name_or_path,
subfolder,
use_safetensors,
cache_dir,
variant,
force_download,
proxies,
local_files_only,
token,
revision,
user_agent,
commit_hash,
):
if is_local:
index_file = Path(
pretrained_model_name_or_path,
subfolder or "",
SAFE_WEIGHTS_INDEX_NAME if use_safetensors else WEIGHTS_INDEX_NAME,
).as_posix()
splits = index_file.split(".")
split_index = -3 if ".cache" in index_file else -2
splits = splits[:-split_index] + [variant] + splits[-split_index:]
index_file = ".".join(splits)
if os.path.exists(index_file):
deprecation_message = f"This serialization format is now deprecated to standardize the serialization format between `transformers` and `diffusers`. We recommend you to remove the existing files associated with the current variant ({variant}) and re-obtain them by running a `save_pretrained()`."
deprecate("legacy_sharded_ckpts_with_variant", "1.0.0", deprecation_message, standard_warn=False)
index_file = Path(index_file)
else:
index_file = None
else:
if variant is not None:
index_file_in_repo = Path(
subfolder or "",
SAFE_WEIGHTS_INDEX_NAME if use_safetensors else WEIGHTS_INDEX_NAME,
).as_posix()
splits = index_file_in_repo.split(".")
split_index = -2
splits = splits[:-split_index] + [variant] + splits[-split_index:]
index_file_in_repo = ".".join(splits)
try:
index_file = _get_model_file(
pretrained_model_name_or_path,
weights_name=index_file_in_repo,
cache_dir=cache_dir,
force_download=force_download,
proxies=proxies,
local_files_only=local_files_only,
token=token,
revision=revision,
subfolder=None,
user_agent=user_agent,
commit_hash=commit_hash,
)
index_file = Path(index_file)
deprecation_message = f"This serialization format is now deprecated to standardize the serialization format between `transformers` and `diffusers`. We recommend you to remove the existing files associated with the current variant ({variant}) and re-obtain them by running a `save_pretrained()`."
deprecate("legacy_sharded_ckpts_with_variant", "1.0.0", deprecation_message, standard_warn=False)
except (EntryNotFoundError, EnvironmentError):
index_file = None
return index_file
+20 -24
View File
@@ -54,7 +54,6 @@ from ..utils.hub_utils import (
from .model_loading_utils import (
_determine_device_map,
_fetch_index_file,
_fetch_index_file_legacy,
_load_state_dict_into_model,
load_model_dict_into_meta,
load_state_dict,
@@ -310,9 +309,11 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
weights_name = SAFETENSORS_WEIGHTS_NAME if safe_serialization else WEIGHTS_NAME
weights_name = _add_variant(weights_name, variant)
weights_name_pattern = weights_name.replace(".bin", "{suffix}.bin").replace(
".safetensors", "{suffix}.safetensors"
)
weight_name_split = weights_name.split(".")
if len(weight_name_split) in [2, 3]:
weights_name_pattern = weight_name_split[0] + "{suffix}." + ".".join(weight_name_split[1:])
else:
raise ValueError(f"Invalid {weights_name} provided.")
os.makedirs(save_directory, exist_ok=True)
@@ -623,26 +624,21 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
is_sharded = False
index_file = None
is_local = os.path.isdir(pretrained_model_name_or_path)
index_file_kwargs = {
"is_local": is_local,
"pretrained_model_name_or_path": pretrained_model_name_or_path,
"subfolder": subfolder or "",
"use_safetensors": use_safetensors,
"cache_dir": cache_dir,
"variant": variant,
"force_download": force_download,
"proxies": proxies,
"local_files_only": local_files_only,
"token": token,
"revision": revision,
"user_agent": user_agent,
"commit_hash": commit_hash,
}
index_file = _fetch_index_file(**index_file_kwargs)
# In case the index file was not found we still have to consider the legacy format.
# this becomes applicable when the variant is not None.
if variant is not None and (index_file is None or not os.path.exists(index_file)):
index_file = _fetch_index_file_legacy(**index_file_kwargs)
index_file = _fetch_index_file(
is_local=is_local,
pretrained_model_name_or_path=pretrained_model_name_or_path,
subfolder=subfolder or "",
use_safetensors=use_safetensors,
cache_dir=cache_dir,
variant=variant,
force_download=force_download,
proxies=proxies,
local_files_only=local_files_only,
token=token,
revision=revision,
user_agent=user_agent,
commit_hash=commit_hash,
)
if index_file is not None and index_file.is_file():
is_sharded = True
-45
View File
@@ -355,51 +355,6 @@ class LuminaLayerNormContinuous(nn.Module):
return x
class CogView3PlusAdaLayerNormZeroTextImage(nn.Module):
r"""
Norm layer adaptive layer norm zero (adaLN-Zero).
Parameters:
embedding_dim (`int`): The size of each embedding vector.
num_embeddings (`int`): The size of the embeddings dictionary.
"""
def __init__(self, embedding_dim: int, dim: int):
super().__init__()
self.silu = nn.SiLU()
self.linear = nn.Linear(embedding_dim, 12 * dim, bias=True)
self.norm_x = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-5)
self.norm_c = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-5)
def forward(
self,
x: torch.Tensor,
context: torch.Tensor,
emb: Optional[torch.Tensor] = None,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
emb = self.linear(self.silu(emb))
(
shift_msa,
scale_msa,
gate_msa,
shift_mlp,
scale_mlp,
gate_mlp,
c_shift_msa,
c_scale_msa,
c_gate_msa,
c_shift_mlp,
c_scale_mlp,
c_gate_mlp,
) = emb.chunk(12, dim=1)
normed_x = self.norm_x(x)
normed_context = self.norm_c(context)
x = normed_x * (1 + scale_msa[:, None]) + shift_msa[:, None]
context = normed_context * (1 + c_scale_msa[:, None]) + c_shift_msa[:, None]
return x, gate_msa, shift_mlp, scale_mlp, gate_mlp, context, c_gate_msa, c_shift_mlp, c_scale_mlp, c_gate_mlp
class CogVideoXLayerNormZero(nn.Module):
def __init__(
self,
@@ -14,7 +14,6 @@ if is_torch_available():
from .stable_audio_transformer import StableAudioDiTModel
from .t5_film_transformer import T5FilmDecoder
from .transformer_2d import Transformer2DModel
from .transformer_cogview3plus import CogView3PlusTransformer2DModel
from .transformer_flux import FluxTransformer2DModel
from .transformer_sd3 import SD3Transformer2DModel
from .transformer_temporal import TransformerTemporalModel
@@ -1,386 +0,0 @@
# Copyright 2024 The CogView team, Tsinghua University & ZhipuAI and The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any, Dict, Union
import torch
import torch.nn as nn
from ...configuration_utils import ConfigMixin, register_to_config
from ...models.attention import FeedForward
from ...models.attention_processor import (
Attention,
AttentionProcessor,
CogVideoXAttnProcessor2_0,
)
from ...models.modeling_utils import ModelMixin
from ...models.normalization import AdaLayerNormContinuous
from ...utils import is_torch_version, logging
from ..embeddings import CogView3CombinedTimestepSizeEmbeddings, CogView3PlusPatchEmbed
from ..modeling_outputs import Transformer2DModelOutput
from ..normalization import CogView3PlusAdaLayerNormZeroTextImage
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
class CogView3PlusTransformerBlock(nn.Module):
r"""
Transformer block used in [CogView](https://github.com/THUDM/CogView3) model.
Args:
dim (`int`):
The number of channels in the input and output.
num_attention_heads (`int`):
The number of heads to use for multi-head attention.
attention_head_dim (`int`):
The number of channels in each head.
time_embed_dim (`int`):
The number of channels in timestep embedding.
"""
def __init__(
self,
dim: int = 2560,
num_attention_heads: int = 64,
attention_head_dim: int = 40,
time_embed_dim: int = 512,
):
super().__init__()
self.norm1 = CogView3PlusAdaLayerNormZeroTextImage(embedding_dim=time_embed_dim, dim=dim)
self.attn1 = Attention(
query_dim=dim,
heads=num_attention_heads,
dim_head=attention_head_dim,
out_dim=dim,
bias=True,
qk_norm="layer_norm",
elementwise_affine=False,
eps=1e-6,
processor=CogVideoXAttnProcessor2_0(),
)
self.norm2 = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-5)
self.norm2_context = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-5)
self.ff = FeedForward(dim=dim, dim_out=dim, activation_fn="gelu-approximate")
def forward(
self,
hidden_states: torch.Tensor,
encoder_hidden_states: torch.Tensor,
emb: torch.Tensor,
) -> torch.Tensor:
text_seq_length = encoder_hidden_states.size(1)
# norm & modulate
(
norm_hidden_states,
gate_msa,
shift_mlp,
scale_mlp,
gate_mlp,
norm_encoder_hidden_states,
c_gate_msa,
c_shift_mlp,
c_scale_mlp,
c_gate_mlp,
) = self.norm1(hidden_states, encoder_hidden_states, emb)
# attention
attn_hidden_states, attn_encoder_hidden_states = self.attn1(
hidden_states=norm_hidden_states, encoder_hidden_states=norm_encoder_hidden_states
)
hidden_states = hidden_states + gate_msa.unsqueeze(1) * attn_hidden_states
encoder_hidden_states = encoder_hidden_states + c_gate_msa.unsqueeze(1) * attn_encoder_hidden_states
# norm & modulate
norm_hidden_states = self.norm2(hidden_states)
norm_hidden_states = norm_hidden_states * (1 + scale_mlp[:, None]) + shift_mlp[:, None]
norm_encoder_hidden_states = self.norm2_context(encoder_hidden_states)
norm_encoder_hidden_states = norm_encoder_hidden_states * (1 + c_scale_mlp[:, None]) + c_shift_mlp[:, None]
# feed-forward
norm_hidden_states = torch.cat([norm_encoder_hidden_states, norm_hidden_states], dim=1)
ff_output = self.ff(norm_hidden_states)
hidden_states = hidden_states + gate_mlp.unsqueeze(1) * ff_output[:, text_seq_length:]
encoder_hidden_states = encoder_hidden_states + c_gate_mlp.unsqueeze(1) * ff_output[:, :text_seq_length]
if hidden_states.dtype == torch.float16:
hidden_states = hidden_states.clip(-65504, 65504)
if encoder_hidden_states.dtype == torch.float16:
encoder_hidden_states = encoder_hidden_states.clip(-65504, 65504)
return hidden_states, encoder_hidden_states
class CogView3PlusTransformer2DModel(ModelMixin, ConfigMixin):
r"""
The Transformer model introduced in [CogView3: Finer and Faster Text-to-Image Generation via Relay
Diffusion](https://huggingface.co/papers/2403.05121).
Args:
patch_size (`int`, defaults to `2`):
The size of the patches to use in the patch embedding layer.
in_channels (`int`, defaults to `16`):
The number of channels in the input.
num_layers (`int`, defaults to `30`):
The number of layers of Transformer blocks to use.
attention_head_dim (`int`, defaults to `40`):
The number of channels in each head.
num_attention_heads (`int`, defaults to `64`):
The number of heads to use for multi-head attention.
out_channels (`int`, defaults to `16`):
The number of channels in the output.
text_embed_dim (`int`, defaults to `4096`):
Input dimension of text embeddings from the text encoder.
time_embed_dim (`int`, defaults to `512`):
Output dimension of timestep embeddings.
condition_dim (`int`, defaults to `256`):
The embedding dimension of the input SDXL-style resolution conditions (original_size, target_size,
crop_coords).
pos_embed_max_size (`int`, defaults to `128`):
The maximum resolution of the positional embeddings, from which slices of shape `H x W` are taken and added
to input patched latents, where `H` and `W` are the latent height and width respectively. A value of 128
means that the maximum supported height and width for image generation is `128 * vae_scale_factor *
patch_size => 128 * 8 * 2 => 2048`.
sample_size (`int`, defaults to `128`):
The base resolution of input latents. If height/width is not provided during generation, this value is used
to determine the resolution as `sample_size * vae_scale_factor => 128 * 8 => 1024`
"""
_supports_gradient_checkpointing = True
@register_to_config
def __init__(
self,
patch_size: int = 2,
in_channels: int = 16,
num_layers: int = 30,
attention_head_dim: int = 40,
num_attention_heads: int = 64,
out_channels: int = 16,
text_embed_dim: int = 4096,
time_embed_dim: int = 512,
condition_dim: int = 256,
pos_embed_max_size: int = 128,
sample_size: int = 128,
):
super().__init__()
self.out_channels = out_channels
self.inner_dim = num_attention_heads * attention_head_dim
# CogView3 uses 3 additional SDXL-like conditions - original_size, target_size, crop_coords
# Each of these are sincos embeddings of shape 2 * condition_dim
self.pooled_projection_dim = 3 * 2 * condition_dim
self.patch_embed = CogView3PlusPatchEmbed(
in_channels=in_channels,
hidden_size=self.inner_dim,
patch_size=patch_size,
text_hidden_size=text_embed_dim,
pos_embed_max_size=pos_embed_max_size,
)
self.time_condition_embed = CogView3CombinedTimestepSizeEmbeddings(
embedding_dim=time_embed_dim,
condition_dim=condition_dim,
pooled_projection_dim=self.pooled_projection_dim,
timesteps_dim=self.inner_dim,
)
self.transformer_blocks = nn.ModuleList(
[
CogView3PlusTransformerBlock(
dim=self.inner_dim,
num_attention_heads=num_attention_heads,
attention_head_dim=attention_head_dim,
time_embed_dim=time_embed_dim,
)
for _ in range(num_layers)
]
)
self.norm_out = AdaLayerNormContinuous(
embedding_dim=self.inner_dim,
conditioning_embedding_dim=time_embed_dim,
elementwise_affine=False,
eps=1e-6,
)
self.proj_out = nn.Linear(self.inner_dim, patch_size * patch_size * self.out_channels, bias=True)
self.gradient_checkpointing = False
@property
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.attn_processors
def attn_processors(self) -> Dict[str, AttentionProcessor]:
r"""
Returns:
`dict` of attention processors: A dictionary containing all attention processors used in the model with
indexed by its weight name.
"""
# set recursively
processors = {}
def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]):
if hasattr(module, "get_processor"):
processors[f"{name}.processor"] = module.get_processor()
for sub_name, child in module.named_children():
fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
return processors
for name, module in self.named_children():
fn_recursive_add_processors(name, module, processors)
return processors
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_attn_processor
def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]):
r"""
Sets the attention processor to use to compute attention.
Parameters:
processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`):
The instantiated processor class or a dictionary of processor classes that will be set as the processor
for **all** `Attention` layers.
If `processor` is a dict, the key needs to define the path to the corresponding cross attention
processor. This is strongly recommended when setting trainable attention processors.
"""
count = len(self.attn_processors.keys())
if isinstance(processor, dict) and len(processor) != count:
raise ValueError(
f"A dict of processors was passed, but the number of processors {len(processor)} does not match the"
f" number of attention layers: {count}. Please make sure to pass {count} processor classes."
)
def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
if hasattr(module, "set_processor"):
if not isinstance(processor, dict):
module.set_processor(processor)
else:
module.set_processor(processor.pop(f"{name}.processor"))
for sub_name, child in module.named_children():
fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
for name, module in self.named_children():
fn_recursive_attn_processor(name, module, processor)
def _set_gradient_checkpointing(self, module, value=False):
if hasattr(module, "gradient_checkpointing"):
module.gradient_checkpointing = value
def forward(
self,
hidden_states: torch.Tensor,
encoder_hidden_states: torch.Tensor,
timestep: torch.LongTensor,
original_size: torch.Tensor,
target_size: torch.Tensor,
crop_coords: torch.Tensor,
return_dict: bool = True,
) -> Union[torch.Tensor, Transformer2DModelOutput]:
"""
The [`CogView3PlusTransformer2DModel`] forward method.
Args:
hidden_states (`torch.Tensor`):
Input `hidden_states` of shape `(batch size, channel, height, width)`.
encoder_hidden_states (`torch.Tensor`):
Conditional embeddings (embeddings computed from the input conditions such as prompts) of shape
`(batch_size, sequence_len, text_embed_dim)`
timestep (`torch.LongTensor`):
Used to indicate denoising step.
original_size (`torch.Tensor`):
CogView3 uses SDXL-like micro-conditioning for original image size as explained in section 2.2 of
[https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952).
target_size (`torch.Tensor`):
CogView3 uses SDXL-like micro-conditioning for target image size as explained in section 2.2 of
[https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952).
crop_coords (`torch.Tensor`):
CogView3 uses SDXL-like micro-conditioning for crop coordinates as explained in section 2.2 of
[https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952).
return_dict (`bool`, *optional*, defaults to `True`):
Whether or not to return a [`~models.transformer_2d.Transformer2DModelOutput`] instead of a plain
tuple.
Returns:
`torch.Tensor` or [`~models.transformer_2d.Transformer2DModelOutput`]:
The denoised latents using provided inputs as conditioning.
"""
height, width = hidden_states.shape[-2:]
text_seq_length = encoder_hidden_states.shape[1]
hidden_states = self.patch_embed(
hidden_states, encoder_hidden_states
) # takes care of adding positional embeddings too.
emb = self.time_condition_embed(timestep, original_size, target_size, crop_coords, hidden_states.dtype)
encoder_hidden_states = hidden_states[:, :text_seq_length]
hidden_states = hidden_states[:, text_seq_length:]
for index_block, block in enumerate(self.transformer_blocks):
if self.training and self.gradient_checkpointing:
def create_custom_forward(module):
def custom_forward(*inputs):
return module(*inputs)
return custom_forward
ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
hidden_states, encoder_hidden_states = torch.utils.checkpoint.checkpoint(
create_custom_forward(block),
hidden_states,
encoder_hidden_states,
emb,
**ckpt_kwargs,
)
else:
hidden_states, encoder_hidden_states = block(
hidden_states=hidden_states,
encoder_hidden_states=encoder_hidden_states,
emb=emb,
)
hidden_states = self.norm_out(hidden_states, emb)
hidden_states = self.proj_out(hidden_states) # (batch_size, height*width, patch_size*patch_size*out_channels)
# unpatchify
patch_size = self.config.patch_size
height = height // patch_size
width = width // patch_size
hidden_states = hidden_states.reshape(
shape=(hidden_states.shape[0], height, width, self.out_channels, patch_size, patch_size)
)
hidden_states = torch.einsum("nhwcpq->nchpwq", hidden_states)
output = hidden_states.reshape(
shape=(hidden_states.shape[0], self.out_channels, height * patch_size, width * patch_size)
)
if not return_dict:
return (output,)
return Transformer2DModelOutput(sample=output)
@@ -83,16 +83,14 @@ class FluxSingleTransformerBlock(nn.Module):
hidden_states: torch.FloatTensor,
temb: torch.FloatTensor,
image_rotary_emb=None,
joint_attention_kwargs=None,
):
residual = hidden_states
norm_hidden_states, gate = self.norm(hidden_states, emb=temb)
mlp_hidden_states = self.act_mlp(self.proj_mlp(norm_hidden_states))
joint_attention_kwargs = joint_attention_kwargs or {}
attn_output = self.attn(
hidden_states=norm_hidden_states,
image_rotary_emb=image_rotary_emb,
**joint_attention_kwargs,
)
hidden_states = torch.cat([attn_output, mlp_hidden_states], dim=2)
@@ -163,20 +161,18 @@ class FluxTransformerBlock(nn.Module):
encoder_hidden_states: torch.FloatTensor,
temb: torch.FloatTensor,
image_rotary_emb=None,
joint_attention_kwargs=None,
):
norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1(hidden_states, emb=temb)
norm_encoder_hidden_states, c_gate_msa, c_shift_mlp, c_scale_mlp, c_gate_mlp = self.norm1_context(
encoder_hidden_states, emb=temb
)
joint_attention_kwargs = joint_attention_kwargs or {}
# Attention.
attn_output, context_attn_output = self.attn(
hidden_states=norm_hidden_states,
encoder_hidden_states=norm_encoder_hidden_states,
image_rotary_emb=image_rotary_emb,
**joint_attention_kwargs,
)
# Process attention outputs for the `hidden_states`.
@@ -402,7 +398,6 @@ class FluxTransformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOrig
controlnet_block_samples=None,
controlnet_single_block_samples=None,
return_dict: bool = True,
controlnet_blocks_repeat: bool = False,
) -> Union[torch.FloatTensor, Transformer2DModelOutput]:
"""
The [`FluxTransformer2DModel`] forward method.
@@ -502,20 +497,13 @@ class FluxTransformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOrig
encoder_hidden_states=encoder_hidden_states,
temb=temb,
image_rotary_emb=image_rotary_emb,
joint_attention_kwargs=joint_attention_kwargs,
)
# controlnet residual
if controlnet_block_samples is not None:
interval_control = len(self.transformer_blocks) / len(controlnet_block_samples)
interval_control = int(np.ceil(interval_control))
# For Xlabs ControlNet.
if controlnet_blocks_repeat:
hidden_states = (
hidden_states + controlnet_block_samples[index_block % len(controlnet_block_samples)]
)
else:
hidden_states = hidden_states + controlnet_block_samples[index_block // interval_control]
hidden_states = hidden_states + controlnet_block_samples[index_block // interval_control]
hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1)
@@ -545,7 +533,6 @@ class FluxTransformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOrig
hidden_states=hidden_states,
temb=temb,
image_rotary_emb=image_rotary_emb,
joint_attention_kwargs=joint_attention_kwargs,
)
# controlnet residual
+6 -6
View File
@@ -19,7 +19,6 @@ import torch.nn as nn
import torch.nn.functional as F
from ..utils import deprecate
from ..utils.import_utils import is_torch_version
from .normalization import RMSNorm
@@ -152,10 +151,11 @@ class Upsample2D(nn.Module):
if self.use_conv_transpose:
return self.conv(hidden_states)
# Cast to float32 to as 'upsample_nearest2d_out_frame' op does not support bfloat16 until PyTorch 2.1
# https://github.com/pytorch/pytorch/issues/86679#issuecomment-1783978767
# Cast to float32 to as 'upsample_nearest2d_out_frame' op does not support bfloat16
# TODO(Suraj): Remove this cast once the issue is fixed in PyTorch
# https://github.com/pytorch/pytorch/issues/86679
dtype = hidden_states.dtype
if dtype == torch.bfloat16 and is_torch_version("<", "2.1"):
if dtype == torch.bfloat16:
hidden_states = hidden_states.to(torch.float32)
# upsample_nearest_nhwc fails with large batch sizes. see https://github.com/huggingface/diffusers/issues/984
@@ -170,8 +170,8 @@ class Upsample2D(nn.Module):
else:
hidden_states = F.interpolate(hidden_states, size=output_size, mode="nearest")
# Cast back to original dtype
if dtype == torch.bfloat16 and is_torch_version("<", "2.1"):
# If the input is bfloat16, we cast back to bfloat16
if dtype == torch.bfloat16:
hidden_states = hidden_states.to(dtype)
# TODO(Suraj, Patrick) - clean up after weight dicts are correctly renamed
+1 -13
View File
@@ -144,9 +144,7 @@ else:
"CogVideoXPipeline",
"CogVideoXImageToVideoPipeline",
"CogVideoXVideoToVideoPipeline",
"CogVideoXFunControlPipeline",
]
_import_structure["cogview3"] = ["CogView3PlusPipeline"]
_import_structure["controlnet"].extend(
[
"BlipDiffusionControlNetPipeline",
@@ -160,13 +158,11 @@ else:
)
_import_structure["pag"].extend(
[
"StableDiffusionControlNetPAGInpaintPipeline",
"AnimateDiffPAGPipeline",
"KolorsPAGPipeline",
"HunyuanDiTPAGPipeline",
"StableDiffusion3PAGPipeline",
"StableDiffusionPAGPipeline",
"StableDiffusionPAGImg2ImgPipeline",
"StableDiffusionControlNetPAGPipeline",
"StableDiffusionXLPAGPipeline",
"StableDiffusionXLPAGInpaintPipeline",
@@ -471,13 +467,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
)
from .aura_flow import AuraFlowPipeline
from .blip_diffusion import BlipDiffusionPipeline
from .cogvideo import (
CogVideoXFunControlPipeline,
CogVideoXImageToVideoPipeline,
CogVideoXPipeline,
CogVideoXVideoToVideoPipeline,
)
from .cogview3 import CogView3PlusPipeline
from .cogvideo import CogVideoXImageToVideoPipeline, CogVideoXPipeline, CogVideoXVideoToVideoPipeline
from .controlnet import (
BlipDiffusionControlNetPipeline,
StableDiffusionControlNetImg2ImgPipeline,
@@ -576,9 +566,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
KolorsPAGPipeline,
PixArtSigmaPAGPipeline,
StableDiffusion3PAGPipeline,
StableDiffusionControlNetPAGInpaintPipeline,
StableDiffusionControlNetPAGPipeline,
StableDiffusionPAGImg2ImgPipeline,
StableDiffusionPAGPipeline,
StableDiffusionXLControlNetPAGImg2ImgPipeline,
StableDiffusionXLControlNetPAGPipeline,
-6
View File
@@ -20,7 +20,6 @@ from huggingface_hub.utils import validate_hf_hub_args
from ..configuration_utils import ConfigMixin
from ..utils import is_sentencepiece_available
from .aura_flow import AuraFlowPipeline
from .cogview3 import CogView3PlusPipeline
from .controlnet import (
StableDiffusionControlNetImg2ImgPipeline,
StableDiffusionControlNetInpaintPipeline,
@@ -62,9 +61,7 @@ from .pag import (
HunyuanDiTPAGPipeline,
PixArtSigmaPAGPipeline,
StableDiffusion3PAGPipeline,
StableDiffusionControlNetPAGInpaintPipeline,
StableDiffusionControlNetPAGPipeline,
StableDiffusionPAGImg2ImgPipeline,
StableDiffusionPAGPipeline,
StableDiffusionXLControlNetPAGImg2ImgPipeline,
StableDiffusionXLControlNetPAGPipeline,
@@ -120,7 +117,6 @@ AUTO_TEXT2IMAGE_PIPELINES_MAPPING = OrderedDict(
("flux", FluxPipeline),
("flux-controlnet", FluxControlNetPipeline),
("lumina", LuminaText2ImgPipeline),
("cogview3", CogView3PlusPipeline),
]
)
@@ -134,7 +130,6 @@ AUTO_IMAGE2IMAGE_PIPELINES_MAPPING = OrderedDict(
("kandinsky22", KandinskyV22Img2ImgCombinedPipeline),
("kandinsky3", Kandinsky3Img2ImgPipeline),
("stable-diffusion-controlnet", StableDiffusionControlNetImg2ImgPipeline),
("stable-diffusion-pag", StableDiffusionPAGImg2ImgPipeline),
("stable-diffusion-xl-controlnet", StableDiffusionXLControlNetImg2ImgPipeline),
("stable-diffusion-xl-pag", StableDiffusionXLPAGImg2ImgPipeline),
("stable-diffusion-xl-controlnet-pag", StableDiffusionXLControlNetPAGImg2ImgPipeline),
@@ -153,7 +148,6 @@ AUTO_INPAINT_PIPELINES_MAPPING = OrderedDict(
("kandinsky", KandinskyInpaintCombinedPipeline),
("kandinsky22", KandinskyV22InpaintCombinedPipeline),
("stable-diffusion-controlnet", StableDiffusionControlNetInpaintPipeline),
("stable-diffusion-controlnet-pag", StableDiffusionControlNetPAGInpaintPipeline),
("stable-diffusion-xl-controlnet", StableDiffusionXLControlNetInpaintPipeline),
("stable-diffusion-xl-pag", StableDiffusionXLPAGInpaintPipeline),
("flux", FluxInpaintPipeline),
@@ -23,7 +23,6 @@ except OptionalDependencyNotAvailable:
_dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects))
else:
_import_structure["pipeline_cogvideox"] = ["CogVideoXPipeline"]
_import_structure["pipeline_cogvideox_fun_control"] = ["CogVideoXFunControlPipeline"]
_import_structure["pipeline_cogvideox_image2video"] = ["CogVideoXImageToVideoPipeline"]
_import_structure["pipeline_cogvideox_video2video"] = ["CogVideoXVideoToVideoPipeline"]
@@ -36,7 +35,6 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
from ...utils.dummy_torch_and_transformers_objects import *
else:
from .pipeline_cogvideox import CogVideoXPipeline
from .pipeline_cogvideox_fun_control import CogVideoXFunControlPipeline
from .pipeline_cogvideox_image2video import CogVideoXImageToVideoPipeline
from .pipeline_cogvideox_video2video import CogVideoXVideoToVideoPipeline
@@ -1,794 +0,0 @@
# Copyright 2024 The CogVideoX team, Tsinghua University & ZhipuAI, Alibaba-PAI and The HuggingFace Team.
# All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import inspect
import math
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
import torch
from PIL import Image
from transformers import T5EncoderModel, T5Tokenizer
from ...callbacks import MultiPipelineCallbacks, PipelineCallback
from ...loaders import CogVideoXLoraLoaderMixin
from ...models import AutoencoderKLCogVideoX, CogVideoXTransformer3DModel
from ...models.embeddings import get_3d_rotary_pos_embed
from ...pipelines.pipeline_utils import DiffusionPipeline
from ...schedulers import KarrasDiffusionSchedulers
from ...utils import logging, replace_example_docstring
from ...utils.torch_utils import randn_tensor
from ...video_processor import VideoProcessor
from .pipeline_output import CogVideoXPipelineOutput
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
EXAMPLE_DOC_STRING = """
Examples:
```python
>>> import torch
>>> from diffusers import CogVideoXFunControlPipeline, DDIMScheduler
>>> from diffusers.utils import export_to_video, load_video
>>> pipe = CogVideoXFunControlPipeline.from_pretrained(
... "alibaba-pai/CogVideoX-Fun-V1.1-5b-Pose", torch_dtype=torch.bfloat16
... )
>>> pipe.scheduler = DDIMScheduler.from_config(pipe.scheduler.config)
>>> pipe.to("cuda")
>>> control_video = load_video(
... "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/hiker.mp4"
... )
>>> prompt = (
... "An astronaut stands triumphantly at the peak of a towering mountain. Panorama of rugged peaks and "
... "valleys. Very futuristic vibe and animated aesthetic. Highlights of purple and golden colors in "
... "the scene. The sky is looks like an animated/cartoonish dream of galaxies, nebulae, stars, planets, "
... "moons, but the remainder of the scene is mostly realistic."
... )
>>> video = pipe(prompt=prompt, control_video=control_video).frames[0]
>>> export_to_video(video, "output.mp4", fps=8)
```
"""
# Copied from diffusers.pipelines.cogvideo.pipeline_cogvideox.get_resize_crop_region_for_grid
def get_resize_crop_region_for_grid(src, tgt_width, tgt_height):
tw = tgt_width
th = tgt_height
h, w = src
r = h / w
if r > (th / tw):
resize_height = th
resize_width = int(round(th / h * w))
else:
resize_width = tw
resize_height = int(round(tw / w * h))
crop_top = int(round((th - resize_height) / 2.0))
crop_left = int(round((tw - resize_width) / 2.0))
return (crop_top, crop_left), (crop_top + resize_height, crop_left + resize_width)
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps
def retrieve_timesteps(
scheduler,
num_inference_steps: Optional[int] = None,
device: Optional[Union[str, torch.device]] = None,
timesteps: Optional[List[int]] = None,
sigmas: Optional[List[float]] = None,
**kwargs,
):
"""
Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.
Args:
scheduler (`SchedulerMixin`):
The scheduler to get timesteps from.
num_inference_steps (`int`):
The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps`
must be `None`.
device (`str` or `torch.device`, *optional*):
The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
timesteps (`List[int]`, *optional*):
Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed,
`num_inference_steps` and `sigmas` must be `None`.
sigmas (`List[float]`, *optional*):
Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed,
`num_inference_steps` and `timesteps` must be `None`.
Returns:
`Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the
second element is the number of inference steps.
"""
if timesteps is not None and sigmas is not None:
raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values")
if timesteps is not None:
accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
if not accepts_timesteps:
raise ValueError(
f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
f" timestep schedules. Please check whether you are using the correct scheduler."
)
scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)
timesteps = scheduler.timesteps
num_inference_steps = len(timesteps)
elif sigmas is not None:
accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
if not accept_sigmas:
raise ValueError(
f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
f" sigmas schedules. Please check whether you are using the correct scheduler."
)
scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs)
timesteps = scheduler.timesteps
num_inference_steps = len(timesteps)
else:
scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)
timesteps = scheduler.timesteps
return timesteps, num_inference_steps
class CogVideoXFunControlPipeline(DiffusionPipeline, CogVideoXLoraLoaderMixin):
r"""
Pipeline for controlled text-to-video generation using CogVideoX Fun.
This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
Args:
vae ([`AutoencoderKL`]):
Variational Auto-Encoder (VAE) Model to encode and decode videos to and from latent representations.
text_encoder ([`T5EncoderModel`]):
Frozen text-encoder. CogVideoX uses
[T5](https://huggingface.co/docs/transformers/model_doc/t5#transformers.T5EncoderModel); specifically the
[t5-v1_1-xxl](https://huggingface.co/PixArt-alpha/PixArt-alpha/tree/main/t5-v1_1-xxl) variant.
tokenizer (`T5Tokenizer`):
Tokenizer of class
[T5Tokenizer](https://huggingface.co/docs/transformers/model_doc/t5#transformers.T5Tokenizer).
transformer ([`CogVideoXTransformer3DModel`]):
A text conditioned `CogVideoXTransformer3DModel` to denoise the encoded video latents.
scheduler ([`SchedulerMixin`]):
A scheduler to be used in combination with `transformer` to denoise the encoded video latents.
"""
_optional_components = []
model_cpu_offload_seq = "text_encoder->vae->transformer->vae"
_callback_tensor_inputs = [
"latents",
"prompt_embeds",
"negative_prompt_embeds",
]
def __init__(
self,
tokenizer: T5Tokenizer,
text_encoder: T5EncoderModel,
vae: AutoencoderKLCogVideoX,
transformer: CogVideoXTransformer3DModel,
scheduler: KarrasDiffusionSchedulers,
):
super().__init__()
self.register_modules(
tokenizer=tokenizer, text_encoder=text_encoder, vae=vae, transformer=transformer, scheduler=scheduler
)
self.vae_scale_factor_spatial = (
2 ** (len(self.vae.config.block_out_channels) - 1) if hasattr(self, "vae") and self.vae is not None else 8
)
self.vae_scale_factor_temporal = (
self.vae.config.temporal_compression_ratio if hasattr(self, "vae") and self.vae is not None else 4
)
self.vae_scaling_factor_image = (
self.vae.config.scaling_factor if hasattr(self, "vae") and self.vae is not None else 0.7
)
self.video_processor = VideoProcessor(vae_scale_factor=self.vae_scale_factor_spatial)
# Copied from diffusers.pipelines.cogvideo.pipeline_cogvideox.CogVideoXPipeline._get_t5_prompt_embeds
def _get_t5_prompt_embeds(
self,
prompt: Union[str, List[str]] = None,
num_videos_per_prompt: int = 1,
max_sequence_length: int = 226,
device: Optional[torch.device] = None,
dtype: Optional[torch.dtype] = None,
):
device = device or self._execution_device
dtype = dtype or self.text_encoder.dtype
prompt = [prompt] if isinstance(prompt, str) else prompt
batch_size = len(prompt)
text_inputs = self.tokenizer(
prompt,
padding="max_length",
max_length=max_sequence_length,
truncation=True,
add_special_tokens=True,
return_tensors="pt",
)
text_input_ids = text_inputs.input_ids
untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids):
removed_text = self.tokenizer.batch_decode(untruncated_ids[:, max_sequence_length - 1 : -1])
logger.warning(
"The following part of your input was truncated because `max_sequence_length` is set to "
f" {max_sequence_length} tokens: {removed_text}"
)
prompt_embeds = self.text_encoder(text_input_ids.to(device))[0]
prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)
# duplicate text embeddings for each generation per prompt, using mps friendly method
_, seq_len, _ = prompt_embeds.shape
prompt_embeds = prompt_embeds.repeat(1, num_videos_per_prompt, 1)
prompt_embeds = prompt_embeds.view(batch_size * num_videos_per_prompt, seq_len, -1)
return prompt_embeds
# Copied from diffusers.pipelines.cogvideo.pipeline_cogvideox.CogVideoXPipeline.encode_prompt
def encode_prompt(
self,
prompt: Union[str, List[str]],
negative_prompt: Optional[Union[str, List[str]]] = None,
do_classifier_free_guidance: bool = True,
num_videos_per_prompt: int = 1,
prompt_embeds: Optional[torch.Tensor] = None,
negative_prompt_embeds: Optional[torch.Tensor] = None,
max_sequence_length: int = 226,
device: Optional[torch.device] = None,
dtype: Optional[torch.dtype] = None,
):
r"""
Encodes the prompt into text encoder hidden states.
Args:
prompt (`str` or `List[str]`, *optional*):
prompt to be encoded
negative_prompt (`str` or `List[str]`, *optional*):
The prompt or prompts not to guide the image generation. If not defined, one has to pass
`negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
less than `1`).
do_classifier_free_guidance (`bool`, *optional*, defaults to `True`):
Whether to use classifier free guidance or not.
num_videos_per_prompt (`int`, *optional*, defaults to 1):
Number of videos that should be generated per prompt. torch device to place the resulting embeddings on
prompt_embeds (`torch.Tensor`, *optional*):
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
provided, text embeddings will be generated from `prompt` input argument.
negative_prompt_embeds (`torch.Tensor`, *optional*):
Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
argument.
device: (`torch.device`, *optional*):
torch device
dtype: (`torch.dtype`, *optional*):
torch dtype
"""
device = device or self._execution_device
prompt = [prompt] if isinstance(prompt, str) else prompt
if prompt is not None:
batch_size = len(prompt)
else:
batch_size = prompt_embeds.shape[0]
if prompt_embeds is None:
prompt_embeds = self._get_t5_prompt_embeds(
prompt=prompt,
num_videos_per_prompt=num_videos_per_prompt,
max_sequence_length=max_sequence_length,
device=device,
dtype=dtype,
)
if do_classifier_free_guidance and negative_prompt_embeds is None:
negative_prompt = negative_prompt or ""
negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt
if prompt is not None and type(prompt) is not type(negative_prompt):
raise TypeError(
f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
f" {type(prompt)}."
)
elif batch_size != len(negative_prompt):
raise ValueError(
f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
" the batch size of `prompt`."
)
negative_prompt_embeds = self._get_t5_prompt_embeds(
prompt=negative_prompt,
num_videos_per_prompt=num_videos_per_prompt,
max_sequence_length=max_sequence_length,
device=device,
dtype=dtype,
)
return prompt_embeds, negative_prompt_embeds
# Copied from diffusers.pipelines.cogvideo.pipeline_cogvideox.CogVideoXPipeline.prepare_latents
def prepare_latents(
self, batch_size, num_channels_latents, num_frames, height, width, dtype, device, generator, latents=None
):
if isinstance(generator, list) and len(generator) != batch_size:
raise ValueError(
f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
f" size of {batch_size}. Make sure the batch size matches the length of the generators."
)
shape = (
batch_size,
(num_frames - 1) // self.vae_scale_factor_temporal + 1,
num_channels_latents,
height // self.vae_scale_factor_spatial,
width // self.vae_scale_factor_spatial,
)
if latents is None:
latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
else:
latents = latents.to(device)
# scale the initial noise by the standard deviation required by the scheduler
latents = latents * self.scheduler.init_noise_sigma
return latents
# Adapted from https://github.com/aigc-apps/CogVideoX-Fun/blob/2a93e5c14e02b2b5921d533fd59fc8c0ed69fb24/cogvideox/pipeline/pipeline_cogvideox_control.py#L366
def prepare_control_latents(
self, mask: Optional[torch.Tensor] = None, masked_image: Optional[torch.Tensor] = None
) -> Tuple[torch.Tensor, torch.Tensor]:
if mask is not None:
masks = []
for i in range(mask.size(0)):
current_mask = mask[i].unsqueeze(0)
current_mask = self.vae.encode(current_mask)[0]
current_mask = current_mask.mode()
masks.append(current_mask)
mask = torch.cat(masks, dim=0)
mask = mask * self.vae.config.scaling_factor
if masked_image is not None:
mask_pixel_values = []
for i in range(masked_image.size(0)):
mask_pixel_value = masked_image[i].unsqueeze(0)
mask_pixel_value = self.vae.encode(mask_pixel_value)[0]
mask_pixel_value = mask_pixel_value.mode()
mask_pixel_values.append(mask_pixel_value)
masked_image_latents = torch.cat(mask_pixel_values, dim=0)
masked_image_latents = masked_image_latents * self.vae.config.scaling_factor
else:
masked_image_latents = None
return mask, masked_image_latents
# Copied from diffusers.pipelines.cogvideo.pipeline_cogvideox.CogVideoXPipeline.decode_latents
def decode_latents(self, latents: torch.Tensor) -> torch.Tensor:
latents = latents.permute(0, 2, 1, 3, 4) # [batch_size, num_channels, num_frames, height, width]
latents = 1 / self.vae_scaling_factor_image * latents
frames = self.vae.decode(latents).sample
return frames
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs
def prepare_extra_step_kwargs(self, generator, eta):
# prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
# eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
# eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
# and should be between [0, 1]
accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
extra_step_kwargs = {}
if accepts_eta:
extra_step_kwargs["eta"] = eta
# check if the scheduler accepts generator
accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys())
if accepts_generator:
extra_step_kwargs["generator"] = generator
return extra_step_kwargs
def check_inputs(
self,
prompt,
height,
width,
negative_prompt,
callback_on_step_end_tensor_inputs,
prompt_embeds=None,
negative_prompt_embeds=None,
control_video=None,
control_video_latents=None,
):
if height % 8 != 0 or width % 8 != 0:
raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
if callback_on_step_end_tensor_inputs is not None and not all(
k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
):
raise ValueError(
f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}"
)
if prompt is not None and prompt_embeds is not None:
raise ValueError(
f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
" only forward one of the two."
)
elif prompt is None and prompt_embeds is None:
raise ValueError(
"Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
)
elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
if prompt is not None and negative_prompt_embeds is not None:
raise ValueError(
f"Cannot forward both `prompt`: {prompt} and `negative_prompt_embeds`:"
f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
)
if negative_prompt is not None and negative_prompt_embeds is not None:
raise ValueError(
f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:"
f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
)
if prompt_embeds is not None and negative_prompt_embeds is not None:
if prompt_embeds.shape != negative_prompt_embeds.shape:
raise ValueError(
"`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but"
f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`"
f" {negative_prompt_embeds.shape}."
)
if control_video is not None and control_video_latents is not None:
raise ValueError(
"Cannot pass both `control_video` and `control_video_latents`. Please make sure to pass only one of these parameters."
)
def fuse_qkv_projections(self) -> None:
r"""Enables fused QKV projections."""
self.fusing_transformer = True
self.transformer.fuse_qkv_projections()
def unfuse_qkv_projections(self) -> None:
r"""Disable QKV projection fusion if enabled."""
if not self.fusing_transformer:
logger.warning("The Transformer was not initially fused for QKV projections. Doing nothing.")
else:
self.transformer.unfuse_qkv_projections()
self.fusing_transformer = False
# Copied from diffusers.pipelines.cogvideo.pipeline_cogvideox.CogVideoXPipeline._prepare_rotary_positional_embeddings
def _prepare_rotary_positional_embeddings(
self,
height: int,
width: int,
num_frames: int,
device: torch.device,
) -> Tuple[torch.Tensor, torch.Tensor]:
grid_height = height // (self.vae_scale_factor_spatial * self.transformer.config.patch_size)
grid_width = width // (self.vae_scale_factor_spatial * self.transformer.config.patch_size)
base_size_width = 720 // (self.vae_scale_factor_spatial * self.transformer.config.patch_size)
base_size_height = 480 // (self.vae_scale_factor_spatial * self.transformer.config.patch_size)
grid_crops_coords = get_resize_crop_region_for_grid(
(grid_height, grid_width), base_size_width, base_size_height
)
freqs_cos, freqs_sin = get_3d_rotary_pos_embed(
embed_dim=self.transformer.config.attention_head_dim,
crops_coords=grid_crops_coords,
grid_size=(grid_height, grid_width),
temporal_size=num_frames,
)
freqs_cos = freqs_cos.to(device=device)
freqs_sin = freqs_sin.to(device=device)
return freqs_cos, freqs_sin
@property
def guidance_scale(self):
return self._guidance_scale
@property
def num_timesteps(self):
return self._num_timesteps
@property
def attention_kwargs(self):
return self._attention_kwargs
@property
def interrupt(self):
return self._interrupt
@torch.no_grad()
@replace_example_docstring(EXAMPLE_DOC_STRING)
def __call__(
self,
prompt: Optional[Union[str, List[str]]] = None,
negative_prompt: Optional[Union[str, List[str]]] = None,
control_video: Optional[List[Image.Image]] = None,
height: int = 480,
width: int = 720,
num_inference_steps: int = 50,
timesteps: Optional[List[int]] = None,
guidance_scale: float = 6,
use_dynamic_cfg: bool = False,
num_videos_per_prompt: int = 1,
eta: float = 0.0,
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
latents: Optional[torch.Tensor] = None,
control_video_latents: Optional[torch.Tensor] = None,
prompt_embeds: Optional[torch.Tensor] = None,
negative_prompt_embeds: Optional[torch.Tensor] = None,
output_type: str = "pil",
return_dict: bool = True,
attention_kwargs: Optional[Dict[str, Any]] = None,
callback_on_step_end: Optional[
Union[Callable[[int, int, Dict], None], PipelineCallback, MultiPipelineCallbacks]
] = None,
callback_on_step_end_tensor_inputs: List[str] = ["latents"],
max_sequence_length: int = 226,
) -> Union[CogVideoXPipelineOutput, Tuple]:
"""
Function invoked when calling the pipeline for generation.
Args:
prompt (`str` or `List[str]`, *optional*):
The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
instead.
negative_prompt (`str` or `List[str]`, *optional*):
The prompt or prompts not to guide the image generation. If not defined, one has to pass
`negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
less than `1`).
control_video (`List[PIL.Image.Image]`):
The control video to condition the generation on. Must be a list of images/frames of the video. If not
provided, `control_video_latents` must be provided.
height (`int`, *optional*, defaults to self.transformer.config.sample_height * self.vae_scale_factor_spatial):
The height in pixels of the generated image. This is set to 480 by default for the best results.
width (`int`, *optional*, defaults to self.transformer.config.sample_height * self.vae_scale_factor_spatial):
The width in pixels of the generated image. This is set to 720 by default for the best results.
num_inference_steps (`int`, *optional*, defaults to 50):
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
expense of slower inference.
timesteps (`List[int]`, *optional*):
Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument
in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is
passed will be used. Must be in descending order.
guidance_scale (`float`, *optional*, defaults to 6.0):
Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
`guidance_scale` is defined as `w` of equation 2. of [Imagen
Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
usually at the expense of lower image quality.
num_videos_per_prompt (`int`, *optional*, defaults to 1):
The number of videos to generate per prompt.
generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
to make generation deterministic.
latents (`torch.Tensor`, *optional*):
Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for video
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
tensor will ge generated by sampling using the supplied random `generator`.
control_video_latents (`torch.Tensor`, *optional*):
Pre-generated control latents, sampled from a Gaussian distribution, to be used as inputs for
controlled video generation. If not provided, `control_video` must be provided.
prompt_embeds (`torch.Tensor`, *optional*):
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
provided, text embeddings will be generated from `prompt` input argument.
negative_prompt_embeds (`torch.Tensor`, *optional*):
Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
argument.
output_type (`str`, *optional*, defaults to `"pil"`):
The output format of the generate image. Choose between
[PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
return_dict (`bool`, *optional*, defaults to `True`):
Whether or not to return a [`~pipelines.stable_diffusion_xl.StableDiffusionXLPipelineOutput`] instead
of a plain tuple.
attention_kwargs (`dict`, *optional*):
A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
`self.processor` in
[diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
callback_on_step_end (`Callable`, *optional*):
A function that calls at the end of each denoising steps during the inference. The function is called
with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int,
callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by
`callback_on_step_end_tensor_inputs`.
callback_on_step_end_tensor_inputs (`List`, *optional*):
The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
`._callback_tensor_inputs` attribute of your pipeline class.
max_sequence_length (`int`, defaults to `226`):
Maximum sequence length in encoded prompt. Must be consistent with
`self.transformer.config.max_text_seq_length` otherwise may lead to poor results.
Examples:
Returns:
[`~pipelines.cogvideo.pipeline_cogvideox.CogVideoXPipelineOutput`] or `tuple`:
[`~pipelines.cogvideo.pipeline_cogvideox.CogVideoXPipelineOutput`] if `return_dict` is True, otherwise a
`tuple`. When returning a tuple, the first element is a list with the generated images.
"""
if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)):
callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs
num_videos_per_prompt = 1
# 1. Check inputs. Raise error if not correct
self.check_inputs(
prompt,
height,
width,
negative_prompt,
callback_on_step_end_tensor_inputs,
prompt_embeds,
negative_prompt_embeds,
control_video,
control_video_latents,
)
self._guidance_scale = guidance_scale
self._attention_kwargs = attention_kwargs
self._interrupt = False
# 2. Default call parameters
if prompt is not None and isinstance(prompt, str):
batch_size = 1
elif prompt is not None and isinstance(prompt, list):
batch_size = len(prompt)
else:
batch_size = prompt_embeds.shape[0]
if control_video is not None and isinstance(control_video[0], Image.Image):
control_video = [control_video]
device = self._execution_device
# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
# of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
# corresponds to doing no classifier free guidance.
do_classifier_free_guidance = guidance_scale > 1.0
# 3. Encode input prompt
prompt_embeds, negative_prompt_embeds = self.encode_prompt(
prompt,
negative_prompt,
do_classifier_free_guidance,
num_videos_per_prompt=num_videos_per_prompt,
prompt_embeds=prompt_embeds,
negative_prompt_embeds=negative_prompt_embeds,
max_sequence_length=max_sequence_length,
device=device,
)
if do_classifier_free_guidance:
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)
# 4. Prepare timesteps
timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, timesteps)
self._num_timesteps = len(timesteps)
# 5. Prepare latents.
latent_channels = self.transformer.config.in_channels // 2
num_frames = len(control_video[0]) if control_video is not None else control_video_latents.size(2)
latents = self.prepare_latents(
batch_size * num_videos_per_prompt,
latent_channels,
num_frames,
height,
width,
prompt_embeds.dtype,
device,
generator,
latents,
)
if control_video_latents is None:
control_video = self.video_processor.preprocess_video(control_video, height=height, width=width)
control_video = control_video.to(device=device, dtype=prompt_embeds.dtype)
_, control_video_latents = self.prepare_control_latents(None, control_video)
control_video_latents = control_video_latents.permute(0, 2, 1, 3, 4)
# 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
# 7. Create rotary embeds if required
image_rotary_emb = (
self._prepare_rotary_positional_embeddings(height, width, latents.size(1), device)
if self.transformer.config.use_rotary_positional_embeddings
else None
)
# 8. Denoising loop
num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
with self.progress_bar(total=num_inference_steps) as progress_bar:
# for DPM-solver++
old_pred_original_sample = None
for i, t in enumerate(timesteps):
if self.interrupt:
continue
latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
latent_control_input = (
torch.cat([control_video_latents] * 2) if do_classifier_free_guidance else control_video_latents
)
latent_model_input = torch.cat([latent_model_input, latent_control_input], dim=2)
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
timestep = t.expand(latent_model_input.shape[0])
# predict noise model_output
noise_pred = self.transformer(
hidden_states=latent_model_input,
encoder_hidden_states=prompt_embeds,
timestep=timestep,
image_rotary_emb=image_rotary_emb,
attention_kwargs=attention_kwargs,
return_dict=False,
)[0]
noise_pred = noise_pred.float()
# perform guidance
if use_dynamic_cfg:
self._guidance_scale = 1 + guidance_scale * (
(1 - math.cos(math.pi * ((num_inference_steps - t.item()) / num_inference_steps) ** 5.0)) / 2
)
if do_classifier_free_guidance:
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond)
# compute the previous noisy sample x_t -> x_t-1
latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]
latents = latents.to(prompt_embeds.dtype)
# call the callback, if provided
if callback_on_step_end is not None:
callback_kwargs = {}
for k in callback_on_step_end_tensor_inputs:
callback_kwargs[k] = locals()[k]
callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
latents = callback_outputs.pop("latents", latents)
prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds)
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
progress_bar.update()
if not output_type == "latent":
video = self.decode_latents(latents)
video = self.video_processor.postprocess_video(video=video, output_type=output_type)
else:
video = latents
# Offload all models
self.maybe_free_model_hooks()
if not return_dict:
return (video,)
return CogVideoXPipelineOutput(frames=video)
@@ -15,7 +15,7 @@
import inspect
import math
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
from typing import Callable, Dict, List, Optional, Tuple, Union
import PIL
import torch
@@ -23,7 +23,6 @@ from transformers import T5EncoderModel, T5Tokenizer
from ...callbacks import MultiPipelineCallbacks, PipelineCallback
from ...image_processor import PipelineImageInput
from ...loaders import CogVideoXLoraLoaderMixin
from ...models import AutoencoderKLCogVideoX, CogVideoXTransformer3DModel
from ...models.embeddings import get_3d_rotary_pos_embed
from ...pipelines.pipeline_utils import DiffusionPipeline
@@ -153,7 +152,7 @@ def retrieve_latents(
raise AttributeError("Could not access latents of provided encoder_output")
class CogVideoXImageToVideoPipeline(DiffusionPipeline, CogVideoXLoraLoaderMixin):
class CogVideoXImageToVideoPipeline(DiffusionPipeline):
r"""
Pipeline for image-to-video generation using CogVideoX.
@@ -547,10 +546,6 @@ class CogVideoXImageToVideoPipeline(DiffusionPipeline, CogVideoXLoraLoaderMixin)
def num_timesteps(self):
return self._num_timesteps
@property
def attention_kwargs(self):
return self._attention_kwargs
@property
def interrupt(self):
return self._interrupt
@@ -577,7 +572,6 @@ class CogVideoXImageToVideoPipeline(DiffusionPipeline, CogVideoXLoraLoaderMixin)
negative_prompt_embeds: Optional[torch.FloatTensor] = None,
output_type: str = "pil",
return_dict: bool = True,
attention_kwargs: Optional[Dict[str, Any]] = None,
callback_on_step_end: Optional[
Union[Callable[[int, int, Dict], None], PipelineCallback, MultiPipelineCallbacks]
] = None,
@@ -641,10 +635,6 @@ class CogVideoXImageToVideoPipeline(DiffusionPipeline, CogVideoXLoraLoaderMixin)
return_dict (`bool`, *optional*, defaults to `True`):
Whether or not to return a [`~pipelines.stable_diffusion_xl.StableDiffusionXLPipelineOutput`] instead
of a plain tuple.
attention_kwargs (`dict`, *optional*):
A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
`self.processor` in
[diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
callback_on_step_end (`Callable`, *optional*):
A function that calls at the end of each denoising steps during the inference. The function is called
with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int,
@@ -689,7 +679,6 @@ class CogVideoXImageToVideoPipeline(DiffusionPipeline, CogVideoXLoraLoaderMixin)
negative_prompt_embeds=negative_prompt_embeds,
)
self._guidance_scale = guidance_scale
self._attention_kwargs = attention_kwargs
self._interrupt = False
# 2. Default call parameters
@@ -779,7 +768,6 @@ class CogVideoXImageToVideoPipeline(DiffusionPipeline, CogVideoXLoraLoaderMixin)
encoder_hidden_states=prompt_embeds,
timestep=timestep,
image_rotary_emb=image_rotary_emb,
attention_kwargs=attention_kwargs,
return_dict=False,
)[0]
noise_pred = noise_pred.float()
@@ -1,47 +0,0 @@
from typing import TYPE_CHECKING
from ...utils import (
DIFFUSERS_SLOW_IMPORT,
OptionalDependencyNotAvailable,
_LazyModule,
get_objects_from_module,
is_torch_available,
is_transformers_available,
)
_dummy_objects = {}
_additional_imports = {}
_import_structure = {"pipeline_output": ["CogView3PlusPipelineOutput"]}
try:
if not (is_transformers_available() and is_torch_available()):
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
from ...utils import dummy_torch_and_transformers_objects # noqa F403
_dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects))
else:
_import_structure["pipeline_cogview3plus"] = ["CogView3PlusPipeline"]
if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
try:
if not (is_transformers_available() and is_torch_available()):
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
from ...utils.dummy_torch_and_transformers_objects import * # noqa F403
else:
from .pipeline_cogview3plus import CogView3PlusPipeline
else:
import sys
sys.modules[__name__] = _LazyModule(
__name__,
globals()["__file__"],
_import_structure,
module_spec=__spec__,
)
for name, value in _dummy_objects.items():
setattr(sys.modules[__name__], name, value)
for name, value in _additional_imports.items():
setattr(sys.modules[__name__], name, value)
@@ -1,674 +0,0 @@
# Copyright 2024 The CogVideoX team, Tsinghua University & ZhipuAI and The HuggingFace Team.
# All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import inspect
from typing import Callable, Dict, List, Optional, Tuple, Union
import torch
from transformers import T5EncoderModel, T5Tokenizer
from ...callbacks import MultiPipelineCallbacks, PipelineCallback
from ...image_processor import VaeImageProcessor
from ...models import AutoencoderKL, CogView3PlusTransformer2DModel
from ...pipelines.pipeline_utils import DiffusionPipeline
from ...schedulers import CogVideoXDDIMScheduler, CogVideoXDPMScheduler
from ...utils import logging, replace_example_docstring
from ...utils.torch_utils import randn_tensor
from .pipeline_output import CogView3PipelineOutput
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
EXAMPLE_DOC_STRING = """
Examples:
```python
>>> import torch
>>> from diffusers import CogView3PlusPipeline
>>> pipe = CogView3PlusPipeline.from_pretrained("THUDM/CogView3Plus-3B", torch_dtype=torch.bfloat16)
>>> pipe.to("cuda")
>>> prompt = "A photo of an astronaut riding a horse on mars"
>>> image = pipe(prompt).images[0]
>>> image.save("output.png")
```
"""
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps
def retrieve_timesteps(
scheduler,
num_inference_steps: Optional[int] = None,
device: Optional[Union[str, torch.device]] = None,
timesteps: Optional[List[int]] = None,
sigmas: Optional[List[float]] = None,
**kwargs,
):
"""
Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.
Args:
scheduler (`SchedulerMixin`):
The scheduler to get timesteps from.
num_inference_steps (`int`):
The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps`
must be `None`.
device (`str` or `torch.device`, *optional*):
The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
timesteps (`List[int]`, *optional*):
Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed,
`num_inference_steps` and `sigmas` must be `None`.
sigmas (`List[float]`, *optional*):
Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed,
`num_inference_steps` and `timesteps` must be `None`.
Returns:
`Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the
second element is the number of inference steps.
"""
if timesteps is not None and sigmas is not None:
raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values")
if timesteps is not None:
accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
if not accepts_timesteps:
raise ValueError(
f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
f" timestep schedules. Please check whether you are using the correct scheduler."
)
scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)
timesteps = scheduler.timesteps
num_inference_steps = len(timesteps)
elif sigmas is not None:
accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
if not accept_sigmas:
raise ValueError(
f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
f" sigmas schedules. Please check whether you are using the correct scheduler."
)
scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs)
timesteps = scheduler.timesteps
num_inference_steps = len(timesteps)
else:
scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)
timesteps = scheduler.timesteps
return timesteps, num_inference_steps
class CogView3PlusPipeline(DiffusionPipeline):
r"""
Pipeline for text-to-image generation using CogView3Plus.
This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
Args:
vae ([`AutoencoderKL`]):
Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.
text_encoder ([`T5EncoderModel`]):
Frozen text-encoder. CogView3Plus uses
[T5](https://huggingface.co/docs/transformers/model_doc/t5#transformers.T5EncoderModel); specifically the
[t5-v1_1-xxl](https://huggingface.co/PixArt-alpha/PixArt-alpha/tree/main/t5-v1_1-xxl) variant.
tokenizer (`T5Tokenizer`):
Tokenizer of class
[T5Tokenizer](https://huggingface.co/docs/transformers/model_doc/t5#transformers.T5Tokenizer).
transformer ([`CogView3PlusTransformer2DModel`]):
A text conditioned `CogView3PlusTransformer2DModel` to denoise the encoded image latents.
scheduler ([`SchedulerMixin`]):
A scheduler to be used in combination with `transformer` to denoise the encoded image latents.
"""
_optional_components = []
model_cpu_offload_seq = "text_encoder->transformer->vae"
_callback_tensor_inputs = [
"latents",
"prompt_embeds",
"negative_prompt_embeds",
]
def __init__(
self,
tokenizer: T5Tokenizer,
text_encoder: T5EncoderModel,
vae: AutoencoderKL,
transformer: CogView3PlusTransformer2DModel,
scheduler: Union[CogVideoXDDIMScheduler, CogVideoXDPMScheduler],
):
super().__init__()
self.register_modules(
tokenizer=tokenizer, text_encoder=text_encoder, vae=vae, transformer=transformer, scheduler=scheduler
)
self.vae_scale_factor = (
2 ** (len(self.vae.config.block_out_channels) - 1) if hasattr(self, "vae") and self.vae is not None else 8
)
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
# Copied from diffusers.pipelines.cogvideo.pipeline_cogvideox.CogVideoXPipeline._get_t5_prompt_embeds with num_videos_per_prompt->num_images_per_prompt
def _get_t5_prompt_embeds(
self,
prompt: Union[str, List[str]] = None,
num_images_per_prompt: int = 1,
max_sequence_length: int = 226,
device: Optional[torch.device] = None,
dtype: Optional[torch.dtype] = None,
):
device = device or self._execution_device
dtype = dtype or self.text_encoder.dtype
prompt = [prompt] if isinstance(prompt, str) else prompt
batch_size = len(prompt)
text_inputs = self.tokenizer(
prompt,
padding="max_length",
max_length=max_sequence_length,
truncation=True,
add_special_tokens=True,
return_tensors="pt",
)
text_input_ids = text_inputs.input_ids
untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids):
removed_text = self.tokenizer.batch_decode(untruncated_ids[:, max_sequence_length - 1 : -1])
logger.warning(
"The following part of your input was truncated because `max_sequence_length` is set to "
f" {max_sequence_length} tokens: {removed_text}"
)
prompt_embeds = self.text_encoder(text_input_ids.to(device))[0]
prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)
# duplicate text embeddings for each generation per prompt, using mps friendly method
_, seq_len, _ = prompt_embeds.shape
prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
return prompt_embeds
def encode_prompt(
self,
prompt: Union[str, List[str]],
negative_prompt: Optional[Union[str, List[str]]] = None,
do_classifier_free_guidance: bool = True,
num_images_per_prompt: int = 1,
prompt_embeds: Optional[torch.Tensor] = None,
negative_prompt_embeds: Optional[torch.Tensor] = None,
max_sequence_length: int = 224,
device: Optional[torch.device] = None,
dtype: Optional[torch.dtype] = None,
):
r"""
Encodes the prompt into text encoder hidden states.
Args:
prompt (`str` or `List[str]`, *optional*):
prompt to be encoded
negative_prompt (`str` or `List[str]`, *optional*):
The prompt or prompts not to guide the image generation. If not defined, one has to pass
`negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
less than `1`).
do_classifier_free_guidance (`bool`, *optional*, defaults to `True`):
Whether to use classifier free guidance or not.
num_images_per_prompt (`int`, *optional*, defaults to 1):
Number of images that should be generated per prompt. torch device to place the resulting embeddings on
prompt_embeds (`torch.Tensor`, *optional*):
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
provided, text embeddings will be generated from `prompt` input argument.
negative_prompt_embeds (`torch.Tensor`, *optional*):
Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
argument.
max_sequence_length (`int`, defaults to `224`):
Maximum sequence length in encoded prompt. Can be set to other values but may lead to poorer results.
device: (`torch.device`, *optional*):
torch device
dtype: (`torch.dtype`, *optional*):
torch dtype
"""
device = device or self._execution_device
prompt = [prompt] if isinstance(prompt, str) else prompt
if prompt is not None:
batch_size = len(prompt)
else:
batch_size = prompt_embeds.shape[0]
if prompt_embeds is None:
prompt_embeds = self._get_t5_prompt_embeds(
prompt=prompt,
num_images_per_prompt=num_images_per_prompt,
max_sequence_length=max_sequence_length,
device=device,
dtype=dtype,
)
if do_classifier_free_guidance and negative_prompt is None:
negative_prompt_embeds = prompt_embeds.new_zeros(prompt_embeds.shape)
if do_classifier_free_guidance and negative_prompt_embeds is None:
negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt
if prompt is not None and type(prompt) is not type(negative_prompt):
raise TypeError(
f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
f" {type(prompt)}."
)
elif batch_size != len(negative_prompt):
raise ValueError(
f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
" the batch size of `prompt`."
)
negative_prompt_embeds = self._get_t5_prompt_embeds(
prompt=negative_prompt,
num_images_per_prompt=num_images_per_prompt,
max_sequence_length=max_sequence_length,
device=device,
dtype=dtype,
)
return prompt_embeds, negative_prompt_embeds
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents
def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None):
shape = (
batch_size,
num_channels_latents,
int(height) // self.vae_scale_factor,
int(width) // self.vae_scale_factor,
)
if isinstance(generator, list) and len(generator) != batch_size:
raise ValueError(
f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
f" size of {batch_size}. Make sure the batch size matches the length of the generators."
)
if latents is None:
latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
else:
latents = latents.to(device)
# scale the initial noise by the standard deviation required by the scheduler
latents = latents * self.scheduler.init_noise_sigma
return latents
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs
def prepare_extra_step_kwargs(self, generator, eta):
# prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
# eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
# eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
# and should be between [0, 1]
accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
extra_step_kwargs = {}
if accepts_eta:
extra_step_kwargs["eta"] = eta
# check if the scheduler accepts generator
accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys())
if accepts_generator:
extra_step_kwargs["generator"] = generator
return extra_step_kwargs
# Copied from diffusers.pipelines.latte.pipeline_latte.LattePipeline.check_inputs
def check_inputs(
self,
prompt,
height,
width,
negative_prompt,
callback_on_step_end_tensor_inputs,
prompt_embeds=None,
negative_prompt_embeds=None,
):
if height % 8 != 0 or width % 8 != 0:
raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
if callback_on_step_end_tensor_inputs is not None and not all(
k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
):
raise ValueError(
f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}"
)
if prompt is not None and prompt_embeds is not None:
raise ValueError(
f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
" only forward one of the two."
)
elif prompt is None and prompt_embeds is None:
raise ValueError(
"Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
)
elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
if prompt is not None and negative_prompt_embeds is not None:
raise ValueError(
f"Cannot forward both `prompt`: {prompt} and `negative_prompt_embeds`:"
f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
)
if negative_prompt is not None and negative_prompt_embeds is not None:
raise ValueError(
f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:"
f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
)
if prompt_embeds is not None and negative_prompt_embeds is not None:
if prompt_embeds.shape != negative_prompt_embeds.shape:
raise ValueError(
"`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but"
f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`"
f" {negative_prompt_embeds.shape}."
)
@property
def guidance_scale(self):
return self._guidance_scale
# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
# of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
# corresponds to doing no classifier free guidance.
@property
def do_classifier_free_guidance(self):
return self._guidance_scale > 1
@property
def num_timesteps(self):
return self._num_timesteps
@property
def interrupt(self):
return self._interrupt
@torch.no_grad()
@replace_example_docstring(EXAMPLE_DOC_STRING)
def __call__(
self,
prompt: Optional[Union[str, List[str]]] = None,
negative_prompt: Optional[Union[str, List[str]]] = None,
height: Optional[int] = None,
width: Optional[int] = None,
num_inference_steps: int = 50,
timesteps: Optional[List[int]] = None,
guidance_scale: float = 5.0,
num_images_per_prompt: int = 1,
eta: float = 0.0,
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
latents: Optional[torch.FloatTensor] = None,
prompt_embeds: Optional[torch.FloatTensor] = None,
negative_prompt_embeds: Optional[torch.FloatTensor] = None,
original_size: Optional[Tuple[int, int]] = None,
crops_coords_top_left: Tuple[int, int] = (0, 0),
output_type: str = "pil",
return_dict: bool = True,
callback_on_step_end: Optional[
Union[Callable[[int, int, Dict], None], PipelineCallback, MultiPipelineCallbacks]
] = None,
callback_on_step_end_tensor_inputs: List[str] = ["latents"],
max_sequence_length: int = 224,
) -> Union[CogView3PipelineOutput, Tuple]:
"""
Function invoked when calling the pipeline for generation.
Args:
prompt (`str` or `List[str]`, *optional*):
The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
negative_prompt (`str` or `List[str]`, *optional*):
The prompt or prompts not to guide the image generation. If not defined, one has to pass
`negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
less than `1`).
height (`int`, *optional*, defaults to self.transformer.config.sample_size * self.vae_scale_factor):
The height in pixels of the generated image. If not provided, it is set to 1024.
width (`int`, *optional*, defaults to self.transformer.config.sample_size * self.vae_scale_factor):
The width in pixels of the generated image. If not provided it is set to 1024.
num_inference_steps (`int`, *optional*, defaults to `50`):
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
expense of slower inference.
timesteps (`List[int]`, *optional*):
Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument
in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is
passed will be used. Must be in descending order.
guidance_scale (`float`, *optional*, defaults to `5.0`):
Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
`guidance_scale` is defined as `w` of equation 2. of [Imagen
Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
usually at the expense of lower image quality.
num_images_per_prompt (`int`, *optional*, defaults to `1`):
The number of images to generate per prompt.
generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
to make generation deterministic.
latents (`torch.FloatTensor`, *optional*):
Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
tensor will ge generated by sampling using the supplied random `generator`.
prompt_embeds (`torch.FloatTensor`, *optional*):
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
provided, text embeddings will be generated from `prompt` input argument.
negative_prompt_embeds (`torch.FloatTensor`, *optional*):
Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
argument.
original_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)):
If `original_size` is not the same as `target_size` the image will appear to be down- or upsampled.
`original_size` defaults to `(height, width)` if not specified. Part of SDXL's micro-conditioning as
explained in section 2.2 of
[https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952).
crops_coords_top_left (`Tuple[int]`, *optional*, defaults to (0, 0)):
`crops_coords_top_left` can be used to generate an image that appears to be "cropped" from the position
`crops_coords_top_left` downwards. Favorable, well-centered images are usually achieved by setting
`crops_coords_top_left` to (0, 0). Part of SDXL's micro-conditioning as explained in section 2.2 of
[https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952).
output_type (`str`, *optional*, defaults to `"pil"`):
The output format of the generate image. Choose between
[PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
return_dict (`bool`, *optional*, defaults to `True`):
Whether or not to return a [`~pipelines.stable_diffusion_xl.StableDiffusionXLPipelineOutput`] instead
of a plain tuple.
attention_kwargs (`dict`, *optional*):
A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
`self.processor` in
[diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
callback_on_step_end (`Callable`, *optional*):
A function that calls at the end of each denoising steps during the inference. The function is called
with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int,
callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by
`callback_on_step_end_tensor_inputs`.
callback_on_step_end_tensor_inputs (`List`, *optional*):
The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
`._callback_tensor_inputs` attribute of your pipeline class.
max_sequence_length (`int`, defaults to `224`):
Maximum sequence length in encoded prompt. Can be set to other values but may lead to poorer results.
Examples:
Returns:
[`~pipelines.cogview3.pipeline_cogview3plus.CogView3PipelineOutput`] or `tuple`:
[`~pipelines.cogview3.pipeline_cogview3plus.CogView3PipelineOutput`] if `return_dict` is True, otherwise a
`tuple`. When returning a tuple, the first element is a list with the generated images.
"""
if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)):
callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs
height = height or self.transformer.config.sample_size * self.vae_scale_factor
width = width or self.transformer.config.sample_size * self.vae_scale_factor
original_size = original_size or (height, width)
target_size = (height, width)
# 1. Check inputs. Raise error if not correct
self.check_inputs(
prompt,
height,
width,
negative_prompt,
callback_on_step_end_tensor_inputs,
prompt_embeds,
negative_prompt_embeds,
)
self._guidance_scale = guidance_scale
self._interrupt = False
# 2. Default call parameters
if prompt is not None and isinstance(prompt, str):
batch_size = 1
elif prompt is not None and isinstance(prompt, list):
batch_size = len(prompt)
else:
batch_size = prompt_embeds.shape[0]
device = self._execution_device
# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
# of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
# corresponds to doing no classifier free guidance.
do_classifier_free_guidance = guidance_scale > 1.0
# 3. Encode input prompt
prompt_embeds, negative_prompt_embeds = self.encode_prompt(
prompt,
negative_prompt,
self.do_classifier_free_guidance,
num_images_per_prompt=num_images_per_prompt,
prompt_embeds=prompt_embeds,
negative_prompt_embeds=negative_prompt_embeds,
max_sequence_length=max_sequence_length,
device=device,
)
if self.do_classifier_free_guidance:
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)
# 4. Prepare timesteps
timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, timesteps)
self._num_timesteps = len(timesteps)
# 5. Prepare latents.
latent_channels = self.transformer.config.in_channels
latents = self.prepare_latents(
batch_size * num_images_per_prompt,
latent_channels,
height,
width,
prompt_embeds.dtype,
device,
generator,
latents,
)
# 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
# 7. Prepare additional timestep conditions
original_size = torch.tensor([original_size], dtype=prompt_embeds.dtype)
target_size = torch.tensor([target_size], dtype=prompt_embeds.dtype)
crops_coords_top_left = torch.tensor([crops_coords_top_left], dtype=prompt_embeds.dtype)
if self.do_classifier_free_guidance:
original_size = torch.cat([original_size, original_size])
target_size = torch.cat([target_size, target_size])
crops_coords_top_left = torch.cat([crops_coords_top_left, crops_coords_top_left])
original_size = original_size.to(device).repeat(batch_size * num_images_per_prompt, 1)
target_size = target_size.to(device).repeat(batch_size * num_images_per_prompt, 1)
crops_coords_top_left = crops_coords_top_left.to(device).repeat(batch_size * num_images_per_prompt, 1)
# 8. Denoising loop
num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
with self.progress_bar(total=num_inference_steps) as progress_bar:
# for DPM-solver++
old_pred_original_sample = None
for i, t in enumerate(timesteps):
if self.interrupt:
continue
latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
timestep = t.expand(latent_model_input.shape[0])
# predict noise model_output
noise_pred = self.transformer(
hidden_states=latent_model_input,
encoder_hidden_states=prompt_embeds,
timestep=timestep,
original_size=original_size,
target_size=target_size,
crop_coords=crops_coords_top_left,
return_dict=False,
)[0]
noise_pred = noise_pred.float()
# perform guidance
if self.do_classifier_free_guidance:
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond)
# compute the previous noisy sample x_t -> x_t-1
if not isinstance(self.scheduler, CogVideoXDPMScheduler):
latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]
else:
latents, old_pred_original_sample = self.scheduler.step(
noise_pred,
old_pred_original_sample,
t,
timesteps[i - 1] if i > 0 else None,
latents,
**extra_step_kwargs,
return_dict=False,
)
latents = latents.to(prompt_embeds.dtype)
# call the callback, if provided
if callback_on_step_end is not None:
callback_kwargs = {}
for k in callback_on_step_end_tensor_inputs:
callback_kwargs[k] = locals()[k]
callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
latents = callback_outputs.pop("latents", latents)
prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds)
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
progress_bar.update()
if not output_type == "latent":
image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False, generator=generator)[
0
]
else:
image = latents
image = self.image_processor.postprocess(image, output_type=output_type)
# Offload all models
self.maybe_free_model_hooks()
if not return_dict:
return (image,)
return CogView3PipelineOutput(images=image)
@@ -1,21 +0,0 @@
from dataclasses import dataclass
from typing import List, Union
import numpy as np
import PIL.Image
from ...utils import BaseOutput
@dataclass
class CogView3PipelineOutput(BaseOutput):
"""
Output class for CogView3 pipelines.
Args:
images (`List[PIL.Image.Image]` or `np.ndarray`)
List of denoised PIL images of length `batch_size` or numpy array of shape `(batch_size, height, width,
num_channels)`. PIL images or numpy array present the denoised images of the diffusion pipeline.
"""
images: Union[List[PIL.Image.Image], np.ndarray]
@@ -893,10 +893,6 @@ class StableDiffusionControlNetPipeline(
def num_timesteps(self):
return self._num_timesteps
@property
def interrupt(self):
return self._interrupt
@torch.no_grad()
@replace_example_docstring(EXAMPLE_DOC_STRING)
def __call__(
@@ -1093,7 +1089,6 @@ class StableDiffusionControlNetPipeline(
self._guidance_scale = guidance_scale
self._clip_skip = clip_skip
self._cross_attention_kwargs = cross_attention_kwargs
self._interrupt = False
# 2. Define call parameters
if prompt is not None and isinstance(prompt, str):
@@ -1240,9 +1235,6 @@ class StableDiffusionControlNetPipeline(
is_torch_higher_equal_2_1 = is_torch_version(">=", "2.1")
with self.progress_bar(total=num_inference_steps) as progress_bar:
for i, t in enumerate(timesteps):
if self.interrupt:
continue
# Relevant thread:
# https://dev-discuss.pytorch.org/t/cudagraphs-in-pytorch-2-0/1428
if (is_unet_compiled and is_controlnet_compiled) and is_torch_higher_equal_2_1:
@@ -891,10 +891,6 @@ class StableDiffusionControlNetImg2ImgPipeline(
def num_timesteps(self):
return self._num_timesteps
@property
def interrupt(self):
return self._interrupt
@torch.no_grad()
@replace_example_docstring(EXAMPLE_DOC_STRING)
def __call__(
@@ -1085,7 +1081,6 @@ class StableDiffusionControlNetImg2ImgPipeline(
self._guidance_scale = guidance_scale
self._clip_skip = clip_skip
self._cross_attention_kwargs = cross_attention_kwargs
self._interrupt = False
# 2. Define call parameters
if prompt is not None and isinstance(prompt, str):
@@ -1216,9 +1211,6 @@ class StableDiffusionControlNetImg2ImgPipeline(
num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
with self.progress_bar(total=num_inference_steps) as progress_bar:
for i, t in enumerate(timesteps):
if self.interrupt:
continue
# expand the latents if we are doing classifier free guidance
latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
@@ -976,10 +976,6 @@ class StableDiffusionControlNetInpaintPipeline(
def num_timesteps(self):
return self._num_timesteps
@property
def interrupt(self):
return self._interrupt
@torch.no_grad()
@replace_example_docstring(EXAMPLE_DOC_STRING)
def __call__(
@@ -1195,7 +1191,6 @@ class StableDiffusionControlNetInpaintPipeline(
self._guidance_scale = guidance_scale
self._clip_skip = clip_skip
self._cross_attention_kwargs = cross_attention_kwargs
self._interrupt = False
# 2. Define call parameters
if prompt is not None and isinstance(prompt, str):
@@ -1380,9 +1375,6 @@ class StableDiffusionControlNetInpaintPipeline(
num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
with self.progress_bar(total=num_inference_steps) as progress_bar:
for i, t in enumerate(timesteps):
if self.interrupt:
continue
# expand the latents if we are doing classifier free guidance
latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
@@ -1145,10 +1145,6 @@ class StableDiffusionXLControlNetInpaintPipeline(
def num_timesteps(self):
return self._num_timesteps
@property
def interrupt(self):
return self._interrupt
@torch.no_grad()
@replace_example_docstring(EXAMPLE_DOC_STRING)
def __call__(
@@ -1431,7 +1427,6 @@ class StableDiffusionXLControlNetInpaintPipeline(
self._guidance_scale = guidance_scale
self._clip_skip = clip_skip
self._cross_attention_kwargs = cross_attention_kwargs
self._interrupt = False
# 2. Define call parameters
if prompt is not None and isinstance(prompt, str):
@@ -1700,9 +1695,6 @@ class StableDiffusionXLControlNetInpaintPipeline(
with self.progress_bar(total=num_inference_steps) as progress_bar:
for i, t in enumerate(timesteps):
if self.interrupt:
continue
# expand the latents if we are doing classifier free guidance
latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents
@@ -990,10 +990,6 @@ class StableDiffusionXLControlNetPipeline(
def num_timesteps(self):
return self._num_timesteps
@property
def interrupt(self):
return self._interrupt
@torch.no_grad()
@replace_example_docstring(EXAMPLE_DOC_STRING)
def __call__(
@@ -1249,7 +1245,6 @@ class StableDiffusionXLControlNetPipeline(
self._clip_skip = clip_skip
self._cross_attention_kwargs = cross_attention_kwargs
self._denoising_end = denoising_end
self._interrupt = False
# 2. Define call parameters
if prompt is not None and isinstance(prompt, str):
@@ -1447,9 +1442,6 @@ class StableDiffusionXLControlNetPipeline(
is_torch_higher_equal_2_1 = is_torch_version(">=", "2.1")
with self.progress_bar(total=num_inference_steps) as progress_bar:
for i, t in enumerate(timesteps):
if self.interrupt:
continue
# Relevant thread:
# https://dev-discuss.pytorch.org/t/cudagraphs-in-pytorch-2-0/1428
if (is_unet_compiled and is_controlnet_compiled) and is_torch_higher_equal_2_1:
@@ -1070,10 +1070,6 @@ class StableDiffusionXLControlNetImg2ImgPipeline(
def num_timesteps(self):
return self._num_timesteps
@property
def interrupt(self):
return self._interrupt
@torch.no_grad()
@replace_example_docstring(EXAMPLE_DOC_STRING)
def __call__(
@@ -1342,7 +1338,6 @@ class StableDiffusionXLControlNetImg2ImgPipeline(
self._guidance_scale = guidance_scale
self._clip_skip = clip_skip
self._cross_attention_kwargs = cross_attention_kwargs
self._interrupt = False
# 2. Define call parameters
if prompt is not None and isinstance(prompt, str):
@@ -1515,9 +1510,6 @@ class StableDiffusionXLControlNetImg2ImgPipeline(
num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
with self.progress_bar(total=num_inference_steps) as progress_bar:
for i, t in enumerate(timesteps):
if self.interrupt:
continue
# expand the latents if we are doing classifier free guidance
latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
@@ -225,8 +225,6 @@ class HunyuanDiTControlNetPipeline(DiffusionPipeline):
requires_safety_checker: bool = True,
):
super().__init__()
if isinstance(controlnet, (list, tuple)):
controlnet = HunyuanDiT2DMultiControlNetModel(controlnet)
self.register_modules(
vae=vae,
@@ -192,8 +192,6 @@ class StableDiffusion3ControlNetPipeline(DiffusionPipeline, SD3LoraLoaderMixin,
],
):
super().__init__()
if isinstance(controlnet, (list, tuple)):
controlnet = SD3MultiControlNetModel(controlnet)
self.register_modules(
vae=vae,
@@ -251,9 +251,6 @@ class StableDiffusion3ControlNetInpaintingPipeline(DiffusionPipeline, SD3LoraLoa
if hasattr(self, "transformer") and self.transformer is not None
else 128
)
self.patch_size = (
self.transformer.config.patch_size if hasattr(self, "transformer") and self.transformer is not None else 2
)
# Copied from diffusers.pipelines.stable_diffusion_3.pipeline_stable_diffusion_3.StableDiffusion3Pipeline._get_t5_prompt_embeds
def _get_t5_prompt_embeds(
@@ -580,14 +577,8 @@ class StableDiffusion3ControlNetInpaintingPipeline(DiffusionPipeline, SD3LoraLoa
callback_on_step_end_tensor_inputs=None,
max_sequence_length=None,
):
if (
height % (self.vae_scale_factor * self.patch_size) != 0
or width % (self.vae_scale_factor * self.patch_size) != 0
):
raise ValueError(
f"`height` and `width` have to be divisible by {self.vae_scale_factor * self.patch_size} but are {height} and {width}."
f"You can use height {height - height % (self.vae_scale_factor * self.patch_size)} and width {width - width % (self.vae_scale_factor * self.patch_size)}."
)
if height % 8 != 0 or width % 8 != 0:
raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
if callback_on_step_end_tensor_inputs is not None and not all(
k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
@@ -9,17 +9,16 @@ from ...utils import BaseOutput
@dataclass
class IFPipelineOutput(BaseOutput):
r"""
Output class for Stable Diffusion pipelines.
"""
Args:
images (`List[PIL.Image.Image]` or `np.ndarray`):
Output class for Stable Diffusion pipelines.
images (`List[PIL.Image.Image]` or `np.ndarray`)
List of denoised PIL images of length `batch_size` or numpy array of shape `(batch_size, height, width,
num_channels)`. PIL images or numpy array present the denoised images of the diffusion pipeline.
nsfw_detected (`List[bool]`):
nsfw_detected (`List[bool]`)
List of flags denoting whether the corresponding generated image likely represents "not-safe-for-work"
(nsfw) content or a watermark. `None` if safety checking could not be performed.
watermark_detected (`List[bool]`):
watermark_detected (`List[bool]`)
List of flags denoting whether the corresponding generated image likely has a watermark. `None` if safety
checking could not be performed.
"""
+2 -13
View File
@@ -20,7 +20,7 @@ import torch
from transformers import CLIPTextModel, CLIPTokenizer, T5EncoderModel, T5TokenizerFast
from ...image_processor import VaeImageProcessor
from ...loaders import FluxLoraLoaderMixin, FromSingleFileMixin, TextualInversionLoaderMixin
from ...loaders import FluxLoraLoaderMixin, FromSingleFileMixin
from ...models.autoencoders import AutoencoderKL
from ...models.transformers import FluxTransformer2DModel
from ...schedulers import FlowMatchEulerDiscreteScheduler
@@ -137,12 +137,7 @@ def retrieve_timesteps(
return timesteps, num_inference_steps
class FluxPipeline(
DiffusionPipeline,
FluxLoraLoaderMixin,
FromSingleFileMixin,
TextualInversionLoaderMixin,
):
class FluxPipeline(DiffusionPipeline, FluxLoraLoaderMixin, FromSingleFileMixin):
r"""
The Flux pipeline for text-to-image generation.
@@ -217,9 +212,6 @@ class FluxPipeline(
prompt = [prompt] if isinstance(prompt, str) else prompt
batch_size = len(prompt)
if isinstance(self, TextualInversionLoaderMixin):
prompt = self.maybe_convert_prompt(prompt, self.tokenizer_2)
text_inputs = self.tokenizer_2(
prompt,
padding="max_length",
@@ -263,9 +255,6 @@ class FluxPipeline(
prompt = [prompt] if isinstance(prompt, str) else prompt
batch_size = len(prompt)
if isinstance(self, TextualInversionLoaderMixin):
prompt = self.maybe_convert_prompt(prompt, self.tokenizer)
text_inputs = self.tokenizer(
prompt,
padding="max_length",
@@ -25,7 +25,7 @@ from transformers import (
)
from ...image_processor import PipelineImageInput, VaeImageProcessor
from ...loaders import FluxLoraLoaderMixin, FromSingleFileMixin, TextualInversionLoaderMixin
from ...loaders import FluxLoraLoaderMixin, FromSingleFileMixin
from ...models.autoencoders import AutoencoderKL
from ...models.controlnet_flux import FluxControlNetModel, FluxMultiControlNetModel
from ...models.transformers import FluxTransformer2DModel
@@ -72,9 +72,7 @@ EXAMPLE_DOC_STRING = """
>>> image = pipe(
... prompt,
... control_image=control_image,
... control_guidance_start=0.2,
... control_guidance_end=0.8,
... controlnet_conditioning_scale=1.0,
... controlnet_conditioning_scale=0.6,
... num_inference_steps=28,
... guidance_scale=3.5,
... ).images[0]
@@ -202,8 +200,6 @@ class FluxControlNetPipeline(DiffusionPipeline, FluxLoraLoaderMixin, FromSingleF
],
):
super().__init__()
if isinstance(controlnet, (list, tuple)):
controlnet = FluxMultiControlNetModel(controlnet)
self.register_modules(
vae=vae,
@@ -238,9 +234,6 @@ class FluxControlNetPipeline(DiffusionPipeline, FluxLoraLoaderMixin, FromSingleF
prompt = [prompt] if isinstance(prompt, str) else prompt
batch_size = len(prompt)
if isinstance(self, TextualInversionLoaderMixin):
prompt = self.maybe_convert_prompt(prompt, self.tokenizer)
text_inputs = self.tokenizer_2(
prompt,
padding="max_length",
@@ -284,9 +277,6 @@ class FluxControlNetPipeline(DiffusionPipeline, FluxLoraLoaderMixin, FromSingleF
prompt = [prompt] if isinstance(prompt, str) else prompt
batch_size = len(prompt)
if isinstance(self, TextualInversionLoaderMixin):
prompt = self.maybe_convert_prompt(prompt, self.tokenizer)
text_inputs = self.tokenizer(
prompt,
padding="max_length",
@@ -582,8 +572,6 @@ class FluxControlNetPipeline(DiffusionPipeline, FluxLoraLoaderMixin, FromSingleF
num_inference_steps: int = 28,
timesteps: List[int] = None,
guidance_scale: float = 7.0,
control_guidance_start: Union[float, List[float]] = 0.0,
control_guidance_end: Union[float, List[float]] = 1.0,
control_image: PipelineImageInput = None,
control_mode: Optional[Union[int, List[int]]] = None,
controlnet_conditioning_scale: Union[float, List[float]] = 1.0,
@@ -626,10 +614,6 @@ class FluxControlNetPipeline(DiffusionPipeline, FluxLoraLoaderMixin, FromSingleF
Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
usually at the expense of lower image quality.
control_guidance_start (`float` or `List[float]`, *optional*, defaults to 0.0):
The percentage of total steps at which the ControlNet starts applying.
control_guidance_end (`float` or `List[float]`, *optional*, defaults to 1.0):
The percentage of total steps at which the ControlNet stops applying.
control_image (`torch.Tensor`, `PIL.Image.Image`, `np.ndarray`, `List[torch.Tensor]`, `List[PIL.Image.Image]`, `List[np.ndarray]`,:
`List[List[torch.Tensor]]`, `List[List[np.ndarray]]` or `List[List[PIL.Image.Image]]`):
The ControlNet input condition to provide guidance to the `unet` for generation. If the type is
@@ -690,17 +674,6 @@ class FluxControlNetPipeline(DiffusionPipeline, FluxLoraLoaderMixin, FromSingleF
height = height or self.default_sample_size * self.vae_scale_factor
width = width or self.default_sample_size * self.vae_scale_factor
if not isinstance(control_guidance_start, list) and isinstance(control_guidance_end, list):
control_guidance_start = len(control_guidance_end) * [control_guidance_start]
elif not isinstance(control_guidance_end, list) and isinstance(control_guidance_start, list):
control_guidance_end = len(control_guidance_start) * [control_guidance_end]
elif not isinstance(control_guidance_start, list) and not isinstance(control_guidance_end, list):
mult = len(self.controlnet.nets) if isinstance(self.controlnet, FluxMultiControlNetModel) else 1
control_guidance_start, control_guidance_end = (
mult * [control_guidance_start],
mult * [control_guidance_end],
)
# 1. Check inputs. Raise error if not correct
self.check_inputs(
prompt,
@@ -760,35 +733,29 @@ class FluxControlNetPipeline(DiffusionPipeline, FluxLoraLoaderMixin, FromSingleF
)
height, width = control_image.shape[-2:]
# xlab controlnet has a input_hint_block and instantx controlnet does not
controlnet_blocks_repeat = False if self.controlnet.input_hint_block is None else True
if self.controlnet.input_hint_block is None:
# vae encode
control_image = self.vae.encode(control_image).latent_dist.sample()
control_image = (control_image - self.vae.config.shift_factor) * self.vae.config.scaling_factor
# vae encode
control_image = self.vae.encode(control_image).latent_dist.sample()
control_image = (control_image - self.vae.config.shift_factor) * self.vae.config.scaling_factor
# pack
height_control_image, width_control_image = control_image.shape[2:]
control_image = self._pack_latents(
control_image,
batch_size * num_images_per_prompt,
num_channels_latents,
height_control_image,
width_control_image,
)
# pack
height_control_image, width_control_image = control_image.shape[2:]
control_image = self._pack_latents(
control_image,
batch_size * num_images_per_prompt,
num_channels_latents,
height_control_image,
width_control_image,
)
# Here we ensure that `control_mode` has the same length as the control_image.
# set control mode
if control_mode is not None:
if not isinstance(control_mode, int):
raise ValueError(" For `FluxControlNet`, `control_mode` should be an `int` or `None`")
control_mode = torch.tensor(control_mode).to(device, dtype=torch.long)
control_mode = control_mode.view(-1, 1).expand(control_image.shape[0], 1)
control_mode = control_mode.reshape([-1, 1])
elif isinstance(self.controlnet, FluxMultiControlNetModel):
control_images = []
# xlab controlnet has a input_hint_block and instantx controlnet does not
controlnet_blocks_repeat = False if self.controlnet.nets[0].input_hint_block is None else True
for i, control_image_ in enumerate(control_image):
for control_image_ in control_image:
control_image_ = self.prepare_image(
image=control_image_,
width=width,
@@ -800,40 +767,34 @@ class FluxControlNetPipeline(DiffusionPipeline, FluxLoraLoaderMixin, FromSingleF
)
height, width = control_image_.shape[-2:]
if self.controlnet.nets[0].input_hint_block is None:
# vae encode
control_image_ = self.vae.encode(control_image_).latent_dist.sample()
control_image_ = (control_image_ - self.vae.config.shift_factor) * self.vae.config.scaling_factor
# vae encode
control_image_ = self.vae.encode(control_image_).latent_dist.sample()
control_image_ = (control_image_ - self.vae.config.shift_factor) * self.vae.config.scaling_factor
# pack
height_control_image, width_control_image = control_image_.shape[2:]
control_image_ = self._pack_latents(
control_image_,
batch_size * num_images_per_prompt,
num_channels_latents,
height_control_image,
width_control_image,
)
# pack
height_control_image, width_control_image = control_image_.shape[2:]
control_image_ = self._pack_latents(
control_image_,
batch_size * num_images_per_prompt,
num_channels_latents,
height_control_image,
width_control_image,
)
control_images.append(control_image_)
control_image = control_images
# Here we ensure that `control_mode` has the same length as the control_image.
if isinstance(control_mode, list) and len(control_mode) != len(control_image):
raise ValueError(
"For Multi-ControlNet, `control_mode` must be a list of the same "
+ " length as the number of controlnets (control images) specified"
)
if not isinstance(control_mode, list):
control_mode = [control_mode] * len(control_image)
# set control mode
control_modes = []
for cmode in control_mode:
if cmode is None:
cmode = -1
control_mode = torch.tensor(cmode).expand(control_images[0].shape[0]).to(device, dtype=torch.long)
control_modes.append(control_mode)
control_mode = control_modes
control_mode_ = []
if isinstance(control_mode, list):
for cmode in control_mode:
if cmode is None:
control_mode_.append(-1)
else:
control_mode_.append(cmode)
control_mode = torch.tensor(control_mode_).to(device, dtype=torch.long)
control_mode = control_mode.reshape([-1, 1])
# 4. Prepare latent variables
num_channels_latents = self.transformer.config.in_channels // 4
@@ -870,16 +831,7 @@ class FluxControlNetPipeline(DiffusionPipeline, FluxLoraLoaderMixin, FromSingleF
num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
self._num_timesteps = len(timesteps)
# 6. Create tensor stating which controlnets to keep
controlnet_keep = []
for i in range(len(timesteps)):
keeps = [
1.0 - float(i / len(timesteps) < s or (i + 1) / len(timesteps) > e)
for s, e in zip(control_guidance_start, control_guidance_end)
]
controlnet_keep.append(keeps[0] if isinstance(self.controlnet, FluxControlNetModel) else keeps)
# 7. Denoising loop
# 6. Denoising loop
with self.progress_bar(total=num_inference_steps) as progress_bar:
for i, t in enumerate(timesteps):
if self.interrupt:
@@ -888,28 +840,17 @@ class FluxControlNetPipeline(DiffusionPipeline, FluxLoraLoaderMixin, FromSingleF
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
timestep = t.expand(latents.shape[0]).to(latents.dtype)
if isinstance(self.controlnet, FluxMultiControlNetModel):
use_guidance = self.controlnet.nets[0].config.guidance_embeds
else:
use_guidance = self.controlnet.config.guidance_embeds
guidance = torch.tensor([guidance_scale], device=device) if use_guidance else None
guidance = (
torch.tensor([guidance_scale], device=device) if self.controlnet.config.guidance_embeds else None
)
guidance = guidance.expand(latents.shape[0]) if guidance is not None else None
if isinstance(controlnet_keep[i], list):
cond_scale = [c * s for c, s in zip(controlnet_conditioning_scale, controlnet_keep[i])]
else:
controlnet_cond_scale = controlnet_conditioning_scale
if isinstance(controlnet_cond_scale, list):
controlnet_cond_scale = controlnet_cond_scale[0]
cond_scale = controlnet_cond_scale * controlnet_keep[i]
# controlnet
controlnet_block_samples, controlnet_single_block_samples = self.controlnet(
hidden_states=latents,
controlnet_cond=control_image,
controlnet_mode=control_mode,
conditioning_scale=cond_scale,
conditioning_scale=controlnet_conditioning_scale,
timestep=timestep / 1000,
guidance=guidance,
pooled_projections=pooled_prompt_embeds,
@@ -937,7 +878,6 @@ class FluxControlNetPipeline(DiffusionPipeline, FluxLoraLoaderMixin, FromSingleF
img_ids=latent_image_ids,
joint_attention_kwargs=self.joint_attention_kwargs,
return_dict=False,
controlnet_blocks_repeat=controlnet_blocks_repeat,
)[0]
# compute the previous noisy sample x_t -> x_t-1
@@ -11,7 +11,7 @@ from transformers import (
)
from ...image_processor import PipelineImageInput, VaeImageProcessor
from ...loaders import FluxLoraLoaderMixin, FromSingleFileMixin, TextualInversionLoaderMixin
from ...loaders import FluxLoraLoaderMixin, FromSingleFileMixin
from ...models.autoencoders import AutoencoderKL
from ...models.controlnet_flux import FluxControlNetModel, FluxMultiControlNetModel
from ...models.transformers import FluxTransformer2DModel
@@ -69,9 +69,7 @@ EXAMPLE_DOC_STRING = """
... prompt,
... image=init_image,
... control_image=control_image,
... control_guidance_start=0.2,
... control_guidance_end=0.8,
... controlnet_conditioning_scale=1.0,
... controlnet_conditioning_scale=0.6,
... strength=0.7,
... num_inference_steps=2,
... guidance_scale=3.5,
@@ -214,8 +212,6 @@ class FluxControlNetImg2ImgPipeline(DiffusionPipeline, FluxLoraLoaderMixin, From
],
):
super().__init__()
if isinstance(controlnet, (list, tuple)):
controlnet = FluxMultiControlNetModel(controlnet)
self.register_modules(
vae=vae,
@@ -251,9 +247,6 @@ class FluxControlNetImg2ImgPipeline(DiffusionPipeline, FluxLoraLoaderMixin, From
prompt = [prompt] if isinstance(prompt, str) else prompt
batch_size = len(prompt)
if isinstance(self, TextualInversionLoaderMixin):
prompt = self.maybe_convert_prompt(prompt, self.tokenizer_2)
text_inputs = self.tokenizer_2(
prompt,
padding="max_length",
@@ -298,9 +291,6 @@ class FluxControlNetImg2ImgPipeline(DiffusionPipeline, FluxLoraLoaderMixin, From
prompt = [prompt] if isinstance(prompt, str) else prompt
batch_size = len(prompt)
if isinstance(self, TextualInversionLoaderMixin):
prompt = self.maybe_convert_prompt(prompt, self.tokenizer)
text_inputs = self.tokenizer(
prompt,
padding="max_length",
@@ -641,8 +631,6 @@ class FluxControlNetImg2ImgPipeline(DiffusionPipeline, FluxLoraLoaderMixin, From
num_inference_steps: int = 28,
timesteps: List[int] = None,
guidance_scale: float = 7.0,
control_guidance_start: Union[float, List[float]] = 0.0,
control_guidance_end: Union[float, List[float]] = 1.0,
control_mode: Optional[Union[int, List[int]]] = None,
controlnet_conditioning_scale: Union[float, List[float]] = 1.0,
num_images_per_prompt: Optional[int] = 1,
@@ -722,17 +710,6 @@ class FluxControlNetImg2ImgPipeline(DiffusionPipeline, FluxLoraLoaderMixin, From
height = height or self.default_sample_size * self.vae_scale_factor
width = width or self.default_sample_size * self.vae_scale_factor
if not isinstance(control_guidance_start, list) and isinstance(control_guidance_end, list):
control_guidance_start = len(control_guidance_end) * [control_guidance_start]
elif not isinstance(control_guidance_end, list) and isinstance(control_guidance_start, list):
control_guidance_end = len(control_guidance_start) * [control_guidance_end]
elif not isinstance(control_guidance_start, list) and not isinstance(control_guidance_end, list):
mult = len(self.controlnet.nets) if isinstance(self.controlnet, FluxMultiControlNetModel) else 1
control_guidance_start, control_guidance_end = (
mult * [control_guidance_start],
mult * [control_guidance_end],
)
self.check_inputs(
prompt,
prompt_2,
@@ -885,14 +862,6 @@ class FluxControlNetImg2ImgPipeline(DiffusionPipeline, FluxLoraLoaderMixin, From
latents,
)
controlnet_keep = []
for i in range(len(timesteps)):
keeps = [
1.0 - float(i / len(timesteps) < s or (i + 1) / len(timesteps) > e)
for s, e in zip(control_guidance_start, control_guidance_end)
]
controlnet_keep.append(keeps[0] if isinstance(self.controlnet, FluxControlNetModel) else keeps)
num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
self._num_timesteps = len(timesteps)
@@ -908,19 +877,11 @@ class FluxControlNetImg2ImgPipeline(DiffusionPipeline, FluxLoraLoaderMixin, From
)
guidance = guidance.expand(latents.shape[0]) if guidance is not None else None
if isinstance(controlnet_keep[i], list):
cond_scale = [c * s for c, s in zip(controlnet_conditioning_scale, controlnet_keep[i])]
else:
controlnet_cond_scale = controlnet_conditioning_scale
if isinstance(controlnet_cond_scale, list):
controlnet_cond_scale = controlnet_cond_scale[0]
cond_scale = controlnet_cond_scale * controlnet_keep[i]
controlnet_block_samples, controlnet_single_block_samples = self.controlnet(
hidden_states=latents,
controlnet_cond=control_image,
controlnet_mode=control_mode,
conditioning_scale=cond_scale,
conditioning_scale=controlnet_conditioning_scale,
timestep=timestep / 1000,
guidance=guidance,
pooled_projections=pooled_prompt_embeds,
@@ -12,7 +12,7 @@ from transformers import (
)
from ...image_processor import PipelineImageInput, VaeImageProcessor
from ...loaders import FluxLoraLoaderMixin, FromSingleFileMixin, TextualInversionLoaderMixin
from ...loaders import FluxLoraLoaderMixin, FromSingleFileMixin
from ...models.autoencoders import AutoencoderKL
from ...models.controlnet_flux import FluxControlNetModel, FluxMultiControlNetModel
from ...models.transformers import FluxTransformer2DModel
@@ -71,8 +71,6 @@ EXAMPLE_DOC_STRING = """
... image=init_image,
... mask_image=mask_image,
... control_image=control_image,
... control_guidance_start=0.2,
... control_guidance_end=0.8,
... controlnet_conditioning_scale=0.7,
... strength=0.7,
... num_inference_steps=28,
@@ -216,8 +214,6 @@ class FluxControlNetInpaintPipeline(DiffusionPipeline, FluxLoraLoaderMixin, From
],
):
super().__init__()
if isinstance(controlnet, (list, tuple)):
controlnet = FluxMultiControlNetModel(controlnet)
self.register_modules(
scheduler=scheduler,
@@ -261,9 +257,6 @@ class FluxControlNetInpaintPipeline(DiffusionPipeline, FluxLoraLoaderMixin, From
prompt = [prompt] if isinstance(prompt, str) else prompt
batch_size = len(prompt)
if isinstance(self, TextualInversionLoaderMixin):
prompt = self.maybe_convert_prompt(prompt, self.tokenizer_2)
text_inputs = self.tokenizer_2(
prompt,
padding="max_length",
@@ -308,9 +301,6 @@ class FluxControlNetInpaintPipeline(DiffusionPipeline, FluxLoraLoaderMixin, From
prompt = [prompt] if isinstance(prompt, str) else prompt
batch_size = len(prompt)
if isinstance(self, TextualInversionLoaderMixin):
prompt = self.maybe_convert_prompt(prompt, self.tokenizer)
text_inputs = self.tokenizer(
prompt,
padding="max_length",
@@ -747,8 +737,6 @@ class FluxControlNetInpaintPipeline(DiffusionPipeline, FluxLoraLoaderMixin, From
timesteps: List[int] = None,
num_inference_steps: int = 28,
guidance_scale: float = 7.0,
control_guidance_start: Union[float, List[float]] = 0.0,
control_guidance_end: Union[float, List[float]] = 1.0,
control_mode: Optional[Union[int, List[int]]] = None,
controlnet_conditioning_scale: Union[float, List[float]] = 1.0,
num_images_per_prompt: Optional[int] = 1,
@@ -795,10 +783,6 @@ class FluxControlNetInpaintPipeline(DiffusionPipeline, FluxLoraLoaderMixin, From
Custom timesteps to use for the denoising process.
guidance_scale (`float`, *optional*, defaults to 7.0):
Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
control_guidance_start (`float` or `List[float]`, *optional*, defaults to 0.0):
The percentage of total steps at which the ControlNet starts applying.
control_guidance_end (`float` or `List[float]`, *optional*, defaults to 1.0):
The percentage of total steps at which the ControlNet stops applying.
control_mode (`int` or `List[int]`, *optional*):
The mode for the ControlNet. If multiple ControlNets are used, this should be a list.
controlnet_conditioning_scale (`float` or `List[float]`, *optional*, defaults to 1.0):
@@ -842,17 +826,6 @@ class FluxControlNetInpaintPipeline(DiffusionPipeline, FluxLoraLoaderMixin, From
global_height = height
global_width = width
if not isinstance(control_guidance_start, list) and isinstance(control_guidance_end, list):
control_guidance_start = len(control_guidance_end) * [control_guidance_start]
elif not isinstance(control_guidance_end, list) and isinstance(control_guidance_start, list):
control_guidance_end = len(control_guidance_start) * [control_guidance_end]
elif not isinstance(control_guidance_start, list) and not isinstance(control_guidance_end, list):
mult = len(self.controlnet.nets) if isinstance(self.controlnet, FluxMultiControlNetModel) else 1
control_guidance_start, control_guidance_end = (
mult * [control_guidance_start],
mult * [control_guidance_end],
)
# 1. Check inputs
self.check_inputs(
prompt,
@@ -1058,14 +1031,6 @@ class FluxControlNetInpaintPipeline(DiffusionPipeline, FluxLoraLoaderMixin, From
generator,
)
controlnet_keep = []
for i in range(len(timesteps)):
keeps = [
1.0 - float(i / len(timesteps) < s or (i + 1) / len(timesteps) > e)
for s, e in zip(control_guidance_start, control_guidance_end)
]
controlnet_keep.append(keeps[0] if isinstance(self.controlnet, FluxControlNetModel) else keeps)
# 9. Denoising loop
num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
self._num_timesteps = len(timesteps)
@@ -1084,19 +1049,11 @@ class FluxControlNetInpaintPipeline(DiffusionPipeline, FluxLoraLoaderMixin, From
else:
guidance = None
if isinstance(controlnet_keep[i], list):
cond_scale = [c * s for c, s in zip(controlnet_conditioning_scale, controlnet_keep[i])]
else:
controlnet_cond_scale = controlnet_conditioning_scale
if isinstance(controlnet_cond_scale, list):
controlnet_cond_scale = controlnet_cond_scale[0]
cond_scale = controlnet_cond_scale * controlnet_keep[i]
controlnet_block_samples, controlnet_single_block_samples = self.controlnet(
hidden_states=latents,
controlnet_cond=control_image,
controlnet_mode=control_mode,
conditioning_scale=cond_scale,
conditioning_scale=controlnet_conditioning_scale,
timestep=timestep / 1000,
guidance=guidance,
pooled_projections=pooled_prompt_embeds,
@@ -20,7 +20,7 @@ import torch
from transformers import CLIPTextModel, CLIPTokenizer, T5EncoderModel, T5TokenizerFast
from ...image_processor import PipelineImageInput, VaeImageProcessor
from ...loaders import FluxLoraLoaderMixin, TextualInversionLoaderMixin
from ...loaders import FluxLoraLoaderMixin
from ...models.autoencoders import AutoencoderKL
from ...models.transformers import FluxTransformer2DModel
from ...schedulers import FlowMatchEulerDiscreteScheduler
@@ -235,9 +235,6 @@ class FluxImg2ImgPipeline(DiffusionPipeline, FluxLoraLoaderMixin):
prompt = [prompt] if isinstance(prompt, str) else prompt
batch_size = len(prompt)
if isinstance(self, TextualInversionLoaderMixin):
prompt = self.maybe_convert_prompt(prompt, self.tokenizer_2)
text_inputs = self.tokenizer_2(
prompt,
padding="max_length",
@@ -282,9 +279,6 @@ class FluxImg2ImgPipeline(DiffusionPipeline, FluxLoraLoaderMixin):
prompt = [prompt] if isinstance(prompt, str) else prompt
batch_size = len(prompt)
if isinstance(self, TextualInversionLoaderMixin):
prompt = self.maybe_convert_prompt(prompt, self.tokenizer)
text_inputs = self.tokenizer(
prompt,
padding="max_length",
@@ -21,7 +21,7 @@ import torch
from transformers import CLIPTextModel, CLIPTokenizer, T5EncoderModel, T5TokenizerFast
from ...image_processor import PipelineImageInput, VaeImageProcessor
from ...loaders import FluxLoraLoaderMixin, TextualInversionLoaderMixin
from ...loaders import FluxLoraLoaderMixin
from ...models.autoencoders import AutoencoderKL
from ...models.transformers import FluxTransformer2DModel
from ...schedulers import FlowMatchEulerDiscreteScheduler
@@ -239,9 +239,6 @@ class FluxInpaintPipeline(DiffusionPipeline, FluxLoraLoaderMixin):
prompt = [prompt] if isinstance(prompt, str) else prompt
batch_size = len(prompt)
if isinstance(self, TextualInversionLoaderMixin):
prompt = self.maybe_convert_prompt(prompt, self.tokenizer_2)
text_inputs = self.tokenizer_2(
prompt,
padding="max_length",
@@ -286,9 +283,6 @@ class FluxInpaintPipeline(DiffusionPipeline, FluxLoraLoaderMixin):
prompt = [prompt] if isinstance(prompt, str) else prompt
batch_size = len(prompt)
if isinstance(self, TextualInversionLoaderMixin):
prompt = self.maybe_convert_prompt(prompt, self.tokenizer)
text_inputs = self.tokenizer(
prompt,
padding="max_length",
@@ -277,7 +277,6 @@ class ChatGLMTokenizer(PreTrainedTokenizer):
padding_strategy: PaddingStrategy = PaddingStrategy.DO_NOT_PAD,
pad_to_multiple_of: Optional[int] = None,
return_attention_mask: Optional[bool] = None,
padding_side: Optional[bool] = None,
) -> dict:
"""
Pad encoded inputs (on left/right and up to predefined length or max length in the batch)
@@ -299,9 +298,6 @@ class ChatGLMTokenizer(PreTrainedTokenizer):
pad_to_multiple_of: (optional) Integer if set will pad the sequence to a multiple of the provided value.
This is especially useful to enable the use of Tensor Core on NVIDIA hardware with compute capability
`>= 7.5` (Volta).
padding_side (`str`, *optional*):
The side on which the model should have padding applied. Should be selected between ['right', 'left'].
Default value is picked from the class attribute of the same name.
return_attention_mask:
(optional) Set to False to avoid returning attention mask (default: set to model specifics)
"""
-4
View File
@@ -23,7 +23,6 @@ except OptionalDependencyNotAvailable:
_dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects))
else:
_import_structure["pipeline_pag_controlnet_sd"] = ["StableDiffusionControlNetPAGPipeline"]
_import_structure["pipeline_pag_controlnet_sd_inpaint"] = ["StableDiffusionControlNetPAGInpaintPipeline"]
_import_structure["pipeline_pag_controlnet_sd_xl"] = ["StableDiffusionXLControlNetPAGPipeline"]
_import_structure["pipeline_pag_controlnet_sd_xl_img2img"] = ["StableDiffusionXLControlNetPAGImg2ImgPipeline"]
_import_structure["pipeline_pag_hunyuandit"] = ["HunyuanDiTPAGPipeline"]
@@ -32,7 +31,6 @@ else:
_import_structure["pipeline_pag_sd"] = ["StableDiffusionPAGPipeline"]
_import_structure["pipeline_pag_sd_3"] = ["StableDiffusion3PAGPipeline"]
_import_structure["pipeline_pag_sd_animatediff"] = ["AnimateDiffPAGPipeline"]
_import_structure["pipeline_pag_sd_img2img"] = ["StableDiffusionPAGImg2ImgPipeline"]
_import_structure["pipeline_pag_sd_xl"] = ["StableDiffusionXLPAGPipeline"]
_import_structure["pipeline_pag_sd_xl_img2img"] = ["StableDiffusionXLPAGImg2ImgPipeline"]
_import_structure["pipeline_pag_sd_xl_inpaint"] = ["StableDiffusionXLPAGInpaintPipeline"]
@@ -46,7 +44,6 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
from ...utils.dummy_torch_and_transformers_objects import *
else:
from .pipeline_pag_controlnet_sd import StableDiffusionControlNetPAGPipeline
from .pipeline_pag_controlnet_sd_inpaint import StableDiffusionControlNetPAGInpaintPipeline
from .pipeline_pag_controlnet_sd_xl import StableDiffusionXLControlNetPAGPipeline
from .pipeline_pag_controlnet_sd_xl_img2img import StableDiffusionXLControlNetPAGImg2ImgPipeline
from .pipeline_pag_hunyuandit import HunyuanDiTPAGPipeline
@@ -55,7 +52,6 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
from .pipeline_pag_sd import StableDiffusionPAGPipeline
from .pipeline_pag_sd_3 import StableDiffusion3PAGPipeline
from .pipeline_pag_sd_animatediff import AnimateDiffPAGPipeline
from .pipeline_pag_sd_img2img import StableDiffusionPAGImg2ImgPipeline
from .pipeline_pag_sd_xl import StableDiffusionXLPAGPipeline
from .pipeline_pag_sd_xl_img2img import StableDiffusionXLPAGImg2ImgPipeline
from .pipeline_pag_sd_xl_inpaint import StableDiffusionXLPAGInpaintPipeline
+2 -8
View File
@@ -98,9 +98,7 @@ class PAGMixin:
else:
return self.pag_scale
def _apply_perturbed_attention_guidance(
self, noise_pred, do_classifier_free_guidance, guidance_scale, t, return_pred_text=False
):
def _apply_perturbed_attention_guidance(self, noise_pred, do_classifier_free_guidance, guidance_scale, t):
r"""
Apply perturbed attention guidance to the noise prediction.
@@ -109,11 +107,9 @@ class PAGMixin:
do_classifier_free_guidance (bool): Whether to apply classifier-free guidance.
guidance_scale (float): The scale factor for the guidance term.
t (int): The current time step.
return_pred_text (bool): Whether to return the text noise prediction.
Returns:
Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: The updated noise prediction tensor after applying
perturbed attention guidance and the text noise prediction.
torch.Tensor: The updated noise prediction tensor after applying perturbed attention guidance.
"""
pag_scale = self._get_pag_scale(t)
if do_classifier_free_guidance:
@@ -126,8 +122,6 @@ class PAGMixin:
else:
noise_pred_text, noise_pred_perturb = noise_pred.chunk(2)
noise_pred = noise_pred_text + pag_scale * (noise_pred_text - noise_pred_perturb)
if return_pred_text:
return noise_pred, noise_pred_text
return noise_pred
def _prepare_perturbed_attention_guidance(self, cond, uncond, do_classifier_free_guidance):
File diff suppressed because it is too large Load Diff

Some files were not shown because too many files have changed in this diff Show More