TensorRT-LLMs/cpp/kernels/fmha_v2/fmha_test.py
Faraz 27a5091fcb
[None][feat] GPT-OSS Sm120/Sm121 Support (#7937)
Signed-off-by: Perkz Zheng <67892460+PerkzZheng@users.noreply.github.com>
Signed-off-by: list <58580514+farazkh80@users.noreply.github.com>
Signed-off-by: Vincent Huang <vincenth@nvidia.com>
Co-authored-by: Perkz Zheng <67892460+PerkzZheng@users.noreply.github.com>
Co-authored-by: Vincent Huang <vincenth@nvidia.com>
2025-10-06 16:59:06 -04:00

271 lines
11 KiB
Python

import subprocess
import pytest
try:
from cuda.bindings import driver as cuda
from cuda.bindings import nvrtc
except ImportError:
from cuda import cuda, nvrtc
def ASSERT_DRV(err):
if isinstance(err, cuda.CUresult):
if err != cuda.CUresult.CUDA_SUCCESS:
raise RuntimeError('Cuda Error: {}'.format(err))
elif isinstance(err, nvrtc.nvrtcResult):
if err != nvrtc.nvrtcResult.NVRTC_SUCCESS:
raise RuntimeError('Nvrtc Error: {}'.format(err))
else:
raise RuntimeError('Unknown error type: {}'.format(err))
def getSMVersion():
# Init
err, = cuda.cuInit(0)
ASSERT_DRV(err)
# Device
err, cuDevice = cuda.cuDeviceGet(0)
ASSERT_DRV(err)
# Get target architecture
err, sm_major = cuda.cuDeviceGetAttribute(
cuda.CUdevice_attribute.CU_DEVICE_ATTRIBUTE_COMPUTE_CAPABILITY_MAJOR,
cuDevice)
ASSERT_DRV(err)
err, sm_minor = cuda.cuDeviceGetAttribute(
cuda.CUdevice_attribute.CU_DEVICE_ATTRIBUTE_COMPUTE_CAPABILITY_MINOR,
cuDevice)
ASSERT_DRV(err)
return sm_major * 10 + sm_minor
# The default test cases for flash attention fmha that will be used in TRTLLM.
@pytest.mark.parametrize('d', [32, 40, 64, 72, 80, 96, 104, 128, 160, 192, 256],
ids=[
"head-size-32", "head-size-40", "head-size-64",
"head-size-72", "head-size-80", "head-size-96",
"head-size-104", "head-size-128", "head-size-160",
"head-size-192", "head-size-256"
])
@pytest.mark.parametrize('s', [1024], ids=["seqlen-1024"])
@pytest.mark.parametrize('dtype', ["-fp16", "-bf16", "-fp16-fp32", "-e4m3"],
ids=["fp16", "bf16", "fp16-fp32", "e4m3"])
@pytest.mark.parametrize('flag', [
"-s-q 128 -paged-kv", "-s-q 63 -paged-kv", "-paged-kv",
"-softcapping-scale-bmm1 30", "-contiguous-q-kv", "-use-attention-sinks"
])
@pytest.mark.parametrize('tiled_kernel', ["", "-force-non-tiled"])
def test_trtllm_flash_attention_fmha(d, s, dtype, flag, tiled_kernel):
verbose = 0
sm_version = getSMVersion()
if flag == "-use-attention-sinks" and sm_version != 90:
pytest.skip("use-attention-sinks is only supported on sm90 currently.")
if sm_version == 90 and tiled_kernel == "-force-non-tiled":
pytest.skip(
"Tiled/non-tiled flags only make a difference to ampere-style kernels."
)
if sm_version == 70 and dtype != "-fp16":
pytest.skip("Volta fmha only supports fp16 data type.")
# looks like cublas doesn't support non-multiple-of-16 head sizes.
if dtype == '-e4m3' and d in [40, 72, 104]:
pytest.skip("cublas doesn't support non-multiple-of-16 head sizes.")
# only ada/hopper support fp8 fmha currently.
if dtype == '-e4m3' and sm_version not in [89, 90]:
pytest.skip("only hopper supports fp8 fmha currently.")
# ada fp8 fmha only supports non-tiled kernels currently.
if dtype == '-e4m3' and sm_version == 89 and tiled_kernel == "":
pytest.skip("ada fp8 fmha only supports non-tiled kernels currently.")
# Known accuracy issue in this case.
skip_dense_mask_test = False
if d == 64 and dtype in ['-fp16-fp32', '-bf16'] and tiled_kernel == "":
skip_dense_mask_test = True
# use higher error tolerance for bf16 and e4m3.
epsilon = ''
if dtype == '-bf16':
epsilon += ' -epsilon 0.03'
elif dtype == '-fp16' and '-softcapping-scale-bmm1' in flag:
epsilon += ' -epsilon 0.03'
elif dtype == '-e4m3':
epsilon += ' -epsilon 0.2'
else:
epsilon += ' -epsilon 0.02'
# only generate d = 128 kernels with softcapping-scale-bmm1 support.
if d != 128 and '-softcapping-scale-bmm1' in flag:
pytest.skip(
"Only d = 128 + softcapping-scale-bmm1 kernels are generated by default."
)
# force using non-tiled kernels for d = 64 + contiguous-q-kv flag.
if d == 64 and flag == '-contiguous-q-kv' and sm_version < 90:
flag += ' -force-non-tiled'
# The sm89 e4m3 kernel has a bug with -s-q < 128. This bug will be tracked in the issue.
if sm_version == 89 and dtype == "-e4m3":
if "-s-q 63" in flag:
pytest.skip("skipping chunk size 63 for sm89 e4m3 fmha.")
if "softcapping-scale-bmm1" in flag:
pytest.skip("skipping softcapping-scale-bmm1 for sm89 e4m3 fmha.")
if not skip_dense_mask_test:
subprocess.run(
f"bin/fmha.exe -d {d} -h 16 -b 8 -s {s} -min-s 128 -v {verbose} {dtype} {epsilon} {flag} {tiled_kernel}",
shell=True,
check=True)
subprocess.run(
f"bin/fmha.exe -d {d} -h 16 -b 8 -s {s} -min-s 128 -causal-mask -v {verbose} {dtype} {epsilon} {flag} {tiled_kernel}",
shell=True,
check=True)
subprocess.run(
f"bin/fmha.exe -d {d} -h 16 -b 8 -s {s} -min-s 128 -causal-mask -gqa 2 -v {verbose} {dtype} {epsilon} {flag} {tiled_kernel}",
shell=True,
check=True)
if flag == '-contiguous-q-kv' or flag == '-paged-kv':
subprocess.run(
f"bin/fmha.exe -d {d} -h 16 -b 8 -s {s} -min-s 128 -custom-mask -gqa 2 -v {verbose} {dtype} {epsilon} {flag} {tiled_kernel}",
shell=True,
check=True)
# alibi doesn't work with softcapping-scale-bmm1/use-attention-sinks.
if '-softcapping-scale-bmm1' not in flag and '-use-attention-sinks' not in flag:
subprocess.run(
f"bin/fmha.exe -d {d} -h 16 -b 8 -s {s} -min-s 128 -causal-mask -alibi -v {verbose} {dtype} {epsilon} {flag} {tiled_kernel}",
shell=True,
check=True)
subprocess.run(
f"bin/fmha.exe -d {d} -h 16 -b 8 -s {s} -min-s 128 -causal-mask -multi-query-attention -sliding-window-size 54 -v {verbose} {dtype} {epsilon} {flag} {tiled_kernel}",
shell=True,
check=True)
# The test cases for sage attention.
@pytest.mark.parametrize('d', [80, 128], ids=["head-size-80", "head-size-128"])
@pytest.mark.parametrize('s', [1024, 4096], ids=["seqlen-1024", "seqlen-4096"])
def test_trtllm_sage_attention_fmha(d, s):
sm_version = getSMVersion()
if sm_version != 89 and sm_version != 90:
pytest.skip("Sage attention only supports sm89 and sm90 currently.")
# Ada.
if sm_version == 89:
subprocess.run(
f"bin/fmha.exe -v 0 -runs 1 -min-s 1024 -s {s} -b 16 -h 8 -d {d} -bf16 \
-sage-block-q 64 -sage-block-k 32 -sage-block-v 32 -force-non-tiled",
shell=True,
check=True)
# The test cases for mla attention.
@pytest.mark.parametrize('dtype', ["-bf16", "-e4m3", "-e4m3 -bf16-output"],
ids=["bf16", "e4m3", "e4m3-bf16"])
@pytest.mark.parametrize('s', [1024, 4096], ids=["seqlen-1024", "seqlen-4096"])
def test_trtllm_context_mla_attention_fmha(dtype, s):
sm_version = getSMVersion()
if sm_version < 90:
pytest.skip("MLA kernels are only tested on sm90 and above currently.")
# use higher error tolerance for bf16 and s = 4096.
epsilon = ''
if dtype == "-bf16" and s == 4096:
epsilon += ' -epsilon 0.03'
if dtype in ["-e4m3", "-e4m3 -bf16-output"] and sm_version not in [90, 120]:
pytest.skip("FP8 MLAs are only supported on sm90 and sm120 currently.")
# Context phase kernels, always use separate-q-k-v layout.
subprocess.run(
f"bin/fmha.exe -v 0 -runs 1 -min-s 1024 -s {s} -b 8 -h 8 -d 192 -dv 128 {dtype} "
f"-causal-mask {epsilon} -separate-q-k-v",
shell=True,
check=True)
# For chunked prefill, we need to enable -save-softmax (dtype: bf16, layout: separate-q-k-v).
if dtype in ["-bf16", "-e4m3"]:
# padding mask
subprocess.run(
f"bin/fmha.exe -v 0 -runs 1 -min-s 1024 -s {s} -b 8 -h 8 -d 192 -dv 128 {dtype} "
f"{epsilon} -separate-q-k-v -save-softmax",
shell=True,
check=True)
# causal mask
subprocess.run(
f"bin/fmha.exe -v 0 -runs 1 -min-s 1024 -s {s} -b 8 -h 8 -d 192 -dv 128 {dtype} "
f"-causal-mask {epsilon} -separate-q-k-v -save-softmax",
shell=True,
check=True)
@pytest.mark.parametrize('dtype', ["-bf16", "-e4m3", "-e4m3 -bf16-output"],
ids=["bf16", "e4m3", "e4m3-bf16"])
@pytest.mark.parametrize('s', [1024, 4096], ids=["seqlen-1024", "seqlen-4096"])
@pytest.mark.parametrize('num_grouped_heads', [16, 32, 64, 128],
ids=[
"num-grouped-heads-16", "num-grouped-heads-32",
"num-grouped-heads-64", "num-grouped-heads-128"
])
def test_trtllm_gen_mla_attention_fmha(dtype, s, num_grouped_heads):
sm_version = getSMVersion()
if sm_version < 90:
pytest.skip("MLA kernels are only tested on sm90 and above currently.")
# use higher error tolerance for bf16 and s = 4096.
epsilon = ''
if dtype == "-bf16" and s == 4096:
epsilon += ' -epsilon 0.03'
if dtype in ["-e4m3", "-e4m3 -bf16-output"] and sm_version != 120:
pytest.skip("FP8 MLAs are only supported on sm120 currently.")
# Generation phase kernels.
subprocess.run(
f"bin/fmha.exe -v 0 -runs 1 -s-q 128 -min-s 1024 -s {s} -b 8 -h 1 -d 576 -dv 512 {dtype} \
-paged-kv -num-grouped-heads {num_grouped_heads} -force-non-warp-specialization {epsilon}",
shell=True,
check=True)
# The test cases for saving softmax.
@pytest.mark.parametrize('mask', ["-causal-mask", ""],
ids=["causal-mask", "padding-mask"])
@pytest.mark.parametrize(
's', [128, 256, 384, 512],
ids=["seqlen-128", "seqlen-256", "seqlen-384", "seqlen-512"])
def test_trtllm_save_softmax(mask, s):
subprocess.run(
f"bin/fmha.exe -v 0 -runs 1 -s {s} -d 64 -min-s 1 -b 1 -h 4 -fp16 \
{mask} -contiguous-q-kv -save-softmax",
shell=True,
check=True)
# The test cases for chunked attention.
@pytest.mark.parametrize('chunked_attention_size', [128, 256, 512, 1024],
ids=[
"chunked-attention-size-128",
"chunked-attention-size-256",
"chunked-attention-size-512",
"chunked-attention-size-1024"
])
@pytest.mark.parametrize('input_layout', ["", "-paged-kv"],
ids=["packed-qkv", "paged-kv"])
def test_trtllm_chunked_attention(chunked_attention_size, input_layout):
# only supported on hopper currently.
if getSMVersion() != 90:
pytest.skip("Chunked attention only supported on hopper currently.")
subprocess.run(f"bin/fmha.exe -d 128 -b 4 -h 5 -fp16 -s 8192 -min-s 4096 \
-chunked-attention-size {chunked_attention_size} {input_layout} ",
shell=True,
check=True)
# Chunked context works with chunked attention.
if input_layout == "-paged-kv":
subprocess.run(
f"bin/fmha.exe -d 128 -b 8 -h 5 -s-q 256 -s 8192 -min-s 4096 -fp16 \
-chunked-attention-size {chunked_attention_size} -paged-kv",
shell=True,
check=True)