mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-01-14 06:27:45 +08:00
238 lines
9.7 KiB
Python
238 lines
9.7 KiB
Python
import subprocess
|
|
|
|
import pytest
|
|
from cuda import cuda, nvrtc
|
|
|
|
|
|
def ASSERT_DRV(err):
|
|
if isinstance(err, cuda.CUresult):
|
|
if err != cuda.CUresult.CUDA_SUCCESS:
|
|
raise RuntimeError('Cuda Error: {}'.format(err))
|
|
elif isinstance(err, nvrtc.nvrtcResult):
|
|
if err != nvrtc.nvrtcResult.NVRTC_SUCCESS:
|
|
raise RuntimeError('Nvrtc Error: {}'.format(err))
|
|
else:
|
|
raise RuntimeError('Unknown error type: {}'.format(err))
|
|
|
|
|
|
def getSMVersion():
|
|
# Init
|
|
err, = cuda.cuInit(0)
|
|
ASSERT_DRV(err)
|
|
|
|
# Device
|
|
err, cuDevice = cuda.cuDeviceGet(0)
|
|
ASSERT_DRV(err)
|
|
|
|
# Get target architecture
|
|
err, sm_major = cuda.cuDeviceGetAttribute(
|
|
cuda.CUdevice_attribute.CU_DEVICE_ATTRIBUTE_COMPUTE_CAPABILITY_MAJOR,
|
|
cuDevice)
|
|
ASSERT_DRV(err)
|
|
err, sm_minor = cuda.cuDeviceGetAttribute(
|
|
cuda.CUdevice_attribute.CU_DEVICE_ATTRIBUTE_COMPUTE_CAPABILITY_MINOR,
|
|
cuDevice)
|
|
ASSERT_DRV(err)
|
|
|
|
return sm_major * 10 + sm_minor
|
|
|
|
|
|
# The default test cases for flash attention fmha that will be used in TRTLLM.
|
|
@pytest.mark.parametrize('d', [32, 40, 64, 72, 80, 96, 104, 128, 160, 192, 256],
|
|
ids=[
|
|
"head-size-32", "head-size-40", "head-size-64",
|
|
"head-size-72", "head-size-80", "head-size-96",
|
|
"head-size-104", "head-size-128", "head-size-160",
|
|
"head-size-192", "head-size-256"
|
|
])
|
|
@pytest.mark.parametrize('s', [1024], ids=["seqlen-1024"])
|
|
@pytest.mark.parametrize('dtype', ["-fp16", "-bf16", "-fp16-fp32", "-e4m3"],
|
|
ids=["fp16", "bf16", "fp16-fp32", "e4m3"])
|
|
@pytest.mark.parametrize('flag', [
|
|
"-s-q 128 -paged-kv", "-s-q 63 -paged-kv", "-paged-kv",
|
|
"-softcapping-scale-bmm1 30", "-contiguous-q-kv"
|
|
])
|
|
@pytest.mark.parametrize('tiled_kernel', ["", "-force-non-tiled"])
|
|
def test_trtllm_flash_attention_fmha(d, s, dtype, flag, tiled_kernel):
|
|
verbose = 0
|
|
sm_version = getSMVersion()
|
|
if sm_version == 90 and tiled_kernel == "-force-non-tiled":
|
|
pytest.skip(
|
|
"Tiled/non-tiled flags only make a difference to ampere-style kernels."
|
|
)
|
|
if sm_version == 70 and dtype != "-fp16":
|
|
pytest.skip("Volta fmha only supports fp16 data type.")
|
|
# looks like cublas doesn't support non-multiple-of-16 head sizes.
|
|
if dtype == '-e4m3' and d in [40, 72, 104]:
|
|
pytest.skip("cublas doesn't support non-multiple-of-16 head sizes.")
|
|
# only ada/hopper support fp8 fmha currently.
|
|
if dtype == '-e4m3' and sm_version not in [89, 90]:
|
|
pytest.skip("only hopper supports fp8 fmha currently.")
|
|
# ada fp8 fmha only supports non-tiled kernels currently.
|
|
if dtype == '-e4m3' and sm_version == 89 and tiled_kernel == "":
|
|
pytest.skip("ada fp8 fmha only supports non-tiled kernels currently.")
|
|
|
|
# use higher error tolerance for bf16 and e4m3.
|
|
epsilon = ''
|
|
if dtype == '-bf16':
|
|
epsilon += ' -epsilon 0.03'
|
|
elif dtype == '-fp16' and '-softcapping-scale-bmm1' in flag:
|
|
epsilon += ' -epsilon 0.03'
|
|
elif dtype == '-e4m3':
|
|
epsilon += ' -epsilon 0.2'
|
|
else:
|
|
epsilon += ' -epsilon 0.02'
|
|
|
|
# only generate d = 128 kernels with softcapping-scale-bmm1 support.
|
|
if d != 128 and '-softcapping-scale-bmm1' in flag:
|
|
pytest.skip(
|
|
"Only d = 128 + softcapping-scale-bmm1 kernels are generated by default."
|
|
)
|
|
|
|
# force using non-tiled kernels for d = 64 + contiguous-q-kv flag.
|
|
if d == 64 and flag == '-contiguous-q-kv' and sm_version < 90:
|
|
flag += ' -force-non-tiled'
|
|
|
|
# The sm89 e4m3 kernel has a bug with -s-q < 128. This bug will be tracked in the issue.
|
|
if sm_version == 89 and dtype == "-e4m3":
|
|
if "-s-q 63" in flag:
|
|
pytest.skip("skipping chunk size 63 for sm89 e4m3 fmha.")
|
|
if "softcapping-scale-bmm1" in flag:
|
|
pytest.skip("skipping softcapping-scale-bmm1 for sm89 e4m3 fmha.")
|
|
|
|
subprocess.run(
|
|
f"bin/fmha.exe -d {d} -h 16 -b 8 -s {s} -min-s 128 -v {verbose} {dtype} {epsilon} {flag} {tiled_kernel}",
|
|
shell=True,
|
|
check=True)
|
|
subprocess.run(
|
|
f"bin/fmha.exe -d {d} -h 16 -b 8 -s {s} -min-s 128 -causal-mask -v {verbose} {dtype} {epsilon} {flag} {tiled_kernel}",
|
|
shell=True,
|
|
check=True)
|
|
subprocess.run(
|
|
f"bin/fmha.exe -d {d} -h 16 -b 8 -s {s} -min-s 128 -causal-mask -gqa 2 -v {verbose} {dtype} {epsilon} {flag} {tiled_kernel}",
|
|
shell=True,
|
|
check=True)
|
|
if flag == '-contiguous-q-kv' or flag == '-paged-kv':
|
|
subprocess.run(
|
|
f"bin/fmha.exe -d {d} -h 16 -b 8 -s {s} -min-s 128 -custom-mask -gqa 2 -v {verbose} {dtype} {epsilon} {flag} {tiled_kernel}",
|
|
shell=True,
|
|
check=True)
|
|
# alibi and softcapping-scale-bmm1 are mutually exclusive.
|
|
if '-softcapping-scale-bmm1' not in flag:
|
|
subprocess.run(
|
|
f"bin/fmha.exe -d {d} -h 16 -b 8 -s {s} -min-s 128 -causal-mask -alibi -v {verbose} {dtype} {epsilon} {flag} {tiled_kernel}",
|
|
shell=True,
|
|
check=True)
|
|
subprocess.run(
|
|
f"bin/fmha.exe -d {d} -h 16 -b 8 -s {s} -min-s 128 -causal-mask -multi-query-attention -sliding-window-size 54 -v {verbose} {dtype} {epsilon} {flag} {tiled_kernel}",
|
|
shell=True,
|
|
check=True)
|
|
|
|
|
|
# The test cases for sage attention.
|
|
@pytest.mark.parametrize('d', [80, 128], ids=["head-size-80", "head-size-128"])
|
|
@pytest.mark.parametrize('s', [1024, 4096], ids=["seqlen-1024", "seqlen-4096"])
|
|
def test_trtllm_sage_attention_fmha(d, s):
|
|
sm_version = getSMVersion()
|
|
if sm_version != 89 and sm_version != 90:
|
|
pytest.skip("Sage attention only supports sm89 and sm90 currently.")
|
|
|
|
# Ada.
|
|
if sm_version == 89:
|
|
subprocess.run(
|
|
f"bin/fmha.exe -v 0 -runs 1 -min-s 1024 -s {s} -b 16 -h 8 -d {d} -bf16 \
|
|
-sage-block-q 64 -sage-block-k 32 -sage-block-v 32 -force-non-tiled",
|
|
shell=True,
|
|
check=True)
|
|
|
|
|
|
# The test cases for mla attention.
|
|
@pytest.mark.parametrize('dtype', ["-bf16", "-e4m3", "-e4m3 -bf16-output"],
|
|
ids=["bf16", "e4m3", "e4m3-bf16"])
|
|
@pytest.mark.parametrize('s', [1024, 4096], ids=["seqlen-1024", "seqlen-4096"])
|
|
def test_trtllm_context_mla_attention_fmha(dtype, s):
|
|
# use higher error tolerance for bf16 and s = 4096.
|
|
epsilon = ''
|
|
if dtype == "-bf16" and s == 4096:
|
|
epsilon += ' -epsilon 0.03'
|
|
|
|
sm_version = getSMVersion()
|
|
if sm_version != 89:
|
|
pytest.skip("FP8 MLAs only supported on sm89 currently.")
|
|
|
|
# Context phase kernels.
|
|
subprocess.run(
|
|
f"bin/fmha.exe -v 0 -runs 1 -min-s 1024 -s {s} -b 8 -h 8 -d 192 -dv 128 {dtype} \
|
|
-force-non-warp-specialization -causal-mask {epsilon}",
|
|
shell=True,
|
|
check=True)
|
|
|
|
|
|
@pytest.mark.parametrize('dtype', ["-bf16", "-e4m3", "-e4m3 -bf16-output"],
|
|
ids=["bf16", "e4m3", "e4m3-bf16"])
|
|
@pytest.mark.parametrize('s', [1024, 4096], ids=["seqlen-1024", "seqlen-4096"])
|
|
@pytest.mark.parametrize('num_grouped_heads', [16, 32, 64, 128],
|
|
ids=[
|
|
"num-grouped-heads-16", "num-grouped-heads-32",
|
|
"num-grouped-heads-64", "num-grouped-heads-128"
|
|
])
|
|
def test_trtllm_gen_mla_attention_fmha(dtype, s, num_grouped_heads):
|
|
# use higher error tolerance for bf16 and s = 4096.
|
|
epsilon = ''
|
|
if dtype == "-bf16" and s == 4096:
|
|
epsilon += ' -epsilon 0.03'
|
|
|
|
sm_version = getSMVersion()
|
|
if dtype in ["-e4m3", "-e4m3 -bf16-output"] and sm_version != 89:
|
|
pytest.skip("FP8 MLAs only supported on sm89 currently.")
|
|
|
|
# Generation phase kernels.
|
|
subprocess.run(
|
|
f"bin/fmha.exe -v 0 -runs 1 -s-q 128 -min-s 1024 -s {s} -b 8 -h 1 -d 576 -dv 512 {dtype} \
|
|
-paged-kv -num-grouped-heads {num_grouped_heads} -force-non-warp-specialization {epsilon}",
|
|
shell=True,
|
|
check=True)
|
|
|
|
|
|
# The test cases for saving softmax.
|
|
@pytest.mark.parametrize('mask', ["-causal-mask", ""],
|
|
ids=["causal-mask", "padding-mask"])
|
|
@pytest.mark.parametrize(
|
|
's', [128, 256, 384, 512],
|
|
ids=["seqlen-128", "seqlen-256", "seqlen-384", "seqlen-512"])
|
|
def test_trtllm_save_softmax(mask, s):
|
|
subprocess.run(
|
|
f"bin/fmha.exe -v 0 -runs 1 -s {s} -d 64 -min-s 1 -b 1 -h 4 -fp16 \
|
|
{mask} -contiguous-q-kv -save-softmax",
|
|
shell=True,
|
|
check=True)
|
|
|
|
|
|
# The test cases for chunked attention.
|
|
@pytest.mark.parametrize('chunked_attention_size', [128, 256, 512, 1024],
|
|
ids=[
|
|
"chunked-attention-size-128",
|
|
"chunked-attention-size-256",
|
|
"chunked-attention-size-512",
|
|
"chunked-attention-size-1024"
|
|
])
|
|
@pytest.mark.parametrize('input_layout', ["", "-paged-kv"],
|
|
ids=["packed-qkv", "paged-kv"])
|
|
def test_trtllm_chunked_attention(chunked_attention_size, input_layout):
|
|
# only supported on hopper currently.
|
|
if getSMVersion() != 90:
|
|
pytest.skip("Chunked attention only supported on hopper currently.")
|
|
|
|
subprocess.run(f"bin/fmha.exe -d 128 -b 4 -h 5 -fp16 -s 8192 -min-s 4096 \
|
|
-chunked-attention-size {chunked_attention_size} {input_layout} ",
|
|
shell=True,
|
|
check=True)
|
|
|
|
# Chunked context works with chunked attention.
|
|
if input_layout == "-paged-kv":
|
|
subprocess.run(
|
|
f"bin/fmha.exe -d 128 -b 8 -h 5 -s-q 256 -s 8192 -min-s 4096 -fp16 \
|
|
-chunked-attention-size {chunked_attention_size} -paged-kv",
|
|
shell=True,
|
|
check=True)
|