Compare commits

..

11 Commits

Author SHA1 Message Date
DN6 d9915a7d65 update 2025-03-12 11:44:40 +05:30
DN6 b7a795dbeb update 2025-03-12 11:40:40 +05:30
DN6 438905d63e update 2025-03-12 11:37:27 +05:30
DN6 904f24de5a update 2025-03-12 11:35:18 +05:30
DN6 e123bbcbc4 memmap 2025-03-12 11:23:14 +05:30
DN6 b3fa8c695d remove cpu param dict 2025-03-12 09:02:04 +05:30
DN6 720be2bac5 update 2025-03-12 08:49:45 +05:30
DN6 e74b782aac update 2025-03-12 08:45:09 +05:30
DN6 d6392b4b49 update 2025-03-12 08:18:19 +05:30
DN6 1475026960 sliding-window 2025-03-11 13:56:39 +05:30
DN6 878eb4ce35 update 2025-03-11 13:21:09 +05:30
200 changed files with 1361 additions and 16445 deletions
-1
View File
@@ -38,7 +38,6 @@ jobs:
python -m venv /opt/venv && export PATH="/opt/venv/bin:$PATH"
python -m uv pip install -e [quality,test]
python -m uv pip install pandas peft
python -m uv pip uninstall transformers && python -m uv pip install transformers==4.48.0
- name: Environment
run: |
python utils/print_env.py
-7
View File
@@ -414,16 +414,12 @@ jobs:
config:
- backend: "bitsandbytes"
test_location: "bnb"
additional_deps: ["peft"]
- backend: "gguf"
test_location: "gguf"
additional_deps: []
- backend: "torchao"
test_location: "torchao"
additional_deps: []
- backend: "optimum_quanto"
test_location: "quanto"
additional_deps: []
runs-on:
group: aws-g6e-xlarge-plus
container:
@@ -441,9 +437,6 @@ jobs:
python -m venv /opt/venv && export PATH="/opt/venv/bin:$PATH"
python -m uv pip install -e [quality,test]
python -m uv pip install -U ${{ matrix.config.backend }}
if [ "${{ join(matrix.config.additional_deps, ' ') }}" != "" ]; then
python -m uv pip install ${{ join(matrix.config.additional_deps, ' ') }}
fi
python -m uv pip install pytest-reportlog
- name: Environment
run: |
+34
View File
@@ -13,5 +13,39 @@ jobs:
uses: huggingface/huggingface_hub/.github/workflows/style-bot-action.yml@main
with:
python_quality_dependencies: "[quality]"
pre_commit_script_name: "Download and Compare files from the main branch"
pre_commit_script: |
echo "Downloading the files from the main branch"
curl -o main_Makefile https://raw.githubusercontent.com/huggingface/diffusers/main/Makefile
curl -o main_setup.py https://raw.githubusercontent.com/huggingface/diffusers/refs/heads/main/setup.py
curl -o main_check_doc_toc.py https://raw.githubusercontent.com/huggingface/diffusers/refs/heads/main/utils/check_doc_toc.py
echo "Compare the files and raise error if needed"
diff_failed=0
if ! diff -q main_Makefile Makefile; then
echo "Error: The Makefile has changed. Please ensure it matches the main branch."
diff_failed=1
fi
if ! diff -q main_setup.py setup.py; then
echo "Error: The setup.py has changed. Please ensure it matches the main branch."
diff_failed=1
fi
if ! diff -q main_check_doc_toc.py utils/check_doc_toc.py; then
echo "Error: The utils/check_doc_toc.py has changed. Please ensure it matches the main branch."
diff_failed=1
fi
if [ $diff_failed -eq 1 ]; then
echo "❌ Error happened as we detected changes in the files that should not be changed ❌"
exit 1
fi
echo "No changes in the files. Proceeding..."
rm -rf main_Makefile main_setup.py main_check_doc_toc.py
style_command: "make style && make quality"
secrets:
bot_token: ${{ secrets.GITHUB_TOKEN }}
+1 -47
View File
@@ -28,51 +28,7 @@ env:
PIPELINE_USAGE_CUTOFF: 1000000000 # set high cutoff so that only always-test pipelines run
jobs:
check_code_quality:
runs-on: ubuntu-22.04
steps:
- uses: actions/checkout@v3
- name: Set up Python
uses: actions/setup-python@v4
with:
python-version: "3.8"
- name: Install dependencies
run: |
python -m pip install --upgrade pip
pip install .[quality]
- name: Check quality
run: make quality
- name: Check if failure
if: ${{ failure() }}
run: |
echo "Quality check failed. Please ensure the right dependency versions are installed with 'pip install -e .[quality]' and run 'make style && make quality'" >> $GITHUB_STEP_SUMMARY
check_repository_consistency:
needs: check_code_quality
runs-on: ubuntu-22.04
steps:
- uses: actions/checkout@v3
- name: Set up Python
uses: actions/setup-python@v4
with:
python-version: "3.8"
- name: Install dependencies
run: |
python -m pip install --upgrade pip
pip install .[quality]
- name: Check repo consistency
run: |
python utils/check_copies.py
python utils/check_dummies.py
python utils/check_support_list.py
make deps_table_check_updated
- name: Check if failure
if: ${{ failure() }}
run: |
echo "Repo consistency check failed. Please ensure the right dependency versions are installed with 'pip install -e .[quality]' and run 'make fix-copies'" >> $GITHUB_STEP_SUMMARY
setup_torch_cuda_pipeline_matrix:
needs: [check_code_quality, check_repository_consistency]
name: Setup Torch Pipelines CUDA Slow Tests Matrix
runs-on:
group: aws-general-8-plus
@@ -177,7 +133,6 @@ jobs:
torch_cuda_tests:
name: Torch CUDA Tests
needs: [check_code_quality, check_repository_consistency]
runs-on:
group: aws-g4dn-2xlarge
container:
@@ -246,7 +201,7 @@ jobs:
run_examples_tests:
name: Examples PyTorch CUDA tests on Ubuntu
needs: [check_code_quality, check_repository_consistency]
pip uninstall transformers -y && python -m uv pip install -U transformers@git+https://github.com/huggingface/transformers.git --no-deps
runs-on:
group: aws-g4dn-2xlarge
@@ -265,7 +220,6 @@ jobs:
- name: Install dependencies
run: |
python -m venv /opt/venv && export PATH="/opt/venv/bin:$PATH"
pip uninstall transformers -y && python -m uv pip install -U transformers@git+https://github.com/huggingface/transformers.git --no-deps
python -m uv pip install -e [quality,test,training]
- name: Environment
-4
View File
@@ -81,8 +81,6 @@
title: Overview
- local: hybrid_inference/vae_decode
title: VAE Decode
- local: hybrid_inference/vae_encode
title: VAE Encode
- local: hybrid_inference/api_reference
title: API Reference
title: Hybrid Inference
@@ -496,8 +494,6 @@
title: PixArt-Σ
- local: api/pipelines/sana
title: Sana
- local: api/pipelines/sana_sprint
title: Sana Sprint
- local: api/pipelines/self_attention_guidance
title: Self-Attention Guidance
- local: api/pipelines/semantic_stable_diffusion
-56
View File
@@ -11,50 +11,6 @@ specific language governing permissions and limitations under the License. -->
# Caching methods
## Faster Cache
[FasterCache](https://huggingface.co/papers/2410.19355) from Zhengyao Lv, Chenyang Si, Junhao Song, Zhenyu Yang, Yu Qiao, Ziwei Liu, Kwan-Yee K. Wong.
FasterCache is a method that speeds up inference in diffusion transformers by:
- Reusing attention states between successive inference steps, due to high similarity between them
- Skipping unconditional branch prediction used in classifier-free guidance by revealing redundancies between unconditional and conditional branch outputs for the same timestep, and therefore approximating the unconditional branch output using the conditional branch output
```python
import torch
from diffusers import CogVideoXPipeline, FasterCacheConfig
pipe = CogVideoXPipeline.from_pretrained("THUDM/CogVideoX-5b", torch_dtype=torch.bfloat16)
pipe.to("cuda")
config = FasterCacheConfig(
spatial_attention_block_skip_range=2,
spatial_attention_timestep_skip_range=(-1, 681),
current_timestep_callback=lambda: pipe.current_timestep,
attention_weight_callback=lambda _: 0.3,
unconditional_batch_skip_range=5,
unconditional_batch_timestep_skip_range=(-1, 781),
tensor_format="BFCHW",
)
pipe.transformer.enable_cache(config)
```
## First Block Cache
[First Block Cache](https://github.com/chengzeyi/ParaAttention/blob/7a266123671b55e7e5a2fe9af3121f07a36afc78/README.md#first-block-cache-our-dynamic-caching) is a method that builds upon the ideas of [TeaCache](https://huggingface.co/papers/2411.19108) to speed up inference in diffusion transformers. The generation quality is superior with greatly reduced inference time. This method always computes the output of the first transformer block and computes the differences between past and current outputs of the first transformer block. If the difference is smaller than a predefined threshold, the computation of remaining transformer blocks is skipped, and otherwise the computation is performed as usual.
```python
import torch
from diffusers import CogVideoXPipeline, FirstBlockCacheConfig
pipe = CogVideoXPipeline.from_pretrained("THUDM/CogVideoX-5b", torch_dtype=torch.bfloat16)
pipe.to("cuda")
# Increasing the threshold may lead to faster inference speeds, but may also lead to poorer quality of generated videos.
# Smaller values between 0.02-2.0 are recommended based on the model being used. The default value is 0.05.
config = FirstBlockCacheConfig(threshold=0.07)
pipe.transformer.enable_cache(config)
```
## Pyramid Attention Broadcast
[Pyramid Attention Broadcast](https://huggingface.co/papers/2408.12588) from Xuanlei Zhao, Xiaolong Jin, Kai Wang, Yang You.
@@ -86,18 +42,6 @@ pipe.transformer.enable_cache(config)
[[autodoc]] CacheMixin
### FasterCacheConfig
[[autodoc]] FasterCacheConfig
[[autodoc]] apply_faster_cache
### FirstBlockCacheConfig
[[autodoc]] FirstBlockCacheConfig
[[autodoc]] apply_first_block_cache
### PyramidAttentionBroadcastConfig
[[autodoc]] PyramidAttentionBroadcastConfig
@@ -50,8 +50,7 @@ The following models are available for the image-to-video pipeline:
| Model name | Description |
|:---|:---|
| [`Skywork/SkyReels-V1-Hunyuan-I2V`](https://huggingface.co/Skywork/SkyReels-V1-Hunyuan-I2V) | Skywork's custom finetune of HunyuanVideo (de-distilled). Performs best with `97x544x960` resolution. Performs best at `97x544x960` resolution, `guidance_scale=1.0`, `true_cfg_scale=6.0` and a negative prompt. |
| [`hunyuanvideo-community/HunyuanVideo-I2V-33ch`](https://huggingface.co/hunyuanvideo-community/HunyuanVideo-I2V) | Tecent's official HunyuanVideo 33-channel I2V model. Performs best at resolutions of 480, 720, 960, 1280. A higher `shift` value when initializing the scheduler is recommended (good values are between 7 and 20). |
| [`hunyuanvideo-community/HunyuanVideo-I2V`](https://huggingface.co/hunyuanvideo-community/HunyuanVideo-I2V) | Tecent's official HunyuanVideo 16-channel I2V model. Performs best at resolutions of 480, 720, 960, 1280. A higher `shift` value when initializing the scheduler is recommended (good values are between 7 and 20) |
| [`hunyuanvideo-community/HunyuanVideo-I2V`](https://huggingface.co/hunyuanvideo-community/HunyuanVideo-I2V) | Tecent's official HunyuanVideo I2V model. Performs best at resolutions of 480, 720, 960, 1280. A higher `shift` value when initializing the scheduler is recommended (good values are between 7 and 20) |
## Quantization
@@ -32,7 +32,6 @@ Available models:
|:-------------:|:-----------------:|
| [`LTX Video 0.9.0`](https://huggingface.co/Lightricks/LTX-Video/blob/main/ltx-video-2b-v0.9.safetensors) | `torch.bfloat16` |
| [`LTX Video 0.9.1`](https://huggingface.co/Lightricks/LTX-Video/blob/main/ltx-video-2b-v0.9.1.safetensors) | `torch.bfloat16` |
| [`LTX Video 0.9.5`](https://huggingface.co/Lightricks/LTX-Video/blob/main/ltx-video-2b-v0.9.5.safetensors) | `torch.bfloat16` |
Note: The recommended dtype is for the transformer component. The VAE and text encoders can be either `torch.float32`, `torch.bfloat16` or `torch.float16` but the recommended dtype is `torch.bfloat16` as used in the original repository.
@@ -197,12 +196,6 @@ export_to_video(video, "ship.mp4", fps=24)
- all
- __call__
## LTXConditionPipeline
[[autodoc]] LTXConditionPipeline
- all
- __call__
## LTXPipelineOutput
[[autodoc]] pipelines.ltx.pipeline_output.LTXPipelineOutput
+7 -7
View File
@@ -58,10 +58,10 @@ Use [`torch.compile`](https://huggingface.co/docs/diffusers/main/en/tutorials/fa
First, load the pipeline:
```python
from diffusers import LuminaPipeline
from diffusers import LuminaText2ImgPipeline
import torch
pipeline = LuminaPipeline.from_pretrained(
pipeline = LuminaText2ImgPipeline.from_pretrained(
"Alpha-VLLM/Lumina-Next-SFT-diffusers", torch_dtype=torch.bfloat16
).to("cuda")
```
@@ -86,11 +86,11 @@ image = pipeline(prompt="Upper body of a young woman in a Victorian-era outfit w
Quantization helps reduce the memory requirements of very large models by storing model weights in a lower precision data type. However, quantization may have varying impact on video quality depending on the video model.
Refer to the [Quantization](../../quantization/overview) overview to learn more about supported quantization backends and selecting a quantization backend that supports your use case. The example below demonstrates how to load a quantized [`LuminaPipeline`] for inference with bitsandbytes.
Refer to the [Quantization](../../quantization/overview) overview to learn more about supported quantization backends and selecting a quantization backend that supports your use case. The example below demonstrates how to load a quantized [`LuminaText2ImgPipeline`] for inference with bitsandbytes.
```py
import torch
from diffusers import BitsAndBytesConfig as DiffusersBitsAndBytesConfig, Transformer2DModel, LuminaPipeline
from diffusers import BitsAndBytesConfig as DiffusersBitsAndBytesConfig, Transformer2DModel, LuminaText2ImgPipeline
from transformers import BitsAndBytesConfig as BitsAndBytesConfig, T5EncoderModel
quant_config = BitsAndBytesConfig(load_in_8bit=True)
@@ -109,7 +109,7 @@ transformer_8bit = Transformer2DModel.from_pretrained(
torch_dtype=torch.float16,
)
pipeline = LuminaPipeline.from_pretrained(
pipeline = LuminaText2ImgPipeline.from_pretrained(
"Alpha-VLLM/Lumina-Next-SFT-diffusers",
text_encoder=text_encoder_8bit,
transformer=transformer_8bit,
@@ -122,9 +122,9 @@ image = pipeline(prompt).images[0]
image.save("lumina.png")
```
## LuminaPipeline
## LuminaText2ImgPipeline
[[autodoc]] LuminaPipeline
[[autodoc]] LuminaText2ImgPipeline
- all
- __call__
+6 -6
View File
@@ -36,14 +36,14 @@ Single file loading for Lumina Image 2.0 is available for the `Lumina2Transforme
```python
import torch
from diffusers import Lumina2Transformer2DModel, Lumina2Pipeline
from diffusers import Lumina2Transformer2DModel, Lumina2Text2ImgPipeline
ckpt_path = "https://huggingface.co/Alpha-VLLM/Lumina-Image-2.0/blob/main/consolidated.00-of-01.pth"
transformer = Lumina2Transformer2DModel.from_single_file(
ckpt_path, torch_dtype=torch.bfloat16
)
pipe = Lumina2Pipeline.from_pretrained(
pipe = Lumina2Text2ImgPipeline.from_pretrained(
"Alpha-VLLM/Lumina-Image-2.0", transformer=transformer, torch_dtype=torch.bfloat16
)
pipe.enable_model_cpu_offload()
@@ -60,7 +60,7 @@ image.save("lumina-single-file.png")
GGUF Quantized checkpoints for the `Lumina2Transformer2DModel` can be loaded via `from_single_file` with the `GGUFQuantizationConfig`
```python
from diffusers import Lumina2Transformer2DModel, Lumina2Pipeline, GGUFQuantizationConfig
from diffusers import Lumina2Transformer2DModel, Lumina2Text2ImgPipeline, GGUFQuantizationConfig
ckpt_path = "https://huggingface.co/calcuis/lumina-gguf/blob/main/lumina2-q4_0.gguf"
transformer = Lumina2Transformer2DModel.from_single_file(
@@ -69,7 +69,7 @@ transformer = Lumina2Transformer2DModel.from_single_file(
torch_dtype=torch.bfloat16,
)
pipe = Lumina2Pipeline.from_pretrained(
pipe = Lumina2Text2ImgPipeline.from_pretrained(
"Alpha-VLLM/Lumina-Image-2.0", transformer=transformer, torch_dtype=torch.bfloat16
)
pipe.enable_model_cpu_offload()
@@ -80,8 +80,8 @@ image = pipe(
image.save("lumina-gguf.png")
```
## Lumina2Pipeline
## Lumina2Text2ImgPipeline
[[autodoc]] Lumina2Pipeline
[[autodoc]] Lumina2Text2ImgPipeline
- all
- __call__
-100
View File
@@ -1,100 +0,0 @@
<!-- Copyright 2024 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License. -->
# SanaSprintPipeline
<div class="flex flex-wrap space-x-1">
<img alt="LoRA" src="https://img.shields.io/badge/LoRA-d8b4fe?style=flat"/>
</div>
[SANA-Sprint: One-Step Diffusion with Continuous-Time Consistency Distillation](https://huggingface.co/papers/2503.09641) from NVIDIA, MIT HAN Lab, and Hugging Face by Junsong Chen, Shuchen Xue, Yuyang Zhao, Jincheng Yu, Sayak Paul, Junyu Chen, Han Cai, Enze Xie, Song Han
The abstract from the paper is:
*This paper presents SANA-Sprint, an efficient diffusion model for ultra-fast text-to-image (T2I) generation. SANA-Sprint is built on a pre-trained foundation model and augmented with hybrid distillation, dramatically reducing inference steps from 20 to 1-4. We introduce three key innovations: (1) We propose a training-free approach that transforms a pre-trained flow-matching model for continuous-time consistency distillation (sCM), eliminating costly training from scratch and achieving high training efficiency. Our hybrid distillation strategy combines sCM with latent adversarial distillation (LADD): sCM ensures alignment with the teacher model, while LADD enhances single-step generation fidelity. (2) SANA-Sprint is a unified step-adaptive model that achieves high-quality generation in 1-4 steps, eliminating step-specific training and improving efficiency. (3) We integrate ControlNet with SANA-Sprint for real-time interactive image generation, enabling instant visual feedback for user interaction. SANA-Sprint establishes a new Pareto frontier in speed-quality tradeoffs, achieving state-of-the-art performance with 7.59 FID and 0.74 GenEval in only 1 step — outperforming FLUX-schnell (7.94 FID / 0.71 GenEval) while being 10× faster (0.1s vs 1.1s on H100). It also achieves 0.1s (T2I) and 0.25s (ControlNet) latency for 1024×1024 images on H100, and 0.31s (T2I) on an RTX 4090, showcasing its exceptional efficiency and potential for AI-powered consumer applications (AIPC). Code and pre-trained models will be open-sourced.*
<Tip>
Make sure to check out the Schedulers [guide](../../using-diffusers/schedulers) to learn how to explore the tradeoff between scheduler speed and quality, and see the [reuse components across pipelines](../../using-diffusers/loading#reuse-a-pipeline) section to learn how to efficiently load the same components into multiple pipelines.
</Tip>
This pipeline was contributed by [lawrence-cj](https://github.com/lawrence-cj), [shuchen Xue](https://github.com/scxue) and [Enze Xie](https://github.com/xieenze). The original codebase can be found [here](https://github.com/NVlabs/Sana). The original weights can be found under [hf.co/Efficient-Large-Model](https://huggingface.co/Efficient-Large-Model/).
Available models:
| Model | Recommended dtype |
|:-------------------------------------------------------------------------------------------------------------------------------------------:|:-----------------:|
| [`Efficient-Large-Model/Sana_Sprint_1.6B_1024px_diffusers`](https://huggingface.co/Efficient-Large-Model/Sana_Sprint_1.6B_1024px_diffusers) | `torch.bfloat16` |
| [`Efficient-Large-Model/Sana_Sprint_0.6B_1024px_diffusers`](https://huggingface.co/Efficient-Large-Model/Sana_Sprint_0.6B_1024px_diffusers) | `torch.bfloat16` |
Refer to [this](https://huggingface.co/collections/Efficient-Large-Model/sana-sprint-67d6810d65235085b3b17c76) collection for more information.
Note: The recommended dtype mentioned is for the transformer weights. The text encoder must stay in `torch.bfloat16` and VAE weights must stay in `torch.bfloat16` or `torch.float32` for the model to work correctly. Please refer to the inference example below to see how to load the model with the recommended dtype.
## Quantization
Quantization helps reduce the memory requirements of very large models by storing model weights in a lower precision data type. However, quantization may have varying impact on video quality depending on the video model.
Refer to the [Quantization](../../quantization/overview) overview to learn more about supported quantization backends and selecting a quantization backend that supports your use case. The example below demonstrates how to load a quantized [`SanaSprintPipeline`] for inference with bitsandbytes.
```py
import torch
from diffusers import BitsAndBytesConfig as DiffusersBitsAndBytesConfig, SanaTransformer2DModel, SanaSprintPipeline
from transformers import BitsAndBytesConfig as BitsAndBytesConfig, AutoModel
quant_config = BitsAndBytesConfig(load_in_8bit=True)
text_encoder_8bit = AutoModel.from_pretrained(
"Efficient-Large-Model/Sana_Sprint_1.6B_1024px_diffusers",
subfolder="text_encoder",
quantization_config=quant_config,
torch_dtype=torch.bfloat16,
)
quant_config = DiffusersBitsAndBytesConfig(load_in_8bit=True)
transformer_8bit = SanaTransformer2DModel.from_pretrained(
"Efficient-Large-Model/Sana_Sprint_1.6B_1024px_diffusers",
subfolder="transformer",
quantization_config=quant_config,
torch_dtype=torch.bfloat16,
)
pipeline = SanaSprintPipeline.from_pretrained(
"Efficient-Large-Model/Sana_Sprint_1.6B_1024px_diffusers",
text_encoder=text_encoder_8bit,
transformer=transformer_8bit,
torch_dtype=torch.bfloat16,
device_map="balanced",
)
prompt = "a tiny astronaut hatching from an egg on the moon"
image = pipeline(prompt).images[0]
image.save("sana.png")
```
## Setting `max_timesteps`
Users can tweak the `max_timesteps` value for experimenting with the visual quality of the generated outputs. The default `max_timesteps` value was obtained with an inference-time search process. For more details about it, check out the paper.
## SanaSprintPipeline
[[autodoc]] SanaSprintPipeline
- all
- __call__
## SanaPipelineOutput
[[autodoc]] pipelines.sana.pipeline_output.SanaPipelineOutput
+12 -399
View File
@@ -14,405 +14,22 @@
# Wan
<div class="flex flex-wrap space-x-1">
<img alt="LoRA" src="https://img.shields.io/badge/LoRA-d8b4fe?style=flat"/>
</div>
[Wan 2.1](https://github.com/Wan-Video/Wan2.1) by the Alibaba Wan Team.
<!-- TODO(aryan): update abstract once paper is out -->
## Generating Videos with Wan 2.1
We will first need to install some addtional dependencies.
```shell
pip install -u ftfy imageio-ffmpeg imageio
```
### Text to Video Generation
The following example requires 11GB VRAM to run and uses the smaller `Wan-AI/Wan2.1-T2V-1.3B-Diffusers` model. You can switch it out
for the larger `Wan2.1-I2V-14B-720P-Diffusers` or `Wan-AI/Wan2.1-I2V-14B-480P-Diffusers` if you have at least 35GB VRAM available.
```python
from diffusers import WanPipeline
from diffusers.utils import export_to_video
# Available models: Wan-AI/Wan2.1-I2V-14B-720P-Diffusers or Wan-AI/Wan2.1-I2V-14B-480P-Diffusers
model_id = "Wan-AI/Wan2.1-T2V-1.3B-Diffusers"
pipe = WanPipeline.from_pretrained(model_id, torch_dtype=torch.bfloat16)
pipe.enable_model_cpu_offload()
prompt = "A cat and a dog baking a cake together in a kitchen. The cat is carefully measuring flour, while the dog is stirring the batter with a wooden spoon. The kitchen is cozy, with sunlight streaming through the window."
negative_prompt = "Bright tones, overexposed, static, blurred details, subtitles, style, works, paintings, images, static, overall gray, worst quality, low quality, JPEG compression residue, ugly, incomplete, extra fingers, poorly drawn hands, poorly drawn faces, deformed, disfigured, misshapen limbs, fused fingers, still picture, messy background, three legs, many people in the background, walking backwards"
num_frames = 33
frames = pipe(prompt=prompt, negative_prompt=negative_prompt, num_frames=num_frames).frames[0]
export_to_video(frames, "wan-t2v.mp4", fps=16)
```
<Tip>
You can improve the quality of the generated video by running the decoding step in full precision.
Make sure to check out the Schedulers [guide](../../using-diffusers/schedulers) to learn how to explore the tradeoff between scheduler speed and quality, and see the [reuse components across pipelines](../../using-diffusers/loading#reuse-a-pipeline) section to learn how to efficiently load the same components into multiple pipelines.
</Tip>
```python
from diffusers import WanPipeline, AutoencoderKLWan
from diffusers.utils import export_to_video
Recommendations for inference:
- VAE in `torch.float32` for better decoding quality.
- `num_frames` should be of the form `4 * k + 1`, for example `49` or `81`.
- For smaller resolution videos, try lower values of `shift` (between `2.0` to `5.0`) in the [Scheduler](https://huggingface.co/docs/diffusers/main/en/api/schedulers/flow_match_euler_discrete#diffusers.FlowMatchEulerDiscreteScheduler.shift). For larger resolution videos, try higher values (between `7.0` and `12.0`). The default value is `3.0` for Wan.
model_id = "Wan-AI/Wan2.1-T2V-1.3B-Diffusers"
vae = AutoencoderKLWan.from_pretrained(model_id, subfolder="vae", torch_dtype=torch.float32)
pipe = WanPipeline.from_pretrained(model_id, vae=vae, torch_dtype=torch.bfloat16)
# replace this with pipe.to("cuda") if you have sufficient VRAM
pipe.enable_model_cpu_offload()
prompt = "A cat and a dog baking a cake together in a kitchen. The cat is carefully measuring flour, while the dog is stirring the batter with a wooden spoon. The kitchen is cozy, with sunlight streaming through the window."
negative_prompt = "Bright tones, overexposed, static, blurred details, subtitles, style, works, paintings, images, static, overall gray, worst quality, low quality, JPEG compression residue, ugly, incomplete, extra fingers, poorly drawn hands, poorly drawn faces, deformed, disfigured, misshapen limbs, fused fingers, still picture, messy background, three legs, many people in the background, walking backwards"
num_frames = 33
frames = pipe(prompt=prompt, num_frames=num_frames).frames[0]
export_to_video(frames, "wan-t2v.mp4", fps=16)
```
### Image to Video Generation
The Image to Video pipeline requires loading the `AutoencoderKLWan` and the `CLIPVisionModel` components in full precision. The following example will need at least
35GB of VRAM to run.
```python
import torch
import numpy as np
from diffusers import AutoencoderKLWan, WanImageToVideoPipeline
from diffusers.utils import export_to_video, load_image
from transformers import CLIPVisionModel
# Available models: Wan-AI/Wan2.1-I2V-14B-480P-Diffusers, Wan-AI/Wan2.1-I2V-14B-720P-Diffusers
model_id = "Wan-AI/Wan2.1-I2V-14B-480P-Diffusers"
image_encoder = CLIPVisionModel.from_pretrained(
model_id, subfolder="image_encoder", torch_dtype=torch.float32
)
vae = AutoencoderKLWan.from_pretrained(model_id, subfolder="vae", torch_dtype=torch.float32)
pipe = WanImageToVideoPipeline.from_pretrained(
model_id, vae=vae, image_encoder=image_encoder, torch_dtype=torch.bfloat16
)
# replace this with pipe.to("cuda") if you have sufficient VRAM
pipe.enable_model_cpu_offload()
image = load_image(
"https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/astronaut.jpg"
)
max_area = 480 * 832
aspect_ratio = image.height / image.width
mod_value = pipe.vae_scale_factor_spatial * pipe.transformer.config.patch_size[1]
height = round(np.sqrt(max_area * aspect_ratio)) // mod_value * mod_value
width = round(np.sqrt(max_area / aspect_ratio)) // mod_value * mod_value
image = image.resize((width, height))
prompt = (
"An astronaut hatching from an egg, on the surface of the moon, the darkness and depth of space realised in "
"the background. High quality, ultrarealistic detail and breath-taking movie-like camera shot."
)
negative_prompt = "Bright tones, overexposed, static, blurred details, subtitles, style, works, paintings, images, static, overall gray, worst quality, low quality, JPEG compression residue, ugly, incomplete, extra fingers, poorly drawn hands, poorly drawn faces, deformed, disfigured, misshapen limbs, fused fingers, still picture, messy background, three legs, many people in the background, walking backwards"
num_frames = 33
output = pipe(
image=image,
prompt=prompt,
negative_prompt=negative_prompt,
height=height,
width=width,
num_frames=num_frames,
guidance_scale=5.0,
).frames[0]
export_to_video(output, "wan-i2v.mp4", fps=16)
```
### Video to Video Generation
```python
import torch
from diffusers.utils import load_video, export_to_video
from diffusers import AutoencoderKLWan, WanVideoToVideoPipeline, UniPCMultistepScheduler
# Available models: Wan-AI/Wan2.1-T2V-14B-Diffusers, Wan-AI/Wan2.1-T2V-1.3B-Diffusers
model_id = "Wan-AI/Wan2.1-T2V-1.3B-Diffusers"
vae = AutoencoderKLWan.from_pretrained(
model_id, subfolder="vae", torch_dtype=torch.float32
)
pipe = WanVideoToVideoPipeline.from_pretrained(
model_id, vae=vae, torch_dtype=torch.bfloat16
)
flow_shift = 3.0 # 5.0 for 720P, 3.0 for 480P
pipe.scheduler = UniPCMultistepScheduler.from_config(
pipe.scheduler.config, flow_shift=flow_shift
)
# change to pipe.to("cuda") if you have sufficient VRAM
pipe.enable_model_cpu_offload()
prompt = "A robot standing on a mountain top. The sun is setting in the background"
negative_prompt = "Bright tones, overexposed, static, blurred details, subtitles, style, works, paintings, images, static, overall gray, worst quality, low quality, JPEG compression residue, ugly, incomplete, extra fingers, poorly drawn hands, poorly drawn faces, deformed, disfigured, misshapen limbs, fused fingers, still picture, messy background, three legs, many people in the background, walking backwards"
video = load_video(
"https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/hiker.mp4"
)
output = pipe(
video=video,
prompt=prompt,
negative_prompt=negative_prompt,
height=480,
width=512,
guidance_scale=7.0,
strength=0.7,
).frames[0]
export_to_video(output, "wan-v2v.mp4", fps=16)
```
## Memory Optimizations for Wan 2.1
Base inference with the large 14B Wan 2.1 models can take up to 35GB of VRAM when generating videos at 720p resolution. We'll outline a few memory optimizations we can apply to reduce the VRAM required to run the model.
We'll use `Wan-AI/Wan2.1-I2V-14B-720P-Diffusers` model in these examples to demonstrate the memory savings, but the techniques are applicable to all model checkpoints.
### Group Offloading the Transformer and UMT5 Text Encoder
Find more information about group offloading [here](../optimization/memory.md)
#### Block Level Group Offloading
We can reduce our VRAM requirements by applying group offloading to the larger model components of the pipeline; the `WanTransformer3DModel` and `UMT5EncoderModel`. Group offloading will break up the individual modules of a model and offload/onload them onto your GPU as needed during inference. In this example, we'll apply `block_level` offloading, which will group the modules in a model into blocks of size `num_blocks_per_group` and offload/onload them to GPU. Moving to between CPU and GPU does add latency to the inference process. You can trade off between latency and memory savings by increasing or decreasing the `num_blocks_per_group`.
The following example will now only require 14GB of VRAM to run, but will take approximately 30 minutes to generate a video.
```python
import torch
import numpy as np
from diffusers import AutoencoderKLWan, WanTransformer3DModel, WanImageToVideoPipeline
from diffusers.hooks.group_offloading import apply_group_offloading
from diffusers.utils import export_to_video, load_image
from transformers import UMT5EncoderModel, CLIPVisionModel
# Available models: Wan-AI/Wan2.1-I2V-14B-480P-Diffusers, Wan-AI/Wan2.1-I2V-14B-720P-Diffusers
model_id = "Wan-AI/Wan2.1-I2V-14B-720P-Diffusers"
image_encoder = CLIPVisionModel.from_pretrained(
model_id, subfolder="image_encoder", torch_dtype=torch.float32
)
text_encoder = UMT5EncoderModel.from_pretrained(model_id, subfolder="text_encoder", torch_dtype=torch.bfloat16)
vae = AutoencoderKLWan.from_pretrained(model_id, subfolder="vae", torch_dtype=torch.float32)
transformer = WanTransformer3DModel.from_pretrained(model_id, subfolder="transformer", torch_dtype=torch.bfloat16)
onload_device = torch.device("cuda")
offload_device = torch.device("cpu")
apply_group_offloading(text_encoder,
onload_device=onload_device,
offload_device=offload_device,
offload_type="block_level",
num_blocks_per_group=4
)
transformer.enable_group_offload(
onload_device=onload_device,
offload_device=offload_device,
offload_type="block_level",
num_blocks_per_group=4,
)
pipe = WanImageToVideoPipeline.from_pretrained(
model_id,
vae=vae,
transformer=transformer,
text_encoder=text_encoder,
image_encoder=image_encoder,
torch_dtype=torch.bfloat16
)
# Since we've offloaded the larger models alrady, we can move the rest of the model components to GPU
pipe.to("cuda")
image = load_image(
"https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/astronaut.jpg"
)
max_area = 720 * 832
aspect_ratio = image.height / image.width
mod_value = pipe.vae_scale_factor_spatial * pipe.transformer.config.patch_size[1]
height = round(np.sqrt(max_area * aspect_ratio)) // mod_value * mod_value
width = round(np.sqrt(max_area / aspect_ratio)) // mod_value * mod_value
image = image.resize((width, height))
prompt = (
"An astronaut hatching from an egg, on the surface of the moon, the darkness and depth of space realised in "
"the background. High quality, ultrarealistic detail and breath-taking movie-like camera shot."
)
negative_prompt = "Bright tones, overexposed, static, blurred details, subtitles, style, works, paintings, images, static, overall gray, worst quality, low quality, JPEG compression residue, ugly, incomplete, extra fingers, poorly drawn hands, poorly drawn faces, deformed, disfigured, misshapen limbs, fused fingers, still picture, messy background, three legs, many people in the background, walking backwards"
num_frames = 33
output = pipe(
image=image,
prompt=prompt,
negative_prompt=negative_prompt,
height=height,
width=width,
num_frames=num_frames,
guidance_scale=5.0,
).frames[0]
export_to_video(output, "wan-i2v.mp4", fps=16)
```
#### Block Level Group Offloading with CUDA Streams
We can speed up group offloading inference, by enabling the use of [CUDA streams](https://pytorch.org/docs/stable/generated/torch.cuda.Stream.html). However, using CUDA streams requires moving the model parameters into pinned memory. This allocation is handled by Pytorch under the hood, and can result in a significant spike in CPU RAM usage. Please consider this option if your CPU RAM is atleast 2X the size of the model you are group offloading.
In the following example we will use CUDA streams when group offloading the `WanTransformer3DModel`. When testing on an A100, this example will require 14GB of VRAM, 52GB of CPU RAM, but will generate a video in approximately 9 minutes.
```python
import torch
import numpy as np
from diffusers import AutoencoderKLWan, WanTransformer3DModel, WanImageToVideoPipeline
from diffusers.hooks.group_offloading import apply_group_offloading
from diffusers.utils import export_to_video, load_image
from transformers import UMT5EncoderModel, CLIPVisionModel
# Available models: Wan-AI/Wan2.1-I2V-14B-480P-Diffusers, Wan-AI/Wan2.1-I2V-14B-720P-Diffusers
model_id = "Wan-AI/Wan2.1-I2V-14B-720P-Diffusers"
image_encoder = CLIPVisionModel.from_pretrained(
model_id, subfolder="image_encoder", torch_dtype=torch.float32
)
text_encoder = UMT5EncoderModel.from_pretrained(model_id, subfolder="text_encoder", torch_dtype=torch.bfloat16)
vae = AutoencoderKLWan.from_pretrained(model_id, subfolder="vae", torch_dtype=torch.float32)
transformer = WanTransformer3DModel.from_pretrained(model_id, subfolder="transformer", torch_dtype=torch.bfloat16)
onload_device = torch.device("cuda")
offload_device = torch.device("cpu")
apply_group_offloading(text_encoder,
onload_device=onload_device,
offload_device=offload_device,
offload_type="block_level",
num_blocks_per_group=4
)
transformer.enable_group_offload(
onload_device=onload_device,
offload_device=offload_device,
offload_type="leaf_level",
use_stream=True
)
pipe = WanImageToVideoPipeline.from_pretrained(
model_id,
vae=vae,
transformer=transformer,
text_encoder=text_encoder,
image_encoder=image_encoder,
torch_dtype=torch.bfloat16
)
# Since we've offloaded the larger models alrady, we can move the rest of the model components to GPU
pipe.to("cuda")
image = load_image(
"https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/astronaut.jpg"
)
max_area = 720 * 832
aspect_ratio = image.height / image.width
mod_value = pipe.vae_scale_factor_spatial * pipe.transformer.config.patch_size[1]
height = round(np.sqrt(max_area * aspect_ratio)) // mod_value * mod_value
width = round(np.sqrt(max_area / aspect_ratio)) // mod_value * mod_value
image = image.resize((width, height))
prompt = (
"An astronaut hatching from an egg, on the surface of the moon, the darkness and depth of space realised in "
"the background. High quality, ultrarealistic detail and breath-taking movie-like camera shot."
)
negative_prompt = "Bright tones, overexposed, static, blurred details, subtitles, style, works, paintings, images, static, overall gray, worst quality, low quality, JPEG compression residue, ugly, incomplete, extra fingers, poorly drawn hands, poorly drawn faces, deformed, disfigured, misshapen limbs, fused fingers, still picture, messy background, three legs, many people in the background, walking backwards"
num_frames = 33
output = pipe(
image=image,
prompt=prompt,
negative_prompt=negative_prompt,
height=height,
width=width,
num_frames=num_frames,
guidance_scale=5.0,
).frames[0]
export_to_video(output, "wan-i2v.mp4", fps=16)
```
### Applying Layerwise Casting to the Transformer
Find more information about layerwise casting [here](../optimization/memory.md)
In this example, we will model offloading with layerwise casting. Layerwise casting will downcast each layer's weights to `torch.float8_e4m3fn`, temporarily upcast to `torch.bfloat16` during the forward pass of the layer, then revert to `torch.float8_e4m3fn` afterward. This approach reduces memory requirements by approximately 50% while introducing a minor quality reduction in the generated video due to the precision trade-off.
This example will require 20GB of VRAM.
```python
import torch
import numpy as np
from diffusers import AutoencoderKLWan, WanTransformer3DModel, WanImageToVideoPipeline
from diffusers.hooks.group_offloading import apply_group_offloading
from diffusers.utils import export_to_video, load_image
from transformers import UMT5EncoderModel, CLIPVisionModel
model_id = "Wan-AI/Wan2.1-I2V-14B-720P-Diffusers"
image_encoder = CLIPVisionModel.from_pretrained(
model_id, subfolder="image_encoder", torch_dtype=torch.float32
)
text_encoder = UMT5EncoderModel.from_pretrained(model_id, subfolder="text_encoder", torch_dtype=torch.bfloat16)
vae = AutoencoderKLWan.from_pretrained(model_id, subfolder="vae", torch_dtype=torch.float32)
transformer = WanTransformer3DModel.from_pretrained(model_id, subfolder="transformer", torch_dtype=torch.bfloat16)
transformer.enable_layerwise_casting(storage_dtype=torch.float8_e4m3fn, compute_dtype=torch.bfloat16)
pipe = WanImageToVideoPipeline.from_pretrained(
model_id,
vae=vae,
transformer=transformer,
text_encoder=text_encoder,
image_encoder=image_encoder,
torch_dtype=torch.bfloat16
)
pipe.enable_model_cpu_offload()
image = load_image("https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/astronaut.jpg")
max_area = 720 * 832
aspect_ratio = image.height / image.width
mod_value = pipe.vae_scale_factor_spatial * pipe.transformer.config.patch_size[1]
height = round(np.sqrt(max_area * aspect_ratio)) // mod_value * mod_value
width = round(np.sqrt(max_area / aspect_ratio)) // mod_value * mod_value
image = image.resize((width, height))
prompt = (
"An astronaut hatching from an egg, on the surface of the moon, the darkness and depth of space realised in "
"the background. High quality, ultrarealistic detail and breath-taking movie-like camera shot."
)
negative_prompt = "Bright tones, overexposed, static, blurred details, subtitles, style, works, paintings, images, static, overall gray, worst quality, low quality, JPEG compression residue, ugly, incomplete, extra fingers, poorly drawn hands, poorly drawn faces, deformed, disfigured, misshapen limbs, fused fingers, still picture, messy background, three legs, many people in the background, walking backwards"
num_frames = 33
output = pipe(
image=image,
prompt=prompt,
negative_prompt=negative_prompt,
height=height,
width=width,
num_frames=num_frames,
num_inference_steps=50,
guidance_scale=5.0,
).frames[0]
export_to_video(output, "wan-i2v.mp4", fps=16)
```
## Using a Custom Scheduler
### Using a custom scheduler
Wan can be used with many different schedulers, each with their own benefits regarding speed and generation quality. By default, Wan uses the `UniPCMultistepScheduler(prediction_type="flow_prediction", use_flow_sigmas=True, flow_shift=3.0)` scheduler. You can use a different scheduler as follows:
@@ -428,10 +45,11 @@ pipe = WanPipeline.from_pretrained("Wan-AI/Wan2.1-T2V-1.3B-Diffusers", scheduler
pipe.scheduler = <CUSTOM_SCHEDULER_HERE>
```
## Using Single File Loading with Wan 2.1
### Using single file loading with Wan
The `WanTransformer3DModel` and `AutoencoderKLWan` models support loading checkpoints in their original format via the `from_single_file` loading
method.
The `WanTransformer3DModel` and `AutoencoderKLWan` models support loading checkpoints in their original format via the `from_single_file` loading
method.
```python
import torch
@@ -443,11 +61,6 @@ transformer = WanTransformer3DModel.from_single_file(ckpt_path, torch_dtype=torc
pipe = WanPipeline.from_pretrained("Wan-AI/Wan2.1-T2V-1.3B-Diffusers", transformer=transformer)
```
## Recommendations for Inference
- Keep `AutencoderKLWan` in `torch.float32` for better decoding quality.
- `num_frames` should satisfy the following constraint: `(num_frames - 1) % 4 == 0`
- For smaller resolution videos, try lower values of `shift` (between `2.0` to `5.0`) in the [Scheduler](https://huggingface.co/docs/diffusers/main/en/api/schedulers/flow_match_euler_discrete#diffusers.FlowMatchEulerDiscreteScheduler.shift). For larger resolution videos, try higher values (between `7.0` and `12.0`). The default value is `3.0` for Wan.
## WanPipeline
[[autodoc]] WanPipeline
@@ -3,7 +3,3 @@
## Remote Decode
[[autodoc]] utils.remote_utils.remote_decode
## Remote Encode
[[autodoc]] utils.remote_utils.remote_encode
+2 -8
View File
@@ -36,7 +36,7 @@ Hybrid Inference offers a fast and simple way to offload local generation requir
## Available Models
* **VAE Decode 🖼️:** Quickly decode latent representations into high-quality images without compromising performance or workflow speed.
* **VAE Encode 🔢:** Efficiently encode images into latent representations for generation and training.
* **VAE Encode 🔢 (coming soon):** Efficiently encode images into latent representations for generation and training.
* **Text Encoders 📃 (coming soon):** Compute text embeddings for your prompts quickly and accurately, ensuring a smooth and high-quality workflow.
---
@@ -46,15 +46,9 @@ Hybrid Inference offers a fast and simple way to offload local generation requir
* **[SD.Next](https://github.com/vladmandic/sdnext):** All-in-one UI with direct supports Hybrid Inference.
* **[ComfyUI-HFRemoteVae](https://github.com/kijai/ComfyUI-HFRemoteVae):** ComfyUI node for Hybrid Inference.
## Changelog
- March 10 2025: Added VAE encode
- March 2 2025: Initial release with VAE decoding
## Contents
The documentation is organized into three sections:
The documentation is organized into two sections:
* **VAE Decode** Learn the basics of how to use VAE Decode with Hybrid Inference.
* **VAE Encode** Learn the basics of how to use VAE Encode with Hybrid Inference.
* **API Reference** Dive into task-specific settings and parameters.
@@ -1,183 +0,0 @@
# Getting Started: VAE Encode with Hybrid Inference
VAE encode is used for training, image-to-image and image-to-video - turning into images or videos into latent representations.
## Memory
These tables demonstrate the VRAM requirements for VAE encode with SD v1 and SD XL on different GPUs.
For the majority of these GPUs the memory usage % dictates other models (text encoders, UNet/Transformer) must be offloaded, or tiled encoding has to be used which increases time taken and impacts quality.
<details><summary>SD v1.5</summary>
| GPU | Resolution | Time (seconds) | Memory (%) | Tiled Time (secs) | Tiled Memory (%) |
|:------------------------------|:-------------|-----------------:|-------------:|--------------------:|-------------------:|
| NVIDIA GeForce RTX 4090 | 512x512 | 0.015 | 3.51901 | 0.015 | 3.51901 |
| NVIDIA GeForce RTX 4090 | 256x256 | 0.004 | 1.3154 | 0.005 | 1.3154 |
| NVIDIA GeForce RTX 4090 | 2048x2048 | 0.402 | 47.1852 | 0.496 | 3.51901 |
| NVIDIA GeForce RTX 4090 | 1024x1024 | 0.078 | 12.2658 | 0.094 | 3.51901 |
| NVIDIA GeForce RTX 4080 SUPER | 512x512 | 0.023 | 5.30105 | 0.023 | 5.30105 |
| NVIDIA GeForce RTX 4080 SUPER | 256x256 | 0.006 | 1.98152 | 0.006 | 1.98152 |
| NVIDIA GeForce RTX 4080 SUPER | 2048x2048 | 0.574 | 71.08 | 0.656 | 5.30105 |
| NVIDIA GeForce RTX 4080 SUPER | 1024x1024 | 0.111 | 18.4772 | 0.14 | 5.30105 |
| NVIDIA GeForce RTX 3090 | 512x512 | 0.032 | 3.52782 | 0.032 | 3.52782 |
| NVIDIA GeForce RTX 3090 | 256x256 | 0.01 | 1.31869 | 0.009 | 1.31869 |
| NVIDIA GeForce RTX 3090 | 2048x2048 | 0.742 | 47.3033 | 0.954 | 3.52782 |
| NVIDIA GeForce RTX 3090 | 1024x1024 | 0.136 | 12.2965 | 0.207 | 3.52782 |
| NVIDIA GeForce RTX 3080 | 512x512 | 0.036 | 8.51761 | 0.036 | 8.51761 |
| NVIDIA GeForce RTX 3080 | 256x256 | 0.01 | 3.18387 | 0.01 | 3.18387 |
| NVIDIA GeForce RTX 3080 | 2048x2048 | 0.863 | 86.7424 | 1.191 | 8.51761 |
| NVIDIA GeForce RTX 3080 | 1024x1024 | 0.157 | 29.6888 | 0.227 | 8.51761 |
| NVIDIA GeForce RTX 3070 | 512x512 | 0.051 | 10.6941 | 0.051 | 10.6941 |
| NVIDIA GeForce RTX 3070 | 256x256 | 0.015 | 3.99743 | 0.015 | 3.99743 |
| NVIDIA GeForce RTX 3070 | 2048x2048 | 1.217 | 96.054 | 1.482 | 10.6941 |
| NVIDIA GeForce RTX 3070 | 1024x1024 | 0.223 | 37.2751 | 0.327 | 10.6941 |
</details>
<details><summary>SDXL</summary>
| GPU | Resolution | Time (seconds) | Memory Consumed (%) | Tiled Time (seconds) | Tiled Memory (%) |
|:------------------------------|:-------------|-----------------:|----------------------:|-----------------------:|-------------------:|
| NVIDIA GeForce RTX 4090 | 512x512 | 0.029 | 4.95707 | 0.029 | 4.95707 |
| NVIDIA GeForce RTX 4090 | 256x256 | 0.007 | 2.29666 | 0.007 | 2.29666 |
| NVIDIA GeForce RTX 4090 | 2048x2048 | 0.873 | 66.3452 | 0.863 | 15.5649 |
| NVIDIA GeForce RTX 4090 | 1024x1024 | 0.142 | 15.5479 | 0.143 | 15.5479 |
| NVIDIA GeForce RTX 4080 SUPER | 512x512 | 0.044 | 7.46735 | 0.044 | 7.46735 |
| NVIDIA GeForce RTX 4080 SUPER | 256x256 | 0.01 | 3.4597 | 0.01 | 3.4597 |
| NVIDIA GeForce RTX 4080 SUPER | 2048x2048 | 1.317 | 87.1615 | 1.291 | 23.447 |
| NVIDIA GeForce RTX 4080 SUPER | 1024x1024 | 0.213 | 23.4215 | 0.214 | 23.4215 |
| NVIDIA GeForce RTX 3090 | 512x512 | 0.058 | 5.65638 | 0.058 | 5.65638 |
| NVIDIA GeForce RTX 3090 | 256x256 | 0.016 | 2.45081 | 0.016 | 2.45081 |
| NVIDIA GeForce RTX 3090 | 2048x2048 | 1.755 | 77.8239 | 1.614 | 18.4193 |
| NVIDIA GeForce RTX 3090 | 1024x1024 | 0.265 | 18.4023 | 0.265 | 18.4023 |
| NVIDIA GeForce RTX 3080 | 512x512 | 0.064 | 13.6568 | 0.064 | 13.6568 |
| NVIDIA GeForce RTX 3080 | 256x256 | 0.018 | 5.91728 | 0.018 | 5.91728 |
| NVIDIA GeForce RTX 3080 | 2048x2048 | OOM | OOM | 1.866 | 44.4717 |
| NVIDIA GeForce RTX 3080 | 1024x1024 | 0.302 | 44.4308 | 0.302 | 44.4308 |
| NVIDIA GeForce RTX 3070 | 512x512 | 0.093 | 17.1465 | 0.093 | 17.1465 |
| NVIDIA GeForce RTX 3070 | 256x256 | 0.025 | 7.42931 | 0.026 | 7.42931 |
| NVIDIA GeForce RTX 3070 | 2048x2048 | OOM | OOM | 2.674 | 55.8355 |
| NVIDIA GeForce RTX 3070 | 1024x1024 | 0.443 | 55.7841 | 0.443 | 55.7841 |
</details>
## Available VAEs
| | **Endpoint** | **Model** |
|:-:|:-----------:|:--------:|
| **Stable Diffusion v1** | [https://qc6479g0aac6qwy9.us-east-1.aws.endpoints.huggingface.cloud](https://qc6479g0aac6qwy9.us-east-1.aws.endpoints.huggingface.cloud) | [`stabilityai/sd-vae-ft-mse`](https://hf.co/stabilityai/sd-vae-ft-mse) |
| **Stable Diffusion XL** | [https://xjqqhmyn62rog84g.us-east-1.aws.endpoints.huggingface.cloud](https://xjqqhmyn62rog84g.us-east-1.aws.endpoints.huggingface.cloud) | [`madebyollin/sdxl-vae-fp16-fix`](https://hf.co/madebyollin/sdxl-vae-fp16-fix) |
| **Flux** | [https://ptccx55jz97f9zgo.us-east-1.aws.endpoints.huggingface.cloud](https://ptccx55jz97f9zgo.us-east-1.aws.endpoints.huggingface.cloud) | [`black-forest-labs/FLUX.1-schnell`](https://hf.co/black-forest-labs/FLUX.1-schnell) |
> [!TIP]
> Model support can be requested [here](https://github.com/huggingface/diffusers/issues/new?template=remote-vae-pilot-feedback.yml).
## Code
> [!TIP]
> Install `diffusers` from `main` to run the code: `pip install git+https://github.com/huggingface/diffusers@main`
A helper method simplifies interacting with Hybrid Inference.
```python
from diffusers.utils.remote_utils import remote_encode
```
### Basic example
Let's encode an image, then decode it to demonstrate.
<figure class="image flex flex-col items-center justify-center text-center m-0 w-full">
<img src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/astronaut.jpg"/>
</figure>
<details><summary>Code</summary>
```python
from diffusers.utils import load_image
from diffusers.utils.remote_utils import remote_decode
image = load_image("https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/astronaut.jpg?download=true")
latent = remote_encode(
endpoint="https://ptccx55jz97f9zgo.us-east-1.aws.endpoints.huggingface.cloud/",
scaling_factor=0.3611,
shift_factor=0.1159,
)
decoded = remote_decode(
endpoint="https://whhx50ex1aryqvw6.us-east-1.aws.endpoints.huggingface.cloud/",
tensor=latent,
scaling_factor=0.3611,
shift_factor=0.1159,
)
```
</details>
<figure class="image flex flex-col items-center justify-center text-center m-0 w-full">
<img src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/blog/remote_vae/decoded.png"/>
</figure>
### Generation
Now let's look at a generation example, we'll encode the image, generate then remotely decode too!
<details><summary>Code</summary>
```python
import torch
from diffusers import StableDiffusionImg2ImgPipeline
from diffusers.utils import load_image
from diffusers.utils.remote_utils import remote_decode, remote_encode
pipe = StableDiffusionImg2ImgPipeline.from_pretrained(
"stable-diffusion-v1-5/stable-diffusion-v1-5",
torch_dtype=torch.float16,
variant="fp16",
vae=None,
).to("cuda")
init_image = load_image(
"https://raw.githubusercontent.com/CompVis/stable-diffusion/main/assets/stable-samples/img2img/sketch-mountains-input.jpg"
)
init_image = init_image.resize((768, 512))
init_latent = remote_encode(
endpoint="https://qc6479g0aac6qwy9.us-east-1.aws.endpoints.huggingface.cloud/",
image=init_image,
scaling_factor=0.18215,
)
prompt = "A fantasy landscape, trending on artstation"
latent = pipe(
prompt=prompt,
image=init_latent,
strength=0.75,
output_type="latent",
).images
image = remote_decode(
endpoint="https://q1bj3bpq6kzilnsu.us-east-1.aws.endpoints.huggingface.cloud/",
tensor=latent,
scaling_factor=0.18215,
)
image.save("fantasy_landscape.jpg")
```
</details>
<figure class="image flex flex-col items-center justify-center text-center m-0 w-full">
<img src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/blog/remote_vae/fantasy_landscape.png"/>
</figure>
## Integrations
* **[SD.Next](https://github.com/vladmandic/sdnext):** All-in-one UI with direct supports Hybrid Inference.
* **[ComfyUI-HFRemoteVae](https://github.com/kijai/ComfyUI-HFRemoteVae):** ComfyUI node for Hybrid Inference.
+5 -7
View File
@@ -161,10 +161,10 @@ Your Python environment will find the `main` version of 🤗 Diffusers on the ne
Model weights and files are downloaded from the Hub to a cache which is usually your home directory. You can change the cache location by specifying the `HF_HOME` or `HUGGINFACE_HUB_CACHE` environment variables or configuring the `cache_dir` parameter in methods like [`~DiffusionPipeline.from_pretrained`].
Cached files allow you to run 🤗 Diffusers offline. To prevent 🤗 Diffusers from connecting to the internet, set the `HF_HUB_OFFLINE` environment variable to `1` and 🤗 Diffusers will only load previously downloaded files in the cache.
Cached files allow you to run 🤗 Diffusers offline. To prevent 🤗 Diffusers from connecting to the internet, set the `HF_HUB_OFFLINE` environment variable to `True` and 🤗 Diffusers will only load previously downloaded files in the cache.
```shell
export HF_HUB_OFFLINE=1
export HF_HUB_OFFLINE=True
```
For more details about managing and cleaning the cache, take a look at the [caching](https://huggingface.co/docs/huggingface_hub/guides/manage-cache) guide.
@@ -179,16 +179,14 @@ Telemetry is only sent when loading models and pipelines from the Hub,
and it is not collected if you're loading local files.
We understand that not everyone wants to share additional information,and we respect your privacy.
You can disable telemetry collection by setting the `HF_HUB_DISABLE_TELEMETRY` environment variable from your terminal:
You can disable telemetry collection by setting the `DISABLE_TELEMETRY` environment variable from your terminal:
On Linux/MacOS:
```bash
export HF_HUB_DISABLE_TELEMETRY=1
export DISABLE_TELEMETRY=YES
```
On Windows:
```bash
set HF_HUB_DISABLE_TELEMETRY=1
set DISABLE_TELEMETRY=YES
```
-20
View File
@@ -198,18 +198,6 @@ export_to_video(video, "output.mp4", fps=8)
Group offloading (for CUDA devices with support for asynchronous data transfer streams) overlaps data transfer and computation to reduce the overall execution time compared to sequential offloading. This is enabled using layer prefetching with CUDA streams. The next layer to be executed is loaded onto the accelerator device while the current layer is being executed - this increases the memory requirements slightly. Group offloading also supports leaf-level offloading (equivalent to sequential CPU offloading) but can be made much faster when using streams.
<Tip>
- Group offloading may not work with all models out-of-the-box. If the forward implementations of the model contain weight-dependent device-casting of inputs, it may clash with the offloading mechanism's handling of device-casting.
- The `offload_type` parameter can be set to either `block_level` or `leaf_level`. `block_level` offloads groups of `torch::nn::ModuleList` or `torch::nn:Sequential` modules based on a configurable attribute `num_blocks_per_group`. For example, if you set `num_blocks_per_group=2` on a standard transformer model containing 40 layers, it will onload/offload 2 layers at a time for a total of 20 onload/offloads. This drastically reduces the VRAM requirements. `leaf_level` offloads individual layers at the lowest level, which is equivalent to sequential offloading. However, unlike sequential offloading, group offloading can be made much faster when using streams, with minimal compromise to end-to-end generation time.
- The `use_stream` parameter can be used with CUDA devices to enable prefetching layers for onload. It defaults to `False`. Layer prefetching allows overlapping computation and data transfer of model weights, which drastically reduces the overall execution time compared to other offloading methods. However, it can increase the CPU RAM usage significantly. Ensure that available CPU RAM that is at least twice the size of the model when setting `use_stream=True`. You can find more information about CUDA streams [here](https://pytorch.org/docs/stable/generated/torch.cuda.Stream.html)
- If specifying `use_stream=True` on VAEs with tiling enabled, make sure to do a dummy forward pass (possibly with dummy inputs) before the actual inference to avoid device-mismatch errors. This may not work on all implementations. Please open an issue if you encounter any problems.
- The parameter `low_cpu_mem_usage` can be set to `True` to reduce CPU memory usage when using streams for group offloading. This is useful when the CPU memory is the bottleneck, but it may counteract the benefits of using streams and increase the overall execution time. The CPU memory savings come from creating pinned-tensors on-the-fly instead of pre-pinning them. This parameter is better suited for using `leaf_level` offloading.
For more information about available parameters and an explanation of how group offloading works, refer to [`~hooks.group_offloading.apply_group_offloading`].
</Tip>
## FP8 layerwise weight-casting
PyTorch supports `torch.float8_e4m3fn` and `torch.float8_e5m2` as weight storage dtypes, but they can't be used for computation in many different tensor operations due to unimplemented kernel support. However, you can use these dtypes to store model weights in fp8 precision and upcast them on-the-fly when the layers are used in the forward pass. This is known as layerwise weight-casting.
@@ -247,14 +235,6 @@ In the above example, layerwise casting is enabled on the transformer component
However, you gain more control and flexibility by directly utilizing the [`~hooks.layerwise_casting.apply_layerwise_casting`] function instead of [`~ModelMixin.enable_layerwise_casting`].
<Tip>
- Layerwise casting may not work with all models out-of-the-box. Sometimes, the forward implementations of the model might contain internal typecasting of weight values. Such implementations are not supported due to the currently simplistic implementation of layerwise casting, which assumes that the forward pass is independent of the weight precision and that the input dtypes are always in `compute_dtype`. An example of an incompatible implementation can be found [here](https://github.com/huggingface/transformers/blob/7f5077e53682ca855afc826162b204ebf809f1f9/src/transformers/models/t5/modeling_t5.py#L294-L299).
- Layerwise casting may fail on custom modeling implementations that make use of [PEFT](https://github.com/huggingface/peft) layers. Some minimal checks to handle this case is implemented but is not extensively tested or guaranteed to work in all cases.
- It can be also be applied partially to specific layers of a model. Partially applying layerwise casting can either be done manually by calling the `apply_layerwise_casting` function on specific internal modules, or by specifying the `skip_modules_pattern` and `skip_modules_classes` parameters for a root module. These parameters are particularly useful for layers such as normalization and modulation.
</Tip>
## Channels-last memory format
The channels-last memory format is an alternative way of ordering NCHW tensors in memory to preserve dimension ordering. Channels-last tensors are ordered in such a way that the channels become the densest dimension (storing images pixel-per-pixel). Since not all operators currently support the channels-last format, it may result in worst performance but you should still try and see if it works for your model.
-17
View File
@@ -95,23 +95,6 @@ Use the Space below to gauge a pipeline's memory requirements before you downloa
></iframe>
</div>
### Specifying Component-Specific Data Types
You can customize the data types for individual sub-models by passing a dictionary to the `torch_dtype` parameter. This allows you to load different components of a pipeline in different floating point precisions. For instance, if you want to load the transformer with `torch.bfloat16` and all other components with `torch.float16`, you can pass a dictionary mapping:
```python
from diffusers import HunyuanVideoPipeline
import torch
pipe = HunyuanVideoPipeline.from_pretrained(
"hunyuanvideo-community/HunyuanVideo",
torch_dtype={'transformer': torch.bfloat16, 'default': torch.float16},
)
print(pipe.transformer.dtype, pipe.vae.dtype) # (torch.bfloat16, torch.float16)
```
If a component is not explicitly specified in the dictionary and no `default` is provided, it will be loaded with `torch.float32`.
### Local pipeline
To load a pipeline locally, use [git-lfs](https://git-lfs.github.com/) to manually download a checkpoint to your local disk.
+6
View File
@@ -66,6 +66,12 @@ from accelerate.utils import write_basic_config
write_basic_config()
```
## 원을 채우는 데이터셋
원본 데이터셋은 ControlNet [repo](https://huggingface.co/lllyasviel/ControlNet/blob/main/training/fill50k.zip)에 올라와있지만, 우리는 [여기](https://huggingface.co/datasets/fusing/fill50k)에 새롭게 다시 올려서 🤗 Datasets 과 호환가능합니다. 그래서 학습 스크립트 상에서 데이터 불러오기를 다룰 수 있습니다.
우리의 학습 예시는 원래 ControlNet의 학습에 쓰였던 [`stable-diffusion-v1-5/stable-diffusion-v1-5`](https://huggingface.co/stable-diffusion-v1-5/stable-diffusion-v1-5)을 사용합니다. 그렇지만 ControlNet은 대응되는 어느 Stable Diffusion 모델([`CompVis/stable-diffusion-v1-4`](https://huggingface.co/CompVis/stable-diffusion-v1-4)) 혹은 [`stabilityai/stable-diffusion-2-1`](https://huggingface.co/stabilityai/stable-diffusion-2-1)의 증가를 위해 학습될 수 있습니다.
자체 데이터셋을 사용하기 위해서는 [학습을 위한 데이터셋 생성하기](create_dataset) 가이드를 확인하세요.
## 학습
@@ -79,13 +79,13 @@ This command will prompt you for a token. Copy-paste yours from your [settings/t
### Target Modules
When LoRA was first adapted from language models to diffusion models, it was applied to the cross-attention layers in the Unet that relate the image representations with the prompts that describe them.
More recently, SOTA text-to-image diffusion models replaced the Unet with a diffusion Transformer(DiT). With this change, we may also want to explore
applying LoRA training onto different types of layers and blocks. To allow more flexibility and control over the targeted modules we added `--lora_layers`- in which you can specify in a comma separated string
applying LoRA training onto different types of layers and blocks. To allow more flexibility and control over the targeted modules we added `--lora_layers`- in which you can specify in a comma seperated string
the exact modules for LoRA training. Here are some examples of target modules you can provide:
- for attention only layers: `--lora_layers="attn.to_k,attn.to_q,attn.to_v,attn.to_out.0"`
- to train the same modules as in the fal trainer: `--lora_layers="attn.to_k,attn.to_q,attn.to_v,attn.to_out.0,attn.add_k_proj,attn.add_q_proj,attn.add_v_proj,attn.to_add_out,ff.net.0.proj,ff.net.2,ff_context.net.0.proj,ff_context.net.2"`
- to train the same modules as in ostris ai-toolkit / replicate trainer: `--lora_blocks="attn.to_k,attn.to_q,attn.to_v,attn.to_out.0,attn.add_k_proj,attn.add_q_proj,attn.add_v_proj,attn.to_add_out,ff.net.0.proj,ff.net.2,ff_context.net.0.proj,ff_context.net.2,norm1_context.linear, norm1.linear,norm.linear,proj_mlp,proj_out"`
> [!NOTE]
> `--lora_layers` can also be used to specify which **blocks** to apply LoRA training to. To do so, simply add a block prefix to each layer in the comma separated string:
> `--lora_layers` can also be used to specify which **blocks** to apply LoRA training to. To do so, simply add a block prefix to each layer in the comma seperated string:
> **single DiT blocks**: to target the ith single transformer block, add the prefix `single_transformer_blocks.i`, e.g. - `single_transformer_blocks.i.attn.to_k`
> **MMDiT blocks**: to target the ith MMDiT block, add the prefix `transformer_blocks.i`, e.g. - `transformer_blocks.i.attn.to_k`
> [!NOTE]
@@ -378,7 +378,7 @@ def parse_args(input_args=None):
default=None,
help="the concept to use to initialize the new inserted tokens when training with "
"--train_text_encoder_ti = True. By default, new tokens (<si><si+1>) are initialized with random value. "
"Alternatively, you could specify a different word/words whose value will be used as the starting point for the new inserted tokens. "
"Alternatively, you could specify a different word/words whos value will be used as the starting point for the new inserted tokens. "
"--num_new_tokens_per_abstraction is ignored when initializer_concept is provided",
)
parser.add_argument(
@@ -662,7 +662,7 @@ def parse_args(input_args=None):
type=str,
default=None,
help=(
"The transformer modules to apply LoRA training on. Please specify the layers in a comma separated. "
"The transformer modules to apply LoRA training on. Please specify the layers in a comma seperated. "
'E.g. - "to_k,to_q,to_v,to_out.0" will result in lora training of attention layers only. For more examples refer to https://github.com/huggingface/diffusers/blob/main/examples/advanced_diffusion_training/README_flux.md'
),
)
@@ -662,7 +662,7 @@ def parse_args(input_args=None):
action="store_true",
default=False,
help=(
"Whether to train a DoRA as proposed in- DoRA: Weight-Decomposed Low-Rank Adaptation https://arxiv.org/abs/2402.09353. "
"Wether to train a DoRA as proposed in- DoRA: Weight-Decomposed Low-Rank Adaptation https://arxiv.org/abs/2402.09353. "
"Note: to use DoRA you need to install peft from main, `pip install git+https://github.com/huggingface/peft.git`"
),
)
@@ -71,7 +71,6 @@ from diffusers.utils import (
convert_unet_state_dict_to_peft,
is_wandb_available,
)
from diffusers.utils.hub_utils import load_or_create_model_card, populate_model_card
from diffusers.utils.import_utils import is_xformers_available
from diffusers.utils.torch_utils import is_compiled_module
@@ -102,7 +101,7 @@ def determine_scheduler_type(pretrained_model_name_or_path, revision):
def save_model_card(
repo_id: str,
use_dora: bool,
images: list = None,
images=None,
base_model: str = None,
train_text_encoder=False,
train_text_encoder_ti=False,
@@ -112,17 +111,20 @@ def save_model_card(
repo_folder=None,
vae_path=None,
):
img_str = "widget:\n"
lora = "lora" if not use_dora else "dora"
widget_dict = []
if images is not None:
for i, image in enumerate(images):
image.save(os.path.join(repo_folder, f"image_{i}.png"))
widget_dict.append(
{"text": validation_prompt if validation_prompt else " ", "output": {"url": f"image_{i}.png"}}
)
else:
widget_dict.append({"text": instance_prompt})
for i, image in enumerate(images):
image.save(os.path.join(repo_folder, f"image_{i}.png"))
img_str += f"""
- text: '{validation_prompt if validation_prompt else ' ' }'
output:
url:
"image_{i}.png"
"""
if not images:
img_str += f"""
- text: '{instance_prompt}'
"""
embeddings_filename = f"{repo_folder}_emb"
instance_prompt_webui = re.sub(r"<s\d+>", "", re.sub(r"<s\d+>", embeddings_filename, instance_prompt, count=1))
ti_keys = ", ".join(f'"{match}"' for match in re.findall(r"<s\d+>", instance_prompt))
@@ -167,7 +169,23 @@ pipeline.load_textual_inversion(state_dict["clip_g"], token=[{ti_keys}], text_en
to trigger concept `{key}` use `{tokens}` in your prompt \n
"""
model_description = f"""
yaml = f"""---
tags:
- stable-diffusion-xl
- stable-diffusion-xl-diffusers
- diffusers-training
- text-to-image
- diffusers
- {lora}
- template:sd-lora
{img_str}
base_model: {base_model}
instance_prompt: {instance_prompt}
license: openrail++
---
"""
model_card = f"""
# SDXL LoRA DreamBooth - {repo_id}
<Gallery />
@@ -216,25 +234,8 @@ Special VAE used for training: {vae_path}.
{license}
"""
model_card = load_or_create_model_card(
repo_id_or_path=repo_id,
from_training=True,
license="openrail++",
base_model=base_model,
prompt=instance_prompt,
model_description=model_description,
widget=widget_dict,
)
tags = [
"text-to-image",
"stable-diffusion-xl",
"stable-diffusion-xl-diffusers",
"text-to-image",
"diffusers",
lora,
"template:sd-lora",
]
model_card = populate_model_card(model_card, tags=tags)
with open(os.path.join(repo_folder, "README.md"), "w") as f:
f.write(yaml + model_card)
def log_validation(
@@ -772,7 +773,7 @@ def parse_args(input_args=None):
action="store_true",
default=False,
help=(
"Whether to train a DoRA as proposed in- DoRA: Weight-Decomposed Low-Rank Adaptation https://arxiv.org/abs/2402.09353. "
"Wether to train a DoRA as proposed in- DoRA: Weight-Decomposed Low-Rank Adaptation https://arxiv.org/abs/2402.09353. "
"Note: to use DoRA you need to install peft from main, `pip install git+https://github.com/huggingface/peft.git`"
),
)
@@ -1874,7 +1875,7 @@ def main(args):
# pack the statically computed variables appropriately here. This is so that we don't
# have to pass them to the dataloader.
# if --train_text_encoder_ti we need add_special_tokens to be True for textual inversion
# if --train_text_encoder_ti we need add_special_tokens to be True fo textual inversion
add_special_tokens = True if args.train_text_encoder_ti else False
if not train_dataset.custom_instance_prompts:
-201
View File
@@ -1,201 +0,0 @@
# Training CogView4 Control
This (experimental) example shows how to train Control LoRAs with [CogView4](https://huggingface.co/THUDM/CogView4-6B) by conditioning it with additional structural controls (like depth maps, poses, etc.). We provide a script for full fine-tuning, too, refer to [this section](#full-fine-tuning). To know more about CogView4 Control family, refer to the following resources:
To incorporate additional condition latents, we expand the input features of CogView-4 from 64 to 128. The first 64 channels correspond to the original input latents to be denoised, while the latter 64 channels correspond to control latents. This expansion happens on the `patch_embed` layer, where the combined latents are projected to the expected feature dimension of rest of the network. Inference is performed using the `CogView4ControlPipeline`.
> [!NOTE]
> **Gated model**
>
> As the model is gated, before using it with diffusers you first need to go to the [CogView4 Hugging Face page](https://huggingface.co/THUDM/CogView4-6B), fill in the form and accept the gate. Once you are in, you need to log in so that your system knows youve accepted the gate. Use the command below to log in:
```bash
huggingface-cli login
```
The example command below shows how to launch fine-tuning for pose conditions. The dataset ([`raulc0399/open_pose_controlnet`](https://huggingface.co/datasets/raulc0399/open_pose_controlnet)) being used here already has the pose conditions of the original images, so we don't have to compute them.
```bash
accelerate launch train_control_lora_cogview4.py \
--pretrained_model_name_or_path="THUDM/CogView4-6B" \
--dataset_name="raulc0399/open_pose_controlnet" \
--output_dir="pose-control-lora" \
--mixed_precision="bf16" \
--train_batch_size=1 \
--rank=64 \
--gradient_accumulation_steps=4 \
--gradient_checkpointing \
--use_8bit_adam \
--learning_rate=1e-4 \
--report_to="wandb" \
--lr_scheduler="constant" \
--lr_warmup_steps=0 \
--max_train_steps=5000 \
--validation_image="openpose.png" \
--validation_prompt="A couple, 4k photo, highly detailed" \
--offload \
--seed="0" \
--push_to_hub
```
`openpose.png` comes from [here](https://huggingface.co/Adapter/t2iadapter/resolve/main/openpose.png).
You need to install `diffusers` from the branch of [this PR](https://github.com/huggingface/diffusers/pull/9999). When it's merged, you should install `diffusers` from the `main`.
The training script exposes additional CLI args that might be useful to experiment with:
* `use_lora_bias`: When set, additionally trains the biases of the `lora_B` layer.
* `train_norm_layers`: When set, additionally trains the normalization scales. Takes care of saving and loading.
* `lora_layers`: Specify the layers you want to apply LoRA to. If you specify "all-linear", all the linear layers will be LoRA-attached.
### Training with DeepSpeed
It's possible to train with [DeepSpeed](https://github.com/microsoft/DeepSpeed), specifically leveraging the Zero2 system optimization. To use it, save the following config to an YAML file (feel free to modify as needed):
```yaml
compute_environment: LOCAL_MACHINE
debug: false
deepspeed_config:
gradient_accumulation_steps: 1
gradient_clipping: 1.0
offload_optimizer_device: cpu
offload_param_device: cpu
zero3_init_flag: false
zero_stage: 2
distributed_type: DEEPSPEED
downcast_bf16: 'no'
enable_cpu_affinity: false
machine_rank: 0
main_training_function: main
mixed_precision: bf16
num_machines: 1
num_processes: 1
rdzv_backend: static
same_network: true
tpu_env: []
tpu_use_cluster: false
tpu_use_sudo: false
use_cpu: false
```
And then while launching training, pass the config file:
```bash
accelerate launch --config_file=CONFIG_FILE.yaml ...
```
### Inference
The pose images in our dataset were computed using the [`controlnet_aux`](https://github.com/huggingface/controlnet_aux) library. Let's install it first:
```bash
pip install controlnet_aux
```
And then we are ready:
```py
from controlnet_aux import OpenposeDetector
from diffusers import CogView4ControlPipeline
from diffusers.utils import load_image
from PIL import Image
import numpy as np
import torch
pipe = CogView4ControlPipeline.from_pretrained("THUDM/CogView4-6B", torch_dtype=torch.bfloat16).to("cuda")
pipe.load_lora_weights("...") # change this.
open_pose = OpenposeDetector.from_pretrained("lllyasviel/Annotators")
# prepare pose condition.
url = "https://huggingface.co/Adapter/t2iadapter/resolve/main/people.jpg"
image = load_image(url)
image = open_pose(image, detect_resolution=512, image_resolution=1024)
image = np.array(image)[:, :, ::-1]
image = Image.fromarray(np.uint8(image))
prompt = "A couple, 4k photo, highly detailed"
gen_images = pipe(
prompt=prompt,
control_image=image,
num_inference_steps=50,
joint_attention_kwargs={"scale": 0.9},
guidance_scale=25.,
).images[0]
gen_images.save("output.png")
```
## Full fine-tuning
We provide a non-LoRA version of the training script `train_control_cogview4.py`. Here is an example command:
```bash
accelerate launch --config_file=accelerate_ds2.yaml train_control_cogview4.py \
--pretrained_model_name_or_path="THUDM/CogView4-6B" \
--dataset_name="raulc0399/open_pose_controlnet" \
--output_dir="pose-control" \
--mixed_precision="bf16" \
--train_batch_size=2 \
--dataloader_num_workers=4 \
--gradient_accumulation_steps=4 \
--gradient_checkpointing \
--use_8bit_adam \
--proportion_empty_prompts=0.2 \
--learning_rate=5e-5 \
--adam_weight_decay=1e-4 \
--report_to="wandb" \
--lr_scheduler="cosine" \
--lr_warmup_steps=1000 \
--checkpointing_steps=1000 \
--max_train_steps=10000 \
--validation_steps=200 \
--validation_image "2_pose_1024.jpg" "3_pose_1024.jpg" \
--validation_prompt "two friends sitting by each other enjoying a day at the park, full hd, cinematic" "person enjoying a day at the park, full hd, cinematic" \
--offload \
--seed="0" \
--push_to_hub
```
Change the `validation_image` and `validation_prompt` as needed.
For inference, this time, we will run:
```py
from controlnet_aux import OpenposeDetector
from diffusers import CogView4ControlPipeline, CogView4Transformer2DModel
from diffusers.utils import load_image
from PIL import Image
import numpy as np
import torch
transformer = CogView4Transformer2DModel.from_pretrained("...") # change this.
pipe = CogView4ControlPipeline.from_pretrained(
"THUDM/CogView4-6B", transformer=transformer, torch_dtype=torch.bfloat16
).to("cuda")
open_pose = OpenposeDetector.from_pretrained("lllyasviel/Annotators")
# prepare pose condition.
url = "https://huggingface.co/Adapter/t2iadapter/resolve/main/people.jpg"
image = load_image(url)
image = open_pose(image, detect_resolution=512, image_resolution=1024)
image = np.array(image)[:, :, ::-1]
image = Image.fromarray(np.uint8(image))
prompt = "A couple, 4k photo, highly detailed"
gen_images = pipe(
prompt=prompt,
control_image=image,
num_inference_steps=50,
guidance_scale=25.,
).images[0]
gen_images.save("output.png")
```
## Things to note
* The scripts provided in this directory are experimental and educational. This means we may have to tweak things around to get good results on a given condition. We believe this is best done with the community 🤗
* The scripts are not memory-optimized but we offload the VAE and the text encoders to CPU when they are not used if `--offload` is specified.
* We can extract LoRAs from the fully fine-tuned model. While we currently don't provide any utilities for that, users are welcome to refer to [this script](https://github.com/Stability-AI/stability-ComfyUI-nodes/blob/master/control_lora_create.py) that provides a similar functionality.
@@ -1,6 +0,0 @@
transformers==4.47.0
wandb
torch
torchvision
accelerate==1.2.0
peft>=0.14.0
File diff suppressed because it is too large Load Diff
+23 -147
View File
@@ -10,7 +10,7 @@ Please also check out our [Community Scripts](https://github.com/huggingface/dif
| Example | Description | Code Example | Colab | Author |
|:--------------------------------------------------------------------------------------------------------------------------------------|:---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|:------------------------------------------------------------------------------------------|:-------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|--------------------------------------------------------------:|
|Spatiotemporal Skip Guidance (STG)|[Spatiotemporal Skip Guidance for Enhanced Video Diffusion Sampling](https://arxiv.org/abs/2411.18664) (CVPR 2025) enhances video diffusion models by generating a weaker model through layer skipping and using it as guidance, improving fidelity in models like HunyuanVideo, LTXVideo, and Mochi.|[Spatiotemporal Skip Guidance](#spatiotemporal-skip-guidance)|-|[Junha Hyung](https://junhahyung.github.io/), [Kinam Kim](https://kinam0252.github.io/), and [Ednaordinary](https://github.com/Ednaordinary)|
|Spatiotemporal Skip Guidance (STG)|[Spatiotemporal Skip Guidance for Enhanced Video Diffusion Sampling](https://arxiv.org/abs/2411.18664) (CVPR 2025) enhances video diffusion models by generating a weaker model through layer skipping and using it as guidance, improving fidelity in models like HunyuanVideo, LTXVideo, and Mochi.|[Spatiotemporal Skip Guidance](#spatiotemporal-skip-guidance)|-|[Junha Hyung](https://junhahyung.github.io/), [Kinam Kim](https://kinam0252.github.io/)|
|Adaptive Mask Inpainting|Adaptive Mask Inpainting algorithm from [Beyond the Contact: Discovering Comprehensive Affordance for 3D Objects from Pre-trained 2D Diffusion Models](https://github.com/snuvclab/coma) (ECCV '24, Oral) provides a way to insert human inside the scene image without altering the background, by inpainting with adapting mask.|[Adaptive Mask Inpainting](#adaptive-mask-inpainting)|-|[Hyeonwoo Kim](https://sshowbiz.xyz),[Sookwan Han](https://jellyheadandrew.github.io)|
|Flux with CFG|[Flux with CFG](https://github.com/ToTheBeginning/PuLID/blob/main/docs/pulid_for_flux.md) provides an implementation of using CFG in [Flux](https://blackforestlabs.ai/announcing-black-forest-labs/).|[Flux with CFG](#flux-with-cfg)|[Notebook](https://github.com/huggingface/notebooks/blob/main/diffusers/flux_with_cfg.ipynb)|[Linoy Tsaban](https://github.com/linoytsaban), [Apolinário](https://github.com/apolinario), and [Sayak Paul](https://github.com/sayakpaul)|
|Differential Diffusion|[Differential Diffusion](https://github.com/exx8/differential-diffusion) modifies an image according to a text prompt, and according to a map that specifies the amount of change in each region.|[Differential Diffusion](#differential-diffusion)|[![Hugging Face Space](https://img.shields.io/badge/🤗%20Hugging%20Face-Space-yellow)](https://huggingface.co/spaces/exx8/differential-diffusion) [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/exx8/differential-diffusion/blob/main/examples/SD2.ipynb)|[Eran Levin](https://github.com/exx8) and [Ohad Fried](https://www.ohadf.com/)|
@@ -24,12 +24,12 @@ Please also check out our [Community Scripts](https://github.com/huggingface/dif
| Long Prompt Weighting Stable Diffusion | **One** Stable Diffusion Pipeline without tokens length limit, and support parsing weighting in prompt. | [Long Prompt Weighting Stable Diffusion](#long-prompt-weighting-stable-diffusion) | [Notebook](https://github.com/huggingface/notebooks/blob/main/diffusers/long_prompt_weighting_stable_diffusion.ipynb) | [SkyTNT](https://github.com/SkyTNT) |
| Speech to Image | Using automatic-speech-recognition to transcribe text and Stable Diffusion to generate images | [Speech to Image](#speech-to-image) |[Notebook](https://github.com/huggingface/notebooks/blob/main/diffusers/speech_to_image.ipynb) | [Mikail Duzenli](https://github.com/MikailINTech)
| Wild Card Stable Diffusion | Stable Diffusion Pipeline that supports prompts that contain wildcard terms (indicated by surrounding double underscores), with values instantiated randomly from a corresponding txt file or a dictionary of possible values | [Wildcard Stable Diffusion](#wildcard-stable-diffusion) | [Notebook](https://github.com/huggingface/notebooks/blob/main/diffusers/wildcard_stable_diffusion.ipynb) | [Shyam Sudhakaran](https://github.com/shyamsn97) |
| [Composable Stable Diffusion](https://energy-based-model.github.io/Compositional-Visual-Generation-with-Composable-Diffusion-Models/) | Stable Diffusion Pipeline that supports prompts that contain "&#124;" in prompts (as an AND condition) and weights (separated by "&#124;" as well) to positively / negatively weight prompts. | [Composable Stable Diffusion](#composable-stable-diffusion) | [Notebook](https://github.com/huggingface/notebooks/blob/main/diffusers/composable_stable_diffusion.ipynb) | [Mark Rich](https://github.com/MarkRich) |
| [Composable Stable Diffusion](https://energy-based-model.github.io/Compositional-Visual-Generation-with-Composable-Diffusion-Models/) | Stable Diffusion Pipeline that supports prompts that contain "&#124;" in prompts (as an AND condition) and weights (separated by "&#124;" as well) to positively / negatively weight prompts. | [Composable Stable Diffusion](#composable-stable-diffusion) | - | [Mark Rich](https://github.com/MarkRich) |
| Seed Resizing Stable Diffusion | Stable Diffusion Pipeline that supports resizing an image and retaining the concepts of the 512 by 512 generation. | [Seed Resizing](#seed-resizing) | [Notebook](https://github.com/huggingface/notebooks/blob/main/diffusers/seed_resizing.ipynb) | [Mark Rich](https://github.com/MarkRich) |
| Imagic Stable Diffusion | Stable Diffusion Pipeline that enables writing a text prompt to edit an existing image | [Imagic Stable Diffusion](#imagic-stable-diffusion) | [Notebook](https://github.com/huggingface/notebooks/blob/main/diffusers/imagic_stable_diffusion.ipynb) | [Mark Rich](https://github.com/MarkRich) |
| Multilingual Stable Diffusion | Stable Diffusion Pipeline that supports prompts in 50 different languages. | [Multilingual Stable Diffusion](#multilingual-stable-diffusion-pipeline) | [Notebook](https://github.com/huggingface/notebooks/blob/main/diffusers/multilingual_stable_diffusion.ipynb) | [Juan Carlos Piñeros](https://github.com/juancopi81) |
| GlueGen Stable Diffusion | Stable Diffusion Pipeline that supports prompts in different languages using GlueGen adapter. | [GlueGen Stable Diffusion](#gluegen-stable-diffusion-pipeline) | [Notebook](https://github.com/huggingface/notebooks/blob/main/diffusers/gluegen_stable_diffusion.ipynb) | [Phạm Hồng Vinh](https://github.com/rootonchair) |
| Image to Image Inpainting Stable Diffusion | Stable Diffusion Pipeline that enables the overlaying of two images and subsequent inpainting | [Image to Image Inpainting Stable Diffusion](#image-to-image-inpainting-stable-diffusion) | [Notebook](https://github.com/huggingface/notebooks/blob/main/diffusers/image_to_image_inpainting_stable_diffusion.ipynb) | [Alex McKinney](https://github.com/vvvm23) |
| Image to Image Inpainting Stable Diffusion | Stable Diffusion Pipeline that enables the overlaying of two images and subsequent inpainting | [Image to Image Inpainting Stable Diffusion](#image-to-image-inpainting-stable-diffusion) | - | [Alex McKinney](https://github.com/vvvm23) |
| Text Based Inpainting Stable Diffusion | Stable Diffusion Inpainting Pipeline that enables passing a text prompt to generate the mask for inpainting | [Text Based Inpainting Stable Diffusion](#text-based-inpainting-stable-diffusion) | [Notebook](https://github.com/huggingface/notebooks/blob/main/diffusers/text_based_inpainting_stable_dffusion.ipynb) | [Dhruv Karan](https://github.com/unography) |
| Bit Diffusion | Diffusion on discrete data | [Bit Diffusion](#bit-diffusion) | - | [Stuti R.](https://github.com/kingstut) |
| K-Diffusion Stable Diffusion | Run Stable Diffusion with any of [K-Diffusion's samplers](https://github.com/crowsonkb/k-diffusion/blob/master/k_diffusion/sampling.py) | [Stable Diffusion with K Diffusion](#stable-diffusion-with-k-diffusion) | - | [Patrick von Platen](https://github.com/patrickvonplaten/) |
@@ -41,7 +41,7 @@ Please also check out our [Community Scripts](https://github.com/huggingface/dif
| UnCLIP Image Interpolation Pipeline | Diffusion Pipeline that allows passing two images/image_embeddings and produces images while interpolating between their image-embeddings | [UnCLIP Image Interpolation Pipeline](#unclip-image-interpolation-pipeline) | [Notebook](https://github.com/huggingface/notebooks/blob/main/diffusers/unclip_image_interpolation.ipynb)| [Naga Sai Abhinay Devarinti](https://github.com/Abhinay1997/) |
| DDIM Noise Comparative Analysis Pipeline | Investigating how the diffusion models learn visual concepts from each noise level (which is a contribution of [P2 weighting (CVPR 2022)](https://arxiv.org/abs/2204.00227)) | [DDIM Noise Comparative Analysis Pipeline](#ddim-noise-comparative-analysis-pipeline) | [Notebook](https://github.com/huggingface/notebooks/blob/main/diffusers/ddim_noise_comparative_analysis.ipynb)| [Aengus (Duc-Anh)](https://github.com/aengusng8) |
| CLIP Guided Img2Img Stable Diffusion Pipeline | Doing CLIP guidance for image to image generation with Stable Diffusion | [CLIP Guided Img2Img Stable Diffusion](#clip-guided-img2img-stable-diffusion) | [Notebook](https://github.com/huggingface/notebooks/blob/main/diffusers/clip_guided_img2img_stable_diffusion.ipynb) | [Nipun Jindal](https://github.com/nipunjindal/) |
| TensorRT Stable Diffusion Text to Image Pipeline | Accelerates the Stable Diffusion Text2Image Pipeline using TensorRT | [TensorRT Stable Diffusion Text to Image Pipeline](#tensorrt-text2image-stable-diffusion-pipeline) | [Notebook](https://github.com/huggingface/notebooks/blob/main/diffusers/tensorrt_text2image_stable_diffusion_pipeline.ipynb) | [Asfiya Baig](https://github.com/asfiyab-nvidia) |
| TensorRT Stable Diffusion Text to Image Pipeline | Accelerates the Stable Diffusion Text2Image Pipeline using TensorRT | [TensorRT Stable Diffusion Text to Image Pipeline](#tensorrt-text2image-stable-diffusion-pipeline) | - | [Asfiya Baig](https://github.com/asfiyab-nvidia) |
| EDICT Image Editing Pipeline | Diffusion pipeline for text-guided image editing | [EDICT Image Editing Pipeline](#edict-image-editing-pipeline) | [Notebook](https://github.com/huggingface/notebooks/blob/main/diffusers/edict_image_pipeline.ipynb) | [Joqsan Azocar](https://github.com/Joqsan) |
| Stable Diffusion RePaint | Stable Diffusion pipeline using [RePaint](https://arxiv.org/abs/2201.09865) for inpainting. | [Stable Diffusion RePaint](#stable-diffusion-repaint )|[Notebook](https://github.com/huggingface/notebooks/blob/main/diffusers/stable_diffusion_repaint.ipynb)| [Markus Pobitzer](https://github.com/Markus-Pobitzer) |
| TensorRT Stable Diffusion Image to Image Pipeline | Accelerates the Stable Diffusion Image2Image Pipeline using TensorRT | [TensorRT Stable Diffusion Image to Image Pipeline](#tensorrt-image2image-stable-diffusion-pipeline) | - | [Asfiya Baig](https://github.com/asfiyab-nvidia) |
@@ -58,7 +58,7 @@ Please also check out our [Community Scripts](https://github.com/huggingface/dif
| FABRIC - Stable Diffusion with feedback Pipeline | pipeline supports feedback from liked and disliked images | [Stable Diffusion Fabric Pipeline](#stable-diffusion-fabric-pipeline) | [Notebook](https://github.com/huggingface/notebooks/blob/main/diffusers/stable_diffusion_fabric.ipynb)| [Shauray Singh](https://shauray8.github.io/about_shauray/) |
| sketch inpaint - Inpainting with non-inpaint Stable Diffusion | sketch inpaint much like in automatic1111 | [Masked Im2Im Stable Diffusion Pipeline](#stable-diffusion-masked-im2im) | - | [Anatoly Belikov](https://github.com/noskill) |
| sketch inpaint xl - Inpainting with non-inpaint Stable Diffusion | sketch inpaint much like in automatic1111 | [Masked Im2Im Stable Diffusion XL Pipeline](#stable-diffusion-xl-masked-im2im) | - | [Anatoly Belikov](https://github.com/noskill) |
| prompt-to-prompt | change parts of a prompt and retain image structure (see [paper page](https://prompt-to-prompt.github.io/)) | [Prompt2Prompt Pipeline](#prompt2prompt-pipeline) | [Notebook](https://github.com/huggingface/notebooks/blob/main/diffusers/prompt_2_prompt_pipeline.ipynb) | [Umer H. Adil](https://twitter.com/UmerHAdil) |
| prompt-to-prompt | change parts of a prompt and retain image structure (see [paper page](https://prompt-to-prompt.github.io/)) | [Prompt2Prompt Pipeline](#prompt2prompt-pipeline) | - | [Umer H. Adil](https://twitter.com/UmerHAdil) |
| Latent Consistency Pipeline | Implementation of [Latent Consistency Models: Synthesizing High-Resolution Images with Few-Step Inference](https://arxiv.org/abs/2310.04378) | [Latent Consistency Pipeline](#latent-consistency-pipeline) | - | [Simian Luo](https://github.com/luosiallen) |
| Latent Consistency Img2img Pipeline | Img2img pipeline for Latent Consistency Models | [Latent Consistency Img2Img Pipeline](#latent-consistency-img2img-pipeline) | - | [Logan Zoellner](https://github.com/nagolinc) |
| Latent Consistency Interpolation Pipeline | Interpolate the latent space of Latent Consistency Models with multiple prompts | [Latent Consistency Interpolation Pipeline](#latent-consistency-interpolation-pipeline) | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/1pK3NrLWJSiJsBynLns1K1-IDTW9zbPvl?usp=sharing) | [Aryan V S](https://github.com/a-r-r-o-w) |
@@ -85,7 +85,7 @@ PIXART-α Controlnet pipeline | Implementation of the controlnet model for pixar
| Stable Diffusion XL Attentive Eraser Pipeline |[[AAAI2025 Oral] Attentive Eraser](https://github.com/Anonym0u3/AttentiveEraser) is a novel tuning-free method that enhances object removal capabilities in pre-trained diffusion models.|[Stable Diffusion XL Attentive Eraser Pipeline](#stable-diffusion-xl-attentive-eraser-pipeline)|-|[Wenhao Sun](https://github.com/Anonym0u3) and [Benlei Cui](https://github.com/Benny079)|
| Perturbed-Attention Guidance |StableDiffusionPAGPipeline is a modification of StableDiffusionPipeline to support Perturbed-Attention Guidance (PAG).|[Perturbed-Attention Guidance](#perturbed-attention-guidance)|[Notebook](https://github.com/huggingface/notebooks/blob/main/diffusers/perturbed_attention_guidance.ipynb)|[Hyoungwon Cho](https://github.com/HyoungwonCho)|
| CogVideoX DDIM Inversion Pipeline | Implementation of DDIM inversion and guided attention-based editing denoising process on CogVideoX. | [CogVideoX DDIM Inversion Pipeline](#cogvideox-ddim-inversion-pipeline) | - | [LittleNyima](https://github.com/LittleNyima) |
| FaithDiff Stable Diffusion XL Pipeline | Implementation of [(CVPR 2025) FaithDiff: Unleashing Diffusion Priors for Faithful Image Super-resolutionUnleashing Diffusion Priors for Faithful Image Super-resolution](https://arxiv.org/abs/2411.18824) - FaithDiff is a faithful image super-resolution method that leverages latent diffusion models by actively adapting the diffusion prior and jointly fine-tuning its components (encoder and diffusion model) with an alignment module to ensure high fidelity and structural consistency. | [FaithDiff Stable Diffusion XL Pipeline](#faithdiff-stable-diffusion-xl-pipeline) | [![Hugging Face Models](https://img.shields.io/badge/%F0%9F%A4%97%20Hugging%20Face-Models-blue)](https://huggingface.co/jychen9811/FaithDiff) | [Junyang Chen, Jinshan Pan, Jiangxin Dong, IMAG Lab, (Adapted by Eliseu Silva)](https://github.com/JyChen9811/FaithDiff) |
To load a custom pipeline you just need to pass the `custom_pipeline` argument to `DiffusionPipeline`, as one of the files in `diffusers/examples/community`. Feel free to send a PR with your own pipelines, we will merge them quickly.
```py
@@ -124,6 +124,7 @@ pipe = pipe.to("cuda")
#--------Option--------#
prompt = "A close-up of a beautiful woman's face with colored powder exploding around her, creating an abstract splash of vibrant hues, realistic style."
stg_applied_layers_idx = [34]
stg_mode = "STG"
stg_scale = 1.0 # 0.0 for CFG
#----------------------#
@@ -953,7 +954,6 @@ for i in range(args.num_images):
images.append(th.from_numpy(np.array(image)).permute(2, 0, 1) / 255.)
grid = tvu.make_grid(th.stack(images, dim=0), nrow=4, padding=0)
tvu.save_image(grid, f'{prompt}_{args.weights}' + '.png')
print("Image saved successfully!")
```
### Imagic Stable Diffusion
@@ -1269,39 +1269,28 @@ The aim is to overlay two images, then mask out the boundary between `image` and
For example, this could be used to place a logo on a shirt and make it blend seamlessly.
```python
import PIL
import torch
import requests
from PIL import Image
from io import BytesIO
from diffusers import DiffusionPipeline
image_url = "https://raw.githubusercontent.com/CompVis/latent-diffusion/main/data/inpainting_examples/overture-creations-5sI6fQgYIuo.png"
inner_image_url = "https://raw.githubusercontent.com/CompVis/latent-diffusion/main/data/inpainting_examples/overture-creations-5sI6fQgYIuo_mask.png"
mask_url = "https://raw.githubusercontent.com/CompVis/latent-diffusion/main/data/inpainting_examples/overture-creations-5sI6fQgYIuo_mask.png"
image_path = "./path-to-image.png"
inner_image_path = "./path-to-inner-image.png"
mask_path = "./path-to-mask.png"
def load_image(url, mode="RGB"):
response = requests.get(url)
if response.status_code == 200:
return Image.open(BytesIO(response.content)).convert(mode).resize((512, 512))
else:
raise FileNotFoundError(f"Could not retrieve image from {url}")
init_image = load_image(image_url, mode="RGB")
inner_image = load_image(inner_image_url, mode="RGBA")
mask_image = load_image(mask_url, mode="RGB")
init_image = PIL.Image.open(image_path).convert("RGB").resize((512, 512))
inner_image = PIL.Image.open(inner_image_path).convert("RGBA").resize((512, 512))
mask_image = PIL.Image.open(mask_path).convert("RGB").resize((512, 512))
pipe = DiffusionPipeline.from_pretrained(
"stable-diffusion-v1-5/stable-diffusion-inpainting",
"runwayml/stable-diffusion-inpainting",
custom_pipeline="img2img_inpainting",
torch_dtype=torch.float16
)
pipe = pipe.to("cuda")
prompt = "a mecha robot sitting on a bench"
prompt = "Your prompt here!"
image = pipe(prompt=prompt, image=init_image, inner_image=inner_image, mask_image=mask_image).images[0]
image.save("output.png")
```
![2 by 2 grid demonstrating image to image inpainting.](https://user-images.githubusercontent.com/44398246/203506577-ec303be4-887e-4ebd-a773-c83fcb3dd01a.png)
@@ -3263,19 +3252,14 @@ Here's a full example for `ReplaceEdit``:
```python
import torch
from diffusers import DiffusionPipeline
import numpy as np
from PIL import Image
import matplotlib.pyplot as plt
from diffusers import DiffusionPipeline
pipe = DiffusionPipeline.from_pretrained(
"CompVis/stable-diffusion-v1-4",
custom_pipeline="pipeline_prompt2prompt"
).to("cuda")
pipe = DiffusionPipeline.from_pretrained("CompVis/stable-diffusion-v1-4", custom_pipeline="pipeline_prompt2prompt").to("cuda")
prompts = [
"A turtle playing with a ball",
"A monkey playing with a ball"
]
prompts = ["A turtle playing with a ball",
"A monkey playing with a ball"]
cross_attention_kwargs = {
"edit_type": "replace",
@@ -3283,15 +3267,7 @@ cross_attention_kwargs = {
"self_replace_steps": 0.4
}
outputs = pipe(
prompt=prompts,
height=512,
width=512,
num_inference_steps=50,
cross_attention_kwargs=cross_attention_kwargs
)
outputs.images[0].save("output_image_0.png")
outputs = pipe(prompt=prompts, height=512, width=512, num_inference_steps=50, cross_attention_kwargs=cross_attention_kwargs)
```
And abbreviated examples for the other edits:
@@ -5333,103 +5309,3 @@ output = pipeline_for_inversion(
pipeline.export_latents_to_video(output.inverse_latents[-1], "path/to/inverse_video.mp4", fps=8)
pipeline.export_latents_to_video(output.recon_latents[-1], "path/to/recon_video.mp4", fps=8)
```
# FaithDiff Stable Diffusion XL Pipeline
[Project](https://jychen9811.github.io/FaithDiff_page/) / [GitHub](https://github.com/JyChen9811/FaithDiff/)
This the implementation of the FaithDiff pipeline for SDXL, adapted to use the HuggingFace Diffusers.
For more details see the project links above.
## Example Usage
This example upscale and restores a low-quality image. The input image has a resolution of 512x512 and will be upscaled at a scale of 2x, to a final resolution of 1024x1024. It is possible to upscale to a larger scale, but it is recommended that the input image be at least 1024x1024 in these cases. To upscale this image by 4x, for example, it would be recommended to re-input the result into a new 2x processing, thus performing progressive scaling.
````py
import random
import numpy as np
import torch
from diffusers import DiffusionPipeline, AutoencoderKL, UniPCMultistepScheduler
from huggingface_hub import hf_hub_download
from diffusers.utils import load_image
from PIL import Image
device = "cuda"
dtype = torch.float16
MAX_SEED = np.iinfo(np.int32).max
# Download weights for additional unet layers
model_file = hf_hub_download(
"jychen9811/FaithDiff",
filename="FaithDiff.bin", local_dir="./proc_data/faithdiff", local_dir_use_symlinks=False
)
# Initialize the models and pipeline
vae = AutoencoderKL.from_pretrained("madebyollin/sdxl-vae-fp16-fix", torch_dtype=dtype)
model_id = "SG161222/RealVisXL_V4.0"
pipe = DiffusionPipeline.from_pretrained(
model_id,
torch_dtype=dtype,
vae=vae,
unet=None, #<- Do not load with original model.
custom_pipeline="pipeline_faithdiff_stable_diffusion_xl",
use_safetensors=True,
variant="fp16",
).to(device)
# Here we need use pipeline internal unet model
pipe.unet = pipe.unet_model.from_pretrained(model_id, subfolder="unet", variant="fp16", use_safetensors=True)
# Load aditional layers to the model
pipe.unet.load_additional_layers(weight_path="proc_data/faithdiff/FaithDiff.bin", dtype=dtype)
# Enable vae tiling
pipe.set_encoder_tile_settings()
pipe.enable_vae_tiling()
# Optimization
pipe.enable_model_cpu_offload()
# Set selected scheduler
pipe.scheduler = UniPCMultistepScheduler.from_config(pipe.scheduler.config)
#input params
prompt = "The image features a woman in her 55s with blonde hair and a white shirt, smiling at the camera. She appears to be in a good mood and is wearing a white scarf around her neck. "
upscale = 2 # scale here
start_point = "lr" # or "noise"
latent_tiled_overlap = 0.5
latent_tiled_size = 1024
# Load image
lq_image = load_image("https://huggingface.co/datasets/DEVAIEXP/assets/resolve/main/woman.png")
original_height = lq_image.height
original_width = lq_image.width
print(f"Current resolution: H:{original_height} x W:{original_width}")
width = original_width * int(upscale)
height = original_height * int(upscale)
print(f"Final resolution: H:{height} x W:{width}")
# Restoration
image = lq_image.resize((width, height), Image.LANCZOS)
input_image, width_init, height_init, width_now, height_now = pipe.check_image_size(image)
generator = torch.Generator(device=device).manual_seed(random.randint(0, MAX_SEED))
gen_image = pipe(lr_img=input_image,
prompt = prompt,
num_inference_steps=20,
guidance_scale=5,
generator=generator,
start_point=start_point,
height = height_now,
width=width_now,
overlap=latent_tiled_overlap,
target_size=(latent_tiled_size, latent_tiled_size)
).images[0]
cropped_image = gen_image.crop((0, 0, width_init, height_init))
cropped_image.save("data/result.png")
````
### Result
[<img src="https://huggingface.co/datasets/DEVAIEXP/assets/resolve/main/faithdiff_restored.PNG" width="512px" height="512px"/>](https://imgsli.com/MzY1NzE2)
+2 -17
View File
@@ -1773,7 +1773,7 @@ class SDXLLongPromptWeightingPipeline(
f"Incorrect configuration settings! The config of `pipeline.unet`: {self.unet.config} expects"
f" {self.unet.config.in_channels} but received `num_channels_latents`: {num_channels_latents} +"
f" `num_channels_mask`: {num_channels_mask} + `num_channels_masked_image`: {num_channels_masked_image}"
f" = {num_channels_latents + num_channels_masked_image + num_channels_mask}. Please verify the config of"
f" = {num_channels_latents+num_channels_masked_image+num_channels_mask}. Please verify the config of"
" `pipeline.unet` or your `mask_image` or `image` input."
)
elif num_channels_unet != 4:
@@ -1924,22 +1924,7 @@ class SDXLLongPromptWeightingPipeline(
self.upcast_vae()
latents = latents.to(next(iter(self.vae.post_quant_conv.parameters())).dtype)
# unscale/denormalize the latents
# denormalize with the mean and std if available and not None
has_latents_mean = hasattr(self.vae.config, "latents_mean") and self.vae.config.latents_mean is not None
has_latents_std = hasattr(self.vae.config, "latents_std") and self.vae.config.latents_std is not None
if has_latents_mean and has_latents_std:
latents_mean = (
torch.tensor(self.vae.config.latents_mean).view(1, 4, 1, 1).to(latents.device, latents.dtype)
)
latents_std = (
torch.tensor(self.vae.config.latents_std).view(1, 4, 1, 1).to(latents.device, latents.dtype)
)
latents = latents * latents_std / self.vae.config.scaling_factor + latents_mean
else:
latents = latents / self.vae.config.scaling_factor
image = self.vae.decode(latents, return_dict=False)[0]
image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0]
# cast back to fp16 if needed
if needs_upcasting:
+22 -22
View File
@@ -1,4 +1,4 @@
# Copyright 2025 The DEVAIEXP Team and The HuggingFace Team. All rights reserved.
# Copyright 2025 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -1070,32 +1070,32 @@ class StableDiffusionXLTilingPipeline(
text_encoder_projection_dim = int(pooled_prompt_embeds.shape[-1])
else:
text_encoder_projection_dim = self.text_encoder_2.config.projection_dim
add_time_ids = self._get_add_time_ids(
original_size,
crops_coords_top_left[row][col],
target_size,
dtype=prompt_embeds.dtype,
text_encoder_projection_dim=text_encoder_projection_dim,
)
if negative_original_size is not None and negative_target_size is not None:
negative_add_time_ids = self._get_add_time_ids(
negative_original_size,
negative_crops_coords_top_left[row][col],
negative_target_size,
add_time_ids = self._get_add_time_ids(
original_size,
crops_coords_top_left[row][col],
target_size,
dtype=prompt_embeds.dtype,
text_encoder_projection_dim=text_encoder_projection_dim,
)
else:
negative_add_time_ids = add_time_ids
if negative_original_size is not None and negative_target_size is not None:
negative_add_time_ids = self._get_add_time_ids(
negative_original_size,
negative_crops_coords_top_left[row][col],
negative_target_size,
dtype=prompt_embeds.dtype,
text_encoder_projection_dim=text_encoder_projection_dim,
)
else:
negative_add_time_ids = add_time_ids
if self.do_classifier_free_guidance:
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)
add_text_embeds = torch.cat([negative_pooled_prompt_embeds, add_text_embeds], dim=0)
add_time_ids = torch.cat([negative_add_time_ids, add_time_ids], dim=0)
if self.do_classifier_free_guidance:
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)
add_text_embeds = torch.cat([negative_pooled_prompt_embeds, add_text_embeds], dim=0)
add_time_ids = torch.cat([negative_add_time_ids, add_time_ids], dim=0)
prompt_embeds = prompt_embeds.to(device)
add_text_embeds = add_text_embeds.to(device)
add_time_ids = add_time_ids.to(device).repeat(batch_size * num_images_per_prompt, 1)
prompt_embeds = prompt_embeds.to(device)
add_text_embeds = add_text_embeds.to(device)
add_time_ids = add_time_ids.to(device).repeat(batch_size * num_images_per_prompt, 1)
addition_embed_type_row.append((prompt_embeds, add_text_embeds, add_time_ids))
embeddings_and_added_time.append(addition_embed_type_row)
File diff suppressed because it is too large Load Diff
-661
View File
@@ -1,661 +0,0 @@
# Copyright 2025 The Wan Team and The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import html
import types
from typing import Any, Callable, Dict, List, Optional, Union
import ftfy
import regex as re
import torch
from transformers import AutoTokenizer, UMT5EncoderModel
from diffusers.callbacks import MultiPipelineCallbacks, PipelineCallback
from diffusers.loaders import WanLoraLoaderMixin
from diffusers.models import AutoencoderKLWan, WanTransformer3DModel
from diffusers.pipelines.pipeline_utils import DiffusionPipeline
from diffusers.pipelines.wan.pipeline_output import WanPipelineOutput
from diffusers.schedulers import FlowMatchEulerDiscreteScheduler
from diffusers.utils import is_torch_xla_available, logging, replace_example_docstring
from diffusers.utils.torch_utils import randn_tensor
from diffusers.video_processor import VideoProcessor
if is_torch_xla_available():
import torch_xla.core.xla_model as xm
XLA_AVAILABLE = True
else:
XLA_AVAILABLE = False
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
EXAMPLE_DOC_STRING = """
Examples:
```python
>>> import torch
>>> from diffusers.utils import export_to_video
>>> from diffusers import AutoencoderKLWan
>>> from diffusers.schedulers.scheduling_unipc_multistep import UniPCMultistepScheduler
>>> from examples.community.pipeline_stg_wan import WanSTGPipeline
>>> # Available models: Wan-AI/Wan2.1-T2V-14B-Diffusers, Wan-AI/Wan2.1-T2V-1.3B-Diffusers
>>> model_id = "Wan-AI/Wan2.1-T2V-14B-Diffusers"
>>> vae = AutoencoderKLWan.from_pretrained(model_id, subfolder="vae", torch_dtype=torch.float32)
>>> pipe = WanSTGPipeline.from_pretrained(model_id, vae=vae, torch_dtype=torch.bfloat16)
>>> flow_shift = 5.0 # 5.0 for 720P, 3.0 for 480P
>>> pipe.scheduler = UniPCMultistepScheduler.from_config(pipe.scheduler.config, flow_shift=flow_shift)
>>> pipe.to("cuda")
>>> prompt = "A cat and a dog baking a cake together in a kitchen. The cat is carefully measuring flour, while the dog is stirring the batter with a wooden spoon. The kitchen is cozy, with sunlight streaming through the window."
>>> negative_prompt = "Bright tones, overexposed, static, blurred details, subtitles, style, works, paintings, images, static, overall gray, worst quality, low quality, JPEG compression residue, ugly, incomplete, extra fingers, poorly drawn hands, poorly drawn faces, deformed, disfigured, misshapen limbs, fused fingers, still picture, messy background, three legs, many people in the background, walking backwards"
>>> # Configure STG mode options
>>> stg_applied_layers_idx = [8] # Layer indices from 0 to 39 for 14b or 0 to 29 for 1.3b
>>> stg_scale = 1.0 # Set 0.0 for CFG
>>> output = pipe(
... prompt=prompt,
... negative_prompt=negative_prompt,
... height=720,
... width=1280,
... num_frames=81,
... guidance_scale=5.0,
... stg_applied_layers_idx=stg_applied_layers_idx,
... stg_scale=stg_scale,
... ).frames[0]
>>> export_to_video(output, "output.mp4", fps=16)
```
"""
def basic_clean(text):
text = ftfy.fix_text(text)
text = html.unescape(html.unescape(text))
return text.strip()
def whitespace_clean(text):
text = re.sub(r"\s+", " ", text)
text = text.strip()
return text
def prompt_clean(text):
text = whitespace_clean(basic_clean(text))
return text
def forward_with_stg(
self,
hidden_states: torch.Tensor,
encoder_hidden_states: torch.Tensor,
temb: torch.Tensor,
rotary_emb: torch.Tensor,
) -> torch.Tensor:
return hidden_states
def forward_without_stg(
self,
hidden_states: torch.Tensor,
encoder_hidden_states: torch.Tensor,
temb: torch.Tensor,
rotary_emb: torch.Tensor,
) -> torch.Tensor:
shift_msa, scale_msa, gate_msa, c_shift_msa, c_scale_msa, c_gate_msa = (
self.scale_shift_table + temb.float()
).chunk(6, dim=1)
# 1. Self-attention
norm_hidden_states = (self.norm1(hidden_states.float()) * (1 + scale_msa) + shift_msa).type_as(hidden_states)
attn_output = self.attn1(hidden_states=norm_hidden_states, rotary_emb=rotary_emb)
hidden_states = (hidden_states.float() + attn_output * gate_msa).type_as(hidden_states)
# 2. Cross-attention
norm_hidden_states = self.norm2(hidden_states.float()).type_as(hidden_states)
attn_output = self.attn2(hidden_states=norm_hidden_states, encoder_hidden_states=encoder_hidden_states)
hidden_states = hidden_states + attn_output
# 3. Feed-forward
norm_hidden_states = (self.norm3(hidden_states.float()) * (1 + c_scale_msa) + c_shift_msa).type_as(hidden_states)
ff_output = self.ffn(norm_hidden_states)
hidden_states = (hidden_states.float() + ff_output.float() * c_gate_msa).type_as(hidden_states)
return hidden_states
class WanSTGPipeline(DiffusionPipeline, WanLoraLoaderMixin):
r"""
Pipeline for text-to-video generation using Wan.
This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods
implemented for all pipelines (downloading, saving, running on a particular device, etc.).
Args:
tokenizer ([`T5Tokenizer`]):
Tokenizer from [T5](https://huggingface.co/docs/transformers/en/model_doc/t5#transformers.T5Tokenizer),
specifically the [google/umt5-xxl](https://huggingface.co/google/umt5-xxl) variant.
text_encoder ([`T5EncoderModel`]):
[T5](https://huggingface.co/docs/transformers/en/model_doc/t5#transformers.T5EncoderModel), specifically
the [google/umt5-xxl](https://huggingface.co/google/umt5-xxl) variant.
transformer ([`WanTransformer3DModel`]):
Conditional Transformer to denoise the input latents.
scheduler ([`UniPCMultistepScheduler`]):
A scheduler to be used in combination with `transformer` to denoise the encoded image latents.
vae ([`AutoencoderKLWan`]):
Variational Auto-Encoder (VAE) Model to encode and decode videos to and from latent representations.
"""
model_cpu_offload_seq = "text_encoder->transformer->vae"
_callback_tensor_inputs = ["latents", "prompt_embeds", "negative_prompt_embeds"]
def __init__(
self,
tokenizer: AutoTokenizer,
text_encoder: UMT5EncoderModel,
transformer: WanTransformer3DModel,
vae: AutoencoderKLWan,
scheduler: FlowMatchEulerDiscreteScheduler,
):
super().__init__()
self.register_modules(
vae=vae,
text_encoder=text_encoder,
tokenizer=tokenizer,
transformer=transformer,
scheduler=scheduler,
)
self.vae_scale_factor_temporal = 2 ** sum(self.vae.temperal_downsample) if getattr(self, "vae", None) else 4
self.vae_scale_factor_spatial = 2 ** len(self.vae.temperal_downsample) if getattr(self, "vae", None) else 8
self.video_processor = VideoProcessor(vae_scale_factor=self.vae_scale_factor_spatial)
def _get_t5_prompt_embeds(
self,
prompt: Union[str, List[str]] = None,
num_videos_per_prompt: int = 1,
max_sequence_length: int = 226,
device: Optional[torch.device] = None,
dtype: Optional[torch.dtype] = None,
):
device = device or self._execution_device
dtype = dtype or self.text_encoder.dtype
prompt = [prompt] if isinstance(prompt, str) else prompt
prompt = [prompt_clean(u) for u in prompt]
batch_size = len(prompt)
text_inputs = self.tokenizer(
prompt,
padding="max_length",
max_length=max_sequence_length,
truncation=True,
add_special_tokens=True,
return_attention_mask=True,
return_tensors="pt",
)
text_input_ids, mask = text_inputs.input_ids, text_inputs.attention_mask
seq_lens = mask.gt(0).sum(dim=1).long()
prompt_embeds = self.text_encoder(text_input_ids.to(device), mask.to(device)).last_hidden_state
prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)
prompt_embeds = [u[:v] for u, v in zip(prompt_embeds, seq_lens)]
prompt_embeds = torch.stack(
[torch.cat([u, u.new_zeros(max_sequence_length - u.size(0), u.size(1))]) for u in prompt_embeds], dim=0
)
# duplicate text embeddings for each generation per prompt, using mps friendly method
_, seq_len, _ = prompt_embeds.shape
prompt_embeds = prompt_embeds.repeat(1, num_videos_per_prompt, 1)
prompt_embeds = prompt_embeds.view(batch_size * num_videos_per_prompt, seq_len, -1)
return prompt_embeds
def encode_prompt(
self,
prompt: Union[str, List[str]],
negative_prompt: Optional[Union[str, List[str]]] = None,
do_classifier_free_guidance: bool = True,
num_videos_per_prompt: int = 1,
prompt_embeds: Optional[torch.Tensor] = None,
negative_prompt_embeds: Optional[torch.Tensor] = None,
max_sequence_length: int = 226,
device: Optional[torch.device] = None,
dtype: Optional[torch.dtype] = None,
):
r"""
Encodes the prompt into text encoder hidden states.
Args:
prompt (`str` or `List[str]`, *optional*):
prompt to be encoded
negative_prompt (`str` or `List[str]`, *optional*):
The prompt or prompts not to guide the image generation. If not defined, one has to pass
`negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
less than `1`).
do_classifier_free_guidance (`bool`, *optional*, defaults to `True`):
Whether to use classifier free guidance or not.
num_videos_per_prompt (`int`, *optional*, defaults to 1):
Number of videos that should be generated per prompt. torch device to place the resulting embeddings on
prompt_embeds (`torch.Tensor`, *optional*):
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
provided, text embeddings will be generated from `prompt` input argument.
negative_prompt_embeds (`torch.Tensor`, *optional*):
Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
argument.
device: (`torch.device`, *optional*):
torch device
dtype: (`torch.dtype`, *optional*):
torch dtype
"""
device = device or self._execution_device
prompt = [prompt] if isinstance(prompt, str) else prompt
if prompt is not None:
batch_size = len(prompt)
else:
batch_size = prompt_embeds.shape[0]
if prompt_embeds is None:
prompt_embeds = self._get_t5_prompt_embeds(
prompt=prompt,
num_videos_per_prompt=num_videos_per_prompt,
max_sequence_length=max_sequence_length,
device=device,
dtype=dtype,
)
if do_classifier_free_guidance and negative_prompt_embeds is None:
negative_prompt = negative_prompt or ""
negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt
if prompt is not None and type(prompt) is not type(negative_prompt):
raise TypeError(
f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
f" {type(prompt)}."
)
elif batch_size != len(negative_prompt):
raise ValueError(
f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
" the batch size of `prompt`."
)
negative_prompt_embeds = self._get_t5_prompt_embeds(
prompt=negative_prompt,
num_videos_per_prompt=num_videos_per_prompt,
max_sequence_length=max_sequence_length,
device=device,
dtype=dtype,
)
return prompt_embeds, negative_prompt_embeds
def check_inputs(
self,
prompt,
negative_prompt,
height,
width,
prompt_embeds=None,
negative_prompt_embeds=None,
callback_on_step_end_tensor_inputs=None,
):
if height % 16 != 0 or width % 16 != 0:
raise ValueError(f"`height` and `width` have to be divisible by 16 but are {height} and {width}.")
if callback_on_step_end_tensor_inputs is not None and not all(
k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
):
raise ValueError(
f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}"
)
if prompt is not None and prompt_embeds is not None:
raise ValueError(
f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
" only forward one of the two."
)
elif negative_prompt is not None and negative_prompt_embeds is not None:
raise ValueError(
f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`: {negative_prompt_embeds}. Please make sure to"
" only forward one of the two."
)
elif prompt is None and prompt_embeds is None:
raise ValueError(
"Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
)
elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
elif negative_prompt is not None and (
not isinstance(negative_prompt, str) and not isinstance(negative_prompt, list)
):
raise ValueError(f"`negative_prompt` has to be of type `str` or `list` but is {type(negative_prompt)}")
def prepare_latents(
self,
batch_size: int,
num_channels_latents: int = 16,
height: int = 480,
width: int = 832,
num_frames: int = 81,
dtype: Optional[torch.dtype] = None,
device: Optional[torch.device] = None,
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
latents: Optional[torch.Tensor] = None,
) -> torch.Tensor:
if latents is not None:
return latents.to(device=device, dtype=dtype)
num_latent_frames = (num_frames - 1) // self.vae_scale_factor_temporal + 1
shape = (
batch_size,
num_channels_latents,
num_latent_frames,
int(height) // self.vae_scale_factor_spatial,
int(width) // self.vae_scale_factor_spatial,
)
if isinstance(generator, list) and len(generator) != batch_size:
raise ValueError(
f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
f" size of {batch_size}. Make sure the batch size matches the length of the generators."
)
latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
return latents
@property
def guidance_scale(self):
return self._guidance_scale
@property
def do_classifier_free_guidance(self):
return self._guidance_scale > 1.0
@property
def do_spatio_temporal_guidance(self):
return self._stg_scale > 0.0
@property
def num_timesteps(self):
return self._num_timesteps
@property
def current_timestep(self):
return self._current_timestep
@property
def interrupt(self):
return self._interrupt
@property
def attention_kwargs(self):
return self._attention_kwargs
@torch.no_grad()
@replace_example_docstring(EXAMPLE_DOC_STRING)
def __call__(
self,
prompt: Union[str, List[str]] = None,
negative_prompt: Union[str, List[str]] = None,
height: int = 480,
width: int = 832,
num_frames: int = 81,
num_inference_steps: int = 50,
guidance_scale: float = 5.0,
num_videos_per_prompt: Optional[int] = 1,
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
latents: Optional[torch.Tensor] = None,
prompt_embeds: Optional[torch.Tensor] = None,
negative_prompt_embeds: Optional[torch.Tensor] = None,
output_type: Optional[str] = "np",
return_dict: bool = True,
attention_kwargs: Optional[Dict[str, Any]] = None,
callback_on_step_end: Optional[
Union[Callable[[int, int, Dict], None], PipelineCallback, MultiPipelineCallbacks]
] = None,
callback_on_step_end_tensor_inputs: List[str] = ["latents"],
max_sequence_length: int = 512,
stg_applied_layers_idx: Optional[List[int]] = [3, 8, 16],
stg_scale: Optional[float] = 0.0,
):
r"""
The call function to the pipeline for generation.
Args:
prompt (`str` or `List[str]`, *optional*):
The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
instead.
height (`int`, defaults to `480`):
The height in pixels of the generated image.
width (`int`, defaults to `832`):
The width in pixels of the generated image.
num_frames (`int`, defaults to `81`):
The number of frames in the generated video.
num_inference_steps (`int`, defaults to `50`):
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
expense of slower inference.
guidance_scale (`float`, defaults to `5.0`):
Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
`guidance_scale` is defined as `w` of equation 2. of [Imagen
Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
usually at the expense of lower image quality.
num_videos_per_prompt (`int`, *optional*, defaults to 1):
The number of images to generate per prompt.
generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make
generation deterministic.
latents (`torch.Tensor`, *optional*):
Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs for image
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
tensor is generated by sampling using the supplied random `generator`.
prompt_embeds (`torch.Tensor`, *optional*):
Pre-generated text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not
provided, text embeddings are generated from the `prompt` input argument.
output_type (`str`, *optional*, defaults to `"pil"`):
The output format of the generated image. Choose between `PIL.Image` or `np.array`.
return_dict (`bool`, *optional*, defaults to `True`):
Whether or not to return a [`WanPipelineOutput`] instead of a plain tuple.
attention_kwargs (`dict`, *optional*):
A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
`self.processor` in
[diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
callback_on_step_end (`Callable`, `PipelineCallback`, `MultiPipelineCallbacks`, *optional*):
A function or a subclass of `PipelineCallback` or `MultiPipelineCallbacks` that is called at the end of
each denoising step during the inference. with the following arguments: `callback_on_step_end(self:
DiffusionPipeline, step: int, timestep: int, callback_kwargs: Dict)`. `callback_kwargs` will include a
list of all tensors as specified by `callback_on_step_end_tensor_inputs`.
callback_on_step_end_tensor_inputs (`List`, *optional*):
The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
`._callback_tensor_inputs` attribute of your pipeline class.
autocast_dtype (`torch.dtype`, *optional*, defaults to `torch.bfloat16`):
The dtype to use for the torch.amp.autocast.
Examples:
Returns:
[`~WanPipelineOutput`] or `tuple`:
If `return_dict` is `True`, [`WanPipelineOutput`] is returned, otherwise a `tuple` is returned where
the first element is a list with the generated images and the second element is a list of `bool`s
indicating whether the corresponding generated image contains "not-safe-for-work" (nsfw) content.
"""
if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)):
callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs
# 1. Check inputs. Raise error if not correct
self.check_inputs(
prompt,
negative_prompt,
height,
width,
prompt_embeds,
negative_prompt_embeds,
callback_on_step_end_tensor_inputs,
)
self._guidance_scale = guidance_scale
self._stg_scale = stg_scale
self._attention_kwargs = attention_kwargs
self._current_timestep = None
self._interrupt = False
device = self._execution_device
# 2. Define call parameters
if prompt is not None and isinstance(prompt, str):
batch_size = 1
elif prompt is not None and isinstance(prompt, list):
batch_size = len(prompt)
else:
batch_size = prompt_embeds.shape[0]
# 3. Encode input prompt
prompt_embeds, negative_prompt_embeds = self.encode_prompt(
prompt=prompt,
negative_prompt=negative_prompt,
do_classifier_free_guidance=self.do_classifier_free_guidance,
num_videos_per_prompt=num_videos_per_prompt,
prompt_embeds=prompt_embeds,
negative_prompt_embeds=negative_prompt_embeds,
max_sequence_length=max_sequence_length,
device=device,
)
transformer_dtype = self.transformer.dtype
prompt_embeds = prompt_embeds.to(transformer_dtype)
if negative_prompt_embeds is not None:
negative_prompt_embeds = negative_prompt_embeds.to(transformer_dtype)
# 4. Prepare timesteps
self.scheduler.set_timesteps(num_inference_steps, device=device)
timesteps = self.scheduler.timesteps
# 5. Prepare latent variables
num_channels_latents = self.transformer.config.in_channels
latents = self.prepare_latents(
batch_size * num_videos_per_prompt,
num_channels_latents,
height,
width,
num_frames,
torch.float32,
device,
generator,
latents,
)
# 6. Denoising loop
num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
self._num_timesteps = len(timesteps)
with self.progress_bar(total=num_inference_steps) as progress_bar:
for i, t in enumerate(timesteps):
if self.interrupt:
continue
self._current_timestep = t
latent_model_input = latents.to(transformer_dtype)
timestep = t.expand(latents.shape[0])
if self.do_spatio_temporal_guidance:
for idx, block in enumerate(self.transformer.blocks):
block.forward = types.MethodType(forward_without_stg, block)
noise_pred = self.transformer(
hidden_states=latent_model_input,
timestep=timestep,
encoder_hidden_states=prompt_embeds,
attention_kwargs=attention_kwargs,
return_dict=False,
)[0]
if self.do_classifier_free_guidance:
noise_uncond = self.transformer(
hidden_states=latent_model_input,
timestep=timestep,
encoder_hidden_states=negative_prompt_embeds,
attention_kwargs=attention_kwargs,
return_dict=False,
)[0]
if self.do_spatio_temporal_guidance:
for idx, block in enumerate(self.transformer.blocks):
if idx in stg_applied_layers_idx:
block.forward = types.MethodType(forward_with_stg, block)
noise_perturb = self.transformer(
hidden_states=latent_model_input,
timestep=timestep,
encoder_hidden_states=prompt_embeds,
attention_kwargs=attention_kwargs,
return_dict=False,
)[0]
noise_pred = (
noise_uncond
+ guidance_scale * (noise_pred - noise_uncond)
+ self._stg_scale * (noise_pred - noise_perturb)
)
else:
noise_pred = noise_uncond + guidance_scale * (noise_pred - noise_uncond)
# compute the previous noisy sample x_t -> x_t-1
latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0]
if callback_on_step_end is not None:
callback_kwargs = {}
for k in callback_on_step_end_tensor_inputs:
callback_kwargs[k] = locals()[k]
callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
latents = callback_outputs.pop("latents", latents)
prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds)
# call the callback, if provided
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
progress_bar.update()
if XLA_AVAILABLE:
xm.mark_step()
self._current_timestep = None
if not output_type == "latent":
latents = latents.to(self.vae.dtype)
latents_mean = (
torch.tensor(self.vae.config.latents_mean)
.view(1, self.vae.config.z_dim, 1, 1, 1)
.to(latents.device, latents.dtype)
)
latents_std = 1.0 / torch.tensor(self.vae.config.latents_std).view(1, self.vae.config.z_dim, 1, 1, 1).to(
latents.device, latents.dtype
)
latents = latents / latents_std + latents_mean
video = self.vae.decode(latents, return_dict=False)[0]
video = self.video_processor.postprocess_video(video, output_type=output_type)
else:
video = latents
# Offload all models
self.maybe_free_model_hooks()
if not return_dict:
return (video,)
return WanPipelineOutput(frames=video)
+3 -1
View File
@@ -152,7 +152,9 @@ def log_validation(
validation_prompt = log["validation_prompt"]
validation_image = log["validation_image"]
formatted_images = [np.asarray(validation_image)]
formatted_images = []
formatted_images.append(np.asarray(validation_image))
for image in images:
formatted_images.append(np.asarray(image))
+3 -1
View File
@@ -166,7 +166,9 @@ def log_validation(
validation_prompt = log["validation_prompt"]
validation_image = log["validation_image"]
formatted_images = [np.asarray(validation_image)]
formatted_images = []
formatted_images.append(np.asarray(validation_image))
for image in images:
formatted_images.append(np.asarray(image))
+2 -2
View File
@@ -1283,8 +1283,8 @@ def main(args):
noisy_model_input = (1.0 - sigmas) * model_input + sigmas * noise
# Get the text embedding for conditioning
prompt_embeds = batch["prompt_embeds"].to(dtype=weight_dtype)
pooled_prompt_embeds = batch["pooled_prompt_embeds"].to(dtype=weight_dtype)
prompt_embeds = batch["prompt_embeds"]
pooled_prompt_embeds = batch["pooled_prompt_embeds"]
# controlnet(s) inference
controlnet_image = batch["conditioning_pixel_values"].to(dtype=weight_dtype)
+3 -1
View File
@@ -157,7 +157,9 @@ def log_validation(vae, unet, controlnet, args, accelerator, weight_dtype, step,
validation_prompt = log["validation_prompt"]
validation_image = log["validation_image"]
formatted_images = [np.asarray(validation_image)]
formatted_images = []
formatted_images.append(np.asarray(validation_image))
for image in images:
formatted_images.append(np.asarray(image))
@@ -49,7 +49,6 @@ from diffusers import AutoencoderKL, DDPMScheduler, StableDiffusionInstructPix2P
from diffusers.optimization import get_scheduler
from diffusers.training_utils import EMAModel
from diffusers.utils import check_min_version, deprecate, is_wandb_available
from diffusers.utils.constants import DIFFUSERS_REQUEST_TIMEOUT
from diffusers.utils.import_utils import is_xformers_available
from diffusers.utils.torch_utils import is_compiled_module
@@ -419,7 +418,7 @@ def convert_to_np(image, resolution):
def download_image(url):
image = PIL.Image.open(requests.get(url, stream=True, timeout=DIFFUSERS_REQUEST_TIMEOUT).raw)
image = PIL.Image.open(requests.get(url, stream=True).raw)
image = PIL.ImageOps.exif_transpose(image)
image = image.convert("RGB")
return image
+4 -12
View File
@@ -1,27 +1,20 @@
# AnyTextPipeline
# AnyTextPipeline Pipeline
Project page: https://aigcdesigngroup.github.io/homepage_anytext
"AnyText comprises a diffusion pipeline with two primary elements: an auxiliary latent module and a text embedding module. The former uses inputs like text glyph, position, and masked image to generate latent features for text generation or editing. The latter employs an OCR model for encoding stroke data as embeddings, which blend with image caption embeddings from the tokenizer to generate texts that seamlessly integrate with the background. We employed text-control diffusion loss and text perceptual loss for training to further enhance writing accuracy."
> **Note:** Each text line that needs to be generated should be enclosed in double quotes.
Each text line that needs to be generated should be enclosed in double quotes. For any usage questions, please refer to the [paper](https://arxiv.org/abs/2311.03054).
For any usage questions, please refer to the [paper](https://arxiv.org/abs/2311.03054).
[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/gist/tolgacangoz/b87ec9d2f265b448dd947c9d4a0da389/anytext.ipynb)
```py
# This example requires the `anytext_controlnet.py` file:
# !git clone --depth 1 https://github.com/huggingface/diffusers.git
# %cd diffusers/examples/research_projects/anytext
# Let's choose a font file shared by an HF staff:
# !wget https://huggingface.co/spaces/ysharma/TranslateQuotesInImageForwards/resolve/main/arial-unicode-ms.ttf
import torch
from diffusers import DiffusionPipeline
from anytext_controlnet import AnyTextControlNetModel
from diffusers.utils import load_image
# I chose a font file shared by an HF staff:
# !wget https://huggingface.co/spaces/ysharma/TranslateQuotesInImageForwards/resolve/main/arial-unicode-ms.ttf
anytext_controlnet = AnyTextControlNetModel.from_pretrained("tolgacangoz/anytext-controlnet", torch_dtype=torch.float16,
variant="fp16",)
@@ -33,7 +26,6 @@ pipe = DiffusionPipeline.from_pretrained("tolgacangoz/anytext", font_path="arial
# generate image
prompt = 'photo of caramel macchiato coffee on the table, top-down perspective, with "Any" "Text" written on it using cream'
draw_pos = load_image("https://raw.githubusercontent.com/tyxsspa/AnyText/refs/heads/main/example_images/gen9.png")
# There are two modes: "generate" and "edit". "edit" mode requires `ori_image` parameter for the image to be edited.
image = pipe(prompt, num_inference_steps=20, mode="generate", draw_pos=draw_pos,
).images[0]
image
+5 -11
View File
@@ -146,17 +146,14 @@ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
EXAMPLE_DOC_STRING = """
Examples:
```py
>>> # This example requires the `anytext_controlnet.py` file:
>>> # !git clone --depth 1 https://github.com/huggingface/diffusers.git
>>> # %cd diffusers/examples/research_projects/anytext
>>> # Let's choose a font file shared by an HF staff:
>>> # !wget https://huggingface.co/spaces/ysharma/TranslateQuotesInImageForwards/resolve/main/arial-unicode-ms.ttf
>>> import torch
>>> from diffusers import DiffusionPipeline
>>> from anytext_controlnet import AnyTextControlNetModel
>>> from diffusers.utils import load_image
>>> # I chose a font file shared by an HF staff:
>>> !wget https://huggingface.co/spaces/ysharma/TranslateQuotesInImageForwards/resolve/main/arial-unicode-ms.ttf
>>> anytext_controlnet = AnyTextControlNetModel.from_pretrained("tolgacangoz/anytext-controlnet", torch_dtype=torch.float16,
... variant="fp16",)
>>> pipe = DiffusionPipeline.from_pretrained("tolgacangoz/anytext", font_path="arial-unicode-ms.ttf",
@@ -168,7 +165,6 @@ EXAMPLE_DOC_STRING = """
>>> # generate image
>>> prompt = 'photo of caramel macchiato coffee on the table, top-down perspective, with "Any" "Text" written on it using cream'
>>> draw_pos = load_image("https://raw.githubusercontent.com/tyxsspa/AnyText/refs/heads/main/example_images/gen9.png")
>>> # There are two modes: "generate" and "edit". "edit" mode requires `ori_image` parameter for the image to be edited.
>>> image = pipe(prompt, num_inference_steps=20, mode="generate", draw_pos=draw_pos,
... ).images[0]
>>> image
@@ -261,11 +257,11 @@ class EmbeddingManager(ModelMixin, ConfigMixin):
idx = tokenized_text[i] == self.placeholder_token.to(device)
if sum(idx) > 0:
if i >= len(self.text_embs_all):
logger.warning("truncation for log images...")
print("truncation for log images...")
break
text_emb = torch.cat(self.text_embs_all[i], dim=0)
if sum(idx) != len(text_emb):
logger.warning("truncation for long caption...")
print("truncation for long caption...")
text_emb = text_emb.to(embedded_text.device)
embedded_text[i][idx] = text_emb[: sum(idx)]
return embedded_text
@@ -1062,8 +1058,6 @@ class AuxiliaryLatentModule(ModelMixin, ConfigMixin):
raise ValueError(f"Can't read ori_image image from {ori_image}!")
elif isinstance(ori_image, torch.Tensor):
ori_image = ori_image.cpu().numpy()
elif isinstance(ori_image, PIL.Image.Image):
ori_image = np.array(ori_image.convert("RGB"))
else:
if not isinstance(ori_image, np.ndarray):
raise ValueError(f"Unknown format of ori_image: {type(ori_image)}")
@@ -627,7 +627,6 @@ def main(args):
ema_vae = EMAModel(vae.parameters(), model_cls=AutoencoderKL, model_config=vae.config)
perceptual_loss = lpips.LPIPS(net="vgg").eval()
discriminator = NLayerDiscriminator(input_nc=3, n_layers=3, use_actnorm=False).apply(weights_init)
discriminator = torch.nn.SyncBatchNorm.convert_sync_batchnorm(discriminator)
# Taken from [Sayak Paul's Diffusers PR #6511](https://github.com/huggingface/diffusers/pull/6511/files)
def unwrap_model(model):
@@ -952,20 +951,13 @@ def main(args):
logits_fake = discriminator(reconstructions)
disc_loss = hinge_d_loss if args.disc_loss == "hinge" else vanilla_d_loss
disc_factor = args.disc_factor if global_step >= args.disc_start else 0.0
d_loss = disc_factor * disc_loss(logits_real, logits_fake)
disc_loss = disc_factor * disc_loss(logits_real, logits_fake)
logs = {
"disc_loss": d_loss.detach().mean().item(),
"disc_loss": disc_loss.detach().mean().item(),
"logits_real": logits_real.detach().mean().item(),
"logits_fake": logits_fake.detach().mean().item(),
"disc_lr": disc_lr_scheduler.get_last_lr()[0],
}
accelerator.backward(d_loss)
if accelerator.sync_gradients:
params_to_clip = discriminator.parameters()
accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm)
disc_optimizer.step()
disc_lr_scheduler.step()
disc_optimizer.zero_grad(set_to_none=args.set_grads_to_none)
# Checks if the accelerator has performed an optimization step behind the scenes
if accelerator.sync_gradients:
progress_bar.update(1)
@@ -381,7 +381,9 @@ def log_validation(vae, unet, controlnet, args, accelerator, weight_dtype, step)
validation_prompt = log["validation_prompt"]
validation_image = log["validation_image"]
formatted_images = [np.asarray(validation_image)]
formatted_images = []
formatted_images.append(np.asarray(validation_image))
for image in images:
formatted_images.append(np.asarray(image))
@@ -54,7 +54,6 @@ from diffusers import AutoencoderKL, DDPMScheduler, StableDiffusionInstructPix2P
from diffusers.optimization import get_scheduler
from diffusers.training_utils import EMAModel, cast_training_params
from diffusers.utils import check_min_version, convert_state_dict_to_diffusers, deprecate, is_wandb_available
from diffusers.utils.constants import DIFFUSERS_REQUEST_TIMEOUT
from diffusers.utils.hub_utils import load_or_create_model_card, populate_model_card
from diffusers.utils.import_utils import is_xformers_available
from diffusers.utils.torch_utils import is_compiled_module
@@ -476,7 +475,7 @@ def convert_to_np(image, resolution):
def download_image(url):
image = PIL.Image.open(requests.get(url, stream=True, timeout=DIFFUSERS_REQUEST_TIMEOUT).raw)
image = PIL.Image.open(requests.get(url, stream=True).raw)
image = PIL.ImageOps.exif_transpose(image)
image = image.convert("RGB")
return image
@@ -164,7 +164,9 @@ def log_validation(
validation_prompt = log["validation_prompt"]
validation_image = log["validation_image"]
formatted_images = [np.asarray(validation_image)]
formatted_images = []
formatted_images.append(np.asarray(validation_image))
for image in images:
formatted_images.append(np.asarray(image))
@@ -59,7 +59,6 @@ from diffusers.schedulers import (
UnCLIPScheduler,
)
from diffusers.utils import is_accelerate_available, logging
from diffusers.utils.constants import DIFFUSERS_REQUEST_TIMEOUT
if is_accelerate_available():
@@ -1436,7 +1435,7 @@ def download_from_original_stable_diffusion_ckpt(
config_url = "https://raw.githubusercontent.com/Stability-AI/stablediffusion/main/configs/stable-diffusion/x4-upscaling.yaml"
if config_url is not None:
original_config_file = BytesIO(requests.get(config_url, timeout=DIFFUSERS_REQUEST_TIMEOUT).content)
original_config_file = BytesIO(requests.get(config_url).content)
else:
with open(original_config_file, "r") as f:
original_config_file = f.read()
@@ -1,6 +1,8 @@
# Generating images using Flux and PyTorch/XLA
The `flux_inference` script shows how to do image generation using Flux on TPU devices using PyTorch/XLA. It uses the pallas kernel for flash attention for faster generation using custom flash block sizes for better performance on [Trillium](https://cloud.google.com/blog/products/compute/introducing-trillium-6th-gen-tpus) TPU versions. No other TPU types have been tested.
The `flux_inference` script shows how to do image generation using Flux on TPU devices using PyTorch/XLA. It uses the pallas kernel for flash attention for faster generation.
It has been tested on [Trillium](https://cloud.google.com/blog/products/compute/introducing-trillium-6th-gen-tpus) TPU versions. No other TPU types have been tested.
## Create TPU
@@ -21,23 +23,20 @@ Verify that PyTorch and PyTorch/XLA were installed correctly:
python3 -c "import torch; import torch_xla;"
```
Clone the diffusers repo and install dependencies
Install dependencies
```bash
git clone https://github.com/huggingface/diffusers.git
cd diffusers
pip install transformers accelerate sentencepiece structlog
pushd ../../..
pip install .
cd examples/research_projects/pytorch_xla/inference/flux/
popd
```
## Run the inference job
### Authenticate
**Gated Model**
As the model is gated, before using it with diffusers you first need to go to the [FLUX.1 [dev] Hugging Face page](https://huggingface.co/black-forest-labs/FLUX.1-dev), fill in the form and accept the gate. Once you are in, you need to log in so that your system knows youve accepted the gate. Use the command below to log in:
Run the following command to authenticate your token in order to download Flux weights.
```bash
huggingface-cli login
@@ -51,116 +50,51 @@ python flux_inference.py
The script loads the text encoders onto the CPU and the Flux transformer and VAE models onto the TPU. The first time the script runs, the compilation time is longer, while the cache stores the compiled programs. On subsequent runs, compilation is much faster and the subsequent passes being the fastest.
On a Trillium v6e-4, you should expect ~6 sec / 4 images or 1.5 sec / image (as devices run generation in parallel):
On a Trillium v6e-4, you should expect ~9 sec / 4 images or 2.25 sec / image (as devices run generation in parallel):
```bash
WARNING:root:libtpu.so and TPU device found. Setting PJRT_DEVICE=TPU.
Loading checkpoint shards: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 2/2 [00:00<00:00, 7.06it/s]
Loading pipeline components...: 60%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████ | 3/5 [00:00<00:00, 6.80it/s]You set `add_prefix_space`. The tokenizer needs to be converted from the slow tokenizers
Loading pipeline components...: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 5/5 [00:00<00:00, 6.28it/s]
2025-03-14 21:17:53 [info ] loading flux from black-forest-labs/FLUX.1-dev
2025-03-14 21:17:53 [info ] loading flux from black-forest-labs/FLUX.1-dev
Loading pipeline components...: 0%| | 0/3 [00:00<?, ?it/s]2025-03-14 21:17:53 [info ] loading flux from black-forest-labs/FLUX.1-dev
2025-03-14 21:17:53 [info ] loading flux from black-forest-labs/FLUX.1-dev
Loading pipeline components...: 0%| | 0/3 [00:00<?, ?it/s]2025-03-14 21:17:54 [info ] loading flux from black-forest-labs/FLUX.1-dev
2025-03-14 21:17:54 [info ] loading flux from black-forest-labs/FLUX.1-dev
Loading pipeline components...: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 3/3 [00:00<00:00, 3.66it/s]
Loading pipeline components...: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 3/3 [00:00<00:00, 4.48it/s]
Loading pipeline components...: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 3/3 [00:00<00:00, 3.32it/s]
Loading pipeline components...: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 3/3 [00:00<00:00, 3.69it/s]
Loading pipeline components...: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 3/3 [00:00<00:00, 3.74it/s]
Loading pipeline components...: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 3/3 [00:00<00:00, 3.10it/s]
2025-03-14 21:17:56 [info ] loading flux from black-forest-labs/FLUX.1-dev
Loading pipeline components...: 0%| | 0/3 [00:00<?, ?it/s]2025-03-14 21:17:56 [info ] loading flux from black-forest-labs/FLUX.1-dev
Loading pipeline components...: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 3/3 [00:00<00:00, 3.55it/s]
Loading pipeline components...: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 3/3 [00:02<00:00, 1.46it/s]
2025-03-14 21:18:34 [info ] starting compilation run...
2025-03-14 21:18:37 [info ] starting compilation run...
2025-03-14 21:18:38 [info ] starting compilation run...
2025-03-14 21:18:39 [info ] starting compilation run...
2025-03-14 21:18:41 [info ] starting compilation run...
2025-03-14 21:18:41 [info ] starting compilation run...
2025-03-14 21:18:42 [info ] starting compilation run...
2025-03-14 21:18:43 [info ] starting compilation run...
82%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████▊ | 23/28 [13:35<03:04, 36.80s/it]2025-03-14 21:33:42.057559: W torch_xla/csrc/runtime/pjrt_computation_client.cc:667] Failed to deserialize executable: INTERNAL: TfrtTpuExecutable proto deserialization failed while parsing core program!
100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 28/28 [16:27<00:00, 35.28s/it]
100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 28/28 [16:27<00:00, 35.26s/it]
2025-03-14 21:36:38 [info ] compilation took 1079.3314765350078 sec.
100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 28/28 [16:12<00:00, 34.73s/it]
2025-03-14 21:36:38 [info ] starting inference run...
100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 28/28 [16:12<00:00, 34.73s/it]
2025-03-14 21:36:38 [info ] compilation took 1081.89390801001 sec.
2025-03-14 21:36:39 [info ] starting inference run...
2025-03-14 21:36:39 [info ] compilation took 1077.1543154849933 sec.
2025-03-14 21:36:39 [info ] compilation took 1075.7239800530078 sec.
2025-03-14 21:36:39 [info ] starting inference run...
2025-03-14 21:36:40 [info ] starting inference run...
100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 28/28 [16:22<00:00, 35.10s/it]
2025-03-14 21:36:50 [info ] compilation took 1088.1632604240003 sec.
2025-03-14 21:36:50 [info ] starting inference run...
100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 28/28 [16:28<00:00, 35.32s/it]
2025-03-14 21:36:55 [info ] compilation took 1096.8027802760043 sec.
2025-03-14 21:36:56 [info ] starting inference run...
100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 28/28 [16:59<00:00, 36.40s/it]
2025-03-14 21:37:08 [info ] compilation took 1113.8591305939917 sec.
2025-03-14 21:37:08 [info ] starting inference run...
100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 28/28 [16:55<00:00, 36.26s/it]
2025-03-14 21:37:22 [info ] compilation took 1120.5590810020076 sec.
2025-03-14 21:37:22 [info ] starting inference run...
100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 28/28 [00:05<00:00, 5.00it/s]
100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 28/28 [00:09<00:00, 2.98it/s]
100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 28/28 [00:06<00:00, 4.08it/s]
100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 28/28 [00:09<00:00, 2.82it/s]
100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 28/28 [00:08<00:00, 3.34it/s]
100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 28/28 [00:06<00:00, 4.22it/s]
100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 28/28 [00:06<00:00, 4.09it/s]
100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 28/28 [00:11<00:00, 2.41it/s]
100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 28/28 [00:06<00:00, 4.50it/s]
100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 28/28 [00:05<00:00, 5.10it/s]
100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 28/28 [00:05<00:00, 5.27it/s]
100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 28/28 [00:05<00:00, 4.80it/s]
100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 28/28 [00:05<00:00, 5.39it/s]
100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 28/28 [00:05<00:00, 5.39it/s]
100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 28/28 [00:04<00:00, 5.67it/s]
29%|█████████████████████████████████████████████████████████████████████████████▍ | 8/28 [00:01<00:03, 6.08it/s]/home/jfacevedo_google_com/diffusers/src/diffusers/image_processor.py:147: RuntimeWarning: invalid value encountered in cast
images = (images * 255).round().astype("uint8")
100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 28/28 [00:04<00:00, 5.82it/s]
100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 28/28 [00:04<00:00, 5.93it/s]
100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 28/28 [00:04<00:00, 6.02it/s]
100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 28/28 [00:04<00:00, 6.02it/s]
100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 28/28 [00:04<00:00, 5.98it/s]
71%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████▊ | 20/28 [00:03<00:01, 6.03it/s]2025-03-14 21:38:32 [info ] inference time: 5.962021178987925
100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 28/28 [00:04<00:00, 5.89it/s]
100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 28/28 [00:04<00:00, 6.09it/s]
2025-03-14 21:38:32 [info ] avg. inference over 5 iterations took 7.2685392687970305 sec.
2025-03-14 21:38:32 [info ] avg. inference over 5 iterations took 7.402720856998348 sec.
100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 28/28 [00:04<00:00, 6.01it/s]
100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 28/28 [00:04<00:00, 5.89it/s]
100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 28/28 [00:04<00:00, 5.96it/s]
100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 28/28 [00:04<00:00, 6.06it/s]
71%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████▊ | 20/28 [00:03<00:01, 6.01it/s]2025-03-14 21:38:38 [info ] inference time: 5.950578948002658
100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 28/28 [00:04<00:00, 5.87it/s]
100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 28/28 [00:04<00:00, 6.09it/s]
100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 28/28 [00:04<00:00, 6.00it/s]
100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 28/28 [00:04<00:00, 5.86it/s]
100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 28/28 [00:04<00:00, 5.99it/s]
100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 28/28 [00:04<00:00, 6.05it/s]
2025-03-14 21:38:43 [info ] avg. inference over 5 iterations took 6.763298449796276 sec.
71%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████▊ | 20/28 [00:03<00:01, 6.04it/s]2025-03-14 21:38:44 [info ] inference time: 5.949129879008979
100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 28/28 [00:04<00:00, 5.92it/s]
100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 28/28 [00:04<00:00, 6.10it/s]
100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 28/28 [00:04<00:00, 6.02it/s]
39%|██████████████████████████████████████████████████████████████████████████████████████████████████████████ | 11/28 [00:01<00:02, 5.98it/s]2025-03-14 21:38:46 [info ] avg. inference over 5 iterations took 7.221068455604836 sec.
100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 28/28 [00:04<00:00, 5.96it/s]
100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 28/28 [00:04<00:00, 6.08it/s]
93%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████▋ | 26/28 [00:04<00:00, 5.92it/s]2025-03-14 21:38:50 [info ] inference time: 5.954778069004533
100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 28/28 [00:04<00:00, 5.90it/s]
11%|█████████████████████████████ | 3/28 [00:00<00:04, 6.03it/s]2025-03-14 21:38:50 [info ] avg. inference over 5 iterations took 6.05970350120042 sec.
100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 28/28 [00:04<00:00, 6.02it/s]
32%|███████████████████████████████████████████████████████████████████████████████████████ | 9/28 [00:01<00:03, 5.99it/s]2025-03-14 21:38:51 [info ] avg. inference over 5 iterations took 6.018543455796316 sec.
54%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████▋ | 15/28 [00:02<00:02, 6.00it/s]2025-03-14 21:38:52 [info ] avg. inference over 5 iterations took 5.9609976705978625 sec.
100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 28/28 [00:04<00:00, 5.97it/s]
2025-03-14 21:38:56 [info ] inference time: 5.944058528999449
2025-03-14 21:38:56 [info ] avg. inference over 5 iterations took 5.952113320800708 sec.
2025-03-14 21:38:56 [info ] saved metric information as /tmp/metrics_report.txt
Loading checkpoint shards: 100%|███████████████████████████████| 2/2 [00:00<00:00, 7.01it/s]
Loading pipeline components...: 40%|██████████ | 2/5 [00:00<00:00, 3.78it/s]You set `add_prefix_space`. The tokenizer needs to be converted from the slow tokenizers
Loading pipeline components...: 100%|██████████████████████████| 5/5 [00:00<00:00, 6.72it/s]
2025-01-10 00:51:25 [info ] loading flux from black-forest-labs/FLUX.1-dev
2025-01-10 00:51:25 [info ] loading flux from black-forest-labs/FLUX.1-dev
2025-01-10 00:51:26 [info ] loading flux from black-forest-labs/FLUX.1-dev
2025-01-10 00:51:26 [info ] loading flux from black-forest-labs/FLUX.1-dev
Loading pipeline components...: 100%|██████████████████████████| 3/3 [00:00<00:00, 4.29it/s]
Loading pipeline components...: 100%|██████████████████████████| 3/3 [00:00<00:00, 3.26it/s]
Loading pipeline components...: 100%|██████████████████████████| 3/3 [00:00<00:00, 3.27it/s]
Loading pipeline components...: 100%|██████████████████████████| 3/3 [00:00<00:00, 3.25it/s]
2025-01-10 00:51:34 [info ] starting compilation run...
2025-01-10 00:51:35 [info ] starting compilation run...
2025-01-10 00:51:37 [info ] starting compilation run...
2025-01-10 00:51:37 [info ] starting compilation run...
2025-01-10 00:52:52 [info ] compilation took 78.5155531649998 sec.
2025-01-10 00:52:53 [info ] starting inference run...
2025-01-10 00:52:57 [info ] compilation took 79.52986721400157 sec.
2025-01-10 00:52:57 [info ] compilation took 81.91776501700042 sec.
2025-01-10 00:52:57 [info ] compilation took 80.24951512600092 sec.
2025-01-10 00:52:57 [info ] starting inference run...
2025-01-10 00:52:57 [info ] starting inference run...
2025-01-10 00:52:58 [info ] starting inference run...
2025-01-10 00:53:22 [info ] inference time: 25.112665320000815
2025-01-10 00:53:30 [info ] inference time: 7.7019307739992655
2025-01-10 00:53:38 [info ] inference time: 7.693858365000779
2025-01-10 00:53:46 [info ] inference time: 7.690621814001133
2025-01-10 00:53:53 [info ] inference time: 7.679490454000188
2025-01-10 00:54:01 [info ] inference time: 7.68949568500102
2025-01-10 00:54:09 [info ] inference time: 7.686633744000574
2025-01-10 00:54:16 [info ] inference time: 7.696786873999372
2025-01-10 00:54:24 [info ] inference time: 7.691988694999964
2025-01-10 00:54:32 [info ] inference time: 7.700649563999832
2025-01-10 00:54:39 [info ] inference time: 7.684993574001055
2025-01-10 00:54:47 [info ] inference time: 7.68343457499941
2025-01-10 00:54:55 [info ] inference time: 7.667921153999487
2025-01-10 00:55:02 [info ] inference time: 7.683585194001353
2025-01-10 00:55:06 [info ] avg. inference over 15 iterations took 8.61202360273334 sec.
2025-01-10 00:55:07 [info ] avg. inference over 15 iterations took 8.952725123600006 sec.
2025-01-10 00:55:10 [info ] inference time: 7.673799695001435
2025-01-10 00:55:10 [info ] avg. inference over 15 iterations took 8.849190365400379 sec.
2025-01-10 00:55:10 [info ] saved metric information as /tmp/metrics_report.txt
2025-01-10 00:55:12 [info ] avg. inference over 15 iterations took 8.940161458400205 sec.
```
@@ -9,7 +9,6 @@ import torch_xla.debug.metrics as met
import torch_xla.debug.profiler as xp
import torch_xla.distributed.xla_multiprocessing as xmp
import torch_xla.runtime as xr
from torch_xla.experimental.custom_kernel import FlashAttention
from diffusers import FluxPipeline
@@ -37,19 +36,6 @@ def _main(index, args, text_pipe, ckpt_id):
ckpt_id, text_encoder=None, tokenizer=None, text_encoder_2=None, tokenizer_2=None, torch_dtype=torch.bfloat16
).to(device0)
flux_pipe.transformer.enable_xla_flash_attention(partition_spec=("data", None, None, None), is_flux=True)
FlashAttention.DEFAULT_BLOCK_SIZES = {
"block_q": 1536,
"block_k_major": 1536,
"block_k": 1536,
"block_b": 1536,
"block_q_major_dkv": 1536,
"block_k_major_dkv": 1536,
"block_q_dkv": 1536,
"block_k_dkv": 1536,
"block_q_dq": 1536,
"block_k_dq": 1536,
"block_k_major_dq": 1536,
}
prompt = "photograph of an electronics chip in the shape of a race car with trillium written on its side"
width = args.width
@@ -83,14 +69,14 @@ def _main(index, args, text_pipe, ckpt_id):
xm.set_rng_state(seed=unique_seed, device=device0)
times = []
logger.info("starting inference run...")
with torch.no_grad():
prompt_embeds, pooled_prompt_embeds, text_ids = text_pipe.encode_prompt(
prompt=prompt, prompt_2=None, max_sequence_length=512
)
prompt_embeds = prompt_embeds.to(device0)
pooled_prompt_embeds = pooled_prompt_embeds.to(device0)
for _ in range(args.itters):
ts = perf_counter()
with torch.no_grad():
prompt_embeds, pooled_prompt_embeds, text_ids = text_pipe.encode_prompt(
prompt=prompt, prompt_2=None, max_sequence_length=512
)
prompt_embeds = prompt_embeds.to(device0)
pooled_prompt_embeds = pooled_prompt_embeds.to(device0)
if args.profile:
xp.trace_detached(f"localhost:{profiler_port}", str(profile_path), duration_ms=profile_duration)
@@ -106,7 +92,7 @@ def _main(index, args, text_pipe, ckpt_id):
if index == 0:
logger.info(f"inference time: {inference_time}")
times.append(inference_time)
logger.info(f"avg. inference over {args.itters} iterations took {sum(times) / len(times)} sec.")
logger.info(f"avg. inference over {args.itters} iterations took {sum(times)/len(times)} sec.")
image.save(f"/tmp/inference_out-{index}.png")
if index == 0:
metrics_report = met.metrics_report()
@@ -141,7 +141,9 @@ def log_validation(vae, unet, adapter, args, accelerator, weight_dtype, step):
validation_prompt = log["validation_prompt"]
validation_image = log["validation_image"]
formatted_images = [np.asarray(validation_image)]
formatted_images = []
formatted_images.append(np.asarray(validation_image))
for image in images:
formatted_images.append(np.asarray(image))
+2 -13
View File
@@ -53,18 +53,8 @@ args = parser.parse_args()
# this is specific to `AdaLayerNormContinuous`:
# diffusers implementation split the linear projection into the scale, shift while CogView4 split it tino shift, scale
def swap_scale_shift(weight, dim):
"""
Swap the scale and shift components in the weight tensor.
Args:
weight (torch.Tensor): The original weight tensor.
dim (int): The dimension along which to split.
Returns:
torch.Tensor: The modified weight tensor with scale and shift swapped.
"""
shift, scale = weight.chunk(2, dim=dim)
new_weight = torch.cat([scale, shift], dim=dim)
shift, scale = weight.chunk(2, dim=0)
new_weight = torch.cat([scale, shift], dim=0)
return new_weight
@@ -210,7 +200,6 @@ def main(args):
"norm_num_groups": 32,
"sample_size": 1024,
"scaling_factor": 1.0,
"shift_factor": 0.0,
"force_upcast": True,
"use_quant_conv": False,
"use_post_quant_conv": False,
@@ -25,15 +25,9 @@ import argparse
import torch
from tqdm import tqdm
from transformers import GlmModel, PreTrainedTokenizerFast
from transformers import GlmForCausalLM, PreTrainedTokenizerFast
from diffusers import (
AutoencoderKL,
CogView4ControlPipeline,
CogView4Pipeline,
CogView4Transformer2DModel,
FlowMatchEulerDiscreteScheduler,
)
from diffusers import AutoencoderKL, CogView4Pipeline, CogView4Transformer2DModel, FlowMatchEulerDiscreteScheduler
from diffusers.loaders.single_file_utils import convert_ldm_vae_checkpoint
@@ -118,12 +112,6 @@ parser.add_argument(
default=128,
help="Maximum size for positional embeddings.",
)
parser.add_argument(
"--control",
action="store_true",
default=False,
help="Whether to use control model.",
)
args = parser.parse_args()
@@ -162,15 +150,13 @@ def convert_megatron_transformer_checkpoint_to_diffusers(
Returns:
dict: The converted state dictionary compatible with Diffusers.
"""
ckpt = torch.load(ckpt_path, map_location="cpu", weights_only=False)
ckpt = torch.load(ckpt_path, map_location="cpu")
mega = ckpt["model"]
new_state_dict = {}
# Patch Embedding
new_state_dict["patch_embed.proj.weight"] = mega["encoder_expand_linear.weight"].reshape(
hidden_size, 128 if args.control else 64
)
new_state_dict["patch_embed.proj.weight"] = mega["encoder_expand_linear.weight"].reshape(hidden_size, 64)
new_state_dict["patch_embed.proj.bias"] = mega["encoder_expand_linear.bias"]
new_state_dict["patch_embed.text_proj.weight"] = mega["text_projector.weight"]
new_state_dict["patch_embed.text_proj.bias"] = mega["text_projector.bias"]
@@ -203,8 +189,14 @@ def convert_megatron_transformer_checkpoint_to_diffusers(
block_prefix = f"transformer_blocks.{i}."
# AdaLayerNorm
new_state_dict[block_prefix + "norm1.linear.weight"] = mega[f"decoder.layers.{i}.adaln.weight"]
new_state_dict[block_prefix + "norm1.linear.bias"] = mega[f"decoder.layers.{i}.adaln.bias"]
new_state_dict[block_prefix + "norm1.linear.weight"] = swap_scale_shift(
mega[f"decoder.layers.{i}.adaln.weight"], dim=0
)
new_state_dict[block_prefix + "norm1.linear.bias"] = swap_scale_shift(
mega[f"decoder.layers.{i}.adaln.bias"], dim=0
)
# QKV
qkv_weight = mega[f"decoder.layers.{i}.self_attention.linear_qkv.weight"]
qkv_bias = mega[f"decoder.layers.{i}.self_attention.linear_qkv.bias"]
@@ -229,7 +221,7 @@ def convert_megatron_transformer_checkpoint_to_diffusers(
# Attention Output
new_state_dict[block_prefix + "attn1.to_out.0.weight"] = mega[
f"decoder.layers.{i}.self_attention.linear_proj.weight"
]
].T
new_state_dict[block_prefix + "attn1.to_out.0.bias"] = mega[
f"decoder.layers.{i}.self_attention.linear_proj.bias"
]
@@ -260,7 +252,7 @@ def convert_cogview4_vae_checkpoint_to_diffusers(ckpt_path, vae_config):
Returns:
dict: The converted VAE state dictionary compatible with Diffusers.
"""
original_state_dict = torch.load(ckpt_path, map_location="cpu", weights_only=False)["state_dict"]
original_state_dict = torch.load(ckpt_path, map_location="cpu")["state_dict"]
return convert_ldm_vae_checkpoint(original_state_dict, vae_config)
@@ -294,7 +286,7 @@ def main(args):
)
transformer = CogView4Transformer2DModel(
patch_size=2,
in_channels=32 if args.control else 16,
in_channels=16,
num_layers=args.num_layers,
attention_head_dim=args.attention_head_dim,
num_attention_heads=args.num_heads,
@@ -325,7 +317,6 @@ def main(args):
"norm_num_groups": 32,
"sample_size": 1024,
"scaling_factor": 1.0,
"shift_factor": 0.0,
"force_upcast": True,
"use_quant_conv": False,
"use_post_quant_conv": False,
@@ -340,7 +331,7 @@ def main(args):
# Load the text encoder and tokenizer
text_encoder_id = "THUDM/glm-4-9b-hf"
tokenizer = PreTrainedTokenizerFast.from_pretrained(text_encoder_id)
text_encoder = GlmModel.from_pretrained(
text_encoder = GlmForCausalLM.from_pretrained(
text_encoder_id,
cache_dir=args.text_encoder_cache_dir,
torch_dtype=torch.bfloat16 if args.dtype == "bf16" else torch.float32,
@@ -354,22 +345,13 @@ def main(args):
)
# Create the pipeline
if args.control:
pipe = CogView4ControlPipeline(
tokenizer=tokenizer,
text_encoder=text_encoder,
vae=vae,
transformer=transformer,
scheduler=scheduler,
)
else:
pipe = CogView4Pipeline(
tokenizer=tokenizer,
text_encoder=text_encoder,
vae=vae,
transformer=transformer,
scheduler=scheduler,
)
pipe = CogView4Pipeline(
tokenizer=tokenizer,
text_encoder=text_encoder,
vae=vae,
transformer=transformer,
scheduler=scheduler,
)
# Save the converted pipeline
pipe.save_pretrained(
@@ -11,7 +11,6 @@ from diffusion import sampling
from torch import nn
from diffusers import DanceDiffusionPipeline, IPNDMScheduler, UNet1DModel
from diffusers.utils.constants import DIFFUSERS_REQUEST_TIMEOUT
MODELS_MAP = {
@@ -75,7 +74,7 @@ class DiffusionUncond(nn.Module):
def download(model_name):
url = MODELS_MAP[model_name]["url"]
r = requests.get(url, stream=True, timeout=DIFFUSERS_REQUEST_TIMEOUT)
r = requests.get(url, stream=True)
local_filename = f"./{model_name}.ckpt"
with open(local_filename, "wb") as fp:
+1 -22
View File
@@ -160,9 +160,8 @@ TRANSFORMER_CONFIGS = {
"pooled_projection_dim": 768,
"rope_theta": 256.0,
"rope_axes_dim": (16, 56, 56),
"image_condition_type": None,
},
"HYVideo-T/2-I2V-33ch": {
"HYVideo-T/2-I2V": {
"in_channels": 16 * 2 + 1,
"out_channels": 16,
"num_attention_heads": 24,
@@ -179,26 +178,6 @@ TRANSFORMER_CONFIGS = {
"pooled_projection_dim": 768,
"rope_theta": 256.0,
"rope_axes_dim": (16, 56, 56),
"image_condition_type": "latent_concat",
},
"HYVideo-T/2-I2V-16ch": {
"in_channels": 16,
"out_channels": 16,
"num_attention_heads": 24,
"attention_head_dim": 128,
"num_layers": 20,
"num_single_layers": 40,
"num_refiner_layers": 2,
"mlp_ratio": 4.0,
"patch_size": 2,
"patch_size_t": 1,
"qk_norm": "rms_norm",
"guidance_embeds": True,
"text_embed_dim": 4096,
"pooled_projection_dim": 768,
"rope_theta": 256.0,
"rope_axes_dim": (16, 56, 56),
"image_condition_type": "token_replace",
},
}
+15 -89
View File
@@ -74,32 +74,6 @@ VAE_091_RENAME_DICT = {
"last_scale_shift_table": "scale_shift_table",
}
VAE_095_RENAME_DICT = {
# decoder
"up_blocks.0": "mid_block",
"up_blocks.1": "up_blocks.0.upsamplers.0",
"up_blocks.2": "up_blocks.0",
"up_blocks.3": "up_blocks.1.upsamplers.0",
"up_blocks.4": "up_blocks.1",
"up_blocks.5": "up_blocks.2.upsamplers.0",
"up_blocks.6": "up_blocks.2",
"up_blocks.7": "up_blocks.3.upsamplers.0",
"up_blocks.8": "up_blocks.3",
# encoder
"down_blocks.0": "down_blocks.0",
"down_blocks.1": "down_blocks.0.downsamplers.0",
"down_blocks.2": "down_blocks.1",
"down_blocks.3": "down_blocks.1.downsamplers.0",
"down_blocks.4": "down_blocks.2",
"down_blocks.5": "down_blocks.2.downsamplers.0",
"down_blocks.6": "down_blocks.3",
"down_blocks.7": "down_blocks.3.downsamplers.0",
"down_blocks.8": "mid_block",
# common
"last_time_embedder": "time_embedder",
"last_scale_shift_table": "scale_shift_table",
}
VAE_SPECIAL_KEYS_REMAP = {
"per_channel_statistics.channel": remove_keys_,
"per_channel_statistics.mean-of-means": remove_keys_,
@@ -107,6 +81,10 @@ VAE_SPECIAL_KEYS_REMAP = {
"model.diffusion_model": remove_keys_,
}
VAE_091_SPECIAL_KEYS_REMAP = {
"timestep_scale_multiplier": remove_keys_,
}
def get_state_dict(saved_dict: Dict[str, Any]) -> Dict[str, Any]:
state_dict = saved_dict
@@ -126,16 +104,12 @@ def update_state_dict_inplace(state_dict: Dict[str, Any], old_key: str, new_key:
def convert_transformer(
ckpt_path: str,
dtype: torch.dtype,
version: str = "0.9.0",
):
PREFIX_KEY = "model.diffusion_model."
original_state_dict = get_state_dict(load_file(ckpt_path))
config = {}
if version == "0.9.5":
config["_use_causal_rope_fix"] = True
with init_empty_weights():
transformer = LTXVideoTransformer3DModel(**config)
transformer = LTXVideoTransformer3DModel()
for key in list(original_state_dict.keys()):
new_key = key[:]
@@ -187,19 +161,12 @@ def get_vae_config(version: str) -> Dict[str, Any]:
"out_channels": 3,
"latent_channels": 128,
"block_out_channels": (128, 256, 512, 512),
"down_block_types": (
"LTXVideoDownBlock3D",
"LTXVideoDownBlock3D",
"LTXVideoDownBlock3D",
"LTXVideoDownBlock3D",
),
"decoder_block_out_channels": (128, 256, 512, 512),
"layers_per_block": (4, 3, 3, 3, 4),
"decoder_layers_per_block": (4, 3, 3, 3, 4),
"spatio_temporal_scaling": (True, True, True, False),
"decoder_spatio_temporal_scaling": (True, True, True, False),
"decoder_inject_noise": (False, False, False, False, False),
"downsample_type": ("conv", "conv", "conv", "conv"),
"upsample_residual": (False, False, False, False),
"upsample_factor": (1, 1, 1, 1),
"patch_size": 4,
@@ -216,19 +183,12 @@ def get_vae_config(version: str) -> Dict[str, Any]:
"out_channels": 3,
"latent_channels": 128,
"block_out_channels": (128, 256, 512, 512),
"down_block_types": (
"LTXVideoDownBlock3D",
"LTXVideoDownBlock3D",
"LTXVideoDownBlock3D",
"LTXVideoDownBlock3D",
),
"decoder_block_out_channels": (256, 512, 1024),
"layers_per_block": (4, 3, 3, 3, 4),
"decoder_layers_per_block": (5, 6, 7, 8),
"spatio_temporal_scaling": (True, True, True, False),
"decoder_spatio_temporal_scaling": (True, True, True),
"decoder_inject_noise": (True, True, True, False),
"downsample_type": ("conv", "conv", "conv", "conv"),
"upsample_residual": (True, True, True),
"upsample_factor": (2, 2, 2),
"timestep_conditioning": True,
@@ -240,38 +200,7 @@ def get_vae_config(version: str) -> Dict[str, Any]:
"decoder_causal": False,
}
VAE_KEYS_RENAME_DICT.update(VAE_091_RENAME_DICT)
elif version == "0.9.5":
config = {
"in_channels": 3,
"out_channels": 3,
"latent_channels": 128,
"block_out_channels": (128, 256, 512, 1024, 2048),
"down_block_types": (
"LTXVideo095DownBlock3D",
"LTXVideo095DownBlock3D",
"LTXVideo095DownBlock3D",
"LTXVideo095DownBlock3D",
),
"decoder_block_out_channels": (256, 512, 1024),
"layers_per_block": (4, 6, 6, 2, 2),
"decoder_layers_per_block": (5, 5, 5, 5),
"spatio_temporal_scaling": (True, True, True, True),
"decoder_spatio_temporal_scaling": (True, True, True),
"decoder_inject_noise": (False, False, False, False),
"downsample_type": ("spatial", "temporal", "spatiotemporal", "spatiotemporal"),
"upsample_residual": (True, True, True),
"upsample_factor": (2, 2, 2),
"timestep_conditioning": True,
"patch_size": 4,
"patch_size_t": 1,
"resnet_norm_eps": 1e-6,
"scaling_factor": 1.0,
"encoder_causal": True,
"decoder_causal": False,
"spatial_compression_ratio": 32,
"temporal_compression_ratio": 8,
}
VAE_KEYS_RENAME_DICT.update(VAE_095_RENAME_DICT)
VAE_SPECIAL_KEYS_REMAP.update(VAE_091_SPECIAL_KEYS_REMAP)
return config
@@ -294,7 +223,7 @@ def get_args():
parser.add_argument("--output_path", type=str, required=True, help="Path where converted model should be saved")
parser.add_argument("--dtype", default="fp32", help="Torch dtype to save the model in.")
parser.add_argument(
"--version", type=str, default="0.9.0", choices=["0.9.0", "0.9.1", "0.9.5"], help="Version of the LTX model"
"--version", type=str, default="0.9.0", choices=["0.9.0", "0.9.1"], help="Version of the LTX model"
)
return parser.parse_args()
@@ -348,17 +277,14 @@ if __name__ == "__main__":
for param in text_encoder.parameters():
param.data = param.data.contiguous()
if args.version == "0.9.5":
scheduler = FlowMatchEulerDiscreteScheduler(use_dynamic_shifting=False)
else:
scheduler = FlowMatchEulerDiscreteScheduler(
use_dynamic_shifting=True,
base_shift=0.95,
max_shift=2.05,
base_image_seq_len=1024,
max_image_seq_len=4096,
shift_terminal=0.1,
)
scheduler = FlowMatchEulerDiscreteScheduler(
use_dynamic_shifting=True,
base_shift=0.95,
max_shift=2.05,
base_image_seq_len=1024,
max_image_seq_len=4096,
shift_terminal=0.1,
)
pipe = LTXPipeline(
scheduler=scheduler,
+2 -2
View File
@@ -5,7 +5,7 @@ import torch
from safetensors.torch import load_file
from transformers import AutoModel, AutoTokenizer
from diffusers import AutoencoderKL, FlowMatchEulerDiscreteScheduler, LuminaNextDiT2DModel, LuminaPipeline
from diffusers import AutoencoderKL, FlowMatchEulerDiscreteScheduler, LuminaNextDiT2DModel, LuminaText2ImgPipeline
def main(args):
@@ -115,7 +115,7 @@ def main(args):
tokenizer = AutoTokenizer.from_pretrained("google/gemma-2b")
text_encoder = AutoModel.from_pretrained("google/gemma-2b")
pipeline = LuminaPipeline(
pipeline = LuminaText2ImgPipeline(
tokenizer=tokenizer, text_encoder=text_encoder, transformer=transformer, vae=vae, scheduler=scheduler
)
pipeline.save_pretrained(args.dump_path)
+60 -198
View File
@@ -16,9 +16,7 @@ from diffusers import (
DPMSolverMultistepScheduler,
FlowMatchEulerDiscreteScheduler,
SanaPipeline,
SanaSprintPipeline,
SanaTransformer2DModel,
SCMScheduler,
)
from diffusers.models.modeling_utils import load_model_dict_into_meta
from diffusers.utils.import_utils import is_accelerate_available
@@ -27,10 +25,6 @@ from diffusers.utils.import_utils import is_accelerate_available
CTX = init_empty_weights if is_accelerate_available else nullcontext
ckpt_ids = [
"Efficient-Large-Model/Sana_Sprint_0.6B_1024px/checkpoints/Sana_Sprint_0.6B_1024px.pth"
"Efficient-Large-Model/Sana_Sprint_1.6B_1024px/checkpoints/Sana_Sprint_1.6B_1024px.pth"
"Efficient-Large-Model/SANA1.5_4.8B_1024px/checkpoints/SANA1.5_4.8B_1024px.pth",
"Efficient-Large-Model/SANA1.5_1.6B_1024px/checkpoints/SANA1.5_1.6B_1024px.pth",
"Efficient-Large-Model/Sana_1600M_4Kpx_BF16/checkpoints/Sana_1600M_4Kpx_BF16.pth",
"Efficient-Large-Model/Sana_1600M_2Kpx_BF16/checkpoints/Sana_1600M_2Kpx_BF16.pth",
"Efficient-Large-Model/Sana_1600M_1024px_MultiLing/checkpoints/Sana_1600M_1024px_MultiLing.pth",
@@ -78,42 +72,15 @@ def main(args):
converted_state_dict["caption_projection.linear_2.weight"] = state_dict.pop("y_embedder.y_proj.fc2.weight")
converted_state_dict["caption_projection.linear_2.bias"] = state_dict.pop("y_embedder.y_proj.fc2.bias")
# Handle different time embedding structure based on model type
if args.model_type in ["SanaSprint_1600M_P1_D20", "SanaSprint_600M_P1_D28"]:
# For Sana Sprint, the time embedding structure is different
converted_state_dict["time_embed.timestep_embedder.linear_1.weight"] = state_dict.pop(
"t_embedder.mlp.0.weight"
)
converted_state_dict["time_embed.timestep_embedder.linear_1.bias"] = state_dict.pop("t_embedder.mlp.0.bias")
converted_state_dict["time_embed.timestep_embedder.linear_2.weight"] = state_dict.pop(
"t_embedder.mlp.2.weight"
)
converted_state_dict["time_embed.timestep_embedder.linear_2.bias"] = state_dict.pop("t_embedder.mlp.2.bias")
# Guidance embedder for Sana Sprint
converted_state_dict["time_embed.guidance_embedder.linear_1.weight"] = state_dict.pop(
"cfg_embedder.mlp.0.weight"
)
converted_state_dict["time_embed.guidance_embedder.linear_1.bias"] = state_dict.pop("cfg_embedder.mlp.0.bias")
converted_state_dict["time_embed.guidance_embedder.linear_2.weight"] = state_dict.pop(
"cfg_embedder.mlp.2.weight"
)
converted_state_dict["time_embed.guidance_embedder.linear_2.bias"] = state_dict.pop("cfg_embedder.mlp.2.bias")
else:
# Original Sana time embedding structure
converted_state_dict["time_embed.emb.timestep_embedder.linear_1.weight"] = state_dict.pop(
"t_embedder.mlp.0.weight"
)
converted_state_dict["time_embed.emb.timestep_embedder.linear_1.bias"] = state_dict.pop(
"t_embedder.mlp.0.bias"
)
converted_state_dict["time_embed.emb.timestep_embedder.linear_2.weight"] = state_dict.pop(
"t_embedder.mlp.2.weight"
)
converted_state_dict["time_embed.emb.timestep_embedder.linear_2.bias"] = state_dict.pop(
"t_embedder.mlp.2.bias"
)
# AdaLN-single LN
converted_state_dict["time_embed.emb.timestep_embedder.linear_1.weight"] = state_dict.pop(
"t_embedder.mlp.0.weight"
)
converted_state_dict["time_embed.emb.timestep_embedder.linear_1.bias"] = state_dict.pop("t_embedder.mlp.0.bias")
converted_state_dict["time_embed.emb.timestep_embedder.linear_2.weight"] = state_dict.pop(
"t_embedder.mlp.2.weight"
)
converted_state_dict["time_embed.emb.timestep_embedder.linear_2.bias"] = state_dict.pop("t_embedder.mlp.2.bias")
# Shared norm.
converted_state_dict["time_embed.linear.weight"] = state_dict.pop("t_block.1.weight")
@@ -129,22 +96,14 @@ def main(args):
flow_shift = 3.0
# model config
if args.model_type in ["SanaMS_1600M_P1_D20", "SanaSprint_1600M_P1_D20", "SanaMS1.5_1600M_P1_D20"]:
if args.model_type == "SanaMS_1600M_P1_D20":
layer_num = 20
elif args.model_type in ["SanaMS_600M_P1_D28", "SanaSprint_600M_P1_D28"]:
elif args.model_type == "SanaMS_600M_P1_D28":
layer_num = 28
elif args.model_type == "SanaMS_4800M_P1_D60":
layer_num = 60
else:
raise ValueError(f"{args.model_type} is not supported.")
# Positional embedding interpolation scale.
interpolation_scale = {512: None, 1024: None, 2048: 1.0, 4096: 2.0}
qk_norm = (
"rms_norm_across_heads"
if args.model_type
in ["SanaMS1.5_1600M_P1_D20", "SanaMS1.5_4800M_P1_D60", "SanaSprint_600M_P1_D28", "SanaSprint_1600M_P1_D20"]
else None
)
for depth in range(layer_num):
# Transformer blocks.
@@ -158,14 +117,6 @@ def main(args):
converted_state_dict[f"transformer_blocks.{depth}.attn1.to_q.weight"] = q
converted_state_dict[f"transformer_blocks.{depth}.attn1.to_k.weight"] = k
converted_state_dict[f"transformer_blocks.{depth}.attn1.to_v.weight"] = v
if qk_norm is not None:
# Add Q/K normalization for self-attention (attn1) - needed for Sana-Sprint and Sana-1.5
converted_state_dict[f"transformer_blocks.{depth}.attn1.norm_q.weight"] = state_dict.pop(
f"blocks.{depth}.attn.q_norm.weight"
)
converted_state_dict[f"transformer_blocks.{depth}.attn1.norm_k.weight"] = state_dict.pop(
f"blocks.{depth}.attn.k_norm.weight"
)
# Projection.
converted_state_dict[f"transformer_blocks.{depth}.attn1.to_out.0.weight"] = state_dict.pop(
f"blocks.{depth}.attn.proj.weight"
@@ -203,14 +154,6 @@ def main(args):
converted_state_dict[f"transformer_blocks.{depth}.attn2.to_k.bias"] = k_bias
converted_state_dict[f"transformer_blocks.{depth}.attn2.to_v.weight"] = v
converted_state_dict[f"transformer_blocks.{depth}.attn2.to_v.bias"] = v_bias
if qk_norm is not None:
# Add Q/K normalization for cross-attention (attn2) - needed for Sana-Sprint and Sana-1.5
converted_state_dict[f"transformer_blocks.{depth}.attn2.norm_q.weight"] = state_dict.pop(
f"blocks.{depth}.cross_attn.q_norm.weight"
)
converted_state_dict[f"transformer_blocks.{depth}.attn2.norm_k.weight"] = state_dict.pop(
f"blocks.{depth}.cross_attn.k_norm.weight"
)
converted_state_dict[f"transformer_blocks.{depth}.attn2.to_out.0.weight"] = state_dict.pop(
f"blocks.{depth}.cross_attn.proj.weight"
@@ -226,37 +169,24 @@ def main(args):
# Transformer
with CTX():
transformer_kwargs = {
"in_channels": 32,
"out_channels": 32,
"num_attention_heads": model_kwargs[args.model_type]["num_attention_heads"],
"attention_head_dim": model_kwargs[args.model_type]["attention_head_dim"],
"num_layers": model_kwargs[args.model_type]["num_layers"],
"num_cross_attention_heads": model_kwargs[args.model_type]["num_cross_attention_heads"],
"cross_attention_head_dim": model_kwargs[args.model_type]["cross_attention_head_dim"],
"cross_attention_dim": model_kwargs[args.model_type]["cross_attention_dim"],
"caption_channels": 2304,
"mlp_ratio": 2.5,
"attention_bias": False,
"sample_size": args.image_size // 32,
"patch_size": 1,
"norm_elementwise_affine": False,
"norm_eps": 1e-6,
"interpolation_scale": interpolation_scale[args.image_size],
}
# Add qk_norm parameter for Sana Sprint
if args.model_type in [
"SanaMS1.5_1600M_P1_D20",
"SanaMS1.5_4800M_P1_D60",
"SanaSprint_600M_P1_D28",
"SanaSprint_1600M_P1_D20",
]:
transformer_kwargs["qk_norm"] = "rms_norm_across_heads"
if args.model_type in ["SanaSprint_1600M_P1_D20", "SanaSprint_600M_P1_D28"]:
transformer_kwargs["guidance_embeds"] = True
transformer = SanaTransformer2DModel(**transformer_kwargs)
transformer = SanaTransformer2DModel(
in_channels=32,
out_channels=32,
num_attention_heads=model_kwargs[args.model_type]["num_attention_heads"],
attention_head_dim=model_kwargs[args.model_type]["attention_head_dim"],
num_layers=model_kwargs[args.model_type]["num_layers"],
num_cross_attention_heads=model_kwargs[args.model_type]["num_cross_attention_heads"],
cross_attention_head_dim=model_kwargs[args.model_type]["cross_attention_head_dim"],
cross_attention_dim=model_kwargs[args.model_type]["cross_attention_dim"],
caption_channels=2304,
mlp_ratio=2.5,
attention_bias=False,
sample_size=args.image_size // 32,
patch_size=1,
norm_elementwise_affine=False,
norm_eps=1e-6,
interpolation_scale=interpolation_scale[args.image_size],
)
if is_accelerate_available():
load_model_dict_into_meta(transformer, converted_state_dict)
@@ -266,8 +196,6 @@ def main(args):
try:
state_dict.pop("y_embedder.y_embedding")
state_dict.pop("pos_embed")
state_dict.pop("logvar_linear.weight")
state_dict.pop("logvar_linear.bias")
except KeyError:
print("y_embedder.y_embedding or pos_embed not found in the state_dict")
@@ -282,74 +210,47 @@ def main(args):
print(
colored(
f"Only saving transformer model of {args.model_type}. "
f"Set --save_full_pipeline to save the whole Pipeline",
f"Set --save_full_pipeline to save the whole SanaPipeline",
"green",
attrs=["bold"],
)
)
transformer.save_pretrained(
os.path.join(args.dump_path, "transformer"), safe_serialization=True, max_shard_size="5GB"
os.path.join(args.dump_path, "transformer"), safe_serialization=True, max_shard_size="5GB", variant=variant
)
else:
print(colored(f"Saving the whole Pipeline containing {args.model_type}", "green", attrs=["bold"]))
print(colored(f"Saving the whole SanaPipeline containing {args.model_type}", "green", attrs=["bold"]))
# VAE
ae = AutoencoderDC.from_pretrained("mit-han-lab/dc-ae-f32c32-sana-1.1-diffusers", torch_dtype=torch.float32)
ae = AutoencoderDC.from_pretrained("mit-han-lab/dc-ae-f32c32-sana-1.0-diffusers", torch_dtype=torch.float32)
# Text Encoder
text_encoder_model_path = "Efficient-Large-Model/gemma-2-2b-it"
text_encoder_model_path = "google/gemma-2-2b-it"
tokenizer = AutoTokenizer.from_pretrained(text_encoder_model_path)
tokenizer.padding_side = "right"
text_encoder = AutoModelForCausalLM.from_pretrained(
text_encoder_model_path, torch_dtype=torch.bfloat16
).get_decoder()
# Choose the appropriate pipeline and scheduler based on model type
if args.model_type in ["SanaSprint_1600M_P1_D20", "SanaSprint_600M_P1_D28"]:
# Force SCM Scheduler for Sana Sprint regardless of scheduler_type
if args.scheduler_type != "scm":
print(
colored(
f"Warning: Overriding scheduler_type '{args.scheduler_type}' to 'scm' for SanaSprint model",
"yellow",
attrs=["bold"],
)
)
# SCM Scheduler for Sana Sprint
scheduler_config = {
"prediction_type": "trigflow",
"sigma_data": 0.5,
}
scheduler = SCMScheduler(**scheduler_config)
pipe = SanaSprintPipeline(
tokenizer=tokenizer,
text_encoder=text_encoder,
transformer=transformer,
vae=ae,
scheduler=scheduler,
# Scheduler
if args.scheduler_type == "flow-dpm_solver":
scheduler = DPMSolverMultistepScheduler(
flow_shift=flow_shift,
use_flow_sigmas=True,
prediction_type="flow_prediction",
)
elif args.scheduler_type == "flow-euler":
scheduler = FlowMatchEulerDiscreteScheduler(shift=flow_shift)
else:
# Original Sana scheduler
if args.scheduler_type == "flow-dpm_solver":
scheduler = DPMSolverMultistepScheduler(
flow_shift=flow_shift,
use_flow_sigmas=True,
prediction_type="flow_prediction",
)
elif args.scheduler_type == "flow-euler":
scheduler = FlowMatchEulerDiscreteScheduler(shift=flow_shift)
else:
raise ValueError(f"Scheduler type {args.scheduler_type} is not supported")
raise ValueError(f"Scheduler type {args.scheduler_type} is not supported")
pipe = SanaPipeline(
tokenizer=tokenizer,
text_encoder=text_encoder,
transformer=transformer,
vae=ae,
scheduler=scheduler,
)
pipe.save_pretrained(args.dump_path, safe_serialization=True, max_shard_size="5GB")
pipe = SanaPipeline(
tokenizer=tokenizer,
text_encoder=text_encoder,
transformer=transformer,
vae=ae,
scheduler=scheduler,
)
pipe.save_pretrained(args.dump_path, safe_serialization=True, max_shard_size="5GB", variant=variant)
DTYPE_MAPPING = {
@@ -358,6 +259,12 @@ DTYPE_MAPPING = {
"bf16": torch.bfloat16,
}
VARIANT_MAPPING = {
"fp32": None,
"fp16": "fp16",
"bf16": "bf16",
}
if __name__ == "__main__":
parser = argparse.ArgumentParser()
@@ -374,24 +281,10 @@ if __name__ == "__main__":
help="Image size of pretrained model, 512, 1024, 2048 or 4096.",
)
parser.add_argument(
"--model_type",
default="SanaMS_1600M_P1_D20",
type=str,
choices=[
"SanaMS_1600M_P1_D20",
"SanaMS_600M_P1_D28",
"SanaMS1.5_1600M_P1_D20",
"SanaMS1.5_4800M_P1_D60",
"SanaSprint_1600M_P1_D20",
"SanaSprint_600M_P1_D28",
],
"--model_type", default="SanaMS_1600M_P1_D20", type=str, choices=["SanaMS_1600M_P1_D20", "SanaMS_600M_P1_D28"]
)
parser.add_argument(
"--scheduler_type",
default="flow-dpm_solver",
type=str,
choices=["flow-dpm_solver", "flow-euler", "scm"],
help="Scheduler type to use. Use 'scm' for Sana Sprint models.",
"--scheduler_type", default="flow-dpm_solver", type=str, choices=["flow-dpm_solver", "flow-euler"]
)
parser.add_argument("--dump_path", default=None, type=str, required=True, help="Path to the output pipeline.")
parser.add_argument("--save_full_pipeline", action="store_true", help="save all the pipelien elemets in one.")
@@ -416,41 +309,10 @@ if __name__ == "__main__":
"cross_attention_dim": 1152,
"num_layers": 28,
},
"SanaMS1.5_1600M_P1_D20": {
"num_attention_heads": 70,
"attention_head_dim": 32,
"num_cross_attention_heads": 20,
"cross_attention_head_dim": 112,
"cross_attention_dim": 2240,
"num_layers": 20,
},
"SanaMS1.5_4800M_P1_D60": {
"num_attention_heads": 70,
"attention_head_dim": 32,
"num_cross_attention_heads": 20,
"cross_attention_head_dim": 112,
"cross_attention_dim": 2240,
"num_layers": 60,
},
"SanaSprint_600M_P1_D28": {
"num_attention_heads": 36,
"attention_head_dim": 32,
"num_cross_attention_heads": 16,
"cross_attention_head_dim": 72,
"cross_attention_dim": 1152,
"num_layers": 28,
},
"SanaSprint_1600M_P1_D20": {
"num_attention_heads": 70,
"attention_head_dim": 32,
"num_cross_attention_heads": 20,
"cross_attention_head_dim": 112,
"cross_attention_dim": 2240,
"num_layers": 20,
},
}
device = "cuda" if torch.cuda.is_available() else "cpu"
weight_dtype = DTYPE_MAPPING[args.dtype]
variant = VARIANT_MAPPING[args.dtype]
main(args)
+1 -3
View File
@@ -13,7 +13,6 @@ from diffusers.pipelines.stable_diffusion.convert_from_ckpt import (
renew_vae_attention_paths,
renew_vae_resnet_paths,
)
from diffusers.utils.constants import DIFFUSERS_REQUEST_TIMEOUT
def custom_convert_ldm_vae_checkpoint(checkpoint, config):
@@ -123,8 +122,7 @@ def vae_pt_to_vae_diffuser(
):
# Only support V1
r = requests.get(
" https://raw.githubusercontent.com/CompVis/stable-diffusion/main/configs/stable-diffusion/v1-inference.yaml",
timeout=DIFFUSERS_REQUEST_TIMEOUT,
" https://raw.githubusercontent.com/CompVis/stable-diffusion/main/configs/stable-diffusion/v1-inference.yaml"
)
io_obj = io.BytesIO(r.content)
+1 -48
View File
@@ -33,7 +33,6 @@ from .utils import (
_import_structure = {
"configuration_utils": ["ConfigMixin"],
"guiders": [],
"hooks": [],
"loaders": ["FromOriginalModelMixin"],
"models": [],
@@ -130,25 +129,10 @@ except OptionalDependencyNotAvailable:
_import_structure["utils.dummy_pt_objects"] = [name for name in dir(dummy_pt_objects) if not name.startswith("_")]
else:
_import_structure["guiders"].extend(
[
"AdaptiveProjectedGuidance",
"ClassifierFreeGuidance",
"ClassifierFreeZeroStarGuidance",
"PerturbedAttentionGuidance",
"SkipLayerGuidance",
]
)
_import_structure["hooks"].extend(
[
"FasterCacheConfig",
"FirstBlockCacheConfig",
"HookRegistry",
"LayerSkipConfig",
"PyramidAttentionBroadcastConfig",
"apply_faster_cache",
"apply_first_block_cache",
"apply_layer_skip",
"apply_pyramid_attention_broadcast",
]
)
@@ -287,7 +271,6 @@ else:
"RePaintScheduler",
"SASolverScheduler",
"SchedulerMixin",
"SCMScheduler",
"ScoreSdeVeScheduler",
"TCDScheduler",
"UnCLIPScheduler",
@@ -362,7 +345,6 @@ else:
"CogVideoXPipeline",
"CogVideoXVideoToVideoPipeline",
"CogView3PlusPipeline",
"CogView4ControlPipeline",
"CogView4Pipeline",
"ConsisIDPipeline",
"CycleDiffusionPipeline",
@@ -419,12 +401,9 @@ else:
"LDMTextToImagePipeline",
"LEditsPPPipelineStableDiffusion",
"LEditsPPPipelineStableDiffusionXL",
"LTXConditionPipeline",
"LTXImageToVideoPipeline",
"LTXPipeline",
"Lumina2Pipeline",
"Lumina2Text2ImgPipeline",
"LuminaPipeline",
"LuminaText2ImgPipeline",
"MarigoldDepthPipeline",
"MarigoldIntrinsicsPipeline",
@@ -440,7 +419,6 @@ else:
"ReduxImageEncoder",
"SanaPAGPipeline",
"SanaPipeline",
"SanaSprintPipeline",
"SemanticStableDiffusionPipeline",
"ShapEImg2ImgPipeline",
"ShapEPipeline",
@@ -523,7 +501,6 @@ else:
"VQDiffusionPipeline",
"WanImageToVideoPipeline",
"WanPipeline",
"WanVideoToVideoPipeline",
"WuerstchenCombinedPipeline",
"WuerstchenDecoderPipeline",
"WuerstchenPriorPipeline",
@@ -722,24 +699,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
except OptionalDependencyNotAvailable:
from .utils.dummy_pt_objects import * # noqa F403
else:
from .guiders import (
AdaptiveProjectedGuidance,
ClassifierFreeGuidance,
ClassifierFreeZeroStarGuidance,
PerturbedAttentionGuidance,
SkipLayerGuidance,
)
from .hooks import (
FasterCacheConfig,
FirstBlockCacheConfig,
HookRegistry,
LayerSkipConfig,
PyramidAttentionBroadcastConfig,
apply_faster_cache,
apply_first_block_cache,
apply_layer_skip,
apply_pyramid_attention_broadcast,
)
from .hooks import HookRegistry, PyramidAttentionBroadcastConfig, apply_pyramid_attention_broadcast
from .models import (
AllegroTransformer3DModel,
AsymmetricAutoencoderKL,
@@ -872,7 +832,6 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
RePaintScheduler,
SASolverScheduler,
SchedulerMixin,
SCMScheduler,
ScoreSdeVeScheduler,
TCDScheduler,
UnCLIPScheduler,
@@ -928,7 +887,6 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
CogVideoXPipeline,
CogVideoXVideoToVideoPipeline,
CogView3PlusPipeline,
CogView4ControlPipeline,
CogView4Pipeline,
ConsisIDPipeline,
CycleDiffusionPipeline,
@@ -985,12 +943,9 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
LDMTextToImagePipeline,
LEditsPPPipelineStableDiffusion,
LEditsPPPipelineStableDiffusionXL,
LTXConditionPipeline,
LTXImageToVideoPipeline,
LTXPipeline,
Lumina2Pipeline,
Lumina2Text2ImgPipeline,
LuminaPipeline,
LuminaText2ImgPipeline,
MarigoldDepthPipeline,
MarigoldIntrinsicsPipeline,
@@ -1006,7 +961,6 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
ReduxImageEncoder,
SanaPAGPipeline,
SanaPipeline,
SanaSprintPipeline,
SemanticStableDiffusionPipeline,
ShapEImg2ImgPipeline,
ShapEPipeline,
@@ -1088,7 +1042,6 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
VQDiffusionPipeline,
WanImageToVideoPipeline,
WanPipeline,
WanVideoToVideoPipeline,
WuerstchenCombinedPipeline,
WuerstchenDecoderPipeline,
WuerstchenPriorPipeline,
+1 -4
View File
@@ -35,7 +35,6 @@ from huggingface_hub.utils import (
validate_hf_hub_args,
)
from requests import HTTPError
from typing_extensions import Self
from . import __version__
from .utils import (
@@ -186,9 +185,7 @@ class ConfigMixin:
)
@classmethod
def from_config(
cls, config: Union[FrozenDict, Dict[str, Any]] = None, return_unused_kwargs=False, **kwargs
) -> Union[Self, Tuple[Self, Dict[str, Any]]]:
def from_config(cls, config: Union[FrozenDict, Dict[str, Any]] = None, return_unused_kwargs=False, **kwargs):
r"""
Instantiate a Python class from a config dictionary.
-24
View File
@@ -1,24 +0,0 @@
# Copyright 2024 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from ..utils import is_torch_available
if is_torch_available():
from .adaptive_projected_guidance import AdaptiveProjectedGuidance
from .classifier_free_guidance import ClassifierFreeGuidance
from .classifier_free_zero_star_guidance import ClassifierFreeZeroStarGuidance
from .guider_utils import GuidanceMixin, _raise_guidance_deprecation_warning
from .perturbed_attention_guidance import PerturbedAttentionGuidance
from .skip_layer_guidance import SkipLayerGuidance
@@ -1,145 +0,0 @@
# Copyright 2024 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import math
from typing import Optional
import torch
from .guider_utils import GuidanceMixin, rescale_noise_cfg
class AdaptiveProjectedGuidance(GuidanceMixin):
"""
Adaptive Projected Guidance (APG): https://huggingface.co/papers/2410.02416
Args:
guidance_scale (`float`, defaults to `7.5`):
The scale parameter for classifier-free guidance. Higher values result in stronger conditioning on the text
prompt, while lower values allow for more freedom in generation. Higher values may lead to saturation and
deterioration of image quality.
adaptive_projected_guidance_momentum (`float`, defaults to `None`):
The momentum parameter for the adaptive projected guidance. Disabled if set to `None`.
adaptive_projected_guidance_rescale (`float`, defaults to `15.0`):
The rescale factor applied to the noise predictions. This is used to improve image quality and fix
guidance_rescale (`float`, defaults to `0.0`):
The rescale factor applied to the noise predictions. This is used to improve image quality and fix
overexposure. Based on Section 3.4 from [Common Diffusion Noise Schedules and Sample Steps are
Flawed](https://huggingface.co/papers/2305.08891).
use_original_formulation (`bool`, defaults to `False`):
Whether to use the original formulation of classifier-free guidance as proposed in the paper. By default,
we use the diffusers-native implementation that has been in the codebase for a long time. See
[~guiders.classifier_free_guidance.ClassifierFreeGuidance] for more details.
"""
_input_predictions = ["pred_cond", "pred_uncond"]
def __init__(
self,
guidance_scale: float = 7.5,
adaptive_projected_guidance_momentum: Optional[float] = None,
adaptive_projected_guidance_rescale: float = 15.0,
eta: float = 1.0,
guidance_rescale: float = 0.0,
use_original_formulation: bool = False,
):
super().__init__()
self.guidance_scale = guidance_scale
self.adaptive_projected_guidance_momentum = adaptive_projected_guidance_momentum
self.adaptive_projected_guidance_rescale = adaptive_projected_guidance_rescale
self.eta = eta
self.guidance_rescale = guidance_rescale
self.use_original_formulation = use_original_formulation
self.momentum_buffer = None
def prepare_inputs(self, *args):
if self._step == 0:
if self.adaptive_projected_guidance_momentum is not None:
self.momentum_buffer = MomentumBuffer(self.adaptive_projected_guidance_momentum)
return super().prepare_inputs(*args)
def forward(self, pred_cond: torch.Tensor, pred_uncond: Optional[torch.Tensor] = None) -> torch.Tensor:
pred = None
if self._is_cfg_enabled():
pred = pred_cond
else:
pred = normalized_guidance(
pred_cond,
pred_uncond,
self.guidance_scale,
self.momentum_buffer,
self.eta,
self.adaptive_projected_guidance_rescale,
self.use_original_formulation,
)
if self.guidance_rescale > 0.0:
pred = rescale_noise_cfg(pred, pred_cond, self.guidance_rescale)
return pred
@property
def num_conditions(self) -> int:
num_conditions = 1
if self._is_cfg_enabled():
num_conditions += 1
return num_conditions
def _is_cfg_enabled(self) -> bool:
if self.use_original_formulation:
return not math.isclose(self.guidance_scale, 0.0)
else:
return not math.isclose(self.guidance_scale, 1.0)
class MomentumBuffer:
def __init__(self, momentum: float):
self.momentum = momentum
self.running_average = 0
def update(self, update_value: torch.Tensor):
new_average = self.momentum * self.running_average
self.running_average = update_value + new_average
def normalized_guidance(
pred_cond: torch.Tensor,
pred_uncond: torch.Tensor,
guidance_scale: float,
momentum_buffer: Optional[MomentumBuffer] = None,
eta: float = 1.0,
norm_threshold: float = 0.0,
use_original_formulation: bool = False,
):
diff = pred_cond - pred_uncond
dim = [-i for i in range(1, len(diff.shape))]
if momentum_buffer is not None:
momentum_buffer.update(diff)
diff = momentum_buffer.running_average
if norm_threshold > 0:
ones = torch.ones_like(diff)
diff_norm = diff.norm(p=2, dim=dim, keepdim=True)
scale_factor = torch.minimum(ones, norm_threshold / diff_norm)
diff = diff * scale_factor
v0, v1 = diff.double(), pred_cond.double()
v1 = torch.nn.functional.normalize(v1, dim=dim)
v0_parallel = (v0 * v1).sum(dim=dim, keepdim=True) * v1
v0_orthogonal = v0 - v0_parallel
diff_parallel, diff_orthogonal = v0_parallel.type_as(diff), v0_orthogonal.type_as(diff)
normalized_update = diff_orthogonal + eta * diff_parallel
pred = pred_cond if use_original_formulation else pred_uncond
pred = pred + (guidance_scale - 1) * normalized_update
return pred
@@ -1,98 +0,0 @@
# Copyright 2024 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import math
from typing import Optional
import torch
from .guider_utils import GuidanceMixin, rescale_noise_cfg
class ClassifierFreeGuidance(GuidanceMixin):
"""
Classifier-free guidance (CFG): https://huggingface.co/papers/2207.12598
CFG is a technique used to improve generation quality and condition-following in diffusion models. It works by
jointly training a model on both conditional and unconditional data, and using a weighted sum of the two during
inference. This allows the model to tradeoff between generation quality and sample diversity.
The original paper proposes scaling and shifting the conditional distribution based on the difference between
conditional and unconditional predictions. [x_pred = x_cond + scale * (x_cond - x_uncond)]
Diffusers implemented the scaling and shifting on the unconditional prediction instead based on the [Imagen
paper](https://huggingface.co/papers/2205.11487), which is equivalent to what the original paper proposed in
theory. [x_pred = x_uncond + scale * (x_cond - x_uncond)]
The intution behind the original formulation can be thought of as moving the conditional distribution estimates
further away from the unconditional distribution estimates, while the diffusers-native implementation can be
thought of as moving the unconditional distribution towards the conditional distribution estimates to get rid of
the unconditional predictions (usually negative features like "bad quality, bad anotomy, watermarks", etc.)
The `use_original_formulation` argument can be set to `True` to use the original CFG formulation mentioned in the
paper. By default, we use the diffusers-native implementation that has been in the codebase for a long time.
Args:
guidance_scale (`float`, defaults to `7.5`):
The scale parameter for classifier-free guidance. Higher values result in stronger conditioning on the text
prompt, while lower values allow for more freedom in generation. Higher values may lead to saturation and
deterioration of image quality.
guidance_rescale (`float`, defaults to `0.0`):
The rescale factor applied to the noise predictions. This is used to improve image quality and fix
overexposure. Based on Section 3.4 from [Common Diffusion Noise Schedules and Sample Steps are
Flawed](https://huggingface.co/papers/2305.08891).
use_original_formulation (`bool`, defaults to `False`):
Whether to use the original formulation of classifier-free guidance as proposed in the paper. By default,
we use the diffusers-native implementation that has been in the codebase for a long time. See
[~guiders.classifier_free_guidance.ClassifierFreeGuidance] for more details.
"""
_input_predictions = ["pred_cond", "pred_uncond"]
def __init__(
self, guidance_scale: float = 7.5, guidance_rescale: float = 0.0, use_original_formulation: bool = False
):
super().__init__()
self.guidance_scale = guidance_scale
self.guidance_rescale = guidance_rescale
self.use_original_formulation = use_original_formulation
def forward(self, pred_cond: torch.Tensor, pred_uncond: Optional[torch.Tensor] = None) -> torch.Tensor:
pred = None
if not self._is_cfg_enabled():
pred = pred_cond
else:
shift = pred_cond - pred_uncond
pred = pred_cond if self.use_original_formulation else pred_uncond
pred = pred + self.guidance_scale * shift
if self.guidance_rescale > 0.0:
pred = rescale_noise_cfg(pred, pred_cond, self.guidance_rescale)
return pred
@property
def num_conditions(self) -> int:
num_conditions = 1
if self._is_cfg_enabled():
num_conditions += 1
return num_conditions
def _is_cfg_enabled(self) -> bool:
if self.use_original_formulation:
return not math.isclose(self.guidance_scale, 0.0)
else:
return not math.isclose(self.guidance_scale, 1.0)
@@ -1,110 +0,0 @@
# Copyright 2024 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import math
from typing import Optional
import torch
from .guider_utils import GuidanceMixin, rescale_noise_cfg
class ClassifierFreeZeroStarGuidance(GuidanceMixin):
"""
Classifier-free Zero* (CFG-Zero*): https://huggingface.co/papers/2503.18886
This is an implementation of the Classifier-Free Zero* guidance technique, which is a variant of classifier-free
guidance. It proposes zero initialization of the noise predictions for the first few steps of the diffusion
process, and also introduces an optimal rescaling factor for the noise predictions, which can help in improving the
quality of generated images.
The authors of the paper suggest setting zero initialization in the first 4% of the inference steps.
Args:
guidance_scale (`float`, defaults to `7.5`):
The scale parameter for classifier-free guidance. Higher values result in stronger conditioning on the text
prompt, while lower values allow for more freedom in generation. Higher values may lead to saturation and
deterioration of image quality.
zero_init_steps (`int`, defaults to `1`):
The number of inference steps for which the noise predictions are zeroed out (see Section 4.2).
guidance_rescale (`float`, defaults to `0.0`):
The rescale factor applied to the noise predictions. This is used to improve image quality and fix
overexposure. Based on Section 3.4 from [Common Diffusion Noise Schedules and Sample Steps are
Flawed](https://huggingface.co/papers/2305.08891).
use_original_formulation (`bool`, defaults to `False`):
Whether to use the original formulation of classifier-free guidance as proposed in the paper. By default,
we use the diffusers-native implementation that has been in the codebase for a long time. See
[~guiders.classifier_free_guidance.ClassifierFreeGuidance] for more details.
"""
_input_predictions = ["pred_cond", "pred_uncond"]
def __init__(
self,
guidance_scale: float = 7.5,
zero_init_steps: int = 1,
guidance_rescale: float = 0.0,
use_original_formulation: bool = False,
):
super().__init__()
self.guidance_scale = guidance_scale
self.zero_init_steps = zero_init_steps
self.guidance_rescale = guidance_rescale
self.use_original_formulation = use_original_formulation
def forward(self, pred_cond: torch.Tensor, pred_uncond: Optional[torch.Tensor] = None) -> torch.Tensor:
pred = None
if self._step < self.zero_init_steps:
pred = torch.zeros_like(pred_cond)
elif self._is_cfg_enabled():
pred = pred_cond
else:
shift = pred_cond - pred_uncond
pred_cond_flat = pred_cond.flatten(1)
pred_uncond_flat = pred_uncond.flatten(1)
alpha = cfg_zero_star_scale(pred_cond_flat, pred_uncond_flat)
alpha = alpha.view(-1, *(1,) * (len(pred_cond.shape) - 1))
pred_uncond = pred_uncond * alpha
pred = pred_cond if self.use_original_formulation else pred_uncond
pred = pred + self.guidance_scale * shift
if self.guidance_rescale > 0.0:
pred = rescale_noise_cfg(pred, pred_cond, self.guidance_rescale)
return pred
@property
def num_conditions(self) -> int:
num_conditions = 1
if self._is_cfg_enabled():
num_conditions += 1
return num_conditions
def _is_cfg_enabled(self) -> bool:
if self.use_original_formulation:
return not math.isclose(self.guidance_scale, 0.0)
else:
return not math.isclose(self.guidance_scale, 1.0)
def cfg_zero_star_scale(cond: torch.Tensor, uncond: torch.Tensor, eps: float = 1e-8) -> torch.Tensor:
cond = cond.float()
uncond = uncond.float()
dot_product = torch.sum(cond * uncond, dim=1, keepdim=True)
squared_norm = torch.sum(uncond**2, dim=1, keepdim=True) + eps
# st_star = v_cond^T * v_uncond / ||v_uncond||^2
scale = dot_product / squared_norm
return scale.type_as(cond)
-213
View File
@@ -1,213 +0,0 @@
# Copyright 2024 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import re
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union
import torch
from ..utils import deprecate, get_logger
if TYPE_CHECKING:
from ..models.attention_processor import AttentionProcessor
logger = get_logger(__name__) # pylint: disable=invalid-name
class GuidanceMixin:
r"""Base mixin class providing the skeleton for implementing guidance techniques."""
_input_predictions = None
def __init__(self):
self._step: int = None
self._num_inference_steps: int = None
self._timestep: torch.LongTensor = None
self._preds: Dict[str, torch.Tensor] = {}
self._num_outputs_prepared: int = 0
if self._input_predictions is None or not isinstance(self._input_predictions, list):
raise ValueError(
"`_input_predictions` must be a list of required prediction names for the guidance technique."
)
def set_state(self, step: int, num_inference_steps: int, timestep: torch.LongTensor) -> None:
self._step = step
self._num_inference_steps = num_inference_steps
self._timestep = timestep
self._preds = {}
self._num_outputs_prepared = 0
def prepare_models(self, denoiser: torch.nn.Module) -> None:
pass
def prepare_inputs(self, *args: Union[Tuple[torch.Tensor], List[torch.Tensor]]) -> Tuple[List[torch.Tensor], ...]:
num_conditions = self.num_conditions
list_of_inputs = []
for arg in args:
if isinstance(arg, torch.Tensor):
list_of_inputs.append([arg] * num_conditions)
elif isinstance(arg, (tuple, list)):
if len(arg) != 2:
raise ValueError(
f"Expected a tuple or list of length 2, but got {len(arg)} for argument {arg}. Please provide a tuple/list of length 2 "
f"with the first element being the conditional input and the second element being the unconditional input or None."
)
if arg[1] is None:
# Only conditioning inputs for all batches
list_of_inputs.append([arg[0]] * num_conditions)
else:
# Alternating conditional and unconditional inputs as batches
inputs = [arg[i % 2] for i in range(num_conditions)]
list_of_inputs.append(inputs)
else:
raise ValueError(
f"Expected a tensor, tuple, or list, but got {type(arg)} for argument {arg}. Please provide a tensor, tuple, or list."
)
return tuple(list_of_inputs)
def prepare_outputs(self, pred: torch.Tensor) -> None:
self._num_outputs_prepared += 1
if self._num_outputs_prepared > self.num_conditions:
raise ValueError(f"Expected {self.num_conditions} outputs, but prepare_outputs called more times.")
key = self._input_predictions[self._num_outputs_prepared - 1]
self._preds[key] = pred
def cleanup_models(self, denoiser: torch.nn.Module) -> None:
pass
def __call__(self, **kwargs) -> Any:
if len(kwargs) != self.num_conditions:
raise ValueError(
f"Expected {self.num_conditions} arguments, but got {len(kwargs)}. Please provide the correct number of arguments."
)
return self.forward(**kwargs)
def forward(self, *args, **kwargs) -> Any:
raise NotImplementedError("GuidanceMixin::forward must be implemented in subclasses.")
@property
def num_conditions(self) -> int:
raise NotImplementedError("GuidanceMixin::num_conditions must be implemented in subclasses.")
@property
def outputs(self) -> Dict[str, torch.Tensor]:
return self._preds
def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0):
r"""
Rescales `noise_cfg` tensor based on `guidance_rescale` to improve image quality and fix overexposure. Based on
Section 3.4 from [Common Diffusion Noise Schedules and Sample Steps are
Flawed](https://arxiv.org/pdf/2305.08891.pdf).
Args:
noise_cfg (`torch.Tensor`):
The predicted noise tensor for the guided diffusion process.
noise_pred_text (`torch.Tensor`):
The predicted noise tensor for the text-guided diffusion process.
guidance_rescale (`float`, *optional*, defaults to 0.0):
A rescale factor applied to the noise predictions.
Returns:
noise_cfg (`torch.Tensor`): The rescaled noise prediction tensor.
"""
std_text = noise_pred_text.std(dim=list(range(1, noise_pred_text.ndim)), keepdim=True)
std_cfg = noise_cfg.std(dim=list(range(1, noise_cfg.ndim)), keepdim=True)
# rescale the results from guidance (fixes overexposure)
noise_pred_rescaled = noise_cfg * (std_text / std_cfg)
# mix with the original results from guidance by factor guidance_rescale to avoid "plain looking" images
noise_cfg = guidance_rescale * noise_pred_rescaled + (1 - guidance_rescale) * noise_cfg
return noise_cfg
def _replace_attention_processors(
module: torch.nn.Module,
pag_applied_layers: Optional[Union[str, List[str]]] = None,
skip_context_attention: bool = False,
processors: Optional[List[Tuple[torch.nn.Module, "AttentionProcessor"]]] = None,
metadata_name: Optional[str] = None,
) -> Optional[List[Tuple[torch.nn.Module, "AttentionProcessor"]]]:
if processors is not None and metadata_name is not None:
raise ValueError("Cannot pass both `processors` and `metadata_name` at the same time.")
if metadata_name is not None:
if isinstance(pag_applied_layers, str):
pag_applied_layers = [pag_applied_layers]
return _replace_layers_with_guidance_processors(
module, pag_applied_layers, skip_context_attention, metadata_name
)
if processors is not None:
_replace_layers_with_existing_processors(processors)
def _replace_layers_with_guidance_processors(
module: torch.nn.Module,
pag_applied_layers: List[str],
skip_context_attention: bool,
metadata_name: str,
) -> List[Tuple[torch.nn.Module, "AttentionProcessor"]]:
from ..hooks._common import _ATTENTION_CLASSES
from ..hooks._helpers import GuidanceMetadataRegistry
processors = []
for name, submodule in module.named_modules():
if (
(not isinstance(submodule, _ATTENTION_CLASSES))
or (getattr(submodule, "processor", None) is None)
or not (
any(
re.search(pag_layer, name) is not None and not _is_fake_integral_match(pag_layer, name)
for pag_layer in pag_applied_layers
)
)
):
continue
old_attention_processor = submodule.processor
metadata = GuidanceMetadataRegistry.get(old_attention_processor.__class__)
new_attention_processor_cls = getattr(metadata, metadata_name)
new_attention_processor = new_attention_processor_cls()
# !!! dunder methods cannot be replaced on instances !!!
# if "skip_context_attention" in inspect.signature(new_attention_processor.__call__).parameters:
# new_attention_processor.__call__ = partial(
# new_attention_processor.__call__, skip_context_attention=skip_context_attention
# )
submodule.processor = new_attention_processor
processors.append((submodule, old_attention_processor))
return processors
def _replace_layers_with_existing_processors(processors: List[Tuple[torch.nn.Module, "AttentionProcessor"]]) -> None:
for module, proc in processors:
module.processor = proc
def _is_fake_integral_match(layer_id, name):
layer_id = layer_id.split(".")[-1]
name = name.split(".")[-1]
return layer_id.isnumeric() and name.isnumeric() and layer_id == name
def _raise_guidance_deprecation_warning(
*,
guidance_scale: Optional[float] = None,
guidance_rescale: Optional[float] = None,
) -> None:
if guidance_scale is not None:
msg = "The `guidance_scale` argument is deprecated and will be removed in version 1.0.0. Please pass a `GuidanceMixin` object for the `guidance` argument instead."
deprecate("guidance_scale", "1.0.0", msg, standard_warn=False)
if guidance_rescale is not None:
msg = "The `guidance_rescale` argument is deprecated and will be removed in version 1.0.0. Please pass a `GuidanceMixin` object for the `guidance` argument instead."
deprecate("guidance_rescale", "1.0.0", msg, standard_warn=False)
@@ -1,180 +0,0 @@
# Copyright 2024 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import math
from typing import List, Optional, Tuple, Union
import torch
from .guider_utils import GuidanceMixin, _replace_attention_processors, rescale_noise_cfg
class PerturbedAttentionGuidance(GuidanceMixin):
"""
Perturbed Attention Guidance (PAB): https://huggingface.co/papers/2403.17377
Args:
pag_applied_layers (`str` or `List[str]`):
The name of the attention layers where Perturbed Attention Guidance is applied. This can be a single layer
name or a list of layer names. The names should either be FQNs (fully qualified names) to each attention
layer or a regex pattern that matches the FQNs of the attention layers. For example, if you want to apply
PAG to transformer blocks 10 and 20, you can set this to `["transformer_blocks.10",
"transformer_blocks.20"]`, or `"transformer_blocks.(10|20)"`.
guidance_scale (`float`, defaults to `7.5`):
The scale parameter for classifier-free guidance. Higher values result in stronger conditioning on the text
prompt, while lower values allow for more freedom in generation. Higher values may lead to saturation and
deterioration of image quality.
pag_scale (`float`, defaults to `3.0`):
The scale parameter for perturbed attention guidance.
guidance_rescale (`float`, defaults to `0.0`):
The rescale factor applied to the noise predictions. This is used to improve image quality and fix
overexposure. Based on Section 3.4 from [Common Diffusion Noise Schedules and Sample Steps are
Flawed](https://huggingface.co/papers/2305.08891).
use_original_formulation (`bool`, defaults to `False`):
Whether to use the original formulation of classifier-free guidance as proposed in the paper. By default,
we use the diffusers-native implementation that has been in the codebase for a long time. See
[~guiders.classifier_free_guidance.ClassifierFreeGuidance] for more details.
"""
_input_predictions = ["pred_cond", "pred_uncond", "pred_perturbed"]
def __init__(
self,
pag_applied_layers: Union[str, List[str]],
guidance_scale: float = 7.5,
pag_scale: float = 3.0,
skip_context_attention: bool = False,
guidance_rescale: float = 0.0,
use_original_formulation: bool = False,
):
super().__init__()
self.pag_applied_layers = pag_applied_layers
self.guidance_scale = guidance_scale
self.pag_scale = pag_scale
self.skip_context_attention = skip_context_attention
self.guidance_rescale = guidance_rescale
self.use_original_formulation = use_original_formulation
self._is_pag_batch = False
self._original_processors = None
self._denoiser = None
def prepare_models(self, denoiser: torch.nn.Module):
self._denoiser = denoiser
def prepare_inputs(self, *args: Union[Tuple[torch.Tensor], List[torch.Tensor]]) -> Tuple[List[torch.Tensor], ...]:
num_conditions = self.num_conditions
list_of_inputs = []
for arg in args:
if isinstance(arg, torch.Tensor):
list_of_inputs.append([arg] * num_conditions)
elif isinstance(arg, (tuple, list)):
if len(arg) != 2:
raise ValueError(
f"Expected a tuple or list of length 2, but got {len(arg)} for argument {arg}. Please provide a tuple/list of length 2 "
f"with the first element being the conditional input and the second element being the unconditional input or None."
)
if arg[1] is None:
# Only conditioning inputs for all batches
list_of_inputs.append([arg[0]] * num_conditions)
else:
list_of_inputs.append([arg[0], arg[1], arg[0]])
else:
raise ValueError(
f"Expected a tensor, tuple, or list, but got {type(arg)} for argument {arg}. Please provide a tensor, tuple, or list."
)
return tuple(list_of_inputs)
def prepare_outputs(self, pred: torch.Tensor) -> None:
self._num_outputs_prepared += 1
if self._num_outputs_prepared > self.num_conditions:
raise ValueError(f"Expected {self.num_conditions} outputs, but prepare_outputs called more times.")
key = self._input_predictions[self._num_outputs_prepared - 1]
if not self._is_cfg_enabled() and self._is_pag_enabled():
# If we're predicting pred_cond and pred_perturbed only, we need to set the key to pred_perturbed
# to avoid writing into pred_uncond which is not used
if self._num_outputs_prepared == 2:
key = "pred_perturbed"
self._preds[key] = pred
# Restore the original attention processors if previously replaced
if self._is_pag_batch:
_replace_attention_processors(self._denoiser, processors=self._original_processors)
self._is_pag_batch = False
self._original_processors = None
# Prepare denoiser for perturbed attention prediction if needed
if self._is_pag_enabled():
should_register_pag = (self._is_cfg_enabled() and self._num_outputs_prepared == 2) or (
not self._is_cfg_enabled() and self._num_outputs_prepared == 1
)
if should_register_pag:
self._is_pag_batch = True
self._original_processors = _replace_attention_processors(
self._denoiser,
self.pag_applied_layers,
skip_context_attention=self.skip_context_attention,
metadata_name="perturbed_attention_guidance_processor_cls",
)
def cleanup_models(self, denoiser: torch.nn.Module):
self._denoiser = None
def forward(
self,
pred_cond: torch.Tensor,
pred_uncond: Optional[torch.Tensor] = None,
pred_perturbed: Optional[torch.Tensor] = None,
) -> torch.Tensor:
pred = None
if not self._is_cfg_enabled() and not self._is_pag_enabled():
pred = pred_cond
elif not self._is_cfg_enabled():
shift = pred_cond - pred_perturbed
pred = pred_cond + self.pag_scale * shift
elif not self._is_pag_enabled():
shift = pred_cond - pred_uncond
pred = pred_cond if self.use_original_formulation else pred_uncond
pred = pred + self.guidance_scale * shift
else:
shift = pred_cond - pred_uncond
shift_perturbed = pred_cond - pred_perturbed
pred = pred_cond if self.use_original_formulation else pred_uncond
pred = pred + self.guidance_scale * shift + self.pag_scale * shift_perturbed
if self.guidance_rescale > 0.0:
pred = rescale_noise_cfg(pred, pred_cond, self.guidance_rescale)
return pred
@property
def num_conditions(self) -> int:
num_conditions = 1
if self._is_cfg_enabled():
num_conditions += 1
if self._is_pag_enabled():
num_conditions += 1
return num_conditions
def _is_cfg_enabled(self) -> bool:
if self.use_original_formulation:
return not math.isclose(self.guidance_scale, 0.0)
else:
return not math.isclose(self.guidance_scale, 1.0)
def _is_pag_enabled(self) -> bool:
is_zero = math.isclose(self.pag_scale, 0.0)
return not is_zero
@@ -1,235 +0,0 @@
# Copyright 2024 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import math
from typing import List, Optional, Tuple, Union
import torch
from ..hooks import HookRegistry, LayerSkipConfig
from ..hooks.layer_skip import _apply_layer_skip_hook
from .guider_utils import GuidanceMixin, rescale_noise_cfg
class SkipLayerGuidance(GuidanceMixin):
"""
Skip Layer Guidance (SLG): https://github.com/Stability-AI/sd3.5 Spatio-Temporal Guidance (STG):
https://huggingface.co/papers/2411.18664
SLG was introduced by StabilityAI for improving structure and anotomy coherence in generated images. It works by
skipping the forward pass of specified transformer blocks during the denoising process on an additional conditional
batch of data, apart from the conditional and unconditional batches already used in CFG
([~guiders.classifier_free_guidance.ClassifierFreeGuidance]), and then scaling and shifting the CFG predictions
based on the difference between conditional without skipping and conditional with skipping predictions.
The intution behind SLG can be thought of as moving the CFG predicted distribution estimates further away from
worse versions of the conditional distribution estimates (because skipping layers is equivalent to using a worse
version of the model for the conditional prediction).
STG is an improvement and follow-up work combining ideas from SLG, PAG and similar techniques for improving
generation quality in video diffusion models.
Additional reading:
- [Guiding a Diffusion Model with a Bad Version of Itself](https://huggingface.co/papers/2406.02507)
The values for `skip_layer_guidance_scale`, `skip_layer_guidance_start`, and `skip_layer_guidance_stop` are
defaulted to the recommendations by StabilityAI for Stable Diffusion 3.5 Medium.
Args:
guidance_scale (`float`, defaults to `7.5`):
The scale parameter for classifier-free guidance. Higher values result in stronger conditioning on the text
prompt, while lower values allow for more freedom in generation. Higher values may lead to saturation and
deterioration of image quality.
skip_layer_guidance_scale (`float`, defaults to `2.8`):
The scale parameter for skip layer guidance. Anatomy and structure coherence may improve with higher
values, but it may also lead to overexposure and saturation.
skip_layer_guidance_start (`float`, defaults to `0.01`):
The fraction of the total number of denoising steps after which skip layer guidance starts.
skip_layer_guidance_stop (`float`, defaults to `0.2`):
The fraction of the total number of denoising steps after which skip layer guidance stops.
skip_layer_guidance_layers (`int` or `List[int]`, *optional*):
The layer indices to apply skip layer guidance to. Can be a single integer or a list of integers. If not
provided, `skip_layer_config` must be provided. The recommended values are `[7, 8, 9]` for Stable Diffusion
3.5 Medium.
skip_layer_config (`LayerSkipConfig` or `List[LayerSkipConfig]`, *optional*):
The configuration for the skip layer guidance. Can be a single `LayerSkipConfig` or a list of
`LayerSkipConfig`. If not provided, `skip_layer_guidance_layers` must be provided.
guidance_rescale (`float`, defaults to `0.0`):
The rescale factor applied to the noise predictions. This is used to improve image quality and fix
overexposure. Based on Section 3.4 from [Common Diffusion Noise Schedules and Sample Steps are
Flawed](https://huggingface.co/papers/2305.08891).
use_original_formulation (`bool`, defaults to `False`):
Whether to use the original formulation of classifier-free guidance as proposed in the paper. By default,
we use the diffusers-native implementation that has been in the codebase for a long time. See
[~guiders.classifier_free_guidance.ClassifierFreeGuidance] for more details.
"""
_input_predictions = ["pred_cond", "pred_uncond", "pred_cond_skip"]
def __init__(
self,
guidance_scale: float = 7.5,
skip_layer_guidance_scale: float = 2.8,
skip_layer_guidance_start: float = 0.01,
skip_layer_guidance_stop: float = 0.2,
skip_layer_guidance_layers: Optional[Union[int, List[int]]] = None,
skip_layer_config: Union[LayerSkipConfig, List[LayerSkipConfig]] = None,
guidance_rescale: float = 0.0,
use_original_formulation: bool = False,
):
super().__init__()
self.guidance_scale = guidance_scale
self.skip_layer_guidance_scale = skip_layer_guidance_scale
self.skip_layer_guidance_start = skip_layer_guidance_start
self.skip_layer_guidance_stop = skip_layer_guidance_stop
self.guidance_rescale = guidance_rescale
self.use_original_formulation = use_original_formulation
if not (0.0 <= skip_layer_guidance_start < 1.0):
raise ValueError(
f"Expected `skip_layer_guidance_start` to be between 0.0 and 1.0, but got {skip_layer_guidance_start}."
)
if not (0.0 < skip_layer_guidance_stop <= 1.0):
raise ValueError(
f"Expected `skip_layer_guidance_stop` to be between 0.0 and 1.0, but got {skip_layer_guidance_stop}."
)
if skip_layer_guidance_layers is None and skip_layer_config is None:
raise ValueError(
"Either `skip_layer_guidance_layers` or `skip_layer_config` must be provided to enable Skip Layer Guidance."
)
if skip_layer_guidance_layers is not None and skip_layer_config is not None:
raise ValueError("Only one of `skip_layer_guidance_layers` or `skip_layer_config` can be provided.")
if skip_layer_guidance_layers is not None:
if isinstance(skip_layer_guidance_layers, int):
skip_layer_guidance_layers = [skip_layer_guidance_layers]
if not isinstance(skip_layer_guidance_layers, list):
raise ValueError(
f"Expected `skip_layer_guidance_layers` to be an int or a list of ints, but got {type(skip_layer_guidance_layers)}."
)
skip_layer_config = [LayerSkipConfig(layer, fqn="auto") for layer in skip_layer_guidance_layers]
if isinstance(skip_layer_config, LayerSkipConfig):
skip_layer_config = [skip_layer_config]
if not isinstance(skip_layer_config, list):
raise ValueError(
f"Expected `skip_layer_config` to be a LayerSkipConfig or a list of LayerSkipConfig, but got {type(skip_layer_config)}."
)
self.skip_layer_config = skip_layer_config
self._skip_layer_hook_names = [f"SkipLayerGuidance_{i}" for i in range(len(self.skip_layer_config))]
def prepare_models(self, denoiser: torch.nn.Module):
skip_start_step = int(self.skip_layer_guidance_start * self._num_inference_steps)
skip_stop_step = int(self.skip_layer_guidance_stop * self._num_inference_steps)
# Register the hooks for layer skipping if the step is within the specified range
if skip_start_step < self._step < skip_stop_step:
for name, config in zip(self._skip_layer_hook_names, self.skip_layer_config):
_apply_layer_skip_hook(denoiser, config, name=name)
def prepare_inputs(self, *args: Union[Tuple[torch.Tensor], List[torch.Tensor]]) -> Tuple[List[torch.Tensor], ...]:
num_conditions = self.num_conditions
list_of_inputs = []
for arg in args:
if isinstance(arg, torch.Tensor):
list_of_inputs.append([arg] * num_conditions)
elif isinstance(arg, (tuple, list)):
if len(arg) != 2:
raise ValueError(
f"Expected a tuple or list of length 2, but got {len(arg)} for argument {arg}. Please provide a tuple/list of length 2 "
f"with the first element being the conditional input and the second element being the unconditional input or None."
)
if arg[1] is None:
# Only conditioning inputs for all batches
list_of_inputs.append([arg[0]] * num_conditions)
else:
list_of_inputs.append([arg[0], arg[1], arg[0]])
else:
raise ValueError(
f"Expected a tensor, tuple, or list, but got {type(arg)} for argument {arg}. Please provide a tensor, tuple, or list."
)
return tuple(list_of_inputs)
def prepare_outputs(self, pred: torch.Tensor) -> None:
self._num_outputs_prepared += 1
if self._num_outputs_prepared > self.num_conditions:
raise ValueError(f"Expected {self.num_conditions} outputs, but prepare_outputs called more times.")
key = self._input_predictions[self._num_outputs_prepared - 1]
if not self._is_cfg_enabled() and self._is_slg_enabled():
# If we're predicting pred_cond and pred_cond_skip only, we need to set the key to pred_cond_skip
# to avoid writing into pred_uncond which is not used
if self._num_outputs_prepared == 2:
key = "pred_cond_skip"
self._preds[key] = pred
def cleanup_models(self, denoiser: torch.nn.Module):
registry = HookRegistry.check_if_exists_or_initialize(denoiser)
# Remove the hooks after inference
for hook_name in self._skip_layer_hook_names:
registry.remove_hook(hook_name, recurse=True)
def forward(
self,
pred_cond: torch.Tensor,
pred_uncond: Optional[torch.Tensor] = None,
pred_cond_skip: Optional[torch.Tensor] = None,
) -> torch.Tensor:
pred = None
if not self._is_cfg_enabled() and not self._is_slg_enabled():
pred = pred_cond
elif not self._is_cfg_enabled():
shift = pred_cond - pred_cond_skip
pred = pred_cond if self.use_original_formulation else pred_cond_skip
pred = pred + self.skip_layer_guidance_scale * shift
elif not self._is_slg_enabled():
shift = pred_cond - pred_uncond
pred = pred_cond if self.use_original_formulation else pred_uncond
pred = pred + self.guidance_scale * shift
else:
shift = pred_cond - pred_uncond
shift_skip = pred_cond - pred_cond_skip
pred = pred_cond if self.use_original_formulation else pred_uncond
pred = pred + self.guidance_scale * shift + self.skip_layer_guidance_scale * shift_skip
if self.guidance_rescale > 0.0:
pred = rescale_noise_cfg(pred, pred_cond, self.guidance_rescale)
return pred
@property
def num_conditions(self) -> int:
num_conditions = 1
if self._is_cfg_enabled():
num_conditions += 1
if self._is_slg_enabled():
num_conditions += 1
return num_conditions
def _is_cfg_enabled(self) -> bool:
if self.use_original_formulation:
return not math.isclose(self.guidance_scale, 0.0)
else:
return not math.isclose(self.guidance_scale, 1.0)
def _is_slg_enabled(self) -> bool:
skip_start_step = int(self.skip_layer_guidance_start * self._num_inference_steps)
skip_stop_step = int(self.skip_layer_guidance_stop * self._num_inference_steps)
is_within_range = skip_start_step < self._step < skip_stop_step
is_zero = math.isclose(self.skip_layer_guidance_scale, 0.0)
return is_within_range and not is_zero
-17
View File
@@ -1,25 +1,8 @@
# Copyright 2024 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from ..utils import is_torch_available
if is_torch_available():
from .faster_cache import FasterCacheConfig, apply_faster_cache
from .first_block_cache import FirstBlockCacheConfig, apply_first_block_cache
from .group_offloading import apply_group_offloading
from .hooks import HookRegistry, ModelHook
from .layer_skip import LayerSkipConfig, apply_layer_skip
from .layerwise_casting import apply_layerwise_casting, apply_layerwise_casting_hook
from .pyramid_attention_broadcast import PyramidAttentionBroadcastConfig, apply_pyramid_attention_broadcast
-32
View File
@@ -1,32 +0,0 @@
# Copyright 2024 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from ..models.attention import FeedForward, LuminaFeedForward
from ..models.attention_processor import Attention, MochiAttention
_ATTENTION_CLASSES = (Attention, MochiAttention)
_FEEDFORWARD_CLASSES = (FeedForward, LuminaFeedForward)
_SPATIAL_TRANSFORMER_BLOCK_IDENTIFIERS = ("blocks", "transformer_blocks", "single_transformer_blocks", "layers")
_TEMPORAL_TRANSFORMER_BLOCK_IDENTIFIERS = ("temporal_transformer_blocks",)
_CROSS_TRANSFORMER_BLOCK_IDENTIFIERS = ("blocks", "transformer_blocks", "layers")
_ALL_TRANSFORMER_BLOCK_IDENTIFIERS = tuple(
{
*_SPATIAL_TRANSFORMER_BLOCK_IDENTIFIERS,
*_TEMPORAL_TRANSFORMER_BLOCK_IDENTIFIERS,
*_CROSS_TRANSFORMER_BLOCK_IDENTIFIERS,
}
)
-276
View File
@@ -1,276 +0,0 @@
# Copyright 2024 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from dataclasses import dataclass
from typing import Any, Callable, Type
from ..models.transformers.cogvideox_transformer_3d import CogVideoXBlock
from ..models.transformers.transformer_cogview4 import (
CogView4AttnProcessor,
CogView4PAGAttnProcessor,
CogView4TransformerBlock,
)
from ..models.transformers.transformer_flux import FluxSingleTransformerBlock, FluxTransformerBlock
from ..models.transformers.transformer_hunyuan_video import (
HunyuanVideoSingleTransformerBlock,
HunyuanVideoTokenReplaceSingleTransformerBlock,
HunyuanVideoTokenReplaceTransformerBlock,
HunyuanVideoTransformerBlock,
)
from ..models.transformers.transformer_ltx import LTXVideoTransformerBlock
from ..models.transformers.transformer_mochi import MochiTransformerBlock
from ..models.transformers.transformer_wan import WanTransformerBlock
@dataclass
class AttentionProcessorMetadata:
skip_processor_output_fn: Callable[[Any], Any]
@dataclass
class GuidanceMetadata:
perturbed_attention_guidance_processor_cls: Type = None
@dataclass
class TransformerBlockMetadata:
skip_block_output_fn: Callable[[Any], Any]
return_hidden_states_index: int = None
return_encoder_hidden_states_index: int = None
class AttentionProcessorRegistry:
_registry = {}
@classmethod
def register(cls, model_class: Type, metadata: AttentionProcessorMetadata):
cls._registry[model_class] = metadata
@classmethod
def get(cls, model_class: Type) -> AttentionProcessorMetadata:
if model_class not in cls._registry:
raise ValueError(f"Model class {model_class} not registered.")
return cls._registry[model_class]
class GuidanceMetadataRegistry:
_registry = {}
@classmethod
def register(cls, model_class: Type, metadata: GuidanceMetadata):
cls._registry[model_class] = metadata
@classmethod
def get(cls, model_class: Type) -> GuidanceMetadata:
if model_class not in cls._registry:
raise ValueError(f"Model class {model_class} not registered.")
return cls._registry[model_class]
class TransformerBlockRegistry:
_registry = {}
@classmethod
def register(cls, model_class: Type, metadata: TransformerBlockMetadata):
cls._registry[model_class] = metadata
@classmethod
def get(cls, model_class: Type) -> TransformerBlockMetadata:
if model_class not in cls._registry:
raise ValueError(f"Model class {model_class} not registered.")
return cls._registry[model_class]
def _register_attention_processors_metadata():
# CogView4
AttentionProcessorRegistry.register(
model_class=CogView4AttnProcessor,
metadata=AttentionProcessorMetadata(
skip_processor_output_fn=_skip_proc_output_fn_Attention_CogView4AttnProcessor,
),
)
def _register_guidance_metadata():
# CogView4
GuidanceMetadataRegistry.register(
model_class=CogView4AttnProcessor,
metadata=GuidanceMetadata(
perturbed_attention_guidance_processor_cls=CogView4PAGAttnProcessor,
),
)
def _register_transformer_blocks_metadata():
# CogVideoX
TransformerBlockRegistry.register(
model_class=CogVideoXBlock,
metadata=TransformerBlockMetadata(
skip_block_output_fn=_skip_block_output_fn_CogVideoXBlock,
return_hidden_states_index=0,
return_encoder_hidden_states_index=1,
),
)
# CogView4
TransformerBlockRegistry.register(
model_class=CogView4TransformerBlock,
metadata=TransformerBlockMetadata(
skip_block_output_fn=_skip_block_output_fn_CogView4TransformerBlock,
return_hidden_states_index=0,
return_encoder_hidden_states_index=1,
),
)
# Flux
TransformerBlockRegistry.register(
model_class=FluxTransformerBlock,
metadata=TransformerBlockMetadata(
skip_block_output_fn=_skip_block_output_fn_FluxTransformerBlock,
return_hidden_states_index=1,
return_encoder_hidden_states_index=0,
),
)
TransformerBlockRegistry.register(
model_class=FluxSingleTransformerBlock,
metadata=TransformerBlockMetadata(
skip_block_output_fn=_skip_block_output_fn_FluxSingleTransformerBlock,
return_hidden_states_index=1,
return_encoder_hidden_states_index=0,
),
)
# HunyuanVideo
TransformerBlockRegistry.register(
model_class=HunyuanVideoTransformerBlock,
metadata=TransformerBlockMetadata(
skip_block_output_fn=_skip_block_output_fn_HunyuanVideoTransformerBlock,
return_hidden_states_index=0,
return_encoder_hidden_states_index=1,
),
)
TransformerBlockRegistry.register(
model_class=HunyuanVideoSingleTransformerBlock,
metadata=TransformerBlockMetadata(
skip_block_output_fn=_skip_block_output_fn_HunyuanVideoSingleTransformerBlock,
return_hidden_states_index=0,
return_encoder_hidden_states_index=1,
),
)
TransformerBlockRegistry.register(
model_class=HunyuanVideoTokenReplaceTransformerBlock,
metadata=TransformerBlockMetadata(
skip_block_output_fn=_skip_block_output_fn_HunyuanVideoTokenReplaceTransformerBlock,
return_hidden_states_index=0,
return_encoder_hidden_states_index=1,
),
)
TransformerBlockRegistry.register(
model_class=HunyuanVideoTokenReplaceSingleTransformerBlock,
metadata=TransformerBlockMetadata(
skip_block_output_fn=_skip_block_output_fn_HunyuanVideoTokenReplaceSingleTransformerBlock,
return_hidden_states_index=0,
return_encoder_hidden_states_index=1,
),
)
# LTXVideo
TransformerBlockRegistry.register(
model_class=LTXVideoTransformerBlock,
metadata=TransformerBlockMetadata(
skip_block_output_fn=_skip_block_output_fn_LTXVideoTransformerBlock,
return_hidden_states_index=0,
return_encoder_hidden_states_index=None,
),
)
# Mochi
TransformerBlockRegistry.register(
model_class=MochiTransformerBlock,
metadata=TransformerBlockMetadata(
skip_block_output_fn=_skip_block_output_fn_MochiTransformerBlock,
return_hidden_states_index=0,
return_encoder_hidden_states_index=1,
),
)
# Wan
TransformerBlockRegistry.register(
model_class=WanTransformerBlock,
metadata=TransformerBlockMetadata(
skip_block_output_fn=_skip_block_output_fn_WanTransformerBlock,
return_hidden_states_index=0,
return_encoder_hidden_states_index=None,
),
)
# fmt: off
def _skip_attention___ret___hidden_states___encoder_hidden_states(self, *args, **kwargs):
hidden_states = kwargs.get("hidden_states", None)
encoder_hidden_states = kwargs.get("encoder_hidden_states", None)
if hidden_states is None and len(args) > 0:
hidden_states = args[0]
if encoder_hidden_states is None and len(args) > 1:
encoder_hidden_states = args[1]
return hidden_states, encoder_hidden_states
_skip_proc_output_fn_Attention_CogView4AttnProcessor = _skip_attention___ret___hidden_states___encoder_hidden_states
def _skip_block_output_fn___hidden_states_0___ret___hidden_states(self, *args, **kwargs):
hidden_states = kwargs.get("hidden_states", None)
if hidden_states is None and len(args) > 0:
hidden_states = args[0]
return hidden_states
def _skip_block_output_fn___hidden_states_0___encoder_hidden_states_1___ret___hidden_states___encoder_hidden_states(self, *args, **kwargs):
hidden_states = kwargs.get("hidden_states", None)
encoder_hidden_states = kwargs.get("encoder_hidden_states", None)
if hidden_states is None and len(args) > 0:
hidden_states = args[0]
if encoder_hidden_states is None and len(args) > 1:
encoder_hidden_states = args[1]
return hidden_states, encoder_hidden_states
def _skip_block_output_fn___hidden_states_0___encoder_hidden_states_1___ret___encoder_hidden_states___hidden_states(self, *args, **kwargs):
hidden_states = kwargs.get("hidden_states", None)
encoder_hidden_states = kwargs.get("encoder_hidden_states", None)
if hidden_states is None and len(args) > 0:
hidden_states = args[0]
if encoder_hidden_states is None and len(args) > 1:
encoder_hidden_states = args[1]
return encoder_hidden_states, hidden_states
_skip_block_output_fn_CogVideoXBlock = _skip_block_output_fn___hidden_states_0___encoder_hidden_states_1___ret___hidden_states___encoder_hidden_states
_skip_block_output_fn_CogView4TransformerBlock = _skip_block_output_fn___hidden_states_0___encoder_hidden_states_1___ret___hidden_states___encoder_hidden_states
_skip_block_output_fn_FluxTransformerBlock = _skip_block_output_fn___hidden_states_0___encoder_hidden_states_1___ret___encoder_hidden_states___hidden_states
_skip_block_output_fn_FluxSingleTransformerBlock = _skip_block_output_fn___hidden_states_0___encoder_hidden_states_1___ret___encoder_hidden_states___hidden_states
_skip_block_output_fn_HunyuanVideoTransformerBlock = _skip_block_output_fn___hidden_states_0___encoder_hidden_states_1___ret___hidden_states___encoder_hidden_states
_skip_block_output_fn_HunyuanVideoSingleTransformerBlock = _skip_block_output_fn___hidden_states_0___encoder_hidden_states_1___ret___hidden_states___encoder_hidden_states
_skip_block_output_fn_HunyuanVideoTokenReplaceTransformerBlock = _skip_block_output_fn___hidden_states_0___encoder_hidden_states_1___ret___hidden_states___encoder_hidden_states
_skip_block_output_fn_HunyuanVideoTokenReplaceSingleTransformerBlock = _skip_block_output_fn___hidden_states_0___encoder_hidden_states_1___ret___hidden_states___encoder_hidden_states
_skip_block_output_fn_LTXVideoTransformerBlock = _skip_block_output_fn___hidden_states_0___ret___hidden_states
_skip_block_output_fn_MochiTransformerBlock = _skip_block_output_fn___hidden_states_0___encoder_hidden_states_1___ret___hidden_states___encoder_hidden_states
_skip_block_output_fn_WanTransformerBlock = _skip_block_output_fn___hidden_states_0___ret___hidden_states
# fmt: on
_register_attention_processors_metadata()
_register_guidance_metadata()
_register_transformer_blocks_metadata()
-653
View File
@@ -1,653 +0,0 @@
# Copyright 2024 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import re
from dataclasses import dataclass
from typing import Any, Callable, List, Optional, Tuple
import torch
from ..models.attention_processor import Attention, MochiAttention
from ..models.modeling_outputs import Transformer2DModelOutput
from ..utils import logging
from .hooks import HookRegistry, ModelHook
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
_FASTER_CACHE_DENOISER_HOOK = "faster_cache_denoiser"
_FASTER_CACHE_BLOCK_HOOK = "faster_cache_block"
_ATTENTION_CLASSES = (Attention, MochiAttention)
_SPATIAL_ATTENTION_BLOCK_IDENTIFIERS = (
"^blocks.*attn",
"^transformer_blocks.*attn",
"^single_transformer_blocks.*attn",
)
_TEMPORAL_ATTENTION_BLOCK_IDENTIFIERS = ("^temporal_transformer_blocks.*attn",)
_TRANSFORMER_BLOCK_IDENTIFIERS = _SPATIAL_ATTENTION_BLOCK_IDENTIFIERS + _TEMPORAL_ATTENTION_BLOCK_IDENTIFIERS
_UNCOND_COND_INPUT_KWARGS_IDENTIFIERS = (
"hidden_states",
"encoder_hidden_states",
"timestep",
"attention_mask",
"encoder_attention_mask",
)
@dataclass
class FasterCacheConfig:
r"""
Configuration for [FasterCache](https://huggingface.co/papers/2410.19355).
Attributes:
spatial_attention_block_skip_range (`int`, defaults to `2`):
Calculate the attention states every `N` iterations. If this is set to `N`, the attention computation will
be skipped `N - 1` times (i.e., cached attention states will be re-used) before computing the new attention
states again.
temporal_attention_block_skip_range (`int`, *optional*, defaults to `None`):
Calculate the attention states every `N` iterations. If this is set to `N`, the attention computation will
be skipped `N - 1` times (i.e., cached attention states will be re-used) before computing the new attention
states again.
spatial_attention_timestep_skip_range (`Tuple[float, float]`, defaults to `(-1, 681)`):
The timestep range within which the spatial attention computation can be skipped without a significant loss
in quality. This is to be determined by the user based on the underlying model. The first value in the
tuple is the lower bound and the second value is the upper bound. Typically, diffusion timesteps for
denoising are in the reversed range of 0 to 1000 (i.e. denoising starts at timestep 1000 and ends at
timestep 0). For the default values, this would mean that the spatial attention computation skipping will
be applicable only after denoising timestep 681 is reached, and continue until the end of the denoising
process.
temporal_attention_timestep_skip_range (`Tuple[float, float]`, *optional*, defaults to `None`):
The timestep range within which the temporal attention computation can be skipped without a significant
loss in quality. This is to be determined by the user based on the underlying model. The first value in the
tuple is the lower bound and the second value is the upper bound. Typically, diffusion timesteps for
denoising are in the reversed range of 0 to 1000 (i.e. denoising starts at timestep 1000 and ends at
timestep 0).
low_frequency_weight_update_timestep_range (`Tuple[int, int]`, defaults to `(99, 901)`):
The timestep range within which the low frequency weight scaling update is applied. The first value in the
tuple is the lower bound and the second value is the upper bound of the timestep range. The callback
function for the update is called only within this range.
high_frequency_weight_update_timestep_range (`Tuple[int, int]`, defaults to `(-1, 301)`):
The timestep range within which the high frequency weight scaling update is applied. The first value in the
tuple is the lower bound and the second value is the upper bound of the timestep range. The callback
function for the update is called only within this range.
alpha_low_frequency (`float`, defaults to `1.1`):
The weight to scale the low frequency updates by. This is used to approximate the unconditional branch from
the conditional branch outputs.
alpha_high_frequency (`float`, defaults to `1.1`):
The weight to scale the high frequency updates by. This is used to approximate the unconditional branch
from the conditional branch outputs.
unconditional_batch_skip_range (`int`, defaults to `5`):
Process the unconditional branch every `N` iterations. If this is set to `N`, the unconditional branch
computation will be skipped `N - 1` times (i.e., cached unconditional branch states will be re-used) before
computing the new unconditional branch states again.
unconditional_batch_timestep_skip_range (`Tuple[float, float]`, defaults to `(-1, 641)`):
The timestep range within which the unconditional branch computation can be skipped without a significant
loss in quality. This is to be determined by the user based on the underlying model. The first value in the
tuple is the lower bound and the second value is the upper bound.
spatial_attention_block_identifiers (`Tuple[str, ...]`, defaults to `("blocks.*attn1", "transformer_blocks.*attn1", "single_transformer_blocks.*attn1")`):
The identifiers to match the spatial attention blocks in the model. If the name of the block contains any
of these identifiers, FasterCache will be applied to that block. This can either be the full layer names,
partial layer names, or regex patterns. Matching will always be done using a regex match.
temporal_attention_block_identifiers (`Tuple[str, ...]`, defaults to `("temporal_transformer_blocks.*attn1",)`):
The identifiers to match the temporal attention blocks in the model. If the name of the block contains any
of these identifiers, FasterCache will be applied to that block. This can either be the full layer names,
partial layer names, or regex patterns. Matching will always be done using a regex match.
attention_weight_callback (`Callable[[torch.nn.Module], float]`, defaults to `None`):
The callback function to determine the weight to scale the attention outputs by. This function should take
the attention module as input and return a float value. This is used to approximate the unconditional
branch from the conditional branch outputs. If not provided, the default weight is 0.5 for all timesteps.
Typically, as described in the paper, this weight should gradually increase from 0 to 1 as the inference
progresses. Users are encouraged to experiment and provide custom weight schedules that take into account
the number of inference steps and underlying model behaviour as denoising progresses.
low_frequency_weight_callback (`Callable[[torch.nn.Module], float]`, defaults to `None`):
The callback function to determine the weight to scale the low frequency updates by. If not provided, the
default weight is 1.1 for timesteps within the range specified (as described in the paper).
high_frequency_weight_callback (`Callable[[torch.nn.Module], float]`, defaults to `None`):
The callback function to determine the weight to scale the high frequency updates by. If not provided, the
default weight is 1.1 for timesteps within the range specified (as described in the paper).
tensor_format (`str`, defaults to `"BCFHW"`):
The format of the input tensors. This should be one of `"BCFHW"`, `"BFCHW"`, or `"BCHW"`. The format is
used to split individual latent frames in order for low and high frequency components to be computed.
is_guidance_distilled (`bool`, defaults to `False`):
Whether the model is guidance distilled or not. If the model is guidance distilled, FasterCache will not be
applied at the denoiser-level to skip the unconditional branch computation (as there is none).
_unconditional_conditional_input_kwargs_identifiers (`List[str]`, defaults to `("hidden_states", "encoder_hidden_states", "timestep", "attention_mask", "encoder_attention_mask")`):
The identifiers to match the input kwargs that contain the batchwise-concatenated unconditional and
conditional inputs. If the name of the input kwargs contains any of these identifiers, FasterCache will
split the inputs into unconditional and conditional branches. This must be a list of exact input kwargs
names that contain the batchwise-concatenated unconditional and conditional inputs.
"""
# In the paper and codebase, they hardcode these values to 2. However, it can be made configurable
# after some testing. We default to 2 if these parameters are not provided.
spatial_attention_block_skip_range: int = 2
temporal_attention_block_skip_range: Optional[int] = None
spatial_attention_timestep_skip_range: Tuple[int, int] = (-1, 681)
temporal_attention_timestep_skip_range: Tuple[int, int] = (-1, 681)
# Indicator functions for low/high frequency as mentioned in Equation 11 of the paper
low_frequency_weight_update_timestep_range: Tuple[int, int] = (99, 901)
high_frequency_weight_update_timestep_range: Tuple[int, int] = (-1, 301)
# 1 and 2 as mentioned in Equation 11 of the paper
alpha_low_frequency: float = 1.1
alpha_high_frequency: float = 1.1
# n as described in CFG-Cache explanation in the paper - dependant on the model
unconditional_batch_skip_range: int = 5
unconditional_batch_timestep_skip_range: Tuple[int, int] = (-1, 641)
spatial_attention_block_identifiers: Tuple[str, ...] = _SPATIAL_ATTENTION_BLOCK_IDENTIFIERS
temporal_attention_block_identifiers: Tuple[str, ...] = _TEMPORAL_ATTENTION_BLOCK_IDENTIFIERS
attention_weight_callback: Callable[[torch.nn.Module], float] = None
low_frequency_weight_callback: Callable[[torch.nn.Module], float] = None
high_frequency_weight_callback: Callable[[torch.nn.Module], float] = None
tensor_format: str = "BCFHW"
is_guidance_distilled: bool = False
current_timestep_callback: Callable[[], int] = None
_unconditional_conditional_input_kwargs_identifiers: List[str] = _UNCOND_COND_INPUT_KWARGS_IDENTIFIERS
def __repr__(self) -> str:
return (
f"FasterCacheConfig(\n"
f" spatial_attention_block_skip_range={self.spatial_attention_block_skip_range},\n"
f" temporal_attention_block_skip_range={self.temporal_attention_block_skip_range},\n"
f" spatial_attention_timestep_skip_range={self.spatial_attention_timestep_skip_range},\n"
f" temporal_attention_timestep_skip_range={self.temporal_attention_timestep_skip_range},\n"
f" low_frequency_weight_update_timestep_range={self.low_frequency_weight_update_timestep_range},\n"
f" high_frequency_weight_update_timestep_range={self.high_frequency_weight_update_timestep_range},\n"
f" alpha_low_frequency={self.alpha_low_frequency},\n"
f" alpha_high_frequency={self.alpha_high_frequency},\n"
f" unconditional_batch_skip_range={self.unconditional_batch_skip_range},\n"
f" unconditional_batch_timestep_skip_range={self.unconditional_batch_timestep_skip_range},\n"
f" spatial_attention_block_identifiers={self.spatial_attention_block_identifiers},\n"
f" temporal_attention_block_identifiers={self.temporal_attention_block_identifiers},\n"
f" tensor_format={self.tensor_format},\n"
f")"
)
class FasterCacheDenoiserState:
r"""
State for [FasterCache](https://huggingface.co/papers/2410.19355) top-level denoiser module.
"""
def __init__(self) -> None:
self.iteration: int = 0
self.low_frequency_delta: torch.Tensor = None
self.high_frequency_delta: torch.Tensor = None
def reset(self):
self.iteration = 0
self.low_frequency_delta = None
self.high_frequency_delta = None
class FasterCacheBlockState:
r"""
State for [FasterCache](https://huggingface.co/papers/2410.19355). Every underlying block that FasterCache is
applied to will have an instance of this state.
"""
def __init__(self) -> None:
self.iteration: int = 0
self.batch_size: int = None
self.cache: Tuple[torch.Tensor, torch.Tensor] = None
def reset(self):
self.iteration = 0
self.batch_size = None
self.cache = None
class FasterCacheDenoiserHook(ModelHook):
_is_stateful = True
def __init__(
self,
unconditional_batch_skip_range: int,
unconditional_batch_timestep_skip_range: Tuple[int, int],
tensor_format: str,
is_guidance_distilled: bool,
uncond_cond_input_kwargs_identifiers: List[str],
current_timestep_callback: Callable[[], int],
low_frequency_weight_callback: Callable[[torch.nn.Module], torch.Tensor],
high_frequency_weight_callback: Callable[[torch.nn.Module], torch.Tensor],
) -> None:
super().__init__()
self.unconditional_batch_skip_range = unconditional_batch_skip_range
self.unconditional_batch_timestep_skip_range = unconditional_batch_timestep_skip_range
# We can't easily detect what args are to be split in unconditional and conditional branches. We
# can only do it for kwargs, hence they are the only ones we split. The args are passed as-is.
# If a model is to be made compatible with FasterCache, the user must ensure that the inputs that
# contain batchwise-concatenated unconditional and conditional inputs are passed as kwargs.
self.uncond_cond_input_kwargs_identifiers = uncond_cond_input_kwargs_identifiers
self.tensor_format = tensor_format
self.is_guidance_distilled = is_guidance_distilled
self.current_timestep_callback = current_timestep_callback
self.low_frequency_weight_callback = low_frequency_weight_callback
self.high_frequency_weight_callback = high_frequency_weight_callback
def initialize_hook(self, module):
self.state = FasterCacheDenoiserState()
return module
@staticmethod
def _get_cond_input(input: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
# Note: this method assumes that the input tensor is batchwise-concatenated with unconditional inputs
# followed by conditional inputs.
_, cond = input.chunk(2, dim=0)
return cond
def new_forward(self, module: torch.nn.Module, *args, **kwargs) -> Any:
# Split the unconditional and conditional inputs. We only want to infer the conditional branch if the
# requirements for skipping the unconditional branch are met as described in the paper.
# We skip the unconditional branch only if the following conditions are met:
# 1. We have completed at least one iteration of the denoiser
# 2. The current timestep is within the range specified by the user. This is the optimal timestep range
# where approximating the unconditional branch from the computation of the conditional branch is possible
# without a significant loss in quality.
# 3. The current iteration is not a multiple of the unconditional batch skip range. This is done so that
# we compute the unconditional branch at least once every few iterations to ensure minimal quality loss.
is_within_timestep_range = (
self.unconditional_batch_timestep_skip_range[0]
< self.current_timestep_callback()
< self.unconditional_batch_timestep_skip_range[1]
)
should_skip_uncond = (
self.state.iteration > 0
and is_within_timestep_range
and self.state.iteration % self.unconditional_batch_skip_range != 0
and not self.is_guidance_distilled
)
if should_skip_uncond:
is_any_kwarg_uncond = any(k in self.uncond_cond_input_kwargs_identifiers for k in kwargs.keys())
if is_any_kwarg_uncond:
logger.debug("FasterCache - Skipping unconditional branch computation")
args = tuple([self._get_cond_input(arg) if torch.is_tensor(arg) else arg for arg in args])
kwargs = {
k: v if k not in self.uncond_cond_input_kwargs_identifiers else self._get_cond_input(v)
for k, v in kwargs.items()
}
output = self.fn_ref.original_forward(*args, **kwargs)
if self.is_guidance_distilled:
self.state.iteration += 1
return output
if torch.is_tensor(output):
hidden_states = output
elif isinstance(output, (tuple, Transformer2DModelOutput)):
hidden_states = output[0]
batch_size = hidden_states.size(0)
if should_skip_uncond:
self.state.low_frequency_delta = self.state.low_frequency_delta * self.low_frequency_weight_callback(
module
)
self.state.high_frequency_delta = self.state.high_frequency_delta * self.high_frequency_weight_callback(
module
)
if self.tensor_format == "BCFHW":
hidden_states = hidden_states.permute(0, 2, 1, 3, 4)
if self.tensor_format == "BCFHW" or self.tensor_format == "BFCHW":
hidden_states = hidden_states.flatten(0, 1)
low_freq_cond, high_freq_cond = _split_low_high_freq(hidden_states.float())
# Approximate/compute the unconditional branch outputs as described in Equation 9 and 10 of the paper
low_freq_uncond = self.state.low_frequency_delta + low_freq_cond
high_freq_uncond = self.state.high_frequency_delta + high_freq_cond
uncond_freq = low_freq_uncond + high_freq_uncond
uncond_states = torch.fft.ifftshift(uncond_freq)
uncond_states = torch.fft.ifft2(uncond_states).real
if self.tensor_format == "BCFHW" or self.tensor_format == "BFCHW":
uncond_states = uncond_states.unflatten(0, (batch_size, -1))
hidden_states = hidden_states.unflatten(0, (batch_size, -1))
if self.tensor_format == "BCFHW":
uncond_states = uncond_states.permute(0, 2, 1, 3, 4)
hidden_states = hidden_states.permute(0, 2, 1, 3, 4)
# Concatenate the approximated unconditional and predicted conditional branches
uncond_states = uncond_states.to(hidden_states.dtype)
hidden_states = torch.cat([uncond_states, hidden_states], dim=0)
else:
uncond_states, cond_states = hidden_states.chunk(2, dim=0)
if self.tensor_format == "BCFHW":
uncond_states = uncond_states.permute(0, 2, 1, 3, 4)
cond_states = cond_states.permute(0, 2, 1, 3, 4)
if self.tensor_format == "BCFHW" or self.tensor_format == "BFCHW":
uncond_states = uncond_states.flatten(0, 1)
cond_states = cond_states.flatten(0, 1)
low_freq_uncond, high_freq_uncond = _split_low_high_freq(uncond_states.float())
low_freq_cond, high_freq_cond = _split_low_high_freq(cond_states.float())
self.state.low_frequency_delta = low_freq_uncond - low_freq_cond
self.state.high_frequency_delta = high_freq_uncond - high_freq_cond
self.state.iteration += 1
if torch.is_tensor(output):
output = hidden_states
elif isinstance(output, tuple):
output = (hidden_states, *output[1:])
else:
output.sample = hidden_states
return output
def reset_state(self, module: torch.nn.Module) -> torch.nn.Module:
self.state.reset()
return module
class FasterCacheBlockHook(ModelHook):
_is_stateful = True
def __init__(
self,
block_skip_range: int,
timestep_skip_range: Tuple[int, int],
is_guidance_distilled: bool,
weight_callback: Callable[[torch.nn.Module], float],
current_timestep_callback: Callable[[], int],
) -> None:
super().__init__()
self.block_skip_range = block_skip_range
self.timestep_skip_range = timestep_skip_range
self.is_guidance_distilled = is_guidance_distilled
self.weight_callback = weight_callback
self.current_timestep_callback = current_timestep_callback
def initialize_hook(self, module):
self.state = FasterCacheBlockState()
return module
def _compute_approximated_attention_output(
self, t_2_output: torch.Tensor, t_output: torch.Tensor, weight: float, batch_size: int
) -> torch.Tensor:
if t_2_output.size(0) != batch_size:
# The cache t_2_output contains both batchwise-concatenated unconditional-conditional branch outputs. Just
# take the conditional branch outputs.
assert t_2_output.size(0) == 2 * batch_size
t_2_output = t_2_output[batch_size:]
if t_output.size(0) != batch_size:
# The cache t_output contains both batchwise-concatenated unconditional-conditional branch outputs. Just
# take the conditional branch outputs.
assert t_output.size(0) == 2 * batch_size
t_output = t_output[batch_size:]
return t_output + (t_output - t_2_output) * weight
def new_forward(self, module: torch.nn.Module, *args, **kwargs) -> Any:
batch_size = [
*[arg.size(0) for arg in args if torch.is_tensor(arg)],
*[v.size(0) for v in kwargs.values() if torch.is_tensor(v)],
][0]
if self.state.batch_size is None:
# Will be updated on first forward pass through the denoiser
self.state.batch_size = batch_size
# If we have to skip due to the skip conditions, then let's skip as expected.
# But, we can't skip if the denoiser wants to infer both unconditional and conditional branches. This
# is because the expected output shapes of attention layer will not match if we only return values from
# the cache (which only caches conditional branch outputs). So, if state.batch_size (which is the true
# unconditional-conditional batch size) is same as the current batch size, we don't perform the layer
# skip. Otherwise, we conditionally skip the layer based on what state.skip_callback returns.
is_within_timestep_range = (
self.timestep_skip_range[0] < self.current_timestep_callback() < self.timestep_skip_range[1]
)
if not is_within_timestep_range:
should_skip_attention = False
else:
should_compute_attention = self.state.iteration > 0 and self.state.iteration % self.block_skip_range == 0
should_skip_attention = not should_compute_attention
if should_skip_attention:
should_skip_attention = self.is_guidance_distilled or self.state.batch_size != batch_size
if should_skip_attention:
logger.debug("FasterCache - Skipping attention and using approximation")
if torch.is_tensor(self.state.cache[-1]):
t_2_output, t_output = self.state.cache
weight = self.weight_callback(module)
output = self._compute_approximated_attention_output(t_2_output, t_output, weight, batch_size)
else:
# The cache contains multiple tensors from past N iterations (N=2 for FasterCache). We need to handle all of them.
# Diffusers blocks can return multiple tensors - let's call them [A, B, C, ...] for simplicity.
# In our cache, we would have [[A_1, B_1, C_1, ...], [A_2, B_2, C_2, ...], ...] where each list is the output from
# a forward pass of the block. We need to compute the approximated output for each of these tensors.
# The zip(*state.cache) operation will give us [(A_1, A_2, ...), (B_1, B_2, ...), (C_1, C_2, ...), ...] which
# allows us to compute the approximated attention output for each tensor in the cache.
output = ()
for t_2_output, t_output in zip(*self.state.cache):
result = self._compute_approximated_attention_output(
t_2_output, t_output, self.weight_callback(module), batch_size
)
output += (result,)
else:
logger.debug("FasterCache - Computing attention")
output = self.fn_ref.original_forward(*args, **kwargs)
# Note that the following condition for getting hidden_states should suffice since Diffusers blocks either return
# a single hidden_states tensor, or a tuple of (hidden_states, encoder_hidden_states) tensors. We need to handle
# both cases.
if torch.is_tensor(output):
cache_output = output
if not self.is_guidance_distilled and cache_output.size(0) == self.state.batch_size:
# The output here can be both unconditional-conditional branch outputs or just conditional branch outputs.
# This is determined at the higher-level denoiser module. We only want to cache the conditional branch outputs.
cache_output = cache_output.chunk(2, dim=0)[1]
else:
# Cache all return values and perform the same operation as above
cache_output = ()
for out in output:
if not self.is_guidance_distilled and out.size(0) == self.state.batch_size:
out = out.chunk(2, dim=0)[1]
cache_output += (out,)
if self.state.cache is None:
self.state.cache = [cache_output, cache_output]
else:
self.state.cache = [self.state.cache[-1], cache_output]
self.state.iteration += 1
return output
def reset_state(self, module: torch.nn.Module) -> torch.nn.Module:
self.state.reset()
return module
def apply_faster_cache(module: torch.nn.Module, config: FasterCacheConfig) -> None:
r"""
Applies [FasterCache](https://huggingface.co/papers/2410.19355) to a given pipeline.
Args:
pipeline (`DiffusionPipeline`):
The diffusion pipeline to apply FasterCache to.
config (`Optional[FasterCacheConfig]`, `optional`, defaults to `None`):
The configuration to use for FasterCache.
Example:
```python
>>> import torch
>>> from diffusers import CogVideoXPipeline, FasterCacheConfig, apply_faster_cache
>>> pipe = CogVideoXPipeline.from_pretrained("THUDM/CogVideoX-5b", torch_dtype=torch.bfloat16)
>>> pipe.to("cuda")
>>> config = FasterCacheConfig(
... spatial_attention_block_skip_range=2,
... spatial_attention_timestep_skip_range=(-1, 681),
... low_frequency_weight_update_timestep_range=(99, 641),
... high_frequency_weight_update_timestep_range=(-1, 301),
... spatial_attention_block_identifiers=["transformer_blocks"],
... attention_weight_callback=lambda _: 0.3,
... tensor_format="BFCHW",
... )
>>> apply_faster_cache(pipe.transformer, config)
```
"""
logger.warning(
"FasterCache is a purely experimental feature and may not work as expected. Not all models support FasterCache. "
"The API is subject to change in future releases, with no guarantee of backward compatibility. Please report any issues at "
"https://github.com/huggingface/diffusers/issues."
)
if config.attention_weight_callback is None:
# If the user has not provided a weight callback, we default to 0.5 for all timesteps.
# In the paper, they recommend using a gradually increasing weight from 0 to 1 as the inference progresses, but
# this depends from model-to-model. It is required by the user to provide a weight callback if they want to
# use a different weight function. Defaulting to 0.5 works well in practice for most cases.
logger.warning(
"No `attention_weight_callback` provided when enabling FasterCache. Defaulting to using a weight of 0.5 for all timesteps."
)
config.attention_weight_callback = lambda _: 0.5
if config.low_frequency_weight_callback is None:
logger.debug(
"Low frequency weight callback not provided when enabling FasterCache. Defaulting to behaviour described in the paper."
)
def low_frequency_weight_callback(module: torch.nn.Module) -> float:
is_within_range = (
config.low_frequency_weight_update_timestep_range[0]
< config.current_timestep_callback()
< config.low_frequency_weight_update_timestep_range[1]
)
return config.alpha_low_frequency if is_within_range else 1.0
config.low_frequency_weight_callback = low_frequency_weight_callback
if config.high_frequency_weight_callback is None:
logger.debug(
"High frequency weight callback not provided when enabling FasterCache. Defaulting to behaviour described in the paper."
)
def high_frequency_weight_callback(module: torch.nn.Module) -> float:
is_within_range = (
config.high_frequency_weight_update_timestep_range[0]
< config.current_timestep_callback()
< config.high_frequency_weight_update_timestep_range[1]
)
return config.alpha_high_frequency if is_within_range else 1.0
config.high_frequency_weight_callback = high_frequency_weight_callback
supported_tensor_formats = ["BCFHW", "BFCHW", "BCHW"] # TODO(aryan): Support BSC for LTX Video
if config.tensor_format not in supported_tensor_formats:
raise ValueError(f"`tensor_format` must be one of {supported_tensor_formats}, but got {config.tensor_format}.")
_apply_faster_cache_on_denoiser(module, config)
for name, submodule in module.named_modules():
if not isinstance(submodule, _ATTENTION_CLASSES):
continue
if any(re.search(identifier, name) is not None for identifier in _TRANSFORMER_BLOCK_IDENTIFIERS):
_apply_faster_cache_on_attention_class(name, submodule, config)
def _apply_faster_cache_on_denoiser(module: torch.nn.Module, config: FasterCacheConfig) -> None:
hook = FasterCacheDenoiserHook(
config.unconditional_batch_skip_range,
config.unconditional_batch_timestep_skip_range,
config.tensor_format,
config.is_guidance_distilled,
config._unconditional_conditional_input_kwargs_identifiers,
config.current_timestep_callback,
config.low_frequency_weight_callback,
config.high_frequency_weight_callback,
)
registry = HookRegistry.check_if_exists_or_initialize(module)
registry.register_hook(hook, _FASTER_CACHE_DENOISER_HOOK)
def _apply_faster_cache_on_attention_class(name: str, module: Attention, config: FasterCacheConfig) -> None:
is_spatial_self_attention = (
any(re.search(identifier, name) is not None for identifier in config.spatial_attention_block_identifiers)
and config.spatial_attention_block_skip_range is not None
and not getattr(module, "is_cross_attention", False)
)
is_temporal_self_attention = (
any(re.search(identifier, name) is not None for identifier in config.temporal_attention_block_identifiers)
and config.temporal_attention_block_skip_range is not None
and not module.is_cross_attention
)
block_skip_range, timestep_skip_range, block_type = None, None, None
if is_spatial_self_attention:
block_skip_range = config.spatial_attention_block_skip_range
timestep_skip_range = config.spatial_attention_timestep_skip_range
block_type = "spatial"
elif is_temporal_self_attention:
block_skip_range = config.temporal_attention_block_skip_range
timestep_skip_range = config.temporal_attention_timestep_skip_range
block_type = "temporal"
if block_skip_range is None or timestep_skip_range is None:
logger.debug(
f'Unable to apply FasterCache to the selected layer: "{name}" because it does '
f"not match any of the required criteria for spatial or temporal attention layers. Note, "
f"however, that this layer may still be valid for applying PAB. Please specify the correct "
f"block identifiers in the configuration or use the specialized `apply_faster_cache_on_module` "
f"function to apply FasterCache to this layer."
)
return
logger.debug(f"Enabling FasterCache ({block_type}) for layer: {name}")
hook = FasterCacheBlockHook(
block_skip_range,
timestep_skip_range,
config.is_guidance_distilled,
config.attention_weight_callback,
config.current_timestep_callback,
)
registry = HookRegistry.check_if_exists_or_initialize(module)
registry.register_hook(hook, _FASTER_CACHE_BLOCK_HOOK)
# Reference: https://github.com/Vchitect/FasterCache/blob/fab32c15014636dc854948319c0a9a8d92c7acb4/scripts/latte/faster_cache_sample_latte.py#L127C1-L143C39
@torch.no_grad()
def _split_low_high_freq(x):
fft = torch.fft.fft2(x)
fft_shifted = torch.fft.fftshift(fft)
height, width = x.shape[-2:]
radius = min(height, width) // 5
y_grid, x_grid = torch.meshgrid(torch.arange(height), torch.arange(width))
center_x, center_y = width // 2, height // 2
mask = (x_grid - center_x) ** 2 + (y_grid - center_y) ** 2 <= radius**2
low_freq_mask = mask.unsqueeze(0).unsqueeze(0).to(x.device)
high_freq_mask = ~low_freq_mask
low_freq_fft = fft_shifted * low_freq_mask
high_freq_fft = fft_shifted * high_freq_mask
return low_freq_fft, high_freq_fft
-223
View File
@@ -1,223 +0,0 @@
# Copyright 2024 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from dataclasses import dataclass
from typing import Tuple, Union
import torch
from ..utils import get_logger
from ..utils.torch_utils import unwrap_module
from ._common import _ALL_TRANSFORMER_BLOCK_IDENTIFIERS
from ._helpers import TransformerBlockRegistry
from .hooks import BaseMarkedState, HookRegistry, ModelHook
logger = get_logger(__name__) # pylint: disable=invalid-name
_FBC_LEADER_BLOCK_HOOK = "fbc_leader_block_hook"
_FBC_BLOCK_HOOK = "fbc_block_hook"
@dataclass
class FirstBlockCacheConfig:
r"""
Configuration for [First Block
Cache](https://github.com/chengzeyi/ParaAttention/blob/7a266123671b55e7e5a2fe9af3121f07a36afc78/README.md#first-block-cache-our-dynamic-caching).
Args:
threshold (`float`, defaults to `0.05`):
The threshold to determine whether or not a forward pass through all layers of the model is required. A
higher threshold usually results in lower number of forward passes and faster inference, but might lead to
poorer generation quality. A lower threshold may not result in significant generation speedup. The
threshold is compared against the absmean difference of the residuals between the current and cached
outputs from the first transformer block. If the difference is below the threshold, the forward pass is
skipped.
"""
threshold: float = 0.05
class FBCSharedBlockState(BaseMarkedState):
def __init__(self) -> None:
super().__init__()
self.head_block_output: Union[torch.Tensor, Tuple[torch.Tensor, ...]] = None
self.head_block_residual: torch.Tensor = None
self.tail_block_residuals: Union[torch.Tensor, Tuple[torch.Tensor, ...]] = None
self.should_compute: bool = True
def reset(self):
self.tail_block_residuals = None
self.should_compute = True
class FBCHeadBlockHook(ModelHook):
_is_stateful = True
def __init__(self, shared_state: FBCSharedBlockState, threshold: float):
self.shared_state = shared_state
self.threshold = threshold
self._metadata = None
def initialize_hook(self, module):
self._metadata = TransformerBlockRegistry.get(unwrap_module(module).__class__)
return module
def new_forward(self, module: torch.nn.Module, *args, **kwargs):
outputs_if_skipped = self._metadata.skip_block_output_fn(module, *args, **kwargs)
original_hs = outputs_if_skipped[self._metadata.return_hidden_states_index]
output = self.fn_ref.original_forward(*args, **kwargs)
is_output_tuple = isinstance(output, tuple)
if is_output_tuple:
hs_residual = output[self._metadata.return_hidden_states_index] - original_hs
else:
hs_residual = output - original_hs
hs, ehs = None, None
should_compute = self._should_compute_remaining_blocks(hs_residual)
self.shared_state.should_compute = should_compute
if not should_compute:
# Apply caching
if is_output_tuple:
hs = self.shared_state.tail_block_residuals[0] + output[self._metadata.return_hidden_states_index]
else:
hs = self.shared_state.tail_block_residuals[0] + output
if self._metadata.return_encoder_hidden_states_index is not None:
assert is_output_tuple
ehs = (
self.shared_state.tail_block_residuals[1]
+ output[self._metadata.return_encoder_hidden_states_index]
)
if is_output_tuple:
return_output = [None] * len(output)
return_output[self._metadata.return_hidden_states_index] = hs
return_output[self._metadata.return_encoder_hidden_states_index] = ehs
return_output = tuple(return_output)
else:
return_output = hs
output = return_output
else:
if is_output_tuple:
head_block_output = [None] * len(output)
head_block_output[0] = output[self._metadata.return_hidden_states_index]
head_block_output[1] = output[self._metadata.return_encoder_hidden_states_index]
else:
head_block_output = output
self.shared_state.head_block_output = head_block_output
self.shared_state.head_block_residual = hs_residual
return output
def reset_state(self, module):
self.shared_state.reset()
return module
@torch.compiler.disable
def _should_compute_remaining_blocks(self, hs_residual: torch.Tensor) -> bool:
if self.shared_state.head_block_residual is None:
return True
prev_hs_residual = self.shared_state.head_block_residual
hs_absmean = (hs_residual - prev_hs_residual).abs().mean()
prev_hs_mean = prev_hs_residual.abs().mean()
diff = (hs_absmean / prev_hs_mean).item()
return diff > self.threshold
class FBCBlockHook(ModelHook):
def __init__(self, shared_state: FBCSharedBlockState, is_tail: bool = False):
super().__init__()
self.shared_state = shared_state
self.is_tail = is_tail
self._metadata = None
def initialize_hook(self, module):
self._metadata = TransformerBlockRegistry.get(unwrap_module(module).__class__)
return module
def new_forward(self, module: torch.nn.Module, *args, **kwargs):
outputs_if_skipped = self._metadata.skip_block_output_fn(module, *args, **kwargs)
if not isinstance(outputs_if_skipped, tuple):
outputs_if_skipped = (outputs_if_skipped,)
original_hs = outputs_if_skipped[self._metadata.return_hidden_states_index]
original_ehs = None
if self._metadata.return_encoder_hidden_states_index is not None:
original_ehs = outputs_if_skipped[self._metadata.return_encoder_hidden_states_index]
if self.shared_state.should_compute:
output = self.fn_ref.original_forward(*args, **kwargs)
if self.is_tail:
hs_residual, ehs_residual = None, None
if isinstance(output, tuple):
hs_residual = (
output[self._metadata.return_hidden_states_index] - self.shared_state.head_block_output[0]
)
ehs_residual = (
output[self._metadata.return_encoder_hidden_states_index]
- self.shared_state.head_block_output[1]
)
else:
hs_residual = output - self.shared_state.head_block_output
self.shared_state.tail_block_residuals = (hs_residual, ehs_residual)
return output
output_count = len(outputs_if_skipped)
if output_count == 1:
return_output = original_hs
else:
return_output = [None] * output_count
return_output[self._metadata.return_hidden_states_index] = original_hs
return_output[self._metadata.return_encoder_hidden_states_index] = original_ehs
return return_output
def apply_first_block_cache(module: torch.nn.Module, config: FirstBlockCacheConfig) -> None:
shared_state = FBCSharedBlockState()
remaining_blocks = []
for name, submodule in module.named_children():
if name not in _ALL_TRANSFORMER_BLOCK_IDENTIFIERS or not isinstance(submodule, torch.nn.ModuleList):
continue
for index, block in enumerate(submodule):
remaining_blocks.append((f"{name}.{index}", block))
head_block_name, head_block = remaining_blocks.pop(0)
tail_block_name, tail_block = remaining_blocks.pop(-1)
logger.debug(f"Apply FBCHeadBlockHook to '{head_block_name}'")
apply_fbc_head_block_hook(head_block, shared_state, config.threshold)
for name, block in remaining_blocks:
logger.debug(f"Apply FBCBlockHook to '{name}'")
apply_fbc_block_hook(block, shared_state)
logger.debug(f"Apply FBCBlockHook to tail block '{tail_block_name}'")
apply_fbc_block_hook(tail_block, shared_state, is_tail=True)
def apply_fbc_head_block_hook(block: torch.nn.Module, state: FBCSharedBlockState, threshold: float) -> None:
registry = HookRegistry.check_if_exists_or_initialize(block)
hook = FBCHeadBlockHook(state, threshold)
registry.register_hook(hook, _FBC_LEADER_BLOCK_HOOK)
def apply_fbc_block_hook(block: torch.nn.Module, state: FBCSharedBlockState, is_tail: bool = False) -> None:
registry = HookRegistry.check_if_exists_or_initialize(block)
hook = FBCBlockHook(state, is_tail)
registry.register_hook(hook, _FBC_BLOCK_HOOK)
+155 -113
View File
@@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from contextlib import contextmanager, nullcontext
from contextlib import nullcontext
from typing import Dict, List, Optional, Set, Tuple
import torch
@@ -29,11 +29,16 @@ if is_accelerate_available():
logger = get_logger(__name__) # pylint: disable=invalid-name
# Removed PinnedGroupManager - we no longer use pinned memory to avoid CPU memory spikes
# fmt: off
_GROUP_OFFLOADING = "group_offloading"
_LAYER_EXECUTION_TRACKER = "layer_execution_tracker"
_LAZY_PREFETCH_GROUP_OFFLOADING = "lazy_prefetch_group_offloading"
# Always use memory-efficient CPU offloading to minimize RAM usage
_SUPPORTED_PYTORCH_LAYERS = (
torch.nn.Conv1d, torch.nn.Conv2d, torch.nn.Conv3d,
torch.nn.ConvTranspose1d, torch.nn.ConvTranspose2d, torch.nn.ConvTranspose3d,
@@ -56,7 +61,6 @@ class ModuleGroup:
buffers: Optional[List[torch.Tensor]] = None,
non_blocking: bool = False,
stream: Optional[torch.cuda.Stream] = None,
low_cpu_mem_usage=False,
onload_self: bool = True,
) -> None:
self.modules = modules
@@ -64,50 +68,11 @@ class ModuleGroup:
self.onload_device = onload_device
self.offload_leader = offload_leader
self.onload_leader = onload_leader
self.parameters = parameters or []
self.buffers = buffers or []
self.parameters = parameters
self.buffers = buffers
self.non_blocking = non_blocking or stream is not None
self.stream = stream
self.onload_self = onload_self
self.low_cpu_mem_usage = low_cpu_mem_usage
self.cpu_param_dict = self._init_cpu_param_dict()
def _init_cpu_param_dict(self):
cpu_param_dict = {}
if self.stream is None:
return cpu_param_dict
for module in self.modules:
for param in module.parameters():
cpu_param_dict[param] = param.data.cpu() if self.low_cpu_mem_usage else param.data.cpu().pin_memory()
for buffer in module.buffers():
cpu_param_dict[buffer] = (
buffer.data.cpu() if self.low_cpu_mem_usage else buffer.data.cpu().pin_memory()
)
for param in self.parameters:
cpu_param_dict[param] = param.data.cpu() if self.low_cpu_mem_usage else param.data.cpu().pin_memory()
for buffer in self.buffers:
cpu_param_dict[buffer] = buffer.data.cpu() if self.low_cpu_mem_usage else buffer.data.cpu().pin_memory()
return cpu_param_dict
@contextmanager
def _pinned_memory_tensors(self):
pinned_dict = {}
try:
for param, tensor in self.cpu_param_dict.items():
if not tensor.is_pinned():
pinned_dict[param] = tensor.pin_memory()
else:
pinned_dict[param] = tensor
yield pinned_dict
finally:
pinned_dict = None
def onload_(self):
r"""Onloads the group of modules to the onload_device."""
@@ -117,52 +82,136 @@ class ModuleGroup:
self.stream.synchronize()
with context:
if self.stream is not None:
with self._pinned_memory_tensors() as pinned_memory:
for group_module in self.modules:
for param in group_module.parameters():
param.data = pinned_memory[param].to(self.onload_device, non_blocking=self.non_blocking)
for buffer in group_module.buffers():
buffer.data = pinned_memory[buffer].to(self.onload_device, non_blocking=self.non_blocking)
for param in self.parameters:
param.data = pinned_memory[param].to(self.onload_device, non_blocking=self.non_blocking)
for buffer in self.buffers:
buffer.data = pinned_memory[buffer].to(self.onload_device, non_blocking=self.non_blocking)
else:
# Use the most efficient module-level transfer when possible
# This approach mirrors how PyTorch handles full model transfers
if self.modules:
for group_module in self.modules:
for param in group_module.parameters():
param.data = param.data.to(self.onload_device, non_blocking=self.non_blocking)
for buffer in group_module.buffers():
buffer.data = buffer.data.to(self.onload_device, non_blocking=self.non_blocking)
# Only onload if some parameters are not on the target device
if any(p.device != self.onload_device for p in group_module.parameters()):
try:
# Try the most efficient approach using _apply
if hasattr(group_module, "_apply"):
# This is what module.to() uses internally
def to_device(t):
if t.device != self.onload_device:
if self.onload_device.type == "cuda":
return t.cuda(self.onload_device.index,
non_blocking=self.non_blocking)
else:
return t.to(self.onload_device,
non_blocking=self.non_blocking)
return t
# Apply to all tensors without unnecessary copies
group_module._apply(to_device)
else:
# Fallback to direct parameter transfer
for param in group_module.parameters():
if param.device != self.onload_device:
if self.onload_device.type == "cuda":
param.data = param.data.cuda(self.onload_device.index,
non_blocking=self.non_blocking)
else:
param.data = param.data.to(self.onload_device,
non_blocking=self.non_blocking)
except Exception as e:
# If optimization fails, fall back to direct parameter transfer
logger.warning(f"Optimized onloading failed: {e}, falling back to direct method")
for param in group_module.parameters():
if param.device != self.onload_device:
if self.onload_device.type == "cuda":
param.data = param.data.cuda(self.onload_device.index,
non_blocking=self.non_blocking)
else:
param.data = param.data.to(self.onload_device,
non_blocking=self.non_blocking)
# Handle explicit parameters
if self.parameters is not None:
for param in self.parameters:
param.data = param.data.to(self.onload_device, non_blocking=self.non_blocking)
if param.device != self.onload_device:
if self.onload_device.type == "cuda":
param.data = param.data.cuda(self.onload_device.index,
non_blocking=self.non_blocking)
else:
param.data = param.data.to(self.onload_device,
non_blocking=self.non_blocking)
# Handle buffers
if self.buffers is not None:
for buffer in self.buffers:
buffer.data = buffer.data.to(self.onload_device, non_blocking=self.non_blocking)
if buffer.device != self.onload_device:
if self.onload_device.type == "cuda":
buffer.data = buffer.data.cuda(self.onload_device.index,
non_blocking=self.non_blocking)
else:
buffer.data = buffer.data.to(self.onload_device,
non_blocking=self.non_blocking)
def offload_(self):
r"""Offloads the group of modules to the offload_device."""
if self.stream is not None:
torch.cuda.current_stream().synchronize()
for group_module in self.modules:
for param in group_module.parameters():
param.data = self.cpu_param_dict[param]
for param in self.parameters:
param.data = self.cpu_param_dict[param]
for buffer in self.buffers:
buffer.data = self.cpu_param_dict[buffer]
# For CPU offloading
if self.offload_device.type == "cpu":
# Synchronize if using stream
if self.stream is not None:
torch.cuda.current_stream().synchronize()
# Empty GPU cache before offloading to reduce memory fragmentation
if torch.cuda.is_available():
torch.cuda.empty_cache()
# For module groups, use a single, unified approach that is closest to
# the behavior of model.to("cpu")
if self.modules:
for group_module in self.modules:
# Check if we need to offload this module
if any(p.device.type != "cpu" for p in group_module.parameters()):
# Use PyTorch's built-in to() method directly, which preserves
# memory mapping when moving to CPU
try:
# Non-blocking=False for CPU transfers, as it ensures memory is
# immediately available and potentially preserves memory mapping
group_module.to("cpu", non_blocking=False)
except Exception as e:
# If there's any error, fall back to parameter-level offloading
logger.warning(f"Module-level CPU offloading failed: {e}, falling back to parameter-level")
for param in group_module.parameters():
if param.device.type != "cpu":
param.data = param.data.to("cpu", non_blocking=False)
# Handle explicit parameters - move directly to CPU with non-blocking=False
# which can preserve memory mapping in some PyTorch versions
if self.parameters is not None:
for param in self.parameters:
if param.device.type != "cpu":
param.data = param.data.to("cpu", non_blocking=False)
# Handle buffers
if self.buffers is not None:
for buffer in self.buffers:
if buffer.device.type != "cpu":
buffer.data = buffer.data.to("cpu", non_blocking=False)
# Let Python's normal reference counting handle cleanup
# We don't force garbage collection to avoid slowing down inference
else:
# For non-CPU offloading, synchronize if using stream
if self.stream is not None:
torch.cuda.current_stream().synchronize()
# For non-CPU offloading, use the regular approach
for group_module in self.modules:
group_module.to(self.offload_device, non_blocking=self.non_blocking)
for param in self.parameters:
param.data = param.data.to(self.offload_device, non_blocking=self.non_blocking)
for buffer in self.buffers:
buffer.data = buffer.data.to(self.offload_device, non_blocking=self.non_blocking)
if self.parameters is not None:
for param in self.parameters:
param.data = param.data.to(self.offload_device, non_blocking=self.non_blocking)
if self.buffers is not None:
for buffer in self.buffers:
buffer.data = buffer.data.to(self.offload_device, non_blocking=self.non_blocking)
# After offloading, we can unpin the memory if configured to do so
# We'll keep it pinned by default for better performance
class GroupOffloadingHook(ModelHook):
@@ -185,6 +234,7 @@ class GroupOffloadingHook(ModelHook):
def initialize_hook(self, module: torch.nn.Module) -> torch.nn.Module:
if self.group.offload_leader == module:
# Offload to CPU
self.group.offload_()
return module
@@ -228,13 +278,6 @@ class LazyPrefetchGroupOffloadingHook(ModelHook):
self._layer_execution_tracker_module_names = set()
def initialize_hook(self, module):
def make_execution_order_update_callback(current_name, current_submodule):
def callback():
logger.debug(f"Adding {current_name} to the execution order")
self.execution_order.append((current_name, current_submodule))
return callback
# To every submodule that contains a group offloading hook (at this point, no prefetching is enabled for any
# of the groups), we add a layer execution tracker hook that will be used to determine the order in which the
# layers are executed during the forward pass.
@@ -246,8 +289,14 @@ class LazyPrefetchGroupOffloadingHook(ModelHook):
group_offloading_hook = registry.get_hook(_GROUP_OFFLOADING)
if group_offloading_hook is not None:
# For the first forward pass, we have to load in a blocking manner
group_offloading_hook.group.non_blocking = False
def make_execution_order_update_callback(current_name, current_submodule):
def callback():
logger.debug(f"Adding {current_name} to the execution order")
self.execution_order.append((current_name, current_submodule))
return callback
layer_tracker_hook = LayerExecutionTrackerHook(make_execution_order_update_callback(name, submodule))
registry.register_hook(layer_tracker_hook, _LAYER_EXECUTION_TRACKER)
self._layer_execution_tracker_module_names.add(name)
@@ -277,7 +326,6 @@ class LazyPrefetchGroupOffloadingHook(ModelHook):
# Remove the layer execution tracker hooks from the submodules
base_module_registry = module._diffusers_hook
registries = [submodule._diffusers_hook for _, submodule in self.execution_order]
group_offloading_hooks = [registry.get_hook(_GROUP_OFFLOADING) for registry in registries]
for i in range(num_executed):
registries[i].remove_hook(_LAYER_EXECUTION_TRACKER, recurse=False)
@@ -285,13 +333,8 @@ class LazyPrefetchGroupOffloadingHook(ModelHook):
# Remove the current lazy prefetch group offloading hook so that it doesn't interfere with the next forward pass
base_module_registry.remove_hook(_LAZY_PREFETCH_GROUP_OFFLOADING, recurse=False)
# LazyPrefetchGroupOffloadingHook is only used with streams, so we know that non_blocking should be True.
# We disable non_blocking for the first forward pass, but need to enable it for the subsequent passes to
# see the benefits of prefetching.
for hook in group_offloading_hooks:
hook.group.non_blocking = True
# Set required attributes for prefetching
# Apply lazy prefetching by setting required attributes
group_offloading_hooks = [registry.get_hook(_GROUP_OFFLOADING) for registry in registries]
if num_executed > 0:
base_module_group_offloading_hook = base_module_registry.get_hook(_GROUP_OFFLOADING)
base_module_group_offloading_hook.next_group = group_offloading_hooks[0].group
@@ -331,7 +374,6 @@ def apply_group_offloading(
num_blocks_per_group: Optional[int] = None,
non_blocking: bool = False,
use_stream: bool = False,
low_cpu_mem_usage: bool = False,
) -> None:
r"""
Applies group offloading to the internal layers of a torch.nn.Module. To understand what group offloading is, and
@@ -377,11 +419,8 @@ def apply_group_offloading(
If True, offloading and onloading is done with non-blocking data transfer.
use_stream (`bool`, defaults to `False`):
If True, offloading and onloading is done asynchronously using a CUDA stream. This can be useful for
overlapping computation and data transfer.
low_cpu_mem_usage (`bool`, defaults to `False`):
If True, the CPU memory usage is minimized by pinning tensors on-the-fly instead of pre-pinning them. This
option only matters when using streamed CPU offloading (i.e. `use_stream=True`). This can be useful when
the CPU memory is a bottleneck but may counteract the benefits of using streams.
overlapping computation and data transfer. Memory-efficient CPU offloading is automatically used
to minimize RAM usage by preserving memory mapping benefits and avoiding unnecessary copies.
Example:
```python
@@ -412,17 +451,22 @@ def apply_group_offloading(
_raise_error_if_accelerate_model_or_sequential_hook_present(module)
# We no longer need a pinned group manager as we're not using pinned memory
if offload_type == "block_level":
if num_blocks_per_group is None:
raise ValueError("num_blocks_per_group must be provided when using offload_type='block_level'.")
_apply_group_offloading_block_level(
module, num_blocks_per_group, offload_device, onload_device, non_blocking, stream, low_cpu_mem_usage
module,
num_blocks_per_group,
offload_device,
onload_device,
non_blocking,
stream,
)
elif offload_type == "leaf_level":
_apply_group_offloading_leaf_level(
module, offload_device, onload_device, non_blocking, stream, low_cpu_mem_usage
)
_apply_group_offloading_leaf_level(module, offload_device, onload_device, non_blocking, stream)
else:
raise ValueError(f"Unsupported offload_type: {offload_type}")
@@ -434,7 +478,6 @@ def _apply_group_offloading_block_level(
onload_device: torch.device,
non_blocking: bool,
stream: Optional[torch.cuda.Stream] = None,
low_cpu_mem_usage: bool = False,
) -> None:
r"""
This function applies offloading to groups of torch.nn.ModuleList or torch.nn.Sequential blocks. In comparison to
@@ -455,6 +498,8 @@ def _apply_group_offloading_block_level(
for overlapping computation and data transfer.
"""
# We no longer need a CPU parameter dictionary
# Create module groups for ModuleList and Sequential blocks
modules_with_group_offloading = set()
unmatched_modules = []
@@ -475,7 +520,6 @@ def _apply_group_offloading_block_level(
onload_leader=current_modules[0],
non_blocking=non_blocking,
stream=stream,
low_cpu_mem_usage=low_cpu_mem_usage,
onload_self=stream is None,
)
matched_module_groups.append(group)
@@ -524,7 +568,6 @@ def _apply_group_offloading_leaf_level(
onload_device: torch.device,
non_blocking: bool,
stream: Optional[torch.cuda.Stream] = None,
low_cpu_mem_usage: bool = False,
) -> None:
r"""
This function applies offloading to groups of leaf modules in a torch.nn.Module. This method has minimal memory
@@ -547,6 +590,8 @@ def _apply_group_offloading_leaf_level(
for overlapping computation and data transfer.
"""
# We no longer need a CPU parameter dictionary
# Create module groups for leaf modules and apply group offloading hooks
modules_with_group_offloading = set()
for name, submodule in module.named_modules():
@@ -560,7 +605,6 @@ def _apply_group_offloading_leaf_level(
onload_leader=submodule,
non_blocking=non_blocking,
stream=stream,
low_cpu_mem_usage=low_cpu_mem_usage,
onload_self=True,
)
_apply_group_offloading_hook(submodule, group, None)
@@ -605,7 +649,6 @@ def _apply_group_offloading_leaf_level(
buffers=buffers,
non_blocking=non_blocking,
stream=stream,
low_cpu_mem_usage=low_cpu_mem_usage,
onload_self=True,
)
_apply_group_offloading_hook(parent_module, group, None)
@@ -624,7 +667,6 @@ def _apply_group_offloading_leaf_level(
buffers=None,
non_blocking=False,
stream=None,
low_cpu_mem_usage=low_cpu_mem_usage,
onload_self=True,
)
_apply_lazy_group_offloading_hook(module, unmatched_group, None)
+1 -92
View File
@@ -18,76 +18,11 @@ from typing import Any, Dict, Optional, Tuple
import torch
from ..utils.logging import get_logger
from ..utils.torch_utils import unwrap_module
logger = get_logger(__name__) # pylint: disable=invalid-name
class BaseState:
def reset(self, *args, **kwargs) -> None:
raise NotImplementedError(
"BaseState::reset is not implemented. Please implement this method in the derived class."
)
class BaseMarkedState(BaseState):
def __init__(self, init_args=None, init_kwargs=None):
super().__init__()
self._init_args = init_args if init_args is not None else ()
self._init_kwargs = init_kwargs if init_kwargs is not None else {}
self._mark_name = None
self._state_cache = {}
def get_current_state(self) -> "BaseMarkedState":
if self._mark_name is None:
# If no mark name is set, simply return a dummy object since we're not going to be using it
return self
if self._mark_name not in self._state_cache.keys():
self._state_cache[self._mark_name] = self.__class__(*self._init_args, **self._init_kwargs)
return self._state_cache[self._mark_name]
def mark_state(self, name: str) -> None:
self._mark_name = name
def reset(self, *args, **kwargs) -> None:
for name, state in list(self._state_cache.items()):
state.reset(*args, **kwargs)
self._state_cache.pop(name)
self._mark_name = None
def __getattribute__(self, name):
if name in (
"get_current_state",
"mark_state",
"reset",
"_init_args",
"_init_kwargs",
"_mark_name",
"_state_cache",
) or _is_dunder_method(name):
return object.__getattribute__(self, name)
else:
current_state = BaseMarkedState.get_current_state(self)
return object.__getattribute__(current_state, name)
def __setattr__(self, name, value):
if name in (
"get_current_state",
"mark_state",
"reset",
"_init_args",
"_init_kwargs",
"_mark_name",
"_state_cache",
) or _is_dunder_method(name):
object.__setattr__(self, name, value)
else:
current_state = BaseMarkedState.get_current_state(self)
object.__setattr__(current_state, name, value)
class ModelHook:
r"""
A hook that contains callbacks to be executed just before and after the forward method of a model.
@@ -164,14 +99,6 @@ class ModelHook:
raise NotImplementedError("This hook is stateful and needs to implement the `reset_state` method.")
return module
def _mark_state(self, module: torch.nn.Module, name: str) -> None:
# Iterate over all attributes of the hook to see if any of them have the type `BaseMarkedState`. If so, call `mark_state` on them.
for attr_name in dir(self):
attr = getattr(self, attr_name)
if isinstance(attr, BaseMarkedState):
attr.mark_state(name)
return module
class HookFunctionReference:
def __init__(self) -> None:
@@ -284,10 +211,9 @@ class HookRegistry:
hook.reset_state(self._module_ref)
if recurse:
for module_name, module in unwrap_module(self._module_ref).named_modules():
for module_name, module in self._module_ref.named_modules():
if module_name == "":
continue
module = unwrap_module(module)
if hasattr(module, "_diffusers_hook"):
module._diffusers_hook.reset_stateful_hooks(recurse=False)
@@ -297,19 +223,6 @@ class HookRegistry:
module._diffusers_hook = cls(module)
return module._diffusers_hook
def _mark_state(self, name: str) -> None:
for hook_name in reversed(self._hook_order):
hook = self.hooks[hook_name]
if hook._is_stateful:
hook._mark_state(self._module_ref, name)
for module_name, module in unwrap_module(self._module_ref).named_modules():
if module_name == "":
continue
module = unwrap_module(module)
if hasattr(module, "_diffusers_hook"):
module._diffusers_hook._mark_state(name)
def __repr__(self) -> str:
registry_repr = ""
for i, hook_name in enumerate(self._hook_order):
@@ -321,7 +234,3 @@ class HookRegistry:
if i < len(self._hook_order) - 1:
registry_repr += "\n"
return f"HookRegistry(\n{registry_repr}\n)"
def _is_dunder_method(name: str) -> bool:
return name.startswith("__") and name.endswith("__") and name in dir(object)
-182
View File
@@ -1,182 +0,0 @@
# Copyright 2024 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from dataclasses import dataclass
from typing import Callable, List, Optional
import torch
from ..utils import get_logger
from ..utils.torch_utils import unwrap_module
from ._common import _ALL_TRANSFORMER_BLOCK_IDENTIFIERS, _ATTENTION_CLASSES, _FEEDFORWARD_CLASSES
from ._helpers import AttentionProcessorRegistry, TransformerBlockRegistry
from .hooks import HookRegistry, ModelHook
logger = get_logger(__name__) # pylint: disable=invalid-name
_LAYER_SKIP_HOOK = "layer_skip_hook"
@dataclass
class LayerSkipConfig:
r"""
Configuration for skipping internal transformer blocks when executing a transformer model.
Args:
indices (`List[int]`):
The indices of the layer to skip. This is typically the first layer in the transformer block.
fqn (`str`, defaults to `"auto"`):
The fully qualified name identifying the stack of transformer blocks. Typically, this is
`transformer_blocks`, `single_transformer_blocks`, `blocks`, `layers`, or `temporal_transformer_blocks`.
"""
indices: List[int]
fqn: str = "auto"
skip_attention: bool = True
skip_attention_scores: bool = False
skip_ff: bool = True
class AttentionScoreSkipFunctionMode(torch.overrides.TorchFunctionMode):
def __init__(self) -> None:
super().__init__()
def __torch_function__(self, func, types, args=(), kwargs=None):
if kwargs is None:
kwargs = {}
if func is torch.nn.functional.scaled_dot_product_attention:
value = kwargs.get("value", None)
if value is None:
value = args[2]
return value
return func(*args, **kwargs)
class AttentionProcessorSkipHook(ModelHook):
def __init__(self, skip_processor_output_fn: Callable, skip_attention_scores: bool = False):
self.skip_processor_output_fn = skip_processor_output_fn
self.skip_attention_scores = skip_attention_scores
def new_forward(self, module: torch.nn.Module, *args, **kwargs):
if self.skip_attention_scores:
with AttentionScoreSkipFunctionMode():
return self.fn_ref.original_forward(*args, **kwargs)
else:
return self.skip_processor_output_fn(module, *args, **kwargs)
class FeedForwardSkipHook(ModelHook):
def new_forward(self, module: torch.nn.Module, *args, **kwargs):
output = kwargs.get("hidden_states", None)
if output is None:
output = kwargs.get("x", None)
if output is None and len(args) > 0:
output = args[0]
return output
class TransformerBlockSkipHook(ModelHook):
def initialize_hook(self, module):
self._metadata = TransformerBlockRegistry.get(unwrap_module(module).__class__)
return module
def new_forward(self, module: torch.nn.Module, *args, **kwargs):
return self._metadata.skip_block_output_fn(module, *args, **kwargs)
def apply_layer_skip(module: torch.nn.Module, config: LayerSkipConfig) -> None:
r"""
Apply layer skipping to internal layers of a transformer.
Args:
module (`torch.nn.Module`):
The transformer model to which the layer skip hook should be applied.
config (`LayerSkipConfig`):
The configuration for the layer skip hook.
Example:
```python
>>> from diffusers import apply_layer_skip_hook, CogVideoXTransformer3DModel, LayerSkipConfig
>>> transformer = CogVideoXTransformer3DModel.from_pretrained("THUDM/CogVideoX-5b", torch_dtype=torch.bfloat16)
>>> config = LayerSkipConfig(layer_index=[10, 20], fqn="transformer_blocks")
>>> apply_layer_skip_hook(transformer, config)
```
"""
_apply_layer_skip_hook(module, config)
def _apply_layer_skip_hook(module: torch.nn.Module, config: LayerSkipConfig, name: Optional[str] = None) -> None:
name = name or _LAYER_SKIP_HOOK
if config.skip_attention and config.skip_attention_scores:
raise ValueError("Cannot set both `skip_attention` and `skip_attention_scores` to True. Please choose one.")
if config.fqn == "auto":
for identifier in _ALL_TRANSFORMER_BLOCK_IDENTIFIERS:
if hasattr(module, identifier):
config.fqn = identifier
break
else:
raise ValueError(
"Could not find a suitable identifier for the transformer blocks automatically. Please provide a valid "
"`fqn` (fully qualified name) that identifies a stack of transformer blocks."
)
transformer_blocks = getattr(module, config.fqn, None)
if transformer_blocks is None or not isinstance(transformer_blocks, torch.nn.ModuleList):
raise ValueError(
f"Could not find {config.fqn} in the provided module, or configured `fqn` (fully qualified name) does not identify "
f"a `torch.nn.ModuleList`. Please provide a valid `fqn` that identifies a stack of transformer blocks."
)
if len(config.indices) == 0:
raise ValueError("Layer index list is empty. Please provide a non-empty list of layer indices to skip.")
blocks_found = False
for i, block in enumerate(transformer_blocks):
if i not in config.indices:
continue
blocks_found = True
if config.skip_attention and config.skip_ff:
logger.debug(f"Applying TransformerBlockSkipHook to '{config.fqn}.{i}'")
registry = HookRegistry.check_if_exists_or_initialize(block)
hook = TransformerBlockSkipHook()
registry.register_hook(hook, name)
elif config.skip_attention or config.skip_attention_scores:
for submodule_name, submodule in block.named_modules():
if isinstance(submodule, _ATTENTION_CLASSES) and not submodule.is_cross_attention:
logger.debug(f"Applying AttentionProcessorSkipHook to '{config.fqn}.{i}.{submodule_name}'")
output_fn = AttentionProcessorRegistry.get(submodule.processor.__class__).skip_processor_output_fn
registry = HookRegistry.check_if_exists_or_initialize(submodule)
hook = AttentionProcessorSkipHook(output_fn, config.skip_attention_scores)
registry.register_hook(hook, name)
elif config.skip_ff:
for submodule_name, submodule in block.named_modules():
if isinstance(submodule, _FEEDFORWARD_CLASSES):
logger.debug(f"Applying FeedForwardSkipHook to '{config.fqn}.{i}.{submodule_name}'")
registry = HookRegistry.check_if_exists_or_initialize(submodule)
hook = FeedForwardSkipHook()
registry.register_hook(hook, name)
else:
raise ValueError(
"At least one of `skip_attention`, `skip_attention_scores`, or `skip_ff` must be set to True."
)
if not blocks_found:
raise ValueError(
f"Could not find any transformer blocks matching the provided indices {config.indices} and "
f"fully qualified name '{config.fqn}'. Please check the indices and fqn for correctness."
)
@@ -26,8 +26,8 @@ from .hooks import HookRegistry, ModelHook
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
_PYRAMID_ATTENTION_BROADCAST_HOOK = "pyramid_attention_broadcast"
_ATTENTION_CLASSES = (Attention, MochiAttention)
_SPATIAL_ATTENTION_BLOCK_IDENTIFIERS = ("blocks", "transformer_blocks", "single_transformer_blocks")
_TEMPORAL_ATTENTION_BLOCK_IDENTIFIERS = ("temporal_transformer_blocks",)
_CROSS_ATTENTION_BLOCK_IDENTIFIERS = ("blocks", "transformer_blocks")
@@ -87,7 +87,7 @@ class PyramidAttentionBroadcastConfig:
def __repr__(self) -> str:
return (
f"PyramidAttentionBroadcastConfig(\n"
f"PyramidAttentionBroadcastConfig("
f" spatial_attention_block_skip_range={self.spatial_attention_block_skip_range},\n"
f" temporal_attention_block_skip_range={self.temporal_attention_block_skip_range},\n"
f" cross_attention_block_skip_range={self.cross_attention_block_skip_range},\n"
@@ -175,7 +175,10 @@ class PyramidAttentionBroadcastHook(ModelHook):
return module
def apply_pyramid_attention_broadcast(module: torch.nn.Module, config: PyramidAttentionBroadcastConfig):
def apply_pyramid_attention_broadcast(
module: torch.nn.Module,
config: PyramidAttentionBroadcastConfig,
):
r"""
Apply [Pyramid Attention Broadcast](https://huggingface.co/papers/2408.12588) to a given pipeline.
@@ -308,4 +311,4 @@ def _apply_pyramid_attention_broadcast_hook(
"""
registry = HookRegistry.check_if_exists_or_initialize(module)
hook = PyramidAttentionBroadcastHook(timestep_skip_range, block_skip_range, current_timestep_callback)
registry.register_hook(hook, _PYRAMID_ATTENTION_BROADCAST_HOOK)
registry.register_hook(hook, "pyramid_attention_broadcast")
+3 -1
View File
@@ -804,7 +804,9 @@ class SD3IPAdapterMixin:
}
self.register_modules(
feature_extractor=SiglipImageProcessor.from_pretrained(image_encoder_subfolder, **kwargs),
feature_extractor=SiglipImageProcessor.from_pretrained(image_encoder_subfolder, **kwargs).to(
self.device, dtype=self.dtype
),
image_encoder=SiglipVisionModel.from_pretrained(
image_encoder_subfolder, torch_dtype=self.dtype, **kwargs
).to(self.device),
+2 -6
View File
@@ -423,12 +423,8 @@ def _load_lora_into_text_encoder(
# Unsafe code />
if prefix is not None and not state_dict:
logger.warning(
f"No LoRA keys associated to {text_encoder.__class__.__name__} found with the {prefix=}. "
"This is safe to ignore if LoRA state dict didn't originally have any "
f"{text_encoder.__class__.__name__} related params. You can also try specifying `prefix=None` "
"to resolve the warning. Otherwise, open an issue if you think it's unexpected: "
"https://github.com/huggingface/diffusers/issues/new"
logger.info(
f"No LoRA keys associated to {text_encoder.__class__.__name__} found with the {prefix=}. This is safe to ignore if LoRA state dict didn't originally have any {text_encoder.__class__.__name__} related params. Open an issue if you think it's unexpected: https://github.com/huggingface/diffusers/issues/new"
)
@@ -1348,56 +1348,3 @@ def _convert_non_diffusers_lumina2_lora_to_diffusers(state_dict):
converted_state_dict[f"transformer.{key}"] = converted_state_dict.pop(key)
return converted_state_dict
def _convert_non_diffusers_wan_lora_to_diffusers(state_dict):
converted_state_dict = {}
original_state_dict = {k[len("diffusion_model.") :]: v for k, v in state_dict.items()}
num_blocks = len({k.split("blocks.")[1].split(".")[0] for k in original_state_dict})
is_i2v_lora = any("k_img" in k for k in original_state_dict) and any("v_img" in k for k in original_state_dict)
for i in range(num_blocks):
# Self-attention
for o, c in zip(["q", "k", "v", "o"], ["to_q", "to_k", "to_v", "to_out.0"]):
converted_state_dict[f"blocks.{i}.attn1.{c}.lora_A.weight"] = original_state_dict.pop(
f"blocks.{i}.self_attn.{o}.lora_A.weight"
)
converted_state_dict[f"blocks.{i}.attn1.{c}.lora_B.weight"] = original_state_dict.pop(
f"blocks.{i}.self_attn.{o}.lora_B.weight"
)
# Cross-attention
for o, c in zip(["q", "k", "v", "o"], ["to_q", "to_k", "to_v", "to_out.0"]):
converted_state_dict[f"blocks.{i}.attn2.{c}.lora_A.weight"] = original_state_dict.pop(
f"blocks.{i}.cross_attn.{o}.lora_A.weight"
)
converted_state_dict[f"blocks.{i}.attn2.{c}.lora_B.weight"] = original_state_dict.pop(
f"blocks.{i}.cross_attn.{o}.lora_B.weight"
)
if is_i2v_lora:
for o, c in zip(["k_img", "v_img"], ["add_k_proj", "add_v_proj"]):
converted_state_dict[f"blocks.{i}.attn2.{c}.lora_A.weight"] = original_state_dict.pop(
f"blocks.{i}.cross_attn.{o}.lora_A.weight"
)
converted_state_dict[f"blocks.{i}.attn2.{c}.lora_B.weight"] = original_state_dict.pop(
f"blocks.{i}.cross_attn.{o}.lora_B.weight"
)
# FFN
for o, c in zip(["ffn.0", "ffn.2"], ["net.0.proj", "net.2"]):
converted_state_dict[f"blocks.{i}.ffn.{c}.lora_A.weight"] = original_state_dict.pop(
f"blocks.{i}.{o}.lora_A.weight"
)
converted_state_dict[f"blocks.{i}.ffn.{c}.lora_B.weight"] = original_state_dict.pop(
f"blocks.{i}.{o}.lora_B.weight"
)
if len(original_state_dict) > 0:
raise ValueError(f"`state_dict` should be empty at this point but has {original_state_dict.keys()=}")
for key in list(converted_state_dict.keys()):
converted_state_dict[f"transformer.{key}"] = converted_state_dict.pop(key)
return converted_state_dict
+27 -107
View File
@@ -42,7 +42,6 @@ from .lora_conversion_utils import (
_convert_kohya_flux_lora_to_diffusers,
_convert_non_diffusers_lora_to_diffusers,
_convert_non_diffusers_lumina2_lora_to_diffusers,
_convert_non_diffusers_wan_lora_to_diffusers,
_convert_xlabs_flux_lora_to_diffusers,
_maybe_map_sgm_blocks_to_diffusers,
)
@@ -452,11 +451,7 @@ class StableDiffusionLoraLoaderMixin(LoraBaseMixin):
```
"""
super().fuse_lora(
components=components,
lora_scale=lora_scale,
safe_fusing=safe_fusing,
adapter_names=adapter_names,
**kwargs,
components=components, lora_scale=lora_scale, safe_fusing=safe_fusing, adapter_names=adapter_names
)
def unfuse_lora(self, components: List[str] = ["unet", "text_encoder"], **kwargs):
@@ -477,7 +472,7 @@ class StableDiffusionLoraLoaderMixin(LoraBaseMixin):
Whether to unfuse the text encoder LoRA parameters. If the text encoder wasn't monkey-patched with the
LoRA parameters then it won't have any effect.
"""
super().unfuse_lora(components=components, **kwargs)
super().unfuse_lora(components=components)
class StableDiffusionXLLoraLoaderMixin(LoraBaseMixin):
@@ -896,11 +891,7 @@ class StableDiffusionXLLoraLoaderMixin(LoraBaseMixin):
```
"""
super().fuse_lora(
components=components,
lora_scale=lora_scale,
safe_fusing=safe_fusing,
adapter_names=adapter_names,
**kwargs,
components=components, lora_scale=lora_scale, safe_fusing=safe_fusing, adapter_names=adapter_names
)
def unfuse_lora(self, components: List[str] = ["unet", "text_encoder", "text_encoder_2"], **kwargs):
@@ -921,7 +912,7 @@ class StableDiffusionXLLoraLoaderMixin(LoraBaseMixin):
Whether to unfuse the text encoder LoRA parameters. If the text encoder wasn't monkey-patched with the
LoRA parameters then it won't have any effect.
"""
super().unfuse_lora(components=components, **kwargs)
super().unfuse_lora(components=components)
class SD3LoraLoaderMixin(LoraBaseMixin):
@@ -1299,11 +1290,7 @@ class SD3LoraLoaderMixin(LoraBaseMixin):
```
"""
super().fuse_lora(
components=components,
lora_scale=lora_scale,
safe_fusing=safe_fusing,
adapter_names=adapter_names,
**kwargs,
components=components, lora_scale=lora_scale, safe_fusing=safe_fusing, adapter_names=adapter_names
)
# Copied from diffusers.loaders.lora_pipeline.StableDiffusionXLLoraLoaderMixin.unfuse_lora with unet->transformer
@@ -1325,7 +1312,7 @@ class SD3LoraLoaderMixin(LoraBaseMixin):
Whether to unfuse the text encoder LoRA parameters. If the text encoder wasn't monkey-patched with the
LoRA parameters then it won't have any effect.
"""
super().unfuse_lora(components=components, **kwargs)
super().unfuse_lora(components=components)
class FluxLoraLoaderMixin(LoraBaseMixin):
@@ -1841,11 +1828,7 @@ class FluxLoraLoaderMixin(LoraBaseMixin):
)
super().fuse_lora(
components=components,
lora_scale=lora_scale,
safe_fusing=safe_fusing,
adapter_names=adapter_names,
**kwargs,
components=components, lora_scale=lora_scale, safe_fusing=safe_fusing, adapter_names=adapter_names
)
def unfuse_lora(self, components: List[str] = ["transformer", "text_encoder"], **kwargs):
@@ -1866,7 +1849,7 @@ class FluxLoraLoaderMixin(LoraBaseMixin):
if hasattr(transformer, "_transformer_norm_layers") and transformer._transformer_norm_layers:
transformer.load_state_dict(transformer._transformer_norm_layers, strict=False)
super().unfuse_lora(components=components, **kwargs)
super().unfuse_lora(components=components)
# We override this here account for `_transformer_norm_layers` and `_overwritten_params`.
def unload_lora_weights(self, reset_to_overwritten_params=False):
@@ -2565,11 +2548,7 @@ class CogVideoXLoraLoaderMixin(LoraBaseMixin):
```
"""
super().fuse_lora(
components=components,
lora_scale=lora_scale,
safe_fusing=safe_fusing,
adapter_names=adapter_names,
**kwargs,
components=components, lora_scale=lora_scale, safe_fusing=safe_fusing, adapter_names=adapter_names
)
def unfuse_lora(self, components: List[str] = ["transformer"], **kwargs):
@@ -2587,7 +2566,7 @@ class CogVideoXLoraLoaderMixin(LoraBaseMixin):
components (`List[str]`): List of LoRA-injectable components to unfuse LoRA from.
unfuse_transformer (`bool`, defaults to `True`): Whether to unfuse the UNet LoRA parameters.
"""
super().unfuse_lora(components=components, **kwargs)
super().unfuse_lora(components=components)
class Mochi1LoraLoaderMixin(LoraBaseMixin):
@@ -2873,11 +2852,7 @@ class Mochi1LoraLoaderMixin(LoraBaseMixin):
```
"""
super().fuse_lora(
components=components,
lora_scale=lora_scale,
safe_fusing=safe_fusing,
adapter_names=adapter_names,
**kwargs,
components=components, lora_scale=lora_scale, safe_fusing=safe_fusing, adapter_names=adapter_names
)
# Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.unfuse_lora
@@ -2896,7 +2871,7 @@ class Mochi1LoraLoaderMixin(LoraBaseMixin):
components (`List[str]`): List of LoRA-injectable components to unfuse LoRA from.
unfuse_transformer (`bool`, defaults to `True`): Whether to unfuse the UNet LoRA parameters.
"""
super().unfuse_lora(components=components, **kwargs)
super().unfuse_lora(components=components)
class LTXVideoLoraLoaderMixin(LoraBaseMixin):
@@ -3182,11 +3157,7 @@ class LTXVideoLoraLoaderMixin(LoraBaseMixin):
```
"""
super().fuse_lora(
components=components,
lora_scale=lora_scale,
safe_fusing=safe_fusing,
adapter_names=adapter_names,
**kwargs,
components=components, lora_scale=lora_scale, safe_fusing=safe_fusing, adapter_names=adapter_names
)
# Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.unfuse_lora
@@ -3205,7 +3176,7 @@ class LTXVideoLoraLoaderMixin(LoraBaseMixin):
components (`List[str]`): List of LoRA-injectable components to unfuse LoRA from.
unfuse_transformer (`bool`, defaults to `True`): Whether to unfuse the UNet LoRA parameters.
"""
super().unfuse_lora(components=components, **kwargs)
super().unfuse_lora(components=components)
class SanaLoraLoaderMixin(LoraBaseMixin):
@@ -3491,11 +3462,7 @@ class SanaLoraLoaderMixin(LoraBaseMixin):
```
"""
super().fuse_lora(
components=components,
lora_scale=lora_scale,
safe_fusing=safe_fusing,
adapter_names=adapter_names,
**kwargs,
components=components, lora_scale=lora_scale, safe_fusing=safe_fusing, adapter_names=adapter_names
)
# Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.unfuse_lora
@@ -3514,7 +3481,7 @@ class SanaLoraLoaderMixin(LoraBaseMixin):
components (`List[str]`): List of LoRA-injectable components to unfuse LoRA from.
unfuse_transformer (`bool`, defaults to `True`): Whether to unfuse the UNet LoRA parameters.
"""
super().unfuse_lora(components=components, **kwargs)
super().unfuse_lora(components=components)
class HunyuanVideoLoraLoaderMixin(LoraBaseMixin):
@@ -3803,11 +3770,7 @@ class HunyuanVideoLoraLoaderMixin(LoraBaseMixin):
```
"""
super().fuse_lora(
components=components,
lora_scale=lora_scale,
safe_fusing=safe_fusing,
adapter_names=adapter_names,
**kwargs,
components=components, lora_scale=lora_scale, safe_fusing=safe_fusing, adapter_names=adapter_names
)
# Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.unfuse_lora
@@ -3826,7 +3789,7 @@ class HunyuanVideoLoraLoaderMixin(LoraBaseMixin):
components (`List[str]`): List of LoRA-injectable components to unfuse LoRA from.
unfuse_transformer (`bool`, defaults to `True`): Whether to unfuse the UNet LoRA parameters.
"""
super().unfuse_lora(components=components, **kwargs)
super().unfuse_lora(components=components)
class Lumina2LoraLoaderMixin(LoraBaseMixin):
@@ -4116,11 +4079,7 @@ class Lumina2LoraLoaderMixin(LoraBaseMixin):
```
"""
super().fuse_lora(
components=components,
lora_scale=lora_scale,
safe_fusing=safe_fusing,
adapter_names=adapter_names,
**kwargs,
components=components, lora_scale=lora_scale, safe_fusing=safe_fusing, adapter_names=adapter_names
)
# Copied from diffusers.loaders.lora_pipeline.SanaLoraLoaderMixin.unfuse_lora
@@ -4139,7 +4098,7 @@ class Lumina2LoraLoaderMixin(LoraBaseMixin):
components (`List[str]`): List of LoRA-injectable components to unfuse LoRA from.
unfuse_transformer (`bool`, defaults to `True`): Whether to unfuse the UNet LoRA parameters.
"""
super().unfuse_lora(components=components, **kwargs)
super().unfuse_lora(components=components)
class WanLoraLoaderMixin(LoraBaseMixin):
@@ -4152,6 +4111,7 @@ class WanLoraLoaderMixin(LoraBaseMixin):
@classmethod
@validate_hf_hub_args
# Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.lora_state_dict
def lora_state_dict(
cls,
pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]],
@@ -4238,8 +4198,6 @@ class WanLoraLoaderMixin(LoraBaseMixin):
user_agent=user_agent,
allow_pickle=allow_pickle,
)
if any(k.startswith("diffusion_model.") for k in state_dict):
state_dict = _convert_non_diffusers_wan_lora_to_diffusers(state_dict)
is_dora_scale_present = any("dora_scale" in k for k in state_dict)
if is_dora_scale_present:
@@ -4249,33 +4207,7 @@ class WanLoraLoaderMixin(LoraBaseMixin):
return state_dict
@classmethod
def _maybe_expand_t2v_lora_for_i2v(
cls,
transformer: torch.nn.Module,
state_dict,
):
if transformer.config.image_dim is None:
return state_dict
if any(k.startswith("transformer.blocks.") for k in state_dict):
num_blocks = len({k.split("blocks.")[1].split(".")[0] for k in state_dict})
is_i2v_lora = any("add_k_proj" in k for k in state_dict) and any("add_v_proj" in k for k in state_dict)
if is_i2v_lora:
return state_dict
for i in range(num_blocks):
for o, c in zip(["k_img", "v_img"], ["add_k_proj", "add_v_proj"]):
state_dict[f"transformer.blocks.{i}.attn2.{c}.lora_A.weight"] = torch.zeros_like(
state_dict[f"transformer.blocks.{i}.attn2.to_k.lora_A.weight"]
)
state_dict[f"transformer.blocks.{i}.attn2.{c}.lora_B.weight"] = torch.zeros_like(
state_dict[f"transformer.blocks.{i}.attn2.to_k.lora_B.weight"]
)
return state_dict
# Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.load_lora_weights
def load_lora_weights(
self, pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]], adapter_name=None, **kwargs
):
@@ -4313,11 +4245,7 @@ class WanLoraLoaderMixin(LoraBaseMixin):
# First, ensure that the checkpoint is a compatible one and can be successfully loaded.
state_dict = self.lora_state_dict(pretrained_model_name_or_path_or_dict, **kwargs)
# convert T2V LoRA to I2V LoRA (when loaded to Wan I2V) by adding zeros for the additional (missing) _img layers
state_dict = self._maybe_expand_t2v_lora_for_i2v(
transformer=getattr(self, self.transformer_name) if not hasattr(self, "transformer") else self.transformer,
state_dict=state_dict,
)
is_correct_format = all("lora" in key for key in state_dict.keys())
if not is_correct_format:
raise ValueError("Invalid LoRA checkpoint.")
@@ -4456,11 +4384,7 @@ class WanLoraLoaderMixin(LoraBaseMixin):
```
"""
super().fuse_lora(
components=components,
lora_scale=lora_scale,
safe_fusing=safe_fusing,
adapter_names=adapter_names,
**kwargs,
components=components, lora_scale=lora_scale, safe_fusing=safe_fusing, adapter_names=adapter_names
)
# Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.unfuse_lora
@@ -4479,7 +4403,7 @@ class WanLoraLoaderMixin(LoraBaseMixin):
components (`List[str]`): List of LoRA-injectable components to unfuse LoRA from.
unfuse_transformer (`bool`, defaults to `True`): Whether to unfuse the UNet LoRA parameters.
"""
super().unfuse_lora(components=components, **kwargs)
super().unfuse_lora(components=components)
class CogView4LoraLoaderMixin(LoraBaseMixin):
@@ -4765,11 +4689,7 @@ class CogView4LoraLoaderMixin(LoraBaseMixin):
```
"""
super().fuse_lora(
components=components,
lora_scale=lora_scale,
safe_fusing=safe_fusing,
adapter_names=adapter_names,
**kwargs,
components=components, lora_scale=lora_scale, safe_fusing=safe_fusing, adapter_names=adapter_names
)
# Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.unfuse_lora
@@ -4788,7 +4708,7 @@ class CogView4LoraLoaderMixin(LoraBaseMixin):
components (`List[str]`): List of LoRA-injectable components to unfuse LoRA from.
unfuse_transformer (`bool`, defaults to `True`): Whether to unfuse the UNet LoRA parameters.
"""
super().unfuse_lora(components=components, **kwargs)
super().unfuse_lora(components=components)
class LoraLoaderMixin(StableDiffusionLoraLoaderMixin):
+2 -9
View File
@@ -307,9 +307,6 @@ class PeftAdapterMixin:
try:
inject_adapter_in_model(lora_config, self, adapter_name=adapter_name, **peft_kwargs)
incompatible_keys = set_peft_model_state_dict(self, state_dict, adapter_name, **peft_kwargs)
# Set peft config loaded flag to True if module has been successfully injected and incompatible keys retrieved
if not self._hf_peft_config_loaded:
self._hf_peft_config_loaded = True
except Exception as e:
# In case `inject_adapter_in_model()` was unsuccessful even before injecting the `peft_config`.
if hasattr(self, "peft_config"):
@@ -357,12 +354,8 @@ class PeftAdapterMixin:
# Unsafe code />
if prefix is not None and not state_dict:
logger.warning(
f"No LoRA keys associated to {self.__class__.__name__} found with the {prefix=}. "
"This is safe to ignore if LoRA state dict didn't originally have any "
f"{self.__class__.__name__} related params. You can also try specifying `prefix=None` "
"to resolve the warning. Otherwise, open an issue if you think it's unexpected: "
"https://github.com/huggingface/diffusers/issues/new"
logger.info(
f"No LoRA keys associated to {self.__class__.__name__} found with the {prefix=}. This is safe to ignore if LoRA state dict didn't originally have any {self.__class__.__name__} related params. Open an issue if you think it's unexpected: https://github.com/huggingface/diffusers/issues/new"
)
def save_lora_adapter(
+2 -2
View File
@@ -360,12 +360,12 @@ class FromSingleFileMixin:
cache_dir = kwargs.pop("cache_dir", None)
local_files_only = kwargs.pop("local_files_only", False)
revision = kwargs.pop("revision", None)
torch_dtype = kwargs.pop("torch_dtype", None)
torch_dtype = kwargs.pop("torch_dtype", torch.float32)
disable_mmap = kwargs.pop("disable_mmap", False)
is_legacy_loading = False
if torch_dtype is not None and not isinstance(torch_dtype, torch.dtype):
if not isinstance(torch_dtype, torch.dtype):
torch_dtype = torch.float32
logger.warning(
f"Passed `torch_dtype` {torch_dtype} is not a `torch.dtype`. Defaulting to `torch.float32`."
+2 -3
View File
@@ -255,12 +255,12 @@ class FromOriginalModelMixin:
subfolder = kwargs.pop("subfolder", None)
revision = kwargs.pop("revision", None)
config_revision = kwargs.pop("config_revision", None)
torch_dtype = kwargs.pop("torch_dtype", None)
torch_dtype = kwargs.pop("torch_dtype", torch.float32)
quantization_config = kwargs.pop("quantization_config", None)
device = kwargs.pop("device", None)
disable_mmap = kwargs.pop("disable_mmap", False)
if torch_dtype is not None and not isinstance(torch_dtype, torch.dtype):
if not isinstance(torch_dtype, torch.dtype):
torch_dtype = torch.float32
logger.warning(
f"Passed `torch_dtype` {torch_dtype} is not a `torch.dtype`. Defaulting to `torch.float32`."
@@ -282,7 +282,6 @@ class FromOriginalModelMixin:
if quantization_config is not None:
hf_quantizer = DiffusersAutoQuantizer.from_config(quantization_config)
hf_quantizer.validate_environment()
torch_dtype = hf_quantizer.update_torch_dtype(torch_dtype)
else:
hf_quantizer = None
+2 -2
View File
@@ -44,7 +44,6 @@ from ..utils import (
is_transformers_available,
logging,
)
from ..utils.constants import DIFFUSERS_REQUEST_TIMEOUT
from ..utils.hub_utils import _get_model_file
@@ -444,7 +443,7 @@ def fetch_original_config(original_config_file, local_files_only=False):
"Please provide a valid local file path."
)
original_config_file = BytesIO(requests.get(original_config_file, timeout=DIFFUSERS_REQUEST_TIMEOUT).content)
original_config_file = BytesIO(requests.get(original_config_file).content)
else:
raise ValueError("Invalid `original_config_file` provided. Please set it to a valid file path or URL.")
@@ -2407,6 +2406,7 @@ def convert_ltx_vae_checkpoint_to_diffusers(checkpoint, **kwargs):
"per_channel_statistics.channel": remove_keys_,
"per_channel_statistics.mean-of-means": remove_keys_,
"per_channel_statistics.mean-of-stds": remove_keys_,
"timestep_scale_multiplier": remove_keys_,
}
if "vae.decoder.last_time_embedder.timestep_embedder.linear_1.weight" in converted_state_dict:
+2 -2
View File
@@ -449,9 +449,9 @@ class TextualInversionLoaderMixin:
# 7.5 Offload the model again
if is_model_cpu_offload:
self.enable_model_cpu_offload(device=device)
self.enable_model_cpu_offload()
elif is_sequential_cpu_offload:
self.enable_sequential_cpu_offload(device=device)
self.enable_sequential_cpu_offload()
# / Unsafe Code >
+5 -18
View File
@@ -741,14 +741,10 @@ class Attention(nn.Module):
if out_dim == 3:
if attention_mask.shape[0] < batch_size * head_size:
attention_mask = attention_mask.repeat_interleave(
head_size, dim=0, output_size=attention_mask.shape[0] * head_size
)
attention_mask = attention_mask.repeat_interleave(head_size, dim=0)
elif out_dim == 4:
attention_mask = attention_mask.unsqueeze(1)
attention_mask = attention_mask.repeat_interleave(
head_size, dim=1, output_size=attention_mask.shape[1] * head_size
)
attention_mask = attention_mask.repeat_interleave(head_size, dim=1)
return attention_mask
@@ -2339,9 +2335,7 @@ class FluxAttnProcessor2_0:
query = apply_rotary_emb(query, image_rotary_emb)
key = apply_rotary_emb(key, image_rotary_emb)
hidden_states = F.scaled_dot_product_attention(
query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
)
hidden_states = F.scaled_dot_product_attention(query, key, value, dropout_p=0.0, is_causal=False)
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
hidden_states = hidden_states.to(query.dtype)
@@ -3710,10 +3704,8 @@ class StableAudioAttnProcessor2_0:
if kv_heads != attn.heads:
# if GQA or MQA, repeat the key/value heads to reach the number of query heads.
heads_per_kv_head = attn.heads // kv_heads
key = torch.repeat_interleave(key, heads_per_kv_head, dim=1, output_size=key.shape[1] * heads_per_kv_head)
value = torch.repeat_interleave(
value, heads_per_kv_head, dim=1, output_size=value.shape[1] * heads_per_kv_head
)
key = torch.repeat_interleave(key, heads_per_kv_head, dim=1)
value = torch.repeat_interleave(value, heads_per_kv_head, dim=1)
if attn.norm_q is not None:
query = attn.norm_q(query)
@@ -6020,11 +6012,6 @@ class SanaLinearAttnProcessor2_0:
key = attn.to_k(encoder_hidden_states)
value = attn.to_v(encoder_hidden_states)
if attn.norm_q is not None:
query = attn.norm_q(query)
if attn.norm_k is not None:
key = attn.norm_k(key)
query = query.transpose(1, 2).unflatten(1, (attn.heads, -1))
key = key.transpose(1, 2).unflatten(1, (attn.heads, -1)).transpose(2, 3)
value = value.transpose(1, 2).unflatten(1, (attn.heads, -1))
@@ -190,7 +190,7 @@ class DCUpBlock2d(nn.Module):
x = F.pixel_shuffle(x, self.factor)
if self.shortcut:
y = hidden_states.repeat_interleave(self.repeats, dim=1, output_size=hidden_states.shape[1] * self.repeats)
y = hidden_states.repeat_interleave(self.repeats, dim=1)
y = F.pixel_shuffle(y, self.factor)
hidden_states = x + y
else:
@@ -361,9 +361,7 @@ class Decoder(nn.Module):
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
if self.in_shortcut:
x = hidden_states.repeat_interleave(
self.in_shortcut_repeats, dim=1, output_size=hidden_states.shape[1] * self.in_shortcut_repeats
)
x = hidden_states.repeat_interleave(self.in_shortcut_repeats, dim=1)
hidden_states = self.conv_in(hidden_states) + x
else:
hidden_states = self.conv_in(hidden_states)
@@ -103,7 +103,7 @@ class AllegroTemporalConvLayer(nn.Module):
if self.down_sample:
identity = hidden_states[:, :, ::2]
elif self.up_sample:
identity = hidden_states.repeat_interleave(2, dim=2, output_size=hidden_states.shape[2] * 2)
identity = hidden_states.repeat_interleave(2, dim=2)
else:
identity = hidden_states
@@ -105,7 +105,6 @@ class CogVideoXCausalConv3d(nn.Module):
self.width_pad = width_pad
self.time_pad = time_pad
self.time_causal_padding = (width_pad, width_pad, height_pad, height_pad, time_pad, 0)
self.const_padding_conv3d = (0, self.width_pad, self.height_pad)
self.temporal_dim = 2
self.time_kernel_size = time_kernel_size
@@ -118,8 +117,6 @@ class CogVideoXCausalConv3d(nn.Module):
kernel_size=kernel_size,
stride=stride,
dilation=dilation,
padding=0 if self.pad_mode == "replicate" else self.const_padding_conv3d,
padding_mode="zeros",
)
def fake_context_parallel_forward(
@@ -140,7 +137,9 @@ class CogVideoXCausalConv3d(nn.Module):
if self.pad_mode == "replicate":
conv_cache = None
else:
padding_2d = (self.width_pad, self.width_pad, self.height_pad, self.height_pad)
conv_cache = inputs[:, :, -self.time_kernel_size + 1 :].clone()
inputs = F.pad(inputs, padding_2d, mode="constant", value=0)
output = self.conv(inputs)
return output, conv_cache
@@ -196,55 +196,6 @@ class LTXVideoResnetBlock3d(nn.Module):
return hidden_states
class LTXVideoDownsampler3d(nn.Module):
def __init__(
self,
in_channels: int,
out_channels: int,
stride: Union[int, Tuple[int, int, int]] = 1,
is_causal: bool = True,
padding_mode: str = "zeros",
) -> None:
super().__init__()
self.stride = stride if isinstance(stride, tuple) else (stride, stride, stride)
self.group_size = (in_channels * stride[0] * stride[1] * stride[2]) // out_channels
out_channels = out_channels // (self.stride[0] * self.stride[1] * self.stride[2])
self.conv = LTXVideoCausalConv3d(
in_channels=in_channels,
out_channels=out_channels,
kernel_size=3,
stride=1,
is_causal=is_causal,
padding_mode=padding_mode,
)
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
hidden_states = torch.cat([hidden_states[:, :, : self.stride[0] - 1], hidden_states], dim=2)
residual = (
hidden_states.unflatten(4, (-1, self.stride[2]))
.unflatten(3, (-1, self.stride[1]))
.unflatten(2, (-1, self.stride[0]))
)
residual = residual.permute(0, 1, 3, 5, 7, 2, 4, 6).flatten(1, 4)
residual = residual.unflatten(1, (-1, self.group_size))
residual = residual.mean(dim=2)
hidden_states = self.conv(hidden_states)
hidden_states = (
hidden_states.unflatten(4, (-1, self.stride[2]))
.unflatten(3, (-1, self.stride[1]))
.unflatten(2, (-1, self.stride[0]))
)
hidden_states = hidden_states.permute(0, 1, 3, 5, 7, 2, 4, 6).flatten(1, 4)
hidden_states = hidden_states + residual
return hidden_states
class LTXVideoUpsampler3d(nn.Module):
def __init__(
self,
@@ -253,7 +204,6 @@ class LTXVideoUpsampler3d(nn.Module):
is_causal: bool = True,
residual: bool = False,
upscale_factor: int = 1,
padding_mode: str = "zeros",
) -> None:
super().__init__()
@@ -269,7 +219,6 @@ class LTXVideoUpsampler3d(nn.Module):
kernel_size=3,
stride=1,
is_causal=is_causal,
padding_mode=padding_mode,
)
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
@@ -403,118 +352,6 @@ class LTXVideoDownBlock3D(nn.Module):
return hidden_states
class LTXVideo095DownBlock3D(nn.Module):
r"""
Down block used in the LTXVideo model.
Args:
in_channels (`int`):
Number of input channels.
out_channels (`int`, *optional*):
Number of output channels. If None, defaults to `in_channels`.
num_layers (`int`, defaults to `1`):
Number of resnet layers.
dropout (`float`, defaults to `0.0`):
Dropout rate.
resnet_eps (`float`, defaults to `1e-6`):
Epsilon value for normalization layers.
resnet_act_fn (`str`, defaults to `"swish"`):
Activation function to use.
spatio_temporal_scale (`bool`, defaults to `True`):
Whether or not to use a downsampling layer. If not used, output dimension would be same as input dimension.
Whether or not to downsample across temporal dimension.
is_causal (`bool`, defaults to `True`):
Whether this layer behaves causally (future frames depend only on past frames) or not.
"""
_supports_gradient_checkpointing = True
def __init__(
self,
in_channels: int,
out_channels: Optional[int] = None,
num_layers: int = 1,
dropout: float = 0.0,
resnet_eps: float = 1e-6,
resnet_act_fn: str = "swish",
spatio_temporal_scale: bool = True,
is_causal: bool = True,
downsample_type: str = "conv",
):
super().__init__()
out_channels = out_channels or in_channels
resnets = []
for _ in range(num_layers):
resnets.append(
LTXVideoResnetBlock3d(
in_channels=in_channels,
out_channels=in_channels,
dropout=dropout,
eps=resnet_eps,
non_linearity=resnet_act_fn,
is_causal=is_causal,
)
)
self.resnets = nn.ModuleList(resnets)
self.downsamplers = None
if spatio_temporal_scale:
self.downsamplers = nn.ModuleList()
if downsample_type == "conv":
self.downsamplers.append(
LTXVideoCausalConv3d(
in_channels=in_channels,
out_channels=in_channels,
kernel_size=3,
stride=(2, 2, 2),
is_causal=is_causal,
)
)
elif downsample_type == "spatial":
self.downsamplers.append(
LTXVideoDownsampler3d(
in_channels=in_channels, out_channels=out_channels, stride=(1, 2, 2), is_causal=is_causal
)
)
elif downsample_type == "temporal":
self.downsamplers.append(
LTXVideoDownsampler3d(
in_channels=in_channels, out_channels=out_channels, stride=(2, 1, 1), is_causal=is_causal
)
)
elif downsample_type == "spatiotemporal":
self.downsamplers.append(
LTXVideoDownsampler3d(
in_channels=in_channels, out_channels=out_channels, stride=(2, 2, 2), is_causal=is_causal
)
)
self.gradient_checkpointing = False
def forward(
self,
hidden_states: torch.Tensor,
temb: Optional[torch.Tensor] = None,
generator: Optional[torch.Generator] = None,
) -> torch.Tensor:
r"""Forward method of the `LTXDownBlock3D` class."""
for i, resnet in enumerate(self.resnets):
if torch.is_grad_enabled() and self.gradient_checkpointing:
hidden_states = self._gradient_checkpointing_func(resnet, hidden_states, temb, generator)
else:
hidden_states = resnet(hidden_states, temb, generator)
if self.downsamplers is not None:
for downsampler in self.downsamplers:
hidden_states = downsampler(hidden_states)
return hidden_states
# Adapted from diffusers.models.autoencoders.autoencoder_kl_cogvideox.CogVideoMidBlock3d
class LTXVideoMidBlock3d(nn.Module):
r"""
@@ -756,15 +593,8 @@ class LTXVideoEncoder3d(nn.Module):
in_channels: int = 3,
out_channels: int = 128,
block_out_channels: Tuple[int, ...] = (128, 256, 512, 512),
down_block_types: Tuple[str, ...] = (
"LTXVideoDownBlock3D",
"LTXVideoDownBlock3D",
"LTXVideoDownBlock3D",
"LTXVideoDownBlock3D",
),
spatio_temporal_scaling: Tuple[bool, ...] = (True, True, True, False),
layers_per_block: Tuple[int, ...] = (4, 3, 3, 3, 4),
downsample_type: Tuple[str, ...] = ("conv", "conv", "conv", "conv"),
patch_size: int = 4,
patch_size_t: int = 1,
resnet_norm_eps: float = 1e-6,
@@ -787,37 +617,20 @@ class LTXVideoEncoder3d(nn.Module):
)
# down blocks
is_ltx_095 = down_block_types[-1] == "LTXVideo095DownBlock3D"
num_block_out_channels = len(block_out_channels) - (1 if is_ltx_095 else 0)
num_block_out_channels = len(block_out_channels)
self.down_blocks = nn.ModuleList([])
for i in range(num_block_out_channels):
input_channel = output_channel
if not is_ltx_095:
output_channel = block_out_channels[i + 1] if i + 1 < num_block_out_channels else block_out_channels[i]
else:
output_channel = block_out_channels[i + 1]
output_channel = block_out_channels[i + 1] if i + 1 < num_block_out_channels else block_out_channels[i]
if down_block_types[i] == "LTXVideoDownBlock3D":
down_block = LTXVideoDownBlock3D(
in_channels=input_channel,
out_channels=output_channel,
num_layers=layers_per_block[i],
resnet_eps=resnet_norm_eps,
spatio_temporal_scale=spatio_temporal_scaling[i],
is_causal=is_causal,
)
elif down_block_types[i] == "LTXVideo095DownBlock3D":
down_block = LTXVideo095DownBlock3D(
in_channels=input_channel,
out_channels=output_channel,
num_layers=layers_per_block[i],
resnet_eps=resnet_norm_eps,
spatio_temporal_scale=spatio_temporal_scaling[i],
is_causal=is_causal,
downsample_type=downsample_type[i],
)
else:
raise ValueError(f"Unknown down block type: {down_block_types[i]}")
down_block = LTXVideoDownBlock3D(
in_channels=input_channel,
out_channels=output_channel,
num_layers=layers_per_block[i],
resnet_eps=resnet_norm_eps,
spatio_temporal_scale=spatio_temporal_scaling[i],
is_causal=is_causal,
)
self.down_blocks.append(down_block)
@@ -981,9 +794,7 @@ class LTXVideoDecoder3d(nn.Module):
# timestep embedding
self.time_embedder = None
self.scale_shift_table = None
self.timestep_scale_multiplier = None
if timestep_conditioning:
self.timestep_scale_multiplier = nn.Parameter(torch.tensor(1000.0, dtype=torch.float32))
self.time_embedder = PixArtAlphaCombinedTimestepSizeEmbeddings(output_channel * 2, 0)
self.scale_shift_table = nn.Parameter(torch.randn(2, output_channel) / output_channel**0.5)
@@ -992,9 +803,6 @@ class LTXVideoDecoder3d(nn.Module):
def forward(self, hidden_states: torch.Tensor, temb: Optional[torch.Tensor] = None) -> torch.Tensor:
hidden_states = self.conv_in(hidden_states)
if self.timestep_scale_multiplier is not None:
temb = temb * self.timestep_scale_multiplier
if torch.is_grad_enabled() and self.gradient_checkpointing:
hidden_states = self._gradient_checkpointing_func(self.mid_block, hidden_states, temb)
@@ -1083,19 +891,12 @@ class AutoencoderKLLTXVideo(ModelMixin, ConfigMixin, FromOriginalModelMixin):
out_channels: int = 3,
latent_channels: int = 128,
block_out_channels: Tuple[int, ...] = (128, 256, 512, 512),
down_block_types: Tuple[str, ...] = (
"LTXVideoDownBlock3D",
"LTXVideoDownBlock3D",
"LTXVideoDownBlock3D",
"LTXVideoDownBlock3D",
),
decoder_block_out_channels: Tuple[int, ...] = (128, 256, 512, 512),
layers_per_block: Tuple[int, ...] = (4, 3, 3, 3, 4),
decoder_layers_per_block: Tuple[int, ...] = (4, 3, 3, 3, 4),
spatio_temporal_scaling: Tuple[bool, ...] = (True, True, True, False),
decoder_spatio_temporal_scaling: Tuple[bool, ...] = (True, True, True, False),
decoder_inject_noise: Tuple[bool, ...] = (False, False, False, False, False),
downsample_type: Tuple[str, ...] = ("conv", "conv", "conv", "conv"),
upsample_residual: Tuple[bool, ...] = (False, False, False, False),
upsample_factor: Tuple[int, ...] = (1, 1, 1, 1),
timestep_conditioning: bool = False,
@@ -1105,8 +906,6 @@ class AutoencoderKLLTXVideo(ModelMixin, ConfigMixin, FromOriginalModelMixin):
scaling_factor: float = 1.0,
encoder_causal: bool = True,
decoder_causal: bool = False,
spatial_compression_ratio: int = None,
temporal_compression_ratio: int = None,
) -> None:
super().__init__()
@@ -1114,10 +913,8 @@ class AutoencoderKLLTXVideo(ModelMixin, ConfigMixin, FromOriginalModelMixin):
in_channels=in_channels,
out_channels=latent_channels,
block_out_channels=block_out_channels,
down_block_types=down_block_types,
spatio_temporal_scaling=spatio_temporal_scaling,
layers_per_block=layers_per_block,
downsample_type=downsample_type,
patch_size=patch_size,
patch_size_t=patch_size_t,
resnet_norm_eps=resnet_norm_eps,
@@ -1144,16 +941,8 @@ class AutoencoderKLLTXVideo(ModelMixin, ConfigMixin, FromOriginalModelMixin):
self.register_buffer("latents_mean", latents_mean, persistent=True)
self.register_buffer("latents_std", latents_std, persistent=True)
self.spatial_compression_ratio = (
patch_size * 2 ** sum(spatio_temporal_scaling)
if spatial_compression_ratio is None
else spatial_compression_ratio
)
self.temporal_compression_ratio = (
patch_size_t * 2 ** sum(spatio_temporal_scaling)
if temporal_compression_ratio is None
else temporal_compression_ratio
)
self.spatial_compression_ratio = patch_size * 2 ** sum(spatio_temporal_scaling)
self.temporal_compression_ratio = patch_size_t * 2 ** sum(spatio_temporal_scaling)
# When decoding a batch of video latents at a time, one can save memory by slicing across the batch dimension
# to perform decoding of a single video latent at a time.
@@ -426,9 +426,7 @@ class FourierFeatures(nn.Module):
w = w.repeat(num_channels)[None, :, None, None, None] # [1, num_channels * num_freqs, 1, 1, 1]
# Interleaved repeat of input channels to match w
h = inputs.repeat_interleave(
num_freqs, dim=1, output_size=inputs.shape[1] * num_freqs
) # [B, C * num_freqs, T, H, W]
h = inputs.repeat_interleave(num_freqs, dim=1) # [B, C * num_freqs, T, H, W]
# Scale channels by frequency.
h = w * h
+6 -53
View File
@@ -12,8 +12,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from contextlib import contextmanager
from ..utils.logging import get_logger
@@ -26,8 +24,6 @@ class CacheMixin:
Supported caching techniques:
- [Pyramid Attention Broadcast](https://huggingface.co/papers/2408.12588)
- [FasterCache](https://huggingface.co/papers/2410.19355)
- [FirstBlockCache](https://github.com/chengzeyi/ParaAttention/blob/7a266123671b55e7e5a2fe9af3121f07a36afc78/README.md#first-block-cache-our-dynamic-caching)
"""
_cache_config = None
@@ -63,25 +59,9 @@ class CacheMixin:
```
"""
from ..hooks import (
FasterCacheConfig,
FirstBlockCacheConfig,
PyramidAttentionBroadcastConfig,
apply_faster_cache,
apply_first_block_cache,
apply_pyramid_attention_broadcast,
)
from ..hooks import PyramidAttentionBroadcastConfig, apply_pyramid_attention_broadcast
if self.is_cache_enabled:
raise ValueError(
f"Caching has already been enabled with {type(self._cache_config)}. To apply a new caching technique, please disable the existing one first."
)
if isinstance(config, FasterCacheConfig):
apply_faster_cache(self, config)
elif isinstance(config, FirstBlockCacheConfig):
apply_first_block_cache(self, config)
elif isinstance(config, PyramidAttentionBroadcastConfig):
if isinstance(config, PyramidAttentionBroadcastConfig):
apply_pyramid_attention_broadcast(self, config)
else:
raise ValueError(f"Cache config {type(config)} is not supported.")
@@ -89,24 +69,15 @@ class CacheMixin:
self._cache_config = config
def disable_cache(self) -> None:
from ..hooks import FasterCacheConfig, FirstBlockCacheConfig, HookRegistry, PyramidAttentionBroadcastConfig
from ..hooks.faster_cache import _FASTER_CACHE_BLOCK_HOOK, _FASTER_CACHE_DENOISER_HOOK
from ..hooks.first_block_cache import _FBC_BLOCK_HOOK, _FBC_LEADER_BLOCK_HOOK
from ..hooks.pyramid_attention_broadcast import _PYRAMID_ATTENTION_BROADCAST_HOOK
from ..hooks import HookRegistry, PyramidAttentionBroadcastConfig
if self._cache_config is None:
logger.warning("Caching techniques have not been enabled, so there's nothing to disable.")
return
registry = HookRegistry.check_if_exists_or_initialize(self)
if isinstance(self._cache_config, FasterCacheConfig):
registry.remove_hook(_FASTER_CACHE_DENOISER_HOOK, recurse=True)
registry.remove_hook(_FASTER_CACHE_BLOCK_HOOK, recurse=True)
elif isinstance(self._cache_config, FirstBlockCacheConfig):
registry.remove_hook(_FBC_LEADER_BLOCK_HOOK, recurse=True)
registry.remove_hook(_FBC_BLOCK_HOOK, recurse=True)
elif isinstance(self._cache_config, PyramidAttentionBroadcastConfig):
registry.remove_hook(_PYRAMID_ATTENTION_BROADCAST_HOOK, recurse=True)
if isinstance(self._cache_config, PyramidAttentionBroadcastConfig):
registry = HookRegistry.check_if_exists_or_initialize(self)
registry.remove_hook("pyramid_attention_broadcast", recurse=True)
else:
raise ValueError(f"Cache config {type(self._cache_config)} is not supported.")
@@ -116,21 +87,3 @@ class CacheMixin:
from ..hooks import HookRegistry
HookRegistry.check_if_exists_or_initialize(self).reset_stateful_hooks(recurse=recurse)
@contextmanager
def _cache_context(self):
r"""Context manager that provides additional methods for cache management."""
cache_context = _CacheContextManager(self)
yield cache_context
class _CacheContextManager:
def __init__(self, model: CacheMixin):
self.model = model
def mark_state(self, name: str) -> None:
from ..hooks import HookRegistry
if self.model.is_cache_enabled:
registry = HookRegistry.check_if_exists_or_initialize(self.model)
registry._mark_state(name)
@@ -343,25 +343,25 @@ class FluxControlNetModel(ModelMixin, ConfigMixin, PeftAdapterMixin):
)
block_samples = block_samples + (hidden_states,)
hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1)
single_block_samples = ()
for index_block, block in enumerate(self.single_transformer_blocks):
if torch.is_grad_enabled() and self.gradient_checkpointing:
encoder_hidden_states, hidden_states = self._gradient_checkpointing_func(
hidden_states = self._gradient_checkpointing_func(
block,
hidden_states,
encoder_hidden_states,
temb,
image_rotary_emb,
)
else:
encoder_hidden_states, hidden_states = block(
hidden_states = block(
hidden_states=hidden_states,
encoder_hidden_states=encoder_hidden_states,
temb=temb,
image_rotary_emb=image_rotary_emb,
)
single_block_samples = single_block_samples + (hidden_states,)
single_block_samples = single_block_samples + (hidden_states[:, encoder_hidden_states.shape[1] :],)
# controlnet block
controlnet_block_samples = ()
@@ -687,7 +687,7 @@ class SparseControlNetModel(ModelMixin, ConfigMixin, FromOriginalModelMixin):
t_emb = t_emb.to(dtype=sample.dtype)
emb = self.time_embedding(t_emb, timestep_cond)
emb = emb.repeat_interleave(sample_num_frames, dim=0, output_size=emb.shape[0] * sample_num_frames)
emb = emb.repeat_interleave(sample_num_frames, dim=0)
# 2. pre-process
batch_size, channels, num_frames, height, width = sample.shape
+4 -9
View File
@@ -139,9 +139,7 @@ def get_3d_sincos_pos_embed(
# 3. Concat
pos_embed_spatial = pos_embed_spatial[None, :, :]
pos_embed_spatial = pos_embed_spatial.repeat_interleave(
temporal_size, dim=0, output_size=pos_embed_spatial.shape[0] * temporal_size
) # [T, H*W, D // 4 * 3]
pos_embed_spatial = pos_embed_spatial.repeat_interleave(temporal_size, dim=0) # [T, H*W, D // 4 * 3]
pos_embed_temporal = pos_embed_temporal[:, None, :]
pos_embed_temporal = pos_embed_temporal.repeat_interleave(
@@ -336,7 +334,7 @@ def get_1d_sincos_pos_embed_from_grid(embed_dim, pos, output_type="np"):
" `from_numpy` is no longer required."
" Pass `output_type='pt' to use the new version now."
)
deprecate("output_type=='np'", "0.34.0", deprecation_message, standard_warn=False)
deprecate("output_type=='np'", "0.33.0", deprecation_message, standard_warn=False)
return get_1d_sincos_pos_embed_from_grid_np(embed_dim=embed_dim, pos=pos)
if embed_dim % 2 != 0:
raise ValueError("embed_dim must be divisible by 2")
@@ -1154,13 +1152,10 @@ def get_1d_rotary_pos_embed(
/ linear_factor
) # [D/2]
freqs = torch.outer(pos, freqs) # type: ignore # [S, D/2]
is_npu = freqs.device.type == "npu"
if is_npu:
freqs = freqs.float()
if use_real and repeat_interleave_real:
# flux, hunyuan-dit, cogvideox
freqs_cos = freqs.cos().repeat_interleave(2, dim=1, output_size=freqs.shape[1] * 2).float() # [S, D]
freqs_sin = freqs.sin().repeat_interleave(2, dim=1, output_size=freqs.shape[1] * 2).float() # [S, D]
freqs_cos = freqs.cos().repeat_interleave(2, dim=1).float() # [S, D]
freqs_sin = freqs.sin().repeat_interleave(2, dim=1).float() # [S, D]
return freqs_cos, freqs_sin
elif use_real:
# stable audio, allegro
+4 -14
View File
@@ -37,6 +37,7 @@ from torch import Tensor, nn
from typing_extensions import Self
from .. import __version__
from ..hooks import apply_group_offloading, apply_layerwise_casting
from ..quantizers import DiffusersAutoQuantizer, DiffusersQuantizer
from ..quantizers.quantization_config import QuantizationMethod
from ..utils import (
@@ -503,7 +504,6 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
non_blocking (`bool`, *optional*, defaults to `False`):
If `True`, the weight casting operations are non-blocking.
"""
from ..hooks import apply_layerwise_casting
user_provided_patterns = True
if skip_modules_pattern is None:
@@ -546,7 +546,6 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
num_blocks_per_group: Optional[int] = None,
non_blocking: bool = False,
use_stream: bool = False,
low_cpu_mem_usage=False,
) -> None:
r"""
Activates group offloading for the current model.
@@ -570,8 +569,6 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
... )
```
"""
from ..hooks import apply_group_offloading
if getattr(self, "enable_tiling", None) is not None and getattr(self, "use_tiling", False) and use_stream:
msg = (
"Applying group offloading on autoencoders, with CUDA streams, may not work as expected if the first "
@@ -587,14 +584,7 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
f"open an issue at https://github.com/huggingface/diffusers/issues."
)
apply_group_offloading(
self,
onload_device,
offload_device,
offload_type,
num_blocks_per_group,
non_blocking,
use_stream,
low_cpu_mem_usage=low_cpu_mem_usage,
self, onload_device, offload_device, offload_type, num_blocks_per_group, non_blocking, use_stream
)
def save_pretrained(
@@ -880,7 +870,7 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
local_files_only = kwargs.pop("local_files_only", None)
token = kwargs.pop("token", None)
revision = kwargs.pop("revision", None)
torch_dtype = kwargs.pop("torch_dtype", None)
torch_dtype = kwargs.pop("torch_dtype", torch.float32)
subfolder = kwargs.pop("subfolder", None)
device_map = kwargs.pop("device_map", None)
max_memory = kwargs.pop("max_memory", None)
@@ -893,7 +883,7 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
dduf_entries: Optional[Dict[str, DDUFEntry]] = kwargs.pop("dduf_entries", None)
disable_mmap = kwargs.pop("disable_mmap", False)
if torch_dtype is not None and not isinstance(torch_dtype, torch.dtype):
if not isinstance(torch_dtype, torch.dtype):
torch_dtype = torch.float32
logger.warning(
f"Passed `torch_dtype` {torch_dtype} is not a `torch.dtype`. Defaulting to `torch.float32`."
+10
View File
@@ -550,6 +550,16 @@ class RMSNorm(nn.Module):
hidden_states = torch_npu.npu_rms_norm(hidden_states, self.weight, epsilon=self.eps)[0]
if self.bias is not None:
hidden_states = hidden_states + self.bias
elif is_torch_version(">=", "2.4"):
if self.weight is not None:
# convert into half-precision if necessary
if self.weight.dtype in [torch.float16, torch.bfloat16]:
hidden_states = hidden_states.to(self.weight.dtype)
hidden_states = nn.functional.rms_norm(
hidden_states, normalized_shape=(hidden_states.shape[-1],), weight=self.weight, eps=self.eps
)
if self.bias is not None:
hidden_states = hidden_states + self.bias
else:
input_dtype = hidden_states.dtype
variance = hidden_states.to(torch.float32).pow(2).mean(-1, keepdim=True)
+1 -1
View File
@@ -366,7 +366,7 @@ class ResnetBlock2D(nn.Module):
hidden_states = self.conv2(hidden_states)
if self.conv_shortcut is not None:
input_tensor = self.conv_shortcut(input_tensor.contiguous())
input_tensor = self.conv_shortcut(input_tensor)
output_tensor = (input_tensor + hidden_states) / self.output_scale_factor
@@ -227,17 +227,13 @@ class LatteTransformer3DModel(ModelMixin, ConfigMixin, CacheMixin):
# Prepare text embeddings for spatial block
# batch_size num_tokens hidden_size -> (batch_size * num_frame) num_tokens hidden_size
encoder_hidden_states = self.caption_projection(encoder_hidden_states) # 3 120 1152
encoder_hidden_states_spatial = encoder_hidden_states.repeat_interleave(
num_frame, dim=0, output_size=encoder_hidden_states.shape[0] * num_frame
).view(-1, encoder_hidden_states.shape[-2], encoder_hidden_states.shape[-1])
encoder_hidden_states_spatial = encoder_hidden_states.repeat_interleave(num_frame, dim=0).view(
-1, encoder_hidden_states.shape[-2], encoder_hidden_states.shape[-1]
)
# Prepare timesteps for spatial and temporal block
timestep_spatial = timestep.repeat_interleave(
num_frame, dim=0, output_size=timestep.shape[0] * num_frame
).view(-1, timestep.shape[-1])
timestep_temp = timestep.repeat_interleave(
num_patches, dim=0, output_size=timestep.shape[0] * num_patches
).view(-1, timestep.shape[-1])
timestep_spatial = timestep.repeat_interleave(num_frame, dim=0).view(-1, timestep.shape[-1])
timestep_temp = timestep.repeat_interleave(num_patches, dim=0).view(-1, timestep.shape[-1])
# Spatial and temporal transformer blocks
for i, (spatial_block, temp_block) in enumerate(
@@ -273,7 +269,7 @@ class LatteTransformer3DModel(ModelMixin, ConfigMixin, CacheMixin):
hidden_states = hidden_states.reshape(-1, hidden_states.shape[-2], hidden_states.shape[-1])
if i == 0 and num_frame > 1:
hidden_states = hidden_states + self.temp_pos_embed.to(hidden_states.dtype)
hidden_states = hidden_states + self.temp_pos_embed
if torch.is_grad_enabled() and self.gradient_checkpointing:
hidden_states = self._gradient_checkpointing_func(
@@ -303,9 +299,7 @@ class LatteTransformer3DModel(ModelMixin, ConfigMixin, CacheMixin):
).permute(0, 2, 1, 3)
hidden_states = hidden_states.reshape(-1, hidden_states.shape[-2], hidden_states.shape[-1])
embedded_timestep = embedded_timestep.repeat_interleave(
num_frame, dim=0, output_size=embedded_timestep.shape[0] * num_frame
).view(-1, embedded_timestep.shape[-1])
embedded_timestep = embedded_timestep.repeat_interleave(num_frame, dim=0).view(-1, embedded_timestep.shape[-1])
shift, scale = (self.scale_shift_table[None] + embedded_timestep[:, None]).chunk(2, dim=1)
hidden_states = self.norm_out(hidden_states)
# Modulation
@@ -353,11 +353,7 @@ class PriorTransformer(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin, Pef
attention_mask = (1 - attention_mask.to(hidden_states.dtype)) * -10000.0
attention_mask = F.pad(attention_mask, (0, self.additional_embeddings), value=0.0)
attention_mask = (attention_mask[:, None, :] + self.causal_attention_mask).to(hidden_states.dtype)
attention_mask = attention_mask.repeat_interleave(
self.config.num_attention_heads,
dim=0,
output_size=attention_mask.shape[0] * self.config.num_attention_heads,
)
attention_mask = attention_mask.repeat_interleave(self.config.num_attention_heads, dim=0)
if self.norm_in is not None:
hidden_states = self.norm_in(hidden_states)
@@ -15,7 +15,6 @@
from typing import Any, Dict, Optional, Tuple, Union
import torch
import torch.nn.functional as F
from torch import nn
from ...configuration_utils import ConfigMixin, register_to_config
@@ -24,9 +23,10 @@ from ...utils import USE_PEFT_BACKEND, logging, scale_lora_layers, unscale_lora_
from ..attention_processor import (
Attention,
AttentionProcessor,
AttnProcessor2_0,
SanaLinearAttnProcessor2_0,
)
from ..embeddings import PatchEmbed, PixArtAlphaTextProjection, TimestepEmbedding, Timesteps
from ..embeddings import PatchEmbed, PixArtAlphaTextProjection
from ..modeling_outputs import Transformer2DModelOutput
from ..modeling_utils import ModelMixin
from ..normalization import AdaLayerNormSingle, RMSNorm
@@ -96,95 +96,6 @@ class SanaModulatedNorm(nn.Module):
return hidden_states
class SanaCombinedTimestepGuidanceEmbeddings(nn.Module):
def __init__(self, embedding_dim):
super().__init__()
self.time_proj = Timesteps(num_channels=256, flip_sin_to_cos=True, downscale_freq_shift=0)
self.timestep_embedder = TimestepEmbedding(in_channels=256, time_embed_dim=embedding_dim)
self.guidance_condition_proj = Timesteps(num_channels=256, flip_sin_to_cos=True, downscale_freq_shift=0)
self.guidance_embedder = TimestepEmbedding(in_channels=256, time_embed_dim=embedding_dim)
self.silu = nn.SiLU()
self.linear = nn.Linear(embedding_dim, 6 * embedding_dim, bias=True)
def forward(self, timestep: torch.Tensor, guidance: torch.Tensor = None, hidden_dtype: torch.dtype = None):
timesteps_proj = self.time_proj(timestep)
timesteps_emb = self.timestep_embedder(timesteps_proj.to(dtype=hidden_dtype)) # (N, D)
guidance_proj = self.guidance_condition_proj(guidance)
guidance_emb = self.guidance_embedder(guidance_proj.to(dtype=hidden_dtype))
conditioning = timesteps_emb + guidance_emb
return self.linear(self.silu(conditioning)), conditioning
class SanaAttnProcessor2_0:
r"""
Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0).
"""
def __init__(self):
if not hasattr(F, "scaled_dot_product_attention"):
raise ImportError("SanaAttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")
def __call__(
self,
attn: Attention,
hidden_states: torch.Tensor,
encoder_hidden_states: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.Tensor] = None,
) -> torch.Tensor:
batch_size, sequence_length, _ = (
hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
)
if attention_mask is not None:
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
# scaled_dot_product_attention expects attention_mask shape to be
# (batch, heads, source_length, target_length)
attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])
query = attn.to_q(hidden_states)
if encoder_hidden_states is None:
encoder_hidden_states = hidden_states
key = attn.to_k(encoder_hidden_states)
value = attn.to_v(encoder_hidden_states)
if attn.norm_q is not None:
query = attn.norm_q(query)
if attn.norm_k is not None:
key = attn.norm_k(key)
inner_dim = key.shape[-1]
head_dim = inner_dim // attn.heads
query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
# the output of sdp = (batch, num_heads, seq_len, head_dim)
# TODO: add support for attn.scale when we move to Torch 2.1
hidden_states = F.scaled_dot_product_attention(
query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
)
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
hidden_states = hidden_states.to(query.dtype)
# linear proj
hidden_states = attn.to_out[0](hidden_states)
# dropout
hidden_states = attn.to_out[1](hidden_states)
hidden_states = hidden_states / attn.rescale_output_factor
return hidden_states
class SanaTransformerBlock(nn.Module):
r"""
Transformer block introduced in [Sana](https://huggingface.co/papers/2410.10629).
@@ -204,7 +115,6 @@ class SanaTransformerBlock(nn.Module):
norm_eps: float = 1e-6,
attention_out_bias: bool = True,
mlp_ratio: float = 2.5,
qk_norm: Optional[str] = None,
) -> None:
super().__init__()
@@ -214,8 +124,6 @@ class SanaTransformerBlock(nn.Module):
query_dim=dim,
heads=num_attention_heads,
dim_head=attention_head_dim,
kv_heads=num_attention_heads if qk_norm is not None else None,
qk_norm=qk_norm,
dropout=dropout,
bias=attention_bias,
cross_attention_dim=None,
@@ -227,15 +135,13 @@ class SanaTransformerBlock(nn.Module):
self.norm2 = nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine, eps=norm_eps)
self.attn2 = Attention(
query_dim=dim,
qk_norm=qk_norm,
kv_heads=num_cross_attention_heads if qk_norm is not None else None,
cross_attention_dim=cross_attention_dim,
heads=num_cross_attention_heads,
dim_head=cross_attention_head_dim,
dropout=dropout,
bias=True,
out_bias=attention_out_bias,
processor=SanaAttnProcessor2_0(),
processor=AttnProcessor2_0(),
)
# 3. Feed-forward
@@ -326,10 +232,6 @@ class SanaTransformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOrig
Whether to use elementwise affinity in the normalization layer.
norm_eps (`float`, defaults to `1e-6`):
The epsilon value for the normalization layer.
qk_norm (`str`, *optional*, defaults to `None`):
The normalization to use for the query and key.
timestep_scale (`float`, defaults to `1.0`):
The scale to use for the timesteps.
"""
_supports_gradient_checkpointing = True
@@ -356,10 +258,6 @@ class SanaTransformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOrig
norm_elementwise_affine: bool = False,
norm_eps: float = 1e-6,
interpolation_scale: Optional[int] = None,
guidance_embeds: bool = False,
guidance_embeds_scale: float = 0.1,
qk_norm: Optional[str] = None,
timestep_scale: float = 1.0,
) -> None:
super().__init__()
@@ -378,10 +276,7 @@ class SanaTransformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOrig
)
# 2. Additional condition embeddings
if guidance_embeds:
self.time_embed = SanaCombinedTimestepGuidanceEmbeddings(inner_dim)
else:
self.time_embed = AdaLayerNormSingle(inner_dim)
self.time_embed = AdaLayerNormSingle(inner_dim)
self.caption_projection = PixArtAlphaTextProjection(in_features=caption_channels, hidden_size=inner_dim)
self.caption_norm = RMSNorm(inner_dim, eps=1e-5, elementwise_affine=True)
@@ -401,7 +296,6 @@ class SanaTransformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOrig
norm_elementwise_affine=norm_elementwise_affine,
norm_eps=norm_eps,
mlp_ratio=mlp_ratio,
qk_norm=qk_norm,
)
for _ in range(num_layers)
]
@@ -478,8 +372,7 @@ class SanaTransformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOrig
self,
hidden_states: torch.Tensor,
encoder_hidden_states: torch.Tensor,
timestep: torch.Tensor,
guidance: Optional[torch.Tensor] = None,
timestep: torch.LongTensor,
encoder_attention_mask: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.Tensor] = None,
attention_kwargs: Optional[Dict[str, Any]] = None,
@@ -530,14 +423,9 @@ class SanaTransformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOrig
hidden_states = self.patch_embed(hidden_states)
if guidance is not None:
timestep, embedded_timestep = self.time_embed(
timestep, guidance=guidance, hidden_dtype=hidden_states.dtype
)
else:
timestep, embedded_timestep = self.time_embed(
timestep, batch_size=batch_size, hidden_dtype=hidden_states.dtype
)
timestep, embedded_timestep = self.time_embed(
timestep, batch_size=batch_size, hidden_dtype=hidden_states.dtype
)
encoder_hidden_states = self.caption_projection(encoder_hidden_states)
encoder_hidden_states = encoder_hidden_states.view(batch_size, -1, hidden_states.shape[-1])
@@ -23,7 +23,6 @@ from ...loaders import PeftAdapterMixin
from ...utils import USE_PEFT_BACKEND, logging, scale_lora_layers, unscale_lora_layers
from ..attention import FeedForward
from ..attention_processor import Attention
from ..cache_utils import CacheMixin
from ..embeddings import CogView3CombinedTimestepSizeEmbeddings
from ..modeling_outputs import Transformer2DModelOutput
from ..modeling_utils import ModelMixin
@@ -127,8 +126,7 @@ class CogView4AttnProcessor:
attention_mask: Optional[torch.Tensor] = None,
image_rotary_emb: Optional[torch.Tensor] = None,
) -> torch.Tensor:
batch_size, text_seq_length, embed_dim = encoder_hidden_states.shape
batch_size, image_seq_length, embed_dim = hidden_states.shape
text_seq_length = encoder_hidden_states.size(1)
hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1)
# 1. QKV projections
@@ -158,15 +156,6 @@ class CogView4AttnProcessor:
)
# 4. Attention
if attention_mask is not None:
text_attention_mask = attention_mask.float().to(query.device)
actual_text_seq_length = text_attention_mask.size(1)
new_attention_mask = torch.zeros((batch_size, text_seq_length + image_seq_length), device=query.device)
new_attention_mask[:, :actual_text_seq_length] = text_attention_mask
new_attention_mask = new_attention_mask.unsqueeze(2)
attention_mask_matrix = new_attention_mask @ new_attention_mask.transpose(1, 2)
attention_mask = (attention_mask_matrix > 0).unsqueeze(1).to(query.dtype)
hidden_states = F.scaled_dot_product_attention(
query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
)
@@ -214,8 +203,6 @@ class CogView4TransformerBlock(nn.Module):
encoder_hidden_states: torch.Tensor,
temb: Optional[torch.Tensor] = None,
image_rotary_emb: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.Tensor] = None,
**kwargs,
) -> torch.Tensor:
# 1. Timestep conditioning
(
@@ -236,8 +223,6 @@ class CogView4TransformerBlock(nn.Module):
hidden_states=norm_hidden_states,
encoder_hidden_states=norm_encoder_hidden_states,
image_rotary_emb=image_rotary_emb,
attention_mask=attention_mask,
**kwargs,
)
hidden_states = hidden_states + attn_hidden_states * gate_msa.unsqueeze(1)
encoder_hidden_states = encoder_hidden_states + attn_encoder_hidden_states * c_gate_msa.unsqueeze(1)
@@ -304,7 +289,7 @@ class CogView4RotaryPosEmbed(nn.Module):
return (freqs.cos(), freqs.sin())
class CogView4Transformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, CacheMixin):
class CogView4Transformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin):
r"""
Args:
patch_size (`int`, defaults to `2`):
@@ -401,8 +386,6 @@ class CogView4Transformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, Cach
crop_coords: torch.Tensor,
attention_kwargs: Optional[Dict[str, Any]] = None,
return_dict: bool = True,
attention_mask: Optional[torch.Tensor] = None,
**kwargs,
) -> Union[torch.Tensor, Transformer2DModelOutput]:
if attention_kwargs is not None:
attention_kwargs = attention_kwargs.copy()
@@ -438,11 +421,11 @@ class CogView4Transformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, Cach
for block in self.transformer_blocks:
if torch.is_grad_enabled() and self.gradient_checkpointing:
hidden_states, encoder_hidden_states = self._gradient_checkpointing_func(
block, hidden_states, encoder_hidden_states, temb, image_rotary_emb, attention_mask, **kwargs
block, hidden_states, encoder_hidden_states, temb, image_rotary_emb
)
else:
hidden_states, encoder_hidden_states = block(
hidden_states, encoder_hidden_states, temb, image_rotary_emb, attention_mask, **kwargs
hidden_states, encoder_hidden_states, temb, image_rotary_emb
)
# 4. Output norm & projection
@@ -460,84 +443,3 @@ class CogView4Transformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, Cach
if not return_dict:
return (output,)
return Transformer2DModelOutput(sample=output)
### ===== Custom attention processors for guidance methods ===== ###
class CogView4PAGAttnProcessor:
"""
Processor for implementing scaled dot-product attention for the CogVideoX model. It applies a rotary embedding on
query and key vectors, but does not include spatial normalization.
"""
def __init__(self):
if not hasattr(F, "scaled_dot_product_attention"):
raise ImportError("CogView4AttnProcessor requires PyTorch 2.0. To use it, please upgrade PyTorch to 2.0.")
def __call__(
self,
attn: Attention,
hidden_states: torch.Tensor,
encoder_hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
image_rotary_emb: Optional[torch.Tensor] = None,
skip_context_attention: bool = False,
) -> torch.Tensor:
batch_size, text_seq_length, embed_dim = encoder_hidden_states.shape
batch_size, image_seq_length, embed_dim = hidden_states.shape
hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1)
# 1. QKV projections
query = attn.to_q(hidden_states)
key = attn.to_k(hidden_states)
value = attn.to_v(hidden_states)
query = query.unflatten(2, (attn.heads, -1)).transpose(1, 2)
key = key.unflatten(2, (attn.heads, -1)).transpose(1, 2)
value = value.unflatten(2, (attn.heads, -1)).transpose(1, 2)
# 2. QK normalization
if attn.norm_q is not None:
query = attn.norm_q(query)
if attn.norm_k is not None:
key = attn.norm_k(key)
# 3. Rotational positional embeddings applied to latent stream
if image_rotary_emb is not None:
from ..embeddings import apply_rotary_emb
query[:, :, text_seq_length:, :] = apply_rotary_emb(
query[:, :, text_seq_length:, :], image_rotary_emb, use_real_unbind_dim=-2
)
key[:, :, text_seq_length:, :] = apply_rotary_emb(
key[:, :, text_seq_length:, :], image_rotary_emb, use_real_unbind_dim=-2
)
# 4. Attention
if skip_context_attention:
hidden_states = value
else:
# PAG uses a custom attention mask for perturbed attention path:
# - Create attention mask with `float("-inf")` for image tokens and `0.0` for text tokens
# - Set diagonal to `0.0` for attention between image tokens
seq_length = text_seq_length + image_seq_length
perturbed_attention_mask = hidden_states.new_full((seq_length, seq_length), float("-inf"))
perturbed_attention_mask[:text_seq_length, :text_seq_length] = 0.0
perturbed_attention_mask.fill_diagonal_(0.0)
perturbed_attention_mask = perturbed_attention_mask.unsqueeze(0).unsqueeze(0)
hidden_states = F.scaled_dot_product_attention(
query, key, value, attn_mask=perturbed_attention_mask, dropout_p=0.0, is_causal=False
)
hidden_states = hidden_states.transpose(1, 2).flatten(2, 3)
hidden_states = hidden_states.type_as(query)
# 5. Output projection
hidden_states = attn.to_out[0](hidden_states)
hidden_states = attn.to_out[1](hidden_states)
encoder_hidden_states, hidden_states = hidden_states.split(
[text_seq_length, hidden_states.size(1) - text_seq_length], dim=1
)
return hidden_states, encoder_hidden_states
@@ -79,14 +79,10 @@ class FluxSingleTransformerBlock(nn.Module):
def forward(
self,
hidden_states: torch.Tensor,
encoder_hidden_states: torch.Tensor,
temb: torch.Tensor,
image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
joint_attention_kwargs: Optional[Dict[str, Any]] = None,
) -> torch.Tensor:
text_seq_len = encoder_hidden_states.shape[1]
hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1)
residual = hidden_states
norm_hidden_states, gate = self.norm(hidden_states, emb=temb)
mlp_hidden_states = self.act_mlp(self.proj_mlp(norm_hidden_states))
@@ -104,8 +100,7 @@ class FluxSingleTransformerBlock(nn.Module):
if hidden_states.dtype == torch.float16:
hidden_states = hidden_states.clip(-65504, 65504)
encoder_hidden_states, hidden_states = hidden_states[:, :text_seq_len], hidden_states[:, text_seq_len:]
return encoder_hidden_states, hidden_states
return hidden_states
@maybe_allow_in_graph
@@ -513,21 +508,20 @@ class FluxTransformer2DModel(
)
else:
hidden_states = hidden_states + controlnet_block_samples[index_block // interval_control]
hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1)
for index_block, block in enumerate(self.single_transformer_blocks):
if torch.is_grad_enabled() and self.gradient_checkpointing:
encoder_hidden_states, hidden_states = self._gradient_checkpointing_func(
hidden_states = self._gradient_checkpointing_func(
block,
hidden_states,
encoder_hidden_states,
temb,
image_rotary_emb,
)
else:
encoder_hidden_states, hidden_states = block(
hidden_states = block(
hidden_states=hidden_states,
encoder_hidden_states=encoder_hidden_states,
temb=temb,
image_rotary_emb=image_rotary_emb,
joint_attention_kwargs=joint_attention_kwargs,
@@ -537,7 +531,12 @@ class FluxTransformer2DModel(
if controlnet_single_block_samples is not None:
interval_control = len(self.single_transformer_blocks) / len(controlnet_single_block_samples)
interval_control = int(np.ceil(interval_control))
hidden_states = hidden_states + controlnet_single_block_samples[index_block // interval_control]
hidden_states[:, encoder_hidden_states.shape[1] :, ...] = (
hidden_states[:, encoder_hidden_states.shape[1] :, ...]
+ controlnet_single_block_samples[index_block // interval_control]
)
hidden_states = hidden_states[:, encoder_hidden_states.shape[1] :, ...]
hidden_states = self.norm_out(hidden_states, temb)
output = self.proj_out(hidden_states)
@@ -27,15 +27,13 @@ from ..attention import FeedForward
from ..attention_processor import Attention, AttentionProcessor
from ..cache_utils import CacheMixin
from ..embeddings import (
CombinedTimestepGuidanceTextProjEmbeddings,
CombinedTimestepTextProjEmbeddings,
PixArtAlphaTextProjection,
TimestepEmbedding,
Timesteps,
get_1d_rotary_pos_embed,
)
from ..modeling_outputs import Transformer2DModelOutput
from ..modeling_utils import ModelMixin
from ..normalization import AdaLayerNormContinuous, AdaLayerNormZero, AdaLayerNormZeroSingle, FP32LayerNorm
from ..normalization import AdaLayerNormContinuous, AdaLayerNormZero, AdaLayerNormZeroSingle
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
@@ -175,141 +173,6 @@ class HunyuanVideoAdaNorm(nn.Module):
return gate_msa, gate_mlp
class HunyuanVideoTokenReplaceAdaLayerNormZero(nn.Module):
def __init__(self, embedding_dim: int, norm_type: str = "layer_norm", bias: bool = True):
super().__init__()
self.silu = nn.SiLU()
self.linear = nn.Linear(embedding_dim, 6 * embedding_dim, bias=bias)
if norm_type == "layer_norm":
self.norm = nn.LayerNorm(embedding_dim, elementwise_affine=False, eps=1e-6)
elif norm_type == "fp32_layer_norm":
self.norm = FP32LayerNorm(embedding_dim, elementwise_affine=False, bias=False)
else:
raise ValueError(
f"Unsupported `norm_type` ({norm_type}) provided. Supported ones are: 'layer_norm', 'fp32_layer_norm'."
)
def forward(
self,
hidden_states: torch.Tensor,
emb: torch.Tensor,
token_replace_emb: torch.Tensor,
first_frame_num_tokens: int,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
emb = self.linear(self.silu(emb))
token_replace_emb = self.linear(self.silu(token_replace_emb))
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = emb.chunk(6, dim=1)
tr_shift_msa, tr_scale_msa, tr_gate_msa, tr_shift_mlp, tr_scale_mlp, tr_gate_mlp = token_replace_emb.chunk(
6, dim=1
)
norm_hidden_states = self.norm(hidden_states)
hidden_states_zero = (
norm_hidden_states[:, :first_frame_num_tokens] * (1 + tr_scale_msa[:, None]) + tr_shift_msa[:, None]
)
hidden_states_orig = (
norm_hidden_states[:, first_frame_num_tokens:] * (1 + scale_msa[:, None]) + shift_msa[:, None]
)
hidden_states = torch.cat([hidden_states_zero, hidden_states_orig], dim=1)
return (
hidden_states,
gate_msa,
shift_mlp,
scale_mlp,
gate_mlp,
tr_gate_msa,
tr_shift_mlp,
tr_scale_mlp,
tr_gate_mlp,
)
class HunyuanVideoTokenReplaceAdaLayerNormZeroSingle(nn.Module):
def __init__(self, embedding_dim: int, norm_type: str = "layer_norm", bias: bool = True):
super().__init__()
self.silu = nn.SiLU()
self.linear = nn.Linear(embedding_dim, 3 * embedding_dim, bias=bias)
if norm_type == "layer_norm":
self.norm = nn.LayerNorm(embedding_dim, elementwise_affine=False, eps=1e-6)
else:
raise ValueError(
f"Unsupported `norm_type` ({norm_type}) provided. Supported ones are: 'layer_norm', 'fp32_layer_norm'."
)
def forward(
self,
hidden_states: torch.Tensor,
emb: torch.Tensor,
token_replace_emb: torch.Tensor,
first_frame_num_tokens: int,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
emb = self.linear(self.silu(emb))
token_replace_emb = self.linear(self.silu(token_replace_emb))
shift_msa, scale_msa, gate_msa = emb.chunk(3, dim=1)
tr_shift_msa, tr_scale_msa, tr_gate_msa = token_replace_emb.chunk(3, dim=1)
norm_hidden_states = self.norm(hidden_states)
hidden_states_zero = (
norm_hidden_states[:, :first_frame_num_tokens] * (1 + tr_scale_msa[:, None]) + tr_shift_msa[:, None]
)
hidden_states_orig = (
norm_hidden_states[:, first_frame_num_tokens:] * (1 + scale_msa[:, None]) + shift_msa[:, None]
)
hidden_states = torch.cat([hidden_states_zero, hidden_states_orig], dim=1)
return hidden_states, gate_msa, tr_gate_msa
class HunyuanVideoConditionEmbedding(nn.Module):
def __init__(
self,
embedding_dim: int,
pooled_projection_dim: int,
guidance_embeds: bool,
image_condition_type: Optional[str] = None,
):
super().__init__()
self.image_condition_type = image_condition_type
self.time_proj = Timesteps(num_channels=256, flip_sin_to_cos=True, downscale_freq_shift=0)
self.timestep_embedder = TimestepEmbedding(in_channels=256, time_embed_dim=embedding_dim)
self.text_embedder = PixArtAlphaTextProjection(pooled_projection_dim, embedding_dim, act_fn="silu")
self.guidance_embedder = None
if guidance_embeds:
self.guidance_embedder = TimestepEmbedding(in_channels=256, time_embed_dim=embedding_dim)
def forward(
self, timestep: torch.Tensor, pooled_projection: torch.Tensor, guidance: Optional[torch.Tensor] = None
) -> Tuple[torch.Tensor, torch.Tensor]:
timesteps_proj = self.time_proj(timestep)
timesteps_emb = self.timestep_embedder(timesteps_proj.to(dtype=pooled_projection.dtype)) # (N, D)
pooled_projections = self.text_embedder(pooled_projection)
conditioning = timesteps_emb + pooled_projections
token_replace_emb = None
if self.image_condition_type == "token_replace":
token_replace_timestep = torch.zeros_like(timestep)
token_replace_proj = self.time_proj(token_replace_timestep)
token_replace_emb = self.timestep_embedder(token_replace_proj.to(dtype=pooled_projection.dtype))
token_replace_emb = token_replace_emb + pooled_projections
if self.guidance_embedder is not None:
guidance_proj = self.time_proj(guidance)
guidance_emb = self.guidance_embedder(guidance_proj.to(dtype=pooled_projection.dtype))
conditioning = conditioning + guidance_emb
return conditioning, token_replace_emb
class HunyuanVideoIndividualTokenRefinerBlock(nn.Module):
def __init__(
self,
@@ -527,8 +390,6 @@ class HunyuanVideoSingleTransformerBlock(nn.Module):
temb: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
*args,
**kwargs,
) -> torch.Tensor:
text_seq_length = encoder_hidden_states.shape[1]
hidden_states = torch.cat([hidden_states, encoder_hidden_states], dim=1)
@@ -607,8 +468,6 @@ class HunyuanVideoTransformerBlock(nn.Module):
temb: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
freqs_cis: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
*args,
**kwargs,
) -> Tuple[torch.Tensor, torch.Tensor]:
# 1. Input normalization
norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1(hidden_states, emb=temb)
@@ -644,181 +503,6 @@ class HunyuanVideoTransformerBlock(nn.Module):
return hidden_states, encoder_hidden_states
class HunyuanVideoTokenReplaceSingleTransformerBlock(nn.Module):
def __init__(
self,
num_attention_heads: int,
attention_head_dim: int,
mlp_ratio: float = 4.0,
qk_norm: str = "rms_norm",
) -> None:
super().__init__()
hidden_size = num_attention_heads * attention_head_dim
mlp_dim = int(hidden_size * mlp_ratio)
self.attn = Attention(
query_dim=hidden_size,
cross_attention_dim=None,
dim_head=attention_head_dim,
heads=num_attention_heads,
out_dim=hidden_size,
bias=True,
processor=HunyuanVideoAttnProcessor2_0(),
qk_norm=qk_norm,
eps=1e-6,
pre_only=True,
)
self.norm = HunyuanVideoTokenReplaceAdaLayerNormZeroSingle(hidden_size, norm_type="layer_norm")
self.proj_mlp = nn.Linear(hidden_size, mlp_dim)
self.act_mlp = nn.GELU(approximate="tanh")
self.proj_out = nn.Linear(hidden_size + mlp_dim, hidden_size)
def forward(
self,
hidden_states: torch.Tensor,
encoder_hidden_states: torch.Tensor,
temb: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
token_replace_emb: torch.Tensor = None,
num_tokens: int = None,
) -> torch.Tensor:
text_seq_length = encoder_hidden_states.shape[1]
hidden_states = torch.cat([hidden_states, encoder_hidden_states], dim=1)
residual = hidden_states
# 1. Input normalization
norm_hidden_states, gate, tr_gate = self.norm(hidden_states, temb, token_replace_emb, num_tokens)
mlp_hidden_states = self.act_mlp(self.proj_mlp(norm_hidden_states))
norm_hidden_states, norm_encoder_hidden_states = (
norm_hidden_states[:, :-text_seq_length, :],
norm_hidden_states[:, -text_seq_length:, :],
)
# 2. Attention
attn_output, context_attn_output = self.attn(
hidden_states=norm_hidden_states,
encoder_hidden_states=norm_encoder_hidden_states,
attention_mask=attention_mask,
image_rotary_emb=image_rotary_emb,
)
attn_output = torch.cat([attn_output, context_attn_output], dim=1)
# 3. Modulation and residual connection
hidden_states = torch.cat([attn_output, mlp_hidden_states], dim=2)
proj_output = self.proj_out(hidden_states)
hidden_states_zero = proj_output[:, :num_tokens] * tr_gate.unsqueeze(1)
hidden_states_orig = proj_output[:, num_tokens:] * gate.unsqueeze(1)
hidden_states = torch.cat([hidden_states_zero, hidden_states_orig], dim=1)
hidden_states = hidden_states + residual
hidden_states, encoder_hidden_states = (
hidden_states[:, :-text_seq_length, :],
hidden_states[:, -text_seq_length:, :],
)
return hidden_states, encoder_hidden_states
class HunyuanVideoTokenReplaceTransformerBlock(nn.Module):
def __init__(
self,
num_attention_heads: int,
attention_head_dim: int,
mlp_ratio: float,
qk_norm: str = "rms_norm",
) -> None:
super().__init__()
hidden_size = num_attention_heads * attention_head_dim
self.norm1 = HunyuanVideoTokenReplaceAdaLayerNormZero(hidden_size, norm_type="layer_norm")
self.norm1_context = AdaLayerNormZero(hidden_size, norm_type="layer_norm")
self.attn = Attention(
query_dim=hidden_size,
cross_attention_dim=None,
added_kv_proj_dim=hidden_size,
dim_head=attention_head_dim,
heads=num_attention_heads,
out_dim=hidden_size,
context_pre_only=False,
bias=True,
processor=HunyuanVideoAttnProcessor2_0(),
qk_norm=qk_norm,
eps=1e-6,
)
self.norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
self.ff = FeedForward(hidden_size, mult=mlp_ratio, activation_fn="gelu-approximate")
self.norm2_context = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
self.ff_context = FeedForward(hidden_size, mult=mlp_ratio, activation_fn="gelu-approximate")
def forward(
self,
hidden_states: torch.Tensor,
encoder_hidden_states: torch.Tensor,
temb: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
freqs_cis: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
token_replace_emb: torch.Tensor = None,
num_tokens: int = None,
) -> Tuple[torch.Tensor, torch.Tensor]:
# 1. Input normalization
(
norm_hidden_states,
gate_msa,
shift_mlp,
scale_mlp,
gate_mlp,
tr_gate_msa,
tr_shift_mlp,
tr_scale_mlp,
tr_gate_mlp,
) = self.norm1(hidden_states, temb, token_replace_emb, num_tokens)
norm_encoder_hidden_states, c_gate_msa, c_shift_mlp, c_scale_mlp, c_gate_mlp = self.norm1_context(
encoder_hidden_states, emb=temb
)
# 2. Joint attention
attn_output, context_attn_output = self.attn(
hidden_states=norm_hidden_states,
encoder_hidden_states=norm_encoder_hidden_states,
attention_mask=attention_mask,
image_rotary_emb=freqs_cis,
)
# 3. Modulation and residual connection
hidden_states_zero = hidden_states[:, :num_tokens] + attn_output[:, :num_tokens] * tr_gate_msa.unsqueeze(1)
hidden_states_orig = hidden_states[:, num_tokens:] + attn_output[:, num_tokens:] * gate_msa.unsqueeze(1)
hidden_states = torch.cat([hidden_states_zero, hidden_states_orig], dim=1)
encoder_hidden_states = encoder_hidden_states + context_attn_output * c_gate_msa.unsqueeze(1)
norm_hidden_states = self.norm2(hidden_states)
norm_encoder_hidden_states = self.norm2_context(encoder_hidden_states)
hidden_states_zero = norm_hidden_states[:, :num_tokens] * (1 + tr_scale_mlp[:, None]) + tr_shift_mlp[:, None]
hidden_states_orig = norm_hidden_states[:, num_tokens:] * (1 + scale_mlp[:, None]) + shift_mlp[:, None]
norm_hidden_states = torch.cat([hidden_states_zero, hidden_states_orig], dim=1)
norm_encoder_hidden_states = norm_encoder_hidden_states * (1 + c_scale_mlp[:, None]) + c_shift_mlp[:, None]
# 4. Feed-forward
ff_output = self.ff(norm_hidden_states)
context_ff_output = self.ff_context(norm_encoder_hidden_states)
hidden_states_zero = hidden_states[:, :num_tokens] + ff_output[:, :num_tokens] * tr_gate_mlp.unsqueeze(1)
hidden_states_orig = hidden_states[:, num_tokens:] + ff_output[:, num_tokens:] * gate_mlp.unsqueeze(1)
hidden_states = torch.cat([hidden_states_zero, hidden_states_orig], dim=1)
encoder_hidden_states = encoder_hidden_states + c_gate_mlp.unsqueeze(1) * context_ff_output
return hidden_states, encoder_hidden_states
class HunyuanVideoTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginalModelMixin, CacheMixin):
r"""
A Transformer model for video-like data used in [HunyuanVideo](https://huggingface.co/tencent/HunyuanVideo).
@@ -856,10 +540,6 @@ class HunyuanVideoTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin,
The value of theta to use in the RoPE layer.
rope_axes_dim (`Tuple[int]`, defaults to `(16, 56, 56)`):
The dimensions of the axes to use in the RoPE layer.
image_condition_type (`str`, *optional*, defaults to `None`):
The type of image conditioning to use. If `None`, no image conditioning is used. If `latent_concat`, the
image is concatenated to the latent stream. If `token_replace`, the image is used to replace first-frame
tokens in the latent stream and apply conditioning.
"""
_supports_gradient_checkpointing = True
@@ -890,16 +570,9 @@ class HunyuanVideoTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin,
pooled_projection_dim: int = 768,
rope_theta: float = 256.0,
rope_axes_dim: Tuple[int] = (16, 56, 56),
image_condition_type: Optional[str] = None,
) -> None:
super().__init__()
supported_image_condition_types = ["latent_concat", "token_replace"]
if image_condition_type is not None and image_condition_type not in supported_image_condition_types:
raise ValueError(
f"Invalid `image_condition_type` ({image_condition_type}). Supported ones are: {supported_image_condition_types}"
)
inner_dim = num_attention_heads * attention_head_dim
out_channels = out_channels or in_channels
@@ -909,52 +582,33 @@ class HunyuanVideoTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin,
text_embed_dim, num_attention_heads, attention_head_dim, num_layers=num_refiner_layers
)
self.time_text_embed = HunyuanVideoConditionEmbedding(
inner_dim, pooled_projection_dim, guidance_embeds, image_condition_type
)
if guidance_embeds:
self.time_text_embed = CombinedTimestepGuidanceTextProjEmbeddings(inner_dim, pooled_projection_dim)
else:
self.time_text_embed = CombinedTimestepTextProjEmbeddings(inner_dim, pooled_projection_dim)
# 2. RoPE
self.rope = HunyuanVideoRotaryPosEmbed(patch_size, patch_size_t, rope_axes_dim, rope_theta)
# 3. Dual stream transformer blocks
if image_condition_type == "token_replace":
self.transformer_blocks = nn.ModuleList(
[
HunyuanVideoTokenReplaceTransformerBlock(
num_attention_heads, attention_head_dim, mlp_ratio=mlp_ratio, qk_norm=qk_norm
)
for _ in range(num_layers)
]
)
else:
self.transformer_blocks = nn.ModuleList(
[
HunyuanVideoTransformerBlock(
num_attention_heads, attention_head_dim, mlp_ratio=mlp_ratio, qk_norm=qk_norm
)
for _ in range(num_layers)
]
)
self.transformer_blocks = nn.ModuleList(
[
HunyuanVideoTransformerBlock(
num_attention_heads, attention_head_dim, mlp_ratio=mlp_ratio, qk_norm=qk_norm
)
for _ in range(num_layers)
]
)
# 4. Single stream transformer blocks
if image_condition_type == "token_replace":
self.single_transformer_blocks = nn.ModuleList(
[
HunyuanVideoTokenReplaceSingleTransformerBlock(
num_attention_heads, attention_head_dim, mlp_ratio=mlp_ratio, qk_norm=qk_norm
)
for _ in range(num_single_layers)
]
)
else:
self.single_transformer_blocks = nn.ModuleList(
[
HunyuanVideoSingleTransformerBlock(
num_attention_heads, attention_head_dim, mlp_ratio=mlp_ratio, qk_norm=qk_norm
)
for _ in range(num_single_layers)
]
)
self.single_transformer_blocks = nn.ModuleList(
[
HunyuanVideoSingleTransformerBlock(
num_attention_heads, attention_head_dim, mlp_ratio=mlp_ratio, qk_norm=qk_norm
)
for _ in range(num_single_layers)
]
)
# 5. Output projection
self.norm_out = AdaLayerNormContinuous(inner_dim, inner_dim, elementwise_affine=False, eps=1e-6)
@@ -1053,13 +707,15 @@ class HunyuanVideoTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin,
post_patch_num_frames = num_frames // p_t
post_patch_height = height // p
post_patch_width = width // p
first_frame_num_tokens = 1 * post_patch_height * post_patch_width
# 1. RoPE
image_rotary_emb = self.rope(hidden_states)
# 2. Conditional embeddings
temb, token_replace_emb = self.time_text_embed(timestep, pooled_projections, guidance)
if self.config.guidance_embeds:
temb = self.time_text_embed(timestep, guidance, pooled_projections)
else:
temb = self.time_text_embed(timestep, pooled_projections)
hidden_states = self.x_embedder(hidden_states)
encoder_hidden_states = self.context_embedder(encoder_hidden_states, timestep, encoder_attention_mask)
@@ -1090,8 +746,6 @@ class HunyuanVideoTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin,
temb,
attention_mask,
image_rotary_emb,
token_replace_emb,
first_frame_num_tokens,
)
for block in self.single_transformer_blocks:
@@ -1102,31 +756,17 @@ class HunyuanVideoTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin,
temb,
attention_mask,
image_rotary_emb,
token_replace_emb,
first_frame_num_tokens,
)
else:
for block in self.transformer_blocks:
hidden_states, encoder_hidden_states = block(
hidden_states,
encoder_hidden_states,
temb,
attention_mask,
image_rotary_emb,
token_replace_emb,
first_frame_num_tokens,
hidden_states, encoder_hidden_states, temb, attention_mask, image_rotary_emb
)
for block in self.single_transformer_blocks:
hidden_states, encoder_hidden_states = block(
hidden_states,
encoder_hidden_states,
temb,
attention_mask,
image_rotary_emb,
token_replace_emb,
first_frame_num_tokens,
hidden_states, encoder_hidden_states, temb, attention_mask, image_rotary_emb
)
# 5. Output projection

Some files were not shown because too many files have changed in this diff Show More