Compare commits
35 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| 32476f9717 | |||
| a69dd6fdf0 | |||
| e539cd32c4 | |||
| 611d37549f | |||
| 739d6ec731 | |||
| 12eeb252d5 | |||
| 1ddf3f3a19 | |||
| 7aac77affa | |||
| 8907a70a36 | |||
| 5dbe4f5de6 | |||
| 1d37f42055 | |||
| 0213179ba8 | |||
| a7d53a5939 | |||
| 8a63aa5e4f | |||
| 844221ae4e | |||
| 9b2c0a7dbe | |||
| f424b1b062 | |||
| e9fda3924f | |||
| 2c1ed50fc5 | |||
| 15ad97f782 | |||
| 9f2d5c9ee9 | |||
| dc62e6931e | |||
| 56f740051d | |||
| a34d97cef0 | |||
| fc28791fc8 | |||
| ae14612673 | |||
| 0ab8fe49bf | |||
| 3be6706018 | |||
| cb1b8b21b8 | |||
| 27916822b2 | |||
| 3fe3bc0642 | |||
| 813d42cc96 | |||
| b4d7e9c632 | |||
| 2e83cbbb6d | |||
| 33d10af28f |
@@ -28,7 +28,51 @@ env:
|
||||
PIPELINE_USAGE_CUTOFF: 1000000000 # set high cutoff so that only always-test pipelines run
|
||||
|
||||
jobs:
|
||||
check_code_quality:
|
||||
runs-on: ubuntu-22.04
|
||||
steps:
|
||||
- uses: actions/checkout@v3
|
||||
- name: Set up Python
|
||||
uses: actions/setup-python@v4
|
||||
with:
|
||||
python-version: "3.8"
|
||||
- name: Install dependencies
|
||||
run: |
|
||||
python -m pip install --upgrade pip
|
||||
pip install .[quality]
|
||||
- name: Check quality
|
||||
run: make quality
|
||||
- name: Check if failure
|
||||
if: ${{ failure() }}
|
||||
run: |
|
||||
echo "Quality check failed. Please ensure the right dependency versions are installed with 'pip install -e .[quality]' and run 'make style && make quality'" >> $GITHUB_STEP_SUMMARY
|
||||
|
||||
check_repository_consistency:
|
||||
needs: check_code_quality
|
||||
runs-on: ubuntu-22.04
|
||||
steps:
|
||||
- uses: actions/checkout@v3
|
||||
- name: Set up Python
|
||||
uses: actions/setup-python@v4
|
||||
with:
|
||||
python-version: "3.8"
|
||||
- name: Install dependencies
|
||||
run: |
|
||||
python -m pip install --upgrade pip
|
||||
pip install .[quality]
|
||||
- name: Check repo consistency
|
||||
run: |
|
||||
python utils/check_copies.py
|
||||
python utils/check_dummies.py
|
||||
python utils/check_support_list.py
|
||||
make deps_table_check_updated
|
||||
- name: Check if failure
|
||||
if: ${{ failure() }}
|
||||
run: |
|
||||
echo "Repo consistency check failed. Please ensure the right dependency versions are installed with 'pip install -e .[quality]' and run 'make fix-copies'" >> $GITHUB_STEP_SUMMARY
|
||||
|
||||
setup_torch_cuda_pipeline_matrix:
|
||||
needs: [check_code_quality, check_repository_consistency]
|
||||
name: Setup Torch Pipelines CUDA Slow Tests Matrix
|
||||
runs-on:
|
||||
group: aws-general-8-plus
|
||||
@@ -133,6 +177,7 @@ jobs:
|
||||
|
||||
torch_cuda_tests:
|
||||
name: Torch CUDA Tests
|
||||
needs: [check_code_quality, check_repository_consistency]
|
||||
runs-on:
|
||||
group: aws-g4dn-2xlarge
|
||||
container:
|
||||
@@ -201,7 +246,7 @@ jobs:
|
||||
|
||||
run_examples_tests:
|
||||
name: Examples PyTorch CUDA tests on Ubuntu
|
||||
pip uninstall transformers -y && python -m uv pip install -U transformers@git+https://github.com/huggingface/transformers.git --no-deps
|
||||
needs: [check_code_quality, check_repository_consistency]
|
||||
runs-on:
|
||||
group: aws-g4dn-2xlarge
|
||||
|
||||
@@ -220,6 +265,7 @@ jobs:
|
||||
- name: Install dependencies
|
||||
run: |
|
||||
python -m venv /opt/venv && export PATH="/opt/venv/bin:$PATH"
|
||||
pip uninstall transformers -y && python -m uv pip install -U transformers@git+https://github.com/huggingface/transformers.git --no-deps
|
||||
python -m uv pip install -e [quality,test,training]
|
||||
|
||||
- name: Environment
|
||||
|
||||
@@ -496,6 +496,8 @@
|
||||
title: PixArt-Σ
|
||||
- local: api/pipelines/sana
|
||||
title: Sana
|
||||
- local: api/pipelines/sana_sprint
|
||||
title: Sana Sprint
|
||||
- local: api/pipelines/self_attention_guidance
|
||||
title: Self-Attention Guidance
|
||||
- local: api/pipelines/semantic_stable_diffusion
|
||||
|
||||
@@ -38,6 +38,33 @@ config = PyramidAttentionBroadcastConfig(
|
||||
pipe.transformer.enable_cache(config)
|
||||
```
|
||||
|
||||
## Faster Cache
|
||||
|
||||
[FasterCache](https://huggingface.co/papers/2410.19355) from Zhengyao Lv, Chenyang Si, Junhao Song, Zhenyu Yang, Yu Qiao, Ziwei Liu, Kwan-Yee K. Wong.
|
||||
|
||||
FasterCache is a method that speeds up inference in diffusion transformers by:
|
||||
- Reusing attention states between successive inference steps, due to high similarity between them
|
||||
- Skipping unconditional branch prediction used in classifier-free guidance by revealing redundancies between unconditional and conditional branch outputs for the same timestep, and therefore approximating the unconditional branch output using the conditional branch output
|
||||
|
||||
```python
|
||||
import torch
|
||||
from diffusers import CogVideoXPipeline, FasterCacheConfig
|
||||
|
||||
pipe = CogVideoXPipeline.from_pretrained("THUDM/CogVideoX-5b", torch_dtype=torch.bfloat16)
|
||||
pipe.to("cuda")
|
||||
|
||||
config = FasterCacheConfig(
|
||||
spatial_attention_block_skip_range=2,
|
||||
spatial_attention_timestep_skip_range=(-1, 681),
|
||||
current_timestep_callback=lambda: pipe.current_timestep,
|
||||
attention_weight_callback=lambda _: 0.3,
|
||||
unconditional_batch_skip_range=5,
|
||||
unconditional_batch_timestep_skip_range=(-1, 781),
|
||||
tensor_format="BFCHW",
|
||||
)
|
||||
pipe.transformer.enable_cache(config)
|
||||
```
|
||||
|
||||
### CacheMixin
|
||||
|
||||
[[autodoc]] CacheMixin
|
||||
@@ -47,3 +74,9 @@ pipe.transformer.enable_cache(config)
|
||||
[[autodoc]] PyramidAttentionBroadcastConfig
|
||||
|
||||
[[autodoc]] apply_pyramid_attention_broadcast
|
||||
|
||||
### FasterCacheConfig
|
||||
|
||||
[[autodoc]] FasterCacheConfig
|
||||
|
||||
[[autodoc]] apply_faster_cache
|
||||
|
||||
@@ -50,7 +50,8 @@ The following models are available for the image-to-video pipeline:
|
||||
| Model name | Description |
|
||||
|:---|:---|
|
||||
| [`Skywork/SkyReels-V1-Hunyuan-I2V`](https://huggingface.co/Skywork/SkyReels-V1-Hunyuan-I2V) | Skywork's custom finetune of HunyuanVideo (de-distilled). Performs best with `97x544x960` resolution. Performs best at `97x544x960` resolution, `guidance_scale=1.0`, `true_cfg_scale=6.0` and a negative prompt. |
|
||||
| [`hunyuanvideo-community/HunyuanVideo-I2V`](https://huggingface.co/hunyuanvideo-community/HunyuanVideo-I2V) | Tecent's official HunyuanVideo I2V model. Performs best at resolutions of 480, 720, 960, 1280. A higher `shift` value when initializing the scheduler is recommended (good values are between 7 and 20) |
|
||||
| [`hunyuanvideo-community/HunyuanVideo-I2V-33ch`](https://huggingface.co/hunyuanvideo-community/HunyuanVideo-I2V) | Tecent's official HunyuanVideo 33-channel I2V model. Performs best at resolutions of 480, 720, 960, 1280. A higher `shift` value when initializing the scheduler is recommended (good values are between 7 and 20). |
|
||||
| [`hunyuanvideo-community/HunyuanVideo-I2V`](https://huggingface.co/hunyuanvideo-community/HunyuanVideo-I2V) | Tecent's official HunyuanVideo 16-channel I2V model. Performs best at resolutions of 480, 720, 960, 1280. A higher `shift` value when initializing the scheduler is recommended (good values are between 7 and 20) |
|
||||
|
||||
## Quantization
|
||||
|
||||
|
||||
@@ -196,6 +196,12 @@ export_to_video(video, "ship.mp4", fps=24)
|
||||
- all
|
||||
- __call__
|
||||
|
||||
## LTXConditionPipeline
|
||||
|
||||
[[autodoc]] LTXConditionPipeline
|
||||
- all
|
||||
- __call__
|
||||
|
||||
## LTXPipelineOutput
|
||||
|
||||
[[autodoc]] pipelines.ltx.pipeline_output.LTXPipelineOutput
|
||||
|
||||
@@ -0,0 +1,100 @@
|
||||
<!-- Copyright 2024 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License. -->
|
||||
|
||||
# SanaSprintPipeline
|
||||
|
||||
<div class="flex flex-wrap space-x-1">
|
||||
<img alt="LoRA" src="https://img.shields.io/badge/LoRA-d8b4fe?style=flat"/>
|
||||
</div>
|
||||
|
||||
[SANA-Sprint: One-Step Diffusion with Continuous-Time Consistency Distillation](https://huggingface.co/papers/2503.09641) from NVIDIA, MIT HAN Lab, and Hugging Face by Junsong Chen, Shuchen Xue, Yuyang Zhao, Jincheng Yu, Sayak Paul, Junyu Chen, Han Cai, Enze Xie, Song Han
|
||||
|
||||
The abstract from the paper is:
|
||||
|
||||
*This paper presents SANA-Sprint, an efficient diffusion model for ultra-fast text-to-image (T2I) generation. SANA-Sprint is built on a pre-trained foundation model and augmented with hybrid distillation, dramatically reducing inference steps from 20 to 1-4. We introduce three key innovations: (1) We propose a training-free approach that transforms a pre-trained flow-matching model for continuous-time consistency distillation (sCM), eliminating costly training from scratch and achieving high training efficiency. Our hybrid distillation strategy combines sCM with latent adversarial distillation (LADD): sCM ensures alignment with the teacher model, while LADD enhances single-step generation fidelity. (2) SANA-Sprint is a unified step-adaptive model that achieves high-quality generation in 1-4 steps, eliminating step-specific training and improving efficiency. (3) We integrate ControlNet with SANA-Sprint for real-time interactive image generation, enabling instant visual feedback for user interaction. SANA-Sprint establishes a new Pareto frontier in speed-quality tradeoffs, achieving state-of-the-art performance with 7.59 FID and 0.74 GenEval in only 1 step — outperforming FLUX-schnell (7.94 FID / 0.71 GenEval) while being 10× faster (0.1s vs 1.1s on H100). It also achieves 0.1s (T2I) and 0.25s (ControlNet) latency for 1024×1024 images on H100, and 0.31s (T2I) on an RTX 4090, showcasing its exceptional efficiency and potential for AI-powered consumer applications (AIPC). Code and pre-trained models will be open-sourced.*
|
||||
|
||||
<Tip>
|
||||
|
||||
Make sure to check out the Schedulers [guide](../../using-diffusers/schedulers) to learn how to explore the tradeoff between scheduler speed and quality, and see the [reuse components across pipelines](../../using-diffusers/loading#reuse-a-pipeline) section to learn how to efficiently load the same components into multiple pipelines.
|
||||
|
||||
</Tip>
|
||||
|
||||
This pipeline was contributed by [lawrence-cj](https://github.com/lawrence-cj), [shuchen Xue](https://github.com/scxue) and [Enze Xie](https://github.com/xieenze). The original codebase can be found [here](https://github.com/NVlabs/Sana). The original weights can be found under [hf.co/Efficient-Large-Model](https://huggingface.co/Efficient-Large-Model/).
|
||||
|
||||
Available models:
|
||||
|
||||
| Model | Recommended dtype |
|
||||
|:-------------------------------------------------------------------------------------------------------------------------------------------:|:-----------------:|
|
||||
| [`Efficient-Large-Model/Sana_Sprint_1.6B_1024px_diffusers`](https://huggingface.co/Efficient-Large-Model/Sana_Sprint_1.6B_1024px_diffusers) | `torch.bfloat16` |
|
||||
| [`Efficient-Large-Model/Sana_Sprint_0.6B_1024px_diffusers`](https://huggingface.co/Efficient-Large-Model/Sana_Sprint_0.6B_1024px_diffusers) | `torch.bfloat16` |
|
||||
|
||||
Refer to [this](https://huggingface.co/collections/Efficient-Large-Model/sana-sprint-67d6810d65235085b3b17c76) collection for more information.
|
||||
|
||||
Note: The recommended dtype mentioned is for the transformer weights. The text encoder must stay in `torch.bfloat16` and VAE weights must stay in `torch.bfloat16` or `torch.float32` for the model to work correctly. Please refer to the inference example below to see how to load the model with the recommended dtype.
|
||||
|
||||
|
||||
## Quantization
|
||||
|
||||
Quantization helps reduce the memory requirements of very large models by storing model weights in a lower precision data type. However, quantization may have varying impact on video quality depending on the video model.
|
||||
|
||||
Refer to the [Quantization](../../quantization/overview) overview to learn more about supported quantization backends and selecting a quantization backend that supports your use case. The example below demonstrates how to load a quantized [`SanaSprintPipeline`] for inference with bitsandbytes.
|
||||
|
||||
```py
|
||||
import torch
|
||||
from diffusers import BitsAndBytesConfig as DiffusersBitsAndBytesConfig, SanaTransformer2DModel, SanaSprintPipeline
|
||||
from transformers import BitsAndBytesConfig as BitsAndBytesConfig, AutoModel
|
||||
|
||||
quant_config = BitsAndBytesConfig(load_in_8bit=True)
|
||||
text_encoder_8bit = AutoModel.from_pretrained(
|
||||
"Efficient-Large-Model/Sana_Sprint_1.6B_1024px_diffusers",
|
||||
subfolder="text_encoder",
|
||||
quantization_config=quant_config,
|
||||
torch_dtype=torch.bfloat16,
|
||||
)
|
||||
|
||||
quant_config = DiffusersBitsAndBytesConfig(load_in_8bit=True)
|
||||
transformer_8bit = SanaTransformer2DModel.from_pretrained(
|
||||
"Efficient-Large-Model/Sana_Sprint_1.6B_1024px_diffusers",
|
||||
subfolder="transformer",
|
||||
quantization_config=quant_config,
|
||||
torch_dtype=torch.bfloat16,
|
||||
)
|
||||
|
||||
pipeline = SanaSprintPipeline.from_pretrained(
|
||||
"Efficient-Large-Model/Sana_Sprint_1.6B_1024px_diffusers",
|
||||
text_encoder=text_encoder_8bit,
|
||||
transformer=transformer_8bit,
|
||||
torch_dtype=torch.bfloat16,
|
||||
device_map="balanced",
|
||||
)
|
||||
|
||||
prompt = "a tiny astronaut hatching from an egg on the moon"
|
||||
image = pipeline(prompt).images[0]
|
||||
image.save("sana.png")
|
||||
```
|
||||
|
||||
## Setting `max_timesteps`
|
||||
|
||||
Users can tweak the `max_timesteps` value for experimenting with the visual quality of the generated outputs. The default `max_timesteps` value was obtained with an inference-time search process. For more details about it, check out the paper.
|
||||
|
||||
## SanaSprintPipeline
|
||||
|
||||
[[autodoc]] SanaSprintPipeline
|
||||
- all
|
||||
- __call__
|
||||
|
||||
|
||||
## SanaPipelineOutput
|
||||
|
||||
[[autodoc]] pipelines.sana.pipeline_output.SanaPipelineOutput
|
||||
@@ -198,6 +198,18 @@ export_to_video(video, "output.mp4", fps=8)
|
||||
|
||||
Group offloading (for CUDA devices with support for asynchronous data transfer streams) overlaps data transfer and computation to reduce the overall execution time compared to sequential offloading. This is enabled using layer prefetching with CUDA streams. The next layer to be executed is loaded onto the accelerator device while the current layer is being executed - this increases the memory requirements slightly. Group offloading also supports leaf-level offloading (equivalent to sequential CPU offloading) but can be made much faster when using streams.
|
||||
|
||||
<Tip>
|
||||
|
||||
- Group offloading may not work with all models out-of-the-box. If the forward implementations of the model contain weight-dependent device-casting of inputs, it may clash with the offloading mechanism's handling of device-casting.
|
||||
- The `offload_type` parameter can be set to either `block_level` or `leaf_level`. `block_level` offloads groups of `torch::nn::ModuleList` or `torch::nn:Sequential` modules based on a configurable attribute `num_blocks_per_group`. For example, if you set `num_blocks_per_group=2` on a standard transformer model containing 40 layers, it will onload/offload 2 layers at a time for a total of 20 onload/offloads. This drastically reduces the VRAM requirements. `leaf_level` offloads individual layers at the lowest level, which is equivalent to sequential offloading. However, unlike sequential offloading, group offloading can be made much faster when using streams, with minimal compromise to end-to-end generation time.
|
||||
- The `use_stream` parameter can be used with CUDA devices to enable prefetching layers for onload. It defaults to `False`. Layer prefetching allows overlapping computation and data transfer of model weights, which drastically reduces the overall execution time compared to other offloading methods. However, it can increase the CPU RAM usage significantly. Ensure that available CPU RAM that is at least twice the size of the model when setting `use_stream=True`. You can find more information about CUDA streams [here](https://pytorch.org/docs/stable/generated/torch.cuda.Stream.html)
|
||||
- If specifying `use_stream=True` on VAEs with tiling enabled, make sure to do a dummy forward pass (possibly with dummy inputs) before the actual inference to avoid device-mismatch errors. This may not work on all implementations. Please open an issue if you encounter any problems.
|
||||
- The parameter `low_cpu_mem_usage` can be set to `True` to reduce CPU memory usage when using streams for group offloading. This is useful when the CPU memory is the bottleneck, but it may counteract the benefits of using streams and increase the overall execution time. The CPU memory savings come from creating pinned-tensors on-the-fly instead of pre-pinning them. This parameter is better suited for using `leaf_level` offloading.
|
||||
|
||||
For more information about available parameters and an explanation of how group offloading works, refer to [`~hooks.group_offloading.apply_group_offloading`].
|
||||
|
||||
</Tip>
|
||||
|
||||
## FP8 layerwise weight-casting
|
||||
|
||||
PyTorch supports `torch.float8_e4m3fn` and `torch.float8_e5m2` as weight storage dtypes, but they can't be used for computation in many different tensor operations due to unimplemented kernel support. However, you can use these dtypes to store model weights in fp8 precision and upcast them on-the-fly when the layers are used in the forward pass. This is known as layerwise weight-casting.
|
||||
@@ -235,6 +247,14 @@ In the above example, layerwise casting is enabled on the transformer component
|
||||
|
||||
However, you gain more control and flexibility by directly utilizing the [`~hooks.layerwise_casting.apply_layerwise_casting`] function instead of [`~ModelMixin.enable_layerwise_casting`].
|
||||
|
||||
<Tip>
|
||||
|
||||
- Layerwise casting may not work with all models out-of-the-box. Sometimes, the forward implementations of the model might contain internal typecasting of weight values. Such implementations are not supported due to the currently simplistic implementation of layerwise casting, which assumes that the forward pass is independent of the weight precision and that the input dtypes are always in `compute_dtype`. An example of an incompatible implementation can be found [here](https://github.com/huggingface/transformers/blob/7f5077e53682ca855afc826162b204ebf809f1f9/src/transformers/models/t5/modeling_t5.py#L294-L299).
|
||||
- Layerwise casting may fail on custom modeling implementations that make use of [PEFT](https://github.com/huggingface/peft) layers. Some minimal checks to handle this case is implemented but is not extensively tested or guaranteed to work in all cases.
|
||||
- It can be also be applied partially to specific layers of a model. Partially applying layerwise casting can either be done manually by calling the `apply_layerwise_casting` function on specific internal modules, or by specifying the `skip_modules_pattern` and `skip_modules_classes` parameters for a root module. These parameters are particularly useful for layers such as normalization and modulation.
|
||||
|
||||
</Tip>
|
||||
|
||||
## Channels-last memory format
|
||||
|
||||
The channels-last memory format is an alternative way of ordering NCHW tensors in memory to preserve dimension ordering. Channels-last tensors are ordered in such a way that the channels become the densest dimension (storing images pixel-per-pixel). Since not all operators currently support the channels-last format, it may result in worst performance but you should still try and see if it works for your model.
|
||||
|
||||
@@ -66,12 +66,6 @@ from accelerate.utils import write_basic_config
|
||||
write_basic_config()
|
||||
```
|
||||
|
||||
## 원을 채우는 데이터셋
|
||||
|
||||
원본 데이터셋은 ControlNet [repo](https://huggingface.co/lllyasviel/ControlNet/blob/main/training/fill50k.zip)에 올라와있지만, 우리는 [여기](https://huggingface.co/datasets/fusing/fill50k)에 새롭게 다시 올려서 🤗 Datasets 과 호환가능합니다. 그래서 학습 스크립트 상에서 데이터 불러오기를 다룰 수 있습니다.
|
||||
|
||||
우리의 학습 예시는 원래 ControlNet의 학습에 쓰였던 [`stable-diffusion-v1-5/stable-diffusion-v1-5`](https://huggingface.co/stable-diffusion-v1-5/stable-diffusion-v1-5)을 사용합니다. 그렇지만 ControlNet은 대응되는 어느 Stable Diffusion 모델([`CompVis/stable-diffusion-v1-4`](https://huggingface.co/CompVis/stable-diffusion-v1-4)) 혹은 [`stabilityai/stable-diffusion-2-1`](https://huggingface.co/stabilityai/stable-diffusion-2-1)의 증가를 위해 학습될 수 있습니다.
|
||||
|
||||
자체 데이터셋을 사용하기 위해서는 [학습을 위한 데이터셋 생성하기](create_dataset) 가이드를 확인하세요.
|
||||
|
||||
## 학습
|
||||
|
||||
@@ -24,12 +24,12 @@ Please also check out our [Community Scripts](https://github.com/huggingface/dif
|
||||
| Long Prompt Weighting Stable Diffusion | **One** Stable Diffusion Pipeline without tokens length limit, and support parsing weighting in prompt. | [Long Prompt Weighting Stable Diffusion](#long-prompt-weighting-stable-diffusion) | [Notebook](https://github.com/huggingface/notebooks/blob/main/diffusers/long_prompt_weighting_stable_diffusion.ipynb) | [SkyTNT](https://github.com/SkyTNT) |
|
||||
| Speech to Image | Using automatic-speech-recognition to transcribe text and Stable Diffusion to generate images | [Speech to Image](#speech-to-image) |[Notebook](https://github.com/huggingface/notebooks/blob/main/diffusers/speech_to_image.ipynb) | [Mikail Duzenli](https://github.com/MikailINTech)
|
||||
| Wild Card Stable Diffusion | Stable Diffusion Pipeline that supports prompts that contain wildcard terms (indicated by surrounding double underscores), with values instantiated randomly from a corresponding txt file or a dictionary of possible values | [Wildcard Stable Diffusion](#wildcard-stable-diffusion) | [Notebook](https://github.com/huggingface/notebooks/blob/main/diffusers/wildcard_stable_diffusion.ipynb) | [Shyam Sudhakaran](https://github.com/shyamsn97) |
|
||||
| [Composable Stable Diffusion](https://energy-based-model.github.io/Compositional-Visual-Generation-with-Composable-Diffusion-Models/) | Stable Diffusion Pipeline that supports prompts that contain "|" in prompts (as an AND condition) and weights (separated by "|" as well) to positively / negatively weight prompts. | [Composable Stable Diffusion](#composable-stable-diffusion) | - | [Mark Rich](https://github.com/MarkRich) |
|
||||
| [Composable Stable Diffusion](https://energy-based-model.github.io/Compositional-Visual-Generation-with-Composable-Diffusion-Models/) | Stable Diffusion Pipeline that supports prompts that contain "|" in prompts (as an AND condition) and weights (separated by "|" as well) to positively / negatively weight prompts. | [Composable Stable Diffusion](#composable-stable-diffusion) | [Notebook](https://github.com/huggingface/notebooks/blob/main/diffusers/composable_stable_diffusion.ipynb) | [Mark Rich](https://github.com/MarkRich) |
|
||||
| Seed Resizing Stable Diffusion | Stable Diffusion Pipeline that supports resizing an image and retaining the concepts of the 512 by 512 generation. | [Seed Resizing](#seed-resizing) | [Notebook](https://github.com/huggingface/notebooks/blob/main/diffusers/seed_resizing.ipynb) | [Mark Rich](https://github.com/MarkRich) |
|
||||
| Imagic Stable Diffusion | Stable Diffusion Pipeline that enables writing a text prompt to edit an existing image | [Imagic Stable Diffusion](#imagic-stable-diffusion) | [Notebook](https://github.com/huggingface/notebooks/blob/main/diffusers/imagic_stable_diffusion.ipynb) | [Mark Rich](https://github.com/MarkRich) |
|
||||
| Multilingual Stable Diffusion | Stable Diffusion Pipeline that supports prompts in 50 different languages. | [Multilingual Stable Diffusion](#multilingual-stable-diffusion-pipeline) | [Notebook](https://github.com/huggingface/notebooks/blob/main/diffusers/multilingual_stable_diffusion.ipynb) | [Juan Carlos Piñeros](https://github.com/juancopi81) |
|
||||
| GlueGen Stable Diffusion | Stable Diffusion Pipeline that supports prompts in different languages using GlueGen adapter. | [GlueGen Stable Diffusion](#gluegen-stable-diffusion-pipeline) | [Notebook](https://github.com/huggingface/notebooks/blob/main/diffusers/gluegen_stable_diffusion.ipynb) | [Phạm Hồng Vinh](https://github.com/rootonchair) |
|
||||
| Image to Image Inpainting Stable Diffusion | Stable Diffusion Pipeline that enables the overlaying of two images and subsequent inpainting | [Image to Image Inpainting Stable Diffusion](#image-to-image-inpainting-stable-diffusion) | - | [Alex McKinney](https://github.com/vvvm23) |
|
||||
| Image to Image Inpainting Stable Diffusion | Stable Diffusion Pipeline that enables the overlaying of two images and subsequent inpainting | [Image to Image Inpainting Stable Diffusion](#image-to-image-inpainting-stable-diffusion) | [Notebook](https://github.com/huggingface/notebooks/blob/main/diffusers/image_to_image_inpainting_stable_diffusion.ipynb) | [Alex McKinney](https://github.com/vvvm23) |
|
||||
| Text Based Inpainting Stable Diffusion | Stable Diffusion Inpainting Pipeline that enables passing a text prompt to generate the mask for inpainting | [Text Based Inpainting Stable Diffusion](#text-based-inpainting-stable-diffusion) | [Notebook](https://github.com/huggingface/notebooks/blob/main/diffusers/text_based_inpainting_stable_dffusion.ipynb) | [Dhruv Karan](https://github.com/unography) |
|
||||
| Bit Diffusion | Diffusion on discrete data | [Bit Diffusion](#bit-diffusion) | - | [Stuti R.](https://github.com/kingstut) |
|
||||
| K-Diffusion Stable Diffusion | Run Stable Diffusion with any of [K-Diffusion's samplers](https://github.com/crowsonkb/k-diffusion/blob/master/k_diffusion/sampling.py) | [Stable Diffusion with K Diffusion](#stable-diffusion-with-k-diffusion) | - | [Patrick von Platen](https://github.com/patrickvonplaten/) |
|
||||
@@ -41,7 +41,7 @@ Please also check out our [Community Scripts](https://github.com/huggingface/dif
|
||||
| UnCLIP Image Interpolation Pipeline | Diffusion Pipeline that allows passing two images/image_embeddings and produces images while interpolating between their image-embeddings | [UnCLIP Image Interpolation Pipeline](#unclip-image-interpolation-pipeline) | [Notebook](https://github.com/huggingface/notebooks/blob/main/diffusers/unclip_image_interpolation.ipynb)| [Naga Sai Abhinay Devarinti](https://github.com/Abhinay1997/) |
|
||||
| DDIM Noise Comparative Analysis Pipeline | Investigating how the diffusion models learn visual concepts from each noise level (which is a contribution of [P2 weighting (CVPR 2022)](https://arxiv.org/abs/2204.00227)) | [DDIM Noise Comparative Analysis Pipeline](#ddim-noise-comparative-analysis-pipeline) | [Notebook](https://github.com/huggingface/notebooks/blob/main/diffusers/ddim_noise_comparative_analysis.ipynb)| [Aengus (Duc-Anh)](https://github.com/aengusng8) |
|
||||
| CLIP Guided Img2Img Stable Diffusion Pipeline | Doing CLIP guidance for image to image generation with Stable Diffusion | [CLIP Guided Img2Img Stable Diffusion](#clip-guided-img2img-stable-diffusion) | [Notebook](https://github.com/huggingface/notebooks/blob/main/diffusers/clip_guided_img2img_stable_diffusion.ipynb) | [Nipun Jindal](https://github.com/nipunjindal/) |
|
||||
| TensorRT Stable Diffusion Text to Image Pipeline | Accelerates the Stable Diffusion Text2Image Pipeline using TensorRT | [TensorRT Stable Diffusion Text to Image Pipeline](#tensorrt-text2image-stable-diffusion-pipeline) | - | [Asfiya Baig](https://github.com/asfiyab-nvidia) |
|
||||
| TensorRT Stable Diffusion Text to Image Pipeline | Accelerates the Stable Diffusion Text2Image Pipeline using TensorRT | [TensorRT Stable Diffusion Text to Image Pipeline](#tensorrt-text2image-stable-diffusion-pipeline) | [Notebook](https://github.com/huggingface/notebooks/blob/main/diffusers/tensorrt_text2image_stable_diffusion_pipeline.ipynb) | [Asfiya Baig](https://github.com/asfiyab-nvidia) |
|
||||
| EDICT Image Editing Pipeline | Diffusion pipeline for text-guided image editing | [EDICT Image Editing Pipeline](#edict-image-editing-pipeline) | [Notebook](https://github.com/huggingface/notebooks/blob/main/diffusers/edict_image_pipeline.ipynb) | [Joqsan Azocar](https://github.com/Joqsan) |
|
||||
| Stable Diffusion RePaint | Stable Diffusion pipeline using [RePaint](https://arxiv.org/abs/2201.09865) for inpainting. | [Stable Diffusion RePaint](#stable-diffusion-repaint )|[Notebook](https://github.com/huggingface/notebooks/blob/main/diffusers/stable_diffusion_repaint.ipynb)| [Markus Pobitzer](https://github.com/Markus-Pobitzer) |
|
||||
| TensorRT Stable Diffusion Image to Image Pipeline | Accelerates the Stable Diffusion Image2Image Pipeline using TensorRT | [TensorRT Stable Diffusion Image to Image Pipeline](#tensorrt-image2image-stable-diffusion-pipeline) | - | [Asfiya Baig](https://github.com/asfiyab-nvidia) |
|
||||
@@ -58,7 +58,7 @@ Please also check out our [Community Scripts](https://github.com/huggingface/dif
|
||||
| FABRIC - Stable Diffusion with feedback Pipeline | pipeline supports feedback from liked and disliked images | [Stable Diffusion Fabric Pipeline](#stable-diffusion-fabric-pipeline) | [Notebook](https://github.com/huggingface/notebooks/blob/main/diffusers/stable_diffusion_fabric.ipynb)| [Shauray Singh](https://shauray8.github.io/about_shauray/) |
|
||||
| sketch inpaint - Inpainting with non-inpaint Stable Diffusion | sketch inpaint much like in automatic1111 | [Masked Im2Im Stable Diffusion Pipeline](#stable-diffusion-masked-im2im) | - | [Anatoly Belikov](https://github.com/noskill) |
|
||||
| sketch inpaint xl - Inpainting with non-inpaint Stable Diffusion | sketch inpaint much like in automatic1111 | [Masked Im2Im Stable Diffusion XL Pipeline](#stable-diffusion-xl-masked-im2im) | - | [Anatoly Belikov](https://github.com/noskill) |
|
||||
| prompt-to-prompt | change parts of a prompt and retain image structure (see [paper page](https://prompt-to-prompt.github.io/)) | [Prompt2Prompt Pipeline](#prompt2prompt-pipeline) | - | [Umer H. Adil](https://twitter.com/UmerHAdil) |
|
||||
| prompt-to-prompt | change parts of a prompt and retain image structure (see [paper page](https://prompt-to-prompt.github.io/)) | [Prompt2Prompt Pipeline](#prompt2prompt-pipeline) | [Notebook](https://github.com/huggingface/notebooks/blob/main/diffusers/prompt_2_prompt_pipeline.ipynb) | [Umer H. Adil](https://twitter.com/UmerHAdil) |
|
||||
| Latent Consistency Pipeline | Implementation of [Latent Consistency Models: Synthesizing High-Resolution Images with Few-Step Inference](https://arxiv.org/abs/2310.04378) | [Latent Consistency Pipeline](#latent-consistency-pipeline) | - | [Simian Luo](https://github.com/luosiallen) |
|
||||
| Latent Consistency Img2img Pipeline | Img2img pipeline for Latent Consistency Models | [Latent Consistency Img2Img Pipeline](#latent-consistency-img2img-pipeline) | - | [Logan Zoellner](https://github.com/nagolinc) |
|
||||
| Latent Consistency Interpolation Pipeline | Interpolate the latent space of Latent Consistency Models with multiple prompts | [Latent Consistency Interpolation Pipeline](#latent-consistency-interpolation-pipeline) | [](https://colab.research.google.com/drive/1pK3NrLWJSiJsBynLns1K1-IDTW9zbPvl?usp=sharing) | [Aryan V S](https://github.com/a-r-r-o-w) |
|
||||
@@ -954,6 +954,7 @@ for i in range(args.num_images):
|
||||
images.append(th.from_numpy(np.array(image)).permute(2, 0, 1) / 255.)
|
||||
grid = tvu.make_grid(th.stack(images, dim=0), nrow=4, padding=0)
|
||||
tvu.save_image(grid, f'{prompt}_{args.weights}' + '.png')
|
||||
print("Image saved successfully!")
|
||||
```
|
||||
|
||||
### Imagic Stable Diffusion
|
||||
@@ -1269,28 +1270,39 @@ The aim is to overlay two images, then mask out the boundary between `image` and
|
||||
For example, this could be used to place a logo on a shirt and make it blend seamlessly.
|
||||
|
||||
```python
|
||||
import PIL
|
||||
import torch
|
||||
|
||||
import requests
|
||||
from PIL import Image
|
||||
from io import BytesIO
|
||||
from diffusers import DiffusionPipeline
|
||||
|
||||
image_path = "./path-to-image.png"
|
||||
inner_image_path = "./path-to-inner-image.png"
|
||||
mask_path = "./path-to-mask.png"
|
||||
image_url = "https://raw.githubusercontent.com/CompVis/latent-diffusion/main/data/inpainting_examples/overture-creations-5sI6fQgYIuo.png"
|
||||
inner_image_url = "https://raw.githubusercontent.com/CompVis/latent-diffusion/main/data/inpainting_examples/overture-creations-5sI6fQgYIuo_mask.png"
|
||||
mask_url = "https://raw.githubusercontent.com/CompVis/latent-diffusion/main/data/inpainting_examples/overture-creations-5sI6fQgYIuo_mask.png"
|
||||
|
||||
init_image = PIL.Image.open(image_path).convert("RGB").resize((512, 512))
|
||||
inner_image = PIL.Image.open(inner_image_path).convert("RGBA").resize((512, 512))
|
||||
mask_image = PIL.Image.open(mask_path).convert("RGB").resize((512, 512))
|
||||
def load_image(url, mode="RGB"):
|
||||
response = requests.get(url)
|
||||
if response.status_code == 200:
|
||||
return Image.open(BytesIO(response.content)).convert(mode).resize((512, 512))
|
||||
else:
|
||||
raise FileNotFoundError(f"Could not retrieve image from {url}")
|
||||
|
||||
|
||||
init_image = load_image(image_url, mode="RGB")
|
||||
inner_image = load_image(inner_image_url, mode="RGBA")
|
||||
mask_image = load_image(mask_url, mode="RGB")
|
||||
|
||||
pipe = DiffusionPipeline.from_pretrained(
|
||||
"runwayml/stable-diffusion-inpainting",
|
||||
"stable-diffusion-v1-5/stable-diffusion-inpainting",
|
||||
custom_pipeline="img2img_inpainting",
|
||||
torch_dtype=torch.float16
|
||||
)
|
||||
pipe = pipe.to("cuda")
|
||||
|
||||
prompt = "Your prompt here!"
|
||||
prompt = "a mecha robot sitting on a bench"
|
||||
image = pipe(prompt=prompt, image=init_image, inner_image=inner_image, mask_image=mask_image).images[0]
|
||||
|
||||
image.save("output.png")
|
||||
```
|
||||
|
||||

|
||||
@@ -3252,14 +3264,19 @@ Here's a full example for `ReplaceEdit``:
|
||||
|
||||
```python
|
||||
import torch
|
||||
import numpy as np
|
||||
import matplotlib.pyplot as plt
|
||||
from diffusers import DiffusionPipeline
|
||||
import numpy as np
|
||||
from PIL import Image
|
||||
|
||||
pipe = DiffusionPipeline.from_pretrained("CompVis/stable-diffusion-v1-4", custom_pipeline="pipeline_prompt2prompt").to("cuda")
|
||||
pipe = DiffusionPipeline.from_pretrained(
|
||||
"CompVis/stable-diffusion-v1-4",
|
||||
custom_pipeline="pipeline_prompt2prompt"
|
||||
).to("cuda")
|
||||
|
||||
prompts = ["A turtle playing with a ball",
|
||||
"A monkey playing with a ball"]
|
||||
prompts = [
|
||||
"A turtle playing with a ball",
|
||||
"A monkey playing with a ball"
|
||||
]
|
||||
|
||||
cross_attention_kwargs = {
|
||||
"edit_type": "replace",
|
||||
@@ -3267,7 +3284,15 @@ cross_attention_kwargs = {
|
||||
"self_replace_steps": 0.4
|
||||
}
|
||||
|
||||
outputs = pipe(prompt=prompts, height=512, width=512, num_inference_steps=50, cross_attention_kwargs=cross_attention_kwargs)
|
||||
outputs = pipe(
|
||||
prompt=prompts,
|
||||
height=512,
|
||||
width=512,
|
||||
num_inference_steps=50,
|
||||
cross_attention_kwargs=cross_attention_kwargs
|
||||
)
|
||||
|
||||
outputs.images[0].save("output_image_0.png")
|
||||
```
|
||||
|
||||
And abbreviated examples for the other edits:
|
||||
|
||||
@@ -1,20 +1,27 @@
|
||||
# AnyTextPipeline Pipeline
|
||||
# AnyTextPipeline
|
||||
|
||||
Project page: https://aigcdesigngroup.github.io/homepage_anytext
|
||||
|
||||
"AnyText comprises a diffusion pipeline with two primary elements: an auxiliary latent module and a text embedding module. The former uses inputs like text glyph, position, and masked image to generate latent features for text generation or editing. The latter employs an OCR model for encoding stroke data as embeddings, which blend with image caption embeddings from the tokenizer to generate texts that seamlessly integrate with the background. We employed text-control diffusion loss and text perceptual loss for training to further enhance writing accuracy."
|
||||
|
||||
Each text line that needs to be generated should be enclosed in double quotes. For any usage questions, please refer to the [paper](https://arxiv.org/abs/2311.03054).
|
||||
> **Note:** Each text line that needs to be generated should be enclosed in double quotes.
|
||||
|
||||
For any usage questions, please refer to the [paper](https://arxiv.org/abs/2311.03054).
|
||||
|
||||
[](https://colab.research.google.com/gist/tolgacangoz/b87ec9d2f265b448dd947c9d4a0da389/anytext.ipynb)
|
||||
|
||||
```py
|
||||
# This example requires the `anytext_controlnet.py` file:
|
||||
# !git clone --depth 1 https://github.com/huggingface/diffusers.git
|
||||
# %cd diffusers/examples/research_projects/anytext
|
||||
# Let's choose a font file shared by an HF staff:
|
||||
# !wget https://huggingface.co/spaces/ysharma/TranslateQuotesInImageForwards/resolve/main/arial-unicode-ms.ttf
|
||||
|
||||
import torch
|
||||
from diffusers import DiffusionPipeline
|
||||
from anytext_controlnet import AnyTextControlNetModel
|
||||
from diffusers.utils import load_image
|
||||
|
||||
# I chose a font file shared by an HF staff:
|
||||
# !wget https://huggingface.co/spaces/ysharma/TranslateQuotesInImageForwards/resolve/main/arial-unicode-ms.ttf
|
||||
|
||||
anytext_controlnet = AnyTextControlNetModel.from_pretrained("tolgacangoz/anytext-controlnet", torch_dtype=torch.float16,
|
||||
variant="fp16",)
|
||||
@@ -26,6 +33,7 @@ pipe = DiffusionPipeline.from_pretrained("tolgacangoz/anytext", font_path="arial
|
||||
# generate image
|
||||
prompt = 'photo of caramel macchiato coffee on the table, top-down perspective, with "Any" "Text" written on it using cream'
|
||||
draw_pos = load_image("https://raw.githubusercontent.com/tyxsspa/AnyText/refs/heads/main/example_images/gen9.png")
|
||||
# There are two modes: "generate" and "edit". "edit" mode requires `ori_image` parameter for the image to be edited.
|
||||
image = pipe(prompt, num_inference_steps=20, mode="generate", draw_pos=draw_pos,
|
||||
).images[0]
|
||||
image
|
||||
|
||||
@@ -146,14 +146,17 @@ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||
EXAMPLE_DOC_STRING = """
|
||||
Examples:
|
||||
```py
|
||||
>>> # This example requires the `anytext_controlnet.py` file:
|
||||
>>> # !git clone --depth 1 https://github.com/huggingface/diffusers.git
|
||||
>>> # %cd diffusers/examples/research_projects/anytext
|
||||
>>> # Let's choose a font file shared by an HF staff:
|
||||
>>> # !wget https://huggingface.co/spaces/ysharma/TranslateQuotesInImageForwards/resolve/main/arial-unicode-ms.ttf
|
||||
|
||||
>>> import torch
|
||||
>>> from diffusers import DiffusionPipeline
|
||||
>>> from anytext_controlnet import AnyTextControlNetModel
|
||||
>>> from diffusers.utils import load_image
|
||||
|
||||
>>> # I chose a font file shared by an HF staff:
|
||||
>>> !wget https://huggingface.co/spaces/ysharma/TranslateQuotesInImageForwards/resolve/main/arial-unicode-ms.ttf
|
||||
|
||||
>>> anytext_controlnet = AnyTextControlNetModel.from_pretrained("tolgacangoz/anytext-controlnet", torch_dtype=torch.float16,
|
||||
... variant="fp16",)
|
||||
>>> pipe = DiffusionPipeline.from_pretrained("tolgacangoz/anytext", font_path="arial-unicode-ms.ttf",
|
||||
@@ -165,6 +168,7 @@ EXAMPLE_DOC_STRING = """
|
||||
>>> # generate image
|
||||
>>> prompt = 'photo of caramel macchiato coffee on the table, top-down perspective, with "Any" "Text" written on it using cream'
|
||||
>>> draw_pos = load_image("https://raw.githubusercontent.com/tyxsspa/AnyText/refs/heads/main/example_images/gen9.png")
|
||||
>>> # There are two modes: "generate" and "edit". "edit" mode requires `ori_image` parameter for the image to be edited.
|
||||
>>> image = pipe(prompt, num_inference_steps=20, mode="generate", draw_pos=draw_pos,
|
||||
... ).images[0]
|
||||
>>> image
|
||||
@@ -257,11 +261,11 @@ class EmbeddingManager(ModelMixin, ConfigMixin):
|
||||
idx = tokenized_text[i] == self.placeholder_token.to(device)
|
||||
if sum(idx) > 0:
|
||||
if i >= len(self.text_embs_all):
|
||||
print("truncation for log images...")
|
||||
logger.warning("truncation for log images...")
|
||||
break
|
||||
text_emb = torch.cat(self.text_embs_all[i], dim=0)
|
||||
if sum(idx) != len(text_emb):
|
||||
print("truncation for long caption...")
|
||||
logger.warning("truncation for long caption...")
|
||||
text_emb = text_emb.to(embedded_text.device)
|
||||
embedded_text[i][idx] = text_emb[: sum(idx)]
|
||||
return embedded_text
|
||||
@@ -1058,6 +1062,8 @@ class AuxiliaryLatentModule(ModelMixin, ConfigMixin):
|
||||
raise ValueError(f"Can't read ori_image image from {ori_image}!")
|
||||
elif isinstance(ori_image, torch.Tensor):
|
||||
ori_image = ori_image.cpu().numpy()
|
||||
elif isinstance(ori_image, PIL.Image.Image):
|
||||
ori_image = np.array(ori_image.convert("RGB"))
|
||||
else:
|
||||
if not isinstance(ori_image, np.ndarray):
|
||||
raise ValueError(f"Unknown format of ori_image: {type(ori_image)}")
|
||||
|
||||
@@ -627,6 +627,7 @@ def main(args):
|
||||
ema_vae = EMAModel(vae.parameters(), model_cls=AutoencoderKL, model_config=vae.config)
|
||||
perceptual_loss = lpips.LPIPS(net="vgg").eval()
|
||||
discriminator = NLayerDiscriminator(input_nc=3, n_layers=3, use_actnorm=False).apply(weights_init)
|
||||
discriminator = torch.nn.SyncBatchNorm.convert_sync_batchnorm(discriminator)
|
||||
|
||||
# Taken from [Sayak Paul's Diffusers PR #6511](https://github.com/huggingface/diffusers/pull/6511/files)
|
||||
def unwrap_model(model):
|
||||
@@ -951,13 +952,20 @@ def main(args):
|
||||
logits_fake = discriminator(reconstructions)
|
||||
disc_loss = hinge_d_loss if args.disc_loss == "hinge" else vanilla_d_loss
|
||||
disc_factor = args.disc_factor if global_step >= args.disc_start else 0.0
|
||||
disc_loss = disc_factor * disc_loss(logits_real, logits_fake)
|
||||
d_loss = disc_factor * disc_loss(logits_real, logits_fake)
|
||||
logs = {
|
||||
"disc_loss": disc_loss.detach().mean().item(),
|
||||
"disc_loss": d_loss.detach().mean().item(),
|
||||
"logits_real": logits_real.detach().mean().item(),
|
||||
"logits_fake": logits_fake.detach().mean().item(),
|
||||
"disc_lr": disc_lr_scheduler.get_last_lr()[0],
|
||||
}
|
||||
accelerator.backward(d_loss)
|
||||
if accelerator.sync_gradients:
|
||||
params_to_clip = discriminator.parameters()
|
||||
accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm)
|
||||
disc_optimizer.step()
|
||||
disc_lr_scheduler.step()
|
||||
disc_optimizer.zero_grad(set_to_none=args.set_grads_to_none)
|
||||
# Checks if the accelerator has performed an optimization step behind the scenes
|
||||
if accelerator.sync_gradients:
|
||||
progress_bar.update(1)
|
||||
|
||||
@@ -1,8 +1,6 @@
|
||||
# Generating images using Flux and PyTorch/XLA
|
||||
|
||||
The `flux_inference` script shows how to do image generation using Flux on TPU devices using PyTorch/XLA. It uses the pallas kernel for flash attention for faster generation.
|
||||
|
||||
It has been tested on [Trillium](https://cloud.google.com/blog/products/compute/introducing-trillium-6th-gen-tpus) TPU versions. No other TPU types have been tested.
|
||||
The `flux_inference` script shows how to do image generation using Flux on TPU devices using PyTorch/XLA. It uses the pallas kernel for flash attention for faster generation using custom flash block sizes for better performance on [Trillium](https://cloud.google.com/blog/products/compute/introducing-trillium-6th-gen-tpus) TPU versions. No other TPU types have been tested.
|
||||
|
||||
## Create TPU
|
||||
|
||||
@@ -23,20 +21,23 @@ Verify that PyTorch and PyTorch/XLA were installed correctly:
|
||||
python3 -c "import torch; import torch_xla;"
|
||||
```
|
||||
|
||||
Install dependencies
|
||||
Clone the diffusers repo and install dependencies
|
||||
|
||||
```bash
|
||||
git clone https://github.com/huggingface/diffusers.git
|
||||
cd diffusers
|
||||
pip install transformers accelerate sentencepiece structlog
|
||||
pushd ../../..
|
||||
pip install .
|
||||
popd
|
||||
cd examples/research_projects/pytorch_xla/inference/flux/
|
||||
```
|
||||
|
||||
## Run the inference job
|
||||
|
||||
### Authenticate
|
||||
|
||||
Run the following command to authenticate your token in order to download Flux weights.
|
||||
**Gated Model**
|
||||
|
||||
As the model is gated, before using it with diffusers you first need to go to the [FLUX.1 [dev] Hugging Face page](https://huggingface.co/black-forest-labs/FLUX.1-dev), fill in the form and accept the gate. Once you are in, you need to log in so that your system knows you’ve accepted the gate. Use the command below to log in:
|
||||
|
||||
```bash
|
||||
huggingface-cli login
|
||||
|
||||
@@ -160,8 +160,9 @@ TRANSFORMER_CONFIGS = {
|
||||
"pooled_projection_dim": 768,
|
||||
"rope_theta": 256.0,
|
||||
"rope_axes_dim": (16, 56, 56),
|
||||
"image_condition_type": None,
|
||||
},
|
||||
"HYVideo-T/2-I2V": {
|
||||
"HYVideo-T/2-I2V-33ch": {
|
||||
"in_channels": 16 * 2 + 1,
|
||||
"out_channels": 16,
|
||||
"num_attention_heads": 24,
|
||||
@@ -178,6 +179,26 @@ TRANSFORMER_CONFIGS = {
|
||||
"pooled_projection_dim": 768,
|
||||
"rope_theta": 256.0,
|
||||
"rope_axes_dim": (16, 56, 56),
|
||||
"image_condition_type": "latent_concat",
|
||||
},
|
||||
"HYVideo-T/2-I2V-16ch": {
|
||||
"in_channels": 16,
|
||||
"out_channels": 16,
|
||||
"num_attention_heads": 24,
|
||||
"attention_head_dim": 128,
|
||||
"num_layers": 20,
|
||||
"num_single_layers": 40,
|
||||
"num_refiner_layers": 2,
|
||||
"mlp_ratio": 4.0,
|
||||
"patch_size": 2,
|
||||
"patch_size_t": 1,
|
||||
"qk_norm": "rms_norm",
|
||||
"guidance_embeds": True,
|
||||
"text_embed_dim": 4096,
|
||||
"pooled_projection_dim": 768,
|
||||
"rope_theta": 256.0,
|
||||
"rope_axes_dim": (16, 56, 56),
|
||||
"image_condition_type": "token_replace",
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
@@ -74,6 +74,32 @@ VAE_091_RENAME_DICT = {
|
||||
"last_scale_shift_table": "scale_shift_table",
|
||||
}
|
||||
|
||||
VAE_095_RENAME_DICT = {
|
||||
# decoder
|
||||
"up_blocks.0": "mid_block",
|
||||
"up_blocks.1": "up_blocks.0.upsamplers.0",
|
||||
"up_blocks.2": "up_blocks.0",
|
||||
"up_blocks.3": "up_blocks.1.upsamplers.0",
|
||||
"up_blocks.4": "up_blocks.1",
|
||||
"up_blocks.5": "up_blocks.2.upsamplers.0",
|
||||
"up_blocks.6": "up_blocks.2",
|
||||
"up_blocks.7": "up_blocks.3.upsamplers.0",
|
||||
"up_blocks.8": "up_blocks.3",
|
||||
# encoder
|
||||
"down_blocks.0": "down_blocks.0",
|
||||
"down_blocks.1": "down_blocks.0.downsamplers.0",
|
||||
"down_blocks.2": "down_blocks.1",
|
||||
"down_blocks.3": "down_blocks.1.downsamplers.0",
|
||||
"down_blocks.4": "down_blocks.2",
|
||||
"down_blocks.5": "down_blocks.2.downsamplers.0",
|
||||
"down_blocks.6": "down_blocks.3",
|
||||
"down_blocks.7": "down_blocks.3.downsamplers.0",
|
||||
"down_blocks.8": "mid_block",
|
||||
# common
|
||||
"last_time_embedder": "time_embedder",
|
||||
"last_scale_shift_table": "scale_shift_table",
|
||||
}
|
||||
|
||||
VAE_SPECIAL_KEYS_REMAP = {
|
||||
"per_channel_statistics.channel": remove_keys_,
|
||||
"per_channel_statistics.mean-of-means": remove_keys_,
|
||||
@@ -81,10 +107,6 @@ VAE_SPECIAL_KEYS_REMAP = {
|
||||
"model.diffusion_model": remove_keys_,
|
||||
}
|
||||
|
||||
VAE_091_SPECIAL_KEYS_REMAP = {
|
||||
"timestep_scale_multiplier": remove_keys_,
|
||||
}
|
||||
|
||||
|
||||
def get_state_dict(saved_dict: Dict[str, Any]) -> Dict[str, Any]:
|
||||
state_dict = saved_dict
|
||||
@@ -104,12 +126,16 @@ def update_state_dict_inplace(state_dict: Dict[str, Any], old_key: str, new_key:
|
||||
def convert_transformer(
|
||||
ckpt_path: str,
|
||||
dtype: torch.dtype,
|
||||
version: str = "0.9.0",
|
||||
):
|
||||
PREFIX_KEY = "model.diffusion_model."
|
||||
|
||||
original_state_dict = get_state_dict(load_file(ckpt_path))
|
||||
config = {}
|
||||
if version == "0.9.5":
|
||||
config["_use_causal_rope_fix"] = True
|
||||
with init_empty_weights():
|
||||
transformer = LTXVideoTransformer3DModel()
|
||||
transformer = LTXVideoTransformer3DModel(**config)
|
||||
|
||||
for key in list(original_state_dict.keys()):
|
||||
new_key = key[:]
|
||||
@@ -161,12 +187,19 @@ def get_vae_config(version: str) -> Dict[str, Any]:
|
||||
"out_channels": 3,
|
||||
"latent_channels": 128,
|
||||
"block_out_channels": (128, 256, 512, 512),
|
||||
"down_block_types": (
|
||||
"LTXVideoDownBlock3D",
|
||||
"LTXVideoDownBlock3D",
|
||||
"LTXVideoDownBlock3D",
|
||||
"LTXVideoDownBlock3D",
|
||||
),
|
||||
"decoder_block_out_channels": (128, 256, 512, 512),
|
||||
"layers_per_block": (4, 3, 3, 3, 4),
|
||||
"decoder_layers_per_block": (4, 3, 3, 3, 4),
|
||||
"spatio_temporal_scaling": (True, True, True, False),
|
||||
"decoder_spatio_temporal_scaling": (True, True, True, False),
|
||||
"decoder_inject_noise": (False, False, False, False, False),
|
||||
"downsample_type": ("conv", "conv", "conv", "conv"),
|
||||
"upsample_residual": (False, False, False, False),
|
||||
"upsample_factor": (1, 1, 1, 1),
|
||||
"patch_size": 4,
|
||||
@@ -183,12 +216,19 @@ def get_vae_config(version: str) -> Dict[str, Any]:
|
||||
"out_channels": 3,
|
||||
"latent_channels": 128,
|
||||
"block_out_channels": (128, 256, 512, 512),
|
||||
"down_block_types": (
|
||||
"LTXVideoDownBlock3D",
|
||||
"LTXVideoDownBlock3D",
|
||||
"LTXVideoDownBlock3D",
|
||||
"LTXVideoDownBlock3D",
|
||||
),
|
||||
"decoder_block_out_channels": (256, 512, 1024),
|
||||
"layers_per_block": (4, 3, 3, 3, 4),
|
||||
"decoder_layers_per_block": (5, 6, 7, 8),
|
||||
"spatio_temporal_scaling": (True, True, True, False),
|
||||
"decoder_spatio_temporal_scaling": (True, True, True),
|
||||
"decoder_inject_noise": (True, True, True, False),
|
||||
"downsample_type": ("conv", "conv", "conv", "conv"),
|
||||
"upsample_residual": (True, True, True),
|
||||
"upsample_factor": (2, 2, 2),
|
||||
"timestep_conditioning": True,
|
||||
@@ -200,7 +240,38 @@ def get_vae_config(version: str) -> Dict[str, Any]:
|
||||
"decoder_causal": False,
|
||||
}
|
||||
VAE_KEYS_RENAME_DICT.update(VAE_091_RENAME_DICT)
|
||||
VAE_SPECIAL_KEYS_REMAP.update(VAE_091_SPECIAL_KEYS_REMAP)
|
||||
elif version == "0.9.5":
|
||||
config = {
|
||||
"in_channels": 3,
|
||||
"out_channels": 3,
|
||||
"latent_channels": 128,
|
||||
"block_out_channels": (128, 256, 512, 1024, 2048),
|
||||
"down_block_types": (
|
||||
"LTXVideo095DownBlock3D",
|
||||
"LTXVideo095DownBlock3D",
|
||||
"LTXVideo095DownBlock3D",
|
||||
"LTXVideo095DownBlock3D",
|
||||
),
|
||||
"decoder_block_out_channels": (256, 512, 1024),
|
||||
"layers_per_block": (4, 6, 6, 2, 2),
|
||||
"decoder_layers_per_block": (5, 5, 5, 5),
|
||||
"spatio_temporal_scaling": (True, True, True, True),
|
||||
"decoder_spatio_temporal_scaling": (True, True, True),
|
||||
"decoder_inject_noise": (False, False, False, False),
|
||||
"downsample_type": ("spatial", "temporal", "spatiotemporal", "spatiotemporal"),
|
||||
"upsample_residual": (True, True, True),
|
||||
"upsample_factor": (2, 2, 2),
|
||||
"timestep_conditioning": True,
|
||||
"patch_size": 4,
|
||||
"patch_size_t": 1,
|
||||
"resnet_norm_eps": 1e-6,
|
||||
"scaling_factor": 1.0,
|
||||
"encoder_causal": True,
|
||||
"decoder_causal": False,
|
||||
"spatial_compression_ratio": 32,
|
||||
"temporal_compression_ratio": 8,
|
||||
}
|
||||
VAE_KEYS_RENAME_DICT.update(VAE_095_RENAME_DICT)
|
||||
return config
|
||||
|
||||
|
||||
@@ -223,7 +294,7 @@ def get_args():
|
||||
parser.add_argument("--output_path", type=str, required=True, help="Path where converted model should be saved")
|
||||
parser.add_argument("--dtype", default="fp32", help="Torch dtype to save the model in.")
|
||||
parser.add_argument(
|
||||
"--version", type=str, default="0.9.0", choices=["0.9.0", "0.9.1"], help="Version of the LTX model"
|
||||
"--version", type=str, default="0.9.0", choices=["0.9.0", "0.9.1", "0.9.5"], help="Version of the LTX model"
|
||||
)
|
||||
return parser.parse_args()
|
||||
|
||||
@@ -277,14 +348,17 @@ if __name__ == "__main__":
|
||||
for param in text_encoder.parameters():
|
||||
param.data = param.data.contiguous()
|
||||
|
||||
scheduler = FlowMatchEulerDiscreteScheduler(
|
||||
use_dynamic_shifting=True,
|
||||
base_shift=0.95,
|
||||
max_shift=2.05,
|
||||
base_image_seq_len=1024,
|
||||
max_image_seq_len=4096,
|
||||
shift_terminal=0.1,
|
||||
)
|
||||
if args.version == "0.9.5":
|
||||
scheduler = FlowMatchEulerDiscreteScheduler(use_dynamic_shifting=False)
|
||||
else:
|
||||
scheduler = FlowMatchEulerDiscreteScheduler(
|
||||
use_dynamic_shifting=True,
|
||||
base_shift=0.95,
|
||||
max_shift=2.05,
|
||||
base_image_seq_len=1024,
|
||||
max_image_seq_len=4096,
|
||||
shift_terminal=0.1,
|
||||
)
|
||||
|
||||
pipe = LTXPipeline(
|
||||
scheduler=scheduler,
|
||||
|
||||
@@ -16,7 +16,9 @@ from diffusers import (
|
||||
DPMSolverMultistepScheduler,
|
||||
FlowMatchEulerDiscreteScheduler,
|
||||
SanaPipeline,
|
||||
SanaSprintPipeline,
|
||||
SanaTransformer2DModel,
|
||||
SCMScheduler,
|
||||
)
|
||||
from diffusers.models.modeling_utils import load_model_dict_into_meta
|
||||
from diffusers.utils.import_utils import is_accelerate_available
|
||||
@@ -25,6 +27,10 @@ from diffusers.utils.import_utils import is_accelerate_available
|
||||
CTX = init_empty_weights if is_accelerate_available else nullcontext
|
||||
|
||||
ckpt_ids = [
|
||||
"Efficient-Large-Model/Sana_Sprint_0.6B_1024px/checkpoints/Sana_Sprint_0.6B_1024px.pth"
|
||||
"Efficient-Large-Model/Sana_Sprint_1.6B_1024px/checkpoints/Sana_Sprint_1.6B_1024px.pth"
|
||||
"Efficient-Large-Model/SANA1.5_4.8B_1024px/checkpoints/SANA1.5_4.8B_1024px.pth",
|
||||
"Efficient-Large-Model/SANA1.5_1.6B_1024px/checkpoints/SANA1.5_1.6B_1024px.pth",
|
||||
"Efficient-Large-Model/Sana_1600M_4Kpx_BF16/checkpoints/Sana_1600M_4Kpx_BF16.pth",
|
||||
"Efficient-Large-Model/Sana_1600M_2Kpx_BF16/checkpoints/Sana_1600M_2Kpx_BF16.pth",
|
||||
"Efficient-Large-Model/Sana_1600M_1024px_MultiLing/checkpoints/Sana_1600M_1024px_MultiLing.pth",
|
||||
@@ -72,15 +78,42 @@ def main(args):
|
||||
converted_state_dict["caption_projection.linear_2.weight"] = state_dict.pop("y_embedder.y_proj.fc2.weight")
|
||||
converted_state_dict["caption_projection.linear_2.bias"] = state_dict.pop("y_embedder.y_proj.fc2.bias")
|
||||
|
||||
# AdaLN-single LN
|
||||
converted_state_dict["time_embed.emb.timestep_embedder.linear_1.weight"] = state_dict.pop(
|
||||
"t_embedder.mlp.0.weight"
|
||||
)
|
||||
converted_state_dict["time_embed.emb.timestep_embedder.linear_1.bias"] = state_dict.pop("t_embedder.mlp.0.bias")
|
||||
converted_state_dict["time_embed.emb.timestep_embedder.linear_2.weight"] = state_dict.pop(
|
||||
"t_embedder.mlp.2.weight"
|
||||
)
|
||||
converted_state_dict["time_embed.emb.timestep_embedder.linear_2.bias"] = state_dict.pop("t_embedder.mlp.2.bias")
|
||||
# Handle different time embedding structure based on model type
|
||||
|
||||
if args.model_type in ["SanaSprint_1600M_P1_D20", "SanaSprint_600M_P1_D28"]:
|
||||
# For Sana Sprint, the time embedding structure is different
|
||||
converted_state_dict["time_embed.timestep_embedder.linear_1.weight"] = state_dict.pop(
|
||||
"t_embedder.mlp.0.weight"
|
||||
)
|
||||
converted_state_dict["time_embed.timestep_embedder.linear_1.bias"] = state_dict.pop("t_embedder.mlp.0.bias")
|
||||
converted_state_dict["time_embed.timestep_embedder.linear_2.weight"] = state_dict.pop(
|
||||
"t_embedder.mlp.2.weight"
|
||||
)
|
||||
converted_state_dict["time_embed.timestep_embedder.linear_2.bias"] = state_dict.pop("t_embedder.mlp.2.bias")
|
||||
|
||||
# Guidance embedder for Sana Sprint
|
||||
converted_state_dict["time_embed.guidance_embedder.linear_1.weight"] = state_dict.pop(
|
||||
"cfg_embedder.mlp.0.weight"
|
||||
)
|
||||
converted_state_dict["time_embed.guidance_embedder.linear_1.bias"] = state_dict.pop("cfg_embedder.mlp.0.bias")
|
||||
converted_state_dict["time_embed.guidance_embedder.linear_2.weight"] = state_dict.pop(
|
||||
"cfg_embedder.mlp.2.weight"
|
||||
)
|
||||
converted_state_dict["time_embed.guidance_embedder.linear_2.bias"] = state_dict.pop("cfg_embedder.mlp.2.bias")
|
||||
else:
|
||||
# Original Sana time embedding structure
|
||||
converted_state_dict["time_embed.emb.timestep_embedder.linear_1.weight"] = state_dict.pop(
|
||||
"t_embedder.mlp.0.weight"
|
||||
)
|
||||
converted_state_dict["time_embed.emb.timestep_embedder.linear_1.bias"] = state_dict.pop(
|
||||
"t_embedder.mlp.0.bias"
|
||||
)
|
||||
converted_state_dict["time_embed.emb.timestep_embedder.linear_2.weight"] = state_dict.pop(
|
||||
"t_embedder.mlp.2.weight"
|
||||
)
|
||||
converted_state_dict["time_embed.emb.timestep_embedder.linear_2.bias"] = state_dict.pop(
|
||||
"t_embedder.mlp.2.bias"
|
||||
)
|
||||
|
||||
# Shared norm.
|
||||
converted_state_dict["time_embed.linear.weight"] = state_dict.pop("t_block.1.weight")
|
||||
@@ -96,14 +129,22 @@ def main(args):
|
||||
flow_shift = 3.0
|
||||
|
||||
# model config
|
||||
if args.model_type == "SanaMS_1600M_P1_D20":
|
||||
if args.model_type in ["SanaMS_1600M_P1_D20", "SanaSprint_1600M_P1_D20", "SanaMS1.5_1600M_P1_D20"]:
|
||||
layer_num = 20
|
||||
elif args.model_type == "SanaMS_600M_P1_D28":
|
||||
elif args.model_type in ["SanaMS_600M_P1_D28", "SanaSprint_600M_P1_D28"]:
|
||||
layer_num = 28
|
||||
elif args.model_type == "SanaMS_4800M_P1_D60":
|
||||
layer_num = 60
|
||||
else:
|
||||
raise ValueError(f"{args.model_type} is not supported.")
|
||||
# Positional embedding interpolation scale.
|
||||
interpolation_scale = {512: None, 1024: None, 2048: 1.0, 4096: 2.0}
|
||||
qk_norm = (
|
||||
"rms_norm_across_heads"
|
||||
if args.model_type
|
||||
in ["SanaMS1.5_1600M_P1_D20", "SanaMS1.5_4800M_P1_D60", "SanaSprint_600M_P1_D28", "SanaSprint_1600M_P1_D20"]
|
||||
else None
|
||||
)
|
||||
|
||||
for depth in range(layer_num):
|
||||
# Transformer blocks.
|
||||
@@ -117,6 +158,14 @@ def main(args):
|
||||
converted_state_dict[f"transformer_blocks.{depth}.attn1.to_q.weight"] = q
|
||||
converted_state_dict[f"transformer_blocks.{depth}.attn1.to_k.weight"] = k
|
||||
converted_state_dict[f"transformer_blocks.{depth}.attn1.to_v.weight"] = v
|
||||
if qk_norm is not None:
|
||||
# Add Q/K normalization for self-attention (attn1) - needed for Sana-Sprint and Sana-1.5
|
||||
converted_state_dict[f"transformer_blocks.{depth}.attn1.norm_q.weight"] = state_dict.pop(
|
||||
f"blocks.{depth}.attn.q_norm.weight"
|
||||
)
|
||||
converted_state_dict[f"transformer_blocks.{depth}.attn1.norm_k.weight"] = state_dict.pop(
|
||||
f"blocks.{depth}.attn.k_norm.weight"
|
||||
)
|
||||
# Projection.
|
||||
converted_state_dict[f"transformer_blocks.{depth}.attn1.to_out.0.weight"] = state_dict.pop(
|
||||
f"blocks.{depth}.attn.proj.weight"
|
||||
@@ -154,6 +203,14 @@ def main(args):
|
||||
converted_state_dict[f"transformer_blocks.{depth}.attn2.to_k.bias"] = k_bias
|
||||
converted_state_dict[f"transformer_blocks.{depth}.attn2.to_v.weight"] = v
|
||||
converted_state_dict[f"transformer_blocks.{depth}.attn2.to_v.bias"] = v_bias
|
||||
if qk_norm is not None:
|
||||
# Add Q/K normalization for cross-attention (attn2) - needed for Sana-Sprint and Sana-1.5
|
||||
converted_state_dict[f"transformer_blocks.{depth}.attn2.norm_q.weight"] = state_dict.pop(
|
||||
f"blocks.{depth}.cross_attn.q_norm.weight"
|
||||
)
|
||||
converted_state_dict[f"transformer_blocks.{depth}.attn2.norm_k.weight"] = state_dict.pop(
|
||||
f"blocks.{depth}.cross_attn.k_norm.weight"
|
||||
)
|
||||
|
||||
converted_state_dict[f"transformer_blocks.{depth}.attn2.to_out.0.weight"] = state_dict.pop(
|
||||
f"blocks.{depth}.cross_attn.proj.weight"
|
||||
@@ -169,24 +226,37 @@ def main(args):
|
||||
|
||||
# Transformer
|
||||
with CTX():
|
||||
transformer = SanaTransformer2DModel(
|
||||
in_channels=32,
|
||||
out_channels=32,
|
||||
num_attention_heads=model_kwargs[args.model_type]["num_attention_heads"],
|
||||
attention_head_dim=model_kwargs[args.model_type]["attention_head_dim"],
|
||||
num_layers=model_kwargs[args.model_type]["num_layers"],
|
||||
num_cross_attention_heads=model_kwargs[args.model_type]["num_cross_attention_heads"],
|
||||
cross_attention_head_dim=model_kwargs[args.model_type]["cross_attention_head_dim"],
|
||||
cross_attention_dim=model_kwargs[args.model_type]["cross_attention_dim"],
|
||||
caption_channels=2304,
|
||||
mlp_ratio=2.5,
|
||||
attention_bias=False,
|
||||
sample_size=args.image_size // 32,
|
||||
patch_size=1,
|
||||
norm_elementwise_affine=False,
|
||||
norm_eps=1e-6,
|
||||
interpolation_scale=interpolation_scale[args.image_size],
|
||||
)
|
||||
transformer_kwargs = {
|
||||
"in_channels": 32,
|
||||
"out_channels": 32,
|
||||
"num_attention_heads": model_kwargs[args.model_type]["num_attention_heads"],
|
||||
"attention_head_dim": model_kwargs[args.model_type]["attention_head_dim"],
|
||||
"num_layers": model_kwargs[args.model_type]["num_layers"],
|
||||
"num_cross_attention_heads": model_kwargs[args.model_type]["num_cross_attention_heads"],
|
||||
"cross_attention_head_dim": model_kwargs[args.model_type]["cross_attention_head_dim"],
|
||||
"cross_attention_dim": model_kwargs[args.model_type]["cross_attention_dim"],
|
||||
"caption_channels": 2304,
|
||||
"mlp_ratio": 2.5,
|
||||
"attention_bias": False,
|
||||
"sample_size": args.image_size // 32,
|
||||
"patch_size": 1,
|
||||
"norm_elementwise_affine": False,
|
||||
"norm_eps": 1e-6,
|
||||
"interpolation_scale": interpolation_scale[args.image_size],
|
||||
}
|
||||
|
||||
# Add qk_norm parameter for Sana Sprint
|
||||
if args.model_type in [
|
||||
"SanaMS1.5_1600M_P1_D20",
|
||||
"SanaMS1.5_4800M_P1_D60",
|
||||
"SanaSprint_600M_P1_D28",
|
||||
"SanaSprint_1600M_P1_D20",
|
||||
]:
|
||||
transformer_kwargs["qk_norm"] = "rms_norm_across_heads"
|
||||
if args.model_type in ["SanaSprint_1600M_P1_D20", "SanaSprint_600M_P1_D28"]:
|
||||
transformer_kwargs["guidance_embeds"] = True
|
||||
|
||||
transformer = SanaTransformer2DModel(**transformer_kwargs)
|
||||
|
||||
if is_accelerate_available():
|
||||
load_model_dict_into_meta(transformer, converted_state_dict)
|
||||
@@ -196,6 +266,8 @@ def main(args):
|
||||
try:
|
||||
state_dict.pop("y_embedder.y_embedding")
|
||||
state_dict.pop("pos_embed")
|
||||
state_dict.pop("logvar_linear.weight")
|
||||
state_dict.pop("logvar_linear.bias")
|
||||
except KeyError:
|
||||
print("y_embedder.y_embedding or pos_embed not found in the state_dict")
|
||||
|
||||
@@ -210,47 +282,74 @@ def main(args):
|
||||
print(
|
||||
colored(
|
||||
f"Only saving transformer model of {args.model_type}. "
|
||||
f"Set --save_full_pipeline to save the whole SanaPipeline",
|
||||
f"Set --save_full_pipeline to save the whole Pipeline",
|
||||
"green",
|
||||
attrs=["bold"],
|
||||
)
|
||||
)
|
||||
transformer.save_pretrained(
|
||||
os.path.join(args.dump_path, "transformer"), safe_serialization=True, max_shard_size="5GB", variant=variant
|
||||
os.path.join(args.dump_path, "transformer"), safe_serialization=True, max_shard_size="5GB"
|
||||
)
|
||||
else:
|
||||
print(colored(f"Saving the whole SanaPipeline containing {args.model_type}", "green", attrs=["bold"]))
|
||||
print(colored(f"Saving the whole Pipeline containing {args.model_type}", "green", attrs=["bold"]))
|
||||
# VAE
|
||||
ae = AutoencoderDC.from_pretrained("mit-han-lab/dc-ae-f32c32-sana-1.0-diffusers", torch_dtype=torch.float32)
|
||||
ae = AutoencoderDC.from_pretrained("mit-han-lab/dc-ae-f32c32-sana-1.1-diffusers", torch_dtype=torch.float32)
|
||||
|
||||
# Text Encoder
|
||||
text_encoder_model_path = "google/gemma-2-2b-it"
|
||||
text_encoder_model_path = "Efficient-Large-Model/gemma-2-2b-it"
|
||||
tokenizer = AutoTokenizer.from_pretrained(text_encoder_model_path)
|
||||
tokenizer.padding_side = "right"
|
||||
text_encoder = AutoModelForCausalLM.from_pretrained(
|
||||
text_encoder_model_path, torch_dtype=torch.bfloat16
|
||||
).get_decoder()
|
||||
|
||||
# Scheduler
|
||||
if args.scheduler_type == "flow-dpm_solver":
|
||||
scheduler = DPMSolverMultistepScheduler(
|
||||
flow_shift=flow_shift,
|
||||
use_flow_sigmas=True,
|
||||
prediction_type="flow_prediction",
|
||||
)
|
||||
elif args.scheduler_type == "flow-euler":
|
||||
scheduler = FlowMatchEulerDiscreteScheduler(shift=flow_shift)
|
||||
else:
|
||||
raise ValueError(f"Scheduler type {args.scheduler_type} is not supported")
|
||||
# Choose the appropriate pipeline and scheduler based on model type
|
||||
if args.model_type in ["SanaSprint_1600M_P1_D20", "SanaSprint_600M_P1_D28"]:
|
||||
# Force SCM Scheduler for Sana Sprint regardless of scheduler_type
|
||||
if args.scheduler_type != "scm":
|
||||
print(
|
||||
colored(
|
||||
f"Warning: Overriding scheduler_type '{args.scheduler_type}' to 'scm' for SanaSprint model",
|
||||
"yellow",
|
||||
attrs=["bold"],
|
||||
)
|
||||
)
|
||||
|
||||
pipe = SanaPipeline(
|
||||
tokenizer=tokenizer,
|
||||
text_encoder=text_encoder,
|
||||
transformer=transformer,
|
||||
vae=ae,
|
||||
scheduler=scheduler,
|
||||
)
|
||||
pipe.save_pretrained(args.dump_path, safe_serialization=True, max_shard_size="5GB", variant=variant)
|
||||
# SCM Scheduler for Sana Sprint
|
||||
scheduler_config = {
|
||||
"prediction_type": "trigflow",
|
||||
"sigma_data": 0.5,
|
||||
}
|
||||
scheduler = SCMScheduler(**scheduler_config)
|
||||
pipe = SanaSprintPipeline(
|
||||
tokenizer=tokenizer,
|
||||
text_encoder=text_encoder,
|
||||
transformer=transformer,
|
||||
vae=ae,
|
||||
scheduler=scheduler,
|
||||
)
|
||||
else:
|
||||
# Original Sana scheduler
|
||||
if args.scheduler_type == "flow-dpm_solver":
|
||||
scheduler = DPMSolverMultistepScheduler(
|
||||
flow_shift=flow_shift,
|
||||
use_flow_sigmas=True,
|
||||
prediction_type="flow_prediction",
|
||||
)
|
||||
elif args.scheduler_type == "flow-euler":
|
||||
scheduler = FlowMatchEulerDiscreteScheduler(shift=flow_shift)
|
||||
else:
|
||||
raise ValueError(f"Scheduler type {args.scheduler_type} is not supported")
|
||||
|
||||
pipe = SanaPipeline(
|
||||
tokenizer=tokenizer,
|
||||
text_encoder=text_encoder,
|
||||
transformer=transformer,
|
||||
vae=ae,
|
||||
scheduler=scheduler,
|
||||
)
|
||||
|
||||
pipe.save_pretrained(args.dump_path, safe_serialization=True, max_shard_size="5GB")
|
||||
|
||||
|
||||
DTYPE_MAPPING = {
|
||||
@@ -259,12 +358,6 @@ DTYPE_MAPPING = {
|
||||
"bf16": torch.bfloat16,
|
||||
}
|
||||
|
||||
VARIANT_MAPPING = {
|
||||
"fp32": None,
|
||||
"fp16": "fp16",
|
||||
"bf16": "bf16",
|
||||
}
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
@@ -281,10 +374,24 @@ if __name__ == "__main__":
|
||||
help="Image size of pretrained model, 512, 1024, 2048 or 4096.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--model_type", default="SanaMS_1600M_P1_D20", type=str, choices=["SanaMS_1600M_P1_D20", "SanaMS_600M_P1_D28"]
|
||||
"--model_type",
|
||||
default="SanaMS_1600M_P1_D20",
|
||||
type=str,
|
||||
choices=[
|
||||
"SanaMS_1600M_P1_D20",
|
||||
"SanaMS_600M_P1_D28",
|
||||
"SanaMS1.5_1600M_P1_D20",
|
||||
"SanaMS1.5_4800M_P1_D60",
|
||||
"SanaSprint_1600M_P1_D20",
|
||||
"SanaSprint_600M_P1_D28",
|
||||
],
|
||||
)
|
||||
parser.add_argument(
|
||||
"--scheduler_type", default="flow-dpm_solver", type=str, choices=["flow-dpm_solver", "flow-euler"]
|
||||
"--scheduler_type",
|
||||
default="flow-dpm_solver",
|
||||
type=str,
|
||||
choices=["flow-dpm_solver", "flow-euler", "scm"],
|
||||
help="Scheduler type to use. Use 'scm' for Sana Sprint models.",
|
||||
)
|
||||
parser.add_argument("--dump_path", default=None, type=str, required=True, help="Path to the output pipeline.")
|
||||
parser.add_argument("--save_full_pipeline", action="store_true", help="save all the pipelien elemets in one.")
|
||||
@@ -309,10 +416,41 @@ if __name__ == "__main__":
|
||||
"cross_attention_dim": 1152,
|
||||
"num_layers": 28,
|
||||
},
|
||||
"SanaMS1.5_1600M_P1_D20": {
|
||||
"num_attention_heads": 70,
|
||||
"attention_head_dim": 32,
|
||||
"num_cross_attention_heads": 20,
|
||||
"cross_attention_head_dim": 112,
|
||||
"cross_attention_dim": 2240,
|
||||
"num_layers": 20,
|
||||
},
|
||||
"SanaMS1.5_4800M_P1_D60": {
|
||||
"num_attention_heads": 70,
|
||||
"attention_head_dim": 32,
|
||||
"num_cross_attention_heads": 20,
|
||||
"cross_attention_head_dim": 112,
|
||||
"cross_attention_dim": 2240,
|
||||
"num_layers": 60,
|
||||
},
|
||||
"SanaSprint_600M_P1_D28": {
|
||||
"num_attention_heads": 36,
|
||||
"attention_head_dim": 32,
|
||||
"num_cross_attention_heads": 16,
|
||||
"cross_attention_head_dim": 72,
|
||||
"cross_attention_dim": 1152,
|
||||
"num_layers": 28,
|
||||
},
|
||||
"SanaSprint_1600M_P1_D20": {
|
||||
"num_attention_heads": 70,
|
||||
"attention_head_dim": 32,
|
||||
"num_cross_attention_heads": 20,
|
||||
"cross_attention_head_dim": 112,
|
||||
"cross_attention_dim": 2240,
|
||||
"num_layers": 20,
|
||||
},
|
||||
}
|
||||
|
||||
device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||
weight_dtype = DTYPE_MAPPING[args.dtype]
|
||||
variant = VARIANT_MAPPING[args.dtype]
|
||||
|
||||
main(args)
|
||||
|
||||
@@ -131,8 +131,10 @@ except OptionalDependencyNotAvailable:
|
||||
else:
|
||||
_import_structure["hooks"].extend(
|
||||
[
|
||||
"FasterCacheConfig",
|
||||
"HookRegistry",
|
||||
"PyramidAttentionBroadcastConfig",
|
||||
"apply_faster_cache",
|
||||
"apply_pyramid_attention_broadcast",
|
||||
]
|
||||
)
|
||||
@@ -271,6 +273,7 @@ else:
|
||||
"RePaintScheduler",
|
||||
"SASolverScheduler",
|
||||
"SchedulerMixin",
|
||||
"SCMScheduler",
|
||||
"ScoreSdeVeScheduler",
|
||||
"TCDScheduler",
|
||||
"UnCLIPScheduler",
|
||||
@@ -402,6 +405,7 @@ else:
|
||||
"LDMTextToImagePipeline",
|
||||
"LEditsPPPipelineStableDiffusion",
|
||||
"LEditsPPPipelineStableDiffusionXL",
|
||||
"LTXConditionPipeline",
|
||||
"LTXImageToVideoPipeline",
|
||||
"LTXPipeline",
|
||||
"Lumina2Pipeline",
|
||||
@@ -422,6 +426,7 @@ else:
|
||||
"ReduxImageEncoder",
|
||||
"SanaPAGPipeline",
|
||||
"SanaPipeline",
|
||||
"SanaSprintPipeline",
|
||||
"SemanticStableDiffusionPipeline",
|
||||
"ShapEImg2ImgPipeline",
|
||||
"ShapEPipeline",
|
||||
@@ -702,7 +707,13 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
except OptionalDependencyNotAvailable:
|
||||
from .utils.dummy_pt_objects import * # noqa F403
|
||||
else:
|
||||
from .hooks import HookRegistry, PyramidAttentionBroadcastConfig, apply_pyramid_attention_broadcast
|
||||
from .hooks import (
|
||||
FasterCacheConfig,
|
||||
HookRegistry,
|
||||
PyramidAttentionBroadcastConfig,
|
||||
apply_faster_cache,
|
||||
apply_pyramid_attention_broadcast,
|
||||
)
|
||||
from .models import (
|
||||
AllegroTransformer3DModel,
|
||||
AsymmetricAutoencoderKL,
|
||||
@@ -835,6 +846,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
RePaintScheduler,
|
||||
SASolverScheduler,
|
||||
SchedulerMixin,
|
||||
SCMScheduler,
|
||||
ScoreSdeVeScheduler,
|
||||
TCDScheduler,
|
||||
UnCLIPScheduler,
|
||||
@@ -947,6 +959,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
LDMTextToImagePipeline,
|
||||
LEditsPPPipelineStableDiffusion,
|
||||
LEditsPPPipelineStableDiffusionXL,
|
||||
LTXConditionPipeline,
|
||||
LTXImageToVideoPipeline,
|
||||
LTXPipeline,
|
||||
Lumina2Pipeline,
|
||||
@@ -967,6 +980,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
ReduxImageEncoder,
|
||||
SanaPAGPipeline,
|
||||
SanaPipeline,
|
||||
SanaSprintPipeline,
|
||||
SemanticStableDiffusionPipeline,
|
||||
ShapEImg2ImgPipeline,
|
||||
ShapEPipeline,
|
||||
|
||||
@@ -2,6 +2,7 @@ from ..utils import is_torch_available
|
||||
|
||||
|
||||
if is_torch_available():
|
||||
from .faster_cache import FasterCacheConfig, apply_faster_cache
|
||||
from .group_offloading import apply_group_offloading
|
||||
from .hooks import HookRegistry, ModelHook
|
||||
from .layerwise_casting import apply_layerwise_casting, apply_layerwise_casting_hook
|
||||
|
||||
@@ -0,0 +1,653 @@
|
||||
# Copyright 2024 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import re
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Callable, List, Optional, Tuple
|
||||
|
||||
import torch
|
||||
|
||||
from ..models.attention_processor import Attention, MochiAttention
|
||||
from ..models.modeling_outputs import Transformer2DModelOutput
|
||||
from ..utils import logging
|
||||
from .hooks import HookRegistry, ModelHook
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||
|
||||
|
||||
_FASTER_CACHE_DENOISER_HOOK = "faster_cache_denoiser"
|
||||
_FASTER_CACHE_BLOCK_HOOK = "faster_cache_block"
|
||||
_ATTENTION_CLASSES = (Attention, MochiAttention)
|
||||
_SPATIAL_ATTENTION_BLOCK_IDENTIFIERS = (
|
||||
"^blocks.*attn",
|
||||
"^transformer_blocks.*attn",
|
||||
"^single_transformer_blocks.*attn",
|
||||
)
|
||||
_TEMPORAL_ATTENTION_BLOCK_IDENTIFIERS = ("^temporal_transformer_blocks.*attn",)
|
||||
_TRANSFORMER_BLOCK_IDENTIFIERS = _SPATIAL_ATTENTION_BLOCK_IDENTIFIERS + _TEMPORAL_ATTENTION_BLOCK_IDENTIFIERS
|
||||
_UNCOND_COND_INPUT_KWARGS_IDENTIFIERS = (
|
||||
"hidden_states",
|
||||
"encoder_hidden_states",
|
||||
"timestep",
|
||||
"attention_mask",
|
||||
"encoder_attention_mask",
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class FasterCacheConfig:
|
||||
r"""
|
||||
Configuration for [FasterCache](https://huggingface.co/papers/2410.19355).
|
||||
|
||||
Attributes:
|
||||
spatial_attention_block_skip_range (`int`, defaults to `2`):
|
||||
Calculate the attention states every `N` iterations. If this is set to `N`, the attention computation will
|
||||
be skipped `N - 1` times (i.e., cached attention states will be re-used) before computing the new attention
|
||||
states again.
|
||||
temporal_attention_block_skip_range (`int`, *optional*, defaults to `None`):
|
||||
Calculate the attention states every `N` iterations. If this is set to `N`, the attention computation will
|
||||
be skipped `N - 1` times (i.e., cached attention states will be re-used) before computing the new attention
|
||||
states again.
|
||||
spatial_attention_timestep_skip_range (`Tuple[float, float]`, defaults to `(-1, 681)`):
|
||||
The timestep range within which the spatial attention computation can be skipped without a significant loss
|
||||
in quality. This is to be determined by the user based on the underlying model. The first value in the
|
||||
tuple is the lower bound and the second value is the upper bound. Typically, diffusion timesteps for
|
||||
denoising are in the reversed range of 0 to 1000 (i.e. denoising starts at timestep 1000 and ends at
|
||||
timestep 0). For the default values, this would mean that the spatial attention computation skipping will
|
||||
be applicable only after denoising timestep 681 is reached, and continue until the end of the denoising
|
||||
process.
|
||||
temporal_attention_timestep_skip_range (`Tuple[float, float]`, *optional*, defaults to `None`):
|
||||
The timestep range within which the temporal attention computation can be skipped without a significant
|
||||
loss in quality. This is to be determined by the user based on the underlying model. The first value in the
|
||||
tuple is the lower bound and the second value is the upper bound. Typically, diffusion timesteps for
|
||||
denoising are in the reversed range of 0 to 1000 (i.e. denoising starts at timestep 1000 and ends at
|
||||
timestep 0).
|
||||
low_frequency_weight_update_timestep_range (`Tuple[int, int]`, defaults to `(99, 901)`):
|
||||
The timestep range within which the low frequency weight scaling update is applied. The first value in the
|
||||
tuple is the lower bound and the second value is the upper bound of the timestep range. The callback
|
||||
function for the update is called only within this range.
|
||||
high_frequency_weight_update_timestep_range (`Tuple[int, int]`, defaults to `(-1, 301)`):
|
||||
The timestep range within which the high frequency weight scaling update is applied. The first value in the
|
||||
tuple is the lower bound and the second value is the upper bound of the timestep range. The callback
|
||||
function for the update is called only within this range.
|
||||
alpha_low_frequency (`float`, defaults to `1.1`):
|
||||
The weight to scale the low frequency updates by. This is used to approximate the unconditional branch from
|
||||
the conditional branch outputs.
|
||||
alpha_high_frequency (`float`, defaults to `1.1`):
|
||||
The weight to scale the high frequency updates by. This is used to approximate the unconditional branch
|
||||
from the conditional branch outputs.
|
||||
unconditional_batch_skip_range (`int`, defaults to `5`):
|
||||
Process the unconditional branch every `N` iterations. If this is set to `N`, the unconditional branch
|
||||
computation will be skipped `N - 1` times (i.e., cached unconditional branch states will be re-used) before
|
||||
computing the new unconditional branch states again.
|
||||
unconditional_batch_timestep_skip_range (`Tuple[float, float]`, defaults to `(-1, 641)`):
|
||||
The timestep range within which the unconditional branch computation can be skipped without a significant
|
||||
loss in quality. This is to be determined by the user based on the underlying model. The first value in the
|
||||
tuple is the lower bound and the second value is the upper bound.
|
||||
spatial_attention_block_identifiers (`Tuple[str, ...]`, defaults to `("blocks.*attn1", "transformer_blocks.*attn1", "single_transformer_blocks.*attn1")`):
|
||||
The identifiers to match the spatial attention blocks in the model. If the name of the block contains any
|
||||
of these identifiers, FasterCache will be applied to that block. This can either be the full layer names,
|
||||
partial layer names, or regex patterns. Matching will always be done using a regex match.
|
||||
temporal_attention_block_identifiers (`Tuple[str, ...]`, defaults to `("temporal_transformer_blocks.*attn1",)`):
|
||||
The identifiers to match the temporal attention blocks in the model. If the name of the block contains any
|
||||
of these identifiers, FasterCache will be applied to that block. This can either be the full layer names,
|
||||
partial layer names, or regex patterns. Matching will always be done using a regex match.
|
||||
attention_weight_callback (`Callable[[torch.nn.Module], float]`, defaults to `None`):
|
||||
The callback function to determine the weight to scale the attention outputs by. This function should take
|
||||
the attention module as input and return a float value. This is used to approximate the unconditional
|
||||
branch from the conditional branch outputs. If not provided, the default weight is 0.5 for all timesteps.
|
||||
Typically, as described in the paper, this weight should gradually increase from 0 to 1 as the inference
|
||||
progresses. Users are encouraged to experiment and provide custom weight schedules that take into account
|
||||
the number of inference steps and underlying model behaviour as denoising progresses.
|
||||
low_frequency_weight_callback (`Callable[[torch.nn.Module], float]`, defaults to `None`):
|
||||
The callback function to determine the weight to scale the low frequency updates by. If not provided, the
|
||||
default weight is 1.1 for timesteps within the range specified (as described in the paper).
|
||||
high_frequency_weight_callback (`Callable[[torch.nn.Module], float]`, defaults to `None`):
|
||||
The callback function to determine the weight to scale the high frequency updates by. If not provided, the
|
||||
default weight is 1.1 for timesteps within the range specified (as described in the paper).
|
||||
tensor_format (`str`, defaults to `"BCFHW"`):
|
||||
The format of the input tensors. This should be one of `"BCFHW"`, `"BFCHW"`, or `"BCHW"`. The format is
|
||||
used to split individual latent frames in order for low and high frequency components to be computed.
|
||||
is_guidance_distilled (`bool`, defaults to `False`):
|
||||
Whether the model is guidance distilled or not. If the model is guidance distilled, FasterCache will not be
|
||||
applied at the denoiser-level to skip the unconditional branch computation (as there is none).
|
||||
_unconditional_conditional_input_kwargs_identifiers (`List[str]`, defaults to `("hidden_states", "encoder_hidden_states", "timestep", "attention_mask", "encoder_attention_mask")`):
|
||||
The identifiers to match the input kwargs that contain the batchwise-concatenated unconditional and
|
||||
conditional inputs. If the name of the input kwargs contains any of these identifiers, FasterCache will
|
||||
split the inputs into unconditional and conditional branches. This must be a list of exact input kwargs
|
||||
names that contain the batchwise-concatenated unconditional and conditional inputs.
|
||||
"""
|
||||
|
||||
# In the paper and codebase, they hardcode these values to 2. However, it can be made configurable
|
||||
# after some testing. We default to 2 if these parameters are not provided.
|
||||
spatial_attention_block_skip_range: int = 2
|
||||
temporal_attention_block_skip_range: Optional[int] = None
|
||||
|
||||
spatial_attention_timestep_skip_range: Tuple[int, int] = (-1, 681)
|
||||
temporal_attention_timestep_skip_range: Tuple[int, int] = (-1, 681)
|
||||
|
||||
# Indicator functions for low/high frequency as mentioned in Equation 11 of the paper
|
||||
low_frequency_weight_update_timestep_range: Tuple[int, int] = (99, 901)
|
||||
high_frequency_weight_update_timestep_range: Tuple[int, int] = (-1, 301)
|
||||
|
||||
# ⍺1 and ⍺2 as mentioned in Equation 11 of the paper
|
||||
alpha_low_frequency: float = 1.1
|
||||
alpha_high_frequency: float = 1.1
|
||||
|
||||
# n as described in CFG-Cache explanation in the paper - dependant on the model
|
||||
unconditional_batch_skip_range: int = 5
|
||||
unconditional_batch_timestep_skip_range: Tuple[int, int] = (-1, 641)
|
||||
|
||||
spatial_attention_block_identifiers: Tuple[str, ...] = _SPATIAL_ATTENTION_BLOCK_IDENTIFIERS
|
||||
temporal_attention_block_identifiers: Tuple[str, ...] = _TEMPORAL_ATTENTION_BLOCK_IDENTIFIERS
|
||||
|
||||
attention_weight_callback: Callable[[torch.nn.Module], float] = None
|
||||
low_frequency_weight_callback: Callable[[torch.nn.Module], float] = None
|
||||
high_frequency_weight_callback: Callable[[torch.nn.Module], float] = None
|
||||
|
||||
tensor_format: str = "BCFHW"
|
||||
is_guidance_distilled: bool = False
|
||||
|
||||
current_timestep_callback: Callable[[], int] = None
|
||||
|
||||
_unconditional_conditional_input_kwargs_identifiers: List[str] = _UNCOND_COND_INPUT_KWARGS_IDENTIFIERS
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return (
|
||||
f"FasterCacheConfig(\n"
|
||||
f" spatial_attention_block_skip_range={self.spatial_attention_block_skip_range},\n"
|
||||
f" temporal_attention_block_skip_range={self.temporal_attention_block_skip_range},\n"
|
||||
f" spatial_attention_timestep_skip_range={self.spatial_attention_timestep_skip_range},\n"
|
||||
f" temporal_attention_timestep_skip_range={self.temporal_attention_timestep_skip_range},\n"
|
||||
f" low_frequency_weight_update_timestep_range={self.low_frequency_weight_update_timestep_range},\n"
|
||||
f" high_frequency_weight_update_timestep_range={self.high_frequency_weight_update_timestep_range},\n"
|
||||
f" alpha_low_frequency={self.alpha_low_frequency},\n"
|
||||
f" alpha_high_frequency={self.alpha_high_frequency},\n"
|
||||
f" unconditional_batch_skip_range={self.unconditional_batch_skip_range},\n"
|
||||
f" unconditional_batch_timestep_skip_range={self.unconditional_batch_timestep_skip_range},\n"
|
||||
f" spatial_attention_block_identifiers={self.spatial_attention_block_identifiers},\n"
|
||||
f" temporal_attention_block_identifiers={self.temporal_attention_block_identifiers},\n"
|
||||
f" tensor_format={self.tensor_format},\n"
|
||||
f")"
|
||||
)
|
||||
|
||||
|
||||
class FasterCacheDenoiserState:
|
||||
r"""
|
||||
State for [FasterCache](https://huggingface.co/papers/2410.19355) top-level denoiser module.
|
||||
"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
self.iteration: int = 0
|
||||
self.low_frequency_delta: torch.Tensor = None
|
||||
self.high_frequency_delta: torch.Tensor = None
|
||||
|
||||
def reset(self):
|
||||
self.iteration = 0
|
||||
self.low_frequency_delta = None
|
||||
self.high_frequency_delta = None
|
||||
|
||||
|
||||
class FasterCacheBlockState:
|
||||
r"""
|
||||
State for [FasterCache](https://huggingface.co/papers/2410.19355). Every underlying block that FasterCache is
|
||||
applied to will have an instance of this state.
|
||||
"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
self.iteration: int = 0
|
||||
self.batch_size: int = None
|
||||
self.cache: Tuple[torch.Tensor, torch.Tensor] = None
|
||||
|
||||
def reset(self):
|
||||
self.iteration = 0
|
||||
self.batch_size = None
|
||||
self.cache = None
|
||||
|
||||
|
||||
class FasterCacheDenoiserHook(ModelHook):
|
||||
_is_stateful = True
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
unconditional_batch_skip_range: int,
|
||||
unconditional_batch_timestep_skip_range: Tuple[int, int],
|
||||
tensor_format: str,
|
||||
is_guidance_distilled: bool,
|
||||
uncond_cond_input_kwargs_identifiers: List[str],
|
||||
current_timestep_callback: Callable[[], int],
|
||||
low_frequency_weight_callback: Callable[[torch.nn.Module], torch.Tensor],
|
||||
high_frequency_weight_callback: Callable[[torch.nn.Module], torch.Tensor],
|
||||
) -> None:
|
||||
super().__init__()
|
||||
|
||||
self.unconditional_batch_skip_range = unconditional_batch_skip_range
|
||||
self.unconditional_batch_timestep_skip_range = unconditional_batch_timestep_skip_range
|
||||
# We can't easily detect what args are to be split in unconditional and conditional branches. We
|
||||
# can only do it for kwargs, hence they are the only ones we split. The args are passed as-is.
|
||||
# If a model is to be made compatible with FasterCache, the user must ensure that the inputs that
|
||||
# contain batchwise-concatenated unconditional and conditional inputs are passed as kwargs.
|
||||
self.uncond_cond_input_kwargs_identifiers = uncond_cond_input_kwargs_identifiers
|
||||
self.tensor_format = tensor_format
|
||||
self.is_guidance_distilled = is_guidance_distilled
|
||||
|
||||
self.current_timestep_callback = current_timestep_callback
|
||||
self.low_frequency_weight_callback = low_frequency_weight_callback
|
||||
self.high_frequency_weight_callback = high_frequency_weight_callback
|
||||
|
||||
def initialize_hook(self, module):
|
||||
self.state = FasterCacheDenoiserState()
|
||||
return module
|
||||
|
||||
@staticmethod
|
||||
def _get_cond_input(input: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
# Note: this method assumes that the input tensor is batchwise-concatenated with unconditional inputs
|
||||
# followed by conditional inputs.
|
||||
_, cond = input.chunk(2, dim=0)
|
||||
return cond
|
||||
|
||||
def new_forward(self, module: torch.nn.Module, *args, **kwargs) -> Any:
|
||||
# Split the unconditional and conditional inputs. We only want to infer the conditional branch if the
|
||||
# requirements for skipping the unconditional branch are met as described in the paper.
|
||||
# We skip the unconditional branch only if the following conditions are met:
|
||||
# 1. We have completed at least one iteration of the denoiser
|
||||
# 2. The current timestep is within the range specified by the user. This is the optimal timestep range
|
||||
# where approximating the unconditional branch from the computation of the conditional branch is possible
|
||||
# without a significant loss in quality.
|
||||
# 3. The current iteration is not a multiple of the unconditional batch skip range. This is done so that
|
||||
# we compute the unconditional branch at least once every few iterations to ensure minimal quality loss.
|
||||
is_within_timestep_range = (
|
||||
self.unconditional_batch_timestep_skip_range[0]
|
||||
< self.current_timestep_callback()
|
||||
< self.unconditional_batch_timestep_skip_range[1]
|
||||
)
|
||||
should_skip_uncond = (
|
||||
self.state.iteration > 0
|
||||
and is_within_timestep_range
|
||||
and self.state.iteration % self.unconditional_batch_skip_range != 0
|
||||
and not self.is_guidance_distilled
|
||||
)
|
||||
|
||||
if should_skip_uncond:
|
||||
is_any_kwarg_uncond = any(k in self.uncond_cond_input_kwargs_identifiers for k in kwargs.keys())
|
||||
if is_any_kwarg_uncond:
|
||||
logger.debug("FasterCache - Skipping unconditional branch computation")
|
||||
args = tuple([self._get_cond_input(arg) if torch.is_tensor(arg) else arg for arg in args])
|
||||
kwargs = {
|
||||
k: v if k not in self.uncond_cond_input_kwargs_identifiers else self._get_cond_input(v)
|
||||
for k, v in kwargs.items()
|
||||
}
|
||||
|
||||
output = self.fn_ref.original_forward(*args, **kwargs)
|
||||
|
||||
if self.is_guidance_distilled:
|
||||
self.state.iteration += 1
|
||||
return output
|
||||
|
||||
if torch.is_tensor(output):
|
||||
hidden_states = output
|
||||
elif isinstance(output, (tuple, Transformer2DModelOutput)):
|
||||
hidden_states = output[0]
|
||||
|
||||
batch_size = hidden_states.size(0)
|
||||
|
||||
if should_skip_uncond:
|
||||
self.state.low_frequency_delta = self.state.low_frequency_delta * self.low_frequency_weight_callback(
|
||||
module
|
||||
)
|
||||
self.state.high_frequency_delta = self.state.high_frequency_delta * self.high_frequency_weight_callback(
|
||||
module
|
||||
)
|
||||
|
||||
if self.tensor_format == "BCFHW":
|
||||
hidden_states = hidden_states.permute(0, 2, 1, 3, 4)
|
||||
if self.tensor_format == "BCFHW" or self.tensor_format == "BFCHW":
|
||||
hidden_states = hidden_states.flatten(0, 1)
|
||||
|
||||
low_freq_cond, high_freq_cond = _split_low_high_freq(hidden_states.float())
|
||||
|
||||
# Approximate/compute the unconditional branch outputs as described in Equation 9 and 10 of the paper
|
||||
low_freq_uncond = self.state.low_frequency_delta + low_freq_cond
|
||||
high_freq_uncond = self.state.high_frequency_delta + high_freq_cond
|
||||
uncond_freq = low_freq_uncond + high_freq_uncond
|
||||
|
||||
uncond_states = torch.fft.ifftshift(uncond_freq)
|
||||
uncond_states = torch.fft.ifft2(uncond_states).real
|
||||
|
||||
if self.tensor_format == "BCFHW" or self.tensor_format == "BFCHW":
|
||||
uncond_states = uncond_states.unflatten(0, (batch_size, -1))
|
||||
hidden_states = hidden_states.unflatten(0, (batch_size, -1))
|
||||
if self.tensor_format == "BCFHW":
|
||||
uncond_states = uncond_states.permute(0, 2, 1, 3, 4)
|
||||
hidden_states = hidden_states.permute(0, 2, 1, 3, 4)
|
||||
|
||||
# Concatenate the approximated unconditional and predicted conditional branches
|
||||
uncond_states = uncond_states.to(hidden_states.dtype)
|
||||
hidden_states = torch.cat([uncond_states, hidden_states], dim=0)
|
||||
else:
|
||||
uncond_states, cond_states = hidden_states.chunk(2, dim=0)
|
||||
if self.tensor_format == "BCFHW":
|
||||
uncond_states = uncond_states.permute(0, 2, 1, 3, 4)
|
||||
cond_states = cond_states.permute(0, 2, 1, 3, 4)
|
||||
if self.tensor_format == "BCFHW" or self.tensor_format == "BFCHW":
|
||||
uncond_states = uncond_states.flatten(0, 1)
|
||||
cond_states = cond_states.flatten(0, 1)
|
||||
|
||||
low_freq_uncond, high_freq_uncond = _split_low_high_freq(uncond_states.float())
|
||||
low_freq_cond, high_freq_cond = _split_low_high_freq(cond_states.float())
|
||||
self.state.low_frequency_delta = low_freq_uncond - low_freq_cond
|
||||
self.state.high_frequency_delta = high_freq_uncond - high_freq_cond
|
||||
|
||||
self.state.iteration += 1
|
||||
if torch.is_tensor(output):
|
||||
output = hidden_states
|
||||
elif isinstance(output, tuple):
|
||||
output = (hidden_states, *output[1:])
|
||||
else:
|
||||
output.sample = hidden_states
|
||||
|
||||
return output
|
||||
|
||||
def reset_state(self, module: torch.nn.Module) -> torch.nn.Module:
|
||||
self.state.reset()
|
||||
return module
|
||||
|
||||
|
||||
class FasterCacheBlockHook(ModelHook):
|
||||
_is_stateful = True
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
block_skip_range: int,
|
||||
timestep_skip_range: Tuple[int, int],
|
||||
is_guidance_distilled: bool,
|
||||
weight_callback: Callable[[torch.nn.Module], float],
|
||||
current_timestep_callback: Callable[[], int],
|
||||
) -> None:
|
||||
super().__init__()
|
||||
|
||||
self.block_skip_range = block_skip_range
|
||||
self.timestep_skip_range = timestep_skip_range
|
||||
self.is_guidance_distilled = is_guidance_distilled
|
||||
|
||||
self.weight_callback = weight_callback
|
||||
self.current_timestep_callback = current_timestep_callback
|
||||
|
||||
def initialize_hook(self, module):
|
||||
self.state = FasterCacheBlockState()
|
||||
return module
|
||||
|
||||
def _compute_approximated_attention_output(
|
||||
self, t_2_output: torch.Tensor, t_output: torch.Tensor, weight: float, batch_size: int
|
||||
) -> torch.Tensor:
|
||||
if t_2_output.size(0) != batch_size:
|
||||
# The cache t_2_output contains both batchwise-concatenated unconditional-conditional branch outputs. Just
|
||||
# take the conditional branch outputs.
|
||||
assert t_2_output.size(0) == 2 * batch_size
|
||||
t_2_output = t_2_output[batch_size:]
|
||||
if t_output.size(0) != batch_size:
|
||||
# The cache t_output contains both batchwise-concatenated unconditional-conditional branch outputs. Just
|
||||
# take the conditional branch outputs.
|
||||
assert t_output.size(0) == 2 * batch_size
|
||||
t_output = t_output[batch_size:]
|
||||
return t_output + (t_output - t_2_output) * weight
|
||||
|
||||
def new_forward(self, module: torch.nn.Module, *args, **kwargs) -> Any:
|
||||
batch_size = [
|
||||
*[arg.size(0) for arg in args if torch.is_tensor(arg)],
|
||||
*[v.size(0) for v in kwargs.values() if torch.is_tensor(v)],
|
||||
][0]
|
||||
if self.state.batch_size is None:
|
||||
# Will be updated on first forward pass through the denoiser
|
||||
self.state.batch_size = batch_size
|
||||
|
||||
# If we have to skip due to the skip conditions, then let's skip as expected.
|
||||
# But, we can't skip if the denoiser wants to infer both unconditional and conditional branches. This
|
||||
# is because the expected output shapes of attention layer will not match if we only return values from
|
||||
# the cache (which only caches conditional branch outputs). So, if state.batch_size (which is the true
|
||||
# unconditional-conditional batch size) is same as the current batch size, we don't perform the layer
|
||||
# skip. Otherwise, we conditionally skip the layer based on what state.skip_callback returns.
|
||||
is_within_timestep_range = (
|
||||
self.timestep_skip_range[0] < self.current_timestep_callback() < self.timestep_skip_range[1]
|
||||
)
|
||||
if not is_within_timestep_range:
|
||||
should_skip_attention = False
|
||||
else:
|
||||
should_compute_attention = self.state.iteration > 0 and self.state.iteration % self.block_skip_range == 0
|
||||
should_skip_attention = not should_compute_attention
|
||||
if should_skip_attention:
|
||||
should_skip_attention = self.is_guidance_distilled or self.state.batch_size != batch_size
|
||||
|
||||
if should_skip_attention:
|
||||
logger.debug("FasterCache - Skipping attention and using approximation")
|
||||
if torch.is_tensor(self.state.cache[-1]):
|
||||
t_2_output, t_output = self.state.cache
|
||||
weight = self.weight_callback(module)
|
||||
output = self._compute_approximated_attention_output(t_2_output, t_output, weight, batch_size)
|
||||
else:
|
||||
# The cache contains multiple tensors from past N iterations (N=2 for FasterCache). We need to handle all of them.
|
||||
# Diffusers blocks can return multiple tensors - let's call them [A, B, C, ...] for simplicity.
|
||||
# In our cache, we would have [[A_1, B_1, C_1, ...], [A_2, B_2, C_2, ...], ...] where each list is the output from
|
||||
# a forward pass of the block. We need to compute the approximated output for each of these tensors.
|
||||
# The zip(*state.cache) operation will give us [(A_1, A_2, ...), (B_1, B_2, ...), (C_1, C_2, ...), ...] which
|
||||
# allows us to compute the approximated attention output for each tensor in the cache.
|
||||
output = ()
|
||||
for t_2_output, t_output in zip(*self.state.cache):
|
||||
result = self._compute_approximated_attention_output(
|
||||
t_2_output, t_output, self.weight_callback(module), batch_size
|
||||
)
|
||||
output += (result,)
|
||||
else:
|
||||
logger.debug("FasterCache - Computing attention")
|
||||
output = self.fn_ref.original_forward(*args, **kwargs)
|
||||
|
||||
# Note that the following condition for getting hidden_states should suffice since Diffusers blocks either return
|
||||
# a single hidden_states tensor, or a tuple of (hidden_states, encoder_hidden_states) tensors. We need to handle
|
||||
# both cases.
|
||||
if torch.is_tensor(output):
|
||||
cache_output = output
|
||||
if not self.is_guidance_distilled and cache_output.size(0) == self.state.batch_size:
|
||||
# The output here can be both unconditional-conditional branch outputs or just conditional branch outputs.
|
||||
# This is determined at the higher-level denoiser module. We only want to cache the conditional branch outputs.
|
||||
cache_output = cache_output.chunk(2, dim=0)[1]
|
||||
else:
|
||||
# Cache all return values and perform the same operation as above
|
||||
cache_output = ()
|
||||
for out in output:
|
||||
if not self.is_guidance_distilled and out.size(0) == self.state.batch_size:
|
||||
out = out.chunk(2, dim=0)[1]
|
||||
cache_output += (out,)
|
||||
|
||||
if self.state.cache is None:
|
||||
self.state.cache = [cache_output, cache_output]
|
||||
else:
|
||||
self.state.cache = [self.state.cache[-1], cache_output]
|
||||
|
||||
self.state.iteration += 1
|
||||
return output
|
||||
|
||||
def reset_state(self, module: torch.nn.Module) -> torch.nn.Module:
|
||||
self.state.reset()
|
||||
return module
|
||||
|
||||
|
||||
def apply_faster_cache(module: torch.nn.Module, config: FasterCacheConfig) -> None:
|
||||
r"""
|
||||
Applies [FasterCache](https://huggingface.co/papers/2410.19355) to a given pipeline.
|
||||
|
||||
Args:
|
||||
pipeline (`DiffusionPipeline`):
|
||||
The diffusion pipeline to apply FasterCache to.
|
||||
config (`Optional[FasterCacheConfig]`, `optional`, defaults to `None`):
|
||||
The configuration to use for FasterCache.
|
||||
|
||||
Example:
|
||||
```python
|
||||
>>> import torch
|
||||
>>> from diffusers import CogVideoXPipeline, FasterCacheConfig, apply_faster_cache
|
||||
|
||||
>>> pipe = CogVideoXPipeline.from_pretrained("THUDM/CogVideoX-5b", torch_dtype=torch.bfloat16)
|
||||
>>> pipe.to("cuda")
|
||||
|
||||
>>> config = FasterCacheConfig(
|
||||
... spatial_attention_block_skip_range=2,
|
||||
... spatial_attention_timestep_skip_range=(-1, 681),
|
||||
... low_frequency_weight_update_timestep_range=(99, 641),
|
||||
... high_frequency_weight_update_timestep_range=(-1, 301),
|
||||
... spatial_attention_block_identifiers=["transformer_blocks"],
|
||||
... attention_weight_callback=lambda _: 0.3,
|
||||
... tensor_format="BFCHW",
|
||||
... )
|
||||
>>> apply_faster_cache(pipe.transformer, config)
|
||||
```
|
||||
"""
|
||||
|
||||
logger.warning(
|
||||
"FasterCache is a purely experimental feature and may not work as expected. Not all models support FasterCache. "
|
||||
"The API is subject to change in future releases, with no guarantee of backward compatibility. Please report any issues at "
|
||||
"https://github.com/huggingface/diffusers/issues."
|
||||
)
|
||||
|
||||
if config.attention_weight_callback is None:
|
||||
# If the user has not provided a weight callback, we default to 0.5 for all timesteps.
|
||||
# In the paper, they recommend using a gradually increasing weight from 0 to 1 as the inference progresses, but
|
||||
# this depends from model-to-model. It is required by the user to provide a weight callback if they want to
|
||||
# use a different weight function. Defaulting to 0.5 works well in practice for most cases.
|
||||
logger.warning(
|
||||
"No `attention_weight_callback` provided when enabling FasterCache. Defaulting to using a weight of 0.5 for all timesteps."
|
||||
)
|
||||
config.attention_weight_callback = lambda _: 0.5
|
||||
|
||||
if config.low_frequency_weight_callback is None:
|
||||
logger.debug(
|
||||
"Low frequency weight callback not provided when enabling FasterCache. Defaulting to behaviour described in the paper."
|
||||
)
|
||||
|
||||
def low_frequency_weight_callback(module: torch.nn.Module) -> float:
|
||||
is_within_range = (
|
||||
config.low_frequency_weight_update_timestep_range[0]
|
||||
< config.current_timestep_callback()
|
||||
< config.low_frequency_weight_update_timestep_range[1]
|
||||
)
|
||||
return config.alpha_low_frequency if is_within_range else 1.0
|
||||
|
||||
config.low_frequency_weight_callback = low_frequency_weight_callback
|
||||
|
||||
if config.high_frequency_weight_callback is None:
|
||||
logger.debug(
|
||||
"High frequency weight callback not provided when enabling FasterCache. Defaulting to behaviour described in the paper."
|
||||
)
|
||||
|
||||
def high_frequency_weight_callback(module: torch.nn.Module) -> float:
|
||||
is_within_range = (
|
||||
config.high_frequency_weight_update_timestep_range[0]
|
||||
< config.current_timestep_callback()
|
||||
< config.high_frequency_weight_update_timestep_range[1]
|
||||
)
|
||||
return config.alpha_high_frequency if is_within_range else 1.0
|
||||
|
||||
config.high_frequency_weight_callback = high_frequency_weight_callback
|
||||
|
||||
supported_tensor_formats = ["BCFHW", "BFCHW", "BCHW"] # TODO(aryan): Support BSC for LTX Video
|
||||
if config.tensor_format not in supported_tensor_formats:
|
||||
raise ValueError(f"`tensor_format` must be one of {supported_tensor_formats}, but got {config.tensor_format}.")
|
||||
|
||||
_apply_faster_cache_on_denoiser(module, config)
|
||||
|
||||
for name, submodule in module.named_modules():
|
||||
if not isinstance(submodule, _ATTENTION_CLASSES):
|
||||
continue
|
||||
if any(re.search(identifier, name) is not None for identifier in _TRANSFORMER_BLOCK_IDENTIFIERS):
|
||||
_apply_faster_cache_on_attention_class(name, submodule, config)
|
||||
|
||||
|
||||
def _apply_faster_cache_on_denoiser(module: torch.nn.Module, config: FasterCacheConfig) -> None:
|
||||
hook = FasterCacheDenoiserHook(
|
||||
config.unconditional_batch_skip_range,
|
||||
config.unconditional_batch_timestep_skip_range,
|
||||
config.tensor_format,
|
||||
config.is_guidance_distilled,
|
||||
config._unconditional_conditional_input_kwargs_identifiers,
|
||||
config.current_timestep_callback,
|
||||
config.low_frequency_weight_callback,
|
||||
config.high_frequency_weight_callback,
|
||||
)
|
||||
registry = HookRegistry.check_if_exists_or_initialize(module)
|
||||
registry.register_hook(hook, _FASTER_CACHE_DENOISER_HOOK)
|
||||
|
||||
|
||||
def _apply_faster_cache_on_attention_class(name: str, module: Attention, config: FasterCacheConfig) -> None:
|
||||
is_spatial_self_attention = (
|
||||
any(re.search(identifier, name) is not None for identifier in config.spatial_attention_block_identifiers)
|
||||
and config.spatial_attention_block_skip_range is not None
|
||||
and not getattr(module, "is_cross_attention", False)
|
||||
)
|
||||
is_temporal_self_attention = (
|
||||
any(re.search(identifier, name) is not None for identifier in config.temporal_attention_block_identifiers)
|
||||
and config.temporal_attention_block_skip_range is not None
|
||||
and not module.is_cross_attention
|
||||
)
|
||||
|
||||
block_skip_range, timestep_skip_range, block_type = None, None, None
|
||||
if is_spatial_self_attention:
|
||||
block_skip_range = config.spatial_attention_block_skip_range
|
||||
timestep_skip_range = config.spatial_attention_timestep_skip_range
|
||||
block_type = "spatial"
|
||||
elif is_temporal_self_attention:
|
||||
block_skip_range = config.temporal_attention_block_skip_range
|
||||
timestep_skip_range = config.temporal_attention_timestep_skip_range
|
||||
block_type = "temporal"
|
||||
|
||||
if block_skip_range is None or timestep_skip_range is None:
|
||||
logger.debug(
|
||||
f'Unable to apply FasterCache to the selected layer: "{name}" because it does '
|
||||
f"not match any of the required criteria for spatial or temporal attention layers. Note, "
|
||||
f"however, that this layer may still be valid for applying PAB. Please specify the correct "
|
||||
f"block identifiers in the configuration or use the specialized `apply_faster_cache_on_module` "
|
||||
f"function to apply FasterCache to this layer."
|
||||
)
|
||||
return
|
||||
|
||||
logger.debug(f"Enabling FasterCache ({block_type}) for layer: {name}")
|
||||
hook = FasterCacheBlockHook(
|
||||
block_skip_range,
|
||||
timestep_skip_range,
|
||||
config.is_guidance_distilled,
|
||||
config.attention_weight_callback,
|
||||
config.current_timestep_callback,
|
||||
)
|
||||
registry = HookRegistry.check_if_exists_or_initialize(module)
|
||||
registry.register_hook(hook, _FASTER_CACHE_BLOCK_HOOK)
|
||||
|
||||
|
||||
# Reference: https://github.com/Vchitect/FasterCache/blob/fab32c15014636dc854948319c0a9a8d92c7acb4/scripts/latte/faster_cache_sample_latte.py#L127C1-L143C39
|
||||
@torch.no_grad()
|
||||
def _split_low_high_freq(x):
|
||||
fft = torch.fft.fft2(x)
|
||||
fft_shifted = torch.fft.fftshift(fft)
|
||||
height, width = x.shape[-2:]
|
||||
radius = min(height, width) // 5
|
||||
|
||||
y_grid, x_grid = torch.meshgrid(torch.arange(height), torch.arange(width))
|
||||
center_x, center_y = width // 2, height // 2
|
||||
mask = (x_grid - center_x) ** 2 + (y_grid - center_y) ** 2 <= radius**2
|
||||
|
||||
low_freq_mask = mask.unsqueeze(0).unsqueeze(0).to(x.device)
|
||||
high_freq_mask = ~low_freq_mask
|
||||
|
||||
low_freq_fft = fft_shifted * low_freq_mask
|
||||
high_freq_fft = fft_shifted * high_freq_mask
|
||||
|
||||
return low_freq_fft, high_freq_fft
|
||||
@@ -12,7 +12,7 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from contextlib import nullcontext
|
||||
from contextlib import contextmanager, nullcontext
|
||||
from typing import Dict, List, Optional, Set, Tuple
|
||||
|
||||
import torch
|
||||
@@ -56,7 +56,7 @@ class ModuleGroup:
|
||||
buffers: Optional[List[torch.Tensor]] = None,
|
||||
non_blocking: bool = False,
|
||||
stream: Optional[torch.cuda.Stream] = None,
|
||||
cpu_param_dict: Optional[Dict[torch.nn.Parameter, torch.Tensor]] = None,
|
||||
low_cpu_mem_usage=False,
|
||||
onload_self: bool = True,
|
||||
) -> None:
|
||||
self.modules = modules
|
||||
@@ -64,15 +64,50 @@ class ModuleGroup:
|
||||
self.onload_device = onload_device
|
||||
self.offload_leader = offload_leader
|
||||
self.onload_leader = onload_leader
|
||||
self.parameters = parameters
|
||||
self.buffers = buffers
|
||||
self.parameters = parameters or []
|
||||
self.buffers = buffers or []
|
||||
self.non_blocking = non_blocking or stream is not None
|
||||
self.stream = stream
|
||||
self.cpu_param_dict = cpu_param_dict
|
||||
self.onload_self = onload_self
|
||||
self.low_cpu_mem_usage = low_cpu_mem_usage
|
||||
|
||||
if self.stream is not None and self.cpu_param_dict is None:
|
||||
raise ValueError("cpu_param_dict must be provided when using stream for data transfer.")
|
||||
self.cpu_param_dict = self._init_cpu_param_dict()
|
||||
|
||||
def _init_cpu_param_dict(self):
|
||||
cpu_param_dict = {}
|
||||
if self.stream is None:
|
||||
return cpu_param_dict
|
||||
|
||||
for module in self.modules:
|
||||
for param in module.parameters():
|
||||
cpu_param_dict[param] = param.data.cpu() if self.low_cpu_mem_usage else param.data.cpu().pin_memory()
|
||||
for buffer in module.buffers():
|
||||
cpu_param_dict[buffer] = (
|
||||
buffer.data.cpu() if self.low_cpu_mem_usage else buffer.data.cpu().pin_memory()
|
||||
)
|
||||
|
||||
for param in self.parameters:
|
||||
cpu_param_dict[param] = param.data.cpu() if self.low_cpu_mem_usage else param.data.cpu().pin_memory()
|
||||
|
||||
for buffer in self.buffers:
|
||||
cpu_param_dict[buffer] = buffer.data.cpu() if self.low_cpu_mem_usage else buffer.data.cpu().pin_memory()
|
||||
|
||||
return cpu_param_dict
|
||||
|
||||
@contextmanager
|
||||
def _pinned_memory_tensors(self):
|
||||
pinned_dict = {}
|
||||
try:
|
||||
for param, tensor in self.cpu_param_dict.items():
|
||||
if not tensor.is_pinned():
|
||||
pinned_dict[param] = tensor.pin_memory()
|
||||
else:
|
||||
pinned_dict[param] = tensor
|
||||
|
||||
yield pinned_dict
|
||||
|
||||
finally:
|
||||
pinned_dict = None
|
||||
|
||||
def onload_(self):
|
||||
r"""Onloads the group of modules to the onload_device."""
|
||||
@@ -82,12 +117,30 @@ class ModuleGroup:
|
||||
self.stream.synchronize()
|
||||
|
||||
with context:
|
||||
for group_module in self.modules:
|
||||
group_module.to(self.onload_device, non_blocking=self.non_blocking)
|
||||
if self.parameters is not None:
|
||||
if self.stream is not None:
|
||||
with self._pinned_memory_tensors() as pinned_memory:
|
||||
for group_module in self.modules:
|
||||
for param in group_module.parameters():
|
||||
param.data = pinned_memory[param].to(self.onload_device, non_blocking=self.non_blocking)
|
||||
for buffer in group_module.buffers():
|
||||
buffer.data = pinned_memory[buffer].to(self.onload_device, non_blocking=self.non_blocking)
|
||||
|
||||
for param in self.parameters:
|
||||
param.data = pinned_memory[param].to(self.onload_device, non_blocking=self.non_blocking)
|
||||
|
||||
for buffer in self.buffers:
|
||||
buffer.data = pinned_memory[buffer].to(self.onload_device, non_blocking=self.non_blocking)
|
||||
|
||||
else:
|
||||
for group_module in self.modules:
|
||||
for param in group_module.parameters():
|
||||
param.data = param.data.to(self.onload_device, non_blocking=self.non_blocking)
|
||||
for buffer in group_module.buffers():
|
||||
buffer.data = buffer.data.to(self.onload_device, non_blocking=self.non_blocking)
|
||||
|
||||
for param in self.parameters:
|
||||
param.data = param.data.to(self.onload_device, non_blocking=self.non_blocking)
|
||||
if self.buffers is not None:
|
||||
|
||||
for buffer in self.buffers:
|
||||
buffer.data = buffer.data.to(self.onload_device, non_blocking=self.non_blocking)
|
||||
|
||||
@@ -98,15 +151,18 @@ class ModuleGroup:
|
||||
for group_module in self.modules:
|
||||
for param in group_module.parameters():
|
||||
param.data = self.cpu_param_dict[param]
|
||||
for param in self.parameters:
|
||||
param.data = self.cpu_param_dict[param]
|
||||
for buffer in self.buffers:
|
||||
buffer.data = self.cpu_param_dict[buffer]
|
||||
|
||||
else:
|
||||
for group_module in self.modules:
|
||||
group_module.to(self.offload_device, non_blocking=self.non_blocking)
|
||||
if self.parameters is not None:
|
||||
for param in self.parameters:
|
||||
param.data = param.data.to(self.offload_device, non_blocking=self.non_blocking)
|
||||
if self.buffers is not None:
|
||||
for buffer in self.buffers:
|
||||
buffer.data = buffer.data.to(self.offload_device, non_blocking=self.non_blocking)
|
||||
for param in self.parameters:
|
||||
param.data = param.data.to(self.offload_device, non_blocking=self.non_blocking)
|
||||
for buffer in self.buffers:
|
||||
buffer.data = buffer.data.to(self.offload_device, non_blocking=self.non_blocking)
|
||||
|
||||
|
||||
class GroupOffloadingHook(ModelHook):
|
||||
@@ -172,6 +228,13 @@ class LazyPrefetchGroupOffloadingHook(ModelHook):
|
||||
self._layer_execution_tracker_module_names = set()
|
||||
|
||||
def initialize_hook(self, module):
|
||||
def make_execution_order_update_callback(current_name, current_submodule):
|
||||
def callback():
|
||||
logger.debug(f"Adding {current_name} to the execution order")
|
||||
self.execution_order.append((current_name, current_submodule))
|
||||
|
||||
return callback
|
||||
|
||||
# To every submodule that contains a group offloading hook (at this point, no prefetching is enabled for any
|
||||
# of the groups), we add a layer execution tracker hook that will be used to determine the order in which the
|
||||
# layers are executed during the forward pass.
|
||||
@@ -183,14 +246,8 @@ class LazyPrefetchGroupOffloadingHook(ModelHook):
|
||||
group_offloading_hook = registry.get_hook(_GROUP_OFFLOADING)
|
||||
|
||||
if group_offloading_hook is not None:
|
||||
|
||||
def make_execution_order_update_callback(current_name, current_submodule):
|
||||
def callback():
|
||||
logger.debug(f"Adding {current_name} to the execution order")
|
||||
self.execution_order.append((current_name, current_submodule))
|
||||
|
||||
return callback
|
||||
|
||||
# For the first forward pass, we have to load in a blocking manner
|
||||
group_offloading_hook.group.non_blocking = False
|
||||
layer_tracker_hook = LayerExecutionTrackerHook(make_execution_order_update_callback(name, submodule))
|
||||
registry.register_hook(layer_tracker_hook, _LAYER_EXECUTION_TRACKER)
|
||||
self._layer_execution_tracker_module_names.add(name)
|
||||
@@ -220,6 +277,7 @@ class LazyPrefetchGroupOffloadingHook(ModelHook):
|
||||
# Remove the layer execution tracker hooks from the submodules
|
||||
base_module_registry = module._diffusers_hook
|
||||
registries = [submodule._diffusers_hook for _, submodule in self.execution_order]
|
||||
group_offloading_hooks = [registry.get_hook(_GROUP_OFFLOADING) for registry in registries]
|
||||
|
||||
for i in range(num_executed):
|
||||
registries[i].remove_hook(_LAYER_EXECUTION_TRACKER, recurse=False)
|
||||
@@ -227,8 +285,13 @@ class LazyPrefetchGroupOffloadingHook(ModelHook):
|
||||
# Remove the current lazy prefetch group offloading hook so that it doesn't interfere with the next forward pass
|
||||
base_module_registry.remove_hook(_LAZY_PREFETCH_GROUP_OFFLOADING, recurse=False)
|
||||
|
||||
# Apply lazy prefetching by setting required attributes
|
||||
group_offloading_hooks = [registry.get_hook(_GROUP_OFFLOADING) for registry in registries]
|
||||
# LazyPrefetchGroupOffloadingHook is only used with streams, so we know that non_blocking should be True.
|
||||
# We disable non_blocking for the first forward pass, but need to enable it for the subsequent passes to
|
||||
# see the benefits of prefetching.
|
||||
for hook in group_offloading_hooks:
|
||||
hook.group.non_blocking = True
|
||||
|
||||
# Set required attributes for prefetching
|
||||
if num_executed > 0:
|
||||
base_module_group_offloading_hook = base_module_registry.get_hook(_GROUP_OFFLOADING)
|
||||
base_module_group_offloading_hook.next_group = group_offloading_hooks[0].group
|
||||
@@ -268,6 +331,7 @@ def apply_group_offloading(
|
||||
num_blocks_per_group: Optional[int] = None,
|
||||
non_blocking: bool = False,
|
||||
use_stream: bool = False,
|
||||
low_cpu_mem_usage: bool = False,
|
||||
) -> None:
|
||||
r"""
|
||||
Applies group offloading to the internal layers of a torch.nn.Module. To understand what group offloading is, and
|
||||
@@ -314,6 +378,10 @@ def apply_group_offloading(
|
||||
use_stream (`bool`, defaults to `False`):
|
||||
If True, offloading and onloading is done asynchronously using a CUDA stream. This can be useful for
|
||||
overlapping computation and data transfer.
|
||||
low_cpu_mem_usage (`bool`, defaults to `False`):
|
||||
If True, the CPU memory usage is minimized by pinning tensors on-the-fly instead of pre-pinning them. This
|
||||
option only matters when using streamed CPU offloading (i.e. `use_stream=True`). This can be useful when
|
||||
the CPU memory is a bottleneck but may counteract the benefits of using streams.
|
||||
|
||||
Example:
|
||||
```python
|
||||
@@ -349,10 +417,12 @@ def apply_group_offloading(
|
||||
raise ValueError("num_blocks_per_group must be provided when using offload_type='block_level'.")
|
||||
|
||||
_apply_group_offloading_block_level(
|
||||
module, num_blocks_per_group, offload_device, onload_device, non_blocking, stream
|
||||
module, num_blocks_per_group, offload_device, onload_device, non_blocking, stream, low_cpu_mem_usage
|
||||
)
|
||||
elif offload_type == "leaf_level":
|
||||
_apply_group_offloading_leaf_level(module, offload_device, onload_device, non_blocking, stream)
|
||||
_apply_group_offloading_leaf_level(
|
||||
module, offload_device, onload_device, non_blocking, stream, low_cpu_mem_usage
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Unsupported offload_type: {offload_type}")
|
||||
|
||||
@@ -364,6 +434,7 @@ def _apply_group_offloading_block_level(
|
||||
onload_device: torch.device,
|
||||
non_blocking: bool,
|
||||
stream: Optional[torch.cuda.Stream] = None,
|
||||
low_cpu_mem_usage: bool = False,
|
||||
) -> None:
|
||||
r"""
|
||||
This function applies offloading to groups of torch.nn.ModuleList or torch.nn.Sequential blocks. In comparison to
|
||||
@@ -384,13 +455,6 @@ def _apply_group_offloading_block_level(
|
||||
for overlapping computation and data transfer.
|
||||
"""
|
||||
|
||||
# Create a pinned CPU parameter dict for async data transfer if streams are to be used
|
||||
cpu_param_dict = None
|
||||
if stream is not None:
|
||||
for param in module.parameters():
|
||||
param.data = param.data.cpu().pin_memory()
|
||||
cpu_param_dict = {param: param.data for param in module.parameters()}
|
||||
|
||||
# Create module groups for ModuleList and Sequential blocks
|
||||
modules_with_group_offloading = set()
|
||||
unmatched_modules = []
|
||||
@@ -411,7 +475,7 @@ def _apply_group_offloading_block_level(
|
||||
onload_leader=current_modules[0],
|
||||
non_blocking=non_blocking,
|
||||
stream=stream,
|
||||
cpu_param_dict=cpu_param_dict,
|
||||
low_cpu_mem_usage=low_cpu_mem_usage,
|
||||
onload_self=stream is None,
|
||||
)
|
||||
matched_module_groups.append(group)
|
||||
@@ -448,7 +512,6 @@ def _apply_group_offloading_block_level(
|
||||
buffers=buffers,
|
||||
non_blocking=False,
|
||||
stream=None,
|
||||
cpu_param_dict=None,
|
||||
onload_self=True,
|
||||
)
|
||||
next_group = matched_module_groups[0] if len(matched_module_groups) > 0 else None
|
||||
@@ -461,6 +524,7 @@ def _apply_group_offloading_leaf_level(
|
||||
onload_device: torch.device,
|
||||
non_blocking: bool,
|
||||
stream: Optional[torch.cuda.Stream] = None,
|
||||
low_cpu_mem_usage: bool = False,
|
||||
) -> None:
|
||||
r"""
|
||||
This function applies offloading to groups of leaf modules in a torch.nn.Module. This method has minimal memory
|
||||
@@ -483,13 +547,6 @@ def _apply_group_offloading_leaf_level(
|
||||
for overlapping computation and data transfer.
|
||||
"""
|
||||
|
||||
# Create a pinned CPU parameter dict for async data transfer if streams are to be used
|
||||
cpu_param_dict = None
|
||||
if stream is not None:
|
||||
for param in module.parameters():
|
||||
param.data = param.data.cpu().pin_memory()
|
||||
cpu_param_dict = {param: param.data for param in module.parameters()}
|
||||
|
||||
# Create module groups for leaf modules and apply group offloading hooks
|
||||
modules_with_group_offloading = set()
|
||||
for name, submodule in module.named_modules():
|
||||
@@ -503,7 +560,7 @@ def _apply_group_offloading_leaf_level(
|
||||
onload_leader=submodule,
|
||||
non_blocking=non_blocking,
|
||||
stream=stream,
|
||||
cpu_param_dict=cpu_param_dict,
|
||||
low_cpu_mem_usage=low_cpu_mem_usage,
|
||||
onload_self=True,
|
||||
)
|
||||
_apply_group_offloading_hook(submodule, group, None)
|
||||
@@ -548,7 +605,7 @@ def _apply_group_offloading_leaf_level(
|
||||
buffers=buffers,
|
||||
non_blocking=non_blocking,
|
||||
stream=stream,
|
||||
cpu_param_dict=cpu_param_dict,
|
||||
low_cpu_mem_usage=low_cpu_mem_usage,
|
||||
onload_self=True,
|
||||
)
|
||||
_apply_group_offloading_hook(parent_module, group, None)
|
||||
@@ -567,7 +624,7 @@ def _apply_group_offloading_leaf_level(
|
||||
buffers=None,
|
||||
non_blocking=False,
|
||||
stream=None,
|
||||
cpu_param_dict=None,
|
||||
low_cpu_mem_usage=low_cpu_mem_usage,
|
||||
onload_self=True,
|
||||
)
|
||||
_apply_lazy_group_offloading_hook(module, unmatched_group, None)
|
||||
|
||||
@@ -26,8 +26,8 @@ from .hooks import HookRegistry, ModelHook
|
||||
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||
|
||||
|
||||
_PYRAMID_ATTENTION_BROADCAST_HOOK = "pyramid_attention_broadcast"
|
||||
_ATTENTION_CLASSES = (Attention, MochiAttention)
|
||||
|
||||
_SPATIAL_ATTENTION_BLOCK_IDENTIFIERS = ("blocks", "transformer_blocks", "single_transformer_blocks")
|
||||
_TEMPORAL_ATTENTION_BLOCK_IDENTIFIERS = ("temporal_transformer_blocks",)
|
||||
_CROSS_ATTENTION_BLOCK_IDENTIFIERS = ("blocks", "transformer_blocks")
|
||||
@@ -87,7 +87,7 @@ class PyramidAttentionBroadcastConfig:
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return (
|
||||
f"PyramidAttentionBroadcastConfig("
|
||||
f"PyramidAttentionBroadcastConfig(\n"
|
||||
f" spatial_attention_block_skip_range={self.spatial_attention_block_skip_range},\n"
|
||||
f" temporal_attention_block_skip_range={self.temporal_attention_block_skip_range},\n"
|
||||
f" cross_attention_block_skip_range={self.cross_attention_block_skip_range},\n"
|
||||
@@ -175,10 +175,7 @@ class PyramidAttentionBroadcastHook(ModelHook):
|
||||
return module
|
||||
|
||||
|
||||
def apply_pyramid_attention_broadcast(
|
||||
module: torch.nn.Module,
|
||||
config: PyramidAttentionBroadcastConfig,
|
||||
):
|
||||
def apply_pyramid_attention_broadcast(module: torch.nn.Module, config: PyramidAttentionBroadcastConfig):
|
||||
r"""
|
||||
Apply [Pyramid Attention Broadcast](https://huggingface.co/papers/2408.12588) to a given pipeline.
|
||||
|
||||
@@ -311,4 +308,4 @@ def _apply_pyramid_attention_broadcast_hook(
|
||||
"""
|
||||
registry = HookRegistry.check_if_exists_or_initialize(module)
|
||||
hook = PyramidAttentionBroadcastHook(timestep_skip_range, block_skip_range, current_timestep_callback)
|
||||
registry.register_hook(hook, "pyramid_attention_broadcast")
|
||||
registry.register_hook(hook, _PYRAMID_ATTENTION_BROADCAST_HOOK)
|
||||
|
||||
@@ -4249,7 +4249,33 @@ class WanLoraLoaderMixin(LoraBaseMixin):
|
||||
|
||||
return state_dict
|
||||
|
||||
# Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.load_lora_weights
|
||||
@classmethod
|
||||
def _maybe_expand_t2v_lora_for_i2v(
|
||||
cls,
|
||||
transformer: torch.nn.Module,
|
||||
state_dict,
|
||||
):
|
||||
if transformer.config.image_dim is None:
|
||||
return state_dict
|
||||
|
||||
if any(k.startswith("transformer.blocks.") for k in state_dict):
|
||||
num_blocks = len({k.split("blocks.")[1].split(".")[0] for k in state_dict})
|
||||
is_i2v_lora = any("add_k_proj" in k for k in state_dict) and any("add_v_proj" in k for k in state_dict)
|
||||
|
||||
if is_i2v_lora:
|
||||
return state_dict
|
||||
|
||||
for i in range(num_blocks):
|
||||
for o, c in zip(["k_img", "v_img"], ["add_k_proj", "add_v_proj"]):
|
||||
state_dict[f"transformer.blocks.{i}.attn2.{c}.lora_A.weight"] = torch.zeros_like(
|
||||
state_dict[f"transformer.blocks.{i}.attn2.to_k.lora_A.weight"]
|
||||
)
|
||||
state_dict[f"transformer.blocks.{i}.attn2.{c}.lora_B.weight"] = torch.zeros_like(
|
||||
state_dict[f"transformer.blocks.{i}.attn2.to_k.lora_B.weight"]
|
||||
)
|
||||
|
||||
return state_dict
|
||||
|
||||
def load_lora_weights(
|
||||
self, pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]], adapter_name=None, **kwargs
|
||||
):
|
||||
@@ -4287,7 +4313,11 @@ class WanLoraLoaderMixin(LoraBaseMixin):
|
||||
|
||||
# First, ensure that the checkpoint is a compatible one and can be successfully loaded.
|
||||
state_dict = self.lora_state_dict(pretrained_model_name_or_path_or_dict, **kwargs)
|
||||
|
||||
# convert T2V LoRA to I2V LoRA (when loaded to Wan I2V) by adding zeros for the additional (missing) _img layers
|
||||
state_dict = self._maybe_expand_t2v_lora_for_i2v(
|
||||
transformer=getattr(self, self.transformer_name) if not hasattr(self, "transformer") else self.transformer,
|
||||
state_dict=state_dict,
|
||||
)
|
||||
is_correct_format = all("lora" in key for key in state_dict.keys())
|
||||
if not is_correct_format:
|
||||
raise ValueError("Invalid LoRA checkpoint.")
|
||||
|
||||
@@ -360,12 +360,12 @@ class FromSingleFileMixin:
|
||||
cache_dir = kwargs.pop("cache_dir", None)
|
||||
local_files_only = kwargs.pop("local_files_only", False)
|
||||
revision = kwargs.pop("revision", None)
|
||||
torch_dtype = kwargs.pop("torch_dtype", torch.float32)
|
||||
torch_dtype = kwargs.pop("torch_dtype", None)
|
||||
disable_mmap = kwargs.pop("disable_mmap", False)
|
||||
|
||||
is_legacy_loading = False
|
||||
|
||||
if not isinstance(torch_dtype, torch.dtype):
|
||||
if torch_dtype is not None and not isinstance(torch_dtype, torch.dtype):
|
||||
torch_dtype = torch.float32
|
||||
logger.warning(
|
||||
f"Passed `torch_dtype` {torch_dtype} is not a `torch.dtype`. Defaulting to `torch.float32`."
|
||||
|
||||
@@ -255,12 +255,12 @@ class FromOriginalModelMixin:
|
||||
subfolder = kwargs.pop("subfolder", None)
|
||||
revision = kwargs.pop("revision", None)
|
||||
config_revision = kwargs.pop("config_revision", None)
|
||||
torch_dtype = kwargs.pop("torch_dtype", torch.float32)
|
||||
torch_dtype = kwargs.pop("torch_dtype", None)
|
||||
quantization_config = kwargs.pop("quantization_config", None)
|
||||
device = kwargs.pop("device", None)
|
||||
disable_mmap = kwargs.pop("disable_mmap", False)
|
||||
|
||||
if not isinstance(torch_dtype, torch.dtype):
|
||||
if torch_dtype is not None and not isinstance(torch_dtype, torch.dtype):
|
||||
torch_dtype = torch.float32
|
||||
logger.warning(
|
||||
f"Passed `torch_dtype` {torch_dtype} is not a `torch.dtype`. Defaulting to `torch.float32`."
|
||||
@@ -282,6 +282,7 @@ class FromOriginalModelMixin:
|
||||
if quantization_config is not None:
|
||||
hf_quantizer = DiffusersAutoQuantizer.from_config(quantization_config)
|
||||
hf_quantizer.validate_environment()
|
||||
torch_dtype = hf_quantizer.update_torch_dtype(torch_dtype)
|
||||
|
||||
else:
|
||||
hf_quantizer = None
|
||||
|
||||
@@ -449,9 +449,9 @@ class TextualInversionLoaderMixin:
|
||||
|
||||
# 7.5 Offload the model again
|
||||
if is_model_cpu_offload:
|
||||
self.enable_model_cpu_offload()
|
||||
self.enable_model_cpu_offload(device=device)
|
||||
elif is_sequential_cpu_offload:
|
||||
self.enable_sequential_cpu_offload()
|
||||
self.enable_sequential_cpu_offload(device=device)
|
||||
|
||||
# / Unsafe Code >
|
||||
|
||||
|
||||
@@ -6020,6 +6020,11 @@ class SanaLinearAttnProcessor2_0:
|
||||
key = attn.to_k(encoder_hidden_states)
|
||||
value = attn.to_v(encoder_hidden_states)
|
||||
|
||||
if attn.norm_q is not None:
|
||||
query = attn.norm_q(query)
|
||||
if attn.norm_k is not None:
|
||||
key = attn.norm_k(key)
|
||||
|
||||
query = query.transpose(1, 2).unflatten(1, (attn.heads, -1))
|
||||
key = key.transpose(1, 2).unflatten(1, (attn.heads, -1)).transpose(2, 3)
|
||||
value = value.transpose(1, 2).unflatten(1, (attn.heads, -1))
|
||||
|
||||
@@ -196,6 +196,55 @@ class LTXVideoResnetBlock3d(nn.Module):
|
||||
return hidden_states
|
||||
|
||||
|
||||
class LTXVideoDownsampler3d(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
in_channels: int,
|
||||
out_channels: int,
|
||||
stride: Union[int, Tuple[int, int, int]] = 1,
|
||||
is_causal: bool = True,
|
||||
padding_mode: str = "zeros",
|
||||
) -> None:
|
||||
super().__init__()
|
||||
|
||||
self.stride = stride if isinstance(stride, tuple) else (stride, stride, stride)
|
||||
self.group_size = (in_channels * stride[0] * stride[1] * stride[2]) // out_channels
|
||||
|
||||
out_channels = out_channels // (self.stride[0] * self.stride[1] * self.stride[2])
|
||||
|
||||
self.conv = LTXVideoCausalConv3d(
|
||||
in_channels=in_channels,
|
||||
out_channels=out_channels,
|
||||
kernel_size=3,
|
||||
stride=1,
|
||||
is_causal=is_causal,
|
||||
padding_mode=padding_mode,
|
||||
)
|
||||
|
||||
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
||||
hidden_states = torch.cat([hidden_states[:, :, : self.stride[0] - 1], hidden_states], dim=2)
|
||||
|
||||
residual = (
|
||||
hidden_states.unflatten(4, (-1, self.stride[2]))
|
||||
.unflatten(3, (-1, self.stride[1]))
|
||||
.unflatten(2, (-1, self.stride[0]))
|
||||
)
|
||||
residual = residual.permute(0, 1, 3, 5, 7, 2, 4, 6).flatten(1, 4)
|
||||
residual = residual.unflatten(1, (-1, self.group_size))
|
||||
residual = residual.mean(dim=2)
|
||||
|
||||
hidden_states = self.conv(hidden_states)
|
||||
hidden_states = (
|
||||
hidden_states.unflatten(4, (-1, self.stride[2]))
|
||||
.unflatten(3, (-1, self.stride[1]))
|
||||
.unflatten(2, (-1, self.stride[0]))
|
||||
)
|
||||
hidden_states = hidden_states.permute(0, 1, 3, 5, 7, 2, 4, 6).flatten(1, 4)
|
||||
hidden_states = hidden_states + residual
|
||||
|
||||
return hidden_states
|
||||
|
||||
|
||||
class LTXVideoUpsampler3d(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
@@ -204,6 +253,7 @@ class LTXVideoUpsampler3d(nn.Module):
|
||||
is_causal: bool = True,
|
||||
residual: bool = False,
|
||||
upscale_factor: int = 1,
|
||||
padding_mode: str = "zeros",
|
||||
) -> None:
|
||||
super().__init__()
|
||||
|
||||
@@ -219,6 +269,7 @@ class LTXVideoUpsampler3d(nn.Module):
|
||||
kernel_size=3,
|
||||
stride=1,
|
||||
is_causal=is_causal,
|
||||
padding_mode=padding_mode,
|
||||
)
|
||||
|
||||
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
||||
@@ -352,6 +403,118 @@ class LTXVideoDownBlock3D(nn.Module):
|
||||
return hidden_states
|
||||
|
||||
|
||||
class LTXVideo095DownBlock3D(nn.Module):
|
||||
r"""
|
||||
Down block used in the LTXVideo model.
|
||||
|
||||
Args:
|
||||
in_channels (`int`):
|
||||
Number of input channels.
|
||||
out_channels (`int`, *optional*):
|
||||
Number of output channels. If None, defaults to `in_channels`.
|
||||
num_layers (`int`, defaults to `1`):
|
||||
Number of resnet layers.
|
||||
dropout (`float`, defaults to `0.0`):
|
||||
Dropout rate.
|
||||
resnet_eps (`float`, defaults to `1e-6`):
|
||||
Epsilon value for normalization layers.
|
||||
resnet_act_fn (`str`, defaults to `"swish"`):
|
||||
Activation function to use.
|
||||
spatio_temporal_scale (`bool`, defaults to `True`):
|
||||
Whether or not to use a downsampling layer. If not used, output dimension would be same as input dimension.
|
||||
Whether or not to downsample across temporal dimension.
|
||||
is_causal (`bool`, defaults to `True`):
|
||||
Whether this layer behaves causally (future frames depend only on past frames) or not.
|
||||
"""
|
||||
|
||||
_supports_gradient_checkpointing = True
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
in_channels: int,
|
||||
out_channels: Optional[int] = None,
|
||||
num_layers: int = 1,
|
||||
dropout: float = 0.0,
|
||||
resnet_eps: float = 1e-6,
|
||||
resnet_act_fn: str = "swish",
|
||||
spatio_temporal_scale: bool = True,
|
||||
is_causal: bool = True,
|
||||
downsample_type: str = "conv",
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
out_channels = out_channels or in_channels
|
||||
|
||||
resnets = []
|
||||
for _ in range(num_layers):
|
||||
resnets.append(
|
||||
LTXVideoResnetBlock3d(
|
||||
in_channels=in_channels,
|
||||
out_channels=in_channels,
|
||||
dropout=dropout,
|
||||
eps=resnet_eps,
|
||||
non_linearity=resnet_act_fn,
|
||||
is_causal=is_causal,
|
||||
)
|
||||
)
|
||||
self.resnets = nn.ModuleList(resnets)
|
||||
|
||||
self.downsamplers = None
|
||||
if spatio_temporal_scale:
|
||||
self.downsamplers = nn.ModuleList()
|
||||
|
||||
if downsample_type == "conv":
|
||||
self.downsamplers.append(
|
||||
LTXVideoCausalConv3d(
|
||||
in_channels=in_channels,
|
||||
out_channels=in_channels,
|
||||
kernel_size=3,
|
||||
stride=(2, 2, 2),
|
||||
is_causal=is_causal,
|
||||
)
|
||||
)
|
||||
elif downsample_type == "spatial":
|
||||
self.downsamplers.append(
|
||||
LTXVideoDownsampler3d(
|
||||
in_channels=in_channels, out_channels=out_channels, stride=(1, 2, 2), is_causal=is_causal
|
||||
)
|
||||
)
|
||||
elif downsample_type == "temporal":
|
||||
self.downsamplers.append(
|
||||
LTXVideoDownsampler3d(
|
||||
in_channels=in_channels, out_channels=out_channels, stride=(2, 1, 1), is_causal=is_causal
|
||||
)
|
||||
)
|
||||
elif downsample_type == "spatiotemporal":
|
||||
self.downsamplers.append(
|
||||
LTXVideoDownsampler3d(
|
||||
in_channels=in_channels, out_channels=out_channels, stride=(2, 2, 2), is_causal=is_causal
|
||||
)
|
||||
)
|
||||
|
||||
self.gradient_checkpointing = False
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
temb: Optional[torch.Tensor] = None,
|
||||
generator: Optional[torch.Generator] = None,
|
||||
) -> torch.Tensor:
|
||||
r"""Forward method of the `LTXDownBlock3D` class."""
|
||||
|
||||
for i, resnet in enumerate(self.resnets):
|
||||
if torch.is_grad_enabled() and self.gradient_checkpointing:
|
||||
hidden_states = self._gradient_checkpointing_func(resnet, hidden_states, temb, generator)
|
||||
else:
|
||||
hidden_states = resnet(hidden_states, temb, generator)
|
||||
|
||||
if self.downsamplers is not None:
|
||||
for downsampler in self.downsamplers:
|
||||
hidden_states = downsampler(hidden_states)
|
||||
|
||||
return hidden_states
|
||||
|
||||
|
||||
# Adapted from diffusers.models.autoencoders.autoencoder_kl_cogvideox.CogVideoMidBlock3d
|
||||
class LTXVideoMidBlock3d(nn.Module):
|
||||
r"""
|
||||
@@ -593,8 +756,15 @@ class LTXVideoEncoder3d(nn.Module):
|
||||
in_channels: int = 3,
|
||||
out_channels: int = 128,
|
||||
block_out_channels: Tuple[int, ...] = (128, 256, 512, 512),
|
||||
down_block_types: Tuple[str, ...] = (
|
||||
"LTXVideoDownBlock3D",
|
||||
"LTXVideoDownBlock3D",
|
||||
"LTXVideoDownBlock3D",
|
||||
"LTXVideoDownBlock3D",
|
||||
),
|
||||
spatio_temporal_scaling: Tuple[bool, ...] = (True, True, True, False),
|
||||
layers_per_block: Tuple[int, ...] = (4, 3, 3, 3, 4),
|
||||
downsample_type: Tuple[str, ...] = ("conv", "conv", "conv", "conv"),
|
||||
patch_size: int = 4,
|
||||
patch_size_t: int = 1,
|
||||
resnet_norm_eps: float = 1e-6,
|
||||
@@ -617,20 +787,37 @@ class LTXVideoEncoder3d(nn.Module):
|
||||
)
|
||||
|
||||
# down blocks
|
||||
num_block_out_channels = len(block_out_channels)
|
||||
is_ltx_095 = down_block_types[-1] == "LTXVideo095DownBlock3D"
|
||||
num_block_out_channels = len(block_out_channels) - (1 if is_ltx_095 else 0)
|
||||
self.down_blocks = nn.ModuleList([])
|
||||
for i in range(num_block_out_channels):
|
||||
input_channel = output_channel
|
||||
output_channel = block_out_channels[i + 1] if i + 1 < num_block_out_channels else block_out_channels[i]
|
||||
if not is_ltx_095:
|
||||
output_channel = block_out_channels[i + 1] if i + 1 < num_block_out_channels else block_out_channels[i]
|
||||
else:
|
||||
output_channel = block_out_channels[i + 1]
|
||||
|
||||
down_block = LTXVideoDownBlock3D(
|
||||
in_channels=input_channel,
|
||||
out_channels=output_channel,
|
||||
num_layers=layers_per_block[i],
|
||||
resnet_eps=resnet_norm_eps,
|
||||
spatio_temporal_scale=spatio_temporal_scaling[i],
|
||||
is_causal=is_causal,
|
||||
)
|
||||
if down_block_types[i] == "LTXVideoDownBlock3D":
|
||||
down_block = LTXVideoDownBlock3D(
|
||||
in_channels=input_channel,
|
||||
out_channels=output_channel,
|
||||
num_layers=layers_per_block[i],
|
||||
resnet_eps=resnet_norm_eps,
|
||||
spatio_temporal_scale=spatio_temporal_scaling[i],
|
||||
is_causal=is_causal,
|
||||
)
|
||||
elif down_block_types[i] == "LTXVideo095DownBlock3D":
|
||||
down_block = LTXVideo095DownBlock3D(
|
||||
in_channels=input_channel,
|
||||
out_channels=output_channel,
|
||||
num_layers=layers_per_block[i],
|
||||
resnet_eps=resnet_norm_eps,
|
||||
spatio_temporal_scale=spatio_temporal_scaling[i],
|
||||
is_causal=is_causal,
|
||||
downsample_type=downsample_type[i],
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Unknown down block type: {down_block_types[i]}")
|
||||
|
||||
self.down_blocks.append(down_block)
|
||||
|
||||
@@ -794,7 +981,9 @@ class LTXVideoDecoder3d(nn.Module):
|
||||
# timestep embedding
|
||||
self.time_embedder = None
|
||||
self.scale_shift_table = None
|
||||
self.timestep_scale_multiplier = None
|
||||
if timestep_conditioning:
|
||||
self.timestep_scale_multiplier = nn.Parameter(torch.tensor(1000.0, dtype=torch.float32))
|
||||
self.time_embedder = PixArtAlphaCombinedTimestepSizeEmbeddings(output_channel * 2, 0)
|
||||
self.scale_shift_table = nn.Parameter(torch.randn(2, output_channel) / output_channel**0.5)
|
||||
|
||||
@@ -803,6 +992,9 @@ class LTXVideoDecoder3d(nn.Module):
|
||||
def forward(self, hidden_states: torch.Tensor, temb: Optional[torch.Tensor] = None) -> torch.Tensor:
|
||||
hidden_states = self.conv_in(hidden_states)
|
||||
|
||||
if self.timestep_scale_multiplier is not None:
|
||||
temb = temb * self.timestep_scale_multiplier
|
||||
|
||||
if torch.is_grad_enabled() and self.gradient_checkpointing:
|
||||
hidden_states = self._gradient_checkpointing_func(self.mid_block, hidden_states, temb)
|
||||
|
||||
@@ -891,12 +1083,19 @@ class AutoencoderKLLTXVideo(ModelMixin, ConfigMixin, FromOriginalModelMixin):
|
||||
out_channels: int = 3,
|
||||
latent_channels: int = 128,
|
||||
block_out_channels: Tuple[int, ...] = (128, 256, 512, 512),
|
||||
down_block_types: Tuple[str, ...] = (
|
||||
"LTXVideoDownBlock3D",
|
||||
"LTXVideoDownBlock3D",
|
||||
"LTXVideoDownBlock3D",
|
||||
"LTXVideoDownBlock3D",
|
||||
),
|
||||
decoder_block_out_channels: Tuple[int, ...] = (128, 256, 512, 512),
|
||||
layers_per_block: Tuple[int, ...] = (4, 3, 3, 3, 4),
|
||||
decoder_layers_per_block: Tuple[int, ...] = (4, 3, 3, 3, 4),
|
||||
spatio_temporal_scaling: Tuple[bool, ...] = (True, True, True, False),
|
||||
decoder_spatio_temporal_scaling: Tuple[bool, ...] = (True, True, True, False),
|
||||
decoder_inject_noise: Tuple[bool, ...] = (False, False, False, False, False),
|
||||
downsample_type: Tuple[str, ...] = ("conv", "conv", "conv", "conv"),
|
||||
upsample_residual: Tuple[bool, ...] = (False, False, False, False),
|
||||
upsample_factor: Tuple[int, ...] = (1, 1, 1, 1),
|
||||
timestep_conditioning: bool = False,
|
||||
@@ -906,6 +1105,8 @@ class AutoencoderKLLTXVideo(ModelMixin, ConfigMixin, FromOriginalModelMixin):
|
||||
scaling_factor: float = 1.0,
|
||||
encoder_causal: bool = True,
|
||||
decoder_causal: bool = False,
|
||||
spatial_compression_ratio: int = None,
|
||||
temporal_compression_ratio: int = None,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
|
||||
@@ -913,8 +1114,10 @@ class AutoencoderKLLTXVideo(ModelMixin, ConfigMixin, FromOriginalModelMixin):
|
||||
in_channels=in_channels,
|
||||
out_channels=latent_channels,
|
||||
block_out_channels=block_out_channels,
|
||||
down_block_types=down_block_types,
|
||||
spatio_temporal_scaling=spatio_temporal_scaling,
|
||||
layers_per_block=layers_per_block,
|
||||
downsample_type=downsample_type,
|
||||
patch_size=patch_size,
|
||||
patch_size_t=patch_size_t,
|
||||
resnet_norm_eps=resnet_norm_eps,
|
||||
@@ -941,8 +1144,16 @@ class AutoencoderKLLTXVideo(ModelMixin, ConfigMixin, FromOriginalModelMixin):
|
||||
self.register_buffer("latents_mean", latents_mean, persistent=True)
|
||||
self.register_buffer("latents_std", latents_std, persistent=True)
|
||||
|
||||
self.spatial_compression_ratio = patch_size * 2 ** sum(spatio_temporal_scaling)
|
||||
self.temporal_compression_ratio = patch_size_t * 2 ** sum(spatio_temporal_scaling)
|
||||
self.spatial_compression_ratio = (
|
||||
patch_size * 2 ** sum(spatio_temporal_scaling)
|
||||
if spatial_compression_ratio is None
|
||||
else spatial_compression_ratio
|
||||
)
|
||||
self.temporal_compression_ratio = (
|
||||
patch_size_t * 2 ** sum(spatio_temporal_scaling)
|
||||
if temporal_compression_ratio is None
|
||||
else temporal_compression_ratio
|
||||
)
|
||||
|
||||
# When decoding a batch of video latents at a time, one can save memory by slicing across the batch dimension
|
||||
# to perform decoding of a single video latent at a time.
|
||||
|
||||
@@ -24,6 +24,7 @@ class CacheMixin:
|
||||
|
||||
Supported caching techniques:
|
||||
- [Pyramid Attention Broadcast](https://huggingface.co/papers/2408.12588)
|
||||
- [FasterCache](https://huggingface.co/papers/2410.19355)
|
||||
"""
|
||||
|
||||
_cache_config = None
|
||||
@@ -59,17 +60,31 @@ class CacheMixin:
|
||||
```
|
||||
"""
|
||||
|
||||
from ..hooks import PyramidAttentionBroadcastConfig, apply_pyramid_attention_broadcast
|
||||
from ..hooks import (
|
||||
FasterCacheConfig,
|
||||
PyramidAttentionBroadcastConfig,
|
||||
apply_faster_cache,
|
||||
apply_pyramid_attention_broadcast,
|
||||
)
|
||||
|
||||
if self.is_cache_enabled:
|
||||
raise ValueError(
|
||||
f"Caching has already been enabled with {type(self._cache_config)}. To apply a new caching technique, please disable the existing one first."
|
||||
)
|
||||
|
||||
if isinstance(config, PyramidAttentionBroadcastConfig):
|
||||
apply_pyramid_attention_broadcast(self, config)
|
||||
elif isinstance(config, FasterCacheConfig):
|
||||
apply_faster_cache(self, config)
|
||||
else:
|
||||
raise ValueError(f"Cache config {type(config)} is not supported.")
|
||||
|
||||
self._cache_config = config
|
||||
|
||||
def disable_cache(self) -> None:
|
||||
from ..hooks import HookRegistry, PyramidAttentionBroadcastConfig
|
||||
from ..hooks import FasterCacheConfig, HookRegistry, PyramidAttentionBroadcastConfig
|
||||
from ..hooks.faster_cache import _FASTER_CACHE_BLOCK_HOOK, _FASTER_CACHE_DENOISER_HOOK
|
||||
from ..hooks.pyramid_attention_broadcast import _PYRAMID_ATTENTION_BROADCAST_HOOK
|
||||
|
||||
if self._cache_config is None:
|
||||
logger.warning("Caching techniques have not been enabled, so there's nothing to disable.")
|
||||
@@ -77,7 +92,11 @@ class CacheMixin:
|
||||
|
||||
if isinstance(self._cache_config, PyramidAttentionBroadcastConfig):
|
||||
registry = HookRegistry.check_if_exists_or_initialize(self)
|
||||
registry.remove_hook("pyramid_attention_broadcast", recurse=True)
|
||||
registry.remove_hook(_PYRAMID_ATTENTION_BROADCAST_HOOK, recurse=True)
|
||||
elif isinstance(self._cache_config, FasterCacheConfig):
|
||||
registry = HookRegistry.check_if_exists_or_initialize(self)
|
||||
registry.remove_hook(_FASTER_CACHE_DENOISER_HOOK, recurse=True)
|
||||
registry.remove_hook(_FASTER_CACHE_BLOCK_HOOK, recurse=True)
|
||||
else:
|
||||
raise ValueError(f"Cache config {type(self._cache_config)} is not supported.")
|
||||
|
||||
|
||||
@@ -336,7 +336,7 @@ def get_1d_sincos_pos_embed_from_grid(embed_dim, pos, output_type="np"):
|
||||
" `from_numpy` is no longer required."
|
||||
" Pass `output_type='pt' to use the new version now."
|
||||
)
|
||||
deprecate("output_type=='np'", "0.33.0", deprecation_message, standard_warn=False)
|
||||
deprecate("output_type=='np'", "0.34.0", deprecation_message, standard_warn=False)
|
||||
return get_1d_sincos_pos_embed_from_grid_np(embed_dim=embed_dim, pos=pos)
|
||||
if embed_dim % 2 != 0:
|
||||
raise ValueError("embed_dim must be divisible by 2")
|
||||
|
||||
@@ -37,7 +37,6 @@ from torch import Tensor, nn
|
||||
from typing_extensions import Self
|
||||
|
||||
from .. import __version__
|
||||
from ..hooks import apply_group_offloading, apply_layerwise_casting
|
||||
from ..quantizers import DiffusersAutoQuantizer, DiffusersQuantizer
|
||||
from ..quantizers.quantization_config import QuantizationMethod
|
||||
from ..utils import (
|
||||
@@ -504,6 +503,7 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
|
||||
non_blocking (`bool`, *optional*, defaults to `False`):
|
||||
If `True`, the weight casting operations are non-blocking.
|
||||
"""
|
||||
from ..hooks import apply_layerwise_casting
|
||||
|
||||
user_provided_patterns = True
|
||||
if skip_modules_pattern is None:
|
||||
@@ -546,6 +546,7 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
|
||||
num_blocks_per_group: Optional[int] = None,
|
||||
non_blocking: bool = False,
|
||||
use_stream: bool = False,
|
||||
low_cpu_mem_usage=False,
|
||||
) -> None:
|
||||
r"""
|
||||
Activates group offloading for the current model.
|
||||
@@ -569,6 +570,8 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
|
||||
... )
|
||||
```
|
||||
"""
|
||||
from ..hooks import apply_group_offloading
|
||||
|
||||
if getattr(self, "enable_tiling", None) is not None and getattr(self, "use_tiling", False) and use_stream:
|
||||
msg = (
|
||||
"Applying group offloading on autoencoders, with CUDA streams, may not work as expected if the first "
|
||||
@@ -584,7 +587,14 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
|
||||
f"open an issue at https://github.com/huggingface/diffusers/issues."
|
||||
)
|
||||
apply_group_offloading(
|
||||
self, onload_device, offload_device, offload_type, num_blocks_per_group, non_blocking, use_stream
|
||||
self,
|
||||
onload_device,
|
||||
offload_device,
|
||||
offload_type,
|
||||
num_blocks_per_group,
|
||||
non_blocking,
|
||||
use_stream,
|
||||
low_cpu_mem_usage=low_cpu_mem_usage,
|
||||
)
|
||||
|
||||
def save_pretrained(
|
||||
@@ -870,7 +880,7 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
|
||||
local_files_only = kwargs.pop("local_files_only", None)
|
||||
token = kwargs.pop("token", None)
|
||||
revision = kwargs.pop("revision", None)
|
||||
torch_dtype = kwargs.pop("torch_dtype", torch.float32)
|
||||
torch_dtype = kwargs.pop("torch_dtype", None)
|
||||
subfolder = kwargs.pop("subfolder", None)
|
||||
device_map = kwargs.pop("device_map", None)
|
||||
max_memory = kwargs.pop("max_memory", None)
|
||||
@@ -883,7 +893,7 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
|
||||
dduf_entries: Optional[Dict[str, DDUFEntry]] = kwargs.pop("dduf_entries", None)
|
||||
disable_mmap = kwargs.pop("disable_mmap", False)
|
||||
|
||||
if not isinstance(torch_dtype, torch.dtype):
|
||||
if torch_dtype is not None and not isinstance(torch_dtype, torch.dtype):
|
||||
torch_dtype = torch.float32
|
||||
logger.warning(
|
||||
f"Passed `torch_dtype` {torch_dtype} is not a `torch.dtype`. Defaulting to `torch.float32`."
|
||||
|
||||
@@ -550,16 +550,6 @@ class RMSNorm(nn.Module):
|
||||
hidden_states = torch_npu.npu_rms_norm(hidden_states, self.weight, epsilon=self.eps)[0]
|
||||
if self.bias is not None:
|
||||
hidden_states = hidden_states + self.bias
|
||||
elif is_torch_version(">=", "2.4"):
|
||||
if self.weight is not None:
|
||||
# convert into half-precision if necessary
|
||||
if self.weight.dtype in [torch.float16, torch.bfloat16]:
|
||||
hidden_states = hidden_states.to(self.weight.dtype)
|
||||
hidden_states = nn.functional.rms_norm(
|
||||
hidden_states, normalized_shape=(hidden_states.shape[-1],), weight=self.weight, eps=self.eps
|
||||
)
|
||||
if self.bias is not None:
|
||||
hidden_states = hidden_states + self.bias
|
||||
else:
|
||||
input_dtype = hidden_states.dtype
|
||||
variance = hidden_states.to(torch.float32).pow(2).mean(-1, keepdim=True)
|
||||
|
||||
@@ -366,7 +366,7 @@ class ResnetBlock2D(nn.Module):
|
||||
hidden_states = self.conv2(hidden_states)
|
||||
|
||||
if self.conv_shortcut is not None:
|
||||
input_tensor = self.conv_shortcut(input_tensor)
|
||||
input_tensor = self.conv_shortcut(input_tensor.contiguous())
|
||||
|
||||
output_tensor = (input_tensor + hidden_states) / self.output_scale_factor
|
||||
|
||||
|
||||
@@ -15,6 +15,7 @@
|
||||
from typing import Any, Dict, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from torch import nn
|
||||
|
||||
from ...configuration_utils import ConfigMixin, register_to_config
|
||||
@@ -23,10 +24,9 @@ from ...utils import USE_PEFT_BACKEND, logging, scale_lora_layers, unscale_lora_
|
||||
from ..attention_processor import (
|
||||
Attention,
|
||||
AttentionProcessor,
|
||||
AttnProcessor2_0,
|
||||
SanaLinearAttnProcessor2_0,
|
||||
)
|
||||
from ..embeddings import PatchEmbed, PixArtAlphaTextProjection
|
||||
from ..embeddings import PatchEmbed, PixArtAlphaTextProjection, TimestepEmbedding, Timesteps
|
||||
from ..modeling_outputs import Transformer2DModelOutput
|
||||
from ..modeling_utils import ModelMixin
|
||||
from ..normalization import AdaLayerNormSingle, RMSNorm
|
||||
@@ -96,6 +96,95 @@ class SanaModulatedNorm(nn.Module):
|
||||
return hidden_states
|
||||
|
||||
|
||||
class SanaCombinedTimestepGuidanceEmbeddings(nn.Module):
|
||||
def __init__(self, embedding_dim):
|
||||
super().__init__()
|
||||
self.time_proj = Timesteps(num_channels=256, flip_sin_to_cos=True, downscale_freq_shift=0)
|
||||
self.timestep_embedder = TimestepEmbedding(in_channels=256, time_embed_dim=embedding_dim)
|
||||
|
||||
self.guidance_condition_proj = Timesteps(num_channels=256, flip_sin_to_cos=True, downscale_freq_shift=0)
|
||||
self.guidance_embedder = TimestepEmbedding(in_channels=256, time_embed_dim=embedding_dim)
|
||||
|
||||
self.silu = nn.SiLU()
|
||||
self.linear = nn.Linear(embedding_dim, 6 * embedding_dim, bias=True)
|
||||
|
||||
def forward(self, timestep: torch.Tensor, guidance: torch.Tensor = None, hidden_dtype: torch.dtype = None):
|
||||
timesteps_proj = self.time_proj(timestep)
|
||||
timesteps_emb = self.timestep_embedder(timesteps_proj.to(dtype=hidden_dtype)) # (N, D)
|
||||
|
||||
guidance_proj = self.guidance_condition_proj(guidance)
|
||||
guidance_emb = self.guidance_embedder(guidance_proj.to(dtype=hidden_dtype))
|
||||
conditioning = timesteps_emb + guidance_emb
|
||||
|
||||
return self.linear(self.silu(conditioning)), conditioning
|
||||
|
||||
|
||||
class SanaAttnProcessor2_0:
|
||||
r"""
|
||||
Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0).
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
if not hasattr(F, "scaled_dot_product_attention"):
|
||||
raise ImportError("SanaAttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
attn: Attention,
|
||||
hidden_states: torch.Tensor,
|
||||
encoder_hidden_states: Optional[torch.Tensor] = None,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
batch_size, sequence_length, _ = (
|
||||
hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
|
||||
)
|
||||
|
||||
if attention_mask is not None:
|
||||
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
|
||||
# scaled_dot_product_attention expects attention_mask shape to be
|
||||
# (batch, heads, source_length, target_length)
|
||||
attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])
|
||||
|
||||
query = attn.to_q(hidden_states)
|
||||
|
||||
if encoder_hidden_states is None:
|
||||
encoder_hidden_states = hidden_states
|
||||
|
||||
key = attn.to_k(encoder_hidden_states)
|
||||
value = attn.to_v(encoder_hidden_states)
|
||||
|
||||
if attn.norm_q is not None:
|
||||
query = attn.norm_q(query)
|
||||
if attn.norm_k is not None:
|
||||
key = attn.norm_k(key)
|
||||
|
||||
inner_dim = key.shape[-1]
|
||||
head_dim = inner_dim // attn.heads
|
||||
|
||||
query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
||||
|
||||
key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
||||
value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
||||
|
||||
# the output of sdp = (batch, num_heads, seq_len, head_dim)
|
||||
# TODO: add support for attn.scale when we move to Torch 2.1
|
||||
hidden_states = F.scaled_dot_product_attention(
|
||||
query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
|
||||
)
|
||||
|
||||
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
|
||||
hidden_states = hidden_states.to(query.dtype)
|
||||
|
||||
# linear proj
|
||||
hidden_states = attn.to_out[0](hidden_states)
|
||||
# dropout
|
||||
hidden_states = attn.to_out[1](hidden_states)
|
||||
|
||||
hidden_states = hidden_states / attn.rescale_output_factor
|
||||
|
||||
return hidden_states
|
||||
|
||||
|
||||
class SanaTransformerBlock(nn.Module):
|
||||
r"""
|
||||
Transformer block introduced in [Sana](https://huggingface.co/papers/2410.10629).
|
||||
@@ -115,6 +204,7 @@ class SanaTransformerBlock(nn.Module):
|
||||
norm_eps: float = 1e-6,
|
||||
attention_out_bias: bool = True,
|
||||
mlp_ratio: float = 2.5,
|
||||
qk_norm: Optional[str] = None,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
|
||||
@@ -124,6 +214,8 @@ class SanaTransformerBlock(nn.Module):
|
||||
query_dim=dim,
|
||||
heads=num_attention_heads,
|
||||
dim_head=attention_head_dim,
|
||||
kv_heads=num_attention_heads if qk_norm is not None else None,
|
||||
qk_norm=qk_norm,
|
||||
dropout=dropout,
|
||||
bias=attention_bias,
|
||||
cross_attention_dim=None,
|
||||
@@ -135,13 +227,15 @@ class SanaTransformerBlock(nn.Module):
|
||||
self.norm2 = nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine, eps=norm_eps)
|
||||
self.attn2 = Attention(
|
||||
query_dim=dim,
|
||||
qk_norm=qk_norm,
|
||||
kv_heads=num_cross_attention_heads if qk_norm is not None else None,
|
||||
cross_attention_dim=cross_attention_dim,
|
||||
heads=num_cross_attention_heads,
|
||||
dim_head=cross_attention_head_dim,
|
||||
dropout=dropout,
|
||||
bias=True,
|
||||
out_bias=attention_out_bias,
|
||||
processor=AttnProcessor2_0(),
|
||||
processor=SanaAttnProcessor2_0(),
|
||||
)
|
||||
|
||||
# 3. Feed-forward
|
||||
@@ -232,6 +326,10 @@ class SanaTransformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOrig
|
||||
Whether to use elementwise affinity in the normalization layer.
|
||||
norm_eps (`float`, defaults to `1e-6`):
|
||||
The epsilon value for the normalization layer.
|
||||
qk_norm (`str`, *optional*, defaults to `None`):
|
||||
The normalization to use for the query and key.
|
||||
timestep_scale (`float`, defaults to `1.0`):
|
||||
The scale to use for the timesteps.
|
||||
"""
|
||||
|
||||
_supports_gradient_checkpointing = True
|
||||
@@ -258,6 +356,10 @@ class SanaTransformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOrig
|
||||
norm_elementwise_affine: bool = False,
|
||||
norm_eps: float = 1e-6,
|
||||
interpolation_scale: Optional[int] = None,
|
||||
guidance_embeds: bool = False,
|
||||
guidance_embeds_scale: float = 0.1,
|
||||
qk_norm: Optional[str] = None,
|
||||
timestep_scale: float = 1.0,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
|
||||
@@ -276,7 +378,10 @@ class SanaTransformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOrig
|
||||
)
|
||||
|
||||
# 2. Additional condition embeddings
|
||||
self.time_embed = AdaLayerNormSingle(inner_dim)
|
||||
if guidance_embeds:
|
||||
self.time_embed = SanaCombinedTimestepGuidanceEmbeddings(inner_dim)
|
||||
else:
|
||||
self.time_embed = AdaLayerNormSingle(inner_dim)
|
||||
|
||||
self.caption_projection = PixArtAlphaTextProjection(in_features=caption_channels, hidden_size=inner_dim)
|
||||
self.caption_norm = RMSNorm(inner_dim, eps=1e-5, elementwise_affine=True)
|
||||
@@ -296,6 +401,7 @@ class SanaTransformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOrig
|
||||
norm_elementwise_affine=norm_elementwise_affine,
|
||||
norm_eps=norm_eps,
|
||||
mlp_ratio=mlp_ratio,
|
||||
qk_norm=qk_norm,
|
||||
)
|
||||
for _ in range(num_layers)
|
||||
]
|
||||
@@ -372,7 +478,8 @@ class SanaTransformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOrig
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
encoder_hidden_states: torch.Tensor,
|
||||
timestep: torch.LongTensor,
|
||||
timestep: torch.Tensor,
|
||||
guidance: Optional[torch.Tensor] = None,
|
||||
encoder_attention_mask: Optional[torch.Tensor] = None,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
attention_kwargs: Optional[Dict[str, Any]] = None,
|
||||
@@ -423,9 +530,14 @@ class SanaTransformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOrig
|
||||
|
||||
hidden_states = self.patch_embed(hidden_states)
|
||||
|
||||
timestep, embedded_timestep = self.time_embed(
|
||||
timestep, batch_size=batch_size, hidden_dtype=hidden_states.dtype
|
||||
)
|
||||
if guidance is not None:
|
||||
timestep, embedded_timestep = self.time_embed(
|
||||
timestep, guidance=guidance, hidden_dtype=hidden_states.dtype
|
||||
)
|
||||
else:
|
||||
timestep, embedded_timestep = self.time_embed(
|
||||
timestep, batch_size=batch_size, hidden_dtype=hidden_states.dtype
|
||||
)
|
||||
|
||||
encoder_hidden_states = self.caption_projection(encoder_hidden_states)
|
||||
encoder_hidden_states = encoder_hidden_states.view(batch_size, -1, hidden_states.shape[-1])
|
||||
|
||||
@@ -27,13 +27,15 @@ from ..attention import FeedForward
|
||||
from ..attention_processor import Attention, AttentionProcessor
|
||||
from ..cache_utils import CacheMixin
|
||||
from ..embeddings import (
|
||||
CombinedTimestepGuidanceTextProjEmbeddings,
|
||||
CombinedTimestepTextProjEmbeddings,
|
||||
PixArtAlphaTextProjection,
|
||||
TimestepEmbedding,
|
||||
Timesteps,
|
||||
get_1d_rotary_pos_embed,
|
||||
)
|
||||
from ..modeling_outputs import Transformer2DModelOutput
|
||||
from ..modeling_utils import ModelMixin
|
||||
from ..normalization import AdaLayerNormContinuous, AdaLayerNormZero, AdaLayerNormZeroSingle
|
||||
from ..normalization import AdaLayerNormContinuous, AdaLayerNormZero, AdaLayerNormZeroSingle, FP32LayerNorm
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||
@@ -173,6 +175,141 @@ class HunyuanVideoAdaNorm(nn.Module):
|
||||
return gate_msa, gate_mlp
|
||||
|
||||
|
||||
class HunyuanVideoTokenReplaceAdaLayerNormZero(nn.Module):
|
||||
def __init__(self, embedding_dim: int, norm_type: str = "layer_norm", bias: bool = True):
|
||||
super().__init__()
|
||||
|
||||
self.silu = nn.SiLU()
|
||||
self.linear = nn.Linear(embedding_dim, 6 * embedding_dim, bias=bias)
|
||||
|
||||
if norm_type == "layer_norm":
|
||||
self.norm = nn.LayerNorm(embedding_dim, elementwise_affine=False, eps=1e-6)
|
||||
elif norm_type == "fp32_layer_norm":
|
||||
self.norm = FP32LayerNorm(embedding_dim, elementwise_affine=False, bias=False)
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Unsupported `norm_type` ({norm_type}) provided. Supported ones are: 'layer_norm', 'fp32_layer_norm'."
|
||||
)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
emb: torch.Tensor,
|
||||
token_replace_emb: torch.Tensor,
|
||||
first_frame_num_tokens: int,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
emb = self.linear(self.silu(emb))
|
||||
token_replace_emb = self.linear(self.silu(token_replace_emb))
|
||||
|
||||
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = emb.chunk(6, dim=1)
|
||||
tr_shift_msa, tr_scale_msa, tr_gate_msa, tr_shift_mlp, tr_scale_mlp, tr_gate_mlp = token_replace_emb.chunk(
|
||||
6, dim=1
|
||||
)
|
||||
|
||||
norm_hidden_states = self.norm(hidden_states)
|
||||
hidden_states_zero = (
|
||||
norm_hidden_states[:, :first_frame_num_tokens] * (1 + tr_scale_msa[:, None]) + tr_shift_msa[:, None]
|
||||
)
|
||||
hidden_states_orig = (
|
||||
norm_hidden_states[:, first_frame_num_tokens:] * (1 + scale_msa[:, None]) + shift_msa[:, None]
|
||||
)
|
||||
hidden_states = torch.cat([hidden_states_zero, hidden_states_orig], dim=1)
|
||||
|
||||
return (
|
||||
hidden_states,
|
||||
gate_msa,
|
||||
shift_mlp,
|
||||
scale_mlp,
|
||||
gate_mlp,
|
||||
tr_gate_msa,
|
||||
tr_shift_mlp,
|
||||
tr_scale_mlp,
|
||||
tr_gate_mlp,
|
||||
)
|
||||
|
||||
|
||||
class HunyuanVideoTokenReplaceAdaLayerNormZeroSingle(nn.Module):
|
||||
def __init__(self, embedding_dim: int, norm_type: str = "layer_norm", bias: bool = True):
|
||||
super().__init__()
|
||||
|
||||
self.silu = nn.SiLU()
|
||||
self.linear = nn.Linear(embedding_dim, 3 * embedding_dim, bias=bias)
|
||||
|
||||
if norm_type == "layer_norm":
|
||||
self.norm = nn.LayerNorm(embedding_dim, elementwise_affine=False, eps=1e-6)
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Unsupported `norm_type` ({norm_type}) provided. Supported ones are: 'layer_norm', 'fp32_layer_norm'."
|
||||
)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
emb: torch.Tensor,
|
||||
token_replace_emb: torch.Tensor,
|
||||
first_frame_num_tokens: int,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
emb = self.linear(self.silu(emb))
|
||||
token_replace_emb = self.linear(self.silu(token_replace_emb))
|
||||
|
||||
shift_msa, scale_msa, gate_msa = emb.chunk(3, dim=1)
|
||||
tr_shift_msa, tr_scale_msa, tr_gate_msa = token_replace_emb.chunk(3, dim=1)
|
||||
|
||||
norm_hidden_states = self.norm(hidden_states)
|
||||
hidden_states_zero = (
|
||||
norm_hidden_states[:, :first_frame_num_tokens] * (1 + tr_scale_msa[:, None]) + tr_shift_msa[:, None]
|
||||
)
|
||||
hidden_states_orig = (
|
||||
norm_hidden_states[:, first_frame_num_tokens:] * (1 + scale_msa[:, None]) + shift_msa[:, None]
|
||||
)
|
||||
hidden_states = torch.cat([hidden_states_zero, hidden_states_orig], dim=1)
|
||||
|
||||
return hidden_states, gate_msa, tr_gate_msa
|
||||
|
||||
|
||||
class HunyuanVideoConditionEmbedding(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
embedding_dim: int,
|
||||
pooled_projection_dim: int,
|
||||
guidance_embeds: bool,
|
||||
image_condition_type: Optional[str] = None,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.image_condition_type = image_condition_type
|
||||
|
||||
self.time_proj = Timesteps(num_channels=256, flip_sin_to_cos=True, downscale_freq_shift=0)
|
||||
self.timestep_embedder = TimestepEmbedding(in_channels=256, time_embed_dim=embedding_dim)
|
||||
self.text_embedder = PixArtAlphaTextProjection(pooled_projection_dim, embedding_dim, act_fn="silu")
|
||||
|
||||
self.guidance_embedder = None
|
||||
if guidance_embeds:
|
||||
self.guidance_embedder = TimestepEmbedding(in_channels=256, time_embed_dim=embedding_dim)
|
||||
|
||||
def forward(
|
||||
self, timestep: torch.Tensor, pooled_projection: torch.Tensor, guidance: Optional[torch.Tensor] = None
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
timesteps_proj = self.time_proj(timestep)
|
||||
timesteps_emb = self.timestep_embedder(timesteps_proj.to(dtype=pooled_projection.dtype)) # (N, D)
|
||||
pooled_projections = self.text_embedder(pooled_projection)
|
||||
conditioning = timesteps_emb + pooled_projections
|
||||
|
||||
token_replace_emb = None
|
||||
if self.image_condition_type == "token_replace":
|
||||
token_replace_timestep = torch.zeros_like(timestep)
|
||||
token_replace_proj = self.time_proj(token_replace_timestep)
|
||||
token_replace_emb = self.timestep_embedder(token_replace_proj.to(dtype=pooled_projection.dtype))
|
||||
token_replace_emb = token_replace_emb + pooled_projections
|
||||
|
||||
if self.guidance_embedder is not None:
|
||||
guidance_proj = self.time_proj(guidance)
|
||||
guidance_emb = self.guidance_embedder(guidance_proj.to(dtype=pooled_projection.dtype))
|
||||
conditioning = conditioning + guidance_emb
|
||||
|
||||
return conditioning, token_replace_emb
|
||||
|
||||
|
||||
class HunyuanVideoIndividualTokenRefinerBlock(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
@@ -390,6 +527,8 @@ class HunyuanVideoSingleTransformerBlock(nn.Module):
|
||||
temb: torch.Tensor,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
|
||||
*args,
|
||||
**kwargs,
|
||||
) -> torch.Tensor:
|
||||
text_seq_length = encoder_hidden_states.shape[1]
|
||||
hidden_states = torch.cat([hidden_states, encoder_hidden_states], dim=1)
|
||||
@@ -468,6 +607,8 @@ class HunyuanVideoTransformerBlock(nn.Module):
|
||||
temb: torch.Tensor,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
freqs_cis: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
|
||||
*args,
|
||||
**kwargs,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
# 1. Input normalization
|
||||
norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1(hidden_states, emb=temb)
|
||||
@@ -503,6 +644,181 @@ class HunyuanVideoTransformerBlock(nn.Module):
|
||||
return hidden_states, encoder_hidden_states
|
||||
|
||||
|
||||
class HunyuanVideoTokenReplaceSingleTransformerBlock(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
num_attention_heads: int,
|
||||
attention_head_dim: int,
|
||||
mlp_ratio: float = 4.0,
|
||||
qk_norm: str = "rms_norm",
|
||||
) -> None:
|
||||
super().__init__()
|
||||
|
||||
hidden_size = num_attention_heads * attention_head_dim
|
||||
mlp_dim = int(hidden_size * mlp_ratio)
|
||||
|
||||
self.attn = Attention(
|
||||
query_dim=hidden_size,
|
||||
cross_attention_dim=None,
|
||||
dim_head=attention_head_dim,
|
||||
heads=num_attention_heads,
|
||||
out_dim=hidden_size,
|
||||
bias=True,
|
||||
processor=HunyuanVideoAttnProcessor2_0(),
|
||||
qk_norm=qk_norm,
|
||||
eps=1e-6,
|
||||
pre_only=True,
|
||||
)
|
||||
|
||||
self.norm = HunyuanVideoTokenReplaceAdaLayerNormZeroSingle(hidden_size, norm_type="layer_norm")
|
||||
self.proj_mlp = nn.Linear(hidden_size, mlp_dim)
|
||||
self.act_mlp = nn.GELU(approximate="tanh")
|
||||
self.proj_out = nn.Linear(hidden_size + mlp_dim, hidden_size)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
encoder_hidden_states: torch.Tensor,
|
||||
temb: torch.Tensor,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
|
||||
token_replace_emb: torch.Tensor = None,
|
||||
num_tokens: int = None,
|
||||
) -> torch.Tensor:
|
||||
text_seq_length = encoder_hidden_states.shape[1]
|
||||
hidden_states = torch.cat([hidden_states, encoder_hidden_states], dim=1)
|
||||
|
||||
residual = hidden_states
|
||||
|
||||
# 1. Input normalization
|
||||
norm_hidden_states, gate, tr_gate = self.norm(hidden_states, temb, token_replace_emb, num_tokens)
|
||||
mlp_hidden_states = self.act_mlp(self.proj_mlp(norm_hidden_states))
|
||||
|
||||
norm_hidden_states, norm_encoder_hidden_states = (
|
||||
norm_hidden_states[:, :-text_seq_length, :],
|
||||
norm_hidden_states[:, -text_seq_length:, :],
|
||||
)
|
||||
|
||||
# 2. Attention
|
||||
attn_output, context_attn_output = self.attn(
|
||||
hidden_states=norm_hidden_states,
|
||||
encoder_hidden_states=norm_encoder_hidden_states,
|
||||
attention_mask=attention_mask,
|
||||
image_rotary_emb=image_rotary_emb,
|
||||
)
|
||||
attn_output = torch.cat([attn_output, context_attn_output], dim=1)
|
||||
|
||||
# 3. Modulation and residual connection
|
||||
hidden_states = torch.cat([attn_output, mlp_hidden_states], dim=2)
|
||||
|
||||
proj_output = self.proj_out(hidden_states)
|
||||
hidden_states_zero = proj_output[:, :num_tokens] * tr_gate.unsqueeze(1)
|
||||
hidden_states_orig = proj_output[:, num_tokens:] * gate.unsqueeze(1)
|
||||
hidden_states = torch.cat([hidden_states_zero, hidden_states_orig], dim=1)
|
||||
hidden_states = hidden_states + residual
|
||||
|
||||
hidden_states, encoder_hidden_states = (
|
||||
hidden_states[:, :-text_seq_length, :],
|
||||
hidden_states[:, -text_seq_length:, :],
|
||||
)
|
||||
return hidden_states, encoder_hidden_states
|
||||
|
||||
|
||||
class HunyuanVideoTokenReplaceTransformerBlock(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
num_attention_heads: int,
|
||||
attention_head_dim: int,
|
||||
mlp_ratio: float,
|
||||
qk_norm: str = "rms_norm",
|
||||
) -> None:
|
||||
super().__init__()
|
||||
|
||||
hidden_size = num_attention_heads * attention_head_dim
|
||||
|
||||
self.norm1 = HunyuanVideoTokenReplaceAdaLayerNormZero(hidden_size, norm_type="layer_norm")
|
||||
self.norm1_context = AdaLayerNormZero(hidden_size, norm_type="layer_norm")
|
||||
|
||||
self.attn = Attention(
|
||||
query_dim=hidden_size,
|
||||
cross_attention_dim=None,
|
||||
added_kv_proj_dim=hidden_size,
|
||||
dim_head=attention_head_dim,
|
||||
heads=num_attention_heads,
|
||||
out_dim=hidden_size,
|
||||
context_pre_only=False,
|
||||
bias=True,
|
||||
processor=HunyuanVideoAttnProcessor2_0(),
|
||||
qk_norm=qk_norm,
|
||||
eps=1e-6,
|
||||
)
|
||||
|
||||
self.norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
|
||||
self.ff = FeedForward(hidden_size, mult=mlp_ratio, activation_fn="gelu-approximate")
|
||||
|
||||
self.norm2_context = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
|
||||
self.ff_context = FeedForward(hidden_size, mult=mlp_ratio, activation_fn="gelu-approximate")
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
encoder_hidden_states: torch.Tensor,
|
||||
temb: torch.Tensor,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
freqs_cis: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
|
||||
token_replace_emb: torch.Tensor = None,
|
||||
num_tokens: int = None,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
# 1. Input normalization
|
||||
(
|
||||
norm_hidden_states,
|
||||
gate_msa,
|
||||
shift_mlp,
|
||||
scale_mlp,
|
||||
gate_mlp,
|
||||
tr_gate_msa,
|
||||
tr_shift_mlp,
|
||||
tr_scale_mlp,
|
||||
tr_gate_mlp,
|
||||
) = self.norm1(hidden_states, temb, token_replace_emb, num_tokens)
|
||||
norm_encoder_hidden_states, c_gate_msa, c_shift_mlp, c_scale_mlp, c_gate_mlp = self.norm1_context(
|
||||
encoder_hidden_states, emb=temb
|
||||
)
|
||||
|
||||
# 2. Joint attention
|
||||
attn_output, context_attn_output = self.attn(
|
||||
hidden_states=norm_hidden_states,
|
||||
encoder_hidden_states=norm_encoder_hidden_states,
|
||||
attention_mask=attention_mask,
|
||||
image_rotary_emb=freqs_cis,
|
||||
)
|
||||
|
||||
# 3. Modulation and residual connection
|
||||
hidden_states_zero = hidden_states[:, :num_tokens] + attn_output[:, :num_tokens] * tr_gate_msa.unsqueeze(1)
|
||||
hidden_states_orig = hidden_states[:, num_tokens:] + attn_output[:, num_tokens:] * gate_msa.unsqueeze(1)
|
||||
hidden_states = torch.cat([hidden_states_zero, hidden_states_orig], dim=1)
|
||||
encoder_hidden_states = encoder_hidden_states + context_attn_output * c_gate_msa.unsqueeze(1)
|
||||
|
||||
norm_hidden_states = self.norm2(hidden_states)
|
||||
norm_encoder_hidden_states = self.norm2_context(encoder_hidden_states)
|
||||
|
||||
hidden_states_zero = norm_hidden_states[:, :num_tokens] * (1 + tr_scale_mlp[:, None]) + tr_shift_mlp[:, None]
|
||||
hidden_states_orig = norm_hidden_states[:, num_tokens:] * (1 + scale_mlp[:, None]) + shift_mlp[:, None]
|
||||
norm_hidden_states = torch.cat([hidden_states_zero, hidden_states_orig], dim=1)
|
||||
norm_encoder_hidden_states = norm_encoder_hidden_states * (1 + c_scale_mlp[:, None]) + c_shift_mlp[:, None]
|
||||
|
||||
# 4. Feed-forward
|
||||
ff_output = self.ff(norm_hidden_states)
|
||||
context_ff_output = self.ff_context(norm_encoder_hidden_states)
|
||||
|
||||
hidden_states_zero = hidden_states[:, :num_tokens] + ff_output[:, :num_tokens] * tr_gate_mlp.unsqueeze(1)
|
||||
hidden_states_orig = hidden_states[:, num_tokens:] + ff_output[:, num_tokens:] * gate_mlp.unsqueeze(1)
|
||||
hidden_states = torch.cat([hidden_states_zero, hidden_states_orig], dim=1)
|
||||
encoder_hidden_states = encoder_hidden_states + c_gate_mlp.unsqueeze(1) * context_ff_output
|
||||
|
||||
return hidden_states, encoder_hidden_states
|
||||
|
||||
|
||||
class HunyuanVideoTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginalModelMixin, CacheMixin):
|
||||
r"""
|
||||
A Transformer model for video-like data used in [HunyuanVideo](https://huggingface.co/tencent/HunyuanVideo).
|
||||
@@ -540,6 +856,10 @@ class HunyuanVideoTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin,
|
||||
The value of theta to use in the RoPE layer.
|
||||
rope_axes_dim (`Tuple[int]`, defaults to `(16, 56, 56)`):
|
||||
The dimensions of the axes to use in the RoPE layer.
|
||||
image_condition_type (`str`, *optional*, defaults to `None`):
|
||||
The type of image conditioning to use. If `None`, no image conditioning is used. If `latent_concat`, the
|
||||
image is concatenated to the latent stream. If `token_replace`, the image is used to replace first-frame
|
||||
tokens in the latent stream and apply conditioning.
|
||||
"""
|
||||
|
||||
_supports_gradient_checkpointing = True
|
||||
@@ -570,9 +890,16 @@ class HunyuanVideoTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin,
|
||||
pooled_projection_dim: int = 768,
|
||||
rope_theta: float = 256.0,
|
||||
rope_axes_dim: Tuple[int] = (16, 56, 56),
|
||||
image_condition_type: Optional[str] = None,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
|
||||
supported_image_condition_types = ["latent_concat", "token_replace"]
|
||||
if image_condition_type is not None and image_condition_type not in supported_image_condition_types:
|
||||
raise ValueError(
|
||||
f"Invalid `image_condition_type` ({image_condition_type}). Supported ones are: {supported_image_condition_types}"
|
||||
)
|
||||
|
||||
inner_dim = num_attention_heads * attention_head_dim
|
||||
out_channels = out_channels or in_channels
|
||||
|
||||
@@ -582,33 +909,52 @@ class HunyuanVideoTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin,
|
||||
text_embed_dim, num_attention_heads, attention_head_dim, num_layers=num_refiner_layers
|
||||
)
|
||||
|
||||
if guidance_embeds:
|
||||
self.time_text_embed = CombinedTimestepGuidanceTextProjEmbeddings(inner_dim, pooled_projection_dim)
|
||||
else:
|
||||
self.time_text_embed = CombinedTimestepTextProjEmbeddings(inner_dim, pooled_projection_dim)
|
||||
self.time_text_embed = HunyuanVideoConditionEmbedding(
|
||||
inner_dim, pooled_projection_dim, guidance_embeds, image_condition_type
|
||||
)
|
||||
|
||||
# 2. RoPE
|
||||
self.rope = HunyuanVideoRotaryPosEmbed(patch_size, patch_size_t, rope_axes_dim, rope_theta)
|
||||
|
||||
# 3. Dual stream transformer blocks
|
||||
self.transformer_blocks = nn.ModuleList(
|
||||
[
|
||||
HunyuanVideoTransformerBlock(
|
||||
num_attention_heads, attention_head_dim, mlp_ratio=mlp_ratio, qk_norm=qk_norm
|
||||
)
|
||||
for _ in range(num_layers)
|
||||
]
|
||||
)
|
||||
if image_condition_type == "token_replace":
|
||||
self.transformer_blocks = nn.ModuleList(
|
||||
[
|
||||
HunyuanVideoTokenReplaceTransformerBlock(
|
||||
num_attention_heads, attention_head_dim, mlp_ratio=mlp_ratio, qk_norm=qk_norm
|
||||
)
|
||||
for _ in range(num_layers)
|
||||
]
|
||||
)
|
||||
else:
|
||||
self.transformer_blocks = nn.ModuleList(
|
||||
[
|
||||
HunyuanVideoTransformerBlock(
|
||||
num_attention_heads, attention_head_dim, mlp_ratio=mlp_ratio, qk_norm=qk_norm
|
||||
)
|
||||
for _ in range(num_layers)
|
||||
]
|
||||
)
|
||||
|
||||
# 4. Single stream transformer blocks
|
||||
self.single_transformer_blocks = nn.ModuleList(
|
||||
[
|
||||
HunyuanVideoSingleTransformerBlock(
|
||||
num_attention_heads, attention_head_dim, mlp_ratio=mlp_ratio, qk_norm=qk_norm
|
||||
)
|
||||
for _ in range(num_single_layers)
|
||||
]
|
||||
)
|
||||
if image_condition_type == "token_replace":
|
||||
self.single_transformer_blocks = nn.ModuleList(
|
||||
[
|
||||
HunyuanVideoTokenReplaceSingleTransformerBlock(
|
||||
num_attention_heads, attention_head_dim, mlp_ratio=mlp_ratio, qk_norm=qk_norm
|
||||
)
|
||||
for _ in range(num_single_layers)
|
||||
]
|
||||
)
|
||||
else:
|
||||
self.single_transformer_blocks = nn.ModuleList(
|
||||
[
|
||||
HunyuanVideoSingleTransformerBlock(
|
||||
num_attention_heads, attention_head_dim, mlp_ratio=mlp_ratio, qk_norm=qk_norm
|
||||
)
|
||||
for _ in range(num_single_layers)
|
||||
]
|
||||
)
|
||||
|
||||
# 5. Output projection
|
||||
self.norm_out = AdaLayerNormContinuous(inner_dim, inner_dim, elementwise_affine=False, eps=1e-6)
|
||||
@@ -707,15 +1053,13 @@ class HunyuanVideoTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin,
|
||||
post_patch_num_frames = num_frames // p_t
|
||||
post_patch_height = height // p
|
||||
post_patch_width = width // p
|
||||
first_frame_num_tokens = 1 * post_patch_height * post_patch_width
|
||||
|
||||
# 1. RoPE
|
||||
image_rotary_emb = self.rope(hidden_states)
|
||||
|
||||
# 2. Conditional embeddings
|
||||
if self.config.guidance_embeds:
|
||||
temb = self.time_text_embed(timestep, guidance, pooled_projections)
|
||||
else:
|
||||
temb = self.time_text_embed(timestep, pooled_projections)
|
||||
temb, token_replace_emb = self.time_text_embed(timestep, pooled_projections, guidance)
|
||||
|
||||
hidden_states = self.x_embedder(hidden_states)
|
||||
encoder_hidden_states = self.context_embedder(encoder_hidden_states, timestep, encoder_attention_mask)
|
||||
@@ -746,6 +1090,8 @@ class HunyuanVideoTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin,
|
||||
temb,
|
||||
attention_mask,
|
||||
image_rotary_emb,
|
||||
token_replace_emb,
|
||||
first_frame_num_tokens,
|
||||
)
|
||||
|
||||
for block in self.single_transformer_blocks:
|
||||
@@ -756,17 +1102,31 @@ class HunyuanVideoTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin,
|
||||
temb,
|
||||
attention_mask,
|
||||
image_rotary_emb,
|
||||
token_replace_emb,
|
||||
first_frame_num_tokens,
|
||||
)
|
||||
|
||||
else:
|
||||
for block in self.transformer_blocks:
|
||||
hidden_states, encoder_hidden_states = block(
|
||||
hidden_states, encoder_hidden_states, temb, attention_mask, image_rotary_emb
|
||||
hidden_states,
|
||||
encoder_hidden_states,
|
||||
temb,
|
||||
attention_mask,
|
||||
image_rotary_emb,
|
||||
token_replace_emb,
|
||||
first_frame_num_tokens,
|
||||
)
|
||||
|
||||
for block in self.single_transformer_blocks:
|
||||
hidden_states, encoder_hidden_states = block(
|
||||
hidden_states, encoder_hidden_states, temb, attention_mask, image_rotary_emb
|
||||
hidden_states,
|
||||
encoder_hidden_states,
|
||||
temb,
|
||||
attention_mask,
|
||||
image_rotary_emb,
|
||||
token_replace_emb,
|
||||
first_frame_num_tokens,
|
||||
)
|
||||
|
||||
# 5. Output projection
|
||||
|
||||
@@ -14,7 +14,7 @@
|
||||
# limitations under the License.
|
||||
|
||||
import math
|
||||
from typing import Any, Dict, Optional, Tuple
|
||||
from typing import Any, Dict, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
@@ -113,20 +113,19 @@ class LTXVideoRotaryPosEmbed(nn.Module):
|
||||
self.patch_size_t = patch_size_t
|
||||
self.theta = theta
|
||||
|
||||
def forward(
|
||||
def _prepare_video_coords(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
batch_size: int,
|
||||
num_frames: int,
|
||||
height: int,
|
||||
width: int,
|
||||
rope_interpolation_scale: Optional[Tuple[torch.Tensor, float, float]] = None,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
batch_size = hidden_states.size(0)
|
||||
|
||||
rope_interpolation_scale: Tuple[torch.Tensor, float, float],
|
||||
device: torch.device,
|
||||
) -> torch.Tensor:
|
||||
# Always compute rope in fp32
|
||||
grid_h = torch.arange(height, dtype=torch.float32, device=hidden_states.device)
|
||||
grid_w = torch.arange(width, dtype=torch.float32, device=hidden_states.device)
|
||||
grid_f = torch.arange(num_frames, dtype=torch.float32, device=hidden_states.device)
|
||||
grid_h = torch.arange(height, dtype=torch.float32, device=device)
|
||||
grid_w = torch.arange(width, dtype=torch.float32, device=device)
|
||||
grid_f = torch.arange(num_frames, dtype=torch.float32, device=device)
|
||||
grid = torch.meshgrid(grid_f, grid_h, grid_w, indexing="ij")
|
||||
grid = torch.stack(grid, dim=0)
|
||||
grid = grid.unsqueeze(0).repeat(batch_size, 1, 1, 1, 1)
|
||||
@@ -138,6 +137,38 @@ class LTXVideoRotaryPosEmbed(nn.Module):
|
||||
|
||||
grid = grid.flatten(2, 4).transpose(1, 2)
|
||||
|
||||
return grid
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
num_frames: Optional[int] = None,
|
||||
height: Optional[int] = None,
|
||||
width: Optional[int] = None,
|
||||
rope_interpolation_scale: Optional[Tuple[torch.Tensor, float, float]] = None,
|
||||
video_coords: Optional[torch.Tensor] = None,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
batch_size = hidden_states.size(0)
|
||||
|
||||
if video_coords is None:
|
||||
grid = self._prepare_video_coords(
|
||||
batch_size,
|
||||
num_frames,
|
||||
height,
|
||||
width,
|
||||
rope_interpolation_scale=rope_interpolation_scale,
|
||||
device=hidden_states.device,
|
||||
)
|
||||
else:
|
||||
grid = torch.stack(
|
||||
[
|
||||
video_coords[:, 0] / self.base_num_frames,
|
||||
video_coords[:, 1] / self.base_height,
|
||||
video_coords[:, 2] / self.base_width,
|
||||
],
|
||||
dim=-1,
|
||||
)
|
||||
|
||||
start = 1.0
|
||||
end = self.theta
|
||||
freqs = self.theta ** torch.linspace(
|
||||
@@ -367,10 +398,11 @@ class LTXVideoTransformer3DModel(ModelMixin, ConfigMixin, FromOriginalModelMixin
|
||||
encoder_hidden_states: torch.Tensor,
|
||||
timestep: torch.LongTensor,
|
||||
encoder_attention_mask: torch.Tensor,
|
||||
num_frames: int,
|
||||
height: int,
|
||||
width: int,
|
||||
rope_interpolation_scale: Optional[Tuple[float, float, float]] = None,
|
||||
num_frames: Optional[int] = None,
|
||||
height: Optional[int] = None,
|
||||
width: Optional[int] = None,
|
||||
rope_interpolation_scale: Optional[Union[Tuple[float, float, float], torch.Tensor]] = None,
|
||||
video_coords: Optional[torch.Tensor] = None,
|
||||
attention_kwargs: Optional[Dict[str, Any]] = None,
|
||||
return_dict: bool = True,
|
||||
) -> torch.Tensor:
|
||||
@@ -389,7 +421,7 @@ class LTXVideoTransformer3DModel(ModelMixin, ConfigMixin, FromOriginalModelMixin
|
||||
"Passing `scale` via `attention_kwargs` when not using the PEFT backend is ineffective."
|
||||
)
|
||||
|
||||
image_rotary_emb = self.rope(hidden_states, num_frames, height, width, rope_interpolation_scale)
|
||||
image_rotary_emb = self.rope(hidden_states, num_frames, height, width, rope_interpolation_scale, video_coords)
|
||||
|
||||
# convert encoder_attention_mask to a bias the same way we do for attention_mask
|
||||
if encoder_attention_mask is not None and encoder_attention_mask.ndim == 2:
|
||||
|
||||
@@ -264,7 +264,7 @@ else:
|
||||
]
|
||||
)
|
||||
_import_structure["latte"] = ["LattePipeline"]
|
||||
_import_structure["ltx"] = ["LTXPipeline", "LTXImageToVideoPipeline"]
|
||||
_import_structure["ltx"] = ["LTXPipeline", "LTXImageToVideoPipeline", "LTXConditionPipeline"]
|
||||
_import_structure["lumina"] = ["LuminaPipeline", "LuminaText2ImgPipeline"]
|
||||
_import_structure["lumina2"] = ["Lumina2Pipeline", "Lumina2Text2ImgPipeline"]
|
||||
_import_structure["marigold"].extend(
|
||||
@@ -280,7 +280,7 @@ else:
|
||||
_import_structure["paint_by_example"] = ["PaintByExamplePipeline"]
|
||||
_import_structure["pia"] = ["PIAPipeline"]
|
||||
_import_structure["pixart_alpha"] = ["PixArtAlphaPipeline", "PixArtSigmaPipeline"]
|
||||
_import_structure["sana"] = ["SanaPipeline"]
|
||||
_import_structure["sana"] = ["SanaPipeline", "SanaSprintPipeline"]
|
||||
_import_structure["semantic_stable_diffusion"] = ["SemanticStableDiffusionPipeline"]
|
||||
_import_structure["shap_e"] = ["ShapEImg2ImgPipeline", "ShapEPipeline"]
|
||||
_import_structure["stable_audio"] = [
|
||||
@@ -618,7 +618,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
LEditsPPPipelineStableDiffusion,
|
||||
LEditsPPPipelineStableDiffusionXL,
|
||||
)
|
||||
from .ltx import LTXImageToVideoPipeline, LTXPipeline
|
||||
from .ltx import LTXConditionPipeline, LTXImageToVideoPipeline, LTXPipeline
|
||||
from .lumina import LuminaPipeline, LuminaText2ImgPipeline
|
||||
from .lumina2 import Lumina2Pipeline, Lumina2Text2ImgPipeline
|
||||
from .marigold import (
|
||||
@@ -651,7 +651,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
from .paint_by_example import PaintByExamplePipeline
|
||||
from .pia import PIAPipeline
|
||||
from .pixart_alpha import PixArtAlphaPipeline, PixArtSigmaPipeline
|
||||
from .sana import SanaPipeline
|
||||
from .sana import SanaPipeline, SanaSprintPipeline
|
||||
from .semantic_stable_diffusion import SemanticStableDiffusionPipeline
|
||||
from .shap_e import ShapEImg2ImgPipeline, ShapEPipeline
|
||||
from .stable_audio import StableAudioPipeline, StableAudioProjectionModel
|
||||
|
||||
@@ -68,7 +68,7 @@ def calculate_shift(
|
||||
return mu
|
||||
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps
|
||||
# Copied from diffusers.pipelines.cogview4.pipeline_cogview4.retrieve_timesteps
|
||||
def retrieve_timesteps(
|
||||
scheduler,
|
||||
num_inference_steps: Optional[int] = None,
|
||||
@@ -100,10 +100,19 @@ def retrieve_timesteps(
|
||||
`Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the
|
||||
second element is the number of inference steps.
|
||||
"""
|
||||
accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
|
||||
accepts_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
|
||||
|
||||
if timesteps is not None and sigmas is not None:
|
||||
raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values")
|
||||
if timesteps is not None:
|
||||
accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
|
||||
if not accepts_timesteps and not accepts_sigmas:
|
||||
raise ValueError(
|
||||
f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
|
||||
f" timestep or sigma schedules. Please check whether you are using the correct scheduler."
|
||||
)
|
||||
scheduler.set_timesteps(timesteps=timesteps, sigmas=sigmas, device=device, **kwargs)
|
||||
timesteps = scheduler.timesteps
|
||||
num_inference_steps = len(timesteps)
|
||||
elif timesteps is not None and sigmas is None:
|
||||
if not accepts_timesteps:
|
||||
raise ValueError(
|
||||
f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
|
||||
@@ -112,9 +121,8 @@ def retrieve_timesteps(
|
||||
scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)
|
||||
timesteps = scheduler.timesteps
|
||||
num_inference_steps = len(timesteps)
|
||||
elif sigmas is not None:
|
||||
accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
|
||||
if not accept_sigmas:
|
||||
elif timesteps is None and sigmas is not None:
|
||||
if not accepts_sigmas:
|
||||
raise ValueError(
|
||||
f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
|
||||
f" sigmas schedules. Please check whether you are using the correct scheduler."
|
||||
@@ -515,8 +523,8 @@ class CogView4ControlPipeline(DiffusionPipeline):
|
||||
The output format of the generate image. Choose between
|
||||
[PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
|
||||
return_dict (`bool`, *optional*, defaults to `True`):
|
||||
Whether or not to return a [`~pipelines.stable_diffusion_xl.StableDiffusionXLPipelineOutput`] instead
|
||||
of a plain tuple.
|
||||
Whether or not to return a [`~pipelines.pipeline_CogView4.CogView4PipelineOutput`] instead of a plain
|
||||
tuple.
|
||||
attention_kwargs (`dict`, *optional*):
|
||||
A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
|
||||
`self.processor` in
|
||||
@@ -532,7 +540,6 @@ class CogView4ControlPipeline(DiffusionPipeline):
|
||||
`._callback_tensor_inputs` attribute of your pipeline class.
|
||||
max_sequence_length (`int`, defaults to `224`):
|
||||
Maximum sequence length in encoded prompt. Can be set to other values but may lead to poorer results.
|
||||
|
||||
Examples:
|
||||
|
||||
Returns:
|
||||
|
||||
@@ -184,7 +184,14 @@ class StableDiffusionControlNetInpaintPipeline(
|
||||
model_cpu_offload_seq = "text_encoder->image_encoder->unet->vae"
|
||||
_optional_components = ["safety_checker", "feature_extractor", "image_encoder"]
|
||||
_exclude_from_cpu_offload = ["safety_checker"]
|
||||
_callback_tensor_inputs = ["latents", "prompt_embeds", "negative_prompt_embeds", "control_image"]
|
||||
_callback_tensor_inputs = [
|
||||
"latents",
|
||||
"prompt_embeds",
|
||||
"negative_prompt_embeds",
|
||||
"control_image",
|
||||
"mask",
|
||||
"masked_image_latents",
|
||||
]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
|
||||
@@ -533,7 +533,6 @@ class FluxControlImg2ImgPipeline(DiffusionPipeline, FluxLoraLoaderMixin, FromSin
|
||||
|
||||
return latents
|
||||
|
||||
# Copied from diffusers.pipelines.flux.pipeline_flux_img2img.FluxImg2ImgPipeline.prepare_latents
|
||||
def prepare_latents(
|
||||
self,
|
||||
image,
|
||||
|
||||
@@ -63,6 +63,7 @@ EXAMPLE_DOC_STRING = """
|
||||
>>> from diffusers import FluxControlNetPipeline
|
||||
>>> from diffusers import FluxControlNetModel
|
||||
|
||||
>>> base_model = "black-forest-labs/FLUX.1-dev"
|
||||
>>> controlnet_model = "InstantX/FLUX.1-dev-controlnet-canny"
|
||||
>>> controlnet = FluxControlNetModel.from_pretrained(controlnet_model, torch_dtype=torch.bfloat16)
|
||||
>>> pipe = FluxControlNetPipeline.from_pretrained(
|
||||
|
||||
@@ -533,7 +533,6 @@ class FluxControlNetImg2ImgPipeline(DiffusionPipeline, FluxLoraLoaderMixin, From
|
||||
|
||||
return latents
|
||||
|
||||
# Copied from diffusers.pipelines.flux.pipeline_flux_img2img.FluxImg2ImgPipeline.prepare_latents
|
||||
def prepare_latents(
|
||||
self,
|
||||
image,
|
||||
|
||||
@@ -561,7 +561,6 @@ class FluxControlNetInpaintPipeline(DiffusionPipeline, FluxLoraLoaderMixin, From
|
||||
|
||||
return latents
|
||||
|
||||
# Copied from diffusers.pipelines.flux.pipeline_flux_inpaint.FluxInpaintPipeline.prepare_latents
|
||||
def prepare_latents(
|
||||
self,
|
||||
image,
|
||||
@@ -614,7 +613,6 @@ class FluxControlNetInpaintPipeline(DiffusionPipeline, FluxLoraLoaderMixin, From
|
||||
latents = self._pack_latents(latents, batch_size, num_channels_latents, height, width)
|
||||
return latents, noise, image_latents, latent_image_ids
|
||||
|
||||
# Copied from diffusers.pipelines.flux.pipeline_flux_inpaint.FluxInpaintPipeline.prepare_mask_latents
|
||||
def prepare_mask_latents(
|
||||
self,
|
||||
mask,
|
||||
|
||||
@@ -225,7 +225,10 @@ class FluxImg2ImgPipeline(DiffusionPipeline, FluxLoraLoaderMixin, FromSingleFile
|
||||
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8
|
||||
# Flux latents are turned into 2x2 patches and packed. This means the latent width and height has to be divisible
|
||||
# by the patch size. So the vae scale factor is multiplied by the patch size to account for this
|
||||
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor * 2)
|
||||
self.latent_channels = self.vae.config.latent_channels if getattr(self, "vae", None) else 16
|
||||
self.image_processor = VaeImageProcessor(
|
||||
vae_scale_factor=self.vae_scale_factor * 2, vae_latent_channels=self.latent_channels
|
||||
)
|
||||
self.tokenizer_max_length = (
|
||||
self.tokenizer.model_max_length if hasattr(self, "tokenizer") and self.tokenizer is not None else 77
|
||||
)
|
||||
@@ -634,7 +637,10 @@ class FluxImg2ImgPipeline(DiffusionPipeline, FluxLoraLoaderMixin, FromSingleFile
|
||||
return latents.to(device=device, dtype=dtype), latent_image_ids
|
||||
|
||||
image = image.to(device=device, dtype=dtype)
|
||||
image_latents = self._encode_vae_image(image=image, generator=generator)
|
||||
if image.shape[1] != self.latent_channels:
|
||||
image_latents = self._encode_vae_image(image=image, generator=generator)
|
||||
else:
|
||||
image_latents = image
|
||||
if batch_size > image_latents.shape[0] and batch_size % image_latents.shape[0] == 0:
|
||||
# expand init_latents for batch_size
|
||||
additional_image_per_prompt = batch_size // image_latents.shape[0]
|
||||
|
||||
@@ -222,11 +222,13 @@ class FluxInpaintPipeline(DiffusionPipeline, FluxLoraLoaderMixin, FluxIPAdapterM
|
||||
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8
|
||||
# Flux latents are turned into 2x2 patches and packed. This means the latent width and height has to be divisible
|
||||
# by the patch size. So the vae scale factor is multiplied by the patch size to account for this
|
||||
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor * 2)
|
||||
latent_channels = self.vae.config.latent_channels if getattr(self, "vae", None) else 16
|
||||
self.latent_channels = self.vae.config.latent_channels if getattr(self, "vae", None) else 16
|
||||
self.image_processor = VaeImageProcessor(
|
||||
vae_scale_factor=self.vae_scale_factor * 2, vae_latent_channels=self.latent_channels
|
||||
)
|
||||
self.mask_processor = VaeImageProcessor(
|
||||
vae_scale_factor=self.vae_scale_factor * 2,
|
||||
vae_latent_channels=latent_channels,
|
||||
vae_latent_channels=self.latent_channels,
|
||||
do_normalize=False,
|
||||
do_binarize=True,
|
||||
do_convert_grayscale=True,
|
||||
@@ -653,7 +655,10 @@ class FluxInpaintPipeline(DiffusionPipeline, FluxLoraLoaderMixin, FluxIPAdapterM
|
||||
latent_image_ids = self._prepare_latent_image_ids(batch_size, height // 2, width // 2, device, dtype)
|
||||
|
||||
image = image.to(device=device, dtype=dtype)
|
||||
image_latents = self._encode_vae_image(image=image, generator=generator)
|
||||
if image.shape[1] != self.latent_channels:
|
||||
image_latents = self._encode_vae_image(image=image, generator=generator)
|
||||
else:
|
||||
image_latents = image
|
||||
|
||||
if batch_size > image_latents.shape[0] and batch_size % image_latents.shape[0] == 0:
|
||||
# expand init_latents for batch_size
|
||||
@@ -710,7 +715,9 @@ class FluxInpaintPipeline(DiffusionPipeline, FluxLoraLoaderMixin, FluxIPAdapterM
|
||||
else:
|
||||
masked_image_latents = retrieve_latents(self.vae.encode(masked_image), generator=generator)
|
||||
|
||||
masked_image_latents = (masked_image_latents - self.vae.config.shift_factor) * self.vae.config.scaling_factor
|
||||
masked_image_latents = (
|
||||
masked_image_latents - self.vae.config.shift_factor
|
||||
) * self.vae.config.scaling_factor
|
||||
|
||||
# duplicate mask and masked_image_latents for each generation per prompt, using mps friendly method
|
||||
if mask.shape[0] < batch_size:
|
||||
|
||||
@@ -54,6 +54,7 @@ EXAMPLE_DOC_STRING = """
|
||||
>>> from diffusers import HunyuanVideoImageToVideoPipeline, HunyuanVideoTransformer3DModel
|
||||
>>> from diffusers.utils import load_image, export_to_video
|
||||
|
||||
>>> # Available checkpoints: hunyuanvideo-community/HunyuanVideo-I2V, hunyuanvideo-community/HunyuanVideo-I2V-33ch
|
||||
>>> model_id = "hunyuanvideo-community/HunyuanVideo-I2V"
|
||||
>>> transformer = HunyuanVideoTransformer3DModel.from_pretrained(
|
||||
... model_id, subfolder="transformer", torch_dtype=torch.bfloat16
|
||||
@@ -69,7 +70,12 @@ EXAMPLE_DOC_STRING = """
|
||||
... "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/guitar-man.png"
|
||||
... )
|
||||
|
||||
>>> output = pipe(image=image, prompt=prompt).frames[0]
|
||||
>>> # If using hunyuanvideo-community/HunyuanVideo-I2V
|
||||
>>> output = pipe(image=image, prompt=prompt, guidance_scale=6.0).frames[0]
|
||||
|
||||
>>> # If using hunyuanvideo-community/HunyuanVideo-I2V-33ch
|
||||
>>> output = pipe(image=image, prompt=prompt, guidance_scale=1.0, true_cfg_scale=1.0).frames[0]
|
||||
|
||||
>>> export_to_video(output, "output.mp4", fps=15)
|
||||
```
|
||||
"""
|
||||
@@ -399,7 +405,8 @@ class HunyuanVideoImageToVideoPipeline(DiffusionPipeline, HunyuanVideoLoraLoader
|
||||
device: Optional[torch.device] = None,
|
||||
dtype: Optional[torch.dtype] = None,
|
||||
max_sequence_length: int = 256,
|
||||
):
|
||||
image_embed_interleave: int = 2,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
if prompt_embeds is None:
|
||||
prompt_embeds, prompt_attention_mask = self._get_llama_prompt_embeds(
|
||||
image,
|
||||
@@ -409,6 +416,7 @@ class HunyuanVideoImageToVideoPipeline(DiffusionPipeline, HunyuanVideoLoraLoader
|
||||
device=device,
|
||||
dtype=dtype,
|
||||
max_sequence_length=max_sequence_length,
|
||||
image_embed_interleave=image_embed_interleave,
|
||||
)
|
||||
|
||||
if pooled_prompt_embeds is None:
|
||||
@@ -433,6 +441,8 @@ class HunyuanVideoImageToVideoPipeline(DiffusionPipeline, HunyuanVideoLoraLoader
|
||||
prompt_embeds=None,
|
||||
callback_on_step_end_tensor_inputs=None,
|
||||
prompt_template=None,
|
||||
true_cfg_scale=1.0,
|
||||
guidance_scale=1.0,
|
||||
):
|
||||
if height % 16 != 0 or width % 16 != 0:
|
||||
raise ValueError(f"`height` and `width` have to be divisible by 16 but are {height} and {width}.")
|
||||
@@ -471,6 +481,13 @@ class HunyuanVideoImageToVideoPipeline(DiffusionPipeline, HunyuanVideoLoraLoader
|
||||
f"`prompt_template` has to contain a key `template` but only found {prompt_template.keys()}"
|
||||
)
|
||||
|
||||
if true_cfg_scale > 1.0 and guidance_scale > 1.0:
|
||||
logger.warning(
|
||||
"Both `true_cfg_scale` and `guidance_scale` are greater than 1.0. This will result in both "
|
||||
"classifier-free guidance and embedded-guidance to be applied. This is not recommended "
|
||||
"as it may lead to higher memory usage, slower inference and potentially worse results."
|
||||
)
|
||||
|
||||
def prepare_latents(
|
||||
self,
|
||||
image: torch.Tensor,
|
||||
@@ -483,6 +500,7 @@ class HunyuanVideoImageToVideoPipeline(DiffusionPipeline, HunyuanVideoLoraLoader
|
||||
device: Optional[torch.device] = None,
|
||||
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
|
||||
latents: Optional[torch.Tensor] = None,
|
||||
image_condition_type: str = "latent_concat",
|
||||
) -> torch.Tensor:
|
||||
if isinstance(generator, list) and len(generator) != batch_size:
|
||||
raise ValueError(
|
||||
@@ -497,10 +515,11 @@ class HunyuanVideoImageToVideoPipeline(DiffusionPipeline, HunyuanVideoLoraLoader
|
||||
image = image.unsqueeze(2) # [B, C, 1, H, W]
|
||||
if isinstance(generator, list):
|
||||
image_latents = [
|
||||
retrieve_latents(self.vae.encode(image[i].unsqueeze(0)), generator[i]) for i in range(batch_size)
|
||||
retrieve_latents(self.vae.encode(image[i].unsqueeze(0)), generator[i], "argmax")
|
||||
for i in range(batch_size)
|
||||
]
|
||||
else:
|
||||
image_latents = [retrieve_latents(self.vae.encode(img.unsqueeze(0)), generator) for img in image]
|
||||
image_latents = [retrieve_latents(self.vae.encode(img.unsqueeze(0)), generator, "argmax") for img in image]
|
||||
|
||||
image_latents = torch.cat(image_latents, dim=0).to(dtype) * self.vae_scaling_factor
|
||||
image_latents = image_latents.repeat(1, 1, num_latent_frames, 1, 1)
|
||||
@@ -513,6 +532,9 @@ class HunyuanVideoImageToVideoPipeline(DiffusionPipeline, HunyuanVideoLoraLoader
|
||||
t = torch.tensor([0.999]).to(device=device)
|
||||
latents = latents * t + image_latents * (1 - t)
|
||||
|
||||
if image_condition_type == "token_replace":
|
||||
image_latents = image_latents[:, :, :1]
|
||||
|
||||
return latents, image_latents
|
||||
|
||||
def enable_vae_slicing(self):
|
||||
@@ -598,6 +620,7 @@ class HunyuanVideoImageToVideoPipeline(DiffusionPipeline, HunyuanVideoLoraLoader
|
||||
callback_on_step_end_tensor_inputs: List[str] = ["latents"],
|
||||
prompt_template: Dict[str, Any] = DEFAULT_PROMPT_TEMPLATE,
|
||||
max_sequence_length: int = 256,
|
||||
image_embed_interleave: Optional[int] = None,
|
||||
):
|
||||
r"""
|
||||
The call function to the pipeline for generation.
|
||||
@@ -704,12 +727,22 @@ class HunyuanVideoImageToVideoPipeline(DiffusionPipeline, HunyuanVideoLoraLoader
|
||||
prompt_embeds,
|
||||
callback_on_step_end_tensor_inputs,
|
||||
prompt_template,
|
||||
true_cfg_scale,
|
||||
guidance_scale,
|
||||
)
|
||||
|
||||
image_condition_type = self.transformer.config.image_condition_type
|
||||
has_neg_prompt = negative_prompt is not None or (
|
||||
negative_prompt_embeds is not None and negative_pooled_prompt_embeds is not None
|
||||
)
|
||||
do_true_cfg = true_cfg_scale > 1 and has_neg_prompt
|
||||
image_embed_interleave = (
|
||||
image_embed_interleave
|
||||
if image_embed_interleave is not None
|
||||
else (
|
||||
2 if image_condition_type == "latent_concat" else 4 if image_condition_type == "token_replace" else 1
|
||||
)
|
||||
)
|
||||
|
||||
self._guidance_scale = guidance_scale
|
||||
self._attention_kwargs = attention_kwargs
|
||||
@@ -729,7 +762,12 @@ class HunyuanVideoImageToVideoPipeline(DiffusionPipeline, HunyuanVideoLoraLoader
|
||||
# 3. Prepare latent variables
|
||||
vae_dtype = self.vae.dtype
|
||||
image_tensor = self.video_processor.preprocess(image, height, width).to(device, vae_dtype)
|
||||
num_channels_latents = (self.transformer.config.in_channels - 1) // 2
|
||||
|
||||
if image_condition_type == "latent_concat":
|
||||
num_channels_latents = (self.transformer.config.in_channels - 1) // 2
|
||||
elif image_condition_type == "token_replace":
|
||||
num_channels_latents = self.transformer.config.in_channels
|
||||
|
||||
latents, image_latents = self.prepare_latents(
|
||||
image_tensor,
|
||||
batch_size * num_videos_per_prompt,
|
||||
@@ -741,10 +779,12 @@ class HunyuanVideoImageToVideoPipeline(DiffusionPipeline, HunyuanVideoLoraLoader
|
||||
device,
|
||||
generator,
|
||||
latents,
|
||||
image_condition_type,
|
||||
)
|
||||
image_latents[:, :, 1:] = 0
|
||||
mask = image_latents.new_ones(image_latents.shape[0], 1, *image_latents.shape[2:])
|
||||
mask[:, :, 1:] = 0
|
||||
if image_condition_type == "latent_concat":
|
||||
image_latents[:, :, 1:] = 0
|
||||
mask = image_latents.new_ones(image_latents.shape[0], 1, *image_latents.shape[2:])
|
||||
mask[:, :, 1:] = 0
|
||||
|
||||
# 4. Encode input prompt
|
||||
transformer_dtype = self.transformer.dtype
|
||||
@@ -759,6 +799,7 @@ class HunyuanVideoImageToVideoPipeline(DiffusionPipeline, HunyuanVideoLoraLoader
|
||||
prompt_attention_mask=prompt_attention_mask,
|
||||
device=device,
|
||||
max_sequence_length=max_sequence_length,
|
||||
image_embed_interleave=image_embed_interleave,
|
||||
)
|
||||
prompt_embeds = prompt_embeds.to(transformer_dtype)
|
||||
prompt_attention_mask = prompt_attention_mask.to(transformer_dtype)
|
||||
@@ -782,10 +823,17 @@ class HunyuanVideoImageToVideoPipeline(DiffusionPipeline, HunyuanVideoLoraLoader
|
||||
negative_prompt_attention_mask = negative_prompt_attention_mask.to(transformer_dtype)
|
||||
negative_pooled_prompt_embeds = negative_pooled_prompt_embeds.to(transformer_dtype)
|
||||
|
||||
# 4. Prepare timesteps
|
||||
# 5. Prepare timesteps
|
||||
sigmas = np.linspace(1.0, 0.0, num_inference_steps + 1)[:-1] if sigmas is None else sigmas
|
||||
timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, sigmas=sigmas)
|
||||
|
||||
# 6. Prepare guidance condition
|
||||
guidance = None
|
||||
if self.transformer.config.guidance_embeds:
|
||||
guidance = (
|
||||
torch.tensor([guidance_scale] * latents.shape[0], dtype=transformer_dtype, device=device) * 1000.0
|
||||
)
|
||||
|
||||
# 7. Denoising loop
|
||||
num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
|
||||
self._num_timesteps = len(timesteps)
|
||||
@@ -796,16 +844,21 @@ class HunyuanVideoImageToVideoPipeline(DiffusionPipeline, HunyuanVideoLoraLoader
|
||||
continue
|
||||
|
||||
self._current_timestep = t
|
||||
latent_model_input = torch.cat([latents, image_latents, mask], dim=1).to(transformer_dtype)
|
||||
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
|
||||
timestep = t.expand(latents.shape[0]).to(latents.dtype)
|
||||
|
||||
if image_condition_type == "latent_concat":
|
||||
latent_model_input = torch.cat([latents, image_latents, mask], dim=1).to(transformer_dtype)
|
||||
elif image_condition_type == "token_replace":
|
||||
latent_model_input = torch.cat([image_latents, latents[:, :, 1:]], dim=2).to(transformer_dtype)
|
||||
|
||||
noise_pred = self.transformer(
|
||||
hidden_states=latent_model_input,
|
||||
timestep=timestep,
|
||||
encoder_hidden_states=prompt_embeds,
|
||||
encoder_attention_mask=prompt_attention_mask,
|
||||
pooled_projections=pooled_prompt_embeds,
|
||||
guidance=guidance,
|
||||
attention_kwargs=attention_kwargs,
|
||||
return_dict=False,
|
||||
)[0]
|
||||
@@ -817,13 +870,20 @@ class HunyuanVideoImageToVideoPipeline(DiffusionPipeline, HunyuanVideoLoraLoader
|
||||
encoder_hidden_states=negative_prompt_embeds,
|
||||
encoder_attention_mask=negative_prompt_attention_mask,
|
||||
pooled_projections=negative_pooled_prompt_embeds,
|
||||
guidance=guidance,
|
||||
attention_kwargs=attention_kwargs,
|
||||
return_dict=False,
|
||||
)[0]
|
||||
noise_pred = neg_noise_pred + true_cfg_scale * (noise_pred - neg_noise_pred)
|
||||
|
||||
# compute the previous noisy sample x_t -> x_t-1
|
||||
latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0]
|
||||
if image_condition_type == "latent_concat":
|
||||
latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0]
|
||||
elif image_condition_type == "token_replace":
|
||||
latents = latents = self.scheduler.step(
|
||||
noise_pred[:, :, 1:], t, latents[:, :, 1:], return_dict=False
|
||||
)[0]
|
||||
latents = torch.cat([image_latents, latents], dim=2)
|
||||
|
||||
if callback_on_step_end is not None:
|
||||
callback_kwargs = {}
|
||||
@@ -844,12 +904,16 @@ class HunyuanVideoImageToVideoPipeline(DiffusionPipeline, HunyuanVideoLoraLoader
|
||||
self._current_timestep = None
|
||||
|
||||
if not output_type == "latent":
|
||||
latents = latents.to(self.vae.dtype) / self.vae.config.scaling_factor
|
||||
latents = latents.to(self.vae.dtype) / self.vae_scaling_factor
|
||||
video = self.vae.decode(latents, return_dict=False)[0]
|
||||
video = video[:, :, 4:, :, :]
|
||||
if image_condition_type == "latent_concat":
|
||||
video = video[:, :, 4:, :, :]
|
||||
video = self.video_processor.postprocess_video(video, output_type=output_type)
|
||||
else:
|
||||
video = latents[:, :, 1:, :, :]
|
||||
if image_condition_type == "latent_concat":
|
||||
video = latents[:, :, 1:, :, :]
|
||||
else:
|
||||
video = latents
|
||||
|
||||
# Offload all models
|
||||
self.maybe_free_model_hooks()
|
||||
|
||||
@@ -104,13 +104,6 @@ class RMSNorm(torch.nn.Module):
|
||||
return (self.weight * hidden_states).to(input_dtype)
|
||||
|
||||
|
||||
def _config_to_kwargs(args):
|
||||
common_kwargs = {
|
||||
"dtype": args.torch_dtype,
|
||||
}
|
||||
return common_kwargs
|
||||
|
||||
|
||||
class CoreAttention(torch.nn.Module):
|
||||
def __init__(self, config: ChatGLMConfig, layer_number):
|
||||
super(CoreAttention, self).__init__()
|
||||
@@ -314,7 +307,6 @@ class SelfAttention(torch.nn.Module):
|
||||
self.qkv_hidden_size,
|
||||
bias=config.add_bias_linear or config.add_qkv_bias,
|
||||
device=device,
|
||||
**_config_to_kwargs(config),
|
||||
)
|
||||
|
||||
self.core_attention = CoreAttention(config, self.layer_number)
|
||||
@@ -325,7 +317,6 @@ class SelfAttention(torch.nn.Module):
|
||||
config.hidden_size,
|
||||
bias=config.add_bias_linear,
|
||||
device=device,
|
||||
**_config_to_kwargs(config),
|
||||
)
|
||||
|
||||
def _allocate_memory(self, inference_max_sequence_len, batch_size, device=None, dtype=None):
|
||||
@@ -449,7 +440,6 @@ class MLP(torch.nn.Module):
|
||||
config.ffn_hidden_size * 2,
|
||||
bias=self.add_bias,
|
||||
device=device,
|
||||
**_config_to_kwargs(config),
|
||||
)
|
||||
|
||||
def swiglu(x):
|
||||
@@ -459,9 +449,7 @@ class MLP(torch.nn.Module):
|
||||
self.activation_func = swiglu
|
||||
|
||||
# Project back to h.
|
||||
self.dense_4h_to_h = nn.Linear(
|
||||
config.ffn_hidden_size, config.hidden_size, bias=self.add_bias, device=device, **_config_to_kwargs(config)
|
||||
)
|
||||
self.dense_4h_to_h = nn.Linear(config.ffn_hidden_size, config.hidden_size, bias=self.add_bias, device=device)
|
||||
|
||||
def forward(self, hidden_states):
|
||||
# [s, b, 4hp]
|
||||
@@ -488,18 +476,14 @@ class GLMBlock(torch.nn.Module):
|
||||
|
||||
LayerNormFunc = RMSNorm if config.rmsnorm else LayerNorm
|
||||
# Layernorm on the input data.
|
||||
self.input_layernorm = LayerNormFunc(
|
||||
config.hidden_size, eps=config.layernorm_epsilon, device=device, dtype=config.torch_dtype
|
||||
)
|
||||
self.input_layernorm = LayerNormFunc(config.hidden_size, eps=config.layernorm_epsilon, device=device)
|
||||
|
||||
# Self attention.
|
||||
self.self_attention = SelfAttention(config, layer_number, device=device)
|
||||
self.hidden_dropout = config.hidden_dropout
|
||||
|
||||
# Layernorm on the attention output
|
||||
self.post_attention_layernorm = LayerNormFunc(
|
||||
config.hidden_size, eps=config.layernorm_epsilon, device=device, dtype=config.torch_dtype
|
||||
)
|
||||
self.post_attention_layernorm = LayerNormFunc(config.hidden_size, eps=config.layernorm_epsilon, device=device)
|
||||
|
||||
# MLP
|
||||
self.mlp = MLP(config, device=device)
|
||||
@@ -569,9 +553,7 @@ class GLMTransformer(torch.nn.Module):
|
||||
if self.post_layer_norm:
|
||||
LayerNormFunc = RMSNorm if config.rmsnorm else LayerNorm
|
||||
# Final layer norm before output.
|
||||
self.final_layernorm = LayerNormFunc(
|
||||
config.hidden_size, eps=config.layernorm_epsilon, device=device, dtype=config.torch_dtype
|
||||
)
|
||||
self.final_layernorm = LayerNormFunc(config.hidden_size, eps=config.layernorm_epsilon, device=device)
|
||||
|
||||
self.gradient_checkpointing = False
|
||||
|
||||
@@ -679,9 +661,7 @@ class Embedding(torch.nn.Module):
|
||||
|
||||
self.hidden_size = config.hidden_size
|
||||
# Word embeddings (parallel).
|
||||
self.word_embeddings = nn.Embedding(
|
||||
config.padded_vocab_size, self.hidden_size, dtype=config.torch_dtype, device=device
|
||||
)
|
||||
self.word_embeddings = nn.Embedding(config.padded_vocab_size, self.hidden_size, device=device)
|
||||
self.fp32_residual_connection = config.fp32_residual_connection
|
||||
|
||||
def forward(self, input_ids):
|
||||
@@ -784,16 +764,13 @@ class ChatGLMModel(ChatGLMPreTrainedModel):
|
||||
config.hidden_size // config.num_attention_heads if config.kv_channels is None else config.kv_channels
|
||||
)
|
||||
|
||||
self.rotary_pos_emb = RotaryEmbedding(
|
||||
rotary_dim // 2, original_impl=config.original_rope, device=device, dtype=config.torch_dtype
|
||||
)
|
||||
self.rotary_pos_emb = RotaryEmbedding(rotary_dim // 2, original_impl=config.original_rope, device=device)
|
||||
self.encoder = init_method(GLMTransformer, config, **init_kwargs)
|
||||
self.output_layer = init_method(
|
||||
nn.Linear,
|
||||
config.hidden_size,
|
||||
config.padded_vocab_size,
|
||||
bias=False,
|
||||
dtype=config.torch_dtype,
|
||||
**init_kwargs,
|
||||
)
|
||||
self.pre_seq_len = config.pre_seq_len
|
||||
|
||||
@@ -817,7 +817,7 @@ class LattePipeline(DiffusionPipeline):
|
||||
|
||||
# predict noise model_output
|
||||
noise_pred = self.transformer(
|
||||
latent_model_input,
|
||||
hidden_states=latent_model_input,
|
||||
encoder_hidden_states=prompt_embeds,
|
||||
timestep=current_timestep,
|
||||
enable_temporal_attentions=enable_temporal_attentions,
|
||||
|
||||
@@ -23,6 +23,7 @@ except OptionalDependencyNotAvailable:
|
||||
_dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects))
|
||||
else:
|
||||
_import_structure["pipeline_ltx"] = ["LTXPipeline"]
|
||||
_import_structure["pipeline_ltx_condition"] = ["LTXConditionPipeline"]
|
||||
_import_structure["pipeline_ltx_image2video"] = ["LTXImageToVideoPipeline"]
|
||||
|
||||
if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
@@ -34,6 +35,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
from ...utils.dummy_torch_and_transformers_objects import *
|
||||
else:
|
||||
from .pipeline_ltx import LTXPipeline
|
||||
from .pipeline_ltx_condition import LTXConditionPipeline
|
||||
from .pipeline_ltx_image2video import LTXImageToVideoPipeline
|
||||
|
||||
else:
|
||||
|
||||
@@ -694,9 +694,8 @@ class LTXPipeline(DiffusionPipeline, FromSingleFileMixin, LTXVideoLoraLoaderMixi
|
||||
self._num_timesteps = len(timesteps)
|
||||
|
||||
# 6. Prepare micro-conditions
|
||||
latent_frame_rate = frame_rate / self.vae_temporal_compression_ratio
|
||||
rope_interpolation_scale = (
|
||||
1 / latent_frame_rate,
|
||||
self.vae_temporal_compression_ratio / frame_rate,
|
||||
self.vae_spatial_compression_ratio,
|
||||
self.vae_spatial_compression_ratio,
|
||||
)
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -764,9 +764,8 @@ class LTXImageToVideoPipeline(DiffusionPipeline, FromSingleFileMixin, LTXVideoLo
|
||||
self._num_timesteps = len(timesteps)
|
||||
|
||||
# 6. Prepare micro-conditions
|
||||
latent_frame_rate = frame_rate / self.vae_temporal_compression_ratio
|
||||
rope_interpolation_scale = (
|
||||
1 / latent_frame_rate,
|
||||
self.vae_temporal_compression_ratio / frame_rate,
|
||||
self.vae_spatial_compression_ratio,
|
||||
self.vae_spatial_compression_ratio,
|
||||
)
|
||||
|
||||
@@ -427,7 +427,7 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
|
||||
"It seems like you have activated a device mapping strategy on the pipeline which doesn't allow explicit device placement using `to()`. You can call `reset_device_map()` to remove the existing device map from the pipeline."
|
||||
)
|
||||
|
||||
if device_type == "cuda":
|
||||
if device_type in ["cuda", "xpu"]:
|
||||
if pipeline_is_sequentially_offloaded and not pipeline_has_bnb:
|
||||
raise ValueError(
|
||||
"It seems like you have activated sequential model offloading by calling `enable_sequential_cpu_offload`, but are now attempting to move the pipeline to GPU. This is not compatible with offloading. Please, move your pipeline `.to('cpu')` or consider removing the move altogether if you use sequential offloading."
|
||||
@@ -440,7 +440,7 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
|
||||
|
||||
# Display a warning in this case (the operation succeeds but the benefits are lost)
|
||||
pipeline_is_offloaded = any(module_is_offloaded(module) for _, module in self.components.items())
|
||||
if pipeline_is_offloaded and device_type == "cuda":
|
||||
if pipeline_is_offloaded and device_type in ["cuda", "xpu"]:
|
||||
logger.warning(
|
||||
f"It seems like you have activated model offloading by calling `enable_model_cpu_offload`, but are now manually moving the pipeline to GPU. It is strongly recommended against doing so as memory gains from offloading are likely to be lost. Offloading automatically takes care of moving the individual components {', '.join(self.components.keys())} to GPU when needed. To make sure offloading works as expected, you should consider moving the pipeline back to CPU: `pipeline.to('cpu')` or removing the move altogether if you use offloading."
|
||||
)
|
||||
@@ -686,7 +686,7 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
|
||||
token = kwargs.pop("token", None)
|
||||
revision = kwargs.pop("revision", None)
|
||||
from_flax = kwargs.pop("from_flax", False)
|
||||
torch_dtype = kwargs.pop("torch_dtype", torch.float32)
|
||||
torch_dtype = kwargs.pop("torch_dtype", None)
|
||||
custom_pipeline = kwargs.pop("custom_pipeline", None)
|
||||
custom_revision = kwargs.pop("custom_revision", None)
|
||||
provider = kwargs.pop("provider", None)
|
||||
@@ -703,7 +703,7 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
|
||||
use_onnx = kwargs.pop("use_onnx", None)
|
||||
load_connected_pipeline = kwargs.pop("load_connected_pipeline", False)
|
||||
|
||||
if not isinstance(torch_dtype, torch.dtype):
|
||||
if torch_dtype is not None and not isinstance(torch_dtype, torch.dtype):
|
||||
torch_dtype = torch.float32
|
||||
logger.warning(
|
||||
f"Passed `torch_dtype` {torch_dtype} is not a `torch.dtype`. Defaulting to `torch.float32`."
|
||||
@@ -1456,8 +1456,8 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
|
||||
|
||||
if load_components_from_hub and not trust_remote_code:
|
||||
raise ValueError(
|
||||
f"The repository for {pretrained_model_name} contains custom code in {'.py, '.join([os.path.join(k, v) for k,v in custom_components.items()])} which must be executed to correctly "
|
||||
f"load the model. You can inspect the repository content at {', '.join([f'https://hf.co/{pretrained_model_name}/{k}/{v}.py' for k,v in custom_components.items()])}.\n"
|
||||
f"The repository for {pretrained_model_name} contains custom code in {'.py, '.join([os.path.join(k, v) for k, v in custom_components.items()])} which must be executed to correctly "
|
||||
f"load the model. You can inspect the repository content at {', '.join([f'https://hf.co/{pretrained_model_name}/{k}/{v}.py' for k, v in custom_components.items()])}.\n"
|
||||
f"Please pass the argument `trust_remote_code=True` to allow custom code to be run."
|
||||
)
|
||||
|
||||
|
||||
@@ -941,8 +941,7 @@ class PixArtAlphaPipeline(DiffusionPipeline):
|
||||
|
||||
# compute previous image: x_t -> x_t-1
|
||||
if num_inference_steps == 1:
|
||||
# For DMD one step sampling: https://arxiv.org/abs/2311.18828
|
||||
latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).pred_original_sample
|
||||
latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[1]
|
||||
else:
|
||||
latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]
|
||||
|
||||
|
||||
@@ -23,6 +23,7 @@ except OptionalDependencyNotAvailable:
|
||||
_dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects))
|
||||
else:
|
||||
_import_structure["pipeline_sana"] = ["SanaPipeline"]
|
||||
_import_structure["pipeline_sana_sprint"] = ["SanaSprintPipeline"]
|
||||
|
||||
if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
try:
|
||||
@@ -33,6 +34,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
from ...utils.dummy_torch_and_transformers_objects import *
|
||||
else:
|
||||
from .pipeline_sana import SanaPipeline
|
||||
from .pipeline_sana_sprint import SanaSprintPipeline
|
||||
else:
|
||||
import sys
|
||||
|
||||
|
||||
@@ -248,6 +248,64 @@ class SanaPipeline(DiffusionPipeline, SanaLoraLoaderMixin):
|
||||
"""
|
||||
self.vae.disable_tiling()
|
||||
|
||||
def _get_gemma_prompt_embeds(
|
||||
self,
|
||||
prompt: Union[str, List[str]],
|
||||
device: torch.device,
|
||||
dtype: torch.dtype,
|
||||
clean_caption: bool = False,
|
||||
max_sequence_length: int = 300,
|
||||
complex_human_instruction: Optional[List[str]] = None,
|
||||
):
|
||||
r"""
|
||||
Encodes the prompt into text encoder hidden states.
|
||||
|
||||
Args:
|
||||
prompt (`str` or `List[str]`, *optional*):
|
||||
prompt to be encoded
|
||||
device: (`torch.device`, *optional*):
|
||||
torch device to place the resulting embeddings on
|
||||
clean_caption (`bool`, defaults to `False`):
|
||||
If `True`, the function will preprocess and clean the provided caption before encoding.
|
||||
max_sequence_length (`int`, defaults to 300): Maximum sequence length to use for the prompt.
|
||||
complex_human_instruction (`list[str]`, defaults to `complex_human_instruction`):
|
||||
If `complex_human_instruction` is not empty, the function will use the complex Human instruction for
|
||||
the prompt.
|
||||
"""
|
||||
prompt = [prompt] if isinstance(prompt, str) else prompt
|
||||
|
||||
if getattr(self, "tokenizer", None) is not None:
|
||||
self.tokenizer.padding_side = "right"
|
||||
|
||||
prompt = self._text_preprocessing(prompt, clean_caption=clean_caption)
|
||||
|
||||
# prepare complex human instruction
|
||||
if not complex_human_instruction:
|
||||
max_length_all = max_sequence_length
|
||||
else:
|
||||
chi_prompt = "\n".join(complex_human_instruction)
|
||||
prompt = [chi_prompt + p for p in prompt]
|
||||
num_chi_prompt_tokens = len(self.tokenizer.encode(chi_prompt))
|
||||
max_length_all = num_chi_prompt_tokens + max_sequence_length - 2
|
||||
|
||||
text_inputs = self.tokenizer(
|
||||
prompt,
|
||||
padding="max_length",
|
||||
max_length=max_length_all,
|
||||
truncation=True,
|
||||
add_special_tokens=True,
|
||||
return_tensors="pt",
|
||||
)
|
||||
text_input_ids = text_inputs.input_ids
|
||||
|
||||
prompt_attention_mask = text_inputs.attention_mask
|
||||
prompt_attention_mask = prompt_attention_mask.to(device)
|
||||
|
||||
prompt_embeds = self.text_encoder(text_input_ids.to(device), attention_mask=prompt_attention_mask)
|
||||
prompt_embeds = prompt_embeds[0].to(dtype=dtype, device=device)
|
||||
|
||||
return prompt_embeds, prompt_attention_mask
|
||||
|
||||
def encode_prompt(
|
||||
self,
|
||||
prompt: Union[str, List[str]],
|
||||
@@ -296,6 +354,13 @@ class SanaPipeline(DiffusionPipeline, SanaLoraLoaderMixin):
|
||||
if device is None:
|
||||
device = self._execution_device
|
||||
|
||||
if self.transformer is not None:
|
||||
dtype = self.transformer.dtype
|
||||
elif self.text_encoder is not None:
|
||||
dtype = self.text_encoder.dtype
|
||||
else:
|
||||
dtype = None
|
||||
|
||||
# set lora scale so that monkey patched LoRA
|
||||
# function of text encoder can correctly access it
|
||||
if lora_scale is not None and isinstance(self, SanaLoraLoaderMixin):
|
||||
@@ -320,43 +385,18 @@ class SanaPipeline(DiffusionPipeline, SanaLoraLoaderMixin):
|
||||
select_index = [0] + list(range(-max_length + 1, 0))
|
||||
|
||||
if prompt_embeds is None:
|
||||
prompt = self._text_preprocessing(prompt, clean_caption=clean_caption)
|
||||
|
||||
# prepare complex human instruction
|
||||
if not complex_human_instruction:
|
||||
max_length_all = max_length
|
||||
else:
|
||||
chi_prompt = "\n".join(complex_human_instruction)
|
||||
prompt = [chi_prompt + p for p in prompt]
|
||||
num_chi_prompt_tokens = len(self.tokenizer.encode(chi_prompt))
|
||||
max_length_all = num_chi_prompt_tokens + max_length - 2
|
||||
|
||||
text_inputs = self.tokenizer(
|
||||
prompt,
|
||||
padding="max_length",
|
||||
max_length=max_length_all,
|
||||
truncation=True,
|
||||
add_special_tokens=True,
|
||||
return_tensors="pt",
|
||||
prompt_embeds, prompt_attention_mask = self._get_gemma_prompt_embeds(
|
||||
prompt=prompt,
|
||||
device=device,
|
||||
dtype=dtype,
|
||||
clean_caption=clean_caption,
|
||||
max_sequence_length=max_sequence_length,
|
||||
complex_human_instruction=complex_human_instruction,
|
||||
)
|
||||
text_input_ids = text_inputs.input_ids
|
||||
|
||||
prompt_attention_mask = text_inputs.attention_mask
|
||||
prompt_attention_mask = prompt_attention_mask.to(device)
|
||||
|
||||
prompt_embeds = self.text_encoder(text_input_ids.to(device), attention_mask=prompt_attention_mask)
|
||||
prompt_embeds = prompt_embeds[0][:, select_index]
|
||||
prompt_embeds = prompt_embeds[:, select_index]
|
||||
prompt_attention_mask = prompt_attention_mask[:, select_index]
|
||||
|
||||
if self.transformer is not None:
|
||||
dtype = self.transformer.dtype
|
||||
elif self.text_encoder is not None:
|
||||
dtype = self.text_encoder.dtype
|
||||
else:
|
||||
dtype = None
|
||||
|
||||
prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)
|
||||
|
||||
bs_embed, seq_len, _ = prompt_embeds.shape
|
||||
# duplicate text embeddings and attention mask for each generation per prompt, using mps friendly method
|
||||
prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
|
||||
@@ -366,25 +406,15 @@ class SanaPipeline(DiffusionPipeline, SanaLoraLoaderMixin):
|
||||
|
||||
# get unconditional embeddings for classifier free guidance
|
||||
if do_classifier_free_guidance and negative_prompt_embeds is None:
|
||||
uncond_tokens = [negative_prompt] * batch_size if isinstance(negative_prompt, str) else negative_prompt
|
||||
uncond_tokens = self._text_preprocessing(uncond_tokens, clean_caption=clean_caption)
|
||||
max_length = prompt_embeds.shape[1]
|
||||
uncond_input = self.tokenizer(
|
||||
uncond_tokens,
|
||||
padding="max_length",
|
||||
max_length=max_length,
|
||||
truncation=True,
|
||||
return_attention_mask=True,
|
||||
add_special_tokens=True,
|
||||
return_tensors="pt",
|
||||
negative_prompt = [negative_prompt] * batch_size if isinstance(negative_prompt, str) else negative_prompt
|
||||
negative_prompt_embeds, negative_prompt_attention_mask = self._get_gemma_prompt_embeds(
|
||||
prompt=negative_prompt,
|
||||
device=device,
|
||||
dtype=dtype,
|
||||
clean_caption=clean_caption,
|
||||
max_sequence_length=max_sequence_length,
|
||||
complex_human_instruction=False,
|
||||
)
|
||||
negative_prompt_attention_mask = uncond_input.attention_mask
|
||||
negative_prompt_attention_mask = negative_prompt_attention_mask.to(device)
|
||||
|
||||
negative_prompt_embeds = self.text_encoder(
|
||||
uncond_input.input_ids.to(device), attention_mask=negative_prompt_attention_mask
|
||||
)
|
||||
negative_prompt_embeds = negative_prompt_embeds[0]
|
||||
|
||||
if do_classifier_free_guidance:
|
||||
# duplicate unconditional embeddings for each generation per prompt, using mps friendly method
|
||||
@@ -908,6 +938,7 @@ class SanaPipeline(DiffusionPipeline, SanaLoraLoaderMixin):
|
||||
|
||||
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
|
||||
timestep = t.expand(latent_model_input.shape[0]).to(latents.dtype)
|
||||
timestep = timestep * self.transformer.config.timestep_scale
|
||||
|
||||
# predict noise model_output
|
||||
noise_pred = self.transformer(
|
||||
|
||||
@@ -0,0 +1,889 @@
|
||||
# Copyright 2024 PixArt-Sigma Authors and The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import html
|
||||
import inspect
|
||||
import re
|
||||
import urllib.parse as ul
|
||||
import warnings
|
||||
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
from transformers import Gemma2PreTrainedModel, GemmaTokenizer, GemmaTokenizerFast
|
||||
|
||||
from ...callbacks import MultiPipelineCallbacks, PipelineCallback
|
||||
from ...image_processor import PixArtImageProcessor
|
||||
from ...loaders import SanaLoraLoaderMixin
|
||||
from ...models import AutoencoderDC, SanaTransformer2DModel
|
||||
from ...schedulers import DPMSolverMultistepScheduler
|
||||
from ...utils import (
|
||||
BACKENDS_MAPPING,
|
||||
USE_PEFT_BACKEND,
|
||||
is_bs4_available,
|
||||
is_ftfy_available,
|
||||
is_torch_xla_available,
|
||||
logging,
|
||||
replace_example_docstring,
|
||||
scale_lora_layers,
|
||||
unscale_lora_layers,
|
||||
)
|
||||
from ...utils.torch_utils import randn_tensor
|
||||
from ..pipeline_utils import DiffusionPipeline
|
||||
from ..pixart_alpha.pipeline_pixart_alpha import ASPECT_RATIO_1024_BIN
|
||||
from .pipeline_output import SanaPipelineOutput
|
||||
|
||||
|
||||
if is_torch_xla_available():
|
||||
import torch_xla.core.xla_model as xm
|
||||
|
||||
XLA_AVAILABLE = True
|
||||
else:
|
||||
XLA_AVAILABLE = False
|
||||
|
||||
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||
|
||||
if is_bs4_available():
|
||||
from bs4 import BeautifulSoup
|
||||
|
||||
if is_ftfy_available():
|
||||
import ftfy
|
||||
|
||||
|
||||
EXAMPLE_DOC_STRING = """
|
||||
Examples:
|
||||
```py
|
||||
>>> import torch
|
||||
>>> from diffusers import SanaSprintPipeline
|
||||
|
||||
>>> pipe = SanaSprintPipeline.from_pretrained(
|
||||
... "Efficient-Large-Model/Sana_Sprint_1.6B_1024px_diffusers", torch_dtype=torch.bfloat16
|
||||
... )
|
||||
>>> pipe.to("cuda")
|
||||
|
||||
>>> image = pipe(prompt="a tiny astronaut hatching from an egg on the moon")[0]
|
||||
>>> image[0].save("output.png")
|
||||
```
|
||||
"""
|
||||
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps
|
||||
def retrieve_timesteps(
|
||||
scheduler,
|
||||
num_inference_steps: Optional[int] = None,
|
||||
device: Optional[Union[str, torch.device]] = None,
|
||||
timesteps: Optional[List[int]] = None,
|
||||
sigmas: Optional[List[float]] = None,
|
||||
**kwargs,
|
||||
):
|
||||
r"""
|
||||
Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
|
||||
custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.
|
||||
|
||||
Args:
|
||||
scheduler (`SchedulerMixin`):
|
||||
The scheduler to get timesteps from.
|
||||
num_inference_steps (`int`):
|
||||
The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps`
|
||||
must be `None`.
|
||||
device (`str` or `torch.device`, *optional*):
|
||||
The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
|
||||
timesteps (`List[int]`, *optional*):
|
||||
Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed,
|
||||
`num_inference_steps` and `sigmas` must be `None`.
|
||||
sigmas (`List[float]`, *optional*):
|
||||
Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed,
|
||||
`num_inference_steps` and `timesteps` must be `None`.
|
||||
|
||||
Returns:
|
||||
`Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the
|
||||
second element is the number of inference steps.
|
||||
"""
|
||||
if timesteps is not None and sigmas is not None:
|
||||
raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values")
|
||||
if timesteps is not None:
|
||||
accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
|
||||
if not accepts_timesteps:
|
||||
raise ValueError(
|
||||
f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
|
||||
f" timestep schedules. Please check whether you are using the correct scheduler."
|
||||
)
|
||||
scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)
|
||||
timesteps = scheduler.timesteps
|
||||
num_inference_steps = len(timesteps)
|
||||
elif sigmas is not None:
|
||||
accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
|
||||
if not accept_sigmas:
|
||||
raise ValueError(
|
||||
f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
|
||||
f" sigmas schedules. Please check whether you are using the correct scheduler."
|
||||
)
|
||||
scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs)
|
||||
timesteps = scheduler.timesteps
|
||||
num_inference_steps = len(timesteps)
|
||||
else:
|
||||
scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)
|
||||
timesteps = scheduler.timesteps
|
||||
return timesteps, num_inference_steps
|
||||
|
||||
|
||||
class SanaSprintPipeline(DiffusionPipeline, SanaLoraLoaderMixin):
|
||||
r"""
|
||||
Pipeline for text-to-image generation using [SANA-Sprint](https://huggingface.co/papers/2503.09641).
|
||||
"""
|
||||
|
||||
# fmt: off
|
||||
bad_punct_regex = re.compile(r"[" + "#®•©™&@·º½¾¿¡§~" + r"\)" + r"\(" + r"\]" + r"\[" + r"\}" + r"\{" + r"\|" + "\\" + r"\/" + r"\*" + r"]{1,}")
|
||||
# fmt: on
|
||||
|
||||
model_cpu_offload_seq = "text_encoder->transformer->vae"
|
||||
_callback_tensor_inputs = ["latents", "prompt_embeds"]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
tokenizer: Union[GemmaTokenizer, GemmaTokenizerFast],
|
||||
text_encoder: Gemma2PreTrainedModel,
|
||||
vae: AutoencoderDC,
|
||||
transformer: SanaTransformer2DModel,
|
||||
scheduler: DPMSolverMultistepScheduler,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.register_modules(
|
||||
tokenizer=tokenizer, text_encoder=text_encoder, vae=vae, transformer=transformer, scheduler=scheduler
|
||||
)
|
||||
|
||||
self.vae_scale_factor = (
|
||||
2 ** (len(self.vae.config.encoder_block_out_channels) - 1)
|
||||
if hasattr(self, "vae") and self.vae is not None
|
||||
else 32
|
||||
)
|
||||
self.image_processor = PixArtImageProcessor(vae_scale_factor=self.vae_scale_factor)
|
||||
|
||||
def enable_vae_slicing(self):
|
||||
r"""
|
||||
Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to
|
||||
compute decoding in several steps. This is useful to save some memory and allow larger batch sizes.
|
||||
"""
|
||||
self.vae.enable_slicing()
|
||||
|
||||
def disable_vae_slicing(self):
|
||||
r"""
|
||||
Disable sliced VAE decoding. If `enable_vae_slicing` was previously enabled, this method will go back to
|
||||
computing decoding in one step.
|
||||
"""
|
||||
self.vae.disable_slicing()
|
||||
|
||||
def enable_vae_tiling(self):
|
||||
r"""
|
||||
Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to
|
||||
compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow
|
||||
processing larger images.
|
||||
"""
|
||||
self.vae.enable_tiling()
|
||||
|
||||
def disable_vae_tiling(self):
|
||||
r"""
|
||||
Disable tiled VAE decoding. If `enable_vae_tiling` was previously enabled, this method will go back to
|
||||
computing decoding in one step.
|
||||
"""
|
||||
self.vae.disable_tiling()
|
||||
|
||||
# Copied from diffusers.pipelines.sana.pipeline_sana.SanaPipeline._get_gemma_prompt_embeds
|
||||
def _get_gemma_prompt_embeds(
|
||||
self,
|
||||
prompt: Union[str, List[str]],
|
||||
device: torch.device,
|
||||
dtype: torch.dtype,
|
||||
clean_caption: bool = False,
|
||||
max_sequence_length: int = 300,
|
||||
complex_human_instruction: Optional[List[str]] = None,
|
||||
):
|
||||
r"""
|
||||
Encodes the prompt into text encoder hidden states.
|
||||
|
||||
Args:
|
||||
prompt (`str` or `List[str]`, *optional*):
|
||||
prompt to be encoded
|
||||
device: (`torch.device`, *optional*):
|
||||
torch device to place the resulting embeddings on
|
||||
clean_caption (`bool`, defaults to `False`):
|
||||
If `True`, the function will preprocess and clean the provided caption before encoding.
|
||||
max_sequence_length (`int`, defaults to 300): Maximum sequence length to use for the prompt.
|
||||
complex_human_instruction (`list[str]`, defaults to `complex_human_instruction`):
|
||||
If `complex_human_instruction` is not empty, the function will use the complex Human instruction for
|
||||
the prompt.
|
||||
"""
|
||||
prompt = [prompt] if isinstance(prompt, str) else prompt
|
||||
|
||||
if getattr(self, "tokenizer", None) is not None:
|
||||
self.tokenizer.padding_side = "right"
|
||||
|
||||
prompt = self._text_preprocessing(prompt, clean_caption=clean_caption)
|
||||
|
||||
# prepare complex human instruction
|
||||
if not complex_human_instruction:
|
||||
max_length_all = max_sequence_length
|
||||
else:
|
||||
chi_prompt = "\n".join(complex_human_instruction)
|
||||
prompt = [chi_prompt + p for p in prompt]
|
||||
num_chi_prompt_tokens = len(self.tokenizer.encode(chi_prompt))
|
||||
max_length_all = num_chi_prompt_tokens + max_sequence_length - 2
|
||||
|
||||
text_inputs = self.tokenizer(
|
||||
prompt,
|
||||
padding="max_length",
|
||||
max_length=max_length_all,
|
||||
truncation=True,
|
||||
add_special_tokens=True,
|
||||
return_tensors="pt",
|
||||
)
|
||||
text_input_ids = text_inputs.input_ids
|
||||
|
||||
prompt_attention_mask = text_inputs.attention_mask
|
||||
prompt_attention_mask = prompt_attention_mask.to(device)
|
||||
|
||||
prompt_embeds = self.text_encoder(text_input_ids.to(device), attention_mask=prompt_attention_mask)
|
||||
prompt_embeds = prompt_embeds[0].to(dtype=dtype, device=device)
|
||||
|
||||
return prompt_embeds, prompt_attention_mask
|
||||
|
||||
def encode_prompt(
|
||||
self,
|
||||
prompt: Union[str, List[str]],
|
||||
num_images_per_prompt: int = 1,
|
||||
device: Optional[torch.device] = None,
|
||||
prompt_embeds: Optional[torch.Tensor] = None,
|
||||
prompt_attention_mask: Optional[torch.Tensor] = None,
|
||||
clean_caption: bool = False,
|
||||
max_sequence_length: int = 300,
|
||||
complex_human_instruction: Optional[List[str]] = None,
|
||||
lora_scale: Optional[float] = None,
|
||||
):
|
||||
r"""
|
||||
Encodes the prompt into text encoder hidden states.
|
||||
|
||||
Args:
|
||||
prompt (`str` or `List[str]`, *optional*):
|
||||
prompt to be encoded
|
||||
|
||||
num_images_per_prompt (`int`, *optional*, defaults to 1):
|
||||
number of images that should be generated per prompt
|
||||
device: (`torch.device`, *optional*):
|
||||
torch device to place the resulting embeddings on
|
||||
prompt_embeds (`torch.Tensor`, *optional*):
|
||||
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
|
||||
provided, text embeddings will be generated from `prompt` input argument.
|
||||
clean_caption (`bool`, defaults to `False`):
|
||||
If `True`, the function will preprocess and clean the provided caption before encoding.
|
||||
max_sequence_length (`int`, defaults to 300): Maximum sequence length to use for the prompt.
|
||||
complex_human_instruction (`list[str]`, defaults to `complex_human_instruction`):
|
||||
If `complex_human_instruction` is not empty, the function will use the complex Human instruction for
|
||||
the prompt.
|
||||
"""
|
||||
|
||||
if device is None:
|
||||
device = self._execution_device
|
||||
|
||||
if self.transformer is not None:
|
||||
dtype = self.transformer.dtype
|
||||
elif self.text_encoder is not None:
|
||||
dtype = self.text_encoder.dtype
|
||||
else:
|
||||
dtype = None
|
||||
|
||||
# set lora scale so that monkey patched LoRA
|
||||
# function of text encoder can correctly access it
|
||||
if lora_scale is not None and isinstance(self, SanaLoraLoaderMixin):
|
||||
self._lora_scale = lora_scale
|
||||
|
||||
# dynamically adjust the LoRA scale
|
||||
if self.text_encoder is not None and USE_PEFT_BACKEND:
|
||||
scale_lora_layers(self.text_encoder, lora_scale)
|
||||
|
||||
if getattr(self, "tokenizer", None) is not None:
|
||||
self.tokenizer.padding_side = "right"
|
||||
|
||||
# See Section 3.1. of the paper.
|
||||
max_length = max_sequence_length
|
||||
select_index = [0] + list(range(-max_length + 1, 0))
|
||||
|
||||
if prompt_embeds is None:
|
||||
prompt_embeds, prompt_attention_mask = self._get_gemma_prompt_embeds(
|
||||
prompt=prompt,
|
||||
device=device,
|
||||
dtype=dtype,
|
||||
clean_caption=clean_caption,
|
||||
max_sequence_length=max_sequence_length,
|
||||
complex_human_instruction=complex_human_instruction,
|
||||
)
|
||||
|
||||
prompt_embeds = prompt_embeds[:, select_index]
|
||||
prompt_attention_mask = prompt_attention_mask[:, select_index]
|
||||
|
||||
bs_embed, seq_len, _ = prompt_embeds.shape
|
||||
# duplicate text embeddings and attention mask for each generation per prompt, using mps friendly method
|
||||
prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
|
||||
prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1)
|
||||
prompt_attention_mask = prompt_attention_mask.view(bs_embed, -1)
|
||||
prompt_attention_mask = prompt_attention_mask.repeat(num_images_per_prompt, 1)
|
||||
|
||||
if self.text_encoder is not None:
|
||||
if isinstance(self, SanaLoraLoaderMixin) and USE_PEFT_BACKEND:
|
||||
# Retrieve the original scale by scaling back the LoRA layers
|
||||
unscale_lora_layers(self.text_encoder, lora_scale)
|
||||
|
||||
return prompt_embeds, prompt_attention_mask
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs
|
||||
def prepare_extra_step_kwargs(self, generator, eta):
|
||||
# prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
|
||||
# eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
|
||||
# eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
|
||||
# and should be between [0, 1]
|
||||
|
||||
accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
|
||||
extra_step_kwargs = {}
|
||||
if accepts_eta:
|
||||
extra_step_kwargs["eta"] = eta
|
||||
|
||||
# check if the scheduler accepts generator
|
||||
accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys())
|
||||
if accepts_generator:
|
||||
extra_step_kwargs["generator"] = generator
|
||||
return extra_step_kwargs
|
||||
|
||||
def check_inputs(
|
||||
self,
|
||||
prompt,
|
||||
height,
|
||||
width,
|
||||
num_inference_steps,
|
||||
timesteps,
|
||||
max_timesteps,
|
||||
intermediate_timesteps,
|
||||
callback_on_step_end_tensor_inputs=None,
|
||||
prompt_embeds=None,
|
||||
prompt_attention_mask=None,
|
||||
):
|
||||
if height % 32 != 0 or width % 32 != 0:
|
||||
raise ValueError(f"`height` and `width` have to be divisible by 32 but are {height} and {width}.")
|
||||
|
||||
if callback_on_step_end_tensor_inputs is not None and not all(
|
||||
k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
|
||||
):
|
||||
raise ValueError(
|
||||
f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}"
|
||||
)
|
||||
|
||||
if prompt is not None and prompt_embeds is not None:
|
||||
raise ValueError(
|
||||
f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
|
||||
" only forward one of the two."
|
||||
)
|
||||
elif prompt is None and prompt_embeds is None:
|
||||
raise ValueError(
|
||||
"Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
|
||||
)
|
||||
elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
|
||||
raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
|
||||
|
||||
if prompt_embeds is not None and prompt_attention_mask is None:
|
||||
raise ValueError("Must provide `prompt_attention_mask` when specifying `prompt_embeds`.")
|
||||
|
||||
if timesteps is not None and len(timesteps) != num_inference_steps + 1:
|
||||
raise ValueError("If providing custom timesteps, `timesteps` must be of length `num_inference_steps + 1`.")
|
||||
|
||||
if timesteps is not None and max_timesteps is not None:
|
||||
raise ValueError("If providing custom timesteps, `max_timesteps` should not be provided.")
|
||||
|
||||
if timesteps is None and max_timesteps is None:
|
||||
raise ValueError("Should provide either `timesteps` or `max_timesteps`.")
|
||||
|
||||
if intermediate_timesteps is not None and num_inference_steps != 2:
|
||||
raise ValueError("Intermediate timesteps for SCM is not supported when num_inference_steps != 2.")
|
||||
|
||||
# Copied from diffusers.pipelines.deepfloyd_if.pipeline_if.IFPipeline._text_preprocessing
|
||||
def _text_preprocessing(self, text, clean_caption=False):
|
||||
if clean_caption and not is_bs4_available():
|
||||
logger.warning(BACKENDS_MAPPING["bs4"][-1].format("Setting `clean_caption=True`"))
|
||||
logger.warning("Setting `clean_caption` to False...")
|
||||
clean_caption = False
|
||||
|
||||
if clean_caption and not is_ftfy_available():
|
||||
logger.warning(BACKENDS_MAPPING["ftfy"][-1].format("Setting `clean_caption=True`"))
|
||||
logger.warning("Setting `clean_caption` to False...")
|
||||
clean_caption = False
|
||||
|
||||
if not isinstance(text, (tuple, list)):
|
||||
text = [text]
|
||||
|
||||
def process(text: str):
|
||||
if clean_caption:
|
||||
text = self._clean_caption(text)
|
||||
text = self._clean_caption(text)
|
||||
else:
|
||||
text = text.lower().strip()
|
||||
return text
|
||||
|
||||
return [process(t) for t in text]
|
||||
|
||||
# Copied from diffusers.pipelines.deepfloyd_if.pipeline_if.IFPipeline._clean_caption
|
||||
def _clean_caption(self, caption):
|
||||
caption = str(caption)
|
||||
caption = ul.unquote_plus(caption)
|
||||
caption = caption.strip().lower()
|
||||
caption = re.sub("<person>", "person", caption)
|
||||
# urls:
|
||||
caption = re.sub(
|
||||
r"\b((?:https?:(?:\/{1,3}|[a-zA-Z0-9%])|[a-zA-Z0-9.\-]+[.](?:com|co|ru|net|org|edu|gov|it)[\w/-]*\b\/?(?!@)))", # noqa
|
||||
"",
|
||||
caption,
|
||||
) # regex for urls
|
||||
caption = re.sub(
|
||||
r"\b((?:www:(?:\/{1,3}|[a-zA-Z0-9%])|[a-zA-Z0-9.\-]+[.](?:com|co|ru|net|org|edu|gov|it)[\w/-]*\b\/?(?!@)))", # noqa
|
||||
"",
|
||||
caption,
|
||||
) # regex for urls
|
||||
# html:
|
||||
caption = BeautifulSoup(caption, features="html.parser").text
|
||||
|
||||
# @<nickname>
|
||||
caption = re.sub(r"@[\w\d]+\b", "", caption)
|
||||
|
||||
# 31C0—31EF CJK Strokes
|
||||
# 31F0—31FF Katakana Phonetic Extensions
|
||||
# 3200—32FF Enclosed CJK Letters and Months
|
||||
# 3300—33FF CJK Compatibility
|
||||
# 3400—4DBF CJK Unified Ideographs Extension A
|
||||
# 4DC0—4DFF Yijing Hexagram Symbols
|
||||
# 4E00—9FFF CJK Unified Ideographs
|
||||
caption = re.sub(r"[\u31c0-\u31ef]+", "", caption)
|
||||
caption = re.sub(r"[\u31f0-\u31ff]+", "", caption)
|
||||
caption = re.sub(r"[\u3200-\u32ff]+", "", caption)
|
||||
caption = re.sub(r"[\u3300-\u33ff]+", "", caption)
|
||||
caption = re.sub(r"[\u3400-\u4dbf]+", "", caption)
|
||||
caption = re.sub(r"[\u4dc0-\u4dff]+", "", caption)
|
||||
caption = re.sub(r"[\u4e00-\u9fff]+", "", caption)
|
||||
#######################################################
|
||||
|
||||
# все виды тире / all types of dash --> "-"
|
||||
caption = re.sub(
|
||||
r"[\u002D\u058A\u05BE\u1400\u1806\u2010-\u2015\u2E17\u2E1A\u2E3A\u2E3B\u2E40\u301C\u3030\u30A0\uFE31\uFE32\uFE58\uFE63\uFF0D]+", # noqa
|
||||
"-",
|
||||
caption,
|
||||
)
|
||||
|
||||
# кавычки к одному стандарту
|
||||
caption = re.sub(r"[`´«»“”¨]", '"', caption)
|
||||
caption = re.sub(r"[‘’]", "'", caption)
|
||||
|
||||
# "
|
||||
caption = re.sub(r""?", "", caption)
|
||||
# &
|
||||
caption = re.sub(r"&", "", caption)
|
||||
|
||||
# ip adresses:
|
||||
caption = re.sub(r"\d{1,3}\.\d{1,3}\.\d{1,3}\.\d{1,3}", " ", caption)
|
||||
|
||||
# article ids:
|
||||
caption = re.sub(r"\d:\d\d\s+$", "", caption)
|
||||
|
||||
# \n
|
||||
caption = re.sub(r"\\n", " ", caption)
|
||||
|
||||
# "#123"
|
||||
caption = re.sub(r"#\d{1,3}\b", "", caption)
|
||||
# "#12345.."
|
||||
caption = re.sub(r"#\d{5,}\b", "", caption)
|
||||
# "123456.."
|
||||
caption = re.sub(r"\b\d{6,}\b", "", caption)
|
||||
# filenames:
|
||||
caption = re.sub(r"[\S]+\.(?:png|jpg|jpeg|bmp|webp|eps|pdf|apk|mp4)", "", caption)
|
||||
|
||||
#
|
||||
caption = re.sub(r"[\"\']{2,}", r'"', caption) # """AUSVERKAUFT"""
|
||||
caption = re.sub(r"[\.]{2,}", r" ", caption) # """AUSVERKAUFT"""
|
||||
|
||||
caption = re.sub(self.bad_punct_regex, r" ", caption) # ***AUSVERKAUFT***, #AUSVERKAUFT
|
||||
caption = re.sub(r"\s+\.\s+", r" ", caption) # " . "
|
||||
|
||||
# this-is-my-cute-cat / this_is_my_cute_cat
|
||||
regex2 = re.compile(r"(?:\-|\_)")
|
||||
if len(re.findall(regex2, caption)) > 3:
|
||||
caption = re.sub(regex2, " ", caption)
|
||||
|
||||
caption = ftfy.fix_text(caption)
|
||||
caption = html.unescape(html.unescape(caption))
|
||||
|
||||
caption = re.sub(r"\b[a-zA-Z]{1,3}\d{3,15}\b", "", caption) # jc6640
|
||||
caption = re.sub(r"\b[a-zA-Z]+\d+[a-zA-Z]+\b", "", caption) # jc6640vc
|
||||
caption = re.sub(r"\b\d+[a-zA-Z]+\d+\b", "", caption) # 6640vc231
|
||||
|
||||
caption = re.sub(r"(worldwide\s+)?(free\s+)?shipping", "", caption)
|
||||
caption = re.sub(r"(free\s)?download(\sfree)?", "", caption)
|
||||
caption = re.sub(r"\bclick\b\s(?:for|on)\s\w+", "", caption)
|
||||
caption = re.sub(r"\b(?:png|jpg|jpeg|bmp|webp|eps|pdf|apk|mp4)(\simage[s]?)?", "", caption)
|
||||
caption = re.sub(r"\bpage\s+\d+\b", "", caption)
|
||||
|
||||
caption = re.sub(r"\b\d*[a-zA-Z]+\d+[a-zA-Z]+\d+[a-zA-Z\d]*\b", r" ", caption) # j2d1a2a...
|
||||
|
||||
caption = re.sub(r"\b\d+\.?\d*[xх×]\d+\.?\d*\b", "", caption)
|
||||
|
||||
caption = re.sub(r"\b\s+\:\s+", r": ", caption)
|
||||
caption = re.sub(r"(\D[,\./])\b", r"\1 ", caption)
|
||||
caption = re.sub(r"\s+", " ", caption)
|
||||
|
||||
caption.strip()
|
||||
|
||||
caption = re.sub(r"^[\"\']([\w\W]+)[\"\']$", r"\1", caption)
|
||||
caption = re.sub(r"^[\'\_,\-\:;]", r"", caption)
|
||||
caption = re.sub(r"[\'\_,\-\:\-\+]$", r"", caption)
|
||||
caption = re.sub(r"^\.\S+$", "", caption)
|
||||
|
||||
return caption.strip()
|
||||
|
||||
# Copied from diffusers.pipelines.sana.pipeline_sana.SanaPipeline.prepare_latents
|
||||
def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None):
|
||||
if latents is not None:
|
||||
return latents.to(device=device, dtype=dtype)
|
||||
|
||||
shape = (
|
||||
batch_size,
|
||||
num_channels_latents,
|
||||
int(height) // self.vae_scale_factor,
|
||||
int(width) // self.vae_scale_factor,
|
||||
)
|
||||
if isinstance(generator, list) and len(generator) != batch_size:
|
||||
raise ValueError(
|
||||
f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
|
||||
f" size of {batch_size}. Make sure the batch size matches the length of the generators."
|
||||
)
|
||||
|
||||
latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
|
||||
return latents
|
||||
|
||||
@property
|
||||
def guidance_scale(self):
|
||||
return self._guidance_scale
|
||||
|
||||
@property
|
||||
def attention_kwargs(self):
|
||||
return self._attention_kwargs
|
||||
|
||||
@property
|
||||
def num_timesteps(self):
|
||||
return self._num_timesteps
|
||||
|
||||
@property
|
||||
def interrupt(self):
|
||||
return self._interrupt
|
||||
|
||||
@torch.no_grad()
|
||||
@replace_example_docstring(EXAMPLE_DOC_STRING)
|
||||
def __call__(
|
||||
self,
|
||||
prompt: Union[str, List[str]] = None,
|
||||
num_inference_steps: int = 2,
|
||||
timesteps: List[int] = None,
|
||||
max_timesteps: float = 1.57080,
|
||||
intermediate_timesteps: float = 1.3,
|
||||
guidance_scale: float = 4.5,
|
||||
num_images_per_prompt: Optional[int] = 1,
|
||||
height: int = 1024,
|
||||
width: int = 1024,
|
||||
eta: float = 0.0,
|
||||
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
|
||||
latents: Optional[torch.Tensor] = None,
|
||||
prompt_embeds: Optional[torch.Tensor] = None,
|
||||
prompt_attention_mask: Optional[torch.Tensor] = None,
|
||||
output_type: Optional[str] = "pil",
|
||||
return_dict: bool = True,
|
||||
clean_caption: bool = False,
|
||||
use_resolution_binning: bool = True,
|
||||
attention_kwargs: Optional[Dict[str, Any]] = None,
|
||||
callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
|
||||
callback_on_step_end_tensor_inputs: List[str] = ["latents"],
|
||||
max_sequence_length: int = 300,
|
||||
complex_human_instruction: List[str] = [
|
||||
"Given a user prompt, generate an 'Enhanced prompt' that provides detailed visual descriptions suitable for image generation. Evaluate the level of detail in the user prompt:",
|
||||
"- If the prompt is simple, focus on adding specifics about colors, shapes, sizes, textures, and spatial relationships to create vivid and concrete scenes.",
|
||||
"- If the prompt is already detailed, refine and enhance the existing details slightly without overcomplicating.",
|
||||
"Here are examples of how to transform or refine prompts:",
|
||||
"- User Prompt: A cat sleeping -> Enhanced: A small, fluffy white cat curled up in a round shape, sleeping peacefully on a warm sunny windowsill, surrounded by pots of blooming red flowers.",
|
||||
"- User Prompt: A busy city street -> Enhanced: A bustling city street scene at dusk, featuring glowing street lamps, a diverse crowd of people in colorful clothing, and a double-decker bus passing by towering glass skyscrapers.",
|
||||
"Please generate only the enhanced description for the prompt below and avoid including any additional commentary or evaluations:",
|
||||
"User Prompt: ",
|
||||
],
|
||||
) -> Union[SanaPipelineOutput, Tuple]:
|
||||
"""
|
||||
Function invoked when calling the pipeline for generation.
|
||||
|
||||
Args:
|
||||
prompt (`str` or `List[str]`, *optional*):
|
||||
The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
|
||||
instead.
|
||||
num_inference_steps (`int`, *optional*, defaults to 20):
|
||||
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
|
||||
expense of slower inference.
|
||||
max_timesteps (`float`, *optional*, defaults to 1.57080):
|
||||
The maximum timestep value used in the SCM scheduler.
|
||||
intermediate_timesteps (`float`, *optional*, defaults to 1.3):
|
||||
The intermediate timestep value used in SCM scheduler (only used when num_inference_steps=2).
|
||||
timesteps (`List[int]`, *optional*):
|
||||
Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument
|
||||
in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is
|
||||
passed will be used. Must be in descending order.
|
||||
guidance_scale (`float`, *optional*, defaults to 4.5):
|
||||
Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
|
||||
`guidance_scale` is defined as `w` of equation 2. of [Imagen
|
||||
Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
|
||||
1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
|
||||
usually at the expense of lower image quality.
|
||||
num_images_per_prompt (`int`, *optional*, defaults to 1):
|
||||
The number of images to generate per prompt.
|
||||
height (`int`, *optional*, defaults to self.unet.config.sample_size):
|
||||
The height in pixels of the generated image.
|
||||
width (`int`, *optional*, defaults to self.unet.config.sample_size):
|
||||
The width in pixels of the generated image.
|
||||
eta (`float`, *optional*, defaults to 0.0):
|
||||
Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to
|
||||
[`schedulers.DDIMScheduler`], will be ignored for others.
|
||||
generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
|
||||
One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
|
||||
to make generation deterministic.
|
||||
latents (`torch.Tensor`, *optional*):
|
||||
Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
|
||||
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
|
||||
tensor will ge generated by sampling using the supplied random `generator`.
|
||||
prompt_embeds (`torch.Tensor`, *optional*):
|
||||
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
|
||||
provided, text embeddings will be generated from `prompt` input argument.
|
||||
prompt_attention_mask (`torch.Tensor`, *optional*): Pre-generated attention mask for text embeddings.
|
||||
output_type (`str`, *optional*, defaults to `"pil"`):
|
||||
The output format of the generate image. Choose between
|
||||
[PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
|
||||
return_dict (`bool`, *optional*, defaults to `True`):
|
||||
Whether or not to return a [`~pipelines.stable_diffusion.IFPipelineOutput`] instead of a plain tuple.
|
||||
attention_kwargs:
|
||||
A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
|
||||
`self.processor` in
|
||||
[diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
|
||||
clean_caption (`bool`, *optional*, defaults to `True`):
|
||||
Whether or not to clean the caption before creating embeddings. Requires `beautifulsoup4` and `ftfy` to
|
||||
be installed. If the dependencies are not installed, the embeddings will be created from the raw
|
||||
prompt.
|
||||
use_resolution_binning (`bool` defaults to `True`):
|
||||
If set to `True`, the requested height and width are first mapped to the closest resolutions using
|
||||
`ASPECT_RATIO_1024_BIN`. After the produced latents are decoded into images, they are resized back to
|
||||
the requested resolution. Useful for generating non-square images.
|
||||
callback_on_step_end (`Callable`, *optional*):
|
||||
A function that calls at the end of each denoising steps during the inference. The function is called
|
||||
with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int,
|
||||
callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by
|
||||
`callback_on_step_end_tensor_inputs`.
|
||||
callback_on_step_end_tensor_inputs (`List`, *optional*):
|
||||
The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
|
||||
will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
|
||||
`._callback_tensor_inputs` attribute of your pipeline class.
|
||||
max_sequence_length (`int` defaults to `300`):
|
||||
Maximum sequence length to use with the `prompt`.
|
||||
complex_human_instruction (`List[str]`, *optional*):
|
||||
Instructions for complex human attention:
|
||||
https://github.com/NVlabs/Sana/blob/main/configs/sana_app_config/Sana_1600M_app.yaml#L55.
|
||||
|
||||
Examples:
|
||||
|
||||
Returns:
|
||||
[`~pipelines.sana.pipeline_output.SanaPipelineOutput`] or `tuple`:
|
||||
If `return_dict` is `True`, [`~pipelines.sana.pipeline_output.SanaPipelineOutput`] is returned,
|
||||
otherwise a `tuple` is returned where the first element is a list with the generated images
|
||||
"""
|
||||
|
||||
if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)):
|
||||
callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs
|
||||
|
||||
# 1. Check inputs. Raise error if not correct
|
||||
if use_resolution_binning:
|
||||
if self.transformer.config.sample_size == 32:
|
||||
aspect_ratio_bin = ASPECT_RATIO_1024_BIN
|
||||
else:
|
||||
raise ValueError("Invalid sample size")
|
||||
orig_height, orig_width = height, width
|
||||
height, width = self.image_processor.classify_height_width_bin(height, width, ratios=aspect_ratio_bin)
|
||||
|
||||
self.check_inputs(
|
||||
prompt=prompt,
|
||||
height=height,
|
||||
width=width,
|
||||
num_inference_steps=num_inference_steps,
|
||||
timesteps=timesteps,
|
||||
max_timesteps=max_timesteps,
|
||||
intermediate_timesteps=intermediate_timesteps,
|
||||
callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs,
|
||||
prompt_embeds=prompt_embeds,
|
||||
prompt_attention_mask=prompt_attention_mask,
|
||||
)
|
||||
|
||||
self._guidance_scale = guidance_scale
|
||||
self._attention_kwargs = attention_kwargs
|
||||
self._interrupt = False
|
||||
|
||||
# 2. Default height and width to transformer
|
||||
if prompt is not None and isinstance(prompt, str):
|
||||
batch_size = 1
|
||||
elif prompt is not None and isinstance(prompt, list):
|
||||
batch_size = len(prompt)
|
||||
else:
|
||||
batch_size = prompt_embeds.shape[0]
|
||||
|
||||
device = self._execution_device
|
||||
lora_scale = self.attention_kwargs.get("scale", None) if self.attention_kwargs is not None else None
|
||||
|
||||
# 3. Encode input prompt
|
||||
(
|
||||
prompt_embeds,
|
||||
prompt_attention_mask,
|
||||
) = self.encode_prompt(
|
||||
prompt,
|
||||
num_images_per_prompt=num_images_per_prompt,
|
||||
device=device,
|
||||
prompt_embeds=prompt_embeds,
|
||||
prompt_attention_mask=prompt_attention_mask,
|
||||
clean_caption=clean_caption,
|
||||
max_sequence_length=max_sequence_length,
|
||||
complex_human_instruction=complex_human_instruction,
|
||||
lora_scale=lora_scale,
|
||||
)
|
||||
|
||||
# 4. Prepare timesteps
|
||||
timesteps, num_inference_steps = retrieve_timesteps(
|
||||
self.scheduler,
|
||||
num_inference_steps,
|
||||
device,
|
||||
timesteps,
|
||||
sigmas=None,
|
||||
max_timesteps=max_timesteps,
|
||||
intermediate_timesteps=intermediate_timesteps,
|
||||
)
|
||||
if hasattr(self.scheduler, "set_begin_index"):
|
||||
self.scheduler.set_begin_index(0)
|
||||
|
||||
# 5. Prepare latents.
|
||||
latent_channels = self.transformer.config.in_channels
|
||||
latents = self.prepare_latents(
|
||||
batch_size * num_images_per_prompt,
|
||||
latent_channels,
|
||||
height,
|
||||
width,
|
||||
torch.float32,
|
||||
device,
|
||||
generator,
|
||||
latents,
|
||||
)
|
||||
|
||||
latents = latents * self.scheduler.config.sigma_data
|
||||
|
||||
guidance = torch.full([1], guidance_scale, device=device, dtype=torch.float32)
|
||||
guidance = guidance.expand(latents.shape[0]).to(prompt_embeds.dtype)
|
||||
guidance = guidance * self.transformer.config.guidance_embeds_scale
|
||||
|
||||
# 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
|
||||
extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
|
||||
|
||||
# 7. Denoising loop
|
||||
timesteps = timesteps[:-1]
|
||||
num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
|
||||
self._num_timesteps = len(timesteps)
|
||||
|
||||
with self.progress_bar(total=num_inference_steps) as progress_bar:
|
||||
for i, t in enumerate(timesteps):
|
||||
if self.interrupt:
|
||||
continue
|
||||
|
||||
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
|
||||
timestep = t.expand(latents.shape[0]).to(prompt_embeds.dtype)
|
||||
latents_model_input = latents / self.scheduler.config.sigma_data
|
||||
|
||||
scm_timestep = torch.sin(timestep) / (torch.cos(timestep) + torch.sin(timestep))
|
||||
|
||||
scm_timestep_expanded = scm_timestep.view(-1, 1, 1, 1)
|
||||
latent_model_input = latents_model_input * torch.sqrt(
|
||||
scm_timestep_expanded**2 + (1 - scm_timestep_expanded) ** 2
|
||||
)
|
||||
latent_model_input = latent_model_input.to(prompt_embeds.dtype)
|
||||
|
||||
# predict noise model_output
|
||||
noise_pred = self.transformer(
|
||||
latent_model_input,
|
||||
encoder_hidden_states=prompt_embeds,
|
||||
encoder_attention_mask=prompt_attention_mask,
|
||||
guidance=guidance,
|
||||
timestep=scm_timestep,
|
||||
return_dict=False,
|
||||
attention_kwargs=self.attention_kwargs,
|
||||
)[0]
|
||||
|
||||
noise_pred = (
|
||||
(1 - 2 * scm_timestep_expanded) * latent_model_input
|
||||
+ (1 - 2 * scm_timestep_expanded + 2 * scm_timestep_expanded**2) * noise_pred
|
||||
) / torch.sqrt(scm_timestep_expanded**2 + (1 - scm_timestep_expanded) ** 2)
|
||||
noise_pred = noise_pred.float() * self.scheduler.config.sigma_data
|
||||
|
||||
# compute previous image: x_t -> x_t-1
|
||||
latents, denoised = self.scheduler.step(
|
||||
noise_pred, timestep, latents, **extra_step_kwargs, return_dict=False
|
||||
)
|
||||
|
||||
if callback_on_step_end is not None:
|
||||
callback_kwargs = {}
|
||||
for k in callback_on_step_end_tensor_inputs:
|
||||
callback_kwargs[k] = locals()[k]
|
||||
callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
|
||||
|
||||
latents = callback_outputs.pop("latents", latents)
|
||||
prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
|
||||
|
||||
# call the callback, if provided
|
||||
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
|
||||
progress_bar.update()
|
||||
|
||||
if XLA_AVAILABLE:
|
||||
xm.mark_step()
|
||||
|
||||
latents = denoised / self.scheduler.config.sigma_data
|
||||
if output_type == "latent":
|
||||
image = latents
|
||||
else:
|
||||
latents = latents.to(self.vae.dtype)
|
||||
try:
|
||||
image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0]
|
||||
except torch.cuda.OutOfMemoryError as e:
|
||||
warnings.warn(
|
||||
f"{e}. \n"
|
||||
f"Try to use VAE tiling for large images. For example: \n"
|
||||
f"pipe.vae.enable_tiling(tile_sample_min_width=512, tile_sample_min_height=512)"
|
||||
)
|
||||
if use_resolution_binning:
|
||||
image = self.image_processor.resize_and_crop_tensor(image, orig_width, orig_height)
|
||||
|
||||
if not output_type == "latent":
|
||||
image = self.image_processor.postprocess(image, output_type=output_type)
|
||||
|
||||
# Offload all models
|
||||
self.maybe_free_model_hooks()
|
||||
|
||||
if not return_dict:
|
||||
return (image,)
|
||||
|
||||
return SanaPipelineOutput(images=image)
|
||||
@@ -108,31 +108,16 @@ def prompt_clean(text):
|
||||
return text
|
||||
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents
|
||||
def retrieve_latents(
|
||||
encoder_output: torch.Tensor,
|
||||
latents_mean: torch.Tensor,
|
||||
latents_std: torch.Tensor,
|
||||
generator: Optional[torch.Generator] = None,
|
||||
sample_mode: str = "sample",
|
||||
encoder_output: torch.Tensor, generator: Optional[torch.Generator] = None, sample_mode: str = "sample"
|
||||
):
|
||||
if hasattr(encoder_output, "latent_dist") and sample_mode == "sample":
|
||||
encoder_output.latent_dist.mean = (encoder_output.latent_dist.mean - latents_mean) * latents_std
|
||||
encoder_output.latent_dist.logvar = torch.clamp(
|
||||
(encoder_output.latent_dist.logvar - latents_mean) * latents_std, -30.0, 20.0
|
||||
)
|
||||
encoder_output.latent_dist.std = torch.exp(0.5 * encoder_output.latent_dist.logvar)
|
||||
encoder_output.latent_dist.var = torch.exp(encoder_output.latent_dist.logvar)
|
||||
return encoder_output.latent_dist.sample(generator)
|
||||
elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax":
|
||||
encoder_output.latent_dist.mean = (encoder_output.latent_dist.mean - latents_mean) * latents_std
|
||||
encoder_output.latent_dist.logvar = torch.clamp(
|
||||
(encoder_output.latent_dist.logvar - latents_mean) * latents_std, -30.0, 20.0
|
||||
)
|
||||
encoder_output.latent_dist.std = torch.exp(0.5 * encoder_output.latent_dist.logvar)
|
||||
encoder_output.latent_dist.var = torch.exp(encoder_output.latent_dist.logvar)
|
||||
return encoder_output.latent_dist.mode()
|
||||
elif hasattr(encoder_output, "latents"):
|
||||
return (encoder_output.latents - latents_mean) * latents_std
|
||||
return encoder_output.latents
|
||||
else:
|
||||
raise AttributeError("Could not access latents of provided encoder_output")
|
||||
|
||||
@@ -412,13 +397,15 @@ class WanImageToVideoPipeline(DiffusionPipeline, WanLoraLoaderMixin):
|
||||
|
||||
if isinstance(generator, list):
|
||||
latent_condition = [
|
||||
retrieve_latents(self.vae.encode(video_condition), latents_mean, latents_std, g) for g in generator
|
||||
retrieve_latents(self.vae.encode(video_condition), sample_mode="argmax") for _ in generator
|
||||
]
|
||||
latent_condition = torch.cat(latent_condition)
|
||||
else:
|
||||
latent_condition = retrieve_latents(self.vae.encode(video_condition), latents_mean, latents_std, generator)
|
||||
latent_condition = retrieve_latents(self.vae.encode(video_condition), sample_mode="argmax")
|
||||
latent_condition = latent_condition.repeat(batch_size, 1, 1, 1, 1)
|
||||
|
||||
latent_condition = (latent_condition - latents_mean) * latents_std
|
||||
|
||||
mask_lat_size = torch.ones(batch_size, 1, num_frames, latent_height, latent_width)
|
||||
mask_lat_size[:, :, list(range(1, num_frames))] = 0
|
||||
first_frame_mask = mask_lat_size[:, :, 0:1]
|
||||
|
||||
@@ -61,7 +61,7 @@ class BnB4BitDiffusersQuantizer(DiffusersQuantizer):
|
||||
self.modules_to_not_convert = self.quantization_config.llm_int8_skip_modules
|
||||
|
||||
def validate_environment(self, *args, **kwargs):
|
||||
if not torch.cuda.is_available():
|
||||
if not (torch.cuda.is_available() or torch.xpu.is_available()):
|
||||
raise RuntimeError("No GPU found. A GPU is needed for quantization.")
|
||||
if not is_accelerate_available() or is_accelerate_version("<", "0.26.0"):
|
||||
raise ImportError(
|
||||
@@ -238,11 +238,15 @@ class BnB4BitDiffusersQuantizer(DiffusersQuantizer):
|
||||
|
||||
def update_device_map(self, device_map):
|
||||
if device_map is None:
|
||||
device_map = {"": f"cuda:{torch.cuda.current_device()}"}
|
||||
if torch.xpu.is_available():
|
||||
current_device = f"xpu:{torch.xpu.current_device()}"
|
||||
else:
|
||||
current_device = f"cuda:{torch.cuda.current_device()}"
|
||||
device_map = {"": current_device}
|
||||
logger.info(
|
||||
"The device_map was not initialized. "
|
||||
"Setting device_map to {"
|
||||
": f`cuda:{torch.cuda.current_device()}`}. "
|
||||
": {current_device}}. "
|
||||
"If you want to use the model for inference, please set device_map ='auto' "
|
||||
)
|
||||
return device_map
|
||||
@@ -312,7 +316,10 @@ class BnB4BitDiffusersQuantizer(DiffusersQuantizer):
|
||||
logger.info(
|
||||
"Model was found to be on CPU (could happen as a result of `enable_model_cpu_offload()`). So, moving it to GPU. After dequantization, will move the model back to CPU again to preserve the previous device."
|
||||
)
|
||||
model.to(torch.cuda.current_device())
|
||||
if torch.xpu.is_available():
|
||||
model.to(torch.xpu.current_device())
|
||||
else:
|
||||
model.to(torch.cuda.current_device())
|
||||
|
||||
model = dequantize_and_replace(
|
||||
model, self.modules_to_not_convert, quantization_config=self.quantization_config
|
||||
@@ -343,7 +350,7 @@ class BnB8BitDiffusersQuantizer(DiffusersQuantizer):
|
||||
self.modules_to_not_convert = self.quantization_config.llm_int8_skip_modules
|
||||
|
||||
def validate_environment(self, *args, **kwargs):
|
||||
if not torch.cuda.is_available():
|
||||
if not (torch.cuda.is_available() or torch.xpu.is_available()):
|
||||
raise RuntimeError("No GPU found. A GPU is needed for quantization.")
|
||||
if not is_accelerate_available() or is_accelerate_version("<", "0.26.0"):
|
||||
raise ImportError(
|
||||
@@ -402,11 +409,15 @@ class BnB8BitDiffusersQuantizer(DiffusersQuantizer):
|
||||
# Copied from diffusers.quantizers.bitsandbytes.bnb_quantizer.BnB4BitDiffusersQuantizer.update_device_map
|
||||
def update_device_map(self, device_map):
|
||||
if device_map is None:
|
||||
device_map = {"": f"cuda:{torch.cuda.current_device()}"}
|
||||
if torch.xpu.is_available():
|
||||
current_device = f"xpu:{torch.xpu.current_device()}"
|
||||
else:
|
||||
current_device = f"cuda:{torch.cuda.current_device()}"
|
||||
device_map = {"": current_device}
|
||||
logger.info(
|
||||
"The device_map was not initialized. "
|
||||
"Setting device_map to {"
|
||||
": f`cuda:{torch.cuda.current_device()}`}. "
|
||||
": {current_device}}. "
|
||||
"If you want to use the model for inference, please set device_map ='auto' "
|
||||
)
|
||||
return device_map
|
||||
|
||||
@@ -68,6 +68,7 @@ else:
|
||||
_import_structure["scheduling_pndm"] = ["PNDMScheduler"]
|
||||
_import_structure["scheduling_repaint"] = ["RePaintScheduler"]
|
||||
_import_structure["scheduling_sasolver"] = ["SASolverScheduler"]
|
||||
_import_structure["scheduling_scm"] = ["SCMScheduler"]
|
||||
_import_structure["scheduling_sde_ve"] = ["ScoreSdeVeScheduler"]
|
||||
_import_structure["scheduling_tcd"] = ["TCDScheduler"]
|
||||
_import_structure["scheduling_unclip"] = ["UnCLIPScheduler"]
|
||||
@@ -168,13 +169,13 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
from .scheduling_pndm import PNDMScheduler
|
||||
from .scheduling_repaint import RePaintScheduler
|
||||
from .scheduling_sasolver import SASolverScheduler
|
||||
from .scheduling_scm import SCMScheduler
|
||||
from .scheduling_sde_ve import ScoreSdeVeScheduler
|
||||
from .scheduling_tcd import TCDScheduler
|
||||
from .scheduling_unclip import UnCLIPScheduler
|
||||
from .scheduling_unipc_multistep import UniPCMultistepScheduler
|
||||
from .scheduling_utils import AysSchedules, KarrasDiffusionSchedulers, SchedulerMixin
|
||||
from .scheduling_vq_diffusion import VQDiffusionScheduler
|
||||
|
||||
try:
|
||||
if not is_flax_available():
|
||||
raise OptionalDependencyNotAvailable()
|
||||
|
||||
@@ -377,6 +377,7 @@ class FlowMatchEulerDiscreteScheduler(SchedulerMixin, ConfigMixin):
|
||||
s_tmax: float = float("inf"),
|
||||
s_noise: float = 1.0,
|
||||
generator: Optional[torch.Generator] = None,
|
||||
per_token_timesteps: Optional[torch.Tensor] = None,
|
||||
return_dict: bool = True,
|
||||
) -> Union[FlowMatchEulerDiscreteSchedulerOutput, Tuple]:
|
||||
"""
|
||||
@@ -397,6 +398,8 @@ class FlowMatchEulerDiscreteScheduler(SchedulerMixin, ConfigMixin):
|
||||
Scaling factor for noise added to the sample.
|
||||
generator (`torch.Generator`, *optional*):
|
||||
A random number generator.
|
||||
per_token_timesteps (`torch.Tensor`, *optional*):
|
||||
The timesteps for each token in the sample.
|
||||
return_dict (`bool`):
|
||||
Whether or not to return a
|
||||
[`~schedulers.scheduling_flow_match_euler_discrete.FlowMatchEulerDiscreteSchedulerOutput`] or tuple.
|
||||
@@ -427,16 +430,26 @@ class FlowMatchEulerDiscreteScheduler(SchedulerMixin, ConfigMixin):
|
||||
# Upcast to avoid precision issues when computing prev_sample
|
||||
sample = sample.to(torch.float32)
|
||||
|
||||
sigma = self.sigmas[self.step_index]
|
||||
sigma_next = self.sigmas[self.step_index + 1]
|
||||
if per_token_timesteps is not None:
|
||||
per_token_sigmas = per_token_timesteps / self.config.num_train_timesteps
|
||||
|
||||
prev_sample = sample + (sigma_next - sigma) * model_output
|
||||
sigmas = self.sigmas[:, None, None]
|
||||
lower_mask = sigmas < per_token_sigmas[None] - 1e-6
|
||||
lower_sigmas = lower_mask * sigmas
|
||||
lower_sigmas, _ = lower_sigmas.max(dim=0)
|
||||
dt = (per_token_sigmas - lower_sigmas)[..., None]
|
||||
else:
|
||||
sigma = self.sigmas[self.step_index]
|
||||
sigma_next = self.sigmas[self.step_index + 1]
|
||||
dt = sigma_next - sigma
|
||||
|
||||
# Cast sample back to model compatible dtype
|
||||
prev_sample = prev_sample.to(model_output.dtype)
|
||||
prev_sample = sample + dt * model_output
|
||||
|
||||
# upon completion increase step index by one
|
||||
self._step_index += 1
|
||||
if per_token_timesteps is None:
|
||||
# Cast sample back to model compatible dtype
|
||||
prev_sample = prev_sample.to(model_output.dtype)
|
||||
|
||||
if not return_dict:
|
||||
return (prev_sample,)
|
||||
|
||||
@@ -0,0 +1,265 @@
|
||||
# # Copyright 2024 Sana-Sprint Authors and The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
# DISCLAIMER: This code is strongly influenced by https://github.com/pesser/pytorch_diffusion
|
||||
# and https://github.com/hojonathanho/diffusion
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import Optional, Tuple, Union
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
from ..configuration_utils import ConfigMixin, register_to_config
|
||||
from ..schedulers.scheduling_utils import SchedulerMixin
|
||||
from ..utils import BaseOutput, logging
|
||||
from ..utils.torch_utils import randn_tensor
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||
|
||||
|
||||
@dataclass
|
||||
# Copied from diffusers.schedulers.scheduling_ddpm.DDPMSchedulerOutput with DDPM->SCM
|
||||
class SCMSchedulerOutput(BaseOutput):
|
||||
"""
|
||||
Output class for the scheduler's `step` function output.
|
||||
|
||||
Args:
|
||||
prev_sample (`torch.Tensor` of shape `(batch_size, num_channels, height, width)` for images):
|
||||
Computed sample `(x_{t-1})` of previous timestep. `prev_sample` should be used as next model input in the
|
||||
denoising loop.
|
||||
pred_original_sample (`torch.Tensor` of shape `(batch_size, num_channels, height, width)` for images):
|
||||
The predicted denoised sample `(x_{0})` based on the model output from the current timestep.
|
||||
`pred_original_sample` can be used to preview progress or for guidance.
|
||||
"""
|
||||
|
||||
prev_sample: torch.Tensor
|
||||
pred_original_sample: Optional[torch.Tensor] = None
|
||||
|
||||
|
||||
class SCMScheduler(SchedulerMixin, ConfigMixin):
|
||||
"""
|
||||
`SCMScheduler` extends the denoising procedure introduced in denoising diffusion probabilistic models (DDPMs) with
|
||||
non-Markovian guidance. This model inherits from [`SchedulerMixin`] and [`ConfigMixin`]. Check the superclass
|
||||
documentation for the generic methods the library implements for all schedulers such as loading and saving.
|
||||
|
||||
Args:
|
||||
num_train_timesteps (`int`, defaults to 1000):
|
||||
The number of diffusion steps to train the model.
|
||||
prediction_type (`str`, defaults to `trigflow`):
|
||||
Prediction type of the scheduler function. Currently only supports "trigflow".
|
||||
sigma_data (`float`, defaults to 0.5):
|
||||
The standard deviation of the noise added during multi-step inference.
|
||||
"""
|
||||
|
||||
# _compatibles = [e.name for e in KarrasDiffusionSchedulers]
|
||||
order = 1
|
||||
|
||||
@register_to_config
|
||||
def __init__(
|
||||
self,
|
||||
num_train_timesteps: int = 1000,
|
||||
prediction_type: str = "trigflow",
|
||||
sigma_data: float = 0.5,
|
||||
):
|
||||
"""
|
||||
Initialize the SCM scheduler.
|
||||
|
||||
Args:
|
||||
num_train_timesteps (`int`, defaults to 1000):
|
||||
The number of diffusion steps to train the model.
|
||||
prediction_type (`str`, defaults to `trigflow`):
|
||||
Prediction type of the scheduler function. Currently only supports "trigflow".
|
||||
sigma_data (`float`, defaults to 0.5):
|
||||
The standard deviation of the noise added during multi-step inference.
|
||||
"""
|
||||
# standard deviation of the initial noise distribution
|
||||
self.init_noise_sigma = 1.0
|
||||
|
||||
# setable values
|
||||
self.num_inference_steps = None
|
||||
self.timesteps = torch.from_numpy(np.arange(0, num_train_timesteps)[::-1].copy().astype(np.int64))
|
||||
|
||||
self._step_index = None
|
||||
self._begin_index = None
|
||||
|
||||
@property
|
||||
def step_index(self):
|
||||
return self._step_index
|
||||
|
||||
@property
|
||||
def begin_index(self):
|
||||
return self._begin_index
|
||||
|
||||
# Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.set_begin_index
|
||||
def set_begin_index(self, begin_index: int = 0):
|
||||
"""
|
||||
Sets the begin index for the scheduler. This function should be run from pipeline before the inference.
|
||||
|
||||
Args:
|
||||
begin_index (`int`):
|
||||
The begin index for the scheduler.
|
||||
"""
|
||||
self._begin_index = begin_index
|
||||
|
||||
def set_timesteps(
|
||||
self,
|
||||
num_inference_steps: int,
|
||||
timesteps: torch.Tensor = None,
|
||||
device: Union[str, torch.device] = None,
|
||||
max_timesteps: float = 1.57080,
|
||||
intermediate_timesteps: float = 1.3,
|
||||
):
|
||||
"""
|
||||
Sets the discrete timesteps used for the diffusion chain (to be run before inference).
|
||||
|
||||
Args:
|
||||
num_inference_steps (`int`):
|
||||
The number of diffusion steps used when generating samples with a pre-trained model.
|
||||
timesteps (`torch.Tensor`, *optional*):
|
||||
Custom timesteps to use for the denoising process.
|
||||
max_timesteps (`float`, defaults to 1.57080):
|
||||
The maximum timestep value used in the SCM scheduler.
|
||||
intermediate_timesteps (`float`, *optional*, defaults to 1.3):
|
||||
The intermediate timestep value used in SCM scheduler (only used when num_inference_steps=2).
|
||||
"""
|
||||
if num_inference_steps > self.config.num_train_timesteps:
|
||||
raise ValueError(
|
||||
f"`num_inference_steps`: {num_inference_steps} cannot be larger than `self.config.train_timesteps`:"
|
||||
f" {self.config.num_train_timesteps} as the unet model trained with this scheduler can only handle"
|
||||
f" maximal {self.config.num_train_timesteps} timesteps."
|
||||
)
|
||||
|
||||
if timesteps is not None and len(timesteps) != num_inference_steps + 1:
|
||||
raise ValueError("If providing custom timesteps, `timesteps` must be of length `num_inference_steps + 1`.")
|
||||
|
||||
if timesteps is not None and max_timesteps is not None:
|
||||
raise ValueError("If providing custom timesteps, `max_timesteps` should not be provided.")
|
||||
|
||||
if timesteps is None and max_timesteps is None:
|
||||
raise ValueError("Should provide either `timesteps` or `max_timesteps`.")
|
||||
|
||||
if intermediate_timesteps is not None and num_inference_steps != 2:
|
||||
raise ValueError("Intermediate timesteps for SCM is not supported when num_inference_steps != 2.")
|
||||
|
||||
self.num_inference_steps = num_inference_steps
|
||||
|
||||
if timesteps is not None:
|
||||
if isinstance(timesteps, list):
|
||||
self.timesteps = torch.tensor(timesteps, device=device).float()
|
||||
elif isinstance(timesteps, torch.Tensor):
|
||||
self.timesteps = timesteps.to(device).float()
|
||||
else:
|
||||
raise ValueError(f"Unsupported timesteps type: {type(timesteps)}")
|
||||
elif intermediate_timesteps is not None:
|
||||
self.timesteps = torch.tensor([max_timesteps, intermediate_timesteps, 0], device=device).float()
|
||||
else:
|
||||
# max_timesteps=arctan(80/0.5)=1.56454 is the default from sCM paper, we choose a different value here
|
||||
self.timesteps = torch.linspace(max_timesteps, 0, num_inference_steps + 1, device=device).float()
|
||||
print(f"Set timesteps: {self.timesteps}")
|
||||
|
||||
self._step_index = None
|
||||
self._begin_index = None
|
||||
|
||||
# Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._init_step_index
|
||||
def _init_step_index(self, timestep):
|
||||
if self.begin_index is None:
|
||||
if isinstance(timestep, torch.Tensor):
|
||||
timestep = timestep.to(self.timesteps.device)
|
||||
self._step_index = self.index_for_timestep(timestep)
|
||||
else:
|
||||
self._step_index = self._begin_index
|
||||
|
||||
# Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler.index_for_timestep
|
||||
def index_for_timestep(self, timestep, schedule_timesteps=None):
|
||||
if schedule_timesteps is None:
|
||||
schedule_timesteps = self.timesteps
|
||||
|
||||
indices = (schedule_timesteps == timestep).nonzero()
|
||||
|
||||
# The sigma index that is taken for the **very** first `step`
|
||||
# is always the second index (or the last index if there is only 1)
|
||||
# This way we can ensure we don't accidentally skip a sigma in
|
||||
# case we start in the middle of the denoising schedule (e.g. for image-to-image)
|
||||
pos = 1 if len(indices) > 1 else 0
|
||||
|
||||
return indices[pos].item()
|
||||
|
||||
def step(
|
||||
self,
|
||||
model_output: torch.FloatTensor,
|
||||
timestep: float,
|
||||
sample: torch.FloatTensor,
|
||||
generator: torch.Generator = None,
|
||||
return_dict: bool = True,
|
||||
) -> Union[SCMSchedulerOutput, Tuple]:
|
||||
"""
|
||||
Predict the sample from the previous timestep by reversing the SDE. This function propagates the diffusion
|
||||
process from the learned model outputs (most often the predicted noise).
|
||||
|
||||
Args:
|
||||
model_output (`torch.FloatTensor`):
|
||||
The direct output from learned diffusion model.
|
||||
timestep (`float`):
|
||||
The current discrete timestep in the diffusion chain.
|
||||
sample (`torch.FloatTensor`):
|
||||
A current instance of a sample created by the diffusion process.
|
||||
return_dict (`bool`, *optional*, defaults to `True`):
|
||||
Whether or not to return a [`~schedulers.scheduling_scm.SCMSchedulerOutput`] or `tuple`.
|
||||
Returns:
|
||||
[`~schedulers.scheduling_utils.SCMSchedulerOutput`] or `tuple`:
|
||||
If return_dict is `True`, [`~schedulers.scheduling_scm.SCMSchedulerOutput`] is returned, otherwise a
|
||||
tuple is returned where the first element is the sample tensor.
|
||||
"""
|
||||
if self.num_inference_steps is None:
|
||||
raise ValueError(
|
||||
"Number of inference steps is 'None', you need to run 'set_timesteps' after creating the scheduler"
|
||||
)
|
||||
|
||||
if self.step_index is None:
|
||||
self._init_step_index(timestep)
|
||||
|
||||
# 2. compute alphas, betas
|
||||
t = self.timesteps[self.step_index + 1]
|
||||
s = self.timesteps[self.step_index]
|
||||
|
||||
# 4. Different Parameterization:
|
||||
parameterization = self.config.prediction_type
|
||||
|
||||
if parameterization == "trigflow":
|
||||
pred_x0 = torch.cos(s) * sample - torch.sin(s) * model_output
|
||||
else:
|
||||
raise ValueError(f"Unsupported parameterization: {parameterization}")
|
||||
|
||||
# 5. Sample z ~ N(0, I), For MultiStep Inference
|
||||
# Noise is not used for one-step sampling.
|
||||
if len(self.timesteps) > 1:
|
||||
noise = (
|
||||
randn_tensor(model_output.shape, device=model_output.device, generator=generator)
|
||||
* self.config.sigma_data
|
||||
)
|
||||
prev_sample = torch.cos(t) * pred_x0 + torch.sin(t) * noise
|
||||
else:
|
||||
prev_sample = pred_x0
|
||||
|
||||
self._step_index += 1
|
||||
|
||||
if not return_dict:
|
||||
return (prev_sample, pred_x0)
|
||||
|
||||
return SCMSchedulerOutput(prev_sample=prev_sample, pred_original_sample=pred_x0)
|
||||
|
||||
def __len__(self):
|
||||
return self.config.num_train_timesteps
|
||||
@@ -2,6 +2,21 @@
|
||||
from ..utils import DummyObject, requires_backends
|
||||
|
||||
|
||||
class FasterCacheConfig(metaclass=DummyObject):
|
||||
_backends = ["torch"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["torch"])
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch"])
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch"])
|
||||
|
||||
|
||||
class HookRegistry(metaclass=DummyObject):
|
||||
_backends = ["torch"]
|
||||
|
||||
@@ -32,6 +47,10 @@ class PyramidAttentionBroadcastConfig(metaclass=DummyObject):
|
||||
requires_backends(cls, ["torch"])
|
||||
|
||||
|
||||
def apply_faster_cache(*args, **kwargs):
|
||||
requires_backends(apply_faster_cache, ["torch"])
|
||||
|
||||
|
||||
def apply_pyramid_attention_broadcast(*args, **kwargs):
|
||||
requires_backends(apply_pyramid_attention_broadcast, ["torch"])
|
||||
|
||||
@@ -1834,6 +1853,21 @@ class SchedulerMixin(metaclass=DummyObject):
|
||||
requires_backends(cls, ["torch"])
|
||||
|
||||
|
||||
class SCMScheduler(metaclass=DummyObject):
|
||||
_backends = ["torch"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["torch"])
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch"])
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch"])
|
||||
|
||||
|
||||
class ScoreSdeVeScheduler(metaclass=DummyObject):
|
||||
_backends = ["torch"]
|
||||
|
||||
|
||||
@@ -1217,6 +1217,21 @@ class LEditsPPPipelineStableDiffusionXL(metaclass=DummyObject):
|
||||
requires_backends(cls, ["torch", "transformers"])
|
||||
|
||||
|
||||
class LTXConditionPipeline(metaclass=DummyObject):
|
||||
_backends = ["torch", "transformers"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["torch", "transformers"])
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch", "transformers"])
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch", "transformers"])
|
||||
|
||||
|
||||
class LTXImageToVideoPipeline(metaclass=DummyObject):
|
||||
_backends = ["torch", "transformers"]
|
||||
|
||||
@@ -1517,6 +1532,21 @@ class SanaPipeline(metaclass=DummyObject):
|
||||
requires_backends(cls, ["torch", "transformers"])
|
||||
|
||||
|
||||
class SanaSprintPipeline(metaclass=DummyObject):
|
||||
_backends = ["torch", "transformers"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["torch", "transformers"])
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch", "transformers"])
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch", "transformers"])
|
||||
|
||||
|
||||
class SemanticStableDiffusionPipeline(metaclass=DummyObject):
|
||||
_backends = ["torch", "transformers"]
|
||||
|
||||
|
||||
@@ -3,7 +3,7 @@ import random
|
||||
import struct
|
||||
import tempfile
|
||||
from contextlib import contextmanager
|
||||
from typing import List, Union
|
||||
from typing import List, Optional, Union
|
||||
|
||||
import numpy as np
|
||||
import PIL.Image
|
||||
@@ -139,8 +139,31 @@ def _legacy_export_to_video(
|
||||
|
||||
|
||||
def export_to_video(
|
||||
video_frames: Union[List[np.ndarray], List[PIL.Image.Image]], output_video_path: str = None, fps: int = 10
|
||||
video_frames: Union[List[np.ndarray], List[PIL.Image.Image]],
|
||||
output_video_path: str = None,
|
||||
fps: int = 10,
|
||||
quality: float = 5.0,
|
||||
bitrate: Optional[int] = None,
|
||||
macro_block_size: Optional[int] = 16,
|
||||
) -> str:
|
||||
"""
|
||||
quality:
|
||||
Video output quality. Default is 5. Uses variable bit rate. Highest quality is 10, lowest is 0. Set to None to
|
||||
prevent variable bitrate flags to FFMPEG so you can manually specify them using output_params instead.
|
||||
Specifying a fixed bitrate using `bitrate` disables this parameter.
|
||||
|
||||
bitrate:
|
||||
Set a constant bitrate for the video encoding. Default is None causing `quality` parameter to be used instead.
|
||||
Better quality videos with smaller file sizes will result from using the `quality` variable bitrate parameter
|
||||
rather than specifiying a fixed bitrate with this parameter.
|
||||
|
||||
macro_block_size:
|
||||
Size constraint for video. Width and height, must be divisible by this number. If not divisible by this number
|
||||
imageio will tell ffmpeg to scale the image up to the next closest size divisible by this number. Most codecs
|
||||
are compatible with a macroblock size of 16 (default), some can go smaller (4, 8). To disable this automatic
|
||||
feature set it to None or 1, however be warned many players can't decode videos that are odd in size and some
|
||||
codecs will produce poor results or fail. See https://en.wikipedia.org/wiki/Macroblock.
|
||||
"""
|
||||
# TODO: Dhruv. Remove by Diffusers release 0.33.0
|
||||
# Added to prevent breaking existing code
|
||||
if not is_imageio_available():
|
||||
@@ -177,7 +200,9 @@ def export_to_video(
|
||||
elif isinstance(video_frames[0], PIL.Image.Image):
|
||||
video_frames = [np.array(frame) for frame in video_frames]
|
||||
|
||||
with imageio.get_writer(output_video_path, fps=fps) as writer:
|
||||
with imageio.get_writer(
|
||||
output_video_path, fps=fps, quality=quality, bitrate=bitrate, macro_block_size=macro_block_size
|
||||
) as writer:
|
||||
for frame in video_frames:
|
||||
writer.append_data(frame)
|
||||
|
||||
|
||||
@@ -367,7 +367,7 @@ def prepare_encode(
|
||||
if shift_factor is not None:
|
||||
parameters["shift_factor"] = shift_factor
|
||||
if isinstance(image, torch.Tensor):
|
||||
data = safetensors.torch._tobytes(image, "tensor")
|
||||
data = safetensors.torch._tobytes(image.contiguous(), "tensor")
|
||||
parameters["shape"] = list(image.shape)
|
||||
parameters["dtype"] = str(image.dtype).split(".")[-1]
|
||||
else:
|
||||
|
||||
@@ -320,6 +320,21 @@ def require_torch_multi_gpu(test_case):
|
||||
return unittest.skipUnless(torch.cuda.device_count() > 1, "test requires multiple GPUs")(test_case)
|
||||
|
||||
|
||||
def require_torch_multi_accelerator(test_case):
|
||||
"""
|
||||
Decorator marking a test that requires a multi-accelerator setup (in PyTorch). These tests are skipped on a machine
|
||||
without multiple hardware accelerators.
|
||||
"""
|
||||
if not is_torch_available():
|
||||
return unittest.skip("test requires PyTorch")(test_case)
|
||||
|
||||
import torch
|
||||
|
||||
return unittest.skipUnless(
|
||||
torch.cuda.device_count() > 1 or torch.xpu.device_count() > 1, "test requires multiple hardware accelerators"
|
||||
)(test_case)
|
||||
|
||||
|
||||
def require_torch_accelerator_with_fp16(test_case):
|
||||
"""Decorator marking a test that requires an accelerator with support for the FP16 data type."""
|
||||
return unittest.skipUnless(_is_torch_fp16_available(torch_device), "test requires accelerator with fp16 support")(
|
||||
@@ -354,6 +369,31 @@ def require_big_gpu_with_torch_cuda(test_case):
|
||||
)(test_case)
|
||||
|
||||
|
||||
def require_big_accelerator(test_case):
|
||||
"""
|
||||
Decorator marking a test that requires a bigger hardware accelerator (24GB) for execution. Some example pipelines:
|
||||
Flux, SD3, Cog, etc.
|
||||
"""
|
||||
if not is_torch_available():
|
||||
return unittest.skip("test requires PyTorch")(test_case)
|
||||
|
||||
import torch
|
||||
|
||||
if not (torch.cuda.is_available() or torch.xpu.is_available()):
|
||||
return unittest.skip("test requires PyTorch CUDA")(test_case)
|
||||
|
||||
if torch.xpu.is_available():
|
||||
device_properties = torch.xpu.get_device_properties(0)
|
||||
else:
|
||||
device_properties = torch.cuda.get_device_properties(0)
|
||||
|
||||
total_memory = device_properties.total_memory / (1024**3)
|
||||
return unittest.skipUnless(
|
||||
total_memory >= BIG_GPU_MEMORY,
|
||||
f"test requires a hardware accelerator with at least {BIG_GPU_MEMORY} GB memory",
|
||||
)(test_case)
|
||||
|
||||
|
||||
def require_torch_accelerator_with_training(test_case):
|
||||
"""Decorator marking a test that requires an accelerator with support for training."""
|
||||
return unittest.skipUnless(
|
||||
@@ -574,10 +614,10 @@ def load_numpy(arry: Union[str, np.ndarray], local_path: Optional[str] = None) -
|
||||
return arry
|
||||
|
||||
|
||||
def load_pt(url: str):
|
||||
def load_pt(url: str, map_location: str):
|
||||
response = requests.get(url)
|
||||
response.raise_for_status()
|
||||
arry = torch.load(BytesIO(response.content))
|
||||
arry = torch.load(BytesIO(response.content), map_location=map_location)
|
||||
return arry
|
||||
|
||||
|
||||
|
||||
@@ -124,7 +124,7 @@ class AsymmetricAutoencoderKLIntegrationTests(unittest.TestCase):
|
||||
return model
|
||||
|
||||
def get_generator(self, seed=0):
|
||||
generator_device = "cpu" if not torch_device.startswith("cuda") else "cuda"
|
||||
generator_device = "cpu" if not torch_device.startswith(torch_device) else torch_device
|
||||
if torch_device != "mps":
|
||||
return torch.Generator(device=generator_device).manual_seed(seed)
|
||||
return torch.manual_seed(seed)
|
||||
|
||||
@@ -165,7 +165,7 @@ class AutoencoderKLTests(ModelTesterMixin, UNetTesterMixin, unittest.TestCase):
|
||||
model.eval()
|
||||
|
||||
# Keep generator on CPU for non-CUDA devices to compare outputs with CPU result tensors
|
||||
generator_device = "cpu" if not torch_device.startswith("cuda") else "cuda"
|
||||
generator_device = "cpu" if not torch_device.startswith(torch_device) else torch_device
|
||||
if torch_device != "mps":
|
||||
generator = torch.Generator(device=generator_device).manual_seed(0)
|
||||
else:
|
||||
@@ -263,7 +263,7 @@ class AutoencoderKLIntegrationTests(unittest.TestCase):
|
||||
return model
|
||||
|
||||
def get_generator(self, seed=0):
|
||||
generator_device = "cpu" if not torch_device.startswith("cuda") else "cuda"
|
||||
generator_device = "cpu" if not torch_device.startswith(torch_device) else torch_device
|
||||
if torch_device != "mps":
|
||||
return torch.Generator(device=generator_device).manual_seed(seed)
|
||||
return torch.manual_seed(seed)
|
||||
|
||||
@@ -183,7 +183,7 @@ class AutoencoderOobleckIntegrationTests(unittest.TestCase):
|
||||
return model
|
||||
|
||||
def get_generator(self, seed=0):
|
||||
generator_device = "cpu" if not torch_device.startswith("cuda") else "cuda"
|
||||
generator_device = "cpu" if not torch_device.startswith(torch_device) else torch_device
|
||||
if torch_device != "mps":
|
||||
return torch.Generator(device=generator_device).manual_seed(seed)
|
||||
return torch.manual_seed(seed)
|
||||
|
||||
@@ -63,7 +63,7 @@ from diffusers.utils.testing_utils import (
|
||||
require_torch_accelerator,
|
||||
require_torch_accelerator_with_training,
|
||||
require_torch_gpu,
|
||||
require_torch_multi_gpu,
|
||||
require_torch_multi_accelerator,
|
||||
run_test_in_subprocess,
|
||||
torch_all_close,
|
||||
torch_device,
|
||||
@@ -1227,7 +1227,7 @@ class ModelTesterMixin:
|
||||
|
||||
self.assertTrue(torch.allclose(base_output[0], new_output[0], atol=1e-5))
|
||||
|
||||
@require_torch_multi_gpu
|
||||
@require_torch_multi_accelerator
|
||||
def test_model_parallelism(self):
|
||||
config, inputs_dict = self.prepare_init_args_and_inputs_for_common()
|
||||
model = self.model_class(**config).eval()
|
||||
|
||||
@@ -80,6 +80,7 @@ class HunyuanVideoTransformer3DTests(ModelTesterMixin, unittest.TestCase):
|
||||
"text_embed_dim": 16,
|
||||
"pooled_projection_dim": 8,
|
||||
"rope_axes_dim": (2, 4, 4),
|
||||
"image_condition_type": None,
|
||||
}
|
||||
inputs_dict = self.dummy_input
|
||||
return init_dict, inputs_dict
|
||||
@@ -144,6 +145,7 @@ class HunyuanSkyreelsImageToVideoTransformer3DTests(ModelTesterMixin, unittest.T
|
||||
"text_embed_dim": 16,
|
||||
"pooled_projection_dim": 8,
|
||||
"rope_axes_dim": (2, 4, 4),
|
||||
"image_condition_type": None,
|
||||
}
|
||||
inputs_dict = self.dummy_input
|
||||
return init_dict, inputs_dict
|
||||
@@ -209,6 +211,75 @@ class HunyuanVideoImageToVideoTransformer3DTests(ModelTesterMixin, unittest.Test
|
||||
"text_embed_dim": 16,
|
||||
"pooled_projection_dim": 8,
|
||||
"rope_axes_dim": (2, 4, 4),
|
||||
"image_condition_type": "latent_concat",
|
||||
}
|
||||
inputs_dict = self.dummy_input
|
||||
return init_dict, inputs_dict
|
||||
|
||||
def test_output(self):
|
||||
super().test_output(expected_output_shape=(1, *self.output_shape))
|
||||
|
||||
def test_gradient_checkpointing_is_applied(self):
|
||||
expected_set = {"HunyuanVideoTransformer3DModel"}
|
||||
super().test_gradient_checkpointing_is_applied(expected_set=expected_set)
|
||||
|
||||
|
||||
class HunyuanVideoTokenReplaceImageToVideoTransformer3DTests(ModelTesterMixin, unittest.TestCase):
|
||||
model_class = HunyuanVideoTransformer3DModel
|
||||
main_input_name = "hidden_states"
|
||||
uses_custom_attn_processor = True
|
||||
|
||||
@property
|
||||
def dummy_input(self):
|
||||
batch_size = 1
|
||||
num_channels = 2
|
||||
num_frames = 1
|
||||
height = 16
|
||||
width = 16
|
||||
text_encoder_embedding_dim = 16
|
||||
pooled_projection_dim = 8
|
||||
sequence_length = 12
|
||||
|
||||
hidden_states = torch.randn((batch_size, num_channels, num_frames, height, width)).to(torch_device)
|
||||
timestep = torch.randint(0, 1000, size=(batch_size,)).to(torch_device)
|
||||
encoder_hidden_states = torch.randn((batch_size, sequence_length, text_encoder_embedding_dim)).to(torch_device)
|
||||
pooled_projections = torch.randn((batch_size, pooled_projection_dim)).to(torch_device)
|
||||
encoder_attention_mask = torch.ones((batch_size, sequence_length)).to(torch_device)
|
||||
guidance = torch.randint(0, 1000, size=(batch_size,)).to(torch_device, dtype=torch.float32)
|
||||
|
||||
return {
|
||||
"hidden_states": hidden_states,
|
||||
"timestep": timestep,
|
||||
"encoder_hidden_states": encoder_hidden_states,
|
||||
"pooled_projections": pooled_projections,
|
||||
"encoder_attention_mask": encoder_attention_mask,
|
||||
"guidance": guidance,
|
||||
}
|
||||
|
||||
@property
|
||||
def input_shape(self):
|
||||
return (8, 1, 16, 16)
|
||||
|
||||
@property
|
||||
def output_shape(self):
|
||||
return (4, 1, 16, 16)
|
||||
|
||||
def prepare_init_args_and_inputs_for_common(self):
|
||||
init_dict = {
|
||||
"in_channels": 2,
|
||||
"out_channels": 4,
|
||||
"num_attention_heads": 2,
|
||||
"attention_head_dim": 10,
|
||||
"num_layers": 1,
|
||||
"num_single_layers": 1,
|
||||
"num_refiner_layers": 1,
|
||||
"patch_size": 1,
|
||||
"patch_size_t": 1,
|
||||
"guidance_embeds": True,
|
||||
"text_embed_dim": 16,
|
||||
"pooled_projection_dim": 8,
|
||||
"rope_axes_dim": (2, 4, 4),
|
||||
"image_condition_type": "token_replace",
|
||||
}
|
||||
inputs_dict = self.dummy_input
|
||||
return init_dict, inputs_dict
|
||||
|
||||
@@ -31,6 +31,7 @@ from diffusers.utils.testing_utils import (
|
||||
|
||||
from ..pipeline_params import TEXT_TO_IMAGE_BATCH_PARAMS, TEXT_TO_IMAGE_IMAGE_PARAMS, TEXT_TO_IMAGE_PARAMS
|
||||
from ..test_pipelines_common import (
|
||||
FasterCacheTesterMixin,
|
||||
PipelineTesterMixin,
|
||||
PyramidAttentionBroadcastTesterMixin,
|
||||
check_qkv_fusion_matches_attn_procs_length,
|
||||
@@ -42,7 +43,9 @@ from ..test_pipelines_common import (
|
||||
enable_full_determinism()
|
||||
|
||||
|
||||
class CogVideoXPipelineFastTests(PipelineTesterMixin, PyramidAttentionBroadcastTesterMixin, unittest.TestCase):
|
||||
class CogVideoXPipelineFastTests(
|
||||
PipelineTesterMixin, PyramidAttentionBroadcastTesterMixin, FasterCacheTesterMixin, unittest.TestCase
|
||||
):
|
||||
pipeline_class = CogVideoXPipeline
|
||||
params = TEXT_TO_IMAGE_PARAMS - {"cross_attention_kwargs"}
|
||||
batch_params = TEXT_TO_IMAGE_BATCH_PARAMS
|
||||
|
||||
@@ -31,9 +31,10 @@ from diffusers import (
|
||||
from diffusers.models import SD3ControlNetModel, SD3MultiControlNetModel
|
||||
from diffusers.utils import load_image
|
||||
from diffusers.utils.testing_utils import (
|
||||
backend_empty_cache,
|
||||
enable_full_determinism,
|
||||
numpy_cosine_similarity_distance,
|
||||
require_big_gpu_with_torch_cuda,
|
||||
require_big_accelerator,
|
||||
slow,
|
||||
torch_device,
|
||||
)
|
||||
@@ -219,7 +220,7 @@ class StableDiffusion3ControlNetPipelineFastTests(unittest.TestCase, PipelineTes
|
||||
|
||||
|
||||
@slow
|
||||
@require_big_gpu_with_torch_cuda
|
||||
@require_big_accelerator
|
||||
@pytest.mark.big_gpu_with_torch_cuda
|
||||
class StableDiffusion3ControlNetPipelineSlowTests(unittest.TestCase):
|
||||
pipeline_class = StableDiffusion3ControlNetPipeline
|
||||
@@ -227,12 +228,12 @@ class StableDiffusion3ControlNetPipelineSlowTests(unittest.TestCase):
|
||||
def setUp(self):
|
||||
super().setUp()
|
||||
gc.collect()
|
||||
torch.cuda.empty_cache()
|
||||
backend_empty_cache(torch_device)
|
||||
|
||||
def tearDown(self):
|
||||
super().tearDown()
|
||||
gc.collect()
|
||||
torch.cuda.empty_cache()
|
||||
backend_empty_cache(torch_device)
|
||||
|
||||
def test_canny(self):
|
||||
controlnet = SD3ControlNetModel.from_pretrained("InstantX/SD3-Controlnet-Canny", torch_dtype=torch.float16)
|
||||
@@ -272,7 +273,7 @@ class StableDiffusion3ControlNetPipelineSlowTests(unittest.TestCase):
|
||||
pipe = StableDiffusion3ControlNetPipeline.from_pretrained(
|
||||
"stabilityai/stable-diffusion-3-medium-diffusers", controlnet=controlnet, torch_dtype=torch.float16
|
||||
)
|
||||
pipe.enable_model_cpu_offload()
|
||||
pipe.enable_model_cpu_offload(device=torch_device)
|
||||
pipe.set_progress_bar_config(disable=None)
|
||||
|
||||
generator = torch.Generator(device="cpu").manual_seed(0)
|
||||
@@ -304,7 +305,7 @@ class StableDiffusion3ControlNetPipelineSlowTests(unittest.TestCase):
|
||||
pipe = StableDiffusion3ControlNetPipeline.from_pretrained(
|
||||
"stabilityai/stable-diffusion-3-medium-diffusers", controlnet=controlnet, torch_dtype=torch.float16
|
||||
)
|
||||
pipe.enable_model_cpu_offload()
|
||||
pipe.enable_model_cpu_offload(device=torch_device)
|
||||
pipe.set_progress_bar_config(disable=None)
|
||||
|
||||
generator = torch.Generator(device="cpu").manual_seed(0)
|
||||
@@ -338,7 +339,7 @@ class StableDiffusion3ControlNetPipelineSlowTests(unittest.TestCase):
|
||||
pipe = StableDiffusion3ControlNetPipeline.from_pretrained(
|
||||
"stabilityai/stable-diffusion-3-medium-diffusers", controlnet=controlnet, torch_dtype=torch.float16
|
||||
)
|
||||
pipe.enable_model_cpu_offload()
|
||||
pipe.enable_model_cpu_offload(device=torch_device)
|
||||
pipe.set_progress_bar_config(disable=None)
|
||||
|
||||
generator = torch.Generator(device="cpu").manual_seed(0)
|
||||
|
||||
@@ -7,17 +7,24 @@ import torch
|
||||
from huggingface_hub import hf_hub_download
|
||||
from transformers import AutoTokenizer, CLIPTextConfig, CLIPTextModel, CLIPTokenizer, T5EncoderModel
|
||||
|
||||
from diffusers import AutoencoderKL, FlowMatchEulerDiscreteScheduler, FluxPipeline, FluxTransformer2DModel
|
||||
from diffusers import (
|
||||
AutoencoderKL,
|
||||
FasterCacheConfig,
|
||||
FlowMatchEulerDiscreteScheduler,
|
||||
FluxPipeline,
|
||||
FluxTransformer2DModel,
|
||||
)
|
||||
from diffusers.utils.testing_utils import (
|
||||
backend_empty_cache,
|
||||
nightly,
|
||||
numpy_cosine_similarity_distance,
|
||||
require_big_gpu_with_torch_cuda,
|
||||
require_big_accelerator,
|
||||
slow,
|
||||
torch_device,
|
||||
)
|
||||
|
||||
from ..test_pipelines_common import (
|
||||
FasterCacheTesterMixin,
|
||||
FluxIPAdapterTesterMixin,
|
||||
PipelineTesterMixin,
|
||||
PyramidAttentionBroadcastTesterMixin,
|
||||
@@ -27,7 +34,11 @@ from ..test_pipelines_common import (
|
||||
|
||||
|
||||
class FluxPipelineFastTests(
|
||||
unittest.TestCase, PipelineTesterMixin, FluxIPAdapterTesterMixin, PyramidAttentionBroadcastTesterMixin
|
||||
unittest.TestCase,
|
||||
PipelineTesterMixin,
|
||||
FluxIPAdapterTesterMixin,
|
||||
PyramidAttentionBroadcastTesterMixin,
|
||||
FasterCacheTesterMixin,
|
||||
):
|
||||
pipeline_class = FluxPipeline
|
||||
params = frozenset(["prompt", "height", "width", "guidance_scale", "prompt_embeds", "pooled_prompt_embeds"])
|
||||
@@ -38,6 +49,14 @@ class FluxPipelineFastTests(
|
||||
test_layerwise_casting = True
|
||||
test_group_offloading = True
|
||||
|
||||
faster_cache_config = FasterCacheConfig(
|
||||
spatial_attention_block_skip_range=2,
|
||||
spatial_attention_timestep_skip_range=(-1, 901),
|
||||
unconditional_batch_skip_range=2,
|
||||
attention_weight_callback=lambda _: 0.5,
|
||||
is_guidance_distilled=True,
|
||||
)
|
||||
|
||||
def get_dummy_components(self, num_layers: int = 1, num_single_layers: int = 1):
|
||||
torch.manual_seed(0)
|
||||
transformer = FluxTransformer2DModel(
|
||||
@@ -204,7 +223,7 @@ class FluxPipelineFastTests(
|
||||
|
||||
|
||||
@nightly
|
||||
@require_big_gpu_with_torch_cuda
|
||||
@require_big_accelerator
|
||||
@pytest.mark.big_gpu_with_torch_cuda
|
||||
class FluxPipelineSlowTests(unittest.TestCase):
|
||||
pipeline_class = FluxPipeline
|
||||
@@ -292,7 +311,7 @@ class FluxPipelineSlowTests(unittest.TestCase):
|
||||
|
||||
|
||||
@slow
|
||||
@require_big_gpu_with_torch_cuda
|
||||
@require_big_accelerator
|
||||
@pytest.mark.big_gpu_with_torch_cuda
|
||||
class FluxIPAdapterPipelineSlowTests(unittest.TestCase):
|
||||
pipeline_class = FluxPipeline
|
||||
@@ -304,12 +323,12 @@ class FluxIPAdapterPipelineSlowTests(unittest.TestCase):
|
||||
def setUp(self):
|
||||
super().setUp()
|
||||
gc.collect()
|
||||
torch.cuda.empty_cache()
|
||||
backend_empty_cache(torch_device)
|
||||
|
||||
def tearDown(self):
|
||||
super().tearDown()
|
||||
gc.collect()
|
||||
torch.cuda.empty_cache()
|
||||
backend_empty_cache(torch_device)
|
||||
|
||||
def get_inputs(self, device, seed=0):
|
||||
if str(device).startswith("mps"):
|
||||
|
||||
@@ -8,15 +8,16 @@ import torch
|
||||
from diffusers import FluxPipeline, FluxPriorReduxPipeline
|
||||
from diffusers.utils import load_image
|
||||
from diffusers.utils.testing_utils import (
|
||||
backend_empty_cache,
|
||||
numpy_cosine_similarity_distance,
|
||||
require_big_gpu_with_torch_cuda,
|
||||
require_big_accelerator,
|
||||
slow,
|
||||
torch_device,
|
||||
)
|
||||
|
||||
|
||||
@slow
|
||||
@require_big_gpu_with_torch_cuda
|
||||
@require_big_accelerator
|
||||
@pytest.mark.big_gpu_with_torch_cuda
|
||||
class FluxReduxSlowTests(unittest.TestCase):
|
||||
pipeline_class = FluxPriorReduxPipeline
|
||||
@@ -27,12 +28,12 @@ class FluxReduxSlowTests(unittest.TestCase):
|
||||
def setUp(self):
|
||||
super().setUp()
|
||||
gc.collect()
|
||||
torch.cuda.empty_cache()
|
||||
backend_empty_cache(torch_device)
|
||||
|
||||
def tearDown(self):
|
||||
super().tearDown()
|
||||
gc.collect()
|
||||
torch.cuda.empty_cache()
|
||||
backend_empty_cache(torch_device)
|
||||
|
||||
def get_inputs(self, device, seed=0):
|
||||
init_image = load_image(
|
||||
@@ -59,7 +60,7 @@ class FluxReduxSlowTests(unittest.TestCase):
|
||||
self.base_repo_id, torch_dtype=torch.bfloat16, text_encoder=None, text_encoder_2=None
|
||||
)
|
||||
pipe_redux.to(torch_device)
|
||||
pipe_base.enable_model_cpu_offload()
|
||||
pipe_base.enable_model_cpu_offload(device=torch_device)
|
||||
|
||||
inputs = self.get_inputs(torch_device)
|
||||
base_pipeline_inputs = self.get_base_pipeline_inputs(torch_device)
|
||||
|
||||
@@ -83,6 +83,7 @@ class HunyuanVideoImageToVideoPipelineFastTests(
|
||||
text_embed_dim=16,
|
||||
pooled_projection_dim=8,
|
||||
rope_axes_dim=(2, 4, 4),
|
||||
image_condition_type="latent_concat",
|
||||
)
|
||||
|
||||
torch.manual_seed(0)
|
||||
|
||||
@@ -21,6 +21,7 @@ from transformers import CLIPTextConfig, CLIPTextModel, CLIPTokenizer, LlamaConf
|
||||
|
||||
from diffusers import (
|
||||
AutoencoderKLHunyuanVideo,
|
||||
FasterCacheConfig,
|
||||
FlowMatchEulerDiscreteScheduler,
|
||||
HunyuanVideoPipeline,
|
||||
HunyuanVideoTransformer3DModel,
|
||||
@@ -30,13 +31,20 @@ from diffusers.utils.testing_utils import (
|
||||
torch_device,
|
||||
)
|
||||
|
||||
from ..test_pipelines_common import PipelineTesterMixin, PyramidAttentionBroadcastTesterMixin, to_np
|
||||
from ..test_pipelines_common import (
|
||||
FasterCacheTesterMixin,
|
||||
PipelineTesterMixin,
|
||||
PyramidAttentionBroadcastTesterMixin,
|
||||
to_np,
|
||||
)
|
||||
|
||||
|
||||
enable_full_determinism()
|
||||
|
||||
|
||||
class HunyuanVideoPipelineFastTests(PipelineTesterMixin, PyramidAttentionBroadcastTesterMixin, unittest.TestCase):
|
||||
class HunyuanVideoPipelineFastTests(
|
||||
PipelineTesterMixin, PyramidAttentionBroadcastTesterMixin, FasterCacheTesterMixin, unittest.TestCase
|
||||
):
|
||||
pipeline_class = HunyuanVideoPipeline
|
||||
params = frozenset(["prompt", "height", "width", "guidance_scale", "prompt_embeds", "pooled_prompt_embeds"])
|
||||
batch_params = frozenset(["prompt"])
|
||||
@@ -56,6 +64,14 @@ class HunyuanVideoPipelineFastTests(PipelineTesterMixin, PyramidAttentionBroadca
|
||||
test_layerwise_casting = True
|
||||
test_group_offloading = True
|
||||
|
||||
faster_cache_config = FasterCacheConfig(
|
||||
spatial_attention_block_skip_range=2,
|
||||
spatial_attention_timestep_skip_range=(-1, 901),
|
||||
unconditional_batch_skip_range=2,
|
||||
attention_weight_callback=lambda _: 0.5,
|
||||
is_guidance_distilled=True,
|
||||
)
|
||||
|
||||
def get_dummy_components(self, num_layers: int = 1, num_single_layers: int = 1):
|
||||
torch.manual_seed(0)
|
||||
transformer = HunyuanVideoTransformer3DModel(
|
||||
|
||||
@@ -377,9 +377,10 @@ class IPAdapterSDIntegrationTests(IPAdapterNightlyTestsMixin):
|
||||
pipeline.set_ip_adapter_scale(0.7)
|
||||
|
||||
inputs = self.get_dummy_inputs()
|
||||
id_embeds = load_pt("https://huggingface.co/datasets/fabiorigano/testing-images/resolve/main/ai_face2.ipadpt")[
|
||||
0
|
||||
]
|
||||
id_embeds = load_pt(
|
||||
"https://huggingface.co/datasets/fabiorigano/testing-images/resolve/main/ai_face2.ipadpt",
|
||||
map_location=torch_device,
|
||||
)[0]
|
||||
id_embeds = id_embeds.reshape((2, 1, 1, 512))
|
||||
inputs["ip_adapter_image_embeds"] = [id_embeds]
|
||||
inputs["ip_adapter_image"] = None
|
||||
|
||||
@@ -90,7 +90,7 @@ class KolorsPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
|
||||
)
|
||||
torch.manual_seed(0)
|
||||
text_encoder = ChatGLMModel.from_pretrained(
|
||||
"hf-internal-testing/tiny-random-chatglm3-6b", torch_dtype=torch.bfloat16
|
||||
"hf-internal-testing/tiny-random-chatglm3-6b", torch_dtype=torch.float32
|
||||
)
|
||||
tokenizer = ChatGLMTokenizer.from_pretrained("hf-internal-testing/tiny-random-chatglm3-6b")
|
||||
|
||||
|
||||
@@ -94,7 +94,7 @@ class KolorsPipelineImg2ImgFastTests(PipelineTesterMixin, unittest.TestCase):
|
||||
)
|
||||
torch.manual_seed(0)
|
||||
text_encoder = ChatGLMModel.from_pretrained(
|
||||
"hf-internal-testing/tiny-random-chatglm3-6b", torch_dtype=torch.bfloat16
|
||||
"hf-internal-testing/tiny-random-chatglm3-6b", torch_dtype=torch.float32
|
||||
)
|
||||
tokenizer = ChatGLMTokenizer.from_pretrained("hf-internal-testing/tiny-random-chatglm3-6b")
|
||||
|
||||
|
||||
@@ -25,6 +25,7 @@ from transformers import AutoTokenizer, T5EncoderModel
|
||||
from diffusers import (
|
||||
AutoencoderKL,
|
||||
DDIMScheduler,
|
||||
FasterCacheConfig,
|
||||
LattePipeline,
|
||||
LatteTransformer3DModel,
|
||||
PyramidAttentionBroadcastConfig,
|
||||
@@ -40,13 +41,20 @@ from diffusers.utils.testing_utils import (
|
||||
)
|
||||
|
||||
from ..pipeline_params import TEXT_TO_IMAGE_BATCH_PARAMS, TEXT_TO_IMAGE_IMAGE_PARAMS, TEXT_TO_IMAGE_PARAMS
|
||||
from ..test_pipelines_common import PipelineTesterMixin, PyramidAttentionBroadcastTesterMixin, to_np
|
||||
from ..test_pipelines_common import (
|
||||
FasterCacheTesterMixin,
|
||||
PipelineTesterMixin,
|
||||
PyramidAttentionBroadcastTesterMixin,
|
||||
to_np,
|
||||
)
|
||||
|
||||
|
||||
enable_full_determinism()
|
||||
|
||||
|
||||
class LattePipelineFastTests(PipelineTesterMixin, PyramidAttentionBroadcastTesterMixin, unittest.TestCase):
|
||||
class LattePipelineFastTests(
|
||||
PipelineTesterMixin, PyramidAttentionBroadcastTesterMixin, FasterCacheTesterMixin, unittest.TestCase
|
||||
):
|
||||
pipeline_class = LattePipeline
|
||||
params = TEXT_TO_IMAGE_PARAMS - {"cross_attention_kwargs"}
|
||||
batch_params = TEXT_TO_IMAGE_BATCH_PARAMS
|
||||
@@ -69,6 +77,15 @@ class LattePipelineFastTests(PipelineTesterMixin, PyramidAttentionBroadcastTeste
|
||||
cross_attention_block_identifiers=["transformer_blocks"],
|
||||
)
|
||||
|
||||
faster_cache_config = FasterCacheConfig(
|
||||
spatial_attention_block_skip_range=2,
|
||||
temporal_attention_block_skip_range=2,
|
||||
spatial_attention_timestep_skip_range=(-1, 901),
|
||||
temporal_attention_timestep_skip_range=(-1, 901),
|
||||
unconditional_batch_skip_range=2,
|
||||
attention_weight_callback=lambda _: 0.5,
|
||||
)
|
||||
|
||||
def get_dummy_components(self, num_layers: int = 1):
|
||||
torch.manual_seed(0)
|
||||
transformer = LatteTransformer3DModel(
|
||||
|
||||
@@ -0,0 +1,284 @@
|
||||
# Copyright 2024 The HuggingFace Team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import inspect
|
||||
import unittest
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from transformers import AutoTokenizer, T5EncoderModel
|
||||
|
||||
from diffusers import (
|
||||
AutoencoderKLLTXVideo,
|
||||
FlowMatchEulerDiscreteScheduler,
|
||||
LTXConditionPipeline,
|
||||
LTXVideoTransformer3DModel,
|
||||
)
|
||||
from diffusers.pipelines.ltx.pipeline_ltx_condition import LTXVideoCondition
|
||||
from diffusers.utils.testing_utils import enable_full_determinism, torch_device
|
||||
|
||||
from ..pipeline_params import TEXT_TO_IMAGE_BATCH_PARAMS, TEXT_TO_IMAGE_IMAGE_PARAMS, TEXT_TO_IMAGE_PARAMS
|
||||
from ..test_pipelines_common import PipelineTesterMixin, to_np
|
||||
|
||||
|
||||
enable_full_determinism()
|
||||
|
||||
|
||||
class LTXConditionPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
|
||||
pipeline_class = LTXConditionPipeline
|
||||
params = TEXT_TO_IMAGE_PARAMS - {"cross_attention_kwargs"}
|
||||
batch_params = TEXT_TO_IMAGE_BATCH_PARAMS.union({"image"})
|
||||
image_params = TEXT_TO_IMAGE_IMAGE_PARAMS
|
||||
image_latents_params = TEXT_TO_IMAGE_IMAGE_PARAMS
|
||||
required_optional_params = frozenset(
|
||||
[
|
||||
"num_inference_steps",
|
||||
"generator",
|
||||
"latents",
|
||||
"return_dict",
|
||||
"callback_on_step_end",
|
||||
"callback_on_step_end_tensor_inputs",
|
||||
]
|
||||
)
|
||||
test_xformers_attention = False
|
||||
|
||||
def get_dummy_components(self):
|
||||
torch.manual_seed(0)
|
||||
transformer = LTXVideoTransformer3DModel(
|
||||
in_channels=8,
|
||||
out_channels=8,
|
||||
patch_size=1,
|
||||
patch_size_t=1,
|
||||
num_attention_heads=4,
|
||||
attention_head_dim=8,
|
||||
cross_attention_dim=32,
|
||||
num_layers=1,
|
||||
caption_channels=32,
|
||||
)
|
||||
|
||||
torch.manual_seed(0)
|
||||
vae = AutoencoderKLLTXVideo(
|
||||
in_channels=3,
|
||||
out_channels=3,
|
||||
latent_channels=8,
|
||||
block_out_channels=(8, 8, 8, 8),
|
||||
decoder_block_out_channels=(8, 8, 8, 8),
|
||||
layers_per_block=(1, 1, 1, 1, 1),
|
||||
decoder_layers_per_block=(1, 1, 1, 1, 1),
|
||||
spatio_temporal_scaling=(True, True, False, False),
|
||||
decoder_spatio_temporal_scaling=(True, True, False, False),
|
||||
decoder_inject_noise=(False, False, False, False, False),
|
||||
upsample_residual=(False, False, False, False),
|
||||
upsample_factor=(1, 1, 1, 1),
|
||||
timestep_conditioning=False,
|
||||
patch_size=1,
|
||||
patch_size_t=1,
|
||||
encoder_causal=True,
|
||||
decoder_causal=False,
|
||||
)
|
||||
vae.use_framewise_encoding = False
|
||||
vae.use_framewise_decoding = False
|
||||
|
||||
torch.manual_seed(0)
|
||||
scheduler = FlowMatchEulerDiscreteScheduler()
|
||||
text_encoder = T5EncoderModel.from_pretrained("hf-internal-testing/tiny-random-t5")
|
||||
tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-t5")
|
||||
|
||||
components = {
|
||||
"transformer": transformer,
|
||||
"vae": vae,
|
||||
"scheduler": scheduler,
|
||||
"text_encoder": text_encoder,
|
||||
"tokenizer": tokenizer,
|
||||
}
|
||||
return components
|
||||
|
||||
def get_dummy_inputs(self, device, seed=0, use_conditions=False):
|
||||
if str(device).startswith("mps"):
|
||||
generator = torch.manual_seed(seed)
|
||||
else:
|
||||
generator = torch.Generator(device=device).manual_seed(seed)
|
||||
|
||||
image = torch.randn((1, 3, 32, 32), generator=generator, device=device)
|
||||
if use_conditions:
|
||||
conditions = LTXVideoCondition(
|
||||
image=image,
|
||||
)
|
||||
else:
|
||||
conditions = None
|
||||
|
||||
inputs = {
|
||||
"conditions": conditions,
|
||||
"image": None if use_conditions else image,
|
||||
"prompt": "dance monkey",
|
||||
"negative_prompt": "",
|
||||
"generator": generator,
|
||||
"num_inference_steps": 2,
|
||||
"guidance_scale": 3.0,
|
||||
"height": 32,
|
||||
"width": 32,
|
||||
# 8 * k + 1 is the recommendation
|
||||
"num_frames": 9,
|
||||
"max_sequence_length": 16,
|
||||
"output_type": "pt",
|
||||
}
|
||||
|
||||
return inputs
|
||||
|
||||
def test_inference(self):
|
||||
device = "cpu"
|
||||
|
||||
components = self.get_dummy_components()
|
||||
pipe = self.pipeline_class(**components)
|
||||
pipe.to(device)
|
||||
pipe.set_progress_bar_config(disable=None)
|
||||
|
||||
inputs = self.get_dummy_inputs(device)
|
||||
inputs2 = self.get_dummy_inputs(device, use_conditions=True)
|
||||
video = pipe(**inputs).frames
|
||||
generated_video = video[0]
|
||||
video2 = pipe(**inputs2).frames
|
||||
generated_video2 = video2[0]
|
||||
|
||||
self.assertEqual(generated_video.shape, (9, 3, 32, 32))
|
||||
|
||||
max_diff = np.abs(generated_video - generated_video2).max()
|
||||
self.assertLessEqual(max_diff, 1e-3)
|
||||
|
||||
def test_callback_inputs(self):
|
||||
sig = inspect.signature(self.pipeline_class.__call__)
|
||||
has_callback_tensor_inputs = "callback_on_step_end_tensor_inputs" in sig.parameters
|
||||
has_callback_step_end = "callback_on_step_end" in sig.parameters
|
||||
|
||||
if not (has_callback_tensor_inputs and has_callback_step_end):
|
||||
return
|
||||
|
||||
components = self.get_dummy_components()
|
||||
pipe = self.pipeline_class(**components)
|
||||
pipe = pipe.to(torch_device)
|
||||
pipe.set_progress_bar_config(disable=None)
|
||||
self.assertTrue(
|
||||
hasattr(pipe, "_callback_tensor_inputs"),
|
||||
f" {self.pipeline_class} should have `_callback_tensor_inputs` that defines a list of tensor variables its callback function can use as inputs",
|
||||
)
|
||||
|
||||
def callback_inputs_subset(pipe, i, t, callback_kwargs):
|
||||
# iterate over callback args
|
||||
for tensor_name, tensor_value in callback_kwargs.items():
|
||||
# check that we're only passing in allowed tensor inputs
|
||||
assert tensor_name in pipe._callback_tensor_inputs
|
||||
|
||||
return callback_kwargs
|
||||
|
||||
def callback_inputs_all(pipe, i, t, callback_kwargs):
|
||||
for tensor_name in pipe._callback_tensor_inputs:
|
||||
assert tensor_name in callback_kwargs
|
||||
|
||||
# iterate over callback args
|
||||
for tensor_name, tensor_value in callback_kwargs.items():
|
||||
# check that we're only passing in allowed tensor inputs
|
||||
assert tensor_name in pipe._callback_tensor_inputs
|
||||
|
||||
return callback_kwargs
|
||||
|
||||
inputs = self.get_dummy_inputs(torch_device)
|
||||
|
||||
# Test passing in a subset
|
||||
inputs["callback_on_step_end"] = callback_inputs_subset
|
||||
inputs["callback_on_step_end_tensor_inputs"] = ["latents"]
|
||||
output = pipe(**inputs)[0]
|
||||
|
||||
# Test passing in a everything
|
||||
inputs["callback_on_step_end"] = callback_inputs_all
|
||||
inputs["callback_on_step_end_tensor_inputs"] = pipe._callback_tensor_inputs
|
||||
output = pipe(**inputs)[0]
|
||||
|
||||
def callback_inputs_change_tensor(pipe, i, t, callback_kwargs):
|
||||
is_last = i == (pipe.num_timesteps - 1)
|
||||
if is_last:
|
||||
callback_kwargs["latents"] = torch.zeros_like(callback_kwargs["latents"])
|
||||
return callback_kwargs
|
||||
|
||||
inputs["callback_on_step_end"] = callback_inputs_change_tensor
|
||||
inputs["callback_on_step_end_tensor_inputs"] = pipe._callback_tensor_inputs
|
||||
output = pipe(**inputs)[0]
|
||||
assert output.abs().sum() < 1e10
|
||||
|
||||
def test_inference_batch_single_identical(self):
|
||||
self._test_inference_batch_single_identical(batch_size=3, expected_max_diff=1e-3)
|
||||
|
||||
def test_attention_slicing_forward_pass(
|
||||
self, test_max_difference=True, test_mean_pixel_difference=True, expected_max_diff=1e-3
|
||||
):
|
||||
if not self.test_attention_slicing:
|
||||
return
|
||||
|
||||
components = self.get_dummy_components()
|
||||
pipe = self.pipeline_class(**components)
|
||||
for component in pipe.components.values():
|
||||
if hasattr(component, "set_default_attn_processor"):
|
||||
component.set_default_attn_processor()
|
||||
pipe.to(torch_device)
|
||||
pipe.set_progress_bar_config(disable=None)
|
||||
|
||||
generator_device = "cpu"
|
||||
inputs = self.get_dummy_inputs(generator_device)
|
||||
output_without_slicing = pipe(**inputs)[0]
|
||||
|
||||
pipe.enable_attention_slicing(slice_size=1)
|
||||
inputs = self.get_dummy_inputs(generator_device)
|
||||
output_with_slicing1 = pipe(**inputs)[0]
|
||||
|
||||
pipe.enable_attention_slicing(slice_size=2)
|
||||
inputs = self.get_dummy_inputs(generator_device)
|
||||
output_with_slicing2 = pipe(**inputs)[0]
|
||||
|
||||
if test_max_difference:
|
||||
max_diff1 = np.abs(to_np(output_with_slicing1) - to_np(output_without_slicing)).max()
|
||||
max_diff2 = np.abs(to_np(output_with_slicing2) - to_np(output_without_slicing)).max()
|
||||
self.assertLess(
|
||||
max(max_diff1, max_diff2),
|
||||
expected_max_diff,
|
||||
"Attention slicing should not affect the inference results",
|
||||
)
|
||||
|
||||
def test_vae_tiling(self, expected_diff_max: float = 0.2):
|
||||
generator_device = "cpu"
|
||||
components = self.get_dummy_components()
|
||||
|
||||
pipe = self.pipeline_class(**components)
|
||||
pipe.to("cpu")
|
||||
pipe.set_progress_bar_config(disable=None)
|
||||
|
||||
# Without tiling
|
||||
inputs = self.get_dummy_inputs(generator_device)
|
||||
inputs["height"] = inputs["width"] = 128
|
||||
output_without_tiling = pipe(**inputs)[0]
|
||||
|
||||
# With tiling
|
||||
pipe.vae.enable_tiling(
|
||||
tile_sample_min_height=96,
|
||||
tile_sample_min_width=96,
|
||||
tile_sample_stride_height=64,
|
||||
tile_sample_stride_width=64,
|
||||
)
|
||||
inputs = self.get_dummy_inputs(generator_device)
|
||||
inputs["height"] = inputs["width"] = 128
|
||||
output_with_tiling = pipe(**inputs)[0]
|
||||
|
||||
self.assertLess(
|
||||
(to_np(output_without_tiling) - to_np(output_with_tiling)).max(),
|
||||
expected_diff_max,
|
||||
"VAE tiling should not affect the inference results",
|
||||
)
|
||||
@@ -33,13 +33,13 @@ from diffusers.utils.testing_utils import (
|
||||
)
|
||||
|
||||
from ..pipeline_params import TEXT_TO_IMAGE_BATCH_PARAMS, TEXT_TO_IMAGE_IMAGE_PARAMS, TEXT_TO_IMAGE_PARAMS
|
||||
from ..test_pipelines_common import PipelineTesterMixin, to_np
|
||||
from ..test_pipelines_common import FasterCacheTesterMixin, PipelineTesterMixin, to_np
|
||||
|
||||
|
||||
enable_full_determinism()
|
||||
|
||||
|
||||
class MochiPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
|
||||
class MochiPipelineFastTests(PipelineTesterMixin, FasterCacheTesterMixin, unittest.TestCase):
|
||||
pipeline_class = MochiPipeline
|
||||
params = TEXT_TO_IMAGE_PARAMS - {"cross_attention_kwargs"}
|
||||
batch_params = TEXT_TO_IMAGE_BATCH_PARAMS
|
||||
@@ -59,13 +59,13 @@ class MochiPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
|
||||
test_layerwise_casting = True
|
||||
test_group_offloading = True
|
||||
|
||||
def get_dummy_components(self):
|
||||
def get_dummy_components(self, num_layers: int = 2):
|
||||
torch.manual_seed(0)
|
||||
transformer = MochiTransformer3DModel(
|
||||
patch_size=2,
|
||||
num_attention_heads=2,
|
||||
attention_head_dim=8,
|
||||
num_layers=2,
|
||||
num_layers=num_layers,
|
||||
pooled_projection_dim=16,
|
||||
in_channels=12,
|
||||
out_channels=None,
|
||||
|
||||
@@ -99,7 +99,7 @@ class KolorsPAGPipelineFastTests(
|
||||
)
|
||||
torch.manual_seed(0)
|
||||
text_encoder = ChatGLMModel.from_pretrained(
|
||||
"hf-internal-testing/tiny-random-chatglm3-6b", torch_dtype=torch.bfloat16
|
||||
"hf-internal-testing/tiny-random-chatglm3-6b", torch_dtype=torch.float32
|
||||
)
|
||||
tokenizer = ChatGLMTokenizer.from_pretrained("hf-internal-testing/tiny-random-chatglm3-6b")
|
||||
|
||||
|
||||
@@ -262,7 +262,7 @@ class StableDiffusion3PAGImg2ImgPipelineIntegrationTests(unittest.TestCase):
|
||||
pipeline = AutoPipelineForImage2Image.from_pretrained(
|
||||
self.repo_id, enable_pag=True, torch_dtype=torch.float16, pag_applied_layers=["blocks.(4|17)"]
|
||||
)
|
||||
pipeline.enable_model_cpu_offload()
|
||||
pipeline.enable_model_cpu_offload(device=torch_device)
|
||||
pipeline.set_progress_bar_config(disable=None)
|
||||
|
||||
inputs = self.get_inputs(torch_device, guidance_scale=0.0, pag_scale=1.8)
|
||||
|
||||
@@ -0,0 +1,302 @@
|
||||
# Copyright 2024 The HuggingFace Team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import inspect
|
||||
import unittest
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from transformers import Gemma2Config, Gemma2Model, GemmaTokenizer
|
||||
|
||||
from diffusers import AutoencoderDC, SanaSprintPipeline, SanaTransformer2DModel, SCMScheduler
|
||||
from diffusers.utils.testing_utils import (
|
||||
enable_full_determinism,
|
||||
torch_device,
|
||||
)
|
||||
|
||||
from ..pipeline_params import TEXT_TO_IMAGE_BATCH_PARAMS, TEXT_TO_IMAGE_IMAGE_PARAMS, TEXT_TO_IMAGE_PARAMS
|
||||
from ..test_pipelines_common import PipelineTesterMixin, to_np
|
||||
|
||||
|
||||
enable_full_determinism()
|
||||
|
||||
|
||||
class SanaSprintPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
|
||||
pipeline_class = SanaSprintPipeline
|
||||
params = TEXT_TO_IMAGE_PARAMS - {"cross_attention_kwargs", "negative_prompt", "negative_prompt_embeds"}
|
||||
batch_params = TEXT_TO_IMAGE_BATCH_PARAMS - {"negative_prompt"}
|
||||
image_params = TEXT_TO_IMAGE_IMAGE_PARAMS - {"negative_prompt"}
|
||||
image_latents_params = TEXT_TO_IMAGE_IMAGE_PARAMS
|
||||
required_optional_params = frozenset(
|
||||
[
|
||||
"num_inference_steps",
|
||||
"generator",
|
||||
"latents",
|
||||
"return_dict",
|
||||
"callback_on_step_end",
|
||||
"callback_on_step_end_tensor_inputs",
|
||||
]
|
||||
)
|
||||
test_xformers_attention = False
|
||||
test_layerwise_casting = True
|
||||
test_group_offloading = True
|
||||
|
||||
def get_dummy_components(self):
|
||||
torch.manual_seed(0)
|
||||
transformer = SanaTransformer2DModel(
|
||||
patch_size=1,
|
||||
in_channels=4,
|
||||
out_channels=4,
|
||||
num_layers=1,
|
||||
num_attention_heads=2,
|
||||
attention_head_dim=4,
|
||||
num_cross_attention_heads=2,
|
||||
cross_attention_head_dim=4,
|
||||
cross_attention_dim=8,
|
||||
caption_channels=8,
|
||||
sample_size=32,
|
||||
qk_norm="rms_norm_across_heads",
|
||||
guidance_embeds=True,
|
||||
)
|
||||
|
||||
torch.manual_seed(0)
|
||||
vae = AutoencoderDC(
|
||||
in_channels=3,
|
||||
latent_channels=4,
|
||||
attention_head_dim=2,
|
||||
encoder_block_types=(
|
||||
"ResBlock",
|
||||
"EfficientViTBlock",
|
||||
),
|
||||
decoder_block_types=(
|
||||
"ResBlock",
|
||||
"EfficientViTBlock",
|
||||
),
|
||||
encoder_block_out_channels=(8, 8),
|
||||
decoder_block_out_channels=(8, 8),
|
||||
encoder_qkv_multiscales=((), (5,)),
|
||||
decoder_qkv_multiscales=((), (5,)),
|
||||
encoder_layers_per_block=(1, 1),
|
||||
decoder_layers_per_block=[1, 1],
|
||||
downsample_block_type="conv",
|
||||
upsample_block_type="interpolate",
|
||||
decoder_norm_types="rms_norm",
|
||||
decoder_act_fns="silu",
|
||||
scaling_factor=0.41407,
|
||||
)
|
||||
|
||||
torch.manual_seed(0)
|
||||
scheduler = SCMScheduler()
|
||||
|
||||
torch.manual_seed(0)
|
||||
text_encoder_config = Gemma2Config(
|
||||
head_dim=16,
|
||||
hidden_size=8,
|
||||
initializer_range=0.02,
|
||||
intermediate_size=64,
|
||||
max_position_embeddings=8192,
|
||||
model_type="gemma2",
|
||||
num_attention_heads=2,
|
||||
num_hidden_layers=1,
|
||||
num_key_value_heads=2,
|
||||
vocab_size=8,
|
||||
attn_implementation="eager",
|
||||
)
|
||||
text_encoder = Gemma2Model(text_encoder_config)
|
||||
tokenizer = GemmaTokenizer.from_pretrained("hf-internal-testing/dummy-gemma")
|
||||
|
||||
components = {
|
||||
"transformer": transformer,
|
||||
"vae": vae,
|
||||
"scheduler": scheduler,
|
||||
"text_encoder": text_encoder,
|
||||
"tokenizer": tokenizer,
|
||||
}
|
||||
return components
|
||||
|
||||
def get_dummy_inputs(self, device, seed=0):
|
||||
if str(device).startswith("mps"):
|
||||
generator = torch.manual_seed(seed)
|
||||
else:
|
||||
generator = torch.Generator(device=device).manual_seed(seed)
|
||||
inputs = {
|
||||
"prompt": "",
|
||||
"generator": generator,
|
||||
"num_inference_steps": 2,
|
||||
"guidance_scale": 6.0,
|
||||
"height": 32,
|
||||
"width": 32,
|
||||
"max_sequence_length": 16,
|
||||
"output_type": "pt",
|
||||
"complex_human_instruction": None,
|
||||
}
|
||||
return inputs
|
||||
|
||||
def test_inference(self):
|
||||
device = "cpu"
|
||||
|
||||
components = self.get_dummy_components()
|
||||
pipe = self.pipeline_class(**components)
|
||||
pipe.to(device)
|
||||
pipe.set_progress_bar_config(disable=None)
|
||||
|
||||
inputs = self.get_dummy_inputs(device)
|
||||
image = pipe(**inputs)[0]
|
||||
generated_image = image[0]
|
||||
|
||||
self.assertEqual(generated_image.shape, (3, 32, 32))
|
||||
expected_image = torch.randn(3, 32, 32)
|
||||
max_diff = np.abs(generated_image - expected_image).max()
|
||||
self.assertLessEqual(max_diff, 1e10)
|
||||
|
||||
def test_callback_inputs(self):
|
||||
sig = inspect.signature(self.pipeline_class.__call__)
|
||||
has_callback_tensor_inputs = "callback_on_step_end_tensor_inputs" in sig.parameters
|
||||
has_callback_step_end = "callback_on_step_end" in sig.parameters
|
||||
|
||||
if not (has_callback_tensor_inputs and has_callback_step_end):
|
||||
return
|
||||
|
||||
components = self.get_dummy_components()
|
||||
pipe = self.pipeline_class(**components)
|
||||
pipe = pipe.to(torch_device)
|
||||
pipe.set_progress_bar_config(disable=None)
|
||||
self.assertTrue(
|
||||
hasattr(pipe, "_callback_tensor_inputs"),
|
||||
f" {self.pipeline_class} should have `_callback_tensor_inputs` that defines a list of tensor variables its callback function can use as inputs",
|
||||
)
|
||||
|
||||
def callback_inputs_subset(pipe, i, t, callback_kwargs):
|
||||
# iterate over callback args
|
||||
for tensor_name, tensor_value in callback_kwargs.items():
|
||||
# check that we're only passing in allowed tensor inputs
|
||||
assert tensor_name in pipe._callback_tensor_inputs
|
||||
|
||||
return callback_kwargs
|
||||
|
||||
def callback_inputs_all(pipe, i, t, callback_kwargs):
|
||||
for tensor_name in pipe._callback_tensor_inputs:
|
||||
assert tensor_name in callback_kwargs
|
||||
|
||||
# iterate over callback args
|
||||
for tensor_name, tensor_value in callback_kwargs.items():
|
||||
# check that we're only passing in allowed tensor inputs
|
||||
assert tensor_name in pipe._callback_tensor_inputs
|
||||
|
||||
return callback_kwargs
|
||||
|
||||
inputs = self.get_dummy_inputs(torch_device)
|
||||
|
||||
# Test passing in a subset
|
||||
inputs["callback_on_step_end"] = callback_inputs_subset
|
||||
inputs["callback_on_step_end_tensor_inputs"] = ["latents"]
|
||||
output = pipe(**inputs)[0]
|
||||
|
||||
# Test passing in a everything
|
||||
inputs["callback_on_step_end"] = callback_inputs_all
|
||||
inputs["callback_on_step_end_tensor_inputs"] = pipe._callback_tensor_inputs
|
||||
output = pipe(**inputs)[0]
|
||||
|
||||
def callback_inputs_change_tensor(pipe, i, t, callback_kwargs):
|
||||
is_last = i == (pipe.num_timesteps - 1)
|
||||
if is_last:
|
||||
callback_kwargs["latents"] = torch.zeros_like(callback_kwargs["latents"])
|
||||
return callback_kwargs
|
||||
|
||||
inputs["callback_on_step_end"] = callback_inputs_change_tensor
|
||||
inputs["callback_on_step_end_tensor_inputs"] = pipe._callback_tensor_inputs
|
||||
output = pipe(**inputs)[0]
|
||||
assert output.abs().sum() < 1e10
|
||||
|
||||
def test_attention_slicing_forward_pass(
|
||||
self, test_max_difference=True, test_mean_pixel_difference=True, expected_max_diff=1e-3
|
||||
):
|
||||
if not self.test_attention_slicing:
|
||||
return
|
||||
|
||||
components = self.get_dummy_components()
|
||||
pipe = self.pipeline_class(**components)
|
||||
for component in pipe.components.values():
|
||||
if hasattr(component, "set_default_attn_processor"):
|
||||
component.set_default_attn_processor()
|
||||
pipe.to(torch_device)
|
||||
pipe.set_progress_bar_config(disable=None)
|
||||
|
||||
generator_device = "cpu"
|
||||
inputs = self.get_dummy_inputs(generator_device)
|
||||
output_without_slicing = pipe(**inputs)[0]
|
||||
|
||||
pipe.enable_attention_slicing(slice_size=1)
|
||||
inputs = self.get_dummy_inputs(generator_device)
|
||||
output_with_slicing1 = pipe(**inputs)[0]
|
||||
|
||||
pipe.enable_attention_slicing(slice_size=2)
|
||||
inputs = self.get_dummy_inputs(generator_device)
|
||||
output_with_slicing2 = pipe(**inputs)[0]
|
||||
|
||||
if test_max_difference:
|
||||
max_diff1 = np.abs(to_np(output_with_slicing1) - to_np(output_without_slicing)).max()
|
||||
max_diff2 = np.abs(to_np(output_with_slicing2) - to_np(output_without_slicing)).max()
|
||||
self.assertLess(
|
||||
max(max_diff1, max_diff2),
|
||||
expected_max_diff,
|
||||
"Attention slicing should not affect the inference results",
|
||||
)
|
||||
|
||||
def test_vae_tiling(self, expected_diff_max: float = 0.2):
|
||||
generator_device = "cpu"
|
||||
components = self.get_dummy_components()
|
||||
|
||||
pipe = self.pipeline_class(**components)
|
||||
pipe.to("cpu")
|
||||
pipe.set_progress_bar_config(disable=None)
|
||||
|
||||
# Without tiling
|
||||
inputs = self.get_dummy_inputs(generator_device)
|
||||
inputs["height"] = inputs["width"] = 128
|
||||
output_without_tiling = pipe(**inputs)[0]
|
||||
|
||||
# With tiling
|
||||
pipe.vae.enable_tiling(
|
||||
tile_sample_min_height=96,
|
||||
tile_sample_min_width=96,
|
||||
tile_sample_stride_height=64,
|
||||
tile_sample_stride_width=64,
|
||||
)
|
||||
inputs = self.get_dummy_inputs(generator_device)
|
||||
inputs["height"] = inputs["width"] = 128
|
||||
output_with_tiling = pipe(**inputs)[0]
|
||||
|
||||
self.assertLess(
|
||||
(to_np(output_without_tiling) - to_np(output_with_tiling)).max(),
|
||||
expected_diff_max,
|
||||
"VAE tiling should not affect the inference results",
|
||||
)
|
||||
|
||||
# TODO(aryan): Create a dummy gemma model with smol vocab size
|
||||
@unittest.skip(
|
||||
"A very small vocab size is used for fast tests. So, any kind of prompt other than the empty default used in other tests will lead to a embedding lookup error. This test uses a long prompt that causes the error."
|
||||
)
|
||||
def test_inference_batch_consistent(self):
|
||||
pass
|
||||
|
||||
@unittest.skip(
|
||||
"A very small vocab size is used for fast tests. So, any kind of prompt other than the empty default used in other tests will lead to a embedding lookup error. This test uses a long prompt that causes the error."
|
||||
)
|
||||
def test_inference_batch_single_identical(self):
|
||||
pass
|
||||
|
||||
def test_float16_inference(self):
|
||||
# Requires higher tolerance as model seems very sensitive to dtype
|
||||
super().test_float16_inference(expected_max_diff=0.08)
|
||||
@@ -57,7 +57,7 @@ from diffusers.utils.testing_utils import (
|
||||
require_accelerate_version_greater,
|
||||
require_torch_2,
|
||||
require_torch_accelerator,
|
||||
require_torch_multi_gpu,
|
||||
require_torch_multi_accelerator,
|
||||
run_test_in_subprocess,
|
||||
skip_mps,
|
||||
slow,
|
||||
@@ -1409,7 +1409,7 @@ class StableDiffusionPipelineNightlyTests(unittest.TestCase):
|
||||
|
||||
# (sayakpaul): This test suite was run in the DGX with two GPUs (1, 2).
|
||||
@slow
|
||||
@require_torch_multi_gpu
|
||||
@require_torch_multi_accelerator
|
||||
@require_accelerate_version_greater("0.27.0")
|
||||
class StableDiffusionPipelineDeviceMapTests(unittest.TestCase):
|
||||
def tearDown(self):
|
||||
@@ -1497,7 +1497,7 @@ class StableDiffusionPipelineDeviceMapTests(unittest.TestCase):
|
||||
assert sd_pipe_with_device_map.hf_device_map is None
|
||||
|
||||
# Make sure `to()` can be used and the pipeline can be called.
|
||||
pipe = sd_pipe_with_device_map.to("cuda")
|
||||
pipe = sd_pipe_with_device_map.to(torch_device)
|
||||
_ = pipe("hello", num_inference_steps=2)
|
||||
|
||||
def test_reset_device_map_enable_model_cpu_offload(self):
|
||||
@@ -1509,7 +1509,7 @@ class StableDiffusionPipelineDeviceMapTests(unittest.TestCase):
|
||||
assert sd_pipe_with_device_map.hf_device_map is None
|
||||
|
||||
# Make sure `enable_model_cpu_offload()` can be used and the pipeline can be called.
|
||||
sd_pipe_with_device_map.enable_model_cpu_offload()
|
||||
sd_pipe_with_device_map.enable_model_cpu_offload(device=torch_device)
|
||||
_ = sd_pipe_with_device_map("hello", num_inference_steps=2)
|
||||
|
||||
def test_reset_device_map_enable_sequential_cpu_offload(self):
|
||||
@@ -1521,5 +1521,5 @@ class StableDiffusionPipelineDeviceMapTests(unittest.TestCase):
|
||||
assert sd_pipe_with_device_map.hf_device_map is None
|
||||
|
||||
# Make sure `enable_sequential_cpu_offload()` can be used and the pipeline can be called.
|
||||
sd_pipe_with_device_map.enable_sequential_cpu_offload()
|
||||
sd_pipe_with_device_map.enable_sequential_cpu_offload(device=torch_device)
|
||||
_ = sd_pipe_with_device_map("hello", num_inference_steps=2)
|
||||
|
||||
@@ -10,7 +10,7 @@ from diffusers import AutoencoderKL, FlowMatchEulerDiscreteScheduler, SD3Transfo
|
||||
from diffusers.utils.testing_utils import (
|
||||
backend_empty_cache,
|
||||
numpy_cosine_similarity_distance,
|
||||
require_big_gpu_with_torch_cuda,
|
||||
require_big_accelerator,
|
||||
slow,
|
||||
torch_device,
|
||||
)
|
||||
@@ -232,7 +232,7 @@ class StableDiffusion3PipelineFastTests(unittest.TestCase, PipelineTesterMixin):
|
||||
|
||||
|
||||
@slow
|
||||
@require_big_gpu_with_torch_cuda
|
||||
@require_big_accelerator
|
||||
@pytest.mark.big_gpu_with_torch_cuda
|
||||
class StableDiffusion3PipelineSlowTests(unittest.TestCase):
|
||||
pipeline_class = StableDiffusion3Pipeline
|
||||
|
||||
@@ -18,7 +18,7 @@ from diffusers.utils.testing_utils import (
|
||||
backend_empty_cache,
|
||||
floats_tensor,
|
||||
numpy_cosine_similarity_distance,
|
||||
require_big_gpu_with_torch_cuda,
|
||||
require_big_accelerator,
|
||||
slow,
|
||||
torch_device,
|
||||
)
|
||||
@@ -166,7 +166,7 @@ class StableDiffusion3Img2ImgPipelineFastTests(PipelineLatentTesterMixin, unitte
|
||||
|
||||
|
||||
@slow
|
||||
@require_big_gpu_with_torch_cuda
|
||||
@require_big_accelerator
|
||||
@pytest.mark.big_gpu_with_torch_cuda
|
||||
class StableDiffusion3Img2ImgPipelineSlowTests(unittest.TestCase):
|
||||
pipeline_class = StableDiffusion3Img2ImgPipeline
|
||||
@@ -202,11 +202,10 @@ class StableDiffusion3Img2ImgPipelineSlowTests(unittest.TestCase):
|
||||
}
|
||||
|
||||
def test_sd3_img2img_inference(self):
|
||||
torch.manual_seed(0)
|
||||
pipe = self.pipeline_class.from_pretrained(self.repo_id, torch_dtype=torch.float16)
|
||||
pipe.enable_model_cpu_offload(device=torch_device)
|
||||
|
||||
inputs = self.get_inputs(torch_device)
|
||||
|
||||
image = pipe(**inputs).images[0]
|
||||
image_slice = image[0, :10, :10]
|
||||
expected_slice = np.array(
|
||||
|
||||
@@ -23,13 +23,16 @@ from diffusers import (
|
||||
ConsistencyDecoderVAE,
|
||||
DDIMScheduler,
|
||||
DiffusionPipeline,
|
||||
FasterCacheConfig,
|
||||
KolorsPipeline,
|
||||
PyramidAttentionBroadcastConfig,
|
||||
StableDiffusionPipeline,
|
||||
StableDiffusionXLPipeline,
|
||||
UNet2DConditionModel,
|
||||
apply_faster_cache,
|
||||
)
|
||||
from diffusers.hooks import apply_group_offloading
|
||||
from diffusers.hooks.faster_cache import FasterCacheBlockHook, FasterCacheDenoiserHook
|
||||
from diffusers.hooks.pyramid_attention_broadcast import PyramidAttentionBroadcastHook
|
||||
from diffusers.image_processor import VaeImageProcessor
|
||||
from diffusers.loaders import FluxIPAdapterMixin, IPAdapterMixin
|
||||
@@ -45,6 +48,7 @@ from diffusers.utils.import_utils import is_xformers_available
|
||||
from diffusers.utils.source_code_parsing_utils import ReturnNameVisitor
|
||||
from diffusers.utils.testing_utils import (
|
||||
CaptureLogger,
|
||||
backend_empty_cache,
|
||||
require_accelerate_version_greater,
|
||||
require_accelerator,
|
||||
require_hf_hub_version_greater,
|
||||
@@ -1108,13 +1112,13 @@ class PipelineTesterMixin:
|
||||
# clean up the VRAM before each test
|
||||
super().setUp()
|
||||
gc.collect()
|
||||
torch.cuda.empty_cache()
|
||||
backend_empty_cache(torch_device)
|
||||
|
||||
def tearDown(self):
|
||||
# clean up the VRAM after each test in case of CUDA runtime errors
|
||||
super().tearDown()
|
||||
gc.collect()
|
||||
torch.cuda.empty_cache()
|
||||
backend_empty_cache(torch_device)
|
||||
|
||||
def test_save_load_local(self, expected_max_difference=5e-4):
|
||||
components = self.get_dummy_components()
|
||||
@@ -1423,7 +1427,6 @@ class PipelineTesterMixin:
|
||||
def test_save_load_optional_components(self, expected_max_difference=1e-4):
|
||||
if not hasattr(self.pipeline_class, "_optional_components"):
|
||||
return
|
||||
|
||||
components = self.get_dummy_components()
|
||||
pipe = self.pipeline_class(**components)
|
||||
for component in pipe.components.values():
|
||||
@@ -1438,6 +1441,7 @@ class PipelineTesterMixin:
|
||||
|
||||
generator_device = "cpu"
|
||||
inputs = self.get_dummy_inputs(generator_device)
|
||||
torch.manual_seed(0)
|
||||
output = pipe(**inputs)[0]
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
@@ -1456,6 +1460,7 @@ class PipelineTesterMixin:
|
||||
)
|
||||
|
||||
inputs = self.get_dummy_inputs(generator_device)
|
||||
torch.manual_seed(0)
|
||||
output_loaded = pipe_loaded(**inputs)[0]
|
||||
|
||||
max_diff = np.abs(to_np(output) - to_np(output_loaded)).max()
|
||||
@@ -1550,12 +1555,14 @@ class PipelineTesterMixin:
|
||||
|
||||
generator_device = "cpu"
|
||||
inputs = self.get_dummy_inputs(generator_device)
|
||||
torch.manual_seed(0)
|
||||
output_without_offload = pipe(**inputs)[0]
|
||||
|
||||
pipe.enable_sequential_cpu_offload(device=torch_device)
|
||||
assert pipe._execution_device.type == torch_device
|
||||
|
||||
inputs = self.get_dummy_inputs(generator_device)
|
||||
torch.manual_seed(0)
|
||||
output_with_offload = pipe(**inputs)[0]
|
||||
|
||||
max_diff = np.abs(to_np(output_with_offload) - to_np(output_without_offload)).max()
|
||||
@@ -1613,12 +1620,14 @@ class PipelineTesterMixin:
|
||||
pipe.set_progress_bar_config(disable=None)
|
||||
|
||||
inputs = self.get_dummy_inputs(generator_device)
|
||||
torch.manual_seed(0)
|
||||
output_without_offload = pipe(**inputs)[0]
|
||||
|
||||
pipe.enable_model_cpu_offload(device=torch_device)
|
||||
assert pipe._execution_device.type == torch_device
|
||||
|
||||
inputs = self.get_dummy_inputs(generator_device)
|
||||
torch.manual_seed(0)
|
||||
output_with_offload = pipe(**inputs)[0]
|
||||
|
||||
max_diff = np.abs(to_np(output_with_offload) - to_np(output_without_offload)).max()
|
||||
@@ -2545,6 +2554,167 @@ class PyramidAttentionBroadcastTesterMixin:
|
||||
), "Outputs from normal inference and after disabling cache should not differ."
|
||||
|
||||
|
||||
class FasterCacheTesterMixin:
|
||||
faster_cache_config = FasterCacheConfig(
|
||||
spatial_attention_block_skip_range=2,
|
||||
spatial_attention_timestep_skip_range=(-1, 901),
|
||||
unconditional_batch_skip_range=2,
|
||||
attention_weight_callback=lambda _: 0.5,
|
||||
)
|
||||
|
||||
def test_faster_cache_basic_warning_or_errors_raised(self):
|
||||
components = self.get_dummy_components()
|
||||
|
||||
logger = logging.get_logger("diffusers.hooks.faster_cache")
|
||||
logger.setLevel(logging.INFO)
|
||||
|
||||
# Check if warning is raise when no attention_weight_callback is provided
|
||||
pipe = self.pipeline_class(**components)
|
||||
with CaptureLogger(logger) as cap_logger:
|
||||
config = FasterCacheConfig(spatial_attention_block_skip_range=2, attention_weight_callback=None)
|
||||
apply_faster_cache(pipe.transformer, config)
|
||||
self.assertTrue("No `attention_weight_callback` provided when enabling FasterCache" in cap_logger.out)
|
||||
|
||||
# Check if error raised when unsupported tensor format used
|
||||
pipe = self.pipeline_class(**components)
|
||||
with self.assertRaises(ValueError):
|
||||
config = FasterCacheConfig(spatial_attention_block_skip_range=2, tensor_format="BFHWC")
|
||||
apply_faster_cache(pipe.transformer, config)
|
||||
|
||||
def test_faster_cache_inference(self, expected_atol: float = 0.1):
|
||||
device = "cpu" # ensure determinism for the device-dependent torch.Generator
|
||||
|
||||
def create_pipe():
|
||||
torch.manual_seed(0)
|
||||
num_layers = 2
|
||||
components = self.get_dummy_components(num_layers=num_layers)
|
||||
pipe = self.pipeline_class(**components)
|
||||
pipe = pipe.to(device)
|
||||
pipe.set_progress_bar_config(disable=None)
|
||||
return pipe
|
||||
|
||||
def run_forward(pipe):
|
||||
torch.manual_seed(0)
|
||||
inputs = self.get_dummy_inputs(device)
|
||||
inputs["num_inference_steps"] = 4
|
||||
return pipe(**inputs)[0]
|
||||
|
||||
# Run inference without FasterCache
|
||||
pipe = create_pipe()
|
||||
output = run_forward(pipe).flatten()
|
||||
original_image_slice = np.concatenate((output[:8], output[-8:]))
|
||||
|
||||
# Run inference with FasterCache enabled
|
||||
self.faster_cache_config.current_timestep_callback = lambda: pipe.current_timestep
|
||||
pipe = create_pipe()
|
||||
pipe.transformer.enable_cache(self.faster_cache_config)
|
||||
output = run_forward(pipe).flatten().flatten()
|
||||
image_slice_faster_cache_enabled = np.concatenate((output[:8], output[-8:]))
|
||||
|
||||
# Run inference with FasterCache disabled
|
||||
pipe.transformer.disable_cache()
|
||||
output = run_forward(pipe).flatten()
|
||||
image_slice_faster_cache_disabled = np.concatenate((output[:8], output[-8:]))
|
||||
|
||||
assert np.allclose(
|
||||
original_image_slice, image_slice_faster_cache_enabled, atol=expected_atol
|
||||
), "FasterCache outputs should not differ much in specified timestep range."
|
||||
assert np.allclose(
|
||||
original_image_slice, image_slice_faster_cache_disabled, atol=1e-4
|
||||
), "Outputs from normal inference and after disabling cache should not differ."
|
||||
|
||||
def test_faster_cache_state(self):
|
||||
from diffusers.hooks.faster_cache import _FASTER_CACHE_BLOCK_HOOK, _FASTER_CACHE_DENOISER_HOOK
|
||||
|
||||
device = "cpu" # ensure determinism for the device-dependent torch.Generator
|
||||
num_layers = 0
|
||||
num_single_layers = 0
|
||||
dummy_component_kwargs = {}
|
||||
dummy_component_parameters = inspect.signature(self.get_dummy_components).parameters
|
||||
if "num_layers" in dummy_component_parameters:
|
||||
num_layers = 2
|
||||
dummy_component_kwargs["num_layers"] = num_layers
|
||||
if "num_single_layers" in dummy_component_parameters:
|
||||
num_single_layers = 2
|
||||
dummy_component_kwargs["num_single_layers"] = num_single_layers
|
||||
|
||||
components = self.get_dummy_components(**dummy_component_kwargs)
|
||||
pipe = self.pipeline_class(**components)
|
||||
pipe.set_progress_bar_config(disable=None)
|
||||
|
||||
self.faster_cache_config.current_timestep_callback = lambda: pipe.current_timestep
|
||||
pipe.transformer.enable_cache(self.faster_cache_config)
|
||||
|
||||
expected_hooks = 0
|
||||
if self.faster_cache_config.spatial_attention_block_skip_range is not None:
|
||||
expected_hooks += num_layers + num_single_layers
|
||||
if self.faster_cache_config.temporal_attention_block_skip_range is not None:
|
||||
expected_hooks += num_layers + num_single_layers
|
||||
|
||||
# Check if faster_cache denoiser hook is attached
|
||||
denoiser = pipe.transformer if hasattr(pipe, "transformer") else pipe.unet
|
||||
self.assertTrue(
|
||||
hasattr(denoiser, "_diffusers_hook")
|
||||
and isinstance(denoiser._diffusers_hook.get_hook(_FASTER_CACHE_DENOISER_HOOK), FasterCacheDenoiserHook),
|
||||
"Hook should be of type FasterCacheDenoiserHook.",
|
||||
)
|
||||
|
||||
# Check if all blocks have faster_cache block hook attached
|
||||
count = 0
|
||||
for name, module in denoiser.named_modules():
|
||||
if hasattr(module, "_diffusers_hook"):
|
||||
if name == "":
|
||||
# Skip the root denoiser module
|
||||
continue
|
||||
count += 1
|
||||
self.assertTrue(
|
||||
isinstance(module._diffusers_hook.get_hook(_FASTER_CACHE_BLOCK_HOOK), FasterCacheBlockHook),
|
||||
"Hook should be of type FasterCacheBlockHook.",
|
||||
)
|
||||
self.assertEqual(count, expected_hooks, "Number of hooks should match expected number.")
|
||||
|
||||
# Perform inference to ensure that states are updated correctly
|
||||
def faster_cache_state_check_callback(pipe, i, t, kwargs):
|
||||
for name, module in denoiser.named_modules():
|
||||
if not hasattr(module, "_diffusers_hook"):
|
||||
continue
|
||||
if name == "":
|
||||
# Root denoiser module
|
||||
state = module._diffusers_hook.get_hook(_FASTER_CACHE_DENOISER_HOOK).state
|
||||
if not self.faster_cache_config.is_guidance_distilled:
|
||||
self.assertTrue(state.low_frequency_delta is not None, "Low frequency delta should be set.")
|
||||
self.assertTrue(state.high_frequency_delta is not None, "High frequency delta should be set.")
|
||||
else:
|
||||
# Internal blocks
|
||||
state = module._diffusers_hook.get_hook(_FASTER_CACHE_BLOCK_HOOK).state
|
||||
self.assertTrue(state.cache is not None and len(state.cache) == 2, "Cache should be set.")
|
||||
self.assertTrue(state.iteration == i + 1, "Hook iteration state should have updated during inference.")
|
||||
return {}
|
||||
|
||||
inputs = self.get_dummy_inputs(device)
|
||||
inputs["num_inference_steps"] = 4
|
||||
inputs["callback_on_step_end"] = faster_cache_state_check_callback
|
||||
_ = pipe(**inputs)[0]
|
||||
|
||||
# After inference, reset_stateful_hooks is called within the pipeline, which should have reset the states
|
||||
for name, module in denoiser.named_modules():
|
||||
if not hasattr(module, "_diffusers_hook"):
|
||||
continue
|
||||
|
||||
if name == "":
|
||||
# Root denoiser module
|
||||
state = module._diffusers_hook.get_hook(_FASTER_CACHE_DENOISER_HOOK).state
|
||||
self.assertTrue(state.iteration == 0, "Iteration should be reset to 0.")
|
||||
self.assertTrue(state.low_frequency_delta is None, "Low frequency delta should be reset to None.")
|
||||
self.assertTrue(state.high_frequency_delta is None, "High frequency delta should be reset to None.")
|
||||
else:
|
||||
# Internal blocks
|
||||
state = module._diffusers_hook.get_hook(_FASTER_CACHE_BLOCK_HOOK).state
|
||||
self.assertTrue(state.iteration == 0, "Iteration should be reset to 0.")
|
||||
self.assertTrue(state.batch_size is None, "Batch size should be reset to None.")
|
||||
self.assertTrue(state.cache is None, "Cache should be reset to None.")
|
||||
|
||||
|
||||
# Some models (e.g. unCLIP) are extremely likely to significantly deviate depending on which hardware is used.
|
||||
# This helper function is used to check that the image doesn't deviate on average more than 10 pixels from a
|
||||
# reference image.
|
||||
|
||||
@@ -303,6 +303,7 @@ class UnCLIPPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
|
||||
shape, dtype=dtype, device=device, generator=generator, latents=None, scheduler=DummyScheduler()
|
||||
)
|
||||
shape = (batch_size, decoder.config.in_channels, decoder.config.sample_size, decoder.config.sample_size)
|
||||
generator = torch.Generator(device=device).manual_seed(0)
|
||||
decoder_latents = pipe.prepare_latents(
|
||||
shape, dtype=dtype, device=device, generator=generator, latents=None, scheduler=DummyScheduler()
|
||||
)
|
||||
|
||||
@@ -407,6 +407,7 @@ class UnCLIPImageVariationPipelineFastTests(PipelineTesterMixin, unittest.TestCa
|
||||
pipe.super_res_first.config.sample_size,
|
||||
pipe.super_res_first.config.sample_size,
|
||||
)
|
||||
generator = torch.Generator(device=device).manual_seed(0)
|
||||
super_res_latents = pipe.prepare_latents(
|
||||
shape, dtype=dtype, device=device, generator=generator, latents=None, scheduler=DummyScheduler()
|
||||
)
|
||||
|
||||
@@ -26,6 +26,7 @@ from diffusers import BitsAndBytesConfig, DiffusionPipeline, FluxTransformer2DMo
|
||||
from diffusers.utils import is_accelerate_version, logging
|
||||
from diffusers.utils.testing_utils import (
|
||||
CaptureLogger,
|
||||
backend_empty_cache,
|
||||
is_bitsandbytes_available,
|
||||
is_torch_available,
|
||||
is_transformers_available,
|
||||
@@ -35,7 +36,7 @@ from diffusers.utils.testing_utils import (
|
||||
require_bitsandbytes_version_greater,
|
||||
require_peft_backend,
|
||||
require_torch,
|
||||
require_torch_gpu,
|
||||
require_torch_accelerator,
|
||||
require_transformers_version_greater,
|
||||
slow,
|
||||
torch_device,
|
||||
@@ -66,7 +67,7 @@ if is_bitsandbytes_available():
|
||||
@require_bitsandbytes_version_greater("0.43.2")
|
||||
@require_accelerate
|
||||
@require_torch
|
||||
@require_torch_gpu
|
||||
@require_torch_accelerator
|
||||
@slow
|
||||
class Base4bitTests(unittest.TestCase):
|
||||
# We need to test on relatively large models (aka >1b parameters otherwise the quantiztion may not work as expected)
|
||||
@@ -84,13 +85,16 @@ class Base4bitTests(unittest.TestCase):
|
||||
|
||||
def get_dummy_inputs(self):
|
||||
prompt_embeds = load_pt(
|
||||
"https://huggingface.co/datasets/hf-internal-testing/bnb-diffusers-testing-artifacts/resolve/main/prompt_embeds.pt"
|
||||
"https://huggingface.co/datasets/hf-internal-testing/bnb-diffusers-testing-artifacts/resolve/main/prompt_embeds.pt",
|
||||
torch_device,
|
||||
)
|
||||
pooled_prompt_embeds = load_pt(
|
||||
"https://huggingface.co/datasets/hf-internal-testing/bnb-diffusers-testing-artifacts/resolve/main/pooled_prompt_embeds.pt"
|
||||
"https://huggingface.co/datasets/hf-internal-testing/bnb-diffusers-testing-artifacts/resolve/main/pooled_prompt_embeds.pt",
|
||||
torch_device,
|
||||
)
|
||||
latent_model_input = load_pt(
|
||||
"https://huggingface.co/datasets/hf-internal-testing/bnb-diffusers-testing-artifacts/resolve/main/latent_model_input.pt"
|
||||
"https://huggingface.co/datasets/hf-internal-testing/bnb-diffusers-testing-artifacts/resolve/main/latent_model_input.pt",
|
||||
torch_device,
|
||||
)
|
||||
|
||||
input_dict_for_transformer = {
|
||||
@@ -106,7 +110,7 @@ class Base4bitTests(unittest.TestCase):
|
||||
class BnB4BitBasicTests(Base4bitTests):
|
||||
def setUp(self):
|
||||
gc.collect()
|
||||
torch.cuda.empty_cache()
|
||||
backend_empty_cache(torch_device)
|
||||
|
||||
# Models
|
||||
self.model_fp16 = SD3Transformer2DModel.from_pretrained(
|
||||
@@ -128,7 +132,7 @@ class BnB4BitBasicTests(Base4bitTests):
|
||||
del self.model_4bit
|
||||
|
||||
gc.collect()
|
||||
torch.cuda.empty_cache()
|
||||
backend_empty_cache(torch_device)
|
||||
|
||||
def test_quantization_num_parameters(self):
|
||||
r"""
|
||||
@@ -224,7 +228,7 @@ class BnB4BitBasicTests(Base4bitTests):
|
||||
self.assertTrue(module.weight.dtype == torch.uint8)
|
||||
|
||||
# test if inference works.
|
||||
with torch.no_grad() and torch.amp.autocast("cuda", dtype=torch.float16):
|
||||
with torch.no_grad() and torch.amp.autocast(torch_device, dtype=torch.float16):
|
||||
input_dict_for_transformer = self.get_dummy_inputs()
|
||||
model_inputs = {
|
||||
k: v.to(device=torch_device) for k, v in input_dict_for_transformer.items() if not isinstance(v, bool)
|
||||
@@ -266,9 +270,9 @@ class BnB4BitBasicTests(Base4bitTests):
|
||||
self.assertAlmostEqual(self.model_4bit.get_memory_footprint(), mem_before)
|
||||
|
||||
# Move back to CUDA device
|
||||
for device in [0, "cuda", "cuda:0", "call()"]:
|
||||
for device in [0, f"{torch_device}", f"{torch_device}:0", "call()"]:
|
||||
if device == "call()":
|
||||
self.model_4bit.cuda(0)
|
||||
self.model_4bit.to(f"{torch_device}:0")
|
||||
else:
|
||||
self.model_4bit.to(device)
|
||||
self.assertEqual(self.model_4bit.device, torch.device(0))
|
||||
@@ -286,7 +290,7 @@ class BnB4BitBasicTests(Base4bitTests):
|
||||
|
||||
with self.assertRaises(ValueError):
|
||||
# Tries with a `device` and `dtype`
|
||||
self.model_4bit.to(device="cuda:0", dtype=torch.float16)
|
||||
self.model_4bit.to(device=f"{torch_device}:0", dtype=torch.float16)
|
||||
|
||||
with self.assertRaises(ValueError):
|
||||
# Tries with a cast
|
||||
@@ -297,7 +301,7 @@ class BnB4BitBasicTests(Base4bitTests):
|
||||
self.model_4bit.half()
|
||||
|
||||
# This should work
|
||||
self.model_4bit.to("cuda")
|
||||
self.model_4bit.to(torch_device)
|
||||
|
||||
# Test if we did not break anything
|
||||
self.model_fp16 = self.model_fp16.to(dtype=torch.float32, device=torch_device)
|
||||
@@ -321,7 +325,7 @@ class BnB4BitBasicTests(Base4bitTests):
|
||||
_ = self.model_fp16.float()
|
||||
|
||||
# Check that this does not throw an error
|
||||
_ = self.model_fp16.cuda()
|
||||
_ = self.model_fp16.to(torch_device)
|
||||
|
||||
def test_bnb_4bit_wrong_config(self):
|
||||
r"""
|
||||
@@ -398,7 +402,7 @@ class BnB4BitTrainingTests(Base4bitTests):
|
||||
model_inputs.update({k: v for k, v in input_dict_for_transformer.items() if k not in model_inputs})
|
||||
|
||||
# Step 4: Check if the gradient is not None
|
||||
with torch.amp.autocast("cuda", dtype=torch.float16):
|
||||
with torch.amp.autocast(torch_device, dtype=torch.float16):
|
||||
out = self.model_4bit(**model_inputs)[0]
|
||||
out.norm().backward()
|
||||
|
||||
@@ -412,7 +416,7 @@ class BnB4BitTrainingTests(Base4bitTests):
|
||||
class SlowBnb4BitTests(Base4bitTests):
|
||||
def setUp(self) -> None:
|
||||
gc.collect()
|
||||
torch.cuda.empty_cache()
|
||||
backend_empty_cache(torch_device)
|
||||
|
||||
nf4_config = BitsAndBytesConfig(
|
||||
load_in_4bit=True,
|
||||
@@ -431,7 +435,7 @@ class SlowBnb4BitTests(Base4bitTests):
|
||||
del self.pipeline_4bit
|
||||
|
||||
gc.collect()
|
||||
torch.cuda.empty_cache()
|
||||
backend_empty_cache(torch_device)
|
||||
|
||||
def test_quality(self):
|
||||
output = self.pipeline_4bit(
|
||||
@@ -501,7 +505,7 @@ class SlowBnb4BitTests(Base4bitTests):
|
||||
reason="Test will pass after https://github.com/huggingface/accelerate/pull/3223 is in a release.",
|
||||
strict=True,
|
||||
)
|
||||
def test_pipeline_cuda_placement_works_with_nf4(self):
|
||||
def test_pipeline_device_placement_works_with_nf4(self):
|
||||
transformer_nf4_config = BitsAndBytesConfig(
|
||||
load_in_4bit=True,
|
||||
bnb_4bit_quant_type="nf4",
|
||||
@@ -532,7 +536,7 @@ class SlowBnb4BitTests(Base4bitTests):
|
||||
transformer=transformer_4bit,
|
||||
text_encoder_3=text_encoder_3_4bit,
|
||||
torch_dtype=torch.float16,
|
||||
).to("cuda")
|
||||
).to(torch_device)
|
||||
|
||||
# Check if inference works.
|
||||
_ = pipeline_4bit("table", max_sequence_length=20, num_inference_steps=2)
|
||||
@@ -696,7 +700,7 @@ class SlowBnb4BitFluxTests(Base4bitTests):
|
||||
class BaseBnb4BitSerializationTests(Base4bitTests):
|
||||
def tearDown(self):
|
||||
gc.collect()
|
||||
torch.cuda.empty_cache()
|
||||
backend_empty_cache(torch_device)
|
||||
|
||||
def test_serialization(self, quant_type="nf4", double_quant=True, safe_serialization=True):
|
||||
r"""
|
||||
|
||||
@@ -31,6 +31,7 @@ from diffusers import (
|
||||
from diffusers.utils import is_accelerate_version
|
||||
from diffusers.utils.testing_utils import (
|
||||
CaptureLogger,
|
||||
backend_empty_cache,
|
||||
is_bitsandbytes_available,
|
||||
is_torch_available,
|
||||
is_transformers_available,
|
||||
@@ -40,7 +41,7 @@ from diffusers.utils.testing_utils import (
|
||||
require_bitsandbytes_version_greater,
|
||||
require_peft_version_greater,
|
||||
require_torch,
|
||||
require_torch_gpu,
|
||||
require_torch_accelerator,
|
||||
require_transformers_version_greater,
|
||||
slow,
|
||||
torch_device,
|
||||
@@ -71,7 +72,7 @@ if is_bitsandbytes_available():
|
||||
@require_bitsandbytes_version_greater("0.43.2")
|
||||
@require_accelerate
|
||||
@require_torch
|
||||
@require_torch_gpu
|
||||
@require_torch_accelerator
|
||||
@slow
|
||||
class Base8bitTests(unittest.TestCase):
|
||||
# We need to test on relatively large models (aka >1b parameters otherwise the quantiztion may not work as expected)
|
||||
@@ -89,13 +90,16 @@ class Base8bitTests(unittest.TestCase):
|
||||
|
||||
def get_dummy_inputs(self):
|
||||
prompt_embeds = load_pt(
|
||||
"https://huggingface.co/datasets/hf-internal-testing/bnb-diffusers-testing-artifacts/resolve/main/prompt_embeds.pt"
|
||||
"https://huggingface.co/datasets/hf-internal-testing/bnb-diffusers-testing-artifacts/resolve/main/prompt_embeds.pt",
|
||||
map_location="cpu",
|
||||
)
|
||||
pooled_prompt_embeds = load_pt(
|
||||
"https://huggingface.co/datasets/hf-internal-testing/bnb-diffusers-testing-artifacts/resolve/main/pooled_prompt_embeds.pt"
|
||||
"https://huggingface.co/datasets/hf-internal-testing/bnb-diffusers-testing-artifacts/resolve/main/pooled_prompt_embeds.pt",
|
||||
map_location="cpu",
|
||||
)
|
||||
latent_model_input = load_pt(
|
||||
"https://huggingface.co/datasets/hf-internal-testing/bnb-diffusers-testing-artifacts/resolve/main/latent_model_input.pt"
|
||||
"https://huggingface.co/datasets/hf-internal-testing/bnb-diffusers-testing-artifacts/resolve/main/latent_model_input.pt",
|
||||
map_location="cpu",
|
||||
)
|
||||
|
||||
input_dict_for_transformer = {
|
||||
@@ -111,7 +115,7 @@ class Base8bitTests(unittest.TestCase):
|
||||
class BnB8bitBasicTests(Base8bitTests):
|
||||
def setUp(self):
|
||||
gc.collect()
|
||||
torch.cuda.empty_cache()
|
||||
backend_empty_cache(torch_device)
|
||||
|
||||
# Models
|
||||
self.model_fp16 = SD3Transformer2DModel.from_pretrained(
|
||||
@@ -129,7 +133,7 @@ class BnB8bitBasicTests(Base8bitTests):
|
||||
del self.model_8bit
|
||||
|
||||
gc.collect()
|
||||
torch.cuda.empty_cache()
|
||||
backend_empty_cache(torch_device)
|
||||
|
||||
def test_quantization_num_parameters(self):
|
||||
r"""
|
||||
@@ -279,7 +283,7 @@ class BnB8bitBasicTests(Base8bitTests):
|
||||
|
||||
with self.assertRaises(ValueError):
|
||||
# Tries with a `device`
|
||||
self.model_8bit.to(torch.device("cuda:0"))
|
||||
self.model_8bit.to(torch.device(f"{torch_device}:0"))
|
||||
|
||||
with self.assertRaises(ValueError):
|
||||
# Tries with a `device`
|
||||
@@ -317,7 +321,7 @@ class BnB8bitBasicTests(Base8bitTests):
|
||||
class Bnb8bitDeviceTests(Base8bitTests):
|
||||
def setUp(self) -> None:
|
||||
gc.collect()
|
||||
torch.cuda.empty_cache()
|
||||
backend_empty_cache(torch_device)
|
||||
|
||||
mixed_int8_config = BitsAndBytesConfig(load_in_8bit=True)
|
||||
self.model_8bit = SanaTransformer2DModel.from_pretrained(
|
||||
@@ -331,7 +335,7 @@ class Bnb8bitDeviceTests(Base8bitTests):
|
||||
del self.model_8bit
|
||||
|
||||
gc.collect()
|
||||
torch.cuda.empty_cache()
|
||||
backend_empty_cache(torch_device)
|
||||
|
||||
def test_buffers_device_assignment(self):
|
||||
for buffer_name, buffer in self.model_8bit.named_buffers():
|
||||
@@ -345,7 +349,7 @@ class Bnb8bitDeviceTests(Base8bitTests):
|
||||
class BnB8bitTrainingTests(Base8bitTests):
|
||||
def setUp(self):
|
||||
gc.collect()
|
||||
torch.cuda.empty_cache()
|
||||
backend_empty_cache(torch_device)
|
||||
|
||||
mixed_int8_config = BitsAndBytesConfig(load_in_8bit=True)
|
||||
self.model_8bit = SD3Transformer2DModel.from_pretrained(
|
||||
@@ -389,7 +393,7 @@ class BnB8bitTrainingTests(Base8bitTests):
|
||||
class SlowBnb8bitTests(Base8bitTests):
|
||||
def setUp(self) -> None:
|
||||
gc.collect()
|
||||
torch.cuda.empty_cache()
|
||||
backend_empty_cache(torch_device)
|
||||
|
||||
mixed_int8_config = BitsAndBytesConfig(load_in_8bit=True)
|
||||
model_8bit = SD3Transformer2DModel.from_pretrained(
|
||||
@@ -404,7 +408,7 @@ class SlowBnb8bitTests(Base8bitTests):
|
||||
del self.pipeline_8bit
|
||||
|
||||
gc.collect()
|
||||
torch.cuda.empty_cache()
|
||||
backend_empty_cache(torch_device)
|
||||
|
||||
def test_quality(self):
|
||||
output = self.pipeline_8bit(
|
||||
@@ -616,7 +620,7 @@ class SlowBnb8bitTests(Base8bitTests):
|
||||
class SlowBnb8bitFluxTests(Base8bitTests):
|
||||
def setUp(self) -> None:
|
||||
gc.collect()
|
||||
torch.cuda.empty_cache()
|
||||
backend_empty_cache(torch_device)
|
||||
|
||||
model_id = "hf-internal-testing/flux.1-dev-int8-pkg"
|
||||
t5_8bit = T5EncoderModel.from_pretrained(model_id, subfolder="text_encoder_2")
|
||||
@@ -633,7 +637,7 @@ class SlowBnb8bitFluxTests(Base8bitTests):
|
||||
del self.pipeline_8bit
|
||||
|
||||
gc.collect()
|
||||
torch.cuda.empty_cache()
|
||||
backend_empty_cache(torch_device)
|
||||
|
||||
def test_quality(self):
|
||||
# keep the resolution and max tokens to a lower number for faster execution.
|
||||
@@ -680,7 +684,7 @@ class SlowBnb8bitFluxTests(Base8bitTests):
|
||||
class BaseBnb8bitSerializationTests(Base8bitTests):
|
||||
def setUp(self):
|
||||
gc.collect()
|
||||
torch.cuda.empty_cache()
|
||||
backend_empty_cache(torch_device)
|
||||
|
||||
quantization_config = BitsAndBytesConfig(
|
||||
load_in_8bit=True,
|
||||
@@ -693,7 +697,7 @@ class BaseBnb8bitSerializationTests(Base8bitTests):
|
||||
del self.model_0
|
||||
|
||||
gc.collect()
|
||||
torch.cuda.empty_cache()
|
||||
backend_empty_cache(torch_device)
|
||||
|
||||
def test_serialization(self):
|
||||
r"""
|
||||
|
||||
@@ -57,7 +57,7 @@ class GGUFSingleFileTesterMixin:
|
||||
if isinstance(module, torch.nn.Linear) and hasattr(module.weight, "quant_type"):
|
||||
assert module.weight.dtype == torch.uint8
|
||||
if module.bias is not None:
|
||||
assert module.bias.dtype == torch.float32
|
||||
assert module.bias.dtype == self.torch_dtype
|
||||
|
||||
def test_gguf_memory_usage(self):
|
||||
quantization_config = GGUFQuantizationConfig(compute_dtype=self.torch_dtype)
|
||||
|
||||
@@ -64,7 +64,7 @@ class DPMSolverSDESchedulerTest(SchedulerCommonTest):
|
||||
if torch_device in ["mps"]:
|
||||
assert abs(result_sum.item() - 167.47821044921875) < 1e-2
|
||||
assert abs(result_mean.item() - 0.2178705964565277) < 1e-3
|
||||
elif torch_device in ["cuda"]:
|
||||
elif torch_device in ["cuda", "xpu"]:
|
||||
assert abs(result_sum.item() - 171.59352111816406) < 1e-2
|
||||
assert abs(result_mean.item() - 0.22342906892299652) < 1e-3
|
||||
else:
|
||||
@@ -96,7 +96,7 @@ class DPMSolverSDESchedulerTest(SchedulerCommonTest):
|
||||
if torch_device in ["mps"]:
|
||||
assert abs(result_sum.item() - 124.77149200439453) < 1e-2
|
||||
assert abs(result_mean.item() - 0.16226289014816284) < 1e-3
|
||||
elif torch_device in ["cuda"]:
|
||||
elif torch_device in ["cuda", "xpu"]:
|
||||
assert abs(result_sum.item() - 128.1663360595703) < 1e-2
|
||||
assert abs(result_mean.item() - 0.16688326001167297) < 1e-3
|
||||
else:
|
||||
@@ -127,7 +127,7 @@ class DPMSolverSDESchedulerTest(SchedulerCommonTest):
|
||||
if torch_device in ["mps"]:
|
||||
assert abs(result_sum.item() - 167.46957397460938) < 1e-2
|
||||
assert abs(result_mean.item() - 0.21805934607982635) < 1e-3
|
||||
elif torch_device in ["cuda"]:
|
||||
elif torch_device in ["cuda", "xpu"]:
|
||||
assert abs(result_sum.item() - 171.59353637695312) < 1e-2
|
||||
assert abs(result_mean.item() - 0.22342908382415771) < 1e-3
|
||||
else:
|
||||
@@ -159,7 +159,7 @@ class DPMSolverSDESchedulerTest(SchedulerCommonTest):
|
||||
if torch_device in ["mps"]:
|
||||
assert abs(result_sum.item() - 176.66974135742188) < 1e-2
|
||||
assert abs(result_mean.item() - 0.23003872730981811) < 1e-2
|
||||
elif torch_device in ["cuda"]:
|
||||
elif torch_device in ["cuda", "xpu"]:
|
||||
assert abs(result_sum.item() - 177.63653564453125) < 1e-2
|
||||
assert abs(result_mean.item() - 0.23003872730981811) < 1e-2
|
||||
else:
|
||||
|
||||
Reference in New Issue
Block a user