Compare commits
2 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| 278b3b8825 | |||
| d7f369cbab |
@@ -265,7 +265,7 @@ jobs:
|
||||
|
||||
- name: Run PyTorch CUDA tests
|
||||
env:
|
||||
HF_TOKEN: ${{ secrets.DIFFUSERS_HF_HUB_READ_TOKEN }}
|
||||
HF_TOKEN: ${{ secrets.HF_TOKEN }}
|
||||
# https://pytorch.org/docs/stable/notes/randomness.html#avoiding-nondeterministic-algorithms
|
||||
CUBLAS_WORKSPACE_CONFIG: :16:8
|
||||
run: |
|
||||
@@ -505,7 +505,7 @@ jobs:
|
||||
# shell: arch -arch arm64 bash {0}
|
||||
# env:
|
||||
# HF_HOME: /System/Volumes/Data/mnt/cache
|
||||
# HF_TOKEN: ${{ secrets.DIFFUSERS_HF_HUB_READ_TOKEN }}
|
||||
# HF_TOKEN: ${{ secrets.HF_TOKEN }}
|
||||
# run: |
|
||||
# ${CONDA_RUN} python -m pytest -n 1 -s -v --make-reports=tests_torch_mps \
|
||||
# --report-log=tests_torch_mps.log \
|
||||
@@ -561,7 +561,7 @@ jobs:
|
||||
# shell: arch -arch arm64 bash {0}
|
||||
# env:
|
||||
# HF_HOME: /System/Volumes/Data/mnt/cache
|
||||
# HF_TOKEN: ${{ secrets.DIFFUSERS_HF_HUB_READ_TOKEN }}
|
||||
# HF_TOKEN: ${{ secrets.HF_TOKEN }}
|
||||
# run: |
|
||||
# ${CONDA_RUN} python -m pytest -n 1 -s -v --make-reports=tests_torch_mps \
|
||||
# --report-log=tests_torch_mps.log \
|
||||
|
||||
@@ -187,7 +187,7 @@ jobs:
|
||||
|
||||
- name: Run Flax TPU tests
|
||||
env:
|
||||
HF_TOKEN: ${{ secrets.DIFFUSERS_HF_HUB_READ_TOKEN }}
|
||||
HF_TOKEN: ${{ secrets.HF_TOKEN }}
|
||||
run: |
|
||||
python -m pytest -n 0 \
|
||||
-s -v -k "Flax" \
|
||||
@@ -235,7 +235,7 @@ jobs:
|
||||
|
||||
- name: Run ONNXRuntime CUDA tests
|
||||
env:
|
||||
HF_TOKEN: ${{ secrets.DIFFUSERS_HF_HUB_READ_TOKEN }}
|
||||
HF_TOKEN: ${{ secrets.HF_TOKEN }}
|
||||
run: |
|
||||
python -m pytest -n 1 --max-worker-restart=0 --dist=loadfile \
|
||||
-s -v -k "Onnx" \
|
||||
@@ -283,7 +283,7 @@ jobs:
|
||||
python utils/print_env.py
|
||||
- name: Run example tests on GPU
|
||||
env:
|
||||
HF_TOKEN: ${{ secrets.DIFFUSERS_HF_HUB_READ_TOKEN }}
|
||||
HF_TOKEN: ${{ secrets.HF_TOKEN }}
|
||||
RUN_COMPILE: yes
|
||||
run: |
|
||||
python -m pytest -n 1 --max-worker-restart=0 --dist=loadfile -s -v -k "compile" --make-reports=tests_torch_compile_cuda tests/
|
||||
@@ -326,7 +326,7 @@ jobs:
|
||||
python utils/print_env.py
|
||||
- name: Run example tests on GPU
|
||||
env:
|
||||
HF_TOKEN: ${{ secrets.DIFFUSERS_HF_HUB_READ_TOKEN }}
|
||||
HF_TOKEN: ${{ secrets.HF_TOKEN }}
|
||||
run: |
|
||||
python -m pytest -n 1 --max-worker-restart=0 --dist=loadfile -s -v -k "xformers" --make-reports=tests_torch_xformers_cuda tests/
|
||||
- name: Failure short reports
|
||||
@@ -372,7 +372,7 @@ jobs:
|
||||
|
||||
- name: Run example tests on GPU
|
||||
env:
|
||||
HF_TOKEN: ${{ secrets.DIFFUSERS_HF_HUB_READ_TOKEN }}
|
||||
HF_TOKEN: ${{ secrets.HF_TOKEN }}
|
||||
run: |
|
||||
python -m venv /opt/venv && export PATH="/opt/venv/bin:$PATH"
|
||||
python -m uv pip install timm
|
||||
|
||||
@@ -81,7 +81,7 @@ jobs:
|
||||
python utils/print_env.py
|
||||
- name: Slow PyTorch CUDA checkpoint tests on Ubuntu
|
||||
env:
|
||||
HF_TOKEN: ${{ secrets.DIFFUSERS_HF_HUB_READ_TOKEN }}
|
||||
HF_TOKEN: ${{ secrets.HF_TOKEN }}
|
||||
# https://pytorch.org/docs/stable/notes/randomness.html#avoiding-nondeterministic-algorithms
|
||||
CUBLAS_WORKSPACE_CONFIG: :16:8
|
||||
run: |
|
||||
@@ -135,7 +135,7 @@ jobs:
|
||||
|
||||
- name: Run PyTorch CUDA tests
|
||||
env:
|
||||
HF_TOKEN: ${{ secrets.DIFFUSERS_HF_HUB_READ_TOKEN }}
|
||||
HF_TOKEN: ${{ secrets.HF_TOKEN }}
|
||||
# https://pytorch.org/docs/stable/notes/randomness.html#avoiding-nondeterministic-algorithms
|
||||
CUBLAS_WORKSPACE_CONFIG: :16:8
|
||||
run: |
|
||||
@@ -186,7 +186,7 @@ jobs:
|
||||
|
||||
- name: Run PyTorch CUDA tests
|
||||
env:
|
||||
HF_TOKEN: ${{ secrets.DIFFUSERS_HF_HUB_READ_TOKEN }}
|
||||
HF_TOKEN: ${{ secrets.HF_TOKEN }}
|
||||
# https://pytorch.org/docs/stable/notes/randomness.html#avoiding-nondeterministic-algorithms
|
||||
CUBLAS_WORKSPACE_CONFIG: :16:8
|
||||
run: |
|
||||
@@ -241,7 +241,7 @@ jobs:
|
||||
|
||||
- name: Run slow Flax TPU tests
|
||||
env:
|
||||
HF_TOKEN: ${{ secrets.DIFFUSERS_HF_HUB_READ_TOKEN }}
|
||||
HF_TOKEN: ${{ secrets.HF_TOKEN }}
|
||||
run: |
|
||||
python -m pytest -n 0 \
|
||||
-s -v -k "Flax" \
|
||||
@@ -289,7 +289,7 @@ jobs:
|
||||
|
||||
- name: Run slow ONNXRuntime CUDA tests
|
||||
env:
|
||||
HF_TOKEN: ${{ secrets.DIFFUSERS_HF_HUB_READ_TOKEN }}
|
||||
HF_TOKEN: ${{ secrets.HF_TOKEN }}
|
||||
run: |
|
||||
python -m pytest -n 1 --max-worker-restart=0 --dist=loadfile \
|
||||
-s -v -k "Onnx" \
|
||||
@@ -337,7 +337,7 @@ jobs:
|
||||
python utils/print_env.py
|
||||
- name: Run example tests on GPU
|
||||
env:
|
||||
HF_TOKEN: ${{ secrets.DIFFUSERS_HF_HUB_READ_TOKEN }}
|
||||
HF_TOKEN: ${{ secrets.HF_TOKEN }}
|
||||
RUN_COMPILE: yes
|
||||
run: |
|
||||
python -m pytest -n 1 --max-worker-restart=0 --dist=loadfile -s -v -k "compile" --make-reports=tests_torch_compile_cuda tests/
|
||||
@@ -380,7 +380,7 @@ jobs:
|
||||
python utils/print_env.py
|
||||
- name: Run example tests on GPU
|
||||
env:
|
||||
HF_TOKEN: ${{ secrets.DIFFUSERS_HF_HUB_READ_TOKEN }}
|
||||
HF_TOKEN: ${{ secrets.HF_TOKEN }}
|
||||
run: |
|
||||
python -m pytest -n 1 --max-worker-restart=0 --dist=loadfile -s -v -k "xformers" --make-reports=tests_torch_xformers_cuda tests/
|
||||
- name: Failure short reports
|
||||
@@ -426,7 +426,7 @@ jobs:
|
||||
|
||||
- name: Run example tests on GPU
|
||||
env:
|
||||
HF_TOKEN: ${{ secrets.DIFFUSERS_HF_HUB_READ_TOKEN }}
|
||||
HF_TOKEN: ${{ secrets.HF_TOKEN }}
|
||||
run: |
|
||||
python -m venv /opt/venv && export PATH="/opt/venv/bin:$PATH"
|
||||
python -m uv pip install timm
|
||||
|
||||
@@ -598,8 +598,6 @@
|
||||
title: Attention Processor
|
||||
- local: api/activations
|
||||
title: Custom activation functions
|
||||
- local: api/cache
|
||||
title: Caching methods
|
||||
- local: api/normalization
|
||||
title: Custom normalization layers
|
||||
- local: api/utilities
|
||||
|
||||
@@ -1,49 +0,0 @@
|
||||
<!-- Copyright 2024 The HuggingFace Team. All rights reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
|
||||
the License. You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
|
||||
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
|
||||
specific language governing permissions and limitations under the License. -->
|
||||
|
||||
# Caching methods
|
||||
|
||||
## Pyramid Attention Broadcast
|
||||
|
||||
[Pyramid Attention Broadcast](https://huggingface.co/papers/2408.12588) from Xuanlei Zhao, Xiaolong Jin, Kai Wang, Yang You.
|
||||
|
||||
Pyramid Attention Broadcast (PAB) is a method that speeds up inference in diffusion models by systematically skipping attention computations between successive inference steps and reusing cached attention states. The attention states are not very different between successive inference steps. The most prominent difference is in the spatial attention blocks, not as much in the temporal attention blocks, and finally the least in the cross attention blocks. Therefore, many cross attention computation blocks can be skipped, followed by the temporal and spatial attention blocks. By combining other techniques like sequence parallelism and classifier-free guidance parallelism, PAB achieves near real-time video generation.
|
||||
|
||||
Enable PAB with [`~PyramidAttentionBroadcastConfig`] on any pipeline. For some benchmarks, refer to [this](https://github.com/huggingface/diffusers/pull/9562) pull request.
|
||||
|
||||
```python
|
||||
import torch
|
||||
from diffusers import CogVideoXPipeline, PyramidAttentionBroadcastConfig
|
||||
|
||||
pipe = CogVideoXPipeline.from_pretrained("THUDM/CogVideoX-5b", torch_dtype=torch.bfloat16)
|
||||
pipe.to("cuda")
|
||||
|
||||
# Increasing the value of `spatial_attention_timestep_skip_range[0]` or decreasing the value of
|
||||
# `spatial_attention_timestep_skip_range[1]` will decrease the interval in which pyramid attention
|
||||
# broadcast is active, leader to slower inference speeds. However, large intervals can lead to
|
||||
# poorer quality of generated videos.
|
||||
config = PyramidAttentionBroadcastConfig(
|
||||
spatial_attention_block_skip_range=2,
|
||||
spatial_attention_timestep_skip_range=(100, 800),
|
||||
current_timestep_callback=lambda: pipe.current_timestep,
|
||||
)
|
||||
pipe.transformer.enable_cache(config)
|
||||
```
|
||||
|
||||
### CacheMixin
|
||||
|
||||
[[autodoc]] CacheMixin
|
||||
|
||||
### PyramidAttentionBroadcastConfig
|
||||
|
||||
[[autodoc]] PyramidAttentionBroadcastConfig
|
||||
|
||||
[[autodoc]] apply_pyramid_attention_broadcast
|
||||
@@ -309,53 +309,6 @@ image.save("output.png")
|
||||
|
||||
When unloading the Control LoRA weights, call `pipe.unload_lora_weights(reset_to_overwritten_params=True)` to reset the `pipe.transformer` completely back to its original form. The resultant pipeline can then be used with methods like [`DiffusionPipeline.from_pipe`]. More details about this argument are available in [this PR](https://github.com/huggingface/diffusers/pull/10397).
|
||||
|
||||
## IP-Adapter
|
||||
|
||||
<Tip>
|
||||
|
||||
Check out [IP-Adapter](../../../using-diffusers/ip_adapter) to learn more about how IP-Adapters work.
|
||||
|
||||
</Tip>
|
||||
|
||||
An IP-Adapter lets you prompt Flux with images, in addition to the text prompt. This is especially useful when describing complex concepts that are difficult to articulate through text alone and you have reference images.
|
||||
|
||||
```python
|
||||
import torch
|
||||
from diffusers import FluxPipeline
|
||||
from diffusers.utils import load_image
|
||||
|
||||
pipe = FluxPipeline.from_pretrained(
|
||||
"black-forest-labs/FLUX.1-dev", torch_dtype=torch.bfloat16
|
||||
).to("cuda")
|
||||
|
||||
image = load_image("https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/flux_ip_adapter_input.jpg").resize((1024, 1024))
|
||||
|
||||
pipe.load_ip_adapter(
|
||||
"XLabs-AI/flux-ip-adapter",
|
||||
weight_name="ip_adapter.safetensors",
|
||||
image_encoder_pretrained_model_name_or_path="openai/clip-vit-large-patch14"
|
||||
)
|
||||
pipe.set_ip_adapter_scale(1.0)
|
||||
|
||||
image = pipe(
|
||||
width=1024,
|
||||
height=1024,
|
||||
prompt="wearing sunglasses",
|
||||
negative_prompt="",
|
||||
true_cfg=4.0,
|
||||
generator=torch.Generator().manual_seed(4444),
|
||||
ip_adapter_image=image,
|
||||
).images[0]
|
||||
|
||||
image.save('flux_ip_adapter_output.jpg')
|
||||
```
|
||||
|
||||
<div class="justify-center">
|
||||
<img src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/flux_ip_adapter_output.jpg"/>
|
||||
<figcaption class="mt-2 text-sm text-center text-gray-500">IP-Adapter examples with prompt "wearing sunglasses"</figcaption>
|
||||
</div>
|
||||
|
||||
|
||||
## Running FP16 inference
|
||||
|
||||
Flux can generate high-quality images with FP16 (i.e. to accelerate inference on Turing/Volta GPUs) but produces different outputs compared to FP32/BF16. The issue is that some activations in the text encoders have to be clipped when running in FP16, which affects the overall image. Forcing text encoders to run with FP32 inference thus removes this output difference. See [here](https://github.com/huggingface/diffusers/pull/9097#issuecomment-2272292516) for details.
|
||||
|
||||
@@ -41,7 +41,3 @@ Utility and helper functions for working with 🤗 Diffusers.
|
||||
## randn_tensor
|
||||
|
||||
[[autodoc]] utils.torch_utils.randn_tensor
|
||||
|
||||
## apply_layerwise_casting
|
||||
|
||||
[[autodoc]] hooks.layerwise_casting.apply_layerwise_casting
|
||||
|
||||
@@ -23,60 +23,32 @@ You should install 🤗 Diffusers in a [virtual environment](https://docs.python
|
||||
If you're unfamiliar with Python virtual environments, take a look at this [guide](https://packaging.python.org/guides/installing-using-pip-and-virtual-environments/).
|
||||
A virtual environment makes it easier to manage different projects and avoid compatibility issues between dependencies.
|
||||
|
||||
Create a virtual environment with Python or [uv](https://docs.astral.sh/uv/) (refer to [Installation](https://docs.astral.sh/uv/getting-started/installation/) for installation instructions), a fast Rust-based Python package and project manager.
|
||||
|
||||
<hfoptions id="install">
|
||||
<hfoption id="uv">
|
||||
Start by creating a virtual environment in your project directory:
|
||||
|
||||
```bash
|
||||
uv venv my-env
|
||||
source my-env/bin/activate
|
||||
python -m venv .env
|
||||
```
|
||||
|
||||
</hfoption>
|
||||
<hfoption id="Python">
|
||||
Activate the virtual environment:
|
||||
|
||||
```bash
|
||||
python -m venv my-env
|
||||
source my-env/bin/activate
|
||||
source .env/bin/activate
|
||||
```
|
||||
|
||||
</hfoption>
|
||||
</hfoptions>
|
||||
|
||||
You should also install 🤗 Transformers because 🤗 Diffusers relies on its models.
|
||||
You should also install 🤗 Transformers because 🤗 Diffusers relies on its models:
|
||||
|
||||
|
||||
<frameworkcontent>
|
||||
<pt>
|
||||
|
||||
PyTorch only supports Python 3.8 - 3.11 on Windows. Install Diffusers with uv.
|
||||
|
||||
```bash
|
||||
uv install diffusers["torch"] transformers
|
||||
```
|
||||
|
||||
You can also install Diffusers with pip.
|
||||
|
||||
Note - PyTorch only supports Python 3.8 - 3.11 on Windows.
|
||||
```bash
|
||||
pip install diffusers["torch"] transformers
|
||||
```
|
||||
|
||||
</pt>
|
||||
<jax>
|
||||
|
||||
Install Diffusers with uv.
|
||||
|
||||
```bash
|
||||
uv pip install diffusers["flax"] transformers
|
||||
```
|
||||
|
||||
You can also install Diffusers with pip.
|
||||
|
||||
```bash
|
||||
pip install diffusers["flax"] transformers
|
||||
```
|
||||
|
||||
</jax>
|
||||
</frameworkcontent>
|
||||
|
||||
|
||||
@@ -158,43 +158,6 @@ In order to properly offload models after they're called, it is required to run
|
||||
|
||||
</Tip>
|
||||
|
||||
## FP8 layerwise weight-casting
|
||||
|
||||
PyTorch supports `torch.float8_e4m3fn` and `torch.float8_e5m2` as weight storage dtypes, but they can't be used for computation in many different tensor operations due to unimplemented kernel support. However, you can use these dtypes to store model weights in fp8 precision and upcast them on-the-fly when the layers are used in the forward pass. This is known as layerwise weight-casting.
|
||||
|
||||
Typically, inference on most models is done with `torch.float16` or `torch.bfloat16` weight/computation precision. Layerwise weight-casting cuts down the memory footprint of the model weights by approximately half.
|
||||
|
||||
```python
|
||||
import torch
|
||||
from diffusers import CogVideoXPipeline, CogVideoXTransformer3DModel
|
||||
from diffusers.utils import export_to_video
|
||||
|
||||
model_id = "THUDM/CogVideoX-5b"
|
||||
|
||||
# Load the model in bfloat16 and enable layerwise casting
|
||||
transformer = CogVideoXTransformer3DModel.from_pretrained(model_id, subfolder="transformer", torch_dtype=torch.bfloat16)
|
||||
transformer.enable_layerwise_casting(storage_dtype=torch.float8_e4m3fn, compute_dtype=torch.bfloat16)
|
||||
|
||||
# Load the pipeline
|
||||
pipe = CogVideoXPipeline.from_pretrained(model_id, transformer=transformer, torch_dtype=torch.bfloat16)
|
||||
pipe.to("cuda")
|
||||
|
||||
prompt = (
|
||||
"A panda, dressed in a small, red jacket and a tiny hat, sits on a wooden stool in a serene bamboo forest. "
|
||||
"The panda's fluffy paws strum a miniature acoustic guitar, producing soft, melodic tunes. Nearby, a few other "
|
||||
"pandas gather, watching curiously and some clapping in rhythm. Sunlight filters through the tall bamboo, "
|
||||
"casting a gentle glow on the scene. The panda's face is expressive, showing concentration and joy as it plays. "
|
||||
"The background includes a small, flowing stream and vibrant green foliage, enhancing the peaceful and magical "
|
||||
"atmosphere of this unique musical performance."
|
||||
)
|
||||
video = pipe(prompt=prompt, guidance_scale=6, num_inference_steps=50).frames[0]
|
||||
export_to_video(video, "output.mp4", fps=8)
|
||||
```
|
||||
|
||||
In the above example, layerwise casting is enabled on the transformer component of the pipeline. By default, certain layers are skipped from the FP8 weight casting because it can lead to significant degradation of generation quality. The normalization and modulation related weight parameters are also skipped by default.
|
||||
|
||||
However, you gain more control and flexibility by directly utilizing the [`~hooks.layerwise_casting.apply_layerwise_casting`] function instead of [`~ModelMixin.enable_layerwise_casting`].
|
||||
|
||||
## Channels-last memory format
|
||||
|
||||
The channels-last memory format is an alternative way of ordering NCHW tensors in memory to preserve dimension ordering. Channels-last tensors are ordered in such a way that the channels become the densest dimension (storing images pixel-per-pixel). Since not all operators currently support the channels-last format, it may result in worst performance but you should still try and see if it works for your model.
|
||||
|
||||
@@ -29,7 +29,7 @@ However, it is hard to decide when to reuse the cache to ensure quality generate
|
||||
This achieves a 2x speedup on FLUX.1-dev and HunyuanVideo inference with very good quality.
|
||||
|
||||
<figure>
|
||||
<img src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/para-attn/ada-cache.png" alt="Cache in Diffusion Transformer" />
|
||||
<img src="https://huggingface.co/datasets/chengzeyi/documentation-images/resolve/main/diffusers/para-attn/ada-cache.png" alt="Cache in Diffusion Transformer" />
|
||||
<figcaption>How AdaCache works, First Block Cache is a variant of it</figcaption>
|
||||
</figure>
|
||||
|
||||
|
||||
Regular → Executable
+2
-90
@@ -77,7 +77,6 @@ Please also check out our [Community Scripts](https://github.com/huggingface/dif
|
||||
PIXART-α Controlnet pipeline | Implementation of the controlnet model for pixart alpha and its diffusers pipeline | [PIXART-α Controlnet pipeline](#pixart-α-controlnet-pipeline) | - | [Raul Ciotescu](https://github.com/raulc0399/) |
|
||||
| HunyuanDiT Differential Diffusion Pipeline | Applies [Differential Diffusion](https://github.com/exx8/differential-diffusion) to [HunyuanDiT](https://github.com/huggingface/diffusers/pull/8240). | [HunyuanDiT with Differential Diffusion](#hunyuandit-with-differential-diffusion) | [](https://colab.research.google.com/drive/1v44a5fpzyr4Ffr4v2XBQ7BajzG874N4P?usp=sharing) | [Monjoy Choudhury](https://github.com/MnCSSJ4x) |
|
||||
| [🪆Matryoshka Diffusion Models](https://huggingface.co/papers/2310.15111) | A diffusion process that denoises inputs at multiple resolutions jointly and uses a NestedUNet architecture where features and parameters for small scale inputs are nested within those of the large scales. See [original codebase](https://github.com/apple/ml-mdm). | [🪆Matryoshka Diffusion Models](#matryoshka-diffusion-models) | [](https://huggingface.co/spaces/pcuenq/mdm) [](https://colab.research.google.com/gist/tolgacangoz/1f54875fc7aeaabcf284ebde64820966/matryoshka_hf.ipynb) | [M. Tolga Cangöz](https://github.com/tolgacangoz) |
|
||||
| Stable Diffusion XL Attentive Eraser Pipeline |[[AAAI2025 Oral] Attentive Eraser](https://github.com/Anonym0u3/AttentiveEraser) is a novel tuning-free method that enhances object removal capabilities in pre-trained diffusion models.|[Stable Diffusion XL Attentive Eraser Pipeline](#stable-diffusion-xl-attentive-eraser-pipeline)|-|[Wenhao Sun](https://github.com/Anonym0u3) and [Benlei Cui](https://github.com/Benny079)|
|
||||
|
||||
To load a custom pipeline you just need to pass the `custom_pipeline` argument to `DiffusionPipeline`, as one of the files in `diffusers/examples/community`. Feel free to send a PR with your own pipelines, we will merge them quickly.
|
||||
|
||||
@@ -4586,8 +4585,8 @@ image = pipe(
|
||||
```
|
||||
|
||||
|  |  |  |
|
||||
| -------------------------------------------------------------------------------------------- | ----------------------------------------------------------------------------------------- | ------------------------------------------------------------------------------------------ |
|
||||
| Gradient | Input | Output |
|
||||
| ------------------------------------------------------------------------------------------ | --------------------------------------------------------------------------------------- | ---------------------------------------------------------------------------------------- |
|
||||
| Gradient | Input | Output |
|
||||
|
||||
A colab notebook demonstrating all results can be found [here](https://colab.research.google.com/drive/1v44a5fpzyr4Ffr4v2XBQ7BajzG874N4P?usp=sharing). Depth Maps have also been added in the same colab.
|
||||
|
||||
@@ -4635,93 +4634,6 @@ make_image_grid(image, rows=1, cols=len(image))
|
||||
# 50+, 100+, and 250+ num_inference_steps are recommended for nesting levels 0, 1, and 2 respectively.
|
||||
```
|
||||
|
||||
### Stable Diffusion XL Attentive Eraser Pipeline
|
||||
<img src="https://raw.githubusercontent.com/Anonym0u3/Images/refs/heads/main/fenmian.png" width="600" />
|
||||
|
||||
**Stable Diffusion XL Attentive Eraser Pipeline** is an advanced object removal pipeline that leverages SDXL for precise content suppression and seamless region completion. This pipeline uses **self-attention redirection guidance** to modify the model’s self-attention mechanism, allowing for effective removal and inpainting across various levels of mask precision, including semantic segmentation masks, bounding boxes, and hand-drawn masks. If you are interested in more detailed information and have any questions, please refer to the [paper](https://arxiv.org/abs/2412.12974) and [official implementation](https://github.com/Anonym0u3/AttentiveEraser).
|
||||
|
||||
#### Key features
|
||||
|
||||
- **Tuning-Free**: No additional training is required, making it easy to integrate and use.
|
||||
- **Flexible Mask Support**: Works with different types of masks for targeted object removal.
|
||||
- **High-Quality Results**: Utilizes the inherent generative power of diffusion models for realistic content completion.
|
||||
|
||||
#### Usage example
|
||||
To use the Stable Diffusion XL Attentive Eraser Pipeline, you can initialize it as follows:
|
||||
```py
|
||||
import torch
|
||||
from diffusers import DDIMScheduler, DiffusionPipeline
|
||||
from diffusers.utils import load_image
|
||||
import torch.nn.functional as F
|
||||
from torchvision.transforms.functional import to_tensor, gaussian_blur
|
||||
|
||||
dtype = torch.float16
|
||||
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
|
||||
|
||||
scheduler = DDIMScheduler(beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", clip_sample=False, set_alpha_to_one=False)
|
||||
pipeline = DiffusionPipeline.from_pretrained(
|
||||
"stabilityai/stable-diffusion-xl-base-1.0",
|
||||
custom_pipeline="pipeline_stable_diffusion_xl_attentive_eraser",
|
||||
scheduler=scheduler,
|
||||
variant="fp16",
|
||||
use_safetensors=True,
|
||||
torch_dtype=dtype,
|
||||
).to(device)
|
||||
|
||||
|
||||
def preprocess_image(image_path, device):
|
||||
image = to_tensor((load_image(image_path)))
|
||||
image = image.unsqueeze_(0).float() * 2 - 1 # [0,1] --> [-1,1]
|
||||
if image.shape[1] != 3:
|
||||
image = image.expand(-1, 3, -1, -1)
|
||||
image = F.interpolate(image, (1024, 1024))
|
||||
image = image.to(dtype).to(device)
|
||||
return image
|
||||
|
||||
def preprocess_mask(mask_path, device):
|
||||
mask = to_tensor((load_image(mask_path, convert_method=lambda img: img.convert('L'))))
|
||||
mask = mask.unsqueeze_(0).float() # 0 or 1
|
||||
mask = F.interpolate(mask, (1024, 1024))
|
||||
mask = gaussian_blur(mask, kernel_size=(77, 77))
|
||||
mask[mask < 0.1] = 0
|
||||
mask[mask >= 0.1] = 1
|
||||
mask = mask.to(dtype).to(device)
|
||||
return mask
|
||||
|
||||
prompt = "" # Set prompt to null
|
||||
seed=123
|
||||
generator = torch.Generator(device=device).manual_seed(seed)
|
||||
source_image_path = "https://raw.githubusercontent.com/Anonym0u3/Images/refs/heads/main/an1024.png"
|
||||
mask_path = "https://raw.githubusercontent.com/Anonym0u3/Images/refs/heads/main/an1024_mask.png"
|
||||
source_image = preprocess_image(source_image_path, device)
|
||||
mask = preprocess_mask(mask_path, device)
|
||||
|
||||
image = pipeline(
|
||||
prompt=prompt,
|
||||
image=source_image,
|
||||
mask_image=mask,
|
||||
height=1024,
|
||||
width=1024,
|
||||
AAS=True, # enable AAS
|
||||
strength=0.8, # inpainting strength
|
||||
rm_guidance_scale=9, # removal guidance scale
|
||||
ss_steps = 9, # similarity suppression steps
|
||||
ss_scale = 0.3, # similarity suppression scale
|
||||
AAS_start_step=0, # AAS start step
|
||||
AAS_start_layer=34, # AAS start layer
|
||||
AAS_end_layer=70, # AAS end layer
|
||||
num_inference_steps=50, # number of inference steps # AAS_end_step = int(strength*num_inference_steps)
|
||||
generator=generator,
|
||||
guidance_scale=1,
|
||||
).images[0]
|
||||
image.save('./removed_img.png')
|
||||
print("Object removal completed")
|
||||
```
|
||||
|
||||
| Source Image | Mask | Output |
|
||||
| ---------------------------------------------------------------------------------------------- | ------------------------------------------------------------------------------------------- | --------------------------------------------------------------------------------------------------- |
|
||||
|  |  |  |
|
||||
|
||||
# Perturbed-Attention Guidance
|
||||
|
||||
[Project](https://ku-cvlab.github.io/Perturbed-Attention-Guidance/) / [arXiv](https://arxiv.org/abs/2403.17377) / [GitHub](https://github.com/KU-CVLAB/Perturbed-Attention-Guidance)
|
||||
|
||||
@@ -80,6 +80,7 @@ from diffusers.utils import (
|
||||
USE_PEFT_BACKEND,
|
||||
BaseOutput,
|
||||
deprecate,
|
||||
is_torch_version,
|
||||
is_torch_xla_available,
|
||||
logging,
|
||||
replace_example_docstring,
|
||||
@@ -868,7 +869,23 @@ class CrossAttnDownBlock2D(nn.Module):
|
||||
|
||||
for i, (resnet, attn) in enumerate(blocks):
|
||||
if torch.is_grad_enabled() and self.gradient_checkpointing:
|
||||
hidden_states = self._gradient_checkpointing_func(resnet, hidden_states, temb)
|
||||
|
||||
def create_custom_forward(module, return_dict=None):
|
||||
def custom_forward(*inputs):
|
||||
if return_dict is not None:
|
||||
return module(*inputs, return_dict=return_dict)
|
||||
else:
|
||||
return module(*inputs)
|
||||
|
||||
return custom_forward
|
||||
|
||||
ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
|
||||
hidden_states = torch.utils.checkpoint.checkpoint(
|
||||
create_custom_forward(resnet),
|
||||
hidden_states,
|
||||
temb,
|
||||
**ckpt_kwargs,
|
||||
)
|
||||
hidden_states = attn(
|
||||
hidden_states,
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
@@ -1013,6 +1030,17 @@ class UNetMidBlock2DCrossAttn(nn.Module):
|
||||
hidden_states = self.resnets[0](hidden_states, temb)
|
||||
for attn, resnet in zip(self.attentions, self.resnets[1:]):
|
||||
if torch.is_grad_enabled() and self.gradient_checkpointing:
|
||||
|
||||
def create_custom_forward(module, return_dict=None):
|
||||
def custom_forward(*inputs):
|
||||
if return_dict is not None:
|
||||
return module(*inputs, return_dict=return_dict)
|
||||
else:
|
||||
return module(*inputs)
|
||||
|
||||
return custom_forward
|
||||
|
||||
ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
|
||||
hidden_states = attn(
|
||||
hidden_states,
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
@@ -1021,7 +1049,12 @@ class UNetMidBlock2DCrossAttn(nn.Module):
|
||||
encoder_attention_mask=encoder_attention_mask,
|
||||
return_dict=False,
|
||||
)[0]
|
||||
hidden_states = self._gradient_checkpointing_func(resnet, hidden_states, temb)
|
||||
hidden_states = torch.utils.checkpoint.checkpoint(
|
||||
create_custom_forward(resnet),
|
||||
hidden_states,
|
||||
temb,
|
||||
**ckpt_kwargs,
|
||||
)
|
||||
else:
|
||||
hidden_states = attn(
|
||||
hidden_states,
|
||||
@@ -1159,7 +1192,23 @@ class CrossAttnUpBlock2D(nn.Module):
|
||||
hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
|
||||
|
||||
if torch.is_grad_enabled() and self.gradient_checkpointing:
|
||||
hidden_states = self._gradient_checkpointing_func(resnet, hidden_states, temb)
|
||||
|
||||
def create_custom_forward(module, return_dict=None):
|
||||
def custom_forward(*inputs):
|
||||
if return_dict is not None:
|
||||
return module(*inputs, return_dict=return_dict)
|
||||
else:
|
||||
return module(*inputs)
|
||||
|
||||
return custom_forward
|
||||
|
||||
ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
|
||||
hidden_states = torch.utils.checkpoint.checkpoint(
|
||||
create_custom_forward(resnet),
|
||||
hidden_states,
|
||||
temb,
|
||||
**ckpt_kwargs,
|
||||
)
|
||||
hidden_states = attn(
|
||||
hidden_states,
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
@@ -1233,6 +1282,10 @@ class MatryoshkaTransformer2DModel(LegacyModelMixin, LegacyConfigMixin):
|
||||
]
|
||||
)
|
||||
|
||||
def _set_gradient_checkpointing(self, module, value=False):
|
||||
if hasattr(module, "gradient_checkpointing"):
|
||||
module.gradient_checkpointing = value
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
@@ -1312,8 +1365,19 @@ class MatryoshkaTransformer2DModel(LegacyModelMixin, LegacyConfigMixin):
|
||||
# Blocks
|
||||
for block in self.transformer_blocks:
|
||||
if torch.is_grad_enabled() and self.gradient_checkpointing:
|
||||
hidden_states = self._gradient_checkpointing_func(
|
||||
block,
|
||||
|
||||
def create_custom_forward(module, return_dict=None):
|
||||
def custom_forward(*inputs):
|
||||
if return_dict is not None:
|
||||
return module(*inputs, return_dict=return_dict)
|
||||
else:
|
||||
return module(*inputs)
|
||||
|
||||
return custom_forward
|
||||
|
||||
ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
|
||||
hidden_states = torch.utils.checkpoint.checkpoint(
|
||||
create_custom_forward(block),
|
||||
hidden_states,
|
||||
attention_mask,
|
||||
encoder_hidden_states,
|
||||
@@ -1321,6 +1385,7 @@ class MatryoshkaTransformer2DModel(LegacyModelMixin, LegacyConfigMixin):
|
||||
timestep,
|
||||
cross_attention_kwargs,
|
||||
class_labels,
|
||||
**ckpt_kwargs,
|
||||
)
|
||||
else:
|
||||
hidden_states = block(
|
||||
@@ -2659,6 +2724,10 @@ class MatryoshkaUNet2DConditionModel(
|
||||
for module in self.children():
|
||||
fn_recursive_set_attention_slice(module, reversed_slice_size)
|
||||
|
||||
def _set_gradient_checkpointing(self, module, value=False):
|
||||
if hasattr(module, "gradient_checkpointing"):
|
||||
module.gradient_checkpointing = value
|
||||
|
||||
def enable_freeu(self, s1: float, s2: float, b1: float, b2: float):
|
||||
r"""Enables the FreeU mechanism from https://arxiv.org/abs/2309.11497.
|
||||
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
@@ -193,8 +193,7 @@ class StableDiffusionXLControlNetReferencePipeline(StableDiffusionXLControlNetPi
|
||||
|
||||
def prepare_ref_latents(self, refimage, batch_size, dtype, device, generator, do_classifier_free_guidance):
|
||||
refimage = refimage.to(device=device)
|
||||
needs_upcasting = self.vae.dtype == torch.float16 and self.vae.config.force_upcast
|
||||
if needs_upcasting:
|
||||
if self.vae.dtype == torch.float16 and self.vae.config.force_upcast:
|
||||
self.upcast_vae()
|
||||
refimage = refimage.to(next(iter(self.vae.post_quant_conv.parameters())).dtype)
|
||||
if refimage.dtype != self.vae.dtype:
|
||||
@@ -224,11 +223,6 @@ class StableDiffusionXLControlNetReferencePipeline(StableDiffusionXLControlNetPi
|
||||
|
||||
# aligning device to prevent device errors when concating it with the latent model input
|
||||
ref_image_latents = ref_image_latents.to(device=device, dtype=dtype)
|
||||
|
||||
# cast back to fp16 if needed
|
||||
if needs_upcasting:
|
||||
self.vae.to(dtype=torch.float16)
|
||||
|
||||
return ref_image_latents
|
||||
|
||||
def prepare_ref_image(
|
||||
|
||||
@@ -139,8 +139,7 @@ def retrieve_timesteps(
|
||||
class StableDiffusionXLReferencePipeline(StableDiffusionXLPipeline):
|
||||
def prepare_ref_latents(self, refimage, batch_size, dtype, device, generator, do_classifier_free_guidance):
|
||||
refimage = refimage.to(device=device)
|
||||
needs_upcasting = self.vae.dtype == torch.float16 and self.vae.config.force_upcast
|
||||
if needs_upcasting:
|
||||
if self.vae.dtype == torch.float16 and self.vae.config.force_upcast:
|
||||
self.upcast_vae()
|
||||
refimage = refimage.to(next(iter(self.vae.post_quant_conv.parameters())).dtype)
|
||||
if refimage.dtype != self.vae.dtype:
|
||||
@@ -170,11 +169,6 @@ class StableDiffusionXLReferencePipeline(StableDiffusionXLPipeline):
|
||||
|
||||
# aligning device to prevent device errors when concating it with the latent model input
|
||||
ref_image_latents = ref_image_latents.to(device=device, dtype=dtype)
|
||||
|
||||
# cast back to fp16 if needed
|
||||
if needs_upcasting:
|
||||
self.vae.to(dtype=torch.float16)
|
||||
|
||||
return ref_image_latents
|
||||
|
||||
def prepare_ref_image(
|
||||
|
||||
@@ -742,29 +742,3 @@ accelerate launch train_dreambooth.py \
|
||||
## Stable Diffusion XL
|
||||
|
||||
We support fine-tuning of the UNet shipped in [Stable Diffusion XL](https://huggingface.co/papers/2307.01952) with DreamBooth and LoRA via the `train_dreambooth_lora_sdxl.py` script. Please refer to the docs [here](./README_sdxl.md).
|
||||
|
||||
## Dataset
|
||||
|
||||
We support 🤗 [Datasets](https://huggingface.co/docs/datasets/index), you can find a dataset on the [Hugging Face Hub](https://huggingface.co/datasets) or use your own.
|
||||
|
||||
The quickest way to get started with your custom dataset is 🤗 Datasets' [`ImageFolder`](https://huggingface.co/docs/datasets/image_dataset#imagefolder).
|
||||
|
||||
We need to create a file `metadata.jsonl` in the directory with our images:
|
||||
|
||||
```
|
||||
{"file_name": "01.jpg", "prompt": "prompt 01"}
|
||||
{"file_name": "02.jpg", "prompt": "prompt 02"}
|
||||
```
|
||||
|
||||
If we have a directory with image-text pairs e.g. `01.jpg` and `01.txt` then `convert_to_imagefolder.py` can create `metadata.jsonl`.
|
||||
|
||||
```sh
|
||||
python convert_to_imagefolder.py --path my_dataset/
|
||||
```
|
||||
|
||||
We use `--dataset_name` and `--caption_column` with training scripts.
|
||||
|
||||
```
|
||||
--dataset_name=my_dataset/
|
||||
--caption_column=prompt
|
||||
```
|
||||
|
||||
@@ -1,32 +0,0 @@
|
||||
import argparse
|
||||
import json
|
||||
import pathlib
|
||||
|
||||
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument(
|
||||
"--path",
|
||||
type=str,
|
||||
required=True,
|
||||
help="Path to folder with image-text pairs.",
|
||||
)
|
||||
parser.add_argument("--caption_column", type=str, default="prompt", help="Name of caption column.")
|
||||
args = parser.parse_args()
|
||||
|
||||
path = pathlib.Path(args.path)
|
||||
if not path.exists():
|
||||
raise RuntimeError(f"`--path` '{args.path}' does not exist.")
|
||||
|
||||
all_files = list(path.glob("*"))
|
||||
captions = list(path.glob("*.txt"))
|
||||
images = set(all_files) - set(captions)
|
||||
images = {image.stem: image for image in images}
|
||||
caption_image = {caption: images.get(caption.stem) for caption in captions if images.get(caption.stem)}
|
||||
|
||||
metadata = path.joinpath("metadata.jsonl")
|
||||
|
||||
with metadata.open("w", encoding="utf-8") as f:
|
||||
for caption, image in caption_image.items():
|
||||
caption_text = caption.read_text(encoding="utf-8")
|
||||
json.dump({"file_name": image.name, args.caption_column: caption_text}, f)
|
||||
f.write("\n")
|
||||
@@ -63,7 +63,6 @@ from diffusers.utils import (
|
||||
is_wandb_available,
|
||||
)
|
||||
from diffusers.utils.hub_utils import load_or_create_model_card, populate_model_card
|
||||
from diffusers.utils.import_utils import is_torch_npu_available
|
||||
from diffusers.utils.torch_utils import is_compiled_module
|
||||
|
||||
|
||||
@@ -75,9 +74,6 @@ check_min_version("0.33.0.dev0")
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
if is_torch_npu_available():
|
||||
torch.npu.config.allow_internal_format = False
|
||||
|
||||
|
||||
def save_model_card(
|
||||
repo_id: str,
|
||||
@@ -605,7 +601,6 @@ def parse_args(input_args=None):
|
||||
)
|
||||
parser.add_argument("--local_rank", type=int, default=-1, help="For distributed training: local_rank")
|
||||
parser.add_argument("--enable_vae_tiling", action="store_true", help="Enabla vae tiling in log validation")
|
||||
parser.add_argument("--enable_npu_flash_attention", action="store_true", help="Enabla Flash Attention for NPU")
|
||||
|
||||
if input_args is not None:
|
||||
args = parser.parse_args(input_args)
|
||||
@@ -929,7 +924,8 @@ def main(args):
|
||||
image.save(image_filename)
|
||||
|
||||
del pipeline
|
||||
free_memory()
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
# Handle the repository creation
|
||||
if accelerator.is_main_process:
|
||||
@@ -992,13 +988,6 @@ def main(args):
|
||||
# because Gemma2 is particularly suited for bfloat16.
|
||||
text_encoder.to(dtype=torch.bfloat16)
|
||||
|
||||
if args.enable_npu_flash_attention:
|
||||
if is_torch_npu_available():
|
||||
logger.info("npu flash attention enabled.")
|
||||
transformer.enable_npu_flash_attention()
|
||||
else:
|
||||
raise ValueError("npu flash attention requires torch_npu extensions and is supported only on npu device ")
|
||||
|
||||
# Initialize a text encoding pipeline and keep it to CPU for now.
|
||||
text_encoding_pipeline = SanaPipeline.from_pretrained(
|
||||
args.pretrained_model_name_or_path,
|
||||
|
||||
@@ -1,59 +0,0 @@
|
||||
# AutoencoderKL training example
|
||||
|
||||
## Installing the dependencies
|
||||
|
||||
Before running the scripts, make sure to install the library's training dependencies:
|
||||
|
||||
**Important**
|
||||
|
||||
To make sure you can successfully run the latest versions of the example scripts, we highly recommend **installing from source** and keeping the install up to date as we update the example scripts frequently and install some example-specific requirements. To do this, execute the following steps in a new virtual environment:
|
||||
```bash
|
||||
git clone https://github.com/huggingface/diffusers
|
||||
cd diffusers
|
||||
pip install .
|
||||
```
|
||||
|
||||
Then cd in the example folder and run
|
||||
```bash
|
||||
pip install -r requirements.txt
|
||||
```
|
||||
|
||||
|
||||
And initialize an [🤗Accelerate](https://github.com/huggingface/accelerate/) environment with:
|
||||
|
||||
```bash
|
||||
accelerate config
|
||||
```
|
||||
|
||||
## Training on CIFAR10
|
||||
|
||||
Please replace the validation image with your own image.
|
||||
|
||||
```bash
|
||||
accelerate launch train_autoencoderkl.py \
|
||||
--pretrained_model_name_or_path stabilityai/sd-vae-ft-mse \
|
||||
--dataset_name=cifar10 \
|
||||
--image_column=img \
|
||||
--validation_image images/bird.jpg images/car.jpg images/dog.jpg images/frog.jpg \
|
||||
--num_train_epochs 100 \
|
||||
--gradient_accumulation_steps 2 \
|
||||
--learning_rate 4.5e-6 \
|
||||
--lr_scheduler cosine \
|
||||
--report_to wandb \
|
||||
```
|
||||
|
||||
## Training on ImageNet
|
||||
|
||||
```bash
|
||||
accelerate launch train_autoencoderkl.py \
|
||||
--pretrained_model_name_or_path stabilityai/sd-vae-ft-mse \
|
||||
--num_train_epochs 100 \
|
||||
--gradient_accumulation_steps 2 \
|
||||
--learning_rate 4.5e-6 \
|
||||
--lr_scheduler cosine \
|
||||
--report_to wandb \
|
||||
--mixed_precision bf16 \
|
||||
--train_data_dir /path/to/ImageNet/train \
|
||||
--validation_image ./image.png \
|
||||
--decoder_only
|
||||
```
|
||||
@@ -1,15 +0,0 @@
|
||||
accelerate>=0.16.0
|
||||
bitsandbytes
|
||||
datasets
|
||||
huggingface_hub
|
||||
lpips
|
||||
numpy
|
||||
packaging
|
||||
Pillow
|
||||
taming_transformers
|
||||
torch
|
||||
torchvision
|
||||
tqdm
|
||||
transformers
|
||||
wandb
|
||||
xformers
|
||||
File diff suppressed because it is too large
Load Diff
@@ -8,6 +8,7 @@ from diffusers.models import PixArtTransformer2DModel
|
||||
from diffusers.models.attention import BasicTransformerBlock
|
||||
from diffusers.models.modeling_outputs import Transformer2DModelOutput
|
||||
from diffusers.models.modeling_utils import ModelMixin
|
||||
from diffusers.utils.torch_utils import is_torch_version
|
||||
|
||||
|
||||
class PixArtControlNetAdapterBlock(nn.Module):
|
||||
@@ -150,6 +151,10 @@ class PixArtControlNetTransformerModel(ModelMixin, ConfigMixin):
|
||||
self.transformer = transformer
|
||||
self.controlnet = controlnet
|
||||
|
||||
def _set_gradient_checkpointing(self, module, value=False):
|
||||
if hasattr(module, "gradient_checkpointing"):
|
||||
module.gradient_checkpointing = value
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
@@ -215,8 +220,18 @@ class PixArtControlNetTransformerModel(ModelMixin, ConfigMixin):
|
||||
print("Gradient checkpointing is not supported for the controlnet transformer model, yet.")
|
||||
exit(1)
|
||||
|
||||
hidden_states = self._gradient_checkpointing_func(
|
||||
block,
|
||||
def create_custom_forward(module, return_dict=None):
|
||||
def custom_forward(*inputs):
|
||||
if return_dict is not None:
|
||||
return module(*inputs, return_dict=return_dict)
|
||||
else:
|
||||
return module(*inputs)
|
||||
|
||||
return custom_forward
|
||||
|
||||
ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
|
||||
hidden_states = torch.utils.checkpoint.checkpoint(
|
||||
create_custom_forward(block),
|
||||
hidden_states,
|
||||
attention_mask,
|
||||
encoder_hidden_states,
|
||||
@@ -224,6 +239,7 @@ class PixArtControlNetTransformerModel(ModelMixin, ConfigMixin):
|
||||
timestep,
|
||||
cross_attention_kwargs,
|
||||
None,
|
||||
**ckpt_kwargs,
|
||||
)
|
||||
else:
|
||||
# the control nets are only used for the blocks 1 to self.blocks_num
|
||||
|
||||
@@ -515,6 +515,10 @@ def main():
|
||||
elif accelerator.mixed_precision == "bf16":
|
||||
weight_dtype = torch.bfloat16
|
||||
|
||||
# Freeze the unet parameters before adding adapters
|
||||
for param in unet.parameters():
|
||||
param.requires_grad_(False)
|
||||
|
||||
unet_lora_config = LoraConfig(
|
||||
r=args.rank,
|
||||
lora_alpha=args.rank,
|
||||
|
||||
@@ -1,151 +0,0 @@
|
||||
"""
|
||||
This script demonstrates how to extract a LoRA checkpoint from a fully finetuned model with the CogVideoX model.
|
||||
|
||||
To make it work for other models:
|
||||
|
||||
* Change the model class. Here we use `CogVideoXTransformer3DModel`. For Flux, it would be `FluxTransformer2DModel`,
|
||||
for example. (TODO: more reason to add `AutoModel`).
|
||||
* Spply path to the base checkpoint via `base_ckpt_path`.
|
||||
* Supply path to the fully fine-tuned checkpoint via `--finetune_ckpt_path`.
|
||||
* Change the `--rank` as needed.
|
||||
|
||||
Example usage:
|
||||
|
||||
```bash
|
||||
python extract_lora_from_model.py \
|
||||
--base_ckpt_path=THUDM/CogVideoX-5b \
|
||||
--finetune_ckpt_path=finetrainers/cakeify-v0 \
|
||||
--lora_out_path=cakeify_lora.safetensors
|
||||
```
|
||||
|
||||
Script is adapted from
|
||||
https://github.com/Stability-AI/stability-ComfyUI-nodes/blob/001154622564b17223ce0191803c5fff7b87146c/control_lora_create.py
|
||||
"""
|
||||
|
||||
import argparse
|
||||
|
||||
import torch
|
||||
from safetensors.torch import save_file
|
||||
from tqdm.auto import tqdm
|
||||
|
||||
from diffusers import CogVideoXTransformer3DModel
|
||||
|
||||
|
||||
RANK = 64
|
||||
CLAMP_QUANTILE = 0.99
|
||||
|
||||
|
||||
# Comes from
|
||||
# https://github.com/Stability-AI/stability-ComfyUI-nodes/blob/001154622564b17223ce0191803c5fff7b87146c/control_lora_create.py#L9
|
||||
def extract_lora(diff, rank):
|
||||
# Important to use CUDA otherwise, very slow!
|
||||
if torch.cuda.is_available():
|
||||
diff = diff.to("cuda")
|
||||
|
||||
is_conv2d = len(diff.shape) == 4
|
||||
kernel_size = None if not is_conv2d else diff.size()[2:4]
|
||||
is_conv2d_3x3 = is_conv2d and kernel_size != (1, 1)
|
||||
out_dim, in_dim = diff.size()[0:2]
|
||||
rank = min(rank, in_dim, out_dim)
|
||||
|
||||
if is_conv2d:
|
||||
if is_conv2d_3x3:
|
||||
diff = diff.flatten(start_dim=1)
|
||||
else:
|
||||
diff = diff.squeeze()
|
||||
|
||||
U, S, Vh = torch.linalg.svd(diff.float())
|
||||
U = U[:, :rank]
|
||||
S = S[:rank]
|
||||
U = U @ torch.diag(S)
|
||||
Vh = Vh[:rank, :]
|
||||
|
||||
dist = torch.cat([U.flatten(), Vh.flatten()])
|
||||
hi_val = torch.quantile(dist, CLAMP_QUANTILE)
|
||||
low_val = -hi_val
|
||||
|
||||
U = U.clamp(low_val, hi_val)
|
||||
Vh = Vh.clamp(low_val, hi_val)
|
||||
if is_conv2d:
|
||||
U = U.reshape(out_dim, rank, 1, 1)
|
||||
Vh = Vh.reshape(rank, in_dim, kernel_size[0], kernel_size[1])
|
||||
return (U.cpu(), Vh.cpu())
|
||||
|
||||
|
||||
def parse_args():
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument(
|
||||
"--base_ckpt_path",
|
||||
default=None,
|
||||
type=str,
|
||||
required=True,
|
||||
help="Base checkpoint path from which the model was finetuned. Can be a model ID on the Hub.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--base_subfolder",
|
||||
default="transformer",
|
||||
type=str,
|
||||
help="subfolder to load the base checkpoint from if any.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--finetune_ckpt_path",
|
||||
default=None,
|
||||
type=str,
|
||||
required=True,
|
||||
help="Fully fine-tuned checkpoint path. Can be a model ID on the Hub.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--finetune_subfolder",
|
||||
default=None,
|
||||
type=str,
|
||||
help="subfolder to load the fulle finetuned checkpoint from if any.",
|
||||
)
|
||||
parser.add_argument("--rank", default=64, type=int)
|
||||
parser.add_argument("--lora_out_path", default=None, type=str, required=True)
|
||||
args = parser.parse_args()
|
||||
|
||||
if not args.lora_out_path.endswith(".safetensors"):
|
||||
raise ValueError("`lora_out_path` must end with `.safetensors`.")
|
||||
|
||||
return args
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def main(args):
|
||||
model_finetuned = CogVideoXTransformer3DModel.from_pretrained(
|
||||
args.finetune_ckpt_path, subfolder=args.finetune_subfolder, torch_dtype=torch.bfloat16
|
||||
)
|
||||
state_dict_ft = model_finetuned.state_dict()
|
||||
|
||||
# Change the `subfolder` as needed.
|
||||
base_model = CogVideoXTransformer3DModel.from_pretrained(
|
||||
args.base_ckpt_path, subfolder=args.base_subfolder, torch_dtype=torch.bfloat16
|
||||
)
|
||||
state_dict = base_model.state_dict()
|
||||
output_dict = {}
|
||||
|
||||
for k in tqdm(state_dict, desc="Extracting LoRA..."):
|
||||
original_param = state_dict[k]
|
||||
finetuned_param = state_dict_ft[k]
|
||||
if len(original_param.shape) >= 2:
|
||||
diff = finetuned_param.float() - original_param.float()
|
||||
out = extract_lora(diff, RANK)
|
||||
name = k
|
||||
|
||||
if name.endswith(".weight"):
|
||||
name = name[: -len(".weight")]
|
||||
down_key = "{}.lora_A.weight".format(name)
|
||||
up_key = "{}.lora_B.weight".format(name)
|
||||
|
||||
output_dict[up_key] = out[0].contiguous().to(finetuned_param.dtype)
|
||||
output_dict[down_key] = out[1].contiguous().to(finetuned_param.dtype)
|
||||
|
||||
prefix = "transformer" if "transformer" in base_model.__class__.__name__.lower() else "unet"
|
||||
output_dict = {f"{prefix}.{k}": v for k, v in output_dict.items()}
|
||||
save_file(output_dict, args.lora_out_path)
|
||||
print(f"LoRA saved and it contains {len(output_dict)} keys.")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
args = parse_args()
|
||||
main(args)
|
||||
@@ -28,7 +28,6 @@ from .utils import (
|
||||
|
||||
_import_structure = {
|
||||
"configuration_utils": ["ConfigMixin"],
|
||||
"hooks": [],
|
||||
"loaders": ["FromOriginalModelMixin"],
|
||||
"models": [],
|
||||
"pipelines": [],
|
||||
@@ -76,13 +75,6 @@ except OptionalDependencyNotAvailable:
|
||||
_import_structure["utils.dummy_pt_objects"] = [name for name in dir(dummy_pt_objects) if not name.startswith("_")]
|
||||
|
||||
else:
|
||||
_import_structure["hooks"].extend(
|
||||
[
|
||||
"HookRegistry",
|
||||
"PyramidAttentionBroadcastConfig",
|
||||
"apply_pyramid_attention_broadcast",
|
||||
]
|
||||
)
|
||||
_import_structure["models"].extend(
|
||||
[
|
||||
"AllegroTransformer3DModel",
|
||||
@@ -98,7 +90,6 @@ else:
|
||||
"AutoencoderKLTemporalDecoder",
|
||||
"AutoencoderOobleck",
|
||||
"AutoencoderTiny",
|
||||
"CacheMixin",
|
||||
"CogVideoXTransformer3DModel",
|
||||
"CogView3PlusTransformer2DModel",
|
||||
"ConsisIDTransformer3DModel",
|
||||
@@ -597,7 +588,6 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
except OptionalDependencyNotAvailable:
|
||||
from .utils.dummy_pt_objects import * # noqa F403
|
||||
else:
|
||||
from .hooks import HookRegistry, PyramidAttentionBroadcastConfig, apply_pyramid_attention_broadcast
|
||||
from .models import (
|
||||
AllegroTransformer3DModel,
|
||||
AsymmetricAutoencoderKL,
|
||||
@@ -612,7 +602,6 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
AutoencoderKLTemporalDecoder,
|
||||
AutoencoderOobleck,
|
||||
AutoencoderTiny,
|
||||
CacheMixin,
|
||||
CogVideoXTransformer3DModel,
|
||||
CogView3PlusTransformer2DModel,
|
||||
ConsisIDTransformer3DModel,
|
||||
|
||||
@@ -1,7 +0,0 @@
|
||||
from ..utils import is_torch_available
|
||||
|
||||
|
||||
if is_torch_available():
|
||||
from .hooks import HookRegistry, ModelHook
|
||||
from .layerwise_casting import apply_layerwise_casting, apply_layerwise_casting_hook
|
||||
from .pyramid_attention_broadcast import PyramidAttentionBroadcastConfig, apply_pyramid_attention_broadcast
|
||||
@@ -1,236 +0,0 @@
|
||||
# Copyright 2024 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import functools
|
||||
from typing import Any, Dict, Optional, Tuple
|
||||
|
||||
import torch
|
||||
|
||||
from ..utils.logging import get_logger
|
||||
|
||||
|
||||
logger = get_logger(__name__) # pylint: disable=invalid-name
|
||||
|
||||
|
||||
class ModelHook:
|
||||
r"""
|
||||
A hook that contains callbacks to be executed just before and after the forward method of a model.
|
||||
"""
|
||||
|
||||
_is_stateful = False
|
||||
|
||||
def __init__(self):
|
||||
self.fn_ref: "HookFunctionReference" = None
|
||||
|
||||
def initialize_hook(self, module: torch.nn.Module) -> torch.nn.Module:
|
||||
r"""
|
||||
Hook that is executed when a model is initialized.
|
||||
|
||||
Args:
|
||||
module (`torch.nn.Module`):
|
||||
The module attached to this hook.
|
||||
"""
|
||||
return module
|
||||
|
||||
def deinitalize_hook(self, module: torch.nn.Module) -> torch.nn.Module:
|
||||
r"""
|
||||
Hook that is executed when a model is deinitalized.
|
||||
|
||||
Args:
|
||||
module (`torch.nn.Module`):
|
||||
The module attached to this hook.
|
||||
"""
|
||||
return module
|
||||
|
||||
def pre_forward(self, module: torch.nn.Module, *args, **kwargs) -> Tuple[Tuple[Any], Dict[str, Any]]:
|
||||
r"""
|
||||
Hook that is executed just before the forward method of the model.
|
||||
|
||||
Args:
|
||||
module (`torch.nn.Module`):
|
||||
The module whose forward pass will be executed just after this event.
|
||||
args (`Tuple[Any]`):
|
||||
The positional arguments passed to the module.
|
||||
kwargs (`Dict[Str, Any]`):
|
||||
The keyword arguments passed to the module.
|
||||
Returns:
|
||||
`Tuple[Tuple[Any], Dict[Str, Any]]`:
|
||||
A tuple with the treated `args` and `kwargs`.
|
||||
"""
|
||||
return args, kwargs
|
||||
|
||||
def post_forward(self, module: torch.nn.Module, output: Any) -> Any:
|
||||
r"""
|
||||
Hook that is executed just after the forward method of the model.
|
||||
|
||||
Args:
|
||||
module (`torch.nn.Module`):
|
||||
The module whose forward pass been executed just before this event.
|
||||
output (`Any`):
|
||||
The output of the module.
|
||||
Returns:
|
||||
`Any`: The processed `output`.
|
||||
"""
|
||||
return output
|
||||
|
||||
def detach_hook(self, module: torch.nn.Module) -> torch.nn.Module:
|
||||
r"""
|
||||
Hook that is executed when the hook is detached from a module.
|
||||
|
||||
Args:
|
||||
module (`torch.nn.Module`):
|
||||
The module detached from this hook.
|
||||
"""
|
||||
return module
|
||||
|
||||
def reset_state(self, module: torch.nn.Module):
|
||||
if self._is_stateful:
|
||||
raise NotImplementedError("This hook is stateful and needs to implement the `reset_state` method.")
|
||||
return module
|
||||
|
||||
|
||||
class HookFunctionReference:
|
||||
def __init__(self) -> None:
|
||||
"""A container class that maintains mutable references to forward pass functions in a hook chain.
|
||||
|
||||
Its mutable nature allows the hook system to modify the execution chain dynamically without rebuilding the
|
||||
entire forward pass structure.
|
||||
|
||||
Attributes:
|
||||
pre_forward: A callable that processes inputs before the main forward pass.
|
||||
post_forward: A callable that processes outputs after the main forward pass.
|
||||
forward: The current forward function in the hook chain.
|
||||
original_forward: The original forward function, stored when a hook provides a custom new_forward.
|
||||
|
||||
The class enables hook removal by allowing updates to the forward chain through reference modification rather
|
||||
than requiring reconstruction of the entire chain. When a hook is removed, only the relevant references need to
|
||||
be updated, preserving the execution order of the remaining hooks.
|
||||
"""
|
||||
self.pre_forward = None
|
||||
self.post_forward = None
|
||||
self.forward = None
|
||||
self.original_forward = None
|
||||
|
||||
|
||||
class HookRegistry:
|
||||
def __init__(self, module_ref: torch.nn.Module) -> None:
|
||||
super().__init__()
|
||||
|
||||
self.hooks: Dict[str, ModelHook] = {}
|
||||
|
||||
self._module_ref = module_ref
|
||||
self._hook_order = []
|
||||
self._fn_refs = []
|
||||
|
||||
def register_hook(self, hook: ModelHook, name: str) -> None:
|
||||
if name in self.hooks.keys():
|
||||
raise ValueError(
|
||||
f"Hook with name {name} already exists in the registry. Please use a different name or "
|
||||
f"first remove the existing hook and then add a new one."
|
||||
)
|
||||
|
||||
self._module_ref = hook.initialize_hook(self._module_ref)
|
||||
|
||||
def create_new_forward(function_reference: HookFunctionReference):
|
||||
def new_forward(module, *args, **kwargs):
|
||||
args, kwargs = function_reference.pre_forward(module, *args, **kwargs)
|
||||
output = function_reference.forward(*args, **kwargs)
|
||||
return function_reference.post_forward(module, output)
|
||||
|
||||
return new_forward
|
||||
|
||||
forward = self._module_ref.forward
|
||||
|
||||
fn_ref = HookFunctionReference()
|
||||
fn_ref.pre_forward = hook.pre_forward
|
||||
fn_ref.post_forward = hook.post_forward
|
||||
fn_ref.forward = forward
|
||||
|
||||
if hasattr(hook, "new_forward"):
|
||||
fn_ref.original_forward = forward
|
||||
fn_ref.forward = functools.update_wrapper(
|
||||
functools.partial(hook.new_forward, self._module_ref), hook.new_forward
|
||||
)
|
||||
|
||||
rewritten_forward = create_new_forward(fn_ref)
|
||||
self._module_ref.forward = functools.update_wrapper(
|
||||
functools.partial(rewritten_forward, self._module_ref), rewritten_forward
|
||||
)
|
||||
|
||||
hook.fn_ref = fn_ref
|
||||
self.hooks[name] = hook
|
||||
self._hook_order.append(name)
|
||||
self._fn_refs.append(fn_ref)
|
||||
|
||||
def get_hook(self, name: str) -> Optional[ModelHook]:
|
||||
return self.hooks.get(name, None)
|
||||
|
||||
def remove_hook(self, name: str, recurse: bool = True) -> None:
|
||||
if name in self.hooks.keys():
|
||||
num_hooks = len(self._hook_order)
|
||||
hook = self.hooks[name]
|
||||
index = self._hook_order.index(name)
|
||||
fn_ref = self._fn_refs[index]
|
||||
|
||||
old_forward = fn_ref.forward
|
||||
if fn_ref.original_forward is not None:
|
||||
old_forward = fn_ref.original_forward
|
||||
|
||||
if index == num_hooks - 1:
|
||||
self._module_ref.forward = old_forward
|
||||
else:
|
||||
self._fn_refs[index + 1].forward = old_forward
|
||||
|
||||
self._module_ref = hook.deinitalize_hook(self._module_ref)
|
||||
del self.hooks[name]
|
||||
self._hook_order.pop(index)
|
||||
self._fn_refs.pop(index)
|
||||
|
||||
if recurse:
|
||||
for module_name, module in self._module_ref.named_modules():
|
||||
if module_name == "":
|
||||
continue
|
||||
if hasattr(module, "_diffusers_hook"):
|
||||
module._diffusers_hook.remove_hook(name, recurse=False)
|
||||
|
||||
def reset_stateful_hooks(self, recurse: bool = True) -> None:
|
||||
for hook_name in reversed(self._hook_order):
|
||||
hook = self.hooks[hook_name]
|
||||
if hook._is_stateful:
|
||||
hook.reset_state(self._module_ref)
|
||||
|
||||
if recurse:
|
||||
for module_name, module in self._module_ref.named_modules():
|
||||
if module_name == "":
|
||||
continue
|
||||
if hasattr(module, "_diffusers_hook"):
|
||||
module._diffusers_hook.reset_stateful_hooks(recurse=False)
|
||||
|
||||
@classmethod
|
||||
def check_if_exists_or_initialize(cls, module: torch.nn.Module) -> "HookRegistry":
|
||||
if not hasattr(module, "_diffusers_hook"):
|
||||
module._diffusers_hook = cls(module)
|
||||
return module._diffusers_hook
|
||||
|
||||
def __repr__(self) -> str:
|
||||
registry_repr = ""
|
||||
for i, hook_name in enumerate(self._hook_order):
|
||||
if self.hooks[hook_name].__class__.__repr__ is not object.__repr__:
|
||||
hook_repr = self.hooks[hook_name].__repr__()
|
||||
else:
|
||||
hook_repr = self.hooks[hook_name].__class__.__name__
|
||||
registry_repr += f" ({i}) {hook_name} - {hook_repr}"
|
||||
if i < len(self._hook_order) - 1:
|
||||
registry_repr += "\n"
|
||||
return f"HookRegistry(\n{registry_repr}\n)"
|
||||
@@ -1,191 +0,0 @@
|
||||
# Copyright 2024 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import re
|
||||
from typing import Optional, Tuple, Type, Union
|
||||
|
||||
import torch
|
||||
|
||||
from ..utils import get_logger
|
||||
from .hooks import HookRegistry, ModelHook
|
||||
|
||||
|
||||
logger = get_logger(__name__) # pylint: disable=invalid-name
|
||||
|
||||
|
||||
# fmt: off
|
||||
SUPPORTED_PYTORCH_LAYERS = (
|
||||
torch.nn.Conv1d, torch.nn.Conv2d, torch.nn.Conv3d,
|
||||
torch.nn.ConvTranspose1d, torch.nn.ConvTranspose2d, torch.nn.ConvTranspose3d,
|
||||
torch.nn.Linear,
|
||||
)
|
||||
|
||||
DEFAULT_SKIP_MODULES_PATTERN = ("pos_embed", "patch_embed", "norm", "^proj_in$", "^proj_out$")
|
||||
# fmt: on
|
||||
|
||||
|
||||
class LayerwiseCastingHook(ModelHook):
|
||||
r"""
|
||||
A hook that casts the weights of a module to a high precision dtype for computation, and to a low precision dtype
|
||||
for storage. This process may lead to quality loss in the output, but can significantly reduce the memory
|
||||
footprint.
|
||||
"""
|
||||
|
||||
_is_stateful = False
|
||||
|
||||
def __init__(self, storage_dtype: torch.dtype, compute_dtype: torch.dtype, non_blocking: bool) -> None:
|
||||
self.storage_dtype = storage_dtype
|
||||
self.compute_dtype = compute_dtype
|
||||
self.non_blocking = non_blocking
|
||||
|
||||
def initialize_hook(self, module: torch.nn.Module):
|
||||
module.to(dtype=self.storage_dtype, non_blocking=self.non_blocking)
|
||||
return module
|
||||
|
||||
def deinitalize_hook(self, module: torch.nn.Module):
|
||||
raise NotImplementedError(
|
||||
"LayerwiseCastingHook does not support deinitalization. A model once enabled with layerwise casting will "
|
||||
"have casted its weights to a lower precision dtype for storage. Casting this back to the original dtype "
|
||||
"will lead to precision loss, which might have an impact on the model's generation quality. The model should "
|
||||
"be re-initialized and loaded in the original dtype."
|
||||
)
|
||||
|
||||
def pre_forward(self, module: torch.nn.Module, *args, **kwargs):
|
||||
module.to(dtype=self.compute_dtype, non_blocking=self.non_blocking)
|
||||
return args, kwargs
|
||||
|
||||
def post_forward(self, module: torch.nn.Module, output):
|
||||
module.to(dtype=self.storage_dtype, non_blocking=self.non_blocking)
|
||||
return output
|
||||
|
||||
|
||||
def apply_layerwise_casting(
|
||||
module: torch.nn.Module,
|
||||
storage_dtype: torch.dtype,
|
||||
compute_dtype: torch.dtype,
|
||||
skip_modules_pattern: Union[str, Tuple[str, ...]] = "auto",
|
||||
skip_modules_classes: Optional[Tuple[Type[torch.nn.Module], ...]] = None,
|
||||
non_blocking: bool = False,
|
||||
) -> None:
|
||||
r"""
|
||||
Applies layerwise casting to a given module. The module expected here is a Diffusers ModelMixin but it can be any
|
||||
nn.Module using diffusers layers or pytorch primitives.
|
||||
|
||||
Example:
|
||||
|
||||
```python
|
||||
>>> import torch
|
||||
>>> from diffusers import CogVideoXTransformer3DModel
|
||||
|
||||
>>> transformer = CogVideoXTransformer3DModel.from_pretrained(
|
||||
... model_id, subfolder="transformer", torch_dtype=torch.bfloat16
|
||||
... )
|
||||
|
||||
>>> apply_layerwise_casting(
|
||||
... transformer,
|
||||
... storage_dtype=torch.float8_e4m3fn,
|
||||
... compute_dtype=torch.bfloat16,
|
||||
... skip_modules_pattern=["patch_embed", "norm", "proj_out"],
|
||||
... non_blocking=True,
|
||||
... )
|
||||
```
|
||||
|
||||
Args:
|
||||
module (`torch.nn.Module`):
|
||||
The module whose leaf modules will be cast to a high precision dtype for computation, and to a low
|
||||
precision dtype for storage.
|
||||
storage_dtype (`torch.dtype`):
|
||||
The dtype to cast the module to before/after the forward pass for storage.
|
||||
compute_dtype (`torch.dtype`):
|
||||
The dtype to cast the module to during the forward pass for computation.
|
||||
skip_modules_pattern (`Tuple[str, ...]`, defaults to `"auto"`):
|
||||
A list of patterns to match the names of the modules to skip during the layerwise casting process. If set
|
||||
to `"auto"`, the default patterns are used. If set to `None`, no modules are skipped. If set to `None`
|
||||
alongside `skip_modules_classes` being `None`, the layerwise casting is applied directly to the module
|
||||
instead of its internal submodules.
|
||||
skip_modules_classes (`Tuple[Type[torch.nn.Module], ...]`, defaults to `None`):
|
||||
A list of module classes to skip during the layerwise casting process.
|
||||
non_blocking (`bool`, defaults to `False`):
|
||||
If `True`, the weight casting operations are non-blocking.
|
||||
"""
|
||||
if skip_modules_pattern == "auto":
|
||||
skip_modules_pattern = DEFAULT_SKIP_MODULES_PATTERN
|
||||
|
||||
if skip_modules_classes is None and skip_modules_pattern is None:
|
||||
apply_layerwise_casting_hook(module, storage_dtype, compute_dtype, non_blocking)
|
||||
return
|
||||
|
||||
_apply_layerwise_casting(
|
||||
module,
|
||||
storage_dtype,
|
||||
compute_dtype,
|
||||
skip_modules_pattern,
|
||||
skip_modules_classes,
|
||||
non_blocking,
|
||||
)
|
||||
|
||||
|
||||
def _apply_layerwise_casting(
|
||||
module: torch.nn.Module,
|
||||
storage_dtype: torch.dtype,
|
||||
compute_dtype: torch.dtype,
|
||||
skip_modules_pattern: Optional[Tuple[str, ...]] = None,
|
||||
skip_modules_classes: Optional[Tuple[Type[torch.nn.Module], ...]] = None,
|
||||
non_blocking: bool = False,
|
||||
_prefix: str = "",
|
||||
) -> None:
|
||||
should_skip = (skip_modules_classes is not None and isinstance(module, skip_modules_classes)) or (
|
||||
skip_modules_pattern is not None and any(re.search(pattern, _prefix) for pattern in skip_modules_pattern)
|
||||
)
|
||||
if should_skip:
|
||||
logger.debug(f'Skipping layerwise casting for layer "{_prefix}"')
|
||||
return
|
||||
|
||||
if isinstance(module, SUPPORTED_PYTORCH_LAYERS):
|
||||
logger.debug(f'Applying layerwise casting to layer "{_prefix}"')
|
||||
apply_layerwise_casting_hook(module, storage_dtype, compute_dtype, non_blocking)
|
||||
return
|
||||
|
||||
for name, submodule in module.named_children():
|
||||
layer_name = f"{_prefix}.{name}" if _prefix else name
|
||||
_apply_layerwise_casting(
|
||||
submodule,
|
||||
storage_dtype,
|
||||
compute_dtype,
|
||||
skip_modules_pattern,
|
||||
skip_modules_classes,
|
||||
non_blocking,
|
||||
_prefix=layer_name,
|
||||
)
|
||||
|
||||
|
||||
def apply_layerwise_casting_hook(
|
||||
module: torch.nn.Module, storage_dtype: torch.dtype, compute_dtype: torch.dtype, non_blocking: bool
|
||||
) -> None:
|
||||
r"""
|
||||
Applies a `LayerwiseCastingHook` to a given module.
|
||||
|
||||
Args:
|
||||
module (`torch.nn.Module`):
|
||||
The module to attach the hook to.
|
||||
storage_dtype (`torch.dtype`):
|
||||
The dtype to cast the module to before the forward pass.
|
||||
compute_dtype (`torch.dtype`):
|
||||
The dtype to cast the module to during the forward pass.
|
||||
non_blocking (`bool`):
|
||||
If `True`, the weight casting operations are non-blocking.
|
||||
"""
|
||||
registry = HookRegistry.check_if_exists_or_initialize(module)
|
||||
hook = LayerwiseCastingHook(storage_dtype, compute_dtype, non_blocking)
|
||||
registry.register_hook(hook, "layerwise_casting")
|
||||
@@ -1,314 +0,0 @@
|
||||
# Copyright 2024 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import re
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Callable, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
|
||||
from ..models.attention_processor import Attention, MochiAttention
|
||||
from ..utils import logging
|
||||
from .hooks import HookRegistry, ModelHook
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||
|
||||
|
||||
_ATTENTION_CLASSES = (Attention, MochiAttention)
|
||||
|
||||
_SPATIAL_ATTENTION_BLOCK_IDENTIFIERS = ("blocks", "transformer_blocks", "single_transformer_blocks")
|
||||
_TEMPORAL_ATTENTION_BLOCK_IDENTIFIERS = ("temporal_transformer_blocks",)
|
||||
_CROSS_ATTENTION_BLOCK_IDENTIFIERS = ("blocks", "transformer_blocks")
|
||||
|
||||
|
||||
@dataclass
|
||||
class PyramidAttentionBroadcastConfig:
|
||||
r"""
|
||||
Configuration for Pyramid Attention Broadcast.
|
||||
|
||||
Args:
|
||||
spatial_attention_block_skip_range (`int`, *optional*, defaults to `None`):
|
||||
The number of times a specific spatial attention broadcast is skipped before computing the attention states
|
||||
to re-use. If this is set to the value `N`, the attention computation will be skipped `N - 1` times (i.e.,
|
||||
old attention states will be re-used) before computing the new attention states again.
|
||||
temporal_attention_block_skip_range (`int`, *optional*, defaults to `None`):
|
||||
The number of times a specific temporal attention broadcast is skipped before computing the attention
|
||||
states to re-use. If this is set to the value `N`, the attention computation will be skipped `N - 1` times
|
||||
(i.e., old attention states will be re-used) before computing the new attention states again.
|
||||
cross_attention_block_skip_range (`int`, *optional*, defaults to `None`):
|
||||
The number of times a specific cross-attention broadcast is skipped before computing the attention states
|
||||
to re-use. If this is set to the value `N`, the attention computation will be skipped `N - 1` times (i.e.,
|
||||
old attention states will be re-used) before computing the new attention states again.
|
||||
spatial_attention_timestep_skip_range (`Tuple[int, int]`, defaults to `(100, 800)`):
|
||||
The range of timesteps to skip in the spatial attention layer. The attention computations will be
|
||||
conditionally skipped if the current timestep is within the specified range.
|
||||
temporal_attention_timestep_skip_range (`Tuple[int, int]`, defaults to `(100, 800)`):
|
||||
The range of timesteps to skip in the temporal attention layer. The attention computations will be
|
||||
conditionally skipped if the current timestep is within the specified range.
|
||||
cross_attention_timestep_skip_range (`Tuple[int, int]`, defaults to `(100, 800)`):
|
||||
The range of timesteps to skip in the cross-attention layer. The attention computations will be
|
||||
conditionally skipped if the current timestep is within the specified range.
|
||||
spatial_attention_block_identifiers (`Tuple[str, ...]`, defaults to `("blocks", "transformer_blocks")`):
|
||||
The identifiers to match against the layer names to determine if the layer is a spatial attention layer.
|
||||
temporal_attention_block_identifiers (`Tuple[str, ...]`, defaults to `("temporal_transformer_blocks",)`):
|
||||
The identifiers to match against the layer names to determine if the layer is a temporal attention layer.
|
||||
cross_attention_block_identifiers (`Tuple[str, ...]`, defaults to `("blocks", "transformer_blocks")`):
|
||||
The identifiers to match against the layer names to determine if the layer is a cross-attention layer.
|
||||
"""
|
||||
|
||||
spatial_attention_block_skip_range: Optional[int] = None
|
||||
temporal_attention_block_skip_range: Optional[int] = None
|
||||
cross_attention_block_skip_range: Optional[int] = None
|
||||
|
||||
spatial_attention_timestep_skip_range: Tuple[int, int] = (100, 800)
|
||||
temporal_attention_timestep_skip_range: Tuple[int, int] = (100, 800)
|
||||
cross_attention_timestep_skip_range: Tuple[int, int] = (100, 800)
|
||||
|
||||
spatial_attention_block_identifiers: Tuple[str, ...] = _SPATIAL_ATTENTION_BLOCK_IDENTIFIERS
|
||||
temporal_attention_block_identifiers: Tuple[str, ...] = _TEMPORAL_ATTENTION_BLOCK_IDENTIFIERS
|
||||
cross_attention_block_identifiers: Tuple[str, ...] = _CROSS_ATTENTION_BLOCK_IDENTIFIERS
|
||||
|
||||
current_timestep_callback: Callable[[], int] = None
|
||||
|
||||
# TODO(aryan): add PAB for MLP layers (very limited speedup from testing with original codebase
|
||||
# so not added for now)
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return (
|
||||
f"PyramidAttentionBroadcastConfig("
|
||||
f" spatial_attention_block_skip_range={self.spatial_attention_block_skip_range},\n"
|
||||
f" temporal_attention_block_skip_range={self.temporal_attention_block_skip_range},\n"
|
||||
f" cross_attention_block_skip_range={self.cross_attention_block_skip_range},\n"
|
||||
f" spatial_attention_timestep_skip_range={self.spatial_attention_timestep_skip_range},\n"
|
||||
f" temporal_attention_timestep_skip_range={self.temporal_attention_timestep_skip_range},\n"
|
||||
f" cross_attention_timestep_skip_range={self.cross_attention_timestep_skip_range},\n"
|
||||
f" spatial_attention_block_identifiers={self.spatial_attention_block_identifiers},\n"
|
||||
f" temporal_attention_block_identifiers={self.temporal_attention_block_identifiers},\n"
|
||||
f" cross_attention_block_identifiers={self.cross_attention_block_identifiers},\n"
|
||||
f" current_timestep_callback={self.current_timestep_callback}\n"
|
||||
")"
|
||||
)
|
||||
|
||||
|
||||
class PyramidAttentionBroadcastState:
|
||||
r"""
|
||||
State for Pyramid Attention Broadcast.
|
||||
|
||||
Attributes:
|
||||
iteration (`int`):
|
||||
The current iteration of the Pyramid Attention Broadcast. It is necessary to ensure that `reset_state` is
|
||||
called before starting a new inference forward pass for PAB to work correctly.
|
||||
cache (`Any`):
|
||||
The cached output from the previous forward pass. This is used to re-use the attention states when the
|
||||
attention computation is skipped. It is either a tensor or a tuple of tensors, depending on the module.
|
||||
"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
self.iteration = 0
|
||||
self.cache = None
|
||||
|
||||
def reset(self):
|
||||
self.iteration = 0
|
||||
self.cache = None
|
||||
|
||||
def __repr__(self):
|
||||
cache_repr = ""
|
||||
if self.cache is None:
|
||||
cache_repr = "None"
|
||||
else:
|
||||
cache_repr = f"Tensor(shape={self.cache.shape}, dtype={self.cache.dtype})"
|
||||
return f"PyramidAttentionBroadcastState(iteration={self.iteration}, cache={cache_repr})"
|
||||
|
||||
|
||||
class PyramidAttentionBroadcastHook(ModelHook):
|
||||
r"""A hook that applies Pyramid Attention Broadcast to a given module."""
|
||||
|
||||
_is_stateful = True
|
||||
|
||||
def __init__(
|
||||
self, timestep_skip_range: Tuple[int, int], block_skip_range: int, current_timestep_callback: Callable[[], int]
|
||||
) -> None:
|
||||
super().__init__()
|
||||
|
||||
self.timestep_skip_range = timestep_skip_range
|
||||
self.block_skip_range = block_skip_range
|
||||
self.current_timestep_callback = current_timestep_callback
|
||||
|
||||
def initialize_hook(self, module):
|
||||
self.state = PyramidAttentionBroadcastState()
|
||||
return module
|
||||
|
||||
def new_forward(self, module: torch.nn.Module, *args, **kwargs) -> Any:
|
||||
is_within_timestep_range = (
|
||||
self.timestep_skip_range[0] < self.current_timestep_callback() < self.timestep_skip_range[1]
|
||||
)
|
||||
should_compute_attention = (
|
||||
self.state.cache is None
|
||||
or self.state.iteration == 0
|
||||
or not is_within_timestep_range
|
||||
or self.state.iteration % self.block_skip_range == 0
|
||||
)
|
||||
|
||||
if should_compute_attention:
|
||||
output = self.fn_ref.original_forward(*args, **kwargs)
|
||||
else:
|
||||
output = self.state.cache
|
||||
|
||||
self.state.cache = output
|
||||
self.state.iteration += 1
|
||||
return output
|
||||
|
||||
def reset_state(self, module: torch.nn.Module) -> None:
|
||||
self.state.reset()
|
||||
return module
|
||||
|
||||
|
||||
def apply_pyramid_attention_broadcast(
|
||||
module: torch.nn.Module,
|
||||
config: PyramidAttentionBroadcastConfig,
|
||||
):
|
||||
r"""
|
||||
Apply [Pyramid Attention Broadcast](https://huggingface.co/papers/2408.12588) to a given pipeline.
|
||||
|
||||
PAB is an attention approximation method that leverages the similarity in attention states between timesteps to
|
||||
reduce the computational cost of attention computation. The key takeaway from the paper is that the attention
|
||||
similarity in the cross-attention layers between timesteps is high, followed by less similarity in the temporal and
|
||||
spatial layers. This allows for the skipping of attention computation in the cross-attention layers more frequently
|
||||
than in the temporal and spatial layers. Applying PAB will, therefore, speedup the inference process.
|
||||
|
||||
Args:
|
||||
module (`torch.nn.Module`):
|
||||
The module to apply Pyramid Attention Broadcast to.
|
||||
config (`Optional[PyramidAttentionBroadcastConfig]`, `optional`, defaults to `None`):
|
||||
The configuration to use for Pyramid Attention Broadcast.
|
||||
|
||||
Example:
|
||||
|
||||
```python
|
||||
>>> import torch
|
||||
>>> from diffusers import CogVideoXPipeline, PyramidAttentionBroadcastConfig, apply_pyramid_attention_broadcast
|
||||
>>> from diffusers.utils import export_to_video
|
||||
|
||||
>>> pipe = CogVideoXPipeline.from_pretrained("THUDM/CogVideoX-5b", torch_dtype=torch.bfloat16)
|
||||
>>> pipe.to("cuda")
|
||||
|
||||
>>> config = PyramidAttentionBroadcastConfig(
|
||||
... spatial_attention_block_skip_range=2,
|
||||
... spatial_attention_timestep_skip_range=(100, 800),
|
||||
... current_timestep_callback=lambda: pipe.current_timestep,
|
||||
... )
|
||||
>>> apply_pyramid_attention_broadcast(pipe.transformer, config)
|
||||
```
|
||||
"""
|
||||
if config.current_timestep_callback is None:
|
||||
raise ValueError(
|
||||
"The `current_timestep_callback` function must be provided in the configuration to apply Pyramid Attention Broadcast."
|
||||
)
|
||||
|
||||
if (
|
||||
config.spatial_attention_block_skip_range is None
|
||||
and config.temporal_attention_block_skip_range is None
|
||||
and config.cross_attention_block_skip_range is None
|
||||
):
|
||||
logger.warning(
|
||||
"Pyramid Attention Broadcast requires one or more of `spatial_attention_block_skip_range`, `temporal_attention_block_skip_range` "
|
||||
"or `cross_attention_block_skip_range` parameters to be set to an integer, not `None`. Defaulting to using `spatial_attention_block_skip_range=2`. "
|
||||
"To avoid this warning, please set one of the above parameters."
|
||||
)
|
||||
config.spatial_attention_block_skip_range = 2
|
||||
|
||||
for name, submodule in module.named_modules():
|
||||
if not isinstance(submodule, _ATTENTION_CLASSES):
|
||||
# PAB has been implemented specific to Diffusers' Attention classes. However, this does not mean that PAB
|
||||
# cannot be applied to this layer. For custom layers, users can extend this functionality and implement
|
||||
# their own PAB logic similar to `_apply_pyramid_attention_broadcast_on_attention_class`.
|
||||
continue
|
||||
_apply_pyramid_attention_broadcast_on_attention_class(name, submodule, config)
|
||||
|
||||
|
||||
def _apply_pyramid_attention_broadcast_on_attention_class(
|
||||
name: str, module: Attention, config: PyramidAttentionBroadcastConfig
|
||||
) -> bool:
|
||||
is_spatial_self_attention = (
|
||||
any(re.search(identifier, name) is not None for identifier in config.spatial_attention_block_identifiers)
|
||||
and config.spatial_attention_block_skip_range is not None
|
||||
and not getattr(module, "is_cross_attention", False)
|
||||
)
|
||||
is_temporal_self_attention = (
|
||||
any(re.search(identifier, name) is not None for identifier in config.temporal_attention_block_identifiers)
|
||||
and config.temporal_attention_block_skip_range is not None
|
||||
and not getattr(module, "is_cross_attention", False)
|
||||
)
|
||||
is_cross_attention = (
|
||||
any(re.search(identifier, name) is not None for identifier in config.cross_attention_block_identifiers)
|
||||
and config.cross_attention_block_skip_range is not None
|
||||
and getattr(module, "is_cross_attention", False)
|
||||
)
|
||||
|
||||
block_skip_range, timestep_skip_range, block_type = None, None, None
|
||||
if is_spatial_self_attention:
|
||||
block_skip_range = config.spatial_attention_block_skip_range
|
||||
timestep_skip_range = config.spatial_attention_timestep_skip_range
|
||||
block_type = "spatial"
|
||||
elif is_temporal_self_attention:
|
||||
block_skip_range = config.temporal_attention_block_skip_range
|
||||
timestep_skip_range = config.temporal_attention_timestep_skip_range
|
||||
block_type = "temporal"
|
||||
elif is_cross_attention:
|
||||
block_skip_range = config.cross_attention_block_skip_range
|
||||
timestep_skip_range = config.cross_attention_timestep_skip_range
|
||||
block_type = "cross"
|
||||
|
||||
if block_skip_range is None or timestep_skip_range is None:
|
||||
logger.info(
|
||||
f'Unable to apply Pyramid Attention Broadcast to the selected layer: "{name}" because it does '
|
||||
f"not match any of the required criteria for spatial, temporal or cross attention layers. Note, "
|
||||
f"however, that this layer may still be valid for applying PAB. Please specify the correct "
|
||||
f"block identifiers in the configuration."
|
||||
)
|
||||
return False
|
||||
|
||||
logger.debug(f"Enabling Pyramid Attention Broadcast ({block_type}) in layer: {name}")
|
||||
_apply_pyramid_attention_broadcast_hook(
|
||||
module, timestep_skip_range, block_skip_range, config.current_timestep_callback
|
||||
)
|
||||
return True
|
||||
|
||||
|
||||
def _apply_pyramid_attention_broadcast_hook(
|
||||
module: Union[Attention, MochiAttention],
|
||||
timestep_skip_range: Tuple[int, int],
|
||||
block_skip_range: int,
|
||||
current_timestep_callback: Callable[[], int],
|
||||
):
|
||||
r"""
|
||||
Apply [Pyramid Attention Broadcast](https://huggingface.co/papers/2408.12588) to a given torch.nn.Module.
|
||||
|
||||
Args:
|
||||
module (`torch.nn.Module`):
|
||||
The module to apply Pyramid Attention Broadcast to.
|
||||
timestep_skip_range (`Tuple[int, int]`):
|
||||
The range of timesteps to skip in the attention layer. The attention computations will be conditionally
|
||||
skipped if the current timestep is within the specified range.
|
||||
block_skip_range (`int`):
|
||||
The number of times a specific attention broadcast is skipped before computing the attention states to
|
||||
re-use. If this is set to the value `N`, the attention computation will be skipped `N - 1` times (i.e., old
|
||||
attention states will be re-used) before computing the new attention states again.
|
||||
current_timestep_callback (`Callable[[], int]`):
|
||||
A callback function that returns the current inference timestep.
|
||||
"""
|
||||
registry = HookRegistry.check_if_exists_or_initialize(module)
|
||||
hook = PyramidAttentionBroadcastHook(timestep_skip_range, block_skip_range, current_timestep_callback)
|
||||
registry.register_hook(hook, "pyramid_attention_broadcast")
|
||||
@@ -39,7 +39,6 @@ if is_torch_available():
|
||||
_import_structure["autoencoders.autoencoder_tiny"] = ["AutoencoderTiny"]
|
||||
_import_structure["autoencoders.consistency_decoder_vae"] = ["ConsistencyDecoderVAE"]
|
||||
_import_structure["autoencoders.vq_model"] = ["VQModel"]
|
||||
_import_structure["cache_utils"] = ["CacheMixin"]
|
||||
_import_structure["controlnets.controlnet"] = ["ControlNetModel"]
|
||||
_import_structure["controlnets.controlnet_flux"] = ["FluxControlNetModel", "FluxMultiControlNetModel"]
|
||||
_import_structure["controlnets.controlnet_hunyuan"] = [
|
||||
@@ -110,7 +109,6 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
ConsistencyDecoderVAE,
|
||||
VQModel,
|
||||
)
|
||||
from .cache_utils import CacheMixin
|
||||
from .controlnets import (
|
||||
ControlNetModel,
|
||||
ControlNetUnionModel,
|
||||
|
||||
@@ -1215,10 +1215,20 @@ class FeedForward(nn.Module):
|
||||
bias: bool = True,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
if inner_dim is None:
|
||||
inner_dim = int(dim * mult)
|
||||
dim_out = dim_out if dim_out is not None else dim
|
||||
|
||||
self._dim = dim
|
||||
self._dim_out = dim_out
|
||||
self._mult = mult
|
||||
self._dropout = dropout
|
||||
self._activation_fn = activation_fn
|
||||
self._final_dropout = final_dropout
|
||||
self._inner_dim = inner_dim
|
||||
self._bias = bias
|
||||
|
||||
if activation_fn == "gelu":
|
||||
act_fn = GELU(dim, inner_dim, bias=bias)
|
||||
if activation_fn == "gelu-approximate":
|
||||
|
||||
@@ -3154,11 +3154,6 @@ class AttnProcessorNPU:
|
||||
# scaled_dot_product_attention expects attention_mask shape to be
|
||||
# (batch, heads, source_length, target_length)
|
||||
attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])
|
||||
attention_mask = attention_mask.repeat(1, 1, hidden_states.shape[1], 1)
|
||||
if attention_mask.dtype == torch.bool:
|
||||
attention_mask = torch.logical_not(attention_mask.bool())
|
||||
else:
|
||||
attention_mask = attention_mask.bool()
|
||||
|
||||
if attn.group_norm is not None:
|
||||
hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
|
||||
|
||||
@@ -60,8 +60,6 @@ class AsymmetricAutoencoderKL(ModelMixin, ConfigMixin):
|
||||
Synthesis with Latent Diffusion Models](https://arxiv.org/abs/2112.10752) paper.
|
||||
"""
|
||||
|
||||
_skip_layerwise_casting_patterns = ["decoder"]
|
||||
|
||||
@register_to_config
|
||||
def __init__(
|
||||
self,
|
||||
|
||||
@@ -138,6 +138,10 @@ class AutoencoderKL(ModelMixin, ConfigMixin, FromOriginalModelMixin, PeftAdapter
|
||||
self.tile_latent_min_size = int(sample_size / (2 ** (len(self.config.block_out_channels) - 1)))
|
||||
self.tile_overlap_factor = 0.25
|
||||
|
||||
def _set_gradient_checkpointing(self, module, value=False):
|
||||
if isinstance(module, (Encoder, Decoder)):
|
||||
module.gradient_checkpointing = value
|
||||
|
||||
def enable_tiling(self, use_tiling: bool = True):
|
||||
r"""
|
||||
Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to
|
||||
|
||||
@@ -507,12 +507,19 @@ class AllegroEncoder3D(nn.Module):
|
||||
sample = sample + residual
|
||||
|
||||
if torch.is_grad_enabled() and self.gradient_checkpointing:
|
||||
|
||||
def create_custom_forward(module):
|
||||
def custom_forward(*inputs):
|
||||
return module(*inputs)
|
||||
|
||||
return custom_forward
|
||||
|
||||
# Down blocks
|
||||
for down_block in self.down_blocks:
|
||||
sample = self._gradient_checkpointing_func(down_block, sample)
|
||||
sample = torch.utils.checkpoint.checkpoint(create_custom_forward(down_block), sample)
|
||||
|
||||
# Mid block
|
||||
sample = self._gradient_checkpointing_func(self.mid_block, sample)
|
||||
sample = torch.utils.checkpoint.checkpoint(create_custom_forward(self.mid_block), sample)
|
||||
else:
|
||||
# Down blocks
|
||||
for down_block in self.down_blocks:
|
||||
@@ -640,12 +647,19 @@ class AllegroDecoder3D(nn.Module):
|
||||
upscale_dtype = next(iter(self.up_blocks.parameters())).dtype
|
||||
|
||||
if torch.is_grad_enabled() and self.gradient_checkpointing:
|
||||
|
||||
def create_custom_forward(module):
|
||||
def custom_forward(*inputs):
|
||||
return module(*inputs)
|
||||
|
||||
return custom_forward
|
||||
|
||||
# Mid block
|
||||
sample = self._gradient_checkpointing_func(self.mid_block, sample)
|
||||
sample = torch.utils.checkpoint.checkpoint(create_custom_forward(self.mid_block), sample)
|
||||
|
||||
# Up blocks
|
||||
for up_block in self.up_blocks:
|
||||
sample = self._gradient_checkpointing_func(up_block, sample)
|
||||
sample = torch.utils.checkpoint.checkpoint(create_custom_forward(up_block), sample)
|
||||
|
||||
else:
|
||||
# Mid block
|
||||
@@ -795,6 +809,10 @@ class AutoencoderKLAllegro(ModelMixin, ConfigMixin):
|
||||
sample_size - self.tile_overlap_w,
|
||||
)
|
||||
|
||||
def _set_gradient_checkpointing(self, module, value=False):
|
||||
if isinstance(module, (AllegroEncoder3D, AllegroDecoder3D)):
|
||||
module.gradient_checkpointing = value
|
||||
|
||||
def enable_tiling(self) -> None:
|
||||
r"""
|
||||
Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to
|
||||
|
||||
@@ -421,8 +421,15 @@ class CogVideoXDownBlock3D(nn.Module):
|
||||
conv_cache_key = f"resnet_{i}"
|
||||
|
||||
if torch.is_grad_enabled() and self.gradient_checkpointing:
|
||||
hidden_states, new_conv_cache[conv_cache_key] = self._gradient_checkpointing_func(
|
||||
resnet,
|
||||
|
||||
def create_custom_forward(module):
|
||||
def create_forward(*inputs):
|
||||
return module(*inputs)
|
||||
|
||||
return create_forward
|
||||
|
||||
hidden_states, new_conv_cache[conv_cache_key] = torch.utils.checkpoint.checkpoint(
|
||||
create_custom_forward(resnet),
|
||||
hidden_states,
|
||||
temb,
|
||||
zq,
|
||||
@@ -516,8 +523,15 @@ class CogVideoXMidBlock3D(nn.Module):
|
||||
conv_cache_key = f"resnet_{i}"
|
||||
|
||||
if torch.is_grad_enabled() and self.gradient_checkpointing:
|
||||
hidden_states, new_conv_cache[conv_cache_key] = self._gradient_checkpointing_func(
|
||||
resnet, hidden_states, temb, zq, conv_cache.get(conv_cache_key)
|
||||
|
||||
def create_custom_forward(module):
|
||||
def create_forward(*inputs):
|
||||
return module(*inputs)
|
||||
|
||||
return create_forward
|
||||
|
||||
hidden_states, new_conv_cache[conv_cache_key] = torch.utils.checkpoint.checkpoint(
|
||||
create_custom_forward(resnet), hidden_states, temb, zq, conv_cache.get(conv_cache_key)
|
||||
)
|
||||
else:
|
||||
hidden_states, new_conv_cache[conv_cache_key] = resnet(
|
||||
@@ -623,8 +637,15 @@ class CogVideoXUpBlock3D(nn.Module):
|
||||
conv_cache_key = f"resnet_{i}"
|
||||
|
||||
if torch.is_grad_enabled() and self.gradient_checkpointing:
|
||||
hidden_states, new_conv_cache[conv_cache_key] = self._gradient_checkpointing_func(
|
||||
resnet,
|
||||
|
||||
def create_custom_forward(module):
|
||||
def create_forward(*inputs):
|
||||
return module(*inputs)
|
||||
|
||||
return create_forward
|
||||
|
||||
hidden_states, new_conv_cache[conv_cache_key] = torch.utils.checkpoint.checkpoint(
|
||||
create_custom_forward(resnet),
|
||||
hidden_states,
|
||||
temb,
|
||||
zq,
|
||||
@@ -753,11 +774,18 @@ class CogVideoXEncoder3D(nn.Module):
|
||||
hidden_states, new_conv_cache["conv_in"] = self.conv_in(sample, conv_cache=conv_cache.get("conv_in"))
|
||||
|
||||
if torch.is_grad_enabled() and self.gradient_checkpointing:
|
||||
|
||||
def create_custom_forward(module):
|
||||
def custom_forward(*inputs):
|
||||
return module(*inputs)
|
||||
|
||||
return custom_forward
|
||||
|
||||
# 1. Down
|
||||
for i, down_block in enumerate(self.down_blocks):
|
||||
conv_cache_key = f"down_block_{i}"
|
||||
hidden_states, new_conv_cache[conv_cache_key] = self._gradient_checkpointing_func(
|
||||
down_block,
|
||||
hidden_states, new_conv_cache[conv_cache_key] = torch.utils.checkpoint.checkpoint(
|
||||
create_custom_forward(down_block),
|
||||
hidden_states,
|
||||
temb,
|
||||
None,
|
||||
@@ -765,8 +793,8 @@ class CogVideoXEncoder3D(nn.Module):
|
||||
)
|
||||
|
||||
# 2. Mid
|
||||
hidden_states, new_conv_cache["mid_block"] = self._gradient_checkpointing_func(
|
||||
self.mid_block,
|
||||
hidden_states, new_conv_cache["mid_block"] = torch.utils.checkpoint.checkpoint(
|
||||
create_custom_forward(self.mid_block),
|
||||
hidden_states,
|
||||
temb,
|
||||
None,
|
||||
@@ -912,9 +940,16 @@ class CogVideoXDecoder3D(nn.Module):
|
||||
hidden_states, new_conv_cache["conv_in"] = self.conv_in(sample, conv_cache=conv_cache.get("conv_in"))
|
||||
|
||||
if torch.is_grad_enabled() and self.gradient_checkpointing:
|
||||
|
||||
def create_custom_forward(module):
|
||||
def custom_forward(*inputs):
|
||||
return module(*inputs)
|
||||
|
||||
return custom_forward
|
||||
|
||||
# 1. Mid
|
||||
hidden_states, new_conv_cache["mid_block"] = self._gradient_checkpointing_func(
|
||||
self.mid_block,
|
||||
hidden_states, new_conv_cache["mid_block"] = torch.utils.checkpoint.checkpoint(
|
||||
create_custom_forward(self.mid_block),
|
||||
hidden_states,
|
||||
temb,
|
||||
sample,
|
||||
@@ -924,8 +959,8 @@ class CogVideoXDecoder3D(nn.Module):
|
||||
# 2. Up
|
||||
for i, up_block in enumerate(self.up_blocks):
|
||||
conv_cache_key = f"up_block_{i}"
|
||||
hidden_states, new_conv_cache[conv_cache_key] = self._gradient_checkpointing_func(
|
||||
up_block,
|
||||
hidden_states, new_conv_cache[conv_cache_key] = torch.utils.checkpoint.checkpoint(
|
||||
create_custom_forward(up_block),
|
||||
hidden_states,
|
||||
temb,
|
||||
sample,
|
||||
@@ -1087,6 +1122,10 @@ class AutoencoderKLCogVideoX(ModelMixin, ConfigMixin, FromOriginalModelMixin):
|
||||
self.tile_overlap_factor_height = 1 / 6
|
||||
self.tile_overlap_factor_width = 1 / 5
|
||||
|
||||
def _set_gradient_checkpointing(self, module, value=False):
|
||||
if isinstance(module, (CogVideoXEncoder3D, CogVideoXDecoder3D)):
|
||||
module.gradient_checkpointing = value
|
||||
|
||||
def enable_tiling(
|
||||
self,
|
||||
tile_sample_min_height: Optional[int] = None,
|
||||
|
||||
@@ -12,7 +12,7 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from typing import Optional, Tuple, Union
|
||||
from typing import Any, Dict, Optional, Tuple, Union
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
@@ -21,7 +21,7 @@ import torch.nn.functional as F
|
||||
import torch.utils.checkpoint
|
||||
|
||||
from ...configuration_utils import ConfigMixin, register_to_config
|
||||
from ...utils import logging
|
||||
from ...utils import is_torch_version, logging
|
||||
from ...utils.accelerate_utils import apply_forward_hook
|
||||
from ..activations import get_activation
|
||||
from ..attention_processor import Attention
|
||||
@@ -252,7 +252,21 @@ class HunyuanVideoMidBlock3D(nn.Module):
|
||||
|
||||
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
||||
if torch.is_grad_enabled() and self.gradient_checkpointing:
|
||||
hidden_states = self._gradient_checkpointing_func(self.resnets[0], hidden_states)
|
||||
|
||||
def create_custom_forward(module, return_dict=None):
|
||||
def custom_forward(*inputs):
|
||||
if return_dict is not None:
|
||||
return module(*inputs, return_dict=return_dict)
|
||||
else:
|
||||
return module(*inputs)
|
||||
|
||||
return custom_forward
|
||||
|
||||
ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
|
||||
|
||||
hidden_states = torch.utils.checkpoint.checkpoint(
|
||||
create_custom_forward(self.resnets[0]), hidden_states, **ckpt_kwargs
|
||||
)
|
||||
|
||||
for attn, resnet in zip(self.attentions, self.resnets[1:]):
|
||||
if attn is not None:
|
||||
@@ -264,7 +278,9 @@ class HunyuanVideoMidBlock3D(nn.Module):
|
||||
hidden_states = attn(hidden_states, attention_mask=attention_mask)
|
||||
hidden_states = hidden_states.unflatten(1, (num_frames, height, width)).permute(0, 4, 1, 2, 3)
|
||||
|
||||
hidden_states = self._gradient_checkpointing_func(resnet, hidden_states)
|
||||
hidden_states = torch.utils.checkpoint.checkpoint(
|
||||
create_custom_forward(resnet), hidden_states, **ckpt_kwargs
|
||||
)
|
||||
|
||||
else:
|
||||
hidden_states = self.resnets[0](hidden_states)
|
||||
@@ -334,8 +350,22 @@ class HunyuanVideoDownBlock3D(nn.Module):
|
||||
|
||||
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
||||
if torch.is_grad_enabled() and self.gradient_checkpointing:
|
||||
|
||||
def create_custom_forward(module, return_dict=None):
|
||||
def custom_forward(*inputs):
|
||||
if return_dict is not None:
|
||||
return module(*inputs, return_dict=return_dict)
|
||||
else:
|
||||
return module(*inputs)
|
||||
|
||||
return custom_forward
|
||||
|
||||
ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
|
||||
|
||||
for resnet in self.resnets:
|
||||
hidden_states = self._gradient_checkpointing_func(resnet, hidden_states)
|
||||
hidden_states = torch.utils.checkpoint.checkpoint(
|
||||
create_custom_forward(resnet), hidden_states, **ckpt_kwargs
|
||||
)
|
||||
else:
|
||||
for resnet in self.resnets:
|
||||
hidden_states = resnet(hidden_states)
|
||||
@@ -396,8 +426,22 @@ class HunyuanVideoUpBlock3D(nn.Module):
|
||||
|
||||
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
||||
if torch.is_grad_enabled() and self.gradient_checkpointing:
|
||||
|
||||
def create_custom_forward(module, return_dict=None):
|
||||
def custom_forward(*inputs):
|
||||
if return_dict is not None:
|
||||
return module(*inputs, return_dict=return_dict)
|
||||
else:
|
||||
return module(*inputs)
|
||||
|
||||
return custom_forward
|
||||
|
||||
ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
|
||||
|
||||
for resnet in self.resnets:
|
||||
hidden_states = self._gradient_checkpointing_func(resnet, hidden_states)
|
||||
hidden_states = torch.utils.checkpoint.checkpoint(
|
||||
create_custom_forward(resnet), hidden_states, **ckpt_kwargs
|
||||
)
|
||||
|
||||
else:
|
||||
for resnet in self.resnets:
|
||||
@@ -501,10 +545,26 @@ class HunyuanVideoEncoder3D(nn.Module):
|
||||
hidden_states = self.conv_in(hidden_states)
|
||||
|
||||
if torch.is_grad_enabled() and self.gradient_checkpointing:
|
||||
for down_block in self.down_blocks:
|
||||
hidden_states = self._gradient_checkpointing_func(down_block, hidden_states)
|
||||
|
||||
hidden_states = self._gradient_checkpointing_func(self.mid_block, hidden_states)
|
||||
def create_custom_forward(module, return_dict=None):
|
||||
def custom_forward(*inputs):
|
||||
if return_dict is not None:
|
||||
return module(*inputs, return_dict=return_dict)
|
||||
else:
|
||||
return module(*inputs)
|
||||
|
||||
return custom_forward
|
||||
|
||||
ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
|
||||
|
||||
for down_block in self.down_blocks:
|
||||
hidden_states = torch.utils.checkpoint.checkpoint(
|
||||
create_custom_forward(down_block), hidden_states, **ckpt_kwargs
|
||||
)
|
||||
|
||||
hidden_states = torch.utils.checkpoint.checkpoint(
|
||||
create_custom_forward(self.mid_block), hidden_states, **ckpt_kwargs
|
||||
)
|
||||
else:
|
||||
for down_block in self.down_blocks:
|
||||
hidden_states = down_block(hidden_states)
|
||||
@@ -607,10 +667,26 @@ class HunyuanVideoDecoder3D(nn.Module):
|
||||
hidden_states = self.conv_in(hidden_states)
|
||||
|
||||
if torch.is_grad_enabled() and self.gradient_checkpointing:
|
||||
hidden_states = self._gradient_checkpointing_func(self.mid_block, hidden_states)
|
||||
|
||||
def create_custom_forward(module, return_dict=None):
|
||||
def custom_forward(*inputs):
|
||||
if return_dict is not None:
|
||||
return module(*inputs, return_dict=return_dict)
|
||||
else:
|
||||
return module(*inputs)
|
||||
|
||||
return custom_forward
|
||||
|
||||
ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
|
||||
|
||||
hidden_states = torch.utils.checkpoint.checkpoint(
|
||||
create_custom_forward(self.mid_block), hidden_states, **ckpt_kwargs
|
||||
)
|
||||
|
||||
for up_block in self.up_blocks:
|
||||
hidden_states = self._gradient_checkpointing_func(up_block, hidden_states)
|
||||
hidden_states = torch.utils.checkpoint.checkpoint(
|
||||
create_custom_forward(up_block), hidden_states, **ckpt_kwargs
|
||||
)
|
||||
else:
|
||||
hidden_states = self.mid_block(hidden_states)
|
||||
|
||||
@@ -710,7 +786,7 @@ class AutoencoderKLHunyuanVideo(ModelMixin, ConfigMixin):
|
||||
self.use_tiling = False
|
||||
|
||||
# When decoding temporally long video latents, the memory requirement is very high. By decoding latent frames
|
||||
# at a fixed frame batch size (based on `self.tile_sample_min_num_frames`), the memory requirement can be lowered.
|
||||
# at a fixed frame batch size (based on `self.num_latent_frames_batch_sizes`), the memory requirement can be lowered.
|
||||
self.use_framewise_encoding = True
|
||||
self.use_framewise_decoding = True
|
||||
|
||||
@@ -724,6 +800,10 @@ class AutoencoderKLHunyuanVideo(ModelMixin, ConfigMixin):
|
||||
self.tile_sample_stride_width = 192
|
||||
self.tile_sample_stride_num_frames = 12
|
||||
|
||||
def _set_gradient_checkpointing(self, module, value=False):
|
||||
if isinstance(module, (HunyuanVideoEncoder3D, HunyuanVideoDecoder3D)):
|
||||
module.gradient_checkpointing = value
|
||||
|
||||
def enable_tiling(
|
||||
self,
|
||||
tile_sample_min_height: Optional[int] = None,
|
||||
@@ -788,7 +868,7 @@ class AutoencoderKLHunyuanVideo(ModelMixin, ConfigMixin):
|
||||
def _encode(self, x: torch.Tensor) -> torch.Tensor:
|
||||
batch_size, num_channels, num_frames, height, width = x.shape
|
||||
|
||||
if self.use_framewise_encoding and num_frames > self.tile_sample_min_num_frames:
|
||||
if self.use_framewise_decoding and num_frames > self.tile_sample_min_num_frames:
|
||||
return self._temporal_tiled_encode(x)
|
||||
|
||||
if self.use_tiling and (width > self.tile_sample_min_width or height > self.tile_sample_min_height):
|
||||
|
||||
@@ -338,7 +338,16 @@ class LTXVideoDownBlock3D(nn.Module):
|
||||
|
||||
for i, resnet in enumerate(self.resnets):
|
||||
if torch.is_grad_enabled() and self.gradient_checkpointing:
|
||||
hidden_states = self._gradient_checkpointing_func(resnet, hidden_states, temb, generator)
|
||||
|
||||
def create_custom_forward(module):
|
||||
def create_forward(*inputs):
|
||||
return module(*inputs)
|
||||
|
||||
return create_forward
|
||||
|
||||
hidden_states = torch.utils.checkpoint.checkpoint(
|
||||
create_custom_forward(resnet), hidden_states, temb, generator
|
||||
)
|
||||
else:
|
||||
hidden_states = resnet(hidden_states, temb, generator)
|
||||
|
||||
@@ -429,7 +438,16 @@ class LTXVideoMidBlock3d(nn.Module):
|
||||
|
||||
for i, resnet in enumerate(self.resnets):
|
||||
if torch.is_grad_enabled() and self.gradient_checkpointing:
|
||||
hidden_states = self._gradient_checkpointing_func(resnet, hidden_states, temb, generator)
|
||||
|
||||
def create_custom_forward(module):
|
||||
def create_forward(*inputs):
|
||||
return module(*inputs)
|
||||
|
||||
return create_forward
|
||||
|
||||
hidden_states = torch.utils.checkpoint.checkpoint(
|
||||
create_custom_forward(resnet), hidden_states, temb, generator
|
||||
)
|
||||
else:
|
||||
hidden_states = resnet(hidden_states, temb, generator)
|
||||
|
||||
@@ -555,7 +573,16 @@ class LTXVideoUpBlock3d(nn.Module):
|
||||
|
||||
for i, resnet in enumerate(self.resnets):
|
||||
if torch.is_grad_enabled() and self.gradient_checkpointing:
|
||||
hidden_states = self._gradient_checkpointing_func(resnet, hidden_states, temb, generator)
|
||||
|
||||
def create_custom_forward(module):
|
||||
def create_forward(*inputs):
|
||||
return module(*inputs)
|
||||
|
||||
return create_forward
|
||||
|
||||
hidden_states = torch.utils.checkpoint.checkpoint(
|
||||
create_custom_forward(resnet), hidden_states, temb, generator
|
||||
)
|
||||
else:
|
||||
hidden_states = resnet(hidden_states, temb, generator)
|
||||
|
||||
@@ -670,10 +697,17 @@ class LTXVideoEncoder3d(nn.Module):
|
||||
hidden_states = self.conv_in(hidden_states)
|
||||
|
||||
if torch.is_grad_enabled() and self.gradient_checkpointing:
|
||||
for down_block in self.down_blocks:
|
||||
hidden_states = self._gradient_checkpointing_func(down_block, hidden_states)
|
||||
|
||||
hidden_states = self._gradient_checkpointing_func(self.mid_block, hidden_states)
|
||||
def create_custom_forward(module):
|
||||
def create_forward(*inputs):
|
||||
return module(*inputs)
|
||||
|
||||
return create_forward
|
||||
|
||||
for down_block in self.down_blocks:
|
||||
hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(down_block), hidden_states)
|
||||
|
||||
hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(self.mid_block), hidden_states)
|
||||
else:
|
||||
for down_block in self.down_blocks:
|
||||
hidden_states = down_block(hidden_states)
|
||||
@@ -804,10 +838,19 @@ class LTXVideoDecoder3d(nn.Module):
|
||||
hidden_states = self.conv_in(hidden_states)
|
||||
|
||||
if torch.is_grad_enabled() and self.gradient_checkpointing:
|
||||
hidden_states = self._gradient_checkpointing_func(self.mid_block, hidden_states, temb)
|
||||
|
||||
def create_custom_forward(module):
|
||||
def create_forward(*inputs):
|
||||
return module(*inputs)
|
||||
|
||||
return create_forward
|
||||
|
||||
hidden_states = torch.utils.checkpoint.checkpoint(
|
||||
create_custom_forward(self.mid_block), hidden_states, temb
|
||||
)
|
||||
|
||||
for up_block in self.up_blocks:
|
||||
hidden_states = self._gradient_checkpointing_func(up_block, hidden_states, temb)
|
||||
hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(up_block), hidden_states, temb)
|
||||
else:
|
||||
hidden_states = self.mid_block(hidden_states, temb)
|
||||
|
||||
@@ -974,6 +1017,10 @@ class AutoencoderKLLTXVideo(ModelMixin, ConfigMixin, FromOriginalModelMixin):
|
||||
self.tile_sample_stride_width = 448
|
||||
self.tile_sample_stride_num_frames = 8
|
||||
|
||||
def _set_gradient_checkpointing(self, module, value=False):
|
||||
if isinstance(module, (LTXVideoEncoder3d, LTXVideoDecoder3d)):
|
||||
module.gradient_checkpointing = value
|
||||
|
||||
def enable_tiling(
|
||||
self,
|
||||
tile_sample_min_height: Optional[int] = None,
|
||||
|
||||
@@ -207,8 +207,15 @@ class MochiDownBlock3D(nn.Module):
|
||||
conv_cache_key = f"resnet_{i}"
|
||||
|
||||
if torch.is_grad_enabled() and self.gradient_checkpointing:
|
||||
hidden_states, new_conv_cache[conv_cache_key] = self._gradient_checkpointing_func(
|
||||
resnet,
|
||||
|
||||
def create_custom_forward(module):
|
||||
def create_forward(*inputs):
|
||||
return module(*inputs)
|
||||
|
||||
return create_forward
|
||||
|
||||
hidden_states, new_conv_cache[conv_cache_key] = torch.utils.checkpoint.checkpoint(
|
||||
create_custom_forward(resnet),
|
||||
hidden_states,
|
||||
conv_cache=conv_cache.get(conv_cache_key),
|
||||
)
|
||||
@@ -305,8 +312,15 @@ class MochiMidBlock3D(nn.Module):
|
||||
conv_cache_key = f"resnet_{i}"
|
||||
|
||||
if torch.is_grad_enabled() and self.gradient_checkpointing:
|
||||
hidden_states, new_conv_cache[conv_cache_key] = self._gradient_checkpointing_func(
|
||||
resnet, hidden_states, conv_cache=conv_cache.get(conv_cache_key)
|
||||
|
||||
def create_custom_forward(module):
|
||||
def create_forward(*inputs):
|
||||
return module(*inputs)
|
||||
|
||||
return create_forward
|
||||
|
||||
hidden_states, new_conv_cache[conv_cache_key] = torch.utils.checkpoint.checkpoint(
|
||||
create_custom_forward(resnet), hidden_states, conv_cache=conv_cache.get(conv_cache_key)
|
||||
)
|
||||
else:
|
||||
hidden_states, new_conv_cache[conv_cache_key] = resnet(
|
||||
@@ -379,8 +393,15 @@ class MochiUpBlock3D(nn.Module):
|
||||
conv_cache_key = f"resnet_{i}"
|
||||
|
||||
if torch.is_grad_enabled() and self.gradient_checkpointing:
|
||||
hidden_states, new_conv_cache[conv_cache_key] = self._gradient_checkpointing_func(
|
||||
resnet,
|
||||
|
||||
def create_custom_forward(module):
|
||||
def create_forward(*inputs):
|
||||
return module(*inputs)
|
||||
|
||||
return create_forward
|
||||
|
||||
hidden_states, new_conv_cache[conv_cache_key] = torch.utils.checkpoint.checkpoint(
|
||||
create_custom_forward(resnet),
|
||||
hidden_states,
|
||||
conv_cache=conv_cache.get(conv_cache_key),
|
||||
)
|
||||
@@ -510,14 +531,21 @@ class MochiEncoder3D(nn.Module):
|
||||
hidden_states = hidden_states.permute(0, 4, 1, 2, 3)
|
||||
|
||||
if torch.is_grad_enabled() and self.gradient_checkpointing:
|
||||
hidden_states, new_conv_cache["block_in"] = self._gradient_checkpointing_func(
|
||||
self.block_in, hidden_states, conv_cache=conv_cache.get("block_in")
|
||||
|
||||
def create_custom_forward(module):
|
||||
def create_forward(*inputs):
|
||||
return module(*inputs)
|
||||
|
||||
return create_forward
|
||||
|
||||
hidden_states, new_conv_cache["block_in"] = torch.utils.checkpoint.checkpoint(
|
||||
create_custom_forward(self.block_in), hidden_states, conv_cache=conv_cache.get("block_in")
|
||||
)
|
||||
|
||||
for i, down_block in enumerate(self.down_blocks):
|
||||
conv_cache_key = f"down_block_{i}"
|
||||
hidden_states, new_conv_cache[conv_cache_key] = self._gradient_checkpointing_func(
|
||||
down_block, hidden_states, conv_cache=conv_cache.get(conv_cache_key)
|
||||
hidden_states, new_conv_cache[conv_cache_key] = torch.utils.checkpoint.checkpoint(
|
||||
create_custom_forward(down_block), hidden_states, conv_cache=conv_cache.get(conv_cache_key)
|
||||
)
|
||||
else:
|
||||
hidden_states, new_conv_cache["block_in"] = self.block_in(
|
||||
@@ -620,14 +648,21 @@ class MochiDecoder3D(nn.Module):
|
||||
|
||||
# 1. Mid
|
||||
if torch.is_grad_enabled() and self.gradient_checkpointing:
|
||||
hidden_states, new_conv_cache["block_in"] = self._gradient_checkpointing_func(
|
||||
self.block_in, hidden_states, conv_cache=conv_cache.get("block_in")
|
||||
|
||||
def create_custom_forward(module):
|
||||
def create_forward(*inputs):
|
||||
return module(*inputs)
|
||||
|
||||
return create_forward
|
||||
|
||||
hidden_states, new_conv_cache["block_in"] = torch.utils.checkpoint.checkpoint(
|
||||
create_custom_forward(self.block_in), hidden_states, conv_cache=conv_cache.get("block_in")
|
||||
)
|
||||
|
||||
for i, up_block in enumerate(self.up_blocks):
|
||||
conv_cache_key = f"up_block_{i}"
|
||||
hidden_states, new_conv_cache[conv_cache_key] = self._gradient_checkpointing_func(
|
||||
up_block, hidden_states, conv_cache=conv_cache.get(conv_cache_key)
|
||||
hidden_states, new_conv_cache[conv_cache_key] = torch.utils.checkpoint.checkpoint(
|
||||
create_custom_forward(up_block), hidden_states, conv_cache=conv_cache.get(conv_cache_key)
|
||||
)
|
||||
else:
|
||||
hidden_states, new_conv_cache["block_in"] = self.block_in(
|
||||
@@ -784,6 +819,10 @@ class AutoencoderKLMochi(ModelMixin, ConfigMixin):
|
||||
self.tile_sample_stride_height = 192
|
||||
self.tile_sample_stride_width = 192
|
||||
|
||||
def _set_gradient_checkpointing(self, module, value=False):
|
||||
if isinstance(module, (MochiEncoder3D, MochiDecoder3D)):
|
||||
module.gradient_checkpointing = value
|
||||
|
||||
def enable_tiling(
|
||||
self,
|
||||
tile_sample_min_height: Optional[int] = None,
|
||||
|
||||
@@ -18,6 +18,7 @@ import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from ...configuration_utils import ConfigMixin, register_to_config
|
||||
from ...utils import is_torch_version
|
||||
from ...utils.accelerate_utils import apply_forward_hook
|
||||
from ..attention_processor import CROSS_ATTENTION_PROCESSORS, AttentionProcessor, AttnProcessor
|
||||
from ..modeling_outputs import AutoencoderKLOutput
|
||||
@@ -96,21 +97,47 @@ class TemporalDecoder(nn.Module):
|
||||
|
||||
upscale_dtype = next(itertools.chain(self.up_blocks.parameters(), self.up_blocks.buffers())).dtype
|
||||
if torch.is_grad_enabled() and self.gradient_checkpointing:
|
||||
# middle
|
||||
sample = self._gradient_checkpointing_func(
|
||||
self.mid_block,
|
||||
sample,
|
||||
image_only_indicator,
|
||||
)
|
||||
sample = sample.to(upscale_dtype)
|
||||
|
||||
# up
|
||||
for up_block in self.up_blocks:
|
||||
sample = self._gradient_checkpointing_func(
|
||||
up_block,
|
||||
def create_custom_forward(module):
|
||||
def custom_forward(*inputs):
|
||||
return module(*inputs)
|
||||
|
||||
return custom_forward
|
||||
|
||||
if is_torch_version(">=", "1.11.0"):
|
||||
# middle
|
||||
sample = torch.utils.checkpoint.checkpoint(
|
||||
create_custom_forward(self.mid_block),
|
||||
sample,
|
||||
image_only_indicator,
|
||||
use_reentrant=False,
|
||||
)
|
||||
sample = sample.to(upscale_dtype)
|
||||
|
||||
# up
|
||||
for up_block in self.up_blocks:
|
||||
sample = torch.utils.checkpoint.checkpoint(
|
||||
create_custom_forward(up_block),
|
||||
sample,
|
||||
image_only_indicator,
|
||||
use_reentrant=False,
|
||||
)
|
||||
else:
|
||||
# middle
|
||||
sample = torch.utils.checkpoint.checkpoint(
|
||||
create_custom_forward(self.mid_block),
|
||||
sample,
|
||||
image_only_indicator,
|
||||
)
|
||||
sample = sample.to(upscale_dtype)
|
||||
|
||||
# up
|
||||
for up_block in self.up_blocks:
|
||||
sample = torch.utils.checkpoint.checkpoint(
|
||||
create_custom_forward(up_block),
|
||||
sample,
|
||||
image_only_indicator,
|
||||
)
|
||||
else:
|
||||
# middle
|
||||
sample = self.mid_block(sample, image_only_indicator=image_only_indicator)
|
||||
@@ -202,6 +229,10 @@ class AutoencoderKLTemporalDecoder(ModelMixin, ConfigMixin):
|
||||
|
||||
self.quant_conv = nn.Conv2d(2 * latent_channels, 2 * latent_channels, 1)
|
||||
|
||||
def _set_gradient_checkpointing(self, module, value=False):
|
||||
if isinstance(module, (Encoder, TemporalDecoder)):
|
||||
module.gradient_checkpointing = value
|
||||
|
||||
@property
|
||||
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.attn_processors
|
||||
def attn_processors(self) -> Dict[str, AttentionProcessor]:
|
||||
|
||||
@@ -154,6 +154,10 @@ class AutoencoderTiny(ModelMixin, ConfigMixin):
|
||||
self.register_to_config(block_out_channels=decoder_block_out_channels)
|
||||
self.register_to_config(force_upcast=False)
|
||||
|
||||
def _set_gradient_checkpointing(self, module, value: bool = False) -> None:
|
||||
if isinstance(module, (EncoderTiny, DecoderTiny)):
|
||||
module.gradient_checkpointing = value
|
||||
|
||||
def scale_latents(self, x: torch.Tensor) -> torch.Tensor:
|
||||
"""raw latents -> [0, 1]"""
|
||||
return x.div(2 * self.latent_magnitude).add(self.latent_shift).clamp(0, 1)
|
||||
|
||||
@@ -18,7 +18,7 @@ import numpy as np
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from ...utils import BaseOutput
|
||||
from ...utils import BaseOutput, is_torch_version
|
||||
from ...utils.torch_utils import randn_tensor
|
||||
from ..activations import get_activation
|
||||
from ..attention_processor import SpatialNorm
|
||||
@@ -156,11 +156,28 @@ class Encoder(nn.Module):
|
||||
sample = self.conv_in(sample)
|
||||
|
||||
if torch.is_grad_enabled() and self.gradient_checkpointing:
|
||||
|
||||
def create_custom_forward(module):
|
||||
def custom_forward(*inputs):
|
||||
return module(*inputs)
|
||||
|
||||
return custom_forward
|
||||
|
||||
# down
|
||||
for down_block in self.down_blocks:
|
||||
sample = self._gradient_checkpointing_func(down_block, sample)
|
||||
# middle
|
||||
sample = self._gradient_checkpointing_func(self.mid_block, sample)
|
||||
if is_torch_version(">=", "1.11.0"):
|
||||
for down_block in self.down_blocks:
|
||||
sample = torch.utils.checkpoint.checkpoint(
|
||||
create_custom_forward(down_block), sample, use_reentrant=False
|
||||
)
|
||||
# middle
|
||||
sample = torch.utils.checkpoint.checkpoint(
|
||||
create_custom_forward(self.mid_block), sample, use_reentrant=False
|
||||
)
|
||||
else:
|
||||
for down_block in self.down_blocks:
|
||||
sample = torch.utils.checkpoint.checkpoint(create_custom_forward(down_block), sample)
|
||||
# middle
|
||||
sample = torch.utils.checkpoint.checkpoint(create_custom_forward(self.mid_block), sample)
|
||||
|
||||
else:
|
||||
# down
|
||||
@@ -288,13 +305,41 @@ class Decoder(nn.Module):
|
||||
|
||||
upscale_dtype = next(iter(self.up_blocks.parameters())).dtype
|
||||
if torch.is_grad_enabled() and self.gradient_checkpointing:
|
||||
# middle
|
||||
sample = self._gradient_checkpointing_func(self.mid_block, sample, latent_embeds)
|
||||
sample = sample.to(upscale_dtype)
|
||||
|
||||
# up
|
||||
for up_block in self.up_blocks:
|
||||
sample = self._gradient_checkpointing_func(up_block, sample, latent_embeds)
|
||||
def create_custom_forward(module):
|
||||
def custom_forward(*inputs):
|
||||
return module(*inputs)
|
||||
|
||||
return custom_forward
|
||||
|
||||
if is_torch_version(">=", "1.11.0"):
|
||||
# middle
|
||||
sample = torch.utils.checkpoint.checkpoint(
|
||||
create_custom_forward(self.mid_block),
|
||||
sample,
|
||||
latent_embeds,
|
||||
use_reentrant=False,
|
||||
)
|
||||
sample = sample.to(upscale_dtype)
|
||||
|
||||
# up
|
||||
for up_block in self.up_blocks:
|
||||
sample = torch.utils.checkpoint.checkpoint(
|
||||
create_custom_forward(up_block),
|
||||
sample,
|
||||
latent_embeds,
|
||||
use_reentrant=False,
|
||||
)
|
||||
else:
|
||||
# middle
|
||||
sample = torch.utils.checkpoint.checkpoint(
|
||||
create_custom_forward(self.mid_block), sample, latent_embeds
|
||||
)
|
||||
sample = sample.to(upscale_dtype)
|
||||
|
||||
# up
|
||||
for up_block in self.up_blocks:
|
||||
sample = torch.utils.checkpoint.checkpoint(create_custom_forward(up_block), sample, latent_embeds)
|
||||
else:
|
||||
# middle
|
||||
sample = self.mid_block(sample, latent_embeds)
|
||||
@@ -513,28 +558,72 @@ class MaskConditionDecoder(nn.Module):
|
||||
|
||||
upscale_dtype = next(iter(self.up_blocks.parameters())).dtype
|
||||
if torch.is_grad_enabled() and self.gradient_checkpointing:
|
||||
# middle
|
||||
sample = self._gradient_checkpointing_func(self.mid_block, sample, latent_embeds)
|
||||
sample = sample.to(upscale_dtype)
|
||||
|
||||
# condition encoder
|
||||
if image is not None and mask is not None:
|
||||
masked_image = (1 - mask) * image
|
||||
im_x = self._gradient_checkpointing_func(
|
||||
self.condition_encoder,
|
||||
masked_image,
|
||||
mask,
|
||||
def create_custom_forward(module):
|
||||
def custom_forward(*inputs):
|
||||
return module(*inputs)
|
||||
|
||||
return custom_forward
|
||||
|
||||
if is_torch_version(">=", "1.11.0"):
|
||||
# middle
|
||||
sample = torch.utils.checkpoint.checkpoint(
|
||||
create_custom_forward(self.mid_block),
|
||||
sample,
|
||||
latent_embeds,
|
||||
use_reentrant=False,
|
||||
)
|
||||
sample = sample.to(upscale_dtype)
|
||||
|
||||
# up
|
||||
for up_block in self.up_blocks:
|
||||
# condition encoder
|
||||
if image is not None and mask is not None:
|
||||
sample_ = im_x[str(tuple(sample.shape))]
|
||||
mask_ = nn.functional.interpolate(mask, size=sample.shape[-2:], mode="nearest")
|
||||
sample = sample * mask_ + sample_ * (1 - mask_)
|
||||
sample = self._gradient_checkpointing_func(up_block, sample, latent_embeds)
|
||||
if image is not None and mask is not None:
|
||||
sample = sample * mask + im_x[str(tuple(sample.shape))] * (1 - mask)
|
||||
masked_image = (1 - mask) * image
|
||||
im_x = torch.utils.checkpoint.checkpoint(
|
||||
create_custom_forward(self.condition_encoder),
|
||||
masked_image,
|
||||
mask,
|
||||
use_reentrant=False,
|
||||
)
|
||||
|
||||
# up
|
||||
for up_block in self.up_blocks:
|
||||
if image is not None and mask is not None:
|
||||
sample_ = im_x[str(tuple(sample.shape))]
|
||||
mask_ = nn.functional.interpolate(mask, size=sample.shape[-2:], mode="nearest")
|
||||
sample = sample * mask_ + sample_ * (1 - mask_)
|
||||
sample = torch.utils.checkpoint.checkpoint(
|
||||
create_custom_forward(up_block),
|
||||
sample,
|
||||
latent_embeds,
|
||||
use_reentrant=False,
|
||||
)
|
||||
if image is not None and mask is not None:
|
||||
sample = sample * mask + im_x[str(tuple(sample.shape))] * (1 - mask)
|
||||
else:
|
||||
# middle
|
||||
sample = torch.utils.checkpoint.checkpoint(
|
||||
create_custom_forward(self.mid_block), sample, latent_embeds
|
||||
)
|
||||
sample = sample.to(upscale_dtype)
|
||||
|
||||
# condition encoder
|
||||
if image is not None and mask is not None:
|
||||
masked_image = (1 - mask) * image
|
||||
im_x = torch.utils.checkpoint.checkpoint(
|
||||
create_custom_forward(self.condition_encoder),
|
||||
masked_image,
|
||||
mask,
|
||||
)
|
||||
|
||||
# up
|
||||
for up_block in self.up_blocks:
|
||||
if image is not None and mask is not None:
|
||||
sample_ = im_x[str(tuple(sample.shape))]
|
||||
mask_ = nn.functional.interpolate(mask, size=sample.shape[-2:], mode="nearest")
|
||||
sample = sample * mask_ + sample_ * (1 - mask_)
|
||||
sample = torch.utils.checkpoint.checkpoint(create_custom_forward(up_block), sample, latent_embeds)
|
||||
if image is not None and mask is not None:
|
||||
sample = sample * mask + im_x[str(tuple(sample.shape))] * (1 - mask)
|
||||
else:
|
||||
# middle
|
||||
sample = self.mid_block(sample, latent_embeds)
|
||||
@@ -801,7 +890,17 @@ class EncoderTiny(nn.Module):
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
r"""The forward method of the `EncoderTiny` class."""
|
||||
if torch.is_grad_enabled() and self.gradient_checkpointing:
|
||||
x = self._gradient_checkpointing_func(self.layers, x)
|
||||
|
||||
def create_custom_forward(module):
|
||||
def custom_forward(*inputs):
|
||||
return module(*inputs)
|
||||
|
||||
return custom_forward
|
||||
|
||||
if is_torch_version(">=", "1.11.0"):
|
||||
x = torch.utils.checkpoint.checkpoint(create_custom_forward(self.layers), x, use_reentrant=False)
|
||||
else:
|
||||
x = torch.utils.checkpoint.checkpoint(create_custom_forward(self.layers), x)
|
||||
|
||||
else:
|
||||
# scale image from [-1, 1] to [0, 1] to match TAESD convention
|
||||
@@ -877,7 +976,18 @@ class DecoderTiny(nn.Module):
|
||||
x = torch.tanh(x / 3) * 3
|
||||
|
||||
if torch.is_grad_enabled() and self.gradient_checkpointing:
|
||||
x = self._gradient_checkpointing_func(self.layers, x)
|
||||
|
||||
def create_custom_forward(module):
|
||||
def custom_forward(*inputs):
|
||||
return module(*inputs)
|
||||
|
||||
return custom_forward
|
||||
|
||||
if is_torch_version(">=", "1.11.0"):
|
||||
x = torch.utils.checkpoint.checkpoint(create_custom_forward(self.layers), x, use_reentrant=False)
|
||||
else:
|
||||
x = torch.utils.checkpoint.checkpoint(create_custom_forward(self.layers), x)
|
||||
|
||||
else:
|
||||
x = self.layers(x)
|
||||
|
||||
|
||||
@@ -71,8 +71,6 @@ class VQModel(ModelMixin, ConfigMixin):
|
||||
Type of normalization layer to use. Can be one of `"group"` or `"spatial"`.
|
||||
"""
|
||||
|
||||
_skip_layerwise_casting_patterns = ["quantize"]
|
||||
|
||||
@register_to_config
|
||||
def __init__(
|
||||
self,
|
||||
|
||||
@@ -1,89 +0,0 @@
|
||||
# Copyright 2024 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from ..utils.logging import get_logger
|
||||
|
||||
|
||||
logger = get_logger(__name__) # pylint: disable=invalid-name
|
||||
|
||||
|
||||
class CacheMixin:
|
||||
r"""
|
||||
A class for enable/disabling caching techniques on diffusion models.
|
||||
|
||||
Supported caching techniques:
|
||||
- [Pyramid Attention Broadcast](https://huggingface.co/papers/2408.12588)
|
||||
"""
|
||||
|
||||
_cache_config = None
|
||||
|
||||
@property
|
||||
def is_cache_enabled(self) -> bool:
|
||||
return self._cache_config is not None
|
||||
|
||||
def enable_cache(self, config) -> None:
|
||||
r"""
|
||||
Enable caching techniques on the model.
|
||||
|
||||
Args:
|
||||
config (`Union[PyramidAttentionBroadcastConfig]`):
|
||||
The configuration for applying the caching technique. Currently supported caching techniques are:
|
||||
- [`~hooks.PyramidAttentionBroadcastConfig`]
|
||||
|
||||
Example:
|
||||
|
||||
```python
|
||||
>>> import torch
|
||||
>>> from diffusers import CogVideoXPipeline, PyramidAttentionBroadcastConfig
|
||||
|
||||
>>> pipe = CogVideoXPipeline.from_pretrained("THUDM/CogVideoX-5b", torch_dtype=torch.bfloat16)
|
||||
>>> pipe.to("cuda")
|
||||
|
||||
>>> config = PyramidAttentionBroadcastConfig(
|
||||
... spatial_attention_block_skip_range=2,
|
||||
... spatial_attention_timestep_skip_range=(100, 800),
|
||||
... current_timestep_callback=lambda: pipe.current_timestep,
|
||||
... )
|
||||
>>> pipe.transformer.enable_cache(config)
|
||||
```
|
||||
"""
|
||||
|
||||
from ..hooks import PyramidAttentionBroadcastConfig, apply_pyramid_attention_broadcast
|
||||
|
||||
if isinstance(config, PyramidAttentionBroadcastConfig):
|
||||
apply_pyramid_attention_broadcast(self, config)
|
||||
else:
|
||||
raise ValueError(f"Cache config {type(config)} is not supported.")
|
||||
|
||||
self._cache_config = config
|
||||
|
||||
def disable_cache(self) -> None:
|
||||
from ..hooks import HookRegistry, PyramidAttentionBroadcastConfig
|
||||
|
||||
if self._cache_config is None:
|
||||
logger.warning("Caching techniques have not been enabled, so there's nothing to disable.")
|
||||
return
|
||||
|
||||
if isinstance(self._cache_config, PyramidAttentionBroadcastConfig):
|
||||
registry = HookRegistry.check_if_exists_or_initialize(self)
|
||||
registry.remove_hook("pyramid_attention_broadcast", recurse=True)
|
||||
else:
|
||||
raise ValueError(f"Cache config {type(self._cache_config)} is not supported.")
|
||||
|
||||
self._cache_config = None
|
||||
|
||||
def _reset_stateful_cache(self, recurse: bool = True) -> None:
|
||||
from ..hooks import HookRegistry
|
||||
|
||||
HookRegistry.check_if_exists_or_initialize(self).reset_stateful_hooks(recurse=recurse)
|
||||
@@ -31,6 +31,8 @@ from ..attention_processor import (
|
||||
from ..embeddings import TextImageProjection, TextImageTimeEmbedding, TextTimeEmbedding, TimestepEmbedding, Timesteps
|
||||
from ..modeling_utils import ModelMixin
|
||||
from ..unets.unet_2d_blocks import (
|
||||
CrossAttnDownBlock2D,
|
||||
DownBlock2D,
|
||||
UNetMidBlock2D,
|
||||
UNetMidBlock2DCrossAttn,
|
||||
get_down_block,
|
||||
@@ -657,6 +659,10 @@ class ControlNetModel(ModelMixin, ConfigMixin, FromOriginalModelMixin):
|
||||
for module in self.children():
|
||||
fn_recursive_set_attention_slice(module, reversed_slice_size)
|
||||
|
||||
def _set_gradient_checkpointing(self, module, value: bool = False) -> None:
|
||||
if isinstance(module, (CrossAttnDownBlock2D, DownBlock2D)):
|
||||
module.gradient_checkpointing = value
|
||||
|
||||
def forward(
|
||||
self,
|
||||
sample: torch.Tensor,
|
||||
|
||||
@@ -22,7 +22,7 @@ from ...configuration_utils import ConfigMixin, register_to_config
|
||||
from ...loaders import PeftAdapterMixin
|
||||
from ...models.attention_processor import AttentionProcessor
|
||||
from ...models.modeling_utils import ModelMixin
|
||||
from ...utils import USE_PEFT_BACKEND, BaseOutput, logging, scale_lora_layers, unscale_lora_layers
|
||||
from ...utils import USE_PEFT_BACKEND, BaseOutput, is_torch_version, logging, scale_lora_layers, unscale_lora_layers
|
||||
from ..controlnets.controlnet import ControlNetConditioningEmbedding, zero_module
|
||||
from ..embeddings import CombinedTimestepGuidanceTextProjEmbeddings, CombinedTimestepTextProjEmbeddings, FluxPosEmbed
|
||||
from ..modeling_outputs import Transformer2DModelOutput
|
||||
@@ -178,6 +178,10 @@ class FluxControlNetModel(ModelMixin, ConfigMixin, PeftAdapterMixin):
|
||||
for name, module in self.named_children():
|
||||
fn_recursive_attn_processor(name, module, processor)
|
||||
|
||||
def _set_gradient_checkpointing(self, module, value=False):
|
||||
if hasattr(module, "gradient_checkpointing"):
|
||||
module.gradient_checkpointing = value
|
||||
|
||||
@classmethod
|
||||
def from_transformer(
|
||||
cls,
|
||||
@@ -326,12 +330,24 @@ class FluxControlNetModel(ModelMixin, ConfigMixin, PeftAdapterMixin):
|
||||
block_samples = ()
|
||||
for index_block, block in enumerate(self.transformer_blocks):
|
||||
if torch.is_grad_enabled() and self.gradient_checkpointing:
|
||||
encoder_hidden_states, hidden_states = self._gradient_checkpointing_func(
|
||||
block,
|
||||
|
||||
def create_custom_forward(module, return_dict=None):
|
||||
def custom_forward(*inputs):
|
||||
if return_dict is not None:
|
||||
return module(*inputs, return_dict=return_dict)
|
||||
else:
|
||||
return module(*inputs)
|
||||
|
||||
return custom_forward
|
||||
|
||||
ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
|
||||
encoder_hidden_states, hidden_states = torch.utils.checkpoint.checkpoint(
|
||||
create_custom_forward(block),
|
||||
hidden_states,
|
||||
encoder_hidden_states,
|
||||
temb,
|
||||
image_rotary_emb,
|
||||
**ckpt_kwargs,
|
||||
)
|
||||
|
||||
else:
|
||||
@@ -348,11 +364,23 @@ class FluxControlNetModel(ModelMixin, ConfigMixin, PeftAdapterMixin):
|
||||
single_block_samples = ()
|
||||
for index_block, block in enumerate(self.single_transformer_blocks):
|
||||
if torch.is_grad_enabled() and self.gradient_checkpointing:
|
||||
hidden_states = self._gradient_checkpointing_func(
|
||||
block,
|
||||
|
||||
def create_custom_forward(module, return_dict=None):
|
||||
def custom_forward(*inputs):
|
||||
if return_dict is not None:
|
||||
return module(*inputs, return_dict=return_dict)
|
||||
else:
|
||||
return module(*inputs)
|
||||
|
||||
return custom_forward
|
||||
|
||||
ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
|
||||
hidden_states = torch.utils.checkpoint.checkpoint(
|
||||
create_custom_forward(block),
|
||||
hidden_states,
|
||||
temb,
|
||||
image_rotary_emb,
|
||||
**ckpt_kwargs,
|
||||
)
|
||||
|
||||
else:
|
||||
|
||||
@@ -21,7 +21,7 @@ import torch.nn as nn
|
||||
|
||||
from ...configuration_utils import ConfigMixin, register_to_config
|
||||
from ...loaders import FromOriginalModelMixin, PeftAdapterMixin
|
||||
from ...utils import USE_PEFT_BACKEND, logging, scale_lora_layers, unscale_lora_layers
|
||||
from ...utils import USE_PEFT_BACKEND, is_torch_version, logging, scale_lora_layers, unscale_lora_layers
|
||||
from ..attention import JointTransformerBlock
|
||||
from ..attention_processor import Attention, AttentionProcessor, FusedJointAttnProcessor2_0
|
||||
from ..embeddings import CombinedTimestepTextProjEmbeddings, PatchEmbed
|
||||
@@ -262,6 +262,10 @@ class SD3ControlNetModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginal
|
||||
if self.original_attn_processors is not None:
|
||||
self.set_attn_processor(self.original_attn_processors)
|
||||
|
||||
def _set_gradient_checkpointing(self, module, value=False):
|
||||
if hasattr(module, "gradient_checkpointing"):
|
||||
module.gradient_checkpointing = value
|
||||
|
||||
# Notes: This is for SD3.5 8b controlnet, which shares the pos_embed with the transformer
|
||||
# we should have handled this in conversion script
|
||||
def _get_pos_embed_from_transformer(self, transformer):
|
||||
@@ -378,16 +382,30 @@ class SD3ControlNetModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginal
|
||||
|
||||
for block in self.transformer_blocks:
|
||||
if torch.is_grad_enabled() and self.gradient_checkpointing:
|
||||
|
||||
def create_custom_forward(module, return_dict=None):
|
||||
def custom_forward(*inputs):
|
||||
if return_dict is not None:
|
||||
return module(*inputs, return_dict=return_dict)
|
||||
else:
|
||||
return module(*inputs)
|
||||
|
||||
return custom_forward
|
||||
|
||||
ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
|
||||
if self.context_embedder is not None:
|
||||
encoder_hidden_states, hidden_states = self._gradient_checkpointing_func(
|
||||
block,
|
||||
encoder_hidden_states, hidden_states = torch.utils.checkpoint.checkpoint(
|
||||
create_custom_forward(block),
|
||||
hidden_states,
|
||||
encoder_hidden_states,
|
||||
temb,
|
||||
**ckpt_kwargs,
|
||||
)
|
||||
else:
|
||||
# SD3.5 8b controlnet use single transformer block, which does not use `encoder_hidden_states`
|
||||
hidden_states = self._gradient_checkpointing_func(block, hidden_states, temb)
|
||||
hidden_states = torch.utils.checkpoint.checkpoint(
|
||||
create_custom_forward(block), hidden_states, temb, **ckpt_kwargs
|
||||
)
|
||||
|
||||
else:
|
||||
if self.context_embedder is not None:
|
||||
|
||||
@@ -590,6 +590,10 @@ class SparseControlNetModel(ModelMixin, ConfigMixin, FromOriginalModelMixin):
|
||||
for module in self.children():
|
||||
fn_recursive_set_attention_slice(module, reversed_slice_size)
|
||||
|
||||
def _set_gradient_checkpointing(self, module, value: bool = False) -> None:
|
||||
if isinstance(module, (CrossAttnDownBlockMotion, DownBlockMotion, UNetMidBlock2DCrossAttn)):
|
||||
module.gradient_checkpointing = value
|
||||
|
||||
def forward(
|
||||
self,
|
||||
sample: torch.Tensor,
|
||||
|
||||
@@ -29,6 +29,8 @@ from ..attention_processor import (
|
||||
from ..embeddings import TextImageTimeEmbedding, TextTimeEmbedding, TimestepEmbedding, Timesteps
|
||||
from ..modeling_utils import ModelMixin
|
||||
from ..unets.unet_2d_blocks import (
|
||||
CrossAttnDownBlock2D,
|
||||
DownBlock2D,
|
||||
UNetMidBlock2DCrossAttn,
|
||||
get_down_block,
|
||||
)
|
||||
@@ -597,6 +599,10 @@ class ControlNetUnionModel(ModelMixin, ConfigMixin, FromOriginalModelMixin):
|
||||
for module in self.children():
|
||||
fn_recursive_set_attention_slice(module, reversed_slice_size)
|
||||
|
||||
def _set_gradient_checkpointing(self, module, value: bool = False) -> None:
|
||||
if isinstance(module, (CrossAttnDownBlock2D, DownBlock2D)):
|
||||
module.gradient_checkpointing = value
|
||||
|
||||
def forward(
|
||||
self,
|
||||
sample: torch.Tensor,
|
||||
|
||||
@@ -20,7 +20,7 @@ import torch.utils.checkpoint
|
||||
from torch import Tensor, nn
|
||||
|
||||
from ...configuration_utils import ConfigMixin, register_to_config
|
||||
from ...utils import BaseOutput, logging
|
||||
from ...utils import BaseOutput, is_torch_version, logging
|
||||
from ...utils.torch_utils import apply_freeu
|
||||
from ..attention_processor import (
|
||||
ADDED_KV_ATTENTION_PROCESSORS,
|
||||
@@ -864,6 +864,10 @@ class UNetControlNetXSModel(ModelMixin, ConfigMixin):
|
||||
for u in self.up_blocks:
|
||||
u.freeze_base_params()
|
||||
|
||||
def _set_gradient_checkpointing(self, module, value=False):
|
||||
if hasattr(module, "gradient_checkpointing"):
|
||||
module.gradient_checkpointing = value
|
||||
|
||||
@property
|
||||
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.attn_processors
|
||||
def attn_processors(self) -> Dict[str, AttentionProcessor]:
|
||||
@@ -1446,6 +1450,15 @@ class ControlNetXSCrossAttnDownBlock2D(nn.Module):
|
||||
base_blocks = list(zip(self.base_resnets, self.base_attentions))
|
||||
ctrl_blocks = list(zip(self.ctrl_resnets, self.ctrl_attentions))
|
||||
|
||||
def create_custom_forward(module, return_dict=None):
|
||||
def custom_forward(*inputs):
|
||||
if return_dict is not None:
|
||||
return module(*inputs, return_dict=return_dict)
|
||||
else:
|
||||
return module(*inputs)
|
||||
|
||||
return custom_forward
|
||||
|
||||
for (b_res, b_attn), (c_res, c_attn), b2c, c2b in zip(
|
||||
base_blocks, ctrl_blocks, self.base_to_ctrl, self.ctrl_to_base
|
||||
):
|
||||
@@ -1455,7 +1468,13 @@ class ControlNetXSCrossAttnDownBlock2D(nn.Module):
|
||||
|
||||
# apply base subblock
|
||||
if torch.is_grad_enabled() and self.gradient_checkpointing:
|
||||
h_base = self._gradient_checkpointing_func(b_res, h_base, temb)
|
||||
ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
|
||||
h_base = torch.utils.checkpoint.checkpoint(
|
||||
create_custom_forward(b_res),
|
||||
h_base,
|
||||
temb,
|
||||
**ckpt_kwargs,
|
||||
)
|
||||
else:
|
||||
h_base = b_res(h_base, temb)
|
||||
|
||||
@@ -1472,7 +1491,13 @@ class ControlNetXSCrossAttnDownBlock2D(nn.Module):
|
||||
# apply ctrl subblock
|
||||
if apply_control:
|
||||
if torch.is_grad_enabled() and self.gradient_checkpointing:
|
||||
h_ctrl = self._gradient_checkpointing_func(c_res, h_ctrl, temb)
|
||||
ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
|
||||
h_ctrl = torch.utils.checkpoint.checkpoint(
|
||||
create_custom_forward(c_res),
|
||||
h_ctrl,
|
||||
temb,
|
||||
**ckpt_kwargs,
|
||||
)
|
||||
else:
|
||||
h_ctrl = c_res(h_ctrl, temb)
|
||||
if c_attn is not None:
|
||||
@@ -1837,6 +1862,15 @@ class ControlNetXSCrossAttnUpBlock2D(nn.Module):
|
||||
and getattr(self, "b2", None)
|
||||
)
|
||||
|
||||
def create_custom_forward(module, return_dict=None):
|
||||
def custom_forward(*inputs):
|
||||
if return_dict is not None:
|
||||
return module(*inputs, return_dict=return_dict)
|
||||
else:
|
||||
return module(*inputs)
|
||||
|
||||
return custom_forward
|
||||
|
||||
def maybe_apply_freeu_to_subblock(hidden_states, res_h_base):
|
||||
# FreeU: Only operate on the first two stages
|
||||
if is_freeu_enabled:
|
||||
@@ -1866,7 +1900,13 @@ class ControlNetXSCrossAttnUpBlock2D(nn.Module):
|
||||
hidden_states = torch.cat([hidden_states, res_h_base], dim=1)
|
||||
|
||||
if torch.is_grad_enabled() and self.gradient_checkpointing:
|
||||
hidden_states = self._gradient_checkpointing_func(resnet, hidden_states, temb)
|
||||
ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
|
||||
hidden_states = torch.utils.checkpoint.checkpoint(
|
||||
create_custom_forward(resnet),
|
||||
hidden_states,
|
||||
temb,
|
||||
**ckpt_kwargs,
|
||||
)
|
||||
else:
|
||||
hidden_states = resnet(hidden_states, temb)
|
||||
|
||||
|
||||
@@ -1787,7 +1787,7 @@ class LuminaCombinedTimestepCaptionEmbedding(nn.Module):
|
||||
def forward(self, timestep, caption_feat, caption_mask):
|
||||
# timestep embedding:
|
||||
time_freq = self.time_proj(timestep)
|
||||
time_embed = self.timestep_embedder(time_freq.to(dtype=caption_feat.dtype))
|
||||
time_embed = self.timestep_embedder(time_freq.to(dtype=self.timestep_embedder.linear_1.weight.dtype))
|
||||
|
||||
# caption condition embedding:
|
||||
caption_mask_float = caption_mask.float().unsqueeze(-1)
|
||||
|
||||
@@ -0,0 +1,186 @@
|
||||
# Copyright 2024 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
|
||||
from ..utils import logging
|
||||
from .activations import GEGLU, GELU, ApproximateGELU, LinearActivation, SwiGLU
|
||||
from .attention import FeedForward
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||
|
||||
|
||||
class _MemoryOptimizedFeedForward(torch.nn.Module):
|
||||
r"""
|
||||
See [`~models.attention.FeedForward`] parameter documentation. This class is a copy of the FeedForward class. The
|
||||
only difference is that this module is optimized for memory.
|
||||
|
||||
This method achieves memory savings by applying the ideas of tensor-parallelism sequentially. Input projection
|
||||
layers are split column-wise and output projection layers are split row-wise. This allows for the computation of
|
||||
the feedforward pass to occur without ever materializing the full intermediate tensor. Typically, the intermediate
|
||||
tensor takes 4x-8x more memory than the input tensor. This method reduces that with a small performance tradeoff.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
dim: int,
|
||||
dim_out: Optional[int] = None,
|
||||
mult: int = 4,
|
||||
dropout: float = 0.0,
|
||||
activation_fn: str = "geglu",
|
||||
final_dropout: bool = False,
|
||||
inner_dim: Optional[int] = None,
|
||||
bias: bool = True,
|
||||
num_splits: int = 4,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
|
||||
if inner_dim is None:
|
||||
inner_dim = int(dim * mult)
|
||||
|
||||
dim_out = dim_out if dim_out is not None else dim
|
||||
|
||||
dim_split = inner_dim // num_splits
|
||||
if inner_dim % dim_split != 0:
|
||||
raise ValueError(f"inner_dim must be divisible by {mult=}, or {num_splits=} if provided.")
|
||||
|
||||
self._dim = dim
|
||||
self._dim_out = dim_out
|
||||
self._mult = mult
|
||||
self._dropout = dropout
|
||||
self._activation_fn = activation_fn
|
||||
self._final_dropout = final_dropout
|
||||
self._inner_dim = inner_dim
|
||||
self._bias = bias
|
||||
self._num_splits = num_splits
|
||||
|
||||
def get_activation_fn(dim_: int, inner_dim_: int):
|
||||
if activation_fn == "gelu":
|
||||
act_fn = GELU(dim_, inner_dim_, bias=bias)
|
||||
if activation_fn == "gelu-approximate":
|
||||
act_fn = GELU(dim_, inner_dim_, approximate="tanh", bias=bias)
|
||||
elif activation_fn == "geglu":
|
||||
act_fn = GEGLU(dim_, inner_dim_, bias=bias)
|
||||
elif activation_fn == "geglu-approximate":
|
||||
act_fn = ApproximateGELU(dim_, inner_dim_, bias=bias)
|
||||
elif activation_fn == "swiglu":
|
||||
act_fn = SwiGLU(dim_, inner_dim_, bias=bias)
|
||||
elif activation_fn == "linear-silu":
|
||||
act_fn = LinearActivation(dim_, inner_dim_, bias=bias, activation="silu")
|
||||
return act_fn
|
||||
|
||||
# Split column-wise
|
||||
self.proj_in = torch.nn.ModuleList([get_activation_fn(dim, dim_split) for _ in range(inner_dim // dim_split)])
|
||||
|
||||
self.dropout = torch.nn.Dropout(dropout)
|
||||
|
||||
# Split row-wise
|
||||
self.proj_out = torch.nn.ModuleList(
|
||||
[torch.nn.Linear(dim_split, dim_out, bias=False) for _ in range(inner_dim // dim_split)]
|
||||
)
|
||||
|
||||
self.bias = None
|
||||
if bias:
|
||||
self.bias = torch.nn.Parameter(torch.zeros(dim_out))
|
||||
|
||||
self.final_dropout = None
|
||||
if final_dropout:
|
||||
self.final_dropout = torch.nn.Dropout(dropout)
|
||||
|
||||
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
||||
# Output tensor for "all_reduce" operation
|
||||
output = hidden_states.new_zeros(hidden_states.shape)
|
||||
|
||||
# Apply feedforward pass sequentially since this is intended for memory optimization on a single GPU
|
||||
for proj_in, proj_out in zip(self.proj_in, self.proj_out):
|
||||
out = proj_in(hidden_states)
|
||||
out = self.dropout(out)
|
||||
out = proj_out(out)
|
||||
# Perform "all_reduce"
|
||||
output += out
|
||||
|
||||
if self.bias is not None:
|
||||
output += self.bias
|
||||
if self.final_dropout is not None:
|
||||
output = self.final_dropout(output)
|
||||
|
||||
return output
|
||||
|
||||
|
||||
def apply_memory_optimized_feedforward(module: torch.nn.Module, num_splits: Optional[int] = None) -> torch.nn.Module:
|
||||
module_dict = dict(module.named_modules())
|
||||
|
||||
for name, submodule in module_dict.items():
|
||||
if not isinstance(submodule, FeedForward):
|
||||
continue
|
||||
|
||||
logger.debug(f"Applying memory optimized feedforward to layer '{name}'")
|
||||
state_dict = submodule.state_dict()
|
||||
num_splits = submodule._mult if num_splits is None else num_splits
|
||||
|
||||
# remap net.0.proj.weight
|
||||
if isinstance(submodule.net[0], (GEGLU, SwiGLU)):
|
||||
net_0_proj = state_dict.pop("net.0.proj.weight")
|
||||
proj, gate = net_0_proj.chunk(2, dim=0)
|
||||
proj = proj.chunk(num_splits, dim=0)
|
||||
gate = gate.chunk(num_splits, dim=0)
|
||||
for i in range(num_splits):
|
||||
state_dict[f"proj_in.{i}.proj.weight"] = torch.cat([proj[i], gate[i]], dim=0)
|
||||
else:
|
||||
net_0_proj = state_dict.pop("net.0.proj.weight")
|
||||
net_0_proj = net_0_proj.chunk(num_splits, dim=0)
|
||||
for i in range(num_splits):
|
||||
state_dict[f"proj_in.{i}.proj.weight"] = net_0_proj[i]
|
||||
|
||||
# remap net.0.proj.bias
|
||||
if "net.0.proj.bias" in state_dict:
|
||||
net_0_proj_bias = state_dict.pop("net.0.proj.bias")
|
||||
net_0_proj_bias = net_0_proj_bias.chunk(num_splits, dim=0)
|
||||
for i in range(num_splits):
|
||||
state_dict[f"proj_in.{i}.proj.bias"] = net_0_proj_bias[i]
|
||||
|
||||
# remap net.2.weight
|
||||
net_2_weight = state_dict.pop("net.2.weight")
|
||||
net_2_weight = net_2_weight.chunk(num_splits, dim=1)
|
||||
for i in range(num_splits):
|
||||
state_dict[f"proj_out.{i}.weight"] = net_2_weight[i]
|
||||
|
||||
# remap net.2.bias
|
||||
if "net.2.bias" in state_dict:
|
||||
net_2_bias = state_dict.pop("net.2.bias")
|
||||
state_dict["bias"] = net_2_bias
|
||||
|
||||
with torch.device("meta"):
|
||||
new_ff = _MemoryOptimizedFeedForward(
|
||||
dim=submodule._dim,
|
||||
dim_out=submodule._dim_out,
|
||||
mult=submodule._mult,
|
||||
dropout=submodule._dropout,
|
||||
activation_fn=submodule._activation_fn,
|
||||
final_dropout=submodule._final_dropout,
|
||||
inner_dim=submodule._inner_dim,
|
||||
bias=submodule._bias,
|
||||
num_splits=num_splits,
|
||||
)
|
||||
|
||||
new_ff.load_state_dict(state_dict, strict=True, assign=True)
|
||||
|
||||
parent_module_name, _, submodule_name = name.rpartition(".")
|
||||
parent_module = module_dict[parent_module_name]
|
||||
setattr(parent_module, submodule_name, new_ff)
|
||||
|
||||
return module
|
||||
@@ -21,19 +21,17 @@ import json
|
||||
import os
|
||||
import re
|
||||
from collections import OrderedDict
|
||||
from functools import wraps
|
||||
from functools import partial, wraps
|
||||
from pathlib import Path
|
||||
from typing import Any, Callable, Dict, List, Optional, Tuple, Type, Union
|
||||
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
|
||||
|
||||
import safetensors
|
||||
import torch
|
||||
import torch.utils.checkpoint
|
||||
from huggingface_hub import DDUFEntry, create_repo, split_torch_state_dict_into_shards
|
||||
from huggingface_hub.utils import validate_hf_hub_args
|
||||
from torch import Tensor, nn
|
||||
|
||||
from .. import __version__
|
||||
from ..hooks import apply_layerwise_casting
|
||||
from ..quantizers import DiffusersAutoQuantizer, DiffusersQuantizer
|
||||
from ..quantizers.quantization_config import QuantizationMethod
|
||||
from ..utils import (
|
||||
@@ -50,7 +48,6 @@ from ..utils import (
|
||||
is_accelerate_available,
|
||||
is_bitsandbytes_available,
|
||||
is_bitsandbytes_version,
|
||||
is_peft_available,
|
||||
is_torch_version,
|
||||
logging,
|
||||
)
|
||||
@@ -105,17 +102,6 @@ def get_parameter_dtype(parameter: torch.nn.Module) -> torch.dtype:
|
||||
"""
|
||||
Returns the first found floating dtype in parameters if there is one, otherwise returns the last dtype it found.
|
||||
"""
|
||||
# 1. Check if we have attached any dtype modifying hooks (eg. layerwise casting)
|
||||
if isinstance(parameter, nn.Module):
|
||||
for name, submodule in parameter.named_modules():
|
||||
if not hasattr(submodule, "_diffusers_hook"):
|
||||
continue
|
||||
registry = submodule._diffusers_hook
|
||||
hook = registry.get_hook("layerwise_casting")
|
||||
if hook is not None:
|
||||
return hook.compute_dtype
|
||||
|
||||
# 2. If no dtype modifying hooks are attached, return the dtype of the first floating point parameter/buffer
|
||||
last_dtype = None
|
||||
for param in parameter.parameters():
|
||||
last_dtype = param.dtype
|
||||
@@ -164,13 +150,10 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
|
||||
_keys_to_ignore_on_load_unexpected = None
|
||||
_no_split_modules = None
|
||||
_keep_in_fp32_modules = None
|
||||
_skip_layerwise_casting_patterns = None
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
self._gradient_checkpointing_func = None
|
||||
|
||||
def __getattr__(self, name: str) -> Any:
|
||||
"""The only reason we overwrite `getattr` here is to gracefully deprecate accessing
|
||||
config attributes directly. See https://github.com/huggingface/diffusers/pull/3129 We need to overwrite
|
||||
@@ -196,35 +179,14 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
|
||||
"""
|
||||
return any(hasattr(m, "gradient_checkpointing") and m.gradient_checkpointing for m in self.modules())
|
||||
|
||||
def enable_gradient_checkpointing(self, gradient_checkpointing_func: Optional[Callable] = None) -> None:
|
||||
def enable_gradient_checkpointing(self) -> None:
|
||||
"""
|
||||
Activates gradient checkpointing for the current model (may be referred to as *activation checkpointing* or
|
||||
*checkpoint activations* in other frameworks).
|
||||
|
||||
Args:
|
||||
gradient_checkpointing_func (`Callable`, *optional*):
|
||||
The function to use for gradient checkpointing. If `None`, the default PyTorch checkpointing function
|
||||
is used (`torch.utils.checkpoint.checkpoint`).
|
||||
"""
|
||||
if not self._supports_gradient_checkpointing:
|
||||
raise ValueError(
|
||||
f"{self.__class__.__name__} does not support gradient checkpointing. Please make sure to set the boolean attribute "
|
||||
f"`_supports_gradient_checkpointing` to `True` in the class definition."
|
||||
)
|
||||
|
||||
if gradient_checkpointing_func is None:
|
||||
|
||||
def _gradient_checkpointing_func(module, *args):
|
||||
ckpt_kwargs = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
|
||||
return torch.utils.checkpoint.checkpoint(
|
||||
module.__call__,
|
||||
*args,
|
||||
**ckpt_kwargs,
|
||||
)
|
||||
|
||||
gradient_checkpointing_func = _gradient_checkpointing_func
|
||||
|
||||
self._set_gradient_checkpointing(enable=True, gradient_checkpointing_func=gradient_checkpointing_func)
|
||||
raise ValueError(f"{self.__class__.__name__} does not support gradient checkpointing.")
|
||||
self.apply(partial(self._set_gradient_checkpointing, value=True))
|
||||
|
||||
def disable_gradient_checkpointing(self) -> None:
|
||||
"""
|
||||
@@ -232,7 +194,7 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
|
||||
*checkpoint activations* in other frameworks).
|
||||
"""
|
||||
if self._supports_gradient_checkpointing:
|
||||
self._set_gradient_checkpointing(enable=False)
|
||||
self.apply(partial(self._set_gradient_checkpointing, value=False))
|
||||
|
||||
def set_use_npu_flash_attention(self, valid: bool) -> None:
|
||||
r"""
|
||||
@@ -352,90 +314,6 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
|
||||
"""
|
||||
self.set_use_memory_efficient_attention_xformers(False)
|
||||
|
||||
def enable_layerwise_casting(
|
||||
self,
|
||||
storage_dtype: torch.dtype = torch.float8_e4m3fn,
|
||||
compute_dtype: Optional[torch.dtype] = None,
|
||||
skip_modules_pattern: Optional[Tuple[str, ...]] = None,
|
||||
skip_modules_classes: Optional[Tuple[Type[torch.nn.Module], ...]] = None,
|
||||
non_blocking: bool = False,
|
||||
) -> None:
|
||||
r"""
|
||||
Activates layerwise casting for the current model.
|
||||
|
||||
Layerwise casting is a technique that casts the model weights to a lower precision dtype for storage but
|
||||
upcasts them on-the-fly to a higher precision dtype for computation. This process can significantly reduce the
|
||||
memory footprint from model weights, but may lead to some quality degradation in the outputs. Most degradations
|
||||
are negligible, mostly stemming from weight casting in normalization and modulation layers.
|
||||
|
||||
By default, most models in diffusers set the `_skip_layerwise_casting_patterns` attribute to ignore patch
|
||||
embedding, positional embedding and normalization layers. This is because these layers are most likely
|
||||
precision-critical for quality. If you wish to change this behavior, you can set the
|
||||
`_skip_layerwise_casting_patterns` attribute to `None`, or call
|
||||
[`~hooks.layerwise_casting.apply_layerwise_casting`] with custom arguments.
|
||||
|
||||
Example:
|
||||
Using [`~models.ModelMixin.enable_layerwise_casting`]:
|
||||
|
||||
```python
|
||||
>>> from diffusers import CogVideoXTransformer3DModel
|
||||
|
||||
>>> transformer = CogVideoXTransformer3DModel.from_pretrained(
|
||||
... "THUDM/CogVideoX-5b", subfolder="transformer", torch_dtype=torch.bfloat16
|
||||
... )
|
||||
|
||||
>>> # Enable layerwise casting via the model, which ignores certain modules by default
|
||||
>>> transformer.enable_layerwise_casting(storage_dtype=torch.float8_e4m3fn, compute_dtype=torch.bfloat16)
|
||||
```
|
||||
|
||||
Args:
|
||||
storage_dtype (`torch.dtype`):
|
||||
The dtype to which the model should be cast for storage.
|
||||
compute_dtype (`torch.dtype`):
|
||||
The dtype to which the model weights should be cast during the forward pass.
|
||||
skip_modules_pattern (`Tuple[str, ...]`, *optional*):
|
||||
A list of patterns to match the names of the modules to skip during the layerwise casting process. If
|
||||
set to `None`, default skip patterns are used to ignore certain internal layers of modules and PEFT
|
||||
layers.
|
||||
skip_modules_classes (`Tuple[Type[torch.nn.Module], ...]`, *optional*):
|
||||
A list of module classes to skip during the layerwise casting process.
|
||||
non_blocking (`bool`, *optional*, defaults to `False`):
|
||||
If `True`, the weight casting operations are non-blocking.
|
||||
"""
|
||||
|
||||
user_provided_patterns = True
|
||||
if skip_modules_pattern is None:
|
||||
from ..hooks.layerwise_casting import DEFAULT_SKIP_MODULES_PATTERN
|
||||
|
||||
skip_modules_pattern = DEFAULT_SKIP_MODULES_PATTERN
|
||||
user_provided_patterns = False
|
||||
if self._keep_in_fp32_modules is not None:
|
||||
skip_modules_pattern += tuple(self._keep_in_fp32_modules)
|
||||
if self._skip_layerwise_casting_patterns is not None:
|
||||
skip_modules_pattern += tuple(self._skip_layerwise_casting_patterns)
|
||||
skip_modules_pattern = tuple(set(skip_modules_pattern))
|
||||
|
||||
if is_peft_available() and not user_provided_patterns:
|
||||
# By default, we want to skip all peft layers because they have a very low memory footprint.
|
||||
# If users want to apply layerwise casting on peft layers as well, they can utilize the
|
||||
# `~diffusers.hooks.layerwise_casting.apply_layerwise_casting` function which provides
|
||||
# them with more flexibility and control.
|
||||
|
||||
from peft.tuners.loha.layer import LoHaLayer
|
||||
from peft.tuners.lokr.layer import LoKrLayer
|
||||
from peft.tuners.lora.layer import LoraLayer
|
||||
|
||||
for layer in (LoHaLayer, LoKrLayer, LoraLayer):
|
||||
skip_modules_pattern += tuple(layer.adapter_layer_names)
|
||||
|
||||
if compute_dtype is None:
|
||||
logger.info("`compute_dtype` not provided when enabling layerwise casting. Using dtype of the model.")
|
||||
compute_dtype = self.dtype
|
||||
|
||||
apply_layerwise_casting(
|
||||
self, storage_dtype, compute_dtype, skip_modules_pattern, skip_modules_classes, non_blocking
|
||||
)
|
||||
|
||||
def save_pretrained(
|
||||
self,
|
||||
save_directory: Union[str, os.PathLike],
|
||||
@@ -1476,24 +1354,6 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
|
||||
mem = mem + mem_bufs
|
||||
return mem
|
||||
|
||||
def _set_gradient_checkpointing(
|
||||
self, enable: bool = True, gradient_checkpointing_func: Callable = torch.utils.checkpoint.checkpoint
|
||||
) -> None:
|
||||
is_gradient_checkpointing_set = False
|
||||
|
||||
for name, module in self.named_modules():
|
||||
if hasattr(module, "gradient_checkpointing"):
|
||||
logger.debug(f"Setting `gradient_checkpointing={enable}` for '{name}'")
|
||||
module._gradient_checkpointing_func = gradient_checkpointing_func
|
||||
module.gradient_checkpointing = enable
|
||||
is_gradient_checkpointing_set = True
|
||||
|
||||
if not is_gradient_checkpointing_set:
|
||||
raise ValueError(
|
||||
f"The module {self.__class__.__name__} does not support gradient checkpointing. Please make sure to "
|
||||
f"use a module that supports gradient checkpointing by creating a boolean attribute `gradient_checkpointing`."
|
||||
)
|
||||
|
||||
def _convert_deprecated_attention_blocks(self, state_dict: OrderedDict) -> None:
|
||||
deprecated_attention_block_paths = []
|
||||
|
||||
|
||||
@@ -13,7 +13,7 @@
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
from typing import Dict, Union
|
||||
from typing import Any, Dict, Union
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
@@ -21,7 +21,7 @@ import torch.nn.functional as F
|
||||
|
||||
from ...configuration_utils import ConfigMixin, register_to_config
|
||||
from ...loaders import FromOriginalModelMixin
|
||||
from ...utils import logging
|
||||
from ...utils import is_torch_version, logging
|
||||
from ...utils.torch_utils import maybe_allow_in_graph
|
||||
from ..attention_processor import (
|
||||
Attention,
|
||||
@@ -276,7 +276,6 @@ class AuraFlowTransformer2DModel(ModelMixin, ConfigMixin, FromOriginalModelMixin
|
||||
"""
|
||||
|
||||
_no_split_modules = ["AuraFlowJointTransformerBlock", "AuraFlowSingleTransformerBlock", "AuraFlowPatchEmbed"]
|
||||
_skip_layerwise_casting_patterns = ["pos_embed", "norm"]
|
||||
_supports_gradient_checkpointing = True
|
||||
|
||||
@register_to_config
|
||||
@@ -444,6 +443,10 @@ class AuraFlowTransformer2DModel(ModelMixin, ConfigMixin, FromOriginalModelMixin
|
||||
if self.original_attn_processors is not None:
|
||||
self.set_attn_processor(self.original_attn_processors)
|
||||
|
||||
def _set_gradient_checkpointing(self, module, value=False):
|
||||
if hasattr(module, "gradient_checkpointing"):
|
||||
module.gradient_checkpointing = value
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.FloatTensor,
|
||||
@@ -465,11 +468,23 @@ class AuraFlowTransformer2DModel(ModelMixin, ConfigMixin, FromOriginalModelMixin
|
||||
# MMDiT blocks.
|
||||
for index_block, block in enumerate(self.joint_transformer_blocks):
|
||||
if torch.is_grad_enabled() and self.gradient_checkpointing:
|
||||
encoder_hidden_states, hidden_states = self._gradient_checkpointing_func(
|
||||
block,
|
||||
|
||||
def create_custom_forward(module, return_dict=None):
|
||||
def custom_forward(*inputs):
|
||||
if return_dict is not None:
|
||||
return module(*inputs, return_dict=return_dict)
|
||||
else:
|
||||
return module(*inputs)
|
||||
|
||||
return custom_forward
|
||||
|
||||
ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
|
||||
encoder_hidden_states, hidden_states = torch.utils.checkpoint.checkpoint(
|
||||
create_custom_forward(block),
|
||||
hidden_states,
|
||||
encoder_hidden_states,
|
||||
temb,
|
||||
**ckpt_kwargs,
|
||||
)
|
||||
|
||||
else:
|
||||
@@ -484,10 +499,22 @@ class AuraFlowTransformer2DModel(ModelMixin, ConfigMixin, FromOriginalModelMixin
|
||||
|
||||
for index_block, block in enumerate(self.single_transformer_blocks):
|
||||
if torch.is_grad_enabled() and self.gradient_checkpointing:
|
||||
combined_hidden_states = self._gradient_checkpointing_func(
|
||||
block,
|
||||
|
||||
def create_custom_forward(module, return_dict=None):
|
||||
def custom_forward(*inputs):
|
||||
if return_dict is not None:
|
||||
return module(*inputs, return_dict=return_dict)
|
||||
else:
|
||||
return module(*inputs)
|
||||
|
||||
return custom_forward
|
||||
|
||||
ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
|
||||
combined_hidden_states = torch.utils.checkpoint.checkpoint(
|
||||
create_custom_forward(block),
|
||||
combined_hidden_states,
|
||||
temb,
|
||||
**ckpt_kwargs,
|
||||
)
|
||||
|
||||
else:
|
||||
|
||||
@@ -20,11 +20,10 @@ from torch import nn
|
||||
|
||||
from ...configuration_utils import ConfigMixin, register_to_config
|
||||
from ...loaders import PeftAdapterMixin
|
||||
from ...utils import USE_PEFT_BACKEND, logging, scale_lora_layers, unscale_lora_layers
|
||||
from ...utils import USE_PEFT_BACKEND, is_torch_version, logging, scale_lora_layers, unscale_lora_layers
|
||||
from ...utils.torch_utils import maybe_allow_in_graph
|
||||
from ..attention import Attention, FeedForward
|
||||
from ..attention_processor import AttentionProcessor, CogVideoXAttnProcessor2_0, FusedCogVideoXAttnProcessor2_0
|
||||
from ..cache_utils import CacheMixin
|
||||
from ..embeddings import CogVideoXPatchEmbed, TimestepEmbedding, Timesteps
|
||||
from ..modeling_outputs import Transformer2DModelOutput
|
||||
from ..modeling_utils import ModelMixin
|
||||
@@ -157,7 +156,7 @@ class CogVideoXBlock(nn.Module):
|
||||
return hidden_states, encoder_hidden_states
|
||||
|
||||
|
||||
class CogVideoXTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, CacheMixin):
|
||||
class CogVideoXTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin):
|
||||
"""
|
||||
A Transformer model for video-like data in [CogVideoX](https://github.com/THUDM/CogVideo).
|
||||
|
||||
@@ -213,7 +212,6 @@ class CogVideoXTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, Cac
|
||||
Scaling factor to apply in 3D positional embeddings across temporal dimensions.
|
||||
"""
|
||||
|
||||
_skip_layerwise_casting_patterns = ["patch_embed", "norm"]
|
||||
_supports_gradient_checkpointing = True
|
||||
_no_split_modules = ["CogVideoXBlock", "CogVideoXPatchEmbed"]
|
||||
|
||||
@@ -331,6 +329,9 @@ class CogVideoXTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, Cac
|
||||
|
||||
self.gradient_checkpointing = False
|
||||
|
||||
def _set_gradient_checkpointing(self, module, value=False):
|
||||
self.gradient_checkpointing = value
|
||||
|
||||
@property
|
||||
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.attn_processors
|
||||
def attn_processors(self) -> Dict[str, AttentionProcessor]:
|
||||
@@ -486,13 +487,22 @@ class CogVideoXTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, Cac
|
||||
# 3. Transformer blocks
|
||||
for i, block in enumerate(self.transformer_blocks):
|
||||
if torch.is_grad_enabled() and self.gradient_checkpointing:
|
||||
hidden_states, encoder_hidden_states = self._gradient_checkpointing_func(
|
||||
block,
|
||||
|
||||
def create_custom_forward(module):
|
||||
def custom_forward(*inputs):
|
||||
return module(*inputs)
|
||||
|
||||
return custom_forward
|
||||
|
||||
ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
|
||||
hidden_states, encoder_hidden_states = torch.utils.checkpoint.checkpoint(
|
||||
create_custom_forward(block),
|
||||
hidden_states,
|
||||
encoder_hidden_states,
|
||||
emb,
|
||||
image_rotary_emb,
|
||||
attention_kwargs,
|
||||
**ckpt_kwargs,
|
||||
)
|
||||
else:
|
||||
hidden_states, encoder_hidden_states = block(
|
||||
|
||||
@@ -20,7 +20,7 @@ from torch import nn
|
||||
|
||||
from ...configuration_utils import ConfigMixin, register_to_config
|
||||
from ...loaders import PeftAdapterMixin
|
||||
from ...utils import USE_PEFT_BACKEND, logging, scale_lora_layers, unscale_lora_layers
|
||||
from ...utils import USE_PEFT_BACKEND, is_torch_version, logging, scale_lora_layers, unscale_lora_layers
|
||||
from ...utils.torch_utils import maybe_allow_in_graph
|
||||
from ..attention import Attention, FeedForward
|
||||
from ..attention_processor import AttentionProcessor, CogVideoXAttnProcessor2_0
|
||||
@@ -595,6 +595,9 @@ class ConsisIDTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin):
|
||||
|
||||
self.gradient_checkpointing = False
|
||||
|
||||
def _set_gradient_checkpointing(self, module, value=False):
|
||||
self.gradient_checkpointing = value
|
||||
|
||||
def _init_face_inputs(self):
|
||||
self.local_facial_extractor = LocalFacialExtractor(
|
||||
id_dim=self.LFE_id_dim,
|
||||
@@ -742,13 +745,22 @@ class ConsisIDTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin):
|
||||
# 3. Transformer blocks
|
||||
ca_idx = 0
|
||||
for i, block in enumerate(self.transformer_blocks):
|
||||
if torch.is_grad_enabled() and self.gradient_checkpointing:
|
||||
hidden_states, encoder_hidden_states = self._gradient_checkpointing_func(
|
||||
block,
|
||||
if self.training and self.gradient_checkpointing:
|
||||
|
||||
def create_custom_forward(module):
|
||||
def custom_forward(*inputs):
|
||||
return module(*inputs)
|
||||
|
||||
return custom_forward
|
||||
|
||||
ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
|
||||
hidden_states, encoder_hidden_states = torch.utils.checkpoint.checkpoint(
|
||||
create_custom_forward(block),
|
||||
hidden_states,
|
||||
encoder_hidden_states,
|
||||
emb,
|
||||
image_rotary_emb,
|
||||
**ckpt_kwargs,
|
||||
)
|
||||
else:
|
||||
hidden_states, encoder_hidden_states = block(
|
||||
|
||||
@@ -18,7 +18,7 @@ import torch.nn.functional as F
|
||||
from torch import nn
|
||||
|
||||
from ...configuration_utils import ConfigMixin, register_to_config
|
||||
from ...utils import logging
|
||||
from ...utils import is_torch_version, logging
|
||||
from ..attention import BasicTransformerBlock
|
||||
from ..embeddings import PatchEmbed
|
||||
from ..modeling_outputs import Transformer2DModelOutput
|
||||
@@ -64,7 +64,6 @@ class DiTTransformer2DModel(ModelMixin, ConfigMixin):
|
||||
A small constant added to the denominator in normalization layers to prevent division by zero.
|
||||
"""
|
||||
|
||||
_skip_layerwise_casting_patterns = ["pos_embed", "norm"]
|
||||
_supports_gradient_checkpointing = True
|
||||
|
||||
@register_to_config
|
||||
@@ -144,6 +143,10 @@ class DiTTransformer2DModel(ModelMixin, ConfigMixin):
|
||||
self.inner_dim, self.config.patch_size * self.config.patch_size * self.out_channels
|
||||
)
|
||||
|
||||
def _set_gradient_checkpointing(self, module, value=False):
|
||||
if hasattr(module, "gradient_checkpointing"):
|
||||
module.gradient_checkpointing = value
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
@@ -182,8 +185,19 @@ class DiTTransformer2DModel(ModelMixin, ConfigMixin):
|
||||
# 2. Blocks
|
||||
for block in self.transformer_blocks:
|
||||
if torch.is_grad_enabled() and self.gradient_checkpointing:
|
||||
hidden_states = self._gradient_checkpointing_func(
|
||||
block,
|
||||
|
||||
def create_custom_forward(module, return_dict=None):
|
||||
def custom_forward(*inputs):
|
||||
if return_dict is not None:
|
||||
return module(*inputs, return_dict=return_dict)
|
||||
else:
|
||||
return module(*inputs)
|
||||
|
||||
return custom_forward
|
||||
|
||||
ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
|
||||
hidden_states = torch.utils.checkpoint.checkpoint(
|
||||
create_custom_forward(block),
|
||||
hidden_states,
|
||||
None,
|
||||
None,
|
||||
@@ -191,6 +205,7 @@ class DiTTransformer2DModel(ModelMixin, ConfigMixin):
|
||||
timestep,
|
||||
cross_attention_kwargs,
|
||||
class_labels,
|
||||
**ckpt_kwargs,
|
||||
)
|
||||
else:
|
||||
hidden_states = block(
|
||||
|
||||
@@ -244,8 +244,6 @@ class HunyuanDiT2DModel(ModelMixin, ConfigMixin):
|
||||
Whether or not to use style condition and image meta size. True for version <=1.1, False for version >= 1.2
|
||||
"""
|
||||
|
||||
_skip_layerwise_casting_patterns = ["pos_embed", "norm", "pooler"]
|
||||
|
||||
@register_to_config
|
||||
def __init__(
|
||||
self,
|
||||
|
||||
@@ -11,7 +11,6 @@
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
@@ -20,14 +19,13 @@ from torch import nn
|
||||
from ...configuration_utils import ConfigMixin, register_to_config
|
||||
from ...models.embeddings import PixArtAlphaTextProjection, get_1d_sincos_pos_embed_from_grid
|
||||
from ..attention import BasicTransformerBlock
|
||||
from ..cache_utils import CacheMixin
|
||||
from ..embeddings import PatchEmbed
|
||||
from ..modeling_outputs import Transformer2DModelOutput
|
||||
from ..modeling_utils import ModelMixin
|
||||
from ..normalization import AdaLayerNormSingle
|
||||
|
||||
|
||||
class LatteTransformer3DModel(ModelMixin, ConfigMixin, CacheMixin):
|
||||
class LatteTransformer3DModel(ModelMixin, ConfigMixin):
|
||||
_supports_gradient_checkpointing = True
|
||||
|
||||
"""
|
||||
@@ -67,8 +65,6 @@ class LatteTransformer3DModel(ModelMixin, ConfigMixin, CacheMixin):
|
||||
The number of frames in the video-like data.
|
||||
"""
|
||||
|
||||
_skip_layerwise_casting_patterns = ["pos_embed", "norm"]
|
||||
|
||||
@register_to_config
|
||||
def __init__(
|
||||
self,
|
||||
@@ -166,6 +162,9 @@ class LatteTransformer3DModel(ModelMixin, ConfigMixin, CacheMixin):
|
||||
|
||||
self.gradient_checkpointing = False
|
||||
|
||||
def _set_gradient_checkpointing(self, module, value=False):
|
||||
self.gradient_checkpointing = value
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
@@ -240,7 +239,7 @@ class LatteTransformer3DModel(ModelMixin, ConfigMixin, CacheMixin):
|
||||
zip(self.transformer_blocks, self.temporal_transformer_blocks)
|
||||
):
|
||||
if torch.is_grad_enabled() and self.gradient_checkpointing:
|
||||
hidden_states = self._gradient_checkpointing_func(
|
||||
hidden_states = torch.utils.checkpoint.checkpoint(
|
||||
spatial_block,
|
||||
hidden_states,
|
||||
None, # attention_mask
|
||||
@@ -249,6 +248,7 @@ class LatteTransformer3DModel(ModelMixin, ConfigMixin, CacheMixin):
|
||||
timestep_spatial,
|
||||
None, # cross_attention_kwargs
|
||||
None, # class_labels
|
||||
use_reentrant=False,
|
||||
)
|
||||
else:
|
||||
hidden_states = spatial_block(
|
||||
@@ -272,7 +272,7 @@ class LatteTransformer3DModel(ModelMixin, ConfigMixin, CacheMixin):
|
||||
hidden_states = hidden_states + self.temp_pos_embed
|
||||
|
||||
if torch.is_grad_enabled() and self.gradient_checkpointing:
|
||||
hidden_states = self._gradient_checkpointing_func(
|
||||
hidden_states = torch.utils.checkpoint.checkpoint(
|
||||
temp_block,
|
||||
hidden_states,
|
||||
None, # attention_mask
|
||||
@@ -281,6 +281,7 @@ class LatteTransformer3DModel(ModelMixin, ConfigMixin, CacheMixin):
|
||||
timestep_temp,
|
||||
None, # cross_attention_kwargs
|
||||
None, # class_labels
|
||||
use_reentrant=False,
|
||||
)
|
||||
else:
|
||||
hidden_states = temp_block(
|
||||
|
||||
@@ -221,8 +221,6 @@ class LuminaNextDiT2DModel(ModelMixin, ConfigMixin):
|
||||
overall scale of the model's operations.
|
||||
"""
|
||||
|
||||
_skip_layerwise_casting_patterns = ["patch_embedder", "norm", "ffn_norm"]
|
||||
|
||||
@register_to_config
|
||||
def __init__(
|
||||
self,
|
||||
|
||||
@@ -17,7 +17,7 @@ import torch
|
||||
from torch import nn
|
||||
|
||||
from ...configuration_utils import ConfigMixin, register_to_config
|
||||
from ...utils import logging
|
||||
from ...utils import is_torch_version, logging
|
||||
from ..attention import BasicTransformerBlock
|
||||
from ..attention_processor import Attention, AttentionProcessor, AttnProcessor, FusedAttnProcessor2_0
|
||||
from ..embeddings import PatchEmbed, PixArtAlphaTextProjection
|
||||
@@ -79,7 +79,6 @@ class PixArtTransformer2DModel(ModelMixin, ConfigMixin):
|
||||
|
||||
_supports_gradient_checkpointing = True
|
||||
_no_split_modules = ["BasicTransformerBlock", "PatchEmbed"]
|
||||
_skip_layerwise_casting_patterns = ["pos_embed", "norm", "adaln_single"]
|
||||
|
||||
@register_to_config
|
||||
def __init__(
|
||||
@@ -184,6 +183,10 @@ class PixArtTransformer2DModel(ModelMixin, ConfigMixin):
|
||||
in_features=self.config.caption_channels, hidden_size=self.inner_dim
|
||||
)
|
||||
|
||||
def _set_gradient_checkpointing(self, module, value=False):
|
||||
if hasattr(module, "gradient_checkpointing"):
|
||||
module.gradient_checkpointing = value
|
||||
|
||||
@property
|
||||
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.attn_processors
|
||||
def attn_processors(self) -> Dict[str, AttentionProcessor]:
|
||||
@@ -384,8 +387,19 @@ class PixArtTransformer2DModel(ModelMixin, ConfigMixin):
|
||||
# 2. Blocks
|
||||
for block in self.transformer_blocks:
|
||||
if torch.is_grad_enabled() and self.gradient_checkpointing:
|
||||
hidden_states = self._gradient_checkpointing_func(
|
||||
block,
|
||||
|
||||
def create_custom_forward(module, return_dict=None):
|
||||
def custom_forward(*inputs):
|
||||
if return_dict is not None:
|
||||
return module(*inputs, return_dict=return_dict)
|
||||
else:
|
||||
return module(*inputs)
|
||||
|
||||
return custom_forward
|
||||
|
||||
ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
|
||||
hidden_states = torch.utils.checkpoint.checkpoint(
|
||||
create_custom_forward(block),
|
||||
hidden_states,
|
||||
attention_mask,
|
||||
encoder_hidden_states,
|
||||
@@ -393,6 +407,7 @@ class PixArtTransformer2DModel(ModelMixin, ConfigMixin):
|
||||
timestep,
|
||||
cross_attention_kwargs,
|
||||
None,
|
||||
**ckpt_kwargs,
|
||||
)
|
||||
else:
|
||||
hidden_states = block(
|
||||
|
||||
@@ -19,7 +19,7 @@ from torch import nn
|
||||
|
||||
from ...configuration_utils import ConfigMixin, register_to_config
|
||||
from ...loaders import PeftAdapterMixin
|
||||
from ...utils import USE_PEFT_BACKEND, logging, scale_lora_layers, unscale_lora_layers
|
||||
from ...utils import USE_PEFT_BACKEND, is_torch_version, logging, scale_lora_layers, unscale_lora_layers
|
||||
from ..attention_processor import (
|
||||
Attention,
|
||||
AttentionProcessor,
|
||||
@@ -236,7 +236,6 @@ class SanaTransformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin):
|
||||
|
||||
_supports_gradient_checkpointing = True
|
||||
_no_split_modules = ["SanaTransformerBlock", "PatchEmbed", "SanaModulatedNorm"]
|
||||
_skip_layerwise_casting_patterns = ["patch_embed", "norm"]
|
||||
|
||||
@register_to_config
|
||||
def __init__(
|
||||
@@ -308,6 +307,10 @@ class SanaTransformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin):
|
||||
|
||||
self.gradient_checkpointing = False
|
||||
|
||||
def _set_gradient_checkpointing(self, module, value=False):
|
||||
if hasattr(module, "gradient_checkpointing"):
|
||||
module.gradient_checkpointing = value
|
||||
|
||||
@property
|
||||
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.attn_processors
|
||||
def attn_processors(self) -> Dict[str, AttentionProcessor]:
|
||||
@@ -434,9 +437,21 @@ class SanaTransformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin):
|
||||
|
||||
# 2. Transformer blocks
|
||||
if torch.is_grad_enabled() and self.gradient_checkpointing:
|
||||
|
||||
def create_custom_forward(module, return_dict=None):
|
||||
def custom_forward(*inputs):
|
||||
if return_dict is not None:
|
||||
return module(*inputs, return_dict=return_dict)
|
||||
else:
|
||||
return module(*inputs)
|
||||
|
||||
return custom_forward
|
||||
|
||||
ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
|
||||
|
||||
for block in self.transformer_blocks:
|
||||
hidden_states = self._gradient_checkpointing_func(
|
||||
block,
|
||||
hidden_states = torch.utils.checkpoint.checkpoint(
|
||||
create_custom_forward(block),
|
||||
hidden_states,
|
||||
attention_mask,
|
||||
encoder_hidden_states,
|
||||
@@ -444,6 +459,7 @@ class SanaTransformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin):
|
||||
timestep,
|
||||
post_patch_height,
|
||||
post_patch_width,
|
||||
**ckpt_kwargs,
|
||||
)
|
||||
|
||||
else:
|
||||
|
||||
@@ -13,7 +13,7 @@
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
from typing import Dict, Optional, Union
|
||||
from typing import Any, Dict, Optional, Union
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
@@ -29,7 +29,7 @@ from ...models.attention_processor import (
|
||||
)
|
||||
from ...models.modeling_utils import ModelMixin
|
||||
from ...models.transformers.transformer_2d import Transformer2DModelOutput
|
||||
from ...utils import logging
|
||||
from ...utils import is_torch_version, logging
|
||||
from ...utils.torch_utils import maybe_allow_in_graph
|
||||
|
||||
|
||||
@@ -211,7 +211,6 @@ class StableAudioDiTModel(ModelMixin, ConfigMixin):
|
||||
"""
|
||||
|
||||
_supports_gradient_checkpointing = True
|
||||
_skip_layerwise_casting_patterns = ["preprocess_conv", "postprocess_conv", "^proj_in$", "^proj_out$", "norm"]
|
||||
|
||||
@register_to_config
|
||||
def __init__(
|
||||
@@ -346,6 +345,10 @@ class StableAudioDiTModel(ModelMixin, ConfigMixin):
|
||||
"""
|
||||
self.set_attn_processor(StableAudioAttnProcessor2_0())
|
||||
|
||||
def _set_gradient_checkpointing(self, module, value=False):
|
||||
if hasattr(module, "gradient_checkpointing"):
|
||||
module.gradient_checkpointing = value
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.FloatTensor,
|
||||
@@ -412,13 +415,25 @@ class StableAudioDiTModel(ModelMixin, ConfigMixin):
|
||||
|
||||
for block in self.transformer_blocks:
|
||||
if torch.is_grad_enabled() and self.gradient_checkpointing:
|
||||
hidden_states = self._gradient_checkpointing_func(
|
||||
block,
|
||||
|
||||
def create_custom_forward(module, return_dict=None):
|
||||
def custom_forward(*inputs):
|
||||
if return_dict is not None:
|
||||
return module(*inputs, return_dict=return_dict)
|
||||
else:
|
||||
return module(*inputs)
|
||||
|
||||
return custom_forward
|
||||
|
||||
ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
|
||||
hidden_states = torch.utils.checkpoint.checkpoint(
|
||||
create_custom_forward(block),
|
||||
hidden_states,
|
||||
attention_mask,
|
||||
cross_attention_hidden_states,
|
||||
encoder_attention_mask,
|
||||
rotary_embedding,
|
||||
**ckpt_kwargs,
|
||||
)
|
||||
|
||||
else:
|
||||
|
||||
@@ -18,7 +18,7 @@ import torch.nn.functional as F
|
||||
from torch import nn
|
||||
|
||||
from ...configuration_utils import LegacyConfigMixin, register_to_config
|
||||
from ...utils import deprecate, logging
|
||||
from ...utils import deprecate, is_torch_version, logging
|
||||
from ..attention import BasicTransformerBlock
|
||||
from ..embeddings import ImagePositionalEmbeddings, PatchEmbed, PixArtAlphaTextProjection
|
||||
from ..modeling_outputs import Transformer2DModelOutput
|
||||
@@ -66,7 +66,6 @@ class Transformer2DModel(LegacyModelMixin, LegacyConfigMixin):
|
||||
|
||||
_supports_gradient_checkpointing = True
|
||||
_no_split_modules = ["BasicTransformerBlock"]
|
||||
_skip_layerwise_casting_patterns = ["latent_image_embedding", "norm"]
|
||||
|
||||
@register_to_config
|
||||
def __init__(
|
||||
@@ -321,6 +320,10 @@ class Transformer2DModel(LegacyModelMixin, LegacyConfigMixin):
|
||||
in_features=self.caption_channels, hidden_size=self.inner_dim
|
||||
)
|
||||
|
||||
def _set_gradient_checkpointing(self, module, value=False):
|
||||
if hasattr(module, "gradient_checkpointing"):
|
||||
module.gradient_checkpointing = value
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
@@ -413,8 +416,19 @@ class Transformer2DModel(LegacyModelMixin, LegacyConfigMixin):
|
||||
# 2. Blocks
|
||||
for block in self.transformer_blocks:
|
||||
if torch.is_grad_enabled() and self.gradient_checkpointing:
|
||||
hidden_states = self._gradient_checkpointing_func(
|
||||
block,
|
||||
|
||||
def create_custom_forward(module, return_dict=None):
|
||||
def custom_forward(*inputs):
|
||||
if return_dict is not None:
|
||||
return module(*inputs, return_dict=return_dict)
|
||||
else:
|
||||
return module(*inputs)
|
||||
|
||||
return custom_forward
|
||||
|
||||
ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
|
||||
hidden_states = torch.utils.checkpoint.checkpoint(
|
||||
create_custom_forward(block),
|
||||
hidden_states,
|
||||
attention_mask,
|
||||
encoder_hidden_states,
|
||||
@@ -422,6 +436,7 @@ class Transformer2DModel(LegacyModelMixin, LegacyConfigMixin):
|
||||
timestep,
|
||||
cross_attention_kwargs,
|
||||
class_labels,
|
||||
**ckpt_kwargs,
|
||||
)
|
||||
else:
|
||||
hidden_states = block(
|
||||
|
||||
@@ -13,18 +13,17 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from typing import Optional, Tuple
|
||||
from typing import Any, Dict, Optional, Tuple
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
from ...configuration_utils import ConfigMixin, register_to_config
|
||||
from ...utils import logging
|
||||
from ...utils import is_torch_version, logging
|
||||
from ...utils.torch_utils import maybe_allow_in_graph
|
||||
from ..attention import FeedForward
|
||||
from ..attention_processor import AllegroAttnProcessor2_0, Attention
|
||||
from ..cache_utils import CacheMixin
|
||||
from ..embeddings import PatchEmbed, PixArtAlphaTextProjection
|
||||
from ..modeling_outputs import Transformer2DModelOutput
|
||||
from ..modeling_utils import ModelMixin
|
||||
@@ -173,7 +172,7 @@ class AllegroTransformerBlock(nn.Module):
|
||||
return hidden_states
|
||||
|
||||
|
||||
class AllegroTransformer3DModel(ModelMixin, ConfigMixin, CacheMixin):
|
||||
class AllegroTransformer3DModel(ModelMixin, ConfigMixin):
|
||||
_supports_gradient_checkpointing = True
|
||||
|
||||
"""
|
||||
@@ -223,7 +222,6 @@ class AllegroTransformer3DModel(ModelMixin, ConfigMixin, CacheMixin):
|
||||
"""
|
||||
|
||||
_supports_gradient_checkpointing = True
|
||||
_skip_layerwise_casting_patterns = ["pos_embed", "norm", "adaln_single"]
|
||||
|
||||
@register_to_config
|
||||
def __init__(
|
||||
@@ -304,6 +302,9 @@ class AllegroTransformer3DModel(ModelMixin, ConfigMixin, CacheMixin):
|
||||
|
||||
self.gradient_checkpointing = False
|
||||
|
||||
def _set_gradient_checkpointing(self, module, value=False):
|
||||
self.gradient_checkpointing = value
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
@@ -373,14 +374,23 @@ class AllegroTransformer3DModel(ModelMixin, ConfigMixin, CacheMixin):
|
||||
for i, block in enumerate(self.transformer_blocks):
|
||||
# TODO(aryan): Implement gradient checkpointing
|
||||
if torch.is_grad_enabled() and self.gradient_checkpointing:
|
||||
hidden_states = self._gradient_checkpointing_func(
|
||||
block,
|
||||
|
||||
def create_custom_forward(module):
|
||||
def custom_forward(*inputs):
|
||||
return module(*inputs)
|
||||
|
||||
return custom_forward
|
||||
|
||||
ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
|
||||
hidden_states = torch.utils.checkpoint.checkpoint(
|
||||
create_custom_forward(block),
|
||||
hidden_states,
|
||||
encoder_hidden_states,
|
||||
timestep,
|
||||
attention_mask,
|
||||
encoder_attention_mask,
|
||||
image_rotary_emb,
|
||||
**ckpt_kwargs,
|
||||
)
|
||||
else:
|
||||
hidden_states = block(
|
||||
|
||||
@@ -13,7 +13,7 @@
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
from typing import Dict, Union
|
||||
from typing import Any, Dict, Union
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
@@ -27,7 +27,7 @@ from ...models.attention_processor import (
|
||||
)
|
||||
from ...models.modeling_utils import ModelMixin
|
||||
from ...models.normalization import AdaLayerNormContinuous
|
||||
from ...utils import logging
|
||||
from ...utils import is_torch_version, logging
|
||||
from ..embeddings import CogView3CombinedTimestepSizeEmbeddings, CogView3PlusPatchEmbed
|
||||
from ..modeling_outputs import Transformer2DModelOutput
|
||||
from ..normalization import CogView3PlusAdaLayerNormZeroTextImage
|
||||
@@ -166,7 +166,6 @@ class CogView3PlusTransformer2DModel(ModelMixin, ConfigMixin):
|
||||
"""
|
||||
|
||||
_supports_gradient_checkpointing = True
|
||||
_skip_layerwise_casting_patterns = ["patch_embed", "norm"]
|
||||
_no_split_modules = ["CogView3PlusTransformerBlock", "CogView3PlusPatchEmbed"]
|
||||
|
||||
@register_to_config
|
||||
@@ -289,6 +288,10 @@ class CogView3PlusTransformer2DModel(ModelMixin, ConfigMixin):
|
||||
for name, module in self.named_children():
|
||||
fn_recursive_attn_processor(name, module, processor)
|
||||
|
||||
def _set_gradient_checkpointing(self, module, value=False):
|
||||
if hasattr(module, "gradient_checkpointing"):
|
||||
module.gradient_checkpointing = value
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
@@ -340,11 +343,20 @@ class CogView3PlusTransformer2DModel(ModelMixin, ConfigMixin):
|
||||
|
||||
for index_block, block in enumerate(self.transformer_blocks):
|
||||
if torch.is_grad_enabled() and self.gradient_checkpointing:
|
||||
hidden_states, encoder_hidden_states = self._gradient_checkpointing_func(
|
||||
block,
|
||||
|
||||
def create_custom_forward(module):
|
||||
def custom_forward(*inputs):
|
||||
return module(*inputs)
|
||||
|
||||
return custom_forward
|
||||
|
||||
ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
|
||||
hidden_states, encoder_hidden_states = torch.utils.checkpoint.checkpoint(
|
||||
create_custom_forward(block),
|
||||
hidden_states,
|
||||
encoder_hidden_states,
|
||||
emb,
|
||||
**ckpt_kwargs,
|
||||
)
|
||||
else:
|
||||
hidden_states, encoder_hidden_states = block(
|
||||
|
||||
@@ -32,10 +32,9 @@ from ...models.attention_processor import (
|
||||
)
|
||||
from ...models.modeling_utils import ModelMixin
|
||||
from ...models.normalization import AdaLayerNormContinuous, AdaLayerNormZero, AdaLayerNormZeroSingle
|
||||
from ...utils import USE_PEFT_BACKEND, logging, scale_lora_layers, unscale_lora_layers
|
||||
from ...utils import USE_PEFT_BACKEND, is_torch_version, logging, scale_lora_layers, unscale_lora_layers
|
||||
from ...utils.import_utils import is_torch_npu_available
|
||||
from ...utils.torch_utils import maybe_allow_in_graph
|
||||
from ..cache_utils import CacheMixin
|
||||
from ..embeddings import CombinedTimestepGuidanceTextProjEmbeddings, CombinedTimestepTextProjEmbeddings, FluxPosEmbed
|
||||
from ..modeling_outputs import Transformer2DModelOutput
|
||||
|
||||
@@ -228,7 +227,7 @@ class FluxTransformerBlock(nn.Module):
|
||||
|
||||
|
||||
class FluxTransformer2DModel(
|
||||
ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginalModelMixin, FluxTransformer2DLoadersMixin, CacheMixin
|
||||
ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginalModelMixin, FluxTransformer2DLoadersMixin
|
||||
):
|
||||
"""
|
||||
The Transformer model introduced in Flux.
|
||||
@@ -263,7 +262,6 @@ class FluxTransformer2DModel(
|
||||
|
||||
_supports_gradient_checkpointing = True
|
||||
_no_split_modules = ["FluxTransformerBlock", "FluxSingleTransformerBlock"]
|
||||
_skip_layerwise_casting_patterns = ["pos_embed", "norm"]
|
||||
|
||||
@register_to_config
|
||||
def __init__(
|
||||
@@ -423,6 +421,10 @@ class FluxTransformer2DModel(
|
||||
if self.original_attn_processors is not None:
|
||||
self.set_attn_processor(self.original_attn_processors)
|
||||
|
||||
def _set_gradient_checkpointing(self, module, value=False):
|
||||
if hasattr(module, "gradient_checkpointing"):
|
||||
module.gradient_checkpointing = value
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
@@ -517,12 +519,24 @@ class FluxTransformer2DModel(
|
||||
|
||||
for index_block, block in enumerate(self.transformer_blocks):
|
||||
if torch.is_grad_enabled() and self.gradient_checkpointing:
|
||||
encoder_hidden_states, hidden_states = self._gradient_checkpointing_func(
|
||||
block,
|
||||
|
||||
def create_custom_forward(module, return_dict=None):
|
||||
def custom_forward(*inputs):
|
||||
if return_dict is not None:
|
||||
return module(*inputs, return_dict=return_dict)
|
||||
else:
|
||||
return module(*inputs)
|
||||
|
||||
return custom_forward
|
||||
|
||||
ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
|
||||
encoder_hidden_states, hidden_states = torch.utils.checkpoint.checkpoint(
|
||||
create_custom_forward(block),
|
||||
hidden_states,
|
||||
encoder_hidden_states,
|
||||
temb,
|
||||
image_rotary_emb,
|
||||
**ckpt_kwargs,
|
||||
)
|
||||
|
||||
else:
|
||||
@@ -549,11 +563,23 @@ class FluxTransformer2DModel(
|
||||
|
||||
for index_block, block in enumerate(self.single_transformer_blocks):
|
||||
if torch.is_grad_enabled() and self.gradient_checkpointing:
|
||||
hidden_states = self._gradient_checkpointing_func(
|
||||
block,
|
||||
|
||||
def create_custom_forward(module, return_dict=None):
|
||||
def custom_forward(*inputs):
|
||||
if return_dict is not None:
|
||||
return module(*inputs, return_dict=return_dict)
|
||||
else:
|
||||
return module(*inputs)
|
||||
|
||||
return custom_forward
|
||||
|
||||
ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
|
||||
hidden_states = torch.utils.checkpoint.checkpoint(
|
||||
create_custom_forward(block),
|
||||
hidden_states,
|
||||
temb,
|
||||
image_rotary_emb,
|
||||
**ckpt_kwargs,
|
||||
)
|
||||
|
||||
else:
|
||||
|
||||
@@ -22,10 +22,9 @@ from diffusers.loaders import FromOriginalModelMixin
|
||||
|
||||
from ...configuration_utils import ConfigMixin, register_to_config
|
||||
from ...loaders import PeftAdapterMixin
|
||||
from ...utils import USE_PEFT_BACKEND, logging, scale_lora_layers, unscale_lora_layers
|
||||
from ...utils import USE_PEFT_BACKEND, is_torch_version, logging, scale_lora_layers, unscale_lora_layers
|
||||
from ..attention import FeedForward
|
||||
from ..attention_processor import Attention, AttentionProcessor
|
||||
from ..cache_utils import CacheMixin
|
||||
from ..embeddings import (
|
||||
CombinedTimestepGuidanceTextProjEmbeddings,
|
||||
CombinedTimestepTextProjEmbeddings,
|
||||
@@ -503,7 +502,7 @@ class HunyuanVideoTransformerBlock(nn.Module):
|
||||
return hidden_states, encoder_hidden_states
|
||||
|
||||
|
||||
class HunyuanVideoTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginalModelMixin, CacheMixin):
|
||||
class HunyuanVideoTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginalModelMixin):
|
||||
r"""
|
||||
A Transformer model for video-like data used in [HunyuanVideo](https://huggingface.co/tencent/HunyuanVideo).
|
||||
|
||||
@@ -543,7 +542,6 @@ class HunyuanVideoTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin,
|
||||
"""
|
||||
|
||||
_supports_gradient_checkpointing = True
|
||||
_skip_layerwise_casting_patterns = ["x_embedder", "context_embedder", "norm"]
|
||||
_no_split_modules = [
|
||||
"HunyuanVideoTransformerBlock",
|
||||
"HunyuanVideoSingleTransformerBlock",
|
||||
@@ -672,6 +670,10 @@ class HunyuanVideoTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin,
|
||||
for name, module in self.named_children():
|
||||
fn_recursive_attn_processor(name, module, processor)
|
||||
|
||||
def _set_gradient_checkpointing(self, module, value=False):
|
||||
if hasattr(module, "gradient_checkpointing"):
|
||||
module.gradient_checkpointing = value
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
@@ -730,24 +732,38 @@ class HunyuanVideoTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin,
|
||||
|
||||
# 4. Transformer blocks
|
||||
if torch.is_grad_enabled() and self.gradient_checkpointing:
|
||||
|
||||
def create_custom_forward(module, return_dict=None):
|
||||
def custom_forward(*inputs):
|
||||
if return_dict is not None:
|
||||
return module(*inputs, return_dict=return_dict)
|
||||
else:
|
||||
return module(*inputs)
|
||||
|
||||
return custom_forward
|
||||
|
||||
ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
|
||||
|
||||
for block in self.transformer_blocks:
|
||||
hidden_states, encoder_hidden_states = self._gradient_checkpointing_func(
|
||||
block,
|
||||
hidden_states, encoder_hidden_states = torch.utils.checkpoint.checkpoint(
|
||||
create_custom_forward(block),
|
||||
hidden_states,
|
||||
encoder_hidden_states,
|
||||
temb,
|
||||
attention_mask,
|
||||
image_rotary_emb,
|
||||
**ckpt_kwargs,
|
||||
)
|
||||
|
||||
for block in self.single_transformer_blocks:
|
||||
hidden_states, encoder_hidden_states = self._gradient_checkpointing_func(
|
||||
block,
|
||||
hidden_states, encoder_hidden_states = torch.utils.checkpoint.checkpoint(
|
||||
create_custom_forward(block),
|
||||
hidden_states,
|
||||
encoder_hidden_states,
|
||||
temb,
|
||||
attention_mask,
|
||||
image_rotary_emb,
|
||||
**ckpt_kwargs,
|
||||
)
|
||||
|
||||
else:
|
||||
|
||||
@@ -22,7 +22,7 @@ import torch.nn.functional as F
|
||||
|
||||
from ...configuration_utils import ConfigMixin, register_to_config
|
||||
from ...loaders import FromOriginalModelMixin, PeftAdapterMixin
|
||||
from ...utils import USE_PEFT_BACKEND, logging, scale_lora_layers, unscale_lora_layers
|
||||
from ...utils import USE_PEFT_BACKEND, is_torch_version, logging, scale_lora_layers, unscale_lora_layers
|
||||
from ...utils.torch_utils import maybe_allow_in_graph
|
||||
from ..attention import FeedForward
|
||||
from ..attention_processor import Attention
|
||||
@@ -295,7 +295,6 @@ class LTXVideoTransformer3DModel(ModelMixin, ConfigMixin, FromOriginalModelMixin
|
||||
"""
|
||||
|
||||
_supports_gradient_checkpointing = True
|
||||
_skip_layerwise_casting_patterns = ["norm"]
|
||||
|
||||
@register_to_config
|
||||
def __init__(
|
||||
@@ -361,6 +360,10 @@ class LTXVideoTransformer3DModel(ModelMixin, ConfigMixin, FromOriginalModelMixin
|
||||
|
||||
self.gradient_checkpointing = False
|
||||
|
||||
def _set_gradient_checkpointing(self, module, value=False):
|
||||
if hasattr(module, "gradient_checkpointing"):
|
||||
module.gradient_checkpointing = value
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
@@ -413,13 +416,25 @@ class LTXVideoTransformer3DModel(ModelMixin, ConfigMixin, FromOriginalModelMixin
|
||||
|
||||
for block in self.transformer_blocks:
|
||||
if torch.is_grad_enabled() and self.gradient_checkpointing:
|
||||
hidden_states = self._gradient_checkpointing_func(
|
||||
block,
|
||||
|
||||
def create_custom_forward(module, return_dict=None):
|
||||
def custom_forward(*inputs):
|
||||
if return_dict is not None:
|
||||
return module(*inputs, return_dict=return_dict)
|
||||
else:
|
||||
return module(*inputs)
|
||||
|
||||
return custom_forward
|
||||
|
||||
ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
|
||||
hidden_states = torch.utils.checkpoint.checkpoint(
|
||||
create_custom_forward(block),
|
||||
hidden_states,
|
||||
encoder_hidden_states,
|
||||
temb,
|
||||
image_rotary_emb,
|
||||
encoder_attention_mask,
|
||||
**ckpt_kwargs,
|
||||
)
|
||||
else:
|
||||
hidden_states = block(
|
||||
|
||||
@@ -21,11 +21,10 @@ import torch.nn as nn
|
||||
from ...configuration_utils import ConfigMixin, register_to_config
|
||||
from ...loaders import PeftAdapterMixin
|
||||
from ...loaders.single_file_model import FromOriginalModelMixin
|
||||
from ...utils import USE_PEFT_BACKEND, logging, scale_lora_layers, unscale_lora_layers
|
||||
from ...utils import USE_PEFT_BACKEND, is_torch_version, logging, scale_lora_layers, unscale_lora_layers
|
||||
from ...utils.torch_utils import maybe_allow_in_graph
|
||||
from ..attention import FeedForward
|
||||
from ..attention_processor import MochiAttention, MochiAttnProcessor2_0
|
||||
from ..cache_utils import CacheMixin
|
||||
from ..embeddings import MochiCombinedTimestepCaptionEmbedding, PatchEmbed
|
||||
from ..modeling_outputs import Transformer2DModelOutput
|
||||
from ..modeling_utils import ModelMixin
|
||||
@@ -306,7 +305,7 @@ class MochiRoPE(nn.Module):
|
||||
|
||||
|
||||
@maybe_allow_in_graph
|
||||
class MochiTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginalModelMixin, CacheMixin):
|
||||
class MochiTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginalModelMixin):
|
||||
r"""
|
||||
A Transformer model for video-like data introduced in [Mochi](https://huggingface.co/genmo/mochi-1-preview).
|
||||
|
||||
@@ -337,7 +336,6 @@ class MochiTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOri
|
||||
|
||||
_supports_gradient_checkpointing = True
|
||||
_no_split_modules = ["MochiTransformerBlock"]
|
||||
_skip_layerwise_casting_patterns = ["patch_embed", "norm"]
|
||||
|
||||
@register_to_config
|
||||
def __init__(
|
||||
@@ -404,6 +402,10 @@ class MochiTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOri
|
||||
|
||||
self.gradient_checkpointing = False
|
||||
|
||||
def _set_gradient_checkpointing(self, module, value=False):
|
||||
if hasattr(module, "gradient_checkpointing"):
|
||||
module.gradient_checkpointing = value
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
@@ -456,13 +458,22 @@ class MochiTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOri
|
||||
|
||||
for i, block in enumerate(self.transformer_blocks):
|
||||
if torch.is_grad_enabled() and self.gradient_checkpointing:
|
||||
hidden_states, encoder_hidden_states = self._gradient_checkpointing_func(
|
||||
block,
|
||||
|
||||
def create_custom_forward(module):
|
||||
def custom_forward(*inputs):
|
||||
return module(*inputs)
|
||||
|
||||
return custom_forward
|
||||
|
||||
ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
|
||||
hidden_states, encoder_hidden_states = torch.utils.checkpoint.checkpoint(
|
||||
create_custom_forward(block),
|
||||
hidden_states,
|
||||
encoder_hidden_states,
|
||||
temb,
|
||||
encoder_attention_mask,
|
||||
image_rotary_emb,
|
||||
**ckpt_kwargs,
|
||||
)
|
||||
else:
|
||||
hidden_states, encoder_hidden_states = block(
|
||||
|
||||
@@ -28,7 +28,7 @@ from ...models.attention_processor import (
|
||||
)
|
||||
from ...models.modeling_utils import ModelMixin
|
||||
from ...models.normalization import AdaLayerNormContinuous, AdaLayerNormZero
|
||||
from ...utils import USE_PEFT_BACKEND, logging, scale_lora_layers, unscale_lora_layers
|
||||
from ...utils import USE_PEFT_BACKEND, is_torch_version, logging, scale_lora_layers, unscale_lora_layers
|
||||
from ...utils.torch_utils import maybe_allow_in_graph
|
||||
from ..embeddings import CombinedTimestepTextProjEmbeddings, PatchEmbed
|
||||
from ..modeling_outputs import Transformer2DModelOutput
|
||||
@@ -127,7 +127,6 @@ class SD3Transformer2DModel(
|
||||
"""
|
||||
|
||||
_supports_gradient_checkpointing = True
|
||||
_skip_layerwise_casting_patterns = ["pos_embed", "norm"]
|
||||
|
||||
@register_to_config
|
||||
def __init__(
|
||||
@@ -329,6 +328,10 @@ class SD3Transformer2DModel(
|
||||
if self.original_attn_processors is not None:
|
||||
self.set_attn_processor(self.original_attn_processors)
|
||||
|
||||
def _set_gradient_checkpointing(self, module, value=False):
|
||||
if hasattr(module, "gradient_checkpointing"):
|
||||
module.gradient_checkpointing = value
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.FloatTensor,
|
||||
@@ -400,12 +403,24 @@ class SD3Transformer2DModel(
|
||||
is_skip = True if skip_layers is not None and index_block in skip_layers else False
|
||||
|
||||
if torch.is_grad_enabled() and self.gradient_checkpointing and not is_skip:
|
||||
encoder_hidden_states, hidden_states = self._gradient_checkpointing_func(
|
||||
block,
|
||||
|
||||
def create_custom_forward(module, return_dict=None):
|
||||
def custom_forward(*inputs):
|
||||
if return_dict is not None:
|
||||
return module(*inputs, return_dict=return_dict)
|
||||
else:
|
||||
return module(*inputs)
|
||||
|
||||
return custom_forward
|
||||
|
||||
ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
|
||||
encoder_hidden_states, hidden_states = torch.utils.checkpoint.checkpoint(
|
||||
create_custom_forward(block),
|
||||
hidden_states,
|
||||
encoder_hidden_states,
|
||||
temb,
|
||||
joint_attention_kwargs,
|
||||
**ckpt_kwargs,
|
||||
)
|
||||
elif not is_skip:
|
||||
encoder_hidden_states, hidden_states = block(
|
||||
|
||||
@@ -67,8 +67,6 @@ class TransformerTemporalModel(ModelMixin, ConfigMixin):
|
||||
The maximum length of the sequence over which to apply positional embeddings.
|
||||
"""
|
||||
|
||||
_skip_layerwise_casting_patterns = ["norm"]
|
||||
|
||||
@register_to_config
|
||||
def __init__(
|
||||
self,
|
||||
@@ -343,11 +341,19 @@ class TransformerSpatioTemporalModel(nn.Module):
|
||||
# 2. Blocks
|
||||
for block, temporal_block in zip(self.transformer_blocks, self.temporal_transformer_blocks):
|
||||
if torch.is_grad_enabled() and self.gradient_checkpointing:
|
||||
hidden_states = self._gradient_checkpointing_func(
|
||||
block, hidden_states, None, encoder_hidden_states, None
|
||||
hidden_states = torch.utils.checkpoint.checkpoint(
|
||||
block,
|
||||
hidden_states,
|
||||
None,
|
||||
encoder_hidden_states,
|
||||
None,
|
||||
use_reentrant=False,
|
||||
)
|
||||
else:
|
||||
hidden_states = block(hidden_states, encoder_hidden_states=encoder_hidden_states)
|
||||
hidden_states = block(
|
||||
hidden_states,
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
)
|
||||
|
||||
hidden_states_mix = hidden_states
|
||||
hidden_states_mix = hidden_states_mix + emb
|
||||
|
||||
@@ -71,8 +71,6 @@ class UNet1DModel(ModelMixin, ConfigMixin):
|
||||
Experimental feature for using a UNet without upsampling.
|
||||
"""
|
||||
|
||||
_skip_layerwise_casting_patterns = ["norm"]
|
||||
|
||||
@register_to_config
|
||||
def __init__(
|
||||
self,
|
||||
@@ -225,7 +223,7 @@ class UNet1DModel(ModelMixin, ConfigMixin):
|
||||
|
||||
timestep_embed = self.time_proj(timesteps)
|
||||
if self.config.use_timestep_embedding:
|
||||
timestep_embed = self.time_mlp(timestep_embed.to(sample.dtype))
|
||||
timestep_embed = self.time_mlp(timestep_embed)
|
||||
else:
|
||||
timestep_embed = timestep_embed[..., None]
|
||||
timestep_embed = timestep_embed.repeat([1, 1, sample.shape[2]]).to(sample.dtype)
|
||||
|
||||
@@ -90,7 +90,6 @@ class UNet2DModel(ModelMixin, ConfigMixin):
|
||||
"""
|
||||
|
||||
_supports_gradient_checkpointing = True
|
||||
_skip_layerwise_casting_patterns = ["norm"]
|
||||
|
||||
@register_to_config
|
||||
def __init__(
|
||||
@@ -248,6 +247,10 @@ class UNet2DModel(ModelMixin, ConfigMixin):
|
||||
self.conv_act = nn.SiLU()
|
||||
self.conv_out = nn.Conv2d(block_out_channels[0], out_channels, kernel_size=3, padding=1)
|
||||
|
||||
def _set_gradient_checkpointing(self, module, value=False):
|
||||
if hasattr(module, "gradient_checkpointing"):
|
||||
module.gradient_checkpointing = value
|
||||
|
||||
def forward(
|
||||
self,
|
||||
sample: torch.Tensor,
|
||||
|
||||
@@ -18,7 +18,7 @@ import torch
|
||||
import torch.nn.functional as F
|
||||
from torch import nn
|
||||
|
||||
from ...utils import deprecate, logging
|
||||
from ...utils import deprecate, is_torch_version, logging
|
||||
from ...utils.torch_utils import apply_freeu
|
||||
from ..activations import get_activation
|
||||
from ..attention_processor import Attention, AttnAddedKVProcessor, AttnAddedKVProcessor2_0
|
||||
@@ -737,9 +737,25 @@ class UNetMidBlock2D(nn.Module):
|
||||
hidden_states = self.resnets[0](hidden_states, temb)
|
||||
for attn, resnet in zip(self.attentions, self.resnets[1:]):
|
||||
if torch.is_grad_enabled() and self.gradient_checkpointing:
|
||||
|
||||
def create_custom_forward(module, return_dict=None):
|
||||
def custom_forward(*inputs):
|
||||
if return_dict is not None:
|
||||
return module(*inputs, return_dict=return_dict)
|
||||
else:
|
||||
return module(*inputs)
|
||||
|
||||
return custom_forward
|
||||
|
||||
ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
|
||||
if attn is not None:
|
||||
hidden_states = attn(hidden_states, temb=temb)
|
||||
hidden_states = self._gradient_checkpointing_func(resnet, hidden_states, temb)
|
||||
hidden_states = torch.utils.checkpoint.checkpoint(
|
||||
create_custom_forward(resnet),
|
||||
hidden_states,
|
||||
temb,
|
||||
**ckpt_kwargs,
|
||||
)
|
||||
else:
|
||||
if attn is not None:
|
||||
hidden_states = attn(hidden_states, temb=temb)
|
||||
@@ -867,6 +883,17 @@ class UNetMidBlock2DCrossAttn(nn.Module):
|
||||
hidden_states = self.resnets[0](hidden_states, temb)
|
||||
for attn, resnet in zip(self.attentions, self.resnets[1:]):
|
||||
if torch.is_grad_enabled() and self.gradient_checkpointing:
|
||||
|
||||
def create_custom_forward(module, return_dict=None):
|
||||
def custom_forward(*inputs):
|
||||
if return_dict is not None:
|
||||
return module(*inputs, return_dict=return_dict)
|
||||
else:
|
||||
return module(*inputs)
|
||||
|
||||
return custom_forward
|
||||
|
||||
ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
|
||||
hidden_states = attn(
|
||||
hidden_states,
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
@@ -875,7 +902,12 @@ class UNetMidBlock2DCrossAttn(nn.Module):
|
||||
encoder_attention_mask=encoder_attention_mask,
|
||||
return_dict=False,
|
||||
)[0]
|
||||
hidden_states = self._gradient_checkpointing_func(resnet, hidden_states, temb)
|
||||
hidden_states = torch.utils.checkpoint.checkpoint(
|
||||
create_custom_forward(resnet),
|
||||
hidden_states,
|
||||
temb,
|
||||
**ckpt_kwargs,
|
||||
)
|
||||
else:
|
||||
hidden_states = attn(
|
||||
hidden_states,
|
||||
@@ -1124,7 +1156,23 @@ class AttnDownBlock2D(nn.Module):
|
||||
|
||||
for resnet, attn in zip(self.resnets, self.attentions):
|
||||
if torch.is_grad_enabled() and self.gradient_checkpointing:
|
||||
hidden_states = self._gradient_checkpointing_func(resnet, hidden_states, temb)
|
||||
|
||||
def create_custom_forward(module, return_dict=None):
|
||||
def custom_forward(*inputs):
|
||||
if return_dict is not None:
|
||||
return module(*inputs, return_dict=return_dict)
|
||||
else:
|
||||
return module(*inputs)
|
||||
|
||||
return custom_forward
|
||||
|
||||
ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
|
||||
hidden_states = torch.utils.checkpoint.checkpoint(
|
||||
create_custom_forward(resnet),
|
||||
hidden_states,
|
||||
temb,
|
||||
**ckpt_kwargs,
|
||||
)
|
||||
hidden_states = attn(hidden_states, **cross_attention_kwargs)
|
||||
output_states = output_states + (hidden_states,)
|
||||
else:
|
||||
@@ -1256,7 +1304,23 @@ class CrossAttnDownBlock2D(nn.Module):
|
||||
|
||||
for i, (resnet, attn) in enumerate(blocks):
|
||||
if torch.is_grad_enabled() and self.gradient_checkpointing:
|
||||
hidden_states = self._gradient_checkpointing_func(resnet, hidden_states, temb)
|
||||
|
||||
def create_custom_forward(module, return_dict=None):
|
||||
def custom_forward(*inputs):
|
||||
if return_dict is not None:
|
||||
return module(*inputs, return_dict=return_dict)
|
||||
else:
|
||||
return module(*inputs)
|
||||
|
||||
return custom_forward
|
||||
|
||||
ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
|
||||
hidden_states = torch.utils.checkpoint.checkpoint(
|
||||
create_custom_forward(resnet),
|
||||
hidden_states,
|
||||
temb,
|
||||
**ckpt_kwargs,
|
||||
)
|
||||
hidden_states = attn(
|
||||
hidden_states,
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
@@ -1354,7 +1418,21 @@ class DownBlock2D(nn.Module):
|
||||
|
||||
for resnet in self.resnets:
|
||||
if torch.is_grad_enabled() and self.gradient_checkpointing:
|
||||
hidden_states = self._gradient_checkpointing_func(resnet, hidden_states, temb)
|
||||
|
||||
def create_custom_forward(module):
|
||||
def custom_forward(*inputs):
|
||||
return module(*inputs)
|
||||
|
||||
return custom_forward
|
||||
|
||||
if is_torch_version(">=", "1.11.0"):
|
||||
hidden_states = torch.utils.checkpoint.checkpoint(
|
||||
create_custom_forward(resnet), hidden_states, temb, use_reentrant=False
|
||||
)
|
||||
else:
|
||||
hidden_states = torch.utils.checkpoint.checkpoint(
|
||||
create_custom_forward(resnet), hidden_states, temb
|
||||
)
|
||||
else:
|
||||
hidden_states = resnet(hidden_states, temb)
|
||||
|
||||
@@ -1828,7 +1906,21 @@ class ResnetDownsampleBlock2D(nn.Module):
|
||||
|
||||
for resnet in self.resnets:
|
||||
if torch.is_grad_enabled() and self.gradient_checkpointing:
|
||||
hidden_states = self._gradient_checkpointing_func(resnet, hidden_states, temb)
|
||||
|
||||
def create_custom_forward(module):
|
||||
def custom_forward(*inputs):
|
||||
return module(*inputs)
|
||||
|
||||
return custom_forward
|
||||
|
||||
if is_torch_version(">=", "1.11.0"):
|
||||
hidden_states = torch.utils.checkpoint.checkpoint(
|
||||
create_custom_forward(resnet), hidden_states, temb, use_reentrant=False
|
||||
)
|
||||
else:
|
||||
hidden_states = torch.utils.checkpoint.checkpoint(
|
||||
create_custom_forward(resnet), hidden_states, temb
|
||||
)
|
||||
else:
|
||||
hidden_states = resnet(hidden_states, temb)
|
||||
|
||||
@@ -1966,7 +2058,17 @@ class SimpleCrossAttnDownBlock2D(nn.Module):
|
||||
|
||||
for resnet, attn in zip(self.resnets, self.attentions):
|
||||
if torch.is_grad_enabled() and self.gradient_checkpointing:
|
||||
hidden_states = self._gradient_checkpointing_func(resnet, hidden_states, temb)
|
||||
|
||||
def create_custom_forward(module, return_dict=None):
|
||||
def custom_forward(*inputs):
|
||||
if return_dict is not None:
|
||||
return module(*inputs, return_dict=return_dict)
|
||||
else:
|
||||
return module(*inputs)
|
||||
|
||||
return custom_forward
|
||||
|
||||
hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(resnet), hidden_states, temb)
|
||||
hidden_states = attn(
|
||||
hidden_states,
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
@@ -2051,7 +2153,21 @@ class KDownBlock2D(nn.Module):
|
||||
|
||||
for resnet in self.resnets:
|
||||
if torch.is_grad_enabled() and self.gradient_checkpointing:
|
||||
hidden_states = self._gradient_checkpointing_func(resnet, hidden_states, temb)
|
||||
|
||||
def create_custom_forward(module):
|
||||
def custom_forward(*inputs):
|
||||
return module(*inputs)
|
||||
|
||||
return custom_forward
|
||||
|
||||
if is_torch_version(">=", "1.11.0"):
|
||||
hidden_states = torch.utils.checkpoint.checkpoint(
|
||||
create_custom_forward(resnet), hidden_states, temb, use_reentrant=False
|
||||
)
|
||||
else:
|
||||
hidden_states = torch.utils.checkpoint.checkpoint(
|
||||
create_custom_forward(resnet), hidden_states, temb
|
||||
)
|
||||
else:
|
||||
hidden_states = resnet(hidden_states, temb)
|
||||
|
||||
@@ -2146,10 +2262,22 @@ class KCrossAttnDownBlock2D(nn.Module):
|
||||
|
||||
for resnet, attn in zip(self.resnets, self.attentions):
|
||||
if torch.is_grad_enabled() and self.gradient_checkpointing:
|
||||
hidden_states = self._gradient_checkpointing_func(
|
||||
resnet,
|
||||
|
||||
def create_custom_forward(module, return_dict=None):
|
||||
def custom_forward(*inputs):
|
||||
if return_dict is not None:
|
||||
return module(*inputs, return_dict=return_dict)
|
||||
else:
|
||||
return module(*inputs)
|
||||
|
||||
return custom_forward
|
||||
|
||||
ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
|
||||
hidden_states = torch.utils.checkpoint.checkpoint(
|
||||
create_custom_forward(resnet),
|
||||
hidden_states,
|
||||
temb,
|
||||
**ckpt_kwargs,
|
||||
)
|
||||
hidden_states = attn(
|
||||
hidden_states,
|
||||
@@ -2295,7 +2423,23 @@ class AttnUpBlock2D(nn.Module):
|
||||
hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
|
||||
|
||||
if torch.is_grad_enabled() and self.gradient_checkpointing:
|
||||
hidden_states = self._gradient_checkpointing_func(resnet, hidden_states, temb)
|
||||
|
||||
def create_custom_forward(module, return_dict=None):
|
||||
def custom_forward(*inputs):
|
||||
if return_dict is not None:
|
||||
return module(*inputs, return_dict=return_dict)
|
||||
else:
|
||||
return module(*inputs)
|
||||
|
||||
return custom_forward
|
||||
|
||||
ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
|
||||
hidden_states = torch.utils.checkpoint.checkpoint(
|
||||
create_custom_forward(resnet),
|
||||
hidden_states,
|
||||
temb,
|
||||
**ckpt_kwargs,
|
||||
)
|
||||
hidden_states = attn(hidden_states)
|
||||
else:
|
||||
hidden_states = resnet(hidden_states, temb)
|
||||
@@ -2444,7 +2588,23 @@ class CrossAttnUpBlock2D(nn.Module):
|
||||
hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
|
||||
|
||||
if torch.is_grad_enabled() and self.gradient_checkpointing:
|
||||
hidden_states = self._gradient_checkpointing_func(resnet, hidden_states, temb)
|
||||
|
||||
def create_custom_forward(module, return_dict=None):
|
||||
def custom_forward(*inputs):
|
||||
if return_dict is not None:
|
||||
return module(*inputs, return_dict=return_dict)
|
||||
else:
|
||||
return module(*inputs)
|
||||
|
||||
return custom_forward
|
||||
|
||||
ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
|
||||
hidden_states = torch.utils.checkpoint.checkpoint(
|
||||
create_custom_forward(resnet),
|
||||
hidden_states,
|
||||
temb,
|
||||
**ckpt_kwargs,
|
||||
)
|
||||
hidden_states = attn(
|
||||
hidden_states,
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
@@ -2561,7 +2721,21 @@ class UpBlock2D(nn.Module):
|
||||
hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
|
||||
|
||||
if torch.is_grad_enabled() and self.gradient_checkpointing:
|
||||
hidden_states = self._gradient_checkpointing_func(resnet, hidden_states, temb)
|
||||
|
||||
def create_custom_forward(module):
|
||||
def custom_forward(*inputs):
|
||||
return module(*inputs)
|
||||
|
||||
return custom_forward
|
||||
|
||||
if is_torch_version(">=", "1.11.0"):
|
||||
hidden_states = torch.utils.checkpoint.checkpoint(
|
||||
create_custom_forward(resnet), hidden_states, temb, use_reentrant=False
|
||||
)
|
||||
else:
|
||||
hidden_states = torch.utils.checkpoint.checkpoint(
|
||||
create_custom_forward(resnet), hidden_states, temb
|
||||
)
|
||||
else:
|
||||
hidden_states = resnet(hidden_states, temb)
|
||||
|
||||
@@ -3077,7 +3251,21 @@ class ResnetUpsampleBlock2D(nn.Module):
|
||||
hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
|
||||
|
||||
if torch.is_grad_enabled() and self.gradient_checkpointing:
|
||||
hidden_states = self._gradient_checkpointing_func(resnet, hidden_states, temb)
|
||||
|
||||
def create_custom_forward(module):
|
||||
def custom_forward(*inputs):
|
||||
return module(*inputs)
|
||||
|
||||
return custom_forward
|
||||
|
||||
if is_torch_version(">=", "1.11.0"):
|
||||
hidden_states = torch.utils.checkpoint.checkpoint(
|
||||
create_custom_forward(resnet), hidden_states, temb, use_reentrant=False
|
||||
)
|
||||
else:
|
||||
hidden_states = torch.utils.checkpoint.checkpoint(
|
||||
create_custom_forward(resnet), hidden_states, temb
|
||||
)
|
||||
else:
|
||||
hidden_states = resnet(hidden_states, temb)
|
||||
|
||||
@@ -3221,7 +3409,17 @@ class SimpleCrossAttnUpBlock2D(nn.Module):
|
||||
hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
|
||||
|
||||
if torch.is_grad_enabled() and self.gradient_checkpointing:
|
||||
hidden_states = self._gradient_checkpointing_func(resnet, hidden_states, temb)
|
||||
|
||||
def create_custom_forward(module, return_dict=None):
|
||||
def custom_forward(*inputs):
|
||||
if return_dict is not None:
|
||||
return module(*inputs, return_dict=return_dict)
|
||||
else:
|
||||
return module(*inputs)
|
||||
|
||||
return custom_forward
|
||||
|
||||
hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(resnet), hidden_states, temb)
|
||||
hidden_states = attn(
|
||||
hidden_states,
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
@@ -3314,7 +3512,21 @@ class KUpBlock2D(nn.Module):
|
||||
|
||||
for resnet in self.resnets:
|
||||
if torch.is_grad_enabled() and self.gradient_checkpointing:
|
||||
hidden_states = self._gradient_checkpointing_func(resnet, hidden_states, temb)
|
||||
|
||||
def create_custom_forward(module):
|
||||
def custom_forward(*inputs):
|
||||
return module(*inputs)
|
||||
|
||||
return custom_forward
|
||||
|
||||
if is_torch_version(">=", "1.11.0"):
|
||||
hidden_states = torch.utils.checkpoint.checkpoint(
|
||||
create_custom_forward(resnet), hidden_states, temb, use_reentrant=False
|
||||
)
|
||||
else:
|
||||
hidden_states = torch.utils.checkpoint.checkpoint(
|
||||
create_custom_forward(resnet), hidden_states, temb
|
||||
)
|
||||
else:
|
||||
hidden_states = resnet(hidden_states, temb)
|
||||
|
||||
@@ -3428,10 +3640,22 @@ class KCrossAttnUpBlock2D(nn.Module):
|
||||
|
||||
for resnet, attn in zip(self.resnets, self.attentions):
|
||||
if torch.is_grad_enabled() and self.gradient_checkpointing:
|
||||
hidden_states = self._gradient_checkpointing_func(
|
||||
resnet,
|
||||
|
||||
def create_custom_forward(module, return_dict=None):
|
||||
def custom_forward(*inputs):
|
||||
if return_dict is not None:
|
||||
return module(*inputs, return_dict=return_dict)
|
||||
else:
|
||||
return module(*inputs)
|
||||
|
||||
return custom_forward
|
||||
|
||||
ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
|
||||
hidden_states = torch.utils.checkpoint.checkpoint(
|
||||
create_custom_forward(resnet),
|
||||
hidden_states,
|
||||
temb,
|
||||
**ckpt_kwargs,
|
||||
)
|
||||
hidden_states = attn(
|
||||
hidden_states,
|
||||
|
||||
@@ -166,7 +166,6 @@ class UNet2DConditionModel(
|
||||
|
||||
_supports_gradient_checkpointing = True
|
||||
_no_split_modules = ["BasicTransformerBlock", "ResnetBlock2D", "CrossAttnUpBlock2D"]
|
||||
_skip_layerwise_casting_patterns = ["norm"]
|
||||
|
||||
@register_to_config
|
||||
def __init__(
|
||||
@@ -834,6 +833,10 @@ class UNet2DConditionModel(
|
||||
for module in self.children():
|
||||
fn_recursive_set_attention_slice(module, reversed_slice_size)
|
||||
|
||||
def _set_gradient_checkpointing(self, module, value=False):
|
||||
if hasattr(module, "gradient_checkpointing"):
|
||||
module.gradient_checkpointing = value
|
||||
|
||||
def enable_freeu(self, s1: float, s2: float, b1: float, b2: float):
|
||||
r"""Enables the FreeU mechanism from https://arxiv.org/abs/2309.11497.
|
||||
|
||||
|
||||
@@ -17,7 +17,7 @@ from typing import Any, Dict, Optional, Tuple, Union
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
from ...utils import deprecate, logging
|
||||
from ...utils import deprecate, is_torch_version, logging
|
||||
from ...utils.torch_utils import apply_freeu
|
||||
from ..attention import Attention
|
||||
from ..resnet import (
|
||||
@@ -1078,14 +1078,31 @@ class UNetMidBlockSpatioTemporal(nn.Module):
|
||||
)
|
||||
|
||||
for attn, resnet in zip(self.attentions, self.resnets[1:]):
|
||||
if torch.is_grad_enabled() and self.gradient_checkpointing:
|
||||
if torch.is_grad_enabled() and self.gradient_checkpointing: # TODO
|
||||
|
||||
def create_custom_forward(module, return_dict=None):
|
||||
def custom_forward(*inputs):
|
||||
if return_dict is not None:
|
||||
return module(*inputs, return_dict=return_dict)
|
||||
else:
|
||||
return module(*inputs)
|
||||
|
||||
return custom_forward
|
||||
|
||||
ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
|
||||
hidden_states = attn(
|
||||
hidden_states,
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
image_only_indicator=image_only_indicator,
|
||||
return_dict=False,
|
||||
)[0]
|
||||
hidden_states = self._gradient_checkpointing_func(resnet, hidden_states, temb, image_only_indicator)
|
||||
hidden_states = torch.utils.checkpoint.checkpoint(
|
||||
create_custom_forward(resnet),
|
||||
hidden_states,
|
||||
temb,
|
||||
image_only_indicator,
|
||||
**ckpt_kwargs,
|
||||
)
|
||||
else:
|
||||
hidden_states = attn(
|
||||
hidden_states,
|
||||
@@ -1093,7 +1110,11 @@ class UNetMidBlockSpatioTemporal(nn.Module):
|
||||
image_only_indicator=image_only_indicator,
|
||||
return_dict=False,
|
||||
)[0]
|
||||
hidden_states = resnet(hidden_states, temb, image_only_indicator=image_only_indicator)
|
||||
hidden_states = resnet(
|
||||
hidden_states,
|
||||
temb,
|
||||
image_only_indicator=image_only_indicator,
|
||||
)
|
||||
|
||||
return hidden_states
|
||||
|
||||
@@ -1148,9 +1169,34 @@ class DownBlockSpatioTemporal(nn.Module):
|
||||
output_states = ()
|
||||
for resnet in self.resnets:
|
||||
if torch.is_grad_enabled() and self.gradient_checkpointing:
|
||||
hidden_states = self._gradient_checkpointing_func(resnet, hidden_states, temb, image_only_indicator)
|
||||
|
||||
def create_custom_forward(module):
|
||||
def custom_forward(*inputs):
|
||||
return module(*inputs)
|
||||
|
||||
return custom_forward
|
||||
|
||||
if is_torch_version(">=", "1.11.0"):
|
||||
hidden_states = torch.utils.checkpoint.checkpoint(
|
||||
create_custom_forward(resnet),
|
||||
hidden_states,
|
||||
temb,
|
||||
image_only_indicator,
|
||||
use_reentrant=False,
|
||||
)
|
||||
else:
|
||||
hidden_states = torch.utils.checkpoint.checkpoint(
|
||||
create_custom_forward(resnet),
|
||||
hidden_states,
|
||||
temb,
|
||||
image_only_indicator,
|
||||
)
|
||||
else:
|
||||
hidden_states = resnet(hidden_states, temb, image_only_indicator=image_only_indicator)
|
||||
hidden_states = resnet(
|
||||
hidden_states,
|
||||
temb,
|
||||
image_only_indicator=image_only_indicator,
|
||||
)
|
||||
|
||||
output_states = output_states + (hidden_states,)
|
||||
|
||||
@@ -1235,8 +1281,25 @@ class CrossAttnDownBlockSpatioTemporal(nn.Module):
|
||||
|
||||
blocks = list(zip(self.resnets, self.attentions))
|
||||
for resnet, attn in blocks:
|
||||
if torch.is_grad_enabled() and self.gradient_checkpointing:
|
||||
hidden_states = self._gradient_checkpointing_func(resnet, hidden_states, temb, image_only_indicator)
|
||||
if torch.is_grad_enabled() and self.gradient_checkpointing: # TODO
|
||||
|
||||
def create_custom_forward(module, return_dict=None):
|
||||
def custom_forward(*inputs):
|
||||
if return_dict is not None:
|
||||
return module(*inputs, return_dict=return_dict)
|
||||
else:
|
||||
return module(*inputs)
|
||||
|
||||
return custom_forward
|
||||
|
||||
ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
|
||||
hidden_states = torch.utils.checkpoint.checkpoint(
|
||||
create_custom_forward(resnet),
|
||||
hidden_states,
|
||||
temb,
|
||||
image_only_indicator,
|
||||
**ckpt_kwargs,
|
||||
)
|
||||
|
||||
hidden_states = attn(
|
||||
hidden_states,
|
||||
@@ -1245,7 +1308,11 @@ class CrossAttnDownBlockSpatioTemporal(nn.Module):
|
||||
return_dict=False,
|
||||
)[0]
|
||||
else:
|
||||
hidden_states = resnet(hidden_states, temb, image_only_indicator=image_only_indicator)
|
||||
hidden_states = resnet(
|
||||
hidden_states,
|
||||
temb,
|
||||
image_only_indicator=image_only_indicator,
|
||||
)
|
||||
hidden_states = attn(
|
||||
hidden_states,
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
@@ -1318,9 +1385,34 @@ class UpBlockSpatioTemporal(nn.Module):
|
||||
hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
|
||||
|
||||
if torch.is_grad_enabled() and self.gradient_checkpointing:
|
||||
hidden_states = self._gradient_checkpointing_func(resnet, hidden_states, temb, image_only_indicator)
|
||||
|
||||
def create_custom_forward(module):
|
||||
def custom_forward(*inputs):
|
||||
return module(*inputs)
|
||||
|
||||
return custom_forward
|
||||
|
||||
if is_torch_version(">=", "1.11.0"):
|
||||
hidden_states = torch.utils.checkpoint.checkpoint(
|
||||
create_custom_forward(resnet),
|
||||
hidden_states,
|
||||
temb,
|
||||
image_only_indicator,
|
||||
use_reentrant=False,
|
||||
)
|
||||
else:
|
||||
hidden_states = torch.utils.checkpoint.checkpoint(
|
||||
create_custom_forward(resnet),
|
||||
hidden_states,
|
||||
temb,
|
||||
image_only_indicator,
|
||||
)
|
||||
else:
|
||||
hidden_states = resnet(hidden_states, temb, image_only_indicator=image_only_indicator)
|
||||
hidden_states = resnet(
|
||||
hidden_states,
|
||||
temb,
|
||||
image_only_indicator=image_only_indicator,
|
||||
)
|
||||
|
||||
if self.upsamplers is not None:
|
||||
for upsampler in self.upsamplers:
|
||||
@@ -1403,8 +1495,25 @@ class CrossAttnUpBlockSpatioTemporal(nn.Module):
|
||||
|
||||
hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
|
||||
|
||||
if torch.is_grad_enabled() and self.gradient_checkpointing:
|
||||
hidden_states = self._gradient_checkpointing_func(resnet, hidden_states, temb, image_only_indicator)
|
||||
if torch.is_grad_enabled() and self.gradient_checkpointing: # TODO
|
||||
|
||||
def create_custom_forward(module, return_dict=None):
|
||||
def custom_forward(*inputs):
|
||||
if return_dict is not None:
|
||||
return module(*inputs, return_dict=return_dict)
|
||||
else:
|
||||
return module(*inputs)
|
||||
|
||||
return custom_forward
|
||||
|
||||
ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
|
||||
hidden_states = torch.utils.checkpoint.checkpoint(
|
||||
create_custom_forward(resnet),
|
||||
hidden_states,
|
||||
temb,
|
||||
image_only_indicator,
|
||||
**ckpt_kwargs,
|
||||
)
|
||||
hidden_states = attn(
|
||||
hidden_states,
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
@@ -1412,7 +1521,11 @@ class CrossAttnUpBlockSpatioTemporal(nn.Module):
|
||||
return_dict=False,
|
||||
)[0]
|
||||
else:
|
||||
hidden_states = resnet(hidden_states, temb, image_only_indicator=image_only_indicator)
|
||||
hidden_states = resnet(
|
||||
hidden_states,
|
||||
temb,
|
||||
image_only_indicator=image_only_indicator,
|
||||
)
|
||||
hidden_states = attn(
|
||||
hidden_states,
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
|
||||
@@ -37,7 +37,11 @@ from ..embeddings import TimestepEmbedding, Timesteps
|
||||
from ..modeling_utils import ModelMixin
|
||||
from ..transformers.transformer_temporal import TransformerTemporalModel
|
||||
from .unet_3d_blocks import (
|
||||
CrossAttnDownBlock3D,
|
||||
CrossAttnUpBlock3D,
|
||||
DownBlock3D,
|
||||
UNetMidBlock3DCrossAttn,
|
||||
UpBlock3D,
|
||||
get_down_block,
|
||||
get_up_block,
|
||||
)
|
||||
@@ -93,7 +97,6 @@ class UNet3DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin)
|
||||
"""
|
||||
|
||||
_supports_gradient_checkpointing = False
|
||||
_skip_layerwise_casting_patterns = ["norm", "time_embedding"]
|
||||
|
||||
@register_to_config
|
||||
def __init__(
|
||||
@@ -468,6 +471,10 @@ class UNet3DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin)
|
||||
|
||||
self.set_attn_processor(processor)
|
||||
|
||||
def _set_gradient_checkpointing(self, module, value: bool = False) -> None:
|
||||
if isinstance(module, (CrossAttnDownBlock3D, DownBlock3D, CrossAttnUpBlock3D, UpBlock3D)):
|
||||
module.gradient_checkpointing = value
|
||||
|
||||
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.enable_freeu
|
||||
def enable_freeu(self, s1, s2, b1, b2):
|
||||
r"""Enables the FreeU mechanism from https://arxiv.org/abs/2309.11497.
|
||||
|
||||
@@ -35,7 +35,11 @@ from ..embeddings import TimestepEmbedding, Timesteps
|
||||
from ..modeling_utils import ModelMixin
|
||||
from ..transformers.transformer_temporal import TransformerTemporalModel
|
||||
from .unet_3d_blocks import (
|
||||
CrossAttnDownBlock3D,
|
||||
CrossAttnUpBlock3D,
|
||||
DownBlock3D,
|
||||
UNetMidBlock3DCrossAttn,
|
||||
UpBlock3D,
|
||||
get_down_block,
|
||||
get_up_block,
|
||||
)
|
||||
@@ -432,6 +436,11 @@ class I2VGenXLUNet(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin):
|
||||
|
||||
self.set_attn_processor(processor)
|
||||
|
||||
# Copied from diffusers.models.unets.unet_3d_condition.UNet3DConditionModel._set_gradient_checkpointing
|
||||
def _set_gradient_checkpointing(self, module, value: bool = False) -> None:
|
||||
if isinstance(module, (CrossAttnDownBlock3D, DownBlock3D, CrossAttnUpBlock3D, UpBlock3D)):
|
||||
module.gradient_checkpointing = value
|
||||
|
||||
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.enable_freeu
|
||||
def enable_freeu(self, s1, s2, b1, b2):
|
||||
r"""Enables the FreeU mechanism from https://arxiv.org/abs/2309.11497.
|
||||
|
||||
@@ -205,6 +205,10 @@ class Kandinsky3UNet(ModelMixin, ConfigMixin):
|
||||
"""
|
||||
self.set_attn_processor(AttnProcessor())
|
||||
|
||||
def _set_gradient_checkpointing(self, module, value=False):
|
||||
if hasattr(module, "gradient_checkpointing"):
|
||||
module.gradient_checkpointing = value
|
||||
|
||||
def forward(self, sample, timestep, encoder_hidden_states=None, encoder_attention_mask=None, return_dict=True):
|
||||
if encoder_attention_mask is not None:
|
||||
encoder_attention_mask = (1 - encoder_attention_mask.to(sample.dtype)) * -10000.0
|
||||
|
||||
@@ -22,7 +22,7 @@ import torch.utils.checkpoint
|
||||
|
||||
from ...configuration_utils import ConfigMixin, FrozenDict, register_to_config
|
||||
from ...loaders import FromOriginalModelMixin, PeftAdapterMixin, UNet2DConditionLoadersMixin
|
||||
from ...utils import BaseOutput, deprecate, logging
|
||||
from ...utils import BaseOutput, deprecate, is_torch_version, logging
|
||||
from ...utils.torch_utils import apply_freeu
|
||||
from ..attention import BasicTransformerBlock
|
||||
from ..attention_processor import (
|
||||
@@ -324,7 +324,25 @@ class DownBlockMotion(nn.Module):
|
||||
blocks = zip(self.resnets, self.motion_modules)
|
||||
for resnet, motion_module in blocks:
|
||||
if torch.is_grad_enabled() and self.gradient_checkpointing:
|
||||
hidden_states = self._gradient_checkpointing_func(resnet, hidden_states, temb)
|
||||
|
||||
def create_custom_forward(module):
|
||||
def custom_forward(*inputs):
|
||||
return module(*inputs)
|
||||
|
||||
return custom_forward
|
||||
|
||||
if is_torch_version(">=", "1.11.0"):
|
||||
hidden_states = torch.utils.checkpoint.checkpoint(
|
||||
create_custom_forward(resnet),
|
||||
hidden_states,
|
||||
temb,
|
||||
use_reentrant=False,
|
||||
)
|
||||
else:
|
||||
hidden_states = torch.utils.checkpoint.checkpoint(
|
||||
create_custom_forward(resnet), hidden_states, temb
|
||||
)
|
||||
|
||||
else:
|
||||
hidden_states = resnet(input_tensor=hidden_states, temb=temb)
|
||||
|
||||
@@ -496,7 +514,23 @@ class CrossAttnDownBlockMotion(nn.Module):
|
||||
blocks = list(zip(self.resnets, self.attentions, self.motion_modules))
|
||||
for i, (resnet, attn, motion_module) in enumerate(blocks):
|
||||
if torch.is_grad_enabled() and self.gradient_checkpointing:
|
||||
hidden_states = self._gradient_checkpointing_func(resnet, hidden_states, temb)
|
||||
|
||||
def create_custom_forward(module, return_dict=None):
|
||||
def custom_forward(*inputs):
|
||||
if return_dict is not None:
|
||||
return module(*inputs, return_dict=return_dict)
|
||||
else:
|
||||
return module(*inputs)
|
||||
|
||||
return custom_forward
|
||||
|
||||
ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
|
||||
hidden_states = torch.utils.checkpoint.checkpoint(
|
||||
create_custom_forward(resnet),
|
||||
hidden_states,
|
||||
temb,
|
||||
**ckpt_kwargs,
|
||||
)
|
||||
else:
|
||||
hidden_states = resnet(input_tensor=hidden_states, temb=temb)
|
||||
|
||||
@@ -509,7 +543,10 @@ class CrossAttnDownBlockMotion(nn.Module):
|
||||
return_dict=False,
|
||||
)[0]
|
||||
|
||||
hidden_states = motion_module(hidden_states, num_frames=num_frames)
|
||||
hidden_states = motion_module(
|
||||
hidden_states,
|
||||
num_frames=num_frames,
|
||||
)
|
||||
|
||||
# apply additional residuals to the output of the last pair of resnet and attention blocks
|
||||
if i == len(blocks) - 1 and additional_residuals is not None:
|
||||
@@ -696,7 +733,23 @@ class CrossAttnUpBlockMotion(nn.Module):
|
||||
hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
|
||||
|
||||
if torch.is_grad_enabled() and self.gradient_checkpointing:
|
||||
hidden_states = self._gradient_checkpointing_func(resnet, hidden_states, temb)
|
||||
|
||||
def create_custom_forward(module, return_dict=None):
|
||||
def custom_forward(*inputs):
|
||||
if return_dict is not None:
|
||||
return module(*inputs, return_dict=return_dict)
|
||||
else:
|
||||
return module(*inputs)
|
||||
|
||||
return custom_forward
|
||||
|
||||
ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
|
||||
hidden_states = torch.utils.checkpoint.checkpoint(
|
||||
create_custom_forward(resnet),
|
||||
hidden_states,
|
||||
temb,
|
||||
**ckpt_kwargs,
|
||||
)
|
||||
else:
|
||||
hidden_states = resnet(input_tensor=hidden_states, temb=temb)
|
||||
|
||||
@@ -709,7 +762,10 @@ class CrossAttnUpBlockMotion(nn.Module):
|
||||
return_dict=False,
|
||||
)[0]
|
||||
|
||||
hidden_states = motion_module(hidden_states, num_frames=num_frames)
|
||||
hidden_states = motion_module(
|
||||
hidden_states,
|
||||
num_frames=num_frames,
|
||||
)
|
||||
|
||||
if self.upsamplers is not None:
|
||||
for upsampler in self.upsamplers:
|
||||
@@ -840,7 +896,24 @@ class UpBlockMotion(nn.Module):
|
||||
hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
|
||||
|
||||
if torch.is_grad_enabled() and self.gradient_checkpointing:
|
||||
hidden_states = self._gradient_checkpointing_func(resnet, hidden_states, temb)
|
||||
|
||||
def create_custom_forward(module):
|
||||
def custom_forward(*inputs):
|
||||
return module(*inputs)
|
||||
|
||||
return custom_forward
|
||||
|
||||
if is_torch_version(">=", "1.11.0"):
|
||||
hidden_states = torch.utils.checkpoint.checkpoint(
|
||||
create_custom_forward(resnet),
|
||||
hidden_states,
|
||||
temb,
|
||||
use_reentrant=False,
|
||||
)
|
||||
else:
|
||||
hidden_states = torch.utils.checkpoint.checkpoint(
|
||||
create_custom_forward(resnet), hidden_states, temb
|
||||
)
|
||||
else:
|
||||
hidden_states = resnet(input_tensor=hidden_states, temb=temb)
|
||||
|
||||
@@ -1007,12 +1080,34 @@ class UNetMidBlockCrossAttnMotion(nn.Module):
|
||||
)[0]
|
||||
|
||||
if torch.is_grad_enabled() and self.gradient_checkpointing:
|
||||
hidden_states = self._gradient_checkpointing_func(
|
||||
motion_module, hidden_states, None, None, None, num_frames, None
|
||||
|
||||
def create_custom_forward(module, return_dict=None):
|
||||
def custom_forward(*inputs):
|
||||
if return_dict is not None:
|
||||
return module(*inputs, return_dict=return_dict)
|
||||
else:
|
||||
return module(*inputs)
|
||||
|
||||
return custom_forward
|
||||
|
||||
ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
|
||||
hidden_states = torch.utils.checkpoint.checkpoint(
|
||||
create_custom_forward(motion_module),
|
||||
hidden_states,
|
||||
temb,
|
||||
**ckpt_kwargs,
|
||||
)
|
||||
hidden_states = torch.utils.checkpoint.checkpoint(
|
||||
create_custom_forward(resnet),
|
||||
hidden_states,
|
||||
temb,
|
||||
**ckpt_kwargs,
|
||||
)
|
||||
hidden_states = self._gradient_checkpointing_func(resnet, hidden_states, temb)
|
||||
else:
|
||||
hidden_states = motion_module(hidden_states, None, None, None, num_frames, None)
|
||||
hidden_states = motion_module(
|
||||
hidden_states,
|
||||
num_frames=num_frames,
|
||||
)
|
||||
hidden_states = resnet(input_tensor=hidden_states, temb=temb)
|
||||
|
||||
return hidden_states
|
||||
@@ -1206,7 +1301,6 @@ class UNetMotionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin, Peft
|
||||
"""
|
||||
|
||||
_supports_gradient_checkpointing = True
|
||||
_skip_layerwise_casting_patterns = ["norm"]
|
||||
|
||||
@register_to_config
|
||||
def __init__(
|
||||
@@ -1871,6 +1965,10 @@ class UNetMotionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin, Peft
|
||||
|
||||
self.set_attn_processor(processor)
|
||||
|
||||
def _set_gradient_checkpointing(self, module, value: bool = False) -> None:
|
||||
if isinstance(module, (CrossAttnDownBlockMotion, DownBlockMotion, CrossAttnUpBlockMotion, UpBlockMotion)):
|
||||
module.gradient_checkpointing = value
|
||||
|
||||
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.enable_freeu
|
||||
def enable_freeu(self, s1: float, s2: float, b1: float, b2: float) -> None:
|
||||
r"""Enables the FreeU mechanism from https://arxiv.org/abs/2309.11497.
|
||||
|
||||
@@ -320,6 +320,10 @@ class UNetSpatioTemporalConditionModel(ModelMixin, ConfigMixin, UNet2DConditionL
|
||||
|
||||
self.set_attn_processor(processor)
|
||||
|
||||
def _set_gradient_checkpointing(self, module, value=False):
|
||||
if hasattr(module, "gradient_checkpointing"):
|
||||
module.gradient_checkpointing = value
|
||||
|
||||
# Copied from diffusers.models.unets.unet_3d_condition.UNet3DConditionModel.enable_forward_chunking
|
||||
def enable_forward_chunking(self, chunk_size: Optional[int] = None, dim: int = 0) -> None:
|
||||
"""
|
||||
|
||||
@@ -387,6 +387,9 @@ class StableCascadeUNet(ModelMixin, ConfigMixin, FromOriginalModelMixin):
|
||||
|
||||
self.gradient_checkpointing = False
|
||||
|
||||
def _set_gradient_checkpointing(self, value=False):
|
||||
self.gradient_checkpointing = value
|
||||
|
||||
def _init_weights(self, m):
|
||||
if isinstance(m, (nn.Conv2d, nn.Linear)):
|
||||
torch.nn.init.xavier_uniform_(m.weight)
|
||||
@@ -453,18 +456,29 @@ class StableCascadeUNet(ModelMixin, ConfigMixin, FromOriginalModelMixin):
|
||||
block_group = zip(self.down_blocks, self.down_downscalers, self.down_repeat_mappers)
|
||||
|
||||
if torch.is_grad_enabled() and self.gradient_checkpointing:
|
||||
|
||||
def create_custom_forward(module):
|
||||
def custom_forward(*inputs):
|
||||
return module(*inputs)
|
||||
|
||||
return custom_forward
|
||||
|
||||
for down_block, downscaler, repmap in block_group:
|
||||
x = downscaler(x)
|
||||
for i in range(len(repmap) + 1):
|
||||
for block in down_block:
|
||||
if isinstance(block, SDCascadeResBlock):
|
||||
x = self._gradient_checkpointing_func(block, x)
|
||||
x = torch.utils.checkpoint.checkpoint(create_custom_forward(block), x, use_reentrant=False)
|
||||
elif isinstance(block, SDCascadeAttnBlock):
|
||||
x = self._gradient_checkpointing_func(block, x, clip)
|
||||
x = torch.utils.checkpoint.checkpoint(
|
||||
create_custom_forward(block), x, clip, use_reentrant=False
|
||||
)
|
||||
elif isinstance(block, SDCascadeTimestepBlock):
|
||||
x = self._gradient_checkpointing_func(block, x, r_embed)
|
||||
x = torch.utils.checkpoint.checkpoint(
|
||||
create_custom_forward(block), x, r_embed, use_reentrant=False
|
||||
)
|
||||
else:
|
||||
x = self._gradient_checkpointing_func(block)
|
||||
x = torch.utils.checkpoint.checkpoint(create_custom_forward(block), use_reentrant=False)
|
||||
if i < len(repmap):
|
||||
x = repmap[i](x)
|
||||
level_outputs.insert(0, x)
|
||||
@@ -491,6 +505,13 @@ class StableCascadeUNet(ModelMixin, ConfigMixin, FromOriginalModelMixin):
|
||||
block_group = zip(self.up_blocks, self.up_upscalers, self.up_repeat_mappers)
|
||||
|
||||
if torch.is_grad_enabled() and self.gradient_checkpointing:
|
||||
|
||||
def create_custom_forward(module):
|
||||
def custom_forward(*inputs):
|
||||
return module(*inputs)
|
||||
|
||||
return custom_forward
|
||||
|
||||
for i, (up_block, upscaler, repmap) in enumerate(block_group):
|
||||
for j in range(len(repmap) + 1):
|
||||
for k, block in enumerate(up_block):
|
||||
@@ -502,13 +523,19 @@ class StableCascadeUNet(ModelMixin, ConfigMixin, FromOriginalModelMixin):
|
||||
x.float(), skip.shape[-2:], mode="bilinear", align_corners=True
|
||||
)
|
||||
x = x.to(orig_type)
|
||||
x = self._gradient_checkpointing_func(block, x, skip)
|
||||
x = torch.utils.checkpoint.checkpoint(
|
||||
create_custom_forward(block), x, skip, use_reentrant=False
|
||||
)
|
||||
elif isinstance(block, SDCascadeAttnBlock):
|
||||
x = self._gradient_checkpointing_func(block, x, clip)
|
||||
x = torch.utils.checkpoint.checkpoint(
|
||||
create_custom_forward(block), x, clip, use_reentrant=False
|
||||
)
|
||||
elif isinstance(block, SDCascadeTimestepBlock):
|
||||
x = self._gradient_checkpointing_func(block, x, r_embed)
|
||||
x = torch.utils.checkpoint.checkpoint(
|
||||
create_custom_forward(block), x, r_embed, use_reentrant=False
|
||||
)
|
||||
else:
|
||||
x = self._gradient_checkpointing_func(block, x)
|
||||
x = torch.utils.checkpoint.checkpoint(create_custom_forward(block), x, use_reentrant=False)
|
||||
if j < len(repmap):
|
||||
x = repmap[j](x)
|
||||
x = upscaler(x)
|
||||
|
||||
@@ -148,6 +148,9 @@ class UVit2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin):
|
||||
|
||||
self.gradient_checkpointing = False
|
||||
|
||||
def _set_gradient_checkpointing(self, module, value: bool = False) -> None:
|
||||
pass
|
||||
|
||||
def forward(self, input_ids, encoder_hidden_states, pooled_text_emb, micro_conds, cross_attention_kwargs=None):
|
||||
encoder_hidden_states = self.encoder_proj(encoder_hidden_states)
|
||||
encoder_hidden_states = self.encoder_proj_layer_norm(encoder_hidden_states)
|
||||
|
||||
@@ -683,10 +683,6 @@ class AllegroPipeline(DiffusionPipeline):
|
||||
def num_timesteps(self):
|
||||
return self._num_timesteps
|
||||
|
||||
@property
|
||||
def current_timestep(self):
|
||||
return self._current_timestep
|
||||
|
||||
@property
|
||||
def interrupt(self):
|
||||
return self._interrupt
|
||||
@@ -819,7 +815,6 @@ class AllegroPipeline(DiffusionPipeline):
|
||||
negative_prompt_attention_mask,
|
||||
)
|
||||
self._guidance_scale = guidance_scale
|
||||
self._current_timestep = None
|
||||
self._interrupt = False
|
||||
|
||||
# 2. Default height and width to transformer
|
||||
@@ -897,7 +892,6 @@ class AllegroPipeline(DiffusionPipeline):
|
||||
if self.interrupt:
|
||||
continue
|
||||
|
||||
self._current_timestep = t
|
||||
latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
|
||||
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
|
||||
|
||||
@@ -939,8 +933,6 @@ class AllegroPipeline(DiffusionPipeline):
|
||||
if XLA_AVAILABLE:
|
||||
xm.mark_step()
|
||||
|
||||
self._current_timestep = None
|
||||
|
||||
if not output_type == "latent":
|
||||
latents = latents.to(self.vae.dtype)
|
||||
video = self.decode_latents(latents)
|
||||
|
||||
@@ -38,7 +38,7 @@ from ...models.resnet import Downsample2D, ResnetBlock2D, Upsample2D
|
||||
from ...models.transformers.transformer_2d import Transformer2DModel
|
||||
from ...models.unets.unet_2d_blocks import DownBlock2D, UpBlock2D
|
||||
from ...models.unets.unet_2d_condition import UNet2DConditionOutput
|
||||
from ...utils import BaseOutput, logging
|
||||
from ...utils import BaseOutput, is_torch_version, logging
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||
@@ -673,6 +673,11 @@ class AudioLDM2UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoad
|
||||
for module in self.children():
|
||||
fn_recursive_set_attention_slice(module, reversed_slice_size)
|
||||
|
||||
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel._set_gradient_checkpointing
|
||||
def _set_gradient_checkpointing(self, module, value=False):
|
||||
if hasattr(module, "gradient_checkpointing"):
|
||||
module.gradient_checkpointing = value
|
||||
|
||||
def forward(
|
||||
self,
|
||||
sample: torch.Tensor,
|
||||
@@ -1109,7 +1114,23 @@ class CrossAttnDownBlock2D(nn.Module):
|
||||
|
||||
for i in range(num_layers):
|
||||
if torch.is_grad_enabled() and self.gradient_checkpointing:
|
||||
hidden_states = self._gradient_checkpointing_func(self.resnets[i], hidden_states, temb)
|
||||
|
||||
def create_custom_forward(module, return_dict=None):
|
||||
def custom_forward(*inputs):
|
||||
if return_dict is not None:
|
||||
return module(*inputs, return_dict=return_dict)
|
||||
else:
|
||||
return module(*inputs)
|
||||
|
||||
return custom_forward
|
||||
|
||||
ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
|
||||
hidden_states = torch.utils.checkpoint.checkpoint(
|
||||
create_custom_forward(self.resnets[i]),
|
||||
hidden_states,
|
||||
temb,
|
||||
**ckpt_kwargs,
|
||||
)
|
||||
for idx, cross_attention_dim in enumerate(self.cross_attention_dim):
|
||||
if cross_attention_dim is not None and idx <= 1:
|
||||
forward_encoder_hidden_states = encoder_hidden_states
|
||||
@@ -1120,8 +1141,8 @@ class CrossAttnDownBlock2D(nn.Module):
|
||||
else:
|
||||
forward_encoder_hidden_states = None
|
||||
forward_encoder_attention_mask = None
|
||||
hidden_states = self._gradient_checkpointing_func(
|
||||
self.attentions[i * num_attention_per_layer + idx],
|
||||
hidden_states = torch.utils.checkpoint.checkpoint(
|
||||
create_custom_forward(self.attentions[i * num_attention_per_layer + idx], return_dict=False),
|
||||
hidden_states,
|
||||
forward_encoder_hidden_states,
|
||||
None, # timestep
|
||||
@@ -1129,6 +1150,7 @@ class CrossAttnDownBlock2D(nn.Module):
|
||||
cross_attention_kwargs,
|
||||
attention_mask,
|
||||
forward_encoder_attention_mask,
|
||||
**ckpt_kwargs,
|
||||
)[0]
|
||||
else:
|
||||
hidden_states = self.resnets[i](hidden_states, temb)
|
||||
@@ -1270,6 +1292,17 @@ class UNetMidBlock2DCrossAttn(nn.Module):
|
||||
|
||||
for i in range(len(self.resnets[1:])):
|
||||
if torch.is_grad_enabled() and self.gradient_checkpointing:
|
||||
|
||||
def create_custom_forward(module, return_dict=None):
|
||||
def custom_forward(*inputs):
|
||||
if return_dict is not None:
|
||||
return module(*inputs, return_dict=return_dict)
|
||||
else:
|
||||
return module(*inputs)
|
||||
|
||||
return custom_forward
|
||||
|
||||
ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
|
||||
for idx, cross_attention_dim in enumerate(self.cross_attention_dim):
|
||||
if cross_attention_dim is not None and idx <= 1:
|
||||
forward_encoder_hidden_states = encoder_hidden_states
|
||||
@@ -1280,8 +1313,8 @@ class UNetMidBlock2DCrossAttn(nn.Module):
|
||||
else:
|
||||
forward_encoder_hidden_states = None
|
||||
forward_encoder_attention_mask = None
|
||||
hidden_states = self._gradient_checkpointing_func(
|
||||
self.attentions[i * num_attention_per_layer + idx],
|
||||
hidden_states = torch.utils.checkpoint.checkpoint(
|
||||
create_custom_forward(self.attentions[i * num_attention_per_layer + idx], return_dict=False),
|
||||
hidden_states,
|
||||
forward_encoder_hidden_states,
|
||||
None, # timestep
|
||||
@@ -1289,8 +1322,14 @@ class UNetMidBlock2DCrossAttn(nn.Module):
|
||||
cross_attention_kwargs,
|
||||
attention_mask,
|
||||
forward_encoder_attention_mask,
|
||||
**ckpt_kwargs,
|
||||
)[0]
|
||||
hidden_states = self._gradient_checkpointing_func(self.resnets[i + 1], hidden_states, temb)
|
||||
hidden_states = torch.utils.checkpoint.checkpoint(
|
||||
create_custom_forward(self.resnets[i + 1]),
|
||||
hidden_states,
|
||||
temb,
|
||||
**ckpt_kwargs,
|
||||
)
|
||||
else:
|
||||
for idx, cross_attention_dim in enumerate(self.cross_attention_dim):
|
||||
if cross_attention_dim is not None and idx <= 1:
|
||||
@@ -1427,7 +1466,23 @@ class CrossAttnUpBlock2D(nn.Module):
|
||||
hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
|
||||
|
||||
if torch.is_grad_enabled() and self.gradient_checkpointing:
|
||||
hidden_states = self._gradient_checkpointing_func(self.resnets[i], hidden_states, temb)
|
||||
|
||||
def create_custom_forward(module, return_dict=None):
|
||||
def custom_forward(*inputs):
|
||||
if return_dict is not None:
|
||||
return module(*inputs, return_dict=return_dict)
|
||||
else:
|
||||
return module(*inputs)
|
||||
|
||||
return custom_forward
|
||||
|
||||
ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
|
||||
hidden_states = torch.utils.checkpoint.checkpoint(
|
||||
create_custom_forward(self.resnets[i]),
|
||||
hidden_states,
|
||||
temb,
|
||||
**ckpt_kwargs,
|
||||
)
|
||||
for idx, cross_attention_dim in enumerate(self.cross_attention_dim):
|
||||
if cross_attention_dim is not None and idx <= 1:
|
||||
forward_encoder_hidden_states = encoder_hidden_states
|
||||
@@ -1438,8 +1493,8 @@ class CrossAttnUpBlock2D(nn.Module):
|
||||
else:
|
||||
forward_encoder_hidden_states = None
|
||||
forward_encoder_attention_mask = None
|
||||
hidden_states = self._gradient_checkpointing_func(
|
||||
self.attentions[i * num_attention_per_layer + idx],
|
||||
hidden_states = torch.utils.checkpoint.checkpoint(
|
||||
create_custom_forward(self.attentions[i * num_attention_per_layer + idx], return_dict=False),
|
||||
hidden_states,
|
||||
forward_encoder_hidden_states,
|
||||
None, # timestep
|
||||
@@ -1447,6 +1502,7 @@ class CrossAttnUpBlock2D(nn.Module):
|
||||
cross_attention_kwargs,
|
||||
attention_mask,
|
||||
forward_encoder_attention_mask,
|
||||
**ckpt_kwargs,
|
||||
)[0]
|
||||
else:
|
||||
hidden_states = self.resnets[i](hidden_states, temb)
|
||||
|
||||
@@ -174,16 +174,19 @@ class Blip2QFormerEncoder(nn.Module):
|
||||
)
|
||||
use_cache = False
|
||||
|
||||
layer_outputs = self._gradient_checkpointing_func(
|
||||
layer_module,
|
||||
def create_custom_forward(module):
|
||||
def custom_forward(*inputs):
|
||||
return module(*inputs, past_key_value, output_attentions, query_length)
|
||||
|
||||
return custom_forward
|
||||
|
||||
layer_outputs = torch.utils.checkpoint.checkpoint(
|
||||
create_custom_forward(layer_module),
|
||||
hidden_states,
|
||||
attention_mask,
|
||||
layer_head_mask,
|
||||
encoder_hidden_states,
|
||||
encoder_attention_mask,
|
||||
past_key_value,
|
||||
output_attentions,
|
||||
query_length,
|
||||
)
|
||||
else:
|
||||
layer_outputs = layer_module(
|
||||
|
||||
@@ -494,10 +494,6 @@ class CogVideoXPipeline(DiffusionPipeline, CogVideoXLoraLoaderMixin):
|
||||
def attention_kwargs(self):
|
||||
return self._attention_kwargs
|
||||
|
||||
@property
|
||||
def current_timestep(self):
|
||||
return self._current_timestep
|
||||
|
||||
@property
|
||||
def interrupt(self):
|
||||
return self._interrupt
|
||||
@@ -631,7 +627,6 @@ class CogVideoXPipeline(DiffusionPipeline, CogVideoXLoraLoaderMixin):
|
||||
)
|
||||
self._guidance_scale = guidance_scale
|
||||
self._attention_kwargs = attention_kwargs
|
||||
self._current_timestep = None
|
||||
self._interrupt = False
|
||||
|
||||
# 2. Default call parameters
|
||||
@@ -710,7 +705,6 @@ class CogVideoXPipeline(DiffusionPipeline, CogVideoXLoraLoaderMixin):
|
||||
if self.interrupt:
|
||||
continue
|
||||
|
||||
self._current_timestep = t
|
||||
latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
|
||||
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
|
||||
|
||||
@@ -769,8 +763,6 @@ class CogVideoXPipeline(DiffusionPipeline, CogVideoXLoraLoaderMixin):
|
||||
if XLA_AVAILABLE:
|
||||
xm.mark_step()
|
||||
|
||||
self._current_timestep = None
|
||||
|
||||
if not output_type == "latent":
|
||||
# Discard any padding frames that were added for CogVideoX 1.5
|
||||
latents = latents[:, additional_frames:]
|
||||
|
||||
@@ -540,10 +540,6 @@ class CogVideoXFunControlPipeline(DiffusionPipeline, CogVideoXLoraLoaderMixin):
|
||||
def attention_kwargs(self):
|
||||
return self._attention_kwargs
|
||||
|
||||
@property
|
||||
def current_timestep(self):
|
||||
return self._current_timestep
|
||||
|
||||
@property
|
||||
def interrupt(self):
|
||||
return self._interrupt
|
||||
@@ -684,7 +680,6 @@ class CogVideoXFunControlPipeline(DiffusionPipeline, CogVideoXLoraLoaderMixin):
|
||||
)
|
||||
self._guidance_scale = guidance_scale
|
||||
self._attention_kwargs = attention_kwargs
|
||||
self._current_timestep = None
|
||||
self._interrupt = False
|
||||
|
||||
# 2. Default call parameters
|
||||
@@ -771,7 +766,6 @@ class CogVideoXFunControlPipeline(DiffusionPipeline, CogVideoXLoraLoaderMixin):
|
||||
if self.interrupt:
|
||||
continue
|
||||
|
||||
self._current_timestep = t
|
||||
latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
|
||||
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
|
||||
|
||||
@@ -824,8 +818,6 @@ class CogVideoXFunControlPipeline(DiffusionPipeline, CogVideoXLoraLoaderMixin):
|
||||
if XLA_AVAILABLE:
|
||||
xm.mark_step()
|
||||
|
||||
self._current_timestep = None
|
||||
|
||||
if not output_type == "latent":
|
||||
video = self.decode_latents(latents)
|
||||
video = self.video_processor.postprocess_video(video=video, output_type=output_type)
|
||||
|
||||
@@ -591,10 +591,6 @@ class CogVideoXImageToVideoPipeline(DiffusionPipeline, CogVideoXLoraLoaderMixin)
|
||||
def attention_kwargs(self):
|
||||
return self._attention_kwargs
|
||||
|
||||
@property
|
||||
def current_timestep(self):
|
||||
return self._current_timestep
|
||||
|
||||
@property
|
||||
def interrupt(self):
|
||||
return self._interrupt
|
||||
@@ -732,7 +728,6 @@ class CogVideoXImageToVideoPipeline(DiffusionPipeline, CogVideoXLoraLoaderMixin)
|
||||
negative_prompt_embeds=negative_prompt_embeds,
|
||||
)
|
||||
self._guidance_scale = guidance_scale
|
||||
self._current_timestep = None
|
||||
self._attention_kwargs = attention_kwargs
|
||||
self._interrupt = False
|
||||
|
||||
@@ -820,7 +815,6 @@ class CogVideoXImageToVideoPipeline(DiffusionPipeline, CogVideoXLoraLoaderMixin)
|
||||
if self.interrupt:
|
||||
continue
|
||||
|
||||
self._current_timestep = t
|
||||
latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
|
||||
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
|
||||
|
||||
@@ -883,8 +877,6 @@ class CogVideoXImageToVideoPipeline(DiffusionPipeline, CogVideoXLoraLoaderMixin)
|
||||
if XLA_AVAILABLE:
|
||||
xm.mark_step()
|
||||
|
||||
self._current_timestep = None
|
||||
|
||||
if not output_type == "latent":
|
||||
# Discard any padding frames that were added for CogVideoX 1.5
|
||||
latents = latents[:, additional_frames:]
|
||||
|
||||
@@ -564,10 +564,6 @@ class CogVideoXVideoToVideoPipeline(DiffusionPipeline, CogVideoXLoraLoaderMixin)
|
||||
def attention_kwargs(self):
|
||||
return self._attention_kwargs
|
||||
|
||||
@property
|
||||
def current_timestep(self):
|
||||
return self._current_timestep
|
||||
|
||||
@property
|
||||
def interrupt(self):
|
||||
return self._interrupt
|
||||
@@ -704,7 +700,6 @@ class CogVideoXVideoToVideoPipeline(DiffusionPipeline, CogVideoXLoraLoaderMixin)
|
||||
)
|
||||
self._guidance_scale = guidance_scale
|
||||
self._attention_kwargs = attention_kwargs
|
||||
self._current_timestep = None
|
||||
self._interrupt = False
|
||||
|
||||
# 2. Default call parameters
|
||||
@@ -791,7 +786,6 @@ class CogVideoXVideoToVideoPipeline(DiffusionPipeline, CogVideoXLoraLoaderMixin)
|
||||
if self.interrupt:
|
||||
continue
|
||||
|
||||
self._current_timestep = t
|
||||
latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
|
||||
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
|
||||
|
||||
@@ -850,8 +844,6 @@ class CogVideoXVideoToVideoPipeline(DiffusionPipeline, CogVideoXLoraLoaderMixin)
|
||||
if XLA_AVAILABLE:
|
||||
xm.mark_step()
|
||||
|
||||
self._current_timestep = None
|
||||
|
||||
if not output_type == "latent":
|
||||
video = self.decode_latents(latents)
|
||||
video = self.video_processor.postprocess_video(video=video, output_type=output_type)
|
||||
|
||||
@@ -766,6 +766,26 @@ class StableDiffusionXLControlNetUnionInpaintPipeline(
|
||||
else:
|
||||
assert False
|
||||
|
||||
# Check `controlnet_conditioning_scale`
|
||||
if (
|
||||
isinstance(self.controlnet, ControlNetModel)
|
||||
or is_compiled
|
||||
and isinstance(self.controlnet._orig_mod, ControlNetModel)
|
||||
):
|
||||
if not isinstance(controlnet_conditioning_scale, float):
|
||||
raise TypeError("For single controlnet: `controlnet_conditioning_scale` must be type `float`.")
|
||||
|
||||
elif (
|
||||
isinstance(self.controlnet, ControlNetUnionModel)
|
||||
or is_compiled
|
||||
and isinstance(self.controlnet._orig_mod, ControlNetUnionModel)
|
||||
):
|
||||
if not isinstance(controlnet_conditioning_scale, float):
|
||||
raise TypeError("For single controlnet: `controlnet_conditioning_scale` must be type `float`.")
|
||||
|
||||
else:
|
||||
assert False
|
||||
|
||||
if not isinstance(control_guidance_start, (tuple, list)):
|
||||
control_guidance_start = [control_guidance_start]
|
||||
|
||||
@@ -1350,8 +1370,6 @@ class StableDiffusionXLControlNetUnionInpaintPipeline(
|
||||
|
||||
if not isinstance(control_image, list):
|
||||
control_image = [control_image]
|
||||
else:
|
||||
control_image = control_image.copy()
|
||||
|
||||
if not isinstance(control_mode, list):
|
||||
control_mode = [control_mode]
|
||||
|
||||
@@ -741,6 +741,26 @@ class StableDiffusionXLControlNetUnionPipeline(
|
||||
else:
|
||||
assert False
|
||||
|
||||
# Check `controlnet_conditioning_scale`
|
||||
if (
|
||||
isinstance(self.controlnet, ControlNetModel)
|
||||
or is_compiled
|
||||
and isinstance(self.controlnet._orig_mod, ControlNetModel)
|
||||
):
|
||||
if not isinstance(controlnet_conditioning_scale, float):
|
||||
raise TypeError("For single controlnet: `controlnet_conditioning_scale` must be type `float`.")
|
||||
|
||||
elif (
|
||||
isinstance(self.controlnet, ControlNetUnionModel)
|
||||
or is_compiled
|
||||
and isinstance(self.controlnet._orig_mod, ControlNetUnionModel)
|
||||
):
|
||||
if not isinstance(controlnet_conditioning_scale, float):
|
||||
raise TypeError("For single controlnet: `controlnet_conditioning_scale` must be type `float`.")
|
||||
|
||||
else:
|
||||
assert False
|
||||
|
||||
if not isinstance(control_guidance_start, (tuple, list)):
|
||||
control_guidance_start = [control_guidance_start]
|
||||
|
||||
@@ -1140,8 +1160,6 @@ class StableDiffusionXLControlNetUnionPipeline(
|
||||
|
||||
if not isinstance(control_image, list):
|
||||
control_image = [control_image]
|
||||
else:
|
||||
control_image = control_image.copy()
|
||||
|
||||
if not isinstance(control_mode, list):
|
||||
control_mode = [control_mode]
|
||||
|
||||
@@ -746,6 +746,26 @@ class StableDiffusionXLControlNetUnionImg2ImgPipeline(
|
||||
else:
|
||||
assert False
|
||||
|
||||
# Check `controlnet_conditioning_scale`
|
||||
if (
|
||||
isinstance(self.controlnet, ControlNetModel)
|
||||
or is_compiled
|
||||
and isinstance(self.controlnet._orig_mod, ControlNetModel)
|
||||
):
|
||||
if not isinstance(controlnet_conditioning_scale, float):
|
||||
raise TypeError("For single controlnet: `controlnet_conditioning_scale` must be type `float`.")
|
||||
|
||||
elif (
|
||||
isinstance(self.controlnet, ControlNetUnionModel)
|
||||
or is_compiled
|
||||
and isinstance(self.controlnet._orig_mod, ControlNetUnionModel)
|
||||
):
|
||||
if not isinstance(controlnet_conditioning_scale, float):
|
||||
raise TypeError("For single controlnet: `controlnet_conditioning_scale` must be type `float`.")
|
||||
|
||||
else:
|
||||
assert False
|
||||
|
||||
if not isinstance(control_guidance_start, (tuple, list)):
|
||||
control_guidance_start = [control_guidance_start]
|
||||
|
||||
@@ -1286,8 +1306,6 @@ class StableDiffusionXLControlNetUnionImg2ImgPipeline(
|
||||
|
||||
if not isinstance(control_image, list):
|
||||
control_image = [control_image]
|
||||
else:
|
||||
control_image = control_image.copy()
|
||||
|
||||
if not isinstance(control_mode, list):
|
||||
control_mode = [control_mode]
|
||||
|
||||
@@ -34,7 +34,7 @@ from ....models.resnet import ResnetBlockCondNorm2D
|
||||
from ....models.transformers.dual_transformer_2d import DualTransformer2DModel
|
||||
from ....models.transformers.transformer_2d import Transformer2DModel
|
||||
from ....models.unets.unet_2d_condition import UNet2DConditionOutput
|
||||
from ....utils import USE_PEFT_BACKEND, logging, scale_lora_layers, unscale_lora_layers
|
||||
from ....utils import USE_PEFT_BACKEND, is_torch_version, logging, scale_lora_layers, unscale_lora_layers
|
||||
from ....utils.torch_utils import apply_freeu
|
||||
|
||||
|
||||
@@ -963,6 +963,10 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin):
|
||||
for module in self.children():
|
||||
fn_recursive_set_attention_slice(module, reversed_slice_size)
|
||||
|
||||
def _set_gradient_checkpointing(self, module, value=False):
|
||||
if hasattr(module, "gradient_checkpointing"):
|
||||
module.gradient_checkpointing = value
|
||||
|
||||
def enable_freeu(self, s1, s2, b1, b2):
|
||||
r"""Enables the FreeU mechanism from https://arxiv.org/abs/2309.11497.
|
||||
|
||||
@@ -1593,7 +1597,21 @@ class DownBlockFlat(nn.Module):
|
||||
|
||||
for resnet in self.resnets:
|
||||
if torch.is_grad_enabled() and self.gradient_checkpointing:
|
||||
hidden_states = self._gradient_checkpointing_func(resnet, hidden_states, temb)
|
||||
|
||||
def create_custom_forward(module):
|
||||
def custom_forward(*inputs):
|
||||
return module(*inputs)
|
||||
|
||||
return custom_forward
|
||||
|
||||
if is_torch_version(">=", "1.11.0"):
|
||||
hidden_states = torch.utils.checkpoint.checkpoint(
|
||||
create_custom_forward(resnet), hidden_states, temb, use_reentrant=False
|
||||
)
|
||||
else:
|
||||
hidden_states = torch.utils.checkpoint.checkpoint(
|
||||
create_custom_forward(resnet), hidden_states, temb
|
||||
)
|
||||
else:
|
||||
hidden_states = resnet(hidden_states, temb)
|
||||
|
||||
@@ -1716,7 +1734,23 @@ class CrossAttnDownBlockFlat(nn.Module):
|
||||
|
||||
for i, (resnet, attn) in enumerate(blocks):
|
||||
if torch.is_grad_enabled() and self.gradient_checkpointing:
|
||||
hidden_states = self._gradient_checkpointing_func(resnet, hidden_states, temb)
|
||||
|
||||
def create_custom_forward(module, return_dict=None):
|
||||
def custom_forward(*inputs):
|
||||
if return_dict is not None:
|
||||
return module(*inputs, return_dict=return_dict)
|
||||
else:
|
||||
return module(*inputs)
|
||||
|
||||
return custom_forward
|
||||
|
||||
ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
|
||||
hidden_states = torch.utils.checkpoint.checkpoint(
|
||||
create_custom_forward(resnet),
|
||||
hidden_states,
|
||||
temb,
|
||||
**ckpt_kwargs,
|
||||
)
|
||||
hidden_states = attn(
|
||||
hidden_states,
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
@@ -1842,7 +1876,21 @@ class UpBlockFlat(nn.Module):
|
||||
hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
|
||||
|
||||
if torch.is_grad_enabled() and self.gradient_checkpointing:
|
||||
hidden_states = self._gradient_checkpointing_func(resnet, hidden_states, temb)
|
||||
|
||||
def create_custom_forward(module):
|
||||
def custom_forward(*inputs):
|
||||
return module(*inputs)
|
||||
|
||||
return custom_forward
|
||||
|
||||
if is_torch_version(">=", "1.11.0"):
|
||||
hidden_states = torch.utils.checkpoint.checkpoint(
|
||||
create_custom_forward(resnet), hidden_states, temb, use_reentrant=False
|
||||
)
|
||||
else:
|
||||
hidden_states = torch.utils.checkpoint.checkpoint(
|
||||
create_custom_forward(resnet), hidden_states, temb
|
||||
)
|
||||
else:
|
||||
hidden_states = resnet(hidden_states, temb)
|
||||
|
||||
@@ -1987,7 +2035,23 @@ class CrossAttnUpBlockFlat(nn.Module):
|
||||
hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
|
||||
|
||||
if torch.is_grad_enabled() and self.gradient_checkpointing:
|
||||
hidden_states = self._gradient_checkpointing_func(resnet, hidden_states, temb)
|
||||
|
||||
def create_custom_forward(module, return_dict=None):
|
||||
def custom_forward(*inputs):
|
||||
if return_dict is not None:
|
||||
return module(*inputs, return_dict=return_dict)
|
||||
else:
|
||||
return module(*inputs)
|
||||
|
||||
return custom_forward
|
||||
|
||||
ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
|
||||
hidden_states = torch.utils.checkpoint.checkpoint(
|
||||
create_custom_forward(resnet),
|
||||
hidden_states,
|
||||
temb,
|
||||
**ckpt_kwargs,
|
||||
)
|
||||
hidden_states = attn(
|
||||
hidden_states,
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
@@ -2166,9 +2230,25 @@ class UNetMidBlockFlat(nn.Module):
|
||||
hidden_states = self.resnets[0](hidden_states, temb)
|
||||
for attn, resnet in zip(self.attentions, self.resnets[1:]):
|
||||
if torch.is_grad_enabled() and self.gradient_checkpointing:
|
||||
|
||||
def create_custom_forward(module, return_dict=None):
|
||||
def custom_forward(*inputs):
|
||||
if return_dict is not None:
|
||||
return module(*inputs, return_dict=return_dict)
|
||||
else:
|
||||
return module(*inputs)
|
||||
|
||||
return custom_forward
|
||||
|
||||
ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
|
||||
if attn is not None:
|
||||
hidden_states = attn(hidden_states, temb=temb)
|
||||
hidden_states = self._gradient_checkpointing_func(resnet, hidden_states, temb)
|
||||
hidden_states = torch.utils.checkpoint.checkpoint(
|
||||
create_custom_forward(resnet),
|
||||
hidden_states,
|
||||
temb,
|
||||
**ckpt_kwargs,
|
||||
)
|
||||
else:
|
||||
if attn is not None:
|
||||
hidden_states = attn(hidden_states, temb=temb)
|
||||
@@ -2297,6 +2377,17 @@ class UNetMidBlockFlatCrossAttn(nn.Module):
|
||||
hidden_states = self.resnets[0](hidden_states, temb)
|
||||
for attn, resnet in zip(self.attentions, self.resnets[1:]):
|
||||
if torch.is_grad_enabled() and self.gradient_checkpointing:
|
||||
|
||||
def create_custom_forward(module, return_dict=None):
|
||||
def custom_forward(*inputs):
|
||||
if return_dict is not None:
|
||||
return module(*inputs, return_dict=return_dict)
|
||||
else:
|
||||
return module(*inputs)
|
||||
|
||||
return custom_forward
|
||||
|
||||
ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
|
||||
hidden_states = attn(
|
||||
hidden_states,
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
@@ -2305,7 +2396,12 @@ class UNetMidBlockFlatCrossAttn(nn.Module):
|
||||
encoder_attention_mask=encoder_attention_mask,
|
||||
return_dict=False,
|
||||
)[0]
|
||||
hidden_states = self._gradient_checkpointing_func(resnet, hidden_states, temb)
|
||||
hidden_states = torch.utils.checkpoint.checkpoint(
|
||||
create_custom_forward(resnet),
|
||||
hidden_states,
|
||||
temb,
|
||||
**ckpt_kwargs,
|
||||
)
|
||||
else:
|
||||
hidden_states = attn(
|
||||
hidden_states,
|
||||
|
||||
@@ -28,7 +28,8 @@ from transformers import (
|
||||
|
||||
from ...image_processor import PipelineImageInput, VaeImageProcessor
|
||||
from ...loaders import FluxIPAdapterMixin, FluxLoraLoaderMixin, FromSingleFileMixin, TextualInversionLoaderMixin
|
||||
from ...models import AutoencoderKL, FluxTransformer2DModel
|
||||
from ...models.autoencoders import AutoencoderKL
|
||||
from ...models.transformers import FluxTransformer2DModel
|
||||
from ...schedulers import FlowMatchEulerDiscreteScheduler
|
||||
from ...utils import (
|
||||
USE_PEFT_BACKEND,
|
||||
@@ -619,10 +620,6 @@ class FluxPipeline(
|
||||
def num_timesteps(self):
|
||||
return self._num_timesteps
|
||||
|
||||
@property
|
||||
def current_timestep(self):
|
||||
return self._current_timestep
|
||||
|
||||
@property
|
||||
def interrupt(self):
|
||||
return self._interrupt
|
||||
@@ -778,7 +775,6 @@ class FluxPipeline(
|
||||
|
||||
self._guidance_scale = guidance_scale
|
||||
self._joint_attention_kwargs = joint_attention_kwargs
|
||||
self._current_timestep = None
|
||||
self._interrupt = False
|
||||
|
||||
# 2. Define call parameters
|
||||
@@ -903,7 +899,6 @@ class FluxPipeline(
|
||||
if self.interrupt:
|
||||
continue
|
||||
|
||||
self._current_timestep = t
|
||||
if image_embeds is not None:
|
||||
self._joint_attention_kwargs["ip_adapter_image_embeds"] = image_embeds
|
||||
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
|
||||
@@ -962,10 +957,9 @@ class FluxPipeline(
|
||||
if XLA_AVAILABLE:
|
||||
xm.mark_step()
|
||||
|
||||
self._current_timestep = None
|
||||
|
||||
if output_type == "latent":
|
||||
image = latents
|
||||
|
||||
else:
|
||||
latents = self._unpack_latents(latents, height, width, self.vae_scale_factor)
|
||||
latents = (latents / self.vae.config.scaling_factor) + self.vae.config.shift_factor
|
||||
|
||||
@@ -930,8 +930,8 @@ class FluxControlNetInpaintPipeline(DiffusionPipeline, FluxLoraLoaderMixin, From
|
||||
if isinstance(self.controlnet, FluxControlNetModel):
|
||||
control_image = self.prepare_image(
|
||||
image=control_image,
|
||||
width=width,
|
||||
height=height,
|
||||
width=height,
|
||||
height=width,
|
||||
batch_size=batch_size * num_images_per_prompt,
|
||||
num_images_per_prompt=num_images_per_prompt,
|
||||
device=device,
|
||||
|
||||
@@ -456,10 +456,6 @@ class HunyuanVideoPipeline(DiffusionPipeline, HunyuanVideoLoraLoaderMixin):
|
||||
def attention_kwargs(self):
|
||||
return self._attention_kwargs
|
||||
|
||||
@property
|
||||
def current_timestep(self):
|
||||
return self._current_timestep
|
||||
|
||||
@property
|
||||
def interrupt(self):
|
||||
return self._interrupt
|
||||
@@ -581,7 +577,6 @@ class HunyuanVideoPipeline(DiffusionPipeline, HunyuanVideoLoraLoaderMixin):
|
||||
|
||||
self._guidance_scale = guidance_scale
|
||||
self._attention_kwargs = attention_kwargs
|
||||
self._current_timestep = None
|
||||
self._interrupt = False
|
||||
|
||||
device = self._execution_device
|
||||
@@ -649,7 +644,6 @@ class HunyuanVideoPipeline(DiffusionPipeline, HunyuanVideoLoraLoaderMixin):
|
||||
if self.interrupt:
|
||||
continue
|
||||
|
||||
self._current_timestep = t
|
||||
latent_model_input = latents.to(transformer_dtype)
|
||||
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
|
||||
timestep = t.expand(latents.shape[0]).to(latents.dtype)
|
||||
@@ -684,8 +678,6 @@ class HunyuanVideoPipeline(DiffusionPipeline, HunyuanVideoLoraLoaderMixin):
|
||||
if XLA_AVAILABLE:
|
||||
xm.mark_step()
|
||||
|
||||
self._current_timestep = None
|
||||
|
||||
if not output_type == "latent":
|
||||
latents = latents.to(self.vae.dtype) / self.vae.config.scaling_factor
|
||||
video = self.vae.decode(latents, return_dict=False)[0]
|
||||
|
||||
@@ -605,7 +605,7 @@ class GLMTransformer(torch.nn.Module):
|
||||
|
||||
layer = self._get_layer(index)
|
||||
if torch.is_grad_enabled() and self.gradient_checkpointing:
|
||||
layer_ret = self._gradient_checkpointing_func(
|
||||
layer_ret = torch.utils.checkpoint.checkpoint(
|
||||
layer, hidden_states, attention_mask, rotary_pos_emb, kv_caches[index], use_cache
|
||||
)
|
||||
else:
|
||||
@@ -666,6 +666,10 @@ class ChatGLMPreTrainedModel(PreTrainedModel):
|
||||
position_ids = torch.arange(seq_length, dtype=torch.long, device=device).unsqueeze(0).repeat(batch_size, 1)
|
||||
return position_ids
|
||||
|
||||
def _set_gradient_checkpointing(self, module, value=False):
|
||||
if isinstance(module, GLMTransformer):
|
||||
module.gradient_checkpointing = value
|
||||
|
||||
|
||||
def default_init(cls, *args, **kwargs):
|
||||
return cls(*args, **kwargs)
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user