Compare commits

..

9 Commits

Author SHA1 Message Date
sayakpaul b56112db6e use backend-agnostic cache and pass devide. 2025-04-09 11:48:26 +05:30
Sayak Paul f50de75b69 Merge branch 'main' into fix-sd3-controlnet-validation 2025-04-09 11:14:43 +05:30
sayakpaul 579bb5f418 fix: SD3 ControlNet validation so that it runs on a A100. 2025-04-09 11:13:43 +05:30
Sayak Paul 6bfacf0418 [LoRA] support more comyui loras for Flux 🚨 (#10985)
* support more comyui loras.

* fix

* fixes

* revert changes in LoRA base.

* no position_embedding

* 🚨 introduce a breaking change to let peft handle module ambiguity

* styling

* remove position embeddings.

* improvements.

* style

* make info instead of NotImplementedError

* Update src/diffusers/loaders/peft.py

Co-authored-by: hlky <hlky@hlky.ac>

* add example.

* robust checks

* updates

---------

Co-authored-by: hlky <hlky@hlky.ac>
2025-04-09 09:17:05 +05:30
Sayak Paul f685981ed0 [docs] minor updates to dtype map docs. (#11237)
minor updates to dtype map docs.
2025-04-09 08:38:17 +05:30
Sayak Paul b924251dd8 minor update to sana sprint docs. (#11236) 2025-04-09 08:17:45 +05:30
Sayak Paul 1a04812439 [bistandbytes] improve replacement warnings for bnb (#11132)
* improve replacement warnings for bnb

* updates to docs.
2025-04-08 21:18:34 +05:30
Sayak Paul 4b27c4a494 [feat] implement record_stream when using CUDA streams during group offloading (#11081)
* implement record_stream for better performance.

* fix

* style.

* merge #11097

* Update src/diffusers/hooks/group_offloading.py

Co-authored-by: Aryan <aryan@huggingface.co>

* fixes

* docstring.

* remaining todos in low_cpu_mem_usage

* tests

* updates to docs.

---------

Co-authored-by: Aryan <aryan@huggingface.co>
2025-04-08 21:17:49 +05:30
hlky 5d49b3e83b Flux quantized with lora (#10990)
* Flux quantized with lora

* fix

* changes

* Apply suggestions from code review

Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>

* Apply style fixes

* enable model cpu offload()

* Update src/diffusers/loaders/lora_pipeline.py

Co-authored-by: hlky <hlky@hlky.ac>

* update

* Apply suggestions from code review

* update

* add peft as an additional dependency for gguf

---------

Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>
Co-authored-by: github-actions[bot] <github-actions[bot]@users.noreply.github.com>
Co-authored-by: Dhruv Nair <dhruv.nair@gmail.com>
2025-04-08 21:17:03 +05:30
19 changed files with 552 additions and 164 deletions
+1 -1
View File
@@ -417,7 +417,7 @@ jobs:
additional_deps: ["peft"]
- backend: "gguf"
test_location: "gguf"
additional_deps: []
additional_deps: ["peft"]
- backend: "torchao"
test_location: "torchao"
additional_deps: []
+1 -1
View File
@@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License. -->
# SanaSprintPipeline
# SANA-Sprint
<div class="flex flex-wrap space-x-1">
<img alt="LoRA" src="https://img.shields.io/badge/LoRA-d8b4fe?style=flat"/>
+4
View File
@@ -178,6 +178,9 @@ pipe = CogVideoXPipeline.from_pretrained("THUDM/CogVideoX-5b", torch_dtype=torch
# We can utilize the enable_group_offload method for Diffusers model implementations
pipe.transformer.enable_group_offload(onload_device=onload_device, offload_device=offload_device, offload_type="leaf_level", use_stream=True)
# Uncomment the following to also allow recording the current streams.
# pipe.transformer.enable_group_offload(onload_device=onload_device, offload_device=offload_device, offload_type="leaf_level", use_stream=True, record_stream=True)
# For any other model implementations, the apply_group_offloading function can be used
apply_group_offloading(pipe.text_encoder, onload_device=onload_device, offload_type="block_level", num_blocks_per_group=2)
apply_group_offloading(pipe.vae, onload_device=onload_device, offload_type="leaf_level")
@@ -205,6 +208,7 @@ Group offloading (for CUDA devices with support for asynchronous data transfer s
- The `use_stream` parameter can be used with CUDA devices to enable prefetching layers for onload. It defaults to `False`. Layer prefetching allows overlapping computation and data transfer of model weights, which drastically reduces the overall execution time compared to other offloading methods. However, it can increase the CPU RAM usage significantly. Ensure that available CPU RAM that is at least twice the size of the model when setting `use_stream=True`. You can find more information about CUDA streams [here](https://pytorch.org/docs/stable/generated/torch.cuda.Stream.html)
- If specifying `use_stream=True` on VAEs with tiling enabled, make sure to do a dummy forward pass (possibly with dummy inputs) before the actual inference to avoid device-mismatch errors. This may not work on all implementations. Please open an issue if you encounter any problems.
- The parameter `low_cpu_mem_usage` can be set to `True` to reduce CPU memory usage when using streams for group offloading. This is useful when the CPU memory is the bottleneck, but it may counteract the benefits of using streams and increase the overall execution time. The CPU memory savings come from creating pinned-tensors on-the-fly instead of pre-pinning them. This parameter is better suited for using `leaf_level` offloading.
- When using `use_stream=True`, users can additionally specify `record_stream=True` to get better speedups at the expense of slightly increased memory usage. Refer to the [official PyTorch docs](https://pytorch.org/docs/stable/generated/torch.Tensor.record_stream.html) to know more about this.
For more information about available parameters and an explanation of how group offloading works, refer to [`~hooks.group_offloading.apply_group_offloading`].
+1 -1
View File
@@ -105,7 +105,7 @@ import torch
pipe = HunyuanVideoPipeline.from_pretrained(
"hunyuanvideo-community/HunyuanVideo",
torch_dtype={'transformer': torch.bfloat16, 'default': torch.float16},
torch_dtype={"transformer": torch.bfloat16, "default": torch.float16},
)
print(pipe.transformer.dtype, pipe.vae.dtype) # (torch.bfloat16, torch.float16)
```
+44 -3
View File
@@ -17,6 +17,7 @@ import argparse
import contextlib
import copy
import functools
import gc
import logging
import math
import os
@@ -52,6 +53,7 @@ from diffusers.optimization import get_scheduler
from diffusers.training_utils import compute_density_for_timestep_sampling, compute_loss_weighting_for_sd3, free_memory
from diffusers.utils import check_min_version, is_wandb_available, make_image_grid
from diffusers.utils.hub_utils import load_or_create_model_card, populate_model_card
from diffusers.utils.testing_utils import backend_empty_cache
from diffusers.utils.torch_utils import is_compiled_module
@@ -74,8 +76,9 @@ def log_validation(controlnet, args, accelerator, weight_dtype, step, is_final_v
pipeline = StableDiffusion3ControlNetPipeline.from_pretrained(
args.pretrained_model_name_or_path,
controlnet=controlnet,
controlnet=None,
safety_checker=None,
transformer=None,
revision=args.revision,
variant=args.variant,
torch_dtype=weight_dtype,
@@ -102,18 +105,55 @@ def log_validation(controlnet, args, accelerator, weight_dtype, step, is_final_v
"number of `args.validation_image` and `args.validation_prompt` should be checked in `parse_args`"
)
with torch.no_grad():
(
prompt_embeds,
negative_prompt_embeds,
pooled_prompt_embeds,
negative_pooled_prompt_embeds,
) = pipeline.encode_prompt(
validation_prompts,
prompt_2=None,
prompt_3=None,
)
del pipeline
gc.collect()
backend_empty_cache(accelerator.device.type)
pipeline = StableDiffusion3ControlNetPipeline.from_pretrained(
args.pretrained_model_name_or_path,
controlnet=controlnet,
safety_checker=None,
text_encoder=None,
text_encoder_2=None,
text_encoder_3=None,
revision=args.revision,
variant=args.variant,
torch_dtype=weight_dtype,
)
pipeline.enable_model_cpu_offload(device=accelerator.device.type)
pipeline.set_progress_bar_config(disable=True)
image_logs = []
inference_ctx = contextlib.nullcontext() if is_final_validation else torch.autocast(accelerator.device.type)
for validation_prompt, validation_image in zip(validation_prompts, validation_images):
for i, validation_image in enumerate(validation_images):
validation_image = Image.open(validation_image).convert("RGB")
validation_prompt = validation_prompts[i]
images = []
for _ in range(args.num_validation_images):
with inference_ctx:
image = pipeline(
validation_prompt, control_image=validation_image, num_inference_steps=20, generator=generator
prompt_embeds=prompt_embeds[i].unsqueeze(0),
negative_prompt_embeds=negative_prompt_embeds[i].unsqueeze(0),
pooled_prompt_embeds=pooled_prompt_embeds[i].unsqueeze(0),
negative_pooled_prompt_embeds=negative_pooled_prompt_embeds[i].unsqueeze(0),
control_image=validation_image,
num_inference_steps=20,
generator=generator,
).images[0]
images.append(image)
@@ -655,6 +695,7 @@ def make_train_dataset(args, tokenizer_one, tokenizer_two, tokenizer_three, acce
dataset = load_dataset(
args.train_data_dir,
cache_dir=args.cache_dir,
trust_remote_code=True,
)
# See more about loading custom images at
# https://huggingface.co/docs/datasets/v2.0.0/en/dataset_script
+62 -4
View File
@@ -56,6 +56,7 @@ class ModuleGroup:
buffers: Optional[List[torch.Tensor]] = None,
non_blocking: bool = False,
stream: Optional[torch.cuda.Stream] = None,
record_stream: Optional[bool] = False,
low_cpu_mem_usage=False,
onload_self: bool = True,
) -> None:
@@ -68,11 +69,14 @@ class ModuleGroup:
self.buffers = buffers or []
self.non_blocking = non_blocking or stream is not None
self.stream = stream
self.record_stream = record_stream
self.onload_self = onload_self
self.low_cpu_mem_usage = low_cpu_mem_usage
self.cpu_param_dict = self._init_cpu_param_dict()
if self.stream is None and self.record_stream:
raise ValueError("`record_stream` cannot be True when `stream` is None.")
def _init_cpu_param_dict(self):
cpu_param_dict = {}
if self.stream is None:
@@ -112,6 +116,8 @@ class ModuleGroup:
def onload_(self):
r"""Onloads the group of modules to the onload_device."""
context = nullcontext() if self.stream is None else torch.cuda.stream(self.stream)
current_stream = torch.cuda.current_stream() if self.record_stream else None
if self.stream is not None:
# Wait for previous Host->Device transfer to complete
self.stream.synchronize()
@@ -122,14 +128,22 @@ class ModuleGroup:
for group_module in self.modules:
for param in group_module.parameters():
param.data = pinned_memory[param].to(self.onload_device, non_blocking=self.non_blocking)
if self.record_stream:
param.data.record_stream(current_stream)
for buffer in group_module.buffers():
buffer.data = pinned_memory[buffer].to(self.onload_device, non_blocking=self.non_blocking)
if self.record_stream:
buffer.data.record_stream(current_stream)
for param in self.parameters:
param.data = pinned_memory[param].to(self.onload_device, non_blocking=self.non_blocking)
if self.record_stream:
param.data.record_stream(current_stream)
for buffer in self.buffers:
buffer.data = pinned_memory[buffer].to(self.onload_device, non_blocking=self.non_blocking)
if self.record_stream:
buffer.data.record_stream(current_stream)
else:
for group_module in self.modules:
@@ -143,11 +157,14 @@ class ModuleGroup:
for buffer in self.buffers:
buffer.data = buffer.data.to(self.onload_device, non_blocking=self.non_blocking)
if self.record_stream:
buffer.data.record_stream(current_stream)
def offload_(self):
r"""Offloads the group of modules to the offload_device."""
if self.stream is not None:
torch.cuda.current_stream().synchronize()
if not self.record_stream:
torch.cuda.current_stream().synchronize()
for group_module in self.modules:
for param in group_module.parameters():
param.data = self.cpu_param_dict[param]
@@ -331,6 +348,7 @@ def apply_group_offloading(
num_blocks_per_group: Optional[int] = None,
non_blocking: bool = False,
use_stream: bool = False,
record_stream: bool = False,
low_cpu_mem_usage: bool = False,
) -> None:
r"""
@@ -378,6 +396,10 @@ def apply_group_offloading(
use_stream (`bool`, defaults to `False`):
If True, offloading and onloading is done asynchronously using a CUDA stream. This can be useful for
overlapping computation and data transfer.
record_stream (`bool`, defaults to `False`): When enabled with `use_stream`, it marks the current tensor
as having been used by this stream. It is faster at the expense of slightly more memory usage. Refer to the
[PyTorch official docs](https://pytorch.org/docs/stable/generated/torch.Tensor.record_stream.html) more
details.
low_cpu_mem_usage (`bool`, defaults to `False`):
If True, the CPU memory usage is minimized by pinning tensors on-the-fly instead of pre-pinning them. This
option only matters when using streamed CPU offloading (i.e. `use_stream=True`). This can be useful when
@@ -417,11 +439,24 @@ def apply_group_offloading(
raise ValueError("num_blocks_per_group must be provided when using offload_type='block_level'.")
_apply_group_offloading_block_level(
module, num_blocks_per_group, offload_device, onload_device, non_blocking, stream, low_cpu_mem_usage
module=module,
num_blocks_per_group=num_blocks_per_group,
offload_device=offload_device,
onload_device=onload_device,
non_blocking=non_blocking,
stream=stream,
record_stream=record_stream,
low_cpu_mem_usage=low_cpu_mem_usage,
)
elif offload_type == "leaf_level":
_apply_group_offloading_leaf_level(
module, offload_device, onload_device, non_blocking, stream, low_cpu_mem_usage
module=module,
offload_device=offload_device,
onload_device=onload_device,
non_blocking=non_blocking,
stream=stream,
record_stream=record_stream,
low_cpu_mem_usage=low_cpu_mem_usage,
)
else:
raise ValueError(f"Unsupported offload_type: {offload_type}")
@@ -434,6 +469,7 @@ def _apply_group_offloading_block_level(
onload_device: torch.device,
non_blocking: bool,
stream: Optional[torch.cuda.Stream] = None,
record_stream: Optional[bool] = False,
low_cpu_mem_usage: bool = False,
) -> None:
r"""
@@ -453,6 +489,14 @@ def _apply_group_offloading_block_level(
stream (`torch.cuda.Stream`, *optional*):
If provided, offloading and onloading is done asynchronously using the provided stream. This can be useful
for overlapping computation and data transfer.
record_stream (`bool`, defaults to `False`): When enabled with `use_stream`, it marks the current tensor
as having been used by this stream. It is faster at the expense of slightly more memory usage. Refer to the
[PyTorch official docs](https://pytorch.org/docs/stable/generated/torch.Tensor.record_stream.html) more
details.
low_cpu_mem_usage (`bool`, defaults to `False`):
If True, the CPU memory usage is minimized by pinning tensors on-the-fly instead of pre-pinning them. This
option only matters when using streamed CPU offloading (i.e. `use_stream=True`). This can be useful when
the CPU memory is a bottleneck but may counteract the benefits of using streams.
"""
# Create module groups for ModuleList and Sequential blocks
@@ -475,6 +519,7 @@ def _apply_group_offloading_block_level(
onload_leader=current_modules[0],
non_blocking=non_blocking,
stream=stream,
record_stream=record_stream,
low_cpu_mem_usage=low_cpu_mem_usage,
onload_self=stream is None,
)
@@ -512,6 +557,7 @@ def _apply_group_offloading_block_level(
buffers=buffers,
non_blocking=False,
stream=None,
record_stream=False,
onload_self=True,
)
next_group = matched_module_groups[0] if len(matched_module_groups) > 0 else None
@@ -524,6 +570,7 @@ def _apply_group_offloading_leaf_level(
onload_device: torch.device,
non_blocking: bool,
stream: Optional[torch.cuda.Stream] = None,
record_stream: Optional[bool] = False,
low_cpu_mem_usage: bool = False,
) -> None:
r"""
@@ -545,6 +592,14 @@ def _apply_group_offloading_leaf_level(
stream (`torch.cuda.Stream`, *optional*):
If provided, offloading and onloading is done asynchronously using the provided stream. This can be useful
for overlapping computation and data transfer.
record_stream (`bool`, defaults to `False`): When enabled with `use_stream`, it marks the current tensor
as having been used by this stream. It is faster at the expense of slightly more memory usage. Refer to the
[PyTorch official docs](https://pytorch.org/docs/stable/generated/torch.Tensor.record_stream.html) more
details.
low_cpu_mem_usage (`bool`, defaults to `False`):
If True, the CPU memory usage is minimized by pinning tensors on-the-fly instead of pre-pinning them. This
option only matters when using streamed CPU offloading (i.e. `use_stream=True`). This can be useful when
the CPU memory is a bottleneck but may counteract the benefits of using streams.
"""
# Create module groups for leaf modules and apply group offloading hooks
@@ -560,6 +615,7 @@ def _apply_group_offloading_leaf_level(
onload_leader=submodule,
non_blocking=non_blocking,
stream=stream,
record_stream=record_stream,
low_cpu_mem_usage=low_cpu_mem_usage,
onload_self=True,
)
@@ -605,6 +661,7 @@ def _apply_group_offloading_leaf_level(
buffers=buffers,
non_blocking=non_blocking,
stream=stream,
record_stream=record_stream,
low_cpu_mem_usage=low_cpu_mem_usage,
onload_self=True,
)
@@ -624,6 +681,7 @@ def _apply_group_offloading_leaf_level(
buffers=None,
non_blocking=False,
stream=None,
record_stream=False,
low_cpu_mem_usage=low_cpu_mem_usage,
onload_self=True,
)
+218 -11
View File
@@ -13,15 +13,22 @@
# limitations under the License.
import re
from typing import List
import torch
from ..utils import is_peft_version, logging
from ..utils import is_peft_version, logging, state_dict_all_zero
logger = logging.get_logger(__name__)
def swap_scale_shift(weight):
shift, scale = weight.chunk(2, dim=0)
new_weight = torch.cat([scale, shift], dim=0)
return new_weight
def _maybe_map_sgm_blocks_to_diffusers(state_dict, unet_config, delimiter="_", block_slice_pos=5):
# 1. get all state_dict_keys
all_keys = list(state_dict.keys())
@@ -313,6 +320,7 @@ def _convert_text_encoder_lora_key(key, lora_name):
# Be aware that this is the new diffusers convention and the rest of the code might
# not utilize it yet.
diffusers_name = diffusers_name.replace(".lora.", ".lora_linear_layer.")
return diffusers_name
@@ -331,8 +339,7 @@ def _get_alpha_name(lora_name_alpha, diffusers_name, alpha):
# The utilities under `_convert_kohya_flux_lora_to_diffusers()`
# are taken from https://github.com/kohya-ss/sd-scripts/blob/a61cf73a5cb5209c3f4d1a3688dd276a4dfd1ecb/networks/convert_flux_lora.py
# All credits go to `kohya-ss`.
# are adapted from https://github.com/kohya-ss/sd-scripts/blob/a61cf73a5cb5209c3f4d1a3688dd276a4dfd1ecb/networks/convert_flux_lora.py
def _convert_kohya_flux_lora_to_diffusers(state_dict):
def _convert_to_ai_toolkit(sds_sd, ait_sd, sds_key, ait_key):
if sds_key + ".lora_down.weight" not in sds_sd:
@@ -341,7 +348,8 @@ def _convert_kohya_flux_lora_to_diffusers(state_dict):
# scale weight by alpha and dim
rank = down_weight.shape[0]
alpha = sds_sd.pop(sds_key + ".alpha").item() # alpha is scalar
default_alpha = torch.tensor(rank, dtype=down_weight.dtype, device=down_weight.device, requires_grad=False)
alpha = sds_sd.pop(sds_key + ".alpha", default_alpha).item() # alpha is scalar
scale = alpha / rank # LoRA is scaled by 'alpha / rank' in forward pass, so we need to scale it back here
# calculate scale_down and scale_up to keep the same value. if scale is 4, scale_down is 2 and scale_up is 2
@@ -362,7 +370,10 @@ def _convert_kohya_flux_lora_to_diffusers(state_dict):
sd_lora_rank = down_weight.shape[0]
# scale weight by alpha and dim
alpha = sds_sd.pop(sds_key + ".alpha")
default_alpha = torch.tensor(
sd_lora_rank, dtype=down_weight.dtype, device=down_weight.device, requires_grad=False
)
alpha = sds_sd.pop(sds_key + ".alpha", default_alpha)
scale = alpha / sd_lora_rank
# calculate scale_down and scale_up
@@ -516,10 +527,103 @@ def _convert_kohya_flux_lora_to_diffusers(state_dict):
f"transformer.single_transformer_blocks.{i}.norm.linear",
)
# TODO: alphas.
def assign_remaining_weights(assignments, source):
for lora_key in ["lora_A", "lora_B"]:
orig_lora_key = "lora_down" if lora_key == "lora_A" else "lora_up"
for target_fmt, source_fmt, transform in assignments:
target_key = target_fmt.format(lora_key=lora_key)
source_key = source_fmt.format(orig_lora_key=orig_lora_key)
value = source.pop(source_key)
if transform:
value = transform(value)
ait_sd[target_key] = value
if any("guidance_in" in k for k in sds_sd):
assign_remaining_weights(
[
(
"time_text_embed.guidance_embedder.linear_1.{lora_key}.weight",
"lora_unet_guidance_in_in_layer.{orig_lora_key}.weight",
None,
),
(
"time_text_embed.guidance_embedder.linear_2.{lora_key}.weight",
"lora_unet_guidance_in_out_layer.{orig_lora_key}.weight",
None,
),
],
sds_sd,
)
if any("img_in" in k for k in sds_sd):
assign_remaining_weights(
[
("x_embedder.{lora_key}.weight", "lora_unet_img_in.{orig_lora_key}.weight", None),
],
sds_sd,
)
if any("txt_in" in k for k in sds_sd):
assign_remaining_weights(
[
("context_embedder.{lora_key}.weight", "lora_unet_txt_in.{orig_lora_key}.weight", None),
],
sds_sd,
)
if any("time_in" in k for k in sds_sd):
assign_remaining_weights(
[
(
"time_text_embed.timestep_embedder.linear_1.{lora_key}.weight",
"lora_unet_time_in_in_layer.{orig_lora_key}.weight",
None,
),
(
"time_text_embed.timestep_embedder.linear_2.{lora_key}.weight",
"lora_unet_time_in_out_layer.{orig_lora_key}.weight",
None,
),
],
sds_sd,
)
if any("vector_in" in k for k in sds_sd):
assign_remaining_weights(
[
(
"time_text_embed.text_embedder.linear_1.{lora_key}.weight",
"lora_unet_vector_in_in_layer.{orig_lora_key}.weight",
None,
),
(
"time_text_embed.text_embedder.linear_2.{lora_key}.weight",
"lora_unet_vector_in_out_layer.{orig_lora_key}.weight",
None,
),
],
sds_sd,
)
if any("final_layer" in k for k in sds_sd):
# Notice the swap in processing for "final_layer".
assign_remaining_weights(
[
(
"norm_out.linear.{lora_key}.weight",
"lora_unet_final_layer_adaLN_modulation_1.{orig_lora_key}.weight",
swap_scale_shift,
),
("proj_out.{lora_key}.weight", "lora_unet_final_layer_linear.{orig_lora_key}.weight", None),
],
sds_sd,
)
remaining_keys = list(sds_sd.keys())
te_state_dict = {}
if remaining_keys:
if not all(k.startswith("lora_te") for k in remaining_keys):
if not all(k.startswith(("lora_te", "lora_te1")) for k in remaining_keys):
raise ValueError(f"Incompatible keys detected: \n\n {', '.join(remaining_keys)}")
for key in remaining_keys:
if not key.endswith("lora_down.weight"):
@@ -680,10 +784,98 @@ def _convert_kohya_flux_lora_to_diffusers(state_dict):
if has_peft_state_dict:
state_dict = {k: v for k, v in state_dict.items() if k.startswith("transformer.")}
return state_dict
# Another weird one.
has_mixture = any(
k.startswith("lora_transformer_") and ("lora_down" in k or "lora_up" in k or "alpha" in k) for k in state_dict
)
# ComfyUI.
if not has_mixture:
state_dict = {k.replace("diffusion_model.", "lora_unet_"): v for k, v in state_dict.items()}
state_dict = {k.replace("text_encoders.clip_l.transformer.", "lora_te_"): v for k, v in state_dict.items()}
has_position_embedding = any("position_embedding" in k for k in state_dict)
if has_position_embedding:
zero_status_pe = state_dict_all_zero(state_dict, "position_embedding")
if zero_status_pe:
logger.info(
"The `position_embedding` LoRA params are all zeros which make them ineffective. "
"So, we will purge them out of the curret state dict to make loading possible."
)
else:
logger.info(
"The state_dict has position_embedding LoRA params and we currently do not support them. "
"Open an issue if you need this supported - https://github.com/huggingface/diffusers/issues/new."
)
state_dict = {k: v for k, v in state_dict.items() if "position_embedding" not in k}
has_t5xxl = any(k.startswith("text_encoders.t5xxl.transformer.") for k in state_dict)
if has_t5xxl:
zero_status_t5 = state_dict_all_zero(state_dict, "text_encoders.t5xxl")
if zero_status_t5:
logger.info(
"The `t5xxl` LoRA params are all zeros which make them ineffective. "
"So, we will purge them out of the curret state dict to make loading possible."
)
else:
logger.info(
"T5-xxl keys found in the state dict, which are currently unsupported. We will filter them out."
"Open an issue if this is a problem - https://github.com/huggingface/diffusers/issues/new."
)
state_dict = {k: v for k, v in state_dict.items() if not k.startswith("text_encoders.t5xxl.transformer.")}
has_diffb = any("diff_b" in k and k.startswith(("lora_unet_", "lora_te_")) for k in state_dict)
if has_diffb:
zero_status_diff_b = state_dict_all_zero(state_dict, ".diff_b")
if zero_status_diff_b:
logger.info(
"The `diff_b` LoRA params are all zeros which make them ineffective. "
"So, we will purge them out of the curret state dict to make loading possible."
)
else:
logger.info(
"`diff_b` keys found in the state dict which are currently unsupported. "
"So, we will filter out those keys. Open an issue if this is a problem - "
"https://github.com/huggingface/diffusers/issues/new."
)
state_dict = {k: v for k, v in state_dict.items() if ".diff_b" not in k}
has_norm_diff = any(".norm" in k and ".diff" in k for k in state_dict)
if has_norm_diff:
zero_status_diff = state_dict_all_zero(state_dict, ".diff")
if zero_status_diff:
logger.info(
"The `diff` LoRA params are all zeros which make them ineffective. "
"So, we will purge them out of the curret state dict to make loading possible."
)
else:
logger.info(
"Normalization diff keys found in the state dict which are currently unsupported. "
"So, we will filter out those keys. Open an issue if this is a problem - "
"https://github.com/huggingface/diffusers/issues/new."
)
state_dict = {k: v for k, v in state_dict.items() if ".norm" not in k and ".diff" not in k}
limit_substrings = ["lora_down", "lora_up"]
if any("alpha" in k for k in state_dict):
limit_substrings.append("alpha")
state_dict = {
_custom_replace(k, limit_substrings): v
for k, v in state_dict.items()
if k.startswith(("lora_unet_", "lora_te_"))
}
if any("text_projection" in k for k in state_dict):
logger.info(
"`text_projection` keys found in the `state_dict` which are unexpected. "
"So, we will filter out those keys. Open an issue if this is a problem - "
"https://github.com/huggingface/diffusers/issues/new."
)
state_dict = {k: v for k, v in state_dict.items() if "text_projection" not in k}
if has_mixture:
return _convert_mixture_state_dict_to_diffusers(state_dict)
@@ -798,6 +990,26 @@ def _convert_xlabs_flux_lora_to_diffusers(old_state_dict):
return new_state_dict
def _custom_replace(key: str, substrings: List[str]) -> str:
# Replaces the "."s with "_"s upto the `substrings`.
# Example:
# lora_unet.foo.bar.lora_A.weight -> lora_unet_foo_bar.lora_A.weight
pattern = "(" + "|".join(re.escape(sub) for sub in substrings) + ")"
match = re.search(pattern, key)
if match:
start_sub = match.start()
if start_sub > 0 and key[start_sub - 1] == ".":
boundary = start_sub - 1
else:
boundary = start_sub
left = key[:boundary].replace(".", "_")
right = key[boundary:]
return left + right
else:
return key.replace(".", "_")
def _convert_bfl_flux_control_lora_to_diffusers(original_state_dict):
converted_state_dict = {}
original_state_dict_keys = list(original_state_dict.keys())
@@ -806,11 +1018,6 @@ def _convert_bfl_flux_control_lora_to_diffusers(original_state_dict):
inner_dim = 3072
mlp_ratio = 4.0
def swap_scale_shift(weight):
shift, scale = weight.chunk(2, dim=0)
new_weight = torch.cat([scale, shift], dim=0)
return new_weight
for lora_key in ["lora_A", "lora_B"]:
## time_text_embed.timestep_embedder <- time_in
converted_state_dict[
+58 -5
View File
@@ -22,6 +22,8 @@ from ..utils import (
USE_PEFT_BACKEND,
deprecate,
get_submodule_by_name,
is_bitsandbytes_available,
is_gguf_available,
is_peft_available,
is_peft_version,
is_torch_version,
@@ -68,6 +70,49 @@ TRANSFORMER_NAME = "transformer"
_MODULE_NAME_TO_ATTRIBUTE_MAP_FLUX = {"x_embedder": "in_channels"}
def _maybe_dequantize_weight_for_expanded_lora(model, module):
if is_bitsandbytes_available():
from ..quantizers.bitsandbytes import dequantize_bnb_weight
if is_gguf_available():
from ..quantizers.gguf.utils import dequantize_gguf_tensor
is_bnb_4bit_quantized = module.weight.__class__.__name__ == "Params4bit"
is_gguf_quantized = module.weight.__class__.__name__ == "GGUFParameter"
if is_bnb_4bit_quantized and not is_bitsandbytes_available():
raise ValueError(
"The checkpoint seems to have been quantized with `bitsandbytes` (4bits). Install `bitsandbytes` to load quantized checkpoints."
)
if is_gguf_quantized and not is_gguf_available():
raise ValueError(
"The checkpoint seems to have been quantized with `gguf`. Install `gguf` to load quantized checkpoints."
)
weight_on_cpu = False
if not module.weight.is_cuda:
weight_on_cpu = True
if is_bnb_4bit_quantized:
module_weight = dequantize_bnb_weight(
module.weight.cuda() if weight_on_cpu else module.weight,
state=module.weight.quant_state,
dtype=model.dtype,
).data
elif is_gguf_quantized:
module_weight = dequantize_gguf_tensor(
module.weight.cuda() if weight_on_cpu else module.weight,
)
module_weight = module_weight.to(model.dtype)
else:
module_weight = module.weight.data
if weight_on_cpu:
module_weight = module_weight.cpu()
return module_weight
class StableDiffusionLoraLoaderMixin(LoraBaseMixin):
r"""
Load LoRA layers into Stable Diffusion [`UNet2DConditionModel`] and
@@ -2267,6 +2312,7 @@ class FluxLoraLoaderMixin(LoraBaseMixin):
overwritten_params = {}
is_peft_loaded = getattr(transformer, "peft_config", None) is not None
is_quantized = hasattr(transformer, "hf_quantizer")
for name, module in transformer.named_modules():
if isinstance(module, torch.nn.Linear):
module_weight = module.weight.data
@@ -2291,9 +2337,7 @@ class FluxLoraLoaderMixin(LoraBaseMixin):
if tuple(module_weight_shape) == (out_features, in_features):
continue
# TODO (sayakpaul): We still need to consider if the module we're expanding is
# quantized and handle it accordingly if that is the case.
module_out_features, module_in_features = module_weight.shape
module_out_features, module_in_features = module_weight_shape
debug_message = ""
if in_features > module_in_features:
debug_message += (
@@ -2316,6 +2360,10 @@ class FluxLoraLoaderMixin(LoraBaseMixin):
parent_module_name, _, current_module_name = name.rpartition(".")
parent_module = transformer.get_submodule(parent_module_name)
if is_quantized:
module_weight = _maybe_dequantize_weight_for_expanded_lora(transformer, module)
# TODO: consider if this layer needs to be a quantized layer as well if `is_quantized` is True.
with torch.device("meta"):
expanded_module = torch.nn.Linear(
in_features, out_features, bias=bias, dtype=module_weight.dtype
@@ -2327,7 +2375,7 @@ class FluxLoraLoaderMixin(LoraBaseMixin):
new_weight = torch.zeros_like(
expanded_module.weight.data, device=module_weight.device, dtype=module_weight.dtype
)
slices = tuple(slice(0, dim) for dim in module_weight.shape)
slices = tuple(slice(0, dim) for dim in module_weight_shape)
new_weight[slices] = module_weight
tmp_state_dict = {"weight": new_weight}
if module_bias is not None:
@@ -2416,7 +2464,12 @@ class FluxLoraLoaderMixin(LoraBaseMixin):
base_weight_param_name: str = None,
) -> "torch.Size":
def _get_weight_shape(weight: torch.Tensor):
return weight.quant_state.shape if weight.__class__.__name__ == "Params4bit" else weight.shape
if weight.__class__.__name__ == "Params4bit":
return weight.quant_state.shape
elif weight.__class__.__name__ == "GGUFParameter":
return weight.quant_shape
else:
return weight.shape
if base_module is not None:
return _get_weight_shape(base_module.weight)
+11 -44
View File
@@ -58,23 +58,11 @@ _SET_ADAPTER_SCALE_FN_MAPPING = {
}
def _maybe_adjust_config(config):
"""
We may run into some ambiguous configuration values when a model has module names, sharing a common prefix
(`proj_out.weight` and `blocks.transformer.proj_out.weight`, for example) and they have different LoRA ranks. This
method removes the ambiguity by following what is described here:
https://github.com/huggingface/diffusers/pull/9985#issuecomment-2493840028.
"""
# Track keys that have been explicitly removed to prevent re-adding them.
deleted_keys = set()
def _maybe_raise_error_for_ambiguity(config):
rank_pattern = config["rank_pattern"].copy()
target_modules = config["target_modules"]
original_r = config["r"]
for key in list(rank_pattern.keys()):
key_rank = rank_pattern[key]
# try to detect ambiguity
# `target_modules` can also be a str, in which case this loop would loop
# over the chars of the str. The technically correct way to match LoRA keys
@@ -82,35 +70,12 @@ def _maybe_adjust_config(config):
# But this cuts it for now.
exact_matches = [mod for mod in target_modules if mod == key]
substring_matches = [mod for mod in target_modules if key in mod and mod != key]
ambiguous_key = key
if exact_matches and substring_matches:
# if ambiguous, update the rank associated with the ambiguous key (`proj_out`, for example)
config["r"] = key_rank
# remove the ambiguous key from `rank_pattern` and record it as deleted
del config["rank_pattern"][key]
deleted_keys.add(key)
# For substring matches, add them with the original rank only if they haven't been assigned already
for mod in substring_matches:
if mod not in config["rank_pattern"] and mod not in deleted_keys:
config["rank_pattern"][mod] = original_r
# Update the rest of the target modules with the original rank if not already set and not deleted
for mod in target_modules:
if mod != ambiguous_key and mod not in config["rank_pattern"] and mod not in deleted_keys:
config["rank_pattern"][mod] = original_r
# Handle alphas to deal with cases like:
# https://github.com/huggingface/diffusers/pull/9999#issuecomment-2516180777
has_different_ranks = len(config["rank_pattern"]) > 1 and list(config["rank_pattern"])[0] != config["r"]
if has_different_ranks:
config["lora_alpha"] = config["r"]
alpha_pattern = {}
for module_name, rank in config["rank_pattern"].items():
alpha_pattern[module_name] = rank
config["alpha_pattern"] = alpha_pattern
return config
if is_peft_version("<", "0.14.1"):
raise ValueError(
"There are ambiguous keys present in this LoRA. To load it, please update your `peft` installation - `pip install -U peft`."
)
class PeftAdapterMixin:
@@ -286,16 +251,18 @@ class PeftAdapterMixin:
# Cannot figure out rank from lora layers that don't have atleast 2 dimensions.
# Bias layers in LoRA only have a single dimension
if "lora_B" in key and val.ndim > 1:
# TODO: revisit this after https://github.com/huggingface/peft/pull/2382 is merged.
rank[key] = val.shape[1]
# Check out https://github.com/huggingface/peft/pull/2419 for the `^` symbol.
# We may run into some ambiguous configuration values when a model has module
# names, sharing a common prefix (`proj_out.weight` and `blocks.transformer.proj_out.weight`,
# for example) and they have different LoRA ranks.
rank[f"^{key}"] = val.shape[1]
if network_alphas is not None and len(network_alphas) >= 1:
alpha_keys = [k for k in network_alphas.keys() if k.startswith(f"{prefix}.")]
network_alphas = {k.replace(f"{prefix}.", ""): v for k, v in network_alphas.items() if k in alpha_keys}
lora_config_kwargs = get_peft_kwargs(rank, network_alpha_dict=network_alphas, peft_state_dict=state_dict)
# TODO: revisit this after https://github.com/huggingface/peft/pull/2382 is merged.
lora_config_kwargs = _maybe_adjust_config(lora_config_kwargs)
_maybe_raise_error_for_ambiguity(lora_config_kwargs)
if "use_dora" in lora_config_kwargs:
if lora_config_kwargs["use_dora"]:
+2
View File
@@ -546,6 +546,7 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
num_blocks_per_group: Optional[int] = None,
non_blocking: bool = False,
use_stream: bool = False,
record_stream: bool = False,
low_cpu_mem_usage=False,
) -> None:
r"""
@@ -594,6 +595,7 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
num_blocks_per_group,
non_blocking,
use_stream,
record_stream,
low_cpu_mem_usage=low_cpu_mem_usage,
)
+10 -6
View File
@@ -139,10 +139,12 @@ def replace_with_bnb_linear(model, modules_to_not_convert=None, current_key_name
models by reducing the precision of the weights and activations, thus making models more efficient in terms
of both storage and computation.
"""
model, has_been_replaced = _replace_with_bnb_linear(
model, modules_to_not_convert, current_key_name, quantization_config
)
model, _ = _replace_with_bnb_linear(model, modules_to_not_convert, current_key_name, quantization_config)
has_been_replaced = any(
isinstance(replaced_module, (bnb.nn.Linear4bit, bnb.nn.Linear8bitLt))
for _, replaced_module in model.named_modules()
)
if not has_been_replaced:
logger.warning(
"You are loading your model in 8bit or 4bit but no linear modules were found in your model."
@@ -283,16 +285,18 @@ def dequantize_and_replace(
modules_to_not_convert=None,
quantization_config=None,
):
model, has_been_replaced = _dequantize_and_replace(
model, _ = _dequantize_and_replace(
model,
dtype=model.dtype,
modules_to_not_convert=modules_to_not_convert,
quantization_config=quantization_config,
)
has_been_replaced = any(
isinstance(replaced_module, torch.nn.Linear) for _, replaced_module in model.named_modules()
)
if not has_been_replaced:
logger.warning(
"For some reason the model has not been properly dequantized. You might see unexpected behavior."
"Some linear modules were not dequantized. This could lead to unexpected behaviour. Please check your model."
)
return model
+2
View File
@@ -400,6 +400,8 @@ class GGUFParameter(torch.nn.Parameter):
data = data if data is not None else torch.empty(0)
self = torch.Tensor._make_subclass(cls, data, requires_grad)
self.quant_type = quant_type
block_size, type_size = GGML_QUANT_SIZES[quant_type]
self.quant_shape = _quant_shape_from_byte_shape(self.shape, type_size, block_size)
return self
+1
View File
@@ -126,6 +126,7 @@ from .state_dict_utils import (
convert_state_dict_to_kohya,
convert_state_dict_to_peft,
convert_unet_state_dict_to_peft,
state_dict_all_zero,
)
from .typing_utils import _get_detailed_type, _is_valid_type
+14
View File
@@ -17,9 +17,14 @@ State dict utilities: utility methods for converting state dicts easily
import enum
from .import_utils import is_torch_available
from .logging import get_logger
if is_torch_available():
import torch
logger = get_logger(__name__)
@@ -333,3 +338,12 @@ def convert_state_dict_to_kohya(state_dict, original_type=None, **kwargs):
kohya_ss_state_dict[alpha_key] = torch.tensor(len(weight))
return kohya_ss_state_dict
def state_dict_all_zero(state_dict, filter_str=None):
if filter_str is not None:
if isinstance(filter_str, str):
filter_str = [filter_str]
state_dict = {k: v for k, v in state_dict.items() if any(f in k for f in filter_str)}
return all(torch.all(param == 0).item() for param in state_dict.values())
-85
View File
@@ -1377,88 +1377,3 @@ class Expectations(DevicePropertiesUserDict):
def __repr__(self):
return f"{self.data}"
def dynamic_slice_test(func):
"""
Decorator that injects an expected_slice parameter into a test function.
On the first run, it will capture the actual slice output and cache it.
On subsequent runs, it provides the cached slice as the expected slice.
Example:
```python
@dynamic_slice_test
def test_stable_diffusion_ddim(self, expected_slice=None):
# Run the pipeline
components = self.get_dummy_components()
sd_pipe = StableDiffusionPipeline(**components)
inputs = self.get_dummy_inputs("cpu")
image = sd_pipe(**inputs).images
image_slice = image[0, -3:, -3:, -1]
# If expected_slice is provided (from cache), assert against it
if expected_slice is not None:
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
# Always return the current slice for caching
return image_slice
```
"""
# Check if the function has the expected_slice parameter
sig = inspect.signature(func)
if "expected_slice" not in sig.parameters:
raise ValueError("The decorated function must have an 'expected_slice' parameter")
@functools.wraps(func)
def wrapper(*args, **kwargs):
# Get the test name from pytest
# pytest sets this environment variable to the current test
test_name = os.environ.get("PYTEST_CURRENT_TEST", "")
if test_name:
# Format is: test_file.py::TestClass::test_method (call)
test_name = test_name.split(" ")[0]
else:
# Fallback if not running in pytest
test_name = f"{func.__module__}.{func.__qualname__}"
# Create a unique filename based on hardware details
device_props = get_device_properties()
device_str = f"{device_props[0]}{device_props[1] if device_props[1] is not None else ''}"
# Setup cache directory
cache_dir = os.environ.get("DIFFUSERS_TEST_CACHE_DIR", ".test_cache")
os.makedirs(cache_dir, exist_ok=True)
cache_path = os.path.join(cache_dir, f"{test_name}_{device_str}.npy")
# Check for cached expected slice
cached_slice = None
if os.path.exists(cache_path):
try:
cached_slice = np.load(cache_path)
print(f"Using cached slice from {cache_path}")
except Exception as e:
print(f"Error loading cached slice: {e}")
# Run the test function with the expected slice injected
kwargs["expected_slice"] = cached_slice
actual_slice = func(*args, **kwargs)
# If the function returned a slice and there's no cached slice yet, cache it
if actual_slice is not None and cached_slice is None:
# Convert torch tensor to numpy if needed
if hasattr(actual_slice, "detach") and hasattr(actual_slice, "cpu") and hasattr(actual_slice, "numpy"):
actual_slice_np = actual_slice.detach().cpu().numpy()
else:
actual_slice_np = actual_slice
# Save the slice
try:
np.save(cache_path, actual_slice_np)
print(f"Saved slice to cache: {cache_path}")
except Exception as e:
print(f"Error saving slice to cache: {e}")
return actual_slice
return wrapper
+5 -2
View File
@@ -1525,8 +1525,9 @@ class ModelTesterMixin:
or abs(fp8_e4m3_fp32_max_memory - fp32_max_memory) < MB_TOLERANCE
)
@parameterized.expand([False, True])
@require_torch_gpu
def test_group_offloading(self):
def test_group_offloading(self, record_stream):
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
torch.manual_seed(0)
@@ -1566,7 +1567,9 @@ class ModelTesterMixin:
torch.manual_seed(0)
model = self.model_class(**init_dict)
model.enable_group_offload(torch_device, offload_type="leaf_level", use_stream=True)
model.enable_group_offload(
torch_device, offload_type="leaf_level", use_stream=True, record_stream=record_stream
)
output_with_group_offloading4 = run_forward(model)
self.assertTrue(torch.allclose(output_without_group_offloading, output_with_group_offloading1, atol=1e-5))
+58 -1
View File
@@ -21,8 +21,15 @@ import numpy as np
import pytest
import safetensors.torch
from huggingface_hub import hf_hub_download
from PIL import Image
from diffusers import BitsAndBytesConfig, DiffusionPipeline, FluxTransformer2DModel, SD3Transformer2DModel
from diffusers import (
BitsAndBytesConfig,
DiffusionPipeline,
FluxControlPipeline,
FluxTransformer2DModel,
SD3Transformer2DModel,
)
from diffusers.utils import is_accelerate_version, logging
from diffusers.utils.testing_utils import (
CaptureLogger,
@@ -63,6 +70,8 @@ if is_torch_available():
if is_bitsandbytes_available():
import bitsandbytes as bnb
from diffusers.quantizers.bitsandbytes.utils import replace_with_bnb_linear
@require_bitsandbytes_version_greater("0.43.2")
@require_accelerate
@@ -364,6 +373,18 @@ class BnB4BitBasicTests(Base4bitTests):
assert key_to_target in str(err_context.exception)
def test_bnb_4bit_logs_warning_for_no_quantization(self):
model_with_no_linear = torch.nn.Sequential(torch.nn.Conv2d(4, 4, 3), torch.nn.ReLU())
quantization_config = BitsAndBytesConfig(load_in_4bit=True)
logger = logging.get_logger("diffusers.quantizers.bitsandbytes.utils")
logger.setLevel(30)
with CaptureLogger(logger) as cap_logger:
_ = replace_with_bnb_linear(model_with_no_linear, quantization_config=quantization_config)
assert (
"You are loading your model in 8bit or 4bit but no linear modules were found in your model."
in cap_logger.out
)
class BnB4BitTrainingTests(Base4bitTests):
def setUp(self):
@@ -696,6 +717,42 @@ class SlowBnb4BitFluxTests(Base4bitTests):
self.assertTrue(max_diff < 1e-3)
@require_transformers_version_greater("4.44.0")
@require_peft_backend
class SlowBnb4BitFluxControlWithLoraTests(Base4bitTests):
def setUp(self) -> None:
gc.collect()
torch.cuda.empty_cache()
self.pipeline_4bit = FluxControlPipeline.from_pretrained("eramth/flux-4bit", torch_dtype=torch.float16)
self.pipeline_4bit.enable_model_cpu_offload()
def tearDown(self):
del self.pipeline_4bit
gc.collect()
torch.cuda.empty_cache()
def test_lora_loading(self):
self.pipeline_4bit.load_lora_weights("black-forest-labs/FLUX.1-Canny-dev-lora")
output = self.pipeline_4bit(
prompt=self.prompt,
control_image=Image.new(mode="RGB", size=(256, 256)),
height=256,
width=256,
max_sequence_length=64,
output_type="np",
num_inference_steps=8,
generator=torch.Generator().manual_seed(42),
).images
out_slice = output[0, -3:, -3:, -1].flatten()
expected_slice = np.array([0.1636, 0.1675, 0.1982, 0.1743, 0.1809, 0.1936, 0.1743, 0.2095, 0.2139])
max_diff = numpy_cosine_similarity_distance(expected_slice, out_slice)
self.assertTrue(max_diff < 1e-3, msg=f"{out_slice=} != {expected_slice=}")
@slow
class BaseBnb4BitSerializationTests(Base4bitTests):
def tearDown(self):
+14
View File
@@ -68,6 +68,8 @@ if is_torch_available():
if is_bitsandbytes_available():
import bitsandbytes as bnb
from diffusers.quantizers.bitsandbytes import replace_with_bnb_linear
@require_bitsandbytes_version_greater("0.43.2")
@require_accelerate
@@ -317,6 +319,18 @@ class BnB8bitBasicTests(Base8bitTests):
# Check that this does not throw an error
_ = self.model_fp16.to(torch_device)
def test_bnb_8bit_logs_warning_for_no_quantization(self):
model_with_no_linear = torch.nn.Sequential(torch.nn.Conv2d(4, 4, 3), torch.nn.ReLU())
quantization_config = BitsAndBytesConfig(load_in_8bit=True)
logger = logging.get_logger("diffusers.quantizers.bitsandbytes.utils")
logger.setLevel(30)
with CaptureLogger(logger) as cap_logger:
_ = replace_with_bnb_linear(model_with_no_linear, quantization_config=quantization_config)
assert (
"You are loading your model in 8bit or 4bit but no linear modules were found in your model."
in cap_logger.out
)
class Bnb8bitDeviceTests(Base8bitTests):
def setUp(self) -> None:
+46
View File
@@ -8,12 +8,14 @@ import torch.nn as nn
from diffusers import (
AuraFlowPipeline,
AuraFlowTransformer2DModel,
FluxControlPipeline,
FluxPipeline,
FluxTransformer2DModel,
GGUFQuantizationConfig,
SD3Transformer2DModel,
StableDiffusion3Pipeline,
)
from diffusers.utils import load_image
from diffusers.utils.testing_utils import (
is_gguf_available,
nightly,
@@ -21,6 +23,7 @@ from diffusers.utils.testing_utils import (
require_accelerate,
require_big_gpu_with_torch_cuda,
require_gguf_version_greater_or_equal,
require_peft_backend,
torch_device,
)
@@ -456,3 +459,46 @@ class AuraFlowGGUFSingleFileTests(GGUFSingleFileTesterMixin, unittest.TestCase):
)
max_diff = numpy_cosine_similarity_distance(expected_slice, output_slice)
assert max_diff < 1e-4
@require_peft_backend
@nightly
@require_big_gpu_with_torch_cuda
@require_accelerate
@require_gguf_version_greater_or_equal("0.10.0")
class FluxControlLoRAGGUFTests(unittest.TestCase):
def test_lora_loading(self):
ckpt_path = "https://huggingface.co/city96/FLUX.1-dev-gguf/blob/main/flux1-dev-Q2_K.gguf"
transformer = FluxTransformer2DModel.from_single_file(
ckpt_path,
quantization_config=GGUFQuantizationConfig(compute_dtype=torch.bfloat16),
torch_dtype=torch.bfloat16,
)
pipe = FluxControlPipeline.from_pretrained(
"black-forest-labs/FLUX.1-dev",
transformer=transformer,
torch_dtype=torch.bfloat16,
).to("cuda")
pipe.load_lora_weights("black-forest-labs/FLUX.1-Canny-dev-lora")
prompt = "A robot made of exotic candies and chocolates of different kinds. The background is filled with confetti and celebratory gifts."
control_image = load_image(
"https://huggingface.co/datasets/sayakpaul/sample-datasets/resolve/main/control_image_robot_canny.png"
)
output = pipe(
prompt=prompt,
control_image=control_image,
height=256,
width=256,
num_inference_steps=10,
guidance_scale=30.0,
output_type="np",
generator=torch.manual_seed(0),
).images
out_slice = output[0, -3:, -3:, -1].flatten()
expected_slice = np.array([0.8047, 0.8359, 0.8711, 0.6875, 0.7070, 0.7383, 0.5469, 0.5820, 0.6641])
max_diff = numpy_cosine_similarity_distance(expected_slice, out_slice)
self.assertTrue(max_diff < 1e-3)