mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-01-13 22:18:36 +08:00
update setup.py for special cases (#5227)
Signed-off-by: Qidi Sang <200703406+qsang-nv@users.noreply.github.com>
This commit is contained in:
parent
6a6b9d2594
commit
faca19c2f0
@ -189,7 +189,8 @@ namespace kernels
|
||||
ns_close = r"""
|
||||
// clang-format on
|
||||
} // namespace kernels
|
||||
} // namespace tensorrt_llm""" if generate_cu_trtllm else ""
|
||||
} // namespace tensorrt_llm
|
||||
""" if generate_cu_trtllm else ""
|
||||
|
||||
copyright = '''\
|
||||
/***************************************************************************************************
|
||||
@ -3403,6 +3404,111 @@ static const struct TestMetaV2
|
||||
return code
|
||||
|
||||
|
||||
def modify_cubin_header(cubin_header):
|
||||
# for paged context fmha cases
|
||||
target = "#ifndef EXCLUDE_SM_90"
|
||||
|
||||
first_addition = """extern unsigned char cubin_fmha_v2_flash_attention_bf16_64_128_S_qkv_192x128_tma_ws_sm90_cu_cubin[];
|
||||
extern unsigned char cubin_fmha_v2_flash_attention_bf16_64_128_S_q_paged_kv_192x128_tma_ws_sm90_cu_cubin[];"""
|
||||
|
||||
second_addition = """extern uint32_t cubin_fmha_v2_flash_attention_bf16_64_128_S_qkv_192x128_tma_ws_sm90_cu_cubin_len;
|
||||
extern uint32_t cubin_fmha_v2_flash_attention_bf16_64_128_S_q_paged_kv_192x128_tma_ws_sm90_cu_cubin_len;"""
|
||||
|
||||
third_addition = """{ DATA_TYPE_BF16, DATA_TYPE_BF16, 0, 64, 128, 192, 128, 0, 0, 0, kSM_90, cubin_fmha_v2_flash_attention_bf16_64_128_S_qkv_192x128_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_128_S_qkv_192x128_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_128_S_qkv_192x128_causal_tma_ws_sm90_kernel", 213248, 384, 64, 1, 0, false, true, true, true, false, false, false, false, nullptr},
|
||||
{ DATA_TYPE_BF16, DATA_TYPE_BF16, 0, 64, 128, 192, 128, 0, 0, 0, kSM_90, cubin_fmha_v2_flash_attention_bf16_64_128_S_q_paged_kv_192x128_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_128_S_q_paged_kv_192x128_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_128_S_q_paged_kv_192x128_causal_tma_ws_sm90_kernel", 213248, 384, 64, 1, 2, false, true, true, true, false, false, false, false, nullptr},"""
|
||||
|
||||
result = cubin_header
|
||||
offset = 0
|
||||
pos = -1
|
||||
|
||||
def add_kernel_line(result, target, addition, pos, offset):
|
||||
if pos == -1:
|
||||
pos = result.find(target)
|
||||
else:
|
||||
pos = result.find(target, pos + len(target) + offset)
|
||||
if pos != -1:
|
||||
end_pos = result.find('\n', pos)
|
||||
if end_pos == -1:
|
||||
end_pos = len(result)
|
||||
result = result[:end_pos + 1] + addition + result[end_pos:]
|
||||
offset += len(addition)
|
||||
return result, offset, pos
|
||||
|
||||
result, offset, pos = add_kernel_line(result, target, first_addition, pos,
|
||||
offset)
|
||||
|
||||
result, offset, pos = add_kernel_line(result, target, second_addition, pos,
|
||||
offset)
|
||||
|
||||
result, offset, pos = add_kernel_line(result, target, third_addition, pos,
|
||||
offset)
|
||||
|
||||
# for CI cases
|
||||
def add_kernel_line(result, target, addition):
|
||||
pos = result.find(target)
|
||||
if pos != -1:
|
||||
end_pos = result.find('\n', pos)
|
||||
if end_pos == -1:
|
||||
end_pos = len(result)
|
||||
result = result[:end_pos + 1] + addition + result[end_pos:]
|
||||
return result
|
||||
|
||||
target = "#ifndef EXCLUDE_SM_89"
|
||||
addition = """extern unsigned char cubin_fmha_v2_flash_attention_bf16_64_32_S_qkv_128_sm89_cu_cubin[];
|
||||
extern uint32_t cubin_fmha_v2_flash_attention_bf16_64_32_S_qkv_128_sm89_cu_cubin_len;
|
||||
extern unsigned char cubin_fmha_v2_flash_attention_e4m3_fp32_64_64_S_q_paged_kv_576x512_output_bf16_sm89_cu_cubin[];
|
||||
extern uint32_t cubin_fmha_v2_flash_attention_e4m3_fp32_64_64_S_q_paged_kv_576x512_output_bf16_sm89_cu_cubin_len;"""
|
||||
result = add_kernel_line(result, target, addition)
|
||||
|
||||
target = "#ifndef EXCLUDE_SM_86"
|
||||
addition = """extern unsigned char cubin_fmha_v2_flash_attention_bf16_64_32_S_q_paged_kv_64_sm86_cu_cubin[];
|
||||
extern uint32_t cubin_fmha_v2_flash_attention_bf16_64_32_S_q_paged_kv_64_sm86_cu_cubin_len;"""
|
||||
result = add_kernel_line(result, target, addition)
|
||||
|
||||
target = "#ifndef EXCLUDE_SM_80"
|
||||
addition = """extern unsigned char cubin_fmha_v2_flash_attention_fp16_128_128_S_q_paged_kv_64_sm80_cu_cubin[];
|
||||
extern uint32_t cubin_fmha_v2_flash_attention_fp16_128_128_S_q_paged_kv_64_sm80_cu_cubin_len;
|
||||
extern unsigned char cubin_fmha_v2_flash_attention_fp16_64_128_S_q_paged_kv_128_sm80_cu_cubin[];
|
||||
extern uint32_t cubin_fmha_v2_flash_attention_fp16_64_128_S_q_paged_kv_128_sm80_cu_cubin_len;"""
|
||||
result = add_kernel_line(result, target, addition)
|
||||
|
||||
def modify_kernel_line(result, target, new_line):
|
||||
lines = result.split('\n')
|
||||
for i, line in enumerate(lines):
|
||||
if target in line:
|
||||
lines[i] = new_line
|
||||
break
|
||||
return '\n'.join(lines)
|
||||
|
||||
target = "fmha_v2_flash_attention_bf16_64_32_S_qkv_128_causal_sm89_kernel_nl"
|
||||
new_line = '{ DATA_TYPE_BF16, DATA_TYPE_BF16, 0, 64, 32, 128, 128, 0, 0, 0, kSM_89, cubin_fmha_v2_flash_attention_bf16_64_32_S_qkv_128_sm89_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_32_S_qkv_128_sm89_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_32_S_qkv_128_causal_sm89_kernel_nl", 32768, 128, 64, 1, 0, false, true, false, true, true, false, false, true, nullptr},'
|
||||
result = modify_kernel_line(result, target, new_line)
|
||||
|
||||
target = "fmha_v2_flash_attention_e4m3_fp32_64_64_S_q_paged_kv_576x512_output_bf16_sm89_kernel_nl_tiled"
|
||||
new_line = '{ DATA_TYPE_E4M3, DATA_TYPE_BF16, 0, 64, 64, 576, 512, 0, 0, 0, kSM_89, cubin_fmha_v2_flash_attention_e4m3_fp32_64_64_S_q_paged_kv_576x512_output_bf16_sm89_cu_cubin, cubin_fmha_v2_flash_attention_e4m3_fp32_64_64_S_q_paged_kv_576x512_output_bf16_sm89_cu_cubin_len, "fmha_v2_flash_attention_e4m3_fp32_64_64_S_q_paged_kv_576x512_output_bf16_sm89_kernel_nl_tiled", 65536, 128, 64, 0, 2, false, true, false, true, true, true, false, true, nullptr},'
|
||||
result = modify_kernel_line(result, target, new_line)
|
||||
|
||||
target = "fmha_v2_flash_attention_fp16_128_128_S_q_paged_kv_64_causal_sm80_kernel_nl_tiled"
|
||||
new_line = '{ DATA_TYPE_FP16, DATA_TYPE_FP16, 0, 128, 128, 64, 64, 0, 0, 0, kSM_80, cubin_fmha_v2_flash_attention_fp16_128_128_S_q_paged_kv_64_sm80_cu_cubin, cubin_fmha_v2_flash_attention_fp16_128_128_S_q_paged_kv_64_sm80_cu_cubin_len, "fmha_v2_flash_attention_fp16_128_128_S_q_paged_kv_64_causal_sm80_kernel_nl_tiled", 65536, 128, 128, 1, 2, false, true, false, false, true, true, false, true, nullptr},'
|
||||
result = modify_kernel_line(result, target, new_line)
|
||||
|
||||
target = "fmha_v2_flash_attention_fp16_64_128_S_q_paged_kv_128_causal_sm80_kernel_nl_tiled"
|
||||
new_line = '{ DATA_TYPE_FP16, DATA_TYPE_FP16, 0, 64, 128, 128, 128, 0, 0, 0, kSM_80, cubin_fmha_v2_flash_attention_fp16_64_128_S_q_paged_kv_128_sm80_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_128_S_q_paged_kv_128_sm80_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_128_S_q_paged_kv_128_causal_sm80_kernel_nl_tiled", 81920, 128, 64, 1, 2, false, true, false, false, true, true, false, true, nullptr},'
|
||||
result = modify_kernel_line(result, target, new_line)
|
||||
|
||||
target = "fmha_v2_flash_attention_bf16_64_32_S_q_paged_kv_64_causal_sm86_kernel_nl"
|
||||
new_line = '{ DATA_TYPE_BF16, DATA_TYPE_BF16, 0, 64, 32, 64, 64, 0, 0, 0, kSM_86, cubin_fmha_v2_flash_attention_bf16_64_32_S_q_paged_kv_64_sm86_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_32_S_q_paged_kv_64_sm86_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_32_S_q_paged_kv_64_causal_sm86_kernel_nl", 16384, 128, 64, 1, 2, false, true, false, true, true, false, false, true, nullptr},'
|
||||
result = modify_kernel_line(result, target, new_line)
|
||||
|
||||
# make sure only one empty line at the end
|
||||
lines = result.split('\n')
|
||||
while lines and not lines[-1].strip():
|
||||
lines.pop()
|
||||
lines.append('')
|
||||
|
||||
return '\n'.join(lines)
|
||||
|
||||
|
||||
def generate_files(specs_names):
|
||||
|
||||
kfiles = []
|
||||
@ -3443,7 +3549,6 @@ def generate_files(specs_names):
|
||||
f.write(print_kernel_traits_code)
|
||||
|
||||
# Make sure we have a bin directory.
|
||||
# TEMP disable this until sm90 is in a good shape [Timmy, do not MR.]
|
||||
if not os.path.exists('bin'):
|
||||
os.mkdir('bin')
|
||||
cmd = 'nvcc -I src -Xcompiler -Wno-enum-compare --std=c++17 -o bin/print_traits.exe generated/print_kernel_traits.cu'.split(
|
||||
@ -3465,6 +3570,8 @@ def generate_files(specs_names):
|
||||
# this gives: kname, smem bytes, threads_per_cta, loop_step
|
||||
kernel_traits = [traits.split() for traits in output.splitlines()]
|
||||
cubin_header = get_cubin_header(kernel_traits, valid_specs_names)
|
||||
if generate_cu_trtllm:
|
||||
cubin_header = modify_cubin_header(cubin_header)
|
||||
|
||||
with open('./generated/fmha_cubin.h', 'w') as f:
|
||||
f.write(cubin_header)
|
||||
|
||||
@ -26,6 +26,8 @@ namespace kernels
|
||||
|
||||
|
||||
#ifndef EXCLUDE_SM_90
|
||||
extern unsigned char cubin_fmha_v2_flash_attention_bf16_64_128_S_qkv_192x128_tma_ws_sm90_cu_cubin[];
|
||||
extern unsigned char cubin_fmha_v2_flash_attention_bf16_64_128_S_q_paged_kv_192x128_tma_ws_sm90_cu_cubin[];
|
||||
extern unsigned char cubin_fmha_v2_fp16_64_32_ldgsts_sm90_cu_cubin[];
|
||||
extern unsigned char cubin_fmha_v2_fp16_128_32_ldgsts_sm90_cu_cubin[];
|
||||
extern unsigned char cubin_fmha_v2_fp16_256_32_ldgsts_sm90_cu_cubin[];
|
||||
@ -380,8 +382,6 @@ extern unsigned char cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_qkv_256_so
|
||||
extern unsigned char cubin_fmha_v2_flash_attention_fp16_fp32_64_32_S_qkv_128_softcapping_sm90_cu_cubin[];
|
||||
extern unsigned char cubin_fmha_v2_flash_attention_fp16_fp32_64_16_S_qkv_256_softcapping_sm90_cu_cubin[];
|
||||
extern unsigned char cubin_fmha_v2_flash_attention_bf16_64_64_S_q_paged_kv_576x512_sm90_cu_cubin[];
|
||||
extern unsigned char cubin_fmha_v2_flash_attention_bf16_64_128_S_qkv_192x128_tma_ws_sm90_cu_cubin[];
|
||||
extern unsigned char cubin_fmha_v2_flash_attention_bf16_64_128_S_q_paged_kv_192x128_tma_ws_sm90_cu_cubin[];
|
||||
#endif
|
||||
|
||||
#ifndef EXCLUDE_SM_89
|
||||
@ -1318,6 +1318,8 @@ extern void run_fmha_v2_flash_attention_fp16_fp32_64_16_S_q_paged_kv_256_softcap
|
||||
|
||||
|
||||
#ifndef EXCLUDE_SM_90
|
||||
extern uint32_t cubin_fmha_v2_flash_attention_bf16_64_128_S_qkv_192x128_tma_ws_sm90_cu_cubin_len;
|
||||
extern uint32_t cubin_fmha_v2_flash_attention_bf16_64_128_S_q_paged_kv_192x128_tma_ws_sm90_cu_cubin_len;
|
||||
extern uint32_t cubin_fmha_v2_fp16_64_32_ldgsts_sm90_cu_cubin_len;
|
||||
extern uint32_t cubin_fmha_v2_fp16_128_32_ldgsts_sm90_cu_cubin_len;
|
||||
extern uint32_t cubin_fmha_v2_fp16_256_32_ldgsts_sm90_cu_cubin_len;
|
||||
@ -1672,8 +1674,6 @@ extern uint32_t cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_qkv_256_softcap
|
||||
extern uint32_t cubin_fmha_v2_flash_attention_fp16_fp32_64_32_S_qkv_128_softcapping_sm90_cu_cubin_len;
|
||||
extern uint32_t cubin_fmha_v2_flash_attention_fp16_fp32_64_16_S_qkv_256_softcapping_sm90_cu_cubin_len;
|
||||
extern uint32_t cubin_fmha_v2_flash_attention_bf16_64_64_S_q_paged_kv_576x512_sm90_cu_cubin_len;
|
||||
extern uint32_t cubin_fmha_v2_flash_attention_bf16_64_128_S_qkv_192x128_tma_ws_sm90_cu_cubin_len;
|
||||
extern uint32_t cubin_fmha_v2_flash_attention_bf16_64_128_S_q_paged_kv_192x128_tma_ws_sm90_cu_cubin_len;
|
||||
#endif
|
||||
|
||||
|
||||
@ -1709,6 +1709,8 @@ static const struct FusedMultiHeadAttentionKernelMetaInfoV2
|
||||
void (*launcher)(Fused_multihead_attention_params_v2& params, const Launch_params& launch_params, cudaStream_t stream);
|
||||
} sMhaKernelMetaInfosV2[] = {
|
||||
#ifndef EXCLUDE_SM_90
|
||||
{ DATA_TYPE_BF16, DATA_TYPE_BF16, 0, 64, 128, 192, 128, 0, 0, 0, kSM_90, cubin_fmha_v2_flash_attention_bf16_64_128_S_qkv_192x128_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_128_S_qkv_192x128_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_128_S_qkv_192x128_causal_tma_ws_sm90_kernel", 213248, 384, 64, 1, 0, false, true, true, true, false, false, false, false, nullptr},
|
||||
{ DATA_TYPE_BF16, DATA_TYPE_BF16, 0, 64, 128, 192, 128, 0, 0, 0, kSM_90, cubin_fmha_v2_flash_attention_bf16_64_128_S_q_paged_kv_192x128_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_128_S_q_paged_kv_192x128_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_128_S_q_paged_kv_192x128_causal_tma_ws_sm90_kernel", 213248, 384, 64, 1, 2, false, true, true, true, false, false, false, false, nullptr},
|
||||
{ DATA_TYPE_FP16, DATA_TYPE_FP16, 64, 64, 64, 32, 32, 0, 0, 0, kSM_90, cubin_fmha_v2_fp16_64_32_ldgsts_sm90_cu_cubin, cubin_fmha_v2_fp16_64_32_ldgsts_sm90_cu_cubin_len, "fmha_v2_fp16_64_32_ldgsts_sm90_kernel", 17408, 128, 0, 0, 0, false, false, false, false, true, false, false, false, nullptr},
|
||||
{ DATA_TYPE_FP16, DATA_TYPE_FP16, 64, 64, 64, 32, 32, 0, 0, 0, kSM_90, cubin_fmha_v2_fp16_64_32_ldgsts_sm90_cu_cubin, cubin_fmha_v2_fp16_64_32_ldgsts_sm90_cu_cubin_len, "fmha_v2_fp16_64_32_sliding_or_chunked_causal_ldgsts_sm90_kernel", 17408, 128, 0, 2, 0, false, false, false, false, true, false, false, false, nullptr},
|
||||
{ DATA_TYPE_FP16, DATA_TYPE_FP16, 64, 64, 64, 32, 32, 0, 0, 0, kSM_90, cubin_fmha_v2_fp16_64_32_ldgsts_sm90_cu_cubin, cubin_fmha_v2_fp16_64_32_ldgsts_sm90_cu_cubin_len, "fmha_v2_fp16_64_32_causal_ldgsts_sm90_kernel", 17408, 128, 0, 1, 0, false, false, false, false, true, false, false, false, nullptr},
|
||||
@ -2663,8 +2665,6 @@ static const struct FusedMultiHeadAttentionKernelMetaInfoV2
|
||||
{ DATA_TYPE_FP16, DATA_TYPE_FP16, 0, 64, 16, 256, 256, 0, 0, 0, kSM_90, cubin_fmha_v2_flash_attention_fp16_fp32_64_16_S_qkv_256_softcapping_sm90_cu_cubin, cubin_fmha_v2_flash_attention_fp16_fp32_64_16_S_qkv_256_softcapping_sm90_cu_cubin_len, "fmha_v2_flash_attention_fp16_fp32_64_16_S_qkv_256_causal_softcapping_sm90_kernel_nl", 49152, 128, 64, 1, 0, false, true, false, true, true, false, true, false, nullptr},
|
||||
{ DATA_TYPE_FP16, DATA_TYPE_FP16, 0, 64, 16, 256, 256, 0, 0, 0, kSM_90, cubin_fmha_v2_flash_attention_fp16_fp32_64_16_S_qkv_256_softcapping_sm90_cu_cubin, cubin_fmha_v2_flash_attention_fp16_fp32_64_16_S_qkv_256_softcapping_sm90_cu_cubin_len, "fmha_v2_flash_attention_fp16_fp32_64_16_S_qkv_256_sliding_or_chunked_causal_softcapping_sm90_kernel_nl", 49152, 128, 64, 2, 0, false, true, false, true, true, false, true, false, nullptr},
|
||||
{ DATA_TYPE_BF16, DATA_TYPE_BF16, 0, 64, 64, 576, 512, 0, 0, 0, kSM_90, cubin_fmha_v2_flash_attention_bf16_64_64_S_q_paged_kv_576x512_sm90_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_64_S_q_paged_kv_576x512_sm90_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_64_S_q_paged_kv_576x512_sm90_kernel_nl_tiled", 49152, 128, 64, 0, 2, false, true, false, true, true, true, false, false, nullptr},
|
||||
{ DATA_TYPE_BF16, DATA_TYPE_BF16, 0, 64, 128, 192, 128, 0, 0, 0, kSM_90, cubin_fmha_v2_flash_attention_bf16_64_128_S_qkv_192x128_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_128_S_qkv_192x128_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_128_S_qkv_192x128_causal_tma_ws_sm90_kernel", 213248, 384, 64, 1, 0, false, true, true, true, false, false, false, false, nullptr},
|
||||
{ DATA_TYPE_BF16, DATA_TYPE_BF16, 0, 64, 128, 192, 128, 0, 0, 0, kSM_90, cubin_fmha_v2_flash_attention_bf16_64_128_S_q_paged_kv_192x128_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_128_S_q_paged_kv_192x128_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_128_S_q_paged_kv_192x128_causal_tma_ws_sm90_kernel", 213248, 384, 64, 1, 2, false, true, true, true, false, false, false, false, nullptr},
|
||||
#endif
|
||||
|
||||
#ifndef EXCLUDE_SM_89
|
||||
|
||||
@ -267,8 +267,8 @@ def generate_fmha_cu(project_dir, venv_python):
|
||||
build_run("python3 setup.py", env=env)
|
||||
|
||||
# Copy generated header file when cu path is active and cubins are deleted.
|
||||
# cubin_dir = project_dir / "cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin"
|
||||
# build_run(f"mv generated/fmha_cubin.h {cubin_dir}")
|
||||
cubin_dir = project_dir / "cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin"
|
||||
build_run(f"mv generated/fmha_cubin.h {cubin_dir}")
|
||||
|
||||
for cu_file in (fmha_v2_dir / "generated").glob("*sm*.cu"):
|
||||
build_run(f"mv {cu_file} {fmha_v2_cu_dir}")
|
||||
|
||||
Loading…
Reference in New Issue
Block a user