Fix flash_attn3 in dispatch attn backend by _flash_attn_forward, replace its origin implement; Add DocString in pipeline for that.
This commit is contained in:
parent
71e8049a84
commit
fbf26b7ed1
@ -79,9 +79,11 @@ else:
|
||||
if _CAN_USE_FLASH_ATTN_3:
|
||||
from flash_attn_interface import flash_attn_func as flash_attn_3_func
|
||||
from flash_attn_interface import flash_attn_varlen_func as flash_attn_3_varlen_func
|
||||
from flash_attn_interface import _flash_attn_forward as flash_attn_3_forward
|
||||
else:
|
||||
flash_attn_3_func = None
|
||||
flash_attn_3_varlen_func = None
|
||||
flash_attn_3_forward = None
|
||||
|
||||
if _CAN_USE_AITER_ATTN:
|
||||
from aiter import flash_attn_func as aiter_flash_attn_func
|
||||
@ -621,22 +623,42 @@ def _wrapped_flash_attn_3(
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
# Hardcoded for now because pytorch does not support tuple/int type hints
|
||||
window_size = (-1, -1)
|
||||
out, lse, *_ = flash_attn_3_func(
|
||||
max_seqlen_q = q.shape[2]
|
||||
max_seqlen_k = k.shape[2]
|
||||
|
||||
out, lse, *_ = flash_attn_3_forward(
|
||||
q=q,
|
||||
k=k,
|
||||
v=v,
|
||||
softmax_scale=softmax_scale,
|
||||
causal=causal,
|
||||
k_new=None,
|
||||
v_new=None,
|
||||
qv=qv,
|
||||
out=None,
|
||||
cu_seqlens_q=None,
|
||||
cu_seqlens_k=None,
|
||||
cu_seqlens_k_new=None,
|
||||
seqused_q=None,
|
||||
seqused_k=None,
|
||||
max_seqlen_q=max_seqlen_q,
|
||||
max_seqlen_k=max_seqlen_k,
|
||||
page_table=None,
|
||||
kv_batch_idx=None,
|
||||
leftpad_k=None,
|
||||
rotary_cos=None,
|
||||
rotary_sin=None,
|
||||
seqlens_rotary=None,
|
||||
q_descale=q_descale,
|
||||
k_descale=k_descale,
|
||||
v_descale=v_descale,
|
||||
softmax_scale=softmax_scale,
|
||||
causal=causal,
|
||||
window_size=window_size,
|
||||
attention_chunk=attention_chunk,
|
||||
softcap=softcap,
|
||||
rotary_interleaved=True,
|
||||
scheduler_metadata=None,
|
||||
num_splits=num_splits,
|
||||
pack_gqa=pack_gqa,
|
||||
deterministic=deterministic,
|
||||
sm_margin=sm_margin,
|
||||
)
|
||||
lse = lse.permute(0, 2, 1)
|
||||
|
||||
@ -39,9 +39,14 @@ EXAMPLE_DOC_STRING = """
|
||||
|
||||
>>> pipe = ZImagePipeline.from_pretrained("Z-a-o/Z-Image-Turbo", torch_dtype=torch.bfloat16)
|
||||
>>> pipe.to("cuda")
|
||||
|
||||
>>> # Optionally, set the attention backend to flash-attn 2 or 3, default is SDPA in PyTorch.
|
||||
>>> # (1) Use flash attention 2
|
||||
>>> # pipe.transformer.set_attention_backend("flash")
|
||||
>>> # (2) Use flash attention 3
|
||||
>>> # pipe.transformer.set_attention_backend("_flash_3")
|
||||
|
||||
>>> prompt = "一幅为名为“造相「Z-IMAGE-TURBO」”的项目设计的创意海报。画面巧妙地将文字概念视觉化:一辆复古蒸汽小火车化身为巨大的拉链头,正拉开厚厚的冬日积雪,展露出一个生机盎然的春天。"
|
||||
>>> # Depending on the variant being used, the pipeline call will slightly vary.
|
||||
>>> # Refer to the pipeline documentation for more details.
|
||||
>>> image = pipe(
|
||||
... prompt,
|
||||
... height=1024,
|
||||
|
||||
Loading…
Reference in New Issue
Block a user