Compare commits
34 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| 1f67e4e7f6 | |||
| 196aef5a6f | |||
| 7b100ce589 | |||
| c4d4ac21e7 | |||
| f295e2eefc | |||
| 658e24e86c | |||
| fb42066489 | |||
| e89ab5bc26 | |||
| 8ceec90d76 | |||
| 158c5c4d08 | |||
| 41571773d9 | |||
| 18f7d1d937 | |||
| f7f36c7d3d | |||
| 4fa24591a3 | |||
| 4f3ec5364e | |||
| 07860f9916 | |||
| 87252d80c3 | |||
| 5897137397 | |||
| a451c0ed14 | |||
| 37c9697f5b | |||
| 9684c52adf | |||
| 5483162d12 | |||
| d77c53b6d2 | |||
| 78bc824729 | |||
| 04d40920a7 | |||
| 8d6f6d6b66 | |||
| ca60ad8e55 | |||
| beacaa5528 | |||
| a647682224 | |||
| a1f9a71238 | |||
| ec37e20972 | |||
| 158a5a87fb | |||
| 012d08b1bc | |||
| 4ace7d0483 |
@@ -598,6 +598,8 @@
|
||||
title: Attention Processor
|
||||
- local: api/activations
|
||||
title: Custom activation functions
|
||||
- local: api/cache
|
||||
title: Caching methods
|
||||
- local: api/normalization
|
||||
title: Custom normalization layers
|
||||
- local: api/utilities
|
||||
|
||||
@@ -0,0 +1,49 @@
|
||||
<!-- Copyright 2024 The HuggingFace Team. All rights reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
|
||||
the License. You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
|
||||
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
|
||||
specific language governing permissions and limitations under the License. -->
|
||||
|
||||
# Caching methods
|
||||
|
||||
## Pyramid Attention Broadcast
|
||||
|
||||
[Pyramid Attention Broadcast](https://huggingface.co/papers/2408.12588) from Xuanlei Zhao, Xiaolong Jin, Kai Wang, Yang You.
|
||||
|
||||
Pyramid Attention Broadcast (PAB) is a method that speeds up inference in diffusion models by systematically skipping attention computations between successive inference steps and reusing cached attention states. The attention states are not very different between successive inference steps. The most prominent difference is in the spatial attention blocks, not as much in the temporal attention blocks, and finally the least in the cross attention blocks. Therefore, many cross attention computation blocks can be skipped, followed by the temporal and spatial attention blocks. By combining other techniques like sequence parallelism and classifier-free guidance parallelism, PAB achieves near real-time video generation.
|
||||
|
||||
Enable PAB with [`~PyramidAttentionBroadcastConfig`] on any pipeline. For some benchmarks, refer to [this](https://github.com/huggingface/diffusers/pull/9562) pull request.
|
||||
|
||||
```python
|
||||
import torch
|
||||
from diffusers import CogVideoXPipeline, PyramidAttentionBroadcastConfig
|
||||
|
||||
pipe = CogVideoXPipeline.from_pretrained("THUDM/CogVideoX-5b", torch_dtype=torch.bfloat16)
|
||||
pipe.to("cuda")
|
||||
|
||||
# Increasing the value of `spatial_attention_timestep_skip_range[0]` or decreasing the value of
|
||||
# `spatial_attention_timestep_skip_range[1]` will decrease the interval in which pyramid attention
|
||||
# broadcast is active, leader to slower inference speeds. However, large intervals can lead to
|
||||
# poorer quality of generated videos.
|
||||
config = PyramidAttentionBroadcastConfig(
|
||||
spatial_attention_block_skip_range=2,
|
||||
spatial_attention_timestep_skip_range=(100, 800),
|
||||
current_timestep_callback=lambda: pipe.current_timestep,
|
||||
)
|
||||
pipe.transformer.enable_cache(config)
|
||||
```
|
||||
|
||||
### CacheMixin
|
||||
|
||||
[[autodoc]] CacheMixin
|
||||
|
||||
### PyramidAttentionBroadcastConfig
|
||||
|
||||
[[autodoc]] PyramidAttentionBroadcastConfig
|
||||
|
||||
[[autodoc]] apply_pyramid_attention_broadcast
|
||||
@@ -309,6 +309,53 @@ image.save("output.png")
|
||||
|
||||
When unloading the Control LoRA weights, call `pipe.unload_lora_weights(reset_to_overwritten_params=True)` to reset the `pipe.transformer` completely back to its original form. The resultant pipeline can then be used with methods like [`DiffusionPipeline.from_pipe`]. More details about this argument are available in [this PR](https://github.com/huggingface/diffusers/pull/10397).
|
||||
|
||||
## IP-Adapter
|
||||
|
||||
<Tip>
|
||||
|
||||
Check out [IP-Adapter](../../../using-diffusers/ip_adapter) to learn more about how IP-Adapters work.
|
||||
|
||||
</Tip>
|
||||
|
||||
An IP-Adapter lets you prompt Flux with images, in addition to the text prompt. This is especially useful when describing complex concepts that are difficult to articulate through text alone and you have reference images.
|
||||
|
||||
```python
|
||||
import torch
|
||||
from diffusers import FluxPipeline
|
||||
from diffusers.utils import load_image
|
||||
|
||||
pipe = FluxPipeline.from_pretrained(
|
||||
"black-forest-labs/FLUX.1-dev", torch_dtype=torch.bfloat16
|
||||
).to("cuda")
|
||||
|
||||
image = load_image("https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/flux_ip_adapter_input.jpg").resize((1024, 1024))
|
||||
|
||||
pipe.load_ip_adapter(
|
||||
"XLabs-AI/flux-ip-adapter",
|
||||
weight_name="ip_adapter.safetensors",
|
||||
image_encoder_pretrained_model_name_or_path="openai/clip-vit-large-patch14"
|
||||
)
|
||||
pipe.set_ip_adapter_scale(1.0)
|
||||
|
||||
image = pipe(
|
||||
width=1024,
|
||||
height=1024,
|
||||
prompt="wearing sunglasses",
|
||||
negative_prompt="",
|
||||
true_cfg=4.0,
|
||||
generator=torch.Generator().manual_seed(4444),
|
||||
ip_adapter_image=image,
|
||||
).images[0]
|
||||
|
||||
image.save('flux_ip_adapter_output.jpg')
|
||||
```
|
||||
|
||||
<div class="justify-center">
|
||||
<img src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/flux_ip_adapter_output.jpg"/>
|
||||
<figcaption class="mt-2 text-sm text-center text-gray-500">IP-Adapter examples with prompt "wearing sunglasses"</figcaption>
|
||||
</div>
|
||||
|
||||
|
||||
## Running FP16 inference
|
||||
|
||||
Flux can generate high-quality images with FP16 (i.e. to accelerate inference on Turing/Volta GPUs) but produces different outputs compared to FP32/BF16. The issue is that some activations in the text encoders have to be clipped when running in FP16, which affects the overall image. Forcing text encoders to run with FP32 inference thus removes this output difference. See [here](https://github.com/huggingface/diffusers/pull/9097#issuecomment-2272292516) for details.
|
||||
|
||||
@@ -41,3 +41,7 @@ Utility and helper functions for working with 🤗 Diffusers.
|
||||
## randn_tensor
|
||||
|
||||
[[autodoc]] utils.torch_utils.randn_tensor
|
||||
|
||||
## apply_layerwise_casting
|
||||
|
||||
[[autodoc]] hooks.layerwise_casting.apply_layerwise_casting
|
||||
|
||||
@@ -23,32 +23,60 @@ You should install 🤗 Diffusers in a [virtual environment](https://docs.python
|
||||
If you're unfamiliar with Python virtual environments, take a look at this [guide](https://packaging.python.org/guides/installing-using-pip-and-virtual-environments/).
|
||||
A virtual environment makes it easier to manage different projects and avoid compatibility issues between dependencies.
|
||||
|
||||
Start by creating a virtual environment in your project directory:
|
||||
Create a virtual environment with Python or [uv](https://docs.astral.sh/uv/) (refer to [Installation](https://docs.astral.sh/uv/getting-started/installation/) for installation instructions), a fast Rust-based Python package and project manager.
|
||||
|
||||
<hfoptions id="install">
|
||||
<hfoption id="uv">
|
||||
|
||||
```bash
|
||||
python -m venv .env
|
||||
uv venv my-env
|
||||
source my-env/bin/activate
|
||||
```
|
||||
|
||||
Activate the virtual environment:
|
||||
</hfoption>
|
||||
<hfoption id="Python">
|
||||
|
||||
```bash
|
||||
source .env/bin/activate
|
||||
python -m venv my-env
|
||||
source my-env/bin/activate
|
||||
```
|
||||
|
||||
You should also install 🤗 Transformers because 🤗 Diffusers relies on its models:
|
||||
</hfoption>
|
||||
</hfoptions>
|
||||
|
||||
You should also install 🤗 Transformers because 🤗 Diffusers relies on its models.
|
||||
|
||||
|
||||
<frameworkcontent>
|
||||
<pt>
|
||||
Note - PyTorch only supports Python 3.8 - 3.11 on Windows.
|
||||
|
||||
PyTorch only supports Python 3.8 - 3.11 on Windows. Install Diffusers with uv.
|
||||
|
||||
```bash
|
||||
uv install diffusers["torch"] transformers
|
||||
```
|
||||
|
||||
You can also install Diffusers with pip.
|
||||
|
||||
```bash
|
||||
pip install diffusers["torch"] transformers
|
||||
```
|
||||
|
||||
</pt>
|
||||
<jax>
|
||||
|
||||
Install Diffusers with uv.
|
||||
|
||||
```bash
|
||||
uv pip install diffusers["flax"] transformers
|
||||
```
|
||||
|
||||
You can also install Diffusers with pip.
|
||||
|
||||
```bash
|
||||
pip install diffusers["flax"] transformers
|
||||
```
|
||||
|
||||
</jax>
|
||||
</frameworkcontent>
|
||||
|
||||
|
||||
@@ -158,6 +158,43 @@ In order to properly offload models after they're called, it is required to run
|
||||
|
||||
</Tip>
|
||||
|
||||
## FP8 layerwise weight-casting
|
||||
|
||||
PyTorch supports `torch.float8_e4m3fn` and `torch.float8_e5m2` as weight storage dtypes, but they can't be used for computation in many different tensor operations due to unimplemented kernel support. However, you can use these dtypes to store model weights in fp8 precision and upcast them on-the-fly when the layers are used in the forward pass. This is known as layerwise weight-casting.
|
||||
|
||||
Typically, inference on most models is done with `torch.float16` or `torch.bfloat16` weight/computation precision. Layerwise weight-casting cuts down the memory footprint of the model weights by approximately half.
|
||||
|
||||
```python
|
||||
import torch
|
||||
from diffusers import CogVideoXPipeline, CogVideoXTransformer3DModel
|
||||
from diffusers.utils import export_to_video
|
||||
|
||||
model_id = "THUDM/CogVideoX-5b"
|
||||
|
||||
# Load the model in bfloat16 and enable layerwise casting
|
||||
transformer = CogVideoXTransformer3DModel.from_pretrained(model_id, subfolder="transformer", torch_dtype=torch.bfloat16)
|
||||
transformer.enable_layerwise_casting(storage_dtype=torch.float8_e4m3fn, compute_dtype=torch.bfloat16)
|
||||
|
||||
# Load the pipeline
|
||||
pipe = CogVideoXPipeline.from_pretrained(model_id, transformer=transformer, torch_dtype=torch.bfloat16)
|
||||
pipe.to("cuda")
|
||||
|
||||
prompt = (
|
||||
"A panda, dressed in a small, red jacket and a tiny hat, sits on a wooden stool in a serene bamboo forest. "
|
||||
"The panda's fluffy paws strum a miniature acoustic guitar, producing soft, melodic tunes. Nearby, a few other "
|
||||
"pandas gather, watching curiously and some clapping in rhythm. Sunlight filters through the tall bamboo, "
|
||||
"casting a gentle glow on the scene. The panda's face is expressive, showing concentration and joy as it plays. "
|
||||
"The background includes a small, flowing stream and vibrant green foliage, enhancing the peaceful and magical "
|
||||
"atmosphere of this unique musical performance."
|
||||
)
|
||||
video = pipe(prompt=prompt, guidance_scale=6, num_inference_steps=50).frames[0]
|
||||
export_to_video(video, "output.mp4", fps=8)
|
||||
```
|
||||
|
||||
In the above example, layerwise casting is enabled on the transformer component of the pipeline. By default, certain layers are skipped from the FP8 weight casting because it can lead to significant degradation of generation quality. The normalization and modulation related weight parameters are also skipped by default.
|
||||
|
||||
However, you gain more control and flexibility by directly utilizing the [`~hooks.layerwise_casting.apply_layerwise_casting`] function instead of [`~ModelMixin.enable_layerwise_casting`].
|
||||
|
||||
## Channels-last memory format
|
||||
|
||||
The channels-last memory format is an alternative way of ordering NCHW tensors in memory to preserve dimension ordering. Channels-last tensors are ordered in such a way that the channels become the densest dimension (storing images pixel-per-pixel). Since not all operators currently support the channels-last format, it may result in worst performance but you should still try and see if it works for your model.
|
||||
|
||||
@@ -29,7 +29,7 @@ However, it is hard to decide when to reuse the cache to ensure quality generate
|
||||
This achieves a 2x speedup on FLUX.1-dev and HunyuanVideo inference with very good quality.
|
||||
|
||||
<figure>
|
||||
<img src="https://huggingface.co/datasets/chengzeyi/documentation-images/resolve/main/diffusers/para-attn/ada-cache.png" alt="Cache in Diffusion Transformer" />
|
||||
<img src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/para-attn/ada-cache.png" alt="Cache in Diffusion Transformer" />
|
||||
<figcaption>How AdaCache works, First Block Cache is a variant of it</figcaption>
|
||||
</figure>
|
||||
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
#!/usr/bin/env python
|
||||
# coding=utf-8
|
||||
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
#!/usr/bin/env python
|
||||
# coding=utf-8
|
||||
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
#!/usr/bin/env python
|
||||
# coding=utf-8
|
||||
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2024 The HuggingFace Inc. team.
|
||||
# Copyright 2025 The HuggingFace Inc. team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
|
||||
Executable → Regular
+90
-2
@@ -77,6 +77,7 @@ Please also check out our [Community Scripts](https://github.com/huggingface/dif
|
||||
PIXART-α Controlnet pipeline | Implementation of the controlnet model for pixart alpha and its diffusers pipeline | [PIXART-α Controlnet pipeline](#pixart-α-controlnet-pipeline) | - | [Raul Ciotescu](https://github.com/raulc0399/) |
|
||||
| HunyuanDiT Differential Diffusion Pipeline | Applies [Differential Diffusion](https://github.com/exx8/differential-diffusion) to [HunyuanDiT](https://github.com/huggingface/diffusers/pull/8240). | [HunyuanDiT with Differential Diffusion](#hunyuandit-with-differential-diffusion) | [](https://colab.research.google.com/drive/1v44a5fpzyr4Ffr4v2XBQ7BajzG874N4P?usp=sharing) | [Monjoy Choudhury](https://github.com/MnCSSJ4x) |
|
||||
| [🪆Matryoshka Diffusion Models](https://huggingface.co/papers/2310.15111) | A diffusion process that denoises inputs at multiple resolutions jointly and uses a NestedUNet architecture where features and parameters for small scale inputs are nested within those of the large scales. See [original codebase](https://github.com/apple/ml-mdm). | [🪆Matryoshka Diffusion Models](#matryoshka-diffusion-models) | [](https://huggingface.co/spaces/pcuenq/mdm) [](https://colab.research.google.com/gist/tolgacangoz/1f54875fc7aeaabcf284ebde64820966/matryoshka_hf.ipynb) | [M. Tolga Cangöz](https://github.com/tolgacangoz) |
|
||||
| Stable Diffusion XL Attentive Eraser Pipeline |[[AAAI2025 Oral] Attentive Eraser](https://github.com/Anonym0u3/AttentiveEraser) is a novel tuning-free method that enhances object removal capabilities in pre-trained diffusion models.|[Stable Diffusion XL Attentive Eraser Pipeline](#stable-diffusion-xl-attentive-eraser-pipeline)|-|[Wenhao Sun](https://github.com/Anonym0u3) and [Benlei Cui](https://github.com/Benny079)|
|
||||
|
||||
To load a custom pipeline you just need to pass the `custom_pipeline` argument to `DiffusionPipeline`, as one of the files in `diffusers/examples/community`. Feel free to send a PR with your own pipelines, we will merge them quickly.
|
||||
|
||||
@@ -4585,8 +4586,8 @@ image = pipe(
|
||||
```
|
||||
|
||||
|  |  |  |
|
||||
| ------------------------------------------------------------------------------------------ | --------------------------------------------------------------------------------------- | ---------------------------------------------------------------------------------------- |
|
||||
| Gradient | Input | Output |
|
||||
| -------------------------------------------------------------------------------------------- | ----------------------------------------------------------------------------------------- | ------------------------------------------------------------------------------------------ |
|
||||
| Gradient | Input | Output |
|
||||
|
||||
A colab notebook demonstrating all results can be found [here](https://colab.research.google.com/drive/1v44a5fpzyr4Ffr4v2XBQ7BajzG874N4P?usp=sharing). Depth Maps have also been added in the same colab.
|
||||
|
||||
@@ -4634,6 +4635,93 @@ make_image_grid(image, rows=1, cols=len(image))
|
||||
# 50+, 100+, and 250+ num_inference_steps are recommended for nesting levels 0, 1, and 2 respectively.
|
||||
```
|
||||
|
||||
### Stable Diffusion XL Attentive Eraser Pipeline
|
||||
<img src="https://raw.githubusercontent.com/Anonym0u3/Images/refs/heads/main/fenmian.png" width="600" />
|
||||
|
||||
**Stable Diffusion XL Attentive Eraser Pipeline** is an advanced object removal pipeline that leverages SDXL for precise content suppression and seamless region completion. This pipeline uses **self-attention redirection guidance** to modify the model’s self-attention mechanism, allowing for effective removal and inpainting across various levels of mask precision, including semantic segmentation masks, bounding boxes, and hand-drawn masks. If you are interested in more detailed information and have any questions, please refer to the [paper](https://arxiv.org/abs/2412.12974) and [official implementation](https://github.com/Anonym0u3/AttentiveEraser).
|
||||
|
||||
#### Key features
|
||||
|
||||
- **Tuning-Free**: No additional training is required, making it easy to integrate and use.
|
||||
- **Flexible Mask Support**: Works with different types of masks for targeted object removal.
|
||||
- **High-Quality Results**: Utilizes the inherent generative power of diffusion models for realistic content completion.
|
||||
|
||||
#### Usage example
|
||||
To use the Stable Diffusion XL Attentive Eraser Pipeline, you can initialize it as follows:
|
||||
```py
|
||||
import torch
|
||||
from diffusers import DDIMScheduler, DiffusionPipeline
|
||||
from diffusers.utils import load_image
|
||||
import torch.nn.functional as F
|
||||
from torchvision.transforms.functional import to_tensor, gaussian_blur
|
||||
|
||||
dtype = torch.float16
|
||||
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
|
||||
|
||||
scheduler = DDIMScheduler(beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", clip_sample=False, set_alpha_to_one=False)
|
||||
pipeline = DiffusionPipeline.from_pretrained(
|
||||
"stabilityai/stable-diffusion-xl-base-1.0",
|
||||
custom_pipeline="pipeline_stable_diffusion_xl_attentive_eraser",
|
||||
scheduler=scheduler,
|
||||
variant="fp16",
|
||||
use_safetensors=True,
|
||||
torch_dtype=dtype,
|
||||
).to(device)
|
||||
|
||||
|
||||
def preprocess_image(image_path, device):
|
||||
image = to_tensor((load_image(image_path)))
|
||||
image = image.unsqueeze_(0).float() * 2 - 1 # [0,1] --> [-1,1]
|
||||
if image.shape[1] != 3:
|
||||
image = image.expand(-1, 3, -1, -1)
|
||||
image = F.interpolate(image, (1024, 1024))
|
||||
image = image.to(dtype).to(device)
|
||||
return image
|
||||
|
||||
def preprocess_mask(mask_path, device):
|
||||
mask = to_tensor((load_image(mask_path, convert_method=lambda img: img.convert('L'))))
|
||||
mask = mask.unsqueeze_(0).float() # 0 or 1
|
||||
mask = F.interpolate(mask, (1024, 1024))
|
||||
mask = gaussian_blur(mask, kernel_size=(77, 77))
|
||||
mask[mask < 0.1] = 0
|
||||
mask[mask >= 0.1] = 1
|
||||
mask = mask.to(dtype).to(device)
|
||||
return mask
|
||||
|
||||
prompt = "" # Set prompt to null
|
||||
seed=123
|
||||
generator = torch.Generator(device=device).manual_seed(seed)
|
||||
source_image_path = "https://raw.githubusercontent.com/Anonym0u3/Images/refs/heads/main/an1024.png"
|
||||
mask_path = "https://raw.githubusercontent.com/Anonym0u3/Images/refs/heads/main/an1024_mask.png"
|
||||
source_image = preprocess_image(source_image_path, device)
|
||||
mask = preprocess_mask(mask_path, device)
|
||||
|
||||
image = pipeline(
|
||||
prompt=prompt,
|
||||
image=source_image,
|
||||
mask_image=mask,
|
||||
height=1024,
|
||||
width=1024,
|
||||
AAS=True, # enable AAS
|
||||
strength=0.8, # inpainting strength
|
||||
rm_guidance_scale=9, # removal guidance scale
|
||||
ss_steps = 9, # similarity suppression steps
|
||||
ss_scale = 0.3, # similarity suppression scale
|
||||
AAS_start_step=0, # AAS start step
|
||||
AAS_start_layer=34, # AAS start layer
|
||||
AAS_end_layer=70, # AAS end layer
|
||||
num_inference_steps=50, # number of inference steps # AAS_end_step = int(strength*num_inference_steps)
|
||||
generator=generator,
|
||||
guidance_scale=1,
|
||||
).images[0]
|
||||
image.save('./removed_img.png')
|
||||
print("Object removal completed")
|
||||
```
|
||||
|
||||
| Source Image | Mask | Output |
|
||||
| ---------------------------------------------------------------------------------------------- | ------------------------------------------------------------------------------------------- | --------------------------------------------------------------------------------------------------- |
|
||||
|  |  |  |
|
||||
|
||||
# Perturbed-Attention Guidance
|
||||
|
||||
[Project](https://ku-cvlab.github.io/Perturbed-Attention-Guidance/) / [arXiv](https://arxiv.org/abs/2403.17377) / [GitHub](https://github.com/KU-CVLAB/Perturbed-Attention-Guidance)
|
||||
|
||||
@@ -80,7 +80,6 @@ from diffusers.utils import (
|
||||
USE_PEFT_BACKEND,
|
||||
BaseOutput,
|
||||
deprecate,
|
||||
is_torch_version,
|
||||
is_torch_xla_available,
|
||||
logging,
|
||||
replace_example_docstring,
|
||||
@@ -869,23 +868,7 @@ class CrossAttnDownBlock2D(nn.Module):
|
||||
|
||||
for i, (resnet, attn) in enumerate(blocks):
|
||||
if torch.is_grad_enabled() and self.gradient_checkpointing:
|
||||
|
||||
def create_custom_forward(module, return_dict=None):
|
||||
def custom_forward(*inputs):
|
||||
if return_dict is not None:
|
||||
return module(*inputs, return_dict=return_dict)
|
||||
else:
|
||||
return module(*inputs)
|
||||
|
||||
return custom_forward
|
||||
|
||||
ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
|
||||
hidden_states = torch.utils.checkpoint.checkpoint(
|
||||
create_custom_forward(resnet),
|
||||
hidden_states,
|
||||
temb,
|
||||
**ckpt_kwargs,
|
||||
)
|
||||
hidden_states = self._gradient_checkpointing_func(resnet, hidden_states, temb)
|
||||
hidden_states = attn(
|
||||
hidden_states,
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
@@ -1030,17 +1013,6 @@ class UNetMidBlock2DCrossAttn(nn.Module):
|
||||
hidden_states = self.resnets[0](hidden_states, temb)
|
||||
for attn, resnet in zip(self.attentions, self.resnets[1:]):
|
||||
if torch.is_grad_enabled() and self.gradient_checkpointing:
|
||||
|
||||
def create_custom_forward(module, return_dict=None):
|
||||
def custom_forward(*inputs):
|
||||
if return_dict is not None:
|
||||
return module(*inputs, return_dict=return_dict)
|
||||
else:
|
||||
return module(*inputs)
|
||||
|
||||
return custom_forward
|
||||
|
||||
ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
|
||||
hidden_states = attn(
|
||||
hidden_states,
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
@@ -1049,12 +1021,7 @@ class UNetMidBlock2DCrossAttn(nn.Module):
|
||||
encoder_attention_mask=encoder_attention_mask,
|
||||
return_dict=False,
|
||||
)[0]
|
||||
hidden_states = torch.utils.checkpoint.checkpoint(
|
||||
create_custom_forward(resnet),
|
||||
hidden_states,
|
||||
temb,
|
||||
**ckpt_kwargs,
|
||||
)
|
||||
hidden_states = self._gradient_checkpointing_func(resnet, hidden_states, temb)
|
||||
else:
|
||||
hidden_states = attn(
|
||||
hidden_states,
|
||||
@@ -1192,23 +1159,7 @@ class CrossAttnUpBlock2D(nn.Module):
|
||||
hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
|
||||
|
||||
if torch.is_grad_enabled() and self.gradient_checkpointing:
|
||||
|
||||
def create_custom_forward(module, return_dict=None):
|
||||
def custom_forward(*inputs):
|
||||
if return_dict is not None:
|
||||
return module(*inputs, return_dict=return_dict)
|
||||
else:
|
||||
return module(*inputs)
|
||||
|
||||
return custom_forward
|
||||
|
||||
ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
|
||||
hidden_states = torch.utils.checkpoint.checkpoint(
|
||||
create_custom_forward(resnet),
|
||||
hidden_states,
|
||||
temb,
|
||||
**ckpt_kwargs,
|
||||
)
|
||||
hidden_states = self._gradient_checkpointing_func(resnet, hidden_states, temb)
|
||||
hidden_states = attn(
|
||||
hidden_states,
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
@@ -1282,10 +1233,6 @@ class MatryoshkaTransformer2DModel(LegacyModelMixin, LegacyConfigMixin):
|
||||
]
|
||||
)
|
||||
|
||||
def _set_gradient_checkpointing(self, module, value=False):
|
||||
if hasattr(module, "gradient_checkpointing"):
|
||||
module.gradient_checkpointing = value
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
@@ -1365,19 +1312,8 @@ class MatryoshkaTransformer2DModel(LegacyModelMixin, LegacyConfigMixin):
|
||||
# Blocks
|
||||
for block in self.transformer_blocks:
|
||||
if torch.is_grad_enabled() and self.gradient_checkpointing:
|
||||
|
||||
def create_custom_forward(module, return_dict=None):
|
||||
def custom_forward(*inputs):
|
||||
if return_dict is not None:
|
||||
return module(*inputs, return_dict=return_dict)
|
||||
else:
|
||||
return module(*inputs)
|
||||
|
||||
return custom_forward
|
||||
|
||||
ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
|
||||
hidden_states = torch.utils.checkpoint.checkpoint(
|
||||
create_custom_forward(block),
|
||||
hidden_states = self._gradient_checkpointing_func(
|
||||
block,
|
||||
hidden_states,
|
||||
attention_mask,
|
||||
encoder_hidden_states,
|
||||
@@ -1385,7 +1321,6 @@ class MatryoshkaTransformer2DModel(LegacyModelMixin, LegacyConfigMixin):
|
||||
timestep,
|
||||
cross_attention_kwargs,
|
||||
class_labels,
|
||||
**ckpt_kwargs,
|
||||
)
|
||||
else:
|
||||
hidden_states = block(
|
||||
@@ -2724,10 +2659,6 @@ class MatryoshkaUNet2DConditionModel(
|
||||
for module in self.children():
|
||||
fn_recursive_set_attention_slice(module, reversed_slice_size)
|
||||
|
||||
def _set_gradient_checkpointing(self, module, value=False):
|
||||
if hasattr(module, "gradient_checkpointing"):
|
||||
module.gradient_checkpointing = value
|
||||
|
||||
def enable_freeu(self, s1: float, s2: float, b1: float, b2: float):
|
||||
r"""Enables the FreeU mechanism from https://arxiv.org/abs/2309.11497.
|
||||
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
@@ -1,5 +1,5 @@
|
||||
#
|
||||
# Copyright 2024 The HuggingFace Inc. team.
|
||||
# Copyright 2025 The HuggingFace Inc. team.
|
||||
# SPDX-FileCopyrightText: Copyright (c) 1993-2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
#
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
#
|
||||
# Copyright 2024 The HuggingFace Inc. team.
|
||||
# Copyright 2025 The HuggingFace Inc. team.
|
||||
# SPDX-FileCopyrightText: Copyright (c) 1993-2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
#
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
#
|
||||
# Copyright 2024 The HuggingFace Inc. team.
|
||||
# Copyright 2025 The HuggingFace Inc. team.
|
||||
# SPDX-FileCopyrightText: Copyright (c) 1993-2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
#
|
||||
|
||||
@@ -193,7 +193,8 @@ class StableDiffusionXLControlNetReferencePipeline(StableDiffusionXLControlNetPi
|
||||
|
||||
def prepare_ref_latents(self, refimage, batch_size, dtype, device, generator, do_classifier_free_guidance):
|
||||
refimage = refimage.to(device=device)
|
||||
if self.vae.dtype == torch.float16 and self.vae.config.force_upcast:
|
||||
needs_upcasting = self.vae.dtype == torch.float16 and self.vae.config.force_upcast
|
||||
if needs_upcasting:
|
||||
self.upcast_vae()
|
||||
refimage = refimage.to(next(iter(self.vae.post_quant_conv.parameters())).dtype)
|
||||
if refimage.dtype != self.vae.dtype:
|
||||
@@ -223,6 +224,11 @@ class StableDiffusionXLControlNetReferencePipeline(StableDiffusionXLControlNetPi
|
||||
|
||||
# aligning device to prevent device errors when concating it with the latent model input
|
||||
ref_image_latents = ref_image_latents.to(device=device, dtype=dtype)
|
||||
|
||||
# cast back to fp16 if needed
|
||||
if needs_upcasting:
|
||||
self.vae.to(dtype=torch.float16)
|
||||
|
||||
return ref_image_latents
|
||||
|
||||
def prepare_ref_image(
|
||||
|
||||
@@ -139,7 +139,8 @@ def retrieve_timesteps(
|
||||
class StableDiffusionXLReferencePipeline(StableDiffusionXLPipeline):
|
||||
def prepare_ref_latents(self, refimage, batch_size, dtype, device, generator, do_classifier_free_guidance):
|
||||
refimage = refimage.to(device=device)
|
||||
if self.vae.dtype == torch.float16 and self.vae.config.force_upcast:
|
||||
needs_upcasting = self.vae.dtype == torch.float16 and self.vae.config.force_upcast
|
||||
if needs_upcasting:
|
||||
self.upcast_vae()
|
||||
refimage = refimage.to(next(iter(self.vae.post_quant_conv.parameters())).dtype)
|
||||
if refimage.dtype != self.vae.dtype:
|
||||
@@ -169,6 +170,11 @@ class StableDiffusionXLReferencePipeline(StableDiffusionXLPipeline):
|
||||
|
||||
# aligning device to prevent device errors when concating it with the latent model input
|
||||
ref_image_latents = ref_image_latents.to(device=device, dtype=dtype)
|
||||
|
||||
# cast back to fp16 if needed
|
||||
if needs_upcasting:
|
||||
self.vae.to(dtype=torch.float16)
|
||||
|
||||
return ref_image_latents
|
||||
|
||||
def prepare_ref_image(
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
#!/usr/bin/env python
|
||||
# coding=utf-8
|
||||
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
#!/usr/bin/env python
|
||||
# coding=utf-8
|
||||
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
#!/usr/bin/env python
|
||||
# coding=utf-8
|
||||
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
#!/usr/bin/env python
|
||||
# coding=utf-8
|
||||
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
#!/usr/bin/env python
|
||||
# coding=utf-8
|
||||
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
#!/usr/bin/env python
|
||||
# coding=utf-8
|
||||
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
#!/usr/bin/env python
|
||||
# coding=utf-8
|
||||
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
#!/usr/bin/env python
|
||||
# coding=utf-8
|
||||
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
#!/usr/bin/env python
|
||||
# coding=utf-8
|
||||
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
|
||||
@@ -742,3 +742,29 @@ accelerate launch train_dreambooth.py \
|
||||
## Stable Diffusion XL
|
||||
|
||||
We support fine-tuning of the UNet shipped in [Stable Diffusion XL](https://huggingface.co/papers/2307.01952) with DreamBooth and LoRA via the `train_dreambooth_lora_sdxl.py` script. Please refer to the docs [here](./README_sdxl.md).
|
||||
|
||||
## Dataset
|
||||
|
||||
We support 🤗 [Datasets](https://huggingface.co/docs/datasets/index), you can find a dataset on the [Hugging Face Hub](https://huggingface.co/datasets) or use your own.
|
||||
|
||||
The quickest way to get started with your custom dataset is 🤗 Datasets' [`ImageFolder`](https://huggingface.co/docs/datasets/image_dataset#imagefolder).
|
||||
|
||||
We need to create a file `metadata.jsonl` in the directory with our images:
|
||||
|
||||
```
|
||||
{"file_name": "01.jpg", "prompt": "prompt 01"}
|
||||
{"file_name": "02.jpg", "prompt": "prompt 02"}
|
||||
```
|
||||
|
||||
If we have a directory with image-text pairs e.g. `01.jpg` and `01.txt` then `convert_to_imagefolder.py` can create `metadata.jsonl`.
|
||||
|
||||
```sh
|
||||
python convert_to_imagefolder.py --path my_dataset/
|
||||
```
|
||||
|
||||
We use `--dataset_name` and `--caption_column` with training scripts.
|
||||
|
||||
```
|
||||
--dataset_name=my_dataset/
|
||||
--caption_column=prompt
|
||||
```
|
||||
|
||||
@@ -0,0 +1,32 @@
|
||||
import argparse
|
||||
import json
|
||||
import pathlib
|
||||
|
||||
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument(
|
||||
"--path",
|
||||
type=str,
|
||||
required=True,
|
||||
help="Path to folder with image-text pairs.",
|
||||
)
|
||||
parser.add_argument("--caption_column", type=str, default="prompt", help="Name of caption column.")
|
||||
args = parser.parse_args()
|
||||
|
||||
path = pathlib.Path(args.path)
|
||||
if not path.exists():
|
||||
raise RuntimeError(f"`--path` '{args.path}' does not exist.")
|
||||
|
||||
all_files = list(path.glob("*"))
|
||||
captions = list(path.glob("*.txt"))
|
||||
images = set(all_files) - set(captions)
|
||||
images = {image.stem: image for image in images}
|
||||
caption_image = {caption: images.get(caption.stem) for caption in captions if images.get(caption.stem)}
|
||||
|
||||
metadata = path.joinpath("metadata.jsonl")
|
||||
|
||||
with metadata.open("w", encoding="utf-8") as f:
|
||||
for caption, image in caption_image.items():
|
||||
caption_text = caption.read_text(encoding="utf-8")
|
||||
json.dump({"file_name": image.name, args.caption_column: caption_text}, f)
|
||||
f.write("\n")
|
||||
@@ -1,6 +1,6 @@
|
||||
#!/usr/bin/env python
|
||||
# coding=utf-8
|
||||
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
#!/usr/bin/env python
|
||||
# coding=utf-8
|
||||
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
@@ -1716,9 +1716,9 @@ def main(args):
|
||||
pipeline = FluxPipeline.from_pretrained(
|
||||
args.pretrained_model_name_or_path,
|
||||
vae=vae,
|
||||
text_encoder=accelerator.unwrap_model(text_encoder_one),
|
||||
text_encoder_2=accelerator.unwrap_model(text_encoder_two),
|
||||
transformer=accelerator.unwrap_model(transformer),
|
||||
text_encoder=accelerator.unwrap_model(text_encoder_one, keep_fp32_wrapper=False),
|
||||
text_encoder_2=accelerator.unwrap_model(text_encoder_two, keep_fp32_wrapper=False),
|
||||
transformer=accelerator.unwrap_model(transformer, keep_fp32_wrapper=False),
|
||||
revision=args.revision,
|
||||
variant=args.variant,
|
||||
torch_dtype=weight_dtype,
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
#!/usr/bin/env python
|
||||
# coding=utf-8
|
||||
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
@@ -54,7 +54,11 @@ from diffusers import (
|
||||
)
|
||||
from diffusers.loaders import StableDiffusionLoraLoaderMixin
|
||||
from diffusers.optimization import get_scheduler
|
||||
from diffusers.training_utils import _set_state_dict_into_text_encoder, cast_training_params
|
||||
from diffusers.training_utils import (
|
||||
_set_state_dict_into_text_encoder,
|
||||
cast_training_params,
|
||||
free_memory,
|
||||
)
|
||||
from diffusers.utils import (
|
||||
check_min_version,
|
||||
convert_state_dict_to_diffusers,
|
||||
@@ -151,14 +155,14 @@ def log_validation(
|
||||
if args.validation_images is None:
|
||||
images = []
|
||||
for _ in range(args.num_validation_images):
|
||||
with torch.cuda.amp.autocast():
|
||||
with torch.amp.autocast(accelerator.device.type):
|
||||
image = pipeline(**pipeline_args, generator=generator).images[0]
|
||||
images.append(image)
|
||||
else:
|
||||
images = []
|
||||
for image in args.validation_images:
|
||||
image = Image.open(image)
|
||||
with torch.cuda.amp.autocast():
|
||||
with torch.amp.autocast(accelerator.device.type):
|
||||
image = pipeline(**pipeline_args, image=image, generator=generator).images[0]
|
||||
images.append(image)
|
||||
|
||||
@@ -177,7 +181,7 @@ def log_validation(
|
||||
)
|
||||
|
||||
del pipeline
|
||||
torch.cuda.empty_cache()
|
||||
free_memory()
|
||||
|
||||
return images
|
||||
|
||||
@@ -793,7 +797,7 @@ def main(args):
|
||||
cur_class_images = len(list(class_images_dir.iterdir()))
|
||||
|
||||
if cur_class_images < args.num_class_images:
|
||||
torch_dtype = torch.float16 if accelerator.device.type == "cuda" else torch.float32
|
||||
torch_dtype = torch.float16 if accelerator.device.type in ("cuda", "xpu") else torch.float32
|
||||
if args.prior_generation_precision == "fp32":
|
||||
torch_dtype = torch.float32
|
||||
elif args.prior_generation_precision == "fp16":
|
||||
@@ -829,8 +833,7 @@ def main(args):
|
||||
image.save(image_filename)
|
||||
|
||||
del pipeline
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.empty_cache()
|
||||
free_memory()
|
||||
|
||||
# Handle the repository creation
|
||||
if accelerator.is_main_process:
|
||||
@@ -1085,7 +1088,7 @@ def main(args):
|
||||
tokenizer = None
|
||||
|
||||
gc.collect()
|
||||
torch.cuda.empty_cache()
|
||||
free_memory()
|
||||
else:
|
||||
pre_computed_encoder_hidden_states = None
|
||||
validation_prompt_encoder_hidden_states = None
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
#!/usr/bin/env python
|
||||
# coding=utf-8
|
||||
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
#!/usr/bin/env python
|
||||
# coding=utf-8
|
||||
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
@@ -63,6 +63,7 @@ from diffusers.utils import (
|
||||
is_wandb_available,
|
||||
)
|
||||
from diffusers.utils.hub_utils import load_or_create_model_card, populate_model_card
|
||||
from diffusers.utils.import_utils import is_torch_npu_available
|
||||
from diffusers.utils.torch_utils import is_compiled_module
|
||||
|
||||
|
||||
@@ -74,6 +75,9 @@ check_min_version("0.33.0.dev0")
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
if is_torch_npu_available():
|
||||
torch.npu.config.allow_internal_format = False
|
||||
|
||||
|
||||
def save_model_card(
|
||||
repo_id: str,
|
||||
@@ -601,6 +605,7 @@ def parse_args(input_args=None):
|
||||
)
|
||||
parser.add_argument("--local_rank", type=int, default=-1, help="For distributed training: local_rank")
|
||||
parser.add_argument("--enable_vae_tiling", action="store_true", help="Enabla vae tiling in log validation")
|
||||
parser.add_argument("--enable_npu_flash_attention", action="store_true", help="Enabla Flash Attention for NPU")
|
||||
|
||||
if input_args is not None:
|
||||
args = parser.parse_args(input_args)
|
||||
@@ -924,8 +929,7 @@ def main(args):
|
||||
image.save(image_filename)
|
||||
|
||||
del pipeline
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.empty_cache()
|
||||
free_memory()
|
||||
|
||||
# Handle the repository creation
|
||||
if accelerator.is_main_process:
|
||||
@@ -988,6 +992,13 @@ def main(args):
|
||||
# because Gemma2 is particularly suited for bfloat16.
|
||||
text_encoder.to(dtype=torch.bfloat16)
|
||||
|
||||
if args.enable_npu_flash_attention:
|
||||
if is_torch_npu_available():
|
||||
logger.info("npu flash attention enabled.")
|
||||
transformer.enable_npu_flash_attention()
|
||||
else:
|
||||
raise ValueError("npu flash attention requires torch_npu extensions and is supported only on npu device ")
|
||||
|
||||
# Initialize a text encoding pipeline and keep it to CPU for now.
|
||||
text_encoding_pipeline = SanaPipeline.from_pretrained(
|
||||
args.pretrained_model_name_or_path,
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
#!/usr/bin/env python
|
||||
# coding=utf-8
|
||||
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
#!/usr/bin/env python
|
||||
# coding=utf-8
|
||||
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
#!/usr/bin/env python
|
||||
# coding=utf-8
|
||||
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
#!/usr/bin/env python
|
||||
# coding=utf-8
|
||||
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
#!/usr/bin/env python
|
||||
# coding=utf-8
|
||||
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
#!/usr/bin/env python
|
||||
# coding=utf-8
|
||||
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
#!/usr/bin/env python
|
||||
# coding=utf-8
|
||||
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
#!/usr/bin/env python
|
||||
# coding=utf-8
|
||||
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
|
||||
@@ -0,0 +1,59 @@
|
||||
# AutoencoderKL training example
|
||||
|
||||
## Installing the dependencies
|
||||
|
||||
Before running the scripts, make sure to install the library's training dependencies:
|
||||
|
||||
**Important**
|
||||
|
||||
To make sure you can successfully run the latest versions of the example scripts, we highly recommend **installing from source** and keeping the install up to date as we update the example scripts frequently and install some example-specific requirements. To do this, execute the following steps in a new virtual environment:
|
||||
```bash
|
||||
git clone https://github.com/huggingface/diffusers
|
||||
cd diffusers
|
||||
pip install .
|
||||
```
|
||||
|
||||
Then cd in the example folder and run
|
||||
```bash
|
||||
pip install -r requirements.txt
|
||||
```
|
||||
|
||||
|
||||
And initialize an [🤗Accelerate](https://github.com/huggingface/accelerate/) environment with:
|
||||
|
||||
```bash
|
||||
accelerate config
|
||||
```
|
||||
|
||||
## Training on CIFAR10
|
||||
|
||||
Please replace the validation image with your own image.
|
||||
|
||||
```bash
|
||||
accelerate launch train_autoencoderkl.py \
|
||||
--pretrained_model_name_or_path stabilityai/sd-vae-ft-mse \
|
||||
--dataset_name=cifar10 \
|
||||
--image_column=img \
|
||||
--validation_image images/bird.jpg images/car.jpg images/dog.jpg images/frog.jpg \
|
||||
--num_train_epochs 100 \
|
||||
--gradient_accumulation_steps 2 \
|
||||
--learning_rate 4.5e-6 \
|
||||
--lr_scheduler cosine \
|
||||
--report_to wandb \
|
||||
```
|
||||
|
||||
## Training on ImageNet
|
||||
|
||||
```bash
|
||||
accelerate launch train_autoencoderkl.py \
|
||||
--pretrained_model_name_or_path stabilityai/sd-vae-ft-mse \
|
||||
--num_train_epochs 100 \
|
||||
--gradient_accumulation_steps 2 \
|
||||
--learning_rate 4.5e-6 \
|
||||
--lr_scheduler cosine \
|
||||
--report_to wandb \
|
||||
--mixed_precision bf16 \
|
||||
--train_data_dir /path/to/ImageNet/train \
|
||||
--validation_image ./image.png \
|
||||
--decoder_only
|
||||
```
|
||||
@@ -0,0 +1,15 @@
|
||||
accelerate>=0.16.0
|
||||
bitsandbytes
|
||||
datasets
|
||||
huggingface_hub
|
||||
lpips
|
||||
numpy
|
||||
packaging
|
||||
Pillow
|
||||
taming_transformers
|
||||
torch
|
||||
torchvision
|
||||
tqdm
|
||||
transformers
|
||||
wandb
|
||||
xformers
|
||||
File diff suppressed because it is too large
Load Diff
@@ -1,6 +1,6 @@
|
||||
#!/usr/bin/env python
|
||||
# coding=utf-8
|
||||
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
#!/usr/bin/env python
|
||||
# coding=utf-8
|
||||
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
#!/usr/bin/env python
|
||||
# coding=utf-8
|
||||
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
#!/usr/bin/env python
|
||||
# coding=utf-8
|
||||
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
#!/usr/bin/env python
|
||||
# coding=utf-8
|
||||
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
|
||||
+1
-1
@@ -1,6 +1,6 @@
|
||||
#!/usr/bin/env python
|
||||
# coding=utf-8
|
||||
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
#!/usr/bin/env python
|
||||
# coding=utf-8
|
||||
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
#!/usr/bin/env python
|
||||
# coding=utf-8
|
||||
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
#!/usr/bin/env python
|
||||
# coding=utf-8
|
||||
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
#!/usr/bin/env python
|
||||
# coding=utf-8
|
||||
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
|
||||
@@ -8,7 +8,6 @@ from diffusers.models import PixArtTransformer2DModel
|
||||
from diffusers.models.attention import BasicTransformerBlock
|
||||
from diffusers.models.modeling_outputs import Transformer2DModelOutput
|
||||
from diffusers.models.modeling_utils import ModelMixin
|
||||
from diffusers.utils.torch_utils import is_torch_version
|
||||
|
||||
|
||||
class PixArtControlNetAdapterBlock(nn.Module):
|
||||
@@ -151,10 +150,6 @@ class PixArtControlNetTransformerModel(ModelMixin, ConfigMixin):
|
||||
self.transformer = transformer
|
||||
self.controlnet = controlnet
|
||||
|
||||
def _set_gradient_checkpointing(self, module, value=False):
|
||||
if hasattr(module, "gradient_checkpointing"):
|
||||
module.gradient_checkpointing = value
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
@@ -220,18 +215,8 @@ class PixArtControlNetTransformerModel(ModelMixin, ConfigMixin):
|
||||
print("Gradient checkpointing is not supported for the controlnet transformer model, yet.")
|
||||
exit(1)
|
||||
|
||||
def create_custom_forward(module, return_dict=None):
|
||||
def custom_forward(*inputs):
|
||||
if return_dict is not None:
|
||||
return module(*inputs, return_dict=return_dict)
|
||||
else:
|
||||
return module(*inputs)
|
||||
|
||||
return custom_forward
|
||||
|
||||
ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
|
||||
hidden_states = torch.utils.checkpoint.checkpoint(
|
||||
create_custom_forward(block),
|
||||
hidden_states = self._gradient_checkpointing_func(
|
||||
block,
|
||||
hidden_states,
|
||||
attention_mask,
|
||||
encoder_hidden_states,
|
||||
@@ -239,7 +224,6 @@ class PixArtControlNetTransformerModel(ModelMixin, ConfigMixin):
|
||||
timestep,
|
||||
cross_attention_kwargs,
|
||||
None,
|
||||
**ckpt_kwargs,
|
||||
)
|
||||
else:
|
||||
# the control nets are only used for the blocks 1 to self.blocks_num
|
||||
|
||||
+1
-1
@@ -1,6 +1,6 @@
|
||||
#!/usr/bin/env python
|
||||
# coding=utf-8
|
||||
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
|
||||
+1
-1
@@ -1,6 +1,6 @@
|
||||
#!/usr/bin/env python
|
||||
# coding=utf-8
|
||||
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
|
||||
+1
-1
@@ -1,6 +1,6 @@
|
||||
#!/usr/bin/env python
|
||||
# coding=utf-8
|
||||
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
|
||||
+1
-1
@@ -1,6 +1,6 @@
|
||||
#!/usr/bin/env python
|
||||
# coding=utf-8
|
||||
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
|
||||
+1
-1
@@ -1,6 +1,6 @@
|
||||
#!/usr/bin/env python
|
||||
# coding=utf-8
|
||||
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
|
||||
+1
-1
@@ -1,6 +1,6 @@
|
||||
#!/usr/bin/env python
|
||||
# coding=utf-8
|
||||
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
|
||||
+1
-1
@@ -1,6 +1,6 @@
|
||||
#!/usr/bin/env python
|
||||
# coding=utf-8
|
||||
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
#!/usr/bin/env python
|
||||
# coding=utf-8
|
||||
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
#!/usr/bin/env python
|
||||
# coding=utf-8
|
||||
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
#!/usr/bin/env python
|
||||
# coding=utf-8
|
||||
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
|
||||
+1
-1
@@ -1,4 +1,4 @@
|
||||
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
#!/usr/bin/env python
|
||||
# coding=utf-8
|
||||
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
#!/usr/bin/env python
|
||||
# coding=utf-8
|
||||
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
#!/usr/bin/env python
|
||||
# coding=utf-8
|
||||
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
#!/usr/bin/env python
|
||||
# coding=utf-8
|
||||
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
#!/usr/bin/env python
|
||||
# coding=utf-8
|
||||
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
#!/usr/bin/env python
|
||||
# coding=utf-8
|
||||
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
@@ -515,10 +515,6 @@ def main():
|
||||
elif accelerator.mixed_precision == "bf16":
|
||||
weight_dtype = torch.bfloat16
|
||||
|
||||
# Freeze the unet parameters before adding adapters
|
||||
for param in unet.parameters():
|
||||
param.requires_grad_(False)
|
||||
|
||||
unet_lora_config = LoraConfig(
|
||||
r=args.rank,
|
||||
lora_alpha=args.rank,
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
#!/usr/bin/env python
|
||||
# coding=utf-8
|
||||
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
#!/usr/bin/env python
|
||||
# coding=utf-8
|
||||
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
#!/usr/bin/env python
|
||||
# coding=utf-8
|
||||
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
#!/usr/bin/env python
|
||||
# coding=utf-8
|
||||
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
#!/usr/bin/env python
|
||||
# coding=utf-8
|
||||
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2024 The HuggingFace Inc. team.
|
||||
# Copyright 2025 The HuggingFace Inc. team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2024 The HuggingFace Inc. team.
|
||||
# Copyright 2025 The HuggingFace Inc. team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2024 The HuggingFace Inc. team.
|
||||
# Copyright 2025 The HuggingFace Inc. team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2024 The HuggingFace Inc. team.
|
||||
# Copyright 2025 The HuggingFace Inc. team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2024 The HuggingFace Inc. team.
|
||||
# Copyright 2025 The HuggingFace Inc. team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2024 The HuggingFace Inc. team.
|
||||
# Copyright 2025 The HuggingFace Inc. team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2024 The HuggingFace Inc. team.
|
||||
# Copyright 2025 The HuggingFace Inc. team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2024 The HuggingFace Inc. team.
|
||||
# Copyright 2025 The HuggingFace Inc. team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2024 The HuggingFace Inc. team.
|
||||
# Copyright 2025 The HuggingFace Inc. team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2024 The HuggingFace Inc. team.
|
||||
# Copyright 2025 The HuggingFace Inc. team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2024 The HuggingFace Inc. team.
|
||||
# Copyright 2025 The HuggingFace Inc. team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2024 The HuggingFace Inc. team.
|
||||
# Copyright 2025 The HuggingFace Inc. team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
|
||||
@@ -0,0 +1,151 @@
|
||||
"""
|
||||
This script demonstrates how to extract a LoRA checkpoint from a fully finetuned model with the CogVideoX model.
|
||||
|
||||
To make it work for other models:
|
||||
|
||||
* Change the model class. Here we use `CogVideoXTransformer3DModel`. For Flux, it would be `FluxTransformer2DModel`,
|
||||
for example. (TODO: more reason to add `AutoModel`).
|
||||
* Spply path to the base checkpoint via `base_ckpt_path`.
|
||||
* Supply path to the fully fine-tuned checkpoint via `--finetune_ckpt_path`.
|
||||
* Change the `--rank` as needed.
|
||||
|
||||
Example usage:
|
||||
|
||||
```bash
|
||||
python extract_lora_from_model.py \
|
||||
--base_ckpt_path=THUDM/CogVideoX-5b \
|
||||
--finetune_ckpt_path=finetrainers/cakeify-v0 \
|
||||
--lora_out_path=cakeify_lora.safetensors
|
||||
```
|
||||
|
||||
Script is adapted from
|
||||
https://github.com/Stability-AI/stability-ComfyUI-nodes/blob/001154622564b17223ce0191803c5fff7b87146c/control_lora_create.py
|
||||
"""
|
||||
|
||||
import argparse
|
||||
|
||||
import torch
|
||||
from safetensors.torch import save_file
|
||||
from tqdm.auto import tqdm
|
||||
|
||||
from diffusers import CogVideoXTransformer3DModel
|
||||
|
||||
|
||||
RANK = 64
|
||||
CLAMP_QUANTILE = 0.99
|
||||
|
||||
|
||||
# Comes from
|
||||
# https://github.com/Stability-AI/stability-ComfyUI-nodes/blob/001154622564b17223ce0191803c5fff7b87146c/control_lora_create.py#L9
|
||||
def extract_lora(diff, rank):
|
||||
# Important to use CUDA otherwise, very slow!
|
||||
if torch.cuda.is_available():
|
||||
diff = diff.to("cuda")
|
||||
|
||||
is_conv2d = len(diff.shape) == 4
|
||||
kernel_size = None if not is_conv2d else diff.size()[2:4]
|
||||
is_conv2d_3x3 = is_conv2d and kernel_size != (1, 1)
|
||||
out_dim, in_dim = diff.size()[0:2]
|
||||
rank = min(rank, in_dim, out_dim)
|
||||
|
||||
if is_conv2d:
|
||||
if is_conv2d_3x3:
|
||||
diff = diff.flatten(start_dim=1)
|
||||
else:
|
||||
diff = diff.squeeze()
|
||||
|
||||
U, S, Vh = torch.linalg.svd(diff.float())
|
||||
U = U[:, :rank]
|
||||
S = S[:rank]
|
||||
U = U @ torch.diag(S)
|
||||
Vh = Vh[:rank, :]
|
||||
|
||||
dist = torch.cat([U.flatten(), Vh.flatten()])
|
||||
hi_val = torch.quantile(dist, CLAMP_QUANTILE)
|
||||
low_val = -hi_val
|
||||
|
||||
U = U.clamp(low_val, hi_val)
|
||||
Vh = Vh.clamp(low_val, hi_val)
|
||||
if is_conv2d:
|
||||
U = U.reshape(out_dim, rank, 1, 1)
|
||||
Vh = Vh.reshape(rank, in_dim, kernel_size[0], kernel_size[1])
|
||||
return (U.cpu(), Vh.cpu())
|
||||
|
||||
|
||||
def parse_args():
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument(
|
||||
"--base_ckpt_path",
|
||||
default=None,
|
||||
type=str,
|
||||
required=True,
|
||||
help="Base checkpoint path from which the model was finetuned. Can be a model ID on the Hub.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--base_subfolder",
|
||||
default="transformer",
|
||||
type=str,
|
||||
help="subfolder to load the base checkpoint from if any.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--finetune_ckpt_path",
|
||||
default=None,
|
||||
type=str,
|
||||
required=True,
|
||||
help="Fully fine-tuned checkpoint path. Can be a model ID on the Hub.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--finetune_subfolder",
|
||||
default=None,
|
||||
type=str,
|
||||
help="subfolder to load the fulle finetuned checkpoint from if any.",
|
||||
)
|
||||
parser.add_argument("--rank", default=64, type=int)
|
||||
parser.add_argument("--lora_out_path", default=None, type=str, required=True)
|
||||
args = parser.parse_args()
|
||||
|
||||
if not args.lora_out_path.endswith(".safetensors"):
|
||||
raise ValueError("`lora_out_path` must end with `.safetensors`.")
|
||||
|
||||
return args
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def main(args):
|
||||
model_finetuned = CogVideoXTransformer3DModel.from_pretrained(
|
||||
args.finetune_ckpt_path, subfolder=args.finetune_subfolder, torch_dtype=torch.bfloat16
|
||||
)
|
||||
state_dict_ft = model_finetuned.state_dict()
|
||||
|
||||
# Change the `subfolder` as needed.
|
||||
base_model = CogVideoXTransformer3DModel.from_pretrained(
|
||||
args.base_ckpt_path, subfolder=args.base_subfolder, torch_dtype=torch.bfloat16
|
||||
)
|
||||
state_dict = base_model.state_dict()
|
||||
output_dict = {}
|
||||
|
||||
for k in tqdm(state_dict, desc="Extracting LoRA..."):
|
||||
original_param = state_dict[k]
|
||||
finetuned_param = state_dict_ft[k]
|
||||
if len(original_param.shape) >= 2:
|
||||
diff = finetuned_param.float() - original_param.float()
|
||||
out = extract_lora(diff, RANK)
|
||||
name = k
|
||||
|
||||
if name.endswith(".weight"):
|
||||
name = name[: -len(".weight")]
|
||||
down_key = "{}.lora_A.weight".format(name)
|
||||
up_key = "{}.lora_B.weight".format(name)
|
||||
|
||||
output_dict[up_key] = out[0].contiguous().to(finetuned_param.dtype)
|
||||
output_dict[down_key] = out[1].contiguous().to(finetuned_param.dtype)
|
||||
|
||||
prefix = "transformer" if "transformer" in base_model.__class__.__name__.lower() else "unet"
|
||||
output_dict = {f"{prefix}.{k}": v for k, v in output_dict.items()}
|
||||
save_file(output_dict, args.lora_out_path)
|
||||
print(f"LoRA saved and it contains {len(output_dict)} keys.")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
args = parse_args()
|
||||
main(args)
|
||||
@@ -28,6 +28,7 @@ from .utils import (
|
||||
|
||||
_import_structure = {
|
||||
"configuration_utils": ["ConfigMixin"],
|
||||
"hooks": [],
|
||||
"loaders": ["FromOriginalModelMixin"],
|
||||
"models": [],
|
||||
"pipelines": [],
|
||||
@@ -75,6 +76,13 @@ except OptionalDependencyNotAvailable:
|
||||
_import_structure["utils.dummy_pt_objects"] = [name for name in dir(dummy_pt_objects) if not name.startswith("_")]
|
||||
|
||||
else:
|
||||
_import_structure["hooks"].extend(
|
||||
[
|
||||
"HookRegistry",
|
||||
"PyramidAttentionBroadcastConfig",
|
||||
"apply_pyramid_attention_broadcast",
|
||||
]
|
||||
)
|
||||
_import_structure["models"].extend(
|
||||
[
|
||||
"AllegroTransformer3DModel",
|
||||
@@ -90,6 +98,7 @@ else:
|
||||
"AutoencoderKLTemporalDecoder",
|
||||
"AutoencoderOobleck",
|
||||
"AutoencoderTiny",
|
||||
"CacheMixin",
|
||||
"CogVideoXTransformer3DModel",
|
||||
"CogView3PlusTransformer2DModel",
|
||||
"ConsisIDTransformer3DModel",
|
||||
@@ -588,6 +597,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
except OptionalDependencyNotAvailable:
|
||||
from .utils.dummy_pt_objects import * # noqa F403
|
||||
else:
|
||||
from .hooks import HookRegistry, PyramidAttentionBroadcastConfig, apply_pyramid_attention_broadcast
|
||||
from .models import (
|
||||
AllegroTransformer3DModel,
|
||||
AsymmetricAutoencoderKL,
|
||||
@@ -602,6 +612,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
AutoencoderKLTemporalDecoder,
|
||||
AutoencoderOobleck,
|
||||
AutoencoderTiny,
|
||||
CacheMixin,
|
||||
CogVideoXTransformer3DModel,
|
||||
CogView3PlusTransformer2DModel,
|
||||
ConsisIDTransformer3DModel,
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2024 The HuggingFace Inc. team.
|
||||
# Copyright 2025 The HuggingFace Inc. team.
|
||||
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
|
||||
@@ -0,0 +1,7 @@
|
||||
from ..utils import is_torch_available
|
||||
|
||||
|
||||
if is_torch_available():
|
||||
from .hooks import HookRegistry, ModelHook
|
||||
from .layerwise_casting import apply_layerwise_casting, apply_layerwise_casting_hook
|
||||
from .pyramid_attention_broadcast import PyramidAttentionBroadcastConfig, apply_pyramid_attention_broadcast
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user