# SPDX-FileCopyrightText: Copyright (c) 2022-2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. # SPDX-License-Identifier: Apache-2.0 # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import math import os import random import sys import tempfile import unittest from itertools import product import numpy as np import pytest import tensorrt as trt import torch from parameterized import parameterized from transformers import GPT2Config, GPT2LMHeadModel import tensorrt_llm from tensorrt_llm import Builder from tensorrt_llm._utils import str_dtype_to_torch from tensorrt_llm.network import net_guard from tensorrt_llm.plugin.plugin import ContextFMHAType from tensorrt_llm.runtime import ModelConfig, SamplingConfig from tensorrt_llm.runtime.generation import _prepare_attention_mask from tensorrt_llm.runtime.kv_cache_manager import (GenerationSequence, KVCacheManager) sys.path.append(os.path.join(os.path.dirname(__file__), '../..')) from examples.gpt.weight import load_from_hf_gpt sys.path.append(os.path.join(os.path.dirname(__file__), '..')) from utils.util import getSMVersion class TestGPT(unittest.TestCase): def _gen_hf_gpt(self, hidden_act, n_layer, max_length, dtype): gpt_config = GPT2Config( activation_function=hidden_act, n_layer=n_layer, max_length=max_length, torch_dtype=dtype, ) hf_gpt = GPT2LMHeadModel(gpt_config).cuda().eval() return gpt_config, hf_gpt def _gen_tensorrt_llm_network(self, network, builder, hf_gpt, gpt_config, batch_size, input_len, output_len, fp16, gpt_attention_plugin, tensor_parallel, apply_query_key_layer_scaling, gather_all_token_logits): num_layers = gpt_config.n_layer num_heads = gpt_config.n_head hidden_size = gpt_config.n_embd vocab_size = gpt_config.vocab_size hidden_act = gpt_config.activation_function n_positions = gpt_config.n_positions tensor_parallel_group = list(range(tensor_parallel)) with net_guard(network): kv_dtype = trt.float16 if fp16 else trt.float32 # Initialize model tensorrt_llm_gpt = tensorrt_llm.models.GPTLMHeadModel( num_layers=num_layers, num_heads=num_heads, hidden_size=hidden_size, vocab_size=vocab_size, hidden_act=hidden_act, max_position_embeddings=n_positions, dtype=kv_dtype, mapping=tensorrt_llm.Mapping(world_size=tensor_parallel, tp_size=tensor_parallel), apply_query_key_layer_scaling=apply_query_key_layer_scaling) inputs = tensorrt_llm_gpt.prepare_inputs( batch_size, input_len, output_len, use_cache=True, max_beam_width=1, gather_all_token_logits=gather_all_token_logits) load_from_hf_gpt(tensorrt_llm_gpt, hf_gpt, dtype="float16" if fp16 else "float32") # Prepare network.set_named_parameters(tensorrt_llm_gpt.named_parameters()) tensorrt_llm_gpt(*inputs) return network def _gen_tensorrt_llm_runtime(self, log_level, dtype, world_size, rank, gpt_config, hf_gpt, model, use_plugin, batch_size, input_len, output_len, use_refit, fast_building=False, apply_query_key_layer_scaling=False, context_fmha_type=ContextFMHAType.disabled, enable_remove_input_padding=False, enable_paged_kv_cache=False, tokens_per_block=64, gather_all_token_logits=False): mapping = tensorrt_llm.Mapping(world_size, rank, tp_size=world_size) runtime = None builder = Builder() fp16 = (dtype == 'float16') with tempfile.TemporaryDirectory() as tmpdirname: network = builder.create_network() if use_plugin: network.plugin_config.set_gpt_attention_plugin(dtype) if fast_building: network.plugin_config.set_gemm_plugin(dtype) network.plugin_config.set_layernorm_plugin(dtype) network.plugin_config.set_context_fmha(context_fmha_type) if enable_remove_input_padding: network.plugin_config.enable_remove_input_padding() if enable_paged_kv_cache: network.plugin_config.enable_paged_kv_cache(tokens_per_block) self._gen_tensorrt_llm_network(network, builder, hf_gpt, gpt_config, batch_size, input_len, output_len, fp16, use_plugin, world_size, apply_query_key_layer_scaling, gather_all_token_logits) builder_config = builder.create_builder_config( name='gpt', precision=dtype, timing_cache='model.cache', tensor_parallel=world_size, # TP only use_refit=use_refit, gather_all_token_logits=gather_all_token_logits, ) engine_buffer = builder.build_engine(network, builder_config) runtime = tensorrt_llm.runtime.generation._Runtime( engine_buffer, mapping) return runtime, engine_buffer @parameterized.expand([(False)]) def test_gpt_float32(self, use_refit): model = 'gpt' log_level = 'error' dtype = 'float32' world_size = 1 rank = 0 hidden_act = 'gelu' n_layer = 2 max_length = 2 batch_size = 4 beam_width = 1 seq_len = 128 total_length = seq_len + max_length use_plugin = False gpt_config, hf_gpt = self._gen_hf_gpt(hidden_act, n_layer, max_length, dtype) runtime, _ = self._gen_tensorrt_llm_runtime( log_level, dtype, world_size, rank, gpt_config, hf_gpt, model, use_plugin, batch_size, seq_len, max_length, use_refit) # compare context pad_token_id = 50256 ctx_ids = torch.randint(100, (batch_size, seq_len)).int().cuda() ctx_ids[0][-1] = pad_token_id ctx_ids[1][-3:] = pad_token_id ctx_ids[2][-5:] = pad_token_id ctx_context_lengths = seq_len * torch.ones( (batch_size), dtype=torch.int32, device='cuda') ctx_host_context_lengths = ctx_context_lengths.cpu() ctx_host_request_types = torch.tensor([0] * batch_size, dtype=torch.int32, device='cpu') ctx_position_ids = torch.tensor(range(seq_len), dtype=torch.int32).reshape([ 1, seq_len ]).expand([batch_size, seq_len]).cuda() ctx_last_token_ids = ctx_context_lengths.clone() ctx_attention_mask = _prepare_attention_mask(ctx_ids) cache_indirections = [ torch.full(( batch_size, beam_width, total_length, ), 0, dtype=torch.int32, device='cuda'), torch.full(( batch_size, beam_width, total_length, ), 0, dtype=torch.int32, device='cuda') ] # ping-pong buffers ctx_shape = { 'input_ids': ctx_ids.shape, 'position_ids': ctx_position_ids.shape, 'context_lengths': ctx_context_lengths.shape, 'host_context_lengths': ctx_host_context_lengths.shape, 'last_token_ids': ctx_last_token_ids.shape, 'attention_mask': ctx_attention_mask.shape, 'host_request_types': ctx_host_request_types.shape, 'cache_indirection': cache_indirections[0].shape, } ctx_buffer = { 'input_ids': ctx_ids, 'position_ids': ctx_position_ids, 'context_lengths': ctx_context_lengths, 'host_context_lengths': ctx_host_context_lengths, 'last_token_ids': ctx_last_token_ids, 'attention_mask': ctx_attention_mask, 'host_request_types': ctx_host_request_types, 'cache_indirection': cache_indirections[0], } for i in range(gpt_config.n_layer): shape = (batch_size, 2, gpt_config.n_head, 0, gpt_config.n_embd // gpt_config.n_head) past_buffer = torch.zeros((1, ), dtype=str_dtype_to_torch(dtype), device='cuda') ctx_shape.update({ f'past_key_value_{i}': shape, }) shape = (batch_size, 2, gpt_config.n_head, seq_len, gpt_config.n_embd // gpt_config.n_head) ctx_buffer.update({ f'past_key_value_{i}': past_buffer, f'present_key_value_{i}': torch.zeros(shape, dtype=str_dtype_to_torch(dtype), device='cuda'), }) context = runtime.ctx_context runtime._set_shape(context, ctx_shape) runtime._set_buffer(context, ctx_buffer) runtime._run(context) torch.cuda.synchronize() res = ctx_buffer['logits'] with torch.no_grad(): hf_outputs = hf_gpt.forward(ctx_ids, attention_mask=ctx_attention_mask) torch.cuda.synchronize() ref = hf_outputs.logits[:, -1, :] np.testing.assert_allclose(ref.cpu().numpy(), res.cpu().numpy(), atol=1e-2) for i in range(gpt_config.n_layer): res_present_key_value = ctx_buffer[f'present_key_value_{i}'] ref_present_key, ref_present_value = hf_outputs.past_key_values[i] past_key_value_tensor = res_present_key_value.permute(1, 0, 2, 3, 4) key, value = past_key_value_tensor.chunk(2) head_size = gpt_config.n_embd // gpt_config.n_head key = key.to(torch.float32).reshape(batch_size, gpt_config.n_head, seq_len, head_size) value = value.reshape(batch_size, gpt_config.n_head, seq_len, head_size) np.testing.assert_allclose(ref_present_key.cpu().numpy(), key.cpu().numpy(), atol=1e-2) np.testing.assert_allclose(ref_present_value.cpu().numpy(), value.cpu().numpy(), atol=1e-2) # compare generation gen_id = torch.randint(100, (batch_size, 1)).int().cuda() gen_context_lengths = ctx_context_lengths.clone() gen_host_context_lengths = ctx_host_context_lengths.clone() gen_host_request_types = torch.tensor([1] * batch_size, dtype=torch.int32, device='cpu') gen_position_ids = torch.ones_like(gen_id).cuda() * seq_len gen_last_token_ids = torch.zeros_like(gen_context_lengths).cuda() gen_attention_mask = torch.cat([ ctx_attention_mask, ctx_attention_mask.new_ones((ctx_attention_mask.shape[0], 1)) ], dim=-1) step1_shape = { 'input_ids': gen_id.shape, 'context_lengths': gen_context_lengths.shape, 'host_context_lengths': gen_host_context_lengths.shape, 'host_request_types': gen_host_request_types.shape, 'position_ids': gen_position_ids.shape, 'last_token_ids': gen_last_token_ids.shape, 'attention_mask': gen_attention_mask.shape, 'cache_indirection': cache_indirections[1].shape, } step1_buffer = { 'input_ids': gen_id, 'context_lengths': gen_context_lengths.contiguous(), 'host_context_lengths': gen_host_context_lengths.contiguous(), 'host_request_types': gen_host_request_types.contiguous(), 'position_ids': gen_position_ids.contiguous(), 'last_token_ids': gen_last_token_ids.contiguous(), 'attention_mask': gen_attention_mask.contiguous(), 'cache_indirection': cache_indirections[1].contiguous(), } for i in range(gpt_config.n_layer): shape = (batch_size, 2, gpt_config.n_head, seq_len, gpt_config.n_embd // gpt_config.n_head) step1_shape.update({ f'past_key_value_{i}': shape, }) step1_buffer.update({ f'past_key_value_{i}': ctx_buffer[f'present_key_value_{i}'], }) context = runtime.context_1 runtime._set_shape(context, step1_shape) runtime._set_buffer(context, step1_buffer) runtime._run(context) torch.cuda.synchronize() res = step1_buffer['logits'] with torch.no_grad(): hf_outputs = hf_gpt.forward( gen_id, attention_mask=gen_attention_mask, past_key_values=hf_outputs.past_key_values, use_cache=True) torch.cuda.synchronize() ref = hf_outputs.logits[:, -1, :] np.testing.assert_allclose(ref.cpu().numpy(), res.cpu().numpy(), atol=1e-2) for i in range(gpt_config.n_layer): res_present_key_value = step1_buffer[f'present_key_value_{i}'] ref_present_key, ref_present_value = hf_outputs.past_key_values[i] past_key_value_tensor = res_present_key_value.permute(1, 0, 2, 3, 4) key, value = past_key_value_tensor.chunk(2) head_size = gpt_config.n_embd // gpt_config.n_head key = key.reshape(batch_size, gpt_config.n_head, seq_len + 1, head_size) value = value.reshape(batch_size, gpt_config.n_head, seq_len + 1, head_size) np.testing.assert_allclose(ref_present_key.cpu().numpy(), key.cpu().numpy(), atol=1e-2) np.testing.assert_allclose(ref_present_value.cpu().numpy(), value.cpu().numpy(), atol=1e-2) def load_test_cases(): test_cases = list( product([False, True], [False, True], [False, True], [ ContextFMHAType.disabled, ContextFMHAType.enabled, ContextFMHAType.enabled_with_fp32_acc ], [False, True], [False, True], [False, True])) return test_cases @parameterized.expand(load_test_cases) def test_gpt_plugin(self, use_refit, fast_building, apply_query_key_layer_scaling, context_fmha_type, enable_remove_input_padding, enable_paged_kv_cache, gather_all_token_logits): # inflight batching mode only works with remove_input_padding and paged_kv_cache use_in_flight_batching = enable_remove_input_padding and enable_paged_kv_cache and not gather_all_token_logits # Skip tests that are not supported in pre-ampere architecture if getSMVersion() < 80: if context_fmha_type == ContextFMHAType.enabled: pytest.skip( "ContextFMHAType is not supported in pre-ampere architecture" ) elif context_fmha_type == ContextFMHAType.enabled_with_fp32_acc: pytest.skip( "ContextFMHAType with fp32 acc is not supported in pre-ampere architecture" ) torch.manual_seed(0) random.seed(0) model = 'gpt' log_level = 'error' dtype = 'float16' world_size = 1 rank = 0 hidden_act = 'gelu' n_layer = 1 max_length = 2 batch_size = 4 beam_width = 1 seq_len = 128 total_length = seq_len + max_length use_plugin = True tokens_per_block = 64 gpt_config, hf_gpt = self._gen_hf_gpt(hidden_act, n_layer, seq_len + max_length, dtype) runtime, _ = self._gen_tensorrt_llm_runtime( log_level, dtype, world_size, rank, gpt_config, hf_gpt, model, use_plugin, batch_size, seq_len, max_length, use_refit, fast_building, apply_query_key_layer_scaling, context_fmha_type, enable_remove_input_padding, enable_paged_kv_cache, tokens_per_block, gather_all_token_logits) key_value_cache_buffers = [] value_cache_buffers = [] head_size = gpt_config.n_embd // gpt_config.n_head for i in range(gpt_config.n_layer): if enable_paged_kv_cache: blocks = batch_size * beam_width * math.ceil( total_length / tokens_per_block) cache_shape = ( blocks, 2, gpt_config.n_head, tokens_per_block, head_size, ) else: cache_shape = ( batch_size, 2, gpt_config.n_head, total_length, head_size, ) key_value_cache_buffers.append( torch.zeros(cache_shape, dtype=tensorrt_llm._utils.str_dtype_to_torch(dtype), device='cuda')) value_cache_buffers.append( torch.zeros(( batch_size, gpt_config.n_head, total_length, head_size, ), dtype=tensorrt_llm._utils.str_dtype_to_torch(dtype), device='cuda')) cache_indirections = [ torch.full(( batch_size, beam_width, total_length, ), 0, dtype=torch.int32, device='cuda'), torch.full(( batch_size, beam_width, total_length, ), 0, dtype=torch.int32, device='cuda') ] # ping-pong buffers if enable_paged_kv_cache: max_blocks_per_seq = math.ceil(total_length / tokens_per_block) blocks = batch_size * beam_width * max_blocks_per_seq kv_cache_manager = KVCacheManager(key_value_cache_buffers, blocks, tokens_per_block, max_blocks_per_seq, beam_width) # Add sequences to the manager for bi in range(batch_size): generation_sequence = GenerationSequence(seq_idx=bi, batch_idx=bi) kv_cache_manager.add_sequence(generation_sequence, seq_len) def run_engine(context, input_ids, context_lengths, host_request_types, position_ids, last_token_ids, cache_indirection, host_past_key_value_lengths, sequence_length=None, host_context_lengths=None): ctx_buffer = { 'input_ids': input_ids, 'context_lengths': context_lengths, 'host_request_types': host_request_types, 'position_ids': position_ids, 'last_token_ids': last_token_ids, 'cache_indirection': cache_indirection, 'host_past_key_value_lengths': host_past_key_value_lengths, 'sequence_length': sequence_length, } assert host_request_types is not None if enable_remove_input_padding: assert host_context_lengths is not None, "host_context_lengths is required for ragged input" ctx_buffer['host_context_lengths'] = host_context_lengths if enable_paged_kv_cache: assert beam_width == 1 # for beam_width > 1 the argument must be '1' in ctx phase and 'beam_width' in gen phase kv_cache_block_pointers = kv_cache_manager.get_pointer_arrays(1) for idx in range(gpt_config.n_layer): shape = kv_cache_block_pointers[idx].shape shape = [shape[0] * shape[1], *shape[2:]] ctx_buffer[ f'kv_cache_block_pointers_{idx}'] = kv_cache_block_pointers[ idx].reshape(shape).contiguous() else: for i in range(gpt_config.n_layer): ctx_buffer[f'past_key_value_{i}'] = key_value_cache_buffers[ i] ctx_buffer[ f'present_key_value_{i}'] = key_value_cache_buffers[i] ctx_shape = { key: buffer.shape for key, buffer in ctx_buffer.items() } runtime._set_shape(context, ctx_shape) runtime._set_buffer(context, ctx_buffer) runtime._run(context) torch.cuda.synchronize() res = ctx_buffer['logits'] return res hf_outputs = None step0_ids = None step1_ids = None def compare_context(run_ref_only=False): nonlocal step0_ids step0_ids = torch.randint( 100, (batch_size, seq_len)).int().cuda() if step0_ids is None else step0_ids ctx_ids = step0_ids.clone() ctx_context_lengths = seq_len * torch.ones( (batch_size), dtype=torch.int32, device='cuda') ctx_position_ids = torch.tensor(range(seq_len), dtype=torch.int32).reshape([ 1, seq_len ]).expand([batch_size, seq_len]).cuda() ctx_last_token_ids = ctx_context_lengths.clone() nonlocal hf_outputs with torch.no_grad(): hf_outputs = hf_gpt.forward(ctx_ids) torch.cuda.synchronize() ref = hf_outputs.logits if run_ref_only: return ref[:, -1, :] if enable_remove_input_padding: ctx_ids = ctx_ids.view([1, batch_size * seq_len]) ctx_position_ids = ctx_position_ids.view( [1, batch_size * seq_len]) ctx_last_token_ids = torch.cumsum(ctx_last_token_ids, dim=0).int() host_past_key_value_lengths = torch.tensor([0] * batch_size, dtype=torch.int32) host_context_lengths = ctx_context_lengths.cpu( ) if enable_remove_input_padding else None host_request_types = torch.tensor([0 for i in range(batch_size)], dtype=torch.int32).cpu() # We need sequence_lengths start as context_lengths for step 0 (context), # and it will be added one after each step. sequence_length = ctx_context_lengths.detach().clone() res = run_engine( context=runtime.ctx_context, input_ids=ctx_ids, context_lengths=ctx_context_lengths, position_ids=ctx_position_ids, last_token_ids=ctx_last_token_ids, cache_indirection=cache_indirections[0], host_past_key_value_lengths=host_past_key_value_lengths, sequence_length=sequence_length, host_context_lengths=host_context_lengths, host_request_types=host_request_types) if gather_all_token_logits: np.testing.assert_allclose(ref.cpu().numpy().flatten(), res.cpu().numpy().flatten(), atol=1e-1) else: np.testing.assert_allclose(ref[:, -1, :].cpu().numpy(), res.cpu().numpy(), atol=1e-1) def compare_generation(run_ref_only=False): step = 1 nonlocal step1_ids step1_ids = torch.randint( 100, (batch_size, 1)).int().cuda() if step1_ids is None else step1_ids gen_ids = step1_ids.clone() gen_context_lengths = seq_len * torch.ones( (batch_size), dtype=torch.int32, device='cuda') gen_position_ids = torch.ones_like(gen_ids).int().cuda() * seq_len gen_last_token_ids = torch.zeros_like( gen_context_lengths).int().cuda() nonlocal hf_outputs with torch.no_grad(): hf_outputs = hf_gpt.forward( gen_ids, past_key_values=hf_outputs.past_key_values, use_cache=True) torch.cuda.synchronize() ref = hf_outputs.logits[:, -1, :] if run_ref_only: return ref if enable_remove_input_padding: gen_ids = gen_ids.view([1, batch_size]) gen_position_ids = gen_position_ids.view([1, batch_size]) gen_last_token_ids = torch.ones_like( gen_context_lengths).int().cuda() gen_last_token_ids = torch.cumsum(gen_last_token_ids, dim=0).int() host_past_key_value_lengths = torch.tensor([seq_len + step - 1] * batch_size, dtype=torch.int32) host_context_lengths = gen_context_lengths.cpu( ) if enable_remove_input_padding else None host_request_types = torch.tensor([1 for i in range(batch_size)], dtype=torch.int32).cpu() # For step 1, the sequence_lengths = context_lengths + 1. sequence_length = torch.add(gen_context_lengths.detach().clone(), 1) res = run_engine( context=runtime.context_1, input_ids=gen_ids, context_lengths=gen_context_lengths, position_ids=gen_position_ids, last_token_ids=gen_last_token_ids, cache_indirection=cache_indirections[1], host_past_key_value_lengths=host_past_key_value_lengths, sequence_length=sequence_length, host_context_lengths=host_context_lengths, host_request_types=host_request_types) np.testing.assert_allclose(ref.cpu().numpy().flatten(), res.cpu().numpy().flatten(), atol=1e-1) def compare_mixing_context_and_generation_phases(): num_context_input = 2 assert batch_size >= num_context_input num_generation_input = batch_size - num_context_input # retrieve the reference output ref_ctx_out = compare_context(True)[:num_context_input, :] ref_gen_out = compare_generation(True)[num_context_input:, :] ref_out = torch.cat([ref_ctx_out, ref_gen_out], dim=0) ref_ctx_out = None ref_gen_out = None # compare_context() # prepare the inputs for plugin-based gpt assert step0_ids is not None and step1_ids is not None input_ids = torch.cat([ step0_ids[:num_context_input, :].view( (-1, )), step1_ids[num_context_input:].view((-1, )) ], dim=0) input_ids = input_ids.view((1, -1)) ctx_position_ids = torch.tensor( range(seq_len), dtype=torch.int32).reshape( (1, seq_len)).expand([num_generation_input, seq_len]).cuda() gen_position_ids = torch.ones_like( step1_ids[num_context_input:].view( (-1, ))).int().cuda() * seq_len position_ids = torch.cat( [ctx_position_ids.view((-1, )), gen_position_ids], dim=0).view( (1, -1)) input_lengths = torch.tensor([seq_len] * num_context_input + [1] * num_generation_input, dtype=torch.int32).cuda() gen_last_token_ids = torch.cumsum(input_lengths, dim=0).int().cuda() # scalar of max_key_value_length for in-flight batching case host_past_key_value_lengths = torch.tensor( [0] * num_context_input + [seq_len] * num_generation_input, dtype=torch.int32) context_lengths = torch.tensor([seq_len] * batch_size, dtype=torch.int32).cuda() if enable_remove_input_padding: host_context_lengths = context_lengths.cpu() host_request_types = torch.tensor([0] * num_context_input + [1] * num_generation_input, dtype=torch.int32).cpu() # The sequence_lengths = context_lengths + step for generation stage. sequence_length = torch.tensor([seq_len] * num_context_input + [seq_len + 1] * num_generation_input, dtype=torch.int32).cuda() res = run_engine( context=runtime.context_1, input_ids=input_ids, context_lengths=context_lengths, position_ids=position_ids, last_token_ids=gen_last_token_ids, cache_indirection=cache_indirections[0], host_past_key_value_lengths=host_past_key_value_lengths, sequence_length=sequence_length, host_context_lengths=host_context_lengths, host_request_types=host_request_types, ) np.testing.assert_allclose(ref_out.cpu().numpy(), res.cpu().numpy(), atol=1e-1) # Main logics compare_context() compare_generation() # Only inflight batching mode could accept the mixture of requests from both context and generation phases if use_in_flight_batching: compare_mixing_context_and_generation_phases() @parameterized.expand([(False, False), (False, True)]) def test_greedy_search_float32(self, use_refit, streaming): model = 'gpt' log_level = 'error' dtype = 'float32' world_size = 1 rank = 0 hidden_act = 'gelu' n_layer = 2 max_new_tokens = 1 batch_size = 4 seq_len = 128 use_plugin = False do_sample = False early_stoppping = False num_beams = 1 num_beam_groups = 1 temperature = 1 top_k = 0 top_p = 0.0 length_penalty = 1 repetition_penalty = 1 gpt_config, hf_gpt = self._gen_hf_gpt(hidden_act, n_layer, max_new_tokens, dtype) runtime, engine_buffer = self._gen_tensorrt_llm_runtime( log_level, dtype, world_size, rank, gpt_config, hf_gpt, model, use_plugin, batch_size, seq_len, max_new_tokens, use_refit) model_config = ModelConfig(vocab_size=gpt_config.vocab_size, num_layers=gpt_config.n_layer, num_heads=gpt_config.n_head, num_kv_heads=gpt_config.n_head, hidden_size=gpt_config.n_embd, gpt_attention_plugin=False, dtype=dtype) mapping = tensorrt_llm.Mapping(world_size, rank, tp_size=world_size) decoder = tensorrt_llm.runtime.GenerationSession( model_config, engine_buffer, mapping) pad_token_id = 50256 eos_token_id = 50257 sampling_config = SamplingConfig(end_id=eos_token_id, pad_id=pad_token_id, num_beams=num_beams, temperature=temperature, top_k=top_k, top_p=top_p, length_penalty=length_penalty, repetition_penalty=repetition_penalty) input_ids = torch.randint(100, (batch_size, seq_len)).int().cuda() input_ids[0][-1] = pad_token_id input_ids[1][-3:] = pad_token_id input_ids[2][-5:] = pad_token_id input_lengths = torch.ones( (batch_size)).type(torch.int32).cuda() * seq_len decoder.setup(batch_size, max_context_length=seq_len, max_new_tokens=max_new_tokens, beam_width=num_beams) if streaming: output_ids_gen = decoder.decode(input_ids, input_lengths, sampling_config, streaming=True) for output_ids in output_ids_gen: pass else: output_ids = decoder.decode(input_ids, input_lengths, sampling_config) #TODO: change to actual ragged tensor after GPT plugin supports it output_ids_x = decoder.decode(input_ids, input_lengths, sampling_config) # works because all requests in the batch has same # TODO: enable this when GPT Plugin attention works # output_ids_y = decoder.decode_batch([t[:input_lengths[i]] for i, t in enumerate(torch.split(input_ids, 1, dim=0))], sampling_config) torch.cuda.synchronize() torch.testing.assert_close(output_ids, output_ids_x) res = output_ids.squeeze() res = res[:, -max_new_tokens:] ref_output_ids = hf_gpt.generate(input_ids, do_sample=do_sample, early_stopping=early_stoppping, num_beams=num_beams, temperature=temperature, top_k=top_k, top_p=top_p, num_beam_groups=num_beam_groups, max_new_tokens=max_new_tokens, length_penalty=length_penalty, repetition_penalty=repetition_penalty, pad_token_id=pad_token_id, eos_token_id=eos_token_id) torch.cuda.synchronize() ref = ref_output_ids[:, -max_new_tokens:] np.testing.assert_allclose(ref.cpu().numpy(), res.cpu().numpy()) if __name__ == '__main__': unittest.main()