TensorRT-LLMs/cpp/tensorrt_llm/kernels/banRepeatNgram.cu
Kaiyu Xie 4bb65f216f
Update TensorRT-LLM (#1274)
* Update TensorRT-LLM

---------

Co-authored-by: meghagarwal <16129366+megha95@users.noreply.github.com>
Co-authored-by: Shixiaowei02 <39303645+Shixiaowei02@users.noreply.github.com>
2024-03-12 18:15:52 +08:00

177 lines
8.0 KiB
Plaintext

/*
* Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "tensorrt_llm/common/cudaUtils.h"
#include "tensorrt_llm/kernels/banRepeatNgram.h"
using namespace tensorrt_llm::common;
namespace tensorrt_llm
{
namespace kernels
{
template <typename T>
__global__ void ban_repeat_ngram(T* logits, int const** output_ids_buf, FinishedState const* finished_buf,
int const** parent_ids_buf, int const* batch_slots, int batch_size, int beam_width, int max_seq_len,
int const* no_repeat_ngram_size_buf, int vocab_size_padded, int const* sequence_lengths)
{
/**
* Find subsequences that match the last (ngram_size - 1) generated tokens. The next-tokens of those matching
* subsequences should be banned from the next-token logits, such that existing N-grams in current generated
* sequence will not be generated again.
*
* Note 1: no-repeat restriction is per-beam instead of across-beam.
* Note 2: since beam search results are stored and retrieved by backtracking (parent_ids), the entire sequence for
* the current can only be obtained by backtracking all the way back.
* Note 3: for greedy search, actually a more efficient implementation can be adopted (since we're not constrained
* by the beam backtrack retrieval, we can have all threads loading into shared mem in parallel AND use normal order
* traversal). But for simplicity and consistency, greedy search and beam search implementation are kept the same.
*
* The overlap between adjacent threads indicates wasted global memory access. Used shared memory instead.
* Shared memory benefit is more significant as ngram_size increases. Shared memory reuse is for
* in-bound positions only. For leftside out-of-boundary tokens, access by global memory.
*/
int const output_idx = blockIdx.x * blockDim.x + threadIdx.x;
int const local_batch_idx = blockIdx.y / beam_width;
auto const batch_slot = batch_slots != nullptr ? batch_slots[local_batch_idx] : local_batch_idx;
int const beam_idx = blockIdx.y % beam_width;
bool const beam_search = beam_width > 1;
int const no_repeat_ngram_size = no_repeat_ngram_size_buf[batch_slot];
int const step = sequence_lengths[batch_slot];
// case 1: ngram_size == 0 --> this means no ngram limit
// case 2: generated length must be greater than ngram_size to do ngram check
if (no_repeat_ngram_size == 0 || step < no_repeat_ngram_size)
{
return;
}
// if the beam has already finished, skip ngram check
if ((finished_buf != nullptr) && (finished_buf[batch_slot * beam_width + beam_idx].isFinished()))
{
return;
}
// shared mem layout: per-thread token that each thread is mapped to, plus (ngram_size - 1) extra tokens beyond
// block boundary, plus (ngram_size - 1) most recent generated tokens in the beam
extern __shared__ int shared_tokens[];
int shared_tokens_length = blockDim.x + no_repeat_ngram_size - 1;
int* last_tokens = &shared_tokens[shared_tokens_length];
int last_tokens_length = no_repeat_ngram_size - 1;
// retrieve the entire beam by backtracking from last token to current token (in reverse order)
// single thread vs parallel thread is equivalent as it's bound by the longest iteration
if (threadIdx.x == 0)
{
int parent_id = beam_idx;
int start_record_idx = min(output_idx + shared_tokens_length, (int) step);
int shared_token_idx = start_record_idx == step ? step - output_idx - 1 : shared_tokens_length - 1;
int last_token_idx = last_tokens_length - 1;
// write to shared mem in reverse order; boundary condition when thread block covers more than step
for (int curr_idx = step - 1; curr_idx >= output_idx; curr_idx--)
{
if (last_token_idx >= 0)
{
last_tokens[last_token_idx--] = output_ids_buf[batch_slot][parent_id * max_seq_len + curr_idx];
}
// before reaching the part of current block, traverse only; after that, record the tokens
if (curr_idx < start_record_idx)
{
shared_tokens[shared_token_idx--] = output_ids_buf[batch_slot][parent_id * max_seq_len + curr_idx];
}
if (beam_search)
{
parent_id = parent_ids_buf[batch_slot][parent_id * max_seq_len + curr_idx];
}
}
}
__syncthreads();
if (output_idx > step - no_repeat_ngram_size)
{
return;
}
bool ban_ngram = true;
// ngram check (in regular order)
for (int ngram_idx = 0; ngram_idx < no_repeat_ngram_size - 1; ngram_idx++)
{
if (shared_tokens[threadIdx.x + ngram_idx] != last_tokens[ngram_idx])
{
ban_ngram = false;
break;
}
}
// erase banned next token's prob as -INF
if (ban_ngram)
{
int banned_token = shared_tokens[threadIdx.x + no_repeat_ngram_size - 1]; // ban the last token in the ngram
logits[local_batch_idx * beam_width * vocab_size_padded + beam_idx * vocab_size_padded + banned_token]
= static_cast<T>(-INFINITY); // note: "logits" passed in is already with batchxbeam offset
}
}
template <typename T>
void invokeBanRepeatNgram(T* logits, int const** output_ids_buf, FinishedState const* finished_buf,
int const** parent_ids_buf, int const* batch_slot, int const* sequence_lengths, int batch_size, int beam_width,
int max_seq_len, int const* no_repeat_ngram_size_buf, int vocab_size_padded, size_t max_step, cudaStream_t stream)
{
// each input in the local batch can have different no_repeat_ngram_size. Use max for shmem allocation
// getting the max of current batch and allocate shmem as needed is ideal. But here the ngram_buf is on GPU, while
// this max value is on CPU for kernel launch. Instead of really finding the max and extra CPU-GPU memcpy, we simply
// use a constant. In practice, ngram size is usually very small, like 3 or 4.
int max_no_repeat_ngram_size = 32;
// step (current generated length, except start token) is from 1 ~ max_seq_len
dim3 block, grid;
constexpr size_t max_blocks{256};
block.x = min(((max_step + 32 - 1) / 32) * 32, max_blocks);
grid.x = (max_step + block.x - 1) / block.x;
grid.y = batch_size * beam_width;
// dynamically allocate shared memory of int[blockDim + 2*(ngram_size - 1)], where ngram_size - 1 is for boundary
// token's ngram and for most recent tokens
ban_repeat_ngram<<<grid, block, (block.x + 2 * (max_no_repeat_ngram_size - 1)) * sizeof(int), stream>>>(logits,
output_ids_buf, finished_buf, parent_ids_buf, batch_slot, batch_size, beam_width, max_seq_len,
no_repeat_ngram_size_buf, vocab_size_padded, sequence_lengths);
sync_check_cuda_error();
}
#define INVOKE_BAN_REPEAT_NGRAM(T) \
template void invokeBanRepeatNgram(T* logits, const int** output_ids_buf, const FinishedState* finished_buf, \
const int** parent_ids_buf, const int* batch_slot, const int* sequence_lengths, int batch_size, \
int beam_width, int max_seq_len, const int* no_repeat_ngram_size_buf, int vocab_size_padded, size_t max_step, \
cudaStream_t stream);
INVOKE_BAN_REPEAT_NGRAM(float)
INVOKE_BAN_REPEAT_NGRAM(half)
#ifdef ENABLE_BF16
INVOKE_BAN_REPEAT_NGRAM(__nv_bfloat16)
#endif
#undef INVOKE_BAN_REPEAT_NGRAM
} // namespace kernels
} // namespace tensorrt_llm