TensorRT-LLMs/cpp/kernels/fmha_v2/fmha_test.py
Zhou Yuxin fca13b8c95
hopper-style context MLA (#5713)
Signed-off-by: Yuxin <yuxinz@nvidia.com>
Signed-off-by: Tomer Asida <57313761+tomeras91@users.noreply.github.com>
Signed-off-by: Yiqing Yan <yiqingy@nvidia.com>
Signed-off-by: qqiao <qqiao@nvidia.com>
Signed-off-by: Fred Wei <20514172+WeiHaocheng@users.noreply.github.com>
Signed-off-by: Omer Ullman Argov <118735753+omera-nv@users.noreply.github.com>
Signed-off-by: Netanel Haber <58652339+netanel-haber@users.noreply.github.com>
Signed-off-by: Robin Kobus <19427718+Funatiq@users.noreply.github.com>
Signed-off-by: nv-guomingz <137257613+nv-guomingz@users.noreply.github.com>
Signed-off-by: Rashid K <rkaleem@nvidia.com>
Signed-off-by: Zhenhuan Chen <chenzhh3671@gmail.com>
Signed-off-by: Po-Wei Wang (Vincent) <poweiw@nvidia.com>
Signed-off-by: Netanel Haber <nhaber@nvidia.com>
Signed-off-by: Lucas Liebenwein <11156568+lucaslie@users.noreply.github.com>
Signed-off-by: Frida Hou <201670829+Fridah-nv@users.noreply.github.com>
Signed-off-by: Clay <ccs96307@gmail.com>
Signed-off-by: Venky <23023424+venkywonka@users.noreply.github.com>
Signed-off-by: Xin He (SW-GPU) <200704525+xinhe-nv@users.noreply.github.com>
Signed-off-by: Superjomn <328693+Superjomn@users.noreply.github.com>
Signed-off-by: zhengd-nv <200704041+zhengd-nv@users.noreply.github.com>
Signed-off-by: Yi Zhang <187001205+yizhang-nv@users.noreply.github.com>
Signed-off-by: Kaiyu Xie <26294424+kaiyux@users.noreply.github.com>
Signed-off-by: Frank Di Natale <3429989+FrankD412@users.noreply.github.com>
Signed-off-by: Balaram Buddharaju <169953907+brb-nv@users.noreply.github.com>
Signed-off-by: Linda-Stadter <57756729+Linda-Stadter@users.noreply.github.com>
Signed-off-by: Shunkang <182541032+Shunkangz@users.noreply.github.co>
Signed-off-by: Yuan Tong <13075180+tongyuantongyu@users.noreply.github.com>
Signed-off-by: Tailing Yuan <yuantailing@gmail.com>
Signed-off-by: Faraz Khoubsirat <58580514+farazkh80@users.noreply.github.com>
Signed-off-by: peaceh <103117813+peaceh-nv@users.noreply.github.com>
Signed-off-by: ixlmar <206748156+ixlmar@users.noreply.github.com>
Signed-off-by: Hui Gao <huig@nvidia.com>
Signed-off-by: ShiXiaowei02 <39303645+Shixiaowei02@users.noreply.github.com>
Signed-off-by: Chuang Zhu <111838961+chuangz0@users.noreply.github.com>
Signed-off-by: Stefan Niebler <82932102+stnie@users.noreply.github.com>
Signed-off-by: jthomson04 <jwillthomson19@gmail.com>
Signed-off-by: Xianjie <5410381+qiaoxj07@users.noreply.github.com>
Signed-off-by: Xianjie Qiao <5410381+qiaoxj07@users.noreply.github.com>
Signed-off-by: Julien Debache <julien.debache@hotmail.com>
Signed-off-by: Yanchao Lu <yanchaol@nvidia.com>
Signed-off-by: Yiteng Niu <6831097+niukuo@users.noreply.github.com>
Signed-off-by: Daniel Stokes <40156487+djns99@users.noreply.github.com>
Signed-off-by: bhsueh <11360707+byshiue@users.noreply.github.com>
Signed-off-by: Bo Li <22713281+bobboli@users.noreply.github.com>
Signed-off-by: Christina Zhang <83400082+ChristinaZ@users.noreply.github.com>
Signed-off-by: xinhe-nv <200704525+xinhe-nv@users.noreply.github.com>
Signed-off-by: Dylan Chen <191843203+DylanChen-NV@users.noreply.github.com>
Signed-off-by: Daniel Campora <961215+dcampora@users.noreply.github.com>
Signed-off-by: David Clark <215764518+davidclark-nv@users.noreply.github.com>
Signed-off-by: yechank <161688079+yechank-nvidia@users.noreply.github.com>
Signed-off-by: Jin Li <59594262+liji-nv@users.noreply.github.com>
Signed-off-by: JieXin Liang <Alcanderian@users.noreply.github.com>
Signed-off-by: Venky Ganesh <23023424+venkywonka@users.noreply.github.com>
Signed-off-by: Enwei Zhu <21126786+syuoni@users.noreply.github.com>
Signed-off-by: Xiwen Yu <13230610+VALLIS-NERIA@users.noreply.github.com>
Signed-off-by: Yegor <75512761+Wokzy@users.noreply.github.com>
Signed-off-by: Yegor Yershov <yegor6741@gmail.com>
Signed-off-by: Yukun He <23156053+hyukn@users.noreply.github.com>
Signed-off-by: raayandhar <rdhar@nvidia.com>
Signed-off-by: Dom Brown <3886319+DomBrown@users.noreply.github.com>
Signed-off-by: Iman Tabrizian <10105175+tabrizian@users.noreply.github.com>
Signed-off-by: xsimmons <xsimmons@nvidia.com>
Signed-off-by: jiahanc <173873397+jiahanc@users.noreply.github.com>
Signed-off-by: Jhao-Ting Chen <jhaotingc@nvidia.com>
Signed-off-by: Wanli Jiang <35160485+Wanli-Jiang@users.noreply.github.com>
Signed-off-by: Erin Ho <14718778+hchings@users.noreply.github.com>
Signed-off-by: Chenfei Zhang <chenfeiz@nvidia.com>
Signed-off-by: Dongxu Yang <78518666+dongxuy04@users.noreply.github.com>
Signed-off-by: Hao Lu <14827759+hlu1@users.noreply.github.com>
Signed-off-by: William Zhang <133824995+2ez4bz@users.noreply.github.com>
Signed-off-by: Ubuntu <ubuntu@ip-10-0-20-146.us-west-2.compute.internal>
Signed-off-by: Hanjun Cho <46752251+gkswns0531@users.noreply.github.com>
Signed-off-by: junq <22017000+QiJune@users.noreply.github.com>
Signed-off-by: Aurelien Chartier <2567591+achartier@users.noreply.github.com>
Signed-off-by: Anthony Chang <27950904+rosenrodt@users.noreply.github.com>
Signed-off-by: CarstyYou <186021327+CarstyYou@users.noreply.github.com>
Signed-off-by: Jinyang Yuan <154768711+jinyangyuan-nvidia@users.noreply.github.com>
Signed-off-by: narutolhy <582909902@qq.com>
Signed-off-by: ZhanruiSunCh <184402041+ZhanruiSunCh@users.noreply.github.com>
Signed-off-by: wili-65535 <wili-65535@users.noreply.github.com>
Signed-off-by: Frank <3429989+FrankD412@users.noreply.github.com>
Signed-off-by: Yilin Zhang <18275976+yilin-void@users.noreply.github.com>
Signed-off-by: William Tambellini <wtambellini@sdl.com>
Co-authored-by: tomeras91 <57313761+tomeras91@users.noreply.github.com>
Co-authored-by: Yiqing Yan <yiqingy@nvidia.com>
Co-authored-by: Emma Qiao <qqiao@nvidia.com>
Co-authored-by: WeiHaocheng <20514172+WeiHaocheng@users.noreply.github.com>
Co-authored-by: Omer Ullman Argov <118735753+omera-nv@users.noreply.github.com>
Co-authored-by: Netanel Haber <58652339+netanel-haber@users.noreply.github.com>
Co-authored-by: Robin Kobus <19427718+Funatiq@users.noreply.github.com>
Co-authored-by: nv-guomingz <137257613+nv-guomingz@users.noreply.github.com>
Co-authored-by: Rashid Kaleem <4079439+arekay@users.noreply.github.com>
Co-authored-by: Zhihan Jiang <68881590+nvzhihanj@users.noreply.github.com>
Co-authored-by: Zhenhuan Chen <chenzhh3671@gmail.com>
Co-authored-by: Po-Wei (Vincent) <poweiw@nvidia.com>
Co-authored-by: Lucas Liebenwein <11156568+lucaslie@users.noreply.github.com>
Co-authored-by: Neta Zmora <nzmora@nvidia.com>
Co-authored-by: Fridah-nv <201670829+Fridah-nv@users.noreply.github.com>
Co-authored-by: Clay <ccs96307@gmail.com>
Co-authored-by: Venky <23023424+venkywonka@users.noreply.github.com>
Co-authored-by: xinhe-nv <200704525+xinhe-nv@users.noreply.github.com>
Co-authored-by: Yan Chunwei <328693+Superjomn@users.noreply.github.com>
Co-authored-by: Zheng Duan <200704041+zhengd-nv@users.noreply.github.com>
Co-authored-by: Yi Zhang <187001205+yizhang-nv@users.noreply.github.com>
Co-authored-by: Kaiyu Xie <26294424+kaiyux@users.noreply.github.com>
Co-authored-by: Frank <3429989+FrankD412@users.noreply.github.com>
Co-authored-by: brb-nv <169953907+brb-nv@users.noreply.github.com>
Co-authored-by: Linda <57756729+Linda-Stadter@users.noreply.github.com>
Co-authored-by: Shunkangz <182541032+Shunkangz@users.noreply.github.com>
Co-authored-by: Yuan Tong <13075180+tongyuantongyu@users.noreply.github.com>
Co-authored-by: Tailing Yuan <yuantailing@gmail.com>
Co-authored-by: Faraz <58580514+farazkh80@users.noreply.github.com>
Co-authored-by: peaceh-nv <103117813+peaceh-nv@users.noreply.github.com>
Co-authored-by: ixlmar <206748156+ixlmar@users.noreply.github.com>
Co-authored-by: HuiGao-NV <huig@nvidia.com>
Co-authored-by: Chuang Zhu <111838961+chuangz0@users.noreply.github.com>
Co-authored-by: ShiXiaowei02 <39303645+Shixiaowei02@users.noreply.github.com>
Co-authored-by: Stefan Niebler <82932102+stnie@users.noreply.github.com>
Co-authored-by: jthomson04 <jwillthomson19@gmail.com>
Co-authored-by: Xianjie Qiao <5410381+qiaoxj07@users.noreply.github.com>
Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
Co-authored-by: Julien Debache <jdebache@nvidia.com>
Co-authored-by: Yanchao Lu <yanchaol@nvidia.com>
Co-authored-by: Yiteng Niu <6831097+niukuo@users.noreply.github.com>
Co-authored-by: Daniel Stokes <40156487+djns99@users.noreply.github.com>
Co-authored-by: bhsueh_NV <11360707+byshiue@users.noreply.github.com>
Co-authored-by: Bo Li <22713281+bobboli@users.noreply.github.com>
Co-authored-by: ChristinaZ <83400082+ChristinaZ@users.noreply.github.com>
Co-authored-by: Larry <197874197+LarryXFly@users.noreply.github.com>
Co-authored-by: DylanChen-NV <191843203+DylanChen-NV@users.noreply.github.com>
Co-authored-by: Daniel Cámpora <961215+dcampora@users.noreply.github.com>
Co-authored-by: davidclark-nv <215764518+davidclark-nv@users.noreply.github.com>
Co-authored-by: Nikita Korobov <14355239+nekorobov@users.noreply.github.com>
Co-authored-by: Yechan Kim <161688079+yechank-nvidia@users.noreply.github.com>
Co-authored-by: liji-nv <59594262+liji-nv@users.noreply.github.com>
Co-authored-by: JieXin Liang <Alcanderian@users.noreply.github.com>
Co-authored-by: Enwei Zhu <21126786+syuoni@users.noreply.github.com>
Co-authored-by: xiweny <13230610+VALLIS-NERIA@users.noreply.github.com>
Co-authored-by: Yegor <75512761+Wokzy@users.noreply.github.com>
Co-authored-by: Yukun He <23156053+hyukn@users.noreply.github.com>
Co-authored-by: Raayan Dhar <58057652+raayandhar@users.noreply.github.com>
Co-authored-by: Dom Brown <3886319+DomBrown@users.noreply.github.com>
Co-authored-by: Chang Liu <9713593+chang-l@users.noreply.github.com>
Co-authored-by: Pamela Peng <179191831+pamelap-nvidia@users.noreply.github.com>
Co-authored-by: Iman Tabrizian <10105175+Tabrizian@users.noreply.github.com>
Co-authored-by: xavier-nvidia <xsimmons@nvidia.com>
Co-authored-by: jiahanc <173873397+jiahanc@users.noreply.github.com>
Co-authored-by: Jhao-Ting Chen <jhaotingc@nvidia.com>
Co-authored-by: Wanli Jiang <35160485+Wanli-Jiang@users.noreply.github.com>
Co-authored-by: Erin <14718778+hchings@users.noreply.github.com>
Co-authored-by: chenfeiz0326 <chenfeiz@nvidia.com>
Co-authored-by: dongxuy04 <78518666+dongxuy04@users.noreply.github.com>
Co-authored-by: 2ez4bz <133824995+2ez4bz@users.noreply.github.com>
Co-authored-by: Hanjun Cho <46752251+gkswns0531@users.noreply.github.com>
Co-authored-by: Ubuntu <ubuntu@ip-10-0-20-146.us-west-2.compute.internal>
Co-authored-by: QI JUN <22017000+QiJune@users.noreply.github.com>
Co-authored-by: Aurelien Chartier <2567591+achartier@users.noreply.github.com>
Co-authored-by: Anthony Chang <27950904+rosenrodt@users.noreply.github.com>
Co-authored-by: CarstyYou <186021327+CarstyYou@users.noreply.github.com>
Co-authored-by: Jinyang Yuan <154768711+jinyangyuan-nvidia@users.noreply.github.com>
Co-authored-by: narutolhy <582909902@qq.com>
Co-authored-by: Zhanrui Sun <184402041+ZhanruiSunCh@users.noreply.github.com>
Co-authored-by: wili <98001977+wili-65535@users.noreply.github.com>
Co-authored-by: wili-65535 <wili-65535@users.noreply.github.com>
Co-authored-by: Void <18275976+yilin-void@users.noreply.github.com>
Co-authored-by: William Tambellini <wtambellini@sdl.com>
2025-07-23 14:37:20 +08:00

249 lines
10 KiB
Python

import subprocess
import pytest
from cuda import cuda, nvrtc
def ASSERT_DRV(err):
if isinstance(err, cuda.CUresult):
if err != cuda.CUresult.CUDA_SUCCESS:
raise RuntimeError('Cuda Error: {}'.format(err))
elif isinstance(err, nvrtc.nvrtcResult):
if err != nvrtc.nvrtcResult.NVRTC_SUCCESS:
raise RuntimeError('Nvrtc Error: {}'.format(err))
else:
raise RuntimeError('Unknown error type: {}'.format(err))
def getSMVersion():
# Init
err, = cuda.cuInit(0)
ASSERT_DRV(err)
# Device
err, cuDevice = cuda.cuDeviceGet(0)
ASSERT_DRV(err)
# Get target architecture
err, sm_major = cuda.cuDeviceGetAttribute(
cuda.CUdevice_attribute.CU_DEVICE_ATTRIBUTE_COMPUTE_CAPABILITY_MAJOR,
cuDevice)
ASSERT_DRV(err)
err, sm_minor = cuda.cuDeviceGetAttribute(
cuda.CUdevice_attribute.CU_DEVICE_ATTRIBUTE_COMPUTE_CAPABILITY_MINOR,
cuDevice)
ASSERT_DRV(err)
return sm_major * 10 + sm_minor
# The default test cases for flash attention fmha that will be used in TRTLLM.
@pytest.mark.parametrize('d', [32, 40, 64, 72, 80, 96, 104, 128, 160, 192, 256],
ids=[
"head-size-32", "head-size-40", "head-size-64",
"head-size-72", "head-size-80", "head-size-96",
"head-size-104", "head-size-128", "head-size-160",
"head-size-192", "head-size-256"
])
@pytest.mark.parametrize('s', [1024], ids=["seqlen-1024"])
@pytest.mark.parametrize('dtype', ["-fp16", "-bf16", "-fp16-fp32", "-e4m3"],
ids=["fp16", "bf16", "fp16-fp32", "e4m3"])
@pytest.mark.parametrize('flag', [
"-s-q 128 -paged-kv", "-s-q 63 -paged-kv", "-paged-kv",
"-softcapping-scale-bmm1 30", "-contiguous-q-kv"
])
@pytest.mark.parametrize('tiled_kernel', ["", "-force-non-tiled"])
def test_trtllm_flash_attention_fmha(d, s, dtype, flag, tiled_kernel):
verbose = 0
sm_version = getSMVersion()
if sm_version == 90 and tiled_kernel == "-force-non-tiled":
pytest.skip(
"Tiled/non-tiled flags only make a difference to ampere-style kernels."
)
if sm_version == 70 and dtype != "-fp16":
pytest.skip("Volta fmha only supports fp16 data type.")
# looks like cublas doesn't support non-multiple-of-16 head sizes.
if dtype == '-e4m3' and d in [40, 72, 104]:
pytest.skip("cublas doesn't support non-multiple-of-16 head sizes.")
# only ada/hopper support fp8 fmha currently.
if dtype == '-e4m3' and sm_version not in [89, 90]:
pytest.skip("only hopper supports fp8 fmha currently.")
# ada fp8 fmha only supports non-tiled kernels currently.
if dtype == '-e4m3' and sm_version == 89 and tiled_kernel == "":
pytest.skip("ada fp8 fmha only supports non-tiled kernels currently.")
# use higher error tolerance for bf16 and e4m3.
epsilon = ''
if dtype == '-bf16':
epsilon += ' -epsilon 0.03'
elif dtype == '-fp16' and '-softcapping-scale-bmm1' in flag:
epsilon += ' -epsilon 0.03'
elif dtype == '-e4m3':
epsilon += ' -epsilon 0.2'
else:
epsilon += ' -epsilon 0.02'
# only generate d = 128 kernels with softcapping-scale-bmm1 support.
if d != 128 and '-softcapping-scale-bmm1' in flag:
pytest.skip(
"Only d = 128 + softcapping-scale-bmm1 kernels are generated by default."
)
# force using non-tiled kernels for d = 64 + contiguous-q-kv flag.
if d == 64 and flag == '-contiguous-q-kv' and sm_version < 90:
flag += ' -force-non-tiled'
# The sm89 e4m3 kernel has a bug with -s-q < 128. This bug will be tracked in the issue.
if sm_version == 89 and dtype == "-e4m3":
if "-s-q 63" in flag:
pytest.skip("skipping chunk size 63 for sm89 e4m3 fmha.")
if "softcapping-scale-bmm1" in flag:
pytest.skip("skipping softcapping-scale-bmm1 for sm89 e4m3 fmha.")
subprocess.run(
f"bin/fmha.exe -d {d} -h 16 -b 8 -s {s} -min-s 128 -v {verbose} {dtype} {epsilon} {flag} {tiled_kernel}",
shell=True,
check=True)
subprocess.run(
f"bin/fmha.exe -d {d} -h 16 -b 8 -s {s} -min-s 128 -causal-mask -v {verbose} {dtype} {epsilon} {flag} {tiled_kernel}",
shell=True,
check=True)
subprocess.run(
f"bin/fmha.exe -d {d} -h 16 -b 8 -s {s} -min-s 128 -causal-mask -gqa 2 -v {verbose} {dtype} {epsilon} {flag} {tiled_kernel}",
shell=True,
check=True)
if flag == '-contiguous-q-kv' or flag == '-paged-kv':
subprocess.run(
f"bin/fmha.exe -d {d} -h 16 -b 8 -s {s} -min-s 128 -custom-mask -gqa 2 -v {verbose} {dtype} {epsilon} {flag} {tiled_kernel}",
shell=True,
check=True)
# alibi and softcapping-scale-bmm1 are mutually exclusive.
if '-softcapping-scale-bmm1' not in flag:
subprocess.run(
f"bin/fmha.exe -d {d} -h 16 -b 8 -s {s} -min-s 128 -causal-mask -alibi -v {verbose} {dtype} {epsilon} {flag} {tiled_kernel}",
shell=True,
check=True)
subprocess.run(
f"bin/fmha.exe -d {d} -h 16 -b 8 -s {s} -min-s 128 -causal-mask -multi-query-attention -sliding-window-size 54 -v {verbose} {dtype} {epsilon} {flag} {tiled_kernel}",
shell=True,
check=True)
# The test cases for sage attention.
@pytest.mark.parametrize('d', [80, 128], ids=["head-size-80", "head-size-128"])
@pytest.mark.parametrize('s', [1024, 4096], ids=["seqlen-1024", "seqlen-4096"])
def test_trtllm_sage_attention_fmha(d, s):
sm_version = getSMVersion()
if sm_version != 89 and sm_version != 90:
pytest.skip("Sage attention only supports sm89 and sm90 currently.")
# Ada.
if sm_version == 89:
subprocess.run(
f"bin/fmha.exe -v 0 -runs 1 -min-s 1024 -s {s} -b 16 -h 8 -d {d} -bf16 \
-sage-block-q 64 -sage-block-k 32 -sage-block-v 32 -force-non-tiled",
shell=True,
check=True)
# The test cases for mla attention.
@pytest.mark.parametrize('dtype', ["-bf16", "-e4m3", "-e4m3 -bf16-output"],
ids=["bf16", "e4m3", "e4m3-bf16"])
@pytest.mark.parametrize('s', [1024, 4096], ids=["seqlen-1024", "seqlen-4096"])
@pytest.mark.parametrize(
'input_layout', ["", "-paged-kv", "-contiguous-q-kv", "-separate-q-k-v"],
ids=["packed-qkv", "paged-kv", "q-contiguous-kv", "separate-q-k-v"])
def test_trtllm_context_mla_attention_fmha(dtype, s, input_layout):
# use higher error tolerance for bf16 and s = 4096.
epsilon = ''
if dtype == "-bf16" and s == 4096:
epsilon += ' -epsilon 0.03'
sm_version = getSMVersion()
if dtype in ["-e4m3", "-e4m3 -bf16-output"] and sm_version != 89:
pytest.skip("FP8 MLAs only supported on sm89 currently.")
# Context phase kernels.
subprocess.run(
f"bin/fmha.exe -v 0 -runs 1 -min-s 1024 -s {s} -b 8 -h 8 -d 192 -dv 128 {dtype} \
-force-non-warp-specialization -causal-mask {epsilon}",
shell=True,
check=True)
if sm_version == 90:
# Now only hopper-style supports separate-q-k-v
subprocess.run(
f"bin/fmha.exe -v 0 -runs 1 -min-s 1024 -s {s} -b 8 -h 8 -d 192 -dv 128 {dtype} \
-causal-mask {epsilon} {input_layout}",
shell=True,
check=True)
@pytest.mark.parametrize('dtype', ["-bf16", "-e4m3", "-e4m3 -bf16-output"],
ids=["bf16", "e4m3", "e4m3-bf16"])
@pytest.mark.parametrize('s', [1024, 4096], ids=["seqlen-1024", "seqlen-4096"])
@pytest.mark.parametrize('num_grouped_heads', [16, 32, 64, 128],
ids=[
"num-grouped-heads-16", "num-grouped-heads-32",
"num-grouped-heads-64", "num-grouped-heads-128"
])
def test_trtllm_gen_mla_attention_fmha(dtype, s, num_grouped_heads):
# use higher error tolerance for bf16 and s = 4096.
epsilon = ''
if dtype == "-bf16" and s == 4096:
epsilon += ' -epsilon 0.03'
sm_version = getSMVersion()
if dtype in ["-e4m3", "-e4m3 -bf16-output"] and sm_version != 89:
pytest.skip("FP8 MLAs only supported on sm89 currently.")
# Generation phase kernels.
subprocess.run(
f"bin/fmha.exe -v 0 -runs 1 -s-q 128 -min-s 1024 -s {s} -b 8 -h 1 -d 576 -dv 512 {dtype} \
-paged-kv -num-grouped-heads {num_grouped_heads} -force-non-warp-specialization {epsilon}",
shell=True,
check=True)
# The test cases for saving softmax.
@pytest.mark.parametrize('mask', ["-causal-mask", ""],
ids=["causal-mask", "padding-mask"])
@pytest.mark.parametrize(
's', [128, 256, 384, 512],
ids=["seqlen-128", "seqlen-256", "seqlen-384", "seqlen-512"])
def test_trtllm_save_softmax(mask, s):
subprocess.run(
f"bin/fmha.exe -v 0 -runs 1 -s {s} -d 64 -min-s 1 -b 1 -h 4 -fp16 \
{mask} -contiguous-q-kv -save-softmax",
shell=True,
check=True)
# The test cases for chunked attention.
@pytest.mark.parametrize('chunked_attention_size', [128, 256, 512, 1024],
ids=[
"chunked-attention-size-128",
"chunked-attention-size-256",
"chunked-attention-size-512",
"chunked-attention-size-1024"
])
@pytest.mark.parametrize('input_layout', ["", "-paged-kv"],
ids=["packed-qkv", "paged-kv"])
def test_trtllm_chunked_attention(chunked_attention_size, input_layout):
# only supported on hopper currently.
if getSMVersion() != 90:
pytest.skip("Chunked attention only supported on hopper currently.")
subprocess.run(f"bin/fmha.exe -d 128 -b 4 -h 5 -fp16 -s 8192 -min-s 4096 \
-chunked-attention-size {chunked_attention_size} {input_layout} ",
shell=True,
check=True)
# Chunked context works with chunked attention.
if input_layout == "-paged-kv":
subprocess.run(
f"bin/fmha.exe -d 128 -b 8 -h 5 -s-q 256 -s 8192 -min-s 4096 -fp16 \
-chunked-attention-size {chunked_attention_size} -paged-kv",
shell=True,
check=True)