mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-01-14 06:27:45 +08:00
237 lines
9.5 KiB
Python
237 lines
9.5 KiB
Python
# SPDX-FileCopyrightText: Copyright (c) 2022-2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
# SPDX-License-Identifier: Apache-2.0
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
import argparse
|
|
import os
|
|
from collections import OrderedDict
|
|
|
|
import onnx
|
|
import tensorrt as trt
|
|
from onnx import TensorProto, helper
|
|
|
|
import tensorrt_llm
|
|
from tensorrt_llm.builder import Builder
|
|
from tensorrt_llm.functional import assertion, shape
|
|
from tensorrt_llm.network import net_guard
|
|
|
|
|
|
def trt_dtype_to_onnx(dtype):
|
|
if dtype == trt.float16:
|
|
return TensorProto.DataType.FLOAT16
|
|
elif dtype == trt.float32:
|
|
return TensorProto.DataType.FLOAT
|
|
elif dtype == trt.int32:
|
|
return TensorProto.DataType.INT32
|
|
else:
|
|
raise TypeError("%s is not supported" % dtype)
|
|
|
|
|
|
def to_onnx(network, path):
|
|
inputs = []
|
|
for i in range(network.num_inputs):
|
|
network_input = network.get_input(i)
|
|
inputs.append(
|
|
helper.make_tensor_value_info(
|
|
network_input.name, trt_dtype_to_onnx(network_input.dtype),
|
|
list(network_input.shape)))
|
|
|
|
outputs = []
|
|
for i in range(network.num_outputs):
|
|
network_output = network.get_output(i)
|
|
outputs.append(
|
|
helper.make_tensor_value_info(
|
|
network_output.name, trt_dtype_to_onnx(network_output.dtype),
|
|
list(network_output.shape)))
|
|
|
|
nodes = []
|
|
for i in range(network.num_layers):
|
|
layer = network.get_layer(i)
|
|
layer_inputs = []
|
|
for j in range(layer.num_inputs):
|
|
ipt = layer.get_input(j)
|
|
if ipt is not None:
|
|
layer_inputs.append(layer.get_input(j).name)
|
|
layer_outputs = [
|
|
layer.get_output(j).name for j in range(layer.num_outputs)
|
|
]
|
|
nodes.append(
|
|
helper.make_node(str(layer.type),
|
|
name=layer.name,
|
|
inputs=layer_inputs,
|
|
outputs=layer_outputs,
|
|
domain="com.nvidia"))
|
|
|
|
onnx_model = helper.make_model(helper.make_graph(nodes,
|
|
'attention',
|
|
inputs,
|
|
outputs,
|
|
initializer=None),
|
|
producer_name='NVIDIA')
|
|
onnx.save(onnx_model, path)
|
|
|
|
|
|
def parse_arguments():
|
|
parser = argparse.ArgumentParser()
|
|
parser.add_argument('--world_size',
|
|
type=int,
|
|
default=1,
|
|
help='world size, only support tensor parallelism now')
|
|
parser.add_argument('--dtype', type=str, default='float32')
|
|
parser.add_argument('--log_level', type=str, default='info')
|
|
parser.add_argument('--vocab_size', type=int, default=51200)
|
|
parser.add_argument('--n_layer', type=int, default=24)
|
|
parser.add_argument('--n_positions', type=int, default=1024)
|
|
parser.add_argument('--n_embd', type=int, default=1024)
|
|
parser.add_argument('--n_head', type=int, default=16)
|
|
parser.add_argument('--hidden_act', type=str, default='gelu')
|
|
parser.add_argument('--max_batch_size', type=int, default=256)
|
|
parser.add_argument('--max_input_len', type=int, default=200)
|
|
parser.add_argument('--max_output_len', type=int, default=200)
|
|
parser.add_argument('--use_gpt_attention_plugin',
|
|
default=False,
|
|
action='store_true')
|
|
parser.add_argument('--use_gemm_plugin', default=False, action='store_true')
|
|
parser.add_argument('--use_layernorm_plugin',
|
|
default=False,
|
|
action='store_true')
|
|
parser.add_argument('--output_dir', type=str, default='gpt_outputs')
|
|
return parser.parse_args()
|
|
|
|
|
|
def prepare_inputs(args):
|
|
# Prepare inputs
|
|
head_size = args.n_embd // args.n_head
|
|
max_len = args.max_input_len + args.max_output_len
|
|
bs_range = [1, (args.max_batch_size + 1) // 2, args.max_batch_size]
|
|
inlen_range = [1, (args.max_input_len + 1) // 2, args.max_input_len]
|
|
max_len_range = [1, (max_len + 1) // 2, max_len]
|
|
step_range = [1, 1, args.max_input_len + 1]
|
|
|
|
input_ids = tensorrt_llm.Tensor(name='input_ids',
|
|
dtype=trt.int32,
|
|
shape=[-1, -1],
|
|
dim_range=OrderedDict([
|
|
('batch_size', [bs_range, bs_range]),
|
|
('input_len', [inlen_range, 1]),
|
|
]))
|
|
kv_dtype = trt.float16 if args.dtype == 'float16' else trt.float32
|
|
past_key_value = []
|
|
sequence_length = None
|
|
shape_tensor = None
|
|
if not args.use_gpt_attention_plugin:
|
|
for i in range(args.n_layer):
|
|
kv_dim_range = OrderedDict([
|
|
('batch_size', [bs_range, bs_range]),
|
|
('num_heads', [args.n_head, args.n_head]),
|
|
('past_key_len', [0, max_len_range]),
|
|
('head_size', [head_size, head_size]),
|
|
])
|
|
k = tensorrt_llm.Tensor(name=f'past_key_{i}',
|
|
dtype=kv_dtype,
|
|
shape=[-1, args.n_head, -1, head_size],
|
|
dim_range=kv_dim_range)
|
|
v = tensorrt_llm.Tensor(name=f'past_value_{i}',
|
|
dtype=kv_dtype,
|
|
shape=[-1, args.n_head, -1, head_size],
|
|
dim_range=kv_dim_range)
|
|
past_key_value.append((k, v))
|
|
# TODO(kaiyu): Remove this when TRT fix the named dimension
|
|
assertion(shape(input_ids, 0) == shape(k, 0), 'batch size')
|
|
assertion(shape(k, 2) == shape(v, 2), 'kv cache len')
|
|
else:
|
|
for i in range(args.n_layer):
|
|
past_key_value.append(
|
|
tensorrt_llm.Tensor(
|
|
name=f'past_{i}',
|
|
dtype=kv_dtype,
|
|
shape=[2, -1, args.n_head, -1, head_size],
|
|
dim_range=OrderedDict([
|
|
('2', [2, 2]), ('batch_size', [bs_range, bs_range]),
|
|
('num_heads', [args.n_head, args.n_head]),
|
|
('past_key_len', [max_len_range, max_len_range]),
|
|
('head_size', [head_size, head_size])
|
|
]),
|
|
))
|
|
sequence_length = tensorrt_llm.Tensor(
|
|
name='sequence_length',
|
|
dtype=trt.int32,
|
|
shape=[-1],
|
|
dim_range=OrderedDict([('batch_size', [bs_range, bs_range])]),
|
|
)
|
|
|
|
shape_tensor = tensorrt_llm.Tensor(
|
|
name='shape_tensor',
|
|
dtype=trt.int32,
|
|
shape=[-1, -1],
|
|
dim_range=OrderedDict([('step', [step_range, max_len_range]),
|
|
('cur_seq_len', [0, max_len_range])]))
|
|
return (input_ids, None, past_key_value, sequence_length, shape_tensor,
|
|
True)
|
|
|
|
|
|
if __name__ == '__main__':
|
|
args = parse_arguments()
|
|
tensorrt_llm.logger.set_level(args.log_level)
|
|
tensorrt_llm.set_default_dtype(args.dtype)
|
|
if not os.path.exists(args.output_dir):
|
|
os.makedirs(args.output_dir)
|
|
|
|
kv_dtype = trt.float16 if args.dtype == 'float16' else trt.float32
|
|
|
|
builder = Builder()
|
|
|
|
# Initialize Module
|
|
apply_query_key_layer_scaling = False
|
|
tensorrt_llm_gpt = tensorrt_llm.models.GPTLMHeadModel(
|
|
num_layers=args.n_layer,
|
|
num_heads=args.n_head,
|
|
hidden_size=args.n_embd,
|
|
vocab_size=args.vocab_size,
|
|
hidden_act=args.hidden_act,
|
|
max_position_embeddings=args.n_positions,
|
|
dtype=kv_dtype,
|
|
tensor_parallel=args.world_size, # TP only
|
|
tensor_parallel_group=list(range(args.world_size)), # TP only
|
|
apply_query_key_layer_scaling=apply_query_key_layer_scaling)
|
|
|
|
# Module -> Network
|
|
network = builder.create_network()
|
|
if args.use_gpt_attention_plugin:
|
|
network.plugin_config.set_gpt_attention_plugin()
|
|
if args.use_gemm_plugin:
|
|
network.plugin_config.set_gemm_plugin()
|
|
if args.use_layernorm_plugin:
|
|
network.plugin_config.set_layernorm_plugin()
|
|
with net_guard(network):
|
|
# Prepare
|
|
network.set_named_parameters(tensorrt_llm_gpt.named_parameters())
|
|
|
|
# Forward
|
|
inputs = prepare_inputs(args)
|
|
lm_logits, presents = tensorrt_llm_gpt(*inputs)
|
|
|
|
# Mark outputs
|
|
lm_logits.mark_output('logits', kv_dtype)
|
|
if not args.use_gpt_attention_plugin:
|
|
for i, present in enumerate(presents):
|
|
k, v = present
|
|
k.mark_output(f'present_key_{i}', kv_dtype)
|
|
v.mark_output(f'present_value_{i}', kv_dtype)
|
|
else:
|
|
for i, present in enumerate(presents):
|
|
present.mark_output(f'present_{i}', kv_dtype)
|
|
|
|
model_path = os.path.join(args.output_dir, 'test.onnx')
|
|
to_onnx(network.trt_network, model_path)
|