TensorRT-LLMs/tests/model/test_gpt.py
2024-12-24 15:58:43 +08:00

1100 lines
48 KiB
Python

# SPDX-FileCopyrightText: Copyright (c) 2022-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import math
import os
import random
import sys
import tempfile
import unittest
from argparse import Namespace
from itertools import product
from unittest.mock import patch
import numpy as np
import pytest
# isort: off
import torch
# isort: on
from parameterized import parameterized
from transformers import GPT2Config, GPT2LMHeadModel
import tensorrt_llm
from tensorrt_llm import Builder
from tensorrt_llm._utils import str_dtype_to_torch
from tensorrt_llm.functional import RotaryScalingType
from tensorrt_llm.layers import PositionEmbeddingType
from tensorrt_llm.models.gpt.convert import load_weights_from_hf_model
from tensorrt_llm.network import net_guard
from tensorrt_llm.plugin.plugin import ContextFMHAType
from tensorrt_llm.runtime import ModelConfig, SamplingConfig
from tensorrt_llm.runtime.generation import _prepare_attention_mask
from tensorrt_llm.runtime.kv_cache_manager import GenerationSequence
from tensorrt_llm.runtime.memory_pools.pools_kv_cache_manager import \
PoolsKVCacheManager
sys.path.append(os.path.join(os.path.dirname(__file__), '../..'))
from examples.gpt.convert_checkpoint import convert_and_save_hf
sys.path.append(os.path.join(os.path.dirname(__file__), '..'))
from utils.llm_data import llm_models_root
from utils.util import skip_fp32_accum_pre_ampere, unittest_name_func
from tensorrt_llm.runtime.memory_pools.memory_pools_allocator import \
MemoryPoolsAllocator
class TestGPT(unittest.TestCase):
def _gen_hf_gpt(self, hidden_act, n_layer, max_length, dtype):
gpt_config = GPT2Config(
activation_function=hidden_act,
n_layer=n_layer,
max_length=max_length,
torch_dtype=dtype,
)
gpt_config.n_kv_head = gpt_config.n_head
hf_gpt = GPT2LMHeadModel(gpt_config).cuda().eval()
return gpt_config, hf_gpt
def _gen_tensorrt_llm_network(self, network, builder, hf_gpt, gpt_config,
batch_size, input_len, output_len, dtype,
gpt_attention_plugin, tensor_parallel,
apply_query_key_layer_scaling,
gather_context_logits,
gather_generation_logits):
config = {
'architecture': 'GPTForCausalLM',
'dtype': dtype,
'num_hidden_layers': gpt_config.n_layer,
'num_attention_heads': gpt_config.n_head,
'num_key_value_heads': gpt_config.n_head,
'hidden_size': gpt_config.n_embd,
'intermediate_size': gpt_config.n_embd * 4,
'norm_epsilon': 1e-5,
'vocab_size': gpt_config.vocab_size,
'position_embedding_type': 'learned_absolute',
'max_position_embeddings': gpt_config.n_positions,
'hidden_act': gpt_config.activation_function,
'mapping': {
'world_size': tensor_parallel,
'tp_size': tensor_parallel,
},
'bias': getattr(gpt_config, 'bias', True),
'apply_query_key_layer_scaling': apply_query_key_layer_scaling,
}
config = tensorrt_llm.models.GPTConfig.from_dict(config)
weights = load_weights_from_hf_model(hf_gpt, config)
tensorrt_llm_gpt = tensorrt_llm.models.GPTForCausalLM(config)
tensorrt_llm_gpt.load(weights)
with net_guard(network):
# Initialize model
network.set_named_parameters(tensorrt_llm_gpt.named_parameters())
inputs = tensorrt_llm_gpt.prepare_inputs(
max_batch_size=batch_size,
max_input_len=input_len,
max_seq_len=input_len + output_len,
max_num_tokens=batch_size * input_len,
use_cache=True,
max_beam_width=1,
gather_context_logits=gather_context_logits,
gather_generation_logits=gather_generation_logits)
# Prepare
tensorrt_llm_gpt(**inputs)
return network
def _gen_tensorrt_llm_runtime(self,
log_level,
dtype,
world_size,
rank,
gpt_config,
hf_gpt,
model,
use_plugin,
batch_size,
input_len,
output_len,
use_refit,
fast_building=False,
apply_query_key_layer_scaling=False,
context_fmha_type=ContextFMHAType.disabled,
enable_remove_input_padding=False,
enable_paged_kv_cache=False,
tokens_per_block=128,
gather_context_logits=False,
gather_generation_logits=False):
mapping = tensorrt_llm.Mapping(world_size, rank, tp_size=world_size)
runtime = None
builder = Builder()
with tempfile.TemporaryDirectory() as tmpdirname:
builder_config = builder.create_builder_config(
name='gpt',
precision=dtype,
timing_cache='model.cache',
tensor_parallel=world_size, # TP only
use_refit=use_refit,
gather_context_logits=gather_context_logits,
gather_generation_logits=gather_generation_logits,
strongly_typed=True,
)
network = builder.create_network()
network.plugin_config.to_legacy_setting()
if use_plugin:
network.plugin_config.gpt_attention_plugin = dtype
if fast_building:
network.plugin_config.gemm_plugin = dtype
network.plugin_config.set_context_fmha(context_fmha_type)
if enable_remove_input_padding:
network.plugin_config.remove_input_padding = True
if enable_paged_kv_cache:
network.plugin_config.enable_paged_kv_cache(tokens_per_block)
self._gen_tensorrt_llm_network(network, builder, hf_gpt, gpt_config,
batch_size, input_len, output_len,
dtype, use_plugin, world_size,
apply_query_key_layer_scaling,
gather_context_logits,
gather_generation_logits)
engine_buffer = builder.build_engine(network, builder_config)
runtime = tensorrt_llm.runtime.generation._Runtime(
engine_buffer, mapping)
return runtime, engine_buffer
@parameterized.expand([("other", False)], name_func=unittest_name_func)
def test_gpt_float32(self, test_partition, use_refit):
torch.manual_seed(42)
model = 'gpt'
log_level = 'error'
dtype = 'float32'
world_size = 1
rank = 0
hidden_act = 'gelu'
n_layer = 2
max_length = 2
batch_size = 4
beam_width = 1
seq_len = 128
total_length = seq_len + max_length
use_plugin = False
gpt_config, hf_gpt = self._gen_hf_gpt(hidden_act, n_layer, max_length,
dtype)
runtime, _ = self._gen_tensorrt_llm_runtime(
log_level, dtype, world_size, rank, gpt_config, hf_gpt, model,
use_plugin, batch_size, seq_len, max_length, use_refit)
# compare context
pad_token_id = 50256
ctx_ids = torch.randint(100, (batch_size, seq_len)).int().cuda()
ctx_ids[0][-1] = pad_token_id
ctx_ids[1][-3:] = pad_token_id
ctx_ids[2][-5:] = pad_token_id
ctx_context_lengths = seq_len * torch.ones(
(batch_size), dtype=torch.int32, device='cuda')
ctx_host_context_lengths = ctx_context_lengths.cpu()
ctx_host_request_types = torch.tensor([0] * batch_size,
dtype=torch.int32,
device='cpu')
ctx_position_ids = torch.tensor(range(seq_len),
dtype=torch.int32).reshape([
1, seq_len
]).expand([batch_size, seq_len]).cuda()
ctx_last_token_ids = ctx_context_lengths.clone()
ctx_attention_mask = _prepare_attention_mask(ctx_ids)
cache_indirections = [
torch.full((
batch_size,
beam_width,
total_length,
),
0,
dtype=torch.int32,
device='cuda'),
torch.full((
batch_size,
beam_width,
total_length,
),
0,
dtype=torch.int32,
device='cuda')
] # ping-pong buffers
perf_knob_tensor_size = 16
context_runtime_perf_knobs = torch.tensor([-1] * perf_knob_tensor_size,
dtype=torch.int64)
host_context_progress = torch.tensor([0], dtype=torch.int64)
ctx_buffer = {
'input_ids': ctx_ids,
'position_ids': ctx_position_ids,
'context_lengths': ctx_context_lengths,
'host_context_lengths': ctx_host_context_lengths,
'last_token_ids': ctx_last_token_ids,
'attention_mask': ctx_attention_mask,
'host_request_types': ctx_host_request_types,
'cache_indirection': cache_indirections[0],
'host_runtime_perf_knobs': context_runtime_perf_knobs,
'host_context_progress': host_context_progress,
}
ctx_shape = {k: v.shape for k, v in ctx_buffer.items()}
for i in range(gpt_config.n_layer):
shape = (batch_size, 2, gpt_config.n_head, 0,
gpt_config.n_embd // gpt_config.n_head)
past_buffer = torch.zeros((1, ),
dtype=str_dtype_to_torch(dtype),
device='cuda')
ctx_shape.update({
f'past_key_value_{i}': shape,
})
shape = (batch_size, 2, gpt_config.n_head, seq_len,
gpt_config.n_embd // gpt_config.n_head)
ctx_buffer.update({
f'past_key_value_{i}':
past_buffer,
f'present_key_value_{i}':
torch.zeros(shape,
dtype=str_dtype_to_torch(dtype),
device='cuda'),
})
context = runtime.ctx_context
runtime._set_shape(context, ctx_shape)
runtime._set_buffer(context, ctx_buffer)
runtime._run(context)
torch.cuda.synchronize()
res = ctx_buffer['logits']
with torch.no_grad():
hf_outputs = hf_gpt.forward(ctx_ids,
attention_mask=ctx_attention_mask)
torch.cuda.synchronize()
ref = hf_outputs.logits[:, -1, :]
np.testing.assert_allclose(ref.cpu().numpy(),
res.cpu().numpy(),
atol=1e-2)
for i in range(gpt_config.n_layer):
res_present_key_value = ctx_buffer[f'present_key_value_{i}']
ref_present_key, ref_present_value = hf_outputs.past_key_values[i]
past_key_value_tensor = res_present_key_value.permute(1, 0, 2, 3, 4)
key, value = past_key_value_tensor.chunk(2)
head_size = gpt_config.n_embd // gpt_config.n_head
key = key.to(torch.float32).reshape(batch_size, gpt_config.n_head,
seq_len, head_size)
value = value.reshape(batch_size, gpt_config.n_head, seq_len,
head_size)
np.testing.assert_allclose(ref_present_key.cpu().numpy(),
key.cpu().numpy(),
atol=1e-2)
np.testing.assert_allclose(ref_present_value.cpu().numpy(),
value.cpu().numpy(),
atol=1e-2)
# compare generation
gen_id = torch.randint(100, (batch_size, 1)).int().cuda()
gen_context_lengths = ctx_context_lengths.clone()
gen_host_context_lengths = ctx_host_context_lengths.clone()
gen_host_request_types = torch.tensor([1] * batch_size,
dtype=torch.int32,
device='cpu')
gen_position_ids = torch.ones_like(gen_id).cuda() * seq_len
gen_last_token_ids = torch.zeros_like(gen_context_lengths).cuda()
gen_attention_mask = torch.cat([
ctx_attention_mask,
ctx_attention_mask.new_ones((ctx_attention_mask.shape[0], 1))
],
dim=-1)
gen_runtime_perf_knobs = torch.tensor([-1] * perf_knob_tensor_size,
dtype=torch.int64)
step1_shape = {
'input_ids': gen_id.shape,
'context_lengths': gen_context_lengths.shape,
'host_context_lengths': gen_host_context_lengths.shape,
'host_request_types': gen_host_request_types.shape,
'position_ids': gen_position_ids.shape,
'last_token_ids': gen_last_token_ids.shape,
'attention_mask': gen_attention_mask.shape,
'cache_indirection': cache_indirections[1].shape,
'host_runtime_perf_knobs': gen_runtime_perf_knobs.shape,
'host_context_progress': host_context_progress.shape,
}
step1_buffer = {
'input_ids': gen_id,
'context_lengths': gen_context_lengths.contiguous(),
'host_context_lengths': gen_host_context_lengths.contiguous(),
'host_request_types': gen_host_request_types.contiguous(),
'position_ids': gen_position_ids.contiguous(),
'last_token_ids': gen_last_token_ids.contiguous(),
'attention_mask': gen_attention_mask.contiguous(),
'cache_indirection': cache_indirections[1].contiguous(),
'host_runtime_perf_knobs': gen_runtime_perf_knobs,
'host_context_progress': host_context_progress,
}
for i in range(gpt_config.n_layer):
shape = (batch_size, 2, gpt_config.n_head, seq_len,
gpt_config.n_embd // gpt_config.n_head)
step1_shape.update({
f'past_key_value_{i}': shape,
})
step1_buffer.update({
f'past_key_value_{i}':
ctx_buffer[f'present_key_value_{i}'],
})
context = runtime.context_1
runtime._set_shape(context, step1_shape)
runtime._set_buffer(context, step1_buffer)
runtime._run(context)
torch.cuda.synchronize()
res = step1_buffer['logits']
with torch.no_grad():
hf_outputs = hf_gpt.forward(
gen_id,
attention_mask=gen_attention_mask,
past_key_values=hf_outputs.past_key_values,
use_cache=True)
torch.cuda.synchronize()
ref = hf_outputs.logits[:, -1, :]
np.testing.assert_allclose(ref.cpu().numpy(),
res.cpu().numpy(),
atol=1e-2)
for i in range(gpt_config.n_layer):
res_present_key_value = step1_buffer[f'present_key_value_{i}']
ref_present_key, ref_present_value = hf_outputs.past_key_values[i]
past_key_value_tensor = res_present_key_value.permute(1, 0, 2, 3, 4)
key, value = past_key_value_tensor.chunk(2)
head_size = gpt_config.n_embd // gpt_config.n_head
key = key.reshape(batch_size, gpt_config.n_head, seq_len + 1,
head_size)
value = value.reshape(batch_size, gpt_config.n_head, seq_len + 1,
head_size)
np.testing.assert_allclose(ref_present_key.cpu().numpy(),
key.cpu().numpy(),
atol=1e-2)
np.testing.assert_allclose(ref_present_value.cpu().numpy(),
value.cpu().numpy(),
atol=1e-2)
def load_test_cases():
test_cases = list(
product([False, True], [False, True], [False, True], [
ContextFMHAType.disabled, ContextFMHAType.enabled,
ContextFMHAType.enabled_with_fp32_acc
], [False, True], [False, True], [False, True], [False, True]))
# split test cases into 4 partitions
test_cases = [(f"partition{int(i % 4)}", ) + case
for i, case in enumerate(test_cases)]
return test_cases
@parameterized.expand(load_test_cases, name_func=unittest_name_func)
def test_gpt_plugin(self, test_partition, use_refit, fast_building,
apply_query_key_layer_scaling, context_fmha_type,
enable_remove_input_padding, enable_paged_kv_cache,
gather_context_logits, gather_generation_logits):
# inflight batching mode only works with remove_input_padding and paged_kv_cache
use_in_flight_batching = enable_remove_input_padding and enable_paged_kv_cache and not (
gather_context_logits or gather_generation_logits)
# Skip tests that are not supported in pre-ampere architecture
skip_fp32_accum_pre_ampere(context_fmha_type)
torch.manual_seed(0)
random.seed(0)
model = 'gpt'
log_level = 'error'
dtype = 'float16'
world_size = 1
rank = 0
hidden_act = 'gelu'
n_layer = 1
max_length = 2
batch_size = 4
beam_width = 1
seq_len = 128
total_length = seq_len + max_length
use_plugin = True
tokens_per_block = 128
gpt_config, hf_gpt = self._gen_hf_gpt(hidden_act, n_layer,
seq_len + max_length, dtype)
runtime, _ = self._gen_tensorrt_llm_runtime(
log_level, dtype, world_size, rank, gpt_config, hf_gpt, model,
use_plugin, batch_size, seq_len, max_length, use_refit,
fast_building, apply_query_key_layer_scaling, context_fmha_type,
enable_remove_input_padding, enable_paged_kv_cache,
tokens_per_block, gather_context_logits, gather_generation_logits)
key_value_cache_buffers = []
value_cache_buffers = []
head_size = gpt_config.n_embd // gpt_config.n_head
if enable_paged_kv_cache:
num_blocks = batch_size * beam_width * math.ceil(
total_length / tokens_per_block)
cache_shape = (
num_blocks,
gpt_config.n_layer,
2,
gpt_config.n_head,
tokens_per_block,
head_size,
)
key_value_cache_buffers.append(
torch.zeros(cache_shape,
dtype=tensorrt_llm._utils.str_dtype_to_torch(dtype),
device='cuda'))
else:
cache_shape = (
batch_size,
2,
gpt_config.n_head,
total_length,
head_size,
)
for _ in range(gpt_config.n_layer):
key_value_cache_buffers.append(
torch.zeros(
cache_shape,
dtype=tensorrt_llm._utils.str_dtype_to_torch(dtype),
device='cuda'))
for _ in range(gpt_config.n_layer):
value_cache_buffers.append(
torch.zeros((
batch_size,
gpt_config.n_head,
total_length,
head_size,
),
dtype=tensorrt_llm._utils.str_dtype_to_torch(dtype),
device='cuda'))
cache_indirections = [
torch.full((
batch_size,
beam_width,
total_length,
),
0,
dtype=torch.int32,
device='cuda'),
torch.full((
batch_size,
beam_width,
total_length,
),
0,
dtype=torch.int32,
device='cuda')
] # ping-pong buffers
if enable_paged_kv_cache:
max_blocks_per_seq = math.ceil(total_length / tokens_per_block)
num_blocks = batch_size * beam_width * max_blocks_per_seq
memory_pools_allocator = MemoryPoolsAllocator(
num_blocks=num_blocks,
tokens_per_block=tokens_per_block,
head_size=head_size)
num_kv_heads_per_layer = MemoryPoolsAllocator.prepare_num_kv_heads_per_layer(
gpt_config.n_head, gpt_config.n_layer)
memory_pools_allocator.allocate(dtype, num_kv_heads_per_layer)
pools_kv_cache_manager = PoolsKVCacheManager(
memory_pools_allocator.pools_metadata,
max_blocks_per_seq,
num_blocks,
tokens_per_block,
head_size,
max_attention_window_size=total_length,
beam_width=beam_width,
sink_token_len=0)
host_kv_cache_pool_pointers = memory_pools_allocator.get_kv_cache_pool_pointers(
)
host_kv_cache_pool_mapping = memory_pools_allocator.pool_mapping
# block_size = gpt_config.n_head * tokens_per_block * head_size
# kv_cache_manager = KVCacheManager(
# num_layers=gpt_config.n_layer,
# num_blocks=num_blocks,
# block_size=block_size,
# tokens_per_block=tokens_per_block,
# max_blocks_per_seq=max_blocks_per_seq,
# max_attention_window_size=total_length,
# sink_token_len=0,
# beam_width=beam_width)
# host_kv_cache_pool_pointers = torch.tensor(
# [key_value_cache_buffers[0].data_ptr(), 0], dtype=torch.int64)
# Add sequences to the manager
for bi in range(batch_size):
generation_sequence = GenerationSequence(seq_idx=bi,
batch_idx=bi)
pools_kv_cache_manager.add_sequence(generation_sequence,
seq_len)
# Pre allocate the kv cache for the generated tokens.
pools_kv_cache_manager.step([False] * batch_size)
def run_engine(context,
input_ids,
context_lengths,
host_request_types,
position_ids,
last_token_ids,
cache_indirection,
host_past_key_value_lengths,
host_max_attention_window_sizes,
host_sink_token_length,
host_runtime_perf_knobs,
host_context_progress,
sequence_length=None,
host_context_lengths=None):
ctx_buffer = {
'input_ids': input_ids,
'context_lengths': context_lengths,
'host_request_types': host_request_types,
'position_ids': position_ids,
'last_token_ids': last_token_ids,
'cache_indirection': cache_indirection,
'host_past_key_value_lengths': host_past_key_value_lengths,
'sequence_length': sequence_length,
'host_sink_token_length': host_sink_token_length,
'host_runtime_perf_knobs': host_runtime_perf_knobs,
'host_context_progress': host_context_progress,
}
assert host_request_types is not None
if enable_remove_input_padding:
assert host_context_lengths is not None, "host_context_lengths is required for ragged input"
ctx_buffer['host_context_lengths'] = host_context_lengths
if enable_paged_kv_cache:
assert beam_width == 1
# for beam_width > 1 the argument must be '1' in ctx phase and 'beam_width' in gen phase
host_kv_cache_block_offsets = pools_kv_cache_manager.get_block_offsets(
beam_width=1)
kv_cache_block_offsets = host_kv_cache_block_offsets.to('cuda')
shape = kv_cache_block_offsets.shape
shape = [shape[0], shape[1] * shape[2], *shape[3:]]
ctx_buffer[
f'kv_cache_block_offsets'] = kv_cache_block_offsets.reshape(
shape).contiguous()
ctx_buffer[
f'host_kv_cache_block_offsets'] = host_kv_cache_block_offsets.reshape(
shape).contiguous()
ctx_buffer[
f'host_kv_cache_pool_pointers'] = host_kv_cache_pool_pointers.contiguous(
)
ctx_buffer[
f'host_kv_cache_pool_mapping'] = memory_pools_allocator.pool_mapping.contiguous(
)
ctx_buffer[
f'host_max_attention_window_sizes'] = host_max_attention_window_sizes
else:
for i in range(gpt_config.n_layer):
ctx_buffer[f'past_key_value_{i}'] = key_value_cache_buffers[
i]
ctx_buffer[
f'present_key_value_{i}'] = key_value_cache_buffers[i]
ctx_buffer[
f'host_max_attention_window_sizes'] = host_max_attention_window_sizes
ctx_shape = {
key: buffer.shape
for key, buffer in ctx_buffer.items()
}
runtime._set_shape(context, ctx_shape)
runtime._set_buffer(context, ctx_buffer)
runtime._run(context)
torch.cuda.synchronize()
res = ctx_buffer['logits']
return res
hf_outputs = None
step0_ids = None
step1_ids = None
def compare_context(run_ref_only=False):
nonlocal step0_ids
step0_ids = torch.randint(
100, (batch_size,
seq_len)).int().cuda() if step0_ids is None else step0_ids
ctx_ids = step0_ids.clone()
ctx_context_lengths = seq_len * torch.ones(
(batch_size), dtype=torch.int32, device='cuda')
ctx_position_ids = torch.tensor(range(seq_len),
dtype=torch.int32).reshape([
1, seq_len
]).expand([batch_size,
seq_len]).cuda()
ctx_last_token_ids = ctx_context_lengths.clone()
nonlocal hf_outputs
with torch.no_grad():
hf_outputs = hf_gpt.forward(ctx_ids)
torch.cuda.synchronize()
ref = hf_outputs.logits
if run_ref_only:
return ref[:, -1, :]
if enable_remove_input_padding:
ctx_ids = ctx_ids.view([batch_size * seq_len])
ctx_position_ids = ctx_position_ids.view([batch_size * seq_len])
ctx_last_token_ids = torch.cumsum(ctx_last_token_ids,
dim=0).int()
host_max_attention_window_sizes = torch.tensor([total_length] *
gpt_config.n_layer,
dtype=torch.int32)
host_sink_token_length = torch.tensor([0], dtype=torch.int32)
host_context_lengths = ctx_context_lengths.cpu(
) if enable_remove_input_padding else None
host_request_types = torch.tensor([0 for i in range(batch_size)],
dtype=torch.int32).cpu()
host_past_key_value_lengths = ctx_context_lengths.detach().clone(
).cpu()
# We need sequence_lengths start as context_lengths for step 0 (context),
# and it will be added one after each step.
sequence_length = ctx_context_lengths.detach().clone()
perf_knob_tensor_size = 16
ctx_runtime_perf_knobs = torch.tensor([-1] * perf_knob_tensor_size,
dtype=torch.int64)
if context_fmha_type == ContextFMHAType.enabled_with_fp32_acc:
ctx_runtime_perf_knobs[1] = 1 # enable_context_fmha_fp32_acc
host_context_progress = torch.tensor([0], dtype=torch.int64)
res = run_engine(
context=runtime.ctx_context,
input_ids=ctx_ids,
context_lengths=ctx_context_lengths,
position_ids=ctx_position_ids,
last_token_ids=ctx_last_token_ids,
cache_indirection=cache_indirections[0],
host_past_key_value_lengths=host_past_key_value_lengths,
host_max_attention_window_sizes=host_max_attention_window_sizes,
host_sink_token_length=host_sink_token_length,
sequence_length=sequence_length,
host_context_lengths=host_context_lengths,
host_request_types=host_request_types,
host_runtime_perf_knobs=ctx_runtime_perf_knobs,
host_context_progress=host_context_progress)
if gather_context_logits:
np.testing.assert_allclose(ref.cpu().numpy().flatten(),
res.cpu().numpy().flatten(),
atol=1e-1)
else:
np.testing.assert_allclose(ref[:, -1, :].cpu().numpy(),
res.cpu().numpy(),
atol=1e-1)
def compare_generation(run_ref_only=False):
step = 1
nonlocal step1_ids
step1_ids = torch.randint(
100, (batch_size,
1)).int().cuda() if step1_ids is None else step1_ids
gen_ids = step1_ids.clone()
gen_context_lengths = seq_len * torch.ones(
(batch_size), dtype=torch.int32, device='cuda')
gen_position_ids = torch.ones_like(gen_ids).int().cuda() * seq_len
gen_last_token_ids = torch.zeros_like(
gen_context_lengths).int().cuda()
nonlocal hf_outputs
with torch.no_grad():
hf_outputs = hf_gpt.forward(
gen_ids,
past_key_values=hf_outputs.past_key_values,
use_cache=True)
torch.cuda.synchronize()
ref = hf_outputs.logits[:, -1, :]
if run_ref_only:
return ref
if enable_remove_input_padding:
gen_ids = gen_ids.view([batch_size])
gen_position_ids = gen_position_ids.view([batch_size])
gen_last_token_ids = torch.ones_like(
gen_context_lengths).int().cuda()
gen_last_token_ids = torch.cumsum(gen_last_token_ids,
dim=0).int()
host_past_key_value_lengths = torch.tensor([seq_len + step - 1] *
batch_size,
dtype=torch.int32)
host_max_attention_window_sizes = torch.tensor([seq_len + step] *
gpt_config.n_layer,
dtype=torch.int32)
host_sink_token_length = torch.tensor([0], dtype=torch.int32)
host_context_lengths = gen_context_lengths.cpu(
) if enable_remove_input_padding else None
host_request_types = torch.tensor([1 for i in range(batch_size)],
dtype=torch.int32).cpu()
# For step 1, the sequence_lengths = context_lengths + 1.
sequence_length = torch.add(gen_context_lengths.detach().clone(), 1)
perf_knob_tensor_size = 16
gen_runtime_perf_knobs = torch.tensor([-1] * perf_knob_tensor_size,
dtype=torch.int64)
if context_fmha_type == ContextFMHAType.enabled_with_fp32_acc:
gen_runtime_perf_knobs[1] = 1 # enable_context_fmha_fp32_acc
host_context_progress = torch.tensor([0], dtype=torch.int64)
res = run_engine(
context=runtime.context_1,
input_ids=gen_ids,
context_lengths=gen_context_lengths,
position_ids=gen_position_ids,
last_token_ids=gen_last_token_ids,
cache_indirection=cache_indirections[1],
host_past_key_value_lengths=host_past_key_value_lengths,
host_max_attention_window_sizes=host_max_attention_window_sizes,
host_sink_token_length=host_sink_token_length,
sequence_length=sequence_length,
host_context_lengths=host_context_lengths,
host_request_types=host_request_types,
host_runtime_perf_knobs=gen_runtime_perf_knobs,
host_context_progress=host_context_progress)
np.testing.assert_allclose(ref.cpu().numpy().flatten(),
res.cpu().numpy().flatten(),
atol=1e-1)
def compare_mixing_context_and_generation_phases():
num_context_input = 2
assert batch_size >= num_context_input
num_generation_input = batch_size - num_context_input
# retrieve the reference output
ref_ctx_out = compare_context(True)[:num_context_input, :]
ref_gen_out = compare_generation(True)[num_context_input:, :]
ref_out = torch.cat([ref_ctx_out, ref_gen_out], dim=0)
ref_ctx_out = None
ref_gen_out = None
# compare_context()
# prepare the inputs for plugin-based gpt
assert step0_ids is not None and step1_ids is not None
input_ids = torch.cat([
step0_ids[:num_context_input, :].view(
(-1, )), step1_ids[num_context_input:].view((-1, ))
],
dim=0)
input_ids = input_ids.view((-1, ))
ctx_position_ids = torch.tensor(
range(seq_len), dtype=torch.int32).reshape(
(1, seq_len)).expand([num_generation_input,
seq_len]).cuda()
gen_position_ids = torch.ones_like(
step1_ids[num_context_input:].view(
(-1, ))).int().cuda() * seq_len
position_ids = torch.cat(
[ctx_position_ids.view((-1, )), gen_position_ids], dim=0).view(
(-1, ))
input_lengths = torch.tensor([seq_len] * num_context_input +
[1] * num_generation_input,
dtype=torch.int32).cuda()
gen_last_token_ids = torch.cumsum(input_lengths, dim=0).int().cuda()
# scalar of max_key_value_length for in-flight batching case
host_past_key_value_lengths = torch.tensor(
[0] * num_context_input + [seq_len] * num_generation_input,
dtype=torch.int32)
host_max_attention_window_sizes = torch.tensor([total_length] *
gpt_config.n_layer,
dtype=torch.int32)
host_sink_token_length = torch.tensor([0], dtype=torch.int32)
context_lengths = torch.tensor([seq_len] * batch_size,
dtype=torch.int32).cuda()
if enable_remove_input_padding:
host_context_lengths = context_lengths.cpu()
host_request_types = torch.tensor([0] * num_context_input +
[1] * num_generation_input,
dtype=torch.int32).cpu()
# The sequence_lengths = context_lengths + step for generation stage.
sequence_length = torch.tensor([seq_len] * num_context_input +
[seq_len + 1] * num_generation_input,
dtype=torch.int32).cuda()
perf_knob_tensor_size = 16
runtime_perf_knobs_tensor = torch.tensor([-1] *
perf_knob_tensor_size,
dtype=torch.int64)
if context_fmha_type == ContextFMHAType.enabled_with_fp32_acc:
runtime_perf_knobs_tensor[1] = 1 # enable_context_fmha_fp32_acc
host_context_progress = torch.tensor([0], dtype=torch.int64)
res = run_engine(
context=runtime.context_1,
input_ids=input_ids,
context_lengths=context_lengths,
position_ids=position_ids,
last_token_ids=gen_last_token_ids,
cache_indirection=cache_indirections[0],
host_past_key_value_lengths=host_past_key_value_lengths,
host_max_attention_window_sizes=host_max_attention_window_sizes,
host_sink_token_length=host_sink_token_length,
sequence_length=sequence_length,
host_context_lengths=host_context_lengths,
host_request_types=host_request_types,
host_runtime_perf_knobs=runtime_perf_knobs_tensor,
host_context_progress=host_context_progress)
np.testing.assert_allclose(ref_out.cpu().numpy(),
res.cpu().numpy(),
atol=1e-1)
# Main logics
compare_context()
compare_generation()
# Only inflight batching mode could accept the mixture of requests from both context and generation phases
if use_in_flight_batching:
compare_mixing_context_and_generation_phases()
@parameterized.expand([("other", False, False), ("other", False, True)],
name_func=unittest_name_func)
def test_greedy_search_float32(self, test_partition, use_refit, streaming):
torch.manual_seed(42)
model = 'gpt'
log_level = 'error'
dtype = 'float32'
world_size = 1
rank = 0
hidden_act = 'gelu'
n_layer = 2
max_new_tokens = 1
batch_size = 4
seq_len = 128
use_plugin = False
do_sample = False
early_stoppping = False
num_beams = 1
num_beam_groups = 1
temperature = 1
top_k = 0
top_p = 0.0
length_penalty = 1
repetition_penalty = 1
gpt_config, hf_gpt = self._gen_hf_gpt(hidden_act, n_layer,
max_new_tokens, dtype)
runtime, engine_buffer = self._gen_tensorrt_llm_runtime(
log_level, dtype, world_size, rank, gpt_config, hf_gpt, model,
use_plugin, batch_size, seq_len, max_new_tokens, use_refit)
model_config = ModelConfig(max_batch_size=batch_size,
max_beam_width=num_beams,
vocab_size=gpt_config.vocab_size,
num_layers=gpt_config.n_layer,
num_heads=gpt_config.n_head,
num_kv_heads=gpt_config.n_head,
hidden_size=gpt_config.n_embd,
gpt_attention_plugin=False,
dtype=dtype)
mapping = tensorrt_llm.Mapping(world_size, rank, tp_size=world_size)
decoder = tensorrt_llm.runtime.GenerationSession(
model_config, engine_buffer, mapping)
pad_token_id = 50256
eos_token_id = 50257
sampling_config = SamplingConfig(end_id=eos_token_id,
pad_id=pad_token_id,
num_beams=num_beams,
temperature=temperature,
top_k=top_k,
top_p=top_p,
length_penalty=length_penalty,
repetition_penalty=repetition_penalty)
input_ids = torch.randint(100, (batch_size, seq_len)).int().cuda()
input_ids[0][-1] = pad_token_id
input_ids[1][-3:] = pad_token_id
input_ids[2][-5:] = pad_token_id
input_lengths = torch.ones(
(batch_size)).type(torch.int32).cuda() * seq_len
decoder.setup(batch_size,
max_context_length=seq_len,
max_new_tokens=max_new_tokens,
beam_width=num_beams)
if streaming:
output_ids_gen = decoder.decode(input_ids,
input_lengths,
sampling_config,
streaming=True)
for output_ids in output_ids_gen:
pass
else:
output_ids = decoder.decode(input_ids, input_lengths,
sampling_config)
#TODO: change to actual ragged tensor after GPT plugin supports it
output_ids_x = decoder.decode(input_ids, input_lengths, sampling_config)
# works because all requests in the batch has same
# TODO: enable this when GPT Plugin attention works
# output_ids_y = decoder.decode_batch([t[:input_lengths[i]] for i, t in enumerate(torch.split(input_ids, 1, dim=0))], sampling_config)
torch.cuda.synchronize()
torch.testing.assert_close(output_ids, output_ids_x)
res = output_ids.squeeze()
res = res[:, -max_new_tokens:]
ref_output_ids = hf_gpt.generate(input_ids,
do_sample=do_sample,
early_stopping=early_stoppping,
num_beams=num_beams,
temperature=temperature,
top_k=top_k,
top_p=top_p,
num_beam_groups=num_beam_groups,
max_new_tokens=max_new_tokens,
length_penalty=length_penalty,
repetition_penalty=repetition_penalty,
pad_token_id=pad_token_id,
eos_token_id=eos_token_id)
torch.cuda.synchronize()
ref = ref_output_ids[:, -max_new_tokens:]
np.testing.assert_allclose(ref.cpu().numpy(), res.cpu().numpy())
@parameterized.expand(["other"], name_func=unittest_name_func)
def test_rope_scaling_is_set_in_attention(self, test_partition):
num_layers = 2
position_embedding_type = 'rope_gpt_neox'
rotary_embedding_percentage = 0.3
rotary_base = 99999.1
rotary_scaling = {"type": "linear", "factor": 2.72}
config = {
'architecture': 'GPTForCausalLM',
'dtype': 'float16',
'num_hidden_layers': num_layers,
'num_attention_heads': 4,
'hidden_size': 128,
'vocab_size': 256,
'max_position_embeddings': 1024,
'hidden_act': 'gelu',
'position_embedding_type': position_embedding_type,
'rotary_pct': rotary_embedding_percentage,
'rotary_base': rotary_base,
'rotary_scaling': rotary_scaling,
}
config = tensorrt_llm.models.PretrainedConfig.from_dict(config)
tensorrt_llm_gpt = tensorrt_llm.models.GPTForCausalLM(config)
for layer_i in range(num_layers):
assert tensorrt_llm_gpt.transformer.layers[
layer_i].attention.rotary_embedding_base == rotary_base
assert tensorrt_llm_gpt.transformer.layers[
layer_i].attention.rotary_embedding_scale == rotary_scaling[
"factor"]
assert tensorrt_llm_gpt.transformer.layers[
layer_i].attention.rotary_embedding_scale_type == RotaryScalingType.linear
assert tensorrt_llm_gpt.transformer.layers[
layer_i].attention.position_embedding_type == PositionEmbeddingType.rope_gpt_neox
@parameterized.expand(["other"], name_func=unittest_name_func)
def test_gpt_variant_is_overridden(self, test_partition):
model_root = llm_models_root()
if model_root is None:
pytest.skip("Skipping since real weights are unavailable.")
with tempfile.TemporaryDirectory() as tempdir:
cli_args = Namespace(tp_size=1,
pp_size=1,
model_dir=f"{model_root}/starcoder2-3b",
output_dir=tempdir,
gpt_variant="starcoder2",
dtype="float16",
load_model_on_cpu=False,
use_parallel_embedding=False,
embedding_sharding_dim=0,
use_weight_only=False,
int8_kv_cache=False,
smoothquant=None,
workers=1)
def check_gpt_variant(*args, **kwargs):
self.assertEqual(kwargs.get("gpt_variant", ""),
cli_args.gpt_variant)
return from_hugging_face(*args, **kwargs)
from_hugging_face = tensorrt_llm.models.GPTConfig.from_hugging_face
with patch('tensorrt_llm.models.GPTConfig.from_hugging_face',
side_effect=check_gpt_variant):
convert_and_save_hf(cli_args)
if __name__ == '__main__':
unittest.main()