From fbf26b7ed11d55146103c97740bad4a5f91744e0 Mon Sep 17 00:00:00 2001 From: Jerry Qilong Wu Date: Mon, 24 Nov 2025 18:49:45 +0000 Subject: [PATCH] Fix flash_attn3 in dispatch attn backend by _flash_attn_forward, replace its origin implement; Add DocString in pipeline for that. --- src/diffusers/models/attention_dispatch.py | 30 ++++++++++++++++--- .../pipelines/z_image/pipeline_z_image.py | 9 ++++-- 2 files changed, 33 insertions(+), 6 deletions(-) diff --git a/src/diffusers/models/attention_dispatch.py b/src/diffusers/models/attention_dispatch.py index 8504504981..df4e0a0122 100644 --- a/src/diffusers/models/attention_dispatch.py +++ b/src/diffusers/models/attention_dispatch.py @@ -79,9 +79,11 @@ else: if _CAN_USE_FLASH_ATTN_3: from flash_attn_interface import flash_attn_func as flash_attn_3_func from flash_attn_interface import flash_attn_varlen_func as flash_attn_3_varlen_func + from flash_attn_interface import _flash_attn_forward as flash_attn_3_forward else: flash_attn_3_func = None flash_attn_3_varlen_func = None + flash_attn_3_forward = None if _CAN_USE_AITER_ATTN: from aiter import flash_attn_func as aiter_flash_attn_func @@ -621,22 +623,42 @@ def _wrapped_flash_attn_3( ) -> Tuple[torch.Tensor, torch.Tensor]: # Hardcoded for now because pytorch does not support tuple/int type hints window_size = (-1, -1) - out, lse, *_ = flash_attn_3_func( + max_seqlen_q = q.shape[2] + max_seqlen_k = k.shape[2] + + out, lse, *_ = flash_attn_3_forward( q=q, k=k, v=v, - softmax_scale=softmax_scale, - causal=causal, + k_new=None, + v_new=None, qv=qv, + out=None, + cu_seqlens_q=None, + cu_seqlens_k=None, + cu_seqlens_k_new=None, + seqused_q=None, + seqused_k=None, + max_seqlen_q=max_seqlen_q, + max_seqlen_k=max_seqlen_k, + page_table=None, + kv_batch_idx=None, + leftpad_k=None, + rotary_cos=None, + rotary_sin=None, + seqlens_rotary=None, q_descale=q_descale, k_descale=k_descale, v_descale=v_descale, + softmax_scale=softmax_scale, + causal=causal, window_size=window_size, attention_chunk=attention_chunk, softcap=softcap, + rotary_interleaved=True, + scheduler_metadata=None, num_splits=num_splits, pack_gqa=pack_gqa, - deterministic=deterministic, sm_margin=sm_margin, ) lse = lse.permute(0, 2, 1) diff --git a/src/diffusers/pipelines/z_image/pipeline_z_image.py b/src/diffusers/pipelines/z_image/pipeline_z_image.py index d33d72d9c7..d4cd574a04 100644 --- a/src/diffusers/pipelines/z_image/pipeline_z_image.py +++ b/src/diffusers/pipelines/z_image/pipeline_z_image.py @@ -39,9 +39,14 @@ EXAMPLE_DOC_STRING = """ >>> pipe = ZImagePipeline.from_pretrained("Z-a-o/Z-Image-Turbo", torch_dtype=torch.bfloat16) >>> pipe.to("cuda") + + >>> # Optionally, set the attention backend to flash-attn 2 or 3, default is SDPA in PyTorch. + >>> # (1) Use flash attention 2 + >>> # pipe.transformer.set_attention_backend("flash") + >>> # (2) Use flash attention 3 + >>> # pipe.transformer.set_attention_backend("_flash_3") + >>> prompt = "一幅为名为“造相「Z-IMAGE-TURBO」”的项目设计的创意海报。画面巧妙地将文字概念视觉化:一辆复古蒸汽小火车化身为巨大的拉链头,正拉开厚厚的冬日积雪,展露出一个生机盎然的春天。" - >>> # Depending on the variant being used, the pipeline call will slightly vary. - >>> # Refer to the pipeline documentation for more details. >>> image = pipe( ... prompt, ... height=1024,