mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-01-14 06:27:45 +08:00
* Update TensorRT-LLM --------- Co-authored-by: RunningLeon <mnsheng@yeah.net> Co-authored-by: Tlntin <TlntinDeng01@Gmail.com> Co-authored-by: ZHENG, Zhen <zhengzhen.z@qq.com> Co-authored-by: Pham Van Ngoan <ngoanpham1196@gmail.com> Co-authored-by: Nathan Price <nathan@abridge.com> Co-authored-by: Tushar Goel <tushar.goel.ml@gmail.com> Co-authored-by: Mati <132419219+matichon-vultureprime@users.noreply.github.com>
207 lines
7.2 KiB
Python
207 lines
7.2 KiB
Python
import argparse
|
|
import math
|
|
# include plugins
|
|
# yapf: disable
|
|
import os
|
|
import sys
|
|
import time
|
|
from pathlib import Path
|
|
from typing import List, OrderedDict
|
|
|
|
import tensorrt as trt
|
|
|
|
# from plugin import LAYER_NAME, FmhaLayer, get_engine_name
|
|
import tensorrt_llm
|
|
from tensorrt_llm import Module, str_dtype_to_trt
|
|
from tensorrt_llm.builder import Builder, BuilderConfig
|
|
from tensorrt_llm.functional import Tensor
|
|
from tensorrt_llm.logger import logger
|
|
from tensorrt_llm.network import net_guard
|
|
|
|
sys.path.append(os.environ.get('PLUGIN_GEN_WORKSPACE', './tmp'))
|
|
from functional import fused_attention_kernel # isort:skip
|
|
# yapf: enable
|
|
|
|
|
|
def get_engine_name(head_size: int, dtype: str) -> str:
|
|
return f'fmha_{head_size}_{dtype}.engine'
|
|
|
|
|
|
class FmhaLayer(Module):
|
|
|
|
def __init__(self, num_heads: int, head_size: int, softmax_scale: float):
|
|
super().__init__()
|
|
self.num_heads = num_heads
|
|
self.head_size = head_size
|
|
self.softmax_scale = softmax_scale
|
|
self.dtype = str_dtype_to_trt('float16')
|
|
|
|
def forward(self, Q: Tensor, K: Tensor, V: Tensor):
|
|
inputs = [Q, K, V]
|
|
Out, L, M = fused_attention_kernel(self.softmax_scale, self.num_heads,
|
|
*[p.trt_tensor for p in inputs])
|
|
Out.mark_output('out', self.dtype)
|
|
L.mark_output('L', self.dtype)
|
|
M.mark_output('M', self.dtype)
|
|
return Out, L, M
|
|
|
|
def prepare_inputs(self, max_batch_size: int, max_len: int) -> List[Tensor]:
|
|
'''
|
|
|
|
@brief: Prepare inputs Tensors for the model, the given sizes are used to
|
|
determine the ranges of the dimensions of when using TRT dynamic shapes.
|
|
|
|
@return: a list contains values which can be fed into the self.forward()
|
|
'''
|
|
|
|
bs_range = [1, (max_batch_size + 1) // 2, max_batch_size]
|
|
max_len_range = [1, (max_len + 1) // 2, max_len]
|
|
|
|
dynamic_shape = [-1, self.num_heads, -1, self.head_size]
|
|
Q = Tensor(name='Q',
|
|
dtype=trt.float16,
|
|
shape=dynamic_shape,
|
|
dim_range=OrderedDict([
|
|
('batch_size', [bs_range]),
|
|
('num_heads', [self.num_heads]),
|
|
('seq_len', [max_len_range]),
|
|
('head_size', [self.head_size]),
|
|
]))
|
|
K = Tensor(name='K',
|
|
dtype=trt.float16,
|
|
shape=dynamic_shape,
|
|
dim_range=OrderedDict([
|
|
('batch_size', [bs_range]),
|
|
('num_heads', [self.num_heads]),
|
|
('seq_len', [max_len_range]),
|
|
('head_size', [self.head_size]),
|
|
]))
|
|
V = Tensor(name='V',
|
|
dtype=trt.float16,
|
|
shape=dynamic_shape,
|
|
dim_range=OrderedDict([
|
|
('batch_size', [bs_range]),
|
|
('num_heads', [self.num_heads]),
|
|
('seq_len', [max_len_range]),
|
|
('head_size', [self.head_size]),
|
|
]))
|
|
return [Q, K, V]
|
|
|
|
|
|
def build_engine(builder: Builder, builder_config: BuilderConfig,
|
|
engine_name: str, args: argparse.Namespace) -> trt.IHostMemory:
|
|
'''
|
|
@brief: Build a TensorRT engine.
|
|
@param args: The cmd line arguments.
|
|
@return: The built or refitted engine.
|
|
'''
|
|
|
|
# Initialize Module
|
|
softmax_scale = 1.0 / math.sqrt(args.head_size)
|
|
layer = FmhaLayer(args.num_heads, args.head_size, softmax_scale)
|
|
|
|
# Module -> Network
|
|
network = builder.create_network()
|
|
network.trt_network.name = engine_name
|
|
network.plugin_config.to_legacy_setting()
|
|
with net_guard(network):
|
|
# Prepare
|
|
inputs = layer.prepare_inputs(args.max_batch_size, args.max_seq_len)
|
|
# Forward
|
|
logger.debug(f'model inputs: {inputs}')
|
|
layer(*inputs)
|
|
|
|
print('dot:')
|
|
print(network.to_dot())
|
|
|
|
layer = network.get_layer_by_name(next(
|
|
network.get_layers()).name).as_layer()
|
|
print('layer', layer.plugin.plugin_type)
|
|
print('layer', layer.plugin.plugin_version)
|
|
print('layer', layer.plugin.plugin_namespace)
|
|
|
|
# Network -> Engine
|
|
engine = builder.build_engine(network, builder_config)
|
|
config_path = Path(args.output_dir) / 'config.json'
|
|
builder.save_config(builder_config, str(config_path))
|
|
return engine
|
|
|
|
|
|
def build(args):
|
|
tensorrt_llm.logger.set_level(args.log_level)
|
|
output_dir = Path(args.output_dir)
|
|
output_dir.mkdir(parents=True, exist_ok=True)
|
|
|
|
builder = Builder()
|
|
cache = None
|
|
builder_config = builder.create_builder_config(
|
|
name='fmha_triton',
|
|
precision=args.dtype,
|
|
timing_cache=args.timing_cache if cache is None else cache,
|
|
profiling_verbosity=args.profiling_verbosity)
|
|
|
|
engine_name = get_engine_name(args.head_size, args.dtype)
|
|
engine = build_engine(builder, builder_config, engine_name, args)
|
|
assert engine is not None
|
|
|
|
engine_path = output_dir / engine_name
|
|
logger.info(f'Serializing engine to {str(engine_path)}...')
|
|
tik = time.time()
|
|
with engine_path.open('wb') as f:
|
|
f.write(engine)
|
|
tok = time.time()
|
|
t = time.strftime('%H:%M:%S', time.gmtime(tok - tik))
|
|
logger.info(f'Engine serialized. Total time: {t}')
|
|
|
|
ok = builder.save_timing_cache(builder_config,
|
|
Path(args.output_dir) / "model.cache")
|
|
assert ok, "Failed to save timing cache."
|
|
|
|
|
|
if __name__ == '__main__':
|
|
parser = argparse.ArgumentParser(
|
|
formatter_class=argparse.ArgumentDefaultsHelpFormatter)
|
|
parser.add_argument('--max_batch_size', type=int, default=4)
|
|
parser.add_argument('--max_seq_len', type=int, default=256)
|
|
parser.add_argument('--num_heads', type=int, default=8)
|
|
parser.add_argument('--head_size', type=int, default=64)
|
|
parser.add_argument('--dtype',
|
|
type=str,
|
|
default='float16',
|
|
choices=['float16', 'float32'])
|
|
parser.add_argument(
|
|
'--timing_cache',
|
|
type=str,
|
|
default='model.cache',
|
|
help='The path of to read timing cache from, will be ignored '
|
|
'if the file does not exist')
|
|
parser.add_argument(
|
|
'--profiling_verbosity',
|
|
type=str,
|
|
default='layer_names_only',
|
|
choices=['layer_names_only', 'detailed', 'none'],
|
|
help=
|
|
'The profiling verbosity for the generated TRT engine. Set to detailed can inspect tactic choices and kernel parameters.'
|
|
)
|
|
parser.add_argument('--log_level', type=str, default='info')
|
|
parser.add_argument(
|
|
'--output_dir',
|
|
type=str,
|
|
default='outputs',
|
|
help='The path to save the serialized engine files, timing cache '
|
|
'file and model configs')
|
|
args = parser.parse_args()
|
|
|
|
logger.set_level(args.log_level)
|
|
logger.info('Parameters'.center(40, '='))
|
|
for k, v in vars(args).items():
|
|
logger.info(f' - {k.ljust(15, ".")}: {v}')
|
|
logger.info(''.center(40, '='))
|
|
|
|
tik = time.time()
|
|
logger.info('Build TensorRT engine.')
|
|
build(args)
|
|
tok = time.time()
|
|
t = time.strftime('%H:%M:%S', time.gmtime(tok - tik))
|
|
logger.info(f'Total time of building TRT engine: {t}')
|