mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-01-14 06:27:45 +08:00
181 lines
5.9 KiB
Python
Executable File
181 lines
5.9 KiB
Python
Executable File
#!/usr/bin/env python3
|
|
# SPDX-FileCopyrightText: Copyright (c) 2022-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
# SPDX-License-Identifier: Apache-2.0
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
|
|
import argparse
|
|
import os
|
|
from collections import defaultdict, namedtuple
|
|
from pathlib import Path
|
|
|
|
# Example: input_fname='mhaUtils.cuh', output_fname='mha_utils_cuh.h', content_var_name='mha_utils_cuh_content', fname_var_name='mha_utils_cuh_header'
|
|
Entry = namedtuple(
|
|
'Entry',
|
|
['input_fname', 'output_fname', 'content_var_name', 'fname_var_name'])
|
|
|
|
parser = argparse.ArgumentParser(
|
|
description='Generate cpp headers from kernel source file for NVRTC.',
|
|
formatter_class=argparse.ArgumentDefaultsHelpFormatter)
|
|
|
|
parser.add_argument('-o',
|
|
'--output',
|
|
help='output header file name',
|
|
default='generated/xqa_sources.h')
|
|
parser.add_argument('--embed-cuda-headers',
|
|
action='store_true',
|
|
help='embed cuda headers',
|
|
default=True)
|
|
parser.add_argument('--no-embed-cuda-headers',
|
|
dest='embed_cuda_headers',
|
|
action='store_false')
|
|
parser.add_argument(
|
|
'--cuda_root',
|
|
help='CUDA Toolkit path (applicable when --embed-cuda-headers is ON)',
|
|
default='/usr/local/cuda')
|
|
|
|
args = parser.parse_args()
|
|
|
|
|
|
def convert_to_raw_cpp_string(s: str):
|
|
|
|
def stringify(x: bytes):
|
|
return "\\" + format(x, "03o")
|
|
|
|
b = bytes(s, 'utf-8')
|
|
return ''.join(map(stringify, b))
|
|
|
|
|
|
def get_canonized_str(s: str):
|
|
tokens = []
|
|
n = len(s)
|
|
i = 0
|
|
while i < n and not s[i].isalpha() and not s[i].isdigit():
|
|
i += 1
|
|
while i < n:
|
|
j = i + 1
|
|
while j < n and (s[j].islower() or s[j].isdigit()):
|
|
j += 1
|
|
tokens.append(s[i:j].lower())
|
|
while j < n and not s[j].isalpha() and not s[j].isdigit():
|
|
j += 1
|
|
i = j
|
|
return '_'.join(tokens)
|
|
|
|
|
|
def get_entry_from_input_fname(input_fname: str):
|
|
canonized_str = get_canonized_str(os.path.basename(input_fname))
|
|
output_fname = None
|
|
output_fname = args.output
|
|
return Entry(input_fname=input_fname,
|
|
output_fname=output_fname,
|
|
content_var_name=canonized_str + '_content',
|
|
fname_var_name=canonized_str + "_fname")
|
|
|
|
|
|
def is_header(entry: Entry):
|
|
return entry.input_fname[-3:] != ".cu"
|
|
|
|
|
|
SOURCE_FILES = [
|
|
'cuda_hint.cuh', 'defines.h', 'ldgsts.cuh', 'mha.cu', 'mha_sm90.cu',
|
|
'mla_sm120.cu', 'mha.h', 'mhaUtils.cuh', 'mma.cuh', 'platform.h',
|
|
'utils.cuh', 'utils.h', 'mha_stdheaders.cuh', 'mha_components.cuh',
|
|
'mla_sm120.cuh', 'gmma.cuh', 'gmma_impl.cuh', 'barriers.cuh', 'tma.h',
|
|
'specDec.h'
|
|
]
|
|
|
|
CUDA_HEADERS = [
|
|
'cuda_bf16.h', 'cuda_bf16.hpp', 'cuda_fp16.h', 'cuda_fp16.hpp',
|
|
'cuda_fp8.h', 'cuda_fp8.hpp', 'vector_types.h', 'vector_functions.h',
|
|
'device_types.h'
|
|
]
|
|
|
|
all_files = SOURCE_FILES
|
|
if args.embed_cuda_headers:
|
|
all_files += [
|
|
os.path.join(args.cuda_root, 'include', i) for i in CUDA_HEADERS
|
|
]
|
|
|
|
all_entries = map(get_entry_from_input_fname, all_files)
|
|
|
|
TEMPLATE_PROLOGUE = '''/*
|
|
* SPDX-FileCopyrightText: Copyright (c) 2022-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
* SPDX-License-Identifier: Apache-2.0
|
|
*
|
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
|
* you may not use this file except in compliance with the License.
|
|
* You may obtain a copy of the License at
|
|
*
|
|
* http://www.apache.org/licenses/LICENSE-2.0
|
|
*
|
|
* Unless required by applicable law or agreed to in writing, software
|
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
* See the License for the specific language governing permissions and
|
|
* limitations under the License.
|
|
*/
|
|
|
|
#pragma once
|
|
#include "tensorrt_llm/common/config.h"
|
|
|
|
TRTLLM_NAMESPACE_BEGIN
|
|
namespace kernels {
|
|
'''
|
|
|
|
TEMPLATE_CONTENT = '''inline constexpr const char* {content_var_name} = "{content}";
|
|
inline constexpr const char* {fname_var_name} = "{fname}";
|
|
'''
|
|
|
|
TEMPLATE_EPILOGUE = '''}
|
|
TRTLLM_NAMESPACE_END
|
|
|
|
'''
|
|
|
|
D = defaultdict(list)
|
|
for entry in all_entries:
|
|
output_fname = entry.output_fname
|
|
D[output_fname].append(entry)
|
|
|
|
for output_fname, entries in D.items():
|
|
output_content = ''
|
|
output_content += TEMPLATE_PROLOGUE
|
|
for entry in entries:
|
|
with open(entry.input_fname, 'r') as f:
|
|
input_content = f.read()
|
|
output_content += TEMPLATE_CONTENT.format(
|
|
content_var_name=entry.content_var_name,
|
|
content=convert_to_raw_cpp_string(input_content),
|
|
fname_var_name=entry.fname_var_name,
|
|
fname=os.path.basename(entry.input_fname))
|
|
|
|
output_content += "inline constexpr char const* xqa_headers_content[] = {\n"
|
|
for entry in entries:
|
|
if is_header(entry):
|
|
output_content += " " + entry.content_var_name + ",\n"
|
|
output_content += "};\n"
|
|
|
|
output_content += "inline constexpr char const* xqa_headers_name[] = {\n"
|
|
for entry in entries:
|
|
if is_header(entry):
|
|
output_content += " " + entry.fname_var_name + ",\n"
|
|
output_content += "};\n"
|
|
|
|
output_content += TEMPLATE_EPILOGUE
|
|
|
|
output_dir = os.path.dirname(entry.output_fname)
|
|
Path(output_dir).mkdir(parents=True, exist_ok=True)
|
|
|
|
with open(entry.output_fname, 'w') as f:
|
|
f.write(output_content)
|