TensorRT-LLMs/examples/gpt/visualize.py
2023-09-20 00:29:41 -07:00

237 lines
9.5 KiB
Python

# SPDX-FileCopyrightText: Copyright (c) 2022-2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
import os
from collections import OrderedDict
import onnx
import tensorrt as trt
from onnx import TensorProto, helper
import tensorrt_llm
from tensorrt_llm.builder import Builder
from tensorrt_llm.functional import assertion, shape
from tensorrt_llm.network import net_guard
def trt_dtype_to_onnx(dtype):
if dtype == trt.float16:
return TensorProto.DataType.FLOAT16
elif dtype == trt.float32:
return TensorProto.DataType.FLOAT
elif dtype == trt.int32:
return TensorProto.DataType.INT32
else:
raise TypeError("%s is not supported" % dtype)
def to_onnx(network, path):
inputs = []
for i in range(network.num_inputs):
network_input = network.get_input(i)
inputs.append(
helper.make_tensor_value_info(
network_input.name, trt_dtype_to_onnx(network_input.dtype),
list(network_input.shape)))
outputs = []
for i in range(network.num_outputs):
network_output = network.get_output(i)
outputs.append(
helper.make_tensor_value_info(
network_output.name, trt_dtype_to_onnx(network_output.dtype),
list(network_output.shape)))
nodes = []
for i in range(network.num_layers):
layer = network.get_layer(i)
layer_inputs = []
for j in range(layer.num_inputs):
ipt = layer.get_input(j)
if ipt is not None:
layer_inputs.append(layer.get_input(j).name)
layer_outputs = [
layer.get_output(j).name for j in range(layer.num_outputs)
]
nodes.append(
helper.make_node(str(layer.type),
name=layer.name,
inputs=layer_inputs,
outputs=layer_outputs,
domain="com.nvidia"))
onnx_model = helper.make_model(helper.make_graph(nodes,
'attention',
inputs,
outputs,
initializer=None),
producer_name='NVIDIA')
onnx.save(onnx_model, path)
def parse_arguments():
parser = argparse.ArgumentParser()
parser.add_argument('--world_size',
type=int,
default=1,
help='world size, only support tensor parallelism now')
parser.add_argument('--dtype', type=str, default='float32')
parser.add_argument('--log_level', type=str, default='info')
parser.add_argument('--vocab_size', type=int, default=51200)
parser.add_argument('--n_layer', type=int, default=24)
parser.add_argument('--n_positions', type=int, default=1024)
parser.add_argument('--n_embd', type=int, default=1024)
parser.add_argument('--n_head', type=int, default=16)
parser.add_argument('--hidden_act', type=str, default='gelu')
parser.add_argument('--max_batch_size', type=int, default=256)
parser.add_argument('--max_input_len', type=int, default=200)
parser.add_argument('--max_output_len', type=int, default=200)
parser.add_argument('--use_gpt_attention_plugin',
default=False,
action='store_true')
parser.add_argument('--use_gemm_plugin', default=False, action='store_true')
parser.add_argument('--use_layernorm_plugin',
default=False,
action='store_true')
parser.add_argument('--output_dir', type=str, default='gpt_outputs')
return parser.parse_args()
def prepare_inputs(args):
# Prepare inputs
head_size = args.n_embd // args.n_head
max_len = args.max_input_len + args.max_output_len
bs_range = [1, (args.max_batch_size + 1) // 2, args.max_batch_size]
inlen_range = [1, (args.max_input_len + 1) // 2, args.max_input_len]
max_len_range = [1, (max_len + 1) // 2, max_len]
step_range = [1, 1, args.max_input_len + 1]
input_ids = tensorrt_llm.Tensor(name='input_ids',
dtype=trt.int32,
shape=[-1, -1],
dim_range=OrderedDict([
('batch_size', [bs_range, bs_range]),
('input_len', [inlen_range, 1]),
]))
kv_dtype = trt.float16 if args.dtype == 'float16' else trt.float32
past_key_value = []
sequence_length = None
shape_tensor = None
if not args.use_gpt_attention_plugin:
for i in range(args.n_layer):
kv_dim_range = OrderedDict([
('batch_size', [bs_range, bs_range]),
('num_heads', [args.n_head, args.n_head]),
('past_key_len', [0, max_len_range]),
('head_size', [head_size, head_size]),
])
k = tensorrt_llm.Tensor(name=f'past_key_{i}',
dtype=kv_dtype,
shape=[-1, args.n_head, -1, head_size],
dim_range=kv_dim_range)
v = tensorrt_llm.Tensor(name=f'past_value_{i}',
dtype=kv_dtype,
shape=[-1, args.n_head, -1, head_size],
dim_range=kv_dim_range)
past_key_value.append((k, v))
# TODO(kaiyu): Remove this when TRT fix the named dimension
assertion(shape(input_ids, 0) == shape(k, 0), 'batch size')
assertion(shape(k, 2) == shape(v, 2), 'kv cache len')
else:
for i in range(args.n_layer):
past_key_value.append(
tensorrt_llm.Tensor(
name=f'past_{i}',
dtype=kv_dtype,
shape=[2, -1, args.n_head, -1, head_size],
dim_range=OrderedDict([
('2', [2, 2]), ('batch_size', [bs_range, bs_range]),
('num_heads', [args.n_head, args.n_head]),
('past_key_len', [max_len_range, max_len_range]),
('head_size', [head_size, head_size])
]),
))
sequence_length = tensorrt_llm.Tensor(
name='sequence_length',
dtype=trt.int32,
shape=[-1],
dim_range=OrderedDict([('batch_size', [bs_range, bs_range])]),
)
shape_tensor = tensorrt_llm.Tensor(
name='shape_tensor',
dtype=trt.int32,
shape=[-1, -1],
dim_range=OrderedDict([('step', [step_range, max_len_range]),
('cur_seq_len', [0, max_len_range])]))
return (input_ids, None, past_key_value, sequence_length, shape_tensor,
True)
if __name__ == '__main__':
args = parse_arguments()
tensorrt_llm.logger.set_level(args.log_level)
tensorrt_llm.set_default_dtype(args.dtype)
if not os.path.exists(args.output_dir):
os.makedirs(args.output_dir)
kv_dtype = trt.float16 if args.dtype == 'float16' else trt.float32
builder = Builder()
# Initialize Module
apply_query_key_layer_scaling = False
tensorrt_llm_gpt = tensorrt_llm.models.GPTLMHeadModel(
num_layers=args.n_layer,
num_heads=args.n_head,
hidden_size=args.n_embd,
vocab_size=args.vocab_size,
hidden_act=args.hidden_act,
max_position_embeddings=args.n_positions,
dtype=kv_dtype,
tensor_parallel=args.world_size, # TP only
tensor_parallel_group=list(range(args.world_size)), # TP only
apply_query_key_layer_scaling=apply_query_key_layer_scaling)
# Module -> Network
network = builder.create_network()
if args.use_gpt_attention_plugin:
network.plugin_config.set_gpt_attention_plugin()
if args.use_gemm_plugin:
network.plugin_config.set_gemm_plugin()
if args.use_layernorm_plugin:
network.plugin_config.set_layernorm_plugin()
with net_guard(network):
# Prepare
network.set_named_parameters(tensorrt_llm_gpt.named_parameters())
# Forward
inputs = prepare_inputs(args)
lm_logits, presents = tensorrt_llm_gpt(*inputs)
# Mark outputs
lm_logits.mark_output('logits', kv_dtype)
if not args.use_gpt_attention_plugin:
for i, present in enumerate(presents):
k, v = present
k.mark_output(f'present_key_{i}', kv_dtype)
v.mark_output(f'present_value_{i}', kv_dtype)
else:
for i, present in enumerate(presents):
present.mark_output(f'present_{i}', kv_dtype)
model_path = os.path.join(args.output_dir, 'test.onnx')
to_onnx(network.trt_network, model_path)