Compare commits
10 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| cb69798b3d | |||
| 0229976ab5 | |||
| 8f1b207ffd | |||
| ccdd96ca52 | |||
| 4c723d8ec3 | |||
| bec2d8eaea | |||
| a0a51eb098 | |||
| 99308efb55 | |||
| 5015ce4fc7 | |||
| 5ed984cc47 |
@@ -138,11 +138,10 @@ Refer to the table below for a complete list of available attention backends and
|
||||
| `_native_npu` | [PyTorch native](https://docs.pytorch.org/docs/stable/generated/torch.nn.attention.SDPBackend.html#torch.nn.attention.SDPBackend) | NPU-optimized attention |
|
||||
| `_native_xla` | [PyTorch native](https://docs.pytorch.org/docs/stable/generated/torch.nn.attention.SDPBackend.html#torch.nn.attention.SDPBackend) | XLA-optimized attention |
|
||||
| `flash` | [FlashAttention](https://github.com/Dao-AILab/flash-attention) | FlashAttention-2 |
|
||||
| `flash_hub` | [FlashAttention](https://github.com/Dao-AILab/flash-attention) | FlashAttention-2 from `kernels` |
|
||||
| `flash_varlen` | [FlashAttention](https://github.com/Dao-AILab/flash-attention) | Variable length FlashAttention |
|
||||
| `_flash_3` | [FlashAttention](https://github.com/Dao-AILab/flash-attention) | FlashAttention-3 |
|
||||
| `_flash_varlen_3` | [FlashAttention](https://github.com/Dao-AILab/flash-attention) | Variable length FlashAttention-3 |
|
||||
| `_flash_3_hub` | [FlashAttention](https://github.com/Dao-AILab/flash-attention) | FlashAttention-3 from `kernels` |
|
||||
| `_flash_3_hub` | [FlashAttention](https://github.com/Dao-AILab/flash-attention) | FlashAttention-3 from kernels |
|
||||
| `sage` | [SageAttention](https://github.com/thu-ml/SageAttention) | Quantized attention (INT8 QK) |
|
||||
| `sage_varlen` | [SageAttention](https://github.com/thu-ml/SageAttention) | Variable length SageAttention |
|
||||
| `_sage_qk_int8_pv_fp8_cuda` | [SageAttention](https://github.com/thu-ml/SageAttention) | INT8 QK + FP8 PV (CUDA) |
|
||||
|
||||
@@ -83,15 +83,12 @@ if DIFFUSERS_ENABLE_HUB_KERNELS:
|
||||
raise ImportError(
|
||||
"To use FA3 kernel for your hardware from the Hub, the `kernels` library must be installed. Install with `pip install kernels`."
|
||||
)
|
||||
from ..utils.kernels_utils import _get_fa3_from_hub, _get_fa_from_hub
|
||||
from ..utils.kernels_utils import _get_fa3_from_hub
|
||||
|
||||
fa3_interface_hub = _get_fa3_from_hub()
|
||||
flash_attn_3_func_hub = fa3_interface_hub.flash_attn_func
|
||||
fa_interface_hub = _get_fa_from_hub()
|
||||
flash_attn_func_hub = fa_interface_hub.flash_attn_func
|
||||
flash_attn_interface_hub = _get_fa3_from_hub()
|
||||
flash_attn_3_func_hub = flash_attn_interface_hub.flash_attn_func
|
||||
else:
|
||||
flash_attn_3_func_hub = None
|
||||
flash_attn_func_hub = None
|
||||
|
||||
if _CAN_USE_SAGE_ATTN:
|
||||
from sageattention import (
|
||||
@@ -176,8 +173,6 @@ class AttentionBackendName(str, Enum):
|
||||
# `flash-attn`
|
||||
FLASH = "flash"
|
||||
FLASH_VARLEN = "flash_varlen"
|
||||
FLASH_HUB = "flash_hub"
|
||||
# FLASH_VARLEN_HUB = "flash_varlen_hub" # not supported yet.
|
||||
_FLASH_3 = "_flash_3"
|
||||
_FLASH_VARLEN_3 = "_flash_varlen_3"
|
||||
_FLASH_3_HUB = "_flash_3_hub"
|
||||
@@ -408,15 +403,15 @@ def _check_attention_backend_requirements(backend: AttentionBackendName) -> None
|
||||
f"Flash Attention 3 backend '{backend.value}' is not usable because of missing package or the version is too old. Please build FA3 beta release from source."
|
||||
)
|
||||
|
||||
# TODO: add support Hub variant of FA and FA3 varlen later
|
||||
elif backend in [AttentionBackendName.FLASH_HUB, AttentionBackendName._FLASH_3_HUB]:
|
||||
# TODO: add support Hub variant of FA3 varlen later
|
||||
elif backend in [AttentionBackendName._FLASH_3_HUB]:
|
||||
if not DIFFUSERS_ENABLE_HUB_KERNELS:
|
||||
raise RuntimeError(
|
||||
f"Flash Attention Hub backend '{backend.value}' is not usable because the `DIFFUSERS_ENABLE_HUB_KERNELS` env var isn't set. Please set it like `export DIFFUSERS_ENABLE_HUB_KERNELS=yes`."
|
||||
f"Flash Attention 3 Hub backend '{backend.value}' is not usable because the `DIFFUSERS_ENABLE_HUB_KERNELS` env var isn't set. Please set it like `export DIFFUSERS_ENABLE_HUB_KERNELS=yes`."
|
||||
)
|
||||
if not is_kernels_available():
|
||||
raise RuntimeError(
|
||||
f"Flash Attention Hub backend '{backend.value}' is not usable because the `kernels` package isn't available. Please install it with `pip install kernels`."
|
||||
f"Flash Attention 3 Hub backend '{backend.value}' is not usable because the `kernels` package isn't available. Please install it with `pip install kernels`."
|
||||
)
|
||||
|
||||
elif backend in [
|
||||
@@ -1233,35 +1228,6 @@ def _flash_attention(
|
||||
return (out, lse) if return_lse else out
|
||||
|
||||
|
||||
@_AttentionBackendRegistry.register(
|
||||
AttentionBackendName.FLASH_HUB,
|
||||
constraints=[_check_device, _check_qkv_dtype_bf16_or_fp16, _check_shape],
|
||||
)
|
||||
def _flash_attention_hub(
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
dropout_p: float = 0.0,
|
||||
is_causal: bool = False,
|
||||
scale: Optional[float] = None,
|
||||
return_lse: bool = False,
|
||||
) -> torch.Tensor:
|
||||
lse = None
|
||||
out = flash_attn_func_hub(
|
||||
q=query,
|
||||
k=key,
|
||||
v=value,
|
||||
dropout_p=dropout_p,
|
||||
softmax_scale=scale,
|
||||
causal=is_causal,
|
||||
return_attn_probs=return_lse,
|
||||
)
|
||||
if return_lse:
|
||||
out, lse, *_ = out
|
||||
|
||||
return (out, lse) if return_lse else out
|
||||
|
||||
|
||||
@_AttentionBackendRegistry.register(
|
||||
AttentionBackendName.FLASH_VARLEN,
|
||||
constraints=[_check_device, _check_qkv_dtype_bf16_or_fp16, _check_shape],
|
||||
|
||||
@@ -453,14 +453,14 @@ class WanMidBlock(nn.Module):
|
||||
|
||||
def forward(self, x, feat_cache=None, feat_idx=[0]):
|
||||
# First residual block
|
||||
x = self.resnets[0](x, feat_cache, feat_idx)
|
||||
x = self.resnets[0](x, feat_cache=feat_cache, feat_idx=feat_idx)
|
||||
|
||||
# Process through attention and residual blocks
|
||||
for attn, resnet in zip(self.attentions, self.resnets[1:]):
|
||||
if attn is not None:
|
||||
x = attn(x)
|
||||
|
||||
x = resnet(x, feat_cache, feat_idx)
|
||||
x = resnet(x, feat_cache=feat_cache, feat_idx=feat_idx)
|
||||
|
||||
return x
|
||||
|
||||
@@ -494,9 +494,9 @@ class WanResidualDownBlock(nn.Module):
|
||||
def forward(self, x, feat_cache=None, feat_idx=[0]):
|
||||
x_copy = x.clone()
|
||||
for resnet in self.resnets:
|
||||
x = resnet(x, feat_cache, feat_idx)
|
||||
x = resnet(x, feat_cache=feat_cache, feat_idx=feat_idx)
|
||||
if self.downsampler is not None:
|
||||
x = self.downsampler(x, feat_cache, feat_idx)
|
||||
x = self.downsampler(x, feat_cache=feat_cache, feat_idx=feat_idx)
|
||||
|
||||
return x + self.avg_shortcut(x_copy)
|
||||
|
||||
@@ -598,12 +598,12 @@ class WanEncoder3d(nn.Module):
|
||||
## downsamples
|
||||
for layer in self.down_blocks:
|
||||
if feat_cache is not None:
|
||||
x = layer(x, feat_cache, feat_idx)
|
||||
x = layer(x, feat_cache=feat_cache, feat_idx=feat_idx)
|
||||
else:
|
||||
x = layer(x)
|
||||
|
||||
## middle
|
||||
x = self.mid_block(x, feat_cache, feat_idx)
|
||||
x = self.mid_block(x, feat_cache=feat_cache, feat_idx=feat_idx)
|
||||
|
||||
## head
|
||||
x = self.norm_out(x)
|
||||
@@ -694,13 +694,13 @@ class WanResidualUpBlock(nn.Module):
|
||||
|
||||
for resnet in self.resnets:
|
||||
if feat_cache is not None:
|
||||
x = resnet(x, feat_cache, feat_idx)
|
||||
x = resnet(x, feat_cache=feat_cache, feat_idx=feat_idx)
|
||||
else:
|
||||
x = resnet(x)
|
||||
|
||||
if self.upsampler is not None:
|
||||
if feat_cache is not None:
|
||||
x = self.upsampler(x, feat_cache, feat_idx)
|
||||
x = self.upsampler(x, feat_cache=feat_cache, feat_idx=feat_idx)
|
||||
else:
|
||||
x = self.upsampler(x)
|
||||
|
||||
@@ -767,13 +767,13 @@ class WanUpBlock(nn.Module):
|
||||
"""
|
||||
for resnet in self.resnets:
|
||||
if feat_cache is not None:
|
||||
x = resnet(x, feat_cache, feat_idx)
|
||||
x = resnet(x, feat_cache=feat_cache, feat_idx=feat_idx)
|
||||
else:
|
||||
x = resnet(x)
|
||||
|
||||
if self.upsamplers is not None:
|
||||
if feat_cache is not None:
|
||||
x = self.upsamplers[0](x, feat_cache, feat_idx)
|
||||
x = self.upsamplers[0](x, feat_cache=feat_cache, feat_idx=feat_idx)
|
||||
else:
|
||||
x = self.upsamplers[0](x)
|
||||
return x
|
||||
@@ -885,11 +885,11 @@ class WanDecoder3d(nn.Module):
|
||||
x = self.conv_in(x)
|
||||
|
||||
## middle
|
||||
x = self.mid_block(x, feat_cache, feat_idx)
|
||||
x = self.mid_block(x, feat_cache=feat_cache, feat_idx=feat_idx)
|
||||
|
||||
## upsamples
|
||||
for up_block in self.up_blocks:
|
||||
x = up_block(x, feat_cache, feat_idx, first_chunk=first_chunk)
|
||||
x = up_block(x, feat_cache=feat_cache, feat_idx=feat_idx, first_chunk=first_chunk)
|
||||
|
||||
## head
|
||||
x = self.norm_out(x)
|
||||
@@ -961,6 +961,9 @@ class AutoencoderKLWan(ModelMixin, AutoencoderMixin, ConfigMixin, FromOriginalMo
|
||||
"""
|
||||
|
||||
_supports_gradient_checkpointing = False
|
||||
# keys toignore when AlignDeviceHook moves inputs/outputs between devices
|
||||
# these are shared mutable state modified in-place
|
||||
_skip_keys = ["feat_cache", "feat_idx"]
|
||||
|
||||
@register_to_config
|
||||
def __init__(
|
||||
|
||||
@@ -251,6 +251,7 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
|
||||
_repeated_blocks = []
|
||||
_parallel_config = None
|
||||
_cp_plan = None
|
||||
_skip_keys = None
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@@ -744,11 +744,13 @@ class Kandinsky5T2VPipeline(DiffusionPipeline, KandinskyLoraLoaderMixin):
|
||||
)
|
||||
|
||||
if negative_prompt_embeds_qwen is None:
|
||||
negative_prompt_embeds_qwen, negative_prompt_embeds_clip, negative_cu_seqlens = self.encode_prompt(
|
||||
prompt=negative_prompt,
|
||||
max_sequence_length=max_sequence_length,
|
||||
device=device,
|
||||
dtype=dtype,
|
||||
negative_prompt_embeds_qwen, negative_prompt_embeds_clip, negative_prompt_cu_seqlens = (
|
||||
self.encode_prompt(
|
||||
prompt=negative_prompt,
|
||||
max_sequence_length=max_sequence_length,
|
||||
device=device,
|
||||
dtype=dtype,
|
||||
)
|
||||
)
|
||||
|
||||
# 4. Prepare timesteps
|
||||
@@ -780,8 +782,8 @@ class Kandinsky5T2VPipeline(DiffusionPipeline, KandinskyLoraLoaderMixin):
|
||||
text_rope_pos = torch.arange(prompt_cu_seqlens.diff().max().item(), device=device)
|
||||
|
||||
negative_text_rope_pos = (
|
||||
torch.arange(negative_cu_seqlens.diff().max().item(), device=device)
|
||||
if negative_cu_seqlens is not None
|
||||
torch.arange(negative_prompt_cu_seqlens.diff().max().item(), device=device)
|
||||
if negative_prompt_cu_seqlens is not None
|
||||
else None
|
||||
)
|
||||
|
||||
|
||||
@@ -866,6 +866,9 @@ def load_sub_model(
|
||||
# remove hooks
|
||||
remove_hook_from_module(loaded_sub_model, recurse=True)
|
||||
needs_offloading_to_cpu = device_map[""] == "cpu"
|
||||
skip_keys = None
|
||||
if hasattr(loaded_sub_model, "_skip_keys") and loaded_sub_model._skip_keys is not None:
|
||||
skip_keys = loaded_sub_model._skip_keys
|
||||
|
||||
if needs_offloading_to_cpu:
|
||||
dispatch_model(
|
||||
@@ -874,9 +877,10 @@ def load_sub_model(
|
||||
device_map=device_map,
|
||||
force_hooks=True,
|
||||
main_device=0,
|
||||
skip_keys=skip_keys,
|
||||
)
|
||||
else:
|
||||
dispatch_model(loaded_sub_model, device_map=device_map, force_hooks=True)
|
||||
dispatch_model(loaded_sub_model, device_map=device_map, force_hooks=True, skip_keys=skip_keys)
|
||||
|
||||
return loaded_sub_model
|
||||
|
||||
|
||||
@@ -152,34 +152,36 @@ class WanVACEPipeline(DiffusionPipeline, WanLoraLoaderMixin):
|
||||
text_encoder ([`T5EncoderModel`]):
|
||||
[T5](https://huggingface.co/docs/transformers/en/model_doc/t5#transformers.T5EncoderModel), specifically
|
||||
the [google/umt5-xxl](https://huggingface.co/google/umt5-xxl) variant.
|
||||
transformer ([`WanVACETransformer3DModel`]):
|
||||
Conditional Transformer to denoise the input latents.
|
||||
transformer_2 ([`WanVACETransformer3DModel`], *optional*):
|
||||
Conditional Transformer to denoise the input latents during the low-noise stage. In two-stage denoising,
|
||||
`transformer` handles high-noise stages and `transformer_2` handles low-noise stages. If not provided, only
|
||||
`transformer` is used.
|
||||
scheduler ([`UniPCMultistepScheduler`]):
|
||||
A scheduler to be used in combination with `transformer` to denoise the encoded image latents.
|
||||
vae ([`AutoencoderKLWan`]):
|
||||
Variational Auto-Encoder (VAE) Model to encode and decode videos to and from latent representations.
|
||||
scheduler ([`UniPCMultistepScheduler`]):
|
||||
A scheduler to be used in combination with `transformer` to denoise the encoded image latents.
|
||||
transformer ([`WanVACETransformer3DModel`], *optional*):
|
||||
Conditional Transformer to denoise the input latents during the high-noise stage. In two-stage denoising,
|
||||
`transformer` handles high-noise stages and `transformer_2` handles low-noise stages. At least one of
|
||||
`transformer` or `transformer_2` must be provided.
|
||||
transformer_2 ([`WanVACETransformer3DModel`], *optional*):
|
||||
Conditional Transformer to denoise the input latents during the low-noise stage. In two-stage denoising,
|
||||
`transformer` handles high-noise stages and `transformer_2` handles low-noise stages. At least one of
|
||||
`transformer` or `transformer_2` must be provided.
|
||||
boundary_ratio (`float`, *optional*, defaults to `None`):
|
||||
Ratio of total timesteps to use as the boundary for switching between transformers in two-stage denoising.
|
||||
The actual boundary timestep is calculated as `boundary_ratio * num_train_timesteps`. When provided,
|
||||
`transformer` handles timesteps >= boundary_timestep and `transformer_2` handles timesteps <
|
||||
boundary_timestep. If `None`, only `transformer` is used for the entire denoising process.
|
||||
boundary_timestep. If `None`, only the available transformer is used for the entire denoising process.
|
||||
"""
|
||||
|
||||
model_cpu_offload_seq = "text_encoder->transformer->vae"
|
||||
model_cpu_offload_seq = "text_encoder->transformer->transformer_2->vae"
|
||||
_callback_tensor_inputs = ["latents", "prompt_embeds", "negative_prompt_embeds"]
|
||||
_optional_components = ["transformer_2"]
|
||||
_optional_components = ["transformer", "transformer_2"]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
tokenizer: AutoTokenizer,
|
||||
text_encoder: UMT5EncoderModel,
|
||||
transformer: WanVACETransformer3DModel,
|
||||
vae: AutoencoderKLWan,
|
||||
scheduler: FlowMatchEulerDiscreteScheduler,
|
||||
transformer: WanVACETransformer3DModel = None,
|
||||
transformer_2: WanVACETransformer3DModel = None,
|
||||
boundary_ratio: Optional[float] = None,
|
||||
):
|
||||
@@ -336,7 +338,15 @@ class WanVACEPipeline(DiffusionPipeline, WanLoraLoaderMixin):
|
||||
reference_images=None,
|
||||
guidance_scale_2=None,
|
||||
):
|
||||
base = self.vae_scale_factor_spatial * self.transformer.config.patch_size[1]
|
||||
if self.transformer is not None:
|
||||
base = self.vae_scale_factor_spatial * self.transformer.config.patch_size[1]
|
||||
elif self.transformer_2 is not None:
|
||||
base = self.vae_scale_factor_spatial * self.transformer_2.config.patch_size[1]
|
||||
else:
|
||||
raise ValueError(
|
||||
"`transformer` or `transformer_2` component must be set in order to run inference with this pipeline"
|
||||
)
|
||||
|
||||
if height % base != 0 or width % base != 0:
|
||||
raise ValueError(f"`height` and `width` have to be divisible by {base} but are {height} and {width}.")
|
||||
|
||||
@@ -414,7 +424,11 @@ class WanVACEPipeline(DiffusionPipeline, WanLoraLoaderMixin):
|
||||
device: Optional[torch.device] = None,
|
||||
):
|
||||
if video is not None:
|
||||
base = self.vae_scale_factor_spatial * self.transformer.config.patch_size[1]
|
||||
base = self.vae_scale_factor_spatial * (
|
||||
self.transformer.config.patch_size[1]
|
||||
if self.transformer is not None
|
||||
else self.transformer_2.config.patch_size[1]
|
||||
)
|
||||
video_height, video_width = self.video_processor.get_default_height_width(video[0])
|
||||
|
||||
if video_height * video_width > height * width:
|
||||
@@ -589,7 +603,11 @@ class WanVACEPipeline(DiffusionPipeline, WanLoraLoaderMixin):
|
||||
"Generating with more than one video is not yet supported. This may be supported in the future."
|
||||
)
|
||||
|
||||
transformer_patch_size = self.transformer.config.patch_size[1]
|
||||
transformer_patch_size = (
|
||||
self.transformer.config.patch_size[1]
|
||||
if self.transformer is not None
|
||||
else self.transformer_2.config.patch_size[1]
|
||||
)
|
||||
|
||||
mask_list = []
|
||||
for mask_, reference_images_batch in zip(mask, reference_images):
|
||||
@@ -844,20 +862,25 @@ class WanVACEPipeline(DiffusionPipeline, WanLoraLoaderMixin):
|
||||
batch_size = prompt_embeds.shape[0]
|
||||
|
||||
vae_dtype = self.vae.dtype
|
||||
transformer_dtype = self.transformer.dtype
|
||||
transformer_dtype = self.transformer.dtype if self.transformer is not None else self.transformer_2.dtype
|
||||
|
||||
vace_layers = (
|
||||
self.transformer.config.vace_layers
|
||||
if self.transformer is not None
|
||||
else self.transformer_2.config.vace_layers
|
||||
)
|
||||
if isinstance(conditioning_scale, (int, float)):
|
||||
conditioning_scale = [conditioning_scale] * len(self.transformer.config.vace_layers)
|
||||
conditioning_scale = [conditioning_scale] * len(vace_layers)
|
||||
if isinstance(conditioning_scale, list):
|
||||
if len(conditioning_scale) != len(self.transformer.config.vace_layers):
|
||||
if len(conditioning_scale) != len(vace_layers):
|
||||
raise ValueError(
|
||||
f"Length of `conditioning_scale` {len(conditioning_scale)} does not match number of layers {len(self.transformer.config.vace_layers)}."
|
||||
f"Length of `conditioning_scale` {len(conditioning_scale)} does not match number of layers {len(vace_layers)}."
|
||||
)
|
||||
conditioning_scale = torch.tensor(conditioning_scale)
|
||||
if isinstance(conditioning_scale, torch.Tensor):
|
||||
if conditioning_scale.size(0) != len(self.transformer.config.vace_layers):
|
||||
if conditioning_scale.size(0) != len(vace_layers):
|
||||
raise ValueError(
|
||||
f"Length of `conditioning_scale` {conditioning_scale.size(0)} does not match number of layers {len(self.transformer.config.vace_layers)}."
|
||||
f"Length of `conditioning_scale` {conditioning_scale.size(0)} does not match number of layers {len(vace_layers)}."
|
||||
)
|
||||
conditioning_scale = conditioning_scale.to(device=device, dtype=transformer_dtype)
|
||||
|
||||
@@ -900,7 +923,11 @@ class WanVACEPipeline(DiffusionPipeline, WanLoraLoaderMixin):
|
||||
conditioning_latents = torch.cat([conditioning_latents, mask], dim=1)
|
||||
conditioning_latents = conditioning_latents.to(transformer_dtype)
|
||||
|
||||
num_channels_latents = self.transformer.config.in_channels
|
||||
num_channels_latents = (
|
||||
self.transformer.config.in_channels
|
||||
if self.transformer is not None
|
||||
else self.transformer_2.config.in_channels
|
||||
)
|
||||
latents = self.prepare_latents(
|
||||
batch_size * num_videos_per_prompt,
|
||||
num_channels_latents,
|
||||
@@ -968,7 +995,7 @@ class WanVACEPipeline(DiffusionPipeline, WanLoraLoaderMixin):
|
||||
attention_kwargs=attention_kwargs,
|
||||
return_dict=False,
|
||||
)[0]
|
||||
noise_pred = noise_uncond + guidance_scale * (noise_pred - noise_uncond)
|
||||
noise_pred = noise_uncond + current_guidance_scale * (noise_pred - noise_uncond)
|
||||
|
||||
# compute the previous noisy sample x_t -> x_t-1
|
||||
latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0]
|
||||
|
||||
@@ -2,32 +2,22 @@ from ..utils import get_logger
|
||||
from .import_utils import is_kernels_available
|
||||
|
||||
|
||||
if is_kernels_available():
|
||||
from kernels import get_kernel
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
_DEFAULT_HUB_IDS = {
|
||||
"fa3": ("kernels-community/flash-attn3", {"revision": "fake-ops-return-probs"}),
|
||||
"fa": ("kernels-community/flash-attn", {}),
|
||||
}
|
||||
|
||||
|
||||
def _get_from_hub(key: str):
|
||||
if not is_kernels_available():
|
||||
return None
|
||||
|
||||
hub_id, kwargs = _DEFAULT_HUB_IDS[key]
|
||||
try:
|
||||
return get_kernel(hub_id, **kwargs)
|
||||
except Exception as e:
|
||||
logger.error(f"An error occurred while fetching kernel '{hub_id}' from the Hub: {e}")
|
||||
raise
|
||||
_DEFAULT_HUB_ID_FA3 = "kernels-community/flash-attn3"
|
||||
|
||||
|
||||
def _get_fa3_from_hub():
|
||||
return _get_from_hub("fa3")
|
||||
if not is_kernels_available():
|
||||
return None
|
||||
else:
|
||||
from kernels import get_kernel
|
||||
|
||||
|
||||
def _get_fa_from_hub():
|
||||
return _get_from_hub("fa")
|
||||
try:
|
||||
# TODO: temporary revision for now. Remove when merged upstream into `main`.
|
||||
flash_attn_3_hub = get_kernel(_DEFAULT_HUB_ID_FA3, revision="fake-ops-return-probs")
|
||||
return flash_attn_3_hub
|
||||
except Exception as e:
|
||||
logger.error(f"An error occurred while fetching kernel '{_DEFAULT_HUB_ID_FA3}' from the Hub: {e}")
|
||||
raise
|
||||
|
||||
@@ -0,0 +1,144 @@
|
||||
"""
|
||||
This test suite exists for the maintainers currently. It's not run in our CI at the moment.
|
||||
|
||||
Once attention backends become more mature, we can consider including this in our CI.
|
||||
|
||||
To run this test suite:
|
||||
|
||||
```bash
|
||||
export RUN_ATTENTION_BACKEND_TESTS=yes
|
||||
export DIFFUSERS_ENABLE_HUB_KERNELS=yes
|
||||
|
||||
pytest tests/others/test_attention_backends.py
|
||||
```
|
||||
|
||||
Tests were conducted on an H100 with PyTorch 2.8.0 (CUDA 12.9). Slices for the compilation tests in
|
||||
"native" variants were obtained with a torch nightly version (2.10.0.dev20250924+cu128).
|
||||
"""
|
||||
|
||||
import os
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
|
||||
pytestmark = pytest.mark.skipif(
|
||||
os.getenv("RUN_ATTENTION_BACKEND_TESTS", "false") == "false", reason="Feature not mature enough."
|
||||
)
|
||||
from diffusers import FluxPipeline # noqa: E402
|
||||
from diffusers.utils import is_torch_version # noqa: E402
|
||||
|
||||
|
||||
# fmt: off
|
||||
FORWARD_CASES = [
|
||||
("flash_hub", None),
|
||||
(
|
||||
"_flash_3_hub",
|
||||
torch.tensor([0.0820, 0.0859, 0.0938, 0.1016, 0.0977, 0.0996, 0.1016, 0.1016, 0.2188, 0.2246, 0.2344, 0.2480, 0.2539, 0.2480, 0.2441, 0.2715], dtype=torch.bfloat16),
|
||||
),
|
||||
(
|
||||
"native",
|
||||
torch.tensor([0.0820, 0.0859, 0.0938, 0.1016, 0.0957, 0.0996, 0.0996, 0.1016, 0.2188, 0.2266, 0.2363, 0.2500, 0.2539, 0.2480, 0.2461, 0.2734], dtype=torch.bfloat16)
|
||||
),
|
||||
(
|
||||
"_native_cudnn",
|
||||
torch.tensor([0.0781, 0.0840, 0.0879, 0.0957, 0.0898, 0.0957, 0.0957, 0.0977, 0.2168, 0.2246, 0.2324, 0.2500, 0.2539, 0.2480, 0.2441, 0.2695], dtype=torch.bfloat16),
|
||||
),
|
||||
]
|
||||
|
||||
COMPILE_CASES = [
|
||||
("flash_hub", None, True),
|
||||
(
|
||||
"_flash_3_hub",
|
||||
torch.tensor([0.0410, 0.0410, 0.0449, 0.0508, 0.0508, 0.0605, 0.0625, 0.0605, 0.2344, 0.2461, 0.2578, 0.2734, 0.2852, 0.2812, 0.2773, 0.3047], dtype=torch.bfloat16),
|
||||
True,
|
||||
),
|
||||
(
|
||||
"native",
|
||||
torch.tensor([0.0410, 0.0410, 0.0449, 0.0508, 0.0508, 0.0605, 0.0605, 0.0605, 0.2344, 0.2461, 0.2578, 0.2773, 0.2871, 0.2832, 0.2773, 0.3066], dtype=torch.bfloat16),
|
||||
True,
|
||||
),
|
||||
(
|
||||
"_native_cudnn",
|
||||
torch.tensor([0.0410, 0.0410, 0.0430, 0.0508, 0.0488, 0.0586, 0.0605, 0.0586, 0.2344, 0.2461, 0.2578, 0.2773, 0.2871, 0.2832, 0.2793, 0.3086], dtype=torch.bfloat16),
|
||||
True,
|
||||
),
|
||||
]
|
||||
# fmt: on
|
||||
|
||||
INFER_KW = {
|
||||
"prompt": "dance doggo dance",
|
||||
"height": 256,
|
||||
"width": 256,
|
||||
"num_inference_steps": 2,
|
||||
"guidance_scale": 3.5,
|
||||
"max_sequence_length": 128,
|
||||
"output_type": "pt",
|
||||
}
|
||||
|
||||
|
||||
def _backend_is_probably_supported(pipe, name: str):
|
||||
try:
|
||||
pipe.transformer.set_attention_backend(name)
|
||||
return pipe, True
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
|
||||
def _check_if_slices_match(output, expected_slice):
|
||||
img = output.images.detach().cpu()
|
||||
generated_slice = img.flatten()
|
||||
generated_slice = torch.cat([generated_slice[:8], generated_slice[-8:]])
|
||||
assert torch.allclose(generated_slice, expected_slice, atol=1e-4)
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def device():
|
||||
if not torch.cuda.is_available():
|
||||
pytest.skip("CUDA is required for these tests.")
|
||||
return torch.device("cuda:0")
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def pipe(device):
|
||||
repo_id = "black-forest-labs/FLUX.1-dev"
|
||||
pipe = FluxPipeline.from_pretrained(repo_id, torch_dtype=torch.bfloat16).to(device)
|
||||
pipe.set_progress_bar_config(disable=True)
|
||||
return pipe
|
||||
|
||||
|
||||
@pytest.mark.parametrize("backend_name,expected_slice", FORWARD_CASES, ids=[c[0] for c in FORWARD_CASES])
|
||||
def test_forward(pipe, backend_name, expected_slice):
|
||||
out = _backend_is_probably_supported(pipe, backend_name)
|
||||
if isinstance(out, bool):
|
||||
pytest.xfail(f"Backend '{backend_name}' not supported in this environment.")
|
||||
|
||||
modified_pipe = out[0]
|
||||
out = modified_pipe(**INFER_KW, generator=torch.manual_seed(0))
|
||||
_check_if_slices_match(out, expected_slice)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"backend_name,expected_slice,error_on_recompile",
|
||||
COMPILE_CASES,
|
||||
ids=[c[0] for c in COMPILE_CASES],
|
||||
)
|
||||
def test_forward_with_compile(pipe, backend_name, expected_slice, error_on_recompile):
|
||||
if "native" in backend_name and error_on_recompile and not is_torch_version(">=", "2.9.0"):
|
||||
pytest.xfail(f"Test with {backend_name=} is compatible with a higher version of torch.")
|
||||
|
||||
out = _backend_is_probably_supported(pipe, backend_name)
|
||||
if isinstance(out, bool):
|
||||
pytest.xfail(f"Backend '{backend_name}' not supported in this environment.")
|
||||
|
||||
modified_pipe = out[0]
|
||||
modified_pipe.transformer.compile(fullgraph=True)
|
||||
|
||||
torch.compiler.reset()
|
||||
with (
|
||||
torch._inductor.utils.fresh_inductor_cache(),
|
||||
torch._dynamo.config.patch(error_on_recompile=error_on_recompile),
|
||||
):
|
||||
out = modified_pipe(**INFER_KW, generator=torch.manual_seed(0))
|
||||
|
||||
_check_if_slices_match(out, expected_slice)
|
||||
@@ -17,11 +17,13 @@ import gc
|
||||
import unittest
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
import torch
|
||||
from transformers import CLIPTextConfig, CLIPTextModelWithProjection, CLIPTokenizer
|
||||
|
||||
from diffusers import DDPMWuerstchenScheduler, StableCascadePriorPipeline
|
||||
from diffusers.models import StableCascadeUNet
|
||||
from diffusers.utils import is_transformers_version
|
||||
from diffusers.utils.import_utils import is_peft_available
|
||||
|
||||
from ...testing_utils import (
|
||||
@@ -154,6 +156,11 @@ class StableCascadePriorPipelineFastTests(PipelineTesterMixin, unittest.TestCase
|
||||
}
|
||||
return inputs
|
||||
|
||||
@pytest.mark.xfail(
|
||||
condition=is_transformers_version(">=", "4.57.1"),
|
||||
reason="Test fails with the latest transformers version",
|
||||
strict=False,
|
||||
)
|
||||
def test_wuerstchen_prior(self):
|
||||
device = "cpu"
|
||||
|
||||
|
||||
@@ -12,6 +12,7 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import tempfile
|
||||
import unittest
|
||||
|
||||
import numpy as np
|
||||
@@ -19,9 +20,15 @@ import torch
|
||||
from PIL import Image
|
||||
from transformers import AutoTokenizer, T5EncoderModel
|
||||
|
||||
from diffusers import AutoencoderKLWan, FlowMatchEulerDiscreteScheduler, WanVACEPipeline, WanVACETransformer3DModel
|
||||
from diffusers import (
|
||||
AutoencoderKLWan,
|
||||
FlowMatchEulerDiscreteScheduler,
|
||||
UniPCMultistepScheduler,
|
||||
WanVACEPipeline,
|
||||
WanVACETransformer3DModel,
|
||||
)
|
||||
|
||||
from ...testing_utils import enable_full_determinism
|
||||
from ...testing_utils import enable_full_determinism, torch_device
|
||||
from ..pipeline_params import TEXT_TO_IMAGE_BATCH_PARAMS, TEXT_TO_IMAGE_IMAGE_PARAMS, TEXT_TO_IMAGE_PARAMS
|
||||
from ..test_pipelines_common import PipelineTesterMixin
|
||||
|
||||
@@ -212,3 +219,81 @@ class WanVACEPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
|
||||
)
|
||||
def test_save_load_float16(self):
|
||||
pass
|
||||
|
||||
def test_inference_with_only_transformer(self):
|
||||
components = self.get_dummy_components()
|
||||
components["transformer_2"] = None
|
||||
components["boundary_ratio"] = 0.0
|
||||
pipe = self.pipeline_class(**components)
|
||||
pipe.to(torch_device)
|
||||
pipe.set_progress_bar_config(disable=None)
|
||||
|
||||
inputs = self.get_dummy_inputs(torch_device)
|
||||
video = pipe(**inputs).frames[0]
|
||||
assert video.shape == (17, 3, 16, 16)
|
||||
|
||||
def test_inference_with_only_transformer_2(self):
|
||||
components = self.get_dummy_components()
|
||||
components["transformer_2"] = components["transformer"]
|
||||
components["transformer"] = None
|
||||
|
||||
# FlowMatchEulerDiscreteScheduler doesn't support running low noise only scheduler
|
||||
# because starting timestep t == 1000 == boundary_timestep
|
||||
components["scheduler"] = UniPCMultistepScheduler(
|
||||
prediction_type="flow_prediction", use_flow_sigmas=True, flow_shift=3.0
|
||||
)
|
||||
|
||||
components["boundary_ratio"] = 1.0
|
||||
pipe = self.pipeline_class(**components)
|
||||
pipe.to(torch_device)
|
||||
pipe.set_progress_bar_config(disable=None)
|
||||
|
||||
inputs = self.get_dummy_inputs(torch_device)
|
||||
video = pipe(**inputs).frames[0]
|
||||
assert video.shape == (17, 3, 16, 16)
|
||||
|
||||
def test_save_load_optional_components(self, expected_max_difference=1e-4):
|
||||
optional_component = ["transformer"]
|
||||
|
||||
components = self.get_dummy_components()
|
||||
components["transformer_2"] = components["transformer"]
|
||||
# FlowMatchEulerDiscreteScheduler doesn't support running low noise only scheduler
|
||||
# because starting timestep t == 1000 == boundary_timestep
|
||||
components["scheduler"] = UniPCMultistepScheduler(
|
||||
prediction_type="flow_prediction", use_flow_sigmas=True, flow_shift=3.0
|
||||
)
|
||||
for component in optional_component:
|
||||
components[component] = None
|
||||
|
||||
components["boundary_ratio"] = 1.0
|
||||
|
||||
pipe = self.pipeline_class(**components)
|
||||
for component in pipe.components.values():
|
||||
if hasattr(component, "set_default_attn_processor"):
|
||||
component.set_default_attn_processor()
|
||||
pipe.to(torch_device)
|
||||
pipe.set_progress_bar_config(disable=None)
|
||||
|
||||
generator_device = "cpu"
|
||||
inputs = self.get_dummy_inputs(generator_device)
|
||||
torch.manual_seed(0)
|
||||
output = pipe(**inputs)[0]
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
pipe.save_pretrained(tmpdir, safe_serialization=False)
|
||||
pipe_loaded = self.pipeline_class.from_pretrained(tmpdir)
|
||||
for component in pipe_loaded.components.values():
|
||||
if hasattr(component, "set_default_attn_processor"):
|
||||
component.set_default_attn_processor()
|
||||
pipe_loaded.to(torch_device)
|
||||
pipe_loaded.set_progress_bar_config(disable=None)
|
||||
|
||||
for component in optional_component:
|
||||
assert getattr(pipe_loaded, component) is None, f"`{component}` did not stay set to None after loading."
|
||||
|
||||
inputs = self.get_dummy_inputs(generator_device)
|
||||
torch.manual_seed(0)
|
||||
output_loaded = pipe_loaded(**inputs)[0]
|
||||
|
||||
max_diff = np.abs(output.detach().cpu().numpy() - output_loaded.detach().cpu().numpy()).max()
|
||||
assert max_diff < expected_max_difference, "Outputs exceed expecpted maximum difference"
|
||||
|
||||
Reference in New Issue
Block a user