Compare commits
116 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| b56112db6e | |||
| f50de75b69 | |||
| 579bb5f418 | |||
| 6bfacf0418 | |||
| f685981ed0 | |||
| b924251dd8 | |||
| 1a04812439 | |||
| 4b27c4a494 | |||
| 5d49b3e83b | |||
| 71f34fc5a4 | |||
| c51b6bd837 | |||
| fb54499614 | |||
| 723dbdd363 | |||
| fbf61f465b | |||
| 841504bb1a | |||
| fc7a867ae5 | |||
| 5ded26cdc7 | |||
| 506f39af3a | |||
| 8ad68c1393 | |||
| 41afb6690c | |||
| 13e48492f0 | |||
| 94f2c48d58 | |||
| aabf8ce20b | |||
| f10775b1b5 | |||
| 6edb774b5e | |||
| 480510ada9 | |||
| d9023a671a | |||
| c4646a3931 | |||
| c97b709afa | |||
| b0ff822ed3 | |||
| 78c2fdc52e | |||
| 54dac3a87c | |||
| e5c6027ef8 | |||
| da857bebb6 | |||
| 52b460feb9 | |||
| d8c617ccb0 | |||
| fe2b397426 | |||
| be0b7f55cc | |||
| 4d5a96e40a | |||
| a7f07c1ef5 | |||
| df1d7b01f1 | |||
| 5a6edac087 | |||
| e8fc8b1f81 | |||
| d6f4774c1c | |||
| eb50defff2 | |||
| 2c59af7222 | |||
| 75d7e5cc45 | |||
| 617c208bb4 | |||
| 5d970a4aa9 | |||
| de6a88c2d7 | |||
| 7dc52ea769 | |||
| 739d6ec731 | |||
| 1ddf3f3a19 | |||
| 7aac77affa | |||
| 8907a70a36 | |||
| 5dbe4f5de6 | |||
| 1d37f42055 | |||
| 0213179ba8 | |||
| a7d53a5939 | |||
| 8a63aa5e4f | |||
| 844221ae4e | |||
| 9b2c0a7dbe | |||
| f424b1b062 | |||
| e9fda3924f | |||
| 2c1ed50fc5 | |||
| 15ad97f782 | |||
| 9f2d5c9ee9 | |||
| dc62e6931e | |||
| 56f740051d | |||
| a34d97cef0 | |||
| fc28791fc8 | |||
| ae14612673 | |||
| 0ab8fe49bf | |||
| 3be6706018 | |||
| cb1b8b21b8 | |||
| 27916822b2 | |||
| 3fe3bc0642 | |||
| 813d42cc96 | |||
| b4d7e9c632 | |||
| 2e83cbbb6d | |||
| 33d10af28f | |||
| 100142586f | |||
| 82188cef04 | |||
| cc19726f3d | |||
| be54a95b93 | |||
| 6b9a3334db | |||
| 8ead643bb7 | |||
| 124ac3e81f | |||
| 2f0f281b0d | |||
| ccc8321651 | |||
| 5e48cd27d4 | |||
| 5551506b29 | |||
| 20e4b6a628 | |||
| 4ea9f89b8e | |||
| 733b44ac82 | |||
| 8b4f8ba764 | |||
| 5428046437 | |||
| e7ffeae0a1 | |||
| d87ce2cefc | |||
| 36d0553af2 | |||
| 7e0db46f73 | |||
| e4b056fe65 | |||
| 4e3ddd5afa | |||
| 9add071592 | |||
| b88fef4785 | |||
| e7e6d85282 | |||
| 8eefed65bd | |||
| 26149c0ecd | |||
| 0703ce8800 | |||
| f5edaa7894 | |||
| 9a1810f0de | |||
| 1fddee211e | |||
| b38450d5d2 | |||
| 1357931d74 | |||
| a2d3d6af44 | |||
| 363d1ab7e2 |
@@ -38,6 +38,7 @@ jobs:
|
||||
python -m venv /opt/venv && export PATH="/opt/venv/bin:$PATH"
|
||||
python -m uv pip install -e [quality,test]
|
||||
python -m uv pip install pandas peft
|
||||
python -m uv pip uninstall transformers && python -m uv pip install transformers==4.48.0
|
||||
- name: Environment
|
||||
run: |
|
||||
python utils/print_env.py
|
||||
|
||||
@@ -414,10 +414,16 @@ jobs:
|
||||
config:
|
||||
- backend: "bitsandbytes"
|
||||
test_location: "bnb"
|
||||
additional_deps: ["peft"]
|
||||
- backend: "gguf"
|
||||
test_location: "gguf"
|
||||
additional_deps: ["peft"]
|
||||
- backend: "torchao"
|
||||
test_location: "torchao"
|
||||
additional_deps: []
|
||||
- backend: "optimum_quanto"
|
||||
test_location: "quanto"
|
||||
additional_deps: []
|
||||
runs-on:
|
||||
group: aws-g6e-xlarge-plus
|
||||
container:
|
||||
@@ -435,6 +441,9 @@ jobs:
|
||||
python -m venv /opt/venv && export PATH="/opt/venv/bin:$PATH"
|
||||
python -m uv pip install -e [quality,test]
|
||||
python -m uv pip install -U ${{ matrix.config.backend }}
|
||||
if [ "${{ join(matrix.config.additional_deps, ' ') }}" != "" ]; then
|
||||
python -m uv pip install ${{ join(matrix.config.additional_deps, ' ') }}
|
||||
fi
|
||||
python -m uv pip install pytest-reportlog
|
||||
- name: Environment
|
||||
run: |
|
||||
|
||||
@@ -13,39 +13,5 @@ jobs:
|
||||
uses: huggingface/huggingface_hub/.github/workflows/style-bot-action.yml@main
|
||||
with:
|
||||
python_quality_dependencies: "[quality]"
|
||||
pre_commit_script_name: "Download and Compare files from the main branch"
|
||||
pre_commit_script: |
|
||||
echo "Downloading the files from the main branch"
|
||||
|
||||
curl -o main_Makefile https://raw.githubusercontent.com/huggingface/diffusers/main/Makefile
|
||||
curl -o main_setup.py https://raw.githubusercontent.com/huggingface/diffusers/refs/heads/main/setup.py
|
||||
curl -o main_check_doc_toc.py https://raw.githubusercontent.com/huggingface/diffusers/refs/heads/main/utils/check_doc_toc.py
|
||||
|
||||
echo "Compare the files and raise error if needed"
|
||||
|
||||
diff_failed=0
|
||||
if ! diff -q main_Makefile Makefile; then
|
||||
echo "Error: The Makefile has changed. Please ensure it matches the main branch."
|
||||
diff_failed=1
|
||||
fi
|
||||
|
||||
if ! diff -q main_setup.py setup.py; then
|
||||
echo "Error: The setup.py has changed. Please ensure it matches the main branch."
|
||||
diff_failed=1
|
||||
fi
|
||||
|
||||
if ! diff -q main_check_doc_toc.py utils/check_doc_toc.py; then
|
||||
echo "Error: The utils/check_doc_toc.py has changed. Please ensure it matches the main branch."
|
||||
diff_failed=1
|
||||
fi
|
||||
|
||||
if [ $diff_failed -eq 1 ]; then
|
||||
echo "❌ Error happened as we detected changes in the files that should not be changed ❌"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
echo "No changes in the files. Proceeding..."
|
||||
rm -rf main_Makefile main_setup.py main_check_doc_toc.py
|
||||
style_command: "make style && make quality"
|
||||
secrets:
|
||||
bot_token: ${{ secrets.GITHUB_TOKEN }}
|
||||
@@ -28,7 +28,51 @@ env:
|
||||
PIPELINE_USAGE_CUTOFF: 1000000000 # set high cutoff so that only always-test pipelines run
|
||||
|
||||
jobs:
|
||||
check_code_quality:
|
||||
runs-on: ubuntu-22.04
|
||||
steps:
|
||||
- uses: actions/checkout@v3
|
||||
- name: Set up Python
|
||||
uses: actions/setup-python@v4
|
||||
with:
|
||||
python-version: "3.8"
|
||||
- name: Install dependencies
|
||||
run: |
|
||||
python -m pip install --upgrade pip
|
||||
pip install .[quality]
|
||||
- name: Check quality
|
||||
run: make quality
|
||||
- name: Check if failure
|
||||
if: ${{ failure() }}
|
||||
run: |
|
||||
echo "Quality check failed. Please ensure the right dependency versions are installed with 'pip install -e .[quality]' and run 'make style && make quality'" >> $GITHUB_STEP_SUMMARY
|
||||
|
||||
check_repository_consistency:
|
||||
needs: check_code_quality
|
||||
runs-on: ubuntu-22.04
|
||||
steps:
|
||||
- uses: actions/checkout@v3
|
||||
- name: Set up Python
|
||||
uses: actions/setup-python@v4
|
||||
with:
|
||||
python-version: "3.8"
|
||||
- name: Install dependencies
|
||||
run: |
|
||||
python -m pip install --upgrade pip
|
||||
pip install .[quality]
|
||||
- name: Check repo consistency
|
||||
run: |
|
||||
python utils/check_copies.py
|
||||
python utils/check_dummies.py
|
||||
python utils/check_support_list.py
|
||||
make deps_table_check_updated
|
||||
- name: Check if failure
|
||||
if: ${{ failure() }}
|
||||
run: |
|
||||
echo "Repo consistency check failed. Please ensure the right dependency versions are installed with 'pip install -e .[quality]' and run 'make fix-copies'" >> $GITHUB_STEP_SUMMARY
|
||||
|
||||
setup_torch_cuda_pipeline_matrix:
|
||||
needs: [check_code_quality, check_repository_consistency]
|
||||
name: Setup Torch Pipelines CUDA Slow Tests Matrix
|
||||
runs-on:
|
||||
group: aws-general-8-plus
|
||||
@@ -133,6 +177,7 @@ jobs:
|
||||
|
||||
torch_cuda_tests:
|
||||
name: Torch CUDA Tests
|
||||
needs: [check_code_quality, check_repository_consistency]
|
||||
runs-on:
|
||||
group: aws-g4dn-2xlarge
|
||||
container:
|
||||
@@ -201,7 +246,7 @@ jobs:
|
||||
|
||||
run_examples_tests:
|
||||
name: Examples PyTorch CUDA tests on Ubuntu
|
||||
pip uninstall transformers -y && python -m uv pip install -U transformers@git+https://github.com/huggingface/transformers.git --no-deps
|
||||
needs: [check_code_quality, check_repository_consistency]
|
||||
runs-on:
|
||||
group: aws-g4dn-2xlarge
|
||||
|
||||
@@ -220,6 +265,7 @@ jobs:
|
||||
- name: Install dependencies
|
||||
run: |
|
||||
python -m venv /opt/venv && export PATH="/opt/venv/bin:$PATH"
|
||||
pip uninstall transformers -y && python -m uv pip install -U transformers@git+https://github.com/huggingface/transformers.git --no-deps
|
||||
python -m uv pip install -e [quality,test,training]
|
||||
|
||||
- name: Environment
|
||||
|
||||
@@ -81,6 +81,8 @@
|
||||
title: Overview
|
||||
- local: hybrid_inference/vae_decode
|
||||
title: VAE Decode
|
||||
- local: hybrid_inference/vae_encode
|
||||
title: VAE Encode
|
||||
- local: hybrid_inference/api_reference
|
||||
title: API Reference
|
||||
title: Hybrid Inference
|
||||
@@ -173,6 +175,8 @@
|
||||
title: gguf
|
||||
- local: quantization/torchao
|
||||
title: torchao
|
||||
- local: quantization/quanto
|
||||
title: quanto
|
||||
title: Quantization Methods
|
||||
- sections:
|
||||
- local: optimization/fp16
|
||||
@@ -492,6 +496,8 @@
|
||||
title: PixArt-Σ
|
||||
- local: api/pipelines/sana
|
||||
title: Sana
|
||||
- local: api/pipelines/sana_sprint
|
||||
title: Sana Sprint
|
||||
- local: api/pipelines/self_attention_guidance
|
||||
title: Self-Attention Guidance
|
||||
- local: api/pipelines/semantic_stable_diffusion
|
||||
|
||||
@@ -38,6 +38,33 @@ config = PyramidAttentionBroadcastConfig(
|
||||
pipe.transformer.enable_cache(config)
|
||||
```
|
||||
|
||||
## Faster Cache
|
||||
|
||||
[FasterCache](https://huggingface.co/papers/2410.19355) from Zhengyao Lv, Chenyang Si, Junhao Song, Zhenyu Yang, Yu Qiao, Ziwei Liu, Kwan-Yee K. Wong.
|
||||
|
||||
FasterCache is a method that speeds up inference in diffusion transformers by:
|
||||
- Reusing attention states between successive inference steps, due to high similarity between them
|
||||
- Skipping unconditional branch prediction used in classifier-free guidance by revealing redundancies between unconditional and conditional branch outputs for the same timestep, and therefore approximating the unconditional branch output using the conditional branch output
|
||||
|
||||
```python
|
||||
import torch
|
||||
from diffusers import CogVideoXPipeline, FasterCacheConfig
|
||||
|
||||
pipe = CogVideoXPipeline.from_pretrained("THUDM/CogVideoX-5b", torch_dtype=torch.bfloat16)
|
||||
pipe.to("cuda")
|
||||
|
||||
config = FasterCacheConfig(
|
||||
spatial_attention_block_skip_range=2,
|
||||
spatial_attention_timestep_skip_range=(-1, 681),
|
||||
current_timestep_callback=lambda: pipe.current_timestep,
|
||||
attention_weight_callback=lambda _: 0.3,
|
||||
unconditional_batch_skip_range=5,
|
||||
unconditional_batch_timestep_skip_range=(-1, 781),
|
||||
tensor_format="BFCHW",
|
||||
)
|
||||
pipe.transformer.enable_cache(config)
|
||||
```
|
||||
|
||||
### CacheMixin
|
||||
|
||||
[[autodoc]] CacheMixin
|
||||
@@ -47,3 +74,9 @@ pipe.transformer.enable_cache(config)
|
||||
[[autodoc]] PyramidAttentionBroadcastConfig
|
||||
|
||||
[[autodoc]] apply_pyramid_attention_broadcast
|
||||
|
||||
### FasterCacheConfig
|
||||
|
||||
[[autodoc]] FasterCacheConfig
|
||||
|
||||
[[autodoc]] apply_faster_cache
|
||||
|
||||
@@ -14,6 +14,7 @@ specific language governing permissions and limitations under the License.
|
||||
|
||||
<div class="flex flex-wrap space-x-1">
|
||||
<img alt="LoRA" src="https://img.shields.io/badge/LoRA-d8b4fe?style=flat"/>
|
||||
<img alt="MPS" src="https://img.shields.io/badge/MPS-000000?style=flat&logo=apple&logoColor=white%22">
|
||||
</div>
|
||||
|
||||
## Overview
|
||||
|
||||
@@ -14,6 +14,7 @@ specific language governing permissions and limitations under the License.
|
||||
|
||||
<div class="flex flex-wrap space-x-1">
|
||||
<img alt="LoRA" src="https://img.shields.io/badge/LoRA-d8b4fe?style=flat"/>
|
||||
<img alt="MPS" src="https://img.shields.io/badge/MPS-000000?style=flat&logo=apple&logoColor=white%22">
|
||||
</div>
|
||||
|
||||
Flux is a series of text-to-image generation models based on diffusion transformers. To know more about Flux, check out the original [blog post](https://blackforestlabs.ai/announcing-black-forest-labs/) by the creators of Flux, Black Forest Labs.
|
||||
|
||||
@@ -50,7 +50,8 @@ The following models are available for the image-to-video pipeline:
|
||||
| Model name | Description |
|
||||
|:---|:---|
|
||||
| [`Skywork/SkyReels-V1-Hunyuan-I2V`](https://huggingface.co/Skywork/SkyReels-V1-Hunyuan-I2V) | Skywork's custom finetune of HunyuanVideo (de-distilled). Performs best with `97x544x960` resolution. Performs best at `97x544x960` resolution, `guidance_scale=1.0`, `true_cfg_scale=6.0` and a negative prompt. |
|
||||
| [`hunyuanvideo-community/HunyuanVideo-I2V`](https://huggingface.co/hunyuanvideo-community/HunyuanVideo-I2V) | Tecent's official HunyuanVideo I2V model. Performs best at resolutions of 480, 720, 960, 1280. A higher `shift` value when initializing the scheduler is recommended (good values are between 7 and 20) |
|
||||
| [`hunyuanvideo-community/HunyuanVideo-I2V-33ch`](https://huggingface.co/hunyuanvideo-community/HunyuanVideo-I2V) | Tecent's official HunyuanVideo 33-channel I2V model. Performs best at resolutions of 480, 720, 960, 1280. A higher `shift` value when initializing the scheduler is recommended (good values are between 7 and 20). |
|
||||
| [`hunyuanvideo-community/HunyuanVideo-I2V`](https://huggingface.co/hunyuanvideo-community/HunyuanVideo-I2V) | Tecent's official HunyuanVideo 16-channel I2V model. Performs best at resolutions of 480, 720, 960, 1280. A higher `shift` value when initializing the scheduler is recommended (good values are between 7 and 20) |
|
||||
|
||||
## Quantization
|
||||
|
||||
|
||||
@@ -14,6 +14,7 @@ specific language governing permissions and limitations under the License.
|
||||
|
||||
<div class="flex flex-wrap space-x-1">
|
||||
<img alt="LoRA" src="https://img.shields.io/badge/LoRA-d8b4fe?style=flat"/>
|
||||
<img alt="MPS" src="https://img.shields.io/badge/MPS-000000?style=flat&logo=apple&logoColor=white%22">
|
||||
</div>
|
||||
|
||||

|
||||
|
||||
@@ -16,6 +16,7 @@
|
||||
|
||||
<div class="flex flex-wrap space-x-1">
|
||||
<img alt="LoRA" src="https://img.shields.io/badge/LoRA-d8b4fe?style=flat"/>
|
||||
<img alt="MPS" src="https://img.shields.io/badge/MPS-000000?style=flat&logo=apple&logoColor=white%22">
|
||||
</div>
|
||||
|
||||
[LTX Video](https://huggingface.co/Lightricks/LTX-Video) is the first DiT-based video generation model capable of generating high-quality videos in real-time. It produces 24 FPS videos at a 768x512 resolution faster than they can be watched. Trained on a large-scale dataset of diverse videos, the model generates high-resolution videos with realistic and varied content. We provide a model for both text-to-video as well as image + text-to-video usecases.
|
||||
@@ -32,6 +33,7 @@ Available models:
|
||||
|:-------------:|:-----------------:|
|
||||
| [`LTX Video 0.9.0`](https://huggingface.co/Lightricks/LTX-Video/blob/main/ltx-video-2b-v0.9.safetensors) | `torch.bfloat16` |
|
||||
| [`LTX Video 0.9.1`](https://huggingface.co/Lightricks/LTX-Video/blob/main/ltx-video-2b-v0.9.1.safetensors) | `torch.bfloat16` |
|
||||
| [`LTX Video 0.9.5`](https://huggingface.co/Lightricks/LTX-Video/blob/main/ltx-video-2b-v0.9.5.safetensors) | `torch.bfloat16` |
|
||||
|
||||
Note: The recommended dtype is for the transformer component. The VAE and text encoders can be either `torch.float32`, `torch.bfloat16` or `torch.float16` but the recommended dtype is `torch.bfloat16` as used in the original repository.
|
||||
|
||||
@@ -196,6 +198,12 @@ export_to_video(video, "ship.mp4", fps=24)
|
||||
- all
|
||||
- __call__
|
||||
|
||||
## LTXConditionPipeline
|
||||
|
||||
[[autodoc]] LTXConditionPipeline
|
||||
- all
|
||||
- __call__
|
||||
|
||||
## LTXPipelineOutput
|
||||
|
||||
[[autodoc]] pipelines.ltx.pipeline_output.LTXPipelineOutput
|
||||
|
||||
@@ -58,10 +58,10 @@ Use [`torch.compile`](https://huggingface.co/docs/diffusers/main/en/tutorials/fa
|
||||
First, load the pipeline:
|
||||
|
||||
```python
|
||||
from diffusers import LuminaText2ImgPipeline
|
||||
from diffusers import LuminaPipeline
|
||||
import torch
|
||||
|
||||
pipeline = LuminaText2ImgPipeline.from_pretrained(
|
||||
pipeline = LuminaPipeline.from_pretrained(
|
||||
"Alpha-VLLM/Lumina-Next-SFT-diffusers", torch_dtype=torch.bfloat16
|
||||
).to("cuda")
|
||||
```
|
||||
@@ -86,11 +86,11 @@ image = pipeline(prompt="Upper body of a young woman in a Victorian-era outfit w
|
||||
|
||||
Quantization helps reduce the memory requirements of very large models by storing model weights in a lower precision data type. However, quantization may have varying impact on video quality depending on the video model.
|
||||
|
||||
Refer to the [Quantization](../../quantization/overview) overview to learn more about supported quantization backends and selecting a quantization backend that supports your use case. The example below demonstrates how to load a quantized [`LuminaText2ImgPipeline`] for inference with bitsandbytes.
|
||||
Refer to the [Quantization](../../quantization/overview) overview to learn more about supported quantization backends and selecting a quantization backend that supports your use case. The example below demonstrates how to load a quantized [`LuminaPipeline`] for inference with bitsandbytes.
|
||||
|
||||
```py
|
||||
import torch
|
||||
from diffusers import BitsAndBytesConfig as DiffusersBitsAndBytesConfig, Transformer2DModel, LuminaText2ImgPipeline
|
||||
from diffusers import BitsAndBytesConfig as DiffusersBitsAndBytesConfig, Transformer2DModel, LuminaPipeline
|
||||
from transformers import BitsAndBytesConfig as BitsAndBytesConfig, T5EncoderModel
|
||||
|
||||
quant_config = BitsAndBytesConfig(load_in_8bit=True)
|
||||
@@ -109,7 +109,7 @@ transformer_8bit = Transformer2DModel.from_pretrained(
|
||||
torch_dtype=torch.float16,
|
||||
)
|
||||
|
||||
pipeline = LuminaText2ImgPipeline.from_pretrained(
|
||||
pipeline = LuminaPipeline.from_pretrained(
|
||||
"Alpha-VLLM/Lumina-Next-SFT-diffusers",
|
||||
text_encoder=text_encoder_8bit,
|
||||
transformer=transformer_8bit,
|
||||
@@ -122,9 +122,9 @@ image = pipeline(prompt).images[0]
|
||||
image.save("lumina.png")
|
||||
```
|
||||
|
||||
## LuminaText2ImgPipeline
|
||||
## LuminaPipeline
|
||||
|
||||
[[autodoc]] LuminaText2ImgPipeline
|
||||
[[autodoc]] LuminaPipeline
|
||||
- all
|
||||
- __call__
|
||||
|
||||
|
||||
@@ -36,14 +36,14 @@ Single file loading for Lumina Image 2.0 is available for the `Lumina2Transforme
|
||||
|
||||
```python
|
||||
import torch
|
||||
from diffusers import Lumina2Transformer2DModel, Lumina2Text2ImgPipeline
|
||||
from diffusers import Lumina2Transformer2DModel, Lumina2Pipeline
|
||||
|
||||
ckpt_path = "https://huggingface.co/Alpha-VLLM/Lumina-Image-2.0/blob/main/consolidated.00-of-01.pth"
|
||||
transformer = Lumina2Transformer2DModel.from_single_file(
|
||||
ckpt_path, torch_dtype=torch.bfloat16
|
||||
)
|
||||
|
||||
pipe = Lumina2Text2ImgPipeline.from_pretrained(
|
||||
pipe = Lumina2Pipeline.from_pretrained(
|
||||
"Alpha-VLLM/Lumina-Image-2.0", transformer=transformer, torch_dtype=torch.bfloat16
|
||||
)
|
||||
pipe.enable_model_cpu_offload()
|
||||
@@ -60,7 +60,7 @@ image.save("lumina-single-file.png")
|
||||
GGUF Quantized checkpoints for the `Lumina2Transformer2DModel` can be loaded via `from_single_file` with the `GGUFQuantizationConfig`
|
||||
|
||||
```python
|
||||
from diffusers import Lumina2Transformer2DModel, Lumina2Text2ImgPipeline, GGUFQuantizationConfig
|
||||
from diffusers import Lumina2Transformer2DModel, Lumina2Pipeline, GGUFQuantizationConfig
|
||||
|
||||
ckpt_path = "https://huggingface.co/calcuis/lumina-gguf/blob/main/lumina2-q4_0.gguf"
|
||||
transformer = Lumina2Transformer2DModel.from_single_file(
|
||||
@@ -69,7 +69,7 @@ transformer = Lumina2Transformer2DModel.from_single_file(
|
||||
torch_dtype=torch.bfloat16,
|
||||
)
|
||||
|
||||
pipe = Lumina2Text2ImgPipeline.from_pretrained(
|
||||
pipe = Lumina2Pipeline.from_pretrained(
|
||||
"Alpha-VLLM/Lumina-Image-2.0", transformer=transformer, torch_dtype=torch.bfloat16
|
||||
)
|
||||
pipe.enable_model_cpu_offload()
|
||||
@@ -80,8 +80,8 @@ image = pipe(
|
||||
image.save("lumina-gguf.png")
|
||||
```
|
||||
|
||||
## Lumina2Text2ImgPipeline
|
||||
## Lumina2Pipeline
|
||||
|
||||
[[autodoc]] Lumina2Text2ImgPipeline
|
||||
[[autodoc]] Lumina2Pipeline
|
||||
- all
|
||||
- __call__
|
||||
|
||||
@@ -16,6 +16,7 @@
|
||||
|
||||
<div class="flex flex-wrap space-x-1">
|
||||
<img alt="LoRA" src="https://img.shields.io/badge/LoRA-d8b4fe?style=flat"/>
|
||||
<img alt="MPS" src="https://img.shields.io/badge/MPS-000000?style=flat&logo=apple&logoColor=white%22">
|
||||
</div>
|
||||
|
||||
[SANA: Efficient High-Resolution Image Synthesis with Linear Diffusion Transformers](https://huggingface.co/papers/2410.10629) from NVIDIA and MIT HAN Lab, by Enze Xie, Junsong Chen, Junyu Chen, Han Cai, Haotian Tang, Yujun Lin, Zhekai Zhang, Muyang Li, Ligeng Zhu, Yao Lu, Song Han.
|
||||
|
||||
@@ -0,0 +1,100 @@
|
||||
<!-- Copyright 2024 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License. -->
|
||||
|
||||
# SANA-Sprint
|
||||
|
||||
<div class="flex flex-wrap space-x-1">
|
||||
<img alt="LoRA" src="https://img.shields.io/badge/LoRA-d8b4fe?style=flat"/>
|
||||
</div>
|
||||
|
||||
[SANA-Sprint: One-Step Diffusion with Continuous-Time Consistency Distillation](https://huggingface.co/papers/2503.09641) from NVIDIA, MIT HAN Lab, and Hugging Face by Junsong Chen, Shuchen Xue, Yuyang Zhao, Jincheng Yu, Sayak Paul, Junyu Chen, Han Cai, Enze Xie, Song Han
|
||||
|
||||
The abstract from the paper is:
|
||||
|
||||
*This paper presents SANA-Sprint, an efficient diffusion model for ultra-fast text-to-image (T2I) generation. SANA-Sprint is built on a pre-trained foundation model and augmented with hybrid distillation, dramatically reducing inference steps from 20 to 1-4. We introduce three key innovations: (1) We propose a training-free approach that transforms a pre-trained flow-matching model for continuous-time consistency distillation (sCM), eliminating costly training from scratch and achieving high training efficiency. Our hybrid distillation strategy combines sCM with latent adversarial distillation (LADD): sCM ensures alignment with the teacher model, while LADD enhances single-step generation fidelity. (2) SANA-Sprint is a unified step-adaptive model that achieves high-quality generation in 1-4 steps, eliminating step-specific training and improving efficiency. (3) We integrate ControlNet with SANA-Sprint for real-time interactive image generation, enabling instant visual feedback for user interaction. SANA-Sprint establishes a new Pareto frontier in speed-quality tradeoffs, achieving state-of-the-art performance with 7.59 FID and 0.74 GenEval in only 1 step — outperforming FLUX-schnell (7.94 FID / 0.71 GenEval) while being 10× faster (0.1s vs 1.1s on H100). It also achieves 0.1s (T2I) and 0.25s (ControlNet) latency for 1024×1024 images on H100, and 0.31s (T2I) on an RTX 4090, showcasing its exceptional efficiency and potential for AI-powered consumer applications (AIPC). Code and pre-trained models will be open-sourced.*
|
||||
|
||||
<Tip>
|
||||
|
||||
Make sure to check out the Schedulers [guide](../../using-diffusers/schedulers) to learn how to explore the tradeoff between scheduler speed and quality, and see the [reuse components across pipelines](../../using-diffusers/loading#reuse-a-pipeline) section to learn how to efficiently load the same components into multiple pipelines.
|
||||
|
||||
</Tip>
|
||||
|
||||
This pipeline was contributed by [lawrence-cj](https://github.com/lawrence-cj), [shuchen Xue](https://github.com/scxue) and [Enze Xie](https://github.com/xieenze). The original codebase can be found [here](https://github.com/NVlabs/Sana). The original weights can be found under [hf.co/Efficient-Large-Model](https://huggingface.co/Efficient-Large-Model/).
|
||||
|
||||
Available models:
|
||||
|
||||
| Model | Recommended dtype |
|
||||
|:-------------------------------------------------------------------------------------------------------------------------------------------:|:-----------------:|
|
||||
| [`Efficient-Large-Model/Sana_Sprint_1.6B_1024px_diffusers`](https://huggingface.co/Efficient-Large-Model/Sana_Sprint_1.6B_1024px_diffusers) | `torch.bfloat16` |
|
||||
| [`Efficient-Large-Model/Sana_Sprint_0.6B_1024px_diffusers`](https://huggingface.co/Efficient-Large-Model/Sana_Sprint_0.6B_1024px_diffusers) | `torch.bfloat16` |
|
||||
|
||||
Refer to [this](https://huggingface.co/collections/Efficient-Large-Model/sana-sprint-67d6810d65235085b3b17c76) collection for more information.
|
||||
|
||||
Note: The recommended dtype mentioned is for the transformer weights. The text encoder must stay in `torch.bfloat16` and VAE weights must stay in `torch.bfloat16` or `torch.float32` for the model to work correctly. Please refer to the inference example below to see how to load the model with the recommended dtype.
|
||||
|
||||
|
||||
## Quantization
|
||||
|
||||
Quantization helps reduce the memory requirements of very large models by storing model weights in a lower precision data type. However, quantization may have varying impact on video quality depending on the video model.
|
||||
|
||||
Refer to the [Quantization](../../quantization/overview) overview to learn more about supported quantization backends and selecting a quantization backend that supports your use case. The example below demonstrates how to load a quantized [`SanaSprintPipeline`] for inference with bitsandbytes.
|
||||
|
||||
```py
|
||||
import torch
|
||||
from diffusers import BitsAndBytesConfig as DiffusersBitsAndBytesConfig, SanaTransformer2DModel, SanaSprintPipeline
|
||||
from transformers import BitsAndBytesConfig as BitsAndBytesConfig, AutoModel
|
||||
|
||||
quant_config = BitsAndBytesConfig(load_in_8bit=True)
|
||||
text_encoder_8bit = AutoModel.from_pretrained(
|
||||
"Efficient-Large-Model/Sana_Sprint_1.6B_1024px_diffusers",
|
||||
subfolder="text_encoder",
|
||||
quantization_config=quant_config,
|
||||
torch_dtype=torch.bfloat16,
|
||||
)
|
||||
|
||||
quant_config = DiffusersBitsAndBytesConfig(load_in_8bit=True)
|
||||
transformer_8bit = SanaTransformer2DModel.from_pretrained(
|
||||
"Efficient-Large-Model/Sana_Sprint_1.6B_1024px_diffusers",
|
||||
subfolder="transformer",
|
||||
quantization_config=quant_config,
|
||||
torch_dtype=torch.bfloat16,
|
||||
)
|
||||
|
||||
pipeline = SanaSprintPipeline.from_pretrained(
|
||||
"Efficient-Large-Model/Sana_Sprint_1.6B_1024px_diffusers",
|
||||
text_encoder=text_encoder_8bit,
|
||||
transformer=transformer_8bit,
|
||||
torch_dtype=torch.bfloat16,
|
||||
device_map="balanced",
|
||||
)
|
||||
|
||||
prompt = "a tiny astronaut hatching from an egg on the moon"
|
||||
image = pipeline(prompt).images[0]
|
||||
image.save("sana.png")
|
||||
```
|
||||
|
||||
## Setting `max_timesteps`
|
||||
|
||||
Users can tweak the `max_timesteps` value for experimenting with the visual quality of the generated outputs. The default `max_timesteps` value was obtained with an inference-time search process. For more details about it, check out the paper.
|
||||
|
||||
## SanaSprintPipeline
|
||||
|
||||
[[autodoc]] SanaSprintPipeline
|
||||
- all
|
||||
- __call__
|
||||
|
||||
|
||||
## SanaPipelineOutput
|
||||
|
||||
[[autodoc]] pipelines.sana.pipeline_output.SanaPipelineOutput
|
||||
@@ -14,6 +14,7 @@ specific language governing permissions and limitations under the License.
|
||||
|
||||
<div class="flex flex-wrap space-x-1">
|
||||
<img alt="LoRA" src="https://img.shields.io/badge/LoRA-d8b4fe?style=flat"/>
|
||||
<img alt="MPS" src="https://img.shields.io/badge/MPS-000000?style=flat&logo=apple&logoColor=white%22">
|
||||
</div>
|
||||
|
||||
Stable Diffusion 3 (SD3) was proposed in [Scaling Rectified Flow Transformers for High-Resolution Image Synthesis](https://arxiv.org/pdf/2403.03206.pdf) by Patrick Esser, Sumith Kulal, Andreas Blattmann, Rahim Entezari, Jonas Muller, Harry Saini, Yam Levi, Dominik Lorenz, Axel Sauer, Frederic Boesel, Dustin Podell, Tim Dockhorn, Zion English, Kyle Lacey, Alex Goodwin, Yannik Marek, and Robin Rombach.
|
||||
|
||||
@@ -14,6 +14,7 @@ specific language governing permissions and limitations under the License.
|
||||
|
||||
<div class="flex flex-wrap space-x-1">
|
||||
<img alt="LoRA" src="https://img.shields.io/badge/LoRA-d8b4fe?style=flat"/>
|
||||
<img alt="MPS" src="https://img.shields.io/badge/MPS-000000?style=flat&logo=apple&logoColor=white%22">
|
||||
</div>
|
||||
|
||||
Stable Diffusion XL (SDXL) was proposed in [SDXL: Improving Latent Diffusion Models for High-Resolution Image Synthesis](https://huggingface.co/papers/2307.01952) by Dustin Podell, Zion English, Kyle Lacey, Andreas Blattmann, Tim Dockhorn, Jonas Müller, Joe Penna, and Robin Rombach.
|
||||
|
||||
@@ -14,22 +14,405 @@
|
||||
|
||||
# Wan
|
||||
|
||||
<div class="flex flex-wrap space-x-1">
|
||||
<img alt="LoRA" src="https://img.shields.io/badge/LoRA-d8b4fe?style=flat"/>
|
||||
</div>
|
||||
|
||||
[Wan 2.1](https://github.com/Wan-Video/Wan2.1) by the Alibaba Wan Team.
|
||||
|
||||
<!-- TODO(aryan): update abstract once paper is out -->
|
||||
|
||||
## Generating Videos with Wan 2.1
|
||||
|
||||
We will first need to install some addtional dependencies.
|
||||
|
||||
```shell
|
||||
pip install -u ftfy imageio-ffmpeg imageio
|
||||
```
|
||||
|
||||
### Text to Video Generation
|
||||
|
||||
The following example requires 11GB VRAM to run and uses the smaller `Wan-AI/Wan2.1-T2V-1.3B-Diffusers` model. You can switch it out
|
||||
for the larger `Wan2.1-I2V-14B-720P-Diffusers` or `Wan-AI/Wan2.1-I2V-14B-480P-Diffusers` if you have at least 35GB VRAM available.
|
||||
|
||||
```python
|
||||
from diffusers import WanPipeline
|
||||
from diffusers.utils import export_to_video
|
||||
|
||||
# Available models: Wan-AI/Wan2.1-I2V-14B-720P-Diffusers or Wan-AI/Wan2.1-I2V-14B-480P-Diffusers
|
||||
model_id = "Wan-AI/Wan2.1-T2V-1.3B-Diffusers"
|
||||
|
||||
pipe = WanPipeline.from_pretrained(model_id, torch_dtype=torch.bfloat16)
|
||||
pipe.enable_model_cpu_offload()
|
||||
|
||||
prompt = "A cat and a dog baking a cake together in a kitchen. The cat is carefully measuring flour, while the dog is stirring the batter with a wooden spoon. The kitchen is cozy, with sunlight streaming through the window."
|
||||
negative_prompt = "Bright tones, overexposed, static, blurred details, subtitles, style, works, paintings, images, static, overall gray, worst quality, low quality, JPEG compression residue, ugly, incomplete, extra fingers, poorly drawn hands, poorly drawn faces, deformed, disfigured, misshapen limbs, fused fingers, still picture, messy background, three legs, many people in the background, walking backwards"
|
||||
num_frames = 33
|
||||
|
||||
frames = pipe(prompt=prompt, negative_prompt=negative_prompt, num_frames=num_frames).frames[0]
|
||||
export_to_video(frames, "wan-t2v.mp4", fps=16)
|
||||
```
|
||||
|
||||
<Tip>
|
||||
|
||||
Make sure to check out the Schedulers [guide](../../using-diffusers/schedulers) to learn how to explore the tradeoff between scheduler speed and quality, and see the [reuse components across pipelines](../../using-diffusers/loading#reuse-a-pipeline) section to learn how to efficiently load the same components into multiple pipelines.
|
||||
|
||||
You can improve the quality of the generated video by running the decoding step in full precision.
|
||||
</Tip>
|
||||
|
||||
Recommendations for inference:
|
||||
- VAE in `torch.float32` for better decoding quality.
|
||||
- `num_frames` should be of the form `4 * k + 1`, for example `49` or `81`.
|
||||
- For smaller resolution videos, try lower values of `shift` (between `2.0` to `5.0`) in the [Scheduler](https://huggingface.co/docs/diffusers/main/en/api/schedulers/flow_match_euler_discrete#diffusers.FlowMatchEulerDiscreteScheduler.shift). For larger resolution videos, try higher values (between `7.0` and `12.0`). The default value is `3.0` for Wan.
|
||||
```python
|
||||
from diffusers import WanPipeline, AutoencoderKLWan
|
||||
from diffusers.utils import export_to_video
|
||||
|
||||
### Using a custom scheduler
|
||||
model_id = "Wan-AI/Wan2.1-T2V-1.3B-Diffusers"
|
||||
|
||||
vae = AutoencoderKLWan.from_pretrained(model_id, subfolder="vae", torch_dtype=torch.float32)
|
||||
pipe = WanPipeline.from_pretrained(model_id, vae=vae, torch_dtype=torch.bfloat16)
|
||||
|
||||
# replace this with pipe.to("cuda") if you have sufficient VRAM
|
||||
pipe.enable_model_cpu_offload()
|
||||
|
||||
prompt = "A cat and a dog baking a cake together in a kitchen. The cat is carefully measuring flour, while the dog is stirring the batter with a wooden spoon. The kitchen is cozy, with sunlight streaming through the window."
|
||||
negative_prompt = "Bright tones, overexposed, static, blurred details, subtitles, style, works, paintings, images, static, overall gray, worst quality, low quality, JPEG compression residue, ugly, incomplete, extra fingers, poorly drawn hands, poorly drawn faces, deformed, disfigured, misshapen limbs, fused fingers, still picture, messy background, three legs, many people in the background, walking backwards"
|
||||
num_frames = 33
|
||||
|
||||
frames = pipe(prompt=prompt, num_frames=num_frames).frames[0]
|
||||
export_to_video(frames, "wan-t2v.mp4", fps=16)
|
||||
```
|
||||
|
||||
### Image to Video Generation
|
||||
|
||||
The Image to Video pipeline requires loading the `AutoencoderKLWan` and the `CLIPVisionModel` components in full precision. The following example will need at least
|
||||
35GB of VRAM to run.
|
||||
|
||||
```python
|
||||
import torch
|
||||
import numpy as np
|
||||
from diffusers import AutoencoderKLWan, WanImageToVideoPipeline
|
||||
from diffusers.utils import export_to_video, load_image
|
||||
from transformers import CLIPVisionModel
|
||||
|
||||
# Available models: Wan-AI/Wan2.1-I2V-14B-480P-Diffusers, Wan-AI/Wan2.1-I2V-14B-720P-Diffusers
|
||||
model_id = "Wan-AI/Wan2.1-I2V-14B-480P-Diffusers"
|
||||
image_encoder = CLIPVisionModel.from_pretrained(
|
||||
model_id, subfolder="image_encoder", torch_dtype=torch.float32
|
||||
)
|
||||
vae = AutoencoderKLWan.from_pretrained(model_id, subfolder="vae", torch_dtype=torch.float32)
|
||||
pipe = WanImageToVideoPipeline.from_pretrained(
|
||||
model_id, vae=vae, image_encoder=image_encoder, torch_dtype=torch.bfloat16
|
||||
)
|
||||
|
||||
# replace this with pipe.to("cuda") if you have sufficient VRAM
|
||||
pipe.enable_model_cpu_offload()
|
||||
|
||||
image = load_image(
|
||||
"https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/astronaut.jpg"
|
||||
)
|
||||
|
||||
max_area = 480 * 832
|
||||
aspect_ratio = image.height / image.width
|
||||
mod_value = pipe.vae_scale_factor_spatial * pipe.transformer.config.patch_size[1]
|
||||
height = round(np.sqrt(max_area * aspect_ratio)) // mod_value * mod_value
|
||||
width = round(np.sqrt(max_area / aspect_ratio)) // mod_value * mod_value
|
||||
image = image.resize((width, height))
|
||||
|
||||
prompt = (
|
||||
"An astronaut hatching from an egg, on the surface of the moon, the darkness and depth of space realised in "
|
||||
"the background. High quality, ultrarealistic detail and breath-taking movie-like camera shot."
|
||||
)
|
||||
negative_prompt = "Bright tones, overexposed, static, blurred details, subtitles, style, works, paintings, images, static, overall gray, worst quality, low quality, JPEG compression residue, ugly, incomplete, extra fingers, poorly drawn hands, poorly drawn faces, deformed, disfigured, misshapen limbs, fused fingers, still picture, messy background, three legs, many people in the background, walking backwards"
|
||||
|
||||
num_frames = 33
|
||||
|
||||
output = pipe(
|
||||
image=image,
|
||||
prompt=prompt,
|
||||
negative_prompt=negative_prompt,
|
||||
height=height,
|
||||
width=width,
|
||||
num_frames=num_frames,
|
||||
guidance_scale=5.0,
|
||||
).frames[0]
|
||||
export_to_video(output, "wan-i2v.mp4", fps=16)
|
||||
```
|
||||
|
||||
### Video to Video Generation
|
||||
|
||||
```python
|
||||
import torch
|
||||
from diffusers.utils import load_video, export_to_video
|
||||
from diffusers import AutoencoderKLWan, WanVideoToVideoPipeline, UniPCMultistepScheduler
|
||||
|
||||
# Available models: Wan-AI/Wan2.1-T2V-14B-Diffusers, Wan-AI/Wan2.1-T2V-1.3B-Diffusers
|
||||
model_id = "Wan-AI/Wan2.1-T2V-1.3B-Diffusers"
|
||||
vae = AutoencoderKLWan.from_pretrained(
|
||||
model_id, subfolder="vae", torch_dtype=torch.float32
|
||||
)
|
||||
pipe = WanVideoToVideoPipeline.from_pretrained(
|
||||
model_id, vae=vae, torch_dtype=torch.bfloat16
|
||||
)
|
||||
flow_shift = 3.0 # 5.0 for 720P, 3.0 for 480P
|
||||
pipe.scheduler = UniPCMultistepScheduler.from_config(
|
||||
pipe.scheduler.config, flow_shift=flow_shift
|
||||
)
|
||||
# change to pipe.to("cuda") if you have sufficient VRAM
|
||||
pipe.enable_model_cpu_offload()
|
||||
|
||||
prompt = "A robot standing on a mountain top. The sun is setting in the background"
|
||||
negative_prompt = "Bright tones, overexposed, static, blurred details, subtitles, style, works, paintings, images, static, overall gray, worst quality, low quality, JPEG compression residue, ugly, incomplete, extra fingers, poorly drawn hands, poorly drawn faces, deformed, disfigured, misshapen limbs, fused fingers, still picture, messy background, three legs, many people in the background, walking backwards"
|
||||
video = load_video(
|
||||
"https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/hiker.mp4"
|
||||
)
|
||||
output = pipe(
|
||||
video=video,
|
||||
prompt=prompt,
|
||||
negative_prompt=negative_prompt,
|
||||
height=480,
|
||||
width=512,
|
||||
guidance_scale=7.0,
|
||||
strength=0.7,
|
||||
).frames[0]
|
||||
|
||||
export_to_video(output, "wan-v2v.mp4", fps=16)
|
||||
```
|
||||
|
||||
## Memory Optimizations for Wan 2.1
|
||||
|
||||
Base inference with the large 14B Wan 2.1 models can take up to 35GB of VRAM when generating videos at 720p resolution. We'll outline a few memory optimizations we can apply to reduce the VRAM required to run the model.
|
||||
|
||||
We'll use `Wan-AI/Wan2.1-I2V-14B-720P-Diffusers` model in these examples to demonstrate the memory savings, but the techniques are applicable to all model checkpoints.
|
||||
|
||||
### Group Offloading the Transformer and UMT5 Text Encoder
|
||||
|
||||
Find more information about group offloading [here](../optimization/memory.md)
|
||||
|
||||
#### Block Level Group Offloading
|
||||
|
||||
We can reduce our VRAM requirements by applying group offloading to the larger model components of the pipeline; the `WanTransformer3DModel` and `UMT5EncoderModel`. Group offloading will break up the individual modules of a model and offload/onload them onto your GPU as needed during inference. In this example, we'll apply `block_level` offloading, which will group the modules in a model into blocks of size `num_blocks_per_group` and offload/onload them to GPU. Moving to between CPU and GPU does add latency to the inference process. You can trade off between latency and memory savings by increasing or decreasing the `num_blocks_per_group`.
|
||||
|
||||
The following example will now only require 14GB of VRAM to run, but will take approximately 30 minutes to generate a video.
|
||||
|
||||
```python
|
||||
import torch
|
||||
import numpy as np
|
||||
from diffusers import AutoencoderKLWan, WanTransformer3DModel, WanImageToVideoPipeline
|
||||
from diffusers.hooks.group_offloading import apply_group_offloading
|
||||
from diffusers.utils import export_to_video, load_image
|
||||
from transformers import UMT5EncoderModel, CLIPVisionModel
|
||||
|
||||
# Available models: Wan-AI/Wan2.1-I2V-14B-480P-Diffusers, Wan-AI/Wan2.1-I2V-14B-720P-Diffusers
|
||||
model_id = "Wan-AI/Wan2.1-I2V-14B-720P-Diffusers"
|
||||
image_encoder = CLIPVisionModel.from_pretrained(
|
||||
model_id, subfolder="image_encoder", torch_dtype=torch.float32
|
||||
)
|
||||
|
||||
text_encoder = UMT5EncoderModel.from_pretrained(model_id, subfolder="text_encoder", torch_dtype=torch.bfloat16)
|
||||
vae = AutoencoderKLWan.from_pretrained(model_id, subfolder="vae", torch_dtype=torch.float32)
|
||||
transformer = WanTransformer3DModel.from_pretrained(model_id, subfolder="transformer", torch_dtype=torch.bfloat16)
|
||||
|
||||
onload_device = torch.device("cuda")
|
||||
offload_device = torch.device("cpu")
|
||||
|
||||
apply_group_offloading(text_encoder,
|
||||
onload_device=onload_device,
|
||||
offload_device=offload_device,
|
||||
offload_type="block_level",
|
||||
num_blocks_per_group=4
|
||||
)
|
||||
|
||||
transformer.enable_group_offload(
|
||||
onload_device=onload_device,
|
||||
offload_device=offload_device,
|
||||
offload_type="block_level",
|
||||
num_blocks_per_group=4,
|
||||
)
|
||||
pipe = WanImageToVideoPipeline.from_pretrained(
|
||||
model_id,
|
||||
vae=vae,
|
||||
transformer=transformer,
|
||||
text_encoder=text_encoder,
|
||||
image_encoder=image_encoder,
|
||||
torch_dtype=torch.bfloat16
|
||||
)
|
||||
# Since we've offloaded the larger models alrady, we can move the rest of the model components to GPU
|
||||
pipe.to("cuda")
|
||||
|
||||
image = load_image(
|
||||
"https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/astronaut.jpg"
|
||||
)
|
||||
|
||||
max_area = 720 * 832
|
||||
aspect_ratio = image.height / image.width
|
||||
mod_value = pipe.vae_scale_factor_spatial * pipe.transformer.config.patch_size[1]
|
||||
height = round(np.sqrt(max_area * aspect_ratio)) // mod_value * mod_value
|
||||
width = round(np.sqrt(max_area / aspect_ratio)) // mod_value * mod_value
|
||||
image = image.resize((width, height))
|
||||
|
||||
prompt = (
|
||||
"An astronaut hatching from an egg, on the surface of the moon, the darkness and depth of space realised in "
|
||||
"the background. High quality, ultrarealistic detail and breath-taking movie-like camera shot."
|
||||
)
|
||||
negative_prompt = "Bright tones, overexposed, static, blurred details, subtitles, style, works, paintings, images, static, overall gray, worst quality, low quality, JPEG compression residue, ugly, incomplete, extra fingers, poorly drawn hands, poorly drawn faces, deformed, disfigured, misshapen limbs, fused fingers, still picture, messy background, three legs, many people in the background, walking backwards"
|
||||
|
||||
num_frames = 33
|
||||
|
||||
output = pipe(
|
||||
image=image,
|
||||
prompt=prompt,
|
||||
negative_prompt=negative_prompt,
|
||||
height=height,
|
||||
width=width,
|
||||
num_frames=num_frames,
|
||||
guidance_scale=5.0,
|
||||
).frames[0]
|
||||
|
||||
export_to_video(output, "wan-i2v.mp4", fps=16)
|
||||
```
|
||||
|
||||
#### Block Level Group Offloading with CUDA Streams
|
||||
|
||||
We can speed up group offloading inference, by enabling the use of [CUDA streams](https://pytorch.org/docs/stable/generated/torch.cuda.Stream.html). However, using CUDA streams requires moving the model parameters into pinned memory. This allocation is handled by Pytorch under the hood, and can result in a significant spike in CPU RAM usage. Please consider this option if your CPU RAM is atleast 2X the size of the model you are group offloading.
|
||||
|
||||
In the following example we will use CUDA streams when group offloading the `WanTransformer3DModel`. When testing on an A100, this example will require 14GB of VRAM, 52GB of CPU RAM, but will generate a video in approximately 9 minutes.
|
||||
|
||||
```python
|
||||
import torch
|
||||
import numpy as np
|
||||
from diffusers import AutoencoderKLWan, WanTransformer3DModel, WanImageToVideoPipeline
|
||||
from diffusers.hooks.group_offloading import apply_group_offloading
|
||||
from diffusers.utils import export_to_video, load_image
|
||||
from transformers import UMT5EncoderModel, CLIPVisionModel
|
||||
|
||||
# Available models: Wan-AI/Wan2.1-I2V-14B-480P-Diffusers, Wan-AI/Wan2.1-I2V-14B-720P-Diffusers
|
||||
model_id = "Wan-AI/Wan2.1-I2V-14B-720P-Diffusers"
|
||||
image_encoder = CLIPVisionModel.from_pretrained(
|
||||
model_id, subfolder="image_encoder", torch_dtype=torch.float32
|
||||
)
|
||||
|
||||
text_encoder = UMT5EncoderModel.from_pretrained(model_id, subfolder="text_encoder", torch_dtype=torch.bfloat16)
|
||||
vae = AutoencoderKLWan.from_pretrained(model_id, subfolder="vae", torch_dtype=torch.float32)
|
||||
transformer = WanTransformer3DModel.from_pretrained(model_id, subfolder="transformer", torch_dtype=torch.bfloat16)
|
||||
|
||||
onload_device = torch.device("cuda")
|
||||
offload_device = torch.device("cpu")
|
||||
|
||||
apply_group_offloading(text_encoder,
|
||||
onload_device=onload_device,
|
||||
offload_device=offload_device,
|
||||
offload_type="block_level",
|
||||
num_blocks_per_group=4
|
||||
)
|
||||
|
||||
transformer.enable_group_offload(
|
||||
onload_device=onload_device,
|
||||
offload_device=offload_device,
|
||||
offload_type="leaf_level",
|
||||
use_stream=True
|
||||
)
|
||||
pipe = WanImageToVideoPipeline.from_pretrained(
|
||||
model_id,
|
||||
vae=vae,
|
||||
transformer=transformer,
|
||||
text_encoder=text_encoder,
|
||||
image_encoder=image_encoder,
|
||||
torch_dtype=torch.bfloat16
|
||||
)
|
||||
# Since we've offloaded the larger models alrady, we can move the rest of the model components to GPU
|
||||
pipe.to("cuda")
|
||||
|
||||
image = load_image(
|
||||
"https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/astronaut.jpg"
|
||||
)
|
||||
|
||||
max_area = 720 * 832
|
||||
aspect_ratio = image.height / image.width
|
||||
mod_value = pipe.vae_scale_factor_spatial * pipe.transformer.config.patch_size[1]
|
||||
height = round(np.sqrt(max_area * aspect_ratio)) // mod_value * mod_value
|
||||
width = round(np.sqrt(max_area / aspect_ratio)) // mod_value * mod_value
|
||||
image = image.resize((width, height))
|
||||
|
||||
prompt = (
|
||||
"An astronaut hatching from an egg, on the surface of the moon, the darkness and depth of space realised in "
|
||||
"the background. High quality, ultrarealistic detail and breath-taking movie-like camera shot."
|
||||
)
|
||||
negative_prompt = "Bright tones, overexposed, static, blurred details, subtitles, style, works, paintings, images, static, overall gray, worst quality, low quality, JPEG compression residue, ugly, incomplete, extra fingers, poorly drawn hands, poorly drawn faces, deformed, disfigured, misshapen limbs, fused fingers, still picture, messy background, three legs, many people in the background, walking backwards"
|
||||
|
||||
num_frames = 33
|
||||
|
||||
output = pipe(
|
||||
image=image,
|
||||
prompt=prompt,
|
||||
negative_prompt=negative_prompt,
|
||||
height=height,
|
||||
width=width,
|
||||
num_frames=num_frames,
|
||||
guidance_scale=5.0,
|
||||
).frames[0]
|
||||
|
||||
export_to_video(output, "wan-i2v.mp4", fps=16)
|
||||
```
|
||||
|
||||
### Applying Layerwise Casting to the Transformer
|
||||
|
||||
Find more information about layerwise casting [here](../optimization/memory.md)
|
||||
|
||||
In this example, we will model offloading with layerwise casting. Layerwise casting will downcast each layer's weights to `torch.float8_e4m3fn`, temporarily upcast to `torch.bfloat16` during the forward pass of the layer, then revert to `torch.float8_e4m3fn` afterward. This approach reduces memory requirements by approximately 50% while introducing a minor quality reduction in the generated video due to the precision trade-off.
|
||||
|
||||
This example will require 20GB of VRAM.
|
||||
|
||||
```python
|
||||
import torch
|
||||
import numpy as np
|
||||
from diffusers import AutoencoderKLWan, WanTransformer3DModel, WanImageToVideoPipeline
|
||||
from diffusers.hooks.group_offloading import apply_group_offloading
|
||||
from diffusers.utils import export_to_video, load_image
|
||||
from transformers import UMT5EncoderModel, CLIPVisionModel
|
||||
|
||||
model_id = "Wan-AI/Wan2.1-I2V-14B-720P-Diffusers"
|
||||
image_encoder = CLIPVisionModel.from_pretrained(
|
||||
model_id, subfolder="image_encoder", torch_dtype=torch.float32
|
||||
)
|
||||
text_encoder = UMT5EncoderModel.from_pretrained(model_id, subfolder="text_encoder", torch_dtype=torch.bfloat16)
|
||||
vae = AutoencoderKLWan.from_pretrained(model_id, subfolder="vae", torch_dtype=torch.float32)
|
||||
|
||||
transformer = WanTransformer3DModel.from_pretrained(model_id, subfolder="transformer", torch_dtype=torch.bfloat16)
|
||||
transformer.enable_layerwise_casting(storage_dtype=torch.float8_e4m3fn, compute_dtype=torch.bfloat16)
|
||||
|
||||
pipe = WanImageToVideoPipeline.from_pretrained(
|
||||
model_id,
|
||||
vae=vae,
|
||||
transformer=transformer,
|
||||
text_encoder=text_encoder,
|
||||
image_encoder=image_encoder,
|
||||
torch_dtype=torch.bfloat16
|
||||
)
|
||||
pipe.enable_model_cpu_offload()
|
||||
image = load_image("https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/astronaut.jpg")
|
||||
|
||||
max_area = 720 * 832
|
||||
aspect_ratio = image.height / image.width
|
||||
mod_value = pipe.vae_scale_factor_spatial * pipe.transformer.config.patch_size[1]
|
||||
height = round(np.sqrt(max_area * aspect_ratio)) // mod_value * mod_value
|
||||
width = round(np.sqrt(max_area / aspect_ratio)) // mod_value * mod_value
|
||||
image = image.resize((width, height))
|
||||
prompt = (
|
||||
"An astronaut hatching from an egg, on the surface of the moon, the darkness and depth of space realised in "
|
||||
"the background. High quality, ultrarealistic detail and breath-taking movie-like camera shot."
|
||||
)
|
||||
negative_prompt = "Bright tones, overexposed, static, blurred details, subtitles, style, works, paintings, images, static, overall gray, worst quality, low quality, JPEG compression residue, ugly, incomplete, extra fingers, poorly drawn hands, poorly drawn faces, deformed, disfigured, misshapen limbs, fused fingers, still picture, messy background, three legs, many people in the background, walking backwards"
|
||||
num_frames = 33
|
||||
|
||||
output = pipe(
|
||||
image=image,
|
||||
prompt=prompt,
|
||||
negative_prompt=negative_prompt,
|
||||
height=height,
|
||||
width=width,
|
||||
num_frames=num_frames,
|
||||
num_inference_steps=50,
|
||||
guidance_scale=5.0,
|
||||
).frames[0]
|
||||
export_to_video(output, "wan-i2v.mp4", fps=16)
|
||||
```
|
||||
|
||||
## Using a Custom Scheduler
|
||||
|
||||
Wan can be used with many different schedulers, each with their own benefits regarding speed and generation quality. By default, Wan uses the `UniPCMultistepScheduler(prediction_type="flow_prediction", use_flow_sigmas=True, flow_shift=3.0)` scheduler. You can use a different scheduler as follows:
|
||||
|
||||
@@ -45,11 +428,10 @@ pipe = WanPipeline.from_pretrained("Wan-AI/Wan2.1-T2V-1.3B-Diffusers", scheduler
|
||||
pipe.scheduler = <CUSTOM_SCHEDULER_HERE>
|
||||
```
|
||||
|
||||
### Using single file loading with Wan
|
||||
|
||||
The `WanTransformer3DModel` and `AutoencoderKLWan` models support loading checkpoints in their original format via the `from_single_file` loading
|
||||
method.
|
||||
## Using Single File Loading with Wan 2.1
|
||||
|
||||
The `WanTransformer3DModel` and `AutoencoderKLWan` models support loading checkpoints in their original format via the `from_single_file` loading
|
||||
method.
|
||||
|
||||
```python
|
||||
import torch
|
||||
@@ -61,6 +443,11 @@ transformer = WanTransformer3DModel.from_single_file(ckpt_path, torch_dtype=torc
|
||||
pipe = WanPipeline.from_pretrained("Wan-AI/Wan2.1-T2V-1.3B-Diffusers", transformer=transformer)
|
||||
```
|
||||
|
||||
## Recommendations for Inference
|
||||
- Keep `AutencoderKLWan` in `torch.float32` for better decoding quality.
|
||||
- `num_frames` should satisfy the following constraint: `(num_frames - 1) % 4 == 0`
|
||||
- For smaller resolution videos, try lower values of `shift` (between `2.0` to `5.0`) in the [Scheduler](https://huggingface.co/docs/diffusers/main/en/api/schedulers/flow_match_euler_discrete#diffusers.FlowMatchEulerDiscreteScheduler.shift). For larger resolution videos, try higher values (between `7.0` and `12.0`). The default value is `3.0` for Wan.
|
||||
|
||||
## WanPipeline
|
||||
|
||||
[[autodoc]] WanPipeline
|
||||
|
||||
@@ -31,6 +31,11 @@ Learn how to quantize models in the [Quantization](../quantization/overview) gui
|
||||
## GGUFQuantizationConfig
|
||||
|
||||
[[autodoc]] GGUFQuantizationConfig
|
||||
|
||||
## QuantoConfig
|
||||
|
||||
[[autodoc]] QuantoConfig
|
||||
|
||||
## TorchAoConfig
|
||||
|
||||
[[autodoc]] TorchAoConfig
|
||||
|
||||
@@ -3,3 +3,7 @@
|
||||
## Remote Decode
|
||||
|
||||
[[autodoc]] utils.remote_utils.remote_decode
|
||||
|
||||
## Remote Encode
|
||||
|
||||
[[autodoc]] utils.remote_utils.remote_encode
|
||||
|
||||
@@ -36,7 +36,7 @@ Hybrid Inference offers a fast and simple way to offload local generation requir
|
||||
## Available Models
|
||||
|
||||
* **VAE Decode 🖼️:** Quickly decode latent representations into high-quality images without compromising performance or workflow speed.
|
||||
* **VAE Encode 🔢 (coming soon):** Efficiently encode images into latent representations for generation and training.
|
||||
* **VAE Encode 🔢:** Efficiently encode images into latent representations for generation and training.
|
||||
* **Text Encoders 📃 (coming soon):** Compute text embeddings for your prompts quickly and accurately, ensuring a smooth and high-quality workflow.
|
||||
|
||||
---
|
||||
@@ -46,9 +46,15 @@ Hybrid Inference offers a fast and simple way to offload local generation requir
|
||||
* **[SD.Next](https://github.com/vladmandic/sdnext):** All-in-one UI with direct supports Hybrid Inference.
|
||||
* **[ComfyUI-HFRemoteVae](https://github.com/kijai/ComfyUI-HFRemoteVae):** ComfyUI node for Hybrid Inference.
|
||||
|
||||
## Changelog
|
||||
|
||||
- March 10 2025: Added VAE encode
|
||||
- March 2 2025: Initial release with VAE decoding
|
||||
|
||||
## Contents
|
||||
|
||||
The documentation is organized into two sections:
|
||||
The documentation is organized into three sections:
|
||||
|
||||
* **VAE Decode** Learn the basics of how to use VAE Decode with Hybrid Inference.
|
||||
* **VAE Encode** Learn the basics of how to use VAE Encode with Hybrid Inference.
|
||||
* **API Reference** Dive into task-specific settings and parameters.
|
||||
|
||||
@@ -0,0 +1,183 @@
|
||||
# Getting Started: VAE Encode with Hybrid Inference
|
||||
|
||||
VAE encode is used for training, image-to-image and image-to-video - turning into images or videos into latent representations.
|
||||
|
||||
## Memory
|
||||
|
||||
These tables demonstrate the VRAM requirements for VAE encode with SD v1 and SD XL on different GPUs.
|
||||
|
||||
For the majority of these GPUs the memory usage % dictates other models (text encoders, UNet/Transformer) must be offloaded, or tiled encoding has to be used which increases time taken and impacts quality.
|
||||
|
||||
<details><summary>SD v1.5</summary>
|
||||
|
||||
| GPU | Resolution | Time (seconds) | Memory (%) | Tiled Time (secs) | Tiled Memory (%) |
|
||||
|:------------------------------|:-------------|-----------------:|-------------:|--------------------:|-------------------:|
|
||||
| NVIDIA GeForce RTX 4090 | 512x512 | 0.015 | 3.51901 | 0.015 | 3.51901 |
|
||||
| NVIDIA GeForce RTX 4090 | 256x256 | 0.004 | 1.3154 | 0.005 | 1.3154 |
|
||||
| NVIDIA GeForce RTX 4090 | 2048x2048 | 0.402 | 47.1852 | 0.496 | 3.51901 |
|
||||
| NVIDIA GeForce RTX 4090 | 1024x1024 | 0.078 | 12.2658 | 0.094 | 3.51901 |
|
||||
| NVIDIA GeForce RTX 4080 SUPER | 512x512 | 0.023 | 5.30105 | 0.023 | 5.30105 |
|
||||
| NVIDIA GeForce RTX 4080 SUPER | 256x256 | 0.006 | 1.98152 | 0.006 | 1.98152 |
|
||||
| NVIDIA GeForce RTX 4080 SUPER | 2048x2048 | 0.574 | 71.08 | 0.656 | 5.30105 |
|
||||
| NVIDIA GeForce RTX 4080 SUPER | 1024x1024 | 0.111 | 18.4772 | 0.14 | 5.30105 |
|
||||
| NVIDIA GeForce RTX 3090 | 512x512 | 0.032 | 3.52782 | 0.032 | 3.52782 |
|
||||
| NVIDIA GeForce RTX 3090 | 256x256 | 0.01 | 1.31869 | 0.009 | 1.31869 |
|
||||
| NVIDIA GeForce RTX 3090 | 2048x2048 | 0.742 | 47.3033 | 0.954 | 3.52782 |
|
||||
| NVIDIA GeForce RTX 3090 | 1024x1024 | 0.136 | 12.2965 | 0.207 | 3.52782 |
|
||||
| NVIDIA GeForce RTX 3080 | 512x512 | 0.036 | 8.51761 | 0.036 | 8.51761 |
|
||||
| NVIDIA GeForce RTX 3080 | 256x256 | 0.01 | 3.18387 | 0.01 | 3.18387 |
|
||||
| NVIDIA GeForce RTX 3080 | 2048x2048 | 0.863 | 86.7424 | 1.191 | 8.51761 |
|
||||
| NVIDIA GeForce RTX 3080 | 1024x1024 | 0.157 | 29.6888 | 0.227 | 8.51761 |
|
||||
| NVIDIA GeForce RTX 3070 | 512x512 | 0.051 | 10.6941 | 0.051 | 10.6941 |
|
||||
| NVIDIA GeForce RTX 3070 | 256x256 | 0.015 | 3.99743 | 0.015 | 3.99743 |
|
||||
| NVIDIA GeForce RTX 3070 | 2048x2048 | 1.217 | 96.054 | 1.482 | 10.6941 |
|
||||
| NVIDIA GeForce RTX 3070 | 1024x1024 | 0.223 | 37.2751 | 0.327 | 10.6941 |
|
||||
|
||||
|
||||
</details>
|
||||
|
||||
<details><summary>SDXL</summary>
|
||||
|
||||
| GPU | Resolution | Time (seconds) | Memory Consumed (%) | Tiled Time (seconds) | Tiled Memory (%) |
|
||||
|:------------------------------|:-------------|-----------------:|----------------------:|-----------------------:|-------------------:|
|
||||
| NVIDIA GeForce RTX 4090 | 512x512 | 0.029 | 4.95707 | 0.029 | 4.95707 |
|
||||
| NVIDIA GeForce RTX 4090 | 256x256 | 0.007 | 2.29666 | 0.007 | 2.29666 |
|
||||
| NVIDIA GeForce RTX 4090 | 2048x2048 | 0.873 | 66.3452 | 0.863 | 15.5649 |
|
||||
| NVIDIA GeForce RTX 4090 | 1024x1024 | 0.142 | 15.5479 | 0.143 | 15.5479 |
|
||||
| NVIDIA GeForce RTX 4080 SUPER | 512x512 | 0.044 | 7.46735 | 0.044 | 7.46735 |
|
||||
| NVIDIA GeForce RTX 4080 SUPER | 256x256 | 0.01 | 3.4597 | 0.01 | 3.4597 |
|
||||
| NVIDIA GeForce RTX 4080 SUPER | 2048x2048 | 1.317 | 87.1615 | 1.291 | 23.447 |
|
||||
| NVIDIA GeForce RTX 4080 SUPER | 1024x1024 | 0.213 | 23.4215 | 0.214 | 23.4215 |
|
||||
| NVIDIA GeForce RTX 3090 | 512x512 | 0.058 | 5.65638 | 0.058 | 5.65638 |
|
||||
| NVIDIA GeForce RTX 3090 | 256x256 | 0.016 | 2.45081 | 0.016 | 2.45081 |
|
||||
| NVIDIA GeForce RTX 3090 | 2048x2048 | 1.755 | 77.8239 | 1.614 | 18.4193 |
|
||||
| NVIDIA GeForce RTX 3090 | 1024x1024 | 0.265 | 18.4023 | 0.265 | 18.4023 |
|
||||
| NVIDIA GeForce RTX 3080 | 512x512 | 0.064 | 13.6568 | 0.064 | 13.6568 |
|
||||
| NVIDIA GeForce RTX 3080 | 256x256 | 0.018 | 5.91728 | 0.018 | 5.91728 |
|
||||
| NVIDIA GeForce RTX 3080 | 2048x2048 | OOM | OOM | 1.866 | 44.4717 |
|
||||
| NVIDIA GeForce RTX 3080 | 1024x1024 | 0.302 | 44.4308 | 0.302 | 44.4308 |
|
||||
| NVIDIA GeForce RTX 3070 | 512x512 | 0.093 | 17.1465 | 0.093 | 17.1465 |
|
||||
| NVIDIA GeForce RTX 3070 | 256x256 | 0.025 | 7.42931 | 0.026 | 7.42931 |
|
||||
| NVIDIA GeForce RTX 3070 | 2048x2048 | OOM | OOM | 2.674 | 55.8355 |
|
||||
| NVIDIA GeForce RTX 3070 | 1024x1024 | 0.443 | 55.7841 | 0.443 | 55.7841 |
|
||||
|
||||
</details>
|
||||
|
||||
## Available VAEs
|
||||
|
||||
| | **Endpoint** | **Model** |
|
||||
|:-:|:-----------:|:--------:|
|
||||
| **Stable Diffusion v1** | [https://qc6479g0aac6qwy9.us-east-1.aws.endpoints.huggingface.cloud](https://qc6479g0aac6qwy9.us-east-1.aws.endpoints.huggingface.cloud) | [`stabilityai/sd-vae-ft-mse`](https://hf.co/stabilityai/sd-vae-ft-mse) |
|
||||
| **Stable Diffusion XL** | [https://xjqqhmyn62rog84g.us-east-1.aws.endpoints.huggingface.cloud](https://xjqqhmyn62rog84g.us-east-1.aws.endpoints.huggingface.cloud) | [`madebyollin/sdxl-vae-fp16-fix`](https://hf.co/madebyollin/sdxl-vae-fp16-fix) |
|
||||
| **Flux** | [https://ptccx55jz97f9zgo.us-east-1.aws.endpoints.huggingface.cloud](https://ptccx55jz97f9zgo.us-east-1.aws.endpoints.huggingface.cloud) | [`black-forest-labs/FLUX.1-schnell`](https://hf.co/black-forest-labs/FLUX.1-schnell) |
|
||||
|
||||
|
||||
> [!TIP]
|
||||
> Model support can be requested [here](https://github.com/huggingface/diffusers/issues/new?template=remote-vae-pilot-feedback.yml).
|
||||
|
||||
|
||||
## Code
|
||||
|
||||
> [!TIP]
|
||||
> Install `diffusers` from `main` to run the code: `pip install git+https://github.com/huggingface/diffusers@main`
|
||||
|
||||
|
||||
A helper method simplifies interacting with Hybrid Inference.
|
||||
|
||||
```python
|
||||
from diffusers.utils.remote_utils import remote_encode
|
||||
```
|
||||
|
||||
### Basic example
|
||||
|
||||
Let's encode an image, then decode it to demonstrate.
|
||||
|
||||
<figure class="image flex flex-col items-center justify-center text-center m-0 w-full">
|
||||
<img src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/astronaut.jpg"/>
|
||||
</figure>
|
||||
|
||||
<details><summary>Code</summary>
|
||||
|
||||
```python
|
||||
from diffusers.utils import load_image
|
||||
from diffusers.utils.remote_utils import remote_decode
|
||||
|
||||
image = load_image("https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/astronaut.jpg?download=true")
|
||||
|
||||
latent = remote_encode(
|
||||
endpoint="https://ptccx55jz97f9zgo.us-east-1.aws.endpoints.huggingface.cloud/",
|
||||
scaling_factor=0.3611,
|
||||
shift_factor=0.1159,
|
||||
)
|
||||
|
||||
decoded = remote_decode(
|
||||
endpoint="https://whhx50ex1aryqvw6.us-east-1.aws.endpoints.huggingface.cloud/",
|
||||
tensor=latent,
|
||||
scaling_factor=0.3611,
|
||||
shift_factor=0.1159,
|
||||
)
|
||||
```
|
||||
|
||||
</details>
|
||||
|
||||
<figure class="image flex flex-col items-center justify-center text-center m-0 w-full">
|
||||
<img src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/blog/remote_vae/decoded.png"/>
|
||||
</figure>
|
||||
|
||||
|
||||
### Generation
|
||||
|
||||
Now let's look at a generation example, we'll encode the image, generate then remotely decode too!
|
||||
|
||||
<details><summary>Code</summary>
|
||||
|
||||
```python
|
||||
import torch
|
||||
from diffusers import StableDiffusionImg2ImgPipeline
|
||||
from diffusers.utils import load_image
|
||||
from diffusers.utils.remote_utils import remote_decode, remote_encode
|
||||
|
||||
pipe = StableDiffusionImg2ImgPipeline.from_pretrained(
|
||||
"stable-diffusion-v1-5/stable-diffusion-v1-5",
|
||||
torch_dtype=torch.float16,
|
||||
variant="fp16",
|
||||
vae=None,
|
||||
).to("cuda")
|
||||
|
||||
init_image = load_image(
|
||||
"https://raw.githubusercontent.com/CompVis/stable-diffusion/main/assets/stable-samples/img2img/sketch-mountains-input.jpg"
|
||||
)
|
||||
init_image = init_image.resize((768, 512))
|
||||
|
||||
init_latent = remote_encode(
|
||||
endpoint="https://qc6479g0aac6qwy9.us-east-1.aws.endpoints.huggingface.cloud/",
|
||||
image=init_image,
|
||||
scaling_factor=0.18215,
|
||||
)
|
||||
|
||||
prompt = "A fantasy landscape, trending on artstation"
|
||||
latent = pipe(
|
||||
prompt=prompt,
|
||||
image=init_latent,
|
||||
strength=0.75,
|
||||
output_type="latent",
|
||||
).images
|
||||
|
||||
image = remote_decode(
|
||||
endpoint="https://q1bj3bpq6kzilnsu.us-east-1.aws.endpoints.huggingface.cloud/",
|
||||
tensor=latent,
|
||||
scaling_factor=0.18215,
|
||||
)
|
||||
image.save("fantasy_landscape.jpg")
|
||||
```
|
||||
|
||||
</details>
|
||||
|
||||
<figure class="image flex flex-col items-center justify-center text-center m-0 w-full">
|
||||
<img src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/blog/remote_vae/fantasy_landscape.png"/>
|
||||
</figure>
|
||||
|
||||
## Integrations
|
||||
|
||||
* **[SD.Next](https://github.com/vladmandic/sdnext):** All-in-one UI with direct supports Hybrid Inference.
|
||||
* **[ComfyUI-HFRemoteVae](https://github.com/kijai/ComfyUI-HFRemoteVae):** ComfyUI node for Hybrid Inference.
|
||||
@@ -161,10 +161,10 @@ Your Python environment will find the `main` version of 🤗 Diffusers on the ne
|
||||
|
||||
Model weights and files are downloaded from the Hub to a cache which is usually your home directory. You can change the cache location by specifying the `HF_HOME` or `HUGGINFACE_HUB_CACHE` environment variables or configuring the `cache_dir` parameter in methods like [`~DiffusionPipeline.from_pretrained`].
|
||||
|
||||
Cached files allow you to run 🤗 Diffusers offline. To prevent 🤗 Diffusers from connecting to the internet, set the `HF_HUB_OFFLINE` environment variable to `True` and 🤗 Diffusers will only load previously downloaded files in the cache.
|
||||
Cached files allow you to run 🤗 Diffusers offline. To prevent 🤗 Diffusers from connecting to the internet, set the `HF_HUB_OFFLINE` environment variable to `1` and 🤗 Diffusers will only load previously downloaded files in the cache.
|
||||
|
||||
```shell
|
||||
export HF_HUB_OFFLINE=True
|
||||
export HF_HUB_OFFLINE=1
|
||||
```
|
||||
|
||||
For more details about managing and cleaning the cache, take a look at the [caching](https://huggingface.co/docs/huggingface_hub/guides/manage-cache) guide.
|
||||
@@ -179,14 +179,16 @@ Telemetry is only sent when loading models and pipelines from the Hub,
|
||||
and it is not collected if you're loading local files.
|
||||
|
||||
We understand that not everyone wants to share additional information,and we respect your privacy.
|
||||
You can disable telemetry collection by setting the `DISABLE_TELEMETRY` environment variable from your terminal:
|
||||
You can disable telemetry collection by setting the `HF_HUB_DISABLE_TELEMETRY` environment variable from your terminal:
|
||||
|
||||
On Linux/MacOS:
|
||||
|
||||
```bash
|
||||
export DISABLE_TELEMETRY=YES
|
||||
export HF_HUB_DISABLE_TELEMETRY=1
|
||||
```
|
||||
|
||||
On Windows:
|
||||
|
||||
```bash
|
||||
set DISABLE_TELEMETRY=YES
|
||||
set HF_HUB_DISABLE_TELEMETRY=1
|
||||
```
|
||||
|
||||
@@ -178,6 +178,9 @@ pipe = CogVideoXPipeline.from_pretrained("THUDM/CogVideoX-5b", torch_dtype=torch
|
||||
# We can utilize the enable_group_offload method for Diffusers model implementations
|
||||
pipe.transformer.enable_group_offload(onload_device=onload_device, offload_device=offload_device, offload_type="leaf_level", use_stream=True)
|
||||
|
||||
# Uncomment the following to also allow recording the current streams.
|
||||
# pipe.transformer.enable_group_offload(onload_device=onload_device, offload_device=offload_device, offload_type="leaf_level", use_stream=True, record_stream=True)
|
||||
|
||||
# For any other model implementations, the apply_group_offloading function can be used
|
||||
apply_group_offloading(pipe.text_encoder, onload_device=onload_device, offload_type="block_level", num_blocks_per_group=2)
|
||||
apply_group_offloading(pipe.vae, onload_device=onload_device, offload_type="leaf_level")
|
||||
@@ -198,6 +201,19 @@ export_to_video(video, "output.mp4", fps=8)
|
||||
|
||||
Group offloading (for CUDA devices with support for asynchronous data transfer streams) overlaps data transfer and computation to reduce the overall execution time compared to sequential offloading. This is enabled using layer prefetching with CUDA streams. The next layer to be executed is loaded onto the accelerator device while the current layer is being executed - this increases the memory requirements slightly. Group offloading also supports leaf-level offloading (equivalent to sequential CPU offloading) but can be made much faster when using streams.
|
||||
|
||||
<Tip>
|
||||
|
||||
- Group offloading may not work with all models out-of-the-box. If the forward implementations of the model contain weight-dependent device-casting of inputs, it may clash with the offloading mechanism's handling of device-casting.
|
||||
- The `offload_type` parameter can be set to either `block_level` or `leaf_level`. `block_level` offloads groups of `torch::nn::ModuleList` or `torch::nn:Sequential` modules based on a configurable attribute `num_blocks_per_group`. For example, if you set `num_blocks_per_group=2` on a standard transformer model containing 40 layers, it will onload/offload 2 layers at a time for a total of 20 onload/offloads. This drastically reduces the VRAM requirements. `leaf_level` offloads individual layers at the lowest level, which is equivalent to sequential offloading. However, unlike sequential offloading, group offloading can be made much faster when using streams, with minimal compromise to end-to-end generation time.
|
||||
- The `use_stream` parameter can be used with CUDA devices to enable prefetching layers for onload. It defaults to `False`. Layer prefetching allows overlapping computation and data transfer of model weights, which drastically reduces the overall execution time compared to other offloading methods. However, it can increase the CPU RAM usage significantly. Ensure that available CPU RAM that is at least twice the size of the model when setting `use_stream=True`. You can find more information about CUDA streams [here](https://pytorch.org/docs/stable/generated/torch.cuda.Stream.html)
|
||||
- If specifying `use_stream=True` on VAEs with tiling enabled, make sure to do a dummy forward pass (possibly with dummy inputs) before the actual inference to avoid device-mismatch errors. This may not work on all implementations. Please open an issue if you encounter any problems.
|
||||
- The parameter `low_cpu_mem_usage` can be set to `True` to reduce CPU memory usage when using streams for group offloading. This is useful when the CPU memory is the bottleneck, but it may counteract the benefits of using streams and increase the overall execution time. The CPU memory savings come from creating pinned-tensors on-the-fly instead of pre-pinning them. This parameter is better suited for using `leaf_level` offloading.
|
||||
- When using `use_stream=True`, users can additionally specify `record_stream=True` to get better speedups at the expense of slightly increased memory usage. Refer to the [official PyTorch docs](https://pytorch.org/docs/stable/generated/torch.Tensor.record_stream.html) to know more about this.
|
||||
|
||||
For more information about available parameters and an explanation of how group offloading works, refer to [`~hooks.group_offloading.apply_group_offloading`].
|
||||
|
||||
</Tip>
|
||||
|
||||
## FP8 layerwise weight-casting
|
||||
|
||||
PyTorch supports `torch.float8_e4m3fn` and `torch.float8_e5m2` as weight storage dtypes, but they can't be used for computation in many different tensor operations due to unimplemented kernel support. However, you can use these dtypes to store model weights in fp8 precision and upcast them on-the-fly when the layers are used in the forward pass. This is known as layerwise weight-casting.
|
||||
@@ -235,6 +251,14 @@ In the above example, layerwise casting is enabled on the transformer component
|
||||
|
||||
However, you gain more control and flexibility by directly utilizing the [`~hooks.layerwise_casting.apply_layerwise_casting`] function instead of [`~ModelMixin.enable_layerwise_casting`].
|
||||
|
||||
<Tip>
|
||||
|
||||
- Layerwise casting may not work with all models out-of-the-box. Sometimes, the forward implementations of the model might contain internal typecasting of weight values. Such implementations are not supported due to the currently simplistic implementation of layerwise casting, which assumes that the forward pass is independent of the weight precision and that the input dtypes are always in `compute_dtype`. An example of an incompatible implementation can be found [here](https://github.com/huggingface/transformers/blob/7f5077e53682ca855afc826162b204ebf809f1f9/src/transformers/models/t5/modeling_t5.py#L294-L299).
|
||||
- Layerwise casting may fail on custom modeling implementations that make use of [PEFT](https://github.com/huggingface/peft) layers. Some minimal checks to handle this case is implemented but is not extensively tested or guaranteed to work in all cases.
|
||||
- It can be also be applied partially to specific layers of a model. Partially applying layerwise casting can either be done manually by calling the `apply_layerwise_casting` function on specific internal modules, or by specifying the `skip_modules_pattern` and `skip_modules_classes` parameters for a root module. These parameters are particularly useful for layers such as normalization and modulation.
|
||||
|
||||
</Tip>
|
||||
|
||||
## Channels-last memory format
|
||||
|
||||
The channels-last memory format is an alternative way of ordering NCHW tensors in memory to preserve dimension ordering. Channels-last tensors are ordered in such a way that the channels become the densest dimension (storing images pixel-per-pixel). Since not all operators currently support the channels-last format, it may result in worst performance but you should still try and see if it works for your model.
|
||||
|
||||
@@ -12,6 +12,9 @@ specific language governing permissions and limitations under the License.
|
||||
|
||||
# Metal Performance Shaders (MPS)
|
||||
|
||||
> [!TIP]
|
||||
> Pipelines with a <img alt="MPS" src="https://img.shields.io/badge/MPS-000000?style=flat&logo=apple&logoColor=white%22"> badge indicate a model can take advantage of the MPS backend on Apple silicon devices for faster inference. Feel free to open a [Pull Request](https://github.com/huggingface/diffusers/compare) to add this badge to pipelines that are missing it.
|
||||
|
||||
🤗 Diffusers is compatible with Apple silicon (M1/M2 chips) using the PyTorch [`mps`](https://pytorch.org/docs/stable/notes/mps.html) device, which uses the Metal framework to leverage the GPU on MacOS devices. You'll need to have:
|
||||
|
||||
- macOS computer with Apple silicon (M1/M2) hardware
|
||||
@@ -37,7 +40,7 @@ image
|
||||
|
||||
<Tip warning={true}>
|
||||
|
||||
Generating multiple prompts in a batch can [crash](https://github.com/huggingface/diffusers/issues/363) or fail to work reliably. We believe this is related to the [`mps`](https://github.com/pytorch/pytorch/issues/84039) backend in PyTorch. While this is being investigated, you should iterate instead of batching.
|
||||
The PyTorch [mps](https://pytorch.org/docs/stable/notes/mps.html) backend does not support NDArray sizes greater than `2**32`. Please open an [Issue](https://github.com/huggingface/diffusers/issues/new/choose) if you encounter this problem so we can investigate.
|
||||
|
||||
</Tip>
|
||||
|
||||
@@ -59,6 +62,10 @@ If you're using **PyTorch 1.13**, you need to "prime" the pipeline with an addit
|
||||
|
||||
## Troubleshoot
|
||||
|
||||
This section lists some common issues with using the `mps` backend and how to solve them.
|
||||
|
||||
### Attention slicing
|
||||
|
||||
M1/M2 performance is very sensitive to memory pressure. When this occurs, the system automatically swaps if it needs to which significantly degrades performance.
|
||||
|
||||
To prevent this from happening, we recommend *attention slicing* to reduce memory pressure during inference and prevent swapping. This is especially relevant if your computer has less than 64GB of system RAM, or if you generate images at non-standard resolutions larger than 512×512 pixels. Call the [`~DiffusionPipeline.enable_attention_slicing`] function on your pipeline:
|
||||
@@ -72,3 +79,7 @@ pipeline.enable_attention_slicing()
|
||||
```
|
||||
|
||||
Attention slicing performs the costly attention operation in multiple steps instead of all at once. It usually improves performance by ~20% in computers without universal memory, but we've observed *better performance* in most Apple silicon computers unless you have 64GB of RAM or more.
|
||||
|
||||
### Batch inference
|
||||
|
||||
Generating multiple prompts in a batch can crash or fail to work reliably. If this is the case, try iterating instead of batching.
|
||||
@@ -36,5 +36,6 @@ Diffusers currently supports the following quantization methods.
|
||||
- [BitsandBytes](./bitsandbytes)
|
||||
- [TorchAO](./torchao)
|
||||
- [GGUF](./gguf)
|
||||
- [Quanto](./quanto.md)
|
||||
|
||||
[This resource](https://huggingface.co/docs/transformers/main/en/quantization/overview#when-to-use-what) provides a good overview of the pros and cons of different quantization techniques.
|
||||
|
||||
@@ -0,0 +1,148 @@
|
||||
<!--Copyright 2025 The HuggingFace Team. All rights reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
|
||||
the License. You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
|
||||
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
|
||||
specific language governing permissions and limitations under the License.
|
||||
|
||||
-->
|
||||
|
||||
# Quanto
|
||||
|
||||
[Quanto](https://github.com/huggingface/optimum-quanto) is a PyTorch quantization backend for [Optimum](https://huggingface.co/docs/optimum/en/index). It has been designed with versatility and simplicity in mind:
|
||||
|
||||
- All features are available in eager mode (works with non-traceable models)
|
||||
- Supports quantization aware training
|
||||
- Quantized models are compatible with `torch.compile`
|
||||
- Quantized models are Device agnostic (e.g CUDA,XPU,MPS,CPU)
|
||||
|
||||
In order to use the Quanto backend, you will first need to install `optimum-quanto>=0.2.6` and `accelerate`
|
||||
|
||||
```shell
|
||||
pip install optimum-quanto accelerate
|
||||
```
|
||||
|
||||
Now you can quantize a model by passing the `QuantoConfig` object to the `from_pretrained()` method. Although the Quanto library does allow quantizing `nn.Conv2d` and `nn.LayerNorm` modules, currently, Diffusers only supports quantizing the weights in the `nn.Linear` layers of a model. The following snippet demonstrates how to apply `float8` quantization with Quanto.
|
||||
|
||||
```python
|
||||
import torch
|
||||
from diffusers import FluxTransformer2DModel, QuantoConfig
|
||||
|
||||
model_id = "black-forest-labs/FLUX.1-dev"
|
||||
quantization_config = QuantoConfig(weights_dtype="float8")
|
||||
transformer = FluxTransformer2DModel.from_pretrained(
|
||||
model_id,
|
||||
subfolder="transformer",
|
||||
quantization_config=quantization_config,
|
||||
torch_dtype=torch.bfloat16,
|
||||
)
|
||||
|
||||
pipe = FluxPipeline.from_pretrained(model_id, transformer=transformer, torch_dtype=torch_dtype)
|
||||
pipe.to("cuda")
|
||||
|
||||
prompt = "A cat holding a sign that says hello world"
|
||||
image = pipe(
|
||||
prompt, num_inference_steps=50, guidance_scale=4.5, max_sequence_length=512
|
||||
).images[0]
|
||||
image.save("output.png")
|
||||
```
|
||||
|
||||
## Skipping Quantization on specific modules
|
||||
|
||||
It is possible to skip applying quantization on certain modules using the `modules_to_not_convert` argument in the `QuantoConfig`. Please ensure that the modules passed in to this argument match the keys of the modules in the `state_dict`
|
||||
|
||||
```python
|
||||
import torch
|
||||
from diffusers import FluxTransformer2DModel, QuantoConfig
|
||||
|
||||
model_id = "black-forest-labs/FLUX.1-dev"
|
||||
quantization_config = QuantoConfig(weights_dtype="float8", modules_to_not_convert=["proj_out"])
|
||||
transformer = FluxTransformer2DModel.from_pretrained(
|
||||
model_id,
|
||||
subfolder="transformer",
|
||||
quantization_config=quantization_config,
|
||||
torch_dtype=torch.bfloat16,
|
||||
)
|
||||
```
|
||||
|
||||
## Using `from_single_file` with the Quanto Backend
|
||||
|
||||
`QuantoConfig` is compatible with `~FromOriginalModelMixin.from_single_file`.
|
||||
|
||||
```python
|
||||
import torch
|
||||
from diffusers import FluxTransformer2DModel, QuantoConfig
|
||||
|
||||
ckpt_path = "https://huggingface.co/black-forest-labs/FLUX.1-dev/blob/main/flux1-dev.safetensors"
|
||||
quantization_config = QuantoConfig(weights_dtype="float8")
|
||||
transformer = FluxTransformer2DModel.from_single_file(ckpt_path, quantization_config=quantization_config, torch_dtype=torch.bfloat16)
|
||||
```
|
||||
|
||||
## Saving Quantized models
|
||||
|
||||
Diffusers supports serializing Quanto models using the `~ModelMixin.save_pretrained` method.
|
||||
|
||||
The serialization and loading requirements are different for models quantized directly with the Quanto library and models quantized
|
||||
with Diffusers using Quanto as the backend. It is currently not possible to load models quantized directly with Quanto into Diffusers using `~ModelMixin.from_pretrained`
|
||||
|
||||
```python
|
||||
import torch
|
||||
from diffusers import FluxTransformer2DModel, QuantoConfig
|
||||
|
||||
model_id = "black-forest-labs/FLUX.1-dev"
|
||||
quantization_config = QuantoConfig(weights_dtype="float8")
|
||||
transformer = FluxTransformer2DModel.from_pretrained(
|
||||
model_id,
|
||||
subfolder="transformer",
|
||||
quantization_config=quantization_config,
|
||||
torch_dtype=torch.bfloat16,
|
||||
)
|
||||
# save quantized model to reuse
|
||||
transformer.save_pretrained("<your quantized model save path>")
|
||||
|
||||
# you can reload your quantized model with
|
||||
model = FluxTransformer2DModel.from_pretrained("<your quantized model save path>")
|
||||
```
|
||||
|
||||
## Using `torch.compile` with Quanto
|
||||
|
||||
Currently the Quanto backend supports `torch.compile` for the following quantization types:
|
||||
|
||||
- `int8` weights
|
||||
|
||||
```python
|
||||
import torch
|
||||
from diffusers import FluxPipeline, FluxTransformer2DModel, QuantoConfig
|
||||
|
||||
model_id = "black-forest-labs/FLUX.1-dev"
|
||||
quantization_config = QuantoConfig(weights_dtype="int8")
|
||||
transformer = FluxTransformer2DModel.from_pretrained(
|
||||
model_id,
|
||||
subfolder="transformer",
|
||||
quantization_config=quantization_config,
|
||||
torch_dtype=torch.bfloat16,
|
||||
)
|
||||
transformer = torch.compile(transformer, mode="max-autotune", fullgraph=True)
|
||||
|
||||
pipe = FluxPipeline.from_pretrained(
|
||||
model_id, transformer=transformer, torch_dtype=torch_dtype
|
||||
)
|
||||
pipe.to("cuda")
|
||||
images = pipe("A cat holding a sign that says hello").images[0]
|
||||
images.save("flux-quanto-compile.png")
|
||||
```
|
||||
|
||||
## Supported Quantization Types
|
||||
|
||||
### Weights
|
||||
|
||||
- float8
|
||||
- int8
|
||||
- int4
|
||||
- int2
|
||||
|
||||
|
||||
@@ -126,7 +126,7 @@ image = pipe(prompt, num_inference_steps=30, guidance_scale=7.0).images[0]
|
||||
image.save("output.png")
|
||||
```
|
||||
|
||||
Some quantization methods, such as `uint4wo`, cannot be loaded directly and may result in an `UnpicklingError` when trying to load the models, but work as expected when saving them. In order to work around this, one can load the state dict manually into the model. Note, however, that this requires using `weights_only=False` in `torch.load`, so it should be run only if the weights were obtained from a trustable source.
|
||||
If you are using `torch<=2.6.0`, some quantization methods, such as `uint4wo`, cannot be loaded directly and may result in an `UnpicklingError` when trying to load the models, but work as expected when saving them. In order to work around this, one can load the state dict manually into the model. Note, however, that this requires using `weights_only=False` in `torch.load`, so it should be run only if the weights were obtained from a trustable source.
|
||||
|
||||
```python
|
||||
import torch
|
||||
|
||||
@@ -95,6 +95,23 @@ Use the Space below to gauge a pipeline's memory requirements before you downloa
|
||||
></iframe>
|
||||
</div>
|
||||
|
||||
### Specifying Component-Specific Data Types
|
||||
|
||||
You can customize the data types for individual sub-models by passing a dictionary to the `torch_dtype` parameter. This allows you to load different components of a pipeline in different floating point precisions. For instance, if you want to load the transformer with `torch.bfloat16` and all other components with `torch.float16`, you can pass a dictionary mapping:
|
||||
|
||||
```python
|
||||
from diffusers import HunyuanVideoPipeline
|
||||
import torch
|
||||
|
||||
pipe = HunyuanVideoPipeline.from_pretrained(
|
||||
"hunyuanvideo-community/HunyuanVideo",
|
||||
torch_dtype={"transformer": torch.bfloat16, "default": torch.float16},
|
||||
)
|
||||
print(pipe.transformer.dtype, pipe.vae.dtype) # (torch.bfloat16, torch.float16)
|
||||
```
|
||||
|
||||
If a component is not explicitly specified in the dictionary and no `default` is provided, it will be loaded with `torch.float32`.
|
||||
|
||||
### Local pipeline
|
||||
|
||||
To load a pipeline locally, use [git-lfs](https://git-lfs.github.com/) to manually download a checkpoint to your local disk.
|
||||
|
||||
@@ -194,6 +194,59 @@ Currently, [`~loaders.StableDiffusionLoraLoaderMixin.set_adapters`] only support
|
||||
|
||||
</Tip>
|
||||
|
||||
### Hotswapping LoRA adapters
|
||||
|
||||
A common use case when serving multiple adapters is to load one adapter first, generate images, load another adapter, generate more images, load another adapter, etc. This workflow normally requires calling [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_weights`], [`~loaders.StableDiffusionLoraLoaderMixin.set_adapters`], and possibly [`~loaders.peft.PeftAdapterMixin.delete_adapters`] to save memory. Moreover, if the model is compiled using `torch.compile`, performing these steps requires recompilation, which takes time.
|
||||
|
||||
To better support this common workflow, you can "hotswap" a LoRA adapter, to avoid accumulating memory and in some cases, recompilation. It requires an adapter to already be loaded, and the new adapter weights are swapped in-place for the existing adapter.
|
||||
|
||||
Pass `hotswap=True` when loading a LoRA adapter to enable this feature. It is important to indicate the name of the existing adapter, (`default_0` is the default adapter name), to be swapped. If you loaded the first adapter with a different name, use that name instead.
|
||||
|
||||
```python
|
||||
pipe = ...
|
||||
# load adapter 1 as normal
|
||||
pipeline.load_lora_weights(file_name_adapter_1)
|
||||
# generate some images with adapter 1
|
||||
...
|
||||
# now hot swap the 2nd adapter
|
||||
pipeline.load_lora_weights(file_name_adapter_2, hotswap=True, adapter_name="default_0")
|
||||
# generate images with adapter 2
|
||||
```
|
||||
|
||||
|
||||
<Tip warning={true}>
|
||||
|
||||
Hotswapping is not currently supported for LoRA adapters that target the text encoder.
|
||||
|
||||
</Tip>
|
||||
|
||||
For compiled models, it is often (though not always if the second adapter targets identical LoRA ranks and scales) necessary to call [`~loaders.lora_base.LoraBaseMixin.enable_lora_hotswap`] to avoid recompilation. Use [`~loaders.lora_base.LoraBaseMixin.enable_lora_hotswap`] _before_ loading the first adapter, and `torch.compile` should be called _after_ loading the first adapter.
|
||||
|
||||
```python
|
||||
pipe = ...
|
||||
# call this extra method
|
||||
pipe.enable_lora_hotswap(target_rank=max_rank)
|
||||
# now load adapter 1
|
||||
pipe.load_lora_weights(file_name_adapter_1)
|
||||
# now compile the unet of the pipeline
|
||||
pipe.unet = torch.compile(pipeline.unet, ...)
|
||||
# generate some images with adapter 1
|
||||
...
|
||||
# now hot swap adapter 2
|
||||
pipeline.load_lora_weights(file_name_adapter_2, hotswap=True, adapter_name="default_0")
|
||||
# generate images with adapter 2
|
||||
```
|
||||
|
||||
The `target_rank=max_rank` argument is important for setting the maximum rank among all LoRA adapters that will be loaded. If you have one adapter with rank 8 and another with rank 16, pass `target_rank=16`. You should use a higher value if in doubt. By default, this value is 128.
|
||||
|
||||
However, there can be situations where recompilation is unavoidable. For example, if the hotswapped adapter targets more layers than the initial adapter, then recompilation is triggered. Try to load the adapter that targets the most layers first. Refer to the PEFT docs on [hotswapping](https://huggingface.co/docs/peft/main/en/package_reference/hotswap#peft.utils.hotswap.hotswap_adapter) for more details about the limitations of this feature.
|
||||
|
||||
<Tip>
|
||||
|
||||
Move your code inside the `with torch._dynamo.config.patch(error_on_recompile=True)` context manager to detect if a model was recompiled. If you detect recompilation despite following all the steps above, please open an issue with [Diffusers](https://github.com/huggingface/diffusers/issues) with a reproducible example.
|
||||
|
||||
</Tip>
|
||||
|
||||
### Kohya and TheLastBen
|
||||
|
||||
Other popular LoRA trainers from the community include those by [Kohya](https://github.com/kohya-ss/sd-scripts/) and [TheLastBen](https://github.com/TheLastBen/fast-stable-diffusion). These trainers create different LoRA checkpoints than those trained by 🤗 Diffusers, but they can still be loaded in the same way.
|
||||
|
||||
@@ -66,12 +66,6 @@ from accelerate.utils import write_basic_config
|
||||
write_basic_config()
|
||||
```
|
||||
|
||||
## 원을 채우는 데이터셋
|
||||
|
||||
원본 데이터셋은 ControlNet [repo](https://huggingface.co/lllyasviel/ControlNet/blob/main/training/fill50k.zip)에 올라와있지만, 우리는 [여기](https://huggingface.co/datasets/fusing/fill50k)에 새롭게 다시 올려서 🤗 Datasets 과 호환가능합니다. 그래서 학습 스크립트 상에서 데이터 불러오기를 다룰 수 있습니다.
|
||||
|
||||
우리의 학습 예시는 원래 ControlNet의 학습에 쓰였던 [`stable-diffusion-v1-5/stable-diffusion-v1-5`](https://huggingface.co/stable-diffusion-v1-5/stable-diffusion-v1-5)을 사용합니다. 그렇지만 ControlNet은 대응되는 어느 Stable Diffusion 모델([`CompVis/stable-diffusion-v1-4`](https://huggingface.co/CompVis/stable-diffusion-v1-4)) 혹은 [`stabilityai/stable-diffusion-2-1`](https://huggingface.co/stabilityai/stable-diffusion-2-1)의 증가를 위해 학습될 수 있습니다.
|
||||
|
||||
자체 데이터셋을 사용하기 위해서는 [학습을 위한 데이터셋 생성하기](create_dataset) 가이드를 확인하세요.
|
||||
|
||||
## 학습
|
||||
|
||||
@@ -79,13 +79,13 @@ This command will prompt you for a token. Copy-paste yours from your [settings/t
|
||||
### Target Modules
|
||||
When LoRA was first adapted from language models to diffusion models, it was applied to the cross-attention layers in the Unet that relate the image representations with the prompts that describe them.
|
||||
More recently, SOTA text-to-image diffusion models replaced the Unet with a diffusion Transformer(DiT). With this change, we may also want to explore
|
||||
applying LoRA training onto different types of layers and blocks. To allow more flexibility and control over the targeted modules we added `--lora_layers`- in which you can specify in a comma seperated string
|
||||
applying LoRA training onto different types of layers and blocks. To allow more flexibility and control over the targeted modules we added `--lora_layers`- in which you can specify in a comma separated string
|
||||
the exact modules for LoRA training. Here are some examples of target modules you can provide:
|
||||
- for attention only layers: `--lora_layers="attn.to_k,attn.to_q,attn.to_v,attn.to_out.0"`
|
||||
- to train the same modules as in the fal trainer: `--lora_layers="attn.to_k,attn.to_q,attn.to_v,attn.to_out.0,attn.add_k_proj,attn.add_q_proj,attn.add_v_proj,attn.to_add_out,ff.net.0.proj,ff.net.2,ff_context.net.0.proj,ff_context.net.2"`
|
||||
- to train the same modules as in ostris ai-toolkit / replicate trainer: `--lora_blocks="attn.to_k,attn.to_q,attn.to_v,attn.to_out.0,attn.add_k_proj,attn.add_q_proj,attn.add_v_proj,attn.to_add_out,ff.net.0.proj,ff.net.2,ff_context.net.0.proj,ff_context.net.2,norm1_context.linear, norm1.linear,norm.linear,proj_mlp,proj_out"`
|
||||
> [!NOTE]
|
||||
> `--lora_layers` can also be used to specify which **blocks** to apply LoRA training to. To do so, simply add a block prefix to each layer in the comma seperated string:
|
||||
> `--lora_layers` can also be used to specify which **blocks** to apply LoRA training to. To do so, simply add a block prefix to each layer in the comma separated string:
|
||||
> **single DiT blocks**: to target the ith single transformer block, add the prefix `single_transformer_blocks.i`, e.g. - `single_transformer_blocks.i.attn.to_k`
|
||||
> **MMDiT blocks**: to target the ith MMDiT block, add the prefix `transformer_blocks.i`, e.g. - `transformer_blocks.i.attn.to_k`
|
||||
> [!NOTE]
|
||||
|
||||
@@ -1,7 +1,8 @@
|
||||
accelerate>=0.16.0
|
||||
accelerate>=0.31.0
|
||||
torchvision
|
||||
transformers>=4.25.1
|
||||
transformers>=4.41.2
|
||||
ftfy
|
||||
tensorboard
|
||||
Jinja2
|
||||
peft==0.7.0
|
||||
peft>=0.11.1
|
||||
sentencepiece
|
||||
@@ -24,7 +24,7 @@ import re
|
||||
import shutil
|
||||
from contextlib import nullcontext
|
||||
from pathlib import Path
|
||||
from typing import List, Optional, Union
|
||||
from typing import List, Optional
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
@@ -228,10 +228,20 @@ def log_validation(
|
||||
|
||||
# run inference
|
||||
generator = torch.Generator(device=accelerator.device).manual_seed(args.seed) if args.seed is not None else None
|
||||
autocast_ctx = nullcontext()
|
||||
autocast_ctx = torch.autocast(accelerator.device.type) if not is_final_validation else nullcontext()
|
||||
|
||||
with autocast_ctx:
|
||||
images = [pipeline(**pipeline_args, generator=generator).images[0] for _ in range(args.num_validation_images)]
|
||||
# pre-calculate prompt embeds, pooled prompt embeds, text ids because t5 does not support autocast
|
||||
with torch.no_grad():
|
||||
prompt_embeds, pooled_prompt_embeds, text_ids = pipeline.encode_prompt(
|
||||
pipeline_args["prompt"], prompt_2=pipeline_args["prompt"]
|
||||
)
|
||||
images = []
|
||||
for _ in range(args.num_validation_images):
|
||||
with autocast_ctx:
|
||||
image = pipeline(
|
||||
prompt_embeds=prompt_embeds, pooled_prompt_embeds=pooled_prompt_embeds, generator=generator
|
||||
).images[0]
|
||||
images.append(image)
|
||||
|
||||
for tracker in accelerator.trackers:
|
||||
phase_name = "test" if is_final_validation else "validation"
|
||||
@@ -378,7 +388,7 @@ def parse_args(input_args=None):
|
||||
default=None,
|
||||
help="the concept to use to initialize the new inserted tokens when training with "
|
||||
"--train_text_encoder_ti = True. By default, new tokens (<si><si+1>) are initialized with random value. "
|
||||
"Alternatively, you could specify a different word/words whos value will be used as the starting point for the new inserted tokens. "
|
||||
"Alternatively, you could specify a different word/words whose value will be used as the starting point for the new inserted tokens. "
|
||||
"--num_new_tokens_per_abstraction is ignored when initializer_concept is provided",
|
||||
)
|
||||
parser.add_argument(
|
||||
@@ -657,15 +667,17 @@ def parse_args(input_args=None):
|
||||
parser.add_argument(
|
||||
"--adam_weight_decay_text_encoder", type=float, default=1e-03, help="Weight decay to use for text_encoder"
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--lora_layers",
|
||||
type=str,
|
||||
default=None,
|
||||
help=(
|
||||
"The transformer modules to apply LoRA training on. Please specify the layers in a comma seperated. "
|
||||
"The transformer modules to apply LoRA training on. Please specify the layers in a comma separated. "
|
||||
'E.g. - "to_k,to_q,to_v,to_out.0" will result in lora training of attention layers only. For more examples refer to https://github.com/huggingface/diffusers/blob/main/examples/advanced_diffusion_training/README_flux.md'
|
||||
),
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--adam_epsilon",
|
||||
type=float,
|
||||
@@ -738,6 +750,15 @@ def parse_args(input_args=None):
|
||||
" flag passed with the `accelerate.launch` command. Use this argument to override the accelerate config."
|
||||
),
|
||||
)
|
||||
parser.add_argument(
|
||||
"--upcast_before_saving",
|
||||
action="store_true",
|
||||
default=False,
|
||||
help=(
|
||||
"Whether to upcast the trained transformer layers to float32 before saving (at the end of training). "
|
||||
"Defaults to precision dtype used for training to save memory"
|
||||
),
|
||||
)
|
||||
parser.add_argument(
|
||||
"--prior_generation_precision",
|
||||
type=str,
|
||||
@@ -1147,7 +1168,7 @@ def tokenize_prompt(tokenizer, prompt, max_sequence_length, add_special_tokens=F
|
||||
return text_input_ids
|
||||
|
||||
|
||||
def _get_t5_prompt_embeds(
|
||||
def _encode_prompt_with_t5(
|
||||
text_encoder,
|
||||
tokenizer,
|
||||
max_sequence_length=512,
|
||||
@@ -1176,7 +1197,10 @@ def _get_t5_prompt_embeds(
|
||||
|
||||
prompt_embeds = text_encoder(text_input_ids.to(device))[0]
|
||||
|
||||
dtype = text_encoder.dtype
|
||||
if hasattr(text_encoder, "module"):
|
||||
dtype = text_encoder.module.dtype
|
||||
else:
|
||||
dtype = text_encoder.dtype
|
||||
prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)
|
||||
|
||||
_, seq_len, _ = prompt_embeds.shape
|
||||
@@ -1188,7 +1212,7 @@ def _get_t5_prompt_embeds(
|
||||
return prompt_embeds
|
||||
|
||||
|
||||
def _get_clip_prompt_embeds(
|
||||
def _encode_prompt_with_clip(
|
||||
text_encoder,
|
||||
tokenizer,
|
||||
prompt: str,
|
||||
@@ -1217,9 +1241,13 @@ def _get_clip_prompt_embeds(
|
||||
|
||||
prompt_embeds = text_encoder(text_input_ids.to(device), output_hidden_states=False)
|
||||
|
||||
if hasattr(text_encoder, "module"):
|
||||
dtype = text_encoder.module.dtype
|
||||
else:
|
||||
dtype = text_encoder.dtype
|
||||
# Use pooled output of CLIPTextModel
|
||||
prompt_embeds = prompt_embeds.pooler_output
|
||||
prompt_embeds = prompt_embeds.to(dtype=text_encoder.dtype, device=device)
|
||||
prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)
|
||||
|
||||
# duplicate text embeddings for each generation per prompt, using mps friendly method
|
||||
prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
|
||||
@@ -1238,136 +1266,35 @@ def encode_prompt(
|
||||
text_input_ids_list=None,
|
||||
):
|
||||
prompt = [prompt] if isinstance(prompt, str) else prompt
|
||||
batch_size = len(prompt)
|
||||
dtype = text_encoders[0].dtype
|
||||
if hasattr(text_encoders[0], "module"):
|
||||
dtype = text_encoders[0].module.dtype
|
||||
else:
|
||||
dtype = text_encoders[0].dtype
|
||||
|
||||
pooled_prompt_embeds = _get_clip_prompt_embeds(
|
||||
pooled_prompt_embeds = _encode_prompt_with_clip(
|
||||
text_encoder=text_encoders[0],
|
||||
tokenizer=tokenizers[0],
|
||||
prompt=prompt,
|
||||
device=device if device is not None else text_encoders[0].device,
|
||||
num_images_per_prompt=num_images_per_prompt,
|
||||
text_input_ids=text_input_ids_list[0] if text_input_ids_list is not None else None,
|
||||
text_input_ids=text_input_ids_list[0] if text_input_ids_list else None,
|
||||
)
|
||||
|
||||
prompt_embeds = _get_t5_prompt_embeds(
|
||||
prompt_embeds = _encode_prompt_with_t5(
|
||||
text_encoder=text_encoders[1],
|
||||
tokenizer=tokenizers[1],
|
||||
max_sequence_length=max_sequence_length,
|
||||
prompt=prompt,
|
||||
num_images_per_prompt=num_images_per_prompt,
|
||||
device=device if device is not None else text_encoders[1].device,
|
||||
text_input_ids=text_input_ids_list[1] if text_input_ids_list is not None else None,
|
||||
text_input_ids=text_input_ids_list[1] if text_input_ids_list else None,
|
||||
)
|
||||
|
||||
text_ids = torch.zeros(batch_size, prompt_embeds.shape[1], 3).to(device=device, dtype=dtype)
|
||||
text_ids = text_ids.repeat(num_images_per_prompt, 1, 1)
|
||||
text_ids = torch.zeros(prompt_embeds.shape[1], 3).to(device=device, dtype=dtype)
|
||||
|
||||
return prompt_embeds, pooled_prompt_embeds, text_ids
|
||||
|
||||
|
||||
# CustomFlowMatchEulerDiscreteScheduler was taken from ostris ai-toolkit trainer:
|
||||
# https://github.com/ostris/ai-toolkit/blob/9ee1ef2a0a2a9a02b92d114a95f21312e5906e54/toolkit/samplers/custom_flowmatch_sampler.py#L95
|
||||
class CustomFlowMatchEulerDiscreteScheduler(FlowMatchEulerDiscreteScheduler):
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
with torch.no_grad():
|
||||
# create weights for timesteps
|
||||
num_timesteps = 1000
|
||||
|
||||
# generate the multiplier based on cosmap loss weighing
|
||||
# this is only used on linear timesteps for now
|
||||
|
||||
# cosine map weighing is higher in the middle and lower at the ends
|
||||
# bot = 1 - 2 * self.sigmas + 2 * self.sigmas ** 2
|
||||
# cosmap_weighing = 2 / (math.pi * bot)
|
||||
|
||||
# sigma sqrt weighing is significantly higher at the end and lower at the beginning
|
||||
sigma_sqrt_weighing = (self.sigmas**-2.0).float()
|
||||
# clip at 1e4 (1e6 is too high)
|
||||
sigma_sqrt_weighing = torch.clamp(sigma_sqrt_weighing, max=1e4)
|
||||
# bring to a mean of 1
|
||||
sigma_sqrt_weighing = sigma_sqrt_weighing / sigma_sqrt_weighing.mean()
|
||||
|
||||
# Create linear timesteps from 1000 to 0
|
||||
timesteps = torch.linspace(1000, 0, num_timesteps, device="cpu")
|
||||
|
||||
self.linear_timesteps = timesteps
|
||||
# self.linear_timesteps_weights = cosmap_weighing
|
||||
self.linear_timesteps_weights = sigma_sqrt_weighing
|
||||
|
||||
# self.sigmas = self.get_sigmas(timesteps, n_dim=1, dtype=torch.float32, device='cpu')
|
||||
pass
|
||||
|
||||
def get_weights_for_timesteps(self, timesteps: torch.Tensor) -> torch.Tensor:
|
||||
# Get the indices of the timesteps
|
||||
step_indices = [(self.timesteps == t).nonzero().item() for t in timesteps]
|
||||
|
||||
# Get the weights for the timesteps
|
||||
weights = self.linear_timesteps_weights[step_indices].flatten()
|
||||
|
||||
return weights
|
||||
|
||||
def get_sigmas(self, timesteps: torch.Tensor, n_dim, dtype, device) -> torch.Tensor:
|
||||
sigmas = self.sigmas.to(device=device, dtype=dtype)
|
||||
schedule_timesteps = self.timesteps.to(device)
|
||||
timesteps = timesteps.to(device)
|
||||
step_indices = [(schedule_timesteps == t).nonzero().item() for t in timesteps]
|
||||
|
||||
sigma = sigmas[step_indices].flatten()
|
||||
while len(sigma.shape) < n_dim:
|
||||
sigma = sigma.unsqueeze(-1)
|
||||
|
||||
return sigma
|
||||
|
||||
def add_noise(
|
||||
self,
|
||||
original_samples: torch.Tensor,
|
||||
noise: torch.Tensor,
|
||||
timesteps: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
## ref https://github.com/huggingface/diffusers/blob/fbe29c62984c33c6cf9cf7ad120a992fe6d20854/examples/dreambooth/train_dreambooth_sd3.py#L1578
|
||||
## Add noise according to flow matching.
|
||||
## zt = (1 - texp) * x + texp * z1
|
||||
|
||||
# sigmas = get_sigmas(timesteps, n_dim=model_input.ndim, dtype=model_input.dtype)
|
||||
# noisy_model_input = (1.0 - sigmas) * model_input + sigmas * noise
|
||||
|
||||
# timestep needs to be in [0, 1], we store them in [0, 1000]
|
||||
# noisy_sample = (1 - timestep) * latent + timestep * noise
|
||||
t_01 = (timesteps / 1000).to(original_samples.device)
|
||||
noisy_model_input = (1 - t_01) * original_samples + t_01 * noise
|
||||
|
||||
# n_dim = original_samples.ndim
|
||||
# sigmas = self.get_sigmas(timesteps, n_dim, original_samples.dtype, original_samples.device)
|
||||
# noisy_model_input = (1.0 - sigmas) * original_samples + sigmas * noise
|
||||
return noisy_model_input
|
||||
|
||||
def scale_model_input(self, sample: torch.Tensor, timestep: Union[float, torch.Tensor]) -> torch.Tensor:
|
||||
return sample
|
||||
|
||||
def set_train_timesteps(self, num_timesteps, device, linear=False):
|
||||
if linear:
|
||||
timesteps = torch.linspace(1000, 0, num_timesteps, device=device)
|
||||
self.timesteps = timesteps
|
||||
return timesteps
|
||||
else:
|
||||
# distribute them closer to center. Inference distributes them as a bias toward first
|
||||
# Generate values from 0 to 1
|
||||
t = torch.sigmoid(torch.randn((num_timesteps,), device=device))
|
||||
|
||||
# Scale and reverse the values to go from 1000 to 0
|
||||
timesteps = (1 - t) * 1000
|
||||
|
||||
# Sort the timesteps in descending order
|
||||
timesteps, _ = torch.sort(timesteps, descending=True)
|
||||
|
||||
self.timesteps = timesteps.to(device=device)
|
||||
|
||||
return timesteps
|
||||
|
||||
|
||||
def main(args):
|
||||
if args.report_to == "wandb" and args.hub_token is not None:
|
||||
raise ValueError(
|
||||
@@ -1499,7 +1426,7 @@ def main(args):
|
||||
)
|
||||
|
||||
# Load scheduler and models
|
||||
noise_scheduler = CustomFlowMatchEulerDiscreteScheduler.from_pretrained(
|
||||
noise_scheduler = FlowMatchEulerDiscreteScheduler.from_pretrained(
|
||||
args.pretrained_model_name_or_path, subfolder="scheduler"
|
||||
)
|
||||
noise_scheduler_copy = copy.deepcopy(noise_scheduler)
|
||||
@@ -1619,7 +1546,6 @@ def main(args):
|
||||
target_modules=target_modules,
|
||||
)
|
||||
transformer.add_adapter(transformer_lora_config)
|
||||
|
||||
if args.train_text_encoder:
|
||||
text_lora_config = LoraConfig(
|
||||
r=args.rank,
|
||||
@@ -1727,7 +1653,6 @@ def main(args):
|
||||
cast_training_params(models, dtype=torch.float32)
|
||||
|
||||
transformer_lora_parameters = list(filter(lambda p: p.requires_grad, transformer.parameters()))
|
||||
|
||||
if args.train_text_encoder:
|
||||
text_lora_parameters_one = list(filter(lambda p: p.requires_grad, text_encoder_one.parameters()))
|
||||
# if we use textual inversion, we freeze all parameters except for the token embeddings
|
||||
@@ -1737,7 +1662,8 @@ def main(args):
|
||||
for name, param in text_encoder_one.named_parameters():
|
||||
if "token_embedding" in name:
|
||||
# ensure that dtype is float32, even if rest of the model that isn't trained is loaded in fp16
|
||||
param.data = param.to(dtype=torch.float32)
|
||||
if args.mixed_precision == "fp16":
|
||||
param.data = param.to(dtype=torch.float32)
|
||||
param.requires_grad = True
|
||||
text_lora_parameters_one.append(param)
|
||||
else:
|
||||
@@ -1747,7 +1673,8 @@ def main(args):
|
||||
for name, param in text_encoder_two.named_parameters():
|
||||
if "shared" in name:
|
||||
# ensure that dtype is float32, even if rest of the model that isn't trained is loaded in fp16
|
||||
param.data = param.to(dtype=torch.float32)
|
||||
if args.mixed_precision == "fp16":
|
||||
param.data = param.to(dtype=torch.float32)
|
||||
param.requires_grad = True
|
||||
text_lora_parameters_two.append(param)
|
||||
else:
|
||||
@@ -1828,6 +1755,7 @@ def main(args):
|
||||
optimizer_class = bnb.optim.AdamW8bit
|
||||
else:
|
||||
optimizer_class = torch.optim.AdamW
|
||||
|
||||
optimizer = optimizer_class(
|
||||
params_to_optimize,
|
||||
betas=(args.adam_beta1, args.adam_beta2),
|
||||
@@ -2021,6 +1949,7 @@ def main(args):
|
||||
lr_scheduler,
|
||||
)
|
||||
else:
|
||||
print("I SHOULD BE HERE")
|
||||
transformer, text_encoder_one, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
|
||||
transformer, text_encoder_one, optimizer, train_dataloader, lr_scheduler
|
||||
)
|
||||
@@ -2125,7 +2054,7 @@ def main(args):
|
||||
if args.train_text_encoder:
|
||||
text_encoder_one.train()
|
||||
# set top parameter requires_grad = True for gradient checkpointing works
|
||||
accelerator.unwrap_model(text_encoder_one).text_model.embeddings.requires_grad_(True)
|
||||
unwrap_model(text_encoder_one).text_model.embeddings.requires_grad_(True)
|
||||
elif args.train_text_encoder_ti: # textual inversion / pivotal tuning
|
||||
text_encoder_one.train()
|
||||
if args.enable_t5_ti:
|
||||
@@ -2137,6 +2066,11 @@ def main(args):
|
||||
pivoted_tr = True
|
||||
|
||||
for step, batch in enumerate(train_dataloader):
|
||||
models_to_accumulate = [transformer]
|
||||
if not freeze_text_encoder:
|
||||
models_to_accumulate.extend([text_encoder_one])
|
||||
if args.enable_t5_ti:
|
||||
models_to_accumulate.extend([text_encoder_two])
|
||||
if pivoted_te:
|
||||
# stopping optimization of text_encoder params
|
||||
optimizer.param_groups[te_idx]["lr"] = 0.0
|
||||
@@ -2145,7 +2079,7 @@ def main(args):
|
||||
logger.info(f"PIVOT TRANSFORMER {epoch}")
|
||||
optimizer.param_groups[0]["lr"] = 0.0
|
||||
|
||||
with accelerator.accumulate(transformer):
|
||||
with accelerator.accumulate(models_to_accumulate):
|
||||
prompts = batch["prompts"]
|
||||
|
||||
# encode batch prompts when custom prompts are provided for each image -
|
||||
@@ -2189,7 +2123,7 @@ def main(args):
|
||||
model_input = (model_input - vae_config_shift_factor) * vae_config_scaling_factor
|
||||
model_input = model_input.to(dtype=weight_dtype)
|
||||
|
||||
vae_scale_factor = 2 ** (len(vae_config_block_out_channels))
|
||||
vae_scale_factor = 2 ** (len(vae_config_block_out_channels) - 1)
|
||||
|
||||
latent_image_ids = FluxPipeline._prepare_latent_image_ids(
|
||||
model_input.shape[0],
|
||||
@@ -2228,7 +2162,7 @@ def main(args):
|
||||
)
|
||||
|
||||
# handle guidance
|
||||
if transformer.config.guidance_embeds:
|
||||
if unwrap_model(transformer).config.guidance_embeds:
|
||||
guidance = torch.tensor([args.guidance_scale], device=accelerator.device)
|
||||
guidance = guidance.expand(model_input.shape[0])
|
||||
else:
|
||||
@@ -2288,16 +2222,26 @@ def main(args):
|
||||
accelerator.backward(loss)
|
||||
if accelerator.sync_gradients:
|
||||
if not freeze_text_encoder:
|
||||
if args.train_text_encoder:
|
||||
if args.train_text_encoder: # text encoder tuning
|
||||
params_to_clip = itertools.chain(transformer.parameters(), text_encoder_one.parameters())
|
||||
elif pure_textual_inversion:
|
||||
params_to_clip = itertools.chain(
|
||||
text_encoder_one.parameters(), text_encoder_two.parameters()
|
||||
)
|
||||
if args.enable_t5_ti:
|
||||
params_to_clip = itertools.chain(
|
||||
text_encoder_one.parameters(), text_encoder_two.parameters()
|
||||
)
|
||||
else:
|
||||
params_to_clip = itertools.chain(text_encoder_one.parameters())
|
||||
else:
|
||||
params_to_clip = itertools.chain(
|
||||
transformer.parameters(), text_encoder_one.parameters(), text_encoder_two.parameters()
|
||||
)
|
||||
if args.enable_t5_ti:
|
||||
params_to_clip = itertools.chain(
|
||||
transformer.parameters(),
|
||||
text_encoder_one.parameters(),
|
||||
text_encoder_two.parameters(),
|
||||
)
|
||||
else:
|
||||
params_to_clip = itertools.chain(
|
||||
transformer.parameters(), text_encoder_one.parameters()
|
||||
)
|
||||
else:
|
||||
params_to_clip = itertools.chain(transformer.parameters())
|
||||
accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm)
|
||||
@@ -2339,6 +2283,10 @@ def main(args):
|
||||
|
||||
save_path = os.path.join(args.output_dir, f"checkpoint-{global_step}")
|
||||
accelerator.save_state(save_path)
|
||||
if args.train_text_encoder_ti:
|
||||
embedding_handler.save_embeddings(
|
||||
f"{args.output_dir}/{Path(args.output_dir).name}_emb_checkpoint_{global_step}.safetensors"
|
||||
)
|
||||
logger.info(f"Saved state to {save_path}")
|
||||
|
||||
logs = {"loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0]}
|
||||
@@ -2351,14 +2299,16 @@ def main(args):
|
||||
if accelerator.is_main_process:
|
||||
if args.validation_prompt is not None and epoch % args.validation_epochs == 0:
|
||||
# create pipeline
|
||||
if freeze_text_encoder:
|
||||
if freeze_text_encoder: # no text encoder one, two optimizations
|
||||
text_encoder_one, text_encoder_two = load_text_encoders(text_encoder_cls_one, text_encoder_cls_two)
|
||||
text_encoder_one.to(weight_dtype)
|
||||
text_encoder_two.to(weight_dtype)
|
||||
pipeline = FluxPipeline.from_pretrained(
|
||||
args.pretrained_model_name_or_path,
|
||||
vae=vae,
|
||||
text_encoder=accelerator.unwrap_model(text_encoder_one),
|
||||
text_encoder_2=accelerator.unwrap_model(text_encoder_two),
|
||||
transformer=accelerator.unwrap_model(transformer),
|
||||
text_encoder=unwrap_model(text_encoder_one),
|
||||
text_encoder_2=unwrap_model(text_encoder_two),
|
||||
transformer=unwrap_model(transformer),
|
||||
revision=args.revision,
|
||||
variant=args.variant,
|
||||
torch_dtype=weight_dtype,
|
||||
@@ -2372,21 +2322,21 @@ def main(args):
|
||||
epoch=epoch,
|
||||
torch_dtype=weight_dtype,
|
||||
)
|
||||
images = None
|
||||
del pipeline
|
||||
|
||||
if freeze_text_encoder:
|
||||
del text_encoder_one, text_encoder_two
|
||||
free_memory()
|
||||
elif args.train_text_encoder:
|
||||
del text_encoder_two
|
||||
free_memory()
|
||||
|
||||
images = None
|
||||
del pipeline
|
||||
|
||||
# Save the lora layers
|
||||
accelerator.wait_for_everyone()
|
||||
if accelerator.is_main_process:
|
||||
transformer = unwrap_model(transformer)
|
||||
transformer = transformer.to(weight_dtype)
|
||||
if args.upcast_before_saving:
|
||||
transformer.to(torch.float32)
|
||||
else:
|
||||
transformer = transformer.to(weight_dtype)
|
||||
transformer_lora_layers = get_peft_model_state_dict(transformer)
|
||||
|
||||
if args.train_text_encoder:
|
||||
@@ -2428,8 +2378,8 @@ def main(args):
|
||||
accelerator=accelerator,
|
||||
pipeline_args=pipeline_args,
|
||||
epoch=epoch,
|
||||
torch_dtype=weight_dtype,
|
||||
is_final_validation=True,
|
||||
torch_dtype=weight_dtype,
|
||||
)
|
||||
|
||||
save_model_card(
|
||||
@@ -2452,6 +2402,7 @@ def main(args):
|
||||
commit_message="End of training",
|
||||
ignore_patterns=["step_*", "epoch_*"],
|
||||
)
|
||||
|
||||
images = None
|
||||
del pipeline
|
||||
|
||||
|
||||
@@ -662,7 +662,7 @@ def parse_args(input_args=None):
|
||||
action="store_true",
|
||||
default=False,
|
||||
help=(
|
||||
"Wether to train a DoRA as proposed in- DoRA: Weight-Decomposed Low-Rank Adaptation https://arxiv.org/abs/2402.09353. "
|
||||
"Whether to train a DoRA as proposed in- DoRA: Weight-Decomposed Low-Rank Adaptation https://arxiv.org/abs/2402.09353. "
|
||||
"Note: to use DoRA you need to install peft from main, `pip install git+https://github.com/huggingface/peft.git`"
|
||||
),
|
||||
)
|
||||
|
||||
@@ -71,6 +71,7 @@ from diffusers.utils import (
|
||||
convert_unet_state_dict_to_peft,
|
||||
is_wandb_available,
|
||||
)
|
||||
from diffusers.utils.hub_utils import load_or_create_model_card, populate_model_card
|
||||
from diffusers.utils.import_utils import is_xformers_available
|
||||
from diffusers.utils.torch_utils import is_compiled_module
|
||||
|
||||
@@ -101,7 +102,7 @@ def determine_scheduler_type(pretrained_model_name_or_path, revision):
|
||||
def save_model_card(
|
||||
repo_id: str,
|
||||
use_dora: bool,
|
||||
images=None,
|
||||
images: list = None,
|
||||
base_model: str = None,
|
||||
train_text_encoder=False,
|
||||
train_text_encoder_ti=False,
|
||||
@@ -111,20 +112,17 @@ def save_model_card(
|
||||
repo_folder=None,
|
||||
vae_path=None,
|
||||
):
|
||||
img_str = "widget:\n"
|
||||
lora = "lora" if not use_dora else "dora"
|
||||
for i, image in enumerate(images):
|
||||
image.save(os.path.join(repo_folder, f"image_{i}.png"))
|
||||
img_str += f"""
|
||||
- text: '{validation_prompt if validation_prompt else ' ' }'
|
||||
output:
|
||||
url:
|
||||
"image_{i}.png"
|
||||
"""
|
||||
if not images:
|
||||
img_str += f"""
|
||||
- text: '{instance_prompt}'
|
||||
"""
|
||||
|
||||
widget_dict = []
|
||||
if images is not None:
|
||||
for i, image in enumerate(images):
|
||||
image.save(os.path.join(repo_folder, f"image_{i}.png"))
|
||||
widget_dict.append(
|
||||
{"text": validation_prompt if validation_prompt else " ", "output": {"url": f"image_{i}.png"}}
|
||||
)
|
||||
else:
|
||||
widget_dict.append({"text": instance_prompt})
|
||||
embeddings_filename = f"{repo_folder}_emb"
|
||||
instance_prompt_webui = re.sub(r"<s\d+>", "", re.sub(r"<s\d+>", embeddings_filename, instance_prompt, count=1))
|
||||
ti_keys = ", ".join(f'"{match}"' for match in re.findall(r"<s\d+>", instance_prompt))
|
||||
@@ -169,23 +167,7 @@ pipeline.load_textual_inversion(state_dict["clip_g"], token=[{ti_keys}], text_en
|
||||
to trigger concept `{key}` → use `{tokens}` in your prompt \n
|
||||
"""
|
||||
|
||||
yaml = f"""---
|
||||
tags:
|
||||
- stable-diffusion-xl
|
||||
- stable-diffusion-xl-diffusers
|
||||
- diffusers-training
|
||||
- text-to-image
|
||||
- diffusers
|
||||
- {lora}
|
||||
- template:sd-lora
|
||||
{img_str}
|
||||
base_model: {base_model}
|
||||
instance_prompt: {instance_prompt}
|
||||
license: openrail++
|
||||
---
|
||||
"""
|
||||
|
||||
model_card = f"""
|
||||
model_description = f"""
|
||||
# SDXL LoRA DreamBooth - {repo_id}
|
||||
|
||||
<Gallery />
|
||||
@@ -234,8 +216,25 @@ Special VAE used for training: {vae_path}.
|
||||
|
||||
{license}
|
||||
"""
|
||||
with open(os.path.join(repo_folder, "README.md"), "w") as f:
|
||||
f.write(yaml + model_card)
|
||||
model_card = load_or_create_model_card(
|
||||
repo_id_or_path=repo_id,
|
||||
from_training=True,
|
||||
license="openrail++",
|
||||
base_model=base_model,
|
||||
prompt=instance_prompt,
|
||||
model_description=model_description,
|
||||
widget=widget_dict,
|
||||
)
|
||||
tags = [
|
||||
"text-to-image",
|
||||
"stable-diffusion-xl",
|
||||
"stable-diffusion-xl-diffusers",
|
||||
"text-to-image",
|
||||
"diffusers",
|
||||
lora,
|
||||
"template:sd-lora",
|
||||
]
|
||||
model_card = populate_model_card(model_card, tags=tags)
|
||||
|
||||
|
||||
def log_validation(
|
||||
@@ -773,7 +772,7 @@ def parse_args(input_args=None):
|
||||
action="store_true",
|
||||
default=False,
|
||||
help=(
|
||||
"Wether to train a DoRA as proposed in- DoRA: Weight-Decomposed Low-Rank Adaptation https://arxiv.org/abs/2402.09353. "
|
||||
"Whether to train a DoRA as proposed in- DoRA: Weight-Decomposed Low-Rank Adaptation https://arxiv.org/abs/2402.09353. "
|
||||
"Note: to use DoRA you need to install peft from main, `pip install git+https://github.com/huggingface/peft.git`"
|
||||
),
|
||||
)
|
||||
@@ -1875,7 +1874,7 @@ def main(args):
|
||||
# pack the statically computed variables appropriately here. This is so that we don't
|
||||
# have to pass them to the dataloader.
|
||||
|
||||
# if --train_text_encoder_ti we need add_special_tokens to be True fo textual inversion
|
||||
# if --train_text_encoder_ti we need add_special_tokens to be True for textual inversion
|
||||
add_special_tokens = True if args.train_text_encoder_ti else False
|
||||
|
||||
if not train_dataset.custom_instance_prompts:
|
||||
|
||||
@@ -0,0 +1,201 @@
|
||||
# Training CogView4 Control
|
||||
|
||||
This (experimental) example shows how to train Control LoRAs with [CogView4](https://huggingface.co/THUDM/CogView4-6B) by conditioning it with additional structural controls (like depth maps, poses, etc.). We provide a script for full fine-tuning, too, refer to [this section](#full-fine-tuning). To know more about CogView4 Control family, refer to the following resources:
|
||||
|
||||
To incorporate additional condition latents, we expand the input features of CogView-4 from 64 to 128. The first 64 channels correspond to the original input latents to be denoised, while the latter 64 channels correspond to control latents. This expansion happens on the `patch_embed` layer, where the combined latents are projected to the expected feature dimension of rest of the network. Inference is performed using the `CogView4ControlPipeline`.
|
||||
|
||||
> [!NOTE]
|
||||
> **Gated model**
|
||||
>
|
||||
> As the model is gated, before using it with diffusers you first need to go to the [CogView4 Hugging Face page](https://huggingface.co/THUDM/CogView4-6B), fill in the form and accept the gate. Once you are in, you need to log in so that your system knows you’ve accepted the gate. Use the command below to log in:
|
||||
|
||||
```bash
|
||||
huggingface-cli login
|
||||
```
|
||||
|
||||
The example command below shows how to launch fine-tuning for pose conditions. The dataset ([`raulc0399/open_pose_controlnet`](https://huggingface.co/datasets/raulc0399/open_pose_controlnet)) being used here already has the pose conditions of the original images, so we don't have to compute them.
|
||||
|
||||
```bash
|
||||
accelerate launch train_control_lora_cogview4.py \
|
||||
--pretrained_model_name_or_path="THUDM/CogView4-6B" \
|
||||
--dataset_name="raulc0399/open_pose_controlnet" \
|
||||
--output_dir="pose-control-lora" \
|
||||
--mixed_precision="bf16" \
|
||||
--train_batch_size=1 \
|
||||
--rank=64 \
|
||||
--gradient_accumulation_steps=4 \
|
||||
--gradient_checkpointing \
|
||||
--use_8bit_adam \
|
||||
--learning_rate=1e-4 \
|
||||
--report_to="wandb" \
|
||||
--lr_scheduler="constant" \
|
||||
--lr_warmup_steps=0 \
|
||||
--max_train_steps=5000 \
|
||||
--validation_image="openpose.png" \
|
||||
--validation_prompt="A couple, 4k photo, highly detailed" \
|
||||
--offload \
|
||||
--seed="0" \
|
||||
--push_to_hub
|
||||
```
|
||||
|
||||
`openpose.png` comes from [here](https://huggingface.co/Adapter/t2iadapter/resolve/main/openpose.png).
|
||||
|
||||
You need to install `diffusers` from the branch of [this PR](https://github.com/huggingface/diffusers/pull/9999). When it's merged, you should install `diffusers` from the `main`.
|
||||
|
||||
The training script exposes additional CLI args that might be useful to experiment with:
|
||||
|
||||
* `use_lora_bias`: When set, additionally trains the biases of the `lora_B` layer.
|
||||
* `train_norm_layers`: When set, additionally trains the normalization scales. Takes care of saving and loading.
|
||||
* `lora_layers`: Specify the layers you want to apply LoRA to. If you specify "all-linear", all the linear layers will be LoRA-attached.
|
||||
|
||||
### Training with DeepSpeed
|
||||
|
||||
It's possible to train with [DeepSpeed](https://github.com/microsoft/DeepSpeed), specifically leveraging the Zero2 system optimization. To use it, save the following config to an YAML file (feel free to modify as needed):
|
||||
|
||||
```yaml
|
||||
compute_environment: LOCAL_MACHINE
|
||||
debug: false
|
||||
deepspeed_config:
|
||||
gradient_accumulation_steps: 1
|
||||
gradient_clipping: 1.0
|
||||
offload_optimizer_device: cpu
|
||||
offload_param_device: cpu
|
||||
zero3_init_flag: false
|
||||
zero_stage: 2
|
||||
distributed_type: DEEPSPEED
|
||||
downcast_bf16: 'no'
|
||||
enable_cpu_affinity: false
|
||||
machine_rank: 0
|
||||
main_training_function: main
|
||||
mixed_precision: bf16
|
||||
num_machines: 1
|
||||
num_processes: 1
|
||||
rdzv_backend: static
|
||||
same_network: true
|
||||
tpu_env: []
|
||||
tpu_use_cluster: false
|
||||
tpu_use_sudo: false
|
||||
use_cpu: false
|
||||
```
|
||||
|
||||
And then while launching training, pass the config file:
|
||||
|
||||
```bash
|
||||
accelerate launch --config_file=CONFIG_FILE.yaml ...
|
||||
```
|
||||
|
||||
### Inference
|
||||
|
||||
The pose images in our dataset were computed using the [`controlnet_aux`](https://github.com/huggingface/controlnet_aux) library. Let's install it first:
|
||||
|
||||
```bash
|
||||
pip install controlnet_aux
|
||||
```
|
||||
|
||||
And then we are ready:
|
||||
|
||||
```py
|
||||
from controlnet_aux import OpenposeDetector
|
||||
from diffusers import CogView4ControlPipeline
|
||||
from diffusers.utils import load_image
|
||||
from PIL import Image
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
pipe = CogView4ControlPipeline.from_pretrained("THUDM/CogView4-6B", torch_dtype=torch.bfloat16).to("cuda")
|
||||
pipe.load_lora_weights("...") # change this.
|
||||
|
||||
open_pose = OpenposeDetector.from_pretrained("lllyasviel/Annotators")
|
||||
|
||||
# prepare pose condition.
|
||||
url = "https://huggingface.co/Adapter/t2iadapter/resolve/main/people.jpg"
|
||||
image = load_image(url)
|
||||
image = open_pose(image, detect_resolution=512, image_resolution=1024)
|
||||
image = np.array(image)[:, :, ::-1]
|
||||
image = Image.fromarray(np.uint8(image))
|
||||
|
||||
prompt = "A couple, 4k photo, highly detailed"
|
||||
|
||||
gen_images = pipe(
|
||||
prompt=prompt,
|
||||
control_image=image,
|
||||
num_inference_steps=50,
|
||||
joint_attention_kwargs={"scale": 0.9},
|
||||
guidance_scale=25.,
|
||||
).images[0]
|
||||
gen_images.save("output.png")
|
||||
```
|
||||
|
||||
## Full fine-tuning
|
||||
|
||||
We provide a non-LoRA version of the training script `train_control_cogview4.py`. Here is an example command:
|
||||
|
||||
```bash
|
||||
accelerate launch --config_file=accelerate_ds2.yaml train_control_cogview4.py \
|
||||
--pretrained_model_name_or_path="THUDM/CogView4-6B" \
|
||||
--dataset_name="raulc0399/open_pose_controlnet" \
|
||||
--output_dir="pose-control" \
|
||||
--mixed_precision="bf16" \
|
||||
--train_batch_size=2 \
|
||||
--dataloader_num_workers=4 \
|
||||
--gradient_accumulation_steps=4 \
|
||||
--gradient_checkpointing \
|
||||
--use_8bit_adam \
|
||||
--proportion_empty_prompts=0.2 \
|
||||
--learning_rate=5e-5 \
|
||||
--adam_weight_decay=1e-4 \
|
||||
--report_to="wandb" \
|
||||
--lr_scheduler="cosine" \
|
||||
--lr_warmup_steps=1000 \
|
||||
--checkpointing_steps=1000 \
|
||||
--max_train_steps=10000 \
|
||||
--validation_steps=200 \
|
||||
--validation_image "2_pose_1024.jpg" "3_pose_1024.jpg" \
|
||||
--validation_prompt "two friends sitting by each other enjoying a day at the park, full hd, cinematic" "person enjoying a day at the park, full hd, cinematic" \
|
||||
--offload \
|
||||
--seed="0" \
|
||||
--push_to_hub
|
||||
```
|
||||
|
||||
Change the `validation_image` and `validation_prompt` as needed.
|
||||
|
||||
For inference, this time, we will run:
|
||||
|
||||
```py
|
||||
from controlnet_aux import OpenposeDetector
|
||||
from diffusers import CogView4ControlPipeline, CogView4Transformer2DModel
|
||||
from diffusers.utils import load_image
|
||||
from PIL import Image
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
transformer = CogView4Transformer2DModel.from_pretrained("...") # change this.
|
||||
pipe = CogView4ControlPipeline.from_pretrained(
|
||||
"THUDM/CogView4-6B", transformer=transformer, torch_dtype=torch.bfloat16
|
||||
).to("cuda")
|
||||
|
||||
open_pose = OpenposeDetector.from_pretrained("lllyasviel/Annotators")
|
||||
|
||||
# prepare pose condition.
|
||||
url = "https://huggingface.co/Adapter/t2iadapter/resolve/main/people.jpg"
|
||||
image = load_image(url)
|
||||
image = open_pose(image, detect_resolution=512, image_resolution=1024)
|
||||
image = np.array(image)[:, :, ::-1]
|
||||
image = Image.fromarray(np.uint8(image))
|
||||
|
||||
prompt = "A couple, 4k photo, highly detailed"
|
||||
|
||||
gen_images = pipe(
|
||||
prompt=prompt,
|
||||
control_image=image,
|
||||
num_inference_steps=50,
|
||||
guidance_scale=25.,
|
||||
).images[0]
|
||||
gen_images.save("output.png")
|
||||
```
|
||||
|
||||
## Things to note
|
||||
|
||||
* The scripts provided in this directory are experimental and educational. This means we may have to tweak things around to get good results on a given condition. We believe this is best done with the community 🤗
|
||||
* The scripts are not memory-optimized but we offload the VAE and the text encoders to CPU when they are not used if `--offload` is specified.
|
||||
* We can extract LoRAs from the fully fine-tuned model. While we currently don't provide any utilities for that, users are welcome to refer to [this script](https://github.com/Stability-AI/stability-ComfyUI-nodes/blob/master/control_lora_create.py) that provides a similar functionality.
|
||||
@@ -0,0 +1,6 @@
|
||||
transformers==4.47.0
|
||||
wandb
|
||||
torch
|
||||
torchvision
|
||||
accelerate==1.2.0
|
||||
peft>=0.14.0
|
||||
File diff suppressed because it is too large
Load Diff
+195
-21
@@ -10,6 +10,7 @@ Please also check out our [Community Scripts](https://github.com/huggingface/dif
|
||||
|
||||
| Example | Description | Code Example | Colab | Author |
|
||||
|:--------------------------------------------------------------------------------------------------------------------------------------|:---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|:------------------------------------------------------------------------------------------|:-------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|--------------------------------------------------------------:|
|
||||
|Spatiotemporal Skip Guidance (STG)|[Spatiotemporal Skip Guidance for Enhanced Video Diffusion Sampling](https://arxiv.org/abs/2411.18664) (CVPR 2025) enhances video diffusion models by generating a weaker model through layer skipping and using it as guidance, improving fidelity in models like HunyuanVideo, LTXVideo, and Mochi.|[Spatiotemporal Skip Guidance](#spatiotemporal-skip-guidance)|-|[Junha Hyung](https://junhahyung.github.io/), [Kinam Kim](https://kinam0252.github.io/), and [Ednaordinary](https://github.com/Ednaordinary)|
|
||||
|Adaptive Mask Inpainting|Adaptive Mask Inpainting algorithm from [Beyond the Contact: Discovering Comprehensive Affordance for 3D Objects from Pre-trained 2D Diffusion Models](https://github.com/snuvclab/coma) (ECCV '24, Oral) provides a way to insert human inside the scene image without altering the background, by inpainting with adapting mask.|[Adaptive Mask Inpainting](#adaptive-mask-inpainting)|-|[Hyeonwoo Kim](https://sshowbiz.xyz),[Sookwan Han](https://jellyheadandrew.github.io)|
|
||||
|Flux with CFG|[Flux with CFG](https://github.com/ToTheBeginning/PuLID/blob/main/docs/pulid_for_flux.md) provides an implementation of using CFG in [Flux](https://blackforestlabs.ai/announcing-black-forest-labs/).|[Flux with CFG](#flux-with-cfg)|[Notebook](https://github.com/huggingface/notebooks/blob/main/diffusers/flux_with_cfg.ipynb)|[Linoy Tsaban](https://github.com/linoytsaban), [Apolinário](https://github.com/apolinario), and [Sayak Paul](https://github.com/sayakpaul)|
|
||||
|Differential Diffusion|[Differential Diffusion](https://github.com/exx8/differential-diffusion) modifies an image according to a text prompt, and according to a map that specifies the amount of change in each region.|[Differential Diffusion](#differential-diffusion)|[](https://huggingface.co/spaces/exx8/differential-diffusion) [](https://colab.research.google.com/github/exx8/differential-diffusion/blob/main/examples/SD2.ipynb)|[Eran Levin](https://github.com/exx8) and [Ohad Fried](https://www.ohadf.com/)|
|
||||
@@ -23,12 +24,12 @@ Please also check out our [Community Scripts](https://github.com/huggingface/dif
|
||||
| Long Prompt Weighting Stable Diffusion | **One** Stable Diffusion Pipeline without tokens length limit, and support parsing weighting in prompt. | [Long Prompt Weighting Stable Diffusion](#long-prompt-weighting-stable-diffusion) | [Notebook](https://github.com/huggingface/notebooks/blob/main/diffusers/long_prompt_weighting_stable_diffusion.ipynb) | [SkyTNT](https://github.com/SkyTNT) |
|
||||
| Speech to Image | Using automatic-speech-recognition to transcribe text and Stable Diffusion to generate images | [Speech to Image](#speech-to-image) |[Notebook](https://github.com/huggingface/notebooks/blob/main/diffusers/speech_to_image.ipynb) | [Mikail Duzenli](https://github.com/MikailINTech)
|
||||
| Wild Card Stable Diffusion | Stable Diffusion Pipeline that supports prompts that contain wildcard terms (indicated by surrounding double underscores), with values instantiated randomly from a corresponding txt file or a dictionary of possible values | [Wildcard Stable Diffusion](#wildcard-stable-diffusion) | [Notebook](https://github.com/huggingface/notebooks/blob/main/diffusers/wildcard_stable_diffusion.ipynb) | [Shyam Sudhakaran](https://github.com/shyamsn97) |
|
||||
| [Composable Stable Diffusion](https://energy-based-model.github.io/Compositional-Visual-Generation-with-Composable-Diffusion-Models/) | Stable Diffusion Pipeline that supports prompts that contain "|" in prompts (as an AND condition) and weights (separated by "|" as well) to positively / negatively weight prompts. | [Composable Stable Diffusion](#composable-stable-diffusion) | - | [Mark Rich](https://github.com/MarkRich) |
|
||||
| [Composable Stable Diffusion](https://energy-based-model.github.io/Compositional-Visual-Generation-with-Composable-Diffusion-Models/) | Stable Diffusion Pipeline that supports prompts that contain "|" in prompts (as an AND condition) and weights (separated by "|" as well) to positively / negatively weight prompts. | [Composable Stable Diffusion](#composable-stable-diffusion) | [Notebook](https://github.com/huggingface/notebooks/blob/main/diffusers/composable_stable_diffusion.ipynb) | [Mark Rich](https://github.com/MarkRich) |
|
||||
| Seed Resizing Stable Diffusion | Stable Diffusion Pipeline that supports resizing an image and retaining the concepts of the 512 by 512 generation. | [Seed Resizing](#seed-resizing) | [Notebook](https://github.com/huggingface/notebooks/blob/main/diffusers/seed_resizing.ipynb) | [Mark Rich](https://github.com/MarkRich) |
|
||||
| Imagic Stable Diffusion | Stable Diffusion Pipeline that enables writing a text prompt to edit an existing image | [Imagic Stable Diffusion](#imagic-stable-diffusion) | [Notebook](https://github.com/huggingface/notebooks/blob/main/diffusers/imagic_stable_diffusion.ipynb) | [Mark Rich](https://github.com/MarkRich) |
|
||||
| Multilingual Stable Diffusion | Stable Diffusion Pipeline that supports prompts in 50 different languages. | [Multilingual Stable Diffusion](#multilingual-stable-diffusion-pipeline) | [Notebook](https://github.com/huggingface/notebooks/blob/main/diffusers/multilingual_stable_diffusion.ipynb) | [Juan Carlos Piñeros](https://github.com/juancopi81) |
|
||||
| GlueGen Stable Diffusion | Stable Diffusion Pipeline that supports prompts in different languages using GlueGen adapter. | [GlueGen Stable Diffusion](#gluegen-stable-diffusion-pipeline) | [Notebook](https://github.com/huggingface/notebooks/blob/main/diffusers/gluegen_stable_diffusion.ipynb) | [Phạm Hồng Vinh](https://github.com/rootonchair) |
|
||||
| Image to Image Inpainting Stable Diffusion | Stable Diffusion Pipeline that enables the overlaying of two images and subsequent inpainting | [Image to Image Inpainting Stable Diffusion](#image-to-image-inpainting-stable-diffusion) | - | [Alex McKinney](https://github.com/vvvm23) |
|
||||
| Image to Image Inpainting Stable Diffusion | Stable Diffusion Pipeline that enables the overlaying of two images and subsequent inpainting | [Image to Image Inpainting Stable Diffusion](#image-to-image-inpainting-stable-diffusion) | [Notebook](https://github.com/huggingface/notebooks/blob/main/diffusers/image_to_image_inpainting_stable_diffusion.ipynb) | [Alex McKinney](https://github.com/vvvm23) |
|
||||
| Text Based Inpainting Stable Diffusion | Stable Diffusion Inpainting Pipeline that enables passing a text prompt to generate the mask for inpainting | [Text Based Inpainting Stable Diffusion](#text-based-inpainting-stable-diffusion) | [Notebook](https://github.com/huggingface/notebooks/blob/main/diffusers/text_based_inpainting_stable_dffusion.ipynb) | [Dhruv Karan](https://github.com/unography) |
|
||||
| Bit Diffusion | Diffusion on discrete data | [Bit Diffusion](#bit-diffusion) | - | [Stuti R.](https://github.com/kingstut) |
|
||||
| K-Diffusion Stable Diffusion | Run Stable Diffusion with any of [K-Diffusion's samplers](https://github.com/crowsonkb/k-diffusion/blob/master/k_diffusion/sampling.py) | [Stable Diffusion with K Diffusion](#stable-diffusion-with-k-diffusion) | - | [Patrick von Platen](https://github.com/patrickvonplaten/) |
|
||||
@@ -40,7 +41,7 @@ Please also check out our [Community Scripts](https://github.com/huggingface/dif
|
||||
| UnCLIP Image Interpolation Pipeline | Diffusion Pipeline that allows passing two images/image_embeddings and produces images while interpolating between their image-embeddings | [UnCLIP Image Interpolation Pipeline](#unclip-image-interpolation-pipeline) | [Notebook](https://github.com/huggingface/notebooks/blob/main/diffusers/unclip_image_interpolation.ipynb)| [Naga Sai Abhinay Devarinti](https://github.com/Abhinay1997/) |
|
||||
| DDIM Noise Comparative Analysis Pipeline | Investigating how the diffusion models learn visual concepts from each noise level (which is a contribution of [P2 weighting (CVPR 2022)](https://arxiv.org/abs/2204.00227)) | [DDIM Noise Comparative Analysis Pipeline](#ddim-noise-comparative-analysis-pipeline) | [Notebook](https://github.com/huggingface/notebooks/blob/main/diffusers/ddim_noise_comparative_analysis.ipynb)| [Aengus (Duc-Anh)](https://github.com/aengusng8) |
|
||||
| CLIP Guided Img2Img Stable Diffusion Pipeline | Doing CLIP guidance for image to image generation with Stable Diffusion | [CLIP Guided Img2Img Stable Diffusion](#clip-guided-img2img-stable-diffusion) | [Notebook](https://github.com/huggingface/notebooks/blob/main/diffusers/clip_guided_img2img_stable_diffusion.ipynb) | [Nipun Jindal](https://github.com/nipunjindal/) |
|
||||
| TensorRT Stable Diffusion Text to Image Pipeline | Accelerates the Stable Diffusion Text2Image Pipeline using TensorRT | [TensorRT Stable Diffusion Text to Image Pipeline](#tensorrt-text2image-stable-diffusion-pipeline) | - | [Asfiya Baig](https://github.com/asfiyab-nvidia) |
|
||||
| TensorRT Stable Diffusion Text to Image Pipeline | Accelerates the Stable Diffusion Text2Image Pipeline using TensorRT | [TensorRT Stable Diffusion Text to Image Pipeline](#tensorrt-text2image-stable-diffusion-pipeline) | [Notebook](https://github.com/huggingface/notebooks/blob/main/diffusers/tensorrt_text2image_stable_diffusion_pipeline.ipynb) | [Asfiya Baig](https://github.com/asfiyab-nvidia) |
|
||||
| EDICT Image Editing Pipeline | Diffusion pipeline for text-guided image editing | [EDICT Image Editing Pipeline](#edict-image-editing-pipeline) | [Notebook](https://github.com/huggingface/notebooks/blob/main/diffusers/edict_image_pipeline.ipynb) | [Joqsan Azocar](https://github.com/Joqsan) |
|
||||
| Stable Diffusion RePaint | Stable Diffusion pipeline using [RePaint](https://arxiv.org/abs/2201.09865) for inpainting. | [Stable Diffusion RePaint](#stable-diffusion-repaint )|[Notebook](https://github.com/huggingface/notebooks/blob/main/diffusers/stable_diffusion_repaint.ipynb)| [Markus Pobitzer](https://github.com/Markus-Pobitzer) |
|
||||
| TensorRT Stable Diffusion Image to Image Pipeline | Accelerates the Stable Diffusion Image2Image Pipeline using TensorRT | [TensorRT Stable Diffusion Image to Image Pipeline](#tensorrt-image2image-stable-diffusion-pipeline) | - | [Asfiya Baig](https://github.com/asfiyab-nvidia) |
|
||||
@@ -57,7 +58,7 @@ Please also check out our [Community Scripts](https://github.com/huggingface/dif
|
||||
| FABRIC - Stable Diffusion with feedback Pipeline | pipeline supports feedback from liked and disliked images | [Stable Diffusion Fabric Pipeline](#stable-diffusion-fabric-pipeline) | [Notebook](https://github.com/huggingface/notebooks/blob/main/diffusers/stable_diffusion_fabric.ipynb)| [Shauray Singh](https://shauray8.github.io/about_shauray/) |
|
||||
| sketch inpaint - Inpainting with non-inpaint Stable Diffusion | sketch inpaint much like in automatic1111 | [Masked Im2Im Stable Diffusion Pipeline](#stable-diffusion-masked-im2im) | - | [Anatoly Belikov](https://github.com/noskill) |
|
||||
| sketch inpaint xl - Inpainting with non-inpaint Stable Diffusion | sketch inpaint much like in automatic1111 | [Masked Im2Im Stable Diffusion XL Pipeline](#stable-diffusion-xl-masked-im2im) | - | [Anatoly Belikov](https://github.com/noskill) |
|
||||
| prompt-to-prompt | change parts of a prompt and retain image structure (see [paper page](https://prompt-to-prompt.github.io/)) | [Prompt2Prompt Pipeline](#prompt2prompt-pipeline) | - | [Umer H. Adil](https://twitter.com/UmerHAdil) |
|
||||
| prompt-to-prompt | change parts of a prompt and retain image structure (see [paper page](https://prompt-to-prompt.github.io/)) | [Prompt2Prompt Pipeline](#prompt2prompt-pipeline) | [Notebook](https://github.com/huggingface/notebooks/blob/main/diffusers/prompt_2_prompt_pipeline.ipynb) | [Umer H. Adil](https://twitter.com/UmerHAdil) |
|
||||
| Latent Consistency Pipeline | Implementation of [Latent Consistency Models: Synthesizing High-Resolution Images with Few-Step Inference](https://arxiv.org/abs/2310.04378) | [Latent Consistency Pipeline](#latent-consistency-pipeline) | - | [Simian Luo](https://github.com/luosiallen) |
|
||||
| Latent Consistency Img2img Pipeline | Img2img pipeline for Latent Consistency Models | [Latent Consistency Img2Img Pipeline](#latent-consistency-img2img-pipeline) | - | [Logan Zoellner](https://github.com/nagolinc) |
|
||||
| Latent Consistency Interpolation Pipeline | Interpolate the latent space of Latent Consistency Models with multiple prompts | [Latent Consistency Interpolation Pipeline](#latent-consistency-interpolation-pipeline) | [](https://colab.research.google.com/drive/1pK3NrLWJSiJsBynLns1K1-IDTW9zbPvl?usp=sharing) | [Aryan V S](https://github.com/a-r-r-o-w) |
|
||||
@@ -84,7 +85,7 @@ PIXART-α Controlnet pipeline | Implementation of the controlnet model for pixar
|
||||
| Stable Diffusion XL Attentive Eraser Pipeline |[[AAAI2025 Oral] Attentive Eraser](https://github.com/Anonym0u3/AttentiveEraser) is a novel tuning-free method that enhances object removal capabilities in pre-trained diffusion models.|[Stable Diffusion XL Attentive Eraser Pipeline](#stable-diffusion-xl-attentive-eraser-pipeline)|-|[Wenhao Sun](https://github.com/Anonym0u3) and [Benlei Cui](https://github.com/Benny079)|
|
||||
| Perturbed-Attention Guidance |StableDiffusionPAGPipeline is a modification of StableDiffusionPipeline to support Perturbed-Attention Guidance (PAG).|[Perturbed-Attention Guidance](#perturbed-attention-guidance)|[Notebook](https://github.com/huggingface/notebooks/blob/main/diffusers/perturbed_attention_guidance.ipynb)|[Hyoungwon Cho](https://github.com/HyoungwonCho)|
|
||||
| CogVideoX DDIM Inversion Pipeline | Implementation of DDIM inversion and guided attention-based editing denoising process on CogVideoX. | [CogVideoX DDIM Inversion Pipeline](#cogvideox-ddim-inversion-pipeline) | - | [LittleNyima](https://github.com/LittleNyima) |
|
||||
|
||||
| FaithDiff Stable Diffusion XL Pipeline | Implementation of [(CVPR 2025) FaithDiff: Unleashing Diffusion Priors for Faithful Image Super-resolutionUnleashing Diffusion Priors for Faithful Image Super-resolution](https://arxiv.org/abs/2411.18824) - FaithDiff is a faithful image super-resolution method that leverages latent diffusion models by actively adapting the diffusion prior and jointly fine-tuning its components (encoder and diffusion model) with an alignment module to ensure high fidelity and structural consistency. | [FaithDiff Stable Diffusion XL Pipeline](#faithdiff-stable-diffusion-xl-pipeline) | [](https://huggingface.co/jychen9811/FaithDiff) | [Junyang Chen, Jinshan Pan, Jiangxin Dong, IMAG Lab, (Adapted by Eliseu Silva)](https://github.com/JyChen9811/FaithDiff) |
|
||||
To load a custom pipeline you just need to pass the `custom_pipeline` argument to `DiffusionPipeline`, as one of the files in `diffusers/examples/community`. Feel free to send a PR with your own pipelines, we will merge them quickly.
|
||||
|
||||
```py
|
||||
@@ -93,6 +94,54 @@ pipe = DiffusionPipeline.from_pretrained("stable-diffusion-v1-5/stable-diffusion
|
||||
|
||||
## Example usages
|
||||
|
||||
### Spatiotemporal Skip Guidance
|
||||
|
||||
**Junha Hyung\*, Kinam Kim\*, Susung Hong, Min-Jung Kim, Jaegul Choo**
|
||||
|
||||
**KAIST AI, University of Washington**
|
||||
|
||||
[*Spatiotemporal Skip Guidance (STG) for Enhanced Video Diffusion Sampling*](https://arxiv.org/abs/2411.18664) (CVPR 2025) is a simple training-free sampling guidance method for enhancing transformer-based video diffusion models. STG employs an implicit weak model via self-perturbation, avoiding the need for external models or additional training. By selectively skipping spatiotemporal layers, STG produces an aligned, degraded version of the original model to boost sample quality without compromising diversity or dynamic degree.
|
||||
|
||||
Following is the example video of STG applied to Mochi.
|
||||
|
||||
|
||||
https://github.com/user-attachments/assets/148adb59-da61-4c50-9dfa-425dcb5c23b3
|
||||
|
||||
More examples and information can be found on the [GitHub repository](https://github.com/junhahyung/STGuidance) and the [Project website](https://junhahyung.github.io/STGuidance/).
|
||||
|
||||
#### Usage example
|
||||
```python
|
||||
import torch
|
||||
from pipeline_stg_mochi import MochiSTGPipeline
|
||||
from diffusers.utils import export_to_video
|
||||
|
||||
# Load the pipeline
|
||||
pipe = MochiSTGPipeline.from_pretrained("genmo/mochi-1-preview", variant="bf16", torch_dtype=torch.bfloat16)
|
||||
|
||||
# Enable memory savings
|
||||
pipe = pipe.to("cuda")
|
||||
|
||||
#--------Option--------#
|
||||
prompt = "A close-up of a beautiful woman's face with colored powder exploding around her, creating an abstract splash of vibrant hues, realistic style."
|
||||
stg_applied_layers_idx = [34]
|
||||
stg_scale = 1.0 # 0.0 for CFG
|
||||
#----------------------#
|
||||
|
||||
# Generate video frames
|
||||
frames = pipe(
|
||||
prompt,
|
||||
height=480,
|
||||
width=480,
|
||||
num_frames=81,
|
||||
stg_applied_layers_idx=stg_applied_layers_idx,
|
||||
stg_scale=stg_scale,
|
||||
generator = torch.Generator().manual_seed(42),
|
||||
do_rescaling=do_rescaling,
|
||||
).frames[0]
|
||||
|
||||
export_to_video(frames, "output.mp4", fps=30)
|
||||
```
|
||||
|
||||
### Adaptive Mask Inpainting
|
||||
|
||||
**Hyeonwoo Kim\*, Sookwan Han\*, Patrick Kwon, Hanbyul Joo**
|
||||
@@ -904,6 +953,7 @@ for i in range(args.num_images):
|
||||
images.append(th.from_numpy(np.array(image)).permute(2, 0, 1) / 255.)
|
||||
grid = tvu.make_grid(th.stack(images, dim=0), nrow=4, padding=0)
|
||||
tvu.save_image(grid, f'{prompt}_{args.weights}' + '.png')
|
||||
print("Image saved successfully!")
|
||||
```
|
||||
|
||||
### Imagic Stable Diffusion
|
||||
@@ -1219,28 +1269,39 @@ The aim is to overlay two images, then mask out the boundary between `image` and
|
||||
For example, this could be used to place a logo on a shirt and make it blend seamlessly.
|
||||
|
||||
```python
|
||||
import PIL
|
||||
import torch
|
||||
|
||||
import requests
|
||||
from PIL import Image
|
||||
from io import BytesIO
|
||||
from diffusers import DiffusionPipeline
|
||||
|
||||
image_path = "./path-to-image.png"
|
||||
inner_image_path = "./path-to-inner-image.png"
|
||||
mask_path = "./path-to-mask.png"
|
||||
image_url = "https://raw.githubusercontent.com/CompVis/latent-diffusion/main/data/inpainting_examples/overture-creations-5sI6fQgYIuo.png"
|
||||
inner_image_url = "https://raw.githubusercontent.com/CompVis/latent-diffusion/main/data/inpainting_examples/overture-creations-5sI6fQgYIuo_mask.png"
|
||||
mask_url = "https://raw.githubusercontent.com/CompVis/latent-diffusion/main/data/inpainting_examples/overture-creations-5sI6fQgYIuo_mask.png"
|
||||
|
||||
init_image = PIL.Image.open(image_path).convert("RGB").resize((512, 512))
|
||||
inner_image = PIL.Image.open(inner_image_path).convert("RGBA").resize((512, 512))
|
||||
mask_image = PIL.Image.open(mask_path).convert("RGB").resize((512, 512))
|
||||
def load_image(url, mode="RGB"):
|
||||
response = requests.get(url)
|
||||
if response.status_code == 200:
|
||||
return Image.open(BytesIO(response.content)).convert(mode).resize((512, 512))
|
||||
else:
|
||||
raise FileNotFoundError(f"Could not retrieve image from {url}")
|
||||
|
||||
|
||||
init_image = load_image(image_url, mode="RGB")
|
||||
inner_image = load_image(inner_image_url, mode="RGBA")
|
||||
mask_image = load_image(mask_url, mode="RGB")
|
||||
|
||||
pipe = DiffusionPipeline.from_pretrained(
|
||||
"runwayml/stable-diffusion-inpainting",
|
||||
"stable-diffusion-v1-5/stable-diffusion-inpainting",
|
||||
custom_pipeline="img2img_inpainting",
|
||||
torch_dtype=torch.float16
|
||||
)
|
||||
pipe = pipe.to("cuda")
|
||||
|
||||
prompt = "Your prompt here!"
|
||||
prompt = "a mecha robot sitting on a bench"
|
||||
image = pipe(prompt=prompt, image=init_image, inner_image=inner_image, mask_image=mask_image).images[0]
|
||||
|
||||
image.save("output.png")
|
||||
```
|
||||
|
||||

|
||||
@@ -3202,14 +3263,19 @@ Here's a full example for `ReplaceEdit``:
|
||||
|
||||
```python
|
||||
import torch
|
||||
import numpy as np
|
||||
import matplotlib.pyplot as plt
|
||||
from diffusers import DiffusionPipeline
|
||||
import numpy as np
|
||||
from PIL import Image
|
||||
|
||||
pipe = DiffusionPipeline.from_pretrained("CompVis/stable-diffusion-v1-4", custom_pipeline="pipeline_prompt2prompt").to("cuda")
|
||||
pipe = DiffusionPipeline.from_pretrained(
|
||||
"CompVis/stable-diffusion-v1-4",
|
||||
custom_pipeline="pipeline_prompt2prompt"
|
||||
).to("cuda")
|
||||
|
||||
prompts = ["A turtle playing with a ball",
|
||||
"A monkey playing with a ball"]
|
||||
prompts = [
|
||||
"A turtle playing with a ball",
|
||||
"A monkey playing with a ball"
|
||||
]
|
||||
|
||||
cross_attention_kwargs = {
|
||||
"edit_type": "replace",
|
||||
@@ -3217,7 +3283,15 @@ cross_attention_kwargs = {
|
||||
"self_replace_steps": 0.4
|
||||
}
|
||||
|
||||
outputs = pipe(prompt=prompts, height=512, width=512, num_inference_steps=50, cross_attention_kwargs=cross_attention_kwargs)
|
||||
outputs = pipe(
|
||||
prompt=prompts,
|
||||
height=512,
|
||||
width=512,
|
||||
num_inference_steps=50,
|
||||
cross_attention_kwargs=cross_attention_kwargs
|
||||
)
|
||||
|
||||
outputs.images[0].save("output_image_0.png")
|
||||
```
|
||||
|
||||
And abbreviated examples for the other edits:
|
||||
@@ -5259,3 +5333,103 @@ output = pipeline_for_inversion(
|
||||
pipeline.export_latents_to_video(output.inverse_latents[-1], "path/to/inverse_video.mp4", fps=8)
|
||||
pipeline.export_latents_to_video(output.recon_latents[-1], "path/to/recon_video.mp4", fps=8)
|
||||
```
|
||||
# FaithDiff Stable Diffusion XL Pipeline
|
||||
|
||||
[Project](https://jychen9811.github.io/FaithDiff_page/) / [GitHub](https://github.com/JyChen9811/FaithDiff/)
|
||||
|
||||
This the implementation of the FaithDiff pipeline for SDXL, adapted to use the HuggingFace Diffusers.
|
||||
|
||||
For more details see the project links above.
|
||||
|
||||
## Example Usage
|
||||
|
||||
This example upscale and restores a low-quality image. The input image has a resolution of 512x512 and will be upscaled at a scale of 2x, to a final resolution of 1024x1024. It is possible to upscale to a larger scale, but it is recommended that the input image be at least 1024x1024 in these cases. To upscale this image by 4x, for example, it would be recommended to re-input the result into a new 2x processing, thus performing progressive scaling.
|
||||
|
||||
````py
|
||||
import random
|
||||
import numpy as np
|
||||
import torch
|
||||
from diffusers import DiffusionPipeline, AutoencoderKL, UniPCMultistepScheduler
|
||||
from huggingface_hub import hf_hub_download
|
||||
from diffusers.utils import load_image
|
||||
from PIL import Image
|
||||
|
||||
device = "cuda"
|
||||
dtype = torch.float16
|
||||
MAX_SEED = np.iinfo(np.int32).max
|
||||
|
||||
# Download weights for additional unet layers
|
||||
model_file = hf_hub_download(
|
||||
"jychen9811/FaithDiff",
|
||||
filename="FaithDiff.bin", local_dir="./proc_data/faithdiff", local_dir_use_symlinks=False
|
||||
)
|
||||
|
||||
# Initialize the models and pipeline
|
||||
vae = AutoencoderKL.from_pretrained("madebyollin/sdxl-vae-fp16-fix", torch_dtype=dtype)
|
||||
|
||||
model_id = "SG161222/RealVisXL_V4.0"
|
||||
pipe = DiffusionPipeline.from_pretrained(
|
||||
model_id,
|
||||
torch_dtype=dtype,
|
||||
vae=vae,
|
||||
unet=None, #<- Do not load with original model.
|
||||
custom_pipeline="pipeline_faithdiff_stable_diffusion_xl",
|
||||
use_safetensors=True,
|
||||
variant="fp16",
|
||||
).to(device)
|
||||
|
||||
# Here we need use pipeline internal unet model
|
||||
pipe.unet = pipe.unet_model.from_pretrained(model_id, subfolder="unet", variant="fp16", use_safetensors=True)
|
||||
|
||||
# Load aditional layers to the model
|
||||
pipe.unet.load_additional_layers(weight_path="proc_data/faithdiff/FaithDiff.bin", dtype=dtype)
|
||||
|
||||
# Enable vae tiling
|
||||
pipe.set_encoder_tile_settings()
|
||||
pipe.enable_vae_tiling()
|
||||
|
||||
# Optimization
|
||||
pipe.enable_model_cpu_offload()
|
||||
|
||||
# Set selected scheduler
|
||||
pipe.scheduler = UniPCMultistepScheduler.from_config(pipe.scheduler.config)
|
||||
|
||||
#input params
|
||||
prompt = "The image features a woman in her 55s with blonde hair and a white shirt, smiling at the camera. She appears to be in a good mood and is wearing a white scarf around her neck. "
|
||||
upscale = 2 # scale here
|
||||
start_point = "lr" # or "noise"
|
||||
latent_tiled_overlap = 0.5
|
||||
latent_tiled_size = 1024
|
||||
|
||||
# Load image
|
||||
lq_image = load_image("https://huggingface.co/datasets/DEVAIEXP/assets/resolve/main/woman.png")
|
||||
original_height = lq_image.height
|
||||
original_width = lq_image.width
|
||||
print(f"Current resolution: H:{original_height} x W:{original_width}")
|
||||
|
||||
width = original_width * int(upscale)
|
||||
height = original_height * int(upscale)
|
||||
print(f"Final resolution: H:{height} x W:{width}")
|
||||
|
||||
# Restoration
|
||||
image = lq_image.resize((width, height), Image.LANCZOS)
|
||||
input_image, width_init, height_init, width_now, height_now = pipe.check_image_size(image)
|
||||
|
||||
generator = torch.Generator(device=device).manual_seed(random.randint(0, MAX_SEED))
|
||||
gen_image = pipe(lr_img=input_image,
|
||||
prompt = prompt,
|
||||
num_inference_steps=20,
|
||||
guidance_scale=5,
|
||||
generator=generator,
|
||||
start_point=start_point,
|
||||
height = height_now,
|
||||
width=width_now,
|
||||
overlap=latent_tiled_overlap,
|
||||
target_size=(latent_tiled_size, latent_tiled_size)
|
||||
).images[0]
|
||||
|
||||
cropped_image = gen_image.crop((0, 0, width_init, height_init))
|
||||
cropped_image.save("data/result.png")
|
||||
````
|
||||
### Result
|
||||
[<img src="https://huggingface.co/datasets/DEVAIEXP/assets/resolve/main/faithdiff_restored.PNG" width="512px" height="512px"/>](https://imgsli.com/MzY1NzE2)
|
||||
@@ -1773,7 +1773,7 @@ class SDXLLongPromptWeightingPipeline(
|
||||
f"Incorrect configuration settings! The config of `pipeline.unet`: {self.unet.config} expects"
|
||||
f" {self.unet.config.in_channels} but received `num_channels_latents`: {num_channels_latents} +"
|
||||
f" `num_channels_mask`: {num_channels_mask} + `num_channels_masked_image`: {num_channels_masked_image}"
|
||||
f" = {num_channels_latents+num_channels_masked_image+num_channels_mask}. Please verify the config of"
|
||||
f" = {num_channels_latents + num_channels_masked_image + num_channels_mask}. Please verify the config of"
|
||||
" `pipeline.unet` or your `mask_image` or `image` input."
|
||||
)
|
||||
elif num_channels_unet != 4:
|
||||
@@ -1924,7 +1924,22 @@ class SDXLLongPromptWeightingPipeline(
|
||||
self.upcast_vae()
|
||||
latents = latents.to(next(iter(self.vae.post_quant_conv.parameters())).dtype)
|
||||
|
||||
image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0]
|
||||
# unscale/denormalize the latents
|
||||
# denormalize with the mean and std if available and not None
|
||||
has_latents_mean = hasattr(self.vae.config, "latents_mean") and self.vae.config.latents_mean is not None
|
||||
has_latents_std = hasattr(self.vae.config, "latents_std") and self.vae.config.latents_std is not None
|
||||
if has_latents_mean and has_latents_std:
|
||||
latents_mean = (
|
||||
torch.tensor(self.vae.config.latents_mean).view(1, 4, 1, 1).to(latents.device, latents.dtype)
|
||||
)
|
||||
latents_std = (
|
||||
torch.tensor(self.vae.config.latents_std).view(1, 4, 1, 1).to(latents.device, latents.dtype)
|
||||
)
|
||||
latents = latents * latents_std / self.vae.config.scaling_factor + latents_mean
|
||||
else:
|
||||
latents = latents / self.vae.config.scaling_factor
|
||||
|
||||
image = self.vae.decode(latents, return_dict=False)[0]
|
||||
|
||||
# cast back to fp16 if needed
|
||||
if needs_upcasting:
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
# Copyright 2025 The HuggingFace Team. All rights reserved.
|
||||
# Copyright 2025 The DEVAIEXP Team and The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
@@ -1070,32 +1070,32 @@ class StableDiffusionXLTilingPipeline(
|
||||
text_encoder_projection_dim = int(pooled_prompt_embeds.shape[-1])
|
||||
else:
|
||||
text_encoder_projection_dim = self.text_encoder_2.config.projection_dim
|
||||
add_time_ids = self._get_add_time_ids(
|
||||
original_size,
|
||||
crops_coords_top_left[row][col],
|
||||
target_size,
|
||||
add_time_ids = self._get_add_time_ids(
|
||||
original_size,
|
||||
crops_coords_top_left[row][col],
|
||||
target_size,
|
||||
dtype=prompt_embeds.dtype,
|
||||
text_encoder_projection_dim=text_encoder_projection_dim,
|
||||
)
|
||||
if negative_original_size is not None and negative_target_size is not None:
|
||||
negative_add_time_ids = self._get_add_time_ids(
|
||||
negative_original_size,
|
||||
negative_crops_coords_top_left[row][col],
|
||||
negative_target_size,
|
||||
dtype=prompt_embeds.dtype,
|
||||
text_encoder_projection_dim=text_encoder_projection_dim,
|
||||
)
|
||||
if negative_original_size is not None and negative_target_size is not None:
|
||||
negative_add_time_ids = self._get_add_time_ids(
|
||||
negative_original_size,
|
||||
negative_crops_coords_top_left[row][col],
|
||||
negative_target_size,
|
||||
dtype=prompt_embeds.dtype,
|
||||
text_encoder_projection_dim=text_encoder_projection_dim,
|
||||
)
|
||||
else:
|
||||
negative_add_time_ids = add_time_ids
|
||||
else:
|
||||
negative_add_time_ids = add_time_ids
|
||||
|
||||
if self.do_classifier_free_guidance:
|
||||
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)
|
||||
add_text_embeds = torch.cat([negative_pooled_prompt_embeds, add_text_embeds], dim=0)
|
||||
add_time_ids = torch.cat([negative_add_time_ids, add_time_ids], dim=0)
|
||||
if self.do_classifier_free_guidance:
|
||||
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)
|
||||
add_text_embeds = torch.cat([negative_pooled_prompt_embeds, add_text_embeds], dim=0)
|
||||
add_time_ids = torch.cat([negative_add_time_ids, add_time_ids], dim=0)
|
||||
|
||||
prompt_embeds = prompt_embeds.to(device)
|
||||
add_text_embeds = add_text_embeds.to(device)
|
||||
add_time_ids = add_time_ids.to(device).repeat(batch_size * num_images_per_prompt, 1)
|
||||
prompt_embeds = prompt_embeds.to(device)
|
||||
add_text_embeds = add_text_embeds.to(device)
|
||||
add_time_ids = add_time_ids.to(device).repeat(batch_size * num_images_per_prompt, 1)
|
||||
addition_embed_type_row.append((prompt_embeds, add_text_embeds, add_time_ids))
|
||||
embeddings_and_added_time.append(addition_embed_type_row)
|
||||
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,876 @@
|
||||
# Copyright 2024 The CogVideoX team, Tsinghua University & ZhipuAI and The HuggingFace Team.
|
||||
# All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import inspect
|
||||
import math
|
||||
import types
|
||||
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
from transformers import T5EncoderModel, T5Tokenizer
|
||||
|
||||
from diffusers.callbacks import MultiPipelineCallbacks, PipelineCallback
|
||||
from diffusers.loaders import CogVideoXLoraLoaderMixin
|
||||
from diffusers.models import AutoencoderKLCogVideoX, CogVideoXTransformer3DModel
|
||||
from diffusers.models.embeddings import get_3d_rotary_pos_embed
|
||||
from diffusers.pipelines.cogvideo.pipeline_output import CogVideoXPipelineOutput
|
||||
from diffusers.pipelines.pipeline_utils import DiffusionPipeline
|
||||
from diffusers.schedulers import CogVideoXDDIMScheduler, CogVideoXDPMScheduler
|
||||
from diffusers.utils import is_torch_xla_available, logging, replace_example_docstring
|
||||
from diffusers.utils.torch_utils import randn_tensor
|
||||
from diffusers.video_processor import VideoProcessor
|
||||
|
||||
|
||||
if is_torch_xla_available():
|
||||
import torch_xla.core.xla_model as xm
|
||||
|
||||
XLA_AVAILABLE = True
|
||||
else:
|
||||
XLA_AVAILABLE = False
|
||||
|
||||
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||
|
||||
|
||||
EXAMPLE_DOC_STRING = """
|
||||
Examples:
|
||||
```python
|
||||
>>> import torch
|
||||
>>> from diffusers.utils import export_to_video
|
||||
>>> from examples.community.pipeline_stg_cogvideox import CogVideoXSTGPipeline
|
||||
|
||||
>>> # Models: "THUDM/CogVideoX-2b" or "THUDM/CogVideoX-5b"
|
||||
>>> pipe = CogVideoXSTGPipeline.from_pretrained("THUDM/CogVideoX-5b", torch_dtype=torch.float16).to("cuda")
|
||||
>>> prompt = (
|
||||
... "A father and son building a treehouse together, their hands covered in sawdust and smiles on their faces, realistic style."
|
||||
... )
|
||||
>>> pipe.transformer.to(memory_format=torch.channels_last)
|
||||
|
||||
>>> # Configure STG mode options
|
||||
>>> stg_applied_layers_idx = [11] # Layer indices from 0 to 41
|
||||
>>> stg_scale = 1.0 # Set to 0.0 for CFG
|
||||
>>> do_rescaling = False
|
||||
|
||||
>>> # Generate video frames with STG parameters
|
||||
>>> frames = pipe(
|
||||
... prompt=prompt,
|
||||
... stg_applied_layers_idx=stg_applied_layers_idx,
|
||||
... stg_scale=stg_scale,
|
||||
... do_rescaling=do_rescaling,
|
||||
>>> ).frames[0]
|
||||
>>> export_to_video(frames, "output.mp4", fps=8)
|
||||
```
|
||||
"""
|
||||
|
||||
|
||||
def forward_with_stg(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
encoder_hidden_states: torch.Tensor,
|
||||
temb: torch.Tensor,
|
||||
image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
|
||||
) -> torch.Tensor:
|
||||
hidden_states_ptb = hidden_states[2:]
|
||||
encoder_hidden_states_ptb = encoder_hidden_states[2:]
|
||||
|
||||
text_seq_length = encoder_hidden_states.size(1)
|
||||
|
||||
# norm & modulate
|
||||
norm_hidden_states, norm_encoder_hidden_states, gate_msa, enc_gate_msa = self.norm1(
|
||||
hidden_states, encoder_hidden_states, temb
|
||||
)
|
||||
|
||||
# attention
|
||||
attn_hidden_states, attn_encoder_hidden_states = self.attn1(
|
||||
hidden_states=norm_hidden_states,
|
||||
encoder_hidden_states=norm_encoder_hidden_states,
|
||||
image_rotary_emb=image_rotary_emb,
|
||||
)
|
||||
|
||||
hidden_states = hidden_states + gate_msa * attn_hidden_states
|
||||
encoder_hidden_states = encoder_hidden_states + enc_gate_msa * attn_encoder_hidden_states
|
||||
|
||||
# norm & modulate
|
||||
norm_hidden_states, norm_encoder_hidden_states, gate_ff, enc_gate_ff = self.norm2(
|
||||
hidden_states, encoder_hidden_states, temb
|
||||
)
|
||||
|
||||
# feed-forward
|
||||
norm_hidden_states = torch.cat([norm_encoder_hidden_states, norm_hidden_states], dim=1)
|
||||
ff_output = self.ff(norm_hidden_states)
|
||||
|
||||
hidden_states = hidden_states + gate_ff * ff_output[:, text_seq_length:]
|
||||
encoder_hidden_states = encoder_hidden_states + enc_gate_ff * ff_output[:, :text_seq_length]
|
||||
|
||||
hidden_states[2:] = hidden_states_ptb
|
||||
encoder_hidden_states[2:] = encoder_hidden_states_ptb
|
||||
|
||||
return hidden_states, encoder_hidden_states
|
||||
|
||||
|
||||
# Similar to diffusers.pipelines.hunyuandit.pipeline_hunyuandit.get_resize_crop_region_for_grid
|
||||
def get_resize_crop_region_for_grid(src, tgt_width, tgt_height):
|
||||
tw = tgt_width
|
||||
th = tgt_height
|
||||
h, w = src
|
||||
r = h / w
|
||||
if r > (th / tw):
|
||||
resize_height = th
|
||||
resize_width = int(round(th / h * w))
|
||||
else:
|
||||
resize_width = tw
|
||||
resize_height = int(round(tw / w * h))
|
||||
|
||||
crop_top = int(round((th - resize_height) / 2.0))
|
||||
crop_left = int(round((tw - resize_width) / 2.0))
|
||||
|
||||
return (crop_top, crop_left), (crop_top + resize_height, crop_left + resize_width)
|
||||
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps
|
||||
def retrieve_timesteps(
|
||||
scheduler,
|
||||
num_inference_steps: Optional[int] = None,
|
||||
device: Optional[Union[str, torch.device]] = None,
|
||||
timesteps: Optional[List[int]] = None,
|
||||
sigmas: Optional[List[float]] = None,
|
||||
**kwargs,
|
||||
):
|
||||
r"""
|
||||
Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
|
||||
custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.
|
||||
|
||||
Args:
|
||||
scheduler (`SchedulerMixin`):
|
||||
The scheduler to get timesteps from.
|
||||
num_inference_steps (`int`):
|
||||
The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps`
|
||||
must be `None`.
|
||||
device (`str` or `torch.device`, *optional*):
|
||||
The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
|
||||
timesteps (`List[int]`, *optional*):
|
||||
Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed,
|
||||
`num_inference_steps` and `sigmas` must be `None`.
|
||||
sigmas (`List[float]`, *optional*):
|
||||
Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed,
|
||||
`num_inference_steps` and `timesteps` must be `None`.
|
||||
|
||||
Returns:
|
||||
`Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the
|
||||
second element is the number of inference steps.
|
||||
"""
|
||||
if timesteps is not None and sigmas is not None:
|
||||
raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values")
|
||||
if timesteps is not None:
|
||||
accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
|
||||
if not accepts_timesteps:
|
||||
raise ValueError(
|
||||
f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
|
||||
f" timestep schedules. Please check whether you are using the correct scheduler."
|
||||
)
|
||||
scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)
|
||||
timesteps = scheduler.timesteps
|
||||
num_inference_steps = len(timesteps)
|
||||
elif sigmas is not None:
|
||||
accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
|
||||
if not accept_sigmas:
|
||||
raise ValueError(
|
||||
f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
|
||||
f" sigmas schedules. Please check whether you are using the correct scheduler."
|
||||
)
|
||||
scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs)
|
||||
timesteps = scheduler.timesteps
|
||||
num_inference_steps = len(timesteps)
|
||||
else:
|
||||
scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)
|
||||
timesteps = scheduler.timesteps
|
||||
return timesteps, num_inference_steps
|
||||
|
||||
|
||||
class CogVideoXSTGPipeline(DiffusionPipeline, CogVideoXLoraLoaderMixin):
|
||||
r"""
|
||||
Pipeline for text-to-video generation using CogVideoX.
|
||||
|
||||
This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
|
||||
library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
|
||||
|
||||
Args:
|
||||
vae ([`AutoencoderKL`]):
|
||||
Variational Auto-Encoder (VAE) Model to encode and decode videos to and from latent representations.
|
||||
text_encoder ([`T5EncoderModel`]):
|
||||
Frozen text-encoder. CogVideoX uses
|
||||
[T5](https://huggingface.co/docs/transformers/model_doc/t5#transformers.T5EncoderModel); specifically the
|
||||
[t5-v1_1-xxl](https://huggingface.co/PixArt-alpha/PixArt-alpha/tree/main/t5-v1_1-xxl) variant.
|
||||
tokenizer (`T5Tokenizer`):
|
||||
Tokenizer of class
|
||||
[T5Tokenizer](https://huggingface.co/docs/transformers/model_doc/t5#transformers.T5Tokenizer).
|
||||
transformer ([`CogVideoXTransformer3DModel`]):
|
||||
A text conditioned `CogVideoXTransformer3DModel` to denoise the encoded video latents.
|
||||
scheduler ([`SchedulerMixin`]):
|
||||
A scheduler to be used in combination with `transformer` to denoise the encoded video latents.
|
||||
"""
|
||||
|
||||
_optional_components = []
|
||||
model_cpu_offload_seq = "text_encoder->transformer->vae"
|
||||
|
||||
_callback_tensor_inputs = [
|
||||
"latents",
|
||||
"prompt_embeds",
|
||||
"negative_prompt_embeds",
|
||||
]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
tokenizer: T5Tokenizer,
|
||||
text_encoder: T5EncoderModel,
|
||||
vae: AutoencoderKLCogVideoX,
|
||||
transformer: CogVideoXTransformer3DModel,
|
||||
scheduler: Union[CogVideoXDDIMScheduler, CogVideoXDPMScheduler],
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.register_modules(
|
||||
tokenizer=tokenizer, text_encoder=text_encoder, vae=vae, transformer=transformer, scheduler=scheduler
|
||||
)
|
||||
self.vae_scale_factor_spatial = (
|
||||
2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8
|
||||
)
|
||||
self.vae_scale_factor_temporal = (
|
||||
self.vae.config.temporal_compression_ratio if getattr(self, "vae", None) else 4
|
||||
)
|
||||
self.vae_scaling_factor_image = self.vae.config.scaling_factor if getattr(self, "vae", None) else 0.7
|
||||
|
||||
self.video_processor = VideoProcessor(vae_scale_factor=self.vae_scale_factor_spatial)
|
||||
|
||||
def _get_t5_prompt_embeds(
|
||||
self,
|
||||
prompt: Union[str, List[str]] = None,
|
||||
num_videos_per_prompt: int = 1,
|
||||
max_sequence_length: int = 226,
|
||||
device: Optional[torch.device] = None,
|
||||
dtype: Optional[torch.dtype] = None,
|
||||
):
|
||||
device = device or self._execution_device
|
||||
dtype = dtype or self.text_encoder.dtype
|
||||
|
||||
prompt = [prompt] if isinstance(prompt, str) else prompt
|
||||
batch_size = len(prompt)
|
||||
|
||||
text_inputs = self.tokenizer(
|
||||
prompt,
|
||||
padding="max_length",
|
||||
max_length=max_sequence_length,
|
||||
truncation=True,
|
||||
add_special_tokens=True,
|
||||
return_tensors="pt",
|
||||
)
|
||||
text_input_ids = text_inputs.input_ids
|
||||
untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
|
||||
|
||||
if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids):
|
||||
removed_text = self.tokenizer.batch_decode(untruncated_ids[:, max_sequence_length - 1 : -1])
|
||||
logger.warning(
|
||||
"The following part of your input was truncated because `max_sequence_length` is set to "
|
||||
f" {max_sequence_length} tokens: {removed_text}"
|
||||
)
|
||||
|
||||
prompt_embeds = self.text_encoder(text_input_ids.to(device))[0]
|
||||
prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)
|
||||
|
||||
# duplicate text embeddings for each generation per prompt, using mps friendly method
|
||||
_, seq_len, _ = prompt_embeds.shape
|
||||
prompt_embeds = prompt_embeds.repeat(1, num_videos_per_prompt, 1)
|
||||
prompt_embeds = prompt_embeds.view(batch_size * num_videos_per_prompt, seq_len, -1)
|
||||
|
||||
return prompt_embeds
|
||||
|
||||
def encode_prompt(
|
||||
self,
|
||||
prompt: Union[str, List[str]],
|
||||
negative_prompt: Optional[Union[str, List[str]]] = None,
|
||||
do_classifier_free_guidance: bool = True,
|
||||
num_videos_per_prompt: int = 1,
|
||||
prompt_embeds: Optional[torch.Tensor] = None,
|
||||
negative_prompt_embeds: Optional[torch.Tensor] = None,
|
||||
max_sequence_length: int = 226,
|
||||
device: Optional[torch.device] = None,
|
||||
dtype: Optional[torch.dtype] = None,
|
||||
):
|
||||
r"""
|
||||
Encodes the prompt into text encoder hidden states.
|
||||
|
||||
Args:
|
||||
prompt (`str` or `List[str]`, *optional*):
|
||||
prompt to be encoded
|
||||
negative_prompt (`str` or `List[str]`, *optional*):
|
||||
The prompt or prompts not to guide the image generation. If not defined, one has to pass
|
||||
`negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
|
||||
less than `1`).
|
||||
do_classifier_free_guidance (`bool`, *optional*, defaults to `True`):
|
||||
Whether to use classifier free guidance or not.
|
||||
num_videos_per_prompt (`int`, *optional*, defaults to 1):
|
||||
Number of videos that should be generated per prompt. torch device to place the resulting embeddings on
|
||||
prompt_embeds (`torch.Tensor`, *optional*):
|
||||
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
|
||||
provided, text embeddings will be generated from `prompt` input argument.
|
||||
negative_prompt_embeds (`torch.Tensor`, *optional*):
|
||||
Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
|
||||
weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
|
||||
argument.
|
||||
device: (`torch.device`, *optional*):
|
||||
torch device
|
||||
dtype: (`torch.dtype`, *optional*):
|
||||
torch dtype
|
||||
"""
|
||||
device = device or self._execution_device
|
||||
|
||||
prompt = [prompt] if isinstance(prompt, str) else prompt
|
||||
if prompt is not None:
|
||||
batch_size = len(prompt)
|
||||
else:
|
||||
batch_size = prompt_embeds.shape[0]
|
||||
|
||||
if prompt_embeds is None:
|
||||
prompt_embeds = self._get_t5_prompt_embeds(
|
||||
prompt=prompt,
|
||||
num_videos_per_prompt=num_videos_per_prompt,
|
||||
max_sequence_length=max_sequence_length,
|
||||
device=device,
|
||||
dtype=dtype,
|
||||
)
|
||||
|
||||
if do_classifier_free_guidance and negative_prompt_embeds is None:
|
||||
negative_prompt = negative_prompt or ""
|
||||
negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt
|
||||
|
||||
if prompt is not None and type(prompt) is not type(negative_prompt):
|
||||
raise TypeError(
|
||||
f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
|
||||
f" {type(prompt)}."
|
||||
)
|
||||
elif batch_size != len(negative_prompt):
|
||||
raise ValueError(
|
||||
f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
|
||||
f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
|
||||
" the batch size of `prompt`."
|
||||
)
|
||||
|
||||
negative_prompt_embeds = self._get_t5_prompt_embeds(
|
||||
prompt=negative_prompt,
|
||||
num_videos_per_prompt=num_videos_per_prompt,
|
||||
max_sequence_length=max_sequence_length,
|
||||
device=device,
|
||||
dtype=dtype,
|
||||
)
|
||||
|
||||
return prompt_embeds, negative_prompt_embeds
|
||||
|
||||
def prepare_latents(
|
||||
self, batch_size, num_channels_latents, num_frames, height, width, dtype, device, generator, latents=None
|
||||
):
|
||||
if isinstance(generator, list) and len(generator) != batch_size:
|
||||
raise ValueError(
|
||||
f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
|
||||
f" size of {batch_size}. Make sure the batch size matches the length of the generators."
|
||||
)
|
||||
|
||||
shape = (
|
||||
batch_size,
|
||||
(num_frames - 1) // self.vae_scale_factor_temporal + 1,
|
||||
num_channels_latents,
|
||||
height // self.vae_scale_factor_spatial,
|
||||
width // self.vae_scale_factor_spatial,
|
||||
)
|
||||
|
||||
if latents is None:
|
||||
latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
|
||||
else:
|
||||
latents = latents.to(device)
|
||||
|
||||
# scale the initial noise by the standard deviation required by the scheduler
|
||||
latents = latents * self.scheduler.init_noise_sigma
|
||||
return latents
|
||||
|
||||
def decode_latents(self, latents: torch.Tensor) -> torch.Tensor:
|
||||
latents = latents.permute(0, 2, 1, 3, 4) # [batch_size, num_channels, num_frames, height, width]
|
||||
latents = 1 / self.vae_scaling_factor_image * latents
|
||||
|
||||
frames = self.vae.decode(latents).sample
|
||||
return frames
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs
|
||||
def prepare_extra_step_kwargs(self, generator, eta):
|
||||
# prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
|
||||
# eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
|
||||
# eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
|
||||
# and should be between [0, 1]
|
||||
|
||||
accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
|
||||
extra_step_kwargs = {}
|
||||
if accepts_eta:
|
||||
extra_step_kwargs["eta"] = eta
|
||||
|
||||
# check if the scheduler accepts generator
|
||||
accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys())
|
||||
if accepts_generator:
|
||||
extra_step_kwargs["generator"] = generator
|
||||
return extra_step_kwargs
|
||||
|
||||
# Copied from diffusers.pipelines.latte.pipeline_latte.LattePipeline.check_inputs
|
||||
def check_inputs(
|
||||
self,
|
||||
prompt,
|
||||
height,
|
||||
width,
|
||||
negative_prompt,
|
||||
callback_on_step_end_tensor_inputs,
|
||||
prompt_embeds=None,
|
||||
negative_prompt_embeds=None,
|
||||
):
|
||||
if height % 8 != 0 or width % 8 != 0:
|
||||
raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
|
||||
|
||||
if callback_on_step_end_tensor_inputs is not None and not all(
|
||||
k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
|
||||
):
|
||||
raise ValueError(
|
||||
f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}"
|
||||
)
|
||||
if prompt is not None and prompt_embeds is not None:
|
||||
raise ValueError(
|
||||
f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
|
||||
" only forward one of the two."
|
||||
)
|
||||
elif prompt is None and prompt_embeds is None:
|
||||
raise ValueError(
|
||||
"Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
|
||||
)
|
||||
elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
|
||||
raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
|
||||
|
||||
if prompt is not None and negative_prompt_embeds is not None:
|
||||
raise ValueError(
|
||||
f"Cannot forward both `prompt`: {prompt} and `negative_prompt_embeds`:"
|
||||
f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
|
||||
)
|
||||
|
||||
if negative_prompt is not None and negative_prompt_embeds is not None:
|
||||
raise ValueError(
|
||||
f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:"
|
||||
f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
|
||||
)
|
||||
|
||||
if prompt_embeds is not None and negative_prompt_embeds is not None:
|
||||
if prompt_embeds.shape != negative_prompt_embeds.shape:
|
||||
raise ValueError(
|
||||
"`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but"
|
||||
f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`"
|
||||
f" {negative_prompt_embeds.shape}."
|
||||
)
|
||||
|
||||
def fuse_qkv_projections(self) -> None:
|
||||
r"""Enables fused QKV projections."""
|
||||
self.fusing_transformer = True
|
||||
self.transformer.fuse_qkv_projections()
|
||||
|
||||
def unfuse_qkv_projections(self) -> None:
|
||||
r"""Disable QKV projection fusion if enabled."""
|
||||
if not self.fusing_transformer:
|
||||
logger.warning("The Transformer was not initially fused for QKV projections. Doing nothing.")
|
||||
else:
|
||||
self.transformer.unfuse_qkv_projections()
|
||||
self.fusing_transformer = False
|
||||
|
||||
def _prepare_rotary_positional_embeddings(
|
||||
self,
|
||||
height: int,
|
||||
width: int,
|
||||
num_frames: int,
|
||||
device: torch.device,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
grid_height = height // (self.vae_scale_factor_spatial * self.transformer.config.patch_size)
|
||||
grid_width = width // (self.vae_scale_factor_spatial * self.transformer.config.patch_size)
|
||||
|
||||
p = self.transformer.config.patch_size
|
||||
p_t = self.transformer.config.patch_size_t
|
||||
|
||||
base_size_width = self.transformer.config.sample_width // p
|
||||
base_size_height = self.transformer.config.sample_height // p
|
||||
|
||||
if p_t is None:
|
||||
# CogVideoX 1.0
|
||||
grid_crops_coords = get_resize_crop_region_for_grid(
|
||||
(grid_height, grid_width), base_size_width, base_size_height
|
||||
)
|
||||
freqs_cos, freqs_sin = get_3d_rotary_pos_embed(
|
||||
embed_dim=self.transformer.config.attention_head_dim,
|
||||
crops_coords=grid_crops_coords,
|
||||
grid_size=(grid_height, grid_width),
|
||||
temporal_size=num_frames,
|
||||
device=device,
|
||||
)
|
||||
else:
|
||||
# CogVideoX 1.5
|
||||
base_num_frames = (num_frames + p_t - 1) // p_t
|
||||
|
||||
freqs_cos, freqs_sin = get_3d_rotary_pos_embed(
|
||||
embed_dim=self.transformer.config.attention_head_dim,
|
||||
crops_coords=None,
|
||||
grid_size=(grid_height, grid_width),
|
||||
temporal_size=base_num_frames,
|
||||
grid_type="slice",
|
||||
max_size=(base_size_height, base_size_width),
|
||||
device=device,
|
||||
)
|
||||
|
||||
return freqs_cos, freqs_sin
|
||||
|
||||
@property
|
||||
def guidance_scale(self):
|
||||
return self._guidance_scale
|
||||
|
||||
@property
|
||||
def do_spatio_temporal_guidance(self):
|
||||
return self._stg_scale > 0.0
|
||||
|
||||
@property
|
||||
def num_timesteps(self):
|
||||
return self._num_timesteps
|
||||
|
||||
@property
|
||||
def attention_kwargs(self):
|
||||
return self._attention_kwargs
|
||||
|
||||
@property
|
||||
def current_timestep(self):
|
||||
return self._current_timestep
|
||||
|
||||
@property
|
||||
def interrupt(self):
|
||||
return self._interrupt
|
||||
|
||||
@torch.no_grad()
|
||||
@replace_example_docstring(EXAMPLE_DOC_STRING)
|
||||
def __call__(
|
||||
self,
|
||||
prompt: Optional[Union[str, List[str]]] = None,
|
||||
negative_prompt: Optional[Union[str, List[str]]] = None,
|
||||
height: Optional[int] = None,
|
||||
width: Optional[int] = None,
|
||||
num_frames: Optional[int] = None,
|
||||
num_inference_steps: int = 50,
|
||||
timesteps: Optional[List[int]] = None,
|
||||
guidance_scale: float = 6,
|
||||
use_dynamic_cfg: bool = False,
|
||||
num_videos_per_prompt: int = 1,
|
||||
eta: float = 0.0,
|
||||
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
|
||||
latents: Optional[torch.FloatTensor] = None,
|
||||
prompt_embeds: Optional[torch.FloatTensor] = None,
|
||||
negative_prompt_embeds: Optional[torch.FloatTensor] = None,
|
||||
output_type: str = "pil",
|
||||
return_dict: bool = True,
|
||||
attention_kwargs: Optional[Dict[str, Any]] = None,
|
||||
callback_on_step_end: Optional[
|
||||
Union[Callable[[int, int, Dict], None], PipelineCallback, MultiPipelineCallbacks]
|
||||
] = None,
|
||||
callback_on_step_end_tensor_inputs: List[str] = ["latents"],
|
||||
max_sequence_length: int = 226,
|
||||
stg_applied_layers_idx: Optional[List[int]] = [11],
|
||||
stg_scale: Optional[float] = 0.0,
|
||||
do_rescaling: Optional[bool] = False,
|
||||
) -> Union[CogVideoXPipelineOutput, Tuple]:
|
||||
"""
|
||||
Function invoked when calling the pipeline for generation.
|
||||
|
||||
Args:
|
||||
prompt (`str` or `List[str]`, *optional*):
|
||||
The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
|
||||
instead.
|
||||
negative_prompt (`str` or `List[str]`, *optional*):
|
||||
The prompt or prompts not to guide the image generation. If not defined, one has to pass
|
||||
`negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
|
||||
less than `1`).
|
||||
height (`int`, *optional*, defaults to self.transformer.config.sample_height * self.vae_scale_factor_spatial):
|
||||
The height in pixels of the generated image. This is set to 480 by default for the best results.
|
||||
width (`int`, *optional*, defaults to self.transformer.config.sample_height * self.vae_scale_factor_spatial):
|
||||
The width in pixels of the generated image. This is set to 720 by default for the best results.
|
||||
num_frames (`int`, defaults to `48`):
|
||||
Number of frames to generate. Must be divisible by self.vae_scale_factor_temporal. Generated video will
|
||||
contain 1 extra frame because CogVideoX is conditioned with (num_seconds * fps + 1) frames where
|
||||
num_seconds is 6 and fps is 8. However, since videos can be saved at any fps, the only condition that
|
||||
needs to be satisfied is that of divisibility mentioned above.
|
||||
num_inference_steps (`int`, *optional*, defaults to 50):
|
||||
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
|
||||
expense of slower inference.
|
||||
timesteps (`List[int]`, *optional*):
|
||||
Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument
|
||||
in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is
|
||||
passed will be used. Must be in descending order.
|
||||
guidance_scale (`float`, *optional*, defaults to 7.0):
|
||||
Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
|
||||
`guidance_scale` is defined as `w` of equation 2. of [Imagen
|
||||
Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
|
||||
1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
|
||||
usually at the expense of lower image quality.
|
||||
num_videos_per_prompt (`int`, *optional*, defaults to 1):
|
||||
The number of videos to generate per prompt.
|
||||
generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
|
||||
One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
|
||||
to make generation deterministic.
|
||||
latents (`torch.FloatTensor`, *optional*):
|
||||
Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
|
||||
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
|
||||
tensor will ge generated by sampling using the supplied random `generator`.
|
||||
prompt_embeds (`torch.FloatTensor`, *optional*):
|
||||
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
|
||||
provided, text embeddings will be generated from `prompt` input argument.
|
||||
negative_prompt_embeds (`torch.FloatTensor`, *optional*):
|
||||
Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
|
||||
weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
|
||||
argument.
|
||||
output_type (`str`, *optional*, defaults to `"pil"`):
|
||||
The output format of the generate image. Choose between
|
||||
[PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
|
||||
return_dict (`bool`, *optional*, defaults to `True`):
|
||||
Whether or not to return a [`~pipelines.stable_diffusion_xl.StableDiffusionXLPipelineOutput`] instead
|
||||
of a plain tuple.
|
||||
attention_kwargs (`dict`, *optional*):
|
||||
A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
|
||||
`self.processor` in
|
||||
[diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
|
||||
callback_on_step_end (`Callable`, *optional*):
|
||||
A function that calls at the end of each denoising steps during the inference. The function is called
|
||||
with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int,
|
||||
callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by
|
||||
`callback_on_step_end_tensor_inputs`.
|
||||
callback_on_step_end_tensor_inputs (`List`, *optional*):
|
||||
The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
|
||||
will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
|
||||
`._callback_tensor_inputs` attribute of your pipeline class.
|
||||
max_sequence_length (`int`, defaults to `226`):
|
||||
Maximum sequence length in encoded prompt. Must be consistent with
|
||||
`self.transformer.config.max_text_seq_length` otherwise may lead to poor results.
|
||||
|
||||
Examples:
|
||||
|
||||
Returns:
|
||||
[`~pipelines.cogvideo.pipeline_cogvideox.CogVideoXPipelineOutput`] or `tuple`:
|
||||
[`~pipelines.cogvideo.pipeline_cogvideox.CogVideoXPipelineOutput`] if `return_dict` is True, otherwise a
|
||||
`tuple`. When returning a tuple, the first element is a list with the generated images.
|
||||
"""
|
||||
|
||||
if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)):
|
||||
callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs
|
||||
|
||||
height = height or self.transformer.config.sample_height * self.vae_scale_factor_spatial
|
||||
width = width or self.transformer.config.sample_width * self.vae_scale_factor_spatial
|
||||
num_frames = num_frames or self.transformer.config.sample_frames
|
||||
|
||||
num_videos_per_prompt = 1
|
||||
|
||||
# 1. Check inputs. Raise error if not correct
|
||||
self.check_inputs(
|
||||
prompt,
|
||||
height,
|
||||
width,
|
||||
negative_prompt,
|
||||
callback_on_step_end_tensor_inputs,
|
||||
prompt_embeds,
|
||||
negative_prompt_embeds,
|
||||
)
|
||||
self._stg_scale = stg_scale
|
||||
self._guidance_scale = guidance_scale
|
||||
self._attention_kwargs = attention_kwargs
|
||||
self._current_timestep = None
|
||||
self._interrupt = False
|
||||
|
||||
if self.do_spatio_temporal_guidance:
|
||||
for i in stg_applied_layers_idx:
|
||||
self.transformer.transformer_blocks[i].forward = types.MethodType(
|
||||
forward_with_stg, self.transformer.transformer_blocks[i]
|
||||
)
|
||||
|
||||
# 2. Default call parameters
|
||||
if prompt is not None and isinstance(prompt, str):
|
||||
batch_size = 1
|
||||
elif prompt is not None and isinstance(prompt, list):
|
||||
batch_size = len(prompt)
|
||||
else:
|
||||
batch_size = prompt_embeds.shape[0]
|
||||
|
||||
device = self._execution_device
|
||||
|
||||
# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
|
||||
# of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
|
||||
# corresponds to doing no classifier free guidance.
|
||||
do_classifier_free_guidance = guidance_scale > 1.0
|
||||
|
||||
# 3. Encode input prompt
|
||||
prompt_embeds, negative_prompt_embeds = self.encode_prompt(
|
||||
prompt,
|
||||
negative_prompt,
|
||||
do_classifier_free_guidance,
|
||||
num_videos_per_prompt=num_videos_per_prompt,
|
||||
prompt_embeds=prompt_embeds,
|
||||
negative_prompt_embeds=negative_prompt_embeds,
|
||||
max_sequence_length=max_sequence_length,
|
||||
device=device,
|
||||
)
|
||||
if do_classifier_free_guidance and not self.do_spatio_temporal_guidance:
|
||||
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)
|
||||
elif do_classifier_free_guidance and self.do_spatio_temporal_guidance:
|
||||
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds, prompt_embeds], dim=0)
|
||||
|
||||
# 4. Prepare timesteps
|
||||
timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, timesteps)
|
||||
self._num_timesteps = len(timesteps)
|
||||
|
||||
# 5. Prepare latents
|
||||
latent_frames = (num_frames - 1) // self.vae_scale_factor_temporal + 1
|
||||
|
||||
# For CogVideoX 1.5, the latent frames should be padded to make it divisible by patch_size_t
|
||||
patch_size_t = self.transformer.config.patch_size_t
|
||||
additional_frames = 0
|
||||
if patch_size_t is not None and latent_frames % patch_size_t != 0:
|
||||
additional_frames = patch_size_t - latent_frames % patch_size_t
|
||||
num_frames += additional_frames * self.vae_scale_factor_temporal
|
||||
|
||||
latent_channels = self.transformer.config.in_channels
|
||||
latents = self.prepare_latents(
|
||||
batch_size * num_videos_per_prompt,
|
||||
latent_channels,
|
||||
num_frames,
|
||||
height,
|
||||
width,
|
||||
prompt_embeds.dtype,
|
||||
device,
|
||||
generator,
|
||||
latents,
|
||||
)
|
||||
|
||||
# 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
|
||||
extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
|
||||
|
||||
# 7. Create rotary embeds if required
|
||||
image_rotary_emb = (
|
||||
self._prepare_rotary_positional_embeddings(height, width, latents.size(1), device)
|
||||
if self.transformer.config.use_rotary_positional_embeddings
|
||||
else None
|
||||
)
|
||||
|
||||
# 8. Denoising loop
|
||||
num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
|
||||
|
||||
with self.progress_bar(total=num_inference_steps) as progress_bar:
|
||||
# for DPM-solver++
|
||||
old_pred_original_sample = None
|
||||
for i, t in enumerate(timesteps):
|
||||
if self.interrupt:
|
||||
continue
|
||||
|
||||
self._current_timestep = t
|
||||
if do_classifier_free_guidance and not self.do_spatio_temporal_guidance:
|
||||
latent_model_input = torch.cat([latents] * 2)
|
||||
elif do_classifier_free_guidance and self.do_spatio_temporal_guidance:
|
||||
latent_model_input = torch.cat([latents] * 3)
|
||||
else:
|
||||
latent_model_input = latents
|
||||
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
|
||||
|
||||
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
|
||||
timestep = t.expand(latent_model_input.shape[0])
|
||||
|
||||
# predict noise model_output
|
||||
noise_pred = self.transformer(
|
||||
hidden_states=latent_model_input,
|
||||
encoder_hidden_states=prompt_embeds,
|
||||
timestep=timestep,
|
||||
image_rotary_emb=image_rotary_emb,
|
||||
attention_kwargs=attention_kwargs,
|
||||
return_dict=False,
|
||||
)[0]
|
||||
noise_pred = noise_pred.float()
|
||||
|
||||
# perform guidance
|
||||
if use_dynamic_cfg:
|
||||
self._guidance_scale = 1 + guidance_scale * (
|
||||
(1 - math.cos(math.pi * ((num_inference_steps - t.item()) / num_inference_steps) ** 5.0)) / 2
|
||||
)
|
||||
if do_classifier_free_guidance and not self.do_spatio_temporal_guidance:
|
||||
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
|
||||
noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond)
|
||||
elif do_classifier_free_guidance and self.do_spatio_temporal_guidance:
|
||||
noise_pred_uncond, noise_pred_text, noise_pred_perturb = noise_pred.chunk(3)
|
||||
noise_pred = (
|
||||
noise_pred_uncond
|
||||
+ self.guidance_scale * (noise_pred_text - noise_pred_uncond)
|
||||
+ self._stg_scale * (noise_pred_text - noise_pred_perturb)
|
||||
)
|
||||
|
||||
if do_rescaling:
|
||||
rescaling_scale = 0.7
|
||||
factor = noise_pred_text.std() / noise_pred.std()
|
||||
factor = rescaling_scale * factor + (1 - rescaling_scale)
|
||||
noise_pred = noise_pred * factor
|
||||
|
||||
# compute the previous noisy sample x_t -> x_t-1
|
||||
if not isinstance(self.scheduler, CogVideoXDPMScheduler):
|
||||
latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]
|
||||
else:
|
||||
latents, old_pred_original_sample = self.scheduler.step(
|
||||
noise_pred,
|
||||
old_pred_original_sample,
|
||||
t,
|
||||
timesteps[i - 1] if i > 0 else None,
|
||||
latents,
|
||||
**extra_step_kwargs,
|
||||
return_dict=False,
|
||||
)
|
||||
latents = latents.to(prompt_embeds.dtype)
|
||||
|
||||
# call the callback, if provided
|
||||
if callback_on_step_end is not None:
|
||||
callback_kwargs = {}
|
||||
for k in callback_on_step_end_tensor_inputs:
|
||||
callback_kwargs[k] = locals()[k]
|
||||
callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
|
||||
|
||||
latents = callback_outputs.pop("latents", latents)
|
||||
prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
|
||||
negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds)
|
||||
|
||||
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
|
||||
progress_bar.update()
|
||||
|
||||
if XLA_AVAILABLE:
|
||||
xm.mark_step()
|
||||
|
||||
self._current_timestep = None
|
||||
|
||||
if not output_type == "latent":
|
||||
# Discard any padding frames that were added for CogVideoX 1.5
|
||||
latents = latents[:, additional_frames:]
|
||||
video = self.decode_latents(latents)
|
||||
video = self.video_processor.postprocess_video(video=video, output_type=output_type)
|
||||
else:
|
||||
video = latents
|
||||
|
||||
# Offload all models
|
||||
self.maybe_free_model_hooks()
|
||||
|
||||
if not return_dict:
|
||||
return (video,)
|
||||
|
||||
return CogVideoXPipelineOutput(frames=video)
|
||||
@@ -0,0 +1,794 @@
|
||||
# Copyright 2024 The HunyuanVideo Team and The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import inspect
|
||||
import types
|
||||
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from transformers import CLIPTextModel, CLIPTokenizer, LlamaModel, LlamaTokenizerFast
|
||||
|
||||
from diffusers.callbacks import MultiPipelineCallbacks, PipelineCallback
|
||||
from diffusers.loaders import HunyuanVideoLoraLoaderMixin
|
||||
from diffusers.models import AutoencoderKLHunyuanVideo, HunyuanVideoTransformer3DModel
|
||||
from diffusers.pipelines.hunyuan_video.pipeline_output import HunyuanVideoPipelineOutput
|
||||
from diffusers.pipelines.pipeline_utils import DiffusionPipeline
|
||||
from diffusers.schedulers import FlowMatchEulerDiscreteScheduler
|
||||
from diffusers.utils import is_torch_xla_available, logging, replace_example_docstring
|
||||
from diffusers.utils.torch_utils import randn_tensor
|
||||
from diffusers.video_processor import VideoProcessor
|
||||
|
||||
|
||||
if is_torch_xla_available():
|
||||
import torch_xla.core.xla_model as xm
|
||||
|
||||
XLA_AVAILABLE = True
|
||||
else:
|
||||
XLA_AVAILABLE = False
|
||||
|
||||
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||
|
||||
|
||||
EXAMPLE_DOC_STRING = """
|
||||
Examples:
|
||||
```python
|
||||
>>> import torch
|
||||
>>> from diffusers.utils import export_to_video
|
||||
>>> from diffusers import HunyuanVideoTransformer3DModel
|
||||
>>> from examples.community.pipeline_stg_hunyuan_video import HunyuanVideoSTGPipeline
|
||||
|
||||
>>> model_id = "hunyuanvideo-community/HunyuanVideo"
|
||||
>>> transformer = HunyuanVideoTransformer3DModel.from_pretrained(
|
||||
... model_id, subfolder="transformer", torch_dtype=torch.bfloat16
|
||||
... )
|
||||
>>> pipe = HunyuanVideoSTGPipeline.from_pretrained(model_id, transformer=transformer, torch_dtype=torch.float16)
|
||||
>>> pipe.vae.enable_tiling()
|
||||
>>> pipe.to("cuda")
|
||||
|
||||
>>> # Configure STG mode options
|
||||
>>> stg_applied_layers_idx = [2] # Layer indices from 0 to 41
|
||||
>>> stg_scale = 1.0 # Set 0.0 for CFG
|
||||
|
||||
>>> output = pipe(
|
||||
... prompt="A wolf howling at the moon, with the moon subtly resembling a giant clock face, realistic style.",
|
||||
... height=320,
|
||||
... width=512,
|
||||
... num_frames=61,
|
||||
... num_inference_steps=30,
|
||||
... stg_applied_layers_idx=stg_applied_layers_idx,
|
||||
... stg_scale=stg_scale,
|
||||
>>> ).frames[0]
|
||||
>>> export_to_video(output, "output.mp4", fps=15)
|
||||
```
|
||||
"""
|
||||
|
||||
|
||||
DEFAULT_PROMPT_TEMPLATE = {
|
||||
"template": (
|
||||
"<|start_header_id|>system<|end_header_id|>\n\nDescribe the video by detailing the following aspects: "
|
||||
"1. The main content and theme of the video."
|
||||
"2. The color, shape, size, texture, quantity, text, and spatial relationships of the objects."
|
||||
"3. Actions, events, behaviors temporal relationships, physical movement changes of the objects."
|
||||
"4. background environment, light, style and atmosphere."
|
||||
"5. camera angles, movements, and transitions used in the video:<|eot_id|>"
|
||||
"<|start_header_id|>user<|end_header_id|>\n\n{}<|eot_id|>"
|
||||
),
|
||||
"crop_start": 95,
|
||||
}
|
||||
|
||||
|
||||
def forward_with_stg(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
encoder_hidden_states: torch.Tensor,
|
||||
temb: torch.Tensor,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
freqs_cis: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
return hidden_states, encoder_hidden_states
|
||||
|
||||
|
||||
def forward_without_stg(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
encoder_hidden_states: torch.Tensor,
|
||||
temb: torch.Tensor,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
freqs_cis: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
# 1. Input normalization
|
||||
norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1(hidden_states, emb=temb)
|
||||
norm_encoder_hidden_states, c_gate_msa, c_shift_mlp, c_scale_mlp, c_gate_mlp = self.norm1_context(
|
||||
encoder_hidden_states, emb=temb
|
||||
)
|
||||
|
||||
# 2. Joint attention
|
||||
attn_output, context_attn_output = self.attn(
|
||||
hidden_states=norm_hidden_states,
|
||||
encoder_hidden_states=norm_encoder_hidden_states,
|
||||
attention_mask=attention_mask,
|
||||
image_rotary_emb=freqs_cis,
|
||||
)
|
||||
|
||||
# 3. Modulation and residual connection
|
||||
hidden_states = hidden_states + attn_output * gate_msa.unsqueeze(1)
|
||||
encoder_hidden_states = encoder_hidden_states + context_attn_output * c_gate_msa.unsqueeze(1)
|
||||
|
||||
norm_hidden_states = self.norm2(hidden_states)
|
||||
norm_encoder_hidden_states = self.norm2_context(encoder_hidden_states)
|
||||
|
||||
norm_hidden_states = norm_hidden_states * (1 + scale_mlp[:, None]) + shift_mlp[:, None]
|
||||
norm_encoder_hidden_states = norm_encoder_hidden_states * (1 + c_scale_mlp[:, None]) + c_shift_mlp[:, None]
|
||||
|
||||
# 4. Feed-forward
|
||||
ff_output = self.ff(norm_hidden_states)
|
||||
context_ff_output = self.ff_context(norm_encoder_hidden_states)
|
||||
|
||||
hidden_states = hidden_states + gate_mlp.unsqueeze(1) * ff_output
|
||||
encoder_hidden_states = encoder_hidden_states + c_gate_mlp.unsqueeze(1) * context_ff_output
|
||||
|
||||
return hidden_states, encoder_hidden_states
|
||||
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps
|
||||
def retrieve_timesteps(
|
||||
scheduler,
|
||||
num_inference_steps: Optional[int] = None,
|
||||
device: Optional[Union[str, torch.device]] = None,
|
||||
timesteps: Optional[List[int]] = None,
|
||||
sigmas: Optional[List[float]] = None,
|
||||
**kwargs,
|
||||
):
|
||||
r"""
|
||||
Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
|
||||
custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.
|
||||
|
||||
Args:
|
||||
scheduler (`SchedulerMixin`):
|
||||
The scheduler to get timesteps from.
|
||||
num_inference_steps (`int`):
|
||||
The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps`
|
||||
must be `None`.
|
||||
device (`str` or `torch.device`, *optional*):
|
||||
The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
|
||||
timesteps (`List[int]`, *optional*):
|
||||
Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed,
|
||||
`num_inference_steps` and `sigmas` must be `None`.
|
||||
sigmas (`List[float]`, *optional*):
|
||||
Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed,
|
||||
`num_inference_steps` and `timesteps` must be `None`.
|
||||
|
||||
Returns:
|
||||
`Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the
|
||||
second element is the number of inference steps.
|
||||
"""
|
||||
if timesteps is not None and sigmas is not None:
|
||||
raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values")
|
||||
if timesteps is not None:
|
||||
accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
|
||||
if not accepts_timesteps:
|
||||
raise ValueError(
|
||||
f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
|
||||
f" timestep schedules. Please check whether you are using the correct scheduler."
|
||||
)
|
||||
scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)
|
||||
timesteps = scheduler.timesteps
|
||||
num_inference_steps = len(timesteps)
|
||||
elif sigmas is not None:
|
||||
accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
|
||||
if not accept_sigmas:
|
||||
raise ValueError(
|
||||
f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
|
||||
f" sigmas schedules. Please check whether you are using the correct scheduler."
|
||||
)
|
||||
scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs)
|
||||
timesteps = scheduler.timesteps
|
||||
num_inference_steps = len(timesteps)
|
||||
else:
|
||||
scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)
|
||||
timesteps = scheduler.timesteps
|
||||
return timesteps, num_inference_steps
|
||||
|
||||
|
||||
class HunyuanVideoSTGPipeline(DiffusionPipeline, HunyuanVideoLoraLoaderMixin):
|
||||
r"""
|
||||
Pipeline for text-to-video generation using HunyuanVideo.
|
||||
|
||||
This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods
|
||||
implemented for all pipelines (downloading, saving, running on a particular device, etc.).
|
||||
|
||||
Args:
|
||||
text_encoder ([`LlamaModel`]):
|
||||
[Llava Llama3-8B](https://huggingface.co/xtuner/llava-llama-3-8b-v1_1-transformers).
|
||||
tokenizer (`LlamaTokenizer`):
|
||||
Tokenizer from [Llava Llama3-8B](https://huggingface.co/xtuner/llava-llama-3-8b-v1_1-transformers).
|
||||
transformer ([`HunyuanVideoTransformer3DModel`]):
|
||||
Conditional Transformer to denoise the encoded image latents.
|
||||
scheduler ([`FlowMatchEulerDiscreteScheduler`]):
|
||||
A scheduler to be used in combination with `transformer` to denoise the encoded image latents.
|
||||
vae ([`AutoencoderKLHunyuanVideo`]):
|
||||
Variational Auto-Encoder (VAE) Model to encode and decode videos to and from latent representations.
|
||||
text_encoder_2 ([`CLIPTextModel`]):
|
||||
[CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically
|
||||
the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant.
|
||||
tokenizer_2 (`CLIPTokenizer`):
|
||||
Tokenizer of class
|
||||
[CLIPTokenizer](https://huggingface.co/docs/transformers/en/model_doc/clip#transformers.CLIPTokenizer).
|
||||
"""
|
||||
|
||||
model_cpu_offload_seq = "text_encoder->text_encoder_2->transformer->vae"
|
||||
_callback_tensor_inputs = ["latents", "prompt_embeds"]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
text_encoder: LlamaModel,
|
||||
tokenizer: LlamaTokenizerFast,
|
||||
transformer: HunyuanVideoTransformer3DModel,
|
||||
vae: AutoencoderKLHunyuanVideo,
|
||||
scheduler: FlowMatchEulerDiscreteScheduler,
|
||||
text_encoder_2: CLIPTextModel,
|
||||
tokenizer_2: CLIPTokenizer,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.register_modules(
|
||||
vae=vae,
|
||||
text_encoder=text_encoder,
|
||||
tokenizer=tokenizer,
|
||||
transformer=transformer,
|
||||
scheduler=scheduler,
|
||||
text_encoder_2=text_encoder_2,
|
||||
tokenizer_2=tokenizer_2,
|
||||
)
|
||||
|
||||
self.vae_scale_factor_temporal = self.vae.temporal_compression_ratio if getattr(self, "vae", None) else 4
|
||||
self.vae_scale_factor_spatial = self.vae.spatial_compression_ratio if getattr(self, "vae", None) else 8
|
||||
self.video_processor = VideoProcessor(vae_scale_factor=self.vae_scale_factor_spatial)
|
||||
|
||||
def _get_llama_prompt_embeds(
|
||||
self,
|
||||
prompt: Union[str, List[str]],
|
||||
prompt_template: Dict[str, Any],
|
||||
num_videos_per_prompt: int = 1,
|
||||
device: Optional[torch.device] = None,
|
||||
dtype: Optional[torch.dtype] = None,
|
||||
max_sequence_length: int = 256,
|
||||
num_hidden_layers_to_skip: int = 2,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
device = device or self._execution_device
|
||||
dtype = dtype or self.text_encoder.dtype
|
||||
|
||||
prompt = [prompt] if isinstance(prompt, str) else prompt
|
||||
batch_size = len(prompt)
|
||||
|
||||
prompt = [prompt_template["template"].format(p) for p in prompt]
|
||||
|
||||
crop_start = prompt_template.get("crop_start", None)
|
||||
if crop_start is None:
|
||||
prompt_template_input = self.tokenizer(
|
||||
prompt_template["template"],
|
||||
padding="max_length",
|
||||
return_tensors="pt",
|
||||
return_length=False,
|
||||
return_overflowing_tokens=False,
|
||||
return_attention_mask=False,
|
||||
)
|
||||
crop_start = prompt_template_input["input_ids"].shape[-1]
|
||||
# Remove <|eot_id|> token and placeholder {}
|
||||
crop_start -= 2
|
||||
|
||||
max_sequence_length += crop_start
|
||||
text_inputs = self.tokenizer(
|
||||
prompt,
|
||||
max_length=max_sequence_length,
|
||||
padding="max_length",
|
||||
truncation=True,
|
||||
return_tensors="pt",
|
||||
return_length=False,
|
||||
return_overflowing_tokens=False,
|
||||
return_attention_mask=True,
|
||||
)
|
||||
text_input_ids = text_inputs.input_ids.to(device=device)
|
||||
prompt_attention_mask = text_inputs.attention_mask.to(device=device)
|
||||
|
||||
prompt_embeds = self.text_encoder(
|
||||
input_ids=text_input_ids,
|
||||
attention_mask=prompt_attention_mask,
|
||||
output_hidden_states=True,
|
||||
).hidden_states[-(num_hidden_layers_to_skip + 1)]
|
||||
prompt_embeds = prompt_embeds.to(dtype=dtype)
|
||||
|
||||
if crop_start is not None and crop_start > 0:
|
||||
prompt_embeds = prompt_embeds[:, crop_start:]
|
||||
prompt_attention_mask = prompt_attention_mask[:, crop_start:]
|
||||
|
||||
# duplicate text embeddings for each generation per prompt, using mps friendly method
|
||||
_, seq_len, _ = prompt_embeds.shape
|
||||
prompt_embeds = prompt_embeds.repeat(1, num_videos_per_prompt, 1)
|
||||
prompt_embeds = prompt_embeds.view(batch_size * num_videos_per_prompt, seq_len, -1)
|
||||
prompt_attention_mask = prompt_attention_mask.repeat(1, num_videos_per_prompt)
|
||||
prompt_attention_mask = prompt_attention_mask.view(batch_size * num_videos_per_prompt, seq_len)
|
||||
|
||||
return prompt_embeds, prompt_attention_mask
|
||||
|
||||
def _get_clip_prompt_embeds(
|
||||
self,
|
||||
prompt: Union[str, List[str]],
|
||||
num_videos_per_prompt: int = 1,
|
||||
device: Optional[torch.device] = None,
|
||||
dtype: Optional[torch.dtype] = None,
|
||||
max_sequence_length: int = 77,
|
||||
) -> torch.Tensor:
|
||||
device = device or self._execution_device
|
||||
dtype = dtype or self.text_encoder_2.dtype
|
||||
|
||||
prompt = [prompt] if isinstance(prompt, str) else prompt
|
||||
batch_size = len(prompt)
|
||||
|
||||
text_inputs = self.tokenizer_2(
|
||||
prompt,
|
||||
padding="max_length",
|
||||
max_length=max_sequence_length,
|
||||
truncation=True,
|
||||
return_tensors="pt",
|
||||
)
|
||||
|
||||
text_input_ids = text_inputs.input_ids
|
||||
untruncated_ids = self.tokenizer_2(prompt, padding="longest", return_tensors="pt").input_ids
|
||||
if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids):
|
||||
removed_text = self.tokenizer_2.batch_decode(untruncated_ids[:, max_sequence_length - 1 : -1])
|
||||
logger.warning(
|
||||
"The following part of your input was truncated because CLIP can only handle sequences up to"
|
||||
f" {max_sequence_length} tokens: {removed_text}"
|
||||
)
|
||||
|
||||
prompt_embeds = self.text_encoder_2(text_input_ids.to(device), output_hidden_states=False).pooler_output
|
||||
|
||||
# duplicate text embeddings for each generation per prompt, using mps friendly method
|
||||
prompt_embeds = prompt_embeds.repeat(1, num_videos_per_prompt)
|
||||
prompt_embeds = prompt_embeds.view(batch_size * num_videos_per_prompt, -1)
|
||||
|
||||
return prompt_embeds
|
||||
|
||||
def encode_prompt(
|
||||
self,
|
||||
prompt: Union[str, List[str]],
|
||||
prompt_2: Union[str, List[str]] = None,
|
||||
prompt_template: Dict[str, Any] = DEFAULT_PROMPT_TEMPLATE,
|
||||
num_videos_per_prompt: int = 1,
|
||||
prompt_embeds: Optional[torch.Tensor] = None,
|
||||
pooled_prompt_embeds: Optional[torch.Tensor] = None,
|
||||
prompt_attention_mask: Optional[torch.Tensor] = None,
|
||||
device: Optional[torch.device] = None,
|
||||
dtype: Optional[torch.dtype] = None,
|
||||
max_sequence_length: int = 256,
|
||||
):
|
||||
if prompt_embeds is None:
|
||||
prompt_embeds, prompt_attention_mask = self._get_llama_prompt_embeds(
|
||||
prompt,
|
||||
prompt_template,
|
||||
num_videos_per_prompt,
|
||||
device=device,
|
||||
dtype=dtype,
|
||||
max_sequence_length=max_sequence_length,
|
||||
)
|
||||
|
||||
if pooled_prompt_embeds is None:
|
||||
if prompt_2 is None and pooled_prompt_embeds is None:
|
||||
prompt_2 = prompt
|
||||
pooled_prompt_embeds = self._get_clip_prompt_embeds(
|
||||
prompt,
|
||||
num_videos_per_prompt,
|
||||
device=device,
|
||||
dtype=dtype,
|
||||
max_sequence_length=77,
|
||||
)
|
||||
|
||||
return prompt_embeds, pooled_prompt_embeds, prompt_attention_mask
|
||||
|
||||
def check_inputs(
|
||||
self,
|
||||
prompt,
|
||||
prompt_2,
|
||||
height,
|
||||
width,
|
||||
prompt_embeds=None,
|
||||
callback_on_step_end_tensor_inputs=None,
|
||||
prompt_template=None,
|
||||
):
|
||||
if height % 16 != 0 or width % 16 != 0:
|
||||
raise ValueError(f"`height` and `width` have to be divisible by 16 but are {height} and {width}.")
|
||||
|
||||
if callback_on_step_end_tensor_inputs is not None and not all(
|
||||
k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
|
||||
):
|
||||
raise ValueError(
|
||||
f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}"
|
||||
)
|
||||
|
||||
if prompt is not None and prompt_embeds is not None:
|
||||
raise ValueError(
|
||||
f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
|
||||
" only forward one of the two."
|
||||
)
|
||||
elif prompt_2 is not None and prompt_embeds is not None:
|
||||
raise ValueError(
|
||||
f"Cannot forward both `prompt_2`: {prompt_2} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
|
||||
" only forward one of the two."
|
||||
)
|
||||
elif prompt is None and prompt_embeds is None:
|
||||
raise ValueError(
|
||||
"Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
|
||||
)
|
||||
elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
|
||||
raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
|
||||
elif prompt_2 is not None and (not isinstance(prompt_2, str) and not isinstance(prompt_2, list)):
|
||||
raise ValueError(f"`prompt_2` has to be of type `str` or `list` but is {type(prompt_2)}")
|
||||
|
||||
if prompt_template is not None:
|
||||
if not isinstance(prompt_template, dict):
|
||||
raise ValueError(f"`prompt_template` has to be of type `dict` but is {type(prompt_template)}")
|
||||
if "template" not in prompt_template:
|
||||
raise ValueError(
|
||||
f"`prompt_template` has to contain a key `template` but only found {prompt_template.keys()}"
|
||||
)
|
||||
|
||||
def prepare_latents(
|
||||
self,
|
||||
batch_size: int,
|
||||
num_channels_latents: 32,
|
||||
height: int = 720,
|
||||
width: int = 1280,
|
||||
num_frames: int = 129,
|
||||
dtype: Optional[torch.dtype] = None,
|
||||
device: Optional[torch.device] = None,
|
||||
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
|
||||
latents: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
if latents is not None:
|
||||
return latents.to(device=device, dtype=dtype)
|
||||
|
||||
shape = (
|
||||
batch_size,
|
||||
num_channels_latents,
|
||||
num_frames,
|
||||
int(height) // self.vae_scale_factor_spatial,
|
||||
int(width) // self.vae_scale_factor_spatial,
|
||||
)
|
||||
if isinstance(generator, list) and len(generator) != batch_size:
|
||||
raise ValueError(
|
||||
f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
|
||||
f" size of {batch_size}. Make sure the batch size matches the length of the generators."
|
||||
)
|
||||
|
||||
latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
|
||||
return latents
|
||||
|
||||
def enable_vae_slicing(self):
|
||||
r"""
|
||||
Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to
|
||||
compute decoding in several steps. This is useful to save some memory and allow larger batch sizes.
|
||||
"""
|
||||
self.vae.enable_slicing()
|
||||
|
||||
def disable_vae_slicing(self):
|
||||
r"""
|
||||
Disable sliced VAE decoding. If `enable_vae_slicing` was previously enabled, this method will go back to
|
||||
computing decoding in one step.
|
||||
"""
|
||||
self.vae.disable_slicing()
|
||||
|
||||
def enable_vae_tiling(self):
|
||||
r"""
|
||||
Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to
|
||||
compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow
|
||||
processing larger images.
|
||||
"""
|
||||
self.vae.enable_tiling()
|
||||
|
||||
def disable_vae_tiling(self):
|
||||
r"""
|
||||
Disable tiled VAE decoding. If `enable_vae_tiling` was previously enabled, this method will go back to
|
||||
computing decoding in one step.
|
||||
"""
|
||||
self.vae.disable_tiling()
|
||||
|
||||
@property
|
||||
def guidance_scale(self):
|
||||
return self._guidance_scale
|
||||
|
||||
@property
|
||||
def do_spatio_temporal_guidance(self):
|
||||
return self._stg_scale > 0.0
|
||||
|
||||
@property
|
||||
def num_timesteps(self):
|
||||
return self._num_timesteps
|
||||
|
||||
@property
|
||||
def attention_kwargs(self):
|
||||
return self._attention_kwargs
|
||||
|
||||
@property
|
||||
def current_timestep(self):
|
||||
return self._current_timestep
|
||||
|
||||
@property
|
||||
def interrupt(self):
|
||||
return self._interrupt
|
||||
|
||||
@torch.no_grad()
|
||||
@replace_example_docstring(EXAMPLE_DOC_STRING)
|
||||
def __call__(
|
||||
self,
|
||||
prompt: Union[str, List[str]] = None,
|
||||
prompt_2: Union[str, List[str]] = None,
|
||||
height: int = 720,
|
||||
width: int = 1280,
|
||||
num_frames: int = 129,
|
||||
num_inference_steps: int = 50,
|
||||
sigmas: List[float] = None,
|
||||
guidance_scale: float = 6.0,
|
||||
num_videos_per_prompt: Optional[int] = 1,
|
||||
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
|
||||
latents: Optional[torch.Tensor] = None,
|
||||
prompt_embeds: Optional[torch.Tensor] = None,
|
||||
pooled_prompt_embeds: Optional[torch.Tensor] = None,
|
||||
prompt_attention_mask: Optional[torch.Tensor] = None,
|
||||
output_type: Optional[str] = "pil",
|
||||
return_dict: bool = True,
|
||||
attention_kwargs: Optional[Dict[str, Any]] = None,
|
||||
callback_on_step_end: Optional[
|
||||
Union[Callable[[int, int, Dict], None], PipelineCallback, MultiPipelineCallbacks]
|
||||
] = None,
|
||||
callback_on_step_end_tensor_inputs: List[str] = ["latents"],
|
||||
prompt_template: Dict[str, Any] = DEFAULT_PROMPT_TEMPLATE,
|
||||
max_sequence_length: int = 256,
|
||||
stg_applied_layers_idx: Optional[List[int]] = [2],
|
||||
stg_scale: Optional[float] = 0.0,
|
||||
):
|
||||
r"""
|
||||
The call function to the pipeline for generation.
|
||||
|
||||
Args:
|
||||
prompt (`str` or `List[str]`, *optional*):
|
||||
The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
|
||||
instead.
|
||||
prompt_2 (`str` or `List[str]`, *optional*):
|
||||
The prompt or prompts to be sent to `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is
|
||||
will be used instead.
|
||||
height (`int`, defaults to `720`):
|
||||
The height in pixels of the generated image.
|
||||
width (`int`, defaults to `1280`):
|
||||
The width in pixels of the generated image.
|
||||
num_frames (`int`, defaults to `129`):
|
||||
The number of frames in the generated video.
|
||||
num_inference_steps (`int`, defaults to `50`):
|
||||
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
|
||||
expense of slower inference.
|
||||
sigmas (`List[float]`, *optional*):
|
||||
Custom sigmas to use for the denoising process with schedulers which support a `sigmas` argument in
|
||||
their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed
|
||||
will be used.
|
||||
guidance_scale (`float`, defaults to `6.0`):
|
||||
Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
|
||||
`guidance_scale` is defined as `w` of equation 2. of [Imagen
|
||||
Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
|
||||
1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
|
||||
usually at the expense of lower image quality. Note that the only available HunyuanVideo model is
|
||||
CFG-distilled, which means that traditional guidance between unconditional and conditional latent is
|
||||
not applied.
|
||||
num_videos_per_prompt (`int`, *optional*, defaults to 1):
|
||||
The number of images to generate per prompt.
|
||||
generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
|
||||
A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make
|
||||
generation deterministic.
|
||||
latents (`torch.Tensor`, *optional*):
|
||||
Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs for image
|
||||
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
|
||||
tensor is generated by sampling using the supplied random `generator`.
|
||||
prompt_embeds (`torch.Tensor`, *optional*):
|
||||
Pre-generated text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not
|
||||
provided, text embeddings are generated from the `prompt` input argument.
|
||||
output_type (`str`, *optional*, defaults to `"pil"`):
|
||||
The output format of the generated image. Choose between `PIL.Image` or `np.array`.
|
||||
return_dict (`bool`, *optional*, defaults to `True`):
|
||||
Whether or not to return a [`HunyuanVideoPipelineOutput`] instead of a plain tuple.
|
||||
attention_kwargs (`dict`, *optional*):
|
||||
A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
|
||||
`self.processor` in
|
||||
[diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
|
||||
clip_skip (`int`, *optional*):
|
||||
Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that
|
||||
the output of the pre-final layer will be used for computing the prompt embeddings.
|
||||
callback_on_step_end (`Callable`, `PipelineCallback`, `MultiPipelineCallbacks`, *optional*):
|
||||
A function or a subclass of `PipelineCallback` or `MultiPipelineCallbacks` that is called at the end of
|
||||
each denoising step during the inference. with the following arguments: `callback_on_step_end(self:
|
||||
DiffusionPipeline, step: int, timestep: int, callback_kwargs: Dict)`. `callback_kwargs` will include a
|
||||
list of all tensors as specified by `callback_on_step_end_tensor_inputs`.
|
||||
callback_on_step_end_tensor_inputs (`List`, *optional*):
|
||||
The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
|
||||
will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
|
||||
`._callback_tensor_inputs` attribute of your pipeline class.
|
||||
|
||||
Examples:
|
||||
|
||||
Returns:
|
||||
[`~HunyuanVideoPipelineOutput`] or `tuple`:
|
||||
If `return_dict` is `True`, [`HunyuanVideoPipelineOutput`] is returned, otherwise a `tuple` is returned
|
||||
where the first element is a list with the generated images and the second element is a list of `bool`s
|
||||
indicating whether the corresponding generated image contains "not-safe-for-work" (nsfw) content.
|
||||
"""
|
||||
|
||||
if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)):
|
||||
callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs
|
||||
|
||||
# 1. Check inputs. Raise error if not correct
|
||||
self.check_inputs(
|
||||
prompt,
|
||||
prompt_2,
|
||||
height,
|
||||
width,
|
||||
prompt_embeds,
|
||||
callback_on_step_end_tensor_inputs,
|
||||
prompt_template,
|
||||
)
|
||||
|
||||
self._stg_scale = stg_scale
|
||||
self._guidance_scale = guidance_scale
|
||||
self._attention_kwargs = attention_kwargs
|
||||
self._current_timestep = None
|
||||
self._interrupt = False
|
||||
|
||||
device = self._execution_device
|
||||
|
||||
# 2. Define call parameters
|
||||
if prompt is not None and isinstance(prompt, str):
|
||||
batch_size = 1
|
||||
elif prompt is not None and isinstance(prompt, list):
|
||||
batch_size = len(prompt)
|
||||
else:
|
||||
batch_size = prompt_embeds.shape[0]
|
||||
|
||||
# 3. Encode input prompt
|
||||
prompt_embeds, pooled_prompt_embeds, prompt_attention_mask = self.encode_prompt(
|
||||
prompt=prompt,
|
||||
prompt_2=prompt_2,
|
||||
prompt_template=prompt_template,
|
||||
num_videos_per_prompt=num_videos_per_prompt,
|
||||
prompt_embeds=prompt_embeds,
|
||||
pooled_prompt_embeds=pooled_prompt_embeds,
|
||||
prompt_attention_mask=prompt_attention_mask,
|
||||
device=device,
|
||||
max_sequence_length=max_sequence_length,
|
||||
)
|
||||
|
||||
transformer_dtype = self.transformer.dtype
|
||||
prompt_embeds = prompt_embeds.to(transformer_dtype)
|
||||
prompt_attention_mask = prompt_attention_mask.to(transformer_dtype)
|
||||
if pooled_prompt_embeds is not None:
|
||||
pooled_prompt_embeds = pooled_prompt_embeds.to(transformer_dtype)
|
||||
|
||||
# 4. Prepare timesteps
|
||||
sigmas = np.linspace(1.0, 0.0, num_inference_steps + 1)[:-1] if sigmas is None else sigmas
|
||||
timesteps, num_inference_steps = retrieve_timesteps(
|
||||
self.scheduler,
|
||||
num_inference_steps,
|
||||
device,
|
||||
sigmas=sigmas,
|
||||
)
|
||||
|
||||
# 5. Prepare latent variables
|
||||
num_channels_latents = self.transformer.config.in_channels
|
||||
num_latent_frames = (num_frames - 1) // self.vae_scale_factor_temporal + 1
|
||||
latents = self.prepare_latents(
|
||||
batch_size * num_videos_per_prompt,
|
||||
num_channels_latents,
|
||||
height,
|
||||
width,
|
||||
num_latent_frames,
|
||||
torch.float32,
|
||||
device,
|
||||
generator,
|
||||
latents,
|
||||
)
|
||||
|
||||
# 6. Prepare guidance condition
|
||||
guidance = torch.tensor([guidance_scale] * latents.shape[0], dtype=transformer_dtype, device=device) * 1000.0
|
||||
|
||||
# 7. Denoising loop
|
||||
num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
|
||||
self._num_timesteps = len(timesteps)
|
||||
|
||||
with self.progress_bar(total=num_inference_steps) as progress_bar:
|
||||
for i, t in enumerate(timesteps):
|
||||
if self.interrupt:
|
||||
continue
|
||||
|
||||
self._current_timestep = t
|
||||
latent_model_input = latents.to(transformer_dtype)
|
||||
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
|
||||
timestep = t.expand(latents.shape[0]).to(latents.dtype)
|
||||
|
||||
if self.do_spatio_temporal_guidance:
|
||||
for i in stg_applied_layers_idx:
|
||||
self.transformer.transformer_blocks[i].forward = types.MethodType(
|
||||
forward_without_stg, self.transformer.transformer_blocks[i]
|
||||
)
|
||||
|
||||
noise_pred = self.transformer(
|
||||
hidden_states=latent_model_input,
|
||||
timestep=timestep,
|
||||
encoder_hidden_states=prompt_embeds,
|
||||
encoder_attention_mask=prompt_attention_mask,
|
||||
pooled_projections=pooled_prompt_embeds,
|
||||
guidance=guidance,
|
||||
attention_kwargs=attention_kwargs,
|
||||
return_dict=False,
|
||||
)[0]
|
||||
|
||||
if self.do_spatio_temporal_guidance:
|
||||
for i in stg_applied_layers_idx:
|
||||
self.transformer.transformer_blocks[i].forward = types.MethodType(
|
||||
forward_with_stg, self.transformer.transformer_blocks[i]
|
||||
)
|
||||
|
||||
noise_pred_perturb = self.transformer(
|
||||
hidden_states=latent_model_input,
|
||||
timestep=timestep,
|
||||
encoder_hidden_states=prompt_embeds,
|
||||
encoder_attention_mask=prompt_attention_mask,
|
||||
pooled_projections=pooled_prompt_embeds,
|
||||
guidance=guidance,
|
||||
attention_kwargs=attention_kwargs,
|
||||
return_dict=False,
|
||||
)[0]
|
||||
noise_pred = noise_pred + self._stg_scale * (noise_pred - noise_pred_perturb)
|
||||
|
||||
# compute the previous noisy sample x_t -> x_t-1
|
||||
latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0]
|
||||
|
||||
if callback_on_step_end is not None:
|
||||
callback_kwargs = {}
|
||||
for k in callback_on_step_end_tensor_inputs:
|
||||
callback_kwargs[k] = locals()[k]
|
||||
callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
|
||||
|
||||
latents = callback_outputs.pop("latents", latents)
|
||||
prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
|
||||
|
||||
# call the callback, if provided
|
||||
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
|
||||
progress_bar.update()
|
||||
|
||||
if XLA_AVAILABLE:
|
||||
xm.mark_step()
|
||||
|
||||
self._current_timestep = None
|
||||
|
||||
if not output_type == "latent":
|
||||
latents = latents.to(self.vae.dtype) / self.vae.config.scaling_factor
|
||||
video = self.vae.decode(latents, return_dict=False)[0]
|
||||
video = self.video_processor.postprocess_video(video, output_type=output_type)
|
||||
else:
|
||||
video = latents
|
||||
|
||||
# Offload all models
|
||||
self.maybe_free_model_hooks()
|
||||
|
||||
if not return_dict:
|
||||
return (video,)
|
||||
|
||||
return HunyuanVideoPipelineOutput(frames=video)
|
||||
@@ -0,0 +1,886 @@
|
||||
# Copyright 2024 Lightricks and The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import inspect
|
||||
import types
|
||||
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from transformers import T5EncoderModel, T5TokenizerFast
|
||||
|
||||
from diffusers.callbacks import MultiPipelineCallbacks, PipelineCallback
|
||||
from diffusers.loaders import FromSingleFileMixin, LTXVideoLoraLoaderMixin
|
||||
from diffusers.models.autoencoders import AutoencoderKLLTXVideo
|
||||
from diffusers.models.transformers import LTXVideoTransformer3DModel
|
||||
from diffusers.pipelines.ltx.pipeline_output import LTXPipelineOutput
|
||||
from diffusers.pipelines.pipeline_utils import DiffusionPipeline
|
||||
from diffusers.schedulers import FlowMatchEulerDiscreteScheduler
|
||||
from diffusers.utils import is_torch_xla_available, logging, replace_example_docstring
|
||||
from diffusers.utils.torch_utils import randn_tensor
|
||||
from diffusers.video_processor import VideoProcessor
|
||||
|
||||
|
||||
if is_torch_xla_available():
|
||||
import torch_xla.core.xla_model as xm
|
||||
|
||||
XLA_AVAILABLE = True
|
||||
else:
|
||||
XLA_AVAILABLE = False
|
||||
|
||||
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||
|
||||
EXAMPLE_DOC_STRING = """
|
||||
Examples:
|
||||
```py
|
||||
>>> import torch
|
||||
>>> from diffusers.utils import export_to_video
|
||||
>>> from examples.community.pipeline_stg_ltx import LTXSTGPipeline
|
||||
|
||||
>>> pipe = LTXSTGPipeline.from_pretrained("Lightricks/LTX-Video", torch_dtype=torch.bfloat16)
|
||||
>>> pipe.to("cuda")
|
||||
|
||||
>>> prompt = "A woman with light skin, wearing a blue jacket and a black hat with a veil, looks down and to her right, then back up as she speaks; she has brown hair styled in an updo, light brown eyebrows, and is wearing a white collared shirt under her jacket; the camera remains stationary on her face as she speaks; the background is out of focus, but shows trees and people in period clothing; the scene is captured in real-life footage."
|
||||
>>> negative_prompt = "worst quality, inconsistent motion, blurry, jittery, distorted"
|
||||
|
||||
>>> # Configure STG mode options
|
||||
>>> stg_applied_layers_idx = [19] # Layer indices from 0 to 41
|
||||
>>> stg_scale = 1.0 # Set 0.0 for CFG
|
||||
>>> do_rescaling = False
|
||||
|
||||
>>> video = pipe(
|
||||
... prompt=prompt,
|
||||
... negative_prompt=negative_prompt,
|
||||
... width=704,
|
||||
... height=480,
|
||||
... num_frames=161,
|
||||
... num_inference_steps=50,
|
||||
... stg_applied_layers_idx=stg_applied_layers_idx,
|
||||
... stg_scale=stg_scale,
|
||||
... do_rescaling=do_rescaling,
|
||||
>>> ).frames[0]
|
||||
>>> export_to_video(video, "output.mp4", fps=24)
|
||||
```
|
||||
"""
|
||||
|
||||
|
||||
def forward_with_stg(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
encoder_hidden_states: torch.Tensor,
|
||||
temb: torch.Tensor,
|
||||
image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
|
||||
encoder_attention_mask: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
hidden_states_ptb = hidden_states[2:]
|
||||
encoder_hidden_states_ptb = encoder_hidden_states[2:]
|
||||
|
||||
batch_size = hidden_states.size(0)
|
||||
norm_hidden_states = self.norm1(hidden_states)
|
||||
|
||||
num_ada_params = self.scale_shift_table.shape[0]
|
||||
ada_values = self.scale_shift_table[None, None] + temb.reshape(batch_size, temb.size(1), num_ada_params, -1)
|
||||
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = ada_values.unbind(dim=2)
|
||||
norm_hidden_states = norm_hidden_states * (1 + scale_msa) + shift_msa
|
||||
|
||||
attn_hidden_states = self.attn1(
|
||||
hidden_states=norm_hidden_states,
|
||||
encoder_hidden_states=None,
|
||||
image_rotary_emb=image_rotary_emb,
|
||||
)
|
||||
hidden_states = hidden_states + attn_hidden_states * gate_msa
|
||||
|
||||
attn_hidden_states = self.attn2(
|
||||
hidden_states,
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
image_rotary_emb=None,
|
||||
attention_mask=encoder_attention_mask,
|
||||
)
|
||||
hidden_states = hidden_states + attn_hidden_states
|
||||
norm_hidden_states = self.norm2(hidden_states) * (1 + scale_mlp) + shift_mlp
|
||||
|
||||
ff_output = self.ff(norm_hidden_states)
|
||||
hidden_states = hidden_states + ff_output * gate_mlp
|
||||
|
||||
hidden_states[2:] = hidden_states_ptb
|
||||
encoder_hidden_states[2:] = encoder_hidden_states_ptb
|
||||
|
||||
return hidden_states
|
||||
|
||||
|
||||
# Copied from diffusers.pipelines.flux.pipeline_flux.calculate_shift
|
||||
def calculate_shift(
|
||||
image_seq_len,
|
||||
base_seq_len: int = 256,
|
||||
max_seq_len: int = 4096,
|
||||
base_shift: float = 0.5,
|
||||
max_shift: float = 1.16,
|
||||
):
|
||||
m = (max_shift - base_shift) / (max_seq_len - base_seq_len)
|
||||
b = base_shift - m * base_seq_len
|
||||
mu = image_seq_len * m + b
|
||||
return mu
|
||||
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps
|
||||
def retrieve_timesteps(
|
||||
scheduler,
|
||||
num_inference_steps: Optional[int] = None,
|
||||
device: Optional[Union[str, torch.device]] = None,
|
||||
timesteps: Optional[List[int]] = None,
|
||||
sigmas: Optional[List[float]] = None,
|
||||
**kwargs,
|
||||
):
|
||||
r"""
|
||||
Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
|
||||
custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.
|
||||
|
||||
Args:
|
||||
scheduler (`SchedulerMixin`):
|
||||
The scheduler to get timesteps from.
|
||||
num_inference_steps (`int`):
|
||||
The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps`
|
||||
must be `None`.
|
||||
device (`str` or `torch.device`, *optional*):
|
||||
The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
|
||||
timesteps (`List[int]`, *optional*):
|
||||
Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed,
|
||||
`num_inference_steps` and `sigmas` must be `None`.
|
||||
sigmas (`List[float]`, *optional*):
|
||||
Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed,
|
||||
`num_inference_steps` and `timesteps` must be `None`.
|
||||
|
||||
Returns:
|
||||
`Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the
|
||||
second element is the number of inference steps.
|
||||
"""
|
||||
if timesteps is not None and sigmas is not None:
|
||||
raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values")
|
||||
if timesteps is not None:
|
||||
accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
|
||||
if not accepts_timesteps:
|
||||
raise ValueError(
|
||||
f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
|
||||
f" timestep schedules. Please check whether you are using the correct scheduler."
|
||||
)
|
||||
scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)
|
||||
timesteps = scheduler.timesteps
|
||||
num_inference_steps = len(timesteps)
|
||||
elif sigmas is not None:
|
||||
accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
|
||||
if not accept_sigmas:
|
||||
raise ValueError(
|
||||
f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
|
||||
f" sigmas schedules. Please check whether you are using the correct scheduler."
|
||||
)
|
||||
scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs)
|
||||
timesteps = scheduler.timesteps
|
||||
num_inference_steps = len(timesteps)
|
||||
else:
|
||||
scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)
|
||||
timesteps = scheduler.timesteps
|
||||
return timesteps, num_inference_steps
|
||||
|
||||
|
||||
class LTXSTGPipeline(DiffusionPipeline, FromSingleFileMixin, LTXVideoLoraLoaderMixin):
|
||||
r"""
|
||||
Pipeline for text-to-video generation.
|
||||
|
||||
Reference: https://github.com/Lightricks/LTX-Video
|
||||
|
||||
Args:
|
||||
transformer ([`LTXVideoTransformer3DModel`]):
|
||||
Conditional Transformer architecture to denoise the encoded video latents.
|
||||
scheduler ([`FlowMatchEulerDiscreteScheduler`]):
|
||||
A scheduler to be used in combination with `transformer` to denoise the encoded image latents.
|
||||
vae ([`AutoencoderKLLTXVideo`]):
|
||||
Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.
|
||||
text_encoder ([`T5EncoderModel`]):
|
||||
[T5](https://huggingface.co/docs/transformers/en/model_doc/t5#transformers.T5EncoderModel), specifically
|
||||
the [google/t5-v1_1-xxl](https://huggingface.co/google/t5-v1_1-xxl) variant.
|
||||
tokenizer (`CLIPTokenizer`):
|
||||
Tokenizer of class
|
||||
[CLIPTokenizer](https://huggingface.co/docs/transformers/en/model_doc/clip#transformers.CLIPTokenizer).
|
||||
tokenizer (`T5TokenizerFast`):
|
||||
Second Tokenizer of class
|
||||
[T5TokenizerFast](https://huggingface.co/docs/transformers/en/model_doc/t5#transformers.T5TokenizerFast).
|
||||
"""
|
||||
|
||||
model_cpu_offload_seq = "text_encoder->transformer->vae"
|
||||
_optional_components = []
|
||||
_callback_tensor_inputs = ["latents", "prompt_embeds", "negative_prompt_embeds"]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
scheduler: FlowMatchEulerDiscreteScheduler,
|
||||
vae: AutoencoderKLLTXVideo,
|
||||
text_encoder: T5EncoderModel,
|
||||
tokenizer: T5TokenizerFast,
|
||||
transformer: LTXVideoTransformer3DModel,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.register_modules(
|
||||
vae=vae,
|
||||
text_encoder=text_encoder,
|
||||
tokenizer=tokenizer,
|
||||
transformer=transformer,
|
||||
scheduler=scheduler,
|
||||
)
|
||||
|
||||
self.vae_spatial_compression_ratio = (
|
||||
self.vae.spatial_compression_ratio if getattr(self, "vae", None) is not None else 32
|
||||
)
|
||||
self.vae_temporal_compression_ratio = (
|
||||
self.vae.temporal_compression_ratio if getattr(self, "vae", None) is not None else 8
|
||||
)
|
||||
self.transformer_spatial_patch_size = (
|
||||
self.transformer.config.patch_size if getattr(self, "transformer", None) is not None else 1
|
||||
)
|
||||
self.transformer_temporal_patch_size = (
|
||||
self.transformer.config.patch_size_t if getattr(self, "transformer") is not None else 1
|
||||
)
|
||||
|
||||
self.video_processor = VideoProcessor(vae_scale_factor=self.vae_spatial_compression_ratio)
|
||||
self.tokenizer_max_length = (
|
||||
self.tokenizer.model_max_length if getattr(self, "tokenizer", None) is not None else 128
|
||||
)
|
||||
|
||||
def _get_t5_prompt_embeds(
|
||||
self,
|
||||
prompt: Union[str, List[str]] = None,
|
||||
num_videos_per_prompt: int = 1,
|
||||
max_sequence_length: int = 128,
|
||||
device: Optional[torch.device] = None,
|
||||
dtype: Optional[torch.dtype] = None,
|
||||
):
|
||||
device = device or self._execution_device
|
||||
dtype = dtype or self.text_encoder.dtype
|
||||
|
||||
prompt = [prompt] if isinstance(prompt, str) else prompt
|
||||
batch_size = len(prompt)
|
||||
|
||||
text_inputs = self.tokenizer(
|
||||
prompt,
|
||||
padding="max_length",
|
||||
max_length=max_sequence_length,
|
||||
truncation=True,
|
||||
add_special_tokens=True,
|
||||
return_tensors="pt",
|
||||
)
|
||||
text_input_ids = text_inputs.input_ids
|
||||
prompt_attention_mask = text_inputs.attention_mask
|
||||
prompt_attention_mask = prompt_attention_mask.bool().to(device)
|
||||
|
||||
untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
|
||||
|
||||
if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids):
|
||||
removed_text = self.tokenizer.batch_decode(untruncated_ids[:, max_sequence_length - 1 : -1])
|
||||
logger.warning(
|
||||
"The following part of your input was truncated because `max_sequence_length` is set to "
|
||||
f" {max_sequence_length} tokens: {removed_text}"
|
||||
)
|
||||
|
||||
prompt_embeds = self.text_encoder(text_input_ids.to(device))[0]
|
||||
prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)
|
||||
|
||||
# duplicate text embeddings for each generation per prompt, using mps friendly method
|
||||
_, seq_len, _ = prompt_embeds.shape
|
||||
prompt_embeds = prompt_embeds.repeat(1, num_videos_per_prompt, 1)
|
||||
prompt_embeds = prompt_embeds.view(batch_size * num_videos_per_prompt, seq_len, -1)
|
||||
|
||||
prompt_attention_mask = prompt_attention_mask.view(batch_size, -1)
|
||||
prompt_attention_mask = prompt_attention_mask.repeat(num_videos_per_prompt, 1)
|
||||
|
||||
return prompt_embeds, prompt_attention_mask
|
||||
|
||||
# Copied from diffusers.pipelines.mochi.pipeline_mochi.MochiPipeline.encode_prompt with 256->128
|
||||
def encode_prompt(
|
||||
self,
|
||||
prompt: Union[str, List[str]],
|
||||
negative_prompt: Optional[Union[str, List[str]]] = None,
|
||||
do_classifier_free_guidance: bool = True,
|
||||
num_videos_per_prompt: int = 1,
|
||||
prompt_embeds: Optional[torch.Tensor] = None,
|
||||
negative_prompt_embeds: Optional[torch.Tensor] = None,
|
||||
prompt_attention_mask: Optional[torch.Tensor] = None,
|
||||
negative_prompt_attention_mask: Optional[torch.Tensor] = None,
|
||||
max_sequence_length: int = 128,
|
||||
device: Optional[torch.device] = None,
|
||||
dtype: Optional[torch.dtype] = None,
|
||||
):
|
||||
r"""
|
||||
Encodes the prompt into text encoder hidden states.
|
||||
|
||||
Args:
|
||||
prompt (`str` or `List[str]`, *optional*):
|
||||
prompt to be encoded
|
||||
negative_prompt (`str` or `List[str]`, *optional*):
|
||||
The prompt or prompts not to guide the image generation. If not defined, one has to pass
|
||||
`negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
|
||||
less than `1`).
|
||||
do_classifier_free_guidance (`bool`, *optional*, defaults to `True`):
|
||||
Whether to use classifier free guidance or not.
|
||||
num_videos_per_prompt (`int`, *optional*, defaults to 1):
|
||||
Number of videos that should be generated per prompt. torch device to place the resulting embeddings on
|
||||
prompt_embeds (`torch.Tensor`, *optional*):
|
||||
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
|
||||
provided, text embeddings will be generated from `prompt` input argument.
|
||||
negative_prompt_embeds (`torch.Tensor`, *optional*):
|
||||
Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
|
||||
weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
|
||||
argument.
|
||||
device: (`torch.device`, *optional*):
|
||||
torch device
|
||||
dtype: (`torch.dtype`, *optional*):
|
||||
torch dtype
|
||||
"""
|
||||
device = device or self._execution_device
|
||||
|
||||
prompt = [prompt] if isinstance(prompt, str) else prompt
|
||||
if prompt is not None:
|
||||
batch_size = len(prompt)
|
||||
else:
|
||||
batch_size = prompt_embeds.shape[0]
|
||||
|
||||
if prompt_embeds is None:
|
||||
prompt_embeds, prompt_attention_mask = self._get_t5_prompt_embeds(
|
||||
prompt=prompt,
|
||||
num_videos_per_prompt=num_videos_per_prompt,
|
||||
max_sequence_length=max_sequence_length,
|
||||
device=device,
|
||||
dtype=dtype,
|
||||
)
|
||||
|
||||
if do_classifier_free_guidance and negative_prompt_embeds is None:
|
||||
negative_prompt = negative_prompt or ""
|
||||
negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt
|
||||
|
||||
if prompt is not None and type(prompt) is not type(negative_prompt):
|
||||
raise TypeError(
|
||||
f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
|
||||
f" {type(prompt)}."
|
||||
)
|
||||
elif batch_size != len(negative_prompt):
|
||||
raise ValueError(
|
||||
f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
|
||||
f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
|
||||
" the batch size of `prompt`."
|
||||
)
|
||||
|
||||
negative_prompt_embeds, negative_prompt_attention_mask = self._get_t5_prompt_embeds(
|
||||
prompt=negative_prompt,
|
||||
num_videos_per_prompt=num_videos_per_prompt,
|
||||
max_sequence_length=max_sequence_length,
|
||||
device=device,
|
||||
dtype=dtype,
|
||||
)
|
||||
|
||||
return prompt_embeds, prompt_attention_mask, negative_prompt_embeds, negative_prompt_attention_mask
|
||||
|
||||
def check_inputs(
|
||||
self,
|
||||
prompt,
|
||||
height,
|
||||
width,
|
||||
callback_on_step_end_tensor_inputs=None,
|
||||
prompt_embeds=None,
|
||||
negative_prompt_embeds=None,
|
||||
prompt_attention_mask=None,
|
||||
negative_prompt_attention_mask=None,
|
||||
):
|
||||
if height % 32 != 0 or width % 32 != 0:
|
||||
raise ValueError(f"`height` and `width` have to be divisible by 32 but are {height} and {width}.")
|
||||
|
||||
if callback_on_step_end_tensor_inputs is not None and not all(
|
||||
k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
|
||||
):
|
||||
raise ValueError(
|
||||
f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}"
|
||||
)
|
||||
|
||||
if prompt is not None and prompt_embeds is not None:
|
||||
raise ValueError(
|
||||
f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
|
||||
" only forward one of the two."
|
||||
)
|
||||
elif prompt is None and prompt_embeds is None:
|
||||
raise ValueError(
|
||||
"Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
|
||||
)
|
||||
elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
|
||||
raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
|
||||
|
||||
if prompt_embeds is not None and prompt_attention_mask is None:
|
||||
raise ValueError("Must provide `prompt_attention_mask` when specifying `prompt_embeds`.")
|
||||
|
||||
if negative_prompt_embeds is not None and negative_prompt_attention_mask is None:
|
||||
raise ValueError("Must provide `negative_prompt_attention_mask` when specifying `negative_prompt_embeds`.")
|
||||
|
||||
if prompt_embeds is not None and negative_prompt_embeds is not None:
|
||||
if prompt_embeds.shape != negative_prompt_embeds.shape:
|
||||
raise ValueError(
|
||||
"`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but"
|
||||
f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`"
|
||||
f" {negative_prompt_embeds.shape}."
|
||||
)
|
||||
if prompt_attention_mask.shape != negative_prompt_attention_mask.shape:
|
||||
raise ValueError(
|
||||
"`prompt_attention_mask` and `negative_prompt_attention_mask` must have the same shape when passed directly, but"
|
||||
f" got: `prompt_attention_mask` {prompt_attention_mask.shape} != `negative_prompt_attention_mask`"
|
||||
f" {negative_prompt_attention_mask.shape}."
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _pack_latents(latents: torch.Tensor, patch_size: int = 1, patch_size_t: int = 1) -> torch.Tensor:
|
||||
# Unpacked latents of shape are [B, C, F, H, W] are patched into tokens of shape [B, C, F // p_t, p_t, H // p, p, W // p, p].
|
||||
# The patch dimensions are then permuted and collapsed into the channel dimension of shape:
|
||||
# [B, F // p_t * H // p * W // p, C * p_t * p * p] (an ndim=3 tensor).
|
||||
# dim=0 is the batch size, dim=1 is the effective video sequence length, dim=2 is the effective number of input features
|
||||
batch_size, num_channels, num_frames, height, width = latents.shape
|
||||
post_patch_num_frames = num_frames // patch_size_t
|
||||
post_patch_height = height // patch_size
|
||||
post_patch_width = width // patch_size
|
||||
latents = latents.reshape(
|
||||
batch_size,
|
||||
-1,
|
||||
post_patch_num_frames,
|
||||
patch_size_t,
|
||||
post_patch_height,
|
||||
patch_size,
|
||||
post_patch_width,
|
||||
patch_size,
|
||||
)
|
||||
latents = latents.permute(0, 2, 4, 6, 1, 3, 5, 7).flatten(4, 7).flatten(1, 3)
|
||||
return latents
|
||||
|
||||
@staticmethod
|
||||
def _unpack_latents(
|
||||
latents: torch.Tensor, num_frames: int, height: int, width: int, patch_size: int = 1, patch_size_t: int = 1
|
||||
) -> torch.Tensor:
|
||||
# Packed latents of shape [B, S, D] (S is the effective video sequence length, D is the effective feature dimensions)
|
||||
# are unpacked and reshaped into a video tensor of shape [B, C, F, H, W]. This is the inverse operation of
|
||||
# what happens in the `_pack_latents` method.
|
||||
batch_size = latents.size(0)
|
||||
latents = latents.reshape(batch_size, num_frames, height, width, -1, patch_size_t, patch_size, patch_size)
|
||||
latents = latents.permute(0, 4, 1, 5, 2, 6, 3, 7).flatten(6, 7).flatten(4, 5).flatten(2, 3)
|
||||
return latents
|
||||
|
||||
@staticmethod
|
||||
def _normalize_latents(
|
||||
latents: torch.Tensor, latents_mean: torch.Tensor, latents_std: torch.Tensor, scaling_factor: float = 1.0
|
||||
) -> torch.Tensor:
|
||||
# Normalize latents across the channel dimension [B, C, F, H, W]
|
||||
latents_mean = latents_mean.view(1, -1, 1, 1, 1).to(latents.device, latents.dtype)
|
||||
latents_std = latents_std.view(1, -1, 1, 1, 1).to(latents.device, latents.dtype)
|
||||
latents = (latents - latents_mean) * scaling_factor / latents_std
|
||||
return latents
|
||||
|
||||
@staticmethod
|
||||
def _denormalize_latents(
|
||||
latents: torch.Tensor, latents_mean: torch.Tensor, latents_std: torch.Tensor, scaling_factor: float = 1.0
|
||||
) -> torch.Tensor:
|
||||
# Denormalize latents across the channel dimension [B, C, F, H, W]
|
||||
latents_mean = latents_mean.view(1, -1, 1, 1, 1).to(latents.device, latents.dtype)
|
||||
latents_std = latents_std.view(1, -1, 1, 1, 1).to(latents.device, latents.dtype)
|
||||
latents = latents * latents_std / scaling_factor + latents_mean
|
||||
return latents
|
||||
|
||||
def prepare_latents(
|
||||
self,
|
||||
batch_size: int = 1,
|
||||
num_channels_latents: int = 128,
|
||||
height: int = 512,
|
||||
width: int = 704,
|
||||
num_frames: int = 161,
|
||||
dtype: Optional[torch.dtype] = None,
|
||||
device: Optional[torch.device] = None,
|
||||
generator: Optional[torch.Generator] = None,
|
||||
latents: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
if latents is not None:
|
||||
return latents.to(device=device, dtype=dtype)
|
||||
|
||||
height = height // self.vae_spatial_compression_ratio
|
||||
width = width // self.vae_spatial_compression_ratio
|
||||
num_frames = (num_frames - 1) // self.vae_temporal_compression_ratio + 1
|
||||
|
||||
shape = (batch_size, num_channels_latents, num_frames, height, width)
|
||||
|
||||
if isinstance(generator, list) and len(generator) != batch_size:
|
||||
raise ValueError(
|
||||
f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
|
||||
f" size of {batch_size}. Make sure the batch size matches the length of the generators."
|
||||
)
|
||||
|
||||
latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
|
||||
latents = self._pack_latents(
|
||||
latents, self.transformer_spatial_patch_size, self.transformer_temporal_patch_size
|
||||
)
|
||||
return latents
|
||||
|
||||
@property
|
||||
def guidance_scale(self):
|
||||
return self._guidance_scale
|
||||
|
||||
@property
|
||||
def do_classifier_free_guidance(self):
|
||||
return self._guidance_scale > 1.0
|
||||
|
||||
@property
|
||||
def do_spatio_temporal_guidance(self):
|
||||
return self._stg_scale > 0.0
|
||||
|
||||
@property
|
||||
def num_timesteps(self):
|
||||
return self._num_timesteps
|
||||
|
||||
@property
|
||||
def attention_kwargs(self):
|
||||
return self._attention_kwargs
|
||||
|
||||
@property
|
||||
def interrupt(self):
|
||||
return self._interrupt
|
||||
|
||||
@torch.no_grad()
|
||||
@replace_example_docstring(EXAMPLE_DOC_STRING)
|
||||
def __call__(
|
||||
self,
|
||||
prompt: Union[str, List[str]] = None,
|
||||
negative_prompt: Optional[Union[str, List[str]]] = None,
|
||||
height: int = 512,
|
||||
width: int = 704,
|
||||
num_frames: int = 161,
|
||||
frame_rate: int = 25,
|
||||
num_inference_steps: int = 50,
|
||||
timesteps: List[int] = None,
|
||||
guidance_scale: float = 3,
|
||||
num_videos_per_prompt: Optional[int] = 1,
|
||||
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
|
||||
latents: Optional[torch.Tensor] = None,
|
||||
prompt_embeds: Optional[torch.Tensor] = None,
|
||||
prompt_attention_mask: Optional[torch.Tensor] = None,
|
||||
negative_prompt_embeds: Optional[torch.Tensor] = None,
|
||||
negative_prompt_attention_mask: Optional[torch.Tensor] = None,
|
||||
decode_timestep: Union[float, List[float]] = 0.0,
|
||||
decode_noise_scale: Optional[Union[float, List[float]]] = None,
|
||||
output_type: Optional[str] = "pil",
|
||||
return_dict: bool = True,
|
||||
attention_kwargs: Optional[Dict[str, Any]] = None,
|
||||
callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
|
||||
callback_on_step_end_tensor_inputs: List[str] = ["latents"],
|
||||
max_sequence_length: int = 128,
|
||||
stg_applied_layers_idx: Optional[List[int]] = [19],
|
||||
stg_scale: Optional[float] = 1.0,
|
||||
do_rescaling: Optional[bool] = False,
|
||||
):
|
||||
r"""
|
||||
Function invoked when calling the pipeline for generation.
|
||||
|
||||
Args:
|
||||
prompt (`str` or `List[str]`, *optional*):
|
||||
The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
|
||||
instead.
|
||||
height (`int`, defaults to `512`):
|
||||
The height in pixels of the generated image. This is set to 480 by default for the best results.
|
||||
width (`int`, defaults to `704`):
|
||||
The width in pixels of the generated image. This is set to 848 by default for the best results.
|
||||
num_frames (`int`, defaults to `161`):
|
||||
The number of video frames to generate
|
||||
num_inference_steps (`int`, *optional*, defaults to 50):
|
||||
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
|
||||
expense of slower inference.
|
||||
timesteps (`List[int]`, *optional*):
|
||||
Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument
|
||||
in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is
|
||||
passed will be used. Must be in descending order.
|
||||
guidance_scale (`float`, defaults to `3 `):
|
||||
Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
|
||||
`guidance_scale` is defined as `w` of equation 2. of [Imagen
|
||||
Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
|
||||
1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
|
||||
usually at the expense of lower image quality.
|
||||
num_videos_per_prompt (`int`, *optional*, defaults to 1):
|
||||
The number of videos to generate per prompt.
|
||||
generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
|
||||
One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
|
||||
to make generation deterministic.
|
||||
latents (`torch.Tensor`, *optional*):
|
||||
Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
|
||||
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
|
||||
tensor will ge generated by sampling using the supplied random `generator`.
|
||||
prompt_embeds (`torch.Tensor`, *optional*):
|
||||
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
|
||||
provided, text embeddings will be generated from `prompt` input argument.
|
||||
prompt_attention_mask (`torch.Tensor`, *optional*):
|
||||
Pre-generated attention mask for text embeddings.
|
||||
negative_prompt_embeds (`torch.FloatTensor`, *optional*):
|
||||
Pre-generated negative text embeddings. For PixArt-Sigma this negative prompt should be "". If not
|
||||
provided, negative_prompt_embeds will be generated from `negative_prompt` input argument.
|
||||
negative_prompt_attention_mask (`torch.FloatTensor`, *optional*):
|
||||
Pre-generated attention mask for negative text embeddings.
|
||||
decode_timestep (`float`, defaults to `0.0`):
|
||||
The timestep at which generated video is decoded.
|
||||
decode_noise_scale (`float`, defaults to `None`):
|
||||
The interpolation factor between random noise and denoised latents at the decode timestep.
|
||||
output_type (`str`, *optional*, defaults to `"pil"`):
|
||||
The output format of the generate image. Choose between
|
||||
[PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
|
||||
return_dict (`bool`, *optional*, defaults to `True`):
|
||||
Whether or not to return a [`~pipelines.ltx.LTXPipelineOutput`] instead of a plain tuple.
|
||||
attention_kwargs (`dict`, *optional*):
|
||||
A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
|
||||
`self.processor` in
|
||||
[diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
|
||||
callback_on_step_end (`Callable`, *optional*):
|
||||
A function that calls at the end of each denoising steps during the inference. The function is called
|
||||
with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int,
|
||||
callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by
|
||||
`callback_on_step_end_tensor_inputs`.
|
||||
callback_on_step_end_tensor_inputs (`List`, *optional*):
|
||||
The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
|
||||
will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
|
||||
`._callback_tensor_inputs` attribute of your pipeline class.
|
||||
max_sequence_length (`int` defaults to `128 `):
|
||||
Maximum sequence length to use with the `prompt`.
|
||||
|
||||
Examples:
|
||||
|
||||
Returns:
|
||||
[`~pipelines.ltx.LTXPipelineOutput`] or `tuple`:
|
||||
If `return_dict` is `True`, [`~pipelines.ltx.LTXPipelineOutput`] is returned, otherwise a `tuple` is
|
||||
returned where the first element is a list with the generated images.
|
||||
"""
|
||||
|
||||
if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)):
|
||||
callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs
|
||||
|
||||
# 1. Check inputs. Raise error if not correct
|
||||
self.check_inputs(
|
||||
prompt=prompt,
|
||||
height=height,
|
||||
width=width,
|
||||
callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs,
|
||||
prompt_embeds=prompt_embeds,
|
||||
negative_prompt_embeds=negative_prompt_embeds,
|
||||
prompt_attention_mask=prompt_attention_mask,
|
||||
negative_prompt_attention_mask=negative_prompt_attention_mask,
|
||||
)
|
||||
|
||||
self._stg_scale = stg_scale
|
||||
self._guidance_scale = guidance_scale
|
||||
self._attention_kwargs = attention_kwargs
|
||||
self._interrupt = False
|
||||
|
||||
if self.do_spatio_temporal_guidance:
|
||||
for i in stg_applied_layers_idx:
|
||||
self.transformer.transformer_blocks[i].forward = types.MethodType(
|
||||
forward_with_stg, self.transformer.transformer_blocks[i]
|
||||
)
|
||||
|
||||
# 2. Define call parameters
|
||||
if prompt is not None and isinstance(prompt, str):
|
||||
batch_size = 1
|
||||
elif prompt is not None and isinstance(prompt, list):
|
||||
batch_size = len(prompt)
|
||||
else:
|
||||
batch_size = prompt_embeds.shape[0]
|
||||
|
||||
device = self._execution_device
|
||||
|
||||
# 3. Prepare text embeddings
|
||||
(
|
||||
prompt_embeds,
|
||||
prompt_attention_mask,
|
||||
negative_prompt_embeds,
|
||||
negative_prompt_attention_mask,
|
||||
) = self.encode_prompt(
|
||||
prompt=prompt,
|
||||
negative_prompt=negative_prompt,
|
||||
do_classifier_free_guidance=self.do_classifier_free_guidance,
|
||||
num_videos_per_prompt=num_videos_per_prompt,
|
||||
prompt_embeds=prompt_embeds,
|
||||
negative_prompt_embeds=negative_prompt_embeds,
|
||||
prompt_attention_mask=prompt_attention_mask,
|
||||
negative_prompt_attention_mask=negative_prompt_attention_mask,
|
||||
max_sequence_length=max_sequence_length,
|
||||
device=device,
|
||||
)
|
||||
if self.do_classifier_free_guidance and not self.do_spatio_temporal_guidance:
|
||||
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)
|
||||
prompt_attention_mask = torch.cat([negative_prompt_attention_mask, prompt_attention_mask], dim=0)
|
||||
elif self.do_classifier_free_guidance and self.do_spatio_temporal_guidance:
|
||||
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds, prompt_embeds], dim=0)
|
||||
prompt_attention_mask = torch.cat(
|
||||
[negative_prompt_attention_mask, prompt_attention_mask, prompt_attention_mask], dim=0
|
||||
)
|
||||
|
||||
# 4. Prepare latent variables
|
||||
num_channels_latents = self.transformer.config.in_channels
|
||||
latents = self.prepare_latents(
|
||||
batch_size * num_videos_per_prompt,
|
||||
num_channels_latents,
|
||||
height,
|
||||
width,
|
||||
num_frames,
|
||||
torch.float32,
|
||||
device,
|
||||
generator,
|
||||
latents,
|
||||
)
|
||||
|
||||
# 5. Prepare timesteps
|
||||
latent_num_frames = (num_frames - 1) // self.vae_temporal_compression_ratio + 1
|
||||
latent_height = height // self.vae_spatial_compression_ratio
|
||||
latent_width = width // self.vae_spatial_compression_ratio
|
||||
video_sequence_length = latent_num_frames * latent_height * latent_width
|
||||
sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps)
|
||||
mu = calculate_shift(
|
||||
video_sequence_length,
|
||||
self.scheduler.config.get("base_image_seq_len", 256),
|
||||
self.scheduler.config.get("max_image_seq_len", 4096),
|
||||
self.scheduler.config.get("base_shift", 0.5),
|
||||
self.scheduler.config.get("max_shift", 1.16),
|
||||
)
|
||||
timesteps, num_inference_steps = retrieve_timesteps(
|
||||
self.scheduler,
|
||||
num_inference_steps,
|
||||
device,
|
||||
timesteps,
|
||||
sigmas=sigmas,
|
||||
mu=mu,
|
||||
)
|
||||
num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
|
||||
self._num_timesteps = len(timesteps)
|
||||
|
||||
# 6. Prepare micro-conditions
|
||||
latent_frame_rate = frame_rate / self.vae_temporal_compression_ratio
|
||||
rope_interpolation_scale = (
|
||||
1 / latent_frame_rate,
|
||||
self.vae_spatial_compression_ratio,
|
||||
self.vae_spatial_compression_ratio,
|
||||
)
|
||||
|
||||
# 7. Denoising loop
|
||||
with self.progress_bar(total=num_inference_steps) as progress_bar:
|
||||
for i, t in enumerate(timesteps):
|
||||
if self.interrupt:
|
||||
continue
|
||||
|
||||
if self.do_classifier_free_guidance and not self.do_spatio_temporal_guidance:
|
||||
latent_model_input = torch.cat([latents] * 2)
|
||||
elif self.do_classifier_free_guidance and self.do_spatio_temporal_guidance:
|
||||
latent_model_input = torch.cat([latents] * 3)
|
||||
else:
|
||||
latent_model_input = latents
|
||||
|
||||
latent_model_input = latent_model_input.to(prompt_embeds.dtype)
|
||||
|
||||
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
|
||||
timestep = t.expand(latent_model_input.shape[0])
|
||||
|
||||
noise_pred = self.transformer(
|
||||
hidden_states=latent_model_input,
|
||||
encoder_hidden_states=prompt_embeds,
|
||||
timestep=timestep,
|
||||
encoder_attention_mask=prompt_attention_mask,
|
||||
num_frames=latent_num_frames,
|
||||
height=latent_height,
|
||||
width=latent_width,
|
||||
rope_interpolation_scale=rope_interpolation_scale,
|
||||
attention_kwargs=attention_kwargs,
|
||||
return_dict=False,
|
||||
)[0]
|
||||
noise_pred = noise_pred.float()
|
||||
|
||||
if self.do_classifier_free_guidance and not self.do_spatio_temporal_guidance:
|
||||
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
|
||||
noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond)
|
||||
elif self.do_classifier_free_guidance and self.do_spatio_temporal_guidance:
|
||||
noise_pred_uncond, noise_pred_text, noise_pred_perturb = noise_pred.chunk(3)
|
||||
noise_pred = (
|
||||
noise_pred_uncond
|
||||
+ self.guidance_scale * (noise_pred_text - noise_pred_uncond)
|
||||
+ self._stg_scale * (noise_pred_text - noise_pred_perturb)
|
||||
)
|
||||
|
||||
if do_rescaling:
|
||||
rescaling_scale = 0.7
|
||||
factor = noise_pred_text.std() / noise_pred.std()
|
||||
factor = rescaling_scale * factor + (1 - rescaling_scale)
|
||||
noise_pred = noise_pred * factor
|
||||
|
||||
# compute the previous noisy sample x_t -> x_t-1
|
||||
latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0]
|
||||
|
||||
if callback_on_step_end is not None:
|
||||
callback_kwargs = {}
|
||||
for k in callback_on_step_end_tensor_inputs:
|
||||
callback_kwargs[k] = locals()[k]
|
||||
callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
|
||||
|
||||
latents = callback_outputs.pop("latents", latents)
|
||||
prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
|
||||
|
||||
# call the callback, if provided
|
||||
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
|
||||
progress_bar.update()
|
||||
|
||||
if XLA_AVAILABLE:
|
||||
xm.mark_step()
|
||||
|
||||
if output_type == "latent":
|
||||
video = latents
|
||||
else:
|
||||
latents = self._unpack_latents(
|
||||
latents,
|
||||
latent_num_frames,
|
||||
latent_height,
|
||||
latent_width,
|
||||
self.transformer_spatial_patch_size,
|
||||
self.transformer_temporal_patch_size,
|
||||
)
|
||||
latents = self._denormalize_latents(
|
||||
latents, self.vae.latents_mean, self.vae.latents_std, self.vae.config.scaling_factor
|
||||
)
|
||||
latents = latents.to(prompt_embeds.dtype)
|
||||
|
||||
if not self.vae.config.timestep_conditioning:
|
||||
timestep = None
|
||||
else:
|
||||
noise = randn_tensor(latents.shape, generator=generator, device=device, dtype=latents.dtype)
|
||||
if not isinstance(decode_timestep, list):
|
||||
decode_timestep = [decode_timestep] * batch_size
|
||||
if decode_noise_scale is None:
|
||||
decode_noise_scale = decode_timestep
|
||||
elif not isinstance(decode_noise_scale, list):
|
||||
decode_noise_scale = [decode_noise_scale] * batch_size
|
||||
|
||||
timestep = torch.tensor(decode_timestep, device=device, dtype=latents.dtype)
|
||||
decode_noise_scale = torch.tensor(decode_noise_scale, device=device, dtype=latents.dtype)[
|
||||
:, None, None, None, None
|
||||
]
|
||||
latents = (1 - decode_noise_scale) * latents + decode_noise_scale * noise
|
||||
|
||||
video = self.vae.decode(latents, timestep, return_dict=False)[0]
|
||||
video = self.video_processor.postprocess_video(video, output_type=output_type)
|
||||
|
||||
# Offload all models
|
||||
self.maybe_free_model_hooks()
|
||||
|
||||
if not return_dict:
|
||||
return (video,)
|
||||
|
||||
return LTXPipelineOutput(frames=video)
|
||||
@@ -0,0 +1,985 @@
|
||||
# Copyright 2024 Lightricks and The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import inspect
|
||||
import types
|
||||
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from transformers import T5EncoderModel, T5TokenizerFast
|
||||
|
||||
from diffusers.callbacks import MultiPipelineCallbacks, PipelineCallback
|
||||
from diffusers.image_processor import PipelineImageInput
|
||||
from diffusers.loaders import FromSingleFileMixin, LTXVideoLoraLoaderMixin
|
||||
from diffusers.models.autoencoders import AutoencoderKLLTXVideo
|
||||
from diffusers.models.transformers import LTXVideoTransformer3DModel
|
||||
from diffusers.pipelines.ltx.pipeline_output import LTXPipelineOutput
|
||||
from diffusers.pipelines.pipeline_utils import DiffusionPipeline
|
||||
from diffusers.schedulers import FlowMatchEulerDiscreteScheduler
|
||||
from diffusers.utils import is_torch_xla_available, logging, replace_example_docstring
|
||||
from diffusers.utils.torch_utils import randn_tensor
|
||||
from diffusers.video_processor import VideoProcessor
|
||||
|
||||
|
||||
if is_torch_xla_available():
|
||||
import torch_xla.core.xla_model as xm
|
||||
|
||||
XLA_AVAILABLE = True
|
||||
else:
|
||||
XLA_AVAILABLE = False
|
||||
|
||||
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||
|
||||
EXAMPLE_DOC_STRING = """
|
||||
Examples:
|
||||
```py
|
||||
>>> import torch
|
||||
>>> from diffusers.utils import export_to_video, load_image
|
||||
>>> from examples.community.pipeline_stg_ltx_image2video import LTXImageToVideoSTGPipeline
|
||||
|
||||
>>> pipe = LTXImageToVideoSTGPipeline.from_pretrained("Lightricks/LTX-Video", torch_dtype=torch.bfloat16)
|
||||
>>> pipe.to("cuda")
|
||||
|
||||
>>> image = load_image(
|
||||
... "https://huggingface.co/datasets/a-r-r-o-w/tiny-meme-dataset-captioned/resolve/main/images/11.png"
|
||||
>>> )
|
||||
>>> prompt = "A medieval fantasy scene featuring a rugged man with shoulder-length brown hair and a beard. He wears a dark leather tunic over a maroon shirt with intricate metal details. His facial expression is serious and intense, and he is making a gesture with his right hand, forming a small circle with his thumb and index finger. The warm golden lighting casts dramatic shadows on his face. The background includes an ornate stone arch and blurred medieval-style decor, creating an epic atmosphere."
|
||||
>>> negative_prompt = "worst quality, inconsistent motion, blurry, jittery, distorted"
|
||||
|
||||
>>> # Configure STG mode options
|
||||
>>> stg_applied_layers_idx = [19] # Layer indices from 0 to 41
|
||||
>>> stg_scale = 1.0 # Set 0.0 for CFG
|
||||
>>> do_rescaling = False
|
||||
|
||||
>>> video = pipe(
|
||||
... image=image,
|
||||
... prompt=prompt,
|
||||
... negative_prompt=negative_prompt,
|
||||
... width=704,
|
||||
... height=480,
|
||||
... num_frames=161,
|
||||
... num_inference_steps=50,
|
||||
... stg_applied_layers_idx=stg_applied_layers_idx,
|
||||
... stg_scale=stg_scale,
|
||||
... do_rescaling=do_rescaling,
|
||||
>>> ).frames[0]
|
||||
>>> export_to_video(video, "output.mp4", fps=24)
|
||||
```
|
||||
"""
|
||||
|
||||
|
||||
def forward_with_stg(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
encoder_hidden_states: torch.Tensor,
|
||||
temb: torch.Tensor,
|
||||
image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
|
||||
encoder_attention_mask: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
hidden_states_ptb = hidden_states[2:]
|
||||
encoder_hidden_states_ptb = encoder_hidden_states[2:]
|
||||
|
||||
batch_size = hidden_states.size(0)
|
||||
norm_hidden_states = self.norm1(hidden_states)
|
||||
|
||||
num_ada_params = self.scale_shift_table.shape[0]
|
||||
ada_values = self.scale_shift_table[None, None] + temb.reshape(batch_size, temb.size(1), num_ada_params, -1)
|
||||
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = ada_values.unbind(dim=2)
|
||||
norm_hidden_states = norm_hidden_states * (1 + scale_msa) + shift_msa
|
||||
|
||||
attn_hidden_states = self.attn1(
|
||||
hidden_states=norm_hidden_states,
|
||||
encoder_hidden_states=None,
|
||||
image_rotary_emb=image_rotary_emb,
|
||||
)
|
||||
hidden_states = hidden_states + attn_hidden_states * gate_msa
|
||||
|
||||
attn_hidden_states = self.attn2(
|
||||
hidden_states,
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
image_rotary_emb=None,
|
||||
attention_mask=encoder_attention_mask,
|
||||
)
|
||||
hidden_states = hidden_states + attn_hidden_states
|
||||
norm_hidden_states = self.norm2(hidden_states) * (1 + scale_mlp) + shift_mlp
|
||||
|
||||
ff_output = self.ff(norm_hidden_states)
|
||||
hidden_states = hidden_states + ff_output * gate_mlp
|
||||
|
||||
hidden_states[2:] = hidden_states_ptb
|
||||
encoder_hidden_states[2:] = encoder_hidden_states_ptb
|
||||
|
||||
return hidden_states
|
||||
|
||||
|
||||
# Copied from diffusers.pipelines.flux.pipeline_flux.calculate_shift
|
||||
def calculate_shift(
|
||||
image_seq_len,
|
||||
base_seq_len: int = 256,
|
||||
max_seq_len: int = 4096,
|
||||
base_shift: float = 0.5,
|
||||
max_shift: float = 1.16,
|
||||
):
|
||||
m = (max_shift - base_shift) / (max_seq_len - base_seq_len)
|
||||
b = base_shift - m * base_seq_len
|
||||
mu = image_seq_len * m + b
|
||||
return mu
|
||||
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps
|
||||
def retrieve_timesteps(
|
||||
scheduler,
|
||||
num_inference_steps: Optional[int] = None,
|
||||
device: Optional[Union[str, torch.device]] = None,
|
||||
timesteps: Optional[List[int]] = None,
|
||||
sigmas: Optional[List[float]] = None,
|
||||
**kwargs,
|
||||
):
|
||||
r"""
|
||||
Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
|
||||
custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.
|
||||
|
||||
Args:
|
||||
scheduler (`SchedulerMixin`):
|
||||
The scheduler to get timesteps from.
|
||||
num_inference_steps (`int`):
|
||||
The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps`
|
||||
must be `None`.
|
||||
device (`str` or `torch.device`, *optional*):
|
||||
The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
|
||||
timesteps (`List[int]`, *optional*):
|
||||
Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed,
|
||||
`num_inference_steps` and `sigmas` must be `None`.
|
||||
sigmas (`List[float]`, *optional*):
|
||||
Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed,
|
||||
`num_inference_steps` and `timesteps` must be `None`.
|
||||
|
||||
Returns:
|
||||
`Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the
|
||||
second element is the number of inference steps.
|
||||
"""
|
||||
if timesteps is not None and sigmas is not None:
|
||||
raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values")
|
||||
if timesteps is not None:
|
||||
accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
|
||||
if not accepts_timesteps:
|
||||
raise ValueError(
|
||||
f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
|
||||
f" timestep schedules. Please check whether you are using the correct scheduler."
|
||||
)
|
||||
scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)
|
||||
timesteps = scheduler.timesteps
|
||||
num_inference_steps = len(timesteps)
|
||||
elif sigmas is not None:
|
||||
accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
|
||||
if not accept_sigmas:
|
||||
raise ValueError(
|
||||
f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
|
||||
f" sigmas schedules. Please check whether you are using the correct scheduler."
|
||||
)
|
||||
scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs)
|
||||
timesteps = scheduler.timesteps
|
||||
num_inference_steps = len(timesteps)
|
||||
else:
|
||||
scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)
|
||||
timesteps = scheduler.timesteps
|
||||
return timesteps, num_inference_steps
|
||||
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents
|
||||
def retrieve_latents(
|
||||
encoder_output: torch.Tensor, generator: Optional[torch.Generator] = None, sample_mode: str = "sample"
|
||||
):
|
||||
if hasattr(encoder_output, "latent_dist") and sample_mode == "sample":
|
||||
return encoder_output.latent_dist.sample(generator)
|
||||
elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax":
|
||||
return encoder_output.latent_dist.mode()
|
||||
elif hasattr(encoder_output, "latents"):
|
||||
return encoder_output.latents
|
||||
else:
|
||||
raise AttributeError("Could not access latents of provided encoder_output")
|
||||
|
||||
|
||||
class LTXImageToVideoSTGPipeline(DiffusionPipeline, FromSingleFileMixin, LTXVideoLoraLoaderMixin):
|
||||
r"""
|
||||
Pipeline for image-to-video generation.
|
||||
|
||||
Reference: https://github.com/Lightricks/LTX-Video
|
||||
|
||||
Args:
|
||||
transformer ([`LTXVideoTransformer3DModel`]):
|
||||
Conditional Transformer architecture to denoise the encoded video latents.
|
||||
scheduler ([`FlowMatchEulerDiscreteScheduler`]):
|
||||
A scheduler to be used in combination with `transformer` to denoise the encoded image latents.
|
||||
vae ([`AutoencoderKLLTXVideo`]):
|
||||
Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.
|
||||
text_encoder ([`T5EncoderModel`]):
|
||||
[T5](https://huggingface.co/docs/transformers/en/model_doc/t5#transformers.T5EncoderModel), specifically
|
||||
the [google/t5-v1_1-xxl](https://huggingface.co/google/t5-v1_1-xxl) variant.
|
||||
tokenizer (`CLIPTokenizer`):
|
||||
Tokenizer of class
|
||||
[CLIPTokenizer](https://huggingface.co/docs/transformers/en/model_doc/clip#transformers.CLIPTokenizer).
|
||||
tokenizer (`T5TokenizerFast`):
|
||||
Second Tokenizer of class
|
||||
[T5TokenizerFast](https://huggingface.co/docs/transformers/en/model_doc/t5#transformers.T5TokenizerFast).
|
||||
"""
|
||||
|
||||
model_cpu_offload_seq = "text_encoder->transformer->vae"
|
||||
_optional_components = []
|
||||
_callback_tensor_inputs = ["latents", "prompt_embeds", "negative_prompt_embeds"]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
scheduler: FlowMatchEulerDiscreteScheduler,
|
||||
vae: AutoencoderKLLTXVideo,
|
||||
text_encoder: T5EncoderModel,
|
||||
tokenizer: T5TokenizerFast,
|
||||
transformer: LTXVideoTransformer3DModel,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.register_modules(
|
||||
vae=vae,
|
||||
text_encoder=text_encoder,
|
||||
tokenizer=tokenizer,
|
||||
transformer=transformer,
|
||||
scheduler=scheduler,
|
||||
)
|
||||
|
||||
self.vae_spatial_compression_ratio = (
|
||||
self.vae.spatial_compression_ratio if getattr(self, "vae", None) is not None else 32
|
||||
)
|
||||
self.vae_temporal_compression_ratio = (
|
||||
self.vae.temporal_compression_ratio if getattr(self, "vae", None) is not None else 8
|
||||
)
|
||||
self.transformer_spatial_patch_size = (
|
||||
self.transformer.config.patch_size if getattr(self, "transformer", None) is not None else 1
|
||||
)
|
||||
self.transformer_temporal_patch_size = (
|
||||
self.transformer.config.patch_size_t if getattr(self, "transformer") is not None else 1
|
||||
)
|
||||
|
||||
self.video_processor = VideoProcessor(vae_scale_factor=self.vae_spatial_compression_ratio)
|
||||
self.tokenizer_max_length = (
|
||||
self.tokenizer.model_max_length if getattr(self, "tokenizer", None) is not None else 128
|
||||
)
|
||||
|
||||
self.default_height = 512
|
||||
self.default_width = 704
|
||||
self.default_frames = 121
|
||||
|
||||
def _get_t5_prompt_embeds(
|
||||
self,
|
||||
prompt: Union[str, List[str]] = None,
|
||||
num_videos_per_prompt: int = 1,
|
||||
max_sequence_length: int = 128,
|
||||
device: Optional[torch.device] = None,
|
||||
dtype: Optional[torch.dtype] = None,
|
||||
):
|
||||
device = device or self._execution_device
|
||||
dtype = dtype or self.text_encoder.dtype
|
||||
|
||||
prompt = [prompt] if isinstance(prompt, str) else prompt
|
||||
batch_size = len(prompt)
|
||||
|
||||
text_inputs = self.tokenizer(
|
||||
prompt,
|
||||
padding="max_length",
|
||||
max_length=max_sequence_length,
|
||||
truncation=True,
|
||||
add_special_tokens=True,
|
||||
return_tensors="pt",
|
||||
)
|
||||
text_input_ids = text_inputs.input_ids
|
||||
prompt_attention_mask = text_inputs.attention_mask
|
||||
prompt_attention_mask = prompt_attention_mask.bool().to(device)
|
||||
|
||||
untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
|
||||
|
||||
if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids):
|
||||
removed_text = self.tokenizer.batch_decode(untruncated_ids[:, max_sequence_length - 1 : -1])
|
||||
logger.warning(
|
||||
"The following part of your input was truncated because `max_sequence_length` is set to "
|
||||
f" {max_sequence_length} tokens: {removed_text}"
|
||||
)
|
||||
|
||||
prompt_embeds = self.text_encoder(text_input_ids.to(device))[0]
|
||||
prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)
|
||||
|
||||
# duplicate text embeddings for each generation per prompt, using mps friendly method
|
||||
_, seq_len, _ = prompt_embeds.shape
|
||||
prompt_embeds = prompt_embeds.repeat(1, num_videos_per_prompt, 1)
|
||||
prompt_embeds = prompt_embeds.view(batch_size * num_videos_per_prompt, seq_len, -1)
|
||||
|
||||
prompt_attention_mask = prompt_attention_mask.view(batch_size, -1)
|
||||
prompt_attention_mask = prompt_attention_mask.repeat(num_videos_per_prompt, 1)
|
||||
|
||||
return prompt_embeds, prompt_attention_mask
|
||||
|
||||
# Copied from diffusers.pipelines.mochi.pipeline_mochi.MochiPipeline.encode_prompt with 256->128
|
||||
def encode_prompt(
|
||||
self,
|
||||
prompt: Union[str, List[str]],
|
||||
negative_prompt: Optional[Union[str, List[str]]] = None,
|
||||
do_classifier_free_guidance: bool = True,
|
||||
num_videos_per_prompt: int = 1,
|
||||
prompt_embeds: Optional[torch.Tensor] = None,
|
||||
negative_prompt_embeds: Optional[torch.Tensor] = None,
|
||||
prompt_attention_mask: Optional[torch.Tensor] = None,
|
||||
negative_prompt_attention_mask: Optional[torch.Tensor] = None,
|
||||
max_sequence_length: int = 128,
|
||||
device: Optional[torch.device] = None,
|
||||
dtype: Optional[torch.dtype] = None,
|
||||
):
|
||||
r"""
|
||||
Encodes the prompt into text encoder hidden states.
|
||||
|
||||
Args:
|
||||
prompt (`str` or `List[str]`, *optional*):
|
||||
prompt to be encoded
|
||||
negative_prompt (`str` or `List[str]`, *optional*):
|
||||
The prompt or prompts not to guide the image generation. If not defined, one has to pass
|
||||
`negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
|
||||
less than `1`).
|
||||
do_classifier_free_guidance (`bool`, *optional*, defaults to `True`):
|
||||
Whether to use classifier free guidance or not.
|
||||
num_videos_per_prompt (`int`, *optional*, defaults to 1):
|
||||
Number of videos that should be generated per prompt. torch device to place the resulting embeddings on
|
||||
prompt_embeds (`torch.Tensor`, *optional*):
|
||||
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
|
||||
provided, text embeddings will be generated from `prompt` input argument.
|
||||
negative_prompt_embeds (`torch.Tensor`, *optional*):
|
||||
Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
|
||||
weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
|
||||
argument.
|
||||
device: (`torch.device`, *optional*):
|
||||
torch device
|
||||
dtype: (`torch.dtype`, *optional*):
|
||||
torch dtype
|
||||
"""
|
||||
device = device or self._execution_device
|
||||
|
||||
prompt = [prompt] if isinstance(prompt, str) else prompt
|
||||
if prompt is not None:
|
||||
batch_size = len(prompt)
|
||||
else:
|
||||
batch_size = prompt_embeds.shape[0]
|
||||
|
||||
if prompt_embeds is None:
|
||||
prompt_embeds, prompt_attention_mask = self._get_t5_prompt_embeds(
|
||||
prompt=prompt,
|
||||
num_videos_per_prompt=num_videos_per_prompt,
|
||||
max_sequence_length=max_sequence_length,
|
||||
device=device,
|
||||
dtype=dtype,
|
||||
)
|
||||
|
||||
if do_classifier_free_guidance and negative_prompt_embeds is None:
|
||||
negative_prompt = negative_prompt or ""
|
||||
negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt
|
||||
|
||||
if prompt is not None and type(prompt) is not type(negative_prompt):
|
||||
raise TypeError(
|
||||
f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
|
||||
f" {type(prompt)}."
|
||||
)
|
||||
elif batch_size != len(negative_prompt):
|
||||
raise ValueError(
|
||||
f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
|
||||
f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
|
||||
" the batch size of `prompt`."
|
||||
)
|
||||
|
||||
negative_prompt_embeds, negative_prompt_attention_mask = self._get_t5_prompt_embeds(
|
||||
prompt=negative_prompt,
|
||||
num_videos_per_prompt=num_videos_per_prompt,
|
||||
max_sequence_length=max_sequence_length,
|
||||
device=device,
|
||||
dtype=dtype,
|
||||
)
|
||||
|
||||
return prompt_embeds, prompt_attention_mask, negative_prompt_embeds, negative_prompt_attention_mask
|
||||
|
||||
# Copied from diffusers.pipelines.ltx.pipeline_ltx.LTXPipeline.check_inputs
|
||||
def check_inputs(
|
||||
self,
|
||||
prompt,
|
||||
height,
|
||||
width,
|
||||
callback_on_step_end_tensor_inputs=None,
|
||||
prompt_embeds=None,
|
||||
negative_prompt_embeds=None,
|
||||
prompt_attention_mask=None,
|
||||
negative_prompt_attention_mask=None,
|
||||
):
|
||||
if height % 32 != 0 or width % 32 != 0:
|
||||
raise ValueError(f"`height` and `width` have to be divisible by 32 but are {height} and {width}.")
|
||||
|
||||
if callback_on_step_end_tensor_inputs is not None and not all(
|
||||
k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
|
||||
):
|
||||
raise ValueError(
|
||||
f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}"
|
||||
)
|
||||
|
||||
if prompt is not None and prompt_embeds is not None:
|
||||
raise ValueError(
|
||||
f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
|
||||
" only forward one of the two."
|
||||
)
|
||||
elif prompt is None and prompt_embeds is None:
|
||||
raise ValueError(
|
||||
"Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
|
||||
)
|
||||
elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
|
||||
raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
|
||||
|
||||
if prompt_embeds is not None and prompt_attention_mask is None:
|
||||
raise ValueError("Must provide `prompt_attention_mask` when specifying `prompt_embeds`.")
|
||||
|
||||
if negative_prompt_embeds is not None and negative_prompt_attention_mask is None:
|
||||
raise ValueError("Must provide `negative_prompt_attention_mask` when specifying `negative_prompt_embeds`.")
|
||||
|
||||
if prompt_embeds is not None and negative_prompt_embeds is not None:
|
||||
if prompt_embeds.shape != negative_prompt_embeds.shape:
|
||||
raise ValueError(
|
||||
"`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but"
|
||||
f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`"
|
||||
f" {negative_prompt_embeds.shape}."
|
||||
)
|
||||
if prompt_attention_mask.shape != negative_prompt_attention_mask.shape:
|
||||
raise ValueError(
|
||||
"`prompt_attention_mask` and `negative_prompt_attention_mask` must have the same shape when passed directly, but"
|
||||
f" got: `prompt_attention_mask` {prompt_attention_mask.shape} != `negative_prompt_attention_mask`"
|
||||
f" {negative_prompt_attention_mask.shape}."
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
# Copied from diffusers.pipelines.ltx.pipeline_ltx.LTXPipeline._pack_latents
|
||||
def _pack_latents(latents: torch.Tensor, patch_size: int = 1, patch_size_t: int = 1) -> torch.Tensor:
|
||||
# Unpacked latents of shape are [B, C, F, H, W] are patched into tokens of shape [B, C, F // p_t, p_t, H // p, p, W // p, p].
|
||||
# The patch dimensions are then permuted and collapsed into the channel dimension of shape:
|
||||
# [B, F // p_t * H // p * W // p, C * p_t * p * p] (an ndim=3 tensor).
|
||||
# dim=0 is the batch size, dim=1 is the effective video sequence length, dim=2 is the effective number of input features
|
||||
batch_size, num_channels, num_frames, height, width = latents.shape
|
||||
post_patch_num_frames = num_frames // patch_size_t
|
||||
post_patch_height = height // patch_size
|
||||
post_patch_width = width // patch_size
|
||||
latents = latents.reshape(
|
||||
batch_size,
|
||||
-1,
|
||||
post_patch_num_frames,
|
||||
patch_size_t,
|
||||
post_patch_height,
|
||||
patch_size,
|
||||
post_patch_width,
|
||||
patch_size,
|
||||
)
|
||||
latents = latents.permute(0, 2, 4, 6, 1, 3, 5, 7).flatten(4, 7).flatten(1, 3)
|
||||
return latents
|
||||
|
||||
@staticmethod
|
||||
# Copied from diffusers.pipelines.ltx.pipeline_ltx.LTXPipeline._unpack_latents
|
||||
def _unpack_latents(
|
||||
latents: torch.Tensor, num_frames: int, height: int, width: int, patch_size: int = 1, patch_size_t: int = 1
|
||||
) -> torch.Tensor:
|
||||
# Packed latents of shape [B, S, D] (S is the effective video sequence length, D is the effective feature dimensions)
|
||||
# are unpacked and reshaped into a video tensor of shape [B, C, F, H, W]. This is the inverse operation of
|
||||
# what happens in the `_pack_latents` method.
|
||||
batch_size = latents.size(0)
|
||||
latents = latents.reshape(batch_size, num_frames, height, width, -1, patch_size_t, patch_size, patch_size)
|
||||
latents = latents.permute(0, 4, 1, 5, 2, 6, 3, 7).flatten(6, 7).flatten(4, 5).flatten(2, 3)
|
||||
return latents
|
||||
|
||||
@staticmethod
|
||||
# Copied from diffusers.pipelines.ltx.pipeline_ltx.LTXPipeline._normalize_latents
|
||||
def _normalize_latents(
|
||||
latents: torch.Tensor, latents_mean: torch.Tensor, latents_std: torch.Tensor, scaling_factor: float = 1.0
|
||||
) -> torch.Tensor:
|
||||
# Normalize latents across the channel dimension [B, C, F, H, W]
|
||||
latents_mean = latents_mean.view(1, -1, 1, 1, 1).to(latents.device, latents.dtype)
|
||||
latents_std = latents_std.view(1, -1, 1, 1, 1).to(latents.device, latents.dtype)
|
||||
latents = (latents - latents_mean) * scaling_factor / latents_std
|
||||
return latents
|
||||
|
||||
@staticmethod
|
||||
# Copied from diffusers.pipelines.ltx.pipeline_ltx.LTXPipeline._denormalize_latents
|
||||
def _denormalize_latents(
|
||||
latents: torch.Tensor, latents_mean: torch.Tensor, latents_std: torch.Tensor, scaling_factor: float = 1.0
|
||||
) -> torch.Tensor:
|
||||
# Denormalize latents across the channel dimension [B, C, F, H, W]
|
||||
latents_mean = latents_mean.view(1, -1, 1, 1, 1).to(latents.device, latents.dtype)
|
||||
latents_std = latents_std.view(1, -1, 1, 1, 1).to(latents.device, latents.dtype)
|
||||
latents = latents * latents_std / scaling_factor + latents_mean
|
||||
return latents
|
||||
|
||||
def prepare_latents(
|
||||
self,
|
||||
image: Optional[torch.Tensor] = None,
|
||||
batch_size: int = 1,
|
||||
num_channels_latents: int = 128,
|
||||
height: int = 512,
|
||||
width: int = 704,
|
||||
num_frames: int = 161,
|
||||
dtype: Optional[torch.dtype] = None,
|
||||
device: Optional[torch.device] = None,
|
||||
generator: Optional[torch.Generator] = None,
|
||||
latents: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
height = height // self.vae_spatial_compression_ratio
|
||||
width = width // self.vae_spatial_compression_ratio
|
||||
num_frames = (
|
||||
(num_frames - 1) // self.vae_temporal_compression_ratio + 1 if latents is None else latents.size(2)
|
||||
)
|
||||
|
||||
shape = (batch_size, num_channels_latents, num_frames, height, width)
|
||||
mask_shape = (batch_size, 1, num_frames, height, width)
|
||||
|
||||
if latents is not None:
|
||||
conditioning_mask = latents.new_zeros(shape)
|
||||
conditioning_mask[:, :, 0] = 1.0
|
||||
conditioning_mask = self._pack_latents(
|
||||
conditioning_mask, self.transformer_spatial_patch_size, self.transformer_temporal_patch_size
|
||||
)
|
||||
return latents.to(device=device, dtype=dtype), conditioning_mask
|
||||
|
||||
if isinstance(generator, list):
|
||||
if len(generator) != batch_size:
|
||||
raise ValueError(
|
||||
f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
|
||||
f" size of {batch_size}. Make sure the batch size matches the length of the generators."
|
||||
)
|
||||
|
||||
init_latents = [
|
||||
retrieve_latents(self.vae.encode(image[i].unsqueeze(0).unsqueeze(2)), generator[i])
|
||||
for i in range(batch_size)
|
||||
]
|
||||
else:
|
||||
init_latents = [
|
||||
retrieve_latents(self.vae.encode(img.unsqueeze(0).unsqueeze(2)), generator) for img in image
|
||||
]
|
||||
|
||||
init_latents = torch.cat(init_latents, dim=0).to(dtype)
|
||||
init_latents = self._normalize_latents(init_latents, self.vae.latents_mean, self.vae.latents_std)
|
||||
init_latents = init_latents.repeat(1, 1, num_frames, 1, 1)
|
||||
conditioning_mask = torch.zeros(mask_shape, device=device, dtype=dtype)
|
||||
conditioning_mask[:, :, 0] = 1.0
|
||||
|
||||
noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
|
||||
latents = init_latents * conditioning_mask + noise * (1 - conditioning_mask)
|
||||
|
||||
conditioning_mask = self._pack_latents(
|
||||
conditioning_mask, self.transformer_spatial_patch_size, self.transformer_temporal_patch_size
|
||||
).squeeze(-1)
|
||||
latents = self._pack_latents(
|
||||
latents, self.transformer_spatial_patch_size, self.transformer_temporal_patch_size
|
||||
)
|
||||
|
||||
return latents, conditioning_mask
|
||||
|
||||
@property
|
||||
def guidance_scale(self):
|
||||
return self._guidance_scale
|
||||
|
||||
@property
|
||||
def do_classifier_free_guidance(self):
|
||||
return self._guidance_scale > 1.0
|
||||
|
||||
@property
|
||||
def do_spatio_temporal_guidance(self):
|
||||
return self._stg_scale > 0.0
|
||||
|
||||
@property
|
||||
def num_timesteps(self):
|
||||
return self._num_timesteps
|
||||
|
||||
@property
|
||||
def attention_kwargs(self):
|
||||
return self._attention_kwargs
|
||||
|
||||
@property
|
||||
def interrupt(self):
|
||||
return self._interrupt
|
||||
|
||||
@torch.no_grad()
|
||||
@replace_example_docstring(EXAMPLE_DOC_STRING)
|
||||
def __call__(
|
||||
self,
|
||||
image: PipelineImageInput = None,
|
||||
prompt: Union[str, List[str]] = None,
|
||||
negative_prompt: Optional[Union[str, List[str]]] = None,
|
||||
height: int = 512,
|
||||
width: int = 704,
|
||||
num_frames: int = 161,
|
||||
frame_rate: int = 25,
|
||||
num_inference_steps: int = 50,
|
||||
timesteps: List[int] = None,
|
||||
guidance_scale: float = 3,
|
||||
num_videos_per_prompt: Optional[int] = 1,
|
||||
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
|
||||
latents: Optional[torch.Tensor] = None,
|
||||
prompt_embeds: Optional[torch.Tensor] = None,
|
||||
prompt_attention_mask: Optional[torch.Tensor] = None,
|
||||
negative_prompt_embeds: Optional[torch.Tensor] = None,
|
||||
negative_prompt_attention_mask: Optional[torch.Tensor] = None,
|
||||
decode_timestep: Union[float, List[float]] = 0.0,
|
||||
decode_noise_scale: Optional[Union[float, List[float]]] = None,
|
||||
output_type: Optional[str] = "pil",
|
||||
return_dict: bool = True,
|
||||
attention_kwargs: Optional[Dict[str, Any]] = None,
|
||||
callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
|
||||
callback_on_step_end_tensor_inputs: List[str] = ["latents"],
|
||||
max_sequence_length: int = 128,
|
||||
stg_applied_layers_idx: Optional[List[int]] = [19],
|
||||
stg_scale: Optional[float] = 1.0,
|
||||
do_rescaling: Optional[bool] = False,
|
||||
):
|
||||
r"""
|
||||
Function invoked when calling the pipeline for generation.
|
||||
|
||||
Args:
|
||||
image (`PipelineImageInput`):
|
||||
The input image to condition the generation on. Must be an image, a list of images or a `torch.Tensor`.
|
||||
prompt (`str` or `List[str]`, *optional*):
|
||||
The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
|
||||
instead.
|
||||
height (`int`, defaults to `512`):
|
||||
The height in pixels of the generated image. This is set to 480 by default for the best results.
|
||||
width (`int`, defaults to `704`):
|
||||
The width in pixels of the generated image. This is set to 848 by default for the best results.
|
||||
num_frames (`int`, defaults to `161`):
|
||||
The number of video frames to generate
|
||||
num_inference_steps (`int`, *optional*, defaults to 50):
|
||||
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
|
||||
expense of slower inference.
|
||||
timesteps (`List[int]`, *optional*):
|
||||
Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument
|
||||
in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is
|
||||
passed will be used. Must be in descending order.
|
||||
guidance_scale (`float`, defaults to `3 `):
|
||||
Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
|
||||
`guidance_scale` is defined as `w` of equation 2. of [Imagen
|
||||
Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
|
||||
1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
|
||||
usually at the expense of lower image quality.
|
||||
num_videos_per_prompt (`int`, *optional*, defaults to 1):
|
||||
The number of videos to generate per prompt.
|
||||
generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
|
||||
One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
|
||||
to make generation deterministic.
|
||||
latents (`torch.Tensor`, *optional*):
|
||||
Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
|
||||
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
|
||||
tensor will ge generated by sampling using the supplied random `generator`.
|
||||
prompt_embeds (`torch.Tensor`, *optional*):
|
||||
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
|
||||
provided, text embeddings will be generated from `prompt` input argument.
|
||||
prompt_attention_mask (`torch.Tensor`, *optional*):
|
||||
Pre-generated attention mask for text embeddings.
|
||||
negative_prompt_embeds (`torch.FloatTensor`, *optional*):
|
||||
Pre-generated negative text embeddings. For PixArt-Sigma this negative prompt should be "". If not
|
||||
provided, negative_prompt_embeds will be generated from `negative_prompt` input argument.
|
||||
negative_prompt_attention_mask (`torch.FloatTensor`, *optional*):
|
||||
Pre-generated attention mask for negative text embeddings.
|
||||
decode_timestep (`float`, defaults to `0.0`):
|
||||
The timestep at which generated video is decoded.
|
||||
decode_noise_scale (`float`, defaults to `None`):
|
||||
The interpolation factor between random noise and denoised latents at the decode timestep.
|
||||
output_type (`str`, *optional*, defaults to `"pil"`):
|
||||
The output format of the generate image. Choose between
|
||||
[PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
|
||||
return_dict (`bool`, *optional*, defaults to `True`):
|
||||
Whether or not to return a [`~pipelines.ltx.LTXPipelineOutput`] instead of a plain tuple.
|
||||
attention_kwargs (`dict`, *optional*):
|
||||
A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
|
||||
`self.processor` in
|
||||
[diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
|
||||
callback_on_step_end (`Callable`, *optional*):
|
||||
A function that calls at the end of each denoising steps during the inference. The function is called
|
||||
with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int,
|
||||
callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by
|
||||
`callback_on_step_end_tensor_inputs`.
|
||||
callback_on_step_end_tensor_inputs (`List`, *optional*):
|
||||
The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
|
||||
will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
|
||||
`._callback_tensor_inputs` attribute of your pipeline class.
|
||||
max_sequence_length (`int` defaults to `128 `):
|
||||
Maximum sequence length to use with the `prompt`.
|
||||
|
||||
Examples:
|
||||
|
||||
Returns:
|
||||
[`~pipelines.ltx.LTXPipelineOutput`] or `tuple`:
|
||||
If `return_dict` is `True`, [`~pipelines.ltx.LTXPipelineOutput`] is returned, otherwise a `tuple` is
|
||||
returned where the first element is a list with the generated images.
|
||||
"""
|
||||
|
||||
if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)):
|
||||
callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs
|
||||
|
||||
# 1. Check inputs. Raise error if not correct
|
||||
self.check_inputs(
|
||||
prompt=prompt,
|
||||
height=height,
|
||||
width=width,
|
||||
callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs,
|
||||
prompt_embeds=prompt_embeds,
|
||||
negative_prompt_embeds=negative_prompt_embeds,
|
||||
prompt_attention_mask=prompt_attention_mask,
|
||||
negative_prompt_attention_mask=negative_prompt_attention_mask,
|
||||
)
|
||||
|
||||
self._stg_scale = stg_scale
|
||||
self._guidance_scale = guidance_scale
|
||||
self._attention_kwargs = attention_kwargs
|
||||
self._interrupt = False
|
||||
|
||||
if self.do_spatio_temporal_guidance:
|
||||
for i in stg_applied_layers_idx:
|
||||
self.transformer.transformer_blocks[i].forward = types.MethodType(
|
||||
forward_with_stg, self.transformer.transformer_blocks[i]
|
||||
)
|
||||
|
||||
# 2. Define call parameters
|
||||
if prompt is not None and isinstance(prompt, str):
|
||||
batch_size = 1
|
||||
elif prompt is not None and isinstance(prompt, list):
|
||||
batch_size = len(prompt)
|
||||
else:
|
||||
batch_size = prompt_embeds.shape[0]
|
||||
|
||||
device = self._execution_device
|
||||
|
||||
# 3. Prepare text embeddings
|
||||
(
|
||||
prompt_embeds,
|
||||
prompt_attention_mask,
|
||||
negative_prompt_embeds,
|
||||
negative_prompt_attention_mask,
|
||||
) = self.encode_prompt(
|
||||
prompt=prompt,
|
||||
negative_prompt=negative_prompt,
|
||||
do_classifier_free_guidance=self.do_classifier_free_guidance,
|
||||
num_videos_per_prompt=num_videos_per_prompt,
|
||||
prompt_embeds=prompt_embeds,
|
||||
negative_prompt_embeds=negative_prompt_embeds,
|
||||
prompt_attention_mask=prompt_attention_mask,
|
||||
negative_prompt_attention_mask=negative_prompt_attention_mask,
|
||||
max_sequence_length=max_sequence_length,
|
||||
device=device,
|
||||
)
|
||||
if self.do_classifier_free_guidance and not self.do_spatio_temporal_guidance:
|
||||
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)
|
||||
prompt_attention_mask = torch.cat([negative_prompt_attention_mask, prompt_attention_mask], dim=0)
|
||||
elif self.do_classifier_free_guidance and self.do_spatio_temporal_guidance:
|
||||
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds, prompt_embeds], dim=0)
|
||||
prompt_attention_mask = torch.cat(
|
||||
[negative_prompt_attention_mask, prompt_attention_mask, prompt_attention_mask], dim=0
|
||||
)
|
||||
|
||||
# 4. Prepare latent variables
|
||||
if latents is None:
|
||||
image = self.video_processor.preprocess(image, height=height, width=width)
|
||||
image = image.to(device=device, dtype=prompt_embeds.dtype)
|
||||
|
||||
num_channels_latents = self.transformer.config.in_channels
|
||||
latents, conditioning_mask = self.prepare_latents(
|
||||
image,
|
||||
batch_size * num_videos_per_prompt,
|
||||
num_channels_latents,
|
||||
height,
|
||||
width,
|
||||
num_frames,
|
||||
torch.float32,
|
||||
device,
|
||||
generator,
|
||||
latents,
|
||||
)
|
||||
|
||||
if self.do_classifier_free_guidance and not self.do_spatio_temporal_guidance:
|
||||
conditioning_mask = torch.cat([conditioning_mask, conditioning_mask])
|
||||
elif self.do_classifier_free_guidance and self.do_spatio_temporal_guidance:
|
||||
conditioning_mask = torch.cat([conditioning_mask, conditioning_mask, conditioning_mask])
|
||||
|
||||
# 5. Prepare timesteps
|
||||
latent_num_frames = (num_frames - 1) // self.vae_temporal_compression_ratio + 1
|
||||
latent_height = height // self.vae_spatial_compression_ratio
|
||||
latent_width = width // self.vae_spatial_compression_ratio
|
||||
video_sequence_length = latent_num_frames * latent_height * latent_width
|
||||
sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps)
|
||||
mu = calculate_shift(
|
||||
video_sequence_length,
|
||||
self.scheduler.config.get("base_image_seq_len", 256),
|
||||
self.scheduler.config.get("max_image_seq_len", 4096),
|
||||
self.scheduler.config.get("base_shift", 0.5),
|
||||
self.scheduler.config.get("max_shift", 1.16),
|
||||
)
|
||||
timesteps, num_inference_steps = retrieve_timesteps(
|
||||
self.scheduler,
|
||||
num_inference_steps,
|
||||
device,
|
||||
timesteps,
|
||||
sigmas=sigmas,
|
||||
mu=mu,
|
||||
)
|
||||
num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
|
||||
self._num_timesteps = len(timesteps)
|
||||
|
||||
# 6. Prepare micro-conditions
|
||||
latent_frame_rate = frame_rate / self.vae_temporal_compression_ratio
|
||||
rope_interpolation_scale = (
|
||||
1 / latent_frame_rate,
|
||||
self.vae_spatial_compression_ratio,
|
||||
self.vae_spatial_compression_ratio,
|
||||
)
|
||||
|
||||
# 7. Denoising loop
|
||||
with self.progress_bar(total=num_inference_steps) as progress_bar:
|
||||
for i, t in enumerate(timesteps):
|
||||
if self.interrupt:
|
||||
continue
|
||||
|
||||
if self.do_classifier_free_guidance and not self.do_spatio_temporal_guidance:
|
||||
latent_model_input = torch.cat([latents] * 2)
|
||||
elif self.do_classifier_free_guidance and self.do_spatio_temporal_guidance:
|
||||
latent_model_input = torch.cat([latents] * 3)
|
||||
else:
|
||||
latent_model_input = latents
|
||||
|
||||
latent_model_input = latent_model_input.to(prompt_embeds.dtype)
|
||||
|
||||
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
|
||||
timestep = t.expand(latent_model_input.shape[0])
|
||||
timestep = timestep.unsqueeze(-1) * (1 - conditioning_mask)
|
||||
|
||||
noise_pred = self.transformer(
|
||||
hidden_states=latent_model_input,
|
||||
encoder_hidden_states=prompt_embeds,
|
||||
timestep=timestep,
|
||||
encoder_attention_mask=prompt_attention_mask,
|
||||
num_frames=latent_num_frames,
|
||||
height=latent_height,
|
||||
width=latent_width,
|
||||
rope_interpolation_scale=rope_interpolation_scale,
|
||||
attention_kwargs=attention_kwargs,
|
||||
return_dict=False,
|
||||
)[0]
|
||||
noise_pred = noise_pred.float()
|
||||
|
||||
if self.do_classifier_free_guidance and not self.do_spatio_temporal_guidance:
|
||||
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
|
||||
noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond)
|
||||
timestep, _ = timestep.chunk(2)
|
||||
elif self.do_classifier_free_guidance and self.do_spatio_temporal_guidance:
|
||||
noise_pred_uncond, noise_pred_text, noise_pred_perturb = noise_pred.chunk(3)
|
||||
noise_pred = (
|
||||
noise_pred_uncond
|
||||
+ self.guidance_scale * (noise_pred_text - noise_pred_uncond)
|
||||
+ self._stg_scale * (noise_pred_text - noise_pred_perturb)
|
||||
)
|
||||
timestep, _, _ = timestep.chunk(3)
|
||||
|
||||
if do_rescaling:
|
||||
rescaling_scale = 0.7
|
||||
factor = noise_pred_text.std() / noise_pred.std()
|
||||
factor = rescaling_scale * factor + (1 - rescaling_scale)
|
||||
noise_pred = noise_pred * factor
|
||||
|
||||
# compute the previous noisy sample x_t -> x_t-1
|
||||
noise_pred = self._unpack_latents(
|
||||
noise_pred,
|
||||
latent_num_frames,
|
||||
latent_height,
|
||||
latent_width,
|
||||
self.transformer_spatial_patch_size,
|
||||
self.transformer_temporal_patch_size,
|
||||
)
|
||||
latents = self._unpack_latents(
|
||||
latents,
|
||||
latent_num_frames,
|
||||
latent_height,
|
||||
latent_width,
|
||||
self.transformer_spatial_patch_size,
|
||||
self.transformer_temporal_patch_size,
|
||||
)
|
||||
|
||||
noise_pred = noise_pred[:, :, 1:]
|
||||
noise_latents = latents[:, :, 1:]
|
||||
pred_latents = self.scheduler.step(noise_pred, t, noise_latents, return_dict=False)[0]
|
||||
|
||||
latents = torch.cat([latents[:, :, :1], pred_latents], dim=2)
|
||||
latents = self._pack_latents(
|
||||
latents, self.transformer_spatial_patch_size, self.transformer_temporal_patch_size
|
||||
)
|
||||
|
||||
if callback_on_step_end is not None:
|
||||
callback_kwargs = {}
|
||||
for k in callback_on_step_end_tensor_inputs:
|
||||
callback_kwargs[k] = locals()[k]
|
||||
callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
|
||||
|
||||
latents = callback_outputs.pop("latents", latents)
|
||||
prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
|
||||
|
||||
# call the callback, if provided
|
||||
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
|
||||
progress_bar.update()
|
||||
|
||||
if XLA_AVAILABLE:
|
||||
xm.mark_step()
|
||||
|
||||
if output_type == "latent":
|
||||
video = latents
|
||||
else:
|
||||
latents = self._unpack_latents(
|
||||
latents,
|
||||
latent_num_frames,
|
||||
latent_height,
|
||||
latent_width,
|
||||
self.transformer_spatial_patch_size,
|
||||
self.transformer_temporal_patch_size,
|
||||
)
|
||||
latents = self._denormalize_latents(
|
||||
latents, self.vae.latents_mean, self.vae.latents_std, self.vae.config.scaling_factor
|
||||
)
|
||||
latents = latents.to(prompt_embeds.dtype)
|
||||
|
||||
if not self.vae.config.timestep_conditioning:
|
||||
timestep = None
|
||||
else:
|
||||
noise = torch.randn(latents.shape, generator=generator, device=device, dtype=latents.dtype)
|
||||
if not isinstance(decode_timestep, list):
|
||||
decode_timestep = [decode_timestep] * batch_size
|
||||
if decode_noise_scale is None:
|
||||
decode_noise_scale = decode_timestep
|
||||
elif not isinstance(decode_noise_scale, list):
|
||||
decode_noise_scale = [decode_noise_scale] * batch_size
|
||||
|
||||
timestep = torch.tensor(decode_timestep, device=device, dtype=latents.dtype)
|
||||
decode_noise_scale = torch.tensor(decode_noise_scale, device=device, dtype=latents.dtype)[
|
||||
:, None, None, None, None
|
||||
]
|
||||
latents = (1 - decode_noise_scale) * latents + decode_noise_scale * noise
|
||||
|
||||
video = self.vae.decode(latents, timestep, return_dict=False)[0]
|
||||
video = self.video_processor.postprocess_video(video, output_type=output_type)
|
||||
|
||||
# Offload all models
|
||||
self.maybe_free_model_hooks()
|
||||
|
||||
if not return_dict:
|
||||
return (video,)
|
||||
|
||||
return LTXPipelineOutput(frames=video)
|
||||
@@ -0,0 +1,843 @@
|
||||
# Copyright 2024 Genmo and The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import inspect
|
||||
import types
|
||||
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from transformers import T5EncoderModel, T5TokenizerFast
|
||||
|
||||
from diffusers.callbacks import MultiPipelineCallbacks, PipelineCallback
|
||||
from diffusers.loaders import Mochi1LoraLoaderMixin
|
||||
from diffusers.models import AutoencoderKLMochi, MochiTransformer3DModel
|
||||
from diffusers.pipelines.mochi.pipeline_output import MochiPipelineOutput
|
||||
from diffusers.pipelines.pipeline_utils import DiffusionPipeline
|
||||
from diffusers.schedulers import FlowMatchEulerDiscreteScheduler
|
||||
from diffusers.utils import (
|
||||
is_torch_xla_available,
|
||||
logging,
|
||||
replace_example_docstring,
|
||||
)
|
||||
from diffusers.utils.torch_utils import randn_tensor
|
||||
from diffusers.video_processor import VideoProcessor
|
||||
|
||||
|
||||
if is_torch_xla_available():
|
||||
import torch_xla.core.xla_model as xm
|
||||
|
||||
XLA_AVAILABLE = True
|
||||
else:
|
||||
XLA_AVAILABLE = False
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||
|
||||
EXAMPLE_DOC_STRING = """
|
||||
Examples:
|
||||
```py
|
||||
>>> import torch
|
||||
>>> from diffusers.utils import export_to_video
|
||||
>>> from examples.community.pipeline_stg_mochi import MochiSTGPipeline
|
||||
|
||||
>>> pipe = MochiSTGPipeline.from_pretrained("genmo/mochi-1-preview", torch_dtype=torch.bfloat16)
|
||||
>>> pipe.enable_model_cpu_offload()
|
||||
>>> pipe.enable_vae_tiling()
|
||||
>>> prompt = "A close-up of a beautiful woman's face with colored powder exploding around her, creating an abstract splash of vibrant hues, realistic style."
|
||||
|
||||
>>> # Configure STG mode options
|
||||
>>> stg_applied_layers_idx = [34] # Layer indices from 0 to 41
|
||||
>>> stg_scale = 1.0 # Set 0.0 for CFG
|
||||
>>> do_rescaling = False
|
||||
|
||||
>>> frames = pipe(
|
||||
... prompt=prompt,
|
||||
... num_inference_steps=28,
|
||||
... guidance_scale=3.5,
|
||||
... stg_applied_layers_idx=stg_applied_layers_idx,
|
||||
... stg_scale=stg_scale,
|
||||
... do_rescaling=do_rescaling).frames[0]
|
||||
>>> export_to_video(frames, "mochi.mp4")
|
||||
```
|
||||
"""
|
||||
|
||||
|
||||
def forward_with_stg(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
encoder_hidden_states: torch.Tensor,
|
||||
temb: torch.Tensor,
|
||||
encoder_attention_mask: torch.Tensor,
|
||||
image_rotary_emb: Optional[torch.Tensor] = None,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
hidden_states_ptb = hidden_states[2:]
|
||||
encoder_hidden_states_ptb = encoder_hidden_states[2:]
|
||||
norm_hidden_states, gate_msa, scale_mlp, gate_mlp = self.norm1(hidden_states, temb)
|
||||
|
||||
if not self.context_pre_only:
|
||||
norm_encoder_hidden_states, enc_gate_msa, enc_scale_mlp, enc_gate_mlp = self.norm1_context(
|
||||
encoder_hidden_states, temb
|
||||
)
|
||||
else:
|
||||
norm_encoder_hidden_states = self.norm1_context(encoder_hidden_states, temb)
|
||||
|
||||
attn_hidden_states, context_attn_hidden_states = self.attn1(
|
||||
hidden_states=norm_hidden_states,
|
||||
encoder_hidden_states=norm_encoder_hidden_states,
|
||||
image_rotary_emb=image_rotary_emb,
|
||||
attention_mask=encoder_attention_mask,
|
||||
)
|
||||
|
||||
hidden_states = hidden_states + self.norm2(attn_hidden_states, torch.tanh(gate_msa).unsqueeze(1))
|
||||
norm_hidden_states = self.norm3(hidden_states, (1 + scale_mlp.unsqueeze(1).to(torch.float32)))
|
||||
ff_output = self.ff(norm_hidden_states)
|
||||
hidden_states = hidden_states + self.norm4(ff_output, torch.tanh(gate_mlp).unsqueeze(1))
|
||||
|
||||
if not self.context_pre_only:
|
||||
encoder_hidden_states = encoder_hidden_states + self.norm2_context(
|
||||
context_attn_hidden_states, torch.tanh(enc_gate_msa).unsqueeze(1)
|
||||
)
|
||||
norm_encoder_hidden_states = self.norm3_context(
|
||||
encoder_hidden_states, (1 + enc_scale_mlp.unsqueeze(1).to(torch.float32))
|
||||
)
|
||||
context_ff_output = self.ff_context(norm_encoder_hidden_states)
|
||||
encoder_hidden_states = encoder_hidden_states + self.norm4_context(
|
||||
context_ff_output, torch.tanh(enc_gate_mlp).unsqueeze(1)
|
||||
)
|
||||
|
||||
hidden_states[2:] = hidden_states_ptb
|
||||
encoder_hidden_states[2:] = encoder_hidden_states_ptb
|
||||
|
||||
return hidden_states, encoder_hidden_states
|
||||
|
||||
|
||||
# from: https://github.com/genmoai/models/blob/075b6e36db58f1242921deff83a1066887b9c9e1/src/mochi_preview/infer.py#L77
|
||||
def linear_quadratic_schedule(num_steps, threshold_noise, linear_steps=None):
|
||||
if linear_steps is None:
|
||||
linear_steps = num_steps // 2
|
||||
linear_sigma_schedule = [i * threshold_noise / linear_steps for i in range(linear_steps)]
|
||||
threshold_noise_step_diff = linear_steps - threshold_noise * num_steps
|
||||
quadratic_steps = num_steps - linear_steps
|
||||
quadratic_coef = threshold_noise_step_diff / (linear_steps * quadratic_steps**2)
|
||||
linear_coef = threshold_noise / linear_steps - 2 * threshold_noise_step_diff / (quadratic_steps**2)
|
||||
const = quadratic_coef * (linear_steps**2)
|
||||
quadratic_sigma_schedule = [
|
||||
quadratic_coef * (i**2) + linear_coef * i + const for i in range(linear_steps, num_steps)
|
||||
]
|
||||
sigma_schedule = linear_sigma_schedule + quadratic_sigma_schedule
|
||||
sigma_schedule = [1.0 - x for x in sigma_schedule]
|
||||
return sigma_schedule
|
||||
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps
|
||||
def retrieve_timesteps(
|
||||
scheduler,
|
||||
num_inference_steps: Optional[int] = None,
|
||||
device: Optional[Union[str, torch.device]] = None,
|
||||
timesteps: Optional[List[int]] = None,
|
||||
sigmas: Optional[List[float]] = None,
|
||||
**kwargs,
|
||||
):
|
||||
r"""
|
||||
Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
|
||||
custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.
|
||||
|
||||
Args:
|
||||
scheduler (`SchedulerMixin`):
|
||||
The scheduler to get timesteps from.
|
||||
num_inference_steps (`int`):
|
||||
The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps`
|
||||
must be `None`.
|
||||
device (`str` or `torch.device`, *optional*):
|
||||
The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
|
||||
timesteps (`List[int]`, *optional*):
|
||||
Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed,
|
||||
`num_inference_steps` and `sigmas` must be `None`.
|
||||
sigmas (`List[float]`, *optional*):
|
||||
Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed,
|
||||
`num_inference_steps` and `timesteps` must be `None`.
|
||||
|
||||
Returns:
|
||||
`Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the
|
||||
second element is the number of inference steps.
|
||||
"""
|
||||
if timesteps is not None and sigmas is not None:
|
||||
raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom value")
|
||||
if timesteps is not None:
|
||||
accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
|
||||
if not accepts_timesteps:
|
||||
raise ValueError(
|
||||
f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
|
||||
f" timestep schedules. Please check whether you are using the correct scheduler."
|
||||
)
|
||||
scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)
|
||||
timesteps = scheduler.timesteps
|
||||
num_inference_steps = len(timesteps)
|
||||
elif sigmas is not None:
|
||||
accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
|
||||
if not accept_sigmas:
|
||||
raise ValueError(
|
||||
f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
|
||||
f" sigmas schedules. Please check whether you are using the correct scheduler."
|
||||
)
|
||||
scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs)
|
||||
timesteps = scheduler.timesteps
|
||||
num_inference_steps = len(timesteps)
|
||||
else:
|
||||
scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)
|
||||
timesteps = scheduler.timesteps
|
||||
return timesteps, num_inference_steps
|
||||
|
||||
|
||||
class MochiSTGPipeline(DiffusionPipeline, Mochi1LoraLoaderMixin):
|
||||
r"""
|
||||
The mochi pipeline for text-to-video generation.
|
||||
|
||||
Reference: https://github.com/genmoai/models
|
||||
|
||||
Args:
|
||||
transformer ([`MochiTransformer3DModel`]):
|
||||
Conditional Transformer architecture to denoise the encoded video latents.
|
||||
scheduler ([`FlowMatchEulerDiscreteScheduler`]):
|
||||
A scheduler to be used in combination with `transformer` to denoise the encoded image latents.
|
||||
vae ([`AutoencoderKLMochi`]):
|
||||
Variational Auto-Encoder (VAE) Model to encode and decode videos to and from latent representations.
|
||||
text_encoder ([`T5EncoderModel`]):
|
||||
[T5](https://huggingface.co/docs/transformers/en/model_doc/t5#transformers.T5EncoderModel), specifically
|
||||
the [google/t5-v1_1-xxl](https://huggingface.co/google/t5-v1_1-xxl) variant.
|
||||
tokenizer (`CLIPTokenizer`):
|
||||
Tokenizer of class
|
||||
[CLIPTokenizer](https://huggingface.co/docs/transformers/en/model_doc/clip#transformers.CLIPTokenizer).
|
||||
tokenizer (`T5TokenizerFast`):
|
||||
Second Tokenizer of class
|
||||
[T5TokenizerFast](https://huggingface.co/docs/transformers/en/model_doc/t5#transformers.T5TokenizerFast).
|
||||
"""
|
||||
|
||||
model_cpu_offload_seq = "text_encoder->transformer->vae"
|
||||
_optional_components = []
|
||||
_callback_tensor_inputs = ["latents", "prompt_embeds", "negative_prompt_embeds"]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
scheduler: FlowMatchEulerDiscreteScheduler,
|
||||
vae: AutoencoderKLMochi,
|
||||
text_encoder: T5EncoderModel,
|
||||
tokenizer: T5TokenizerFast,
|
||||
transformer: MochiTransformer3DModel,
|
||||
force_zeros_for_empty_prompt: bool = False,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.register_modules(
|
||||
vae=vae,
|
||||
text_encoder=text_encoder,
|
||||
tokenizer=tokenizer,
|
||||
transformer=transformer,
|
||||
scheduler=scheduler,
|
||||
)
|
||||
# TODO: determine these scaling factors from model parameters
|
||||
self.vae_spatial_scale_factor = 8
|
||||
self.vae_temporal_scale_factor = 6
|
||||
self.patch_size = 2
|
||||
|
||||
self.video_processor = VideoProcessor(vae_scale_factor=self.vae_spatial_scale_factor)
|
||||
self.tokenizer_max_length = (
|
||||
self.tokenizer.model_max_length if hasattr(self, "tokenizer") and self.tokenizer is not None else 256
|
||||
)
|
||||
self.default_height = 480
|
||||
self.default_width = 848
|
||||
self.register_to_config(force_zeros_for_empty_prompt=force_zeros_for_empty_prompt)
|
||||
|
||||
def _get_t5_prompt_embeds(
|
||||
self,
|
||||
prompt: Union[str, List[str]] = None,
|
||||
num_videos_per_prompt: int = 1,
|
||||
max_sequence_length: int = 256,
|
||||
device: Optional[torch.device] = None,
|
||||
dtype: Optional[torch.dtype] = None,
|
||||
):
|
||||
device = device or self._execution_device
|
||||
dtype = dtype or self.text_encoder.dtype
|
||||
|
||||
prompt = [prompt] if isinstance(prompt, str) else prompt
|
||||
batch_size = len(prompt)
|
||||
|
||||
text_inputs = self.tokenizer(
|
||||
prompt,
|
||||
padding="max_length",
|
||||
max_length=max_sequence_length,
|
||||
truncation=True,
|
||||
add_special_tokens=True,
|
||||
return_tensors="pt",
|
||||
)
|
||||
|
||||
text_input_ids = text_inputs.input_ids
|
||||
prompt_attention_mask = text_inputs.attention_mask
|
||||
prompt_attention_mask = prompt_attention_mask.bool().to(device)
|
||||
|
||||
# The original Mochi implementation zeros out empty negative prompts
|
||||
# but this can lead to overflow when placing the entire pipeline under the autocast context
|
||||
# adding this here so that we can enable zeroing prompts if necessary
|
||||
if self.config.force_zeros_for_empty_prompt and (prompt == "" or prompt[-1] == ""):
|
||||
text_input_ids = torch.zeros_like(text_input_ids, device=device)
|
||||
prompt_attention_mask = torch.zeros_like(prompt_attention_mask, dtype=torch.bool, device=device)
|
||||
|
||||
untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
|
||||
|
||||
if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids):
|
||||
removed_text = self.tokenizer.batch_decode(untruncated_ids[:, max_sequence_length - 1 : -1])
|
||||
logger.warning(
|
||||
"The following part of your input was truncated because `max_sequence_length` is set to "
|
||||
f" {max_sequence_length} tokens: {removed_text}"
|
||||
)
|
||||
|
||||
prompt_embeds = self.text_encoder(text_input_ids.to(device), attention_mask=prompt_attention_mask)[0]
|
||||
prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)
|
||||
|
||||
# duplicate text embeddings for each generation per prompt, using mps friendly method
|
||||
_, seq_len, _ = prompt_embeds.shape
|
||||
prompt_embeds = prompt_embeds.repeat(1, num_videos_per_prompt, 1)
|
||||
prompt_embeds = prompt_embeds.view(batch_size * num_videos_per_prompt, seq_len, -1)
|
||||
|
||||
prompt_attention_mask = prompt_attention_mask.view(batch_size, -1)
|
||||
prompt_attention_mask = prompt_attention_mask.repeat(num_videos_per_prompt, 1)
|
||||
|
||||
return prompt_embeds, prompt_attention_mask
|
||||
|
||||
# Adapted from diffusers.pipelines.cogvideo.pipeline_cogvideox.CogVideoXPipeline.encode_prompt
|
||||
def encode_prompt(
|
||||
self,
|
||||
prompt: Union[str, List[str]],
|
||||
negative_prompt: Optional[Union[str, List[str]]] = None,
|
||||
do_classifier_free_guidance: bool = True,
|
||||
num_videos_per_prompt: int = 1,
|
||||
prompt_embeds: Optional[torch.Tensor] = None,
|
||||
negative_prompt_embeds: Optional[torch.Tensor] = None,
|
||||
prompt_attention_mask: Optional[torch.Tensor] = None,
|
||||
negative_prompt_attention_mask: Optional[torch.Tensor] = None,
|
||||
max_sequence_length: int = 256,
|
||||
device: Optional[torch.device] = None,
|
||||
dtype: Optional[torch.dtype] = None,
|
||||
):
|
||||
r"""
|
||||
Encodes the prompt into text encoder hidden states.
|
||||
|
||||
Args:
|
||||
prompt (`str` or `List[str]`, *optional*):
|
||||
prompt to be encoded
|
||||
negative_prompt (`str` or `List[str]`, *optional*):
|
||||
The prompt or prompts not to guide the image generation. If not defined, one has to pass
|
||||
`negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
|
||||
less than `1`).
|
||||
do_classifier_free_guidance (`bool`, *optional*, defaults to `True`):
|
||||
Whether to use classifier free guidance or not.
|
||||
num_videos_per_prompt (`int`, *optional*, defaults to 1):
|
||||
Number of videos that should be generated per prompt. torch device to place the resulting embeddings on
|
||||
prompt_embeds (`torch.Tensor`, *optional*):
|
||||
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
|
||||
provided, text embeddings will be generated from `prompt` input argument.
|
||||
negative_prompt_embeds (`torch.Tensor`, *optional*):
|
||||
Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
|
||||
weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
|
||||
argument.
|
||||
device: (`torch.device`, *optional*):
|
||||
torch device
|
||||
dtype: (`torch.dtype`, *optional*):
|
||||
torch dtype
|
||||
"""
|
||||
device = device or self._execution_device
|
||||
|
||||
prompt = [prompt] if isinstance(prompt, str) else prompt
|
||||
if prompt is not None:
|
||||
batch_size = len(prompt)
|
||||
else:
|
||||
batch_size = prompt_embeds.shape[0]
|
||||
|
||||
if prompt_embeds is None:
|
||||
prompt_embeds, prompt_attention_mask = self._get_t5_prompt_embeds(
|
||||
prompt=prompt,
|
||||
num_videos_per_prompt=num_videos_per_prompt,
|
||||
max_sequence_length=max_sequence_length,
|
||||
device=device,
|
||||
dtype=dtype,
|
||||
)
|
||||
|
||||
if do_classifier_free_guidance and negative_prompt_embeds is None:
|
||||
negative_prompt = negative_prompt or ""
|
||||
negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt
|
||||
|
||||
if prompt is not None and type(prompt) is not type(negative_prompt):
|
||||
raise TypeError(
|
||||
f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
|
||||
f" {type(prompt)}."
|
||||
)
|
||||
elif batch_size != len(negative_prompt):
|
||||
raise ValueError(
|
||||
f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
|
||||
f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
|
||||
" the batch size of `prompt`."
|
||||
)
|
||||
|
||||
negative_prompt_embeds, negative_prompt_attention_mask = self._get_t5_prompt_embeds(
|
||||
prompt=negative_prompt,
|
||||
num_videos_per_prompt=num_videos_per_prompt,
|
||||
max_sequence_length=max_sequence_length,
|
||||
device=device,
|
||||
dtype=dtype,
|
||||
)
|
||||
|
||||
return prompt_embeds, prompt_attention_mask, negative_prompt_embeds, negative_prompt_attention_mask
|
||||
|
||||
def check_inputs(
|
||||
self,
|
||||
prompt,
|
||||
height,
|
||||
width,
|
||||
callback_on_step_end_tensor_inputs=None,
|
||||
prompt_embeds=None,
|
||||
negative_prompt_embeds=None,
|
||||
prompt_attention_mask=None,
|
||||
negative_prompt_attention_mask=None,
|
||||
):
|
||||
if height % 8 != 0 or width % 8 != 0:
|
||||
raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
|
||||
|
||||
if callback_on_step_end_tensor_inputs is not None and not all(
|
||||
k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
|
||||
):
|
||||
raise ValueError(
|
||||
f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}"
|
||||
)
|
||||
|
||||
if prompt is not None and prompt_embeds is not None:
|
||||
raise ValueError(
|
||||
f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
|
||||
" only forward one of the two."
|
||||
)
|
||||
elif prompt is None and prompt_embeds is None:
|
||||
raise ValueError(
|
||||
"Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
|
||||
)
|
||||
elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
|
||||
raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
|
||||
|
||||
if prompt_embeds is not None and prompt_attention_mask is None:
|
||||
raise ValueError("Must provide `prompt_attention_mask` when specifying `prompt_embeds`.")
|
||||
|
||||
if negative_prompt_embeds is not None and negative_prompt_attention_mask is None:
|
||||
raise ValueError("Must provide `negative_prompt_attention_mask` when specifying `negative_prompt_embeds`.")
|
||||
|
||||
if prompt_embeds is not None and negative_prompt_embeds is not None:
|
||||
if prompt_embeds.shape != negative_prompt_embeds.shape:
|
||||
raise ValueError(
|
||||
"`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but"
|
||||
f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`"
|
||||
f" {negative_prompt_embeds.shape}."
|
||||
)
|
||||
if prompt_attention_mask.shape != negative_prompt_attention_mask.shape:
|
||||
raise ValueError(
|
||||
"`prompt_attention_mask` and `negative_prompt_attention_mask` must have the same shape when passed directly, but"
|
||||
f" got: `prompt_attention_mask` {prompt_attention_mask.shape} != `negative_prompt_attention_mask`"
|
||||
f" {negative_prompt_attention_mask.shape}."
|
||||
)
|
||||
|
||||
def enable_vae_slicing(self):
|
||||
r"""
|
||||
Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to
|
||||
compute decoding in several steps. This is useful to save some memory and allow larger batch sizes.
|
||||
"""
|
||||
self.vae.enable_slicing()
|
||||
|
||||
def disable_vae_slicing(self):
|
||||
r"""
|
||||
Disable sliced VAE decoding. If `enable_vae_slicing` was previously enabled, this method will go back to
|
||||
computing decoding in one step.
|
||||
"""
|
||||
self.vae.disable_slicing()
|
||||
|
||||
def enable_vae_tiling(self):
|
||||
r"""
|
||||
Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to
|
||||
compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow
|
||||
processing larger images.
|
||||
"""
|
||||
self.vae.enable_tiling()
|
||||
|
||||
def disable_vae_tiling(self):
|
||||
r"""
|
||||
Disable tiled VAE decoding. If `enable_vae_tiling` was previously enabled, this method will go back to
|
||||
computing decoding in one step.
|
||||
"""
|
||||
self.vae.disable_tiling()
|
||||
|
||||
def prepare_latents(
|
||||
self,
|
||||
batch_size,
|
||||
num_channels_latents,
|
||||
height,
|
||||
width,
|
||||
num_frames,
|
||||
dtype,
|
||||
device,
|
||||
generator,
|
||||
latents=None,
|
||||
):
|
||||
height = height // self.vae_spatial_scale_factor
|
||||
width = width // self.vae_spatial_scale_factor
|
||||
num_frames = (num_frames - 1) // self.vae_temporal_scale_factor + 1
|
||||
|
||||
shape = (batch_size, num_channels_latents, num_frames, height, width)
|
||||
|
||||
if latents is not None:
|
||||
return latents.to(device=device, dtype=dtype)
|
||||
if isinstance(generator, list) and len(generator) != batch_size:
|
||||
raise ValueError(
|
||||
f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
|
||||
f" size of {batch_size}. Make sure the batch size matches the length of the generators."
|
||||
)
|
||||
|
||||
latents = randn_tensor(shape, generator=generator, device=device, dtype=torch.float32)
|
||||
latents = latents.to(dtype)
|
||||
return latents
|
||||
|
||||
@property
|
||||
def guidance_scale(self):
|
||||
return self._guidance_scale
|
||||
|
||||
@property
|
||||
def do_classifier_free_guidance(self):
|
||||
return self._guidance_scale > 1.0
|
||||
|
||||
@property
|
||||
def do_spatio_temporal_guidance(self):
|
||||
return self._stg_scale > 0.0
|
||||
|
||||
@property
|
||||
def num_timesteps(self):
|
||||
return self._num_timesteps
|
||||
|
||||
@property
|
||||
def attention_kwargs(self):
|
||||
return self._attention_kwargs
|
||||
|
||||
@property
|
||||
def current_timestep(self):
|
||||
return self._current_timestep
|
||||
|
||||
@property
|
||||
def interrupt(self):
|
||||
return self._interrupt
|
||||
|
||||
@torch.no_grad()
|
||||
@replace_example_docstring(EXAMPLE_DOC_STRING)
|
||||
def __call__(
|
||||
self,
|
||||
prompt: Union[str, List[str]] = None,
|
||||
negative_prompt: Optional[Union[str, List[str]]] = None,
|
||||
height: Optional[int] = None,
|
||||
width: Optional[int] = None,
|
||||
num_frames: int = 19,
|
||||
num_inference_steps: int = 64,
|
||||
timesteps: List[int] = None,
|
||||
guidance_scale: float = 4.5,
|
||||
num_videos_per_prompt: Optional[int] = 1,
|
||||
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
|
||||
latents: Optional[torch.Tensor] = None,
|
||||
prompt_embeds: Optional[torch.Tensor] = None,
|
||||
prompt_attention_mask: Optional[torch.Tensor] = None,
|
||||
negative_prompt_embeds: Optional[torch.Tensor] = None,
|
||||
negative_prompt_attention_mask: Optional[torch.Tensor] = None,
|
||||
output_type: Optional[str] = "pil",
|
||||
return_dict: bool = True,
|
||||
attention_kwargs: Optional[Dict[str, Any]] = None,
|
||||
callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
|
||||
callback_on_step_end_tensor_inputs: List[str] = ["latents"],
|
||||
max_sequence_length: int = 256,
|
||||
stg_applied_layers_idx: Optional[List[int]] = [34],
|
||||
stg_scale: Optional[float] = 0.0,
|
||||
do_rescaling: Optional[bool] = False,
|
||||
):
|
||||
r"""
|
||||
Function invoked when calling the pipeline for generation.
|
||||
|
||||
Args:
|
||||
prompt (`str` or `List[str]`, *optional*):
|
||||
The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
|
||||
instead.
|
||||
height (`int`, *optional*, defaults to `self.default_height`):
|
||||
The height in pixels of the generated image. This is set to 480 by default for the best results.
|
||||
width (`int`, *optional*, defaults to `self.default_width`):
|
||||
The width in pixels of the generated image. This is set to 848 by default for the best results.
|
||||
num_frames (`int`, defaults to `19`):
|
||||
The number of video frames to generate
|
||||
num_inference_steps (`int`, *optional*, defaults to 50):
|
||||
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
|
||||
expense of slower inference.
|
||||
timesteps (`List[int]`, *optional*):
|
||||
Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument
|
||||
in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is
|
||||
passed will be used. Must be in descending order.
|
||||
guidance_scale (`float`, defaults to `4.5`):
|
||||
Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
|
||||
`guidance_scale` is defined as `w` of equation 2. of [Imagen
|
||||
Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
|
||||
1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
|
||||
usually at the expense of lower image quality.
|
||||
num_videos_per_prompt (`int`, *optional*, defaults to 1):
|
||||
The number of videos to generate per prompt.
|
||||
generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
|
||||
One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
|
||||
to make generation deterministic.
|
||||
latents (`torch.Tensor`, *optional*):
|
||||
Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
|
||||
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
|
||||
tensor will ge generated by sampling using the supplied random `generator`.
|
||||
prompt_embeds (`torch.Tensor`, *optional*):
|
||||
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
|
||||
provided, text embeddings will be generated from `prompt` input argument.
|
||||
prompt_attention_mask (`torch.Tensor`, *optional*):
|
||||
Pre-generated attention mask for text embeddings.
|
||||
negative_prompt_embeds (`torch.FloatTensor`, *optional*):
|
||||
Pre-generated negative text embeddings. For PixArt-Sigma this negative prompt should be "". If not
|
||||
provided, negative_prompt_embeds will be generated from `negative_prompt` input argument.
|
||||
negative_prompt_attention_mask (`torch.FloatTensor`, *optional*):
|
||||
Pre-generated attention mask for negative text embeddings.
|
||||
output_type (`str`, *optional*, defaults to `"pil"`):
|
||||
The output format of the generate image. Choose between
|
||||
[PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
|
||||
return_dict (`bool`, *optional*, defaults to `True`):
|
||||
Whether or not to return a [`~pipelines.mochi.MochiPipelineOutput`] instead of a plain tuple.
|
||||
attention_kwargs (`dict`, *optional*):
|
||||
A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
|
||||
`self.processor` in
|
||||
[diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
|
||||
callback_on_step_end (`Callable`, *optional*):
|
||||
A function that calls at the end of each denoising steps during the inference. The function is called
|
||||
with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int,
|
||||
callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by
|
||||
`callback_on_step_end_tensor_inputs`.
|
||||
callback_on_step_end_tensor_inputs (`List`, *optional*):
|
||||
The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
|
||||
will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
|
||||
`._callback_tensor_inputs` attribute of your pipeline class.
|
||||
max_sequence_length (`int` defaults to `256`):
|
||||
Maximum sequence length to use with the `prompt`.
|
||||
|
||||
Examples:
|
||||
|
||||
Returns:
|
||||
[`~pipelines.mochi.MochiPipelineOutput`] or `tuple`:
|
||||
If `return_dict` is `True`, [`~pipelines.mochi.MochiPipelineOutput`] is returned, otherwise a `tuple`
|
||||
is returned where the first element is a list with the generated images.
|
||||
"""
|
||||
|
||||
if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)):
|
||||
callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs
|
||||
|
||||
height = height or self.default_height
|
||||
width = width or self.default_width
|
||||
|
||||
# 1. Check inputs. Raise error if not correct
|
||||
self.check_inputs(
|
||||
prompt=prompt,
|
||||
height=height,
|
||||
width=width,
|
||||
callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs,
|
||||
prompt_embeds=prompt_embeds,
|
||||
negative_prompt_embeds=negative_prompt_embeds,
|
||||
prompt_attention_mask=prompt_attention_mask,
|
||||
negative_prompt_attention_mask=negative_prompt_attention_mask,
|
||||
)
|
||||
|
||||
self._guidance_scale = guidance_scale
|
||||
self._stg_scale = stg_scale
|
||||
self._attention_kwargs = attention_kwargs
|
||||
self._current_timestep = None
|
||||
self._interrupt = False
|
||||
|
||||
if self.do_spatio_temporal_guidance:
|
||||
for i in stg_applied_layers_idx:
|
||||
self.transformer.transformer_blocks[i].forward = types.MethodType(
|
||||
forward_with_stg, self.transformer.transformer_blocks[i]
|
||||
)
|
||||
|
||||
# 2. Define call parameters
|
||||
if prompt is not None and isinstance(prompt, str):
|
||||
batch_size = 1
|
||||
elif prompt is not None and isinstance(prompt, list):
|
||||
batch_size = len(prompt)
|
||||
else:
|
||||
batch_size = prompt_embeds.shape[0]
|
||||
|
||||
device = self._execution_device
|
||||
# 3. Prepare text embeddings
|
||||
(
|
||||
prompt_embeds,
|
||||
prompt_attention_mask,
|
||||
negative_prompt_embeds,
|
||||
negative_prompt_attention_mask,
|
||||
) = self.encode_prompt(
|
||||
prompt=prompt,
|
||||
negative_prompt=negative_prompt,
|
||||
do_classifier_free_guidance=self.do_classifier_free_guidance,
|
||||
num_videos_per_prompt=num_videos_per_prompt,
|
||||
prompt_embeds=prompt_embeds,
|
||||
negative_prompt_embeds=negative_prompt_embeds,
|
||||
prompt_attention_mask=prompt_attention_mask,
|
||||
negative_prompt_attention_mask=negative_prompt_attention_mask,
|
||||
max_sequence_length=max_sequence_length,
|
||||
device=device,
|
||||
)
|
||||
# 4. Prepare latent variables
|
||||
num_channels_latents = self.transformer.config.in_channels
|
||||
latents = self.prepare_latents(
|
||||
batch_size * num_videos_per_prompt,
|
||||
num_channels_latents,
|
||||
height,
|
||||
width,
|
||||
num_frames,
|
||||
prompt_embeds.dtype,
|
||||
device,
|
||||
generator,
|
||||
latents,
|
||||
)
|
||||
|
||||
if self.do_classifier_free_guidance and not self.do_spatio_temporal_guidance:
|
||||
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)
|
||||
prompt_attention_mask = torch.cat([negative_prompt_attention_mask, prompt_attention_mask], dim=0)
|
||||
elif self.do_classifier_free_guidance and self.do_spatio_temporal_guidance:
|
||||
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds, prompt_embeds], dim=0)
|
||||
prompt_attention_mask = torch.cat(
|
||||
[negative_prompt_attention_mask, prompt_attention_mask, prompt_attention_mask], dim=0
|
||||
)
|
||||
|
||||
# 5. Prepare timestep
|
||||
# from https://github.com/genmoai/models/blob/075b6e36db58f1242921deff83a1066887b9c9e1/src/mochi_preview/infer.py#L77
|
||||
threshold_noise = 0.025
|
||||
sigmas = linear_quadratic_schedule(num_inference_steps, threshold_noise)
|
||||
sigmas = np.array(sigmas)
|
||||
|
||||
timesteps, num_inference_steps = retrieve_timesteps(
|
||||
self.scheduler,
|
||||
num_inference_steps,
|
||||
device,
|
||||
timesteps,
|
||||
sigmas,
|
||||
)
|
||||
num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
|
||||
self._num_timesteps = len(timesteps)
|
||||
|
||||
# 6. Denoising loop
|
||||
with self.progress_bar(total=num_inference_steps) as progress_bar:
|
||||
for i, t in enumerate(timesteps):
|
||||
if self.interrupt:
|
||||
continue
|
||||
|
||||
# Note: Mochi uses reversed timesteps. To ensure compatibility with methods like FasterCache, we need
|
||||
# to make sure we're using the correct non-reversed timestep value.
|
||||
self._current_timestep = 1000 - t
|
||||
if self.do_classifier_free_guidance and not self.do_spatio_temporal_guidance:
|
||||
latent_model_input = torch.cat([latents] * 2)
|
||||
elif self.do_classifier_free_guidance and self.do_spatio_temporal_guidance:
|
||||
latent_model_input = torch.cat([latents] * 3)
|
||||
else:
|
||||
latent_model_input = latents
|
||||
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
|
||||
timestep = t.expand(latent_model_input.shape[0]).to(latents.dtype)
|
||||
|
||||
noise_pred = self.transformer(
|
||||
hidden_states=latent_model_input,
|
||||
encoder_hidden_states=prompt_embeds,
|
||||
timestep=timestep,
|
||||
encoder_attention_mask=prompt_attention_mask,
|
||||
attention_kwargs=attention_kwargs,
|
||||
return_dict=False,
|
||||
)[0]
|
||||
# Mochi CFG + Sampling runs in FP32
|
||||
noise_pred = noise_pred.to(torch.float32)
|
||||
|
||||
if self.do_classifier_free_guidance and not self.do_spatio_temporal_guidance:
|
||||
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
|
||||
noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond)
|
||||
elif self.do_classifier_free_guidance and self.do_spatio_temporal_guidance:
|
||||
noise_pred_uncond, noise_pred_text, noise_pred_perturb = noise_pred.chunk(3)
|
||||
noise_pred = (
|
||||
noise_pred_uncond
|
||||
+ self.guidance_scale * (noise_pred_text - noise_pred_uncond)
|
||||
+ self._stg_scale * (noise_pred_text - noise_pred_perturb)
|
||||
)
|
||||
|
||||
if do_rescaling:
|
||||
rescaling_scale = 0.7
|
||||
factor = noise_pred_text.std() / noise_pred.std()
|
||||
factor = rescaling_scale * factor + (1 - rescaling_scale)
|
||||
noise_pred = noise_pred * factor
|
||||
|
||||
# compute the previous noisy sample x_t -> x_t-1
|
||||
latents_dtype = latents.dtype
|
||||
latents = self.scheduler.step(noise_pred, t, latents.to(torch.float32), return_dict=False)[0]
|
||||
latents = latents.to(latents_dtype)
|
||||
|
||||
if latents.dtype != latents_dtype:
|
||||
if torch.backends.mps.is_available():
|
||||
# some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272
|
||||
latents = latents.to(latents_dtype)
|
||||
|
||||
if callback_on_step_end is not None:
|
||||
callback_kwargs = {}
|
||||
for k in callback_on_step_end_tensor_inputs:
|
||||
callback_kwargs[k] = locals()[k]
|
||||
callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
|
||||
|
||||
latents = callback_outputs.pop("latents", latents)
|
||||
prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
|
||||
|
||||
# call the callback, if provided
|
||||
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
|
||||
progress_bar.update()
|
||||
|
||||
if XLA_AVAILABLE:
|
||||
xm.mark_step()
|
||||
|
||||
self._current_timestep = None
|
||||
|
||||
if output_type == "latent":
|
||||
video = latents
|
||||
else:
|
||||
# unscale/denormalize the latents
|
||||
# denormalize with the mean and std if available and not None
|
||||
has_latents_mean = hasattr(self.vae.config, "latents_mean") and self.vae.config.latents_mean is not None
|
||||
has_latents_std = hasattr(self.vae.config, "latents_std") and self.vae.config.latents_std is not None
|
||||
if has_latents_mean and has_latents_std:
|
||||
latents_mean = (
|
||||
torch.tensor(self.vae.config.latents_mean).view(1, 12, 1, 1, 1).to(latents.device, latents.dtype)
|
||||
)
|
||||
latents_std = (
|
||||
torch.tensor(self.vae.config.latents_std).view(1, 12, 1, 1, 1).to(latents.device, latents.dtype)
|
||||
)
|
||||
latents = latents * latents_std / self.vae.config.scaling_factor + latents_mean
|
||||
else:
|
||||
latents = latents / self.vae.config.scaling_factor
|
||||
|
||||
video = self.vae.decode(latents, return_dict=False)[0]
|
||||
video = self.video_processor.postprocess_video(video, output_type=output_type)
|
||||
|
||||
# Offload all models
|
||||
self.maybe_free_model_hooks()
|
||||
|
||||
if not return_dict:
|
||||
return (video,)
|
||||
|
||||
return MochiPipelineOutput(frames=video)
|
||||
@@ -0,0 +1,661 @@
|
||||
# Copyright 2025 The Wan Team and The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import html
|
||||
import types
|
||||
from typing import Any, Callable, Dict, List, Optional, Union
|
||||
|
||||
import ftfy
|
||||
import regex as re
|
||||
import torch
|
||||
from transformers import AutoTokenizer, UMT5EncoderModel
|
||||
|
||||
from diffusers.callbacks import MultiPipelineCallbacks, PipelineCallback
|
||||
from diffusers.loaders import WanLoraLoaderMixin
|
||||
from diffusers.models import AutoencoderKLWan, WanTransformer3DModel
|
||||
from diffusers.pipelines.pipeline_utils import DiffusionPipeline
|
||||
from diffusers.pipelines.wan.pipeline_output import WanPipelineOutput
|
||||
from diffusers.schedulers import FlowMatchEulerDiscreteScheduler
|
||||
from diffusers.utils import is_torch_xla_available, logging, replace_example_docstring
|
||||
from diffusers.utils.torch_utils import randn_tensor
|
||||
from diffusers.video_processor import VideoProcessor
|
||||
|
||||
|
||||
if is_torch_xla_available():
|
||||
import torch_xla.core.xla_model as xm
|
||||
|
||||
XLA_AVAILABLE = True
|
||||
else:
|
||||
XLA_AVAILABLE = False
|
||||
|
||||
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||
|
||||
|
||||
EXAMPLE_DOC_STRING = """
|
||||
Examples:
|
||||
```python
|
||||
>>> import torch
|
||||
>>> from diffusers.utils import export_to_video
|
||||
>>> from diffusers import AutoencoderKLWan
|
||||
>>> from diffusers.schedulers.scheduling_unipc_multistep import UniPCMultistepScheduler
|
||||
>>> from examples.community.pipeline_stg_wan import WanSTGPipeline
|
||||
|
||||
>>> # Available models: Wan-AI/Wan2.1-T2V-14B-Diffusers, Wan-AI/Wan2.1-T2V-1.3B-Diffusers
|
||||
>>> model_id = "Wan-AI/Wan2.1-T2V-14B-Diffusers"
|
||||
>>> vae = AutoencoderKLWan.from_pretrained(model_id, subfolder="vae", torch_dtype=torch.float32)
|
||||
>>> pipe = WanSTGPipeline.from_pretrained(model_id, vae=vae, torch_dtype=torch.bfloat16)
|
||||
>>> flow_shift = 5.0 # 5.0 for 720P, 3.0 for 480P
|
||||
>>> pipe.scheduler = UniPCMultistepScheduler.from_config(pipe.scheduler.config, flow_shift=flow_shift)
|
||||
>>> pipe.to("cuda")
|
||||
|
||||
>>> prompt = "A cat and a dog baking a cake together in a kitchen. The cat is carefully measuring flour, while the dog is stirring the batter with a wooden spoon. The kitchen is cozy, with sunlight streaming through the window."
|
||||
>>> negative_prompt = "Bright tones, overexposed, static, blurred details, subtitles, style, works, paintings, images, static, overall gray, worst quality, low quality, JPEG compression residue, ugly, incomplete, extra fingers, poorly drawn hands, poorly drawn faces, deformed, disfigured, misshapen limbs, fused fingers, still picture, messy background, three legs, many people in the background, walking backwards"
|
||||
|
||||
>>> # Configure STG mode options
|
||||
>>> stg_applied_layers_idx = [8] # Layer indices from 0 to 39 for 14b or 0 to 29 for 1.3b
|
||||
>>> stg_scale = 1.0 # Set 0.0 for CFG
|
||||
|
||||
>>> output = pipe(
|
||||
... prompt=prompt,
|
||||
... negative_prompt=negative_prompt,
|
||||
... height=720,
|
||||
... width=1280,
|
||||
... num_frames=81,
|
||||
... guidance_scale=5.0,
|
||||
... stg_applied_layers_idx=stg_applied_layers_idx,
|
||||
... stg_scale=stg_scale,
|
||||
... ).frames[0]
|
||||
>>> export_to_video(output, "output.mp4", fps=16)
|
||||
```
|
||||
"""
|
||||
|
||||
|
||||
def basic_clean(text):
|
||||
text = ftfy.fix_text(text)
|
||||
text = html.unescape(html.unescape(text))
|
||||
return text.strip()
|
||||
|
||||
|
||||
def whitespace_clean(text):
|
||||
text = re.sub(r"\s+", " ", text)
|
||||
text = text.strip()
|
||||
return text
|
||||
|
||||
|
||||
def prompt_clean(text):
|
||||
text = whitespace_clean(basic_clean(text))
|
||||
return text
|
||||
|
||||
|
||||
def forward_with_stg(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
encoder_hidden_states: torch.Tensor,
|
||||
temb: torch.Tensor,
|
||||
rotary_emb: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
return hidden_states
|
||||
|
||||
|
||||
def forward_without_stg(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
encoder_hidden_states: torch.Tensor,
|
||||
temb: torch.Tensor,
|
||||
rotary_emb: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
shift_msa, scale_msa, gate_msa, c_shift_msa, c_scale_msa, c_gate_msa = (
|
||||
self.scale_shift_table + temb.float()
|
||||
).chunk(6, dim=1)
|
||||
|
||||
# 1. Self-attention
|
||||
norm_hidden_states = (self.norm1(hidden_states.float()) * (1 + scale_msa) + shift_msa).type_as(hidden_states)
|
||||
attn_output = self.attn1(hidden_states=norm_hidden_states, rotary_emb=rotary_emb)
|
||||
hidden_states = (hidden_states.float() + attn_output * gate_msa).type_as(hidden_states)
|
||||
|
||||
# 2. Cross-attention
|
||||
norm_hidden_states = self.norm2(hidden_states.float()).type_as(hidden_states)
|
||||
attn_output = self.attn2(hidden_states=norm_hidden_states, encoder_hidden_states=encoder_hidden_states)
|
||||
hidden_states = hidden_states + attn_output
|
||||
|
||||
# 3. Feed-forward
|
||||
norm_hidden_states = (self.norm3(hidden_states.float()) * (1 + c_scale_msa) + c_shift_msa).type_as(hidden_states)
|
||||
ff_output = self.ffn(norm_hidden_states)
|
||||
hidden_states = (hidden_states.float() + ff_output.float() * c_gate_msa).type_as(hidden_states)
|
||||
|
||||
return hidden_states
|
||||
|
||||
|
||||
class WanSTGPipeline(DiffusionPipeline, WanLoraLoaderMixin):
|
||||
r"""
|
||||
Pipeline for text-to-video generation using Wan.
|
||||
|
||||
This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods
|
||||
implemented for all pipelines (downloading, saving, running on a particular device, etc.).
|
||||
|
||||
Args:
|
||||
tokenizer ([`T5Tokenizer`]):
|
||||
Tokenizer from [T5](https://huggingface.co/docs/transformers/en/model_doc/t5#transformers.T5Tokenizer),
|
||||
specifically the [google/umt5-xxl](https://huggingface.co/google/umt5-xxl) variant.
|
||||
text_encoder ([`T5EncoderModel`]):
|
||||
[T5](https://huggingface.co/docs/transformers/en/model_doc/t5#transformers.T5EncoderModel), specifically
|
||||
the [google/umt5-xxl](https://huggingface.co/google/umt5-xxl) variant.
|
||||
transformer ([`WanTransformer3DModel`]):
|
||||
Conditional Transformer to denoise the input latents.
|
||||
scheduler ([`UniPCMultistepScheduler`]):
|
||||
A scheduler to be used in combination with `transformer` to denoise the encoded image latents.
|
||||
vae ([`AutoencoderKLWan`]):
|
||||
Variational Auto-Encoder (VAE) Model to encode and decode videos to and from latent representations.
|
||||
"""
|
||||
|
||||
model_cpu_offload_seq = "text_encoder->transformer->vae"
|
||||
_callback_tensor_inputs = ["latents", "prompt_embeds", "negative_prompt_embeds"]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
tokenizer: AutoTokenizer,
|
||||
text_encoder: UMT5EncoderModel,
|
||||
transformer: WanTransformer3DModel,
|
||||
vae: AutoencoderKLWan,
|
||||
scheduler: FlowMatchEulerDiscreteScheduler,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.register_modules(
|
||||
vae=vae,
|
||||
text_encoder=text_encoder,
|
||||
tokenizer=tokenizer,
|
||||
transformer=transformer,
|
||||
scheduler=scheduler,
|
||||
)
|
||||
|
||||
self.vae_scale_factor_temporal = 2 ** sum(self.vae.temperal_downsample) if getattr(self, "vae", None) else 4
|
||||
self.vae_scale_factor_spatial = 2 ** len(self.vae.temperal_downsample) if getattr(self, "vae", None) else 8
|
||||
self.video_processor = VideoProcessor(vae_scale_factor=self.vae_scale_factor_spatial)
|
||||
|
||||
def _get_t5_prompt_embeds(
|
||||
self,
|
||||
prompt: Union[str, List[str]] = None,
|
||||
num_videos_per_prompt: int = 1,
|
||||
max_sequence_length: int = 226,
|
||||
device: Optional[torch.device] = None,
|
||||
dtype: Optional[torch.dtype] = None,
|
||||
):
|
||||
device = device or self._execution_device
|
||||
dtype = dtype or self.text_encoder.dtype
|
||||
|
||||
prompt = [prompt] if isinstance(prompt, str) else prompt
|
||||
prompt = [prompt_clean(u) for u in prompt]
|
||||
batch_size = len(prompt)
|
||||
|
||||
text_inputs = self.tokenizer(
|
||||
prompt,
|
||||
padding="max_length",
|
||||
max_length=max_sequence_length,
|
||||
truncation=True,
|
||||
add_special_tokens=True,
|
||||
return_attention_mask=True,
|
||||
return_tensors="pt",
|
||||
)
|
||||
text_input_ids, mask = text_inputs.input_ids, text_inputs.attention_mask
|
||||
seq_lens = mask.gt(0).sum(dim=1).long()
|
||||
|
||||
prompt_embeds = self.text_encoder(text_input_ids.to(device), mask.to(device)).last_hidden_state
|
||||
prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)
|
||||
prompt_embeds = [u[:v] for u, v in zip(prompt_embeds, seq_lens)]
|
||||
prompt_embeds = torch.stack(
|
||||
[torch.cat([u, u.new_zeros(max_sequence_length - u.size(0), u.size(1))]) for u in prompt_embeds], dim=0
|
||||
)
|
||||
|
||||
# duplicate text embeddings for each generation per prompt, using mps friendly method
|
||||
_, seq_len, _ = prompt_embeds.shape
|
||||
prompt_embeds = prompt_embeds.repeat(1, num_videos_per_prompt, 1)
|
||||
prompt_embeds = prompt_embeds.view(batch_size * num_videos_per_prompt, seq_len, -1)
|
||||
|
||||
return prompt_embeds
|
||||
|
||||
def encode_prompt(
|
||||
self,
|
||||
prompt: Union[str, List[str]],
|
||||
negative_prompt: Optional[Union[str, List[str]]] = None,
|
||||
do_classifier_free_guidance: bool = True,
|
||||
num_videos_per_prompt: int = 1,
|
||||
prompt_embeds: Optional[torch.Tensor] = None,
|
||||
negative_prompt_embeds: Optional[torch.Tensor] = None,
|
||||
max_sequence_length: int = 226,
|
||||
device: Optional[torch.device] = None,
|
||||
dtype: Optional[torch.dtype] = None,
|
||||
):
|
||||
r"""
|
||||
Encodes the prompt into text encoder hidden states.
|
||||
|
||||
Args:
|
||||
prompt (`str` or `List[str]`, *optional*):
|
||||
prompt to be encoded
|
||||
negative_prompt (`str` or `List[str]`, *optional*):
|
||||
The prompt or prompts not to guide the image generation. If not defined, one has to pass
|
||||
`negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
|
||||
less than `1`).
|
||||
do_classifier_free_guidance (`bool`, *optional*, defaults to `True`):
|
||||
Whether to use classifier free guidance or not.
|
||||
num_videos_per_prompt (`int`, *optional*, defaults to 1):
|
||||
Number of videos that should be generated per prompt. torch device to place the resulting embeddings on
|
||||
prompt_embeds (`torch.Tensor`, *optional*):
|
||||
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
|
||||
provided, text embeddings will be generated from `prompt` input argument.
|
||||
negative_prompt_embeds (`torch.Tensor`, *optional*):
|
||||
Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
|
||||
weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
|
||||
argument.
|
||||
device: (`torch.device`, *optional*):
|
||||
torch device
|
||||
dtype: (`torch.dtype`, *optional*):
|
||||
torch dtype
|
||||
"""
|
||||
device = device or self._execution_device
|
||||
|
||||
prompt = [prompt] if isinstance(prompt, str) else prompt
|
||||
if prompt is not None:
|
||||
batch_size = len(prompt)
|
||||
else:
|
||||
batch_size = prompt_embeds.shape[0]
|
||||
|
||||
if prompt_embeds is None:
|
||||
prompt_embeds = self._get_t5_prompt_embeds(
|
||||
prompt=prompt,
|
||||
num_videos_per_prompt=num_videos_per_prompt,
|
||||
max_sequence_length=max_sequence_length,
|
||||
device=device,
|
||||
dtype=dtype,
|
||||
)
|
||||
|
||||
if do_classifier_free_guidance and negative_prompt_embeds is None:
|
||||
negative_prompt = negative_prompt or ""
|
||||
negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt
|
||||
|
||||
if prompt is not None and type(prompt) is not type(negative_prompt):
|
||||
raise TypeError(
|
||||
f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
|
||||
f" {type(prompt)}."
|
||||
)
|
||||
elif batch_size != len(negative_prompt):
|
||||
raise ValueError(
|
||||
f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
|
||||
f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
|
||||
" the batch size of `prompt`."
|
||||
)
|
||||
|
||||
negative_prompt_embeds = self._get_t5_prompt_embeds(
|
||||
prompt=negative_prompt,
|
||||
num_videos_per_prompt=num_videos_per_prompt,
|
||||
max_sequence_length=max_sequence_length,
|
||||
device=device,
|
||||
dtype=dtype,
|
||||
)
|
||||
|
||||
return prompt_embeds, negative_prompt_embeds
|
||||
|
||||
def check_inputs(
|
||||
self,
|
||||
prompt,
|
||||
negative_prompt,
|
||||
height,
|
||||
width,
|
||||
prompt_embeds=None,
|
||||
negative_prompt_embeds=None,
|
||||
callback_on_step_end_tensor_inputs=None,
|
||||
):
|
||||
if height % 16 != 0 or width % 16 != 0:
|
||||
raise ValueError(f"`height` and `width` have to be divisible by 16 but are {height} and {width}.")
|
||||
|
||||
if callback_on_step_end_tensor_inputs is not None and not all(
|
||||
k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
|
||||
):
|
||||
raise ValueError(
|
||||
f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}"
|
||||
)
|
||||
|
||||
if prompt is not None and prompt_embeds is not None:
|
||||
raise ValueError(
|
||||
f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
|
||||
" only forward one of the two."
|
||||
)
|
||||
elif negative_prompt is not None and negative_prompt_embeds is not None:
|
||||
raise ValueError(
|
||||
f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`: {negative_prompt_embeds}. Please make sure to"
|
||||
" only forward one of the two."
|
||||
)
|
||||
elif prompt is None and prompt_embeds is None:
|
||||
raise ValueError(
|
||||
"Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
|
||||
)
|
||||
elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
|
||||
raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
|
||||
elif negative_prompt is not None and (
|
||||
not isinstance(negative_prompt, str) and not isinstance(negative_prompt, list)
|
||||
):
|
||||
raise ValueError(f"`negative_prompt` has to be of type `str` or `list` but is {type(negative_prompt)}")
|
||||
|
||||
def prepare_latents(
|
||||
self,
|
||||
batch_size: int,
|
||||
num_channels_latents: int = 16,
|
||||
height: int = 480,
|
||||
width: int = 832,
|
||||
num_frames: int = 81,
|
||||
dtype: Optional[torch.dtype] = None,
|
||||
device: Optional[torch.device] = None,
|
||||
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
|
||||
latents: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
if latents is not None:
|
||||
return latents.to(device=device, dtype=dtype)
|
||||
|
||||
num_latent_frames = (num_frames - 1) // self.vae_scale_factor_temporal + 1
|
||||
shape = (
|
||||
batch_size,
|
||||
num_channels_latents,
|
||||
num_latent_frames,
|
||||
int(height) // self.vae_scale_factor_spatial,
|
||||
int(width) // self.vae_scale_factor_spatial,
|
||||
)
|
||||
if isinstance(generator, list) and len(generator) != batch_size:
|
||||
raise ValueError(
|
||||
f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
|
||||
f" size of {batch_size}. Make sure the batch size matches the length of the generators."
|
||||
)
|
||||
|
||||
latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
|
||||
return latents
|
||||
|
||||
@property
|
||||
def guidance_scale(self):
|
||||
return self._guidance_scale
|
||||
|
||||
@property
|
||||
def do_classifier_free_guidance(self):
|
||||
return self._guidance_scale > 1.0
|
||||
|
||||
@property
|
||||
def do_spatio_temporal_guidance(self):
|
||||
return self._stg_scale > 0.0
|
||||
|
||||
@property
|
||||
def num_timesteps(self):
|
||||
return self._num_timesteps
|
||||
|
||||
@property
|
||||
def current_timestep(self):
|
||||
return self._current_timestep
|
||||
|
||||
@property
|
||||
def interrupt(self):
|
||||
return self._interrupt
|
||||
|
||||
@property
|
||||
def attention_kwargs(self):
|
||||
return self._attention_kwargs
|
||||
|
||||
@torch.no_grad()
|
||||
@replace_example_docstring(EXAMPLE_DOC_STRING)
|
||||
def __call__(
|
||||
self,
|
||||
prompt: Union[str, List[str]] = None,
|
||||
negative_prompt: Union[str, List[str]] = None,
|
||||
height: int = 480,
|
||||
width: int = 832,
|
||||
num_frames: int = 81,
|
||||
num_inference_steps: int = 50,
|
||||
guidance_scale: float = 5.0,
|
||||
num_videos_per_prompt: Optional[int] = 1,
|
||||
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
|
||||
latents: Optional[torch.Tensor] = None,
|
||||
prompt_embeds: Optional[torch.Tensor] = None,
|
||||
negative_prompt_embeds: Optional[torch.Tensor] = None,
|
||||
output_type: Optional[str] = "np",
|
||||
return_dict: bool = True,
|
||||
attention_kwargs: Optional[Dict[str, Any]] = None,
|
||||
callback_on_step_end: Optional[
|
||||
Union[Callable[[int, int, Dict], None], PipelineCallback, MultiPipelineCallbacks]
|
||||
] = None,
|
||||
callback_on_step_end_tensor_inputs: List[str] = ["latents"],
|
||||
max_sequence_length: int = 512,
|
||||
stg_applied_layers_idx: Optional[List[int]] = [3, 8, 16],
|
||||
stg_scale: Optional[float] = 0.0,
|
||||
):
|
||||
r"""
|
||||
The call function to the pipeline for generation.
|
||||
|
||||
Args:
|
||||
prompt (`str` or `List[str]`, *optional*):
|
||||
The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
|
||||
instead.
|
||||
height (`int`, defaults to `480`):
|
||||
The height in pixels of the generated image.
|
||||
width (`int`, defaults to `832`):
|
||||
The width in pixels of the generated image.
|
||||
num_frames (`int`, defaults to `81`):
|
||||
The number of frames in the generated video.
|
||||
num_inference_steps (`int`, defaults to `50`):
|
||||
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
|
||||
expense of slower inference.
|
||||
guidance_scale (`float`, defaults to `5.0`):
|
||||
Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
|
||||
`guidance_scale` is defined as `w` of equation 2. of [Imagen
|
||||
Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
|
||||
1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
|
||||
usually at the expense of lower image quality.
|
||||
num_videos_per_prompt (`int`, *optional*, defaults to 1):
|
||||
The number of images to generate per prompt.
|
||||
generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
|
||||
A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make
|
||||
generation deterministic.
|
||||
latents (`torch.Tensor`, *optional*):
|
||||
Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs for image
|
||||
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
|
||||
tensor is generated by sampling using the supplied random `generator`.
|
||||
prompt_embeds (`torch.Tensor`, *optional*):
|
||||
Pre-generated text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not
|
||||
provided, text embeddings are generated from the `prompt` input argument.
|
||||
output_type (`str`, *optional*, defaults to `"pil"`):
|
||||
The output format of the generated image. Choose between `PIL.Image` or `np.array`.
|
||||
return_dict (`bool`, *optional*, defaults to `True`):
|
||||
Whether or not to return a [`WanPipelineOutput`] instead of a plain tuple.
|
||||
attention_kwargs (`dict`, *optional*):
|
||||
A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
|
||||
`self.processor` in
|
||||
[diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
|
||||
callback_on_step_end (`Callable`, `PipelineCallback`, `MultiPipelineCallbacks`, *optional*):
|
||||
A function or a subclass of `PipelineCallback` or `MultiPipelineCallbacks` that is called at the end of
|
||||
each denoising step during the inference. with the following arguments: `callback_on_step_end(self:
|
||||
DiffusionPipeline, step: int, timestep: int, callback_kwargs: Dict)`. `callback_kwargs` will include a
|
||||
list of all tensors as specified by `callback_on_step_end_tensor_inputs`.
|
||||
callback_on_step_end_tensor_inputs (`List`, *optional*):
|
||||
The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
|
||||
will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
|
||||
`._callback_tensor_inputs` attribute of your pipeline class.
|
||||
autocast_dtype (`torch.dtype`, *optional*, defaults to `torch.bfloat16`):
|
||||
The dtype to use for the torch.amp.autocast.
|
||||
|
||||
Examples:
|
||||
|
||||
Returns:
|
||||
[`~WanPipelineOutput`] or `tuple`:
|
||||
If `return_dict` is `True`, [`WanPipelineOutput`] is returned, otherwise a `tuple` is returned where
|
||||
the first element is a list with the generated images and the second element is a list of `bool`s
|
||||
indicating whether the corresponding generated image contains "not-safe-for-work" (nsfw) content.
|
||||
"""
|
||||
|
||||
if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)):
|
||||
callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs
|
||||
|
||||
# 1. Check inputs. Raise error if not correct
|
||||
self.check_inputs(
|
||||
prompt,
|
||||
negative_prompt,
|
||||
height,
|
||||
width,
|
||||
prompt_embeds,
|
||||
negative_prompt_embeds,
|
||||
callback_on_step_end_tensor_inputs,
|
||||
)
|
||||
|
||||
self._guidance_scale = guidance_scale
|
||||
self._stg_scale = stg_scale
|
||||
self._attention_kwargs = attention_kwargs
|
||||
self._current_timestep = None
|
||||
self._interrupt = False
|
||||
|
||||
device = self._execution_device
|
||||
|
||||
# 2. Define call parameters
|
||||
if prompt is not None and isinstance(prompt, str):
|
||||
batch_size = 1
|
||||
elif prompt is not None and isinstance(prompt, list):
|
||||
batch_size = len(prompt)
|
||||
else:
|
||||
batch_size = prompt_embeds.shape[0]
|
||||
|
||||
# 3. Encode input prompt
|
||||
prompt_embeds, negative_prompt_embeds = self.encode_prompt(
|
||||
prompt=prompt,
|
||||
negative_prompt=negative_prompt,
|
||||
do_classifier_free_guidance=self.do_classifier_free_guidance,
|
||||
num_videos_per_prompt=num_videos_per_prompt,
|
||||
prompt_embeds=prompt_embeds,
|
||||
negative_prompt_embeds=negative_prompt_embeds,
|
||||
max_sequence_length=max_sequence_length,
|
||||
device=device,
|
||||
)
|
||||
|
||||
transformer_dtype = self.transformer.dtype
|
||||
prompt_embeds = prompt_embeds.to(transformer_dtype)
|
||||
if negative_prompt_embeds is not None:
|
||||
negative_prompt_embeds = negative_prompt_embeds.to(transformer_dtype)
|
||||
|
||||
# 4. Prepare timesteps
|
||||
self.scheduler.set_timesteps(num_inference_steps, device=device)
|
||||
timesteps = self.scheduler.timesteps
|
||||
|
||||
# 5. Prepare latent variables
|
||||
num_channels_latents = self.transformer.config.in_channels
|
||||
latents = self.prepare_latents(
|
||||
batch_size * num_videos_per_prompt,
|
||||
num_channels_latents,
|
||||
height,
|
||||
width,
|
||||
num_frames,
|
||||
torch.float32,
|
||||
device,
|
||||
generator,
|
||||
latents,
|
||||
)
|
||||
|
||||
# 6. Denoising loop
|
||||
num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
|
||||
self._num_timesteps = len(timesteps)
|
||||
|
||||
with self.progress_bar(total=num_inference_steps) as progress_bar:
|
||||
for i, t in enumerate(timesteps):
|
||||
if self.interrupt:
|
||||
continue
|
||||
|
||||
self._current_timestep = t
|
||||
latent_model_input = latents.to(transformer_dtype)
|
||||
timestep = t.expand(latents.shape[0])
|
||||
|
||||
if self.do_spatio_temporal_guidance:
|
||||
for idx, block in enumerate(self.transformer.blocks):
|
||||
block.forward = types.MethodType(forward_without_stg, block)
|
||||
|
||||
noise_pred = self.transformer(
|
||||
hidden_states=latent_model_input,
|
||||
timestep=timestep,
|
||||
encoder_hidden_states=prompt_embeds,
|
||||
attention_kwargs=attention_kwargs,
|
||||
return_dict=False,
|
||||
)[0]
|
||||
|
||||
if self.do_classifier_free_guidance:
|
||||
noise_uncond = self.transformer(
|
||||
hidden_states=latent_model_input,
|
||||
timestep=timestep,
|
||||
encoder_hidden_states=negative_prompt_embeds,
|
||||
attention_kwargs=attention_kwargs,
|
||||
return_dict=False,
|
||||
)[0]
|
||||
if self.do_spatio_temporal_guidance:
|
||||
for idx, block in enumerate(self.transformer.blocks):
|
||||
if idx in stg_applied_layers_idx:
|
||||
block.forward = types.MethodType(forward_with_stg, block)
|
||||
noise_perturb = self.transformer(
|
||||
hidden_states=latent_model_input,
|
||||
timestep=timestep,
|
||||
encoder_hidden_states=prompt_embeds,
|
||||
attention_kwargs=attention_kwargs,
|
||||
return_dict=False,
|
||||
)[0]
|
||||
noise_pred = (
|
||||
noise_uncond
|
||||
+ guidance_scale * (noise_pred - noise_uncond)
|
||||
+ self._stg_scale * (noise_pred - noise_perturb)
|
||||
)
|
||||
else:
|
||||
noise_pred = noise_uncond + guidance_scale * (noise_pred - noise_uncond)
|
||||
|
||||
# compute the previous noisy sample x_t -> x_t-1
|
||||
latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0]
|
||||
|
||||
if callback_on_step_end is not None:
|
||||
callback_kwargs = {}
|
||||
for k in callback_on_step_end_tensor_inputs:
|
||||
callback_kwargs[k] = locals()[k]
|
||||
callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
|
||||
|
||||
latents = callback_outputs.pop("latents", latents)
|
||||
prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
|
||||
negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds)
|
||||
|
||||
# call the callback, if provided
|
||||
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
|
||||
progress_bar.update()
|
||||
|
||||
if XLA_AVAILABLE:
|
||||
xm.mark_step()
|
||||
|
||||
self._current_timestep = None
|
||||
|
||||
if not output_type == "latent":
|
||||
latents = latents.to(self.vae.dtype)
|
||||
latents_mean = (
|
||||
torch.tensor(self.vae.config.latents_mean)
|
||||
.view(1, self.vae.config.z_dim, 1, 1, 1)
|
||||
.to(latents.device, latents.dtype)
|
||||
)
|
||||
latents_std = 1.0 / torch.tensor(self.vae.config.latents_std).view(1, self.vae.config.z_dim, 1, 1, 1).to(
|
||||
latents.device, latents.dtype
|
||||
)
|
||||
latents = latents / latents_std + latents_mean
|
||||
video = self.vae.decode(latents, return_dict=False)[0]
|
||||
video = self.video_processor.postprocess_video(video, output_type=output_type)
|
||||
else:
|
||||
video = latents
|
||||
|
||||
# Offload all models
|
||||
self.maybe_free_model_hooks()
|
||||
|
||||
if not return_dict:
|
||||
return (video,)
|
||||
|
||||
return WanPipelineOutput(frames=video)
|
||||
@@ -152,9 +152,7 @@ def log_validation(
|
||||
validation_prompt = log["validation_prompt"]
|
||||
validation_image = log["validation_image"]
|
||||
|
||||
formatted_images = []
|
||||
|
||||
formatted_images.append(np.asarray(validation_image))
|
||||
formatted_images = [np.asarray(validation_image)]
|
||||
|
||||
for image in images:
|
||||
formatted_images.append(np.asarray(image))
|
||||
@@ -929,17 +927,22 @@ def main(args):
|
||||
)
|
||||
|
||||
# Scheduler and math around the number of training steps.
|
||||
overrode_max_train_steps = False
|
||||
num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
|
||||
# Check the PR https://github.com/huggingface/diffusers/pull/8312 for detailed explanation.
|
||||
num_warmup_steps_for_scheduler = args.lr_warmup_steps * accelerator.num_processes
|
||||
if args.max_train_steps is None:
|
||||
args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
|
||||
overrode_max_train_steps = True
|
||||
len_train_dataloader_after_sharding = math.ceil(len(train_dataloader) / accelerator.num_processes)
|
||||
num_update_steps_per_epoch = math.ceil(len_train_dataloader_after_sharding / args.gradient_accumulation_steps)
|
||||
num_training_steps_for_scheduler = (
|
||||
args.num_train_epochs * num_update_steps_per_epoch * accelerator.num_processes
|
||||
)
|
||||
else:
|
||||
num_training_steps_for_scheduler = args.max_train_steps * accelerator.num_processes
|
||||
|
||||
lr_scheduler = get_scheduler(
|
||||
args.lr_scheduler,
|
||||
optimizer=optimizer,
|
||||
num_warmup_steps=args.lr_warmup_steps * accelerator.num_processes,
|
||||
num_training_steps=args.max_train_steps * accelerator.num_processes,
|
||||
num_warmup_steps=num_warmup_steps_for_scheduler,
|
||||
num_training_steps=num_training_steps_for_scheduler,
|
||||
num_cycles=args.lr_num_cycles,
|
||||
power=args.lr_power,
|
||||
)
|
||||
@@ -964,8 +967,14 @@ def main(args):
|
||||
|
||||
# We need to recalculate our total training steps as the size of the training dataloader may have changed.
|
||||
num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
|
||||
if overrode_max_train_steps:
|
||||
if args.max_train_steps is None:
|
||||
args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
|
||||
if num_training_steps_for_scheduler != args.max_train_steps * accelerator.num_processes:
|
||||
logger.warning(
|
||||
f"The length of the 'train_dataloader' after 'accelerator.prepare' ({len(train_dataloader)}) does not match "
|
||||
f"the expected length ({len_train_dataloader_after_sharding}) when the learning rate scheduler was created. "
|
||||
f"This inconsistency may result in the learning rate scheduler not functioning properly."
|
||||
)
|
||||
# Afterwards we recalculate our number of training epochs
|
||||
args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)
|
||||
|
||||
|
||||
@@ -166,9 +166,7 @@ def log_validation(
|
||||
validation_prompt = log["validation_prompt"]
|
||||
validation_image = log["validation_image"]
|
||||
|
||||
formatted_images = []
|
||||
|
||||
formatted_images.append(np.asarray(validation_image))
|
||||
formatted_images = [np.asarray(validation_image)]
|
||||
|
||||
for image in images:
|
||||
formatted_images.append(np.asarray(image))
|
||||
|
||||
@@ -17,6 +17,7 @@ import argparse
|
||||
import contextlib
|
||||
import copy
|
||||
import functools
|
||||
import gc
|
||||
import logging
|
||||
import math
|
||||
import os
|
||||
@@ -52,6 +53,7 @@ from diffusers.optimization import get_scheduler
|
||||
from diffusers.training_utils import compute_density_for_timestep_sampling, compute_loss_weighting_for_sd3, free_memory
|
||||
from diffusers.utils import check_min_version, is_wandb_available, make_image_grid
|
||||
from diffusers.utils.hub_utils import load_or_create_model_card, populate_model_card
|
||||
from diffusers.utils.testing_utils import backend_empty_cache
|
||||
from diffusers.utils.torch_utils import is_compiled_module
|
||||
|
||||
|
||||
@@ -74,8 +76,9 @@ def log_validation(controlnet, args, accelerator, weight_dtype, step, is_final_v
|
||||
|
||||
pipeline = StableDiffusion3ControlNetPipeline.from_pretrained(
|
||||
args.pretrained_model_name_or_path,
|
||||
controlnet=controlnet,
|
||||
controlnet=None,
|
||||
safety_checker=None,
|
||||
transformer=None,
|
||||
revision=args.revision,
|
||||
variant=args.variant,
|
||||
torch_dtype=weight_dtype,
|
||||
@@ -102,18 +105,55 @@ def log_validation(controlnet, args, accelerator, weight_dtype, step, is_final_v
|
||||
"number of `args.validation_image` and `args.validation_prompt` should be checked in `parse_args`"
|
||||
)
|
||||
|
||||
with torch.no_grad():
|
||||
(
|
||||
prompt_embeds,
|
||||
negative_prompt_embeds,
|
||||
pooled_prompt_embeds,
|
||||
negative_pooled_prompt_embeds,
|
||||
) = pipeline.encode_prompt(
|
||||
validation_prompts,
|
||||
prompt_2=None,
|
||||
prompt_3=None,
|
||||
)
|
||||
|
||||
del pipeline
|
||||
gc.collect()
|
||||
backend_empty_cache(accelerator.device.type)
|
||||
|
||||
pipeline = StableDiffusion3ControlNetPipeline.from_pretrained(
|
||||
args.pretrained_model_name_or_path,
|
||||
controlnet=controlnet,
|
||||
safety_checker=None,
|
||||
text_encoder=None,
|
||||
text_encoder_2=None,
|
||||
text_encoder_3=None,
|
||||
revision=args.revision,
|
||||
variant=args.variant,
|
||||
torch_dtype=weight_dtype,
|
||||
)
|
||||
pipeline.enable_model_cpu_offload(device=accelerator.device.type)
|
||||
pipeline.set_progress_bar_config(disable=True)
|
||||
|
||||
image_logs = []
|
||||
inference_ctx = contextlib.nullcontext() if is_final_validation else torch.autocast(accelerator.device.type)
|
||||
|
||||
for validation_prompt, validation_image in zip(validation_prompts, validation_images):
|
||||
for i, validation_image in enumerate(validation_images):
|
||||
validation_image = Image.open(validation_image).convert("RGB")
|
||||
validation_prompt = validation_prompts[i]
|
||||
|
||||
images = []
|
||||
|
||||
for _ in range(args.num_validation_images):
|
||||
with inference_ctx:
|
||||
image = pipeline(
|
||||
validation_prompt, control_image=validation_image, num_inference_steps=20, generator=generator
|
||||
prompt_embeds=prompt_embeds[i].unsqueeze(0),
|
||||
negative_prompt_embeds=negative_prompt_embeds[i].unsqueeze(0),
|
||||
pooled_prompt_embeds=pooled_prompt_embeds[i].unsqueeze(0),
|
||||
negative_pooled_prompt_embeds=negative_pooled_prompt_embeds[i].unsqueeze(0),
|
||||
control_image=validation_image,
|
||||
num_inference_steps=20,
|
||||
generator=generator,
|
||||
).images[0]
|
||||
|
||||
images.append(image)
|
||||
@@ -655,6 +695,7 @@ def make_train_dataset(args, tokenizer_one, tokenizer_two, tokenizer_three, acce
|
||||
dataset = load_dataset(
|
||||
args.train_data_dir,
|
||||
cache_dir=args.cache_dir,
|
||||
trust_remote_code=True,
|
||||
)
|
||||
# See more about loading custom images at
|
||||
# https://huggingface.co/docs/datasets/v2.0.0/en/dataset_script
|
||||
@@ -1283,8 +1324,8 @@ def main(args):
|
||||
noisy_model_input = (1.0 - sigmas) * model_input + sigmas * noise
|
||||
|
||||
# Get the text embedding for conditioning
|
||||
prompt_embeds = batch["prompt_embeds"]
|
||||
pooled_prompt_embeds = batch["pooled_prompt_embeds"]
|
||||
prompt_embeds = batch["prompt_embeds"].to(dtype=weight_dtype)
|
||||
pooled_prompt_embeds = batch["pooled_prompt_embeds"].to(dtype=weight_dtype)
|
||||
|
||||
# controlnet(s) inference
|
||||
controlnet_image = batch["conditioning_pixel_values"].to(dtype=weight_dtype)
|
||||
|
||||
@@ -157,9 +157,7 @@ def log_validation(vae, unet, controlnet, args, accelerator, weight_dtype, step,
|
||||
validation_prompt = log["validation_prompt"]
|
||||
validation_image = log["validation_image"]
|
||||
|
||||
formatted_images = []
|
||||
|
||||
formatted_images.append(np.asarray(validation_image))
|
||||
formatted_images = [np.asarray(validation_image)]
|
||||
|
||||
for image in images:
|
||||
formatted_images.append(np.asarray(image))
|
||||
|
||||
@@ -895,7 +895,10 @@ def _encode_prompt_with_t5(
|
||||
|
||||
prompt_embeds = text_encoder(text_input_ids.to(device))[0]
|
||||
|
||||
dtype = text_encoder.dtype
|
||||
if hasattr(text_encoder, "module"):
|
||||
dtype = text_encoder.module.dtype
|
||||
else:
|
||||
dtype = text_encoder.dtype
|
||||
prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)
|
||||
|
||||
_, seq_len, _ = prompt_embeds.shape
|
||||
@@ -936,9 +939,13 @@ def _encode_prompt_with_clip(
|
||||
|
||||
prompt_embeds = text_encoder(text_input_ids.to(device), output_hidden_states=False)
|
||||
|
||||
if hasattr(text_encoder, "module"):
|
||||
dtype = text_encoder.module.dtype
|
||||
else:
|
||||
dtype = text_encoder.dtype
|
||||
# Use pooled output of CLIPTextModel
|
||||
prompt_embeds = prompt_embeds.pooler_output
|
||||
prompt_embeds = prompt_embeds.to(dtype=text_encoder.dtype, device=device)
|
||||
prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)
|
||||
|
||||
# duplicate text embeddings for each generation per prompt, using mps friendly method
|
||||
prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
|
||||
@@ -958,7 +965,12 @@ def encode_prompt(
|
||||
):
|
||||
prompt = [prompt] if isinstance(prompt, str) else prompt
|
||||
batch_size = len(prompt)
|
||||
dtype = text_encoders[0].dtype
|
||||
|
||||
if hasattr(text_encoders[0], "module"):
|
||||
dtype = text_encoders[0].module.dtype
|
||||
else:
|
||||
dtype = text_encoders[0].dtype
|
||||
|
||||
device = device if device is not None else text_encoders[1].device
|
||||
pooled_prompt_embeds = _encode_prompt_with_clip(
|
||||
text_encoder=text_encoders[0],
|
||||
@@ -1590,7 +1602,7 @@ def main(args):
|
||||
)
|
||||
|
||||
# handle guidance
|
||||
if accelerator.unwrap_model(transformer).config.guidance_embeds:
|
||||
if unwrap_model(transformer).config.guidance_embeds:
|
||||
guidance = torch.tensor([args.guidance_scale], device=accelerator.device)
|
||||
guidance = guidance.expand(model_input.shape[0])
|
||||
else:
|
||||
@@ -1716,9 +1728,9 @@ def main(args):
|
||||
pipeline = FluxPipeline.from_pretrained(
|
||||
args.pretrained_model_name_or_path,
|
||||
vae=vae,
|
||||
text_encoder=accelerator.unwrap_model(text_encoder_one, keep_fp32_wrapper=False),
|
||||
text_encoder_2=accelerator.unwrap_model(text_encoder_two, keep_fp32_wrapper=False),
|
||||
transformer=accelerator.unwrap_model(transformer, keep_fp32_wrapper=False),
|
||||
text_encoder=unwrap_model(text_encoder_one, keep_fp32_wrapper=False),
|
||||
text_encoder_2=unwrap_model(text_encoder_two, keep_fp32_wrapper=False),
|
||||
transformer=unwrap_model(transformer, keep_fp32_wrapper=False),
|
||||
revision=args.revision,
|
||||
variant=args.variant,
|
||||
torch_dtype=weight_dtype,
|
||||
|
||||
@@ -177,16 +177,25 @@ def log_validation(
|
||||
f"Running validation... \n Generating {args.num_validation_images} images with prompt:"
|
||||
f" {args.validation_prompt}."
|
||||
)
|
||||
pipeline = pipeline.to(accelerator.device)
|
||||
pipeline = pipeline.to(accelerator.device, dtype=torch_dtype)
|
||||
pipeline.set_progress_bar_config(disable=True)
|
||||
|
||||
# run inference
|
||||
generator = torch.Generator(device=accelerator.device).manual_seed(args.seed) if args.seed is not None else None
|
||||
# autocast_ctx = torch.autocast(accelerator.device.type) if not is_final_validation else nullcontext()
|
||||
autocast_ctx = nullcontext()
|
||||
autocast_ctx = torch.autocast(accelerator.device.type) if not is_final_validation else nullcontext()
|
||||
|
||||
with autocast_ctx:
|
||||
images = [pipeline(**pipeline_args, generator=generator).images[0] for _ in range(args.num_validation_images)]
|
||||
# pre-calculate prompt embeds, pooled prompt embeds, text ids because t5 does not support autocast
|
||||
with torch.no_grad():
|
||||
prompt_embeds, pooled_prompt_embeds, text_ids = pipeline.encode_prompt(
|
||||
pipeline_args["prompt"], prompt_2=pipeline_args["prompt"]
|
||||
)
|
||||
images = []
|
||||
for _ in range(args.num_validation_images):
|
||||
with autocast_ctx:
|
||||
image = pipeline(
|
||||
prompt_embeds=prompt_embeds, pooled_prompt_embeds=pooled_prompt_embeds, generator=generator
|
||||
).images[0]
|
||||
images.append(image)
|
||||
|
||||
for tracker in accelerator.trackers:
|
||||
phase_name = "test" if is_final_validation else "validation"
|
||||
@@ -203,8 +212,7 @@ def log_validation(
|
||||
)
|
||||
|
||||
del pipeline
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.empty_cache()
|
||||
free_memory()
|
||||
|
||||
return images
|
||||
|
||||
@@ -932,7 +940,10 @@ def _encode_prompt_with_t5(
|
||||
|
||||
prompt_embeds = text_encoder(text_input_ids.to(device))[0]
|
||||
|
||||
dtype = text_encoder.dtype
|
||||
if hasattr(text_encoder, "module"):
|
||||
dtype = text_encoder.module.dtype
|
||||
else:
|
||||
dtype = text_encoder.dtype
|
||||
prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)
|
||||
|
||||
_, seq_len, _ = prompt_embeds.shape
|
||||
@@ -973,9 +984,13 @@ def _encode_prompt_with_clip(
|
||||
|
||||
prompt_embeds = text_encoder(text_input_ids.to(device), output_hidden_states=False)
|
||||
|
||||
if hasattr(text_encoder, "module"):
|
||||
dtype = text_encoder.module.dtype
|
||||
else:
|
||||
dtype = text_encoder.dtype
|
||||
# Use pooled output of CLIPTextModel
|
||||
prompt_embeds = prompt_embeds.pooler_output
|
||||
prompt_embeds = prompt_embeds.to(dtype=text_encoder.dtype, device=device)
|
||||
prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)
|
||||
|
||||
# duplicate text embeddings for each generation per prompt, using mps friendly method
|
||||
prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
|
||||
@@ -994,7 +1009,11 @@ def encode_prompt(
|
||||
text_input_ids_list=None,
|
||||
):
|
||||
prompt = [prompt] if isinstance(prompt, str) else prompt
|
||||
dtype = text_encoders[0].dtype
|
||||
|
||||
if hasattr(text_encoders[0], "module"):
|
||||
dtype = text_encoders[0].module.dtype
|
||||
else:
|
||||
dtype = text_encoders[0].dtype
|
||||
|
||||
pooled_prompt_embeds = _encode_prompt_with_clip(
|
||||
text_encoder=text_encoders[0],
|
||||
@@ -1619,7 +1638,7 @@ def main(args):
|
||||
if args.train_text_encoder:
|
||||
text_encoder_one.train()
|
||||
# set top parameter requires_grad = True for gradient checkpointing works
|
||||
accelerator.unwrap_model(text_encoder_one).text_model.embeddings.requires_grad_(True)
|
||||
unwrap_model(text_encoder_one).text_model.embeddings.requires_grad_(True)
|
||||
|
||||
for step, batch in enumerate(train_dataloader):
|
||||
models_to_accumulate = [transformer]
|
||||
@@ -1710,7 +1729,7 @@ def main(args):
|
||||
)
|
||||
|
||||
# handle guidance
|
||||
if accelerator.unwrap_model(transformer).config.guidance_embeds:
|
||||
if unwrap_model(transformer).config.guidance_embeds:
|
||||
guidance = torch.tensor([args.guidance_scale], device=accelerator.device)
|
||||
guidance = guidance.expand(model_input.shape[0])
|
||||
else:
|
||||
@@ -1828,9 +1847,9 @@ def main(args):
|
||||
pipeline = FluxPipeline.from_pretrained(
|
||||
args.pretrained_model_name_or_path,
|
||||
vae=vae,
|
||||
text_encoder=accelerator.unwrap_model(text_encoder_one),
|
||||
text_encoder_2=accelerator.unwrap_model(text_encoder_two),
|
||||
transformer=accelerator.unwrap_model(transformer),
|
||||
text_encoder=unwrap_model(text_encoder_one),
|
||||
text_encoder_2=unwrap_model(text_encoder_two),
|
||||
transformer=unwrap_model(transformer),
|
||||
revision=args.revision,
|
||||
variant=args.variant,
|
||||
torch_dtype=weight_dtype,
|
||||
|
||||
@@ -669,6 +669,16 @@ def parse_args(input_args=None):
|
||||
),
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--image_interpolation_mode",
|
||||
type=str,
|
||||
default="lanczos",
|
||||
choices=[
|
||||
f.lower() for f in dir(transforms.InterpolationMode) if not f.startswith("__") and not f.endswith("__")
|
||||
],
|
||||
help="The image interpolation method to use for resizing images.",
|
||||
)
|
||||
|
||||
if input_args is not None:
|
||||
args = parser.parse_args(input_args)
|
||||
else:
|
||||
@@ -790,7 +800,12 @@ class DreamBoothDataset(Dataset):
|
||||
self.original_sizes = []
|
||||
self.crop_top_lefts = []
|
||||
self.pixel_values = []
|
||||
train_resize = transforms.Resize(size, interpolation=transforms.InterpolationMode.BILINEAR)
|
||||
|
||||
interpolation = getattr(transforms.InterpolationMode, args.image_interpolation_mode.upper(), None)
|
||||
if interpolation is None:
|
||||
raise ValueError(f"Unsupported interpolation mode {interpolation=}.")
|
||||
train_resize = transforms.Resize(size, interpolation=interpolation)
|
||||
|
||||
train_crop = transforms.CenterCrop(size) if center_crop else transforms.RandomCrop(size)
|
||||
train_flip = transforms.RandomHorizontalFlip(p=1.0)
|
||||
train_transforms = transforms.Compose(
|
||||
|
||||
@@ -49,6 +49,7 @@ from diffusers import AutoencoderKL, DDPMScheduler, StableDiffusionInstructPix2P
|
||||
from diffusers.optimization import get_scheduler
|
||||
from diffusers.training_utils import EMAModel
|
||||
from diffusers.utils import check_min_version, deprecate, is_wandb_available
|
||||
from diffusers.utils.constants import DIFFUSERS_REQUEST_TIMEOUT
|
||||
from diffusers.utils.import_utils import is_xformers_available
|
||||
from diffusers.utils.torch_utils import is_compiled_module
|
||||
|
||||
@@ -418,7 +419,7 @@ def convert_to_np(image, resolution):
|
||||
|
||||
|
||||
def download_image(url):
|
||||
image = PIL.Image.open(requests.get(url, stream=True).raw)
|
||||
image = PIL.Image.open(requests.get(url, stream=True, timeout=DIFFUSERS_REQUEST_TIMEOUT).raw)
|
||||
image = PIL.ImageOps.exif_transpose(image)
|
||||
image = image.convert("RGB")
|
||||
return image
|
||||
|
||||
@@ -0,0 +1,40 @@
|
||||
# AnyTextPipeline
|
||||
|
||||
Project page: https://aigcdesigngroup.github.io/homepage_anytext
|
||||
|
||||
"AnyText comprises a diffusion pipeline with two primary elements: an auxiliary latent module and a text embedding module. The former uses inputs like text glyph, position, and masked image to generate latent features for text generation or editing. The latter employs an OCR model for encoding stroke data as embeddings, which blend with image caption embeddings from the tokenizer to generate texts that seamlessly integrate with the background. We employed text-control diffusion loss and text perceptual loss for training to further enhance writing accuracy."
|
||||
|
||||
> **Note:** Each text line that needs to be generated should be enclosed in double quotes.
|
||||
|
||||
For any usage questions, please refer to the [paper](https://arxiv.org/abs/2311.03054).
|
||||
|
||||
[](https://colab.research.google.com/gist/tolgacangoz/b87ec9d2f265b448dd947c9d4a0da389/anytext.ipynb)
|
||||
|
||||
```py
|
||||
# This example requires the `anytext_controlnet.py` file:
|
||||
# !git clone --depth 1 https://github.com/huggingface/diffusers.git
|
||||
# %cd diffusers/examples/research_projects/anytext
|
||||
# Let's choose a font file shared by an HF staff:
|
||||
# !wget https://huggingface.co/spaces/ysharma/TranslateQuotesInImageForwards/resolve/main/arial-unicode-ms.ttf
|
||||
|
||||
import torch
|
||||
from diffusers import DiffusionPipeline
|
||||
from anytext_controlnet import AnyTextControlNetModel
|
||||
from diffusers.utils import load_image
|
||||
|
||||
|
||||
anytext_controlnet = AnyTextControlNetModel.from_pretrained("tolgacangoz/anytext-controlnet", torch_dtype=torch.float16,
|
||||
variant="fp16",)
|
||||
pipe = DiffusionPipeline.from_pretrained("tolgacangoz/anytext", font_path="arial-unicode-ms.ttf",
|
||||
controlnet=anytext_controlnet, torch_dtype=torch.float16,
|
||||
trust_remote_code=False, # One needs to give permission to run this pipeline's code
|
||||
).to("cuda")
|
||||
|
||||
# generate image
|
||||
prompt = 'photo of caramel macchiato coffee on the table, top-down perspective, with "Any" "Text" written on it using cream'
|
||||
draw_pos = load_image("https://raw.githubusercontent.com/tyxsspa/AnyText/refs/heads/main/example_images/gen9.png")
|
||||
# There are two modes: "generate" and "edit". "edit" mode requires `ori_image` parameter for the image to be edited.
|
||||
image = pipe(prompt, num_inference_steps=20, mode="generate", draw_pos=draw_pos,
|
||||
).images[0]
|
||||
image
|
||||
```
|
||||
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,463 @@
|
||||
# Copyright 2023 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
# Based on [AnyText: Multilingual Visual Text Generation And Editing](https://huggingface.co/papers/2311.03054).
|
||||
# Authors: Yuxiang Tuo, Wangmeng Xiang, Jun-Yan He, Yifeng Geng, Xuansong Xie
|
||||
# Code: https://github.com/tyxsspa/AnyText with Apache-2.0 license
|
||||
#
|
||||
# Adapted to Diffusers by [M. Tolga Cangöz](https://github.com/tolgacangoz).
|
||||
|
||||
|
||||
from typing import Any, Dict, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
from diffusers.configuration_utils import register_to_config
|
||||
from diffusers.models.controlnets.controlnet import (
|
||||
ControlNetModel,
|
||||
ControlNetOutput,
|
||||
)
|
||||
from diffusers.utils import logging
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||
|
||||
|
||||
class AnyTextControlNetConditioningEmbedding(nn.Module):
|
||||
"""
|
||||
Quoting from https://arxiv.org/abs/2302.05543: "Stable Diffusion uses a pre-processing method similar to VQ-GAN
|
||||
[11] to convert the entire dataset of 512 × 512 images into smaller 64 × 64 “latent images” for stabilized
|
||||
training. This requires ControlNets to convert image-based conditions to 64 × 64 feature space to match the
|
||||
convolution size. We use a tiny network E(·) of four convolution layers with 4 × 4 kernels and 2 × 2 strides
|
||||
(activated by ReLU, channels are 16, 32, 64, 128, initialized with Gaussian weights, trained jointly with the full
|
||||
model) to encode image-space conditions ... into feature maps ..."
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
conditioning_embedding_channels: int,
|
||||
glyph_channels=1,
|
||||
position_channels=1,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.glyph_block = nn.Sequential(
|
||||
nn.Conv2d(glyph_channels, 8, 3, padding=1),
|
||||
nn.SiLU(),
|
||||
nn.Conv2d(8, 8, 3, padding=1),
|
||||
nn.SiLU(),
|
||||
nn.Conv2d(8, 16, 3, padding=1, stride=2),
|
||||
nn.SiLU(),
|
||||
nn.Conv2d(16, 16, 3, padding=1),
|
||||
nn.SiLU(),
|
||||
nn.Conv2d(16, 32, 3, padding=1, stride=2),
|
||||
nn.SiLU(),
|
||||
nn.Conv2d(32, 32, 3, padding=1),
|
||||
nn.SiLU(),
|
||||
nn.Conv2d(32, 96, 3, padding=1, stride=2),
|
||||
nn.SiLU(),
|
||||
nn.Conv2d(96, 96, 3, padding=1),
|
||||
nn.SiLU(),
|
||||
nn.Conv2d(96, 256, 3, padding=1, stride=2),
|
||||
nn.SiLU(),
|
||||
)
|
||||
|
||||
self.position_block = nn.Sequential(
|
||||
nn.Conv2d(position_channels, 8, 3, padding=1),
|
||||
nn.SiLU(),
|
||||
nn.Conv2d(8, 8, 3, padding=1),
|
||||
nn.SiLU(),
|
||||
nn.Conv2d(8, 16, 3, padding=1, stride=2),
|
||||
nn.SiLU(),
|
||||
nn.Conv2d(16, 16, 3, padding=1),
|
||||
nn.SiLU(),
|
||||
nn.Conv2d(16, 32, 3, padding=1, stride=2),
|
||||
nn.SiLU(),
|
||||
nn.Conv2d(32, 32, 3, padding=1),
|
||||
nn.SiLU(),
|
||||
nn.Conv2d(32, 64, 3, padding=1, stride=2),
|
||||
nn.SiLU(),
|
||||
)
|
||||
|
||||
self.fuse_block = nn.Conv2d(256 + 64 + 4, conditioning_embedding_channels, 3, padding=1)
|
||||
|
||||
def forward(self, glyphs, positions, text_info):
|
||||
glyph_embedding = self.glyph_block(glyphs.to(self.glyph_block[0].weight.device))
|
||||
position_embedding = self.position_block(positions.to(self.position_block[0].weight.device))
|
||||
guided_hint = self.fuse_block(torch.cat([glyph_embedding, position_embedding, text_info["masked_x"]], dim=1))
|
||||
|
||||
return guided_hint
|
||||
|
||||
|
||||
class AnyTextControlNetModel(ControlNetModel):
|
||||
"""
|
||||
A AnyTextControlNetModel model.
|
||||
|
||||
Args:
|
||||
in_channels (`int`, defaults to 4):
|
||||
The number of channels in the input sample.
|
||||
flip_sin_to_cos (`bool`, defaults to `True`):
|
||||
Whether to flip the sin to cos in the time embedding.
|
||||
freq_shift (`int`, defaults to 0):
|
||||
The frequency shift to apply to the time embedding.
|
||||
down_block_types (`tuple[str]`, defaults to `("CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "DownBlock2D")`):
|
||||
The tuple of downsample blocks to use.
|
||||
only_cross_attention (`Union[bool, Tuple[bool]]`, defaults to `False`):
|
||||
block_out_channels (`tuple[int]`, defaults to `(320, 640, 1280, 1280)`):
|
||||
The tuple of output channels for each block.
|
||||
layers_per_block (`int`, defaults to 2):
|
||||
The number of layers per block.
|
||||
downsample_padding (`int`, defaults to 1):
|
||||
The padding to use for the downsampling convolution.
|
||||
mid_block_scale_factor (`float`, defaults to 1):
|
||||
The scale factor to use for the mid block.
|
||||
act_fn (`str`, defaults to "silu"):
|
||||
The activation function to use.
|
||||
norm_num_groups (`int`, *optional*, defaults to 32):
|
||||
The number of groups to use for the normalization. If None, normalization and activation layers is skipped
|
||||
in post-processing.
|
||||
norm_eps (`float`, defaults to 1e-5):
|
||||
The epsilon to use for the normalization.
|
||||
cross_attention_dim (`int`, defaults to 1280):
|
||||
The dimension of the cross attention features.
|
||||
transformer_layers_per_block (`int` or `Tuple[int]`, *optional*, defaults to 1):
|
||||
The number of transformer blocks of type [`~models.attention.BasicTransformerBlock`]. Only relevant for
|
||||
[`~models.unet_2d_blocks.CrossAttnDownBlock2D`], [`~models.unet_2d_blocks.CrossAttnUpBlock2D`],
|
||||
[`~models.unet_2d_blocks.UNetMidBlock2DCrossAttn`].
|
||||
encoder_hid_dim (`int`, *optional*, defaults to None):
|
||||
If `encoder_hid_dim_type` is defined, `encoder_hidden_states` will be projected from `encoder_hid_dim`
|
||||
dimension to `cross_attention_dim`.
|
||||
encoder_hid_dim_type (`str`, *optional*, defaults to `None`):
|
||||
If given, the `encoder_hidden_states` and potentially other embeddings are down-projected to text
|
||||
embeddings of dimension `cross_attention` according to `encoder_hid_dim_type`.
|
||||
attention_head_dim (`Union[int, Tuple[int]]`, defaults to 8):
|
||||
The dimension of the attention heads.
|
||||
use_linear_projection (`bool`, defaults to `False`):
|
||||
class_embed_type (`str`, *optional*, defaults to `None`):
|
||||
The type of class embedding to use which is ultimately summed with the time embeddings. Choose from None,
|
||||
`"timestep"`, `"identity"`, `"projection"`, or `"simple_projection"`.
|
||||
addition_embed_type (`str`, *optional*, defaults to `None`):
|
||||
Configures an optional embedding which will be summed with the time embeddings. Choose from `None` or
|
||||
"text". "text" will use the `TextTimeEmbedding` layer.
|
||||
num_class_embeds (`int`, *optional*, defaults to 0):
|
||||
Input dimension of the learnable embedding matrix to be projected to `time_embed_dim`, when performing
|
||||
class conditioning with `class_embed_type` equal to `None`.
|
||||
upcast_attention (`bool`, defaults to `False`):
|
||||
resnet_time_scale_shift (`str`, defaults to `"default"`):
|
||||
Time scale shift config for ResNet blocks (see `ResnetBlock2D`). Choose from `default` or `scale_shift`.
|
||||
projection_class_embeddings_input_dim (`int`, *optional*, defaults to `None`):
|
||||
The dimension of the `class_labels` input when `class_embed_type="projection"`. Required when
|
||||
`class_embed_type="projection"`.
|
||||
controlnet_conditioning_channel_order (`str`, defaults to `"rgb"`):
|
||||
The channel order of conditional image. Will convert to `rgb` if it's `bgr`.
|
||||
conditioning_embedding_out_channels (`tuple[int]`, *optional*, defaults to `(16, 32, 96, 256)`):
|
||||
The tuple of output channel for each block in the `conditioning_embedding` layer.
|
||||
global_pool_conditions (`bool`, defaults to `False`):
|
||||
TODO(Patrick) - unused parameter.
|
||||
addition_embed_type_num_heads (`int`, defaults to 64):
|
||||
The number of heads to use for the `TextTimeEmbedding` layer.
|
||||
"""
|
||||
|
||||
_supports_gradient_checkpointing = True
|
||||
|
||||
@register_to_config
|
||||
def __init__(
|
||||
self,
|
||||
in_channels: int = 4,
|
||||
conditioning_channels: int = 1,
|
||||
flip_sin_to_cos: bool = True,
|
||||
freq_shift: int = 0,
|
||||
down_block_types: Tuple[str, ...] = (
|
||||
"CrossAttnDownBlock2D",
|
||||
"CrossAttnDownBlock2D",
|
||||
"CrossAttnDownBlock2D",
|
||||
"DownBlock2D",
|
||||
),
|
||||
mid_block_type: Optional[str] = "UNetMidBlock2DCrossAttn",
|
||||
only_cross_attention: Union[bool, Tuple[bool]] = False,
|
||||
block_out_channels: Tuple[int, ...] = (320, 640, 1280, 1280),
|
||||
layers_per_block: int = 2,
|
||||
downsample_padding: int = 1,
|
||||
mid_block_scale_factor: float = 1,
|
||||
act_fn: str = "silu",
|
||||
norm_num_groups: Optional[int] = 32,
|
||||
norm_eps: float = 1e-5,
|
||||
cross_attention_dim: int = 1280,
|
||||
transformer_layers_per_block: Union[int, Tuple[int, ...]] = 1,
|
||||
encoder_hid_dim: Optional[int] = None,
|
||||
encoder_hid_dim_type: Optional[str] = None,
|
||||
attention_head_dim: Union[int, Tuple[int, ...]] = 8,
|
||||
num_attention_heads: Optional[Union[int, Tuple[int, ...]]] = None,
|
||||
use_linear_projection: bool = False,
|
||||
class_embed_type: Optional[str] = None,
|
||||
addition_embed_type: Optional[str] = None,
|
||||
addition_time_embed_dim: Optional[int] = None,
|
||||
num_class_embeds: Optional[int] = None,
|
||||
upcast_attention: bool = False,
|
||||
resnet_time_scale_shift: str = "default",
|
||||
projection_class_embeddings_input_dim: Optional[int] = None,
|
||||
controlnet_conditioning_channel_order: str = "rgb",
|
||||
conditioning_embedding_out_channels: Optional[Tuple[int, ...]] = (16, 32, 96, 256),
|
||||
global_pool_conditions: bool = False,
|
||||
addition_embed_type_num_heads: int = 64,
|
||||
):
|
||||
super().__init__(
|
||||
in_channels,
|
||||
conditioning_channels,
|
||||
flip_sin_to_cos,
|
||||
freq_shift,
|
||||
down_block_types,
|
||||
mid_block_type,
|
||||
only_cross_attention,
|
||||
block_out_channels,
|
||||
layers_per_block,
|
||||
downsample_padding,
|
||||
mid_block_scale_factor,
|
||||
act_fn,
|
||||
norm_num_groups,
|
||||
norm_eps,
|
||||
cross_attention_dim,
|
||||
transformer_layers_per_block,
|
||||
encoder_hid_dim,
|
||||
encoder_hid_dim_type,
|
||||
attention_head_dim,
|
||||
num_attention_heads,
|
||||
use_linear_projection,
|
||||
class_embed_type,
|
||||
addition_embed_type,
|
||||
addition_time_embed_dim,
|
||||
num_class_embeds,
|
||||
upcast_attention,
|
||||
resnet_time_scale_shift,
|
||||
projection_class_embeddings_input_dim,
|
||||
controlnet_conditioning_channel_order,
|
||||
conditioning_embedding_out_channels,
|
||||
global_pool_conditions,
|
||||
addition_embed_type_num_heads,
|
||||
)
|
||||
|
||||
# control net conditioning embedding
|
||||
self.controlnet_cond_embedding = AnyTextControlNetConditioningEmbedding(
|
||||
conditioning_embedding_channels=block_out_channels[0],
|
||||
glyph_channels=conditioning_channels,
|
||||
position_channels=conditioning_channels,
|
||||
)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
sample: torch.Tensor,
|
||||
timestep: Union[torch.Tensor, float, int],
|
||||
encoder_hidden_states: torch.Tensor,
|
||||
controlnet_cond: torch.Tensor,
|
||||
conditioning_scale: float = 1.0,
|
||||
class_labels: Optional[torch.Tensor] = None,
|
||||
timestep_cond: Optional[torch.Tensor] = None,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
added_cond_kwargs: Optional[Dict[str, torch.Tensor]] = None,
|
||||
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
|
||||
guess_mode: bool = False,
|
||||
return_dict: bool = True,
|
||||
) -> Union[ControlNetOutput, Tuple[Tuple[torch.Tensor, ...], torch.Tensor]]:
|
||||
"""
|
||||
The [`~PromptDiffusionControlNetModel`] forward method.
|
||||
|
||||
Args:
|
||||
sample (`torch.Tensor`):
|
||||
The noisy input tensor.
|
||||
timestep (`Union[torch.Tensor, float, int]`):
|
||||
The number of timesteps to denoise an input.
|
||||
encoder_hidden_states (`torch.Tensor`):
|
||||
The encoder hidden states.
|
||||
#controlnet_cond (`torch.Tensor`):
|
||||
# The conditional input tensor of shape `(batch_size, sequence_length, hidden_size)`.
|
||||
conditioning_scale (`float`, defaults to `1.0`):
|
||||
The scale factor for ControlNet outputs.
|
||||
class_labels (`torch.Tensor`, *optional*, defaults to `None`):
|
||||
Optional class labels for conditioning. Their embeddings will be summed with the timestep embeddings.
|
||||
timestep_cond (`torch.Tensor`, *optional*, defaults to `None`):
|
||||
Additional conditional embeddings for timestep. If provided, the embeddings will be summed with the
|
||||
timestep_embedding passed through the `self.time_embedding` layer to obtain the final timestep
|
||||
embeddings.
|
||||
attention_mask (`torch.Tensor`, *optional*, defaults to `None`):
|
||||
An attention mask of shape `(batch, key_tokens)` is applied to `encoder_hidden_states`. If `1` the mask
|
||||
is kept, otherwise if `0` it is discarded. Mask will be converted into a bias, which adds large
|
||||
negative values to the attention scores corresponding to "discard" tokens.
|
||||
added_cond_kwargs (`dict`):
|
||||
Additional conditions for the Stable Diffusion XL UNet.
|
||||
cross_attention_kwargs (`dict[str]`, *optional*, defaults to `None`):
|
||||
A kwargs dictionary that if specified is passed along to the `AttnProcessor`.
|
||||
guess_mode (`bool`, defaults to `False`):
|
||||
In this mode, the ControlNet encoder tries its best to recognize the input content of the input even if
|
||||
you remove all prompts. A `guidance_scale` between 3.0 and 5.0 is recommended.
|
||||
return_dict (`bool`, defaults to `True`):
|
||||
Whether or not to return a [`~models.controlnet.ControlNetOutput`] instead of a plain tuple.
|
||||
|
||||
Returns:
|
||||
[`~models.controlnet.ControlNetOutput`] **or** `tuple`:
|
||||
If `return_dict` is `True`, a [`~models.controlnet.ControlNetOutput`] is returned, otherwise a tuple is
|
||||
returned where the first element is the sample tensor.
|
||||
"""
|
||||
# check channel order
|
||||
channel_order = self.config.controlnet_conditioning_channel_order
|
||||
|
||||
if channel_order == "rgb":
|
||||
# in rgb order by default
|
||||
...
|
||||
# elif channel_order == "bgr":
|
||||
# controlnet_cond = torch.flip(controlnet_cond, dims=[1])
|
||||
else:
|
||||
raise ValueError(f"unknown `controlnet_conditioning_channel_order`: {channel_order}")
|
||||
|
||||
# prepare attention_mask
|
||||
if attention_mask is not None:
|
||||
attention_mask = (1 - attention_mask.to(sample.dtype)) * -10000.0
|
||||
attention_mask = attention_mask.unsqueeze(1)
|
||||
|
||||
# 1. time
|
||||
timesteps = timestep
|
||||
if not torch.is_tensor(timesteps):
|
||||
# TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can
|
||||
# This would be a good case for the `match` statement (Python 3.10+)
|
||||
is_mps = sample.device.type == "mps"
|
||||
if isinstance(timestep, float):
|
||||
dtype = torch.float32 if is_mps else torch.float64
|
||||
else:
|
||||
dtype = torch.int32 if is_mps else torch.int64
|
||||
timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device)
|
||||
elif len(timesteps.shape) == 0:
|
||||
timesteps = timesteps[None].to(sample.device)
|
||||
|
||||
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
|
||||
timesteps = timesteps.expand(sample.shape[0])
|
||||
|
||||
t_emb = self.time_proj(timesteps)
|
||||
|
||||
# timesteps does not contain any weights and will always return f32 tensors
|
||||
# but time_embedding might actually be running in fp16. so we need to cast here.
|
||||
# there might be better ways to encapsulate this.
|
||||
t_emb = t_emb.to(dtype=sample.dtype)
|
||||
|
||||
emb = self.time_embedding(t_emb, timestep_cond)
|
||||
aug_emb = None
|
||||
|
||||
if self.class_embedding is not None:
|
||||
if class_labels is None:
|
||||
raise ValueError("class_labels should be provided when num_class_embeds > 0")
|
||||
|
||||
if self.config.class_embed_type == "timestep":
|
||||
class_labels = self.time_proj(class_labels)
|
||||
|
||||
class_emb = self.class_embedding(class_labels).to(dtype=self.dtype)
|
||||
emb = emb + class_emb
|
||||
|
||||
if self.config.addition_embed_type is not None:
|
||||
if self.config.addition_embed_type == "text":
|
||||
aug_emb = self.add_embedding(encoder_hidden_states)
|
||||
|
||||
elif self.config.addition_embed_type == "text_time":
|
||||
if "text_embeds" not in added_cond_kwargs:
|
||||
raise ValueError(
|
||||
f"{self.__class__} has the config param `addition_embed_type` set to 'text_time' which requires the keyword argument `text_embeds` to be passed in `added_cond_kwargs`"
|
||||
)
|
||||
text_embeds = added_cond_kwargs.get("text_embeds")
|
||||
if "time_ids" not in added_cond_kwargs:
|
||||
raise ValueError(
|
||||
f"{self.__class__} has the config param `addition_embed_type` set to 'text_time' which requires the keyword argument `time_ids` to be passed in `added_cond_kwargs`"
|
||||
)
|
||||
time_ids = added_cond_kwargs.get("time_ids")
|
||||
time_embeds = self.add_time_proj(time_ids.flatten())
|
||||
time_embeds = time_embeds.reshape((text_embeds.shape[0], -1))
|
||||
|
||||
add_embeds = torch.concat([text_embeds, time_embeds], dim=-1)
|
||||
add_embeds = add_embeds.to(emb.dtype)
|
||||
aug_emb = self.add_embedding(add_embeds)
|
||||
|
||||
emb = emb + aug_emb if aug_emb is not None else emb
|
||||
|
||||
# 2. pre-process
|
||||
sample = self.conv_in(sample)
|
||||
|
||||
controlnet_cond = self.controlnet_cond_embedding(*controlnet_cond)
|
||||
sample = sample + controlnet_cond
|
||||
|
||||
# 3. down
|
||||
down_block_res_samples = (sample,)
|
||||
for downsample_block in self.down_blocks:
|
||||
if hasattr(downsample_block, "has_cross_attention") and downsample_block.has_cross_attention:
|
||||
sample, res_samples = downsample_block(
|
||||
hidden_states=sample,
|
||||
temb=emb,
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
attention_mask=attention_mask,
|
||||
cross_attention_kwargs=cross_attention_kwargs,
|
||||
)
|
||||
else:
|
||||
sample, res_samples = downsample_block(hidden_states=sample, temb=emb)
|
||||
|
||||
down_block_res_samples += res_samples
|
||||
|
||||
# 4. mid
|
||||
if self.mid_block is not None:
|
||||
if hasattr(self.mid_block, "has_cross_attention") and self.mid_block.has_cross_attention:
|
||||
sample = self.mid_block(
|
||||
sample,
|
||||
emb,
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
attention_mask=attention_mask,
|
||||
cross_attention_kwargs=cross_attention_kwargs,
|
||||
)
|
||||
else:
|
||||
sample = self.mid_block(sample, emb)
|
||||
|
||||
# 5. Control net blocks
|
||||
controlnet_down_block_res_samples = ()
|
||||
|
||||
for down_block_res_sample, controlnet_block in zip(down_block_res_samples, self.controlnet_down_blocks):
|
||||
down_block_res_sample = controlnet_block(down_block_res_sample)
|
||||
controlnet_down_block_res_samples = controlnet_down_block_res_samples + (down_block_res_sample,)
|
||||
|
||||
down_block_res_samples = controlnet_down_block_res_samples
|
||||
|
||||
mid_block_res_sample = self.controlnet_mid_block(sample)
|
||||
|
||||
# 6. scaling
|
||||
if guess_mode and not self.config.global_pool_conditions:
|
||||
scales = torch.logspace(-1, 0, len(down_block_res_samples) + 1, device=sample.device) # 0.1 to 1.0
|
||||
scales = scales * conditioning_scale
|
||||
down_block_res_samples = [sample * scale for sample, scale in zip(down_block_res_samples, scales)]
|
||||
mid_block_res_sample = mid_block_res_sample * scales[-1] # last one
|
||||
else:
|
||||
down_block_res_samples = [sample * conditioning_scale for sample in down_block_res_samples]
|
||||
mid_block_res_sample = mid_block_res_sample * conditioning_scale
|
||||
|
||||
if self.config.global_pool_conditions:
|
||||
down_block_res_samples = [
|
||||
torch.mean(sample, dim=(2, 3), keepdim=True) for sample in down_block_res_samples
|
||||
]
|
||||
mid_block_res_sample = torch.mean(mid_block_res_sample, dim=(2, 3), keepdim=True)
|
||||
|
||||
if not return_dict:
|
||||
return (down_block_res_samples, mid_block_res_sample)
|
||||
|
||||
return ControlNetOutput(
|
||||
down_block_res_samples=down_block_res_samples, mid_block_res_sample=mid_block_res_sample
|
||||
)
|
||||
|
||||
|
||||
# Copied from diffusers.models.controlnet.zero_module
|
||||
def zero_module(module):
|
||||
for p in module.parameters():
|
||||
nn.init.zeros_(p)
|
||||
return module
|
||||
+209
@@ -0,0 +1,209 @@
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
from .RecSVTR import Block
|
||||
|
||||
|
||||
class Swish(nn.Module):
|
||||
def __int__(self):
|
||||
super(Swish, self).__int__()
|
||||
|
||||
def forward(self, x):
|
||||
return x * torch.sigmoid(x)
|
||||
|
||||
|
||||
class Im2Im(nn.Module):
|
||||
def __init__(self, in_channels, **kwargs):
|
||||
super().__init__()
|
||||
self.out_channels = in_channels
|
||||
|
||||
def forward(self, x):
|
||||
return x
|
||||
|
||||
|
||||
class Im2Seq(nn.Module):
|
||||
def __init__(self, in_channels, **kwargs):
|
||||
super().__init__()
|
||||
self.out_channels = in_channels
|
||||
|
||||
def forward(self, x):
|
||||
B, C, H, W = x.shape
|
||||
# assert H == 1
|
||||
x = x.reshape(B, C, H * W)
|
||||
x = x.permute((0, 2, 1))
|
||||
return x
|
||||
|
||||
|
||||
class EncoderWithRNN(nn.Module):
|
||||
def __init__(self, in_channels, **kwargs):
|
||||
super(EncoderWithRNN, self).__init__()
|
||||
hidden_size = kwargs.get("hidden_size", 256)
|
||||
self.out_channels = hidden_size * 2
|
||||
self.lstm = nn.LSTM(in_channels, hidden_size, bidirectional=True, num_layers=2, batch_first=True)
|
||||
|
||||
def forward(self, x):
|
||||
self.lstm.flatten_parameters()
|
||||
x, _ = self.lstm(x)
|
||||
return x
|
||||
|
||||
|
||||
class SequenceEncoder(nn.Module):
|
||||
def __init__(self, in_channels, encoder_type="rnn", **kwargs):
|
||||
super(SequenceEncoder, self).__init__()
|
||||
self.encoder_reshape = Im2Seq(in_channels)
|
||||
self.out_channels = self.encoder_reshape.out_channels
|
||||
self.encoder_type = encoder_type
|
||||
if encoder_type == "reshape":
|
||||
self.only_reshape = True
|
||||
else:
|
||||
support_encoder_dict = {"reshape": Im2Seq, "rnn": EncoderWithRNN, "svtr": EncoderWithSVTR}
|
||||
assert encoder_type in support_encoder_dict, "{} must in {}".format(
|
||||
encoder_type, support_encoder_dict.keys()
|
||||
)
|
||||
|
||||
self.encoder = support_encoder_dict[encoder_type](self.encoder_reshape.out_channels, **kwargs)
|
||||
self.out_channels = self.encoder.out_channels
|
||||
self.only_reshape = False
|
||||
|
||||
def forward(self, x):
|
||||
if self.encoder_type != "svtr":
|
||||
x = self.encoder_reshape(x)
|
||||
if not self.only_reshape:
|
||||
x = self.encoder(x)
|
||||
return x
|
||||
else:
|
||||
x = self.encoder(x)
|
||||
x = self.encoder_reshape(x)
|
||||
return x
|
||||
|
||||
|
||||
class ConvBNLayer(nn.Module):
|
||||
def __init__(
|
||||
self, in_channels, out_channels, kernel_size=3, stride=1, padding=0, bias_attr=False, groups=1, act=nn.GELU
|
||||
):
|
||||
super().__init__()
|
||||
self.conv = nn.Conv2d(
|
||||
in_channels=in_channels,
|
||||
out_channels=out_channels,
|
||||
kernel_size=kernel_size,
|
||||
stride=stride,
|
||||
padding=padding,
|
||||
groups=groups,
|
||||
# weight_attr=paddle.ParamAttr(initializer=nn.initializer.KaimingUniform()),
|
||||
bias=bias_attr,
|
||||
)
|
||||
self.norm = nn.BatchNorm2d(out_channels)
|
||||
self.act = Swish()
|
||||
|
||||
def forward(self, inputs):
|
||||
out = self.conv(inputs)
|
||||
out = self.norm(out)
|
||||
out = self.act(out)
|
||||
return out
|
||||
|
||||
|
||||
class EncoderWithSVTR(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
in_channels,
|
||||
dims=64, # XS
|
||||
depth=2,
|
||||
hidden_dims=120,
|
||||
use_guide=False,
|
||||
num_heads=8,
|
||||
qkv_bias=True,
|
||||
mlp_ratio=2.0,
|
||||
drop_rate=0.1,
|
||||
attn_drop_rate=0.1,
|
||||
drop_path=0.0,
|
||||
qk_scale=None,
|
||||
):
|
||||
super(EncoderWithSVTR, self).__init__()
|
||||
self.depth = depth
|
||||
self.use_guide = use_guide
|
||||
self.conv1 = ConvBNLayer(in_channels, in_channels // 8, padding=1, act="swish")
|
||||
self.conv2 = ConvBNLayer(in_channels // 8, hidden_dims, kernel_size=1, act="swish")
|
||||
|
||||
self.svtr_block = nn.ModuleList(
|
||||
[
|
||||
Block(
|
||||
dim=hidden_dims,
|
||||
num_heads=num_heads,
|
||||
mixer="Global",
|
||||
HW=None,
|
||||
mlp_ratio=mlp_ratio,
|
||||
qkv_bias=qkv_bias,
|
||||
qk_scale=qk_scale,
|
||||
drop=drop_rate,
|
||||
act_layer="swish",
|
||||
attn_drop=attn_drop_rate,
|
||||
drop_path=drop_path,
|
||||
norm_layer="nn.LayerNorm",
|
||||
epsilon=1e-05,
|
||||
prenorm=False,
|
||||
)
|
||||
for i in range(depth)
|
||||
]
|
||||
)
|
||||
self.norm = nn.LayerNorm(hidden_dims, eps=1e-6)
|
||||
self.conv3 = ConvBNLayer(hidden_dims, in_channels, kernel_size=1, act="swish")
|
||||
# last conv-nxn, the input is concat of input tensor and conv3 output tensor
|
||||
self.conv4 = ConvBNLayer(2 * in_channels, in_channels // 8, padding=1, act="swish")
|
||||
|
||||
self.conv1x1 = ConvBNLayer(in_channels // 8, dims, kernel_size=1, act="swish")
|
||||
self.out_channels = dims
|
||||
self.apply(self._init_weights)
|
||||
|
||||
def _init_weights(self, m):
|
||||
# weight initialization
|
||||
if isinstance(m, nn.Conv2d):
|
||||
nn.init.kaiming_normal_(m.weight, mode="fan_out")
|
||||
if m.bias is not None:
|
||||
nn.init.zeros_(m.bias)
|
||||
elif isinstance(m, nn.BatchNorm2d):
|
||||
nn.init.ones_(m.weight)
|
||||
nn.init.zeros_(m.bias)
|
||||
elif isinstance(m, nn.Linear):
|
||||
nn.init.normal_(m.weight, 0, 0.01)
|
||||
if m.bias is not None:
|
||||
nn.init.zeros_(m.bias)
|
||||
elif isinstance(m, nn.ConvTranspose2d):
|
||||
nn.init.kaiming_normal_(m.weight, mode="fan_out")
|
||||
if m.bias is not None:
|
||||
nn.init.zeros_(m.bias)
|
||||
elif isinstance(m, nn.LayerNorm):
|
||||
nn.init.ones_(m.weight)
|
||||
nn.init.zeros_(m.bias)
|
||||
|
||||
def forward(self, x):
|
||||
# for use guide
|
||||
if self.use_guide:
|
||||
z = x.clone()
|
||||
z.stop_gradient = True
|
||||
else:
|
||||
z = x
|
||||
# for short cut
|
||||
h = z
|
||||
# reduce dim
|
||||
z = self.conv1(z)
|
||||
z = self.conv2(z)
|
||||
# SVTR global block
|
||||
B, C, H, W = z.shape
|
||||
z = z.flatten(2).permute(0, 2, 1)
|
||||
|
||||
for blk in self.svtr_block:
|
||||
z = blk(z)
|
||||
|
||||
z = self.norm(z)
|
||||
# last stage
|
||||
z = z.reshape([-1, H, W, C]).permute(0, 3, 1, 2)
|
||||
z = self.conv3(z)
|
||||
z = torch.cat((h, z), dim=1)
|
||||
z = self.conv1x1(self.conv4(z))
|
||||
|
||||
return z
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
svtrRNN = EncoderWithSVTR(56)
|
||||
print(svtrRNN)
|
||||
@@ -0,0 +1,45 @@
|
||||
from torch import nn
|
||||
|
||||
|
||||
class CTCHead(nn.Module):
|
||||
def __init__(
|
||||
self, in_channels, out_channels=6625, fc_decay=0.0004, mid_channels=None, return_feats=False, **kwargs
|
||||
):
|
||||
super(CTCHead, self).__init__()
|
||||
if mid_channels is None:
|
||||
self.fc = nn.Linear(
|
||||
in_channels,
|
||||
out_channels,
|
||||
bias=True,
|
||||
)
|
||||
else:
|
||||
self.fc1 = nn.Linear(
|
||||
in_channels,
|
||||
mid_channels,
|
||||
bias=True,
|
||||
)
|
||||
self.fc2 = nn.Linear(
|
||||
mid_channels,
|
||||
out_channels,
|
||||
bias=True,
|
||||
)
|
||||
|
||||
self.out_channels = out_channels
|
||||
self.mid_channels = mid_channels
|
||||
self.return_feats = return_feats
|
||||
|
||||
def forward(self, x, labels=None):
|
||||
if self.mid_channels is None:
|
||||
predicts = self.fc(x)
|
||||
else:
|
||||
x = self.fc1(x)
|
||||
predicts = self.fc2(x)
|
||||
|
||||
if self.return_feats:
|
||||
result = {}
|
||||
result["ctc"] = predicts
|
||||
result["ctc_neck"] = x
|
||||
else:
|
||||
result = predicts
|
||||
|
||||
return result
|
||||
@@ -0,0 +1,49 @@
|
||||
from torch import nn
|
||||
|
||||
from .RecCTCHead import CTCHead
|
||||
from .RecMv1_enhance import MobileNetV1Enhance
|
||||
from .RNN import Im2Im, Im2Seq, SequenceEncoder
|
||||
|
||||
|
||||
backbone_dict = {"MobileNetV1Enhance": MobileNetV1Enhance}
|
||||
neck_dict = {"SequenceEncoder": SequenceEncoder, "Im2Seq": Im2Seq, "None": Im2Im}
|
||||
head_dict = {"CTCHead": CTCHead}
|
||||
|
||||
|
||||
class RecModel(nn.Module):
|
||||
def __init__(self, config):
|
||||
super().__init__()
|
||||
assert "in_channels" in config, "in_channels must in model config"
|
||||
backbone_type = config["backbone"].pop("type")
|
||||
assert backbone_type in backbone_dict, f"backbone.type must in {backbone_dict}"
|
||||
self.backbone = backbone_dict[backbone_type](config["in_channels"], **config["backbone"])
|
||||
|
||||
neck_type = config["neck"].pop("type")
|
||||
assert neck_type in neck_dict, f"neck.type must in {neck_dict}"
|
||||
self.neck = neck_dict[neck_type](self.backbone.out_channels, **config["neck"])
|
||||
|
||||
head_type = config["head"].pop("type")
|
||||
assert head_type in head_dict, f"head.type must in {head_dict}"
|
||||
self.head = head_dict[head_type](self.neck.out_channels, **config["head"])
|
||||
|
||||
self.name = f"RecModel_{backbone_type}_{neck_type}_{head_type}"
|
||||
|
||||
def load_3rd_state_dict(self, _3rd_name, _state):
|
||||
self.backbone.load_3rd_state_dict(_3rd_name, _state)
|
||||
self.neck.load_3rd_state_dict(_3rd_name, _state)
|
||||
self.head.load_3rd_state_dict(_3rd_name, _state)
|
||||
|
||||
def forward(self, x):
|
||||
import torch
|
||||
|
||||
x = x.to(torch.float32)
|
||||
x = self.backbone(x)
|
||||
x = self.neck(x)
|
||||
x = self.head(x)
|
||||
return x
|
||||
|
||||
def encode(self, x):
|
||||
x = self.backbone(x)
|
||||
x = self.neck(x)
|
||||
x = self.head.ctc_encoder(x)
|
||||
return x
|
||||
@@ -0,0 +1,197 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
from .common import Activation
|
||||
|
||||
|
||||
class ConvBNLayer(nn.Module):
|
||||
def __init__(
|
||||
self, num_channels, filter_size, num_filters, stride, padding, channels=None, num_groups=1, act="hard_swish"
|
||||
):
|
||||
super(ConvBNLayer, self).__init__()
|
||||
self.act = act
|
||||
self._conv = nn.Conv2d(
|
||||
in_channels=num_channels,
|
||||
out_channels=num_filters,
|
||||
kernel_size=filter_size,
|
||||
stride=stride,
|
||||
padding=padding,
|
||||
groups=num_groups,
|
||||
bias=False,
|
||||
)
|
||||
|
||||
self._batch_norm = nn.BatchNorm2d(
|
||||
num_filters,
|
||||
)
|
||||
if self.act is not None:
|
||||
self._act = Activation(act_type=act, inplace=True)
|
||||
|
||||
def forward(self, inputs):
|
||||
y = self._conv(inputs)
|
||||
y = self._batch_norm(y)
|
||||
if self.act is not None:
|
||||
y = self._act(y)
|
||||
return y
|
||||
|
||||
|
||||
class DepthwiseSeparable(nn.Module):
|
||||
def __init__(
|
||||
self, num_channels, num_filters1, num_filters2, num_groups, stride, scale, dw_size=3, padding=1, use_se=False
|
||||
):
|
||||
super(DepthwiseSeparable, self).__init__()
|
||||
self.use_se = use_se
|
||||
self._depthwise_conv = ConvBNLayer(
|
||||
num_channels=num_channels,
|
||||
num_filters=int(num_filters1 * scale),
|
||||
filter_size=dw_size,
|
||||
stride=stride,
|
||||
padding=padding,
|
||||
num_groups=int(num_groups * scale),
|
||||
)
|
||||
if use_se:
|
||||
self._se = SEModule(int(num_filters1 * scale))
|
||||
self._pointwise_conv = ConvBNLayer(
|
||||
num_channels=int(num_filters1 * scale),
|
||||
filter_size=1,
|
||||
num_filters=int(num_filters2 * scale),
|
||||
stride=1,
|
||||
padding=0,
|
||||
)
|
||||
|
||||
def forward(self, inputs):
|
||||
y = self._depthwise_conv(inputs)
|
||||
if self.use_se:
|
||||
y = self._se(y)
|
||||
y = self._pointwise_conv(y)
|
||||
return y
|
||||
|
||||
|
||||
class MobileNetV1Enhance(nn.Module):
|
||||
def __init__(self, in_channels=3, scale=0.5, last_conv_stride=1, last_pool_type="max", **kwargs):
|
||||
super().__init__()
|
||||
self.scale = scale
|
||||
self.block_list = []
|
||||
|
||||
self.conv1 = ConvBNLayer(
|
||||
num_channels=in_channels, filter_size=3, channels=3, num_filters=int(32 * scale), stride=2, padding=1
|
||||
)
|
||||
|
||||
conv2_1 = DepthwiseSeparable(
|
||||
num_channels=int(32 * scale), num_filters1=32, num_filters2=64, num_groups=32, stride=1, scale=scale
|
||||
)
|
||||
self.block_list.append(conv2_1)
|
||||
|
||||
conv2_2 = DepthwiseSeparable(
|
||||
num_channels=int(64 * scale), num_filters1=64, num_filters2=128, num_groups=64, stride=1, scale=scale
|
||||
)
|
||||
self.block_list.append(conv2_2)
|
||||
|
||||
conv3_1 = DepthwiseSeparable(
|
||||
num_channels=int(128 * scale), num_filters1=128, num_filters2=128, num_groups=128, stride=1, scale=scale
|
||||
)
|
||||
self.block_list.append(conv3_1)
|
||||
|
||||
conv3_2 = DepthwiseSeparable(
|
||||
num_channels=int(128 * scale),
|
||||
num_filters1=128,
|
||||
num_filters2=256,
|
||||
num_groups=128,
|
||||
stride=(2, 1),
|
||||
scale=scale,
|
||||
)
|
||||
self.block_list.append(conv3_2)
|
||||
|
||||
conv4_1 = DepthwiseSeparable(
|
||||
num_channels=int(256 * scale), num_filters1=256, num_filters2=256, num_groups=256, stride=1, scale=scale
|
||||
)
|
||||
self.block_list.append(conv4_1)
|
||||
|
||||
conv4_2 = DepthwiseSeparable(
|
||||
num_channels=int(256 * scale),
|
||||
num_filters1=256,
|
||||
num_filters2=512,
|
||||
num_groups=256,
|
||||
stride=(2, 1),
|
||||
scale=scale,
|
||||
)
|
||||
self.block_list.append(conv4_2)
|
||||
|
||||
for _ in range(5):
|
||||
conv5 = DepthwiseSeparable(
|
||||
num_channels=int(512 * scale),
|
||||
num_filters1=512,
|
||||
num_filters2=512,
|
||||
num_groups=512,
|
||||
stride=1,
|
||||
dw_size=5,
|
||||
padding=2,
|
||||
scale=scale,
|
||||
use_se=False,
|
||||
)
|
||||
self.block_list.append(conv5)
|
||||
|
||||
conv5_6 = DepthwiseSeparable(
|
||||
num_channels=int(512 * scale),
|
||||
num_filters1=512,
|
||||
num_filters2=1024,
|
||||
num_groups=512,
|
||||
stride=(2, 1),
|
||||
dw_size=5,
|
||||
padding=2,
|
||||
scale=scale,
|
||||
use_se=True,
|
||||
)
|
||||
self.block_list.append(conv5_6)
|
||||
|
||||
conv6 = DepthwiseSeparable(
|
||||
num_channels=int(1024 * scale),
|
||||
num_filters1=1024,
|
||||
num_filters2=1024,
|
||||
num_groups=1024,
|
||||
stride=last_conv_stride,
|
||||
dw_size=5,
|
||||
padding=2,
|
||||
use_se=True,
|
||||
scale=scale,
|
||||
)
|
||||
self.block_list.append(conv6)
|
||||
|
||||
self.block_list = nn.Sequential(*self.block_list)
|
||||
if last_pool_type == "avg":
|
||||
self.pool = nn.AvgPool2d(kernel_size=2, stride=2, padding=0)
|
||||
else:
|
||||
self.pool = nn.MaxPool2d(kernel_size=2, stride=2, padding=0)
|
||||
self.out_channels = int(1024 * scale)
|
||||
|
||||
def forward(self, inputs):
|
||||
y = self.conv1(inputs)
|
||||
y = self.block_list(y)
|
||||
y = self.pool(y)
|
||||
return y
|
||||
|
||||
|
||||
def hardsigmoid(x):
|
||||
return F.relu6(x + 3.0, inplace=True) / 6.0
|
||||
|
||||
|
||||
class SEModule(nn.Module):
|
||||
def __init__(self, channel, reduction=4):
|
||||
super(SEModule, self).__init__()
|
||||
self.avg_pool = nn.AdaptiveAvgPool2d(1)
|
||||
self.conv1 = nn.Conv2d(
|
||||
in_channels=channel, out_channels=channel // reduction, kernel_size=1, stride=1, padding=0, bias=True
|
||||
)
|
||||
self.conv2 = nn.Conv2d(
|
||||
in_channels=channel // reduction, out_channels=channel, kernel_size=1, stride=1, padding=0, bias=True
|
||||
)
|
||||
|
||||
def forward(self, inputs):
|
||||
outputs = self.avg_pool(inputs)
|
||||
outputs = self.conv1(outputs)
|
||||
outputs = F.relu(outputs)
|
||||
outputs = self.conv2(outputs)
|
||||
outputs = hardsigmoid(outputs)
|
||||
x = torch.mul(inputs, outputs)
|
||||
|
||||
return x
|
||||
@@ -0,0 +1,570 @@
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from torch.nn import functional
|
||||
from torch.nn.init import ones_, trunc_normal_, zeros_
|
||||
|
||||
|
||||
def drop_path(x, drop_prob=0.0, training=False):
|
||||
"""Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
|
||||
the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper...
|
||||
See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ...
|
||||
"""
|
||||
if drop_prob == 0.0 or not training:
|
||||
return x
|
||||
keep_prob = torch.tensor(1 - drop_prob)
|
||||
shape = (x.size()[0],) + (1,) * (x.ndim - 1)
|
||||
random_tensor = keep_prob + torch.rand(shape, dtype=x.dtype)
|
||||
random_tensor = torch.floor(random_tensor) # binarize
|
||||
output = x.divide(keep_prob) * random_tensor
|
||||
return output
|
||||
|
||||
|
||||
class Swish(nn.Module):
|
||||
def __int__(self):
|
||||
super(Swish, self).__int__()
|
||||
|
||||
def forward(self, x):
|
||||
return x * torch.sigmoid(x)
|
||||
|
||||
|
||||
class ConvBNLayer(nn.Module):
|
||||
def __init__(
|
||||
self, in_channels, out_channels, kernel_size=3, stride=1, padding=0, bias_attr=False, groups=1, act=nn.GELU
|
||||
):
|
||||
super().__init__()
|
||||
self.conv = nn.Conv2d(
|
||||
in_channels=in_channels,
|
||||
out_channels=out_channels,
|
||||
kernel_size=kernel_size,
|
||||
stride=stride,
|
||||
padding=padding,
|
||||
groups=groups,
|
||||
# weight_attr=paddle.ParamAttr(initializer=nn.initializer.KaimingUniform()),
|
||||
bias=bias_attr,
|
||||
)
|
||||
self.norm = nn.BatchNorm2d(out_channels)
|
||||
self.act = act()
|
||||
|
||||
def forward(self, inputs):
|
||||
out = self.conv(inputs)
|
||||
out = self.norm(out)
|
||||
out = self.act(out)
|
||||
return out
|
||||
|
||||
|
||||
class DropPath(nn.Module):
|
||||
"""Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks)."""
|
||||
|
||||
def __init__(self, drop_prob=None):
|
||||
super(DropPath, self).__init__()
|
||||
self.drop_prob = drop_prob
|
||||
|
||||
def forward(self, x):
|
||||
return drop_path(x, self.drop_prob, self.training)
|
||||
|
||||
|
||||
class Identity(nn.Module):
|
||||
def __init__(self):
|
||||
super(Identity, self).__init__()
|
||||
|
||||
def forward(self, input):
|
||||
return input
|
||||
|
||||
|
||||
class Mlp(nn.Module):
|
||||
def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.0):
|
||||
super().__init__()
|
||||
out_features = out_features or in_features
|
||||
hidden_features = hidden_features or in_features
|
||||
self.fc1 = nn.Linear(in_features, hidden_features)
|
||||
if isinstance(act_layer, str):
|
||||
self.act = Swish()
|
||||
else:
|
||||
self.act = act_layer()
|
||||
self.fc2 = nn.Linear(hidden_features, out_features)
|
||||
self.drop = nn.Dropout(drop)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.fc1(x)
|
||||
x = self.act(x)
|
||||
x = self.drop(x)
|
||||
x = self.fc2(x)
|
||||
x = self.drop(x)
|
||||
return x
|
||||
|
||||
|
||||
class ConvMixer(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
dim,
|
||||
num_heads=8,
|
||||
HW=(8, 25),
|
||||
local_k=(3, 3),
|
||||
):
|
||||
super().__init__()
|
||||
self.HW = HW
|
||||
self.dim = dim
|
||||
self.local_mixer = nn.Conv2d(
|
||||
dim,
|
||||
dim,
|
||||
local_k,
|
||||
1,
|
||||
(local_k[0] // 2, local_k[1] // 2),
|
||||
groups=num_heads,
|
||||
# weight_attr=ParamAttr(initializer=KaimingNormal())
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
h = self.HW[0]
|
||||
w = self.HW[1]
|
||||
x = x.transpose([0, 2, 1]).reshape([0, self.dim, h, w])
|
||||
x = self.local_mixer(x)
|
||||
x = x.flatten(2).transpose([0, 2, 1])
|
||||
return x
|
||||
|
||||
|
||||
class Attention(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
dim,
|
||||
num_heads=8,
|
||||
mixer="Global",
|
||||
HW=(8, 25),
|
||||
local_k=(7, 11),
|
||||
qkv_bias=False,
|
||||
qk_scale=None,
|
||||
attn_drop=0.0,
|
||||
proj_drop=0.0,
|
||||
):
|
||||
super().__init__()
|
||||
self.num_heads = num_heads
|
||||
head_dim = dim // num_heads
|
||||
self.scale = qk_scale or head_dim**-0.5
|
||||
|
||||
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
|
||||
self.attn_drop = nn.Dropout(attn_drop)
|
||||
self.proj = nn.Linear(dim, dim)
|
||||
self.proj_drop = nn.Dropout(proj_drop)
|
||||
self.HW = HW
|
||||
if HW is not None:
|
||||
H = HW[0]
|
||||
W = HW[1]
|
||||
self.N = H * W
|
||||
self.C = dim
|
||||
if mixer == "Local" and HW is not None:
|
||||
hk = local_k[0]
|
||||
wk = local_k[1]
|
||||
mask = torch.ones([H * W, H + hk - 1, W + wk - 1])
|
||||
for h in range(0, H):
|
||||
for w in range(0, W):
|
||||
mask[h * W + w, h : h + hk, w : w + wk] = 0.0
|
||||
mask_paddle = mask[:, hk // 2 : H + hk // 2, wk // 2 : W + wk // 2].flatten(1)
|
||||
mask_inf = torch.full([H * W, H * W], fill_value=float("-inf"))
|
||||
mask = torch.where(mask_paddle < 1, mask_paddle, mask_inf)
|
||||
self.mask = mask[None, None, :]
|
||||
# self.mask = mask.unsqueeze([0, 1])
|
||||
self.mixer = mixer
|
||||
|
||||
def forward(self, x):
|
||||
if self.HW is not None:
|
||||
N = self.N
|
||||
C = self.C
|
||||
else:
|
||||
_, N, C = x.shape
|
||||
qkv = self.qkv(x).reshape((-1, N, 3, self.num_heads, C // self.num_heads)).permute((2, 0, 3, 1, 4))
|
||||
q, k, v = qkv[0] * self.scale, qkv[1], qkv[2]
|
||||
|
||||
attn = q.matmul(k.permute((0, 1, 3, 2)))
|
||||
if self.mixer == "Local":
|
||||
attn += self.mask
|
||||
attn = functional.softmax(attn, dim=-1)
|
||||
attn = self.attn_drop(attn)
|
||||
|
||||
x = (attn.matmul(v)).permute((0, 2, 1, 3)).reshape((-1, N, C))
|
||||
x = self.proj(x)
|
||||
x = self.proj_drop(x)
|
||||
return x
|
||||
|
||||
|
||||
class Block(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
dim,
|
||||
num_heads,
|
||||
mixer="Global",
|
||||
local_mixer=(7, 11),
|
||||
HW=(8, 25),
|
||||
mlp_ratio=4.0,
|
||||
qkv_bias=False,
|
||||
qk_scale=None,
|
||||
drop=0.0,
|
||||
attn_drop=0.0,
|
||||
drop_path=0.0,
|
||||
act_layer=nn.GELU,
|
||||
norm_layer="nn.LayerNorm",
|
||||
epsilon=1e-6,
|
||||
prenorm=True,
|
||||
):
|
||||
super().__init__()
|
||||
if isinstance(norm_layer, str):
|
||||
self.norm1 = eval(norm_layer)(dim, eps=epsilon)
|
||||
else:
|
||||
self.norm1 = norm_layer(dim)
|
||||
if mixer == "Global" or mixer == "Local":
|
||||
self.mixer = Attention(
|
||||
dim,
|
||||
num_heads=num_heads,
|
||||
mixer=mixer,
|
||||
HW=HW,
|
||||
local_k=local_mixer,
|
||||
qkv_bias=qkv_bias,
|
||||
qk_scale=qk_scale,
|
||||
attn_drop=attn_drop,
|
||||
proj_drop=drop,
|
||||
)
|
||||
elif mixer == "Conv":
|
||||
self.mixer = ConvMixer(dim, num_heads=num_heads, HW=HW, local_k=local_mixer)
|
||||
else:
|
||||
raise TypeError("The mixer must be one of [Global, Local, Conv]")
|
||||
|
||||
self.drop_path = DropPath(drop_path) if drop_path > 0.0 else Identity()
|
||||
if isinstance(norm_layer, str):
|
||||
self.norm2 = eval(norm_layer)(dim, eps=epsilon)
|
||||
else:
|
||||
self.norm2 = norm_layer(dim)
|
||||
mlp_hidden_dim = int(dim * mlp_ratio)
|
||||
self.mlp_ratio = mlp_ratio
|
||||
self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
|
||||
self.prenorm = prenorm
|
||||
|
||||
def forward(self, x):
|
||||
if self.prenorm:
|
||||
x = self.norm1(x + self.drop_path(self.mixer(x)))
|
||||
x = self.norm2(x + self.drop_path(self.mlp(x)))
|
||||
else:
|
||||
x = x + self.drop_path(self.mixer(self.norm1(x)))
|
||||
x = x + self.drop_path(self.mlp(self.norm2(x)))
|
||||
return x
|
||||
|
||||
|
||||
class PatchEmbed(nn.Module):
|
||||
"""Image to Patch Embedding"""
|
||||
|
||||
def __init__(self, img_size=(32, 100), in_channels=3, embed_dim=768, sub_num=2):
|
||||
super().__init__()
|
||||
num_patches = (img_size[1] // (2**sub_num)) * (img_size[0] // (2**sub_num))
|
||||
self.img_size = img_size
|
||||
self.num_patches = num_patches
|
||||
self.embed_dim = embed_dim
|
||||
self.norm = None
|
||||
if sub_num == 2:
|
||||
self.proj = nn.Sequential(
|
||||
ConvBNLayer(
|
||||
in_channels=in_channels,
|
||||
out_channels=embed_dim // 2,
|
||||
kernel_size=3,
|
||||
stride=2,
|
||||
padding=1,
|
||||
act=nn.GELU,
|
||||
bias_attr=False,
|
||||
),
|
||||
ConvBNLayer(
|
||||
in_channels=embed_dim // 2,
|
||||
out_channels=embed_dim,
|
||||
kernel_size=3,
|
||||
stride=2,
|
||||
padding=1,
|
||||
act=nn.GELU,
|
||||
bias_attr=False,
|
||||
),
|
||||
)
|
||||
if sub_num == 3:
|
||||
self.proj = nn.Sequential(
|
||||
ConvBNLayer(
|
||||
in_channels=in_channels,
|
||||
out_channels=embed_dim // 4,
|
||||
kernel_size=3,
|
||||
stride=2,
|
||||
padding=1,
|
||||
act=nn.GELU,
|
||||
bias_attr=False,
|
||||
),
|
||||
ConvBNLayer(
|
||||
in_channels=embed_dim // 4,
|
||||
out_channels=embed_dim // 2,
|
||||
kernel_size=3,
|
||||
stride=2,
|
||||
padding=1,
|
||||
act=nn.GELU,
|
||||
bias_attr=False,
|
||||
),
|
||||
ConvBNLayer(
|
||||
in_channels=embed_dim // 2,
|
||||
out_channels=embed_dim,
|
||||
kernel_size=3,
|
||||
stride=2,
|
||||
padding=1,
|
||||
act=nn.GELU,
|
||||
bias_attr=False,
|
||||
),
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
B, C, H, W = x.shape
|
||||
assert (
|
||||
H == self.img_size[0] and W == self.img_size[1]
|
||||
), f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})."
|
||||
x = self.proj(x).flatten(2).permute(0, 2, 1)
|
||||
return x
|
||||
|
||||
|
||||
class SubSample(nn.Module):
|
||||
def __init__(self, in_channels, out_channels, types="Pool", stride=(2, 1), sub_norm="nn.LayerNorm", act=None):
|
||||
super().__init__()
|
||||
self.types = types
|
||||
if types == "Pool":
|
||||
self.avgpool = nn.AvgPool2d(kernel_size=(3, 5), stride=stride, padding=(1, 2))
|
||||
self.maxpool = nn.MaxPool2d(kernel_size=(3, 5), stride=stride, padding=(1, 2))
|
||||
self.proj = nn.Linear(in_channels, out_channels)
|
||||
else:
|
||||
self.conv = nn.Conv2d(
|
||||
in_channels,
|
||||
out_channels,
|
||||
kernel_size=3,
|
||||
stride=stride,
|
||||
padding=1,
|
||||
# weight_attr=ParamAttr(initializer=KaimingNormal())
|
||||
)
|
||||
self.norm = eval(sub_norm)(out_channels)
|
||||
if act is not None:
|
||||
self.act = act()
|
||||
else:
|
||||
self.act = None
|
||||
|
||||
def forward(self, x):
|
||||
if self.types == "Pool":
|
||||
x1 = self.avgpool(x)
|
||||
x2 = self.maxpool(x)
|
||||
x = (x1 + x2) * 0.5
|
||||
out = self.proj(x.flatten(2).permute((0, 2, 1)))
|
||||
else:
|
||||
x = self.conv(x)
|
||||
out = x.flatten(2).permute((0, 2, 1))
|
||||
out = self.norm(out)
|
||||
if self.act is not None:
|
||||
out = self.act(out)
|
||||
|
||||
return out
|
||||
|
||||
|
||||
class SVTRNet(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
img_size=[48, 100],
|
||||
in_channels=3,
|
||||
embed_dim=[64, 128, 256],
|
||||
depth=[3, 6, 3],
|
||||
num_heads=[2, 4, 8],
|
||||
mixer=["Local"] * 6 + ["Global"] * 6, # Local atten, Global atten, Conv
|
||||
local_mixer=[[7, 11], [7, 11], [7, 11]],
|
||||
patch_merging="Conv", # Conv, Pool, None
|
||||
mlp_ratio=4,
|
||||
qkv_bias=True,
|
||||
qk_scale=None,
|
||||
drop_rate=0.0,
|
||||
last_drop=0.1,
|
||||
attn_drop_rate=0.0,
|
||||
drop_path_rate=0.1,
|
||||
norm_layer="nn.LayerNorm",
|
||||
sub_norm="nn.LayerNorm",
|
||||
epsilon=1e-6,
|
||||
out_channels=192,
|
||||
out_char_num=25,
|
||||
block_unit="Block",
|
||||
act="nn.GELU",
|
||||
last_stage=True,
|
||||
sub_num=2,
|
||||
prenorm=True,
|
||||
use_lenhead=False,
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__()
|
||||
self.img_size = img_size
|
||||
self.embed_dim = embed_dim
|
||||
self.out_channels = out_channels
|
||||
self.prenorm = prenorm
|
||||
patch_merging = None if patch_merging != "Conv" and patch_merging != "Pool" else patch_merging
|
||||
self.patch_embed = PatchEmbed(
|
||||
img_size=img_size, in_channels=in_channels, embed_dim=embed_dim[0], sub_num=sub_num
|
||||
)
|
||||
num_patches = self.patch_embed.num_patches
|
||||
self.HW = [img_size[0] // (2**sub_num), img_size[1] // (2**sub_num)]
|
||||
self.pos_embed = nn.Parameter(torch.zeros(1, num_patches, embed_dim[0]))
|
||||
# self.pos_embed = self.create_parameter(
|
||||
# shape=[1, num_patches, embed_dim[0]], default_initializer=zeros_)
|
||||
|
||||
# self.add_parameter("pos_embed", self.pos_embed)
|
||||
|
||||
self.pos_drop = nn.Dropout(p=drop_rate)
|
||||
Block_unit = eval(block_unit)
|
||||
|
||||
dpr = np.linspace(0, drop_path_rate, sum(depth))
|
||||
self.blocks1 = nn.ModuleList(
|
||||
[
|
||||
Block_unit(
|
||||
dim=embed_dim[0],
|
||||
num_heads=num_heads[0],
|
||||
mixer=mixer[0 : depth[0]][i],
|
||||
HW=self.HW,
|
||||
local_mixer=local_mixer[0],
|
||||
mlp_ratio=mlp_ratio,
|
||||
qkv_bias=qkv_bias,
|
||||
qk_scale=qk_scale,
|
||||
drop=drop_rate,
|
||||
act_layer=eval(act),
|
||||
attn_drop=attn_drop_rate,
|
||||
drop_path=dpr[0 : depth[0]][i],
|
||||
norm_layer=norm_layer,
|
||||
epsilon=epsilon,
|
||||
prenorm=prenorm,
|
||||
)
|
||||
for i in range(depth[0])
|
||||
]
|
||||
)
|
||||
if patch_merging is not None:
|
||||
self.sub_sample1 = SubSample(
|
||||
embed_dim[0], embed_dim[1], sub_norm=sub_norm, stride=[2, 1], types=patch_merging
|
||||
)
|
||||
HW = [self.HW[0] // 2, self.HW[1]]
|
||||
else:
|
||||
HW = self.HW
|
||||
self.patch_merging = patch_merging
|
||||
self.blocks2 = nn.ModuleList(
|
||||
[
|
||||
Block_unit(
|
||||
dim=embed_dim[1],
|
||||
num_heads=num_heads[1],
|
||||
mixer=mixer[depth[0] : depth[0] + depth[1]][i],
|
||||
HW=HW,
|
||||
local_mixer=local_mixer[1],
|
||||
mlp_ratio=mlp_ratio,
|
||||
qkv_bias=qkv_bias,
|
||||
qk_scale=qk_scale,
|
||||
drop=drop_rate,
|
||||
act_layer=eval(act),
|
||||
attn_drop=attn_drop_rate,
|
||||
drop_path=dpr[depth[0] : depth[0] + depth[1]][i],
|
||||
norm_layer=norm_layer,
|
||||
epsilon=epsilon,
|
||||
prenorm=prenorm,
|
||||
)
|
||||
for i in range(depth[1])
|
||||
]
|
||||
)
|
||||
if patch_merging is not None:
|
||||
self.sub_sample2 = SubSample(
|
||||
embed_dim[1], embed_dim[2], sub_norm=sub_norm, stride=[2, 1], types=patch_merging
|
||||
)
|
||||
HW = [self.HW[0] // 4, self.HW[1]]
|
||||
else:
|
||||
HW = self.HW
|
||||
self.blocks3 = nn.ModuleList(
|
||||
[
|
||||
Block_unit(
|
||||
dim=embed_dim[2],
|
||||
num_heads=num_heads[2],
|
||||
mixer=mixer[depth[0] + depth[1] :][i],
|
||||
HW=HW,
|
||||
local_mixer=local_mixer[2],
|
||||
mlp_ratio=mlp_ratio,
|
||||
qkv_bias=qkv_bias,
|
||||
qk_scale=qk_scale,
|
||||
drop=drop_rate,
|
||||
act_layer=eval(act),
|
||||
attn_drop=attn_drop_rate,
|
||||
drop_path=dpr[depth[0] + depth[1] :][i],
|
||||
norm_layer=norm_layer,
|
||||
epsilon=epsilon,
|
||||
prenorm=prenorm,
|
||||
)
|
||||
for i in range(depth[2])
|
||||
]
|
||||
)
|
||||
self.last_stage = last_stage
|
||||
if last_stage:
|
||||
self.avg_pool = nn.AdaptiveAvgPool2d((1, out_char_num))
|
||||
self.last_conv = nn.Conv2d(
|
||||
in_channels=embed_dim[2],
|
||||
out_channels=self.out_channels,
|
||||
kernel_size=1,
|
||||
stride=1,
|
||||
padding=0,
|
||||
bias=False,
|
||||
)
|
||||
self.hardswish = nn.Hardswish()
|
||||
self.dropout = nn.Dropout(p=last_drop)
|
||||
if not prenorm:
|
||||
self.norm = eval(norm_layer)(embed_dim[-1], epsilon=epsilon)
|
||||
self.use_lenhead = use_lenhead
|
||||
if use_lenhead:
|
||||
self.len_conv = nn.Linear(embed_dim[2], self.out_channels)
|
||||
self.hardswish_len = nn.Hardswish()
|
||||
self.dropout_len = nn.Dropout(p=last_drop)
|
||||
|
||||
trunc_normal_(self.pos_embed, std=0.02)
|
||||
self.apply(self._init_weights)
|
||||
|
||||
def _init_weights(self, m):
|
||||
if isinstance(m, nn.Linear):
|
||||
trunc_normal_(m.weight, std=0.02)
|
||||
if isinstance(m, nn.Linear) and m.bias is not None:
|
||||
zeros_(m.bias)
|
||||
elif isinstance(m, nn.LayerNorm):
|
||||
zeros_(m.bias)
|
||||
ones_(m.weight)
|
||||
|
||||
def forward_features(self, x):
|
||||
x = self.patch_embed(x)
|
||||
x = x + self.pos_embed
|
||||
x = self.pos_drop(x)
|
||||
for blk in self.blocks1:
|
||||
x = blk(x)
|
||||
if self.patch_merging is not None:
|
||||
x = self.sub_sample1(x.permute([0, 2, 1]).reshape([-1, self.embed_dim[0], self.HW[0], self.HW[1]]))
|
||||
for blk in self.blocks2:
|
||||
x = blk(x)
|
||||
if self.patch_merging is not None:
|
||||
x = self.sub_sample2(x.permute([0, 2, 1]).reshape([-1, self.embed_dim[1], self.HW[0] // 2, self.HW[1]]))
|
||||
for blk in self.blocks3:
|
||||
x = blk(x)
|
||||
if not self.prenorm:
|
||||
x = self.norm(x)
|
||||
return x
|
||||
|
||||
def forward(self, x):
|
||||
x = self.forward_features(x)
|
||||
if self.use_lenhead:
|
||||
len_x = self.len_conv(x.mean(1))
|
||||
len_x = self.dropout_len(self.hardswish_len(len_x))
|
||||
if self.last_stage:
|
||||
if self.patch_merging is not None:
|
||||
h = self.HW[0] // 4
|
||||
else:
|
||||
h = self.HW[0]
|
||||
x = self.avg_pool(x.permute([0, 2, 1]).reshape([-1, self.embed_dim[2], h, self.HW[1]]))
|
||||
x = self.last_conv(x)
|
||||
x = self.hardswish(x)
|
||||
x = self.dropout(x)
|
||||
if self.use_lenhead:
|
||||
return x, len_x
|
||||
return x
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
a = torch.rand(1, 3, 48, 100)
|
||||
svtr = SVTRNet()
|
||||
|
||||
out = svtr(a)
|
||||
print(svtr)
|
||||
print(out.size())
|
||||
@@ -0,0 +1,74 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
|
||||
class Hswish(nn.Module):
|
||||
def __init__(self, inplace=True):
|
||||
super(Hswish, self).__init__()
|
||||
self.inplace = inplace
|
||||
|
||||
def forward(self, x):
|
||||
return x * F.relu6(x + 3.0, inplace=self.inplace) / 6.0
|
||||
|
||||
|
||||
# out = max(0, min(1, slop*x+offset))
|
||||
# paddle.fluid.layers.hard_sigmoid(x, slope=0.2, offset=0.5, name=None)
|
||||
class Hsigmoid(nn.Module):
|
||||
def __init__(self, inplace=True):
|
||||
super(Hsigmoid, self).__init__()
|
||||
self.inplace = inplace
|
||||
|
||||
def forward(self, x):
|
||||
# torch: F.relu6(x + 3., inplace=self.inplace) / 6.
|
||||
# paddle: F.relu6(1.2 * x + 3., inplace=self.inplace) / 6.
|
||||
return F.relu6(1.2 * x + 3.0, inplace=self.inplace) / 6.0
|
||||
|
||||
|
||||
class GELU(nn.Module):
|
||||
def __init__(self, inplace=True):
|
||||
super(GELU, self).__init__()
|
||||
self.inplace = inplace
|
||||
|
||||
def forward(self, x):
|
||||
return torch.nn.functional.gelu(x)
|
||||
|
||||
|
||||
class Swish(nn.Module):
|
||||
def __init__(self, inplace=True):
|
||||
super(Swish, self).__init__()
|
||||
self.inplace = inplace
|
||||
|
||||
def forward(self, x):
|
||||
if self.inplace:
|
||||
x.mul_(torch.sigmoid(x))
|
||||
return x
|
||||
else:
|
||||
return x * torch.sigmoid(x)
|
||||
|
||||
|
||||
class Activation(nn.Module):
|
||||
def __init__(self, act_type, inplace=True):
|
||||
super(Activation, self).__init__()
|
||||
act_type = act_type.lower()
|
||||
if act_type == "relu":
|
||||
self.act = nn.ReLU(inplace=inplace)
|
||||
elif act_type == "relu6":
|
||||
self.act = nn.ReLU6(inplace=inplace)
|
||||
elif act_type == "sigmoid":
|
||||
raise NotImplementedError
|
||||
elif act_type == "hard_sigmoid":
|
||||
self.act = Hsigmoid(inplace)
|
||||
elif act_type == "hard_swish":
|
||||
self.act = Hswish(inplace=inplace)
|
||||
elif act_type == "leakyrelu":
|
||||
self.act = nn.LeakyReLU(inplace=inplace)
|
||||
elif act_type == "gelu":
|
||||
self.act = GELU(inplace=inplace)
|
||||
elif act_type == "swish":
|
||||
self.act = Swish(inplace=inplace)
|
||||
else:
|
||||
raise NotImplementedError
|
||||
|
||||
def forward(self, inputs):
|
||||
return self.act(inputs)
|
||||
@@ -0,0 +1,95 @@
|
||||
0
|
||||
1
|
||||
2
|
||||
3
|
||||
4
|
||||
5
|
||||
6
|
||||
7
|
||||
8
|
||||
9
|
||||
:
|
||||
;
|
||||
<
|
||||
=
|
||||
>
|
||||
?
|
||||
@
|
||||
A
|
||||
B
|
||||
C
|
||||
D
|
||||
E
|
||||
F
|
||||
G
|
||||
H
|
||||
I
|
||||
J
|
||||
K
|
||||
L
|
||||
M
|
||||
N
|
||||
O
|
||||
P
|
||||
Q
|
||||
R
|
||||
S
|
||||
T
|
||||
U
|
||||
V
|
||||
W
|
||||
X
|
||||
Y
|
||||
Z
|
||||
[
|
||||
\
|
||||
]
|
||||
^
|
||||
_
|
||||
`
|
||||
a
|
||||
b
|
||||
c
|
||||
d
|
||||
e
|
||||
f
|
||||
g
|
||||
h
|
||||
i
|
||||
j
|
||||
k
|
||||
l
|
||||
m
|
||||
n
|
||||
o
|
||||
p
|
||||
q
|
||||
r
|
||||
s
|
||||
t
|
||||
u
|
||||
v
|
||||
w
|
||||
x
|
||||
y
|
||||
z
|
||||
{
|
||||
|
|
||||
}
|
||||
~
|
||||
!
|
||||
"
|
||||
#
|
||||
$
|
||||
%
|
||||
&
|
||||
'
|
||||
(
|
||||
)
|
||||
*
|
||||
+
|
||||
,
|
||||
-
|
||||
.
|
||||
/
|
||||
|
||||
@@ -627,6 +627,7 @@ def main(args):
|
||||
ema_vae = EMAModel(vae.parameters(), model_cls=AutoencoderKL, model_config=vae.config)
|
||||
perceptual_loss = lpips.LPIPS(net="vgg").eval()
|
||||
discriminator = NLayerDiscriminator(input_nc=3, n_layers=3, use_actnorm=False).apply(weights_init)
|
||||
discriminator = torch.nn.SyncBatchNorm.convert_sync_batchnorm(discriminator)
|
||||
|
||||
# Taken from [Sayak Paul's Diffusers PR #6511](https://github.com/huggingface/diffusers/pull/6511/files)
|
||||
def unwrap_model(model):
|
||||
@@ -951,13 +952,20 @@ def main(args):
|
||||
logits_fake = discriminator(reconstructions)
|
||||
disc_loss = hinge_d_loss if args.disc_loss == "hinge" else vanilla_d_loss
|
||||
disc_factor = args.disc_factor if global_step >= args.disc_start else 0.0
|
||||
disc_loss = disc_factor * disc_loss(logits_real, logits_fake)
|
||||
d_loss = disc_factor * disc_loss(logits_real, logits_fake)
|
||||
logs = {
|
||||
"disc_loss": disc_loss.detach().mean().item(),
|
||||
"disc_loss": d_loss.detach().mean().item(),
|
||||
"logits_real": logits_real.detach().mean().item(),
|
||||
"logits_fake": logits_fake.detach().mean().item(),
|
||||
"disc_lr": disc_lr_scheduler.get_last_lr()[0],
|
||||
}
|
||||
accelerator.backward(d_loss)
|
||||
if accelerator.sync_gradients:
|
||||
params_to_clip = discriminator.parameters()
|
||||
accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm)
|
||||
disc_optimizer.step()
|
||||
disc_lr_scheduler.step()
|
||||
disc_optimizer.zero_grad(set_to_none=args.set_grads_to_none)
|
||||
# Checks if the accelerator has performed an optimization step behind the scenes
|
||||
if accelerator.sync_gradients:
|
||||
progress_bar.update(1)
|
||||
|
||||
@@ -381,9 +381,7 @@ def log_validation(vae, unet, controlnet, args, accelerator, weight_dtype, step)
|
||||
validation_prompt = log["validation_prompt"]
|
||||
validation_image = log["validation_image"]
|
||||
|
||||
formatted_images = []
|
||||
|
||||
formatted_images.append(np.asarray(validation_image))
|
||||
formatted_images = [np.asarray(validation_image)]
|
||||
|
||||
for image in images:
|
||||
formatted_images.append(np.asarray(image))
|
||||
|
||||
@@ -54,6 +54,7 @@ from diffusers import AutoencoderKL, DDPMScheduler, StableDiffusionInstructPix2P
|
||||
from diffusers.optimization import get_scheduler
|
||||
from diffusers.training_utils import EMAModel, cast_training_params
|
||||
from diffusers.utils import check_min_version, convert_state_dict_to_diffusers, deprecate, is_wandb_available
|
||||
from diffusers.utils.constants import DIFFUSERS_REQUEST_TIMEOUT
|
||||
from diffusers.utils.hub_utils import load_or_create_model_card, populate_model_card
|
||||
from diffusers.utils.import_utils import is_xformers_available
|
||||
from diffusers.utils.torch_utils import is_compiled_module
|
||||
@@ -475,7 +476,7 @@ def convert_to_np(image, resolution):
|
||||
|
||||
|
||||
def download_image(url):
|
||||
image = PIL.Image.open(requests.get(url, stream=True).raw)
|
||||
image = PIL.Image.open(requests.get(url, stream=True, timeout=DIFFUSERS_REQUEST_TIMEOUT).raw)
|
||||
image = PIL.ImageOps.exif_transpose(image)
|
||||
image = image.convert("RGB")
|
||||
return image
|
||||
|
||||
@@ -164,9 +164,7 @@ def log_validation(
|
||||
validation_prompt = log["validation_prompt"]
|
||||
validation_image = log["validation_image"]
|
||||
|
||||
formatted_images = []
|
||||
|
||||
formatted_images.append(np.asarray(validation_image))
|
||||
formatted_images = [np.asarray(validation_image)]
|
||||
|
||||
for image in images:
|
||||
formatted_images.append(np.asarray(image))
|
||||
|
||||
+2
-1
@@ -59,6 +59,7 @@ from diffusers.schedulers import (
|
||||
UnCLIPScheduler,
|
||||
)
|
||||
from diffusers.utils import is_accelerate_available, logging
|
||||
from diffusers.utils.constants import DIFFUSERS_REQUEST_TIMEOUT
|
||||
|
||||
|
||||
if is_accelerate_available():
|
||||
@@ -1435,7 +1436,7 @@ def download_from_original_stable_diffusion_ckpt(
|
||||
config_url = "https://raw.githubusercontent.com/Stability-AI/stablediffusion/main/configs/stable-diffusion/x4-upscaling.yaml"
|
||||
|
||||
if config_url is not None:
|
||||
original_config_file = BytesIO(requests.get(config_url).content)
|
||||
original_config_file = BytesIO(requests.get(config_url, timeout=DIFFUSERS_REQUEST_TIMEOUT).content)
|
||||
else:
|
||||
with open(original_config_file, "r") as f:
|
||||
original_config_file = f.read()
|
||||
|
||||
@@ -1,8 +1,6 @@
|
||||
# Generating images using Flux and PyTorch/XLA
|
||||
|
||||
The `flux_inference` script shows how to do image generation using Flux on TPU devices using PyTorch/XLA. It uses the pallas kernel for flash attention for faster generation.
|
||||
|
||||
It has been tested on [Trillium](https://cloud.google.com/blog/products/compute/introducing-trillium-6th-gen-tpus) TPU versions. No other TPU types have been tested.
|
||||
The `flux_inference` script shows how to do image generation using Flux on TPU devices using PyTorch/XLA. It uses the pallas kernel for flash attention for faster generation using custom flash block sizes for better performance on [Trillium](https://cloud.google.com/blog/products/compute/introducing-trillium-6th-gen-tpus) TPU versions. No other TPU types have been tested.
|
||||
|
||||
## Create TPU
|
||||
|
||||
@@ -23,20 +21,23 @@ Verify that PyTorch and PyTorch/XLA were installed correctly:
|
||||
python3 -c "import torch; import torch_xla;"
|
||||
```
|
||||
|
||||
Install dependencies
|
||||
Clone the diffusers repo and install dependencies
|
||||
|
||||
```bash
|
||||
git clone https://github.com/huggingface/diffusers.git
|
||||
cd diffusers
|
||||
pip install transformers accelerate sentencepiece structlog
|
||||
pushd ../../..
|
||||
pip install .
|
||||
popd
|
||||
cd examples/research_projects/pytorch_xla/inference/flux/
|
||||
```
|
||||
|
||||
## Run the inference job
|
||||
|
||||
### Authenticate
|
||||
|
||||
Run the following command to authenticate your token in order to download Flux weights.
|
||||
**Gated Model**
|
||||
|
||||
As the model is gated, before using it with diffusers you first need to go to the [FLUX.1 [dev] Hugging Face page](https://huggingface.co/black-forest-labs/FLUX.1-dev), fill in the form and accept the gate. Once you are in, you need to log in so that your system knows you’ve accepted the gate. Use the command below to log in:
|
||||
|
||||
```bash
|
||||
huggingface-cli login
|
||||
@@ -50,51 +51,116 @@ python flux_inference.py
|
||||
|
||||
The script loads the text encoders onto the CPU and the Flux transformer and VAE models onto the TPU. The first time the script runs, the compilation time is longer, while the cache stores the compiled programs. On subsequent runs, compilation is much faster and the subsequent passes being the fastest.
|
||||
|
||||
On a Trillium v6e-4, you should expect ~9 sec / 4 images or 2.25 sec / image (as devices run generation in parallel):
|
||||
On a Trillium v6e-4, you should expect ~6 sec / 4 images or 1.5 sec / image (as devices run generation in parallel):
|
||||
|
||||
```bash
|
||||
WARNING:root:libtpu.so and TPU device found. Setting PJRT_DEVICE=TPU.
|
||||
Loading checkpoint shards: 100%|███████████████████████████████| 2/2 [00:00<00:00, 7.01it/s]
|
||||
Loading pipeline components...: 40%|██████████▍ | 2/5 [00:00<00:00, 3.78it/s]You set `add_prefix_space`. The tokenizer needs to be converted from the slow tokenizers
|
||||
Loading pipeline components...: 100%|██████████████████████████| 5/5 [00:00<00:00, 6.72it/s]
|
||||
2025-01-10 00:51:25 [info ] loading flux from black-forest-labs/FLUX.1-dev
|
||||
2025-01-10 00:51:25 [info ] loading flux from black-forest-labs/FLUX.1-dev
|
||||
2025-01-10 00:51:26 [info ] loading flux from black-forest-labs/FLUX.1-dev
|
||||
2025-01-10 00:51:26 [info ] loading flux from black-forest-labs/FLUX.1-dev
|
||||
Loading pipeline components...: 100%|██████████████████████████| 3/3 [00:00<00:00, 4.29it/s]
|
||||
Loading pipeline components...: 100%|██████████████████████████| 3/3 [00:00<00:00, 3.26it/s]
|
||||
Loading pipeline components...: 100%|██████████████████████████| 3/3 [00:00<00:00, 3.27it/s]
|
||||
Loading pipeline components...: 100%|██████████████████████████| 3/3 [00:00<00:00, 3.25it/s]
|
||||
2025-01-10 00:51:34 [info ] starting compilation run...
|
||||
2025-01-10 00:51:35 [info ] starting compilation run...
|
||||
2025-01-10 00:51:37 [info ] starting compilation run...
|
||||
2025-01-10 00:51:37 [info ] starting compilation run...
|
||||
2025-01-10 00:52:52 [info ] compilation took 78.5155531649998 sec.
|
||||
2025-01-10 00:52:53 [info ] starting inference run...
|
||||
2025-01-10 00:52:57 [info ] compilation took 79.52986721400157 sec.
|
||||
2025-01-10 00:52:57 [info ] compilation took 81.91776501700042 sec.
|
||||
2025-01-10 00:52:57 [info ] compilation took 80.24951512600092 sec.
|
||||
2025-01-10 00:52:57 [info ] starting inference run...
|
||||
2025-01-10 00:52:57 [info ] starting inference run...
|
||||
2025-01-10 00:52:58 [info ] starting inference run...
|
||||
2025-01-10 00:53:22 [info ] inference time: 25.112665320000815
|
||||
2025-01-10 00:53:30 [info ] inference time: 7.7019307739992655
|
||||
2025-01-10 00:53:38 [info ] inference time: 7.693858365000779
|
||||
2025-01-10 00:53:46 [info ] inference time: 7.690621814001133
|
||||
2025-01-10 00:53:53 [info ] inference time: 7.679490454000188
|
||||
2025-01-10 00:54:01 [info ] inference time: 7.68949568500102
|
||||
2025-01-10 00:54:09 [info ] inference time: 7.686633744000574
|
||||
2025-01-10 00:54:16 [info ] inference time: 7.696786873999372
|
||||
2025-01-10 00:54:24 [info ] inference time: 7.691988694999964
|
||||
2025-01-10 00:54:32 [info ] inference time: 7.700649563999832
|
||||
2025-01-10 00:54:39 [info ] inference time: 7.684993574001055
|
||||
2025-01-10 00:54:47 [info ] inference time: 7.68343457499941
|
||||
2025-01-10 00:54:55 [info ] inference time: 7.667921153999487
|
||||
2025-01-10 00:55:02 [info ] inference time: 7.683585194001353
|
||||
2025-01-10 00:55:06 [info ] avg. inference over 15 iterations took 8.61202360273334 sec.
|
||||
2025-01-10 00:55:07 [info ] avg. inference over 15 iterations took 8.952725123600006 sec.
|
||||
2025-01-10 00:55:10 [info ] inference time: 7.673799695001435
|
||||
2025-01-10 00:55:10 [info ] avg. inference over 15 iterations took 8.849190365400379 sec.
|
||||
2025-01-10 00:55:10 [info ] saved metric information as /tmp/metrics_report.txt
|
||||
2025-01-10 00:55:12 [info ] avg. inference over 15 iterations took 8.940161458400205 sec.
|
||||
Loading checkpoint shards: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 2/2 [00:00<00:00, 7.06it/s]
|
||||
Loading pipeline components...: 60%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████ | 3/5 [00:00<00:00, 6.80it/s]You set `add_prefix_space`. The tokenizer needs to be converted from the slow tokenizers
|
||||
Loading pipeline components...: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 5/5 [00:00<00:00, 6.28it/s]
|
||||
2025-03-14 21:17:53 [info ] loading flux from black-forest-labs/FLUX.1-dev
|
||||
2025-03-14 21:17:53 [info ] loading flux from black-forest-labs/FLUX.1-dev
|
||||
Loading pipeline components...: 0%| | 0/3 [00:00<?, ?it/s]2025-03-14 21:17:53 [info ] loading flux from black-forest-labs/FLUX.1-dev
|
||||
2025-03-14 21:17:53 [info ] loading flux from black-forest-labs/FLUX.1-dev
|
||||
Loading pipeline components...: 0%| | 0/3 [00:00<?, ?it/s]2025-03-14 21:17:54 [info ] loading flux from black-forest-labs/FLUX.1-dev
|
||||
2025-03-14 21:17:54 [info ] loading flux from black-forest-labs/FLUX.1-dev
|
||||
Loading pipeline components...: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 3/3 [00:00<00:00, 3.66it/s]
|
||||
Loading pipeline components...: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 3/3 [00:00<00:00, 4.48it/s]
|
||||
Loading pipeline components...: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 3/3 [00:00<00:00, 3.32it/s]
|
||||
Loading pipeline components...: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 3/3 [00:00<00:00, 3.69it/s]
|
||||
Loading pipeline components...: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 3/3 [00:00<00:00, 3.74it/s]
|
||||
Loading pipeline components...: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 3/3 [00:00<00:00, 3.10it/s]
|
||||
2025-03-14 21:17:56 [info ] loading flux from black-forest-labs/FLUX.1-dev
|
||||
Loading pipeline components...: 0%| | 0/3 [00:00<?, ?it/s]2025-03-14 21:17:56 [info ] loading flux from black-forest-labs/FLUX.1-dev
|
||||
Loading pipeline components...: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 3/3 [00:00<00:00, 3.55it/s]
|
||||
Loading pipeline components...: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 3/3 [00:02<00:00, 1.46it/s]
|
||||
2025-03-14 21:18:34 [info ] starting compilation run...
|
||||
2025-03-14 21:18:37 [info ] starting compilation run...
|
||||
2025-03-14 21:18:38 [info ] starting compilation run...
|
||||
2025-03-14 21:18:39 [info ] starting compilation run...
|
||||
2025-03-14 21:18:41 [info ] starting compilation run...
|
||||
2025-03-14 21:18:41 [info ] starting compilation run...
|
||||
2025-03-14 21:18:42 [info ] starting compilation run...
|
||||
2025-03-14 21:18:43 [info ] starting compilation run...
|
||||
82%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████▊ | 23/28 [13:35<03:04, 36.80s/it]2025-03-14 21:33:42.057559: W torch_xla/csrc/runtime/pjrt_computation_client.cc:667] Failed to deserialize executable: INTERNAL: TfrtTpuExecutable proto deserialization failed while parsing core program!
|
||||
100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 28/28 [16:27<00:00, 35.28s/it]
|
||||
100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 28/28 [16:27<00:00, 35.26s/it]
|
||||
2025-03-14 21:36:38 [info ] compilation took 1079.3314765350078 sec.
|
||||
100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 28/28 [16:12<00:00, 34.73s/it]
|
||||
2025-03-14 21:36:38 [info ] starting inference run...
|
||||
100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 28/28 [16:12<00:00, 34.73s/it]
|
||||
2025-03-14 21:36:38 [info ] compilation took 1081.89390801001 sec.
|
||||
2025-03-14 21:36:39 [info ] starting inference run...
|
||||
2025-03-14 21:36:39 [info ] compilation took 1077.1543154849933 sec.
|
||||
2025-03-14 21:36:39 [info ] compilation took 1075.7239800530078 sec.
|
||||
2025-03-14 21:36:39 [info ] starting inference run...
|
||||
2025-03-14 21:36:40 [info ] starting inference run...
|
||||
100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 28/28 [16:22<00:00, 35.10s/it]
|
||||
2025-03-14 21:36:50 [info ] compilation took 1088.1632604240003 sec.
|
||||
2025-03-14 21:36:50 [info ] starting inference run...
|
||||
100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 28/28 [16:28<00:00, 35.32s/it]
|
||||
2025-03-14 21:36:55 [info ] compilation took 1096.8027802760043 sec.
|
||||
2025-03-14 21:36:56 [info ] starting inference run...
|
||||
100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 28/28 [16:59<00:00, 36.40s/it]
|
||||
2025-03-14 21:37:08 [info ] compilation took 1113.8591305939917 sec.
|
||||
2025-03-14 21:37:08 [info ] starting inference run...
|
||||
100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 28/28 [16:55<00:00, 36.26s/it]
|
||||
2025-03-14 21:37:22 [info ] compilation took 1120.5590810020076 sec.
|
||||
2025-03-14 21:37:22 [info ] starting inference run...
|
||||
100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 28/28 [00:05<00:00, 5.00it/s]
|
||||
100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 28/28 [00:09<00:00, 2.98it/s]
|
||||
100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 28/28 [00:06<00:00, 4.08it/s]
|
||||
100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 28/28 [00:09<00:00, 2.82it/s]
|
||||
100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 28/28 [00:08<00:00, 3.34it/s]
|
||||
100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 28/28 [00:06<00:00, 4.22it/s]
|
||||
100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 28/28 [00:06<00:00, 4.09it/s]
|
||||
100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 28/28 [00:11<00:00, 2.41it/s]
|
||||
100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 28/28 [00:06<00:00, 4.50it/s]
|
||||
100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 28/28 [00:05<00:00, 5.10it/s]
|
||||
100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 28/28 [00:05<00:00, 5.27it/s]
|
||||
100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 28/28 [00:05<00:00, 4.80it/s]
|
||||
100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 28/28 [00:05<00:00, 5.39it/s]
|
||||
100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 28/28 [00:05<00:00, 5.39it/s]
|
||||
100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 28/28 [00:04<00:00, 5.67it/s]
|
||||
29%|█████████████████████████████████████████████████████████████████████████████▍ | 8/28 [00:01<00:03, 6.08it/s]/home/jfacevedo_google_com/diffusers/src/diffusers/image_processor.py:147: RuntimeWarning: invalid value encountered in cast
|
||||
images = (images * 255).round().astype("uint8")
|
||||
100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 28/28 [00:04<00:00, 5.82it/s]
|
||||
100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 28/28 [00:04<00:00, 5.93it/s]
|
||||
100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 28/28 [00:04<00:00, 6.02it/s]
|
||||
100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 28/28 [00:04<00:00, 6.02it/s]
|
||||
100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 28/28 [00:04<00:00, 5.98it/s]
|
||||
71%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████▊ | 20/28 [00:03<00:01, 6.03it/s]2025-03-14 21:38:32 [info ] inference time: 5.962021178987925
|
||||
100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 28/28 [00:04<00:00, 5.89it/s]
|
||||
100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 28/28 [00:04<00:00, 6.09it/s]
|
||||
2025-03-14 21:38:32 [info ] avg. inference over 5 iterations took 7.2685392687970305 sec.
|
||||
2025-03-14 21:38:32 [info ] avg. inference over 5 iterations took 7.402720856998348 sec.
|
||||
100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 28/28 [00:04<00:00, 6.01it/s]
|
||||
100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 28/28 [00:04<00:00, 5.89it/s]
|
||||
100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 28/28 [00:04<00:00, 5.96it/s]
|
||||
100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 28/28 [00:04<00:00, 6.06it/s]
|
||||
71%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████▊ | 20/28 [00:03<00:01, 6.01it/s]2025-03-14 21:38:38 [info ] inference time: 5.950578948002658
|
||||
100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 28/28 [00:04<00:00, 5.87it/s]
|
||||
100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 28/28 [00:04<00:00, 6.09it/s]
|
||||
100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 28/28 [00:04<00:00, 6.00it/s]
|
||||
100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 28/28 [00:04<00:00, 5.86it/s]
|
||||
100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 28/28 [00:04<00:00, 5.99it/s]
|
||||
100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 28/28 [00:04<00:00, 6.05it/s]
|
||||
2025-03-14 21:38:43 [info ] avg. inference over 5 iterations took 6.763298449796276 sec.
|
||||
71%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████▊ | 20/28 [00:03<00:01, 6.04it/s]2025-03-14 21:38:44 [info ] inference time: 5.949129879008979
|
||||
100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 28/28 [00:04<00:00, 5.92it/s]
|
||||
100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 28/28 [00:04<00:00, 6.10it/s]
|
||||
100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 28/28 [00:04<00:00, 6.02it/s]
|
||||
39%|██████████████████████████████████████████████████████████████████████████████████████████████████████████ | 11/28 [00:01<00:02, 5.98it/s]2025-03-14 21:38:46 [info ] avg. inference over 5 iterations took 7.221068455604836 sec.
|
||||
100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 28/28 [00:04<00:00, 5.96it/s]
|
||||
100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 28/28 [00:04<00:00, 6.08it/s]
|
||||
93%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████▋ | 26/28 [00:04<00:00, 5.92it/s]2025-03-14 21:38:50 [info ] inference time: 5.954778069004533
|
||||
100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 28/28 [00:04<00:00, 5.90it/s]
|
||||
11%|█████████████████████████████ | 3/28 [00:00<00:04, 6.03it/s]2025-03-14 21:38:50 [info ] avg. inference over 5 iterations took 6.05970350120042 sec.
|
||||
100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 28/28 [00:04<00:00, 6.02it/s]
|
||||
32%|███████████████████████████████████████████████████████████████████████████████████████ | 9/28 [00:01<00:03, 5.99it/s]2025-03-14 21:38:51 [info ] avg. inference over 5 iterations took 6.018543455796316 sec.
|
||||
54%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████▋ | 15/28 [00:02<00:02, 6.00it/s]2025-03-14 21:38:52 [info ] avg. inference over 5 iterations took 5.9609976705978625 sec.
|
||||
100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 28/28 [00:04<00:00, 5.97it/s]
|
||||
2025-03-14 21:38:56 [info ] inference time: 5.944058528999449
|
||||
2025-03-14 21:38:56 [info ] avg. inference over 5 iterations took 5.952113320800708 sec.
|
||||
2025-03-14 21:38:56 [info ] saved metric information as /tmp/metrics_report.txt
|
||||
```
|
||||
@@ -9,6 +9,7 @@ import torch_xla.debug.metrics as met
|
||||
import torch_xla.debug.profiler as xp
|
||||
import torch_xla.distributed.xla_multiprocessing as xmp
|
||||
import torch_xla.runtime as xr
|
||||
from torch_xla.experimental.custom_kernel import FlashAttention
|
||||
|
||||
from diffusers import FluxPipeline
|
||||
|
||||
@@ -36,6 +37,19 @@ def _main(index, args, text_pipe, ckpt_id):
|
||||
ckpt_id, text_encoder=None, tokenizer=None, text_encoder_2=None, tokenizer_2=None, torch_dtype=torch.bfloat16
|
||||
).to(device0)
|
||||
flux_pipe.transformer.enable_xla_flash_attention(partition_spec=("data", None, None, None), is_flux=True)
|
||||
FlashAttention.DEFAULT_BLOCK_SIZES = {
|
||||
"block_q": 1536,
|
||||
"block_k_major": 1536,
|
||||
"block_k": 1536,
|
||||
"block_b": 1536,
|
||||
"block_q_major_dkv": 1536,
|
||||
"block_k_major_dkv": 1536,
|
||||
"block_q_dkv": 1536,
|
||||
"block_k_dkv": 1536,
|
||||
"block_q_dq": 1536,
|
||||
"block_k_dq": 1536,
|
||||
"block_k_major_dq": 1536,
|
||||
}
|
||||
|
||||
prompt = "photograph of an electronics chip in the shape of a race car with trillium written on its side"
|
||||
width = args.width
|
||||
@@ -69,14 +83,14 @@ def _main(index, args, text_pipe, ckpt_id):
|
||||
xm.set_rng_state(seed=unique_seed, device=device0)
|
||||
times = []
|
||||
logger.info("starting inference run...")
|
||||
with torch.no_grad():
|
||||
prompt_embeds, pooled_prompt_embeds, text_ids = text_pipe.encode_prompt(
|
||||
prompt=prompt, prompt_2=None, max_sequence_length=512
|
||||
)
|
||||
prompt_embeds = prompt_embeds.to(device0)
|
||||
pooled_prompt_embeds = pooled_prompt_embeds.to(device0)
|
||||
for _ in range(args.itters):
|
||||
ts = perf_counter()
|
||||
with torch.no_grad():
|
||||
prompt_embeds, pooled_prompt_embeds, text_ids = text_pipe.encode_prompt(
|
||||
prompt=prompt, prompt_2=None, max_sequence_length=512
|
||||
)
|
||||
prompt_embeds = prompt_embeds.to(device0)
|
||||
pooled_prompt_embeds = pooled_prompt_embeds.to(device0)
|
||||
|
||||
if args.profile:
|
||||
xp.trace_detached(f"localhost:{profiler_port}", str(profile_path), duration_ms=profile_duration)
|
||||
@@ -92,7 +106,7 @@ def _main(index, args, text_pipe, ckpt_id):
|
||||
if index == 0:
|
||||
logger.info(f"inference time: {inference_time}")
|
||||
times.append(inference_time)
|
||||
logger.info(f"avg. inference over {args.itters} iterations took {sum(times)/len(times)} sec.")
|
||||
logger.info(f"avg. inference over {args.itters} iterations took {sum(times) / len(times)} sec.")
|
||||
image.save(f"/tmp/inference_out-{index}.png")
|
||||
if index == 0:
|
||||
metrics_report = met.metrics_report()
|
||||
|
||||
@@ -141,9 +141,7 @@ def log_validation(vae, unet, adapter, args, accelerator, weight_dtype, step):
|
||||
validation_prompt = log["validation_prompt"]
|
||||
validation_image = log["validation_image"]
|
||||
|
||||
formatted_images = []
|
||||
|
||||
formatted_images.append(np.asarray(validation_image))
|
||||
formatted_images = [np.asarray(validation_image)]
|
||||
|
||||
for image in images:
|
||||
formatted_images.append(np.asarray(image))
|
||||
|
||||
@@ -53,8 +53,18 @@ args = parser.parse_args()
|
||||
# this is specific to `AdaLayerNormContinuous`:
|
||||
# diffusers implementation split the linear projection into the scale, shift while CogView4 split it tino shift, scale
|
||||
def swap_scale_shift(weight, dim):
|
||||
shift, scale = weight.chunk(2, dim=0)
|
||||
new_weight = torch.cat([scale, shift], dim=0)
|
||||
"""
|
||||
Swap the scale and shift components in the weight tensor.
|
||||
|
||||
Args:
|
||||
weight (torch.Tensor): The original weight tensor.
|
||||
dim (int): The dimension along which to split.
|
||||
|
||||
Returns:
|
||||
torch.Tensor: The modified weight tensor with scale and shift swapped.
|
||||
"""
|
||||
shift, scale = weight.chunk(2, dim=dim)
|
||||
new_weight = torch.cat([scale, shift], dim=dim)
|
||||
return new_weight
|
||||
|
||||
|
||||
@@ -200,6 +210,7 @@ def main(args):
|
||||
"norm_num_groups": 32,
|
||||
"sample_size": 1024,
|
||||
"scaling_factor": 1.0,
|
||||
"shift_factor": 0.0,
|
||||
"force_upcast": True,
|
||||
"use_quant_conv": False,
|
||||
"use_post_quant_conv": False,
|
||||
|
||||
@@ -25,9 +25,15 @@ import argparse
|
||||
|
||||
import torch
|
||||
from tqdm import tqdm
|
||||
from transformers import GlmForCausalLM, PreTrainedTokenizerFast
|
||||
from transformers import GlmModel, PreTrainedTokenizerFast
|
||||
|
||||
from diffusers import AutoencoderKL, CogView4Pipeline, CogView4Transformer2DModel, FlowMatchEulerDiscreteScheduler
|
||||
from diffusers import (
|
||||
AutoencoderKL,
|
||||
CogView4ControlPipeline,
|
||||
CogView4Pipeline,
|
||||
CogView4Transformer2DModel,
|
||||
FlowMatchEulerDiscreteScheduler,
|
||||
)
|
||||
from diffusers.loaders.single_file_utils import convert_ldm_vae_checkpoint
|
||||
|
||||
|
||||
@@ -112,6 +118,12 @@ parser.add_argument(
|
||||
default=128,
|
||||
help="Maximum size for positional embeddings.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--control",
|
||||
action="store_true",
|
||||
default=False,
|
||||
help="Whether to use control model.",
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
@@ -150,13 +162,15 @@ def convert_megatron_transformer_checkpoint_to_diffusers(
|
||||
Returns:
|
||||
dict: The converted state dictionary compatible with Diffusers.
|
||||
"""
|
||||
ckpt = torch.load(ckpt_path, map_location="cpu")
|
||||
ckpt = torch.load(ckpt_path, map_location="cpu", weights_only=False)
|
||||
mega = ckpt["model"]
|
||||
|
||||
new_state_dict = {}
|
||||
|
||||
# Patch Embedding
|
||||
new_state_dict["patch_embed.proj.weight"] = mega["encoder_expand_linear.weight"].reshape(hidden_size, 64)
|
||||
new_state_dict["patch_embed.proj.weight"] = mega["encoder_expand_linear.weight"].reshape(
|
||||
hidden_size, 128 if args.control else 64
|
||||
)
|
||||
new_state_dict["patch_embed.proj.bias"] = mega["encoder_expand_linear.bias"]
|
||||
new_state_dict["patch_embed.text_proj.weight"] = mega["text_projector.weight"]
|
||||
new_state_dict["patch_embed.text_proj.bias"] = mega["text_projector.bias"]
|
||||
@@ -189,14 +203,8 @@ def convert_megatron_transformer_checkpoint_to_diffusers(
|
||||
block_prefix = f"transformer_blocks.{i}."
|
||||
|
||||
# AdaLayerNorm
|
||||
new_state_dict[block_prefix + "norm1.linear.weight"] = swap_scale_shift(
|
||||
mega[f"decoder.layers.{i}.adaln.weight"], dim=0
|
||||
)
|
||||
new_state_dict[block_prefix + "norm1.linear.bias"] = swap_scale_shift(
|
||||
mega[f"decoder.layers.{i}.adaln.bias"], dim=0
|
||||
)
|
||||
|
||||
# QKV
|
||||
new_state_dict[block_prefix + "norm1.linear.weight"] = mega[f"decoder.layers.{i}.adaln.weight"]
|
||||
new_state_dict[block_prefix + "norm1.linear.bias"] = mega[f"decoder.layers.{i}.adaln.bias"]
|
||||
qkv_weight = mega[f"decoder.layers.{i}.self_attention.linear_qkv.weight"]
|
||||
qkv_bias = mega[f"decoder.layers.{i}.self_attention.linear_qkv.bias"]
|
||||
|
||||
@@ -221,7 +229,7 @@ def convert_megatron_transformer_checkpoint_to_diffusers(
|
||||
# Attention Output
|
||||
new_state_dict[block_prefix + "attn1.to_out.0.weight"] = mega[
|
||||
f"decoder.layers.{i}.self_attention.linear_proj.weight"
|
||||
].T
|
||||
]
|
||||
new_state_dict[block_prefix + "attn1.to_out.0.bias"] = mega[
|
||||
f"decoder.layers.{i}.self_attention.linear_proj.bias"
|
||||
]
|
||||
@@ -252,7 +260,7 @@ def convert_cogview4_vae_checkpoint_to_diffusers(ckpt_path, vae_config):
|
||||
Returns:
|
||||
dict: The converted VAE state dictionary compatible with Diffusers.
|
||||
"""
|
||||
original_state_dict = torch.load(ckpt_path, map_location="cpu")["state_dict"]
|
||||
original_state_dict = torch.load(ckpt_path, map_location="cpu", weights_only=False)["state_dict"]
|
||||
return convert_ldm_vae_checkpoint(original_state_dict, vae_config)
|
||||
|
||||
|
||||
@@ -286,7 +294,7 @@ def main(args):
|
||||
)
|
||||
transformer = CogView4Transformer2DModel(
|
||||
patch_size=2,
|
||||
in_channels=16,
|
||||
in_channels=32 if args.control else 16,
|
||||
num_layers=args.num_layers,
|
||||
attention_head_dim=args.attention_head_dim,
|
||||
num_attention_heads=args.num_heads,
|
||||
@@ -317,6 +325,7 @@ def main(args):
|
||||
"norm_num_groups": 32,
|
||||
"sample_size": 1024,
|
||||
"scaling_factor": 1.0,
|
||||
"shift_factor": 0.0,
|
||||
"force_upcast": True,
|
||||
"use_quant_conv": False,
|
||||
"use_post_quant_conv": False,
|
||||
@@ -331,7 +340,7 @@ def main(args):
|
||||
# Load the text encoder and tokenizer
|
||||
text_encoder_id = "THUDM/glm-4-9b-hf"
|
||||
tokenizer = PreTrainedTokenizerFast.from_pretrained(text_encoder_id)
|
||||
text_encoder = GlmForCausalLM.from_pretrained(
|
||||
text_encoder = GlmModel.from_pretrained(
|
||||
text_encoder_id,
|
||||
cache_dir=args.text_encoder_cache_dir,
|
||||
torch_dtype=torch.bfloat16 if args.dtype == "bf16" else torch.float32,
|
||||
@@ -345,13 +354,22 @@ def main(args):
|
||||
)
|
||||
|
||||
# Create the pipeline
|
||||
pipe = CogView4Pipeline(
|
||||
tokenizer=tokenizer,
|
||||
text_encoder=text_encoder,
|
||||
vae=vae,
|
||||
transformer=transformer,
|
||||
scheduler=scheduler,
|
||||
)
|
||||
if args.control:
|
||||
pipe = CogView4ControlPipeline(
|
||||
tokenizer=tokenizer,
|
||||
text_encoder=text_encoder,
|
||||
vae=vae,
|
||||
transformer=transformer,
|
||||
scheduler=scheduler,
|
||||
)
|
||||
else:
|
||||
pipe = CogView4Pipeline(
|
||||
tokenizer=tokenizer,
|
||||
text_encoder=text_encoder,
|
||||
vae=vae,
|
||||
transformer=transformer,
|
||||
scheduler=scheduler,
|
||||
)
|
||||
|
||||
# Save the converted pipeline
|
||||
pipe.save_pretrained(
|
||||
|
||||
@@ -11,6 +11,7 @@ from diffusion import sampling
|
||||
from torch import nn
|
||||
|
||||
from diffusers import DanceDiffusionPipeline, IPNDMScheduler, UNet1DModel
|
||||
from diffusers.utils.constants import DIFFUSERS_REQUEST_TIMEOUT
|
||||
|
||||
|
||||
MODELS_MAP = {
|
||||
@@ -74,7 +75,7 @@ class DiffusionUncond(nn.Module):
|
||||
|
||||
def download(model_name):
|
||||
url = MODELS_MAP[model_name]["url"]
|
||||
r = requests.get(url, stream=True)
|
||||
r = requests.get(url, stream=True, timeout=DIFFUSERS_REQUEST_TIMEOUT)
|
||||
|
||||
local_filename = f"./{model_name}.ckpt"
|
||||
with open(local_filename, "wb") as fp:
|
||||
|
||||
@@ -160,8 +160,9 @@ TRANSFORMER_CONFIGS = {
|
||||
"pooled_projection_dim": 768,
|
||||
"rope_theta": 256.0,
|
||||
"rope_axes_dim": (16, 56, 56),
|
||||
"image_condition_type": None,
|
||||
},
|
||||
"HYVideo-T/2-I2V": {
|
||||
"HYVideo-T/2-I2V-33ch": {
|
||||
"in_channels": 16 * 2 + 1,
|
||||
"out_channels": 16,
|
||||
"num_attention_heads": 24,
|
||||
@@ -178,6 +179,26 @@ TRANSFORMER_CONFIGS = {
|
||||
"pooled_projection_dim": 768,
|
||||
"rope_theta": 256.0,
|
||||
"rope_axes_dim": (16, 56, 56),
|
||||
"image_condition_type": "latent_concat",
|
||||
},
|
||||
"HYVideo-T/2-I2V-16ch": {
|
||||
"in_channels": 16,
|
||||
"out_channels": 16,
|
||||
"num_attention_heads": 24,
|
||||
"attention_head_dim": 128,
|
||||
"num_layers": 20,
|
||||
"num_single_layers": 40,
|
||||
"num_refiner_layers": 2,
|
||||
"mlp_ratio": 4.0,
|
||||
"patch_size": 2,
|
||||
"patch_size_t": 1,
|
||||
"qk_norm": "rms_norm",
|
||||
"guidance_embeds": True,
|
||||
"text_embed_dim": 4096,
|
||||
"pooled_projection_dim": 768,
|
||||
"rope_theta": 256.0,
|
||||
"rope_axes_dim": (16, 56, 56),
|
||||
"image_condition_type": "token_replace",
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
@@ -74,6 +74,32 @@ VAE_091_RENAME_DICT = {
|
||||
"last_scale_shift_table": "scale_shift_table",
|
||||
}
|
||||
|
||||
VAE_095_RENAME_DICT = {
|
||||
# decoder
|
||||
"up_blocks.0": "mid_block",
|
||||
"up_blocks.1": "up_blocks.0.upsamplers.0",
|
||||
"up_blocks.2": "up_blocks.0",
|
||||
"up_blocks.3": "up_blocks.1.upsamplers.0",
|
||||
"up_blocks.4": "up_blocks.1",
|
||||
"up_blocks.5": "up_blocks.2.upsamplers.0",
|
||||
"up_blocks.6": "up_blocks.2",
|
||||
"up_blocks.7": "up_blocks.3.upsamplers.0",
|
||||
"up_blocks.8": "up_blocks.3",
|
||||
# encoder
|
||||
"down_blocks.0": "down_blocks.0",
|
||||
"down_blocks.1": "down_blocks.0.downsamplers.0",
|
||||
"down_blocks.2": "down_blocks.1",
|
||||
"down_blocks.3": "down_blocks.1.downsamplers.0",
|
||||
"down_blocks.4": "down_blocks.2",
|
||||
"down_blocks.5": "down_blocks.2.downsamplers.0",
|
||||
"down_blocks.6": "down_blocks.3",
|
||||
"down_blocks.7": "down_blocks.3.downsamplers.0",
|
||||
"down_blocks.8": "mid_block",
|
||||
# common
|
||||
"last_time_embedder": "time_embedder",
|
||||
"last_scale_shift_table": "scale_shift_table",
|
||||
}
|
||||
|
||||
VAE_SPECIAL_KEYS_REMAP = {
|
||||
"per_channel_statistics.channel": remove_keys_,
|
||||
"per_channel_statistics.mean-of-means": remove_keys_,
|
||||
@@ -81,10 +107,6 @@ VAE_SPECIAL_KEYS_REMAP = {
|
||||
"model.diffusion_model": remove_keys_,
|
||||
}
|
||||
|
||||
VAE_091_SPECIAL_KEYS_REMAP = {
|
||||
"timestep_scale_multiplier": remove_keys_,
|
||||
}
|
||||
|
||||
|
||||
def get_state_dict(saved_dict: Dict[str, Any]) -> Dict[str, Any]:
|
||||
state_dict = saved_dict
|
||||
@@ -104,12 +126,16 @@ def update_state_dict_inplace(state_dict: Dict[str, Any], old_key: str, new_key:
|
||||
def convert_transformer(
|
||||
ckpt_path: str,
|
||||
dtype: torch.dtype,
|
||||
version: str = "0.9.0",
|
||||
):
|
||||
PREFIX_KEY = "model.diffusion_model."
|
||||
|
||||
original_state_dict = get_state_dict(load_file(ckpt_path))
|
||||
config = {}
|
||||
if version == "0.9.5":
|
||||
config["_use_causal_rope_fix"] = True
|
||||
with init_empty_weights():
|
||||
transformer = LTXVideoTransformer3DModel()
|
||||
transformer = LTXVideoTransformer3DModel(**config)
|
||||
|
||||
for key in list(original_state_dict.keys()):
|
||||
new_key = key[:]
|
||||
@@ -161,12 +187,19 @@ def get_vae_config(version: str) -> Dict[str, Any]:
|
||||
"out_channels": 3,
|
||||
"latent_channels": 128,
|
||||
"block_out_channels": (128, 256, 512, 512),
|
||||
"down_block_types": (
|
||||
"LTXVideoDownBlock3D",
|
||||
"LTXVideoDownBlock3D",
|
||||
"LTXVideoDownBlock3D",
|
||||
"LTXVideoDownBlock3D",
|
||||
),
|
||||
"decoder_block_out_channels": (128, 256, 512, 512),
|
||||
"layers_per_block": (4, 3, 3, 3, 4),
|
||||
"decoder_layers_per_block": (4, 3, 3, 3, 4),
|
||||
"spatio_temporal_scaling": (True, True, True, False),
|
||||
"decoder_spatio_temporal_scaling": (True, True, True, False),
|
||||
"decoder_inject_noise": (False, False, False, False, False),
|
||||
"downsample_type": ("conv", "conv", "conv", "conv"),
|
||||
"upsample_residual": (False, False, False, False),
|
||||
"upsample_factor": (1, 1, 1, 1),
|
||||
"patch_size": 4,
|
||||
@@ -183,12 +216,19 @@ def get_vae_config(version: str) -> Dict[str, Any]:
|
||||
"out_channels": 3,
|
||||
"latent_channels": 128,
|
||||
"block_out_channels": (128, 256, 512, 512),
|
||||
"down_block_types": (
|
||||
"LTXVideoDownBlock3D",
|
||||
"LTXVideoDownBlock3D",
|
||||
"LTXVideoDownBlock3D",
|
||||
"LTXVideoDownBlock3D",
|
||||
),
|
||||
"decoder_block_out_channels": (256, 512, 1024),
|
||||
"layers_per_block": (4, 3, 3, 3, 4),
|
||||
"decoder_layers_per_block": (5, 6, 7, 8),
|
||||
"spatio_temporal_scaling": (True, True, True, False),
|
||||
"decoder_spatio_temporal_scaling": (True, True, True),
|
||||
"decoder_inject_noise": (True, True, True, False),
|
||||
"downsample_type": ("conv", "conv", "conv", "conv"),
|
||||
"upsample_residual": (True, True, True),
|
||||
"upsample_factor": (2, 2, 2),
|
||||
"timestep_conditioning": True,
|
||||
@@ -200,7 +240,38 @@ def get_vae_config(version: str) -> Dict[str, Any]:
|
||||
"decoder_causal": False,
|
||||
}
|
||||
VAE_KEYS_RENAME_DICT.update(VAE_091_RENAME_DICT)
|
||||
VAE_SPECIAL_KEYS_REMAP.update(VAE_091_SPECIAL_KEYS_REMAP)
|
||||
elif version == "0.9.5":
|
||||
config = {
|
||||
"in_channels": 3,
|
||||
"out_channels": 3,
|
||||
"latent_channels": 128,
|
||||
"block_out_channels": (128, 256, 512, 1024, 2048),
|
||||
"down_block_types": (
|
||||
"LTXVideo095DownBlock3D",
|
||||
"LTXVideo095DownBlock3D",
|
||||
"LTXVideo095DownBlock3D",
|
||||
"LTXVideo095DownBlock3D",
|
||||
),
|
||||
"decoder_block_out_channels": (256, 512, 1024),
|
||||
"layers_per_block": (4, 6, 6, 2, 2),
|
||||
"decoder_layers_per_block": (5, 5, 5, 5),
|
||||
"spatio_temporal_scaling": (True, True, True, True),
|
||||
"decoder_spatio_temporal_scaling": (True, True, True),
|
||||
"decoder_inject_noise": (False, False, False, False),
|
||||
"downsample_type": ("spatial", "temporal", "spatiotemporal", "spatiotemporal"),
|
||||
"upsample_residual": (True, True, True),
|
||||
"upsample_factor": (2, 2, 2),
|
||||
"timestep_conditioning": True,
|
||||
"patch_size": 4,
|
||||
"patch_size_t": 1,
|
||||
"resnet_norm_eps": 1e-6,
|
||||
"scaling_factor": 1.0,
|
||||
"encoder_causal": True,
|
||||
"decoder_causal": False,
|
||||
"spatial_compression_ratio": 32,
|
||||
"temporal_compression_ratio": 8,
|
||||
}
|
||||
VAE_KEYS_RENAME_DICT.update(VAE_095_RENAME_DICT)
|
||||
return config
|
||||
|
||||
|
||||
@@ -223,7 +294,7 @@ def get_args():
|
||||
parser.add_argument("--output_path", type=str, required=True, help="Path where converted model should be saved")
|
||||
parser.add_argument("--dtype", default="fp32", help="Torch dtype to save the model in.")
|
||||
parser.add_argument(
|
||||
"--version", type=str, default="0.9.0", choices=["0.9.0", "0.9.1"], help="Version of the LTX model"
|
||||
"--version", type=str, default="0.9.0", choices=["0.9.0", "0.9.1", "0.9.5"], help="Version of the LTX model"
|
||||
)
|
||||
return parser.parse_args()
|
||||
|
||||
@@ -277,14 +348,17 @@ if __name__ == "__main__":
|
||||
for param in text_encoder.parameters():
|
||||
param.data = param.data.contiguous()
|
||||
|
||||
scheduler = FlowMatchEulerDiscreteScheduler(
|
||||
use_dynamic_shifting=True,
|
||||
base_shift=0.95,
|
||||
max_shift=2.05,
|
||||
base_image_seq_len=1024,
|
||||
max_image_seq_len=4096,
|
||||
shift_terminal=0.1,
|
||||
)
|
||||
if args.version == "0.9.5":
|
||||
scheduler = FlowMatchEulerDiscreteScheduler(use_dynamic_shifting=False)
|
||||
else:
|
||||
scheduler = FlowMatchEulerDiscreteScheduler(
|
||||
use_dynamic_shifting=True,
|
||||
base_shift=0.95,
|
||||
max_shift=2.05,
|
||||
base_image_seq_len=1024,
|
||||
max_image_seq_len=4096,
|
||||
shift_terminal=0.1,
|
||||
)
|
||||
|
||||
pipe = LTXPipeline(
|
||||
scheduler=scheduler,
|
||||
|
||||
@@ -5,7 +5,7 @@ import torch
|
||||
from safetensors.torch import load_file
|
||||
from transformers import AutoModel, AutoTokenizer
|
||||
|
||||
from diffusers import AutoencoderKL, FlowMatchEulerDiscreteScheduler, LuminaNextDiT2DModel, LuminaText2ImgPipeline
|
||||
from diffusers import AutoencoderKL, FlowMatchEulerDiscreteScheduler, LuminaNextDiT2DModel, LuminaPipeline
|
||||
|
||||
|
||||
def main(args):
|
||||
@@ -115,7 +115,7 @@ def main(args):
|
||||
tokenizer = AutoTokenizer.from_pretrained("google/gemma-2b")
|
||||
text_encoder = AutoModel.from_pretrained("google/gemma-2b")
|
||||
|
||||
pipeline = LuminaText2ImgPipeline(
|
||||
pipeline = LuminaPipeline(
|
||||
tokenizer=tokenizer, text_encoder=text_encoder, transformer=transformer, vae=vae, scheduler=scheduler
|
||||
)
|
||||
pipeline.save_pretrained(args.dump_path)
|
||||
|
||||
@@ -16,7 +16,9 @@ from diffusers import (
|
||||
DPMSolverMultistepScheduler,
|
||||
FlowMatchEulerDiscreteScheduler,
|
||||
SanaPipeline,
|
||||
SanaSprintPipeline,
|
||||
SanaTransformer2DModel,
|
||||
SCMScheduler,
|
||||
)
|
||||
from diffusers.models.modeling_utils import load_model_dict_into_meta
|
||||
from diffusers.utils.import_utils import is_accelerate_available
|
||||
@@ -25,6 +27,10 @@ from diffusers.utils.import_utils import is_accelerate_available
|
||||
CTX = init_empty_weights if is_accelerate_available else nullcontext
|
||||
|
||||
ckpt_ids = [
|
||||
"Efficient-Large-Model/Sana_Sprint_0.6B_1024px/checkpoints/Sana_Sprint_0.6B_1024px.pth"
|
||||
"Efficient-Large-Model/Sana_Sprint_1.6B_1024px/checkpoints/Sana_Sprint_1.6B_1024px.pth"
|
||||
"Efficient-Large-Model/SANA1.5_4.8B_1024px/checkpoints/SANA1.5_4.8B_1024px.pth",
|
||||
"Efficient-Large-Model/SANA1.5_1.6B_1024px/checkpoints/SANA1.5_1.6B_1024px.pth",
|
||||
"Efficient-Large-Model/Sana_1600M_4Kpx_BF16/checkpoints/Sana_1600M_4Kpx_BF16.pth",
|
||||
"Efficient-Large-Model/Sana_1600M_2Kpx_BF16/checkpoints/Sana_1600M_2Kpx_BF16.pth",
|
||||
"Efficient-Large-Model/Sana_1600M_1024px_MultiLing/checkpoints/Sana_1600M_1024px_MultiLing.pth",
|
||||
@@ -72,15 +78,42 @@ def main(args):
|
||||
converted_state_dict["caption_projection.linear_2.weight"] = state_dict.pop("y_embedder.y_proj.fc2.weight")
|
||||
converted_state_dict["caption_projection.linear_2.bias"] = state_dict.pop("y_embedder.y_proj.fc2.bias")
|
||||
|
||||
# AdaLN-single LN
|
||||
converted_state_dict["time_embed.emb.timestep_embedder.linear_1.weight"] = state_dict.pop(
|
||||
"t_embedder.mlp.0.weight"
|
||||
)
|
||||
converted_state_dict["time_embed.emb.timestep_embedder.linear_1.bias"] = state_dict.pop("t_embedder.mlp.0.bias")
|
||||
converted_state_dict["time_embed.emb.timestep_embedder.linear_2.weight"] = state_dict.pop(
|
||||
"t_embedder.mlp.2.weight"
|
||||
)
|
||||
converted_state_dict["time_embed.emb.timestep_embedder.linear_2.bias"] = state_dict.pop("t_embedder.mlp.2.bias")
|
||||
# Handle different time embedding structure based on model type
|
||||
|
||||
if args.model_type in ["SanaSprint_1600M_P1_D20", "SanaSprint_600M_P1_D28"]:
|
||||
# For Sana Sprint, the time embedding structure is different
|
||||
converted_state_dict["time_embed.timestep_embedder.linear_1.weight"] = state_dict.pop(
|
||||
"t_embedder.mlp.0.weight"
|
||||
)
|
||||
converted_state_dict["time_embed.timestep_embedder.linear_1.bias"] = state_dict.pop("t_embedder.mlp.0.bias")
|
||||
converted_state_dict["time_embed.timestep_embedder.linear_2.weight"] = state_dict.pop(
|
||||
"t_embedder.mlp.2.weight"
|
||||
)
|
||||
converted_state_dict["time_embed.timestep_embedder.linear_2.bias"] = state_dict.pop("t_embedder.mlp.2.bias")
|
||||
|
||||
# Guidance embedder for Sana Sprint
|
||||
converted_state_dict["time_embed.guidance_embedder.linear_1.weight"] = state_dict.pop(
|
||||
"cfg_embedder.mlp.0.weight"
|
||||
)
|
||||
converted_state_dict["time_embed.guidance_embedder.linear_1.bias"] = state_dict.pop("cfg_embedder.mlp.0.bias")
|
||||
converted_state_dict["time_embed.guidance_embedder.linear_2.weight"] = state_dict.pop(
|
||||
"cfg_embedder.mlp.2.weight"
|
||||
)
|
||||
converted_state_dict["time_embed.guidance_embedder.linear_2.bias"] = state_dict.pop("cfg_embedder.mlp.2.bias")
|
||||
else:
|
||||
# Original Sana time embedding structure
|
||||
converted_state_dict["time_embed.emb.timestep_embedder.linear_1.weight"] = state_dict.pop(
|
||||
"t_embedder.mlp.0.weight"
|
||||
)
|
||||
converted_state_dict["time_embed.emb.timestep_embedder.linear_1.bias"] = state_dict.pop(
|
||||
"t_embedder.mlp.0.bias"
|
||||
)
|
||||
converted_state_dict["time_embed.emb.timestep_embedder.linear_2.weight"] = state_dict.pop(
|
||||
"t_embedder.mlp.2.weight"
|
||||
)
|
||||
converted_state_dict["time_embed.emb.timestep_embedder.linear_2.bias"] = state_dict.pop(
|
||||
"t_embedder.mlp.2.bias"
|
||||
)
|
||||
|
||||
# Shared norm.
|
||||
converted_state_dict["time_embed.linear.weight"] = state_dict.pop("t_block.1.weight")
|
||||
@@ -96,14 +129,22 @@ def main(args):
|
||||
flow_shift = 3.0
|
||||
|
||||
# model config
|
||||
if args.model_type == "SanaMS_1600M_P1_D20":
|
||||
if args.model_type in ["SanaMS_1600M_P1_D20", "SanaSprint_1600M_P1_D20", "SanaMS1.5_1600M_P1_D20"]:
|
||||
layer_num = 20
|
||||
elif args.model_type == "SanaMS_600M_P1_D28":
|
||||
elif args.model_type in ["SanaMS_600M_P1_D28", "SanaSprint_600M_P1_D28"]:
|
||||
layer_num = 28
|
||||
elif args.model_type == "SanaMS_4800M_P1_D60":
|
||||
layer_num = 60
|
||||
else:
|
||||
raise ValueError(f"{args.model_type} is not supported.")
|
||||
# Positional embedding interpolation scale.
|
||||
interpolation_scale = {512: None, 1024: None, 2048: 1.0, 4096: 2.0}
|
||||
qk_norm = (
|
||||
"rms_norm_across_heads"
|
||||
if args.model_type
|
||||
in ["SanaMS1.5_1600M_P1_D20", "SanaMS1.5_4800M_P1_D60", "SanaSprint_600M_P1_D28", "SanaSprint_1600M_P1_D20"]
|
||||
else None
|
||||
)
|
||||
|
||||
for depth in range(layer_num):
|
||||
# Transformer blocks.
|
||||
@@ -117,6 +158,14 @@ def main(args):
|
||||
converted_state_dict[f"transformer_blocks.{depth}.attn1.to_q.weight"] = q
|
||||
converted_state_dict[f"transformer_blocks.{depth}.attn1.to_k.weight"] = k
|
||||
converted_state_dict[f"transformer_blocks.{depth}.attn1.to_v.weight"] = v
|
||||
if qk_norm is not None:
|
||||
# Add Q/K normalization for self-attention (attn1) - needed for Sana-Sprint and Sana-1.5
|
||||
converted_state_dict[f"transformer_blocks.{depth}.attn1.norm_q.weight"] = state_dict.pop(
|
||||
f"blocks.{depth}.attn.q_norm.weight"
|
||||
)
|
||||
converted_state_dict[f"transformer_blocks.{depth}.attn1.norm_k.weight"] = state_dict.pop(
|
||||
f"blocks.{depth}.attn.k_norm.weight"
|
||||
)
|
||||
# Projection.
|
||||
converted_state_dict[f"transformer_blocks.{depth}.attn1.to_out.0.weight"] = state_dict.pop(
|
||||
f"blocks.{depth}.attn.proj.weight"
|
||||
@@ -154,6 +203,14 @@ def main(args):
|
||||
converted_state_dict[f"transformer_blocks.{depth}.attn2.to_k.bias"] = k_bias
|
||||
converted_state_dict[f"transformer_blocks.{depth}.attn2.to_v.weight"] = v
|
||||
converted_state_dict[f"transformer_blocks.{depth}.attn2.to_v.bias"] = v_bias
|
||||
if qk_norm is not None:
|
||||
# Add Q/K normalization for cross-attention (attn2) - needed for Sana-Sprint and Sana-1.5
|
||||
converted_state_dict[f"transformer_blocks.{depth}.attn2.norm_q.weight"] = state_dict.pop(
|
||||
f"blocks.{depth}.cross_attn.q_norm.weight"
|
||||
)
|
||||
converted_state_dict[f"transformer_blocks.{depth}.attn2.norm_k.weight"] = state_dict.pop(
|
||||
f"blocks.{depth}.cross_attn.k_norm.weight"
|
||||
)
|
||||
|
||||
converted_state_dict[f"transformer_blocks.{depth}.attn2.to_out.0.weight"] = state_dict.pop(
|
||||
f"blocks.{depth}.cross_attn.proj.weight"
|
||||
@@ -169,24 +226,37 @@ def main(args):
|
||||
|
||||
# Transformer
|
||||
with CTX():
|
||||
transformer = SanaTransformer2DModel(
|
||||
in_channels=32,
|
||||
out_channels=32,
|
||||
num_attention_heads=model_kwargs[args.model_type]["num_attention_heads"],
|
||||
attention_head_dim=model_kwargs[args.model_type]["attention_head_dim"],
|
||||
num_layers=model_kwargs[args.model_type]["num_layers"],
|
||||
num_cross_attention_heads=model_kwargs[args.model_type]["num_cross_attention_heads"],
|
||||
cross_attention_head_dim=model_kwargs[args.model_type]["cross_attention_head_dim"],
|
||||
cross_attention_dim=model_kwargs[args.model_type]["cross_attention_dim"],
|
||||
caption_channels=2304,
|
||||
mlp_ratio=2.5,
|
||||
attention_bias=False,
|
||||
sample_size=args.image_size // 32,
|
||||
patch_size=1,
|
||||
norm_elementwise_affine=False,
|
||||
norm_eps=1e-6,
|
||||
interpolation_scale=interpolation_scale[args.image_size],
|
||||
)
|
||||
transformer_kwargs = {
|
||||
"in_channels": 32,
|
||||
"out_channels": 32,
|
||||
"num_attention_heads": model_kwargs[args.model_type]["num_attention_heads"],
|
||||
"attention_head_dim": model_kwargs[args.model_type]["attention_head_dim"],
|
||||
"num_layers": model_kwargs[args.model_type]["num_layers"],
|
||||
"num_cross_attention_heads": model_kwargs[args.model_type]["num_cross_attention_heads"],
|
||||
"cross_attention_head_dim": model_kwargs[args.model_type]["cross_attention_head_dim"],
|
||||
"cross_attention_dim": model_kwargs[args.model_type]["cross_attention_dim"],
|
||||
"caption_channels": 2304,
|
||||
"mlp_ratio": 2.5,
|
||||
"attention_bias": False,
|
||||
"sample_size": args.image_size // 32,
|
||||
"patch_size": 1,
|
||||
"norm_elementwise_affine": False,
|
||||
"norm_eps": 1e-6,
|
||||
"interpolation_scale": interpolation_scale[args.image_size],
|
||||
}
|
||||
|
||||
# Add qk_norm parameter for Sana Sprint
|
||||
if args.model_type in [
|
||||
"SanaMS1.5_1600M_P1_D20",
|
||||
"SanaMS1.5_4800M_P1_D60",
|
||||
"SanaSprint_600M_P1_D28",
|
||||
"SanaSprint_1600M_P1_D20",
|
||||
]:
|
||||
transformer_kwargs["qk_norm"] = "rms_norm_across_heads"
|
||||
if args.model_type in ["SanaSprint_1600M_P1_D20", "SanaSprint_600M_P1_D28"]:
|
||||
transformer_kwargs["guidance_embeds"] = True
|
||||
|
||||
transformer = SanaTransformer2DModel(**transformer_kwargs)
|
||||
|
||||
if is_accelerate_available():
|
||||
load_model_dict_into_meta(transformer, converted_state_dict)
|
||||
@@ -196,6 +266,8 @@ def main(args):
|
||||
try:
|
||||
state_dict.pop("y_embedder.y_embedding")
|
||||
state_dict.pop("pos_embed")
|
||||
state_dict.pop("logvar_linear.weight")
|
||||
state_dict.pop("logvar_linear.bias")
|
||||
except KeyError:
|
||||
print("y_embedder.y_embedding or pos_embed not found in the state_dict")
|
||||
|
||||
@@ -210,47 +282,74 @@ def main(args):
|
||||
print(
|
||||
colored(
|
||||
f"Only saving transformer model of {args.model_type}. "
|
||||
f"Set --save_full_pipeline to save the whole SanaPipeline",
|
||||
f"Set --save_full_pipeline to save the whole Pipeline",
|
||||
"green",
|
||||
attrs=["bold"],
|
||||
)
|
||||
)
|
||||
transformer.save_pretrained(
|
||||
os.path.join(args.dump_path, "transformer"), safe_serialization=True, max_shard_size="5GB", variant=variant
|
||||
os.path.join(args.dump_path, "transformer"), safe_serialization=True, max_shard_size="5GB"
|
||||
)
|
||||
else:
|
||||
print(colored(f"Saving the whole SanaPipeline containing {args.model_type}", "green", attrs=["bold"]))
|
||||
print(colored(f"Saving the whole Pipeline containing {args.model_type}", "green", attrs=["bold"]))
|
||||
# VAE
|
||||
ae = AutoencoderDC.from_pretrained("mit-han-lab/dc-ae-f32c32-sana-1.0-diffusers", torch_dtype=torch.float32)
|
||||
ae = AutoencoderDC.from_pretrained("mit-han-lab/dc-ae-f32c32-sana-1.1-diffusers", torch_dtype=torch.float32)
|
||||
|
||||
# Text Encoder
|
||||
text_encoder_model_path = "google/gemma-2-2b-it"
|
||||
text_encoder_model_path = "Efficient-Large-Model/gemma-2-2b-it"
|
||||
tokenizer = AutoTokenizer.from_pretrained(text_encoder_model_path)
|
||||
tokenizer.padding_side = "right"
|
||||
text_encoder = AutoModelForCausalLM.from_pretrained(
|
||||
text_encoder_model_path, torch_dtype=torch.bfloat16
|
||||
).get_decoder()
|
||||
|
||||
# Scheduler
|
||||
if args.scheduler_type == "flow-dpm_solver":
|
||||
scheduler = DPMSolverMultistepScheduler(
|
||||
flow_shift=flow_shift,
|
||||
use_flow_sigmas=True,
|
||||
prediction_type="flow_prediction",
|
||||
)
|
||||
elif args.scheduler_type == "flow-euler":
|
||||
scheduler = FlowMatchEulerDiscreteScheduler(shift=flow_shift)
|
||||
else:
|
||||
raise ValueError(f"Scheduler type {args.scheduler_type} is not supported")
|
||||
# Choose the appropriate pipeline and scheduler based on model type
|
||||
if args.model_type in ["SanaSprint_1600M_P1_D20", "SanaSprint_600M_P1_D28"]:
|
||||
# Force SCM Scheduler for Sana Sprint regardless of scheduler_type
|
||||
if args.scheduler_type != "scm":
|
||||
print(
|
||||
colored(
|
||||
f"Warning: Overriding scheduler_type '{args.scheduler_type}' to 'scm' for SanaSprint model",
|
||||
"yellow",
|
||||
attrs=["bold"],
|
||||
)
|
||||
)
|
||||
|
||||
pipe = SanaPipeline(
|
||||
tokenizer=tokenizer,
|
||||
text_encoder=text_encoder,
|
||||
transformer=transformer,
|
||||
vae=ae,
|
||||
scheduler=scheduler,
|
||||
)
|
||||
pipe.save_pretrained(args.dump_path, safe_serialization=True, max_shard_size="5GB", variant=variant)
|
||||
# SCM Scheduler for Sana Sprint
|
||||
scheduler_config = {
|
||||
"prediction_type": "trigflow",
|
||||
"sigma_data": 0.5,
|
||||
}
|
||||
scheduler = SCMScheduler(**scheduler_config)
|
||||
pipe = SanaSprintPipeline(
|
||||
tokenizer=tokenizer,
|
||||
text_encoder=text_encoder,
|
||||
transformer=transformer,
|
||||
vae=ae,
|
||||
scheduler=scheduler,
|
||||
)
|
||||
else:
|
||||
# Original Sana scheduler
|
||||
if args.scheduler_type == "flow-dpm_solver":
|
||||
scheduler = DPMSolverMultistepScheduler(
|
||||
flow_shift=flow_shift,
|
||||
use_flow_sigmas=True,
|
||||
prediction_type="flow_prediction",
|
||||
)
|
||||
elif args.scheduler_type == "flow-euler":
|
||||
scheduler = FlowMatchEulerDiscreteScheduler(shift=flow_shift)
|
||||
else:
|
||||
raise ValueError(f"Scheduler type {args.scheduler_type} is not supported")
|
||||
|
||||
pipe = SanaPipeline(
|
||||
tokenizer=tokenizer,
|
||||
text_encoder=text_encoder,
|
||||
transformer=transformer,
|
||||
vae=ae,
|
||||
scheduler=scheduler,
|
||||
)
|
||||
|
||||
pipe.save_pretrained(args.dump_path, safe_serialization=True, max_shard_size="5GB")
|
||||
|
||||
|
||||
DTYPE_MAPPING = {
|
||||
@@ -259,12 +358,6 @@ DTYPE_MAPPING = {
|
||||
"bf16": torch.bfloat16,
|
||||
}
|
||||
|
||||
VARIANT_MAPPING = {
|
||||
"fp32": None,
|
||||
"fp16": "fp16",
|
||||
"bf16": "bf16",
|
||||
}
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
@@ -281,10 +374,24 @@ if __name__ == "__main__":
|
||||
help="Image size of pretrained model, 512, 1024, 2048 or 4096.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--model_type", default="SanaMS_1600M_P1_D20", type=str, choices=["SanaMS_1600M_P1_D20", "SanaMS_600M_P1_D28"]
|
||||
"--model_type",
|
||||
default="SanaMS_1600M_P1_D20",
|
||||
type=str,
|
||||
choices=[
|
||||
"SanaMS_1600M_P1_D20",
|
||||
"SanaMS_600M_P1_D28",
|
||||
"SanaMS1.5_1600M_P1_D20",
|
||||
"SanaMS1.5_4800M_P1_D60",
|
||||
"SanaSprint_1600M_P1_D20",
|
||||
"SanaSprint_600M_P1_D28",
|
||||
],
|
||||
)
|
||||
parser.add_argument(
|
||||
"--scheduler_type", default="flow-dpm_solver", type=str, choices=["flow-dpm_solver", "flow-euler"]
|
||||
"--scheduler_type",
|
||||
default="flow-dpm_solver",
|
||||
type=str,
|
||||
choices=["flow-dpm_solver", "flow-euler", "scm"],
|
||||
help="Scheduler type to use. Use 'scm' for Sana Sprint models.",
|
||||
)
|
||||
parser.add_argument("--dump_path", default=None, type=str, required=True, help="Path to the output pipeline.")
|
||||
parser.add_argument("--save_full_pipeline", action="store_true", help="save all the pipelien elemets in one.")
|
||||
@@ -309,10 +416,41 @@ if __name__ == "__main__":
|
||||
"cross_attention_dim": 1152,
|
||||
"num_layers": 28,
|
||||
},
|
||||
"SanaMS1.5_1600M_P1_D20": {
|
||||
"num_attention_heads": 70,
|
||||
"attention_head_dim": 32,
|
||||
"num_cross_attention_heads": 20,
|
||||
"cross_attention_head_dim": 112,
|
||||
"cross_attention_dim": 2240,
|
||||
"num_layers": 20,
|
||||
},
|
||||
"SanaMS1.5_4800M_P1_D60": {
|
||||
"num_attention_heads": 70,
|
||||
"attention_head_dim": 32,
|
||||
"num_cross_attention_heads": 20,
|
||||
"cross_attention_head_dim": 112,
|
||||
"cross_attention_dim": 2240,
|
||||
"num_layers": 60,
|
||||
},
|
||||
"SanaSprint_600M_P1_D28": {
|
||||
"num_attention_heads": 36,
|
||||
"attention_head_dim": 32,
|
||||
"num_cross_attention_heads": 16,
|
||||
"cross_attention_head_dim": 72,
|
||||
"cross_attention_dim": 1152,
|
||||
"num_layers": 28,
|
||||
},
|
||||
"SanaSprint_1600M_P1_D20": {
|
||||
"num_attention_heads": 70,
|
||||
"attention_head_dim": 32,
|
||||
"num_cross_attention_heads": 20,
|
||||
"cross_attention_head_dim": 112,
|
||||
"cross_attention_dim": 2240,
|
||||
"num_layers": 20,
|
||||
},
|
||||
}
|
||||
|
||||
device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||
weight_dtype = DTYPE_MAPPING[args.dtype]
|
||||
variant = VARIANT_MAPPING[args.dtype]
|
||||
|
||||
main(args)
|
||||
|
||||
@@ -13,6 +13,7 @@ from diffusers.pipelines.stable_diffusion.convert_from_ckpt import (
|
||||
renew_vae_attention_paths,
|
||||
renew_vae_resnet_paths,
|
||||
)
|
||||
from diffusers.utils.constants import DIFFUSERS_REQUEST_TIMEOUT
|
||||
|
||||
|
||||
def custom_convert_ldm_vae_checkpoint(checkpoint, config):
|
||||
@@ -122,7 +123,8 @@ def vae_pt_to_vae_diffuser(
|
||||
):
|
||||
# Only support V1
|
||||
r = requests.get(
|
||||
" https://raw.githubusercontent.com/CompVis/stable-diffusion/main/configs/stable-diffusion/v1-inference.yaml"
|
||||
" https://raw.githubusercontent.com/CompVis/stable-diffusion/main/configs/stable-diffusion/v1-inference.yaml",
|
||||
timeout=DIFFUSERS_REQUEST_TIMEOUT,
|
||||
)
|
||||
io_obj = io.BytesIO(r.content)
|
||||
|
||||
|
||||
@@ -128,6 +128,10 @@ _deps = [
|
||||
"GitPython<3.1.19",
|
||||
"scipy",
|
||||
"onnx",
|
||||
"optimum_quanto>=0.2.6",
|
||||
"gguf>=0.10.0",
|
||||
"torchao>=0.7.0",
|
||||
"bitsandbytes>=0.43.3",
|
||||
"regex!=2019.12.17",
|
||||
"requests",
|
||||
"tensorboard",
|
||||
@@ -235,6 +239,11 @@ extras["test"] = deps_list(
|
||||
)
|
||||
extras["torch"] = deps_list("torch", "accelerate")
|
||||
|
||||
extras["bitsandbytes"] = deps_list("bitsandbytes", "accelerate")
|
||||
extras["gguf"] = deps_list("gguf", "accelerate")
|
||||
extras["optimum_quanto"] = deps_list("optimum_quanto", "accelerate")
|
||||
extras["torchao"] = deps_list("torchao", "accelerate")
|
||||
|
||||
if os.name == "nt": # windows
|
||||
extras["flax"] = [] # jax is not supported on windows
|
||||
else:
|
||||
|
||||
+109
-3
@@ -6,14 +6,19 @@ from .utils import (
|
||||
DIFFUSERS_SLOW_IMPORT,
|
||||
OptionalDependencyNotAvailable,
|
||||
_LazyModule,
|
||||
is_accelerate_available,
|
||||
is_bitsandbytes_available,
|
||||
is_flax_available,
|
||||
is_gguf_available,
|
||||
is_k_diffusion_available,
|
||||
is_librosa_available,
|
||||
is_note_seq_available,
|
||||
is_onnx_available,
|
||||
is_optimum_quanto_available,
|
||||
is_scipy_available,
|
||||
is_sentencepiece_available,
|
||||
is_torch_available,
|
||||
is_torchao_available,
|
||||
is_torchsde_available,
|
||||
is_transformers_available,
|
||||
)
|
||||
@@ -32,7 +37,7 @@ _import_structure = {
|
||||
"loaders": ["FromOriginalModelMixin"],
|
||||
"models": [],
|
||||
"pipelines": [],
|
||||
"quantizers.quantization_config": ["BitsAndBytesConfig", "GGUFQuantizationConfig", "TorchAoConfig"],
|
||||
"quantizers.quantization_config": [],
|
||||
"schedulers": [],
|
||||
"utils": [
|
||||
"OptionalDependencyNotAvailable",
|
||||
@@ -54,6 +59,54 @@ _import_structure = {
|
||||
],
|
||||
}
|
||||
|
||||
try:
|
||||
if not is_torch_available() and not is_accelerate_available() and not is_bitsandbytes_available():
|
||||
raise OptionalDependencyNotAvailable()
|
||||
except OptionalDependencyNotAvailable:
|
||||
from .utils import dummy_bitsandbytes_objects
|
||||
|
||||
_import_structure["utils.dummy_bitsandbytes_objects"] = [
|
||||
name for name in dir(dummy_bitsandbytes_objects) if not name.startswith("_")
|
||||
]
|
||||
else:
|
||||
_import_structure["quantizers.quantization_config"].append("BitsAndBytesConfig")
|
||||
|
||||
try:
|
||||
if not is_torch_available() and not is_accelerate_available() and not is_gguf_available():
|
||||
raise OptionalDependencyNotAvailable()
|
||||
except OptionalDependencyNotAvailable:
|
||||
from .utils import dummy_gguf_objects
|
||||
|
||||
_import_structure["utils.dummy_gguf_objects"] = [
|
||||
name for name in dir(dummy_gguf_objects) if not name.startswith("_")
|
||||
]
|
||||
else:
|
||||
_import_structure["quantizers.quantization_config"].append("GGUFQuantizationConfig")
|
||||
|
||||
try:
|
||||
if not is_torch_available() and not is_accelerate_available() and not is_torchao_available():
|
||||
raise OptionalDependencyNotAvailable()
|
||||
except OptionalDependencyNotAvailable:
|
||||
from .utils import dummy_torchao_objects
|
||||
|
||||
_import_structure["utils.dummy_torchao_objects"] = [
|
||||
name for name in dir(dummy_torchao_objects) if not name.startswith("_")
|
||||
]
|
||||
else:
|
||||
_import_structure["quantizers.quantization_config"].append("TorchAoConfig")
|
||||
|
||||
try:
|
||||
if not is_torch_available() and not is_accelerate_available() and not is_optimum_quanto_available():
|
||||
raise OptionalDependencyNotAvailable()
|
||||
except OptionalDependencyNotAvailable:
|
||||
from .utils import dummy_optimum_quanto_objects
|
||||
|
||||
_import_structure["utils.dummy_optimum_quanto_objects"] = [
|
||||
name for name in dir(dummy_optimum_quanto_objects) if not name.startswith("_")
|
||||
]
|
||||
else:
|
||||
_import_structure["quantizers.quantization_config"].append("QuantoConfig")
|
||||
|
||||
try:
|
||||
if not is_onnx_available():
|
||||
raise OptionalDependencyNotAvailable()
|
||||
@@ -78,8 +131,10 @@ except OptionalDependencyNotAvailable:
|
||||
else:
|
||||
_import_structure["hooks"].extend(
|
||||
[
|
||||
"FasterCacheConfig",
|
||||
"HookRegistry",
|
||||
"PyramidAttentionBroadcastConfig",
|
||||
"apply_faster_cache",
|
||||
"apply_pyramid_attention_broadcast",
|
||||
]
|
||||
)
|
||||
@@ -218,6 +273,7 @@ else:
|
||||
"RePaintScheduler",
|
||||
"SASolverScheduler",
|
||||
"SchedulerMixin",
|
||||
"SCMScheduler",
|
||||
"ScoreSdeVeScheduler",
|
||||
"TCDScheduler",
|
||||
"UnCLIPScheduler",
|
||||
@@ -292,6 +348,7 @@ else:
|
||||
"CogVideoXPipeline",
|
||||
"CogVideoXVideoToVideoPipeline",
|
||||
"CogView3PlusPipeline",
|
||||
"CogView4ControlPipeline",
|
||||
"CogView4Pipeline",
|
||||
"ConsisIDPipeline",
|
||||
"CycleDiffusionPipeline",
|
||||
@@ -348,9 +405,12 @@ else:
|
||||
"LDMTextToImagePipeline",
|
||||
"LEditsPPPipelineStableDiffusion",
|
||||
"LEditsPPPipelineStableDiffusionXL",
|
||||
"LTXConditionPipeline",
|
||||
"LTXImageToVideoPipeline",
|
||||
"LTXPipeline",
|
||||
"Lumina2Pipeline",
|
||||
"Lumina2Text2ImgPipeline",
|
||||
"LuminaPipeline",
|
||||
"LuminaText2ImgPipeline",
|
||||
"MarigoldDepthPipeline",
|
||||
"MarigoldIntrinsicsPipeline",
|
||||
@@ -366,6 +426,7 @@ else:
|
||||
"ReduxImageEncoder",
|
||||
"SanaPAGPipeline",
|
||||
"SanaPipeline",
|
||||
"SanaSprintPipeline",
|
||||
"SemanticStableDiffusionPipeline",
|
||||
"ShapEImg2ImgPipeline",
|
||||
"ShapEPipeline",
|
||||
@@ -448,6 +509,7 @@ else:
|
||||
"VQDiffusionPipeline",
|
||||
"WanImageToVideoPipeline",
|
||||
"WanPipeline",
|
||||
"WanVideoToVideoPipeline",
|
||||
"WuerstchenCombinedPipeline",
|
||||
"WuerstchenDecoderPipeline",
|
||||
"WuerstchenPriorPipeline",
|
||||
@@ -599,7 +661,38 @@ else:
|
||||
|
||||
if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
from .configuration_utils import ConfigMixin
|
||||
from .quantizers.quantization_config import BitsAndBytesConfig, GGUFQuantizationConfig, TorchAoConfig
|
||||
|
||||
try:
|
||||
if not is_bitsandbytes_available():
|
||||
raise OptionalDependencyNotAvailable()
|
||||
except OptionalDependencyNotAvailable:
|
||||
from .utils.dummy_bitsandbytes_objects import *
|
||||
else:
|
||||
from .quantizers.quantization_config import BitsAndBytesConfig
|
||||
|
||||
try:
|
||||
if not is_gguf_available():
|
||||
raise OptionalDependencyNotAvailable()
|
||||
except OptionalDependencyNotAvailable:
|
||||
from .utils.dummy_gguf_objects import *
|
||||
else:
|
||||
from .quantizers.quantization_config import GGUFQuantizationConfig
|
||||
|
||||
try:
|
||||
if not is_torchao_available():
|
||||
raise OptionalDependencyNotAvailable()
|
||||
except OptionalDependencyNotAvailable:
|
||||
from .utils.dummy_torchao_objects import *
|
||||
else:
|
||||
from .quantizers.quantization_config import TorchAoConfig
|
||||
|
||||
try:
|
||||
if not is_optimum_quanto_available():
|
||||
raise OptionalDependencyNotAvailable()
|
||||
except OptionalDependencyNotAvailable:
|
||||
from .utils.dummy_optimum_quanto_objects import *
|
||||
else:
|
||||
from .quantizers.quantization_config import QuantoConfig
|
||||
|
||||
try:
|
||||
if not is_onnx_available():
|
||||
@@ -615,7 +708,13 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
except OptionalDependencyNotAvailable:
|
||||
from .utils.dummy_pt_objects import * # noqa F403
|
||||
else:
|
||||
from .hooks import HookRegistry, PyramidAttentionBroadcastConfig, apply_pyramid_attention_broadcast
|
||||
from .hooks import (
|
||||
FasterCacheConfig,
|
||||
HookRegistry,
|
||||
PyramidAttentionBroadcastConfig,
|
||||
apply_faster_cache,
|
||||
apply_pyramid_attention_broadcast,
|
||||
)
|
||||
from .models import (
|
||||
AllegroTransformer3DModel,
|
||||
AsymmetricAutoencoderKL,
|
||||
@@ -748,6 +847,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
RePaintScheduler,
|
||||
SASolverScheduler,
|
||||
SchedulerMixin,
|
||||
SCMScheduler,
|
||||
ScoreSdeVeScheduler,
|
||||
TCDScheduler,
|
||||
UnCLIPScheduler,
|
||||
@@ -803,6 +903,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
CogVideoXPipeline,
|
||||
CogVideoXVideoToVideoPipeline,
|
||||
CogView3PlusPipeline,
|
||||
CogView4ControlPipeline,
|
||||
CogView4Pipeline,
|
||||
ConsisIDPipeline,
|
||||
CycleDiffusionPipeline,
|
||||
@@ -859,9 +960,12 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
LDMTextToImagePipeline,
|
||||
LEditsPPPipelineStableDiffusion,
|
||||
LEditsPPPipelineStableDiffusionXL,
|
||||
LTXConditionPipeline,
|
||||
LTXImageToVideoPipeline,
|
||||
LTXPipeline,
|
||||
Lumina2Pipeline,
|
||||
Lumina2Text2ImgPipeline,
|
||||
LuminaPipeline,
|
||||
LuminaText2ImgPipeline,
|
||||
MarigoldDepthPipeline,
|
||||
MarigoldIntrinsicsPipeline,
|
||||
@@ -877,6 +981,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
ReduxImageEncoder,
|
||||
SanaPAGPipeline,
|
||||
SanaPipeline,
|
||||
SanaSprintPipeline,
|
||||
SemanticStableDiffusionPipeline,
|
||||
ShapEImg2ImgPipeline,
|
||||
ShapEPipeline,
|
||||
@@ -958,6 +1063,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
VQDiffusionPipeline,
|
||||
WanImageToVideoPipeline,
|
||||
WanPipeline,
|
||||
WanVideoToVideoPipeline,
|
||||
WuerstchenCombinedPipeline,
|
||||
WuerstchenDecoderPipeline,
|
||||
WuerstchenPriorPipeline,
|
||||
|
||||
@@ -35,6 +35,7 @@ from huggingface_hub.utils import (
|
||||
validate_hf_hub_args,
|
||||
)
|
||||
from requests import HTTPError
|
||||
from typing_extensions import Self
|
||||
|
||||
from . import __version__
|
||||
from .utils import (
|
||||
@@ -185,7 +186,9 @@ class ConfigMixin:
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, config: Union[FrozenDict, Dict[str, Any]] = None, return_unused_kwargs=False, **kwargs):
|
||||
def from_config(
|
||||
cls, config: Union[FrozenDict, Dict[str, Any]] = None, return_unused_kwargs=False, **kwargs
|
||||
) -> Union[Self, Tuple[Self, Dict[str, Any]]]:
|
||||
r"""
|
||||
Instantiate a Python class from a config dictionary.
|
||||
|
||||
|
||||
@@ -35,6 +35,10 @@ deps = {
|
||||
"GitPython": "GitPython<3.1.19",
|
||||
"scipy": "scipy",
|
||||
"onnx": "onnx",
|
||||
"optimum_quanto": "optimum_quanto>=0.2.6",
|
||||
"gguf": "gguf>=0.10.0",
|
||||
"torchao": "torchao>=0.7.0",
|
||||
"bitsandbytes": "bitsandbytes>=0.43.3",
|
||||
"regex": "regex!=2019.12.17",
|
||||
"requests": "requests",
|
||||
"tensorboard": "tensorboard",
|
||||
|
||||
@@ -2,6 +2,7 @@ from ..utils import is_torch_available
|
||||
|
||||
|
||||
if is_torch_available():
|
||||
from .faster_cache import FasterCacheConfig, apply_faster_cache
|
||||
from .group_offloading import apply_group_offloading
|
||||
from .hooks import HookRegistry, ModelHook
|
||||
from .layerwise_casting import apply_layerwise_casting, apply_layerwise_casting_hook
|
||||
|
||||
@@ -0,0 +1,653 @@
|
||||
# Copyright 2024 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import re
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Callable, List, Optional, Tuple
|
||||
|
||||
import torch
|
||||
|
||||
from ..models.attention_processor import Attention, MochiAttention
|
||||
from ..models.modeling_outputs import Transformer2DModelOutput
|
||||
from ..utils import logging
|
||||
from .hooks import HookRegistry, ModelHook
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||
|
||||
|
||||
_FASTER_CACHE_DENOISER_HOOK = "faster_cache_denoiser"
|
||||
_FASTER_CACHE_BLOCK_HOOK = "faster_cache_block"
|
||||
_ATTENTION_CLASSES = (Attention, MochiAttention)
|
||||
_SPATIAL_ATTENTION_BLOCK_IDENTIFIERS = (
|
||||
"^blocks.*attn",
|
||||
"^transformer_blocks.*attn",
|
||||
"^single_transformer_blocks.*attn",
|
||||
)
|
||||
_TEMPORAL_ATTENTION_BLOCK_IDENTIFIERS = ("^temporal_transformer_blocks.*attn",)
|
||||
_TRANSFORMER_BLOCK_IDENTIFIERS = _SPATIAL_ATTENTION_BLOCK_IDENTIFIERS + _TEMPORAL_ATTENTION_BLOCK_IDENTIFIERS
|
||||
_UNCOND_COND_INPUT_KWARGS_IDENTIFIERS = (
|
||||
"hidden_states",
|
||||
"encoder_hidden_states",
|
||||
"timestep",
|
||||
"attention_mask",
|
||||
"encoder_attention_mask",
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class FasterCacheConfig:
|
||||
r"""
|
||||
Configuration for [FasterCache](https://huggingface.co/papers/2410.19355).
|
||||
|
||||
Attributes:
|
||||
spatial_attention_block_skip_range (`int`, defaults to `2`):
|
||||
Calculate the attention states every `N` iterations. If this is set to `N`, the attention computation will
|
||||
be skipped `N - 1` times (i.e., cached attention states will be re-used) before computing the new attention
|
||||
states again.
|
||||
temporal_attention_block_skip_range (`int`, *optional*, defaults to `None`):
|
||||
Calculate the attention states every `N` iterations. If this is set to `N`, the attention computation will
|
||||
be skipped `N - 1` times (i.e., cached attention states will be re-used) before computing the new attention
|
||||
states again.
|
||||
spatial_attention_timestep_skip_range (`Tuple[float, float]`, defaults to `(-1, 681)`):
|
||||
The timestep range within which the spatial attention computation can be skipped without a significant loss
|
||||
in quality. This is to be determined by the user based on the underlying model. The first value in the
|
||||
tuple is the lower bound and the second value is the upper bound. Typically, diffusion timesteps for
|
||||
denoising are in the reversed range of 0 to 1000 (i.e. denoising starts at timestep 1000 and ends at
|
||||
timestep 0). For the default values, this would mean that the spatial attention computation skipping will
|
||||
be applicable only after denoising timestep 681 is reached, and continue until the end of the denoising
|
||||
process.
|
||||
temporal_attention_timestep_skip_range (`Tuple[float, float]`, *optional*, defaults to `None`):
|
||||
The timestep range within which the temporal attention computation can be skipped without a significant
|
||||
loss in quality. This is to be determined by the user based on the underlying model. The first value in the
|
||||
tuple is the lower bound and the second value is the upper bound. Typically, diffusion timesteps for
|
||||
denoising are in the reversed range of 0 to 1000 (i.e. denoising starts at timestep 1000 and ends at
|
||||
timestep 0).
|
||||
low_frequency_weight_update_timestep_range (`Tuple[int, int]`, defaults to `(99, 901)`):
|
||||
The timestep range within which the low frequency weight scaling update is applied. The first value in the
|
||||
tuple is the lower bound and the second value is the upper bound of the timestep range. The callback
|
||||
function for the update is called only within this range.
|
||||
high_frequency_weight_update_timestep_range (`Tuple[int, int]`, defaults to `(-1, 301)`):
|
||||
The timestep range within which the high frequency weight scaling update is applied. The first value in the
|
||||
tuple is the lower bound and the second value is the upper bound of the timestep range. The callback
|
||||
function for the update is called only within this range.
|
||||
alpha_low_frequency (`float`, defaults to `1.1`):
|
||||
The weight to scale the low frequency updates by. This is used to approximate the unconditional branch from
|
||||
the conditional branch outputs.
|
||||
alpha_high_frequency (`float`, defaults to `1.1`):
|
||||
The weight to scale the high frequency updates by. This is used to approximate the unconditional branch
|
||||
from the conditional branch outputs.
|
||||
unconditional_batch_skip_range (`int`, defaults to `5`):
|
||||
Process the unconditional branch every `N` iterations. If this is set to `N`, the unconditional branch
|
||||
computation will be skipped `N - 1` times (i.e., cached unconditional branch states will be re-used) before
|
||||
computing the new unconditional branch states again.
|
||||
unconditional_batch_timestep_skip_range (`Tuple[float, float]`, defaults to `(-1, 641)`):
|
||||
The timestep range within which the unconditional branch computation can be skipped without a significant
|
||||
loss in quality. This is to be determined by the user based on the underlying model. The first value in the
|
||||
tuple is the lower bound and the second value is the upper bound.
|
||||
spatial_attention_block_identifiers (`Tuple[str, ...]`, defaults to `("blocks.*attn1", "transformer_blocks.*attn1", "single_transformer_blocks.*attn1")`):
|
||||
The identifiers to match the spatial attention blocks in the model. If the name of the block contains any
|
||||
of these identifiers, FasterCache will be applied to that block. This can either be the full layer names,
|
||||
partial layer names, or regex patterns. Matching will always be done using a regex match.
|
||||
temporal_attention_block_identifiers (`Tuple[str, ...]`, defaults to `("temporal_transformer_blocks.*attn1",)`):
|
||||
The identifiers to match the temporal attention blocks in the model. If the name of the block contains any
|
||||
of these identifiers, FasterCache will be applied to that block. This can either be the full layer names,
|
||||
partial layer names, or regex patterns. Matching will always be done using a regex match.
|
||||
attention_weight_callback (`Callable[[torch.nn.Module], float]`, defaults to `None`):
|
||||
The callback function to determine the weight to scale the attention outputs by. This function should take
|
||||
the attention module as input and return a float value. This is used to approximate the unconditional
|
||||
branch from the conditional branch outputs. If not provided, the default weight is 0.5 for all timesteps.
|
||||
Typically, as described in the paper, this weight should gradually increase from 0 to 1 as the inference
|
||||
progresses. Users are encouraged to experiment and provide custom weight schedules that take into account
|
||||
the number of inference steps and underlying model behaviour as denoising progresses.
|
||||
low_frequency_weight_callback (`Callable[[torch.nn.Module], float]`, defaults to `None`):
|
||||
The callback function to determine the weight to scale the low frequency updates by. If not provided, the
|
||||
default weight is 1.1 for timesteps within the range specified (as described in the paper).
|
||||
high_frequency_weight_callback (`Callable[[torch.nn.Module], float]`, defaults to `None`):
|
||||
The callback function to determine the weight to scale the high frequency updates by. If not provided, the
|
||||
default weight is 1.1 for timesteps within the range specified (as described in the paper).
|
||||
tensor_format (`str`, defaults to `"BCFHW"`):
|
||||
The format of the input tensors. This should be one of `"BCFHW"`, `"BFCHW"`, or `"BCHW"`. The format is
|
||||
used to split individual latent frames in order for low and high frequency components to be computed.
|
||||
is_guidance_distilled (`bool`, defaults to `False`):
|
||||
Whether the model is guidance distilled or not. If the model is guidance distilled, FasterCache will not be
|
||||
applied at the denoiser-level to skip the unconditional branch computation (as there is none).
|
||||
_unconditional_conditional_input_kwargs_identifiers (`List[str]`, defaults to `("hidden_states", "encoder_hidden_states", "timestep", "attention_mask", "encoder_attention_mask")`):
|
||||
The identifiers to match the input kwargs that contain the batchwise-concatenated unconditional and
|
||||
conditional inputs. If the name of the input kwargs contains any of these identifiers, FasterCache will
|
||||
split the inputs into unconditional and conditional branches. This must be a list of exact input kwargs
|
||||
names that contain the batchwise-concatenated unconditional and conditional inputs.
|
||||
"""
|
||||
|
||||
# In the paper and codebase, they hardcode these values to 2. However, it can be made configurable
|
||||
# after some testing. We default to 2 if these parameters are not provided.
|
||||
spatial_attention_block_skip_range: int = 2
|
||||
temporal_attention_block_skip_range: Optional[int] = None
|
||||
|
||||
spatial_attention_timestep_skip_range: Tuple[int, int] = (-1, 681)
|
||||
temporal_attention_timestep_skip_range: Tuple[int, int] = (-1, 681)
|
||||
|
||||
# Indicator functions for low/high frequency as mentioned in Equation 11 of the paper
|
||||
low_frequency_weight_update_timestep_range: Tuple[int, int] = (99, 901)
|
||||
high_frequency_weight_update_timestep_range: Tuple[int, int] = (-1, 301)
|
||||
|
||||
# ⍺1 and ⍺2 as mentioned in Equation 11 of the paper
|
||||
alpha_low_frequency: float = 1.1
|
||||
alpha_high_frequency: float = 1.1
|
||||
|
||||
# n as described in CFG-Cache explanation in the paper - dependant on the model
|
||||
unconditional_batch_skip_range: int = 5
|
||||
unconditional_batch_timestep_skip_range: Tuple[int, int] = (-1, 641)
|
||||
|
||||
spatial_attention_block_identifiers: Tuple[str, ...] = _SPATIAL_ATTENTION_BLOCK_IDENTIFIERS
|
||||
temporal_attention_block_identifiers: Tuple[str, ...] = _TEMPORAL_ATTENTION_BLOCK_IDENTIFIERS
|
||||
|
||||
attention_weight_callback: Callable[[torch.nn.Module], float] = None
|
||||
low_frequency_weight_callback: Callable[[torch.nn.Module], float] = None
|
||||
high_frequency_weight_callback: Callable[[torch.nn.Module], float] = None
|
||||
|
||||
tensor_format: str = "BCFHW"
|
||||
is_guidance_distilled: bool = False
|
||||
|
||||
current_timestep_callback: Callable[[], int] = None
|
||||
|
||||
_unconditional_conditional_input_kwargs_identifiers: List[str] = _UNCOND_COND_INPUT_KWARGS_IDENTIFIERS
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return (
|
||||
f"FasterCacheConfig(\n"
|
||||
f" spatial_attention_block_skip_range={self.spatial_attention_block_skip_range},\n"
|
||||
f" temporal_attention_block_skip_range={self.temporal_attention_block_skip_range},\n"
|
||||
f" spatial_attention_timestep_skip_range={self.spatial_attention_timestep_skip_range},\n"
|
||||
f" temporal_attention_timestep_skip_range={self.temporal_attention_timestep_skip_range},\n"
|
||||
f" low_frequency_weight_update_timestep_range={self.low_frequency_weight_update_timestep_range},\n"
|
||||
f" high_frequency_weight_update_timestep_range={self.high_frequency_weight_update_timestep_range},\n"
|
||||
f" alpha_low_frequency={self.alpha_low_frequency},\n"
|
||||
f" alpha_high_frequency={self.alpha_high_frequency},\n"
|
||||
f" unconditional_batch_skip_range={self.unconditional_batch_skip_range},\n"
|
||||
f" unconditional_batch_timestep_skip_range={self.unconditional_batch_timestep_skip_range},\n"
|
||||
f" spatial_attention_block_identifiers={self.spatial_attention_block_identifiers},\n"
|
||||
f" temporal_attention_block_identifiers={self.temporal_attention_block_identifiers},\n"
|
||||
f" tensor_format={self.tensor_format},\n"
|
||||
f")"
|
||||
)
|
||||
|
||||
|
||||
class FasterCacheDenoiserState:
|
||||
r"""
|
||||
State for [FasterCache](https://huggingface.co/papers/2410.19355) top-level denoiser module.
|
||||
"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
self.iteration: int = 0
|
||||
self.low_frequency_delta: torch.Tensor = None
|
||||
self.high_frequency_delta: torch.Tensor = None
|
||||
|
||||
def reset(self):
|
||||
self.iteration = 0
|
||||
self.low_frequency_delta = None
|
||||
self.high_frequency_delta = None
|
||||
|
||||
|
||||
class FasterCacheBlockState:
|
||||
r"""
|
||||
State for [FasterCache](https://huggingface.co/papers/2410.19355). Every underlying block that FasterCache is
|
||||
applied to will have an instance of this state.
|
||||
"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
self.iteration: int = 0
|
||||
self.batch_size: int = None
|
||||
self.cache: Tuple[torch.Tensor, torch.Tensor] = None
|
||||
|
||||
def reset(self):
|
||||
self.iteration = 0
|
||||
self.batch_size = None
|
||||
self.cache = None
|
||||
|
||||
|
||||
class FasterCacheDenoiserHook(ModelHook):
|
||||
_is_stateful = True
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
unconditional_batch_skip_range: int,
|
||||
unconditional_batch_timestep_skip_range: Tuple[int, int],
|
||||
tensor_format: str,
|
||||
is_guidance_distilled: bool,
|
||||
uncond_cond_input_kwargs_identifiers: List[str],
|
||||
current_timestep_callback: Callable[[], int],
|
||||
low_frequency_weight_callback: Callable[[torch.nn.Module], torch.Tensor],
|
||||
high_frequency_weight_callback: Callable[[torch.nn.Module], torch.Tensor],
|
||||
) -> None:
|
||||
super().__init__()
|
||||
|
||||
self.unconditional_batch_skip_range = unconditional_batch_skip_range
|
||||
self.unconditional_batch_timestep_skip_range = unconditional_batch_timestep_skip_range
|
||||
# We can't easily detect what args are to be split in unconditional and conditional branches. We
|
||||
# can only do it for kwargs, hence they are the only ones we split. The args are passed as-is.
|
||||
# If a model is to be made compatible with FasterCache, the user must ensure that the inputs that
|
||||
# contain batchwise-concatenated unconditional and conditional inputs are passed as kwargs.
|
||||
self.uncond_cond_input_kwargs_identifiers = uncond_cond_input_kwargs_identifiers
|
||||
self.tensor_format = tensor_format
|
||||
self.is_guidance_distilled = is_guidance_distilled
|
||||
|
||||
self.current_timestep_callback = current_timestep_callback
|
||||
self.low_frequency_weight_callback = low_frequency_weight_callback
|
||||
self.high_frequency_weight_callback = high_frequency_weight_callback
|
||||
|
||||
def initialize_hook(self, module):
|
||||
self.state = FasterCacheDenoiserState()
|
||||
return module
|
||||
|
||||
@staticmethod
|
||||
def _get_cond_input(input: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
# Note: this method assumes that the input tensor is batchwise-concatenated with unconditional inputs
|
||||
# followed by conditional inputs.
|
||||
_, cond = input.chunk(2, dim=0)
|
||||
return cond
|
||||
|
||||
def new_forward(self, module: torch.nn.Module, *args, **kwargs) -> Any:
|
||||
# Split the unconditional and conditional inputs. We only want to infer the conditional branch if the
|
||||
# requirements for skipping the unconditional branch are met as described in the paper.
|
||||
# We skip the unconditional branch only if the following conditions are met:
|
||||
# 1. We have completed at least one iteration of the denoiser
|
||||
# 2. The current timestep is within the range specified by the user. This is the optimal timestep range
|
||||
# where approximating the unconditional branch from the computation of the conditional branch is possible
|
||||
# without a significant loss in quality.
|
||||
# 3. The current iteration is not a multiple of the unconditional batch skip range. This is done so that
|
||||
# we compute the unconditional branch at least once every few iterations to ensure minimal quality loss.
|
||||
is_within_timestep_range = (
|
||||
self.unconditional_batch_timestep_skip_range[0]
|
||||
< self.current_timestep_callback()
|
||||
< self.unconditional_batch_timestep_skip_range[1]
|
||||
)
|
||||
should_skip_uncond = (
|
||||
self.state.iteration > 0
|
||||
and is_within_timestep_range
|
||||
and self.state.iteration % self.unconditional_batch_skip_range != 0
|
||||
and not self.is_guidance_distilled
|
||||
)
|
||||
|
||||
if should_skip_uncond:
|
||||
is_any_kwarg_uncond = any(k in self.uncond_cond_input_kwargs_identifiers for k in kwargs.keys())
|
||||
if is_any_kwarg_uncond:
|
||||
logger.debug("FasterCache - Skipping unconditional branch computation")
|
||||
args = tuple([self._get_cond_input(arg) if torch.is_tensor(arg) else arg for arg in args])
|
||||
kwargs = {
|
||||
k: v if k not in self.uncond_cond_input_kwargs_identifiers else self._get_cond_input(v)
|
||||
for k, v in kwargs.items()
|
||||
}
|
||||
|
||||
output = self.fn_ref.original_forward(*args, **kwargs)
|
||||
|
||||
if self.is_guidance_distilled:
|
||||
self.state.iteration += 1
|
||||
return output
|
||||
|
||||
if torch.is_tensor(output):
|
||||
hidden_states = output
|
||||
elif isinstance(output, (tuple, Transformer2DModelOutput)):
|
||||
hidden_states = output[0]
|
||||
|
||||
batch_size = hidden_states.size(0)
|
||||
|
||||
if should_skip_uncond:
|
||||
self.state.low_frequency_delta = self.state.low_frequency_delta * self.low_frequency_weight_callback(
|
||||
module
|
||||
)
|
||||
self.state.high_frequency_delta = self.state.high_frequency_delta * self.high_frequency_weight_callback(
|
||||
module
|
||||
)
|
||||
|
||||
if self.tensor_format == "BCFHW":
|
||||
hidden_states = hidden_states.permute(0, 2, 1, 3, 4)
|
||||
if self.tensor_format == "BCFHW" or self.tensor_format == "BFCHW":
|
||||
hidden_states = hidden_states.flatten(0, 1)
|
||||
|
||||
low_freq_cond, high_freq_cond = _split_low_high_freq(hidden_states.float())
|
||||
|
||||
# Approximate/compute the unconditional branch outputs as described in Equation 9 and 10 of the paper
|
||||
low_freq_uncond = self.state.low_frequency_delta + low_freq_cond
|
||||
high_freq_uncond = self.state.high_frequency_delta + high_freq_cond
|
||||
uncond_freq = low_freq_uncond + high_freq_uncond
|
||||
|
||||
uncond_states = torch.fft.ifftshift(uncond_freq)
|
||||
uncond_states = torch.fft.ifft2(uncond_states).real
|
||||
|
||||
if self.tensor_format == "BCFHW" or self.tensor_format == "BFCHW":
|
||||
uncond_states = uncond_states.unflatten(0, (batch_size, -1))
|
||||
hidden_states = hidden_states.unflatten(0, (batch_size, -1))
|
||||
if self.tensor_format == "BCFHW":
|
||||
uncond_states = uncond_states.permute(0, 2, 1, 3, 4)
|
||||
hidden_states = hidden_states.permute(0, 2, 1, 3, 4)
|
||||
|
||||
# Concatenate the approximated unconditional and predicted conditional branches
|
||||
uncond_states = uncond_states.to(hidden_states.dtype)
|
||||
hidden_states = torch.cat([uncond_states, hidden_states], dim=0)
|
||||
else:
|
||||
uncond_states, cond_states = hidden_states.chunk(2, dim=0)
|
||||
if self.tensor_format == "BCFHW":
|
||||
uncond_states = uncond_states.permute(0, 2, 1, 3, 4)
|
||||
cond_states = cond_states.permute(0, 2, 1, 3, 4)
|
||||
if self.tensor_format == "BCFHW" or self.tensor_format == "BFCHW":
|
||||
uncond_states = uncond_states.flatten(0, 1)
|
||||
cond_states = cond_states.flatten(0, 1)
|
||||
|
||||
low_freq_uncond, high_freq_uncond = _split_low_high_freq(uncond_states.float())
|
||||
low_freq_cond, high_freq_cond = _split_low_high_freq(cond_states.float())
|
||||
self.state.low_frequency_delta = low_freq_uncond - low_freq_cond
|
||||
self.state.high_frequency_delta = high_freq_uncond - high_freq_cond
|
||||
|
||||
self.state.iteration += 1
|
||||
if torch.is_tensor(output):
|
||||
output = hidden_states
|
||||
elif isinstance(output, tuple):
|
||||
output = (hidden_states, *output[1:])
|
||||
else:
|
||||
output.sample = hidden_states
|
||||
|
||||
return output
|
||||
|
||||
def reset_state(self, module: torch.nn.Module) -> torch.nn.Module:
|
||||
self.state.reset()
|
||||
return module
|
||||
|
||||
|
||||
class FasterCacheBlockHook(ModelHook):
|
||||
_is_stateful = True
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
block_skip_range: int,
|
||||
timestep_skip_range: Tuple[int, int],
|
||||
is_guidance_distilled: bool,
|
||||
weight_callback: Callable[[torch.nn.Module], float],
|
||||
current_timestep_callback: Callable[[], int],
|
||||
) -> None:
|
||||
super().__init__()
|
||||
|
||||
self.block_skip_range = block_skip_range
|
||||
self.timestep_skip_range = timestep_skip_range
|
||||
self.is_guidance_distilled = is_guidance_distilled
|
||||
|
||||
self.weight_callback = weight_callback
|
||||
self.current_timestep_callback = current_timestep_callback
|
||||
|
||||
def initialize_hook(self, module):
|
||||
self.state = FasterCacheBlockState()
|
||||
return module
|
||||
|
||||
def _compute_approximated_attention_output(
|
||||
self, t_2_output: torch.Tensor, t_output: torch.Tensor, weight: float, batch_size: int
|
||||
) -> torch.Tensor:
|
||||
if t_2_output.size(0) != batch_size:
|
||||
# The cache t_2_output contains both batchwise-concatenated unconditional-conditional branch outputs. Just
|
||||
# take the conditional branch outputs.
|
||||
assert t_2_output.size(0) == 2 * batch_size
|
||||
t_2_output = t_2_output[batch_size:]
|
||||
if t_output.size(0) != batch_size:
|
||||
# The cache t_output contains both batchwise-concatenated unconditional-conditional branch outputs. Just
|
||||
# take the conditional branch outputs.
|
||||
assert t_output.size(0) == 2 * batch_size
|
||||
t_output = t_output[batch_size:]
|
||||
return t_output + (t_output - t_2_output) * weight
|
||||
|
||||
def new_forward(self, module: torch.nn.Module, *args, **kwargs) -> Any:
|
||||
batch_size = [
|
||||
*[arg.size(0) for arg in args if torch.is_tensor(arg)],
|
||||
*[v.size(0) for v in kwargs.values() if torch.is_tensor(v)],
|
||||
][0]
|
||||
if self.state.batch_size is None:
|
||||
# Will be updated on first forward pass through the denoiser
|
||||
self.state.batch_size = batch_size
|
||||
|
||||
# If we have to skip due to the skip conditions, then let's skip as expected.
|
||||
# But, we can't skip if the denoiser wants to infer both unconditional and conditional branches. This
|
||||
# is because the expected output shapes of attention layer will not match if we only return values from
|
||||
# the cache (which only caches conditional branch outputs). So, if state.batch_size (which is the true
|
||||
# unconditional-conditional batch size) is same as the current batch size, we don't perform the layer
|
||||
# skip. Otherwise, we conditionally skip the layer based on what state.skip_callback returns.
|
||||
is_within_timestep_range = (
|
||||
self.timestep_skip_range[0] < self.current_timestep_callback() < self.timestep_skip_range[1]
|
||||
)
|
||||
if not is_within_timestep_range:
|
||||
should_skip_attention = False
|
||||
else:
|
||||
should_compute_attention = self.state.iteration > 0 and self.state.iteration % self.block_skip_range == 0
|
||||
should_skip_attention = not should_compute_attention
|
||||
if should_skip_attention:
|
||||
should_skip_attention = self.is_guidance_distilled or self.state.batch_size != batch_size
|
||||
|
||||
if should_skip_attention:
|
||||
logger.debug("FasterCache - Skipping attention and using approximation")
|
||||
if torch.is_tensor(self.state.cache[-1]):
|
||||
t_2_output, t_output = self.state.cache
|
||||
weight = self.weight_callback(module)
|
||||
output = self._compute_approximated_attention_output(t_2_output, t_output, weight, batch_size)
|
||||
else:
|
||||
# The cache contains multiple tensors from past N iterations (N=2 for FasterCache). We need to handle all of them.
|
||||
# Diffusers blocks can return multiple tensors - let's call them [A, B, C, ...] for simplicity.
|
||||
# In our cache, we would have [[A_1, B_1, C_1, ...], [A_2, B_2, C_2, ...], ...] where each list is the output from
|
||||
# a forward pass of the block. We need to compute the approximated output for each of these tensors.
|
||||
# The zip(*state.cache) operation will give us [(A_1, A_2, ...), (B_1, B_2, ...), (C_1, C_2, ...), ...] which
|
||||
# allows us to compute the approximated attention output for each tensor in the cache.
|
||||
output = ()
|
||||
for t_2_output, t_output in zip(*self.state.cache):
|
||||
result = self._compute_approximated_attention_output(
|
||||
t_2_output, t_output, self.weight_callback(module), batch_size
|
||||
)
|
||||
output += (result,)
|
||||
else:
|
||||
logger.debug("FasterCache - Computing attention")
|
||||
output = self.fn_ref.original_forward(*args, **kwargs)
|
||||
|
||||
# Note that the following condition for getting hidden_states should suffice since Diffusers blocks either return
|
||||
# a single hidden_states tensor, or a tuple of (hidden_states, encoder_hidden_states) tensors. We need to handle
|
||||
# both cases.
|
||||
if torch.is_tensor(output):
|
||||
cache_output = output
|
||||
if not self.is_guidance_distilled and cache_output.size(0) == self.state.batch_size:
|
||||
# The output here can be both unconditional-conditional branch outputs or just conditional branch outputs.
|
||||
# This is determined at the higher-level denoiser module. We only want to cache the conditional branch outputs.
|
||||
cache_output = cache_output.chunk(2, dim=0)[1]
|
||||
else:
|
||||
# Cache all return values and perform the same operation as above
|
||||
cache_output = ()
|
||||
for out in output:
|
||||
if not self.is_guidance_distilled and out.size(0) == self.state.batch_size:
|
||||
out = out.chunk(2, dim=0)[1]
|
||||
cache_output += (out,)
|
||||
|
||||
if self.state.cache is None:
|
||||
self.state.cache = [cache_output, cache_output]
|
||||
else:
|
||||
self.state.cache = [self.state.cache[-1], cache_output]
|
||||
|
||||
self.state.iteration += 1
|
||||
return output
|
||||
|
||||
def reset_state(self, module: torch.nn.Module) -> torch.nn.Module:
|
||||
self.state.reset()
|
||||
return module
|
||||
|
||||
|
||||
def apply_faster_cache(module: torch.nn.Module, config: FasterCacheConfig) -> None:
|
||||
r"""
|
||||
Applies [FasterCache](https://huggingface.co/papers/2410.19355) to a given pipeline.
|
||||
|
||||
Args:
|
||||
pipeline (`DiffusionPipeline`):
|
||||
The diffusion pipeline to apply FasterCache to.
|
||||
config (`Optional[FasterCacheConfig]`, `optional`, defaults to `None`):
|
||||
The configuration to use for FasterCache.
|
||||
|
||||
Example:
|
||||
```python
|
||||
>>> import torch
|
||||
>>> from diffusers import CogVideoXPipeline, FasterCacheConfig, apply_faster_cache
|
||||
|
||||
>>> pipe = CogVideoXPipeline.from_pretrained("THUDM/CogVideoX-5b", torch_dtype=torch.bfloat16)
|
||||
>>> pipe.to("cuda")
|
||||
|
||||
>>> config = FasterCacheConfig(
|
||||
... spatial_attention_block_skip_range=2,
|
||||
... spatial_attention_timestep_skip_range=(-1, 681),
|
||||
... low_frequency_weight_update_timestep_range=(99, 641),
|
||||
... high_frequency_weight_update_timestep_range=(-1, 301),
|
||||
... spatial_attention_block_identifiers=["transformer_blocks"],
|
||||
... attention_weight_callback=lambda _: 0.3,
|
||||
... tensor_format="BFCHW",
|
||||
... )
|
||||
>>> apply_faster_cache(pipe.transformer, config)
|
||||
```
|
||||
"""
|
||||
|
||||
logger.warning(
|
||||
"FasterCache is a purely experimental feature and may not work as expected. Not all models support FasterCache. "
|
||||
"The API is subject to change in future releases, with no guarantee of backward compatibility. Please report any issues at "
|
||||
"https://github.com/huggingface/diffusers/issues."
|
||||
)
|
||||
|
||||
if config.attention_weight_callback is None:
|
||||
# If the user has not provided a weight callback, we default to 0.5 for all timesteps.
|
||||
# In the paper, they recommend using a gradually increasing weight from 0 to 1 as the inference progresses, but
|
||||
# this depends from model-to-model. It is required by the user to provide a weight callback if they want to
|
||||
# use a different weight function. Defaulting to 0.5 works well in practice for most cases.
|
||||
logger.warning(
|
||||
"No `attention_weight_callback` provided when enabling FasterCache. Defaulting to using a weight of 0.5 for all timesteps."
|
||||
)
|
||||
config.attention_weight_callback = lambda _: 0.5
|
||||
|
||||
if config.low_frequency_weight_callback is None:
|
||||
logger.debug(
|
||||
"Low frequency weight callback not provided when enabling FasterCache. Defaulting to behaviour described in the paper."
|
||||
)
|
||||
|
||||
def low_frequency_weight_callback(module: torch.nn.Module) -> float:
|
||||
is_within_range = (
|
||||
config.low_frequency_weight_update_timestep_range[0]
|
||||
< config.current_timestep_callback()
|
||||
< config.low_frequency_weight_update_timestep_range[1]
|
||||
)
|
||||
return config.alpha_low_frequency if is_within_range else 1.0
|
||||
|
||||
config.low_frequency_weight_callback = low_frequency_weight_callback
|
||||
|
||||
if config.high_frequency_weight_callback is None:
|
||||
logger.debug(
|
||||
"High frequency weight callback not provided when enabling FasterCache. Defaulting to behaviour described in the paper."
|
||||
)
|
||||
|
||||
def high_frequency_weight_callback(module: torch.nn.Module) -> float:
|
||||
is_within_range = (
|
||||
config.high_frequency_weight_update_timestep_range[0]
|
||||
< config.current_timestep_callback()
|
||||
< config.high_frequency_weight_update_timestep_range[1]
|
||||
)
|
||||
return config.alpha_high_frequency if is_within_range else 1.0
|
||||
|
||||
config.high_frequency_weight_callback = high_frequency_weight_callback
|
||||
|
||||
supported_tensor_formats = ["BCFHW", "BFCHW", "BCHW"] # TODO(aryan): Support BSC for LTX Video
|
||||
if config.tensor_format not in supported_tensor_formats:
|
||||
raise ValueError(f"`tensor_format` must be one of {supported_tensor_formats}, but got {config.tensor_format}.")
|
||||
|
||||
_apply_faster_cache_on_denoiser(module, config)
|
||||
|
||||
for name, submodule in module.named_modules():
|
||||
if not isinstance(submodule, _ATTENTION_CLASSES):
|
||||
continue
|
||||
if any(re.search(identifier, name) is not None for identifier in _TRANSFORMER_BLOCK_IDENTIFIERS):
|
||||
_apply_faster_cache_on_attention_class(name, submodule, config)
|
||||
|
||||
|
||||
def _apply_faster_cache_on_denoiser(module: torch.nn.Module, config: FasterCacheConfig) -> None:
|
||||
hook = FasterCacheDenoiserHook(
|
||||
config.unconditional_batch_skip_range,
|
||||
config.unconditional_batch_timestep_skip_range,
|
||||
config.tensor_format,
|
||||
config.is_guidance_distilled,
|
||||
config._unconditional_conditional_input_kwargs_identifiers,
|
||||
config.current_timestep_callback,
|
||||
config.low_frequency_weight_callback,
|
||||
config.high_frequency_weight_callback,
|
||||
)
|
||||
registry = HookRegistry.check_if_exists_or_initialize(module)
|
||||
registry.register_hook(hook, _FASTER_CACHE_DENOISER_HOOK)
|
||||
|
||||
|
||||
def _apply_faster_cache_on_attention_class(name: str, module: Attention, config: FasterCacheConfig) -> None:
|
||||
is_spatial_self_attention = (
|
||||
any(re.search(identifier, name) is not None for identifier in config.spatial_attention_block_identifiers)
|
||||
and config.spatial_attention_block_skip_range is not None
|
||||
and not getattr(module, "is_cross_attention", False)
|
||||
)
|
||||
is_temporal_self_attention = (
|
||||
any(re.search(identifier, name) is not None for identifier in config.temporal_attention_block_identifiers)
|
||||
and config.temporal_attention_block_skip_range is not None
|
||||
and not module.is_cross_attention
|
||||
)
|
||||
|
||||
block_skip_range, timestep_skip_range, block_type = None, None, None
|
||||
if is_spatial_self_attention:
|
||||
block_skip_range = config.spatial_attention_block_skip_range
|
||||
timestep_skip_range = config.spatial_attention_timestep_skip_range
|
||||
block_type = "spatial"
|
||||
elif is_temporal_self_attention:
|
||||
block_skip_range = config.temporal_attention_block_skip_range
|
||||
timestep_skip_range = config.temporal_attention_timestep_skip_range
|
||||
block_type = "temporal"
|
||||
|
||||
if block_skip_range is None or timestep_skip_range is None:
|
||||
logger.debug(
|
||||
f'Unable to apply FasterCache to the selected layer: "{name}" because it does '
|
||||
f"not match any of the required criteria for spatial or temporal attention layers. Note, "
|
||||
f"however, that this layer may still be valid for applying PAB. Please specify the correct "
|
||||
f"block identifiers in the configuration or use the specialized `apply_faster_cache_on_module` "
|
||||
f"function to apply FasterCache to this layer."
|
||||
)
|
||||
return
|
||||
|
||||
logger.debug(f"Enabling FasterCache ({block_type}) for layer: {name}")
|
||||
hook = FasterCacheBlockHook(
|
||||
block_skip_range,
|
||||
timestep_skip_range,
|
||||
config.is_guidance_distilled,
|
||||
config.attention_weight_callback,
|
||||
config.current_timestep_callback,
|
||||
)
|
||||
registry = HookRegistry.check_if_exists_or_initialize(module)
|
||||
registry.register_hook(hook, _FASTER_CACHE_BLOCK_HOOK)
|
||||
|
||||
|
||||
# Reference: https://github.com/Vchitect/FasterCache/blob/fab32c15014636dc854948319c0a9a8d92c7acb4/scripts/latte/faster_cache_sample_latte.py#L127C1-L143C39
|
||||
@torch.no_grad()
|
||||
def _split_low_high_freq(x):
|
||||
fft = torch.fft.fft2(x)
|
||||
fft_shifted = torch.fft.fftshift(fft)
|
||||
height, width = x.shape[-2:]
|
||||
radius = min(height, width) // 5
|
||||
|
||||
y_grid, x_grid = torch.meshgrid(torch.arange(height), torch.arange(width))
|
||||
center_x, center_y = width // 2, height // 2
|
||||
mask = (x_grid - center_x) ** 2 + (y_grid - center_y) ** 2 <= radius**2
|
||||
|
||||
low_freq_mask = mask.unsqueeze(0).unsqueeze(0).to(x.device)
|
||||
high_freq_mask = ~low_freq_mask
|
||||
|
||||
low_freq_fft = fft_shifted * low_freq_mask
|
||||
high_freq_fft = fft_shifted * high_freq_mask
|
||||
|
||||
return low_freq_fft, high_freq_fft
|
||||
@@ -12,7 +12,7 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from contextlib import nullcontext
|
||||
from contextlib import contextmanager, nullcontext
|
||||
from typing import Dict, List, Optional, Set, Tuple
|
||||
|
||||
import torch
|
||||
@@ -56,7 +56,8 @@ class ModuleGroup:
|
||||
buffers: Optional[List[torch.Tensor]] = None,
|
||||
non_blocking: bool = False,
|
||||
stream: Optional[torch.cuda.Stream] = None,
|
||||
cpu_param_dict: Optional[Dict[torch.nn.Parameter, torch.Tensor]] = None,
|
||||
record_stream: Optional[bool] = False,
|
||||
low_cpu_mem_usage=False,
|
||||
onload_self: bool = True,
|
||||
) -> None:
|
||||
self.modules = modules
|
||||
@@ -64,49 +65,121 @@ class ModuleGroup:
|
||||
self.onload_device = onload_device
|
||||
self.offload_leader = offload_leader
|
||||
self.onload_leader = onload_leader
|
||||
self.parameters = parameters
|
||||
self.buffers = buffers
|
||||
self.parameters = parameters or []
|
||||
self.buffers = buffers or []
|
||||
self.non_blocking = non_blocking or stream is not None
|
||||
self.stream = stream
|
||||
self.cpu_param_dict = cpu_param_dict
|
||||
self.record_stream = record_stream
|
||||
self.onload_self = onload_self
|
||||
self.low_cpu_mem_usage = low_cpu_mem_usage
|
||||
self.cpu_param_dict = self._init_cpu_param_dict()
|
||||
|
||||
if self.stream is not None and self.cpu_param_dict is None:
|
||||
raise ValueError("cpu_param_dict must be provided when using stream for data transfer.")
|
||||
if self.stream is None and self.record_stream:
|
||||
raise ValueError("`record_stream` cannot be True when `stream` is None.")
|
||||
|
||||
def _init_cpu_param_dict(self):
|
||||
cpu_param_dict = {}
|
||||
if self.stream is None:
|
||||
return cpu_param_dict
|
||||
|
||||
for module in self.modules:
|
||||
for param in module.parameters():
|
||||
cpu_param_dict[param] = param.data.cpu() if self.low_cpu_mem_usage else param.data.cpu().pin_memory()
|
||||
for buffer in module.buffers():
|
||||
cpu_param_dict[buffer] = (
|
||||
buffer.data.cpu() if self.low_cpu_mem_usage else buffer.data.cpu().pin_memory()
|
||||
)
|
||||
|
||||
for param in self.parameters:
|
||||
cpu_param_dict[param] = param.data.cpu() if self.low_cpu_mem_usage else param.data.cpu().pin_memory()
|
||||
|
||||
for buffer in self.buffers:
|
||||
cpu_param_dict[buffer] = buffer.data.cpu() if self.low_cpu_mem_usage else buffer.data.cpu().pin_memory()
|
||||
|
||||
return cpu_param_dict
|
||||
|
||||
@contextmanager
|
||||
def _pinned_memory_tensors(self):
|
||||
pinned_dict = {}
|
||||
try:
|
||||
for param, tensor in self.cpu_param_dict.items():
|
||||
if not tensor.is_pinned():
|
||||
pinned_dict[param] = tensor.pin_memory()
|
||||
else:
|
||||
pinned_dict[param] = tensor
|
||||
|
||||
yield pinned_dict
|
||||
|
||||
finally:
|
||||
pinned_dict = None
|
||||
|
||||
def onload_(self):
|
||||
r"""Onloads the group of modules to the onload_device."""
|
||||
context = nullcontext() if self.stream is None else torch.cuda.stream(self.stream)
|
||||
current_stream = torch.cuda.current_stream() if self.record_stream else None
|
||||
|
||||
if self.stream is not None:
|
||||
# Wait for previous Host->Device transfer to complete
|
||||
self.stream.synchronize()
|
||||
|
||||
with context:
|
||||
for group_module in self.modules:
|
||||
group_module.to(self.onload_device, non_blocking=self.non_blocking)
|
||||
if self.parameters is not None:
|
||||
if self.stream is not None:
|
||||
with self._pinned_memory_tensors() as pinned_memory:
|
||||
for group_module in self.modules:
|
||||
for param in group_module.parameters():
|
||||
param.data = pinned_memory[param].to(self.onload_device, non_blocking=self.non_blocking)
|
||||
if self.record_stream:
|
||||
param.data.record_stream(current_stream)
|
||||
for buffer in group_module.buffers():
|
||||
buffer.data = pinned_memory[buffer].to(self.onload_device, non_blocking=self.non_blocking)
|
||||
if self.record_stream:
|
||||
buffer.data.record_stream(current_stream)
|
||||
|
||||
for param in self.parameters:
|
||||
param.data = pinned_memory[param].to(self.onload_device, non_blocking=self.non_blocking)
|
||||
if self.record_stream:
|
||||
param.data.record_stream(current_stream)
|
||||
|
||||
for buffer in self.buffers:
|
||||
buffer.data = pinned_memory[buffer].to(self.onload_device, non_blocking=self.non_blocking)
|
||||
if self.record_stream:
|
||||
buffer.data.record_stream(current_stream)
|
||||
|
||||
else:
|
||||
for group_module in self.modules:
|
||||
for param in group_module.parameters():
|
||||
param.data = param.data.to(self.onload_device, non_blocking=self.non_blocking)
|
||||
for buffer in group_module.buffers():
|
||||
buffer.data = buffer.data.to(self.onload_device, non_blocking=self.non_blocking)
|
||||
|
||||
for param in self.parameters:
|
||||
param.data = param.data.to(self.onload_device, non_blocking=self.non_blocking)
|
||||
if self.buffers is not None:
|
||||
|
||||
for buffer in self.buffers:
|
||||
buffer.data = buffer.data.to(self.onload_device, non_blocking=self.non_blocking)
|
||||
if self.record_stream:
|
||||
buffer.data.record_stream(current_stream)
|
||||
|
||||
def offload_(self):
|
||||
r"""Offloads the group of modules to the offload_device."""
|
||||
if self.stream is not None:
|
||||
torch.cuda.current_stream().synchronize()
|
||||
if not self.record_stream:
|
||||
torch.cuda.current_stream().synchronize()
|
||||
for group_module in self.modules:
|
||||
for param in group_module.parameters():
|
||||
param.data = self.cpu_param_dict[param]
|
||||
for param in self.parameters:
|
||||
param.data = self.cpu_param_dict[param]
|
||||
for buffer in self.buffers:
|
||||
buffer.data = self.cpu_param_dict[buffer]
|
||||
|
||||
else:
|
||||
for group_module in self.modules:
|
||||
group_module.to(self.offload_device, non_blocking=self.non_blocking)
|
||||
if self.parameters is not None:
|
||||
for param in self.parameters:
|
||||
param.data = param.data.to(self.offload_device, non_blocking=self.non_blocking)
|
||||
if self.buffers is not None:
|
||||
for buffer in self.buffers:
|
||||
buffer.data = buffer.data.to(self.offload_device, non_blocking=self.non_blocking)
|
||||
for param in self.parameters:
|
||||
param.data = param.data.to(self.offload_device, non_blocking=self.non_blocking)
|
||||
for buffer in self.buffers:
|
||||
buffer.data = buffer.data.to(self.offload_device, non_blocking=self.non_blocking)
|
||||
|
||||
|
||||
class GroupOffloadingHook(ModelHook):
|
||||
@@ -172,6 +245,13 @@ class LazyPrefetchGroupOffloadingHook(ModelHook):
|
||||
self._layer_execution_tracker_module_names = set()
|
||||
|
||||
def initialize_hook(self, module):
|
||||
def make_execution_order_update_callback(current_name, current_submodule):
|
||||
def callback():
|
||||
logger.debug(f"Adding {current_name} to the execution order")
|
||||
self.execution_order.append((current_name, current_submodule))
|
||||
|
||||
return callback
|
||||
|
||||
# To every submodule that contains a group offloading hook (at this point, no prefetching is enabled for any
|
||||
# of the groups), we add a layer execution tracker hook that will be used to determine the order in which the
|
||||
# layers are executed during the forward pass.
|
||||
@@ -183,14 +263,8 @@ class LazyPrefetchGroupOffloadingHook(ModelHook):
|
||||
group_offloading_hook = registry.get_hook(_GROUP_OFFLOADING)
|
||||
|
||||
if group_offloading_hook is not None:
|
||||
|
||||
def make_execution_order_update_callback(current_name, current_submodule):
|
||||
def callback():
|
||||
logger.debug(f"Adding {current_name} to the execution order")
|
||||
self.execution_order.append((current_name, current_submodule))
|
||||
|
||||
return callback
|
||||
|
||||
# For the first forward pass, we have to load in a blocking manner
|
||||
group_offloading_hook.group.non_blocking = False
|
||||
layer_tracker_hook = LayerExecutionTrackerHook(make_execution_order_update_callback(name, submodule))
|
||||
registry.register_hook(layer_tracker_hook, _LAYER_EXECUTION_TRACKER)
|
||||
self._layer_execution_tracker_module_names.add(name)
|
||||
@@ -220,6 +294,7 @@ class LazyPrefetchGroupOffloadingHook(ModelHook):
|
||||
# Remove the layer execution tracker hooks from the submodules
|
||||
base_module_registry = module._diffusers_hook
|
||||
registries = [submodule._diffusers_hook for _, submodule in self.execution_order]
|
||||
group_offloading_hooks = [registry.get_hook(_GROUP_OFFLOADING) for registry in registries]
|
||||
|
||||
for i in range(num_executed):
|
||||
registries[i].remove_hook(_LAYER_EXECUTION_TRACKER, recurse=False)
|
||||
@@ -227,8 +302,13 @@ class LazyPrefetchGroupOffloadingHook(ModelHook):
|
||||
# Remove the current lazy prefetch group offloading hook so that it doesn't interfere with the next forward pass
|
||||
base_module_registry.remove_hook(_LAZY_PREFETCH_GROUP_OFFLOADING, recurse=False)
|
||||
|
||||
# Apply lazy prefetching by setting required attributes
|
||||
group_offloading_hooks = [registry.get_hook(_GROUP_OFFLOADING) for registry in registries]
|
||||
# LazyPrefetchGroupOffloadingHook is only used with streams, so we know that non_blocking should be True.
|
||||
# We disable non_blocking for the first forward pass, but need to enable it for the subsequent passes to
|
||||
# see the benefits of prefetching.
|
||||
for hook in group_offloading_hooks:
|
||||
hook.group.non_blocking = True
|
||||
|
||||
# Set required attributes for prefetching
|
||||
if num_executed > 0:
|
||||
base_module_group_offloading_hook = base_module_registry.get_hook(_GROUP_OFFLOADING)
|
||||
base_module_group_offloading_hook.next_group = group_offloading_hooks[0].group
|
||||
@@ -268,6 +348,8 @@ def apply_group_offloading(
|
||||
num_blocks_per_group: Optional[int] = None,
|
||||
non_blocking: bool = False,
|
||||
use_stream: bool = False,
|
||||
record_stream: bool = False,
|
||||
low_cpu_mem_usage: bool = False,
|
||||
) -> None:
|
||||
r"""
|
||||
Applies group offloading to the internal layers of a torch.nn.Module. To understand what group offloading is, and
|
||||
@@ -314,6 +396,14 @@ def apply_group_offloading(
|
||||
use_stream (`bool`, defaults to `False`):
|
||||
If True, offloading and onloading is done asynchronously using a CUDA stream. This can be useful for
|
||||
overlapping computation and data transfer.
|
||||
record_stream (`bool`, defaults to `False`): When enabled with `use_stream`, it marks the current tensor
|
||||
as having been used by this stream. It is faster at the expense of slightly more memory usage. Refer to the
|
||||
[PyTorch official docs](https://pytorch.org/docs/stable/generated/torch.Tensor.record_stream.html) more
|
||||
details.
|
||||
low_cpu_mem_usage (`bool`, defaults to `False`):
|
||||
If True, the CPU memory usage is minimized by pinning tensors on-the-fly instead of pre-pinning them. This
|
||||
option only matters when using streamed CPU offloading (i.e. `use_stream=True`). This can be useful when
|
||||
the CPU memory is a bottleneck but may counteract the benefits of using streams.
|
||||
|
||||
Example:
|
||||
```python
|
||||
@@ -349,10 +439,25 @@ def apply_group_offloading(
|
||||
raise ValueError("num_blocks_per_group must be provided when using offload_type='block_level'.")
|
||||
|
||||
_apply_group_offloading_block_level(
|
||||
module, num_blocks_per_group, offload_device, onload_device, non_blocking, stream
|
||||
module=module,
|
||||
num_blocks_per_group=num_blocks_per_group,
|
||||
offload_device=offload_device,
|
||||
onload_device=onload_device,
|
||||
non_blocking=non_blocking,
|
||||
stream=stream,
|
||||
record_stream=record_stream,
|
||||
low_cpu_mem_usage=low_cpu_mem_usage,
|
||||
)
|
||||
elif offload_type == "leaf_level":
|
||||
_apply_group_offloading_leaf_level(module, offload_device, onload_device, non_blocking, stream)
|
||||
_apply_group_offloading_leaf_level(
|
||||
module=module,
|
||||
offload_device=offload_device,
|
||||
onload_device=onload_device,
|
||||
non_blocking=non_blocking,
|
||||
stream=stream,
|
||||
record_stream=record_stream,
|
||||
low_cpu_mem_usage=low_cpu_mem_usage,
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Unsupported offload_type: {offload_type}")
|
||||
|
||||
@@ -364,6 +469,8 @@ def _apply_group_offloading_block_level(
|
||||
onload_device: torch.device,
|
||||
non_blocking: bool,
|
||||
stream: Optional[torch.cuda.Stream] = None,
|
||||
record_stream: Optional[bool] = False,
|
||||
low_cpu_mem_usage: bool = False,
|
||||
) -> None:
|
||||
r"""
|
||||
This function applies offloading to groups of torch.nn.ModuleList or torch.nn.Sequential blocks. In comparison to
|
||||
@@ -382,15 +489,16 @@ def _apply_group_offloading_block_level(
|
||||
stream (`torch.cuda.Stream`, *optional*):
|
||||
If provided, offloading and onloading is done asynchronously using the provided stream. This can be useful
|
||||
for overlapping computation and data transfer.
|
||||
record_stream (`bool`, defaults to `False`): When enabled with `use_stream`, it marks the current tensor
|
||||
as having been used by this stream. It is faster at the expense of slightly more memory usage. Refer to the
|
||||
[PyTorch official docs](https://pytorch.org/docs/stable/generated/torch.Tensor.record_stream.html) more
|
||||
details.
|
||||
low_cpu_mem_usage (`bool`, defaults to `False`):
|
||||
If True, the CPU memory usage is minimized by pinning tensors on-the-fly instead of pre-pinning them. This
|
||||
option only matters when using streamed CPU offloading (i.e. `use_stream=True`). This can be useful when
|
||||
the CPU memory is a bottleneck but may counteract the benefits of using streams.
|
||||
"""
|
||||
|
||||
# Create a pinned CPU parameter dict for async data transfer if streams are to be used
|
||||
cpu_param_dict = None
|
||||
if stream is not None:
|
||||
for param in module.parameters():
|
||||
param.data = param.data.cpu().pin_memory()
|
||||
cpu_param_dict = {param: param.data for param in module.parameters()}
|
||||
|
||||
# Create module groups for ModuleList and Sequential blocks
|
||||
modules_with_group_offloading = set()
|
||||
unmatched_modules = []
|
||||
@@ -411,7 +519,8 @@ def _apply_group_offloading_block_level(
|
||||
onload_leader=current_modules[0],
|
||||
non_blocking=non_blocking,
|
||||
stream=stream,
|
||||
cpu_param_dict=cpu_param_dict,
|
||||
record_stream=record_stream,
|
||||
low_cpu_mem_usage=low_cpu_mem_usage,
|
||||
onload_self=stream is None,
|
||||
)
|
||||
matched_module_groups.append(group)
|
||||
@@ -448,7 +557,7 @@ def _apply_group_offloading_block_level(
|
||||
buffers=buffers,
|
||||
non_blocking=False,
|
||||
stream=None,
|
||||
cpu_param_dict=None,
|
||||
record_stream=False,
|
||||
onload_self=True,
|
||||
)
|
||||
next_group = matched_module_groups[0] if len(matched_module_groups) > 0 else None
|
||||
@@ -461,6 +570,8 @@ def _apply_group_offloading_leaf_level(
|
||||
onload_device: torch.device,
|
||||
non_blocking: bool,
|
||||
stream: Optional[torch.cuda.Stream] = None,
|
||||
record_stream: Optional[bool] = False,
|
||||
low_cpu_mem_usage: bool = False,
|
||||
) -> None:
|
||||
r"""
|
||||
This function applies offloading to groups of leaf modules in a torch.nn.Module. This method has minimal memory
|
||||
@@ -481,15 +592,16 @@ def _apply_group_offloading_leaf_level(
|
||||
stream (`torch.cuda.Stream`, *optional*):
|
||||
If provided, offloading and onloading is done asynchronously using the provided stream. This can be useful
|
||||
for overlapping computation and data transfer.
|
||||
record_stream (`bool`, defaults to `False`): When enabled with `use_stream`, it marks the current tensor
|
||||
as having been used by this stream. It is faster at the expense of slightly more memory usage. Refer to the
|
||||
[PyTorch official docs](https://pytorch.org/docs/stable/generated/torch.Tensor.record_stream.html) more
|
||||
details.
|
||||
low_cpu_mem_usage (`bool`, defaults to `False`):
|
||||
If True, the CPU memory usage is minimized by pinning tensors on-the-fly instead of pre-pinning them. This
|
||||
option only matters when using streamed CPU offloading (i.e. `use_stream=True`). This can be useful when
|
||||
the CPU memory is a bottleneck but may counteract the benefits of using streams.
|
||||
"""
|
||||
|
||||
# Create a pinned CPU parameter dict for async data transfer if streams are to be used
|
||||
cpu_param_dict = None
|
||||
if stream is not None:
|
||||
for param in module.parameters():
|
||||
param.data = param.data.cpu().pin_memory()
|
||||
cpu_param_dict = {param: param.data for param in module.parameters()}
|
||||
|
||||
# Create module groups for leaf modules and apply group offloading hooks
|
||||
modules_with_group_offloading = set()
|
||||
for name, submodule in module.named_modules():
|
||||
@@ -503,7 +615,8 @@ def _apply_group_offloading_leaf_level(
|
||||
onload_leader=submodule,
|
||||
non_blocking=non_blocking,
|
||||
stream=stream,
|
||||
cpu_param_dict=cpu_param_dict,
|
||||
record_stream=record_stream,
|
||||
low_cpu_mem_usage=low_cpu_mem_usage,
|
||||
onload_self=True,
|
||||
)
|
||||
_apply_group_offloading_hook(submodule, group, None)
|
||||
@@ -548,7 +661,8 @@ def _apply_group_offloading_leaf_level(
|
||||
buffers=buffers,
|
||||
non_blocking=non_blocking,
|
||||
stream=stream,
|
||||
cpu_param_dict=cpu_param_dict,
|
||||
record_stream=record_stream,
|
||||
low_cpu_mem_usage=low_cpu_mem_usage,
|
||||
onload_self=True,
|
||||
)
|
||||
_apply_group_offloading_hook(parent_module, group, None)
|
||||
@@ -567,7 +681,8 @@ def _apply_group_offloading_leaf_level(
|
||||
buffers=None,
|
||||
non_blocking=False,
|
||||
stream=None,
|
||||
cpu_param_dict=None,
|
||||
record_stream=False,
|
||||
low_cpu_mem_usage=low_cpu_mem_usage,
|
||||
onload_self=True,
|
||||
)
|
||||
_apply_lazy_group_offloading_hook(module, unmatched_group, None)
|
||||
|
||||
@@ -26,8 +26,8 @@ from .hooks import HookRegistry, ModelHook
|
||||
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||
|
||||
|
||||
_PYRAMID_ATTENTION_BROADCAST_HOOK = "pyramid_attention_broadcast"
|
||||
_ATTENTION_CLASSES = (Attention, MochiAttention)
|
||||
|
||||
_SPATIAL_ATTENTION_BLOCK_IDENTIFIERS = ("blocks", "transformer_blocks", "single_transformer_blocks")
|
||||
_TEMPORAL_ATTENTION_BLOCK_IDENTIFIERS = ("temporal_transformer_blocks",)
|
||||
_CROSS_ATTENTION_BLOCK_IDENTIFIERS = ("blocks", "transformer_blocks")
|
||||
@@ -87,7 +87,7 @@ class PyramidAttentionBroadcastConfig:
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return (
|
||||
f"PyramidAttentionBroadcastConfig("
|
||||
f"PyramidAttentionBroadcastConfig(\n"
|
||||
f" spatial_attention_block_skip_range={self.spatial_attention_block_skip_range},\n"
|
||||
f" temporal_attention_block_skip_range={self.temporal_attention_block_skip_range},\n"
|
||||
f" cross_attention_block_skip_range={self.cross_attention_block_skip_range},\n"
|
||||
@@ -175,10 +175,7 @@ class PyramidAttentionBroadcastHook(ModelHook):
|
||||
return module
|
||||
|
||||
|
||||
def apply_pyramid_attention_broadcast(
|
||||
module: torch.nn.Module,
|
||||
config: PyramidAttentionBroadcastConfig,
|
||||
):
|
||||
def apply_pyramid_attention_broadcast(module: torch.nn.Module, config: PyramidAttentionBroadcastConfig):
|
||||
r"""
|
||||
Apply [Pyramid Attention Broadcast](https://huggingface.co/papers/2408.12588) to a given pipeline.
|
||||
|
||||
@@ -311,4 +308,4 @@ def _apply_pyramid_attention_broadcast_hook(
|
||||
"""
|
||||
registry = HookRegistry.check_if_exists_or_initialize(module)
|
||||
hook = PyramidAttentionBroadcastHook(timestep_skip_range, block_skip_range, current_timestep_callback)
|
||||
registry.register_hook(hook, "pyramid_attention_broadcast")
|
||||
registry.register_hook(hook, _PYRAMID_ATTENTION_BROADCAST_HOOK)
|
||||
|
||||
@@ -70,6 +70,7 @@ if is_torch_available():
|
||||
"LoraLoaderMixin",
|
||||
"FluxLoraLoaderMixin",
|
||||
"CogVideoXLoraLoaderMixin",
|
||||
"CogView4LoraLoaderMixin",
|
||||
"Mochi1LoraLoaderMixin",
|
||||
"HunyuanVideoLoraLoaderMixin",
|
||||
"SanaLoraLoaderMixin",
|
||||
@@ -103,6 +104,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
from .lora_pipeline import (
|
||||
AmusedLoraLoaderMixin,
|
||||
CogVideoXLoraLoaderMixin,
|
||||
CogView4LoraLoaderMixin,
|
||||
FluxLoraLoaderMixin,
|
||||
HunyuanVideoLoraLoaderMixin,
|
||||
LoraLoaderMixin,
|
||||
|
||||
@@ -804,9 +804,7 @@ class SD3IPAdapterMixin:
|
||||
}
|
||||
|
||||
self.register_modules(
|
||||
feature_extractor=SiglipImageProcessor.from_pretrained(image_encoder_subfolder, **kwargs).to(
|
||||
self.device, dtype=self.dtype
|
||||
),
|
||||
feature_extractor=SiglipImageProcessor.from_pretrained(image_encoder_subfolder, **kwargs),
|
||||
image_encoder=SiglipVisionModel.from_pretrained(
|
||||
image_encoder_subfolder, torch_dtype=self.dtype, **kwargs
|
||||
).to(self.device),
|
||||
|
||||
@@ -316,6 +316,7 @@ def _load_lora_into_text_encoder(
|
||||
adapter_name=None,
|
||||
_pipeline=None,
|
||||
low_cpu_mem_usage=False,
|
||||
hotswap: bool = False,
|
||||
):
|
||||
if not USE_PEFT_BACKEND:
|
||||
raise ValueError("PEFT backend is required for this method.")
|
||||
@@ -339,93 +340,101 @@ def _load_lora_into_text_encoder(
|
||||
# If the serialization format is new (introduced in https://github.com/huggingface/diffusers/pull/2918),
|
||||
# then the `state_dict` keys should have `unet_name` and/or `text_encoder_name` as
|
||||
# their prefixes.
|
||||
keys = list(state_dict.keys())
|
||||
prefix = text_encoder_name if prefix is None else prefix
|
||||
|
||||
# Safe prefix to check with.
|
||||
if any(text_encoder_name in key for key in keys):
|
||||
# Load the layers corresponding to text encoder and make necessary adjustments.
|
||||
text_encoder_keys = [k for k in keys if k.startswith(prefix) and k.split(".")[0] == prefix]
|
||||
text_encoder_lora_state_dict = {
|
||||
k.replace(f"{prefix}.", ""): v for k, v in state_dict.items() if k in text_encoder_keys
|
||||
}
|
||||
if hotswap and any(text_encoder_name in key for key in state_dict.keys()):
|
||||
raise ValueError("At the moment, hotswapping is not supported for text encoders, please pass `hotswap=False`.")
|
||||
|
||||
if len(text_encoder_lora_state_dict) > 0:
|
||||
logger.info(f"Loading {prefix}.")
|
||||
rank = {}
|
||||
text_encoder_lora_state_dict = convert_state_dict_to_diffusers(text_encoder_lora_state_dict)
|
||||
# Load the layers corresponding to text encoder and make necessary adjustments.
|
||||
if prefix is not None:
|
||||
state_dict = {k[len(f"{prefix}.") :]: v for k, v in state_dict.items() if k.startswith(f"{prefix}.")}
|
||||
|
||||
# convert state dict
|
||||
text_encoder_lora_state_dict = convert_state_dict_to_peft(text_encoder_lora_state_dict)
|
||||
if len(state_dict) > 0:
|
||||
logger.info(f"Loading {prefix}.")
|
||||
rank = {}
|
||||
state_dict = convert_state_dict_to_diffusers(state_dict)
|
||||
|
||||
for name, _ in text_encoder_attn_modules(text_encoder):
|
||||
for module in ("out_proj", "q_proj", "k_proj", "v_proj"):
|
||||
rank_key = f"{name}.{module}.lora_B.weight"
|
||||
if rank_key not in text_encoder_lora_state_dict:
|
||||
continue
|
||||
rank[rank_key] = text_encoder_lora_state_dict[rank_key].shape[1]
|
||||
# convert state dict
|
||||
state_dict = convert_state_dict_to_peft(state_dict)
|
||||
|
||||
for name, _ in text_encoder_mlp_modules(text_encoder):
|
||||
for module in ("fc1", "fc2"):
|
||||
rank_key = f"{name}.{module}.lora_B.weight"
|
||||
if rank_key not in text_encoder_lora_state_dict:
|
||||
continue
|
||||
rank[rank_key] = text_encoder_lora_state_dict[rank_key].shape[1]
|
||||
for name, _ in text_encoder_attn_modules(text_encoder):
|
||||
for module in ("out_proj", "q_proj", "k_proj", "v_proj"):
|
||||
rank_key = f"{name}.{module}.lora_B.weight"
|
||||
if rank_key not in state_dict:
|
||||
continue
|
||||
rank[rank_key] = state_dict[rank_key].shape[1]
|
||||
|
||||
if network_alphas is not None:
|
||||
alpha_keys = [k for k in network_alphas.keys() if k.startswith(prefix) and k.split(".")[0] == prefix]
|
||||
network_alphas = {k.replace(f"{prefix}.", ""): v for k, v in network_alphas.items() if k in alpha_keys}
|
||||
for name, _ in text_encoder_mlp_modules(text_encoder):
|
||||
for module in ("fc1", "fc2"):
|
||||
rank_key = f"{name}.{module}.lora_B.weight"
|
||||
if rank_key not in state_dict:
|
||||
continue
|
||||
rank[rank_key] = state_dict[rank_key].shape[1]
|
||||
|
||||
lora_config_kwargs = get_peft_kwargs(rank, network_alphas, text_encoder_lora_state_dict, is_unet=False)
|
||||
if network_alphas is not None:
|
||||
alpha_keys = [k for k in network_alphas.keys() if k.startswith(prefix) and k.split(".")[0] == prefix]
|
||||
network_alphas = {k.replace(f"{prefix}.", ""): v for k, v in network_alphas.items() if k in alpha_keys}
|
||||
|
||||
if "use_dora" in lora_config_kwargs:
|
||||
if lora_config_kwargs["use_dora"]:
|
||||
if is_peft_version("<", "0.9.0"):
|
||||
raise ValueError(
|
||||
"You need `peft` 0.9.0 at least to use DoRA-enabled LoRAs. Please upgrade your installation of `peft`."
|
||||
)
|
||||
else:
|
||||
if is_peft_version("<", "0.9.0"):
|
||||
lora_config_kwargs.pop("use_dora")
|
||||
lora_config_kwargs = get_peft_kwargs(rank, network_alphas, state_dict, is_unet=False)
|
||||
|
||||
if "lora_bias" in lora_config_kwargs:
|
||||
if lora_config_kwargs["lora_bias"]:
|
||||
if is_peft_version("<=", "0.13.2"):
|
||||
raise ValueError(
|
||||
"You need `peft` 0.14.0 at least to use `bias` in LoRAs. Please upgrade your installation of `peft`."
|
||||
)
|
||||
else:
|
||||
if is_peft_version("<=", "0.13.2"):
|
||||
lora_config_kwargs.pop("lora_bias")
|
||||
if "use_dora" in lora_config_kwargs:
|
||||
if lora_config_kwargs["use_dora"]:
|
||||
if is_peft_version("<", "0.9.0"):
|
||||
raise ValueError(
|
||||
"You need `peft` 0.9.0 at least to use DoRA-enabled LoRAs. Please upgrade your installation of `peft`."
|
||||
)
|
||||
else:
|
||||
if is_peft_version("<", "0.9.0"):
|
||||
lora_config_kwargs.pop("use_dora")
|
||||
|
||||
lora_config = LoraConfig(**lora_config_kwargs)
|
||||
if "lora_bias" in lora_config_kwargs:
|
||||
if lora_config_kwargs["lora_bias"]:
|
||||
if is_peft_version("<=", "0.13.2"):
|
||||
raise ValueError(
|
||||
"You need `peft` 0.14.0 at least to use `bias` in LoRAs. Please upgrade your installation of `peft`."
|
||||
)
|
||||
else:
|
||||
if is_peft_version("<=", "0.13.2"):
|
||||
lora_config_kwargs.pop("lora_bias")
|
||||
|
||||
# adapter_name
|
||||
if adapter_name is None:
|
||||
adapter_name = get_adapter_name(text_encoder)
|
||||
lora_config = LoraConfig(**lora_config_kwargs)
|
||||
|
||||
is_model_cpu_offload, is_sequential_cpu_offload = _func_optionally_disable_offloading(_pipeline)
|
||||
# adapter_name
|
||||
if adapter_name is None:
|
||||
adapter_name = get_adapter_name(text_encoder)
|
||||
|
||||
# inject LoRA layers and load the state dict
|
||||
# in transformers we automatically check whether the adapter name is already in use or not
|
||||
text_encoder.load_adapter(
|
||||
adapter_name=adapter_name,
|
||||
adapter_state_dict=text_encoder_lora_state_dict,
|
||||
peft_config=lora_config,
|
||||
**peft_kwargs,
|
||||
)
|
||||
is_model_cpu_offload, is_sequential_cpu_offload = _func_optionally_disable_offloading(_pipeline)
|
||||
|
||||
# scale LoRA layers with `lora_scale`
|
||||
scale_lora_layers(text_encoder, weight=lora_scale)
|
||||
# inject LoRA layers and load the state dict
|
||||
# in transformers we automatically check whether the adapter name is already in use or not
|
||||
text_encoder.load_adapter(
|
||||
adapter_name=adapter_name,
|
||||
adapter_state_dict=state_dict,
|
||||
peft_config=lora_config,
|
||||
**peft_kwargs,
|
||||
)
|
||||
|
||||
text_encoder.to(device=text_encoder.device, dtype=text_encoder.dtype)
|
||||
# scale LoRA layers with `lora_scale`
|
||||
scale_lora_layers(text_encoder, weight=lora_scale)
|
||||
|
||||
# Offload back.
|
||||
if is_model_cpu_offload:
|
||||
_pipeline.enable_model_cpu_offload()
|
||||
elif is_sequential_cpu_offload:
|
||||
_pipeline.enable_sequential_cpu_offload()
|
||||
# Unsafe code />
|
||||
text_encoder.to(device=text_encoder.device, dtype=text_encoder.dtype)
|
||||
|
||||
# Offload back.
|
||||
if is_model_cpu_offload:
|
||||
_pipeline.enable_model_cpu_offload()
|
||||
elif is_sequential_cpu_offload:
|
||||
_pipeline.enable_sequential_cpu_offload()
|
||||
# Unsafe code />
|
||||
|
||||
if prefix is not None and not state_dict:
|
||||
logger.warning(
|
||||
f"No LoRA keys associated to {text_encoder.__class__.__name__} found with the {prefix=}. "
|
||||
"This is safe to ignore if LoRA state dict didn't originally have any "
|
||||
f"{text_encoder.__class__.__name__} related params. You can also try specifying `prefix=None` "
|
||||
"to resolve the warning. Otherwise, open an issue if you think it's unexpected: "
|
||||
"https://github.com/huggingface/diffusers/issues/new"
|
||||
)
|
||||
|
||||
|
||||
def _func_optionally_disable_offloading(_pipeline):
|
||||
@@ -904,3 +913,23 @@ class LoraBaseMixin:
|
||||
# property function that returns the lora scale which can be set at run time by the pipeline.
|
||||
# if _lora_scale has not been set, return 1
|
||||
return self._lora_scale if hasattr(self, "_lora_scale") else 1.0
|
||||
|
||||
def enable_lora_hotswap(self, **kwargs) -> None:
|
||||
"""Enables the possibility to hotswap LoRA adapters.
|
||||
|
||||
Calling this method is only required when hotswapping adapters and if the model is compiled or if the ranks of
|
||||
the loaded adapters differ.
|
||||
|
||||
Args:
|
||||
target_rank (`int`):
|
||||
The highest rank among all the adapters that will be loaded.
|
||||
check_compiled (`str`, *optional*, defaults to `"error"`):
|
||||
How to handle the case when the model is already compiled, which should generally be avoided. The
|
||||
options are:
|
||||
- "error" (default): raise an error
|
||||
- "warn": issue a warning
|
||||
- "ignore": do nothing
|
||||
"""
|
||||
for key, component in self.components.items():
|
||||
if hasattr(component, "enable_lora_hotswap") and (key in self._lora_loadable_modules):
|
||||
component.enable_lora_hotswap(**kwargs)
|
||||
|
||||
@@ -13,15 +13,22 @@
|
||||
# limitations under the License.
|
||||
|
||||
import re
|
||||
from typing import List
|
||||
|
||||
import torch
|
||||
|
||||
from ..utils import is_peft_version, logging
|
||||
from ..utils import is_peft_version, logging, state_dict_all_zero
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
def swap_scale_shift(weight):
|
||||
shift, scale = weight.chunk(2, dim=0)
|
||||
new_weight = torch.cat([scale, shift], dim=0)
|
||||
return new_weight
|
||||
|
||||
|
||||
def _maybe_map_sgm_blocks_to_diffusers(state_dict, unet_config, delimiter="_", block_slice_pos=5):
|
||||
# 1. get all state_dict_keys
|
||||
all_keys = list(state_dict.keys())
|
||||
@@ -313,6 +320,7 @@ def _convert_text_encoder_lora_key(key, lora_name):
|
||||
# Be aware that this is the new diffusers convention and the rest of the code might
|
||||
# not utilize it yet.
|
||||
diffusers_name = diffusers_name.replace(".lora.", ".lora_linear_layer.")
|
||||
|
||||
return diffusers_name
|
||||
|
||||
|
||||
@@ -331,8 +339,7 @@ def _get_alpha_name(lora_name_alpha, diffusers_name, alpha):
|
||||
|
||||
|
||||
# The utilities under `_convert_kohya_flux_lora_to_diffusers()`
|
||||
# are taken from https://github.com/kohya-ss/sd-scripts/blob/a61cf73a5cb5209c3f4d1a3688dd276a4dfd1ecb/networks/convert_flux_lora.py
|
||||
# All credits go to `kohya-ss`.
|
||||
# are adapted from https://github.com/kohya-ss/sd-scripts/blob/a61cf73a5cb5209c3f4d1a3688dd276a4dfd1ecb/networks/convert_flux_lora.py
|
||||
def _convert_kohya_flux_lora_to_diffusers(state_dict):
|
||||
def _convert_to_ai_toolkit(sds_sd, ait_sd, sds_key, ait_key):
|
||||
if sds_key + ".lora_down.weight" not in sds_sd:
|
||||
@@ -341,7 +348,8 @@ def _convert_kohya_flux_lora_to_diffusers(state_dict):
|
||||
|
||||
# scale weight by alpha and dim
|
||||
rank = down_weight.shape[0]
|
||||
alpha = sds_sd.pop(sds_key + ".alpha").item() # alpha is scalar
|
||||
default_alpha = torch.tensor(rank, dtype=down_weight.dtype, device=down_weight.device, requires_grad=False)
|
||||
alpha = sds_sd.pop(sds_key + ".alpha", default_alpha).item() # alpha is scalar
|
||||
scale = alpha / rank # LoRA is scaled by 'alpha / rank' in forward pass, so we need to scale it back here
|
||||
|
||||
# calculate scale_down and scale_up to keep the same value. if scale is 4, scale_down is 2 and scale_up is 2
|
||||
@@ -362,7 +370,10 @@ def _convert_kohya_flux_lora_to_diffusers(state_dict):
|
||||
sd_lora_rank = down_weight.shape[0]
|
||||
|
||||
# scale weight by alpha and dim
|
||||
alpha = sds_sd.pop(sds_key + ".alpha")
|
||||
default_alpha = torch.tensor(
|
||||
sd_lora_rank, dtype=down_weight.dtype, device=down_weight.device, requires_grad=False
|
||||
)
|
||||
alpha = sds_sd.pop(sds_key + ".alpha", default_alpha)
|
||||
scale = alpha / sd_lora_rank
|
||||
|
||||
# calculate scale_down and scale_up
|
||||
@@ -516,10 +527,103 @@ def _convert_kohya_flux_lora_to_diffusers(state_dict):
|
||||
f"transformer.single_transformer_blocks.{i}.norm.linear",
|
||||
)
|
||||
|
||||
# TODO: alphas.
|
||||
def assign_remaining_weights(assignments, source):
|
||||
for lora_key in ["lora_A", "lora_B"]:
|
||||
orig_lora_key = "lora_down" if lora_key == "lora_A" else "lora_up"
|
||||
for target_fmt, source_fmt, transform in assignments:
|
||||
target_key = target_fmt.format(lora_key=lora_key)
|
||||
source_key = source_fmt.format(orig_lora_key=orig_lora_key)
|
||||
value = source.pop(source_key)
|
||||
if transform:
|
||||
value = transform(value)
|
||||
ait_sd[target_key] = value
|
||||
|
||||
if any("guidance_in" in k for k in sds_sd):
|
||||
assign_remaining_weights(
|
||||
[
|
||||
(
|
||||
"time_text_embed.guidance_embedder.linear_1.{lora_key}.weight",
|
||||
"lora_unet_guidance_in_in_layer.{orig_lora_key}.weight",
|
||||
None,
|
||||
),
|
||||
(
|
||||
"time_text_embed.guidance_embedder.linear_2.{lora_key}.weight",
|
||||
"lora_unet_guidance_in_out_layer.{orig_lora_key}.weight",
|
||||
None,
|
||||
),
|
||||
],
|
||||
sds_sd,
|
||||
)
|
||||
|
||||
if any("img_in" in k for k in sds_sd):
|
||||
assign_remaining_weights(
|
||||
[
|
||||
("x_embedder.{lora_key}.weight", "lora_unet_img_in.{orig_lora_key}.weight", None),
|
||||
],
|
||||
sds_sd,
|
||||
)
|
||||
|
||||
if any("txt_in" in k for k in sds_sd):
|
||||
assign_remaining_weights(
|
||||
[
|
||||
("context_embedder.{lora_key}.weight", "lora_unet_txt_in.{orig_lora_key}.weight", None),
|
||||
],
|
||||
sds_sd,
|
||||
)
|
||||
|
||||
if any("time_in" in k for k in sds_sd):
|
||||
assign_remaining_weights(
|
||||
[
|
||||
(
|
||||
"time_text_embed.timestep_embedder.linear_1.{lora_key}.weight",
|
||||
"lora_unet_time_in_in_layer.{orig_lora_key}.weight",
|
||||
None,
|
||||
),
|
||||
(
|
||||
"time_text_embed.timestep_embedder.linear_2.{lora_key}.weight",
|
||||
"lora_unet_time_in_out_layer.{orig_lora_key}.weight",
|
||||
None,
|
||||
),
|
||||
],
|
||||
sds_sd,
|
||||
)
|
||||
|
||||
if any("vector_in" in k for k in sds_sd):
|
||||
assign_remaining_weights(
|
||||
[
|
||||
(
|
||||
"time_text_embed.text_embedder.linear_1.{lora_key}.weight",
|
||||
"lora_unet_vector_in_in_layer.{orig_lora_key}.weight",
|
||||
None,
|
||||
),
|
||||
(
|
||||
"time_text_embed.text_embedder.linear_2.{lora_key}.weight",
|
||||
"lora_unet_vector_in_out_layer.{orig_lora_key}.weight",
|
||||
None,
|
||||
),
|
||||
],
|
||||
sds_sd,
|
||||
)
|
||||
|
||||
if any("final_layer" in k for k in sds_sd):
|
||||
# Notice the swap in processing for "final_layer".
|
||||
assign_remaining_weights(
|
||||
[
|
||||
(
|
||||
"norm_out.linear.{lora_key}.weight",
|
||||
"lora_unet_final_layer_adaLN_modulation_1.{orig_lora_key}.weight",
|
||||
swap_scale_shift,
|
||||
),
|
||||
("proj_out.{lora_key}.weight", "lora_unet_final_layer_linear.{orig_lora_key}.weight", None),
|
||||
],
|
||||
sds_sd,
|
||||
)
|
||||
|
||||
remaining_keys = list(sds_sd.keys())
|
||||
te_state_dict = {}
|
||||
if remaining_keys:
|
||||
if not all(k.startswith("lora_te") for k in remaining_keys):
|
||||
if not all(k.startswith(("lora_te", "lora_te1")) for k in remaining_keys):
|
||||
raise ValueError(f"Incompatible keys detected: \n\n {', '.join(remaining_keys)}")
|
||||
for key in remaining_keys:
|
||||
if not key.endswith("lora_down.weight"):
|
||||
@@ -680,10 +784,98 @@ def _convert_kohya_flux_lora_to_diffusers(state_dict):
|
||||
if has_peft_state_dict:
|
||||
state_dict = {k: v for k, v in state_dict.items() if k.startswith("transformer.")}
|
||||
return state_dict
|
||||
|
||||
# Another weird one.
|
||||
has_mixture = any(
|
||||
k.startswith("lora_transformer_") and ("lora_down" in k or "lora_up" in k or "alpha" in k) for k in state_dict
|
||||
)
|
||||
|
||||
# ComfyUI.
|
||||
if not has_mixture:
|
||||
state_dict = {k.replace("diffusion_model.", "lora_unet_"): v for k, v in state_dict.items()}
|
||||
state_dict = {k.replace("text_encoders.clip_l.transformer.", "lora_te_"): v for k, v in state_dict.items()}
|
||||
|
||||
has_position_embedding = any("position_embedding" in k for k in state_dict)
|
||||
if has_position_embedding:
|
||||
zero_status_pe = state_dict_all_zero(state_dict, "position_embedding")
|
||||
if zero_status_pe:
|
||||
logger.info(
|
||||
"The `position_embedding` LoRA params are all zeros which make them ineffective. "
|
||||
"So, we will purge them out of the curret state dict to make loading possible."
|
||||
)
|
||||
|
||||
else:
|
||||
logger.info(
|
||||
"The state_dict has position_embedding LoRA params and we currently do not support them. "
|
||||
"Open an issue if you need this supported - https://github.com/huggingface/diffusers/issues/new."
|
||||
)
|
||||
state_dict = {k: v for k, v in state_dict.items() if "position_embedding" not in k}
|
||||
|
||||
has_t5xxl = any(k.startswith("text_encoders.t5xxl.transformer.") for k in state_dict)
|
||||
if has_t5xxl:
|
||||
zero_status_t5 = state_dict_all_zero(state_dict, "text_encoders.t5xxl")
|
||||
if zero_status_t5:
|
||||
logger.info(
|
||||
"The `t5xxl` LoRA params are all zeros which make them ineffective. "
|
||||
"So, we will purge them out of the curret state dict to make loading possible."
|
||||
)
|
||||
else:
|
||||
logger.info(
|
||||
"T5-xxl keys found in the state dict, which are currently unsupported. We will filter them out."
|
||||
"Open an issue if this is a problem - https://github.com/huggingface/diffusers/issues/new."
|
||||
)
|
||||
state_dict = {k: v for k, v in state_dict.items() if not k.startswith("text_encoders.t5xxl.transformer.")}
|
||||
|
||||
has_diffb = any("diff_b" in k and k.startswith(("lora_unet_", "lora_te_")) for k in state_dict)
|
||||
if has_diffb:
|
||||
zero_status_diff_b = state_dict_all_zero(state_dict, ".diff_b")
|
||||
if zero_status_diff_b:
|
||||
logger.info(
|
||||
"The `diff_b` LoRA params are all zeros which make them ineffective. "
|
||||
"So, we will purge them out of the curret state dict to make loading possible."
|
||||
)
|
||||
else:
|
||||
logger.info(
|
||||
"`diff_b` keys found in the state dict which are currently unsupported. "
|
||||
"So, we will filter out those keys. Open an issue if this is a problem - "
|
||||
"https://github.com/huggingface/diffusers/issues/new."
|
||||
)
|
||||
state_dict = {k: v for k, v in state_dict.items() if ".diff_b" not in k}
|
||||
|
||||
has_norm_diff = any(".norm" in k and ".diff" in k for k in state_dict)
|
||||
if has_norm_diff:
|
||||
zero_status_diff = state_dict_all_zero(state_dict, ".diff")
|
||||
if zero_status_diff:
|
||||
logger.info(
|
||||
"The `diff` LoRA params are all zeros which make them ineffective. "
|
||||
"So, we will purge them out of the curret state dict to make loading possible."
|
||||
)
|
||||
else:
|
||||
logger.info(
|
||||
"Normalization diff keys found in the state dict which are currently unsupported. "
|
||||
"So, we will filter out those keys. Open an issue if this is a problem - "
|
||||
"https://github.com/huggingface/diffusers/issues/new."
|
||||
)
|
||||
state_dict = {k: v for k, v in state_dict.items() if ".norm" not in k and ".diff" not in k}
|
||||
|
||||
limit_substrings = ["lora_down", "lora_up"]
|
||||
if any("alpha" in k for k in state_dict):
|
||||
limit_substrings.append("alpha")
|
||||
|
||||
state_dict = {
|
||||
_custom_replace(k, limit_substrings): v
|
||||
for k, v in state_dict.items()
|
||||
if k.startswith(("lora_unet_", "lora_te_"))
|
||||
}
|
||||
|
||||
if any("text_projection" in k for k in state_dict):
|
||||
logger.info(
|
||||
"`text_projection` keys found in the `state_dict` which are unexpected. "
|
||||
"So, we will filter out those keys. Open an issue if this is a problem - "
|
||||
"https://github.com/huggingface/diffusers/issues/new."
|
||||
)
|
||||
state_dict = {k: v for k, v in state_dict.items() if "text_projection" not in k}
|
||||
|
||||
if has_mixture:
|
||||
return _convert_mixture_state_dict_to_diffusers(state_dict)
|
||||
|
||||
@@ -798,6 +990,26 @@ def _convert_xlabs_flux_lora_to_diffusers(old_state_dict):
|
||||
return new_state_dict
|
||||
|
||||
|
||||
def _custom_replace(key: str, substrings: List[str]) -> str:
|
||||
# Replaces the "."s with "_"s upto the `substrings`.
|
||||
# Example:
|
||||
# lora_unet.foo.bar.lora_A.weight -> lora_unet_foo_bar.lora_A.weight
|
||||
pattern = "(" + "|".join(re.escape(sub) for sub in substrings) + ")"
|
||||
|
||||
match = re.search(pattern, key)
|
||||
if match:
|
||||
start_sub = match.start()
|
||||
if start_sub > 0 and key[start_sub - 1] == ".":
|
||||
boundary = start_sub - 1
|
||||
else:
|
||||
boundary = start_sub
|
||||
left = key[:boundary].replace(".", "_")
|
||||
right = key[boundary:]
|
||||
return left + right
|
||||
else:
|
||||
return key.replace(".", "_")
|
||||
|
||||
|
||||
def _convert_bfl_flux_control_lora_to_diffusers(original_state_dict):
|
||||
converted_state_dict = {}
|
||||
original_state_dict_keys = list(original_state_dict.keys())
|
||||
@@ -806,11 +1018,6 @@ def _convert_bfl_flux_control_lora_to_diffusers(original_state_dict):
|
||||
inner_dim = 3072
|
||||
mlp_ratio = 4.0
|
||||
|
||||
def swap_scale_shift(weight):
|
||||
shift, scale = weight.chunk(2, dim=0)
|
||||
new_weight = torch.cat([scale, shift], dim=0)
|
||||
return new_weight
|
||||
|
||||
for lora_key in ["lora_A", "lora_B"]:
|
||||
## time_text_embed.timestep_embedder <- time_in
|
||||
converted_state_dict[
|
||||
@@ -1348,3 +1555,56 @@ def _convert_non_diffusers_lumina2_lora_to_diffusers(state_dict):
|
||||
converted_state_dict[f"transformer.{key}"] = converted_state_dict.pop(key)
|
||||
|
||||
return converted_state_dict
|
||||
|
||||
|
||||
def _convert_non_diffusers_wan_lora_to_diffusers(state_dict):
|
||||
converted_state_dict = {}
|
||||
original_state_dict = {k[len("diffusion_model.") :]: v for k, v in state_dict.items()}
|
||||
|
||||
num_blocks = len({k.split("blocks.")[1].split(".")[0] for k in original_state_dict})
|
||||
is_i2v_lora = any("k_img" in k for k in original_state_dict) and any("v_img" in k for k in original_state_dict)
|
||||
|
||||
for i in range(num_blocks):
|
||||
# Self-attention
|
||||
for o, c in zip(["q", "k", "v", "o"], ["to_q", "to_k", "to_v", "to_out.0"]):
|
||||
converted_state_dict[f"blocks.{i}.attn1.{c}.lora_A.weight"] = original_state_dict.pop(
|
||||
f"blocks.{i}.self_attn.{o}.lora_A.weight"
|
||||
)
|
||||
converted_state_dict[f"blocks.{i}.attn1.{c}.lora_B.weight"] = original_state_dict.pop(
|
||||
f"blocks.{i}.self_attn.{o}.lora_B.weight"
|
||||
)
|
||||
|
||||
# Cross-attention
|
||||
for o, c in zip(["q", "k", "v", "o"], ["to_q", "to_k", "to_v", "to_out.0"]):
|
||||
converted_state_dict[f"blocks.{i}.attn2.{c}.lora_A.weight"] = original_state_dict.pop(
|
||||
f"blocks.{i}.cross_attn.{o}.lora_A.weight"
|
||||
)
|
||||
converted_state_dict[f"blocks.{i}.attn2.{c}.lora_B.weight"] = original_state_dict.pop(
|
||||
f"blocks.{i}.cross_attn.{o}.lora_B.weight"
|
||||
)
|
||||
|
||||
if is_i2v_lora:
|
||||
for o, c in zip(["k_img", "v_img"], ["add_k_proj", "add_v_proj"]):
|
||||
converted_state_dict[f"blocks.{i}.attn2.{c}.lora_A.weight"] = original_state_dict.pop(
|
||||
f"blocks.{i}.cross_attn.{o}.lora_A.weight"
|
||||
)
|
||||
converted_state_dict[f"blocks.{i}.attn2.{c}.lora_B.weight"] = original_state_dict.pop(
|
||||
f"blocks.{i}.cross_attn.{o}.lora_B.weight"
|
||||
)
|
||||
|
||||
# FFN
|
||||
for o, c in zip(["ffn.0", "ffn.2"], ["net.0.proj", "net.2"]):
|
||||
converted_state_dict[f"blocks.{i}.ffn.{c}.lora_A.weight"] = original_state_dict.pop(
|
||||
f"blocks.{i}.{o}.lora_A.weight"
|
||||
)
|
||||
converted_state_dict[f"blocks.{i}.ffn.{c}.lora_B.weight"] = original_state_dict.pop(
|
||||
f"blocks.{i}.{o}.lora_B.weight"
|
||||
)
|
||||
|
||||
if len(original_state_dict) > 0:
|
||||
raise ValueError(f"`state_dict` should be empty at this point but has {original_state_dict.keys()=}")
|
||||
|
||||
for key in list(converted_state_dict.keys()):
|
||||
converted_state_dict[f"transformer.{key}"] = converted_state_dict.pop(key)
|
||||
|
||||
return converted_state_dict
|
||||
|
||||
+1162
-189
File diff suppressed because it is too large
Load Diff
+152
-64
@@ -16,7 +16,7 @@ import inspect
|
||||
import os
|
||||
from functools import partial
|
||||
from pathlib import Path
|
||||
from typing import Dict, List, Optional, Union
|
||||
from typing import Dict, List, Literal, Optional, Union
|
||||
|
||||
import safetensors
|
||||
import torch
|
||||
@@ -54,26 +54,15 @@ _SET_ADAPTER_SCALE_FN_MAPPING = {
|
||||
"SanaTransformer2DModel": lambda model_cls, weights: weights,
|
||||
"Lumina2Transformer2DModel": lambda model_cls, weights: weights,
|
||||
"WanTransformer3DModel": lambda model_cls, weights: weights,
|
||||
"CogView4Transformer2DModel": lambda model_cls, weights: weights,
|
||||
}
|
||||
|
||||
|
||||
def _maybe_adjust_config(config):
|
||||
"""
|
||||
We may run into some ambiguous configuration values when a model has module names, sharing a common prefix
|
||||
(`proj_out.weight` and `blocks.transformer.proj_out.weight`, for example) and they have different LoRA ranks. This
|
||||
method removes the ambiguity by following what is described here:
|
||||
https://github.com/huggingface/diffusers/pull/9985#issuecomment-2493840028.
|
||||
"""
|
||||
# Track keys that have been explicitly removed to prevent re-adding them.
|
||||
deleted_keys = set()
|
||||
|
||||
def _maybe_raise_error_for_ambiguity(config):
|
||||
rank_pattern = config["rank_pattern"].copy()
|
||||
target_modules = config["target_modules"]
|
||||
original_r = config["r"]
|
||||
|
||||
for key in list(rank_pattern.keys()):
|
||||
key_rank = rank_pattern[key]
|
||||
|
||||
# try to detect ambiguity
|
||||
# `target_modules` can also be a str, in which case this loop would loop
|
||||
# over the chars of the str. The technically correct way to match LoRA keys
|
||||
@@ -81,35 +70,12 @@ def _maybe_adjust_config(config):
|
||||
# But this cuts it for now.
|
||||
exact_matches = [mod for mod in target_modules if mod == key]
|
||||
substring_matches = [mod for mod in target_modules if key in mod and mod != key]
|
||||
ambiguous_key = key
|
||||
|
||||
if exact_matches and substring_matches:
|
||||
# if ambiguous, update the rank associated with the ambiguous key (`proj_out`, for example)
|
||||
config["r"] = key_rank
|
||||
# remove the ambiguous key from `rank_pattern` and record it as deleted
|
||||
del config["rank_pattern"][key]
|
||||
deleted_keys.add(key)
|
||||
# For substring matches, add them with the original rank only if they haven't been assigned already
|
||||
for mod in substring_matches:
|
||||
if mod not in config["rank_pattern"] and mod not in deleted_keys:
|
||||
config["rank_pattern"][mod] = original_r
|
||||
|
||||
# Update the rest of the target modules with the original rank if not already set and not deleted
|
||||
for mod in target_modules:
|
||||
if mod != ambiguous_key and mod not in config["rank_pattern"] and mod not in deleted_keys:
|
||||
config["rank_pattern"][mod] = original_r
|
||||
|
||||
# Handle alphas to deal with cases like:
|
||||
# https://github.com/huggingface/diffusers/pull/9999#issuecomment-2516180777
|
||||
has_different_ranks = len(config["rank_pattern"]) > 1 and list(config["rank_pattern"])[0] != config["r"]
|
||||
if has_different_ranks:
|
||||
config["lora_alpha"] = config["r"]
|
||||
alpha_pattern = {}
|
||||
for module_name, rank in config["rank_pattern"].items():
|
||||
alpha_pattern[module_name] = rank
|
||||
config["alpha_pattern"] = alpha_pattern
|
||||
|
||||
return config
|
||||
if is_peft_version("<", "0.14.1"):
|
||||
raise ValueError(
|
||||
"There are ambiguous keys present in this LoRA. To load it, please update your `peft` installation - `pip install -U peft`."
|
||||
)
|
||||
|
||||
|
||||
class PeftAdapterMixin:
|
||||
@@ -127,6 +93,8 @@ class PeftAdapterMixin:
|
||||
"""
|
||||
|
||||
_hf_peft_config_loaded = False
|
||||
# kwargs for prepare_model_for_compiled_hotswap, if required
|
||||
_prepare_lora_hotswap_kwargs: Optional[dict] = None
|
||||
|
||||
@classmethod
|
||||
# Copied from diffusers.loaders.lora_base.LoraBaseMixin._optionally_disable_offloading
|
||||
@@ -144,7 +112,9 @@ class PeftAdapterMixin:
|
||||
"""
|
||||
return _func_optionally_disable_offloading(_pipeline=_pipeline)
|
||||
|
||||
def load_lora_adapter(self, pretrained_model_name_or_path_or_dict, prefix="transformer", **kwargs):
|
||||
def load_lora_adapter(
|
||||
self, pretrained_model_name_or_path_or_dict, prefix="transformer", hotswap: bool = False, **kwargs
|
||||
):
|
||||
r"""
|
||||
Loads a LoRA adapter into the underlying model.
|
||||
|
||||
@@ -188,15 +158,33 @@ class PeftAdapterMixin:
|
||||
low_cpu_mem_usage (`bool`, *optional*):
|
||||
Speed up model loading by only loading the pretrained LoRA weights and not initializing the random
|
||||
weights.
|
||||
hotswap : (`bool`, *optional*)
|
||||
Defaults to `False`. Whether to substitute an existing (LoRA) adapter with the newly loaded adapter
|
||||
in-place. This means that, instead of loading an additional adapter, this will take the existing
|
||||
adapter weights and replace them with the weights of the new adapter. This can be faster and more
|
||||
memory efficient. However, the main advantage of hotswapping is that when the model is compiled with
|
||||
torch.compile, loading the new adapter does not require recompilation of the model. When using
|
||||
hotswapping, the passed `adapter_name` should be the name of an already loaded adapter.
|
||||
|
||||
If the new adapter and the old adapter have different ranks and/or LoRA alphas (i.e. scaling), you need
|
||||
to call an additional method before loading the adapter:
|
||||
|
||||
```py
|
||||
pipeline = ... # load diffusers pipeline
|
||||
max_rank = ... # the highest rank among all LoRAs that you want to load
|
||||
# call *before* compiling and loading the LoRA adapter
|
||||
pipeline.enable_lora_hotswap(target_rank=max_rank)
|
||||
pipeline.load_lora_weights(file_name)
|
||||
# optionally compile the model now
|
||||
```
|
||||
|
||||
Note that hotswapping adapters of the text encoder is not yet supported. There are some further
|
||||
limitations to this technique, which are documented here:
|
||||
https://huggingface.co/docs/peft/main/en/package_reference/hotswap
|
||||
"""
|
||||
from peft import LoraConfig, inject_adapter_in_model, set_peft_model_state_dict
|
||||
from peft.tuners.tuners_utils import BaseTunerLayer
|
||||
|
||||
try:
|
||||
from peft.utils.constants import FULLY_QUALIFIED_PATTERN_KEY_PREFIX
|
||||
except ImportError:
|
||||
FULLY_QUALIFIED_PATTERN_KEY_PREFIX = None
|
||||
|
||||
cache_dir = kwargs.pop("cache_dir", None)
|
||||
force_download = kwargs.pop("force_download", False)
|
||||
proxies = kwargs.pop("proxies", None)
|
||||
@@ -240,16 +228,18 @@ class PeftAdapterMixin:
|
||||
raise ValueError("`network_alphas` cannot be None when `prefix` is None.")
|
||||
|
||||
if prefix is not None:
|
||||
keys = list(state_dict.keys())
|
||||
model_keys = [k for k in keys if k.startswith(f"{prefix}.")]
|
||||
if len(model_keys) > 0:
|
||||
state_dict = {k.replace(f"{prefix}.", ""): v for k, v in state_dict.items() if k in model_keys}
|
||||
state_dict = {k[len(f"{prefix}.") :]: v for k, v in state_dict.items() if k.startswith(f"{prefix}.")}
|
||||
|
||||
if len(state_dict) > 0:
|
||||
if adapter_name in getattr(self, "peft_config", {}):
|
||||
if adapter_name in getattr(self, "peft_config", {}) and not hotswap:
|
||||
raise ValueError(
|
||||
f"Adapter name {adapter_name} already in use in the model - please select a new adapter name."
|
||||
)
|
||||
elif adapter_name not in getattr(self, "peft_config", {}) and hotswap:
|
||||
raise ValueError(
|
||||
f"Trying to hotswap LoRA adapter '{adapter_name}' but there is no existing adapter by that name. "
|
||||
"Please choose an existing adapter name or set `hotswap=False` to prevent hotswapping."
|
||||
)
|
||||
|
||||
# check with first key if is not in peft format
|
||||
first_key = next(iter(state_dict.keys()))
|
||||
@@ -261,22 +251,18 @@ class PeftAdapterMixin:
|
||||
# Cannot figure out rank from lora layers that don't have atleast 2 dimensions.
|
||||
# Bias layers in LoRA only have a single dimension
|
||||
if "lora_B" in key and val.ndim > 1:
|
||||
# Support to handle cases where layer patterns are treated as full layer names
|
||||
# was added later in PEFT. So, we handle it accordingly.
|
||||
# TODO: when we fix the minimal PEFT version for Diffusers,
|
||||
# we should remove `_maybe_adjust_config()`.
|
||||
if FULLY_QUALIFIED_PATTERN_KEY_PREFIX:
|
||||
rank[f"{FULLY_QUALIFIED_PATTERN_KEY_PREFIX}{key}"] = val.shape[1]
|
||||
else:
|
||||
rank[key] = val.shape[1]
|
||||
# Check out https://github.com/huggingface/peft/pull/2419 for the `^` symbol.
|
||||
# We may run into some ambiguous configuration values when a model has module
|
||||
# names, sharing a common prefix (`proj_out.weight` and `blocks.transformer.proj_out.weight`,
|
||||
# for example) and they have different LoRA ranks.
|
||||
rank[f"^{key}"] = val.shape[1]
|
||||
|
||||
if network_alphas is not None and len(network_alphas) >= 1:
|
||||
alpha_keys = [k for k in network_alphas.keys() if k.startswith(f"{prefix}.")]
|
||||
network_alphas = {k.replace(f"{prefix}.", ""): v for k, v in network_alphas.items() if k in alpha_keys}
|
||||
|
||||
lora_config_kwargs = get_peft_kwargs(rank, network_alpha_dict=network_alphas, peft_state_dict=state_dict)
|
||||
if not FULLY_QUALIFIED_PATTERN_KEY_PREFIX:
|
||||
lora_config_kwargs = _maybe_adjust_config(lora_config_kwargs)
|
||||
_maybe_raise_error_for_ambiguity(lora_config_kwargs)
|
||||
|
||||
if "use_dora" in lora_config_kwargs:
|
||||
if lora_config_kwargs["use_dora"]:
|
||||
@@ -315,11 +301,71 @@ class PeftAdapterMixin:
|
||||
if is_peft_version(">=", "0.13.1"):
|
||||
peft_kwargs["low_cpu_mem_usage"] = low_cpu_mem_usage
|
||||
|
||||
if hotswap or (self._prepare_lora_hotswap_kwargs is not None):
|
||||
if is_peft_version(">", "0.14.0"):
|
||||
from peft.utils.hotswap import (
|
||||
check_hotswap_configs_compatible,
|
||||
hotswap_adapter_from_state_dict,
|
||||
prepare_model_for_compiled_hotswap,
|
||||
)
|
||||
else:
|
||||
msg = (
|
||||
"Hotswapping requires PEFT > v0.14. Please upgrade PEFT to a higher version or install it "
|
||||
"from source."
|
||||
)
|
||||
raise ImportError(msg)
|
||||
|
||||
if hotswap:
|
||||
|
||||
def map_state_dict_for_hotswap(sd):
|
||||
# For hotswapping, we need the adapter name to be present in the state dict keys
|
||||
new_sd = {}
|
||||
for k, v in sd.items():
|
||||
if k.endswith("lora_A.weight") or key.endswith("lora_B.weight"):
|
||||
k = k[: -len(".weight")] + f".{adapter_name}.weight"
|
||||
elif k.endswith("lora_B.bias"): # lora_bias=True option
|
||||
k = k[: -len(".bias")] + f".{adapter_name}.bias"
|
||||
new_sd[k] = v
|
||||
return new_sd
|
||||
|
||||
# To handle scenarios where we cannot successfully set state dict. If it's unsucessful,
|
||||
# we should also delete the `peft_config` associated to the `adapter_name`.
|
||||
try:
|
||||
inject_adapter_in_model(lora_config, self, adapter_name=adapter_name, **peft_kwargs)
|
||||
incompatible_keys = set_peft_model_state_dict(self, state_dict, adapter_name, **peft_kwargs)
|
||||
if hotswap:
|
||||
state_dict = map_state_dict_for_hotswap(state_dict)
|
||||
check_hotswap_configs_compatible(self.peft_config[adapter_name], lora_config)
|
||||
try:
|
||||
hotswap_adapter_from_state_dict(
|
||||
model=self,
|
||||
state_dict=state_dict,
|
||||
adapter_name=adapter_name,
|
||||
config=lora_config,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Hotswapping {adapter_name} was unsucessful with the following error: \n{e}")
|
||||
raise
|
||||
# the hotswap function raises if there are incompatible keys, so if we reach this point we can set
|
||||
# it to None
|
||||
incompatible_keys = None
|
||||
else:
|
||||
inject_adapter_in_model(lora_config, self, adapter_name=adapter_name, **peft_kwargs)
|
||||
incompatible_keys = set_peft_model_state_dict(self, state_dict, adapter_name, **peft_kwargs)
|
||||
|
||||
if self._prepare_lora_hotswap_kwargs is not None:
|
||||
# For hotswapping of compiled models or adapters with different ranks.
|
||||
# If the user called enable_lora_hotswap, we need to ensure it is called:
|
||||
# - after the first adapter was loaded
|
||||
# - before the model is compiled and the 2nd adapter is being hotswapped in
|
||||
# Therefore, it needs to be called here
|
||||
prepare_model_for_compiled_hotswap(
|
||||
self, config=lora_config, **self._prepare_lora_hotswap_kwargs
|
||||
)
|
||||
# We only want to call prepare_model_for_compiled_hotswap once
|
||||
self._prepare_lora_hotswap_kwargs = None
|
||||
|
||||
# Set peft config loaded flag to True if module has been successfully injected and incompatible keys retrieved
|
||||
if not self._hf_peft_config_loaded:
|
||||
self._hf_peft_config_loaded = True
|
||||
except Exception as e:
|
||||
# In case `inject_adapter_in_model()` was unsuccessful even before injecting the `peft_config`.
|
||||
if hasattr(self, "peft_config"):
|
||||
@@ -366,6 +412,15 @@ class PeftAdapterMixin:
|
||||
_pipeline.enable_sequential_cpu_offload()
|
||||
# Unsafe code />
|
||||
|
||||
if prefix is not None and not state_dict:
|
||||
logger.warning(
|
||||
f"No LoRA keys associated to {self.__class__.__name__} found with the {prefix=}. "
|
||||
"This is safe to ignore if LoRA state dict didn't originally have any "
|
||||
f"{self.__class__.__name__} related params. You can also try specifying `prefix=None` "
|
||||
"to resolve the warning. Otherwise, open an issue if you think it's unexpected: "
|
||||
"https://github.com/huggingface/diffusers/issues/new"
|
||||
)
|
||||
|
||||
def save_lora_adapter(
|
||||
self,
|
||||
save_directory,
|
||||
@@ -770,3 +825,36 @@ class PeftAdapterMixin:
|
||||
# Pop also the corresponding adapter from the config
|
||||
if hasattr(self, "peft_config"):
|
||||
self.peft_config.pop(adapter_name, None)
|
||||
|
||||
def enable_lora_hotswap(
|
||||
self, target_rank: int = 128, check_compiled: Literal["error", "warn", "ignore"] = "error"
|
||||
) -> None:
|
||||
"""Enables the possibility to hotswap LoRA adapters.
|
||||
|
||||
Calling this method is only required when hotswapping adapters and if the model is compiled or if the ranks of
|
||||
the loaded adapters differ.
|
||||
|
||||
Args:
|
||||
target_rank (`int`, *optional*, defaults to `128`):
|
||||
The highest rank among all the adapters that will be loaded.
|
||||
|
||||
check_compiled (`str`, *optional*, defaults to `"error"`):
|
||||
How to handle the case when the model is already compiled, which should generally be avoided. The
|
||||
options are:
|
||||
- "error" (default): raise an error
|
||||
- "warn": issue a warning
|
||||
- "ignore": do nothing
|
||||
"""
|
||||
if getattr(self, "peft_config", {}):
|
||||
if check_compiled == "error":
|
||||
raise RuntimeError("Call `enable_lora_hotswap` before loading the first adapter.")
|
||||
elif check_compiled == "warn":
|
||||
logger.warning(
|
||||
"It is recommended to call `enable_lora_hotswap` before loading the first adapter to avoid recompilation."
|
||||
)
|
||||
elif check_compiled != "ignore":
|
||||
raise ValueError(
|
||||
f"check_compiles should be one of 'error', 'warn', or 'ignore', got '{check_compiled}' instead."
|
||||
)
|
||||
|
||||
self._prepare_lora_hotswap_kwargs = {"target_rank": target_rank, "check_compiled": check_compiled}
|
||||
|
||||
@@ -360,12 +360,12 @@ class FromSingleFileMixin:
|
||||
cache_dir = kwargs.pop("cache_dir", None)
|
||||
local_files_only = kwargs.pop("local_files_only", False)
|
||||
revision = kwargs.pop("revision", None)
|
||||
torch_dtype = kwargs.pop("torch_dtype", torch.float32)
|
||||
torch_dtype = kwargs.pop("torch_dtype", None)
|
||||
disable_mmap = kwargs.pop("disable_mmap", False)
|
||||
|
||||
is_legacy_loading = False
|
||||
|
||||
if not isinstance(torch_dtype, torch.dtype):
|
||||
if torch_dtype is not None and not isinstance(torch_dtype, torch.dtype):
|
||||
torch_dtype = torch.float32
|
||||
logger.warning(
|
||||
f"Passed `torch_dtype` {torch_dtype} is not a `torch.dtype`. Defaulting to `torch.float32`."
|
||||
|
||||
@@ -37,6 +37,7 @@ from .single_file_utils import (
|
||||
convert_ltx_vae_checkpoint_to_diffusers,
|
||||
convert_lumina2_to_diffusers,
|
||||
convert_mochi_transformer_checkpoint_to_diffusers,
|
||||
convert_sana_transformer_to_diffusers,
|
||||
convert_sd3_transformer_checkpoint_to_diffusers,
|
||||
convert_stable_cascade_unet_single_file_to_diffusers,
|
||||
convert_wan_transformer_to_diffusers,
|
||||
@@ -119,6 +120,10 @@ SINGLE_FILE_LOADABLE_CLASSES = {
|
||||
"checkpoint_mapping_fn": convert_lumina2_to_diffusers,
|
||||
"default_subfolder": "transformer",
|
||||
},
|
||||
"SanaTransformer2DModel": {
|
||||
"checkpoint_mapping_fn": convert_sana_transformer_to_diffusers,
|
||||
"default_subfolder": "transformer",
|
||||
},
|
||||
"WanTransformer3DModel": {
|
||||
"checkpoint_mapping_fn": convert_wan_transformer_to_diffusers,
|
||||
"default_subfolder": "transformer",
|
||||
@@ -250,12 +255,12 @@ class FromOriginalModelMixin:
|
||||
subfolder = kwargs.pop("subfolder", None)
|
||||
revision = kwargs.pop("revision", None)
|
||||
config_revision = kwargs.pop("config_revision", None)
|
||||
torch_dtype = kwargs.pop("torch_dtype", torch.float32)
|
||||
torch_dtype = kwargs.pop("torch_dtype", None)
|
||||
quantization_config = kwargs.pop("quantization_config", None)
|
||||
device = kwargs.pop("device", None)
|
||||
disable_mmap = kwargs.pop("disable_mmap", False)
|
||||
|
||||
if not isinstance(torch_dtype, torch.dtype):
|
||||
if torch_dtype is not None and not isinstance(torch_dtype, torch.dtype):
|
||||
torch_dtype = torch.float32
|
||||
logger.warning(
|
||||
f"Passed `torch_dtype` {torch_dtype} is not a `torch.dtype`. Defaulting to `torch.float32`."
|
||||
@@ -277,6 +282,7 @@ class FromOriginalModelMixin:
|
||||
if quantization_config is not None:
|
||||
hf_quantizer = DiffusersAutoQuantizer.from_config(quantization_config)
|
||||
hf_quantizer.validate_environment()
|
||||
torch_dtype = hf_quantizer.update_torch_dtype(torch_dtype)
|
||||
|
||||
else:
|
||||
hf_quantizer = None
|
||||
|
||||
@@ -44,6 +44,7 @@ from ..utils import (
|
||||
is_transformers_available,
|
||||
logging,
|
||||
)
|
||||
from ..utils.constants import DIFFUSERS_REQUEST_TIMEOUT
|
||||
from ..utils.hub_utils import _get_model_file
|
||||
|
||||
|
||||
@@ -117,6 +118,12 @@ CHECKPOINT_KEY_NAMES = {
|
||||
"hunyuan-video": "txt_in.individual_token_refiner.blocks.0.adaLN_modulation.1.bias",
|
||||
"instruct-pix2pix": "model.diffusion_model.input_blocks.0.0.weight",
|
||||
"lumina2": ["model.diffusion_model.cap_embedder.0.weight", "cap_embedder.0.weight"],
|
||||
"sana": [
|
||||
"blocks.0.cross_attn.q_linear.weight",
|
||||
"blocks.0.cross_attn.q_linear.bias",
|
||||
"blocks.0.cross_attn.kv_linear.weight",
|
||||
"blocks.0.cross_attn.kv_linear.bias",
|
||||
],
|
||||
"wan": ["model.diffusion_model.head.modulation", "head.modulation"],
|
||||
"wan_vae": "decoder.middle.0.residual.0.gamma",
|
||||
}
|
||||
@@ -178,6 +185,7 @@ DIFFUSERS_DEFAULT_PIPELINE_PATHS = {
|
||||
"hunyuan-video": {"pretrained_model_name_or_path": "hunyuanvideo-community/HunyuanVideo"},
|
||||
"instruct-pix2pix": {"pretrained_model_name_or_path": "timbrooks/instruct-pix2pix"},
|
||||
"lumina2": {"pretrained_model_name_or_path": "Alpha-VLLM/Lumina-Image-2.0"},
|
||||
"sana": {"pretrained_model_name_or_path": "Efficient-Large-Model/Sana_1600M_1024px_diffusers"},
|
||||
"wan-t2v-1.3B": {"pretrained_model_name_or_path": "Wan-AI/Wan2.1-T2V-1.3B-Diffusers"},
|
||||
"wan-t2v-14B": {"pretrained_model_name_or_path": "Wan-AI/Wan2.1-T2V-14B-Diffusers"},
|
||||
"wan-i2v-14B": {"pretrained_model_name_or_path": "Wan-AI/Wan2.1-I2V-14B-480P-Diffusers"},
|
||||
@@ -436,7 +444,7 @@ def fetch_original_config(original_config_file, local_files_only=False):
|
||||
"Please provide a valid local file path."
|
||||
)
|
||||
|
||||
original_config_file = BytesIO(requests.get(original_config_file).content)
|
||||
original_config_file = BytesIO(requests.get(original_config_file, timeout=DIFFUSERS_REQUEST_TIMEOUT).content)
|
||||
|
||||
else:
|
||||
raise ValueError("Invalid `original_config_file` provided. Please set it to a valid file path or URL.")
|
||||
@@ -669,6 +677,9 @@ def infer_diffusers_model_type(checkpoint):
|
||||
elif any(key in checkpoint for key in CHECKPOINT_KEY_NAMES["lumina2"]):
|
||||
model_type = "lumina2"
|
||||
|
||||
elif any(key in checkpoint for key in CHECKPOINT_KEY_NAMES["sana"]):
|
||||
model_type = "sana"
|
||||
|
||||
elif any(key in checkpoint for key in CHECKPOINT_KEY_NAMES["wan"]):
|
||||
if "model.diffusion_model.patch_embedding.weight" in checkpoint:
|
||||
target_key = "model.diffusion_model.patch_embedding.weight"
|
||||
@@ -2396,7 +2407,6 @@ def convert_ltx_vae_checkpoint_to_diffusers(checkpoint, **kwargs):
|
||||
"per_channel_statistics.channel": remove_keys_,
|
||||
"per_channel_statistics.mean-of-means": remove_keys_,
|
||||
"per_channel_statistics.mean-of-stds": remove_keys_,
|
||||
"timestep_scale_multiplier": remove_keys_,
|
||||
}
|
||||
|
||||
if "vae.decoder.last_time_embedder.timestep_embedder.linear_1.weight" in converted_state_dict:
|
||||
@@ -2897,6 +2907,111 @@ def convert_lumina2_to_diffusers(checkpoint, **kwargs):
|
||||
return converted_state_dict
|
||||
|
||||
|
||||
def convert_sana_transformer_to_diffusers(checkpoint, **kwargs):
|
||||
converted_state_dict = {}
|
||||
keys = list(checkpoint.keys())
|
||||
for k in keys:
|
||||
if "model.diffusion_model." in k:
|
||||
checkpoint[k.replace("model.diffusion_model.", "")] = checkpoint.pop(k)
|
||||
|
||||
num_layers = list(set(int(k.split(".", 2)[1]) for k in checkpoint if "blocks" in k))[-1] + 1 # noqa: C401
|
||||
|
||||
# Positional and patch embeddings.
|
||||
checkpoint.pop("pos_embed")
|
||||
converted_state_dict["patch_embed.proj.weight"] = checkpoint.pop("x_embedder.proj.weight")
|
||||
converted_state_dict["patch_embed.proj.bias"] = checkpoint.pop("x_embedder.proj.bias")
|
||||
|
||||
# Timestep embeddings.
|
||||
converted_state_dict["time_embed.emb.timestep_embedder.linear_1.weight"] = checkpoint.pop(
|
||||
"t_embedder.mlp.0.weight"
|
||||
)
|
||||
converted_state_dict["time_embed.emb.timestep_embedder.linear_1.bias"] = checkpoint.pop("t_embedder.mlp.0.bias")
|
||||
converted_state_dict["time_embed.emb.timestep_embedder.linear_2.weight"] = checkpoint.pop(
|
||||
"t_embedder.mlp.2.weight"
|
||||
)
|
||||
converted_state_dict["time_embed.emb.timestep_embedder.linear_2.bias"] = checkpoint.pop("t_embedder.mlp.2.bias")
|
||||
converted_state_dict["time_embed.linear.weight"] = checkpoint.pop("t_block.1.weight")
|
||||
converted_state_dict["time_embed.linear.bias"] = checkpoint.pop("t_block.1.bias")
|
||||
|
||||
# Caption Projection.
|
||||
checkpoint.pop("y_embedder.y_embedding")
|
||||
converted_state_dict["caption_projection.linear_1.weight"] = checkpoint.pop("y_embedder.y_proj.fc1.weight")
|
||||
converted_state_dict["caption_projection.linear_1.bias"] = checkpoint.pop("y_embedder.y_proj.fc1.bias")
|
||||
converted_state_dict["caption_projection.linear_2.weight"] = checkpoint.pop("y_embedder.y_proj.fc2.weight")
|
||||
converted_state_dict["caption_projection.linear_2.bias"] = checkpoint.pop("y_embedder.y_proj.fc2.bias")
|
||||
converted_state_dict["caption_norm.weight"] = checkpoint.pop("attention_y_norm.weight")
|
||||
|
||||
for i in range(num_layers):
|
||||
converted_state_dict[f"transformer_blocks.{i}.scale_shift_table"] = checkpoint.pop(
|
||||
f"blocks.{i}.scale_shift_table"
|
||||
)
|
||||
|
||||
# Self-Attention
|
||||
sample_q, sample_k, sample_v = torch.chunk(checkpoint.pop(f"blocks.{i}.attn.qkv.weight"), 3, dim=0)
|
||||
converted_state_dict[f"transformer_blocks.{i}.attn1.to_q.weight"] = torch.cat([sample_q])
|
||||
converted_state_dict[f"transformer_blocks.{i}.attn1.to_k.weight"] = torch.cat([sample_k])
|
||||
converted_state_dict[f"transformer_blocks.{i}.attn1.to_v.weight"] = torch.cat([sample_v])
|
||||
|
||||
# Output Projections
|
||||
converted_state_dict[f"transformer_blocks.{i}.attn1.to_out.0.weight"] = checkpoint.pop(
|
||||
f"blocks.{i}.attn.proj.weight"
|
||||
)
|
||||
converted_state_dict[f"transformer_blocks.{i}.attn1.to_out.0.bias"] = checkpoint.pop(
|
||||
f"blocks.{i}.attn.proj.bias"
|
||||
)
|
||||
|
||||
# Cross-Attention
|
||||
converted_state_dict[f"transformer_blocks.{i}.attn2.to_q.weight"] = checkpoint.pop(
|
||||
f"blocks.{i}.cross_attn.q_linear.weight"
|
||||
)
|
||||
converted_state_dict[f"transformer_blocks.{i}.attn2.to_q.bias"] = checkpoint.pop(
|
||||
f"blocks.{i}.cross_attn.q_linear.bias"
|
||||
)
|
||||
|
||||
linear_sample_k, linear_sample_v = torch.chunk(
|
||||
checkpoint.pop(f"blocks.{i}.cross_attn.kv_linear.weight"), 2, dim=0
|
||||
)
|
||||
linear_sample_k_bias, linear_sample_v_bias = torch.chunk(
|
||||
checkpoint.pop(f"blocks.{i}.cross_attn.kv_linear.bias"), 2, dim=0
|
||||
)
|
||||
converted_state_dict[f"transformer_blocks.{i}.attn2.to_k.weight"] = linear_sample_k
|
||||
converted_state_dict[f"transformer_blocks.{i}.attn2.to_v.weight"] = linear_sample_v
|
||||
converted_state_dict[f"transformer_blocks.{i}.attn2.to_k.bias"] = linear_sample_k_bias
|
||||
converted_state_dict[f"transformer_blocks.{i}.attn2.to_v.bias"] = linear_sample_v_bias
|
||||
|
||||
# Output Projections
|
||||
converted_state_dict[f"transformer_blocks.{i}.attn2.to_out.0.weight"] = checkpoint.pop(
|
||||
f"blocks.{i}.cross_attn.proj.weight"
|
||||
)
|
||||
converted_state_dict[f"transformer_blocks.{i}.attn2.to_out.0.bias"] = checkpoint.pop(
|
||||
f"blocks.{i}.cross_attn.proj.bias"
|
||||
)
|
||||
|
||||
# MLP
|
||||
converted_state_dict[f"transformer_blocks.{i}.ff.conv_inverted.weight"] = checkpoint.pop(
|
||||
f"blocks.{i}.mlp.inverted_conv.conv.weight"
|
||||
)
|
||||
converted_state_dict[f"transformer_blocks.{i}.ff.conv_inverted.bias"] = checkpoint.pop(
|
||||
f"blocks.{i}.mlp.inverted_conv.conv.bias"
|
||||
)
|
||||
converted_state_dict[f"transformer_blocks.{i}.ff.conv_depth.weight"] = checkpoint.pop(
|
||||
f"blocks.{i}.mlp.depth_conv.conv.weight"
|
||||
)
|
||||
converted_state_dict[f"transformer_blocks.{i}.ff.conv_depth.bias"] = checkpoint.pop(
|
||||
f"blocks.{i}.mlp.depth_conv.conv.bias"
|
||||
)
|
||||
converted_state_dict[f"transformer_blocks.{i}.ff.conv_point.weight"] = checkpoint.pop(
|
||||
f"blocks.{i}.mlp.point_conv.conv.weight"
|
||||
)
|
||||
|
||||
# Final layer
|
||||
converted_state_dict["proj_out.weight"] = checkpoint.pop("final_layer.linear.weight")
|
||||
converted_state_dict["proj_out.bias"] = checkpoint.pop("final_layer.linear.bias")
|
||||
converted_state_dict["scale_shift_table"] = checkpoint.pop("final_layer.scale_shift_table")
|
||||
|
||||
return converted_state_dict
|
||||
|
||||
|
||||
def convert_wan_transformer_to_diffusers(checkpoint, **kwargs):
|
||||
converted_state_dict = {}
|
||||
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user