Compare commits

..

10 Commits

Author SHA1 Message Date
DN6 cb69798b3d update 2025-10-27 18:11:28 +05:30
DN6 0229976ab5 update 2025-10-23 16:08:35 +05:30
Dhruv Nair 8f1b207ffd Merge branch 'main' into vace-fix 2025-10-23 15:11:28 +05:30
Sayak Paul ccdd96ca52 [tests] Test attention backends (#12388)
* add a lightweight test suite for attention backends.

* up

* up

* Apply suggestions from code review

* formatting
2025-10-23 15:09:41 +05:30
Sayak Paul 4c723d8ec3 [CI] xfail the test_wuerstchen_prior test (#12530)
xfail the test_wuerstchen_prior test
2025-10-22 08:45:47 -10:00
YiYi Xu bec2d8eaea Fix: Add _skip_keys for AutoencoderKLWan (#12523)
add
2025-10-22 07:53:13 -10:00
Álvaro Somoza a0a51eb098 Kandinsky5 No cfg fix (#12527)
fix
2025-10-22 22:02:47 +05:30
DN6 99308efb55 update 2025-10-03 16:48:43 +05:30
DN6 5015ce4fc7 update 2025-10-03 16:44:23 +05:30
DN6 5ed984cc47 update 2025-10-03 14:42:58 +05:30
11 changed files with 338 additions and 110 deletions
@@ -138,11 +138,10 @@ Refer to the table below for a complete list of available attention backends and
| `_native_npu` | [PyTorch native](https://docs.pytorch.org/docs/stable/generated/torch.nn.attention.SDPBackend.html#torch.nn.attention.SDPBackend) | NPU-optimized attention |
| `_native_xla` | [PyTorch native](https://docs.pytorch.org/docs/stable/generated/torch.nn.attention.SDPBackend.html#torch.nn.attention.SDPBackend) | XLA-optimized attention |
| `flash` | [FlashAttention](https://github.com/Dao-AILab/flash-attention) | FlashAttention-2 |
| `flash_hub` | [FlashAttention](https://github.com/Dao-AILab/flash-attention) | FlashAttention-2 from `kernels` |
| `flash_varlen` | [FlashAttention](https://github.com/Dao-AILab/flash-attention) | Variable length FlashAttention |
| `_flash_3` | [FlashAttention](https://github.com/Dao-AILab/flash-attention) | FlashAttention-3 |
| `_flash_varlen_3` | [FlashAttention](https://github.com/Dao-AILab/flash-attention) | Variable length FlashAttention-3 |
| `_flash_3_hub` | [FlashAttention](https://github.com/Dao-AILab/flash-attention) | FlashAttention-3 from `kernels` |
| `_flash_3_hub` | [FlashAttention](https://github.com/Dao-AILab/flash-attention) | FlashAttention-3 from kernels |
| `sage` | [SageAttention](https://github.com/thu-ml/SageAttention) | Quantized attention (INT8 QK) |
| `sage_varlen` | [SageAttention](https://github.com/thu-ml/SageAttention) | Variable length SageAttention |
| `_sage_qk_int8_pv_fp8_cuda` | [SageAttention](https://github.com/thu-ml/SageAttention) | INT8 QK + FP8 PV (CUDA) |
+7 -41
View File
@@ -83,15 +83,12 @@ if DIFFUSERS_ENABLE_HUB_KERNELS:
raise ImportError(
"To use FA3 kernel for your hardware from the Hub, the `kernels` library must be installed. Install with `pip install kernels`."
)
from ..utils.kernels_utils import _get_fa3_from_hub, _get_fa_from_hub
from ..utils.kernels_utils import _get_fa3_from_hub
fa3_interface_hub = _get_fa3_from_hub()
flash_attn_3_func_hub = fa3_interface_hub.flash_attn_func
fa_interface_hub = _get_fa_from_hub()
flash_attn_func_hub = fa_interface_hub.flash_attn_func
flash_attn_interface_hub = _get_fa3_from_hub()
flash_attn_3_func_hub = flash_attn_interface_hub.flash_attn_func
else:
flash_attn_3_func_hub = None
flash_attn_func_hub = None
if _CAN_USE_SAGE_ATTN:
from sageattention import (
@@ -176,8 +173,6 @@ class AttentionBackendName(str, Enum):
# `flash-attn`
FLASH = "flash"
FLASH_VARLEN = "flash_varlen"
FLASH_HUB = "flash_hub"
# FLASH_VARLEN_HUB = "flash_varlen_hub" # not supported yet.
_FLASH_3 = "_flash_3"
_FLASH_VARLEN_3 = "_flash_varlen_3"
_FLASH_3_HUB = "_flash_3_hub"
@@ -408,15 +403,15 @@ def _check_attention_backend_requirements(backend: AttentionBackendName) -> None
f"Flash Attention 3 backend '{backend.value}' is not usable because of missing package or the version is too old. Please build FA3 beta release from source."
)
# TODO: add support Hub variant of FA and FA3 varlen later
elif backend in [AttentionBackendName.FLASH_HUB, AttentionBackendName._FLASH_3_HUB]:
# TODO: add support Hub variant of FA3 varlen later
elif backend in [AttentionBackendName._FLASH_3_HUB]:
if not DIFFUSERS_ENABLE_HUB_KERNELS:
raise RuntimeError(
f"Flash Attention Hub backend '{backend.value}' is not usable because the `DIFFUSERS_ENABLE_HUB_KERNELS` env var isn't set. Please set it like `export DIFFUSERS_ENABLE_HUB_KERNELS=yes`."
f"Flash Attention 3 Hub backend '{backend.value}' is not usable because the `DIFFUSERS_ENABLE_HUB_KERNELS` env var isn't set. Please set it like `export DIFFUSERS_ENABLE_HUB_KERNELS=yes`."
)
if not is_kernels_available():
raise RuntimeError(
f"Flash Attention Hub backend '{backend.value}' is not usable because the `kernels` package isn't available. Please install it with `pip install kernels`."
f"Flash Attention 3 Hub backend '{backend.value}' is not usable because the `kernels` package isn't available. Please install it with `pip install kernels`."
)
elif backend in [
@@ -1233,35 +1228,6 @@ def _flash_attention(
return (out, lse) if return_lse else out
@_AttentionBackendRegistry.register(
AttentionBackendName.FLASH_HUB,
constraints=[_check_device, _check_qkv_dtype_bf16_or_fp16, _check_shape],
)
def _flash_attention_hub(
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
dropout_p: float = 0.0,
is_causal: bool = False,
scale: Optional[float] = None,
return_lse: bool = False,
) -> torch.Tensor:
lse = None
out = flash_attn_func_hub(
q=query,
k=key,
v=value,
dropout_p=dropout_p,
softmax_scale=scale,
causal=is_causal,
return_attn_probs=return_lse,
)
if return_lse:
out, lse, *_ = out
return (out, lse) if return_lse else out
@_AttentionBackendRegistry.register(
AttentionBackendName.FLASH_VARLEN,
constraints=[_check_device, _check_qkv_dtype_bf16_or_fp16, _check_shape],
@@ -453,14 +453,14 @@ class WanMidBlock(nn.Module):
def forward(self, x, feat_cache=None, feat_idx=[0]):
# First residual block
x = self.resnets[0](x, feat_cache, feat_idx)
x = self.resnets[0](x, feat_cache=feat_cache, feat_idx=feat_idx)
# Process through attention and residual blocks
for attn, resnet in zip(self.attentions, self.resnets[1:]):
if attn is not None:
x = attn(x)
x = resnet(x, feat_cache, feat_idx)
x = resnet(x, feat_cache=feat_cache, feat_idx=feat_idx)
return x
@@ -494,9 +494,9 @@ class WanResidualDownBlock(nn.Module):
def forward(self, x, feat_cache=None, feat_idx=[0]):
x_copy = x.clone()
for resnet in self.resnets:
x = resnet(x, feat_cache, feat_idx)
x = resnet(x, feat_cache=feat_cache, feat_idx=feat_idx)
if self.downsampler is not None:
x = self.downsampler(x, feat_cache, feat_idx)
x = self.downsampler(x, feat_cache=feat_cache, feat_idx=feat_idx)
return x + self.avg_shortcut(x_copy)
@@ -598,12 +598,12 @@ class WanEncoder3d(nn.Module):
## downsamples
for layer in self.down_blocks:
if feat_cache is not None:
x = layer(x, feat_cache, feat_idx)
x = layer(x, feat_cache=feat_cache, feat_idx=feat_idx)
else:
x = layer(x)
## middle
x = self.mid_block(x, feat_cache, feat_idx)
x = self.mid_block(x, feat_cache=feat_cache, feat_idx=feat_idx)
## head
x = self.norm_out(x)
@@ -694,13 +694,13 @@ class WanResidualUpBlock(nn.Module):
for resnet in self.resnets:
if feat_cache is not None:
x = resnet(x, feat_cache, feat_idx)
x = resnet(x, feat_cache=feat_cache, feat_idx=feat_idx)
else:
x = resnet(x)
if self.upsampler is not None:
if feat_cache is not None:
x = self.upsampler(x, feat_cache, feat_idx)
x = self.upsampler(x, feat_cache=feat_cache, feat_idx=feat_idx)
else:
x = self.upsampler(x)
@@ -767,13 +767,13 @@ class WanUpBlock(nn.Module):
"""
for resnet in self.resnets:
if feat_cache is not None:
x = resnet(x, feat_cache, feat_idx)
x = resnet(x, feat_cache=feat_cache, feat_idx=feat_idx)
else:
x = resnet(x)
if self.upsamplers is not None:
if feat_cache is not None:
x = self.upsamplers[0](x, feat_cache, feat_idx)
x = self.upsamplers[0](x, feat_cache=feat_cache, feat_idx=feat_idx)
else:
x = self.upsamplers[0](x)
return x
@@ -885,11 +885,11 @@ class WanDecoder3d(nn.Module):
x = self.conv_in(x)
## middle
x = self.mid_block(x, feat_cache, feat_idx)
x = self.mid_block(x, feat_cache=feat_cache, feat_idx=feat_idx)
## upsamples
for up_block in self.up_blocks:
x = up_block(x, feat_cache, feat_idx, first_chunk=first_chunk)
x = up_block(x, feat_cache=feat_cache, feat_idx=feat_idx, first_chunk=first_chunk)
## head
x = self.norm_out(x)
@@ -961,6 +961,9 @@ class AutoencoderKLWan(ModelMixin, AutoencoderMixin, ConfigMixin, FromOriginalMo
"""
_supports_gradient_checkpointing = False
# keys toignore when AlignDeviceHook moves inputs/outputs between devices
# these are shared mutable state modified in-place
_skip_keys = ["feat_cache", "feat_idx"]
@register_to_config
def __init__(
+1
View File
@@ -251,6 +251,7 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
_repeated_blocks = []
_parallel_config = None
_cp_plan = None
_skip_keys = None
def __init__(self):
super().__init__()
@@ -744,11 +744,13 @@ class Kandinsky5T2VPipeline(DiffusionPipeline, KandinskyLoraLoaderMixin):
)
if negative_prompt_embeds_qwen is None:
negative_prompt_embeds_qwen, negative_prompt_embeds_clip, negative_cu_seqlens = self.encode_prompt(
prompt=negative_prompt,
max_sequence_length=max_sequence_length,
device=device,
dtype=dtype,
negative_prompt_embeds_qwen, negative_prompt_embeds_clip, negative_prompt_cu_seqlens = (
self.encode_prompt(
prompt=negative_prompt,
max_sequence_length=max_sequence_length,
device=device,
dtype=dtype,
)
)
# 4. Prepare timesteps
@@ -780,8 +782,8 @@ class Kandinsky5T2VPipeline(DiffusionPipeline, KandinskyLoraLoaderMixin):
text_rope_pos = torch.arange(prompt_cu_seqlens.diff().max().item(), device=device)
negative_text_rope_pos = (
torch.arange(negative_cu_seqlens.diff().max().item(), device=device)
if negative_cu_seqlens is not None
torch.arange(negative_prompt_cu_seqlens.diff().max().item(), device=device)
if negative_prompt_cu_seqlens is not None
else None
)
@@ -866,6 +866,9 @@ def load_sub_model(
# remove hooks
remove_hook_from_module(loaded_sub_model, recurse=True)
needs_offloading_to_cpu = device_map[""] == "cpu"
skip_keys = None
if hasattr(loaded_sub_model, "_skip_keys") and loaded_sub_model._skip_keys is not None:
skip_keys = loaded_sub_model._skip_keys
if needs_offloading_to_cpu:
dispatch_model(
@@ -874,9 +877,10 @@ def load_sub_model(
device_map=device_map,
force_hooks=True,
main_device=0,
skip_keys=skip_keys,
)
else:
dispatch_model(loaded_sub_model, device_map=device_map, force_hooks=True)
dispatch_model(loaded_sub_model, device_map=device_map, force_hooks=True, skip_keys=skip_keys)
return loaded_sub_model
@@ -152,34 +152,36 @@ class WanVACEPipeline(DiffusionPipeline, WanLoraLoaderMixin):
text_encoder ([`T5EncoderModel`]):
[T5](https://huggingface.co/docs/transformers/en/model_doc/t5#transformers.T5EncoderModel), specifically
the [google/umt5-xxl](https://huggingface.co/google/umt5-xxl) variant.
transformer ([`WanVACETransformer3DModel`]):
Conditional Transformer to denoise the input latents.
transformer_2 ([`WanVACETransformer3DModel`], *optional*):
Conditional Transformer to denoise the input latents during the low-noise stage. In two-stage denoising,
`transformer` handles high-noise stages and `transformer_2` handles low-noise stages. If not provided, only
`transformer` is used.
scheduler ([`UniPCMultistepScheduler`]):
A scheduler to be used in combination with `transformer` to denoise the encoded image latents.
vae ([`AutoencoderKLWan`]):
Variational Auto-Encoder (VAE) Model to encode and decode videos to and from latent representations.
scheduler ([`UniPCMultistepScheduler`]):
A scheduler to be used in combination with `transformer` to denoise the encoded image latents.
transformer ([`WanVACETransformer3DModel`], *optional*):
Conditional Transformer to denoise the input latents during the high-noise stage. In two-stage denoising,
`transformer` handles high-noise stages and `transformer_2` handles low-noise stages. At least one of
`transformer` or `transformer_2` must be provided.
transformer_2 ([`WanVACETransformer3DModel`], *optional*):
Conditional Transformer to denoise the input latents during the low-noise stage. In two-stage denoising,
`transformer` handles high-noise stages and `transformer_2` handles low-noise stages. At least one of
`transformer` or `transformer_2` must be provided.
boundary_ratio (`float`, *optional*, defaults to `None`):
Ratio of total timesteps to use as the boundary for switching between transformers in two-stage denoising.
The actual boundary timestep is calculated as `boundary_ratio * num_train_timesteps`. When provided,
`transformer` handles timesteps >= boundary_timestep and `transformer_2` handles timesteps <
boundary_timestep. If `None`, only `transformer` is used for the entire denoising process.
boundary_timestep. If `None`, only the available transformer is used for the entire denoising process.
"""
model_cpu_offload_seq = "text_encoder->transformer->vae"
model_cpu_offload_seq = "text_encoder->transformer->transformer_2->vae"
_callback_tensor_inputs = ["latents", "prompt_embeds", "negative_prompt_embeds"]
_optional_components = ["transformer_2"]
_optional_components = ["transformer", "transformer_2"]
def __init__(
self,
tokenizer: AutoTokenizer,
text_encoder: UMT5EncoderModel,
transformer: WanVACETransformer3DModel,
vae: AutoencoderKLWan,
scheduler: FlowMatchEulerDiscreteScheduler,
transformer: WanVACETransformer3DModel = None,
transformer_2: WanVACETransformer3DModel = None,
boundary_ratio: Optional[float] = None,
):
@@ -336,7 +338,15 @@ class WanVACEPipeline(DiffusionPipeline, WanLoraLoaderMixin):
reference_images=None,
guidance_scale_2=None,
):
base = self.vae_scale_factor_spatial * self.transformer.config.patch_size[1]
if self.transformer is not None:
base = self.vae_scale_factor_spatial * self.transformer.config.patch_size[1]
elif self.transformer_2 is not None:
base = self.vae_scale_factor_spatial * self.transformer_2.config.patch_size[1]
else:
raise ValueError(
"`transformer` or `transformer_2` component must be set in order to run inference with this pipeline"
)
if height % base != 0 or width % base != 0:
raise ValueError(f"`height` and `width` have to be divisible by {base} but are {height} and {width}.")
@@ -414,7 +424,11 @@ class WanVACEPipeline(DiffusionPipeline, WanLoraLoaderMixin):
device: Optional[torch.device] = None,
):
if video is not None:
base = self.vae_scale_factor_spatial * self.transformer.config.patch_size[1]
base = self.vae_scale_factor_spatial * (
self.transformer.config.patch_size[1]
if self.transformer is not None
else self.transformer_2.config.patch_size[1]
)
video_height, video_width = self.video_processor.get_default_height_width(video[0])
if video_height * video_width > height * width:
@@ -589,7 +603,11 @@ class WanVACEPipeline(DiffusionPipeline, WanLoraLoaderMixin):
"Generating with more than one video is not yet supported. This may be supported in the future."
)
transformer_patch_size = self.transformer.config.patch_size[1]
transformer_patch_size = (
self.transformer.config.patch_size[1]
if self.transformer is not None
else self.transformer_2.config.patch_size[1]
)
mask_list = []
for mask_, reference_images_batch in zip(mask, reference_images):
@@ -844,20 +862,25 @@ class WanVACEPipeline(DiffusionPipeline, WanLoraLoaderMixin):
batch_size = prompt_embeds.shape[0]
vae_dtype = self.vae.dtype
transformer_dtype = self.transformer.dtype
transformer_dtype = self.transformer.dtype if self.transformer is not None else self.transformer_2.dtype
vace_layers = (
self.transformer.config.vace_layers
if self.transformer is not None
else self.transformer_2.config.vace_layers
)
if isinstance(conditioning_scale, (int, float)):
conditioning_scale = [conditioning_scale] * len(self.transformer.config.vace_layers)
conditioning_scale = [conditioning_scale] * len(vace_layers)
if isinstance(conditioning_scale, list):
if len(conditioning_scale) != len(self.transformer.config.vace_layers):
if len(conditioning_scale) != len(vace_layers):
raise ValueError(
f"Length of `conditioning_scale` {len(conditioning_scale)} does not match number of layers {len(self.transformer.config.vace_layers)}."
f"Length of `conditioning_scale` {len(conditioning_scale)} does not match number of layers {len(vace_layers)}."
)
conditioning_scale = torch.tensor(conditioning_scale)
if isinstance(conditioning_scale, torch.Tensor):
if conditioning_scale.size(0) != len(self.transformer.config.vace_layers):
if conditioning_scale.size(0) != len(vace_layers):
raise ValueError(
f"Length of `conditioning_scale` {conditioning_scale.size(0)} does not match number of layers {len(self.transformer.config.vace_layers)}."
f"Length of `conditioning_scale` {conditioning_scale.size(0)} does not match number of layers {len(vace_layers)}."
)
conditioning_scale = conditioning_scale.to(device=device, dtype=transformer_dtype)
@@ -900,7 +923,11 @@ class WanVACEPipeline(DiffusionPipeline, WanLoraLoaderMixin):
conditioning_latents = torch.cat([conditioning_latents, mask], dim=1)
conditioning_latents = conditioning_latents.to(transformer_dtype)
num_channels_latents = self.transformer.config.in_channels
num_channels_latents = (
self.transformer.config.in_channels
if self.transformer is not None
else self.transformer_2.config.in_channels
)
latents = self.prepare_latents(
batch_size * num_videos_per_prompt,
num_channels_latents,
@@ -968,7 +995,7 @@ class WanVACEPipeline(DiffusionPipeline, WanLoraLoaderMixin):
attention_kwargs=attention_kwargs,
return_dict=False,
)[0]
noise_pred = noise_uncond + guidance_scale * (noise_pred - noise_uncond)
noise_pred = noise_uncond + current_guidance_scale * (noise_pred - noise_uncond)
# compute the previous noisy sample x_t -> x_t-1
latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0]
+12 -22
View File
@@ -2,32 +2,22 @@ from ..utils import get_logger
from .import_utils import is_kernels_available
if is_kernels_available():
from kernels import get_kernel
logger = get_logger(__name__)
_DEFAULT_HUB_IDS = {
"fa3": ("kernels-community/flash-attn3", {"revision": "fake-ops-return-probs"}),
"fa": ("kernels-community/flash-attn", {}),
}
def _get_from_hub(key: str):
if not is_kernels_available():
return None
hub_id, kwargs = _DEFAULT_HUB_IDS[key]
try:
return get_kernel(hub_id, **kwargs)
except Exception as e:
logger.error(f"An error occurred while fetching kernel '{hub_id}' from the Hub: {e}")
raise
_DEFAULT_HUB_ID_FA3 = "kernels-community/flash-attn3"
def _get_fa3_from_hub():
return _get_from_hub("fa3")
if not is_kernels_available():
return None
else:
from kernels import get_kernel
def _get_fa_from_hub():
return _get_from_hub("fa")
try:
# TODO: temporary revision for now. Remove when merged upstream into `main`.
flash_attn_3_hub = get_kernel(_DEFAULT_HUB_ID_FA3, revision="fake-ops-return-probs")
return flash_attn_3_hub
except Exception as e:
logger.error(f"An error occurred while fetching kernel '{_DEFAULT_HUB_ID_FA3}' from the Hub: {e}")
raise
+144
View File
@@ -0,0 +1,144 @@
"""
This test suite exists for the maintainers currently. It's not run in our CI at the moment.
Once attention backends become more mature, we can consider including this in our CI.
To run this test suite:
```bash
export RUN_ATTENTION_BACKEND_TESTS=yes
export DIFFUSERS_ENABLE_HUB_KERNELS=yes
pytest tests/others/test_attention_backends.py
```
Tests were conducted on an H100 with PyTorch 2.8.0 (CUDA 12.9). Slices for the compilation tests in
"native" variants were obtained with a torch nightly version (2.10.0.dev20250924+cu128).
"""
import os
import pytest
import torch
pytestmark = pytest.mark.skipif(
os.getenv("RUN_ATTENTION_BACKEND_TESTS", "false") == "false", reason="Feature not mature enough."
)
from diffusers import FluxPipeline # noqa: E402
from diffusers.utils import is_torch_version # noqa: E402
# fmt: off
FORWARD_CASES = [
("flash_hub", None),
(
"_flash_3_hub",
torch.tensor([0.0820, 0.0859, 0.0938, 0.1016, 0.0977, 0.0996, 0.1016, 0.1016, 0.2188, 0.2246, 0.2344, 0.2480, 0.2539, 0.2480, 0.2441, 0.2715], dtype=torch.bfloat16),
),
(
"native",
torch.tensor([0.0820, 0.0859, 0.0938, 0.1016, 0.0957, 0.0996, 0.0996, 0.1016, 0.2188, 0.2266, 0.2363, 0.2500, 0.2539, 0.2480, 0.2461, 0.2734], dtype=torch.bfloat16)
),
(
"_native_cudnn",
torch.tensor([0.0781, 0.0840, 0.0879, 0.0957, 0.0898, 0.0957, 0.0957, 0.0977, 0.2168, 0.2246, 0.2324, 0.2500, 0.2539, 0.2480, 0.2441, 0.2695], dtype=torch.bfloat16),
),
]
COMPILE_CASES = [
("flash_hub", None, True),
(
"_flash_3_hub",
torch.tensor([0.0410, 0.0410, 0.0449, 0.0508, 0.0508, 0.0605, 0.0625, 0.0605, 0.2344, 0.2461, 0.2578, 0.2734, 0.2852, 0.2812, 0.2773, 0.3047], dtype=torch.bfloat16),
True,
),
(
"native",
torch.tensor([0.0410, 0.0410, 0.0449, 0.0508, 0.0508, 0.0605, 0.0605, 0.0605, 0.2344, 0.2461, 0.2578, 0.2773, 0.2871, 0.2832, 0.2773, 0.3066], dtype=torch.bfloat16),
True,
),
(
"_native_cudnn",
torch.tensor([0.0410, 0.0410, 0.0430, 0.0508, 0.0488, 0.0586, 0.0605, 0.0586, 0.2344, 0.2461, 0.2578, 0.2773, 0.2871, 0.2832, 0.2793, 0.3086], dtype=torch.bfloat16),
True,
),
]
# fmt: on
INFER_KW = {
"prompt": "dance doggo dance",
"height": 256,
"width": 256,
"num_inference_steps": 2,
"guidance_scale": 3.5,
"max_sequence_length": 128,
"output_type": "pt",
}
def _backend_is_probably_supported(pipe, name: str):
try:
pipe.transformer.set_attention_backend(name)
return pipe, True
except Exception:
return False
def _check_if_slices_match(output, expected_slice):
img = output.images.detach().cpu()
generated_slice = img.flatten()
generated_slice = torch.cat([generated_slice[:8], generated_slice[-8:]])
assert torch.allclose(generated_slice, expected_slice, atol=1e-4)
@pytest.fixture(scope="session")
def device():
if not torch.cuda.is_available():
pytest.skip("CUDA is required for these tests.")
return torch.device("cuda:0")
@pytest.fixture(scope="session")
def pipe(device):
repo_id = "black-forest-labs/FLUX.1-dev"
pipe = FluxPipeline.from_pretrained(repo_id, torch_dtype=torch.bfloat16).to(device)
pipe.set_progress_bar_config(disable=True)
return pipe
@pytest.mark.parametrize("backend_name,expected_slice", FORWARD_CASES, ids=[c[0] for c in FORWARD_CASES])
def test_forward(pipe, backend_name, expected_slice):
out = _backend_is_probably_supported(pipe, backend_name)
if isinstance(out, bool):
pytest.xfail(f"Backend '{backend_name}' not supported in this environment.")
modified_pipe = out[0]
out = modified_pipe(**INFER_KW, generator=torch.manual_seed(0))
_check_if_slices_match(out, expected_slice)
@pytest.mark.parametrize(
"backend_name,expected_slice,error_on_recompile",
COMPILE_CASES,
ids=[c[0] for c in COMPILE_CASES],
)
def test_forward_with_compile(pipe, backend_name, expected_slice, error_on_recompile):
if "native" in backend_name and error_on_recompile and not is_torch_version(">=", "2.9.0"):
pytest.xfail(f"Test with {backend_name=} is compatible with a higher version of torch.")
out = _backend_is_probably_supported(pipe, backend_name)
if isinstance(out, bool):
pytest.xfail(f"Backend '{backend_name}' not supported in this environment.")
modified_pipe = out[0]
modified_pipe.transformer.compile(fullgraph=True)
torch.compiler.reset()
with (
torch._inductor.utils.fresh_inductor_cache(),
torch._dynamo.config.patch(error_on_recompile=error_on_recompile),
):
out = modified_pipe(**INFER_KW, generator=torch.manual_seed(0))
_check_if_slices_match(out, expected_slice)
@@ -17,11 +17,13 @@ import gc
import unittest
import numpy as np
import pytest
import torch
from transformers import CLIPTextConfig, CLIPTextModelWithProjection, CLIPTokenizer
from diffusers import DDPMWuerstchenScheduler, StableCascadePriorPipeline
from diffusers.models import StableCascadeUNet
from diffusers.utils import is_transformers_version
from diffusers.utils.import_utils import is_peft_available
from ...testing_utils import (
@@ -154,6 +156,11 @@ class StableCascadePriorPipelineFastTests(PipelineTesterMixin, unittest.TestCase
}
return inputs
@pytest.mark.xfail(
condition=is_transformers_version(">=", "4.57.1"),
reason="Test fails with the latest transformers version",
strict=False,
)
def test_wuerstchen_prior(self):
device = "cpu"
+87 -2
View File
@@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import tempfile
import unittest
import numpy as np
@@ -19,9 +20,15 @@ import torch
from PIL import Image
from transformers import AutoTokenizer, T5EncoderModel
from diffusers import AutoencoderKLWan, FlowMatchEulerDiscreteScheduler, WanVACEPipeline, WanVACETransformer3DModel
from diffusers import (
AutoencoderKLWan,
FlowMatchEulerDiscreteScheduler,
UniPCMultistepScheduler,
WanVACEPipeline,
WanVACETransformer3DModel,
)
from ...testing_utils import enable_full_determinism
from ...testing_utils import enable_full_determinism, torch_device
from ..pipeline_params import TEXT_TO_IMAGE_BATCH_PARAMS, TEXT_TO_IMAGE_IMAGE_PARAMS, TEXT_TO_IMAGE_PARAMS
from ..test_pipelines_common import PipelineTesterMixin
@@ -212,3 +219,81 @@ class WanVACEPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
)
def test_save_load_float16(self):
pass
def test_inference_with_only_transformer(self):
components = self.get_dummy_components()
components["transformer_2"] = None
components["boundary_ratio"] = 0.0
pipe = self.pipeline_class(**components)
pipe.to(torch_device)
pipe.set_progress_bar_config(disable=None)
inputs = self.get_dummy_inputs(torch_device)
video = pipe(**inputs).frames[0]
assert video.shape == (17, 3, 16, 16)
def test_inference_with_only_transformer_2(self):
components = self.get_dummy_components()
components["transformer_2"] = components["transformer"]
components["transformer"] = None
# FlowMatchEulerDiscreteScheduler doesn't support running low noise only scheduler
# because starting timestep t == 1000 == boundary_timestep
components["scheduler"] = UniPCMultistepScheduler(
prediction_type="flow_prediction", use_flow_sigmas=True, flow_shift=3.0
)
components["boundary_ratio"] = 1.0
pipe = self.pipeline_class(**components)
pipe.to(torch_device)
pipe.set_progress_bar_config(disable=None)
inputs = self.get_dummy_inputs(torch_device)
video = pipe(**inputs).frames[0]
assert video.shape == (17, 3, 16, 16)
def test_save_load_optional_components(self, expected_max_difference=1e-4):
optional_component = ["transformer"]
components = self.get_dummy_components()
components["transformer_2"] = components["transformer"]
# FlowMatchEulerDiscreteScheduler doesn't support running low noise only scheduler
# because starting timestep t == 1000 == boundary_timestep
components["scheduler"] = UniPCMultistepScheduler(
prediction_type="flow_prediction", use_flow_sigmas=True, flow_shift=3.0
)
for component in optional_component:
components[component] = None
components["boundary_ratio"] = 1.0
pipe = self.pipeline_class(**components)
for component in pipe.components.values():
if hasattr(component, "set_default_attn_processor"):
component.set_default_attn_processor()
pipe.to(torch_device)
pipe.set_progress_bar_config(disable=None)
generator_device = "cpu"
inputs = self.get_dummy_inputs(generator_device)
torch.manual_seed(0)
output = pipe(**inputs)[0]
with tempfile.TemporaryDirectory() as tmpdir:
pipe.save_pretrained(tmpdir, safe_serialization=False)
pipe_loaded = self.pipeline_class.from_pretrained(tmpdir)
for component in pipe_loaded.components.values():
if hasattr(component, "set_default_attn_processor"):
component.set_default_attn_processor()
pipe_loaded.to(torch_device)
pipe_loaded.set_progress_bar_config(disable=None)
for component in optional_component:
assert getattr(pipe_loaded, component) is None, f"`{component}` did not stay set to None after loading."
inputs = self.get_dummy_inputs(generator_device)
torch.manual_seed(0)
output_loaded = pipe_loaded(**inputs)[0]
max_diff = np.abs(output.detach().cpu().numpy() - output_loaded.detach().cpu().numpy()).max()
assert max_diff < expected_max_difference, "Outputs exceed expecpted maximum difference"