TensorRT-LLMs/tests/model/test_gptneox.py
Kaiyu Xie f044eb8d94
Update TensorRT-LLM (#302)
* Update TensorRT-LLM

---------

Co-authored-by: wangruohui <12756472+wangruohui@users.noreply.github.com>
2023-11-07 19:51:58 +08:00

414 lines
17 KiB
Python

# SPDX-FileCopyrightText: Copyright (c) 2022-2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import sys
import tempfile
import unittest
from itertools import product
import numpy as np
import pytest
import tensorrt as trt
import torch
from parameterized import parameterized
from transformers import GPTNeoXConfig, GPTNeoXForCausalLM
import tensorrt_llm
from tensorrt_llm import Builder
from tensorrt_llm.network import net_guard
from tensorrt_llm.plugin.plugin import ContextFMHAType
sys.path.append(os.path.join(os.path.dirname(__file__), '../..'))
from examples.gptneox.weight import load_from_hf_gpt_neox
sys.path.append(os.path.join(os.path.dirname(__file__), '..'))
from utils.util import getSMVersion
def compare_max_abs_error(ref, res, str):
# calculate max abs error
compare_HF = ref.cpu().numpy().flatten()
compare_TRT_LLM = res.cpu().numpy().flatten()
max_abs_error = np.max(abs(compare_TRT_LLM - compare_HF))
print(str, "max abs error = ", max_abs_error)
class TestGPTNeoX(unittest.TestCase):
def _gen_hf_gpt_neox(self, hidden_act, n_layer, max_length, dtype):
gpt_config = GPTNeoXConfig(hidden_act=hidden_act,
num_hidden_layers=n_layer,
max_length=max_length,
torch_dtype=dtype,
hidden_size=4096,
intermediate_size=4096 * 4,
num_attention_heads=64,
rotary_pct=0.25)
hf_gpt = GPTNeoXForCausalLM(gpt_config).cuda().to(
tensorrt_llm._utils.str_dtype_to_torch(dtype)).eval()
return gpt_config, hf_gpt
def _gen_tensorrt_llm_network(self, network, builder, hf_gpt, gpt_config,
batch_size, beam_width, input_len, output_len,
fp16, gpt_attention_plugin, rank,
tensor_parallel,
apply_query_key_layer_scaling):
num_layers = gpt_config.num_hidden_layers
num_heads = gpt_config.num_attention_heads
hidden_size = gpt_config.hidden_size
vocab_size = gpt_config.vocab_size
hidden_act = gpt_config.hidden_act
max_position_embeddings = gpt_config.max_position_embeddings
rotary_dim = int((hidden_size // num_heads) * gpt_config.rotary_pct)
list(range(tensor_parallel))
with net_guard(network):
kv_dtype = trt.float16 if fp16 else trt.float32
# Initialize model
tensorrt_llm_gpt = tensorrt_llm.models.GPTNeoXForCausalLM(
num_layers=num_layers,
num_heads=num_heads,
hidden_size=hidden_size,
vocab_size=vocab_size,
hidden_act=hidden_act,
max_position_embeddings=max_position_embeddings,
rotary_dim=rotary_dim,
dtype=kv_dtype,
mapping=tensorrt_llm.Mapping(world_size=tensor_parallel,
tp_size=tensor_parallel),
apply_query_key_layer_scaling=apply_query_key_layer_scaling)
inputs = tensorrt_llm_gpt.prepare_inputs(batch_size,
input_len,
output_len,
use_cache=True,
max_beam_width=beam_width)
load_from_hf_gpt_neox(tensorrt_llm_gpt,
hf_gpt,
fp16=fp16,
rank=rank,
tp_size=tensor_parallel)
# Prepare
network.set_named_parameters(tensorrt_llm_gpt.named_parameters())
tensorrt_llm_gpt(*inputs)
return network
def _gen_tensorrt_llm_runtime(self,
log_level,
dtype,
world_size,
rank,
gpt_config,
hf_gpt,
model,
use_attention_plugin,
batch_size,
beam_width,
input_len,
output_len,
use_refit,
use_ln_gemm_plugin,
apply_query_key_layer_scaling,
context_fmha_flag=ContextFMHAType.disabled,
enable_remove_input_padding=False):
tensorrt_llm.logger.set_level('error')
mapping = tensorrt_llm.Mapping(world_size, rank, tp_size=world_size)
runtime = None
builder = Builder()
fp16 = (dtype == 'float16')
with tempfile.TemporaryDirectory() as tmpdirname:
network = builder.create_network()
if use_attention_plugin:
network.plugin_config.set_gpt_attention_plugin(dtype)
if use_ln_gemm_plugin:
network.plugin_config.set_gemm_plugin(dtype)
network.plugin_config.set_layernorm_plugin(dtype)
if enable_remove_input_padding:
network.plugin_config.enable_remove_input_padding()
network.plugin_config.set_context_fmha(context_fmha_flag)
self._gen_tensorrt_llm_network(network, builder, hf_gpt, gpt_config,
batch_size, beam_width, input_len,
output_len, fp16,
use_attention_plugin, rank,
world_size,
apply_query_key_layer_scaling)
builder_config = builder.create_builder_config(
name='gptneox',
precision=dtype,
timing_cache='model.cache',
tensor_parallel=world_size, # TP only
use_refit=use_refit,
)
engine_buffer = builder.build_engine(network, builder_config)
runtime = tensorrt_llm.runtime.generation._Runtime(
engine_buffer, mapping)
ok = builder.save_timing_cache(builder_config, 'model.cache')
assert ok, "Failed to save timing cache."
return runtime, engine_buffer
def load_test_cases():
test_cases = product([
ContextFMHAType.disabled, ContextFMHAType.enabled,
ContextFMHAType.enabled_with_fp32_acc
], [False, True])
return test_cases
@parameterized.expand(load_test_cases)
def test_gptneox_plugin(self, context_fmha_flag,
enable_remove_input_padding):
# Skip tests that are not supported in pre-ampere architecture
if getSMVersion() < 80:
if context_fmha_flag == ContextFMHAType.enabled:
pytest.skip(
"ContextFMHAType is not supported in pre-ampere architecture"
)
elif context_fmha_flag == ContextFMHAType.enabled_with_fp32_acc:
pytest.skip(
"ContextFMHAType with fp32 acc is not supported in pre-ampere architecture"
)
torch.random.manual_seed(0)
use_refit = False
apply_query_key_layer_scaling = False
model = 'gptneox'
log_level = 'error'
dtype = 'float16'
world_size = 1
rank = 0
hidden_act = 'gelu'
n_layer = 6
max_length = 128
batch_size = 1
beam_width = 1
seq_len = 128
total_seq_len = max_length + seq_len
use_attention_plugin = True
use_ln_gemm_plugin = True
gpt_config, hf_gpt = self._gen_hf_gpt_neox(hidden_act, n_layer,
seq_len + max_length, dtype)
runtime, _ = self._gen_tensorrt_llm_runtime(
log_level, dtype, world_size, rank, gpt_config, hf_gpt, model,
use_attention_plugin, batch_size, beam_width, seq_len, max_length,
use_refit, use_ln_gemm_plugin, apply_query_key_layer_scaling,
context_fmha_flag, enable_remove_input_padding)
key_value_cache_buffers = []
head_size = gpt_config.hidden_size // gpt_config.num_attention_heads
for i in range(gpt_config.num_hidden_layers):
key_value_cache_buffers.append(
torch.zeros((
batch_size,
2,
gpt_config.num_attention_heads,
total_seq_len,
head_size,
),
dtype=tensorrt_llm._utils.str_dtype_to_torch(dtype),
device='cuda'))
# compare context
step = 0
ctx_ids = torch.randint(100, (batch_size, seq_len)).int().cuda()
with torch.no_grad():
hf_outputs = hf_gpt.forward(ctx_ids, use_cache=True)
torch.cuda.synchronize()
ref = hf_outputs.logits[:, -1, :]
ctx_context_lengths = seq_len * torch.ones(
(batch_size), dtype=torch.int32, device='cuda')
ctx_host_request_types = torch.tensor([0] * batch_size,
dtype=torch.int32)
ctx_position_ids = torch.tensor(range(seq_len),
dtype=torch.int32).reshape([
1, seq_len
]).expand([batch_size, seq_len]).cuda()
ctx_last_token_ids = ctx_context_lengths.clone()
# We need sequence_lengths start as context_lengths for step 0,
# and it will be added one after each step.
sequence_length_buffer = ctx_context_lengths.detach().clone()
if enable_remove_input_padding:
ctx_ids = ctx_ids.view([1, batch_size * seq_len])
ctx_position_ids = ctx_position_ids.view([1, batch_size * seq_len])
ctx_last_token_ids = torch.cumsum(ctx_last_token_ids, dim=0).int()
cache_indirections = [
torch.full((
batch_size,
beam_width,
total_seq_len,
),
0,
dtype=torch.int32,
device='cuda'),
torch.full((
batch_size,
beam_width,
total_seq_len,
),
0,
dtype=torch.int32,
device='cuda')
] # ping-pong buffers
ctx_buffer = {
'input_ids': ctx_ids,
'context_lengths': ctx_context_lengths,
'host_request_types': ctx_host_request_types,
'position_ids': ctx_position_ids,
'last_token_ids': ctx_last_token_ids,
'cache_indirection': cache_indirections[0],
}
if enable_remove_input_padding:
ctx_buffer['host_context_lengths'] = ctx_context_lengths.cpu()
ctx_shape = {k: v.shape for k, v in ctx_buffer.items()}
shape = (batch_size, 2, gpt_config.num_attention_heads, total_seq_len,
gpt_config.hidden_size // gpt_config.num_attention_heads)
for i in range(gpt_config.num_hidden_layers):
ctx_shape[f'past_key_value_{i}'] = shape
ctx_buffer[f'past_key_value_{i}'] = key_value_cache_buffers[i]
ctx_buffer[f'present_key_value_{i}'] = key_value_cache_buffers[i]
sequence_length_buffer = torch.add(sequence_length_buffer, step)
ctx_buffer['sequence_length'] = sequence_length_buffer
ctx_shape['sequence_length'] = ctx_buffer['sequence_length'].shape
ctx_buffer['host_past_key_value_lengths'] = torch.tensor(
[0] * batch_size, dtype=torch.int32)
ctx_shape['host_past_key_value_lengths'] = ctx_buffer[
'host_past_key_value_lengths'].shape
context = runtime.ctx_context
runtime._set_shape(context, ctx_shape)
runtime._set_buffer(context, ctx_buffer)
runtime._run(context)
torch.cuda.synchronize()
res = ctx_buffer['logits']
np.testing.assert_allclose(ref.cpu().numpy(),
res.cpu().numpy(),
atol=1e-1)
compare_max_abs_error(ref, res, "context logits")
v_inner = 16 // (2 if dtype == 'float16' else 4)
for i in range(gpt_config.num_hidden_layers):
res_present_key_value = ctx_buffer[f'present_key_value_{i}']
past_key_value_tensor = res_present_key_value.permute(1, 0, 2, 3, 4)
key, value = past_key_value_tensor.chunk(2)
# TRT-LLM has the same cache layout for key and value:
# [bs, n_head, max_seq_len, head_size]
key = key.reshape(batch_size, gpt_config.num_attention_heads,
total_seq_len, head_size)
value = value.reshape(batch_size, gpt_config.num_attention_heads,
total_seq_len, head_size)
ref_present_key, ref_present_value = hf_outputs.past_key_values[i]
np.testing.assert_allclose(ref_present_key.cpu().numpy(),
key[:, :, :seq_len, :].cpu().numpy(),
atol=1e-1)
compare_max_abs_error(ref_present_key, key[:, :, :seq_len, :],
"ref_present_key")
np.testing.assert_allclose(ref_present_value.cpu().numpy(),
value[:, :, :seq_len, :].cpu().numpy(),
atol=1e-1)
compare_max_abs_error(ref_present_value, value[:, :, :seq_len, :],
"ref_present_value")
# compare generation
step = 1
step1_id = torch.randint(100, (batch_size, 1)).int().cuda()
gen_position_ids = torch.ones_like(step1_id).int().cuda() * seq_len
gen_context_lengths = ctx_context_lengths.clone()
gen_host_request_types = torch.tensor([1] * batch_size,
dtype=torch.int32)
gen_last_token_ids = torch.zeros_like(gen_context_lengths).int().cuda()
with torch.no_grad():
hf_outputs = hf_gpt.forward(
step1_id,
past_key_values=hf_outputs.past_key_values,
position_ids=gen_position_ids,
use_cache=True)
torch.cuda.synchronize()
ref = hf_outputs.logits[:, -1, :]
if enable_remove_input_padding:
step1_id = step1_id.view([1, batch_size])
gen_position_ids = gen_position_ids.view([1, batch_size])
gen_last_token_ids = torch.ones_like(
gen_context_lengths).int().cuda()
gen_last_token_ids = torch.cumsum(gen_last_token_ids, dim=0).int()
step1_buffer = {
'input_ids': step1_id,
'context_lengths': gen_context_lengths,
'host_request_types': gen_host_request_types,
'position_ids': gen_position_ids,
'last_token_ids': gen_last_token_ids,
'cache_indirection': cache_indirections[1],
}
if enable_remove_input_padding:
step1_buffer['host_context_lengths'] = gen_context_lengths.cpu()
step1_shape = {k: v.shape for k, v in step1_buffer.items()}
for i in range(gpt_config.num_hidden_layers):
step1_shape[f'past_key_value_{i}'] = shape
step1_shape['sequence_length'] = (batch_size, )
step1_shape['host_past_key_value_lengths'] = (batch_size, )
for i in range(gpt_config.num_hidden_layers):
step1_buffer[f'past_key_value_{i}'] = key_value_cache_buffers[i]
step1_buffer[f'present_key_value_{i}'] = key_value_cache_buffers[i]
# For step 1, the sequence_lengths = context_lengths + 1.
sequence_length_buffer = torch.add(sequence_length_buffer, step)
step1_buffer['sequence_length'] = sequence_length_buffer
step1_buffer['host_past_key_value_lengths'] = torch.tensor(
[seq_len + step - 1] * batch_size, dtype=torch.int32)
context = runtime.context_1
runtime._set_shape(context, step1_shape)
runtime._set_buffer(context, step1_buffer)
runtime._run(context)
torch.cuda.synchronize()
res = step1_buffer['logits']
np.testing.assert_allclose(ref.cpu().numpy(),
res.cpu().numpy(),
atol=1e-1)
compare_max_abs_error(ref, res, "generation logits")
if __name__ == '__main__':
unittest.main()