Compare commits

..

21 Commits

Author SHA1 Message Date
Aryan a1da7752e5 Merge branch 'main' into feature/guiders 2025-04-05 10:55:02 +02:00
Aryan b30cf5d452 spatio temporal guidance 2025-04-05 10:39:46 +02:00
Aryan 357f4f056b update 2025-04-04 22:44:02 +02:00
Aryan 53b6b9fcb6 perturbed attention guidance 2025-04-04 04:16:46 +02:00
Aryan 46643564a3 refactor 2025-04-04 01:41:34 +02:00
Aryan 77324c40c4 adaptive projected guidance 2025-04-03 05:01:54 +02:00
Aryan 05d74ef3e7 cfg zero star 2025-04-03 04:21:24 +02:00
Aryan 9997c223a8 more slg improvements 2025-04-03 03:50:30 +02:00
Aryan d91d10737a update slg docstring 2025-04-03 03:43:10 +02:00
Aryan 5ac7f360b0 skip layer guidance 2025-04-03 03:26:55 +02:00
Aryan 594e8d663f classifier-free guidance 2025-04-03 00:13:15 +02:00
Aryan c76e1cc17e update 2025-04-02 21:52:33 +02:00
Aryan 315e357a18 Merge branch 'main' into integrations/first-block-cache-2 2025-04-02 01:21:22 +02:00
Aryan 1f33ca276d support flux, ltx i2v, ltx condition 2025-04-02 01:21:09 +02:00
Aryan 41b0c473d2 fix controlnet flux 2025-04-02 01:20:53 +02:00
Aryan 0e232ac8c0 fix hs residual bug for single return outputs; support ltx 2025-04-02 00:38:11 +02:00
Aryan 2557238b4d cache context for different batches of data 2025-04-01 19:40:23 +02:00
Aryan d71fe55895 update 2025-04-01 17:06:45 +02:00
Aryan 7ab424a15a remove debug logs 2025-04-01 01:39:00 +02:00
Aryan dd69b41834 modify flux single blocks to make compatible with cache techniques (without too much model-specific intrusion code) 2025-04-01 01:28:09 +02:00
Aryan 406b1062f8 update 2025-03-31 04:27:35 +02:00
266 changed files with 7291 additions and 7522 deletions
+1 -1
View File
@@ -417,7 +417,7 @@ jobs:
additional_deps: ["peft"]
- backend: "gguf"
test_location: "gguf"
additional_deps: ["peft"]
additional_deps: []
- backend: "torchao"
test_location: "torchao"
additional_deps: []
+56 -33
View File
@@ -11,33 +11,6 @@ specific language governing permissions and limitations under the License. -->
# Caching methods
## Pyramid Attention Broadcast
[Pyramid Attention Broadcast](https://huggingface.co/papers/2408.12588) from Xuanlei Zhao, Xiaolong Jin, Kai Wang, Yang You.
Pyramid Attention Broadcast (PAB) is a method that speeds up inference in diffusion models by systematically skipping attention computations between successive inference steps and reusing cached attention states. The attention states are not very different between successive inference steps. The most prominent difference is in the spatial attention blocks, not as much in the temporal attention blocks, and finally the least in the cross attention blocks. Therefore, many cross attention computation blocks can be skipped, followed by the temporal and spatial attention blocks. By combining other techniques like sequence parallelism and classifier-free guidance parallelism, PAB achieves near real-time video generation.
Enable PAB with [`~PyramidAttentionBroadcastConfig`] on any pipeline. For some benchmarks, refer to [this](https://github.com/huggingface/diffusers/pull/9562) pull request.
```python
import torch
from diffusers import CogVideoXPipeline, PyramidAttentionBroadcastConfig
pipe = CogVideoXPipeline.from_pretrained("THUDM/CogVideoX-5b", torch_dtype=torch.bfloat16)
pipe.to("cuda")
# Increasing the value of `spatial_attention_timestep_skip_range[0]` or decreasing the value of
# `spatial_attention_timestep_skip_range[1]` will decrease the interval in which pyramid attention
# broadcast is active, leader to slower inference speeds. However, large intervals can lead to
# poorer quality of generated videos.
config = PyramidAttentionBroadcastConfig(
spatial_attention_block_skip_range=2,
spatial_attention_timestep_skip_range=(100, 800),
current_timestep_callback=lambda: pipe.current_timestep,
)
pipe.transformer.enable_cache(config)
```
## Faster Cache
[FasterCache](https://huggingface.co/papers/2410.19355) from Zhengyao Lv, Chenyang Si, Junhao Song, Zhenyu Yang, Yu Qiao, Ziwei Liu, Kwan-Yee K. Wong.
@@ -65,18 +38,68 @@ config = FasterCacheConfig(
pipe.transformer.enable_cache(config)
```
## First Block Cache
[First Block Cache](https://github.com/chengzeyi/ParaAttention/blob/7a266123671b55e7e5a2fe9af3121f07a36afc78/README.md#first-block-cache-our-dynamic-caching) is a method that builds upon the ideas of [TeaCache](https://huggingface.co/papers/2411.19108) to speed up inference in diffusion transformers. The generation quality is superior with greatly reduced inference time. This method always computes the output of the first transformer block and computes the differences between past and current outputs of the first transformer block. If the difference is smaller than a predefined threshold, the computation of remaining transformer blocks is skipped, and otherwise the computation is performed as usual.
```python
import torch
from diffusers import CogVideoXPipeline, FirstBlockCacheConfig
pipe = CogVideoXPipeline.from_pretrained("THUDM/CogVideoX-5b", torch_dtype=torch.bfloat16)
pipe.to("cuda")
# Increasing the threshold may lead to faster inference speeds, but may also lead to poorer quality of generated videos.
# Smaller values between 0.02-2.0 are recommended based on the model being used. The default value is 0.05.
config = FirstBlockCacheConfig(threshold=0.07)
pipe.transformer.enable_cache(config)
```
## Pyramid Attention Broadcast
[Pyramid Attention Broadcast](https://huggingface.co/papers/2408.12588) from Xuanlei Zhao, Xiaolong Jin, Kai Wang, Yang You.
Pyramid Attention Broadcast (PAB) is a method that speeds up inference in diffusion models by systematically skipping attention computations between successive inference steps and reusing cached attention states. The attention states are not very different between successive inference steps. The most prominent difference is in the spatial attention blocks, not as much in the temporal attention blocks, and finally the least in the cross attention blocks. Therefore, many cross attention computation blocks can be skipped, followed by the temporal and spatial attention blocks. By combining other techniques like sequence parallelism and classifier-free guidance parallelism, PAB achieves near real-time video generation.
Enable PAB with [`~PyramidAttentionBroadcastConfig`] on any pipeline. For some benchmarks, refer to [this](https://github.com/huggingface/diffusers/pull/9562) pull request.
```python
import torch
from diffusers import CogVideoXPipeline, PyramidAttentionBroadcastConfig
pipe = CogVideoXPipeline.from_pretrained("THUDM/CogVideoX-5b", torch_dtype=torch.bfloat16)
pipe.to("cuda")
# Increasing the value of `spatial_attention_timestep_skip_range[0]` or decreasing the value of
# `spatial_attention_timestep_skip_range[1]` will decrease the interval in which pyramid attention
# broadcast is active, leader to slower inference speeds. However, large intervals can lead to
# poorer quality of generated videos.
config = PyramidAttentionBroadcastConfig(
spatial_attention_block_skip_range=2,
spatial_attention_timestep_skip_range=(100, 800),
current_timestep_callback=lambda: pipe.current_timestep,
)
pipe.transformer.enable_cache(config)
```
### CacheMixin
[[autodoc]] CacheMixin
### PyramidAttentionBroadcastConfig
[[autodoc]] PyramidAttentionBroadcastConfig
[[autodoc]] apply_pyramid_attention_broadcast
### FasterCacheConfig
[[autodoc]] FasterCacheConfig
[[autodoc]] apply_faster_cache
### FirstBlockCacheConfig
[[autodoc]] FirstBlockCacheConfig
[[autodoc]] apply_first_block_cache
### PyramidAttentionBroadcastConfig
[[autodoc]] PyramidAttentionBroadcastConfig
[[autodoc]] apply_pyramid_attention_broadcast
@@ -14,7 +14,6 @@ specific language governing permissions and limitations under the License.
<div class="flex flex-wrap space-x-1">
<img alt="LoRA" src="https://img.shields.io/badge/LoRA-d8b4fe?style=flat"/>
<img alt="MPS" src="https://img.shields.io/badge/MPS-000000?style=flat&logo=apple&logoColor=white%22">
</div>
## Overview
-1
View File
@@ -14,7 +14,6 @@ specific language governing permissions and limitations under the License.
<div class="flex flex-wrap space-x-1">
<img alt="LoRA" src="https://img.shields.io/badge/LoRA-d8b4fe?style=flat"/>
<img alt="MPS" src="https://img.shields.io/badge/MPS-000000?style=flat&logo=apple&logoColor=white%22">
</div>
Flux is a series of text-to-image generation models based on diffusion transformers. To know more about Flux, check out the original [blog post](https://blackforestlabs.ai/announcing-black-forest-labs/) by the creators of Flux, Black Forest Labs.
-1
View File
@@ -14,7 +14,6 @@ specific language governing permissions and limitations under the License.
<div class="flex flex-wrap space-x-1">
<img alt="LoRA" src="https://img.shields.io/badge/LoRA-d8b4fe?style=flat"/>
<img alt="MPS" src="https://img.shields.io/badge/MPS-000000?style=flat&logo=apple&logoColor=white%22">
</div>
![](https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/kolors/kolors_header_collage.png)
@@ -16,7 +16,6 @@
<div class="flex flex-wrap space-x-1">
<img alt="LoRA" src="https://img.shields.io/badge/LoRA-d8b4fe?style=flat"/>
<img alt="MPS" src="https://img.shields.io/badge/MPS-000000?style=flat&logo=apple&logoColor=white%22">
</div>
[LTX Video](https://huggingface.co/Lightricks/LTX-Video) is the first DiT-based video generation model capable of generating high-quality videos in real-time. It produces 24 FPS videos at a 768x512 resolution faster than they can be watched. Trained on a large-scale dataset of diverse videos, the model generates high-resolution videos with realistic and varied content. We provide a model for both text-to-video as well as image + text-to-video usecases.
-1
View File
@@ -16,7 +16,6 @@
<div class="flex flex-wrap space-x-1">
<img alt="LoRA" src="https://img.shields.io/badge/LoRA-d8b4fe?style=flat"/>
<img alt="MPS" src="https://img.shields.io/badge/MPS-000000?style=flat&logo=apple&logoColor=white%22">
</div>
[SANA: Efficient High-Resolution Image Synthesis with Linear Diffusion Transformers](https://huggingface.co/papers/2410.10629) from NVIDIA and MIT HAN Lab, by Enze Xie, Junsong Chen, Junyu Chen, Han Cai, Haotian Tang, Yujun Lin, Zhekai Zhang, Muyang Li, Ligeng Zhu, Yao Lu, Song Han.
+1 -1
View File
@@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License. -->
# SANA-Sprint
# SanaSprintPipeline
<div class="flex flex-wrap space-x-1">
<img alt="LoRA" src="https://img.shields.io/badge/LoRA-d8b4fe?style=flat"/>
@@ -14,7 +14,6 @@ specific language governing permissions and limitations under the License.
<div class="flex flex-wrap space-x-1">
<img alt="LoRA" src="https://img.shields.io/badge/LoRA-d8b4fe?style=flat"/>
<img alt="MPS" src="https://img.shields.io/badge/MPS-000000?style=flat&logo=apple&logoColor=white%22">
</div>
Stable Diffusion 3 (SD3) was proposed in [Scaling Rectified Flow Transformers for High-Resolution Image Synthesis](https://arxiv.org/pdf/2403.03206.pdf) by Patrick Esser, Sumith Kulal, Andreas Blattmann, Rahim Entezari, Jonas Muller, Harry Saini, Yam Levi, Dominik Lorenz, Axel Sauer, Frederic Boesel, Dustin Podell, Tim Dockhorn, Zion English, Kyle Lacey, Alex Goodwin, Yannik Marek, and Robin Rombach.
@@ -14,7 +14,6 @@ specific language governing permissions and limitations under the License.
<div class="flex flex-wrap space-x-1">
<img alt="LoRA" src="https://img.shields.io/badge/LoRA-d8b4fe?style=flat"/>
<img alt="MPS" src="https://img.shields.io/badge/MPS-000000?style=flat&logo=apple&logoColor=white%22">
</div>
Stable Diffusion XL (SDXL) was proposed in [SDXL: Improving Latent Diffusion Models for High-Resolution Image Synthesis](https://huggingface.co/papers/2307.01952) by Dustin Podell, Zion English, Kyle Lacey, Andreas Blattmann, Tim Dockhorn, Jonas Müller, Joe Penna, and Robin Rombach.
-4
View File
@@ -178,9 +178,6 @@ pipe = CogVideoXPipeline.from_pretrained("THUDM/CogVideoX-5b", torch_dtype=torch
# We can utilize the enable_group_offload method for Diffusers model implementations
pipe.transformer.enable_group_offload(onload_device=onload_device, offload_device=offload_device, offload_type="leaf_level", use_stream=True)
# Uncomment the following to also allow recording the current streams.
# pipe.transformer.enable_group_offload(onload_device=onload_device, offload_device=offload_device, offload_type="leaf_level", use_stream=True, record_stream=True)
# For any other model implementations, the apply_group_offloading function can be used
apply_group_offloading(pipe.text_encoder, onload_device=onload_device, offload_type="block_level", num_blocks_per_group=2)
apply_group_offloading(pipe.vae, onload_device=onload_device, offload_type="leaf_level")
@@ -208,7 +205,6 @@ Group offloading (for CUDA devices with support for asynchronous data transfer s
- The `use_stream` parameter can be used with CUDA devices to enable prefetching layers for onload. It defaults to `False`. Layer prefetching allows overlapping computation and data transfer of model weights, which drastically reduces the overall execution time compared to other offloading methods. However, it can increase the CPU RAM usage significantly. Ensure that available CPU RAM that is at least twice the size of the model when setting `use_stream=True`. You can find more information about CUDA streams [here](https://pytorch.org/docs/stable/generated/torch.cuda.Stream.html)
- If specifying `use_stream=True` on VAEs with tiling enabled, make sure to do a dummy forward pass (possibly with dummy inputs) before the actual inference to avoid device-mismatch errors. This may not work on all implementations. Please open an issue if you encounter any problems.
- The parameter `low_cpu_mem_usage` can be set to `True` to reduce CPU memory usage when using streams for group offloading. This is useful when the CPU memory is the bottleneck, but it may counteract the benefits of using streams and increase the overall execution time. The CPU memory savings come from creating pinned-tensors on-the-fly instead of pre-pinning them. This parameter is better suited for using `leaf_level` offloading.
- When using `use_stream=True`, users can additionally specify `record_stream=True` to get better speedups at the expense of slightly increased memory usage. Refer to the [official PyTorch docs](https://pytorch.org/docs/stable/generated/torch.Tensor.record_stream.html) to know more about this.
For more information about available parameters and an explanation of how group offloading works, refer to [`~hooks.group_offloading.apply_group_offloading`].
+1 -12
View File
@@ -12,9 +12,6 @@ specific language governing permissions and limitations under the License.
# Metal Performance Shaders (MPS)
> [!TIP]
> Pipelines with a <img alt="MPS" src="https://img.shields.io/badge/MPS-000000?style=flat&logo=apple&logoColor=white%22"> badge indicate a model can take advantage of the MPS backend on Apple silicon devices for faster inference. Feel free to open a [Pull Request](https://github.com/huggingface/diffusers/compare) to add this badge to pipelines that are missing it.
🤗 Diffusers is compatible with Apple silicon (M1/M2 chips) using the PyTorch [`mps`](https://pytorch.org/docs/stable/notes/mps.html) device, which uses the Metal framework to leverage the GPU on MacOS devices. You'll need to have:
- macOS computer with Apple silicon (M1/M2) hardware
@@ -40,7 +37,7 @@ image
<Tip warning={true}>
The PyTorch [mps](https://pytorch.org/docs/stable/notes/mps.html) backend does not support NDArray sizes greater than `2**32`. Please open an [Issue](https://github.com/huggingface/diffusers/issues/new/choose) if you encounter this problem so we can investigate.
Generating multiple prompts in a batch can [crash](https://github.com/huggingface/diffusers/issues/363) or fail to work reliably. We believe this is related to the [`mps`](https://github.com/pytorch/pytorch/issues/84039) backend in PyTorch. While this is being investigated, you should iterate instead of batching.
</Tip>
@@ -62,10 +59,6 @@ If you're using **PyTorch 1.13**, you need to "prime" the pipeline with an addit
## Troubleshoot
This section lists some common issues with using the `mps` backend and how to solve them.
### Attention slicing
M1/M2 performance is very sensitive to memory pressure. When this occurs, the system automatically swaps if it needs to which significantly degrades performance.
To prevent this from happening, we recommend *attention slicing* to reduce memory pressure during inference and prevent swapping. This is especially relevant if your computer has less than 64GB of system RAM, or if you generate images at non-standard resolutions larger than 512×512 pixels. Call the [`~DiffusionPipeline.enable_attention_slicing`] function on your pipeline:
@@ -79,7 +72,3 @@ pipeline.enable_attention_slicing()
```
Attention slicing performs the costly attention operation in multiple steps instead of all at once. It usually improves performance by ~20% in computers without universal memory, but we've observed *better performance* in most Apple silicon computers unless you have 64GB of RAM or more.
### Batch inference
Generating multiple prompts in a batch can crash or fail to work reliably. If this is the case, try iterating instead of batching.
+1 -1
View File
@@ -105,7 +105,7 @@ import torch
pipe = HunyuanVideoPipeline.from_pretrained(
"hunyuanvideo-community/HunyuanVideo",
torch_dtype={"transformer": torch.bfloat16, "default": torch.float16},
torch_dtype={'transformer': torch.bfloat16, 'default': torch.float16},
)
print(pipe.transformer.dtype, pipe.vae.dtype) # (torch.bfloat16, torch.float16)
```
@@ -194,59 +194,6 @@ Currently, [`~loaders.StableDiffusionLoraLoaderMixin.set_adapters`] only support
</Tip>
### Hotswapping LoRA adapters
A common use case when serving multiple adapters is to load one adapter first, generate images, load another adapter, generate more images, load another adapter, etc. This workflow normally requires calling [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_weights`], [`~loaders.StableDiffusionLoraLoaderMixin.set_adapters`], and possibly [`~loaders.peft.PeftAdapterMixin.delete_adapters`] to save memory. Moreover, if the model is compiled using `torch.compile`, performing these steps requires recompilation, which takes time.
To better support this common workflow, you can "hotswap" a LoRA adapter, to avoid accumulating memory and in some cases, recompilation. It requires an adapter to already be loaded, and the new adapter weights are swapped in-place for the existing adapter.
Pass `hotswap=True` when loading a LoRA adapter to enable this feature. It is important to indicate the name of the existing adapter, (`default_0` is the default adapter name), to be swapped. If you loaded the first adapter with a different name, use that name instead.
```python
pipe = ...
# load adapter 1 as normal
pipeline.load_lora_weights(file_name_adapter_1)
# generate some images with adapter 1
...
# now hot swap the 2nd adapter
pipeline.load_lora_weights(file_name_adapter_2, hotswap=True, adapter_name="default_0")
# generate images with adapter 2
```
<Tip warning={true}>
Hotswapping is not currently supported for LoRA adapters that target the text encoder.
</Tip>
For compiled models, it is often (though not always if the second adapter targets identical LoRA ranks and scales) necessary to call [`~loaders.lora_base.LoraBaseMixin.enable_lora_hotswap`] to avoid recompilation. Use [`~loaders.lora_base.LoraBaseMixin.enable_lora_hotswap`] _before_ loading the first adapter, and `torch.compile` should be called _after_ loading the first adapter.
```python
pipe = ...
# call this extra method
pipe.enable_lora_hotswap(target_rank=max_rank)
# now load adapter 1
pipe.load_lora_weights(file_name_adapter_1)
# now compile the unet of the pipeline
pipe.unet = torch.compile(pipeline.unet, ...)
# generate some images with adapter 1
...
# now hot swap adapter 2
pipeline.load_lora_weights(file_name_adapter_2, hotswap=True, adapter_name="default_0")
# generate images with adapter 2
```
The `target_rank=max_rank` argument is important for setting the maximum rank among all LoRA adapters that will be loaded. If you have one adapter with rank 8 and another with rank 16, pass `target_rank=16`. You should use a higher value if in doubt. By default, this value is 128.
However, there can be situations where recompilation is unavoidable. For example, if the hotswapped adapter targets more layers than the initial adapter, then recompilation is triggered. Try to load the adapter that targets the most layers first. Refer to the PEFT docs on [hotswapping](https://huggingface.co/docs/peft/main/en/package_reference/hotswap#peft.utils.hotswap.hotswap_adapter) for more details about the limitations of this feature.
<Tip>
Move your code inside the `with torch._dynamo.config.patch(error_on_recompile=True)` context manager to detect if a model was recompiled. If you detect recompilation despite following all the steps above, please open an issue with [Diffusers](https://github.com/huggingface/diffusers/issues) with a reproducible example.
</Tip>
### Kohya and TheLastBen
Other popular LoRA trainers from the community include those by [Kohya](https://github.com/kohya-ss/sd-scripts/) and [TheLastBen](https://github.com/TheLastBen/fast-stable-diffusion). These trainers create different LoRA checkpoints than those trained by 🤗 Diffusers, but they can still be loaded in the same way.
@@ -1,8 +1,7 @@
accelerate>=0.31.0
accelerate>=0.16.0
torchvision
transformers>=4.41.2
transformers>=4.25.1
ftfy
tensorboard
Jinja2
peft>=0.11.1
sentencepiece
peft==0.7.0
@@ -24,7 +24,7 @@ import re
import shutil
from contextlib import nullcontext
from pathlib import Path
from typing import List, Optional
from typing import List, Optional, Union
import numpy as np
import torch
@@ -228,20 +228,10 @@ def log_validation(
# run inference
generator = torch.Generator(device=accelerator.device).manual_seed(args.seed) if args.seed is not None else None
autocast_ctx = torch.autocast(accelerator.device.type) if not is_final_validation else nullcontext()
autocast_ctx = nullcontext()
# pre-calculate prompt embeds, pooled prompt embeds, text ids because t5 does not support autocast
with torch.no_grad():
prompt_embeds, pooled_prompt_embeds, text_ids = pipeline.encode_prompt(
pipeline_args["prompt"], prompt_2=pipeline_args["prompt"]
)
images = []
for _ in range(args.num_validation_images):
with autocast_ctx:
image = pipeline(
prompt_embeds=prompt_embeds, pooled_prompt_embeds=pooled_prompt_embeds, generator=generator
).images[0]
images.append(image)
with autocast_ctx:
images = [pipeline(**pipeline_args, generator=generator).images[0] for _ in range(args.num_validation_images)]
for tracker in accelerator.trackers:
phase_name = "test" if is_final_validation else "validation"
@@ -667,7 +657,6 @@ def parse_args(input_args=None):
parser.add_argument(
"--adam_weight_decay_text_encoder", type=float, default=1e-03, help="Weight decay to use for text_encoder"
)
parser.add_argument(
"--lora_layers",
type=str,
@@ -677,7 +666,6 @@ def parse_args(input_args=None):
'E.g. - "to_k,to_q,to_v,to_out.0" will result in lora training of attention layers only. For more examples refer to https://github.com/huggingface/diffusers/blob/main/examples/advanced_diffusion_training/README_flux.md'
),
)
parser.add_argument(
"--adam_epsilon",
type=float,
@@ -750,15 +738,6 @@ def parse_args(input_args=None):
" flag passed with the `accelerate.launch` command. Use this argument to override the accelerate config."
),
)
parser.add_argument(
"--upcast_before_saving",
action="store_true",
default=False,
help=(
"Whether to upcast the trained transformer layers to float32 before saving (at the end of training). "
"Defaults to precision dtype used for training to save memory"
),
)
parser.add_argument(
"--prior_generation_precision",
type=str,
@@ -839,9 +818,9 @@ class TokenEmbeddingsHandler:
idx = 0
for tokenizer, text_encoder in zip(self.tokenizers, self.text_encoders):
assert isinstance(inserting_toks, list), "inserting_toks should be a list of strings."
assert all(isinstance(tok, str) for tok in inserting_toks), (
"All elements in inserting_toks should be strings."
)
assert all(
isinstance(tok, str) for tok in inserting_toks
), "All elements in inserting_toks should be strings."
self.inserting_toks = inserting_toks
special_tokens_dict = {"additional_special_tokens": self.inserting_toks}
@@ -1168,7 +1147,7 @@ def tokenize_prompt(tokenizer, prompt, max_sequence_length, add_special_tokens=F
return text_input_ids
def _encode_prompt_with_t5(
def _get_t5_prompt_embeds(
text_encoder,
tokenizer,
max_sequence_length=512,
@@ -1197,10 +1176,7 @@ def _encode_prompt_with_t5(
prompt_embeds = text_encoder(text_input_ids.to(device))[0]
if hasattr(text_encoder, "module"):
dtype = text_encoder.module.dtype
else:
dtype = text_encoder.dtype
dtype = text_encoder.dtype
prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)
_, seq_len, _ = prompt_embeds.shape
@@ -1212,7 +1188,7 @@ def _encode_prompt_with_t5(
return prompt_embeds
def _encode_prompt_with_clip(
def _get_clip_prompt_embeds(
text_encoder,
tokenizer,
prompt: str,
@@ -1241,13 +1217,9 @@ def _encode_prompt_with_clip(
prompt_embeds = text_encoder(text_input_ids.to(device), output_hidden_states=False)
if hasattr(text_encoder, "module"):
dtype = text_encoder.module.dtype
else:
dtype = text_encoder.dtype
# Use pooled output of CLIPTextModel
prompt_embeds = prompt_embeds.pooler_output
prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)
prompt_embeds = prompt_embeds.to(dtype=text_encoder.dtype, device=device)
# duplicate text embeddings for each generation per prompt, using mps friendly method
prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
@@ -1266,35 +1238,136 @@ def encode_prompt(
text_input_ids_list=None,
):
prompt = [prompt] if isinstance(prompt, str) else prompt
if hasattr(text_encoders[0], "module"):
dtype = text_encoders[0].module.dtype
else:
dtype = text_encoders[0].dtype
batch_size = len(prompt)
dtype = text_encoders[0].dtype
pooled_prompt_embeds = _encode_prompt_with_clip(
pooled_prompt_embeds = _get_clip_prompt_embeds(
text_encoder=text_encoders[0],
tokenizer=tokenizers[0],
prompt=prompt,
device=device if device is not None else text_encoders[0].device,
num_images_per_prompt=num_images_per_prompt,
text_input_ids=text_input_ids_list[0] if text_input_ids_list else None,
text_input_ids=text_input_ids_list[0] if text_input_ids_list is not None else None,
)
prompt_embeds = _encode_prompt_with_t5(
prompt_embeds = _get_t5_prompt_embeds(
text_encoder=text_encoders[1],
tokenizer=tokenizers[1],
max_sequence_length=max_sequence_length,
prompt=prompt,
num_images_per_prompt=num_images_per_prompt,
device=device if device is not None else text_encoders[1].device,
text_input_ids=text_input_ids_list[1] if text_input_ids_list else None,
text_input_ids=text_input_ids_list[1] if text_input_ids_list is not None else None,
)
text_ids = torch.zeros(prompt_embeds.shape[1], 3).to(device=device, dtype=dtype)
text_ids = torch.zeros(batch_size, prompt_embeds.shape[1], 3).to(device=device, dtype=dtype)
text_ids = text_ids.repeat(num_images_per_prompt, 1, 1)
return prompt_embeds, pooled_prompt_embeds, text_ids
# CustomFlowMatchEulerDiscreteScheduler was taken from ostris ai-toolkit trainer:
# https://github.com/ostris/ai-toolkit/blob/9ee1ef2a0a2a9a02b92d114a95f21312e5906e54/toolkit/samplers/custom_flowmatch_sampler.py#L95
class CustomFlowMatchEulerDiscreteScheduler(FlowMatchEulerDiscreteScheduler):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
with torch.no_grad():
# create weights for timesteps
num_timesteps = 1000
# generate the multiplier based on cosmap loss weighing
# this is only used on linear timesteps for now
# cosine map weighing is higher in the middle and lower at the ends
# bot = 1 - 2 * self.sigmas + 2 * self.sigmas ** 2
# cosmap_weighing = 2 / (math.pi * bot)
# sigma sqrt weighing is significantly higher at the end and lower at the beginning
sigma_sqrt_weighing = (self.sigmas**-2.0).float()
# clip at 1e4 (1e6 is too high)
sigma_sqrt_weighing = torch.clamp(sigma_sqrt_weighing, max=1e4)
# bring to a mean of 1
sigma_sqrt_weighing = sigma_sqrt_weighing / sigma_sqrt_weighing.mean()
# Create linear timesteps from 1000 to 0
timesteps = torch.linspace(1000, 0, num_timesteps, device="cpu")
self.linear_timesteps = timesteps
# self.linear_timesteps_weights = cosmap_weighing
self.linear_timesteps_weights = sigma_sqrt_weighing
# self.sigmas = self.get_sigmas(timesteps, n_dim=1, dtype=torch.float32, device='cpu')
pass
def get_weights_for_timesteps(self, timesteps: torch.Tensor) -> torch.Tensor:
# Get the indices of the timesteps
step_indices = [(self.timesteps == t).nonzero().item() for t in timesteps]
# Get the weights for the timesteps
weights = self.linear_timesteps_weights[step_indices].flatten()
return weights
def get_sigmas(self, timesteps: torch.Tensor, n_dim, dtype, device) -> torch.Tensor:
sigmas = self.sigmas.to(device=device, dtype=dtype)
schedule_timesteps = self.timesteps.to(device)
timesteps = timesteps.to(device)
step_indices = [(schedule_timesteps == t).nonzero().item() for t in timesteps]
sigma = sigmas[step_indices].flatten()
while len(sigma.shape) < n_dim:
sigma = sigma.unsqueeze(-1)
return sigma
def add_noise(
self,
original_samples: torch.Tensor,
noise: torch.Tensor,
timesteps: torch.Tensor,
) -> torch.Tensor:
## ref https://github.com/huggingface/diffusers/blob/fbe29c62984c33c6cf9cf7ad120a992fe6d20854/examples/dreambooth/train_dreambooth_sd3.py#L1578
## Add noise according to flow matching.
## zt = (1 - texp) * x + texp * z1
# sigmas = get_sigmas(timesteps, n_dim=model_input.ndim, dtype=model_input.dtype)
# noisy_model_input = (1.0 - sigmas) * model_input + sigmas * noise
# timestep needs to be in [0, 1], we store them in [0, 1000]
# noisy_sample = (1 - timestep) * latent + timestep * noise
t_01 = (timesteps / 1000).to(original_samples.device)
noisy_model_input = (1 - t_01) * original_samples + t_01 * noise
# n_dim = original_samples.ndim
# sigmas = self.get_sigmas(timesteps, n_dim, original_samples.dtype, original_samples.device)
# noisy_model_input = (1.0 - sigmas) * original_samples + sigmas * noise
return noisy_model_input
def scale_model_input(self, sample: torch.Tensor, timestep: Union[float, torch.Tensor]) -> torch.Tensor:
return sample
def set_train_timesteps(self, num_timesteps, device, linear=False):
if linear:
timesteps = torch.linspace(1000, 0, num_timesteps, device=device)
self.timesteps = timesteps
return timesteps
else:
# distribute them closer to center. Inference distributes them as a bias toward first
# Generate values from 0 to 1
t = torch.sigmoid(torch.randn((num_timesteps,), device=device))
# Scale and reverse the values to go from 1000 to 0
timesteps = (1 - t) * 1000
# Sort the timesteps in descending order
timesteps, _ = torch.sort(timesteps, descending=True)
self.timesteps = timesteps.to(device=device)
return timesteps
def main(args):
if args.report_to == "wandb" and args.hub_token is not None:
raise ValueError(
@@ -1426,7 +1499,7 @@ def main(args):
)
# Load scheduler and models
noise_scheduler = FlowMatchEulerDiscreteScheduler.from_pretrained(
noise_scheduler = CustomFlowMatchEulerDiscreteScheduler.from_pretrained(
args.pretrained_model_name_or_path, subfolder="scheduler"
)
noise_scheduler_copy = copy.deepcopy(noise_scheduler)
@@ -1546,6 +1619,7 @@ def main(args):
target_modules=target_modules,
)
transformer.add_adapter(transformer_lora_config)
if args.train_text_encoder:
text_lora_config = LoraConfig(
r=args.rank,
@@ -1605,7 +1679,7 @@ def main(args):
lora_state_dict = FluxPipeline.lora_state_dict(input_dir)
transformer_state_dict = {
f"{k.replace('transformer.', '')}": v for k, v in lora_state_dict.items() if k.startswith("transformer.")
f'{k.replace("transformer.", "")}': v for k, v in lora_state_dict.items() if k.startswith("transformer.")
}
transformer_state_dict = convert_unet_state_dict_to_peft(transformer_state_dict)
incompatible_keys = set_peft_model_state_dict(transformer_, transformer_state_dict, adapter_name="default")
@@ -1653,6 +1727,7 @@ def main(args):
cast_training_params(models, dtype=torch.float32)
transformer_lora_parameters = list(filter(lambda p: p.requires_grad, transformer.parameters()))
if args.train_text_encoder:
text_lora_parameters_one = list(filter(lambda p: p.requires_grad, text_encoder_one.parameters()))
# if we use textual inversion, we freeze all parameters except for the token embeddings
@@ -1662,8 +1737,7 @@ def main(args):
for name, param in text_encoder_one.named_parameters():
if "token_embedding" in name:
# ensure that dtype is float32, even if rest of the model that isn't trained is loaded in fp16
if args.mixed_precision == "fp16":
param.data = param.to(dtype=torch.float32)
param.data = param.to(dtype=torch.float32)
param.requires_grad = True
text_lora_parameters_one.append(param)
else:
@@ -1673,8 +1747,7 @@ def main(args):
for name, param in text_encoder_two.named_parameters():
if "shared" in name:
# ensure that dtype is float32, even if rest of the model that isn't trained is loaded in fp16
if args.mixed_precision == "fp16":
param.data = param.to(dtype=torch.float32)
param.data = param.to(dtype=torch.float32)
param.requires_grad = True
text_lora_parameters_two.append(param)
else:
@@ -1755,7 +1828,6 @@ def main(args):
optimizer_class = bnb.optim.AdamW8bit
else:
optimizer_class = torch.optim.AdamW
optimizer = optimizer_class(
params_to_optimize,
betas=(args.adam_beta1, args.adam_beta2),
@@ -1949,7 +2021,6 @@ def main(args):
lr_scheduler,
)
else:
print("I SHOULD BE HERE")
transformer, text_encoder_one, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
transformer, text_encoder_one, optimizer, train_dataloader, lr_scheduler
)
@@ -2054,7 +2125,7 @@ def main(args):
if args.train_text_encoder:
text_encoder_one.train()
# set top parameter requires_grad = True for gradient checkpointing works
unwrap_model(text_encoder_one).text_model.embeddings.requires_grad_(True)
accelerator.unwrap_model(text_encoder_one).text_model.embeddings.requires_grad_(True)
elif args.train_text_encoder_ti: # textual inversion / pivotal tuning
text_encoder_one.train()
if args.enable_t5_ti:
@@ -2066,11 +2137,6 @@ def main(args):
pivoted_tr = True
for step, batch in enumerate(train_dataloader):
models_to_accumulate = [transformer]
if not freeze_text_encoder:
models_to_accumulate.extend([text_encoder_one])
if args.enable_t5_ti:
models_to_accumulate.extend([text_encoder_two])
if pivoted_te:
# stopping optimization of text_encoder params
optimizer.param_groups[te_idx]["lr"] = 0.0
@@ -2079,7 +2145,7 @@ def main(args):
logger.info(f"PIVOT TRANSFORMER {epoch}")
optimizer.param_groups[0]["lr"] = 0.0
with accelerator.accumulate(models_to_accumulate):
with accelerator.accumulate(transformer):
prompts = batch["prompts"]
# encode batch prompts when custom prompts are provided for each image -
@@ -2123,7 +2189,7 @@ def main(args):
model_input = (model_input - vae_config_shift_factor) * vae_config_scaling_factor
model_input = model_input.to(dtype=weight_dtype)
vae_scale_factor = 2 ** (len(vae_config_block_out_channels) - 1)
vae_scale_factor = 2 ** (len(vae_config_block_out_channels))
latent_image_ids = FluxPipeline._prepare_latent_image_ids(
model_input.shape[0],
@@ -2162,7 +2228,7 @@ def main(args):
)
# handle guidance
if unwrap_model(transformer).config.guidance_embeds:
if transformer.config.guidance_embeds:
guidance = torch.tensor([args.guidance_scale], device=accelerator.device)
guidance = guidance.expand(model_input.shape[0])
else:
@@ -2222,26 +2288,16 @@ def main(args):
accelerator.backward(loss)
if accelerator.sync_gradients:
if not freeze_text_encoder:
if args.train_text_encoder: # text encoder tuning
if args.train_text_encoder:
params_to_clip = itertools.chain(transformer.parameters(), text_encoder_one.parameters())
elif pure_textual_inversion:
if args.enable_t5_ti:
params_to_clip = itertools.chain(
text_encoder_one.parameters(), text_encoder_two.parameters()
)
else:
params_to_clip = itertools.chain(text_encoder_one.parameters())
params_to_clip = itertools.chain(
text_encoder_one.parameters(), text_encoder_two.parameters()
)
else:
if args.enable_t5_ti:
params_to_clip = itertools.chain(
transformer.parameters(),
text_encoder_one.parameters(),
text_encoder_two.parameters(),
)
else:
params_to_clip = itertools.chain(
transformer.parameters(), text_encoder_one.parameters()
)
params_to_clip = itertools.chain(
transformer.parameters(), text_encoder_one.parameters(), text_encoder_two.parameters()
)
else:
params_to_clip = itertools.chain(transformer.parameters())
accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm)
@@ -2283,10 +2339,6 @@ def main(args):
save_path = os.path.join(args.output_dir, f"checkpoint-{global_step}")
accelerator.save_state(save_path)
if args.train_text_encoder_ti:
embedding_handler.save_embeddings(
f"{args.output_dir}/{Path(args.output_dir).name}_emb_checkpoint_{global_step}.safetensors"
)
logger.info(f"Saved state to {save_path}")
logs = {"loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0]}
@@ -2299,16 +2351,14 @@ def main(args):
if accelerator.is_main_process:
if args.validation_prompt is not None and epoch % args.validation_epochs == 0:
# create pipeline
if freeze_text_encoder: # no text encoder one, two optimizations
if freeze_text_encoder:
text_encoder_one, text_encoder_two = load_text_encoders(text_encoder_cls_one, text_encoder_cls_two)
text_encoder_one.to(weight_dtype)
text_encoder_two.to(weight_dtype)
pipeline = FluxPipeline.from_pretrained(
args.pretrained_model_name_or_path,
vae=vae,
text_encoder=unwrap_model(text_encoder_one),
text_encoder_2=unwrap_model(text_encoder_two),
transformer=unwrap_model(transformer),
text_encoder=accelerator.unwrap_model(text_encoder_one),
text_encoder_2=accelerator.unwrap_model(text_encoder_two),
transformer=accelerator.unwrap_model(transformer),
revision=args.revision,
variant=args.variant,
torch_dtype=weight_dtype,
@@ -2322,21 +2372,21 @@ def main(args):
epoch=epoch,
torch_dtype=weight_dtype,
)
images = None
del pipeline
if freeze_text_encoder:
del text_encoder_one, text_encoder_two
free_memory()
images = None
del pipeline
elif args.train_text_encoder:
del text_encoder_two
free_memory()
# Save the lora layers
accelerator.wait_for_everyone()
if accelerator.is_main_process:
transformer = unwrap_model(transformer)
if args.upcast_before_saving:
transformer.to(torch.float32)
else:
transformer = transformer.to(weight_dtype)
transformer = transformer.to(weight_dtype)
transformer_lora_layers = get_peft_model_state_dict(transformer)
if args.train_text_encoder:
@@ -2378,8 +2428,8 @@ def main(args):
accelerator=accelerator,
pipeline_args=pipeline_args,
epoch=epoch,
is_final_validation=True,
torch_dtype=weight_dtype,
is_final_validation=True,
)
save_model_card(
@@ -2402,7 +2452,6 @@ def main(args):
commit_message="End of training",
ignore_patterns=["step_*", "epoch_*"],
)
images = None
del pipeline
@@ -200,8 +200,7 @@ Special VAE used for training: {vae_path}.
"diffusers",
"diffusers-training",
lora,
"template:sd-lora",
"stable-diffusion",
"template:sd-lora" "stable-diffusion",
"stable-diffusion-diffusers",
]
model_card = populate_model_card(model_card, tags=tags)
@@ -725,9 +724,9 @@ class TokenEmbeddingsHandler:
idx = 0
for tokenizer, text_encoder in zip(self.tokenizers, self.text_encoders):
assert isinstance(inserting_toks, list), "inserting_toks should be a list of strings."
assert all(isinstance(tok, str) for tok in inserting_toks), (
"All elements in inserting_toks should be strings."
)
assert all(
isinstance(tok, str) for tok in inserting_toks
), "All elements in inserting_toks should be strings."
self.inserting_toks = inserting_toks
special_tokens_dict = {"additional_special_tokens": self.inserting_toks}
@@ -747,9 +746,9 @@ class TokenEmbeddingsHandler:
.to(dtype=self.dtype)
* std_token_embedding
)
self.embeddings_settings[f"original_embeddings_{idx}"] = (
text_encoder.text_model.embeddings.token_embedding.weight.data.clone()
)
self.embeddings_settings[
f"original_embeddings_{idx}"
] = text_encoder.text_model.embeddings.token_embedding.weight.data.clone()
self.embeddings_settings[f"std_token_embedding_{idx}"] = std_token_embedding
inu = torch.ones((len(tokenizer),), dtype=torch.bool)
@@ -1323,7 +1322,7 @@ def main(args):
lora_state_dict, network_alphas = StableDiffusionPipeline.lora_state_dict(input_dir)
unet_state_dict = {f"{k.replace('unet.', '')}": v for k, v in lora_state_dict.items() if k.startswith("unet.")}
unet_state_dict = {f'{k.replace("unet.", "")}': v for k, v in lora_state_dict.items() if k.startswith("unet.")}
unet_state_dict = convert_unet_state_dict_to_peft(unet_state_dict)
incompatible_keys = set_peft_model_state_dict(unet_, unet_state_dict, adapter_name="default")
if incompatible_keys is not None:
@@ -890,9 +890,9 @@ class TokenEmbeddingsHandler:
idx = 0
for tokenizer, text_encoder in zip(self.tokenizers, self.text_encoders):
assert isinstance(inserting_toks, list), "inserting_toks should be a list of strings."
assert all(isinstance(tok, str) for tok in inserting_toks), (
"All elements in inserting_toks should be strings."
)
assert all(
isinstance(tok, str) for tok in inserting_toks
), "All elements in inserting_toks should be strings."
self.inserting_toks = inserting_toks
special_tokens_dict = {"additional_special_tokens": self.inserting_toks}
@@ -912,9 +912,9 @@ class TokenEmbeddingsHandler:
.to(dtype=self.dtype)
* std_token_embedding
)
self.embeddings_settings[f"original_embeddings_{idx}"] = (
text_encoder.text_model.embeddings.token_embedding.weight.data.clone()
)
self.embeddings_settings[
f"original_embeddings_{idx}"
] = text_encoder.text_model.embeddings.token_embedding.weight.data.clone()
self.embeddings_settings[f"std_token_embedding_{idx}"] = std_token_embedding
inu = torch.ones((len(tokenizer),), dtype=torch.bool)
@@ -1647,7 +1647,7 @@ def main(args):
lora_state_dict, network_alphas = StableDiffusionLoraLoaderMixin.lora_state_dict(input_dir)
unet_state_dict = {f"{k.replace('unet.', '')}": v for k, v in lora_state_dict.items() if k.startswith("unet.")}
unet_state_dict = {f'{k.replace("unet.", "")}': v for k, v in lora_state_dict.items() if k.startswith("unet.")}
unet_state_dict = convert_unet_state_dict_to_peft(unet_state_dict)
incompatible_keys = set_peft_model_state_dict(unet_, unet_state_dict, adapter_name="default")
if incompatible_keys is not None:
+1 -1
View File
@@ -720,7 +720,7 @@ def main(args):
# Train!
logger.info("***** Running training *****")
logger.info(f" Num training steps = {args.max_train_steps}")
logger.info(f" Instantaneous batch size per device = {args.train_batch_size}")
logger.info(f" Instantaneous batch size per device = { args.train_batch_size}")
logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}")
logger.info(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}")
@@ -1138,7 +1138,7 @@ def main(args):
lora_state_dict = CogVideoXImageToVideoPipeline.lora_state_dict(input_dir)
transformer_state_dict = {
f"{k.replace('transformer.', '')}": v for k, v in lora_state_dict.items() if k.startswith("transformer.")
f'{k.replace("transformer.", "")}': v for k, v in lora_state_dict.items() if k.startswith("transformer.")
}
transformer_state_dict = convert_unet_state_dict_to_peft(transformer_state_dict)
incompatible_keys = set_peft_model_state_dict(transformer_, transformer_state_dict, adapter_name="default")
+1 -1
View File
@@ -1159,7 +1159,7 @@ def main(args):
lora_state_dict = CogVideoXPipeline.lora_state_dict(input_dir)
transformer_state_dict = {
f"{k.replace('transformer.', '')}": v for k, v in lora_state_dict.items() if k.startswith("transformer.")
f'{k.replace("transformer.", "")}': v for k, v in lora_state_dict.items() if k.startswith("transformer.")
}
transformer_state_dict = convert_unet_state_dict_to_peft(transformer_state_dict)
incompatible_keys = set_peft_model_state_dict(transformer_, transformer_state_dict, adapter_name="default")
@@ -1103,7 +1103,7 @@ class AdaptiveMaskInpaintPipeline(
f"Incorrect configuration settings! The config of `pipeline.unet`: {self.unet.config} expects"
f" {self.unet.config.in_channels} but received `num_channels_latents`: {num_channels_latents} +"
f" `num_channels_mask`: {num_channels_mask} + `num_channels_masked_image`: {num_channels_masked_image}"
f" = {num_channels_latents + num_channels_masked_image + num_channels_mask}. Please verify the config of"
f" = {num_channels_latents+num_channels_masked_image+num_channels_mask}. Please verify the config of"
" `pipeline.unet` or your `default_mask_image` or `image` input."
)
elif num_channels_unet != 4:
+1 -1
View File
@@ -686,7 +686,7 @@ class StableDiffusionHDPainterPipeline(StableDiffusionInpaintPipeline):
f"Incorrect configuration settings! The config of `pipeline.unet`: {self.unet.config} expects"
f" {self.unet.config.in_channels} but received `num_channels_latents`: {num_channels_latents} +"
f" `num_channels_mask`: {num_channels_mask} + `num_channels_masked_image`: {num_channels_masked_image}"
f" = {num_channels_latents + num_channels_masked_image + num_channels_mask}. Please verify the config of"
f" = {num_channels_latents+num_channels_masked_image+num_channels_mask}. Please verify the config of"
" `pipeline.unet` or your `mask_image` or `image` input."
)
elif num_channels_unet != 4:
+1 -1
View File
@@ -362,7 +362,7 @@ class ImageToImageInpaintingPipeline(DiffusionPipeline):
f"Incorrect configuration settings! The config of `pipeline.unet`: {self.unet.config} expects"
f" {self.unet.config.in_channels} but received `num_channels_latents`: {num_channels_latents} +"
f" `num_channels_mask`: {num_channels_mask} + `num_channels_masked_image`: {num_channels_masked_image}"
f" = {num_channels_latents + num_channels_masked_image + num_channels_mask}. Please verify the config of"
f" = {num_channels_latents+num_channels_masked_image+num_channels_mask}. Please verify the config of"
" `pipeline.unet` or your `mask_image` or `image` input."
)
+2 -2
View File
@@ -1120,7 +1120,7 @@ class LLMGroundedDiffusionPipeline(
if verbose:
logger.info(
f"time index {index}, loss: {loss.item() / loss_scale:.3f} (de-scaled with scale {loss_scale:.1f}), loss threshold: {loss_threshold:.3f}"
f"time index {index}, loss: {loss.item()/loss_scale:.3f} (de-scaled with scale {loss_scale:.1f}), loss threshold: {loss_threshold:.3f}"
)
try:
@@ -1184,7 +1184,7 @@ class LLMGroundedDiffusionPipeline(
if verbose:
logger.info(
f"time index {index}, loss: {loss.item() / loss_scale:.3f}, loss threshold: {loss_threshold:.3f}, iteration: {iteration}"
f"time index {index}, loss: {loss.item()/loss_scale:.3f}, loss threshold: {loss_threshold:.3f}, iteration: {iteration}"
)
finally:
@@ -701,7 +701,7 @@ class StableDiffusionXLControlNetTileSRPipeline(
raise ValueError("`max_tile_size` cannot be None.")
elif not isinstance(max_tile_size, int) or max_tile_size not in (1024, 1280):
raise ValueError(
f"`max_tile_size` has to be in 1024 or 1280 but is {max_tile_size} of type {type(max_tile_size)}."
f"`max_tile_size` has to be in 1024 or 1280 but is {max_tile_size} of type" f" {type(max_tile_size)}."
)
if tile_gaussian_sigma is None:
raise ValueError("`tile_gaussian_sigma` cannot be None.")
@@ -488,7 +488,7 @@ class FluxDifferentialImg2ImgPipeline(DiffusionPipeline, FluxLoraLoaderMixin):
if padding_mask_crop is not None:
if not isinstance(image, PIL.Image.Image):
raise ValueError(
f"The image should be a PIL image when inpainting mask crop, but is of type {type(image)}."
f"The image should be a PIL image when inpainting mask crop, but is of type" f" {type(image)}."
)
if not isinstance(mask_image, PIL.Image.Image):
raise ValueError(
@@ -496,7 +496,7 @@ class FluxDifferentialImg2ImgPipeline(DiffusionPipeline, FluxLoraLoaderMixin):
f" {type(mask_image)}."
)
if output_type != "pil":
raise ValueError(f"The output type should be PIL when inpainting mask crop, but is {output_type}.")
raise ValueError(f"The output type should be PIL when inpainting mask crop, but is" f" {output_type}.")
if max_sequence_length is not None and max_sequence_length > 512:
raise ValueError(f"`max_sequence_length` cannot be greater than 512 but is {max_sequence_length}")
+6 -6
View File
@@ -907,12 +907,12 @@ def create_controller(
# reweight
if edit_type == "reweight":
assert equalizer_words is not None and equalizer_strengths is not None, (
"To use reweight edit, please specify equalizer_words and equalizer_strengths."
)
assert len(equalizer_words) == len(equalizer_strengths), (
"equalizer_words and equalizer_strengths must be of same length."
)
assert (
equalizer_words is not None and equalizer_strengths is not None
), "To use reweight edit, please specify equalizer_words and equalizer_strengths."
assert len(equalizer_words) == len(
equalizer_strengths
), "equalizer_words and equalizer_strengths must be of same length."
equalizer = get_equalizer(prompts[1], equalizer_words, equalizer_strengths, tokenizer=tokenizer)
return AttentionReweight(
prompts,
@@ -1738,7 +1738,7 @@ class StyleAlignedSDXLPipeline(
f"Incorrect configuration settings! The config of `pipeline.unet`: {self.unet.config} expects"
f" {self.unet.config.in_channels} but received `num_channels_latents`: {num_channels_latents} +"
f" `num_channels_mask`: {num_channels_mask} + `num_channels_masked_image`: {num_channels_masked_image}"
f" = {num_channels_latents + num_channels_masked_image + num_channels_mask}. Please verify the config of"
f" = {num_channels_latents+num_channels_masked_image+num_channels_mask}. Please verify the config of"
" `pipeline.unet` or your `mask_image` or `image` input."
)
elif num_channels_unet != 4:
@@ -689,7 +689,7 @@ class StableDiffusionUpscaleLDM3DPipeline(
f"Incorrect configuration settings! The config of `pipeline.unet`: {self.unet.config} expects"
f" {self.unet.config.in_channels} but received `num_channels_latents`: {num_channels_latents} +"
f" `num_channels_image`: {num_channels_image} "
f" = {num_channels_latents + num_channels_image}. Please verify the config of"
f" = {num_channels_latents+num_channels_image}. Please verify the config of"
" `pipeline.unet` or your `image` input."
)
@@ -1028,7 +1028,7 @@ class StableDiffusionXL_AE_Pipeline(
if padding_mask_crop is not None:
if not isinstance(image, PIL.Image.Image):
raise ValueError(
f"The image should be a PIL image when inpainting mask crop, but is of type {type(image)}."
f"The image should be a PIL image when inpainting mask crop, but is of type" f" {type(image)}."
)
if not isinstance(mask_image, PIL.Image.Image):
raise ValueError(
@@ -1036,7 +1036,7 @@ class StableDiffusionXL_AE_Pipeline(
f" {type(mask_image)}."
)
if output_type != "pil":
raise ValueError(f"The output type should be PIL when inpainting mask crop, but is {output_type}.")
raise ValueError(f"The output type should be PIL when inpainting mask crop, but is" f" {output_type}.")
if ip_adapter_image is not None and ip_adapter_image_embeds is not None:
raise ValueError(
@@ -2050,7 +2050,7 @@ class StableDiffusionXL_AE_Pipeline(
f"Incorrect configuration settings! The config of `pipeline.unet`: {self.unet.config} expects"
f" {self.unet.config.in_channels} but received `num_channels_latents`: {num_channels_latents} +"
f" `num_channels_mask`: {num_channels_mask} + `num_channels_masked_image`: {num_channels_masked_image}"
f" = {num_channels_latents + num_channels_masked_image + num_channels_mask}. Please verify the config of"
f" = {num_channels_latents+num_channels_masked_image+num_channels_mask}. Please verify the config of"
" `pipeline.unet` or your `mask_image` or `image` input."
)
elif num_channels_unet != 4:
@@ -1578,7 +1578,7 @@ class StableDiffusionXLControlNetAdapterInpaintPipeline(
f"Incorrect configuration settings! The config of `pipeline.unet`: {self.unet.config} expects"
f" {self.unet.config.in_channels} but received `num_channels_latents`: {num_channels_latents} +"
f" `num_channels_mask`: {num_channels_mask} + `num_channels_masked_image`: {num_channels_masked_image}"
f" = {num_channels_latents + num_channels_masked_image + num_channels_mask}. Please verify the config of"
f" = {num_channels_latents+num_channels_masked_image+num_channels_mask}. Please verify the config of"
" `pipeline.unet` or your `mask_image` or `image` input."
)
elif num_channels_unet != 4:
+2 -1
View File
@@ -288,7 +288,8 @@ class UFOGenScheduler(SchedulerMixin, ConfigMixin):
if timesteps[0] >= self.config.num_train_timesteps:
raise ValueError(
f"`timesteps` must start before `self.config.train_timesteps`: {self.config.num_train_timesteps}."
f"`timesteps` must start before `self.config.train_timesteps`:"
f" {self.config.num_train_timesteps}."
)
timesteps = np.array(timesteps, dtype=np.int64)
@@ -89,7 +89,7 @@ def get_module_kohya_state_dict(module, prefix: str, dtype: torch.dtype, adapter
# Set alpha parameter
if "lora_down" in kohya_key:
alpha_key = f"{kohya_key.split('.')[0]}.alpha"
alpha_key = f'{kohya_key.split(".")[0]}.alpha'
kohya_ss_state_dict[alpha_key] = torch.tensor(module.peft_config[adapter_name].lora_alpha).to(dtype)
return kohya_ss_state_dict
@@ -901,7 +901,7 @@ def main(args):
unet_ = accelerator.unwrap_model(unet)
lora_state_dict, _ = StableDiffusionXLPipeline.lora_state_dict(input_dir)
unet_state_dict = {
f"{k.replace('unet.', '')}": v for k, v in lora_state_dict.items() if k.startswith("unet.")
f'{k.replace("unet.", "")}': v for k, v in lora_state_dict.items() if k.startswith("unet.")
}
unet_state_dict = convert_unet_state_dict_to_peft(unet_state_dict)
incompatible_keys = set_peft_model_state_dict(unet_, unet_state_dict, adapter_name="default")
@@ -95,7 +95,7 @@ def get_module_kohya_state_dict(module, prefix: str, dtype: torch.dtype, adapter
# Set alpha parameter
if "lora_down" in kohya_key:
alpha_key = f"{kohya_key.split('.')[0]}.alpha"
alpha_key = f'{kohya_key.split(".")[0]}.alpha'
kohya_ss_state_dict[alpha_key] = torch.tensor(module.peft_config[adapter_name].lora_alpha).to(dtype)
return kohya_ss_state_dict
+7 -18
View File
@@ -927,22 +927,17 @@ def main(args):
)
# Scheduler and math around the number of training steps.
# Check the PR https://github.com/huggingface/diffusers/pull/8312 for detailed explanation.
num_warmup_steps_for_scheduler = args.lr_warmup_steps * accelerator.num_processes
overrode_max_train_steps = False
num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
if args.max_train_steps is None:
len_train_dataloader_after_sharding = math.ceil(len(train_dataloader) / accelerator.num_processes)
num_update_steps_per_epoch = math.ceil(len_train_dataloader_after_sharding / args.gradient_accumulation_steps)
num_training_steps_for_scheduler = (
args.num_train_epochs * num_update_steps_per_epoch * accelerator.num_processes
)
else:
num_training_steps_for_scheduler = args.max_train_steps * accelerator.num_processes
args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
overrode_max_train_steps = True
lr_scheduler = get_scheduler(
args.lr_scheduler,
optimizer=optimizer,
num_warmup_steps=num_warmup_steps_for_scheduler,
num_training_steps=num_training_steps_for_scheduler,
num_warmup_steps=args.lr_warmup_steps * accelerator.num_processes,
num_training_steps=args.max_train_steps * accelerator.num_processes,
num_cycles=args.lr_num_cycles,
power=args.lr_power,
)
@@ -967,14 +962,8 @@ def main(args):
# We need to recalculate our total training steps as the size of the training dataloader may have changed.
num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
if args.max_train_steps is None:
if overrode_max_train_steps:
args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
if num_training_steps_for_scheduler != args.max_train_steps * accelerator.num_processes:
logger.warning(
f"The length of the 'train_dataloader' after 'accelerator.prepare' ({len(train_dataloader)}) does not match "
f"the expected length ({len_train_dataloader_after_sharding}) when the learning rate scheduler was created. "
f"This inconsistency may result in the learning rate scheduler not functioning properly."
)
# Afterwards we recalculate our number of training epochs
args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)
+3 -44
View File
@@ -17,7 +17,6 @@ import argparse
import contextlib
import copy
import functools
import gc
import logging
import math
import os
@@ -53,7 +52,6 @@ from diffusers.optimization import get_scheduler
from diffusers.training_utils import compute_density_for_timestep_sampling, compute_loss_weighting_for_sd3, free_memory
from diffusers.utils import check_min_version, is_wandb_available, make_image_grid
from diffusers.utils.hub_utils import load_or_create_model_card, populate_model_card
from diffusers.utils.testing_utils import backend_empty_cache
from diffusers.utils.torch_utils import is_compiled_module
@@ -76,9 +74,8 @@ def log_validation(controlnet, args, accelerator, weight_dtype, step, is_final_v
pipeline = StableDiffusion3ControlNetPipeline.from_pretrained(
args.pretrained_model_name_or_path,
controlnet=None,
controlnet=controlnet,
safety_checker=None,
transformer=None,
revision=args.revision,
variant=args.variant,
torch_dtype=weight_dtype,
@@ -105,55 +102,18 @@ def log_validation(controlnet, args, accelerator, weight_dtype, step, is_final_v
"number of `args.validation_image` and `args.validation_prompt` should be checked in `parse_args`"
)
with torch.no_grad():
(
prompt_embeds,
negative_prompt_embeds,
pooled_prompt_embeds,
negative_pooled_prompt_embeds,
) = pipeline.encode_prompt(
validation_prompts,
prompt_2=None,
prompt_3=None,
)
del pipeline
gc.collect()
backend_empty_cache(accelerator.device.type)
pipeline = StableDiffusion3ControlNetPipeline.from_pretrained(
args.pretrained_model_name_or_path,
controlnet=controlnet,
safety_checker=None,
text_encoder=None,
text_encoder_2=None,
text_encoder_3=None,
revision=args.revision,
variant=args.variant,
torch_dtype=weight_dtype,
)
pipeline.enable_model_cpu_offload(device=accelerator.device.type)
pipeline.set_progress_bar_config(disable=True)
image_logs = []
inference_ctx = contextlib.nullcontext() if is_final_validation else torch.autocast(accelerator.device.type)
for i, validation_image in enumerate(validation_images):
for validation_prompt, validation_image in zip(validation_prompts, validation_images):
validation_image = Image.open(validation_image).convert("RGB")
validation_prompt = validation_prompts[i]
images = []
for _ in range(args.num_validation_images):
with inference_ctx:
image = pipeline(
prompt_embeds=prompt_embeds[i].unsqueeze(0),
negative_prompt_embeds=negative_prompt_embeds[i].unsqueeze(0),
pooled_prompt_embeds=pooled_prompt_embeds[i].unsqueeze(0),
negative_pooled_prompt_embeds=negative_pooled_prompt_embeds[i].unsqueeze(0),
control_image=validation_image,
num_inference_steps=20,
generator=generator,
validation_prompt, control_image=validation_image, num_inference_steps=20, generator=generator
).images[0]
images.append(image)
@@ -695,7 +655,6 @@ def make_train_dataset(args, tokenizer_one, tokenizer_two, tokenizer_three, acce
dataset = load_dataset(
args.train_data_dir,
cache_dir=args.cache_dir,
trust_remote_code=True,
)
# See more about loading custom images at
# https://huggingface.co/docs/datasets/v2.0.0/en/dataset_script
+3 -5
View File
@@ -50,11 +50,9 @@ def retrieve(class_prompt, class_data_dir, num_class_images):
total = 0
pbar = tqdm(desc="downloading real regularization images", total=num_class_images)
with (
open(f"{class_data_dir}/caption.txt", "w") as f1,
open(f"{class_data_dir}/urls.txt", "w") as f2,
open(f"{class_data_dir}/images.txt", "w") as f3,
):
with open(f"{class_data_dir}/caption.txt", "w") as f1, open(f"{class_data_dir}/urls.txt", "w") as f2, open(
f"{class_data_dir}/images.txt", "w"
) as f3:
while total < num_class_images:
images = class_images[count]
count += 1
@@ -731,18 +731,18 @@ def main(args):
if not class_images_dir.exists():
class_images_dir.mkdir(parents=True, exist_ok=True)
if args.real_prior:
assert (class_images_dir / "images").exists(), (
f'Please run: python retrieve.py --class_prompt "{concept["class_prompt"]}" --class_data_dir {class_images_dir} --num_class_images {args.num_class_images}'
)
assert len(list((class_images_dir / "images").iterdir())) == args.num_class_images, (
f'Please run: python retrieve.py --class_prompt "{concept["class_prompt"]}" --class_data_dir {class_images_dir} --num_class_images {args.num_class_images}'
)
assert (class_images_dir / "caption.txt").exists(), (
f'Please run: python retrieve.py --class_prompt "{concept["class_prompt"]}" --class_data_dir {class_images_dir} --num_class_images {args.num_class_images}'
)
assert (class_images_dir / "images.txt").exists(), (
f'Please run: python retrieve.py --class_prompt "{concept["class_prompt"]}" --class_data_dir {class_images_dir} --num_class_images {args.num_class_images}'
)
assert (
class_images_dir / "images"
).exists(), f"Please run: python retrieve.py --class_prompt \"{concept['class_prompt']}\" --class_data_dir {class_images_dir} --num_class_images {args.num_class_images}"
assert (
len(list((class_images_dir / "images").iterdir())) == args.num_class_images
), f"Please run: python retrieve.py --class_prompt \"{concept['class_prompt']}\" --class_data_dir {class_images_dir} --num_class_images {args.num_class_images}"
assert (
class_images_dir / "caption.txt"
).exists(), f"Please run: python retrieve.py --class_prompt \"{concept['class_prompt']}\" --class_data_dir {class_images_dir} --num_class_images {args.num_class_images}"
assert (
class_images_dir / "images.txt"
).exists(), f"Please run: python retrieve.py --class_prompt \"{concept['class_prompt']}\" --class_data_dir {class_images_dir} --num_class_images {args.num_class_images}"
concept["class_prompt"] = os.path.join(class_images_dir, "caption.txt")
concept["class_data_dir"] = os.path.join(class_images_dir, "images.txt")
args.concepts_list[i] = concept
+1 -1
View File
@@ -1014,7 +1014,7 @@ def main(args):
if args.train_text_encoder and unwrap_model(text_encoder).dtype != torch.float32:
raise ValueError(
f"Text encoder loaded as datatype {unwrap_model(text_encoder).dtype}. {low_precision_error_string}"
f"Text encoder loaded as datatype {unwrap_model(text_encoder).dtype}." f" {low_precision_error_string}"
)
# Enable TF32 for faster training on Ampere GPUs,
+7 -19
View File
@@ -895,10 +895,7 @@ def _encode_prompt_with_t5(
prompt_embeds = text_encoder(text_input_ids.to(device))[0]
if hasattr(text_encoder, "module"):
dtype = text_encoder.module.dtype
else:
dtype = text_encoder.dtype
dtype = text_encoder.dtype
prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)
_, seq_len, _ = prompt_embeds.shape
@@ -939,13 +936,9 @@ def _encode_prompt_with_clip(
prompt_embeds = text_encoder(text_input_ids.to(device), output_hidden_states=False)
if hasattr(text_encoder, "module"):
dtype = text_encoder.module.dtype
else:
dtype = text_encoder.dtype
# Use pooled output of CLIPTextModel
prompt_embeds = prompt_embeds.pooler_output
prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)
prompt_embeds = prompt_embeds.to(dtype=text_encoder.dtype, device=device)
# duplicate text embeddings for each generation per prompt, using mps friendly method
prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
@@ -965,12 +958,7 @@ def encode_prompt(
):
prompt = [prompt] if isinstance(prompt, str) else prompt
batch_size = len(prompt)
if hasattr(text_encoders[0], "module"):
dtype = text_encoders[0].module.dtype
else:
dtype = text_encoders[0].dtype
dtype = text_encoders[0].dtype
device = device if device is not None else text_encoders[1].device
pooled_prompt_embeds = _encode_prompt_with_clip(
text_encoder=text_encoders[0],
@@ -1602,7 +1590,7 @@ def main(args):
)
# handle guidance
if unwrap_model(transformer).config.guidance_embeds:
if accelerator.unwrap_model(transformer).config.guidance_embeds:
guidance = torch.tensor([args.guidance_scale], device=accelerator.device)
guidance = guidance.expand(model_input.shape[0])
else:
@@ -1728,9 +1716,9 @@ def main(args):
pipeline = FluxPipeline.from_pretrained(
args.pretrained_model_name_or_path,
vae=vae,
text_encoder=unwrap_model(text_encoder_one, keep_fp32_wrapper=False),
text_encoder_2=unwrap_model(text_encoder_two, keep_fp32_wrapper=False),
transformer=unwrap_model(transformer, keep_fp32_wrapper=False),
text_encoder=accelerator.unwrap_model(text_encoder_one, keep_fp32_wrapper=False),
text_encoder_2=accelerator.unwrap_model(text_encoder_two, keep_fp32_wrapper=False),
transformer=accelerator.unwrap_model(transformer, keep_fp32_wrapper=False),
revision=args.revision,
variant=args.variant,
torch_dtype=weight_dtype,
+1 -1
View File
@@ -982,7 +982,7 @@ def main(args):
lora_state_dict, network_alphas = StableDiffusionLoraLoaderMixin.lora_state_dict(input_dir)
unet_state_dict = {f"{k.replace('unet.', '')}": v for k, v in lora_state_dict.items() if k.startswith("unet.")}
unet_state_dict = {f'{k.replace("unet.", "")}': v for k, v in lora_state_dict.items() if k.startswith("unet.")}
unet_state_dict = convert_unet_state_dict_to_peft(unet_state_dict)
incompatible_keys = set_peft_model_state_dict(unet_, unet_state_dict, adapter_name="default")
@@ -177,25 +177,16 @@ def log_validation(
f"Running validation... \n Generating {args.num_validation_images} images with prompt:"
f" {args.validation_prompt}."
)
pipeline = pipeline.to(accelerator.device, dtype=torch_dtype)
pipeline = pipeline.to(accelerator.device)
pipeline.set_progress_bar_config(disable=True)
# run inference
generator = torch.Generator(device=accelerator.device).manual_seed(args.seed) if args.seed is not None else None
autocast_ctx = torch.autocast(accelerator.device.type) if not is_final_validation else nullcontext()
# autocast_ctx = torch.autocast(accelerator.device.type) if not is_final_validation else nullcontext()
autocast_ctx = nullcontext()
# pre-calculate prompt embeds, pooled prompt embeds, text ids because t5 does not support autocast
with torch.no_grad():
prompt_embeds, pooled_prompt_embeds, text_ids = pipeline.encode_prompt(
pipeline_args["prompt"], prompt_2=pipeline_args["prompt"]
)
images = []
for _ in range(args.num_validation_images):
with autocast_ctx:
image = pipeline(
prompt_embeds=prompt_embeds, pooled_prompt_embeds=pooled_prompt_embeds, generator=generator
).images[0]
images.append(image)
with autocast_ctx:
images = [pipeline(**pipeline_args, generator=generator).images[0] for _ in range(args.num_validation_images)]
for tracker in accelerator.trackers:
phase_name = "test" if is_final_validation else "validation"
@@ -212,7 +203,8 @@ def log_validation(
)
del pipeline
free_memory()
if torch.cuda.is_available():
torch.cuda.empty_cache()
return images
@@ -940,10 +932,7 @@ def _encode_prompt_with_t5(
prompt_embeds = text_encoder(text_input_ids.to(device))[0]
if hasattr(text_encoder, "module"):
dtype = text_encoder.module.dtype
else:
dtype = text_encoder.dtype
dtype = text_encoder.dtype
prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)
_, seq_len, _ = prompt_embeds.shape
@@ -984,13 +973,9 @@ def _encode_prompt_with_clip(
prompt_embeds = text_encoder(text_input_ids.to(device), output_hidden_states=False)
if hasattr(text_encoder, "module"):
dtype = text_encoder.module.dtype
else:
dtype = text_encoder.dtype
# Use pooled output of CLIPTextModel
prompt_embeds = prompt_embeds.pooler_output
prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)
prompt_embeds = prompt_embeds.to(dtype=text_encoder.dtype, device=device)
# duplicate text embeddings for each generation per prompt, using mps friendly method
prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
@@ -1009,11 +994,7 @@ def encode_prompt(
text_input_ids_list=None,
):
prompt = [prompt] if isinstance(prompt, str) else prompt
if hasattr(text_encoders[0], "module"):
dtype = text_encoders[0].module.dtype
else:
dtype = text_encoders[0].dtype
dtype = text_encoders[0].dtype
pooled_prompt_embeds = _encode_prompt_with_clip(
text_encoder=text_encoders[0],
@@ -1294,7 +1275,7 @@ def main(args):
lora_state_dict = FluxPipeline.lora_state_dict(input_dir)
transformer_state_dict = {
f"{k.replace('transformer.', '')}": v for k, v in lora_state_dict.items() if k.startswith("transformer.")
f'{k.replace("transformer.", "")}': v for k, v in lora_state_dict.items() if k.startswith("transformer.")
}
transformer_state_dict = convert_unet_state_dict_to_peft(transformer_state_dict)
incompatible_keys = set_peft_model_state_dict(transformer_, transformer_state_dict, adapter_name="default")
@@ -1638,7 +1619,7 @@ def main(args):
if args.train_text_encoder:
text_encoder_one.train()
# set top parameter requires_grad = True for gradient checkpointing works
unwrap_model(text_encoder_one).text_model.embeddings.requires_grad_(True)
accelerator.unwrap_model(text_encoder_one).text_model.embeddings.requires_grad_(True)
for step, batch in enumerate(train_dataloader):
models_to_accumulate = [transformer]
@@ -1729,7 +1710,7 @@ def main(args):
)
# handle guidance
if unwrap_model(transformer).config.guidance_embeds:
if accelerator.unwrap_model(transformer).config.guidance_embeds:
guidance = torch.tensor([args.guidance_scale], device=accelerator.device)
guidance = guidance.expand(model_input.shape[0])
else:
@@ -1847,9 +1828,9 @@ def main(args):
pipeline = FluxPipeline.from_pretrained(
args.pretrained_model_name_or_path,
vae=vae,
text_encoder=unwrap_model(text_encoder_one),
text_encoder_2=unwrap_model(text_encoder_two),
transformer=unwrap_model(transformer),
text_encoder=accelerator.unwrap_model(text_encoder_one),
text_encoder_2=accelerator.unwrap_model(text_encoder_two),
transformer=accelerator.unwrap_model(transformer),
revision=args.revision,
variant=args.variant,
torch_dtype=weight_dtype,
@@ -1053,7 +1053,7 @@ def main(args):
lora_state_dict = Lumina2Text2ImgPipeline.lora_state_dict(input_dir)
transformer_state_dict = {
f"{k.replace('transformer.', '')}": v for k, v in lora_state_dict.items() if k.startswith("transformer.")
f'{k.replace("transformer.", "")}': v for k, v in lora_state_dict.items() if k.startswith("transformer.")
}
transformer_state_dict = convert_unet_state_dict_to_peft(transformer_state_dict)
incompatible_keys = set_peft_model_state_dict(transformer_, transformer_state_dict, adapter_name="default")
@@ -1064,7 +1064,7 @@ def main(args):
lora_state_dict = SanaPipeline.lora_state_dict(input_dir)
transformer_state_dict = {
f"{k.replace('transformer.', '')}": v for k, v in lora_state_dict.items() if k.startswith("transformer.")
f'{k.replace("transformer.", "")}': v for k, v in lora_state_dict.items() if k.startswith("transformer.")
}
transformer_state_dict = convert_unet_state_dict_to_peft(transformer_state_dict)
incompatible_keys = set_peft_model_state_dict(transformer_, transformer_state_dict, adapter_name="default")
@@ -1355,7 +1355,7 @@ def main(args):
lora_state_dict = StableDiffusion3Pipeline.lora_state_dict(input_dir)
transformer_state_dict = {
f"{k.replace('transformer.', '')}": v for k, v in lora_state_dict.items() if k.startswith("transformer.")
f'{k.replace("transformer.", "")}': v for k, v in lora_state_dict.items() if k.startswith("transformer.")
}
transformer_state_dict = convert_unet_state_dict_to_peft(transformer_state_dict)
incompatible_keys = set_peft_model_state_dict(transformer_, transformer_state_dict, adapter_name="default")
@@ -118,7 +118,7 @@ def save_model_card(
)
model_description = f"""
# {"SDXL" if "playground" not in base_model else "Playground"} LoRA DreamBooth - {repo_id}
# {'SDXL' if 'playground' not in base_model else 'Playground'} LoRA DreamBooth - {repo_id}
<Gallery />
@@ -669,16 +669,6 @@ def parse_args(input_args=None):
),
)
parser.add_argument(
"--image_interpolation_mode",
type=str,
default="lanczos",
choices=[
f.lower() for f in dir(transforms.InterpolationMode) if not f.startswith("__") and not f.endswith("__")
],
help="The image interpolation method to use for resizing images.",
)
if input_args is not None:
args = parser.parse_args(input_args)
else:
@@ -800,12 +790,7 @@ class DreamBoothDataset(Dataset):
self.original_sizes = []
self.crop_top_lefts = []
self.pixel_values = []
interpolation = getattr(transforms.InterpolationMode, args.image_interpolation_mode.upper(), None)
if interpolation is None:
raise ValueError(f"Unsupported interpolation mode {interpolation=}.")
train_resize = transforms.Resize(size, interpolation=interpolation)
train_resize = transforms.Resize(size, interpolation=transforms.InterpolationMode.BILINEAR)
train_crop = transforms.CenterCrop(size) if center_crop else transforms.RandomCrop(size)
train_flip = transforms.RandomHorizontalFlip(p=1.0)
train_transforms = transforms.Compose(
@@ -1286,7 +1271,7 @@ def main(args):
lora_state_dict, network_alphas = StableDiffusionLoraLoaderMixin.lora_state_dict(input_dir)
unet_state_dict = {f"{k.replace('unet.', '')}": v for k, v in lora_state_dict.items() if k.startswith("unet.")}
unet_state_dict = {f'{k.replace("unet.", "")}': v for k, v in lora_state_dict.items() if k.startswith("unet.")}
unet_state_dict = convert_unet_state_dict_to_peft(unet_state_dict)
incompatible_keys = set_peft_model_state_dict(unet_, unet_state_dict, adapter_name="default")
if incompatible_keys is not None:
@@ -91,9 +91,9 @@ def log_validation(flux_transformer, args, accelerator, weight_dtype, step, is_f
torch_dtype=weight_dtype,
)
pipeline.load_lora_weights(args.output_dir)
assert pipeline.transformer.config.in_channels == initial_channels * 2, (
f"{pipeline.transformer.config.in_channels=}"
)
assert (
pipeline.transformer.config.in_channels == initial_channels * 2
), f"{pipeline.transformer.config.in_channels=}"
pipeline.to(accelerator.device)
pipeline.set_progress_bar_config(disable=True)
@@ -954,7 +954,7 @@ def main(args):
lora_state_dict = FluxControlPipeline.lora_state_dict(input_dir)
transformer_lora_state_dict = {
f"{k.replace('transformer.', '')}": v
f'{k.replace("transformer.", "")}': v
for k, v in lora_state_dict.items()
if k.startswith("transformer.") and "lora" in k
}
+3 -3
View File
@@ -1081,9 +1081,9 @@ class AutoConfig:
f"textual_inversion_path: {search_word} -> {textual_inversion_path.model_status.site_url}"
)
pretrained_model_name_or_paths[pretrained_model_name_or_paths.index(search_word)] = (
textual_inversion_path.model_path
)
pretrained_model_name_or_paths[
pretrained_model_name_or_paths.index(search_word)
] = textual_inversion_path.model_path
self.load_textual_inversion(
pretrained_model_name_or_paths, token=tokens, tokenizer=tokenizer, text_encoder=text_encoder, **kwargs
@@ -187,9 +187,9 @@ def get_clip_token_for_string(tokenizer, string):
return_tensors="pt",
)
tokens = batch_encoding["input_ids"]
assert torch.count_nonzero(tokens - 49407) == 2, (
f"String '{string}' maps to more than a single token. Please use another string"
)
assert (
torch.count_nonzero(tokens - 49407) == 2
), f"String '{string}' maps to more than a single token. Please use another string"
return tokens[0, 1]
@@ -312,9 +312,9 @@ class PatchEmbed(nn.Module):
def forward(self, x):
B, C, H, W = x.shape
assert H == self.img_size[0] and W == self.img_size[1], (
f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})."
)
assert (
H == self.img_size[0] and W == self.img_size[1]
), f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})."
x = self.proj(x).flatten(2).permute(0, 2, 1)
return x
@@ -619,7 +619,7 @@ def main(args):
optimizer.step()
lr_scheduler.step()
logger.info(f"max GPU_mem cost is {torch.cuda.max_memory_allocated() / 2**20} MB", ranks=[0])
logger.info(f"max GPU_mem cost is {torch.cuda.max_memory_allocated()/2**20} MB", ranks=[0])
# Checks if the accelerator has performed an optimization step behind the scenes
progress_bar.update(1)
global_step += 1
@@ -803,20 +803,21 @@ def parse_args(input_args=None):
"--control_type",
type=str,
default="canny",
help=("The type of controlnet conditioning image to use. One of `canny`, `depth` Defaults to `canny`."),
help=("The type of controlnet conditioning image to use. One of `canny`, `depth`" " Defaults to `canny`."),
)
parser.add_argument(
"--transformer_layers_per_block",
type=str,
default=None,
help=("The number of layers per block in the transformer. If None, defaults to `args.transformer_layers`."),
help=("The number of layers per block in the transformer. If None, defaults to" " `args.transformer_layers`."),
)
parser.add_argument(
"--old_style_controlnet",
action="store_true",
default=False,
help=(
"Use the old style controlnet, which is a single transformer layer with a single head. Defaults to False."
"Use the old style controlnet, which is a single transformer layer with"
" a single head. Defaults to False."
),
)
@@ -86,7 +86,7 @@ def import_model_class_from_model_name_or_path(pretrained_model_name_or_path: st
def log_validation(args, unet, accelerator, weight_dtype, epoch, is_final_validation=False):
logger.info(f"Running validation... \n Generating images with prompts:\n {VALIDATION_PROMPTS}.")
logger.info(f"Running validation... \n Generating images with prompts:\n" f" {VALIDATION_PROMPTS}.")
# create pipeline
pipeline = DiffusionPipeline.from_pretrained(
@@ -91,7 +91,7 @@ def import_model_class_from_model_name_or_path(
def log_validation(args, unet, vae, accelerator, weight_dtype, epoch, is_final_validation=False):
logger.info(f"Running validation... \n Generating images with prompts:\n {VALIDATION_PROMPTS}.")
logger.info(f"Running validation... \n Generating images with prompts:\n" f" {VALIDATION_PROMPTS}.")
if is_final_validation:
if args.mixed_precision == "fp16":
@@ -91,7 +91,7 @@ def import_model_class_from_model_name_or_path(
def log_validation(args, unet, vae, accelerator, weight_dtype, epoch, is_final_validation=False):
logger.info(f"Running validation... \n Generating images with prompts:\n {VALIDATION_PROMPTS}.")
logger.info(f"Running validation... \n Generating images with prompts:\n" f" {VALIDATION_PROMPTS}.")
if is_final_validation:
if args.mixed_precision == "fp16":
@@ -683,7 +683,7 @@ def main(args):
lora_state_dict, network_alphas = StableDiffusionXLLoraLoaderMixin.lora_state_dict(input_dir)
unet_state_dict = {f"{k.replace('unet.', '')}": v for k, v in lora_state_dict.items() if k.startswith("unet.")}
unet_state_dict = {f'{k.replace("unet.", "")}': v for k, v in lora_state_dict.items() if k.startswith("unet.")}
unet_state_dict = convert_unet_state_dict_to_peft(unet_state_dict)
incompatible_keys = set_peft_model_state_dict(unet_, unet_state_dict, adapter_name="default")
if incompatible_keys is not None:
@@ -89,7 +89,7 @@ def import_model_class_from_model_name_or_path(
def log_validation(args, unet, vae, accelerator, weight_dtype, epoch, is_final_validation=False):
logger.info(f"Running validation... \n Generating images with prompts:\n {VALIDATION_PROMPTS}.")
logger.info(f"Running validation... \n Generating images with prompts:\n" f" {VALIDATION_PROMPTS}.")
if is_final_validation:
if args.mixed_precision == "fp16":
@@ -790,7 +790,7 @@ def main(args):
lora_state_dict, network_alphas = StableDiffusionXLLoraLoaderMixin.lora_state_dict(input_dir)
unet_state_dict = {f"{k.replace('unet.', '')}": v for k, v in lora_state_dict.items() if k.startswith("unet.")}
unet_state_dict = {f'{k.replace("unet.", "")}': v for k, v in lora_state_dict.items() if k.startswith("unet.")}
unet_state_dict = convert_unet_state_dict_to_peft(unet_state_dict)
incompatible_keys = set_peft_model_state_dict(unet_, unet_state_dict, adapter_name="default")
if incompatible_keys is not None:
@@ -783,7 +783,7 @@ def main(args):
lora_state_dict = FluxPipeline.lora_state_dict(input_dir)
transformer_state_dict = {
f"{k.replace('transformer.', '')}": v for k, v in lora_state_dict.items() if k.startswith("transformer.")
f'{k.replace("transformer.", "")}': v for k, v in lora_state_dict.items() if k.startswith("transformer.")
}
transformer_state_dict = convert_unet_state_dict_to_peft(transformer_state_dict)
incompatible_keys = set_peft_model_state_dict(transformer_, transformer_state_dict, adapter_name="default")
File diff suppressed because one or more lines are too long
+23 -18
View File
@@ -26,7 +26,8 @@
"%load_ext autoreload\n",
"%autoreload 2\n",
"\n",
"from diffusers import StableDiffusionGLIGENPipeline"
"import torch\n",
"from diffusers import StableDiffusionGLIGENTextImagePipeline, StableDiffusionGLIGENPipeline"
]
},
{
@@ -35,25 +36,28 @@
"metadata": {},
"outputs": [],
"source": [
"from transformers import CLIPTextModel, CLIPTokenizer\n",
"\n",
"import os\n",
"import diffusers\n",
"from diffusers import (\n",
" AutoencoderKL,\n",
" DDPMScheduler,\n",
" EulerDiscreteScheduler,\n",
" UNet2DConditionModel,\n",
" UniPCMultistepScheduler,\n",
" EulerDiscreteScheduler,\n",
")\n",
"\n",
"\n",
"from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer\n",
"# pretrained_model_name_or_path = 'masterful/gligen-1-4-generation-text-box'\n",
"\n",
"pretrained_model_name_or_path = \"/root/data/zhizhonghuang/checkpoints/models--masterful--gligen-1-4-generation-text-box/snapshots/d2820dc1e9ba6ca082051ce79cfd3eb468ae2c83\"\n",
"pretrained_model_name_or_path = '/root/data/zhizhonghuang/checkpoints/models--masterful--gligen-1-4-generation-text-box/snapshots/d2820dc1e9ba6ca082051ce79cfd3eb468ae2c83'\n",
"\n",
"tokenizer = CLIPTokenizer.from_pretrained(pretrained_model_name_or_path, subfolder=\"tokenizer\")\n",
"noise_scheduler = DDPMScheduler.from_pretrained(pretrained_model_name_or_path, subfolder=\"scheduler\")\n",
"text_encoder = CLIPTextModel.from_pretrained(pretrained_model_name_or_path, subfolder=\"text_encoder\")\n",
"vae = AutoencoderKL.from_pretrained(pretrained_model_name_or_path, subfolder=\"vae\")\n",
"text_encoder = CLIPTextModel.from_pretrained(\n",
" pretrained_model_name_or_path, subfolder=\"text_encoder\"\n",
")\n",
"vae = AutoencoderKL.from_pretrained(\n",
" pretrained_model_name_or_path, subfolder=\"vae\"\n",
")\n",
"# unet = UNet2DConditionModel.from_pretrained(\n",
"# pretrained_model_name_or_path, subfolder=\"unet\"\n",
"# )\n",
@@ -67,7 +71,9 @@
"metadata": {},
"outputs": [],
"source": [
"unet = UNet2DConditionModel.from_pretrained(\"/root/data/zhizhonghuang/ckpt/GLIGEN_Text_Retrain_COCO\")"
"unet = UNet2DConditionModel.from_pretrained(\n",
" '/root/data/zhizhonghuang/ckpt/GLIGEN_Text_Retrain_COCO'\n",
")"
]
},
{
@@ -102,9 +108,6 @@
"metadata": {},
"outputs": [],
"source": [
"import numpy as np\n",
"\n",
"\n",
"# prompt = 'A realistic image of landscape scene depicting a green car parking on the left of a blue truck, with a red air balloon and a bird in the sky'\n",
"# gen_boxes = [('a green car', [21, 281, 211, 159]), ('a blue truck', [269, 283, 209, 160]), ('a red air balloon', [66, 8, 145, 135]), ('a bird', [296, 42, 143, 100])]\n",
"\n",
@@ -114,8 +117,10 @@
"# prompt = 'A realistic scene of three skiers standing in a line on the snow near a palm tree'\n",
"# gen_boxes = [('a skier', [5, 152, 139, 168]), ('a skier', [278, 192, 121, 158]), ('a skier', [148, 173, 124, 155]), ('a palm tree', [404, 105, 103, 251])]\n",
"\n",
"prompt = \"An oil painting of a pink dolphin jumping on the left of a steam boat on the sea\"\n",
"gen_boxes = [(\"a steam boat\", [232, 225, 257, 149]), (\"a jumping pink dolphin\", [21, 249, 189, 123])]\n",
"prompt = 'An oil painting of a pink dolphin jumping on the left of a steam boat on the sea'\n",
"gen_boxes = [('a steam boat', [232, 225, 257, 149]), ('a jumping pink dolphin', [21, 249, 189, 123])]\n",
"\n",
"import numpy as np\n",
"\n",
"boxes = np.array([x[1] for x in gen_boxes])\n",
"boxes = boxes / 512\n",
@@ -161,7 +166,7 @@
"metadata": {},
"outputs": [],
"source": [
"diffusers.utils.make_image_grid(images, 4, len(images) // 4)"
"diffusers.utils.make_image_grid(images, 4, len(images)//4)"
]
},
{
@@ -174,7 +179,7 @@
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"display_name": "densecaption",
"language": "python",
"name": "python3"
},
@@ -192,5 +197,5 @@
}
},
"nbformat": 4,
"nbformat_minor": 4
"nbformat_minor": 2
}
@@ -15,8 +15,8 @@
# limitations under the License.
"""
Script to fine-tune Stable Diffusion for LORA InstructPix2Pix.
Base code referred from: https://github.com/huggingface/diffusers/blob/main/examples/instruct_pix2pix/train_instruct_pix2pix.py
Script to fine-tune Stable Diffusion for LORA InstructPix2Pix.
Base code referred from: https://github.com/huggingface/diffusers/blob/main/examples/instruct_pix2pix/train_instruct_pix2pix.py
"""
import argparse
@@ -763,9 +763,9 @@ def main(args):
# Parse instance and class inputs, and double check that lengths match
instance_data_dir = args.instance_data_dir.split(",")
instance_prompt = args.instance_prompt.split(",")
assert all(x == len(instance_data_dir) for x in [len(instance_data_dir), len(instance_prompt)]), (
"Instance data dir and prompt inputs are not of the same length."
)
assert all(
x == len(instance_data_dir) for x in [len(instance_data_dir), len(instance_prompt)]
), "Instance data dir and prompt inputs are not of the same length."
if args.with_prior_preservation:
class_data_dir = args.class_data_dir.split(",")
@@ -788,9 +788,9 @@ def main(args):
negative_validation_prompts.append(None)
args.validation_negative_prompt = negative_validation_prompts
assert num_of_validation_prompts == len(negative_validation_prompts), (
"The length of negative prompts for validation is greater than the number of validation prompts."
)
assert num_of_validation_prompts == len(
negative_validation_prompts
), "The length of negative prompts for validation is greater than the number of validation prompts."
args.validation_inference_steps = [args.validation_inference_steps] * num_of_validation_prompts
args.validation_guidance_scale = [args.validation_guidance_scale] * num_of_validation_prompts
@@ -830,9 +830,9 @@ def main():
# Let's make sure we don't update any embedding weights besides the newly added token
index_no_updates = get_mask(tokenizer, accelerator)
with torch.no_grad():
accelerator.unwrap_model(text_encoder).get_input_embeddings().weight[index_no_updates] = (
orig_embeds_params[index_no_updates]
)
accelerator.unwrap_model(text_encoder).get_input_embeddings().weight[
index_no_updates
] = orig_embeds_params[index_no_updates]
# Checks if the accelerator has performed an optimization step behind the scenes
if accelerator.sync_gradients:
@@ -886,9 +886,9 @@ def main():
index_no_updates[min(placeholder_token_ids) : max(placeholder_token_ids) + 1] = False
with torch.no_grad():
accelerator.unwrap_model(text_encoder).get_input_embeddings().weight[index_no_updates] = (
orig_embeds_params[index_no_updates]
)
accelerator.unwrap_model(text_encoder).get_input_embeddings().weight[
index_no_updates
] = orig_embeds_params[index_no_updates]
# Checks if the accelerator has performed an optimization step behind the scenes
if accelerator.sync_gradients:
@@ -663,7 +663,8 @@ class PromptDiffusionPipeline(
self.check_image(image, prompt, prompt_embeds)
else:
raise ValueError(
f"You have passed a list of images of length {len(image_pair)}.Make sure the list size equals to two."
f"You have passed a list of images of length {len(image_pair)}."
f"Make sure the list size equals to two."
)
# Check `controlnet_conditioning_scale`
@@ -173,7 +173,7 @@ class TrainSD:
if not dataloader_exception:
xm.wait_device_ops()
total_time = time.time() - last_time
print(f"Average step time: {total_time / (self.args.max_train_steps - measure_start_step)}")
print(f"Average step time: {total_time/(self.args.max_train_steps-measure_start_step)}")
else:
print("dataloader exception happen, skip result")
return
@@ -622,7 +622,7 @@ def main(args):
num_devices_per_host = num_devices // num_hosts
if xm.is_master_ordinal():
print("***** Running training *****")
print(f"Instantaneous batch size per device = {args.train_batch_size // num_devices_per_host}")
print(f"Instantaneous batch size per device = {args.train_batch_size // num_devices_per_host }")
print(
f"Total train batch size (w. parallel, distributed & accumulation) = {args.train_batch_size * num_hosts}"
)
@@ -1057,7 +1057,7 @@ def main(args):
if args.train_text_encoder and unwrap_model(text_encoder).dtype != torch.float32:
raise ValueError(
f"Text encoder loaded as datatype {unwrap_model(text_encoder).dtype}. {low_precision_error_string}"
f"Text encoder loaded as datatype {unwrap_model(text_encoder).dtype}." f" {low_precision_error_string}"
)
# Enable TF32 for faster training on Ampere GPUs,
@@ -1021,7 +1021,7 @@ def main(args):
lora_state_dict, network_alphas = StableDiffusionLoraLoaderMixin.lora_state_dict(input_dir)
unet_state_dict = {f"{k.replace('unet.', '')}": v for k, v in lora_state_dict.items() if k.startswith("unet.")}
unet_state_dict = {f'{k.replace("unet.", "")}': v for k, v in lora_state_dict.items() if k.startswith("unet.")}
unet_state_dict = convert_unet_state_dict_to_peft(unet_state_dict)
incompatible_keys = set_peft_model_state_dict(unet_, unet_state_dict, adapter_name="default")
@@ -118,7 +118,7 @@ def save_model_card(
)
model_description = f"""
# {"SDXL" if "playground" not in base_model else "Playground"} LoRA DreamBooth - {repo_id}
# {'SDXL' if 'playground' not in base_model else 'Playground'} LoRA DreamBooth - {repo_id}
<Gallery />
@@ -1336,7 +1336,7 @@ def main(args):
lora_state_dict, network_alphas = StableDiffusionLoraLoaderMixin.lora_state_dict(input_dir)
unet_state_dict = {f"{k.replace('unet.', '')}": v for k, v in lora_state_dict.items() if k.startswith("unet.")}
unet_state_dict = {f'{k.replace("unet.", "")}': v for k, v in lora_state_dict.items() if k.startswith("unet.")}
unet_state_dict = convert_unet_state_dict_to_peft(unet_state_dict)
incompatible_keys = set_peft_model_state_dict(unet_, unet_state_dict, adapter_name="default")
if incompatible_keys is not None:
@@ -750,7 +750,7 @@ def main(args):
raise ValueError(f"unexpected save model: {model.__class__}")
lora_state_dict, _ = StableDiffusionLoraLoaderMixin.lora_state_dict(input_dir)
unet_state_dict = {f"{k.replace('unet.', '')}": v for k, v in lora_state_dict.items() if k.startswith("unet.")}
unet_state_dict = {f'{k.replace("unet.", "")}': v for k, v in lora_state_dict.items() if k.startswith("unet.")}
unet_state_dict = convert_unet_state_dict_to_peft(unet_state_dict)
incompatible_keys = set_peft_model_state_dict(unet_, unet_state_dict, adapter_name="default")
if incompatible_keys is not None:
@@ -765,7 +765,7 @@ def main(args):
lora_state_dict = StableDiffusion3Pipeline.lora_state_dict(input_dir)
transformer_state_dict = {
f"{k.replace('transformer.', '')}": v for k, v in lora_state_dict.items() if k.startswith("transformer.")
f'{k.replace("transformer.", "")}': v for k, v in lora_state_dict.items() if k.startswith("transformer.")
}
transformer_state_dict = convert_unet_state_dict_to_peft(transformer_state_dict)
incompatible_keys = set_peft_model_state_dict(transformer_, transformer_state_dict, adapter_name="default")
@@ -767,7 +767,7 @@ def main(args):
raise ValueError(f"unexpected save model: {model.__class__}")
lora_state_dict, _ = StableDiffusionLoraLoaderMixin.lora_state_dict(input_dir)
unet_state_dict = {f"{k.replace('unet.', '')}": v for k, v in lora_state_dict.items() if k.startswith("unet.")}
unet_state_dict = {f'{k.replace("unet.", "")}': v for k, v in lora_state_dict.items() if k.startswith("unet.")}
unet_state_dict = convert_unet_state_dict_to_peft(unet_state_dict)
incompatible_keys = set_peft_model_state_dict(unet_, unet_state_dict, adapter_name="default")
if incompatible_keys is not None:
@@ -910,9 +910,9 @@ def main():
index_no_updates[min(placeholder_token_ids) : max(placeholder_token_ids) + 1] = False
with torch.no_grad():
accelerator.unwrap_model(text_encoder).get_input_embeddings().weight[index_no_updates] = (
orig_embeds_params[index_no_updates]
)
accelerator.unwrap_model(text_encoder).get_input_embeddings().weight[
index_no_updates
] = orig_embeds_params[index_no_updates]
# Checks if the accelerator has performed an optimization step behind the scenes
if accelerator.sync_gradients:
@@ -965,12 +965,12 @@ def main():
index_no_updates_2[min(placeholder_token_ids_2) : max(placeholder_token_ids_2) + 1] = False
with torch.no_grad():
accelerator.unwrap_model(text_encoder_1).get_input_embeddings().weight[index_no_updates] = (
orig_embeds_params[index_no_updates]
)
accelerator.unwrap_model(text_encoder_2).get_input_embeddings().weight[index_no_updates_2] = (
orig_embeds_params_2[index_no_updates_2]
)
accelerator.unwrap_model(text_encoder_1).get_input_embeddings().weight[
index_no_updates
] = orig_embeds_params[index_no_updates]
accelerator.unwrap_model(text_encoder_2).get_input_embeddings().weight[
index_no_updates_2
] = orig_embeds_params_2[index_no_updates_2]
# Checks if the accelerator has performed an optimization step behind the scenes
if accelerator.sync_gradients:
+3 -3
View File
@@ -177,7 +177,7 @@ class TextToImage(ExamplesTestsAccelerate):
--model_config_name_or_path {vqmodel_config_path}
--discriminator_config_name_or_path {discriminator_config_path}
--checkpointing_steps=1
--resume_from_checkpoint={os.path.join(tmpdir, "checkpoint-4")}
--resume_from_checkpoint={os.path.join(tmpdir, 'checkpoint-4')}
--output_dir {tmpdir}
--seed=0
""".split()
@@ -262,7 +262,7 @@ class TextToImage(ExamplesTestsAccelerate):
--model_config_name_or_path {vqmodel_config_path}
--discriminator_config_name_or_path {discriminator_config_path}
--checkpointing_steps=1
--resume_from_checkpoint={os.path.join(tmpdir, "checkpoint-4")}
--resume_from_checkpoint={os.path.join(tmpdir, 'checkpoint-4')}
--output_dir {tmpdir}
--use_ema
--seed=0
@@ -377,7 +377,7 @@ class TextToImage(ExamplesTestsAccelerate):
--discriminator_config_name_or_path {discriminator_config_path}
--output_dir {tmpdir}
--checkpointing_steps=2
--resume_from_checkpoint={os.path.join(tmpdir, "checkpoint-4")}
--resume_from_checkpoint={os.path.join(tmpdir, 'checkpoint-4')}
--checkpoints_total_limit=2
--seed=0
""".split()
+6 -6
View File
@@ -653,15 +653,15 @@ def main():
try:
# Gets the resolution of the timm transformation after centercrop
timm_centercrop_transform = timm_transform.transforms[1]
assert isinstance(timm_centercrop_transform, transforms.CenterCrop), (
f"Timm model {timm_model} is currently incompatible with this script. Try vgg19."
)
assert isinstance(
timm_centercrop_transform, transforms.CenterCrop
), f"Timm model {timm_model} is currently incompatible with this script. Try vgg19."
timm_model_resolution = timm_centercrop_transform.size[0]
# Gets final normalization
timm_model_normalization = timm_transform.transforms[-1]
assert isinstance(timm_model_normalization, transforms.Normalize), (
f"Timm model {timm_model} is currently incompatible with this script. Try vgg19."
)
assert isinstance(
timm_model_normalization, transforms.Normalize
), f"Timm model {timm_model} is currently incompatible with this script. Try vgg19."
except AssertionError as e:
raise NotImplementedError(e)
# Enable flash attention if asked
+1 -1
View File
@@ -3,7 +3,7 @@ line-length = 119
[tool.ruff.lint]
# Never enforce `E501` (line length violations).
ignore = ["C901", "E501", "E721", "E741", "F402", "F823"]
ignore = ["C901", "E501", "E741", "F402", "F823"]
select = ["C", "E", "F", "I", "W"]
# Ignore import violations in all `__init__.py` files.
+1 -1
View File
@@ -468,7 +468,7 @@ def make_vqvae(old_vae):
# assert (old_output == new_output).all()
print("skipping full vae equivalence check")
print(f"vae full diff {(old_output - new_output).float().abs().sum()}")
print(f"vae full diff { (old_output - new_output).float().abs().sum()}")
return new_vae
+2 -2
View File
@@ -239,7 +239,7 @@ def con_pt_to_diffuser(checkpoint_path: str, unet_config):
if i != len(up_block_types) - 1:
new_prefix = f"up_blocks.{i}.upsamplers.0"
old_prefix = f"output_blocks.{current_layer - 1}.1"
old_prefix = f"output_blocks.{current_layer-1}.1"
new_checkpoint = convert_resnet(checkpoint, new_checkpoint, old_prefix, new_prefix)
elif layer_type == "AttnUpBlock2D":
for j in range(layers_per_block + 1):
@@ -255,7 +255,7 @@ def con_pt_to_diffuser(checkpoint_path: str, unet_config):
if i != len(up_block_types) - 1:
new_prefix = f"up_blocks.{i}.upsamplers.0"
old_prefix = f"output_blocks.{current_layer - 1}.2"
old_prefix = f"output_blocks.{current_layer-1}.2"
new_checkpoint = convert_resnet(checkpoint, new_checkpoint, old_prefix, new_prefix)
new_checkpoint["conv_norm_out.weight"] = checkpoint["out.0.weight"]
@@ -261,9 +261,9 @@ def main(args):
model_name = args.model_path.split("/")[-1].split(".")[0]
if not os.path.isfile(args.model_path):
assert model_name == args.model_path, (
f"Make sure to provide one of the official model names {MODELS_MAP.keys()}"
)
assert (
model_name == args.model_path
), f"Make sure to provide one of the official model names {MODELS_MAP.keys()}"
args.model_path = download(model_name)
sample_rate = MODELS_MAP[model_name]["sample_rate"]
@@ -290,9 +290,9 @@ def main(args):
assert all(k.endswith("kernel") for k in list(diffusers_minus_renamed)), f"Problem with {diffusers_minus_renamed}"
for key, value in renamed_state_dict.items():
assert diffusers_state_dict[key].squeeze().shape == value.squeeze().shape, (
f"Shape for {key} doesn't match. Diffusers: {diffusers_state_dict[key].shape} vs. {value.shape}"
)
assert (
diffusers_state_dict[key].squeeze().shape == value.squeeze().shape
), f"Shape for {key} doesn't match. Diffusers: {diffusers_state_dict[key].shape} vs. {value.shape}"
if key == "time_proj.weight":
value = value.squeeze()
@@ -52,18 +52,18 @@ for i in range(3):
for j in range(2):
# loop over resnets/attentions for downblocks
hf_down_res_prefix = f"down_blocks.{i}.resnets.{j}."
sd_down_res_prefix = f"input_blocks.{3 * i + j + 1}.0."
sd_down_res_prefix = f"input_blocks.{3*i + j + 1}.0."
unet_conversion_map_layer.append((sd_down_res_prefix, hf_down_res_prefix))
if i > 0:
hf_down_atn_prefix = f"down_blocks.{i}.attentions.{j}."
sd_down_atn_prefix = f"input_blocks.{3 * i + j + 1}.1."
sd_down_atn_prefix = f"input_blocks.{3*i + j + 1}.1."
unet_conversion_map_layer.append((sd_down_atn_prefix, hf_down_atn_prefix))
for j in range(4):
# loop over resnets/attentions for upblocks
hf_up_res_prefix = f"up_blocks.{i}.resnets.{j}."
sd_up_res_prefix = f"output_blocks.{3 * i + j}.0."
sd_up_res_prefix = f"output_blocks.{3*i + j}.0."
unet_conversion_map_layer.append((sd_up_res_prefix, hf_up_res_prefix))
if i < 2:
@@ -75,12 +75,12 @@ for i in range(3):
if i < 3:
# no downsample in down_blocks.3
hf_downsample_prefix = f"down_blocks.{i}.downsamplers.0.conv."
sd_downsample_prefix = f"input_blocks.{3 * (i + 1)}.0.op."
sd_downsample_prefix = f"input_blocks.{3*(i+1)}.0.op."
unet_conversion_map_layer.append((sd_downsample_prefix, hf_downsample_prefix))
# no upsample in up_blocks.3
hf_upsample_prefix = f"up_blocks.{i}.upsamplers.0."
sd_upsample_prefix = f"output_blocks.{3 * i + 2}.{1 if i == 0 else 2}."
sd_upsample_prefix = f"output_blocks.{3*i + 2}.{1 if i == 0 else 2}."
unet_conversion_map_layer.append((sd_upsample_prefix, hf_upsample_prefix))
unet_conversion_map_layer.append(("output_blocks.2.2.conv.", "output_blocks.2.1.conv."))
@@ -89,7 +89,7 @@ sd_mid_atn_prefix = "middle_block.1."
unet_conversion_map_layer.append((sd_mid_atn_prefix, hf_mid_atn_prefix))
for j in range(2):
hf_mid_res_prefix = f"mid_block.resnets.{j}."
sd_mid_res_prefix = f"middle_block.{2 * j}."
sd_mid_res_prefix = f"middle_block.{2*j}."
unet_conversion_map_layer.append((sd_mid_res_prefix, hf_mid_res_prefix))
@@ -137,20 +137,20 @@ for i in range(4):
vae_conversion_map.append((sd_downsample_prefix, hf_downsample_prefix))
hf_upsample_prefix = f"up_blocks.{i}.upsamplers.0."
sd_upsample_prefix = f"up.{3 - i}.upsample."
sd_upsample_prefix = f"up.{3-i}.upsample."
vae_conversion_map.append((sd_upsample_prefix, hf_upsample_prefix))
# up_blocks have three resnets
# also, up blocks in hf are numbered in reverse from sd
for j in range(3):
hf_up_prefix = f"decoder.up_blocks.{i}.resnets.{j}."
sd_up_prefix = f"decoder.up.{3 - i}.block.{j}."
sd_up_prefix = f"decoder.up.{3-i}.block.{j}."
vae_conversion_map.append((sd_up_prefix, hf_up_prefix))
# this part accounts for mid blocks in both the encoder and the decoder
for i in range(2):
hf_mid_res_prefix = f"mid_block.resnets.{i}."
sd_mid_res_prefix = f"mid.block_{i + 1}."
sd_mid_res_prefix = f"mid.block_{i+1}."
vae_conversion_map.append((sd_mid_res_prefix, hf_mid_res_prefix))
@@ -47,36 +47,36 @@ for i in range(4):
for j in range(2):
# loop over resnets/attentions for downblocks
hf_down_res_prefix = f"down_blocks.{i}.resnets.{j}."
sd_down_res_prefix = f"input_blocks.{3 * i + j + 1}.0."
sd_down_res_prefix = f"input_blocks.{3*i + j + 1}.0."
unet_conversion_map_layer.append((sd_down_res_prefix, hf_down_res_prefix))
if i < 3:
# no attention layers in down_blocks.3
hf_down_atn_prefix = f"down_blocks.{i}.attentions.{j}."
sd_down_atn_prefix = f"input_blocks.{3 * i + j + 1}.1."
sd_down_atn_prefix = f"input_blocks.{3*i + j + 1}.1."
unet_conversion_map_layer.append((sd_down_atn_prefix, hf_down_atn_prefix))
for j in range(3):
# loop over resnets/attentions for upblocks
hf_up_res_prefix = f"up_blocks.{i}.resnets.{j}."
sd_up_res_prefix = f"output_blocks.{3 * i + j}.0."
sd_up_res_prefix = f"output_blocks.{3*i + j}.0."
unet_conversion_map_layer.append((sd_up_res_prefix, hf_up_res_prefix))
if i > 0:
# no attention layers in up_blocks.0
hf_up_atn_prefix = f"up_blocks.{i}.attentions.{j}."
sd_up_atn_prefix = f"output_blocks.{3 * i + j}.1."
sd_up_atn_prefix = f"output_blocks.{3*i + j}.1."
unet_conversion_map_layer.append((sd_up_atn_prefix, hf_up_atn_prefix))
if i < 3:
# no downsample in down_blocks.3
hf_downsample_prefix = f"down_blocks.{i}.downsamplers.0.conv."
sd_downsample_prefix = f"input_blocks.{3 * (i + 1)}.0.op."
sd_downsample_prefix = f"input_blocks.{3*(i+1)}.0.op."
unet_conversion_map_layer.append((sd_downsample_prefix, hf_downsample_prefix))
# no upsample in up_blocks.3
hf_upsample_prefix = f"up_blocks.{i}.upsamplers.0."
sd_upsample_prefix = f"output_blocks.{3 * i + 2}.{1 if i == 0 else 2}."
sd_upsample_prefix = f"output_blocks.{3*i + 2}.{1 if i == 0 else 2}."
unet_conversion_map_layer.append((sd_upsample_prefix, hf_upsample_prefix))
hf_mid_atn_prefix = "mid_block.attentions.0."
@@ -85,7 +85,7 @@ unet_conversion_map_layer.append((sd_mid_atn_prefix, hf_mid_atn_prefix))
for j in range(2):
hf_mid_res_prefix = f"mid_block.resnets.{j}."
sd_mid_res_prefix = f"middle_block.{2 * j}."
sd_mid_res_prefix = f"middle_block.{2*j}."
unet_conversion_map_layer.append((sd_mid_res_prefix, hf_mid_res_prefix))
@@ -133,20 +133,20 @@ for i in range(4):
vae_conversion_map.append((sd_downsample_prefix, hf_downsample_prefix))
hf_upsample_prefix = f"up_blocks.{i}.upsamplers.0."
sd_upsample_prefix = f"up.{3 - i}.upsample."
sd_upsample_prefix = f"up.{3-i}.upsample."
vae_conversion_map.append((sd_upsample_prefix, hf_upsample_prefix))
# up_blocks have three resnets
# also, up blocks in hf are numbered in reverse from sd
for j in range(3):
hf_up_prefix = f"decoder.up_blocks.{i}.resnets.{j}."
sd_up_prefix = f"decoder.up.{3 - i}.block.{j}."
sd_up_prefix = f"decoder.up.{3-i}.block.{j}."
vae_conversion_map.append((sd_up_prefix, hf_up_prefix))
# this part accounts for mid blocks in both the encoder and the decoder
for i in range(2):
hf_mid_res_prefix = f"mid_block.resnets.{i}."
sd_mid_res_prefix = f"mid.block_{i + 1}."
sd_mid_res_prefix = f"mid.block_{i+1}."
vae_conversion_map.append((sd_mid_res_prefix, hf_mid_res_prefix))
@@ -21,9 +21,9 @@ def main(args):
model_config = HunyuanDiT2DControlNetModel.load_config(
"Tencent-Hunyuan/HunyuanDiT-v1.2-Diffusers", subfolder="transformer"
)
model_config["use_style_cond_and_image_meta_size"] = (
args.use_style_cond_and_image_meta_size
) ### version <= v1.1: True; version >= v1.2: False
model_config[
"use_style_cond_and_image_meta_size"
] = args.use_style_cond_and_image_meta_size ### version <= v1.1: True; version >= v1.2: False
print(model_config)
for key in state_dict:
+5 -4
View File
@@ -13,14 +13,15 @@ def main(args):
state_dict = state_dict[args.load_key]
except KeyError:
raise KeyError(
f"{args.load_key} not found in the checkpoint.Please load from the following keys:{state_dict.keys()}"
f"{args.load_key} not found in the checkpoint."
f"Please load from the following keys:{state_dict.keys()}"
)
device = "cuda"
model_config = HunyuanDiT2DModel.load_config("Tencent-Hunyuan/HunyuanDiT-Diffusers", subfolder="transformer")
model_config["use_style_cond_and_image_meta_size"] = (
args.use_style_cond_and_image_meta_size
) ### version <= v1.1: True; version >= v1.2: False
model_config[
"use_style_cond_and_image_meta_size"
] = args.use_style_cond_and_image_meta_size ### version <= v1.1: True; version >= v1.2: False
# input_size -> sample_size, text_dim -> cross_attention_dim
for key in state_dict:
+5 -5
View File
@@ -142,14 +142,14 @@ def block_to_diffusers_checkpoint(block, checkpoint, block_idx, block_type):
diffusers_attention_prefix = f"{block_type}_blocks.{block_idx}.attentions.{attention_idx}"
idx = n * attention_idx + 1 if block_type == "up" else n * attention_idx + 2
self_attention_prefix = f"{block_prefix}.{idx}"
cross_attention_prefix = f"{block_prefix}.{idx}"
cross_attention_prefix = f"{block_prefix}.{idx }"
cross_attention_index = 1 if not attention.add_self_attention else 2
idx = (
n * attention_idx + cross_attention_index
if block_type == "up"
else n * attention_idx + cross_attention_index + 1
)
cross_attention_prefix = f"{block_prefix}.{idx}"
cross_attention_prefix = f"{block_prefix}.{idx }"
diffusers_checkpoint.update(
cross_attn_to_diffusers_checkpoint(
@@ -220,9 +220,9 @@ def unet_model_from_original_config(original_config):
block_out_channels = original_config["channels"]
assert len(set(original_config["depths"])) == 1, (
"UNet2DConditionModel currently do not support blocks with different number of layers"
)
assert (
len(set(original_config["depths"])) == 1
), "UNet2DConditionModel currently do not support blocks with different number of layers"
layers_per_block = original_config["depths"][0]
class_labels_dim = original_config["mapping_cond_dim"]
+58 -60
View File
@@ -168,28 +168,28 @@ def convert_mochi_vae_state_dict_to_diffusers(encoder_ckpt_path, decoder_ckpt_pa
# Convert block_in (MochiMidBlock3D)
for i in range(3): # layers_per_block[-1] = 3
new_state_dict[f"{prefix}block_in.resnets.{i}.norm1.norm_layer.weight"] = decoder_state_dict.pop(
f"blocks.0.{i + 1}.stack.0.weight"
f"blocks.0.{i+1}.stack.0.weight"
)
new_state_dict[f"{prefix}block_in.resnets.{i}.norm1.norm_layer.bias"] = decoder_state_dict.pop(
f"blocks.0.{i + 1}.stack.0.bias"
f"blocks.0.{i+1}.stack.0.bias"
)
new_state_dict[f"{prefix}block_in.resnets.{i}.conv1.conv.weight"] = decoder_state_dict.pop(
f"blocks.0.{i + 1}.stack.2.weight"
f"blocks.0.{i+1}.stack.2.weight"
)
new_state_dict[f"{prefix}block_in.resnets.{i}.conv1.conv.bias"] = decoder_state_dict.pop(
f"blocks.0.{i + 1}.stack.2.bias"
f"blocks.0.{i+1}.stack.2.bias"
)
new_state_dict[f"{prefix}block_in.resnets.{i}.norm2.norm_layer.weight"] = decoder_state_dict.pop(
f"blocks.0.{i + 1}.stack.3.weight"
f"blocks.0.{i+1}.stack.3.weight"
)
new_state_dict[f"{prefix}block_in.resnets.{i}.norm2.norm_layer.bias"] = decoder_state_dict.pop(
f"blocks.0.{i + 1}.stack.3.bias"
f"blocks.0.{i+1}.stack.3.bias"
)
new_state_dict[f"{prefix}block_in.resnets.{i}.conv2.conv.weight"] = decoder_state_dict.pop(
f"blocks.0.{i + 1}.stack.5.weight"
f"blocks.0.{i+1}.stack.5.weight"
)
new_state_dict[f"{prefix}block_in.resnets.{i}.conv2.conv.bias"] = decoder_state_dict.pop(
f"blocks.0.{i + 1}.stack.5.bias"
f"blocks.0.{i+1}.stack.5.bias"
)
# Convert up_blocks (MochiUpBlock3D)
@@ -197,35 +197,33 @@ def convert_mochi_vae_state_dict_to_diffusers(encoder_ckpt_path, decoder_ckpt_pa
for block in range(3):
for i in range(down_block_layers[block]):
new_state_dict[f"{prefix}up_blocks.{block}.resnets.{i}.norm1.norm_layer.weight"] = decoder_state_dict.pop(
f"blocks.{block + 1}.blocks.{i}.stack.0.weight"
f"blocks.{block+1}.blocks.{i}.stack.0.weight"
)
new_state_dict[f"{prefix}up_blocks.{block}.resnets.{i}.norm1.norm_layer.bias"] = decoder_state_dict.pop(
f"blocks.{block + 1}.blocks.{i}.stack.0.bias"
f"blocks.{block+1}.blocks.{i}.stack.0.bias"
)
new_state_dict[f"{prefix}up_blocks.{block}.resnets.{i}.conv1.conv.weight"] = decoder_state_dict.pop(
f"blocks.{block + 1}.blocks.{i}.stack.2.weight"
f"blocks.{block+1}.blocks.{i}.stack.2.weight"
)
new_state_dict[f"{prefix}up_blocks.{block}.resnets.{i}.conv1.conv.bias"] = decoder_state_dict.pop(
f"blocks.{block + 1}.blocks.{i}.stack.2.bias"
f"blocks.{block+1}.blocks.{i}.stack.2.bias"
)
new_state_dict[f"{prefix}up_blocks.{block}.resnets.{i}.norm2.norm_layer.weight"] = decoder_state_dict.pop(
f"blocks.{block + 1}.blocks.{i}.stack.3.weight"
f"blocks.{block+1}.blocks.{i}.stack.3.weight"
)
new_state_dict[f"{prefix}up_blocks.{block}.resnets.{i}.norm2.norm_layer.bias"] = decoder_state_dict.pop(
f"blocks.{block + 1}.blocks.{i}.stack.3.bias"
f"blocks.{block+1}.blocks.{i}.stack.3.bias"
)
new_state_dict[f"{prefix}up_blocks.{block}.resnets.{i}.conv2.conv.weight"] = decoder_state_dict.pop(
f"blocks.{block + 1}.blocks.{i}.stack.5.weight"
f"blocks.{block+1}.blocks.{i}.stack.5.weight"
)
new_state_dict[f"{prefix}up_blocks.{block}.resnets.{i}.conv2.conv.bias"] = decoder_state_dict.pop(
f"blocks.{block + 1}.blocks.{i}.stack.5.bias"
f"blocks.{block+1}.blocks.{i}.stack.5.bias"
)
new_state_dict[f"{prefix}up_blocks.{block}.proj.weight"] = decoder_state_dict.pop(
f"blocks.{block + 1}.proj.weight"
)
new_state_dict[f"{prefix}up_blocks.{block}.proj.bias"] = decoder_state_dict.pop(
f"blocks.{block + 1}.proj.bias"
f"blocks.{block+1}.proj.weight"
)
new_state_dict[f"{prefix}up_blocks.{block}.proj.bias"] = decoder_state_dict.pop(f"blocks.{block+1}.proj.bias")
# Convert block_out (MochiMidBlock3D)
for i in range(3): # layers_per_block[0] = 3
@@ -269,133 +267,133 @@ def convert_mochi_vae_state_dict_to_diffusers(encoder_ckpt_path, decoder_ckpt_pa
# Convert block_in (MochiMidBlock3D)
for i in range(3): # layers_per_block[0] = 3
new_state_dict[f"{prefix}block_in.resnets.{i}.norm1.norm_layer.weight"] = encoder_state_dict.pop(
f"layers.{i + 1}.stack.0.weight"
f"layers.{i+1}.stack.0.weight"
)
new_state_dict[f"{prefix}block_in.resnets.{i}.norm1.norm_layer.bias"] = encoder_state_dict.pop(
f"layers.{i + 1}.stack.0.bias"
f"layers.{i+1}.stack.0.bias"
)
new_state_dict[f"{prefix}block_in.resnets.{i}.conv1.conv.weight"] = encoder_state_dict.pop(
f"layers.{i + 1}.stack.2.weight"
f"layers.{i+1}.stack.2.weight"
)
new_state_dict[f"{prefix}block_in.resnets.{i}.conv1.conv.bias"] = encoder_state_dict.pop(
f"layers.{i + 1}.stack.2.bias"
f"layers.{i+1}.stack.2.bias"
)
new_state_dict[f"{prefix}block_in.resnets.{i}.norm2.norm_layer.weight"] = encoder_state_dict.pop(
f"layers.{i + 1}.stack.3.weight"
f"layers.{i+1}.stack.3.weight"
)
new_state_dict[f"{prefix}block_in.resnets.{i}.norm2.norm_layer.bias"] = encoder_state_dict.pop(
f"layers.{i + 1}.stack.3.bias"
f"layers.{i+1}.stack.3.bias"
)
new_state_dict[f"{prefix}block_in.resnets.{i}.conv2.conv.weight"] = encoder_state_dict.pop(
f"layers.{i + 1}.stack.5.weight"
f"layers.{i+1}.stack.5.weight"
)
new_state_dict[f"{prefix}block_in.resnets.{i}.conv2.conv.bias"] = encoder_state_dict.pop(
f"layers.{i + 1}.stack.5.bias"
f"layers.{i+1}.stack.5.bias"
)
# Convert down_blocks (MochiDownBlock3D)
down_block_layers = [3, 4, 6] # layers_per_block[1], layers_per_block[2], layers_per_block[3]
for block in range(3):
new_state_dict[f"{prefix}down_blocks.{block}.conv_in.conv.weight"] = encoder_state_dict.pop(
f"layers.{block + 4}.layers.0.weight"
f"layers.{block+4}.layers.0.weight"
)
new_state_dict[f"{prefix}down_blocks.{block}.conv_in.conv.bias"] = encoder_state_dict.pop(
f"layers.{block + 4}.layers.0.bias"
f"layers.{block+4}.layers.0.bias"
)
for i in range(down_block_layers[block]):
# Convert resnets
new_state_dict[f"{prefix}down_blocks.{block}.resnets.{i}.norm1.norm_layer.weight"] = (
encoder_state_dict.pop(f"layers.{block + 4}.layers.{i + 1}.stack.0.weight")
)
new_state_dict[
f"{prefix}down_blocks.{block}.resnets.{i}.norm1.norm_layer.weight"
] = encoder_state_dict.pop(f"layers.{block+4}.layers.{i+1}.stack.0.weight")
new_state_dict[f"{prefix}down_blocks.{block}.resnets.{i}.norm1.norm_layer.bias"] = encoder_state_dict.pop(
f"layers.{block + 4}.layers.{i + 1}.stack.0.bias"
f"layers.{block+4}.layers.{i+1}.stack.0.bias"
)
new_state_dict[f"{prefix}down_blocks.{block}.resnets.{i}.conv1.conv.weight"] = encoder_state_dict.pop(
f"layers.{block + 4}.layers.{i + 1}.stack.2.weight"
f"layers.{block+4}.layers.{i+1}.stack.2.weight"
)
new_state_dict[f"{prefix}down_blocks.{block}.resnets.{i}.conv1.conv.bias"] = encoder_state_dict.pop(
f"layers.{block + 4}.layers.{i + 1}.stack.2.bias"
)
new_state_dict[f"{prefix}down_blocks.{block}.resnets.{i}.norm2.norm_layer.weight"] = (
encoder_state_dict.pop(f"layers.{block + 4}.layers.{i + 1}.stack.3.weight")
f"layers.{block+4}.layers.{i+1}.stack.2.bias"
)
new_state_dict[
f"{prefix}down_blocks.{block}.resnets.{i}.norm2.norm_layer.weight"
] = encoder_state_dict.pop(f"layers.{block+4}.layers.{i+1}.stack.3.weight")
new_state_dict[f"{prefix}down_blocks.{block}.resnets.{i}.norm2.norm_layer.bias"] = encoder_state_dict.pop(
f"layers.{block + 4}.layers.{i + 1}.stack.3.bias"
f"layers.{block+4}.layers.{i+1}.stack.3.bias"
)
new_state_dict[f"{prefix}down_blocks.{block}.resnets.{i}.conv2.conv.weight"] = encoder_state_dict.pop(
f"layers.{block + 4}.layers.{i + 1}.stack.5.weight"
f"layers.{block+4}.layers.{i+1}.stack.5.weight"
)
new_state_dict[f"{prefix}down_blocks.{block}.resnets.{i}.conv2.conv.bias"] = encoder_state_dict.pop(
f"layers.{block + 4}.layers.{i + 1}.stack.5.bias"
f"layers.{block+4}.layers.{i+1}.stack.5.bias"
)
# Convert attentions
qkv_weight = encoder_state_dict.pop(f"layers.{block + 4}.layers.{i + 1}.attn_block.attn.qkv.weight")
qkv_weight = encoder_state_dict.pop(f"layers.{block+4}.layers.{i+1}.attn_block.attn.qkv.weight")
q, k, v = qkv_weight.chunk(3, dim=0)
new_state_dict[f"{prefix}down_blocks.{block}.attentions.{i}.to_q.weight"] = q
new_state_dict[f"{prefix}down_blocks.{block}.attentions.{i}.to_k.weight"] = k
new_state_dict[f"{prefix}down_blocks.{block}.attentions.{i}.to_v.weight"] = v
new_state_dict[f"{prefix}down_blocks.{block}.attentions.{i}.to_out.0.weight"] = encoder_state_dict.pop(
f"layers.{block + 4}.layers.{i + 1}.attn_block.attn.out.weight"
f"layers.{block+4}.layers.{i+1}.attn_block.attn.out.weight"
)
new_state_dict[f"{prefix}down_blocks.{block}.attentions.{i}.to_out.0.bias"] = encoder_state_dict.pop(
f"layers.{block + 4}.layers.{i + 1}.attn_block.attn.out.bias"
f"layers.{block+4}.layers.{i+1}.attn_block.attn.out.bias"
)
new_state_dict[f"{prefix}down_blocks.{block}.norms.{i}.norm_layer.weight"] = encoder_state_dict.pop(
f"layers.{block + 4}.layers.{i + 1}.attn_block.norm.weight"
f"layers.{block+4}.layers.{i+1}.attn_block.norm.weight"
)
new_state_dict[f"{prefix}down_blocks.{block}.norms.{i}.norm_layer.bias"] = encoder_state_dict.pop(
f"layers.{block + 4}.layers.{i + 1}.attn_block.norm.bias"
f"layers.{block+4}.layers.{i+1}.attn_block.norm.bias"
)
# Convert block_out (MochiMidBlock3D)
for i in range(3): # layers_per_block[-1] = 3
# Convert resnets
new_state_dict[f"{prefix}block_out.resnets.{i}.norm1.norm_layer.weight"] = encoder_state_dict.pop(
f"layers.{i + 7}.stack.0.weight"
f"layers.{i+7}.stack.0.weight"
)
new_state_dict[f"{prefix}block_out.resnets.{i}.norm1.norm_layer.bias"] = encoder_state_dict.pop(
f"layers.{i + 7}.stack.0.bias"
f"layers.{i+7}.stack.0.bias"
)
new_state_dict[f"{prefix}block_out.resnets.{i}.conv1.conv.weight"] = encoder_state_dict.pop(
f"layers.{i + 7}.stack.2.weight"
f"layers.{i+7}.stack.2.weight"
)
new_state_dict[f"{prefix}block_out.resnets.{i}.conv1.conv.bias"] = encoder_state_dict.pop(
f"layers.{i + 7}.stack.2.bias"
f"layers.{i+7}.stack.2.bias"
)
new_state_dict[f"{prefix}block_out.resnets.{i}.norm2.norm_layer.weight"] = encoder_state_dict.pop(
f"layers.{i + 7}.stack.3.weight"
f"layers.{i+7}.stack.3.weight"
)
new_state_dict[f"{prefix}block_out.resnets.{i}.norm2.norm_layer.bias"] = encoder_state_dict.pop(
f"layers.{i + 7}.stack.3.bias"
f"layers.{i+7}.stack.3.bias"
)
new_state_dict[f"{prefix}block_out.resnets.{i}.conv2.conv.weight"] = encoder_state_dict.pop(
f"layers.{i + 7}.stack.5.weight"
f"layers.{i+7}.stack.5.weight"
)
new_state_dict[f"{prefix}block_out.resnets.{i}.conv2.conv.bias"] = encoder_state_dict.pop(
f"layers.{i + 7}.stack.5.bias"
f"layers.{i+7}.stack.5.bias"
)
# Convert attentions
qkv_weight = encoder_state_dict.pop(f"layers.{i + 7}.attn_block.attn.qkv.weight")
qkv_weight = encoder_state_dict.pop(f"layers.{i+7}.attn_block.attn.qkv.weight")
q, k, v = qkv_weight.chunk(3, dim=0)
new_state_dict[f"{prefix}block_out.attentions.{i}.to_q.weight"] = q
new_state_dict[f"{prefix}block_out.attentions.{i}.to_k.weight"] = k
new_state_dict[f"{prefix}block_out.attentions.{i}.to_v.weight"] = v
new_state_dict[f"{prefix}block_out.attentions.{i}.to_out.0.weight"] = encoder_state_dict.pop(
f"layers.{i + 7}.attn_block.attn.out.weight"
f"layers.{i+7}.attn_block.attn.out.weight"
)
new_state_dict[f"{prefix}block_out.attentions.{i}.to_out.0.bias"] = encoder_state_dict.pop(
f"layers.{i + 7}.attn_block.attn.out.bias"
f"layers.{i+7}.attn_block.attn.out.bias"
)
new_state_dict[f"{prefix}block_out.norms.{i}.norm_layer.weight"] = encoder_state_dict.pop(
f"layers.{i + 7}.attn_block.norm.weight"
f"layers.{i+7}.attn_block.norm.weight"
)
new_state_dict[f"{prefix}block_out.norms.{i}.norm_layer.bias"] = encoder_state_dict.pop(
f"layers.{i + 7}.attn_block.norm.bias"
f"layers.{i+7}.attn_block.norm.bias"
)
# Convert output layers
@@ -662,7 +662,7 @@ def convert_open_clap_checkpoint(checkpoint):
# replace sequential layers with list
sequential_layer = re.match(sequential_layers_pattern, key).group(1)
key = key.replace(f"sequential.{sequential_layer}.", f"layers.{int(sequential_layer) // 3}.linear.")
key = key.replace(f"sequential.{sequential_layer}.", f"layers.{int(sequential_layer)//3}.linear.")
elif re.match(text_projection_pattern, key):
projecton_layer = int(re.match(text_projection_pattern, key).group(1))
@@ -636,7 +636,7 @@ def convert_open_clap_checkpoint(checkpoint):
# replace sequential layers with list
sequential_layer = re.match(sequential_layers_pattern, key).group(1)
key = key.replace(f"sequential.{sequential_layer}.", f"layers.{int(sequential_layer) // 3}.linear.")
key = key.replace(f"sequential.{sequential_layer}.", f"layers.{int(sequential_layer)//3}.linear.")
elif re.match(text_projection_pattern, key):
projecton_layer = int(re.match(text_projection_pattern, key).group(1))
@@ -642,7 +642,7 @@ def convert_open_clap_checkpoint(checkpoint):
# replace sequential layers with list
sequential_layer = re.match(sequential_layers_pattern, key).group(1)
key = key.replace(f"sequential.{sequential_layer}.", f"layers.{int(sequential_layer) // 3}.linear.")
key = key.replace(f"sequential.{sequential_layer}.", f"layers.{int(sequential_layer)//3}.linear.")
elif re.match(text_projection_pattern, key):
projecton_layer = int(re.match(text_projection_pattern, key).group(1))
+9 -9
View File
@@ -95,18 +95,18 @@ def convert_stable_audio_state_dict_to_diffusers(state_dict, num_autoencoder_lay
# get idx of the layer
idx = int(new_key.split("coder.layers.")[1].split(".")[0])
new_key = new_key.replace(f"coder.layers.{idx}", f"coder.block.{idx - 1}")
new_key = new_key.replace(f"coder.layers.{idx}", f"coder.block.{idx-1}")
if "encoder" in new_key:
for i in range(3):
new_key = new_key.replace(f"block.{idx - 1}.layers.{i}", f"block.{idx - 1}.res_unit{i + 1}")
new_key = new_key.replace(f"block.{idx - 1}.layers.3", f"block.{idx - 1}.snake1")
new_key = new_key.replace(f"block.{idx - 1}.layers.4", f"block.{idx - 1}.conv1")
new_key = new_key.replace(f"block.{idx-1}.layers.{i}", f"block.{idx-1}.res_unit{i+1}")
new_key = new_key.replace(f"block.{idx-1}.layers.3", f"block.{idx-1}.snake1")
new_key = new_key.replace(f"block.{idx-1}.layers.4", f"block.{idx-1}.conv1")
else:
for i in range(2, 5):
new_key = new_key.replace(f"block.{idx - 1}.layers.{i}", f"block.{idx - 1}.res_unit{i - 1}")
new_key = new_key.replace(f"block.{idx - 1}.layers.0", f"block.{idx - 1}.snake1")
new_key = new_key.replace(f"block.{idx - 1}.layers.1", f"block.{idx - 1}.conv_t1")
new_key = new_key.replace(f"block.{idx-1}.layers.{i}", f"block.{idx-1}.res_unit{i-1}")
new_key = new_key.replace(f"block.{idx-1}.layers.0", f"block.{idx-1}.snake1")
new_key = new_key.replace(f"block.{idx-1}.layers.1", f"block.{idx-1}.conv_t1")
new_key = new_key.replace("layers.0.beta", "snake1.beta")
new_key = new_key.replace("layers.0.alpha", "snake1.alpha")
@@ -118,9 +118,9 @@ def convert_stable_audio_state_dict_to_diffusers(state_dict, num_autoencoder_lay
new_key = new_key.replace("layers.3.weight_", "conv2.weight_")
if idx == num_autoencoder_layers + 1:
new_key = new_key.replace(f"block.{idx - 1}", "snake1")
new_key = new_key.replace(f"block.{idx-1}", "snake1")
elif idx == num_autoencoder_layers + 2:
new_key = new_key.replace(f"block.{idx - 1}", "conv2")
new_key = new_key.replace(f"block.{idx-1}", "conv2")
else:
new_key = new_key
+6 -6
View File
@@ -381,9 +381,9 @@ def convert_ldm_unet_checkpoint(
# TODO resnet time_mixer.mix_factor
if f"input_blocks.{i}.0.time_mixer.mix_factor" in unet_state_dict:
new_checkpoint[f"down_blocks.{block_id}.resnets.{layer_in_block_id}.time_mixer.mix_factor"] = (
unet_state_dict[f"input_blocks.{i}.0.time_mixer.mix_factor"]
)
new_checkpoint[
f"down_blocks.{block_id}.resnets.{layer_in_block_id}.time_mixer.mix_factor"
] = unet_state_dict[f"input_blocks.{i}.0.time_mixer.mix_factor"]
if len(attentions):
paths = renew_attention_paths(attentions)
@@ -478,9 +478,9 @@ def convert_ldm_unet_checkpoint(
)
if f"output_blocks.{i}.0.time_mixer.mix_factor" in unet_state_dict:
new_checkpoint[f"up_blocks.{block_id}.resnets.{layer_in_block_id}.time_mixer.mix_factor"] = (
unet_state_dict[f"output_blocks.{i}.0.time_mixer.mix_factor"]
)
new_checkpoint[
f"up_blocks.{block_id}.resnets.{layer_in_block_id}.time_mixer.mix_factor"
] = unet_state_dict[f"output_blocks.{i}.0.time_mixer.mix_factor"]
output_block_list = {k: sorted(v) for k, v in output_block_list.items()}
if ["conv.bias", "conv.weight"] in output_block_list.values():
+12 -12
View File
@@ -51,9 +51,9 @@ PORTED_VQVAES = ["image_synthesis.modeling.codecs.image_codec.patch_vqgan.PatchV
def vqvae_model_from_original_config(original_config):
assert original_config["target"] in PORTED_VQVAES, (
f"{original_config['target']} has not yet been ported to diffusers."
)
assert (
original_config["target"] in PORTED_VQVAES
), f"{original_config['target']} has not yet been ported to diffusers."
original_config = original_config["params"]
@@ -464,15 +464,15 @@ PORTED_CONTENT_EMBEDDINGS = ["image_synthesis.modeling.embeddings.dalle_mask_ima
def transformer_model_from_original_config(
original_diffusion_config, original_transformer_config, original_content_embedding_config
):
assert original_diffusion_config["target"] in PORTED_DIFFUSIONS, (
f"{original_diffusion_config['target']} has not yet been ported to diffusers."
)
assert original_transformer_config["target"] in PORTED_TRANSFORMERS, (
f"{original_transformer_config['target']} has not yet been ported to diffusers."
)
assert original_content_embedding_config["target"] in PORTED_CONTENT_EMBEDDINGS, (
f"{original_content_embedding_config['target']} has not yet been ported to diffusers."
)
assert (
original_diffusion_config["target"] in PORTED_DIFFUSIONS
), f"{original_diffusion_config['target']} has not yet been ported to diffusers."
assert (
original_transformer_config["target"] in PORTED_TRANSFORMERS
), f"{original_transformer_config['target']} has not yet been ported to diffusers."
assert (
original_content_embedding_config["target"] in PORTED_CONTENT_EMBEDDINGS
), f"{original_content_embedding_config['target']} has not yet been ported to diffusers."
original_diffusion_config = original_diffusion_config["params"]
original_transformer_config = original_transformer_config["params"]
+1 -1
View File
@@ -122,7 +122,7 @@ _deps = [
"pytest-timeout",
"pytest-xdist",
"python>=3.8.0",
"ruff==0.9.10",
"ruff==0.1.5",
"safetensors>=0.3.1",
"sentencepiece>=0.1.91,!=0.1.92",
"GitPython<3.1.19",
+25 -4
View File
@@ -33,6 +33,7 @@ from .utils import (
_import_structure = {
"configuration_utils": ["ConfigMixin"],
"guiders": [],
"hooks": [],
"loaders": ["FromOriginalModelMixin"],
"models": [],
@@ -129,12 +130,25 @@ except OptionalDependencyNotAvailable:
_import_structure["utils.dummy_pt_objects"] = [name for name in dir(dummy_pt_objects) if not name.startswith("_")]
else:
_import_structure["guiders"].extend(
[
"AdaptiveProjectedGuidance",
"ClassifierFreeGuidance",
"ClassifierFreeZeroStarGuidance",
"PerturbedAttentionGuidance",
"SkipLayerGuidance",
]
)
_import_structure["hooks"].extend(
[
"FasterCacheConfig",
"FirstBlockCacheConfig",
"HookRegistry",
"LayerSkipConfig",
"PyramidAttentionBroadcastConfig",
"apply_faster_cache",
"apply_first_block_cache",
"apply_layer_skip",
"apply_pyramid_attention_broadcast",
]
)
@@ -155,7 +169,6 @@ else:
"AutoencoderKLWan",
"AutoencoderOobleck",
"AutoencoderTiny",
"AutoModel",
"CacheMixin",
"CogVideoXTransformer3DModel",
"CogView3PlusTransformer2DModel",
@@ -198,7 +211,6 @@ else:
"T2IAdapter",
"T5FilmDecoder",
"Transformer2DModel",
"TransformerTemporalModel",
"UNet1DModel",
"UNet2DConditionModel",
"UNet2DModel",
@@ -710,11 +722,22 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
except OptionalDependencyNotAvailable:
from .utils.dummy_pt_objects import * # noqa F403
else:
from .guiders import (
AdaptiveProjectedGuidance,
ClassifierFreeGuidance,
ClassifierFreeZeroStarGuidance,
PerturbedAttentionGuidance,
SkipLayerGuidance,
)
from .hooks import (
FasterCacheConfig,
FirstBlockCacheConfig,
HookRegistry,
LayerSkipConfig,
PyramidAttentionBroadcastConfig,
apply_faster_cache,
apply_first_block_cache,
apply_layer_skip,
apply_pyramid_attention_broadcast,
)
from .models import (
@@ -733,7 +756,6 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
AutoencoderKLWan,
AutoencoderOobleck,
AutoencoderTiny,
AutoModel,
CacheMixin,
CogVideoXTransformer3DModel,
CogView3PlusTransformer2DModel,
@@ -775,7 +797,6 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
T2IAdapter,
T5FilmDecoder,
Transformer2DModel,
TransformerTemporalModel,
UNet1DModel,
UNet2DConditionModel,
UNet2DModel,
+1 -1
View File
@@ -29,7 +29,7 @@ deps = {
"pytest-timeout": "pytest-timeout",
"pytest-xdist": "pytest-xdist",
"python": "python>=3.8.0",
"ruff": "ruff==0.9.10",
"ruff": "ruff==0.1.5",
"safetensors": "safetensors>=0.3.1",
"sentencepiece": "sentencepiece>=0.1.91,!=0.1.92",
"GitPython": "GitPython<3.1.19",
+24
View File
@@ -0,0 +1,24 @@
# Copyright 2024 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from ..utils import is_torch_available
if is_torch_available():
from .adaptive_projected_guidance import AdaptiveProjectedGuidance
from .classifier_free_guidance import ClassifierFreeGuidance
from .classifier_free_zero_star_guidance import ClassifierFreeZeroStarGuidance
from .guider_utils import GuidanceMixin, _raise_guidance_deprecation_warning
from .perturbed_attention_guidance import PerturbedAttentionGuidance
from .skip_layer_guidance import SkipLayerGuidance
@@ -0,0 +1,145 @@
# Copyright 2024 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import math
from typing import Optional
import torch
from .guider_utils import GuidanceMixin, rescale_noise_cfg
class AdaptiveProjectedGuidance(GuidanceMixin):
"""
Adaptive Projected Guidance (APG): https://huggingface.co/papers/2410.02416
Args:
guidance_scale (`float`, defaults to `7.5`):
The scale parameter for classifier-free guidance. Higher values result in stronger conditioning on the text
prompt, while lower values allow for more freedom in generation. Higher values may lead to saturation and
deterioration of image quality.
adaptive_projected_guidance_momentum (`float`, defaults to `None`):
The momentum parameter for the adaptive projected guidance. Disabled if set to `None`.
adaptive_projected_guidance_rescale (`float`, defaults to `15.0`):
The rescale factor applied to the noise predictions. This is used to improve image quality and fix
guidance_rescale (`float`, defaults to `0.0`):
The rescale factor applied to the noise predictions. This is used to improve image quality and fix
overexposure. Based on Section 3.4 from [Common Diffusion Noise Schedules and Sample Steps are
Flawed](https://huggingface.co/papers/2305.08891).
use_original_formulation (`bool`, defaults to `False`):
Whether to use the original formulation of classifier-free guidance as proposed in the paper. By default,
we use the diffusers-native implementation that has been in the codebase for a long time. See
[~guiders.classifier_free_guidance.ClassifierFreeGuidance] for more details.
"""
_input_predictions = ["pred_cond", "pred_uncond"]
def __init__(
self,
guidance_scale: float = 7.5,
adaptive_projected_guidance_momentum: Optional[float] = None,
adaptive_projected_guidance_rescale: float = 15.0,
eta: float = 1.0,
guidance_rescale: float = 0.0,
use_original_formulation: bool = False,
):
super().__init__()
self.guidance_scale = guidance_scale
self.adaptive_projected_guidance_momentum = adaptive_projected_guidance_momentum
self.adaptive_projected_guidance_rescale = adaptive_projected_guidance_rescale
self.eta = eta
self.guidance_rescale = guidance_rescale
self.use_original_formulation = use_original_formulation
self.momentum_buffer = None
def prepare_inputs(self, *args):
if self._step == 0:
if self.adaptive_projected_guidance_momentum is not None:
self.momentum_buffer = MomentumBuffer(self.adaptive_projected_guidance_momentum)
return super().prepare_inputs(*args)
def forward(self, pred_cond: torch.Tensor, pred_uncond: Optional[torch.Tensor] = None) -> torch.Tensor:
pred = None
if self._is_cfg_enabled():
pred = pred_cond
else:
pred = normalized_guidance(
pred_cond,
pred_uncond,
self.guidance_scale,
self.momentum_buffer,
self.eta,
self.adaptive_projected_guidance_rescale,
self.use_original_formulation,
)
if self.guidance_rescale > 0.0:
pred = rescale_noise_cfg(pred, pred_cond, self.guidance_rescale)
return pred
@property
def num_conditions(self) -> int:
num_conditions = 1
if self._is_cfg_enabled():
num_conditions += 1
return num_conditions
def _is_cfg_enabled(self) -> bool:
if self.use_original_formulation:
return not math.isclose(self.guidance_scale, 0.0)
else:
return not math.isclose(self.guidance_scale, 1.0)
class MomentumBuffer:
def __init__(self, momentum: float):
self.momentum = momentum
self.running_average = 0
def update(self, update_value: torch.Tensor):
new_average = self.momentum * self.running_average
self.running_average = update_value + new_average
def normalized_guidance(
pred_cond: torch.Tensor,
pred_uncond: torch.Tensor,
guidance_scale: float,
momentum_buffer: Optional[MomentumBuffer] = None,
eta: float = 1.0,
norm_threshold: float = 0.0,
use_original_formulation: bool = False,
):
diff = pred_cond - pred_uncond
dim = [-i for i in range(1, len(diff.shape))]
if momentum_buffer is not None:
momentum_buffer.update(diff)
diff = momentum_buffer.running_average
if norm_threshold > 0:
ones = torch.ones_like(diff)
diff_norm = diff.norm(p=2, dim=dim, keepdim=True)
scale_factor = torch.minimum(ones, norm_threshold / diff_norm)
diff = diff * scale_factor
v0, v1 = diff.double(), pred_cond.double()
v1 = torch.nn.functional.normalize(v1, dim=dim)
v0_parallel = (v0 * v1).sum(dim=dim, keepdim=True) * v1
v0_orthogonal = v0 - v0_parallel
diff_parallel, diff_orthogonal = v0_parallel.type_as(diff), v0_orthogonal.type_as(diff)
normalized_update = diff_orthogonal + eta * diff_parallel
pred = pred_cond if use_original_formulation else pred_uncond
pred = pred + (guidance_scale - 1) * normalized_update
return pred
@@ -0,0 +1,98 @@
# Copyright 2024 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import math
from typing import Optional
import torch
from .guider_utils import GuidanceMixin, rescale_noise_cfg
class ClassifierFreeGuidance(GuidanceMixin):
"""
Classifier-free guidance (CFG): https://huggingface.co/papers/2207.12598
CFG is a technique used to improve generation quality and condition-following in diffusion models. It works by
jointly training a model on both conditional and unconditional data, and using a weighted sum of the two during
inference. This allows the model to tradeoff between generation quality and sample diversity.
The original paper proposes scaling and shifting the conditional distribution based on the difference between
conditional and unconditional predictions. [x_pred = x_cond + scale * (x_cond - x_uncond)]
Diffusers implemented the scaling and shifting on the unconditional prediction instead based on the [Imagen
paper](https://huggingface.co/papers/2205.11487), which is equivalent to what the original paper proposed in
theory. [x_pred = x_uncond + scale * (x_cond - x_uncond)]
The intution behind the original formulation can be thought of as moving the conditional distribution estimates
further away from the unconditional distribution estimates, while the diffusers-native implementation can be
thought of as moving the unconditional distribution towards the conditional distribution estimates to get rid of
the unconditional predictions (usually negative features like "bad quality, bad anotomy, watermarks", etc.)
The `use_original_formulation` argument can be set to `True` to use the original CFG formulation mentioned in the
paper. By default, we use the diffusers-native implementation that has been in the codebase for a long time.
Args:
guidance_scale (`float`, defaults to `7.5`):
The scale parameter for classifier-free guidance. Higher values result in stronger conditioning on the text
prompt, while lower values allow for more freedom in generation. Higher values may lead to saturation and
deterioration of image quality.
guidance_rescale (`float`, defaults to `0.0`):
The rescale factor applied to the noise predictions. This is used to improve image quality and fix
overexposure. Based on Section 3.4 from [Common Diffusion Noise Schedules and Sample Steps are
Flawed](https://huggingface.co/papers/2305.08891).
use_original_formulation (`bool`, defaults to `False`):
Whether to use the original formulation of classifier-free guidance as proposed in the paper. By default,
we use the diffusers-native implementation that has been in the codebase for a long time. See
[~guiders.classifier_free_guidance.ClassifierFreeGuidance] for more details.
"""
_input_predictions = ["pred_cond", "pred_uncond"]
def __init__(
self, guidance_scale: float = 7.5, guidance_rescale: float = 0.0, use_original_formulation: bool = False
):
super().__init__()
self.guidance_scale = guidance_scale
self.guidance_rescale = guidance_rescale
self.use_original_formulation = use_original_formulation
def forward(self, pred_cond: torch.Tensor, pred_uncond: Optional[torch.Tensor] = None) -> torch.Tensor:
pred = None
if not self._is_cfg_enabled():
pred = pred_cond
else:
shift = pred_cond - pred_uncond
pred = pred_cond if self.use_original_formulation else pred_uncond
pred = pred + self.guidance_scale * shift
if self.guidance_rescale > 0.0:
pred = rescale_noise_cfg(pred, pred_cond, self.guidance_rescale)
return pred
@property
def num_conditions(self) -> int:
num_conditions = 1
if self._is_cfg_enabled():
num_conditions += 1
return num_conditions
def _is_cfg_enabled(self) -> bool:
if self.use_original_formulation:
return not math.isclose(self.guidance_scale, 0.0)
else:
return not math.isclose(self.guidance_scale, 1.0)
@@ -0,0 +1,110 @@
# Copyright 2024 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import math
from typing import Optional
import torch
from .guider_utils import GuidanceMixin, rescale_noise_cfg
class ClassifierFreeZeroStarGuidance(GuidanceMixin):
"""
Classifier-free Zero* (CFG-Zero*): https://huggingface.co/papers/2503.18886
This is an implementation of the Classifier-Free Zero* guidance technique, which is a variant of classifier-free
guidance. It proposes zero initialization of the noise predictions for the first few steps of the diffusion
process, and also introduces an optimal rescaling factor for the noise predictions, which can help in improving the
quality of generated images.
The authors of the paper suggest setting zero initialization in the first 4% of the inference steps.
Args:
guidance_scale (`float`, defaults to `7.5`):
The scale parameter for classifier-free guidance. Higher values result in stronger conditioning on the text
prompt, while lower values allow for more freedom in generation. Higher values may lead to saturation and
deterioration of image quality.
zero_init_steps (`int`, defaults to `1`):
The number of inference steps for which the noise predictions are zeroed out (see Section 4.2).
guidance_rescale (`float`, defaults to `0.0`):
The rescale factor applied to the noise predictions. This is used to improve image quality and fix
overexposure. Based on Section 3.4 from [Common Diffusion Noise Schedules and Sample Steps are
Flawed](https://huggingface.co/papers/2305.08891).
use_original_formulation (`bool`, defaults to `False`):
Whether to use the original formulation of classifier-free guidance as proposed in the paper. By default,
we use the diffusers-native implementation that has been in the codebase for a long time. See
[~guiders.classifier_free_guidance.ClassifierFreeGuidance] for more details.
"""
_input_predictions = ["pred_cond", "pred_uncond"]
def __init__(
self,
guidance_scale: float = 7.5,
zero_init_steps: int = 1,
guidance_rescale: float = 0.0,
use_original_formulation: bool = False,
):
super().__init__()
self.guidance_scale = guidance_scale
self.zero_init_steps = zero_init_steps
self.guidance_rescale = guidance_rescale
self.use_original_formulation = use_original_formulation
def forward(self, pred_cond: torch.Tensor, pred_uncond: Optional[torch.Tensor] = None) -> torch.Tensor:
pred = None
if self._step < self.zero_init_steps:
pred = torch.zeros_like(pred_cond)
elif self._is_cfg_enabled():
pred = pred_cond
else:
shift = pred_cond - pred_uncond
pred_cond_flat = pred_cond.flatten(1)
pred_uncond_flat = pred_uncond.flatten(1)
alpha = cfg_zero_star_scale(pred_cond_flat, pred_uncond_flat)
alpha = alpha.view(-1, *(1,) * (len(pred_cond.shape) - 1))
pred_uncond = pred_uncond * alpha
pred = pred_cond if self.use_original_formulation else pred_uncond
pred = pred + self.guidance_scale * shift
if self.guidance_rescale > 0.0:
pred = rescale_noise_cfg(pred, pred_cond, self.guidance_rescale)
return pred
@property
def num_conditions(self) -> int:
num_conditions = 1
if self._is_cfg_enabled():
num_conditions += 1
return num_conditions
def _is_cfg_enabled(self) -> bool:
if self.use_original_formulation:
return not math.isclose(self.guidance_scale, 0.0)
else:
return not math.isclose(self.guidance_scale, 1.0)
def cfg_zero_star_scale(cond: torch.Tensor, uncond: torch.Tensor, eps: float = 1e-8) -> torch.Tensor:
cond = cond.float()
uncond = uncond.float()
dot_product = torch.sum(cond * uncond, dim=1, keepdim=True)
squared_norm = torch.sum(uncond**2, dim=1, keepdim=True) + eps
# st_star = v_cond^T * v_uncond / ||v_uncond||^2
scale = dot_product / squared_norm
return scale.type_as(cond)

Some files were not shown because too many files have changed in this diff Show More