Compare commits

..

12 Commits

Author SHA1 Message Date
DN6 532b395718 update 2025-07-17 21:56:48 +05:30
DN6 5c43924ac2 update 2025-07-17 19:57:45 +05:30
DN6 a633289e10 update 2025-07-16 19:41:48 +05:30
Aryan 06fd427797 [tests] Improve Flux tests (#11919)
update
2025-07-15 10:47:41 +05:30
dependabot[bot] 48a551251d Bump aiohttp from 3.10.10 to 3.12.14 in /examples/server (#11924)
Bumps [aiohttp](https://github.com/aio-libs/aiohttp) from 3.10.10 to 3.12.14.
- [Release notes](https://github.com/aio-libs/aiohttp/releases)
- [Changelog](https://github.com/aio-libs/aiohttp/blob/master/CHANGES.rst)
- [Commits](https://github.com/aio-libs/aiohttp/compare/v3.10.10...v3.12.14)

---
updated-dependencies:
- dependency-name: aiohttp
  dependency-version: 3.12.14
  dependency-type: direct:production
...

Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
2025-07-15 09:15:57 +05:30
Hengyue-Bi 6398fbc391 Fix: Align VAE processing in ControlNet SD3 training with inference (#11909)
Fix: Apply vae_shift_factor in ControlNet SD3 training
2025-07-14 14:54:38 -04:00
Colle 3c8b67b371 Flux: pass joint_attention_kwargs when using gradient_checkpointing (#11814)
Flux: pass joint_attention_kwargs when gradient_checkpointing
2025-07-11 08:35:18 -10:00
Steven Liu 9feb946432 [docs] torch.compile blog post (#11837)
* add blog post

* feedback

* feedback
2025-07-11 10:29:40 -07:00
Aryan c90352754a Speedup model loading by 4-5x (#11904)
* update

* update

* update

* pin accelerate version

* add comment explanations

* update docstring

* make style

* non_blocking does not matter for dtype cast

* _empty_cache -> clear_cache

* update

* Update src/diffusers/models/model_loading_utils.py

Co-authored-by: Marc Sun <57196510+SunMarc@users.noreply.github.com>

* Update src/diffusers/models/model_loading_utils.py

---------

Co-authored-by: Marc Sun <57196510+SunMarc@users.noreply.github.com>
2025-07-11 21:43:53 +05:30
Sayak Paul 7a935a0bbe [tests] Unify compilation + offloading tests in quantization (#11910)
* unify the quant compile + offloading tests.

* fix

* update
2025-07-11 17:02:29 +05:30
chenxiao 941b7fc084 Avoid creating tensor in CosmosAttnProcessor2_0 (#11761) (#11763)
* Avoid creating tensor in CosmosAttnProcessor2_0 (#11761)

* up

---------

Co-authored-by: yiyixuxu <yixu310@gmail.com>
2025-07-10 11:51:05 -10:00
Álvaro Somoza 76a62ac9cc [ControlnetUnion] Multiple Fixes (#11888)
fixes

---------

Co-authored-by: hlky <hlky@hlky.ac>
2025-07-10 14:35:28 -04:00
22 changed files with 260 additions and 206 deletions
+17 -14
View File
@@ -174,39 +174,36 @@ Feel free to open an issue if dynamic compilation doesn't work as expected for a
### Regional compilation
[Regional compilation](https://docs.pytorch.org/tutorials/recipes/regional_compilation.html) trims cold-start latency by only compiling the *small and frequently-repeated block(s)* of a model - typically a transformer layer - and enables reusing compiled artifacts for every subsequent occurrence.
For many diffusion architectures, this delivers the same runtime speedups as full-graph compilation and reduces compile time by 810x.
[Regional compilation](https://docs.pytorch.org/tutorials/recipes/regional_compilation.html) trims cold-start latency by compiling **only the small, frequently-repeated block(s)** of a model, typically a Transformer layer, enabling reuse of compiled artifacts for every subsequent occurrence.
For many diffusion architectures this delivers the *same* runtime speed-ups as full-graph compilation yet cuts compile time by **810 ×**.
To make this effortless, [`ModelMixin`] exposes [`ModelMixin.compile_repeated_blocks`] API, a helper that wraps `torch.compile` around any sub-modules you designate as repeatable:
Use the [`~ModelMixin.compile_repeated_blocks`] method, a helper that wraps `torch.compile`, on any component such as the transformer model as shown below.
```py
# pip install -U diffusers
import torch
from diffusers import StableDiffusionXLPipeline
pipe = StableDiffusionXLPipeline.from_pretrained(
pipeline = StableDiffusionXLPipeline.from_pretrained(
"stabilityai/stable-diffusion-xl-base-1.0",
torch_dtype=torch.float16,
).to("cuda")
# Compile only the repeated Transformer layers inside the UNet
pipe.unet.compile_repeated_blocks(fullgraph=True)
# compile only the repeated transformer layers inside the UNet
pipeline.unet.compile_repeated_blocks(fullgraph=True)
```
To enable a new model with regional compilation, add a `_repeated_blocks` attribute to your model class containing the class names (as strings) of the blocks you want compiled:
To enable regional compilation for a new model, add a `_repeated_blocks` attribute to a model class containing the class names (as strings) of the blocks you want to compile.
```py
class MyUNet(ModelMixin):
_repeated_blocks = ("Transformer2DModel",) # ← compiled by default
```
For more examples, see the reference [PR](https://github.com/huggingface/diffusers/pull/11705).
**Relation to Accelerate compile_regions** There is also a separate API in [accelerate](https://huggingface.co/docs/accelerate/index) - [compile_regions](https://github.com/huggingface/accelerate/blob/273799c85d849a1954a4f2e65767216eb37fa089/src/accelerate/utils/other.py#L78). It takes a fully automatic approach: it walks the module, picks candidate blocks, then compiles the remaining graph separately. That hands-off experience is handy for quick experiments, but it also leaves fewer knobs when you want to fine-tune which blocks are compiled or adjust compilation flags.
> [!TIP]
> For more regional compilation examples, see the reference [PR](https://github.com/huggingface/diffusers/pull/11705).
There is also a [compile_regions](https://github.com/huggingface/accelerate/blob/273799c85d849a1954a4f2e65767216eb37fa089/src/accelerate/utils/other.py#L78) method in [Accelerate](https://huggingface.co/docs/accelerate/index) that automatically selects candidate blocks in a model to compile. The remaining graph is compiled separately. This is useful for quick experiments because there aren't as many options for you to set which blocks to compile or adjust compilation flags.
```py
# pip install -U accelerate
@@ -219,8 +216,8 @@ pipeline = StableDiffusionXLPipeline.from_pretrained(
).to("cuda")
pipeline.unet = compile_regions(pipeline.unet, mode="reduce-overhead", fullgraph=True)
```
`compile_repeated_blocks`, by contrast, is intentionally explicit. You list the repeated blocks once (via `_repeated_blocks`) and the helper compiles exactly those, nothing more. In practice this small dose of control hits a sweet spot for diffusion models: predictable behavior, easy reasoning about cache reuse, and still a one-liner for users.
[`~ModelMixin.compile_repeated_blocks`] is intentionally explicit. List the blocks to repeat in `_repeated_blocks` and the helper only compiles those blocks. It offers predictable behavior and easy reasoning about cache reuse in one line of code.
### Graph breaks
@@ -296,3 +293,9 @@ An input is projected into three subspaces, represented by the projection matric
```py
pipeline.fuse_qkv_projections()
```
## Resources
- Read the [Presenting Flux Fast: Making Flux go brrr on H100s](https://pytorch.org/blog/presenting-flux-fast-making-flux-go-brrr-on-h100s/) blog post to learn more about how you can combine all of these optimizations with [TorchInductor](https://docs.pytorch.org/docs/stable/torch.compiler.html) and [AOTInductor](https://docs.pytorch.org/docs/stable/torch.compiler_aot_inductor.html) for a ~2.5x speedup using recipes from [flux-fast](https://github.com/huggingface/flux-fast).
These recipes support AMD hardware and [Flux.1 Kontext Dev](https://huggingface.co/black-forest-labs/FLUX.1-Kontext-dev).
@@ -14,6 +14,9 @@ specific language governing permissions and limitations under the License.
Optimizing models often involves trade-offs between [inference speed](./fp16) and [memory-usage](./memory). For instance, while [caching](./cache) can boost inference speed, it also increases memory consumption since it needs to store the outputs of intermediate attention layers. A more balanced optimization strategy combines quantizing a model, [torch.compile](./fp16#torchcompile) and various [offloading methods](./memory#offloading).
> [!TIP]
> Check the [torch.compile](./fp16#torchcompile) guide to learn more about compilation and how they can be applied here. For example, regional compilation can significantly reduce compilation time without giving up any speedups.
For image generation, combining quantization and [model offloading](./memory#model-offloading) can often give the best trade-off between quality, speed, and memory. Group offloading is not as effective for image generation because it is usually not possible to *fully* overlap data transfer if the compute kernel finishes faster. This results in some communication overhead between the CPU and GPU.
For video generation, combining quantization and [group-offloading](./memory#group-offloading) tends to be better because video models are more compute-bound.
@@ -25,7 +28,7 @@ The table below provides a comparison of optimization strategy combinations and
| quantization | 32.602 | 14.9453 |
| quantization, torch.compile | 25.847 | 14.9448 |
| quantization, torch.compile, model CPU offloading | 32.312 | 12.2369 |
<small>These results are benchmarked on Flux with a RTX 4090. The transformer and text_encoder components are quantized. Refer to the <a href="https://gist.github.com/sayakpaul/0db9d8eeeb3d2a0e5ed7cf0d9ca19b7d" benchmarking script</a> if you're interested in evaluating your own model.</small>
<small>These results are benchmarked on Flux with a RTX 4090. The transformer and text_encoder components are quantized. Refer to the [benchmarking script](https://gist.github.com/sayakpaul/0db9d8eeeb3d2a0e5ed7cf0d9ca19b7d) if you're interested in evaluating your own model.</small>
This guide will show you how to compile and offload a quantized model with [bitsandbytes](../quantization/bitsandbytes#torchcompile). Make sure you are using [PyTorch nightly](https://pytorch.org/get-started/locally/) and the latest version of bitsandbytes.
+1 -1
View File
@@ -1330,7 +1330,7 @@ def main(args):
# controlnet(s) inference
controlnet_image = batch["conditioning_pixel_values"].to(dtype=weight_dtype)
controlnet_image = vae.encode(controlnet_image).latent_dist.sample()
controlnet_image = controlnet_image * vae.config.scaling_factor
controlnet_image = (controlnet_image - vae.config.shift_factor) * vae.config.scaling_factor
control_block_res_samples = controlnet(
hidden_states=noisy_model_input,
+9 -6
View File
@@ -1,10 +1,10 @@
# This file was autogenerated by uv via the following command:
# uv pip compile requirements.in -o requirements.txt
aiohappyeyeballs==2.4.3
aiohappyeyeballs==2.6.1
# via aiohttp
aiohttp==3.10.10
aiohttp==3.12.14
# via -r requirements.in
aiosignal==1.3.1
aiosignal==1.4.0
# via aiohttp
annotated-types==0.7.0
# via pydantic
@@ -29,7 +29,6 @@ filelock==3.16.1
# huggingface-hub
# torch
# transformers
# triton
frozenlist==1.5.0
# via
# aiohttp
@@ -111,7 +110,9 @@ prometheus-client==0.21.0
prometheus-fastapi-instrumentator==7.0.0
# via -r requirements.in
propcache==0.2.0
# via yarl
# via
# aiohttp
# yarl
py-consul==1.5.3
# via -r requirements.in
pydantic==2.9.2
@@ -155,7 +156,9 @@ triton==3.3.0
# via torch
typing-extensions==4.12.2
# via
# aiosignal
# anyio
# exceptiongroup
# fastapi
# huggingface-hub
# multidict
@@ -168,5 +171,5 @@ urllib3==2.5.0
# via requests
uvicorn==0.32.0
# via -r requirements.in
yarl==1.16.0
yarl==1.18.3
# via aiohttp
@@ -24,6 +24,7 @@ from typing_extensions import Self
from .. import __version__
from ..quantizers import DiffusersAutoQuantizer
from ..utils import deprecate, is_accelerate_available, logging
from ..utils.torch_utils import device_synchronize, empty_device_cache
from .single_file_utils import (
SingleFileComponentError,
convert_animatediff_checkpoint_to_diffusers,
@@ -430,6 +431,10 @@ class FromOriginalModelMixin:
keep_in_fp32_modules=keep_in_fp32_modules,
unexpected_keys=unexpected_keys,
)
# Ensure tensors are correctly placed on device by synchronizing before returning control to user. This is
# required because we move tensors with non_blocking=True, which is slightly faster for model loading.
empty_device_cache()
device_synchronize()
else:
_, unexpected_keys = model.load_state_dict(diffusers_format_checkpoint, strict=False)
@@ -46,6 +46,7 @@ from ..utils import (
)
from ..utils.constants import DIFFUSERS_REQUEST_TIMEOUT
from ..utils.hub_utils import _get_model_file
from ..utils.torch_utils import device_synchronize, empty_device_cache
if is_transformers_available():
@@ -1689,6 +1690,10 @@ def create_diffusers_clip_model_from_ldm(
if is_accelerate_available():
load_model_dict_into_meta(model, diffusers_format_checkpoint, dtype=torch_dtype)
# Ensure tensors are correctly placed on device by synchronizing before returning control to user. This is
# required because we move tensors with non_blocking=True, which is slightly faster for model loading.
empty_device_cache()
device_synchronize()
else:
model.load_state_dict(diffusers_format_checkpoint, strict=False)
@@ -2148,6 +2153,10 @@ def create_diffusers_t5_model_from_checkpoint(
if is_accelerate_available():
load_model_dict_into_meta(model, diffusers_format_checkpoint, dtype=torch_dtype)
# Ensure tensors are correctly placed on device by synchronizing before returning control to user. This is
# required because we move tensors with non_blocking=True, which is slightly faster for model loading.
empty_device_cache()
device_synchronize()
else:
model.load_state_dict(diffusers_format_checkpoint)
+7 -5
View File
@@ -18,11 +18,8 @@ from ..models.embeddings import (
MultiIPAdapterImageProjection,
)
from ..models.modeling_utils import _LOW_CPU_MEM_USAGE_DEFAULT, load_model_dict_into_meta
from ..utils import (
is_accelerate_available,
is_torch_version,
logging,
)
from ..utils import is_accelerate_available, is_torch_version, logging
from ..utils.torch_utils import device_synchronize, empty_device_cache
if is_accelerate_available():
@@ -84,6 +81,8 @@ class FluxTransformer2DLoadersMixin:
else:
device_map = {"": self.device}
load_model_dict_into_meta(image_projection, updated_state_dict, device_map=device_map, dtype=self.dtype)
empty_device_cache()
device_synchronize()
return image_projection
@@ -158,6 +157,9 @@ class FluxTransformer2DLoadersMixin:
key_id += 1
empty_device_cache()
device_synchronize()
return attn_procs
def _load_ip_adapter_weights(self, state_dicts, low_cpu_mem_usage=_LOW_CPU_MEM_USAGE_DEFAULT):
+6
View File
@@ -18,6 +18,7 @@ from ..models.attention_processor import SD3IPAdapterJointAttnProcessor2_0
from ..models.embeddings import IPAdapterTimeImageProjection
from ..models.modeling_utils import _LOW_CPU_MEM_USAGE_DEFAULT, load_model_dict_into_meta
from ..utils import is_accelerate_available, is_torch_version, logging
from ..utils.torch_utils import device_synchronize, empty_device_cache
logger = logging.get_logger(__name__)
@@ -80,6 +81,9 @@ class SD3Transformer2DLoadersMixin:
attn_procs[name], layer_state_dict[idx], device_map=device_map, dtype=self.dtype
)
empty_device_cache()
device_synchronize()
return attn_procs
def _convert_ip_adapter_image_proj_to_diffusers(
@@ -147,6 +151,8 @@ class SD3Transformer2DLoadersMixin:
else:
device_map = {"": self.device}
load_model_dict_into_meta(image_proj, updated_state_dict, device_map=device_map, dtype=self.dtype)
empty_device_cache()
device_synchronize()
return image_proj
+6
View File
@@ -43,6 +43,7 @@ from ..utils import (
is_torch_version,
logging,
)
from ..utils.torch_utils import device_synchronize, empty_device_cache
from .lora_base import _func_optionally_disable_offloading
from .lora_pipeline import LORA_WEIGHT_NAME, LORA_WEIGHT_NAME_SAFE, TEXT_ENCODER_NAME, UNET_NAME
from .utils import AttnProcsLayers
@@ -753,6 +754,8 @@ class UNet2DConditionLoadersMixin:
else:
device_map = {"": self.device}
load_model_dict_into_meta(image_projection, updated_state_dict, device_map=device_map, dtype=self.dtype)
empty_device_cache()
device_synchronize()
return image_projection
@@ -850,6 +853,9 @@ class UNet2DConditionLoadersMixin:
key_id += 2
empty_device_cache()
device_synchronize()
return attn_procs
def _load_ip_adapter_weights(self, state_dicts, low_cpu_mem_usage=_LOW_CPU_MEM_USAGE_DEFAULT):
+64 -1
View File
@@ -16,9 +16,10 @@
import importlib
import inspect
import math
import os
from array import array
from collections import OrderedDict
from collections import OrderedDict, defaultdict
from pathlib import Path
from typing import Dict, List, Optional, Union
from zipfile import is_zipfile
@@ -38,6 +39,7 @@ from ..utils import (
_get_model_file,
deprecate,
is_accelerate_available,
is_accelerate_version,
is_gguf_available,
is_torch_available,
is_torch_version,
@@ -252,6 +254,10 @@ def load_model_dict_into_meta(
param = param.to(dtype)
set_module_kwargs["dtype"] = dtype
if is_accelerate_version(">", "1.8.1"):
set_module_kwargs["non_blocking"] = True
set_module_kwargs["clear_cache"] = False
# For compatibility with PyTorch load_state_dict which converts state dict dtype to existing dtype in model, and which
# uses `param.copy_(input_param)` that preserves the contiguity of the parameter in the model.
# Reference: https://github.com/pytorch/pytorch/blob/db79ceb110f6646523019a59bbd7b838f43d4a86/torch/nn/modules/module.py#L2040C29-L2040C29
@@ -520,3 +526,60 @@ def load_gguf_checkpoint(gguf_checkpoint_path, return_tensors=False):
parsed_parameters[name] = GGUFParameter(weights, quant_type=quant_type) if is_gguf_quant else weights
return parsed_parameters
def _find_mismatched_keys(state_dict, model_state_dict, loaded_keys, ignore_mismatched_sizes):
mismatched_keys = []
if not ignore_mismatched_sizes:
return mismatched_keys
for checkpoint_key in loaded_keys:
model_key = checkpoint_key
# If the checkpoint is sharded, we may not have the key here.
if checkpoint_key not in state_dict:
continue
if model_key in model_state_dict and state_dict[checkpoint_key].shape != model_state_dict[model_key].shape:
mismatched_keys.append(
(checkpoint_key, state_dict[checkpoint_key].shape, model_state_dict[model_key].shape)
)
del state_dict[checkpoint_key]
return mismatched_keys
def _expand_device_map(device_map, param_names):
"""
Expand a device map to return the correspondence parameter name to device.
"""
new_device_map = {}
for module, device in device_map.items():
new_device_map.update(
{p: device for p in param_names if p == module or p.startswith(f"{module}.") or module == ""}
)
return new_device_map
# Adapted from: https://github.com/huggingface/transformers/blob/0687d481e2c71544501ef9cb3eef795a6e79b1de/src/transformers/modeling_utils.py#L5859
def _caching_allocator_warmup(model, expanded_device_map: Dict[str, torch.device], dtype: torch.dtype) -> None:
"""
This function warm-ups the caching allocator based on the size of the model tensors that will reside on each
device. It allows to have one large call to Malloc, instead of recursively calling it later when loading the model,
which is actually the loading speed bottleneck. Calling this function allows to cut the model loading time by a
very large margin.
"""
# Remove disk and cpu devices, and cast to proper torch.device
accelerator_device_map = {
param: torch.device(device)
for param, device in expanded_device_map.items()
if str(device) not in ["cpu", "disk"]
}
parameter_count = defaultdict(lambda: 0)
for param_name, device in accelerator_device_map.items():
try:
param = model.get_parameter(param_name)
except AttributeError:
param = model.get_buffer(param_name)
parameter_count[device] += math.prod(param.shape)
# This will kick off the caching allocator to avoid having to Malloc afterwards
for device, param_count in parameter_count.items():
_ = torch.empty(param_count, dtype=dtype, device=device, requires_grad=False)
+27 -39
View File
@@ -62,10 +62,14 @@ from ..utils.hub_utils import (
load_or_create_model_card,
populate_model_card,
)
from ..utils.torch_utils import device_synchronize, empty_device_cache
from .model_loading_utils import (
_caching_allocator_warmup,
_determine_device_map,
_expand_device_map,
_fetch_index_file,
_fetch_index_file_legacy,
_find_mismatched_keys,
_load_state_dict_into_model,
load_model_dict_into_meta,
load_state_dict,
@@ -1469,11 +1473,6 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
for pat in cls._keys_to_ignore_on_load_unexpected:
unexpected_keys = [k for k in unexpected_keys if re.search(pat, k) is None]
mismatched_keys = []
assign_to_params_buffers = None
error_msgs = []
# Deal with offload
if device_map is not None and "disk" in device_map.values():
if offload_folder is None:
@@ -1482,18 +1481,27 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
" for them. Alternatively, make sure you have `safetensors` installed if the model you are using"
" offers the weights in this format."
)
if offload_folder is not None:
else:
os.makedirs(offload_folder, exist_ok=True)
if offload_state_dict is None:
offload_state_dict = True
# If a device map has been used, we can speedup the load time by warming up the device caching allocator.
# If we don't warmup, each tensor allocation on device calls to the allocator for memory (effectively, a
# lot of individual calls to device malloc). We can, however, preallocate the memory required by the
# tensors using their expected shape and not performing any initialization of the memory (empty data).
# When the actual device allocations happen, the allocator already has a pool of unused device memory
# that it can re-use for faster loading of the model.
# TODO: add support for warmup with hf_quantizer
if device_map is not None and hf_quantizer is None:
expanded_device_map = _expand_device_map(device_map, expected_keys)
_caching_allocator_warmup(model, expanded_device_map, dtype)
offload_index = {} if device_map is not None and "disk" in device_map.values() else None
state_dict_folder, state_dict_index = None, None
if offload_state_dict:
state_dict_folder = tempfile.mkdtemp()
state_dict_index = {}
else:
state_dict_folder = None
state_dict_index = None
if state_dict is not None:
# load_state_dict will manage the case where we pass a dict instead of a file
@@ -1503,38 +1511,14 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
if len(resolved_model_file) > 1:
resolved_model_file = logging.tqdm(resolved_model_file, desc="Loading checkpoint shards")
mismatched_keys = []
assign_to_params_buffers = None
error_msgs = []
for shard_file in resolved_model_file:
state_dict = load_state_dict(shard_file, dduf_entries=dduf_entries)
def _find_mismatched_keys(
state_dict,
model_state_dict,
loaded_keys,
ignore_mismatched_sizes,
):
mismatched_keys = []
if ignore_mismatched_sizes:
for checkpoint_key in loaded_keys:
model_key = checkpoint_key
# If the checkpoint is sharded, we may not have the key here.
if checkpoint_key not in state_dict:
continue
if (
model_key in model_state_dict
and state_dict[checkpoint_key].shape != model_state_dict[model_key].shape
):
mismatched_keys.append(
(checkpoint_key, state_dict[checkpoint_key].shape, model_state_dict[model_key].shape)
)
del state_dict[checkpoint_key]
return mismatched_keys
mismatched_keys += _find_mismatched_keys(
state_dict,
model_state_dict,
loaded_keys,
ignore_mismatched_sizes,
state_dict, model_state_dict, loaded_keys, ignore_mismatched_sizes
)
if low_cpu_mem_usage:
@@ -1554,9 +1538,13 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
else:
if assign_to_params_buffers is None:
assign_to_params_buffers = check_support_param_buffer_assignment(model, state_dict)
error_msgs += _load_state_dict_into_model(model, state_dict, assign_to_params_buffers)
# Ensure tensors are correctly placed on device by synchronizing before returning control to user. This is
# required because we move tensors with non_blocking=True, which is slightly faster for model loading.
empty_device_cache()
device_synchronize()
if offload_index is not None and len(offload_index) > 0:
save_offload_index(offload_index, offload_folder)
offload_index = None
@@ -187,9 +187,15 @@ class CosmosAttnProcessor2_0:
key = apply_rotary_emb(key, image_rotary_emb, use_real=True, use_real_unbind_dim=-2)
# 4. Prepare for GQA
query_idx = torch.tensor(query.size(3), device=query.device)
key_idx = torch.tensor(key.size(3), device=key.device)
value_idx = torch.tensor(value.size(3), device=value.device)
if torch.onnx.is_in_onnx_export():
query_idx = torch.tensor(query.size(3), device=query.device)
key_idx = torch.tensor(key.size(3), device=key.device)
value_idx = torch.tensor(value.size(3), device=value.device)
else:
query_idx = query.size(3)
key_idx = key.size(3)
value_idx = value.size(3)
key = key.repeat_interleave(query_idx // key_idx, dim=3)
value = value.repeat_interleave(query_idx // value_idx, dim=3)
@@ -490,6 +490,7 @@ class FluxTransformer2DModel(
encoder_hidden_states,
temb,
image_rotary_emb,
joint_attention_kwargs,
)
else:
@@ -521,6 +522,7 @@ class FluxTransformer2DModel(
encoder_hidden_states,
temb,
image_rotary_emb,
joint_attention_kwargs,
)
else:
@@ -323,6 +323,7 @@ class ModularPipelineBlocks(ConfigMixin, PushToHubMixin):
"""
config_name = "config.json"
model_name = None
@classmethod
def _get_signature_keys(cls, obj):
@@ -333,6 +334,14 @@ class ModularPipelineBlocks(ConfigMixin, PushToHubMixin):
return expected_modules, optional_parameters
@property
def expected_components(self) -> List[ComponentSpec]:
return []
@property
def expected_configs(self) -> List[ConfigSpec]:
return []
@classmethod
def from_pretrained(
cls,
@@ -358,7 +367,9 @@ class ModularPipelineBlocks(ConfigMixin, PushToHubMixin):
trust_remote_code, pretrained_model_name_or_path, has_remote_code
)
if not (has_remote_code and trust_remote_code):
raise ValueError("TODO")
raise ValueError(
"Selected model repository does not happear to have any custom code or does not have a valid `config.json` file."
)
class_ref = config["auto_map"][cls.__name__]
module_file, class_name = class_ref.split(".")
@@ -367,7 +378,6 @@ class ModularPipelineBlocks(ConfigMixin, PushToHubMixin):
pretrained_model_name_or_path,
module_file=module_file,
class_name=class_name,
is_modular=True,
**hub_kwargs,
**kwargs,
)
@@ -93,7 +93,7 @@ class ComponentSpec:
config: Optional[FrozenDict] = None
# YiYi Notes: should we change it to pretrained_model_name_or_path for consistency? a bit long for a field name
repo: Optional[Union[str, List[str]]] = field(default=None, metadata={"loading": True})
subfolder: Optional[str] = field(default=None, metadata={"loading": True})
subfolder: Optional[str] = field(default="", metadata={"loading": True})
variant: Optional[str] = field(default=None, metadata={"loading": True})
revision: Optional[str] = field(default=None, metadata={"loading": True})
default_creation_method: Literal["from_config", "from_pretrained"] = "from_pretrained"
+9
View File
@@ -184,5 +184,14 @@ def get_device():
def empty_device_cache(device_type: Optional[str] = None):
if device_type is None:
device_type = get_device()
if device_type in ["cpu"]:
return
device_mod = getattr(torch, device_type, torch.cuda)
device_mod.empty_cache()
def device_synchronize(device_type: Optional[str] = None):
if device_type is None:
device_type = get_device()
device_mod = getattr(torch, device_type, torch.cuda)
device_mod.synchronize()
+30 -77
View File
@@ -155,7 +155,7 @@ class FluxPipelineFastTests(
# Outputs should be different here
# For some reasons, they don't show large differences
assert max_diff > 1e-6
self.assertGreater(max_diff, 1e-6, "Outputs should be different for different prompts.")
def test_fused_qkv_projections(self):
device = "cpu" # ensure determinism for the device-dependent torch.Generator
@@ -187,14 +187,17 @@ class FluxPipelineFastTests(
image = pipe(**inputs).images
image_slice_disabled = image[0, -3:, -3:, -1]
assert np.allclose(original_image_slice, image_slice_fused, atol=1e-3, rtol=1e-3), (
"Fusion of QKV projections shouldn't affect the outputs."
self.assertTrue(
np.allclose(original_image_slice, image_slice_fused, atol=1e-3, rtol=1e-3),
("Fusion of QKV projections shouldn't affect the outputs."),
)
assert np.allclose(image_slice_fused, image_slice_disabled, atol=1e-3, rtol=1e-3), (
"Outputs, with QKV projection fusion enabled, shouldn't change when fused QKV projections are disabled."
self.assertTrue(
np.allclose(image_slice_fused, image_slice_disabled, atol=1e-3, rtol=1e-3),
("Outputs, with QKV projection fusion enabled, shouldn't change when fused QKV projections are disabled."),
)
assert np.allclose(original_image_slice, image_slice_disabled, atol=1e-2, rtol=1e-2), (
"Original outputs should match when fused QKV projections are disabled."
self.assertTrue(
np.allclose(original_image_slice, image_slice_disabled, atol=1e-2, rtol=1e-2),
("Original outputs should match when fused QKV projections are disabled."),
)
def test_flux_image_output_shape(self):
@@ -209,7 +212,11 @@ class FluxPipelineFastTests(
inputs.update({"height": height, "width": width})
image = pipe(**inputs).images[0]
output_height, output_width, _ = image.shape
assert (output_height, output_width) == (expected_height, expected_width)
self.assertEqual(
(output_height, output_width),
(expected_height, expected_width),
f"Output shape {image.shape} does not match expected shape {(expected_height, expected_width)}",
)
def test_flux_true_cfg(self):
pipe = self.pipeline_class(**self.get_dummy_components()).to(torch_device)
@@ -220,7 +227,9 @@ class FluxPipelineFastTests(
inputs["negative_prompt"] = "bad quality"
inputs["true_cfg_scale"] = 2.0
true_cfg_out = pipe(**inputs, generator=torch.manual_seed(0)).images[0]
assert not np.allclose(no_true_cfg_out, true_cfg_out)
self.assertFalse(
np.allclose(no_true_cfg_out, true_cfg_out), "Outputs should be different when true_cfg_scale is set."
)
@nightly
@@ -269,45 +278,17 @@ class FluxPipelineSlowTests(unittest.TestCase):
image = pipe(**inputs).images[0]
image_slice = image[0, :10, :10]
# fmt: off
expected_slice = np.array(
[
0.3242,
0.3203,
0.3164,
0.3164,
0.3125,
0.3125,
0.3281,
0.3242,
0.3203,
0.3301,
0.3262,
0.3242,
0.3281,
0.3242,
0.3203,
0.3262,
0.3262,
0.3164,
0.3262,
0.3281,
0.3184,
0.3281,
0.3281,
0.3203,
0.3281,
0.3281,
0.3164,
0.3320,
0.3320,
0.3203,
],
[0.3242, 0.3203, 0.3164, 0.3164, 0.3125, 0.3125, 0.3281, 0.3242, 0.3203, 0.3301, 0.3262, 0.3242, 0.3281, 0.3242, 0.3203, 0.3262, 0.3262, 0.3164, 0.3262, 0.3281, 0.3184, 0.3281, 0.3281, 0.3203, 0.3281, 0.3281, 0.3164, 0.3320, 0.3320, 0.3203],
dtype=np.float32,
)
# fmt: on
max_diff = numpy_cosine_similarity_distance(expected_slice.flatten(), image_slice.flatten())
assert max_diff < 1e-4
self.assertLess(
max_diff, 1e-4, f"Image slice is different from expected slice: {image_slice} != {expected_slice}"
)
@slow
@@ -377,42 +358,14 @@ class FluxIPAdapterPipelineSlowTests(unittest.TestCase):
image = pipe(**inputs).images[0]
image_slice = image[0, :10, :10]
# fmt: off
expected_slice = np.array(
[
0.1855,
0.1680,
0.1406,
0.1953,
0.1699,
0.1465,
0.2012,
0.1738,
0.1484,
0.2051,
0.1797,
0.1523,
0.2012,
0.1719,
0.1445,
0.2070,
0.1777,
0.1465,
0.2090,
0.1836,
0.1484,
0.2129,
0.1875,
0.1523,
0.2090,
0.1816,
0.1484,
0.2110,
0.1836,
0.1543,
],
[0.1855, 0.1680, 0.1406, 0.1953, 0.1699, 0.1465, 0.2012, 0.1738, 0.1484, 0.2051, 0.1797, 0.1523, 0.2012, 0.1719, 0.1445, 0.2070, 0.1777, 0.1465, 0.2090, 0.1836, 0.1484, 0.2129, 0.1875, 0.1523, 0.2090, 0.1816, 0.1484, 0.2110, 0.1836, 0.1543],
dtype=np.float32,
)
# fmt: on
max_diff = numpy_cosine_similarity_distance(expected_slice.flatten(), image_slice.flatten())
assert max_diff < 1e-4, f"{image_slice} != {expected_slice}"
self.assertLess(
max_diff, 1e-4, f"Image slice is different from expected slice: {image_slice} != {expected_slice}"
)
+4 -9
View File
@@ -873,11 +873,11 @@ class ExtendedSerializationTest(BaseBnb4BitSerializationTests):
@require_torch_version_greater("2.7.1")
@require_bitsandbytes_version_greater("0.45.5")
class Bnb4BitCompileTests(QuantCompileTests):
class Bnb4BitCompileTests(QuantCompileTests, unittest.TestCase):
@property
def quantization_config(self):
return PipelineQuantizationConfig(
quant_backend="bitsandbytes_8bit",
quant_backend="bitsandbytes_4bit",
quant_kwargs={
"load_in_4bit": True,
"bnb_4bit_quant_type": "nf4",
@@ -888,12 +888,7 @@ class Bnb4BitCompileTests(QuantCompileTests):
def test_torch_compile(self):
torch._dynamo.config.capture_dynamic_output_shape_ops = True
super()._test_torch_compile(quantization_config=self.quantization_config)
def test_torch_compile_with_cpu_offload(self):
super()._test_torch_compile_with_cpu_offload(quantization_config=self.quantization_config)
super().test_torch_compile()
def test_torch_compile_with_group_offload_leaf(self):
super()._test_torch_compile_with_group_offload_leaf(
quantization_config=self.quantization_config, use_stream=True
)
super()._test_torch_compile_with_group_offload_leaf(use_stream=True)
+4 -8
View File
@@ -838,7 +838,7 @@ class BaseBnb8bitSerializationTests(Base8bitTests):
@require_torch_version_greater_equal("2.6.0")
@require_bitsandbytes_version_greater("0.45.5")
class Bnb8BitCompileTests(QuantCompileTests):
class Bnb8BitCompileTests(QuantCompileTests, unittest.TestCase):
@property
def quantization_config(self):
return PipelineQuantizationConfig(
@@ -849,15 +849,11 @@ class Bnb8BitCompileTests(QuantCompileTests):
def test_torch_compile(self):
torch._dynamo.config.capture_dynamic_output_shape_ops = True
super()._test_torch_compile(quantization_config=self.quantization_config, torch_dtype=torch.float16)
super()._test_torch_compile(torch_dtype=torch.float16)
def test_torch_compile_with_cpu_offload(self):
super()._test_torch_compile_with_cpu_offload(
quantization_config=self.quantization_config, torch_dtype=torch.float16
)
super()._test_torch_compile_with_cpu_offload(torch_dtype=torch.float16)
@pytest.mark.xfail(reason="Test fails because of an offloading problem from Accelerate with confusion in hooks.")
def test_torch_compile_with_group_offload_leaf(self):
super()._test_torch_compile_with_group_offload_leaf(
quantization_config=self.quantization_config, torch_dtype=torch.float16, use_stream=True
)
super()._test_torch_compile_with_group_offload_leaf(torch_dtype=torch.float16, use_stream=True)
+1 -10
View File
@@ -654,7 +654,7 @@ class WanVACEGGUFSingleFileTests(GGUFSingleFileTesterMixin, unittest.TestCase):
@require_torch_version_greater("2.7.1")
class GGUFCompileTests(QuantCompileTests):
class GGUFCompileTests(QuantCompileTests, unittest.TestCase):
torch_dtype = torch.bfloat16
gguf_ckpt = "https://huggingface.co/city96/FLUX.1-dev-gguf/blob/main/flux1-dev-Q2_K.gguf"
@@ -662,15 +662,6 @@ class GGUFCompileTests(QuantCompileTests):
def quantization_config(self):
return GGUFQuantizationConfig(compute_dtype=self.torch_dtype)
def test_torch_compile(self):
super()._test_torch_compile(quantization_config=self.quantization_config)
def test_torch_compile_with_cpu_offload(self):
super()._test_torch_compile_with_cpu_offload(quantization_config=self.quantization_config)
def test_torch_compile_with_group_offload_leaf(self):
super()._test_torch_compile_with_group_offload_leaf(quantization_config=self.quantization_config)
def _init_pipeline(self, *args, **kwargs):
transformer = FluxTransformer2DModel.from_single_file(
self.gguf_ckpt, quantization_config=self.quantization_config, torch_dtype=self.torch_dtype
+28 -21
View File
@@ -13,7 +13,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import gc
import unittest
import inspect
import torch
@@ -23,7 +23,7 @@ from diffusers.utils.testing_utils import backend_empty_cache, require_torch_gpu
@require_torch_gpu
@slow
class QuantCompileTests(unittest.TestCase):
class QuantCompileTests:
@property
def quantization_config(self):
raise NotImplementedError(
@@ -50,30 +50,26 @@ class QuantCompileTests(unittest.TestCase):
)
return pipe
def _test_torch_compile(self, quantization_config, torch_dtype=torch.bfloat16):
pipe = self._init_pipeline(quantization_config, torch_dtype).to("cuda")
# import to ensure fullgraph True
def _test_torch_compile(self, torch_dtype=torch.bfloat16):
pipe = self._init_pipeline(self.quantization_config, torch_dtype).to("cuda")
# `fullgraph=True` ensures no graph breaks
pipe.transformer.compile(fullgraph=True)
for _ in range(2):
# small resolutions to ensure speedy execution.
pipe("a dog", num_inference_steps=3, max_sequence_length=16, height=256, width=256)
# small resolutions to ensure speedy execution.
pipe("a dog", num_inference_steps=2, max_sequence_length=16, height=256, width=256)
def _test_torch_compile_with_cpu_offload(self, quantization_config, torch_dtype=torch.bfloat16):
pipe = self._init_pipeline(quantization_config, torch_dtype)
def _test_torch_compile_with_cpu_offload(self, torch_dtype=torch.bfloat16):
pipe = self._init_pipeline(self.quantization_config, torch_dtype)
pipe.enable_model_cpu_offload()
pipe.transformer.compile()
for _ in range(2):
# small resolutions to ensure speedy execution.
pipe("a dog", num_inference_steps=3, max_sequence_length=16, height=256, width=256)
# small resolutions to ensure speedy execution.
pipe("a dog", num_inference_steps=2, max_sequence_length=16, height=256, width=256)
def _test_torch_compile_with_group_offload_leaf(
self, quantization_config, torch_dtype=torch.bfloat16, *, use_stream: bool = False
):
torch._dynamo.config.cache_size_limit = 10000
def _test_torch_compile_with_group_offload_leaf(self, torch_dtype=torch.bfloat16, *, use_stream: bool = False):
torch._dynamo.config.cache_size_limit = 1000
pipe = self._init_pipeline(quantization_config, torch_dtype)
pipe = self._init_pipeline(self.quantization_config, torch_dtype)
group_offload_kwargs = {
"onload_device": torch.device("cuda"),
"offload_device": torch.device("cpu"),
@@ -87,6 +83,17 @@ class QuantCompileTests(unittest.TestCase):
if torch.device(component.device).type == "cpu":
component.to("cuda")
for _ in range(2):
# small resolutions to ensure speedy execution.
pipe("a dog", num_inference_steps=3, max_sequence_length=16, height=256, width=256)
# small resolutions to ensure speedy execution.
pipe("a dog", num_inference_steps=2, max_sequence_length=16, height=256, width=256)
def test_torch_compile(self):
self._test_torch_compile()
def test_torch_compile_with_cpu_offload(self):
self._test_torch_compile_with_cpu_offload()
def test_torch_compile_with_group_offload_leaf(self, use_stream=False):
for cls in inspect.getmro(self.__class__):
if "test_torch_compile_with_group_offload_leaf" in cls.__dict__ and cls is not QuantCompileTests:
return
self._test_torch_compile_with_group_offload_leaf(use_stream=use_stream)
+5 -8
View File
@@ -630,7 +630,7 @@ class TorchAoSerializationTest(unittest.TestCase):
@require_torchao_version_greater_or_equal("0.7.0")
class TorchAoCompileTest(QuantCompileTests):
class TorchAoCompileTest(QuantCompileTests, unittest.TestCase):
@property
def quantization_config(self):
return PipelineQuantizationConfig(
@@ -639,17 +639,15 @@ class TorchAoCompileTest(QuantCompileTests):
},
)
def test_torch_compile(self):
super()._test_torch_compile(quantization_config=self.quantization_config)
@unittest.skip(
"Changing the device of AQT tensor with module._apply (called from doing module.to() in accelerate) does not work "
"when compiling."
)
def test_torch_compile_with_cpu_offload(self):
# RuntimeError: _apply(): Couldn't swap Linear.weight
super()._test_torch_compile_with_cpu_offload(quantization_config=self.quantization_config)
super().test_torch_compile_with_cpu_offload()
@parameterized.expand([False, True])
@unittest.skip(
"""
For `use_stream=False`:
@@ -659,8 +657,7 @@ class TorchAoCompileTest(QuantCompileTests):
Using non-default stream requires ability to pin tensors. AQT does not seem to support this yet in TorchAO.
"""
)
@parameterized.expand([False, True])
def test_torch_compile_with_group_offload_leaf(self):
def test_torch_compile_with_group_offload_leaf(self, use_stream):
# For use_stream=False:
# If we run group offloading without compilation, we will see:
# RuntimeError: Attempted to set the storage of a tensor on device "cpu" to a storage on different device "cuda:0". This is no longer allowed; the devices must match.
@@ -673,7 +670,7 @@ class TorchAoCompileTest(QuantCompileTests):
# For use_stream=True:
# NotImplementedError: AffineQuantizedTensor dispatch: attempting to run unimplemented operator/function: func=<OpOverload(op='aten.is_pinned', overload='default')>, types=(<class 'torchao.dtypes.affine_quantized_tensor.AffineQuantizedTensor'>,), arg_types=(<class 'torchao.dtypes.affine_quantized_tensor.AffineQuantizedTensor'>,), kwarg_types={}
super()._test_torch_compile_with_group_offload_leaf(quantization_config=self.quantization_config)
super()._test_torch_compile_with_group_offload_leaf(use_stream=use_stream)
# Slices for these tests have been obtained on our aws-g6e-xlarge-plus runners