Compare commits

..

5 Commits

Author SHA1 Message Date
sayakpaul b71269675e Release: v0.35.2-patch 2025-10-15 09:23:57 +05:30
Vladimir Mandic 36059182f1 fix scale_shift_factor being on cpu for wan and ltx (#12347)
* wan fix scale_shift_factor being on cpu

* apply device cast to ltx transformer

* Apply style fixes

---------

Co-authored-by: Dhruv Nair <dhruv.nair@gmail.com>
Co-authored-by: github-actions[bot] <github-actions[bot]@users.noreply.github.com>
2025-10-15 09:19:40 +05:30
Aishwarya Badlani 9169e81609 Fix PyTorch 2.3.1 compatibility: add version guard for torch.library.… (#12206)
* Fix PyTorch 2.3.1 compatibility: add version guard for torch.library.custom_op

- Add hasattr() check for torch.library.custom_op and register_fake
- These functions were added in PyTorch 2.4, causing import failures in 2.3.1
- Both decorators and functions are now properly guarded with version checks
- Maintains backward compatibility while preserving functionality

Fixes #12195

* Use dummy decorators approach for PyTorch version compatibility

- Replace hasattr check with version string comparison
- Add no-op decorator functions for PyTorch < 2.4.0
- Follows pattern from #11941 as suggested by reviewer
- Maintains cleaner code structure without indentation changes

* Update src/diffusers/models/attention_dispatch.py

Update all the decorator usages

Co-authored-by: Aryan <contact.aryanvs@gmail.com>

* Update src/diffusers/models/attention_dispatch.py

Co-authored-by: Aryan <contact.aryanvs@gmail.com>

* Update src/diffusers/models/attention_dispatch.py

Co-authored-by: Aryan <contact.aryanvs@gmail.com>

* Update src/diffusers/models/attention_dispatch.py

Co-authored-by: Aryan <contact.aryanvs@gmail.com>

* Move version check to top of file and use private naming as requested

* Apply style fixes

---------

Co-authored-by: Aryan <contact.aryanvs@gmail.com>
Co-authored-by: Aryan <aryan@huggingface.co>
Co-authored-by: github-actions[bot] <github-actions[bot]@users.noreply.github.com>
2025-10-15 09:14:51 +05:30
Dhruv Nair 8160289373 [CI] Fix TRANSFORMERS_FLAX_WEIGHTS_NAME import issue (#12354)
update
2025-10-15 09:09:25 +05:30
Sayak Paul 08782bf3bf handle offload_state_dict when initing transformers models (#12438) 2025-10-15 09:08:31 +05:30
7 changed files with 49 additions and 16 deletions
+1 -1
View File
@@ -269,7 +269,7 @@ version_range_max = max(sys.version_info[1], 10) + 1
setup(
name="diffusers",
version="0.35.1", # expected format is one of x.y.z.dev0, or x.y.z.rc1 or x.y.z (no to dashes, yes to dots)
version="0.35.2", # expected format is one of x.y.z.dev0, or x.y.z.rc1 or x.y.z (no to dashes, yes to dots)
description="State-of-the-art diffusion in PyTorch and JAX.",
long_description=open("README.md", "r", encoding="utf-8").read(),
long_description_content_type="text/markdown",
+1 -1
View File
@@ -1,4 +1,4 @@
__version__ = "0.35.1"
__version__ = "0.35.2"
from typing import TYPE_CHECKING
+25 -5
View File
@@ -110,6 +110,27 @@ if _CAN_USE_XFORMERS_ATTN:
else:
xops = None
# Version guard for PyTorch compatibility - custom_op was added in PyTorch 2.4
if torch.__version__ >= "2.4.0":
_custom_op = torch.library.custom_op
_register_fake = torch.library.register_fake
else:
def custom_op_no_op(name, fn=None, /, *, mutates_args, device_types=None, schema=None):
def wrap(func):
return func
return wrap if fn is None else fn
def register_fake_no_op(op, fn=None, /, *, lib=None, _stacklevel=1):
def wrap(func):
return func
return wrap if fn is None else fn
_custom_op = custom_op_no_op
_register_fake = register_fake_no_op
logger = get_logger(__name__) # pylint: disable=invalid-name
@@ -473,12 +494,11 @@ def _flex_attention_causal_mask_mod(batch_idx, head_idx, q_idx, kv_idx):
# ===== torch op registrations =====
# Registrations are required for fullgraph tracing compatibility
# TODO: library.custom_op and register_fake probably need version guards?
# TODO: this is only required because the beta release FA3 does not have it. There is a PR adding
# this but it was never merged: https://github.com/Dao-AILab/flash-attention/pull/1590
@torch.library.custom_op("flash_attn_3::_flash_attn_forward", mutates_args=(), device_types="cuda")
@_custom_op("flash_attn_3::_flash_attn_forward", mutates_args=(), device_types="cuda")
def _wrapped_flash_attn_3_original(
query: torch.Tensor, key: torch.Tensor, value: torch.Tensor
) -> Tuple[torch.Tensor, torch.Tensor]:
@@ -487,7 +507,7 @@ def _wrapped_flash_attn_3_original(
return out, lse
@torch.library.register_fake("flash_attn_3::_flash_attn_forward")
@_register_fake("flash_attn_3::_flash_attn_forward")
def _(query: torch.Tensor, key: torch.Tensor, value: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
batch_size, seq_len, num_heads, head_dim = query.shape
lse_shape = (batch_size, seq_len, num_heads)
@@ -350,7 +350,9 @@ class LTXVideoTransformerBlock(nn.Module):
norm_hidden_states = self.norm1(hidden_states)
num_ada_params = self.scale_shift_table.shape[0]
ada_values = self.scale_shift_table[None, None] + temb.reshape(batch_size, temb.size(1), num_ada_params, -1)
ada_values = self.scale_shift_table[None, None].to(temb.device) + temb.reshape(
batch_size, temb.size(1), num_ada_params, -1
)
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = ada_values.unbind(dim=2)
norm_hidden_states = norm_hidden_states * (1 + scale_msa) + shift_msa
@@ -665,12 +665,12 @@ class WanTransformer3DModel(
# 5. Output norm, projection & unpatchify
if temb.ndim == 3:
# batch_size, seq_len, inner_dim (wan 2.2 ti2v)
shift, scale = (self.scale_shift_table.unsqueeze(0) + temb.unsqueeze(2)).chunk(2, dim=2)
shift, scale = (self.scale_shift_table.unsqueeze(0).to(temb.device) + temb.unsqueeze(2)).chunk(2, dim=2)
shift = shift.squeeze(2)
scale = scale.squeeze(2)
else:
# batch_size, inner_dim
shift, scale = (self.scale_shift_table + temb.unsqueeze(1)).chunk(2, dim=1)
shift, scale = (self.scale_shift_table.to(temb.device) + temb.unsqueeze(1)).chunk(2, dim=1)
# Move the shift and scale tensors to the same device as hidden_states.
# When using multi-GPU inference via accelerate these will be on the
@@ -103,7 +103,7 @@ class WanVACETransformerBlock(nn.Module):
control_hidden_states = control_hidden_states + hidden_states
shift_msa, scale_msa, gate_msa, c_shift_msa, c_scale_msa, c_gate_msa = (
self.scale_shift_table + temb.float()
self.scale_shift_table.to(temb.device) + temb.float()
).chunk(6, dim=1)
# 1. Self-attention
@@ -359,7 +359,7 @@ class WanVACETransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromO
hidden_states = hidden_states + control_hint * scale
# 6. Output norm, projection & unpatchify
shift, scale = (self.scale_shift_table + temb.unsqueeze(1)).chunk(2, dim=1)
shift, scale = (self.scale_shift_table.to(temb.device) + temb.unsqueeze(1)).chunk(2, dim=1)
# Move the shift and scale tensors to the same device as hidden_states.
# When using multi-GPU inference via accelerate these will be on the
@@ -48,10 +48,12 @@ from .transformers_loading_utils import _load_tokenizer_from_dduf, _load_transfo
if is_transformers_available():
import transformers
from transformers import PreTrainedModel, PreTrainedTokenizerBase
from transformers.utils import FLAX_WEIGHTS_NAME as TRANSFORMERS_FLAX_WEIGHTS_NAME
from transformers.utils import SAFE_WEIGHTS_NAME as TRANSFORMERS_SAFE_WEIGHTS_NAME
from transformers.utils import WEIGHTS_NAME as TRANSFORMERS_WEIGHTS_NAME
if is_transformers_version("<=", "4.56.2"):
from transformers.utils import FLAX_WEIGHTS_NAME as TRANSFORMERS_FLAX_WEIGHTS_NAME
if is_accelerate_available():
import accelerate
from accelerate import dispatch_model
@@ -112,7 +114,9 @@ def is_safetensors_compatible(filenames, passed_components=None, folder_names=No
]
if is_transformers_available():
weight_names += [TRANSFORMERS_WEIGHTS_NAME, TRANSFORMERS_SAFE_WEIGHTS_NAME, TRANSFORMERS_FLAX_WEIGHTS_NAME]
weight_names += [TRANSFORMERS_WEIGHTS_NAME, TRANSFORMERS_SAFE_WEIGHTS_NAME]
if is_transformers_version("<=", "4.56.2"):
weight_names += [TRANSFORMERS_FLAX_WEIGHTS_NAME]
# model_pytorch, diffusion_model_pytorch, ...
weight_prefixes = [w.split(".")[0] for w in weight_names]
@@ -191,7 +195,9 @@ def filter_model_files(filenames):
]
if is_transformers_available():
weight_names += [TRANSFORMERS_WEIGHTS_NAME, TRANSFORMERS_SAFE_WEIGHTS_NAME, TRANSFORMERS_FLAX_WEIGHTS_NAME]
weight_names += [TRANSFORMERS_WEIGHTS_NAME, TRANSFORMERS_SAFE_WEIGHTS_NAME]
if is_transformers_version("<=", "4.56.2"):
weight_names += [TRANSFORMERS_FLAX_WEIGHTS_NAME]
allowed_extensions = [wn.split(".")[-1] for wn in weight_names]
@@ -212,7 +218,9 @@ def variant_compatible_siblings(filenames, variant=None, ignore_patterns=None) -
]
if is_transformers_available():
weight_names += [TRANSFORMERS_WEIGHTS_NAME, TRANSFORMERS_SAFE_WEIGHTS_NAME, TRANSFORMERS_FLAX_WEIGHTS_NAME]
weight_names += [TRANSFORMERS_WEIGHTS_NAME, TRANSFORMERS_SAFE_WEIGHTS_NAME]
if is_transformers_version("<=", "4.56.2"):
weight_names += [TRANSFORMERS_FLAX_WEIGHTS_NAME]
# model_pytorch, diffusion_model_pytorch, ...
weight_prefixes = [w.split(".")[0] for w in weight_names]
@@ -830,6 +838,9 @@ def load_sub_model(
else:
loading_kwargs["low_cpu_mem_usage"] = False
if is_transformers_model and is_transformers_version(">=", "4.57.0"):
loading_kwargs.pop("offload_state_dict")
if (
quantization_config is not None
and isinstance(quantization_config, PipelineQuantizationConfig)