TensorRT-LLMs/tests/model/test_gptj.py
Kaiyu Xie 75b6210ff4
Kaiyu/update main (#5)
* Update

* Update
2023-10-18 22:38:53 +08:00

456 lines
18 KiB
Python

# SPDX-FileCopyrightText: Copyright (c) 2022-2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import sys
import tempfile
import unittest
from itertools import product
import numpy as np
import pytest
import tensorrt as trt
import torch
from parameterized import parameterized
from transformers import GPTJConfig, GPTJForCausalLM
import tensorrt_llm
from tensorrt_llm import Builder
from tensorrt_llm.network import net_guard
from tensorrt_llm.plugin.plugin import ContextFMHAType
sys.path.append(os.path.join(os.path.dirname(__file__), '../..'))
from examples.gptj.weight import load_from_hf_gpt_j
sys.path.append(os.path.join(os.path.dirname(__file__), '..'))
from utils.util import getSMVersion
class TestGPTJ(unittest.TestCase):
def _gen_hf_gpt_j(self, hidden_act, n_layer, max_length, dtype):
gpt_config = GPTJConfig(activation_function=hidden_act,
n_layer=n_layer,
max_length=max_length,
torch_dtype=dtype,
n_embd=4096,
n_head=16,
rotary_dim=64)
hf_gpt = GPTJForCausalLM(gpt_config).cuda().to(
tensorrt_llm._utils.str_dtype_to_torch(dtype)).eval()
return gpt_config, hf_gpt
def _gen_tensorrt_llm_network(self, network, hf_gpt, gpt_config, batch_size,
beam_width, input_len, output_len, fp16,
tensor_parallel):
num_layers = gpt_config.n_layer
num_heads = gpt_config.n_head
hidden_size = gpt_config.n_embd
vocab_size = gpt_config.vocab_size
hidden_act = gpt_config.activation_function
n_positions = gpt_config.n_positions
rotary_dim = gpt_config.rotary_dim
with net_guard(network):
kv_dtype = trt.float16 if fp16 else trt.float32
# Initialize model
tensorrt_llm_gpt = tensorrt_llm.models.GPTJForCausalLM(
num_layers=num_layers,
num_heads=num_heads,
hidden_size=hidden_size,
vocab_size=vocab_size,
hidden_act=hidden_act,
max_position_embeddings=n_positions,
rotary_dim=rotary_dim,
dtype=kv_dtype,
mapping=tensorrt_llm.Mapping(world_size=tensor_parallel,
tp_size=tensor_parallel),
)
inputs = tensorrt_llm_gpt.prepare_inputs(batch_size,
input_len,
output_len,
use_cache=True,
max_beam_width=beam_width)
load_from_hf_gpt_j(tensorrt_llm_gpt, hf_gpt, fp16=fp16)
# Prepare
network.set_named_parameters(tensorrt_llm_gpt.named_parameters())
tensorrt_llm_gpt(*inputs)
return network
def _gen_tensorrt_llm_runtime(self,
dtype,
world_size,
rank,
gpt_config,
hf_gpt,
use_attention_plugin,
batch_size,
beam_width,
input_len,
output_len,
use_refit,
use_ln_gemm_plugin,
context_fmha_flag=ContextFMHAType.disabled,
enable_remove_input_padding=False):
tensorrt_llm.logger.set_level('error')
mapping = tensorrt_llm.Mapping(world_size, rank, tp_size=world_size)
runtime = None
builder = Builder()
fp16 = (dtype == 'float16')
with tempfile.TemporaryDirectory() as tmpdirname:
network = builder.create_network()
if use_attention_plugin:
network.plugin_config.set_gpt_attention_plugin(dtype)
if use_ln_gemm_plugin:
network.plugin_config.set_gemm_plugin(dtype)
network.plugin_config.set_layernorm_plugin(dtype)
if enable_remove_input_padding:
network.plugin_config.enable_remove_input_padding()
network.plugin_config.set_context_fmha(context_fmha_flag)
self._gen_tensorrt_llm_network(network, hf_gpt, gpt_config,
batch_size, beam_width, input_len,
output_len, fp16, world_size)
builder_config = builder.create_builder_config(
name='gptj',
precision=dtype,
timing_cache='model.cache',
tensor_parallel=world_size, # TP only
use_refit=use_refit,
)
engine_buffer = builder.build_engine(network, builder_config)
assert engine_buffer is not None
runtime = tensorrt_llm.runtime.generation._Runtime(
engine_buffer, mapping)
ok = builder.save_timing_cache(builder_config, 'model.cache')
assert ok, "Failed to save timing cache."
return runtime, engine_buffer
def load_test_cases():
test_cases = list(
product([
ContextFMHAType.disabled, ContextFMHAType.enabled,
ContextFMHAType.enabled_with_fp32_acc
], [False, True]))
return test_cases
@parameterized.expand(load_test_cases)
def test_gptj_plugin(self, context_fmha_flag, enable_remove_input_padding):
# Skip tests that are not supported in pre-ampere architecture
if getSMVersion() < 80:
if context_fmha_flag == ContextFMHAType.enabled:
pytest.skip(
"ContextFMHAType is not supported in pre-ampere architecture"
)
elif context_fmha_flag == ContextFMHAType.enabled_with_fp32_acc:
pytest.skip(
"ContextFMHAType with fp32 acc is not supported in pre-ampere architecture"
)
torch.random.manual_seed(0)
use_refit = False
dtype = 'float16'
world_size = 1
rank = 0
hidden_act = 'gelu'
n_layer = 2
max_length = 2
batch_size = 1
beam_width = 1
seq_len = 12
total_seq_len = max_length + seq_len
use_attention_plugin = True
use_ln_gemm_plugin = True
gpt_config, hf_gpt = self._gen_hf_gpt_j(hidden_act, n_layer,
seq_len + max_length, dtype)
runtime, _ = self._gen_tensorrt_llm_runtime(
dtype,
world_size,
rank,
gpt_config,
hf_gpt,
use_attention_plugin,
batch_size,
beam_width,
seq_len,
max_length,
use_refit,
use_ln_gemm_plugin,
context_fmha_flag,
enable_remove_input_padding=enable_remove_input_padding)
key_value_cache_buffers = []
head_size = gpt_config.n_embd // gpt_config.n_head
for i in range(gpt_config.n_layer):
key_value_cache_buffers.append(
torch.zeros((
batch_size,
2,
gpt_config.n_head,
total_seq_len,
head_size,
),
dtype=tensorrt_llm._utils.str_dtype_to_torch(dtype),
device='cuda'))
def run_engine(context,
input_ids,
context_lengths,
host_request_types,
position_ids,
last_token_ids,
cache_indirection,
host_past_key_value_lengths,
sequence_length,
host_context_lengths=None):
ctx_buffer = {
'input_ids': input_ids,
'context_lengths': context_lengths,
'host_request_types': host_request_types,
'position_ids': position_ids,
'last_token_ids': last_token_ids,
'cache_indirection': cache_indirection,
'host_past_key_value_lengths': host_past_key_value_lengths,
'sequence_length': sequence_length,
}
for i in range(gpt_config.n_layer):
ctx_buffer[f'past_key_value_{i}'] = key_value_cache_buffers[i]
ctx_buffer[f'present_key_value_{i}'] = key_value_cache_buffers[
i]
if enable_remove_input_padding:
assert host_context_lengths is not None, "host_context_lengths is required for ragged input"
ctx_buffer['host_context_lengths'] = host_context_lengths
ctx_shape = {
key: buffer.shape
for key, buffer in ctx_buffer.items()
}
runtime._set_shape(context, ctx_shape)
runtime._set_buffer(context, ctx_buffer)
runtime._run(context)
torch.cuda.synchronize()
res = ctx_buffer['logits']
return res
cache_indirections = [
torch.full((
batch_size,
beam_width,
total_seq_len,
),
0,
dtype=torch.int32,
device='cuda'),
torch.full((
batch_size,
beam_width,
total_seq_len,
),
0,
dtype=torch.int32,
device='cuda')
] # ping-pong buffers
ctx_context_lengths = seq_len * torch.ones(
(batch_size), dtype=torch.int32, device='cuda')
# We need sequence_lengths start as context_lengths, and are added one in each step.
sequence_length_buffer = ctx_context_lengths.detach().clone()
hf_outputs = None
def compare_context():
ctx_ids = torch.randint(100, (batch_size, seq_len)).int().cuda()
with torch.no_grad():
nonlocal hf_outputs
hf_outputs = hf_gpt.forward(ctx_ids, use_cache=True)
torch.cuda.synchronize()
ref = hf_outputs.logits[:, -1, :]
ctx_position_ids = torch.tensor(range(seq_len),
dtype=torch.int32).reshape([
1, seq_len
]).expand([batch_size,
seq_len]).cuda()
ctx_last_token_ids = ctx_context_lengths.clone()
if enable_remove_input_padding:
ctx_ids = ctx_ids.view([1, batch_size * seq_len])
ctx_position_ids = ctx_position_ids.view(
[1, batch_size * seq_len])
ctx_last_token_ids = torch.cumsum(ctx_last_token_ids,
dim=0).int()
host_request_types = torch.tensor([0 for i in range(batch_size)],
dtype=torch.int32).cpu()
host_past_key_value_lengths = torch.tensor([0] * batch_size,
dtype=torch.int32)
host_context_lengths = ctx_context_lengths.cpu(
) if enable_remove_input_padding else None
res = run_engine(
context=runtime.ctx_context,
input_ids=ctx_ids,
context_lengths=ctx_context_lengths,
position_ids=ctx_position_ids,
last_token_ids=ctx_last_token_ids,
cache_indirection=cache_indirections[0],
host_past_key_value_lengths=host_past_key_value_lengths,
sequence_length=sequence_length_buffer,
host_context_lengths=host_context_lengths,
host_request_types=host_request_types)
np.testing.assert_allclose(ref.cpu().numpy(),
res.cpu().numpy(),
atol=1e-1)
v_inner = 16 // (2 if dtype == 'float16' else 4)
for i in range(gpt_config.n_layer):
res_present_key_value = key_value_cache_buffers[i]
past_key_value_tensor = res_present_key_value.permute(
1, 0, 2, 3, 4)
key, value = past_key_value_tensor.chunk(2)
# TRT-LLM has the same cache layout for key and value:
# [bs, n_head, max_seq_len, head_size]
head_size = gpt_config.n_embd // gpt_config.n_head
key = key.reshape(batch_size, gpt_config.n_head, total_seq_len,
head_size)
value = value.reshape(batch_size, gpt_config.n_head,
total_seq_len, head_size)
ref_present_key, ref_present_value = hf_outputs.past_key_values[
i]
np.testing.assert_allclose(ref_present_key.cpu().numpy(),
key[:, :, :seq_len, :].cpu().numpy(),
atol=1e-1)
np.testing.assert_allclose(
ref_present_value.cpu().numpy(),
value[:, :, :seq_len, :].cpu().numpy(),
atol=1e-1)
def compare_generation():
step1_id = torch.randint(100, (batch_size, 1)).int().cuda()
gen_position_ids = torch.ones_like(step1_id).int().cuda() * seq_len
gen_context_lengths = ctx_context_lengths.clone()
gen_last_token_ids = torch.zeros_like(
gen_context_lengths).int().cuda()
with torch.no_grad():
nonlocal hf_outputs
hf_outputs = hf_gpt.forward(
step1_id,
past_key_values=hf_outputs.past_key_values,
position_ids=gen_position_ids,
use_cache=True)
torch.cuda.synchronize()
ref = hf_outputs.logits[:, -1, :]
if enable_remove_input_padding:
step1_id = step1_id.view([1, batch_size])
gen_position_ids = gen_position_ids.view([1, batch_size])
gen_last_token_ids = torch.ones_like(
gen_context_lengths).int().cuda()
gen_last_token_ids = torch.cumsum(gen_last_token_ids,
dim=0).int()
host_past_key_value_lengths = torch.tensor([seq_len] * batch_size,
dtype=torch.int32)
host_request_types = torch.tensor([1] * batch_size,
dtype=torch.int32).cpu()
host_context_lengths = gen_context_lengths.cpu(
) if enable_remove_input_padding else None
# For step 1, the sequence_lengths = context_lengths + 1.
sequence_length_buffer = torch.add(ctx_context_lengths, 1)
res = run_engine(
context=runtime.context_1,
input_ids=step1_id,
# note we should pass context length for generation phase.
context_lengths=ctx_context_lengths,
position_ids=gen_position_ids,
last_token_ids=gen_last_token_ids,
cache_indirection=cache_indirections[1],
host_past_key_value_lengths=host_past_key_value_lengths,
sequence_length=sequence_length_buffer,
host_context_lengths=host_context_lengths,
host_request_types=host_request_types)
np.testing.assert_allclose(ref.cpu().numpy(),
res.cpu().numpy(),
atol=1e-1)
compare_context()
compare_generation()
def test_gptj_noplugin_supported(self):
use_refit = False
dtype = 'float16'
world_size = 1
rank = 0
hidden_act = 'gelu'
n_layer = 1
max_length = 2
batch_size = 4
seq_len = 128
use_attention_plugin = False
use_ln_gemm_plugin = True
beam_width = 1
gpt_config, hf_gpt = self._gen_hf_gpt_j(hidden_act, n_layer,
seq_len + max_length, dtype)
runtime, _ = self._gen_tensorrt_llm_runtime(
dtype, world_size, rank, gpt_config, hf_gpt, use_attention_plugin,
batch_size, beam_width, seq_len, max_length, use_refit,
use_ln_gemm_plugin)
use_ln_gemm_plugin = False
if trt.__version__[:3] == '8.6':
with self.assertRaisesRegex(
AssertionError,
"You need to enable the LayerNorm plugin for GPT-J with TensorRT"
):
runtime, _ = self._gen_tensorrt_llm_runtime(
dtype, world_size, rank, gpt_config, hf_gpt,
use_attention_plugin, batch_size, beam_width, seq_len,
max_length, use_refit, use_ln_gemm_plugin)
if __name__ == '__main__':
unittest.main()