Compare commits

...

11 Commits

Author SHA1 Message Date
DN6 7e13c18b29 update 2025-04-08 20:38:58 +05:30
Linoy Tsaban 71f34fc5a4 [Flux LoRA] fix issues in flux lora scripts (#11111)
* remove custom scheduler

* update requirements.txt

* log_validation with mixed precision

* add intermediate embeddings saving when checkpointing is enabled

* remove comment

* fix validation

* add unwrap_model for accelerator, torch.no_grad context for validation, fix accelerator.accumulate call in advanced script

* revert unwrap_model change temp

* add .module to address distributed training bug + replace accelerator.unwrap_model with unwrap model

* changes to align advanced script with canonical script

* make changes for distributed training + unify unwrap_model calls in advanced script

* add module.dtype fix to dreambooth script

* unify unwrap_model calls in dreambooth script

* fix condition in validation run

* mixed precision

* Update examples/advanced_diffusion_training/train_dreambooth_lora_flux_advanced.py

Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>

* smol style change

* change autocast

* Apply style fixes

---------

Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>
Co-authored-by: github-actions[bot] <github-actions[bot]@users.noreply.github.com>
2025-04-08 17:40:30 +03:00
Yao Matrix c51b6bd837 introduce compute arch specific expectations and fix test_sd3_img2img_inference failure (#11227)
* add arch specfic expectations support, to support different arch's numerical characteristics

Signed-off-by: YAO Matrix <matrix.yao@intel.com>

* fix typo

Signed-off-by: YAO Matrix <matrix.yao@intel.com>

* Apply suggestions from code review

* Apply style fixes

* Update src/diffusers/utils/testing_utils.py

---------

Signed-off-by: YAO Matrix <matrix.yao@intel.com>
Co-authored-by: hlky <hlky@hlky.ac>
Co-authored-by: github-actions[bot] <github-actions[bot]@users.noreply.github.com>
2025-04-08 14:57:49 +01:00
Benjamin Bossan fb54499614 [LoRA] Implement hot-swapping of LoRA (#9453)
* [WIP][LoRA] Implement hot-swapping of LoRA

This PR adds the possibility to hot-swap LoRA adapters. It is WIP.

Description

As of now, users can already load multiple LoRA adapters. They can
offload existing adapters or they can unload them (i.e. delete them).
However, they cannot "hotswap" adapters yet, i.e. substitute the weights
from one LoRA adapter with the weights of another, without the need to
create a separate LoRA adapter.

Generally, hot-swapping may not appear not super useful but when the
model is compiled, it is necessary to prevent recompilation. See #9279
for more context.

Caveats

To hot-swap a LoRA adapter for another, these two adapters should target
exactly the same layers and the "hyper-parameters" of the two adapters
should be identical. For instance, the LoRA alpha has to be the same:
Given that we keep the alpha from the first adapter, the LoRA scaling
would be incorrect for the second adapter otherwise.

Theoretically, we could override the scaling dict with the alpha values
derived from the second adapter's config, but changing the dict will
trigger a guard for recompilation, defeating the main purpose of the
feature.

I also found that compilation flags can have an impact on whether this
works or not. E.g. when passing "reduce-overhead", there will be errors
of the type:

> input name: arg861_1. data pointer changed from 139647332027392 to
139647331054592

I don't know enough about compilation to determine whether this is
problematic or not.

Current state

This is obviously WIP right now to collect feedback and discuss which
direction to take this. If this PR turns out to be useful, the
hot-swapping functions will be added to PEFT itself and can be imported
here (or there is a separate copy in diffusers to avoid the need for a
min PEFT version to use this feature).

Moreover, more tests need to be added to better cover this feature,
although we don't necessarily need tests for the hot-swapping
functionality itself, since those tests will be added to PEFT.

Furthermore, as of now, this is only implemented for the unet. Other
pipeline components have yet to implement this feature.

Finally, it should be properly documented.

I would like to collect feedback on the current state of the PR before
putting more time into finalizing it.

* Reviewer feedback

* Reviewer feedback, adjust test

* Fix, doc

* Make fix

* Fix for possible g++ error

* Add test for recompilation w/o hotswapping

* Make hotswap work

Requires https://github.com/huggingface/peft/pull/2366

More changes to make hotswapping work. Together with the mentioned PEFT
PR, the tests pass for me locally.

List of changes:

- docstring for hotswap
- remove code copied from PEFT, import from PEFT now
- adjustments to PeftAdapterMixin.load_lora_adapter (unfortunately, some
  state dict renaming was necessary, LMK if there is a better solution)
- adjustments to UNet2DConditionLoadersMixin._process_lora: LMK if this
  is even necessary or not, I'm unsure what the overall relationship is
  between this and PeftAdapterMixin.load_lora_adapter
- also in UNet2DConditionLoadersMixin._process_lora, I saw that there is
  no LoRA unloading when loading the adapter fails, so I added it
  there (in line with what happens in PeftAdapterMixin.load_lora_adapter)
- rewritten tests to avoid shelling out, make the test more precise by
  making sure that the outputs align, parametrize it
- also checked the pipeline code mentioned in this comment:
  https://github.com/huggingface/diffusers/pull/9453#issuecomment-2418508871;
  when running this inside the with
  torch._dynamo.config.patch(error_on_recompile=True) context, there is
  no error, so I think hotswapping is now working with pipelines.

* Address reviewer feedback:

- Revert deprecated method
- Fix PEFT doc link to main
- Don't use private function
- Clarify magic numbers
- Add pipeline test

Moreover:
- Extend docstrings
- Extend existing test for outputs != 0
- Extend existing test for wrong adapter name

* Change order of test decorators

parameterized.expand seems to ignore skip decorators if added in last
place (i.e. innermost decorator).

* Split model and pipeline tests

Also increase test coverage by also targeting conv2d layers (support of
which was added recently on the PEFT PR).

* Reviewer feedback: Move decorator to test classes

... instead of having them on each test method.

* Apply suggestions from code review

Co-authored-by: hlky <hlky@hlky.ac>

* Reviewer feedback: version check, TODO comment

* Add enable_lora_hotswap method

* Reviewer feedback: check _lora_loadable_modules

* Revert changes in unet.py

* Add possibility to ignore enabled at wrong time

* Fix docstrings

* Log possible PEFT error, test

* Raise helpful error if hotswap not supported

I.e. for the text encoder

* Formatting

* More linter

* More ruff

* Doc-builder complaint

* Update docstring:

- mention no text encoder support yet
- make it clear that LoRA is meant
- mention that same adapter name should be passed

* Fix error in docstring

* Update more methods with hotswap argument

- SDXL
- SD3
- Flux

No changes were made to load_lora_into_transformer.

* Add hotswap argument to load_lora_into_transformer

For SD3 and Flux. Use shorter docstring for brevity.

* Extend docstrings

* Add version guards to tests

* Formatting

* Fix LoRA loading call to add prefix=None

See:
https://github.com/huggingface/diffusers/pull/10187#issuecomment-2717571064

* Run make fix-copies

* Add hot swap documentation to the docs

* Apply suggestions from code review

Co-authored-by: Steven Liu <59462357+stevhliu@users.noreply.github.com>

---------

Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>
Co-authored-by: hlky <hlky@hlky.ac>
Co-authored-by: YiYi Xu <yixu310@gmail.com>
Co-authored-by: Steven Liu <59462357+stevhliu@users.noreply.github.com>
2025-04-08 17:05:31 +05:30
Álvaro Somoza 723dbdd363 [Training] Better image interpolation in training scripts (#11206)
* initial

* Update examples/dreambooth/train_dreambooth_lora_sdxl.py

Co-authored-by: hlky <hlky@hlky.ac>

* update

---------

Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>
Co-authored-by: hlky <hlky@hlky.ac>
2025-04-08 12:26:07 +05:30
Bhavay Malhotra fbf61f465b [train_controlnet.py] Fix the LR schedulers when num_train_epochs is passed in a distributed training env (#8461)
* Create diffusers.yml

* fix num_train_epochs

* Delete diffusers.yml

* Fixed Changes

---------

Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>
Co-authored-by: YiYi Xu <yixu310@gmail.com>
2025-04-08 12:10:09 +05:30
Inigo Goiri 841504bb1a Add support to pass image embeddings to the WAN I2V pipeline. (#11175)
* Add support to pass image embeddings to the pipeline.



---------

Co-authored-by: hlky <hlky@hlky.ac>
Co-authored-by: github-actions[bot] <github-actions[bot]@users.noreply.github.com>
Co-authored-by: YiYi Xu <yixu310@gmail.com>
2025-04-07 15:47:06 -10:00
Steven Liu fc7a867ae5 [docs] MPS update (#11212)
mps
2025-04-07 14:32:27 -10:00
alex choi 5ded26cdc7 ensure dtype match between diffused latents and vae weights (#8391) 2025-04-07 12:59:10 -10:00
Yao Matrix 506f39af3a enable 1 case on XPU (#11219)
enable case on XPU: 1. tests/quantization/bnb/test_mixed_int8.py::BnB8bitTrainingTests::test_training

Signed-off-by: YAO Matrix <matrix.yao@intel.com>
2025-04-07 08:24:21 +01:00
Mikko Tukiainen 8ad68c1393 Add missing MochiEncoder3D.gradient_checkpointing attribute (#11146)
* Add missing 'gradient_checkpointing = False' attr

* Add (limited) tests for Mochi autoencoder

* Apply style fixes

* pass 'conv_cache' as arg instead of kwarg

---------

Co-authored-by: github-actions[bot] <github-actions[bot]@users.noreply.github.com>
2025-04-06 02:46:45 +05:30
28 changed files with 1953 additions and 249 deletions
@@ -14,6 +14,7 @@ specific language governing permissions and limitations under the License.
<div class="flex flex-wrap space-x-1">
<img alt="LoRA" src="https://img.shields.io/badge/LoRA-d8b4fe?style=flat"/>
<img alt="MPS" src="https://img.shields.io/badge/MPS-000000?style=flat&logo=apple&logoColor=white%22">
</div>
## Overview
+1
View File
@@ -14,6 +14,7 @@ specific language governing permissions and limitations under the License.
<div class="flex flex-wrap space-x-1">
<img alt="LoRA" src="https://img.shields.io/badge/LoRA-d8b4fe?style=flat"/>
<img alt="MPS" src="https://img.shields.io/badge/MPS-000000?style=flat&logo=apple&logoColor=white%22">
</div>
Flux is a series of text-to-image generation models based on diffusion transformers. To know more about Flux, check out the original [blog post](https://blackforestlabs.ai/announcing-black-forest-labs/) by the creators of Flux, Black Forest Labs.
+1
View File
@@ -14,6 +14,7 @@ specific language governing permissions and limitations under the License.
<div class="flex flex-wrap space-x-1">
<img alt="LoRA" src="https://img.shields.io/badge/LoRA-d8b4fe?style=flat"/>
<img alt="MPS" src="https://img.shields.io/badge/MPS-000000?style=flat&logo=apple&logoColor=white%22">
</div>
![](https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/kolors/kolors_header_collage.png)
@@ -16,6 +16,7 @@
<div class="flex flex-wrap space-x-1">
<img alt="LoRA" src="https://img.shields.io/badge/LoRA-d8b4fe?style=flat"/>
<img alt="MPS" src="https://img.shields.io/badge/MPS-000000?style=flat&logo=apple&logoColor=white%22">
</div>
[LTX Video](https://huggingface.co/Lightricks/LTX-Video) is the first DiT-based video generation model capable of generating high-quality videos in real-time. It produces 24 FPS videos at a 768x512 resolution faster than they can be watched. Trained on a large-scale dataset of diverse videos, the model generates high-resolution videos with realistic and varied content. We provide a model for both text-to-video as well as image + text-to-video usecases.
+1
View File
@@ -16,6 +16,7 @@
<div class="flex flex-wrap space-x-1">
<img alt="LoRA" src="https://img.shields.io/badge/LoRA-d8b4fe?style=flat"/>
<img alt="MPS" src="https://img.shields.io/badge/MPS-000000?style=flat&logo=apple&logoColor=white%22">
</div>
[SANA: Efficient High-Resolution Image Synthesis with Linear Diffusion Transformers](https://huggingface.co/papers/2410.10629) from NVIDIA and MIT HAN Lab, by Enze Xie, Junsong Chen, Junyu Chen, Han Cai, Haotian Tang, Yujun Lin, Zhekai Zhang, Muyang Li, Ligeng Zhu, Yao Lu, Song Han.
@@ -14,6 +14,7 @@ specific language governing permissions and limitations under the License.
<div class="flex flex-wrap space-x-1">
<img alt="LoRA" src="https://img.shields.io/badge/LoRA-d8b4fe?style=flat"/>
<img alt="MPS" src="https://img.shields.io/badge/MPS-000000?style=flat&logo=apple&logoColor=white%22">
</div>
Stable Diffusion 3 (SD3) was proposed in [Scaling Rectified Flow Transformers for High-Resolution Image Synthesis](https://arxiv.org/pdf/2403.03206.pdf) by Patrick Esser, Sumith Kulal, Andreas Blattmann, Rahim Entezari, Jonas Muller, Harry Saini, Yam Levi, Dominik Lorenz, Axel Sauer, Frederic Boesel, Dustin Podell, Tim Dockhorn, Zion English, Kyle Lacey, Alex Goodwin, Yannik Marek, and Robin Rombach.
@@ -14,6 +14,7 @@ specific language governing permissions and limitations under the License.
<div class="flex flex-wrap space-x-1">
<img alt="LoRA" src="https://img.shields.io/badge/LoRA-d8b4fe?style=flat"/>
<img alt="MPS" src="https://img.shields.io/badge/MPS-000000?style=flat&logo=apple&logoColor=white%22">
</div>
Stable Diffusion XL (SDXL) was proposed in [SDXL: Improving Latent Diffusion Models for High-Resolution Image Synthesis](https://huggingface.co/papers/2307.01952) by Dustin Podell, Zion English, Kyle Lacey, Andreas Blattmann, Tim Dockhorn, Jonas Müller, Joe Penna, and Robin Rombach.
+12 -1
View File
@@ -12,6 +12,9 @@ specific language governing permissions and limitations under the License.
# Metal Performance Shaders (MPS)
> [!TIP]
> Pipelines with a <img alt="MPS" src="https://img.shields.io/badge/MPS-000000?style=flat&logo=apple&logoColor=white%22"> badge indicate a model can take advantage of the MPS backend on Apple silicon devices for faster inference. Feel free to open a [Pull Request](https://github.com/huggingface/diffusers/compare) to add this badge to pipelines that are missing it.
🤗 Diffusers is compatible with Apple silicon (M1/M2 chips) using the PyTorch [`mps`](https://pytorch.org/docs/stable/notes/mps.html) device, which uses the Metal framework to leverage the GPU on MacOS devices. You'll need to have:
- macOS computer with Apple silicon (M1/M2) hardware
@@ -37,7 +40,7 @@ image
<Tip warning={true}>
Generating multiple prompts in a batch can [crash](https://github.com/huggingface/diffusers/issues/363) or fail to work reliably. We believe this is related to the [`mps`](https://github.com/pytorch/pytorch/issues/84039) backend in PyTorch. While this is being investigated, you should iterate instead of batching.
The PyTorch [mps](https://pytorch.org/docs/stable/notes/mps.html) backend does not support NDArray sizes greater than `2**32`. Please open an [Issue](https://github.com/huggingface/diffusers/issues/new/choose) if you encounter this problem so we can investigate.
</Tip>
@@ -59,6 +62,10 @@ If you're using **PyTorch 1.13**, you need to "prime" the pipeline with an addit
## Troubleshoot
This section lists some common issues with using the `mps` backend and how to solve them.
### Attention slicing
M1/M2 performance is very sensitive to memory pressure. When this occurs, the system automatically swaps if it needs to which significantly degrades performance.
To prevent this from happening, we recommend *attention slicing* to reduce memory pressure during inference and prevent swapping. This is especially relevant if your computer has less than 64GB of system RAM, or if you generate images at non-standard resolutions larger than 512×512 pixels. Call the [`~DiffusionPipeline.enable_attention_slicing`] function on your pipeline:
@@ -72,3 +79,7 @@ pipeline.enable_attention_slicing()
```
Attention slicing performs the costly attention operation in multiple steps instead of all at once. It usually improves performance by ~20% in computers without universal memory, but we've observed *better performance* in most Apple silicon computers unless you have 64GB of RAM or more.
### Batch inference
Generating multiple prompts in a batch can crash or fail to work reliably. If this is the case, try iterating instead of batching.
@@ -194,6 +194,59 @@ Currently, [`~loaders.StableDiffusionLoraLoaderMixin.set_adapters`] only support
</Tip>
### Hotswapping LoRA adapters
A common use case when serving multiple adapters is to load one adapter first, generate images, load another adapter, generate more images, load another adapter, etc. This workflow normally requires calling [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_weights`], [`~loaders.StableDiffusionLoraLoaderMixin.set_adapters`], and possibly [`~loaders.peft.PeftAdapterMixin.delete_adapters`] to save memory. Moreover, if the model is compiled using `torch.compile`, performing these steps requires recompilation, which takes time.
To better support this common workflow, you can "hotswap" a LoRA adapter, to avoid accumulating memory and in some cases, recompilation. It requires an adapter to already be loaded, and the new adapter weights are swapped in-place for the existing adapter.
Pass `hotswap=True` when loading a LoRA adapter to enable this feature. It is important to indicate the name of the existing adapter, (`default_0` is the default adapter name), to be swapped. If you loaded the first adapter with a different name, use that name instead.
```python
pipe = ...
# load adapter 1 as normal
pipeline.load_lora_weights(file_name_adapter_1)
# generate some images with adapter 1
...
# now hot swap the 2nd adapter
pipeline.load_lora_weights(file_name_adapter_2, hotswap=True, adapter_name="default_0")
# generate images with adapter 2
```
<Tip warning={true}>
Hotswapping is not currently supported for LoRA adapters that target the text encoder.
</Tip>
For compiled models, it is often (though not always if the second adapter targets identical LoRA ranks and scales) necessary to call [`~loaders.lora_base.LoraBaseMixin.enable_lora_hotswap`] to avoid recompilation. Use [`~loaders.lora_base.LoraBaseMixin.enable_lora_hotswap`] _before_ loading the first adapter, and `torch.compile` should be called _after_ loading the first adapter.
```python
pipe = ...
# call this extra method
pipe.enable_lora_hotswap(target_rank=max_rank)
# now load adapter 1
pipe.load_lora_weights(file_name_adapter_1)
# now compile the unet of the pipeline
pipe.unet = torch.compile(pipeline.unet, ...)
# generate some images with adapter 1
...
# now hot swap adapter 2
pipeline.load_lora_weights(file_name_adapter_2, hotswap=True, adapter_name="default_0")
# generate images with adapter 2
```
The `target_rank=max_rank` argument is important for setting the maximum rank among all LoRA adapters that will be loaded. If you have one adapter with rank 8 and another with rank 16, pass `target_rank=16`. You should use a higher value if in doubt. By default, this value is 128.
However, there can be situations where recompilation is unavoidable. For example, if the hotswapped adapter targets more layers than the initial adapter, then recompilation is triggered. Try to load the adapter that targets the most layers first. Refer to the PEFT docs on [hotswapping](https://huggingface.co/docs/peft/main/en/package_reference/hotswap#peft.utils.hotswap.hotswap_adapter) for more details about the limitations of this feature.
<Tip>
Move your code inside the `with torch._dynamo.config.patch(error_on_recompile=True)` context manager to detect if a model was recompiled. If you detect recompilation despite following all the steps above, please open an issue with [Diffusers](https://github.com/huggingface/diffusers/issues) with a reproducible example.
</Tip>
### Kohya and TheLastBen
Other popular LoRA trainers from the community include those by [Kohya](https://github.com/kohya-ss/sd-scripts/) and [TheLastBen](https://github.com/TheLastBen/fast-stable-diffusion). These trainers create different LoRA checkpoints than those trained by 🤗 Diffusers, but they can still be loaded in the same way.
@@ -1,7 +1,8 @@
accelerate>=0.16.0
accelerate>=0.31.0
torchvision
transformers>=4.25.1
transformers>=4.41.2
ftfy
tensorboard
Jinja2
peft==0.7.0
peft>=0.11.1
sentencepiece
@@ -24,7 +24,7 @@ import re
import shutil
from contextlib import nullcontext
from pathlib import Path
from typing import List, Optional, Union
from typing import List, Optional
import numpy as np
import torch
@@ -228,10 +228,20 @@ def log_validation(
# run inference
generator = torch.Generator(device=accelerator.device).manual_seed(args.seed) if args.seed is not None else None
autocast_ctx = nullcontext()
autocast_ctx = torch.autocast(accelerator.device.type) if not is_final_validation else nullcontext()
with autocast_ctx:
images = [pipeline(**pipeline_args, generator=generator).images[0] for _ in range(args.num_validation_images)]
# pre-calculate prompt embeds, pooled prompt embeds, text ids because t5 does not support autocast
with torch.no_grad():
prompt_embeds, pooled_prompt_embeds, text_ids = pipeline.encode_prompt(
pipeline_args["prompt"], prompt_2=pipeline_args["prompt"]
)
images = []
for _ in range(args.num_validation_images):
with autocast_ctx:
image = pipeline(
prompt_embeds=prompt_embeds, pooled_prompt_embeds=pooled_prompt_embeds, generator=generator
).images[0]
images.append(image)
for tracker in accelerator.trackers:
phase_name = "test" if is_final_validation else "validation"
@@ -657,6 +667,7 @@ def parse_args(input_args=None):
parser.add_argument(
"--adam_weight_decay_text_encoder", type=float, default=1e-03, help="Weight decay to use for text_encoder"
)
parser.add_argument(
"--lora_layers",
type=str,
@@ -666,6 +677,7 @@ def parse_args(input_args=None):
'E.g. - "to_k,to_q,to_v,to_out.0" will result in lora training of attention layers only. For more examples refer to https://github.com/huggingface/diffusers/blob/main/examples/advanced_diffusion_training/README_flux.md'
),
)
parser.add_argument(
"--adam_epsilon",
type=float,
@@ -738,6 +750,15 @@ def parse_args(input_args=None):
" flag passed with the `accelerate.launch` command. Use this argument to override the accelerate config."
),
)
parser.add_argument(
"--upcast_before_saving",
action="store_true",
default=False,
help=(
"Whether to upcast the trained transformer layers to float32 before saving (at the end of training). "
"Defaults to precision dtype used for training to save memory"
),
)
parser.add_argument(
"--prior_generation_precision",
type=str,
@@ -1147,7 +1168,7 @@ def tokenize_prompt(tokenizer, prompt, max_sequence_length, add_special_tokens=F
return text_input_ids
def _get_t5_prompt_embeds(
def _encode_prompt_with_t5(
text_encoder,
tokenizer,
max_sequence_length=512,
@@ -1176,7 +1197,10 @@ def _get_t5_prompt_embeds(
prompt_embeds = text_encoder(text_input_ids.to(device))[0]
dtype = text_encoder.dtype
if hasattr(text_encoder, "module"):
dtype = text_encoder.module.dtype
else:
dtype = text_encoder.dtype
prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)
_, seq_len, _ = prompt_embeds.shape
@@ -1188,7 +1212,7 @@ def _get_t5_prompt_embeds(
return prompt_embeds
def _get_clip_prompt_embeds(
def _encode_prompt_with_clip(
text_encoder,
tokenizer,
prompt: str,
@@ -1217,9 +1241,13 @@ def _get_clip_prompt_embeds(
prompt_embeds = text_encoder(text_input_ids.to(device), output_hidden_states=False)
if hasattr(text_encoder, "module"):
dtype = text_encoder.module.dtype
else:
dtype = text_encoder.dtype
# Use pooled output of CLIPTextModel
prompt_embeds = prompt_embeds.pooler_output
prompt_embeds = prompt_embeds.to(dtype=text_encoder.dtype, device=device)
prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)
# duplicate text embeddings for each generation per prompt, using mps friendly method
prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
@@ -1238,136 +1266,35 @@ def encode_prompt(
text_input_ids_list=None,
):
prompt = [prompt] if isinstance(prompt, str) else prompt
batch_size = len(prompt)
dtype = text_encoders[0].dtype
if hasattr(text_encoders[0], "module"):
dtype = text_encoders[0].module.dtype
else:
dtype = text_encoders[0].dtype
pooled_prompt_embeds = _get_clip_prompt_embeds(
pooled_prompt_embeds = _encode_prompt_with_clip(
text_encoder=text_encoders[0],
tokenizer=tokenizers[0],
prompt=prompt,
device=device if device is not None else text_encoders[0].device,
num_images_per_prompt=num_images_per_prompt,
text_input_ids=text_input_ids_list[0] if text_input_ids_list is not None else None,
text_input_ids=text_input_ids_list[0] if text_input_ids_list else None,
)
prompt_embeds = _get_t5_prompt_embeds(
prompt_embeds = _encode_prompt_with_t5(
text_encoder=text_encoders[1],
tokenizer=tokenizers[1],
max_sequence_length=max_sequence_length,
prompt=prompt,
num_images_per_prompt=num_images_per_prompt,
device=device if device is not None else text_encoders[1].device,
text_input_ids=text_input_ids_list[1] if text_input_ids_list is not None else None,
text_input_ids=text_input_ids_list[1] if text_input_ids_list else None,
)
text_ids = torch.zeros(batch_size, prompt_embeds.shape[1], 3).to(device=device, dtype=dtype)
text_ids = text_ids.repeat(num_images_per_prompt, 1, 1)
text_ids = torch.zeros(prompt_embeds.shape[1], 3).to(device=device, dtype=dtype)
return prompt_embeds, pooled_prompt_embeds, text_ids
# CustomFlowMatchEulerDiscreteScheduler was taken from ostris ai-toolkit trainer:
# https://github.com/ostris/ai-toolkit/blob/9ee1ef2a0a2a9a02b92d114a95f21312e5906e54/toolkit/samplers/custom_flowmatch_sampler.py#L95
class CustomFlowMatchEulerDiscreteScheduler(FlowMatchEulerDiscreteScheduler):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
with torch.no_grad():
# create weights for timesteps
num_timesteps = 1000
# generate the multiplier based on cosmap loss weighing
# this is only used on linear timesteps for now
# cosine map weighing is higher in the middle and lower at the ends
# bot = 1 - 2 * self.sigmas + 2 * self.sigmas ** 2
# cosmap_weighing = 2 / (math.pi * bot)
# sigma sqrt weighing is significantly higher at the end and lower at the beginning
sigma_sqrt_weighing = (self.sigmas**-2.0).float()
# clip at 1e4 (1e6 is too high)
sigma_sqrt_weighing = torch.clamp(sigma_sqrt_weighing, max=1e4)
# bring to a mean of 1
sigma_sqrt_weighing = sigma_sqrt_weighing / sigma_sqrt_weighing.mean()
# Create linear timesteps from 1000 to 0
timesteps = torch.linspace(1000, 0, num_timesteps, device="cpu")
self.linear_timesteps = timesteps
# self.linear_timesteps_weights = cosmap_weighing
self.linear_timesteps_weights = sigma_sqrt_weighing
# self.sigmas = self.get_sigmas(timesteps, n_dim=1, dtype=torch.float32, device='cpu')
pass
def get_weights_for_timesteps(self, timesteps: torch.Tensor) -> torch.Tensor:
# Get the indices of the timesteps
step_indices = [(self.timesteps == t).nonzero().item() for t in timesteps]
# Get the weights for the timesteps
weights = self.linear_timesteps_weights[step_indices].flatten()
return weights
def get_sigmas(self, timesteps: torch.Tensor, n_dim, dtype, device) -> torch.Tensor:
sigmas = self.sigmas.to(device=device, dtype=dtype)
schedule_timesteps = self.timesteps.to(device)
timesteps = timesteps.to(device)
step_indices = [(schedule_timesteps == t).nonzero().item() for t in timesteps]
sigma = sigmas[step_indices].flatten()
while len(sigma.shape) < n_dim:
sigma = sigma.unsqueeze(-1)
return sigma
def add_noise(
self,
original_samples: torch.Tensor,
noise: torch.Tensor,
timesteps: torch.Tensor,
) -> torch.Tensor:
## ref https://github.com/huggingface/diffusers/blob/fbe29c62984c33c6cf9cf7ad120a992fe6d20854/examples/dreambooth/train_dreambooth_sd3.py#L1578
## Add noise according to flow matching.
## zt = (1 - texp) * x + texp * z1
# sigmas = get_sigmas(timesteps, n_dim=model_input.ndim, dtype=model_input.dtype)
# noisy_model_input = (1.0 - sigmas) * model_input + sigmas * noise
# timestep needs to be in [0, 1], we store them in [0, 1000]
# noisy_sample = (1 - timestep) * latent + timestep * noise
t_01 = (timesteps / 1000).to(original_samples.device)
noisy_model_input = (1 - t_01) * original_samples + t_01 * noise
# n_dim = original_samples.ndim
# sigmas = self.get_sigmas(timesteps, n_dim, original_samples.dtype, original_samples.device)
# noisy_model_input = (1.0 - sigmas) * original_samples + sigmas * noise
return noisy_model_input
def scale_model_input(self, sample: torch.Tensor, timestep: Union[float, torch.Tensor]) -> torch.Tensor:
return sample
def set_train_timesteps(self, num_timesteps, device, linear=False):
if linear:
timesteps = torch.linspace(1000, 0, num_timesteps, device=device)
self.timesteps = timesteps
return timesteps
else:
# distribute them closer to center. Inference distributes them as a bias toward first
# Generate values from 0 to 1
t = torch.sigmoid(torch.randn((num_timesteps,), device=device))
# Scale and reverse the values to go from 1000 to 0
timesteps = (1 - t) * 1000
# Sort the timesteps in descending order
timesteps, _ = torch.sort(timesteps, descending=True)
self.timesteps = timesteps.to(device=device)
return timesteps
def main(args):
if args.report_to == "wandb" and args.hub_token is not None:
raise ValueError(
@@ -1499,7 +1426,7 @@ def main(args):
)
# Load scheduler and models
noise_scheduler = CustomFlowMatchEulerDiscreteScheduler.from_pretrained(
noise_scheduler = FlowMatchEulerDiscreteScheduler.from_pretrained(
args.pretrained_model_name_or_path, subfolder="scheduler"
)
noise_scheduler_copy = copy.deepcopy(noise_scheduler)
@@ -1619,7 +1546,6 @@ def main(args):
target_modules=target_modules,
)
transformer.add_adapter(transformer_lora_config)
if args.train_text_encoder:
text_lora_config = LoraConfig(
r=args.rank,
@@ -1727,7 +1653,6 @@ def main(args):
cast_training_params(models, dtype=torch.float32)
transformer_lora_parameters = list(filter(lambda p: p.requires_grad, transformer.parameters()))
if args.train_text_encoder:
text_lora_parameters_one = list(filter(lambda p: p.requires_grad, text_encoder_one.parameters()))
# if we use textual inversion, we freeze all parameters except for the token embeddings
@@ -1737,7 +1662,8 @@ def main(args):
for name, param in text_encoder_one.named_parameters():
if "token_embedding" in name:
# ensure that dtype is float32, even if rest of the model that isn't trained is loaded in fp16
param.data = param.to(dtype=torch.float32)
if args.mixed_precision == "fp16":
param.data = param.to(dtype=torch.float32)
param.requires_grad = True
text_lora_parameters_one.append(param)
else:
@@ -1747,7 +1673,8 @@ def main(args):
for name, param in text_encoder_two.named_parameters():
if "shared" in name:
# ensure that dtype is float32, even if rest of the model that isn't trained is loaded in fp16
param.data = param.to(dtype=torch.float32)
if args.mixed_precision == "fp16":
param.data = param.to(dtype=torch.float32)
param.requires_grad = True
text_lora_parameters_two.append(param)
else:
@@ -1828,6 +1755,7 @@ def main(args):
optimizer_class = bnb.optim.AdamW8bit
else:
optimizer_class = torch.optim.AdamW
optimizer = optimizer_class(
params_to_optimize,
betas=(args.adam_beta1, args.adam_beta2),
@@ -2021,6 +1949,7 @@ def main(args):
lr_scheduler,
)
else:
print("I SHOULD BE HERE")
transformer, text_encoder_one, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
transformer, text_encoder_one, optimizer, train_dataloader, lr_scheduler
)
@@ -2125,7 +2054,7 @@ def main(args):
if args.train_text_encoder:
text_encoder_one.train()
# set top parameter requires_grad = True for gradient checkpointing works
accelerator.unwrap_model(text_encoder_one).text_model.embeddings.requires_grad_(True)
unwrap_model(text_encoder_one).text_model.embeddings.requires_grad_(True)
elif args.train_text_encoder_ti: # textual inversion / pivotal tuning
text_encoder_one.train()
if args.enable_t5_ti:
@@ -2137,6 +2066,11 @@ def main(args):
pivoted_tr = True
for step, batch in enumerate(train_dataloader):
models_to_accumulate = [transformer]
if not freeze_text_encoder:
models_to_accumulate.extend([text_encoder_one])
if args.enable_t5_ti:
models_to_accumulate.extend([text_encoder_two])
if pivoted_te:
# stopping optimization of text_encoder params
optimizer.param_groups[te_idx]["lr"] = 0.0
@@ -2145,7 +2079,7 @@ def main(args):
logger.info(f"PIVOT TRANSFORMER {epoch}")
optimizer.param_groups[0]["lr"] = 0.0
with accelerator.accumulate(transformer):
with accelerator.accumulate(models_to_accumulate):
prompts = batch["prompts"]
# encode batch prompts when custom prompts are provided for each image -
@@ -2189,7 +2123,7 @@ def main(args):
model_input = (model_input - vae_config_shift_factor) * vae_config_scaling_factor
model_input = model_input.to(dtype=weight_dtype)
vae_scale_factor = 2 ** (len(vae_config_block_out_channels))
vae_scale_factor = 2 ** (len(vae_config_block_out_channels) - 1)
latent_image_ids = FluxPipeline._prepare_latent_image_ids(
model_input.shape[0],
@@ -2228,7 +2162,7 @@ def main(args):
)
# handle guidance
if transformer.config.guidance_embeds:
if unwrap_model(transformer).config.guidance_embeds:
guidance = torch.tensor([args.guidance_scale], device=accelerator.device)
guidance = guidance.expand(model_input.shape[0])
else:
@@ -2288,16 +2222,26 @@ def main(args):
accelerator.backward(loss)
if accelerator.sync_gradients:
if not freeze_text_encoder:
if args.train_text_encoder:
if args.train_text_encoder: # text encoder tuning
params_to_clip = itertools.chain(transformer.parameters(), text_encoder_one.parameters())
elif pure_textual_inversion:
params_to_clip = itertools.chain(
text_encoder_one.parameters(), text_encoder_two.parameters()
)
if args.enable_t5_ti:
params_to_clip = itertools.chain(
text_encoder_one.parameters(), text_encoder_two.parameters()
)
else:
params_to_clip = itertools.chain(text_encoder_one.parameters())
else:
params_to_clip = itertools.chain(
transformer.parameters(), text_encoder_one.parameters(), text_encoder_two.parameters()
)
if args.enable_t5_ti:
params_to_clip = itertools.chain(
transformer.parameters(),
text_encoder_one.parameters(),
text_encoder_two.parameters(),
)
else:
params_to_clip = itertools.chain(
transformer.parameters(), text_encoder_one.parameters()
)
else:
params_to_clip = itertools.chain(transformer.parameters())
accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm)
@@ -2339,6 +2283,10 @@ def main(args):
save_path = os.path.join(args.output_dir, f"checkpoint-{global_step}")
accelerator.save_state(save_path)
if args.train_text_encoder_ti:
embedding_handler.save_embeddings(
f"{args.output_dir}/{Path(args.output_dir).name}_emb_checkpoint_{global_step}.safetensors"
)
logger.info(f"Saved state to {save_path}")
logs = {"loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0]}
@@ -2351,14 +2299,16 @@ def main(args):
if accelerator.is_main_process:
if args.validation_prompt is not None and epoch % args.validation_epochs == 0:
# create pipeline
if freeze_text_encoder:
if freeze_text_encoder: # no text encoder one, two optimizations
text_encoder_one, text_encoder_two = load_text_encoders(text_encoder_cls_one, text_encoder_cls_two)
text_encoder_one.to(weight_dtype)
text_encoder_two.to(weight_dtype)
pipeline = FluxPipeline.from_pretrained(
args.pretrained_model_name_or_path,
vae=vae,
text_encoder=accelerator.unwrap_model(text_encoder_one),
text_encoder_2=accelerator.unwrap_model(text_encoder_two),
transformer=accelerator.unwrap_model(transformer),
text_encoder=unwrap_model(text_encoder_one),
text_encoder_2=unwrap_model(text_encoder_two),
transformer=unwrap_model(transformer),
revision=args.revision,
variant=args.variant,
torch_dtype=weight_dtype,
@@ -2372,21 +2322,21 @@ def main(args):
epoch=epoch,
torch_dtype=weight_dtype,
)
images = None
del pipeline
if freeze_text_encoder:
del text_encoder_one, text_encoder_two
free_memory()
elif args.train_text_encoder:
del text_encoder_two
free_memory()
images = None
del pipeline
# Save the lora layers
accelerator.wait_for_everyone()
if accelerator.is_main_process:
transformer = unwrap_model(transformer)
transformer = transformer.to(weight_dtype)
if args.upcast_before_saving:
transformer.to(torch.float32)
else:
transformer = transformer.to(weight_dtype)
transformer_lora_layers = get_peft_model_state_dict(transformer)
if args.train_text_encoder:
@@ -2428,8 +2378,8 @@ def main(args):
accelerator=accelerator,
pipeline_args=pipeline_args,
epoch=epoch,
torch_dtype=weight_dtype,
is_final_validation=True,
torch_dtype=weight_dtype,
)
save_model_card(
@@ -2452,6 +2402,7 @@ def main(args):
commit_message="End of training",
ignore_patterns=["step_*", "epoch_*"],
)
images = None
del pipeline
+18 -7
View File
@@ -927,17 +927,22 @@ def main(args):
)
# Scheduler and math around the number of training steps.
overrode_max_train_steps = False
num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
# Check the PR https://github.com/huggingface/diffusers/pull/8312 for detailed explanation.
num_warmup_steps_for_scheduler = args.lr_warmup_steps * accelerator.num_processes
if args.max_train_steps is None:
args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
overrode_max_train_steps = True
len_train_dataloader_after_sharding = math.ceil(len(train_dataloader) / accelerator.num_processes)
num_update_steps_per_epoch = math.ceil(len_train_dataloader_after_sharding / args.gradient_accumulation_steps)
num_training_steps_for_scheduler = (
args.num_train_epochs * num_update_steps_per_epoch * accelerator.num_processes
)
else:
num_training_steps_for_scheduler = args.max_train_steps * accelerator.num_processes
lr_scheduler = get_scheduler(
args.lr_scheduler,
optimizer=optimizer,
num_warmup_steps=args.lr_warmup_steps * accelerator.num_processes,
num_training_steps=args.max_train_steps * accelerator.num_processes,
num_warmup_steps=num_warmup_steps_for_scheduler,
num_training_steps=num_training_steps_for_scheduler,
num_cycles=args.lr_num_cycles,
power=args.lr_power,
)
@@ -962,8 +967,14 @@ def main(args):
# We need to recalculate our total training steps as the size of the training dataloader may have changed.
num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
if overrode_max_train_steps:
if args.max_train_steps is None:
args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
if num_training_steps_for_scheduler != args.max_train_steps * accelerator.num_processes:
logger.warning(
f"The length of the 'train_dataloader' after 'accelerator.prepare' ({len(train_dataloader)}) does not match "
f"the expected length ({len_train_dataloader_after_sharding}) when the learning rate scheduler was created. "
f"This inconsistency may result in the learning rate scheduler not functioning properly."
)
# Afterwards we recalculate our number of training epochs
args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)
+19 -7
View File
@@ -895,7 +895,10 @@ def _encode_prompt_with_t5(
prompt_embeds = text_encoder(text_input_ids.to(device))[0]
dtype = text_encoder.dtype
if hasattr(text_encoder, "module"):
dtype = text_encoder.module.dtype
else:
dtype = text_encoder.dtype
prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)
_, seq_len, _ = prompt_embeds.shape
@@ -936,9 +939,13 @@ def _encode_prompt_with_clip(
prompt_embeds = text_encoder(text_input_ids.to(device), output_hidden_states=False)
if hasattr(text_encoder, "module"):
dtype = text_encoder.module.dtype
else:
dtype = text_encoder.dtype
# Use pooled output of CLIPTextModel
prompt_embeds = prompt_embeds.pooler_output
prompt_embeds = prompt_embeds.to(dtype=text_encoder.dtype, device=device)
prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)
# duplicate text embeddings for each generation per prompt, using mps friendly method
prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
@@ -958,7 +965,12 @@ def encode_prompt(
):
prompt = [prompt] if isinstance(prompt, str) else prompt
batch_size = len(prompt)
dtype = text_encoders[0].dtype
if hasattr(text_encoders[0], "module"):
dtype = text_encoders[0].module.dtype
else:
dtype = text_encoders[0].dtype
device = device if device is not None else text_encoders[1].device
pooled_prompt_embeds = _encode_prompt_with_clip(
text_encoder=text_encoders[0],
@@ -1590,7 +1602,7 @@ def main(args):
)
# handle guidance
if accelerator.unwrap_model(transformer).config.guidance_embeds:
if unwrap_model(transformer).config.guidance_embeds:
guidance = torch.tensor([args.guidance_scale], device=accelerator.device)
guidance = guidance.expand(model_input.shape[0])
else:
@@ -1716,9 +1728,9 @@ def main(args):
pipeline = FluxPipeline.from_pretrained(
args.pretrained_model_name_or_path,
vae=vae,
text_encoder=accelerator.unwrap_model(text_encoder_one, keep_fp32_wrapper=False),
text_encoder_2=accelerator.unwrap_model(text_encoder_two, keep_fp32_wrapper=False),
transformer=accelerator.unwrap_model(transformer, keep_fp32_wrapper=False),
text_encoder=unwrap_model(text_encoder_one, keep_fp32_wrapper=False),
text_encoder_2=unwrap_model(text_encoder_two, keep_fp32_wrapper=False),
transformer=unwrap_model(transformer, keep_fp32_wrapper=False),
revision=args.revision,
variant=args.variant,
torch_dtype=weight_dtype,
@@ -177,16 +177,25 @@ def log_validation(
f"Running validation... \n Generating {args.num_validation_images} images with prompt:"
f" {args.validation_prompt}."
)
pipeline = pipeline.to(accelerator.device)
pipeline = pipeline.to(accelerator.device, dtype=torch_dtype)
pipeline.set_progress_bar_config(disable=True)
# run inference
generator = torch.Generator(device=accelerator.device).manual_seed(args.seed) if args.seed is not None else None
# autocast_ctx = torch.autocast(accelerator.device.type) if not is_final_validation else nullcontext()
autocast_ctx = nullcontext()
autocast_ctx = torch.autocast(accelerator.device.type) if not is_final_validation else nullcontext()
with autocast_ctx:
images = [pipeline(**pipeline_args, generator=generator).images[0] for _ in range(args.num_validation_images)]
# pre-calculate prompt embeds, pooled prompt embeds, text ids because t5 does not support autocast
with torch.no_grad():
prompt_embeds, pooled_prompt_embeds, text_ids = pipeline.encode_prompt(
pipeline_args["prompt"], prompt_2=pipeline_args["prompt"]
)
images = []
for _ in range(args.num_validation_images):
with autocast_ctx:
image = pipeline(
prompt_embeds=prompt_embeds, pooled_prompt_embeds=pooled_prompt_embeds, generator=generator
).images[0]
images.append(image)
for tracker in accelerator.trackers:
phase_name = "test" if is_final_validation else "validation"
@@ -203,8 +212,7 @@ def log_validation(
)
del pipeline
if torch.cuda.is_available():
torch.cuda.empty_cache()
free_memory()
return images
@@ -932,7 +940,10 @@ def _encode_prompt_with_t5(
prompt_embeds = text_encoder(text_input_ids.to(device))[0]
dtype = text_encoder.dtype
if hasattr(text_encoder, "module"):
dtype = text_encoder.module.dtype
else:
dtype = text_encoder.dtype
prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)
_, seq_len, _ = prompt_embeds.shape
@@ -973,9 +984,13 @@ def _encode_prompt_with_clip(
prompt_embeds = text_encoder(text_input_ids.to(device), output_hidden_states=False)
if hasattr(text_encoder, "module"):
dtype = text_encoder.module.dtype
else:
dtype = text_encoder.dtype
# Use pooled output of CLIPTextModel
prompt_embeds = prompt_embeds.pooler_output
prompt_embeds = prompt_embeds.to(dtype=text_encoder.dtype, device=device)
prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)
# duplicate text embeddings for each generation per prompt, using mps friendly method
prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
@@ -994,7 +1009,11 @@ def encode_prompt(
text_input_ids_list=None,
):
prompt = [prompt] if isinstance(prompt, str) else prompt
dtype = text_encoders[0].dtype
if hasattr(text_encoders[0], "module"):
dtype = text_encoders[0].module.dtype
else:
dtype = text_encoders[0].dtype
pooled_prompt_embeds = _encode_prompt_with_clip(
text_encoder=text_encoders[0],
@@ -1619,7 +1638,7 @@ def main(args):
if args.train_text_encoder:
text_encoder_one.train()
# set top parameter requires_grad = True for gradient checkpointing works
accelerator.unwrap_model(text_encoder_one).text_model.embeddings.requires_grad_(True)
unwrap_model(text_encoder_one).text_model.embeddings.requires_grad_(True)
for step, batch in enumerate(train_dataloader):
models_to_accumulate = [transformer]
@@ -1710,7 +1729,7 @@ def main(args):
)
# handle guidance
if accelerator.unwrap_model(transformer).config.guidance_embeds:
if unwrap_model(transformer).config.guidance_embeds:
guidance = torch.tensor([args.guidance_scale], device=accelerator.device)
guidance = guidance.expand(model_input.shape[0])
else:
@@ -1828,9 +1847,9 @@ def main(args):
pipeline = FluxPipeline.from_pretrained(
args.pretrained_model_name_or_path,
vae=vae,
text_encoder=accelerator.unwrap_model(text_encoder_one),
text_encoder_2=accelerator.unwrap_model(text_encoder_two),
transformer=accelerator.unwrap_model(transformer),
text_encoder=unwrap_model(text_encoder_one),
text_encoder_2=unwrap_model(text_encoder_two),
transformer=unwrap_model(transformer),
revision=args.revision,
variant=args.variant,
torch_dtype=weight_dtype,
@@ -669,6 +669,16 @@ def parse_args(input_args=None):
),
)
parser.add_argument(
"--image_interpolation_mode",
type=str,
default="lanczos",
choices=[
f.lower() for f in dir(transforms.InterpolationMode) if not f.startswith("__") and not f.endswith("__")
],
help="The image interpolation method to use for resizing images.",
)
if input_args is not None:
args = parser.parse_args(input_args)
else:
@@ -790,7 +800,12 @@ class DreamBoothDataset(Dataset):
self.original_sizes = []
self.crop_top_lefts = []
self.pixel_values = []
train_resize = transforms.Resize(size, interpolation=transforms.InterpolationMode.BILINEAR)
interpolation = getattr(transforms.InterpolationMode, args.image_interpolation_mode.upper(), None)
if interpolation is None:
raise ValueError(f"Unsupported interpolation mode {interpolation=}.")
train_resize = transforms.Resize(size, interpolation=interpolation)
train_crop = transforms.CenterCrop(size) if center_crop else transforms.RandomCrop(size)
train_flip = transforms.RandomHorizontalFlip(p=1.0)
train_transforms = transforms.Compose(
+25
View File
@@ -316,6 +316,7 @@ def _load_lora_into_text_encoder(
adapter_name=None,
_pipeline=None,
low_cpu_mem_usage=False,
hotswap: bool = False,
):
if not USE_PEFT_BACKEND:
raise ValueError("PEFT backend is required for this method.")
@@ -341,6 +342,10 @@ def _load_lora_into_text_encoder(
# their prefixes.
prefix = text_encoder_name if prefix is None else prefix
# Safe prefix to check with.
if hotswap and any(text_encoder_name in key for key in state_dict.keys()):
raise ValueError("At the moment, hotswapping is not supported for text encoders, please pass `hotswap=False`.")
# Load the layers corresponding to text encoder and make necessary adjustments.
if prefix is not None:
state_dict = {k[len(f"{prefix}.") :]: v for k, v in state_dict.items() if k.startswith(f"{prefix}.")}
@@ -908,3 +913,23 @@ class LoraBaseMixin:
# property function that returns the lora scale which can be set at run time by the pipeline.
# if _lora_scale has not been set, return 1
return self._lora_scale if hasattr(self, "_lora_scale") else 1.0
def enable_lora_hotswap(self, **kwargs) -> None:
"""Enables the possibility to hotswap LoRA adapters.
Calling this method is only required when hotswapping adapters and if the model is compiled or if the ranks of
the loaded adapters differ.
Args:
target_rank (`int`):
The highest rank among all the adapters that will be loaded.
check_compiled (`str`, *optional*, defaults to `"error"`):
How to handle the case when the model is already compiled, which should generally be avoided. The
options are:
- "error" (default): raise an error
- "warn": issue a warning
- "ignore": do nothing
"""
for key, component in self.components.items():
if hasattr(component, "enable_lora_hotswap") and (key in self._lora_loadable_modules):
component.enable_lora_hotswap(**kwargs)
+567 -18
View File
@@ -79,10 +79,13 @@ class StableDiffusionLoraLoaderMixin(LoraBaseMixin):
text_encoder_name = TEXT_ENCODER_NAME
def load_lora_weights(
self, pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]], adapter_name=None, **kwargs
self,
pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]],
adapter_name=None,
hotswap: bool = False,
**kwargs,
):
"""
Load LoRA weights specified in `pretrained_model_name_or_path_or_dict` into `self.unet` and
"""Load LoRA weights specified in `pretrained_model_name_or_path_or_dict` into `self.unet` and
`self.text_encoder`.
All kwargs are forwarded to `self.lora_state_dict`.
@@ -105,6 +108,29 @@ class StableDiffusionLoraLoaderMixin(LoraBaseMixin):
low_cpu_mem_usage (`bool`, *optional*):
Speed up model loading by only loading the pretrained LoRA weights and not initializing the random
weights.
hotswap : (`bool`, *optional*)
Defaults to `False`. Whether to substitute an existing (LoRA) adapter with the newly loaded adapter
in-place. This means that, instead of loading an additional adapter, this will take the existing
adapter weights and replace them with the weights of the new adapter. This can be faster and more
memory efficient. However, the main advantage of hotswapping is that when the model is compiled with
torch.compile, loading the new adapter does not require recompilation of the model. When using
hotswapping, the passed `adapter_name` should be the name of an already loaded adapter.
If the new adapter and the old adapter have different ranks and/or LoRA alphas (i.e. scaling), you need
to call an additional method before loading the adapter:
```py
pipeline = ... # load diffusers pipeline
max_rank = ... # the highest rank among all LoRAs that you want to load
# call *before* compiling and loading the LoRA adapter
pipeline.enable_lora_hotswap(target_rank=max_rank)
pipeline.load_lora_weights(file_name)
# optionally compile the model now
```
Note that hotswapping adapters of the text encoder is not yet supported. There are some further
limitations to this technique, which are documented here:
https://huggingface.co/docs/peft/main/en/package_reference/hotswap
kwargs (`dict`, *optional*):
See [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`].
"""
@@ -135,6 +161,7 @@ class StableDiffusionLoraLoaderMixin(LoraBaseMixin):
adapter_name=adapter_name,
_pipeline=self,
low_cpu_mem_usage=low_cpu_mem_usage,
hotswap=hotswap,
)
self.load_lora_into_text_encoder(
state_dict,
@@ -146,6 +173,7 @@ class StableDiffusionLoraLoaderMixin(LoraBaseMixin):
adapter_name=adapter_name,
_pipeline=self,
low_cpu_mem_usage=low_cpu_mem_usage,
hotswap=hotswap,
)
@classmethod
@@ -265,7 +293,14 @@ class StableDiffusionLoraLoaderMixin(LoraBaseMixin):
@classmethod
def load_lora_into_unet(
cls, state_dict, network_alphas, unet, adapter_name=None, _pipeline=None, low_cpu_mem_usage=False
cls,
state_dict,
network_alphas,
unet,
adapter_name=None,
_pipeline=None,
low_cpu_mem_usage=False,
hotswap: bool = False,
):
"""
This will load the LoRA layers specified in `state_dict` into `unet`.
@@ -287,6 +322,29 @@ class StableDiffusionLoraLoaderMixin(LoraBaseMixin):
low_cpu_mem_usage (`bool`, *optional*):
Speed up model loading only loading the pretrained LoRA weights and not initializing the random
weights.
hotswap : (`bool`, *optional*)
Defaults to `False`. Whether to substitute an existing (LoRA) adapter with the newly loaded adapter
in-place. This means that, instead of loading an additional adapter, this will take the existing
adapter weights and replace them with the weights of the new adapter. This can be faster and more
memory efficient. However, the main advantage of hotswapping is that when the model is compiled with
torch.compile, loading the new adapter does not require recompilation of the model. When using
hotswapping, the passed `adapter_name` should be the name of an already loaded adapter.
If the new adapter and the old adapter have different ranks and/or LoRA alphas (i.e. scaling), you need
to call an additional method before loading the adapter:
```py
pipeline = ... # load diffusers pipeline
max_rank = ... # the highest rank among all LoRAs that you want to load
# call *before* compiling and loading the LoRA adapter
pipeline.enable_lora_hotswap(target_rank=max_rank)
pipeline.load_lora_weights(file_name)
# optionally compile the model now
```
Note that hotswapping adapters of the text encoder is not yet supported. There are some further
limitations to this technique, which are documented here:
https://huggingface.co/docs/peft/main/en/package_reference/hotswap
"""
if not USE_PEFT_BACKEND:
raise ValueError("PEFT backend is required for this method.")
@@ -307,6 +365,7 @@ class StableDiffusionLoraLoaderMixin(LoraBaseMixin):
adapter_name=adapter_name,
_pipeline=_pipeline,
low_cpu_mem_usage=low_cpu_mem_usage,
hotswap=hotswap,
)
@classmethod
@@ -320,6 +379,7 @@ class StableDiffusionLoraLoaderMixin(LoraBaseMixin):
adapter_name=None,
_pipeline=None,
low_cpu_mem_usage=False,
hotswap: bool = False,
):
"""
This will load the LoRA layers specified in `state_dict` into `text_encoder`
@@ -345,6 +405,29 @@ class StableDiffusionLoraLoaderMixin(LoraBaseMixin):
low_cpu_mem_usage (`bool`, *optional*):
Speed up model loading by only loading the pretrained LoRA weights and not initializing the random
weights.
hotswap : (`bool`, *optional*)
Defaults to `False`. Whether to substitute an existing (LoRA) adapter with the newly loaded adapter
in-place. This means that, instead of loading an additional adapter, this will take the existing
adapter weights and replace them with the weights of the new adapter. This can be faster and more
memory efficient. However, the main advantage of hotswapping is that when the model is compiled with
torch.compile, loading the new adapter does not require recompilation of the model. When using
hotswapping, the passed `adapter_name` should be the name of an already loaded adapter.
If the new adapter and the old adapter have different ranks and/or LoRA alphas (i.e. scaling), you need
to call an additional method before loading the adapter:
```py
pipeline = ... # load diffusers pipeline
max_rank = ... # the highest rank among all LoRAs that you want to load
# call *before* compiling and loading the LoRA adapter
pipeline.enable_lora_hotswap(target_rank=max_rank)
pipeline.load_lora_weights(file_name)
# optionally compile the model now
```
Note that hotswapping adapters of the text encoder is not yet supported. There are some further
limitations to this technique, which are documented here:
https://huggingface.co/docs/peft/main/en/package_reference/hotswap
"""
_load_lora_into_text_encoder(
state_dict=state_dict,
@@ -356,6 +439,7 @@ class StableDiffusionLoraLoaderMixin(LoraBaseMixin):
adapter_name=adapter_name,
_pipeline=_pipeline,
low_cpu_mem_usage=low_cpu_mem_usage,
hotswap=hotswap,
)
@classmethod
@@ -700,7 +784,14 @@ class StableDiffusionXLLoraLoaderMixin(LoraBaseMixin):
@classmethod
# Copied from diffusers.loaders.lora_pipeline.StableDiffusionLoraLoaderMixin.load_lora_into_unet
def load_lora_into_unet(
cls, state_dict, network_alphas, unet, adapter_name=None, _pipeline=None, low_cpu_mem_usage=False
cls,
state_dict,
network_alphas,
unet,
adapter_name=None,
_pipeline=None,
low_cpu_mem_usage=False,
hotswap: bool = False,
):
"""
This will load the LoRA layers specified in `state_dict` into `unet`.
@@ -722,6 +813,29 @@ class StableDiffusionXLLoraLoaderMixin(LoraBaseMixin):
low_cpu_mem_usage (`bool`, *optional*):
Speed up model loading only loading the pretrained LoRA weights and not initializing the random
weights.
hotswap : (`bool`, *optional*)
Defaults to `False`. Whether to substitute an existing (LoRA) adapter with the newly loaded adapter
in-place. This means that, instead of loading an additional adapter, this will take the existing
adapter weights and replace them with the weights of the new adapter. This can be faster and more
memory efficient. However, the main advantage of hotswapping is that when the model is compiled with
torch.compile, loading the new adapter does not require recompilation of the model. When using
hotswapping, the passed `adapter_name` should be the name of an already loaded adapter.
If the new adapter and the old adapter have different ranks and/or LoRA alphas (i.e. scaling), you need
to call an additional method before loading the adapter:
```py
pipeline = ... # load diffusers pipeline
max_rank = ... # the highest rank among all LoRAs that you want to load
# call *before* compiling and loading the LoRA adapter
pipeline.enable_lora_hotswap(target_rank=max_rank)
pipeline.load_lora_weights(file_name)
# optionally compile the model now
```
Note that hotswapping adapters of the text encoder is not yet supported. There are some further
limitations to this technique, which are documented here:
https://huggingface.co/docs/peft/main/en/package_reference/hotswap
"""
if not USE_PEFT_BACKEND:
raise ValueError("PEFT backend is required for this method.")
@@ -742,6 +856,7 @@ class StableDiffusionXLLoraLoaderMixin(LoraBaseMixin):
adapter_name=adapter_name,
_pipeline=_pipeline,
low_cpu_mem_usage=low_cpu_mem_usage,
hotswap=hotswap,
)
@classmethod
@@ -756,6 +871,7 @@ class StableDiffusionXLLoraLoaderMixin(LoraBaseMixin):
adapter_name=None,
_pipeline=None,
low_cpu_mem_usage=False,
hotswap: bool = False,
):
"""
This will load the LoRA layers specified in `state_dict` into `text_encoder`
@@ -781,6 +897,29 @@ class StableDiffusionXLLoraLoaderMixin(LoraBaseMixin):
low_cpu_mem_usage (`bool`, *optional*):
Speed up model loading by only loading the pretrained LoRA weights and not initializing the random
weights.
hotswap : (`bool`, *optional*)
Defaults to `False`. Whether to substitute an existing (LoRA) adapter with the newly loaded adapter
in-place. This means that, instead of loading an additional adapter, this will take the existing
adapter weights and replace them with the weights of the new adapter. This can be faster and more
memory efficient. However, the main advantage of hotswapping is that when the model is compiled with
torch.compile, loading the new adapter does not require recompilation of the model. When using
hotswapping, the passed `adapter_name` should be the name of an already loaded adapter.
If the new adapter and the old adapter have different ranks and/or LoRA alphas (i.e. scaling), you need
to call an additional method before loading the adapter:
```py
pipeline = ... # load diffusers pipeline
max_rank = ... # the highest rank among all LoRAs that you want to load
# call *before* compiling and loading the LoRA adapter
pipeline.enable_lora_hotswap(target_rank=max_rank)
pipeline.load_lora_weights(file_name)
# optionally compile the model now
```
Note that hotswapping adapters of the text encoder is not yet supported. There are some further
limitations to this technique, which are documented here:
https://huggingface.co/docs/peft/main/en/package_reference/hotswap
"""
_load_lora_into_text_encoder(
state_dict=state_dict,
@@ -792,6 +931,7 @@ class StableDiffusionXLLoraLoaderMixin(LoraBaseMixin):
adapter_name=adapter_name,
_pipeline=_pipeline,
low_cpu_mem_usage=low_cpu_mem_usage,
hotswap=hotswap,
)
@classmethod
@@ -1035,7 +1175,11 @@ class SD3LoraLoaderMixin(LoraBaseMixin):
return state_dict
def load_lora_weights(
self, pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]], adapter_name=None, **kwargs
self,
pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]],
adapter_name=None,
hotswap: bool = False,
**kwargs,
):
"""
Load LoRA weights specified in `pretrained_model_name_or_path_or_dict` into `self.unet` and
@@ -1058,6 +1202,29 @@ class SD3LoraLoaderMixin(LoraBaseMixin):
low_cpu_mem_usage (`bool`, *optional*):
Speed up model loading by only loading the pretrained LoRA weights and not initializing the random
weights.
hotswap : (`bool`, *optional*)
Defaults to `False`. Whether to substitute an existing (LoRA) adapter with the newly loaded adapter
in-place. This means that, instead of loading an additional adapter, this will take the existing
adapter weights and replace them with the weights of the new adapter. This can be faster and more
memory efficient. However, the main advantage of hotswapping is that when the model is compiled with
torch.compile, loading the new adapter does not require recompilation of the model. When using
hotswapping, the passed `adapter_name` should be the name of an already loaded adapter.
If the new adapter and the old adapter have different ranks and/or LoRA alphas (i.e. scaling), you need
to call an additional method before loading the adapter:
```py
pipeline = ... # load diffusers pipeline
max_rank = ... # the highest rank among all LoRAs that you want to load
# call *before* compiling and loading the LoRA adapter
pipeline.enable_lora_hotswap(target_rank=max_rank)
pipeline.load_lora_weights(file_name)
# optionally compile the model now
```
Note that hotswapping adapters of the text encoder is not yet supported. There are some further
limitations to this technique, which are documented here:
https://huggingface.co/docs/peft/main/en/package_reference/hotswap
kwargs (`dict`, *optional*):
See [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`].
"""
@@ -1087,6 +1254,7 @@ class SD3LoraLoaderMixin(LoraBaseMixin):
adapter_name=adapter_name,
_pipeline=self,
low_cpu_mem_usage=low_cpu_mem_usage,
hotswap=hotswap,
)
self.load_lora_into_text_encoder(
state_dict,
@@ -1097,6 +1265,7 @@ class SD3LoraLoaderMixin(LoraBaseMixin):
adapter_name=adapter_name,
_pipeline=self,
low_cpu_mem_usage=low_cpu_mem_usage,
hotswap=hotswap,
)
self.load_lora_into_text_encoder(
state_dict,
@@ -1107,11 +1276,12 @@ class SD3LoraLoaderMixin(LoraBaseMixin):
adapter_name=adapter_name,
_pipeline=self,
low_cpu_mem_usage=low_cpu_mem_usage,
hotswap=hotswap,
)
@classmethod
def load_lora_into_transformer(
cls, state_dict, transformer, adapter_name=None, _pipeline=None, low_cpu_mem_usage=False
cls, state_dict, transformer, adapter_name=None, _pipeline=None, low_cpu_mem_usage=False, hotswap: bool = False
):
"""
This will load the LoRA layers specified in `state_dict` into `transformer`.
@@ -1129,6 +1299,29 @@ class SD3LoraLoaderMixin(LoraBaseMixin):
low_cpu_mem_usage (`bool`, *optional*):
Speed up model loading by only loading the pretrained LoRA weights and not initializing the random
weights.
hotswap : (`bool`, *optional*)
Defaults to `False`. Whether to substitute an existing (LoRA) adapter with the newly loaded adapter
in-place. This means that, instead of loading an additional adapter, this will take the existing
adapter weights and replace them with the weights of the new adapter. This can be faster and more
memory efficient. However, the main advantage of hotswapping is that when the model is compiled with
torch.compile, loading the new adapter does not require recompilation of the model. When using
hotswapping, the passed `adapter_name` should be the name of an already loaded adapter.
If the new adapter and the old adapter have different ranks and/or LoRA alphas (i.e. scaling), you need
to call an additional method before loading the adapter:
```py
pipeline = ... # load diffusers pipeline
max_rank = ... # the highest rank among all LoRAs that you want to load
# call *before* compiling and loading the LoRA adapter
pipeline.enable_lora_hotswap(target_rank=max_rank)
pipeline.load_lora_weights(file_name)
# optionally compile the model now
```
Note that hotswapping adapters of the text encoder is not yet supported. There are some further
limitations to this technique, which are documented here:
https://huggingface.co/docs/peft/main/en/package_reference/hotswap
"""
if low_cpu_mem_usage and is_peft_version("<", "0.13.0"):
raise ValueError(
@@ -1143,6 +1336,7 @@ class SD3LoraLoaderMixin(LoraBaseMixin):
adapter_name=adapter_name,
_pipeline=_pipeline,
low_cpu_mem_usage=low_cpu_mem_usage,
hotswap=hotswap,
)
@classmethod
@@ -1157,6 +1351,7 @@ class SD3LoraLoaderMixin(LoraBaseMixin):
adapter_name=None,
_pipeline=None,
low_cpu_mem_usage=False,
hotswap: bool = False,
):
"""
This will load the LoRA layers specified in `state_dict` into `text_encoder`
@@ -1182,6 +1377,29 @@ class SD3LoraLoaderMixin(LoraBaseMixin):
low_cpu_mem_usage (`bool`, *optional*):
Speed up model loading by only loading the pretrained LoRA weights and not initializing the random
weights.
hotswap : (`bool`, *optional*)
Defaults to `False`. Whether to substitute an existing (LoRA) adapter with the newly loaded adapter
in-place. This means that, instead of loading an additional adapter, this will take the existing
adapter weights and replace them with the weights of the new adapter. This can be faster and more
memory efficient. However, the main advantage of hotswapping is that when the model is compiled with
torch.compile, loading the new adapter does not require recompilation of the model. When using
hotswapping, the passed `adapter_name` should be the name of an already loaded adapter.
If the new adapter and the old adapter have different ranks and/or LoRA alphas (i.e. scaling), you need
to call an additional method before loading the adapter:
```py
pipeline = ... # load diffusers pipeline
max_rank = ... # the highest rank among all LoRAs that you want to load
# call *before* compiling and loading the LoRA adapter
pipeline.enable_lora_hotswap(target_rank=max_rank)
pipeline.load_lora_weights(file_name)
# optionally compile the model now
```
Note that hotswapping adapters of the text encoder is not yet supported. There are some further
limitations to this technique, which are documented here:
https://huggingface.co/docs/peft/main/en/package_reference/hotswap
"""
_load_lora_into_text_encoder(
state_dict=state_dict,
@@ -1193,6 +1411,7 @@ class SD3LoraLoaderMixin(LoraBaseMixin):
adapter_name=adapter_name,
_pipeline=_pipeline,
low_cpu_mem_usage=low_cpu_mem_usage,
hotswap=hotswap,
)
@classmethod
@@ -1476,7 +1695,11 @@ class FluxLoraLoaderMixin(LoraBaseMixin):
return state_dict
def load_lora_weights(
self, pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]], adapter_name=None, **kwargs
self,
pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]],
adapter_name=None,
hotswap: bool = False,
**kwargs,
):
"""
Load LoRA weights specified in `pretrained_model_name_or_path_or_dict` into `self.transformer` and
@@ -1501,6 +1724,26 @@ class FluxLoraLoaderMixin(LoraBaseMixin):
low_cpu_mem_usage (`bool`, *optional*):
`Speed up model loading by only loading the pretrained LoRA weights and not initializing the random
weights.
hotswap : (`bool`, *optional*)
Defaults to `False`. Whether to substitute an existing (LoRA) adapter with the newly loaded adapter
in-place. This means that, instead of loading an additional adapter, this will take the existing
adapter weights and replace them with the weights of the new adapter. This can be faster and more
memory efficient. However, the main advantage of hotswapping is that when the model is compiled with
torch.compile, loading the new adapter does not require recompilation of the model. When using
hotswapping, the passed `adapter_name` should be the name of an already loaded adapter. If the new
adapter and the old adapter have different ranks and/or LoRA alphas (i.e. scaling), you need to call an
additional method before loading the adapter:
```py
pipeline = ... # load diffusers pipeline
max_rank = ... # the highest rank among all LoRAs that you want to load
# call *before* compiling and loading the LoRA adapter
pipeline.enable_lora_hotswap(target_rank=max_rank)
pipeline.load_lora_weights(file_name)
# optionally compile the model now
```
Note that hotswapping adapters of the text encoder is not yet supported. There are some further
limitations to this technique, which are documented here:
https://huggingface.co/docs/peft/main/en/package_reference/hotswap
"""
if not USE_PEFT_BACKEND:
raise ValueError("PEFT backend is required for this method.")
@@ -1569,6 +1812,7 @@ class FluxLoraLoaderMixin(LoraBaseMixin):
adapter_name=adapter_name,
_pipeline=self,
low_cpu_mem_usage=low_cpu_mem_usage,
hotswap=hotswap,
)
if len(transformer_norm_state_dict) > 0:
@@ -1587,11 +1831,19 @@ class FluxLoraLoaderMixin(LoraBaseMixin):
adapter_name=adapter_name,
_pipeline=self,
low_cpu_mem_usage=low_cpu_mem_usage,
hotswap=hotswap,
)
@classmethod
def load_lora_into_transformer(
cls, state_dict, network_alphas, transformer, adapter_name=None, _pipeline=None, low_cpu_mem_usage=False
cls,
state_dict,
network_alphas,
transformer,
adapter_name=None,
_pipeline=None,
low_cpu_mem_usage=False,
hotswap: bool = False,
):
"""
This will load the LoRA layers specified in `state_dict` into `transformer`.
@@ -1613,6 +1865,29 @@ class FluxLoraLoaderMixin(LoraBaseMixin):
low_cpu_mem_usage (`bool`, *optional*):
Speed up model loading by only loading the pretrained LoRA weights and not initializing the random
weights.
hotswap : (`bool`, *optional*)
Defaults to `False`. Whether to substitute an existing (LoRA) adapter with the newly loaded adapter
in-place. This means that, instead of loading an additional adapter, this will take the existing
adapter weights and replace them with the weights of the new adapter. This can be faster and more
memory efficient. However, the main advantage of hotswapping is that when the model is compiled with
torch.compile, loading the new adapter does not require recompilation of the model. When using
hotswapping, the passed `adapter_name` should be the name of an already loaded adapter.
If the new adapter and the old adapter have different ranks and/or LoRA alphas (i.e. scaling), you need
to call an additional method before loading the adapter:
```py
pipeline = ... # load diffusers pipeline
max_rank = ... # the highest rank among all LoRAs that you want to load
# call *before* compiling and loading the LoRA adapter
pipeline.enable_lora_hotswap(target_rank=max_rank)
pipeline.load_lora_weights(file_name)
# optionally compile the model now
```
Note that hotswapping adapters of the text encoder is not yet supported. There are some further
limitations to this technique, which are documented here:
https://huggingface.co/docs/peft/main/en/package_reference/hotswap
"""
if low_cpu_mem_usage and not is_peft_version(">=", "0.13.1"):
raise ValueError(
@@ -1627,6 +1902,7 @@ class FluxLoraLoaderMixin(LoraBaseMixin):
adapter_name=adapter_name,
_pipeline=_pipeline,
low_cpu_mem_usage=low_cpu_mem_usage,
hotswap=hotswap,
)
@classmethod
@@ -1695,6 +1971,7 @@ class FluxLoraLoaderMixin(LoraBaseMixin):
adapter_name=None,
_pipeline=None,
low_cpu_mem_usage=False,
hotswap: bool = False,
):
"""
This will load the LoRA layers specified in `state_dict` into `text_encoder`
@@ -1720,6 +1997,29 @@ class FluxLoraLoaderMixin(LoraBaseMixin):
low_cpu_mem_usage (`bool`, *optional*):
Speed up model loading by only loading the pretrained LoRA weights and not initializing the random
weights.
hotswap : (`bool`, *optional*)
Defaults to `False`. Whether to substitute an existing (LoRA) adapter with the newly loaded adapter
in-place. This means that, instead of loading an additional adapter, this will take the existing
adapter weights and replace them with the weights of the new adapter. This can be faster and more
memory efficient. However, the main advantage of hotswapping is that when the model is compiled with
torch.compile, loading the new adapter does not require recompilation of the model. When using
hotswapping, the passed `adapter_name` should be the name of an already loaded adapter.
If the new adapter and the old adapter have different ranks and/or LoRA alphas (i.e. scaling), you need
to call an additional method before loading the adapter:
```py
pipeline = ... # load diffusers pipeline
max_rank = ... # the highest rank among all LoRAs that you want to load
# call *before* compiling and loading the LoRA adapter
pipeline.enable_lora_hotswap(target_rank=max_rank)
pipeline.load_lora_weights(file_name)
# optionally compile the model now
```
Note that hotswapping adapters of the text encoder is not yet supported. There are some further
limitations to this technique, which are documented here:
https://huggingface.co/docs/peft/main/en/package_reference/hotswap
"""
_load_lora_into_text_encoder(
state_dict=state_dict,
@@ -1731,6 +2031,7 @@ class FluxLoraLoaderMixin(LoraBaseMixin):
adapter_name=adapter_name,
_pipeline=_pipeline,
low_cpu_mem_usage=low_cpu_mem_usage,
hotswap=hotswap,
)
@classmethod
@@ -2141,7 +2442,14 @@ class AmusedLoraLoaderMixin(StableDiffusionLoraLoaderMixin):
@classmethod
# Copied from diffusers.loaders.lora_pipeline.FluxLoraLoaderMixin.load_lora_into_transformer with FluxTransformer2DModel->UVit2DModel
def load_lora_into_transformer(
cls, state_dict, network_alphas, transformer, adapter_name=None, _pipeline=None, low_cpu_mem_usage=False
cls,
state_dict,
network_alphas,
transformer,
adapter_name=None,
_pipeline=None,
low_cpu_mem_usage=False,
hotswap: bool = False,
):
"""
This will load the LoRA layers specified in `state_dict` into `transformer`.
@@ -2163,6 +2471,29 @@ class AmusedLoraLoaderMixin(StableDiffusionLoraLoaderMixin):
low_cpu_mem_usage (`bool`, *optional*):
Speed up model loading by only loading the pretrained LoRA weights and not initializing the random
weights.
hotswap : (`bool`, *optional*)
Defaults to `False`. Whether to substitute an existing (LoRA) adapter with the newly loaded adapter
in-place. This means that, instead of loading an additional adapter, this will take the existing
adapter weights and replace them with the weights of the new adapter. This can be faster and more
memory efficient. However, the main advantage of hotswapping is that when the model is compiled with
torch.compile, loading the new adapter does not require recompilation of the model. When using
hotswapping, the passed `adapter_name` should be the name of an already loaded adapter.
If the new adapter and the old adapter have different ranks and/or LoRA alphas (i.e. scaling), you need
to call an additional method before loading the adapter:
```py
pipeline = ... # load diffusers pipeline
max_rank = ... # the highest rank among all LoRAs that you want to load
# call *before* compiling and loading the LoRA adapter
pipeline.enable_lora_hotswap(target_rank=max_rank)
pipeline.load_lora_weights(file_name)
# optionally compile the model now
```
Note that hotswapping adapters of the text encoder is not yet supported. There are some further
limitations to this technique, which are documented here:
https://huggingface.co/docs/peft/main/en/package_reference/hotswap
"""
if low_cpu_mem_usage and not is_peft_version(">=", "0.13.1"):
raise ValueError(
@@ -2177,6 +2508,7 @@ class AmusedLoraLoaderMixin(StableDiffusionLoraLoaderMixin):
adapter_name=adapter_name,
_pipeline=_pipeline,
low_cpu_mem_usage=low_cpu_mem_usage,
hotswap=hotswap,
)
@classmethod
@@ -2191,6 +2523,7 @@ class AmusedLoraLoaderMixin(StableDiffusionLoraLoaderMixin):
adapter_name=None,
_pipeline=None,
low_cpu_mem_usage=False,
hotswap: bool = False,
):
"""
This will load the LoRA layers specified in `state_dict` into `text_encoder`
@@ -2216,6 +2549,29 @@ class AmusedLoraLoaderMixin(StableDiffusionLoraLoaderMixin):
low_cpu_mem_usage (`bool`, *optional*):
Speed up model loading by only loading the pretrained LoRA weights and not initializing the random
weights.
hotswap : (`bool`, *optional*)
Defaults to `False`. Whether to substitute an existing (LoRA) adapter with the newly loaded adapter
in-place. This means that, instead of loading an additional adapter, this will take the existing
adapter weights and replace them with the weights of the new adapter. This can be faster and more
memory efficient. However, the main advantage of hotswapping is that when the model is compiled with
torch.compile, loading the new adapter does not require recompilation of the model. When using
hotswapping, the passed `adapter_name` should be the name of an already loaded adapter.
If the new adapter and the old adapter have different ranks and/or LoRA alphas (i.e. scaling), you need
to call an additional method before loading the adapter:
```py
pipeline = ... # load diffusers pipeline
max_rank = ... # the highest rank among all LoRAs that you want to load
# call *before* compiling and loading the LoRA adapter
pipeline.enable_lora_hotswap(target_rank=max_rank)
pipeline.load_lora_weights(file_name)
# optionally compile the model now
```
Note that hotswapping adapters of the text encoder is not yet supported. There are some further
limitations to this technique, which are documented here:
https://huggingface.co/docs/peft/main/en/package_reference/hotswap
"""
_load_lora_into_text_encoder(
state_dict=state_dict,
@@ -2227,6 +2583,7 @@ class AmusedLoraLoaderMixin(StableDiffusionLoraLoaderMixin):
adapter_name=adapter_name,
_pipeline=_pipeline,
low_cpu_mem_usage=low_cpu_mem_usage,
hotswap=hotswap,
)
@classmethod
@@ -2443,7 +2800,7 @@ class CogVideoXLoraLoaderMixin(LoraBaseMixin):
@classmethod
# Copied from diffusers.loaders.lora_pipeline.SD3LoraLoaderMixin.load_lora_into_transformer with SD3Transformer2DModel->CogVideoXTransformer3DModel
def load_lora_into_transformer(
cls, state_dict, transformer, adapter_name=None, _pipeline=None, low_cpu_mem_usage=False
cls, state_dict, transformer, adapter_name=None, _pipeline=None, low_cpu_mem_usage=False, hotswap: bool = False
):
"""
This will load the LoRA layers specified in `state_dict` into `transformer`.
@@ -2461,6 +2818,29 @@ class CogVideoXLoraLoaderMixin(LoraBaseMixin):
low_cpu_mem_usage (`bool`, *optional*):
Speed up model loading by only loading the pretrained LoRA weights and not initializing the random
weights.
hotswap : (`bool`, *optional*)
Defaults to `False`. Whether to substitute an existing (LoRA) adapter with the newly loaded adapter
in-place. This means that, instead of loading an additional adapter, this will take the existing
adapter weights and replace them with the weights of the new adapter. This can be faster and more
memory efficient. However, the main advantage of hotswapping is that when the model is compiled with
torch.compile, loading the new adapter does not require recompilation of the model. When using
hotswapping, the passed `adapter_name` should be the name of an already loaded adapter.
If the new adapter and the old adapter have different ranks and/or LoRA alphas (i.e. scaling), you need
to call an additional method before loading the adapter:
```py
pipeline = ... # load diffusers pipeline
max_rank = ... # the highest rank among all LoRAs that you want to load
# call *before* compiling and loading the LoRA adapter
pipeline.enable_lora_hotswap(target_rank=max_rank)
pipeline.load_lora_weights(file_name)
# optionally compile the model now
```
Note that hotswapping adapters of the text encoder is not yet supported. There are some further
limitations to this technique, which are documented here:
https://huggingface.co/docs/peft/main/en/package_reference/hotswap
"""
if low_cpu_mem_usage and is_peft_version("<", "0.13.0"):
raise ValueError(
@@ -2475,6 +2855,7 @@ class CogVideoXLoraLoaderMixin(LoraBaseMixin):
adapter_name=adapter_name,
_pipeline=_pipeline,
low_cpu_mem_usage=low_cpu_mem_usage,
hotswap=hotswap,
)
@classmethod
@@ -2750,7 +3131,7 @@ class Mochi1LoraLoaderMixin(LoraBaseMixin):
@classmethod
# Copied from diffusers.loaders.lora_pipeline.SD3LoraLoaderMixin.load_lora_into_transformer with SD3Transformer2DModel->MochiTransformer3DModel
def load_lora_into_transformer(
cls, state_dict, transformer, adapter_name=None, _pipeline=None, low_cpu_mem_usage=False
cls, state_dict, transformer, adapter_name=None, _pipeline=None, low_cpu_mem_usage=False, hotswap: bool = False
):
"""
This will load the LoRA layers specified in `state_dict` into `transformer`.
@@ -2768,6 +3149,29 @@ class Mochi1LoraLoaderMixin(LoraBaseMixin):
low_cpu_mem_usage (`bool`, *optional*):
Speed up model loading by only loading the pretrained LoRA weights and not initializing the random
weights.
hotswap : (`bool`, *optional*)
Defaults to `False`. Whether to substitute an existing (LoRA) adapter with the newly loaded adapter
in-place. This means that, instead of loading an additional adapter, this will take the existing
adapter weights and replace them with the weights of the new adapter. This can be faster and more
memory efficient. However, the main advantage of hotswapping is that when the model is compiled with
torch.compile, loading the new adapter does not require recompilation of the model. When using
hotswapping, the passed `adapter_name` should be the name of an already loaded adapter.
If the new adapter and the old adapter have different ranks and/or LoRA alphas (i.e. scaling), you need
to call an additional method before loading the adapter:
```py
pipeline = ... # load diffusers pipeline
max_rank = ... # the highest rank among all LoRAs that you want to load
# call *before* compiling and loading the LoRA adapter
pipeline.enable_lora_hotswap(target_rank=max_rank)
pipeline.load_lora_weights(file_name)
# optionally compile the model now
```
Note that hotswapping adapters of the text encoder is not yet supported. There are some further
limitations to this technique, which are documented here:
https://huggingface.co/docs/peft/main/en/package_reference/hotswap
"""
if low_cpu_mem_usage and is_peft_version("<", "0.13.0"):
raise ValueError(
@@ -2782,6 +3186,7 @@ class Mochi1LoraLoaderMixin(LoraBaseMixin):
adapter_name=adapter_name,
_pipeline=_pipeline,
low_cpu_mem_usage=low_cpu_mem_usage,
hotswap=hotswap,
)
@classmethod
@@ -3059,7 +3464,7 @@ class LTXVideoLoraLoaderMixin(LoraBaseMixin):
@classmethod
# Copied from diffusers.loaders.lora_pipeline.SD3LoraLoaderMixin.load_lora_into_transformer with SD3Transformer2DModel->LTXVideoTransformer3DModel
def load_lora_into_transformer(
cls, state_dict, transformer, adapter_name=None, _pipeline=None, low_cpu_mem_usage=False
cls, state_dict, transformer, adapter_name=None, _pipeline=None, low_cpu_mem_usage=False, hotswap: bool = False
):
"""
This will load the LoRA layers specified in `state_dict` into `transformer`.
@@ -3077,6 +3482,29 @@ class LTXVideoLoraLoaderMixin(LoraBaseMixin):
low_cpu_mem_usage (`bool`, *optional*):
Speed up model loading by only loading the pretrained LoRA weights and not initializing the random
weights.
hotswap : (`bool`, *optional*)
Defaults to `False`. Whether to substitute an existing (LoRA) adapter with the newly loaded adapter
in-place. This means that, instead of loading an additional adapter, this will take the existing
adapter weights and replace them with the weights of the new adapter. This can be faster and more
memory efficient. However, the main advantage of hotswapping is that when the model is compiled with
torch.compile, loading the new adapter does not require recompilation of the model. When using
hotswapping, the passed `adapter_name` should be the name of an already loaded adapter.
If the new adapter and the old adapter have different ranks and/or LoRA alphas (i.e. scaling), you need
to call an additional method before loading the adapter:
```py
pipeline = ... # load diffusers pipeline
max_rank = ... # the highest rank among all LoRAs that you want to load
# call *before* compiling and loading the LoRA adapter
pipeline.enable_lora_hotswap(target_rank=max_rank)
pipeline.load_lora_weights(file_name)
# optionally compile the model now
```
Note that hotswapping adapters of the text encoder is not yet supported. There are some further
limitations to this technique, which are documented here:
https://huggingface.co/docs/peft/main/en/package_reference/hotswap
"""
if low_cpu_mem_usage and is_peft_version("<", "0.13.0"):
raise ValueError(
@@ -3091,6 +3519,7 @@ class LTXVideoLoraLoaderMixin(LoraBaseMixin):
adapter_name=adapter_name,
_pipeline=_pipeline,
low_cpu_mem_usage=low_cpu_mem_usage,
hotswap=hotswap,
)
@classmethod
@@ -3368,7 +3797,7 @@ class SanaLoraLoaderMixin(LoraBaseMixin):
@classmethod
# Copied from diffusers.loaders.lora_pipeline.SD3LoraLoaderMixin.load_lora_into_transformer with SD3Transformer2DModel->SanaTransformer2DModel
def load_lora_into_transformer(
cls, state_dict, transformer, adapter_name=None, _pipeline=None, low_cpu_mem_usage=False
cls, state_dict, transformer, adapter_name=None, _pipeline=None, low_cpu_mem_usage=False, hotswap: bool = False
):
"""
This will load the LoRA layers specified in `state_dict` into `transformer`.
@@ -3386,6 +3815,29 @@ class SanaLoraLoaderMixin(LoraBaseMixin):
low_cpu_mem_usage (`bool`, *optional*):
Speed up model loading by only loading the pretrained LoRA weights and not initializing the random
weights.
hotswap : (`bool`, *optional*)
Defaults to `False`. Whether to substitute an existing (LoRA) adapter with the newly loaded adapter
in-place. This means that, instead of loading an additional adapter, this will take the existing
adapter weights and replace them with the weights of the new adapter. This can be faster and more
memory efficient. However, the main advantage of hotswapping is that when the model is compiled with
torch.compile, loading the new adapter does not require recompilation of the model. When using
hotswapping, the passed `adapter_name` should be the name of an already loaded adapter.
If the new adapter and the old adapter have different ranks and/or LoRA alphas (i.e. scaling), you need
to call an additional method before loading the adapter:
```py
pipeline = ... # load diffusers pipeline
max_rank = ... # the highest rank among all LoRAs that you want to load
# call *before* compiling and loading the LoRA adapter
pipeline.enable_lora_hotswap(target_rank=max_rank)
pipeline.load_lora_weights(file_name)
# optionally compile the model now
```
Note that hotswapping adapters of the text encoder is not yet supported. There are some further
limitations to this technique, which are documented here:
https://huggingface.co/docs/peft/main/en/package_reference/hotswap
"""
if low_cpu_mem_usage and is_peft_version("<", "0.13.0"):
raise ValueError(
@@ -3400,6 +3852,7 @@ class SanaLoraLoaderMixin(LoraBaseMixin):
adapter_name=adapter_name,
_pipeline=_pipeline,
low_cpu_mem_usage=low_cpu_mem_usage,
hotswap=hotswap,
)
@classmethod
@@ -3680,7 +4133,7 @@ class HunyuanVideoLoraLoaderMixin(LoraBaseMixin):
@classmethod
# Copied from diffusers.loaders.lora_pipeline.SD3LoraLoaderMixin.load_lora_into_transformer with SD3Transformer2DModel->HunyuanVideoTransformer3DModel
def load_lora_into_transformer(
cls, state_dict, transformer, adapter_name=None, _pipeline=None, low_cpu_mem_usage=False
cls, state_dict, transformer, adapter_name=None, _pipeline=None, low_cpu_mem_usage=False, hotswap: bool = False
):
"""
This will load the LoRA layers specified in `state_dict` into `transformer`.
@@ -3698,6 +4151,29 @@ class HunyuanVideoLoraLoaderMixin(LoraBaseMixin):
low_cpu_mem_usage (`bool`, *optional*):
Speed up model loading by only loading the pretrained LoRA weights and not initializing the random
weights.
hotswap : (`bool`, *optional*)
Defaults to `False`. Whether to substitute an existing (LoRA) adapter with the newly loaded adapter
in-place. This means that, instead of loading an additional adapter, this will take the existing
adapter weights and replace them with the weights of the new adapter. This can be faster and more
memory efficient. However, the main advantage of hotswapping is that when the model is compiled with
torch.compile, loading the new adapter does not require recompilation of the model. When using
hotswapping, the passed `adapter_name` should be the name of an already loaded adapter.
If the new adapter and the old adapter have different ranks and/or LoRA alphas (i.e. scaling), you need
to call an additional method before loading the adapter:
```py
pipeline = ... # load diffusers pipeline
max_rank = ... # the highest rank among all LoRAs that you want to load
# call *before* compiling and loading the LoRA adapter
pipeline.enable_lora_hotswap(target_rank=max_rank)
pipeline.load_lora_weights(file_name)
# optionally compile the model now
```
Note that hotswapping adapters of the text encoder is not yet supported. There are some further
limitations to this technique, which are documented here:
https://huggingface.co/docs/peft/main/en/package_reference/hotswap
"""
if low_cpu_mem_usage and is_peft_version("<", "0.13.0"):
raise ValueError(
@@ -3712,6 +4188,7 @@ class HunyuanVideoLoraLoaderMixin(LoraBaseMixin):
adapter_name=adapter_name,
_pipeline=_pipeline,
low_cpu_mem_usage=low_cpu_mem_usage,
hotswap=hotswap,
)
@classmethod
@@ -3993,7 +4470,7 @@ class Lumina2LoraLoaderMixin(LoraBaseMixin):
@classmethod
# Copied from diffusers.loaders.lora_pipeline.SD3LoraLoaderMixin.load_lora_into_transformer with SD3Transformer2DModel->Lumina2Transformer2DModel
def load_lora_into_transformer(
cls, state_dict, transformer, adapter_name=None, _pipeline=None, low_cpu_mem_usage=False
cls, state_dict, transformer, adapter_name=None, _pipeline=None, low_cpu_mem_usage=False, hotswap: bool = False
):
"""
This will load the LoRA layers specified in `state_dict` into `transformer`.
@@ -4011,6 +4488,29 @@ class Lumina2LoraLoaderMixin(LoraBaseMixin):
low_cpu_mem_usage (`bool`, *optional*):
Speed up model loading by only loading the pretrained LoRA weights and not initializing the random
weights.
hotswap : (`bool`, *optional*)
Defaults to `False`. Whether to substitute an existing (LoRA) adapter with the newly loaded adapter
in-place. This means that, instead of loading an additional adapter, this will take the existing
adapter weights and replace them with the weights of the new adapter. This can be faster and more
memory efficient. However, the main advantage of hotswapping is that when the model is compiled with
torch.compile, loading the new adapter does not require recompilation of the model. When using
hotswapping, the passed `adapter_name` should be the name of an already loaded adapter.
If the new adapter and the old adapter have different ranks and/or LoRA alphas (i.e. scaling), you need
to call an additional method before loading the adapter:
```py
pipeline = ... # load diffusers pipeline
max_rank = ... # the highest rank among all LoRAs that you want to load
# call *before* compiling and loading the LoRA adapter
pipeline.enable_lora_hotswap(target_rank=max_rank)
pipeline.load_lora_weights(file_name)
# optionally compile the model now
```
Note that hotswapping adapters of the text encoder is not yet supported. There are some further
limitations to this technique, which are documented here:
https://huggingface.co/docs/peft/main/en/package_reference/hotswap
"""
if low_cpu_mem_usage and is_peft_version("<", "0.13.0"):
raise ValueError(
@@ -4025,6 +4525,7 @@ class Lumina2LoraLoaderMixin(LoraBaseMixin):
adapter_name=adapter_name,
_pipeline=_pipeline,
low_cpu_mem_usage=low_cpu_mem_usage,
hotswap=hotswap,
)
@classmethod
@@ -4333,7 +4834,7 @@ class WanLoraLoaderMixin(LoraBaseMixin):
@classmethod
# Copied from diffusers.loaders.lora_pipeline.SD3LoraLoaderMixin.load_lora_into_transformer with SD3Transformer2DModel->WanTransformer3DModel
def load_lora_into_transformer(
cls, state_dict, transformer, adapter_name=None, _pipeline=None, low_cpu_mem_usage=False
cls, state_dict, transformer, adapter_name=None, _pipeline=None, low_cpu_mem_usage=False, hotswap: bool = False
):
"""
This will load the LoRA layers specified in `state_dict` into `transformer`.
@@ -4351,6 +4852,29 @@ class WanLoraLoaderMixin(LoraBaseMixin):
low_cpu_mem_usage (`bool`, *optional*):
Speed up model loading by only loading the pretrained LoRA weights and not initializing the random
weights.
hotswap : (`bool`, *optional*)
Defaults to `False`. Whether to substitute an existing (LoRA) adapter with the newly loaded adapter
in-place. This means that, instead of loading an additional adapter, this will take the existing
adapter weights and replace them with the weights of the new adapter. This can be faster and more
memory efficient. However, the main advantage of hotswapping is that when the model is compiled with
torch.compile, loading the new adapter does not require recompilation of the model. When using
hotswapping, the passed `adapter_name` should be the name of an already loaded adapter.
If the new adapter and the old adapter have different ranks and/or LoRA alphas (i.e. scaling), you need
to call an additional method before loading the adapter:
```py
pipeline = ... # load diffusers pipeline
max_rank = ... # the highest rank among all LoRAs that you want to load
# call *before* compiling and loading the LoRA adapter
pipeline.enable_lora_hotswap(target_rank=max_rank)
pipeline.load_lora_weights(file_name)
# optionally compile the model now
```
Note that hotswapping adapters of the text encoder is not yet supported. There are some further
limitations to this technique, which are documented here:
https://huggingface.co/docs/peft/main/en/package_reference/hotswap
"""
if low_cpu_mem_usage and is_peft_version("<", "0.13.0"):
raise ValueError(
@@ -4365,6 +4889,7 @@ class WanLoraLoaderMixin(LoraBaseMixin):
adapter_name=adapter_name,
_pipeline=_pipeline,
low_cpu_mem_usage=low_cpu_mem_usage,
hotswap=hotswap,
)
@classmethod
@@ -4642,7 +5167,7 @@ class CogView4LoraLoaderMixin(LoraBaseMixin):
@classmethod
# Copied from diffusers.loaders.lora_pipeline.SD3LoraLoaderMixin.load_lora_into_transformer with SD3Transformer2DModel->CogView4Transformer2DModel
def load_lora_into_transformer(
cls, state_dict, transformer, adapter_name=None, _pipeline=None, low_cpu_mem_usage=False
cls, state_dict, transformer, adapter_name=None, _pipeline=None, low_cpu_mem_usage=False, hotswap: bool = False
):
"""
This will load the LoRA layers specified in `state_dict` into `transformer`.
@@ -4660,6 +5185,29 @@ class CogView4LoraLoaderMixin(LoraBaseMixin):
low_cpu_mem_usage (`bool`, *optional*):
Speed up model loading by only loading the pretrained LoRA weights and not initializing the random
weights.
hotswap : (`bool`, *optional*)
Defaults to `False`. Whether to substitute an existing (LoRA) adapter with the newly loaded adapter
in-place. This means that, instead of loading an additional adapter, this will take the existing
adapter weights and replace them with the weights of the new adapter. This can be faster and more
memory efficient. However, the main advantage of hotswapping is that when the model is compiled with
torch.compile, loading the new adapter does not require recompilation of the model. When using
hotswapping, the passed `adapter_name` should be the name of an already loaded adapter.
If the new adapter and the old adapter have different ranks and/or LoRA alphas (i.e. scaling), you need
to call an additional method before loading the adapter:
```py
pipeline = ... # load diffusers pipeline
max_rank = ... # the highest rank among all LoRAs that you want to load
# call *before* compiling and loading the LoRA adapter
pipeline.enable_lora_hotswap(target_rank=max_rank)
pipeline.load_lora_weights(file_name)
# optionally compile the model now
```
Note that hotswapping adapters of the text encoder is not yet supported. There are some further
limitations to this technique, which are documented here:
https://huggingface.co/docs/peft/main/en/package_reference/hotswap
"""
if low_cpu_mem_usage and is_peft_version("<", "0.13.0"):
raise ValueError(
@@ -4674,6 +5222,7 @@ class CogView4LoraLoaderMixin(LoraBaseMixin):
adapter_name=adapter_name,
_pipeline=_pipeline,
low_cpu_mem_usage=low_cpu_mem_usage,
hotswap=hotswap,
)
@classmethod
+127 -5
View File
@@ -16,7 +16,7 @@ import inspect
import os
from functools import partial
from pathlib import Path
from typing import Dict, List, Optional, Union
from typing import Dict, List, Literal, Optional, Union
import safetensors
import torch
@@ -128,6 +128,8 @@ class PeftAdapterMixin:
"""
_hf_peft_config_loaded = False
# kwargs for prepare_model_for_compiled_hotswap, if required
_prepare_lora_hotswap_kwargs: Optional[dict] = None
@classmethod
# Copied from diffusers.loaders.lora_base.LoraBaseMixin._optionally_disable_offloading
@@ -145,7 +147,9 @@ class PeftAdapterMixin:
"""
return _func_optionally_disable_offloading(_pipeline=_pipeline)
def load_lora_adapter(self, pretrained_model_name_or_path_or_dict, prefix="transformer", **kwargs):
def load_lora_adapter(
self, pretrained_model_name_or_path_or_dict, prefix="transformer", hotswap: bool = False, **kwargs
):
r"""
Loads a LoRA adapter into the underlying model.
@@ -189,6 +193,29 @@ class PeftAdapterMixin:
low_cpu_mem_usage (`bool`, *optional*):
Speed up model loading by only loading the pretrained LoRA weights and not initializing the random
weights.
hotswap : (`bool`, *optional*)
Defaults to `False`. Whether to substitute an existing (LoRA) adapter with the newly loaded adapter
in-place. This means that, instead of loading an additional adapter, this will take the existing
adapter weights and replace them with the weights of the new adapter. This can be faster and more
memory efficient. However, the main advantage of hotswapping is that when the model is compiled with
torch.compile, loading the new adapter does not require recompilation of the model. When using
hotswapping, the passed `adapter_name` should be the name of an already loaded adapter.
If the new adapter and the old adapter have different ranks and/or LoRA alphas (i.e. scaling), you need
to call an additional method before loading the adapter:
```py
pipeline = ... # load diffusers pipeline
max_rank = ... # the highest rank among all LoRAs that you want to load
# call *before* compiling and loading the LoRA adapter
pipeline.enable_lora_hotswap(target_rank=max_rank)
pipeline.load_lora_weights(file_name)
# optionally compile the model now
```
Note that hotswapping adapters of the text encoder is not yet supported. There are some further
limitations to this technique, which are documented here:
https://huggingface.co/docs/peft/main/en/package_reference/hotswap
"""
from peft import LoraConfig, inject_adapter_in_model, set_peft_model_state_dict
from peft.tuners.tuners_utils import BaseTunerLayer
@@ -239,10 +266,15 @@ class PeftAdapterMixin:
state_dict = {k[len(f"{prefix}.") :]: v for k, v in state_dict.items() if k.startswith(f"{prefix}.")}
if len(state_dict) > 0:
if adapter_name in getattr(self, "peft_config", {}):
if adapter_name in getattr(self, "peft_config", {}) and not hotswap:
raise ValueError(
f"Adapter name {adapter_name} already in use in the model - please select a new adapter name."
)
elif adapter_name not in getattr(self, "peft_config", {}) and hotswap:
raise ValueError(
f"Trying to hotswap LoRA adapter '{adapter_name}' but there is no existing adapter by that name. "
"Please choose an existing adapter name or set `hotswap=False` to prevent hotswapping."
)
# check with first key if is not in peft format
first_key = next(iter(state_dict.keys()))
@@ -302,11 +334,68 @@ class PeftAdapterMixin:
if is_peft_version(">=", "0.13.1"):
peft_kwargs["low_cpu_mem_usage"] = low_cpu_mem_usage
if hotswap or (self._prepare_lora_hotswap_kwargs is not None):
if is_peft_version(">", "0.14.0"):
from peft.utils.hotswap import (
check_hotswap_configs_compatible,
hotswap_adapter_from_state_dict,
prepare_model_for_compiled_hotswap,
)
else:
msg = (
"Hotswapping requires PEFT > v0.14. Please upgrade PEFT to a higher version or install it "
"from source."
)
raise ImportError(msg)
if hotswap:
def map_state_dict_for_hotswap(sd):
# For hotswapping, we need the adapter name to be present in the state dict keys
new_sd = {}
for k, v in sd.items():
if k.endswith("lora_A.weight") or key.endswith("lora_B.weight"):
k = k[: -len(".weight")] + f".{adapter_name}.weight"
elif k.endswith("lora_B.bias"): # lora_bias=True option
k = k[: -len(".bias")] + f".{adapter_name}.bias"
new_sd[k] = v
return new_sd
# To handle scenarios where we cannot successfully set state dict. If it's unsucessful,
# we should also delete the `peft_config` associated to the `adapter_name`.
try:
inject_adapter_in_model(lora_config, self, adapter_name=adapter_name, **peft_kwargs)
incompatible_keys = set_peft_model_state_dict(self, state_dict, adapter_name, **peft_kwargs)
if hotswap:
state_dict = map_state_dict_for_hotswap(state_dict)
check_hotswap_configs_compatible(self.peft_config[adapter_name], lora_config)
try:
hotswap_adapter_from_state_dict(
model=self,
state_dict=state_dict,
adapter_name=adapter_name,
config=lora_config,
)
except Exception as e:
logger.error(f"Hotswapping {adapter_name} was unsucessful with the following error: \n{e}")
raise
# the hotswap function raises if there are incompatible keys, so if we reach this point we can set
# it to None
incompatible_keys = None
else:
inject_adapter_in_model(lora_config, self, adapter_name=adapter_name, **peft_kwargs)
incompatible_keys = set_peft_model_state_dict(self, state_dict, adapter_name, **peft_kwargs)
if self._prepare_lora_hotswap_kwargs is not None:
# For hotswapping of compiled models or adapters with different ranks.
# If the user called enable_lora_hotswap, we need to ensure it is called:
# - after the first adapter was loaded
# - before the model is compiled and the 2nd adapter is being hotswapped in
# Therefore, it needs to be called here
prepare_model_for_compiled_hotswap(
self, config=lora_config, **self._prepare_lora_hotswap_kwargs
)
# We only want to call prepare_model_for_compiled_hotswap once
self._prepare_lora_hotswap_kwargs = None
# Set peft config loaded flag to True if module has been successfully injected and incompatible keys retrieved
if not self._hf_peft_config_loaded:
self._hf_peft_config_loaded = True
@@ -769,3 +858,36 @@ class PeftAdapterMixin:
# Pop also the corresponding adapter from the config
if hasattr(self, "peft_config"):
self.peft_config.pop(adapter_name, None)
def enable_lora_hotswap(
self, target_rank: int = 128, check_compiled: Literal["error", "warn", "ignore"] = "error"
) -> None:
"""Enables the possibility to hotswap LoRA adapters.
Calling this method is only required when hotswapping adapters and if the model is compiled or if the ranks of
the loaded adapters differ.
Args:
target_rank (`int`, *optional*, defaults to `128`):
The highest rank among all the adapters that will be loaded.
check_compiled (`str`, *optional*, defaults to `"error"`):
How to handle the case when the model is already compiled, which should generally be avoided. The
options are:
- "error" (default): raise an error
- "warn": issue a warning
- "ignore": do nothing
"""
if getattr(self, "peft_config", {}):
if check_compiled == "error":
raise RuntimeError("Call `enable_lora_hotswap` before loading the first adapter.")
elif check_compiled == "warn":
logger.warning(
"It is recommended to call `enable_lora_hotswap` before loading the first adapter to avoid recompilation."
)
elif check_compiled != "ignore":
raise ValueError(
f"check_compiles should be one of 'error', 'warn', or 'ignore', got '{check_compiled}' instead."
)
self._prepare_lora_hotswap_kwargs = {"target_rank": target_rank, "check_compiled": check_compiled}
@@ -210,7 +210,7 @@ class MochiDownBlock3D(nn.Module):
hidden_states, new_conv_cache[conv_cache_key] = self._gradient_checkpointing_func(
resnet,
hidden_states,
conv_cache=conv_cache.get(conv_cache_key),
conv_cache.get(conv_cache_key),
)
else:
hidden_states, new_conv_cache[conv_cache_key] = resnet(
@@ -306,7 +306,7 @@ class MochiMidBlock3D(nn.Module):
if torch.is_grad_enabled() and self.gradient_checkpointing:
hidden_states, new_conv_cache[conv_cache_key] = self._gradient_checkpointing_func(
resnet, hidden_states, conv_cache=conv_cache.get(conv_cache_key)
resnet, hidden_states, conv_cache.get(conv_cache_key)
)
else:
hidden_states, new_conv_cache[conv_cache_key] = resnet(
@@ -382,7 +382,7 @@ class MochiUpBlock3D(nn.Module):
hidden_states, new_conv_cache[conv_cache_key] = self._gradient_checkpointing_func(
resnet,
hidden_states,
conv_cache=conv_cache.get(conv_cache_key),
conv_cache.get(conv_cache_key),
)
else:
hidden_states, new_conv_cache[conv_cache_key] = resnet(
@@ -497,6 +497,8 @@ class MochiEncoder3D(nn.Module):
self.norm_out = MochiChunkedGroupNorm3D(block_out_channels[-1])
self.proj_out = nn.Linear(block_out_channels[-1], 2 * out_channels, bias=False)
self.gradient_checkpointing = False
def forward(
self, hidden_states: torch.Tensor, conv_cache: Optional[Dict[str, torch.Tensor]] = None
) -> torch.Tensor:
@@ -513,13 +515,13 @@ class MochiEncoder3D(nn.Module):
if torch.is_grad_enabled() and self.gradient_checkpointing:
hidden_states, new_conv_cache["block_in"] = self._gradient_checkpointing_func(
self.block_in, hidden_states, conv_cache=conv_cache.get("block_in")
self.block_in, hidden_states, conv_cache.get("block_in")
)
for i, down_block in enumerate(self.down_blocks):
conv_cache_key = f"down_block_{i}"
hidden_states, new_conv_cache[conv_cache_key] = self._gradient_checkpointing_func(
down_block, hidden_states, conv_cache=conv_cache.get(conv_cache_key)
down_block, hidden_states, conv_cache.get(conv_cache_key)
)
else:
hidden_states, new_conv_cache["block_in"] = self.block_in(
@@ -623,13 +625,13 @@ class MochiDecoder3D(nn.Module):
# 1. Mid
if torch.is_grad_enabled() and self.gradient_checkpointing:
hidden_states, new_conv_cache["block_in"] = self._gradient_checkpointing_func(
self.block_in, hidden_states, conv_cache=conv_cache.get("block_in")
self.block_in, hidden_states, conv_cache.get("block_in")
)
for i, up_block in enumerate(self.up_blocks):
conv_cache_key = f"up_block_{i}"
hidden_states, new_conv_cache[conv_cache_key] = self._gradient_checkpointing_func(
up_block, hidden_states, conv_cache=conv_cache.get(conv_cache_key)
up_block, hidden_states, conv_cache.get(conv_cache_key)
)
else:
hidden_states, new_conv_cache["block_in"] = self.block_in(
@@ -868,7 +868,7 @@ class PixArtSigmaPipeline(DiffusionPipeline):
xm.mark_step()
if not output_type == "latent":
image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0]
image = self.vae.decode(latents.to(self.vae.dtype) / self.vae.config.scaling_factor, return_dict=False)[0]
if use_resolution_binning:
image = self.image_processor.resize_and_crop_tensor(image, orig_width, orig_height)
else:
@@ -321,9 +321,19 @@ class WanImageToVideoPipeline(DiffusionPipeline, WanLoraLoaderMixin):
width,
prompt_embeds=None,
negative_prompt_embeds=None,
image_embeds=None,
callback_on_step_end_tensor_inputs=None,
):
if not isinstance(image, torch.Tensor) and not isinstance(image, PIL.Image.Image):
if image is not None and image_embeds is not None:
raise ValueError(
f"Cannot forward both `image`: {image} and `image_embeds`: {image_embeds}. Please make sure to"
" only forward one of the two."
)
if image is None and image_embeds is None:
raise ValueError(
"Provide either `image` or `prompt_embeds`. Cannot leave both `image` and `image_embeds` undefined."
)
if image is not None and not isinstance(image, torch.Tensor) and not isinstance(image, PIL.Image.Image):
raise ValueError("`image` has to be of type `torch.Tensor` or `PIL.Image.Image` but is" f" {type(image)}")
if height % 16 != 0 or width % 16 != 0:
raise ValueError(f"`height` and `width` have to be divisible by 16 but are {height} and {width}.")
@@ -463,6 +473,7 @@ class WanImageToVideoPipeline(DiffusionPipeline, WanLoraLoaderMixin):
latents: Optional[torch.Tensor] = None,
prompt_embeds: Optional[torch.Tensor] = None,
negative_prompt_embeds: Optional[torch.Tensor] = None,
image_embeds: Optional[torch.Tensor] = None,
output_type: Optional[str] = "np",
return_dict: bool = True,
attention_kwargs: Optional[Dict[str, Any]] = None,
@@ -512,6 +523,12 @@ class WanImageToVideoPipeline(DiffusionPipeline, WanLoraLoaderMixin):
prompt_embeds (`torch.Tensor`, *optional*):
Pre-generated text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not
provided, text embeddings are generated from the `prompt` input argument.
negative_prompt_embeds (`torch.Tensor`, *optional*):
Pre-generated text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not
provided, text embeddings are generated from the `negative_prompt` input argument.
image_embeds (`torch.Tensor`, *optional*):
Pre-generated image embeddings. Can be used to easily tweak image inputs (weighting). If not provided,
image embeddings are generated from the `image` input argument.
output_type (`str`, *optional*, defaults to `"pil"`):
The output format of the generated image. Choose between `PIL.Image` or `np.array`.
return_dict (`bool`, *optional*, defaults to `True`):
@@ -556,6 +573,7 @@ class WanImageToVideoPipeline(DiffusionPipeline, WanLoraLoaderMixin):
width,
prompt_embeds,
negative_prompt_embeds,
image_embeds,
callback_on_step_end_tensor_inputs,
)
@@ -599,7 +617,8 @@ class WanImageToVideoPipeline(DiffusionPipeline, WanLoraLoaderMixin):
if negative_prompt_embeds is not None:
negative_prompt_embeds = negative_prompt_embeds.to(transformer_dtype)
image_embeds = self.encode_image(image, device)
if image_embeds is None:
image_embeds = self.encode_image(image, device)
image_embeds = image_embeds.repeat(batch_size, 1, 1)
image_embeds = image_embeds.to(transformer_dtype)
+188 -1
View File
@@ -14,10 +14,11 @@ import tempfile
import time
import unittest
import urllib.parse
from collections import UserDict
from contextlib import contextmanager
from io import BytesIO, StringIO
from pathlib import Path
from typing import Callable, Dict, List, Optional, Union
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple, Union
import numpy as np
import PIL.Image
@@ -48,6 +49,17 @@ from .import_utils import (
from .logging import get_logger
if is_torch_available():
import torch
IS_ROCM_SYSTEM = torch.version.hip is not None
IS_CUDA_SYSTEM = torch.version.cuda is not None
IS_XPU_SYSTEM = getattr(torch.version, "xpu", None) is not None
else:
IS_ROCM_SYSTEM = False
IS_CUDA_SYSTEM = False
IS_XPU_SYSTEM = False
global_rng = random.Random()
logger = get_logger(__name__)
@@ -1275,3 +1287,178 @@ if is_torch_available():
update_mapping_from_spec(BACKEND_RESET_PEAK_MEMORY_STATS, "RESET_PEAK_MEMORY_STATS_FN")
update_mapping_from_spec(BACKEND_RESET_MAX_MEMORY_ALLOCATED, "RESET_MAX_MEMORY_ALLOCATED_FN")
update_mapping_from_spec(BACKEND_MAX_MEMORY_ALLOCATED, "MAX_MEMORY_ALLOCATED_FN")
# Modified from https://github.com/huggingface/transformers/blob/cdfb018d0300fef3b07d9220f3efe9c2a9974662/src/transformers/testing_utils.py#L3090
# Type definition of key used in `Expectations` class.
DeviceProperties = Tuple[Union[str, None], Union[int, None]]
@functools.lru_cache
def get_device_properties() -> DeviceProperties:
"""
Get environment device properties.
"""
if IS_CUDA_SYSTEM or IS_ROCM_SYSTEM:
import torch
major, _ = torch.cuda.get_device_capability()
if IS_ROCM_SYSTEM:
return ("rocm", major)
else:
return ("cuda", major)
elif IS_XPU_SYSTEM:
import torch
# To get more info of the architecture meaning and bit allocation, refer to https://github.com/intel/llvm/blob/sycl/sycl/include/sycl/ext/oneapi/experimental/device_architecture.def
arch = torch.xpu.get_device_capability()["architecture"]
gen_mask = 0x000000FF00000000
gen = (arch & gen_mask) >> 32
return ("xpu", gen)
else:
return (torch_device, None)
if TYPE_CHECKING:
DevicePropertiesUserDict = UserDict[DeviceProperties, Any]
else:
DevicePropertiesUserDict = UserDict
class Expectations(DevicePropertiesUserDict):
def get_expectation(self) -> Any:
"""
Find best matching expectation based on environment device properties.
"""
return self.find_expectation(get_device_properties())
@staticmethod
def is_default(key: DeviceProperties) -> bool:
return all(p is None for p in key)
@staticmethod
def score(key: DeviceProperties, other: DeviceProperties) -> int:
"""
Returns score indicating how similar two instances of the `Properties` tuple are. Points are calculated using
bits, but documented as int. Rules are as follows:
* Matching `type` gives 8 points.
* Semi-matching `type`, for example cuda and rocm, gives 4 points.
* Matching `major` (compute capability major version) gives 2 points.
* Default expectation (if present) gives 1 points.
"""
(device_type, major) = key
(other_device_type, other_major) = other
score = 0b0
if device_type == other_device_type:
score |= 0b1000
elif device_type in ["cuda", "rocm"] and other_device_type in ["cuda", "rocm"]:
score |= 0b100
if major == other_major and other_major is not None:
score |= 0b10
if Expectations.is_default(other):
score |= 0b1
return int(score)
def find_expectation(self, key: DeviceProperties = (None, None)) -> Any:
"""
Find best matching expectation based on provided device properties.
"""
(result_key, result) = max(self.data.items(), key=lambda x: Expectations.score(key, x[0]))
if Expectations.score(key, result_key) == 0:
raise ValueError(f"No matching expectation found for {key}")
return result
def __repr__(self):
return f"{self.data}"
def dynamic_slice_test(func):
"""
Decorator that injects an expected_slice parameter into a test function.
On the first run, it will capture the actual slice output and cache it.
On subsequent runs, it provides the cached slice as the expected slice.
Example:
```python
@dynamic_slice_test
def test_stable_diffusion_ddim(self, expected_slice=None):
# Run the pipeline
components = self.get_dummy_components()
sd_pipe = StableDiffusionPipeline(**components)
inputs = self.get_dummy_inputs("cpu")
image = sd_pipe(**inputs).images
image_slice = image[0, -3:, -3:, -1]
# If expected_slice is provided (from cache), assert against it
if expected_slice is not None:
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
# Always return the current slice for caching
return image_slice
```
"""
# Check if the function has the expected_slice parameter
sig = inspect.signature(func)
if "expected_slice" not in sig.parameters:
raise ValueError("The decorated function must have an 'expected_slice' parameter")
@functools.wraps(func)
def wrapper(*args, **kwargs):
# Get the test name from pytest
# pytest sets this environment variable to the current test
test_name = os.environ.get("PYTEST_CURRENT_TEST", "")
if test_name:
# Format is: test_file.py::TestClass::test_method (call)
test_name = test_name.split(" ")[0]
else:
# Fallback if not running in pytest
test_name = f"{func.__module__}.{func.__qualname__}"
# Create a unique filename based on hardware details
device_props = get_device_properties()
device_str = f"{device_props[0]}{device_props[1] if device_props[1] is not None else ''}"
# Setup cache directory
cache_dir = os.environ.get("DIFFUSERS_TEST_CACHE_DIR", ".test_cache")
os.makedirs(cache_dir, exist_ok=True)
cache_path = os.path.join(cache_dir, f"{test_name}_{device_str}.npy")
# Check for cached expected slice
cached_slice = None
if os.path.exists(cache_path):
try:
cached_slice = np.load(cache_path)
print(f"Using cached slice from {cache_path}")
except Exception as e:
print(f"Error loading cached slice: {e}")
# Run the test function with the expected slice injected
kwargs["expected_slice"] = cached_slice
actual_slice = func(*args, **kwargs)
# If the function returned a slice and there's no cached slice yet, cache it
if actual_slice is not None and cached_slice is None:
# Convert torch tensor to numpy if needed
if hasattr(actual_slice, "detach") and hasattr(actual_slice, "cpu") and hasattr(actual_slice, "numpy"):
actual_slice_np = actual_slice.detach().cpu().numpy()
else:
actual_slice_np = actual_slice
# Save the slice
try:
np.save(cache_path, actual_slice_np)
print(f"Saved slice to cache: {cache_path}")
except Exception as e:
print(f"Error saving slice to cache: {e}")
return actual_slice
return wrapper
+111
View File
@@ -0,0 +1,111 @@
# coding=utf-8
# Copyright 2024 HuggingFace Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
from diffusers import AutoencoderKLMochi
from diffusers.utils.testing_utils import (
enable_full_determinism,
floats_tensor,
torch_device,
)
from ..test_modeling_common import ModelTesterMixin, UNetTesterMixin
enable_full_determinism()
class AutoencoderKLMochiTests(ModelTesterMixin, UNetTesterMixin, unittest.TestCase):
model_class = AutoencoderKLMochi
main_input_name = "sample"
base_precision = 1e-2
def get_autoencoder_kl_mochi_config(self):
return {
"in_channels": 15,
"out_channels": 3,
"latent_channels": 4,
"encoder_block_out_channels": (32, 32, 32, 32),
"decoder_block_out_channels": (32, 32, 32, 32),
"layers_per_block": (1, 1, 1, 1, 1),
"act_fn": "silu",
"scaling_factor": 1,
}
@property
def dummy_input(self):
batch_size = 2
num_frames = 7
num_channels = 3
sizes = (16, 16)
image = floats_tensor((batch_size, num_channels, num_frames) + sizes).to(torch_device)
return {"sample": image}
@property
def input_shape(self):
return (3, 7, 16, 16)
@property
def output_shape(self):
return (3, 7, 16, 16)
def prepare_init_args_and_inputs_for_common(self):
init_dict = self.get_autoencoder_kl_mochi_config()
inputs_dict = self.dummy_input
return init_dict, inputs_dict
def test_gradient_checkpointing_is_applied(self):
expected_set = {
"MochiDecoder3D",
"MochiDownBlock3D",
"MochiEncoder3D",
"MochiMidBlock3D",
"MochiUpBlock3D",
}
super().test_gradient_checkpointing_is_applied(expected_set=expected_set)
@unittest.skip("Unsupported test.")
def test_forward_with_norm_groups(self):
"""
tests/models/autoencoders/test_models_autoencoder_mochi.py::AutoencoderKLMochiTests::test_forward_with_norm_groups -
TypeError: AutoencoderKLMochi.__init__() got an unexpected keyword argument 'norm_num_groups'
"""
pass
@unittest.skip("Unsupported test.")
def test_model_parallelism(self):
"""
tests/models/autoencoders/test_models_autoencoder_mochi.py::AutoencoderKLMochiTests::test_outputs_equivalence -
RuntimeError: values expected sparse tensor layout but got Strided
"""
pass
@unittest.skip("Unsupported test.")
def test_outputs_equivalence(self):
"""
tests/models/autoencoders/test_models_autoencoder_mochi.py::AutoencoderKLMochiTests::test_outputs_equivalence -
RuntimeError: values expected sparse tensor layout but got Strided
"""
pass
@unittest.skip("Unsupported test.")
def test_sharded_checkpoints_device_map(self):
"""
tests/models/autoencoders/test_models_autoencoder_mochi.py::AutoencoderKLMochiTests::test_sharded_checkpoints_device_map -
RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cuda:0 and cuda:5!
"""
+237
View File
@@ -24,6 +24,7 @@ import traceback
import unittest
import unittest.mock as mock
import uuid
import warnings
from collections import defaultdict
from typing import Dict, List, Optional, Tuple, Union
@@ -56,15 +57,20 @@ from diffusers.utils import (
from diffusers.utils.hub_utils import _add_variant
from diffusers.utils.testing_utils import (
CaptureLogger,
backend_empty_cache,
floats_tensor,
get_python_version,
is_torch_compile,
numpy_cosine_similarity_distance,
require_peft_backend,
require_peft_version_greater,
require_torch_2,
require_torch_accelerator,
require_torch_accelerator_with_training,
require_torch_gpu,
require_torch_multi_accelerator,
run_test_in_subprocess,
slow,
torch_all_close,
torch_device,
)
@@ -1659,3 +1665,234 @@ class ModelPushToHubTester(unittest.TestCase):
# Reset repo
delete_repo(self.repo_id, token=TOKEN)
@slow
@require_torch_2
@require_torch_accelerator
@require_peft_backend
@require_peft_version_greater("0.14.0")
@is_torch_compile
class TestLoraHotSwappingForModel(unittest.TestCase):
"""Test that hotswapping does not result in recompilation on the model directly.
We're not extensively testing the hotswapping functionality since it is implemented in PEFT and is extensively
tested there. The goal of this test is specifically to ensure that hotswapping with diffusers does not require
recompilation.
See
https://github.com/huggingface/peft/blob/eaab05e18d51fb4cce20a73c9acd82a00c013b83/tests/test_gpu_examples.py#L4252
for the analogous PEFT test.
"""
def tearDown(self):
# It is critical that the dynamo cache is reset for each test. Otherwise, if the test re-uses the same model,
# there will be recompilation errors, as torch caches the model when run in the same process.
super().tearDown()
torch._dynamo.reset()
gc.collect()
backend_empty_cache(torch_device)
def get_small_unet(self):
# from diffusers UNet2DConditionModelTests
torch.manual_seed(0)
init_dict = {
"block_out_channels": (4, 8),
"norm_num_groups": 4,
"down_block_types": ("CrossAttnDownBlock2D", "DownBlock2D"),
"up_block_types": ("UpBlock2D", "CrossAttnUpBlock2D"),
"cross_attention_dim": 8,
"attention_head_dim": 2,
"out_channels": 4,
"in_channels": 4,
"layers_per_block": 1,
"sample_size": 16,
}
model = UNet2DConditionModel(**init_dict)
return model.to(torch_device)
def get_unet_lora_config(self, lora_rank, lora_alpha, target_modules):
# from diffusers test_models_unet_2d_condition.py
from peft import LoraConfig
unet_lora_config = LoraConfig(
r=lora_rank,
lora_alpha=lora_alpha,
target_modules=target_modules,
init_lora_weights=False,
use_dora=False,
)
return unet_lora_config
def get_dummy_input(self):
# from UNet2DConditionModelTests
batch_size = 4
num_channels = 4
sizes = (16, 16)
noise = floats_tensor((batch_size, num_channels) + sizes).to(torch_device)
time_step = torch.tensor([10]).to(torch_device)
encoder_hidden_states = floats_tensor((batch_size, 4, 8)).to(torch_device)
return {"sample": noise, "timestep": time_step, "encoder_hidden_states": encoder_hidden_states}
def check_model_hotswap(self, do_compile, rank0, rank1, target_modules0, target_modules1=None):
"""
Check that hotswapping works on a small unet.
Steps:
- create 2 LoRA adapters and save them
- load the first adapter
- hotswap the second adapter
- check that the outputs are correct
- optionally compile the model
Note: We set rank == alpha here because save_lora_adapter does not save the alpha scalings, thus the test would
fail if the values are different. Since rank != alpha does not matter for the purpose of this test, this is
fine.
"""
# create 2 adapters with different ranks and alphas
dummy_input = self.get_dummy_input()
alpha0, alpha1 = rank0, rank1
max_rank = max([rank0, rank1])
if target_modules1 is None:
target_modules1 = target_modules0[:]
lora_config0 = self.get_unet_lora_config(rank0, alpha0, target_modules0)
lora_config1 = self.get_unet_lora_config(rank1, alpha1, target_modules1)
unet = self.get_small_unet()
unet.add_adapter(lora_config0, adapter_name="adapter0")
with torch.inference_mode():
output0_before = unet(**dummy_input)["sample"]
unet.add_adapter(lora_config1, adapter_name="adapter1")
unet.set_adapter("adapter1")
with torch.inference_mode():
output1_before = unet(**dummy_input)["sample"]
# sanity checks:
tol = 5e-3
assert not torch.allclose(output0_before, output1_before, atol=tol, rtol=tol)
assert not (output0_before == 0).all()
assert not (output1_before == 0).all()
with tempfile.TemporaryDirectory() as tmp_dirname:
# save the adapter checkpoints
unet.save_lora_adapter(os.path.join(tmp_dirname, "0"), safe_serialization=True, adapter_name="adapter0")
unet.save_lora_adapter(os.path.join(tmp_dirname, "1"), safe_serialization=True, adapter_name="adapter1")
del unet
# load the first adapter
unet = self.get_small_unet()
if do_compile or (rank0 != rank1):
# no need to prepare if the model is not compiled or if the ranks are identical
unet.enable_lora_hotswap(target_rank=max_rank)
file_name0 = os.path.join(os.path.join(tmp_dirname, "0"), "pytorch_lora_weights.safetensors")
file_name1 = os.path.join(os.path.join(tmp_dirname, "1"), "pytorch_lora_weights.safetensors")
unet.load_lora_adapter(file_name0, safe_serialization=True, adapter_name="adapter0", prefix=None)
if do_compile:
unet = torch.compile(unet, mode="reduce-overhead")
with torch.inference_mode():
output0_after = unet(**dummy_input)["sample"]
assert torch.allclose(output0_before, output0_after, atol=tol, rtol=tol)
# hotswap the 2nd adapter
unet.load_lora_adapter(file_name1, adapter_name="adapter0", hotswap=True, prefix=None)
# we need to call forward to potentially trigger recompilation
with torch.inference_mode():
output1_after = unet(**dummy_input)["sample"]
assert torch.allclose(output1_before, output1_after, atol=tol, rtol=tol)
# check error when not passing valid adapter name
name = "does-not-exist"
msg = f"Trying to hotswap LoRA adapter '{name}' but there is no existing adapter by that name"
with self.assertRaisesRegex(ValueError, msg):
unet.load_lora_adapter(file_name1, adapter_name=name, hotswap=True, prefix=None)
@parameterized.expand([(11, 11), (7, 13), (13, 7)]) # important to test small to large and vice versa
def test_hotswapping_model(self, rank0, rank1):
self.check_model_hotswap(
do_compile=False, rank0=rank0, rank1=rank1, target_modules0=["to_q", "to_k", "to_v", "to_out.0"]
)
@parameterized.expand([(11, 11), (7, 13), (13, 7)]) # important to test small to large and vice versa
def test_hotswapping_compiled_model_linear(self, rank0, rank1):
# It's important to add this context to raise an error on recompilation
target_modules = ["to_q", "to_k", "to_v", "to_out.0"]
with torch._dynamo.config.patch(error_on_recompile=True):
self.check_model_hotswap(do_compile=True, rank0=rank0, rank1=rank1, target_modules0=target_modules)
@parameterized.expand([(11, 11), (7, 13), (13, 7)]) # important to test small to large and vice versa
def test_hotswapping_compiled_model_conv2d(self, rank0, rank1):
# It's important to add this context to raise an error on recompilation
target_modules = ["conv", "conv1", "conv2"]
with torch._dynamo.config.patch(error_on_recompile=True):
self.check_model_hotswap(do_compile=True, rank0=rank0, rank1=rank1, target_modules0=target_modules)
@parameterized.expand([(11, 11), (7, 13), (13, 7)]) # important to test small to large and vice versa
def test_hotswapping_compiled_model_both_linear_and_conv2d(self, rank0, rank1):
# It's important to add this context to raise an error on recompilation
target_modules = ["to_q", "conv"]
with torch._dynamo.config.patch(error_on_recompile=True):
self.check_model_hotswap(do_compile=True, rank0=rank0, rank1=rank1, target_modules0=target_modules)
def test_enable_lora_hotswap_called_after_adapter_added_raises(self):
# ensure that enable_lora_hotswap is called before loading the first adapter
lora_config = self.get_unet_lora_config(8, 8, target_modules=["to_q"])
unet = self.get_small_unet()
unet.add_adapter(lora_config)
msg = re.escape("Call `enable_lora_hotswap` before loading the first adapter.")
with self.assertRaisesRegex(RuntimeError, msg):
unet.enable_lora_hotswap(target_rank=32)
def test_enable_lora_hotswap_called_after_adapter_added_warning(self):
# ensure that enable_lora_hotswap is called before loading the first adapter
from diffusers.loaders.peft import logger
lora_config = self.get_unet_lora_config(8, 8, target_modules=["to_q"])
unet = self.get_small_unet()
unet.add_adapter(lora_config)
msg = (
"It is recommended to call `enable_lora_hotswap` before loading the first adapter to avoid recompilation."
)
with self.assertLogs(logger=logger, level="WARNING") as cm:
unet.enable_lora_hotswap(target_rank=32, check_compiled="warn")
assert any(msg in log for log in cm.output)
def test_enable_lora_hotswap_called_after_adapter_added_ignore(self):
# check possibility to ignore the error/warning
lora_config = self.get_unet_lora_config(8, 8, target_modules=["to_q"])
unet = self.get_small_unet()
unet.add_adapter(lora_config)
with warnings.catch_warnings(record=True) as w:
warnings.simplefilter("always") # Capture all warnings
unet.enable_lora_hotswap(target_rank=32, check_compiled="warn")
self.assertEqual(len(w), 0, f"Expected no warnings, but got: {[str(warn.message) for warn in w]}")
def test_enable_lora_hotswap_wrong_check_compiled_argument_raises(self):
# check that wrong argument value raises an error
lora_config = self.get_unet_lora_config(8, 8, target_modules=["to_q"])
unet = self.get_small_unet()
unet.add_adapter(lora_config)
msg = re.escape("check_compiles should be one of 'error', 'warn', or 'ignore', got 'wrong-argument' instead.")
with self.assertRaisesRegex(ValueError, msg):
unet.enable_lora_hotswap(target_rank=32, check_compiled="wrong-argument")
def test_hotswap_second_adapter_targets_more_layers_raises(self):
# check the error and log
from diffusers.loaders.peft import logger
# at the moment, PEFT requires the 2nd adapter to target the same or a subset of layers
target_modules0 = ["to_q"]
target_modules1 = ["to_q", "to_k"]
with self.assertRaises(RuntimeError): # peft raises RuntimeError
with self.assertLogs(logger=logger, level="ERROR") as cm:
self.check_model_hotswap(
do_compile=True, rank0=8, rank1=8, target_modules0=target_modules0, target_modules1=target_modules1
)
assert any("Hotswapping adapter0 was unsuccessful" in log for log in cm.output)
+33 -1
View File
@@ -20,7 +20,7 @@ import pytest
from diffusers import __version__
from diffusers.utils import deprecate
from diffusers.utils.testing_utils import str_to_bool
from diffusers.utils.testing_utils import Expectations, str_to_bool
# Used to test the hub
@@ -182,6 +182,38 @@ class DeprecateTester(unittest.TestCase):
assert "diffusers/tests/others/test_utils.py" in warning.filename
# Copied from https://github.com/huggingface/transformers/blob/main/tests/utils/test_expectations.py
class ExpectationsTester(unittest.TestCase):
def test_expectations(self):
expectations = Expectations(
{
(None, None): 1,
("cuda", 8): 2,
("cuda", 7): 3,
("rocm", 8): 4,
("rocm", None): 5,
("cpu", None): 6,
("xpu", 3): 7,
}
)
def check(value, key):
assert expectations.find_expectation(key) == value
# npu has no matches so should find default expectation
check(1, ("npu", None))
check(7, ("xpu", 3))
check(2, ("cuda", 8))
check(3, ("cuda", 7))
check(4, ("rocm", 9))
check(4, ("rocm", None))
check(2, ("cuda", 2))
expectations = Expectations({("cuda", 8): 1})
with self.assertRaises(ValueError):
expectations.find_expectation(("xpu", None))
def parse_flag_from_env(key, default=False):
try:
value = os.environ[key]
@@ -15,6 +15,7 @@ from diffusers import (
)
from diffusers.utils import load_image
from diffusers.utils.testing_utils import (
Expectations,
backend_empty_cache,
floats_tensor,
numpy_cosine_similarity_distance,
@@ -208,41 +209,115 @@ class StableDiffusion3Img2ImgPipelineSlowTests(unittest.TestCase):
inputs = self.get_inputs(torch_device)
image = pipe(**inputs).images[0]
image_slice = image[0, :10, :10]
expected_slice = np.array(
[
0.5435,
0.4673,
0.5732,
0.4438,
0.3557,
0.4912,
0.4331,
0.3491,
0.4915,
0.4287,
0.3477,
0.4849,
0.4355,
0.3469,
0.4871,
0.4431,
0.3538,
0.4912,
0.4521,
0.3643,
0.5059,
0.4587,
0.3730,
0.5166,
0.4685,
0.3845,
0.5264,
0.4746,
0.3914,
0.5342,
]
expected_slices = Expectations(
{
("xpu", 3): np.array(
[
0.5117,
0.4421,
0.3852,
0.5044,
0.4219,
0.3262,
0.5024,
0.4329,
0.3276,
0.4978,
0.4412,
0.3355,
0.4983,
0.4338,
0.3279,
0.4893,
0.4241,
0.3129,
0.4875,
0.4253,
0.3030,
0.4961,
0.4267,
0.2988,
0.5029,
0.4255,
0.3054,
0.5132,
0.4248,
0.3222,
]
),
("cuda", 7): np.array(
[
0.5435,
0.4673,
0.5732,
0.4438,
0.3557,
0.4912,
0.4331,
0.3491,
0.4915,
0.4287,
0.347,
0.4849,
0.4355,
0.3469,
0.4871,
0.4431,
0.3538,
0.4912,
0.4521,
0.3643,
0.5059,
0.4587,
0.373,
0.5166,
0.4685,
0.3845,
0.5264,
0.4746,
0.3914,
0.5342,
]
),
("cuda", 8): np.array(
[
0.5146,
0.4385,
0.3826,
0.5098,
0.4150,
0.3218,
0.5142,
0.4312,
0.3298,
0.5127,
0.4431,
0.3411,
0.5171,
0.4424,
0.3374,
0.5088,
0.4348,
0.3242,
0.5073,
0.4380,
0.3174,
0.5132,
0.4397,
0.3115,
0.5132,
0.4343,
0.3118,
0.5219,
0.4328,
0.3256,
]
),
}
)
expected_slice = expected_slices.get_expectation()
max_diff = numpy_cosine_similarity_distance(expected_slice.flatten(), image_slice.flatten())
assert max_diff < 1e-4, f"Outputs are not close enough, got {max_diff}"
+265
View File
@@ -17,12 +17,14 @@ import gc
import json
import os
import random
import re
import shutil
import sys
import tempfile
import traceback
import unittest
import unittest.mock as mock
import warnings
import numpy as np
import PIL.Image
@@ -78,6 +80,8 @@ from diffusers.utils.testing_utils import (
require_flax,
require_hf_hub_version_greater,
require_onnxruntime,
require_peft_backend,
require_peft_version_greater,
require_torch_2,
require_torch_accelerator,
require_transformers_version_greater,
@@ -2175,3 +2179,264 @@ class PipelineNightlyTests(unittest.TestCase):
# the values aren't exactly equal, but the images look the same visually
assert np.abs(ddpm_images - ddim_images).max() < 1e-1
@slow
@require_torch_2
@require_torch_accelerator
@require_peft_backend
@require_peft_version_greater("0.14.0")
@is_torch_compile
class TestLoraHotSwappingForPipeline(unittest.TestCase):
"""Test that hotswapping does not result in recompilation in a pipeline.
We're not extensively testing the hotswapping functionality since it is implemented in PEFT and is extensively
tested there. The goal of this test is specifically to ensure that hotswapping with diffusers does not require
recompilation.
See
https://github.com/huggingface/peft/blob/eaab05e18d51fb4cce20a73c9acd82a00c013b83/tests/test_gpu_examples.py#L4252
for the analogous PEFT test.
"""
def tearDown(self):
# It is critical that the dynamo cache is reset for each test. Otherwise, if the test re-uses the same model,
# there will be recompilation errors, as torch caches the model when run in the same process.
super().tearDown()
torch._dynamo.reset()
gc.collect()
backend_empty_cache(torch_device)
def get_unet_lora_config(self, lora_rank, lora_alpha, target_modules):
# from diffusers test_models_unet_2d_condition.py
from peft import LoraConfig
unet_lora_config = LoraConfig(
r=lora_rank,
lora_alpha=lora_alpha,
target_modules=target_modules,
init_lora_weights=False,
use_dora=False,
)
return unet_lora_config
def get_lora_state_dicts(self, modules_to_save, adapter_name):
from peft import get_peft_model_state_dict
state_dicts = {}
for module_name, module in modules_to_save.items():
if module is not None:
state_dicts[f"{module_name}_lora_layers"] = get_peft_model_state_dict(
module, adapter_name=adapter_name
)
return state_dicts
def get_dummy_input(self):
pipeline_inputs = {
"prompt": "A painting of a squirrel eating a burger",
"num_inference_steps": 5,
"guidance_scale": 6.0,
"output_type": "np",
"return_dict": False,
}
return pipeline_inputs
def check_pipeline_hotswap(self, do_compile, rank0, rank1, target_modules0, target_modules1=None):
"""
Check that hotswapping works on a pipeline.
Steps:
- create 2 LoRA adapters and save them
- load the first adapter
- hotswap the second adapter
- check that the outputs are correct
- optionally compile the model
Note: We set rank == alpha here because save_lora_adapter does not save the alpha scalings, thus the test would
fail if the values are different. Since rank != alpha does not matter for the purpose of this test, this is
fine.
"""
# create 2 adapters with different ranks and alphas
dummy_input = self.get_dummy_input()
pipeline = StableDiffusionPipeline.from_pretrained("hf-internal-testing/tiny-sd-pipe").to(torch_device)
alpha0, alpha1 = rank0, rank1
max_rank = max([rank0, rank1])
if target_modules1 is None:
target_modules1 = target_modules0[:]
lora_config0 = self.get_unet_lora_config(rank0, alpha0, target_modules0)
lora_config1 = self.get_unet_lora_config(rank1, alpha1, target_modules1)
torch.manual_seed(0)
pipeline.unet.add_adapter(lora_config0, adapter_name="adapter0")
output0_before = pipeline(**dummy_input, generator=torch.manual_seed(0))[0]
torch.manual_seed(1)
pipeline.unet.add_adapter(lora_config1, adapter_name="adapter1")
pipeline.unet.set_adapter("adapter1")
output1_before = pipeline(**dummy_input, generator=torch.manual_seed(0))[0]
# sanity check
tol = 1e-3
assert not np.allclose(output0_before, output1_before, atol=tol, rtol=tol)
assert not (output0_before == 0).all()
assert not (output1_before == 0).all()
with tempfile.TemporaryDirectory() as tmp_dirname:
# save the adapter checkpoints
lora0_state_dicts = self.get_lora_state_dicts({"unet": pipeline.unet}, adapter_name="adapter0")
StableDiffusionPipeline.save_lora_weights(
save_directory=os.path.join(tmp_dirname, "adapter0"), safe_serialization=True, **lora0_state_dicts
)
lora1_state_dicts = self.get_lora_state_dicts({"unet": pipeline.unet}, adapter_name="adapter1")
StableDiffusionPipeline.save_lora_weights(
save_directory=os.path.join(tmp_dirname, "adapter1"), safe_serialization=True, **lora1_state_dicts
)
del pipeline
# load the first adapter
pipeline = StableDiffusionPipeline.from_pretrained("hf-internal-testing/tiny-sd-pipe").to(torch_device)
if do_compile or (rank0 != rank1):
# no need to prepare if the model is not compiled or if the ranks are identical
pipeline.enable_lora_hotswap(target_rank=max_rank)
file_name0 = os.path.join(tmp_dirname, "adapter0", "pytorch_lora_weights.safetensors")
file_name1 = os.path.join(tmp_dirname, "adapter1", "pytorch_lora_weights.safetensors")
pipeline.load_lora_weights(file_name0)
if do_compile:
pipeline.unet = torch.compile(pipeline.unet, mode="reduce-overhead")
output0_after = pipeline(**dummy_input, generator=torch.manual_seed(0))[0]
# sanity check: still same result
assert np.allclose(output0_before, output0_after, atol=tol, rtol=tol)
# hotswap the 2nd adapter
pipeline.load_lora_weights(file_name1, hotswap=True, adapter_name="default_0")
output1_after = pipeline(**dummy_input, generator=torch.manual_seed(0))[0]
# sanity check: since it's the same LoRA, the results should be identical
assert np.allclose(output1_before, output1_after, atol=tol, rtol=tol)
@parameterized.expand([(11, 11), (7, 13), (13, 7)]) # important to test small to large and vice versa
def test_hotswapping_pipeline(self, rank0, rank1):
self.check_pipeline_hotswap(
do_compile=False, rank0=rank0, rank1=rank1, target_modules0=["to_q", "to_k", "to_v", "to_out.0"]
)
@parameterized.expand([(11, 11), (7, 13), (13, 7)]) # important to test small to large and vice versa
def test_hotswapping_compiled_pipline_linear(self, rank0, rank1):
# It's important to add this context to raise an error on recompilation
target_modules = ["to_q", "to_k", "to_v", "to_out.0"]
with torch._dynamo.config.patch(error_on_recompile=True):
self.check_pipeline_hotswap(do_compile=True, rank0=rank0, rank1=rank1, target_modules0=target_modules)
@parameterized.expand([(11, 11), (7, 13), (13, 7)]) # important to test small to large and vice versa
def test_hotswapping_compiled_pipline_conv2d(self, rank0, rank1):
# It's important to add this context to raise an error on recompilation
target_modules = ["conv", "conv1", "conv2"]
with torch._dynamo.config.patch(error_on_recompile=True):
self.check_pipeline_hotswap(do_compile=True, rank0=rank0, rank1=rank1, target_modules0=target_modules)
@parameterized.expand([(11, 11), (7, 13), (13, 7)]) # important to test small to large and vice versa
def test_hotswapping_compiled_pipline_both_linear_and_conv2d(self, rank0, rank1):
# It's important to add this context to raise an error on recompilation
target_modules = ["to_q", "conv"]
with torch._dynamo.config.patch(error_on_recompile=True):
self.check_pipeline_hotswap(do_compile=True, rank0=rank0, rank1=rank1, target_modules0=target_modules)
def test_enable_lora_hotswap_called_after_adapter_added_raises(self):
# ensure that enable_lora_hotswap is called before loading the first adapter
lora_config = self.get_unet_lora_config(8, 8, target_modules=["to_q"])
pipeline = StableDiffusionPipeline.from_pretrained("hf-internal-testing/tiny-sd-pipe").to(torch_device)
pipeline.unet.add_adapter(lora_config)
msg = re.escape("Call `enable_lora_hotswap` before loading the first adapter.")
with self.assertRaisesRegex(RuntimeError, msg):
pipeline.enable_lora_hotswap(target_rank=32)
def test_enable_lora_hotswap_called_after_adapter_added_warns(self):
# ensure that enable_lora_hotswap is called before loading the first adapter
from diffusers.loaders.peft import logger
lora_config = self.get_unet_lora_config(8, 8, target_modules=["to_q"])
pipeline = StableDiffusionPipeline.from_pretrained("hf-internal-testing/tiny-sd-pipe").to(torch_device)
pipeline.unet.add_adapter(lora_config)
msg = (
"It is recommended to call `enable_lora_hotswap` before loading the first adapter to avoid recompilation."
)
with self.assertLogs(logger=logger, level="WARNING") as cm:
pipeline.enable_lora_hotswap(target_rank=32, check_compiled="warn")
assert any(msg in log for log in cm.output)
def test_enable_lora_hotswap_called_after_adapter_added_ignore(self):
# check possibility to ignore the error/warning
lora_config = self.get_unet_lora_config(8, 8, target_modules=["to_q"])
pipeline = StableDiffusionPipeline.from_pretrained("hf-internal-testing/tiny-sd-pipe").to(torch_device)
pipeline.unet.add_adapter(lora_config)
with warnings.catch_warnings(record=True) as w:
warnings.simplefilter("always") # Capture all warnings
pipeline.enable_lora_hotswap(target_rank=32, check_compiled="warn")
self.assertEqual(len(w), 0, f"Expected no warnings, but got: {[str(warn.message) for warn in w]}")
def test_enable_lora_hotswap_wrong_check_compiled_argument_raises(self):
# check that wrong argument value raises an error
lora_config = self.get_unet_lora_config(8, 8, target_modules=["to_q"])
pipeline = StableDiffusionPipeline.from_pretrained("hf-internal-testing/tiny-sd-pipe").to(torch_device)
pipeline.unet.add_adapter(lora_config)
msg = re.escape("check_compiles should be one of 'error', 'warn', or 'ignore', got 'wrong-argument' instead.")
with self.assertRaisesRegex(ValueError, msg):
pipeline.enable_lora_hotswap(target_rank=32, check_compiled="wrong-argument")
def test_hotswap_second_adapter_targets_more_layers_raises(self):
# check the error and log
from diffusers.loaders.peft import logger
# at the moment, PEFT requires the 2nd adapter to target the same or a subset of layers
target_modules0 = ["to_q"]
target_modules1 = ["to_q", "to_k"]
with self.assertRaises(RuntimeError): # peft raises RuntimeError
with self.assertLogs(logger=logger, level="ERROR") as cm:
self.check_pipeline_hotswap(
do_compile=True, rank0=8, rank1=8, target_modules0=target_modules0, target_modules1=target_modules1
)
assert any("Hotswapping adapter0 was unsuccessful" in log for log in cm.output)
def test_hotswap_component_not_supported_raises(self):
# right now, not some components don't support hotswapping, e.g. the text_encoder
from peft import LoraConfig
pipeline = StableDiffusionPipeline.from_pretrained("hf-internal-testing/tiny-sd-pipe").to(torch_device)
lora_config0 = LoraConfig(target_modules=["q_proj"])
lora_config1 = LoraConfig(target_modules=["q_proj"])
pipeline.text_encoder.add_adapter(lora_config0, adapter_name="adapter0")
pipeline.text_encoder.add_adapter(lora_config1, adapter_name="adapter1")
with tempfile.TemporaryDirectory() as tmp_dirname:
# save the adapter checkpoints
lora0_state_dicts = self.get_lora_state_dicts(
{"text_encoder": pipeline.text_encoder}, adapter_name="adapter0"
)
StableDiffusionPipeline.save_lora_weights(
save_directory=os.path.join(tmp_dirname, "adapter0"), safe_serialization=True, **lora0_state_dicts
)
lora1_state_dicts = self.get_lora_state_dicts(
{"text_encoder": pipeline.text_encoder}, adapter_name="adapter1"
)
StableDiffusionPipeline.save_lora_weights(
save_directory=os.path.join(tmp_dirname, "adapter1"), safe_serialization=True, **lora1_state_dicts
)
del pipeline
# load the first adapter
pipeline = StableDiffusionPipeline.from_pretrained("hf-internal-testing/tiny-sd-pipe").to(torch_device)
file_name0 = os.path.join(tmp_dirname, "adapter0", "pytorch_lora_weights.safetensors")
file_name1 = os.path.join(tmp_dirname, "adapter1", "pytorch_lora_weights.safetensors")
pipeline.load_lora_weights(file_name0)
msg = re.escape(
"At the moment, hotswapping is not supported for text encoders, please pass `hotswap=False`"
)
with self.assertRaisesRegex(ValueError, msg):
pipeline.load_lora_weights(file_name1, hotswap=True, adapter_name="default_0")
+1 -1
View File
@@ -379,7 +379,7 @@ class BnB8bitTrainingTests(Base8bitTests):
model_inputs.update({k: v for k, v in input_dict_for_transformer.items() if k not in model_inputs})
# Step 4: Check if the gradient is not None
with torch.amp.autocast("cuda", dtype=torch.float16):
with torch.amp.autocast(torch_device, dtype=torch.float16):
out = self.model_8bit(**model_inputs)[0]
out.norm().backward()