mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-01-14 06:27:45 +08:00
* Update TensorRT-LLM --------- Co-authored-by: Shixiaowei02 <39303645+Shixiaowei02@users.noreply.github.com>
663 lines
25 KiB
C++
663 lines
25 KiB
C++
/*
|
|
* Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved.
|
|
*
|
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
|
* you may not use this file except in compliance with the License.
|
|
* You may obtain a copy of the License at
|
|
*
|
|
* http://www.apache.org/licenses/LICENSE-2.0
|
|
*
|
|
* Unless required by applicable law or agreed to in writing, software
|
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
* See the License for the specific language governing permissions and
|
|
* limitations under the License.
|
|
*/
|
|
|
|
#ifndef CUDART_VERSION
|
|
#error CUDART_VERSION Undefined!
|
|
#elif (CUDART_VERSION >= 11050)
|
|
#include <cub/cub.cuh>
|
|
#else
|
|
#include "3rdparty/cub/cub.cuh"
|
|
#endif
|
|
|
|
#include "tensorrt_llm/common/assert.h"
|
|
#include "tensorrt_llm/common/cudaUtils.h"
|
|
#include "tensorrt_llm/common/reduceKernelUtils.cuh"
|
|
#include "tensorrt_llm/common/stringUtils.h"
|
|
#include "tensorrt_llm/kernels/onlineSoftmaxBeamsearchKernels.h"
|
|
|
|
using namespace tensorrt_llm::common;
|
|
|
|
namespace tensorrt_llm
|
|
{
|
|
namespace kernels
|
|
{
|
|
|
|
#define DO_SPLIT_SMALL_TOP_K_SOFTMAX
|
|
static const int SMALL_TOP_K_SOFTMAX_THREADBLOCK_SIZE = 256;
|
|
|
|
#define TOPK_FP16_STORAGE 0
|
|
|
|
template <typename T>
|
|
__device__ __forceinline__ T apply_length_penalty(T log_prob, int length, float length_penalty)
|
|
{
|
|
// score = log(prob) / (length)^length_penalty.
|
|
if (length_penalty == 0.0f || length == 1)
|
|
{
|
|
return log_prob;
|
|
}
|
|
return log_prob / static_cast<T>(powf(length, length_penalty));
|
|
}
|
|
|
|
template <typename T, int MAX_K, int THREADBLOCK_SIZE>
|
|
__launch_bounds__(THREADBLOCK_SIZE) __global__
|
|
void batch_topK_kernel(int* topk_tmp_id_buf, T* topk_tmp_val_buf, int* id_buf)
|
|
{
|
|
int thread_id = threadIdx.x;
|
|
int block_id = blockIdx.x;
|
|
TopK<T, MAX_K> partial;
|
|
if (thread_id == 0)
|
|
{
|
|
for (int i = 0; i < MAX_K; ++i)
|
|
{
|
|
partial.p[i] = -1;
|
|
partial.u[i] = -FLT_MAX;
|
|
}
|
|
|
|
int index = block_id * MAX_K * MAX_K;
|
|
for (int i = 0; i < MAX_K * MAX_K; i++)
|
|
{
|
|
partial.insert((T) topk_tmp_val_buf[index + i], topk_tmp_id_buf[index + i]);
|
|
}
|
|
|
|
index = block_id * MAX_K;
|
|
for (int i = 0; i < MAX_K; i++)
|
|
{
|
|
id_buf[index + i] = partial.p[i];
|
|
}
|
|
}
|
|
}
|
|
|
|
template <typename T, int MAX_K, int THREADBLOCK_SIZE>
|
|
__launch_bounds__(THREADBLOCK_SIZE) __global__ void batch_topK_kernel(const int* __restrict topk_tmp_id_buf,
|
|
const T* __restrict topk_tmp_val_buf, int* __restrict id_buf, T* __restrict val_buf)
|
|
{
|
|
int thread_id = threadIdx.x;
|
|
int block_id = blockIdx.x;
|
|
TopK<T, MAX_K> partial;
|
|
if (thread_id == 0)
|
|
{
|
|
for (int i = 0; i < MAX_K; ++i)
|
|
{
|
|
partial.p[i] = -1;
|
|
partial.u[i] = -FLT_MAX;
|
|
}
|
|
|
|
int index = block_id * MAX_K * MAX_K;
|
|
for (int i = 0; i < MAX_K * MAX_K; i++)
|
|
{
|
|
partial.insert((T) topk_tmp_val_buf[index + i], topk_tmp_id_buf[index + i]);
|
|
}
|
|
|
|
index = block_id * MAX_K;
|
|
for (int i = 0; i < MAX_K; i++)
|
|
{
|
|
id_buf[index + i] = partial.p[i];
|
|
val_buf[index + i] = partial.u[i];
|
|
}
|
|
}
|
|
}
|
|
|
|
template <typename T, int MAX_K, int THREADBLOCK_SIZE>
|
|
__launch_bounds__(THREADBLOCK_SIZE) __global__
|
|
void batch_topk_kernel(const int* __restrict x, const T* __restrict y, int** output_ids_ptr, float* __restrict v,
|
|
float* output_log_probs, const bool* finished, const int* sequence_lengths, BeamHypotheses beam_hyps,
|
|
const int V, const int K, const int vocab_size, const float* length_penalties, const float* diversity_rates)
|
|
{
|
|
int thread_id = threadIdx.x;
|
|
int vector_id = blockIdx.x;
|
|
const int global_batch_idx{beam_hyps.ite * beam_hyps.local_batch_size + vector_id};
|
|
const T diversity_rate{diversity_rates[global_batch_idx]};
|
|
const float length_penalty{length_penalties[global_batch_idx]};
|
|
|
|
// reposition x, y to data for the current vector
|
|
x += vector_id * V;
|
|
y += vector_id * V;
|
|
|
|
typedef cub::BlockReduce<TopK<T, MAX_K>, THREADBLOCK_SIZE> BlockReduce;
|
|
|
|
__shared__ typename BlockReduce::TempStorage temp_storage;
|
|
__shared__ int selected_beams;
|
|
__shared__ float old_cum_log_probs[MAX_K];
|
|
|
|
if (thread_id == 0)
|
|
{
|
|
selected_beams = 0;
|
|
}
|
|
if (thread_id < K)
|
|
{
|
|
old_cum_log_probs[thread_id] = v[vector_id * K + thread_id];
|
|
}
|
|
__syncthreads();
|
|
if (beam_hyps.num_beams != nullptr)
|
|
{
|
|
if (beam_hyps.num_beams[global_batch_idx] == 0 && thread_id == 0)
|
|
{
|
|
beam_hyps.min_normed_scores[global_batch_idx] = FLT_MAX;
|
|
}
|
|
else if (beam_hyps.num_beams[global_batch_idx] == K)
|
|
{
|
|
return;
|
|
}
|
|
}
|
|
|
|
TopK<T, MAX_K> partial;
|
|
for (int i = 0; i < MAX_K; ++i)
|
|
{
|
|
partial.p[i] = -1;
|
|
partial.u[i] = -FLT_MAX;
|
|
}
|
|
|
|
for (int elem_id = thread_id; elem_id < V; elem_id += THREADBLOCK_SIZE)
|
|
{
|
|
int i = beam_hyps.num_beams == nullptr ? elem_id % K : elem_id / 2 / K;
|
|
T elem = length_penalty == 0.0f ? y[elem_id]
|
|
: apply_length_penalty(y[elem_id],
|
|
finished[vector_id * K + i] ? sequence_lengths[vector_id * K + i]
|
|
: sequence_lengths[vector_id * K + i] + 1,
|
|
length_penalty);
|
|
elem += diversity_rate * (T) i;
|
|
int elem_idx = elem_id; // x[elem_id];
|
|
partial.insert(elem, elem_idx);
|
|
}
|
|
|
|
TopK<T, MAX_K> total = BlockReduce(temp_storage).Reduce(partial, reduce_topk_op<T, MAX_K>);
|
|
|
|
if (thread_id == 0)
|
|
{
|
|
v += vector_id * K;
|
|
|
|
for (int i = 0; i < MAX_K; ++i)
|
|
{
|
|
if (beam_hyps.num_beams != nullptr && x[total.p[i]] % vocab_size == beam_hyps.end_ids[vector_id])
|
|
{
|
|
// if beam_token does not belong to top num_beams tokens, it should not
|
|
// be added. Refer from
|
|
// https://github.com/huggingface/transformers/blob/v4.24.0/src/transformers/generation_beam_search.py#L257
|
|
if (i >= K)
|
|
{
|
|
// do nothing
|
|
}
|
|
else
|
|
{
|
|
const float normed_score = (float) total.u[i];
|
|
const int num_beam = beam_hyps.num_beams[global_batch_idx];
|
|
int beam_idx = num_beam;
|
|
// If there are beam_width finished sentences, check that the score of
|
|
// selected candidatet is higher than min_normed_score or not. If
|
|
// current score is better, replace worst one and update the
|
|
// min_normed_score.
|
|
if (num_beam == K)
|
|
{
|
|
if (normed_score < beam_hyps.min_normed_scores[global_batch_idx])
|
|
{
|
|
// end the tracing and exist this for loop
|
|
selected_beams = K;
|
|
break;
|
|
}
|
|
else
|
|
{
|
|
// find the beam index which's score = min_normed_score, erase it.
|
|
for (int j = 0; j < K; j++)
|
|
{
|
|
if (beam_hyps.normed_scores[global_batch_idx * (K * 2) + j]
|
|
== beam_hyps.min_normed_scores[global_batch_idx])
|
|
{
|
|
beam_idx = j;
|
|
beam_hyps.num_beams[global_batch_idx]--;
|
|
|
|
beam_hyps.min_normed_scores[global_batch_idx] = FLT_MAX;
|
|
beam_hyps.normed_scores[global_batch_idx * (K * 2) + j] = normed_score;
|
|
for (int l = 0; l < K; l++)
|
|
{
|
|
beam_hyps.min_normed_scores[global_batch_idx]
|
|
= min(beam_hyps.min_normed_scores[global_batch_idx],
|
|
beam_hyps.normed_scores[global_batch_idx * (K * 2) + l]);
|
|
}
|
|
break;
|
|
}
|
|
}
|
|
}
|
|
}
|
|
const int tgt_id_offset
|
|
= ((vector_id + beam_hyps.ite * beam_hyps.local_batch_size) * (K * 2) + beam_idx)
|
|
* (beam_hyps.max_seq_len);
|
|
|
|
int prev_id = (x[total.p[i]] / vocab_size) % K;
|
|
const int current_step{sequence_lengths[vector_id * K + prev_id]};
|
|
beam_hyps.output_ids_tgt[tgt_id_offset + current_step] = beam_hyps.end_ids[vector_id];
|
|
|
|
if (beam_hyps.log_probs != nullptr)
|
|
{
|
|
beam_hyps.log_probs[tgt_id_offset + current_step]
|
|
= (float) y[total.p[i]] - old_cum_log_probs[(x[total.p[i]] / vocab_size) % K];
|
|
}
|
|
|
|
for (int j = current_step - 1; j >= 0; j--)
|
|
{
|
|
const int src_idx = j * beam_hyps.batch_size * K
|
|
+ beam_hyps.ite * beam_hyps.local_batch_size * K + vector_id * K + prev_id;
|
|
|
|
beam_hyps.output_ids_tgt[tgt_id_offset + j]
|
|
= beam_hyps.output_ids_src_ptr[vector_id][prev_id * beam_hyps.max_seq_len + j];
|
|
if (beam_hyps.log_probs != nullptr && beam_hyps.log_probs_src != nullptr)
|
|
{
|
|
beam_hyps.log_probs[tgt_id_offset + j] = beam_hyps.log_probs_src[src_idx];
|
|
}
|
|
prev_id = beam_hyps.parent_ids_src_ptr[vector_id][prev_id * beam_hyps.max_seq_len + j];
|
|
}
|
|
const int tgt_beam_idx = global_batch_idx * (K * 2) + beam_idx;
|
|
beam_hyps.sequence_lengths_tgt[tgt_beam_idx] = current_step;
|
|
beam_hyps.normed_scores[tgt_beam_idx] = normed_score;
|
|
beam_hyps.min_normed_scores[global_batch_idx]
|
|
= min(beam_hyps.min_normed_scores[global_batch_idx], beam_hyps.normed_scores[tgt_beam_idx]);
|
|
|
|
beam_hyps.num_beams[global_batch_idx]++;
|
|
beam_hyps.cum_log_probs[tgt_beam_idx] = (float) y[total.p[i]];
|
|
}
|
|
}
|
|
else if ((beam_hyps.num_beams != nullptr && i < 2 * K) || (beam_hyps.num_beams == nullptr && i < K))
|
|
{
|
|
const int current_step{sequence_lengths[vector_id * K + selected_beams]};
|
|
output_ids_ptr[vector_id][selected_beams * beam_hyps.max_seq_len + current_step] = x[total.p[i]];
|
|
if (output_log_probs != nullptr)
|
|
{
|
|
output_log_probs[current_step * beam_hyps.batch_size * K + vector_id * K + selected_beams]
|
|
= (float) y[total.p[i]] - old_cum_log_probs[(x[total.p[i]] / vocab_size) % K];
|
|
}
|
|
v[selected_beams] = (float) y[total.p[i]];
|
|
selected_beams++;
|
|
}
|
|
__syncthreads();
|
|
if (selected_beams >= K)
|
|
{
|
|
break;
|
|
}
|
|
}
|
|
}
|
|
if (threadIdx.x == 0 && beam_hyps.num_beams != nullptr)
|
|
{
|
|
if (beam_hyps.num_beams[blockIdx.x] < K)
|
|
{
|
|
beam_hyps.is_done[blockIdx.x] = false;
|
|
}
|
|
else if (beam_hyps.early_stopping)
|
|
{
|
|
beam_hyps.is_done[blockIdx.x] = true;
|
|
}
|
|
}
|
|
}
|
|
|
|
struct __align__(8) MD
|
|
{
|
|
float m;
|
|
float d;
|
|
};
|
|
|
|
__device__ __forceinline__ MD reduce_md_op(MD a, MD b)
|
|
{
|
|
bool a_bigger = (a.m > b.m);
|
|
MD bigger_m = a_bigger ? a : b;
|
|
MD smaller_m = a_bigger ? b : a;
|
|
MD res;
|
|
res.d = bigger_m.d + smaller_m.d * __expf(smaller_m.m - bigger_m.m);
|
|
res.m = bigger_m.m;
|
|
return res;
|
|
}
|
|
|
|
template <typename T, int MAX_K>
|
|
struct TopKMD
|
|
{
|
|
MD md;
|
|
TopK<T, MAX_K> topk;
|
|
};
|
|
|
|
template <typename T, int MAX_K>
|
|
__device__ __forceinline__ TopKMD<T, MAX_K> reduce_topk_md_op(const TopKMD<T, MAX_K>& a, const TopKMD<T, MAX_K>& b)
|
|
{
|
|
TopKMD<T, MAX_K> res;
|
|
res.md = reduce_md_op(a.md, b.md);
|
|
res.topk = reduce_topk_op(a.topk, b.topk);
|
|
return res;
|
|
}
|
|
|
|
template <typename T, int ITEMS_PER_THREAD, int MAX_K, int THREADBLOCK_SIZE>
|
|
__launch_bounds__(THREADBLOCK_SIZE) __global__ void beam_online_softmax_topk_kernel(const T* __restrict x,
|
|
const T* __restrict b, const float* __restrict c, const bool* __restrict finished, int* __restrict z,
|
|
T* __restrict v, int V, int K, const int* __restrict end_ids)
|
|
{
|
|
int thread_id = threadIdx.x;
|
|
int vector_id = blockIdx.x;
|
|
|
|
const bool IS_FP16 = std::is_same<T, half>::value;
|
|
const T MAX_T_VAL = (IS_FP16) ? HALF_FLT_MAX : FLT_MAX;
|
|
|
|
// reposition y to data for the current vector
|
|
x += vector_id * V;
|
|
|
|
typedef cub::BlockReduce<TopKMD<float, MAX_K>, THREADBLOCK_SIZE> BlockReduce;
|
|
__shared__ typename BlockReduce::TempStorage temp_storage;
|
|
|
|
TopKMD<float, MAX_K> partial;
|
|
bool finish = finished[vector_id];
|
|
for (int i = 0; i < MAX_K; ++i)
|
|
{
|
|
partial.topk.p[i] = -1;
|
|
partial.topk.u[i] = -MAX_T_VAL;
|
|
}
|
|
partial.md.m = -MAX_T_VAL;
|
|
partial.md.d = 0.0F;
|
|
|
|
if (finish)
|
|
{
|
|
for (int elem_id = thread_id; elem_id < V; elem_id += THREADBLOCK_SIZE)
|
|
{
|
|
float elem = (elem_id == end_ids[vector_id / K]) ? MAX_T_VAL : -MAX_T_VAL;
|
|
MD new_elem{elem, 1.0F};
|
|
partial.md = reduce_md_op(partial.md, new_elem);
|
|
partial.topk.insert(elem, elem_id);
|
|
// if (elem_id > THREADBLOCK_SIZE * MAX_K && (elem_id == E)) break;
|
|
}
|
|
}
|
|
else
|
|
{
|
|
for (int elem_id = thread_id; elem_id < V; elem_id += THREADBLOCK_SIZE)
|
|
{
|
|
float elem = x[elem_id] + b[elem_id];
|
|
MD new_elem{elem, 1.0F};
|
|
partial.md = reduce_md_op(partial.md, new_elem);
|
|
partial.topk.insert(elem, elem_id);
|
|
}
|
|
}
|
|
|
|
TopKMD<float, MAX_K> total = BlockReduce(temp_storage).Reduce(partial, reduce_topk_md_op<float, MAX_K>);
|
|
|
|
if (thread_id == 0)
|
|
{
|
|
z += vector_id * K;
|
|
v += vector_id * K;
|
|
c += vector_id;
|
|
|
|
// float d_total_inverse = __fdividef(1.0F, total.md.d);
|
|
float d_total_log = logf(total.md.d);
|
|
for (int i = 0; i < MAX_K; ++i)
|
|
{
|
|
// float val = __expf(total.topk.u[i] - total.md.m) * d_total_inverse;
|
|
float val = total.topk.u[i] - total.md.m - d_total_log;
|
|
if (i < K)
|
|
{
|
|
z[i] = total.topk.p[i] + vector_id * V; // faster transformer needs absolute id
|
|
v[i] = val + c[0];
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
template <typename T, int ITEMS_PER_THREAD, int MAX_K, int THREADBLOCK_SIZE>
|
|
__launch_bounds__(THREADBLOCK_SIZE, 1) __global__
|
|
void beam_online_softmax_topk_stage1_kernel(const T* __restrict x, const T* __restrict b,
|
|
const bool* __restrict finished, float* __restrict t, int V, int K, const int* __restrict end_ids)
|
|
{
|
|
int thread_id = threadIdx.x;
|
|
int vector_id = blockIdx.x; // batch beam index.
|
|
|
|
const int PACKED_TOP_KMD_SIZE = 2 * MAX_K + 2;
|
|
|
|
const bool IS_FP16 = std::is_same<T, half>::value;
|
|
const T MAX_T_VAL = (IS_FP16) ? HALF_FLT_MAX : FLT_MAX;
|
|
|
|
// one will have multiple sections per V
|
|
const int v_local = (V + gridDim.y - 1) / gridDim.y;
|
|
const int section_start = v_local * blockIdx.y;
|
|
int section_end = section_start + v_local;
|
|
section_end = (section_end > V) ? V : section_end;
|
|
|
|
// reposition x to data for the current vector
|
|
x += vector_id * V;
|
|
#if TOPK_FP16_STORAGE == 1
|
|
typedef cub::BlockReduce<TopKMD<__half, MAX_K>, THREADBLOCK_SIZE> BlockReduce;
|
|
#else
|
|
typedef cub::BlockReduce<TopKMD<T, MAX_K>, THREADBLOCK_SIZE> BlockReduce;
|
|
#endif
|
|
__shared__ typename BlockReduce::TempStorage temp_storage;
|
|
__shared__ float buf_s[PACKED_TOP_KMD_SIZE]; // save intermediate result
|
|
|
|
#if TOPK_FP16_STORAGE == 1
|
|
TopKMD<__half, MAX_K> partial;
|
|
#else
|
|
TopKMD<T, MAX_K> partial;
|
|
#endif
|
|
bool finish = finished[vector_id];
|
|
for (int i = 0; i < MAX_K; ++i)
|
|
{
|
|
partial.topk.p[i] = -1;
|
|
partial.topk.u[i] = -MAX_T_VAL;
|
|
}
|
|
partial.md.m = -MAX_T_VAL;
|
|
partial.md.d = 0.0F;
|
|
|
|
if (finish)
|
|
{
|
|
#pragma unroll 1
|
|
for (int elem_id = section_start + thread_id; elem_id < section_end; elem_id += THREADBLOCK_SIZE)
|
|
{
|
|
float elem = (elem_id == end_ids[vector_id / K]) ? MAX_T_VAL : -MAX_T_VAL;
|
|
MD new_elem{elem, 1.0F};
|
|
partial.md = reduce_md_op(partial.md, new_elem);
|
|
partial.topk.insert(elem, elem_id);
|
|
}
|
|
}
|
|
else
|
|
{
|
|
#pragma unroll 1
|
|
for (int elem_id = section_start + thread_id; elem_id < section_end; elem_id += THREADBLOCK_SIZE)
|
|
{
|
|
T bias = b == nullptr ? (T) 0.0f : b[elem_id]; // gpt-2 does not use bias
|
|
T elem = x[elem_id] + bias;
|
|
MD new_elem{elem, 1.0F};
|
|
partial.md = reduce_md_op(partial.md, new_elem);
|
|
partial.topk.insert(elem, elem_id);
|
|
}
|
|
}
|
|
|
|
#if TOPK_FP16_STORAGE == 1
|
|
TopKMD<__half, MAX_K> total = BlockReduce(temp_storage).Reduce(partial, reduce_topk_md_op<__half, MAX_K>);
|
|
#else
|
|
TopKMD<T, MAX_K> total = BlockReduce(temp_storage).Reduce(partial, reduce_topk_md_op<T, MAX_K>);
|
|
#endif
|
|
|
|
if (thread_id == 0)
|
|
{
|
|
for (int i = 0; i < 2 * K; i++)
|
|
{
|
|
reinterpret_cast<int*>(buf_s)[i] = total.topk.p[i] + vector_id * V; // faster transformer needs absolute id
|
|
buf_s[MAX_K + i] = total.topk.u[i];
|
|
}
|
|
buf_s[2 * MAX_K] = total.md.d;
|
|
buf_s[2 * MAX_K + 1] = total.md.m;
|
|
}
|
|
__syncthreads();
|
|
for (int elem_id = thread_id; elem_id < PACKED_TOP_KMD_SIZE; elem_id += THREADBLOCK_SIZE)
|
|
{
|
|
t[blockIdx.x * PACKED_TOP_KMD_SIZE * gridDim.y + blockIdx.y * PACKED_TOP_KMD_SIZE + elem_id] = buf_s[elem_id];
|
|
}
|
|
}
|
|
|
|
template <typename T, int MAX_K, int THREADBLOCK_SIZE>
|
|
__launch_bounds__(THREADBLOCK_SIZE) __global__ void beam_online_softmax_topk_stage2_kernel(
|
|
const float* __restrict x, const float* __restrict c, int* __restrict z, T* __restrict v, int K, int parts_per_beam)
|
|
{
|
|
const int vector_id = blockIdx.x;
|
|
const int thread_id = threadIdx.x;
|
|
const int PACKED_TOP_KMD_SIZE = 2 * MAX_K + 2;
|
|
|
|
const bool IS_FP16 = std::is_same<T, half>::value;
|
|
const T MAX_T_VAL = (IS_FP16) ? HALF_FLT_MAX : FLT_MAX;
|
|
|
|
extern __shared__ char buf_s_[]; // intermediate result
|
|
float* buf_s = reinterpret_cast<float*>(buf_s_);
|
|
//__shared__ float buf_s[PACKED_TOP_KMD_SIZE * THREADBLOCK_SIZE]; //
|
|
// intermediate result
|
|
|
|
typedef cub::BlockReduce<TopKMD<T, MAX_K>, THREADBLOCK_SIZE> BlockReduce;
|
|
__shared__ typename BlockReduce::TempStorage temp_storage;
|
|
|
|
x += vector_id * PACKED_TOP_KMD_SIZE * parts_per_beam;
|
|
|
|
TopKMD<T, MAX_K> partial;
|
|
for (int i = 0; i < MAX_K; ++i)
|
|
{
|
|
partial.topk.p[i] = -1;
|
|
partial.topk.u[i] = -MAX_T_VAL;
|
|
}
|
|
partial.md.m = -MAX_T_VAL;
|
|
partial.md.d = 0.0F;
|
|
|
|
// load and unpack into registers through smem
|
|
for (int idx = thread_id; idx < PACKED_TOP_KMD_SIZE * parts_per_beam; idx += THREADBLOCK_SIZE)
|
|
{
|
|
buf_s[idx] = x[idx];
|
|
}
|
|
__syncthreads();
|
|
|
|
if (threadIdx.x < parts_per_beam)
|
|
{
|
|
float* b_s = buf_s + thread_id * PACKED_TOP_KMD_SIZE;
|
|
for (int i = 0; i < 2 * K; i++)
|
|
{
|
|
partial.topk.p[i] = reinterpret_cast<int*>(b_s)[i];
|
|
partial.topk.u[i] = b_s[MAX_K + i];
|
|
}
|
|
partial.md.d = b_s[2 * MAX_K];
|
|
partial.md.m = b_s[2 * MAX_K + 1];
|
|
}
|
|
__syncthreads();
|
|
|
|
TopKMD<T, MAX_K> total = BlockReduce(temp_storage).Reduce(partial, reduce_topk_md_op<T, MAX_K>);
|
|
|
|
if (thread_id == 0)
|
|
{
|
|
z += vector_id * 2 * K;
|
|
v += vector_id * 2 * K;
|
|
c += vector_id;
|
|
|
|
float d_total_log = logf(total.md.d);
|
|
for (int i = 0; i < MAX_K; ++i)
|
|
{
|
|
float val = (float) total.topk.u[i] - total.md.m - d_total_log;
|
|
if (i < 2 * K)
|
|
{
|
|
z[i] = total.topk.p[i];
|
|
v[i] = (float) val + (float) c[0];
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
template <typename T, int MAX_K>
|
|
void beam_online_softmax_topk_stage2_kernelLauncher(const float* temp_storage, const float* cum_log_probs, int* ids,
|
|
T* vals, int batch_size, int beam_width, int parts_per_beam, cudaStream_t stream)
|
|
{
|
|
// might rewrite beam_online_softmax_topk_stage2_kernel no to depend on
|
|
// constant block size in oreder to reduce compilation time
|
|
int smem_stage2_size = parts_per_beam * (2 * MAX_K + 2) * sizeof(float);
|
|
|
|
if (parts_per_beam <= 32)
|
|
{
|
|
beam_online_softmax_topk_stage2_kernel<T, MAX_K, 32><<<batch_size * beam_width, 32, smem_stage2_size, stream>>>(
|
|
temp_storage, cum_log_probs, ids, vals, beam_width, parts_per_beam);
|
|
return;
|
|
}
|
|
if (parts_per_beam <= 64)
|
|
{
|
|
beam_online_softmax_topk_stage2_kernel<T, MAX_K, 64><<<batch_size * beam_width, 64, smem_stage2_size, stream>>>(
|
|
temp_storage, cum_log_probs, ids, vals, beam_width, parts_per_beam);
|
|
return;
|
|
}
|
|
if (parts_per_beam <= 128)
|
|
{
|
|
beam_online_softmax_topk_stage2_kernel<T, MAX_K, 128>
|
|
<<<batch_size * beam_width, 128, smem_stage2_size, stream>>>(
|
|
temp_storage, cum_log_probs, ids, vals, beam_width, parts_per_beam);
|
|
return;
|
|
}
|
|
assert(0);
|
|
}
|
|
|
|
template <typename T, int MAX_K>
|
|
void topK_softMax_kernelLauncher(const T* log_probs, const T* bias, const bool* finished, const int* sequence_lengths,
|
|
float* cum_log_probs, float* output_log_probs, int** output_ids_ptr, void* temp_storage,
|
|
const int temp_storage_size, BeamHypotheses* beam_hyps, const int batch_size, const int beam_width,
|
|
const int vocab_size, const int* end_ids, const float* diversity_rates, const float* length_penalties,
|
|
cudaStream_t stream)
|
|
{
|
|
const int items_per_thread = 1;
|
|
const int block_sz = (MAX_K < 16) ? (MAX_K < 8) ? SMALL_TOP_K_SOFTMAX_THREADBLOCK_SIZE : 128 : 64;
|
|
// const int block_sz = SMALL_TOP_K_SOFTMAX_THREADBLOCK_SIZE;
|
|
|
|
assert(temp_storage_size % 2 == 0);
|
|
assert(temp_storage_size >= 2 * batch_size * beam_width * beam_width * 2);
|
|
// Beam search needs the sequence lengths of beams to apply length penalty.
|
|
assert(length_penalties == nullptr || sequence_lengths != nullptr);
|
|
|
|
const int topk_buf_offset = ceil(batch_size * beam_width * beam_width * 2 / 4.) * 4;
|
|
int* topk_tmp_id_buf = reinterpret_cast<int*>(temp_storage);
|
|
T* topk_tmp_val_buf = reinterpret_cast<T*>(topk_tmp_id_buf + topk_buf_offset);
|
|
float* tmp_buffer = reinterpret_cast<float*>(topk_tmp_val_buf + topk_buf_offset);
|
|
|
|
#ifdef DO_SPLIT_SMALL_TOP_K_SOFTMAX
|
|
int voc_parts = 4;
|
|
if (batch_size * beam_width < 256)
|
|
{
|
|
// Volta has 80 SMs, so we aim for three waves
|
|
voc_parts = (240 + batch_size * beam_width - 1) / (batch_size * beam_width);
|
|
voc_parts = std::min(128, voc_parts); // we implement up to 128
|
|
}
|
|
dim3 grid(batch_size * beam_width, voc_parts);
|
|
cudaFuncSetAttribute(beam_online_softmax_topk_stage1_kernel<T, items_per_thread, 2 * MAX_K, block_sz>,
|
|
cudaFuncAttributePreferredSharedMemoryCarveout, cudaSharedmemCarveoutMaxL1);
|
|
beam_online_softmax_topk_stage1_kernel<T, items_per_thread, 2 * MAX_K, block_sz>
|
|
<<<grid, block_sz, 0, stream>>>(log_probs, bias, finished, tmp_buffer, vocab_size, beam_width, end_ids);
|
|
sync_check_cuda_error();
|
|
#endif
|
|
|
|
#ifdef DO_SPLIT_SMALL_TOP_K_SOFTMAX
|
|
beam_online_softmax_topk_stage2_kernelLauncher<T, 2 * MAX_K>(
|
|
tmp_buffer, cum_log_probs, topk_tmp_id_buf, topk_tmp_val_buf, batch_size, beam_width, voc_parts, stream);
|
|
sync_check_cuda_error();
|
|
#else
|
|
beam_online_softmax_topk_kernel<T, items_per_thread, MAX_K, block_sz>
|
|
<<<batch_size * beam_width, block_sz, 0, stream>>>(log_probs, bias, cum_log_probs, finished, topk_tmp_id_buf,
|
|
topk_tmp_val_buf, vocab_size, beam_width, end_ids);
|
|
#endif
|
|
|
|
// We need 2*MAX_K candidates because at most k candidates are finished, and
|
|
// we will not put them into next iteration
|
|
batch_topk_kernel<T, MAX_K * 2, 32><<<batch_size, 32, 0, stream>>>(topk_tmp_id_buf, topk_tmp_val_buf,
|
|
output_ids_ptr, cum_log_probs, output_log_probs, finished, sequence_lengths, *beam_hyps,
|
|
beam_width * beam_width * 2, beam_width, vocab_size, length_penalties, diversity_rates);
|
|
sync_check_cuda_error();
|
|
}
|
|
|
|
#define INSTANTIATE_BEAMSEARCH_K(T, MAX_K) \
|
|
template void topK_softMax_kernelLauncher<T, MAX_K>(const T* log_probs, const T* bias, const bool* finished, \
|
|
const int* sequence_lengths, float* cum_log_probs, float* output_log_probs, int** output_ids_ptr, \
|
|
void* temp_storage, const int temp_storage_size, BeamHypotheses* beam_hyps, const int batch_size, \
|
|
const int beam_width, const int vocab_size, const int* end_ids, const float* diversity_rates, \
|
|
const float* length_penalties, cudaStream_t stream);
|
|
|
|
} // namespace kernels
|
|
} // namespace tensorrt_llm
|