Fix flash_attn3 in dispatch attn backend by _flash_attn_forward, replace its origin implement; Add DocString in pipeline for that.

This commit is contained in:
Jerry Qilong Wu 2025-11-24 18:49:45 +00:00
parent 71e8049a84
commit fbf26b7ed1
2 changed files with 33 additions and 6 deletions

View File

@ -79,9 +79,11 @@ else:
if _CAN_USE_FLASH_ATTN_3: if _CAN_USE_FLASH_ATTN_3:
from flash_attn_interface import flash_attn_func as flash_attn_3_func from flash_attn_interface import flash_attn_func as flash_attn_3_func
from flash_attn_interface import flash_attn_varlen_func as flash_attn_3_varlen_func from flash_attn_interface import flash_attn_varlen_func as flash_attn_3_varlen_func
from flash_attn_interface import _flash_attn_forward as flash_attn_3_forward
else: else:
flash_attn_3_func = None flash_attn_3_func = None
flash_attn_3_varlen_func = None flash_attn_3_varlen_func = None
flash_attn_3_forward = None
if _CAN_USE_AITER_ATTN: if _CAN_USE_AITER_ATTN:
from aiter import flash_attn_func as aiter_flash_attn_func from aiter import flash_attn_func as aiter_flash_attn_func
@ -621,22 +623,42 @@ def _wrapped_flash_attn_3(
) -> Tuple[torch.Tensor, torch.Tensor]: ) -> Tuple[torch.Tensor, torch.Tensor]:
# Hardcoded for now because pytorch does not support tuple/int type hints # Hardcoded for now because pytorch does not support tuple/int type hints
window_size = (-1, -1) window_size = (-1, -1)
out, lse, *_ = flash_attn_3_func( max_seqlen_q = q.shape[2]
max_seqlen_k = k.shape[2]
out, lse, *_ = flash_attn_3_forward(
q=q, q=q,
k=k, k=k,
v=v, v=v,
softmax_scale=softmax_scale, k_new=None,
causal=causal, v_new=None,
qv=qv, qv=qv,
out=None,
cu_seqlens_q=None,
cu_seqlens_k=None,
cu_seqlens_k_new=None,
seqused_q=None,
seqused_k=None,
max_seqlen_q=max_seqlen_q,
max_seqlen_k=max_seqlen_k,
page_table=None,
kv_batch_idx=None,
leftpad_k=None,
rotary_cos=None,
rotary_sin=None,
seqlens_rotary=None,
q_descale=q_descale, q_descale=q_descale,
k_descale=k_descale, k_descale=k_descale,
v_descale=v_descale, v_descale=v_descale,
softmax_scale=softmax_scale,
causal=causal,
window_size=window_size, window_size=window_size,
attention_chunk=attention_chunk, attention_chunk=attention_chunk,
softcap=softcap, softcap=softcap,
rotary_interleaved=True,
scheduler_metadata=None,
num_splits=num_splits, num_splits=num_splits,
pack_gqa=pack_gqa, pack_gqa=pack_gqa,
deterministic=deterministic,
sm_margin=sm_margin, sm_margin=sm_margin,
) )
lse = lse.permute(0, 2, 1) lse = lse.permute(0, 2, 1)

View File

@ -39,9 +39,14 @@ EXAMPLE_DOC_STRING = """
>>> pipe = ZImagePipeline.from_pretrained("Z-a-o/Z-Image-Turbo", torch_dtype=torch.bfloat16) >>> pipe = ZImagePipeline.from_pretrained("Z-a-o/Z-Image-Turbo", torch_dtype=torch.bfloat16)
>>> pipe.to("cuda") >>> pipe.to("cuda")
>>> # Optionally, set the attention backend to flash-attn 2 or 3, default is SDPA in PyTorch.
>>> # (1) Use flash attention 2
>>> # pipe.transformer.set_attention_backend("flash")
>>> # (2) Use flash attention 3
>>> # pipe.transformer.set_attention_backend("_flash_3")
>>> prompt = "一幅为名为“造相「Z-IMAGE-TURBO」”的项目设计的创意海报。画面巧妙地将文字概念视觉化一辆复古蒸汽小火车化身为巨大的拉链头正拉开厚厚的冬日积雪展露出一个生机盎然的春天。" >>> prompt = "一幅为名为“造相「Z-IMAGE-TURBO」”的项目设计的创意海报。画面巧妙地将文字概念视觉化一辆复古蒸汽小火车化身为巨大的拉链头正拉开厚厚的冬日积雪展露出一个生机盎然的春天。"
>>> # Depending on the variant being used, the pipeline call will slightly vary.
>>> # Refer to the pipeline documentation for more details.
>>> image = pipe( >>> image = pipe(
... prompt, ... prompt,
... height=1024, ... height=1024,