Fix flash_attn3 in dispatch attn backend by _flash_attn_forward, replace its origin implement; Add DocString in pipeline for that.
This commit is contained in:
parent
71e8049a84
commit
fbf26b7ed1
@ -79,9 +79,11 @@ else:
|
|||||||
if _CAN_USE_FLASH_ATTN_3:
|
if _CAN_USE_FLASH_ATTN_3:
|
||||||
from flash_attn_interface import flash_attn_func as flash_attn_3_func
|
from flash_attn_interface import flash_attn_func as flash_attn_3_func
|
||||||
from flash_attn_interface import flash_attn_varlen_func as flash_attn_3_varlen_func
|
from flash_attn_interface import flash_attn_varlen_func as flash_attn_3_varlen_func
|
||||||
|
from flash_attn_interface import _flash_attn_forward as flash_attn_3_forward
|
||||||
else:
|
else:
|
||||||
flash_attn_3_func = None
|
flash_attn_3_func = None
|
||||||
flash_attn_3_varlen_func = None
|
flash_attn_3_varlen_func = None
|
||||||
|
flash_attn_3_forward = None
|
||||||
|
|
||||||
if _CAN_USE_AITER_ATTN:
|
if _CAN_USE_AITER_ATTN:
|
||||||
from aiter import flash_attn_func as aiter_flash_attn_func
|
from aiter import flash_attn_func as aiter_flash_attn_func
|
||||||
@ -621,22 +623,42 @@ def _wrapped_flash_attn_3(
|
|||||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||||
# Hardcoded for now because pytorch does not support tuple/int type hints
|
# Hardcoded for now because pytorch does not support tuple/int type hints
|
||||||
window_size = (-1, -1)
|
window_size = (-1, -1)
|
||||||
out, lse, *_ = flash_attn_3_func(
|
max_seqlen_q = q.shape[2]
|
||||||
|
max_seqlen_k = k.shape[2]
|
||||||
|
|
||||||
|
out, lse, *_ = flash_attn_3_forward(
|
||||||
q=q,
|
q=q,
|
||||||
k=k,
|
k=k,
|
||||||
v=v,
|
v=v,
|
||||||
softmax_scale=softmax_scale,
|
k_new=None,
|
||||||
causal=causal,
|
v_new=None,
|
||||||
qv=qv,
|
qv=qv,
|
||||||
|
out=None,
|
||||||
|
cu_seqlens_q=None,
|
||||||
|
cu_seqlens_k=None,
|
||||||
|
cu_seqlens_k_new=None,
|
||||||
|
seqused_q=None,
|
||||||
|
seqused_k=None,
|
||||||
|
max_seqlen_q=max_seqlen_q,
|
||||||
|
max_seqlen_k=max_seqlen_k,
|
||||||
|
page_table=None,
|
||||||
|
kv_batch_idx=None,
|
||||||
|
leftpad_k=None,
|
||||||
|
rotary_cos=None,
|
||||||
|
rotary_sin=None,
|
||||||
|
seqlens_rotary=None,
|
||||||
q_descale=q_descale,
|
q_descale=q_descale,
|
||||||
k_descale=k_descale,
|
k_descale=k_descale,
|
||||||
v_descale=v_descale,
|
v_descale=v_descale,
|
||||||
|
softmax_scale=softmax_scale,
|
||||||
|
causal=causal,
|
||||||
window_size=window_size,
|
window_size=window_size,
|
||||||
attention_chunk=attention_chunk,
|
attention_chunk=attention_chunk,
|
||||||
softcap=softcap,
|
softcap=softcap,
|
||||||
|
rotary_interleaved=True,
|
||||||
|
scheduler_metadata=None,
|
||||||
num_splits=num_splits,
|
num_splits=num_splits,
|
||||||
pack_gqa=pack_gqa,
|
pack_gqa=pack_gqa,
|
||||||
deterministic=deterministic,
|
|
||||||
sm_margin=sm_margin,
|
sm_margin=sm_margin,
|
||||||
)
|
)
|
||||||
lse = lse.permute(0, 2, 1)
|
lse = lse.permute(0, 2, 1)
|
||||||
|
|||||||
@ -39,9 +39,14 @@ EXAMPLE_DOC_STRING = """
|
|||||||
|
|
||||||
>>> pipe = ZImagePipeline.from_pretrained("Z-a-o/Z-Image-Turbo", torch_dtype=torch.bfloat16)
|
>>> pipe = ZImagePipeline.from_pretrained("Z-a-o/Z-Image-Turbo", torch_dtype=torch.bfloat16)
|
||||||
>>> pipe.to("cuda")
|
>>> pipe.to("cuda")
|
||||||
|
|
||||||
|
>>> # Optionally, set the attention backend to flash-attn 2 or 3, default is SDPA in PyTorch.
|
||||||
|
>>> # (1) Use flash attention 2
|
||||||
|
>>> # pipe.transformer.set_attention_backend("flash")
|
||||||
|
>>> # (2) Use flash attention 3
|
||||||
|
>>> # pipe.transformer.set_attention_backend("_flash_3")
|
||||||
|
|
||||||
>>> prompt = "一幅为名为“造相「Z-IMAGE-TURBO」”的项目设计的创意海报。画面巧妙地将文字概念视觉化:一辆复古蒸汽小火车化身为巨大的拉链头,正拉开厚厚的冬日积雪,展露出一个生机盎然的春天。"
|
>>> prompt = "一幅为名为“造相「Z-IMAGE-TURBO」”的项目设计的创意海报。画面巧妙地将文字概念视觉化:一辆复古蒸汽小火车化身为巨大的拉链头,正拉开厚厚的冬日积雪,展露出一个生机盎然的春天。"
|
||||||
>>> # Depending on the variant being used, the pipeline call will slightly vary.
|
|
||||||
>>> # Refer to the pipeline documentation for more details.
|
|
||||||
>>> image = pipe(
|
>>> image = pipe(
|
||||||
... prompt,
|
... prompt,
|
||||||
... height=1024,
|
... height=1024,
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user