update setup.py for special cases (#5227)

Signed-off-by: Qidi Sang <200703406+qsang-nv@users.noreply.github.com>
This commit is contained in:
qsang-nv 2025-06-17 16:41:07 +08:00 committed by GitHub
parent 6a6b9d2594
commit faca19c2f0
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 117 additions and 10 deletions

View File

@ -189,7 +189,8 @@ namespace kernels
ns_close = r"""
// clang-format on
} // namespace kernels
} // namespace tensorrt_llm""" if generate_cu_trtllm else ""
} // namespace tensorrt_llm
""" if generate_cu_trtllm else ""
copyright = '''\
/***************************************************************************************************
@ -3403,6 +3404,111 @@ static const struct TestMetaV2
return code
def modify_cubin_header(cubin_header):
# for paged context fmha cases
target = "#ifndef EXCLUDE_SM_90"
first_addition = """extern unsigned char cubin_fmha_v2_flash_attention_bf16_64_128_S_qkv_192x128_tma_ws_sm90_cu_cubin[];
extern unsigned char cubin_fmha_v2_flash_attention_bf16_64_128_S_q_paged_kv_192x128_tma_ws_sm90_cu_cubin[];"""
second_addition = """extern uint32_t cubin_fmha_v2_flash_attention_bf16_64_128_S_qkv_192x128_tma_ws_sm90_cu_cubin_len;
extern uint32_t cubin_fmha_v2_flash_attention_bf16_64_128_S_q_paged_kv_192x128_tma_ws_sm90_cu_cubin_len;"""
third_addition = """{ DATA_TYPE_BF16, DATA_TYPE_BF16, 0, 64, 128, 192, 128, 0, 0, 0, kSM_90, cubin_fmha_v2_flash_attention_bf16_64_128_S_qkv_192x128_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_128_S_qkv_192x128_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_128_S_qkv_192x128_causal_tma_ws_sm90_kernel", 213248, 384, 64, 1, 0, false, true, true, true, false, false, false, false, nullptr},
{ DATA_TYPE_BF16, DATA_TYPE_BF16, 0, 64, 128, 192, 128, 0, 0, 0, kSM_90, cubin_fmha_v2_flash_attention_bf16_64_128_S_q_paged_kv_192x128_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_128_S_q_paged_kv_192x128_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_128_S_q_paged_kv_192x128_causal_tma_ws_sm90_kernel", 213248, 384, 64, 1, 2, false, true, true, true, false, false, false, false, nullptr},"""
result = cubin_header
offset = 0
pos = -1
def add_kernel_line(result, target, addition, pos, offset):
if pos == -1:
pos = result.find(target)
else:
pos = result.find(target, pos + len(target) + offset)
if pos != -1:
end_pos = result.find('\n', pos)
if end_pos == -1:
end_pos = len(result)
result = result[:end_pos + 1] + addition + result[end_pos:]
offset += len(addition)
return result, offset, pos
result, offset, pos = add_kernel_line(result, target, first_addition, pos,
offset)
result, offset, pos = add_kernel_line(result, target, second_addition, pos,
offset)
result, offset, pos = add_kernel_line(result, target, third_addition, pos,
offset)
# for CI cases
def add_kernel_line(result, target, addition):
pos = result.find(target)
if pos != -1:
end_pos = result.find('\n', pos)
if end_pos == -1:
end_pos = len(result)
result = result[:end_pos + 1] + addition + result[end_pos:]
return result
target = "#ifndef EXCLUDE_SM_89"
addition = """extern unsigned char cubin_fmha_v2_flash_attention_bf16_64_32_S_qkv_128_sm89_cu_cubin[];
extern uint32_t cubin_fmha_v2_flash_attention_bf16_64_32_S_qkv_128_sm89_cu_cubin_len;
extern unsigned char cubin_fmha_v2_flash_attention_e4m3_fp32_64_64_S_q_paged_kv_576x512_output_bf16_sm89_cu_cubin[];
extern uint32_t cubin_fmha_v2_flash_attention_e4m3_fp32_64_64_S_q_paged_kv_576x512_output_bf16_sm89_cu_cubin_len;"""
result = add_kernel_line(result, target, addition)
target = "#ifndef EXCLUDE_SM_86"
addition = """extern unsigned char cubin_fmha_v2_flash_attention_bf16_64_32_S_q_paged_kv_64_sm86_cu_cubin[];
extern uint32_t cubin_fmha_v2_flash_attention_bf16_64_32_S_q_paged_kv_64_sm86_cu_cubin_len;"""
result = add_kernel_line(result, target, addition)
target = "#ifndef EXCLUDE_SM_80"
addition = """extern unsigned char cubin_fmha_v2_flash_attention_fp16_128_128_S_q_paged_kv_64_sm80_cu_cubin[];
extern uint32_t cubin_fmha_v2_flash_attention_fp16_128_128_S_q_paged_kv_64_sm80_cu_cubin_len;
extern unsigned char cubin_fmha_v2_flash_attention_fp16_64_128_S_q_paged_kv_128_sm80_cu_cubin[];
extern uint32_t cubin_fmha_v2_flash_attention_fp16_64_128_S_q_paged_kv_128_sm80_cu_cubin_len;"""
result = add_kernel_line(result, target, addition)
def modify_kernel_line(result, target, new_line):
lines = result.split('\n')
for i, line in enumerate(lines):
if target in line:
lines[i] = new_line
break
return '\n'.join(lines)
target = "fmha_v2_flash_attention_bf16_64_32_S_qkv_128_causal_sm89_kernel_nl"
new_line = '{ DATA_TYPE_BF16, DATA_TYPE_BF16, 0, 64, 32, 128, 128, 0, 0, 0, kSM_89, cubin_fmha_v2_flash_attention_bf16_64_32_S_qkv_128_sm89_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_32_S_qkv_128_sm89_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_32_S_qkv_128_causal_sm89_kernel_nl", 32768, 128, 64, 1, 0, false, true, false, true, true, false, false, true, nullptr},'
result = modify_kernel_line(result, target, new_line)
target = "fmha_v2_flash_attention_e4m3_fp32_64_64_S_q_paged_kv_576x512_output_bf16_sm89_kernel_nl_tiled"
new_line = '{ DATA_TYPE_E4M3, DATA_TYPE_BF16, 0, 64, 64, 576, 512, 0, 0, 0, kSM_89, cubin_fmha_v2_flash_attention_e4m3_fp32_64_64_S_q_paged_kv_576x512_output_bf16_sm89_cu_cubin, cubin_fmha_v2_flash_attention_e4m3_fp32_64_64_S_q_paged_kv_576x512_output_bf16_sm89_cu_cubin_len, "fmha_v2_flash_attention_e4m3_fp32_64_64_S_q_paged_kv_576x512_output_bf16_sm89_kernel_nl_tiled", 65536, 128, 64, 0, 2, false, true, false, true, true, true, false, true, nullptr},'
result = modify_kernel_line(result, target, new_line)
target = "fmha_v2_flash_attention_fp16_128_128_S_q_paged_kv_64_causal_sm80_kernel_nl_tiled"
new_line = '{ DATA_TYPE_FP16, DATA_TYPE_FP16, 0, 128, 128, 64, 64, 0, 0, 0, kSM_80, cubin_fmha_v2_flash_attention_fp16_128_128_S_q_paged_kv_64_sm80_cu_cubin, cubin_fmha_v2_flash_attention_fp16_128_128_S_q_paged_kv_64_sm80_cu_cubin_len, "fmha_v2_flash_attention_fp16_128_128_S_q_paged_kv_64_causal_sm80_kernel_nl_tiled", 65536, 128, 128, 1, 2, false, true, false, false, true, true, false, true, nullptr},'
result = modify_kernel_line(result, target, new_line)
target = "fmha_v2_flash_attention_fp16_64_128_S_q_paged_kv_128_causal_sm80_kernel_nl_tiled"
new_line = '{ DATA_TYPE_FP16, DATA_TYPE_FP16, 0, 64, 128, 128, 128, 0, 0, 0, kSM_80, cubin_fmha_v2_flash_attention_fp16_64_128_S_q_paged_kv_128_sm80_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_128_S_q_paged_kv_128_sm80_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_128_S_q_paged_kv_128_causal_sm80_kernel_nl_tiled", 81920, 128, 64, 1, 2, false, true, false, false, true, true, false, true, nullptr},'
result = modify_kernel_line(result, target, new_line)
target = "fmha_v2_flash_attention_bf16_64_32_S_q_paged_kv_64_causal_sm86_kernel_nl"
new_line = '{ DATA_TYPE_BF16, DATA_TYPE_BF16, 0, 64, 32, 64, 64, 0, 0, 0, kSM_86, cubin_fmha_v2_flash_attention_bf16_64_32_S_q_paged_kv_64_sm86_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_32_S_q_paged_kv_64_sm86_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_32_S_q_paged_kv_64_causal_sm86_kernel_nl", 16384, 128, 64, 1, 2, false, true, false, true, true, false, false, true, nullptr},'
result = modify_kernel_line(result, target, new_line)
# make sure only one empty line at the end
lines = result.split('\n')
while lines and not lines[-1].strip():
lines.pop()
lines.append('')
return '\n'.join(lines)
def generate_files(specs_names):
kfiles = []
@ -3443,7 +3549,6 @@ def generate_files(specs_names):
f.write(print_kernel_traits_code)
# Make sure we have a bin directory.
# TEMP disable this until sm90 is in a good shape [Timmy, do not MR.]
if not os.path.exists('bin'):
os.mkdir('bin')
cmd = 'nvcc -I src -Xcompiler -Wno-enum-compare --std=c++17 -o bin/print_traits.exe generated/print_kernel_traits.cu'.split(
@ -3465,6 +3570,8 @@ def generate_files(specs_names):
# this gives: kname, smem bytes, threads_per_cta, loop_step
kernel_traits = [traits.split() for traits in output.splitlines()]
cubin_header = get_cubin_header(kernel_traits, valid_specs_names)
if generate_cu_trtllm:
cubin_header = modify_cubin_header(cubin_header)
with open('./generated/fmha_cubin.h', 'w') as f:
f.write(cubin_header)

View File

@ -26,6 +26,8 @@ namespace kernels
#ifndef EXCLUDE_SM_90
extern unsigned char cubin_fmha_v2_flash_attention_bf16_64_128_S_qkv_192x128_tma_ws_sm90_cu_cubin[];
extern unsigned char cubin_fmha_v2_flash_attention_bf16_64_128_S_q_paged_kv_192x128_tma_ws_sm90_cu_cubin[];
extern unsigned char cubin_fmha_v2_fp16_64_32_ldgsts_sm90_cu_cubin[];
extern unsigned char cubin_fmha_v2_fp16_128_32_ldgsts_sm90_cu_cubin[];
extern unsigned char cubin_fmha_v2_fp16_256_32_ldgsts_sm90_cu_cubin[];
@ -380,8 +382,6 @@ extern unsigned char cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_qkv_256_so
extern unsigned char cubin_fmha_v2_flash_attention_fp16_fp32_64_32_S_qkv_128_softcapping_sm90_cu_cubin[];
extern unsigned char cubin_fmha_v2_flash_attention_fp16_fp32_64_16_S_qkv_256_softcapping_sm90_cu_cubin[];
extern unsigned char cubin_fmha_v2_flash_attention_bf16_64_64_S_q_paged_kv_576x512_sm90_cu_cubin[];
extern unsigned char cubin_fmha_v2_flash_attention_bf16_64_128_S_qkv_192x128_tma_ws_sm90_cu_cubin[];
extern unsigned char cubin_fmha_v2_flash_attention_bf16_64_128_S_q_paged_kv_192x128_tma_ws_sm90_cu_cubin[];
#endif
#ifndef EXCLUDE_SM_89
@ -1318,6 +1318,8 @@ extern void run_fmha_v2_flash_attention_fp16_fp32_64_16_S_q_paged_kv_256_softcap
#ifndef EXCLUDE_SM_90
extern uint32_t cubin_fmha_v2_flash_attention_bf16_64_128_S_qkv_192x128_tma_ws_sm90_cu_cubin_len;
extern uint32_t cubin_fmha_v2_flash_attention_bf16_64_128_S_q_paged_kv_192x128_tma_ws_sm90_cu_cubin_len;
extern uint32_t cubin_fmha_v2_fp16_64_32_ldgsts_sm90_cu_cubin_len;
extern uint32_t cubin_fmha_v2_fp16_128_32_ldgsts_sm90_cu_cubin_len;
extern uint32_t cubin_fmha_v2_fp16_256_32_ldgsts_sm90_cu_cubin_len;
@ -1672,8 +1674,6 @@ extern uint32_t cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_qkv_256_softcap
extern uint32_t cubin_fmha_v2_flash_attention_fp16_fp32_64_32_S_qkv_128_softcapping_sm90_cu_cubin_len;
extern uint32_t cubin_fmha_v2_flash_attention_fp16_fp32_64_16_S_qkv_256_softcapping_sm90_cu_cubin_len;
extern uint32_t cubin_fmha_v2_flash_attention_bf16_64_64_S_q_paged_kv_576x512_sm90_cu_cubin_len;
extern uint32_t cubin_fmha_v2_flash_attention_bf16_64_128_S_qkv_192x128_tma_ws_sm90_cu_cubin_len;
extern uint32_t cubin_fmha_v2_flash_attention_bf16_64_128_S_q_paged_kv_192x128_tma_ws_sm90_cu_cubin_len;
#endif
@ -1709,6 +1709,8 @@ static const struct FusedMultiHeadAttentionKernelMetaInfoV2
void (*launcher)(Fused_multihead_attention_params_v2& params, const Launch_params& launch_params, cudaStream_t stream);
} sMhaKernelMetaInfosV2[] = {
#ifndef EXCLUDE_SM_90
{ DATA_TYPE_BF16, DATA_TYPE_BF16, 0, 64, 128, 192, 128, 0, 0, 0, kSM_90, cubin_fmha_v2_flash_attention_bf16_64_128_S_qkv_192x128_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_128_S_qkv_192x128_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_128_S_qkv_192x128_causal_tma_ws_sm90_kernel", 213248, 384, 64, 1, 0, false, true, true, true, false, false, false, false, nullptr},
{ DATA_TYPE_BF16, DATA_TYPE_BF16, 0, 64, 128, 192, 128, 0, 0, 0, kSM_90, cubin_fmha_v2_flash_attention_bf16_64_128_S_q_paged_kv_192x128_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_128_S_q_paged_kv_192x128_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_128_S_q_paged_kv_192x128_causal_tma_ws_sm90_kernel", 213248, 384, 64, 1, 2, false, true, true, true, false, false, false, false, nullptr},
{ DATA_TYPE_FP16, DATA_TYPE_FP16, 64, 64, 64, 32, 32, 0, 0, 0, kSM_90, cubin_fmha_v2_fp16_64_32_ldgsts_sm90_cu_cubin, cubin_fmha_v2_fp16_64_32_ldgsts_sm90_cu_cubin_len, "fmha_v2_fp16_64_32_ldgsts_sm90_kernel", 17408, 128, 0, 0, 0, false, false, false, false, true, false, false, false, nullptr},
{ DATA_TYPE_FP16, DATA_TYPE_FP16, 64, 64, 64, 32, 32, 0, 0, 0, kSM_90, cubin_fmha_v2_fp16_64_32_ldgsts_sm90_cu_cubin, cubin_fmha_v2_fp16_64_32_ldgsts_sm90_cu_cubin_len, "fmha_v2_fp16_64_32_sliding_or_chunked_causal_ldgsts_sm90_kernel", 17408, 128, 0, 2, 0, false, false, false, false, true, false, false, false, nullptr},
{ DATA_TYPE_FP16, DATA_TYPE_FP16, 64, 64, 64, 32, 32, 0, 0, 0, kSM_90, cubin_fmha_v2_fp16_64_32_ldgsts_sm90_cu_cubin, cubin_fmha_v2_fp16_64_32_ldgsts_sm90_cu_cubin_len, "fmha_v2_fp16_64_32_causal_ldgsts_sm90_kernel", 17408, 128, 0, 1, 0, false, false, false, false, true, false, false, false, nullptr},
@ -2663,8 +2665,6 @@ static const struct FusedMultiHeadAttentionKernelMetaInfoV2
{ DATA_TYPE_FP16, DATA_TYPE_FP16, 0, 64, 16, 256, 256, 0, 0, 0, kSM_90, cubin_fmha_v2_flash_attention_fp16_fp32_64_16_S_qkv_256_softcapping_sm90_cu_cubin, cubin_fmha_v2_flash_attention_fp16_fp32_64_16_S_qkv_256_softcapping_sm90_cu_cubin_len, "fmha_v2_flash_attention_fp16_fp32_64_16_S_qkv_256_causal_softcapping_sm90_kernel_nl", 49152, 128, 64, 1, 0, false, true, false, true, true, false, true, false, nullptr},
{ DATA_TYPE_FP16, DATA_TYPE_FP16, 0, 64, 16, 256, 256, 0, 0, 0, kSM_90, cubin_fmha_v2_flash_attention_fp16_fp32_64_16_S_qkv_256_softcapping_sm90_cu_cubin, cubin_fmha_v2_flash_attention_fp16_fp32_64_16_S_qkv_256_softcapping_sm90_cu_cubin_len, "fmha_v2_flash_attention_fp16_fp32_64_16_S_qkv_256_sliding_or_chunked_causal_softcapping_sm90_kernel_nl", 49152, 128, 64, 2, 0, false, true, false, true, true, false, true, false, nullptr},
{ DATA_TYPE_BF16, DATA_TYPE_BF16, 0, 64, 64, 576, 512, 0, 0, 0, kSM_90, cubin_fmha_v2_flash_attention_bf16_64_64_S_q_paged_kv_576x512_sm90_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_64_S_q_paged_kv_576x512_sm90_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_64_S_q_paged_kv_576x512_sm90_kernel_nl_tiled", 49152, 128, 64, 0, 2, false, true, false, true, true, true, false, false, nullptr},
{ DATA_TYPE_BF16, DATA_TYPE_BF16, 0, 64, 128, 192, 128, 0, 0, 0, kSM_90, cubin_fmha_v2_flash_attention_bf16_64_128_S_qkv_192x128_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_128_S_qkv_192x128_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_128_S_qkv_192x128_causal_tma_ws_sm90_kernel", 213248, 384, 64, 1, 0, false, true, true, true, false, false, false, false, nullptr},
{ DATA_TYPE_BF16, DATA_TYPE_BF16, 0, 64, 128, 192, 128, 0, 0, 0, kSM_90, cubin_fmha_v2_flash_attention_bf16_64_128_S_q_paged_kv_192x128_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_128_S_q_paged_kv_192x128_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_128_S_q_paged_kv_192x128_causal_tma_ws_sm90_kernel", 213248, 384, 64, 1, 2, false, true, true, true, false, false, false, false, nullptr},
#endif
#ifndef EXCLUDE_SM_89

View File

@ -267,8 +267,8 @@ def generate_fmha_cu(project_dir, venv_python):
build_run("python3 setup.py", env=env)
# Copy generated header file when cu path is active and cubins are deleted.
# cubin_dir = project_dir / "cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin"
# build_run(f"mv generated/fmha_cubin.h {cubin_dir}")
cubin_dir = project_dir / "cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin"
build_run(f"mv generated/fmha_cubin.h {cubin_dir}")
for cu_file in (fmha_v2_dir / "generated").glob("*sm*.cu"):
build_run(f"mv {cu_file} {fmha_v2_cu_dir}")