TensorRT-LLMs/tests/model/test_gpt.py
2023-10-10 23:22:17 -07:00

889 lines
37 KiB
Python

# SPDX-FileCopyrightText: Copyright (c) 2022-2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import math
import os
import random
import sys
import tempfile
import unittest
from itertools import product
import numpy as np
import pytest
import tensorrt as trt
import torch
from parameterized import parameterized
from transformers import GPT2Config, GPT2LMHeadModel
import tensorrt_llm
from tensorrt_llm import Builder
from tensorrt_llm._utils import str_dtype_to_torch
from tensorrt_llm.network import net_guard
from tensorrt_llm.plugin.plugin import ContextFMHAType
from tensorrt_llm.runtime import ModelConfig, SamplingConfig
from tensorrt_llm.runtime.generation import _prepare_attention_mask
from tensorrt_llm.runtime.kv_cache_manager import (GenerationSequence,
KVCacheManager)
sys.path.append(os.path.join(os.path.dirname(__file__), '../..'))
from examples.gpt.weight import load_from_hf_gpt
sys.path.append(os.path.join(os.path.dirname(__file__), '..'))
from utils.util import getSMVersion
class TestGPT(unittest.TestCase):
def _gen_hf_gpt(self, hidden_act, n_layer, max_length, dtype):
gpt_config = GPT2Config(
activation_function=hidden_act,
n_layer=n_layer,
max_length=max_length,
torch_dtype=dtype,
)
hf_gpt = GPT2LMHeadModel(gpt_config).cuda().eval()
return gpt_config, hf_gpt
def _gen_tensorrt_llm_network(self, network, builder, hf_gpt, gpt_config,
batch_size, input_len, output_len, fp16,
gpt_attention_plugin, tensor_parallel,
apply_query_key_layer_scaling,
gather_all_token_logits):
num_layers = gpt_config.n_layer
num_heads = gpt_config.n_head
hidden_size = gpt_config.n_embd
vocab_size = gpt_config.vocab_size
hidden_act = gpt_config.activation_function
n_positions = gpt_config.n_positions
tensor_parallel_group = list(range(tensor_parallel))
with net_guard(network):
kv_dtype = trt.float16 if fp16 else trt.float32
# Initialize model
tensorrt_llm_gpt = tensorrt_llm.models.GPTLMHeadModel(
num_layers=num_layers,
num_heads=num_heads,
hidden_size=hidden_size,
vocab_size=vocab_size,
hidden_act=hidden_act,
max_position_embeddings=n_positions,
dtype=kv_dtype,
mapping=tensorrt_llm.Mapping(world_size=tensor_parallel,
tp_size=tensor_parallel),
apply_query_key_layer_scaling=apply_query_key_layer_scaling)
inputs = tensorrt_llm_gpt.prepare_inputs(
batch_size,
input_len,
output_len,
use_cache=True,
max_beam_width=1,
gather_all_token_logits=gather_all_token_logits)
load_from_hf_gpt(tensorrt_llm_gpt,
hf_gpt,
dtype="float16" if fp16 else "float32")
# Prepare
network.set_named_parameters(tensorrt_llm_gpt.named_parameters())
tensorrt_llm_gpt(*inputs)
return network
def _gen_tensorrt_llm_runtime(self,
log_level,
dtype,
world_size,
rank,
gpt_config,
hf_gpt,
model,
use_plugin,
batch_size,
input_len,
output_len,
use_refit,
fast_building=False,
apply_query_key_layer_scaling=False,
context_fmha_type=ContextFMHAType.disabled,
enable_remove_input_padding=False,
enable_paged_kv_cache=False,
tokens_per_block=64,
gather_all_token_logits=False):
mapping = tensorrt_llm.Mapping(world_size, rank, tp_size=world_size)
runtime = None
builder = Builder()
fp16 = (dtype == 'float16')
with tempfile.TemporaryDirectory() as tmpdirname:
network = builder.create_network()
if use_plugin:
network.plugin_config.set_gpt_attention_plugin(dtype)
if fast_building:
network.plugin_config.set_gemm_plugin(dtype)
network.plugin_config.set_layernorm_plugin(dtype)
network.plugin_config.set_context_fmha(context_fmha_type)
if enable_remove_input_padding:
network.plugin_config.enable_remove_input_padding()
if enable_paged_kv_cache:
network.plugin_config.enable_paged_kv_cache(tokens_per_block)
self._gen_tensorrt_llm_network(network, builder, hf_gpt, gpt_config,
batch_size, input_len, output_len,
fp16, use_plugin, world_size,
apply_query_key_layer_scaling,
gather_all_token_logits)
builder_config = builder.create_builder_config(
name='gpt',
precision=dtype,
timing_cache='model.cache',
tensor_parallel=world_size, # TP only
use_refit=use_refit,
gather_all_token_logits=gather_all_token_logits,
)
engine_buffer = builder.build_engine(network, builder_config)
runtime = tensorrt_llm.runtime.generation._Runtime(
engine_buffer, mapping)
return runtime, engine_buffer
@parameterized.expand([(False)])
def test_gpt_float32(self, use_refit):
model = 'gpt'
log_level = 'error'
dtype = 'float32'
world_size = 1
rank = 0
hidden_act = 'gelu'
n_layer = 2
max_length = 2
batch_size = 4
beam_width = 1
seq_len = 128
total_length = seq_len + max_length
use_plugin = False
gpt_config, hf_gpt = self._gen_hf_gpt(hidden_act, n_layer, max_length,
dtype)
runtime, _ = self._gen_tensorrt_llm_runtime(
log_level, dtype, world_size, rank, gpt_config, hf_gpt, model,
use_plugin, batch_size, seq_len, max_length, use_refit)
# compare context
pad_token_id = 50256
ctx_ids = torch.randint(100, (batch_size, seq_len)).int().cuda()
ctx_ids[0][-1] = pad_token_id
ctx_ids[1][-3:] = pad_token_id
ctx_ids[2][-5:] = pad_token_id
ctx_context_lengths = seq_len * torch.ones(
(batch_size), dtype=torch.int32, device='cuda')
ctx_host_context_lengths = ctx_context_lengths.cpu()
ctx_host_request_types = torch.tensor([0] * batch_size,
dtype=torch.int32,
device='cpu')
ctx_position_ids = torch.tensor(range(seq_len),
dtype=torch.int32).reshape([
1, seq_len
]).expand([batch_size, seq_len]).cuda()
ctx_last_token_ids = ctx_context_lengths.clone()
ctx_attention_mask = _prepare_attention_mask(ctx_ids)
cache_indirections = [
torch.full((
batch_size,
beam_width,
total_length,
),
0,
dtype=torch.int32,
device='cuda'),
torch.full((
batch_size,
beam_width,
total_length,
),
0,
dtype=torch.int32,
device='cuda')
] # ping-pong buffers
ctx_shape = {
'input_ids': ctx_ids.shape,
'position_ids': ctx_position_ids.shape,
'context_lengths': ctx_context_lengths.shape,
'host_context_lengths': ctx_host_context_lengths.shape,
'last_token_ids': ctx_last_token_ids.shape,
'attention_mask': ctx_attention_mask.shape,
'host_request_types': ctx_host_request_types.shape,
'cache_indirection': cache_indirections[0].shape,
}
ctx_buffer = {
'input_ids': ctx_ids,
'position_ids': ctx_position_ids,
'context_lengths': ctx_context_lengths,
'host_context_lengths': ctx_host_context_lengths,
'last_token_ids': ctx_last_token_ids,
'attention_mask': ctx_attention_mask,
'host_request_types': ctx_host_request_types,
'cache_indirection': cache_indirections[0],
}
for i in range(gpt_config.n_layer):
shape = (batch_size, 2, gpt_config.n_head, 0,
gpt_config.n_embd // gpt_config.n_head)
past_buffer = torch.zeros((1, ),
dtype=str_dtype_to_torch(dtype),
device='cuda')
ctx_shape.update({
f'past_key_value_{i}': shape,
})
shape = (batch_size, 2, gpt_config.n_head, seq_len,
gpt_config.n_embd // gpt_config.n_head)
ctx_buffer.update({
f'past_key_value_{i}':
past_buffer,
f'present_key_value_{i}':
torch.zeros(shape,
dtype=str_dtype_to_torch(dtype),
device='cuda'),
})
context = runtime.ctx_context
runtime._set_shape(context, ctx_shape)
runtime._set_buffer(context, ctx_buffer)
runtime._run(context)
torch.cuda.synchronize()
res = ctx_buffer['logits']
with torch.no_grad():
hf_outputs = hf_gpt.forward(ctx_ids,
attention_mask=ctx_attention_mask)
torch.cuda.synchronize()
ref = hf_outputs.logits[:, -1, :]
np.testing.assert_allclose(ref.cpu().numpy(),
res.cpu().numpy(),
atol=1e-2)
for i in range(gpt_config.n_layer):
res_present_key_value = ctx_buffer[f'present_key_value_{i}']
ref_present_key, ref_present_value = hf_outputs.past_key_values[i]
past_key_value_tensor = res_present_key_value.permute(1, 0, 2, 3, 4)
key, value = past_key_value_tensor.chunk(2)
head_size = gpt_config.n_embd // gpt_config.n_head
key = key.to(torch.float32).reshape(batch_size, gpt_config.n_head,
seq_len, head_size)
value = value.reshape(batch_size, gpt_config.n_head, seq_len,
head_size)
np.testing.assert_allclose(ref_present_key.cpu().numpy(),
key.cpu().numpy(),
atol=1e-2)
np.testing.assert_allclose(ref_present_value.cpu().numpy(),
value.cpu().numpy(),
atol=1e-2)
# compare generation
gen_id = torch.randint(100, (batch_size, 1)).int().cuda()
gen_context_lengths = ctx_context_lengths.clone()
gen_host_context_lengths = ctx_host_context_lengths.clone()
gen_host_request_types = torch.tensor([1] * batch_size,
dtype=torch.int32,
device='cpu')
gen_position_ids = torch.ones_like(gen_id).cuda() * seq_len
gen_last_token_ids = torch.zeros_like(gen_context_lengths).cuda()
gen_attention_mask = torch.cat([
ctx_attention_mask,
ctx_attention_mask.new_ones((ctx_attention_mask.shape[0], 1))
],
dim=-1)
step1_shape = {
'input_ids': gen_id.shape,
'context_lengths': gen_context_lengths.shape,
'host_context_lengths': gen_host_context_lengths.shape,
'host_request_types': gen_host_request_types.shape,
'position_ids': gen_position_ids.shape,
'last_token_ids': gen_last_token_ids.shape,
'attention_mask': gen_attention_mask.shape,
'cache_indirection': cache_indirections[1].shape,
}
step1_buffer = {
'input_ids': gen_id,
'context_lengths': gen_context_lengths.contiguous(),
'host_context_lengths': gen_host_context_lengths.contiguous(),
'host_request_types': gen_host_request_types.contiguous(),
'position_ids': gen_position_ids.contiguous(),
'last_token_ids': gen_last_token_ids.contiguous(),
'attention_mask': gen_attention_mask.contiguous(),
'cache_indirection': cache_indirections[1].contiguous(),
}
for i in range(gpt_config.n_layer):
shape = (batch_size, 2, gpt_config.n_head, seq_len,
gpt_config.n_embd // gpt_config.n_head)
step1_shape.update({
f'past_key_value_{i}': shape,
})
step1_buffer.update({
f'past_key_value_{i}':
ctx_buffer[f'present_key_value_{i}'],
})
context = runtime.context_1
runtime._set_shape(context, step1_shape)
runtime._set_buffer(context, step1_buffer)
runtime._run(context)
torch.cuda.synchronize()
res = step1_buffer['logits']
with torch.no_grad():
hf_outputs = hf_gpt.forward(
gen_id,
attention_mask=gen_attention_mask,
past_key_values=hf_outputs.past_key_values,
use_cache=True)
torch.cuda.synchronize()
ref = hf_outputs.logits[:, -1, :]
np.testing.assert_allclose(ref.cpu().numpy(),
res.cpu().numpy(),
atol=1e-2)
for i in range(gpt_config.n_layer):
res_present_key_value = step1_buffer[f'present_key_value_{i}']
ref_present_key, ref_present_value = hf_outputs.past_key_values[i]
past_key_value_tensor = res_present_key_value.permute(1, 0, 2, 3, 4)
key, value = past_key_value_tensor.chunk(2)
head_size = gpt_config.n_embd // gpt_config.n_head
key = key.reshape(batch_size, gpt_config.n_head, seq_len + 1,
head_size)
value = value.reshape(batch_size, gpt_config.n_head, seq_len + 1,
head_size)
np.testing.assert_allclose(ref_present_key.cpu().numpy(),
key.cpu().numpy(),
atol=1e-2)
np.testing.assert_allclose(ref_present_value.cpu().numpy(),
value.cpu().numpy(),
atol=1e-2)
def load_test_cases():
test_cases = list(
product([False, True], [False, True], [False, True], [
ContextFMHAType.disabled, ContextFMHAType.enabled,
ContextFMHAType.enabled_with_fp32_acc
], [False, True], [False, True], [False, True]))
return test_cases
@parameterized.expand(load_test_cases)
def test_gpt_plugin(self, use_refit, fast_building,
apply_query_key_layer_scaling, context_fmha_type,
enable_remove_input_padding, enable_paged_kv_cache,
gather_all_token_logits):
# inflight batching mode only works with remove_input_padding and paged_kv_cache
use_in_flight_batching = enable_remove_input_padding and enable_paged_kv_cache and not gather_all_token_logits
# Skip tests that are not supported in pre-ampere architecture
if getSMVersion() < 80:
if context_fmha_type == ContextFMHAType.enabled:
pytest.skip(
"ContextFMHAType is not supported in pre-ampere architecture"
)
elif context_fmha_type == ContextFMHAType.enabled_with_fp32_acc:
pytest.skip(
"ContextFMHAType with fp32 acc is not supported in pre-ampere architecture"
)
torch.manual_seed(0)
random.seed(0)
model = 'gpt'
log_level = 'error'
dtype = 'float16'
world_size = 1
rank = 0
hidden_act = 'gelu'
n_layer = 1
max_length = 2
batch_size = 4
beam_width = 1
seq_len = 128
total_length = seq_len + max_length
use_plugin = True
tokens_per_block = 64
gpt_config, hf_gpt = self._gen_hf_gpt(hidden_act, n_layer,
seq_len + max_length, dtype)
runtime, _ = self._gen_tensorrt_llm_runtime(
log_level, dtype, world_size, rank, gpt_config, hf_gpt, model,
use_plugin, batch_size, seq_len, max_length, use_refit,
fast_building, apply_query_key_layer_scaling, context_fmha_type,
enable_remove_input_padding, enable_paged_kv_cache,
tokens_per_block, gather_all_token_logits)
key_value_cache_buffers = []
value_cache_buffers = []
head_size = gpt_config.n_embd // gpt_config.n_head
for i in range(gpt_config.n_layer):
if enable_paged_kv_cache:
blocks = batch_size * beam_width * math.ceil(
total_length / tokens_per_block)
cache_shape = (
blocks,
2,
gpt_config.n_head,
tokens_per_block,
head_size,
)
else:
cache_shape = (
batch_size,
2,
gpt_config.n_head,
total_length,
head_size,
)
key_value_cache_buffers.append(
torch.zeros(cache_shape,
dtype=tensorrt_llm._utils.str_dtype_to_torch(dtype),
device='cuda'))
value_cache_buffers.append(
torch.zeros((
batch_size,
gpt_config.n_head,
total_length,
head_size,
),
dtype=tensorrt_llm._utils.str_dtype_to_torch(dtype),
device='cuda'))
cache_indirections = [
torch.full((
batch_size,
beam_width,
total_length,
),
0,
dtype=torch.int32,
device='cuda'),
torch.full((
batch_size,
beam_width,
total_length,
),
0,
dtype=torch.int32,
device='cuda')
] # ping-pong buffers
if enable_paged_kv_cache:
max_blocks_per_seq = math.ceil(total_length / tokens_per_block)
blocks = batch_size * beam_width * max_blocks_per_seq
kv_cache_manager = KVCacheManager(key_value_cache_buffers, blocks,
tokens_per_block,
max_blocks_per_seq, beam_width)
# Add sequences to the manager
for bi in range(batch_size):
generation_sequence = GenerationSequence(seq_idx=bi,
batch_idx=bi)
kv_cache_manager.add_sequence(generation_sequence, seq_len)
def run_engine(context,
input_ids,
context_lengths,
host_request_types,
position_ids,
last_token_ids,
cache_indirection,
host_past_key_value_lengths,
sequence_length=None,
host_context_lengths=None):
ctx_buffer = {
'input_ids': input_ids,
'context_lengths': context_lengths,
'host_request_types': host_request_types,
'position_ids': position_ids,
'last_token_ids': last_token_ids,
'cache_indirection': cache_indirection,
'host_past_key_value_lengths': host_past_key_value_lengths,
'sequence_length': sequence_length,
}
assert host_request_types is not None
if enable_remove_input_padding:
assert host_context_lengths is not None, "host_context_lengths is required for ragged input"
ctx_buffer['host_context_lengths'] = host_context_lengths
if enable_paged_kv_cache:
assert beam_width == 1
# for beam_width > 1 the argument must be '1' in ctx phase and 'beam_width' in gen phase
kv_cache_block_pointers = kv_cache_manager.get_pointer_arrays(1)
for idx in range(gpt_config.n_layer):
shape = kv_cache_block_pointers[idx].shape
shape = [shape[0] * shape[1], *shape[2:]]
ctx_buffer[
f'kv_cache_block_pointers_{idx}'] = kv_cache_block_pointers[
idx].reshape(shape).contiguous()
else:
for i in range(gpt_config.n_layer):
ctx_buffer[f'past_key_value_{i}'] = key_value_cache_buffers[
i]
ctx_buffer[
f'present_key_value_{i}'] = key_value_cache_buffers[i]
ctx_shape = {
key: buffer.shape
for key, buffer in ctx_buffer.items()
}
runtime._set_shape(context, ctx_shape)
runtime._set_buffer(context, ctx_buffer)
runtime._run(context)
torch.cuda.synchronize()
res = ctx_buffer['logits']
return res
hf_outputs = None
step0_ids = None
step1_ids = None
def compare_context(run_ref_only=False):
nonlocal step0_ids
step0_ids = torch.randint(
100, (batch_size,
seq_len)).int().cuda() if step0_ids is None else step0_ids
ctx_ids = step0_ids.clone()
ctx_context_lengths = seq_len * torch.ones(
(batch_size), dtype=torch.int32, device='cuda')
ctx_position_ids = torch.tensor(range(seq_len),
dtype=torch.int32).reshape([
1, seq_len
]).expand([batch_size,
seq_len]).cuda()
ctx_last_token_ids = ctx_context_lengths.clone()
nonlocal hf_outputs
with torch.no_grad():
hf_outputs = hf_gpt.forward(ctx_ids)
torch.cuda.synchronize()
ref = hf_outputs.logits
if run_ref_only:
return ref[:, -1, :]
if enable_remove_input_padding:
ctx_ids = ctx_ids.view([1, batch_size * seq_len])
ctx_position_ids = ctx_position_ids.view(
[1, batch_size * seq_len])
ctx_last_token_ids = torch.cumsum(ctx_last_token_ids,
dim=0).int()
host_past_key_value_lengths = torch.tensor([0] * batch_size,
dtype=torch.int32)
host_context_lengths = ctx_context_lengths.cpu(
) if enable_remove_input_padding else None
host_request_types = torch.tensor([0 for i in range(batch_size)],
dtype=torch.int32).cpu()
# We need sequence_lengths start as context_lengths for step 0 (context),
# and it will be added one after each step.
sequence_length = ctx_context_lengths.detach().clone()
res = run_engine(
context=runtime.ctx_context,
input_ids=ctx_ids,
context_lengths=ctx_context_lengths,
position_ids=ctx_position_ids,
last_token_ids=ctx_last_token_ids,
cache_indirection=cache_indirections[0],
host_past_key_value_lengths=host_past_key_value_lengths,
sequence_length=sequence_length,
host_context_lengths=host_context_lengths,
host_request_types=host_request_types)
if gather_all_token_logits:
np.testing.assert_allclose(ref.cpu().numpy().flatten(),
res.cpu().numpy().flatten(),
atol=1e-1)
else:
np.testing.assert_allclose(ref[:, -1, :].cpu().numpy(),
res.cpu().numpy(),
atol=1e-1)
def compare_generation(run_ref_only=False):
step = 1
nonlocal step1_ids
step1_ids = torch.randint(
100, (batch_size,
1)).int().cuda() if step1_ids is None else step1_ids
gen_ids = step1_ids.clone()
gen_context_lengths = seq_len * torch.ones(
(batch_size), dtype=torch.int32, device='cuda')
gen_position_ids = torch.ones_like(gen_ids).int().cuda() * seq_len
gen_last_token_ids = torch.zeros_like(
gen_context_lengths).int().cuda()
nonlocal hf_outputs
with torch.no_grad():
hf_outputs = hf_gpt.forward(
gen_ids,
past_key_values=hf_outputs.past_key_values,
use_cache=True)
torch.cuda.synchronize()
ref = hf_outputs.logits[:, -1, :]
if run_ref_only:
return ref
if enable_remove_input_padding:
gen_ids = gen_ids.view([1, batch_size])
gen_position_ids = gen_position_ids.view([1, batch_size])
gen_last_token_ids = torch.ones_like(
gen_context_lengths).int().cuda()
gen_last_token_ids = torch.cumsum(gen_last_token_ids,
dim=0).int()
host_past_key_value_lengths = torch.tensor([seq_len + step - 1] *
batch_size,
dtype=torch.int32)
host_context_lengths = gen_context_lengths.cpu(
) if enable_remove_input_padding else None
host_request_types = torch.tensor([1 for i in range(batch_size)],
dtype=torch.int32).cpu()
# For step 1, the sequence_lengths = context_lengths + 1.
sequence_length = torch.add(gen_context_lengths.detach().clone(), 1)
res = run_engine(
context=runtime.context_1,
input_ids=gen_ids,
context_lengths=gen_context_lengths,
position_ids=gen_position_ids,
last_token_ids=gen_last_token_ids,
cache_indirection=cache_indirections[1],
host_past_key_value_lengths=host_past_key_value_lengths,
sequence_length=sequence_length,
host_context_lengths=host_context_lengths,
host_request_types=host_request_types)
np.testing.assert_allclose(ref.cpu().numpy().flatten(),
res.cpu().numpy().flatten(),
atol=1e-1)
def compare_mixing_context_and_generation_phases():
num_context_input = 2
assert batch_size >= num_context_input
num_generation_input = batch_size - num_context_input
# retrieve the reference output
ref_ctx_out = compare_context(True)[:num_context_input, :]
ref_gen_out = compare_generation(True)[num_context_input:, :]
ref_out = torch.cat([ref_ctx_out, ref_gen_out], dim=0)
ref_ctx_out = None
ref_gen_out = None
# compare_context()
# prepare the inputs for plugin-based gpt
assert step0_ids is not None and step1_ids is not None
input_ids = torch.cat([
step0_ids[:num_context_input, :].view(
(-1, )), step1_ids[num_context_input:].view((-1, ))
],
dim=0)
input_ids = input_ids.view((1, -1))
ctx_position_ids = torch.tensor(
range(seq_len), dtype=torch.int32).reshape(
(1, seq_len)).expand([num_generation_input,
seq_len]).cuda()
gen_position_ids = torch.ones_like(
step1_ids[num_context_input:].view(
(-1, ))).int().cuda() * seq_len
position_ids = torch.cat(
[ctx_position_ids.view((-1, )), gen_position_ids], dim=0).view(
(1, -1))
input_lengths = torch.tensor([seq_len] * num_context_input +
[1] * num_generation_input,
dtype=torch.int32).cuda()
gen_last_token_ids = torch.cumsum(input_lengths, dim=0).int().cuda()
# scalar of max_key_value_length for in-flight batching case
host_past_key_value_lengths = torch.tensor(
[0] * num_context_input + [seq_len] * num_generation_input,
dtype=torch.int32)
context_lengths = torch.tensor([seq_len] * batch_size,
dtype=torch.int32).cuda()
if enable_remove_input_padding:
host_context_lengths = context_lengths.cpu()
host_request_types = torch.tensor([0] * num_context_input +
[1] * num_generation_input,
dtype=torch.int32).cpu()
# The sequence_lengths = context_lengths + step for generation stage.
sequence_length = torch.tensor([seq_len] * num_context_input +
[seq_len + 1] * num_generation_input,
dtype=torch.int32).cuda()
res = run_engine(
context=runtime.context_1,
input_ids=input_ids,
context_lengths=context_lengths,
position_ids=position_ids,
last_token_ids=gen_last_token_ids,
cache_indirection=cache_indirections[0],
host_past_key_value_lengths=host_past_key_value_lengths,
sequence_length=sequence_length,
host_context_lengths=host_context_lengths,
host_request_types=host_request_types,
)
np.testing.assert_allclose(ref_out.cpu().numpy(),
res.cpu().numpy(),
atol=1e-1)
# Main logics
compare_context()
compare_generation()
# Only inflight batching mode could accept the mixture of requests from both context and generation phases
if use_in_flight_batching:
compare_mixing_context_and_generation_phases()
@parameterized.expand([(False, False), (False, True)])
def test_greedy_search_float32(self, use_refit, streaming):
model = 'gpt'
log_level = 'error'
dtype = 'float32'
world_size = 1
rank = 0
hidden_act = 'gelu'
n_layer = 2
max_new_tokens = 1
batch_size = 4
seq_len = 128
use_plugin = False
do_sample = False
early_stoppping = False
num_beams = 1
num_beam_groups = 1
temperature = 1
top_k = 0
top_p = 0.0
length_penalty = 1
repetition_penalty = 1
gpt_config, hf_gpt = self._gen_hf_gpt(hidden_act, n_layer,
max_new_tokens, dtype)
runtime, engine_buffer = self._gen_tensorrt_llm_runtime(
log_level, dtype, world_size, rank, gpt_config, hf_gpt, model,
use_plugin, batch_size, seq_len, max_new_tokens, use_refit)
model_config = ModelConfig(vocab_size=gpt_config.vocab_size,
num_layers=gpt_config.n_layer,
num_heads=gpt_config.n_head,
num_kv_heads=gpt_config.n_head,
hidden_size=gpt_config.n_embd,
gpt_attention_plugin=False,
dtype=dtype)
mapping = tensorrt_llm.Mapping(world_size, rank, tp_size=world_size)
decoder = tensorrt_llm.runtime.GenerationSession(
model_config, engine_buffer, mapping)
pad_token_id = 50256
eos_token_id = 50257
sampling_config = SamplingConfig(end_id=eos_token_id,
pad_id=pad_token_id,
num_beams=num_beams,
temperature=temperature,
top_k=top_k,
top_p=top_p,
length_penalty=length_penalty,
repetition_penalty=repetition_penalty)
input_ids = torch.randint(100, (batch_size, seq_len)).int().cuda()
input_ids[0][-1] = pad_token_id
input_ids[1][-3:] = pad_token_id
input_ids[2][-5:] = pad_token_id
input_lengths = torch.ones(
(batch_size)).type(torch.int32).cuda() * seq_len
decoder.setup(batch_size,
max_context_length=seq_len,
max_new_tokens=max_new_tokens,
beam_width=num_beams)
if streaming:
output_ids_gen = decoder.decode(input_ids,
input_lengths,
sampling_config,
streaming=True)
for output_ids in output_ids_gen:
pass
else:
output_ids = decoder.decode(input_ids, input_lengths,
sampling_config)
#TODO: change to actual ragged tensor after GPT plugin supports it
output_ids_x = decoder.decode(input_ids, input_lengths, sampling_config)
# works because all requests in the batch has same
# TODO: enable this when GPT Plugin attention works
# output_ids_y = decoder.decode_batch([t[:input_lengths[i]] for i, t in enumerate(torch.split(input_ids, 1, dim=0))], sampling_config)
torch.cuda.synchronize()
torch.testing.assert_close(output_ids, output_ids_x)
res = output_ids.squeeze()
res = res[:, -max_new_tokens:]
ref_output_ids = hf_gpt.generate(input_ids,
do_sample=do_sample,
early_stopping=early_stoppping,
num_beams=num_beams,
temperature=temperature,
top_k=top_k,
top_p=top_p,
num_beam_groups=num_beam_groups,
max_new_tokens=max_new_tokens,
length_penalty=length_penalty,
repetition_penalty=repetition_penalty,
pad_token_id=pad_token_id,
eos_token_id=eos_token_id)
torch.cuda.synchronize()
ref = ref_output_ids[:, -max_new_tokens:]
np.testing.assert_allclose(ref.cpu().numpy(), res.cpu().numpy())
if __name__ == '__main__':
unittest.main()