mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-01-14 06:27:45 +08:00
Replace libtensorrt_llm_nvrtc_wrapper.so with its source code, which consists of two parts: 1. NVRTC glue code 2. XQA kernel code During TensorRT-LLM build, XQA kernel code is embedded as C++ arries via gen_cpp_header.py and passed to NVRTC for JIT compilation. Signed-off-by: Ming Wei <2345434+ming-wei@users.noreply.github.com>
167 lines
5.5 KiB
Python
Executable File
167 lines
5.5 KiB
Python
Executable File
#!/usr/bin/env python3
|
|
# SPDX-FileCopyrightText: Copyright (c) 2022-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
# SPDX-License-Identifier: NVIDIA TensorRT Source Code License Agreement
|
|
#
|
|
# NVIDIA CORPORATION, its affiliates and licensors retain all intellectual
|
|
# property and proprietary rights in and to this material, related
|
|
# documentation and any modifications thereto. Any use, reproduction,
|
|
# disclosure or distribution of this material and related documentation
|
|
# without an express license agreement from NVIDIA CORPORATION or
|
|
# its affiliates is strictly prohibited.
|
|
|
|
import argparse
|
|
import os
|
|
from collections import defaultdict, namedtuple
|
|
from pathlib import Path
|
|
|
|
# Example: input_fname='mhaUtils.cuh', output_fname='mha_utils_cuh.h', content_var_name='mha_utils_cuh_content', fname_var_name='mha_utils_cuh_header'
|
|
Entry = namedtuple(
|
|
'Entry',
|
|
['input_fname', 'output_fname', 'content_var_name', 'fname_var_name'])
|
|
|
|
parser = argparse.ArgumentParser(
|
|
description='Generate cpp headers from kernel source file for NVRTC.',
|
|
formatter_class=argparse.ArgumentDefaultsHelpFormatter)
|
|
|
|
parser.add_argument('-o',
|
|
'--output',
|
|
help='output header file name',
|
|
default='generated/xqa_sources.h')
|
|
parser.add_argument('--embed-cuda-headers',
|
|
action='store_true',
|
|
help='embed cuda headers',
|
|
default=True)
|
|
parser.add_argument('--no-embed-cuda-headers',
|
|
dest='embed_cuda_headers',
|
|
action='store_false')
|
|
parser.add_argument(
|
|
'--cuda_root',
|
|
help='CUDA Toolkit path (applicable when --embed-cuda-headers is ON)',
|
|
default='/usr/local/cuda')
|
|
|
|
args = parser.parse_args()
|
|
|
|
|
|
def convert_to_raw_cpp_string(s: str):
|
|
|
|
def stringify(x: bytes):
|
|
return "\\" + format(x, "03o")
|
|
|
|
b = bytes(s, 'utf-8')
|
|
return ''.join(map(stringify, b))
|
|
|
|
|
|
def get_canonized_str(s: str):
|
|
tokens = []
|
|
n = len(s)
|
|
i = 0
|
|
while i < n and not s[i].isalpha() and not s[i].isdigit():
|
|
i += 1
|
|
while i < n:
|
|
j = i + 1
|
|
while j < n and (s[j].islower() or s[j].isdigit()):
|
|
j += 1
|
|
tokens.append(s[i:j].lower())
|
|
while j < n and not s[j].isalpha() and not s[j].isdigit():
|
|
j += 1
|
|
i = j
|
|
return '_'.join(tokens)
|
|
|
|
|
|
def get_entry_from_input_fname(input_fname: str):
|
|
canonized_str = get_canonized_str(os.path.basename(input_fname))
|
|
output_fname = None
|
|
output_fname = args.output
|
|
return Entry(input_fname=input_fname,
|
|
output_fname=output_fname,
|
|
content_var_name=canonized_str + '_content',
|
|
fname_var_name=canonized_str + "_fname")
|
|
|
|
|
|
def is_header(entry: Entry):
|
|
return entry.input_fname[-3:] != ".cu"
|
|
|
|
|
|
SOURCE_FILES = [
|
|
'cuda_hint.cuh', 'defines.h', 'ldgsts.cuh', 'mha.cu', 'mha_sm90.cu',
|
|
'mha.h', 'mhaUtils.cuh', 'mma.cuh', 'platform.h', 'utils.cuh', 'utils.h',
|
|
'mha_stdheaders.cuh', 'gmma.cuh', 'gmma_impl.cuh', 'barriers.cuh', 'tma.h',
|
|
'pairedF32Op.cuh', 'specDec.h'
|
|
]
|
|
|
|
CUDA_HEADERS = [
|
|
'cuda_bf16.h', 'cuda_bf16.hpp', 'cuda_fp16.h', 'cuda_fp16.hpp',
|
|
'cuda_fp8.h', 'cuda_fp8.hpp', 'vector_types.h', 'vector_functions.h',
|
|
'device_types.h'
|
|
]
|
|
|
|
all_files = SOURCE_FILES
|
|
if args.embed_cuda_headers:
|
|
all_files += [
|
|
os.path.join(args.cuda_root, 'include', i) for i in CUDA_HEADERS
|
|
]
|
|
|
|
all_entries = map(get_entry_from_input_fname, all_files)
|
|
|
|
TEMPLATE_PROLOGUE = '''/*
|
|
* SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
* SPDX-License-Identifier: NVIDIA TensorRT Source Code License Agreement
|
|
*
|
|
* NVIDIA CORPORATION, its affiliates and licensors retain all intellectual
|
|
* property and proprietary rights in and to this material, related
|
|
* documentation and any modifications thereto. Any use, reproduction,
|
|
* disclosure or distribution of this material and related documentation
|
|
* without an express license agreement from NVIDIA CORPORATION or
|
|
* its affiliates is strictly prohibited.
|
|
*/
|
|
|
|
#pragma once
|
|
namespace tensorrt_llm {
|
|
namespace kernels {
|
|
'''
|
|
|
|
TEMPLATE_CONTENT = '''inline constexpr const char* {content_var_name} = "{content}";
|
|
inline constexpr const char* {fname_var_name} = "{fname}";
|
|
'''
|
|
|
|
TEMPLATE_EPILOGUE = '''}
|
|
}
|
|
'''
|
|
|
|
D = defaultdict(list)
|
|
for entry in all_entries:
|
|
output_fname = entry.output_fname
|
|
D[output_fname].append(entry)
|
|
|
|
for output_fname, entries in D.items():
|
|
output_content = ''
|
|
output_content += TEMPLATE_PROLOGUE
|
|
for entry in entries:
|
|
with open(entry.input_fname, 'r') as f:
|
|
input_content = f.read()
|
|
output_content += TEMPLATE_CONTENT.format(
|
|
content_var_name=entry.content_var_name,
|
|
content=convert_to_raw_cpp_string(input_content),
|
|
fname_var_name=entry.fname_var_name,
|
|
fname=os.path.basename(entry.input_fname))
|
|
|
|
output_content += "inline constexpr char const* xqa_headers_content[] = {\n"
|
|
for entry in entries:
|
|
if is_header(entry):
|
|
output_content += " " + entry.content_var_name + ",\n"
|
|
output_content += "};\n"
|
|
|
|
output_content += "inline constexpr char const* xqa_headers_name[] = {\n"
|
|
for entry in entries:
|
|
if is_header(entry):
|
|
output_content += " " + entry.fname_var_name + ",\n"
|
|
output_content += "};\n"
|
|
|
|
output_content += TEMPLATE_EPILOGUE
|
|
|
|
output_dir = os.path.dirname(entry.output_fname)
|
|
Path(output_dir).mkdir(parents=True, exist_ok=True)
|
|
|
|
with open(entry.output_fname, 'w') as f:
|
|
f.write(output_content)
|