Compare commits

..

34 Commits

Author SHA1 Message Date
sayakpaul 1f67e4e7f6 add a test for checking effective custom gc. 2025-01-29 11:04:12 +05:30
Dimitri Barbot 196aef5a6f Fix pipeline dtype unexpected change when using SDXL reference community pipelines in float16 mode (#10670)
Fix pipeline dtype unexpected change when using SDXL reference community pipelines
2025-01-28 10:46:41 -03:00
Sayak Paul 7b100ce589 [Tests] conditionally check fp8_e4m3_bf16_max_memory < fp8_e4m3_fp32_max_memory (#10669)
* conditionally check if compute capability is met.

* log info.

* fix condition.

* updates

* updates

* updates

* updates
2025-01-28 12:00:14 +05:30
Aryan c4d4ac21e7 Refactor gradient checkpointing (#10611)
* update

* remove unused fn

* apply suggestions based on review

* update + cleanup 🧹

* more cleanup 🧹

* make fix-copies

* update test
2025-01-28 06:51:46 +05:30
Hanch Han f295e2eefc [fix] refer use_framewise_encoding on AutoencoderKLHunyuanVideo._encode (#10600)
* fix: refer to use_framewise_encoding on AutoencoderKLHunyuanVideo._encode

* fix: comment about tile_sample_min_num_frames

---------

Co-authored-by: Aryan <aryan@huggingface.co>
2025-01-28 06:51:27 +05:30
Aryan 658e24e86c [core] Pyramid Attention Broadcast (#9562)
* start pyramid attention broadcast

* add coauthor

Co-Authored-By: Xuanlei Zhao <43881818+oahzxl@users.noreply.github.com>

* update

* make style

* update

* make style

* add docs

* add tests

* update

* Update docs/source/en/api/pipelines/cogvideox.md

Co-authored-by: Steven Liu <59462357+stevhliu@users.noreply.github.com>

* Update docs/source/en/api/pipelines/cogvideox.md

Co-authored-by: Steven Liu <59462357+stevhliu@users.noreply.github.com>

* Pyramid Attention Broadcast rewrite + introduce hooks (#9826)

* rewrite implementation with hooks

* make style

* update

* merge pyramid-attention-rewrite-2

* make style

* remove changes from latte transformer

* revert docs changes

* better debug message

* add todos for future

* update tests

* make style

* cleanup

* fix

* improve log message; fix latte test

* refactor

* update

* update

* update

* revert changes to tests

* update docs

* update tests

* Apply suggestions from code review

Co-authored-by: Steven Liu <59462357+stevhliu@users.noreply.github.com>

* update

* fix flux test

* reorder

* refactor

* make fix-copies

* update docs

* fixes

* more fixes

* make style

* update tests

* update code example

* make fix-copies

* refactor based on reviews

* use maybe_free_model_hooks

* CacheMixin

* make style

* update

* add current_timestep property; update docs

* make fix-copies

* update

* improve tests

* try circular import fix

* apply suggestions from review

* address review comments

* Apply suggestions from code review

* refactor hook implementation

* add test suite for hooks

* PAB Refactor (#10667)

* update

* update

* update

---------

Co-authored-by: DN6 <dhruv.nair@gmail.com>

* update

* fix remove hook behaviour

---------

Co-authored-by: Xuanlei Zhao <43881818+oahzxl@users.noreply.github.com>
Co-authored-by: Steven Liu <59462357+stevhliu@users.noreply.github.com>
Co-authored-by: DN6 <dhruv.nair@gmail.com>
2025-01-28 05:09:04 +05:30
Giuseppe Catalano fb42066489 Revert RePaint scheduler 'fix' (#10644)
Co-authored-by: Giuseppe Catalano <giuseppelorenzo.catalano@unito.it>
2025-01-27 11:16:45 -10:00
Teriks e89ab5bc26 SDXL ControlNet Union pipelines, make control_image argument immutible (#10663)
controlnet union XL, make control_image immutible

when this argument is passed a list, __call__
modifies its content, since it is pass by reference
the list passed by the caller gets its content
modified unexpectedly

make a copy at method intro so this does not happen

Co-authored-by: Teriks <Teriks@users.noreply.github.com>
2025-01-27 10:53:30 -10:00
victolee0 8ceec90d76 fix check_inputs func in LuminaText2ImgPipeline (#10651) 2025-01-27 09:47:01 -10:00
hlky 158c5c4d08 Add provider_options to OnnxRuntimeModel (#10661) 2025-01-27 09:46:17 -10:00
hlky 41571773d9 [training] Convert to ImageFolder script (#10664)
* [training] Convert to ImageFolder script

* make
2025-01-27 09:43:51 -10:00
hlky 18f7d1d937 ControlNet Union controlnet_conditioning_scale for multiple control inputs (#10666) 2025-01-27 08:15:25 -10:00
Marlon May f7f36c7d3d Add community pipeline for semantic guidance for FLUX (#10610)
* add community pipeline for semantic guidance for flux

* fix imports in community pipeline for semantic guidance for flux

* Update examples/community/pipeline_flux_semantic_guidance.py

Co-authored-by: hlky <hlky@hlky.ac>

* fix community pipeline for semantic guidance for flux

---------

Co-authored-by: Linoy Tsaban <57615435+linoytsaban@users.noreply.github.com>
Co-authored-by: hlky <hlky@hlky.ac>
2025-01-27 16:19:46 +02:00
Yuqian Hong 4fa24591a3 create a script to train autoencoderkl (#10605)
* create a script to train vae

* update main.py

* update train_autoencoderkl.py

* update train_autoencoderkl.py

* add a check of --pretrained_model_name_or_path and --model_config_name_or_path

* remove the comment, remove diffusers in requiremnets.txt, add validation_image ote

* update autoencoderkl.py

* quality

---------

Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>
2025-01-27 16:41:34 +05:30
Jacob Helwig 4f3ec5364e Add sigmoid scheduler in scheduling_ddpm.py docs (#10648)
Sigmoid scheduler in scheduling_ddpm.py docs
2025-01-26 15:37:20 -08:00
Leo Jiang 07860f9916 NPU Adaption for Sanna (#10409)
* NPU Adaption for Sanna


---------

Co-authored-by: J石页 <jiangshuo9@h-partners.com>
Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>
2025-01-24 09:08:52 -10:00
Wenhao Sun 87252d80c3 Add pipeline_stable_diffusion_xl_attentive_eraser (#10579)
* add pipeline_stable_diffusion_xl_attentive_eraser

* add pipeline_stable_diffusion_xl_attentive_eraser_make_style

* make style and add example output

* update Docs

Co-authored-by: Other Contributor <a457435687@126.com>

* add Oral

Co-authored-by: Other Contributor <a457435687@126.com>

* update_review

Co-authored-by: Other Contributor <a457435687@126.com>

* update_review_ms

Co-authored-by: Other Contributor <a457435687@126.com>

---------

Co-authored-by: Other Contributor <a457435687@126.com>
2025-01-24 13:52:45 +00:00
Sayak Paul 5897137397 [chore] add a script to extract loras from full fine-tuned models (#10631)
* feat: add a lora extraction script.

* updates
2025-01-24 11:50:36 +05:30
Yaniv Galron a451c0ed14 removing redundant requires_grad = False (#10628)
We already set the unet to requires grad false at line 506

Co-authored-by: Aryan <aryan@huggingface.co>
2025-01-24 03:25:33 +05:30
hlky 37c9697f5b Add IP-Adapter example to Flux docs (#10633)
* Add IP-Adapter example to Flux docs

* Apply suggestions from code review

Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>

---------

Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>
2025-01-23 22:15:33 +05:30
Raul Ciotescu 9684c52adf width and height are mixed-up (#10629)
vars mixed-up
2025-01-23 06:40:22 -10:00
Steven Liu 5483162d12 [docs] uv installation (#10622)
* uv

* feedback
2025-01-23 08:34:51 -08:00
Sayak Paul d77c53b6d2 [docs] fix image path in para attention docs (#10632)
fix image path in para attention docs
2025-01-23 08:22:42 -08:00
Sayak Paul 78bc824729 [Tests] modify the test slices for the failing flax test (#10630)
* fixes

* fixes

* fixes

* updates
2025-01-23 12:10:24 +05:30
kahmed10 04d40920a7 add onnxruntime-migraphx as part of check for onnxruntime in import_utils.py (#10624)
add onnxruntime-migraphx to import_utils.py

Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>
2025-01-23 07:49:51 +05:30
Dhruv Nair 8d6f6d6b66 [CI] Update HF_TOKEN in all workflows (#10613)
update
2025-01-22 20:03:41 +05:30
Aryan ca60ad8e55 Improve TorchAO error message (#10627)
improve error message
2025-01-22 19:50:02 +05:30
Aryan beacaa5528 [core] Layerwise Upcasting (#10347)
* update

* update

* make style

* remove dynamo disable

* add coauthor

Co-Authored-By: Dhruv Nair <dhruv.nair@gmail.com>

* update

* update

* update

* update mixin

* add some basic tests

* update

* update

* non_blocking

* improvements

* update

* norm.* -> norm

* apply suggestions from review

* add example

* update hook implementation to the latest changes from pyramid attention broadcast

* deinitialize should raise an error

* update doc page

* Apply suggestions from code review

Co-authored-by: Steven Liu <59462357+stevhliu@users.noreply.github.com>

* update docs

* update

* refactor

* fix _always_upcast_modules for asym ae and vq_model

* fix lumina embedding forward to not depend on weight dtype

* refactor tests

* add simple lora inference tests

* _always_upcast_modules -> _precision_sensitive_module_patterns

* remove todo comments about review; revert changes to self.dtype in unets because .dtype on ModelMixin should be able to handle fp8 weight case

* check layer dtypes in lora test

* fix UNet1DModelTests::test_layerwise_upcasting_inference

* _precision_sensitive_module_patterns -> _skip_layerwise_casting_patterns based on feedback

* skip test in NCSNppModelTests

* skip tests for AutoencoderTinyTests

* skip tests for AutoencoderOobleckTests

* skip tests for UNet1DModelTests - unsupported pytorch operations

* layerwise_upcasting -> layerwise_casting

* skip tests for UNetRLModelTests; needs next pytorch release for currently unimplemented operation support

* add layerwise fp8 pipeline test

* use xfail

* Apply suggestions from code review

Co-authored-by: Dhruv Nair <dhruv.nair@gmail.com>

* add assertion with fp32 comparison; add tolerance to fp8-fp32 vs fp32-fp32 comparison (required for a few models' test to pass)

* add note about memory consumption on tesla CI runner for failing test

---------

Co-authored-by: Dhruv Nair <dhruv.nair@gmail.com>
Co-authored-by: Steven Liu <59462357+stevhliu@users.noreply.github.com>
2025-01-22 19:49:37 +05:30
Lucain a647682224 Remove cache migration script (#10619) 2025-01-21 07:22:59 -10:00
YiYi Xu a1f9a71238 fix offload gpu tests etc (#10366)
* add

* style
2025-01-21 18:52:36 +05:30
Fanli Lin ec37e20972 [tests] make tests device-agnostic (part 3) (#10437)
* initial comit

* fix empty cache

* fix one more

* fix style

* update device functions

* update

* update

* Update src/diffusers/utils/testing_utils.py

Co-authored-by: hlky <hlky@hlky.ac>

* Update src/diffusers/utils/testing_utils.py

Co-authored-by: hlky <hlky@hlky.ac>

* Update src/diffusers/utils/testing_utils.py

Co-authored-by: hlky <hlky@hlky.ac>

* Update tests/pipelines/controlnet/test_controlnet.py

Co-authored-by: hlky <hlky@hlky.ac>

* Update src/diffusers/utils/testing_utils.py

Co-authored-by: hlky <hlky@hlky.ac>

* Update src/diffusers/utils/testing_utils.py

Co-authored-by: hlky <hlky@hlky.ac>

* Update tests/pipelines/controlnet/test_controlnet.py

Co-authored-by: hlky <hlky@hlky.ac>

* with gc.collect

* update

* make style

* check_torch_dependencies

* add mps empty cache

* bug fix

* Apply suggestions from code review

---------

Co-authored-by: hlky <hlky@hlky.ac>
2025-01-21 12:15:45 +00:00
Muyang Li 158a5a87fb Remove the FP32 Wrapper when evaluating (#10617)
Remove the FP32 Wrapper

Co-authored-by: Linoy Tsaban <57615435+linoytsaban@users.noreply.github.com>
2025-01-21 16:16:54 +05:30
jiqing-feng 012d08b1bc Enable dreambooth lora finetune example on other devices (#10602)
* enable dreambooth_lora on other devices

Signed-off-by: jiqing-feng <jiqing.feng@intel.com>

* enable xpu

Signed-off-by: jiqing-feng <jiqing.feng@intel.com>

* check cuda device before empty cache

Signed-off-by: jiqing-feng <jiqing.feng@intel.com>

* fix comment

Signed-off-by: jiqing-feng <jiqing.feng@intel.com>

* import free_memory

Signed-off-by: jiqing-feng <jiqing.feng@intel.com>

---------

Signed-off-by: jiqing-feng <jiqing.feng@intel.com>
2025-01-21 14:09:45 +05:30
Sayak Paul 4ace7d0483 [chore] change licensing to 2025 from 2024. (#10615)
change licensing to 2025 from 2024.
2025-01-20 16:57:27 -10:00
278 changed files with 8118 additions and 2333 deletions
+2
View File
@@ -598,6 +598,8 @@
title: Attention Processor
- local: api/activations
title: Custom activation functions
- local: api/cache
title: Caching methods
- local: api/normalization
title: Custom normalization layers
- local: api/utilities
+49
View File
@@ -0,0 +1,49 @@
<!-- Copyright 2024 The HuggingFace Team. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
the License. You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
specific language governing permissions and limitations under the License. -->
# Caching methods
## Pyramid Attention Broadcast
[Pyramid Attention Broadcast](https://huggingface.co/papers/2408.12588) from Xuanlei Zhao, Xiaolong Jin, Kai Wang, Yang You.
Pyramid Attention Broadcast (PAB) is a method that speeds up inference in diffusion models by systematically skipping attention computations between successive inference steps and reusing cached attention states. The attention states are not very different between successive inference steps. The most prominent difference is in the spatial attention blocks, not as much in the temporal attention blocks, and finally the least in the cross attention blocks. Therefore, many cross attention computation blocks can be skipped, followed by the temporal and spatial attention blocks. By combining other techniques like sequence parallelism and classifier-free guidance parallelism, PAB achieves near real-time video generation.
Enable PAB with [`~PyramidAttentionBroadcastConfig`] on any pipeline. For some benchmarks, refer to [this](https://github.com/huggingface/diffusers/pull/9562) pull request.
```python
import torch
from diffusers import CogVideoXPipeline, PyramidAttentionBroadcastConfig
pipe = CogVideoXPipeline.from_pretrained("THUDM/CogVideoX-5b", torch_dtype=torch.bfloat16)
pipe.to("cuda")
# Increasing the value of `spatial_attention_timestep_skip_range[0]` or decreasing the value of
# `spatial_attention_timestep_skip_range[1]` will decrease the interval in which pyramid attention
# broadcast is active, leader to slower inference speeds. However, large intervals can lead to
# poorer quality of generated videos.
config = PyramidAttentionBroadcastConfig(
spatial_attention_block_skip_range=2,
spatial_attention_timestep_skip_range=(100, 800),
current_timestep_callback=lambda: pipe.current_timestep,
)
pipe.transformer.enable_cache(config)
```
### CacheMixin
[[autodoc]] CacheMixin
### PyramidAttentionBroadcastConfig
[[autodoc]] PyramidAttentionBroadcastConfig
[[autodoc]] apply_pyramid_attention_broadcast
+47
View File
@@ -309,6 +309,53 @@ image.save("output.png")
When unloading the Control LoRA weights, call `pipe.unload_lora_weights(reset_to_overwritten_params=True)` to reset the `pipe.transformer` completely back to its original form. The resultant pipeline can then be used with methods like [`DiffusionPipeline.from_pipe`]. More details about this argument are available in [this PR](https://github.com/huggingface/diffusers/pull/10397).
## IP-Adapter
<Tip>
Check out [IP-Adapter](../../../using-diffusers/ip_adapter) to learn more about how IP-Adapters work.
</Tip>
An IP-Adapter lets you prompt Flux with images, in addition to the text prompt. This is especially useful when describing complex concepts that are difficult to articulate through text alone and you have reference images.
```python
import torch
from diffusers import FluxPipeline
from diffusers.utils import load_image
pipe = FluxPipeline.from_pretrained(
"black-forest-labs/FLUX.1-dev", torch_dtype=torch.bfloat16
).to("cuda")
image = load_image("https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/flux_ip_adapter_input.jpg").resize((1024, 1024))
pipe.load_ip_adapter(
"XLabs-AI/flux-ip-adapter",
weight_name="ip_adapter.safetensors",
image_encoder_pretrained_model_name_or_path="openai/clip-vit-large-patch14"
)
pipe.set_ip_adapter_scale(1.0)
image = pipe(
width=1024,
height=1024,
prompt="wearing sunglasses",
negative_prompt="",
true_cfg=4.0,
generator=torch.Generator().manual_seed(4444),
ip_adapter_image=image,
).images[0]
image.save('flux_ip_adapter_output.jpg')
```
<div class="justify-center">
<img src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/flux_ip_adapter_output.jpg"/>
<figcaption class="mt-2 text-sm text-center text-gray-500">IP-Adapter examples with prompt "wearing sunglasses"</figcaption>
</div>
## Running FP16 inference
Flux can generate high-quality images with FP16 (i.e. to accelerate inference on Turing/Volta GPUs) but produces different outputs compared to FP32/BF16. The issue is that some activations in the text encoders have to be clipped when running in FP16, which affects the overall image. Forcing text encoders to run with FP32 inference thus removes this output difference. See [here](https://github.com/huggingface/diffusers/pull/9097#issuecomment-2272292516) for details.
+4
View File
@@ -41,3 +41,7 @@ Utility and helper functions for working with 🤗 Diffusers.
## randn_tensor
[[autodoc]] utils.torch_utils.randn_tensor
## apply_layerwise_casting
[[autodoc]] hooks.layerwise_casting.apply_layerwise_casting
+34 -6
View File
@@ -23,32 +23,60 @@ You should install 🤗 Diffusers in a [virtual environment](https://docs.python
If you're unfamiliar with Python virtual environments, take a look at this [guide](https://packaging.python.org/guides/installing-using-pip-and-virtual-environments/).
A virtual environment makes it easier to manage different projects and avoid compatibility issues between dependencies.
Start by creating a virtual environment in your project directory:
Create a virtual environment with Python or [uv](https://docs.astral.sh/uv/) (refer to [Installation](https://docs.astral.sh/uv/getting-started/installation/) for installation instructions), a fast Rust-based Python package and project manager.
<hfoptions id="install">
<hfoption id="uv">
```bash
python -m venv .env
uv venv my-env
source my-env/bin/activate
```
Activate the virtual environment:
</hfoption>
<hfoption id="Python">
```bash
source .env/bin/activate
python -m venv my-env
source my-env/bin/activate
```
You should also install 🤗 Transformers because 🤗 Diffusers relies on its models:
</hfoption>
</hfoptions>
You should also install 🤗 Transformers because 🤗 Diffusers relies on its models.
<frameworkcontent>
<pt>
Note - PyTorch only supports Python 3.8 - 3.11 on Windows.
PyTorch only supports Python 3.8 - 3.11 on Windows. Install Diffusers with uv.
```bash
uv install diffusers["torch"] transformers
```
You can also install Diffusers with pip.
```bash
pip install diffusers["torch"] transformers
```
</pt>
<jax>
Install Diffusers with uv.
```bash
uv pip install diffusers["flax"] transformers
```
You can also install Diffusers with pip.
```bash
pip install diffusers["flax"] transformers
```
</jax>
</frameworkcontent>
+37
View File
@@ -158,6 +158,43 @@ In order to properly offload models after they're called, it is required to run
</Tip>
## FP8 layerwise weight-casting
PyTorch supports `torch.float8_e4m3fn` and `torch.float8_e5m2` as weight storage dtypes, but they can't be used for computation in many different tensor operations due to unimplemented kernel support. However, you can use these dtypes to store model weights in fp8 precision and upcast them on-the-fly when the layers are used in the forward pass. This is known as layerwise weight-casting.
Typically, inference on most models is done with `torch.float16` or `torch.bfloat16` weight/computation precision. Layerwise weight-casting cuts down the memory footprint of the model weights by approximately half.
```python
import torch
from diffusers import CogVideoXPipeline, CogVideoXTransformer3DModel
from diffusers.utils import export_to_video
model_id = "THUDM/CogVideoX-5b"
# Load the model in bfloat16 and enable layerwise casting
transformer = CogVideoXTransformer3DModel.from_pretrained(model_id, subfolder="transformer", torch_dtype=torch.bfloat16)
transformer.enable_layerwise_casting(storage_dtype=torch.float8_e4m3fn, compute_dtype=torch.bfloat16)
# Load the pipeline
pipe = CogVideoXPipeline.from_pretrained(model_id, transformer=transformer, torch_dtype=torch.bfloat16)
pipe.to("cuda")
prompt = (
"A panda, dressed in a small, red jacket and a tiny hat, sits on a wooden stool in a serene bamboo forest. "
"The panda's fluffy paws strum a miniature acoustic guitar, producing soft, melodic tunes. Nearby, a few other "
"pandas gather, watching curiously and some clapping in rhythm. Sunlight filters through the tall bamboo, "
"casting a gentle glow on the scene. The panda's face is expressive, showing concentration and joy as it plays. "
"The background includes a small, flowing stream and vibrant green foliage, enhancing the peaceful and magical "
"atmosphere of this unique musical performance."
)
video = pipe(prompt=prompt, guidance_scale=6, num_inference_steps=50).frames[0]
export_to_video(video, "output.mp4", fps=8)
```
In the above example, layerwise casting is enabled on the transformer component of the pipeline. By default, certain layers are skipped from the FP8 weight casting because it can lead to significant degradation of generation quality. The normalization and modulation related weight parameters are also skipped by default.
However, you gain more control and flexibility by directly utilizing the [`~hooks.layerwise_casting.apply_layerwise_casting`] function instead of [`~ModelMixin.enable_layerwise_casting`].
## Channels-last memory format
The channels-last memory format is an alternative way of ordering NCHW tensors in memory to preserve dimension ordering. Channels-last tensors are ordered in such a way that the channels become the densest dimension (storing images pixel-per-pixel). Since not all operators currently support the channels-last format, it may result in worst performance but you should still try and see if it works for your model.
+1 -1
View File
@@ -29,7 +29,7 @@ However, it is hard to decide when to reuse the cache to ensure quality generate
This achieves a 2x speedup on FLUX.1-dev and HunyuanVideo inference with very good quality.
<figure>
<img src="https://huggingface.co/datasets/chengzeyi/documentation-images/resolve/main/diffusers/para-attn/ada-cache.png" alt="Cache in Diffusion Transformer" />
<img src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/para-attn/ada-cache.png" alt="Cache in Diffusion Transformer" />
<figcaption>How AdaCache works, First Block Cache is a variant of it</figcaption>
</figure>
@@ -1,6 +1,6 @@
#!/usr/bin/env python
# coding=utf-8
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -1,6 +1,6 @@
#!/usr/bin/env python
# coding=utf-8
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -1,6 +1,6 @@
#!/usr/bin/env python
# coding=utf-8
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
+1 -1
View File
@@ -1,5 +1,5 @@
# coding=utf-8
# Copyright 2024 The HuggingFace Inc. team.
# Copyright 2025 The HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Executable → Regular
+90 -2
View File
@@ -77,6 +77,7 @@ Please also check out our [Community Scripts](https://github.com/huggingface/dif
PIXART-α Controlnet pipeline | Implementation of the controlnet model for pixart alpha and its diffusers pipeline | [PIXART-α Controlnet pipeline](#pixart-α-controlnet-pipeline) | - | [Raul Ciotescu](https://github.com/raulc0399/) |
| HunyuanDiT Differential Diffusion Pipeline | Applies [Differential Diffusion](https://github.com/exx8/differential-diffusion) to [HunyuanDiT](https://github.com/huggingface/diffusers/pull/8240). | [HunyuanDiT with Differential Diffusion](#hunyuandit-with-differential-diffusion) | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/1v44a5fpzyr4Ffr4v2XBQ7BajzG874N4P?usp=sharing) | [Monjoy Choudhury](https://github.com/MnCSSJ4x) |
| [🪆Matryoshka Diffusion Models](https://huggingface.co/papers/2310.15111) | A diffusion process that denoises inputs at multiple resolutions jointly and uses a NestedUNet architecture where features and parameters for small scale inputs are nested within those of the large scales. See [original codebase](https://github.com/apple/ml-mdm). | [🪆Matryoshka Diffusion Models](#matryoshka-diffusion-models) | [![Hugging Face Space](https://img.shields.io/badge/🤗%20Hugging%20Face-Space-yellow)](https://huggingface.co/spaces/pcuenq/mdm) [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/gist/tolgacangoz/1f54875fc7aeaabcf284ebde64820966/matryoshka_hf.ipynb) | [M. Tolga Cangöz](https://github.com/tolgacangoz) |
| Stable Diffusion XL Attentive Eraser Pipeline |[[AAAI2025 Oral] Attentive Eraser](https://github.com/Anonym0u3/AttentiveEraser) is a novel tuning-free method that enhances object removal capabilities in pre-trained diffusion models.|[Stable Diffusion XL Attentive Eraser Pipeline](#stable-diffusion-xl-attentive-eraser-pipeline)|-|[Wenhao Sun](https://github.com/Anonym0u3) and [Benlei Cui](https://github.com/Benny079)|
To load a custom pipeline you just need to pass the `custom_pipeline` argument to `DiffusionPipeline`, as one of the files in `diffusers/examples/community`. Feel free to send a PR with your own pipelines, we will merge them quickly.
@@ -4585,8 +4586,8 @@ image = pipe(
```
| ![Gradient](https://github.com/user-attachments/assets/e38ce4d5-1ae6-4df0-ab43-adc1b45716b5) | ![Input](https://github.com/user-attachments/assets/9c95679c-e9d7-4f5a-90d6-560203acd6b3) | ![Output](https://github.com/user-attachments/assets/5313ff64-a0c4-418b-8b55-a38f1a5e7532) |
| ------------------------------------------------------------------------------------------ | --------------------------------------------------------------------------------------- | ---------------------------------------------------------------------------------------- |
| Gradient | Input | Output |
| -------------------------------------------------------------------------------------------- | ----------------------------------------------------------------------------------------- | ------------------------------------------------------------------------------------------ |
| Gradient | Input | Output |
A colab notebook demonstrating all results can be found [here](https://colab.research.google.com/drive/1v44a5fpzyr4Ffr4v2XBQ7BajzG874N4P?usp=sharing). Depth Maps have also been added in the same colab.
@@ -4634,6 +4635,93 @@ make_image_grid(image, rows=1, cols=len(image))
# 50+, 100+, and 250+ num_inference_steps are recommended for nesting levels 0, 1, and 2 respectively.
```
### Stable Diffusion XL Attentive Eraser Pipeline
<img src="https://raw.githubusercontent.com/Anonym0u3/Images/refs/heads/main/fenmian.png" width="600" />
**Stable Diffusion XL Attentive Eraser Pipeline** is an advanced object removal pipeline that leverages SDXL for precise content suppression and seamless region completion. This pipeline uses **self-attention redirection guidance** to modify the models self-attention mechanism, allowing for effective removal and inpainting across various levels of mask precision, including semantic segmentation masks, bounding boxes, and hand-drawn masks. If you are interested in more detailed information and have any questions, please refer to the [paper](https://arxiv.org/abs/2412.12974) and [official implementation](https://github.com/Anonym0u3/AttentiveEraser).
#### Key features
- **Tuning-Free**: No additional training is required, making it easy to integrate and use.
- **Flexible Mask Support**: Works with different types of masks for targeted object removal.
- **High-Quality Results**: Utilizes the inherent generative power of diffusion models for realistic content completion.
#### Usage example
To use the Stable Diffusion XL Attentive Eraser Pipeline, you can initialize it as follows:
```py
import torch
from diffusers import DDIMScheduler, DiffusionPipeline
from diffusers.utils import load_image
import torch.nn.functional as F
from torchvision.transforms.functional import to_tensor, gaussian_blur
dtype = torch.float16
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
scheduler = DDIMScheduler(beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", clip_sample=False, set_alpha_to_one=False)
pipeline = DiffusionPipeline.from_pretrained(
"stabilityai/stable-diffusion-xl-base-1.0",
custom_pipeline="pipeline_stable_diffusion_xl_attentive_eraser",
scheduler=scheduler,
variant="fp16",
use_safetensors=True,
torch_dtype=dtype,
).to(device)
def preprocess_image(image_path, device):
image = to_tensor((load_image(image_path)))
image = image.unsqueeze_(0).float() * 2 - 1 # [0,1] --> [-1,1]
if image.shape[1] != 3:
image = image.expand(-1, 3, -1, -1)
image = F.interpolate(image, (1024, 1024))
image = image.to(dtype).to(device)
return image
def preprocess_mask(mask_path, device):
mask = to_tensor((load_image(mask_path, convert_method=lambda img: img.convert('L'))))
mask = mask.unsqueeze_(0).float() # 0 or 1
mask = F.interpolate(mask, (1024, 1024))
mask = gaussian_blur(mask, kernel_size=(77, 77))
mask[mask < 0.1] = 0
mask[mask >= 0.1] = 1
mask = mask.to(dtype).to(device)
return mask
prompt = "" # Set prompt to null
seed=123
generator = torch.Generator(device=device).manual_seed(seed)
source_image_path = "https://raw.githubusercontent.com/Anonym0u3/Images/refs/heads/main/an1024.png"
mask_path = "https://raw.githubusercontent.com/Anonym0u3/Images/refs/heads/main/an1024_mask.png"
source_image = preprocess_image(source_image_path, device)
mask = preprocess_mask(mask_path, device)
image = pipeline(
prompt=prompt,
image=source_image,
mask_image=mask,
height=1024,
width=1024,
AAS=True, # enable AAS
strength=0.8, # inpainting strength
rm_guidance_scale=9, # removal guidance scale
ss_steps = 9, # similarity suppression steps
ss_scale = 0.3, # similarity suppression scale
AAS_start_step=0, # AAS start step
AAS_start_layer=34, # AAS start layer
AAS_end_layer=70, # AAS end layer
num_inference_steps=50, # number of inference steps # AAS_end_step = int(strength*num_inference_steps)
generator=generator,
guidance_scale=1,
).images[0]
image.save('./removed_img.png')
print("Object removal completed")
```
| Source Image | Mask | Output |
| ---------------------------------------------------------------------------------------------- | ------------------------------------------------------------------------------------------- | --------------------------------------------------------------------------------------------------- |
| ![Source Image](https://raw.githubusercontent.com/Anonym0u3/Images/refs/heads/main/an1024.png) | ![Mask](https://raw.githubusercontent.com/Anonym0u3/Images/refs/heads/main/an1024_mask.png) | ![Output](https://raw.githubusercontent.com/Anonym0u3/Images/refs/heads/main/AE_step40_layer34.png) |
# Perturbed-Attention Guidance
[Project](https://ku-cvlab.github.io/Perturbed-Attention-Guidance/) / [arXiv](https://arxiv.org/abs/2403.17377) / [GitHub](https://github.com/KU-CVLAB/Perturbed-Attention-Guidance)
+5 -74
View File
@@ -80,7 +80,6 @@ from diffusers.utils import (
USE_PEFT_BACKEND,
BaseOutput,
deprecate,
is_torch_version,
is_torch_xla_available,
logging,
replace_example_docstring,
@@ -869,23 +868,7 @@ class CrossAttnDownBlock2D(nn.Module):
for i, (resnet, attn) in enumerate(blocks):
if torch.is_grad_enabled() and self.gradient_checkpointing:
def create_custom_forward(module, return_dict=None):
def custom_forward(*inputs):
if return_dict is not None:
return module(*inputs, return_dict=return_dict)
else:
return module(*inputs)
return custom_forward
ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
hidden_states = torch.utils.checkpoint.checkpoint(
create_custom_forward(resnet),
hidden_states,
temb,
**ckpt_kwargs,
)
hidden_states = self._gradient_checkpointing_func(resnet, hidden_states, temb)
hidden_states = attn(
hidden_states,
encoder_hidden_states=encoder_hidden_states,
@@ -1030,17 +1013,6 @@ class UNetMidBlock2DCrossAttn(nn.Module):
hidden_states = self.resnets[0](hidden_states, temb)
for attn, resnet in zip(self.attentions, self.resnets[1:]):
if torch.is_grad_enabled() and self.gradient_checkpointing:
def create_custom_forward(module, return_dict=None):
def custom_forward(*inputs):
if return_dict is not None:
return module(*inputs, return_dict=return_dict)
else:
return module(*inputs)
return custom_forward
ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
hidden_states = attn(
hidden_states,
encoder_hidden_states=encoder_hidden_states,
@@ -1049,12 +1021,7 @@ class UNetMidBlock2DCrossAttn(nn.Module):
encoder_attention_mask=encoder_attention_mask,
return_dict=False,
)[0]
hidden_states = torch.utils.checkpoint.checkpoint(
create_custom_forward(resnet),
hidden_states,
temb,
**ckpt_kwargs,
)
hidden_states = self._gradient_checkpointing_func(resnet, hidden_states, temb)
else:
hidden_states = attn(
hidden_states,
@@ -1192,23 +1159,7 @@ class CrossAttnUpBlock2D(nn.Module):
hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
if torch.is_grad_enabled() and self.gradient_checkpointing:
def create_custom_forward(module, return_dict=None):
def custom_forward(*inputs):
if return_dict is not None:
return module(*inputs, return_dict=return_dict)
else:
return module(*inputs)
return custom_forward
ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
hidden_states = torch.utils.checkpoint.checkpoint(
create_custom_forward(resnet),
hidden_states,
temb,
**ckpt_kwargs,
)
hidden_states = self._gradient_checkpointing_func(resnet, hidden_states, temb)
hidden_states = attn(
hidden_states,
encoder_hidden_states=encoder_hidden_states,
@@ -1282,10 +1233,6 @@ class MatryoshkaTransformer2DModel(LegacyModelMixin, LegacyConfigMixin):
]
)
def _set_gradient_checkpointing(self, module, value=False):
if hasattr(module, "gradient_checkpointing"):
module.gradient_checkpointing = value
def forward(
self,
hidden_states: torch.Tensor,
@@ -1365,19 +1312,8 @@ class MatryoshkaTransformer2DModel(LegacyModelMixin, LegacyConfigMixin):
# Blocks
for block in self.transformer_blocks:
if torch.is_grad_enabled() and self.gradient_checkpointing:
def create_custom_forward(module, return_dict=None):
def custom_forward(*inputs):
if return_dict is not None:
return module(*inputs, return_dict=return_dict)
else:
return module(*inputs)
return custom_forward
ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
hidden_states = torch.utils.checkpoint.checkpoint(
create_custom_forward(block),
hidden_states = self._gradient_checkpointing_func(
block,
hidden_states,
attention_mask,
encoder_hidden_states,
@@ -1385,7 +1321,6 @@ class MatryoshkaTransformer2DModel(LegacyModelMixin, LegacyConfigMixin):
timestep,
cross_attention_kwargs,
class_labels,
**ckpt_kwargs,
)
else:
hidden_states = block(
@@ -2724,10 +2659,6 @@ class MatryoshkaUNet2DConditionModel(
for module in self.children():
fn_recursive_set_attention_slice(module, reversed_slice_size)
def _set_gradient_checkpointing(self, module, value=False):
if hasattr(module, "gradient_checkpointing"):
module.gradient_checkpointing = value
def enable_freeu(self, s1: float, s2: float, b1: float, b2: float):
r"""Enables the FreeU mechanism from https://arxiv.org/abs/2309.11497.
File diff suppressed because it is too large Load Diff
File diff suppressed because it is too large Load Diff
@@ -1,5 +1,5 @@
#
# Copyright 2024 The HuggingFace Inc. team.
# Copyright 2025 The HuggingFace Inc. team.
# SPDX-FileCopyrightText: Copyright (c) 1993-2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
@@ -1,5 +1,5 @@
#
# Copyright 2024 The HuggingFace Inc. team.
# Copyright 2025 The HuggingFace Inc. team.
# SPDX-FileCopyrightText: Copyright (c) 1993-2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
@@ -1,5 +1,5 @@
#
# Copyright 2024 The HuggingFace Inc. team.
# Copyright 2025 The HuggingFace Inc. team.
# SPDX-FileCopyrightText: Copyright (c) 1993-2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
@@ -193,7 +193,8 @@ class StableDiffusionXLControlNetReferencePipeline(StableDiffusionXLControlNetPi
def prepare_ref_latents(self, refimage, batch_size, dtype, device, generator, do_classifier_free_guidance):
refimage = refimage.to(device=device)
if self.vae.dtype == torch.float16 and self.vae.config.force_upcast:
needs_upcasting = self.vae.dtype == torch.float16 and self.vae.config.force_upcast
if needs_upcasting:
self.upcast_vae()
refimage = refimage.to(next(iter(self.vae.post_quant_conv.parameters())).dtype)
if refimage.dtype != self.vae.dtype:
@@ -223,6 +224,11 @@ class StableDiffusionXLControlNetReferencePipeline(StableDiffusionXLControlNetPi
# aligning device to prevent device errors when concating it with the latent model input
ref_image_latents = ref_image_latents.to(device=device, dtype=dtype)
# cast back to fp16 if needed
if needs_upcasting:
self.vae.to(dtype=torch.float16)
return ref_image_latents
def prepare_ref_image(
@@ -139,7 +139,8 @@ def retrieve_timesteps(
class StableDiffusionXLReferencePipeline(StableDiffusionXLPipeline):
def prepare_ref_latents(self, refimage, batch_size, dtype, device, generator, do_classifier_free_guidance):
refimage = refimage.to(device=device)
if self.vae.dtype == torch.float16 and self.vae.config.force_upcast:
needs_upcasting = self.vae.dtype == torch.float16 and self.vae.config.force_upcast
if needs_upcasting:
self.upcast_vae()
refimage = refimage.to(next(iter(self.vae.post_quant_conv.parameters())).dtype)
if refimage.dtype != self.vae.dtype:
@@ -169,6 +170,11 @@ class StableDiffusionXLReferencePipeline(StableDiffusionXLPipeline):
# aligning device to prevent device errors when concating it with the latent model input
ref_image_latents = ref_image_latents.to(device=device, dtype=dtype)
# cast back to fp16 if needed
if needs_upcasting:
self.vae.to(dtype=torch.float16)
return ref_image_latents
def prepare_ref_image(
@@ -1,6 +1,6 @@
#!/usr/bin/env python
# coding=utf-8
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -1,6 +1,6 @@
#!/usr/bin/env python
# coding=utf-8
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -1,6 +1,6 @@
#!/usr/bin/env python
# coding=utf-8
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -1,6 +1,6 @@
#!/usr/bin/env python
# coding=utf-8
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
+1 -1
View File
@@ -1,6 +1,6 @@
#!/usr/bin/env python
# coding=utf-8
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
+1 -1
View File
@@ -1,6 +1,6 @@
#!/usr/bin/env python
# coding=utf-8
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
+1 -1
View File
@@ -1,6 +1,6 @@
#!/usr/bin/env python
# coding=utf-8
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
+1 -1
View File
@@ -1,6 +1,6 @@
#!/usr/bin/env python
# coding=utf-8
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
+1 -1
View File
@@ -1,6 +1,6 @@
#!/usr/bin/env python
# coding=utf-8
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
+26
View File
@@ -742,3 +742,29 @@ accelerate launch train_dreambooth.py \
## Stable Diffusion XL
We support fine-tuning of the UNet shipped in [Stable Diffusion XL](https://huggingface.co/papers/2307.01952) with DreamBooth and LoRA via the `train_dreambooth_lora_sdxl.py` script. Please refer to the docs [here](./README_sdxl.md).
## Dataset
We support 🤗 [Datasets](https://huggingface.co/docs/datasets/index), you can find a dataset on the [Hugging Face Hub](https://huggingface.co/datasets) or use your own.
The quickest way to get started with your custom dataset is 🤗 Datasets' [`ImageFolder`](https://huggingface.co/docs/datasets/image_dataset#imagefolder).
We need to create a file `metadata.jsonl` in the directory with our images:
```
{"file_name": "01.jpg", "prompt": "prompt 01"}
{"file_name": "02.jpg", "prompt": "prompt 02"}
```
If we have a directory with image-text pairs e.g. `01.jpg` and `01.txt` then `convert_to_imagefolder.py` can create `metadata.jsonl`.
```sh
python convert_to_imagefolder.py --path my_dataset/
```
We use `--dataset_name` and `--caption_column` with training scripts.
```
--dataset_name=my_dataset/
--caption_column=prompt
```
@@ -0,0 +1,32 @@
import argparse
import json
import pathlib
parser = argparse.ArgumentParser()
parser.add_argument(
"--path",
type=str,
required=True,
help="Path to folder with image-text pairs.",
)
parser.add_argument("--caption_column", type=str, default="prompt", help="Name of caption column.")
args = parser.parse_args()
path = pathlib.Path(args.path)
if not path.exists():
raise RuntimeError(f"`--path` '{args.path}' does not exist.")
all_files = list(path.glob("*"))
captions = list(path.glob("*.txt"))
images = set(all_files) - set(captions)
images = {image.stem: image for image in images}
caption_image = {caption: images.get(caption.stem) for caption in captions if images.get(caption.stem)}
metadata = path.joinpath("metadata.jsonl")
with metadata.open("w", encoding="utf-8") as f:
for caption, image in caption_image.items():
caption_text = caption.read_text(encoding="utf-8")
json.dump({"file_name": image.name, args.caption_column: caption_text}, f)
f.write("\n")
+1 -1
View File
@@ -1,6 +1,6 @@
#!/usr/bin/env python
# coding=utf-8
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
+4 -4
View File
@@ -1,6 +1,6 @@
#!/usr/bin/env python
# coding=utf-8
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -1716,9 +1716,9 @@ def main(args):
pipeline = FluxPipeline.from_pretrained(
args.pretrained_model_name_or_path,
vae=vae,
text_encoder=accelerator.unwrap_model(text_encoder_one),
text_encoder_2=accelerator.unwrap_model(text_encoder_two),
transformer=accelerator.unwrap_model(transformer),
text_encoder=accelerator.unwrap_model(text_encoder_one, keep_fp32_wrapper=False),
text_encoder_2=accelerator.unwrap_model(text_encoder_two, keep_fp32_wrapper=False),
transformer=accelerator.unwrap_model(transformer, keep_fp32_wrapper=False),
revision=args.revision,
variant=args.variant,
torch_dtype=weight_dtype,
+12 -9
View File
@@ -1,6 +1,6 @@
#!/usr/bin/env python
# coding=utf-8
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -54,7 +54,11 @@ from diffusers import (
)
from diffusers.loaders import StableDiffusionLoraLoaderMixin
from diffusers.optimization import get_scheduler
from diffusers.training_utils import _set_state_dict_into_text_encoder, cast_training_params
from diffusers.training_utils import (
_set_state_dict_into_text_encoder,
cast_training_params,
free_memory,
)
from diffusers.utils import (
check_min_version,
convert_state_dict_to_diffusers,
@@ -151,14 +155,14 @@ def log_validation(
if args.validation_images is None:
images = []
for _ in range(args.num_validation_images):
with torch.cuda.amp.autocast():
with torch.amp.autocast(accelerator.device.type):
image = pipeline(**pipeline_args, generator=generator).images[0]
images.append(image)
else:
images = []
for image in args.validation_images:
image = Image.open(image)
with torch.cuda.amp.autocast():
with torch.amp.autocast(accelerator.device.type):
image = pipeline(**pipeline_args, image=image, generator=generator).images[0]
images.append(image)
@@ -177,7 +181,7 @@ def log_validation(
)
del pipeline
torch.cuda.empty_cache()
free_memory()
return images
@@ -793,7 +797,7 @@ def main(args):
cur_class_images = len(list(class_images_dir.iterdir()))
if cur_class_images < args.num_class_images:
torch_dtype = torch.float16 if accelerator.device.type == "cuda" else torch.float32
torch_dtype = torch.float16 if accelerator.device.type in ("cuda", "xpu") else torch.float32
if args.prior_generation_precision == "fp32":
torch_dtype = torch.float32
elif args.prior_generation_precision == "fp16":
@@ -829,8 +833,7 @@ def main(args):
image.save(image_filename)
del pipeline
if torch.cuda.is_available():
torch.cuda.empty_cache()
free_memory()
# Handle the repository creation
if accelerator.is_main_process:
@@ -1085,7 +1088,7 @@ def main(args):
tokenizer = None
gc.collect()
torch.cuda.empty_cache()
free_memory()
else:
pre_computed_encoder_hidden_states = None
validation_prompt_encoder_hidden_states = None
@@ -1,6 +1,6 @@
#!/usr/bin/env python
# coding=utf-8
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -1,6 +1,6 @@
#!/usr/bin/env python
# coding=utf-8
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -63,6 +63,7 @@ from diffusers.utils import (
is_wandb_available,
)
from diffusers.utils.hub_utils import load_or_create_model_card, populate_model_card
from diffusers.utils.import_utils import is_torch_npu_available
from diffusers.utils.torch_utils import is_compiled_module
@@ -74,6 +75,9 @@ check_min_version("0.33.0.dev0")
logger = get_logger(__name__)
if is_torch_npu_available():
torch.npu.config.allow_internal_format = False
def save_model_card(
repo_id: str,
@@ -601,6 +605,7 @@ def parse_args(input_args=None):
)
parser.add_argument("--local_rank", type=int, default=-1, help="For distributed training: local_rank")
parser.add_argument("--enable_vae_tiling", action="store_true", help="Enabla vae tiling in log validation")
parser.add_argument("--enable_npu_flash_attention", action="store_true", help="Enabla Flash Attention for NPU")
if input_args is not None:
args = parser.parse_args(input_args)
@@ -924,8 +929,7 @@ def main(args):
image.save(image_filename)
del pipeline
if torch.cuda.is_available():
torch.cuda.empty_cache()
free_memory()
# Handle the repository creation
if accelerator.is_main_process:
@@ -988,6 +992,13 @@ def main(args):
# because Gemma2 is particularly suited for bfloat16.
text_encoder.to(dtype=torch.bfloat16)
if args.enable_npu_flash_attention:
if is_torch_npu_available():
logger.info("npu flash attention enabled.")
transformer.enable_npu_flash_attention()
else:
raise ValueError("npu flash attention requires torch_npu extensions and is supported only on npu device ")
# Initialize a text encoding pipeline and keep it to CPU for now.
text_encoding_pipeline = SanaPipeline.from_pretrained(
args.pretrained_model_name_or_path,
@@ -1,6 +1,6 @@
#!/usr/bin/env python
# coding=utf-8
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -1,6 +1,6 @@
#!/usr/bin/env python
# coding=utf-8
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
+1 -1
View File
@@ -1,6 +1,6 @@
#!/usr/bin/env python
# coding=utf-8
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
+1 -1
View File
@@ -1,6 +1,6 @@
#!/usr/bin/env python
# coding=utf-8
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -1,6 +1,6 @@
#!/usr/bin/env python
# coding=utf-8
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -1,6 +1,6 @@
#!/usr/bin/env python
# coding=utf-8
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -1,6 +1,6 @@
#!/usr/bin/env python
# coding=utf-8
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -1,5 +1,5 @@
# coding=utf-8
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -1,5 +1,5 @@
# coding=utf-8
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -1,6 +1,6 @@
#!/usr/bin/env python
# coding=utf-8
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -0,0 +1,59 @@
# AutoencoderKL training example
## Installing the dependencies
Before running the scripts, make sure to install the library's training dependencies:
**Important**
To make sure you can successfully run the latest versions of the example scripts, we highly recommend **installing from source** and keeping the install up to date as we update the example scripts frequently and install some example-specific requirements. To do this, execute the following steps in a new virtual environment:
```bash
git clone https://github.com/huggingface/diffusers
cd diffusers
pip install .
```
Then cd in the example folder and run
```bash
pip install -r requirements.txt
```
And initialize an [🤗Accelerate](https://github.com/huggingface/accelerate/) environment with:
```bash
accelerate config
```
## Training on CIFAR10
Please replace the validation image with your own image.
```bash
accelerate launch train_autoencoderkl.py \
--pretrained_model_name_or_path stabilityai/sd-vae-ft-mse \
--dataset_name=cifar10 \
--image_column=img \
--validation_image images/bird.jpg images/car.jpg images/dog.jpg images/frog.jpg \
--num_train_epochs 100 \
--gradient_accumulation_steps 2 \
--learning_rate 4.5e-6 \
--lr_scheduler cosine \
--report_to wandb \
```
## Training on ImageNet
```bash
accelerate launch train_autoencoderkl.py \
--pretrained_model_name_or_path stabilityai/sd-vae-ft-mse \
--num_train_epochs 100 \
--gradient_accumulation_steps 2 \
--learning_rate 4.5e-6 \
--lr_scheduler cosine \
--report_to wandb \
--mixed_precision bf16 \
--train_data_dir /path/to/ImageNet/train \
--validation_image ./image.png \
--decoder_only
```
@@ -0,0 +1,15 @@
accelerate>=0.16.0
bitsandbytes
datasets
huggingface_hub
lpips
numpy
packaging
Pillow
taming_transformers
torch
torchvision
tqdm
transformers
wandb
xformers
File diff suppressed because it is too large Load Diff
@@ -1,6 +1,6 @@
#!/usr/bin/env python
# coding=utf-8
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -1,6 +1,6 @@
#!/usr/bin/env python
# coding=utf-8
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -1,6 +1,6 @@
#!/usr/bin/env python
# coding=utf-8
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -1,6 +1,6 @@
#!/usr/bin/env python
# coding=utf-8
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -1,6 +1,6 @@
#!/usr/bin/env python
# coding=utf-8
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -1,6 +1,6 @@
#!/usr/bin/env python
# coding=utf-8
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -1,6 +1,6 @@
#!/usr/bin/env python
# coding=utf-8
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -1,5 +1,5 @@
# coding=utf-8
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -1,6 +1,6 @@
#!/usr/bin/env python
# coding=utf-8
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -1,6 +1,6 @@
#!/usr/bin/env python
# coding=utf-8
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -1,6 +1,6 @@
#!/usr/bin/env python
# coding=utf-8
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -8,7 +8,6 @@ from diffusers.models import PixArtTransformer2DModel
from diffusers.models.attention import BasicTransformerBlock
from diffusers.models.modeling_outputs import Transformer2DModelOutput
from diffusers.models.modeling_utils import ModelMixin
from diffusers.utils.torch_utils import is_torch_version
class PixArtControlNetAdapterBlock(nn.Module):
@@ -151,10 +150,6 @@ class PixArtControlNetTransformerModel(ModelMixin, ConfigMixin):
self.transformer = transformer
self.controlnet = controlnet
def _set_gradient_checkpointing(self, module, value=False):
if hasattr(module, "gradient_checkpointing"):
module.gradient_checkpointing = value
def forward(
self,
hidden_states: torch.Tensor,
@@ -220,18 +215,8 @@ class PixArtControlNetTransformerModel(ModelMixin, ConfigMixin):
print("Gradient checkpointing is not supported for the controlnet transformer model, yet.")
exit(1)
def create_custom_forward(module, return_dict=None):
def custom_forward(*inputs):
if return_dict is not None:
return module(*inputs, return_dict=return_dict)
else:
return module(*inputs)
return custom_forward
ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
hidden_states = torch.utils.checkpoint.checkpoint(
create_custom_forward(block),
hidden_states = self._gradient_checkpointing_func(
block,
hidden_states,
attention_mask,
encoder_hidden_states,
@@ -239,7 +224,6 @@ class PixArtControlNetTransformerModel(ModelMixin, ConfigMixin):
timestep,
cross_attention_kwargs,
None,
**ckpt_kwargs,
)
else:
# the control nets are only used for the blocks 1 to self.blocks_num
@@ -1,6 +1,6 @@
#!/usr/bin/env python
# coding=utf-8
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -1,6 +1,6 @@
#!/usr/bin/env python
# coding=utf-8
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -1,6 +1,6 @@
#!/usr/bin/env python
# coding=utf-8
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -1,6 +1,6 @@
#!/usr/bin/env python
# coding=utf-8
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -1,6 +1,6 @@
#!/usr/bin/env python
# coding=utf-8
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -1,6 +1,6 @@
#!/usr/bin/env python
# coding=utf-8
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -1,6 +1,6 @@
#!/usr/bin/env python
# coding=utf-8
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -1,6 +1,6 @@
#!/usr/bin/env python
# coding=utf-8
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -1,6 +1,6 @@
#!/usr/bin/env python
# coding=utf-8
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -1,6 +1,6 @@
#!/usr/bin/env python
# coding=utf-8
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -1,4 +1,4 @@
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -1,4 +1,4 @@
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -1,6 +1,6 @@
#!/usr/bin/env python
# coding=utf-8
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
+1 -1
View File
@@ -1,6 +1,6 @@
#!/usr/bin/env python
# coding=utf-8
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -1,6 +1,6 @@
#!/usr/bin/env python
# coding=utf-8
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -1,6 +1,6 @@
#!/usr/bin/env python
# coding=utf-8
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -1,6 +1,6 @@
#!/usr/bin/env python
# coding=utf-8
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -1,6 +1,6 @@
#!/usr/bin/env python
# coding=utf-8
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -515,10 +515,6 @@ def main():
elif accelerator.mixed_precision == "bf16":
weight_dtype = torch.bfloat16
# Freeze the unet parameters before adding adapters
for param in unet.parameters():
param.requires_grad_(False)
unet_lora_config = LoraConfig(
r=args.rank,
lora_alpha=args.rank,
@@ -1,6 +1,6 @@
#!/usr/bin/env python
# coding=utf-8
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -1,6 +1,6 @@
#!/usr/bin/env python
# coding=utf-8
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -1,6 +1,6 @@
#!/usr/bin/env python
# coding=utf-8
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -1,6 +1,6 @@
#!/usr/bin/env python
# coding=utf-8
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
+1 -1
View File
@@ -1,6 +1,6 @@
#!/usr/bin/env python
# coding=utf-8
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -1,5 +1,5 @@
# coding=utf-8
# Copyright 2024 The HuggingFace Inc. team.
# Copyright 2025 The HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
+1 -1
View File
@@ -1,5 +1,5 @@
# coding=utf-8
# Copyright 2024 The HuggingFace Inc. team.
# Copyright 2025 The HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -1,5 +1,5 @@
# coding=utf-8
# Copyright 2024 The HuggingFace Inc. team.
# Copyright 2025 The HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -1,5 +1,5 @@
# coding=utf-8
# Copyright 2024 The HuggingFace Inc. team.
# Copyright 2025 The HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -1,5 +1,5 @@
# coding=utf-8
# Copyright 2024 The HuggingFace Inc. team.
# Copyright 2025 The HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -1,5 +1,5 @@
# coding=utf-8
# Copyright 2024 The HuggingFace Inc. team.
# Copyright 2025 The HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -1,5 +1,5 @@
# coding=utf-8
# Copyright 2024 The HuggingFace Inc. team.
# Copyright 2025 The HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -1,5 +1,5 @@
# coding=utf-8
# Copyright 2024 The HuggingFace Inc. team.
# Copyright 2025 The HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -1,5 +1,5 @@
# coding=utf-8
# Copyright 2024 The HuggingFace Inc. team.
# Copyright 2025 The HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -1,5 +1,5 @@
# coding=utf-8
# Copyright 2024 The HuggingFace Inc. team.
# Copyright 2025 The HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
+1 -1
View File
@@ -1,5 +1,5 @@
# coding=utf-8
# Copyright 2024 The HuggingFace Inc. team.
# Copyright 2025 The HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -1,5 +1,5 @@
# coding=utf-8
# Copyright 2024 The HuggingFace Inc. team.
# Copyright 2025 The HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
+151
View File
@@ -0,0 +1,151 @@
"""
This script demonstrates how to extract a LoRA checkpoint from a fully finetuned model with the CogVideoX model.
To make it work for other models:
* Change the model class. Here we use `CogVideoXTransformer3DModel`. For Flux, it would be `FluxTransformer2DModel`,
for example. (TODO: more reason to add `AutoModel`).
* Spply path to the base checkpoint via `base_ckpt_path`.
* Supply path to the fully fine-tuned checkpoint via `--finetune_ckpt_path`.
* Change the `--rank` as needed.
Example usage:
```bash
python extract_lora_from_model.py \
--base_ckpt_path=THUDM/CogVideoX-5b \
--finetune_ckpt_path=finetrainers/cakeify-v0 \
--lora_out_path=cakeify_lora.safetensors
```
Script is adapted from
https://github.com/Stability-AI/stability-ComfyUI-nodes/blob/001154622564b17223ce0191803c5fff7b87146c/control_lora_create.py
"""
import argparse
import torch
from safetensors.torch import save_file
from tqdm.auto import tqdm
from diffusers import CogVideoXTransformer3DModel
RANK = 64
CLAMP_QUANTILE = 0.99
# Comes from
# https://github.com/Stability-AI/stability-ComfyUI-nodes/blob/001154622564b17223ce0191803c5fff7b87146c/control_lora_create.py#L9
def extract_lora(diff, rank):
# Important to use CUDA otherwise, very slow!
if torch.cuda.is_available():
diff = diff.to("cuda")
is_conv2d = len(diff.shape) == 4
kernel_size = None if not is_conv2d else diff.size()[2:4]
is_conv2d_3x3 = is_conv2d and kernel_size != (1, 1)
out_dim, in_dim = diff.size()[0:2]
rank = min(rank, in_dim, out_dim)
if is_conv2d:
if is_conv2d_3x3:
diff = diff.flatten(start_dim=1)
else:
diff = diff.squeeze()
U, S, Vh = torch.linalg.svd(diff.float())
U = U[:, :rank]
S = S[:rank]
U = U @ torch.diag(S)
Vh = Vh[:rank, :]
dist = torch.cat([U.flatten(), Vh.flatten()])
hi_val = torch.quantile(dist, CLAMP_QUANTILE)
low_val = -hi_val
U = U.clamp(low_val, hi_val)
Vh = Vh.clamp(low_val, hi_val)
if is_conv2d:
U = U.reshape(out_dim, rank, 1, 1)
Vh = Vh.reshape(rank, in_dim, kernel_size[0], kernel_size[1])
return (U.cpu(), Vh.cpu())
def parse_args():
parser = argparse.ArgumentParser()
parser.add_argument(
"--base_ckpt_path",
default=None,
type=str,
required=True,
help="Base checkpoint path from which the model was finetuned. Can be a model ID on the Hub.",
)
parser.add_argument(
"--base_subfolder",
default="transformer",
type=str,
help="subfolder to load the base checkpoint from if any.",
)
parser.add_argument(
"--finetune_ckpt_path",
default=None,
type=str,
required=True,
help="Fully fine-tuned checkpoint path. Can be a model ID on the Hub.",
)
parser.add_argument(
"--finetune_subfolder",
default=None,
type=str,
help="subfolder to load the fulle finetuned checkpoint from if any.",
)
parser.add_argument("--rank", default=64, type=int)
parser.add_argument("--lora_out_path", default=None, type=str, required=True)
args = parser.parse_args()
if not args.lora_out_path.endswith(".safetensors"):
raise ValueError("`lora_out_path` must end with `.safetensors`.")
return args
@torch.no_grad()
def main(args):
model_finetuned = CogVideoXTransformer3DModel.from_pretrained(
args.finetune_ckpt_path, subfolder=args.finetune_subfolder, torch_dtype=torch.bfloat16
)
state_dict_ft = model_finetuned.state_dict()
# Change the `subfolder` as needed.
base_model = CogVideoXTransformer3DModel.from_pretrained(
args.base_ckpt_path, subfolder=args.base_subfolder, torch_dtype=torch.bfloat16
)
state_dict = base_model.state_dict()
output_dict = {}
for k in tqdm(state_dict, desc="Extracting LoRA..."):
original_param = state_dict[k]
finetuned_param = state_dict_ft[k]
if len(original_param.shape) >= 2:
diff = finetuned_param.float() - original_param.float()
out = extract_lora(diff, RANK)
name = k
if name.endswith(".weight"):
name = name[: -len(".weight")]
down_key = "{}.lora_A.weight".format(name)
up_key = "{}.lora_B.weight".format(name)
output_dict[up_key] = out[0].contiguous().to(finetuned_param.dtype)
output_dict[down_key] = out[1].contiguous().to(finetuned_param.dtype)
prefix = "transformer" if "transformer" in base_model.__class__.__name__.lower() else "unet"
output_dict = {f"{prefix}.{k}": v for k, v in output_dict.items()}
save_file(output_dict, args.lora_out_path)
print(f"LoRA saved and it contains {len(output_dict)} keys.")
if __name__ == "__main__":
args = parse_args()
main(args)
+11
View File
@@ -28,6 +28,7 @@ from .utils import (
_import_structure = {
"configuration_utils": ["ConfigMixin"],
"hooks": [],
"loaders": ["FromOriginalModelMixin"],
"models": [],
"pipelines": [],
@@ -75,6 +76,13 @@ except OptionalDependencyNotAvailable:
_import_structure["utils.dummy_pt_objects"] = [name for name in dir(dummy_pt_objects) if not name.startswith("_")]
else:
_import_structure["hooks"].extend(
[
"HookRegistry",
"PyramidAttentionBroadcastConfig",
"apply_pyramid_attention_broadcast",
]
)
_import_structure["models"].extend(
[
"AllegroTransformer3DModel",
@@ -90,6 +98,7 @@ else:
"AutoencoderKLTemporalDecoder",
"AutoencoderOobleck",
"AutoencoderTiny",
"CacheMixin",
"CogVideoXTransformer3DModel",
"CogView3PlusTransformer2DModel",
"ConsisIDTransformer3DModel",
@@ -588,6 +597,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
except OptionalDependencyNotAvailable:
from .utils.dummy_pt_objects import * # noqa F403
else:
from .hooks import HookRegistry, PyramidAttentionBroadcastConfig, apply_pyramid_attention_broadcast
from .models import (
AllegroTransformer3DModel,
AsymmetricAutoencoderKL,
@@ -602,6 +612,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
AutoencoderKLTemporalDecoder,
AutoencoderOobleck,
AutoencoderTiny,
CacheMixin,
CogVideoXTransformer3DModel,
CogView3PlusTransformer2DModel,
ConsisIDTransformer3DModel,
+1 -1
View File
@@ -1,5 +1,5 @@
# coding=utf-8
# Copyright 2024 The HuggingFace Inc. team.
# Copyright 2025 The HuggingFace Inc. team.
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
+7
View File
@@ -0,0 +1,7 @@
from ..utils import is_torch_available
if is_torch_available():
from .hooks import HookRegistry, ModelHook
from .layerwise_casting import apply_layerwise_casting, apply_layerwise_casting_hook
from .pyramid_attention_broadcast import PyramidAttentionBroadcastConfig, apply_pyramid_attention_broadcast

Some files were not shown because too many files have changed in this diff Show More