TensorRT-LLMs/cpp/kernels/xqa/gen_cpp_header.py
Yihan Wang 9df4dad3b6
[None][fix] Introduce inline namespace to avoid symbol collision (#9541)
Signed-off-by: Yihan Wang <yihwang@nvidia.com>
2025-12-12 23:32:15 +08:00

181 lines
5.9 KiB
Python
Executable File

#!/usr/bin/env python3
# SPDX-FileCopyrightText: Copyright (c) 2022-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
import os
from collections import defaultdict, namedtuple
from pathlib import Path
# Example: input_fname='mhaUtils.cuh', output_fname='mha_utils_cuh.h', content_var_name='mha_utils_cuh_content', fname_var_name='mha_utils_cuh_header'
Entry = namedtuple(
'Entry',
['input_fname', 'output_fname', 'content_var_name', 'fname_var_name'])
parser = argparse.ArgumentParser(
description='Generate cpp headers from kernel source file for NVRTC.',
formatter_class=argparse.ArgumentDefaultsHelpFormatter)
parser.add_argument('-o',
'--output',
help='output header file name',
default='generated/xqa_sources.h')
parser.add_argument('--embed-cuda-headers',
action='store_true',
help='embed cuda headers',
default=True)
parser.add_argument('--no-embed-cuda-headers',
dest='embed_cuda_headers',
action='store_false')
parser.add_argument(
'--cuda_root',
help='CUDA Toolkit path (applicable when --embed-cuda-headers is ON)',
default='/usr/local/cuda')
args = parser.parse_args()
def convert_to_raw_cpp_string(s: str):
def stringify(x: bytes):
return "\\" + format(x, "03o")
b = bytes(s, 'utf-8')
return ''.join(map(stringify, b))
def get_canonized_str(s: str):
tokens = []
n = len(s)
i = 0
while i < n and not s[i].isalpha() and not s[i].isdigit():
i += 1
while i < n:
j = i + 1
while j < n and (s[j].islower() or s[j].isdigit()):
j += 1
tokens.append(s[i:j].lower())
while j < n and not s[j].isalpha() and not s[j].isdigit():
j += 1
i = j
return '_'.join(tokens)
def get_entry_from_input_fname(input_fname: str):
canonized_str = get_canonized_str(os.path.basename(input_fname))
output_fname = None
output_fname = args.output
return Entry(input_fname=input_fname,
output_fname=output_fname,
content_var_name=canonized_str + '_content',
fname_var_name=canonized_str + "_fname")
def is_header(entry: Entry):
return entry.input_fname[-3:] != ".cu"
SOURCE_FILES = [
'cuda_hint.cuh', 'defines.h', 'ldgsts.cuh', 'mha.cu', 'mha_sm90.cu',
'mla_sm120.cu', 'mha.h', 'mhaUtils.cuh', 'mma.cuh', 'platform.h',
'utils.cuh', 'utils.h', 'mha_stdheaders.cuh', 'mha_components.cuh',
'mla_sm120.cuh', 'gmma.cuh', 'gmma_impl.cuh', 'barriers.cuh', 'tma.h',
'specDec.h'
]
CUDA_HEADERS = [
'cuda_bf16.h', 'cuda_bf16.hpp', 'cuda_fp16.h', 'cuda_fp16.hpp',
'cuda_fp8.h', 'cuda_fp8.hpp', 'vector_types.h', 'vector_functions.h',
'device_types.h'
]
all_files = SOURCE_FILES
if args.embed_cuda_headers:
all_files += [
os.path.join(args.cuda_root, 'include', i) for i in CUDA_HEADERS
]
all_entries = map(get_entry_from_input_fname, all_files)
TEMPLATE_PROLOGUE = '''/*
* SPDX-FileCopyrightText: Copyright (c) 2022-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: Apache-2.0
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#pragma once
#include "tensorrt_llm/common/config.h"
TRTLLM_NAMESPACE_BEGIN
namespace kernels {
'''
TEMPLATE_CONTENT = '''inline constexpr const char* {content_var_name} = "{content}";
inline constexpr const char* {fname_var_name} = "{fname}";
'''
TEMPLATE_EPILOGUE = '''}
TRTLLM_NAMESPACE_END
'''
D = defaultdict(list)
for entry in all_entries:
output_fname = entry.output_fname
D[output_fname].append(entry)
for output_fname, entries in D.items():
output_content = ''
output_content += TEMPLATE_PROLOGUE
for entry in entries:
with open(entry.input_fname, 'r') as f:
input_content = f.read()
output_content += TEMPLATE_CONTENT.format(
content_var_name=entry.content_var_name,
content=convert_to_raw_cpp_string(input_content),
fname_var_name=entry.fname_var_name,
fname=os.path.basename(entry.input_fname))
output_content += "inline constexpr char const* xqa_headers_content[] = {\n"
for entry in entries:
if is_header(entry):
output_content += " " + entry.content_var_name + ",\n"
output_content += "};\n"
output_content += "inline constexpr char const* xqa_headers_name[] = {\n"
for entry in entries:
if is_header(entry):
output_content += " " + entry.fname_var_name + ",\n"
output_content += "};\n"
output_content += TEMPLATE_EPILOGUE
output_dir = os.path.dirname(entry.output_fname)
Path(output_dir).mkdir(parents=True, exist_ok=True)
with open(entry.output_fname, 'w') as f:
f.write(output_content)