Compare commits
79 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| ca5cfbd37e | |||
| 5ac65c4513 | |||
| 8c2b2cdc52 | |||
| df1d7b01f1 | |||
| 5a6edac087 | |||
| e8fc8b1f81 | |||
| d6f4774c1c | |||
| eb50defff2 | |||
| 2c59af7222 | |||
| 75d7e5cc45 | |||
| 617c208bb4 | |||
| 5d970a4aa9 | |||
| de6a88c2d7 | |||
| 7dc52ea769 | |||
| 739d6ec731 | |||
| 1ddf3f3a19 | |||
| 7aac77affa | |||
| 8907a70a36 | |||
| 5dbe4f5de6 | |||
| 1d37f42055 | |||
| 0213179ba8 | |||
| a7d53a5939 | |||
| 8a63aa5e4f | |||
| 844221ae4e | |||
| 9b2c0a7dbe | |||
| f424b1b062 | |||
| e9fda3924f | |||
| 2c1ed50fc5 | |||
| 15ad97f782 | |||
| 9f2d5c9ee9 | |||
| dc62e6931e | |||
| 56f740051d | |||
| a34d97cef0 | |||
| fc28791fc8 | |||
| ae14612673 | |||
| 0ab8fe49bf | |||
| 3be6706018 | |||
| cb1b8b21b8 | |||
| 27916822b2 | |||
| 3fe3bc0642 | |||
| 813d42cc96 | |||
| b4d7e9c632 | |||
| 2e83cbbb6d | |||
| 33d10af28f | |||
| 100142586f | |||
| 82188cef04 | |||
| cc19726f3d | |||
| be54a95b93 | |||
| 6b9a3334db | |||
| 8ead643bb7 | |||
| 124ac3e81f | |||
| 2f0f281b0d | |||
| ccc8321651 | |||
| 5e48cd27d4 | |||
| 5551506b29 | |||
| 20e4b6a628 | |||
| 4ea9f89b8e | |||
| 733b44ac82 | |||
| 8b4f8ba764 | |||
| 5428046437 | |||
| e7ffeae0a1 | |||
| d87ce2cefc | |||
| 36d0553af2 | |||
| 7e0db46f73 | |||
| e4b056fe65 | |||
| 4e3ddd5afa | |||
| 9add071592 | |||
| b88fef4785 | |||
| e7e6d85282 | |||
| 8eefed65bd | |||
| 26149c0ecd | |||
| 0703ce8800 | |||
| f5edaa7894 | |||
| 9a1810f0de | |||
| 1fddee211e | |||
| b38450d5d2 | |||
| 1357931d74 | |||
| a2d3d6af44 | |||
| 363d1ab7e2 |
@@ -38,6 +38,7 @@ jobs:
|
||||
python -m venv /opt/venv && export PATH="/opt/venv/bin:$PATH"
|
||||
python -m uv pip install -e [quality,test]
|
||||
python -m uv pip install pandas peft
|
||||
python -m uv pip uninstall transformers && python -m uv pip install transformers==4.48.0
|
||||
- name: Environment
|
||||
run: |
|
||||
python utils/print_env.py
|
||||
|
||||
@@ -414,10 +414,16 @@ jobs:
|
||||
config:
|
||||
- backend: "bitsandbytes"
|
||||
test_location: "bnb"
|
||||
additional_deps: ["peft"]
|
||||
- backend: "gguf"
|
||||
test_location: "gguf"
|
||||
additional_deps: []
|
||||
- backend: "torchao"
|
||||
test_location: "torchao"
|
||||
additional_deps: []
|
||||
- backend: "optimum_quanto"
|
||||
test_location: "quanto"
|
||||
additional_deps: []
|
||||
runs-on:
|
||||
group: aws-g6e-xlarge-plus
|
||||
container:
|
||||
@@ -435,6 +441,9 @@ jobs:
|
||||
python -m venv /opt/venv && export PATH="/opt/venv/bin:$PATH"
|
||||
python -m uv pip install -e [quality,test]
|
||||
python -m uv pip install -U ${{ matrix.config.backend }}
|
||||
if [ "${{ join(matrix.config.additional_deps, ' ') }}" != "" ]; then
|
||||
python -m uv pip install ${{ join(matrix.config.additional_deps, ' ') }}
|
||||
fi
|
||||
python -m uv pip install pytest-reportlog
|
||||
- name: Environment
|
||||
run: |
|
||||
|
||||
@@ -28,7 +28,51 @@ env:
|
||||
PIPELINE_USAGE_CUTOFF: 1000000000 # set high cutoff so that only always-test pipelines run
|
||||
|
||||
jobs:
|
||||
check_code_quality:
|
||||
runs-on: ubuntu-22.04
|
||||
steps:
|
||||
- uses: actions/checkout@v3
|
||||
- name: Set up Python
|
||||
uses: actions/setup-python@v4
|
||||
with:
|
||||
python-version: "3.8"
|
||||
- name: Install dependencies
|
||||
run: |
|
||||
python -m pip install --upgrade pip
|
||||
pip install .[quality]
|
||||
- name: Check quality
|
||||
run: make quality
|
||||
- name: Check if failure
|
||||
if: ${{ failure() }}
|
||||
run: |
|
||||
echo "Quality check failed. Please ensure the right dependency versions are installed with 'pip install -e .[quality]' and run 'make style && make quality'" >> $GITHUB_STEP_SUMMARY
|
||||
|
||||
check_repository_consistency:
|
||||
needs: check_code_quality
|
||||
runs-on: ubuntu-22.04
|
||||
steps:
|
||||
- uses: actions/checkout@v3
|
||||
- name: Set up Python
|
||||
uses: actions/setup-python@v4
|
||||
with:
|
||||
python-version: "3.8"
|
||||
- name: Install dependencies
|
||||
run: |
|
||||
python -m pip install --upgrade pip
|
||||
pip install .[quality]
|
||||
- name: Check repo consistency
|
||||
run: |
|
||||
python utils/check_copies.py
|
||||
python utils/check_dummies.py
|
||||
python utils/check_support_list.py
|
||||
make deps_table_check_updated
|
||||
- name: Check if failure
|
||||
if: ${{ failure() }}
|
||||
run: |
|
||||
echo "Repo consistency check failed. Please ensure the right dependency versions are installed with 'pip install -e .[quality]' and run 'make fix-copies'" >> $GITHUB_STEP_SUMMARY
|
||||
|
||||
setup_torch_cuda_pipeline_matrix:
|
||||
needs: [check_code_quality, check_repository_consistency]
|
||||
name: Setup Torch Pipelines CUDA Slow Tests Matrix
|
||||
runs-on:
|
||||
group: aws-general-8-plus
|
||||
@@ -133,6 +177,7 @@ jobs:
|
||||
|
||||
torch_cuda_tests:
|
||||
name: Torch CUDA Tests
|
||||
needs: [check_code_quality, check_repository_consistency]
|
||||
runs-on:
|
||||
group: aws-g4dn-2xlarge
|
||||
container:
|
||||
@@ -201,7 +246,7 @@ jobs:
|
||||
|
||||
run_examples_tests:
|
||||
name: Examples PyTorch CUDA tests on Ubuntu
|
||||
pip uninstall transformers -y && python -m uv pip install -U transformers@git+https://github.com/huggingface/transformers.git --no-deps
|
||||
needs: [check_code_quality, check_repository_consistency]
|
||||
runs-on:
|
||||
group: aws-g4dn-2xlarge
|
||||
|
||||
@@ -220,6 +265,7 @@ jobs:
|
||||
- name: Install dependencies
|
||||
run: |
|
||||
python -m venv /opt/venv && export PATH="/opt/venv/bin:$PATH"
|
||||
pip uninstall transformers -y && python -m uv pip install -U transformers@git+https://github.com/huggingface/transformers.git --no-deps
|
||||
python -m uv pip install -e [quality,test,training]
|
||||
|
||||
- name: Environment
|
||||
|
||||
@@ -81,6 +81,8 @@
|
||||
title: Overview
|
||||
- local: hybrid_inference/vae_decode
|
||||
title: VAE Decode
|
||||
- local: hybrid_inference/vae_encode
|
||||
title: VAE Encode
|
||||
- local: hybrid_inference/api_reference
|
||||
title: API Reference
|
||||
title: Hybrid Inference
|
||||
@@ -173,6 +175,8 @@
|
||||
title: gguf
|
||||
- local: quantization/torchao
|
||||
title: torchao
|
||||
- local: quantization/quanto
|
||||
title: quanto
|
||||
title: Quantization Methods
|
||||
- sections:
|
||||
- local: optimization/fp16
|
||||
@@ -492,6 +496,8 @@
|
||||
title: PixArt-Σ
|
||||
- local: api/pipelines/sana
|
||||
title: Sana
|
||||
- local: api/pipelines/sana_sprint
|
||||
title: Sana Sprint
|
||||
- local: api/pipelines/self_attention_guidance
|
||||
title: Self-Attention Guidance
|
||||
- local: api/pipelines/semantic_stable_diffusion
|
||||
|
||||
@@ -38,6 +38,33 @@ config = PyramidAttentionBroadcastConfig(
|
||||
pipe.transformer.enable_cache(config)
|
||||
```
|
||||
|
||||
## Faster Cache
|
||||
|
||||
[FasterCache](https://huggingface.co/papers/2410.19355) from Zhengyao Lv, Chenyang Si, Junhao Song, Zhenyu Yang, Yu Qiao, Ziwei Liu, Kwan-Yee K. Wong.
|
||||
|
||||
FasterCache is a method that speeds up inference in diffusion transformers by:
|
||||
- Reusing attention states between successive inference steps, due to high similarity between them
|
||||
- Skipping unconditional branch prediction used in classifier-free guidance by revealing redundancies between unconditional and conditional branch outputs for the same timestep, and therefore approximating the unconditional branch output using the conditional branch output
|
||||
|
||||
```python
|
||||
import torch
|
||||
from diffusers import CogVideoXPipeline, FasterCacheConfig
|
||||
|
||||
pipe = CogVideoXPipeline.from_pretrained("THUDM/CogVideoX-5b", torch_dtype=torch.bfloat16)
|
||||
pipe.to("cuda")
|
||||
|
||||
config = FasterCacheConfig(
|
||||
spatial_attention_block_skip_range=2,
|
||||
spatial_attention_timestep_skip_range=(-1, 681),
|
||||
current_timestep_callback=lambda: pipe.current_timestep,
|
||||
attention_weight_callback=lambda _: 0.3,
|
||||
unconditional_batch_skip_range=5,
|
||||
unconditional_batch_timestep_skip_range=(-1, 781),
|
||||
tensor_format="BFCHW",
|
||||
)
|
||||
pipe.transformer.enable_cache(config)
|
||||
```
|
||||
|
||||
### CacheMixin
|
||||
|
||||
[[autodoc]] CacheMixin
|
||||
@@ -47,3 +74,9 @@ pipe.transformer.enable_cache(config)
|
||||
[[autodoc]] PyramidAttentionBroadcastConfig
|
||||
|
||||
[[autodoc]] apply_pyramid_attention_broadcast
|
||||
|
||||
### FasterCacheConfig
|
||||
|
||||
[[autodoc]] FasterCacheConfig
|
||||
|
||||
[[autodoc]] apply_faster_cache
|
||||
|
||||
@@ -50,7 +50,8 @@ The following models are available for the image-to-video pipeline:
|
||||
| Model name | Description |
|
||||
|:---|:---|
|
||||
| [`Skywork/SkyReels-V1-Hunyuan-I2V`](https://huggingface.co/Skywork/SkyReels-V1-Hunyuan-I2V) | Skywork's custom finetune of HunyuanVideo (de-distilled). Performs best with `97x544x960` resolution. Performs best at `97x544x960` resolution, `guidance_scale=1.0`, `true_cfg_scale=6.0` and a negative prompt. |
|
||||
| [`hunyuanvideo-community/HunyuanVideo-I2V`](https://huggingface.co/hunyuanvideo-community/HunyuanVideo-I2V) | Tecent's official HunyuanVideo I2V model. Performs best at resolutions of 480, 720, 960, 1280. A higher `shift` value when initializing the scheduler is recommended (good values are between 7 and 20) |
|
||||
| [`hunyuanvideo-community/HunyuanVideo-I2V-33ch`](https://huggingface.co/hunyuanvideo-community/HunyuanVideo-I2V) | Tecent's official HunyuanVideo 33-channel I2V model. Performs best at resolutions of 480, 720, 960, 1280. A higher `shift` value when initializing the scheduler is recommended (good values are between 7 and 20). |
|
||||
| [`hunyuanvideo-community/HunyuanVideo-I2V`](https://huggingface.co/hunyuanvideo-community/HunyuanVideo-I2V) | Tecent's official HunyuanVideo 16-channel I2V model. Performs best at resolutions of 480, 720, 960, 1280. A higher `shift` value when initializing the scheduler is recommended (good values are between 7 and 20) |
|
||||
|
||||
## Quantization
|
||||
|
||||
|
||||
@@ -196,6 +196,12 @@ export_to_video(video, "ship.mp4", fps=24)
|
||||
- all
|
||||
- __call__
|
||||
|
||||
## LTXConditionPipeline
|
||||
|
||||
[[autodoc]] LTXConditionPipeline
|
||||
- all
|
||||
- __call__
|
||||
|
||||
## LTXPipelineOutput
|
||||
|
||||
[[autodoc]] pipelines.ltx.pipeline_output.LTXPipelineOutput
|
||||
|
||||
@@ -58,10 +58,10 @@ Use [`torch.compile`](https://huggingface.co/docs/diffusers/main/en/tutorials/fa
|
||||
First, load the pipeline:
|
||||
|
||||
```python
|
||||
from diffusers import LuminaText2ImgPipeline
|
||||
from diffusers import LuminaPipeline
|
||||
import torch
|
||||
|
||||
pipeline = LuminaText2ImgPipeline.from_pretrained(
|
||||
pipeline = LuminaPipeline.from_pretrained(
|
||||
"Alpha-VLLM/Lumina-Next-SFT-diffusers", torch_dtype=torch.bfloat16
|
||||
).to("cuda")
|
||||
```
|
||||
@@ -86,11 +86,11 @@ image = pipeline(prompt="Upper body of a young woman in a Victorian-era outfit w
|
||||
|
||||
Quantization helps reduce the memory requirements of very large models by storing model weights in a lower precision data type. However, quantization may have varying impact on video quality depending on the video model.
|
||||
|
||||
Refer to the [Quantization](../../quantization/overview) overview to learn more about supported quantization backends and selecting a quantization backend that supports your use case. The example below demonstrates how to load a quantized [`LuminaText2ImgPipeline`] for inference with bitsandbytes.
|
||||
Refer to the [Quantization](../../quantization/overview) overview to learn more about supported quantization backends and selecting a quantization backend that supports your use case. The example below demonstrates how to load a quantized [`LuminaPipeline`] for inference with bitsandbytes.
|
||||
|
||||
```py
|
||||
import torch
|
||||
from diffusers import BitsAndBytesConfig as DiffusersBitsAndBytesConfig, Transformer2DModel, LuminaText2ImgPipeline
|
||||
from diffusers import BitsAndBytesConfig as DiffusersBitsAndBytesConfig, Transformer2DModel, LuminaPipeline
|
||||
from transformers import BitsAndBytesConfig as BitsAndBytesConfig, T5EncoderModel
|
||||
|
||||
quant_config = BitsAndBytesConfig(load_in_8bit=True)
|
||||
@@ -109,7 +109,7 @@ transformer_8bit = Transformer2DModel.from_pretrained(
|
||||
torch_dtype=torch.float16,
|
||||
)
|
||||
|
||||
pipeline = LuminaText2ImgPipeline.from_pretrained(
|
||||
pipeline = LuminaPipeline.from_pretrained(
|
||||
"Alpha-VLLM/Lumina-Next-SFT-diffusers",
|
||||
text_encoder=text_encoder_8bit,
|
||||
transformer=transformer_8bit,
|
||||
@@ -122,9 +122,9 @@ image = pipeline(prompt).images[0]
|
||||
image.save("lumina.png")
|
||||
```
|
||||
|
||||
## LuminaText2ImgPipeline
|
||||
## LuminaPipeline
|
||||
|
||||
[[autodoc]] LuminaText2ImgPipeline
|
||||
[[autodoc]] LuminaPipeline
|
||||
- all
|
||||
- __call__
|
||||
|
||||
|
||||
@@ -36,14 +36,14 @@ Single file loading for Lumina Image 2.0 is available for the `Lumina2Transforme
|
||||
|
||||
```python
|
||||
import torch
|
||||
from diffusers import Lumina2Transformer2DModel, Lumina2Text2ImgPipeline
|
||||
from diffusers import Lumina2Transformer2DModel, Lumina2Pipeline
|
||||
|
||||
ckpt_path = "https://huggingface.co/Alpha-VLLM/Lumina-Image-2.0/blob/main/consolidated.00-of-01.pth"
|
||||
transformer = Lumina2Transformer2DModel.from_single_file(
|
||||
ckpt_path, torch_dtype=torch.bfloat16
|
||||
)
|
||||
|
||||
pipe = Lumina2Text2ImgPipeline.from_pretrained(
|
||||
pipe = Lumina2Pipeline.from_pretrained(
|
||||
"Alpha-VLLM/Lumina-Image-2.0", transformer=transformer, torch_dtype=torch.bfloat16
|
||||
)
|
||||
pipe.enable_model_cpu_offload()
|
||||
@@ -60,7 +60,7 @@ image.save("lumina-single-file.png")
|
||||
GGUF Quantized checkpoints for the `Lumina2Transformer2DModel` can be loaded via `from_single_file` with the `GGUFQuantizationConfig`
|
||||
|
||||
```python
|
||||
from diffusers import Lumina2Transformer2DModel, Lumina2Text2ImgPipeline, GGUFQuantizationConfig
|
||||
from diffusers import Lumina2Transformer2DModel, Lumina2Pipeline, GGUFQuantizationConfig
|
||||
|
||||
ckpt_path = "https://huggingface.co/calcuis/lumina-gguf/blob/main/lumina2-q4_0.gguf"
|
||||
transformer = Lumina2Transformer2DModel.from_single_file(
|
||||
@@ -69,7 +69,7 @@ transformer = Lumina2Transformer2DModel.from_single_file(
|
||||
torch_dtype=torch.bfloat16,
|
||||
)
|
||||
|
||||
pipe = Lumina2Text2ImgPipeline.from_pretrained(
|
||||
pipe = Lumina2Pipeline.from_pretrained(
|
||||
"Alpha-VLLM/Lumina-Image-2.0", transformer=transformer, torch_dtype=torch.bfloat16
|
||||
)
|
||||
pipe.enable_model_cpu_offload()
|
||||
@@ -80,8 +80,8 @@ image = pipe(
|
||||
image.save("lumina-gguf.png")
|
||||
```
|
||||
|
||||
## Lumina2Text2ImgPipeline
|
||||
## Lumina2Pipeline
|
||||
|
||||
[[autodoc]] Lumina2Text2ImgPipeline
|
||||
[[autodoc]] Lumina2Pipeline
|
||||
- all
|
||||
- __call__
|
||||
|
||||
@@ -0,0 +1,100 @@
|
||||
<!-- Copyright 2024 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License. -->
|
||||
|
||||
# SanaSprintPipeline
|
||||
|
||||
<div class="flex flex-wrap space-x-1">
|
||||
<img alt="LoRA" src="https://img.shields.io/badge/LoRA-d8b4fe?style=flat"/>
|
||||
</div>
|
||||
|
||||
[SANA-Sprint: One-Step Diffusion with Continuous-Time Consistency Distillation](https://huggingface.co/papers/2503.09641) from NVIDIA, MIT HAN Lab, and Hugging Face by Junsong Chen, Shuchen Xue, Yuyang Zhao, Jincheng Yu, Sayak Paul, Junyu Chen, Han Cai, Enze Xie, Song Han
|
||||
|
||||
The abstract from the paper is:
|
||||
|
||||
*This paper presents SANA-Sprint, an efficient diffusion model for ultra-fast text-to-image (T2I) generation. SANA-Sprint is built on a pre-trained foundation model and augmented with hybrid distillation, dramatically reducing inference steps from 20 to 1-4. We introduce three key innovations: (1) We propose a training-free approach that transforms a pre-trained flow-matching model for continuous-time consistency distillation (sCM), eliminating costly training from scratch and achieving high training efficiency. Our hybrid distillation strategy combines sCM with latent adversarial distillation (LADD): sCM ensures alignment with the teacher model, while LADD enhances single-step generation fidelity. (2) SANA-Sprint is a unified step-adaptive model that achieves high-quality generation in 1-4 steps, eliminating step-specific training and improving efficiency. (3) We integrate ControlNet with SANA-Sprint for real-time interactive image generation, enabling instant visual feedback for user interaction. SANA-Sprint establishes a new Pareto frontier in speed-quality tradeoffs, achieving state-of-the-art performance with 7.59 FID and 0.74 GenEval in only 1 step — outperforming FLUX-schnell (7.94 FID / 0.71 GenEval) while being 10× faster (0.1s vs 1.1s on H100). It also achieves 0.1s (T2I) and 0.25s (ControlNet) latency for 1024×1024 images on H100, and 0.31s (T2I) on an RTX 4090, showcasing its exceptional efficiency and potential for AI-powered consumer applications (AIPC). Code and pre-trained models will be open-sourced.*
|
||||
|
||||
<Tip>
|
||||
|
||||
Make sure to check out the Schedulers [guide](../../using-diffusers/schedulers) to learn how to explore the tradeoff between scheduler speed and quality, and see the [reuse components across pipelines](../../using-diffusers/loading#reuse-a-pipeline) section to learn how to efficiently load the same components into multiple pipelines.
|
||||
|
||||
</Tip>
|
||||
|
||||
This pipeline was contributed by [lawrence-cj](https://github.com/lawrence-cj), [shuchen Xue](https://github.com/scxue) and [Enze Xie](https://github.com/xieenze). The original codebase can be found [here](https://github.com/NVlabs/Sana). The original weights can be found under [hf.co/Efficient-Large-Model](https://huggingface.co/Efficient-Large-Model/).
|
||||
|
||||
Available models:
|
||||
|
||||
| Model | Recommended dtype |
|
||||
|:-------------------------------------------------------------------------------------------------------------------------------------------:|:-----------------:|
|
||||
| [`Efficient-Large-Model/Sana_Sprint_1.6B_1024px_diffusers`](https://huggingface.co/Efficient-Large-Model/Sana_Sprint_1.6B_1024px_diffusers) | `torch.bfloat16` |
|
||||
| [`Efficient-Large-Model/Sana_Sprint_0.6B_1024px_diffusers`](https://huggingface.co/Efficient-Large-Model/Sana_Sprint_0.6B_1024px_diffusers) | `torch.bfloat16` |
|
||||
|
||||
Refer to [this](https://huggingface.co/collections/Efficient-Large-Model/sana-sprint-67d6810d65235085b3b17c76) collection for more information.
|
||||
|
||||
Note: The recommended dtype mentioned is for the transformer weights. The text encoder must stay in `torch.bfloat16` and VAE weights must stay in `torch.bfloat16` or `torch.float32` for the model to work correctly. Please refer to the inference example below to see how to load the model with the recommended dtype.
|
||||
|
||||
|
||||
## Quantization
|
||||
|
||||
Quantization helps reduce the memory requirements of very large models by storing model weights in a lower precision data type. However, quantization may have varying impact on video quality depending on the video model.
|
||||
|
||||
Refer to the [Quantization](../../quantization/overview) overview to learn more about supported quantization backends and selecting a quantization backend that supports your use case. The example below demonstrates how to load a quantized [`SanaSprintPipeline`] for inference with bitsandbytes.
|
||||
|
||||
```py
|
||||
import torch
|
||||
from diffusers import BitsAndBytesConfig as DiffusersBitsAndBytesConfig, SanaTransformer2DModel, SanaSprintPipeline
|
||||
from transformers import BitsAndBytesConfig as BitsAndBytesConfig, AutoModel
|
||||
|
||||
quant_config = BitsAndBytesConfig(load_in_8bit=True)
|
||||
text_encoder_8bit = AutoModel.from_pretrained(
|
||||
"Efficient-Large-Model/Sana_Sprint_1.6B_1024px_diffusers",
|
||||
subfolder="text_encoder",
|
||||
quantization_config=quant_config,
|
||||
torch_dtype=torch.bfloat16,
|
||||
)
|
||||
|
||||
quant_config = DiffusersBitsAndBytesConfig(load_in_8bit=True)
|
||||
transformer_8bit = SanaTransformer2DModel.from_pretrained(
|
||||
"Efficient-Large-Model/Sana_Sprint_1.6B_1024px_diffusers",
|
||||
subfolder="transformer",
|
||||
quantization_config=quant_config,
|
||||
torch_dtype=torch.bfloat16,
|
||||
)
|
||||
|
||||
pipeline = SanaSprintPipeline.from_pretrained(
|
||||
"Efficient-Large-Model/Sana_Sprint_1.6B_1024px_diffusers",
|
||||
text_encoder=text_encoder_8bit,
|
||||
transformer=transformer_8bit,
|
||||
torch_dtype=torch.bfloat16,
|
||||
device_map="balanced",
|
||||
)
|
||||
|
||||
prompt = "a tiny astronaut hatching from an egg on the moon"
|
||||
image = pipeline(prompt).images[0]
|
||||
image.save("sana.png")
|
||||
```
|
||||
|
||||
## Setting `max_timesteps`
|
||||
|
||||
Users can tweak the `max_timesteps` value for experimenting with the visual quality of the generated outputs. The default `max_timesteps` value was obtained with an inference-time search process. For more details about it, check out the paper.
|
||||
|
||||
## SanaSprintPipeline
|
||||
|
||||
[[autodoc]] SanaSprintPipeline
|
||||
- all
|
||||
- __call__
|
||||
|
||||
|
||||
## SanaPipelineOutput
|
||||
|
||||
[[autodoc]] pipelines.sana.pipeline_output.SanaPipelineOutput
|
||||
@@ -14,22 +14,405 @@
|
||||
|
||||
# Wan
|
||||
|
||||
<div class="flex flex-wrap space-x-1">
|
||||
<img alt="LoRA" src="https://img.shields.io/badge/LoRA-d8b4fe?style=flat"/>
|
||||
</div>
|
||||
|
||||
[Wan 2.1](https://github.com/Wan-Video/Wan2.1) by the Alibaba Wan Team.
|
||||
|
||||
<!-- TODO(aryan): update abstract once paper is out -->
|
||||
|
||||
## Generating Videos with Wan 2.1
|
||||
|
||||
We will first need to install some addtional dependencies.
|
||||
|
||||
```shell
|
||||
pip install -u ftfy imageio-ffmpeg imageio
|
||||
```
|
||||
|
||||
### Text to Video Generation
|
||||
|
||||
The following example requires 11GB VRAM to run and uses the smaller `Wan-AI/Wan2.1-T2V-1.3B-Diffusers` model. You can switch it out
|
||||
for the larger `Wan2.1-I2V-14B-720P-Diffusers` or `Wan-AI/Wan2.1-I2V-14B-480P-Diffusers` if you have at least 35GB VRAM available.
|
||||
|
||||
```python
|
||||
from diffusers import WanPipeline
|
||||
from diffusers.utils import export_to_video
|
||||
|
||||
# Available models: Wan-AI/Wan2.1-I2V-14B-720P-Diffusers or Wan-AI/Wan2.1-I2V-14B-480P-Diffusers
|
||||
model_id = "Wan-AI/Wan2.1-T2V-1.3B-Diffusers"
|
||||
|
||||
pipe = WanPipeline.from_pretrained(model_id, torch_dtype=torch.bfloat16)
|
||||
pipe.enable_model_cpu_offload()
|
||||
|
||||
prompt = "A cat and a dog baking a cake together in a kitchen. The cat is carefully measuring flour, while the dog is stirring the batter with a wooden spoon. The kitchen is cozy, with sunlight streaming through the window."
|
||||
negative_prompt = "Bright tones, overexposed, static, blurred details, subtitles, style, works, paintings, images, static, overall gray, worst quality, low quality, JPEG compression residue, ugly, incomplete, extra fingers, poorly drawn hands, poorly drawn faces, deformed, disfigured, misshapen limbs, fused fingers, still picture, messy background, three legs, many people in the background, walking backwards"
|
||||
num_frames = 33
|
||||
|
||||
frames = pipe(prompt=prompt, negative_prompt=negative_prompt, num_frames=num_frames).frames[0]
|
||||
export_to_video(frames, "wan-t2v.mp4", fps=16)
|
||||
```
|
||||
|
||||
<Tip>
|
||||
|
||||
Make sure to check out the Schedulers [guide](../../using-diffusers/schedulers) to learn how to explore the tradeoff between scheduler speed and quality, and see the [reuse components across pipelines](../../using-diffusers/loading#reuse-a-pipeline) section to learn how to efficiently load the same components into multiple pipelines.
|
||||
|
||||
You can improve the quality of the generated video by running the decoding step in full precision.
|
||||
</Tip>
|
||||
|
||||
Recommendations for inference:
|
||||
- VAE in `torch.float32` for better decoding quality.
|
||||
- `num_frames` should be of the form `4 * k + 1`, for example `49` or `81`.
|
||||
- For smaller resolution videos, try lower values of `shift` (between `2.0` to `5.0`) in the [Scheduler](https://huggingface.co/docs/diffusers/main/en/api/schedulers/flow_match_euler_discrete#diffusers.FlowMatchEulerDiscreteScheduler.shift). For larger resolution videos, try higher values (between `7.0` and `12.0`). The default value is `3.0` for Wan.
|
||||
```python
|
||||
from diffusers import WanPipeline, AutoencoderKLWan
|
||||
from diffusers.utils import export_to_video
|
||||
|
||||
### Using a custom scheduler
|
||||
model_id = "Wan-AI/Wan2.1-T2V-1.3B-Diffusers"
|
||||
|
||||
vae = AutoencoderKLWan.from_pretrained(model_id, subfolder="vae", torch_dtype=torch.float32)
|
||||
pipe = WanPipeline.from_pretrained(model_id, vae=vae, torch_dtype=torch.bfloat16)
|
||||
|
||||
# replace this with pipe.to("cuda") if you have sufficient VRAM
|
||||
pipe.enable_model_cpu_offload()
|
||||
|
||||
prompt = "A cat and a dog baking a cake together in a kitchen. The cat is carefully measuring flour, while the dog is stirring the batter with a wooden spoon. The kitchen is cozy, with sunlight streaming through the window."
|
||||
negative_prompt = "Bright tones, overexposed, static, blurred details, subtitles, style, works, paintings, images, static, overall gray, worst quality, low quality, JPEG compression residue, ugly, incomplete, extra fingers, poorly drawn hands, poorly drawn faces, deformed, disfigured, misshapen limbs, fused fingers, still picture, messy background, three legs, many people in the background, walking backwards"
|
||||
num_frames = 33
|
||||
|
||||
frames = pipe(prompt=prompt, num_frames=num_frames).frames[0]
|
||||
export_to_video(frames, "wan-t2v.mp4", fps=16)
|
||||
```
|
||||
|
||||
### Image to Video Generation
|
||||
|
||||
The Image to Video pipeline requires loading the `AutoencoderKLWan` and the `CLIPVisionModel` components in full precision. The following example will need at least
|
||||
35GB of VRAM to run.
|
||||
|
||||
```python
|
||||
import torch
|
||||
import numpy as np
|
||||
from diffusers import AutoencoderKLWan, WanImageToVideoPipeline
|
||||
from diffusers.utils import export_to_video, load_image
|
||||
from transformers import CLIPVisionModel
|
||||
|
||||
# Available models: Wan-AI/Wan2.1-I2V-14B-480P-Diffusers, Wan-AI/Wan2.1-I2V-14B-720P-Diffusers
|
||||
model_id = "Wan-AI/Wan2.1-I2V-14B-480P-Diffusers"
|
||||
image_encoder = CLIPVisionModel.from_pretrained(
|
||||
model_id, subfolder="image_encoder", torch_dtype=torch.float32
|
||||
)
|
||||
vae = AutoencoderKLWan.from_pretrained(model_id, subfolder="vae", torch_dtype=torch.float32)
|
||||
pipe = WanImageToVideoPipeline.from_pretrained(
|
||||
model_id, vae=vae, image_encoder=image_encoder, torch_dtype=torch.bfloat16
|
||||
)
|
||||
|
||||
# replace this with pipe.to("cuda") if you have sufficient VRAM
|
||||
pipe.enable_model_cpu_offload()
|
||||
|
||||
image = load_image(
|
||||
"https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/astronaut.jpg"
|
||||
)
|
||||
|
||||
max_area = 480 * 832
|
||||
aspect_ratio = image.height / image.width
|
||||
mod_value = pipe.vae_scale_factor_spatial * pipe.transformer.config.patch_size[1]
|
||||
height = round(np.sqrt(max_area * aspect_ratio)) // mod_value * mod_value
|
||||
width = round(np.sqrt(max_area / aspect_ratio)) // mod_value * mod_value
|
||||
image = image.resize((width, height))
|
||||
|
||||
prompt = (
|
||||
"An astronaut hatching from an egg, on the surface of the moon, the darkness and depth of space realised in "
|
||||
"the background. High quality, ultrarealistic detail and breath-taking movie-like camera shot."
|
||||
)
|
||||
negative_prompt = "Bright tones, overexposed, static, blurred details, subtitles, style, works, paintings, images, static, overall gray, worst quality, low quality, JPEG compression residue, ugly, incomplete, extra fingers, poorly drawn hands, poorly drawn faces, deformed, disfigured, misshapen limbs, fused fingers, still picture, messy background, three legs, many people in the background, walking backwards"
|
||||
|
||||
num_frames = 33
|
||||
|
||||
output = pipe(
|
||||
image=image,
|
||||
prompt=prompt,
|
||||
negative_prompt=negative_prompt,
|
||||
height=height,
|
||||
width=width,
|
||||
num_frames=num_frames,
|
||||
guidance_scale=5.0,
|
||||
).frames[0]
|
||||
export_to_video(output, "wan-i2v.mp4", fps=16)
|
||||
```
|
||||
|
||||
### Video to Video Generation
|
||||
|
||||
```python
|
||||
import torch
|
||||
from diffusers.utils import load_video, export_to_video
|
||||
from diffusers import AutoencoderKLWan, WanVideoToVideoPipeline, UniPCMultistepScheduler
|
||||
|
||||
# Available models: Wan-AI/Wan2.1-T2V-14B-Diffusers, Wan-AI/Wan2.1-T2V-1.3B-Diffusers
|
||||
model_id = "Wan-AI/Wan2.1-T2V-1.3B-Diffusers"
|
||||
vae = AutoencoderKLWan.from_pretrained(
|
||||
model_id, subfolder="vae", torch_dtype=torch.float32
|
||||
)
|
||||
pipe = WanVideoToVideoPipeline.from_pretrained(
|
||||
model_id, vae=vae, torch_dtype=torch.bfloat16
|
||||
)
|
||||
flow_shift = 3.0 # 5.0 for 720P, 3.0 for 480P
|
||||
pipe.scheduler = UniPCMultistepScheduler.from_config(
|
||||
pipe.scheduler.config, flow_shift=flow_shift
|
||||
)
|
||||
# change to pipe.to("cuda") if you have sufficient VRAM
|
||||
pipe.enable_model_cpu_offload()
|
||||
|
||||
prompt = "A robot standing on a mountain top. The sun is setting in the background"
|
||||
negative_prompt = "Bright tones, overexposed, static, blurred details, subtitles, style, works, paintings, images, static, overall gray, worst quality, low quality, JPEG compression residue, ugly, incomplete, extra fingers, poorly drawn hands, poorly drawn faces, deformed, disfigured, misshapen limbs, fused fingers, still picture, messy background, three legs, many people in the background, walking backwards"
|
||||
video = load_video(
|
||||
"https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/hiker.mp4"
|
||||
)
|
||||
output = pipe(
|
||||
video=video,
|
||||
prompt=prompt,
|
||||
negative_prompt=negative_prompt,
|
||||
height=480,
|
||||
width=512,
|
||||
guidance_scale=7.0,
|
||||
strength=0.7,
|
||||
).frames[0]
|
||||
|
||||
export_to_video(output, "wan-v2v.mp4", fps=16)
|
||||
```
|
||||
|
||||
## Memory Optimizations for Wan 2.1
|
||||
|
||||
Base inference with the large 14B Wan 2.1 models can take up to 35GB of VRAM when generating videos at 720p resolution. We'll outline a few memory optimizations we can apply to reduce the VRAM required to run the model.
|
||||
|
||||
We'll use `Wan-AI/Wan2.1-I2V-14B-720P-Diffusers` model in these examples to demonstrate the memory savings, but the techniques are applicable to all model checkpoints.
|
||||
|
||||
### Group Offloading the Transformer and UMT5 Text Encoder
|
||||
|
||||
Find more information about group offloading [here](../optimization/memory.md)
|
||||
|
||||
#### Block Level Group Offloading
|
||||
|
||||
We can reduce our VRAM requirements by applying group offloading to the larger model components of the pipeline; the `WanTransformer3DModel` and `UMT5EncoderModel`. Group offloading will break up the individual modules of a model and offload/onload them onto your GPU as needed during inference. In this example, we'll apply `block_level` offloading, which will group the modules in a model into blocks of size `num_blocks_per_group` and offload/onload them to GPU. Moving to between CPU and GPU does add latency to the inference process. You can trade off between latency and memory savings by increasing or decreasing the `num_blocks_per_group`.
|
||||
|
||||
The following example will now only require 14GB of VRAM to run, but will take approximately 30 minutes to generate a video.
|
||||
|
||||
```python
|
||||
import torch
|
||||
import numpy as np
|
||||
from diffusers import AutoencoderKLWan, WanTransformer3DModel, WanImageToVideoPipeline
|
||||
from diffusers.hooks.group_offloading import apply_group_offloading
|
||||
from diffusers.utils import export_to_video, load_image
|
||||
from transformers import UMT5EncoderModel, CLIPVisionModel
|
||||
|
||||
# Available models: Wan-AI/Wan2.1-I2V-14B-480P-Diffusers, Wan-AI/Wan2.1-I2V-14B-720P-Diffusers
|
||||
model_id = "Wan-AI/Wan2.1-I2V-14B-720P-Diffusers"
|
||||
image_encoder = CLIPVisionModel.from_pretrained(
|
||||
model_id, subfolder="image_encoder", torch_dtype=torch.float32
|
||||
)
|
||||
|
||||
text_encoder = UMT5EncoderModel.from_pretrained(model_id, subfolder="text_encoder", torch_dtype=torch.bfloat16)
|
||||
vae = AutoencoderKLWan.from_pretrained(model_id, subfolder="vae", torch_dtype=torch.float32)
|
||||
transformer = WanTransformer3DModel.from_pretrained(model_id, subfolder="transformer", torch_dtype=torch.bfloat16)
|
||||
|
||||
onload_device = torch.device("cuda")
|
||||
offload_device = torch.device("cpu")
|
||||
|
||||
apply_group_offloading(text_encoder,
|
||||
onload_device=onload_device,
|
||||
offload_device=offload_device,
|
||||
offload_type="block_level",
|
||||
num_blocks_per_group=4
|
||||
)
|
||||
|
||||
transformer.enable_group_offload(
|
||||
onload_device=onload_device,
|
||||
offload_device=offload_device,
|
||||
offload_type="block_level",
|
||||
num_blocks_per_group=4,
|
||||
)
|
||||
pipe = WanImageToVideoPipeline.from_pretrained(
|
||||
model_id,
|
||||
vae=vae,
|
||||
transformer=transformer,
|
||||
text_encoder=text_encoder,
|
||||
image_encoder=image_encoder,
|
||||
torch_dtype=torch.bfloat16
|
||||
)
|
||||
# Since we've offloaded the larger models alrady, we can move the rest of the model components to GPU
|
||||
pipe.to("cuda")
|
||||
|
||||
image = load_image(
|
||||
"https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/astronaut.jpg"
|
||||
)
|
||||
|
||||
max_area = 720 * 832
|
||||
aspect_ratio = image.height / image.width
|
||||
mod_value = pipe.vae_scale_factor_spatial * pipe.transformer.config.patch_size[1]
|
||||
height = round(np.sqrt(max_area * aspect_ratio)) // mod_value * mod_value
|
||||
width = round(np.sqrt(max_area / aspect_ratio)) // mod_value * mod_value
|
||||
image = image.resize((width, height))
|
||||
|
||||
prompt = (
|
||||
"An astronaut hatching from an egg, on the surface of the moon, the darkness and depth of space realised in "
|
||||
"the background. High quality, ultrarealistic detail and breath-taking movie-like camera shot."
|
||||
)
|
||||
negative_prompt = "Bright tones, overexposed, static, blurred details, subtitles, style, works, paintings, images, static, overall gray, worst quality, low quality, JPEG compression residue, ugly, incomplete, extra fingers, poorly drawn hands, poorly drawn faces, deformed, disfigured, misshapen limbs, fused fingers, still picture, messy background, three legs, many people in the background, walking backwards"
|
||||
|
||||
num_frames = 33
|
||||
|
||||
output = pipe(
|
||||
image=image,
|
||||
prompt=prompt,
|
||||
negative_prompt=negative_prompt,
|
||||
height=height,
|
||||
width=width,
|
||||
num_frames=num_frames,
|
||||
guidance_scale=5.0,
|
||||
).frames[0]
|
||||
|
||||
export_to_video(output, "wan-i2v.mp4", fps=16)
|
||||
```
|
||||
|
||||
#### Block Level Group Offloading with CUDA Streams
|
||||
|
||||
We can speed up group offloading inference, by enabling the use of [CUDA streams](https://pytorch.org/docs/stable/generated/torch.cuda.Stream.html). However, using CUDA streams requires moving the model parameters into pinned memory. This allocation is handled by Pytorch under the hood, and can result in a significant spike in CPU RAM usage. Please consider this option if your CPU RAM is atleast 2X the size of the model you are group offloading.
|
||||
|
||||
In the following example we will use CUDA streams when group offloading the `WanTransformer3DModel`. When testing on an A100, this example will require 14GB of VRAM, 52GB of CPU RAM, but will generate a video in approximately 9 minutes.
|
||||
|
||||
```python
|
||||
import torch
|
||||
import numpy as np
|
||||
from diffusers import AutoencoderKLWan, WanTransformer3DModel, WanImageToVideoPipeline
|
||||
from diffusers.hooks.group_offloading import apply_group_offloading
|
||||
from diffusers.utils import export_to_video, load_image
|
||||
from transformers import UMT5EncoderModel, CLIPVisionModel
|
||||
|
||||
# Available models: Wan-AI/Wan2.1-I2V-14B-480P-Diffusers, Wan-AI/Wan2.1-I2V-14B-720P-Diffusers
|
||||
model_id = "Wan-AI/Wan2.1-I2V-14B-720P-Diffusers"
|
||||
image_encoder = CLIPVisionModel.from_pretrained(
|
||||
model_id, subfolder="image_encoder", torch_dtype=torch.float32
|
||||
)
|
||||
|
||||
text_encoder = UMT5EncoderModel.from_pretrained(model_id, subfolder="text_encoder", torch_dtype=torch.bfloat16)
|
||||
vae = AutoencoderKLWan.from_pretrained(model_id, subfolder="vae", torch_dtype=torch.float32)
|
||||
transformer = WanTransformer3DModel.from_pretrained(model_id, subfolder="transformer", torch_dtype=torch.bfloat16)
|
||||
|
||||
onload_device = torch.device("cuda")
|
||||
offload_device = torch.device("cpu")
|
||||
|
||||
apply_group_offloading(text_encoder,
|
||||
onload_device=onload_device,
|
||||
offload_device=offload_device,
|
||||
offload_type="block_level",
|
||||
num_blocks_per_group=4
|
||||
)
|
||||
|
||||
transformer.enable_group_offload(
|
||||
onload_device=onload_device,
|
||||
offload_device=offload_device,
|
||||
offload_type="leaf_level",
|
||||
use_stream=True
|
||||
)
|
||||
pipe = WanImageToVideoPipeline.from_pretrained(
|
||||
model_id,
|
||||
vae=vae,
|
||||
transformer=transformer,
|
||||
text_encoder=text_encoder,
|
||||
image_encoder=image_encoder,
|
||||
torch_dtype=torch.bfloat16
|
||||
)
|
||||
# Since we've offloaded the larger models alrady, we can move the rest of the model components to GPU
|
||||
pipe.to("cuda")
|
||||
|
||||
image = load_image(
|
||||
"https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/astronaut.jpg"
|
||||
)
|
||||
|
||||
max_area = 720 * 832
|
||||
aspect_ratio = image.height / image.width
|
||||
mod_value = pipe.vae_scale_factor_spatial * pipe.transformer.config.patch_size[1]
|
||||
height = round(np.sqrt(max_area * aspect_ratio)) // mod_value * mod_value
|
||||
width = round(np.sqrt(max_area / aspect_ratio)) // mod_value * mod_value
|
||||
image = image.resize((width, height))
|
||||
|
||||
prompt = (
|
||||
"An astronaut hatching from an egg, on the surface of the moon, the darkness and depth of space realised in "
|
||||
"the background. High quality, ultrarealistic detail and breath-taking movie-like camera shot."
|
||||
)
|
||||
negative_prompt = "Bright tones, overexposed, static, blurred details, subtitles, style, works, paintings, images, static, overall gray, worst quality, low quality, JPEG compression residue, ugly, incomplete, extra fingers, poorly drawn hands, poorly drawn faces, deformed, disfigured, misshapen limbs, fused fingers, still picture, messy background, three legs, many people in the background, walking backwards"
|
||||
|
||||
num_frames = 33
|
||||
|
||||
output = pipe(
|
||||
image=image,
|
||||
prompt=prompt,
|
||||
negative_prompt=negative_prompt,
|
||||
height=height,
|
||||
width=width,
|
||||
num_frames=num_frames,
|
||||
guidance_scale=5.0,
|
||||
).frames[0]
|
||||
|
||||
export_to_video(output, "wan-i2v.mp4", fps=16)
|
||||
```
|
||||
|
||||
### Applying Layerwise Casting to the Transformer
|
||||
|
||||
Find more information about layerwise casting [here](../optimization/memory.md)
|
||||
|
||||
In this example, we will model offloading with layerwise casting. Layerwise casting will downcast each layer's weights to `torch.float8_e4m3fn`, temporarily upcast to `torch.bfloat16` during the forward pass of the layer, then revert to `torch.float8_e4m3fn` afterward. This approach reduces memory requirements by approximately 50% while introducing a minor quality reduction in the generated video due to the precision trade-off.
|
||||
|
||||
This example will require 20GB of VRAM.
|
||||
|
||||
```python
|
||||
import torch
|
||||
import numpy as np
|
||||
from diffusers import AutoencoderKLWan, WanTransformer3DModel, WanImageToVideoPipeline
|
||||
from diffusers.hooks.group_offloading import apply_group_offloading
|
||||
from diffusers.utils import export_to_video, load_image
|
||||
from transformers import UMT5EncoderModel, CLIPVisionModel
|
||||
|
||||
model_id = "Wan-AI/Wan2.1-I2V-14B-720P-Diffusers"
|
||||
image_encoder = CLIPVisionModel.from_pretrained(
|
||||
model_id, subfolder="image_encoder", torch_dtype=torch.float32
|
||||
)
|
||||
text_encoder = UMT5EncoderModel.from_pretrained(model_id, subfolder="text_encoder", torch_dtype=torch.bfloat16)
|
||||
vae = AutoencoderKLWan.from_pretrained(model_id, subfolder="vae", torch_dtype=torch.float32)
|
||||
|
||||
transformer = WanTransformer3DModel.from_pretrained(model_id, subfolder="transformer", torch_dtype=torch.bfloat16)
|
||||
transformer.enable_layerwise_casting(storage_dtype=torch.float8_e4m3fn, compute_dtype=torch.bfloat16)
|
||||
|
||||
pipe = WanImageToVideoPipeline.from_pretrained(
|
||||
model_id,
|
||||
vae=vae,
|
||||
transformer=transformer,
|
||||
text_encoder=text_encoder,
|
||||
image_encoder=image_encoder,
|
||||
torch_dtype=torch.bfloat16
|
||||
)
|
||||
pipe.enable_model_cpu_offload()
|
||||
image = load_image("https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/astronaut.jpg")
|
||||
|
||||
max_area = 720 * 832
|
||||
aspect_ratio = image.height / image.width
|
||||
mod_value = pipe.vae_scale_factor_spatial * pipe.transformer.config.patch_size[1]
|
||||
height = round(np.sqrt(max_area * aspect_ratio)) // mod_value * mod_value
|
||||
width = round(np.sqrt(max_area / aspect_ratio)) // mod_value * mod_value
|
||||
image = image.resize((width, height))
|
||||
prompt = (
|
||||
"An astronaut hatching from an egg, on the surface of the moon, the darkness and depth of space realised in "
|
||||
"the background. High quality, ultrarealistic detail and breath-taking movie-like camera shot."
|
||||
)
|
||||
negative_prompt = "Bright tones, overexposed, static, blurred details, subtitles, style, works, paintings, images, static, overall gray, worst quality, low quality, JPEG compression residue, ugly, incomplete, extra fingers, poorly drawn hands, poorly drawn faces, deformed, disfigured, misshapen limbs, fused fingers, still picture, messy background, three legs, many people in the background, walking backwards"
|
||||
num_frames = 33
|
||||
|
||||
output = pipe(
|
||||
image=image,
|
||||
prompt=prompt,
|
||||
negative_prompt=negative_prompt,
|
||||
height=height,
|
||||
width=width,
|
||||
num_frames=num_frames,
|
||||
num_inference_steps=50,
|
||||
guidance_scale=5.0,
|
||||
).frames[0]
|
||||
export_to_video(output, "wan-i2v.mp4", fps=16)
|
||||
```
|
||||
|
||||
## Using a Custom Scheduler
|
||||
|
||||
Wan can be used with many different schedulers, each with their own benefits regarding speed and generation quality. By default, Wan uses the `UniPCMultistepScheduler(prediction_type="flow_prediction", use_flow_sigmas=True, flow_shift=3.0)` scheduler. You can use a different scheduler as follows:
|
||||
|
||||
@@ -45,6 +428,26 @@ pipe = WanPipeline.from_pretrained("Wan-AI/Wan2.1-T2V-1.3B-Diffusers", scheduler
|
||||
pipe.scheduler = <CUSTOM_SCHEDULER_HERE>
|
||||
```
|
||||
|
||||
## Using Single File Loading with Wan 2.1
|
||||
|
||||
The `WanTransformer3DModel` and `AutoencoderKLWan` models support loading checkpoints in their original format via the `from_single_file` loading
|
||||
method.
|
||||
|
||||
```python
|
||||
import torch
|
||||
from diffusers import WanPipeline, WanTransformer3DModel
|
||||
|
||||
ckpt_path = "https://huggingface.co/Comfy-Org/Wan_2.1_ComfyUI_repackaged/blob/main/split_files/diffusion_models/wan2.1_t2v_1.3B_bf16.safetensors"
|
||||
transformer = WanTransformer3DModel.from_single_file(ckpt_path, torch_dtype=torch.bfloat16)
|
||||
|
||||
pipe = WanPipeline.from_pretrained("Wan-AI/Wan2.1-T2V-1.3B-Diffusers", transformer=transformer)
|
||||
```
|
||||
|
||||
## Recommendations for Inference
|
||||
- Keep `AutencoderKLWan` in `torch.float32` for better decoding quality.
|
||||
- `num_frames` should satisfy the following constraint: `(num_frames - 1) % 4 == 0`
|
||||
- For smaller resolution videos, try lower values of `shift` (between `2.0` to `5.0`) in the [Scheduler](https://huggingface.co/docs/diffusers/main/en/api/schedulers/flow_match_euler_discrete#diffusers.FlowMatchEulerDiscreteScheduler.shift). For larger resolution videos, try higher values (between `7.0` and `12.0`). The default value is `3.0` for Wan.
|
||||
|
||||
## WanPipeline
|
||||
|
||||
[[autodoc]] WanPipeline
|
||||
|
||||
@@ -31,6 +31,11 @@ Learn how to quantize models in the [Quantization](../quantization/overview) gui
|
||||
## GGUFQuantizationConfig
|
||||
|
||||
[[autodoc]] GGUFQuantizationConfig
|
||||
|
||||
## QuantoConfig
|
||||
|
||||
[[autodoc]] QuantoConfig
|
||||
|
||||
## TorchAoConfig
|
||||
|
||||
[[autodoc]] TorchAoConfig
|
||||
|
||||
@@ -3,3 +3,7 @@
|
||||
## Remote Decode
|
||||
|
||||
[[autodoc]] utils.remote_utils.remote_decode
|
||||
|
||||
## Remote Encode
|
||||
|
||||
[[autodoc]] utils.remote_utils.remote_encode
|
||||
|
||||
@@ -36,7 +36,7 @@ Hybrid Inference offers a fast and simple way to offload local generation requir
|
||||
## Available Models
|
||||
|
||||
* **VAE Decode 🖼️:** Quickly decode latent representations into high-quality images without compromising performance or workflow speed.
|
||||
* **VAE Encode 🔢 (coming soon):** Efficiently encode images into latent representations for generation and training.
|
||||
* **VAE Encode 🔢:** Efficiently encode images into latent representations for generation and training.
|
||||
* **Text Encoders 📃 (coming soon):** Compute text embeddings for your prompts quickly and accurately, ensuring a smooth and high-quality workflow.
|
||||
|
||||
---
|
||||
@@ -46,9 +46,15 @@ Hybrid Inference offers a fast and simple way to offload local generation requir
|
||||
* **[SD.Next](https://github.com/vladmandic/sdnext):** All-in-one UI with direct supports Hybrid Inference.
|
||||
* **[ComfyUI-HFRemoteVae](https://github.com/kijai/ComfyUI-HFRemoteVae):** ComfyUI node for Hybrid Inference.
|
||||
|
||||
## Changelog
|
||||
|
||||
- March 10 2025: Added VAE encode
|
||||
- March 2 2025: Initial release with VAE decoding
|
||||
|
||||
## Contents
|
||||
|
||||
The documentation is organized into two sections:
|
||||
The documentation is organized into three sections:
|
||||
|
||||
* **VAE Decode** Learn the basics of how to use VAE Decode with Hybrid Inference.
|
||||
* **VAE Encode** Learn the basics of how to use VAE Encode with Hybrid Inference.
|
||||
* **API Reference** Dive into task-specific settings and parameters.
|
||||
|
||||
@@ -0,0 +1,183 @@
|
||||
# Getting Started: VAE Encode with Hybrid Inference
|
||||
|
||||
VAE encode is used for training, image-to-image and image-to-video - turning into images or videos into latent representations.
|
||||
|
||||
## Memory
|
||||
|
||||
These tables demonstrate the VRAM requirements for VAE encode with SD v1 and SD XL on different GPUs.
|
||||
|
||||
For the majority of these GPUs the memory usage % dictates other models (text encoders, UNet/Transformer) must be offloaded, or tiled encoding has to be used which increases time taken and impacts quality.
|
||||
|
||||
<details><summary>SD v1.5</summary>
|
||||
|
||||
| GPU | Resolution | Time (seconds) | Memory (%) | Tiled Time (secs) | Tiled Memory (%) |
|
||||
|:------------------------------|:-------------|-----------------:|-------------:|--------------------:|-------------------:|
|
||||
| NVIDIA GeForce RTX 4090 | 512x512 | 0.015 | 3.51901 | 0.015 | 3.51901 |
|
||||
| NVIDIA GeForce RTX 4090 | 256x256 | 0.004 | 1.3154 | 0.005 | 1.3154 |
|
||||
| NVIDIA GeForce RTX 4090 | 2048x2048 | 0.402 | 47.1852 | 0.496 | 3.51901 |
|
||||
| NVIDIA GeForce RTX 4090 | 1024x1024 | 0.078 | 12.2658 | 0.094 | 3.51901 |
|
||||
| NVIDIA GeForce RTX 4080 SUPER | 512x512 | 0.023 | 5.30105 | 0.023 | 5.30105 |
|
||||
| NVIDIA GeForce RTX 4080 SUPER | 256x256 | 0.006 | 1.98152 | 0.006 | 1.98152 |
|
||||
| NVIDIA GeForce RTX 4080 SUPER | 2048x2048 | 0.574 | 71.08 | 0.656 | 5.30105 |
|
||||
| NVIDIA GeForce RTX 4080 SUPER | 1024x1024 | 0.111 | 18.4772 | 0.14 | 5.30105 |
|
||||
| NVIDIA GeForce RTX 3090 | 512x512 | 0.032 | 3.52782 | 0.032 | 3.52782 |
|
||||
| NVIDIA GeForce RTX 3090 | 256x256 | 0.01 | 1.31869 | 0.009 | 1.31869 |
|
||||
| NVIDIA GeForce RTX 3090 | 2048x2048 | 0.742 | 47.3033 | 0.954 | 3.52782 |
|
||||
| NVIDIA GeForce RTX 3090 | 1024x1024 | 0.136 | 12.2965 | 0.207 | 3.52782 |
|
||||
| NVIDIA GeForce RTX 3080 | 512x512 | 0.036 | 8.51761 | 0.036 | 8.51761 |
|
||||
| NVIDIA GeForce RTX 3080 | 256x256 | 0.01 | 3.18387 | 0.01 | 3.18387 |
|
||||
| NVIDIA GeForce RTX 3080 | 2048x2048 | 0.863 | 86.7424 | 1.191 | 8.51761 |
|
||||
| NVIDIA GeForce RTX 3080 | 1024x1024 | 0.157 | 29.6888 | 0.227 | 8.51761 |
|
||||
| NVIDIA GeForce RTX 3070 | 512x512 | 0.051 | 10.6941 | 0.051 | 10.6941 |
|
||||
| NVIDIA GeForce RTX 3070 | 256x256 | 0.015 | 3.99743 | 0.015 | 3.99743 |
|
||||
| NVIDIA GeForce RTX 3070 | 2048x2048 | 1.217 | 96.054 | 1.482 | 10.6941 |
|
||||
| NVIDIA GeForce RTX 3070 | 1024x1024 | 0.223 | 37.2751 | 0.327 | 10.6941 |
|
||||
|
||||
|
||||
</details>
|
||||
|
||||
<details><summary>SDXL</summary>
|
||||
|
||||
| GPU | Resolution | Time (seconds) | Memory Consumed (%) | Tiled Time (seconds) | Tiled Memory (%) |
|
||||
|:------------------------------|:-------------|-----------------:|----------------------:|-----------------------:|-------------------:|
|
||||
| NVIDIA GeForce RTX 4090 | 512x512 | 0.029 | 4.95707 | 0.029 | 4.95707 |
|
||||
| NVIDIA GeForce RTX 4090 | 256x256 | 0.007 | 2.29666 | 0.007 | 2.29666 |
|
||||
| NVIDIA GeForce RTX 4090 | 2048x2048 | 0.873 | 66.3452 | 0.863 | 15.5649 |
|
||||
| NVIDIA GeForce RTX 4090 | 1024x1024 | 0.142 | 15.5479 | 0.143 | 15.5479 |
|
||||
| NVIDIA GeForce RTX 4080 SUPER | 512x512 | 0.044 | 7.46735 | 0.044 | 7.46735 |
|
||||
| NVIDIA GeForce RTX 4080 SUPER | 256x256 | 0.01 | 3.4597 | 0.01 | 3.4597 |
|
||||
| NVIDIA GeForce RTX 4080 SUPER | 2048x2048 | 1.317 | 87.1615 | 1.291 | 23.447 |
|
||||
| NVIDIA GeForce RTX 4080 SUPER | 1024x1024 | 0.213 | 23.4215 | 0.214 | 23.4215 |
|
||||
| NVIDIA GeForce RTX 3090 | 512x512 | 0.058 | 5.65638 | 0.058 | 5.65638 |
|
||||
| NVIDIA GeForce RTX 3090 | 256x256 | 0.016 | 2.45081 | 0.016 | 2.45081 |
|
||||
| NVIDIA GeForce RTX 3090 | 2048x2048 | 1.755 | 77.8239 | 1.614 | 18.4193 |
|
||||
| NVIDIA GeForce RTX 3090 | 1024x1024 | 0.265 | 18.4023 | 0.265 | 18.4023 |
|
||||
| NVIDIA GeForce RTX 3080 | 512x512 | 0.064 | 13.6568 | 0.064 | 13.6568 |
|
||||
| NVIDIA GeForce RTX 3080 | 256x256 | 0.018 | 5.91728 | 0.018 | 5.91728 |
|
||||
| NVIDIA GeForce RTX 3080 | 2048x2048 | OOM | OOM | 1.866 | 44.4717 |
|
||||
| NVIDIA GeForce RTX 3080 | 1024x1024 | 0.302 | 44.4308 | 0.302 | 44.4308 |
|
||||
| NVIDIA GeForce RTX 3070 | 512x512 | 0.093 | 17.1465 | 0.093 | 17.1465 |
|
||||
| NVIDIA GeForce RTX 3070 | 256x256 | 0.025 | 7.42931 | 0.026 | 7.42931 |
|
||||
| NVIDIA GeForce RTX 3070 | 2048x2048 | OOM | OOM | 2.674 | 55.8355 |
|
||||
| NVIDIA GeForce RTX 3070 | 1024x1024 | 0.443 | 55.7841 | 0.443 | 55.7841 |
|
||||
|
||||
</details>
|
||||
|
||||
## Available VAEs
|
||||
|
||||
| | **Endpoint** | **Model** |
|
||||
|:-:|:-----------:|:--------:|
|
||||
| **Stable Diffusion v1** | [https://qc6479g0aac6qwy9.us-east-1.aws.endpoints.huggingface.cloud](https://qc6479g0aac6qwy9.us-east-1.aws.endpoints.huggingface.cloud) | [`stabilityai/sd-vae-ft-mse`](https://hf.co/stabilityai/sd-vae-ft-mse) |
|
||||
| **Stable Diffusion XL** | [https://xjqqhmyn62rog84g.us-east-1.aws.endpoints.huggingface.cloud](https://xjqqhmyn62rog84g.us-east-1.aws.endpoints.huggingface.cloud) | [`madebyollin/sdxl-vae-fp16-fix`](https://hf.co/madebyollin/sdxl-vae-fp16-fix) |
|
||||
| **Flux** | [https://ptccx55jz97f9zgo.us-east-1.aws.endpoints.huggingface.cloud](https://ptccx55jz97f9zgo.us-east-1.aws.endpoints.huggingface.cloud) | [`black-forest-labs/FLUX.1-schnell`](https://hf.co/black-forest-labs/FLUX.1-schnell) |
|
||||
|
||||
|
||||
> [!TIP]
|
||||
> Model support can be requested [here](https://github.com/huggingface/diffusers/issues/new?template=remote-vae-pilot-feedback.yml).
|
||||
|
||||
|
||||
## Code
|
||||
|
||||
> [!TIP]
|
||||
> Install `diffusers` from `main` to run the code: `pip install git+https://github.com/huggingface/diffusers@main`
|
||||
|
||||
|
||||
A helper method simplifies interacting with Hybrid Inference.
|
||||
|
||||
```python
|
||||
from diffusers.utils.remote_utils import remote_encode
|
||||
```
|
||||
|
||||
### Basic example
|
||||
|
||||
Let's encode an image, then decode it to demonstrate.
|
||||
|
||||
<figure class="image flex flex-col items-center justify-center text-center m-0 w-full">
|
||||
<img src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/astronaut.jpg"/>
|
||||
</figure>
|
||||
|
||||
<details><summary>Code</summary>
|
||||
|
||||
```python
|
||||
from diffusers.utils import load_image
|
||||
from diffusers.utils.remote_utils import remote_decode
|
||||
|
||||
image = load_image("https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/astronaut.jpg?download=true")
|
||||
|
||||
latent = remote_encode(
|
||||
endpoint="https://ptccx55jz97f9zgo.us-east-1.aws.endpoints.huggingface.cloud/",
|
||||
scaling_factor=0.3611,
|
||||
shift_factor=0.1159,
|
||||
)
|
||||
|
||||
decoded = remote_decode(
|
||||
endpoint="https://whhx50ex1aryqvw6.us-east-1.aws.endpoints.huggingface.cloud/",
|
||||
tensor=latent,
|
||||
scaling_factor=0.3611,
|
||||
shift_factor=0.1159,
|
||||
)
|
||||
```
|
||||
|
||||
</details>
|
||||
|
||||
<figure class="image flex flex-col items-center justify-center text-center m-0 w-full">
|
||||
<img src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/blog/remote_vae/decoded.png"/>
|
||||
</figure>
|
||||
|
||||
|
||||
### Generation
|
||||
|
||||
Now let's look at a generation example, we'll encode the image, generate then remotely decode too!
|
||||
|
||||
<details><summary>Code</summary>
|
||||
|
||||
```python
|
||||
import torch
|
||||
from diffusers import StableDiffusionImg2ImgPipeline
|
||||
from diffusers.utils import load_image
|
||||
from diffusers.utils.remote_utils import remote_decode, remote_encode
|
||||
|
||||
pipe = StableDiffusionImg2ImgPipeline.from_pretrained(
|
||||
"stable-diffusion-v1-5/stable-diffusion-v1-5",
|
||||
torch_dtype=torch.float16,
|
||||
variant="fp16",
|
||||
vae=None,
|
||||
).to("cuda")
|
||||
|
||||
init_image = load_image(
|
||||
"https://raw.githubusercontent.com/CompVis/stable-diffusion/main/assets/stable-samples/img2img/sketch-mountains-input.jpg"
|
||||
)
|
||||
init_image = init_image.resize((768, 512))
|
||||
|
||||
init_latent = remote_encode(
|
||||
endpoint="https://qc6479g0aac6qwy9.us-east-1.aws.endpoints.huggingface.cloud/",
|
||||
image=init_image,
|
||||
scaling_factor=0.18215,
|
||||
)
|
||||
|
||||
prompt = "A fantasy landscape, trending on artstation"
|
||||
latent = pipe(
|
||||
prompt=prompt,
|
||||
image=init_latent,
|
||||
strength=0.75,
|
||||
output_type="latent",
|
||||
).images
|
||||
|
||||
image = remote_decode(
|
||||
endpoint="https://q1bj3bpq6kzilnsu.us-east-1.aws.endpoints.huggingface.cloud/",
|
||||
tensor=latent,
|
||||
scaling_factor=0.18215,
|
||||
)
|
||||
image.save("fantasy_landscape.jpg")
|
||||
```
|
||||
|
||||
</details>
|
||||
|
||||
<figure class="image flex flex-col items-center justify-center text-center m-0 w-full">
|
||||
<img src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/blog/remote_vae/fantasy_landscape.png"/>
|
||||
</figure>
|
||||
|
||||
## Integrations
|
||||
|
||||
* **[SD.Next](https://github.com/vladmandic/sdnext):** All-in-one UI with direct supports Hybrid Inference.
|
||||
* **[ComfyUI-HFRemoteVae](https://github.com/kijai/ComfyUI-HFRemoteVae):** ComfyUI node for Hybrid Inference.
|
||||
@@ -161,10 +161,10 @@ Your Python environment will find the `main` version of 🤗 Diffusers on the ne
|
||||
|
||||
Model weights and files are downloaded from the Hub to a cache which is usually your home directory. You can change the cache location by specifying the `HF_HOME` or `HUGGINFACE_HUB_CACHE` environment variables or configuring the `cache_dir` parameter in methods like [`~DiffusionPipeline.from_pretrained`].
|
||||
|
||||
Cached files allow you to run 🤗 Diffusers offline. To prevent 🤗 Diffusers from connecting to the internet, set the `HF_HUB_OFFLINE` environment variable to `True` and 🤗 Diffusers will only load previously downloaded files in the cache.
|
||||
Cached files allow you to run 🤗 Diffusers offline. To prevent 🤗 Diffusers from connecting to the internet, set the `HF_HUB_OFFLINE` environment variable to `1` and 🤗 Diffusers will only load previously downloaded files in the cache.
|
||||
|
||||
```shell
|
||||
export HF_HUB_OFFLINE=True
|
||||
export HF_HUB_OFFLINE=1
|
||||
```
|
||||
|
||||
For more details about managing and cleaning the cache, take a look at the [caching](https://huggingface.co/docs/huggingface_hub/guides/manage-cache) guide.
|
||||
@@ -179,14 +179,16 @@ Telemetry is only sent when loading models and pipelines from the Hub,
|
||||
and it is not collected if you're loading local files.
|
||||
|
||||
We understand that not everyone wants to share additional information,and we respect your privacy.
|
||||
You can disable telemetry collection by setting the `DISABLE_TELEMETRY` environment variable from your terminal:
|
||||
You can disable telemetry collection by setting the `HF_HUB_DISABLE_TELEMETRY` environment variable from your terminal:
|
||||
|
||||
On Linux/MacOS:
|
||||
|
||||
```bash
|
||||
export DISABLE_TELEMETRY=YES
|
||||
export HF_HUB_DISABLE_TELEMETRY=1
|
||||
```
|
||||
|
||||
On Windows:
|
||||
|
||||
```bash
|
||||
set DISABLE_TELEMETRY=YES
|
||||
set HF_HUB_DISABLE_TELEMETRY=1
|
||||
```
|
||||
|
||||
@@ -198,6 +198,18 @@ export_to_video(video, "output.mp4", fps=8)
|
||||
|
||||
Group offloading (for CUDA devices with support for asynchronous data transfer streams) overlaps data transfer and computation to reduce the overall execution time compared to sequential offloading. This is enabled using layer prefetching with CUDA streams. The next layer to be executed is loaded onto the accelerator device while the current layer is being executed - this increases the memory requirements slightly. Group offloading also supports leaf-level offloading (equivalent to sequential CPU offloading) but can be made much faster when using streams.
|
||||
|
||||
<Tip>
|
||||
|
||||
- Group offloading may not work with all models out-of-the-box. If the forward implementations of the model contain weight-dependent device-casting of inputs, it may clash with the offloading mechanism's handling of device-casting.
|
||||
- The `offload_type` parameter can be set to either `block_level` or `leaf_level`. `block_level` offloads groups of `torch::nn::ModuleList` or `torch::nn:Sequential` modules based on a configurable attribute `num_blocks_per_group`. For example, if you set `num_blocks_per_group=2` on a standard transformer model containing 40 layers, it will onload/offload 2 layers at a time for a total of 20 onload/offloads. This drastically reduces the VRAM requirements. `leaf_level` offloads individual layers at the lowest level, which is equivalent to sequential offloading. However, unlike sequential offloading, group offloading can be made much faster when using streams, with minimal compromise to end-to-end generation time.
|
||||
- The `use_stream` parameter can be used with CUDA devices to enable prefetching layers for onload. It defaults to `False`. Layer prefetching allows overlapping computation and data transfer of model weights, which drastically reduces the overall execution time compared to other offloading methods. However, it can increase the CPU RAM usage significantly. Ensure that available CPU RAM that is at least twice the size of the model when setting `use_stream=True`. You can find more information about CUDA streams [here](https://pytorch.org/docs/stable/generated/torch.cuda.Stream.html)
|
||||
- If specifying `use_stream=True` on VAEs with tiling enabled, make sure to do a dummy forward pass (possibly with dummy inputs) before the actual inference to avoid device-mismatch errors. This may not work on all implementations. Please open an issue if you encounter any problems.
|
||||
- The parameter `low_cpu_mem_usage` can be set to `True` to reduce CPU memory usage when using streams for group offloading. This is useful when the CPU memory is the bottleneck, but it may counteract the benefits of using streams and increase the overall execution time. The CPU memory savings come from creating pinned-tensors on-the-fly instead of pre-pinning them. This parameter is better suited for using `leaf_level` offloading.
|
||||
|
||||
For more information about available parameters and an explanation of how group offloading works, refer to [`~hooks.group_offloading.apply_group_offloading`].
|
||||
|
||||
</Tip>
|
||||
|
||||
## FP8 layerwise weight-casting
|
||||
|
||||
PyTorch supports `torch.float8_e4m3fn` and `torch.float8_e5m2` as weight storage dtypes, but they can't be used for computation in many different tensor operations due to unimplemented kernel support. However, you can use these dtypes to store model weights in fp8 precision and upcast them on-the-fly when the layers are used in the forward pass. This is known as layerwise weight-casting.
|
||||
@@ -235,6 +247,14 @@ In the above example, layerwise casting is enabled on the transformer component
|
||||
|
||||
However, you gain more control and flexibility by directly utilizing the [`~hooks.layerwise_casting.apply_layerwise_casting`] function instead of [`~ModelMixin.enable_layerwise_casting`].
|
||||
|
||||
<Tip>
|
||||
|
||||
- Layerwise casting may not work with all models out-of-the-box. Sometimes, the forward implementations of the model might contain internal typecasting of weight values. Such implementations are not supported due to the currently simplistic implementation of layerwise casting, which assumes that the forward pass is independent of the weight precision and that the input dtypes are always in `compute_dtype`. An example of an incompatible implementation can be found [here](https://github.com/huggingface/transformers/blob/7f5077e53682ca855afc826162b204ebf809f1f9/src/transformers/models/t5/modeling_t5.py#L294-L299).
|
||||
- Layerwise casting may fail on custom modeling implementations that make use of [PEFT](https://github.com/huggingface/peft) layers. Some minimal checks to handle this case is implemented but is not extensively tested or guaranteed to work in all cases.
|
||||
- It can be also be applied partially to specific layers of a model. Partially applying layerwise casting can either be done manually by calling the `apply_layerwise_casting` function on specific internal modules, or by specifying the `skip_modules_pattern` and `skip_modules_classes` parameters for a root module. These parameters are particularly useful for layers such as normalization and modulation.
|
||||
|
||||
</Tip>
|
||||
|
||||
## Channels-last memory format
|
||||
|
||||
The channels-last memory format is an alternative way of ordering NCHW tensors in memory to preserve dimension ordering. Channels-last tensors are ordered in such a way that the channels become the densest dimension (storing images pixel-per-pixel). Since not all operators currently support the channels-last format, it may result in worst performance but you should still try and see if it works for your model.
|
||||
|
||||
@@ -36,5 +36,6 @@ Diffusers currently supports the following quantization methods.
|
||||
- [BitsandBytes](./bitsandbytes)
|
||||
- [TorchAO](./torchao)
|
||||
- [GGUF](./gguf)
|
||||
- [Quanto](./quanto.md)
|
||||
|
||||
[This resource](https://huggingface.co/docs/transformers/main/en/quantization/overview#when-to-use-what) provides a good overview of the pros and cons of different quantization techniques.
|
||||
|
||||
@@ -0,0 +1,148 @@
|
||||
<!--Copyright 2025 The HuggingFace Team. All rights reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
|
||||
the License. You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
|
||||
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
|
||||
specific language governing permissions and limitations under the License.
|
||||
|
||||
-->
|
||||
|
||||
# Quanto
|
||||
|
||||
[Quanto](https://github.com/huggingface/optimum-quanto) is a PyTorch quantization backend for [Optimum](https://huggingface.co/docs/optimum/en/index). It has been designed with versatility and simplicity in mind:
|
||||
|
||||
- All features are available in eager mode (works with non-traceable models)
|
||||
- Supports quantization aware training
|
||||
- Quantized models are compatible with `torch.compile`
|
||||
- Quantized models are Device agnostic (e.g CUDA,XPU,MPS,CPU)
|
||||
|
||||
In order to use the Quanto backend, you will first need to install `optimum-quanto>=0.2.6` and `accelerate`
|
||||
|
||||
```shell
|
||||
pip install optimum-quanto accelerate
|
||||
```
|
||||
|
||||
Now you can quantize a model by passing the `QuantoConfig` object to the `from_pretrained()` method. Although the Quanto library does allow quantizing `nn.Conv2d` and `nn.LayerNorm` modules, currently, Diffusers only supports quantizing the weights in the `nn.Linear` layers of a model. The following snippet demonstrates how to apply `float8` quantization with Quanto.
|
||||
|
||||
```python
|
||||
import torch
|
||||
from diffusers import FluxTransformer2DModel, QuantoConfig
|
||||
|
||||
model_id = "black-forest-labs/FLUX.1-dev"
|
||||
quantization_config = QuantoConfig(weights_dtype="float8")
|
||||
transformer = FluxTransformer2DModel.from_pretrained(
|
||||
model_id,
|
||||
subfolder="transformer",
|
||||
quantization_config=quantization_config,
|
||||
torch_dtype=torch.bfloat16,
|
||||
)
|
||||
|
||||
pipe = FluxPipeline.from_pretrained(model_id, transformer=transformer, torch_dtype=torch_dtype)
|
||||
pipe.to("cuda")
|
||||
|
||||
prompt = "A cat holding a sign that says hello world"
|
||||
image = pipe(
|
||||
prompt, num_inference_steps=50, guidance_scale=4.5, max_sequence_length=512
|
||||
).images[0]
|
||||
image.save("output.png")
|
||||
```
|
||||
|
||||
## Skipping Quantization on specific modules
|
||||
|
||||
It is possible to skip applying quantization on certain modules using the `modules_to_not_convert` argument in the `QuantoConfig`. Please ensure that the modules passed in to this argument match the keys of the modules in the `state_dict`
|
||||
|
||||
```python
|
||||
import torch
|
||||
from diffusers import FluxTransformer2DModel, QuantoConfig
|
||||
|
||||
model_id = "black-forest-labs/FLUX.1-dev"
|
||||
quantization_config = QuantoConfig(weights_dtype="float8", modules_to_not_convert=["proj_out"])
|
||||
transformer = FluxTransformer2DModel.from_pretrained(
|
||||
model_id,
|
||||
subfolder="transformer",
|
||||
quantization_config=quantization_config,
|
||||
torch_dtype=torch.bfloat16,
|
||||
)
|
||||
```
|
||||
|
||||
## Using `from_single_file` with the Quanto Backend
|
||||
|
||||
`QuantoConfig` is compatible with `~FromOriginalModelMixin.from_single_file`.
|
||||
|
||||
```python
|
||||
import torch
|
||||
from diffusers import FluxTransformer2DModel, QuantoConfig
|
||||
|
||||
ckpt_path = "https://huggingface.co/black-forest-labs/FLUX.1-dev/blob/main/flux1-dev.safetensors"
|
||||
quantization_config = QuantoConfig(weights_dtype="float8")
|
||||
transformer = FluxTransformer2DModel.from_single_file(ckpt_path, quantization_config=quantization_config, torch_dtype=torch.bfloat16)
|
||||
```
|
||||
|
||||
## Saving Quantized models
|
||||
|
||||
Diffusers supports serializing Quanto models using the `~ModelMixin.save_pretrained` method.
|
||||
|
||||
The serialization and loading requirements are different for models quantized directly with the Quanto library and models quantized
|
||||
with Diffusers using Quanto as the backend. It is currently not possible to load models quantized directly with Quanto into Diffusers using `~ModelMixin.from_pretrained`
|
||||
|
||||
```python
|
||||
import torch
|
||||
from diffusers import FluxTransformer2DModel, QuantoConfig
|
||||
|
||||
model_id = "black-forest-labs/FLUX.1-dev"
|
||||
quantization_config = QuantoConfig(weights_dtype="float8")
|
||||
transformer = FluxTransformer2DModel.from_pretrained(
|
||||
model_id,
|
||||
subfolder="transformer",
|
||||
quantization_config=quantization_config,
|
||||
torch_dtype=torch.bfloat16,
|
||||
)
|
||||
# save quantized model to reuse
|
||||
transformer.save_pretrained("<your quantized model save path>")
|
||||
|
||||
# you can reload your quantized model with
|
||||
model = FluxTransformer2DModel.from_pretrained("<your quantized model save path>")
|
||||
```
|
||||
|
||||
## Using `torch.compile` with Quanto
|
||||
|
||||
Currently the Quanto backend supports `torch.compile` for the following quantization types:
|
||||
|
||||
- `int8` weights
|
||||
|
||||
```python
|
||||
import torch
|
||||
from diffusers import FluxPipeline, FluxTransformer2DModel, QuantoConfig
|
||||
|
||||
model_id = "black-forest-labs/FLUX.1-dev"
|
||||
quantization_config = QuantoConfig(weights_dtype="int8")
|
||||
transformer = FluxTransformer2DModel.from_pretrained(
|
||||
model_id,
|
||||
subfolder="transformer",
|
||||
quantization_config=quantization_config,
|
||||
torch_dtype=torch.bfloat16,
|
||||
)
|
||||
transformer = torch.compile(transformer, mode="max-autotune", fullgraph=True)
|
||||
|
||||
pipe = FluxPipeline.from_pretrained(
|
||||
model_id, transformer=transformer, torch_dtype=torch_dtype
|
||||
)
|
||||
pipe.to("cuda")
|
||||
images = pipe("A cat holding a sign that says hello").images[0]
|
||||
images.save("flux-quanto-compile.png")
|
||||
```
|
||||
|
||||
## Supported Quantization Types
|
||||
|
||||
### Weights
|
||||
|
||||
- float8
|
||||
- int8
|
||||
- int4
|
||||
- int2
|
||||
|
||||
|
||||
@@ -126,7 +126,7 @@ image = pipe(prompt, num_inference_steps=30, guidance_scale=7.0).images[0]
|
||||
image.save("output.png")
|
||||
```
|
||||
|
||||
Some quantization methods, such as `uint4wo`, cannot be loaded directly and may result in an `UnpicklingError` when trying to load the models, but work as expected when saving them. In order to work around this, one can load the state dict manually into the model. Note, however, that this requires using `weights_only=False` in `torch.load`, so it should be run only if the weights were obtained from a trustable source.
|
||||
If you are using `torch<=2.6.0`, some quantization methods, such as `uint4wo`, cannot be loaded directly and may result in an `UnpicklingError` when trying to load the models, but work as expected when saving them. In order to work around this, one can load the state dict manually into the model. Note, however, that this requires using `weights_only=False` in `torch.load`, so it should be run only if the weights were obtained from a trustable source.
|
||||
|
||||
```python
|
||||
import torch
|
||||
|
||||
@@ -66,12 +66,6 @@ from accelerate.utils import write_basic_config
|
||||
write_basic_config()
|
||||
```
|
||||
|
||||
## 원을 채우는 데이터셋
|
||||
|
||||
원본 데이터셋은 ControlNet [repo](https://huggingface.co/lllyasviel/ControlNet/blob/main/training/fill50k.zip)에 올라와있지만, 우리는 [여기](https://huggingface.co/datasets/fusing/fill50k)에 새롭게 다시 올려서 🤗 Datasets 과 호환가능합니다. 그래서 학습 스크립트 상에서 데이터 불러오기를 다룰 수 있습니다.
|
||||
|
||||
우리의 학습 예시는 원래 ControlNet의 학습에 쓰였던 [`stable-diffusion-v1-5/stable-diffusion-v1-5`](https://huggingface.co/stable-diffusion-v1-5/stable-diffusion-v1-5)을 사용합니다. 그렇지만 ControlNet은 대응되는 어느 Stable Diffusion 모델([`CompVis/stable-diffusion-v1-4`](https://huggingface.co/CompVis/stable-diffusion-v1-4)) 혹은 [`stabilityai/stable-diffusion-2-1`](https://huggingface.co/stabilityai/stable-diffusion-2-1)의 증가를 위해 학습될 수 있습니다.
|
||||
|
||||
자체 데이터셋을 사용하기 위해서는 [학습을 위한 데이터셋 생성하기](create_dataset) 가이드를 확인하세요.
|
||||
|
||||
## 학습
|
||||
|
||||
@@ -79,13 +79,13 @@ This command will prompt you for a token. Copy-paste yours from your [settings/t
|
||||
### Target Modules
|
||||
When LoRA was first adapted from language models to diffusion models, it was applied to the cross-attention layers in the Unet that relate the image representations with the prompts that describe them.
|
||||
More recently, SOTA text-to-image diffusion models replaced the Unet with a diffusion Transformer(DiT). With this change, we may also want to explore
|
||||
applying LoRA training onto different types of layers and blocks. To allow more flexibility and control over the targeted modules we added `--lora_layers`- in which you can specify in a comma seperated string
|
||||
applying LoRA training onto different types of layers and blocks. To allow more flexibility and control over the targeted modules we added `--lora_layers`- in which you can specify in a comma separated string
|
||||
the exact modules for LoRA training. Here are some examples of target modules you can provide:
|
||||
- for attention only layers: `--lora_layers="attn.to_k,attn.to_q,attn.to_v,attn.to_out.0"`
|
||||
- to train the same modules as in the fal trainer: `--lora_layers="attn.to_k,attn.to_q,attn.to_v,attn.to_out.0,attn.add_k_proj,attn.add_q_proj,attn.add_v_proj,attn.to_add_out,ff.net.0.proj,ff.net.2,ff_context.net.0.proj,ff_context.net.2"`
|
||||
- to train the same modules as in ostris ai-toolkit / replicate trainer: `--lora_blocks="attn.to_k,attn.to_q,attn.to_v,attn.to_out.0,attn.add_k_proj,attn.add_q_proj,attn.add_v_proj,attn.to_add_out,ff.net.0.proj,ff.net.2,ff_context.net.0.proj,ff_context.net.2,norm1_context.linear, norm1.linear,norm.linear,proj_mlp,proj_out"`
|
||||
> [!NOTE]
|
||||
> `--lora_layers` can also be used to specify which **blocks** to apply LoRA training to. To do so, simply add a block prefix to each layer in the comma seperated string:
|
||||
> `--lora_layers` can also be used to specify which **blocks** to apply LoRA training to. To do so, simply add a block prefix to each layer in the comma separated string:
|
||||
> **single DiT blocks**: to target the ith single transformer block, add the prefix `single_transformer_blocks.i`, e.g. - `single_transformer_blocks.i.attn.to_k`
|
||||
> **MMDiT blocks**: to target the ith MMDiT block, add the prefix `transformer_blocks.i`, e.g. - `transformer_blocks.i.attn.to_k`
|
||||
> [!NOTE]
|
||||
|
||||
@@ -378,7 +378,7 @@ def parse_args(input_args=None):
|
||||
default=None,
|
||||
help="the concept to use to initialize the new inserted tokens when training with "
|
||||
"--train_text_encoder_ti = True. By default, new tokens (<si><si+1>) are initialized with random value. "
|
||||
"Alternatively, you could specify a different word/words whos value will be used as the starting point for the new inserted tokens. "
|
||||
"Alternatively, you could specify a different word/words whose value will be used as the starting point for the new inserted tokens. "
|
||||
"--num_new_tokens_per_abstraction is ignored when initializer_concept is provided",
|
||||
)
|
||||
parser.add_argument(
|
||||
@@ -662,7 +662,7 @@ def parse_args(input_args=None):
|
||||
type=str,
|
||||
default=None,
|
||||
help=(
|
||||
"The transformer modules to apply LoRA training on. Please specify the layers in a comma seperated. "
|
||||
"The transformer modules to apply LoRA training on. Please specify the layers in a comma separated. "
|
||||
'E.g. - "to_k,to_q,to_v,to_out.0" will result in lora training of attention layers only. For more examples refer to https://github.com/huggingface/diffusers/blob/main/examples/advanced_diffusion_training/README_flux.md'
|
||||
),
|
||||
)
|
||||
|
||||
@@ -662,7 +662,7 @@ def parse_args(input_args=None):
|
||||
action="store_true",
|
||||
default=False,
|
||||
help=(
|
||||
"Wether to train a DoRA as proposed in- DoRA: Weight-Decomposed Low-Rank Adaptation https://arxiv.org/abs/2402.09353. "
|
||||
"Whether to train a DoRA as proposed in- DoRA: Weight-Decomposed Low-Rank Adaptation https://arxiv.org/abs/2402.09353. "
|
||||
"Note: to use DoRA you need to install peft from main, `pip install git+https://github.com/huggingface/peft.git`"
|
||||
),
|
||||
)
|
||||
|
||||
@@ -773,7 +773,7 @@ def parse_args(input_args=None):
|
||||
action="store_true",
|
||||
default=False,
|
||||
help=(
|
||||
"Wether to train a DoRA as proposed in- DoRA: Weight-Decomposed Low-Rank Adaptation https://arxiv.org/abs/2402.09353. "
|
||||
"Whether to train a DoRA as proposed in- DoRA: Weight-Decomposed Low-Rank Adaptation https://arxiv.org/abs/2402.09353. "
|
||||
"Note: to use DoRA you need to install peft from main, `pip install git+https://github.com/huggingface/peft.git`"
|
||||
),
|
||||
)
|
||||
@@ -1875,7 +1875,7 @@ def main(args):
|
||||
# pack the statically computed variables appropriately here. This is so that we don't
|
||||
# have to pass them to the dataloader.
|
||||
|
||||
# if --train_text_encoder_ti we need add_special_tokens to be True fo textual inversion
|
||||
# if --train_text_encoder_ti we need add_special_tokens to be True for textual inversion
|
||||
add_special_tokens = True if args.train_text_encoder_ti else False
|
||||
|
||||
if not train_dataset.custom_instance_prompts:
|
||||
|
||||
@@ -0,0 +1,201 @@
|
||||
# Training CogView4 Control
|
||||
|
||||
This (experimental) example shows how to train Control LoRAs with [CogView4](https://huggingface.co/THUDM/CogView4-6B) by conditioning it with additional structural controls (like depth maps, poses, etc.). We provide a script for full fine-tuning, too, refer to [this section](#full-fine-tuning). To know more about CogView4 Control family, refer to the following resources:
|
||||
|
||||
To incorporate additional condition latents, we expand the input features of CogView-4 from 64 to 128. The first 64 channels correspond to the original input latents to be denoised, while the latter 64 channels correspond to control latents. This expansion happens on the `patch_embed` layer, where the combined latents are projected to the expected feature dimension of rest of the network. Inference is performed using the `CogView4ControlPipeline`.
|
||||
|
||||
> [!NOTE]
|
||||
> **Gated model**
|
||||
>
|
||||
> As the model is gated, before using it with diffusers you first need to go to the [CogView4 Hugging Face page](https://huggingface.co/THUDM/CogView4-6B), fill in the form and accept the gate. Once you are in, you need to log in so that your system knows you’ve accepted the gate. Use the command below to log in:
|
||||
|
||||
```bash
|
||||
huggingface-cli login
|
||||
```
|
||||
|
||||
The example command below shows how to launch fine-tuning for pose conditions. The dataset ([`raulc0399/open_pose_controlnet`](https://huggingface.co/datasets/raulc0399/open_pose_controlnet)) being used here already has the pose conditions of the original images, so we don't have to compute them.
|
||||
|
||||
```bash
|
||||
accelerate launch train_control_lora_cogview4.py \
|
||||
--pretrained_model_name_or_path="THUDM/CogView4-6B" \
|
||||
--dataset_name="raulc0399/open_pose_controlnet" \
|
||||
--output_dir="pose-control-lora" \
|
||||
--mixed_precision="bf16" \
|
||||
--train_batch_size=1 \
|
||||
--rank=64 \
|
||||
--gradient_accumulation_steps=4 \
|
||||
--gradient_checkpointing \
|
||||
--use_8bit_adam \
|
||||
--learning_rate=1e-4 \
|
||||
--report_to="wandb" \
|
||||
--lr_scheduler="constant" \
|
||||
--lr_warmup_steps=0 \
|
||||
--max_train_steps=5000 \
|
||||
--validation_image="openpose.png" \
|
||||
--validation_prompt="A couple, 4k photo, highly detailed" \
|
||||
--offload \
|
||||
--seed="0" \
|
||||
--push_to_hub
|
||||
```
|
||||
|
||||
`openpose.png` comes from [here](https://huggingface.co/Adapter/t2iadapter/resolve/main/openpose.png).
|
||||
|
||||
You need to install `diffusers` from the branch of [this PR](https://github.com/huggingface/diffusers/pull/9999). When it's merged, you should install `diffusers` from the `main`.
|
||||
|
||||
The training script exposes additional CLI args that might be useful to experiment with:
|
||||
|
||||
* `use_lora_bias`: When set, additionally trains the biases of the `lora_B` layer.
|
||||
* `train_norm_layers`: When set, additionally trains the normalization scales. Takes care of saving and loading.
|
||||
* `lora_layers`: Specify the layers you want to apply LoRA to. If you specify "all-linear", all the linear layers will be LoRA-attached.
|
||||
|
||||
### Training with DeepSpeed
|
||||
|
||||
It's possible to train with [DeepSpeed](https://github.com/microsoft/DeepSpeed), specifically leveraging the Zero2 system optimization. To use it, save the following config to an YAML file (feel free to modify as needed):
|
||||
|
||||
```yaml
|
||||
compute_environment: LOCAL_MACHINE
|
||||
debug: false
|
||||
deepspeed_config:
|
||||
gradient_accumulation_steps: 1
|
||||
gradient_clipping: 1.0
|
||||
offload_optimizer_device: cpu
|
||||
offload_param_device: cpu
|
||||
zero3_init_flag: false
|
||||
zero_stage: 2
|
||||
distributed_type: DEEPSPEED
|
||||
downcast_bf16: 'no'
|
||||
enable_cpu_affinity: false
|
||||
machine_rank: 0
|
||||
main_training_function: main
|
||||
mixed_precision: bf16
|
||||
num_machines: 1
|
||||
num_processes: 1
|
||||
rdzv_backend: static
|
||||
same_network: true
|
||||
tpu_env: []
|
||||
tpu_use_cluster: false
|
||||
tpu_use_sudo: false
|
||||
use_cpu: false
|
||||
```
|
||||
|
||||
And then while launching training, pass the config file:
|
||||
|
||||
```bash
|
||||
accelerate launch --config_file=CONFIG_FILE.yaml ...
|
||||
```
|
||||
|
||||
### Inference
|
||||
|
||||
The pose images in our dataset were computed using the [`controlnet_aux`](https://github.com/huggingface/controlnet_aux) library. Let's install it first:
|
||||
|
||||
```bash
|
||||
pip install controlnet_aux
|
||||
```
|
||||
|
||||
And then we are ready:
|
||||
|
||||
```py
|
||||
from controlnet_aux import OpenposeDetector
|
||||
from diffusers import CogView4ControlPipeline
|
||||
from diffusers.utils import load_image
|
||||
from PIL import Image
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
pipe = CogView4ControlPipeline.from_pretrained("THUDM/CogView4-6B", torch_dtype=torch.bfloat16).to("cuda")
|
||||
pipe.load_lora_weights("...") # change this.
|
||||
|
||||
open_pose = OpenposeDetector.from_pretrained("lllyasviel/Annotators")
|
||||
|
||||
# prepare pose condition.
|
||||
url = "https://huggingface.co/Adapter/t2iadapter/resolve/main/people.jpg"
|
||||
image = load_image(url)
|
||||
image = open_pose(image, detect_resolution=512, image_resolution=1024)
|
||||
image = np.array(image)[:, :, ::-1]
|
||||
image = Image.fromarray(np.uint8(image))
|
||||
|
||||
prompt = "A couple, 4k photo, highly detailed"
|
||||
|
||||
gen_images = pipe(
|
||||
prompt=prompt,
|
||||
control_image=image,
|
||||
num_inference_steps=50,
|
||||
joint_attention_kwargs={"scale": 0.9},
|
||||
guidance_scale=25.,
|
||||
).images[0]
|
||||
gen_images.save("output.png")
|
||||
```
|
||||
|
||||
## Full fine-tuning
|
||||
|
||||
We provide a non-LoRA version of the training script `train_control_cogview4.py`. Here is an example command:
|
||||
|
||||
```bash
|
||||
accelerate launch --config_file=accelerate_ds2.yaml train_control_cogview4.py \
|
||||
--pretrained_model_name_or_path="THUDM/CogView4-6B" \
|
||||
--dataset_name="raulc0399/open_pose_controlnet" \
|
||||
--output_dir="pose-control" \
|
||||
--mixed_precision="bf16" \
|
||||
--train_batch_size=2 \
|
||||
--dataloader_num_workers=4 \
|
||||
--gradient_accumulation_steps=4 \
|
||||
--gradient_checkpointing \
|
||||
--use_8bit_adam \
|
||||
--proportion_empty_prompts=0.2 \
|
||||
--learning_rate=5e-5 \
|
||||
--adam_weight_decay=1e-4 \
|
||||
--report_to="wandb" \
|
||||
--lr_scheduler="cosine" \
|
||||
--lr_warmup_steps=1000 \
|
||||
--checkpointing_steps=1000 \
|
||||
--max_train_steps=10000 \
|
||||
--validation_steps=200 \
|
||||
--validation_image "2_pose_1024.jpg" "3_pose_1024.jpg" \
|
||||
--validation_prompt "two friends sitting by each other enjoying a day at the park, full hd, cinematic" "person enjoying a day at the park, full hd, cinematic" \
|
||||
--offload \
|
||||
--seed="0" \
|
||||
--push_to_hub
|
||||
```
|
||||
|
||||
Change the `validation_image` and `validation_prompt` as needed.
|
||||
|
||||
For inference, this time, we will run:
|
||||
|
||||
```py
|
||||
from controlnet_aux import OpenposeDetector
|
||||
from diffusers import CogView4ControlPipeline, CogView4Transformer2DModel
|
||||
from diffusers.utils import load_image
|
||||
from PIL import Image
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
transformer = CogView4Transformer2DModel.from_pretrained("...") # change this.
|
||||
pipe = CogView4ControlPipeline.from_pretrained(
|
||||
"THUDM/CogView4-6B", transformer=transformer, torch_dtype=torch.bfloat16
|
||||
).to("cuda")
|
||||
|
||||
open_pose = OpenposeDetector.from_pretrained("lllyasviel/Annotators")
|
||||
|
||||
# prepare pose condition.
|
||||
url = "https://huggingface.co/Adapter/t2iadapter/resolve/main/people.jpg"
|
||||
image = load_image(url)
|
||||
image = open_pose(image, detect_resolution=512, image_resolution=1024)
|
||||
image = np.array(image)[:, :, ::-1]
|
||||
image = Image.fromarray(np.uint8(image))
|
||||
|
||||
prompt = "A couple, 4k photo, highly detailed"
|
||||
|
||||
gen_images = pipe(
|
||||
prompt=prompt,
|
||||
control_image=image,
|
||||
num_inference_steps=50,
|
||||
guidance_scale=25.,
|
||||
).images[0]
|
||||
gen_images.save("output.png")
|
||||
```
|
||||
|
||||
## Things to note
|
||||
|
||||
* The scripts provided in this directory are experimental and educational. This means we may have to tweak things around to get good results on a given condition. We believe this is best done with the community 🤗
|
||||
* The scripts are not memory-optimized but we offload the VAE and the text encoders to CPU when they are not used if `--offload` is specified.
|
||||
* We can extract LoRAs from the fully fine-tuned model. While we currently don't provide any utilities for that, users are welcome to refer to [this script](https://github.com/Stability-AI/stability-ComfyUI-nodes/blob/master/control_lora_create.py) that provides a similar functionality.
|
||||
@@ -0,0 +1,6 @@
|
||||
transformers==4.47.0
|
||||
wandb
|
||||
torch
|
||||
torchvision
|
||||
accelerate==1.2.0
|
||||
peft>=0.14.0
|
||||
File diff suppressed because it is too large
Load Diff
@@ -10,6 +10,7 @@ Please also check out our [Community Scripts](https://github.com/huggingface/dif
|
||||
|
||||
| Example | Description | Code Example | Colab | Author |
|
||||
|:--------------------------------------------------------------------------------------------------------------------------------------|:---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|:------------------------------------------------------------------------------------------|:-------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|--------------------------------------------------------------:|
|
||||
|Spatiotemporal Skip Guidance (STG)|[Spatiotemporal Skip Guidance for Enhanced Video Diffusion Sampling](https://arxiv.org/abs/2411.18664) (CVPR 2025) enhances video diffusion models by generating a weaker model through layer skipping and using it as guidance, improving fidelity in models like HunyuanVideo, LTXVideo, and Mochi.|[Spatiotemporal Skip Guidance](#spatiotemporal-skip-guidance)|-|[Junha Hyung](https://junhahyung.github.io/), [Kinam Kim](https://kinam0252.github.io/)|
|
||||
|Adaptive Mask Inpainting|Adaptive Mask Inpainting algorithm from [Beyond the Contact: Discovering Comprehensive Affordance for 3D Objects from Pre-trained 2D Diffusion Models](https://github.com/snuvclab/coma) (ECCV '24, Oral) provides a way to insert human inside the scene image without altering the background, by inpainting with adapting mask.|[Adaptive Mask Inpainting](#adaptive-mask-inpainting)|-|[Hyeonwoo Kim](https://sshowbiz.xyz),[Sookwan Han](https://jellyheadandrew.github.io)|
|
||||
|Flux with CFG|[Flux with CFG](https://github.com/ToTheBeginning/PuLID/blob/main/docs/pulid_for_flux.md) provides an implementation of using CFG in [Flux](https://blackforestlabs.ai/announcing-black-forest-labs/).|[Flux with CFG](#flux-with-cfg)|[Notebook](https://github.com/huggingface/notebooks/blob/main/diffusers/flux_with_cfg.ipynb)|[Linoy Tsaban](https://github.com/linoytsaban), [Apolinário](https://github.com/apolinario), and [Sayak Paul](https://github.com/sayakpaul)|
|
||||
|Differential Diffusion|[Differential Diffusion](https://github.com/exx8/differential-diffusion) modifies an image according to a text prompt, and according to a map that specifies the amount of change in each region.|[Differential Diffusion](#differential-diffusion)|[](https://huggingface.co/spaces/exx8/differential-diffusion) [](https://colab.research.google.com/github/exx8/differential-diffusion/blob/main/examples/SD2.ipynb)|[Eran Levin](https://github.com/exx8) and [Ohad Fried](https://www.ohadf.com/)|
|
||||
@@ -23,12 +24,12 @@ Please also check out our [Community Scripts](https://github.com/huggingface/dif
|
||||
| Long Prompt Weighting Stable Diffusion | **One** Stable Diffusion Pipeline without tokens length limit, and support parsing weighting in prompt. | [Long Prompt Weighting Stable Diffusion](#long-prompt-weighting-stable-diffusion) | [Notebook](https://github.com/huggingface/notebooks/blob/main/diffusers/long_prompt_weighting_stable_diffusion.ipynb) | [SkyTNT](https://github.com/SkyTNT) |
|
||||
| Speech to Image | Using automatic-speech-recognition to transcribe text and Stable Diffusion to generate images | [Speech to Image](#speech-to-image) |[Notebook](https://github.com/huggingface/notebooks/blob/main/diffusers/speech_to_image.ipynb) | [Mikail Duzenli](https://github.com/MikailINTech)
|
||||
| Wild Card Stable Diffusion | Stable Diffusion Pipeline that supports prompts that contain wildcard terms (indicated by surrounding double underscores), with values instantiated randomly from a corresponding txt file or a dictionary of possible values | [Wildcard Stable Diffusion](#wildcard-stable-diffusion) | [Notebook](https://github.com/huggingface/notebooks/blob/main/diffusers/wildcard_stable_diffusion.ipynb) | [Shyam Sudhakaran](https://github.com/shyamsn97) |
|
||||
| [Composable Stable Diffusion](https://energy-based-model.github.io/Compositional-Visual-Generation-with-Composable-Diffusion-Models/) | Stable Diffusion Pipeline that supports prompts that contain "|" in prompts (as an AND condition) and weights (separated by "|" as well) to positively / negatively weight prompts. | [Composable Stable Diffusion](#composable-stable-diffusion) | - | [Mark Rich](https://github.com/MarkRich) |
|
||||
| [Composable Stable Diffusion](https://energy-based-model.github.io/Compositional-Visual-Generation-with-Composable-Diffusion-Models/) | Stable Diffusion Pipeline that supports prompts that contain "|" in prompts (as an AND condition) and weights (separated by "|" as well) to positively / negatively weight prompts. | [Composable Stable Diffusion](#composable-stable-diffusion) | [Notebook](https://github.com/huggingface/notebooks/blob/main/diffusers/composable_stable_diffusion.ipynb) | [Mark Rich](https://github.com/MarkRich) |
|
||||
| Seed Resizing Stable Diffusion | Stable Diffusion Pipeline that supports resizing an image and retaining the concepts of the 512 by 512 generation. | [Seed Resizing](#seed-resizing) | [Notebook](https://github.com/huggingface/notebooks/blob/main/diffusers/seed_resizing.ipynb) | [Mark Rich](https://github.com/MarkRich) |
|
||||
| Imagic Stable Diffusion | Stable Diffusion Pipeline that enables writing a text prompt to edit an existing image | [Imagic Stable Diffusion](#imagic-stable-diffusion) | [Notebook](https://github.com/huggingface/notebooks/blob/main/diffusers/imagic_stable_diffusion.ipynb) | [Mark Rich](https://github.com/MarkRich) |
|
||||
| Multilingual Stable Diffusion | Stable Diffusion Pipeline that supports prompts in 50 different languages. | [Multilingual Stable Diffusion](#multilingual-stable-diffusion-pipeline) | [Notebook](https://github.com/huggingface/notebooks/blob/main/diffusers/multilingual_stable_diffusion.ipynb) | [Juan Carlos Piñeros](https://github.com/juancopi81) |
|
||||
| GlueGen Stable Diffusion | Stable Diffusion Pipeline that supports prompts in different languages using GlueGen adapter. | [GlueGen Stable Diffusion](#gluegen-stable-diffusion-pipeline) | [Notebook](https://github.com/huggingface/notebooks/blob/main/diffusers/gluegen_stable_diffusion.ipynb) | [Phạm Hồng Vinh](https://github.com/rootonchair) |
|
||||
| Image to Image Inpainting Stable Diffusion | Stable Diffusion Pipeline that enables the overlaying of two images and subsequent inpainting | [Image to Image Inpainting Stable Diffusion](#image-to-image-inpainting-stable-diffusion) | - | [Alex McKinney](https://github.com/vvvm23) |
|
||||
| Image to Image Inpainting Stable Diffusion | Stable Diffusion Pipeline that enables the overlaying of two images and subsequent inpainting | [Image to Image Inpainting Stable Diffusion](#image-to-image-inpainting-stable-diffusion) | [Notebook](https://github.com/huggingface/notebooks/blob/main/diffusers/image_to_image_inpainting_stable_diffusion.ipynb) | [Alex McKinney](https://github.com/vvvm23) |
|
||||
| Text Based Inpainting Stable Diffusion | Stable Diffusion Inpainting Pipeline that enables passing a text prompt to generate the mask for inpainting | [Text Based Inpainting Stable Diffusion](#text-based-inpainting-stable-diffusion) | [Notebook](https://github.com/huggingface/notebooks/blob/main/diffusers/text_based_inpainting_stable_dffusion.ipynb) | [Dhruv Karan](https://github.com/unography) |
|
||||
| Bit Diffusion | Diffusion on discrete data | [Bit Diffusion](#bit-diffusion) | - | [Stuti R.](https://github.com/kingstut) |
|
||||
| K-Diffusion Stable Diffusion | Run Stable Diffusion with any of [K-Diffusion's samplers](https://github.com/crowsonkb/k-diffusion/blob/master/k_diffusion/sampling.py) | [Stable Diffusion with K Diffusion](#stable-diffusion-with-k-diffusion) | - | [Patrick von Platen](https://github.com/patrickvonplaten/) |
|
||||
@@ -40,7 +41,7 @@ Please also check out our [Community Scripts](https://github.com/huggingface/dif
|
||||
| UnCLIP Image Interpolation Pipeline | Diffusion Pipeline that allows passing two images/image_embeddings and produces images while interpolating between their image-embeddings | [UnCLIP Image Interpolation Pipeline](#unclip-image-interpolation-pipeline) | [Notebook](https://github.com/huggingface/notebooks/blob/main/diffusers/unclip_image_interpolation.ipynb)| [Naga Sai Abhinay Devarinti](https://github.com/Abhinay1997/) |
|
||||
| DDIM Noise Comparative Analysis Pipeline | Investigating how the diffusion models learn visual concepts from each noise level (which is a contribution of [P2 weighting (CVPR 2022)](https://arxiv.org/abs/2204.00227)) | [DDIM Noise Comparative Analysis Pipeline](#ddim-noise-comparative-analysis-pipeline) | [Notebook](https://github.com/huggingface/notebooks/blob/main/diffusers/ddim_noise_comparative_analysis.ipynb)| [Aengus (Duc-Anh)](https://github.com/aengusng8) |
|
||||
| CLIP Guided Img2Img Stable Diffusion Pipeline | Doing CLIP guidance for image to image generation with Stable Diffusion | [CLIP Guided Img2Img Stable Diffusion](#clip-guided-img2img-stable-diffusion) | [Notebook](https://github.com/huggingface/notebooks/blob/main/diffusers/clip_guided_img2img_stable_diffusion.ipynb) | [Nipun Jindal](https://github.com/nipunjindal/) |
|
||||
| TensorRT Stable Diffusion Text to Image Pipeline | Accelerates the Stable Diffusion Text2Image Pipeline using TensorRT | [TensorRT Stable Diffusion Text to Image Pipeline](#tensorrt-text2image-stable-diffusion-pipeline) | - | [Asfiya Baig](https://github.com/asfiyab-nvidia) |
|
||||
| TensorRT Stable Diffusion Text to Image Pipeline | Accelerates the Stable Diffusion Text2Image Pipeline using TensorRT | [TensorRT Stable Diffusion Text to Image Pipeline](#tensorrt-text2image-stable-diffusion-pipeline) | [Notebook](https://github.com/huggingface/notebooks/blob/main/diffusers/tensorrt_text2image_stable_diffusion_pipeline.ipynb) | [Asfiya Baig](https://github.com/asfiyab-nvidia) |
|
||||
| EDICT Image Editing Pipeline | Diffusion pipeline for text-guided image editing | [EDICT Image Editing Pipeline](#edict-image-editing-pipeline) | [Notebook](https://github.com/huggingface/notebooks/blob/main/diffusers/edict_image_pipeline.ipynb) | [Joqsan Azocar](https://github.com/Joqsan) |
|
||||
| Stable Diffusion RePaint | Stable Diffusion pipeline using [RePaint](https://arxiv.org/abs/2201.09865) for inpainting. | [Stable Diffusion RePaint](#stable-diffusion-repaint )|[Notebook](https://github.com/huggingface/notebooks/blob/main/diffusers/stable_diffusion_repaint.ipynb)| [Markus Pobitzer](https://github.com/Markus-Pobitzer) |
|
||||
| TensorRT Stable Diffusion Image to Image Pipeline | Accelerates the Stable Diffusion Image2Image Pipeline using TensorRT | [TensorRT Stable Diffusion Image to Image Pipeline](#tensorrt-image2image-stable-diffusion-pipeline) | - | [Asfiya Baig](https://github.com/asfiyab-nvidia) |
|
||||
@@ -57,7 +58,7 @@ Please also check out our [Community Scripts](https://github.com/huggingface/dif
|
||||
| FABRIC - Stable Diffusion with feedback Pipeline | pipeline supports feedback from liked and disliked images | [Stable Diffusion Fabric Pipeline](#stable-diffusion-fabric-pipeline) | [Notebook](https://github.com/huggingface/notebooks/blob/main/diffusers/stable_diffusion_fabric.ipynb)| [Shauray Singh](https://shauray8.github.io/about_shauray/) |
|
||||
| sketch inpaint - Inpainting with non-inpaint Stable Diffusion | sketch inpaint much like in automatic1111 | [Masked Im2Im Stable Diffusion Pipeline](#stable-diffusion-masked-im2im) | - | [Anatoly Belikov](https://github.com/noskill) |
|
||||
| sketch inpaint xl - Inpainting with non-inpaint Stable Diffusion | sketch inpaint much like in automatic1111 | [Masked Im2Im Stable Diffusion XL Pipeline](#stable-diffusion-xl-masked-im2im) | - | [Anatoly Belikov](https://github.com/noskill) |
|
||||
| prompt-to-prompt | change parts of a prompt and retain image structure (see [paper page](https://prompt-to-prompt.github.io/)) | [Prompt2Prompt Pipeline](#prompt2prompt-pipeline) | - | [Umer H. Adil](https://twitter.com/UmerHAdil) |
|
||||
| prompt-to-prompt | change parts of a prompt and retain image structure (see [paper page](https://prompt-to-prompt.github.io/)) | [Prompt2Prompt Pipeline](#prompt2prompt-pipeline) | [Notebook](https://github.com/huggingface/notebooks/blob/main/diffusers/prompt_2_prompt_pipeline.ipynb) | [Umer H. Adil](https://twitter.com/UmerHAdil) |
|
||||
| Latent Consistency Pipeline | Implementation of [Latent Consistency Models: Synthesizing High-Resolution Images with Few-Step Inference](https://arxiv.org/abs/2310.04378) | [Latent Consistency Pipeline](#latent-consistency-pipeline) | - | [Simian Luo](https://github.com/luosiallen) |
|
||||
| Latent Consistency Img2img Pipeline | Img2img pipeline for Latent Consistency Models | [Latent Consistency Img2Img Pipeline](#latent-consistency-img2img-pipeline) | - | [Logan Zoellner](https://github.com/nagolinc) |
|
||||
| Latent Consistency Interpolation Pipeline | Interpolate the latent space of Latent Consistency Models with multiple prompts | [Latent Consistency Interpolation Pipeline](#latent-consistency-interpolation-pipeline) | [](https://colab.research.google.com/drive/1pK3NrLWJSiJsBynLns1K1-IDTW9zbPvl?usp=sharing) | [Aryan V S](https://github.com/a-r-r-o-w) |
|
||||
@@ -93,6 +94,55 @@ pipe = DiffusionPipeline.from_pretrained("stable-diffusion-v1-5/stable-diffusion
|
||||
|
||||
## Example usages
|
||||
|
||||
### Spatiotemporal Skip Guidance
|
||||
|
||||
**Junha Hyung\*, Kinam Kim\*, Susung Hong, Min-Jung Kim, Jaegul Choo**
|
||||
|
||||
**KAIST AI, University of Washington**
|
||||
|
||||
[*Spatiotemporal Skip Guidance (STG) for Enhanced Video Diffusion Sampling*](https://arxiv.org/abs/2411.18664) (CVPR 2025) is a simple training-free sampling guidance method for enhancing transformer-based video diffusion models. STG employs an implicit weak model via self-perturbation, avoiding the need for external models or additional training. By selectively skipping spatiotemporal layers, STG produces an aligned, degraded version of the original model to boost sample quality without compromising diversity or dynamic degree.
|
||||
|
||||
Following is the example video of STG applied to Mochi.
|
||||
|
||||
|
||||
https://github.com/user-attachments/assets/148adb59-da61-4c50-9dfa-425dcb5c23b3
|
||||
|
||||
More examples and information can be found on the [GitHub repository](https://github.com/junhahyung/STGuidance) and the [Project website](https://junhahyung.github.io/STGuidance/).
|
||||
|
||||
#### Usage example
|
||||
```python
|
||||
import torch
|
||||
from pipeline_stg_mochi import MochiSTGPipeline
|
||||
from diffusers.utils import export_to_video
|
||||
|
||||
# Load the pipeline
|
||||
pipe = MochiSTGPipeline.from_pretrained("genmo/mochi-1-preview", variant="bf16", torch_dtype=torch.bfloat16)
|
||||
|
||||
# Enable memory savings
|
||||
pipe = pipe.to("cuda")
|
||||
|
||||
#--------Option--------#
|
||||
prompt = "A close-up of a beautiful woman's face with colored powder exploding around her, creating an abstract splash of vibrant hues, realistic style."
|
||||
stg_applied_layers_idx = [34]
|
||||
stg_mode = "STG"
|
||||
stg_scale = 1.0 # 0.0 for CFG
|
||||
#----------------------#
|
||||
|
||||
# Generate video frames
|
||||
frames = pipe(
|
||||
prompt,
|
||||
height=480,
|
||||
width=480,
|
||||
num_frames=81,
|
||||
stg_applied_layers_idx=stg_applied_layers_idx,
|
||||
stg_scale=stg_scale,
|
||||
generator = torch.Generator().manual_seed(42),
|
||||
do_rescaling=do_rescaling,
|
||||
).frames[0]
|
||||
|
||||
export_to_video(frames, "output.mp4", fps=30)
|
||||
```
|
||||
|
||||
### Adaptive Mask Inpainting
|
||||
|
||||
**Hyeonwoo Kim\*, Sookwan Han\*, Patrick Kwon, Hanbyul Joo**
|
||||
@@ -904,6 +954,7 @@ for i in range(args.num_images):
|
||||
images.append(th.from_numpy(np.array(image)).permute(2, 0, 1) / 255.)
|
||||
grid = tvu.make_grid(th.stack(images, dim=0), nrow=4, padding=0)
|
||||
tvu.save_image(grid, f'{prompt}_{args.weights}' + '.png')
|
||||
print("Image saved successfully!")
|
||||
```
|
||||
|
||||
### Imagic Stable Diffusion
|
||||
@@ -1219,28 +1270,39 @@ The aim is to overlay two images, then mask out the boundary between `image` and
|
||||
For example, this could be used to place a logo on a shirt and make it blend seamlessly.
|
||||
|
||||
```python
|
||||
import PIL
|
||||
import torch
|
||||
|
||||
import requests
|
||||
from PIL import Image
|
||||
from io import BytesIO
|
||||
from diffusers import DiffusionPipeline
|
||||
|
||||
image_path = "./path-to-image.png"
|
||||
inner_image_path = "./path-to-inner-image.png"
|
||||
mask_path = "./path-to-mask.png"
|
||||
image_url = "https://raw.githubusercontent.com/CompVis/latent-diffusion/main/data/inpainting_examples/overture-creations-5sI6fQgYIuo.png"
|
||||
inner_image_url = "https://raw.githubusercontent.com/CompVis/latent-diffusion/main/data/inpainting_examples/overture-creations-5sI6fQgYIuo_mask.png"
|
||||
mask_url = "https://raw.githubusercontent.com/CompVis/latent-diffusion/main/data/inpainting_examples/overture-creations-5sI6fQgYIuo_mask.png"
|
||||
|
||||
init_image = PIL.Image.open(image_path).convert("RGB").resize((512, 512))
|
||||
inner_image = PIL.Image.open(inner_image_path).convert("RGBA").resize((512, 512))
|
||||
mask_image = PIL.Image.open(mask_path).convert("RGB").resize((512, 512))
|
||||
def load_image(url, mode="RGB"):
|
||||
response = requests.get(url)
|
||||
if response.status_code == 200:
|
||||
return Image.open(BytesIO(response.content)).convert(mode).resize((512, 512))
|
||||
else:
|
||||
raise FileNotFoundError(f"Could not retrieve image from {url}")
|
||||
|
||||
|
||||
init_image = load_image(image_url, mode="RGB")
|
||||
inner_image = load_image(inner_image_url, mode="RGBA")
|
||||
mask_image = load_image(mask_url, mode="RGB")
|
||||
|
||||
pipe = DiffusionPipeline.from_pretrained(
|
||||
"runwayml/stable-diffusion-inpainting",
|
||||
"stable-diffusion-v1-5/stable-diffusion-inpainting",
|
||||
custom_pipeline="img2img_inpainting",
|
||||
torch_dtype=torch.float16
|
||||
)
|
||||
pipe = pipe.to("cuda")
|
||||
|
||||
prompt = "Your prompt here!"
|
||||
prompt = "a mecha robot sitting on a bench"
|
||||
image = pipe(prompt=prompt, image=init_image, inner_image=inner_image, mask_image=mask_image).images[0]
|
||||
|
||||
image.save("output.png")
|
||||
```
|
||||
|
||||

|
||||
@@ -3202,14 +3264,19 @@ Here's a full example for `ReplaceEdit``:
|
||||
|
||||
```python
|
||||
import torch
|
||||
import numpy as np
|
||||
import matplotlib.pyplot as plt
|
||||
from diffusers import DiffusionPipeline
|
||||
import numpy as np
|
||||
from PIL import Image
|
||||
|
||||
pipe = DiffusionPipeline.from_pretrained("CompVis/stable-diffusion-v1-4", custom_pipeline="pipeline_prompt2prompt").to("cuda")
|
||||
pipe = DiffusionPipeline.from_pretrained(
|
||||
"CompVis/stable-diffusion-v1-4",
|
||||
custom_pipeline="pipeline_prompt2prompt"
|
||||
).to("cuda")
|
||||
|
||||
prompts = ["A turtle playing with a ball",
|
||||
"A monkey playing with a ball"]
|
||||
prompts = [
|
||||
"A turtle playing with a ball",
|
||||
"A monkey playing with a ball"
|
||||
]
|
||||
|
||||
cross_attention_kwargs = {
|
||||
"edit_type": "replace",
|
||||
@@ -3217,7 +3284,15 @@ cross_attention_kwargs = {
|
||||
"self_replace_steps": 0.4
|
||||
}
|
||||
|
||||
outputs = pipe(prompt=prompts, height=512, width=512, num_inference_steps=50, cross_attention_kwargs=cross_attention_kwargs)
|
||||
outputs = pipe(
|
||||
prompt=prompts,
|
||||
height=512,
|
||||
width=512,
|
||||
num_inference_steps=50,
|
||||
cross_attention_kwargs=cross_attention_kwargs
|
||||
)
|
||||
|
||||
outputs.images[0].save("output_image_0.png")
|
||||
```
|
||||
|
||||
And abbreviated examples for the other edits:
|
||||
|
||||
@@ -1773,7 +1773,7 @@ class SDXLLongPromptWeightingPipeline(
|
||||
f"Incorrect configuration settings! The config of `pipeline.unet`: {self.unet.config} expects"
|
||||
f" {self.unet.config.in_channels} but received `num_channels_latents`: {num_channels_latents} +"
|
||||
f" `num_channels_mask`: {num_channels_mask} + `num_channels_masked_image`: {num_channels_masked_image}"
|
||||
f" = {num_channels_latents+num_channels_masked_image+num_channels_mask}. Please verify the config of"
|
||||
f" = {num_channels_latents + num_channels_masked_image + num_channels_mask}. Please verify the config of"
|
||||
" `pipeline.unet` or your `mask_image` or `image` input."
|
||||
)
|
||||
elif num_channels_unet != 4:
|
||||
@@ -1924,7 +1924,22 @@ class SDXLLongPromptWeightingPipeline(
|
||||
self.upcast_vae()
|
||||
latents = latents.to(next(iter(self.vae.post_quant_conv.parameters())).dtype)
|
||||
|
||||
image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0]
|
||||
# unscale/denormalize the latents
|
||||
# denormalize with the mean and std if available and not None
|
||||
has_latents_mean = hasattr(self.vae.config, "latents_mean") and self.vae.config.latents_mean is not None
|
||||
has_latents_std = hasattr(self.vae.config, "latents_std") and self.vae.config.latents_std is not None
|
||||
if has_latents_mean and has_latents_std:
|
||||
latents_mean = (
|
||||
torch.tensor(self.vae.config.latents_mean).view(1, 4, 1, 1).to(latents.device, latents.dtype)
|
||||
)
|
||||
latents_std = (
|
||||
torch.tensor(self.vae.config.latents_std).view(1, 4, 1, 1).to(latents.device, latents.dtype)
|
||||
)
|
||||
latents = latents * latents_std / self.vae.config.scaling_factor + latents_mean
|
||||
else:
|
||||
latents = latents / self.vae.config.scaling_factor
|
||||
|
||||
image = self.vae.decode(latents, return_dict=False)[0]
|
||||
|
||||
# cast back to fp16 if needed
|
||||
if needs_upcasting:
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
# Copyright 2025 The HuggingFace Team. All rights reserved.
|
||||
# Copyright 2025 The DEVAIEXP Team and The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
@@ -1070,32 +1070,32 @@ class StableDiffusionXLTilingPipeline(
|
||||
text_encoder_projection_dim = int(pooled_prompt_embeds.shape[-1])
|
||||
else:
|
||||
text_encoder_projection_dim = self.text_encoder_2.config.projection_dim
|
||||
add_time_ids = self._get_add_time_ids(
|
||||
original_size,
|
||||
crops_coords_top_left[row][col],
|
||||
target_size,
|
||||
add_time_ids = self._get_add_time_ids(
|
||||
original_size,
|
||||
crops_coords_top_left[row][col],
|
||||
target_size,
|
||||
dtype=prompt_embeds.dtype,
|
||||
text_encoder_projection_dim=text_encoder_projection_dim,
|
||||
)
|
||||
if negative_original_size is not None and negative_target_size is not None:
|
||||
negative_add_time_ids = self._get_add_time_ids(
|
||||
negative_original_size,
|
||||
negative_crops_coords_top_left[row][col],
|
||||
negative_target_size,
|
||||
dtype=prompt_embeds.dtype,
|
||||
text_encoder_projection_dim=text_encoder_projection_dim,
|
||||
)
|
||||
if negative_original_size is not None and negative_target_size is not None:
|
||||
negative_add_time_ids = self._get_add_time_ids(
|
||||
negative_original_size,
|
||||
negative_crops_coords_top_left[row][col],
|
||||
negative_target_size,
|
||||
dtype=prompt_embeds.dtype,
|
||||
text_encoder_projection_dim=text_encoder_projection_dim,
|
||||
)
|
||||
else:
|
||||
negative_add_time_ids = add_time_ids
|
||||
else:
|
||||
negative_add_time_ids = add_time_ids
|
||||
|
||||
if self.do_classifier_free_guidance:
|
||||
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)
|
||||
add_text_embeds = torch.cat([negative_pooled_prompt_embeds, add_text_embeds], dim=0)
|
||||
add_time_ids = torch.cat([negative_add_time_ids, add_time_ids], dim=0)
|
||||
if self.do_classifier_free_guidance:
|
||||
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)
|
||||
add_text_embeds = torch.cat([negative_pooled_prompt_embeds, add_text_embeds], dim=0)
|
||||
add_time_ids = torch.cat([negative_add_time_ids, add_time_ids], dim=0)
|
||||
|
||||
prompt_embeds = prompt_embeds.to(device)
|
||||
add_text_embeds = add_text_embeds.to(device)
|
||||
add_time_ids = add_time_ids.to(device).repeat(batch_size * num_images_per_prompt, 1)
|
||||
prompt_embeds = prompt_embeds.to(device)
|
||||
add_text_embeds = add_text_embeds.to(device)
|
||||
add_time_ids = add_time_ids.to(device).repeat(batch_size * num_images_per_prompt, 1)
|
||||
addition_embed_type_row.append((prompt_embeds, add_text_embeds, add_time_ids))
|
||||
embeddings_and_added_time.append(addition_embed_type_row)
|
||||
|
||||
|
||||
@@ -0,0 +1,876 @@
|
||||
# Copyright 2024 The CogVideoX team, Tsinghua University & ZhipuAI and The HuggingFace Team.
|
||||
# All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import inspect
|
||||
import math
|
||||
import types
|
||||
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
from transformers import T5EncoderModel, T5Tokenizer
|
||||
|
||||
from diffusers.callbacks import MultiPipelineCallbacks, PipelineCallback
|
||||
from diffusers.loaders import CogVideoXLoraLoaderMixin
|
||||
from diffusers.models import AutoencoderKLCogVideoX, CogVideoXTransformer3DModel
|
||||
from diffusers.models.embeddings import get_3d_rotary_pos_embed
|
||||
from diffusers.pipelines.cogvideo.pipeline_output import CogVideoXPipelineOutput
|
||||
from diffusers.pipelines.pipeline_utils import DiffusionPipeline
|
||||
from diffusers.schedulers import CogVideoXDDIMScheduler, CogVideoXDPMScheduler
|
||||
from diffusers.utils import is_torch_xla_available, logging, replace_example_docstring
|
||||
from diffusers.utils.torch_utils import randn_tensor
|
||||
from diffusers.video_processor import VideoProcessor
|
||||
|
||||
|
||||
if is_torch_xla_available():
|
||||
import torch_xla.core.xla_model as xm
|
||||
|
||||
XLA_AVAILABLE = True
|
||||
else:
|
||||
XLA_AVAILABLE = False
|
||||
|
||||
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||
|
||||
|
||||
EXAMPLE_DOC_STRING = """
|
||||
Examples:
|
||||
```python
|
||||
>>> import torch
|
||||
>>> from diffusers.utils import export_to_video
|
||||
>>> from examples.community.pipeline_stg_cogvideox import CogVideoXSTGPipeline
|
||||
|
||||
>>> # Models: "THUDM/CogVideoX-2b" or "THUDM/CogVideoX-5b"
|
||||
>>> pipe = CogVideoXSTGPipeline.from_pretrained("THUDM/CogVideoX-5b", torch_dtype=torch.float16).to("cuda")
|
||||
>>> prompt = (
|
||||
... "A father and son building a treehouse together, their hands covered in sawdust and smiles on their faces, realistic style."
|
||||
... )
|
||||
>>> pipe.transformer.to(memory_format=torch.channels_last)
|
||||
|
||||
>>> # Configure STG mode options
|
||||
>>> stg_applied_layers_idx = [11] # Layer indices from 0 to 41
|
||||
>>> stg_scale = 1.0 # Set to 0.0 for CFG
|
||||
>>> do_rescaling = False
|
||||
|
||||
>>> # Generate video frames with STG parameters
|
||||
>>> frames = pipe(
|
||||
... prompt=prompt,
|
||||
... stg_applied_layers_idx=stg_applied_layers_idx,
|
||||
... stg_scale=stg_scale,
|
||||
... do_rescaling=do_rescaling,
|
||||
>>> ).frames[0]
|
||||
>>> export_to_video(frames, "output.mp4", fps=8)
|
||||
```
|
||||
"""
|
||||
|
||||
|
||||
def forward_with_stg(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
encoder_hidden_states: torch.Tensor,
|
||||
temb: torch.Tensor,
|
||||
image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
|
||||
) -> torch.Tensor:
|
||||
hidden_states_ptb = hidden_states[2:]
|
||||
encoder_hidden_states_ptb = encoder_hidden_states[2:]
|
||||
|
||||
text_seq_length = encoder_hidden_states.size(1)
|
||||
|
||||
# norm & modulate
|
||||
norm_hidden_states, norm_encoder_hidden_states, gate_msa, enc_gate_msa = self.norm1(
|
||||
hidden_states, encoder_hidden_states, temb
|
||||
)
|
||||
|
||||
# attention
|
||||
attn_hidden_states, attn_encoder_hidden_states = self.attn1(
|
||||
hidden_states=norm_hidden_states,
|
||||
encoder_hidden_states=norm_encoder_hidden_states,
|
||||
image_rotary_emb=image_rotary_emb,
|
||||
)
|
||||
|
||||
hidden_states = hidden_states + gate_msa * attn_hidden_states
|
||||
encoder_hidden_states = encoder_hidden_states + enc_gate_msa * attn_encoder_hidden_states
|
||||
|
||||
# norm & modulate
|
||||
norm_hidden_states, norm_encoder_hidden_states, gate_ff, enc_gate_ff = self.norm2(
|
||||
hidden_states, encoder_hidden_states, temb
|
||||
)
|
||||
|
||||
# feed-forward
|
||||
norm_hidden_states = torch.cat([norm_encoder_hidden_states, norm_hidden_states], dim=1)
|
||||
ff_output = self.ff(norm_hidden_states)
|
||||
|
||||
hidden_states = hidden_states + gate_ff * ff_output[:, text_seq_length:]
|
||||
encoder_hidden_states = encoder_hidden_states + enc_gate_ff * ff_output[:, :text_seq_length]
|
||||
|
||||
hidden_states[2:] = hidden_states_ptb
|
||||
encoder_hidden_states[2:] = encoder_hidden_states_ptb
|
||||
|
||||
return hidden_states, encoder_hidden_states
|
||||
|
||||
|
||||
# Similar to diffusers.pipelines.hunyuandit.pipeline_hunyuandit.get_resize_crop_region_for_grid
|
||||
def get_resize_crop_region_for_grid(src, tgt_width, tgt_height):
|
||||
tw = tgt_width
|
||||
th = tgt_height
|
||||
h, w = src
|
||||
r = h / w
|
||||
if r > (th / tw):
|
||||
resize_height = th
|
||||
resize_width = int(round(th / h * w))
|
||||
else:
|
||||
resize_width = tw
|
||||
resize_height = int(round(tw / w * h))
|
||||
|
||||
crop_top = int(round((th - resize_height) / 2.0))
|
||||
crop_left = int(round((tw - resize_width) / 2.0))
|
||||
|
||||
return (crop_top, crop_left), (crop_top + resize_height, crop_left + resize_width)
|
||||
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps
|
||||
def retrieve_timesteps(
|
||||
scheduler,
|
||||
num_inference_steps: Optional[int] = None,
|
||||
device: Optional[Union[str, torch.device]] = None,
|
||||
timesteps: Optional[List[int]] = None,
|
||||
sigmas: Optional[List[float]] = None,
|
||||
**kwargs,
|
||||
):
|
||||
r"""
|
||||
Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
|
||||
custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.
|
||||
|
||||
Args:
|
||||
scheduler (`SchedulerMixin`):
|
||||
The scheduler to get timesteps from.
|
||||
num_inference_steps (`int`):
|
||||
The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps`
|
||||
must be `None`.
|
||||
device (`str` or `torch.device`, *optional*):
|
||||
The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
|
||||
timesteps (`List[int]`, *optional*):
|
||||
Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed,
|
||||
`num_inference_steps` and `sigmas` must be `None`.
|
||||
sigmas (`List[float]`, *optional*):
|
||||
Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed,
|
||||
`num_inference_steps` and `timesteps` must be `None`.
|
||||
|
||||
Returns:
|
||||
`Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the
|
||||
second element is the number of inference steps.
|
||||
"""
|
||||
if timesteps is not None and sigmas is not None:
|
||||
raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values")
|
||||
if timesteps is not None:
|
||||
accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
|
||||
if not accepts_timesteps:
|
||||
raise ValueError(
|
||||
f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
|
||||
f" timestep schedules. Please check whether you are using the correct scheduler."
|
||||
)
|
||||
scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)
|
||||
timesteps = scheduler.timesteps
|
||||
num_inference_steps = len(timesteps)
|
||||
elif sigmas is not None:
|
||||
accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
|
||||
if not accept_sigmas:
|
||||
raise ValueError(
|
||||
f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
|
||||
f" sigmas schedules. Please check whether you are using the correct scheduler."
|
||||
)
|
||||
scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs)
|
||||
timesteps = scheduler.timesteps
|
||||
num_inference_steps = len(timesteps)
|
||||
else:
|
||||
scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)
|
||||
timesteps = scheduler.timesteps
|
||||
return timesteps, num_inference_steps
|
||||
|
||||
|
||||
class CogVideoXSTGPipeline(DiffusionPipeline, CogVideoXLoraLoaderMixin):
|
||||
r"""
|
||||
Pipeline for text-to-video generation using CogVideoX.
|
||||
|
||||
This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
|
||||
library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
|
||||
|
||||
Args:
|
||||
vae ([`AutoencoderKL`]):
|
||||
Variational Auto-Encoder (VAE) Model to encode and decode videos to and from latent representations.
|
||||
text_encoder ([`T5EncoderModel`]):
|
||||
Frozen text-encoder. CogVideoX uses
|
||||
[T5](https://huggingface.co/docs/transformers/model_doc/t5#transformers.T5EncoderModel); specifically the
|
||||
[t5-v1_1-xxl](https://huggingface.co/PixArt-alpha/PixArt-alpha/tree/main/t5-v1_1-xxl) variant.
|
||||
tokenizer (`T5Tokenizer`):
|
||||
Tokenizer of class
|
||||
[T5Tokenizer](https://huggingface.co/docs/transformers/model_doc/t5#transformers.T5Tokenizer).
|
||||
transformer ([`CogVideoXTransformer3DModel`]):
|
||||
A text conditioned `CogVideoXTransformer3DModel` to denoise the encoded video latents.
|
||||
scheduler ([`SchedulerMixin`]):
|
||||
A scheduler to be used in combination with `transformer` to denoise the encoded video latents.
|
||||
"""
|
||||
|
||||
_optional_components = []
|
||||
model_cpu_offload_seq = "text_encoder->transformer->vae"
|
||||
|
||||
_callback_tensor_inputs = [
|
||||
"latents",
|
||||
"prompt_embeds",
|
||||
"negative_prompt_embeds",
|
||||
]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
tokenizer: T5Tokenizer,
|
||||
text_encoder: T5EncoderModel,
|
||||
vae: AutoencoderKLCogVideoX,
|
||||
transformer: CogVideoXTransformer3DModel,
|
||||
scheduler: Union[CogVideoXDDIMScheduler, CogVideoXDPMScheduler],
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.register_modules(
|
||||
tokenizer=tokenizer, text_encoder=text_encoder, vae=vae, transformer=transformer, scheduler=scheduler
|
||||
)
|
||||
self.vae_scale_factor_spatial = (
|
||||
2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8
|
||||
)
|
||||
self.vae_scale_factor_temporal = (
|
||||
self.vae.config.temporal_compression_ratio if getattr(self, "vae", None) else 4
|
||||
)
|
||||
self.vae_scaling_factor_image = self.vae.config.scaling_factor if getattr(self, "vae", None) else 0.7
|
||||
|
||||
self.video_processor = VideoProcessor(vae_scale_factor=self.vae_scale_factor_spatial)
|
||||
|
||||
def _get_t5_prompt_embeds(
|
||||
self,
|
||||
prompt: Union[str, List[str]] = None,
|
||||
num_videos_per_prompt: int = 1,
|
||||
max_sequence_length: int = 226,
|
||||
device: Optional[torch.device] = None,
|
||||
dtype: Optional[torch.dtype] = None,
|
||||
):
|
||||
device = device or self._execution_device
|
||||
dtype = dtype or self.text_encoder.dtype
|
||||
|
||||
prompt = [prompt] if isinstance(prompt, str) else prompt
|
||||
batch_size = len(prompt)
|
||||
|
||||
text_inputs = self.tokenizer(
|
||||
prompt,
|
||||
padding="max_length",
|
||||
max_length=max_sequence_length,
|
||||
truncation=True,
|
||||
add_special_tokens=True,
|
||||
return_tensors="pt",
|
||||
)
|
||||
text_input_ids = text_inputs.input_ids
|
||||
untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
|
||||
|
||||
if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids):
|
||||
removed_text = self.tokenizer.batch_decode(untruncated_ids[:, max_sequence_length - 1 : -1])
|
||||
logger.warning(
|
||||
"The following part of your input was truncated because `max_sequence_length` is set to "
|
||||
f" {max_sequence_length} tokens: {removed_text}"
|
||||
)
|
||||
|
||||
prompt_embeds = self.text_encoder(text_input_ids.to(device))[0]
|
||||
prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)
|
||||
|
||||
# duplicate text embeddings for each generation per prompt, using mps friendly method
|
||||
_, seq_len, _ = prompt_embeds.shape
|
||||
prompt_embeds = prompt_embeds.repeat(1, num_videos_per_prompt, 1)
|
||||
prompt_embeds = prompt_embeds.view(batch_size * num_videos_per_prompt, seq_len, -1)
|
||||
|
||||
return prompt_embeds
|
||||
|
||||
def encode_prompt(
|
||||
self,
|
||||
prompt: Union[str, List[str]],
|
||||
negative_prompt: Optional[Union[str, List[str]]] = None,
|
||||
do_classifier_free_guidance: bool = True,
|
||||
num_videos_per_prompt: int = 1,
|
||||
prompt_embeds: Optional[torch.Tensor] = None,
|
||||
negative_prompt_embeds: Optional[torch.Tensor] = None,
|
||||
max_sequence_length: int = 226,
|
||||
device: Optional[torch.device] = None,
|
||||
dtype: Optional[torch.dtype] = None,
|
||||
):
|
||||
r"""
|
||||
Encodes the prompt into text encoder hidden states.
|
||||
|
||||
Args:
|
||||
prompt (`str` or `List[str]`, *optional*):
|
||||
prompt to be encoded
|
||||
negative_prompt (`str` or `List[str]`, *optional*):
|
||||
The prompt or prompts not to guide the image generation. If not defined, one has to pass
|
||||
`negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
|
||||
less than `1`).
|
||||
do_classifier_free_guidance (`bool`, *optional*, defaults to `True`):
|
||||
Whether to use classifier free guidance or not.
|
||||
num_videos_per_prompt (`int`, *optional*, defaults to 1):
|
||||
Number of videos that should be generated per prompt. torch device to place the resulting embeddings on
|
||||
prompt_embeds (`torch.Tensor`, *optional*):
|
||||
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
|
||||
provided, text embeddings will be generated from `prompt` input argument.
|
||||
negative_prompt_embeds (`torch.Tensor`, *optional*):
|
||||
Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
|
||||
weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
|
||||
argument.
|
||||
device: (`torch.device`, *optional*):
|
||||
torch device
|
||||
dtype: (`torch.dtype`, *optional*):
|
||||
torch dtype
|
||||
"""
|
||||
device = device or self._execution_device
|
||||
|
||||
prompt = [prompt] if isinstance(prompt, str) else prompt
|
||||
if prompt is not None:
|
||||
batch_size = len(prompt)
|
||||
else:
|
||||
batch_size = prompt_embeds.shape[0]
|
||||
|
||||
if prompt_embeds is None:
|
||||
prompt_embeds = self._get_t5_prompt_embeds(
|
||||
prompt=prompt,
|
||||
num_videos_per_prompt=num_videos_per_prompt,
|
||||
max_sequence_length=max_sequence_length,
|
||||
device=device,
|
||||
dtype=dtype,
|
||||
)
|
||||
|
||||
if do_classifier_free_guidance and negative_prompt_embeds is None:
|
||||
negative_prompt = negative_prompt or ""
|
||||
negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt
|
||||
|
||||
if prompt is not None and type(prompt) is not type(negative_prompt):
|
||||
raise TypeError(
|
||||
f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
|
||||
f" {type(prompt)}."
|
||||
)
|
||||
elif batch_size != len(negative_prompt):
|
||||
raise ValueError(
|
||||
f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
|
||||
f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
|
||||
" the batch size of `prompt`."
|
||||
)
|
||||
|
||||
negative_prompt_embeds = self._get_t5_prompt_embeds(
|
||||
prompt=negative_prompt,
|
||||
num_videos_per_prompt=num_videos_per_prompt,
|
||||
max_sequence_length=max_sequence_length,
|
||||
device=device,
|
||||
dtype=dtype,
|
||||
)
|
||||
|
||||
return prompt_embeds, negative_prompt_embeds
|
||||
|
||||
def prepare_latents(
|
||||
self, batch_size, num_channels_latents, num_frames, height, width, dtype, device, generator, latents=None
|
||||
):
|
||||
if isinstance(generator, list) and len(generator) != batch_size:
|
||||
raise ValueError(
|
||||
f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
|
||||
f" size of {batch_size}. Make sure the batch size matches the length of the generators."
|
||||
)
|
||||
|
||||
shape = (
|
||||
batch_size,
|
||||
(num_frames - 1) // self.vae_scale_factor_temporal + 1,
|
||||
num_channels_latents,
|
||||
height // self.vae_scale_factor_spatial,
|
||||
width // self.vae_scale_factor_spatial,
|
||||
)
|
||||
|
||||
if latents is None:
|
||||
latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
|
||||
else:
|
||||
latents = latents.to(device)
|
||||
|
||||
# scale the initial noise by the standard deviation required by the scheduler
|
||||
latents = latents * self.scheduler.init_noise_sigma
|
||||
return latents
|
||||
|
||||
def decode_latents(self, latents: torch.Tensor) -> torch.Tensor:
|
||||
latents = latents.permute(0, 2, 1, 3, 4) # [batch_size, num_channels, num_frames, height, width]
|
||||
latents = 1 / self.vae_scaling_factor_image * latents
|
||||
|
||||
frames = self.vae.decode(latents).sample
|
||||
return frames
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs
|
||||
def prepare_extra_step_kwargs(self, generator, eta):
|
||||
# prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
|
||||
# eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
|
||||
# eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
|
||||
# and should be between [0, 1]
|
||||
|
||||
accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
|
||||
extra_step_kwargs = {}
|
||||
if accepts_eta:
|
||||
extra_step_kwargs["eta"] = eta
|
||||
|
||||
# check if the scheduler accepts generator
|
||||
accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys())
|
||||
if accepts_generator:
|
||||
extra_step_kwargs["generator"] = generator
|
||||
return extra_step_kwargs
|
||||
|
||||
# Copied from diffusers.pipelines.latte.pipeline_latte.LattePipeline.check_inputs
|
||||
def check_inputs(
|
||||
self,
|
||||
prompt,
|
||||
height,
|
||||
width,
|
||||
negative_prompt,
|
||||
callback_on_step_end_tensor_inputs,
|
||||
prompt_embeds=None,
|
||||
negative_prompt_embeds=None,
|
||||
):
|
||||
if height % 8 != 0 or width % 8 != 0:
|
||||
raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
|
||||
|
||||
if callback_on_step_end_tensor_inputs is not None and not all(
|
||||
k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
|
||||
):
|
||||
raise ValueError(
|
||||
f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}"
|
||||
)
|
||||
if prompt is not None and prompt_embeds is not None:
|
||||
raise ValueError(
|
||||
f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
|
||||
" only forward one of the two."
|
||||
)
|
||||
elif prompt is None and prompt_embeds is None:
|
||||
raise ValueError(
|
||||
"Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
|
||||
)
|
||||
elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
|
||||
raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
|
||||
|
||||
if prompt is not None and negative_prompt_embeds is not None:
|
||||
raise ValueError(
|
||||
f"Cannot forward both `prompt`: {prompt} and `negative_prompt_embeds`:"
|
||||
f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
|
||||
)
|
||||
|
||||
if negative_prompt is not None and negative_prompt_embeds is not None:
|
||||
raise ValueError(
|
||||
f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:"
|
||||
f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
|
||||
)
|
||||
|
||||
if prompt_embeds is not None and negative_prompt_embeds is not None:
|
||||
if prompt_embeds.shape != negative_prompt_embeds.shape:
|
||||
raise ValueError(
|
||||
"`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but"
|
||||
f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`"
|
||||
f" {negative_prompt_embeds.shape}."
|
||||
)
|
||||
|
||||
def fuse_qkv_projections(self) -> None:
|
||||
r"""Enables fused QKV projections."""
|
||||
self.fusing_transformer = True
|
||||
self.transformer.fuse_qkv_projections()
|
||||
|
||||
def unfuse_qkv_projections(self) -> None:
|
||||
r"""Disable QKV projection fusion if enabled."""
|
||||
if not self.fusing_transformer:
|
||||
logger.warning("The Transformer was not initially fused for QKV projections. Doing nothing.")
|
||||
else:
|
||||
self.transformer.unfuse_qkv_projections()
|
||||
self.fusing_transformer = False
|
||||
|
||||
def _prepare_rotary_positional_embeddings(
|
||||
self,
|
||||
height: int,
|
||||
width: int,
|
||||
num_frames: int,
|
||||
device: torch.device,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
grid_height = height // (self.vae_scale_factor_spatial * self.transformer.config.patch_size)
|
||||
grid_width = width // (self.vae_scale_factor_spatial * self.transformer.config.patch_size)
|
||||
|
||||
p = self.transformer.config.patch_size
|
||||
p_t = self.transformer.config.patch_size_t
|
||||
|
||||
base_size_width = self.transformer.config.sample_width // p
|
||||
base_size_height = self.transformer.config.sample_height // p
|
||||
|
||||
if p_t is None:
|
||||
# CogVideoX 1.0
|
||||
grid_crops_coords = get_resize_crop_region_for_grid(
|
||||
(grid_height, grid_width), base_size_width, base_size_height
|
||||
)
|
||||
freqs_cos, freqs_sin = get_3d_rotary_pos_embed(
|
||||
embed_dim=self.transformer.config.attention_head_dim,
|
||||
crops_coords=grid_crops_coords,
|
||||
grid_size=(grid_height, grid_width),
|
||||
temporal_size=num_frames,
|
||||
device=device,
|
||||
)
|
||||
else:
|
||||
# CogVideoX 1.5
|
||||
base_num_frames = (num_frames + p_t - 1) // p_t
|
||||
|
||||
freqs_cos, freqs_sin = get_3d_rotary_pos_embed(
|
||||
embed_dim=self.transformer.config.attention_head_dim,
|
||||
crops_coords=None,
|
||||
grid_size=(grid_height, grid_width),
|
||||
temporal_size=base_num_frames,
|
||||
grid_type="slice",
|
||||
max_size=(base_size_height, base_size_width),
|
||||
device=device,
|
||||
)
|
||||
|
||||
return freqs_cos, freqs_sin
|
||||
|
||||
@property
|
||||
def guidance_scale(self):
|
||||
return self._guidance_scale
|
||||
|
||||
@property
|
||||
def do_spatio_temporal_guidance(self):
|
||||
return self._stg_scale > 0.0
|
||||
|
||||
@property
|
||||
def num_timesteps(self):
|
||||
return self._num_timesteps
|
||||
|
||||
@property
|
||||
def attention_kwargs(self):
|
||||
return self._attention_kwargs
|
||||
|
||||
@property
|
||||
def current_timestep(self):
|
||||
return self._current_timestep
|
||||
|
||||
@property
|
||||
def interrupt(self):
|
||||
return self._interrupt
|
||||
|
||||
@torch.no_grad()
|
||||
@replace_example_docstring(EXAMPLE_DOC_STRING)
|
||||
def __call__(
|
||||
self,
|
||||
prompt: Optional[Union[str, List[str]]] = None,
|
||||
negative_prompt: Optional[Union[str, List[str]]] = None,
|
||||
height: Optional[int] = None,
|
||||
width: Optional[int] = None,
|
||||
num_frames: Optional[int] = None,
|
||||
num_inference_steps: int = 50,
|
||||
timesteps: Optional[List[int]] = None,
|
||||
guidance_scale: float = 6,
|
||||
use_dynamic_cfg: bool = False,
|
||||
num_videos_per_prompt: int = 1,
|
||||
eta: float = 0.0,
|
||||
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
|
||||
latents: Optional[torch.FloatTensor] = None,
|
||||
prompt_embeds: Optional[torch.FloatTensor] = None,
|
||||
negative_prompt_embeds: Optional[torch.FloatTensor] = None,
|
||||
output_type: str = "pil",
|
||||
return_dict: bool = True,
|
||||
attention_kwargs: Optional[Dict[str, Any]] = None,
|
||||
callback_on_step_end: Optional[
|
||||
Union[Callable[[int, int, Dict], None], PipelineCallback, MultiPipelineCallbacks]
|
||||
] = None,
|
||||
callback_on_step_end_tensor_inputs: List[str] = ["latents"],
|
||||
max_sequence_length: int = 226,
|
||||
stg_applied_layers_idx: Optional[List[int]] = [11],
|
||||
stg_scale: Optional[float] = 0.0,
|
||||
do_rescaling: Optional[bool] = False,
|
||||
) -> Union[CogVideoXPipelineOutput, Tuple]:
|
||||
"""
|
||||
Function invoked when calling the pipeline for generation.
|
||||
|
||||
Args:
|
||||
prompt (`str` or `List[str]`, *optional*):
|
||||
The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
|
||||
instead.
|
||||
negative_prompt (`str` or `List[str]`, *optional*):
|
||||
The prompt or prompts not to guide the image generation. If not defined, one has to pass
|
||||
`negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
|
||||
less than `1`).
|
||||
height (`int`, *optional*, defaults to self.transformer.config.sample_height * self.vae_scale_factor_spatial):
|
||||
The height in pixels of the generated image. This is set to 480 by default for the best results.
|
||||
width (`int`, *optional*, defaults to self.transformer.config.sample_height * self.vae_scale_factor_spatial):
|
||||
The width in pixels of the generated image. This is set to 720 by default for the best results.
|
||||
num_frames (`int`, defaults to `48`):
|
||||
Number of frames to generate. Must be divisible by self.vae_scale_factor_temporal. Generated video will
|
||||
contain 1 extra frame because CogVideoX is conditioned with (num_seconds * fps + 1) frames where
|
||||
num_seconds is 6 and fps is 8. However, since videos can be saved at any fps, the only condition that
|
||||
needs to be satisfied is that of divisibility mentioned above.
|
||||
num_inference_steps (`int`, *optional*, defaults to 50):
|
||||
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
|
||||
expense of slower inference.
|
||||
timesteps (`List[int]`, *optional*):
|
||||
Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument
|
||||
in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is
|
||||
passed will be used. Must be in descending order.
|
||||
guidance_scale (`float`, *optional*, defaults to 7.0):
|
||||
Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
|
||||
`guidance_scale` is defined as `w` of equation 2. of [Imagen
|
||||
Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
|
||||
1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
|
||||
usually at the expense of lower image quality.
|
||||
num_videos_per_prompt (`int`, *optional*, defaults to 1):
|
||||
The number of videos to generate per prompt.
|
||||
generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
|
||||
One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
|
||||
to make generation deterministic.
|
||||
latents (`torch.FloatTensor`, *optional*):
|
||||
Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
|
||||
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
|
||||
tensor will ge generated by sampling using the supplied random `generator`.
|
||||
prompt_embeds (`torch.FloatTensor`, *optional*):
|
||||
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
|
||||
provided, text embeddings will be generated from `prompt` input argument.
|
||||
negative_prompt_embeds (`torch.FloatTensor`, *optional*):
|
||||
Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
|
||||
weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
|
||||
argument.
|
||||
output_type (`str`, *optional*, defaults to `"pil"`):
|
||||
The output format of the generate image. Choose between
|
||||
[PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
|
||||
return_dict (`bool`, *optional*, defaults to `True`):
|
||||
Whether or not to return a [`~pipelines.stable_diffusion_xl.StableDiffusionXLPipelineOutput`] instead
|
||||
of a plain tuple.
|
||||
attention_kwargs (`dict`, *optional*):
|
||||
A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
|
||||
`self.processor` in
|
||||
[diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
|
||||
callback_on_step_end (`Callable`, *optional*):
|
||||
A function that calls at the end of each denoising steps during the inference. The function is called
|
||||
with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int,
|
||||
callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by
|
||||
`callback_on_step_end_tensor_inputs`.
|
||||
callback_on_step_end_tensor_inputs (`List`, *optional*):
|
||||
The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
|
||||
will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
|
||||
`._callback_tensor_inputs` attribute of your pipeline class.
|
||||
max_sequence_length (`int`, defaults to `226`):
|
||||
Maximum sequence length in encoded prompt. Must be consistent with
|
||||
`self.transformer.config.max_text_seq_length` otherwise may lead to poor results.
|
||||
|
||||
Examples:
|
||||
|
||||
Returns:
|
||||
[`~pipelines.cogvideo.pipeline_cogvideox.CogVideoXPipelineOutput`] or `tuple`:
|
||||
[`~pipelines.cogvideo.pipeline_cogvideox.CogVideoXPipelineOutput`] if `return_dict` is True, otherwise a
|
||||
`tuple`. When returning a tuple, the first element is a list with the generated images.
|
||||
"""
|
||||
|
||||
if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)):
|
||||
callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs
|
||||
|
||||
height = height or self.transformer.config.sample_height * self.vae_scale_factor_spatial
|
||||
width = width or self.transformer.config.sample_width * self.vae_scale_factor_spatial
|
||||
num_frames = num_frames or self.transformer.config.sample_frames
|
||||
|
||||
num_videos_per_prompt = 1
|
||||
|
||||
# 1. Check inputs. Raise error if not correct
|
||||
self.check_inputs(
|
||||
prompt,
|
||||
height,
|
||||
width,
|
||||
negative_prompt,
|
||||
callback_on_step_end_tensor_inputs,
|
||||
prompt_embeds,
|
||||
negative_prompt_embeds,
|
||||
)
|
||||
self._stg_scale = stg_scale
|
||||
self._guidance_scale = guidance_scale
|
||||
self._attention_kwargs = attention_kwargs
|
||||
self._current_timestep = None
|
||||
self._interrupt = False
|
||||
|
||||
if self.do_spatio_temporal_guidance:
|
||||
for i in stg_applied_layers_idx:
|
||||
self.transformer.transformer_blocks[i].forward = types.MethodType(
|
||||
forward_with_stg, self.transformer.transformer_blocks[i]
|
||||
)
|
||||
|
||||
# 2. Default call parameters
|
||||
if prompt is not None and isinstance(prompt, str):
|
||||
batch_size = 1
|
||||
elif prompt is not None and isinstance(prompt, list):
|
||||
batch_size = len(prompt)
|
||||
else:
|
||||
batch_size = prompt_embeds.shape[0]
|
||||
|
||||
device = self._execution_device
|
||||
|
||||
# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
|
||||
# of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
|
||||
# corresponds to doing no classifier free guidance.
|
||||
do_classifier_free_guidance = guidance_scale > 1.0
|
||||
|
||||
# 3. Encode input prompt
|
||||
prompt_embeds, negative_prompt_embeds = self.encode_prompt(
|
||||
prompt,
|
||||
negative_prompt,
|
||||
do_classifier_free_guidance,
|
||||
num_videos_per_prompt=num_videos_per_prompt,
|
||||
prompt_embeds=prompt_embeds,
|
||||
negative_prompt_embeds=negative_prompt_embeds,
|
||||
max_sequence_length=max_sequence_length,
|
||||
device=device,
|
||||
)
|
||||
if do_classifier_free_guidance and not self.do_spatio_temporal_guidance:
|
||||
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)
|
||||
elif do_classifier_free_guidance and self.do_spatio_temporal_guidance:
|
||||
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds, prompt_embeds], dim=0)
|
||||
|
||||
# 4. Prepare timesteps
|
||||
timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, timesteps)
|
||||
self._num_timesteps = len(timesteps)
|
||||
|
||||
# 5. Prepare latents
|
||||
latent_frames = (num_frames - 1) // self.vae_scale_factor_temporal + 1
|
||||
|
||||
# For CogVideoX 1.5, the latent frames should be padded to make it divisible by patch_size_t
|
||||
patch_size_t = self.transformer.config.patch_size_t
|
||||
additional_frames = 0
|
||||
if patch_size_t is not None and latent_frames % patch_size_t != 0:
|
||||
additional_frames = patch_size_t - latent_frames % patch_size_t
|
||||
num_frames += additional_frames * self.vae_scale_factor_temporal
|
||||
|
||||
latent_channels = self.transformer.config.in_channels
|
||||
latents = self.prepare_latents(
|
||||
batch_size * num_videos_per_prompt,
|
||||
latent_channels,
|
||||
num_frames,
|
||||
height,
|
||||
width,
|
||||
prompt_embeds.dtype,
|
||||
device,
|
||||
generator,
|
||||
latents,
|
||||
)
|
||||
|
||||
# 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
|
||||
extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
|
||||
|
||||
# 7. Create rotary embeds if required
|
||||
image_rotary_emb = (
|
||||
self._prepare_rotary_positional_embeddings(height, width, latents.size(1), device)
|
||||
if self.transformer.config.use_rotary_positional_embeddings
|
||||
else None
|
||||
)
|
||||
|
||||
# 8. Denoising loop
|
||||
num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
|
||||
|
||||
with self.progress_bar(total=num_inference_steps) as progress_bar:
|
||||
# for DPM-solver++
|
||||
old_pred_original_sample = None
|
||||
for i, t in enumerate(timesteps):
|
||||
if self.interrupt:
|
||||
continue
|
||||
|
||||
self._current_timestep = t
|
||||
if do_classifier_free_guidance and not self.do_spatio_temporal_guidance:
|
||||
latent_model_input = torch.cat([latents] * 2)
|
||||
elif do_classifier_free_guidance and self.do_spatio_temporal_guidance:
|
||||
latent_model_input = torch.cat([latents] * 3)
|
||||
else:
|
||||
latent_model_input = latents
|
||||
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
|
||||
|
||||
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
|
||||
timestep = t.expand(latent_model_input.shape[0])
|
||||
|
||||
# predict noise model_output
|
||||
noise_pred = self.transformer(
|
||||
hidden_states=latent_model_input,
|
||||
encoder_hidden_states=prompt_embeds,
|
||||
timestep=timestep,
|
||||
image_rotary_emb=image_rotary_emb,
|
||||
attention_kwargs=attention_kwargs,
|
||||
return_dict=False,
|
||||
)[0]
|
||||
noise_pred = noise_pred.float()
|
||||
|
||||
# perform guidance
|
||||
if use_dynamic_cfg:
|
||||
self._guidance_scale = 1 + guidance_scale * (
|
||||
(1 - math.cos(math.pi * ((num_inference_steps - t.item()) / num_inference_steps) ** 5.0)) / 2
|
||||
)
|
||||
if do_classifier_free_guidance and not self.do_spatio_temporal_guidance:
|
||||
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
|
||||
noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond)
|
||||
elif do_classifier_free_guidance and self.do_spatio_temporal_guidance:
|
||||
noise_pred_uncond, noise_pred_text, noise_pred_perturb = noise_pred.chunk(3)
|
||||
noise_pred = (
|
||||
noise_pred_uncond
|
||||
+ self.guidance_scale * (noise_pred_text - noise_pred_uncond)
|
||||
+ self._stg_scale * (noise_pred_text - noise_pred_perturb)
|
||||
)
|
||||
|
||||
if do_rescaling:
|
||||
rescaling_scale = 0.7
|
||||
factor = noise_pred_text.std() / noise_pred.std()
|
||||
factor = rescaling_scale * factor + (1 - rescaling_scale)
|
||||
noise_pred = noise_pred * factor
|
||||
|
||||
# compute the previous noisy sample x_t -> x_t-1
|
||||
if not isinstance(self.scheduler, CogVideoXDPMScheduler):
|
||||
latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]
|
||||
else:
|
||||
latents, old_pred_original_sample = self.scheduler.step(
|
||||
noise_pred,
|
||||
old_pred_original_sample,
|
||||
t,
|
||||
timesteps[i - 1] if i > 0 else None,
|
||||
latents,
|
||||
**extra_step_kwargs,
|
||||
return_dict=False,
|
||||
)
|
||||
latents = latents.to(prompt_embeds.dtype)
|
||||
|
||||
# call the callback, if provided
|
||||
if callback_on_step_end is not None:
|
||||
callback_kwargs = {}
|
||||
for k in callback_on_step_end_tensor_inputs:
|
||||
callback_kwargs[k] = locals()[k]
|
||||
callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
|
||||
|
||||
latents = callback_outputs.pop("latents", latents)
|
||||
prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
|
||||
negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds)
|
||||
|
||||
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
|
||||
progress_bar.update()
|
||||
|
||||
if XLA_AVAILABLE:
|
||||
xm.mark_step()
|
||||
|
||||
self._current_timestep = None
|
||||
|
||||
if not output_type == "latent":
|
||||
# Discard any padding frames that were added for CogVideoX 1.5
|
||||
latents = latents[:, additional_frames:]
|
||||
video = self.decode_latents(latents)
|
||||
video = self.video_processor.postprocess_video(video=video, output_type=output_type)
|
||||
else:
|
||||
video = latents
|
||||
|
||||
# Offload all models
|
||||
self.maybe_free_model_hooks()
|
||||
|
||||
if not return_dict:
|
||||
return (video,)
|
||||
|
||||
return CogVideoXPipelineOutput(frames=video)
|
||||
@@ -0,0 +1,794 @@
|
||||
# Copyright 2024 The HunyuanVideo Team and The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import inspect
|
||||
import types
|
||||
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from transformers import CLIPTextModel, CLIPTokenizer, LlamaModel, LlamaTokenizerFast
|
||||
|
||||
from diffusers.callbacks import MultiPipelineCallbacks, PipelineCallback
|
||||
from diffusers.loaders import HunyuanVideoLoraLoaderMixin
|
||||
from diffusers.models import AutoencoderKLHunyuanVideo, HunyuanVideoTransformer3DModel
|
||||
from diffusers.pipelines.hunyuan_video.pipeline_output import HunyuanVideoPipelineOutput
|
||||
from diffusers.pipelines.pipeline_utils import DiffusionPipeline
|
||||
from diffusers.schedulers import FlowMatchEulerDiscreteScheduler
|
||||
from diffusers.utils import is_torch_xla_available, logging, replace_example_docstring
|
||||
from diffusers.utils.torch_utils import randn_tensor
|
||||
from diffusers.video_processor import VideoProcessor
|
||||
|
||||
|
||||
if is_torch_xla_available():
|
||||
import torch_xla.core.xla_model as xm
|
||||
|
||||
XLA_AVAILABLE = True
|
||||
else:
|
||||
XLA_AVAILABLE = False
|
||||
|
||||
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||
|
||||
|
||||
EXAMPLE_DOC_STRING = """
|
||||
Examples:
|
||||
```python
|
||||
>>> import torch
|
||||
>>> from diffusers.utils import export_to_video
|
||||
>>> from diffusers import HunyuanVideoTransformer3DModel
|
||||
>>> from examples.community.pipeline_stg_hunyuan_video import HunyuanVideoSTGPipeline
|
||||
|
||||
>>> model_id = "hunyuanvideo-community/HunyuanVideo"
|
||||
>>> transformer = HunyuanVideoTransformer3DModel.from_pretrained(
|
||||
... model_id, subfolder="transformer", torch_dtype=torch.bfloat16
|
||||
... )
|
||||
>>> pipe = HunyuanVideoSTGPipeline.from_pretrained(model_id, transformer=transformer, torch_dtype=torch.float16)
|
||||
>>> pipe.vae.enable_tiling()
|
||||
>>> pipe.to("cuda")
|
||||
|
||||
>>> # Configure STG mode options
|
||||
>>> stg_applied_layers_idx = [2] # Layer indices from 0 to 41
|
||||
>>> stg_scale = 1.0 # Set 0.0 for CFG
|
||||
|
||||
>>> output = pipe(
|
||||
... prompt="A wolf howling at the moon, with the moon subtly resembling a giant clock face, realistic style.",
|
||||
... height=320,
|
||||
... width=512,
|
||||
... num_frames=61,
|
||||
... num_inference_steps=30,
|
||||
... stg_applied_layers_idx=stg_applied_layers_idx,
|
||||
... stg_scale=stg_scale,
|
||||
>>> ).frames[0]
|
||||
>>> export_to_video(output, "output.mp4", fps=15)
|
||||
```
|
||||
"""
|
||||
|
||||
|
||||
DEFAULT_PROMPT_TEMPLATE = {
|
||||
"template": (
|
||||
"<|start_header_id|>system<|end_header_id|>\n\nDescribe the video by detailing the following aspects: "
|
||||
"1. The main content and theme of the video."
|
||||
"2. The color, shape, size, texture, quantity, text, and spatial relationships of the objects."
|
||||
"3. Actions, events, behaviors temporal relationships, physical movement changes of the objects."
|
||||
"4. background environment, light, style and atmosphere."
|
||||
"5. camera angles, movements, and transitions used in the video:<|eot_id|>"
|
||||
"<|start_header_id|>user<|end_header_id|>\n\n{}<|eot_id|>"
|
||||
),
|
||||
"crop_start": 95,
|
||||
}
|
||||
|
||||
|
||||
def forward_with_stg(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
encoder_hidden_states: torch.Tensor,
|
||||
temb: torch.Tensor,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
freqs_cis: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
return hidden_states, encoder_hidden_states
|
||||
|
||||
|
||||
def forward_without_stg(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
encoder_hidden_states: torch.Tensor,
|
||||
temb: torch.Tensor,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
freqs_cis: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
# 1. Input normalization
|
||||
norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1(hidden_states, emb=temb)
|
||||
norm_encoder_hidden_states, c_gate_msa, c_shift_mlp, c_scale_mlp, c_gate_mlp = self.norm1_context(
|
||||
encoder_hidden_states, emb=temb
|
||||
)
|
||||
|
||||
# 2. Joint attention
|
||||
attn_output, context_attn_output = self.attn(
|
||||
hidden_states=norm_hidden_states,
|
||||
encoder_hidden_states=norm_encoder_hidden_states,
|
||||
attention_mask=attention_mask,
|
||||
image_rotary_emb=freqs_cis,
|
||||
)
|
||||
|
||||
# 3. Modulation and residual connection
|
||||
hidden_states = hidden_states + attn_output * gate_msa.unsqueeze(1)
|
||||
encoder_hidden_states = encoder_hidden_states + context_attn_output * c_gate_msa.unsqueeze(1)
|
||||
|
||||
norm_hidden_states = self.norm2(hidden_states)
|
||||
norm_encoder_hidden_states = self.norm2_context(encoder_hidden_states)
|
||||
|
||||
norm_hidden_states = norm_hidden_states * (1 + scale_mlp[:, None]) + shift_mlp[:, None]
|
||||
norm_encoder_hidden_states = norm_encoder_hidden_states * (1 + c_scale_mlp[:, None]) + c_shift_mlp[:, None]
|
||||
|
||||
# 4. Feed-forward
|
||||
ff_output = self.ff(norm_hidden_states)
|
||||
context_ff_output = self.ff_context(norm_encoder_hidden_states)
|
||||
|
||||
hidden_states = hidden_states + gate_mlp.unsqueeze(1) * ff_output
|
||||
encoder_hidden_states = encoder_hidden_states + c_gate_mlp.unsqueeze(1) * context_ff_output
|
||||
|
||||
return hidden_states, encoder_hidden_states
|
||||
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps
|
||||
def retrieve_timesteps(
|
||||
scheduler,
|
||||
num_inference_steps: Optional[int] = None,
|
||||
device: Optional[Union[str, torch.device]] = None,
|
||||
timesteps: Optional[List[int]] = None,
|
||||
sigmas: Optional[List[float]] = None,
|
||||
**kwargs,
|
||||
):
|
||||
r"""
|
||||
Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
|
||||
custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.
|
||||
|
||||
Args:
|
||||
scheduler (`SchedulerMixin`):
|
||||
The scheduler to get timesteps from.
|
||||
num_inference_steps (`int`):
|
||||
The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps`
|
||||
must be `None`.
|
||||
device (`str` or `torch.device`, *optional*):
|
||||
The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
|
||||
timesteps (`List[int]`, *optional*):
|
||||
Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed,
|
||||
`num_inference_steps` and `sigmas` must be `None`.
|
||||
sigmas (`List[float]`, *optional*):
|
||||
Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed,
|
||||
`num_inference_steps` and `timesteps` must be `None`.
|
||||
|
||||
Returns:
|
||||
`Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the
|
||||
second element is the number of inference steps.
|
||||
"""
|
||||
if timesteps is not None and sigmas is not None:
|
||||
raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values")
|
||||
if timesteps is not None:
|
||||
accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
|
||||
if not accepts_timesteps:
|
||||
raise ValueError(
|
||||
f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
|
||||
f" timestep schedules. Please check whether you are using the correct scheduler."
|
||||
)
|
||||
scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)
|
||||
timesteps = scheduler.timesteps
|
||||
num_inference_steps = len(timesteps)
|
||||
elif sigmas is not None:
|
||||
accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
|
||||
if not accept_sigmas:
|
||||
raise ValueError(
|
||||
f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
|
||||
f" sigmas schedules. Please check whether you are using the correct scheduler."
|
||||
)
|
||||
scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs)
|
||||
timesteps = scheduler.timesteps
|
||||
num_inference_steps = len(timesteps)
|
||||
else:
|
||||
scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)
|
||||
timesteps = scheduler.timesteps
|
||||
return timesteps, num_inference_steps
|
||||
|
||||
|
||||
class HunyuanVideoSTGPipeline(DiffusionPipeline, HunyuanVideoLoraLoaderMixin):
|
||||
r"""
|
||||
Pipeline for text-to-video generation using HunyuanVideo.
|
||||
|
||||
This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods
|
||||
implemented for all pipelines (downloading, saving, running on a particular device, etc.).
|
||||
|
||||
Args:
|
||||
text_encoder ([`LlamaModel`]):
|
||||
[Llava Llama3-8B](https://huggingface.co/xtuner/llava-llama-3-8b-v1_1-transformers).
|
||||
tokenizer (`LlamaTokenizer`):
|
||||
Tokenizer from [Llava Llama3-8B](https://huggingface.co/xtuner/llava-llama-3-8b-v1_1-transformers).
|
||||
transformer ([`HunyuanVideoTransformer3DModel`]):
|
||||
Conditional Transformer to denoise the encoded image latents.
|
||||
scheduler ([`FlowMatchEulerDiscreteScheduler`]):
|
||||
A scheduler to be used in combination with `transformer` to denoise the encoded image latents.
|
||||
vae ([`AutoencoderKLHunyuanVideo`]):
|
||||
Variational Auto-Encoder (VAE) Model to encode and decode videos to and from latent representations.
|
||||
text_encoder_2 ([`CLIPTextModel`]):
|
||||
[CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically
|
||||
the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant.
|
||||
tokenizer_2 (`CLIPTokenizer`):
|
||||
Tokenizer of class
|
||||
[CLIPTokenizer](https://huggingface.co/docs/transformers/en/model_doc/clip#transformers.CLIPTokenizer).
|
||||
"""
|
||||
|
||||
model_cpu_offload_seq = "text_encoder->text_encoder_2->transformer->vae"
|
||||
_callback_tensor_inputs = ["latents", "prompt_embeds"]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
text_encoder: LlamaModel,
|
||||
tokenizer: LlamaTokenizerFast,
|
||||
transformer: HunyuanVideoTransformer3DModel,
|
||||
vae: AutoencoderKLHunyuanVideo,
|
||||
scheduler: FlowMatchEulerDiscreteScheduler,
|
||||
text_encoder_2: CLIPTextModel,
|
||||
tokenizer_2: CLIPTokenizer,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.register_modules(
|
||||
vae=vae,
|
||||
text_encoder=text_encoder,
|
||||
tokenizer=tokenizer,
|
||||
transformer=transformer,
|
||||
scheduler=scheduler,
|
||||
text_encoder_2=text_encoder_2,
|
||||
tokenizer_2=tokenizer_2,
|
||||
)
|
||||
|
||||
self.vae_scale_factor_temporal = self.vae.temporal_compression_ratio if getattr(self, "vae", None) else 4
|
||||
self.vae_scale_factor_spatial = self.vae.spatial_compression_ratio if getattr(self, "vae", None) else 8
|
||||
self.video_processor = VideoProcessor(vae_scale_factor=self.vae_scale_factor_spatial)
|
||||
|
||||
def _get_llama_prompt_embeds(
|
||||
self,
|
||||
prompt: Union[str, List[str]],
|
||||
prompt_template: Dict[str, Any],
|
||||
num_videos_per_prompt: int = 1,
|
||||
device: Optional[torch.device] = None,
|
||||
dtype: Optional[torch.dtype] = None,
|
||||
max_sequence_length: int = 256,
|
||||
num_hidden_layers_to_skip: int = 2,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
device = device or self._execution_device
|
||||
dtype = dtype or self.text_encoder.dtype
|
||||
|
||||
prompt = [prompt] if isinstance(prompt, str) else prompt
|
||||
batch_size = len(prompt)
|
||||
|
||||
prompt = [prompt_template["template"].format(p) for p in prompt]
|
||||
|
||||
crop_start = prompt_template.get("crop_start", None)
|
||||
if crop_start is None:
|
||||
prompt_template_input = self.tokenizer(
|
||||
prompt_template["template"],
|
||||
padding="max_length",
|
||||
return_tensors="pt",
|
||||
return_length=False,
|
||||
return_overflowing_tokens=False,
|
||||
return_attention_mask=False,
|
||||
)
|
||||
crop_start = prompt_template_input["input_ids"].shape[-1]
|
||||
# Remove <|eot_id|> token and placeholder {}
|
||||
crop_start -= 2
|
||||
|
||||
max_sequence_length += crop_start
|
||||
text_inputs = self.tokenizer(
|
||||
prompt,
|
||||
max_length=max_sequence_length,
|
||||
padding="max_length",
|
||||
truncation=True,
|
||||
return_tensors="pt",
|
||||
return_length=False,
|
||||
return_overflowing_tokens=False,
|
||||
return_attention_mask=True,
|
||||
)
|
||||
text_input_ids = text_inputs.input_ids.to(device=device)
|
||||
prompt_attention_mask = text_inputs.attention_mask.to(device=device)
|
||||
|
||||
prompt_embeds = self.text_encoder(
|
||||
input_ids=text_input_ids,
|
||||
attention_mask=prompt_attention_mask,
|
||||
output_hidden_states=True,
|
||||
).hidden_states[-(num_hidden_layers_to_skip + 1)]
|
||||
prompt_embeds = prompt_embeds.to(dtype=dtype)
|
||||
|
||||
if crop_start is not None and crop_start > 0:
|
||||
prompt_embeds = prompt_embeds[:, crop_start:]
|
||||
prompt_attention_mask = prompt_attention_mask[:, crop_start:]
|
||||
|
||||
# duplicate text embeddings for each generation per prompt, using mps friendly method
|
||||
_, seq_len, _ = prompt_embeds.shape
|
||||
prompt_embeds = prompt_embeds.repeat(1, num_videos_per_prompt, 1)
|
||||
prompt_embeds = prompt_embeds.view(batch_size * num_videos_per_prompt, seq_len, -1)
|
||||
prompt_attention_mask = prompt_attention_mask.repeat(1, num_videos_per_prompt)
|
||||
prompt_attention_mask = prompt_attention_mask.view(batch_size * num_videos_per_prompt, seq_len)
|
||||
|
||||
return prompt_embeds, prompt_attention_mask
|
||||
|
||||
def _get_clip_prompt_embeds(
|
||||
self,
|
||||
prompt: Union[str, List[str]],
|
||||
num_videos_per_prompt: int = 1,
|
||||
device: Optional[torch.device] = None,
|
||||
dtype: Optional[torch.dtype] = None,
|
||||
max_sequence_length: int = 77,
|
||||
) -> torch.Tensor:
|
||||
device = device or self._execution_device
|
||||
dtype = dtype or self.text_encoder_2.dtype
|
||||
|
||||
prompt = [prompt] if isinstance(prompt, str) else prompt
|
||||
batch_size = len(prompt)
|
||||
|
||||
text_inputs = self.tokenizer_2(
|
||||
prompt,
|
||||
padding="max_length",
|
||||
max_length=max_sequence_length,
|
||||
truncation=True,
|
||||
return_tensors="pt",
|
||||
)
|
||||
|
||||
text_input_ids = text_inputs.input_ids
|
||||
untruncated_ids = self.tokenizer_2(prompt, padding="longest", return_tensors="pt").input_ids
|
||||
if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids):
|
||||
removed_text = self.tokenizer_2.batch_decode(untruncated_ids[:, max_sequence_length - 1 : -1])
|
||||
logger.warning(
|
||||
"The following part of your input was truncated because CLIP can only handle sequences up to"
|
||||
f" {max_sequence_length} tokens: {removed_text}"
|
||||
)
|
||||
|
||||
prompt_embeds = self.text_encoder_2(text_input_ids.to(device), output_hidden_states=False).pooler_output
|
||||
|
||||
# duplicate text embeddings for each generation per prompt, using mps friendly method
|
||||
prompt_embeds = prompt_embeds.repeat(1, num_videos_per_prompt)
|
||||
prompt_embeds = prompt_embeds.view(batch_size * num_videos_per_prompt, -1)
|
||||
|
||||
return prompt_embeds
|
||||
|
||||
def encode_prompt(
|
||||
self,
|
||||
prompt: Union[str, List[str]],
|
||||
prompt_2: Union[str, List[str]] = None,
|
||||
prompt_template: Dict[str, Any] = DEFAULT_PROMPT_TEMPLATE,
|
||||
num_videos_per_prompt: int = 1,
|
||||
prompt_embeds: Optional[torch.Tensor] = None,
|
||||
pooled_prompt_embeds: Optional[torch.Tensor] = None,
|
||||
prompt_attention_mask: Optional[torch.Tensor] = None,
|
||||
device: Optional[torch.device] = None,
|
||||
dtype: Optional[torch.dtype] = None,
|
||||
max_sequence_length: int = 256,
|
||||
):
|
||||
if prompt_embeds is None:
|
||||
prompt_embeds, prompt_attention_mask = self._get_llama_prompt_embeds(
|
||||
prompt,
|
||||
prompt_template,
|
||||
num_videos_per_prompt,
|
||||
device=device,
|
||||
dtype=dtype,
|
||||
max_sequence_length=max_sequence_length,
|
||||
)
|
||||
|
||||
if pooled_prompt_embeds is None:
|
||||
if prompt_2 is None and pooled_prompt_embeds is None:
|
||||
prompt_2 = prompt
|
||||
pooled_prompt_embeds = self._get_clip_prompt_embeds(
|
||||
prompt,
|
||||
num_videos_per_prompt,
|
||||
device=device,
|
||||
dtype=dtype,
|
||||
max_sequence_length=77,
|
||||
)
|
||||
|
||||
return prompt_embeds, pooled_prompt_embeds, prompt_attention_mask
|
||||
|
||||
def check_inputs(
|
||||
self,
|
||||
prompt,
|
||||
prompt_2,
|
||||
height,
|
||||
width,
|
||||
prompt_embeds=None,
|
||||
callback_on_step_end_tensor_inputs=None,
|
||||
prompt_template=None,
|
||||
):
|
||||
if height % 16 != 0 or width % 16 != 0:
|
||||
raise ValueError(f"`height` and `width` have to be divisible by 16 but are {height} and {width}.")
|
||||
|
||||
if callback_on_step_end_tensor_inputs is not None and not all(
|
||||
k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
|
||||
):
|
||||
raise ValueError(
|
||||
f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}"
|
||||
)
|
||||
|
||||
if prompt is not None and prompt_embeds is not None:
|
||||
raise ValueError(
|
||||
f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
|
||||
" only forward one of the two."
|
||||
)
|
||||
elif prompt_2 is not None and prompt_embeds is not None:
|
||||
raise ValueError(
|
||||
f"Cannot forward both `prompt_2`: {prompt_2} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
|
||||
" only forward one of the two."
|
||||
)
|
||||
elif prompt is None and prompt_embeds is None:
|
||||
raise ValueError(
|
||||
"Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
|
||||
)
|
||||
elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
|
||||
raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
|
||||
elif prompt_2 is not None and (not isinstance(prompt_2, str) and not isinstance(prompt_2, list)):
|
||||
raise ValueError(f"`prompt_2` has to be of type `str` or `list` but is {type(prompt_2)}")
|
||||
|
||||
if prompt_template is not None:
|
||||
if not isinstance(prompt_template, dict):
|
||||
raise ValueError(f"`prompt_template` has to be of type `dict` but is {type(prompt_template)}")
|
||||
if "template" not in prompt_template:
|
||||
raise ValueError(
|
||||
f"`prompt_template` has to contain a key `template` but only found {prompt_template.keys()}"
|
||||
)
|
||||
|
||||
def prepare_latents(
|
||||
self,
|
||||
batch_size: int,
|
||||
num_channels_latents: 32,
|
||||
height: int = 720,
|
||||
width: int = 1280,
|
||||
num_frames: int = 129,
|
||||
dtype: Optional[torch.dtype] = None,
|
||||
device: Optional[torch.device] = None,
|
||||
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
|
||||
latents: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
if latents is not None:
|
||||
return latents.to(device=device, dtype=dtype)
|
||||
|
||||
shape = (
|
||||
batch_size,
|
||||
num_channels_latents,
|
||||
num_frames,
|
||||
int(height) // self.vae_scale_factor_spatial,
|
||||
int(width) // self.vae_scale_factor_spatial,
|
||||
)
|
||||
if isinstance(generator, list) and len(generator) != batch_size:
|
||||
raise ValueError(
|
||||
f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
|
||||
f" size of {batch_size}. Make sure the batch size matches the length of the generators."
|
||||
)
|
||||
|
||||
latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
|
||||
return latents
|
||||
|
||||
def enable_vae_slicing(self):
|
||||
r"""
|
||||
Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to
|
||||
compute decoding in several steps. This is useful to save some memory and allow larger batch sizes.
|
||||
"""
|
||||
self.vae.enable_slicing()
|
||||
|
||||
def disable_vae_slicing(self):
|
||||
r"""
|
||||
Disable sliced VAE decoding. If `enable_vae_slicing` was previously enabled, this method will go back to
|
||||
computing decoding in one step.
|
||||
"""
|
||||
self.vae.disable_slicing()
|
||||
|
||||
def enable_vae_tiling(self):
|
||||
r"""
|
||||
Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to
|
||||
compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow
|
||||
processing larger images.
|
||||
"""
|
||||
self.vae.enable_tiling()
|
||||
|
||||
def disable_vae_tiling(self):
|
||||
r"""
|
||||
Disable tiled VAE decoding. If `enable_vae_tiling` was previously enabled, this method will go back to
|
||||
computing decoding in one step.
|
||||
"""
|
||||
self.vae.disable_tiling()
|
||||
|
||||
@property
|
||||
def guidance_scale(self):
|
||||
return self._guidance_scale
|
||||
|
||||
@property
|
||||
def do_spatio_temporal_guidance(self):
|
||||
return self._stg_scale > 0.0
|
||||
|
||||
@property
|
||||
def num_timesteps(self):
|
||||
return self._num_timesteps
|
||||
|
||||
@property
|
||||
def attention_kwargs(self):
|
||||
return self._attention_kwargs
|
||||
|
||||
@property
|
||||
def current_timestep(self):
|
||||
return self._current_timestep
|
||||
|
||||
@property
|
||||
def interrupt(self):
|
||||
return self._interrupt
|
||||
|
||||
@torch.no_grad()
|
||||
@replace_example_docstring(EXAMPLE_DOC_STRING)
|
||||
def __call__(
|
||||
self,
|
||||
prompt: Union[str, List[str]] = None,
|
||||
prompt_2: Union[str, List[str]] = None,
|
||||
height: int = 720,
|
||||
width: int = 1280,
|
||||
num_frames: int = 129,
|
||||
num_inference_steps: int = 50,
|
||||
sigmas: List[float] = None,
|
||||
guidance_scale: float = 6.0,
|
||||
num_videos_per_prompt: Optional[int] = 1,
|
||||
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
|
||||
latents: Optional[torch.Tensor] = None,
|
||||
prompt_embeds: Optional[torch.Tensor] = None,
|
||||
pooled_prompt_embeds: Optional[torch.Tensor] = None,
|
||||
prompt_attention_mask: Optional[torch.Tensor] = None,
|
||||
output_type: Optional[str] = "pil",
|
||||
return_dict: bool = True,
|
||||
attention_kwargs: Optional[Dict[str, Any]] = None,
|
||||
callback_on_step_end: Optional[
|
||||
Union[Callable[[int, int, Dict], None], PipelineCallback, MultiPipelineCallbacks]
|
||||
] = None,
|
||||
callback_on_step_end_tensor_inputs: List[str] = ["latents"],
|
||||
prompt_template: Dict[str, Any] = DEFAULT_PROMPT_TEMPLATE,
|
||||
max_sequence_length: int = 256,
|
||||
stg_applied_layers_idx: Optional[List[int]] = [2],
|
||||
stg_scale: Optional[float] = 0.0,
|
||||
):
|
||||
r"""
|
||||
The call function to the pipeline for generation.
|
||||
|
||||
Args:
|
||||
prompt (`str` or `List[str]`, *optional*):
|
||||
The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
|
||||
instead.
|
||||
prompt_2 (`str` or `List[str]`, *optional*):
|
||||
The prompt or prompts to be sent to `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is
|
||||
will be used instead.
|
||||
height (`int`, defaults to `720`):
|
||||
The height in pixels of the generated image.
|
||||
width (`int`, defaults to `1280`):
|
||||
The width in pixels of the generated image.
|
||||
num_frames (`int`, defaults to `129`):
|
||||
The number of frames in the generated video.
|
||||
num_inference_steps (`int`, defaults to `50`):
|
||||
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
|
||||
expense of slower inference.
|
||||
sigmas (`List[float]`, *optional*):
|
||||
Custom sigmas to use for the denoising process with schedulers which support a `sigmas` argument in
|
||||
their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed
|
||||
will be used.
|
||||
guidance_scale (`float`, defaults to `6.0`):
|
||||
Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
|
||||
`guidance_scale` is defined as `w` of equation 2. of [Imagen
|
||||
Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
|
||||
1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
|
||||
usually at the expense of lower image quality. Note that the only available HunyuanVideo model is
|
||||
CFG-distilled, which means that traditional guidance between unconditional and conditional latent is
|
||||
not applied.
|
||||
num_videos_per_prompt (`int`, *optional*, defaults to 1):
|
||||
The number of images to generate per prompt.
|
||||
generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
|
||||
A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make
|
||||
generation deterministic.
|
||||
latents (`torch.Tensor`, *optional*):
|
||||
Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs for image
|
||||
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
|
||||
tensor is generated by sampling using the supplied random `generator`.
|
||||
prompt_embeds (`torch.Tensor`, *optional*):
|
||||
Pre-generated text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not
|
||||
provided, text embeddings are generated from the `prompt` input argument.
|
||||
output_type (`str`, *optional*, defaults to `"pil"`):
|
||||
The output format of the generated image. Choose between `PIL.Image` or `np.array`.
|
||||
return_dict (`bool`, *optional*, defaults to `True`):
|
||||
Whether or not to return a [`HunyuanVideoPipelineOutput`] instead of a plain tuple.
|
||||
attention_kwargs (`dict`, *optional*):
|
||||
A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
|
||||
`self.processor` in
|
||||
[diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
|
||||
clip_skip (`int`, *optional*):
|
||||
Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that
|
||||
the output of the pre-final layer will be used for computing the prompt embeddings.
|
||||
callback_on_step_end (`Callable`, `PipelineCallback`, `MultiPipelineCallbacks`, *optional*):
|
||||
A function or a subclass of `PipelineCallback` or `MultiPipelineCallbacks` that is called at the end of
|
||||
each denoising step during the inference. with the following arguments: `callback_on_step_end(self:
|
||||
DiffusionPipeline, step: int, timestep: int, callback_kwargs: Dict)`. `callback_kwargs` will include a
|
||||
list of all tensors as specified by `callback_on_step_end_tensor_inputs`.
|
||||
callback_on_step_end_tensor_inputs (`List`, *optional*):
|
||||
The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
|
||||
will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
|
||||
`._callback_tensor_inputs` attribute of your pipeline class.
|
||||
|
||||
Examples:
|
||||
|
||||
Returns:
|
||||
[`~HunyuanVideoPipelineOutput`] or `tuple`:
|
||||
If `return_dict` is `True`, [`HunyuanVideoPipelineOutput`] is returned, otherwise a `tuple` is returned
|
||||
where the first element is a list with the generated images and the second element is a list of `bool`s
|
||||
indicating whether the corresponding generated image contains "not-safe-for-work" (nsfw) content.
|
||||
"""
|
||||
|
||||
if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)):
|
||||
callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs
|
||||
|
||||
# 1. Check inputs. Raise error if not correct
|
||||
self.check_inputs(
|
||||
prompt,
|
||||
prompt_2,
|
||||
height,
|
||||
width,
|
||||
prompt_embeds,
|
||||
callback_on_step_end_tensor_inputs,
|
||||
prompt_template,
|
||||
)
|
||||
|
||||
self._stg_scale = stg_scale
|
||||
self._guidance_scale = guidance_scale
|
||||
self._attention_kwargs = attention_kwargs
|
||||
self._current_timestep = None
|
||||
self._interrupt = False
|
||||
|
||||
device = self._execution_device
|
||||
|
||||
# 2. Define call parameters
|
||||
if prompt is not None and isinstance(prompt, str):
|
||||
batch_size = 1
|
||||
elif prompt is not None and isinstance(prompt, list):
|
||||
batch_size = len(prompt)
|
||||
else:
|
||||
batch_size = prompt_embeds.shape[0]
|
||||
|
||||
# 3. Encode input prompt
|
||||
prompt_embeds, pooled_prompt_embeds, prompt_attention_mask = self.encode_prompt(
|
||||
prompt=prompt,
|
||||
prompt_2=prompt_2,
|
||||
prompt_template=prompt_template,
|
||||
num_videos_per_prompt=num_videos_per_prompt,
|
||||
prompt_embeds=prompt_embeds,
|
||||
pooled_prompt_embeds=pooled_prompt_embeds,
|
||||
prompt_attention_mask=prompt_attention_mask,
|
||||
device=device,
|
||||
max_sequence_length=max_sequence_length,
|
||||
)
|
||||
|
||||
transformer_dtype = self.transformer.dtype
|
||||
prompt_embeds = prompt_embeds.to(transformer_dtype)
|
||||
prompt_attention_mask = prompt_attention_mask.to(transformer_dtype)
|
||||
if pooled_prompt_embeds is not None:
|
||||
pooled_prompt_embeds = pooled_prompt_embeds.to(transformer_dtype)
|
||||
|
||||
# 4. Prepare timesteps
|
||||
sigmas = np.linspace(1.0, 0.0, num_inference_steps + 1)[:-1] if sigmas is None else sigmas
|
||||
timesteps, num_inference_steps = retrieve_timesteps(
|
||||
self.scheduler,
|
||||
num_inference_steps,
|
||||
device,
|
||||
sigmas=sigmas,
|
||||
)
|
||||
|
||||
# 5. Prepare latent variables
|
||||
num_channels_latents = self.transformer.config.in_channels
|
||||
num_latent_frames = (num_frames - 1) // self.vae_scale_factor_temporal + 1
|
||||
latents = self.prepare_latents(
|
||||
batch_size * num_videos_per_prompt,
|
||||
num_channels_latents,
|
||||
height,
|
||||
width,
|
||||
num_latent_frames,
|
||||
torch.float32,
|
||||
device,
|
||||
generator,
|
||||
latents,
|
||||
)
|
||||
|
||||
# 6. Prepare guidance condition
|
||||
guidance = torch.tensor([guidance_scale] * latents.shape[0], dtype=transformer_dtype, device=device) * 1000.0
|
||||
|
||||
# 7. Denoising loop
|
||||
num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
|
||||
self._num_timesteps = len(timesteps)
|
||||
|
||||
with self.progress_bar(total=num_inference_steps) as progress_bar:
|
||||
for i, t in enumerate(timesteps):
|
||||
if self.interrupt:
|
||||
continue
|
||||
|
||||
self._current_timestep = t
|
||||
latent_model_input = latents.to(transformer_dtype)
|
||||
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
|
||||
timestep = t.expand(latents.shape[0]).to(latents.dtype)
|
||||
|
||||
if self.do_spatio_temporal_guidance:
|
||||
for i in stg_applied_layers_idx:
|
||||
self.transformer.transformer_blocks[i].forward = types.MethodType(
|
||||
forward_without_stg, self.transformer.transformer_blocks[i]
|
||||
)
|
||||
|
||||
noise_pred = self.transformer(
|
||||
hidden_states=latent_model_input,
|
||||
timestep=timestep,
|
||||
encoder_hidden_states=prompt_embeds,
|
||||
encoder_attention_mask=prompt_attention_mask,
|
||||
pooled_projections=pooled_prompt_embeds,
|
||||
guidance=guidance,
|
||||
attention_kwargs=attention_kwargs,
|
||||
return_dict=False,
|
||||
)[0]
|
||||
|
||||
if self.do_spatio_temporal_guidance:
|
||||
for i in stg_applied_layers_idx:
|
||||
self.transformer.transformer_blocks[i].forward = types.MethodType(
|
||||
forward_with_stg, self.transformer.transformer_blocks[i]
|
||||
)
|
||||
|
||||
noise_pred_perturb = self.transformer(
|
||||
hidden_states=latent_model_input,
|
||||
timestep=timestep,
|
||||
encoder_hidden_states=prompt_embeds,
|
||||
encoder_attention_mask=prompt_attention_mask,
|
||||
pooled_projections=pooled_prompt_embeds,
|
||||
guidance=guidance,
|
||||
attention_kwargs=attention_kwargs,
|
||||
return_dict=False,
|
||||
)[0]
|
||||
noise_pred = noise_pred + self._stg_scale * (noise_pred - noise_pred_perturb)
|
||||
|
||||
# compute the previous noisy sample x_t -> x_t-1
|
||||
latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0]
|
||||
|
||||
if callback_on_step_end is not None:
|
||||
callback_kwargs = {}
|
||||
for k in callback_on_step_end_tensor_inputs:
|
||||
callback_kwargs[k] = locals()[k]
|
||||
callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
|
||||
|
||||
latents = callback_outputs.pop("latents", latents)
|
||||
prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
|
||||
|
||||
# call the callback, if provided
|
||||
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
|
||||
progress_bar.update()
|
||||
|
||||
if XLA_AVAILABLE:
|
||||
xm.mark_step()
|
||||
|
||||
self._current_timestep = None
|
||||
|
||||
if not output_type == "latent":
|
||||
latents = latents.to(self.vae.dtype) / self.vae.config.scaling_factor
|
||||
video = self.vae.decode(latents, return_dict=False)[0]
|
||||
video = self.video_processor.postprocess_video(video, output_type=output_type)
|
||||
else:
|
||||
video = latents
|
||||
|
||||
# Offload all models
|
||||
self.maybe_free_model_hooks()
|
||||
|
||||
if not return_dict:
|
||||
return (video,)
|
||||
|
||||
return HunyuanVideoPipelineOutput(frames=video)
|
||||
@@ -0,0 +1,886 @@
|
||||
# Copyright 2024 Lightricks and The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import inspect
|
||||
import types
|
||||
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from transformers import T5EncoderModel, T5TokenizerFast
|
||||
|
||||
from diffusers.callbacks import MultiPipelineCallbacks, PipelineCallback
|
||||
from diffusers.loaders import FromSingleFileMixin, LTXVideoLoraLoaderMixin
|
||||
from diffusers.models.autoencoders import AutoencoderKLLTXVideo
|
||||
from diffusers.models.transformers import LTXVideoTransformer3DModel
|
||||
from diffusers.pipelines.ltx.pipeline_output import LTXPipelineOutput
|
||||
from diffusers.pipelines.pipeline_utils import DiffusionPipeline
|
||||
from diffusers.schedulers import FlowMatchEulerDiscreteScheduler
|
||||
from diffusers.utils import is_torch_xla_available, logging, replace_example_docstring
|
||||
from diffusers.utils.torch_utils import randn_tensor
|
||||
from diffusers.video_processor import VideoProcessor
|
||||
|
||||
|
||||
if is_torch_xla_available():
|
||||
import torch_xla.core.xla_model as xm
|
||||
|
||||
XLA_AVAILABLE = True
|
||||
else:
|
||||
XLA_AVAILABLE = False
|
||||
|
||||
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||
|
||||
EXAMPLE_DOC_STRING = """
|
||||
Examples:
|
||||
```py
|
||||
>>> import torch
|
||||
>>> from diffusers.utils import export_to_video
|
||||
>>> from examples.community.pipeline_stg_ltx import LTXSTGPipeline
|
||||
|
||||
>>> pipe = LTXSTGPipeline.from_pretrained("Lightricks/LTX-Video", torch_dtype=torch.bfloat16)
|
||||
>>> pipe.to("cuda")
|
||||
|
||||
>>> prompt = "A woman with light skin, wearing a blue jacket and a black hat with a veil, looks down and to her right, then back up as she speaks; she has brown hair styled in an updo, light brown eyebrows, and is wearing a white collared shirt under her jacket; the camera remains stationary on her face as she speaks; the background is out of focus, but shows trees and people in period clothing; the scene is captured in real-life footage."
|
||||
>>> negative_prompt = "worst quality, inconsistent motion, blurry, jittery, distorted"
|
||||
|
||||
>>> # Configure STG mode options
|
||||
>>> stg_applied_layers_idx = [19] # Layer indices from 0 to 41
|
||||
>>> stg_scale = 1.0 # Set 0.0 for CFG
|
||||
>>> do_rescaling = False
|
||||
|
||||
>>> video = pipe(
|
||||
... prompt=prompt,
|
||||
... negative_prompt=negative_prompt,
|
||||
... width=704,
|
||||
... height=480,
|
||||
... num_frames=161,
|
||||
... num_inference_steps=50,
|
||||
... stg_applied_layers_idx=stg_applied_layers_idx,
|
||||
... stg_scale=stg_scale,
|
||||
... do_rescaling=do_rescaling,
|
||||
>>> ).frames[0]
|
||||
>>> export_to_video(video, "output.mp4", fps=24)
|
||||
```
|
||||
"""
|
||||
|
||||
|
||||
def forward_with_stg(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
encoder_hidden_states: torch.Tensor,
|
||||
temb: torch.Tensor,
|
||||
image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
|
||||
encoder_attention_mask: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
hidden_states_ptb = hidden_states[2:]
|
||||
encoder_hidden_states_ptb = encoder_hidden_states[2:]
|
||||
|
||||
batch_size = hidden_states.size(0)
|
||||
norm_hidden_states = self.norm1(hidden_states)
|
||||
|
||||
num_ada_params = self.scale_shift_table.shape[0]
|
||||
ada_values = self.scale_shift_table[None, None] + temb.reshape(batch_size, temb.size(1), num_ada_params, -1)
|
||||
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = ada_values.unbind(dim=2)
|
||||
norm_hidden_states = norm_hidden_states * (1 + scale_msa) + shift_msa
|
||||
|
||||
attn_hidden_states = self.attn1(
|
||||
hidden_states=norm_hidden_states,
|
||||
encoder_hidden_states=None,
|
||||
image_rotary_emb=image_rotary_emb,
|
||||
)
|
||||
hidden_states = hidden_states + attn_hidden_states * gate_msa
|
||||
|
||||
attn_hidden_states = self.attn2(
|
||||
hidden_states,
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
image_rotary_emb=None,
|
||||
attention_mask=encoder_attention_mask,
|
||||
)
|
||||
hidden_states = hidden_states + attn_hidden_states
|
||||
norm_hidden_states = self.norm2(hidden_states) * (1 + scale_mlp) + shift_mlp
|
||||
|
||||
ff_output = self.ff(norm_hidden_states)
|
||||
hidden_states = hidden_states + ff_output * gate_mlp
|
||||
|
||||
hidden_states[2:] = hidden_states_ptb
|
||||
encoder_hidden_states[2:] = encoder_hidden_states_ptb
|
||||
|
||||
return hidden_states
|
||||
|
||||
|
||||
# Copied from diffusers.pipelines.flux.pipeline_flux.calculate_shift
|
||||
def calculate_shift(
|
||||
image_seq_len,
|
||||
base_seq_len: int = 256,
|
||||
max_seq_len: int = 4096,
|
||||
base_shift: float = 0.5,
|
||||
max_shift: float = 1.16,
|
||||
):
|
||||
m = (max_shift - base_shift) / (max_seq_len - base_seq_len)
|
||||
b = base_shift - m * base_seq_len
|
||||
mu = image_seq_len * m + b
|
||||
return mu
|
||||
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps
|
||||
def retrieve_timesteps(
|
||||
scheduler,
|
||||
num_inference_steps: Optional[int] = None,
|
||||
device: Optional[Union[str, torch.device]] = None,
|
||||
timesteps: Optional[List[int]] = None,
|
||||
sigmas: Optional[List[float]] = None,
|
||||
**kwargs,
|
||||
):
|
||||
r"""
|
||||
Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
|
||||
custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.
|
||||
|
||||
Args:
|
||||
scheduler (`SchedulerMixin`):
|
||||
The scheduler to get timesteps from.
|
||||
num_inference_steps (`int`):
|
||||
The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps`
|
||||
must be `None`.
|
||||
device (`str` or `torch.device`, *optional*):
|
||||
The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
|
||||
timesteps (`List[int]`, *optional*):
|
||||
Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed,
|
||||
`num_inference_steps` and `sigmas` must be `None`.
|
||||
sigmas (`List[float]`, *optional*):
|
||||
Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed,
|
||||
`num_inference_steps` and `timesteps` must be `None`.
|
||||
|
||||
Returns:
|
||||
`Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the
|
||||
second element is the number of inference steps.
|
||||
"""
|
||||
if timesteps is not None and sigmas is not None:
|
||||
raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values")
|
||||
if timesteps is not None:
|
||||
accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
|
||||
if not accepts_timesteps:
|
||||
raise ValueError(
|
||||
f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
|
||||
f" timestep schedules. Please check whether you are using the correct scheduler."
|
||||
)
|
||||
scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)
|
||||
timesteps = scheduler.timesteps
|
||||
num_inference_steps = len(timesteps)
|
||||
elif sigmas is not None:
|
||||
accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
|
||||
if not accept_sigmas:
|
||||
raise ValueError(
|
||||
f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
|
||||
f" sigmas schedules. Please check whether you are using the correct scheduler."
|
||||
)
|
||||
scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs)
|
||||
timesteps = scheduler.timesteps
|
||||
num_inference_steps = len(timesteps)
|
||||
else:
|
||||
scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)
|
||||
timesteps = scheduler.timesteps
|
||||
return timesteps, num_inference_steps
|
||||
|
||||
|
||||
class LTXSTGPipeline(DiffusionPipeline, FromSingleFileMixin, LTXVideoLoraLoaderMixin):
|
||||
r"""
|
||||
Pipeline for text-to-video generation.
|
||||
|
||||
Reference: https://github.com/Lightricks/LTX-Video
|
||||
|
||||
Args:
|
||||
transformer ([`LTXVideoTransformer3DModel`]):
|
||||
Conditional Transformer architecture to denoise the encoded video latents.
|
||||
scheduler ([`FlowMatchEulerDiscreteScheduler`]):
|
||||
A scheduler to be used in combination with `transformer` to denoise the encoded image latents.
|
||||
vae ([`AutoencoderKLLTXVideo`]):
|
||||
Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.
|
||||
text_encoder ([`T5EncoderModel`]):
|
||||
[T5](https://huggingface.co/docs/transformers/en/model_doc/t5#transformers.T5EncoderModel), specifically
|
||||
the [google/t5-v1_1-xxl](https://huggingface.co/google/t5-v1_1-xxl) variant.
|
||||
tokenizer (`CLIPTokenizer`):
|
||||
Tokenizer of class
|
||||
[CLIPTokenizer](https://huggingface.co/docs/transformers/en/model_doc/clip#transformers.CLIPTokenizer).
|
||||
tokenizer (`T5TokenizerFast`):
|
||||
Second Tokenizer of class
|
||||
[T5TokenizerFast](https://huggingface.co/docs/transformers/en/model_doc/t5#transformers.T5TokenizerFast).
|
||||
"""
|
||||
|
||||
model_cpu_offload_seq = "text_encoder->transformer->vae"
|
||||
_optional_components = []
|
||||
_callback_tensor_inputs = ["latents", "prompt_embeds", "negative_prompt_embeds"]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
scheduler: FlowMatchEulerDiscreteScheduler,
|
||||
vae: AutoencoderKLLTXVideo,
|
||||
text_encoder: T5EncoderModel,
|
||||
tokenizer: T5TokenizerFast,
|
||||
transformer: LTXVideoTransformer3DModel,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.register_modules(
|
||||
vae=vae,
|
||||
text_encoder=text_encoder,
|
||||
tokenizer=tokenizer,
|
||||
transformer=transformer,
|
||||
scheduler=scheduler,
|
||||
)
|
||||
|
||||
self.vae_spatial_compression_ratio = (
|
||||
self.vae.spatial_compression_ratio if getattr(self, "vae", None) is not None else 32
|
||||
)
|
||||
self.vae_temporal_compression_ratio = (
|
||||
self.vae.temporal_compression_ratio if getattr(self, "vae", None) is not None else 8
|
||||
)
|
||||
self.transformer_spatial_patch_size = (
|
||||
self.transformer.config.patch_size if getattr(self, "transformer", None) is not None else 1
|
||||
)
|
||||
self.transformer_temporal_patch_size = (
|
||||
self.transformer.config.patch_size_t if getattr(self, "transformer") is not None else 1
|
||||
)
|
||||
|
||||
self.video_processor = VideoProcessor(vae_scale_factor=self.vae_spatial_compression_ratio)
|
||||
self.tokenizer_max_length = (
|
||||
self.tokenizer.model_max_length if getattr(self, "tokenizer", None) is not None else 128
|
||||
)
|
||||
|
||||
def _get_t5_prompt_embeds(
|
||||
self,
|
||||
prompt: Union[str, List[str]] = None,
|
||||
num_videos_per_prompt: int = 1,
|
||||
max_sequence_length: int = 128,
|
||||
device: Optional[torch.device] = None,
|
||||
dtype: Optional[torch.dtype] = None,
|
||||
):
|
||||
device = device or self._execution_device
|
||||
dtype = dtype or self.text_encoder.dtype
|
||||
|
||||
prompt = [prompt] if isinstance(prompt, str) else prompt
|
||||
batch_size = len(prompt)
|
||||
|
||||
text_inputs = self.tokenizer(
|
||||
prompt,
|
||||
padding="max_length",
|
||||
max_length=max_sequence_length,
|
||||
truncation=True,
|
||||
add_special_tokens=True,
|
||||
return_tensors="pt",
|
||||
)
|
||||
text_input_ids = text_inputs.input_ids
|
||||
prompt_attention_mask = text_inputs.attention_mask
|
||||
prompt_attention_mask = prompt_attention_mask.bool().to(device)
|
||||
|
||||
untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
|
||||
|
||||
if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids):
|
||||
removed_text = self.tokenizer.batch_decode(untruncated_ids[:, max_sequence_length - 1 : -1])
|
||||
logger.warning(
|
||||
"The following part of your input was truncated because `max_sequence_length` is set to "
|
||||
f" {max_sequence_length} tokens: {removed_text}"
|
||||
)
|
||||
|
||||
prompt_embeds = self.text_encoder(text_input_ids.to(device))[0]
|
||||
prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)
|
||||
|
||||
# duplicate text embeddings for each generation per prompt, using mps friendly method
|
||||
_, seq_len, _ = prompt_embeds.shape
|
||||
prompt_embeds = prompt_embeds.repeat(1, num_videos_per_prompt, 1)
|
||||
prompt_embeds = prompt_embeds.view(batch_size * num_videos_per_prompt, seq_len, -1)
|
||||
|
||||
prompt_attention_mask = prompt_attention_mask.view(batch_size, -1)
|
||||
prompt_attention_mask = prompt_attention_mask.repeat(num_videos_per_prompt, 1)
|
||||
|
||||
return prompt_embeds, prompt_attention_mask
|
||||
|
||||
# Copied from diffusers.pipelines.mochi.pipeline_mochi.MochiPipeline.encode_prompt with 256->128
|
||||
def encode_prompt(
|
||||
self,
|
||||
prompt: Union[str, List[str]],
|
||||
negative_prompt: Optional[Union[str, List[str]]] = None,
|
||||
do_classifier_free_guidance: bool = True,
|
||||
num_videos_per_prompt: int = 1,
|
||||
prompt_embeds: Optional[torch.Tensor] = None,
|
||||
negative_prompt_embeds: Optional[torch.Tensor] = None,
|
||||
prompt_attention_mask: Optional[torch.Tensor] = None,
|
||||
negative_prompt_attention_mask: Optional[torch.Tensor] = None,
|
||||
max_sequence_length: int = 128,
|
||||
device: Optional[torch.device] = None,
|
||||
dtype: Optional[torch.dtype] = None,
|
||||
):
|
||||
r"""
|
||||
Encodes the prompt into text encoder hidden states.
|
||||
|
||||
Args:
|
||||
prompt (`str` or `List[str]`, *optional*):
|
||||
prompt to be encoded
|
||||
negative_prompt (`str` or `List[str]`, *optional*):
|
||||
The prompt or prompts not to guide the image generation. If not defined, one has to pass
|
||||
`negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
|
||||
less than `1`).
|
||||
do_classifier_free_guidance (`bool`, *optional*, defaults to `True`):
|
||||
Whether to use classifier free guidance or not.
|
||||
num_videos_per_prompt (`int`, *optional*, defaults to 1):
|
||||
Number of videos that should be generated per prompt. torch device to place the resulting embeddings on
|
||||
prompt_embeds (`torch.Tensor`, *optional*):
|
||||
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
|
||||
provided, text embeddings will be generated from `prompt` input argument.
|
||||
negative_prompt_embeds (`torch.Tensor`, *optional*):
|
||||
Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
|
||||
weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
|
||||
argument.
|
||||
device: (`torch.device`, *optional*):
|
||||
torch device
|
||||
dtype: (`torch.dtype`, *optional*):
|
||||
torch dtype
|
||||
"""
|
||||
device = device or self._execution_device
|
||||
|
||||
prompt = [prompt] if isinstance(prompt, str) else prompt
|
||||
if prompt is not None:
|
||||
batch_size = len(prompt)
|
||||
else:
|
||||
batch_size = prompt_embeds.shape[0]
|
||||
|
||||
if prompt_embeds is None:
|
||||
prompt_embeds, prompt_attention_mask = self._get_t5_prompt_embeds(
|
||||
prompt=prompt,
|
||||
num_videos_per_prompt=num_videos_per_prompt,
|
||||
max_sequence_length=max_sequence_length,
|
||||
device=device,
|
||||
dtype=dtype,
|
||||
)
|
||||
|
||||
if do_classifier_free_guidance and negative_prompt_embeds is None:
|
||||
negative_prompt = negative_prompt or ""
|
||||
negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt
|
||||
|
||||
if prompt is not None and type(prompt) is not type(negative_prompt):
|
||||
raise TypeError(
|
||||
f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
|
||||
f" {type(prompt)}."
|
||||
)
|
||||
elif batch_size != len(negative_prompt):
|
||||
raise ValueError(
|
||||
f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
|
||||
f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
|
||||
" the batch size of `prompt`."
|
||||
)
|
||||
|
||||
negative_prompt_embeds, negative_prompt_attention_mask = self._get_t5_prompt_embeds(
|
||||
prompt=negative_prompt,
|
||||
num_videos_per_prompt=num_videos_per_prompt,
|
||||
max_sequence_length=max_sequence_length,
|
||||
device=device,
|
||||
dtype=dtype,
|
||||
)
|
||||
|
||||
return prompt_embeds, prompt_attention_mask, negative_prompt_embeds, negative_prompt_attention_mask
|
||||
|
||||
def check_inputs(
|
||||
self,
|
||||
prompt,
|
||||
height,
|
||||
width,
|
||||
callback_on_step_end_tensor_inputs=None,
|
||||
prompt_embeds=None,
|
||||
negative_prompt_embeds=None,
|
||||
prompt_attention_mask=None,
|
||||
negative_prompt_attention_mask=None,
|
||||
):
|
||||
if height % 32 != 0 or width % 32 != 0:
|
||||
raise ValueError(f"`height` and `width` have to be divisible by 32 but are {height} and {width}.")
|
||||
|
||||
if callback_on_step_end_tensor_inputs is not None and not all(
|
||||
k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
|
||||
):
|
||||
raise ValueError(
|
||||
f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}"
|
||||
)
|
||||
|
||||
if prompt is not None and prompt_embeds is not None:
|
||||
raise ValueError(
|
||||
f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
|
||||
" only forward one of the two."
|
||||
)
|
||||
elif prompt is None and prompt_embeds is None:
|
||||
raise ValueError(
|
||||
"Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
|
||||
)
|
||||
elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
|
||||
raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
|
||||
|
||||
if prompt_embeds is not None and prompt_attention_mask is None:
|
||||
raise ValueError("Must provide `prompt_attention_mask` when specifying `prompt_embeds`.")
|
||||
|
||||
if negative_prompt_embeds is not None and negative_prompt_attention_mask is None:
|
||||
raise ValueError("Must provide `negative_prompt_attention_mask` when specifying `negative_prompt_embeds`.")
|
||||
|
||||
if prompt_embeds is not None and negative_prompt_embeds is not None:
|
||||
if prompt_embeds.shape != negative_prompt_embeds.shape:
|
||||
raise ValueError(
|
||||
"`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but"
|
||||
f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`"
|
||||
f" {negative_prompt_embeds.shape}."
|
||||
)
|
||||
if prompt_attention_mask.shape != negative_prompt_attention_mask.shape:
|
||||
raise ValueError(
|
||||
"`prompt_attention_mask` and `negative_prompt_attention_mask` must have the same shape when passed directly, but"
|
||||
f" got: `prompt_attention_mask` {prompt_attention_mask.shape} != `negative_prompt_attention_mask`"
|
||||
f" {negative_prompt_attention_mask.shape}."
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _pack_latents(latents: torch.Tensor, patch_size: int = 1, patch_size_t: int = 1) -> torch.Tensor:
|
||||
# Unpacked latents of shape are [B, C, F, H, W] are patched into tokens of shape [B, C, F // p_t, p_t, H // p, p, W // p, p].
|
||||
# The patch dimensions are then permuted and collapsed into the channel dimension of shape:
|
||||
# [B, F // p_t * H // p * W // p, C * p_t * p * p] (an ndim=3 tensor).
|
||||
# dim=0 is the batch size, dim=1 is the effective video sequence length, dim=2 is the effective number of input features
|
||||
batch_size, num_channels, num_frames, height, width = latents.shape
|
||||
post_patch_num_frames = num_frames // patch_size_t
|
||||
post_patch_height = height // patch_size
|
||||
post_patch_width = width // patch_size
|
||||
latents = latents.reshape(
|
||||
batch_size,
|
||||
-1,
|
||||
post_patch_num_frames,
|
||||
patch_size_t,
|
||||
post_patch_height,
|
||||
patch_size,
|
||||
post_patch_width,
|
||||
patch_size,
|
||||
)
|
||||
latents = latents.permute(0, 2, 4, 6, 1, 3, 5, 7).flatten(4, 7).flatten(1, 3)
|
||||
return latents
|
||||
|
||||
@staticmethod
|
||||
def _unpack_latents(
|
||||
latents: torch.Tensor, num_frames: int, height: int, width: int, patch_size: int = 1, patch_size_t: int = 1
|
||||
) -> torch.Tensor:
|
||||
# Packed latents of shape [B, S, D] (S is the effective video sequence length, D is the effective feature dimensions)
|
||||
# are unpacked and reshaped into a video tensor of shape [B, C, F, H, W]. This is the inverse operation of
|
||||
# what happens in the `_pack_latents` method.
|
||||
batch_size = latents.size(0)
|
||||
latents = latents.reshape(batch_size, num_frames, height, width, -1, patch_size_t, patch_size, patch_size)
|
||||
latents = latents.permute(0, 4, 1, 5, 2, 6, 3, 7).flatten(6, 7).flatten(4, 5).flatten(2, 3)
|
||||
return latents
|
||||
|
||||
@staticmethod
|
||||
def _normalize_latents(
|
||||
latents: torch.Tensor, latents_mean: torch.Tensor, latents_std: torch.Tensor, scaling_factor: float = 1.0
|
||||
) -> torch.Tensor:
|
||||
# Normalize latents across the channel dimension [B, C, F, H, W]
|
||||
latents_mean = latents_mean.view(1, -1, 1, 1, 1).to(latents.device, latents.dtype)
|
||||
latents_std = latents_std.view(1, -1, 1, 1, 1).to(latents.device, latents.dtype)
|
||||
latents = (latents - latents_mean) * scaling_factor / latents_std
|
||||
return latents
|
||||
|
||||
@staticmethod
|
||||
def _denormalize_latents(
|
||||
latents: torch.Tensor, latents_mean: torch.Tensor, latents_std: torch.Tensor, scaling_factor: float = 1.0
|
||||
) -> torch.Tensor:
|
||||
# Denormalize latents across the channel dimension [B, C, F, H, W]
|
||||
latents_mean = latents_mean.view(1, -1, 1, 1, 1).to(latents.device, latents.dtype)
|
||||
latents_std = latents_std.view(1, -1, 1, 1, 1).to(latents.device, latents.dtype)
|
||||
latents = latents * latents_std / scaling_factor + latents_mean
|
||||
return latents
|
||||
|
||||
def prepare_latents(
|
||||
self,
|
||||
batch_size: int = 1,
|
||||
num_channels_latents: int = 128,
|
||||
height: int = 512,
|
||||
width: int = 704,
|
||||
num_frames: int = 161,
|
||||
dtype: Optional[torch.dtype] = None,
|
||||
device: Optional[torch.device] = None,
|
||||
generator: Optional[torch.Generator] = None,
|
||||
latents: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
if latents is not None:
|
||||
return latents.to(device=device, dtype=dtype)
|
||||
|
||||
height = height // self.vae_spatial_compression_ratio
|
||||
width = width // self.vae_spatial_compression_ratio
|
||||
num_frames = (num_frames - 1) // self.vae_temporal_compression_ratio + 1
|
||||
|
||||
shape = (batch_size, num_channels_latents, num_frames, height, width)
|
||||
|
||||
if isinstance(generator, list) and len(generator) != batch_size:
|
||||
raise ValueError(
|
||||
f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
|
||||
f" size of {batch_size}. Make sure the batch size matches the length of the generators."
|
||||
)
|
||||
|
||||
latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
|
||||
latents = self._pack_latents(
|
||||
latents, self.transformer_spatial_patch_size, self.transformer_temporal_patch_size
|
||||
)
|
||||
return latents
|
||||
|
||||
@property
|
||||
def guidance_scale(self):
|
||||
return self._guidance_scale
|
||||
|
||||
@property
|
||||
def do_classifier_free_guidance(self):
|
||||
return self._guidance_scale > 1.0
|
||||
|
||||
@property
|
||||
def do_spatio_temporal_guidance(self):
|
||||
return self._stg_scale > 0.0
|
||||
|
||||
@property
|
||||
def num_timesteps(self):
|
||||
return self._num_timesteps
|
||||
|
||||
@property
|
||||
def attention_kwargs(self):
|
||||
return self._attention_kwargs
|
||||
|
||||
@property
|
||||
def interrupt(self):
|
||||
return self._interrupt
|
||||
|
||||
@torch.no_grad()
|
||||
@replace_example_docstring(EXAMPLE_DOC_STRING)
|
||||
def __call__(
|
||||
self,
|
||||
prompt: Union[str, List[str]] = None,
|
||||
negative_prompt: Optional[Union[str, List[str]]] = None,
|
||||
height: int = 512,
|
||||
width: int = 704,
|
||||
num_frames: int = 161,
|
||||
frame_rate: int = 25,
|
||||
num_inference_steps: int = 50,
|
||||
timesteps: List[int] = None,
|
||||
guidance_scale: float = 3,
|
||||
num_videos_per_prompt: Optional[int] = 1,
|
||||
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
|
||||
latents: Optional[torch.Tensor] = None,
|
||||
prompt_embeds: Optional[torch.Tensor] = None,
|
||||
prompt_attention_mask: Optional[torch.Tensor] = None,
|
||||
negative_prompt_embeds: Optional[torch.Tensor] = None,
|
||||
negative_prompt_attention_mask: Optional[torch.Tensor] = None,
|
||||
decode_timestep: Union[float, List[float]] = 0.0,
|
||||
decode_noise_scale: Optional[Union[float, List[float]]] = None,
|
||||
output_type: Optional[str] = "pil",
|
||||
return_dict: bool = True,
|
||||
attention_kwargs: Optional[Dict[str, Any]] = None,
|
||||
callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
|
||||
callback_on_step_end_tensor_inputs: List[str] = ["latents"],
|
||||
max_sequence_length: int = 128,
|
||||
stg_applied_layers_idx: Optional[List[int]] = [19],
|
||||
stg_scale: Optional[float] = 1.0,
|
||||
do_rescaling: Optional[bool] = False,
|
||||
):
|
||||
r"""
|
||||
Function invoked when calling the pipeline for generation.
|
||||
|
||||
Args:
|
||||
prompt (`str` or `List[str]`, *optional*):
|
||||
The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
|
||||
instead.
|
||||
height (`int`, defaults to `512`):
|
||||
The height in pixels of the generated image. This is set to 480 by default for the best results.
|
||||
width (`int`, defaults to `704`):
|
||||
The width in pixels of the generated image. This is set to 848 by default for the best results.
|
||||
num_frames (`int`, defaults to `161`):
|
||||
The number of video frames to generate
|
||||
num_inference_steps (`int`, *optional*, defaults to 50):
|
||||
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
|
||||
expense of slower inference.
|
||||
timesteps (`List[int]`, *optional*):
|
||||
Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument
|
||||
in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is
|
||||
passed will be used. Must be in descending order.
|
||||
guidance_scale (`float`, defaults to `3 `):
|
||||
Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
|
||||
`guidance_scale` is defined as `w` of equation 2. of [Imagen
|
||||
Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
|
||||
1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
|
||||
usually at the expense of lower image quality.
|
||||
num_videos_per_prompt (`int`, *optional*, defaults to 1):
|
||||
The number of videos to generate per prompt.
|
||||
generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
|
||||
One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
|
||||
to make generation deterministic.
|
||||
latents (`torch.Tensor`, *optional*):
|
||||
Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
|
||||
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
|
||||
tensor will ge generated by sampling using the supplied random `generator`.
|
||||
prompt_embeds (`torch.Tensor`, *optional*):
|
||||
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
|
||||
provided, text embeddings will be generated from `prompt` input argument.
|
||||
prompt_attention_mask (`torch.Tensor`, *optional*):
|
||||
Pre-generated attention mask for text embeddings.
|
||||
negative_prompt_embeds (`torch.FloatTensor`, *optional*):
|
||||
Pre-generated negative text embeddings. For PixArt-Sigma this negative prompt should be "". If not
|
||||
provided, negative_prompt_embeds will be generated from `negative_prompt` input argument.
|
||||
negative_prompt_attention_mask (`torch.FloatTensor`, *optional*):
|
||||
Pre-generated attention mask for negative text embeddings.
|
||||
decode_timestep (`float`, defaults to `0.0`):
|
||||
The timestep at which generated video is decoded.
|
||||
decode_noise_scale (`float`, defaults to `None`):
|
||||
The interpolation factor between random noise and denoised latents at the decode timestep.
|
||||
output_type (`str`, *optional*, defaults to `"pil"`):
|
||||
The output format of the generate image. Choose between
|
||||
[PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
|
||||
return_dict (`bool`, *optional*, defaults to `True`):
|
||||
Whether or not to return a [`~pipelines.ltx.LTXPipelineOutput`] instead of a plain tuple.
|
||||
attention_kwargs (`dict`, *optional*):
|
||||
A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
|
||||
`self.processor` in
|
||||
[diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
|
||||
callback_on_step_end (`Callable`, *optional*):
|
||||
A function that calls at the end of each denoising steps during the inference. The function is called
|
||||
with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int,
|
||||
callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by
|
||||
`callback_on_step_end_tensor_inputs`.
|
||||
callback_on_step_end_tensor_inputs (`List`, *optional*):
|
||||
The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
|
||||
will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
|
||||
`._callback_tensor_inputs` attribute of your pipeline class.
|
||||
max_sequence_length (`int` defaults to `128 `):
|
||||
Maximum sequence length to use with the `prompt`.
|
||||
|
||||
Examples:
|
||||
|
||||
Returns:
|
||||
[`~pipelines.ltx.LTXPipelineOutput`] or `tuple`:
|
||||
If `return_dict` is `True`, [`~pipelines.ltx.LTXPipelineOutput`] is returned, otherwise a `tuple` is
|
||||
returned where the first element is a list with the generated images.
|
||||
"""
|
||||
|
||||
if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)):
|
||||
callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs
|
||||
|
||||
# 1. Check inputs. Raise error if not correct
|
||||
self.check_inputs(
|
||||
prompt=prompt,
|
||||
height=height,
|
||||
width=width,
|
||||
callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs,
|
||||
prompt_embeds=prompt_embeds,
|
||||
negative_prompt_embeds=negative_prompt_embeds,
|
||||
prompt_attention_mask=prompt_attention_mask,
|
||||
negative_prompt_attention_mask=negative_prompt_attention_mask,
|
||||
)
|
||||
|
||||
self._stg_scale = stg_scale
|
||||
self._guidance_scale = guidance_scale
|
||||
self._attention_kwargs = attention_kwargs
|
||||
self._interrupt = False
|
||||
|
||||
if self.do_spatio_temporal_guidance:
|
||||
for i in stg_applied_layers_idx:
|
||||
self.transformer.transformer_blocks[i].forward = types.MethodType(
|
||||
forward_with_stg, self.transformer.transformer_blocks[i]
|
||||
)
|
||||
|
||||
# 2. Define call parameters
|
||||
if prompt is not None and isinstance(prompt, str):
|
||||
batch_size = 1
|
||||
elif prompt is not None and isinstance(prompt, list):
|
||||
batch_size = len(prompt)
|
||||
else:
|
||||
batch_size = prompt_embeds.shape[0]
|
||||
|
||||
device = self._execution_device
|
||||
|
||||
# 3. Prepare text embeddings
|
||||
(
|
||||
prompt_embeds,
|
||||
prompt_attention_mask,
|
||||
negative_prompt_embeds,
|
||||
negative_prompt_attention_mask,
|
||||
) = self.encode_prompt(
|
||||
prompt=prompt,
|
||||
negative_prompt=negative_prompt,
|
||||
do_classifier_free_guidance=self.do_classifier_free_guidance,
|
||||
num_videos_per_prompt=num_videos_per_prompt,
|
||||
prompt_embeds=prompt_embeds,
|
||||
negative_prompt_embeds=negative_prompt_embeds,
|
||||
prompt_attention_mask=prompt_attention_mask,
|
||||
negative_prompt_attention_mask=negative_prompt_attention_mask,
|
||||
max_sequence_length=max_sequence_length,
|
||||
device=device,
|
||||
)
|
||||
if self.do_classifier_free_guidance and not self.do_spatio_temporal_guidance:
|
||||
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)
|
||||
prompt_attention_mask = torch.cat([negative_prompt_attention_mask, prompt_attention_mask], dim=0)
|
||||
elif self.do_classifier_free_guidance and self.do_spatio_temporal_guidance:
|
||||
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds, prompt_embeds], dim=0)
|
||||
prompt_attention_mask = torch.cat(
|
||||
[negative_prompt_attention_mask, prompt_attention_mask, prompt_attention_mask], dim=0
|
||||
)
|
||||
|
||||
# 4. Prepare latent variables
|
||||
num_channels_latents = self.transformer.config.in_channels
|
||||
latents = self.prepare_latents(
|
||||
batch_size * num_videos_per_prompt,
|
||||
num_channels_latents,
|
||||
height,
|
||||
width,
|
||||
num_frames,
|
||||
torch.float32,
|
||||
device,
|
||||
generator,
|
||||
latents,
|
||||
)
|
||||
|
||||
# 5. Prepare timesteps
|
||||
latent_num_frames = (num_frames - 1) // self.vae_temporal_compression_ratio + 1
|
||||
latent_height = height // self.vae_spatial_compression_ratio
|
||||
latent_width = width // self.vae_spatial_compression_ratio
|
||||
video_sequence_length = latent_num_frames * latent_height * latent_width
|
||||
sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps)
|
||||
mu = calculate_shift(
|
||||
video_sequence_length,
|
||||
self.scheduler.config.get("base_image_seq_len", 256),
|
||||
self.scheduler.config.get("max_image_seq_len", 4096),
|
||||
self.scheduler.config.get("base_shift", 0.5),
|
||||
self.scheduler.config.get("max_shift", 1.16),
|
||||
)
|
||||
timesteps, num_inference_steps = retrieve_timesteps(
|
||||
self.scheduler,
|
||||
num_inference_steps,
|
||||
device,
|
||||
timesteps,
|
||||
sigmas=sigmas,
|
||||
mu=mu,
|
||||
)
|
||||
num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
|
||||
self._num_timesteps = len(timesteps)
|
||||
|
||||
# 6. Prepare micro-conditions
|
||||
latent_frame_rate = frame_rate / self.vae_temporal_compression_ratio
|
||||
rope_interpolation_scale = (
|
||||
1 / latent_frame_rate,
|
||||
self.vae_spatial_compression_ratio,
|
||||
self.vae_spatial_compression_ratio,
|
||||
)
|
||||
|
||||
# 7. Denoising loop
|
||||
with self.progress_bar(total=num_inference_steps) as progress_bar:
|
||||
for i, t in enumerate(timesteps):
|
||||
if self.interrupt:
|
||||
continue
|
||||
|
||||
if self.do_classifier_free_guidance and not self.do_spatio_temporal_guidance:
|
||||
latent_model_input = torch.cat([latents] * 2)
|
||||
elif self.do_classifier_free_guidance and self.do_spatio_temporal_guidance:
|
||||
latent_model_input = torch.cat([latents] * 3)
|
||||
else:
|
||||
latent_model_input = latents
|
||||
|
||||
latent_model_input = latent_model_input.to(prompt_embeds.dtype)
|
||||
|
||||
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
|
||||
timestep = t.expand(latent_model_input.shape[0])
|
||||
|
||||
noise_pred = self.transformer(
|
||||
hidden_states=latent_model_input,
|
||||
encoder_hidden_states=prompt_embeds,
|
||||
timestep=timestep,
|
||||
encoder_attention_mask=prompt_attention_mask,
|
||||
num_frames=latent_num_frames,
|
||||
height=latent_height,
|
||||
width=latent_width,
|
||||
rope_interpolation_scale=rope_interpolation_scale,
|
||||
attention_kwargs=attention_kwargs,
|
||||
return_dict=False,
|
||||
)[0]
|
||||
noise_pred = noise_pred.float()
|
||||
|
||||
if self.do_classifier_free_guidance and not self.do_spatio_temporal_guidance:
|
||||
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
|
||||
noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond)
|
||||
elif self.do_classifier_free_guidance and self.do_spatio_temporal_guidance:
|
||||
noise_pred_uncond, noise_pred_text, noise_pred_perturb = noise_pred.chunk(3)
|
||||
noise_pred = (
|
||||
noise_pred_uncond
|
||||
+ self.guidance_scale * (noise_pred_text - noise_pred_uncond)
|
||||
+ self._stg_scale * (noise_pred_text - noise_pred_perturb)
|
||||
)
|
||||
|
||||
if do_rescaling:
|
||||
rescaling_scale = 0.7
|
||||
factor = noise_pred_text.std() / noise_pred.std()
|
||||
factor = rescaling_scale * factor + (1 - rescaling_scale)
|
||||
noise_pred = noise_pred * factor
|
||||
|
||||
# compute the previous noisy sample x_t -> x_t-1
|
||||
latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0]
|
||||
|
||||
if callback_on_step_end is not None:
|
||||
callback_kwargs = {}
|
||||
for k in callback_on_step_end_tensor_inputs:
|
||||
callback_kwargs[k] = locals()[k]
|
||||
callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
|
||||
|
||||
latents = callback_outputs.pop("latents", latents)
|
||||
prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
|
||||
|
||||
# call the callback, if provided
|
||||
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
|
||||
progress_bar.update()
|
||||
|
||||
if XLA_AVAILABLE:
|
||||
xm.mark_step()
|
||||
|
||||
if output_type == "latent":
|
||||
video = latents
|
||||
else:
|
||||
latents = self._unpack_latents(
|
||||
latents,
|
||||
latent_num_frames,
|
||||
latent_height,
|
||||
latent_width,
|
||||
self.transformer_spatial_patch_size,
|
||||
self.transformer_temporal_patch_size,
|
||||
)
|
||||
latents = self._denormalize_latents(
|
||||
latents, self.vae.latents_mean, self.vae.latents_std, self.vae.config.scaling_factor
|
||||
)
|
||||
latents = latents.to(prompt_embeds.dtype)
|
||||
|
||||
if not self.vae.config.timestep_conditioning:
|
||||
timestep = None
|
||||
else:
|
||||
noise = randn_tensor(latents.shape, generator=generator, device=device, dtype=latents.dtype)
|
||||
if not isinstance(decode_timestep, list):
|
||||
decode_timestep = [decode_timestep] * batch_size
|
||||
if decode_noise_scale is None:
|
||||
decode_noise_scale = decode_timestep
|
||||
elif not isinstance(decode_noise_scale, list):
|
||||
decode_noise_scale = [decode_noise_scale] * batch_size
|
||||
|
||||
timestep = torch.tensor(decode_timestep, device=device, dtype=latents.dtype)
|
||||
decode_noise_scale = torch.tensor(decode_noise_scale, device=device, dtype=latents.dtype)[
|
||||
:, None, None, None, None
|
||||
]
|
||||
latents = (1 - decode_noise_scale) * latents + decode_noise_scale * noise
|
||||
|
||||
video = self.vae.decode(latents, timestep, return_dict=False)[0]
|
||||
video = self.video_processor.postprocess_video(video, output_type=output_type)
|
||||
|
||||
# Offload all models
|
||||
self.maybe_free_model_hooks()
|
||||
|
||||
if not return_dict:
|
||||
return (video,)
|
||||
|
||||
return LTXPipelineOutput(frames=video)
|
||||
@@ -0,0 +1,985 @@
|
||||
# Copyright 2024 Lightricks and The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import inspect
|
||||
import types
|
||||
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from transformers import T5EncoderModel, T5TokenizerFast
|
||||
|
||||
from diffusers.callbacks import MultiPipelineCallbacks, PipelineCallback
|
||||
from diffusers.image_processor import PipelineImageInput
|
||||
from diffusers.loaders import FromSingleFileMixin, LTXVideoLoraLoaderMixin
|
||||
from diffusers.models.autoencoders import AutoencoderKLLTXVideo
|
||||
from diffusers.models.transformers import LTXVideoTransformer3DModel
|
||||
from diffusers.pipelines.ltx.pipeline_output import LTXPipelineOutput
|
||||
from diffusers.pipelines.pipeline_utils import DiffusionPipeline
|
||||
from diffusers.schedulers import FlowMatchEulerDiscreteScheduler
|
||||
from diffusers.utils import is_torch_xla_available, logging, replace_example_docstring
|
||||
from diffusers.utils.torch_utils import randn_tensor
|
||||
from diffusers.video_processor import VideoProcessor
|
||||
|
||||
|
||||
if is_torch_xla_available():
|
||||
import torch_xla.core.xla_model as xm
|
||||
|
||||
XLA_AVAILABLE = True
|
||||
else:
|
||||
XLA_AVAILABLE = False
|
||||
|
||||
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||
|
||||
EXAMPLE_DOC_STRING = """
|
||||
Examples:
|
||||
```py
|
||||
>>> import torch
|
||||
>>> from diffusers.utils import export_to_video, load_image
|
||||
>>> from examples.community.pipeline_stg_ltx_image2video import LTXImageToVideoSTGPipeline
|
||||
|
||||
>>> pipe = LTXImageToVideoSTGPipeline.from_pretrained("Lightricks/LTX-Video", torch_dtype=torch.bfloat16)
|
||||
>>> pipe.to("cuda")
|
||||
|
||||
>>> image = load_image(
|
||||
... "https://huggingface.co/datasets/a-r-r-o-w/tiny-meme-dataset-captioned/resolve/main/images/11.png"
|
||||
>>> )
|
||||
>>> prompt = "A medieval fantasy scene featuring a rugged man with shoulder-length brown hair and a beard. He wears a dark leather tunic over a maroon shirt with intricate metal details. His facial expression is serious and intense, and he is making a gesture with his right hand, forming a small circle with his thumb and index finger. The warm golden lighting casts dramatic shadows on his face. The background includes an ornate stone arch and blurred medieval-style decor, creating an epic atmosphere."
|
||||
>>> negative_prompt = "worst quality, inconsistent motion, blurry, jittery, distorted"
|
||||
|
||||
>>> # Configure STG mode options
|
||||
>>> stg_applied_layers_idx = [19] # Layer indices from 0 to 41
|
||||
>>> stg_scale = 1.0 # Set 0.0 for CFG
|
||||
>>> do_rescaling = False
|
||||
|
||||
>>> video = pipe(
|
||||
... image=image,
|
||||
... prompt=prompt,
|
||||
... negative_prompt=negative_prompt,
|
||||
... width=704,
|
||||
... height=480,
|
||||
... num_frames=161,
|
||||
... num_inference_steps=50,
|
||||
... stg_applied_layers_idx=stg_applied_layers_idx,
|
||||
... stg_scale=stg_scale,
|
||||
... do_rescaling=do_rescaling,
|
||||
>>> ).frames[0]
|
||||
>>> export_to_video(video, "output.mp4", fps=24)
|
||||
```
|
||||
"""
|
||||
|
||||
|
||||
def forward_with_stg(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
encoder_hidden_states: torch.Tensor,
|
||||
temb: torch.Tensor,
|
||||
image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
|
||||
encoder_attention_mask: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
hidden_states_ptb = hidden_states[2:]
|
||||
encoder_hidden_states_ptb = encoder_hidden_states[2:]
|
||||
|
||||
batch_size = hidden_states.size(0)
|
||||
norm_hidden_states = self.norm1(hidden_states)
|
||||
|
||||
num_ada_params = self.scale_shift_table.shape[0]
|
||||
ada_values = self.scale_shift_table[None, None] + temb.reshape(batch_size, temb.size(1), num_ada_params, -1)
|
||||
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = ada_values.unbind(dim=2)
|
||||
norm_hidden_states = norm_hidden_states * (1 + scale_msa) + shift_msa
|
||||
|
||||
attn_hidden_states = self.attn1(
|
||||
hidden_states=norm_hidden_states,
|
||||
encoder_hidden_states=None,
|
||||
image_rotary_emb=image_rotary_emb,
|
||||
)
|
||||
hidden_states = hidden_states + attn_hidden_states * gate_msa
|
||||
|
||||
attn_hidden_states = self.attn2(
|
||||
hidden_states,
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
image_rotary_emb=None,
|
||||
attention_mask=encoder_attention_mask,
|
||||
)
|
||||
hidden_states = hidden_states + attn_hidden_states
|
||||
norm_hidden_states = self.norm2(hidden_states) * (1 + scale_mlp) + shift_mlp
|
||||
|
||||
ff_output = self.ff(norm_hidden_states)
|
||||
hidden_states = hidden_states + ff_output * gate_mlp
|
||||
|
||||
hidden_states[2:] = hidden_states_ptb
|
||||
encoder_hidden_states[2:] = encoder_hidden_states_ptb
|
||||
|
||||
return hidden_states
|
||||
|
||||
|
||||
# Copied from diffusers.pipelines.flux.pipeline_flux.calculate_shift
|
||||
def calculate_shift(
|
||||
image_seq_len,
|
||||
base_seq_len: int = 256,
|
||||
max_seq_len: int = 4096,
|
||||
base_shift: float = 0.5,
|
||||
max_shift: float = 1.16,
|
||||
):
|
||||
m = (max_shift - base_shift) / (max_seq_len - base_seq_len)
|
||||
b = base_shift - m * base_seq_len
|
||||
mu = image_seq_len * m + b
|
||||
return mu
|
||||
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps
|
||||
def retrieve_timesteps(
|
||||
scheduler,
|
||||
num_inference_steps: Optional[int] = None,
|
||||
device: Optional[Union[str, torch.device]] = None,
|
||||
timesteps: Optional[List[int]] = None,
|
||||
sigmas: Optional[List[float]] = None,
|
||||
**kwargs,
|
||||
):
|
||||
r"""
|
||||
Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
|
||||
custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.
|
||||
|
||||
Args:
|
||||
scheduler (`SchedulerMixin`):
|
||||
The scheduler to get timesteps from.
|
||||
num_inference_steps (`int`):
|
||||
The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps`
|
||||
must be `None`.
|
||||
device (`str` or `torch.device`, *optional*):
|
||||
The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
|
||||
timesteps (`List[int]`, *optional*):
|
||||
Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed,
|
||||
`num_inference_steps` and `sigmas` must be `None`.
|
||||
sigmas (`List[float]`, *optional*):
|
||||
Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed,
|
||||
`num_inference_steps` and `timesteps` must be `None`.
|
||||
|
||||
Returns:
|
||||
`Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the
|
||||
second element is the number of inference steps.
|
||||
"""
|
||||
if timesteps is not None and sigmas is not None:
|
||||
raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values")
|
||||
if timesteps is not None:
|
||||
accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
|
||||
if not accepts_timesteps:
|
||||
raise ValueError(
|
||||
f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
|
||||
f" timestep schedules. Please check whether you are using the correct scheduler."
|
||||
)
|
||||
scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)
|
||||
timesteps = scheduler.timesteps
|
||||
num_inference_steps = len(timesteps)
|
||||
elif sigmas is not None:
|
||||
accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
|
||||
if not accept_sigmas:
|
||||
raise ValueError(
|
||||
f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
|
||||
f" sigmas schedules. Please check whether you are using the correct scheduler."
|
||||
)
|
||||
scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs)
|
||||
timesteps = scheduler.timesteps
|
||||
num_inference_steps = len(timesteps)
|
||||
else:
|
||||
scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)
|
||||
timesteps = scheduler.timesteps
|
||||
return timesteps, num_inference_steps
|
||||
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents
|
||||
def retrieve_latents(
|
||||
encoder_output: torch.Tensor, generator: Optional[torch.Generator] = None, sample_mode: str = "sample"
|
||||
):
|
||||
if hasattr(encoder_output, "latent_dist") and sample_mode == "sample":
|
||||
return encoder_output.latent_dist.sample(generator)
|
||||
elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax":
|
||||
return encoder_output.latent_dist.mode()
|
||||
elif hasattr(encoder_output, "latents"):
|
||||
return encoder_output.latents
|
||||
else:
|
||||
raise AttributeError("Could not access latents of provided encoder_output")
|
||||
|
||||
|
||||
class LTXImageToVideoSTGPipeline(DiffusionPipeline, FromSingleFileMixin, LTXVideoLoraLoaderMixin):
|
||||
r"""
|
||||
Pipeline for image-to-video generation.
|
||||
|
||||
Reference: https://github.com/Lightricks/LTX-Video
|
||||
|
||||
Args:
|
||||
transformer ([`LTXVideoTransformer3DModel`]):
|
||||
Conditional Transformer architecture to denoise the encoded video latents.
|
||||
scheduler ([`FlowMatchEulerDiscreteScheduler`]):
|
||||
A scheduler to be used in combination with `transformer` to denoise the encoded image latents.
|
||||
vae ([`AutoencoderKLLTXVideo`]):
|
||||
Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.
|
||||
text_encoder ([`T5EncoderModel`]):
|
||||
[T5](https://huggingface.co/docs/transformers/en/model_doc/t5#transformers.T5EncoderModel), specifically
|
||||
the [google/t5-v1_1-xxl](https://huggingface.co/google/t5-v1_1-xxl) variant.
|
||||
tokenizer (`CLIPTokenizer`):
|
||||
Tokenizer of class
|
||||
[CLIPTokenizer](https://huggingface.co/docs/transformers/en/model_doc/clip#transformers.CLIPTokenizer).
|
||||
tokenizer (`T5TokenizerFast`):
|
||||
Second Tokenizer of class
|
||||
[T5TokenizerFast](https://huggingface.co/docs/transformers/en/model_doc/t5#transformers.T5TokenizerFast).
|
||||
"""
|
||||
|
||||
model_cpu_offload_seq = "text_encoder->transformer->vae"
|
||||
_optional_components = []
|
||||
_callback_tensor_inputs = ["latents", "prompt_embeds", "negative_prompt_embeds"]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
scheduler: FlowMatchEulerDiscreteScheduler,
|
||||
vae: AutoencoderKLLTXVideo,
|
||||
text_encoder: T5EncoderModel,
|
||||
tokenizer: T5TokenizerFast,
|
||||
transformer: LTXVideoTransformer3DModel,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.register_modules(
|
||||
vae=vae,
|
||||
text_encoder=text_encoder,
|
||||
tokenizer=tokenizer,
|
||||
transformer=transformer,
|
||||
scheduler=scheduler,
|
||||
)
|
||||
|
||||
self.vae_spatial_compression_ratio = (
|
||||
self.vae.spatial_compression_ratio if getattr(self, "vae", None) is not None else 32
|
||||
)
|
||||
self.vae_temporal_compression_ratio = (
|
||||
self.vae.temporal_compression_ratio if getattr(self, "vae", None) is not None else 8
|
||||
)
|
||||
self.transformer_spatial_patch_size = (
|
||||
self.transformer.config.patch_size if getattr(self, "transformer", None) is not None else 1
|
||||
)
|
||||
self.transformer_temporal_patch_size = (
|
||||
self.transformer.config.patch_size_t if getattr(self, "transformer") is not None else 1
|
||||
)
|
||||
|
||||
self.video_processor = VideoProcessor(vae_scale_factor=self.vae_spatial_compression_ratio)
|
||||
self.tokenizer_max_length = (
|
||||
self.tokenizer.model_max_length if getattr(self, "tokenizer", None) is not None else 128
|
||||
)
|
||||
|
||||
self.default_height = 512
|
||||
self.default_width = 704
|
||||
self.default_frames = 121
|
||||
|
||||
def _get_t5_prompt_embeds(
|
||||
self,
|
||||
prompt: Union[str, List[str]] = None,
|
||||
num_videos_per_prompt: int = 1,
|
||||
max_sequence_length: int = 128,
|
||||
device: Optional[torch.device] = None,
|
||||
dtype: Optional[torch.dtype] = None,
|
||||
):
|
||||
device = device or self._execution_device
|
||||
dtype = dtype or self.text_encoder.dtype
|
||||
|
||||
prompt = [prompt] if isinstance(prompt, str) else prompt
|
||||
batch_size = len(prompt)
|
||||
|
||||
text_inputs = self.tokenizer(
|
||||
prompt,
|
||||
padding="max_length",
|
||||
max_length=max_sequence_length,
|
||||
truncation=True,
|
||||
add_special_tokens=True,
|
||||
return_tensors="pt",
|
||||
)
|
||||
text_input_ids = text_inputs.input_ids
|
||||
prompt_attention_mask = text_inputs.attention_mask
|
||||
prompt_attention_mask = prompt_attention_mask.bool().to(device)
|
||||
|
||||
untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
|
||||
|
||||
if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids):
|
||||
removed_text = self.tokenizer.batch_decode(untruncated_ids[:, max_sequence_length - 1 : -1])
|
||||
logger.warning(
|
||||
"The following part of your input was truncated because `max_sequence_length` is set to "
|
||||
f" {max_sequence_length} tokens: {removed_text}"
|
||||
)
|
||||
|
||||
prompt_embeds = self.text_encoder(text_input_ids.to(device))[0]
|
||||
prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)
|
||||
|
||||
# duplicate text embeddings for each generation per prompt, using mps friendly method
|
||||
_, seq_len, _ = prompt_embeds.shape
|
||||
prompt_embeds = prompt_embeds.repeat(1, num_videos_per_prompt, 1)
|
||||
prompt_embeds = prompt_embeds.view(batch_size * num_videos_per_prompt, seq_len, -1)
|
||||
|
||||
prompt_attention_mask = prompt_attention_mask.view(batch_size, -1)
|
||||
prompt_attention_mask = prompt_attention_mask.repeat(num_videos_per_prompt, 1)
|
||||
|
||||
return prompt_embeds, prompt_attention_mask
|
||||
|
||||
# Copied from diffusers.pipelines.mochi.pipeline_mochi.MochiPipeline.encode_prompt with 256->128
|
||||
def encode_prompt(
|
||||
self,
|
||||
prompt: Union[str, List[str]],
|
||||
negative_prompt: Optional[Union[str, List[str]]] = None,
|
||||
do_classifier_free_guidance: bool = True,
|
||||
num_videos_per_prompt: int = 1,
|
||||
prompt_embeds: Optional[torch.Tensor] = None,
|
||||
negative_prompt_embeds: Optional[torch.Tensor] = None,
|
||||
prompt_attention_mask: Optional[torch.Tensor] = None,
|
||||
negative_prompt_attention_mask: Optional[torch.Tensor] = None,
|
||||
max_sequence_length: int = 128,
|
||||
device: Optional[torch.device] = None,
|
||||
dtype: Optional[torch.dtype] = None,
|
||||
):
|
||||
r"""
|
||||
Encodes the prompt into text encoder hidden states.
|
||||
|
||||
Args:
|
||||
prompt (`str` or `List[str]`, *optional*):
|
||||
prompt to be encoded
|
||||
negative_prompt (`str` or `List[str]`, *optional*):
|
||||
The prompt or prompts not to guide the image generation. If not defined, one has to pass
|
||||
`negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
|
||||
less than `1`).
|
||||
do_classifier_free_guidance (`bool`, *optional*, defaults to `True`):
|
||||
Whether to use classifier free guidance or not.
|
||||
num_videos_per_prompt (`int`, *optional*, defaults to 1):
|
||||
Number of videos that should be generated per prompt. torch device to place the resulting embeddings on
|
||||
prompt_embeds (`torch.Tensor`, *optional*):
|
||||
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
|
||||
provided, text embeddings will be generated from `prompt` input argument.
|
||||
negative_prompt_embeds (`torch.Tensor`, *optional*):
|
||||
Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
|
||||
weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
|
||||
argument.
|
||||
device: (`torch.device`, *optional*):
|
||||
torch device
|
||||
dtype: (`torch.dtype`, *optional*):
|
||||
torch dtype
|
||||
"""
|
||||
device = device or self._execution_device
|
||||
|
||||
prompt = [prompt] if isinstance(prompt, str) else prompt
|
||||
if prompt is not None:
|
||||
batch_size = len(prompt)
|
||||
else:
|
||||
batch_size = prompt_embeds.shape[0]
|
||||
|
||||
if prompt_embeds is None:
|
||||
prompt_embeds, prompt_attention_mask = self._get_t5_prompt_embeds(
|
||||
prompt=prompt,
|
||||
num_videos_per_prompt=num_videos_per_prompt,
|
||||
max_sequence_length=max_sequence_length,
|
||||
device=device,
|
||||
dtype=dtype,
|
||||
)
|
||||
|
||||
if do_classifier_free_guidance and negative_prompt_embeds is None:
|
||||
negative_prompt = negative_prompt or ""
|
||||
negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt
|
||||
|
||||
if prompt is not None and type(prompt) is not type(negative_prompt):
|
||||
raise TypeError(
|
||||
f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
|
||||
f" {type(prompt)}."
|
||||
)
|
||||
elif batch_size != len(negative_prompt):
|
||||
raise ValueError(
|
||||
f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
|
||||
f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
|
||||
" the batch size of `prompt`."
|
||||
)
|
||||
|
||||
negative_prompt_embeds, negative_prompt_attention_mask = self._get_t5_prompt_embeds(
|
||||
prompt=negative_prompt,
|
||||
num_videos_per_prompt=num_videos_per_prompt,
|
||||
max_sequence_length=max_sequence_length,
|
||||
device=device,
|
||||
dtype=dtype,
|
||||
)
|
||||
|
||||
return prompt_embeds, prompt_attention_mask, negative_prompt_embeds, negative_prompt_attention_mask
|
||||
|
||||
# Copied from diffusers.pipelines.ltx.pipeline_ltx.LTXPipeline.check_inputs
|
||||
def check_inputs(
|
||||
self,
|
||||
prompt,
|
||||
height,
|
||||
width,
|
||||
callback_on_step_end_tensor_inputs=None,
|
||||
prompt_embeds=None,
|
||||
negative_prompt_embeds=None,
|
||||
prompt_attention_mask=None,
|
||||
negative_prompt_attention_mask=None,
|
||||
):
|
||||
if height % 32 != 0 or width % 32 != 0:
|
||||
raise ValueError(f"`height` and `width` have to be divisible by 32 but are {height} and {width}.")
|
||||
|
||||
if callback_on_step_end_tensor_inputs is not None and not all(
|
||||
k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
|
||||
):
|
||||
raise ValueError(
|
||||
f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}"
|
||||
)
|
||||
|
||||
if prompt is not None and prompt_embeds is not None:
|
||||
raise ValueError(
|
||||
f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
|
||||
" only forward one of the two."
|
||||
)
|
||||
elif prompt is None and prompt_embeds is None:
|
||||
raise ValueError(
|
||||
"Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
|
||||
)
|
||||
elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
|
||||
raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
|
||||
|
||||
if prompt_embeds is not None and prompt_attention_mask is None:
|
||||
raise ValueError("Must provide `prompt_attention_mask` when specifying `prompt_embeds`.")
|
||||
|
||||
if negative_prompt_embeds is not None and negative_prompt_attention_mask is None:
|
||||
raise ValueError("Must provide `negative_prompt_attention_mask` when specifying `negative_prompt_embeds`.")
|
||||
|
||||
if prompt_embeds is not None and negative_prompt_embeds is not None:
|
||||
if prompt_embeds.shape != negative_prompt_embeds.shape:
|
||||
raise ValueError(
|
||||
"`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but"
|
||||
f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`"
|
||||
f" {negative_prompt_embeds.shape}."
|
||||
)
|
||||
if prompt_attention_mask.shape != negative_prompt_attention_mask.shape:
|
||||
raise ValueError(
|
||||
"`prompt_attention_mask` and `negative_prompt_attention_mask` must have the same shape when passed directly, but"
|
||||
f" got: `prompt_attention_mask` {prompt_attention_mask.shape} != `negative_prompt_attention_mask`"
|
||||
f" {negative_prompt_attention_mask.shape}."
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
# Copied from diffusers.pipelines.ltx.pipeline_ltx.LTXPipeline._pack_latents
|
||||
def _pack_latents(latents: torch.Tensor, patch_size: int = 1, patch_size_t: int = 1) -> torch.Tensor:
|
||||
# Unpacked latents of shape are [B, C, F, H, W] are patched into tokens of shape [B, C, F // p_t, p_t, H // p, p, W // p, p].
|
||||
# The patch dimensions are then permuted and collapsed into the channel dimension of shape:
|
||||
# [B, F // p_t * H // p * W // p, C * p_t * p * p] (an ndim=3 tensor).
|
||||
# dim=0 is the batch size, dim=1 is the effective video sequence length, dim=2 is the effective number of input features
|
||||
batch_size, num_channels, num_frames, height, width = latents.shape
|
||||
post_patch_num_frames = num_frames // patch_size_t
|
||||
post_patch_height = height // patch_size
|
||||
post_patch_width = width // patch_size
|
||||
latents = latents.reshape(
|
||||
batch_size,
|
||||
-1,
|
||||
post_patch_num_frames,
|
||||
patch_size_t,
|
||||
post_patch_height,
|
||||
patch_size,
|
||||
post_patch_width,
|
||||
patch_size,
|
||||
)
|
||||
latents = latents.permute(0, 2, 4, 6, 1, 3, 5, 7).flatten(4, 7).flatten(1, 3)
|
||||
return latents
|
||||
|
||||
@staticmethod
|
||||
# Copied from diffusers.pipelines.ltx.pipeline_ltx.LTXPipeline._unpack_latents
|
||||
def _unpack_latents(
|
||||
latents: torch.Tensor, num_frames: int, height: int, width: int, patch_size: int = 1, patch_size_t: int = 1
|
||||
) -> torch.Tensor:
|
||||
# Packed latents of shape [B, S, D] (S is the effective video sequence length, D is the effective feature dimensions)
|
||||
# are unpacked and reshaped into a video tensor of shape [B, C, F, H, W]. This is the inverse operation of
|
||||
# what happens in the `_pack_latents` method.
|
||||
batch_size = latents.size(0)
|
||||
latents = latents.reshape(batch_size, num_frames, height, width, -1, patch_size_t, patch_size, patch_size)
|
||||
latents = latents.permute(0, 4, 1, 5, 2, 6, 3, 7).flatten(6, 7).flatten(4, 5).flatten(2, 3)
|
||||
return latents
|
||||
|
||||
@staticmethod
|
||||
# Copied from diffusers.pipelines.ltx.pipeline_ltx.LTXPipeline._normalize_latents
|
||||
def _normalize_latents(
|
||||
latents: torch.Tensor, latents_mean: torch.Tensor, latents_std: torch.Tensor, scaling_factor: float = 1.0
|
||||
) -> torch.Tensor:
|
||||
# Normalize latents across the channel dimension [B, C, F, H, W]
|
||||
latents_mean = latents_mean.view(1, -1, 1, 1, 1).to(latents.device, latents.dtype)
|
||||
latents_std = latents_std.view(1, -1, 1, 1, 1).to(latents.device, latents.dtype)
|
||||
latents = (latents - latents_mean) * scaling_factor / latents_std
|
||||
return latents
|
||||
|
||||
@staticmethod
|
||||
# Copied from diffusers.pipelines.ltx.pipeline_ltx.LTXPipeline._denormalize_latents
|
||||
def _denormalize_latents(
|
||||
latents: torch.Tensor, latents_mean: torch.Tensor, latents_std: torch.Tensor, scaling_factor: float = 1.0
|
||||
) -> torch.Tensor:
|
||||
# Denormalize latents across the channel dimension [B, C, F, H, W]
|
||||
latents_mean = latents_mean.view(1, -1, 1, 1, 1).to(latents.device, latents.dtype)
|
||||
latents_std = latents_std.view(1, -1, 1, 1, 1).to(latents.device, latents.dtype)
|
||||
latents = latents * latents_std / scaling_factor + latents_mean
|
||||
return latents
|
||||
|
||||
def prepare_latents(
|
||||
self,
|
||||
image: Optional[torch.Tensor] = None,
|
||||
batch_size: int = 1,
|
||||
num_channels_latents: int = 128,
|
||||
height: int = 512,
|
||||
width: int = 704,
|
||||
num_frames: int = 161,
|
||||
dtype: Optional[torch.dtype] = None,
|
||||
device: Optional[torch.device] = None,
|
||||
generator: Optional[torch.Generator] = None,
|
||||
latents: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
height = height // self.vae_spatial_compression_ratio
|
||||
width = width // self.vae_spatial_compression_ratio
|
||||
num_frames = (
|
||||
(num_frames - 1) // self.vae_temporal_compression_ratio + 1 if latents is None else latents.size(2)
|
||||
)
|
||||
|
||||
shape = (batch_size, num_channels_latents, num_frames, height, width)
|
||||
mask_shape = (batch_size, 1, num_frames, height, width)
|
||||
|
||||
if latents is not None:
|
||||
conditioning_mask = latents.new_zeros(shape)
|
||||
conditioning_mask[:, :, 0] = 1.0
|
||||
conditioning_mask = self._pack_latents(
|
||||
conditioning_mask, self.transformer_spatial_patch_size, self.transformer_temporal_patch_size
|
||||
)
|
||||
return latents.to(device=device, dtype=dtype), conditioning_mask
|
||||
|
||||
if isinstance(generator, list):
|
||||
if len(generator) != batch_size:
|
||||
raise ValueError(
|
||||
f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
|
||||
f" size of {batch_size}. Make sure the batch size matches the length of the generators."
|
||||
)
|
||||
|
||||
init_latents = [
|
||||
retrieve_latents(self.vae.encode(image[i].unsqueeze(0).unsqueeze(2)), generator[i])
|
||||
for i in range(batch_size)
|
||||
]
|
||||
else:
|
||||
init_latents = [
|
||||
retrieve_latents(self.vae.encode(img.unsqueeze(0).unsqueeze(2)), generator) for img in image
|
||||
]
|
||||
|
||||
init_latents = torch.cat(init_latents, dim=0).to(dtype)
|
||||
init_latents = self._normalize_latents(init_latents, self.vae.latents_mean, self.vae.latents_std)
|
||||
init_latents = init_latents.repeat(1, 1, num_frames, 1, 1)
|
||||
conditioning_mask = torch.zeros(mask_shape, device=device, dtype=dtype)
|
||||
conditioning_mask[:, :, 0] = 1.0
|
||||
|
||||
noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
|
||||
latents = init_latents * conditioning_mask + noise * (1 - conditioning_mask)
|
||||
|
||||
conditioning_mask = self._pack_latents(
|
||||
conditioning_mask, self.transformer_spatial_patch_size, self.transformer_temporal_patch_size
|
||||
).squeeze(-1)
|
||||
latents = self._pack_latents(
|
||||
latents, self.transformer_spatial_patch_size, self.transformer_temporal_patch_size
|
||||
)
|
||||
|
||||
return latents, conditioning_mask
|
||||
|
||||
@property
|
||||
def guidance_scale(self):
|
||||
return self._guidance_scale
|
||||
|
||||
@property
|
||||
def do_classifier_free_guidance(self):
|
||||
return self._guidance_scale > 1.0
|
||||
|
||||
@property
|
||||
def do_spatio_temporal_guidance(self):
|
||||
return self._stg_scale > 0.0
|
||||
|
||||
@property
|
||||
def num_timesteps(self):
|
||||
return self._num_timesteps
|
||||
|
||||
@property
|
||||
def attention_kwargs(self):
|
||||
return self._attention_kwargs
|
||||
|
||||
@property
|
||||
def interrupt(self):
|
||||
return self._interrupt
|
||||
|
||||
@torch.no_grad()
|
||||
@replace_example_docstring(EXAMPLE_DOC_STRING)
|
||||
def __call__(
|
||||
self,
|
||||
image: PipelineImageInput = None,
|
||||
prompt: Union[str, List[str]] = None,
|
||||
negative_prompt: Optional[Union[str, List[str]]] = None,
|
||||
height: int = 512,
|
||||
width: int = 704,
|
||||
num_frames: int = 161,
|
||||
frame_rate: int = 25,
|
||||
num_inference_steps: int = 50,
|
||||
timesteps: List[int] = None,
|
||||
guidance_scale: float = 3,
|
||||
num_videos_per_prompt: Optional[int] = 1,
|
||||
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
|
||||
latents: Optional[torch.Tensor] = None,
|
||||
prompt_embeds: Optional[torch.Tensor] = None,
|
||||
prompt_attention_mask: Optional[torch.Tensor] = None,
|
||||
negative_prompt_embeds: Optional[torch.Tensor] = None,
|
||||
negative_prompt_attention_mask: Optional[torch.Tensor] = None,
|
||||
decode_timestep: Union[float, List[float]] = 0.0,
|
||||
decode_noise_scale: Optional[Union[float, List[float]]] = None,
|
||||
output_type: Optional[str] = "pil",
|
||||
return_dict: bool = True,
|
||||
attention_kwargs: Optional[Dict[str, Any]] = None,
|
||||
callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
|
||||
callback_on_step_end_tensor_inputs: List[str] = ["latents"],
|
||||
max_sequence_length: int = 128,
|
||||
stg_applied_layers_idx: Optional[List[int]] = [19],
|
||||
stg_scale: Optional[float] = 1.0,
|
||||
do_rescaling: Optional[bool] = False,
|
||||
):
|
||||
r"""
|
||||
Function invoked when calling the pipeline for generation.
|
||||
|
||||
Args:
|
||||
image (`PipelineImageInput`):
|
||||
The input image to condition the generation on. Must be an image, a list of images or a `torch.Tensor`.
|
||||
prompt (`str` or `List[str]`, *optional*):
|
||||
The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
|
||||
instead.
|
||||
height (`int`, defaults to `512`):
|
||||
The height in pixels of the generated image. This is set to 480 by default for the best results.
|
||||
width (`int`, defaults to `704`):
|
||||
The width in pixels of the generated image. This is set to 848 by default for the best results.
|
||||
num_frames (`int`, defaults to `161`):
|
||||
The number of video frames to generate
|
||||
num_inference_steps (`int`, *optional*, defaults to 50):
|
||||
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
|
||||
expense of slower inference.
|
||||
timesteps (`List[int]`, *optional*):
|
||||
Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument
|
||||
in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is
|
||||
passed will be used. Must be in descending order.
|
||||
guidance_scale (`float`, defaults to `3 `):
|
||||
Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
|
||||
`guidance_scale` is defined as `w` of equation 2. of [Imagen
|
||||
Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
|
||||
1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
|
||||
usually at the expense of lower image quality.
|
||||
num_videos_per_prompt (`int`, *optional*, defaults to 1):
|
||||
The number of videos to generate per prompt.
|
||||
generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
|
||||
One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
|
||||
to make generation deterministic.
|
||||
latents (`torch.Tensor`, *optional*):
|
||||
Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
|
||||
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
|
||||
tensor will ge generated by sampling using the supplied random `generator`.
|
||||
prompt_embeds (`torch.Tensor`, *optional*):
|
||||
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
|
||||
provided, text embeddings will be generated from `prompt` input argument.
|
||||
prompt_attention_mask (`torch.Tensor`, *optional*):
|
||||
Pre-generated attention mask for text embeddings.
|
||||
negative_prompt_embeds (`torch.FloatTensor`, *optional*):
|
||||
Pre-generated negative text embeddings. For PixArt-Sigma this negative prompt should be "". If not
|
||||
provided, negative_prompt_embeds will be generated from `negative_prompt` input argument.
|
||||
negative_prompt_attention_mask (`torch.FloatTensor`, *optional*):
|
||||
Pre-generated attention mask for negative text embeddings.
|
||||
decode_timestep (`float`, defaults to `0.0`):
|
||||
The timestep at which generated video is decoded.
|
||||
decode_noise_scale (`float`, defaults to `None`):
|
||||
The interpolation factor between random noise and denoised latents at the decode timestep.
|
||||
output_type (`str`, *optional*, defaults to `"pil"`):
|
||||
The output format of the generate image. Choose between
|
||||
[PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
|
||||
return_dict (`bool`, *optional*, defaults to `True`):
|
||||
Whether or not to return a [`~pipelines.ltx.LTXPipelineOutput`] instead of a plain tuple.
|
||||
attention_kwargs (`dict`, *optional*):
|
||||
A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
|
||||
`self.processor` in
|
||||
[diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
|
||||
callback_on_step_end (`Callable`, *optional*):
|
||||
A function that calls at the end of each denoising steps during the inference. The function is called
|
||||
with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int,
|
||||
callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by
|
||||
`callback_on_step_end_tensor_inputs`.
|
||||
callback_on_step_end_tensor_inputs (`List`, *optional*):
|
||||
The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
|
||||
will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
|
||||
`._callback_tensor_inputs` attribute of your pipeline class.
|
||||
max_sequence_length (`int` defaults to `128 `):
|
||||
Maximum sequence length to use with the `prompt`.
|
||||
|
||||
Examples:
|
||||
|
||||
Returns:
|
||||
[`~pipelines.ltx.LTXPipelineOutput`] or `tuple`:
|
||||
If `return_dict` is `True`, [`~pipelines.ltx.LTXPipelineOutput`] is returned, otherwise a `tuple` is
|
||||
returned where the first element is a list with the generated images.
|
||||
"""
|
||||
|
||||
if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)):
|
||||
callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs
|
||||
|
||||
# 1. Check inputs. Raise error if not correct
|
||||
self.check_inputs(
|
||||
prompt=prompt,
|
||||
height=height,
|
||||
width=width,
|
||||
callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs,
|
||||
prompt_embeds=prompt_embeds,
|
||||
negative_prompt_embeds=negative_prompt_embeds,
|
||||
prompt_attention_mask=prompt_attention_mask,
|
||||
negative_prompt_attention_mask=negative_prompt_attention_mask,
|
||||
)
|
||||
|
||||
self._stg_scale = stg_scale
|
||||
self._guidance_scale = guidance_scale
|
||||
self._attention_kwargs = attention_kwargs
|
||||
self._interrupt = False
|
||||
|
||||
if self.do_spatio_temporal_guidance:
|
||||
for i in stg_applied_layers_idx:
|
||||
self.transformer.transformer_blocks[i].forward = types.MethodType(
|
||||
forward_with_stg, self.transformer.transformer_blocks[i]
|
||||
)
|
||||
|
||||
# 2. Define call parameters
|
||||
if prompt is not None and isinstance(prompt, str):
|
||||
batch_size = 1
|
||||
elif prompt is not None and isinstance(prompt, list):
|
||||
batch_size = len(prompt)
|
||||
else:
|
||||
batch_size = prompt_embeds.shape[0]
|
||||
|
||||
device = self._execution_device
|
||||
|
||||
# 3. Prepare text embeddings
|
||||
(
|
||||
prompt_embeds,
|
||||
prompt_attention_mask,
|
||||
negative_prompt_embeds,
|
||||
negative_prompt_attention_mask,
|
||||
) = self.encode_prompt(
|
||||
prompt=prompt,
|
||||
negative_prompt=negative_prompt,
|
||||
do_classifier_free_guidance=self.do_classifier_free_guidance,
|
||||
num_videos_per_prompt=num_videos_per_prompt,
|
||||
prompt_embeds=prompt_embeds,
|
||||
negative_prompt_embeds=negative_prompt_embeds,
|
||||
prompt_attention_mask=prompt_attention_mask,
|
||||
negative_prompt_attention_mask=negative_prompt_attention_mask,
|
||||
max_sequence_length=max_sequence_length,
|
||||
device=device,
|
||||
)
|
||||
if self.do_classifier_free_guidance and not self.do_spatio_temporal_guidance:
|
||||
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)
|
||||
prompt_attention_mask = torch.cat([negative_prompt_attention_mask, prompt_attention_mask], dim=0)
|
||||
elif self.do_classifier_free_guidance and self.do_spatio_temporal_guidance:
|
||||
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds, prompt_embeds], dim=0)
|
||||
prompt_attention_mask = torch.cat(
|
||||
[negative_prompt_attention_mask, prompt_attention_mask, prompt_attention_mask], dim=0
|
||||
)
|
||||
|
||||
# 4. Prepare latent variables
|
||||
if latents is None:
|
||||
image = self.video_processor.preprocess(image, height=height, width=width)
|
||||
image = image.to(device=device, dtype=prompt_embeds.dtype)
|
||||
|
||||
num_channels_latents = self.transformer.config.in_channels
|
||||
latents, conditioning_mask = self.prepare_latents(
|
||||
image,
|
||||
batch_size * num_videos_per_prompt,
|
||||
num_channels_latents,
|
||||
height,
|
||||
width,
|
||||
num_frames,
|
||||
torch.float32,
|
||||
device,
|
||||
generator,
|
||||
latents,
|
||||
)
|
||||
|
||||
if self.do_classifier_free_guidance and not self.do_spatio_temporal_guidance:
|
||||
conditioning_mask = torch.cat([conditioning_mask, conditioning_mask])
|
||||
elif self.do_classifier_free_guidance and self.do_spatio_temporal_guidance:
|
||||
conditioning_mask = torch.cat([conditioning_mask, conditioning_mask, conditioning_mask])
|
||||
|
||||
# 5. Prepare timesteps
|
||||
latent_num_frames = (num_frames - 1) // self.vae_temporal_compression_ratio + 1
|
||||
latent_height = height // self.vae_spatial_compression_ratio
|
||||
latent_width = width // self.vae_spatial_compression_ratio
|
||||
video_sequence_length = latent_num_frames * latent_height * latent_width
|
||||
sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps)
|
||||
mu = calculate_shift(
|
||||
video_sequence_length,
|
||||
self.scheduler.config.get("base_image_seq_len", 256),
|
||||
self.scheduler.config.get("max_image_seq_len", 4096),
|
||||
self.scheduler.config.get("base_shift", 0.5),
|
||||
self.scheduler.config.get("max_shift", 1.16),
|
||||
)
|
||||
timesteps, num_inference_steps = retrieve_timesteps(
|
||||
self.scheduler,
|
||||
num_inference_steps,
|
||||
device,
|
||||
timesteps,
|
||||
sigmas=sigmas,
|
||||
mu=mu,
|
||||
)
|
||||
num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
|
||||
self._num_timesteps = len(timesteps)
|
||||
|
||||
# 6. Prepare micro-conditions
|
||||
latent_frame_rate = frame_rate / self.vae_temporal_compression_ratio
|
||||
rope_interpolation_scale = (
|
||||
1 / latent_frame_rate,
|
||||
self.vae_spatial_compression_ratio,
|
||||
self.vae_spatial_compression_ratio,
|
||||
)
|
||||
|
||||
# 7. Denoising loop
|
||||
with self.progress_bar(total=num_inference_steps) as progress_bar:
|
||||
for i, t in enumerate(timesteps):
|
||||
if self.interrupt:
|
||||
continue
|
||||
|
||||
if self.do_classifier_free_guidance and not self.do_spatio_temporal_guidance:
|
||||
latent_model_input = torch.cat([latents] * 2)
|
||||
elif self.do_classifier_free_guidance and self.do_spatio_temporal_guidance:
|
||||
latent_model_input = torch.cat([latents] * 3)
|
||||
else:
|
||||
latent_model_input = latents
|
||||
|
||||
latent_model_input = latent_model_input.to(prompt_embeds.dtype)
|
||||
|
||||
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
|
||||
timestep = t.expand(latent_model_input.shape[0])
|
||||
timestep = timestep.unsqueeze(-1) * (1 - conditioning_mask)
|
||||
|
||||
noise_pred = self.transformer(
|
||||
hidden_states=latent_model_input,
|
||||
encoder_hidden_states=prompt_embeds,
|
||||
timestep=timestep,
|
||||
encoder_attention_mask=prompt_attention_mask,
|
||||
num_frames=latent_num_frames,
|
||||
height=latent_height,
|
||||
width=latent_width,
|
||||
rope_interpolation_scale=rope_interpolation_scale,
|
||||
attention_kwargs=attention_kwargs,
|
||||
return_dict=False,
|
||||
)[0]
|
||||
noise_pred = noise_pred.float()
|
||||
|
||||
if self.do_classifier_free_guidance and not self.do_spatio_temporal_guidance:
|
||||
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
|
||||
noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond)
|
||||
timestep, _ = timestep.chunk(2)
|
||||
elif self.do_classifier_free_guidance and self.do_spatio_temporal_guidance:
|
||||
noise_pred_uncond, noise_pred_text, noise_pred_perturb = noise_pred.chunk(3)
|
||||
noise_pred = (
|
||||
noise_pred_uncond
|
||||
+ self.guidance_scale * (noise_pred_text - noise_pred_uncond)
|
||||
+ self._stg_scale * (noise_pred_text - noise_pred_perturb)
|
||||
)
|
||||
timestep, _, _ = timestep.chunk(3)
|
||||
|
||||
if do_rescaling:
|
||||
rescaling_scale = 0.7
|
||||
factor = noise_pred_text.std() / noise_pred.std()
|
||||
factor = rescaling_scale * factor + (1 - rescaling_scale)
|
||||
noise_pred = noise_pred * factor
|
||||
|
||||
# compute the previous noisy sample x_t -> x_t-1
|
||||
noise_pred = self._unpack_latents(
|
||||
noise_pred,
|
||||
latent_num_frames,
|
||||
latent_height,
|
||||
latent_width,
|
||||
self.transformer_spatial_patch_size,
|
||||
self.transformer_temporal_patch_size,
|
||||
)
|
||||
latents = self._unpack_latents(
|
||||
latents,
|
||||
latent_num_frames,
|
||||
latent_height,
|
||||
latent_width,
|
||||
self.transformer_spatial_patch_size,
|
||||
self.transformer_temporal_patch_size,
|
||||
)
|
||||
|
||||
noise_pred = noise_pred[:, :, 1:]
|
||||
noise_latents = latents[:, :, 1:]
|
||||
pred_latents = self.scheduler.step(noise_pred, t, noise_latents, return_dict=False)[0]
|
||||
|
||||
latents = torch.cat([latents[:, :, :1], pred_latents], dim=2)
|
||||
latents = self._pack_latents(
|
||||
latents, self.transformer_spatial_patch_size, self.transformer_temporal_patch_size
|
||||
)
|
||||
|
||||
if callback_on_step_end is not None:
|
||||
callback_kwargs = {}
|
||||
for k in callback_on_step_end_tensor_inputs:
|
||||
callback_kwargs[k] = locals()[k]
|
||||
callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
|
||||
|
||||
latents = callback_outputs.pop("latents", latents)
|
||||
prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
|
||||
|
||||
# call the callback, if provided
|
||||
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
|
||||
progress_bar.update()
|
||||
|
||||
if XLA_AVAILABLE:
|
||||
xm.mark_step()
|
||||
|
||||
if output_type == "latent":
|
||||
video = latents
|
||||
else:
|
||||
latents = self._unpack_latents(
|
||||
latents,
|
||||
latent_num_frames,
|
||||
latent_height,
|
||||
latent_width,
|
||||
self.transformer_spatial_patch_size,
|
||||
self.transformer_temporal_patch_size,
|
||||
)
|
||||
latents = self._denormalize_latents(
|
||||
latents, self.vae.latents_mean, self.vae.latents_std, self.vae.config.scaling_factor
|
||||
)
|
||||
latents = latents.to(prompt_embeds.dtype)
|
||||
|
||||
if not self.vae.config.timestep_conditioning:
|
||||
timestep = None
|
||||
else:
|
||||
noise = torch.randn(latents.shape, generator=generator, device=device, dtype=latents.dtype)
|
||||
if not isinstance(decode_timestep, list):
|
||||
decode_timestep = [decode_timestep] * batch_size
|
||||
if decode_noise_scale is None:
|
||||
decode_noise_scale = decode_timestep
|
||||
elif not isinstance(decode_noise_scale, list):
|
||||
decode_noise_scale = [decode_noise_scale] * batch_size
|
||||
|
||||
timestep = torch.tensor(decode_timestep, device=device, dtype=latents.dtype)
|
||||
decode_noise_scale = torch.tensor(decode_noise_scale, device=device, dtype=latents.dtype)[
|
||||
:, None, None, None, None
|
||||
]
|
||||
latents = (1 - decode_noise_scale) * latents + decode_noise_scale * noise
|
||||
|
||||
video = self.vae.decode(latents, timestep, return_dict=False)[0]
|
||||
video = self.video_processor.postprocess_video(video, output_type=output_type)
|
||||
|
||||
# Offload all models
|
||||
self.maybe_free_model_hooks()
|
||||
|
||||
if not return_dict:
|
||||
return (video,)
|
||||
|
||||
return LTXPipelineOutput(frames=video)
|
||||
@@ -0,0 +1,843 @@
|
||||
# Copyright 2024 Genmo and The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import inspect
|
||||
import types
|
||||
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from transformers import T5EncoderModel, T5TokenizerFast
|
||||
|
||||
from diffusers.callbacks import MultiPipelineCallbacks, PipelineCallback
|
||||
from diffusers.loaders import Mochi1LoraLoaderMixin
|
||||
from diffusers.models import AutoencoderKLMochi, MochiTransformer3DModel
|
||||
from diffusers.pipelines.mochi.pipeline_output import MochiPipelineOutput
|
||||
from diffusers.pipelines.pipeline_utils import DiffusionPipeline
|
||||
from diffusers.schedulers import FlowMatchEulerDiscreteScheduler
|
||||
from diffusers.utils import (
|
||||
is_torch_xla_available,
|
||||
logging,
|
||||
replace_example_docstring,
|
||||
)
|
||||
from diffusers.utils.torch_utils import randn_tensor
|
||||
from diffusers.video_processor import VideoProcessor
|
||||
|
||||
|
||||
if is_torch_xla_available():
|
||||
import torch_xla.core.xla_model as xm
|
||||
|
||||
XLA_AVAILABLE = True
|
||||
else:
|
||||
XLA_AVAILABLE = False
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||
|
||||
EXAMPLE_DOC_STRING = """
|
||||
Examples:
|
||||
```py
|
||||
>>> import torch
|
||||
>>> from diffusers.utils import export_to_video
|
||||
>>> from examples.community.pipeline_stg_mochi import MochiSTGPipeline
|
||||
|
||||
>>> pipe = MochiSTGPipeline.from_pretrained("genmo/mochi-1-preview", torch_dtype=torch.bfloat16)
|
||||
>>> pipe.enable_model_cpu_offload()
|
||||
>>> pipe.enable_vae_tiling()
|
||||
>>> prompt = "A close-up of a beautiful woman's face with colored powder exploding around her, creating an abstract splash of vibrant hues, realistic style."
|
||||
|
||||
>>> # Configure STG mode options
|
||||
>>> stg_applied_layers_idx = [34] # Layer indices from 0 to 41
|
||||
>>> stg_scale = 1.0 # Set 0.0 for CFG
|
||||
>>> do_rescaling = False
|
||||
|
||||
>>> frames = pipe(
|
||||
... prompt=prompt,
|
||||
... num_inference_steps=28,
|
||||
... guidance_scale=3.5,
|
||||
... stg_applied_layers_idx=stg_applied_layers_idx,
|
||||
... stg_scale=stg_scale,
|
||||
... do_rescaling=do_rescaling).frames[0]
|
||||
>>> export_to_video(frames, "mochi.mp4")
|
||||
```
|
||||
"""
|
||||
|
||||
|
||||
def forward_with_stg(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
encoder_hidden_states: torch.Tensor,
|
||||
temb: torch.Tensor,
|
||||
encoder_attention_mask: torch.Tensor,
|
||||
image_rotary_emb: Optional[torch.Tensor] = None,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
hidden_states_ptb = hidden_states[2:]
|
||||
encoder_hidden_states_ptb = encoder_hidden_states[2:]
|
||||
norm_hidden_states, gate_msa, scale_mlp, gate_mlp = self.norm1(hidden_states, temb)
|
||||
|
||||
if not self.context_pre_only:
|
||||
norm_encoder_hidden_states, enc_gate_msa, enc_scale_mlp, enc_gate_mlp = self.norm1_context(
|
||||
encoder_hidden_states, temb
|
||||
)
|
||||
else:
|
||||
norm_encoder_hidden_states = self.norm1_context(encoder_hidden_states, temb)
|
||||
|
||||
attn_hidden_states, context_attn_hidden_states = self.attn1(
|
||||
hidden_states=norm_hidden_states,
|
||||
encoder_hidden_states=norm_encoder_hidden_states,
|
||||
image_rotary_emb=image_rotary_emb,
|
||||
attention_mask=encoder_attention_mask,
|
||||
)
|
||||
|
||||
hidden_states = hidden_states + self.norm2(attn_hidden_states, torch.tanh(gate_msa).unsqueeze(1))
|
||||
norm_hidden_states = self.norm3(hidden_states, (1 + scale_mlp.unsqueeze(1).to(torch.float32)))
|
||||
ff_output = self.ff(norm_hidden_states)
|
||||
hidden_states = hidden_states + self.norm4(ff_output, torch.tanh(gate_mlp).unsqueeze(1))
|
||||
|
||||
if not self.context_pre_only:
|
||||
encoder_hidden_states = encoder_hidden_states + self.norm2_context(
|
||||
context_attn_hidden_states, torch.tanh(enc_gate_msa).unsqueeze(1)
|
||||
)
|
||||
norm_encoder_hidden_states = self.norm3_context(
|
||||
encoder_hidden_states, (1 + enc_scale_mlp.unsqueeze(1).to(torch.float32))
|
||||
)
|
||||
context_ff_output = self.ff_context(norm_encoder_hidden_states)
|
||||
encoder_hidden_states = encoder_hidden_states + self.norm4_context(
|
||||
context_ff_output, torch.tanh(enc_gate_mlp).unsqueeze(1)
|
||||
)
|
||||
|
||||
hidden_states[2:] = hidden_states_ptb
|
||||
encoder_hidden_states[2:] = encoder_hidden_states_ptb
|
||||
|
||||
return hidden_states, encoder_hidden_states
|
||||
|
||||
|
||||
# from: https://github.com/genmoai/models/blob/075b6e36db58f1242921deff83a1066887b9c9e1/src/mochi_preview/infer.py#L77
|
||||
def linear_quadratic_schedule(num_steps, threshold_noise, linear_steps=None):
|
||||
if linear_steps is None:
|
||||
linear_steps = num_steps // 2
|
||||
linear_sigma_schedule = [i * threshold_noise / linear_steps for i in range(linear_steps)]
|
||||
threshold_noise_step_diff = linear_steps - threshold_noise * num_steps
|
||||
quadratic_steps = num_steps - linear_steps
|
||||
quadratic_coef = threshold_noise_step_diff / (linear_steps * quadratic_steps**2)
|
||||
linear_coef = threshold_noise / linear_steps - 2 * threshold_noise_step_diff / (quadratic_steps**2)
|
||||
const = quadratic_coef * (linear_steps**2)
|
||||
quadratic_sigma_schedule = [
|
||||
quadratic_coef * (i**2) + linear_coef * i + const for i in range(linear_steps, num_steps)
|
||||
]
|
||||
sigma_schedule = linear_sigma_schedule + quadratic_sigma_schedule
|
||||
sigma_schedule = [1.0 - x for x in sigma_schedule]
|
||||
return sigma_schedule
|
||||
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps
|
||||
def retrieve_timesteps(
|
||||
scheduler,
|
||||
num_inference_steps: Optional[int] = None,
|
||||
device: Optional[Union[str, torch.device]] = None,
|
||||
timesteps: Optional[List[int]] = None,
|
||||
sigmas: Optional[List[float]] = None,
|
||||
**kwargs,
|
||||
):
|
||||
r"""
|
||||
Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
|
||||
custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.
|
||||
|
||||
Args:
|
||||
scheduler (`SchedulerMixin`):
|
||||
The scheduler to get timesteps from.
|
||||
num_inference_steps (`int`):
|
||||
The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps`
|
||||
must be `None`.
|
||||
device (`str` or `torch.device`, *optional*):
|
||||
The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
|
||||
timesteps (`List[int]`, *optional*):
|
||||
Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed,
|
||||
`num_inference_steps` and `sigmas` must be `None`.
|
||||
sigmas (`List[float]`, *optional*):
|
||||
Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed,
|
||||
`num_inference_steps` and `timesteps` must be `None`.
|
||||
|
||||
Returns:
|
||||
`Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the
|
||||
second element is the number of inference steps.
|
||||
"""
|
||||
if timesteps is not None and sigmas is not None:
|
||||
raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom value")
|
||||
if timesteps is not None:
|
||||
accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
|
||||
if not accepts_timesteps:
|
||||
raise ValueError(
|
||||
f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
|
||||
f" timestep schedules. Please check whether you are using the correct scheduler."
|
||||
)
|
||||
scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)
|
||||
timesteps = scheduler.timesteps
|
||||
num_inference_steps = len(timesteps)
|
||||
elif sigmas is not None:
|
||||
accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
|
||||
if not accept_sigmas:
|
||||
raise ValueError(
|
||||
f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
|
||||
f" sigmas schedules. Please check whether you are using the correct scheduler."
|
||||
)
|
||||
scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs)
|
||||
timesteps = scheduler.timesteps
|
||||
num_inference_steps = len(timesteps)
|
||||
else:
|
||||
scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)
|
||||
timesteps = scheduler.timesteps
|
||||
return timesteps, num_inference_steps
|
||||
|
||||
|
||||
class MochiSTGPipeline(DiffusionPipeline, Mochi1LoraLoaderMixin):
|
||||
r"""
|
||||
The mochi pipeline for text-to-video generation.
|
||||
|
||||
Reference: https://github.com/genmoai/models
|
||||
|
||||
Args:
|
||||
transformer ([`MochiTransformer3DModel`]):
|
||||
Conditional Transformer architecture to denoise the encoded video latents.
|
||||
scheduler ([`FlowMatchEulerDiscreteScheduler`]):
|
||||
A scheduler to be used in combination with `transformer` to denoise the encoded image latents.
|
||||
vae ([`AutoencoderKLMochi`]):
|
||||
Variational Auto-Encoder (VAE) Model to encode and decode videos to and from latent representations.
|
||||
text_encoder ([`T5EncoderModel`]):
|
||||
[T5](https://huggingface.co/docs/transformers/en/model_doc/t5#transformers.T5EncoderModel), specifically
|
||||
the [google/t5-v1_1-xxl](https://huggingface.co/google/t5-v1_1-xxl) variant.
|
||||
tokenizer (`CLIPTokenizer`):
|
||||
Tokenizer of class
|
||||
[CLIPTokenizer](https://huggingface.co/docs/transformers/en/model_doc/clip#transformers.CLIPTokenizer).
|
||||
tokenizer (`T5TokenizerFast`):
|
||||
Second Tokenizer of class
|
||||
[T5TokenizerFast](https://huggingface.co/docs/transformers/en/model_doc/t5#transformers.T5TokenizerFast).
|
||||
"""
|
||||
|
||||
model_cpu_offload_seq = "text_encoder->transformer->vae"
|
||||
_optional_components = []
|
||||
_callback_tensor_inputs = ["latents", "prompt_embeds", "negative_prompt_embeds"]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
scheduler: FlowMatchEulerDiscreteScheduler,
|
||||
vae: AutoencoderKLMochi,
|
||||
text_encoder: T5EncoderModel,
|
||||
tokenizer: T5TokenizerFast,
|
||||
transformer: MochiTransformer3DModel,
|
||||
force_zeros_for_empty_prompt: bool = False,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.register_modules(
|
||||
vae=vae,
|
||||
text_encoder=text_encoder,
|
||||
tokenizer=tokenizer,
|
||||
transformer=transformer,
|
||||
scheduler=scheduler,
|
||||
)
|
||||
# TODO: determine these scaling factors from model parameters
|
||||
self.vae_spatial_scale_factor = 8
|
||||
self.vae_temporal_scale_factor = 6
|
||||
self.patch_size = 2
|
||||
|
||||
self.video_processor = VideoProcessor(vae_scale_factor=self.vae_spatial_scale_factor)
|
||||
self.tokenizer_max_length = (
|
||||
self.tokenizer.model_max_length if hasattr(self, "tokenizer") and self.tokenizer is not None else 256
|
||||
)
|
||||
self.default_height = 480
|
||||
self.default_width = 848
|
||||
self.register_to_config(force_zeros_for_empty_prompt=force_zeros_for_empty_prompt)
|
||||
|
||||
def _get_t5_prompt_embeds(
|
||||
self,
|
||||
prompt: Union[str, List[str]] = None,
|
||||
num_videos_per_prompt: int = 1,
|
||||
max_sequence_length: int = 256,
|
||||
device: Optional[torch.device] = None,
|
||||
dtype: Optional[torch.dtype] = None,
|
||||
):
|
||||
device = device or self._execution_device
|
||||
dtype = dtype or self.text_encoder.dtype
|
||||
|
||||
prompt = [prompt] if isinstance(prompt, str) else prompt
|
||||
batch_size = len(prompt)
|
||||
|
||||
text_inputs = self.tokenizer(
|
||||
prompt,
|
||||
padding="max_length",
|
||||
max_length=max_sequence_length,
|
||||
truncation=True,
|
||||
add_special_tokens=True,
|
||||
return_tensors="pt",
|
||||
)
|
||||
|
||||
text_input_ids = text_inputs.input_ids
|
||||
prompt_attention_mask = text_inputs.attention_mask
|
||||
prompt_attention_mask = prompt_attention_mask.bool().to(device)
|
||||
|
||||
# The original Mochi implementation zeros out empty negative prompts
|
||||
# but this can lead to overflow when placing the entire pipeline under the autocast context
|
||||
# adding this here so that we can enable zeroing prompts if necessary
|
||||
if self.config.force_zeros_for_empty_prompt and (prompt == "" or prompt[-1] == ""):
|
||||
text_input_ids = torch.zeros_like(text_input_ids, device=device)
|
||||
prompt_attention_mask = torch.zeros_like(prompt_attention_mask, dtype=torch.bool, device=device)
|
||||
|
||||
untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
|
||||
|
||||
if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids):
|
||||
removed_text = self.tokenizer.batch_decode(untruncated_ids[:, max_sequence_length - 1 : -1])
|
||||
logger.warning(
|
||||
"The following part of your input was truncated because `max_sequence_length` is set to "
|
||||
f" {max_sequence_length} tokens: {removed_text}"
|
||||
)
|
||||
|
||||
prompt_embeds = self.text_encoder(text_input_ids.to(device), attention_mask=prompt_attention_mask)[0]
|
||||
prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)
|
||||
|
||||
# duplicate text embeddings for each generation per prompt, using mps friendly method
|
||||
_, seq_len, _ = prompt_embeds.shape
|
||||
prompt_embeds = prompt_embeds.repeat(1, num_videos_per_prompt, 1)
|
||||
prompt_embeds = prompt_embeds.view(batch_size * num_videos_per_prompt, seq_len, -1)
|
||||
|
||||
prompt_attention_mask = prompt_attention_mask.view(batch_size, -1)
|
||||
prompt_attention_mask = prompt_attention_mask.repeat(num_videos_per_prompt, 1)
|
||||
|
||||
return prompt_embeds, prompt_attention_mask
|
||||
|
||||
# Adapted from diffusers.pipelines.cogvideo.pipeline_cogvideox.CogVideoXPipeline.encode_prompt
|
||||
def encode_prompt(
|
||||
self,
|
||||
prompt: Union[str, List[str]],
|
||||
negative_prompt: Optional[Union[str, List[str]]] = None,
|
||||
do_classifier_free_guidance: bool = True,
|
||||
num_videos_per_prompt: int = 1,
|
||||
prompt_embeds: Optional[torch.Tensor] = None,
|
||||
negative_prompt_embeds: Optional[torch.Tensor] = None,
|
||||
prompt_attention_mask: Optional[torch.Tensor] = None,
|
||||
negative_prompt_attention_mask: Optional[torch.Tensor] = None,
|
||||
max_sequence_length: int = 256,
|
||||
device: Optional[torch.device] = None,
|
||||
dtype: Optional[torch.dtype] = None,
|
||||
):
|
||||
r"""
|
||||
Encodes the prompt into text encoder hidden states.
|
||||
|
||||
Args:
|
||||
prompt (`str` or `List[str]`, *optional*):
|
||||
prompt to be encoded
|
||||
negative_prompt (`str` or `List[str]`, *optional*):
|
||||
The prompt or prompts not to guide the image generation. If not defined, one has to pass
|
||||
`negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
|
||||
less than `1`).
|
||||
do_classifier_free_guidance (`bool`, *optional*, defaults to `True`):
|
||||
Whether to use classifier free guidance or not.
|
||||
num_videos_per_prompt (`int`, *optional*, defaults to 1):
|
||||
Number of videos that should be generated per prompt. torch device to place the resulting embeddings on
|
||||
prompt_embeds (`torch.Tensor`, *optional*):
|
||||
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
|
||||
provided, text embeddings will be generated from `prompt` input argument.
|
||||
negative_prompt_embeds (`torch.Tensor`, *optional*):
|
||||
Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
|
||||
weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
|
||||
argument.
|
||||
device: (`torch.device`, *optional*):
|
||||
torch device
|
||||
dtype: (`torch.dtype`, *optional*):
|
||||
torch dtype
|
||||
"""
|
||||
device = device or self._execution_device
|
||||
|
||||
prompt = [prompt] if isinstance(prompt, str) else prompt
|
||||
if prompt is not None:
|
||||
batch_size = len(prompt)
|
||||
else:
|
||||
batch_size = prompt_embeds.shape[0]
|
||||
|
||||
if prompt_embeds is None:
|
||||
prompt_embeds, prompt_attention_mask = self._get_t5_prompt_embeds(
|
||||
prompt=prompt,
|
||||
num_videos_per_prompt=num_videos_per_prompt,
|
||||
max_sequence_length=max_sequence_length,
|
||||
device=device,
|
||||
dtype=dtype,
|
||||
)
|
||||
|
||||
if do_classifier_free_guidance and negative_prompt_embeds is None:
|
||||
negative_prompt = negative_prompt or ""
|
||||
negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt
|
||||
|
||||
if prompt is not None and type(prompt) is not type(negative_prompt):
|
||||
raise TypeError(
|
||||
f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
|
||||
f" {type(prompt)}."
|
||||
)
|
||||
elif batch_size != len(negative_prompt):
|
||||
raise ValueError(
|
||||
f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
|
||||
f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
|
||||
" the batch size of `prompt`."
|
||||
)
|
||||
|
||||
negative_prompt_embeds, negative_prompt_attention_mask = self._get_t5_prompt_embeds(
|
||||
prompt=negative_prompt,
|
||||
num_videos_per_prompt=num_videos_per_prompt,
|
||||
max_sequence_length=max_sequence_length,
|
||||
device=device,
|
||||
dtype=dtype,
|
||||
)
|
||||
|
||||
return prompt_embeds, prompt_attention_mask, negative_prompt_embeds, negative_prompt_attention_mask
|
||||
|
||||
def check_inputs(
|
||||
self,
|
||||
prompt,
|
||||
height,
|
||||
width,
|
||||
callback_on_step_end_tensor_inputs=None,
|
||||
prompt_embeds=None,
|
||||
negative_prompt_embeds=None,
|
||||
prompt_attention_mask=None,
|
||||
negative_prompt_attention_mask=None,
|
||||
):
|
||||
if height % 8 != 0 or width % 8 != 0:
|
||||
raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
|
||||
|
||||
if callback_on_step_end_tensor_inputs is not None and not all(
|
||||
k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
|
||||
):
|
||||
raise ValueError(
|
||||
f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}"
|
||||
)
|
||||
|
||||
if prompt is not None and prompt_embeds is not None:
|
||||
raise ValueError(
|
||||
f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
|
||||
" only forward one of the two."
|
||||
)
|
||||
elif prompt is None and prompt_embeds is None:
|
||||
raise ValueError(
|
||||
"Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
|
||||
)
|
||||
elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
|
||||
raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
|
||||
|
||||
if prompt_embeds is not None and prompt_attention_mask is None:
|
||||
raise ValueError("Must provide `prompt_attention_mask` when specifying `prompt_embeds`.")
|
||||
|
||||
if negative_prompt_embeds is not None and negative_prompt_attention_mask is None:
|
||||
raise ValueError("Must provide `negative_prompt_attention_mask` when specifying `negative_prompt_embeds`.")
|
||||
|
||||
if prompt_embeds is not None and negative_prompt_embeds is not None:
|
||||
if prompt_embeds.shape != negative_prompt_embeds.shape:
|
||||
raise ValueError(
|
||||
"`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but"
|
||||
f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`"
|
||||
f" {negative_prompt_embeds.shape}."
|
||||
)
|
||||
if prompt_attention_mask.shape != negative_prompt_attention_mask.shape:
|
||||
raise ValueError(
|
||||
"`prompt_attention_mask` and `negative_prompt_attention_mask` must have the same shape when passed directly, but"
|
||||
f" got: `prompt_attention_mask` {prompt_attention_mask.shape} != `negative_prompt_attention_mask`"
|
||||
f" {negative_prompt_attention_mask.shape}."
|
||||
)
|
||||
|
||||
def enable_vae_slicing(self):
|
||||
r"""
|
||||
Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to
|
||||
compute decoding in several steps. This is useful to save some memory and allow larger batch sizes.
|
||||
"""
|
||||
self.vae.enable_slicing()
|
||||
|
||||
def disable_vae_slicing(self):
|
||||
r"""
|
||||
Disable sliced VAE decoding. If `enable_vae_slicing` was previously enabled, this method will go back to
|
||||
computing decoding in one step.
|
||||
"""
|
||||
self.vae.disable_slicing()
|
||||
|
||||
def enable_vae_tiling(self):
|
||||
r"""
|
||||
Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to
|
||||
compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow
|
||||
processing larger images.
|
||||
"""
|
||||
self.vae.enable_tiling()
|
||||
|
||||
def disable_vae_tiling(self):
|
||||
r"""
|
||||
Disable tiled VAE decoding. If `enable_vae_tiling` was previously enabled, this method will go back to
|
||||
computing decoding in one step.
|
||||
"""
|
||||
self.vae.disable_tiling()
|
||||
|
||||
def prepare_latents(
|
||||
self,
|
||||
batch_size,
|
||||
num_channels_latents,
|
||||
height,
|
||||
width,
|
||||
num_frames,
|
||||
dtype,
|
||||
device,
|
||||
generator,
|
||||
latents=None,
|
||||
):
|
||||
height = height // self.vae_spatial_scale_factor
|
||||
width = width // self.vae_spatial_scale_factor
|
||||
num_frames = (num_frames - 1) // self.vae_temporal_scale_factor + 1
|
||||
|
||||
shape = (batch_size, num_channels_latents, num_frames, height, width)
|
||||
|
||||
if latents is not None:
|
||||
return latents.to(device=device, dtype=dtype)
|
||||
if isinstance(generator, list) and len(generator) != batch_size:
|
||||
raise ValueError(
|
||||
f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
|
||||
f" size of {batch_size}. Make sure the batch size matches the length of the generators."
|
||||
)
|
||||
|
||||
latents = randn_tensor(shape, generator=generator, device=device, dtype=torch.float32)
|
||||
latents = latents.to(dtype)
|
||||
return latents
|
||||
|
||||
@property
|
||||
def guidance_scale(self):
|
||||
return self._guidance_scale
|
||||
|
||||
@property
|
||||
def do_classifier_free_guidance(self):
|
||||
return self._guidance_scale > 1.0
|
||||
|
||||
@property
|
||||
def do_spatio_temporal_guidance(self):
|
||||
return self._stg_scale > 0.0
|
||||
|
||||
@property
|
||||
def num_timesteps(self):
|
||||
return self._num_timesteps
|
||||
|
||||
@property
|
||||
def attention_kwargs(self):
|
||||
return self._attention_kwargs
|
||||
|
||||
@property
|
||||
def current_timestep(self):
|
||||
return self._current_timestep
|
||||
|
||||
@property
|
||||
def interrupt(self):
|
||||
return self._interrupt
|
||||
|
||||
@torch.no_grad()
|
||||
@replace_example_docstring(EXAMPLE_DOC_STRING)
|
||||
def __call__(
|
||||
self,
|
||||
prompt: Union[str, List[str]] = None,
|
||||
negative_prompt: Optional[Union[str, List[str]]] = None,
|
||||
height: Optional[int] = None,
|
||||
width: Optional[int] = None,
|
||||
num_frames: int = 19,
|
||||
num_inference_steps: int = 64,
|
||||
timesteps: List[int] = None,
|
||||
guidance_scale: float = 4.5,
|
||||
num_videos_per_prompt: Optional[int] = 1,
|
||||
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
|
||||
latents: Optional[torch.Tensor] = None,
|
||||
prompt_embeds: Optional[torch.Tensor] = None,
|
||||
prompt_attention_mask: Optional[torch.Tensor] = None,
|
||||
negative_prompt_embeds: Optional[torch.Tensor] = None,
|
||||
negative_prompt_attention_mask: Optional[torch.Tensor] = None,
|
||||
output_type: Optional[str] = "pil",
|
||||
return_dict: bool = True,
|
||||
attention_kwargs: Optional[Dict[str, Any]] = None,
|
||||
callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
|
||||
callback_on_step_end_tensor_inputs: List[str] = ["latents"],
|
||||
max_sequence_length: int = 256,
|
||||
stg_applied_layers_idx: Optional[List[int]] = [34],
|
||||
stg_scale: Optional[float] = 0.0,
|
||||
do_rescaling: Optional[bool] = False,
|
||||
):
|
||||
r"""
|
||||
Function invoked when calling the pipeline for generation.
|
||||
|
||||
Args:
|
||||
prompt (`str` or `List[str]`, *optional*):
|
||||
The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
|
||||
instead.
|
||||
height (`int`, *optional*, defaults to `self.default_height`):
|
||||
The height in pixels of the generated image. This is set to 480 by default for the best results.
|
||||
width (`int`, *optional*, defaults to `self.default_width`):
|
||||
The width in pixels of the generated image. This is set to 848 by default for the best results.
|
||||
num_frames (`int`, defaults to `19`):
|
||||
The number of video frames to generate
|
||||
num_inference_steps (`int`, *optional*, defaults to 50):
|
||||
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
|
||||
expense of slower inference.
|
||||
timesteps (`List[int]`, *optional*):
|
||||
Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument
|
||||
in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is
|
||||
passed will be used. Must be in descending order.
|
||||
guidance_scale (`float`, defaults to `4.5`):
|
||||
Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
|
||||
`guidance_scale` is defined as `w` of equation 2. of [Imagen
|
||||
Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
|
||||
1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
|
||||
usually at the expense of lower image quality.
|
||||
num_videos_per_prompt (`int`, *optional*, defaults to 1):
|
||||
The number of videos to generate per prompt.
|
||||
generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
|
||||
One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
|
||||
to make generation deterministic.
|
||||
latents (`torch.Tensor`, *optional*):
|
||||
Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
|
||||
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
|
||||
tensor will ge generated by sampling using the supplied random `generator`.
|
||||
prompt_embeds (`torch.Tensor`, *optional*):
|
||||
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
|
||||
provided, text embeddings will be generated from `prompt` input argument.
|
||||
prompt_attention_mask (`torch.Tensor`, *optional*):
|
||||
Pre-generated attention mask for text embeddings.
|
||||
negative_prompt_embeds (`torch.FloatTensor`, *optional*):
|
||||
Pre-generated negative text embeddings. For PixArt-Sigma this negative prompt should be "". If not
|
||||
provided, negative_prompt_embeds will be generated from `negative_prompt` input argument.
|
||||
negative_prompt_attention_mask (`torch.FloatTensor`, *optional*):
|
||||
Pre-generated attention mask for negative text embeddings.
|
||||
output_type (`str`, *optional*, defaults to `"pil"`):
|
||||
The output format of the generate image. Choose between
|
||||
[PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
|
||||
return_dict (`bool`, *optional*, defaults to `True`):
|
||||
Whether or not to return a [`~pipelines.mochi.MochiPipelineOutput`] instead of a plain tuple.
|
||||
attention_kwargs (`dict`, *optional*):
|
||||
A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
|
||||
`self.processor` in
|
||||
[diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
|
||||
callback_on_step_end (`Callable`, *optional*):
|
||||
A function that calls at the end of each denoising steps during the inference. The function is called
|
||||
with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int,
|
||||
callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by
|
||||
`callback_on_step_end_tensor_inputs`.
|
||||
callback_on_step_end_tensor_inputs (`List`, *optional*):
|
||||
The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
|
||||
will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
|
||||
`._callback_tensor_inputs` attribute of your pipeline class.
|
||||
max_sequence_length (`int` defaults to `256`):
|
||||
Maximum sequence length to use with the `prompt`.
|
||||
|
||||
Examples:
|
||||
|
||||
Returns:
|
||||
[`~pipelines.mochi.MochiPipelineOutput`] or `tuple`:
|
||||
If `return_dict` is `True`, [`~pipelines.mochi.MochiPipelineOutput`] is returned, otherwise a `tuple`
|
||||
is returned where the first element is a list with the generated images.
|
||||
"""
|
||||
|
||||
if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)):
|
||||
callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs
|
||||
|
||||
height = height or self.default_height
|
||||
width = width or self.default_width
|
||||
|
||||
# 1. Check inputs. Raise error if not correct
|
||||
self.check_inputs(
|
||||
prompt=prompt,
|
||||
height=height,
|
||||
width=width,
|
||||
callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs,
|
||||
prompt_embeds=prompt_embeds,
|
||||
negative_prompt_embeds=negative_prompt_embeds,
|
||||
prompt_attention_mask=prompt_attention_mask,
|
||||
negative_prompt_attention_mask=negative_prompt_attention_mask,
|
||||
)
|
||||
|
||||
self._guidance_scale = guidance_scale
|
||||
self._stg_scale = stg_scale
|
||||
self._attention_kwargs = attention_kwargs
|
||||
self._current_timestep = None
|
||||
self._interrupt = False
|
||||
|
||||
if self.do_spatio_temporal_guidance:
|
||||
for i in stg_applied_layers_idx:
|
||||
self.transformer.transformer_blocks[i].forward = types.MethodType(
|
||||
forward_with_stg, self.transformer.transformer_blocks[i]
|
||||
)
|
||||
|
||||
# 2. Define call parameters
|
||||
if prompt is not None and isinstance(prompt, str):
|
||||
batch_size = 1
|
||||
elif prompt is not None and isinstance(prompt, list):
|
||||
batch_size = len(prompt)
|
||||
else:
|
||||
batch_size = prompt_embeds.shape[0]
|
||||
|
||||
device = self._execution_device
|
||||
# 3. Prepare text embeddings
|
||||
(
|
||||
prompt_embeds,
|
||||
prompt_attention_mask,
|
||||
negative_prompt_embeds,
|
||||
negative_prompt_attention_mask,
|
||||
) = self.encode_prompt(
|
||||
prompt=prompt,
|
||||
negative_prompt=negative_prompt,
|
||||
do_classifier_free_guidance=self.do_classifier_free_guidance,
|
||||
num_videos_per_prompt=num_videos_per_prompt,
|
||||
prompt_embeds=prompt_embeds,
|
||||
negative_prompt_embeds=negative_prompt_embeds,
|
||||
prompt_attention_mask=prompt_attention_mask,
|
||||
negative_prompt_attention_mask=negative_prompt_attention_mask,
|
||||
max_sequence_length=max_sequence_length,
|
||||
device=device,
|
||||
)
|
||||
# 4. Prepare latent variables
|
||||
num_channels_latents = self.transformer.config.in_channels
|
||||
latents = self.prepare_latents(
|
||||
batch_size * num_videos_per_prompt,
|
||||
num_channels_latents,
|
||||
height,
|
||||
width,
|
||||
num_frames,
|
||||
prompt_embeds.dtype,
|
||||
device,
|
||||
generator,
|
||||
latents,
|
||||
)
|
||||
|
||||
if self.do_classifier_free_guidance and not self.do_spatio_temporal_guidance:
|
||||
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)
|
||||
prompt_attention_mask = torch.cat([negative_prompt_attention_mask, prompt_attention_mask], dim=0)
|
||||
elif self.do_classifier_free_guidance and self.do_spatio_temporal_guidance:
|
||||
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds, prompt_embeds], dim=0)
|
||||
prompt_attention_mask = torch.cat(
|
||||
[negative_prompt_attention_mask, prompt_attention_mask, prompt_attention_mask], dim=0
|
||||
)
|
||||
|
||||
# 5. Prepare timestep
|
||||
# from https://github.com/genmoai/models/blob/075b6e36db58f1242921deff83a1066887b9c9e1/src/mochi_preview/infer.py#L77
|
||||
threshold_noise = 0.025
|
||||
sigmas = linear_quadratic_schedule(num_inference_steps, threshold_noise)
|
||||
sigmas = np.array(sigmas)
|
||||
|
||||
timesteps, num_inference_steps = retrieve_timesteps(
|
||||
self.scheduler,
|
||||
num_inference_steps,
|
||||
device,
|
||||
timesteps,
|
||||
sigmas,
|
||||
)
|
||||
num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
|
||||
self._num_timesteps = len(timesteps)
|
||||
|
||||
# 6. Denoising loop
|
||||
with self.progress_bar(total=num_inference_steps) as progress_bar:
|
||||
for i, t in enumerate(timesteps):
|
||||
if self.interrupt:
|
||||
continue
|
||||
|
||||
# Note: Mochi uses reversed timesteps. To ensure compatibility with methods like FasterCache, we need
|
||||
# to make sure we're using the correct non-reversed timestep value.
|
||||
self._current_timestep = 1000 - t
|
||||
if self.do_classifier_free_guidance and not self.do_spatio_temporal_guidance:
|
||||
latent_model_input = torch.cat([latents] * 2)
|
||||
elif self.do_classifier_free_guidance and self.do_spatio_temporal_guidance:
|
||||
latent_model_input = torch.cat([latents] * 3)
|
||||
else:
|
||||
latent_model_input = latents
|
||||
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
|
||||
timestep = t.expand(latent_model_input.shape[0]).to(latents.dtype)
|
||||
|
||||
noise_pred = self.transformer(
|
||||
hidden_states=latent_model_input,
|
||||
encoder_hidden_states=prompt_embeds,
|
||||
timestep=timestep,
|
||||
encoder_attention_mask=prompt_attention_mask,
|
||||
attention_kwargs=attention_kwargs,
|
||||
return_dict=False,
|
||||
)[0]
|
||||
# Mochi CFG + Sampling runs in FP32
|
||||
noise_pred = noise_pred.to(torch.float32)
|
||||
|
||||
if self.do_classifier_free_guidance and not self.do_spatio_temporal_guidance:
|
||||
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
|
||||
noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond)
|
||||
elif self.do_classifier_free_guidance and self.do_spatio_temporal_guidance:
|
||||
noise_pred_uncond, noise_pred_text, noise_pred_perturb = noise_pred.chunk(3)
|
||||
noise_pred = (
|
||||
noise_pred_uncond
|
||||
+ self.guidance_scale * (noise_pred_text - noise_pred_uncond)
|
||||
+ self._stg_scale * (noise_pred_text - noise_pred_perturb)
|
||||
)
|
||||
|
||||
if do_rescaling:
|
||||
rescaling_scale = 0.7
|
||||
factor = noise_pred_text.std() / noise_pred.std()
|
||||
factor = rescaling_scale * factor + (1 - rescaling_scale)
|
||||
noise_pred = noise_pred * factor
|
||||
|
||||
# compute the previous noisy sample x_t -> x_t-1
|
||||
latents_dtype = latents.dtype
|
||||
latents = self.scheduler.step(noise_pred, t, latents.to(torch.float32), return_dict=False)[0]
|
||||
latents = latents.to(latents_dtype)
|
||||
|
||||
if latents.dtype != latents_dtype:
|
||||
if torch.backends.mps.is_available():
|
||||
# some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272
|
||||
latents = latents.to(latents_dtype)
|
||||
|
||||
if callback_on_step_end is not None:
|
||||
callback_kwargs = {}
|
||||
for k in callback_on_step_end_tensor_inputs:
|
||||
callback_kwargs[k] = locals()[k]
|
||||
callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
|
||||
|
||||
latents = callback_outputs.pop("latents", latents)
|
||||
prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
|
||||
|
||||
# call the callback, if provided
|
||||
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
|
||||
progress_bar.update()
|
||||
|
||||
if XLA_AVAILABLE:
|
||||
xm.mark_step()
|
||||
|
||||
self._current_timestep = None
|
||||
|
||||
if output_type == "latent":
|
||||
video = latents
|
||||
else:
|
||||
# unscale/denormalize the latents
|
||||
# denormalize with the mean and std if available and not None
|
||||
has_latents_mean = hasattr(self.vae.config, "latents_mean") and self.vae.config.latents_mean is not None
|
||||
has_latents_std = hasattr(self.vae.config, "latents_std") and self.vae.config.latents_std is not None
|
||||
if has_latents_mean and has_latents_std:
|
||||
latents_mean = (
|
||||
torch.tensor(self.vae.config.latents_mean).view(1, 12, 1, 1, 1).to(latents.device, latents.dtype)
|
||||
)
|
||||
latents_std = (
|
||||
torch.tensor(self.vae.config.latents_std).view(1, 12, 1, 1, 1).to(latents.device, latents.dtype)
|
||||
)
|
||||
latents = latents * latents_std / self.vae.config.scaling_factor + latents_mean
|
||||
else:
|
||||
latents = latents / self.vae.config.scaling_factor
|
||||
|
||||
video = self.vae.decode(latents, return_dict=False)[0]
|
||||
video = self.video_processor.postprocess_video(video, output_type=output_type)
|
||||
|
||||
# Offload all models
|
||||
self.maybe_free_model_hooks()
|
||||
|
||||
if not return_dict:
|
||||
return (video,)
|
||||
|
||||
return MochiPipelineOutput(frames=video)
|
||||
@@ -152,9 +152,7 @@ def log_validation(
|
||||
validation_prompt = log["validation_prompt"]
|
||||
validation_image = log["validation_image"]
|
||||
|
||||
formatted_images = []
|
||||
|
||||
formatted_images.append(np.asarray(validation_image))
|
||||
formatted_images = [np.asarray(validation_image)]
|
||||
|
||||
for image in images:
|
||||
formatted_images.append(np.asarray(image))
|
||||
|
||||
@@ -166,9 +166,7 @@ def log_validation(
|
||||
validation_prompt = log["validation_prompt"]
|
||||
validation_image = log["validation_image"]
|
||||
|
||||
formatted_images = []
|
||||
|
||||
formatted_images.append(np.asarray(validation_image))
|
||||
formatted_images = [np.asarray(validation_image)]
|
||||
|
||||
for image in images:
|
||||
formatted_images.append(np.asarray(image))
|
||||
|
||||
@@ -1283,8 +1283,8 @@ def main(args):
|
||||
noisy_model_input = (1.0 - sigmas) * model_input + sigmas * noise
|
||||
|
||||
# Get the text embedding for conditioning
|
||||
prompt_embeds = batch["prompt_embeds"]
|
||||
pooled_prompt_embeds = batch["pooled_prompt_embeds"]
|
||||
prompt_embeds = batch["prompt_embeds"].to(dtype=weight_dtype)
|
||||
pooled_prompt_embeds = batch["pooled_prompt_embeds"].to(dtype=weight_dtype)
|
||||
|
||||
# controlnet(s) inference
|
||||
controlnet_image = batch["conditioning_pixel_values"].to(dtype=weight_dtype)
|
||||
|
||||
@@ -157,9 +157,7 @@ def log_validation(vae, unet, controlnet, args, accelerator, weight_dtype, step,
|
||||
validation_prompt = log["validation_prompt"]
|
||||
validation_image = log["validation_image"]
|
||||
|
||||
formatted_images = []
|
||||
|
||||
formatted_images.append(np.asarray(validation_image))
|
||||
formatted_images = [np.asarray(validation_image)]
|
||||
|
||||
for image in images:
|
||||
formatted_images.append(np.asarray(image))
|
||||
|
||||
@@ -0,0 +1,40 @@
|
||||
# AnyTextPipeline
|
||||
|
||||
Project page: https://aigcdesigngroup.github.io/homepage_anytext
|
||||
|
||||
"AnyText comprises a diffusion pipeline with two primary elements: an auxiliary latent module and a text embedding module. The former uses inputs like text glyph, position, and masked image to generate latent features for text generation or editing. The latter employs an OCR model for encoding stroke data as embeddings, which blend with image caption embeddings from the tokenizer to generate texts that seamlessly integrate with the background. We employed text-control diffusion loss and text perceptual loss for training to further enhance writing accuracy."
|
||||
|
||||
> **Note:** Each text line that needs to be generated should be enclosed in double quotes.
|
||||
|
||||
For any usage questions, please refer to the [paper](https://arxiv.org/abs/2311.03054).
|
||||
|
||||
[](https://colab.research.google.com/gist/tolgacangoz/b87ec9d2f265b448dd947c9d4a0da389/anytext.ipynb)
|
||||
|
||||
```py
|
||||
# This example requires the `anytext_controlnet.py` file:
|
||||
# !git clone --depth 1 https://github.com/huggingface/diffusers.git
|
||||
# %cd diffusers/examples/research_projects/anytext
|
||||
# Let's choose a font file shared by an HF staff:
|
||||
# !wget https://huggingface.co/spaces/ysharma/TranslateQuotesInImageForwards/resolve/main/arial-unicode-ms.ttf
|
||||
|
||||
import torch
|
||||
from diffusers import DiffusionPipeline
|
||||
from anytext_controlnet import AnyTextControlNetModel
|
||||
from diffusers.utils import load_image
|
||||
|
||||
|
||||
anytext_controlnet = AnyTextControlNetModel.from_pretrained("tolgacangoz/anytext-controlnet", torch_dtype=torch.float16,
|
||||
variant="fp16",)
|
||||
pipe = DiffusionPipeline.from_pretrained("tolgacangoz/anytext", font_path="arial-unicode-ms.ttf",
|
||||
controlnet=anytext_controlnet, torch_dtype=torch.float16,
|
||||
trust_remote_code=False, # One needs to give permission to run this pipeline's code
|
||||
).to("cuda")
|
||||
|
||||
# generate image
|
||||
prompt = 'photo of caramel macchiato coffee on the table, top-down perspective, with "Any" "Text" written on it using cream'
|
||||
draw_pos = load_image("https://raw.githubusercontent.com/tyxsspa/AnyText/refs/heads/main/example_images/gen9.png")
|
||||
# There are two modes: "generate" and "edit". "edit" mode requires `ori_image` parameter for the image to be edited.
|
||||
image = pipe(prompt, num_inference_steps=20, mode="generate", draw_pos=draw_pos,
|
||||
).images[0]
|
||||
image
|
||||
```
|
||||
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,463 @@
|
||||
# Copyright 2023 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
# Based on [AnyText: Multilingual Visual Text Generation And Editing](https://huggingface.co/papers/2311.03054).
|
||||
# Authors: Yuxiang Tuo, Wangmeng Xiang, Jun-Yan He, Yifeng Geng, Xuansong Xie
|
||||
# Code: https://github.com/tyxsspa/AnyText with Apache-2.0 license
|
||||
#
|
||||
# Adapted to Diffusers by [M. Tolga Cangöz](https://github.com/tolgacangoz).
|
||||
|
||||
|
||||
from typing import Any, Dict, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
from diffusers.configuration_utils import register_to_config
|
||||
from diffusers.models.controlnets.controlnet import (
|
||||
ControlNetModel,
|
||||
ControlNetOutput,
|
||||
)
|
||||
from diffusers.utils import logging
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||
|
||||
|
||||
class AnyTextControlNetConditioningEmbedding(nn.Module):
|
||||
"""
|
||||
Quoting from https://arxiv.org/abs/2302.05543: "Stable Diffusion uses a pre-processing method similar to VQ-GAN
|
||||
[11] to convert the entire dataset of 512 × 512 images into smaller 64 × 64 “latent images” for stabilized
|
||||
training. This requires ControlNets to convert image-based conditions to 64 × 64 feature space to match the
|
||||
convolution size. We use a tiny network E(·) of four convolution layers with 4 × 4 kernels and 2 × 2 strides
|
||||
(activated by ReLU, channels are 16, 32, 64, 128, initialized with Gaussian weights, trained jointly with the full
|
||||
model) to encode image-space conditions ... into feature maps ..."
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
conditioning_embedding_channels: int,
|
||||
glyph_channels=1,
|
||||
position_channels=1,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.glyph_block = nn.Sequential(
|
||||
nn.Conv2d(glyph_channels, 8, 3, padding=1),
|
||||
nn.SiLU(),
|
||||
nn.Conv2d(8, 8, 3, padding=1),
|
||||
nn.SiLU(),
|
||||
nn.Conv2d(8, 16, 3, padding=1, stride=2),
|
||||
nn.SiLU(),
|
||||
nn.Conv2d(16, 16, 3, padding=1),
|
||||
nn.SiLU(),
|
||||
nn.Conv2d(16, 32, 3, padding=1, stride=2),
|
||||
nn.SiLU(),
|
||||
nn.Conv2d(32, 32, 3, padding=1),
|
||||
nn.SiLU(),
|
||||
nn.Conv2d(32, 96, 3, padding=1, stride=2),
|
||||
nn.SiLU(),
|
||||
nn.Conv2d(96, 96, 3, padding=1),
|
||||
nn.SiLU(),
|
||||
nn.Conv2d(96, 256, 3, padding=1, stride=2),
|
||||
nn.SiLU(),
|
||||
)
|
||||
|
||||
self.position_block = nn.Sequential(
|
||||
nn.Conv2d(position_channels, 8, 3, padding=1),
|
||||
nn.SiLU(),
|
||||
nn.Conv2d(8, 8, 3, padding=1),
|
||||
nn.SiLU(),
|
||||
nn.Conv2d(8, 16, 3, padding=1, stride=2),
|
||||
nn.SiLU(),
|
||||
nn.Conv2d(16, 16, 3, padding=1),
|
||||
nn.SiLU(),
|
||||
nn.Conv2d(16, 32, 3, padding=1, stride=2),
|
||||
nn.SiLU(),
|
||||
nn.Conv2d(32, 32, 3, padding=1),
|
||||
nn.SiLU(),
|
||||
nn.Conv2d(32, 64, 3, padding=1, stride=2),
|
||||
nn.SiLU(),
|
||||
)
|
||||
|
||||
self.fuse_block = nn.Conv2d(256 + 64 + 4, conditioning_embedding_channels, 3, padding=1)
|
||||
|
||||
def forward(self, glyphs, positions, text_info):
|
||||
glyph_embedding = self.glyph_block(glyphs.to(self.glyph_block[0].weight.device))
|
||||
position_embedding = self.position_block(positions.to(self.position_block[0].weight.device))
|
||||
guided_hint = self.fuse_block(torch.cat([glyph_embedding, position_embedding, text_info["masked_x"]], dim=1))
|
||||
|
||||
return guided_hint
|
||||
|
||||
|
||||
class AnyTextControlNetModel(ControlNetModel):
|
||||
"""
|
||||
A AnyTextControlNetModel model.
|
||||
|
||||
Args:
|
||||
in_channels (`int`, defaults to 4):
|
||||
The number of channels in the input sample.
|
||||
flip_sin_to_cos (`bool`, defaults to `True`):
|
||||
Whether to flip the sin to cos in the time embedding.
|
||||
freq_shift (`int`, defaults to 0):
|
||||
The frequency shift to apply to the time embedding.
|
||||
down_block_types (`tuple[str]`, defaults to `("CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "DownBlock2D")`):
|
||||
The tuple of downsample blocks to use.
|
||||
only_cross_attention (`Union[bool, Tuple[bool]]`, defaults to `False`):
|
||||
block_out_channels (`tuple[int]`, defaults to `(320, 640, 1280, 1280)`):
|
||||
The tuple of output channels for each block.
|
||||
layers_per_block (`int`, defaults to 2):
|
||||
The number of layers per block.
|
||||
downsample_padding (`int`, defaults to 1):
|
||||
The padding to use for the downsampling convolution.
|
||||
mid_block_scale_factor (`float`, defaults to 1):
|
||||
The scale factor to use for the mid block.
|
||||
act_fn (`str`, defaults to "silu"):
|
||||
The activation function to use.
|
||||
norm_num_groups (`int`, *optional*, defaults to 32):
|
||||
The number of groups to use for the normalization. If None, normalization and activation layers is skipped
|
||||
in post-processing.
|
||||
norm_eps (`float`, defaults to 1e-5):
|
||||
The epsilon to use for the normalization.
|
||||
cross_attention_dim (`int`, defaults to 1280):
|
||||
The dimension of the cross attention features.
|
||||
transformer_layers_per_block (`int` or `Tuple[int]`, *optional*, defaults to 1):
|
||||
The number of transformer blocks of type [`~models.attention.BasicTransformerBlock`]. Only relevant for
|
||||
[`~models.unet_2d_blocks.CrossAttnDownBlock2D`], [`~models.unet_2d_blocks.CrossAttnUpBlock2D`],
|
||||
[`~models.unet_2d_blocks.UNetMidBlock2DCrossAttn`].
|
||||
encoder_hid_dim (`int`, *optional*, defaults to None):
|
||||
If `encoder_hid_dim_type` is defined, `encoder_hidden_states` will be projected from `encoder_hid_dim`
|
||||
dimension to `cross_attention_dim`.
|
||||
encoder_hid_dim_type (`str`, *optional*, defaults to `None`):
|
||||
If given, the `encoder_hidden_states` and potentially other embeddings are down-projected to text
|
||||
embeddings of dimension `cross_attention` according to `encoder_hid_dim_type`.
|
||||
attention_head_dim (`Union[int, Tuple[int]]`, defaults to 8):
|
||||
The dimension of the attention heads.
|
||||
use_linear_projection (`bool`, defaults to `False`):
|
||||
class_embed_type (`str`, *optional*, defaults to `None`):
|
||||
The type of class embedding to use which is ultimately summed with the time embeddings. Choose from None,
|
||||
`"timestep"`, `"identity"`, `"projection"`, or `"simple_projection"`.
|
||||
addition_embed_type (`str`, *optional*, defaults to `None`):
|
||||
Configures an optional embedding which will be summed with the time embeddings. Choose from `None` or
|
||||
"text". "text" will use the `TextTimeEmbedding` layer.
|
||||
num_class_embeds (`int`, *optional*, defaults to 0):
|
||||
Input dimension of the learnable embedding matrix to be projected to `time_embed_dim`, when performing
|
||||
class conditioning with `class_embed_type` equal to `None`.
|
||||
upcast_attention (`bool`, defaults to `False`):
|
||||
resnet_time_scale_shift (`str`, defaults to `"default"`):
|
||||
Time scale shift config for ResNet blocks (see `ResnetBlock2D`). Choose from `default` or `scale_shift`.
|
||||
projection_class_embeddings_input_dim (`int`, *optional*, defaults to `None`):
|
||||
The dimension of the `class_labels` input when `class_embed_type="projection"`. Required when
|
||||
`class_embed_type="projection"`.
|
||||
controlnet_conditioning_channel_order (`str`, defaults to `"rgb"`):
|
||||
The channel order of conditional image. Will convert to `rgb` if it's `bgr`.
|
||||
conditioning_embedding_out_channels (`tuple[int]`, *optional*, defaults to `(16, 32, 96, 256)`):
|
||||
The tuple of output channel for each block in the `conditioning_embedding` layer.
|
||||
global_pool_conditions (`bool`, defaults to `False`):
|
||||
TODO(Patrick) - unused parameter.
|
||||
addition_embed_type_num_heads (`int`, defaults to 64):
|
||||
The number of heads to use for the `TextTimeEmbedding` layer.
|
||||
"""
|
||||
|
||||
_supports_gradient_checkpointing = True
|
||||
|
||||
@register_to_config
|
||||
def __init__(
|
||||
self,
|
||||
in_channels: int = 4,
|
||||
conditioning_channels: int = 1,
|
||||
flip_sin_to_cos: bool = True,
|
||||
freq_shift: int = 0,
|
||||
down_block_types: Tuple[str, ...] = (
|
||||
"CrossAttnDownBlock2D",
|
||||
"CrossAttnDownBlock2D",
|
||||
"CrossAttnDownBlock2D",
|
||||
"DownBlock2D",
|
||||
),
|
||||
mid_block_type: Optional[str] = "UNetMidBlock2DCrossAttn",
|
||||
only_cross_attention: Union[bool, Tuple[bool]] = False,
|
||||
block_out_channels: Tuple[int, ...] = (320, 640, 1280, 1280),
|
||||
layers_per_block: int = 2,
|
||||
downsample_padding: int = 1,
|
||||
mid_block_scale_factor: float = 1,
|
||||
act_fn: str = "silu",
|
||||
norm_num_groups: Optional[int] = 32,
|
||||
norm_eps: float = 1e-5,
|
||||
cross_attention_dim: int = 1280,
|
||||
transformer_layers_per_block: Union[int, Tuple[int, ...]] = 1,
|
||||
encoder_hid_dim: Optional[int] = None,
|
||||
encoder_hid_dim_type: Optional[str] = None,
|
||||
attention_head_dim: Union[int, Tuple[int, ...]] = 8,
|
||||
num_attention_heads: Optional[Union[int, Tuple[int, ...]]] = None,
|
||||
use_linear_projection: bool = False,
|
||||
class_embed_type: Optional[str] = None,
|
||||
addition_embed_type: Optional[str] = None,
|
||||
addition_time_embed_dim: Optional[int] = None,
|
||||
num_class_embeds: Optional[int] = None,
|
||||
upcast_attention: bool = False,
|
||||
resnet_time_scale_shift: str = "default",
|
||||
projection_class_embeddings_input_dim: Optional[int] = None,
|
||||
controlnet_conditioning_channel_order: str = "rgb",
|
||||
conditioning_embedding_out_channels: Optional[Tuple[int, ...]] = (16, 32, 96, 256),
|
||||
global_pool_conditions: bool = False,
|
||||
addition_embed_type_num_heads: int = 64,
|
||||
):
|
||||
super().__init__(
|
||||
in_channels,
|
||||
conditioning_channels,
|
||||
flip_sin_to_cos,
|
||||
freq_shift,
|
||||
down_block_types,
|
||||
mid_block_type,
|
||||
only_cross_attention,
|
||||
block_out_channels,
|
||||
layers_per_block,
|
||||
downsample_padding,
|
||||
mid_block_scale_factor,
|
||||
act_fn,
|
||||
norm_num_groups,
|
||||
norm_eps,
|
||||
cross_attention_dim,
|
||||
transformer_layers_per_block,
|
||||
encoder_hid_dim,
|
||||
encoder_hid_dim_type,
|
||||
attention_head_dim,
|
||||
num_attention_heads,
|
||||
use_linear_projection,
|
||||
class_embed_type,
|
||||
addition_embed_type,
|
||||
addition_time_embed_dim,
|
||||
num_class_embeds,
|
||||
upcast_attention,
|
||||
resnet_time_scale_shift,
|
||||
projection_class_embeddings_input_dim,
|
||||
controlnet_conditioning_channel_order,
|
||||
conditioning_embedding_out_channels,
|
||||
global_pool_conditions,
|
||||
addition_embed_type_num_heads,
|
||||
)
|
||||
|
||||
# control net conditioning embedding
|
||||
self.controlnet_cond_embedding = AnyTextControlNetConditioningEmbedding(
|
||||
conditioning_embedding_channels=block_out_channels[0],
|
||||
glyph_channels=conditioning_channels,
|
||||
position_channels=conditioning_channels,
|
||||
)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
sample: torch.Tensor,
|
||||
timestep: Union[torch.Tensor, float, int],
|
||||
encoder_hidden_states: torch.Tensor,
|
||||
controlnet_cond: torch.Tensor,
|
||||
conditioning_scale: float = 1.0,
|
||||
class_labels: Optional[torch.Tensor] = None,
|
||||
timestep_cond: Optional[torch.Tensor] = None,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
added_cond_kwargs: Optional[Dict[str, torch.Tensor]] = None,
|
||||
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
|
||||
guess_mode: bool = False,
|
||||
return_dict: bool = True,
|
||||
) -> Union[ControlNetOutput, Tuple[Tuple[torch.Tensor, ...], torch.Tensor]]:
|
||||
"""
|
||||
The [`~PromptDiffusionControlNetModel`] forward method.
|
||||
|
||||
Args:
|
||||
sample (`torch.Tensor`):
|
||||
The noisy input tensor.
|
||||
timestep (`Union[torch.Tensor, float, int]`):
|
||||
The number of timesteps to denoise an input.
|
||||
encoder_hidden_states (`torch.Tensor`):
|
||||
The encoder hidden states.
|
||||
#controlnet_cond (`torch.Tensor`):
|
||||
# The conditional input tensor of shape `(batch_size, sequence_length, hidden_size)`.
|
||||
conditioning_scale (`float`, defaults to `1.0`):
|
||||
The scale factor for ControlNet outputs.
|
||||
class_labels (`torch.Tensor`, *optional*, defaults to `None`):
|
||||
Optional class labels for conditioning. Their embeddings will be summed with the timestep embeddings.
|
||||
timestep_cond (`torch.Tensor`, *optional*, defaults to `None`):
|
||||
Additional conditional embeddings for timestep. If provided, the embeddings will be summed with the
|
||||
timestep_embedding passed through the `self.time_embedding` layer to obtain the final timestep
|
||||
embeddings.
|
||||
attention_mask (`torch.Tensor`, *optional*, defaults to `None`):
|
||||
An attention mask of shape `(batch, key_tokens)` is applied to `encoder_hidden_states`. If `1` the mask
|
||||
is kept, otherwise if `0` it is discarded. Mask will be converted into a bias, which adds large
|
||||
negative values to the attention scores corresponding to "discard" tokens.
|
||||
added_cond_kwargs (`dict`):
|
||||
Additional conditions for the Stable Diffusion XL UNet.
|
||||
cross_attention_kwargs (`dict[str]`, *optional*, defaults to `None`):
|
||||
A kwargs dictionary that if specified is passed along to the `AttnProcessor`.
|
||||
guess_mode (`bool`, defaults to `False`):
|
||||
In this mode, the ControlNet encoder tries its best to recognize the input content of the input even if
|
||||
you remove all prompts. A `guidance_scale` between 3.0 and 5.0 is recommended.
|
||||
return_dict (`bool`, defaults to `True`):
|
||||
Whether or not to return a [`~models.controlnet.ControlNetOutput`] instead of a plain tuple.
|
||||
|
||||
Returns:
|
||||
[`~models.controlnet.ControlNetOutput`] **or** `tuple`:
|
||||
If `return_dict` is `True`, a [`~models.controlnet.ControlNetOutput`] is returned, otherwise a tuple is
|
||||
returned where the first element is the sample tensor.
|
||||
"""
|
||||
# check channel order
|
||||
channel_order = self.config.controlnet_conditioning_channel_order
|
||||
|
||||
if channel_order == "rgb":
|
||||
# in rgb order by default
|
||||
...
|
||||
# elif channel_order == "bgr":
|
||||
# controlnet_cond = torch.flip(controlnet_cond, dims=[1])
|
||||
else:
|
||||
raise ValueError(f"unknown `controlnet_conditioning_channel_order`: {channel_order}")
|
||||
|
||||
# prepare attention_mask
|
||||
if attention_mask is not None:
|
||||
attention_mask = (1 - attention_mask.to(sample.dtype)) * -10000.0
|
||||
attention_mask = attention_mask.unsqueeze(1)
|
||||
|
||||
# 1. time
|
||||
timesteps = timestep
|
||||
if not torch.is_tensor(timesteps):
|
||||
# TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can
|
||||
# This would be a good case for the `match` statement (Python 3.10+)
|
||||
is_mps = sample.device.type == "mps"
|
||||
if isinstance(timestep, float):
|
||||
dtype = torch.float32 if is_mps else torch.float64
|
||||
else:
|
||||
dtype = torch.int32 if is_mps else torch.int64
|
||||
timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device)
|
||||
elif len(timesteps.shape) == 0:
|
||||
timesteps = timesteps[None].to(sample.device)
|
||||
|
||||
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
|
||||
timesteps = timesteps.expand(sample.shape[0])
|
||||
|
||||
t_emb = self.time_proj(timesteps)
|
||||
|
||||
# timesteps does not contain any weights and will always return f32 tensors
|
||||
# but time_embedding might actually be running in fp16. so we need to cast here.
|
||||
# there might be better ways to encapsulate this.
|
||||
t_emb = t_emb.to(dtype=sample.dtype)
|
||||
|
||||
emb = self.time_embedding(t_emb, timestep_cond)
|
||||
aug_emb = None
|
||||
|
||||
if self.class_embedding is not None:
|
||||
if class_labels is None:
|
||||
raise ValueError("class_labels should be provided when num_class_embeds > 0")
|
||||
|
||||
if self.config.class_embed_type == "timestep":
|
||||
class_labels = self.time_proj(class_labels)
|
||||
|
||||
class_emb = self.class_embedding(class_labels).to(dtype=self.dtype)
|
||||
emb = emb + class_emb
|
||||
|
||||
if self.config.addition_embed_type is not None:
|
||||
if self.config.addition_embed_type == "text":
|
||||
aug_emb = self.add_embedding(encoder_hidden_states)
|
||||
|
||||
elif self.config.addition_embed_type == "text_time":
|
||||
if "text_embeds" not in added_cond_kwargs:
|
||||
raise ValueError(
|
||||
f"{self.__class__} has the config param `addition_embed_type` set to 'text_time' which requires the keyword argument `text_embeds` to be passed in `added_cond_kwargs`"
|
||||
)
|
||||
text_embeds = added_cond_kwargs.get("text_embeds")
|
||||
if "time_ids" not in added_cond_kwargs:
|
||||
raise ValueError(
|
||||
f"{self.__class__} has the config param `addition_embed_type` set to 'text_time' which requires the keyword argument `time_ids` to be passed in `added_cond_kwargs`"
|
||||
)
|
||||
time_ids = added_cond_kwargs.get("time_ids")
|
||||
time_embeds = self.add_time_proj(time_ids.flatten())
|
||||
time_embeds = time_embeds.reshape((text_embeds.shape[0], -1))
|
||||
|
||||
add_embeds = torch.concat([text_embeds, time_embeds], dim=-1)
|
||||
add_embeds = add_embeds.to(emb.dtype)
|
||||
aug_emb = self.add_embedding(add_embeds)
|
||||
|
||||
emb = emb + aug_emb if aug_emb is not None else emb
|
||||
|
||||
# 2. pre-process
|
||||
sample = self.conv_in(sample)
|
||||
|
||||
controlnet_cond = self.controlnet_cond_embedding(*controlnet_cond)
|
||||
sample = sample + controlnet_cond
|
||||
|
||||
# 3. down
|
||||
down_block_res_samples = (sample,)
|
||||
for downsample_block in self.down_blocks:
|
||||
if hasattr(downsample_block, "has_cross_attention") and downsample_block.has_cross_attention:
|
||||
sample, res_samples = downsample_block(
|
||||
hidden_states=sample,
|
||||
temb=emb,
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
attention_mask=attention_mask,
|
||||
cross_attention_kwargs=cross_attention_kwargs,
|
||||
)
|
||||
else:
|
||||
sample, res_samples = downsample_block(hidden_states=sample, temb=emb)
|
||||
|
||||
down_block_res_samples += res_samples
|
||||
|
||||
# 4. mid
|
||||
if self.mid_block is not None:
|
||||
if hasattr(self.mid_block, "has_cross_attention") and self.mid_block.has_cross_attention:
|
||||
sample = self.mid_block(
|
||||
sample,
|
||||
emb,
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
attention_mask=attention_mask,
|
||||
cross_attention_kwargs=cross_attention_kwargs,
|
||||
)
|
||||
else:
|
||||
sample = self.mid_block(sample, emb)
|
||||
|
||||
# 5. Control net blocks
|
||||
controlnet_down_block_res_samples = ()
|
||||
|
||||
for down_block_res_sample, controlnet_block in zip(down_block_res_samples, self.controlnet_down_blocks):
|
||||
down_block_res_sample = controlnet_block(down_block_res_sample)
|
||||
controlnet_down_block_res_samples = controlnet_down_block_res_samples + (down_block_res_sample,)
|
||||
|
||||
down_block_res_samples = controlnet_down_block_res_samples
|
||||
|
||||
mid_block_res_sample = self.controlnet_mid_block(sample)
|
||||
|
||||
# 6. scaling
|
||||
if guess_mode and not self.config.global_pool_conditions:
|
||||
scales = torch.logspace(-1, 0, len(down_block_res_samples) + 1, device=sample.device) # 0.1 to 1.0
|
||||
scales = scales * conditioning_scale
|
||||
down_block_res_samples = [sample * scale for sample, scale in zip(down_block_res_samples, scales)]
|
||||
mid_block_res_sample = mid_block_res_sample * scales[-1] # last one
|
||||
else:
|
||||
down_block_res_samples = [sample * conditioning_scale for sample in down_block_res_samples]
|
||||
mid_block_res_sample = mid_block_res_sample * conditioning_scale
|
||||
|
||||
if self.config.global_pool_conditions:
|
||||
down_block_res_samples = [
|
||||
torch.mean(sample, dim=(2, 3), keepdim=True) for sample in down_block_res_samples
|
||||
]
|
||||
mid_block_res_sample = torch.mean(mid_block_res_sample, dim=(2, 3), keepdim=True)
|
||||
|
||||
if not return_dict:
|
||||
return (down_block_res_samples, mid_block_res_sample)
|
||||
|
||||
return ControlNetOutput(
|
||||
down_block_res_samples=down_block_res_samples, mid_block_res_sample=mid_block_res_sample
|
||||
)
|
||||
|
||||
|
||||
# Copied from diffusers.models.controlnet.zero_module
|
||||
def zero_module(module):
|
||||
for p in module.parameters():
|
||||
nn.init.zeros_(p)
|
||||
return module
|
||||
+209
@@ -0,0 +1,209 @@
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
from .RecSVTR import Block
|
||||
|
||||
|
||||
class Swish(nn.Module):
|
||||
def __int__(self):
|
||||
super(Swish, self).__int__()
|
||||
|
||||
def forward(self, x):
|
||||
return x * torch.sigmoid(x)
|
||||
|
||||
|
||||
class Im2Im(nn.Module):
|
||||
def __init__(self, in_channels, **kwargs):
|
||||
super().__init__()
|
||||
self.out_channels = in_channels
|
||||
|
||||
def forward(self, x):
|
||||
return x
|
||||
|
||||
|
||||
class Im2Seq(nn.Module):
|
||||
def __init__(self, in_channels, **kwargs):
|
||||
super().__init__()
|
||||
self.out_channels = in_channels
|
||||
|
||||
def forward(self, x):
|
||||
B, C, H, W = x.shape
|
||||
# assert H == 1
|
||||
x = x.reshape(B, C, H * W)
|
||||
x = x.permute((0, 2, 1))
|
||||
return x
|
||||
|
||||
|
||||
class EncoderWithRNN(nn.Module):
|
||||
def __init__(self, in_channels, **kwargs):
|
||||
super(EncoderWithRNN, self).__init__()
|
||||
hidden_size = kwargs.get("hidden_size", 256)
|
||||
self.out_channels = hidden_size * 2
|
||||
self.lstm = nn.LSTM(in_channels, hidden_size, bidirectional=True, num_layers=2, batch_first=True)
|
||||
|
||||
def forward(self, x):
|
||||
self.lstm.flatten_parameters()
|
||||
x, _ = self.lstm(x)
|
||||
return x
|
||||
|
||||
|
||||
class SequenceEncoder(nn.Module):
|
||||
def __init__(self, in_channels, encoder_type="rnn", **kwargs):
|
||||
super(SequenceEncoder, self).__init__()
|
||||
self.encoder_reshape = Im2Seq(in_channels)
|
||||
self.out_channels = self.encoder_reshape.out_channels
|
||||
self.encoder_type = encoder_type
|
||||
if encoder_type == "reshape":
|
||||
self.only_reshape = True
|
||||
else:
|
||||
support_encoder_dict = {"reshape": Im2Seq, "rnn": EncoderWithRNN, "svtr": EncoderWithSVTR}
|
||||
assert encoder_type in support_encoder_dict, "{} must in {}".format(
|
||||
encoder_type, support_encoder_dict.keys()
|
||||
)
|
||||
|
||||
self.encoder = support_encoder_dict[encoder_type](self.encoder_reshape.out_channels, **kwargs)
|
||||
self.out_channels = self.encoder.out_channels
|
||||
self.only_reshape = False
|
||||
|
||||
def forward(self, x):
|
||||
if self.encoder_type != "svtr":
|
||||
x = self.encoder_reshape(x)
|
||||
if not self.only_reshape:
|
||||
x = self.encoder(x)
|
||||
return x
|
||||
else:
|
||||
x = self.encoder(x)
|
||||
x = self.encoder_reshape(x)
|
||||
return x
|
||||
|
||||
|
||||
class ConvBNLayer(nn.Module):
|
||||
def __init__(
|
||||
self, in_channels, out_channels, kernel_size=3, stride=1, padding=0, bias_attr=False, groups=1, act=nn.GELU
|
||||
):
|
||||
super().__init__()
|
||||
self.conv = nn.Conv2d(
|
||||
in_channels=in_channels,
|
||||
out_channels=out_channels,
|
||||
kernel_size=kernel_size,
|
||||
stride=stride,
|
||||
padding=padding,
|
||||
groups=groups,
|
||||
# weight_attr=paddle.ParamAttr(initializer=nn.initializer.KaimingUniform()),
|
||||
bias=bias_attr,
|
||||
)
|
||||
self.norm = nn.BatchNorm2d(out_channels)
|
||||
self.act = Swish()
|
||||
|
||||
def forward(self, inputs):
|
||||
out = self.conv(inputs)
|
||||
out = self.norm(out)
|
||||
out = self.act(out)
|
||||
return out
|
||||
|
||||
|
||||
class EncoderWithSVTR(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
in_channels,
|
||||
dims=64, # XS
|
||||
depth=2,
|
||||
hidden_dims=120,
|
||||
use_guide=False,
|
||||
num_heads=8,
|
||||
qkv_bias=True,
|
||||
mlp_ratio=2.0,
|
||||
drop_rate=0.1,
|
||||
attn_drop_rate=0.1,
|
||||
drop_path=0.0,
|
||||
qk_scale=None,
|
||||
):
|
||||
super(EncoderWithSVTR, self).__init__()
|
||||
self.depth = depth
|
||||
self.use_guide = use_guide
|
||||
self.conv1 = ConvBNLayer(in_channels, in_channels // 8, padding=1, act="swish")
|
||||
self.conv2 = ConvBNLayer(in_channels // 8, hidden_dims, kernel_size=1, act="swish")
|
||||
|
||||
self.svtr_block = nn.ModuleList(
|
||||
[
|
||||
Block(
|
||||
dim=hidden_dims,
|
||||
num_heads=num_heads,
|
||||
mixer="Global",
|
||||
HW=None,
|
||||
mlp_ratio=mlp_ratio,
|
||||
qkv_bias=qkv_bias,
|
||||
qk_scale=qk_scale,
|
||||
drop=drop_rate,
|
||||
act_layer="swish",
|
||||
attn_drop=attn_drop_rate,
|
||||
drop_path=drop_path,
|
||||
norm_layer="nn.LayerNorm",
|
||||
epsilon=1e-05,
|
||||
prenorm=False,
|
||||
)
|
||||
for i in range(depth)
|
||||
]
|
||||
)
|
||||
self.norm = nn.LayerNorm(hidden_dims, eps=1e-6)
|
||||
self.conv3 = ConvBNLayer(hidden_dims, in_channels, kernel_size=1, act="swish")
|
||||
# last conv-nxn, the input is concat of input tensor and conv3 output tensor
|
||||
self.conv4 = ConvBNLayer(2 * in_channels, in_channels // 8, padding=1, act="swish")
|
||||
|
||||
self.conv1x1 = ConvBNLayer(in_channels // 8, dims, kernel_size=1, act="swish")
|
||||
self.out_channels = dims
|
||||
self.apply(self._init_weights)
|
||||
|
||||
def _init_weights(self, m):
|
||||
# weight initialization
|
||||
if isinstance(m, nn.Conv2d):
|
||||
nn.init.kaiming_normal_(m.weight, mode="fan_out")
|
||||
if m.bias is not None:
|
||||
nn.init.zeros_(m.bias)
|
||||
elif isinstance(m, nn.BatchNorm2d):
|
||||
nn.init.ones_(m.weight)
|
||||
nn.init.zeros_(m.bias)
|
||||
elif isinstance(m, nn.Linear):
|
||||
nn.init.normal_(m.weight, 0, 0.01)
|
||||
if m.bias is not None:
|
||||
nn.init.zeros_(m.bias)
|
||||
elif isinstance(m, nn.ConvTranspose2d):
|
||||
nn.init.kaiming_normal_(m.weight, mode="fan_out")
|
||||
if m.bias is not None:
|
||||
nn.init.zeros_(m.bias)
|
||||
elif isinstance(m, nn.LayerNorm):
|
||||
nn.init.ones_(m.weight)
|
||||
nn.init.zeros_(m.bias)
|
||||
|
||||
def forward(self, x):
|
||||
# for use guide
|
||||
if self.use_guide:
|
||||
z = x.clone()
|
||||
z.stop_gradient = True
|
||||
else:
|
||||
z = x
|
||||
# for short cut
|
||||
h = z
|
||||
# reduce dim
|
||||
z = self.conv1(z)
|
||||
z = self.conv2(z)
|
||||
# SVTR global block
|
||||
B, C, H, W = z.shape
|
||||
z = z.flatten(2).permute(0, 2, 1)
|
||||
|
||||
for blk in self.svtr_block:
|
||||
z = blk(z)
|
||||
|
||||
z = self.norm(z)
|
||||
# last stage
|
||||
z = z.reshape([-1, H, W, C]).permute(0, 3, 1, 2)
|
||||
z = self.conv3(z)
|
||||
z = torch.cat((h, z), dim=1)
|
||||
z = self.conv1x1(self.conv4(z))
|
||||
|
||||
return z
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
svtrRNN = EncoderWithSVTR(56)
|
||||
print(svtrRNN)
|
||||
@@ -0,0 +1,45 @@
|
||||
from torch import nn
|
||||
|
||||
|
||||
class CTCHead(nn.Module):
|
||||
def __init__(
|
||||
self, in_channels, out_channels=6625, fc_decay=0.0004, mid_channels=None, return_feats=False, **kwargs
|
||||
):
|
||||
super(CTCHead, self).__init__()
|
||||
if mid_channels is None:
|
||||
self.fc = nn.Linear(
|
||||
in_channels,
|
||||
out_channels,
|
||||
bias=True,
|
||||
)
|
||||
else:
|
||||
self.fc1 = nn.Linear(
|
||||
in_channels,
|
||||
mid_channels,
|
||||
bias=True,
|
||||
)
|
||||
self.fc2 = nn.Linear(
|
||||
mid_channels,
|
||||
out_channels,
|
||||
bias=True,
|
||||
)
|
||||
|
||||
self.out_channels = out_channels
|
||||
self.mid_channels = mid_channels
|
||||
self.return_feats = return_feats
|
||||
|
||||
def forward(self, x, labels=None):
|
||||
if self.mid_channels is None:
|
||||
predicts = self.fc(x)
|
||||
else:
|
||||
x = self.fc1(x)
|
||||
predicts = self.fc2(x)
|
||||
|
||||
if self.return_feats:
|
||||
result = {}
|
||||
result["ctc"] = predicts
|
||||
result["ctc_neck"] = x
|
||||
else:
|
||||
result = predicts
|
||||
|
||||
return result
|
||||
@@ -0,0 +1,49 @@
|
||||
from torch import nn
|
||||
|
||||
from .RecCTCHead import CTCHead
|
||||
from .RecMv1_enhance import MobileNetV1Enhance
|
||||
from .RNN import Im2Im, Im2Seq, SequenceEncoder
|
||||
|
||||
|
||||
backbone_dict = {"MobileNetV1Enhance": MobileNetV1Enhance}
|
||||
neck_dict = {"SequenceEncoder": SequenceEncoder, "Im2Seq": Im2Seq, "None": Im2Im}
|
||||
head_dict = {"CTCHead": CTCHead}
|
||||
|
||||
|
||||
class RecModel(nn.Module):
|
||||
def __init__(self, config):
|
||||
super().__init__()
|
||||
assert "in_channels" in config, "in_channels must in model config"
|
||||
backbone_type = config["backbone"].pop("type")
|
||||
assert backbone_type in backbone_dict, f"backbone.type must in {backbone_dict}"
|
||||
self.backbone = backbone_dict[backbone_type](config["in_channels"], **config["backbone"])
|
||||
|
||||
neck_type = config["neck"].pop("type")
|
||||
assert neck_type in neck_dict, f"neck.type must in {neck_dict}"
|
||||
self.neck = neck_dict[neck_type](self.backbone.out_channels, **config["neck"])
|
||||
|
||||
head_type = config["head"].pop("type")
|
||||
assert head_type in head_dict, f"head.type must in {head_dict}"
|
||||
self.head = head_dict[head_type](self.neck.out_channels, **config["head"])
|
||||
|
||||
self.name = f"RecModel_{backbone_type}_{neck_type}_{head_type}"
|
||||
|
||||
def load_3rd_state_dict(self, _3rd_name, _state):
|
||||
self.backbone.load_3rd_state_dict(_3rd_name, _state)
|
||||
self.neck.load_3rd_state_dict(_3rd_name, _state)
|
||||
self.head.load_3rd_state_dict(_3rd_name, _state)
|
||||
|
||||
def forward(self, x):
|
||||
import torch
|
||||
|
||||
x = x.to(torch.float32)
|
||||
x = self.backbone(x)
|
||||
x = self.neck(x)
|
||||
x = self.head(x)
|
||||
return x
|
||||
|
||||
def encode(self, x):
|
||||
x = self.backbone(x)
|
||||
x = self.neck(x)
|
||||
x = self.head.ctc_encoder(x)
|
||||
return x
|
||||
@@ -0,0 +1,197 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
from .common import Activation
|
||||
|
||||
|
||||
class ConvBNLayer(nn.Module):
|
||||
def __init__(
|
||||
self, num_channels, filter_size, num_filters, stride, padding, channels=None, num_groups=1, act="hard_swish"
|
||||
):
|
||||
super(ConvBNLayer, self).__init__()
|
||||
self.act = act
|
||||
self._conv = nn.Conv2d(
|
||||
in_channels=num_channels,
|
||||
out_channels=num_filters,
|
||||
kernel_size=filter_size,
|
||||
stride=stride,
|
||||
padding=padding,
|
||||
groups=num_groups,
|
||||
bias=False,
|
||||
)
|
||||
|
||||
self._batch_norm = nn.BatchNorm2d(
|
||||
num_filters,
|
||||
)
|
||||
if self.act is not None:
|
||||
self._act = Activation(act_type=act, inplace=True)
|
||||
|
||||
def forward(self, inputs):
|
||||
y = self._conv(inputs)
|
||||
y = self._batch_norm(y)
|
||||
if self.act is not None:
|
||||
y = self._act(y)
|
||||
return y
|
||||
|
||||
|
||||
class DepthwiseSeparable(nn.Module):
|
||||
def __init__(
|
||||
self, num_channels, num_filters1, num_filters2, num_groups, stride, scale, dw_size=3, padding=1, use_se=False
|
||||
):
|
||||
super(DepthwiseSeparable, self).__init__()
|
||||
self.use_se = use_se
|
||||
self._depthwise_conv = ConvBNLayer(
|
||||
num_channels=num_channels,
|
||||
num_filters=int(num_filters1 * scale),
|
||||
filter_size=dw_size,
|
||||
stride=stride,
|
||||
padding=padding,
|
||||
num_groups=int(num_groups * scale),
|
||||
)
|
||||
if use_se:
|
||||
self._se = SEModule(int(num_filters1 * scale))
|
||||
self._pointwise_conv = ConvBNLayer(
|
||||
num_channels=int(num_filters1 * scale),
|
||||
filter_size=1,
|
||||
num_filters=int(num_filters2 * scale),
|
||||
stride=1,
|
||||
padding=0,
|
||||
)
|
||||
|
||||
def forward(self, inputs):
|
||||
y = self._depthwise_conv(inputs)
|
||||
if self.use_se:
|
||||
y = self._se(y)
|
||||
y = self._pointwise_conv(y)
|
||||
return y
|
||||
|
||||
|
||||
class MobileNetV1Enhance(nn.Module):
|
||||
def __init__(self, in_channels=3, scale=0.5, last_conv_stride=1, last_pool_type="max", **kwargs):
|
||||
super().__init__()
|
||||
self.scale = scale
|
||||
self.block_list = []
|
||||
|
||||
self.conv1 = ConvBNLayer(
|
||||
num_channels=in_channels, filter_size=3, channels=3, num_filters=int(32 * scale), stride=2, padding=1
|
||||
)
|
||||
|
||||
conv2_1 = DepthwiseSeparable(
|
||||
num_channels=int(32 * scale), num_filters1=32, num_filters2=64, num_groups=32, stride=1, scale=scale
|
||||
)
|
||||
self.block_list.append(conv2_1)
|
||||
|
||||
conv2_2 = DepthwiseSeparable(
|
||||
num_channels=int(64 * scale), num_filters1=64, num_filters2=128, num_groups=64, stride=1, scale=scale
|
||||
)
|
||||
self.block_list.append(conv2_2)
|
||||
|
||||
conv3_1 = DepthwiseSeparable(
|
||||
num_channels=int(128 * scale), num_filters1=128, num_filters2=128, num_groups=128, stride=1, scale=scale
|
||||
)
|
||||
self.block_list.append(conv3_1)
|
||||
|
||||
conv3_2 = DepthwiseSeparable(
|
||||
num_channels=int(128 * scale),
|
||||
num_filters1=128,
|
||||
num_filters2=256,
|
||||
num_groups=128,
|
||||
stride=(2, 1),
|
||||
scale=scale,
|
||||
)
|
||||
self.block_list.append(conv3_2)
|
||||
|
||||
conv4_1 = DepthwiseSeparable(
|
||||
num_channels=int(256 * scale), num_filters1=256, num_filters2=256, num_groups=256, stride=1, scale=scale
|
||||
)
|
||||
self.block_list.append(conv4_1)
|
||||
|
||||
conv4_2 = DepthwiseSeparable(
|
||||
num_channels=int(256 * scale),
|
||||
num_filters1=256,
|
||||
num_filters2=512,
|
||||
num_groups=256,
|
||||
stride=(2, 1),
|
||||
scale=scale,
|
||||
)
|
||||
self.block_list.append(conv4_2)
|
||||
|
||||
for _ in range(5):
|
||||
conv5 = DepthwiseSeparable(
|
||||
num_channels=int(512 * scale),
|
||||
num_filters1=512,
|
||||
num_filters2=512,
|
||||
num_groups=512,
|
||||
stride=1,
|
||||
dw_size=5,
|
||||
padding=2,
|
||||
scale=scale,
|
||||
use_se=False,
|
||||
)
|
||||
self.block_list.append(conv5)
|
||||
|
||||
conv5_6 = DepthwiseSeparable(
|
||||
num_channels=int(512 * scale),
|
||||
num_filters1=512,
|
||||
num_filters2=1024,
|
||||
num_groups=512,
|
||||
stride=(2, 1),
|
||||
dw_size=5,
|
||||
padding=2,
|
||||
scale=scale,
|
||||
use_se=True,
|
||||
)
|
||||
self.block_list.append(conv5_6)
|
||||
|
||||
conv6 = DepthwiseSeparable(
|
||||
num_channels=int(1024 * scale),
|
||||
num_filters1=1024,
|
||||
num_filters2=1024,
|
||||
num_groups=1024,
|
||||
stride=last_conv_stride,
|
||||
dw_size=5,
|
||||
padding=2,
|
||||
use_se=True,
|
||||
scale=scale,
|
||||
)
|
||||
self.block_list.append(conv6)
|
||||
|
||||
self.block_list = nn.Sequential(*self.block_list)
|
||||
if last_pool_type == "avg":
|
||||
self.pool = nn.AvgPool2d(kernel_size=2, stride=2, padding=0)
|
||||
else:
|
||||
self.pool = nn.MaxPool2d(kernel_size=2, stride=2, padding=0)
|
||||
self.out_channels = int(1024 * scale)
|
||||
|
||||
def forward(self, inputs):
|
||||
y = self.conv1(inputs)
|
||||
y = self.block_list(y)
|
||||
y = self.pool(y)
|
||||
return y
|
||||
|
||||
|
||||
def hardsigmoid(x):
|
||||
return F.relu6(x + 3.0, inplace=True) / 6.0
|
||||
|
||||
|
||||
class SEModule(nn.Module):
|
||||
def __init__(self, channel, reduction=4):
|
||||
super(SEModule, self).__init__()
|
||||
self.avg_pool = nn.AdaptiveAvgPool2d(1)
|
||||
self.conv1 = nn.Conv2d(
|
||||
in_channels=channel, out_channels=channel // reduction, kernel_size=1, stride=1, padding=0, bias=True
|
||||
)
|
||||
self.conv2 = nn.Conv2d(
|
||||
in_channels=channel // reduction, out_channels=channel, kernel_size=1, stride=1, padding=0, bias=True
|
||||
)
|
||||
|
||||
def forward(self, inputs):
|
||||
outputs = self.avg_pool(inputs)
|
||||
outputs = self.conv1(outputs)
|
||||
outputs = F.relu(outputs)
|
||||
outputs = self.conv2(outputs)
|
||||
outputs = hardsigmoid(outputs)
|
||||
x = torch.mul(inputs, outputs)
|
||||
|
||||
return x
|
||||
@@ -0,0 +1,570 @@
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from torch.nn import functional
|
||||
from torch.nn.init import ones_, trunc_normal_, zeros_
|
||||
|
||||
|
||||
def drop_path(x, drop_prob=0.0, training=False):
|
||||
"""Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
|
||||
the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper...
|
||||
See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ...
|
||||
"""
|
||||
if drop_prob == 0.0 or not training:
|
||||
return x
|
||||
keep_prob = torch.tensor(1 - drop_prob)
|
||||
shape = (x.size()[0],) + (1,) * (x.ndim - 1)
|
||||
random_tensor = keep_prob + torch.rand(shape, dtype=x.dtype)
|
||||
random_tensor = torch.floor(random_tensor) # binarize
|
||||
output = x.divide(keep_prob) * random_tensor
|
||||
return output
|
||||
|
||||
|
||||
class Swish(nn.Module):
|
||||
def __int__(self):
|
||||
super(Swish, self).__int__()
|
||||
|
||||
def forward(self, x):
|
||||
return x * torch.sigmoid(x)
|
||||
|
||||
|
||||
class ConvBNLayer(nn.Module):
|
||||
def __init__(
|
||||
self, in_channels, out_channels, kernel_size=3, stride=1, padding=0, bias_attr=False, groups=1, act=nn.GELU
|
||||
):
|
||||
super().__init__()
|
||||
self.conv = nn.Conv2d(
|
||||
in_channels=in_channels,
|
||||
out_channels=out_channels,
|
||||
kernel_size=kernel_size,
|
||||
stride=stride,
|
||||
padding=padding,
|
||||
groups=groups,
|
||||
# weight_attr=paddle.ParamAttr(initializer=nn.initializer.KaimingUniform()),
|
||||
bias=bias_attr,
|
||||
)
|
||||
self.norm = nn.BatchNorm2d(out_channels)
|
||||
self.act = act()
|
||||
|
||||
def forward(self, inputs):
|
||||
out = self.conv(inputs)
|
||||
out = self.norm(out)
|
||||
out = self.act(out)
|
||||
return out
|
||||
|
||||
|
||||
class DropPath(nn.Module):
|
||||
"""Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks)."""
|
||||
|
||||
def __init__(self, drop_prob=None):
|
||||
super(DropPath, self).__init__()
|
||||
self.drop_prob = drop_prob
|
||||
|
||||
def forward(self, x):
|
||||
return drop_path(x, self.drop_prob, self.training)
|
||||
|
||||
|
||||
class Identity(nn.Module):
|
||||
def __init__(self):
|
||||
super(Identity, self).__init__()
|
||||
|
||||
def forward(self, input):
|
||||
return input
|
||||
|
||||
|
||||
class Mlp(nn.Module):
|
||||
def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.0):
|
||||
super().__init__()
|
||||
out_features = out_features or in_features
|
||||
hidden_features = hidden_features or in_features
|
||||
self.fc1 = nn.Linear(in_features, hidden_features)
|
||||
if isinstance(act_layer, str):
|
||||
self.act = Swish()
|
||||
else:
|
||||
self.act = act_layer()
|
||||
self.fc2 = nn.Linear(hidden_features, out_features)
|
||||
self.drop = nn.Dropout(drop)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.fc1(x)
|
||||
x = self.act(x)
|
||||
x = self.drop(x)
|
||||
x = self.fc2(x)
|
||||
x = self.drop(x)
|
||||
return x
|
||||
|
||||
|
||||
class ConvMixer(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
dim,
|
||||
num_heads=8,
|
||||
HW=(8, 25),
|
||||
local_k=(3, 3),
|
||||
):
|
||||
super().__init__()
|
||||
self.HW = HW
|
||||
self.dim = dim
|
||||
self.local_mixer = nn.Conv2d(
|
||||
dim,
|
||||
dim,
|
||||
local_k,
|
||||
1,
|
||||
(local_k[0] // 2, local_k[1] // 2),
|
||||
groups=num_heads,
|
||||
# weight_attr=ParamAttr(initializer=KaimingNormal())
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
h = self.HW[0]
|
||||
w = self.HW[1]
|
||||
x = x.transpose([0, 2, 1]).reshape([0, self.dim, h, w])
|
||||
x = self.local_mixer(x)
|
||||
x = x.flatten(2).transpose([0, 2, 1])
|
||||
return x
|
||||
|
||||
|
||||
class Attention(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
dim,
|
||||
num_heads=8,
|
||||
mixer="Global",
|
||||
HW=(8, 25),
|
||||
local_k=(7, 11),
|
||||
qkv_bias=False,
|
||||
qk_scale=None,
|
||||
attn_drop=0.0,
|
||||
proj_drop=0.0,
|
||||
):
|
||||
super().__init__()
|
||||
self.num_heads = num_heads
|
||||
head_dim = dim // num_heads
|
||||
self.scale = qk_scale or head_dim**-0.5
|
||||
|
||||
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
|
||||
self.attn_drop = nn.Dropout(attn_drop)
|
||||
self.proj = nn.Linear(dim, dim)
|
||||
self.proj_drop = nn.Dropout(proj_drop)
|
||||
self.HW = HW
|
||||
if HW is not None:
|
||||
H = HW[0]
|
||||
W = HW[1]
|
||||
self.N = H * W
|
||||
self.C = dim
|
||||
if mixer == "Local" and HW is not None:
|
||||
hk = local_k[0]
|
||||
wk = local_k[1]
|
||||
mask = torch.ones([H * W, H + hk - 1, W + wk - 1])
|
||||
for h in range(0, H):
|
||||
for w in range(0, W):
|
||||
mask[h * W + w, h : h + hk, w : w + wk] = 0.0
|
||||
mask_paddle = mask[:, hk // 2 : H + hk // 2, wk // 2 : W + wk // 2].flatten(1)
|
||||
mask_inf = torch.full([H * W, H * W], fill_value=float("-inf"))
|
||||
mask = torch.where(mask_paddle < 1, mask_paddle, mask_inf)
|
||||
self.mask = mask[None, None, :]
|
||||
# self.mask = mask.unsqueeze([0, 1])
|
||||
self.mixer = mixer
|
||||
|
||||
def forward(self, x):
|
||||
if self.HW is not None:
|
||||
N = self.N
|
||||
C = self.C
|
||||
else:
|
||||
_, N, C = x.shape
|
||||
qkv = self.qkv(x).reshape((-1, N, 3, self.num_heads, C // self.num_heads)).permute((2, 0, 3, 1, 4))
|
||||
q, k, v = qkv[0] * self.scale, qkv[1], qkv[2]
|
||||
|
||||
attn = q.matmul(k.permute((0, 1, 3, 2)))
|
||||
if self.mixer == "Local":
|
||||
attn += self.mask
|
||||
attn = functional.softmax(attn, dim=-1)
|
||||
attn = self.attn_drop(attn)
|
||||
|
||||
x = (attn.matmul(v)).permute((0, 2, 1, 3)).reshape((-1, N, C))
|
||||
x = self.proj(x)
|
||||
x = self.proj_drop(x)
|
||||
return x
|
||||
|
||||
|
||||
class Block(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
dim,
|
||||
num_heads,
|
||||
mixer="Global",
|
||||
local_mixer=(7, 11),
|
||||
HW=(8, 25),
|
||||
mlp_ratio=4.0,
|
||||
qkv_bias=False,
|
||||
qk_scale=None,
|
||||
drop=0.0,
|
||||
attn_drop=0.0,
|
||||
drop_path=0.0,
|
||||
act_layer=nn.GELU,
|
||||
norm_layer="nn.LayerNorm",
|
||||
epsilon=1e-6,
|
||||
prenorm=True,
|
||||
):
|
||||
super().__init__()
|
||||
if isinstance(norm_layer, str):
|
||||
self.norm1 = eval(norm_layer)(dim, eps=epsilon)
|
||||
else:
|
||||
self.norm1 = norm_layer(dim)
|
||||
if mixer == "Global" or mixer == "Local":
|
||||
self.mixer = Attention(
|
||||
dim,
|
||||
num_heads=num_heads,
|
||||
mixer=mixer,
|
||||
HW=HW,
|
||||
local_k=local_mixer,
|
||||
qkv_bias=qkv_bias,
|
||||
qk_scale=qk_scale,
|
||||
attn_drop=attn_drop,
|
||||
proj_drop=drop,
|
||||
)
|
||||
elif mixer == "Conv":
|
||||
self.mixer = ConvMixer(dim, num_heads=num_heads, HW=HW, local_k=local_mixer)
|
||||
else:
|
||||
raise TypeError("The mixer must be one of [Global, Local, Conv]")
|
||||
|
||||
self.drop_path = DropPath(drop_path) if drop_path > 0.0 else Identity()
|
||||
if isinstance(norm_layer, str):
|
||||
self.norm2 = eval(norm_layer)(dim, eps=epsilon)
|
||||
else:
|
||||
self.norm2 = norm_layer(dim)
|
||||
mlp_hidden_dim = int(dim * mlp_ratio)
|
||||
self.mlp_ratio = mlp_ratio
|
||||
self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
|
||||
self.prenorm = prenorm
|
||||
|
||||
def forward(self, x):
|
||||
if self.prenorm:
|
||||
x = self.norm1(x + self.drop_path(self.mixer(x)))
|
||||
x = self.norm2(x + self.drop_path(self.mlp(x)))
|
||||
else:
|
||||
x = x + self.drop_path(self.mixer(self.norm1(x)))
|
||||
x = x + self.drop_path(self.mlp(self.norm2(x)))
|
||||
return x
|
||||
|
||||
|
||||
class PatchEmbed(nn.Module):
|
||||
"""Image to Patch Embedding"""
|
||||
|
||||
def __init__(self, img_size=(32, 100), in_channels=3, embed_dim=768, sub_num=2):
|
||||
super().__init__()
|
||||
num_patches = (img_size[1] // (2**sub_num)) * (img_size[0] // (2**sub_num))
|
||||
self.img_size = img_size
|
||||
self.num_patches = num_patches
|
||||
self.embed_dim = embed_dim
|
||||
self.norm = None
|
||||
if sub_num == 2:
|
||||
self.proj = nn.Sequential(
|
||||
ConvBNLayer(
|
||||
in_channels=in_channels,
|
||||
out_channels=embed_dim // 2,
|
||||
kernel_size=3,
|
||||
stride=2,
|
||||
padding=1,
|
||||
act=nn.GELU,
|
||||
bias_attr=False,
|
||||
),
|
||||
ConvBNLayer(
|
||||
in_channels=embed_dim // 2,
|
||||
out_channels=embed_dim,
|
||||
kernel_size=3,
|
||||
stride=2,
|
||||
padding=1,
|
||||
act=nn.GELU,
|
||||
bias_attr=False,
|
||||
),
|
||||
)
|
||||
if sub_num == 3:
|
||||
self.proj = nn.Sequential(
|
||||
ConvBNLayer(
|
||||
in_channels=in_channels,
|
||||
out_channels=embed_dim // 4,
|
||||
kernel_size=3,
|
||||
stride=2,
|
||||
padding=1,
|
||||
act=nn.GELU,
|
||||
bias_attr=False,
|
||||
),
|
||||
ConvBNLayer(
|
||||
in_channels=embed_dim // 4,
|
||||
out_channels=embed_dim // 2,
|
||||
kernel_size=3,
|
||||
stride=2,
|
||||
padding=1,
|
||||
act=nn.GELU,
|
||||
bias_attr=False,
|
||||
),
|
||||
ConvBNLayer(
|
||||
in_channels=embed_dim // 2,
|
||||
out_channels=embed_dim,
|
||||
kernel_size=3,
|
||||
stride=2,
|
||||
padding=1,
|
||||
act=nn.GELU,
|
||||
bias_attr=False,
|
||||
),
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
B, C, H, W = x.shape
|
||||
assert (
|
||||
H == self.img_size[0] and W == self.img_size[1]
|
||||
), f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})."
|
||||
x = self.proj(x).flatten(2).permute(0, 2, 1)
|
||||
return x
|
||||
|
||||
|
||||
class SubSample(nn.Module):
|
||||
def __init__(self, in_channels, out_channels, types="Pool", stride=(2, 1), sub_norm="nn.LayerNorm", act=None):
|
||||
super().__init__()
|
||||
self.types = types
|
||||
if types == "Pool":
|
||||
self.avgpool = nn.AvgPool2d(kernel_size=(3, 5), stride=stride, padding=(1, 2))
|
||||
self.maxpool = nn.MaxPool2d(kernel_size=(3, 5), stride=stride, padding=(1, 2))
|
||||
self.proj = nn.Linear(in_channels, out_channels)
|
||||
else:
|
||||
self.conv = nn.Conv2d(
|
||||
in_channels,
|
||||
out_channels,
|
||||
kernel_size=3,
|
||||
stride=stride,
|
||||
padding=1,
|
||||
# weight_attr=ParamAttr(initializer=KaimingNormal())
|
||||
)
|
||||
self.norm = eval(sub_norm)(out_channels)
|
||||
if act is not None:
|
||||
self.act = act()
|
||||
else:
|
||||
self.act = None
|
||||
|
||||
def forward(self, x):
|
||||
if self.types == "Pool":
|
||||
x1 = self.avgpool(x)
|
||||
x2 = self.maxpool(x)
|
||||
x = (x1 + x2) * 0.5
|
||||
out = self.proj(x.flatten(2).permute((0, 2, 1)))
|
||||
else:
|
||||
x = self.conv(x)
|
||||
out = x.flatten(2).permute((0, 2, 1))
|
||||
out = self.norm(out)
|
||||
if self.act is not None:
|
||||
out = self.act(out)
|
||||
|
||||
return out
|
||||
|
||||
|
||||
class SVTRNet(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
img_size=[48, 100],
|
||||
in_channels=3,
|
||||
embed_dim=[64, 128, 256],
|
||||
depth=[3, 6, 3],
|
||||
num_heads=[2, 4, 8],
|
||||
mixer=["Local"] * 6 + ["Global"] * 6, # Local atten, Global atten, Conv
|
||||
local_mixer=[[7, 11], [7, 11], [7, 11]],
|
||||
patch_merging="Conv", # Conv, Pool, None
|
||||
mlp_ratio=4,
|
||||
qkv_bias=True,
|
||||
qk_scale=None,
|
||||
drop_rate=0.0,
|
||||
last_drop=0.1,
|
||||
attn_drop_rate=0.0,
|
||||
drop_path_rate=0.1,
|
||||
norm_layer="nn.LayerNorm",
|
||||
sub_norm="nn.LayerNorm",
|
||||
epsilon=1e-6,
|
||||
out_channels=192,
|
||||
out_char_num=25,
|
||||
block_unit="Block",
|
||||
act="nn.GELU",
|
||||
last_stage=True,
|
||||
sub_num=2,
|
||||
prenorm=True,
|
||||
use_lenhead=False,
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__()
|
||||
self.img_size = img_size
|
||||
self.embed_dim = embed_dim
|
||||
self.out_channels = out_channels
|
||||
self.prenorm = prenorm
|
||||
patch_merging = None if patch_merging != "Conv" and patch_merging != "Pool" else patch_merging
|
||||
self.patch_embed = PatchEmbed(
|
||||
img_size=img_size, in_channels=in_channels, embed_dim=embed_dim[0], sub_num=sub_num
|
||||
)
|
||||
num_patches = self.patch_embed.num_patches
|
||||
self.HW = [img_size[0] // (2**sub_num), img_size[1] // (2**sub_num)]
|
||||
self.pos_embed = nn.Parameter(torch.zeros(1, num_patches, embed_dim[0]))
|
||||
# self.pos_embed = self.create_parameter(
|
||||
# shape=[1, num_patches, embed_dim[0]], default_initializer=zeros_)
|
||||
|
||||
# self.add_parameter("pos_embed", self.pos_embed)
|
||||
|
||||
self.pos_drop = nn.Dropout(p=drop_rate)
|
||||
Block_unit = eval(block_unit)
|
||||
|
||||
dpr = np.linspace(0, drop_path_rate, sum(depth))
|
||||
self.blocks1 = nn.ModuleList(
|
||||
[
|
||||
Block_unit(
|
||||
dim=embed_dim[0],
|
||||
num_heads=num_heads[0],
|
||||
mixer=mixer[0 : depth[0]][i],
|
||||
HW=self.HW,
|
||||
local_mixer=local_mixer[0],
|
||||
mlp_ratio=mlp_ratio,
|
||||
qkv_bias=qkv_bias,
|
||||
qk_scale=qk_scale,
|
||||
drop=drop_rate,
|
||||
act_layer=eval(act),
|
||||
attn_drop=attn_drop_rate,
|
||||
drop_path=dpr[0 : depth[0]][i],
|
||||
norm_layer=norm_layer,
|
||||
epsilon=epsilon,
|
||||
prenorm=prenorm,
|
||||
)
|
||||
for i in range(depth[0])
|
||||
]
|
||||
)
|
||||
if patch_merging is not None:
|
||||
self.sub_sample1 = SubSample(
|
||||
embed_dim[0], embed_dim[1], sub_norm=sub_norm, stride=[2, 1], types=patch_merging
|
||||
)
|
||||
HW = [self.HW[0] // 2, self.HW[1]]
|
||||
else:
|
||||
HW = self.HW
|
||||
self.patch_merging = patch_merging
|
||||
self.blocks2 = nn.ModuleList(
|
||||
[
|
||||
Block_unit(
|
||||
dim=embed_dim[1],
|
||||
num_heads=num_heads[1],
|
||||
mixer=mixer[depth[0] : depth[0] + depth[1]][i],
|
||||
HW=HW,
|
||||
local_mixer=local_mixer[1],
|
||||
mlp_ratio=mlp_ratio,
|
||||
qkv_bias=qkv_bias,
|
||||
qk_scale=qk_scale,
|
||||
drop=drop_rate,
|
||||
act_layer=eval(act),
|
||||
attn_drop=attn_drop_rate,
|
||||
drop_path=dpr[depth[0] : depth[0] + depth[1]][i],
|
||||
norm_layer=norm_layer,
|
||||
epsilon=epsilon,
|
||||
prenorm=prenorm,
|
||||
)
|
||||
for i in range(depth[1])
|
||||
]
|
||||
)
|
||||
if patch_merging is not None:
|
||||
self.sub_sample2 = SubSample(
|
||||
embed_dim[1], embed_dim[2], sub_norm=sub_norm, stride=[2, 1], types=patch_merging
|
||||
)
|
||||
HW = [self.HW[0] // 4, self.HW[1]]
|
||||
else:
|
||||
HW = self.HW
|
||||
self.blocks3 = nn.ModuleList(
|
||||
[
|
||||
Block_unit(
|
||||
dim=embed_dim[2],
|
||||
num_heads=num_heads[2],
|
||||
mixer=mixer[depth[0] + depth[1] :][i],
|
||||
HW=HW,
|
||||
local_mixer=local_mixer[2],
|
||||
mlp_ratio=mlp_ratio,
|
||||
qkv_bias=qkv_bias,
|
||||
qk_scale=qk_scale,
|
||||
drop=drop_rate,
|
||||
act_layer=eval(act),
|
||||
attn_drop=attn_drop_rate,
|
||||
drop_path=dpr[depth[0] + depth[1] :][i],
|
||||
norm_layer=norm_layer,
|
||||
epsilon=epsilon,
|
||||
prenorm=prenorm,
|
||||
)
|
||||
for i in range(depth[2])
|
||||
]
|
||||
)
|
||||
self.last_stage = last_stage
|
||||
if last_stage:
|
||||
self.avg_pool = nn.AdaptiveAvgPool2d((1, out_char_num))
|
||||
self.last_conv = nn.Conv2d(
|
||||
in_channels=embed_dim[2],
|
||||
out_channels=self.out_channels,
|
||||
kernel_size=1,
|
||||
stride=1,
|
||||
padding=0,
|
||||
bias=False,
|
||||
)
|
||||
self.hardswish = nn.Hardswish()
|
||||
self.dropout = nn.Dropout(p=last_drop)
|
||||
if not prenorm:
|
||||
self.norm = eval(norm_layer)(embed_dim[-1], epsilon=epsilon)
|
||||
self.use_lenhead = use_lenhead
|
||||
if use_lenhead:
|
||||
self.len_conv = nn.Linear(embed_dim[2], self.out_channels)
|
||||
self.hardswish_len = nn.Hardswish()
|
||||
self.dropout_len = nn.Dropout(p=last_drop)
|
||||
|
||||
trunc_normal_(self.pos_embed, std=0.02)
|
||||
self.apply(self._init_weights)
|
||||
|
||||
def _init_weights(self, m):
|
||||
if isinstance(m, nn.Linear):
|
||||
trunc_normal_(m.weight, std=0.02)
|
||||
if isinstance(m, nn.Linear) and m.bias is not None:
|
||||
zeros_(m.bias)
|
||||
elif isinstance(m, nn.LayerNorm):
|
||||
zeros_(m.bias)
|
||||
ones_(m.weight)
|
||||
|
||||
def forward_features(self, x):
|
||||
x = self.patch_embed(x)
|
||||
x = x + self.pos_embed
|
||||
x = self.pos_drop(x)
|
||||
for blk in self.blocks1:
|
||||
x = blk(x)
|
||||
if self.patch_merging is not None:
|
||||
x = self.sub_sample1(x.permute([0, 2, 1]).reshape([-1, self.embed_dim[0], self.HW[0], self.HW[1]]))
|
||||
for blk in self.blocks2:
|
||||
x = blk(x)
|
||||
if self.patch_merging is not None:
|
||||
x = self.sub_sample2(x.permute([0, 2, 1]).reshape([-1, self.embed_dim[1], self.HW[0] // 2, self.HW[1]]))
|
||||
for blk in self.blocks3:
|
||||
x = blk(x)
|
||||
if not self.prenorm:
|
||||
x = self.norm(x)
|
||||
return x
|
||||
|
||||
def forward(self, x):
|
||||
x = self.forward_features(x)
|
||||
if self.use_lenhead:
|
||||
len_x = self.len_conv(x.mean(1))
|
||||
len_x = self.dropout_len(self.hardswish_len(len_x))
|
||||
if self.last_stage:
|
||||
if self.patch_merging is not None:
|
||||
h = self.HW[0] // 4
|
||||
else:
|
||||
h = self.HW[0]
|
||||
x = self.avg_pool(x.permute([0, 2, 1]).reshape([-1, self.embed_dim[2], h, self.HW[1]]))
|
||||
x = self.last_conv(x)
|
||||
x = self.hardswish(x)
|
||||
x = self.dropout(x)
|
||||
if self.use_lenhead:
|
||||
return x, len_x
|
||||
return x
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
a = torch.rand(1, 3, 48, 100)
|
||||
svtr = SVTRNet()
|
||||
|
||||
out = svtr(a)
|
||||
print(svtr)
|
||||
print(out.size())
|
||||
@@ -0,0 +1,74 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
|
||||
class Hswish(nn.Module):
|
||||
def __init__(self, inplace=True):
|
||||
super(Hswish, self).__init__()
|
||||
self.inplace = inplace
|
||||
|
||||
def forward(self, x):
|
||||
return x * F.relu6(x + 3.0, inplace=self.inplace) / 6.0
|
||||
|
||||
|
||||
# out = max(0, min(1, slop*x+offset))
|
||||
# paddle.fluid.layers.hard_sigmoid(x, slope=0.2, offset=0.5, name=None)
|
||||
class Hsigmoid(nn.Module):
|
||||
def __init__(self, inplace=True):
|
||||
super(Hsigmoid, self).__init__()
|
||||
self.inplace = inplace
|
||||
|
||||
def forward(self, x):
|
||||
# torch: F.relu6(x + 3., inplace=self.inplace) / 6.
|
||||
# paddle: F.relu6(1.2 * x + 3., inplace=self.inplace) / 6.
|
||||
return F.relu6(1.2 * x + 3.0, inplace=self.inplace) / 6.0
|
||||
|
||||
|
||||
class GELU(nn.Module):
|
||||
def __init__(self, inplace=True):
|
||||
super(GELU, self).__init__()
|
||||
self.inplace = inplace
|
||||
|
||||
def forward(self, x):
|
||||
return torch.nn.functional.gelu(x)
|
||||
|
||||
|
||||
class Swish(nn.Module):
|
||||
def __init__(self, inplace=True):
|
||||
super(Swish, self).__init__()
|
||||
self.inplace = inplace
|
||||
|
||||
def forward(self, x):
|
||||
if self.inplace:
|
||||
x.mul_(torch.sigmoid(x))
|
||||
return x
|
||||
else:
|
||||
return x * torch.sigmoid(x)
|
||||
|
||||
|
||||
class Activation(nn.Module):
|
||||
def __init__(self, act_type, inplace=True):
|
||||
super(Activation, self).__init__()
|
||||
act_type = act_type.lower()
|
||||
if act_type == "relu":
|
||||
self.act = nn.ReLU(inplace=inplace)
|
||||
elif act_type == "relu6":
|
||||
self.act = nn.ReLU6(inplace=inplace)
|
||||
elif act_type == "sigmoid":
|
||||
raise NotImplementedError
|
||||
elif act_type == "hard_sigmoid":
|
||||
self.act = Hsigmoid(inplace)
|
||||
elif act_type == "hard_swish":
|
||||
self.act = Hswish(inplace=inplace)
|
||||
elif act_type == "leakyrelu":
|
||||
self.act = nn.LeakyReLU(inplace=inplace)
|
||||
elif act_type == "gelu":
|
||||
self.act = GELU(inplace=inplace)
|
||||
elif act_type == "swish":
|
||||
self.act = Swish(inplace=inplace)
|
||||
else:
|
||||
raise NotImplementedError
|
||||
|
||||
def forward(self, inputs):
|
||||
return self.act(inputs)
|
||||
@@ -0,0 +1,95 @@
|
||||
0
|
||||
1
|
||||
2
|
||||
3
|
||||
4
|
||||
5
|
||||
6
|
||||
7
|
||||
8
|
||||
9
|
||||
:
|
||||
;
|
||||
<
|
||||
=
|
||||
>
|
||||
?
|
||||
@
|
||||
A
|
||||
B
|
||||
C
|
||||
D
|
||||
E
|
||||
F
|
||||
G
|
||||
H
|
||||
I
|
||||
J
|
||||
K
|
||||
L
|
||||
M
|
||||
N
|
||||
O
|
||||
P
|
||||
Q
|
||||
R
|
||||
S
|
||||
T
|
||||
U
|
||||
V
|
||||
W
|
||||
X
|
||||
Y
|
||||
Z
|
||||
[
|
||||
\
|
||||
]
|
||||
^
|
||||
_
|
||||
`
|
||||
a
|
||||
b
|
||||
c
|
||||
d
|
||||
e
|
||||
f
|
||||
g
|
||||
h
|
||||
i
|
||||
j
|
||||
k
|
||||
l
|
||||
m
|
||||
n
|
||||
o
|
||||
p
|
||||
q
|
||||
r
|
||||
s
|
||||
t
|
||||
u
|
||||
v
|
||||
w
|
||||
x
|
||||
y
|
||||
z
|
||||
{
|
||||
|
|
||||
}
|
||||
~
|
||||
!
|
||||
"
|
||||
#
|
||||
$
|
||||
%
|
||||
&
|
||||
'
|
||||
(
|
||||
)
|
||||
*
|
||||
+
|
||||
,
|
||||
-
|
||||
.
|
||||
/
|
||||
|
||||
@@ -627,6 +627,7 @@ def main(args):
|
||||
ema_vae = EMAModel(vae.parameters(), model_cls=AutoencoderKL, model_config=vae.config)
|
||||
perceptual_loss = lpips.LPIPS(net="vgg").eval()
|
||||
discriminator = NLayerDiscriminator(input_nc=3, n_layers=3, use_actnorm=False).apply(weights_init)
|
||||
discriminator = torch.nn.SyncBatchNorm.convert_sync_batchnorm(discriminator)
|
||||
|
||||
# Taken from [Sayak Paul's Diffusers PR #6511](https://github.com/huggingface/diffusers/pull/6511/files)
|
||||
def unwrap_model(model):
|
||||
@@ -951,13 +952,20 @@ def main(args):
|
||||
logits_fake = discriminator(reconstructions)
|
||||
disc_loss = hinge_d_loss if args.disc_loss == "hinge" else vanilla_d_loss
|
||||
disc_factor = args.disc_factor if global_step >= args.disc_start else 0.0
|
||||
disc_loss = disc_factor * disc_loss(logits_real, logits_fake)
|
||||
d_loss = disc_factor * disc_loss(logits_real, logits_fake)
|
||||
logs = {
|
||||
"disc_loss": disc_loss.detach().mean().item(),
|
||||
"disc_loss": d_loss.detach().mean().item(),
|
||||
"logits_real": logits_real.detach().mean().item(),
|
||||
"logits_fake": logits_fake.detach().mean().item(),
|
||||
"disc_lr": disc_lr_scheduler.get_last_lr()[0],
|
||||
}
|
||||
accelerator.backward(d_loss)
|
||||
if accelerator.sync_gradients:
|
||||
params_to_clip = discriminator.parameters()
|
||||
accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm)
|
||||
disc_optimizer.step()
|
||||
disc_lr_scheduler.step()
|
||||
disc_optimizer.zero_grad(set_to_none=args.set_grads_to_none)
|
||||
# Checks if the accelerator has performed an optimization step behind the scenes
|
||||
if accelerator.sync_gradients:
|
||||
progress_bar.update(1)
|
||||
|
||||
@@ -381,9 +381,7 @@ def log_validation(vae, unet, controlnet, args, accelerator, weight_dtype, step)
|
||||
validation_prompt = log["validation_prompt"]
|
||||
validation_image = log["validation_image"]
|
||||
|
||||
formatted_images = []
|
||||
|
||||
formatted_images.append(np.asarray(validation_image))
|
||||
formatted_images = [np.asarray(validation_image)]
|
||||
|
||||
for image in images:
|
||||
formatted_images.append(np.asarray(image))
|
||||
|
||||
@@ -164,9 +164,7 @@ def log_validation(
|
||||
validation_prompt = log["validation_prompt"]
|
||||
validation_image = log["validation_image"]
|
||||
|
||||
formatted_images = []
|
||||
|
||||
formatted_images.append(np.asarray(validation_image))
|
||||
formatted_images = [np.asarray(validation_image)]
|
||||
|
||||
for image in images:
|
||||
formatted_images.append(np.asarray(image))
|
||||
|
||||
@@ -1,8 +1,6 @@
|
||||
# Generating images using Flux and PyTorch/XLA
|
||||
|
||||
The `flux_inference` script shows how to do image generation using Flux on TPU devices using PyTorch/XLA. It uses the pallas kernel for flash attention for faster generation.
|
||||
|
||||
It has been tested on [Trillium](https://cloud.google.com/blog/products/compute/introducing-trillium-6th-gen-tpus) TPU versions. No other TPU types have been tested.
|
||||
The `flux_inference` script shows how to do image generation using Flux on TPU devices using PyTorch/XLA. It uses the pallas kernel for flash attention for faster generation using custom flash block sizes for better performance on [Trillium](https://cloud.google.com/blog/products/compute/introducing-trillium-6th-gen-tpus) TPU versions. No other TPU types have been tested.
|
||||
|
||||
## Create TPU
|
||||
|
||||
@@ -23,20 +21,23 @@ Verify that PyTorch and PyTorch/XLA were installed correctly:
|
||||
python3 -c "import torch; import torch_xla;"
|
||||
```
|
||||
|
||||
Install dependencies
|
||||
Clone the diffusers repo and install dependencies
|
||||
|
||||
```bash
|
||||
git clone https://github.com/huggingface/diffusers.git
|
||||
cd diffusers
|
||||
pip install transformers accelerate sentencepiece structlog
|
||||
pushd ../../..
|
||||
pip install .
|
||||
popd
|
||||
cd examples/research_projects/pytorch_xla/inference/flux/
|
||||
```
|
||||
|
||||
## Run the inference job
|
||||
|
||||
### Authenticate
|
||||
|
||||
Run the following command to authenticate your token in order to download Flux weights.
|
||||
**Gated Model**
|
||||
|
||||
As the model is gated, before using it with diffusers you first need to go to the [FLUX.1 [dev] Hugging Face page](https://huggingface.co/black-forest-labs/FLUX.1-dev), fill in the form and accept the gate. Once you are in, you need to log in so that your system knows you’ve accepted the gate. Use the command below to log in:
|
||||
|
||||
```bash
|
||||
huggingface-cli login
|
||||
@@ -50,51 +51,116 @@ python flux_inference.py
|
||||
|
||||
The script loads the text encoders onto the CPU and the Flux transformer and VAE models onto the TPU. The first time the script runs, the compilation time is longer, while the cache stores the compiled programs. On subsequent runs, compilation is much faster and the subsequent passes being the fastest.
|
||||
|
||||
On a Trillium v6e-4, you should expect ~9 sec / 4 images or 2.25 sec / image (as devices run generation in parallel):
|
||||
On a Trillium v6e-4, you should expect ~6 sec / 4 images or 1.5 sec / image (as devices run generation in parallel):
|
||||
|
||||
```bash
|
||||
WARNING:root:libtpu.so and TPU device found. Setting PJRT_DEVICE=TPU.
|
||||
Loading checkpoint shards: 100%|███████████████████████████████| 2/2 [00:00<00:00, 7.01it/s]
|
||||
Loading pipeline components...: 40%|██████████▍ | 2/5 [00:00<00:00, 3.78it/s]You set `add_prefix_space`. The tokenizer needs to be converted from the slow tokenizers
|
||||
Loading pipeline components...: 100%|██████████████████████████| 5/5 [00:00<00:00, 6.72it/s]
|
||||
2025-01-10 00:51:25 [info ] loading flux from black-forest-labs/FLUX.1-dev
|
||||
2025-01-10 00:51:25 [info ] loading flux from black-forest-labs/FLUX.1-dev
|
||||
2025-01-10 00:51:26 [info ] loading flux from black-forest-labs/FLUX.1-dev
|
||||
2025-01-10 00:51:26 [info ] loading flux from black-forest-labs/FLUX.1-dev
|
||||
Loading pipeline components...: 100%|██████████████████████████| 3/3 [00:00<00:00, 4.29it/s]
|
||||
Loading pipeline components...: 100%|██████████████████████████| 3/3 [00:00<00:00, 3.26it/s]
|
||||
Loading pipeline components...: 100%|██████████████████████████| 3/3 [00:00<00:00, 3.27it/s]
|
||||
Loading pipeline components...: 100%|██████████████████████████| 3/3 [00:00<00:00, 3.25it/s]
|
||||
2025-01-10 00:51:34 [info ] starting compilation run...
|
||||
2025-01-10 00:51:35 [info ] starting compilation run...
|
||||
2025-01-10 00:51:37 [info ] starting compilation run...
|
||||
2025-01-10 00:51:37 [info ] starting compilation run...
|
||||
2025-01-10 00:52:52 [info ] compilation took 78.5155531649998 sec.
|
||||
2025-01-10 00:52:53 [info ] starting inference run...
|
||||
2025-01-10 00:52:57 [info ] compilation took 79.52986721400157 sec.
|
||||
2025-01-10 00:52:57 [info ] compilation took 81.91776501700042 sec.
|
||||
2025-01-10 00:52:57 [info ] compilation took 80.24951512600092 sec.
|
||||
2025-01-10 00:52:57 [info ] starting inference run...
|
||||
2025-01-10 00:52:57 [info ] starting inference run...
|
||||
2025-01-10 00:52:58 [info ] starting inference run...
|
||||
2025-01-10 00:53:22 [info ] inference time: 25.112665320000815
|
||||
2025-01-10 00:53:30 [info ] inference time: 7.7019307739992655
|
||||
2025-01-10 00:53:38 [info ] inference time: 7.693858365000779
|
||||
2025-01-10 00:53:46 [info ] inference time: 7.690621814001133
|
||||
2025-01-10 00:53:53 [info ] inference time: 7.679490454000188
|
||||
2025-01-10 00:54:01 [info ] inference time: 7.68949568500102
|
||||
2025-01-10 00:54:09 [info ] inference time: 7.686633744000574
|
||||
2025-01-10 00:54:16 [info ] inference time: 7.696786873999372
|
||||
2025-01-10 00:54:24 [info ] inference time: 7.691988694999964
|
||||
2025-01-10 00:54:32 [info ] inference time: 7.700649563999832
|
||||
2025-01-10 00:54:39 [info ] inference time: 7.684993574001055
|
||||
2025-01-10 00:54:47 [info ] inference time: 7.68343457499941
|
||||
2025-01-10 00:54:55 [info ] inference time: 7.667921153999487
|
||||
2025-01-10 00:55:02 [info ] inference time: 7.683585194001353
|
||||
2025-01-10 00:55:06 [info ] avg. inference over 15 iterations took 8.61202360273334 sec.
|
||||
2025-01-10 00:55:07 [info ] avg. inference over 15 iterations took 8.952725123600006 sec.
|
||||
2025-01-10 00:55:10 [info ] inference time: 7.673799695001435
|
||||
2025-01-10 00:55:10 [info ] avg. inference over 15 iterations took 8.849190365400379 sec.
|
||||
2025-01-10 00:55:10 [info ] saved metric information as /tmp/metrics_report.txt
|
||||
2025-01-10 00:55:12 [info ] avg. inference over 15 iterations took 8.940161458400205 sec.
|
||||
Loading checkpoint shards: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 2/2 [00:00<00:00, 7.06it/s]
|
||||
Loading pipeline components...: 60%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████ | 3/5 [00:00<00:00, 6.80it/s]You set `add_prefix_space`. The tokenizer needs to be converted from the slow tokenizers
|
||||
Loading pipeline components...: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 5/5 [00:00<00:00, 6.28it/s]
|
||||
2025-03-14 21:17:53 [info ] loading flux from black-forest-labs/FLUX.1-dev
|
||||
2025-03-14 21:17:53 [info ] loading flux from black-forest-labs/FLUX.1-dev
|
||||
Loading pipeline components...: 0%| | 0/3 [00:00<?, ?it/s]2025-03-14 21:17:53 [info ] loading flux from black-forest-labs/FLUX.1-dev
|
||||
2025-03-14 21:17:53 [info ] loading flux from black-forest-labs/FLUX.1-dev
|
||||
Loading pipeline components...: 0%| | 0/3 [00:00<?, ?it/s]2025-03-14 21:17:54 [info ] loading flux from black-forest-labs/FLUX.1-dev
|
||||
2025-03-14 21:17:54 [info ] loading flux from black-forest-labs/FLUX.1-dev
|
||||
Loading pipeline components...: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 3/3 [00:00<00:00, 3.66it/s]
|
||||
Loading pipeline components...: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 3/3 [00:00<00:00, 4.48it/s]
|
||||
Loading pipeline components...: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 3/3 [00:00<00:00, 3.32it/s]
|
||||
Loading pipeline components...: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 3/3 [00:00<00:00, 3.69it/s]
|
||||
Loading pipeline components...: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 3/3 [00:00<00:00, 3.74it/s]
|
||||
Loading pipeline components...: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 3/3 [00:00<00:00, 3.10it/s]
|
||||
2025-03-14 21:17:56 [info ] loading flux from black-forest-labs/FLUX.1-dev
|
||||
Loading pipeline components...: 0%| | 0/3 [00:00<?, ?it/s]2025-03-14 21:17:56 [info ] loading flux from black-forest-labs/FLUX.1-dev
|
||||
Loading pipeline components...: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 3/3 [00:00<00:00, 3.55it/s]
|
||||
Loading pipeline components...: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 3/3 [00:02<00:00, 1.46it/s]
|
||||
2025-03-14 21:18:34 [info ] starting compilation run...
|
||||
2025-03-14 21:18:37 [info ] starting compilation run...
|
||||
2025-03-14 21:18:38 [info ] starting compilation run...
|
||||
2025-03-14 21:18:39 [info ] starting compilation run...
|
||||
2025-03-14 21:18:41 [info ] starting compilation run...
|
||||
2025-03-14 21:18:41 [info ] starting compilation run...
|
||||
2025-03-14 21:18:42 [info ] starting compilation run...
|
||||
2025-03-14 21:18:43 [info ] starting compilation run...
|
||||
82%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████▊ | 23/28 [13:35<03:04, 36.80s/it]2025-03-14 21:33:42.057559: W torch_xla/csrc/runtime/pjrt_computation_client.cc:667] Failed to deserialize executable: INTERNAL: TfrtTpuExecutable proto deserialization failed while parsing core program!
|
||||
100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 28/28 [16:27<00:00, 35.28s/it]
|
||||
100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 28/28 [16:27<00:00, 35.26s/it]
|
||||
2025-03-14 21:36:38 [info ] compilation took 1079.3314765350078 sec.
|
||||
100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 28/28 [16:12<00:00, 34.73s/it]
|
||||
2025-03-14 21:36:38 [info ] starting inference run...
|
||||
100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 28/28 [16:12<00:00, 34.73s/it]
|
||||
2025-03-14 21:36:38 [info ] compilation took 1081.89390801001 sec.
|
||||
2025-03-14 21:36:39 [info ] starting inference run...
|
||||
2025-03-14 21:36:39 [info ] compilation took 1077.1543154849933 sec.
|
||||
2025-03-14 21:36:39 [info ] compilation took 1075.7239800530078 sec.
|
||||
2025-03-14 21:36:39 [info ] starting inference run...
|
||||
2025-03-14 21:36:40 [info ] starting inference run...
|
||||
100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 28/28 [16:22<00:00, 35.10s/it]
|
||||
2025-03-14 21:36:50 [info ] compilation took 1088.1632604240003 sec.
|
||||
2025-03-14 21:36:50 [info ] starting inference run...
|
||||
100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 28/28 [16:28<00:00, 35.32s/it]
|
||||
2025-03-14 21:36:55 [info ] compilation took 1096.8027802760043 sec.
|
||||
2025-03-14 21:36:56 [info ] starting inference run...
|
||||
100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 28/28 [16:59<00:00, 36.40s/it]
|
||||
2025-03-14 21:37:08 [info ] compilation took 1113.8591305939917 sec.
|
||||
2025-03-14 21:37:08 [info ] starting inference run...
|
||||
100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 28/28 [16:55<00:00, 36.26s/it]
|
||||
2025-03-14 21:37:22 [info ] compilation took 1120.5590810020076 sec.
|
||||
2025-03-14 21:37:22 [info ] starting inference run...
|
||||
100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 28/28 [00:05<00:00, 5.00it/s]
|
||||
100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 28/28 [00:09<00:00, 2.98it/s]
|
||||
100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 28/28 [00:06<00:00, 4.08it/s]
|
||||
100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 28/28 [00:09<00:00, 2.82it/s]
|
||||
100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 28/28 [00:08<00:00, 3.34it/s]
|
||||
100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 28/28 [00:06<00:00, 4.22it/s]
|
||||
100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 28/28 [00:06<00:00, 4.09it/s]
|
||||
100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 28/28 [00:11<00:00, 2.41it/s]
|
||||
100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 28/28 [00:06<00:00, 4.50it/s]
|
||||
100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 28/28 [00:05<00:00, 5.10it/s]
|
||||
100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 28/28 [00:05<00:00, 5.27it/s]
|
||||
100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 28/28 [00:05<00:00, 4.80it/s]
|
||||
100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 28/28 [00:05<00:00, 5.39it/s]
|
||||
100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 28/28 [00:05<00:00, 5.39it/s]
|
||||
100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 28/28 [00:04<00:00, 5.67it/s]
|
||||
29%|█████████████████████████████████████████████████████████████████████████████▍ | 8/28 [00:01<00:03, 6.08it/s]/home/jfacevedo_google_com/diffusers/src/diffusers/image_processor.py:147: RuntimeWarning: invalid value encountered in cast
|
||||
images = (images * 255).round().astype("uint8")
|
||||
100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 28/28 [00:04<00:00, 5.82it/s]
|
||||
100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 28/28 [00:04<00:00, 5.93it/s]
|
||||
100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 28/28 [00:04<00:00, 6.02it/s]
|
||||
100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 28/28 [00:04<00:00, 6.02it/s]
|
||||
100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 28/28 [00:04<00:00, 5.98it/s]
|
||||
71%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████▊ | 20/28 [00:03<00:01, 6.03it/s]2025-03-14 21:38:32 [info ] inference time: 5.962021178987925
|
||||
100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 28/28 [00:04<00:00, 5.89it/s]
|
||||
100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 28/28 [00:04<00:00, 6.09it/s]
|
||||
2025-03-14 21:38:32 [info ] avg. inference over 5 iterations took 7.2685392687970305 sec.
|
||||
2025-03-14 21:38:32 [info ] avg. inference over 5 iterations took 7.402720856998348 sec.
|
||||
100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 28/28 [00:04<00:00, 6.01it/s]
|
||||
100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 28/28 [00:04<00:00, 5.89it/s]
|
||||
100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 28/28 [00:04<00:00, 5.96it/s]
|
||||
100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 28/28 [00:04<00:00, 6.06it/s]
|
||||
71%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████▊ | 20/28 [00:03<00:01, 6.01it/s]2025-03-14 21:38:38 [info ] inference time: 5.950578948002658
|
||||
100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 28/28 [00:04<00:00, 5.87it/s]
|
||||
100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 28/28 [00:04<00:00, 6.09it/s]
|
||||
100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 28/28 [00:04<00:00, 6.00it/s]
|
||||
100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 28/28 [00:04<00:00, 5.86it/s]
|
||||
100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 28/28 [00:04<00:00, 5.99it/s]
|
||||
100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 28/28 [00:04<00:00, 6.05it/s]
|
||||
2025-03-14 21:38:43 [info ] avg. inference over 5 iterations took 6.763298449796276 sec.
|
||||
71%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████▊ | 20/28 [00:03<00:01, 6.04it/s]2025-03-14 21:38:44 [info ] inference time: 5.949129879008979
|
||||
100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 28/28 [00:04<00:00, 5.92it/s]
|
||||
100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 28/28 [00:04<00:00, 6.10it/s]
|
||||
100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 28/28 [00:04<00:00, 6.02it/s]
|
||||
39%|██████████████████████████████████████████████████████████████████████████████████████████████████████████ | 11/28 [00:01<00:02, 5.98it/s]2025-03-14 21:38:46 [info ] avg. inference over 5 iterations took 7.221068455604836 sec.
|
||||
100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 28/28 [00:04<00:00, 5.96it/s]
|
||||
100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 28/28 [00:04<00:00, 6.08it/s]
|
||||
93%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████▋ | 26/28 [00:04<00:00, 5.92it/s]2025-03-14 21:38:50 [info ] inference time: 5.954778069004533
|
||||
100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 28/28 [00:04<00:00, 5.90it/s]
|
||||
11%|█████████████████████████████ | 3/28 [00:00<00:04, 6.03it/s]2025-03-14 21:38:50 [info ] avg. inference over 5 iterations took 6.05970350120042 sec.
|
||||
100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 28/28 [00:04<00:00, 6.02it/s]
|
||||
32%|███████████████████████████████████████████████████████████████████████████████████████ | 9/28 [00:01<00:03, 5.99it/s]2025-03-14 21:38:51 [info ] avg. inference over 5 iterations took 6.018543455796316 sec.
|
||||
54%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████▋ | 15/28 [00:02<00:02, 6.00it/s]2025-03-14 21:38:52 [info ] avg. inference over 5 iterations took 5.9609976705978625 sec.
|
||||
100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 28/28 [00:04<00:00, 5.97it/s]
|
||||
2025-03-14 21:38:56 [info ] inference time: 5.944058528999449
|
||||
2025-03-14 21:38:56 [info ] avg. inference over 5 iterations took 5.952113320800708 sec.
|
||||
2025-03-14 21:38:56 [info ] saved metric information as /tmp/metrics_report.txt
|
||||
```
|
||||
@@ -9,6 +9,7 @@ import torch_xla.debug.metrics as met
|
||||
import torch_xla.debug.profiler as xp
|
||||
import torch_xla.distributed.xla_multiprocessing as xmp
|
||||
import torch_xla.runtime as xr
|
||||
from torch_xla.experimental.custom_kernel import FlashAttention
|
||||
|
||||
from diffusers import FluxPipeline
|
||||
|
||||
@@ -36,6 +37,19 @@ def _main(index, args, text_pipe, ckpt_id):
|
||||
ckpt_id, text_encoder=None, tokenizer=None, text_encoder_2=None, tokenizer_2=None, torch_dtype=torch.bfloat16
|
||||
).to(device0)
|
||||
flux_pipe.transformer.enable_xla_flash_attention(partition_spec=("data", None, None, None), is_flux=True)
|
||||
FlashAttention.DEFAULT_BLOCK_SIZES = {
|
||||
"block_q": 1536,
|
||||
"block_k_major": 1536,
|
||||
"block_k": 1536,
|
||||
"block_b": 1536,
|
||||
"block_q_major_dkv": 1536,
|
||||
"block_k_major_dkv": 1536,
|
||||
"block_q_dkv": 1536,
|
||||
"block_k_dkv": 1536,
|
||||
"block_q_dq": 1536,
|
||||
"block_k_dq": 1536,
|
||||
"block_k_major_dq": 1536,
|
||||
}
|
||||
|
||||
prompt = "photograph of an electronics chip in the shape of a race car with trillium written on its side"
|
||||
width = args.width
|
||||
@@ -69,14 +83,14 @@ def _main(index, args, text_pipe, ckpt_id):
|
||||
xm.set_rng_state(seed=unique_seed, device=device0)
|
||||
times = []
|
||||
logger.info("starting inference run...")
|
||||
with torch.no_grad():
|
||||
prompt_embeds, pooled_prompt_embeds, text_ids = text_pipe.encode_prompt(
|
||||
prompt=prompt, prompt_2=None, max_sequence_length=512
|
||||
)
|
||||
prompt_embeds = prompt_embeds.to(device0)
|
||||
pooled_prompt_embeds = pooled_prompt_embeds.to(device0)
|
||||
for _ in range(args.itters):
|
||||
ts = perf_counter()
|
||||
with torch.no_grad():
|
||||
prompt_embeds, pooled_prompt_embeds, text_ids = text_pipe.encode_prompt(
|
||||
prompt=prompt, prompt_2=None, max_sequence_length=512
|
||||
)
|
||||
prompt_embeds = prompt_embeds.to(device0)
|
||||
pooled_prompt_embeds = pooled_prompt_embeds.to(device0)
|
||||
|
||||
if args.profile:
|
||||
xp.trace_detached(f"localhost:{profiler_port}", str(profile_path), duration_ms=profile_duration)
|
||||
@@ -92,7 +106,7 @@ def _main(index, args, text_pipe, ckpt_id):
|
||||
if index == 0:
|
||||
logger.info(f"inference time: {inference_time}")
|
||||
times.append(inference_time)
|
||||
logger.info(f"avg. inference over {args.itters} iterations took {sum(times)/len(times)} sec.")
|
||||
logger.info(f"avg. inference over {args.itters} iterations took {sum(times) / len(times)} sec.")
|
||||
image.save(f"/tmp/inference_out-{index}.png")
|
||||
if index == 0:
|
||||
metrics_report = met.metrics_report()
|
||||
|
||||
@@ -141,9 +141,7 @@ def log_validation(vae, unet, adapter, args, accelerator, weight_dtype, step):
|
||||
validation_prompt = log["validation_prompt"]
|
||||
validation_image = log["validation_image"]
|
||||
|
||||
formatted_images = []
|
||||
|
||||
formatted_images.append(np.asarray(validation_image))
|
||||
formatted_images = [np.asarray(validation_image)]
|
||||
|
||||
for image in images:
|
||||
formatted_images.append(np.asarray(image))
|
||||
|
||||
@@ -53,8 +53,18 @@ args = parser.parse_args()
|
||||
# this is specific to `AdaLayerNormContinuous`:
|
||||
# diffusers implementation split the linear projection into the scale, shift while CogView4 split it tino shift, scale
|
||||
def swap_scale_shift(weight, dim):
|
||||
shift, scale = weight.chunk(2, dim=0)
|
||||
new_weight = torch.cat([scale, shift], dim=0)
|
||||
"""
|
||||
Swap the scale and shift components in the weight tensor.
|
||||
|
||||
Args:
|
||||
weight (torch.Tensor): The original weight tensor.
|
||||
dim (int): The dimension along which to split.
|
||||
|
||||
Returns:
|
||||
torch.Tensor: The modified weight tensor with scale and shift swapped.
|
||||
"""
|
||||
shift, scale = weight.chunk(2, dim=dim)
|
||||
new_weight = torch.cat([scale, shift], dim=dim)
|
||||
return new_weight
|
||||
|
||||
|
||||
@@ -200,6 +210,7 @@ def main(args):
|
||||
"norm_num_groups": 32,
|
||||
"sample_size": 1024,
|
||||
"scaling_factor": 1.0,
|
||||
"shift_factor": 0.0,
|
||||
"force_upcast": True,
|
||||
"use_quant_conv": False,
|
||||
"use_post_quant_conv": False,
|
||||
|
||||
@@ -25,9 +25,15 @@ import argparse
|
||||
|
||||
import torch
|
||||
from tqdm import tqdm
|
||||
from transformers import GlmForCausalLM, PreTrainedTokenizerFast
|
||||
from transformers import GlmModel, PreTrainedTokenizerFast
|
||||
|
||||
from diffusers import AutoencoderKL, CogView4Pipeline, CogView4Transformer2DModel, FlowMatchEulerDiscreteScheduler
|
||||
from diffusers import (
|
||||
AutoencoderKL,
|
||||
CogView4ControlPipeline,
|
||||
CogView4Pipeline,
|
||||
CogView4Transformer2DModel,
|
||||
FlowMatchEulerDiscreteScheduler,
|
||||
)
|
||||
from diffusers.loaders.single_file_utils import convert_ldm_vae_checkpoint
|
||||
|
||||
|
||||
@@ -112,6 +118,12 @@ parser.add_argument(
|
||||
default=128,
|
||||
help="Maximum size for positional embeddings.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--control",
|
||||
action="store_true",
|
||||
default=False,
|
||||
help="Whether to use control model.",
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
@@ -150,13 +162,15 @@ def convert_megatron_transformer_checkpoint_to_diffusers(
|
||||
Returns:
|
||||
dict: The converted state dictionary compatible with Diffusers.
|
||||
"""
|
||||
ckpt = torch.load(ckpt_path, map_location="cpu")
|
||||
ckpt = torch.load(ckpt_path, map_location="cpu", weights_only=False)
|
||||
mega = ckpt["model"]
|
||||
|
||||
new_state_dict = {}
|
||||
|
||||
# Patch Embedding
|
||||
new_state_dict["patch_embed.proj.weight"] = mega["encoder_expand_linear.weight"].reshape(hidden_size, 64)
|
||||
new_state_dict["patch_embed.proj.weight"] = mega["encoder_expand_linear.weight"].reshape(
|
||||
hidden_size, 128 if args.control else 64
|
||||
)
|
||||
new_state_dict["patch_embed.proj.bias"] = mega["encoder_expand_linear.bias"]
|
||||
new_state_dict["patch_embed.text_proj.weight"] = mega["text_projector.weight"]
|
||||
new_state_dict["patch_embed.text_proj.bias"] = mega["text_projector.bias"]
|
||||
@@ -189,14 +203,8 @@ def convert_megatron_transformer_checkpoint_to_diffusers(
|
||||
block_prefix = f"transformer_blocks.{i}."
|
||||
|
||||
# AdaLayerNorm
|
||||
new_state_dict[block_prefix + "norm1.linear.weight"] = swap_scale_shift(
|
||||
mega[f"decoder.layers.{i}.adaln.weight"], dim=0
|
||||
)
|
||||
new_state_dict[block_prefix + "norm1.linear.bias"] = swap_scale_shift(
|
||||
mega[f"decoder.layers.{i}.adaln.bias"], dim=0
|
||||
)
|
||||
|
||||
# QKV
|
||||
new_state_dict[block_prefix + "norm1.linear.weight"] = mega[f"decoder.layers.{i}.adaln.weight"]
|
||||
new_state_dict[block_prefix + "norm1.linear.bias"] = mega[f"decoder.layers.{i}.adaln.bias"]
|
||||
qkv_weight = mega[f"decoder.layers.{i}.self_attention.linear_qkv.weight"]
|
||||
qkv_bias = mega[f"decoder.layers.{i}.self_attention.linear_qkv.bias"]
|
||||
|
||||
@@ -221,7 +229,7 @@ def convert_megatron_transformer_checkpoint_to_diffusers(
|
||||
# Attention Output
|
||||
new_state_dict[block_prefix + "attn1.to_out.0.weight"] = mega[
|
||||
f"decoder.layers.{i}.self_attention.linear_proj.weight"
|
||||
].T
|
||||
]
|
||||
new_state_dict[block_prefix + "attn1.to_out.0.bias"] = mega[
|
||||
f"decoder.layers.{i}.self_attention.linear_proj.bias"
|
||||
]
|
||||
@@ -252,7 +260,7 @@ def convert_cogview4_vae_checkpoint_to_diffusers(ckpt_path, vae_config):
|
||||
Returns:
|
||||
dict: The converted VAE state dictionary compatible with Diffusers.
|
||||
"""
|
||||
original_state_dict = torch.load(ckpt_path, map_location="cpu")["state_dict"]
|
||||
original_state_dict = torch.load(ckpt_path, map_location="cpu", weights_only=False)["state_dict"]
|
||||
return convert_ldm_vae_checkpoint(original_state_dict, vae_config)
|
||||
|
||||
|
||||
@@ -286,7 +294,7 @@ def main(args):
|
||||
)
|
||||
transformer = CogView4Transformer2DModel(
|
||||
patch_size=2,
|
||||
in_channels=16,
|
||||
in_channels=32 if args.control else 16,
|
||||
num_layers=args.num_layers,
|
||||
attention_head_dim=args.attention_head_dim,
|
||||
num_attention_heads=args.num_heads,
|
||||
@@ -317,6 +325,7 @@ def main(args):
|
||||
"norm_num_groups": 32,
|
||||
"sample_size": 1024,
|
||||
"scaling_factor": 1.0,
|
||||
"shift_factor": 0.0,
|
||||
"force_upcast": True,
|
||||
"use_quant_conv": False,
|
||||
"use_post_quant_conv": False,
|
||||
@@ -331,7 +340,7 @@ def main(args):
|
||||
# Load the text encoder and tokenizer
|
||||
text_encoder_id = "THUDM/glm-4-9b-hf"
|
||||
tokenizer = PreTrainedTokenizerFast.from_pretrained(text_encoder_id)
|
||||
text_encoder = GlmForCausalLM.from_pretrained(
|
||||
text_encoder = GlmModel.from_pretrained(
|
||||
text_encoder_id,
|
||||
cache_dir=args.text_encoder_cache_dir,
|
||||
torch_dtype=torch.bfloat16 if args.dtype == "bf16" else torch.float32,
|
||||
@@ -345,13 +354,22 @@ def main(args):
|
||||
)
|
||||
|
||||
# Create the pipeline
|
||||
pipe = CogView4Pipeline(
|
||||
tokenizer=tokenizer,
|
||||
text_encoder=text_encoder,
|
||||
vae=vae,
|
||||
transformer=transformer,
|
||||
scheduler=scheduler,
|
||||
)
|
||||
if args.control:
|
||||
pipe = CogView4ControlPipeline(
|
||||
tokenizer=tokenizer,
|
||||
text_encoder=text_encoder,
|
||||
vae=vae,
|
||||
transformer=transformer,
|
||||
scheduler=scheduler,
|
||||
)
|
||||
else:
|
||||
pipe = CogView4Pipeline(
|
||||
tokenizer=tokenizer,
|
||||
text_encoder=text_encoder,
|
||||
vae=vae,
|
||||
transformer=transformer,
|
||||
scheduler=scheduler,
|
||||
)
|
||||
|
||||
# Save the converted pipeline
|
||||
pipe.save_pretrained(
|
||||
|
||||
@@ -160,8 +160,9 @@ TRANSFORMER_CONFIGS = {
|
||||
"pooled_projection_dim": 768,
|
||||
"rope_theta": 256.0,
|
||||
"rope_axes_dim": (16, 56, 56),
|
||||
"image_condition_type": None,
|
||||
},
|
||||
"HYVideo-T/2-I2V": {
|
||||
"HYVideo-T/2-I2V-33ch": {
|
||||
"in_channels": 16 * 2 + 1,
|
||||
"out_channels": 16,
|
||||
"num_attention_heads": 24,
|
||||
@@ -178,6 +179,26 @@ TRANSFORMER_CONFIGS = {
|
||||
"pooled_projection_dim": 768,
|
||||
"rope_theta": 256.0,
|
||||
"rope_axes_dim": (16, 56, 56),
|
||||
"image_condition_type": "latent_concat",
|
||||
},
|
||||
"HYVideo-T/2-I2V-16ch": {
|
||||
"in_channels": 16,
|
||||
"out_channels": 16,
|
||||
"num_attention_heads": 24,
|
||||
"attention_head_dim": 128,
|
||||
"num_layers": 20,
|
||||
"num_single_layers": 40,
|
||||
"num_refiner_layers": 2,
|
||||
"mlp_ratio": 4.0,
|
||||
"patch_size": 2,
|
||||
"patch_size_t": 1,
|
||||
"qk_norm": "rms_norm",
|
||||
"guidance_embeds": True,
|
||||
"text_embed_dim": 4096,
|
||||
"pooled_projection_dim": 768,
|
||||
"rope_theta": 256.0,
|
||||
"rope_axes_dim": (16, 56, 56),
|
||||
"image_condition_type": "token_replace",
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
@@ -74,6 +74,32 @@ VAE_091_RENAME_DICT = {
|
||||
"last_scale_shift_table": "scale_shift_table",
|
||||
}
|
||||
|
||||
VAE_095_RENAME_DICT = {
|
||||
# decoder
|
||||
"up_blocks.0": "mid_block",
|
||||
"up_blocks.1": "up_blocks.0.upsamplers.0",
|
||||
"up_blocks.2": "up_blocks.0",
|
||||
"up_blocks.3": "up_blocks.1.upsamplers.0",
|
||||
"up_blocks.4": "up_blocks.1",
|
||||
"up_blocks.5": "up_blocks.2.upsamplers.0",
|
||||
"up_blocks.6": "up_blocks.2",
|
||||
"up_blocks.7": "up_blocks.3.upsamplers.0",
|
||||
"up_blocks.8": "up_blocks.3",
|
||||
# encoder
|
||||
"down_blocks.0": "down_blocks.0",
|
||||
"down_blocks.1": "down_blocks.0.downsamplers.0",
|
||||
"down_blocks.2": "down_blocks.1",
|
||||
"down_blocks.3": "down_blocks.1.downsamplers.0",
|
||||
"down_blocks.4": "down_blocks.2",
|
||||
"down_blocks.5": "down_blocks.2.downsamplers.0",
|
||||
"down_blocks.6": "down_blocks.3",
|
||||
"down_blocks.7": "down_blocks.3.downsamplers.0",
|
||||
"down_blocks.8": "mid_block",
|
||||
# common
|
||||
"last_time_embedder": "time_embedder",
|
||||
"last_scale_shift_table": "scale_shift_table",
|
||||
}
|
||||
|
||||
VAE_SPECIAL_KEYS_REMAP = {
|
||||
"per_channel_statistics.channel": remove_keys_,
|
||||
"per_channel_statistics.mean-of-means": remove_keys_,
|
||||
@@ -81,10 +107,6 @@ VAE_SPECIAL_KEYS_REMAP = {
|
||||
"model.diffusion_model": remove_keys_,
|
||||
}
|
||||
|
||||
VAE_091_SPECIAL_KEYS_REMAP = {
|
||||
"timestep_scale_multiplier": remove_keys_,
|
||||
}
|
||||
|
||||
|
||||
def get_state_dict(saved_dict: Dict[str, Any]) -> Dict[str, Any]:
|
||||
state_dict = saved_dict
|
||||
@@ -104,12 +126,16 @@ def update_state_dict_inplace(state_dict: Dict[str, Any], old_key: str, new_key:
|
||||
def convert_transformer(
|
||||
ckpt_path: str,
|
||||
dtype: torch.dtype,
|
||||
version: str = "0.9.0",
|
||||
):
|
||||
PREFIX_KEY = "model.diffusion_model."
|
||||
|
||||
original_state_dict = get_state_dict(load_file(ckpt_path))
|
||||
config = {}
|
||||
if version == "0.9.5":
|
||||
config["_use_causal_rope_fix"] = True
|
||||
with init_empty_weights():
|
||||
transformer = LTXVideoTransformer3DModel()
|
||||
transformer = LTXVideoTransformer3DModel(**config)
|
||||
|
||||
for key in list(original_state_dict.keys()):
|
||||
new_key = key[:]
|
||||
@@ -161,12 +187,19 @@ def get_vae_config(version: str) -> Dict[str, Any]:
|
||||
"out_channels": 3,
|
||||
"latent_channels": 128,
|
||||
"block_out_channels": (128, 256, 512, 512),
|
||||
"down_block_types": (
|
||||
"LTXVideoDownBlock3D",
|
||||
"LTXVideoDownBlock3D",
|
||||
"LTXVideoDownBlock3D",
|
||||
"LTXVideoDownBlock3D",
|
||||
),
|
||||
"decoder_block_out_channels": (128, 256, 512, 512),
|
||||
"layers_per_block": (4, 3, 3, 3, 4),
|
||||
"decoder_layers_per_block": (4, 3, 3, 3, 4),
|
||||
"spatio_temporal_scaling": (True, True, True, False),
|
||||
"decoder_spatio_temporal_scaling": (True, True, True, False),
|
||||
"decoder_inject_noise": (False, False, False, False, False),
|
||||
"downsample_type": ("conv", "conv", "conv", "conv"),
|
||||
"upsample_residual": (False, False, False, False),
|
||||
"upsample_factor": (1, 1, 1, 1),
|
||||
"patch_size": 4,
|
||||
@@ -183,12 +216,19 @@ def get_vae_config(version: str) -> Dict[str, Any]:
|
||||
"out_channels": 3,
|
||||
"latent_channels": 128,
|
||||
"block_out_channels": (128, 256, 512, 512),
|
||||
"down_block_types": (
|
||||
"LTXVideoDownBlock3D",
|
||||
"LTXVideoDownBlock3D",
|
||||
"LTXVideoDownBlock3D",
|
||||
"LTXVideoDownBlock3D",
|
||||
),
|
||||
"decoder_block_out_channels": (256, 512, 1024),
|
||||
"layers_per_block": (4, 3, 3, 3, 4),
|
||||
"decoder_layers_per_block": (5, 6, 7, 8),
|
||||
"spatio_temporal_scaling": (True, True, True, False),
|
||||
"decoder_spatio_temporal_scaling": (True, True, True),
|
||||
"decoder_inject_noise": (True, True, True, False),
|
||||
"downsample_type": ("conv", "conv", "conv", "conv"),
|
||||
"upsample_residual": (True, True, True),
|
||||
"upsample_factor": (2, 2, 2),
|
||||
"timestep_conditioning": True,
|
||||
@@ -200,7 +240,38 @@ def get_vae_config(version: str) -> Dict[str, Any]:
|
||||
"decoder_causal": False,
|
||||
}
|
||||
VAE_KEYS_RENAME_DICT.update(VAE_091_RENAME_DICT)
|
||||
VAE_SPECIAL_KEYS_REMAP.update(VAE_091_SPECIAL_KEYS_REMAP)
|
||||
elif version == "0.9.5":
|
||||
config = {
|
||||
"in_channels": 3,
|
||||
"out_channels": 3,
|
||||
"latent_channels": 128,
|
||||
"block_out_channels": (128, 256, 512, 1024, 2048),
|
||||
"down_block_types": (
|
||||
"LTXVideo095DownBlock3D",
|
||||
"LTXVideo095DownBlock3D",
|
||||
"LTXVideo095DownBlock3D",
|
||||
"LTXVideo095DownBlock3D",
|
||||
),
|
||||
"decoder_block_out_channels": (256, 512, 1024),
|
||||
"layers_per_block": (4, 6, 6, 2, 2),
|
||||
"decoder_layers_per_block": (5, 5, 5, 5),
|
||||
"spatio_temporal_scaling": (True, True, True, True),
|
||||
"decoder_spatio_temporal_scaling": (True, True, True),
|
||||
"decoder_inject_noise": (False, False, False, False),
|
||||
"downsample_type": ("spatial", "temporal", "spatiotemporal", "spatiotemporal"),
|
||||
"upsample_residual": (True, True, True),
|
||||
"upsample_factor": (2, 2, 2),
|
||||
"timestep_conditioning": True,
|
||||
"patch_size": 4,
|
||||
"patch_size_t": 1,
|
||||
"resnet_norm_eps": 1e-6,
|
||||
"scaling_factor": 1.0,
|
||||
"encoder_causal": True,
|
||||
"decoder_causal": False,
|
||||
"spatial_compression_ratio": 32,
|
||||
"temporal_compression_ratio": 8,
|
||||
}
|
||||
VAE_KEYS_RENAME_DICT.update(VAE_095_RENAME_DICT)
|
||||
return config
|
||||
|
||||
|
||||
@@ -223,7 +294,7 @@ def get_args():
|
||||
parser.add_argument("--output_path", type=str, required=True, help="Path where converted model should be saved")
|
||||
parser.add_argument("--dtype", default="fp32", help="Torch dtype to save the model in.")
|
||||
parser.add_argument(
|
||||
"--version", type=str, default="0.9.0", choices=["0.9.0", "0.9.1"], help="Version of the LTX model"
|
||||
"--version", type=str, default="0.9.0", choices=["0.9.0", "0.9.1", "0.9.5"], help="Version of the LTX model"
|
||||
)
|
||||
return parser.parse_args()
|
||||
|
||||
@@ -277,14 +348,17 @@ if __name__ == "__main__":
|
||||
for param in text_encoder.parameters():
|
||||
param.data = param.data.contiguous()
|
||||
|
||||
scheduler = FlowMatchEulerDiscreteScheduler(
|
||||
use_dynamic_shifting=True,
|
||||
base_shift=0.95,
|
||||
max_shift=2.05,
|
||||
base_image_seq_len=1024,
|
||||
max_image_seq_len=4096,
|
||||
shift_terminal=0.1,
|
||||
)
|
||||
if args.version == "0.9.5":
|
||||
scheduler = FlowMatchEulerDiscreteScheduler(use_dynamic_shifting=False)
|
||||
else:
|
||||
scheduler = FlowMatchEulerDiscreteScheduler(
|
||||
use_dynamic_shifting=True,
|
||||
base_shift=0.95,
|
||||
max_shift=2.05,
|
||||
base_image_seq_len=1024,
|
||||
max_image_seq_len=4096,
|
||||
shift_terminal=0.1,
|
||||
)
|
||||
|
||||
pipe = LTXPipeline(
|
||||
scheduler=scheduler,
|
||||
|
||||
@@ -5,7 +5,7 @@ import torch
|
||||
from safetensors.torch import load_file
|
||||
from transformers import AutoModel, AutoTokenizer
|
||||
|
||||
from diffusers import AutoencoderKL, FlowMatchEulerDiscreteScheduler, LuminaNextDiT2DModel, LuminaText2ImgPipeline
|
||||
from diffusers import AutoencoderKL, FlowMatchEulerDiscreteScheduler, LuminaNextDiT2DModel, LuminaPipeline
|
||||
|
||||
|
||||
def main(args):
|
||||
@@ -115,7 +115,7 @@ def main(args):
|
||||
tokenizer = AutoTokenizer.from_pretrained("google/gemma-2b")
|
||||
text_encoder = AutoModel.from_pretrained("google/gemma-2b")
|
||||
|
||||
pipeline = LuminaText2ImgPipeline(
|
||||
pipeline = LuminaPipeline(
|
||||
tokenizer=tokenizer, text_encoder=text_encoder, transformer=transformer, vae=vae, scheduler=scheduler
|
||||
)
|
||||
pipeline.save_pretrained(args.dump_path)
|
||||
|
||||
@@ -16,7 +16,9 @@ from diffusers import (
|
||||
DPMSolverMultistepScheduler,
|
||||
FlowMatchEulerDiscreteScheduler,
|
||||
SanaPipeline,
|
||||
SanaSprintPipeline,
|
||||
SanaTransformer2DModel,
|
||||
SCMScheduler,
|
||||
)
|
||||
from diffusers.models.modeling_utils import load_model_dict_into_meta
|
||||
from diffusers.utils.import_utils import is_accelerate_available
|
||||
@@ -25,6 +27,10 @@ from diffusers.utils.import_utils import is_accelerate_available
|
||||
CTX = init_empty_weights if is_accelerate_available else nullcontext
|
||||
|
||||
ckpt_ids = [
|
||||
"Efficient-Large-Model/Sana_Sprint_0.6B_1024px/checkpoints/Sana_Sprint_0.6B_1024px.pth"
|
||||
"Efficient-Large-Model/Sana_Sprint_1.6B_1024px/checkpoints/Sana_Sprint_1.6B_1024px.pth"
|
||||
"Efficient-Large-Model/SANA1.5_4.8B_1024px/checkpoints/SANA1.5_4.8B_1024px.pth",
|
||||
"Efficient-Large-Model/SANA1.5_1.6B_1024px/checkpoints/SANA1.5_1.6B_1024px.pth",
|
||||
"Efficient-Large-Model/Sana_1600M_4Kpx_BF16/checkpoints/Sana_1600M_4Kpx_BF16.pth",
|
||||
"Efficient-Large-Model/Sana_1600M_2Kpx_BF16/checkpoints/Sana_1600M_2Kpx_BF16.pth",
|
||||
"Efficient-Large-Model/Sana_1600M_1024px_MultiLing/checkpoints/Sana_1600M_1024px_MultiLing.pth",
|
||||
@@ -72,15 +78,42 @@ def main(args):
|
||||
converted_state_dict["caption_projection.linear_2.weight"] = state_dict.pop("y_embedder.y_proj.fc2.weight")
|
||||
converted_state_dict["caption_projection.linear_2.bias"] = state_dict.pop("y_embedder.y_proj.fc2.bias")
|
||||
|
||||
# AdaLN-single LN
|
||||
converted_state_dict["time_embed.emb.timestep_embedder.linear_1.weight"] = state_dict.pop(
|
||||
"t_embedder.mlp.0.weight"
|
||||
)
|
||||
converted_state_dict["time_embed.emb.timestep_embedder.linear_1.bias"] = state_dict.pop("t_embedder.mlp.0.bias")
|
||||
converted_state_dict["time_embed.emb.timestep_embedder.linear_2.weight"] = state_dict.pop(
|
||||
"t_embedder.mlp.2.weight"
|
||||
)
|
||||
converted_state_dict["time_embed.emb.timestep_embedder.linear_2.bias"] = state_dict.pop("t_embedder.mlp.2.bias")
|
||||
# Handle different time embedding structure based on model type
|
||||
|
||||
if args.model_type in ["SanaSprint_1600M_P1_D20", "SanaSprint_600M_P1_D28"]:
|
||||
# For Sana Sprint, the time embedding structure is different
|
||||
converted_state_dict["time_embed.timestep_embedder.linear_1.weight"] = state_dict.pop(
|
||||
"t_embedder.mlp.0.weight"
|
||||
)
|
||||
converted_state_dict["time_embed.timestep_embedder.linear_1.bias"] = state_dict.pop("t_embedder.mlp.0.bias")
|
||||
converted_state_dict["time_embed.timestep_embedder.linear_2.weight"] = state_dict.pop(
|
||||
"t_embedder.mlp.2.weight"
|
||||
)
|
||||
converted_state_dict["time_embed.timestep_embedder.linear_2.bias"] = state_dict.pop("t_embedder.mlp.2.bias")
|
||||
|
||||
# Guidance embedder for Sana Sprint
|
||||
converted_state_dict["time_embed.guidance_embedder.linear_1.weight"] = state_dict.pop(
|
||||
"cfg_embedder.mlp.0.weight"
|
||||
)
|
||||
converted_state_dict["time_embed.guidance_embedder.linear_1.bias"] = state_dict.pop("cfg_embedder.mlp.0.bias")
|
||||
converted_state_dict["time_embed.guidance_embedder.linear_2.weight"] = state_dict.pop(
|
||||
"cfg_embedder.mlp.2.weight"
|
||||
)
|
||||
converted_state_dict["time_embed.guidance_embedder.linear_2.bias"] = state_dict.pop("cfg_embedder.mlp.2.bias")
|
||||
else:
|
||||
# Original Sana time embedding structure
|
||||
converted_state_dict["time_embed.emb.timestep_embedder.linear_1.weight"] = state_dict.pop(
|
||||
"t_embedder.mlp.0.weight"
|
||||
)
|
||||
converted_state_dict["time_embed.emb.timestep_embedder.linear_1.bias"] = state_dict.pop(
|
||||
"t_embedder.mlp.0.bias"
|
||||
)
|
||||
converted_state_dict["time_embed.emb.timestep_embedder.linear_2.weight"] = state_dict.pop(
|
||||
"t_embedder.mlp.2.weight"
|
||||
)
|
||||
converted_state_dict["time_embed.emb.timestep_embedder.linear_2.bias"] = state_dict.pop(
|
||||
"t_embedder.mlp.2.bias"
|
||||
)
|
||||
|
||||
# Shared norm.
|
||||
converted_state_dict["time_embed.linear.weight"] = state_dict.pop("t_block.1.weight")
|
||||
@@ -96,14 +129,22 @@ def main(args):
|
||||
flow_shift = 3.0
|
||||
|
||||
# model config
|
||||
if args.model_type == "SanaMS_1600M_P1_D20":
|
||||
if args.model_type in ["SanaMS_1600M_P1_D20", "SanaSprint_1600M_P1_D20", "SanaMS1.5_1600M_P1_D20"]:
|
||||
layer_num = 20
|
||||
elif args.model_type == "SanaMS_600M_P1_D28":
|
||||
elif args.model_type in ["SanaMS_600M_P1_D28", "SanaSprint_600M_P1_D28"]:
|
||||
layer_num = 28
|
||||
elif args.model_type == "SanaMS_4800M_P1_D60":
|
||||
layer_num = 60
|
||||
else:
|
||||
raise ValueError(f"{args.model_type} is not supported.")
|
||||
# Positional embedding interpolation scale.
|
||||
interpolation_scale = {512: None, 1024: None, 2048: 1.0, 4096: 2.0}
|
||||
qk_norm = (
|
||||
"rms_norm_across_heads"
|
||||
if args.model_type
|
||||
in ["SanaMS1.5_1600M_P1_D20", "SanaMS1.5_4800M_P1_D60", "SanaSprint_600M_P1_D28", "SanaSprint_1600M_P1_D20"]
|
||||
else None
|
||||
)
|
||||
|
||||
for depth in range(layer_num):
|
||||
# Transformer blocks.
|
||||
@@ -117,6 +158,14 @@ def main(args):
|
||||
converted_state_dict[f"transformer_blocks.{depth}.attn1.to_q.weight"] = q
|
||||
converted_state_dict[f"transformer_blocks.{depth}.attn1.to_k.weight"] = k
|
||||
converted_state_dict[f"transformer_blocks.{depth}.attn1.to_v.weight"] = v
|
||||
if qk_norm is not None:
|
||||
# Add Q/K normalization for self-attention (attn1) - needed for Sana-Sprint and Sana-1.5
|
||||
converted_state_dict[f"transformer_blocks.{depth}.attn1.norm_q.weight"] = state_dict.pop(
|
||||
f"blocks.{depth}.attn.q_norm.weight"
|
||||
)
|
||||
converted_state_dict[f"transformer_blocks.{depth}.attn1.norm_k.weight"] = state_dict.pop(
|
||||
f"blocks.{depth}.attn.k_norm.weight"
|
||||
)
|
||||
# Projection.
|
||||
converted_state_dict[f"transformer_blocks.{depth}.attn1.to_out.0.weight"] = state_dict.pop(
|
||||
f"blocks.{depth}.attn.proj.weight"
|
||||
@@ -154,6 +203,14 @@ def main(args):
|
||||
converted_state_dict[f"transformer_blocks.{depth}.attn2.to_k.bias"] = k_bias
|
||||
converted_state_dict[f"transformer_blocks.{depth}.attn2.to_v.weight"] = v
|
||||
converted_state_dict[f"transformer_blocks.{depth}.attn2.to_v.bias"] = v_bias
|
||||
if qk_norm is not None:
|
||||
# Add Q/K normalization for cross-attention (attn2) - needed for Sana-Sprint and Sana-1.5
|
||||
converted_state_dict[f"transformer_blocks.{depth}.attn2.norm_q.weight"] = state_dict.pop(
|
||||
f"blocks.{depth}.cross_attn.q_norm.weight"
|
||||
)
|
||||
converted_state_dict[f"transformer_blocks.{depth}.attn2.norm_k.weight"] = state_dict.pop(
|
||||
f"blocks.{depth}.cross_attn.k_norm.weight"
|
||||
)
|
||||
|
||||
converted_state_dict[f"transformer_blocks.{depth}.attn2.to_out.0.weight"] = state_dict.pop(
|
||||
f"blocks.{depth}.cross_attn.proj.weight"
|
||||
@@ -169,24 +226,37 @@ def main(args):
|
||||
|
||||
# Transformer
|
||||
with CTX():
|
||||
transformer = SanaTransformer2DModel(
|
||||
in_channels=32,
|
||||
out_channels=32,
|
||||
num_attention_heads=model_kwargs[args.model_type]["num_attention_heads"],
|
||||
attention_head_dim=model_kwargs[args.model_type]["attention_head_dim"],
|
||||
num_layers=model_kwargs[args.model_type]["num_layers"],
|
||||
num_cross_attention_heads=model_kwargs[args.model_type]["num_cross_attention_heads"],
|
||||
cross_attention_head_dim=model_kwargs[args.model_type]["cross_attention_head_dim"],
|
||||
cross_attention_dim=model_kwargs[args.model_type]["cross_attention_dim"],
|
||||
caption_channels=2304,
|
||||
mlp_ratio=2.5,
|
||||
attention_bias=False,
|
||||
sample_size=args.image_size // 32,
|
||||
patch_size=1,
|
||||
norm_elementwise_affine=False,
|
||||
norm_eps=1e-6,
|
||||
interpolation_scale=interpolation_scale[args.image_size],
|
||||
)
|
||||
transformer_kwargs = {
|
||||
"in_channels": 32,
|
||||
"out_channels": 32,
|
||||
"num_attention_heads": model_kwargs[args.model_type]["num_attention_heads"],
|
||||
"attention_head_dim": model_kwargs[args.model_type]["attention_head_dim"],
|
||||
"num_layers": model_kwargs[args.model_type]["num_layers"],
|
||||
"num_cross_attention_heads": model_kwargs[args.model_type]["num_cross_attention_heads"],
|
||||
"cross_attention_head_dim": model_kwargs[args.model_type]["cross_attention_head_dim"],
|
||||
"cross_attention_dim": model_kwargs[args.model_type]["cross_attention_dim"],
|
||||
"caption_channels": 2304,
|
||||
"mlp_ratio": 2.5,
|
||||
"attention_bias": False,
|
||||
"sample_size": args.image_size // 32,
|
||||
"patch_size": 1,
|
||||
"norm_elementwise_affine": False,
|
||||
"norm_eps": 1e-6,
|
||||
"interpolation_scale": interpolation_scale[args.image_size],
|
||||
}
|
||||
|
||||
# Add qk_norm parameter for Sana Sprint
|
||||
if args.model_type in [
|
||||
"SanaMS1.5_1600M_P1_D20",
|
||||
"SanaMS1.5_4800M_P1_D60",
|
||||
"SanaSprint_600M_P1_D28",
|
||||
"SanaSprint_1600M_P1_D20",
|
||||
]:
|
||||
transformer_kwargs["qk_norm"] = "rms_norm_across_heads"
|
||||
if args.model_type in ["SanaSprint_1600M_P1_D20", "SanaSprint_600M_P1_D28"]:
|
||||
transformer_kwargs["guidance_embeds"] = True
|
||||
|
||||
transformer = SanaTransformer2DModel(**transformer_kwargs)
|
||||
|
||||
if is_accelerate_available():
|
||||
load_model_dict_into_meta(transformer, converted_state_dict)
|
||||
@@ -196,6 +266,8 @@ def main(args):
|
||||
try:
|
||||
state_dict.pop("y_embedder.y_embedding")
|
||||
state_dict.pop("pos_embed")
|
||||
state_dict.pop("logvar_linear.weight")
|
||||
state_dict.pop("logvar_linear.bias")
|
||||
except KeyError:
|
||||
print("y_embedder.y_embedding or pos_embed not found in the state_dict")
|
||||
|
||||
@@ -210,47 +282,74 @@ def main(args):
|
||||
print(
|
||||
colored(
|
||||
f"Only saving transformer model of {args.model_type}. "
|
||||
f"Set --save_full_pipeline to save the whole SanaPipeline",
|
||||
f"Set --save_full_pipeline to save the whole Pipeline",
|
||||
"green",
|
||||
attrs=["bold"],
|
||||
)
|
||||
)
|
||||
transformer.save_pretrained(
|
||||
os.path.join(args.dump_path, "transformer"), safe_serialization=True, max_shard_size="5GB", variant=variant
|
||||
os.path.join(args.dump_path, "transformer"), safe_serialization=True, max_shard_size="5GB"
|
||||
)
|
||||
else:
|
||||
print(colored(f"Saving the whole SanaPipeline containing {args.model_type}", "green", attrs=["bold"]))
|
||||
print(colored(f"Saving the whole Pipeline containing {args.model_type}", "green", attrs=["bold"]))
|
||||
# VAE
|
||||
ae = AutoencoderDC.from_pretrained("mit-han-lab/dc-ae-f32c32-sana-1.0-diffusers", torch_dtype=torch.float32)
|
||||
ae = AutoencoderDC.from_pretrained("mit-han-lab/dc-ae-f32c32-sana-1.1-diffusers", torch_dtype=torch.float32)
|
||||
|
||||
# Text Encoder
|
||||
text_encoder_model_path = "google/gemma-2-2b-it"
|
||||
text_encoder_model_path = "Efficient-Large-Model/gemma-2-2b-it"
|
||||
tokenizer = AutoTokenizer.from_pretrained(text_encoder_model_path)
|
||||
tokenizer.padding_side = "right"
|
||||
text_encoder = AutoModelForCausalLM.from_pretrained(
|
||||
text_encoder_model_path, torch_dtype=torch.bfloat16
|
||||
).get_decoder()
|
||||
|
||||
# Scheduler
|
||||
if args.scheduler_type == "flow-dpm_solver":
|
||||
scheduler = DPMSolverMultistepScheduler(
|
||||
flow_shift=flow_shift,
|
||||
use_flow_sigmas=True,
|
||||
prediction_type="flow_prediction",
|
||||
)
|
||||
elif args.scheduler_type == "flow-euler":
|
||||
scheduler = FlowMatchEulerDiscreteScheduler(shift=flow_shift)
|
||||
else:
|
||||
raise ValueError(f"Scheduler type {args.scheduler_type} is not supported")
|
||||
# Choose the appropriate pipeline and scheduler based on model type
|
||||
if args.model_type in ["SanaSprint_1600M_P1_D20", "SanaSprint_600M_P1_D28"]:
|
||||
# Force SCM Scheduler for Sana Sprint regardless of scheduler_type
|
||||
if args.scheduler_type != "scm":
|
||||
print(
|
||||
colored(
|
||||
f"Warning: Overriding scheduler_type '{args.scheduler_type}' to 'scm' for SanaSprint model",
|
||||
"yellow",
|
||||
attrs=["bold"],
|
||||
)
|
||||
)
|
||||
|
||||
pipe = SanaPipeline(
|
||||
tokenizer=tokenizer,
|
||||
text_encoder=text_encoder,
|
||||
transformer=transformer,
|
||||
vae=ae,
|
||||
scheduler=scheduler,
|
||||
)
|
||||
pipe.save_pretrained(args.dump_path, safe_serialization=True, max_shard_size="5GB", variant=variant)
|
||||
# SCM Scheduler for Sana Sprint
|
||||
scheduler_config = {
|
||||
"prediction_type": "trigflow",
|
||||
"sigma_data": 0.5,
|
||||
}
|
||||
scheduler = SCMScheduler(**scheduler_config)
|
||||
pipe = SanaSprintPipeline(
|
||||
tokenizer=tokenizer,
|
||||
text_encoder=text_encoder,
|
||||
transformer=transformer,
|
||||
vae=ae,
|
||||
scheduler=scheduler,
|
||||
)
|
||||
else:
|
||||
# Original Sana scheduler
|
||||
if args.scheduler_type == "flow-dpm_solver":
|
||||
scheduler = DPMSolverMultistepScheduler(
|
||||
flow_shift=flow_shift,
|
||||
use_flow_sigmas=True,
|
||||
prediction_type="flow_prediction",
|
||||
)
|
||||
elif args.scheduler_type == "flow-euler":
|
||||
scheduler = FlowMatchEulerDiscreteScheduler(shift=flow_shift)
|
||||
else:
|
||||
raise ValueError(f"Scheduler type {args.scheduler_type} is not supported")
|
||||
|
||||
pipe = SanaPipeline(
|
||||
tokenizer=tokenizer,
|
||||
text_encoder=text_encoder,
|
||||
transformer=transformer,
|
||||
vae=ae,
|
||||
scheduler=scheduler,
|
||||
)
|
||||
|
||||
pipe.save_pretrained(args.dump_path, safe_serialization=True, max_shard_size="5GB")
|
||||
|
||||
|
||||
DTYPE_MAPPING = {
|
||||
@@ -259,12 +358,6 @@ DTYPE_MAPPING = {
|
||||
"bf16": torch.bfloat16,
|
||||
}
|
||||
|
||||
VARIANT_MAPPING = {
|
||||
"fp32": None,
|
||||
"fp16": "fp16",
|
||||
"bf16": "bf16",
|
||||
}
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
@@ -281,10 +374,24 @@ if __name__ == "__main__":
|
||||
help="Image size of pretrained model, 512, 1024, 2048 or 4096.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--model_type", default="SanaMS_1600M_P1_D20", type=str, choices=["SanaMS_1600M_P1_D20", "SanaMS_600M_P1_D28"]
|
||||
"--model_type",
|
||||
default="SanaMS_1600M_P1_D20",
|
||||
type=str,
|
||||
choices=[
|
||||
"SanaMS_1600M_P1_D20",
|
||||
"SanaMS_600M_P1_D28",
|
||||
"SanaMS1.5_1600M_P1_D20",
|
||||
"SanaMS1.5_4800M_P1_D60",
|
||||
"SanaSprint_1600M_P1_D20",
|
||||
"SanaSprint_600M_P1_D28",
|
||||
],
|
||||
)
|
||||
parser.add_argument(
|
||||
"--scheduler_type", default="flow-dpm_solver", type=str, choices=["flow-dpm_solver", "flow-euler"]
|
||||
"--scheduler_type",
|
||||
default="flow-dpm_solver",
|
||||
type=str,
|
||||
choices=["flow-dpm_solver", "flow-euler", "scm"],
|
||||
help="Scheduler type to use. Use 'scm' for Sana Sprint models.",
|
||||
)
|
||||
parser.add_argument("--dump_path", default=None, type=str, required=True, help="Path to the output pipeline.")
|
||||
parser.add_argument("--save_full_pipeline", action="store_true", help="save all the pipelien elemets in one.")
|
||||
@@ -309,10 +416,41 @@ if __name__ == "__main__":
|
||||
"cross_attention_dim": 1152,
|
||||
"num_layers": 28,
|
||||
},
|
||||
"SanaMS1.5_1600M_P1_D20": {
|
||||
"num_attention_heads": 70,
|
||||
"attention_head_dim": 32,
|
||||
"num_cross_attention_heads": 20,
|
||||
"cross_attention_head_dim": 112,
|
||||
"cross_attention_dim": 2240,
|
||||
"num_layers": 20,
|
||||
},
|
||||
"SanaMS1.5_4800M_P1_D60": {
|
||||
"num_attention_heads": 70,
|
||||
"attention_head_dim": 32,
|
||||
"num_cross_attention_heads": 20,
|
||||
"cross_attention_head_dim": 112,
|
||||
"cross_attention_dim": 2240,
|
||||
"num_layers": 60,
|
||||
},
|
||||
"SanaSprint_600M_P1_D28": {
|
||||
"num_attention_heads": 36,
|
||||
"attention_head_dim": 32,
|
||||
"num_cross_attention_heads": 16,
|
||||
"cross_attention_head_dim": 72,
|
||||
"cross_attention_dim": 1152,
|
||||
"num_layers": 28,
|
||||
},
|
||||
"SanaSprint_1600M_P1_D20": {
|
||||
"num_attention_heads": 70,
|
||||
"attention_head_dim": 32,
|
||||
"num_cross_attention_heads": 20,
|
||||
"cross_attention_head_dim": 112,
|
||||
"cross_attention_dim": 2240,
|
||||
"num_layers": 20,
|
||||
},
|
||||
}
|
||||
|
||||
device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||
weight_dtype = DTYPE_MAPPING[args.dtype]
|
||||
variant = VARIANT_MAPPING[args.dtype]
|
||||
|
||||
main(args)
|
||||
|
||||
@@ -128,6 +128,10 @@ _deps = [
|
||||
"GitPython<3.1.19",
|
||||
"scipy",
|
||||
"onnx",
|
||||
"optimum_quanto>=0.2.6",
|
||||
"gguf>=0.10.0",
|
||||
"torchao>=0.7.0",
|
||||
"bitsandbytes>=0.43.3",
|
||||
"regex!=2019.12.17",
|
||||
"requests",
|
||||
"tensorboard",
|
||||
@@ -235,6 +239,11 @@ extras["test"] = deps_list(
|
||||
)
|
||||
extras["torch"] = deps_list("torch", "accelerate")
|
||||
|
||||
extras["bitsandbytes"] = deps_list("bitsandbytes", "accelerate")
|
||||
extras["gguf"] = deps_list("gguf", "accelerate")
|
||||
extras["optimum_quanto"] = deps_list("optimum_quanto", "accelerate")
|
||||
extras["torchao"] = deps_list("torchao", "accelerate")
|
||||
|
||||
if os.name == "nt": # windows
|
||||
extras["flax"] = [] # jax is not supported on windows
|
||||
else:
|
||||
|
||||
+109
-3
@@ -6,14 +6,19 @@ from .utils import (
|
||||
DIFFUSERS_SLOW_IMPORT,
|
||||
OptionalDependencyNotAvailable,
|
||||
_LazyModule,
|
||||
is_accelerate_available,
|
||||
is_bitsandbytes_available,
|
||||
is_flax_available,
|
||||
is_gguf_available,
|
||||
is_k_diffusion_available,
|
||||
is_librosa_available,
|
||||
is_note_seq_available,
|
||||
is_onnx_available,
|
||||
is_optimum_quanto_available,
|
||||
is_scipy_available,
|
||||
is_sentencepiece_available,
|
||||
is_torch_available,
|
||||
is_torchao_available,
|
||||
is_torchsde_available,
|
||||
is_transformers_available,
|
||||
)
|
||||
@@ -32,7 +37,7 @@ _import_structure = {
|
||||
"loaders": ["FromOriginalModelMixin"],
|
||||
"models": [],
|
||||
"pipelines": [],
|
||||
"quantizers.quantization_config": ["BitsAndBytesConfig", "GGUFQuantizationConfig", "TorchAoConfig"],
|
||||
"quantizers.quantization_config": [],
|
||||
"schedulers": [],
|
||||
"utils": [
|
||||
"OptionalDependencyNotAvailable",
|
||||
@@ -54,6 +59,54 @@ _import_structure = {
|
||||
],
|
||||
}
|
||||
|
||||
try:
|
||||
if not is_torch_available() and not is_accelerate_available() and not is_bitsandbytes_available():
|
||||
raise OptionalDependencyNotAvailable()
|
||||
except OptionalDependencyNotAvailable:
|
||||
from .utils import dummy_bitsandbytes_objects
|
||||
|
||||
_import_structure["utils.dummy_bitsandbytes_objects"] = [
|
||||
name for name in dir(dummy_bitsandbytes_objects) if not name.startswith("_")
|
||||
]
|
||||
else:
|
||||
_import_structure["quantizers.quantization_config"].append("BitsAndBytesConfig")
|
||||
|
||||
try:
|
||||
if not is_torch_available() and not is_accelerate_available() and not is_gguf_available():
|
||||
raise OptionalDependencyNotAvailable()
|
||||
except OptionalDependencyNotAvailable:
|
||||
from .utils import dummy_gguf_objects
|
||||
|
||||
_import_structure["utils.dummy_gguf_objects"] = [
|
||||
name for name in dir(dummy_gguf_objects) if not name.startswith("_")
|
||||
]
|
||||
else:
|
||||
_import_structure["quantizers.quantization_config"].append("GGUFQuantizationConfig")
|
||||
|
||||
try:
|
||||
if not is_torch_available() and not is_accelerate_available() and not is_torchao_available():
|
||||
raise OptionalDependencyNotAvailable()
|
||||
except OptionalDependencyNotAvailable:
|
||||
from .utils import dummy_torchao_objects
|
||||
|
||||
_import_structure["utils.dummy_torchao_objects"] = [
|
||||
name for name in dir(dummy_torchao_objects) if not name.startswith("_")
|
||||
]
|
||||
else:
|
||||
_import_structure["quantizers.quantization_config"].append("TorchAoConfig")
|
||||
|
||||
try:
|
||||
if not is_torch_available() and not is_accelerate_available() and not is_optimum_quanto_available():
|
||||
raise OptionalDependencyNotAvailable()
|
||||
except OptionalDependencyNotAvailable:
|
||||
from .utils import dummy_optimum_quanto_objects
|
||||
|
||||
_import_structure["utils.dummy_optimum_quanto_objects"] = [
|
||||
name for name in dir(dummy_optimum_quanto_objects) if not name.startswith("_")
|
||||
]
|
||||
else:
|
||||
_import_structure["quantizers.quantization_config"].append("QuantoConfig")
|
||||
|
||||
try:
|
||||
if not is_onnx_available():
|
||||
raise OptionalDependencyNotAvailable()
|
||||
@@ -78,8 +131,10 @@ except OptionalDependencyNotAvailable:
|
||||
else:
|
||||
_import_structure["hooks"].extend(
|
||||
[
|
||||
"FasterCacheConfig",
|
||||
"HookRegistry",
|
||||
"PyramidAttentionBroadcastConfig",
|
||||
"apply_faster_cache",
|
||||
"apply_pyramid_attention_broadcast",
|
||||
]
|
||||
)
|
||||
@@ -218,6 +273,7 @@ else:
|
||||
"RePaintScheduler",
|
||||
"SASolverScheduler",
|
||||
"SchedulerMixin",
|
||||
"SCMScheduler",
|
||||
"ScoreSdeVeScheduler",
|
||||
"TCDScheduler",
|
||||
"UnCLIPScheduler",
|
||||
@@ -292,6 +348,7 @@ else:
|
||||
"CogVideoXPipeline",
|
||||
"CogVideoXVideoToVideoPipeline",
|
||||
"CogView3PlusPipeline",
|
||||
"CogView4ControlPipeline",
|
||||
"CogView4Pipeline",
|
||||
"ConsisIDPipeline",
|
||||
"CycleDiffusionPipeline",
|
||||
@@ -348,9 +405,12 @@ else:
|
||||
"LDMTextToImagePipeline",
|
||||
"LEditsPPPipelineStableDiffusion",
|
||||
"LEditsPPPipelineStableDiffusionXL",
|
||||
"LTXConditionPipeline",
|
||||
"LTXImageToVideoPipeline",
|
||||
"LTXPipeline",
|
||||
"Lumina2Pipeline",
|
||||
"Lumina2Text2ImgPipeline",
|
||||
"LuminaPipeline",
|
||||
"LuminaText2ImgPipeline",
|
||||
"MarigoldDepthPipeline",
|
||||
"MarigoldIntrinsicsPipeline",
|
||||
@@ -366,6 +426,7 @@ else:
|
||||
"ReduxImageEncoder",
|
||||
"SanaPAGPipeline",
|
||||
"SanaPipeline",
|
||||
"SanaSprintPipeline",
|
||||
"SemanticStableDiffusionPipeline",
|
||||
"ShapEImg2ImgPipeline",
|
||||
"ShapEPipeline",
|
||||
@@ -448,6 +509,7 @@ else:
|
||||
"VQDiffusionPipeline",
|
||||
"WanImageToVideoPipeline",
|
||||
"WanPipeline",
|
||||
"WanVideoToVideoPipeline",
|
||||
"WuerstchenCombinedPipeline",
|
||||
"WuerstchenDecoderPipeline",
|
||||
"WuerstchenPriorPipeline",
|
||||
@@ -599,7 +661,38 @@ else:
|
||||
|
||||
if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
from .configuration_utils import ConfigMixin
|
||||
from .quantizers.quantization_config import BitsAndBytesConfig, GGUFQuantizationConfig, TorchAoConfig
|
||||
|
||||
try:
|
||||
if not is_bitsandbytes_available():
|
||||
raise OptionalDependencyNotAvailable()
|
||||
except OptionalDependencyNotAvailable:
|
||||
from .utils.dummy_bitsandbytes_objects import *
|
||||
else:
|
||||
from .quantizers.quantization_config import BitsAndBytesConfig
|
||||
|
||||
try:
|
||||
if not is_gguf_available():
|
||||
raise OptionalDependencyNotAvailable()
|
||||
except OptionalDependencyNotAvailable:
|
||||
from .utils.dummy_gguf_objects import *
|
||||
else:
|
||||
from .quantizers.quantization_config import GGUFQuantizationConfig
|
||||
|
||||
try:
|
||||
if not is_torchao_available():
|
||||
raise OptionalDependencyNotAvailable()
|
||||
except OptionalDependencyNotAvailable:
|
||||
from .utils.dummy_torchao_objects import *
|
||||
else:
|
||||
from .quantizers.quantization_config import TorchAoConfig
|
||||
|
||||
try:
|
||||
if not is_optimum_quanto_available():
|
||||
raise OptionalDependencyNotAvailable()
|
||||
except OptionalDependencyNotAvailable:
|
||||
from .utils.dummy_optimum_quanto_objects import *
|
||||
else:
|
||||
from .quantizers.quantization_config import QuantoConfig
|
||||
|
||||
try:
|
||||
if not is_onnx_available():
|
||||
@@ -615,7 +708,13 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
except OptionalDependencyNotAvailable:
|
||||
from .utils.dummy_pt_objects import * # noqa F403
|
||||
else:
|
||||
from .hooks import HookRegistry, PyramidAttentionBroadcastConfig, apply_pyramid_attention_broadcast
|
||||
from .hooks import (
|
||||
FasterCacheConfig,
|
||||
HookRegistry,
|
||||
PyramidAttentionBroadcastConfig,
|
||||
apply_faster_cache,
|
||||
apply_pyramid_attention_broadcast,
|
||||
)
|
||||
from .models import (
|
||||
AllegroTransformer3DModel,
|
||||
AsymmetricAutoencoderKL,
|
||||
@@ -748,6 +847,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
RePaintScheduler,
|
||||
SASolverScheduler,
|
||||
SchedulerMixin,
|
||||
SCMScheduler,
|
||||
ScoreSdeVeScheduler,
|
||||
TCDScheduler,
|
||||
UnCLIPScheduler,
|
||||
@@ -803,6 +903,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
CogVideoXPipeline,
|
||||
CogVideoXVideoToVideoPipeline,
|
||||
CogView3PlusPipeline,
|
||||
CogView4ControlPipeline,
|
||||
CogView4Pipeline,
|
||||
ConsisIDPipeline,
|
||||
CycleDiffusionPipeline,
|
||||
@@ -859,9 +960,12 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
LDMTextToImagePipeline,
|
||||
LEditsPPPipelineStableDiffusion,
|
||||
LEditsPPPipelineStableDiffusionXL,
|
||||
LTXConditionPipeline,
|
||||
LTXImageToVideoPipeline,
|
||||
LTXPipeline,
|
||||
Lumina2Pipeline,
|
||||
Lumina2Text2ImgPipeline,
|
||||
LuminaPipeline,
|
||||
LuminaText2ImgPipeline,
|
||||
MarigoldDepthPipeline,
|
||||
MarigoldIntrinsicsPipeline,
|
||||
@@ -877,6 +981,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
ReduxImageEncoder,
|
||||
SanaPAGPipeline,
|
||||
SanaPipeline,
|
||||
SanaSprintPipeline,
|
||||
SemanticStableDiffusionPipeline,
|
||||
ShapEImg2ImgPipeline,
|
||||
ShapEPipeline,
|
||||
@@ -958,6 +1063,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
VQDiffusionPipeline,
|
||||
WanImageToVideoPipeline,
|
||||
WanPipeline,
|
||||
WanVideoToVideoPipeline,
|
||||
WuerstchenCombinedPipeline,
|
||||
WuerstchenDecoderPipeline,
|
||||
WuerstchenPriorPipeline,
|
||||
|
||||
@@ -35,6 +35,10 @@ deps = {
|
||||
"GitPython": "GitPython<3.1.19",
|
||||
"scipy": "scipy",
|
||||
"onnx": "onnx",
|
||||
"optimum_quanto": "optimum_quanto>=0.2.6",
|
||||
"gguf": "gguf>=0.10.0",
|
||||
"torchao": "torchao>=0.7.0",
|
||||
"bitsandbytes": "bitsandbytes>=0.43.3",
|
||||
"regex": "regex!=2019.12.17",
|
||||
"requests": "requests",
|
||||
"tensorboard": "tensorboard",
|
||||
|
||||
@@ -2,6 +2,7 @@ from ..utils import is_torch_available
|
||||
|
||||
|
||||
if is_torch_available():
|
||||
from .faster_cache import FasterCacheConfig, apply_faster_cache
|
||||
from .group_offloading import apply_group_offloading
|
||||
from .hooks import HookRegistry, ModelHook
|
||||
from .layerwise_casting import apply_layerwise_casting, apply_layerwise_casting_hook
|
||||
|
||||
@@ -0,0 +1,653 @@
|
||||
# Copyright 2024 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import re
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Callable, List, Optional, Tuple
|
||||
|
||||
import torch
|
||||
|
||||
from ..models.attention_processor import Attention, MochiAttention
|
||||
from ..models.modeling_outputs import Transformer2DModelOutput
|
||||
from ..utils import logging
|
||||
from .hooks import HookRegistry, ModelHook
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||
|
||||
|
||||
_FASTER_CACHE_DENOISER_HOOK = "faster_cache_denoiser"
|
||||
_FASTER_CACHE_BLOCK_HOOK = "faster_cache_block"
|
||||
_ATTENTION_CLASSES = (Attention, MochiAttention)
|
||||
_SPATIAL_ATTENTION_BLOCK_IDENTIFIERS = (
|
||||
"^blocks.*attn",
|
||||
"^transformer_blocks.*attn",
|
||||
"^single_transformer_blocks.*attn",
|
||||
)
|
||||
_TEMPORAL_ATTENTION_BLOCK_IDENTIFIERS = ("^temporal_transformer_blocks.*attn",)
|
||||
_TRANSFORMER_BLOCK_IDENTIFIERS = _SPATIAL_ATTENTION_BLOCK_IDENTIFIERS + _TEMPORAL_ATTENTION_BLOCK_IDENTIFIERS
|
||||
_UNCOND_COND_INPUT_KWARGS_IDENTIFIERS = (
|
||||
"hidden_states",
|
||||
"encoder_hidden_states",
|
||||
"timestep",
|
||||
"attention_mask",
|
||||
"encoder_attention_mask",
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class FasterCacheConfig:
|
||||
r"""
|
||||
Configuration for [FasterCache](https://huggingface.co/papers/2410.19355).
|
||||
|
||||
Attributes:
|
||||
spatial_attention_block_skip_range (`int`, defaults to `2`):
|
||||
Calculate the attention states every `N` iterations. If this is set to `N`, the attention computation will
|
||||
be skipped `N - 1` times (i.e., cached attention states will be re-used) before computing the new attention
|
||||
states again.
|
||||
temporal_attention_block_skip_range (`int`, *optional*, defaults to `None`):
|
||||
Calculate the attention states every `N` iterations. If this is set to `N`, the attention computation will
|
||||
be skipped `N - 1` times (i.e., cached attention states will be re-used) before computing the new attention
|
||||
states again.
|
||||
spatial_attention_timestep_skip_range (`Tuple[float, float]`, defaults to `(-1, 681)`):
|
||||
The timestep range within which the spatial attention computation can be skipped without a significant loss
|
||||
in quality. This is to be determined by the user based on the underlying model. The first value in the
|
||||
tuple is the lower bound and the second value is the upper bound. Typically, diffusion timesteps for
|
||||
denoising are in the reversed range of 0 to 1000 (i.e. denoising starts at timestep 1000 and ends at
|
||||
timestep 0). For the default values, this would mean that the spatial attention computation skipping will
|
||||
be applicable only after denoising timestep 681 is reached, and continue until the end of the denoising
|
||||
process.
|
||||
temporal_attention_timestep_skip_range (`Tuple[float, float]`, *optional*, defaults to `None`):
|
||||
The timestep range within which the temporal attention computation can be skipped without a significant
|
||||
loss in quality. This is to be determined by the user based on the underlying model. The first value in the
|
||||
tuple is the lower bound and the second value is the upper bound. Typically, diffusion timesteps for
|
||||
denoising are in the reversed range of 0 to 1000 (i.e. denoising starts at timestep 1000 and ends at
|
||||
timestep 0).
|
||||
low_frequency_weight_update_timestep_range (`Tuple[int, int]`, defaults to `(99, 901)`):
|
||||
The timestep range within which the low frequency weight scaling update is applied. The first value in the
|
||||
tuple is the lower bound and the second value is the upper bound of the timestep range. The callback
|
||||
function for the update is called only within this range.
|
||||
high_frequency_weight_update_timestep_range (`Tuple[int, int]`, defaults to `(-1, 301)`):
|
||||
The timestep range within which the high frequency weight scaling update is applied. The first value in the
|
||||
tuple is the lower bound and the second value is the upper bound of the timestep range. The callback
|
||||
function for the update is called only within this range.
|
||||
alpha_low_frequency (`float`, defaults to `1.1`):
|
||||
The weight to scale the low frequency updates by. This is used to approximate the unconditional branch from
|
||||
the conditional branch outputs.
|
||||
alpha_high_frequency (`float`, defaults to `1.1`):
|
||||
The weight to scale the high frequency updates by. This is used to approximate the unconditional branch
|
||||
from the conditional branch outputs.
|
||||
unconditional_batch_skip_range (`int`, defaults to `5`):
|
||||
Process the unconditional branch every `N` iterations. If this is set to `N`, the unconditional branch
|
||||
computation will be skipped `N - 1` times (i.e., cached unconditional branch states will be re-used) before
|
||||
computing the new unconditional branch states again.
|
||||
unconditional_batch_timestep_skip_range (`Tuple[float, float]`, defaults to `(-1, 641)`):
|
||||
The timestep range within which the unconditional branch computation can be skipped without a significant
|
||||
loss in quality. This is to be determined by the user based on the underlying model. The first value in the
|
||||
tuple is the lower bound and the second value is the upper bound.
|
||||
spatial_attention_block_identifiers (`Tuple[str, ...]`, defaults to `("blocks.*attn1", "transformer_blocks.*attn1", "single_transformer_blocks.*attn1")`):
|
||||
The identifiers to match the spatial attention blocks in the model. If the name of the block contains any
|
||||
of these identifiers, FasterCache will be applied to that block. This can either be the full layer names,
|
||||
partial layer names, or regex patterns. Matching will always be done using a regex match.
|
||||
temporal_attention_block_identifiers (`Tuple[str, ...]`, defaults to `("temporal_transformer_blocks.*attn1",)`):
|
||||
The identifiers to match the temporal attention blocks in the model. If the name of the block contains any
|
||||
of these identifiers, FasterCache will be applied to that block. This can either be the full layer names,
|
||||
partial layer names, or regex patterns. Matching will always be done using a regex match.
|
||||
attention_weight_callback (`Callable[[torch.nn.Module], float]`, defaults to `None`):
|
||||
The callback function to determine the weight to scale the attention outputs by. This function should take
|
||||
the attention module as input and return a float value. This is used to approximate the unconditional
|
||||
branch from the conditional branch outputs. If not provided, the default weight is 0.5 for all timesteps.
|
||||
Typically, as described in the paper, this weight should gradually increase from 0 to 1 as the inference
|
||||
progresses. Users are encouraged to experiment and provide custom weight schedules that take into account
|
||||
the number of inference steps and underlying model behaviour as denoising progresses.
|
||||
low_frequency_weight_callback (`Callable[[torch.nn.Module], float]`, defaults to `None`):
|
||||
The callback function to determine the weight to scale the low frequency updates by. If not provided, the
|
||||
default weight is 1.1 for timesteps within the range specified (as described in the paper).
|
||||
high_frequency_weight_callback (`Callable[[torch.nn.Module], float]`, defaults to `None`):
|
||||
The callback function to determine the weight to scale the high frequency updates by. If not provided, the
|
||||
default weight is 1.1 for timesteps within the range specified (as described in the paper).
|
||||
tensor_format (`str`, defaults to `"BCFHW"`):
|
||||
The format of the input tensors. This should be one of `"BCFHW"`, `"BFCHW"`, or `"BCHW"`. The format is
|
||||
used to split individual latent frames in order for low and high frequency components to be computed.
|
||||
is_guidance_distilled (`bool`, defaults to `False`):
|
||||
Whether the model is guidance distilled or not. If the model is guidance distilled, FasterCache will not be
|
||||
applied at the denoiser-level to skip the unconditional branch computation (as there is none).
|
||||
_unconditional_conditional_input_kwargs_identifiers (`List[str]`, defaults to `("hidden_states", "encoder_hidden_states", "timestep", "attention_mask", "encoder_attention_mask")`):
|
||||
The identifiers to match the input kwargs that contain the batchwise-concatenated unconditional and
|
||||
conditional inputs. If the name of the input kwargs contains any of these identifiers, FasterCache will
|
||||
split the inputs into unconditional and conditional branches. This must be a list of exact input kwargs
|
||||
names that contain the batchwise-concatenated unconditional and conditional inputs.
|
||||
"""
|
||||
|
||||
# In the paper and codebase, they hardcode these values to 2. However, it can be made configurable
|
||||
# after some testing. We default to 2 if these parameters are not provided.
|
||||
spatial_attention_block_skip_range: int = 2
|
||||
temporal_attention_block_skip_range: Optional[int] = None
|
||||
|
||||
spatial_attention_timestep_skip_range: Tuple[int, int] = (-1, 681)
|
||||
temporal_attention_timestep_skip_range: Tuple[int, int] = (-1, 681)
|
||||
|
||||
# Indicator functions for low/high frequency as mentioned in Equation 11 of the paper
|
||||
low_frequency_weight_update_timestep_range: Tuple[int, int] = (99, 901)
|
||||
high_frequency_weight_update_timestep_range: Tuple[int, int] = (-1, 301)
|
||||
|
||||
# ⍺1 and ⍺2 as mentioned in Equation 11 of the paper
|
||||
alpha_low_frequency: float = 1.1
|
||||
alpha_high_frequency: float = 1.1
|
||||
|
||||
# n as described in CFG-Cache explanation in the paper - dependant on the model
|
||||
unconditional_batch_skip_range: int = 5
|
||||
unconditional_batch_timestep_skip_range: Tuple[int, int] = (-1, 641)
|
||||
|
||||
spatial_attention_block_identifiers: Tuple[str, ...] = _SPATIAL_ATTENTION_BLOCK_IDENTIFIERS
|
||||
temporal_attention_block_identifiers: Tuple[str, ...] = _TEMPORAL_ATTENTION_BLOCK_IDENTIFIERS
|
||||
|
||||
attention_weight_callback: Callable[[torch.nn.Module], float] = None
|
||||
low_frequency_weight_callback: Callable[[torch.nn.Module], float] = None
|
||||
high_frequency_weight_callback: Callable[[torch.nn.Module], float] = None
|
||||
|
||||
tensor_format: str = "BCFHW"
|
||||
is_guidance_distilled: bool = False
|
||||
|
||||
current_timestep_callback: Callable[[], int] = None
|
||||
|
||||
_unconditional_conditional_input_kwargs_identifiers: List[str] = _UNCOND_COND_INPUT_KWARGS_IDENTIFIERS
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return (
|
||||
f"FasterCacheConfig(\n"
|
||||
f" spatial_attention_block_skip_range={self.spatial_attention_block_skip_range},\n"
|
||||
f" temporal_attention_block_skip_range={self.temporal_attention_block_skip_range},\n"
|
||||
f" spatial_attention_timestep_skip_range={self.spatial_attention_timestep_skip_range},\n"
|
||||
f" temporal_attention_timestep_skip_range={self.temporal_attention_timestep_skip_range},\n"
|
||||
f" low_frequency_weight_update_timestep_range={self.low_frequency_weight_update_timestep_range},\n"
|
||||
f" high_frequency_weight_update_timestep_range={self.high_frequency_weight_update_timestep_range},\n"
|
||||
f" alpha_low_frequency={self.alpha_low_frequency},\n"
|
||||
f" alpha_high_frequency={self.alpha_high_frequency},\n"
|
||||
f" unconditional_batch_skip_range={self.unconditional_batch_skip_range},\n"
|
||||
f" unconditional_batch_timestep_skip_range={self.unconditional_batch_timestep_skip_range},\n"
|
||||
f" spatial_attention_block_identifiers={self.spatial_attention_block_identifiers},\n"
|
||||
f" temporal_attention_block_identifiers={self.temporal_attention_block_identifiers},\n"
|
||||
f" tensor_format={self.tensor_format},\n"
|
||||
f")"
|
||||
)
|
||||
|
||||
|
||||
class FasterCacheDenoiserState:
|
||||
r"""
|
||||
State for [FasterCache](https://huggingface.co/papers/2410.19355) top-level denoiser module.
|
||||
"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
self.iteration: int = 0
|
||||
self.low_frequency_delta: torch.Tensor = None
|
||||
self.high_frequency_delta: torch.Tensor = None
|
||||
|
||||
def reset(self):
|
||||
self.iteration = 0
|
||||
self.low_frequency_delta = None
|
||||
self.high_frequency_delta = None
|
||||
|
||||
|
||||
class FasterCacheBlockState:
|
||||
r"""
|
||||
State for [FasterCache](https://huggingface.co/papers/2410.19355). Every underlying block that FasterCache is
|
||||
applied to will have an instance of this state.
|
||||
"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
self.iteration: int = 0
|
||||
self.batch_size: int = None
|
||||
self.cache: Tuple[torch.Tensor, torch.Tensor] = None
|
||||
|
||||
def reset(self):
|
||||
self.iteration = 0
|
||||
self.batch_size = None
|
||||
self.cache = None
|
||||
|
||||
|
||||
class FasterCacheDenoiserHook(ModelHook):
|
||||
_is_stateful = True
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
unconditional_batch_skip_range: int,
|
||||
unconditional_batch_timestep_skip_range: Tuple[int, int],
|
||||
tensor_format: str,
|
||||
is_guidance_distilled: bool,
|
||||
uncond_cond_input_kwargs_identifiers: List[str],
|
||||
current_timestep_callback: Callable[[], int],
|
||||
low_frequency_weight_callback: Callable[[torch.nn.Module], torch.Tensor],
|
||||
high_frequency_weight_callback: Callable[[torch.nn.Module], torch.Tensor],
|
||||
) -> None:
|
||||
super().__init__()
|
||||
|
||||
self.unconditional_batch_skip_range = unconditional_batch_skip_range
|
||||
self.unconditional_batch_timestep_skip_range = unconditional_batch_timestep_skip_range
|
||||
# We can't easily detect what args are to be split in unconditional and conditional branches. We
|
||||
# can only do it for kwargs, hence they are the only ones we split. The args are passed as-is.
|
||||
# If a model is to be made compatible with FasterCache, the user must ensure that the inputs that
|
||||
# contain batchwise-concatenated unconditional and conditional inputs are passed as kwargs.
|
||||
self.uncond_cond_input_kwargs_identifiers = uncond_cond_input_kwargs_identifiers
|
||||
self.tensor_format = tensor_format
|
||||
self.is_guidance_distilled = is_guidance_distilled
|
||||
|
||||
self.current_timestep_callback = current_timestep_callback
|
||||
self.low_frequency_weight_callback = low_frequency_weight_callback
|
||||
self.high_frequency_weight_callback = high_frequency_weight_callback
|
||||
|
||||
def initialize_hook(self, module):
|
||||
self.state = FasterCacheDenoiserState()
|
||||
return module
|
||||
|
||||
@staticmethod
|
||||
def _get_cond_input(input: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
# Note: this method assumes that the input tensor is batchwise-concatenated with unconditional inputs
|
||||
# followed by conditional inputs.
|
||||
_, cond = input.chunk(2, dim=0)
|
||||
return cond
|
||||
|
||||
def new_forward(self, module: torch.nn.Module, *args, **kwargs) -> Any:
|
||||
# Split the unconditional and conditional inputs. We only want to infer the conditional branch if the
|
||||
# requirements for skipping the unconditional branch are met as described in the paper.
|
||||
# We skip the unconditional branch only if the following conditions are met:
|
||||
# 1. We have completed at least one iteration of the denoiser
|
||||
# 2. The current timestep is within the range specified by the user. This is the optimal timestep range
|
||||
# where approximating the unconditional branch from the computation of the conditional branch is possible
|
||||
# without a significant loss in quality.
|
||||
# 3. The current iteration is not a multiple of the unconditional batch skip range. This is done so that
|
||||
# we compute the unconditional branch at least once every few iterations to ensure minimal quality loss.
|
||||
is_within_timestep_range = (
|
||||
self.unconditional_batch_timestep_skip_range[0]
|
||||
< self.current_timestep_callback()
|
||||
< self.unconditional_batch_timestep_skip_range[1]
|
||||
)
|
||||
should_skip_uncond = (
|
||||
self.state.iteration > 0
|
||||
and is_within_timestep_range
|
||||
and self.state.iteration % self.unconditional_batch_skip_range != 0
|
||||
and not self.is_guidance_distilled
|
||||
)
|
||||
|
||||
if should_skip_uncond:
|
||||
is_any_kwarg_uncond = any(k in self.uncond_cond_input_kwargs_identifiers for k in kwargs.keys())
|
||||
if is_any_kwarg_uncond:
|
||||
logger.debug("FasterCache - Skipping unconditional branch computation")
|
||||
args = tuple([self._get_cond_input(arg) if torch.is_tensor(arg) else arg for arg in args])
|
||||
kwargs = {
|
||||
k: v if k not in self.uncond_cond_input_kwargs_identifiers else self._get_cond_input(v)
|
||||
for k, v in kwargs.items()
|
||||
}
|
||||
|
||||
output = self.fn_ref.original_forward(*args, **kwargs)
|
||||
|
||||
if self.is_guidance_distilled:
|
||||
self.state.iteration += 1
|
||||
return output
|
||||
|
||||
if torch.is_tensor(output):
|
||||
hidden_states = output
|
||||
elif isinstance(output, (tuple, Transformer2DModelOutput)):
|
||||
hidden_states = output[0]
|
||||
|
||||
batch_size = hidden_states.size(0)
|
||||
|
||||
if should_skip_uncond:
|
||||
self.state.low_frequency_delta = self.state.low_frequency_delta * self.low_frequency_weight_callback(
|
||||
module
|
||||
)
|
||||
self.state.high_frequency_delta = self.state.high_frequency_delta * self.high_frequency_weight_callback(
|
||||
module
|
||||
)
|
||||
|
||||
if self.tensor_format == "BCFHW":
|
||||
hidden_states = hidden_states.permute(0, 2, 1, 3, 4)
|
||||
if self.tensor_format == "BCFHW" or self.tensor_format == "BFCHW":
|
||||
hidden_states = hidden_states.flatten(0, 1)
|
||||
|
||||
low_freq_cond, high_freq_cond = _split_low_high_freq(hidden_states.float())
|
||||
|
||||
# Approximate/compute the unconditional branch outputs as described in Equation 9 and 10 of the paper
|
||||
low_freq_uncond = self.state.low_frequency_delta + low_freq_cond
|
||||
high_freq_uncond = self.state.high_frequency_delta + high_freq_cond
|
||||
uncond_freq = low_freq_uncond + high_freq_uncond
|
||||
|
||||
uncond_states = torch.fft.ifftshift(uncond_freq)
|
||||
uncond_states = torch.fft.ifft2(uncond_states).real
|
||||
|
||||
if self.tensor_format == "BCFHW" or self.tensor_format == "BFCHW":
|
||||
uncond_states = uncond_states.unflatten(0, (batch_size, -1))
|
||||
hidden_states = hidden_states.unflatten(0, (batch_size, -1))
|
||||
if self.tensor_format == "BCFHW":
|
||||
uncond_states = uncond_states.permute(0, 2, 1, 3, 4)
|
||||
hidden_states = hidden_states.permute(0, 2, 1, 3, 4)
|
||||
|
||||
# Concatenate the approximated unconditional and predicted conditional branches
|
||||
uncond_states = uncond_states.to(hidden_states.dtype)
|
||||
hidden_states = torch.cat([uncond_states, hidden_states], dim=0)
|
||||
else:
|
||||
uncond_states, cond_states = hidden_states.chunk(2, dim=0)
|
||||
if self.tensor_format == "BCFHW":
|
||||
uncond_states = uncond_states.permute(0, 2, 1, 3, 4)
|
||||
cond_states = cond_states.permute(0, 2, 1, 3, 4)
|
||||
if self.tensor_format == "BCFHW" or self.tensor_format == "BFCHW":
|
||||
uncond_states = uncond_states.flatten(0, 1)
|
||||
cond_states = cond_states.flatten(0, 1)
|
||||
|
||||
low_freq_uncond, high_freq_uncond = _split_low_high_freq(uncond_states.float())
|
||||
low_freq_cond, high_freq_cond = _split_low_high_freq(cond_states.float())
|
||||
self.state.low_frequency_delta = low_freq_uncond - low_freq_cond
|
||||
self.state.high_frequency_delta = high_freq_uncond - high_freq_cond
|
||||
|
||||
self.state.iteration += 1
|
||||
if torch.is_tensor(output):
|
||||
output = hidden_states
|
||||
elif isinstance(output, tuple):
|
||||
output = (hidden_states, *output[1:])
|
||||
else:
|
||||
output.sample = hidden_states
|
||||
|
||||
return output
|
||||
|
||||
def reset_state(self, module: torch.nn.Module) -> torch.nn.Module:
|
||||
self.state.reset()
|
||||
return module
|
||||
|
||||
|
||||
class FasterCacheBlockHook(ModelHook):
|
||||
_is_stateful = True
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
block_skip_range: int,
|
||||
timestep_skip_range: Tuple[int, int],
|
||||
is_guidance_distilled: bool,
|
||||
weight_callback: Callable[[torch.nn.Module], float],
|
||||
current_timestep_callback: Callable[[], int],
|
||||
) -> None:
|
||||
super().__init__()
|
||||
|
||||
self.block_skip_range = block_skip_range
|
||||
self.timestep_skip_range = timestep_skip_range
|
||||
self.is_guidance_distilled = is_guidance_distilled
|
||||
|
||||
self.weight_callback = weight_callback
|
||||
self.current_timestep_callback = current_timestep_callback
|
||||
|
||||
def initialize_hook(self, module):
|
||||
self.state = FasterCacheBlockState()
|
||||
return module
|
||||
|
||||
def _compute_approximated_attention_output(
|
||||
self, t_2_output: torch.Tensor, t_output: torch.Tensor, weight: float, batch_size: int
|
||||
) -> torch.Tensor:
|
||||
if t_2_output.size(0) != batch_size:
|
||||
# The cache t_2_output contains both batchwise-concatenated unconditional-conditional branch outputs. Just
|
||||
# take the conditional branch outputs.
|
||||
assert t_2_output.size(0) == 2 * batch_size
|
||||
t_2_output = t_2_output[batch_size:]
|
||||
if t_output.size(0) != batch_size:
|
||||
# The cache t_output contains both batchwise-concatenated unconditional-conditional branch outputs. Just
|
||||
# take the conditional branch outputs.
|
||||
assert t_output.size(0) == 2 * batch_size
|
||||
t_output = t_output[batch_size:]
|
||||
return t_output + (t_output - t_2_output) * weight
|
||||
|
||||
def new_forward(self, module: torch.nn.Module, *args, **kwargs) -> Any:
|
||||
batch_size = [
|
||||
*[arg.size(0) for arg in args if torch.is_tensor(arg)],
|
||||
*[v.size(0) for v in kwargs.values() if torch.is_tensor(v)],
|
||||
][0]
|
||||
if self.state.batch_size is None:
|
||||
# Will be updated on first forward pass through the denoiser
|
||||
self.state.batch_size = batch_size
|
||||
|
||||
# If we have to skip due to the skip conditions, then let's skip as expected.
|
||||
# But, we can't skip if the denoiser wants to infer both unconditional and conditional branches. This
|
||||
# is because the expected output shapes of attention layer will not match if we only return values from
|
||||
# the cache (which only caches conditional branch outputs). So, if state.batch_size (which is the true
|
||||
# unconditional-conditional batch size) is same as the current batch size, we don't perform the layer
|
||||
# skip. Otherwise, we conditionally skip the layer based on what state.skip_callback returns.
|
||||
is_within_timestep_range = (
|
||||
self.timestep_skip_range[0] < self.current_timestep_callback() < self.timestep_skip_range[1]
|
||||
)
|
||||
if not is_within_timestep_range:
|
||||
should_skip_attention = False
|
||||
else:
|
||||
should_compute_attention = self.state.iteration > 0 and self.state.iteration % self.block_skip_range == 0
|
||||
should_skip_attention = not should_compute_attention
|
||||
if should_skip_attention:
|
||||
should_skip_attention = self.is_guidance_distilled or self.state.batch_size != batch_size
|
||||
|
||||
if should_skip_attention:
|
||||
logger.debug("FasterCache - Skipping attention and using approximation")
|
||||
if torch.is_tensor(self.state.cache[-1]):
|
||||
t_2_output, t_output = self.state.cache
|
||||
weight = self.weight_callback(module)
|
||||
output = self._compute_approximated_attention_output(t_2_output, t_output, weight, batch_size)
|
||||
else:
|
||||
# The cache contains multiple tensors from past N iterations (N=2 for FasterCache). We need to handle all of them.
|
||||
# Diffusers blocks can return multiple tensors - let's call them [A, B, C, ...] for simplicity.
|
||||
# In our cache, we would have [[A_1, B_1, C_1, ...], [A_2, B_2, C_2, ...], ...] where each list is the output from
|
||||
# a forward pass of the block. We need to compute the approximated output for each of these tensors.
|
||||
# The zip(*state.cache) operation will give us [(A_1, A_2, ...), (B_1, B_2, ...), (C_1, C_2, ...), ...] which
|
||||
# allows us to compute the approximated attention output for each tensor in the cache.
|
||||
output = ()
|
||||
for t_2_output, t_output in zip(*self.state.cache):
|
||||
result = self._compute_approximated_attention_output(
|
||||
t_2_output, t_output, self.weight_callback(module), batch_size
|
||||
)
|
||||
output += (result,)
|
||||
else:
|
||||
logger.debug("FasterCache - Computing attention")
|
||||
output = self.fn_ref.original_forward(*args, **kwargs)
|
||||
|
||||
# Note that the following condition for getting hidden_states should suffice since Diffusers blocks either return
|
||||
# a single hidden_states tensor, or a tuple of (hidden_states, encoder_hidden_states) tensors. We need to handle
|
||||
# both cases.
|
||||
if torch.is_tensor(output):
|
||||
cache_output = output
|
||||
if not self.is_guidance_distilled and cache_output.size(0) == self.state.batch_size:
|
||||
# The output here can be both unconditional-conditional branch outputs or just conditional branch outputs.
|
||||
# This is determined at the higher-level denoiser module. We only want to cache the conditional branch outputs.
|
||||
cache_output = cache_output.chunk(2, dim=0)[1]
|
||||
else:
|
||||
# Cache all return values and perform the same operation as above
|
||||
cache_output = ()
|
||||
for out in output:
|
||||
if not self.is_guidance_distilled and out.size(0) == self.state.batch_size:
|
||||
out = out.chunk(2, dim=0)[1]
|
||||
cache_output += (out,)
|
||||
|
||||
if self.state.cache is None:
|
||||
self.state.cache = [cache_output, cache_output]
|
||||
else:
|
||||
self.state.cache = [self.state.cache[-1], cache_output]
|
||||
|
||||
self.state.iteration += 1
|
||||
return output
|
||||
|
||||
def reset_state(self, module: torch.nn.Module) -> torch.nn.Module:
|
||||
self.state.reset()
|
||||
return module
|
||||
|
||||
|
||||
def apply_faster_cache(module: torch.nn.Module, config: FasterCacheConfig) -> None:
|
||||
r"""
|
||||
Applies [FasterCache](https://huggingface.co/papers/2410.19355) to a given pipeline.
|
||||
|
||||
Args:
|
||||
pipeline (`DiffusionPipeline`):
|
||||
The diffusion pipeline to apply FasterCache to.
|
||||
config (`Optional[FasterCacheConfig]`, `optional`, defaults to `None`):
|
||||
The configuration to use for FasterCache.
|
||||
|
||||
Example:
|
||||
```python
|
||||
>>> import torch
|
||||
>>> from diffusers import CogVideoXPipeline, FasterCacheConfig, apply_faster_cache
|
||||
|
||||
>>> pipe = CogVideoXPipeline.from_pretrained("THUDM/CogVideoX-5b", torch_dtype=torch.bfloat16)
|
||||
>>> pipe.to("cuda")
|
||||
|
||||
>>> config = FasterCacheConfig(
|
||||
... spatial_attention_block_skip_range=2,
|
||||
... spatial_attention_timestep_skip_range=(-1, 681),
|
||||
... low_frequency_weight_update_timestep_range=(99, 641),
|
||||
... high_frequency_weight_update_timestep_range=(-1, 301),
|
||||
... spatial_attention_block_identifiers=["transformer_blocks"],
|
||||
... attention_weight_callback=lambda _: 0.3,
|
||||
... tensor_format="BFCHW",
|
||||
... )
|
||||
>>> apply_faster_cache(pipe.transformer, config)
|
||||
```
|
||||
"""
|
||||
|
||||
logger.warning(
|
||||
"FasterCache is a purely experimental feature and may not work as expected. Not all models support FasterCache. "
|
||||
"The API is subject to change in future releases, with no guarantee of backward compatibility. Please report any issues at "
|
||||
"https://github.com/huggingface/diffusers/issues."
|
||||
)
|
||||
|
||||
if config.attention_weight_callback is None:
|
||||
# If the user has not provided a weight callback, we default to 0.5 for all timesteps.
|
||||
# In the paper, they recommend using a gradually increasing weight from 0 to 1 as the inference progresses, but
|
||||
# this depends from model-to-model. It is required by the user to provide a weight callback if they want to
|
||||
# use a different weight function. Defaulting to 0.5 works well in practice for most cases.
|
||||
logger.warning(
|
||||
"No `attention_weight_callback` provided when enabling FasterCache. Defaulting to using a weight of 0.5 for all timesteps."
|
||||
)
|
||||
config.attention_weight_callback = lambda _: 0.5
|
||||
|
||||
if config.low_frequency_weight_callback is None:
|
||||
logger.debug(
|
||||
"Low frequency weight callback not provided when enabling FasterCache. Defaulting to behaviour described in the paper."
|
||||
)
|
||||
|
||||
def low_frequency_weight_callback(module: torch.nn.Module) -> float:
|
||||
is_within_range = (
|
||||
config.low_frequency_weight_update_timestep_range[0]
|
||||
< config.current_timestep_callback()
|
||||
< config.low_frequency_weight_update_timestep_range[1]
|
||||
)
|
||||
return config.alpha_low_frequency if is_within_range else 1.0
|
||||
|
||||
config.low_frequency_weight_callback = low_frequency_weight_callback
|
||||
|
||||
if config.high_frequency_weight_callback is None:
|
||||
logger.debug(
|
||||
"High frequency weight callback not provided when enabling FasterCache. Defaulting to behaviour described in the paper."
|
||||
)
|
||||
|
||||
def high_frequency_weight_callback(module: torch.nn.Module) -> float:
|
||||
is_within_range = (
|
||||
config.high_frequency_weight_update_timestep_range[0]
|
||||
< config.current_timestep_callback()
|
||||
< config.high_frequency_weight_update_timestep_range[1]
|
||||
)
|
||||
return config.alpha_high_frequency if is_within_range else 1.0
|
||||
|
||||
config.high_frequency_weight_callback = high_frequency_weight_callback
|
||||
|
||||
supported_tensor_formats = ["BCFHW", "BFCHW", "BCHW"] # TODO(aryan): Support BSC for LTX Video
|
||||
if config.tensor_format not in supported_tensor_formats:
|
||||
raise ValueError(f"`tensor_format` must be one of {supported_tensor_formats}, but got {config.tensor_format}.")
|
||||
|
||||
_apply_faster_cache_on_denoiser(module, config)
|
||||
|
||||
for name, submodule in module.named_modules():
|
||||
if not isinstance(submodule, _ATTENTION_CLASSES):
|
||||
continue
|
||||
if any(re.search(identifier, name) is not None for identifier in _TRANSFORMER_BLOCK_IDENTIFIERS):
|
||||
_apply_faster_cache_on_attention_class(name, submodule, config)
|
||||
|
||||
|
||||
def _apply_faster_cache_on_denoiser(module: torch.nn.Module, config: FasterCacheConfig) -> None:
|
||||
hook = FasterCacheDenoiserHook(
|
||||
config.unconditional_batch_skip_range,
|
||||
config.unconditional_batch_timestep_skip_range,
|
||||
config.tensor_format,
|
||||
config.is_guidance_distilled,
|
||||
config._unconditional_conditional_input_kwargs_identifiers,
|
||||
config.current_timestep_callback,
|
||||
config.low_frequency_weight_callback,
|
||||
config.high_frequency_weight_callback,
|
||||
)
|
||||
registry = HookRegistry.check_if_exists_or_initialize(module)
|
||||
registry.register_hook(hook, _FASTER_CACHE_DENOISER_HOOK)
|
||||
|
||||
|
||||
def _apply_faster_cache_on_attention_class(name: str, module: Attention, config: FasterCacheConfig) -> None:
|
||||
is_spatial_self_attention = (
|
||||
any(re.search(identifier, name) is not None for identifier in config.spatial_attention_block_identifiers)
|
||||
and config.spatial_attention_block_skip_range is not None
|
||||
and not getattr(module, "is_cross_attention", False)
|
||||
)
|
||||
is_temporal_self_attention = (
|
||||
any(re.search(identifier, name) is not None for identifier in config.temporal_attention_block_identifiers)
|
||||
and config.temporal_attention_block_skip_range is not None
|
||||
and not module.is_cross_attention
|
||||
)
|
||||
|
||||
block_skip_range, timestep_skip_range, block_type = None, None, None
|
||||
if is_spatial_self_attention:
|
||||
block_skip_range = config.spatial_attention_block_skip_range
|
||||
timestep_skip_range = config.spatial_attention_timestep_skip_range
|
||||
block_type = "spatial"
|
||||
elif is_temporal_self_attention:
|
||||
block_skip_range = config.temporal_attention_block_skip_range
|
||||
timestep_skip_range = config.temporal_attention_timestep_skip_range
|
||||
block_type = "temporal"
|
||||
|
||||
if block_skip_range is None or timestep_skip_range is None:
|
||||
logger.debug(
|
||||
f'Unable to apply FasterCache to the selected layer: "{name}" because it does '
|
||||
f"not match any of the required criteria for spatial or temporal attention layers. Note, "
|
||||
f"however, that this layer may still be valid for applying PAB. Please specify the correct "
|
||||
f"block identifiers in the configuration or use the specialized `apply_faster_cache_on_module` "
|
||||
f"function to apply FasterCache to this layer."
|
||||
)
|
||||
return
|
||||
|
||||
logger.debug(f"Enabling FasterCache ({block_type}) for layer: {name}")
|
||||
hook = FasterCacheBlockHook(
|
||||
block_skip_range,
|
||||
timestep_skip_range,
|
||||
config.is_guidance_distilled,
|
||||
config.attention_weight_callback,
|
||||
config.current_timestep_callback,
|
||||
)
|
||||
registry = HookRegistry.check_if_exists_or_initialize(module)
|
||||
registry.register_hook(hook, _FASTER_CACHE_BLOCK_HOOK)
|
||||
|
||||
|
||||
# Reference: https://github.com/Vchitect/FasterCache/blob/fab32c15014636dc854948319c0a9a8d92c7acb4/scripts/latte/faster_cache_sample_latte.py#L127C1-L143C39
|
||||
@torch.no_grad()
|
||||
def _split_low_high_freq(x):
|
||||
fft = torch.fft.fft2(x)
|
||||
fft_shifted = torch.fft.fftshift(fft)
|
||||
height, width = x.shape[-2:]
|
||||
radius = min(height, width) // 5
|
||||
|
||||
y_grid, x_grid = torch.meshgrid(torch.arange(height), torch.arange(width))
|
||||
center_x, center_y = width // 2, height // 2
|
||||
mask = (x_grid - center_x) ** 2 + (y_grid - center_y) ** 2 <= radius**2
|
||||
|
||||
low_freq_mask = mask.unsqueeze(0).unsqueeze(0).to(x.device)
|
||||
high_freq_mask = ~low_freq_mask
|
||||
|
||||
low_freq_fft = fft_shifted * low_freq_mask
|
||||
high_freq_fft = fft_shifted * high_freq_mask
|
||||
|
||||
return low_freq_fft, high_freq_fft
|
||||
@@ -12,7 +12,7 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from contextlib import nullcontext
|
||||
from contextlib import contextmanager, nullcontext
|
||||
from typing import Dict, List, Optional, Set, Tuple
|
||||
|
||||
import torch
|
||||
@@ -56,7 +56,7 @@ class ModuleGroup:
|
||||
buffers: Optional[List[torch.Tensor]] = None,
|
||||
non_blocking: bool = False,
|
||||
stream: Optional[torch.cuda.Stream] = None,
|
||||
cpu_param_dict: Optional[Dict[torch.nn.Parameter, torch.Tensor]] = None,
|
||||
low_cpu_mem_usage=False,
|
||||
onload_self: bool = True,
|
||||
) -> None:
|
||||
self.modules = modules
|
||||
@@ -64,15 +64,50 @@ class ModuleGroup:
|
||||
self.onload_device = onload_device
|
||||
self.offload_leader = offload_leader
|
||||
self.onload_leader = onload_leader
|
||||
self.parameters = parameters
|
||||
self.buffers = buffers
|
||||
self.parameters = parameters or []
|
||||
self.buffers = buffers or []
|
||||
self.non_blocking = non_blocking or stream is not None
|
||||
self.stream = stream
|
||||
self.cpu_param_dict = cpu_param_dict
|
||||
self.onload_self = onload_self
|
||||
self.low_cpu_mem_usage = low_cpu_mem_usage
|
||||
|
||||
if self.stream is not None and self.cpu_param_dict is None:
|
||||
raise ValueError("cpu_param_dict must be provided when using stream for data transfer.")
|
||||
self.cpu_param_dict = self._init_cpu_param_dict()
|
||||
|
||||
def _init_cpu_param_dict(self):
|
||||
cpu_param_dict = {}
|
||||
if self.stream is None:
|
||||
return cpu_param_dict
|
||||
|
||||
for module in self.modules:
|
||||
for param in module.parameters():
|
||||
cpu_param_dict[param] = param.data.cpu() if self.low_cpu_mem_usage else param.data.cpu().pin_memory()
|
||||
for buffer in module.buffers():
|
||||
cpu_param_dict[buffer] = (
|
||||
buffer.data.cpu() if self.low_cpu_mem_usage else buffer.data.cpu().pin_memory()
|
||||
)
|
||||
|
||||
for param in self.parameters:
|
||||
cpu_param_dict[param] = param.data.cpu() if self.low_cpu_mem_usage else param.data.cpu().pin_memory()
|
||||
|
||||
for buffer in self.buffers:
|
||||
cpu_param_dict[buffer] = buffer.data.cpu() if self.low_cpu_mem_usage else buffer.data.cpu().pin_memory()
|
||||
|
||||
return cpu_param_dict
|
||||
|
||||
@contextmanager
|
||||
def _pinned_memory_tensors(self):
|
||||
pinned_dict = {}
|
||||
try:
|
||||
for param, tensor in self.cpu_param_dict.items():
|
||||
if not tensor.is_pinned():
|
||||
pinned_dict[param] = tensor.pin_memory()
|
||||
else:
|
||||
pinned_dict[param] = tensor
|
||||
|
||||
yield pinned_dict
|
||||
|
||||
finally:
|
||||
pinned_dict = None
|
||||
|
||||
def onload_(self):
|
||||
r"""Onloads the group of modules to the onload_device."""
|
||||
@@ -82,12 +117,30 @@ class ModuleGroup:
|
||||
self.stream.synchronize()
|
||||
|
||||
with context:
|
||||
for group_module in self.modules:
|
||||
group_module.to(self.onload_device, non_blocking=self.non_blocking)
|
||||
if self.parameters is not None:
|
||||
if self.stream is not None:
|
||||
with self._pinned_memory_tensors() as pinned_memory:
|
||||
for group_module in self.modules:
|
||||
for param in group_module.parameters():
|
||||
param.data = pinned_memory[param].to(self.onload_device, non_blocking=self.non_blocking)
|
||||
for buffer in group_module.buffers():
|
||||
buffer.data = pinned_memory[buffer].to(self.onload_device, non_blocking=self.non_blocking)
|
||||
|
||||
for param in self.parameters:
|
||||
param.data = pinned_memory[param].to(self.onload_device, non_blocking=self.non_blocking)
|
||||
|
||||
for buffer in self.buffers:
|
||||
buffer.data = pinned_memory[buffer].to(self.onload_device, non_blocking=self.non_blocking)
|
||||
|
||||
else:
|
||||
for group_module in self.modules:
|
||||
for param in group_module.parameters():
|
||||
param.data = param.data.to(self.onload_device, non_blocking=self.non_blocking)
|
||||
for buffer in group_module.buffers():
|
||||
buffer.data = buffer.data.to(self.onload_device, non_blocking=self.non_blocking)
|
||||
|
||||
for param in self.parameters:
|
||||
param.data = param.data.to(self.onload_device, non_blocking=self.non_blocking)
|
||||
if self.buffers is not None:
|
||||
|
||||
for buffer in self.buffers:
|
||||
buffer.data = buffer.data.to(self.onload_device, non_blocking=self.non_blocking)
|
||||
|
||||
@@ -98,15 +151,18 @@ class ModuleGroup:
|
||||
for group_module in self.modules:
|
||||
for param in group_module.parameters():
|
||||
param.data = self.cpu_param_dict[param]
|
||||
for param in self.parameters:
|
||||
param.data = self.cpu_param_dict[param]
|
||||
for buffer in self.buffers:
|
||||
buffer.data = self.cpu_param_dict[buffer]
|
||||
|
||||
else:
|
||||
for group_module in self.modules:
|
||||
group_module.to(self.offload_device, non_blocking=self.non_blocking)
|
||||
if self.parameters is not None:
|
||||
for param in self.parameters:
|
||||
param.data = param.data.to(self.offload_device, non_blocking=self.non_blocking)
|
||||
if self.buffers is not None:
|
||||
for buffer in self.buffers:
|
||||
buffer.data = buffer.data.to(self.offload_device, non_blocking=self.non_blocking)
|
||||
for param in self.parameters:
|
||||
param.data = param.data.to(self.offload_device, non_blocking=self.non_blocking)
|
||||
for buffer in self.buffers:
|
||||
buffer.data = buffer.data.to(self.offload_device, non_blocking=self.non_blocking)
|
||||
|
||||
|
||||
class GroupOffloadingHook(ModelHook):
|
||||
@@ -172,6 +228,13 @@ class LazyPrefetchGroupOffloadingHook(ModelHook):
|
||||
self._layer_execution_tracker_module_names = set()
|
||||
|
||||
def initialize_hook(self, module):
|
||||
def make_execution_order_update_callback(current_name, current_submodule):
|
||||
def callback():
|
||||
logger.debug(f"Adding {current_name} to the execution order")
|
||||
self.execution_order.append((current_name, current_submodule))
|
||||
|
||||
return callback
|
||||
|
||||
# To every submodule that contains a group offloading hook (at this point, no prefetching is enabled for any
|
||||
# of the groups), we add a layer execution tracker hook that will be used to determine the order in which the
|
||||
# layers are executed during the forward pass.
|
||||
@@ -183,14 +246,8 @@ class LazyPrefetchGroupOffloadingHook(ModelHook):
|
||||
group_offloading_hook = registry.get_hook(_GROUP_OFFLOADING)
|
||||
|
||||
if group_offloading_hook is not None:
|
||||
|
||||
def make_execution_order_update_callback(current_name, current_submodule):
|
||||
def callback():
|
||||
logger.debug(f"Adding {current_name} to the execution order")
|
||||
self.execution_order.append((current_name, current_submodule))
|
||||
|
||||
return callback
|
||||
|
||||
# For the first forward pass, we have to load in a blocking manner
|
||||
group_offloading_hook.group.non_blocking = False
|
||||
layer_tracker_hook = LayerExecutionTrackerHook(make_execution_order_update_callback(name, submodule))
|
||||
registry.register_hook(layer_tracker_hook, _LAYER_EXECUTION_TRACKER)
|
||||
self._layer_execution_tracker_module_names.add(name)
|
||||
@@ -220,6 +277,7 @@ class LazyPrefetchGroupOffloadingHook(ModelHook):
|
||||
# Remove the layer execution tracker hooks from the submodules
|
||||
base_module_registry = module._diffusers_hook
|
||||
registries = [submodule._diffusers_hook for _, submodule in self.execution_order]
|
||||
group_offloading_hooks = [registry.get_hook(_GROUP_OFFLOADING) for registry in registries]
|
||||
|
||||
for i in range(num_executed):
|
||||
registries[i].remove_hook(_LAYER_EXECUTION_TRACKER, recurse=False)
|
||||
@@ -227,8 +285,13 @@ class LazyPrefetchGroupOffloadingHook(ModelHook):
|
||||
# Remove the current lazy prefetch group offloading hook so that it doesn't interfere with the next forward pass
|
||||
base_module_registry.remove_hook(_LAZY_PREFETCH_GROUP_OFFLOADING, recurse=False)
|
||||
|
||||
# Apply lazy prefetching by setting required attributes
|
||||
group_offloading_hooks = [registry.get_hook(_GROUP_OFFLOADING) for registry in registries]
|
||||
# LazyPrefetchGroupOffloadingHook is only used with streams, so we know that non_blocking should be True.
|
||||
# We disable non_blocking for the first forward pass, but need to enable it for the subsequent passes to
|
||||
# see the benefits of prefetching.
|
||||
for hook in group_offloading_hooks:
|
||||
hook.group.non_blocking = True
|
||||
|
||||
# Set required attributes for prefetching
|
||||
if num_executed > 0:
|
||||
base_module_group_offloading_hook = base_module_registry.get_hook(_GROUP_OFFLOADING)
|
||||
base_module_group_offloading_hook.next_group = group_offloading_hooks[0].group
|
||||
@@ -268,6 +331,7 @@ def apply_group_offloading(
|
||||
num_blocks_per_group: Optional[int] = None,
|
||||
non_blocking: bool = False,
|
||||
use_stream: bool = False,
|
||||
low_cpu_mem_usage: bool = False,
|
||||
) -> None:
|
||||
r"""
|
||||
Applies group offloading to the internal layers of a torch.nn.Module. To understand what group offloading is, and
|
||||
@@ -314,6 +378,10 @@ def apply_group_offloading(
|
||||
use_stream (`bool`, defaults to `False`):
|
||||
If True, offloading and onloading is done asynchronously using a CUDA stream. This can be useful for
|
||||
overlapping computation and data transfer.
|
||||
low_cpu_mem_usage (`bool`, defaults to `False`):
|
||||
If True, the CPU memory usage is minimized by pinning tensors on-the-fly instead of pre-pinning them. This
|
||||
option only matters when using streamed CPU offloading (i.e. `use_stream=True`). This can be useful when
|
||||
the CPU memory is a bottleneck but may counteract the benefits of using streams.
|
||||
|
||||
Example:
|
||||
```python
|
||||
@@ -349,10 +417,12 @@ def apply_group_offloading(
|
||||
raise ValueError("num_blocks_per_group must be provided when using offload_type='block_level'.")
|
||||
|
||||
_apply_group_offloading_block_level(
|
||||
module, num_blocks_per_group, offload_device, onload_device, non_blocking, stream
|
||||
module, num_blocks_per_group, offload_device, onload_device, non_blocking, stream, low_cpu_mem_usage
|
||||
)
|
||||
elif offload_type == "leaf_level":
|
||||
_apply_group_offloading_leaf_level(module, offload_device, onload_device, non_blocking, stream)
|
||||
_apply_group_offloading_leaf_level(
|
||||
module, offload_device, onload_device, non_blocking, stream, low_cpu_mem_usage
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Unsupported offload_type: {offload_type}")
|
||||
|
||||
@@ -364,6 +434,7 @@ def _apply_group_offloading_block_level(
|
||||
onload_device: torch.device,
|
||||
non_blocking: bool,
|
||||
stream: Optional[torch.cuda.Stream] = None,
|
||||
low_cpu_mem_usage: bool = False,
|
||||
) -> None:
|
||||
r"""
|
||||
This function applies offloading to groups of torch.nn.ModuleList or torch.nn.Sequential blocks. In comparison to
|
||||
@@ -384,13 +455,6 @@ def _apply_group_offloading_block_level(
|
||||
for overlapping computation and data transfer.
|
||||
"""
|
||||
|
||||
# Create a pinned CPU parameter dict for async data transfer if streams are to be used
|
||||
cpu_param_dict = None
|
||||
if stream is not None:
|
||||
for param in module.parameters():
|
||||
param.data = param.data.cpu().pin_memory()
|
||||
cpu_param_dict = {param: param.data for param in module.parameters()}
|
||||
|
||||
# Create module groups for ModuleList and Sequential blocks
|
||||
modules_with_group_offloading = set()
|
||||
unmatched_modules = []
|
||||
@@ -411,7 +475,7 @@ def _apply_group_offloading_block_level(
|
||||
onload_leader=current_modules[0],
|
||||
non_blocking=non_blocking,
|
||||
stream=stream,
|
||||
cpu_param_dict=cpu_param_dict,
|
||||
low_cpu_mem_usage=low_cpu_mem_usage,
|
||||
onload_self=stream is None,
|
||||
)
|
||||
matched_module_groups.append(group)
|
||||
@@ -448,7 +512,6 @@ def _apply_group_offloading_block_level(
|
||||
buffers=buffers,
|
||||
non_blocking=False,
|
||||
stream=None,
|
||||
cpu_param_dict=None,
|
||||
onload_self=True,
|
||||
)
|
||||
next_group = matched_module_groups[0] if len(matched_module_groups) > 0 else None
|
||||
@@ -461,6 +524,7 @@ def _apply_group_offloading_leaf_level(
|
||||
onload_device: torch.device,
|
||||
non_blocking: bool,
|
||||
stream: Optional[torch.cuda.Stream] = None,
|
||||
low_cpu_mem_usage: bool = False,
|
||||
) -> None:
|
||||
r"""
|
||||
This function applies offloading to groups of leaf modules in a torch.nn.Module. This method has minimal memory
|
||||
@@ -483,13 +547,6 @@ def _apply_group_offloading_leaf_level(
|
||||
for overlapping computation and data transfer.
|
||||
"""
|
||||
|
||||
# Create a pinned CPU parameter dict for async data transfer if streams are to be used
|
||||
cpu_param_dict = None
|
||||
if stream is not None:
|
||||
for param in module.parameters():
|
||||
param.data = param.data.cpu().pin_memory()
|
||||
cpu_param_dict = {param: param.data for param in module.parameters()}
|
||||
|
||||
# Create module groups for leaf modules and apply group offloading hooks
|
||||
modules_with_group_offloading = set()
|
||||
for name, submodule in module.named_modules():
|
||||
@@ -503,7 +560,7 @@ def _apply_group_offloading_leaf_level(
|
||||
onload_leader=submodule,
|
||||
non_blocking=non_blocking,
|
||||
stream=stream,
|
||||
cpu_param_dict=cpu_param_dict,
|
||||
low_cpu_mem_usage=low_cpu_mem_usage,
|
||||
onload_self=True,
|
||||
)
|
||||
_apply_group_offloading_hook(submodule, group, None)
|
||||
@@ -548,7 +605,7 @@ def _apply_group_offloading_leaf_level(
|
||||
buffers=buffers,
|
||||
non_blocking=non_blocking,
|
||||
stream=stream,
|
||||
cpu_param_dict=cpu_param_dict,
|
||||
low_cpu_mem_usage=low_cpu_mem_usage,
|
||||
onload_self=True,
|
||||
)
|
||||
_apply_group_offloading_hook(parent_module, group, None)
|
||||
@@ -567,7 +624,7 @@ def _apply_group_offloading_leaf_level(
|
||||
buffers=None,
|
||||
non_blocking=False,
|
||||
stream=None,
|
||||
cpu_param_dict=None,
|
||||
low_cpu_mem_usage=low_cpu_mem_usage,
|
||||
onload_self=True,
|
||||
)
|
||||
_apply_lazy_group_offloading_hook(module, unmatched_group, None)
|
||||
|
||||
@@ -26,8 +26,8 @@ from .hooks import HookRegistry, ModelHook
|
||||
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||
|
||||
|
||||
_PYRAMID_ATTENTION_BROADCAST_HOOK = "pyramid_attention_broadcast"
|
||||
_ATTENTION_CLASSES = (Attention, MochiAttention)
|
||||
|
||||
_SPATIAL_ATTENTION_BLOCK_IDENTIFIERS = ("blocks", "transformer_blocks", "single_transformer_blocks")
|
||||
_TEMPORAL_ATTENTION_BLOCK_IDENTIFIERS = ("temporal_transformer_blocks",)
|
||||
_CROSS_ATTENTION_BLOCK_IDENTIFIERS = ("blocks", "transformer_blocks")
|
||||
@@ -87,7 +87,7 @@ class PyramidAttentionBroadcastConfig:
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return (
|
||||
f"PyramidAttentionBroadcastConfig("
|
||||
f"PyramidAttentionBroadcastConfig(\n"
|
||||
f" spatial_attention_block_skip_range={self.spatial_attention_block_skip_range},\n"
|
||||
f" temporal_attention_block_skip_range={self.temporal_attention_block_skip_range},\n"
|
||||
f" cross_attention_block_skip_range={self.cross_attention_block_skip_range},\n"
|
||||
@@ -175,10 +175,7 @@ class PyramidAttentionBroadcastHook(ModelHook):
|
||||
return module
|
||||
|
||||
|
||||
def apply_pyramid_attention_broadcast(
|
||||
module: torch.nn.Module,
|
||||
config: PyramidAttentionBroadcastConfig,
|
||||
):
|
||||
def apply_pyramid_attention_broadcast(module: torch.nn.Module, config: PyramidAttentionBroadcastConfig):
|
||||
r"""
|
||||
Apply [Pyramid Attention Broadcast](https://huggingface.co/papers/2408.12588) to a given pipeline.
|
||||
|
||||
@@ -311,4 +308,4 @@ def _apply_pyramid_attention_broadcast_hook(
|
||||
"""
|
||||
registry = HookRegistry.check_if_exists_or_initialize(module)
|
||||
hook = PyramidAttentionBroadcastHook(timestep_skip_range, block_skip_range, current_timestep_callback)
|
||||
registry.register_hook(hook, "pyramid_attention_broadcast")
|
||||
registry.register_hook(hook, _PYRAMID_ATTENTION_BROADCAST_HOOK)
|
||||
|
||||
@@ -70,6 +70,7 @@ if is_torch_available():
|
||||
"LoraLoaderMixin",
|
||||
"FluxLoraLoaderMixin",
|
||||
"CogVideoXLoraLoaderMixin",
|
||||
"CogView4LoraLoaderMixin",
|
||||
"Mochi1LoraLoaderMixin",
|
||||
"HunyuanVideoLoraLoaderMixin",
|
||||
"SanaLoraLoaderMixin",
|
||||
@@ -103,6 +104,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
from .lora_pipeline import (
|
||||
AmusedLoraLoaderMixin,
|
||||
CogVideoXLoraLoaderMixin,
|
||||
CogView4LoraLoaderMixin,
|
||||
FluxLoraLoaderMixin,
|
||||
HunyuanVideoLoraLoaderMixin,
|
||||
LoraLoaderMixin,
|
||||
|
||||
@@ -804,9 +804,7 @@ class SD3IPAdapterMixin:
|
||||
}
|
||||
|
||||
self.register_modules(
|
||||
feature_extractor=SiglipImageProcessor.from_pretrained(image_encoder_subfolder, **kwargs).to(
|
||||
self.device, dtype=self.dtype
|
||||
),
|
||||
feature_extractor=SiglipImageProcessor.from_pretrained(image_encoder_subfolder, **kwargs),
|
||||
image_encoder=SiglipVisionModel.from_pretrained(
|
||||
image_encoder_subfolder, torch_dtype=self.dtype, **kwargs
|
||||
).to(self.device),
|
||||
|
||||
@@ -339,93 +339,97 @@ def _load_lora_into_text_encoder(
|
||||
# If the serialization format is new (introduced in https://github.com/huggingface/diffusers/pull/2918),
|
||||
# then the `state_dict` keys should have `unet_name` and/or `text_encoder_name` as
|
||||
# their prefixes.
|
||||
keys = list(state_dict.keys())
|
||||
prefix = text_encoder_name if prefix is None else prefix
|
||||
|
||||
# Safe prefix to check with.
|
||||
if any(text_encoder_name in key for key in keys):
|
||||
# Load the layers corresponding to text encoder and make necessary adjustments.
|
||||
text_encoder_keys = [k for k in keys if k.startswith(prefix) and k.split(".")[0] == prefix]
|
||||
text_encoder_lora_state_dict = {
|
||||
k.replace(f"{prefix}.", ""): v for k, v in state_dict.items() if k in text_encoder_keys
|
||||
}
|
||||
# Load the layers corresponding to text encoder and make necessary adjustments.
|
||||
if prefix is not None:
|
||||
state_dict = {k[len(f"{prefix}.") :]: v for k, v in state_dict.items() if k.startswith(f"{prefix}.")}
|
||||
|
||||
if len(text_encoder_lora_state_dict) > 0:
|
||||
logger.info(f"Loading {prefix}.")
|
||||
rank = {}
|
||||
text_encoder_lora_state_dict = convert_state_dict_to_diffusers(text_encoder_lora_state_dict)
|
||||
if len(state_dict) > 0:
|
||||
logger.info(f"Loading {prefix}.")
|
||||
rank = {}
|
||||
state_dict = convert_state_dict_to_diffusers(state_dict)
|
||||
|
||||
# convert state dict
|
||||
text_encoder_lora_state_dict = convert_state_dict_to_peft(text_encoder_lora_state_dict)
|
||||
# convert state dict
|
||||
state_dict = convert_state_dict_to_peft(state_dict)
|
||||
|
||||
for name, _ in text_encoder_attn_modules(text_encoder):
|
||||
for module in ("out_proj", "q_proj", "k_proj", "v_proj"):
|
||||
rank_key = f"{name}.{module}.lora_B.weight"
|
||||
if rank_key not in text_encoder_lora_state_dict:
|
||||
continue
|
||||
rank[rank_key] = text_encoder_lora_state_dict[rank_key].shape[1]
|
||||
for name, _ in text_encoder_attn_modules(text_encoder):
|
||||
for module in ("out_proj", "q_proj", "k_proj", "v_proj"):
|
||||
rank_key = f"{name}.{module}.lora_B.weight"
|
||||
if rank_key not in state_dict:
|
||||
continue
|
||||
rank[rank_key] = state_dict[rank_key].shape[1]
|
||||
|
||||
for name, _ in text_encoder_mlp_modules(text_encoder):
|
||||
for module in ("fc1", "fc2"):
|
||||
rank_key = f"{name}.{module}.lora_B.weight"
|
||||
if rank_key not in text_encoder_lora_state_dict:
|
||||
continue
|
||||
rank[rank_key] = text_encoder_lora_state_dict[rank_key].shape[1]
|
||||
for name, _ in text_encoder_mlp_modules(text_encoder):
|
||||
for module in ("fc1", "fc2"):
|
||||
rank_key = f"{name}.{module}.lora_B.weight"
|
||||
if rank_key not in state_dict:
|
||||
continue
|
||||
rank[rank_key] = state_dict[rank_key].shape[1]
|
||||
|
||||
if network_alphas is not None:
|
||||
alpha_keys = [k for k in network_alphas.keys() if k.startswith(prefix) and k.split(".")[0] == prefix]
|
||||
network_alphas = {k.replace(f"{prefix}.", ""): v for k, v in network_alphas.items() if k in alpha_keys}
|
||||
if network_alphas is not None:
|
||||
alpha_keys = [k for k in network_alphas.keys() if k.startswith(prefix) and k.split(".")[0] == prefix]
|
||||
network_alphas = {k.replace(f"{prefix}.", ""): v for k, v in network_alphas.items() if k in alpha_keys}
|
||||
|
||||
lora_config_kwargs = get_peft_kwargs(rank, network_alphas, text_encoder_lora_state_dict, is_unet=False)
|
||||
lora_config_kwargs = get_peft_kwargs(rank, network_alphas, state_dict, is_unet=False)
|
||||
|
||||
if "use_dora" in lora_config_kwargs:
|
||||
if lora_config_kwargs["use_dora"]:
|
||||
if is_peft_version("<", "0.9.0"):
|
||||
raise ValueError(
|
||||
"You need `peft` 0.9.0 at least to use DoRA-enabled LoRAs. Please upgrade your installation of `peft`."
|
||||
)
|
||||
else:
|
||||
if is_peft_version("<", "0.9.0"):
|
||||
lora_config_kwargs.pop("use_dora")
|
||||
if "use_dora" in lora_config_kwargs:
|
||||
if lora_config_kwargs["use_dora"]:
|
||||
if is_peft_version("<", "0.9.0"):
|
||||
raise ValueError(
|
||||
"You need `peft` 0.9.0 at least to use DoRA-enabled LoRAs. Please upgrade your installation of `peft`."
|
||||
)
|
||||
else:
|
||||
if is_peft_version("<", "0.9.0"):
|
||||
lora_config_kwargs.pop("use_dora")
|
||||
|
||||
if "lora_bias" in lora_config_kwargs:
|
||||
if lora_config_kwargs["lora_bias"]:
|
||||
if is_peft_version("<=", "0.13.2"):
|
||||
raise ValueError(
|
||||
"You need `peft` 0.14.0 at least to use `bias` in LoRAs. Please upgrade your installation of `peft`."
|
||||
)
|
||||
else:
|
||||
if is_peft_version("<=", "0.13.2"):
|
||||
lora_config_kwargs.pop("lora_bias")
|
||||
if "lora_bias" in lora_config_kwargs:
|
||||
if lora_config_kwargs["lora_bias"]:
|
||||
if is_peft_version("<=", "0.13.2"):
|
||||
raise ValueError(
|
||||
"You need `peft` 0.14.0 at least to use `bias` in LoRAs. Please upgrade your installation of `peft`."
|
||||
)
|
||||
else:
|
||||
if is_peft_version("<=", "0.13.2"):
|
||||
lora_config_kwargs.pop("lora_bias")
|
||||
|
||||
lora_config = LoraConfig(**lora_config_kwargs)
|
||||
lora_config = LoraConfig(**lora_config_kwargs)
|
||||
|
||||
# adapter_name
|
||||
if adapter_name is None:
|
||||
adapter_name = get_adapter_name(text_encoder)
|
||||
# adapter_name
|
||||
if adapter_name is None:
|
||||
adapter_name = get_adapter_name(text_encoder)
|
||||
|
||||
is_model_cpu_offload, is_sequential_cpu_offload = _func_optionally_disable_offloading(_pipeline)
|
||||
is_model_cpu_offload, is_sequential_cpu_offload = _func_optionally_disable_offloading(_pipeline)
|
||||
|
||||
# inject LoRA layers and load the state dict
|
||||
# in transformers we automatically check whether the adapter name is already in use or not
|
||||
text_encoder.load_adapter(
|
||||
adapter_name=adapter_name,
|
||||
adapter_state_dict=text_encoder_lora_state_dict,
|
||||
peft_config=lora_config,
|
||||
**peft_kwargs,
|
||||
)
|
||||
# inject LoRA layers and load the state dict
|
||||
# in transformers we automatically check whether the adapter name is already in use or not
|
||||
text_encoder.load_adapter(
|
||||
adapter_name=adapter_name,
|
||||
adapter_state_dict=state_dict,
|
||||
peft_config=lora_config,
|
||||
**peft_kwargs,
|
||||
)
|
||||
|
||||
# scale LoRA layers with `lora_scale`
|
||||
scale_lora_layers(text_encoder, weight=lora_scale)
|
||||
# scale LoRA layers with `lora_scale`
|
||||
scale_lora_layers(text_encoder, weight=lora_scale)
|
||||
|
||||
text_encoder.to(device=text_encoder.device, dtype=text_encoder.dtype)
|
||||
text_encoder.to(device=text_encoder.device, dtype=text_encoder.dtype)
|
||||
|
||||
# Offload back.
|
||||
if is_model_cpu_offload:
|
||||
_pipeline.enable_model_cpu_offload()
|
||||
elif is_sequential_cpu_offload:
|
||||
_pipeline.enable_sequential_cpu_offload()
|
||||
# Unsafe code />
|
||||
# Offload back.
|
||||
if is_model_cpu_offload:
|
||||
_pipeline.enable_model_cpu_offload()
|
||||
elif is_sequential_cpu_offload:
|
||||
_pipeline.enable_sequential_cpu_offload()
|
||||
# Unsafe code />
|
||||
|
||||
if prefix is not None and not state_dict:
|
||||
logger.warning(
|
||||
f"No LoRA keys associated to {text_encoder.__class__.__name__} found with the {prefix=}. "
|
||||
"This is safe to ignore if LoRA state dict didn't originally have any "
|
||||
f"{text_encoder.__class__.__name__} related params. You can also try specifying `prefix=None` "
|
||||
"to resolve the warning. Otherwise, open an issue if you think it's unexpected: "
|
||||
"https://github.com/huggingface/diffusers/issues/new"
|
||||
)
|
||||
|
||||
|
||||
def _func_optionally_disable_offloading(_pipeline):
|
||||
|
||||
@@ -1348,3 +1348,56 @@ def _convert_non_diffusers_lumina2_lora_to_diffusers(state_dict):
|
||||
converted_state_dict[f"transformer.{key}"] = converted_state_dict.pop(key)
|
||||
|
||||
return converted_state_dict
|
||||
|
||||
|
||||
def _convert_non_diffusers_wan_lora_to_diffusers(state_dict):
|
||||
converted_state_dict = {}
|
||||
original_state_dict = {k[len("diffusion_model.") :]: v for k, v in state_dict.items()}
|
||||
|
||||
num_blocks = len({k.split("blocks.")[1].split(".")[0] for k in original_state_dict})
|
||||
is_i2v_lora = any("k_img" in k for k in original_state_dict) and any("v_img" in k for k in original_state_dict)
|
||||
|
||||
for i in range(num_blocks):
|
||||
# Self-attention
|
||||
for o, c in zip(["q", "k", "v", "o"], ["to_q", "to_k", "to_v", "to_out.0"]):
|
||||
converted_state_dict[f"blocks.{i}.attn1.{c}.lora_A.weight"] = original_state_dict.pop(
|
||||
f"blocks.{i}.self_attn.{o}.lora_A.weight"
|
||||
)
|
||||
converted_state_dict[f"blocks.{i}.attn1.{c}.lora_B.weight"] = original_state_dict.pop(
|
||||
f"blocks.{i}.self_attn.{o}.lora_B.weight"
|
||||
)
|
||||
|
||||
# Cross-attention
|
||||
for o, c in zip(["q", "k", "v", "o"], ["to_q", "to_k", "to_v", "to_out.0"]):
|
||||
converted_state_dict[f"blocks.{i}.attn2.{c}.lora_A.weight"] = original_state_dict.pop(
|
||||
f"blocks.{i}.cross_attn.{o}.lora_A.weight"
|
||||
)
|
||||
converted_state_dict[f"blocks.{i}.attn2.{c}.lora_B.weight"] = original_state_dict.pop(
|
||||
f"blocks.{i}.cross_attn.{o}.lora_B.weight"
|
||||
)
|
||||
|
||||
if is_i2v_lora:
|
||||
for o, c in zip(["k_img", "v_img"], ["add_k_proj", "add_v_proj"]):
|
||||
converted_state_dict[f"blocks.{i}.attn2.{c}.lora_A.weight"] = original_state_dict.pop(
|
||||
f"blocks.{i}.cross_attn.{o}.lora_A.weight"
|
||||
)
|
||||
converted_state_dict[f"blocks.{i}.attn2.{c}.lora_B.weight"] = original_state_dict.pop(
|
||||
f"blocks.{i}.cross_attn.{o}.lora_B.weight"
|
||||
)
|
||||
|
||||
# FFN
|
||||
for o, c in zip(["ffn.0", "ffn.2"], ["net.0.proj", "net.2"]):
|
||||
converted_state_dict[f"blocks.{i}.ffn.{c}.lora_A.weight"] = original_state_dict.pop(
|
||||
f"blocks.{i}.{o}.lora_A.weight"
|
||||
)
|
||||
converted_state_dict[f"blocks.{i}.ffn.{c}.lora_B.weight"] = original_state_dict.pop(
|
||||
f"blocks.{i}.{o}.lora_B.weight"
|
||||
)
|
||||
|
||||
if len(original_state_dict) > 0:
|
||||
raise ValueError(f"`state_dict` should be empty at this point but has {original_state_dict.keys()=}")
|
||||
|
||||
for key in list(converted_state_dict.keys()):
|
||||
converted_state_dict[f"transformer.{key}"] = converted_state_dict.pop(key)
|
||||
|
||||
return converted_state_dict
|
||||
|
||||
@@ -42,6 +42,7 @@ from .lora_conversion_utils import (
|
||||
_convert_kohya_flux_lora_to_diffusers,
|
||||
_convert_non_diffusers_lora_to_diffusers,
|
||||
_convert_non_diffusers_lumina2_lora_to_diffusers,
|
||||
_convert_non_diffusers_wan_lora_to_diffusers,
|
||||
_convert_xlabs_flux_lora_to_diffusers,
|
||||
_maybe_map_sgm_blocks_to_diffusers,
|
||||
)
|
||||
@@ -298,19 +299,15 @@ class StableDiffusionLoraLoaderMixin(LoraBaseMixin):
|
||||
# If the serialization format is new (introduced in https://github.com/huggingface/diffusers/pull/2918),
|
||||
# then the `state_dict` keys should have `cls.unet_name` and/or `cls.text_encoder_name` as
|
||||
# their prefixes.
|
||||
keys = list(state_dict.keys())
|
||||
only_text_encoder = all(key.startswith(cls.text_encoder_name) for key in keys)
|
||||
if not only_text_encoder:
|
||||
# Load the layers corresponding to UNet.
|
||||
logger.info(f"Loading {cls.unet_name}.")
|
||||
unet.load_lora_adapter(
|
||||
state_dict,
|
||||
prefix=cls.unet_name,
|
||||
network_alphas=network_alphas,
|
||||
adapter_name=adapter_name,
|
||||
_pipeline=_pipeline,
|
||||
low_cpu_mem_usage=low_cpu_mem_usage,
|
||||
)
|
||||
logger.info(f"Loading {cls.unet_name}.")
|
||||
unet.load_lora_adapter(
|
||||
state_dict,
|
||||
prefix=cls.unet_name,
|
||||
network_alphas=network_alphas,
|
||||
adapter_name=adapter_name,
|
||||
_pipeline=_pipeline,
|
||||
low_cpu_mem_usage=low_cpu_mem_usage,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def load_lora_into_text_encoder(
|
||||
@@ -455,7 +452,11 @@ class StableDiffusionLoraLoaderMixin(LoraBaseMixin):
|
||||
```
|
||||
"""
|
||||
super().fuse_lora(
|
||||
components=components, lora_scale=lora_scale, safe_fusing=safe_fusing, adapter_names=adapter_names
|
||||
components=components,
|
||||
lora_scale=lora_scale,
|
||||
safe_fusing=safe_fusing,
|
||||
adapter_names=adapter_names,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
def unfuse_lora(self, components: List[str] = ["unet", "text_encoder"], **kwargs):
|
||||
@@ -476,7 +477,7 @@ class StableDiffusionLoraLoaderMixin(LoraBaseMixin):
|
||||
Whether to unfuse the text encoder LoRA parameters. If the text encoder wasn't monkey-patched with the
|
||||
LoRA parameters then it won't have any effect.
|
||||
"""
|
||||
super().unfuse_lora(components=components)
|
||||
super().unfuse_lora(components=components, **kwargs)
|
||||
|
||||
|
||||
class StableDiffusionXLLoraLoaderMixin(LoraBaseMixin):
|
||||
@@ -559,31 +560,26 @@ class StableDiffusionXLLoraLoaderMixin(LoraBaseMixin):
|
||||
_pipeline=self,
|
||||
low_cpu_mem_usage=low_cpu_mem_usage,
|
||||
)
|
||||
text_encoder_state_dict = {k: v for k, v in state_dict.items() if "text_encoder." in k}
|
||||
if len(text_encoder_state_dict) > 0:
|
||||
self.load_lora_into_text_encoder(
|
||||
text_encoder_state_dict,
|
||||
network_alphas=network_alphas,
|
||||
text_encoder=self.text_encoder,
|
||||
prefix="text_encoder",
|
||||
lora_scale=self.lora_scale,
|
||||
adapter_name=adapter_name,
|
||||
_pipeline=self,
|
||||
low_cpu_mem_usage=low_cpu_mem_usage,
|
||||
)
|
||||
|
||||
text_encoder_2_state_dict = {k: v for k, v in state_dict.items() if "text_encoder_2." in k}
|
||||
if len(text_encoder_2_state_dict) > 0:
|
||||
self.load_lora_into_text_encoder(
|
||||
text_encoder_2_state_dict,
|
||||
network_alphas=network_alphas,
|
||||
text_encoder=self.text_encoder_2,
|
||||
prefix="text_encoder_2",
|
||||
lora_scale=self.lora_scale,
|
||||
adapter_name=adapter_name,
|
||||
_pipeline=self,
|
||||
low_cpu_mem_usage=low_cpu_mem_usage,
|
||||
)
|
||||
self.load_lora_into_text_encoder(
|
||||
state_dict,
|
||||
network_alphas=network_alphas,
|
||||
text_encoder=self.text_encoder,
|
||||
prefix=self.text_encoder_name,
|
||||
lora_scale=self.lora_scale,
|
||||
adapter_name=adapter_name,
|
||||
_pipeline=self,
|
||||
low_cpu_mem_usage=low_cpu_mem_usage,
|
||||
)
|
||||
self.load_lora_into_text_encoder(
|
||||
state_dict,
|
||||
network_alphas=network_alphas,
|
||||
text_encoder=self.text_encoder_2,
|
||||
prefix=f"{self.text_encoder_name}_2",
|
||||
lora_scale=self.lora_scale,
|
||||
adapter_name=adapter_name,
|
||||
_pipeline=self,
|
||||
low_cpu_mem_usage=low_cpu_mem_usage,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
@validate_hf_hub_args
|
||||
@@ -738,19 +734,15 @@ class StableDiffusionXLLoraLoaderMixin(LoraBaseMixin):
|
||||
# If the serialization format is new (introduced in https://github.com/huggingface/diffusers/pull/2918),
|
||||
# then the `state_dict` keys should have `cls.unet_name` and/or `cls.text_encoder_name` as
|
||||
# their prefixes.
|
||||
keys = list(state_dict.keys())
|
||||
only_text_encoder = all(key.startswith(cls.text_encoder_name) for key in keys)
|
||||
if not only_text_encoder:
|
||||
# Load the layers corresponding to UNet.
|
||||
logger.info(f"Loading {cls.unet_name}.")
|
||||
unet.load_lora_adapter(
|
||||
state_dict,
|
||||
prefix=cls.unet_name,
|
||||
network_alphas=network_alphas,
|
||||
adapter_name=adapter_name,
|
||||
_pipeline=_pipeline,
|
||||
low_cpu_mem_usage=low_cpu_mem_usage,
|
||||
)
|
||||
logger.info(f"Loading {cls.unet_name}.")
|
||||
unet.load_lora_adapter(
|
||||
state_dict,
|
||||
prefix=cls.unet_name,
|
||||
network_alphas=network_alphas,
|
||||
adapter_name=adapter_name,
|
||||
_pipeline=_pipeline,
|
||||
low_cpu_mem_usage=low_cpu_mem_usage,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
# Copied from diffusers.loaders.lora_pipeline.StableDiffusionLoraLoaderMixin.load_lora_into_text_encoder
|
||||
@@ -843,11 +835,11 @@ class StableDiffusionXLLoraLoaderMixin(LoraBaseMixin):
|
||||
|
||||
if not (unet_lora_layers or text_encoder_lora_layers or text_encoder_2_lora_layers):
|
||||
raise ValueError(
|
||||
"You must pass at least one of `unet_lora_layers`, `text_encoder_lora_layers` or `text_encoder_2_lora_layers`."
|
||||
"You must pass at least one of `unet_lora_layers`, `text_encoder_lora_layers`, `text_encoder_2_lora_layers`."
|
||||
)
|
||||
|
||||
if unet_lora_layers:
|
||||
state_dict.update(cls.pack_weights(unet_lora_layers, "unet"))
|
||||
state_dict.update(cls.pack_weights(unet_lora_layers, cls.unet_name))
|
||||
|
||||
if text_encoder_lora_layers:
|
||||
state_dict.update(cls.pack_weights(text_encoder_lora_layers, "text_encoder"))
|
||||
@@ -904,7 +896,11 @@ class StableDiffusionXLLoraLoaderMixin(LoraBaseMixin):
|
||||
```
|
||||
"""
|
||||
super().fuse_lora(
|
||||
components=components, lora_scale=lora_scale, safe_fusing=safe_fusing, adapter_names=adapter_names
|
||||
components=components,
|
||||
lora_scale=lora_scale,
|
||||
safe_fusing=safe_fusing,
|
||||
adapter_names=adapter_names,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
def unfuse_lora(self, components: List[str] = ["unet", "text_encoder", "text_encoder_2"], **kwargs):
|
||||
@@ -925,7 +921,7 @@ class StableDiffusionXLLoraLoaderMixin(LoraBaseMixin):
|
||||
Whether to unfuse the text encoder LoRA parameters. If the text encoder wasn't monkey-patched with the
|
||||
LoRA parameters then it won't have any effect.
|
||||
"""
|
||||
super().unfuse_lora(components=components)
|
||||
super().unfuse_lora(components=components, **kwargs)
|
||||
|
||||
|
||||
class SD3LoraLoaderMixin(LoraBaseMixin):
|
||||
@@ -1085,43 +1081,33 @@ class SD3LoraLoaderMixin(LoraBaseMixin):
|
||||
if not is_correct_format:
|
||||
raise ValueError("Invalid LoRA checkpoint.")
|
||||
|
||||
transformer_state_dict = {k: v for k, v in state_dict.items() if "transformer." in k}
|
||||
if len(transformer_state_dict) > 0:
|
||||
self.load_lora_into_transformer(
|
||||
state_dict,
|
||||
transformer=getattr(self, self.transformer_name)
|
||||
if not hasattr(self, "transformer")
|
||||
else self.transformer,
|
||||
adapter_name=adapter_name,
|
||||
_pipeline=self,
|
||||
low_cpu_mem_usage=low_cpu_mem_usage,
|
||||
)
|
||||
|
||||
text_encoder_state_dict = {k: v for k, v in state_dict.items() if "text_encoder." in k}
|
||||
if len(text_encoder_state_dict) > 0:
|
||||
self.load_lora_into_text_encoder(
|
||||
text_encoder_state_dict,
|
||||
network_alphas=None,
|
||||
text_encoder=self.text_encoder,
|
||||
prefix="text_encoder",
|
||||
lora_scale=self.lora_scale,
|
||||
adapter_name=adapter_name,
|
||||
_pipeline=self,
|
||||
low_cpu_mem_usage=low_cpu_mem_usage,
|
||||
)
|
||||
|
||||
text_encoder_2_state_dict = {k: v for k, v in state_dict.items() if "text_encoder_2." in k}
|
||||
if len(text_encoder_2_state_dict) > 0:
|
||||
self.load_lora_into_text_encoder(
|
||||
text_encoder_2_state_dict,
|
||||
network_alphas=None,
|
||||
text_encoder=self.text_encoder_2,
|
||||
prefix="text_encoder_2",
|
||||
lora_scale=self.lora_scale,
|
||||
adapter_name=adapter_name,
|
||||
_pipeline=self,
|
||||
low_cpu_mem_usage=low_cpu_mem_usage,
|
||||
)
|
||||
self.load_lora_into_transformer(
|
||||
state_dict,
|
||||
transformer=getattr(self, self.transformer_name) if not hasattr(self, "transformer") else self.transformer,
|
||||
adapter_name=adapter_name,
|
||||
_pipeline=self,
|
||||
low_cpu_mem_usage=low_cpu_mem_usage,
|
||||
)
|
||||
self.load_lora_into_text_encoder(
|
||||
state_dict,
|
||||
network_alphas=None,
|
||||
text_encoder=self.text_encoder,
|
||||
prefix=self.text_encoder_name,
|
||||
lora_scale=self.lora_scale,
|
||||
adapter_name=adapter_name,
|
||||
_pipeline=self,
|
||||
low_cpu_mem_usage=low_cpu_mem_usage,
|
||||
)
|
||||
self.load_lora_into_text_encoder(
|
||||
state_dict,
|
||||
network_alphas=None,
|
||||
text_encoder=self.text_encoder_2,
|
||||
prefix=f"{self.text_encoder_name}_2",
|
||||
lora_scale=self.lora_scale,
|
||||
adapter_name=adapter_name,
|
||||
_pipeline=self,
|
||||
low_cpu_mem_usage=low_cpu_mem_usage,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def load_lora_into_transformer(
|
||||
@@ -1210,10 +1196,11 @@ class SD3LoraLoaderMixin(LoraBaseMixin):
|
||||
)
|
||||
|
||||
@classmethod
|
||||
# Copied from diffusers.loaders.lora_pipeline.StableDiffusionXLLoraLoaderMixin.save_lora_weights with unet->transformer
|
||||
def save_lora_weights(
|
||||
cls,
|
||||
save_directory: Union[str, os.PathLike],
|
||||
transformer_lora_layers: Dict[str, torch.nn.Module] = None,
|
||||
transformer_lora_layers: Dict[str, Union[torch.nn.Module, torch.Tensor]] = None,
|
||||
text_encoder_lora_layers: Dict[str, Union[torch.nn.Module, torch.Tensor]] = None,
|
||||
text_encoder_2_lora_layers: Dict[str, Union[torch.nn.Module, torch.Tensor]] = None,
|
||||
is_main_process: bool = True,
|
||||
@@ -1262,7 +1249,6 @@ class SD3LoraLoaderMixin(LoraBaseMixin):
|
||||
if text_encoder_2_lora_layers:
|
||||
state_dict.update(cls.pack_weights(text_encoder_2_lora_layers, "text_encoder_2"))
|
||||
|
||||
# Save the model
|
||||
cls.write_lora_layers(
|
||||
state_dict=state_dict,
|
||||
save_directory=save_directory,
|
||||
@@ -1272,6 +1258,7 @@ class SD3LoraLoaderMixin(LoraBaseMixin):
|
||||
safe_serialization=safe_serialization,
|
||||
)
|
||||
|
||||
# Copied from diffusers.loaders.lora_pipeline.StableDiffusionXLLoraLoaderMixin.fuse_lora with unet->transformer
|
||||
def fuse_lora(
|
||||
self,
|
||||
components: List[str] = ["transformer", "text_encoder", "text_encoder_2"],
|
||||
@@ -1312,9 +1299,14 @@ class SD3LoraLoaderMixin(LoraBaseMixin):
|
||||
```
|
||||
"""
|
||||
super().fuse_lora(
|
||||
components=components, lora_scale=lora_scale, safe_fusing=safe_fusing, adapter_names=adapter_names
|
||||
components=components,
|
||||
lora_scale=lora_scale,
|
||||
safe_fusing=safe_fusing,
|
||||
adapter_names=adapter_names,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
# Copied from diffusers.loaders.lora_pipeline.StableDiffusionXLLoraLoaderMixin.unfuse_lora with unet->transformer
|
||||
def unfuse_lora(self, components: List[str] = ["transformer", "text_encoder", "text_encoder_2"], **kwargs):
|
||||
r"""
|
||||
Reverses the effect of
|
||||
@@ -1328,12 +1320,12 @@ class SD3LoraLoaderMixin(LoraBaseMixin):
|
||||
|
||||
Args:
|
||||
components (`List[str]`): List of LoRA-injectable components to unfuse LoRA from.
|
||||
unfuse_unet (`bool`, defaults to `True`): Whether to unfuse the UNet LoRA parameters.
|
||||
unfuse_transformer (`bool`, defaults to `True`): Whether to unfuse the UNet LoRA parameters.
|
||||
unfuse_text_encoder (`bool`, defaults to `True`):
|
||||
Whether to unfuse the text encoder LoRA parameters. If the text encoder wasn't monkey-patched with the
|
||||
LoRA parameters then it won't have any effect.
|
||||
"""
|
||||
super().unfuse_lora(components=components)
|
||||
super().unfuse_lora(components=components, **kwargs)
|
||||
|
||||
|
||||
class FluxLoraLoaderMixin(LoraBaseMixin):
|
||||
@@ -1539,18 +1531,23 @@ class FluxLoraLoaderMixin(LoraBaseMixin):
|
||||
raise ValueError("Invalid LoRA checkpoint.")
|
||||
|
||||
transformer_lora_state_dict = {
|
||||
k: state_dict.pop(k) for k in list(state_dict.keys()) if "transformer." in k and "lora" in k
|
||||
k: state_dict.get(k)
|
||||
for k in list(state_dict.keys())
|
||||
if k.startswith(f"{self.transformer_name}.") and "lora" in k
|
||||
}
|
||||
transformer_norm_state_dict = {
|
||||
k: state_dict.pop(k)
|
||||
for k in list(state_dict.keys())
|
||||
if "transformer." in k and any(norm_key in k for norm_key in self._control_lora_supported_norm_keys)
|
||||
if k.startswith(f"{self.transformer_name}.")
|
||||
and any(norm_key in k for norm_key in self._control_lora_supported_norm_keys)
|
||||
}
|
||||
|
||||
transformer = getattr(self, self.transformer_name) if not hasattr(self, "transformer") else self.transformer
|
||||
has_param_with_expanded_shape = self._maybe_expand_transformer_param_shape_or_error_(
|
||||
transformer, transformer_lora_state_dict, transformer_norm_state_dict
|
||||
)
|
||||
has_param_with_expanded_shape = False
|
||||
if len(transformer_lora_state_dict) > 0:
|
||||
has_param_with_expanded_shape = self._maybe_expand_transformer_param_shape_or_error_(
|
||||
transformer, transformer_lora_state_dict, transformer_norm_state_dict
|
||||
)
|
||||
|
||||
if has_param_with_expanded_shape:
|
||||
logger.info(
|
||||
@@ -1558,19 +1555,21 @@ class FluxLoraLoaderMixin(LoraBaseMixin):
|
||||
"As a result, the state_dict of the transformer has been expanded to match the LoRA parameter shapes. "
|
||||
"To get a comprehensive list of parameter names that were modified, enable debug logging."
|
||||
)
|
||||
transformer_lora_state_dict = self._maybe_expand_lora_state_dict(
|
||||
transformer=transformer, lora_state_dict=transformer_lora_state_dict
|
||||
)
|
||||
|
||||
if len(transformer_lora_state_dict) > 0:
|
||||
self.load_lora_into_transformer(
|
||||
transformer_lora_state_dict,
|
||||
network_alphas=network_alphas,
|
||||
transformer=transformer,
|
||||
adapter_name=adapter_name,
|
||||
_pipeline=self,
|
||||
low_cpu_mem_usage=low_cpu_mem_usage,
|
||||
transformer_lora_state_dict = self._maybe_expand_lora_state_dict(
|
||||
transformer=transformer, lora_state_dict=transformer_lora_state_dict
|
||||
)
|
||||
for k in transformer_lora_state_dict:
|
||||
state_dict.update({k: transformer_lora_state_dict[k]})
|
||||
|
||||
self.load_lora_into_transformer(
|
||||
state_dict,
|
||||
network_alphas=network_alphas,
|
||||
transformer=transformer,
|
||||
adapter_name=adapter_name,
|
||||
_pipeline=self,
|
||||
low_cpu_mem_usage=low_cpu_mem_usage,
|
||||
)
|
||||
|
||||
if len(transformer_norm_state_dict) > 0:
|
||||
transformer._transformer_norm_layers = self._load_norm_into_transformer(
|
||||
@@ -1579,18 +1578,16 @@ class FluxLoraLoaderMixin(LoraBaseMixin):
|
||||
discard_original_layers=False,
|
||||
)
|
||||
|
||||
text_encoder_state_dict = {k: v for k, v in state_dict.items() if "text_encoder." in k}
|
||||
if len(text_encoder_state_dict) > 0:
|
||||
self.load_lora_into_text_encoder(
|
||||
text_encoder_state_dict,
|
||||
network_alphas=network_alphas,
|
||||
text_encoder=self.text_encoder,
|
||||
prefix="text_encoder",
|
||||
lora_scale=self.lora_scale,
|
||||
adapter_name=adapter_name,
|
||||
_pipeline=self,
|
||||
low_cpu_mem_usage=low_cpu_mem_usage,
|
||||
)
|
||||
self.load_lora_into_text_encoder(
|
||||
state_dict,
|
||||
network_alphas=network_alphas,
|
||||
text_encoder=self.text_encoder,
|
||||
prefix=self.text_encoder_name,
|
||||
lora_scale=self.lora_scale,
|
||||
adapter_name=adapter_name,
|
||||
_pipeline=self,
|
||||
low_cpu_mem_usage=low_cpu_mem_usage,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def load_lora_into_transformer(
|
||||
@@ -1623,17 +1620,14 @@ class FluxLoraLoaderMixin(LoraBaseMixin):
|
||||
)
|
||||
|
||||
# Load the layers corresponding to transformer.
|
||||
keys = list(state_dict.keys())
|
||||
transformer_present = any(key.startswith(cls.transformer_name) for key in keys)
|
||||
if transformer_present:
|
||||
logger.info(f"Loading {cls.transformer_name}.")
|
||||
transformer.load_lora_adapter(
|
||||
state_dict,
|
||||
network_alphas=network_alphas,
|
||||
adapter_name=adapter_name,
|
||||
_pipeline=_pipeline,
|
||||
low_cpu_mem_usage=low_cpu_mem_usage,
|
||||
)
|
||||
logger.info(f"Loading {cls.transformer_name}.")
|
||||
transformer.load_lora_adapter(
|
||||
state_dict,
|
||||
network_alphas=network_alphas,
|
||||
adapter_name=adapter_name,
|
||||
_pipeline=_pipeline,
|
||||
low_cpu_mem_usage=low_cpu_mem_usage,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def _load_norm_into_transformer(
|
||||
@@ -1847,7 +1841,11 @@ class FluxLoraLoaderMixin(LoraBaseMixin):
|
||||
)
|
||||
|
||||
super().fuse_lora(
|
||||
components=components, lora_scale=lora_scale, safe_fusing=safe_fusing, adapter_names=adapter_names
|
||||
components=components,
|
||||
lora_scale=lora_scale,
|
||||
safe_fusing=safe_fusing,
|
||||
adapter_names=adapter_names,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
def unfuse_lora(self, components: List[str] = ["transformer", "text_encoder"], **kwargs):
|
||||
@@ -1868,7 +1866,7 @@ class FluxLoraLoaderMixin(LoraBaseMixin):
|
||||
if hasattr(transformer, "_transformer_norm_layers") and transformer._transformer_norm_layers:
|
||||
transformer.load_state_dict(transformer._transformer_norm_layers, strict=False)
|
||||
|
||||
super().unfuse_lora(components=components)
|
||||
super().unfuse_lora(components=components, **kwargs)
|
||||
|
||||
# We override this here account for `_transformer_norm_layers` and `_overwritten_params`.
|
||||
def unload_lora_weights(self, reset_to_overwritten_params=False):
|
||||
@@ -2172,17 +2170,14 @@ class AmusedLoraLoaderMixin(StableDiffusionLoraLoaderMixin):
|
||||
)
|
||||
|
||||
# Load the layers corresponding to transformer.
|
||||
keys = list(state_dict.keys())
|
||||
transformer_present = any(key.startswith(cls.transformer_name) for key in keys)
|
||||
if transformer_present:
|
||||
logger.info(f"Loading {cls.transformer_name}.")
|
||||
transformer.load_lora_adapter(
|
||||
state_dict,
|
||||
network_alphas=network_alphas,
|
||||
adapter_name=adapter_name,
|
||||
_pipeline=_pipeline,
|
||||
low_cpu_mem_usage=low_cpu_mem_usage,
|
||||
)
|
||||
logger.info(f"Loading {cls.transformer_name}.")
|
||||
transformer.load_lora_adapter(
|
||||
state_dict,
|
||||
network_alphas=network_alphas,
|
||||
adapter_name=adapter_name,
|
||||
_pipeline=_pipeline,
|
||||
low_cpu_mem_usage=low_cpu_mem_usage,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
# Copied from diffusers.loaders.lora_pipeline.StableDiffusionLoraLoaderMixin.load_lora_into_text_encoder
|
||||
@@ -2570,7 +2565,11 @@ class CogVideoXLoraLoaderMixin(LoraBaseMixin):
|
||||
```
|
||||
"""
|
||||
super().fuse_lora(
|
||||
components=components, lora_scale=lora_scale, safe_fusing=safe_fusing, adapter_names=adapter_names
|
||||
components=components,
|
||||
lora_scale=lora_scale,
|
||||
safe_fusing=safe_fusing,
|
||||
adapter_names=adapter_names,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
def unfuse_lora(self, components: List[str] = ["transformer"], **kwargs):
|
||||
@@ -2588,7 +2587,7 @@ class CogVideoXLoraLoaderMixin(LoraBaseMixin):
|
||||
components (`List[str]`): List of LoRA-injectable components to unfuse LoRA from.
|
||||
unfuse_transformer (`bool`, defaults to `True`): Whether to unfuse the UNet LoRA parameters.
|
||||
"""
|
||||
super().unfuse_lora(components=components)
|
||||
super().unfuse_lora(components=components, **kwargs)
|
||||
|
||||
|
||||
class Mochi1LoraLoaderMixin(LoraBaseMixin):
|
||||
@@ -2833,6 +2832,7 @@ class Mochi1LoraLoaderMixin(LoraBaseMixin):
|
||||
safe_serialization=safe_serialization,
|
||||
)
|
||||
|
||||
# Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.fuse_lora
|
||||
def fuse_lora(
|
||||
self,
|
||||
components: List[str] = ["transformer"],
|
||||
@@ -2873,9 +2873,14 @@ class Mochi1LoraLoaderMixin(LoraBaseMixin):
|
||||
```
|
||||
"""
|
||||
super().fuse_lora(
|
||||
components=components, lora_scale=lora_scale, safe_fusing=safe_fusing, adapter_names=adapter_names
|
||||
components=components,
|
||||
lora_scale=lora_scale,
|
||||
safe_fusing=safe_fusing,
|
||||
adapter_names=adapter_names,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
# Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.unfuse_lora
|
||||
def unfuse_lora(self, components: List[str] = ["transformer"], **kwargs):
|
||||
r"""
|
||||
Reverses the effect of
|
||||
@@ -2891,7 +2896,7 @@ class Mochi1LoraLoaderMixin(LoraBaseMixin):
|
||||
components (`List[str]`): List of LoRA-injectable components to unfuse LoRA from.
|
||||
unfuse_transformer (`bool`, defaults to `True`): Whether to unfuse the UNet LoRA parameters.
|
||||
"""
|
||||
super().unfuse_lora(components=components)
|
||||
super().unfuse_lora(components=components, **kwargs)
|
||||
|
||||
|
||||
class LTXVideoLoraLoaderMixin(LoraBaseMixin):
|
||||
@@ -3136,6 +3141,7 @@ class LTXVideoLoraLoaderMixin(LoraBaseMixin):
|
||||
safe_serialization=safe_serialization,
|
||||
)
|
||||
|
||||
# Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.fuse_lora
|
||||
def fuse_lora(
|
||||
self,
|
||||
components: List[str] = ["transformer"],
|
||||
@@ -3176,9 +3182,14 @@ class LTXVideoLoraLoaderMixin(LoraBaseMixin):
|
||||
```
|
||||
"""
|
||||
super().fuse_lora(
|
||||
components=components, lora_scale=lora_scale, safe_fusing=safe_fusing, adapter_names=adapter_names
|
||||
components=components,
|
||||
lora_scale=lora_scale,
|
||||
safe_fusing=safe_fusing,
|
||||
adapter_names=adapter_names,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
# Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.unfuse_lora
|
||||
def unfuse_lora(self, components: List[str] = ["transformer"], **kwargs):
|
||||
r"""
|
||||
Reverses the effect of
|
||||
@@ -3194,7 +3205,7 @@ class LTXVideoLoraLoaderMixin(LoraBaseMixin):
|
||||
components (`List[str]`): List of LoRA-injectable components to unfuse LoRA from.
|
||||
unfuse_transformer (`bool`, defaults to `True`): Whether to unfuse the UNet LoRA parameters.
|
||||
"""
|
||||
super().unfuse_lora(components=components)
|
||||
super().unfuse_lora(components=components, **kwargs)
|
||||
|
||||
|
||||
class SanaLoraLoaderMixin(LoraBaseMixin):
|
||||
@@ -3439,6 +3450,7 @@ class SanaLoraLoaderMixin(LoraBaseMixin):
|
||||
safe_serialization=safe_serialization,
|
||||
)
|
||||
|
||||
# Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.fuse_lora
|
||||
def fuse_lora(
|
||||
self,
|
||||
components: List[str] = ["transformer"],
|
||||
@@ -3479,9 +3491,14 @@ class SanaLoraLoaderMixin(LoraBaseMixin):
|
||||
```
|
||||
"""
|
||||
super().fuse_lora(
|
||||
components=components, lora_scale=lora_scale, safe_fusing=safe_fusing, adapter_names=adapter_names
|
||||
components=components,
|
||||
lora_scale=lora_scale,
|
||||
safe_fusing=safe_fusing,
|
||||
adapter_names=adapter_names,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
# Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.unfuse_lora
|
||||
def unfuse_lora(self, components: List[str] = ["transformer"], **kwargs):
|
||||
r"""
|
||||
Reverses the effect of
|
||||
@@ -3497,7 +3514,7 @@ class SanaLoraLoaderMixin(LoraBaseMixin):
|
||||
components (`List[str]`): List of LoRA-injectable components to unfuse LoRA from.
|
||||
unfuse_transformer (`bool`, defaults to `True`): Whether to unfuse the UNet LoRA parameters.
|
||||
"""
|
||||
super().unfuse_lora(components=components)
|
||||
super().unfuse_lora(components=components, **kwargs)
|
||||
|
||||
|
||||
class HunyuanVideoLoraLoaderMixin(LoraBaseMixin):
|
||||
@@ -3745,6 +3762,7 @@ class HunyuanVideoLoraLoaderMixin(LoraBaseMixin):
|
||||
safe_serialization=safe_serialization,
|
||||
)
|
||||
|
||||
# Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.fuse_lora
|
||||
def fuse_lora(
|
||||
self,
|
||||
components: List[str] = ["transformer"],
|
||||
@@ -3785,9 +3803,14 @@ class HunyuanVideoLoraLoaderMixin(LoraBaseMixin):
|
||||
```
|
||||
"""
|
||||
super().fuse_lora(
|
||||
components=components, lora_scale=lora_scale, safe_fusing=safe_fusing, adapter_names=adapter_names
|
||||
components=components,
|
||||
lora_scale=lora_scale,
|
||||
safe_fusing=safe_fusing,
|
||||
adapter_names=adapter_names,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
# Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.unfuse_lora
|
||||
def unfuse_lora(self, components: List[str] = ["transformer"], **kwargs):
|
||||
r"""
|
||||
Reverses the effect of
|
||||
@@ -3803,7 +3826,7 @@ class HunyuanVideoLoraLoaderMixin(LoraBaseMixin):
|
||||
components (`List[str]`): List of LoRA-injectable components to unfuse LoRA from.
|
||||
unfuse_transformer (`bool`, defaults to `True`): Whether to unfuse the UNet LoRA parameters.
|
||||
"""
|
||||
super().unfuse_lora(components=components)
|
||||
super().unfuse_lora(components=components, **kwargs)
|
||||
|
||||
|
||||
class Lumina2LoraLoaderMixin(LoraBaseMixin):
|
||||
@@ -4093,7 +4116,11 @@ class Lumina2LoraLoaderMixin(LoraBaseMixin):
|
||||
```
|
||||
"""
|
||||
super().fuse_lora(
|
||||
components=components, lora_scale=lora_scale, safe_fusing=safe_fusing, adapter_names=adapter_names
|
||||
components=components,
|
||||
lora_scale=lora_scale,
|
||||
safe_fusing=safe_fusing,
|
||||
adapter_names=adapter_names,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
# Copied from diffusers.loaders.lora_pipeline.SanaLoraLoaderMixin.unfuse_lora
|
||||
@@ -4112,7 +4139,7 @@ class Lumina2LoraLoaderMixin(LoraBaseMixin):
|
||||
components (`List[str]`): List of LoRA-injectable components to unfuse LoRA from.
|
||||
unfuse_transformer (`bool`, defaults to `True`): Whether to unfuse the UNet LoRA parameters.
|
||||
"""
|
||||
super().unfuse_lora(components=components)
|
||||
super().unfuse_lora(components=components, **kwargs)
|
||||
|
||||
|
||||
class WanLoraLoaderMixin(LoraBaseMixin):
|
||||
@@ -4123,6 +4150,346 @@ class WanLoraLoaderMixin(LoraBaseMixin):
|
||||
_lora_loadable_modules = ["transformer"]
|
||||
transformer_name = TRANSFORMER_NAME
|
||||
|
||||
@classmethod
|
||||
@validate_hf_hub_args
|
||||
def lora_state_dict(
|
||||
cls,
|
||||
pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]],
|
||||
**kwargs,
|
||||
):
|
||||
r"""
|
||||
Return state dict for lora weights and the network alphas.
|
||||
|
||||
<Tip warning={true}>
|
||||
|
||||
We support loading A1111 formatted LoRA checkpoints in a limited capacity.
|
||||
|
||||
This function is experimental and might change in the future.
|
||||
|
||||
</Tip>
|
||||
|
||||
Parameters:
|
||||
pretrained_model_name_or_path_or_dict (`str` or `os.PathLike` or `dict`):
|
||||
Can be either:
|
||||
|
||||
- A string, the *model id* (for example `google/ddpm-celebahq-256`) of a pretrained model hosted on
|
||||
the Hub.
|
||||
- A path to a *directory* (for example `./my_model_directory`) containing the model weights saved
|
||||
with [`ModelMixin.save_pretrained`].
|
||||
- A [torch state
|
||||
dict](https://pytorch.org/tutorials/beginner/saving_loading_models.html#what-is-a-state-dict).
|
||||
|
||||
cache_dir (`Union[str, os.PathLike]`, *optional*):
|
||||
Path to a directory where a downloaded pretrained model configuration is cached if the standard cache
|
||||
is not used.
|
||||
force_download (`bool`, *optional*, defaults to `False`):
|
||||
Whether or not to force the (re-)download of the model weights and configuration files, overriding the
|
||||
cached versions if they exist.
|
||||
|
||||
proxies (`Dict[str, str]`, *optional*):
|
||||
A dictionary of proxy servers to use by protocol or endpoint, for example, `{'http': 'foo.bar:3128',
|
||||
'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request.
|
||||
local_files_only (`bool`, *optional*, defaults to `False`):
|
||||
Whether to only load local model weights and configuration files or not. If set to `True`, the model
|
||||
won't be downloaded from the Hub.
|
||||
token (`str` or *bool*, *optional*):
|
||||
The token to use as HTTP bearer authorization for remote files. If `True`, the token generated from
|
||||
`diffusers-cli login` (stored in `~/.huggingface`) is used.
|
||||
revision (`str`, *optional*, defaults to `"main"`):
|
||||
The specific model version to use. It can be a branch name, a tag name, a commit id, or any identifier
|
||||
allowed by Git.
|
||||
subfolder (`str`, *optional*, defaults to `""`):
|
||||
The subfolder location of a model file within a larger model repository on the Hub or locally.
|
||||
|
||||
"""
|
||||
# Load the main state dict first which has the LoRA layers for either of
|
||||
# transformer and text encoder or both.
|
||||
cache_dir = kwargs.pop("cache_dir", None)
|
||||
force_download = kwargs.pop("force_download", False)
|
||||
proxies = kwargs.pop("proxies", None)
|
||||
local_files_only = kwargs.pop("local_files_only", None)
|
||||
token = kwargs.pop("token", None)
|
||||
revision = kwargs.pop("revision", None)
|
||||
subfolder = kwargs.pop("subfolder", None)
|
||||
weight_name = kwargs.pop("weight_name", None)
|
||||
use_safetensors = kwargs.pop("use_safetensors", None)
|
||||
|
||||
allow_pickle = False
|
||||
if use_safetensors is None:
|
||||
use_safetensors = True
|
||||
allow_pickle = True
|
||||
|
||||
user_agent = {
|
||||
"file_type": "attn_procs_weights",
|
||||
"framework": "pytorch",
|
||||
}
|
||||
|
||||
state_dict = _fetch_state_dict(
|
||||
pretrained_model_name_or_path_or_dict=pretrained_model_name_or_path_or_dict,
|
||||
weight_name=weight_name,
|
||||
use_safetensors=use_safetensors,
|
||||
local_files_only=local_files_only,
|
||||
cache_dir=cache_dir,
|
||||
force_download=force_download,
|
||||
proxies=proxies,
|
||||
token=token,
|
||||
revision=revision,
|
||||
subfolder=subfolder,
|
||||
user_agent=user_agent,
|
||||
allow_pickle=allow_pickle,
|
||||
)
|
||||
if any(k.startswith("diffusion_model.") for k in state_dict):
|
||||
state_dict = _convert_non_diffusers_wan_lora_to_diffusers(state_dict)
|
||||
|
||||
is_dora_scale_present = any("dora_scale" in k for k in state_dict)
|
||||
if is_dora_scale_present:
|
||||
warn_msg = "It seems like you are using a DoRA checkpoint that is not compatible in Diffusers at the moment. So, we are going to filter out the keys associated to 'dora_scale` from the state dict. If you think this is a mistake please open an issue https://github.com/huggingface/diffusers/issues/new."
|
||||
logger.warning(warn_msg)
|
||||
state_dict = {k: v for k, v in state_dict.items() if "dora_scale" not in k}
|
||||
|
||||
return state_dict
|
||||
|
||||
@classmethod
|
||||
def _maybe_expand_t2v_lora_for_i2v(
|
||||
cls,
|
||||
transformer: torch.nn.Module,
|
||||
state_dict,
|
||||
):
|
||||
if transformer.config.image_dim is None:
|
||||
return state_dict
|
||||
|
||||
if any(k.startswith("transformer.blocks.") for k in state_dict):
|
||||
num_blocks = len({k.split("blocks.")[1].split(".")[0] for k in state_dict})
|
||||
is_i2v_lora = any("add_k_proj" in k for k in state_dict) and any("add_v_proj" in k for k in state_dict)
|
||||
|
||||
if is_i2v_lora:
|
||||
return state_dict
|
||||
|
||||
for i in range(num_blocks):
|
||||
for o, c in zip(["k_img", "v_img"], ["add_k_proj", "add_v_proj"]):
|
||||
state_dict[f"transformer.blocks.{i}.attn2.{c}.lora_A.weight"] = torch.zeros_like(
|
||||
state_dict[f"transformer.blocks.{i}.attn2.to_k.lora_A.weight"]
|
||||
)
|
||||
state_dict[f"transformer.blocks.{i}.attn2.{c}.lora_B.weight"] = torch.zeros_like(
|
||||
state_dict[f"transformer.blocks.{i}.attn2.to_k.lora_B.weight"]
|
||||
)
|
||||
|
||||
return state_dict
|
||||
|
||||
def load_lora_weights(
|
||||
self, pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]], adapter_name=None, **kwargs
|
||||
):
|
||||
"""
|
||||
Load LoRA weights specified in `pretrained_model_name_or_path_or_dict` into `self.transformer` and
|
||||
`self.text_encoder`. All kwargs are forwarded to `self.lora_state_dict`. See
|
||||
[`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`] for more details on how the state dict is loaded.
|
||||
See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_into_transformer`] for more details on how the state
|
||||
dict is loaded into `self.transformer`.
|
||||
|
||||
Parameters:
|
||||
pretrained_model_name_or_path_or_dict (`str` or `os.PathLike` or `dict`):
|
||||
See [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`].
|
||||
adapter_name (`str`, *optional*):
|
||||
Adapter name to be used for referencing the loaded adapter model. If not specified, it will use
|
||||
`default_{i}` where i is the total number of adapters being loaded.
|
||||
low_cpu_mem_usage (`bool`, *optional*):
|
||||
Speed up model loading by only loading the pretrained LoRA weights and not initializing the random
|
||||
weights.
|
||||
kwargs (`dict`, *optional*):
|
||||
See [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`].
|
||||
"""
|
||||
if not USE_PEFT_BACKEND:
|
||||
raise ValueError("PEFT backend is required for this method.")
|
||||
|
||||
low_cpu_mem_usage = kwargs.pop("low_cpu_mem_usage", _LOW_CPU_MEM_USAGE_DEFAULT_LORA)
|
||||
if low_cpu_mem_usage and is_peft_version("<", "0.13.0"):
|
||||
raise ValueError(
|
||||
"`low_cpu_mem_usage=True` is not compatible with this `peft` version. Please update it with `pip install -U peft`."
|
||||
)
|
||||
|
||||
# if a dict is passed, copy it instead of modifying it inplace
|
||||
if isinstance(pretrained_model_name_or_path_or_dict, dict):
|
||||
pretrained_model_name_or_path_or_dict = pretrained_model_name_or_path_or_dict.copy()
|
||||
|
||||
# First, ensure that the checkpoint is a compatible one and can be successfully loaded.
|
||||
state_dict = self.lora_state_dict(pretrained_model_name_or_path_or_dict, **kwargs)
|
||||
# convert T2V LoRA to I2V LoRA (when loaded to Wan I2V) by adding zeros for the additional (missing) _img layers
|
||||
state_dict = self._maybe_expand_t2v_lora_for_i2v(
|
||||
transformer=getattr(self, self.transformer_name) if not hasattr(self, "transformer") else self.transformer,
|
||||
state_dict=state_dict,
|
||||
)
|
||||
is_correct_format = all("lora" in key for key in state_dict.keys())
|
||||
if not is_correct_format:
|
||||
raise ValueError("Invalid LoRA checkpoint.")
|
||||
|
||||
self.load_lora_into_transformer(
|
||||
state_dict,
|
||||
transformer=getattr(self, self.transformer_name) if not hasattr(self, "transformer") else self.transformer,
|
||||
adapter_name=adapter_name,
|
||||
_pipeline=self,
|
||||
low_cpu_mem_usage=low_cpu_mem_usage,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
# Copied from diffusers.loaders.lora_pipeline.SD3LoraLoaderMixin.load_lora_into_transformer with SD3Transformer2DModel->WanTransformer3DModel
|
||||
def load_lora_into_transformer(
|
||||
cls, state_dict, transformer, adapter_name=None, _pipeline=None, low_cpu_mem_usage=False
|
||||
):
|
||||
"""
|
||||
This will load the LoRA layers specified in `state_dict` into `transformer`.
|
||||
|
||||
Parameters:
|
||||
state_dict (`dict`):
|
||||
A standard state dict containing the lora layer parameters. The keys can either be indexed directly
|
||||
into the unet or prefixed with an additional `unet` which can be used to distinguish between text
|
||||
encoder lora layers.
|
||||
transformer (`WanTransformer3DModel`):
|
||||
The Transformer model to load the LoRA layers into.
|
||||
adapter_name (`str`, *optional*):
|
||||
Adapter name to be used for referencing the loaded adapter model. If not specified, it will use
|
||||
`default_{i}` where i is the total number of adapters being loaded.
|
||||
low_cpu_mem_usage (`bool`, *optional*):
|
||||
Speed up model loading by only loading the pretrained LoRA weights and not initializing the random
|
||||
weights.
|
||||
"""
|
||||
if low_cpu_mem_usage and is_peft_version("<", "0.13.0"):
|
||||
raise ValueError(
|
||||
"`low_cpu_mem_usage=True` is not compatible with this `peft` version. Please update it with `pip install -U peft`."
|
||||
)
|
||||
|
||||
# Load the layers corresponding to transformer.
|
||||
logger.info(f"Loading {cls.transformer_name}.")
|
||||
transformer.load_lora_adapter(
|
||||
state_dict,
|
||||
network_alphas=None,
|
||||
adapter_name=adapter_name,
|
||||
_pipeline=_pipeline,
|
||||
low_cpu_mem_usage=low_cpu_mem_usage,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
# Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.save_lora_weights
|
||||
def save_lora_weights(
|
||||
cls,
|
||||
save_directory: Union[str, os.PathLike],
|
||||
transformer_lora_layers: Dict[str, Union[torch.nn.Module, torch.Tensor]] = None,
|
||||
is_main_process: bool = True,
|
||||
weight_name: str = None,
|
||||
save_function: Callable = None,
|
||||
safe_serialization: bool = True,
|
||||
):
|
||||
r"""
|
||||
Save the LoRA parameters corresponding to the UNet and text encoder.
|
||||
|
||||
Arguments:
|
||||
save_directory (`str` or `os.PathLike`):
|
||||
Directory to save LoRA parameters to. Will be created if it doesn't exist.
|
||||
transformer_lora_layers (`Dict[str, torch.nn.Module]` or `Dict[str, torch.Tensor]`):
|
||||
State dict of the LoRA layers corresponding to the `transformer`.
|
||||
is_main_process (`bool`, *optional*, defaults to `True`):
|
||||
Whether the process calling this is the main process or not. Useful during distributed training and you
|
||||
need to call this function on all processes. In this case, set `is_main_process=True` only on the main
|
||||
process to avoid race conditions.
|
||||
save_function (`Callable`):
|
||||
The function to use to save the state dictionary. Useful during distributed training when you need to
|
||||
replace `torch.save` with another method. Can be configured with the environment variable
|
||||
`DIFFUSERS_SAVE_MODE`.
|
||||
safe_serialization (`bool`, *optional*, defaults to `True`):
|
||||
Whether to save the model using `safetensors` or the traditional PyTorch way with `pickle`.
|
||||
"""
|
||||
state_dict = {}
|
||||
|
||||
if not transformer_lora_layers:
|
||||
raise ValueError("You must pass `transformer_lora_layers`.")
|
||||
|
||||
if transformer_lora_layers:
|
||||
state_dict.update(cls.pack_weights(transformer_lora_layers, cls.transformer_name))
|
||||
|
||||
# Save the model
|
||||
cls.write_lora_layers(
|
||||
state_dict=state_dict,
|
||||
save_directory=save_directory,
|
||||
is_main_process=is_main_process,
|
||||
weight_name=weight_name,
|
||||
save_function=save_function,
|
||||
safe_serialization=safe_serialization,
|
||||
)
|
||||
|
||||
# Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.fuse_lora
|
||||
def fuse_lora(
|
||||
self,
|
||||
components: List[str] = ["transformer"],
|
||||
lora_scale: float = 1.0,
|
||||
safe_fusing: bool = False,
|
||||
adapter_names: Optional[List[str]] = None,
|
||||
**kwargs,
|
||||
):
|
||||
r"""
|
||||
Fuses the LoRA parameters into the original parameters of the corresponding blocks.
|
||||
|
||||
<Tip warning={true}>
|
||||
|
||||
This is an experimental API.
|
||||
|
||||
</Tip>
|
||||
|
||||
Args:
|
||||
components: (`List[str]`): List of LoRA-injectable components to fuse the LoRAs into.
|
||||
lora_scale (`float`, defaults to 1.0):
|
||||
Controls how much to influence the outputs with the LoRA parameters.
|
||||
safe_fusing (`bool`, defaults to `False`):
|
||||
Whether to check fused weights for NaN values before fusing and if values are NaN not fusing them.
|
||||
adapter_names (`List[str]`, *optional*):
|
||||
Adapter names to be used for fusing. If nothing is passed, all active adapters will be fused.
|
||||
|
||||
Example:
|
||||
|
||||
```py
|
||||
from diffusers import DiffusionPipeline
|
||||
import torch
|
||||
|
||||
pipeline = DiffusionPipeline.from_pretrained(
|
||||
"stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.float16
|
||||
).to("cuda")
|
||||
pipeline.load_lora_weights("nerijs/pixel-art-xl", weight_name="pixel-art-xl.safetensors", adapter_name="pixel")
|
||||
pipeline.fuse_lora(lora_scale=0.7)
|
||||
```
|
||||
"""
|
||||
super().fuse_lora(
|
||||
components=components,
|
||||
lora_scale=lora_scale,
|
||||
safe_fusing=safe_fusing,
|
||||
adapter_names=adapter_names,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
# Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.unfuse_lora
|
||||
def unfuse_lora(self, components: List[str] = ["transformer"], **kwargs):
|
||||
r"""
|
||||
Reverses the effect of
|
||||
[`pipe.fuse_lora()`](https://huggingface.co/docs/diffusers/main/en/api/loaders#diffusers.loaders.LoraBaseMixin.fuse_lora).
|
||||
|
||||
<Tip warning={true}>
|
||||
|
||||
This is an experimental API.
|
||||
|
||||
</Tip>
|
||||
|
||||
Args:
|
||||
components (`List[str]`): List of LoRA-injectable components to unfuse LoRA from.
|
||||
unfuse_transformer (`bool`, defaults to `True`): Whether to unfuse the UNet LoRA parameters.
|
||||
"""
|
||||
super().unfuse_lora(components=components, **kwargs)
|
||||
|
||||
|
||||
class CogView4LoraLoaderMixin(LoraBaseMixin):
|
||||
r"""
|
||||
Load LoRA layers into [`WanTransformer3DModel`]. Specific to [`CogView4Pipeline`].
|
||||
"""
|
||||
|
||||
_lora_loadable_modules = ["transformer"]
|
||||
transformer_name = TRANSFORMER_NAME
|
||||
|
||||
@classmethod
|
||||
@validate_hf_hub_args
|
||||
# Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.lora_state_dict
|
||||
@@ -4273,7 +4640,7 @@ class WanLoraLoaderMixin(LoraBaseMixin):
|
||||
)
|
||||
|
||||
@classmethod
|
||||
# Copied from diffusers.loaders.lora_pipeline.SD3LoraLoaderMixin.load_lora_into_transformer with SD3Transformer2DModel->WanTransformer3DModel
|
||||
# Copied from diffusers.loaders.lora_pipeline.SD3LoraLoaderMixin.load_lora_into_transformer with SD3Transformer2DModel->CogView4Transformer2DModel
|
||||
def load_lora_into_transformer(
|
||||
cls, state_dict, transformer, adapter_name=None, _pipeline=None, low_cpu_mem_usage=False
|
||||
):
|
||||
@@ -4285,7 +4652,7 @@ class WanLoraLoaderMixin(LoraBaseMixin):
|
||||
A standard state dict containing the lora layer parameters. The keys can either be indexed directly
|
||||
into the unet or prefixed with an additional `unet` which can be used to distinguish between text
|
||||
encoder lora layers.
|
||||
transformer (`WanTransformer3DModel`):
|
||||
transformer (`CogView4Transformer2DModel`):
|
||||
The Transformer model to load the LoRA layers into.
|
||||
adapter_name (`str`, *optional*):
|
||||
Adapter name to be used for referencing the loaded adapter model. If not specified, it will use
|
||||
@@ -4398,7 +4765,11 @@ class WanLoraLoaderMixin(LoraBaseMixin):
|
||||
```
|
||||
"""
|
||||
super().fuse_lora(
|
||||
components=components, lora_scale=lora_scale, safe_fusing=safe_fusing, adapter_names=adapter_names
|
||||
components=components,
|
||||
lora_scale=lora_scale,
|
||||
safe_fusing=safe_fusing,
|
||||
adapter_names=adapter_names,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
# Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.unfuse_lora
|
||||
@@ -4417,7 +4788,7 @@ class WanLoraLoaderMixin(LoraBaseMixin):
|
||||
components (`List[str]`): List of LoRA-injectable components to unfuse LoRA from.
|
||||
unfuse_transformer (`bool`, defaults to `True`): Whether to unfuse the UNet LoRA parameters.
|
||||
"""
|
||||
super().unfuse_lora(components=components)
|
||||
super().unfuse_lora(components=components, **kwargs)
|
||||
|
||||
|
||||
class LoraLoaderMixin(StableDiffusionLoraLoaderMixin):
|
||||
|
||||
@@ -54,6 +54,7 @@ _SET_ADAPTER_SCALE_FN_MAPPING = {
|
||||
"SanaTransformer2DModel": lambda model_cls, weights: weights,
|
||||
"Lumina2Transformer2DModel": lambda model_cls, weights: weights,
|
||||
"WanTransformer3DModel": lambda model_cls, weights: weights,
|
||||
"CogView4Transformer2DModel": lambda model_cls, weights: weights,
|
||||
}
|
||||
|
||||
|
||||
@@ -192,11 +193,6 @@ class PeftAdapterMixin:
|
||||
from peft import LoraConfig, inject_adapter_in_model, set_peft_model_state_dict
|
||||
from peft.tuners.tuners_utils import BaseTunerLayer
|
||||
|
||||
try:
|
||||
from peft.utils.constants import FULLY_QUALIFIED_PATTERN_KEY_PREFIX
|
||||
except ImportError:
|
||||
FULLY_QUALIFIED_PATTERN_KEY_PREFIX = None
|
||||
|
||||
cache_dir = kwargs.pop("cache_dir", None)
|
||||
force_download = kwargs.pop("force_download", False)
|
||||
proxies = kwargs.pop("proxies", None)
|
||||
@@ -240,10 +236,7 @@ class PeftAdapterMixin:
|
||||
raise ValueError("`network_alphas` cannot be None when `prefix` is None.")
|
||||
|
||||
if prefix is not None:
|
||||
keys = list(state_dict.keys())
|
||||
model_keys = [k for k in keys if k.startswith(f"{prefix}.")]
|
||||
if len(model_keys) > 0:
|
||||
state_dict = {k.replace(f"{prefix}.", ""): v for k, v in state_dict.items() if k in model_keys}
|
||||
state_dict = {k[len(f"{prefix}.") :]: v for k, v in state_dict.items() if k.startswith(f"{prefix}.")}
|
||||
|
||||
if len(state_dict) > 0:
|
||||
if adapter_name in getattr(self, "peft_config", {}):
|
||||
@@ -261,22 +254,16 @@ class PeftAdapterMixin:
|
||||
# Cannot figure out rank from lora layers that don't have atleast 2 dimensions.
|
||||
# Bias layers in LoRA only have a single dimension
|
||||
if "lora_B" in key and val.ndim > 1:
|
||||
# Support to handle cases where layer patterns are treated as full layer names
|
||||
# was added later in PEFT. So, we handle it accordingly.
|
||||
# TODO: when we fix the minimal PEFT version for Diffusers,
|
||||
# we should remove `_maybe_adjust_config()`.
|
||||
if FULLY_QUALIFIED_PATTERN_KEY_PREFIX:
|
||||
rank[f"{FULLY_QUALIFIED_PATTERN_KEY_PREFIX}{key}"] = val.shape[1]
|
||||
else:
|
||||
rank[key] = val.shape[1]
|
||||
# TODO: revisit this after https://github.com/huggingface/peft/pull/2382 is merged.
|
||||
rank[key] = val.shape[1]
|
||||
|
||||
if network_alphas is not None and len(network_alphas) >= 1:
|
||||
alpha_keys = [k for k in network_alphas.keys() if k.startswith(f"{prefix}.")]
|
||||
network_alphas = {k.replace(f"{prefix}.", ""): v for k, v in network_alphas.items() if k in alpha_keys}
|
||||
|
||||
lora_config_kwargs = get_peft_kwargs(rank, network_alpha_dict=network_alphas, peft_state_dict=state_dict)
|
||||
if not FULLY_QUALIFIED_PATTERN_KEY_PREFIX:
|
||||
lora_config_kwargs = _maybe_adjust_config(lora_config_kwargs)
|
||||
# TODO: revisit this after https://github.com/huggingface/peft/pull/2382 is merged.
|
||||
lora_config_kwargs = _maybe_adjust_config(lora_config_kwargs)
|
||||
|
||||
if "use_dora" in lora_config_kwargs:
|
||||
if lora_config_kwargs["use_dora"]:
|
||||
@@ -320,6 +307,9 @@ class PeftAdapterMixin:
|
||||
try:
|
||||
inject_adapter_in_model(lora_config, self, adapter_name=adapter_name, **peft_kwargs)
|
||||
incompatible_keys = set_peft_model_state_dict(self, state_dict, adapter_name, **peft_kwargs)
|
||||
# Set peft config loaded flag to True if module has been successfully injected and incompatible keys retrieved
|
||||
if not self._hf_peft_config_loaded:
|
||||
self._hf_peft_config_loaded = True
|
||||
except Exception as e:
|
||||
# In case `inject_adapter_in_model()` was unsuccessful even before injecting the `peft_config`.
|
||||
if hasattr(self, "peft_config"):
|
||||
@@ -366,6 +356,15 @@ class PeftAdapterMixin:
|
||||
_pipeline.enable_sequential_cpu_offload()
|
||||
# Unsafe code />
|
||||
|
||||
if prefix is not None and not state_dict:
|
||||
logger.warning(
|
||||
f"No LoRA keys associated to {self.__class__.__name__} found with the {prefix=}. "
|
||||
"This is safe to ignore if LoRA state dict didn't originally have any "
|
||||
f"{self.__class__.__name__} related params. You can also try specifying `prefix=None` "
|
||||
"to resolve the warning. Otherwise, open an issue if you think it's unexpected: "
|
||||
"https://github.com/huggingface/diffusers/issues/new"
|
||||
)
|
||||
|
||||
def save_lora_adapter(
|
||||
self,
|
||||
save_directory,
|
||||
|
||||
@@ -360,12 +360,12 @@ class FromSingleFileMixin:
|
||||
cache_dir = kwargs.pop("cache_dir", None)
|
||||
local_files_only = kwargs.pop("local_files_only", False)
|
||||
revision = kwargs.pop("revision", None)
|
||||
torch_dtype = kwargs.pop("torch_dtype", torch.float32)
|
||||
torch_dtype = kwargs.pop("torch_dtype", None)
|
||||
disable_mmap = kwargs.pop("disable_mmap", False)
|
||||
|
||||
is_legacy_loading = False
|
||||
|
||||
if not isinstance(torch_dtype, torch.dtype):
|
||||
if torch_dtype is not None and not isinstance(torch_dtype, torch.dtype):
|
||||
torch_dtype = torch.float32
|
||||
logger.warning(
|
||||
f"Passed `torch_dtype` {torch_dtype} is not a `torch.dtype`. Defaulting to `torch.float32`."
|
||||
|
||||
@@ -37,8 +37,11 @@ from .single_file_utils import (
|
||||
convert_ltx_vae_checkpoint_to_diffusers,
|
||||
convert_lumina2_to_diffusers,
|
||||
convert_mochi_transformer_checkpoint_to_diffusers,
|
||||
convert_sana_transformer_to_diffusers,
|
||||
convert_sd3_transformer_checkpoint_to_diffusers,
|
||||
convert_stable_cascade_unet_single_file_to_diffusers,
|
||||
convert_wan_transformer_to_diffusers,
|
||||
convert_wan_vae_to_diffusers,
|
||||
create_controlnet_diffusers_config_from_ldm,
|
||||
create_unet_diffusers_config_from_ldm,
|
||||
create_vae_diffusers_config_from_ldm,
|
||||
@@ -117,6 +120,18 @@ SINGLE_FILE_LOADABLE_CLASSES = {
|
||||
"checkpoint_mapping_fn": convert_lumina2_to_diffusers,
|
||||
"default_subfolder": "transformer",
|
||||
},
|
||||
"SanaTransformer2DModel": {
|
||||
"checkpoint_mapping_fn": convert_sana_transformer_to_diffusers,
|
||||
"default_subfolder": "transformer",
|
||||
},
|
||||
"WanTransformer3DModel": {
|
||||
"checkpoint_mapping_fn": convert_wan_transformer_to_diffusers,
|
||||
"default_subfolder": "transformer",
|
||||
},
|
||||
"AutoencoderKLWan": {
|
||||
"checkpoint_mapping_fn": convert_wan_vae_to_diffusers,
|
||||
"default_subfolder": "vae",
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
@@ -240,12 +255,12 @@ class FromOriginalModelMixin:
|
||||
subfolder = kwargs.pop("subfolder", None)
|
||||
revision = kwargs.pop("revision", None)
|
||||
config_revision = kwargs.pop("config_revision", None)
|
||||
torch_dtype = kwargs.pop("torch_dtype", torch.float32)
|
||||
torch_dtype = kwargs.pop("torch_dtype", None)
|
||||
quantization_config = kwargs.pop("quantization_config", None)
|
||||
device = kwargs.pop("device", None)
|
||||
disable_mmap = kwargs.pop("disable_mmap", False)
|
||||
|
||||
if not isinstance(torch_dtype, torch.dtype):
|
||||
if torch_dtype is not None and not isinstance(torch_dtype, torch.dtype):
|
||||
torch_dtype = torch.float32
|
||||
logger.warning(
|
||||
f"Passed `torch_dtype` {torch_dtype} is not a `torch.dtype`. Defaulting to `torch.float32`."
|
||||
@@ -267,6 +282,7 @@ class FromOriginalModelMixin:
|
||||
if quantization_config is not None:
|
||||
hf_quantizer = DiffusersAutoQuantizer.from_config(quantization_config)
|
||||
hf_quantizer.validate_environment()
|
||||
torch_dtype = hf_quantizer.update_torch_dtype(torch_dtype)
|
||||
|
||||
else:
|
||||
hf_quantizer = None
|
||||
|
||||
@@ -117,6 +117,14 @@ CHECKPOINT_KEY_NAMES = {
|
||||
"hunyuan-video": "txt_in.individual_token_refiner.blocks.0.adaLN_modulation.1.bias",
|
||||
"instruct-pix2pix": "model.diffusion_model.input_blocks.0.0.weight",
|
||||
"lumina2": ["model.diffusion_model.cap_embedder.0.weight", "cap_embedder.0.weight"],
|
||||
"sana": [
|
||||
"blocks.0.cross_attn.q_linear.weight",
|
||||
"blocks.0.cross_attn.q_linear.bias",
|
||||
"blocks.0.cross_attn.kv_linear.weight",
|
||||
"blocks.0.cross_attn.kv_linear.bias",
|
||||
],
|
||||
"wan": ["model.diffusion_model.head.modulation", "head.modulation"],
|
||||
"wan_vae": "decoder.middle.0.residual.0.gamma",
|
||||
}
|
||||
|
||||
DIFFUSERS_DEFAULT_PIPELINE_PATHS = {
|
||||
@@ -176,6 +184,10 @@ DIFFUSERS_DEFAULT_PIPELINE_PATHS = {
|
||||
"hunyuan-video": {"pretrained_model_name_or_path": "hunyuanvideo-community/HunyuanVideo"},
|
||||
"instruct-pix2pix": {"pretrained_model_name_or_path": "timbrooks/instruct-pix2pix"},
|
||||
"lumina2": {"pretrained_model_name_or_path": "Alpha-VLLM/Lumina-Image-2.0"},
|
||||
"sana": {"pretrained_model_name_or_path": "Efficient-Large-Model/Sana_1600M_1024px_diffusers"},
|
||||
"wan-t2v-1.3B": {"pretrained_model_name_or_path": "Wan-AI/Wan2.1-T2V-1.3B-Diffusers"},
|
||||
"wan-t2v-14B": {"pretrained_model_name_or_path": "Wan-AI/Wan2.1-T2V-14B-Diffusers"},
|
||||
"wan-i2v-14B": {"pretrained_model_name_or_path": "Wan-AI/Wan2.1-I2V-14B-480P-Diffusers"},
|
||||
}
|
||||
|
||||
# Use to configure model sample size when original config is provided
|
||||
@@ -664,6 +676,24 @@ def infer_diffusers_model_type(checkpoint):
|
||||
elif any(key in checkpoint for key in CHECKPOINT_KEY_NAMES["lumina2"]):
|
||||
model_type = "lumina2"
|
||||
|
||||
elif any(key in checkpoint for key in CHECKPOINT_KEY_NAMES["sana"]):
|
||||
model_type = "sana"
|
||||
|
||||
elif any(key in checkpoint for key in CHECKPOINT_KEY_NAMES["wan"]):
|
||||
if "model.diffusion_model.patch_embedding.weight" in checkpoint:
|
||||
target_key = "model.diffusion_model.patch_embedding.weight"
|
||||
else:
|
||||
target_key = "patch_embedding.weight"
|
||||
|
||||
if checkpoint[target_key].shape[0] == 1536:
|
||||
model_type = "wan-t2v-1.3B"
|
||||
elif checkpoint[target_key].shape[0] == 5120 and checkpoint[target_key].shape[1] == 16:
|
||||
model_type = "wan-t2v-14B"
|
||||
else:
|
||||
model_type = "wan-i2v-14B"
|
||||
elif CHECKPOINT_KEY_NAMES["wan_vae"] in checkpoint:
|
||||
# All Wan models use the same VAE so we can use the same default model repo to fetch the config
|
||||
model_type = "wan-t2v-14B"
|
||||
else:
|
||||
model_type = "v1"
|
||||
|
||||
@@ -2470,7 +2500,7 @@ def convert_autoencoder_dc_checkpoint_to_diffusers(checkpoint, **kwargs):
|
||||
|
||||
|
||||
def convert_mochi_transformer_checkpoint_to_diffusers(checkpoint, **kwargs):
|
||||
new_state_dict = {}
|
||||
converted_state_dict = {}
|
||||
|
||||
# Comfy checkpoints add this prefix
|
||||
keys = list(checkpoint.keys())
|
||||
@@ -2479,22 +2509,22 @@ def convert_mochi_transformer_checkpoint_to_diffusers(checkpoint, **kwargs):
|
||||
checkpoint[k.replace("model.diffusion_model.", "")] = checkpoint.pop(k)
|
||||
|
||||
# Convert patch_embed
|
||||
new_state_dict["patch_embed.proj.weight"] = checkpoint.pop("x_embedder.proj.weight")
|
||||
new_state_dict["patch_embed.proj.bias"] = checkpoint.pop("x_embedder.proj.bias")
|
||||
converted_state_dict["patch_embed.proj.weight"] = checkpoint.pop("x_embedder.proj.weight")
|
||||
converted_state_dict["patch_embed.proj.bias"] = checkpoint.pop("x_embedder.proj.bias")
|
||||
|
||||
# Convert time_embed
|
||||
new_state_dict["time_embed.timestep_embedder.linear_1.weight"] = checkpoint.pop("t_embedder.mlp.0.weight")
|
||||
new_state_dict["time_embed.timestep_embedder.linear_1.bias"] = checkpoint.pop("t_embedder.mlp.0.bias")
|
||||
new_state_dict["time_embed.timestep_embedder.linear_2.weight"] = checkpoint.pop("t_embedder.mlp.2.weight")
|
||||
new_state_dict["time_embed.timestep_embedder.linear_2.bias"] = checkpoint.pop("t_embedder.mlp.2.bias")
|
||||
new_state_dict["time_embed.pooler.to_kv.weight"] = checkpoint.pop("t5_y_embedder.to_kv.weight")
|
||||
new_state_dict["time_embed.pooler.to_kv.bias"] = checkpoint.pop("t5_y_embedder.to_kv.bias")
|
||||
new_state_dict["time_embed.pooler.to_q.weight"] = checkpoint.pop("t5_y_embedder.to_q.weight")
|
||||
new_state_dict["time_embed.pooler.to_q.bias"] = checkpoint.pop("t5_y_embedder.to_q.bias")
|
||||
new_state_dict["time_embed.pooler.to_out.weight"] = checkpoint.pop("t5_y_embedder.to_out.weight")
|
||||
new_state_dict["time_embed.pooler.to_out.bias"] = checkpoint.pop("t5_y_embedder.to_out.bias")
|
||||
new_state_dict["time_embed.caption_proj.weight"] = checkpoint.pop("t5_yproj.weight")
|
||||
new_state_dict["time_embed.caption_proj.bias"] = checkpoint.pop("t5_yproj.bias")
|
||||
converted_state_dict["time_embed.timestep_embedder.linear_1.weight"] = checkpoint.pop("t_embedder.mlp.0.weight")
|
||||
converted_state_dict["time_embed.timestep_embedder.linear_1.bias"] = checkpoint.pop("t_embedder.mlp.0.bias")
|
||||
converted_state_dict["time_embed.timestep_embedder.linear_2.weight"] = checkpoint.pop("t_embedder.mlp.2.weight")
|
||||
converted_state_dict["time_embed.timestep_embedder.linear_2.bias"] = checkpoint.pop("t_embedder.mlp.2.bias")
|
||||
converted_state_dict["time_embed.pooler.to_kv.weight"] = checkpoint.pop("t5_y_embedder.to_kv.weight")
|
||||
converted_state_dict["time_embed.pooler.to_kv.bias"] = checkpoint.pop("t5_y_embedder.to_kv.bias")
|
||||
converted_state_dict["time_embed.pooler.to_q.weight"] = checkpoint.pop("t5_y_embedder.to_q.weight")
|
||||
converted_state_dict["time_embed.pooler.to_q.bias"] = checkpoint.pop("t5_y_embedder.to_q.bias")
|
||||
converted_state_dict["time_embed.pooler.to_out.weight"] = checkpoint.pop("t5_y_embedder.to_out.weight")
|
||||
converted_state_dict["time_embed.pooler.to_out.bias"] = checkpoint.pop("t5_y_embedder.to_out.bias")
|
||||
converted_state_dict["time_embed.caption_proj.weight"] = checkpoint.pop("t5_yproj.weight")
|
||||
converted_state_dict["time_embed.caption_proj.bias"] = checkpoint.pop("t5_yproj.bias")
|
||||
|
||||
# Convert transformer blocks
|
||||
num_layers = 48
|
||||
@@ -2503,68 +2533,84 @@ def convert_mochi_transformer_checkpoint_to_diffusers(checkpoint, **kwargs):
|
||||
old_prefix = f"blocks.{i}."
|
||||
|
||||
# norm1
|
||||
new_state_dict[block_prefix + "norm1.linear.weight"] = checkpoint.pop(old_prefix + "mod_x.weight")
|
||||
new_state_dict[block_prefix + "norm1.linear.bias"] = checkpoint.pop(old_prefix + "mod_x.bias")
|
||||
converted_state_dict[block_prefix + "norm1.linear.weight"] = checkpoint.pop(old_prefix + "mod_x.weight")
|
||||
converted_state_dict[block_prefix + "norm1.linear.bias"] = checkpoint.pop(old_prefix + "mod_x.bias")
|
||||
if i < num_layers - 1:
|
||||
new_state_dict[block_prefix + "norm1_context.linear.weight"] = checkpoint.pop(old_prefix + "mod_y.weight")
|
||||
new_state_dict[block_prefix + "norm1_context.linear.bias"] = checkpoint.pop(old_prefix + "mod_y.bias")
|
||||
else:
|
||||
new_state_dict[block_prefix + "norm1_context.linear_1.weight"] = checkpoint.pop(
|
||||
converted_state_dict[block_prefix + "norm1_context.linear.weight"] = checkpoint.pop(
|
||||
old_prefix + "mod_y.weight"
|
||||
)
|
||||
new_state_dict[block_prefix + "norm1_context.linear_1.bias"] = checkpoint.pop(old_prefix + "mod_y.bias")
|
||||
converted_state_dict[block_prefix + "norm1_context.linear.bias"] = checkpoint.pop(
|
||||
old_prefix + "mod_y.bias"
|
||||
)
|
||||
else:
|
||||
converted_state_dict[block_prefix + "norm1_context.linear_1.weight"] = checkpoint.pop(
|
||||
old_prefix + "mod_y.weight"
|
||||
)
|
||||
converted_state_dict[block_prefix + "norm1_context.linear_1.bias"] = checkpoint.pop(
|
||||
old_prefix + "mod_y.bias"
|
||||
)
|
||||
|
||||
# Visual attention
|
||||
qkv_weight = checkpoint.pop(old_prefix + "attn.qkv_x.weight")
|
||||
q, k, v = qkv_weight.chunk(3, dim=0)
|
||||
|
||||
new_state_dict[block_prefix + "attn1.to_q.weight"] = q
|
||||
new_state_dict[block_prefix + "attn1.to_k.weight"] = k
|
||||
new_state_dict[block_prefix + "attn1.to_v.weight"] = v
|
||||
new_state_dict[block_prefix + "attn1.norm_q.weight"] = checkpoint.pop(old_prefix + "attn.q_norm_x.weight")
|
||||
new_state_dict[block_prefix + "attn1.norm_k.weight"] = checkpoint.pop(old_prefix + "attn.k_norm_x.weight")
|
||||
new_state_dict[block_prefix + "attn1.to_out.0.weight"] = checkpoint.pop(old_prefix + "attn.proj_x.weight")
|
||||
new_state_dict[block_prefix + "attn1.to_out.0.bias"] = checkpoint.pop(old_prefix + "attn.proj_x.bias")
|
||||
converted_state_dict[block_prefix + "attn1.to_q.weight"] = q
|
||||
converted_state_dict[block_prefix + "attn1.to_k.weight"] = k
|
||||
converted_state_dict[block_prefix + "attn1.to_v.weight"] = v
|
||||
converted_state_dict[block_prefix + "attn1.norm_q.weight"] = checkpoint.pop(
|
||||
old_prefix + "attn.q_norm_x.weight"
|
||||
)
|
||||
converted_state_dict[block_prefix + "attn1.norm_k.weight"] = checkpoint.pop(
|
||||
old_prefix + "attn.k_norm_x.weight"
|
||||
)
|
||||
converted_state_dict[block_prefix + "attn1.to_out.0.weight"] = checkpoint.pop(
|
||||
old_prefix + "attn.proj_x.weight"
|
||||
)
|
||||
converted_state_dict[block_prefix + "attn1.to_out.0.bias"] = checkpoint.pop(old_prefix + "attn.proj_x.bias")
|
||||
|
||||
# Context attention
|
||||
qkv_weight = checkpoint.pop(old_prefix + "attn.qkv_y.weight")
|
||||
q, k, v = qkv_weight.chunk(3, dim=0)
|
||||
|
||||
new_state_dict[block_prefix + "attn1.add_q_proj.weight"] = q
|
||||
new_state_dict[block_prefix + "attn1.add_k_proj.weight"] = k
|
||||
new_state_dict[block_prefix + "attn1.add_v_proj.weight"] = v
|
||||
new_state_dict[block_prefix + "attn1.norm_added_q.weight"] = checkpoint.pop(
|
||||
converted_state_dict[block_prefix + "attn1.add_q_proj.weight"] = q
|
||||
converted_state_dict[block_prefix + "attn1.add_k_proj.weight"] = k
|
||||
converted_state_dict[block_prefix + "attn1.add_v_proj.weight"] = v
|
||||
converted_state_dict[block_prefix + "attn1.norm_added_q.weight"] = checkpoint.pop(
|
||||
old_prefix + "attn.q_norm_y.weight"
|
||||
)
|
||||
new_state_dict[block_prefix + "attn1.norm_added_k.weight"] = checkpoint.pop(
|
||||
converted_state_dict[block_prefix + "attn1.norm_added_k.weight"] = checkpoint.pop(
|
||||
old_prefix + "attn.k_norm_y.weight"
|
||||
)
|
||||
if i < num_layers - 1:
|
||||
new_state_dict[block_prefix + "attn1.to_add_out.weight"] = checkpoint.pop(
|
||||
converted_state_dict[block_prefix + "attn1.to_add_out.weight"] = checkpoint.pop(
|
||||
old_prefix + "attn.proj_y.weight"
|
||||
)
|
||||
new_state_dict[block_prefix + "attn1.to_add_out.bias"] = checkpoint.pop(old_prefix + "attn.proj_y.bias")
|
||||
converted_state_dict[block_prefix + "attn1.to_add_out.bias"] = checkpoint.pop(
|
||||
old_prefix + "attn.proj_y.bias"
|
||||
)
|
||||
|
||||
# MLP
|
||||
new_state_dict[block_prefix + "ff.net.0.proj.weight"] = swap_proj_gate(
|
||||
converted_state_dict[block_prefix + "ff.net.0.proj.weight"] = swap_proj_gate(
|
||||
checkpoint.pop(old_prefix + "mlp_x.w1.weight")
|
||||
)
|
||||
new_state_dict[block_prefix + "ff.net.2.weight"] = checkpoint.pop(old_prefix + "mlp_x.w2.weight")
|
||||
converted_state_dict[block_prefix + "ff.net.2.weight"] = checkpoint.pop(old_prefix + "mlp_x.w2.weight")
|
||||
if i < num_layers - 1:
|
||||
new_state_dict[block_prefix + "ff_context.net.0.proj.weight"] = swap_proj_gate(
|
||||
converted_state_dict[block_prefix + "ff_context.net.0.proj.weight"] = swap_proj_gate(
|
||||
checkpoint.pop(old_prefix + "mlp_y.w1.weight")
|
||||
)
|
||||
new_state_dict[block_prefix + "ff_context.net.2.weight"] = checkpoint.pop(old_prefix + "mlp_y.w2.weight")
|
||||
converted_state_dict[block_prefix + "ff_context.net.2.weight"] = checkpoint.pop(
|
||||
old_prefix + "mlp_y.w2.weight"
|
||||
)
|
||||
|
||||
# Output layers
|
||||
new_state_dict["norm_out.linear.weight"] = swap_scale_shift(checkpoint.pop("final_layer.mod.weight"), dim=0)
|
||||
new_state_dict["norm_out.linear.bias"] = swap_scale_shift(checkpoint.pop("final_layer.mod.bias"), dim=0)
|
||||
new_state_dict["proj_out.weight"] = checkpoint.pop("final_layer.linear.weight")
|
||||
new_state_dict["proj_out.bias"] = checkpoint.pop("final_layer.linear.bias")
|
||||
converted_state_dict["norm_out.linear.weight"] = swap_scale_shift(checkpoint.pop("final_layer.mod.weight"), dim=0)
|
||||
converted_state_dict["norm_out.linear.bias"] = swap_scale_shift(checkpoint.pop("final_layer.mod.bias"), dim=0)
|
||||
converted_state_dict["proj_out.weight"] = checkpoint.pop("final_layer.linear.weight")
|
||||
converted_state_dict["proj_out.bias"] = checkpoint.pop("final_layer.linear.bias")
|
||||
|
||||
new_state_dict["pos_frequencies"] = checkpoint.pop("pos_frequencies")
|
||||
converted_state_dict["pos_frequencies"] = checkpoint.pop("pos_frequencies")
|
||||
|
||||
return new_state_dict
|
||||
return converted_state_dict
|
||||
|
||||
|
||||
def convert_hunyuan_video_transformer_to_diffusers(checkpoint, **kwargs):
|
||||
@@ -2859,3 +2905,357 @@ def convert_lumina2_to_diffusers(checkpoint, **kwargs):
|
||||
converted_state_dict[diffusers_key] = checkpoint.pop(key)
|
||||
|
||||
return converted_state_dict
|
||||
|
||||
|
||||
def convert_sana_transformer_to_diffusers(checkpoint, **kwargs):
|
||||
converted_state_dict = {}
|
||||
keys = list(checkpoint.keys())
|
||||
for k in keys:
|
||||
if "model.diffusion_model." in k:
|
||||
checkpoint[k.replace("model.diffusion_model.", "")] = checkpoint.pop(k)
|
||||
|
||||
num_layers = list(set(int(k.split(".", 2)[1]) for k in checkpoint if "blocks" in k))[-1] + 1 # noqa: C401
|
||||
|
||||
# Positional and patch embeddings.
|
||||
checkpoint.pop("pos_embed")
|
||||
converted_state_dict["patch_embed.proj.weight"] = checkpoint.pop("x_embedder.proj.weight")
|
||||
converted_state_dict["patch_embed.proj.bias"] = checkpoint.pop("x_embedder.proj.bias")
|
||||
|
||||
# Timestep embeddings.
|
||||
converted_state_dict["time_embed.emb.timestep_embedder.linear_1.weight"] = checkpoint.pop(
|
||||
"t_embedder.mlp.0.weight"
|
||||
)
|
||||
converted_state_dict["time_embed.emb.timestep_embedder.linear_1.bias"] = checkpoint.pop("t_embedder.mlp.0.bias")
|
||||
converted_state_dict["time_embed.emb.timestep_embedder.linear_2.weight"] = checkpoint.pop(
|
||||
"t_embedder.mlp.2.weight"
|
||||
)
|
||||
converted_state_dict["time_embed.emb.timestep_embedder.linear_2.bias"] = checkpoint.pop("t_embedder.mlp.2.bias")
|
||||
converted_state_dict["time_embed.linear.weight"] = checkpoint.pop("t_block.1.weight")
|
||||
converted_state_dict["time_embed.linear.bias"] = checkpoint.pop("t_block.1.bias")
|
||||
|
||||
# Caption Projection.
|
||||
checkpoint.pop("y_embedder.y_embedding")
|
||||
converted_state_dict["caption_projection.linear_1.weight"] = checkpoint.pop("y_embedder.y_proj.fc1.weight")
|
||||
converted_state_dict["caption_projection.linear_1.bias"] = checkpoint.pop("y_embedder.y_proj.fc1.bias")
|
||||
converted_state_dict["caption_projection.linear_2.weight"] = checkpoint.pop("y_embedder.y_proj.fc2.weight")
|
||||
converted_state_dict["caption_projection.linear_2.bias"] = checkpoint.pop("y_embedder.y_proj.fc2.bias")
|
||||
converted_state_dict["caption_norm.weight"] = checkpoint.pop("attention_y_norm.weight")
|
||||
|
||||
for i in range(num_layers):
|
||||
converted_state_dict[f"transformer_blocks.{i}.scale_shift_table"] = checkpoint.pop(
|
||||
f"blocks.{i}.scale_shift_table"
|
||||
)
|
||||
|
||||
# Self-Attention
|
||||
sample_q, sample_k, sample_v = torch.chunk(checkpoint.pop(f"blocks.{i}.attn.qkv.weight"), 3, dim=0)
|
||||
converted_state_dict[f"transformer_blocks.{i}.attn1.to_q.weight"] = torch.cat([sample_q])
|
||||
converted_state_dict[f"transformer_blocks.{i}.attn1.to_k.weight"] = torch.cat([sample_k])
|
||||
converted_state_dict[f"transformer_blocks.{i}.attn1.to_v.weight"] = torch.cat([sample_v])
|
||||
|
||||
# Output Projections
|
||||
converted_state_dict[f"transformer_blocks.{i}.attn1.to_out.0.weight"] = checkpoint.pop(
|
||||
f"blocks.{i}.attn.proj.weight"
|
||||
)
|
||||
converted_state_dict[f"transformer_blocks.{i}.attn1.to_out.0.bias"] = checkpoint.pop(
|
||||
f"blocks.{i}.attn.proj.bias"
|
||||
)
|
||||
|
||||
# Cross-Attention
|
||||
converted_state_dict[f"transformer_blocks.{i}.attn2.to_q.weight"] = checkpoint.pop(
|
||||
f"blocks.{i}.cross_attn.q_linear.weight"
|
||||
)
|
||||
converted_state_dict[f"transformer_blocks.{i}.attn2.to_q.bias"] = checkpoint.pop(
|
||||
f"blocks.{i}.cross_attn.q_linear.bias"
|
||||
)
|
||||
|
||||
linear_sample_k, linear_sample_v = torch.chunk(
|
||||
checkpoint.pop(f"blocks.{i}.cross_attn.kv_linear.weight"), 2, dim=0
|
||||
)
|
||||
linear_sample_k_bias, linear_sample_v_bias = torch.chunk(
|
||||
checkpoint.pop(f"blocks.{i}.cross_attn.kv_linear.bias"), 2, dim=0
|
||||
)
|
||||
converted_state_dict[f"transformer_blocks.{i}.attn2.to_k.weight"] = linear_sample_k
|
||||
converted_state_dict[f"transformer_blocks.{i}.attn2.to_v.weight"] = linear_sample_v
|
||||
converted_state_dict[f"transformer_blocks.{i}.attn2.to_k.bias"] = linear_sample_k_bias
|
||||
converted_state_dict[f"transformer_blocks.{i}.attn2.to_v.bias"] = linear_sample_v_bias
|
||||
|
||||
# Output Projections
|
||||
converted_state_dict[f"transformer_blocks.{i}.attn2.to_out.0.weight"] = checkpoint.pop(
|
||||
f"blocks.{i}.cross_attn.proj.weight"
|
||||
)
|
||||
converted_state_dict[f"transformer_blocks.{i}.attn2.to_out.0.bias"] = checkpoint.pop(
|
||||
f"blocks.{i}.cross_attn.proj.bias"
|
||||
)
|
||||
|
||||
# MLP
|
||||
converted_state_dict[f"transformer_blocks.{i}.ff.conv_inverted.weight"] = checkpoint.pop(
|
||||
f"blocks.{i}.mlp.inverted_conv.conv.weight"
|
||||
)
|
||||
converted_state_dict[f"transformer_blocks.{i}.ff.conv_inverted.bias"] = checkpoint.pop(
|
||||
f"blocks.{i}.mlp.inverted_conv.conv.bias"
|
||||
)
|
||||
converted_state_dict[f"transformer_blocks.{i}.ff.conv_depth.weight"] = checkpoint.pop(
|
||||
f"blocks.{i}.mlp.depth_conv.conv.weight"
|
||||
)
|
||||
converted_state_dict[f"transformer_blocks.{i}.ff.conv_depth.bias"] = checkpoint.pop(
|
||||
f"blocks.{i}.mlp.depth_conv.conv.bias"
|
||||
)
|
||||
converted_state_dict[f"transformer_blocks.{i}.ff.conv_point.weight"] = checkpoint.pop(
|
||||
f"blocks.{i}.mlp.point_conv.conv.weight"
|
||||
)
|
||||
|
||||
# Final layer
|
||||
converted_state_dict["proj_out.weight"] = checkpoint.pop("final_layer.linear.weight")
|
||||
converted_state_dict["proj_out.bias"] = checkpoint.pop("final_layer.linear.bias")
|
||||
converted_state_dict["scale_shift_table"] = checkpoint.pop("final_layer.scale_shift_table")
|
||||
|
||||
return converted_state_dict
|
||||
|
||||
|
||||
def convert_wan_transformer_to_diffusers(checkpoint, **kwargs):
|
||||
converted_state_dict = {}
|
||||
|
||||
keys = list(checkpoint.keys())
|
||||
for k in keys:
|
||||
if "model.diffusion_model." in k:
|
||||
checkpoint[k.replace("model.diffusion_model.", "")] = checkpoint.pop(k)
|
||||
|
||||
TRANSFORMER_KEYS_RENAME_DICT = {
|
||||
"time_embedding.0": "condition_embedder.time_embedder.linear_1",
|
||||
"time_embedding.2": "condition_embedder.time_embedder.linear_2",
|
||||
"text_embedding.0": "condition_embedder.text_embedder.linear_1",
|
||||
"text_embedding.2": "condition_embedder.text_embedder.linear_2",
|
||||
"time_projection.1": "condition_embedder.time_proj",
|
||||
"cross_attn": "attn2",
|
||||
"self_attn": "attn1",
|
||||
".o.": ".to_out.0.",
|
||||
".q.": ".to_q.",
|
||||
".k.": ".to_k.",
|
||||
".v.": ".to_v.",
|
||||
".k_img.": ".add_k_proj.",
|
||||
".v_img.": ".add_v_proj.",
|
||||
".norm_k_img.": ".norm_added_k.",
|
||||
"head.modulation": "scale_shift_table",
|
||||
"head.head": "proj_out",
|
||||
"modulation": "scale_shift_table",
|
||||
"ffn.0": "ffn.net.0.proj",
|
||||
"ffn.2": "ffn.net.2",
|
||||
# Hack to swap the layer names
|
||||
# The original model calls the norms in following order: norm1, norm3, norm2
|
||||
# We convert it to: norm1, norm2, norm3
|
||||
"norm2": "norm__placeholder",
|
||||
"norm3": "norm2",
|
||||
"norm__placeholder": "norm3",
|
||||
# For the I2V model
|
||||
"img_emb.proj.0": "condition_embedder.image_embedder.norm1",
|
||||
"img_emb.proj.1": "condition_embedder.image_embedder.ff.net.0.proj",
|
||||
"img_emb.proj.3": "condition_embedder.image_embedder.ff.net.2",
|
||||
"img_emb.proj.4": "condition_embedder.image_embedder.norm2",
|
||||
}
|
||||
|
||||
for key in list(checkpoint.keys()):
|
||||
new_key = key[:]
|
||||
for replace_key, rename_key in TRANSFORMER_KEYS_RENAME_DICT.items():
|
||||
new_key = new_key.replace(replace_key, rename_key)
|
||||
|
||||
converted_state_dict[new_key] = checkpoint.pop(key)
|
||||
|
||||
return converted_state_dict
|
||||
|
||||
|
||||
def convert_wan_vae_to_diffusers(checkpoint, **kwargs):
|
||||
converted_state_dict = {}
|
||||
|
||||
# Create mappings for specific components
|
||||
middle_key_mapping = {
|
||||
# Encoder middle block
|
||||
"encoder.middle.0.residual.0.gamma": "encoder.mid_block.resnets.0.norm1.gamma",
|
||||
"encoder.middle.0.residual.2.bias": "encoder.mid_block.resnets.0.conv1.bias",
|
||||
"encoder.middle.0.residual.2.weight": "encoder.mid_block.resnets.0.conv1.weight",
|
||||
"encoder.middle.0.residual.3.gamma": "encoder.mid_block.resnets.0.norm2.gamma",
|
||||
"encoder.middle.0.residual.6.bias": "encoder.mid_block.resnets.0.conv2.bias",
|
||||
"encoder.middle.0.residual.6.weight": "encoder.mid_block.resnets.0.conv2.weight",
|
||||
"encoder.middle.2.residual.0.gamma": "encoder.mid_block.resnets.1.norm1.gamma",
|
||||
"encoder.middle.2.residual.2.bias": "encoder.mid_block.resnets.1.conv1.bias",
|
||||
"encoder.middle.2.residual.2.weight": "encoder.mid_block.resnets.1.conv1.weight",
|
||||
"encoder.middle.2.residual.3.gamma": "encoder.mid_block.resnets.1.norm2.gamma",
|
||||
"encoder.middle.2.residual.6.bias": "encoder.mid_block.resnets.1.conv2.bias",
|
||||
"encoder.middle.2.residual.6.weight": "encoder.mid_block.resnets.1.conv2.weight",
|
||||
# Decoder middle block
|
||||
"decoder.middle.0.residual.0.gamma": "decoder.mid_block.resnets.0.norm1.gamma",
|
||||
"decoder.middle.0.residual.2.bias": "decoder.mid_block.resnets.0.conv1.bias",
|
||||
"decoder.middle.0.residual.2.weight": "decoder.mid_block.resnets.0.conv1.weight",
|
||||
"decoder.middle.0.residual.3.gamma": "decoder.mid_block.resnets.0.norm2.gamma",
|
||||
"decoder.middle.0.residual.6.bias": "decoder.mid_block.resnets.0.conv2.bias",
|
||||
"decoder.middle.0.residual.6.weight": "decoder.mid_block.resnets.0.conv2.weight",
|
||||
"decoder.middle.2.residual.0.gamma": "decoder.mid_block.resnets.1.norm1.gamma",
|
||||
"decoder.middle.2.residual.2.bias": "decoder.mid_block.resnets.1.conv1.bias",
|
||||
"decoder.middle.2.residual.2.weight": "decoder.mid_block.resnets.1.conv1.weight",
|
||||
"decoder.middle.2.residual.3.gamma": "decoder.mid_block.resnets.1.norm2.gamma",
|
||||
"decoder.middle.2.residual.6.bias": "decoder.mid_block.resnets.1.conv2.bias",
|
||||
"decoder.middle.2.residual.6.weight": "decoder.mid_block.resnets.1.conv2.weight",
|
||||
}
|
||||
|
||||
# Create a mapping for attention blocks
|
||||
attention_mapping = {
|
||||
# Encoder middle attention
|
||||
"encoder.middle.1.norm.gamma": "encoder.mid_block.attentions.0.norm.gamma",
|
||||
"encoder.middle.1.to_qkv.weight": "encoder.mid_block.attentions.0.to_qkv.weight",
|
||||
"encoder.middle.1.to_qkv.bias": "encoder.mid_block.attentions.0.to_qkv.bias",
|
||||
"encoder.middle.1.proj.weight": "encoder.mid_block.attentions.0.proj.weight",
|
||||
"encoder.middle.1.proj.bias": "encoder.mid_block.attentions.0.proj.bias",
|
||||
# Decoder middle attention
|
||||
"decoder.middle.1.norm.gamma": "decoder.mid_block.attentions.0.norm.gamma",
|
||||
"decoder.middle.1.to_qkv.weight": "decoder.mid_block.attentions.0.to_qkv.weight",
|
||||
"decoder.middle.1.to_qkv.bias": "decoder.mid_block.attentions.0.to_qkv.bias",
|
||||
"decoder.middle.1.proj.weight": "decoder.mid_block.attentions.0.proj.weight",
|
||||
"decoder.middle.1.proj.bias": "decoder.mid_block.attentions.0.proj.bias",
|
||||
}
|
||||
|
||||
# Create a mapping for the head components
|
||||
head_mapping = {
|
||||
# Encoder head
|
||||
"encoder.head.0.gamma": "encoder.norm_out.gamma",
|
||||
"encoder.head.2.bias": "encoder.conv_out.bias",
|
||||
"encoder.head.2.weight": "encoder.conv_out.weight",
|
||||
# Decoder head
|
||||
"decoder.head.0.gamma": "decoder.norm_out.gamma",
|
||||
"decoder.head.2.bias": "decoder.conv_out.bias",
|
||||
"decoder.head.2.weight": "decoder.conv_out.weight",
|
||||
}
|
||||
|
||||
# Create a mapping for the quant components
|
||||
quant_mapping = {
|
||||
"conv1.weight": "quant_conv.weight",
|
||||
"conv1.bias": "quant_conv.bias",
|
||||
"conv2.weight": "post_quant_conv.weight",
|
||||
"conv2.bias": "post_quant_conv.bias",
|
||||
}
|
||||
|
||||
# Process each key in the state dict
|
||||
for key, value in checkpoint.items():
|
||||
# Handle middle block keys using the mapping
|
||||
if key in middle_key_mapping:
|
||||
new_key = middle_key_mapping[key]
|
||||
converted_state_dict[new_key] = value
|
||||
# Handle attention blocks using the mapping
|
||||
elif key in attention_mapping:
|
||||
new_key = attention_mapping[key]
|
||||
converted_state_dict[new_key] = value
|
||||
# Handle head keys using the mapping
|
||||
elif key in head_mapping:
|
||||
new_key = head_mapping[key]
|
||||
converted_state_dict[new_key] = value
|
||||
# Handle quant keys using the mapping
|
||||
elif key in quant_mapping:
|
||||
new_key = quant_mapping[key]
|
||||
converted_state_dict[new_key] = value
|
||||
# Handle encoder conv1
|
||||
elif key == "encoder.conv1.weight":
|
||||
converted_state_dict["encoder.conv_in.weight"] = value
|
||||
elif key == "encoder.conv1.bias":
|
||||
converted_state_dict["encoder.conv_in.bias"] = value
|
||||
# Handle decoder conv1
|
||||
elif key == "decoder.conv1.weight":
|
||||
converted_state_dict["decoder.conv_in.weight"] = value
|
||||
elif key == "decoder.conv1.bias":
|
||||
converted_state_dict["decoder.conv_in.bias"] = value
|
||||
# Handle encoder downsamples
|
||||
elif key.startswith("encoder.downsamples."):
|
||||
# Convert to down_blocks
|
||||
new_key = key.replace("encoder.downsamples.", "encoder.down_blocks.")
|
||||
|
||||
# Convert residual block naming but keep the original structure
|
||||
if ".residual.0.gamma" in new_key:
|
||||
new_key = new_key.replace(".residual.0.gamma", ".norm1.gamma")
|
||||
elif ".residual.2.bias" in new_key:
|
||||
new_key = new_key.replace(".residual.2.bias", ".conv1.bias")
|
||||
elif ".residual.2.weight" in new_key:
|
||||
new_key = new_key.replace(".residual.2.weight", ".conv1.weight")
|
||||
elif ".residual.3.gamma" in new_key:
|
||||
new_key = new_key.replace(".residual.3.gamma", ".norm2.gamma")
|
||||
elif ".residual.6.bias" in new_key:
|
||||
new_key = new_key.replace(".residual.6.bias", ".conv2.bias")
|
||||
elif ".residual.6.weight" in new_key:
|
||||
new_key = new_key.replace(".residual.6.weight", ".conv2.weight")
|
||||
elif ".shortcut.bias" in new_key:
|
||||
new_key = new_key.replace(".shortcut.bias", ".conv_shortcut.bias")
|
||||
elif ".shortcut.weight" in new_key:
|
||||
new_key = new_key.replace(".shortcut.weight", ".conv_shortcut.weight")
|
||||
|
||||
converted_state_dict[new_key] = value
|
||||
|
||||
# Handle decoder upsamples
|
||||
elif key.startswith("decoder.upsamples."):
|
||||
# Convert to up_blocks
|
||||
parts = key.split(".")
|
||||
block_idx = int(parts[2])
|
||||
|
||||
# Group residual blocks
|
||||
if "residual" in key:
|
||||
if block_idx in [0, 1, 2]:
|
||||
new_block_idx = 0
|
||||
resnet_idx = block_idx
|
||||
elif block_idx in [4, 5, 6]:
|
||||
new_block_idx = 1
|
||||
resnet_idx = block_idx - 4
|
||||
elif block_idx in [8, 9, 10]:
|
||||
new_block_idx = 2
|
||||
resnet_idx = block_idx - 8
|
||||
elif block_idx in [12, 13, 14]:
|
||||
new_block_idx = 3
|
||||
resnet_idx = block_idx - 12
|
||||
else:
|
||||
# Keep as is for other blocks
|
||||
converted_state_dict[key] = value
|
||||
continue
|
||||
|
||||
# Convert residual block naming
|
||||
if ".residual.0.gamma" in key:
|
||||
new_key = f"decoder.up_blocks.{new_block_idx}.resnets.{resnet_idx}.norm1.gamma"
|
||||
elif ".residual.2.bias" in key:
|
||||
new_key = f"decoder.up_blocks.{new_block_idx}.resnets.{resnet_idx}.conv1.bias"
|
||||
elif ".residual.2.weight" in key:
|
||||
new_key = f"decoder.up_blocks.{new_block_idx}.resnets.{resnet_idx}.conv1.weight"
|
||||
elif ".residual.3.gamma" in key:
|
||||
new_key = f"decoder.up_blocks.{new_block_idx}.resnets.{resnet_idx}.norm2.gamma"
|
||||
elif ".residual.6.bias" in key:
|
||||
new_key = f"decoder.up_blocks.{new_block_idx}.resnets.{resnet_idx}.conv2.bias"
|
||||
elif ".residual.6.weight" in key:
|
||||
new_key = f"decoder.up_blocks.{new_block_idx}.resnets.{resnet_idx}.conv2.weight"
|
||||
else:
|
||||
new_key = key
|
||||
|
||||
converted_state_dict[new_key] = value
|
||||
|
||||
# Handle shortcut connections
|
||||
elif ".shortcut." in key:
|
||||
if block_idx == 4:
|
||||
new_key = key.replace(".shortcut.", ".resnets.0.conv_shortcut.")
|
||||
new_key = new_key.replace("decoder.upsamples.4", "decoder.up_blocks.1")
|
||||
else:
|
||||
new_key = key.replace("decoder.upsamples.", "decoder.up_blocks.")
|
||||
new_key = new_key.replace(".shortcut.", ".conv_shortcut.")
|
||||
|
||||
converted_state_dict[new_key] = value
|
||||
|
||||
# Handle upsamplers
|
||||
elif ".resample." in key or ".time_conv." in key:
|
||||
if block_idx == 3:
|
||||
new_key = key.replace(f"decoder.upsamples.{block_idx}", "decoder.up_blocks.0.upsamplers.0")
|
||||
elif block_idx == 7:
|
||||
new_key = key.replace(f"decoder.upsamples.{block_idx}", "decoder.up_blocks.1.upsamplers.0")
|
||||
elif block_idx == 11:
|
||||
new_key = key.replace(f"decoder.upsamples.{block_idx}", "decoder.up_blocks.2.upsamplers.0")
|
||||
else:
|
||||
new_key = key.replace("decoder.upsamples.", "decoder.up_blocks.")
|
||||
|
||||
converted_state_dict[new_key] = value
|
||||
else:
|
||||
new_key = key.replace("decoder.upsamples.", "decoder.up_blocks.")
|
||||
converted_state_dict[new_key] = value
|
||||
else:
|
||||
# Keep other keys unchanged
|
||||
converted_state_dict[key] = value
|
||||
|
||||
return converted_state_dict
|
||||
|
||||
@@ -449,9 +449,9 @@ class TextualInversionLoaderMixin:
|
||||
|
||||
# 7.5 Offload the model again
|
||||
if is_model_cpu_offload:
|
||||
self.enable_model_cpu_offload()
|
||||
self.enable_model_cpu_offload(device=device)
|
||||
elif is_sequential_cpu_offload:
|
||||
self.enable_sequential_cpu_offload()
|
||||
self.enable_sequential_cpu_offload(device=device)
|
||||
|
||||
# / Unsafe Code >
|
||||
|
||||
|
||||
@@ -284,8 +284,9 @@ class Attention(nn.Module):
|
||||
self.norm_added_q = RMSNorm(dim_head, eps=eps)
|
||||
self.norm_added_k = RMSNorm(dim_head, eps=eps)
|
||||
elif qk_norm == "rms_norm_across_heads":
|
||||
# Wanx applies qk norm across all heads
|
||||
self.norm_added_q = RMSNorm(dim_head * heads, eps=eps)
|
||||
# Wan applies qk norm across all heads
|
||||
# Wan also doesn't apply a q norm
|
||||
self.norm_added_q = None
|
||||
self.norm_added_k = RMSNorm(dim_head * kv_heads, eps=eps)
|
||||
else:
|
||||
raise ValueError(
|
||||
@@ -740,10 +741,14 @@ class Attention(nn.Module):
|
||||
|
||||
if out_dim == 3:
|
||||
if attention_mask.shape[0] < batch_size * head_size:
|
||||
attention_mask = attention_mask.repeat_interleave(head_size, dim=0)
|
||||
attention_mask = attention_mask.repeat_interleave(
|
||||
head_size, dim=0, output_size=attention_mask.shape[0] * head_size
|
||||
)
|
||||
elif out_dim == 4:
|
||||
attention_mask = attention_mask.unsqueeze(1)
|
||||
attention_mask = attention_mask.repeat_interleave(head_size, dim=1)
|
||||
attention_mask = attention_mask.repeat_interleave(
|
||||
head_size, dim=1, output_size=attention_mask.shape[1] * head_size
|
||||
)
|
||||
|
||||
return attention_mask
|
||||
|
||||
@@ -2334,7 +2339,9 @@ class FluxAttnProcessor2_0:
|
||||
query = apply_rotary_emb(query, image_rotary_emb)
|
||||
key = apply_rotary_emb(key, image_rotary_emb)
|
||||
|
||||
hidden_states = F.scaled_dot_product_attention(query, key, value, dropout_p=0.0, is_causal=False)
|
||||
hidden_states = F.scaled_dot_product_attention(
|
||||
query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
|
||||
)
|
||||
|
||||
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
|
||||
hidden_states = hidden_states.to(query.dtype)
|
||||
@@ -3703,8 +3710,10 @@ class StableAudioAttnProcessor2_0:
|
||||
if kv_heads != attn.heads:
|
||||
# if GQA or MQA, repeat the key/value heads to reach the number of query heads.
|
||||
heads_per_kv_head = attn.heads // kv_heads
|
||||
key = torch.repeat_interleave(key, heads_per_kv_head, dim=1)
|
||||
value = torch.repeat_interleave(value, heads_per_kv_head, dim=1)
|
||||
key = torch.repeat_interleave(key, heads_per_kv_head, dim=1, output_size=key.shape[1] * heads_per_kv_head)
|
||||
value = torch.repeat_interleave(
|
||||
value, heads_per_kv_head, dim=1, output_size=value.shape[1] * heads_per_kv_head
|
||||
)
|
||||
|
||||
if attn.norm_q is not None:
|
||||
query = attn.norm_q(query)
|
||||
@@ -6011,6 +6020,11 @@ class SanaLinearAttnProcessor2_0:
|
||||
key = attn.to_k(encoder_hidden_states)
|
||||
value = attn.to_v(encoder_hidden_states)
|
||||
|
||||
if attn.norm_q is not None:
|
||||
query = attn.norm_q(query)
|
||||
if attn.norm_k is not None:
|
||||
key = attn.norm_k(key)
|
||||
|
||||
query = query.transpose(1, 2).unflatten(1, (attn.heads, -1))
|
||||
key = key.transpose(1, 2).unflatten(1, (attn.heads, -1)).transpose(2, 3)
|
||||
value = value.transpose(1, 2).unflatten(1, (attn.heads, -1))
|
||||
|
||||
@@ -190,7 +190,7 @@ class DCUpBlock2d(nn.Module):
|
||||
x = F.pixel_shuffle(x, self.factor)
|
||||
|
||||
if self.shortcut:
|
||||
y = hidden_states.repeat_interleave(self.repeats, dim=1)
|
||||
y = hidden_states.repeat_interleave(self.repeats, dim=1, output_size=hidden_states.shape[1] * self.repeats)
|
||||
y = F.pixel_shuffle(y, self.factor)
|
||||
hidden_states = x + y
|
||||
else:
|
||||
@@ -361,7 +361,9 @@ class Decoder(nn.Module):
|
||||
|
||||
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
||||
if self.in_shortcut:
|
||||
x = hidden_states.repeat_interleave(self.in_shortcut_repeats, dim=1)
|
||||
x = hidden_states.repeat_interleave(
|
||||
self.in_shortcut_repeats, dim=1, output_size=hidden_states.shape[1] * self.in_shortcut_repeats
|
||||
)
|
||||
hidden_states = self.conv_in(hidden_states) + x
|
||||
else:
|
||||
hidden_states = self.conv_in(hidden_states)
|
||||
|
||||
@@ -103,7 +103,7 @@ class AllegroTemporalConvLayer(nn.Module):
|
||||
if self.down_sample:
|
||||
identity = hidden_states[:, :, ::2]
|
||||
elif self.up_sample:
|
||||
identity = hidden_states.repeat_interleave(2, dim=2)
|
||||
identity = hidden_states.repeat_interleave(2, dim=2, output_size=hidden_states.shape[2] * 2)
|
||||
else:
|
||||
identity = hidden_states
|
||||
|
||||
|
||||
@@ -196,6 +196,55 @@ class LTXVideoResnetBlock3d(nn.Module):
|
||||
return hidden_states
|
||||
|
||||
|
||||
class LTXVideoDownsampler3d(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
in_channels: int,
|
||||
out_channels: int,
|
||||
stride: Union[int, Tuple[int, int, int]] = 1,
|
||||
is_causal: bool = True,
|
||||
padding_mode: str = "zeros",
|
||||
) -> None:
|
||||
super().__init__()
|
||||
|
||||
self.stride = stride if isinstance(stride, tuple) else (stride, stride, stride)
|
||||
self.group_size = (in_channels * stride[0] * stride[1] * stride[2]) // out_channels
|
||||
|
||||
out_channels = out_channels // (self.stride[0] * self.stride[1] * self.stride[2])
|
||||
|
||||
self.conv = LTXVideoCausalConv3d(
|
||||
in_channels=in_channels,
|
||||
out_channels=out_channels,
|
||||
kernel_size=3,
|
||||
stride=1,
|
||||
is_causal=is_causal,
|
||||
padding_mode=padding_mode,
|
||||
)
|
||||
|
||||
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
||||
hidden_states = torch.cat([hidden_states[:, :, : self.stride[0] - 1], hidden_states], dim=2)
|
||||
|
||||
residual = (
|
||||
hidden_states.unflatten(4, (-1, self.stride[2]))
|
||||
.unflatten(3, (-1, self.stride[1]))
|
||||
.unflatten(2, (-1, self.stride[0]))
|
||||
)
|
||||
residual = residual.permute(0, 1, 3, 5, 7, 2, 4, 6).flatten(1, 4)
|
||||
residual = residual.unflatten(1, (-1, self.group_size))
|
||||
residual = residual.mean(dim=2)
|
||||
|
||||
hidden_states = self.conv(hidden_states)
|
||||
hidden_states = (
|
||||
hidden_states.unflatten(4, (-1, self.stride[2]))
|
||||
.unflatten(3, (-1, self.stride[1]))
|
||||
.unflatten(2, (-1, self.stride[0]))
|
||||
)
|
||||
hidden_states = hidden_states.permute(0, 1, 3, 5, 7, 2, 4, 6).flatten(1, 4)
|
||||
hidden_states = hidden_states + residual
|
||||
|
||||
return hidden_states
|
||||
|
||||
|
||||
class LTXVideoUpsampler3d(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
@@ -204,6 +253,7 @@ class LTXVideoUpsampler3d(nn.Module):
|
||||
is_causal: bool = True,
|
||||
residual: bool = False,
|
||||
upscale_factor: int = 1,
|
||||
padding_mode: str = "zeros",
|
||||
) -> None:
|
||||
super().__init__()
|
||||
|
||||
@@ -219,6 +269,7 @@ class LTXVideoUpsampler3d(nn.Module):
|
||||
kernel_size=3,
|
||||
stride=1,
|
||||
is_causal=is_causal,
|
||||
padding_mode=padding_mode,
|
||||
)
|
||||
|
||||
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
||||
@@ -352,6 +403,118 @@ class LTXVideoDownBlock3D(nn.Module):
|
||||
return hidden_states
|
||||
|
||||
|
||||
class LTXVideo095DownBlock3D(nn.Module):
|
||||
r"""
|
||||
Down block used in the LTXVideo model.
|
||||
|
||||
Args:
|
||||
in_channels (`int`):
|
||||
Number of input channels.
|
||||
out_channels (`int`, *optional*):
|
||||
Number of output channels. If None, defaults to `in_channels`.
|
||||
num_layers (`int`, defaults to `1`):
|
||||
Number of resnet layers.
|
||||
dropout (`float`, defaults to `0.0`):
|
||||
Dropout rate.
|
||||
resnet_eps (`float`, defaults to `1e-6`):
|
||||
Epsilon value for normalization layers.
|
||||
resnet_act_fn (`str`, defaults to `"swish"`):
|
||||
Activation function to use.
|
||||
spatio_temporal_scale (`bool`, defaults to `True`):
|
||||
Whether or not to use a downsampling layer. If not used, output dimension would be same as input dimension.
|
||||
Whether or not to downsample across temporal dimension.
|
||||
is_causal (`bool`, defaults to `True`):
|
||||
Whether this layer behaves causally (future frames depend only on past frames) or not.
|
||||
"""
|
||||
|
||||
_supports_gradient_checkpointing = True
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
in_channels: int,
|
||||
out_channels: Optional[int] = None,
|
||||
num_layers: int = 1,
|
||||
dropout: float = 0.0,
|
||||
resnet_eps: float = 1e-6,
|
||||
resnet_act_fn: str = "swish",
|
||||
spatio_temporal_scale: bool = True,
|
||||
is_causal: bool = True,
|
||||
downsample_type: str = "conv",
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
out_channels = out_channels or in_channels
|
||||
|
||||
resnets = []
|
||||
for _ in range(num_layers):
|
||||
resnets.append(
|
||||
LTXVideoResnetBlock3d(
|
||||
in_channels=in_channels,
|
||||
out_channels=in_channels,
|
||||
dropout=dropout,
|
||||
eps=resnet_eps,
|
||||
non_linearity=resnet_act_fn,
|
||||
is_causal=is_causal,
|
||||
)
|
||||
)
|
||||
self.resnets = nn.ModuleList(resnets)
|
||||
|
||||
self.downsamplers = None
|
||||
if spatio_temporal_scale:
|
||||
self.downsamplers = nn.ModuleList()
|
||||
|
||||
if downsample_type == "conv":
|
||||
self.downsamplers.append(
|
||||
LTXVideoCausalConv3d(
|
||||
in_channels=in_channels,
|
||||
out_channels=in_channels,
|
||||
kernel_size=3,
|
||||
stride=(2, 2, 2),
|
||||
is_causal=is_causal,
|
||||
)
|
||||
)
|
||||
elif downsample_type == "spatial":
|
||||
self.downsamplers.append(
|
||||
LTXVideoDownsampler3d(
|
||||
in_channels=in_channels, out_channels=out_channels, stride=(1, 2, 2), is_causal=is_causal
|
||||
)
|
||||
)
|
||||
elif downsample_type == "temporal":
|
||||
self.downsamplers.append(
|
||||
LTXVideoDownsampler3d(
|
||||
in_channels=in_channels, out_channels=out_channels, stride=(2, 1, 1), is_causal=is_causal
|
||||
)
|
||||
)
|
||||
elif downsample_type == "spatiotemporal":
|
||||
self.downsamplers.append(
|
||||
LTXVideoDownsampler3d(
|
||||
in_channels=in_channels, out_channels=out_channels, stride=(2, 2, 2), is_causal=is_causal
|
||||
)
|
||||
)
|
||||
|
||||
self.gradient_checkpointing = False
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
temb: Optional[torch.Tensor] = None,
|
||||
generator: Optional[torch.Generator] = None,
|
||||
) -> torch.Tensor:
|
||||
r"""Forward method of the `LTXDownBlock3D` class."""
|
||||
|
||||
for i, resnet in enumerate(self.resnets):
|
||||
if torch.is_grad_enabled() and self.gradient_checkpointing:
|
||||
hidden_states = self._gradient_checkpointing_func(resnet, hidden_states, temb, generator)
|
||||
else:
|
||||
hidden_states = resnet(hidden_states, temb, generator)
|
||||
|
||||
if self.downsamplers is not None:
|
||||
for downsampler in self.downsamplers:
|
||||
hidden_states = downsampler(hidden_states)
|
||||
|
||||
return hidden_states
|
||||
|
||||
|
||||
# Adapted from diffusers.models.autoencoders.autoencoder_kl_cogvideox.CogVideoMidBlock3d
|
||||
class LTXVideoMidBlock3d(nn.Module):
|
||||
r"""
|
||||
@@ -593,8 +756,15 @@ class LTXVideoEncoder3d(nn.Module):
|
||||
in_channels: int = 3,
|
||||
out_channels: int = 128,
|
||||
block_out_channels: Tuple[int, ...] = (128, 256, 512, 512),
|
||||
down_block_types: Tuple[str, ...] = (
|
||||
"LTXVideoDownBlock3D",
|
||||
"LTXVideoDownBlock3D",
|
||||
"LTXVideoDownBlock3D",
|
||||
"LTXVideoDownBlock3D",
|
||||
),
|
||||
spatio_temporal_scaling: Tuple[bool, ...] = (True, True, True, False),
|
||||
layers_per_block: Tuple[int, ...] = (4, 3, 3, 3, 4),
|
||||
downsample_type: Tuple[str, ...] = ("conv", "conv", "conv", "conv"),
|
||||
patch_size: int = 4,
|
||||
patch_size_t: int = 1,
|
||||
resnet_norm_eps: float = 1e-6,
|
||||
@@ -617,20 +787,37 @@ class LTXVideoEncoder3d(nn.Module):
|
||||
)
|
||||
|
||||
# down blocks
|
||||
num_block_out_channels = len(block_out_channels)
|
||||
is_ltx_095 = down_block_types[-1] == "LTXVideo095DownBlock3D"
|
||||
num_block_out_channels = len(block_out_channels) - (1 if is_ltx_095 else 0)
|
||||
self.down_blocks = nn.ModuleList([])
|
||||
for i in range(num_block_out_channels):
|
||||
input_channel = output_channel
|
||||
output_channel = block_out_channels[i + 1] if i + 1 < num_block_out_channels else block_out_channels[i]
|
||||
if not is_ltx_095:
|
||||
output_channel = block_out_channels[i + 1] if i + 1 < num_block_out_channels else block_out_channels[i]
|
||||
else:
|
||||
output_channel = block_out_channels[i + 1]
|
||||
|
||||
down_block = LTXVideoDownBlock3D(
|
||||
in_channels=input_channel,
|
||||
out_channels=output_channel,
|
||||
num_layers=layers_per_block[i],
|
||||
resnet_eps=resnet_norm_eps,
|
||||
spatio_temporal_scale=spatio_temporal_scaling[i],
|
||||
is_causal=is_causal,
|
||||
)
|
||||
if down_block_types[i] == "LTXVideoDownBlock3D":
|
||||
down_block = LTXVideoDownBlock3D(
|
||||
in_channels=input_channel,
|
||||
out_channels=output_channel,
|
||||
num_layers=layers_per_block[i],
|
||||
resnet_eps=resnet_norm_eps,
|
||||
spatio_temporal_scale=spatio_temporal_scaling[i],
|
||||
is_causal=is_causal,
|
||||
)
|
||||
elif down_block_types[i] == "LTXVideo095DownBlock3D":
|
||||
down_block = LTXVideo095DownBlock3D(
|
||||
in_channels=input_channel,
|
||||
out_channels=output_channel,
|
||||
num_layers=layers_per_block[i],
|
||||
resnet_eps=resnet_norm_eps,
|
||||
spatio_temporal_scale=spatio_temporal_scaling[i],
|
||||
is_causal=is_causal,
|
||||
downsample_type=downsample_type[i],
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Unknown down block type: {down_block_types[i]}")
|
||||
|
||||
self.down_blocks.append(down_block)
|
||||
|
||||
@@ -794,7 +981,9 @@ class LTXVideoDecoder3d(nn.Module):
|
||||
# timestep embedding
|
||||
self.time_embedder = None
|
||||
self.scale_shift_table = None
|
||||
self.timestep_scale_multiplier = None
|
||||
if timestep_conditioning:
|
||||
self.timestep_scale_multiplier = nn.Parameter(torch.tensor(1000.0, dtype=torch.float32))
|
||||
self.time_embedder = PixArtAlphaCombinedTimestepSizeEmbeddings(output_channel * 2, 0)
|
||||
self.scale_shift_table = nn.Parameter(torch.randn(2, output_channel) / output_channel**0.5)
|
||||
|
||||
@@ -803,6 +992,9 @@ class LTXVideoDecoder3d(nn.Module):
|
||||
def forward(self, hidden_states: torch.Tensor, temb: Optional[torch.Tensor] = None) -> torch.Tensor:
|
||||
hidden_states = self.conv_in(hidden_states)
|
||||
|
||||
if self.timestep_scale_multiplier is not None:
|
||||
temb = temb * self.timestep_scale_multiplier
|
||||
|
||||
if torch.is_grad_enabled() and self.gradient_checkpointing:
|
||||
hidden_states = self._gradient_checkpointing_func(self.mid_block, hidden_states, temb)
|
||||
|
||||
@@ -891,12 +1083,19 @@ class AutoencoderKLLTXVideo(ModelMixin, ConfigMixin, FromOriginalModelMixin):
|
||||
out_channels: int = 3,
|
||||
latent_channels: int = 128,
|
||||
block_out_channels: Tuple[int, ...] = (128, 256, 512, 512),
|
||||
down_block_types: Tuple[str, ...] = (
|
||||
"LTXVideoDownBlock3D",
|
||||
"LTXVideoDownBlock3D",
|
||||
"LTXVideoDownBlock3D",
|
||||
"LTXVideoDownBlock3D",
|
||||
),
|
||||
decoder_block_out_channels: Tuple[int, ...] = (128, 256, 512, 512),
|
||||
layers_per_block: Tuple[int, ...] = (4, 3, 3, 3, 4),
|
||||
decoder_layers_per_block: Tuple[int, ...] = (4, 3, 3, 3, 4),
|
||||
spatio_temporal_scaling: Tuple[bool, ...] = (True, True, True, False),
|
||||
decoder_spatio_temporal_scaling: Tuple[bool, ...] = (True, True, True, False),
|
||||
decoder_inject_noise: Tuple[bool, ...] = (False, False, False, False, False),
|
||||
downsample_type: Tuple[str, ...] = ("conv", "conv", "conv", "conv"),
|
||||
upsample_residual: Tuple[bool, ...] = (False, False, False, False),
|
||||
upsample_factor: Tuple[int, ...] = (1, 1, 1, 1),
|
||||
timestep_conditioning: bool = False,
|
||||
@@ -906,6 +1105,8 @@ class AutoencoderKLLTXVideo(ModelMixin, ConfigMixin, FromOriginalModelMixin):
|
||||
scaling_factor: float = 1.0,
|
||||
encoder_causal: bool = True,
|
||||
decoder_causal: bool = False,
|
||||
spatial_compression_ratio: int = None,
|
||||
temporal_compression_ratio: int = None,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
|
||||
@@ -913,8 +1114,10 @@ class AutoencoderKLLTXVideo(ModelMixin, ConfigMixin, FromOriginalModelMixin):
|
||||
in_channels=in_channels,
|
||||
out_channels=latent_channels,
|
||||
block_out_channels=block_out_channels,
|
||||
down_block_types=down_block_types,
|
||||
spatio_temporal_scaling=spatio_temporal_scaling,
|
||||
layers_per_block=layers_per_block,
|
||||
downsample_type=downsample_type,
|
||||
patch_size=patch_size,
|
||||
patch_size_t=patch_size_t,
|
||||
resnet_norm_eps=resnet_norm_eps,
|
||||
@@ -941,8 +1144,16 @@ class AutoencoderKLLTXVideo(ModelMixin, ConfigMixin, FromOriginalModelMixin):
|
||||
self.register_buffer("latents_mean", latents_mean, persistent=True)
|
||||
self.register_buffer("latents_std", latents_std, persistent=True)
|
||||
|
||||
self.spatial_compression_ratio = patch_size * 2 ** sum(spatio_temporal_scaling)
|
||||
self.temporal_compression_ratio = patch_size_t * 2 ** sum(spatio_temporal_scaling)
|
||||
self.spatial_compression_ratio = (
|
||||
patch_size * 2 ** sum(spatio_temporal_scaling)
|
||||
if spatial_compression_ratio is None
|
||||
else spatial_compression_ratio
|
||||
)
|
||||
self.temporal_compression_ratio = (
|
||||
patch_size_t * 2 ** sum(spatio_temporal_scaling)
|
||||
if temporal_compression_ratio is None
|
||||
else temporal_compression_ratio
|
||||
)
|
||||
|
||||
# When decoding a batch of video latents at a time, one can save memory by slicing across the batch dimension
|
||||
# to perform decoding of a single video latent at a time.
|
||||
|
||||
@@ -426,7 +426,9 @@ class FourierFeatures(nn.Module):
|
||||
w = w.repeat(num_channels)[None, :, None, None, None] # [1, num_channels * num_freqs, 1, 1, 1]
|
||||
|
||||
# Interleaved repeat of input channels to match w
|
||||
h = inputs.repeat_interleave(num_freqs, dim=1) # [B, C * num_freqs, T, H, W]
|
||||
h = inputs.repeat_interleave(
|
||||
num_freqs, dim=1, output_size=inputs.shape[1] * num_freqs
|
||||
) # [B, C * num_freqs, T, H, W]
|
||||
# Scale channels by frequency.
|
||||
h = w * h
|
||||
|
||||
|
||||
@@ -20,6 +20,7 @@ import torch.nn.functional as F
|
||||
import torch.utils.checkpoint
|
||||
|
||||
from ...configuration_utils import ConfigMixin, register_to_config
|
||||
from ...loaders import FromOriginalModelMixin
|
||||
from ...utils import logging
|
||||
from ...utils.accelerate_utils import apply_forward_hook
|
||||
from ..activations import get_activation
|
||||
@@ -655,7 +656,7 @@ class WanDecoder3d(nn.Module):
|
||||
return x
|
||||
|
||||
|
||||
class AutoencoderKLWan(ModelMixin, ConfigMixin):
|
||||
class AutoencoderKLWan(ModelMixin, ConfigMixin, FromOriginalModelMixin):
|
||||
r"""
|
||||
A VAE model with KL loss for encoding videos into latents and decoding latent representations into videos.
|
||||
Introduced in [Wan 2.1].
|
||||
@@ -715,11 +716,6 @@ class AutoencoderKLWan(ModelMixin, ConfigMixin):
|
||||
) -> None:
|
||||
super().__init__()
|
||||
|
||||
# Store normalization parameters as tensors
|
||||
self.mean = torch.tensor(latents_mean)
|
||||
self.std = torch.tensor(latents_std)
|
||||
self.scale = torch.stack([self.mean, 1.0 / self.std]) # Shape: [2, C]
|
||||
|
||||
self.z_dim = z_dim
|
||||
self.temperal_downsample = temperal_downsample
|
||||
self.temperal_upsample = temperal_downsample[::-1]
|
||||
@@ -751,7 +747,6 @@ class AutoencoderKLWan(ModelMixin, ConfigMixin):
|
||||
self._enc_feat_map = [None] * self._enc_conv_num
|
||||
|
||||
def _encode(self, x: torch.Tensor) -> torch.Tensor:
|
||||
scale = self.scale.type_as(x)
|
||||
self.clear_cache()
|
||||
## cache
|
||||
t = x.shape[2]
|
||||
@@ -770,8 +765,6 @@ class AutoencoderKLWan(ModelMixin, ConfigMixin):
|
||||
|
||||
enc = self.quant_conv(out)
|
||||
mu, logvar = enc[:, : self.z_dim, :, :, :], enc[:, self.z_dim :, :, :, :]
|
||||
mu = (mu - scale[0].view(1, self.z_dim, 1, 1, 1)) * scale[1].view(1, self.z_dim, 1, 1, 1)
|
||||
logvar = (logvar - scale[0].view(1, self.z_dim, 1, 1, 1)) * scale[1].view(1, self.z_dim, 1, 1, 1)
|
||||
enc = torch.cat([mu, logvar], dim=1)
|
||||
self.clear_cache()
|
||||
return enc
|
||||
@@ -798,10 +791,8 @@ class AutoencoderKLWan(ModelMixin, ConfigMixin):
|
||||
return (posterior,)
|
||||
return AutoencoderKLOutput(latent_dist=posterior)
|
||||
|
||||
def _decode(self, z: torch.Tensor, scale, return_dict: bool = True) -> Union[DecoderOutput, torch.Tensor]:
|
||||
def _decode(self, z: torch.Tensor, return_dict: bool = True) -> Union[DecoderOutput, torch.Tensor]:
|
||||
self.clear_cache()
|
||||
# z: [b,c,t,h,w]
|
||||
z = z / scale[1].view(1, self.z_dim, 1, 1, 1) + scale[0].view(1, self.z_dim, 1, 1, 1)
|
||||
|
||||
iter_ = z.shape[2]
|
||||
x = self.post_quant_conv(z)
|
||||
@@ -835,8 +826,7 @@ class AutoencoderKLWan(ModelMixin, ConfigMixin):
|
||||
If return_dict is True, a [`~models.vae.DecoderOutput`] is returned, otherwise a plain `tuple` is
|
||||
returned.
|
||||
"""
|
||||
scale = self.scale.type_as(z)
|
||||
decoded = self._decode(z, scale).sample
|
||||
decoded = self._decode(z).sample
|
||||
if not return_dict:
|
||||
return (decoded,)
|
||||
|
||||
|
||||
@@ -24,6 +24,7 @@ class CacheMixin:
|
||||
|
||||
Supported caching techniques:
|
||||
- [Pyramid Attention Broadcast](https://huggingface.co/papers/2408.12588)
|
||||
- [FasterCache](https://huggingface.co/papers/2410.19355)
|
||||
"""
|
||||
|
||||
_cache_config = None
|
||||
@@ -59,17 +60,31 @@ class CacheMixin:
|
||||
```
|
||||
"""
|
||||
|
||||
from ..hooks import PyramidAttentionBroadcastConfig, apply_pyramid_attention_broadcast
|
||||
from ..hooks import (
|
||||
FasterCacheConfig,
|
||||
PyramidAttentionBroadcastConfig,
|
||||
apply_faster_cache,
|
||||
apply_pyramid_attention_broadcast,
|
||||
)
|
||||
|
||||
if self.is_cache_enabled:
|
||||
raise ValueError(
|
||||
f"Caching has already been enabled with {type(self._cache_config)}. To apply a new caching technique, please disable the existing one first."
|
||||
)
|
||||
|
||||
if isinstance(config, PyramidAttentionBroadcastConfig):
|
||||
apply_pyramid_attention_broadcast(self, config)
|
||||
elif isinstance(config, FasterCacheConfig):
|
||||
apply_faster_cache(self, config)
|
||||
else:
|
||||
raise ValueError(f"Cache config {type(config)} is not supported.")
|
||||
|
||||
self._cache_config = config
|
||||
|
||||
def disable_cache(self) -> None:
|
||||
from ..hooks import HookRegistry, PyramidAttentionBroadcastConfig
|
||||
from ..hooks import FasterCacheConfig, HookRegistry, PyramidAttentionBroadcastConfig
|
||||
from ..hooks.faster_cache import _FASTER_CACHE_BLOCK_HOOK, _FASTER_CACHE_DENOISER_HOOK
|
||||
from ..hooks.pyramid_attention_broadcast import _PYRAMID_ATTENTION_BROADCAST_HOOK
|
||||
|
||||
if self._cache_config is None:
|
||||
logger.warning("Caching techniques have not been enabled, so there's nothing to disable.")
|
||||
@@ -77,7 +92,11 @@ class CacheMixin:
|
||||
|
||||
if isinstance(self._cache_config, PyramidAttentionBroadcastConfig):
|
||||
registry = HookRegistry.check_if_exists_or_initialize(self)
|
||||
registry.remove_hook("pyramid_attention_broadcast", recurse=True)
|
||||
registry.remove_hook(_PYRAMID_ATTENTION_BROADCAST_HOOK, recurse=True)
|
||||
elif isinstance(self._cache_config, FasterCacheConfig):
|
||||
registry = HookRegistry.check_if_exists_or_initialize(self)
|
||||
registry.remove_hook(_FASTER_CACHE_DENOISER_HOOK, recurse=True)
|
||||
registry.remove_hook(_FASTER_CACHE_BLOCK_HOOK, recurse=True)
|
||||
else:
|
||||
raise ValueError(f"Cache config {type(self._cache_config)} is not supported.")
|
||||
|
||||
|
||||
@@ -687,7 +687,7 @@ class SparseControlNetModel(ModelMixin, ConfigMixin, FromOriginalModelMixin):
|
||||
t_emb = t_emb.to(dtype=sample.dtype)
|
||||
|
||||
emb = self.time_embedding(t_emb, timestep_cond)
|
||||
emb = emb.repeat_interleave(sample_num_frames, dim=0)
|
||||
emb = emb.repeat_interleave(sample_num_frames, dim=0, output_size=emb.shape[0] * sample_num_frames)
|
||||
|
||||
# 2. pre-process
|
||||
batch_size, channels, num_frames, height, width = sample.shape
|
||||
|
||||
@@ -139,7 +139,9 @@ def get_3d_sincos_pos_embed(
|
||||
|
||||
# 3. Concat
|
||||
pos_embed_spatial = pos_embed_spatial[None, :, :]
|
||||
pos_embed_spatial = pos_embed_spatial.repeat_interleave(temporal_size, dim=0) # [T, H*W, D // 4 * 3]
|
||||
pos_embed_spatial = pos_embed_spatial.repeat_interleave(
|
||||
temporal_size, dim=0, output_size=pos_embed_spatial.shape[0] * temporal_size
|
||||
) # [T, H*W, D // 4 * 3]
|
||||
|
||||
pos_embed_temporal = pos_embed_temporal[:, None, :]
|
||||
pos_embed_temporal = pos_embed_temporal.repeat_interleave(
|
||||
@@ -334,7 +336,7 @@ def get_1d_sincos_pos_embed_from_grid(embed_dim, pos, output_type="np"):
|
||||
" `from_numpy` is no longer required."
|
||||
" Pass `output_type='pt' to use the new version now."
|
||||
)
|
||||
deprecate("output_type=='np'", "0.33.0", deprecation_message, standard_warn=False)
|
||||
deprecate("output_type=='np'", "0.34.0", deprecation_message, standard_warn=False)
|
||||
return get_1d_sincos_pos_embed_from_grid_np(embed_dim=embed_dim, pos=pos)
|
||||
if embed_dim % 2 != 0:
|
||||
raise ValueError("embed_dim must be divisible by 2")
|
||||
@@ -1152,10 +1154,13 @@ def get_1d_rotary_pos_embed(
|
||||
/ linear_factor
|
||||
) # [D/2]
|
||||
freqs = torch.outer(pos, freqs) # type: ignore # [S, D/2]
|
||||
is_npu = freqs.device.type == "npu"
|
||||
if is_npu:
|
||||
freqs = freqs.float()
|
||||
if use_real and repeat_interleave_real:
|
||||
# flux, hunyuan-dit, cogvideox
|
||||
freqs_cos = freqs.cos().repeat_interleave(2, dim=1).float() # [S, D]
|
||||
freqs_sin = freqs.sin().repeat_interleave(2, dim=1).float() # [S, D]
|
||||
freqs_cos = freqs.cos().repeat_interleave(2, dim=1, output_size=freqs.shape[1] * 2).float() # [S, D]
|
||||
freqs_sin = freqs.sin().repeat_interleave(2, dim=1, output_size=freqs.shape[1] * 2).float() # [S, D]
|
||||
return freqs_cos, freqs_sin
|
||||
elif use_real:
|
||||
# stable audio, allegro
|
||||
|
||||
@@ -245,6 +245,9 @@ def load_model_dict_into_meta(
|
||||
):
|
||||
param = param.to(torch.float32)
|
||||
set_module_kwargs["dtype"] = torch.float32
|
||||
# For quantizers have save weights using torch.float8_e4m3fn
|
||||
elif hf_quantizer is not None and param.dtype == getattr(torch, "float8_e4m3fn", None):
|
||||
pass
|
||||
else:
|
||||
param = param.to(dtype)
|
||||
set_module_kwargs["dtype"] = dtype
|
||||
@@ -292,7 +295,9 @@ def load_model_dict_into_meta(
|
||||
elif is_quantized and (
|
||||
hf_quantizer.check_if_quantized_param(model, param, param_name, state_dict, param_device=param_device)
|
||||
):
|
||||
hf_quantizer.create_quantized_param(model, param, param_name, param_device, state_dict, unexpected_keys)
|
||||
hf_quantizer.create_quantized_param(
|
||||
model, param, param_name, param_device, state_dict, unexpected_keys, dtype=dtype
|
||||
)
|
||||
else:
|
||||
set_module_tensor_to_device(model, param_name, param_device, value=param, **set_module_kwargs)
|
||||
|
||||
|
||||
@@ -37,7 +37,6 @@ from torch import Tensor, nn
|
||||
from typing_extensions import Self
|
||||
|
||||
from .. import __version__
|
||||
from ..hooks import apply_group_offloading, apply_layerwise_casting
|
||||
from ..quantizers import DiffusersAutoQuantizer, DiffusersQuantizer
|
||||
from ..quantizers.quantization_config import QuantizationMethod
|
||||
from ..utils import (
|
||||
@@ -504,6 +503,7 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
|
||||
non_blocking (`bool`, *optional*, defaults to `False`):
|
||||
If `True`, the weight casting operations are non-blocking.
|
||||
"""
|
||||
from ..hooks import apply_layerwise_casting
|
||||
|
||||
user_provided_patterns = True
|
||||
if skip_modules_pattern is None:
|
||||
@@ -546,6 +546,7 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
|
||||
num_blocks_per_group: Optional[int] = None,
|
||||
non_blocking: bool = False,
|
||||
use_stream: bool = False,
|
||||
low_cpu_mem_usage=False,
|
||||
) -> None:
|
||||
r"""
|
||||
Activates group offloading for the current model.
|
||||
@@ -569,6 +570,8 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
|
||||
... )
|
||||
```
|
||||
"""
|
||||
from ..hooks import apply_group_offloading
|
||||
|
||||
if getattr(self, "enable_tiling", None) is not None and getattr(self, "use_tiling", False) and use_stream:
|
||||
msg = (
|
||||
"Applying group offloading on autoencoders, with CUDA streams, may not work as expected if the first "
|
||||
@@ -584,7 +587,14 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
|
||||
f"open an issue at https://github.com/huggingface/diffusers/issues."
|
||||
)
|
||||
apply_group_offloading(
|
||||
self, onload_device, offload_device, offload_type, num_blocks_per_group, non_blocking, use_stream
|
||||
self,
|
||||
onload_device,
|
||||
offload_device,
|
||||
offload_type,
|
||||
num_blocks_per_group,
|
||||
non_blocking,
|
||||
use_stream,
|
||||
low_cpu_mem_usage=low_cpu_mem_usage,
|
||||
)
|
||||
|
||||
def save_pretrained(
|
||||
@@ -870,7 +880,7 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
|
||||
local_files_only = kwargs.pop("local_files_only", None)
|
||||
token = kwargs.pop("token", None)
|
||||
revision = kwargs.pop("revision", None)
|
||||
torch_dtype = kwargs.pop("torch_dtype", torch.float32)
|
||||
torch_dtype = kwargs.pop("torch_dtype", None)
|
||||
subfolder = kwargs.pop("subfolder", None)
|
||||
device_map = kwargs.pop("device_map", None)
|
||||
max_memory = kwargs.pop("max_memory", None)
|
||||
@@ -883,7 +893,7 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
|
||||
dduf_entries: Optional[Dict[str, DDUFEntry]] = kwargs.pop("dduf_entries", None)
|
||||
disable_mmap = kwargs.pop("disable_mmap", False)
|
||||
|
||||
if not isinstance(torch_dtype, torch.dtype):
|
||||
if torch_dtype is not None and not isinstance(torch_dtype, torch.dtype):
|
||||
torch_dtype = torch.float32
|
||||
logger.warning(
|
||||
f"Passed `torch_dtype` {torch_dtype} is not a `torch.dtype`. Defaulting to `torch.float32`."
|
||||
|
||||
@@ -550,16 +550,6 @@ class RMSNorm(nn.Module):
|
||||
hidden_states = torch_npu.npu_rms_norm(hidden_states, self.weight, epsilon=self.eps)[0]
|
||||
if self.bias is not None:
|
||||
hidden_states = hidden_states + self.bias
|
||||
elif is_torch_version(">=", "2.4"):
|
||||
if self.weight is not None:
|
||||
# convert into half-precision if necessary
|
||||
if self.weight.dtype in [torch.float16, torch.bfloat16]:
|
||||
hidden_states = hidden_states.to(self.weight.dtype)
|
||||
hidden_states = nn.functional.rms_norm(
|
||||
hidden_states, normalized_shape=(hidden_states.shape[-1],), weight=self.weight, eps=self.eps
|
||||
)
|
||||
if self.bias is not None:
|
||||
hidden_states = hidden_states + self.bias
|
||||
else:
|
||||
input_dtype = hidden_states.dtype
|
||||
variance = hidden_states.to(torch.float32).pow(2).mean(-1, keepdim=True)
|
||||
|
||||
@@ -366,7 +366,7 @@ class ResnetBlock2D(nn.Module):
|
||||
hidden_states = self.conv2(hidden_states)
|
||||
|
||||
if self.conv_shortcut is not None:
|
||||
input_tensor = self.conv_shortcut(input_tensor)
|
||||
input_tensor = self.conv_shortcut(input_tensor.contiguous())
|
||||
|
||||
output_tensor = (input_tensor + hidden_states) / self.output_scale_factor
|
||||
|
||||
|
||||
@@ -227,13 +227,17 @@ class LatteTransformer3DModel(ModelMixin, ConfigMixin, CacheMixin):
|
||||
# Prepare text embeddings for spatial block
|
||||
# batch_size num_tokens hidden_size -> (batch_size * num_frame) num_tokens hidden_size
|
||||
encoder_hidden_states = self.caption_projection(encoder_hidden_states) # 3 120 1152
|
||||
encoder_hidden_states_spatial = encoder_hidden_states.repeat_interleave(num_frame, dim=0).view(
|
||||
-1, encoder_hidden_states.shape[-2], encoder_hidden_states.shape[-1]
|
||||
)
|
||||
encoder_hidden_states_spatial = encoder_hidden_states.repeat_interleave(
|
||||
num_frame, dim=0, output_size=encoder_hidden_states.shape[0] * num_frame
|
||||
).view(-1, encoder_hidden_states.shape[-2], encoder_hidden_states.shape[-1])
|
||||
|
||||
# Prepare timesteps for spatial and temporal block
|
||||
timestep_spatial = timestep.repeat_interleave(num_frame, dim=0).view(-1, timestep.shape[-1])
|
||||
timestep_temp = timestep.repeat_interleave(num_patches, dim=0).view(-1, timestep.shape[-1])
|
||||
timestep_spatial = timestep.repeat_interleave(
|
||||
num_frame, dim=0, output_size=timestep.shape[0] * num_frame
|
||||
).view(-1, timestep.shape[-1])
|
||||
timestep_temp = timestep.repeat_interleave(
|
||||
num_patches, dim=0, output_size=timestep.shape[0] * num_patches
|
||||
).view(-1, timestep.shape[-1])
|
||||
|
||||
# Spatial and temporal transformer blocks
|
||||
for i, (spatial_block, temp_block) in enumerate(
|
||||
@@ -269,7 +273,7 @@ class LatteTransformer3DModel(ModelMixin, ConfigMixin, CacheMixin):
|
||||
hidden_states = hidden_states.reshape(-1, hidden_states.shape[-2], hidden_states.shape[-1])
|
||||
|
||||
if i == 0 and num_frame > 1:
|
||||
hidden_states = hidden_states + self.temp_pos_embed
|
||||
hidden_states = hidden_states + self.temp_pos_embed.to(hidden_states.dtype)
|
||||
|
||||
if torch.is_grad_enabled() and self.gradient_checkpointing:
|
||||
hidden_states = self._gradient_checkpointing_func(
|
||||
@@ -299,7 +303,9 @@ class LatteTransformer3DModel(ModelMixin, ConfigMixin, CacheMixin):
|
||||
).permute(0, 2, 1, 3)
|
||||
hidden_states = hidden_states.reshape(-1, hidden_states.shape[-2], hidden_states.shape[-1])
|
||||
|
||||
embedded_timestep = embedded_timestep.repeat_interleave(num_frame, dim=0).view(-1, embedded_timestep.shape[-1])
|
||||
embedded_timestep = embedded_timestep.repeat_interleave(
|
||||
num_frame, dim=0, output_size=embedded_timestep.shape[0] * num_frame
|
||||
).view(-1, embedded_timestep.shape[-1])
|
||||
shift, scale = (self.scale_shift_table[None] + embedded_timestep[:, None]).chunk(2, dim=1)
|
||||
hidden_states = self.norm_out(hidden_states)
|
||||
# Modulation
|
||||
|
||||
@@ -353,7 +353,11 @@ class PriorTransformer(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin, Pef
|
||||
attention_mask = (1 - attention_mask.to(hidden_states.dtype)) * -10000.0
|
||||
attention_mask = F.pad(attention_mask, (0, self.additional_embeddings), value=0.0)
|
||||
attention_mask = (attention_mask[:, None, :] + self.causal_attention_mask).to(hidden_states.dtype)
|
||||
attention_mask = attention_mask.repeat_interleave(self.config.num_attention_heads, dim=0)
|
||||
attention_mask = attention_mask.repeat_interleave(
|
||||
self.config.num_attention_heads,
|
||||
dim=0,
|
||||
output_size=attention_mask.shape[0] * self.config.num_attention_heads,
|
||||
)
|
||||
|
||||
if self.norm_in is not None:
|
||||
hidden_states = self.norm_in(hidden_states)
|
||||
|
||||
@@ -15,18 +15,18 @@
|
||||
from typing import Any, Dict, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from torch import nn
|
||||
|
||||
from ...configuration_utils import ConfigMixin, register_to_config
|
||||
from ...loaders import PeftAdapterMixin
|
||||
from ...loaders import FromOriginalModelMixin, PeftAdapterMixin
|
||||
from ...utils import USE_PEFT_BACKEND, logging, scale_lora_layers, unscale_lora_layers
|
||||
from ..attention_processor import (
|
||||
Attention,
|
||||
AttentionProcessor,
|
||||
AttnProcessor2_0,
|
||||
SanaLinearAttnProcessor2_0,
|
||||
)
|
||||
from ..embeddings import PatchEmbed, PixArtAlphaTextProjection
|
||||
from ..embeddings import PatchEmbed, PixArtAlphaTextProjection, TimestepEmbedding, Timesteps
|
||||
from ..modeling_outputs import Transformer2DModelOutput
|
||||
from ..modeling_utils import ModelMixin
|
||||
from ..normalization import AdaLayerNormSingle, RMSNorm
|
||||
@@ -96,6 +96,95 @@ class SanaModulatedNorm(nn.Module):
|
||||
return hidden_states
|
||||
|
||||
|
||||
class SanaCombinedTimestepGuidanceEmbeddings(nn.Module):
|
||||
def __init__(self, embedding_dim):
|
||||
super().__init__()
|
||||
self.time_proj = Timesteps(num_channels=256, flip_sin_to_cos=True, downscale_freq_shift=0)
|
||||
self.timestep_embedder = TimestepEmbedding(in_channels=256, time_embed_dim=embedding_dim)
|
||||
|
||||
self.guidance_condition_proj = Timesteps(num_channels=256, flip_sin_to_cos=True, downscale_freq_shift=0)
|
||||
self.guidance_embedder = TimestepEmbedding(in_channels=256, time_embed_dim=embedding_dim)
|
||||
|
||||
self.silu = nn.SiLU()
|
||||
self.linear = nn.Linear(embedding_dim, 6 * embedding_dim, bias=True)
|
||||
|
||||
def forward(self, timestep: torch.Tensor, guidance: torch.Tensor = None, hidden_dtype: torch.dtype = None):
|
||||
timesteps_proj = self.time_proj(timestep)
|
||||
timesteps_emb = self.timestep_embedder(timesteps_proj.to(dtype=hidden_dtype)) # (N, D)
|
||||
|
||||
guidance_proj = self.guidance_condition_proj(guidance)
|
||||
guidance_emb = self.guidance_embedder(guidance_proj.to(dtype=hidden_dtype))
|
||||
conditioning = timesteps_emb + guidance_emb
|
||||
|
||||
return self.linear(self.silu(conditioning)), conditioning
|
||||
|
||||
|
||||
class SanaAttnProcessor2_0:
|
||||
r"""
|
||||
Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0).
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
if not hasattr(F, "scaled_dot_product_attention"):
|
||||
raise ImportError("SanaAttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
attn: Attention,
|
||||
hidden_states: torch.Tensor,
|
||||
encoder_hidden_states: Optional[torch.Tensor] = None,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
batch_size, sequence_length, _ = (
|
||||
hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
|
||||
)
|
||||
|
||||
if attention_mask is not None:
|
||||
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
|
||||
# scaled_dot_product_attention expects attention_mask shape to be
|
||||
# (batch, heads, source_length, target_length)
|
||||
attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])
|
||||
|
||||
query = attn.to_q(hidden_states)
|
||||
|
||||
if encoder_hidden_states is None:
|
||||
encoder_hidden_states = hidden_states
|
||||
|
||||
key = attn.to_k(encoder_hidden_states)
|
||||
value = attn.to_v(encoder_hidden_states)
|
||||
|
||||
if attn.norm_q is not None:
|
||||
query = attn.norm_q(query)
|
||||
if attn.norm_k is not None:
|
||||
key = attn.norm_k(key)
|
||||
|
||||
inner_dim = key.shape[-1]
|
||||
head_dim = inner_dim // attn.heads
|
||||
|
||||
query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
||||
|
||||
key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
||||
value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
||||
|
||||
# the output of sdp = (batch, num_heads, seq_len, head_dim)
|
||||
# TODO: add support for attn.scale when we move to Torch 2.1
|
||||
hidden_states = F.scaled_dot_product_attention(
|
||||
query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
|
||||
)
|
||||
|
||||
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
|
||||
hidden_states = hidden_states.to(query.dtype)
|
||||
|
||||
# linear proj
|
||||
hidden_states = attn.to_out[0](hidden_states)
|
||||
# dropout
|
||||
hidden_states = attn.to_out[1](hidden_states)
|
||||
|
||||
hidden_states = hidden_states / attn.rescale_output_factor
|
||||
|
||||
return hidden_states
|
||||
|
||||
|
||||
class SanaTransformerBlock(nn.Module):
|
||||
r"""
|
||||
Transformer block introduced in [Sana](https://huggingface.co/papers/2410.10629).
|
||||
@@ -115,6 +204,7 @@ class SanaTransformerBlock(nn.Module):
|
||||
norm_eps: float = 1e-6,
|
||||
attention_out_bias: bool = True,
|
||||
mlp_ratio: float = 2.5,
|
||||
qk_norm: Optional[str] = None,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
|
||||
@@ -124,6 +214,8 @@ class SanaTransformerBlock(nn.Module):
|
||||
query_dim=dim,
|
||||
heads=num_attention_heads,
|
||||
dim_head=attention_head_dim,
|
||||
kv_heads=num_attention_heads if qk_norm is not None else None,
|
||||
qk_norm=qk_norm,
|
||||
dropout=dropout,
|
||||
bias=attention_bias,
|
||||
cross_attention_dim=None,
|
||||
@@ -135,13 +227,15 @@ class SanaTransformerBlock(nn.Module):
|
||||
self.norm2 = nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine, eps=norm_eps)
|
||||
self.attn2 = Attention(
|
||||
query_dim=dim,
|
||||
qk_norm=qk_norm,
|
||||
kv_heads=num_cross_attention_heads if qk_norm is not None else None,
|
||||
cross_attention_dim=cross_attention_dim,
|
||||
heads=num_cross_attention_heads,
|
||||
dim_head=cross_attention_head_dim,
|
||||
dropout=dropout,
|
||||
bias=True,
|
||||
out_bias=attention_out_bias,
|
||||
processor=AttnProcessor2_0(),
|
||||
processor=SanaAttnProcessor2_0(),
|
||||
)
|
||||
|
||||
# 3. Feed-forward
|
||||
@@ -195,7 +289,7 @@ class SanaTransformerBlock(nn.Module):
|
||||
return hidden_states
|
||||
|
||||
|
||||
class SanaTransformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin):
|
||||
class SanaTransformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginalModelMixin):
|
||||
r"""
|
||||
A 2D Transformer model introduced in [Sana](https://huggingface.co/papers/2410.10629) family of models.
|
||||
|
||||
@@ -232,6 +326,10 @@ class SanaTransformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin):
|
||||
Whether to use elementwise affinity in the normalization layer.
|
||||
norm_eps (`float`, defaults to `1e-6`):
|
||||
The epsilon value for the normalization layer.
|
||||
qk_norm (`str`, *optional*, defaults to `None`):
|
||||
The normalization to use for the query and key.
|
||||
timestep_scale (`float`, defaults to `1.0`):
|
||||
The scale to use for the timesteps.
|
||||
"""
|
||||
|
||||
_supports_gradient_checkpointing = True
|
||||
@@ -258,6 +356,10 @@ class SanaTransformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin):
|
||||
norm_elementwise_affine: bool = False,
|
||||
norm_eps: float = 1e-6,
|
||||
interpolation_scale: Optional[int] = None,
|
||||
guidance_embeds: bool = False,
|
||||
guidance_embeds_scale: float = 0.1,
|
||||
qk_norm: Optional[str] = None,
|
||||
timestep_scale: float = 1.0,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
|
||||
@@ -276,7 +378,10 @@ class SanaTransformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin):
|
||||
)
|
||||
|
||||
# 2. Additional condition embeddings
|
||||
self.time_embed = AdaLayerNormSingle(inner_dim)
|
||||
if guidance_embeds:
|
||||
self.time_embed = SanaCombinedTimestepGuidanceEmbeddings(inner_dim)
|
||||
else:
|
||||
self.time_embed = AdaLayerNormSingle(inner_dim)
|
||||
|
||||
self.caption_projection = PixArtAlphaTextProjection(in_features=caption_channels, hidden_size=inner_dim)
|
||||
self.caption_norm = RMSNorm(inner_dim, eps=1e-5, elementwise_affine=True)
|
||||
@@ -296,6 +401,7 @@ class SanaTransformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin):
|
||||
norm_elementwise_affine=norm_elementwise_affine,
|
||||
norm_eps=norm_eps,
|
||||
mlp_ratio=mlp_ratio,
|
||||
qk_norm=qk_norm,
|
||||
)
|
||||
for _ in range(num_layers)
|
||||
]
|
||||
@@ -372,7 +478,8 @@ class SanaTransformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin):
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
encoder_hidden_states: torch.Tensor,
|
||||
timestep: torch.LongTensor,
|
||||
timestep: torch.Tensor,
|
||||
guidance: Optional[torch.Tensor] = None,
|
||||
encoder_attention_mask: Optional[torch.Tensor] = None,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
attention_kwargs: Optional[Dict[str, Any]] = None,
|
||||
@@ -423,9 +530,14 @@ class SanaTransformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin):
|
||||
|
||||
hidden_states = self.patch_embed(hidden_states)
|
||||
|
||||
timestep, embedded_timestep = self.time_embed(
|
||||
timestep, batch_size=batch_size, hidden_dtype=hidden_states.dtype
|
||||
)
|
||||
if guidance is not None:
|
||||
timestep, embedded_timestep = self.time_embed(
|
||||
timestep, guidance=guidance, hidden_dtype=hidden_states.dtype
|
||||
)
|
||||
else:
|
||||
timestep, embedded_timestep = self.time_embed(
|
||||
timestep, batch_size=batch_size, hidden_dtype=hidden_states.dtype
|
||||
)
|
||||
|
||||
encoder_hidden_states = self.caption_projection(encoder_hidden_states)
|
||||
encoder_hidden_states = encoder_hidden_states.view(batch_size, -1, hidden_states.shape[-1])
|
||||
|
||||
@@ -12,20 +12,22 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from typing import Optional, Tuple, Union
|
||||
from typing import Any, Dict, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
from ...configuration_utils import ConfigMixin, register_to_config
|
||||
from ...models.attention import FeedForward
|
||||
from ...models.attention_processor import Attention
|
||||
from ...models.modeling_utils import ModelMixin
|
||||
from ...models.normalization import AdaLayerNormContinuous
|
||||
from ...utils import logging
|
||||
from ...loaders import PeftAdapterMixin
|
||||
from ...utils import USE_PEFT_BACKEND, logging, scale_lora_layers, unscale_lora_layers
|
||||
from ..attention import FeedForward
|
||||
from ..attention_processor import Attention
|
||||
from ..cache_utils import CacheMixin
|
||||
from ..embeddings import CogView3CombinedTimestepSizeEmbeddings
|
||||
from ..modeling_outputs import Transformer2DModelOutput
|
||||
from ..modeling_utils import ModelMixin
|
||||
from ..normalization import AdaLayerNormContinuous
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||
@@ -125,7 +127,8 @@ class CogView4AttnProcessor:
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
image_rotary_emb: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
text_seq_length = encoder_hidden_states.size(1)
|
||||
batch_size, text_seq_length, embed_dim = encoder_hidden_states.shape
|
||||
batch_size, image_seq_length, embed_dim = hidden_states.shape
|
||||
hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1)
|
||||
|
||||
# 1. QKV projections
|
||||
@@ -155,6 +158,15 @@ class CogView4AttnProcessor:
|
||||
)
|
||||
|
||||
# 4. Attention
|
||||
if attention_mask is not None:
|
||||
text_attention_mask = attention_mask.float().to(query.device)
|
||||
actual_text_seq_length = text_attention_mask.size(1)
|
||||
new_attention_mask = torch.zeros((batch_size, text_seq_length + image_seq_length), device=query.device)
|
||||
new_attention_mask[:, :actual_text_seq_length] = text_attention_mask
|
||||
new_attention_mask = new_attention_mask.unsqueeze(2)
|
||||
attention_mask_matrix = new_attention_mask @ new_attention_mask.transpose(1, 2)
|
||||
attention_mask = (attention_mask_matrix > 0).unsqueeze(1).to(query.dtype)
|
||||
|
||||
hidden_states = F.scaled_dot_product_attention(
|
||||
query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
|
||||
)
|
||||
@@ -202,6 +214,8 @@ class CogView4TransformerBlock(nn.Module):
|
||||
encoder_hidden_states: torch.Tensor,
|
||||
temb: Optional[torch.Tensor] = None,
|
||||
image_rotary_emb: Optional[torch.Tensor] = None,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
**kwargs,
|
||||
) -> torch.Tensor:
|
||||
# 1. Timestep conditioning
|
||||
(
|
||||
@@ -222,6 +236,8 @@ class CogView4TransformerBlock(nn.Module):
|
||||
hidden_states=norm_hidden_states,
|
||||
encoder_hidden_states=norm_encoder_hidden_states,
|
||||
image_rotary_emb=image_rotary_emb,
|
||||
attention_mask=attention_mask,
|
||||
**kwargs,
|
||||
)
|
||||
hidden_states = hidden_states + attn_hidden_states * gate_msa.unsqueeze(1)
|
||||
encoder_hidden_states = encoder_hidden_states + attn_encoder_hidden_states * c_gate_msa.unsqueeze(1)
|
||||
@@ -288,7 +304,7 @@ class CogView4RotaryPosEmbed(nn.Module):
|
||||
return (freqs.cos(), freqs.sin())
|
||||
|
||||
|
||||
class CogView4Transformer2DModel(ModelMixin, ConfigMixin):
|
||||
class CogView4Transformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, CacheMixin):
|
||||
r"""
|
||||
Args:
|
||||
patch_size (`int`, defaults to `2`):
|
||||
@@ -383,8 +399,26 @@ class CogView4Transformer2DModel(ModelMixin, ConfigMixin):
|
||||
original_size: torch.Tensor,
|
||||
target_size: torch.Tensor,
|
||||
crop_coords: torch.Tensor,
|
||||
attention_kwargs: Optional[Dict[str, Any]] = None,
|
||||
return_dict: bool = True,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
**kwargs,
|
||||
) -> Union[torch.Tensor, Transformer2DModelOutput]:
|
||||
if attention_kwargs is not None:
|
||||
attention_kwargs = attention_kwargs.copy()
|
||||
lora_scale = attention_kwargs.pop("scale", 1.0)
|
||||
else:
|
||||
lora_scale = 1.0
|
||||
|
||||
if USE_PEFT_BACKEND:
|
||||
# weight the lora layers by setting `lora_scale` for each PEFT layer
|
||||
scale_lora_layers(self, lora_scale)
|
||||
else:
|
||||
if attention_kwargs is not None and attention_kwargs.get("scale", None) is not None:
|
||||
logger.warning(
|
||||
"Passing `scale` via `attention_kwargs` when not using the PEFT backend is ineffective."
|
||||
)
|
||||
|
||||
batch_size, num_channels, height, width = hidden_states.shape
|
||||
|
||||
# 1. RoPE
|
||||
@@ -404,11 +438,11 @@ class CogView4Transformer2DModel(ModelMixin, ConfigMixin):
|
||||
for block in self.transformer_blocks:
|
||||
if torch.is_grad_enabled() and self.gradient_checkpointing:
|
||||
hidden_states, encoder_hidden_states = self._gradient_checkpointing_func(
|
||||
block, hidden_states, encoder_hidden_states, temb, image_rotary_emb
|
||||
block, hidden_states, encoder_hidden_states, temb, image_rotary_emb, attention_mask, **kwargs
|
||||
)
|
||||
else:
|
||||
hidden_states, encoder_hidden_states = block(
|
||||
hidden_states, encoder_hidden_states, temb, image_rotary_emb
|
||||
hidden_states, encoder_hidden_states, temb, image_rotary_emb, attention_mask, **kwargs
|
||||
)
|
||||
|
||||
# 4. Output norm & projection
|
||||
@@ -419,6 +453,10 @@ class CogView4Transformer2DModel(ModelMixin, ConfigMixin):
|
||||
hidden_states = hidden_states.reshape(batch_size, post_patch_height, post_patch_width, -1, p, p)
|
||||
output = hidden_states.permute(0, 3, 1, 4, 2, 5).flatten(4, 5).flatten(2, 3)
|
||||
|
||||
if USE_PEFT_BACKEND:
|
||||
# remove `lora_scale` from each PEFT layer
|
||||
unscale_lora_layers(self, lora_scale)
|
||||
|
||||
if not return_dict:
|
||||
return (output,)
|
||||
return Transformer2DModelOutput(sample=output)
|
||||
|
||||
@@ -27,13 +27,15 @@ from ..attention import FeedForward
|
||||
from ..attention_processor import Attention, AttentionProcessor
|
||||
from ..cache_utils import CacheMixin
|
||||
from ..embeddings import (
|
||||
CombinedTimestepGuidanceTextProjEmbeddings,
|
||||
CombinedTimestepTextProjEmbeddings,
|
||||
PixArtAlphaTextProjection,
|
||||
TimestepEmbedding,
|
||||
Timesteps,
|
||||
get_1d_rotary_pos_embed,
|
||||
)
|
||||
from ..modeling_outputs import Transformer2DModelOutput
|
||||
from ..modeling_utils import ModelMixin
|
||||
from ..normalization import AdaLayerNormContinuous, AdaLayerNormZero, AdaLayerNormZeroSingle
|
||||
from ..normalization import AdaLayerNormContinuous, AdaLayerNormZero, AdaLayerNormZeroSingle, FP32LayerNorm
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||
@@ -173,6 +175,141 @@ class HunyuanVideoAdaNorm(nn.Module):
|
||||
return gate_msa, gate_mlp
|
||||
|
||||
|
||||
class HunyuanVideoTokenReplaceAdaLayerNormZero(nn.Module):
|
||||
def __init__(self, embedding_dim: int, norm_type: str = "layer_norm", bias: bool = True):
|
||||
super().__init__()
|
||||
|
||||
self.silu = nn.SiLU()
|
||||
self.linear = nn.Linear(embedding_dim, 6 * embedding_dim, bias=bias)
|
||||
|
||||
if norm_type == "layer_norm":
|
||||
self.norm = nn.LayerNorm(embedding_dim, elementwise_affine=False, eps=1e-6)
|
||||
elif norm_type == "fp32_layer_norm":
|
||||
self.norm = FP32LayerNorm(embedding_dim, elementwise_affine=False, bias=False)
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Unsupported `norm_type` ({norm_type}) provided. Supported ones are: 'layer_norm', 'fp32_layer_norm'."
|
||||
)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
emb: torch.Tensor,
|
||||
token_replace_emb: torch.Tensor,
|
||||
first_frame_num_tokens: int,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
emb = self.linear(self.silu(emb))
|
||||
token_replace_emb = self.linear(self.silu(token_replace_emb))
|
||||
|
||||
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = emb.chunk(6, dim=1)
|
||||
tr_shift_msa, tr_scale_msa, tr_gate_msa, tr_shift_mlp, tr_scale_mlp, tr_gate_mlp = token_replace_emb.chunk(
|
||||
6, dim=1
|
||||
)
|
||||
|
||||
norm_hidden_states = self.norm(hidden_states)
|
||||
hidden_states_zero = (
|
||||
norm_hidden_states[:, :first_frame_num_tokens] * (1 + tr_scale_msa[:, None]) + tr_shift_msa[:, None]
|
||||
)
|
||||
hidden_states_orig = (
|
||||
norm_hidden_states[:, first_frame_num_tokens:] * (1 + scale_msa[:, None]) + shift_msa[:, None]
|
||||
)
|
||||
hidden_states = torch.cat([hidden_states_zero, hidden_states_orig], dim=1)
|
||||
|
||||
return (
|
||||
hidden_states,
|
||||
gate_msa,
|
||||
shift_mlp,
|
||||
scale_mlp,
|
||||
gate_mlp,
|
||||
tr_gate_msa,
|
||||
tr_shift_mlp,
|
||||
tr_scale_mlp,
|
||||
tr_gate_mlp,
|
||||
)
|
||||
|
||||
|
||||
class HunyuanVideoTokenReplaceAdaLayerNormZeroSingle(nn.Module):
|
||||
def __init__(self, embedding_dim: int, norm_type: str = "layer_norm", bias: bool = True):
|
||||
super().__init__()
|
||||
|
||||
self.silu = nn.SiLU()
|
||||
self.linear = nn.Linear(embedding_dim, 3 * embedding_dim, bias=bias)
|
||||
|
||||
if norm_type == "layer_norm":
|
||||
self.norm = nn.LayerNorm(embedding_dim, elementwise_affine=False, eps=1e-6)
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Unsupported `norm_type` ({norm_type}) provided. Supported ones are: 'layer_norm', 'fp32_layer_norm'."
|
||||
)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
emb: torch.Tensor,
|
||||
token_replace_emb: torch.Tensor,
|
||||
first_frame_num_tokens: int,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
emb = self.linear(self.silu(emb))
|
||||
token_replace_emb = self.linear(self.silu(token_replace_emb))
|
||||
|
||||
shift_msa, scale_msa, gate_msa = emb.chunk(3, dim=1)
|
||||
tr_shift_msa, tr_scale_msa, tr_gate_msa = token_replace_emb.chunk(3, dim=1)
|
||||
|
||||
norm_hidden_states = self.norm(hidden_states)
|
||||
hidden_states_zero = (
|
||||
norm_hidden_states[:, :first_frame_num_tokens] * (1 + tr_scale_msa[:, None]) + tr_shift_msa[:, None]
|
||||
)
|
||||
hidden_states_orig = (
|
||||
norm_hidden_states[:, first_frame_num_tokens:] * (1 + scale_msa[:, None]) + shift_msa[:, None]
|
||||
)
|
||||
hidden_states = torch.cat([hidden_states_zero, hidden_states_orig], dim=1)
|
||||
|
||||
return hidden_states, gate_msa, tr_gate_msa
|
||||
|
||||
|
||||
class HunyuanVideoConditionEmbedding(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
embedding_dim: int,
|
||||
pooled_projection_dim: int,
|
||||
guidance_embeds: bool,
|
||||
image_condition_type: Optional[str] = None,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.image_condition_type = image_condition_type
|
||||
|
||||
self.time_proj = Timesteps(num_channels=256, flip_sin_to_cos=True, downscale_freq_shift=0)
|
||||
self.timestep_embedder = TimestepEmbedding(in_channels=256, time_embed_dim=embedding_dim)
|
||||
self.text_embedder = PixArtAlphaTextProjection(pooled_projection_dim, embedding_dim, act_fn="silu")
|
||||
|
||||
self.guidance_embedder = None
|
||||
if guidance_embeds:
|
||||
self.guidance_embedder = TimestepEmbedding(in_channels=256, time_embed_dim=embedding_dim)
|
||||
|
||||
def forward(
|
||||
self, timestep: torch.Tensor, pooled_projection: torch.Tensor, guidance: Optional[torch.Tensor] = None
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
timesteps_proj = self.time_proj(timestep)
|
||||
timesteps_emb = self.timestep_embedder(timesteps_proj.to(dtype=pooled_projection.dtype)) # (N, D)
|
||||
pooled_projections = self.text_embedder(pooled_projection)
|
||||
conditioning = timesteps_emb + pooled_projections
|
||||
|
||||
token_replace_emb = None
|
||||
if self.image_condition_type == "token_replace":
|
||||
token_replace_timestep = torch.zeros_like(timestep)
|
||||
token_replace_proj = self.time_proj(token_replace_timestep)
|
||||
token_replace_emb = self.timestep_embedder(token_replace_proj.to(dtype=pooled_projection.dtype))
|
||||
token_replace_emb = token_replace_emb + pooled_projections
|
||||
|
||||
if self.guidance_embedder is not None:
|
||||
guidance_proj = self.time_proj(guidance)
|
||||
guidance_emb = self.guidance_embedder(guidance_proj.to(dtype=pooled_projection.dtype))
|
||||
conditioning = conditioning + guidance_emb
|
||||
|
||||
return conditioning, token_replace_emb
|
||||
|
||||
|
||||
class HunyuanVideoIndividualTokenRefinerBlock(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
@@ -390,6 +527,8 @@ class HunyuanVideoSingleTransformerBlock(nn.Module):
|
||||
temb: torch.Tensor,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
|
||||
*args,
|
||||
**kwargs,
|
||||
) -> torch.Tensor:
|
||||
text_seq_length = encoder_hidden_states.shape[1]
|
||||
hidden_states = torch.cat([hidden_states, encoder_hidden_states], dim=1)
|
||||
@@ -468,6 +607,8 @@ class HunyuanVideoTransformerBlock(nn.Module):
|
||||
temb: torch.Tensor,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
freqs_cis: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
|
||||
*args,
|
||||
**kwargs,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
# 1. Input normalization
|
||||
norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1(hidden_states, emb=temb)
|
||||
@@ -503,6 +644,181 @@ class HunyuanVideoTransformerBlock(nn.Module):
|
||||
return hidden_states, encoder_hidden_states
|
||||
|
||||
|
||||
class HunyuanVideoTokenReplaceSingleTransformerBlock(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
num_attention_heads: int,
|
||||
attention_head_dim: int,
|
||||
mlp_ratio: float = 4.0,
|
||||
qk_norm: str = "rms_norm",
|
||||
) -> None:
|
||||
super().__init__()
|
||||
|
||||
hidden_size = num_attention_heads * attention_head_dim
|
||||
mlp_dim = int(hidden_size * mlp_ratio)
|
||||
|
||||
self.attn = Attention(
|
||||
query_dim=hidden_size,
|
||||
cross_attention_dim=None,
|
||||
dim_head=attention_head_dim,
|
||||
heads=num_attention_heads,
|
||||
out_dim=hidden_size,
|
||||
bias=True,
|
||||
processor=HunyuanVideoAttnProcessor2_0(),
|
||||
qk_norm=qk_norm,
|
||||
eps=1e-6,
|
||||
pre_only=True,
|
||||
)
|
||||
|
||||
self.norm = HunyuanVideoTokenReplaceAdaLayerNormZeroSingle(hidden_size, norm_type="layer_norm")
|
||||
self.proj_mlp = nn.Linear(hidden_size, mlp_dim)
|
||||
self.act_mlp = nn.GELU(approximate="tanh")
|
||||
self.proj_out = nn.Linear(hidden_size + mlp_dim, hidden_size)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
encoder_hidden_states: torch.Tensor,
|
||||
temb: torch.Tensor,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
|
||||
token_replace_emb: torch.Tensor = None,
|
||||
num_tokens: int = None,
|
||||
) -> torch.Tensor:
|
||||
text_seq_length = encoder_hidden_states.shape[1]
|
||||
hidden_states = torch.cat([hidden_states, encoder_hidden_states], dim=1)
|
||||
|
||||
residual = hidden_states
|
||||
|
||||
# 1. Input normalization
|
||||
norm_hidden_states, gate, tr_gate = self.norm(hidden_states, temb, token_replace_emb, num_tokens)
|
||||
mlp_hidden_states = self.act_mlp(self.proj_mlp(norm_hidden_states))
|
||||
|
||||
norm_hidden_states, norm_encoder_hidden_states = (
|
||||
norm_hidden_states[:, :-text_seq_length, :],
|
||||
norm_hidden_states[:, -text_seq_length:, :],
|
||||
)
|
||||
|
||||
# 2. Attention
|
||||
attn_output, context_attn_output = self.attn(
|
||||
hidden_states=norm_hidden_states,
|
||||
encoder_hidden_states=norm_encoder_hidden_states,
|
||||
attention_mask=attention_mask,
|
||||
image_rotary_emb=image_rotary_emb,
|
||||
)
|
||||
attn_output = torch.cat([attn_output, context_attn_output], dim=1)
|
||||
|
||||
# 3. Modulation and residual connection
|
||||
hidden_states = torch.cat([attn_output, mlp_hidden_states], dim=2)
|
||||
|
||||
proj_output = self.proj_out(hidden_states)
|
||||
hidden_states_zero = proj_output[:, :num_tokens] * tr_gate.unsqueeze(1)
|
||||
hidden_states_orig = proj_output[:, num_tokens:] * gate.unsqueeze(1)
|
||||
hidden_states = torch.cat([hidden_states_zero, hidden_states_orig], dim=1)
|
||||
hidden_states = hidden_states + residual
|
||||
|
||||
hidden_states, encoder_hidden_states = (
|
||||
hidden_states[:, :-text_seq_length, :],
|
||||
hidden_states[:, -text_seq_length:, :],
|
||||
)
|
||||
return hidden_states, encoder_hidden_states
|
||||
|
||||
|
||||
class HunyuanVideoTokenReplaceTransformerBlock(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
num_attention_heads: int,
|
||||
attention_head_dim: int,
|
||||
mlp_ratio: float,
|
||||
qk_norm: str = "rms_norm",
|
||||
) -> None:
|
||||
super().__init__()
|
||||
|
||||
hidden_size = num_attention_heads * attention_head_dim
|
||||
|
||||
self.norm1 = HunyuanVideoTokenReplaceAdaLayerNormZero(hidden_size, norm_type="layer_norm")
|
||||
self.norm1_context = AdaLayerNormZero(hidden_size, norm_type="layer_norm")
|
||||
|
||||
self.attn = Attention(
|
||||
query_dim=hidden_size,
|
||||
cross_attention_dim=None,
|
||||
added_kv_proj_dim=hidden_size,
|
||||
dim_head=attention_head_dim,
|
||||
heads=num_attention_heads,
|
||||
out_dim=hidden_size,
|
||||
context_pre_only=False,
|
||||
bias=True,
|
||||
processor=HunyuanVideoAttnProcessor2_0(),
|
||||
qk_norm=qk_norm,
|
||||
eps=1e-6,
|
||||
)
|
||||
|
||||
self.norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
|
||||
self.ff = FeedForward(hidden_size, mult=mlp_ratio, activation_fn="gelu-approximate")
|
||||
|
||||
self.norm2_context = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
|
||||
self.ff_context = FeedForward(hidden_size, mult=mlp_ratio, activation_fn="gelu-approximate")
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
encoder_hidden_states: torch.Tensor,
|
||||
temb: torch.Tensor,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
freqs_cis: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
|
||||
token_replace_emb: torch.Tensor = None,
|
||||
num_tokens: int = None,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
# 1. Input normalization
|
||||
(
|
||||
norm_hidden_states,
|
||||
gate_msa,
|
||||
shift_mlp,
|
||||
scale_mlp,
|
||||
gate_mlp,
|
||||
tr_gate_msa,
|
||||
tr_shift_mlp,
|
||||
tr_scale_mlp,
|
||||
tr_gate_mlp,
|
||||
) = self.norm1(hidden_states, temb, token_replace_emb, num_tokens)
|
||||
norm_encoder_hidden_states, c_gate_msa, c_shift_mlp, c_scale_mlp, c_gate_mlp = self.norm1_context(
|
||||
encoder_hidden_states, emb=temb
|
||||
)
|
||||
|
||||
# 2. Joint attention
|
||||
attn_output, context_attn_output = self.attn(
|
||||
hidden_states=norm_hidden_states,
|
||||
encoder_hidden_states=norm_encoder_hidden_states,
|
||||
attention_mask=attention_mask,
|
||||
image_rotary_emb=freqs_cis,
|
||||
)
|
||||
|
||||
# 3. Modulation and residual connection
|
||||
hidden_states_zero = hidden_states[:, :num_tokens] + attn_output[:, :num_tokens] * tr_gate_msa.unsqueeze(1)
|
||||
hidden_states_orig = hidden_states[:, num_tokens:] + attn_output[:, num_tokens:] * gate_msa.unsqueeze(1)
|
||||
hidden_states = torch.cat([hidden_states_zero, hidden_states_orig], dim=1)
|
||||
encoder_hidden_states = encoder_hidden_states + context_attn_output * c_gate_msa.unsqueeze(1)
|
||||
|
||||
norm_hidden_states = self.norm2(hidden_states)
|
||||
norm_encoder_hidden_states = self.norm2_context(encoder_hidden_states)
|
||||
|
||||
hidden_states_zero = norm_hidden_states[:, :num_tokens] * (1 + tr_scale_mlp[:, None]) + tr_shift_mlp[:, None]
|
||||
hidden_states_orig = norm_hidden_states[:, num_tokens:] * (1 + scale_mlp[:, None]) + shift_mlp[:, None]
|
||||
norm_hidden_states = torch.cat([hidden_states_zero, hidden_states_orig], dim=1)
|
||||
norm_encoder_hidden_states = norm_encoder_hidden_states * (1 + c_scale_mlp[:, None]) + c_shift_mlp[:, None]
|
||||
|
||||
# 4. Feed-forward
|
||||
ff_output = self.ff(norm_hidden_states)
|
||||
context_ff_output = self.ff_context(norm_encoder_hidden_states)
|
||||
|
||||
hidden_states_zero = hidden_states[:, :num_tokens] + ff_output[:, :num_tokens] * tr_gate_mlp.unsqueeze(1)
|
||||
hidden_states_orig = hidden_states[:, num_tokens:] + ff_output[:, num_tokens:] * gate_mlp.unsqueeze(1)
|
||||
hidden_states = torch.cat([hidden_states_zero, hidden_states_orig], dim=1)
|
||||
encoder_hidden_states = encoder_hidden_states + c_gate_mlp.unsqueeze(1) * context_ff_output
|
||||
|
||||
return hidden_states, encoder_hidden_states
|
||||
|
||||
|
||||
class HunyuanVideoTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginalModelMixin, CacheMixin):
|
||||
r"""
|
||||
A Transformer model for video-like data used in [HunyuanVideo](https://huggingface.co/tencent/HunyuanVideo).
|
||||
@@ -540,6 +856,10 @@ class HunyuanVideoTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin,
|
||||
The value of theta to use in the RoPE layer.
|
||||
rope_axes_dim (`Tuple[int]`, defaults to `(16, 56, 56)`):
|
||||
The dimensions of the axes to use in the RoPE layer.
|
||||
image_condition_type (`str`, *optional*, defaults to `None`):
|
||||
The type of image conditioning to use. If `None`, no image conditioning is used. If `latent_concat`, the
|
||||
image is concatenated to the latent stream. If `token_replace`, the image is used to replace first-frame
|
||||
tokens in the latent stream and apply conditioning.
|
||||
"""
|
||||
|
||||
_supports_gradient_checkpointing = True
|
||||
@@ -570,9 +890,16 @@ class HunyuanVideoTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin,
|
||||
pooled_projection_dim: int = 768,
|
||||
rope_theta: float = 256.0,
|
||||
rope_axes_dim: Tuple[int] = (16, 56, 56),
|
||||
image_condition_type: Optional[str] = None,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
|
||||
supported_image_condition_types = ["latent_concat", "token_replace"]
|
||||
if image_condition_type is not None and image_condition_type not in supported_image_condition_types:
|
||||
raise ValueError(
|
||||
f"Invalid `image_condition_type` ({image_condition_type}). Supported ones are: {supported_image_condition_types}"
|
||||
)
|
||||
|
||||
inner_dim = num_attention_heads * attention_head_dim
|
||||
out_channels = out_channels or in_channels
|
||||
|
||||
@@ -582,33 +909,52 @@ class HunyuanVideoTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin,
|
||||
text_embed_dim, num_attention_heads, attention_head_dim, num_layers=num_refiner_layers
|
||||
)
|
||||
|
||||
if guidance_embeds:
|
||||
self.time_text_embed = CombinedTimestepGuidanceTextProjEmbeddings(inner_dim, pooled_projection_dim)
|
||||
else:
|
||||
self.time_text_embed = CombinedTimestepTextProjEmbeddings(inner_dim, pooled_projection_dim)
|
||||
self.time_text_embed = HunyuanVideoConditionEmbedding(
|
||||
inner_dim, pooled_projection_dim, guidance_embeds, image_condition_type
|
||||
)
|
||||
|
||||
# 2. RoPE
|
||||
self.rope = HunyuanVideoRotaryPosEmbed(patch_size, patch_size_t, rope_axes_dim, rope_theta)
|
||||
|
||||
# 3. Dual stream transformer blocks
|
||||
self.transformer_blocks = nn.ModuleList(
|
||||
[
|
||||
HunyuanVideoTransformerBlock(
|
||||
num_attention_heads, attention_head_dim, mlp_ratio=mlp_ratio, qk_norm=qk_norm
|
||||
)
|
||||
for _ in range(num_layers)
|
||||
]
|
||||
)
|
||||
if image_condition_type == "token_replace":
|
||||
self.transformer_blocks = nn.ModuleList(
|
||||
[
|
||||
HunyuanVideoTokenReplaceTransformerBlock(
|
||||
num_attention_heads, attention_head_dim, mlp_ratio=mlp_ratio, qk_norm=qk_norm
|
||||
)
|
||||
for _ in range(num_layers)
|
||||
]
|
||||
)
|
||||
else:
|
||||
self.transformer_blocks = nn.ModuleList(
|
||||
[
|
||||
HunyuanVideoTransformerBlock(
|
||||
num_attention_heads, attention_head_dim, mlp_ratio=mlp_ratio, qk_norm=qk_norm
|
||||
)
|
||||
for _ in range(num_layers)
|
||||
]
|
||||
)
|
||||
|
||||
# 4. Single stream transformer blocks
|
||||
self.single_transformer_blocks = nn.ModuleList(
|
||||
[
|
||||
HunyuanVideoSingleTransformerBlock(
|
||||
num_attention_heads, attention_head_dim, mlp_ratio=mlp_ratio, qk_norm=qk_norm
|
||||
)
|
||||
for _ in range(num_single_layers)
|
||||
]
|
||||
)
|
||||
if image_condition_type == "token_replace":
|
||||
self.single_transformer_blocks = nn.ModuleList(
|
||||
[
|
||||
HunyuanVideoTokenReplaceSingleTransformerBlock(
|
||||
num_attention_heads, attention_head_dim, mlp_ratio=mlp_ratio, qk_norm=qk_norm
|
||||
)
|
||||
for _ in range(num_single_layers)
|
||||
]
|
||||
)
|
||||
else:
|
||||
self.single_transformer_blocks = nn.ModuleList(
|
||||
[
|
||||
HunyuanVideoSingleTransformerBlock(
|
||||
num_attention_heads, attention_head_dim, mlp_ratio=mlp_ratio, qk_norm=qk_norm
|
||||
)
|
||||
for _ in range(num_single_layers)
|
||||
]
|
||||
)
|
||||
|
||||
# 5. Output projection
|
||||
self.norm_out = AdaLayerNormContinuous(inner_dim, inner_dim, elementwise_affine=False, eps=1e-6)
|
||||
@@ -707,15 +1053,13 @@ class HunyuanVideoTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin,
|
||||
post_patch_num_frames = num_frames // p_t
|
||||
post_patch_height = height // p
|
||||
post_patch_width = width // p
|
||||
first_frame_num_tokens = 1 * post_patch_height * post_patch_width
|
||||
|
||||
# 1. RoPE
|
||||
image_rotary_emb = self.rope(hidden_states)
|
||||
|
||||
# 2. Conditional embeddings
|
||||
if self.config.guidance_embeds:
|
||||
temb = self.time_text_embed(timestep, guidance, pooled_projections)
|
||||
else:
|
||||
temb = self.time_text_embed(timestep, pooled_projections)
|
||||
temb, token_replace_emb = self.time_text_embed(timestep, pooled_projections, guidance)
|
||||
|
||||
hidden_states = self.x_embedder(hidden_states)
|
||||
encoder_hidden_states = self.context_embedder(encoder_hidden_states, timestep, encoder_attention_mask)
|
||||
@@ -746,6 +1090,8 @@ class HunyuanVideoTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin,
|
||||
temb,
|
||||
attention_mask,
|
||||
image_rotary_emb,
|
||||
token_replace_emb,
|
||||
first_frame_num_tokens,
|
||||
)
|
||||
|
||||
for block in self.single_transformer_blocks:
|
||||
@@ -756,17 +1102,31 @@ class HunyuanVideoTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin,
|
||||
temb,
|
||||
attention_mask,
|
||||
image_rotary_emb,
|
||||
token_replace_emb,
|
||||
first_frame_num_tokens,
|
||||
)
|
||||
|
||||
else:
|
||||
for block in self.transformer_blocks:
|
||||
hidden_states, encoder_hidden_states = block(
|
||||
hidden_states, encoder_hidden_states, temb, attention_mask, image_rotary_emb
|
||||
hidden_states,
|
||||
encoder_hidden_states,
|
||||
temb,
|
||||
attention_mask,
|
||||
image_rotary_emb,
|
||||
token_replace_emb,
|
||||
first_frame_num_tokens,
|
||||
)
|
||||
|
||||
for block in self.single_transformer_blocks:
|
||||
hidden_states, encoder_hidden_states = block(
|
||||
hidden_states, encoder_hidden_states, temb, attention_mask, image_rotary_emb
|
||||
hidden_states,
|
||||
encoder_hidden_states,
|
||||
temb,
|
||||
attention_mask,
|
||||
image_rotary_emb,
|
||||
token_replace_emb,
|
||||
first_frame_num_tokens,
|
||||
)
|
||||
|
||||
# 5. Output projection
|
||||
|
||||
@@ -14,7 +14,7 @@
|
||||
# limitations under the License.
|
||||
|
||||
import math
|
||||
from typing import Any, Dict, Optional, Tuple
|
||||
from typing import Any, Dict, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
@@ -26,6 +26,7 @@ from ...utils import USE_PEFT_BACKEND, logging, scale_lora_layers, unscale_lora_
|
||||
from ...utils.torch_utils import maybe_allow_in_graph
|
||||
from ..attention import FeedForward
|
||||
from ..attention_processor import Attention
|
||||
from ..cache_utils import CacheMixin
|
||||
from ..embeddings import PixArtAlphaTextProjection
|
||||
from ..modeling_outputs import Transformer2DModelOutput
|
||||
from ..modeling_utils import ModelMixin
|
||||
@@ -113,20 +114,19 @@ class LTXVideoRotaryPosEmbed(nn.Module):
|
||||
self.patch_size_t = patch_size_t
|
||||
self.theta = theta
|
||||
|
||||
def forward(
|
||||
def _prepare_video_coords(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
batch_size: int,
|
||||
num_frames: int,
|
||||
height: int,
|
||||
width: int,
|
||||
rope_interpolation_scale: Optional[Tuple[torch.Tensor, float, float]] = None,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
batch_size = hidden_states.size(0)
|
||||
|
||||
rope_interpolation_scale: Tuple[torch.Tensor, float, float],
|
||||
device: torch.device,
|
||||
) -> torch.Tensor:
|
||||
# Always compute rope in fp32
|
||||
grid_h = torch.arange(height, dtype=torch.float32, device=hidden_states.device)
|
||||
grid_w = torch.arange(width, dtype=torch.float32, device=hidden_states.device)
|
||||
grid_f = torch.arange(num_frames, dtype=torch.float32, device=hidden_states.device)
|
||||
grid_h = torch.arange(height, dtype=torch.float32, device=device)
|
||||
grid_w = torch.arange(width, dtype=torch.float32, device=device)
|
||||
grid_f = torch.arange(num_frames, dtype=torch.float32, device=device)
|
||||
grid = torch.meshgrid(grid_f, grid_h, grid_w, indexing="ij")
|
||||
grid = torch.stack(grid, dim=0)
|
||||
grid = grid.unsqueeze(0).repeat(batch_size, 1, 1, 1, 1)
|
||||
@@ -138,6 +138,38 @@ class LTXVideoRotaryPosEmbed(nn.Module):
|
||||
|
||||
grid = grid.flatten(2, 4).transpose(1, 2)
|
||||
|
||||
return grid
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
num_frames: Optional[int] = None,
|
||||
height: Optional[int] = None,
|
||||
width: Optional[int] = None,
|
||||
rope_interpolation_scale: Optional[Tuple[torch.Tensor, float, float]] = None,
|
||||
video_coords: Optional[torch.Tensor] = None,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
batch_size = hidden_states.size(0)
|
||||
|
||||
if video_coords is None:
|
||||
grid = self._prepare_video_coords(
|
||||
batch_size,
|
||||
num_frames,
|
||||
height,
|
||||
width,
|
||||
rope_interpolation_scale=rope_interpolation_scale,
|
||||
device=hidden_states.device,
|
||||
)
|
||||
else:
|
||||
grid = torch.stack(
|
||||
[
|
||||
video_coords[:, 0] / self.base_num_frames,
|
||||
video_coords[:, 1] / self.base_height,
|
||||
video_coords[:, 2] / self.base_width,
|
||||
],
|
||||
dim=-1,
|
||||
)
|
||||
|
||||
start = 1.0
|
||||
end = self.theta
|
||||
freqs = self.theta ** torch.linspace(
|
||||
@@ -267,7 +299,7 @@ class LTXVideoTransformerBlock(nn.Module):
|
||||
|
||||
|
||||
@maybe_allow_in_graph
|
||||
class LTXVideoTransformer3DModel(ModelMixin, ConfigMixin, FromOriginalModelMixin, PeftAdapterMixin):
|
||||
class LTXVideoTransformer3DModel(ModelMixin, ConfigMixin, FromOriginalModelMixin, PeftAdapterMixin, CacheMixin):
|
||||
r"""
|
||||
A Transformer model for video-like data used in [LTX](https://huggingface.co/Lightricks/LTX-Video).
|
||||
|
||||
@@ -367,10 +399,11 @@ class LTXVideoTransformer3DModel(ModelMixin, ConfigMixin, FromOriginalModelMixin
|
||||
encoder_hidden_states: torch.Tensor,
|
||||
timestep: torch.LongTensor,
|
||||
encoder_attention_mask: torch.Tensor,
|
||||
num_frames: int,
|
||||
height: int,
|
||||
width: int,
|
||||
rope_interpolation_scale: Optional[Tuple[float, float, float]] = None,
|
||||
num_frames: Optional[int] = None,
|
||||
height: Optional[int] = None,
|
||||
width: Optional[int] = None,
|
||||
rope_interpolation_scale: Optional[Union[Tuple[float, float, float], torch.Tensor]] = None,
|
||||
video_coords: Optional[torch.Tensor] = None,
|
||||
attention_kwargs: Optional[Dict[str, Any]] = None,
|
||||
return_dict: bool = True,
|
||||
) -> torch.Tensor:
|
||||
@@ -389,7 +422,7 @@ class LTXVideoTransformer3DModel(ModelMixin, ConfigMixin, FromOriginalModelMixin
|
||||
"Passing `scale` via `attention_kwargs` when not using the PEFT backend is ineffective."
|
||||
)
|
||||
|
||||
image_rotary_emb = self.rope(hidden_states, num_frames, height, width, rope_interpolation_scale)
|
||||
image_rotary_emb = self.rope(hidden_states, num_frames, height, width, rope_interpolation_scale, video_coords)
|
||||
|
||||
# convert encoder_attention_mask to a bias the same way we do for attention_mask
|
||||
if encoder_attention_mask is not None and encoder_attention_mask.ndim == 2:
|
||||
|
||||
@@ -20,10 +20,11 @@ import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
from ...configuration_utils import ConfigMixin, register_to_config
|
||||
from ...loaders import PeftAdapterMixin
|
||||
from ...loaders import FromOriginalModelMixin, PeftAdapterMixin
|
||||
from ...utils import USE_PEFT_BACKEND, logging, scale_lora_layers, unscale_lora_layers
|
||||
from ..attention import FeedForward
|
||||
from ..attention_processor import Attention
|
||||
from ..cache_utils import CacheMixin
|
||||
from ..embeddings import PixArtAlphaTextProjection, TimestepEmbedding, Timesteps, get_1d_rotary_pos_embed
|
||||
from ..modeling_outputs import Transformer2DModelOutput
|
||||
from ..modeling_utils import ModelMixin
|
||||
@@ -288,7 +289,7 @@ class WanTransformerBlock(nn.Module):
|
||||
return hidden_states
|
||||
|
||||
|
||||
class WanTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin):
|
||||
class WanTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginalModelMixin, CacheMixin):
|
||||
r"""
|
||||
A Transformer model for video-like data used in the Wan model.
|
||||
|
||||
@@ -329,6 +330,7 @@ class WanTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin):
|
||||
_skip_layerwise_casting_patterns = ["patch_embedding", "condition_embedder", "norm"]
|
||||
_no_split_modules = ["WanTransformerBlock"]
|
||||
_keep_in_fp32_modules = ["time_embedder", "scale_shift_table", "norm1", "norm2", "norm3"]
|
||||
_keys_to_ignore_on_load_unexpected = ["norm_added_q"]
|
||||
|
||||
@register_to_config
|
||||
def __init__(
|
||||
@@ -440,6 +442,14 @@ class WanTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin):
|
||||
|
||||
# 5. Output norm, projection & unpatchify
|
||||
shift, scale = (self.scale_shift_table + temb.unsqueeze(1)).chunk(2, dim=1)
|
||||
|
||||
# Move the shift and scale tensors to the same device as hidden_states.
|
||||
# When using multi-GPU inference via accelerate these will be on the
|
||||
# first device rather than the last device, which hidden_states ends up
|
||||
# on.
|
||||
shift = shift.to(hidden_states.device)
|
||||
scale = scale.to(hidden_states.device)
|
||||
|
||||
hidden_states = (self.norm_out(hidden_states.float()) * (1 + scale) + shift).type_as(hidden_states)
|
||||
hidden_states = self.proj_out(hidden_states)
|
||||
|
||||
|
||||
@@ -638,8 +638,10 @@ class UNet3DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin)
|
||||
t_emb = t_emb.to(dtype=self.dtype)
|
||||
|
||||
emb = self.time_embedding(t_emb, timestep_cond)
|
||||
emb = emb.repeat_interleave(repeats=num_frames, dim=0)
|
||||
encoder_hidden_states = encoder_hidden_states.repeat_interleave(repeats=num_frames, dim=0)
|
||||
emb = emb.repeat_interleave(num_frames, dim=0, output_size=emb.shape[0] * num_frames)
|
||||
encoder_hidden_states = encoder_hidden_states.repeat_interleave(
|
||||
num_frames, dim=0, output_size=encoder_hidden_states.shape[0] * num_frames
|
||||
)
|
||||
|
||||
# 2. pre-process
|
||||
sample = sample.permute(0, 2, 1, 3, 4).reshape((sample.shape[0] * num_frames, -1) + sample.shape[3:])
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user