Compare commits
11 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| 7e13c18b29 | |||
| 71f34fc5a4 | |||
| c51b6bd837 | |||
| fb54499614 | |||
| 723dbdd363 | |||
| fbf61f465b | |||
| 841504bb1a | |||
| fc7a867ae5 | |||
| 5ded26cdc7 | |||
| 506f39af3a | |||
| 8ad68c1393 |
@@ -14,6 +14,7 @@ specific language governing permissions and limitations under the License.
|
||||
|
||||
<div class="flex flex-wrap space-x-1">
|
||||
<img alt="LoRA" src="https://img.shields.io/badge/LoRA-d8b4fe?style=flat"/>
|
||||
<img alt="MPS" src="https://img.shields.io/badge/MPS-000000?style=flat&logo=apple&logoColor=white%22">
|
||||
</div>
|
||||
|
||||
## Overview
|
||||
|
||||
@@ -14,6 +14,7 @@ specific language governing permissions and limitations under the License.
|
||||
|
||||
<div class="flex flex-wrap space-x-1">
|
||||
<img alt="LoRA" src="https://img.shields.io/badge/LoRA-d8b4fe?style=flat"/>
|
||||
<img alt="MPS" src="https://img.shields.io/badge/MPS-000000?style=flat&logo=apple&logoColor=white%22">
|
||||
</div>
|
||||
|
||||
Flux is a series of text-to-image generation models based on diffusion transformers. To know more about Flux, check out the original [blog post](https://blackforestlabs.ai/announcing-black-forest-labs/) by the creators of Flux, Black Forest Labs.
|
||||
|
||||
@@ -14,6 +14,7 @@ specific language governing permissions and limitations under the License.
|
||||
|
||||
<div class="flex flex-wrap space-x-1">
|
||||
<img alt="LoRA" src="https://img.shields.io/badge/LoRA-d8b4fe?style=flat"/>
|
||||
<img alt="MPS" src="https://img.shields.io/badge/MPS-000000?style=flat&logo=apple&logoColor=white%22">
|
||||
</div>
|
||||
|
||||

|
||||
|
||||
@@ -16,6 +16,7 @@
|
||||
|
||||
<div class="flex flex-wrap space-x-1">
|
||||
<img alt="LoRA" src="https://img.shields.io/badge/LoRA-d8b4fe?style=flat"/>
|
||||
<img alt="MPS" src="https://img.shields.io/badge/MPS-000000?style=flat&logo=apple&logoColor=white%22">
|
||||
</div>
|
||||
|
||||
[LTX Video](https://huggingface.co/Lightricks/LTX-Video) is the first DiT-based video generation model capable of generating high-quality videos in real-time. It produces 24 FPS videos at a 768x512 resolution faster than they can be watched. Trained on a large-scale dataset of diverse videos, the model generates high-resolution videos with realistic and varied content. We provide a model for both text-to-video as well as image + text-to-video usecases.
|
||||
|
||||
@@ -16,6 +16,7 @@
|
||||
|
||||
<div class="flex flex-wrap space-x-1">
|
||||
<img alt="LoRA" src="https://img.shields.io/badge/LoRA-d8b4fe?style=flat"/>
|
||||
<img alt="MPS" src="https://img.shields.io/badge/MPS-000000?style=flat&logo=apple&logoColor=white%22">
|
||||
</div>
|
||||
|
||||
[SANA: Efficient High-Resolution Image Synthesis with Linear Diffusion Transformers](https://huggingface.co/papers/2410.10629) from NVIDIA and MIT HAN Lab, by Enze Xie, Junsong Chen, Junyu Chen, Han Cai, Haotian Tang, Yujun Lin, Zhekai Zhang, Muyang Li, Ligeng Zhu, Yao Lu, Song Han.
|
||||
|
||||
@@ -14,6 +14,7 @@ specific language governing permissions and limitations under the License.
|
||||
|
||||
<div class="flex flex-wrap space-x-1">
|
||||
<img alt="LoRA" src="https://img.shields.io/badge/LoRA-d8b4fe?style=flat"/>
|
||||
<img alt="MPS" src="https://img.shields.io/badge/MPS-000000?style=flat&logo=apple&logoColor=white%22">
|
||||
</div>
|
||||
|
||||
Stable Diffusion 3 (SD3) was proposed in [Scaling Rectified Flow Transformers for High-Resolution Image Synthesis](https://arxiv.org/pdf/2403.03206.pdf) by Patrick Esser, Sumith Kulal, Andreas Blattmann, Rahim Entezari, Jonas Muller, Harry Saini, Yam Levi, Dominik Lorenz, Axel Sauer, Frederic Boesel, Dustin Podell, Tim Dockhorn, Zion English, Kyle Lacey, Alex Goodwin, Yannik Marek, and Robin Rombach.
|
||||
|
||||
@@ -14,6 +14,7 @@ specific language governing permissions and limitations under the License.
|
||||
|
||||
<div class="flex flex-wrap space-x-1">
|
||||
<img alt="LoRA" src="https://img.shields.io/badge/LoRA-d8b4fe?style=flat"/>
|
||||
<img alt="MPS" src="https://img.shields.io/badge/MPS-000000?style=flat&logo=apple&logoColor=white%22">
|
||||
</div>
|
||||
|
||||
Stable Diffusion XL (SDXL) was proposed in [SDXL: Improving Latent Diffusion Models for High-Resolution Image Synthesis](https://huggingface.co/papers/2307.01952) by Dustin Podell, Zion English, Kyle Lacey, Andreas Blattmann, Tim Dockhorn, Jonas Müller, Joe Penna, and Robin Rombach.
|
||||
|
||||
@@ -12,6 +12,9 @@ specific language governing permissions and limitations under the License.
|
||||
|
||||
# Metal Performance Shaders (MPS)
|
||||
|
||||
> [!TIP]
|
||||
> Pipelines with a <img alt="MPS" src="https://img.shields.io/badge/MPS-000000?style=flat&logo=apple&logoColor=white%22"> badge indicate a model can take advantage of the MPS backend on Apple silicon devices for faster inference. Feel free to open a [Pull Request](https://github.com/huggingface/diffusers/compare) to add this badge to pipelines that are missing it.
|
||||
|
||||
🤗 Diffusers is compatible with Apple silicon (M1/M2 chips) using the PyTorch [`mps`](https://pytorch.org/docs/stable/notes/mps.html) device, which uses the Metal framework to leverage the GPU on MacOS devices. You'll need to have:
|
||||
|
||||
- macOS computer with Apple silicon (M1/M2) hardware
|
||||
@@ -37,7 +40,7 @@ image
|
||||
|
||||
<Tip warning={true}>
|
||||
|
||||
Generating multiple prompts in a batch can [crash](https://github.com/huggingface/diffusers/issues/363) or fail to work reliably. We believe this is related to the [`mps`](https://github.com/pytorch/pytorch/issues/84039) backend in PyTorch. While this is being investigated, you should iterate instead of batching.
|
||||
The PyTorch [mps](https://pytorch.org/docs/stable/notes/mps.html) backend does not support NDArray sizes greater than `2**32`. Please open an [Issue](https://github.com/huggingface/diffusers/issues/new/choose) if you encounter this problem so we can investigate.
|
||||
|
||||
</Tip>
|
||||
|
||||
@@ -59,6 +62,10 @@ If you're using **PyTorch 1.13**, you need to "prime" the pipeline with an addit
|
||||
|
||||
## Troubleshoot
|
||||
|
||||
This section lists some common issues with using the `mps` backend and how to solve them.
|
||||
|
||||
### Attention slicing
|
||||
|
||||
M1/M2 performance is very sensitive to memory pressure. When this occurs, the system automatically swaps if it needs to which significantly degrades performance.
|
||||
|
||||
To prevent this from happening, we recommend *attention slicing* to reduce memory pressure during inference and prevent swapping. This is especially relevant if your computer has less than 64GB of system RAM, or if you generate images at non-standard resolutions larger than 512×512 pixels. Call the [`~DiffusionPipeline.enable_attention_slicing`] function on your pipeline:
|
||||
@@ -72,3 +79,7 @@ pipeline.enable_attention_slicing()
|
||||
```
|
||||
|
||||
Attention slicing performs the costly attention operation in multiple steps instead of all at once. It usually improves performance by ~20% in computers without universal memory, but we've observed *better performance* in most Apple silicon computers unless you have 64GB of RAM or more.
|
||||
|
||||
### Batch inference
|
||||
|
||||
Generating multiple prompts in a batch can crash or fail to work reliably. If this is the case, try iterating instead of batching.
|
||||
@@ -194,6 +194,59 @@ Currently, [`~loaders.StableDiffusionLoraLoaderMixin.set_adapters`] only support
|
||||
|
||||
</Tip>
|
||||
|
||||
### Hotswapping LoRA adapters
|
||||
|
||||
A common use case when serving multiple adapters is to load one adapter first, generate images, load another adapter, generate more images, load another adapter, etc. This workflow normally requires calling [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_weights`], [`~loaders.StableDiffusionLoraLoaderMixin.set_adapters`], and possibly [`~loaders.peft.PeftAdapterMixin.delete_adapters`] to save memory. Moreover, if the model is compiled using `torch.compile`, performing these steps requires recompilation, which takes time.
|
||||
|
||||
To better support this common workflow, you can "hotswap" a LoRA adapter, to avoid accumulating memory and in some cases, recompilation. It requires an adapter to already be loaded, and the new adapter weights are swapped in-place for the existing adapter.
|
||||
|
||||
Pass `hotswap=True` when loading a LoRA adapter to enable this feature. It is important to indicate the name of the existing adapter, (`default_0` is the default adapter name), to be swapped. If you loaded the first adapter with a different name, use that name instead.
|
||||
|
||||
```python
|
||||
pipe = ...
|
||||
# load adapter 1 as normal
|
||||
pipeline.load_lora_weights(file_name_adapter_1)
|
||||
# generate some images with adapter 1
|
||||
...
|
||||
# now hot swap the 2nd adapter
|
||||
pipeline.load_lora_weights(file_name_adapter_2, hotswap=True, adapter_name="default_0")
|
||||
# generate images with adapter 2
|
||||
```
|
||||
|
||||
|
||||
<Tip warning={true}>
|
||||
|
||||
Hotswapping is not currently supported for LoRA adapters that target the text encoder.
|
||||
|
||||
</Tip>
|
||||
|
||||
For compiled models, it is often (though not always if the second adapter targets identical LoRA ranks and scales) necessary to call [`~loaders.lora_base.LoraBaseMixin.enable_lora_hotswap`] to avoid recompilation. Use [`~loaders.lora_base.LoraBaseMixin.enable_lora_hotswap`] _before_ loading the first adapter, and `torch.compile` should be called _after_ loading the first adapter.
|
||||
|
||||
```python
|
||||
pipe = ...
|
||||
# call this extra method
|
||||
pipe.enable_lora_hotswap(target_rank=max_rank)
|
||||
# now load adapter 1
|
||||
pipe.load_lora_weights(file_name_adapter_1)
|
||||
# now compile the unet of the pipeline
|
||||
pipe.unet = torch.compile(pipeline.unet, ...)
|
||||
# generate some images with adapter 1
|
||||
...
|
||||
# now hot swap adapter 2
|
||||
pipeline.load_lora_weights(file_name_adapter_2, hotswap=True, adapter_name="default_0")
|
||||
# generate images with adapter 2
|
||||
```
|
||||
|
||||
The `target_rank=max_rank` argument is important for setting the maximum rank among all LoRA adapters that will be loaded. If you have one adapter with rank 8 and another with rank 16, pass `target_rank=16`. You should use a higher value if in doubt. By default, this value is 128.
|
||||
|
||||
However, there can be situations where recompilation is unavoidable. For example, if the hotswapped adapter targets more layers than the initial adapter, then recompilation is triggered. Try to load the adapter that targets the most layers first. Refer to the PEFT docs on [hotswapping](https://huggingface.co/docs/peft/main/en/package_reference/hotswap#peft.utils.hotswap.hotswap_adapter) for more details about the limitations of this feature.
|
||||
|
||||
<Tip>
|
||||
|
||||
Move your code inside the `with torch._dynamo.config.patch(error_on_recompile=True)` context manager to detect if a model was recompiled. If you detect recompilation despite following all the steps above, please open an issue with [Diffusers](https://github.com/huggingface/diffusers/issues) with a reproducible example.
|
||||
|
||||
</Tip>
|
||||
|
||||
### Kohya and TheLastBen
|
||||
|
||||
Other popular LoRA trainers from the community include those by [Kohya](https://github.com/kohya-ss/sd-scripts/) and [TheLastBen](https://github.com/TheLastBen/fast-stable-diffusion). These trainers create different LoRA checkpoints than those trained by 🤗 Diffusers, but they can still be loaded in the same way.
|
||||
|
||||
@@ -1,7 +1,8 @@
|
||||
accelerate>=0.16.0
|
||||
accelerate>=0.31.0
|
||||
torchvision
|
||||
transformers>=4.25.1
|
||||
transformers>=4.41.2
|
||||
ftfy
|
||||
tensorboard
|
||||
Jinja2
|
||||
peft==0.7.0
|
||||
peft>=0.11.1
|
||||
sentencepiece
|
||||
@@ -24,7 +24,7 @@ import re
|
||||
import shutil
|
||||
from contextlib import nullcontext
|
||||
from pathlib import Path
|
||||
from typing import List, Optional, Union
|
||||
from typing import List, Optional
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
@@ -228,10 +228,20 @@ def log_validation(
|
||||
|
||||
# run inference
|
||||
generator = torch.Generator(device=accelerator.device).manual_seed(args.seed) if args.seed is not None else None
|
||||
autocast_ctx = nullcontext()
|
||||
autocast_ctx = torch.autocast(accelerator.device.type) if not is_final_validation else nullcontext()
|
||||
|
||||
with autocast_ctx:
|
||||
images = [pipeline(**pipeline_args, generator=generator).images[0] for _ in range(args.num_validation_images)]
|
||||
# pre-calculate prompt embeds, pooled prompt embeds, text ids because t5 does not support autocast
|
||||
with torch.no_grad():
|
||||
prompt_embeds, pooled_prompt_embeds, text_ids = pipeline.encode_prompt(
|
||||
pipeline_args["prompt"], prompt_2=pipeline_args["prompt"]
|
||||
)
|
||||
images = []
|
||||
for _ in range(args.num_validation_images):
|
||||
with autocast_ctx:
|
||||
image = pipeline(
|
||||
prompt_embeds=prompt_embeds, pooled_prompt_embeds=pooled_prompt_embeds, generator=generator
|
||||
).images[0]
|
||||
images.append(image)
|
||||
|
||||
for tracker in accelerator.trackers:
|
||||
phase_name = "test" if is_final_validation else "validation"
|
||||
@@ -657,6 +667,7 @@ def parse_args(input_args=None):
|
||||
parser.add_argument(
|
||||
"--adam_weight_decay_text_encoder", type=float, default=1e-03, help="Weight decay to use for text_encoder"
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--lora_layers",
|
||||
type=str,
|
||||
@@ -666,6 +677,7 @@ def parse_args(input_args=None):
|
||||
'E.g. - "to_k,to_q,to_v,to_out.0" will result in lora training of attention layers only. For more examples refer to https://github.com/huggingface/diffusers/blob/main/examples/advanced_diffusion_training/README_flux.md'
|
||||
),
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--adam_epsilon",
|
||||
type=float,
|
||||
@@ -738,6 +750,15 @@ def parse_args(input_args=None):
|
||||
" flag passed with the `accelerate.launch` command. Use this argument to override the accelerate config."
|
||||
),
|
||||
)
|
||||
parser.add_argument(
|
||||
"--upcast_before_saving",
|
||||
action="store_true",
|
||||
default=False,
|
||||
help=(
|
||||
"Whether to upcast the trained transformer layers to float32 before saving (at the end of training). "
|
||||
"Defaults to precision dtype used for training to save memory"
|
||||
),
|
||||
)
|
||||
parser.add_argument(
|
||||
"--prior_generation_precision",
|
||||
type=str,
|
||||
@@ -1147,7 +1168,7 @@ def tokenize_prompt(tokenizer, prompt, max_sequence_length, add_special_tokens=F
|
||||
return text_input_ids
|
||||
|
||||
|
||||
def _get_t5_prompt_embeds(
|
||||
def _encode_prompt_with_t5(
|
||||
text_encoder,
|
||||
tokenizer,
|
||||
max_sequence_length=512,
|
||||
@@ -1176,7 +1197,10 @@ def _get_t5_prompt_embeds(
|
||||
|
||||
prompt_embeds = text_encoder(text_input_ids.to(device))[0]
|
||||
|
||||
dtype = text_encoder.dtype
|
||||
if hasattr(text_encoder, "module"):
|
||||
dtype = text_encoder.module.dtype
|
||||
else:
|
||||
dtype = text_encoder.dtype
|
||||
prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)
|
||||
|
||||
_, seq_len, _ = prompt_embeds.shape
|
||||
@@ -1188,7 +1212,7 @@ def _get_t5_prompt_embeds(
|
||||
return prompt_embeds
|
||||
|
||||
|
||||
def _get_clip_prompt_embeds(
|
||||
def _encode_prompt_with_clip(
|
||||
text_encoder,
|
||||
tokenizer,
|
||||
prompt: str,
|
||||
@@ -1217,9 +1241,13 @@ def _get_clip_prompt_embeds(
|
||||
|
||||
prompt_embeds = text_encoder(text_input_ids.to(device), output_hidden_states=False)
|
||||
|
||||
if hasattr(text_encoder, "module"):
|
||||
dtype = text_encoder.module.dtype
|
||||
else:
|
||||
dtype = text_encoder.dtype
|
||||
# Use pooled output of CLIPTextModel
|
||||
prompt_embeds = prompt_embeds.pooler_output
|
||||
prompt_embeds = prompt_embeds.to(dtype=text_encoder.dtype, device=device)
|
||||
prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)
|
||||
|
||||
# duplicate text embeddings for each generation per prompt, using mps friendly method
|
||||
prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
|
||||
@@ -1238,136 +1266,35 @@ def encode_prompt(
|
||||
text_input_ids_list=None,
|
||||
):
|
||||
prompt = [prompt] if isinstance(prompt, str) else prompt
|
||||
batch_size = len(prompt)
|
||||
dtype = text_encoders[0].dtype
|
||||
if hasattr(text_encoders[0], "module"):
|
||||
dtype = text_encoders[0].module.dtype
|
||||
else:
|
||||
dtype = text_encoders[0].dtype
|
||||
|
||||
pooled_prompt_embeds = _get_clip_prompt_embeds(
|
||||
pooled_prompt_embeds = _encode_prompt_with_clip(
|
||||
text_encoder=text_encoders[0],
|
||||
tokenizer=tokenizers[0],
|
||||
prompt=prompt,
|
||||
device=device if device is not None else text_encoders[0].device,
|
||||
num_images_per_prompt=num_images_per_prompt,
|
||||
text_input_ids=text_input_ids_list[0] if text_input_ids_list is not None else None,
|
||||
text_input_ids=text_input_ids_list[0] if text_input_ids_list else None,
|
||||
)
|
||||
|
||||
prompt_embeds = _get_t5_prompt_embeds(
|
||||
prompt_embeds = _encode_prompt_with_t5(
|
||||
text_encoder=text_encoders[1],
|
||||
tokenizer=tokenizers[1],
|
||||
max_sequence_length=max_sequence_length,
|
||||
prompt=prompt,
|
||||
num_images_per_prompt=num_images_per_prompt,
|
||||
device=device if device is not None else text_encoders[1].device,
|
||||
text_input_ids=text_input_ids_list[1] if text_input_ids_list is not None else None,
|
||||
text_input_ids=text_input_ids_list[1] if text_input_ids_list else None,
|
||||
)
|
||||
|
||||
text_ids = torch.zeros(batch_size, prompt_embeds.shape[1], 3).to(device=device, dtype=dtype)
|
||||
text_ids = text_ids.repeat(num_images_per_prompt, 1, 1)
|
||||
text_ids = torch.zeros(prompt_embeds.shape[1], 3).to(device=device, dtype=dtype)
|
||||
|
||||
return prompt_embeds, pooled_prompt_embeds, text_ids
|
||||
|
||||
|
||||
# CustomFlowMatchEulerDiscreteScheduler was taken from ostris ai-toolkit trainer:
|
||||
# https://github.com/ostris/ai-toolkit/blob/9ee1ef2a0a2a9a02b92d114a95f21312e5906e54/toolkit/samplers/custom_flowmatch_sampler.py#L95
|
||||
class CustomFlowMatchEulerDiscreteScheduler(FlowMatchEulerDiscreteScheduler):
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
with torch.no_grad():
|
||||
# create weights for timesteps
|
||||
num_timesteps = 1000
|
||||
|
||||
# generate the multiplier based on cosmap loss weighing
|
||||
# this is only used on linear timesteps for now
|
||||
|
||||
# cosine map weighing is higher in the middle and lower at the ends
|
||||
# bot = 1 - 2 * self.sigmas + 2 * self.sigmas ** 2
|
||||
# cosmap_weighing = 2 / (math.pi * bot)
|
||||
|
||||
# sigma sqrt weighing is significantly higher at the end and lower at the beginning
|
||||
sigma_sqrt_weighing = (self.sigmas**-2.0).float()
|
||||
# clip at 1e4 (1e6 is too high)
|
||||
sigma_sqrt_weighing = torch.clamp(sigma_sqrt_weighing, max=1e4)
|
||||
# bring to a mean of 1
|
||||
sigma_sqrt_weighing = sigma_sqrt_weighing / sigma_sqrt_weighing.mean()
|
||||
|
||||
# Create linear timesteps from 1000 to 0
|
||||
timesteps = torch.linspace(1000, 0, num_timesteps, device="cpu")
|
||||
|
||||
self.linear_timesteps = timesteps
|
||||
# self.linear_timesteps_weights = cosmap_weighing
|
||||
self.linear_timesteps_weights = sigma_sqrt_weighing
|
||||
|
||||
# self.sigmas = self.get_sigmas(timesteps, n_dim=1, dtype=torch.float32, device='cpu')
|
||||
pass
|
||||
|
||||
def get_weights_for_timesteps(self, timesteps: torch.Tensor) -> torch.Tensor:
|
||||
# Get the indices of the timesteps
|
||||
step_indices = [(self.timesteps == t).nonzero().item() for t in timesteps]
|
||||
|
||||
# Get the weights for the timesteps
|
||||
weights = self.linear_timesteps_weights[step_indices].flatten()
|
||||
|
||||
return weights
|
||||
|
||||
def get_sigmas(self, timesteps: torch.Tensor, n_dim, dtype, device) -> torch.Tensor:
|
||||
sigmas = self.sigmas.to(device=device, dtype=dtype)
|
||||
schedule_timesteps = self.timesteps.to(device)
|
||||
timesteps = timesteps.to(device)
|
||||
step_indices = [(schedule_timesteps == t).nonzero().item() for t in timesteps]
|
||||
|
||||
sigma = sigmas[step_indices].flatten()
|
||||
while len(sigma.shape) < n_dim:
|
||||
sigma = sigma.unsqueeze(-1)
|
||||
|
||||
return sigma
|
||||
|
||||
def add_noise(
|
||||
self,
|
||||
original_samples: torch.Tensor,
|
||||
noise: torch.Tensor,
|
||||
timesteps: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
## ref https://github.com/huggingface/diffusers/blob/fbe29c62984c33c6cf9cf7ad120a992fe6d20854/examples/dreambooth/train_dreambooth_sd3.py#L1578
|
||||
## Add noise according to flow matching.
|
||||
## zt = (1 - texp) * x + texp * z1
|
||||
|
||||
# sigmas = get_sigmas(timesteps, n_dim=model_input.ndim, dtype=model_input.dtype)
|
||||
# noisy_model_input = (1.0 - sigmas) * model_input + sigmas * noise
|
||||
|
||||
# timestep needs to be in [0, 1], we store them in [0, 1000]
|
||||
# noisy_sample = (1 - timestep) * latent + timestep * noise
|
||||
t_01 = (timesteps / 1000).to(original_samples.device)
|
||||
noisy_model_input = (1 - t_01) * original_samples + t_01 * noise
|
||||
|
||||
# n_dim = original_samples.ndim
|
||||
# sigmas = self.get_sigmas(timesteps, n_dim, original_samples.dtype, original_samples.device)
|
||||
# noisy_model_input = (1.0 - sigmas) * original_samples + sigmas * noise
|
||||
return noisy_model_input
|
||||
|
||||
def scale_model_input(self, sample: torch.Tensor, timestep: Union[float, torch.Tensor]) -> torch.Tensor:
|
||||
return sample
|
||||
|
||||
def set_train_timesteps(self, num_timesteps, device, linear=False):
|
||||
if linear:
|
||||
timesteps = torch.linspace(1000, 0, num_timesteps, device=device)
|
||||
self.timesteps = timesteps
|
||||
return timesteps
|
||||
else:
|
||||
# distribute them closer to center. Inference distributes them as a bias toward first
|
||||
# Generate values from 0 to 1
|
||||
t = torch.sigmoid(torch.randn((num_timesteps,), device=device))
|
||||
|
||||
# Scale and reverse the values to go from 1000 to 0
|
||||
timesteps = (1 - t) * 1000
|
||||
|
||||
# Sort the timesteps in descending order
|
||||
timesteps, _ = torch.sort(timesteps, descending=True)
|
||||
|
||||
self.timesteps = timesteps.to(device=device)
|
||||
|
||||
return timesteps
|
||||
|
||||
|
||||
def main(args):
|
||||
if args.report_to == "wandb" and args.hub_token is not None:
|
||||
raise ValueError(
|
||||
@@ -1499,7 +1426,7 @@ def main(args):
|
||||
)
|
||||
|
||||
# Load scheduler and models
|
||||
noise_scheduler = CustomFlowMatchEulerDiscreteScheduler.from_pretrained(
|
||||
noise_scheduler = FlowMatchEulerDiscreteScheduler.from_pretrained(
|
||||
args.pretrained_model_name_or_path, subfolder="scheduler"
|
||||
)
|
||||
noise_scheduler_copy = copy.deepcopy(noise_scheduler)
|
||||
@@ -1619,7 +1546,6 @@ def main(args):
|
||||
target_modules=target_modules,
|
||||
)
|
||||
transformer.add_adapter(transformer_lora_config)
|
||||
|
||||
if args.train_text_encoder:
|
||||
text_lora_config = LoraConfig(
|
||||
r=args.rank,
|
||||
@@ -1727,7 +1653,6 @@ def main(args):
|
||||
cast_training_params(models, dtype=torch.float32)
|
||||
|
||||
transformer_lora_parameters = list(filter(lambda p: p.requires_grad, transformer.parameters()))
|
||||
|
||||
if args.train_text_encoder:
|
||||
text_lora_parameters_one = list(filter(lambda p: p.requires_grad, text_encoder_one.parameters()))
|
||||
# if we use textual inversion, we freeze all parameters except for the token embeddings
|
||||
@@ -1737,7 +1662,8 @@ def main(args):
|
||||
for name, param in text_encoder_one.named_parameters():
|
||||
if "token_embedding" in name:
|
||||
# ensure that dtype is float32, even if rest of the model that isn't trained is loaded in fp16
|
||||
param.data = param.to(dtype=torch.float32)
|
||||
if args.mixed_precision == "fp16":
|
||||
param.data = param.to(dtype=torch.float32)
|
||||
param.requires_grad = True
|
||||
text_lora_parameters_one.append(param)
|
||||
else:
|
||||
@@ -1747,7 +1673,8 @@ def main(args):
|
||||
for name, param in text_encoder_two.named_parameters():
|
||||
if "shared" in name:
|
||||
# ensure that dtype is float32, even if rest of the model that isn't trained is loaded in fp16
|
||||
param.data = param.to(dtype=torch.float32)
|
||||
if args.mixed_precision == "fp16":
|
||||
param.data = param.to(dtype=torch.float32)
|
||||
param.requires_grad = True
|
||||
text_lora_parameters_two.append(param)
|
||||
else:
|
||||
@@ -1828,6 +1755,7 @@ def main(args):
|
||||
optimizer_class = bnb.optim.AdamW8bit
|
||||
else:
|
||||
optimizer_class = torch.optim.AdamW
|
||||
|
||||
optimizer = optimizer_class(
|
||||
params_to_optimize,
|
||||
betas=(args.adam_beta1, args.adam_beta2),
|
||||
@@ -2021,6 +1949,7 @@ def main(args):
|
||||
lr_scheduler,
|
||||
)
|
||||
else:
|
||||
print("I SHOULD BE HERE")
|
||||
transformer, text_encoder_one, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
|
||||
transformer, text_encoder_one, optimizer, train_dataloader, lr_scheduler
|
||||
)
|
||||
@@ -2125,7 +2054,7 @@ def main(args):
|
||||
if args.train_text_encoder:
|
||||
text_encoder_one.train()
|
||||
# set top parameter requires_grad = True for gradient checkpointing works
|
||||
accelerator.unwrap_model(text_encoder_one).text_model.embeddings.requires_grad_(True)
|
||||
unwrap_model(text_encoder_one).text_model.embeddings.requires_grad_(True)
|
||||
elif args.train_text_encoder_ti: # textual inversion / pivotal tuning
|
||||
text_encoder_one.train()
|
||||
if args.enable_t5_ti:
|
||||
@@ -2137,6 +2066,11 @@ def main(args):
|
||||
pivoted_tr = True
|
||||
|
||||
for step, batch in enumerate(train_dataloader):
|
||||
models_to_accumulate = [transformer]
|
||||
if not freeze_text_encoder:
|
||||
models_to_accumulate.extend([text_encoder_one])
|
||||
if args.enable_t5_ti:
|
||||
models_to_accumulate.extend([text_encoder_two])
|
||||
if pivoted_te:
|
||||
# stopping optimization of text_encoder params
|
||||
optimizer.param_groups[te_idx]["lr"] = 0.0
|
||||
@@ -2145,7 +2079,7 @@ def main(args):
|
||||
logger.info(f"PIVOT TRANSFORMER {epoch}")
|
||||
optimizer.param_groups[0]["lr"] = 0.0
|
||||
|
||||
with accelerator.accumulate(transformer):
|
||||
with accelerator.accumulate(models_to_accumulate):
|
||||
prompts = batch["prompts"]
|
||||
|
||||
# encode batch prompts when custom prompts are provided for each image -
|
||||
@@ -2189,7 +2123,7 @@ def main(args):
|
||||
model_input = (model_input - vae_config_shift_factor) * vae_config_scaling_factor
|
||||
model_input = model_input.to(dtype=weight_dtype)
|
||||
|
||||
vae_scale_factor = 2 ** (len(vae_config_block_out_channels))
|
||||
vae_scale_factor = 2 ** (len(vae_config_block_out_channels) - 1)
|
||||
|
||||
latent_image_ids = FluxPipeline._prepare_latent_image_ids(
|
||||
model_input.shape[0],
|
||||
@@ -2228,7 +2162,7 @@ def main(args):
|
||||
)
|
||||
|
||||
# handle guidance
|
||||
if transformer.config.guidance_embeds:
|
||||
if unwrap_model(transformer).config.guidance_embeds:
|
||||
guidance = torch.tensor([args.guidance_scale], device=accelerator.device)
|
||||
guidance = guidance.expand(model_input.shape[0])
|
||||
else:
|
||||
@@ -2288,16 +2222,26 @@ def main(args):
|
||||
accelerator.backward(loss)
|
||||
if accelerator.sync_gradients:
|
||||
if not freeze_text_encoder:
|
||||
if args.train_text_encoder:
|
||||
if args.train_text_encoder: # text encoder tuning
|
||||
params_to_clip = itertools.chain(transformer.parameters(), text_encoder_one.parameters())
|
||||
elif pure_textual_inversion:
|
||||
params_to_clip = itertools.chain(
|
||||
text_encoder_one.parameters(), text_encoder_two.parameters()
|
||||
)
|
||||
if args.enable_t5_ti:
|
||||
params_to_clip = itertools.chain(
|
||||
text_encoder_one.parameters(), text_encoder_two.parameters()
|
||||
)
|
||||
else:
|
||||
params_to_clip = itertools.chain(text_encoder_one.parameters())
|
||||
else:
|
||||
params_to_clip = itertools.chain(
|
||||
transformer.parameters(), text_encoder_one.parameters(), text_encoder_two.parameters()
|
||||
)
|
||||
if args.enable_t5_ti:
|
||||
params_to_clip = itertools.chain(
|
||||
transformer.parameters(),
|
||||
text_encoder_one.parameters(),
|
||||
text_encoder_two.parameters(),
|
||||
)
|
||||
else:
|
||||
params_to_clip = itertools.chain(
|
||||
transformer.parameters(), text_encoder_one.parameters()
|
||||
)
|
||||
else:
|
||||
params_to_clip = itertools.chain(transformer.parameters())
|
||||
accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm)
|
||||
@@ -2339,6 +2283,10 @@ def main(args):
|
||||
|
||||
save_path = os.path.join(args.output_dir, f"checkpoint-{global_step}")
|
||||
accelerator.save_state(save_path)
|
||||
if args.train_text_encoder_ti:
|
||||
embedding_handler.save_embeddings(
|
||||
f"{args.output_dir}/{Path(args.output_dir).name}_emb_checkpoint_{global_step}.safetensors"
|
||||
)
|
||||
logger.info(f"Saved state to {save_path}")
|
||||
|
||||
logs = {"loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0]}
|
||||
@@ -2351,14 +2299,16 @@ def main(args):
|
||||
if accelerator.is_main_process:
|
||||
if args.validation_prompt is not None and epoch % args.validation_epochs == 0:
|
||||
# create pipeline
|
||||
if freeze_text_encoder:
|
||||
if freeze_text_encoder: # no text encoder one, two optimizations
|
||||
text_encoder_one, text_encoder_two = load_text_encoders(text_encoder_cls_one, text_encoder_cls_two)
|
||||
text_encoder_one.to(weight_dtype)
|
||||
text_encoder_two.to(weight_dtype)
|
||||
pipeline = FluxPipeline.from_pretrained(
|
||||
args.pretrained_model_name_or_path,
|
||||
vae=vae,
|
||||
text_encoder=accelerator.unwrap_model(text_encoder_one),
|
||||
text_encoder_2=accelerator.unwrap_model(text_encoder_two),
|
||||
transformer=accelerator.unwrap_model(transformer),
|
||||
text_encoder=unwrap_model(text_encoder_one),
|
||||
text_encoder_2=unwrap_model(text_encoder_two),
|
||||
transformer=unwrap_model(transformer),
|
||||
revision=args.revision,
|
||||
variant=args.variant,
|
||||
torch_dtype=weight_dtype,
|
||||
@@ -2372,21 +2322,21 @@ def main(args):
|
||||
epoch=epoch,
|
||||
torch_dtype=weight_dtype,
|
||||
)
|
||||
images = None
|
||||
del pipeline
|
||||
|
||||
if freeze_text_encoder:
|
||||
del text_encoder_one, text_encoder_two
|
||||
free_memory()
|
||||
elif args.train_text_encoder:
|
||||
del text_encoder_two
|
||||
free_memory()
|
||||
|
||||
images = None
|
||||
del pipeline
|
||||
|
||||
# Save the lora layers
|
||||
accelerator.wait_for_everyone()
|
||||
if accelerator.is_main_process:
|
||||
transformer = unwrap_model(transformer)
|
||||
transformer = transformer.to(weight_dtype)
|
||||
if args.upcast_before_saving:
|
||||
transformer.to(torch.float32)
|
||||
else:
|
||||
transformer = transformer.to(weight_dtype)
|
||||
transformer_lora_layers = get_peft_model_state_dict(transformer)
|
||||
|
||||
if args.train_text_encoder:
|
||||
@@ -2428,8 +2378,8 @@ def main(args):
|
||||
accelerator=accelerator,
|
||||
pipeline_args=pipeline_args,
|
||||
epoch=epoch,
|
||||
torch_dtype=weight_dtype,
|
||||
is_final_validation=True,
|
||||
torch_dtype=weight_dtype,
|
||||
)
|
||||
|
||||
save_model_card(
|
||||
@@ -2452,6 +2402,7 @@ def main(args):
|
||||
commit_message="End of training",
|
||||
ignore_patterns=["step_*", "epoch_*"],
|
||||
)
|
||||
|
||||
images = None
|
||||
del pipeline
|
||||
|
||||
|
||||
@@ -927,17 +927,22 @@ def main(args):
|
||||
)
|
||||
|
||||
# Scheduler and math around the number of training steps.
|
||||
overrode_max_train_steps = False
|
||||
num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
|
||||
# Check the PR https://github.com/huggingface/diffusers/pull/8312 for detailed explanation.
|
||||
num_warmup_steps_for_scheduler = args.lr_warmup_steps * accelerator.num_processes
|
||||
if args.max_train_steps is None:
|
||||
args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
|
||||
overrode_max_train_steps = True
|
||||
len_train_dataloader_after_sharding = math.ceil(len(train_dataloader) / accelerator.num_processes)
|
||||
num_update_steps_per_epoch = math.ceil(len_train_dataloader_after_sharding / args.gradient_accumulation_steps)
|
||||
num_training_steps_for_scheduler = (
|
||||
args.num_train_epochs * num_update_steps_per_epoch * accelerator.num_processes
|
||||
)
|
||||
else:
|
||||
num_training_steps_for_scheduler = args.max_train_steps * accelerator.num_processes
|
||||
|
||||
lr_scheduler = get_scheduler(
|
||||
args.lr_scheduler,
|
||||
optimizer=optimizer,
|
||||
num_warmup_steps=args.lr_warmup_steps * accelerator.num_processes,
|
||||
num_training_steps=args.max_train_steps * accelerator.num_processes,
|
||||
num_warmup_steps=num_warmup_steps_for_scheduler,
|
||||
num_training_steps=num_training_steps_for_scheduler,
|
||||
num_cycles=args.lr_num_cycles,
|
||||
power=args.lr_power,
|
||||
)
|
||||
@@ -962,8 +967,14 @@ def main(args):
|
||||
|
||||
# We need to recalculate our total training steps as the size of the training dataloader may have changed.
|
||||
num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
|
||||
if overrode_max_train_steps:
|
||||
if args.max_train_steps is None:
|
||||
args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
|
||||
if num_training_steps_for_scheduler != args.max_train_steps * accelerator.num_processes:
|
||||
logger.warning(
|
||||
f"The length of the 'train_dataloader' after 'accelerator.prepare' ({len(train_dataloader)}) does not match "
|
||||
f"the expected length ({len_train_dataloader_after_sharding}) when the learning rate scheduler was created. "
|
||||
f"This inconsistency may result in the learning rate scheduler not functioning properly."
|
||||
)
|
||||
# Afterwards we recalculate our number of training epochs
|
||||
args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)
|
||||
|
||||
|
||||
@@ -895,7 +895,10 @@ def _encode_prompt_with_t5(
|
||||
|
||||
prompt_embeds = text_encoder(text_input_ids.to(device))[0]
|
||||
|
||||
dtype = text_encoder.dtype
|
||||
if hasattr(text_encoder, "module"):
|
||||
dtype = text_encoder.module.dtype
|
||||
else:
|
||||
dtype = text_encoder.dtype
|
||||
prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)
|
||||
|
||||
_, seq_len, _ = prompt_embeds.shape
|
||||
@@ -936,9 +939,13 @@ def _encode_prompt_with_clip(
|
||||
|
||||
prompt_embeds = text_encoder(text_input_ids.to(device), output_hidden_states=False)
|
||||
|
||||
if hasattr(text_encoder, "module"):
|
||||
dtype = text_encoder.module.dtype
|
||||
else:
|
||||
dtype = text_encoder.dtype
|
||||
# Use pooled output of CLIPTextModel
|
||||
prompt_embeds = prompt_embeds.pooler_output
|
||||
prompt_embeds = prompt_embeds.to(dtype=text_encoder.dtype, device=device)
|
||||
prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)
|
||||
|
||||
# duplicate text embeddings for each generation per prompt, using mps friendly method
|
||||
prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
|
||||
@@ -958,7 +965,12 @@ def encode_prompt(
|
||||
):
|
||||
prompt = [prompt] if isinstance(prompt, str) else prompt
|
||||
batch_size = len(prompt)
|
||||
dtype = text_encoders[0].dtype
|
||||
|
||||
if hasattr(text_encoders[0], "module"):
|
||||
dtype = text_encoders[0].module.dtype
|
||||
else:
|
||||
dtype = text_encoders[0].dtype
|
||||
|
||||
device = device if device is not None else text_encoders[1].device
|
||||
pooled_prompt_embeds = _encode_prompt_with_clip(
|
||||
text_encoder=text_encoders[0],
|
||||
@@ -1590,7 +1602,7 @@ def main(args):
|
||||
)
|
||||
|
||||
# handle guidance
|
||||
if accelerator.unwrap_model(transformer).config.guidance_embeds:
|
||||
if unwrap_model(transformer).config.guidance_embeds:
|
||||
guidance = torch.tensor([args.guidance_scale], device=accelerator.device)
|
||||
guidance = guidance.expand(model_input.shape[0])
|
||||
else:
|
||||
@@ -1716,9 +1728,9 @@ def main(args):
|
||||
pipeline = FluxPipeline.from_pretrained(
|
||||
args.pretrained_model_name_or_path,
|
||||
vae=vae,
|
||||
text_encoder=accelerator.unwrap_model(text_encoder_one, keep_fp32_wrapper=False),
|
||||
text_encoder_2=accelerator.unwrap_model(text_encoder_two, keep_fp32_wrapper=False),
|
||||
transformer=accelerator.unwrap_model(transformer, keep_fp32_wrapper=False),
|
||||
text_encoder=unwrap_model(text_encoder_one, keep_fp32_wrapper=False),
|
||||
text_encoder_2=unwrap_model(text_encoder_two, keep_fp32_wrapper=False),
|
||||
transformer=unwrap_model(transformer, keep_fp32_wrapper=False),
|
||||
revision=args.revision,
|
||||
variant=args.variant,
|
||||
torch_dtype=weight_dtype,
|
||||
|
||||
@@ -177,16 +177,25 @@ def log_validation(
|
||||
f"Running validation... \n Generating {args.num_validation_images} images with prompt:"
|
||||
f" {args.validation_prompt}."
|
||||
)
|
||||
pipeline = pipeline.to(accelerator.device)
|
||||
pipeline = pipeline.to(accelerator.device, dtype=torch_dtype)
|
||||
pipeline.set_progress_bar_config(disable=True)
|
||||
|
||||
# run inference
|
||||
generator = torch.Generator(device=accelerator.device).manual_seed(args.seed) if args.seed is not None else None
|
||||
# autocast_ctx = torch.autocast(accelerator.device.type) if not is_final_validation else nullcontext()
|
||||
autocast_ctx = nullcontext()
|
||||
autocast_ctx = torch.autocast(accelerator.device.type) if not is_final_validation else nullcontext()
|
||||
|
||||
with autocast_ctx:
|
||||
images = [pipeline(**pipeline_args, generator=generator).images[0] for _ in range(args.num_validation_images)]
|
||||
# pre-calculate prompt embeds, pooled prompt embeds, text ids because t5 does not support autocast
|
||||
with torch.no_grad():
|
||||
prompt_embeds, pooled_prompt_embeds, text_ids = pipeline.encode_prompt(
|
||||
pipeline_args["prompt"], prompt_2=pipeline_args["prompt"]
|
||||
)
|
||||
images = []
|
||||
for _ in range(args.num_validation_images):
|
||||
with autocast_ctx:
|
||||
image = pipeline(
|
||||
prompt_embeds=prompt_embeds, pooled_prompt_embeds=pooled_prompt_embeds, generator=generator
|
||||
).images[0]
|
||||
images.append(image)
|
||||
|
||||
for tracker in accelerator.trackers:
|
||||
phase_name = "test" if is_final_validation else "validation"
|
||||
@@ -203,8 +212,7 @@ def log_validation(
|
||||
)
|
||||
|
||||
del pipeline
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.empty_cache()
|
||||
free_memory()
|
||||
|
||||
return images
|
||||
|
||||
@@ -932,7 +940,10 @@ def _encode_prompt_with_t5(
|
||||
|
||||
prompt_embeds = text_encoder(text_input_ids.to(device))[0]
|
||||
|
||||
dtype = text_encoder.dtype
|
||||
if hasattr(text_encoder, "module"):
|
||||
dtype = text_encoder.module.dtype
|
||||
else:
|
||||
dtype = text_encoder.dtype
|
||||
prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)
|
||||
|
||||
_, seq_len, _ = prompt_embeds.shape
|
||||
@@ -973,9 +984,13 @@ def _encode_prompt_with_clip(
|
||||
|
||||
prompt_embeds = text_encoder(text_input_ids.to(device), output_hidden_states=False)
|
||||
|
||||
if hasattr(text_encoder, "module"):
|
||||
dtype = text_encoder.module.dtype
|
||||
else:
|
||||
dtype = text_encoder.dtype
|
||||
# Use pooled output of CLIPTextModel
|
||||
prompt_embeds = prompt_embeds.pooler_output
|
||||
prompt_embeds = prompt_embeds.to(dtype=text_encoder.dtype, device=device)
|
||||
prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)
|
||||
|
||||
# duplicate text embeddings for each generation per prompt, using mps friendly method
|
||||
prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
|
||||
@@ -994,7 +1009,11 @@ def encode_prompt(
|
||||
text_input_ids_list=None,
|
||||
):
|
||||
prompt = [prompt] if isinstance(prompt, str) else prompt
|
||||
dtype = text_encoders[0].dtype
|
||||
|
||||
if hasattr(text_encoders[0], "module"):
|
||||
dtype = text_encoders[0].module.dtype
|
||||
else:
|
||||
dtype = text_encoders[0].dtype
|
||||
|
||||
pooled_prompt_embeds = _encode_prompt_with_clip(
|
||||
text_encoder=text_encoders[0],
|
||||
@@ -1619,7 +1638,7 @@ def main(args):
|
||||
if args.train_text_encoder:
|
||||
text_encoder_one.train()
|
||||
# set top parameter requires_grad = True for gradient checkpointing works
|
||||
accelerator.unwrap_model(text_encoder_one).text_model.embeddings.requires_grad_(True)
|
||||
unwrap_model(text_encoder_one).text_model.embeddings.requires_grad_(True)
|
||||
|
||||
for step, batch in enumerate(train_dataloader):
|
||||
models_to_accumulate = [transformer]
|
||||
@@ -1710,7 +1729,7 @@ def main(args):
|
||||
)
|
||||
|
||||
# handle guidance
|
||||
if accelerator.unwrap_model(transformer).config.guidance_embeds:
|
||||
if unwrap_model(transformer).config.guidance_embeds:
|
||||
guidance = torch.tensor([args.guidance_scale], device=accelerator.device)
|
||||
guidance = guidance.expand(model_input.shape[0])
|
||||
else:
|
||||
@@ -1828,9 +1847,9 @@ def main(args):
|
||||
pipeline = FluxPipeline.from_pretrained(
|
||||
args.pretrained_model_name_or_path,
|
||||
vae=vae,
|
||||
text_encoder=accelerator.unwrap_model(text_encoder_one),
|
||||
text_encoder_2=accelerator.unwrap_model(text_encoder_two),
|
||||
transformer=accelerator.unwrap_model(transformer),
|
||||
text_encoder=unwrap_model(text_encoder_one),
|
||||
text_encoder_2=unwrap_model(text_encoder_two),
|
||||
transformer=unwrap_model(transformer),
|
||||
revision=args.revision,
|
||||
variant=args.variant,
|
||||
torch_dtype=weight_dtype,
|
||||
|
||||
@@ -669,6 +669,16 @@ def parse_args(input_args=None):
|
||||
),
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--image_interpolation_mode",
|
||||
type=str,
|
||||
default="lanczos",
|
||||
choices=[
|
||||
f.lower() for f in dir(transforms.InterpolationMode) if not f.startswith("__") and not f.endswith("__")
|
||||
],
|
||||
help="The image interpolation method to use for resizing images.",
|
||||
)
|
||||
|
||||
if input_args is not None:
|
||||
args = parser.parse_args(input_args)
|
||||
else:
|
||||
@@ -790,7 +800,12 @@ class DreamBoothDataset(Dataset):
|
||||
self.original_sizes = []
|
||||
self.crop_top_lefts = []
|
||||
self.pixel_values = []
|
||||
train_resize = transforms.Resize(size, interpolation=transforms.InterpolationMode.BILINEAR)
|
||||
|
||||
interpolation = getattr(transforms.InterpolationMode, args.image_interpolation_mode.upper(), None)
|
||||
if interpolation is None:
|
||||
raise ValueError(f"Unsupported interpolation mode {interpolation=}.")
|
||||
train_resize = transforms.Resize(size, interpolation=interpolation)
|
||||
|
||||
train_crop = transforms.CenterCrop(size) if center_crop else transforms.RandomCrop(size)
|
||||
train_flip = transforms.RandomHorizontalFlip(p=1.0)
|
||||
train_transforms = transforms.Compose(
|
||||
|
||||
@@ -316,6 +316,7 @@ def _load_lora_into_text_encoder(
|
||||
adapter_name=None,
|
||||
_pipeline=None,
|
||||
low_cpu_mem_usage=False,
|
||||
hotswap: bool = False,
|
||||
):
|
||||
if not USE_PEFT_BACKEND:
|
||||
raise ValueError("PEFT backend is required for this method.")
|
||||
@@ -341,6 +342,10 @@ def _load_lora_into_text_encoder(
|
||||
# their prefixes.
|
||||
prefix = text_encoder_name if prefix is None else prefix
|
||||
|
||||
# Safe prefix to check with.
|
||||
if hotswap and any(text_encoder_name in key for key in state_dict.keys()):
|
||||
raise ValueError("At the moment, hotswapping is not supported for text encoders, please pass `hotswap=False`.")
|
||||
|
||||
# Load the layers corresponding to text encoder and make necessary adjustments.
|
||||
if prefix is not None:
|
||||
state_dict = {k[len(f"{prefix}.") :]: v for k, v in state_dict.items() if k.startswith(f"{prefix}.")}
|
||||
@@ -908,3 +913,23 @@ class LoraBaseMixin:
|
||||
# property function that returns the lora scale which can be set at run time by the pipeline.
|
||||
# if _lora_scale has not been set, return 1
|
||||
return self._lora_scale if hasattr(self, "_lora_scale") else 1.0
|
||||
|
||||
def enable_lora_hotswap(self, **kwargs) -> None:
|
||||
"""Enables the possibility to hotswap LoRA adapters.
|
||||
|
||||
Calling this method is only required when hotswapping adapters and if the model is compiled or if the ranks of
|
||||
the loaded adapters differ.
|
||||
|
||||
Args:
|
||||
target_rank (`int`):
|
||||
The highest rank among all the adapters that will be loaded.
|
||||
check_compiled (`str`, *optional*, defaults to `"error"`):
|
||||
How to handle the case when the model is already compiled, which should generally be avoided. The
|
||||
options are:
|
||||
- "error" (default): raise an error
|
||||
- "warn": issue a warning
|
||||
- "ignore": do nothing
|
||||
"""
|
||||
for key, component in self.components.items():
|
||||
if hasattr(component, "enable_lora_hotswap") and (key in self._lora_loadable_modules):
|
||||
component.enable_lora_hotswap(**kwargs)
|
||||
|
||||
@@ -79,10 +79,13 @@ class StableDiffusionLoraLoaderMixin(LoraBaseMixin):
|
||||
text_encoder_name = TEXT_ENCODER_NAME
|
||||
|
||||
def load_lora_weights(
|
||||
self, pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]], adapter_name=None, **kwargs
|
||||
self,
|
||||
pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]],
|
||||
adapter_name=None,
|
||||
hotswap: bool = False,
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
Load LoRA weights specified in `pretrained_model_name_or_path_or_dict` into `self.unet` and
|
||||
"""Load LoRA weights specified in `pretrained_model_name_or_path_or_dict` into `self.unet` and
|
||||
`self.text_encoder`.
|
||||
|
||||
All kwargs are forwarded to `self.lora_state_dict`.
|
||||
@@ -105,6 +108,29 @@ class StableDiffusionLoraLoaderMixin(LoraBaseMixin):
|
||||
low_cpu_mem_usage (`bool`, *optional*):
|
||||
Speed up model loading by only loading the pretrained LoRA weights and not initializing the random
|
||||
weights.
|
||||
hotswap : (`bool`, *optional*)
|
||||
Defaults to `False`. Whether to substitute an existing (LoRA) adapter with the newly loaded adapter
|
||||
in-place. This means that, instead of loading an additional adapter, this will take the existing
|
||||
adapter weights and replace them with the weights of the new adapter. This can be faster and more
|
||||
memory efficient. However, the main advantage of hotswapping is that when the model is compiled with
|
||||
torch.compile, loading the new adapter does not require recompilation of the model. When using
|
||||
hotswapping, the passed `adapter_name` should be the name of an already loaded adapter.
|
||||
|
||||
If the new adapter and the old adapter have different ranks and/or LoRA alphas (i.e. scaling), you need
|
||||
to call an additional method before loading the adapter:
|
||||
|
||||
```py
|
||||
pipeline = ... # load diffusers pipeline
|
||||
max_rank = ... # the highest rank among all LoRAs that you want to load
|
||||
# call *before* compiling and loading the LoRA adapter
|
||||
pipeline.enable_lora_hotswap(target_rank=max_rank)
|
||||
pipeline.load_lora_weights(file_name)
|
||||
# optionally compile the model now
|
||||
```
|
||||
|
||||
Note that hotswapping adapters of the text encoder is not yet supported. There are some further
|
||||
limitations to this technique, which are documented here:
|
||||
https://huggingface.co/docs/peft/main/en/package_reference/hotswap
|
||||
kwargs (`dict`, *optional*):
|
||||
See [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`].
|
||||
"""
|
||||
@@ -135,6 +161,7 @@ class StableDiffusionLoraLoaderMixin(LoraBaseMixin):
|
||||
adapter_name=adapter_name,
|
||||
_pipeline=self,
|
||||
low_cpu_mem_usage=low_cpu_mem_usage,
|
||||
hotswap=hotswap,
|
||||
)
|
||||
self.load_lora_into_text_encoder(
|
||||
state_dict,
|
||||
@@ -146,6 +173,7 @@ class StableDiffusionLoraLoaderMixin(LoraBaseMixin):
|
||||
adapter_name=adapter_name,
|
||||
_pipeline=self,
|
||||
low_cpu_mem_usage=low_cpu_mem_usage,
|
||||
hotswap=hotswap,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
@@ -265,7 +293,14 @@ class StableDiffusionLoraLoaderMixin(LoraBaseMixin):
|
||||
|
||||
@classmethod
|
||||
def load_lora_into_unet(
|
||||
cls, state_dict, network_alphas, unet, adapter_name=None, _pipeline=None, low_cpu_mem_usage=False
|
||||
cls,
|
||||
state_dict,
|
||||
network_alphas,
|
||||
unet,
|
||||
adapter_name=None,
|
||||
_pipeline=None,
|
||||
low_cpu_mem_usage=False,
|
||||
hotswap: bool = False,
|
||||
):
|
||||
"""
|
||||
This will load the LoRA layers specified in `state_dict` into `unet`.
|
||||
@@ -287,6 +322,29 @@ class StableDiffusionLoraLoaderMixin(LoraBaseMixin):
|
||||
low_cpu_mem_usage (`bool`, *optional*):
|
||||
Speed up model loading only loading the pretrained LoRA weights and not initializing the random
|
||||
weights.
|
||||
hotswap : (`bool`, *optional*)
|
||||
Defaults to `False`. Whether to substitute an existing (LoRA) adapter with the newly loaded adapter
|
||||
in-place. This means that, instead of loading an additional adapter, this will take the existing
|
||||
adapter weights and replace them with the weights of the new adapter. This can be faster and more
|
||||
memory efficient. However, the main advantage of hotswapping is that when the model is compiled with
|
||||
torch.compile, loading the new adapter does not require recompilation of the model. When using
|
||||
hotswapping, the passed `adapter_name` should be the name of an already loaded adapter.
|
||||
|
||||
If the new adapter and the old adapter have different ranks and/or LoRA alphas (i.e. scaling), you need
|
||||
to call an additional method before loading the adapter:
|
||||
|
||||
```py
|
||||
pipeline = ... # load diffusers pipeline
|
||||
max_rank = ... # the highest rank among all LoRAs that you want to load
|
||||
# call *before* compiling and loading the LoRA adapter
|
||||
pipeline.enable_lora_hotswap(target_rank=max_rank)
|
||||
pipeline.load_lora_weights(file_name)
|
||||
# optionally compile the model now
|
||||
```
|
||||
|
||||
Note that hotswapping adapters of the text encoder is not yet supported. There are some further
|
||||
limitations to this technique, which are documented here:
|
||||
https://huggingface.co/docs/peft/main/en/package_reference/hotswap
|
||||
"""
|
||||
if not USE_PEFT_BACKEND:
|
||||
raise ValueError("PEFT backend is required for this method.")
|
||||
@@ -307,6 +365,7 @@ class StableDiffusionLoraLoaderMixin(LoraBaseMixin):
|
||||
adapter_name=adapter_name,
|
||||
_pipeline=_pipeline,
|
||||
low_cpu_mem_usage=low_cpu_mem_usage,
|
||||
hotswap=hotswap,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
@@ -320,6 +379,7 @@ class StableDiffusionLoraLoaderMixin(LoraBaseMixin):
|
||||
adapter_name=None,
|
||||
_pipeline=None,
|
||||
low_cpu_mem_usage=False,
|
||||
hotswap: bool = False,
|
||||
):
|
||||
"""
|
||||
This will load the LoRA layers specified in `state_dict` into `text_encoder`
|
||||
@@ -345,6 +405,29 @@ class StableDiffusionLoraLoaderMixin(LoraBaseMixin):
|
||||
low_cpu_mem_usage (`bool`, *optional*):
|
||||
Speed up model loading by only loading the pretrained LoRA weights and not initializing the random
|
||||
weights.
|
||||
hotswap : (`bool`, *optional*)
|
||||
Defaults to `False`. Whether to substitute an existing (LoRA) adapter with the newly loaded adapter
|
||||
in-place. This means that, instead of loading an additional adapter, this will take the existing
|
||||
adapter weights and replace them with the weights of the new adapter. This can be faster and more
|
||||
memory efficient. However, the main advantage of hotswapping is that when the model is compiled with
|
||||
torch.compile, loading the new adapter does not require recompilation of the model. When using
|
||||
hotswapping, the passed `adapter_name` should be the name of an already loaded adapter.
|
||||
|
||||
If the new adapter and the old adapter have different ranks and/or LoRA alphas (i.e. scaling), you need
|
||||
to call an additional method before loading the adapter:
|
||||
|
||||
```py
|
||||
pipeline = ... # load diffusers pipeline
|
||||
max_rank = ... # the highest rank among all LoRAs that you want to load
|
||||
# call *before* compiling and loading the LoRA adapter
|
||||
pipeline.enable_lora_hotswap(target_rank=max_rank)
|
||||
pipeline.load_lora_weights(file_name)
|
||||
# optionally compile the model now
|
||||
```
|
||||
|
||||
Note that hotswapping adapters of the text encoder is not yet supported. There are some further
|
||||
limitations to this technique, which are documented here:
|
||||
https://huggingface.co/docs/peft/main/en/package_reference/hotswap
|
||||
"""
|
||||
_load_lora_into_text_encoder(
|
||||
state_dict=state_dict,
|
||||
@@ -356,6 +439,7 @@ class StableDiffusionLoraLoaderMixin(LoraBaseMixin):
|
||||
adapter_name=adapter_name,
|
||||
_pipeline=_pipeline,
|
||||
low_cpu_mem_usage=low_cpu_mem_usage,
|
||||
hotswap=hotswap,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
@@ -700,7 +784,14 @@ class StableDiffusionXLLoraLoaderMixin(LoraBaseMixin):
|
||||
@classmethod
|
||||
# Copied from diffusers.loaders.lora_pipeline.StableDiffusionLoraLoaderMixin.load_lora_into_unet
|
||||
def load_lora_into_unet(
|
||||
cls, state_dict, network_alphas, unet, adapter_name=None, _pipeline=None, low_cpu_mem_usage=False
|
||||
cls,
|
||||
state_dict,
|
||||
network_alphas,
|
||||
unet,
|
||||
adapter_name=None,
|
||||
_pipeline=None,
|
||||
low_cpu_mem_usage=False,
|
||||
hotswap: bool = False,
|
||||
):
|
||||
"""
|
||||
This will load the LoRA layers specified in `state_dict` into `unet`.
|
||||
@@ -722,6 +813,29 @@ class StableDiffusionXLLoraLoaderMixin(LoraBaseMixin):
|
||||
low_cpu_mem_usage (`bool`, *optional*):
|
||||
Speed up model loading only loading the pretrained LoRA weights and not initializing the random
|
||||
weights.
|
||||
hotswap : (`bool`, *optional*)
|
||||
Defaults to `False`. Whether to substitute an existing (LoRA) adapter with the newly loaded adapter
|
||||
in-place. This means that, instead of loading an additional adapter, this will take the existing
|
||||
adapter weights and replace them with the weights of the new adapter. This can be faster and more
|
||||
memory efficient. However, the main advantage of hotswapping is that when the model is compiled with
|
||||
torch.compile, loading the new adapter does not require recompilation of the model. When using
|
||||
hotswapping, the passed `adapter_name` should be the name of an already loaded adapter.
|
||||
|
||||
If the new adapter and the old adapter have different ranks and/or LoRA alphas (i.e. scaling), you need
|
||||
to call an additional method before loading the adapter:
|
||||
|
||||
```py
|
||||
pipeline = ... # load diffusers pipeline
|
||||
max_rank = ... # the highest rank among all LoRAs that you want to load
|
||||
# call *before* compiling and loading the LoRA adapter
|
||||
pipeline.enable_lora_hotswap(target_rank=max_rank)
|
||||
pipeline.load_lora_weights(file_name)
|
||||
# optionally compile the model now
|
||||
```
|
||||
|
||||
Note that hotswapping adapters of the text encoder is not yet supported. There are some further
|
||||
limitations to this technique, which are documented here:
|
||||
https://huggingface.co/docs/peft/main/en/package_reference/hotswap
|
||||
"""
|
||||
if not USE_PEFT_BACKEND:
|
||||
raise ValueError("PEFT backend is required for this method.")
|
||||
@@ -742,6 +856,7 @@ class StableDiffusionXLLoraLoaderMixin(LoraBaseMixin):
|
||||
adapter_name=adapter_name,
|
||||
_pipeline=_pipeline,
|
||||
low_cpu_mem_usage=low_cpu_mem_usage,
|
||||
hotswap=hotswap,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
@@ -756,6 +871,7 @@ class StableDiffusionXLLoraLoaderMixin(LoraBaseMixin):
|
||||
adapter_name=None,
|
||||
_pipeline=None,
|
||||
low_cpu_mem_usage=False,
|
||||
hotswap: bool = False,
|
||||
):
|
||||
"""
|
||||
This will load the LoRA layers specified in `state_dict` into `text_encoder`
|
||||
@@ -781,6 +897,29 @@ class StableDiffusionXLLoraLoaderMixin(LoraBaseMixin):
|
||||
low_cpu_mem_usage (`bool`, *optional*):
|
||||
Speed up model loading by only loading the pretrained LoRA weights and not initializing the random
|
||||
weights.
|
||||
hotswap : (`bool`, *optional*)
|
||||
Defaults to `False`. Whether to substitute an existing (LoRA) adapter with the newly loaded adapter
|
||||
in-place. This means that, instead of loading an additional adapter, this will take the existing
|
||||
adapter weights and replace them with the weights of the new adapter. This can be faster and more
|
||||
memory efficient. However, the main advantage of hotswapping is that when the model is compiled with
|
||||
torch.compile, loading the new adapter does not require recompilation of the model. When using
|
||||
hotswapping, the passed `adapter_name` should be the name of an already loaded adapter.
|
||||
|
||||
If the new adapter and the old adapter have different ranks and/or LoRA alphas (i.e. scaling), you need
|
||||
to call an additional method before loading the adapter:
|
||||
|
||||
```py
|
||||
pipeline = ... # load diffusers pipeline
|
||||
max_rank = ... # the highest rank among all LoRAs that you want to load
|
||||
# call *before* compiling and loading the LoRA adapter
|
||||
pipeline.enable_lora_hotswap(target_rank=max_rank)
|
||||
pipeline.load_lora_weights(file_name)
|
||||
# optionally compile the model now
|
||||
```
|
||||
|
||||
Note that hotswapping adapters of the text encoder is not yet supported. There are some further
|
||||
limitations to this technique, which are documented here:
|
||||
https://huggingface.co/docs/peft/main/en/package_reference/hotswap
|
||||
"""
|
||||
_load_lora_into_text_encoder(
|
||||
state_dict=state_dict,
|
||||
@@ -792,6 +931,7 @@ class StableDiffusionXLLoraLoaderMixin(LoraBaseMixin):
|
||||
adapter_name=adapter_name,
|
||||
_pipeline=_pipeline,
|
||||
low_cpu_mem_usage=low_cpu_mem_usage,
|
||||
hotswap=hotswap,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
@@ -1035,7 +1175,11 @@ class SD3LoraLoaderMixin(LoraBaseMixin):
|
||||
return state_dict
|
||||
|
||||
def load_lora_weights(
|
||||
self, pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]], adapter_name=None, **kwargs
|
||||
self,
|
||||
pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]],
|
||||
adapter_name=None,
|
||||
hotswap: bool = False,
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
Load LoRA weights specified in `pretrained_model_name_or_path_or_dict` into `self.unet` and
|
||||
@@ -1058,6 +1202,29 @@ class SD3LoraLoaderMixin(LoraBaseMixin):
|
||||
low_cpu_mem_usage (`bool`, *optional*):
|
||||
Speed up model loading by only loading the pretrained LoRA weights and not initializing the random
|
||||
weights.
|
||||
hotswap : (`bool`, *optional*)
|
||||
Defaults to `False`. Whether to substitute an existing (LoRA) adapter with the newly loaded adapter
|
||||
in-place. This means that, instead of loading an additional adapter, this will take the existing
|
||||
adapter weights and replace them with the weights of the new adapter. This can be faster and more
|
||||
memory efficient. However, the main advantage of hotswapping is that when the model is compiled with
|
||||
torch.compile, loading the new adapter does not require recompilation of the model. When using
|
||||
hotswapping, the passed `adapter_name` should be the name of an already loaded adapter.
|
||||
|
||||
If the new adapter and the old adapter have different ranks and/or LoRA alphas (i.e. scaling), you need
|
||||
to call an additional method before loading the adapter:
|
||||
|
||||
```py
|
||||
pipeline = ... # load diffusers pipeline
|
||||
max_rank = ... # the highest rank among all LoRAs that you want to load
|
||||
# call *before* compiling and loading the LoRA adapter
|
||||
pipeline.enable_lora_hotswap(target_rank=max_rank)
|
||||
pipeline.load_lora_weights(file_name)
|
||||
# optionally compile the model now
|
||||
```
|
||||
|
||||
Note that hotswapping adapters of the text encoder is not yet supported. There are some further
|
||||
limitations to this technique, which are documented here:
|
||||
https://huggingface.co/docs/peft/main/en/package_reference/hotswap
|
||||
kwargs (`dict`, *optional*):
|
||||
See [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`].
|
||||
"""
|
||||
@@ -1087,6 +1254,7 @@ class SD3LoraLoaderMixin(LoraBaseMixin):
|
||||
adapter_name=adapter_name,
|
||||
_pipeline=self,
|
||||
low_cpu_mem_usage=low_cpu_mem_usage,
|
||||
hotswap=hotswap,
|
||||
)
|
||||
self.load_lora_into_text_encoder(
|
||||
state_dict,
|
||||
@@ -1097,6 +1265,7 @@ class SD3LoraLoaderMixin(LoraBaseMixin):
|
||||
adapter_name=adapter_name,
|
||||
_pipeline=self,
|
||||
low_cpu_mem_usage=low_cpu_mem_usage,
|
||||
hotswap=hotswap,
|
||||
)
|
||||
self.load_lora_into_text_encoder(
|
||||
state_dict,
|
||||
@@ -1107,11 +1276,12 @@ class SD3LoraLoaderMixin(LoraBaseMixin):
|
||||
adapter_name=adapter_name,
|
||||
_pipeline=self,
|
||||
low_cpu_mem_usage=low_cpu_mem_usage,
|
||||
hotswap=hotswap,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def load_lora_into_transformer(
|
||||
cls, state_dict, transformer, adapter_name=None, _pipeline=None, low_cpu_mem_usage=False
|
||||
cls, state_dict, transformer, adapter_name=None, _pipeline=None, low_cpu_mem_usage=False, hotswap: bool = False
|
||||
):
|
||||
"""
|
||||
This will load the LoRA layers specified in `state_dict` into `transformer`.
|
||||
@@ -1129,6 +1299,29 @@ class SD3LoraLoaderMixin(LoraBaseMixin):
|
||||
low_cpu_mem_usage (`bool`, *optional*):
|
||||
Speed up model loading by only loading the pretrained LoRA weights and not initializing the random
|
||||
weights.
|
||||
hotswap : (`bool`, *optional*)
|
||||
Defaults to `False`. Whether to substitute an existing (LoRA) adapter with the newly loaded adapter
|
||||
in-place. This means that, instead of loading an additional adapter, this will take the existing
|
||||
adapter weights and replace them with the weights of the new adapter. This can be faster and more
|
||||
memory efficient. However, the main advantage of hotswapping is that when the model is compiled with
|
||||
torch.compile, loading the new adapter does not require recompilation of the model. When using
|
||||
hotswapping, the passed `adapter_name` should be the name of an already loaded adapter.
|
||||
|
||||
If the new adapter and the old adapter have different ranks and/or LoRA alphas (i.e. scaling), you need
|
||||
to call an additional method before loading the adapter:
|
||||
|
||||
```py
|
||||
pipeline = ... # load diffusers pipeline
|
||||
max_rank = ... # the highest rank among all LoRAs that you want to load
|
||||
# call *before* compiling and loading the LoRA adapter
|
||||
pipeline.enable_lora_hotswap(target_rank=max_rank)
|
||||
pipeline.load_lora_weights(file_name)
|
||||
# optionally compile the model now
|
||||
```
|
||||
|
||||
Note that hotswapping adapters of the text encoder is not yet supported. There are some further
|
||||
limitations to this technique, which are documented here:
|
||||
https://huggingface.co/docs/peft/main/en/package_reference/hotswap
|
||||
"""
|
||||
if low_cpu_mem_usage and is_peft_version("<", "0.13.0"):
|
||||
raise ValueError(
|
||||
@@ -1143,6 +1336,7 @@ class SD3LoraLoaderMixin(LoraBaseMixin):
|
||||
adapter_name=adapter_name,
|
||||
_pipeline=_pipeline,
|
||||
low_cpu_mem_usage=low_cpu_mem_usage,
|
||||
hotswap=hotswap,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
@@ -1157,6 +1351,7 @@ class SD3LoraLoaderMixin(LoraBaseMixin):
|
||||
adapter_name=None,
|
||||
_pipeline=None,
|
||||
low_cpu_mem_usage=False,
|
||||
hotswap: bool = False,
|
||||
):
|
||||
"""
|
||||
This will load the LoRA layers specified in `state_dict` into `text_encoder`
|
||||
@@ -1182,6 +1377,29 @@ class SD3LoraLoaderMixin(LoraBaseMixin):
|
||||
low_cpu_mem_usage (`bool`, *optional*):
|
||||
Speed up model loading by only loading the pretrained LoRA weights and not initializing the random
|
||||
weights.
|
||||
hotswap : (`bool`, *optional*)
|
||||
Defaults to `False`. Whether to substitute an existing (LoRA) adapter with the newly loaded adapter
|
||||
in-place. This means that, instead of loading an additional adapter, this will take the existing
|
||||
adapter weights and replace them with the weights of the new adapter. This can be faster and more
|
||||
memory efficient. However, the main advantage of hotswapping is that when the model is compiled with
|
||||
torch.compile, loading the new adapter does not require recompilation of the model. When using
|
||||
hotswapping, the passed `adapter_name` should be the name of an already loaded adapter.
|
||||
|
||||
If the new adapter and the old adapter have different ranks and/or LoRA alphas (i.e. scaling), you need
|
||||
to call an additional method before loading the adapter:
|
||||
|
||||
```py
|
||||
pipeline = ... # load diffusers pipeline
|
||||
max_rank = ... # the highest rank among all LoRAs that you want to load
|
||||
# call *before* compiling and loading the LoRA adapter
|
||||
pipeline.enable_lora_hotswap(target_rank=max_rank)
|
||||
pipeline.load_lora_weights(file_name)
|
||||
# optionally compile the model now
|
||||
```
|
||||
|
||||
Note that hotswapping adapters of the text encoder is not yet supported. There are some further
|
||||
limitations to this technique, which are documented here:
|
||||
https://huggingface.co/docs/peft/main/en/package_reference/hotswap
|
||||
"""
|
||||
_load_lora_into_text_encoder(
|
||||
state_dict=state_dict,
|
||||
@@ -1193,6 +1411,7 @@ class SD3LoraLoaderMixin(LoraBaseMixin):
|
||||
adapter_name=adapter_name,
|
||||
_pipeline=_pipeline,
|
||||
low_cpu_mem_usage=low_cpu_mem_usage,
|
||||
hotswap=hotswap,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
@@ -1476,7 +1695,11 @@ class FluxLoraLoaderMixin(LoraBaseMixin):
|
||||
return state_dict
|
||||
|
||||
def load_lora_weights(
|
||||
self, pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]], adapter_name=None, **kwargs
|
||||
self,
|
||||
pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]],
|
||||
adapter_name=None,
|
||||
hotswap: bool = False,
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
Load LoRA weights specified in `pretrained_model_name_or_path_or_dict` into `self.transformer` and
|
||||
@@ -1501,6 +1724,26 @@ class FluxLoraLoaderMixin(LoraBaseMixin):
|
||||
low_cpu_mem_usage (`bool`, *optional*):
|
||||
`Speed up model loading by only loading the pretrained LoRA weights and not initializing the random
|
||||
weights.
|
||||
hotswap : (`bool`, *optional*)
|
||||
Defaults to `False`. Whether to substitute an existing (LoRA) adapter with the newly loaded adapter
|
||||
in-place. This means that, instead of loading an additional adapter, this will take the existing
|
||||
adapter weights and replace them with the weights of the new adapter. This can be faster and more
|
||||
memory efficient. However, the main advantage of hotswapping is that when the model is compiled with
|
||||
torch.compile, loading the new adapter does not require recompilation of the model. When using
|
||||
hotswapping, the passed `adapter_name` should be the name of an already loaded adapter. If the new
|
||||
adapter and the old adapter have different ranks and/or LoRA alphas (i.e. scaling), you need to call an
|
||||
additional method before loading the adapter:
|
||||
```py
|
||||
pipeline = ... # load diffusers pipeline
|
||||
max_rank = ... # the highest rank among all LoRAs that you want to load
|
||||
# call *before* compiling and loading the LoRA adapter
|
||||
pipeline.enable_lora_hotswap(target_rank=max_rank)
|
||||
pipeline.load_lora_weights(file_name)
|
||||
# optionally compile the model now
|
||||
```
|
||||
Note that hotswapping adapters of the text encoder is not yet supported. There are some further
|
||||
limitations to this technique, which are documented here:
|
||||
https://huggingface.co/docs/peft/main/en/package_reference/hotswap
|
||||
"""
|
||||
if not USE_PEFT_BACKEND:
|
||||
raise ValueError("PEFT backend is required for this method.")
|
||||
@@ -1569,6 +1812,7 @@ class FluxLoraLoaderMixin(LoraBaseMixin):
|
||||
adapter_name=adapter_name,
|
||||
_pipeline=self,
|
||||
low_cpu_mem_usage=low_cpu_mem_usage,
|
||||
hotswap=hotswap,
|
||||
)
|
||||
|
||||
if len(transformer_norm_state_dict) > 0:
|
||||
@@ -1587,11 +1831,19 @@ class FluxLoraLoaderMixin(LoraBaseMixin):
|
||||
adapter_name=adapter_name,
|
||||
_pipeline=self,
|
||||
low_cpu_mem_usage=low_cpu_mem_usage,
|
||||
hotswap=hotswap,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def load_lora_into_transformer(
|
||||
cls, state_dict, network_alphas, transformer, adapter_name=None, _pipeline=None, low_cpu_mem_usage=False
|
||||
cls,
|
||||
state_dict,
|
||||
network_alphas,
|
||||
transformer,
|
||||
adapter_name=None,
|
||||
_pipeline=None,
|
||||
low_cpu_mem_usage=False,
|
||||
hotswap: bool = False,
|
||||
):
|
||||
"""
|
||||
This will load the LoRA layers specified in `state_dict` into `transformer`.
|
||||
@@ -1613,6 +1865,29 @@ class FluxLoraLoaderMixin(LoraBaseMixin):
|
||||
low_cpu_mem_usage (`bool`, *optional*):
|
||||
Speed up model loading by only loading the pretrained LoRA weights and not initializing the random
|
||||
weights.
|
||||
hotswap : (`bool`, *optional*)
|
||||
Defaults to `False`. Whether to substitute an existing (LoRA) adapter with the newly loaded adapter
|
||||
in-place. This means that, instead of loading an additional adapter, this will take the existing
|
||||
adapter weights and replace them with the weights of the new adapter. This can be faster and more
|
||||
memory efficient. However, the main advantage of hotswapping is that when the model is compiled with
|
||||
torch.compile, loading the new adapter does not require recompilation of the model. When using
|
||||
hotswapping, the passed `adapter_name` should be the name of an already loaded adapter.
|
||||
|
||||
If the new adapter and the old adapter have different ranks and/or LoRA alphas (i.e. scaling), you need
|
||||
to call an additional method before loading the adapter:
|
||||
|
||||
```py
|
||||
pipeline = ... # load diffusers pipeline
|
||||
max_rank = ... # the highest rank among all LoRAs that you want to load
|
||||
# call *before* compiling and loading the LoRA adapter
|
||||
pipeline.enable_lora_hotswap(target_rank=max_rank)
|
||||
pipeline.load_lora_weights(file_name)
|
||||
# optionally compile the model now
|
||||
```
|
||||
|
||||
Note that hotswapping adapters of the text encoder is not yet supported. There are some further
|
||||
limitations to this technique, which are documented here:
|
||||
https://huggingface.co/docs/peft/main/en/package_reference/hotswap
|
||||
"""
|
||||
if low_cpu_mem_usage and not is_peft_version(">=", "0.13.1"):
|
||||
raise ValueError(
|
||||
@@ -1627,6 +1902,7 @@ class FluxLoraLoaderMixin(LoraBaseMixin):
|
||||
adapter_name=adapter_name,
|
||||
_pipeline=_pipeline,
|
||||
low_cpu_mem_usage=low_cpu_mem_usage,
|
||||
hotswap=hotswap,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
@@ -1695,6 +1971,7 @@ class FluxLoraLoaderMixin(LoraBaseMixin):
|
||||
adapter_name=None,
|
||||
_pipeline=None,
|
||||
low_cpu_mem_usage=False,
|
||||
hotswap: bool = False,
|
||||
):
|
||||
"""
|
||||
This will load the LoRA layers specified in `state_dict` into `text_encoder`
|
||||
@@ -1720,6 +1997,29 @@ class FluxLoraLoaderMixin(LoraBaseMixin):
|
||||
low_cpu_mem_usage (`bool`, *optional*):
|
||||
Speed up model loading by only loading the pretrained LoRA weights and not initializing the random
|
||||
weights.
|
||||
hotswap : (`bool`, *optional*)
|
||||
Defaults to `False`. Whether to substitute an existing (LoRA) adapter with the newly loaded adapter
|
||||
in-place. This means that, instead of loading an additional adapter, this will take the existing
|
||||
adapter weights and replace them with the weights of the new adapter. This can be faster and more
|
||||
memory efficient. However, the main advantage of hotswapping is that when the model is compiled with
|
||||
torch.compile, loading the new adapter does not require recompilation of the model. When using
|
||||
hotswapping, the passed `adapter_name` should be the name of an already loaded adapter.
|
||||
|
||||
If the new adapter and the old adapter have different ranks and/or LoRA alphas (i.e. scaling), you need
|
||||
to call an additional method before loading the adapter:
|
||||
|
||||
```py
|
||||
pipeline = ... # load diffusers pipeline
|
||||
max_rank = ... # the highest rank among all LoRAs that you want to load
|
||||
# call *before* compiling and loading the LoRA adapter
|
||||
pipeline.enable_lora_hotswap(target_rank=max_rank)
|
||||
pipeline.load_lora_weights(file_name)
|
||||
# optionally compile the model now
|
||||
```
|
||||
|
||||
Note that hotswapping adapters of the text encoder is not yet supported. There are some further
|
||||
limitations to this technique, which are documented here:
|
||||
https://huggingface.co/docs/peft/main/en/package_reference/hotswap
|
||||
"""
|
||||
_load_lora_into_text_encoder(
|
||||
state_dict=state_dict,
|
||||
@@ -1731,6 +2031,7 @@ class FluxLoraLoaderMixin(LoraBaseMixin):
|
||||
adapter_name=adapter_name,
|
||||
_pipeline=_pipeline,
|
||||
low_cpu_mem_usage=low_cpu_mem_usage,
|
||||
hotswap=hotswap,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
@@ -2141,7 +2442,14 @@ class AmusedLoraLoaderMixin(StableDiffusionLoraLoaderMixin):
|
||||
@classmethod
|
||||
# Copied from diffusers.loaders.lora_pipeline.FluxLoraLoaderMixin.load_lora_into_transformer with FluxTransformer2DModel->UVit2DModel
|
||||
def load_lora_into_transformer(
|
||||
cls, state_dict, network_alphas, transformer, adapter_name=None, _pipeline=None, low_cpu_mem_usage=False
|
||||
cls,
|
||||
state_dict,
|
||||
network_alphas,
|
||||
transformer,
|
||||
adapter_name=None,
|
||||
_pipeline=None,
|
||||
low_cpu_mem_usage=False,
|
||||
hotswap: bool = False,
|
||||
):
|
||||
"""
|
||||
This will load the LoRA layers specified in `state_dict` into `transformer`.
|
||||
@@ -2163,6 +2471,29 @@ class AmusedLoraLoaderMixin(StableDiffusionLoraLoaderMixin):
|
||||
low_cpu_mem_usage (`bool`, *optional*):
|
||||
Speed up model loading by only loading the pretrained LoRA weights and not initializing the random
|
||||
weights.
|
||||
hotswap : (`bool`, *optional*)
|
||||
Defaults to `False`. Whether to substitute an existing (LoRA) adapter with the newly loaded adapter
|
||||
in-place. This means that, instead of loading an additional adapter, this will take the existing
|
||||
adapter weights and replace them with the weights of the new adapter. This can be faster and more
|
||||
memory efficient. However, the main advantage of hotswapping is that when the model is compiled with
|
||||
torch.compile, loading the new adapter does not require recompilation of the model. When using
|
||||
hotswapping, the passed `adapter_name` should be the name of an already loaded adapter.
|
||||
|
||||
If the new adapter and the old adapter have different ranks and/or LoRA alphas (i.e. scaling), you need
|
||||
to call an additional method before loading the adapter:
|
||||
|
||||
```py
|
||||
pipeline = ... # load diffusers pipeline
|
||||
max_rank = ... # the highest rank among all LoRAs that you want to load
|
||||
# call *before* compiling and loading the LoRA adapter
|
||||
pipeline.enable_lora_hotswap(target_rank=max_rank)
|
||||
pipeline.load_lora_weights(file_name)
|
||||
# optionally compile the model now
|
||||
```
|
||||
|
||||
Note that hotswapping adapters of the text encoder is not yet supported. There are some further
|
||||
limitations to this technique, which are documented here:
|
||||
https://huggingface.co/docs/peft/main/en/package_reference/hotswap
|
||||
"""
|
||||
if low_cpu_mem_usage and not is_peft_version(">=", "0.13.1"):
|
||||
raise ValueError(
|
||||
@@ -2177,6 +2508,7 @@ class AmusedLoraLoaderMixin(StableDiffusionLoraLoaderMixin):
|
||||
adapter_name=adapter_name,
|
||||
_pipeline=_pipeline,
|
||||
low_cpu_mem_usage=low_cpu_mem_usage,
|
||||
hotswap=hotswap,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
@@ -2191,6 +2523,7 @@ class AmusedLoraLoaderMixin(StableDiffusionLoraLoaderMixin):
|
||||
adapter_name=None,
|
||||
_pipeline=None,
|
||||
low_cpu_mem_usage=False,
|
||||
hotswap: bool = False,
|
||||
):
|
||||
"""
|
||||
This will load the LoRA layers specified in `state_dict` into `text_encoder`
|
||||
@@ -2216,6 +2549,29 @@ class AmusedLoraLoaderMixin(StableDiffusionLoraLoaderMixin):
|
||||
low_cpu_mem_usage (`bool`, *optional*):
|
||||
Speed up model loading by only loading the pretrained LoRA weights and not initializing the random
|
||||
weights.
|
||||
hotswap : (`bool`, *optional*)
|
||||
Defaults to `False`. Whether to substitute an existing (LoRA) adapter with the newly loaded adapter
|
||||
in-place. This means that, instead of loading an additional adapter, this will take the existing
|
||||
adapter weights and replace them with the weights of the new adapter. This can be faster and more
|
||||
memory efficient. However, the main advantage of hotswapping is that when the model is compiled with
|
||||
torch.compile, loading the new adapter does not require recompilation of the model. When using
|
||||
hotswapping, the passed `adapter_name` should be the name of an already loaded adapter.
|
||||
|
||||
If the new adapter and the old adapter have different ranks and/or LoRA alphas (i.e. scaling), you need
|
||||
to call an additional method before loading the adapter:
|
||||
|
||||
```py
|
||||
pipeline = ... # load diffusers pipeline
|
||||
max_rank = ... # the highest rank among all LoRAs that you want to load
|
||||
# call *before* compiling and loading the LoRA adapter
|
||||
pipeline.enable_lora_hotswap(target_rank=max_rank)
|
||||
pipeline.load_lora_weights(file_name)
|
||||
# optionally compile the model now
|
||||
```
|
||||
|
||||
Note that hotswapping adapters of the text encoder is not yet supported. There are some further
|
||||
limitations to this technique, which are documented here:
|
||||
https://huggingface.co/docs/peft/main/en/package_reference/hotswap
|
||||
"""
|
||||
_load_lora_into_text_encoder(
|
||||
state_dict=state_dict,
|
||||
@@ -2227,6 +2583,7 @@ class AmusedLoraLoaderMixin(StableDiffusionLoraLoaderMixin):
|
||||
adapter_name=adapter_name,
|
||||
_pipeline=_pipeline,
|
||||
low_cpu_mem_usage=low_cpu_mem_usage,
|
||||
hotswap=hotswap,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
@@ -2443,7 +2800,7 @@ class CogVideoXLoraLoaderMixin(LoraBaseMixin):
|
||||
@classmethod
|
||||
# Copied from diffusers.loaders.lora_pipeline.SD3LoraLoaderMixin.load_lora_into_transformer with SD3Transformer2DModel->CogVideoXTransformer3DModel
|
||||
def load_lora_into_transformer(
|
||||
cls, state_dict, transformer, adapter_name=None, _pipeline=None, low_cpu_mem_usage=False
|
||||
cls, state_dict, transformer, adapter_name=None, _pipeline=None, low_cpu_mem_usage=False, hotswap: bool = False
|
||||
):
|
||||
"""
|
||||
This will load the LoRA layers specified in `state_dict` into `transformer`.
|
||||
@@ -2461,6 +2818,29 @@ class CogVideoXLoraLoaderMixin(LoraBaseMixin):
|
||||
low_cpu_mem_usage (`bool`, *optional*):
|
||||
Speed up model loading by only loading the pretrained LoRA weights and not initializing the random
|
||||
weights.
|
||||
hotswap : (`bool`, *optional*)
|
||||
Defaults to `False`. Whether to substitute an existing (LoRA) adapter with the newly loaded adapter
|
||||
in-place. This means that, instead of loading an additional adapter, this will take the existing
|
||||
adapter weights and replace them with the weights of the new adapter. This can be faster and more
|
||||
memory efficient. However, the main advantage of hotswapping is that when the model is compiled with
|
||||
torch.compile, loading the new adapter does not require recompilation of the model. When using
|
||||
hotswapping, the passed `adapter_name` should be the name of an already loaded adapter.
|
||||
|
||||
If the new adapter and the old adapter have different ranks and/or LoRA alphas (i.e. scaling), you need
|
||||
to call an additional method before loading the adapter:
|
||||
|
||||
```py
|
||||
pipeline = ... # load diffusers pipeline
|
||||
max_rank = ... # the highest rank among all LoRAs that you want to load
|
||||
# call *before* compiling and loading the LoRA adapter
|
||||
pipeline.enable_lora_hotswap(target_rank=max_rank)
|
||||
pipeline.load_lora_weights(file_name)
|
||||
# optionally compile the model now
|
||||
```
|
||||
|
||||
Note that hotswapping adapters of the text encoder is not yet supported. There are some further
|
||||
limitations to this technique, which are documented here:
|
||||
https://huggingface.co/docs/peft/main/en/package_reference/hotswap
|
||||
"""
|
||||
if low_cpu_mem_usage and is_peft_version("<", "0.13.0"):
|
||||
raise ValueError(
|
||||
@@ -2475,6 +2855,7 @@ class CogVideoXLoraLoaderMixin(LoraBaseMixin):
|
||||
adapter_name=adapter_name,
|
||||
_pipeline=_pipeline,
|
||||
low_cpu_mem_usage=low_cpu_mem_usage,
|
||||
hotswap=hotswap,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
@@ -2750,7 +3131,7 @@ class Mochi1LoraLoaderMixin(LoraBaseMixin):
|
||||
@classmethod
|
||||
# Copied from diffusers.loaders.lora_pipeline.SD3LoraLoaderMixin.load_lora_into_transformer with SD3Transformer2DModel->MochiTransformer3DModel
|
||||
def load_lora_into_transformer(
|
||||
cls, state_dict, transformer, adapter_name=None, _pipeline=None, low_cpu_mem_usage=False
|
||||
cls, state_dict, transformer, adapter_name=None, _pipeline=None, low_cpu_mem_usage=False, hotswap: bool = False
|
||||
):
|
||||
"""
|
||||
This will load the LoRA layers specified in `state_dict` into `transformer`.
|
||||
@@ -2768,6 +3149,29 @@ class Mochi1LoraLoaderMixin(LoraBaseMixin):
|
||||
low_cpu_mem_usage (`bool`, *optional*):
|
||||
Speed up model loading by only loading the pretrained LoRA weights and not initializing the random
|
||||
weights.
|
||||
hotswap : (`bool`, *optional*)
|
||||
Defaults to `False`. Whether to substitute an existing (LoRA) adapter with the newly loaded adapter
|
||||
in-place. This means that, instead of loading an additional adapter, this will take the existing
|
||||
adapter weights and replace them with the weights of the new adapter. This can be faster and more
|
||||
memory efficient. However, the main advantage of hotswapping is that when the model is compiled with
|
||||
torch.compile, loading the new adapter does not require recompilation of the model. When using
|
||||
hotswapping, the passed `adapter_name` should be the name of an already loaded adapter.
|
||||
|
||||
If the new adapter and the old adapter have different ranks and/or LoRA alphas (i.e. scaling), you need
|
||||
to call an additional method before loading the adapter:
|
||||
|
||||
```py
|
||||
pipeline = ... # load diffusers pipeline
|
||||
max_rank = ... # the highest rank among all LoRAs that you want to load
|
||||
# call *before* compiling and loading the LoRA adapter
|
||||
pipeline.enable_lora_hotswap(target_rank=max_rank)
|
||||
pipeline.load_lora_weights(file_name)
|
||||
# optionally compile the model now
|
||||
```
|
||||
|
||||
Note that hotswapping adapters of the text encoder is not yet supported. There are some further
|
||||
limitations to this technique, which are documented here:
|
||||
https://huggingface.co/docs/peft/main/en/package_reference/hotswap
|
||||
"""
|
||||
if low_cpu_mem_usage and is_peft_version("<", "0.13.0"):
|
||||
raise ValueError(
|
||||
@@ -2782,6 +3186,7 @@ class Mochi1LoraLoaderMixin(LoraBaseMixin):
|
||||
adapter_name=adapter_name,
|
||||
_pipeline=_pipeline,
|
||||
low_cpu_mem_usage=low_cpu_mem_usage,
|
||||
hotswap=hotswap,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
@@ -3059,7 +3464,7 @@ class LTXVideoLoraLoaderMixin(LoraBaseMixin):
|
||||
@classmethod
|
||||
# Copied from diffusers.loaders.lora_pipeline.SD3LoraLoaderMixin.load_lora_into_transformer with SD3Transformer2DModel->LTXVideoTransformer3DModel
|
||||
def load_lora_into_transformer(
|
||||
cls, state_dict, transformer, adapter_name=None, _pipeline=None, low_cpu_mem_usage=False
|
||||
cls, state_dict, transformer, adapter_name=None, _pipeline=None, low_cpu_mem_usage=False, hotswap: bool = False
|
||||
):
|
||||
"""
|
||||
This will load the LoRA layers specified in `state_dict` into `transformer`.
|
||||
@@ -3077,6 +3482,29 @@ class LTXVideoLoraLoaderMixin(LoraBaseMixin):
|
||||
low_cpu_mem_usage (`bool`, *optional*):
|
||||
Speed up model loading by only loading the pretrained LoRA weights and not initializing the random
|
||||
weights.
|
||||
hotswap : (`bool`, *optional*)
|
||||
Defaults to `False`. Whether to substitute an existing (LoRA) adapter with the newly loaded adapter
|
||||
in-place. This means that, instead of loading an additional adapter, this will take the existing
|
||||
adapter weights and replace them with the weights of the new adapter. This can be faster and more
|
||||
memory efficient. However, the main advantage of hotswapping is that when the model is compiled with
|
||||
torch.compile, loading the new adapter does not require recompilation of the model. When using
|
||||
hotswapping, the passed `adapter_name` should be the name of an already loaded adapter.
|
||||
|
||||
If the new adapter and the old adapter have different ranks and/or LoRA alphas (i.e. scaling), you need
|
||||
to call an additional method before loading the adapter:
|
||||
|
||||
```py
|
||||
pipeline = ... # load diffusers pipeline
|
||||
max_rank = ... # the highest rank among all LoRAs that you want to load
|
||||
# call *before* compiling and loading the LoRA adapter
|
||||
pipeline.enable_lora_hotswap(target_rank=max_rank)
|
||||
pipeline.load_lora_weights(file_name)
|
||||
# optionally compile the model now
|
||||
```
|
||||
|
||||
Note that hotswapping adapters of the text encoder is not yet supported. There are some further
|
||||
limitations to this technique, which are documented here:
|
||||
https://huggingface.co/docs/peft/main/en/package_reference/hotswap
|
||||
"""
|
||||
if low_cpu_mem_usage and is_peft_version("<", "0.13.0"):
|
||||
raise ValueError(
|
||||
@@ -3091,6 +3519,7 @@ class LTXVideoLoraLoaderMixin(LoraBaseMixin):
|
||||
adapter_name=adapter_name,
|
||||
_pipeline=_pipeline,
|
||||
low_cpu_mem_usage=low_cpu_mem_usage,
|
||||
hotswap=hotswap,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
@@ -3368,7 +3797,7 @@ class SanaLoraLoaderMixin(LoraBaseMixin):
|
||||
@classmethod
|
||||
# Copied from diffusers.loaders.lora_pipeline.SD3LoraLoaderMixin.load_lora_into_transformer with SD3Transformer2DModel->SanaTransformer2DModel
|
||||
def load_lora_into_transformer(
|
||||
cls, state_dict, transformer, adapter_name=None, _pipeline=None, low_cpu_mem_usage=False
|
||||
cls, state_dict, transformer, adapter_name=None, _pipeline=None, low_cpu_mem_usage=False, hotswap: bool = False
|
||||
):
|
||||
"""
|
||||
This will load the LoRA layers specified in `state_dict` into `transformer`.
|
||||
@@ -3386,6 +3815,29 @@ class SanaLoraLoaderMixin(LoraBaseMixin):
|
||||
low_cpu_mem_usage (`bool`, *optional*):
|
||||
Speed up model loading by only loading the pretrained LoRA weights and not initializing the random
|
||||
weights.
|
||||
hotswap : (`bool`, *optional*)
|
||||
Defaults to `False`. Whether to substitute an existing (LoRA) adapter with the newly loaded adapter
|
||||
in-place. This means that, instead of loading an additional adapter, this will take the existing
|
||||
adapter weights and replace them with the weights of the new adapter. This can be faster and more
|
||||
memory efficient. However, the main advantage of hotswapping is that when the model is compiled with
|
||||
torch.compile, loading the new adapter does not require recompilation of the model. When using
|
||||
hotswapping, the passed `adapter_name` should be the name of an already loaded adapter.
|
||||
|
||||
If the new adapter and the old adapter have different ranks and/or LoRA alphas (i.e. scaling), you need
|
||||
to call an additional method before loading the adapter:
|
||||
|
||||
```py
|
||||
pipeline = ... # load diffusers pipeline
|
||||
max_rank = ... # the highest rank among all LoRAs that you want to load
|
||||
# call *before* compiling and loading the LoRA adapter
|
||||
pipeline.enable_lora_hotswap(target_rank=max_rank)
|
||||
pipeline.load_lora_weights(file_name)
|
||||
# optionally compile the model now
|
||||
```
|
||||
|
||||
Note that hotswapping adapters of the text encoder is not yet supported. There are some further
|
||||
limitations to this technique, which are documented here:
|
||||
https://huggingface.co/docs/peft/main/en/package_reference/hotswap
|
||||
"""
|
||||
if low_cpu_mem_usage and is_peft_version("<", "0.13.0"):
|
||||
raise ValueError(
|
||||
@@ -3400,6 +3852,7 @@ class SanaLoraLoaderMixin(LoraBaseMixin):
|
||||
adapter_name=adapter_name,
|
||||
_pipeline=_pipeline,
|
||||
low_cpu_mem_usage=low_cpu_mem_usage,
|
||||
hotswap=hotswap,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
@@ -3680,7 +4133,7 @@ class HunyuanVideoLoraLoaderMixin(LoraBaseMixin):
|
||||
@classmethod
|
||||
# Copied from diffusers.loaders.lora_pipeline.SD3LoraLoaderMixin.load_lora_into_transformer with SD3Transformer2DModel->HunyuanVideoTransformer3DModel
|
||||
def load_lora_into_transformer(
|
||||
cls, state_dict, transformer, adapter_name=None, _pipeline=None, low_cpu_mem_usage=False
|
||||
cls, state_dict, transformer, adapter_name=None, _pipeline=None, low_cpu_mem_usage=False, hotswap: bool = False
|
||||
):
|
||||
"""
|
||||
This will load the LoRA layers specified in `state_dict` into `transformer`.
|
||||
@@ -3698,6 +4151,29 @@ class HunyuanVideoLoraLoaderMixin(LoraBaseMixin):
|
||||
low_cpu_mem_usage (`bool`, *optional*):
|
||||
Speed up model loading by only loading the pretrained LoRA weights and not initializing the random
|
||||
weights.
|
||||
hotswap : (`bool`, *optional*)
|
||||
Defaults to `False`. Whether to substitute an existing (LoRA) adapter with the newly loaded adapter
|
||||
in-place. This means that, instead of loading an additional adapter, this will take the existing
|
||||
adapter weights and replace them with the weights of the new adapter. This can be faster and more
|
||||
memory efficient. However, the main advantage of hotswapping is that when the model is compiled with
|
||||
torch.compile, loading the new adapter does not require recompilation of the model. When using
|
||||
hotswapping, the passed `adapter_name` should be the name of an already loaded adapter.
|
||||
|
||||
If the new adapter and the old adapter have different ranks and/or LoRA alphas (i.e. scaling), you need
|
||||
to call an additional method before loading the adapter:
|
||||
|
||||
```py
|
||||
pipeline = ... # load diffusers pipeline
|
||||
max_rank = ... # the highest rank among all LoRAs that you want to load
|
||||
# call *before* compiling and loading the LoRA adapter
|
||||
pipeline.enable_lora_hotswap(target_rank=max_rank)
|
||||
pipeline.load_lora_weights(file_name)
|
||||
# optionally compile the model now
|
||||
```
|
||||
|
||||
Note that hotswapping adapters of the text encoder is not yet supported. There are some further
|
||||
limitations to this technique, which are documented here:
|
||||
https://huggingface.co/docs/peft/main/en/package_reference/hotswap
|
||||
"""
|
||||
if low_cpu_mem_usage and is_peft_version("<", "0.13.0"):
|
||||
raise ValueError(
|
||||
@@ -3712,6 +4188,7 @@ class HunyuanVideoLoraLoaderMixin(LoraBaseMixin):
|
||||
adapter_name=adapter_name,
|
||||
_pipeline=_pipeline,
|
||||
low_cpu_mem_usage=low_cpu_mem_usage,
|
||||
hotswap=hotswap,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
@@ -3993,7 +4470,7 @@ class Lumina2LoraLoaderMixin(LoraBaseMixin):
|
||||
@classmethod
|
||||
# Copied from diffusers.loaders.lora_pipeline.SD3LoraLoaderMixin.load_lora_into_transformer with SD3Transformer2DModel->Lumina2Transformer2DModel
|
||||
def load_lora_into_transformer(
|
||||
cls, state_dict, transformer, adapter_name=None, _pipeline=None, low_cpu_mem_usage=False
|
||||
cls, state_dict, transformer, adapter_name=None, _pipeline=None, low_cpu_mem_usage=False, hotswap: bool = False
|
||||
):
|
||||
"""
|
||||
This will load the LoRA layers specified in `state_dict` into `transformer`.
|
||||
@@ -4011,6 +4488,29 @@ class Lumina2LoraLoaderMixin(LoraBaseMixin):
|
||||
low_cpu_mem_usage (`bool`, *optional*):
|
||||
Speed up model loading by only loading the pretrained LoRA weights and not initializing the random
|
||||
weights.
|
||||
hotswap : (`bool`, *optional*)
|
||||
Defaults to `False`. Whether to substitute an existing (LoRA) adapter with the newly loaded adapter
|
||||
in-place. This means that, instead of loading an additional adapter, this will take the existing
|
||||
adapter weights and replace them with the weights of the new adapter. This can be faster and more
|
||||
memory efficient. However, the main advantage of hotswapping is that when the model is compiled with
|
||||
torch.compile, loading the new adapter does not require recompilation of the model. When using
|
||||
hotswapping, the passed `adapter_name` should be the name of an already loaded adapter.
|
||||
|
||||
If the new adapter and the old adapter have different ranks and/or LoRA alphas (i.e. scaling), you need
|
||||
to call an additional method before loading the adapter:
|
||||
|
||||
```py
|
||||
pipeline = ... # load diffusers pipeline
|
||||
max_rank = ... # the highest rank among all LoRAs that you want to load
|
||||
# call *before* compiling and loading the LoRA adapter
|
||||
pipeline.enable_lora_hotswap(target_rank=max_rank)
|
||||
pipeline.load_lora_weights(file_name)
|
||||
# optionally compile the model now
|
||||
```
|
||||
|
||||
Note that hotswapping adapters of the text encoder is not yet supported. There are some further
|
||||
limitations to this technique, which are documented here:
|
||||
https://huggingface.co/docs/peft/main/en/package_reference/hotswap
|
||||
"""
|
||||
if low_cpu_mem_usage and is_peft_version("<", "0.13.0"):
|
||||
raise ValueError(
|
||||
@@ -4025,6 +4525,7 @@ class Lumina2LoraLoaderMixin(LoraBaseMixin):
|
||||
adapter_name=adapter_name,
|
||||
_pipeline=_pipeline,
|
||||
low_cpu_mem_usage=low_cpu_mem_usage,
|
||||
hotswap=hotswap,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
@@ -4333,7 +4834,7 @@ class WanLoraLoaderMixin(LoraBaseMixin):
|
||||
@classmethod
|
||||
# Copied from diffusers.loaders.lora_pipeline.SD3LoraLoaderMixin.load_lora_into_transformer with SD3Transformer2DModel->WanTransformer3DModel
|
||||
def load_lora_into_transformer(
|
||||
cls, state_dict, transformer, adapter_name=None, _pipeline=None, low_cpu_mem_usage=False
|
||||
cls, state_dict, transformer, adapter_name=None, _pipeline=None, low_cpu_mem_usage=False, hotswap: bool = False
|
||||
):
|
||||
"""
|
||||
This will load the LoRA layers specified in `state_dict` into `transformer`.
|
||||
@@ -4351,6 +4852,29 @@ class WanLoraLoaderMixin(LoraBaseMixin):
|
||||
low_cpu_mem_usage (`bool`, *optional*):
|
||||
Speed up model loading by only loading the pretrained LoRA weights and not initializing the random
|
||||
weights.
|
||||
hotswap : (`bool`, *optional*)
|
||||
Defaults to `False`. Whether to substitute an existing (LoRA) adapter with the newly loaded adapter
|
||||
in-place. This means that, instead of loading an additional adapter, this will take the existing
|
||||
adapter weights and replace them with the weights of the new adapter. This can be faster and more
|
||||
memory efficient. However, the main advantage of hotswapping is that when the model is compiled with
|
||||
torch.compile, loading the new adapter does not require recompilation of the model. When using
|
||||
hotswapping, the passed `adapter_name` should be the name of an already loaded adapter.
|
||||
|
||||
If the new adapter and the old adapter have different ranks and/or LoRA alphas (i.e. scaling), you need
|
||||
to call an additional method before loading the adapter:
|
||||
|
||||
```py
|
||||
pipeline = ... # load diffusers pipeline
|
||||
max_rank = ... # the highest rank among all LoRAs that you want to load
|
||||
# call *before* compiling and loading the LoRA adapter
|
||||
pipeline.enable_lora_hotswap(target_rank=max_rank)
|
||||
pipeline.load_lora_weights(file_name)
|
||||
# optionally compile the model now
|
||||
```
|
||||
|
||||
Note that hotswapping adapters of the text encoder is not yet supported. There are some further
|
||||
limitations to this technique, which are documented here:
|
||||
https://huggingface.co/docs/peft/main/en/package_reference/hotswap
|
||||
"""
|
||||
if low_cpu_mem_usage and is_peft_version("<", "0.13.0"):
|
||||
raise ValueError(
|
||||
@@ -4365,6 +4889,7 @@ class WanLoraLoaderMixin(LoraBaseMixin):
|
||||
adapter_name=adapter_name,
|
||||
_pipeline=_pipeline,
|
||||
low_cpu_mem_usage=low_cpu_mem_usage,
|
||||
hotswap=hotswap,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
@@ -4642,7 +5167,7 @@ class CogView4LoraLoaderMixin(LoraBaseMixin):
|
||||
@classmethod
|
||||
# Copied from diffusers.loaders.lora_pipeline.SD3LoraLoaderMixin.load_lora_into_transformer with SD3Transformer2DModel->CogView4Transformer2DModel
|
||||
def load_lora_into_transformer(
|
||||
cls, state_dict, transformer, adapter_name=None, _pipeline=None, low_cpu_mem_usage=False
|
||||
cls, state_dict, transformer, adapter_name=None, _pipeline=None, low_cpu_mem_usage=False, hotswap: bool = False
|
||||
):
|
||||
"""
|
||||
This will load the LoRA layers specified in `state_dict` into `transformer`.
|
||||
@@ -4660,6 +5185,29 @@ class CogView4LoraLoaderMixin(LoraBaseMixin):
|
||||
low_cpu_mem_usage (`bool`, *optional*):
|
||||
Speed up model loading by only loading the pretrained LoRA weights and not initializing the random
|
||||
weights.
|
||||
hotswap : (`bool`, *optional*)
|
||||
Defaults to `False`. Whether to substitute an existing (LoRA) adapter with the newly loaded adapter
|
||||
in-place. This means that, instead of loading an additional adapter, this will take the existing
|
||||
adapter weights and replace them with the weights of the new adapter. This can be faster and more
|
||||
memory efficient. However, the main advantage of hotswapping is that when the model is compiled with
|
||||
torch.compile, loading the new adapter does not require recompilation of the model. When using
|
||||
hotswapping, the passed `adapter_name` should be the name of an already loaded adapter.
|
||||
|
||||
If the new adapter and the old adapter have different ranks and/or LoRA alphas (i.e. scaling), you need
|
||||
to call an additional method before loading the adapter:
|
||||
|
||||
```py
|
||||
pipeline = ... # load diffusers pipeline
|
||||
max_rank = ... # the highest rank among all LoRAs that you want to load
|
||||
# call *before* compiling and loading the LoRA adapter
|
||||
pipeline.enable_lora_hotswap(target_rank=max_rank)
|
||||
pipeline.load_lora_weights(file_name)
|
||||
# optionally compile the model now
|
||||
```
|
||||
|
||||
Note that hotswapping adapters of the text encoder is not yet supported. There are some further
|
||||
limitations to this technique, which are documented here:
|
||||
https://huggingface.co/docs/peft/main/en/package_reference/hotswap
|
||||
"""
|
||||
if low_cpu_mem_usage and is_peft_version("<", "0.13.0"):
|
||||
raise ValueError(
|
||||
@@ -4674,6 +5222,7 @@ class CogView4LoraLoaderMixin(LoraBaseMixin):
|
||||
adapter_name=adapter_name,
|
||||
_pipeline=_pipeline,
|
||||
low_cpu_mem_usage=low_cpu_mem_usage,
|
||||
hotswap=hotswap,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
|
||||
@@ -16,7 +16,7 @@ import inspect
|
||||
import os
|
||||
from functools import partial
|
||||
from pathlib import Path
|
||||
from typing import Dict, List, Optional, Union
|
||||
from typing import Dict, List, Literal, Optional, Union
|
||||
|
||||
import safetensors
|
||||
import torch
|
||||
@@ -128,6 +128,8 @@ class PeftAdapterMixin:
|
||||
"""
|
||||
|
||||
_hf_peft_config_loaded = False
|
||||
# kwargs for prepare_model_for_compiled_hotswap, if required
|
||||
_prepare_lora_hotswap_kwargs: Optional[dict] = None
|
||||
|
||||
@classmethod
|
||||
# Copied from diffusers.loaders.lora_base.LoraBaseMixin._optionally_disable_offloading
|
||||
@@ -145,7 +147,9 @@ class PeftAdapterMixin:
|
||||
"""
|
||||
return _func_optionally_disable_offloading(_pipeline=_pipeline)
|
||||
|
||||
def load_lora_adapter(self, pretrained_model_name_or_path_or_dict, prefix="transformer", **kwargs):
|
||||
def load_lora_adapter(
|
||||
self, pretrained_model_name_or_path_or_dict, prefix="transformer", hotswap: bool = False, **kwargs
|
||||
):
|
||||
r"""
|
||||
Loads a LoRA adapter into the underlying model.
|
||||
|
||||
@@ -189,6 +193,29 @@ class PeftAdapterMixin:
|
||||
low_cpu_mem_usage (`bool`, *optional*):
|
||||
Speed up model loading by only loading the pretrained LoRA weights and not initializing the random
|
||||
weights.
|
||||
hotswap : (`bool`, *optional*)
|
||||
Defaults to `False`. Whether to substitute an existing (LoRA) adapter with the newly loaded adapter
|
||||
in-place. This means that, instead of loading an additional adapter, this will take the existing
|
||||
adapter weights and replace them with the weights of the new adapter. This can be faster and more
|
||||
memory efficient. However, the main advantage of hotswapping is that when the model is compiled with
|
||||
torch.compile, loading the new adapter does not require recompilation of the model. When using
|
||||
hotswapping, the passed `adapter_name` should be the name of an already loaded adapter.
|
||||
|
||||
If the new adapter and the old adapter have different ranks and/or LoRA alphas (i.e. scaling), you need
|
||||
to call an additional method before loading the adapter:
|
||||
|
||||
```py
|
||||
pipeline = ... # load diffusers pipeline
|
||||
max_rank = ... # the highest rank among all LoRAs that you want to load
|
||||
# call *before* compiling and loading the LoRA adapter
|
||||
pipeline.enable_lora_hotswap(target_rank=max_rank)
|
||||
pipeline.load_lora_weights(file_name)
|
||||
# optionally compile the model now
|
||||
```
|
||||
|
||||
Note that hotswapping adapters of the text encoder is not yet supported. There are some further
|
||||
limitations to this technique, which are documented here:
|
||||
https://huggingface.co/docs/peft/main/en/package_reference/hotswap
|
||||
"""
|
||||
from peft import LoraConfig, inject_adapter_in_model, set_peft_model_state_dict
|
||||
from peft.tuners.tuners_utils import BaseTunerLayer
|
||||
@@ -239,10 +266,15 @@ class PeftAdapterMixin:
|
||||
state_dict = {k[len(f"{prefix}.") :]: v for k, v in state_dict.items() if k.startswith(f"{prefix}.")}
|
||||
|
||||
if len(state_dict) > 0:
|
||||
if adapter_name in getattr(self, "peft_config", {}):
|
||||
if adapter_name in getattr(self, "peft_config", {}) and not hotswap:
|
||||
raise ValueError(
|
||||
f"Adapter name {adapter_name} already in use in the model - please select a new adapter name."
|
||||
)
|
||||
elif adapter_name not in getattr(self, "peft_config", {}) and hotswap:
|
||||
raise ValueError(
|
||||
f"Trying to hotswap LoRA adapter '{adapter_name}' but there is no existing adapter by that name. "
|
||||
"Please choose an existing adapter name or set `hotswap=False` to prevent hotswapping."
|
||||
)
|
||||
|
||||
# check with first key if is not in peft format
|
||||
first_key = next(iter(state_dict.keys()))
|
||||
@@ -302,11 +334,68 @@ class PeftAdapterMixin:
|
||||
if is_peft_version(">=", "0.13.1"):
|
||||
peft_kwargs["low_cpu_mem_usage"] = low_cpu_mem_usage
|
||||
|
||||
if hotswap or (self._prepare_lora_hotswap_kwargs is not None):
|
||||
if is_peft_version(">", "0.14.0"):
|
||||
from peft.utils.hotswap import (
|
||||
check_hotswap_configs_compatible,
|
||||
hotswap_adapter_from_state_dict,
|
||||
prepare_model_for_compiled_hotswap,
|
||||
)
|
||||
else:
|
||||
msg = (
|
||||
"Hotswapping requires PEFT > v0.14. Please upgrade PEFT to a higher version or install it "
|
||||
"from source."
|
||||
)
|
||||
raise ImportError(msg)
|
||||
|
||||
if hotswap:
|
||||
|
||||
def map_state_dict_for_hotswap(sd):
|
||||
# For hotswapping, we need the adapter name to be present in the state dict keys
|
||||
new_sd = {}
|
||||
for k, v in sd.items():
|
||||
if k.endswith("lora_A.weight") or key.endswith("lora_B.weight"):
|
||||
k = k[: -len(".weight")] + f".{adapter_name}.weight"
|
||||
elif k.endswith("lora_B.bias"): # lora_bias=True option
|
||||
k = k[: -len(".bias")] + f".{adapter_name}.bias"
|
||||
new_sd[k] = v
|
||||
return new_sd
|
||||
|
||||
# To handle scenarios where we cannot successfully set state dict. If it's unsucessful,
|
||||
# we should also delete the `peft_config` associated to the `adapter_name`.
|
||||
try:
|
||||
inject_adapter_in_model(lora_config, self, adapter_name=adapter_name, **peft_kwargs)
|
||||
incompatible_keys = set_peft_model_state_dict(self, state_dict, adapter_name, **peft_kwargs)
|
||||
if hotswap:
|
||||
state_dict = map_state_dict_for_hotswap(state_dict)
|
||||
check_hotswap_configs_compatible(self.peft_config[adapter_name], lora_config)
|
||||
try:
|
||||
hotswap_adapter_from_state_dict(
|
||||
model=self,
|
||||
state_dict=state_dict,
|
||||
adapter_name=adapter_name,
|
||||
config=lora_config,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Hotswapping {adapter_name} was unsucessful with the following error: \n{e}")
|
||||
raise
|
||||
# the hotswap function raises if there are incompatible keys, so if we reach this point we can set
|
||||
# it to None
|
||||
incompatible_keys = None
|
||||
else:
|
||||
inject_adapter_in_model(lora_config, self, adapter_name=adapter_name, **peft_kwargs)
|
||||
incompatible_keys = set_peft_model_state_dict(self, state_dict, adapter_name, **peft_kwargs)
|
||||
|
||||
if self._prepare_lora_hotswap_kwargs is not None:
|
||||
# For hotswapping of compiled models or adapters with different ranks.
|
||||
# If the user called enable_lora_hotswap, we need to ensure it is called:
|
||||
# - after the first adapter was loaded
|
||||
# - before the model is compiled and the 2nd adapter is being hotswapped in
|
||||
# Therefore, it needs to be called here
|
||||
prepare_model_for_compiled_hotswap(
|
||||
self, config=lora_config, **self._prepare_lora_hotswap_kwargs
|
||||
)
|
||||
# We only want to call prepare_model_for_compiled_hotswap once
|
||||
self._prepare_lora_hotswap_kwargs = None
|
||||
|
||||
# Set peft config loaded flag to True if module has been successfully injected and incompatible keys retrieved
|
||||
if not self._hf_peft_config_loaded:
|
||||
self._hf_peft_config_loaded = True
|
||||
@@ -769,3 +858,36 @@ class PeftAdapterMixin:
|
||||
# Pop also the corresponding adapter from the config
|
||||
if hasattr(self, "peft_config"):
|
||||
self.peft_config.pop(adapter_name, None)
|
||||
|
||||
def enable_lora_hotswap(
|
||||
self, target_rank: int = 128, check_compiled: Literal["error", "warn", "ignore"] = "error"
|
||||
) -> None:
|
||||
"""Enables the possibility to hotswap LoRA adapters.
|
||||
|
||||
Calling this method is only required when hotswapping adapters and if the model is compiled or if the ranks of
|
||||
the loaded adapters differ.
|
||||
|
||||
Args:
|
||||
target_rank (`int`, *optional*, defaults to `128`):
|
||||
The highest rank among all the adapters that will be loaded.
|
||||
|
||||
check_compiled (`str`, *optional*, defaults to `"error"`):
|
||||
How to handle the case when the model is already compiled, which should generally be avoided. The
|
||||
options are:
|
||||
- "error" (default): raise an error
|
||||
- "warn": issue a warning
|
||||
- "ignore": do nothing
|
||||
"""
|
||||
if getattr(self, "peft_config", {}):
|
||||
if check_compiled == "error":
|
||||
raise RuntimeError("Call `enable_lora_hotswap` before loading the first adapter.")
|
||||
elif check_compiled == "warn":
|
||||
logger.warning(
|
||||
"It is recommended to call `enable_lora_hotswap` before loading the first adapter to avoid recompilation."
|
||||
)
|
||||
elif check_compiled != "ignore":
|
||||
raise ValueError(
|
||||
f"check_compiles should be one of 'error', 'warn', or 'ignore', got '{check_compiled}' instead."
|
||||
)
|
||||
|
||||
self._prepare_lora_hotswap_kwargs = {"target_rank": target_rank, "check_compiled": check_compiled}
|
||||
|
||||
@@ -210,7 +210,7 @@ class MochiDownBlock3D(nn.Module):
|
||||
hidden_states, new_conv_cache[conv_cache_key] = self._gradient_checkpointing_func(
|
||||
resnet,
|
||||
hidden_states,
|
||||
conv_cache=conv_cache.get(conv_cache_key),
|
||||
conv_cache.get(conv_cache_key),
|
||||
)
|
||||
else:
|
||||
hidden_states, new_conv_cache[conv_cache_key] = resnet(
|
||||
@@ -306,7 +306,7 @@ class MochiMidBlock3D(nn.Module):
|
||||
|
||||
if torch.is_grad_enabled() and self.gradient_checkpointing:
|
||||
hidden_states, new_conv_cache[conv_cache_key] = self._gradient_checkpointing_func(
|
||||
resnet, hidden_states, conv_cache=conv_cache.get(conv_cache_key)
|
||||
resnet, hidden_states, conv_cache.get(conv_cache_key)
|
||||
)
|
||||
else:
|
||||
hidden_states, new_conv_cache[conv_cache_key] = resnet(
|
||||
@@ -382,7 +382,7 @@ class MochiUpBlock3D(nn.Module):
|
||||
hidden_states, new_conv_cache[conv_cache_key] = self._gradient_checkpointing_func(
|
||||
resnet,
|
||||
hidden_states,
|
||||
conv_cache=conv_cache.get(conv_cache_key),
|
||||
conv_cache.get(conv_cache_key),
|
||||
)
|
||||
else:
|
||||
hidden_states, new_conv_cache[conv_cache_key] = resnet(
|
||||
@@ -497,6 +497,8 @@ class MochiEncoder3D(nn.Module):
|
||||
self.norm_out = MochiChunkedGroupNorm3D(block_out_channels[-1])
|
||||
self.proj_out = nn.Linear(block_out_channels[-1], 2 * out_channels, bias=False)
|
||||
|
||||
self.gradient_checkpointing = False
|
||||
|
||||
def forward(
|
||||
self, hidden_states: torch.Tensor, conv_cache: Optional[Dict[str, torch.Tensor]] = None
|
||||
) -> torch.Tensor:
|
||||
@@ -513,13 +515,13 @@ class MochiEncoder3D(nn.Module):
|
||||
|
||||
if torch.is_grad_enabled() and self.gradient_checkpointing:
|
||||
hidden_states, new_conv_cache["block_in"] = self._gradient_checkpointing_func(
|
||||
self.block_in, hidden_states, conv_cache=conv_cache.get("block_in")
|
||||
self.block_in, hidden_states, conv_cache.get("block_in")
|
||||
)
|
||||
|
||||
for i, down_block in enumerate(self.down_blocks):
|
||||
conv_cache_key = f"down_block_{i}"
|
||||
hidden_states, new_conv_cache[conv_cache_key] = self._gradient_checkpointing_func(
|
||||
down_block, hidden_states, conv_cache=conv_cache.get(conv_cache_key)
|
||||
down_block, hidden_states, conv_cache.get(conv_cache_key)
|
||||
)
|
||||
else:
|
||||
hidden_states, new_conv_cache["block_in"] = self.block_in(
|
||||
@@ -623,13 +625,13 @@ class MochiDecoder3D(nn.Module):
|
||||
# 1. Mid
|
||||
if torch.is_grad_enabled() and self.gradient_checkpointing:
|
||||
hidden_states, new_conv_cache["block_in"] = self._gradient_checkpointing_func(
|
||||
self.block_in, hidden_states, conv_cache=conv_cache.get("block_in")
|
||||
self.block_in, hidden_states, conv_cache.get("block_in")
|
||||
)
|
||||
|
||||
for i, up_block in enumerate(self.up_blocks):
|
||||
conv_cache_key = f"up_block_{i}"
|
||||
hidden_states, new_conv_cache[conv_cache_key] = self._gradient_checkpointing_func(
|
||||
up_block, hidden_states, conv_cache=conv_cache.get(conv_cache_key)
|
||||
up_block, hidden_states, conv_cache.get(conv_cache_key)
|
||||
)
|
||||
else:
|
||||
hidden_states, new_conv_cache["block_in"] = self.block_in(
|
||||
|
||||
@@ -868,7 +868,7 @@ class PixArtSigmaPipeline(DiffusionPipeline):
|
||||
xm.mark_step()
|
||||
|
||||
if not output_type == "latent":
|
||||
image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0]
|
||||
image = self.vae.decode(latents.to(self.vae.dtype) / self.vae.config.scaling_factor, return_dict=False)[0]
|
||||
if use_resolution_binning:
|
||||
image = self.image_processor.resize_and_crop_tensor(image, orig_width, orig_height)
|
||||
else:
|
||||
|
||||
@@ -321,9 +321,19 @@ class WanImageToVideoPipeline(DiffusionPipeline, WanLoraLoaderMixin):
|
||||
width,
|
||||
prompt_embeds=None,
|
||||
negative_prompt_embeds=None,
|
||||
image_embeds=None,
|
||||
callback_on_step_end_tensor_inputs=None,
|
||||
):
|
||||
if not isinstance(image, torch.Tensor) and not isinstance(image, PIL.Image.Image):
|
||||
if image is not None and image_embeds is not None:
|
||||
raise ValueError(
|
||||
f"Cannot forward both `image`: {image} and `image_embeds`: {image_embeds}. Please make sure to"
|
||||
" only forward one of the two."
|
||||
)
|
||||
if image is None and image_embeds is None:
|
||||
raise ValueError(
|
||||
"Provide either `image` or `prompt_embeds`. Cannot leave both `image` and `image_embeds` undefined."
|
||||
)
|
||||
if image is not None and not isinstance(image, torch.Tensor) and not isinstance(image, PIL.Image.Image):
|
||||
raise ValueError("`image` has to be of type `torch.Tensor` or `PIL.Image.Image` but is" f" {type(image)}")
|
||||
if height % 16 != 0 or width % 16 != 0:
|
||||
raise ValueError(f"`height` and `width` have to be divisible by 16 but are {height} and {width}.")
|
||||
@@ -463,6 +473,7 @@ class WanImageToVideoPipeline(DiffusionPipeline, WanLoraLoaderMixin):
|
||||
latents: Optional[torch.Tensor] = None,
|
||||
prompt_embeds: Optional[torch.Tensor] = None,
|
||||
negative_prompt_embeds: Optional[torch.Tensor] = None,
|
||||
image_embeds: Optional[torch.Tensor] = None,
|
||||
output_type: Optional[str] = "np",
|
||||
return_dict: bool = True,
|
||||
attention_kwargs: Optional[Dict[str, Any]] = None,
|
||||
@@ -512,6 +523,12 @@ class WanImageToVideoPipeline(DiffusionPipeline, WanLoraLoaderMixin):
|
||||
prompt_embeds (`torch.Tensor`, *optional*):
|
||||
Pre-generated text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not
|
||||
provided, text embeddings are generated from the `prompt` input argument.
|
||||
negative_prompt_embeds (`torch.Tensor`, *optional*):
|
||||
Pre-generated text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not
|
||||
provided, text embeddings are generated from the `negative_prompt` input argument.
|
||||
image_embeds (`torch.Tensor`, *optional*):
|
||||
Pre-generated image embeddings. Can be used to easily tweak image inputs (weighting). If not provided,
|
||||
image embeddings are generated from the `image` input argument.
|
||||
output_type (`str`, *optional*, defaults to `"pil"`):
|
||||
The output format of the generated image. Choose between `PIL.Image` or `np.array`.
|
||||
return_dict (`bool`, *optional*, defaults to `True`):
|
||||
@@ -556,6 +573,7 @@ class WanImageToVideoPipeline(DiffusionPipeline, WanLoraLoaderMixin):
|
||||
width,
|
||||
prompt_embeds,
|
||||
negative_prompt_embeds,
|
||||
image_embeds,
|
||||
callback_on_step_end_tensor_inputs,
|
||||
)
|
||||
|
||||
@@ -599,7 +617,8 @@ class WanImageToVideoPipeline(DiffusionPipeline, WanLoraLoaderMixin):
|
||||
if negative_prompt_embeds is not None:
|
||||
negative_prompt_embeds = negative_prompt_embeds.to(transformer_dtype)
|
||||
|
||||
image_embeds = self.encode_image(image, device)
|
||||
if image_embeds is None:
|
||||
image_embeds = self.encode_image(image, device)
|
||||
image_embeds = image_embeds.repeat(batch_size, 1, 1)
|
||||
image_embeds = image_embeds.to(transformer_dtype)
|
||||
|
||||
|
||||
@@ -14,10 +14,11 @@ import tempfile
|
||||
import time
|
||||
import unittest
|
||||
import urllib.parse
|
||||
from collections import UserDict
|
||||
from contextlib import contextmanager
|
||||
from io import BytesIO, StringIO
|
||||
from pathlib import Path
|
||||
from typing import Callable, Dict, List, Optional, Union
|
||||
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple, Union
|
||||
|
||||
import numpy as np
|
||||
import PIL.Image
|
||||
@@ -48,6 +49,17 @@ from .import_utils import (
|
||||
from .logging import get_logger
|
||||
|
||||
|
||||
if is_torch_available():
|
||||
import torch
|
||||
|
||||
IS_ROCM_SYSTEM = torch.version.hip is not None
|
||||
IS_CUDA_SYSTEM = torch.version.cuda is not None
|
||||
IS_XPU_SYSTEM = getattr(torch.version, "xpu", None) is not None
|
||||
else:
|
||||
IS_ROCM_SYSTEM = False
|
||||
IS_CUDA_SYSTEM = False
|
||||
IS_XPU_SYSTEM = False
|
||||
|
||||
global_rng = random.Random()
|
||||
|
||||
logger = get_logger(__name__)
|
||||
@@ -1275,3 +1287,178 @@ if is_torch_available():
|
||||
update_mapping_from_spec(BACKEND_RESET_PEAK_MEMORY_STATS, "RESET_PEAK_MEMORY_STATS_FN")
|
||||
update_mapping_from_spec(BACKEND_RESET_MAX_MEMORY_ALLOCATED, "RESET_MAX_MEMORY_ALLOCATED_FN")
|
||||
update_mapping_from_spec(BACKEND_MAX_MEMORY_ALLOCATED, "MAX_MEMORY_ALLOCATED_FN")
|
||||
|
||||
|
||||
# Modified from https://github.com/huggingface/transformers/blob/cdfb018d0300fef3b07d9220f3efe9c2a9974662/src/transformers/testing_utils.py#L3090
|
||||
|
||||
# Type definition of key used in `Expectations` class.
|
||||
DeviceProperties = Tuple[Union[str, None], Union[int, None]]
|
||||
|
||||
|
||||
@functools.lru_cache
|
||||
def get_device_properties() -> DeviceProperties:
|
||||
"""
|
||||
Get environment device properties.
|
||||
"""
|
||||
if IS_CUDA_SYSTEM or IS_ROCM_SYSTEM:
|
||||
import torch
|
||||
|
||||
major, _ = torch.cuda.get_device_capability()
|
||||
if IS_ROCM_SYSTEM:
|
||||
return ("rocm", major)
|
||||
else:
|
||||
return ("cuda", major)
|
||||
elif IS_XPU_SYSTEM:
|
||||
import torch
|
||||
|
||||
# To get more info of the architecture meaning and bit allocation, refer to https://github.com/intel/llvm/blob/sycl/sycl/include/sycl/ext/oneapi/experimental/device_architecture.def
|
||||
arch = torch.xpu.get_device_capability()["architecture"]
|
||||
gen_mask = 0x000000FF00000000
|
||||
gen = (arch & gen_mask) >> 32
|
||||
return ("xpu", gen)
|
||||
else:
|
||||
return (torch_device, None)
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
DevicePropertiesUserDict = UserDict[DeviceProperties, Any]
|
||||
else:
|
||||
DevicePropertiesUserDict = UserDict
|
||||
|
||||
|
||||
class Expectations(DevicePropertiesUserDict):
|
||||
def get_expectation(self) -> Any:
|
||||
"""
|
||||
Find best matching expectation based on environment device properties.
|
||||
"""
|
||||
return self.find_expectation(get_device_properties())
|
||||
|
||||
@staticmethod
|
||||
def is_default(key: DeviceProperties) -> bool:
|
||||
return all(p is None for p in key)
|
||||
|
||||
@staticmethod
|
||||
def score(key: DeviceProperties, other: DeviceProperties) -> int:
|
||||
"""
|
||||
Returns score indicating how similar two instances of the `Properties` tuple are. Points are calculated using
|
||||
bits, but documented as int. Rules are as follows:
|
||||
* Matching `type` gives 8 points.
|
||||
* Semi-matching `type`, for example cuda and rocm, gives 4 points.
|
||||
* Matching `major` (compute capability major version) gives 2 points.
|
||||
* Default expectation (if present) gives 1 points.
|
||||
"""
|
||||
(device_type, major) = key
|
||||
(other_device_type, other_major) = other
|
||||
|
||||
score = 0b0
|
||||
if device_type == other_device_type:
|
||||
score |= 0b1000
|
||||
elif device_type in ["cuda", "rocm"] and other_device_type in ["cuda", "rocm"]:
|
||||
score |= 0b100
|
||||
|
||||
if major == other_major and other_major is not None:
|
||||
score |= 0b10
|
||||
|
||||
if Expectations.is_default(other):
|
||||
score |= 0b1
|
||||
|
||||
return int(score)
|
||||
|
||||
def find_expectation(self, key: DeviceProperties = (None, None)) -> Any:
|
||||
"""
|
||||
Find best matching expectation based on provided device properties.
|
||||
"""
|
||||
(result_key, result) = max(self.data.items(), key=lambda x: Expectations.score(key, x[0]))
|
||||
|
||||
if Expectations.score(key, result_key) == 0:
|
||||
raise ValueError(f"No matching expectation found for {key}")
|
||||
|
||||
return result
|
||||
|
||||
def __repr__(self):
|
||||
return f"{self.data}"
|
||||
|
||||
|
||||
def dynamic_slice_test(func):
|
||||
"""
|
||||
Decorator that injects an expected_slice parameter into a test function.
|
||||
|
||||
On the first run, it will capture the actual slice output and cache it.
|
||||
On subsequent runs, it provides the cached slice as the expected slice.
|
||||
|
||||
Example:
|
||||
```python
|
||||
@dynamic_slice_test
|
||||
def test_stable_diffusion_ddim(self, expected_slice=None):
|
||||
# Run the pipeline
|
||||
components = self.get_dummy_components()
|
||||
sd_pipe = StableDiffusionPipeline(**components)
|
||||
inputs = self.get_dummy_inputs("cpu")
|
||||
image = sd_pipe(**inputs).images
|
||||
image_slice = image[0, -3:, -3:, -1]
|
||||
|
||||
# If expected_slice is provided (from cache), assert against it
|
||||
if expected_slice is not None:
|
||||
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
|
||||
|
||||
# Always return the current slice for caching
|
||||
return image_slice
|
||||
```
|
||||
"""
|
||||
# Check if the function has the expected_slice parameter
|
||||
sig = inspect.signature(func)
|
||||
if "expected_slice" not in sig.parameters:
|
||||
raise ValueError("The decorated function must have an 'expected_slice' parameter")
|
||||
|
||||
@functools.wraps(func)
|
||||
def wrapper(*args, **kwargs):
|
||||
# Get the test name from pytest
|
||||
# pytest sets this environment variable to the current test
|
||||
test_name = os.environ.get("PYTEST_CURRENT_TEST", "")
|
||||
if test_name:
|
||||
# Format is: test_file.py::TestClass::test_method (call)
|
||||
test_name = test_name.split(" ")[0]
|
||||
else:
|
||||
# Fallback if not running in pytest
|
||||
test_name = f"{func.__module__}.{func.__qualname__}"
|
||||
|
||||
# Create a unique filename based on hardware details
|
||||
device_props = get_device_properties()
|
||||
device_str = f"{device_props[0]}{device_props[1] if device_props[1] is not None else ''}"
|
||||
|
||||
# Setup cache directory
|
||||
cache_dir = os.environ.get("DIFFUSERS_TEST_CACHE_DIR", ".test_cache")
|
||||
os.makedirs(cache_dir, exist_ok=True)
|
||||
cache_path = os.path.join(cache_dir, f"{test_name}_{device_str}.npy")
|
||||
|
||||
# Check for cached expected slice
|
||||
cached_slice = None
|
||||
if os.path.exists(cache_path):
|
||||
try:
|
||||
cached_slice = np.load(cache_path)
|
||||
print(f"Using cached slice from {cache_path}")
|
||||
except Exception as e:
|
||||
print(f"Error loading cached slice: {e}")
|
||||
|
||||
# Run the test function with the expected slice injected
|
||||
kwargs["expected_slice"] = cached_slice
|
||||
actual_slice = func(*args, **kwargs)
|
||||
|
||||
# If the function returned a slice and there's no cached slice yet, cache it
|
||||
if actual_slice is not None and cached_slice is None:
|
||||
# Convert torch tensor to numpy if needed
|
||||
if hasattr(actual_slice, "detach") and hasattr(actual_slice, "cpu") and hasattr(actual_slice, "numpy"):
|
||||
actual_slice_np = actual_slice.detach().cpu().numpy()
|
||||
else:
|
||||
actual_slice_np = actual_slice
|
||||
|
||||
# Save the slice
|
||||
try:
|
||||
np.save(cache_path, actual_slice_np)
|
||||
print(f"Saved slice to cache: {cache_path}")
|
||||
except Exception as e:
|
||||
print(f"Error saving slice to cache: {e}")
|
||||
|
||||
return actual_slice
|
||||
|
||||
return wrapper
|
||||
|
||||
+111
@@ -0,0 +1,111 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2024 HuggingFace Inc.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import unittest
|
||||
|
||||
from diffusers import AutoencoderKLMochi
|
||||
from diffusers.utils.testing_utils import (
|
||||
enable_full_determinism,
|
||||
floats_tensor,
|
||||
torch_device,
|
||||
)
|
||||
|
||||
from ..test_modeling_common import ModelTesterMixin, UNetTesterMixin
|
||||
|
||||
|
||||
enable_full_determinism()
|
||||
|
||||
|
||||
class AutoencoderKLMochiTests(ModelTesterMixin, UNetTesterMixin, unittest.TestCase):
|
||||
model_class = AutoencoderKLMochi
|
||||
main_input_name = "sample"
|
||||
base_precision = 1e-2
|
||||
|
||||
def get_autoencoder_kl_mochi_config(self):
|
||||
return {
|
||||
"in_channels": 15,
|
||||
"out_channels": 3,
|
||||
"latent_channels": 4,
|
||||
"encoder_block_out_channels": (32, 32, 32, 32),
|
||||
"decoder_block_out_channels": (32, 32, 32, 32),
|
||||
"layers_per_block": (1, 1, 1, 1, 1),
|
||||
"act_fn": "silu",
|
||||
"scaling_factor": 1,
|
||||
}
|
||||
|
||||
@property
|
||||
def dummy_input(self):
|
||||
batch_size = 2
|
||||
num_frames = 7
|
||||
num_channels = 3
|
||||
sizes = (16, 16)
|
||||
|
||||
image = floats_tensor((batch_size, num_channels, num_frames) + sizes).to(torch_device)
|
||||
|
||||
return {"sample": image}
|
||||
|
||||
@property
|
||||
def input_shape(self):
|
||||
return (3, 7, 16, 16)
|
||||
|
||||
@property
|
||||
def output_shape(self):
|
||||
return (3, 7, 16, 16)
|
||||
|
||||
def prepare_init_args_and_inputs_for_common(self):
|
||||
init_dict = self.get_autoencoder_kl_mochi_config()
|
||||
inputs_dict = self.dummy_input
|
||||
return init_dict, inputs_dict
|
||||
|
||||
def test_gradient_checkpointing_is_applied(self):
|
||||
expected_set = {
|
||||
"MochiDecoder3D",
|
||||
"MochiDownBlock3D",
|
||||
"MochiEncoder3D",
|
||||
"MochiMidBlock3D",
|
||||
"MochiUpBlock3D",
|
||||
}
|
||||
super().test_gradient_checkpointing_is_applied(expected_set=expected_set)
|
||||
|
||||
@unittest.skip("Unsupported test.")
|
||||
def test_forward_with_norm_groups(self):
|
||||
"""
|
||||
tests/models/autoencoders/test_models_autoencoder_mochi.py::AutoencoderKLMochiTests::test_forward_with_norm_groups -
|
||||
TypeError: AutoencoderKLMochi.__init__() got an unexpected keyword argument 'norm_num_groups'
|
||||
"""
|
||||
pass
|
||||
|
||||
@unittest.skip("Unsupported test.")
|
||||
def test_model_parallelism(self):
|
||||
"""
|
||||
tests/models/autoencoders/test_models_autoencoder_mochi.py::AutoencoderKLMochiTests::test_outputs_equivalence -
|
||||
RuntimeError: values expected sparse tensor layout but got Strided
|
||||
"""
|
||||
pass
|
||||
|
||||
@unittest.skip("Unsupported test.")
|
||||
def test_outputs_equivalence(self):
|
||||
"""
|
||||
tests/models/autoencoders/test_models_autoencoder_mochi.py::AutoencoderKLMochiTests::test_outputs_equivalence -
|
||||
RuntimeError: values expected sparse tensor layout but got Strided
|
||||
"""
|
||||
pass
|
||||
|
||||
@unittest.skip("Unsupported test.")
|
||||
def test_sharded_checkpoints_device_map(self):
|
||||
"""
|
||||
tests/models/autoencoders/test_models_autoencoder_mochi.py::AutoencoderKLMochiTests::test_sharded_checkpoints_device_map -
|
||||
RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cuda:0 and cuda:5!
|
||||
"""
|
||||
@@ -24,6 +24,7 @@ import traceback
|
||||
import unittest
|
||||
import unittest.mock as mock
|
||||
import uuid
|
||||
import warnings
|
||||
from collections import defaultdict
|
||||
from typing import Dict, List, Optional, Tuple, Union
|
||||
|
||||
@@ -56,15 +57,20 @@ from diffusers.utils import (
|
||||
from diffusers.utils.hub_utils import _add_variant
|
||||
from diffusers.utils.testing_utils import (
|
||||
CaptureLogger,
|
||||
backend_empty_cache,
|
||||
floats_tensor,
|
||||
get_python_version,
|
||||
is_torch_compile,
|
||||
numpy_cosine_similarity_distance,
|
||||
require_peft_backend,
|
||||
require_peft_version_greater,
|
||||
require_torch_2,
|
||||
require_torch_accelerator,
|
||||
require_torch_accelerator_with_training,
|
||||
require_torch_gpu,
|
||||
require_torch_multi_accelerator,
|
||||
run_test_in_subprocess,
|
||||
slow,
|
||||
torch_all_close,
|
||||
torch_device,
|
||||
)
|
||||
@@ -1659,3 +1665,234 @@ class ModelPushToHubTester(unittest.TestCase):
|
||||
|
||||
# Reset repo
|
||||
delete_repo(self.repo_id, token=TOKEN)
|
||||
|
||||
|
||||
@slow
|
||||
@require_torch_2
|
||||
@require_torch_accelerator
|
||||
@require_peft_backend
|
||||
@require_peft_version_greater("0.14.0")
|
||||
@is_torch_compile
|
||||
class TestLoraHotSwappingForModel(unittest.TestCase):
|
||||
"""Test that hotswapping does not result in recompilation on the model directly.
|
||||
|
||||
We're not extensively testing the hotswapping functionality since it is implemented in PEFT and is extensively
|
||||
tested there. The goal of this test is specifically to ensure that hotswapping with diffusers does not require
|
||||
recompilation.
|
||||
|
||||
See
|
||||
https://github.com/huggingface/peft/blob/eaab05e18d51fb4cce20a73c9acd82a00c013b83/tests/test_gpu_examples.py#L4252
|
||||
for the analogous PEFT test.
|
||||
|
||||
"""
|
||||
|
||||
def tearDown(self):
|
||||
# It is critical that the dynamo cache is reset for each test. Otherwise, if the test re-uses the same model,
|
||||
# there will be recompilation errors, as torch caches the model when run in the same process.
|
||||
super().tearDown()
|
||||
torch._dynamo.reset()
|
||||
gc.collect()
|
||||
backend_empty_cache(torch_device)
|
||||
|
||||
def get_small_unet(self):
|
||||
# from diffusers UNet2DConditionModelTests
|
||||
torch.manual_seed(0)
|
||||
init_dict = {
|
||||
"block_out_channels": (4, 8),
|
||||
"norm_num_groups": 4,
|
||||
"down_block_types": ("CrossAttnDownBlock2D", "DownBlock2D"),
|
||||
"up_block_types": ("UpBlock2D", "CrossAttnUpBlock2D"),
|
||||
"cross_attention_dim": 8,
|
||||
"attention_head_dim": 2,
|
||||
"out_channels": 4,
|
||||
"in_channels": 4,
|
||||
"layers_per_block": 1,
|
||||
"sample_size": 16,
|
||||
}
|
||||
model = UNet2DConditionModel(**init_dict)
|
||||
return model.to(torch_device)
|
||||
|
||||
def get_unet_lora_config(self, lora_rank, lora_alpha, target_modules):
|
||||
# from diffusers test_models_unet_2d_condition.py
|
||||
from peft import LoraConfig
|
||||
|
||||
unet_lora_config = LoraConfig(
|
||||
r=lora_rank,
|
||||
lora_alpha=lora_alpha,
|
||||
target_modules=target_modules,
|
||||
init_lora_weights=False,
|
||||
use_dora=False,
|
||||
)
|
||||
return unet_lora_config
|
||||
|
||||
def get_dummy_input(self):
|
||||
# from UNet2DConditionModelTests
|
||||
batch_size = 4
|
||||
num_channels = 4
|
||||
sizes = (16, 16)
|
||||
|
||||
noise = floats_tensor((batch_size, num_channels) + sizes).to(torch_device)
|
||||
time_step = torch.tensor([10]).to(torch_device)
|
||||
encoder_hidden_states = floats_tensor((batch_size, 4, 8)).to(torch_device)
|
||||
|
||||
return {"sample": noise, "timestep": time_step, "encoder_hidden_states": encoder_hidden_states}
|
||||
|
||||
def check_model_hotswap(self, do_compile, rank0, rank1, target_modules0, target_modules1=None):
|
||||
"""
|
||||
Check that hotswapping works on a small unet.
|
||||
|
||||
Steps:
|
||||
- create 2 LoRA adapters and save them
|
||||
- load the first adapter
|
||||
- hotswap the second adapter
|
||||
- check that the outputs are correct
|
||||
- optionally compile the model
|
||||
|
||||
Note: We set rank == alpha here because save_lora_adapter does not save the alpha scalings, thus the test would
|
||||
fail if the values are different. Since rank != alpha does not matter for the purpose of this test, this is
|
||||
fine.
|
||||
"""
|
||||
# create 2 adapters with different ranks and alphas
|
||||
dummy_input = self.get_dummy_input()
|
||||
alpha0, alpha1 = rank0, rank1
|
||||
max_rank = max([rank0, rank1])
|
||||
if target_modules1 is None:
|
||||
target_modules1 = target_modules0[:]
|
||||
lora_config0 = self.get_unet_lora_config(rank0, alpha0, target_modules0)
|
||||
lora_config1 = self.get_unet_lora_config(rank1, alpha1, target_modules1)
|
||||
|
||||
unet = self.get_small_unet()
|
||||
unet.add_adapter(lora_config0, adapter_name="adapter0")
|
||||
with torch.inference_mode():
|
||||
output0_before = unet(**dummy_input)["sample"]
|
||||
|
||||
unet.add_adapter(lora_config1, adapter_name="adapter1")
|
||||
unet.set_adapter("adapter1")
|
||||
with torch.inference_mode():
|
||||
output1_before = unet(**dummy_input)["sample"]
|
||||
|
||||
# sanity checks:
|
||||
tol = 5e-3
|
||||
assert not torch.allclose(output0_before, output1_before, atol=tol, rtol=tol)
|
||||
assert not (output0_before == 0).all()
|
||||
assert not (output1_before == 0).all()
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmp_dirname:
|
||||
# save the adapter checkpoints
|
||||
unet.save_lora_adapter(os.path.join(tmp_dirname, "0"), safe_serialization=True, adapter_name="adapter0")
|
||||
unet.save_lora_adapter(os.path.join(tmp_dirname, "1"), safe_serialization=True, adapter_name="adapter1")
|
||||
del unet
|
||||
|
||||
# load the first adapter
|
||||
unet = self.get_small_unet()
|
||||
if do_compile or (rank0 != rank1):
|
||||
# no need to prepare if the model is not compiled or if the ranks are identical
|
||||
unet.enable_lora_hotswap(target_rank=max_rank)
|
||||
|
||||
file_name0 = os.path.join(os.path.join(tmp_dirname, "0"), "pytorch_lora_weights.safetensors")
|
||||
file_name1 = os.path.join(os.path.join(tmp_dirname, "1"), "pytorch_lora_weights.safetensors")
|
||||
unet.load_lora_adapter(file_name0, safe_serialization=True, adapter_name="adapter0", prefix=None)
|
||||
|
||||
if do_compile:
|
||||
unet = torch.compile(unet, mode="reduce-overhead")
|
||||
|
||||
with torch.inference_mode():
|
||||
output0_after = unet(**dummy_input)["sample"]
|
||||
assert torch.allclose(output0_before, output0_after, atol=tol, rtol=tol)
|
||||
|
||||
# hotswap the 2nd adapter
|
||||
unet.load_lora_adapter(file_name1, adapter_name="adapter0", hotswap=True, prefix=None)
|
||||
|
||||
# we need to call forward to potentially trigger recompilation
|
||||
with torch.inference_mode():
|
||||
output1_after = unet(**dummy_input)["sample"]
|
||||
assert torch.allclose(output1_before, output1_after, atol=tol, rtol=tol)
|
||||
|
||||
# check error when not passing valid adapter name
|
||||
name = "does-not-exist"
|
||||
msg = f"Trying to hotswap LoRA adapter '{name}' but there is no existing adapter by that name"
|
||||
with self.assertRaisesRegex(ValueError, msg):
|
||||
unet.load_lora_adapter(file_name1, adapter_name=name, hotswap=True, prefix=None)
|
||||
|
||||
@parameterized.expand([(11, 11), (7, 13), (13, 7)]) # important to test small to large and vice versa
|
||||
def test_hotswapping_model(self, rank0, rank1):
|
||||
self.check_model_hotswap(
|
||||
do_compile=False, rank0=rank0, rank1=rank1, target_modules0=["to_q", "to_k", "to_v", "to_out.0"]
|
||||
)
|
||||
|
||||
@parameterized.expand([(11, 11), (7, 13), (13, 7)]) # important to test small to large and vice versa
|
||||
def test_hotswapping_compiled_model_linear(self, rank0, rank1):
|
||||
# It's important to add this context to raise an error on recompilation
|
||||
target_modules = ["to_q", "to_k", "to_v", "to_out.0"]
|
||||
with torch._dynamo.config.patch(error_on_recompile=True):
|
||||
self.check_model_hotswap(do_compile=True, rank0=rank0, rank1=rank1, target_modules0=target_modules)
|
||||
|
||||
@parameterized.expand([(11, 11), (7, 13), (13, 7)]) # important to test small to large and vice versa
|
||||
def test_hotswapping_compiled_model_conv2d(self, rank0, rank1):
|
||||
# It's important to add this context to raise an error on recompilation
|
||||
target_modules = ["conv", "conv1", "conv2"]
|
||||
with torch._dynamo.config.patch(error_on_recompile=True):
|
||||
self.check_model_hotswap(do_compile=True, rank0=rank0, rank1=rank1, target_modules0=target_modules)
|
||||
|
||||
@parameterized.expand([(11, 11), (7, 13), (13, 7)]) # important to test small to large and vice versa
|
||||
def test_hotswapping_compiled_model_both_linear_and_conv2d(self, rank0, rank1):
|
||||
# It's important to add this context to raise an error on recompilation
|
||||
target_modules = ["to_q", "conv"]
|
||||
with torch._dynamo.config.patch(error_on_recompile=True):
|
||||
self.check_model_hotswap(do_compile=True, rank0=rank0, rank1=rank1, target_modules0=target_modules)
|
||||
|
||||
def test_enable_lora_hotswap_called_after_adapter_added_raises(self):
|
||||
# ensure that enable_lora_hotswap is called before loading the first adapter
|
||||
lora_config = self.get_unet_lora_config(8, 8, target_modules=["to_q"])
|
||||
unet = self.get_small_unet()
|
||||
unet.add_adapter(lora_config)
|
||||
msg = re.escape("Call `enable_lora_hotswap` before loading the first adapter.")
|
||||
with self.assertRaisesRegex(RuntimeError, msg):
|
||||
unet.enable_lora_hotswap(target_rank=32)
|
||||
|
||||
def test_enable_lora_hotswap_called_after_adapter_added_warning(self):
|
||||
# ensure that enable_lora_hotswap is called before loading the first adapter
|
||||
from diffusers.loaders.peft import logger
|
||||
|
||||
lora_config = self.get_unet_lora_config(8, 8, target_modules=["to_q"])
|
||||
unet = self.get_small_unet()
|
||||
unet.add_adapter(lora_config)
|
||||
msg = (
|
||||
"It is recommended to call `enable_lora_hotswap` before loading the first adapter to avoid recompilation."
|
||||
)
|
||||
with self.assertLogs(logger=logger, level="WARNING") as cm:
|
||||
unet.enable_lora_hotswap(target_rank=32, check_compiled="warn")
|
||||
assert any(msg in log for log in cm.output)
|
||||
|
||||
def test_enable_lora_hotswap_called_after_adapter_added_ignore(self):
|
||||
# check possibility to ignore the error/warning
|
||||
lora_config = self.get_unet_lora_config(8, 8, target_modules=["to_q"])
|
||||
unet = self.get_small_unet()
|
||||
unet.add_adapter(lora_config)
|
||||
with warnings.catch_warnings(record=True) as w:
|
||||
warnings.simplefilter("always") # Capture all warnings
|
||||
unet.enable_lora_hotswap(target_rank=32, check_compiled="warn")
|
||||
self.assertEqual(len(w), 0, f"Expected no warnings, but got: {[str(warn.message) for warn in w]}")
|
||||
|
||||
def test_enable_lora_hotswap_wrong_check_compiled_argument_raises(self):
|
||||
# check that wrong argument value raises an error
|
||||
lora_config = self.get_unet_lora_config(8, 8, target_modules=["to_q"])
|
||||
unet = self.get_small_unet()
|
||||
unet.add_adapter(lora_config)
|
||||
msg = re.escape("check_compiles should be one of 'error', 'warn', or 'ignore', got 'wrong-argument' instead.")
|
||||
with self.assertRaisesRegex(ValueError, msg):
|
||||
unet.enable_lora_hotswap(target_rank=32, check_compiled="wrong-argument")
|
||||
|
||||
def test_hotswap_second_adapter_targets_more_layers_raises(self):
|
||||
# check the error and log
|
||||
from diffusers.loaders.peft import logger
|
||||
|
||||
# at the moment, PEFT requires the 2nd adapter to target the same or a subset of layers
|
||||
target_modules0 = ["to_q"]
|
||||
target_modules1 = ["to_q", "to_k"]
|
||||
with self.assertRaises(RuntimeError): # peft raises RuntimeError
|
||||
with self.assertLogs(logger=logger, level="ERROR") as cm:
|
||||
self.check_model_hotswap(
|
||||
do_compile=True, rank0=8, rank1=8, target_modules0=target_modules0, target_modules1=target_modules1
|
||||
)
|
||||
assert any("Hotswapping adapter0 was unsuccessful" in log for log in cm.output)
|
||||
|
||||
@@ -20,7 +20,7 @@ import pytest
|
||||
|
||||
from diffusers import __version__
|
||||
from diffusers.utils import deprecate
|
||||
from diffusers.utils.testing_utils import str_to_bool
|
||||
from diffusers.utils.testing_utils import Expectations, str_to_bool
|
||||
|
||||
|
||||
# Used to test the hub
|
||||
@@ -182,6 +182,38 @@ class DeprecateTester(unittest.TestCase):
|
||||
assert "diffusers/tests/others/test_utils.py" in warning.filename
|
||||
|
||||
|
||||
# Copied from https://github.com/huggingface/transformers/blob/main/tests/utils/test_expectations.py
|
||||
class ExpectationsTester(unittest.TestCase):
|
||||
def test_expectations(self):
|
||||
expectations = Expectations(
|
||||
{
|
||||
(None, None): 1,
|
||||
("cuda", 8): 2,
|
||||
("cuda", 7): 3,
|
||||
("rocm", 8): 4,
|
||||
("rocm", None): 5,
|
||||
("cpu", None): 6,
|
||||
("xpu", 3): 7,
|
||||
}
|
||||
)
|
||||
|
||||
def check(value, key):
|
||||
assert expectations.find_expectation(key) == value
|
||||
|
||||
# npu has no matches so should find default expectation
|
||||
check(1, ("npu", None))
|
||||
check(7, ("xpu", 3))
|
||||
check(2, ("cuda", 8))
|
||||
check(3, ("cuda", 7))
|
||||
check(4, ("rocm", 9))
|
||||
check(4, ("rocm", None))
|
||||
check(2, ("cuda", 2))
|
||||
|
||||
expectations = Expectations({("cuda", 8): 1})
|
||||
with self.assertRaises(ValueError):
|
||||
expectations.find_expectation(("xpu", None))
|
||||
|
||||
|
||||
def parse_flag_from_env(key, default=False):
|
||||
try:
|
||||
value = os.environ[key]
|
||||
|
||||
@@ -15,6 +15,7 @@ from diffusers import (
|
||||
)
|
||||
from diffusers.utils import load_image
|
||||
from diffusers.utils.testing_utils import (
|
||||
Expectations,
|
||||
backend_empty_cache,
|
||||
floats_tensor,
|
||||
numpy_cosine_similarity_distance,
|
||||
@@ -208,41 +209,115 @@ class StableDiffusion3Img2ImgPipelineSlowTests(unittest.TestCase):
|
||||
inputs = self.get_inputs(torch_device)
|
||||
image = pipe(**inputs).images[0]
|
||||
image_slice = image[0, :10, :10]
|
||||
expected_slice = np.array(
|
||||
[
|
||||
0.5435,
|
||||
0.4673,
|
||||
0.5732,
|
||||
0.4438,
|
||||
0.3557,
|
||||
0.4912,
|
||||
0.4331,
|
||||
0.3491,
|
||||
0.4915,
|
||||
0.4287,
|
||||
0.3477,
|
||||
0.4849,
|
||||
0.4355,
|
||||
0.3469,
|
||||
0.4871,
|
||||
0.4431,
|
||||
0.3538,
|
||||
0.4912,
|
||||
0.4521,
|
||||
0.3643,
|
||||
0.5059,
|
||||
0.4587,
|
||||
0.3730,
|
||||
0.5166,
|
||||
0.4685,
|
||||
0.3845,
|
||||
0.5264,
|
||||
0.4746,
|
||||
0.3914,
|
||||
0.5342,
|
||||
]
|
||||
expected_slices = Expectations(
|
||||
{
|
||||
("xpu", 3): np.array(
|
||||
[
|
||||
0.5117,
|
||||
0.4421,
|
||||
0.3852,
|
||||
0.5044,
|
||||
0.4219,
|
||||
0.3262,
|
||||
0.5024,
|
||||
0.4329,
|
||||
0.3276,
|
||||
0.4978,
|
||||
0.4412,
|
||||
0.3355,
|
||||
0.4983,
|
||||
0.4338,
|
||||
0.3279,
|
||||
0.4893,
|
||||
0.4241,
|
||||
0.3129,
|
||||
0.4875,
|
||||
0.4253,
|
||||
0.3030,
|
||||
0.4961,
|
||||
0.4267,
|
||||
0.2988,
|
||||
0.5029,
|
||||
0.4255,
|
||||
0.3054,
|
||||
0.5132,
|
||||
0.4248,
|
||||
0.3222,
|
||||
]
|
||||
),
|
||||
("cuda", 7): np.array(
|
||||
[
|
||||
0.5435,
|
||||
0.4673,
|
||||
0.5732,
|
||||
0.4438,
|
||||
0.3557,
|
||||
0.4912,
|
||||
0.4331,
|
||||
0.3491,
|
||||
0.4915,
|
||||
0.4287,
|
||||
0.347,
|
||||
0.4849,
|
||||
0.4355,
|
||||
0.3469,
|
||||
0.4871,
|
||||
0.4431,
|
||||
0.3538,
|
||||
0.4912,
|
||||
0.4521,
|
||||
0.3643,
|
||||
0.5059,
|
||||
0.4587,
|
||||
0.373,
|
||||
0.5166,
|
||||
0.4685,
|
||||
0.3845,
|
||||
0.5264,
|
||||
0.4746,
|
||||
0.3914,
|
||||
0.5342,
|
||||
]
|
||||
),
|
||||
("cuda", 8): np.array(
|
||||
[
|
||||
0.5146,
|
||||
0.4385,
|
||||
0.3826,
|
||||
0.5098,
|
||||
0.4150,
|
||||
0.3218,
|
||||
0.5142,
|
||||
0.4312,
|
||||
0.3298,
|
||||
0.5127,
|
||||
0.4431,
|
||||
0.3411,
|
||||
0.5171,
|
||||
0.4424,
|
||||
0.3374,
|
||||
0.5088,
|
||||
0.4348,
|
||||
0.3242,
|
||||
0.5073,
|
||||
0.4380,
|
||||
0.3174,
|
||||
0.5132,
|
||||
0.4397,
|
||||
0.3115,
|
||||
0.5132,
|
||||
0.4343,
|
||||
0.3118,
|
||||
0.5219,
|
||||
0.4328,
|
||||
0.3256,
|
||||
]
|
||||
),
|
||||
}
|
||||
)
|
||||
|
||||
expected_slice = expected_slices.get_expectation()
|
||||
|
||||
max_diff = numpy_cosine_similarity_distance(expected_slice.flatten(), image_slice.flatten())
|
||||
|
||||
assert max_diff < 1e-4, f"Outputs are not close enough, got {max_diff}"
|
||||
|
||||
@@ -17,12 +17,14 @@ import gc
|
||||
import json
|
||||
import os
|
||||
import random
|
||||
import re
|
||||
import shutil
|
||||
import sys
|
||||
import tempfile
|
||||
import traceback
|
||||
import unittest
|
||||
import unittest.mock as mock
|
||||
import warnings
|
||||
|
||||
import numpy as np
|
||||
import PIL.Image
|
||||
@@ -78,6 +80,8 @@ from diffusers.utils.testing_utils import (
|
||||
require_flax,
|
||||
require_hf_hub_version_greater,
|
||||
require_onnxruntime,
|
||||
require_peft_backend,
|
||||
require_peft_version_greater,
|
||||
require_torch_2,
|
||||
require_torch_accelerator,
|
||||
require_transformers_version_greater,
|
||||
@@ -2175,3 +2179,264 @@ class PipelineNightlyTests(unittest.TestCase):
|
||||
|
||||
# the values aren't exactly equal, but the images look the same visually
|
||||
assert np.abs(ddpm_images - ddim_images).max() < 1e-1
|
||||
|
||||
|
||||
@slow
|
||||
@require_torch_2
|
||||
@require_torch_accelerator
|
||||
@require_peft_backend
|
||||
@require_peft_version_greater("0.14.0")
|
||||
@is_torch_compile
|
||||
class TestLoraHotSwappingForPipeline(unittest.TestCase):
|
||||
"""Test that hotswapping does not result in recompilation in a pipeline.
|
||||
|
||||
We're not extensively testing the hotswapping functionality since it is implemented in PEFT and is extensively
|
||||
tested there. The goal of this test is specifically to ensure that hotswapping with diffusers does not require
|
||||
recompilation.
|
||||
|
||||
See
|
||||
https://github.com/huggingface/peft/blob/eaab05e18d51fb4cce20a73c9acd82a00c013b83/tests/test_gpu_examples.py#L4252
|
||||
for the analogous PEFT test.
|
||||
|
||||
"""
|
||||
|
||||
def tearDown(self):
|
||||
# It is critical that the dynamo cache is reset for each test. Otherwise, if the test re-uses the same model,
|
||||
# there will be recompilation errors, as torch caches the model when run in the same process.
|
||||
super().tearDown()
|
||||
torch._dynamo.reset()
|
||||
gc.collect()
|
||||
backend_empty_cache(torch_device)
|
||||
|
||||
def get_unet_lora_config(self, lora_rank, lora_alpha, target_modules):
|
||||
# from diffusers test_models_unet_2d_condition.py
|
||||
from peft import LoraConfig
|
||||
|
||||
unet_lora_config = LoraConfig(
|
||||
r=lora_rank,
|
||||
lora_alpha=lora_alpha,
|
||||
target_modules=target_modules,
|
||||
init_lora_weights=False,
|
||||
use_dora=False,
|
||||
)
|
||||
return unet_lora_config
|
||||
|
||||
def get_lora_state_dicts(self, modules_to_save, adapter_name):
|
||||
from peft import get_peft_model_state_dict
|
||||
|
||||
state_dicts = {}
|
||||
for module_name, module in modules_to_save.items():
|
||||
if module is not None:
|
||||
state_dicts[f"{module_name}_lora_layers"] = get_peft_model_state_dict(
|
||||
module, adapter_name=adapter_name
|
||||
)
|
||||
return state_dicts
|
||||
|
||||
def get_dummy_input(self):
|
||||
pipeline_inputs = {
|
||||
"prompt": "A painting of a squirrel eating a burger",
|
||||
"num_inference_steps": 5,
|
||||
"guidance_scale": 6.0,
|
||||
"output_type": "np",
|
||||
"return_dict": False,
|
||||
}
|
||||
return pipeline_inputs
|
||||
|
||||
def check_pipeline_hotswap(self, do_compile, rank0, rank1, target_modules0, target_modules1=None):
|
||||
"""
|
||||
Check that hotswapping works on a pipeline.
|
||||
|
||||
Steps:
|
||||
- create 2 LoRA adapters and save them
|
||||
- load the first adapter
|
||||
- hotswap the second adapter
|
||||
- check that the outputs are correct
|
||||
- optionally compile the model
|
||||
|
||||
Note: We set rank == alpha here because save_lora_adapter does not save the alpha scalings, thus the test would
|
||||
fail if the values are different. Since rank != alpha does not matter for the purpose of this test, this is
|
||||
fine.
|
||||
"""
|
||||
# create 2 adapters with different ranks and alphas
|
||||
dummy_input = self.get_dummy_input()
|
||||
pipeline = StableDiffusionPipeline.from_pretrained("hf-internal-testing/tiny-sd-pipe").to(torch_device)
|
||||
alpha0, alpha1 = rank0, rank1
|
||||
max_rank = max([rank0, rank1])
|
||||
if target_modules1 is None:
|
||||
target_modules1 = target_modules0[:]
|
||||
lora_config0 = self.get_unet_lora_config(rank0, alpha0, target_modules0)
|
||||
lora_config1 = self.get_unet_lora_config(rank1, alpha1, target_modules1)
|
||||
|
||||
torch.manual_seed(0)
|
||||
pipeline.unet.add_adapter(lora_config0, adapter_name="adapter0")
|
||||
output0_before = pipeline(**dummy_input, generator=torch.manual_seed(0))[0]
|
||||
|
||||
torch.manual_seed(1)
|
||||
pipeline.unet.add_adapter(lora_config1, adapter_name="adapter1")
|
||||
pipeline.unet.set_adapter("adapter1")
|
||||
output1_before = pipeline(**dummy_input, generator=torch.manual_seed(0))[0]
|
||||
|
||||
# sanity check
|
||||
tol = 1e-3
|
||||
assert not np.allclose(output0_before, output1_before, atol=tol, rtol=tol)
|
||||
assert not (output0_before == 0).all()
|
||||
assert not (output1_before == 0).all()
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmp_dirname:
|
||||
# save the adapter checkpoints
|
||||
lora0_state_dicts = self.get_lora_state_dicts({"unet": pipeline.unet}, adapter_name="adapter0")
|
||||
StableDiffusionPipeline.save_lora_weights(
|
||||
save_directory=os.path.join(tmp_dirname, "adapter0"), safe_serialization=True, **lora0_state_dicts
|
||||
)
|
||||
lora1_state_dicts = self.get_lora_state_dicts({"unet": pipeline.unet}, adapter_name="adapter1")
|
||||
StableDiffusionPipeline.save_lora_weights(
|
||||
save_directory=os.path.join(tmp_dirname, "adapter1"), safe_serialization=True, **lora1_state_dicts
|
||||
)
|
||||
del pipeline
|
||||
|
||||
# load the first adapter
|
||||
pipeline = StableDiffusionPipeline.from_pretrained("hf-internal-testing/tiny-sd-pipe").to(torch_device)
|
||||
if do_compile or (rank0 != rank1):
|
||||
# no need to prepare if the model is not compiled or if the ranks are identical
|
||||
pipeline.enable_lora_hotswap(target_rank=max_rank)
|
||||
|
||||
file_name0 = os.path.join(tmp_dirname, "adapter0", "pytorch_lora_weights.safetensors")
|
||||
file_name1 = os.path.join(tmp_dirname, "adapter1", "pytorch_lora_weights.safetensors")
|
||||
|
||||
pipeline.load_lora_weights(file_name0)
|
||||
if do_compile:
|
||||
pipeline.unet = torch.compile(pipeline.unet, mode="reduce-overhead")
|
||||
|
||||
output0_after = pipeline(**dummy_input, generator=torch.manual_seed(0))[0]
|
||||
|
||||
# sanity check: still same result
|
||||
assert np.allclose(output0_before, output0_after, atol=tol, rtol=tol)
|
||||
|
||||
# hotswap the 2nd adapter
|
||||
pipeline.load_lora_weights(file_name1, hotswap=True, adapter_name="default_0")
|
||||
output1_after = pipeline(**dummy_input, generator=torch.manual_seed(0))[0]
|
||||
|
||||
# sanity check: since it's the same LoRA, the results should be identical
|
||||
assert np.allclose(output1_before, output1_after, atol=tol, rtol=tol)
|
||||
|
||||
@parameterized.expand([(11, 11), (7, 13), (13, 7)]) # important to test small to large and vice versa
|
||||
def test_hotswapping_pipeline(self, rank0, rank1):
|
||||
self.check_pipeline_hotswap(
|
||||
do_compile=False, rank0=rank0, rank1=rank1, target_modules0=["to_q", "to_k", "to_v", "to_out.0"]
|
||||
)
|
||||
|
||||
@parameterized.expand([(11, 11), (7, 13), (13, 7)]) # important to test small to large and vice versa
|
||||
def test_hotswapping_compiled_pipline_linear(self, rank0, rank1):
|
||||
# It's important to add this context to raise an error on recompilation
|
||||
target_modules = ["to_q", "to_k", "to_v", "to_out.0"]
|
||||
with torch._dynamo.config.patch(error_on_recompile=True):
|
||||
self.check_pipeline_hotswap(do_compile=True, rank0=rank0, rank1=rank1, target_modules0=target_modules)
|
||||
|
||||
@parameterized.expand([(11, 11), (7, 13), (13, 7)]) # important to test small to large and vice versa
|
||||
def test_hotswapping_compiled_pipline_conv2d(self, rank0, rank1):
|
||||
# It's important to add this context to raise an error on recompilation
|
||||
target_modules = ["conv", "conv1", "conv2"]
|
||||
with torch._dynamo.config.patch(error_on_recompile=True):
|
||||
self.check_pipeline_hotswap(do_compile=True, rank0=rank0, rank1=rank1, target_modules0=target_modules)
|
||||
|
||||
@parameterized.expand([(11, 11), (7, 13), (13, 7)]) # important to test small to large and vice versa
|
||||
def test_hotswapping_compiled_pipline_both_linear_and_conv2d(self, rank0, rank1):
|
||||
# It's important to add this context to raise an error on recompilation
|
||||
target_modules = ["to_q", "conv"]
|
||||
with torch._dynamo.config.patch(error_on_recompile=True):
|
||||
self.check_pipeline_hotswap(do_compile=True, rank0=rank0, rank1=rank1, target_modules0=target_modules)
|
||||
|
||||
def test_enable_lora_hotswap_called_after_adapter_added_raises(self):
|
||||
# ensure that enable_lora_hotswap is called before loading the first adapter
|
||||
lora_config = self.get_unet_lora_config(8, 8, target_modules=["to_q"])
|
||||
pipeline = StableDiffusionPipeline.from_pretrained("hf-internal-testing/tiny-sd-pipe").to(torch_device)
|
||||
pipeline.unet.add_adapter(lora_config)
|
||||
msg = re.escape("Call `enable_lora_hotswap` before loading the first adapter.")
|
||||
with self.assertRaisesRegex(RuntimeError, msg):
|
||||
pipeline.enable_lora_hotswap(target_rank=32)
|
||||
|
||||
def test_enable_lora_hotswap_called_after_adapter_added_warns(self):
|
||||
# ensure that enable_lora_hotswap is called before loading the first adapter
|
||||
from diffusers.loaders.peft import logger
|
||||
|
||||
lora_config = self.get_unet_lora_config(8, 8, target_modules=["to_q"])
|
||||
pipeline = StableDiffusionPipeline.from_pretrained("hf-internal-testing/tiny-sd-pipe").to(torch_device)
|
||||
pipeline.unet.add_adapter(lora_config)
|
||||
msg = (
|
||||
"It is recommended to call `enable_lora_hotswap` before loading the first adapter to avoid recompilation."
|
||||
)
|
||||
with self.assertLogs(logger=logger, level="WARNING") as cm:
|
||||
pipeline.enable_lora_hotswap(target_rank=32, check_compiled="warn")
|
||||
assert any(msg in log for log in cm.output)
|
||||
|
||||
def test_enable_lora_hotswap_called_after_adapter_added_ignore(self):
|
||||
# check possibility to ignore the error/warning
|
||||
lora_config = self.get_unet_lora_config(8, 8, target_modules=["to_q"])
|
||||
pipeline = StableDiffusionPipeline.from_pretrained("hf-internal-testing/tiny-sd-pipe").to(torch_device)
|
||||
pipeline.unet.add_adapter(lora_config)
|
||||
with warnings.catch_warnings(record=True) as w:
|
||||
warnings.simplefilter("always") # Capture all warnings
|
||||
pipeline.enable_lora_hotswap(target_rank=32, check_compiled="warn")
|
||||
self.assertEqual(len(w), 0, f"Expected no warnings, but got: {[str(warn.message) for warn in w]}")
|
||||
|
||||
def test_enable_lora_hotswap_wrong_check_compiled_argument_raises(self):
|
||||
# check that wrong argument value raises an error
|
||||
lora_config = self.get_unet_lora_config(8, 8, target_modules=["to_q"])
|
||||
pipeline = StableDiffusionPipeline.from_pretrained("hf-internal-testing/tiny-sd-pipe").to(torch_device)
|
||||
pipeline.unet.add_adapter(lora_config)
|
||||
msg = re.escape("check_compiles should be one of 'error', 'warn', or 'ignore', got 'wrong-argument' instead.")
|
||||
with self.assertRaisesRegex(ValueError, msg):
|
||||
pipeline.enable_lora_hotswap(target_rank=32, check_compiled="wrong-argument")
|
||||
|
||||
def test_hotswap_second_adapter_targets_more_layers_raises(self):
|
||||
# check the error and log
|
||||
from diffusers.loaders.peft import logger
|
||||
|
||||
# at the moment, PEFT requires the 2nd adapter to target the same or a subset of layers
|
||||
target_modules0 = ["to_q"]
|
||||
target_modules1 = ["to_q", "to_k"]
|
||||
with self.assertRaises(RuntimeError): # peft raises RuntimeError
|
||||
with self.assertLogs(logger=logger, level="ERROR") as cm:
|
||||
self.check_pipeline_hotswap(
|
||||
do_compile=True, rank0=8, rank1=8, target_modules0=target_modules0, target_modules1=target_modules1
|
||||
)
|
||||
assert any("Hotswapping adapter0 was unsuccessful" in log for log in cm.output)
|
||||
|
||||
def test_hotswap_component_not_supported_raises(self):
|
||||
# right now, not some components don't support hotswapping, e.g. the text_encoder
|
||||
from peft import LoraConfig
|
||||
|
||||
pipeline = StableDiffusionPipeline.from_pretrained("hf-internal-testing/tiny-sd-pipe").to(torch_device)
|
||||
lora_config0 = LoraConfig(target_modules=["q_proj"])
|
||||
lora_config1 = LoraConfig(target_modules=["q_proj"])
|
||||
|
||||
pipeline.text_encoder.add_adapter(lora_config0, adapter_name="adapter0")
|
||||
pipeline.text_encoder.add_adapter(lora_config1, adapter_name="adapter1")
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmp_dirname:
|
||||
# save the adapter checkpoints
|
||||
lora0_state_dicts = self.get_lora_state_dicts(
|
||||
{"text_encoder": pipeline.text_encoder}, adapter_name="adapter0"
|
||||
)
|
||||
StableDiffusionPipeline.save_lora_weights(
|
||||
save_directory=os.path.join(tmp_dirname, "adapter0"), safe_serialization=True, **lora0_state_dicts
|
||||
)
|
||||
lora1_state_dicts = self.get_lora_state_dicts(
|
||||
{"text_encoder": pipeline.text_encoder}, adapter_name="adapter1"
|
||||
)
|
||||
StableDiffusionPipeline.save_lora_weights(
|
||||
save_directory=os.path.join(tmp_dirname, "adapter1"), safe_serialization=True, **lora1_state_dicts
|
||||
)
|
||||
del pipeline
|
||||
|
||||
# load the first adapter
|
||||
pipeline = StableDiffusionPipeline.from_pretrained("hf-internal-testing/tiny-sd-pipe").to(torch_device)
|
||||
file_name0 = os.path.join(tmp_dirname, "adapter0", "pytorch_lora_weights.safetensors")
|
||||
file_name1 = os.path.join(tmp_dirname, "adapter1", "pytorch_lora_weights.safetensors")
|
||||
|
||||
pipeline.load_lora_weights(file_name0)
|
||||
msg = re.escape(
|
||||
"At the moment, hotswapping is not supported for text encoders, please pass `hotswap=False`"
|
||||
)
|
||||
with self.assertRaisesRegex(ValueError, msg):
|
||||
pipeline.load_lora_weights(file_name1, hotswap=True, adapter_name="default_0")
|
||||
|
||||
@@ -379,7 +379,7 @@ class BnB8bitTrainingTests(Base8bitTests):
|
||||
model_inputs.update({k: v for k, v in input_dict_for_transformer.items() if k not in model_inputs})
|
||||
|
||||
# Step 4: Check if the gradient is not None
|
||||
with torch.amp.autocast("cuda", dtype=torch.float16):
|
||||
with torch.amp.autocast(torch_device, dtype=torch.float16):
|
||||
out = self.model_8bit(**model_inputs)[0]
|
||||
out.norm().backward()
|
||||
|
||||
|
||||
Reference in New Issue
Block a user