import subprocess import pytest try: from cuda.bindings import driver as cuda from cuda.bindings import nvrtc except ImportError: from cuda import cuda, nvrtc def ASSERT_DRV(err): if isinstance(err, cuda.CUresult): if err != cuda.CUresult.CUDA_SUCCESS: raise RuntimeError('Cuda Error: {}'.format(err)) elif isinstance(err, nvrtc.nvrtcResult): if err != nvrtc.nvrtcResult.NVRTC_SUCCESS: raise RuntimeError('Nvrtc Error: {}'.format(err)) else: raise RuntimeError('Unknown error type: {}'.format(err)) def getSMVersion(): # Init err, = cuda.cuInit(0) ASSERT_DRV(err) # Device err, cuDevice = cuda.cuDeviceGet(0) ASSERT_DRV(err) # Get target architecture err, sm_major = cuda.cuDeviceGetAttribute( cuda.CUdevice_attribute.CU_DEVICE_ATTRIBUTE_COMPUTE_CAPABILITY_MAJOR, cuDevice) ASSERT_DRV(err) err, sm_minor = cuda.cuDeviceGetAttribute( cuda.CUdevice_attribute.CU_DEVICE_ATTRIBUTE_COMPUTE_CAPABILITY_MINOR, cuDevice) ASSERT_DRV(err) return sm_major * 10 + sm_minor # The default test cases for flash attention fmha that will be used in TRTLLM. @pytest.mark.parametrize('d', [32, 40, 64, 72, 80, 96, 104, 128, 160, 192, 256], ids=[ "head-size-32", "head-size-40", "head-size-64", "head-size-72", "head-size-80", "head-size-96", "head-size-104", "head-size-128", "head-size-160", "head-size-192", "head-size-256" ]) @pytest.mark.parametrize('s', [1024], ids=["seqlen-1024"]) @pytest.mark.parametrize('dtype', ["-fp16", "-bf16", "-fp16-fp32", "-e4m3"], ids=["fp16", "bf16", "fp16-fp32", "e4m3"]) @pytest.mark.parametrize('flag', [ "-s-q 128 -paged-kv", "-s-q 63 -paged-kv", "-paged-kv", "-softcapping-scale-bmm1 30", "-contiguous-q-kv", "-use-attention-sinks" ]) @pytest.mark.parametrize('tiled_kernel', ["", "-force-non-tiled"]) def test_trtllm_flash_attention_fmha(d, s, dtype, flag, tiled_kernel): verbose = 0 sm_version = getSMVersion() if flag == "-use-attention-sinks" and sm_version != 90: pytest.skip("use-attention-sinks is only supported on sm90 currently.") if sm_version == 90 and tiled_kernel == "-force-non-tiled": pytest.skip( "Tiled/non-tiled flags only make a difference to ampere-style kernels." ) if sm_version == 70 and dtype != "-fp16": pytest.skip("Volta fmha only supports fp16 data type.") # looks like cublas doesn't support non-multiple-of-16 head sizes. if dtype == '-e4m3' and d in [40, 72, 104]: pytest.skip("cublas doesn't support non-multiple-of-16 head sizes.") # only ada/hopper support fp8 fmha currently. if dtype == '-e4m3' and sm_version not in [89, 90]: pytest.skip("only hopper supports fp8 fmha currently.") # ada fp8 fmha only supports non-tiled kernels currently. if dtype == '-e4m3' and sm_version == 89 and tiled_kernel == "": pytest.skip("ada fp8 fmha only supports non-tiled kernels currently.") # Known accuracy issue in this case. skip_dense_mask_test = False if d == 64 and dtype in ['-fp16-fp32', '-bf16'] and tiled_kernel == "": skip_dense_mask_test = True # use higher error tolerance for bf16 and e4m3. epsilon = '' if dtype == '-bf16': epsilon += ' -epsilon 0.03' elif dtype == '-fp16' and '-softcapping-scale-bmm1' in flag: epsilon += ' -epsilon 0.03' elif dtype == '-e4m3': epsilon += ' -epsilon 0.2' else: epsilon += ' -epsilon 0.02' # only generate d = 128 kernels with softcapping-scale-bmm1 support. if d != 128 and '-softcapping-scale-bmm1' in flag: pytest.skip( "Only d = 128 + softcapping-scale-bmm1 kernels are generated by default." ) # force using non-tiled kernels for d = 64 + contiguous-q-kv flag. if d == 64 and flag == '-contiguous-q-kv' and sm_version < 90: flag += ' -force-non-tiled' # The sm89 e4m3 kernel has a bug with -s-q < 128. This bug will be tracked in the issue. if sm_version == 89 and dtype == "-e4m3": if "-s-q 63" in flag: pytest.skip("skipping chunk size 63 for sm89 e4m3 fmha.") if "softcapping-scale-bmm1" in flag: pytest.skip("skipping softcapping-scale-bmm1 for sm89 e4m3 fmha.") if not skip_dense_mask_test: subprocess.run( f"bin/fmha.exe -d {d} -h 16 -b 8 -s {s} -min-s 128 -v {verbose} {dtype} {epsilon} {flag} {tiled_kernel}", shell=True, check=True) subprocess.run( f"bin/fmha.exe -d {d} -h 16 -b 8 -s {s} -min-s 128 -causal-mask -v {verbose} {dtype} {epsilon} {flag} {tiled_kernel}", shell=True, check=True) subprocess.run( f"bin/fmha.exe -d {d} -h 16 -b 8 -s {s} -min-s 128 -causal-mask -gqa 2 -v {verbose} {dtype} {epsilon} {flag} {tiled_kernel}", shell=True, check=True) if flag == '-contiguous-q-kv' or flag == '-paged-kv': subprocess.run( f"bin/fmha.exe -d {d} -h 16 -b 8 -s {s} -min-s 128 -custom-mask -gqa 2 -v {verbose} {dtype} {epsilon} {flag} {tiled_kernel}", shell=True, check=True) # alibi doesn't work with softcapping-scale-bmm1/use-attention-sinks. if '-softcapping-scale-bmm1' not in flag and '-use-attention-sinks' not in flag: subprocess.run( f"bin/fmha.exe -d {d} -h 16 -b 8 -s {s} -min-s 128 -causal-mask -alibi -v {verbose} {dtype} {epsilon} {flag} {tiled_kernel}", shell=True, check=True) subprocess.run( f"bin/fmha.exe -d {d} -h 16 -b 8 -s {s} -min-s 128 -causal-mask -multi-query-attention -sliding-window-size 54 -v {verbose} {dtype} {epsilon} {flag} {tiled_kernel}", shell=True, check=True) # The test cases for sage attention. @pytest.mark.parametrize('d', [80, 128], ids=["head-size-80", "head-size-128"]) @pytest.mark.parametrize('s', [1024, 4096], ids=["seqlen-1024", "seqlen-4096"]) def test_trtllm_sage_attention_fmha(d, s): sm_version = getSMVersion() if sm_version != 89 and sm_version != 90: pytest.skip("Sage attention only supports sm89 and sm90 currently.") # Ada. if sm_version == 89: subprocess.run( f"bin/fmha.exe -v 0 -runs 1 -min-s 1024 -s {s} -b 16 -h 8 -d {d} -bf16 \ -sage-block-q 64 -sage-block-k 32 -sage-block-v 32 -force-non-tiled", shell=True, check=True) # The test cases for mla attention. @pytest.mark.parametrize('dtype', ["-bf16", "-e4m3", "-e4m3 -bf16-output"], ids=["bf16", "e4m3", "e4m3-bf16"]) @pytest.mark.parametrize('s', [1024, 4096], ids=["seqlen-1024", "seqlen-4096"]) def test_trtllm_context_mla_attention_fmha(dtype, s): sm_version = getSMVersion() if sm_version < 90: pytest.skip("MLA kernels are only tested on sm90 and above currently.") # use higher error tolerance for bf16 and s = 4096. epsilon = '' if dtype == "-bf16" and s == 4096: epsilon += ' -epsilon 0.03' if dtype in ["-e4m3", "-e4m3 -bf16-output"] and sm_version not in [90, 120]: pytest.skip("FP8 MLAs are only supported on sm90 and sm120 currently.") # Context phase kernels, always use separate-q-k-v layout. subprocess.run( f"bin/fmha.exe -v 0 -runs 1 -min-s 1024 -s {s} -b 8 -h 8 -d 192 -dv 128 {dtype} " f"-causal-mask {epsilon} -separate-q-k-v", shell=True, check=True) # For chunked prefill, we need to enable -save-softmax (dtype: bf16, layout: separate-q-k-v). if dtype in ["-bf16", "-e4m3"]: # padding mask subprocess.run( f"bin/fmha.exe -v 0 -runs 1 -min-s 1024 -s {s} -b 8 -h 8 -d 192 -dv 128 {dtype} " f"{epsilon} -separate-q-k-v -save-softmax", shell=True, check=True) # causal mask subprocess.run( f"bin/fmha.exe -v 0 -runs 1 -min-s 1024 -s {s} -b 8 -h 8 -d 192 -dv 128 {dtype} " f"-causal-mask {epsilon} -separate-q-k-v -save-softmax", shell=True, check=True) @pytest.mark.parametrize('dtype', ["-bf16", "-e4m3", "-e4m3 -bf16-output"], ids=["bf16", "e4m3", "e4m3-bf16"]) @pytest.mark.parametrize('s', [1024, 4096], ids=["seqlen-1024", "seqlen-4096"]) @pytest.mark.parametrize('num_grouped_heads', [16, 32, 64, 128], ids=[ "num-grouped-heads-16", "num-grouped-heads-32", "num-grouped-heads-64", "num-grouped-heads-128" ]) def test_trtllm_gen_mla_attention_fmha(dtype, s, num_grouped_heads): sm_version = getSMVersion() if sm_version < 90: pytest.skip("MLA kernels are only tested on sm90 and above currently.") # use higher error tolerance for bf16 and s = 4096. epsilon = '' if dtype == "-bf16" and s == 4096: epsilon += ' -epsilon 0.03' if dtype in ["-e4m3", "-e4m3 -bf16-output"] and sm_version != 120: pytest.skip("FP8 MLAs are only supported on sm120 currently.") # Generation phase kernels. subprocess.run( f"bin/fmha.exe -v 0 -runs 1 -s-q 128 -min-s 1024 -s {s} -b 8 -h 1 -d 576 -dv 512 {dtype} \ -paged-kv -num-grouped-heads {num_grouped_heads} -force-non-warp-specialization {epsilon}", shell=True, check=True) # The test cases for saving softmax. @pytest.mark.parametrize('mask', ["-causal-mask", ""], ids=["causal-mask", "padding-mask"]) @pytest.mark.parametrize( 's', [128, 256, 384, 512], ids=["seqlen-128", "seqlen-256", "seqlen-384", "seqlen-512"]) def test_trtllm_save_softmax(mask, s): subprocess.run( f"bin/fmha.exe -v 0 -runs 1 -s {s} -d 64 -min-s 1 -b 1 -h 4 -fp16 \ {mask} -contiguous-q-kv -save-softmax", shell=True, check=True) # The test cases for chunked attention. @pytest.mark.parametrize('chunked_attention_size', [128, 256, 512, 1024], ids=[ "chunked-attention-size-128", "chunked-attention-size-256", "chunked-attention-size-512", "chunked-attention-size-1024" ]) @pytest.mark.parametrize('input_layout', ["", "-paged-kv"], ids=["packed-qkv", "paged-kv"]) def test_trtllm_chunked_attention(chunked_attention_size, input_layout): # only supported on hopper currently. if getSMVersion() != 90: pytest.skip("Chunked attention only supported on hopper currently.") subprocess.run(f"bin/fmha.exe -d 128 -b 4 -h 5 -fp16 -s 8192 -min-s 4096 \ -chunked-attention-size {chunked_attention_size} {input_layout} ", shell=True, check=True) # Chunked context works with chunked attention. if input_layout == "-paged-kv": subprocess.run( f"bin/fmha.exe -d 128 -b 8 -h 5 -s-q 256 -s 8192 -min-s 4096 -fp16 \ -chunked-attention-size {chunked_attention_size} -paged-kv", shell=True, check=True)