TensorRT-LLMs/cpp/tensorrt_llm/kernels/decodingKernels.cu
2024-05-07 23:34:28 +08:00

1026 lines
43 KiB
Plaintext

/*
* Copyright (c) 2020-2024, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "tensorrt_llm/common/assert.h"
#include "tensorrt_llm/common/cudaTypeUtils.cuh"
#include "tensorrt_llm/common/cudaUtils.h"
#include "tensorrt_llm/common/memoryUtils.h"
#include "tensorrt_llm/common/reduceKernelUtils.cuh"
#include "tensorrt_llm/kernels/decodingKernels.h"
#ifndef CUDART_VERSION
#error CUDART_VERSION Undefined!
#elif (CUDART_VERSION >= 11050)
#include <cub/cub.cuh>
#else
#include "3rdparty/cub/cub.cuh"
#endif
using namespace tensorrt_llm::common;
using namespace tensorrt_llm::runtime;
namespace tensorrt_llm
{
namespace kernels
{
__global__ void gatherTree(gatherTreeParam param)
{
for (int batchbeamIdx = blockIdx.x * blockDim.x + threadIdx.x; batchbeamIdx < param.batchSize * param.beamWidth;
batchbeamIdx += gridDim.x * blockDim.x)
{
int const batch = batchbeamIdx / param.beamWidth;
int const beam = batchbeamIdx % param.beamWidth;
int const inputLen = param.inputLengths == nullptr ? 0 : param.inputLengths[batchbeamIdx];
int const* parentIds = param.parentIds;
int const* stepIds = param.stepIds;
// TODO optimize the reduce_max operation for large beamWidth
int maxLen = -1;
bool updateResponseInputLength = param.responseInputLengths != nullptr;
// int selected_beam_index = 0;
for (int beamIdx = 0; beamIdx < param.beamWidth; beamIdx++)
{
int tmpLen
= param.sequenceLengths[batch * param.beamWidth + beamIdx] + param.maxSequenceLengthFinalStep - 1;
param.sequenceLengths[batch * param.beamWidth + beamIdx] = tmpLen;
if (updateResponseInputLength)
{
param.responseInputLengths[batch * param.beamWidth + beamIdx] = inputLen;
}
if (tmpLen > maxLen)
{
maxLen = tmpLen;
}
}
int const maxSeqLenB = min(param.maxSeqLen, maxLen);
if (maxSeqLenB <= 0)
{
continue;
}
int const initialTgtIx = batch * param.beamWidth * param.maxSeqLen + beam * param.maxSeqLen + maxSeqLenB - 1;
int const initialParentIx = batch * param.beamWidth * param.maxSeqLen + beam * param.maxSeqLen + maxSeqLenB - 1;
param.outputIds[initialTgtIx] = __ldg(stepIds + initialParentIx);
int parent = parentIds == nullptr ? 0 : __ldg(parentIds + initialParentIx) % param.beamWidth;
bool foundBad = false;
for (int level = maxSeqLenB - 2; level >= 0; --level)
{
int const levelBeamIx = batch * param.beamWidth * param.maxSeqLen + beam * param.maxSeqLen + level;
int const levelParentIx = batch * param.beamWidth * param.maxSeqLen + parent * param.maxSeqLen + level;
if (parent < 0 || parent > param.beamWidth)
{
param.outputIds[levelBeamIx] = param.endTokens[batch];
parent = -1;
foundBad = true;
}
else
{
param.outputIds[levelBeamIx] = __ldg(stepIds + levelParentIx);
parent = parentIds == nullptr ? 0 : __ldg(parentIds + levelParentIx) % param.beamWidth;
}
}
// set the padded part as end_token
// inputLen
for (int index = maxLen; index < param.maxSeqLen; ++index)
{
param.outputIds[batch * param.beamWidth * param.maxSeqLen + beam * param.maxSeqLen + index]
= param.endTokens[batch];
}
// Not necessary when using a BeamSearchDecoder, but necessary
// when a user feeds in possibly broken trajectory (i.e., non-eos
// entries in a beam following eos entries).
if (!foundBad)
{
bool finished = false;
// skip the step 0 because it is often the start token
int startStep = 1;
for (int time = startStep; time < maxSeqLenB; ++time)
{
int const levelBeamIx = batch * param.beamWidth * param.maxSeqLen + beam * param.maxSeqLen + time;
if (finished)
{
param.outputIds[levelBeamIx] = param.endTokens[batch];
}
else if (param.outputIds[levelBeamIx] == param.endTokens[batch])
{
finished = true;
}
}
}
}
}
struct RankNorm
{
int rank;
float norm;
};
inline __device__ RankNorm swap(RankNorm const& rankNorm, int mask, int dir)
{
// Exchange the rank and norm inside the warp.
RankNorm other;
other.rank = __shfl_xor_sync(unsigned(-1), rankNorm.rank, mask);
other.norm = __shfl_xor_sync(unsigned(-1), rankNorm.norm, mask);
// Update the sorted values.
bool doSwap = (rankNorm.norm != other.norm) && ((rankNorm.norm > other.norm) == dir);
RankNorm res;
res.rank = doSwap ? other.rank : rankNorm.rank;
res.norm = doSwap ? other.norm : rankNorm.norm;
return res;
}
inline __device__ uint32_t bfe(uint32_t a, uint32_t start, uint32_t len = 1)
{
uint32_t d;
asm volatile("bfe.u32 %0, %1, %2, %3;" : "=r"(d) : "r"(a), "r"(start), "r"(len));
return d;
}
__global__ void finalized(gatherTreeParam param)
{
int const beamIdx = static_cast<int>(threadIdx.x);
int const beamWidth{param.beamWidth};
extern __shared__ char array[];
int* sRank = (int*) (array);
int* sLength = (int*) (sRank + beamWidth);
float* sScores = (float*) (sLength + beamWidth);
float* sNormedScores = (float*) (sScores + beamWidth);
int* sIds = (int*) (sNormedScores + beamWidth);
if (beamIdx < beamWidth)
{
int const idx = blockIdx.x * param.beamWidth + beamIdx;
int const numGeneratedToken{param.sequenceLengths[idx] - param.inputLengths[idx]};
sNormedScores[beamIdx] = applyLengthPenalty(param.cumLogProbs[idx], numGeneratedToken, param.lengthPenalty);
sLength[beamIdx] = param.sequenceLengths[idx];
sScores[beamIdx] = param.cumLogProbs[idx];
}
for (int idx = beamIdx; idx < beamWidth * param.maxSeqLen; idx += blockDim.x)
{
sIds[idx] = param.outputIds[blockIdx.x * param.beamWidth * param.maxSeqLen + idx];
}
__syncthreads();
RankNorm rankNorm;
rankNorm.rank = beamIdx;
rankNorm.norm = beamIdx < beamWidth ? sNormedScores[beamIdx] : -FLT_MAX;
if (beamWidth < 32)
{
int warpid = threadIdx.x / 32;
int laneid = threadIdx.x % 32;
if (warpid == 0 && beamWidth > 1)
{
rankNorm = swap(rankNorm, 0x01, bfe(laneid, 1) ^ bfe(laneid, 0)); // 2
}
if (warpid == 0 && beamWidth > 2)
{
rankNorm = swap(rankNorm, 0x02, bfe(laneid, 2) ^ bfe(laneid, 1)); // 3~4
rankNorm = swap(rankNorm, 0x01, bfe(laneid, 2) ^ bfe(laneid, 0));
}
if (warpid == 0 && beamWidth > 4)
{
rankNorm = swap(rankNorm, 0x04, bfe(laneid, 3) ^ bfe(laneid, 2)); // 5~8
rankNorm = swap(rankNorm, 0x02, bfe(laneid, 3) ^ bfe(laneid, 1));
rankNorm = swap(rankNorm, 0x01, bfe(laneid, 3) ^ bfe(laneid, 0));
}
if (warpid == 0 && beamWidth > 8)
{
rankNorm = swap(rankNorm, 0x08, bfe(laneid, 4) ^ bfe(laneid, 3)); // 9~16
rankNorm = swap(rankNorm, 0x04, bfe(laneid, 4) ^ bfe(laneid, 2));
rankNorm = swap(rankNorm, 0x02, bfe(laneid, 4) ^ bfe(laneid, 1));
rankNorm = swap(rankNorm, 0x01, bfe(laneid, 4) ^ bfe(laneid, 0));
}
if (warpid == 0 && beamWidth > 16)
{
rankNorm = swap(rankNorm, 0x10, bfe(laneid, 4) ^ bfe(laneid, 4)); // 17~32
rankNorm = swap(rankNorm, 0x08, bfe(laneid, 4) ^ bfe(laneid, 3));
rankNorm = swap(rankNorm, 0x04, bfe(laneid, 4) ^ bfe(laneid, 2));
rankNorm = swap(rankNorm, 0x02, bfe(laneid, 4) ^ bfe(laneid, 1));
rankNorm = swap(rankNorm, 0x01, bfe(laneid, 4) ^ bfe(laneid, 0));
}
}
else
{
// Not supported! We must have a check before calling that kernel.
}
if (beamIdx < beamWidth)
{
sRank[beamIdx] = rankNorm.rank;
}
__syncthreads();
if (beamIdx < beamWidth)
{
auto srcIdx{rankNorm.rank};
auto tgtIdx{blockIdx.x * param.beamWidth + beamIdx};
param.sequenceLengths[tgtIdx] = sLength[srcIdx];
param.cumLogProbs[tgtIdx] = sScores[srcIdx];
}
for (int beamIdx = 0; beamIdx < beamWidth; beamIdx++)
{
for (int i = threadIdx.x; i < sLength[sRank[beamIdx]]; i += blockDim.x)
{
param.outputIds[blockIdx.x * beamWidth * param.maxSeqLen + beamIdx * param.maxSeqLen + i]
= sIds[sRank[beamIdx] * param.maxSeqLen + i];
}
}
}
void invokeGatherTree(gatherTreeParam param)
{
int batchbeam = param.batchSize * param.beamWidth;
dim3 grid(1), block(batchbeam);
// though decoder do not support > 1024 for now
if (batchbeam > 1024)
{
grid.x = ceil(param.batchSize * param.beamWidth / 1024.);
block.x = 1024;
}
gatherTree<<<grid, block, 0, param.stream>>>(param);
sync_check_cuda_error();
if (param.beamWidth > 1)
{
TLLM_CHECK_WITH_INFO(param.beamWidth <= 32, "TRT-LLM does not support beam width > 32 now");
// sort results by normalized cumLogProbs
dim3 grid(param.batchSize);
dim3 block(divUp(param.beamWidth, 32) * 32);
auto shm_size = param.beamWidth * (sizeof(float) * 2 + sizeof(int) * 2 + sizeof(int) * param.maxSeqLen);
finalized<<<grid, block, shm_size, param.stream>>>(param);
}
}
__global__ void insertUnfinishedPath(BeamHypotheses bh)
{
int const bid = blockIdx.x;
int const nBS{bh.nBatchSize};
int const nBM{bh.nBeamWidth};
int const tgt_start_idx{bh.numBeamsCBA[bid]};
int const nMaxSeqLen{bh.nMaxSeqLen};
// TODO: nullptr is from [gptDecoder.cpp] GptDecoder<T>::gatherTree, need to be fixed
float const length_penalty{bh.lengthPenalties == nullptr ? 1.0f : bh.lengthPenalties[bid]};
if (bh.batchDones[bid])
{
return;
}
// Move ALL unfinished beams from bh.outputIdsUnfinish to bh.outputIdsCBA
// So there might be more than `nBM` beams in bh.outputIdsCBA
for (int i = 0; i < nBM; ++i)
{
int const src_beam_idx = bid * nBM + i;
int const tgt_beam_idx = bid * nBM * 2 + i + tgt_start_idx;
int const current_step = bh.sequenceLengths[src_beam_idx] - 1;
bh.outputIdsCBA[tgt_beam_idx * nMaxSeqLen + current_step]
= bh.outputIdsUnfinish[src_beam_idx * nMaxSeqLen + current_step];
if (bh.logProbsCBA != nullptr && bh.logProbs != nullptr)
{
bh.logProbsCBA[tgt_beam_idx * nMaxSeqLen + current_step]
= bh.logProbs[current_step * nBS * nBM + src_beam_idx];
}
int prev_id = bh.parentIdsUnfinish[src_beam_idx * nMaxSeqLen + current_step];
for (int j = current_step - 1; j >= 0; --j)
{
bh.outputIdsCBA[tgt_beam_idx * nMaxSeqLen + j]
= bh.outputIdsUnfinish[bid * nBM * nMaxSeqLen + prev_id * nMaxSeqLen + j];
if (bh.logProbsCBA != nullptr && bh.logProbs != nullptr)
{
bh.logProbsCBA[tgt_beam_idx * nMaxSeqLen + j] = bh.logProbs[j * nBS * nBM + bid * nBM + prev_id];
}
prev_id = bh.parentIdsUnfinish[bid * nBM * nMaxSeqLen + prev_id * nMaxSeqLen + j];
}
if (bh.logProbsCBA != nullptr && bh.logProbs != nullptr)
{
prev_id = bh.parentIdsUnfinish[src_beam_idx * nMaxSeqLen + current_step];
for (int j = current_step - 1; j >= 0; --j)
{
bh.logProbsCBA[tgt_beam_idx * nMaxSeqLen + j] = bh.logProbs[j * nBS * nBM + bid * nBM + prev_id];
prev_id = bh.parentIdsUnfinish[bid * nBM * nMaxSeqLen + prev_id * nMaxSeqLen + j];
}
}
bh.sequenceLengthsCBA[tgt_beam_idx] = bh.sequenceLengths[src_beam_idx];
bh.normedScoresCBA[tgt_beam_idx] = applyLengthPenalty(
bh.cumLogProbs[src_beam_idx], current_step - bh.inputLengths[src_beam_idx], length_penalty);
bh.cumLogProbsCBA[tgt_beam_idx] = bh.cumLogProbs[src_beam_idx];
bh.numBeamsCBA[bid]++;
}
}
void invokeInsertUnfinishedPath(BeamHypotheses& bh, cudaStream_t stream)
{
insertUnfinishedPath<<<bh.nBatchSize, 1, 0, stream>>>(bh);
}
__global__ void finalizeKernel(BeamHypotheses bh)
{
// Do index sort on bh.normedScoresCBA, then move buffers from CBA to output by the order of index
// bh.outputIdsCBA -> bh.outputIds
// bh.sequenceLengthsCBA -> bh.sequenceLengths
// bh.cumLogProbsCBA -> bh.cumLogProbs
// bh.logProbsCBA -> bh.logProbs
int const bid = blockIdx.x;
int const tid = threadIdx.x;
int const nBM{bh.nBeamWidth};
int const nMaxSeqLen{bh.nMaxSeqLen};
int const nBeam{bh.numBeamsCBA[bid]};
int const* inputLengths{bh.inputLengths};
extern __shared__ char array[];
int* sRank = (int*) (array); // [nBM]
float* sScores = (float*) (sRank + nBM); // [2*nBM]
int* sSequenceLengths = (int*) (sScores + nBM * 2); // [nBM]
if (tid < nBeam)
{
sScores[tid] = bh.normedScoresCBA[bid * nBM * 2 + tid];
}
__syncthreads();
if (nBeam < 32)
{
int const warpid = tid / 32;
int const laneid = tid % 32;
RankNorm rankNorm{tid, tid < nBeam ? sScores[tid] : -FLT_MAX};
if (warpid == 0 && nBeam > 1)
{
rankNorm = swap(rankNorm, 0x01, bfe(laneid, 1) ^ bfe(laneid, 0)); // 2
}
if (warpid == 0 && nBeam > 2)
{
rankNorm = swap(rankNorm, 0x02, bfe(laneid, 2) ^ bfe(laneid, 1)); // 3~4
rankNorm = swap(rankNorm, 0x01, bfe(laneid, 2) ^ bfe(laneid, 0));
}
if (warpid == 0 && nBeam > 4)
{
rankNorm = swap(rankNorm, 0x04, bfe(laneid, 3) ^ bfe(laneid, 2)); // 5~8
rankNorm = swap(rankNorm, 0x02, bfe(laneid, 3) ^ bfe(laneid, 1));
rankNorm = swap(rankNorm, 0x01, bfe(laneid, 3) ^ bfe(laneid, 0));
}
if (warpid == 0 && nBeam > 8)
{
rankNorm = swap(rankNorm, 0x08, bfe(laneid, 4) ^ bfe(laneid, 3)); // 9~16
rankNorm = swap(rankNorm, 0x04, bfe(laneid, 4) ^ bfe(laneid, 2));
rankNorm = swap(rankNorm, 0x02, bfe(laneid, 4) ^ bfe(laneid, 1));
rankNorm = swap(rankNorm, 0x01, bfe(laneid, 4) ^ bfe(laneid, 0));
}
if (warpid == 0 && nBeam > 16)
{
rankNorm = swap(rankNorm, 0x10, bfe(laneid, 4) ^ bfe(laneid, 4)); // 17~32
rankNorm = swap(rankNorm, 0x08, bfe(laneid, 4) ^ bfe(laneid, 3));
rankNorm = swap(rankNorm, 0x04, bfe(laneid, 4) ^ bfe(laneid, 2));
rankNorm = swap(rankNorm, 0x02, bfe(laneid, 4) ^ bfe(laneid, 1));
rankNorm = swap(rankNorm, 0x01, bfe(laneid, 4) ^ bfe(laneid, 0));
}
if (tid < nBM)
{
sRank[tid] = rankNorm.rank;
}
__syncthreads();
}
else
{
for (int i = 0; i < nBM; ++i)
{
float const score = tid < bh.numBeamsCBA[bid] ? sScores[tid] : -FLT_MAX;
float const maxScore = blockReduceMax<float>(score);
if (tid == 0)
{
for (int j = 0; j < nBM * 2; ++j)
{
if (sScores[j] == maxScore)
{
sRank[i] = j;
sScores[j] = -FLT_MAX;
break;
}
}
}
__syncthreads();
}
}
if (tid < nBM)
{
sSequenceLengths[tid] = bh.sequenceLengthsCBA[bid * nBM * 2 + sRank[tid]];
bh.sequenceLengths[bid * nBM + tid] = sSequenceLengths[tid];
if (bh.cumLogProbs != nullptr)
{
bh.cumLogProbs[bid * nBM + tid] = bh.cumLogProbsCBA[bid * nBM * 2 + sRank[tid]];
}
}
__syncthreads();
for (int beamIdx = 0; beamIdx < nBM; beamIdx++)
{
// start from step 1 to skip the start token
for (int i = tid; i < sSequenceLengths[beamIdx]; i += blockDim.x)
{
bh.outputIds[bid * nBM * nMaxSeqLen + beamIdx * nMaxSeqLen + i]
= bh.outputIdsCBA[bid * (nBM * 2) * nMaxSeqLen + sRank[beamIdx] * nMaxSeqLen + i];
if (bh.logProbs != nullptr)
{
int const inputLen = inputLengths[bid * nBM + beamIdx];
if (i >= inputLen)
{
bh.logProbs[bid * nBM * nMaxSeqLen + beamIdx * nMaxSeqLen + i - inputLen]
= bh.logProbsCBA[bid * (nBM * 2) * nMaxSeqLen + sRank[beamIdx] * nMaxSeqLen + i];
}
}
}
}
}
void invokeFinalize(BeamHypotheses& bh, cudaStream_t stream)
{
TLLM_LOG_DEBUG("%s %s start", __FILE__, __PRETTY_FUNCTION__);
int const nBM = bh.nBeamWidth;
size_t const smem_size = sizeof(int) * nBM * 2 + sizeof(float) * nBM * 2;
finalizeKernel<<<bh.nBatchSize, roundUp(nBM * 2, 32), smem_size, stream>>>(bh);
}
__global__ void initializeOutput(TokenIdType* finalOutputIds, TokenIdType const* endIds, SizeType const nMaxSeqLen)
{
for (int i = threadIdx.x; i < nMaxSeqLen; i += blockDim.x)
{
finalOutputIds[blockIdx.x * nMaxSeqLen + i] = endIds[blockIdx.x];
}
}
void invokeInitializeOutput(TokenIdType* finalOutputIds, TokenIdType const* endIds, SizeType const batchBeam,
SizeType const nMaxSeqLen, cudaStream_t stream)
{
initializeOutput<<<batchBeam, 256, 0, stream>>>(finalOutputIds, endIds, nMaxSeqLen);
}
__global__ void copyNextStepIds(TokenIdType* nextStepIds, TokenIdType const* const* outputIdsPtr,
SizeType32 const* sequenceLengths, SizeType32 const* numNewTokens, SizeType32 const* batchSlots, SizeType batchSize,
SizeType maxBatchSize, SizeType beamWidth, SizeType maxSeqLen, SizeType maxTokensPerStep)
{
for (auto index = static_cast<SizeType>(blockIdx.x * blockDim.x + threadIdx.x);
index < batchSize * beamWidth * maxTokensPerStep; index += static_cast<SizeType>(blockDim.x * gridDim.x))
{
auto const batchIdx{index / (beamWidth * maxTokensPerStep)};
auto const batchSlot = batchSlots != nullptr ? batchSlots[batchIdx] : batchIdx;
auto const remainder{index % (beamWidth * maxTokensPerStep)};
auto const beamIdx{remainder / maxTokensPerStep};
auto const tokenIdx{remainder % maxTokensPerStep};
auto const newTokens = numNewTokens == nullptr ? 1 : numNewTokens[batchSlot];
auto const batchBeamIdx = batchSlot * beamWidth + beamIdx;
auto const tokenBatchBeamIdx = tokenIdx * maxBatchSize * beamWidth + batchSlot * beamWidth + beamIdx;
auto const index_src = beamIdx * maxSeqLen + sequenceLengths[batchBeamIdx] - newTokens + tokenIdx;
if (tokenIdx >= newTokens || index_src < 0)
{
continue;
}
nextStepIds[tokenBatchBeamIdx] = outputIdsPtr[batchSlot][index_src];
}
}
void invokeCopyNextStepIds(TokenIdType* nextStepIds, TokenIdType const* const* outputIdsPtr,
SizeType32 const* sequenceLengths, SizeType32 const* numNewTokens, SizeType32 const* batchSlots, SizeType batchSize,
SizeType maxBatchSize, SizeType beamWidth, SizeType maxSeqLen, SizeType maxTokensPerStep, cudaStream_t stream)
{
auto const numElems = batchSize * beamWidth * maxTokensPerStep;
dim3 block(min(256, numElems));
dim3 grid(divUp(numElems, block.x));
copyNextStepIds<<<grid, block, 0, stream>>>(nextStepIds, outputIdsPtr, sequenceLengths, numNewTokens, batchSlots,
batchSize, maxBatchSize, beamWidth, maxSeqLen, maxTokensPerStep);
}
__global__ void transposeLogProbs(float* outputLogProbs, float* outputLogProbsTiled, SizeType32 const* sequenceLengths,
SizeType32 const* batchSlots, SizeType batchSize, SizeType maxBatchSize, SizeType beamWidth, SizeType maxSeqLen)
{
auto index = static_cast<SizeType>(blockIdx.x * blockDim.x + threadIdx.x);
auto const batchIdx = index / (beamWidth * maxSeqLen);
auto const tmpIdx = index % (beamWidth * maxSeqLen);
auto const beamIdx = tmpIdx / maxSeqLen;
auto const pos = tmpIdx % maxSeqLen;
if (batchIdx >= batchSize)
{
return;
}
auto const batchSlot = batchSlots != nullptr ? batchSlots[batchIdx] : batchIdx;
if (pos < sequenceLengths[batchSlot])
{
auto const batchBeamIdx = batchSlot * beamWidth * maxSeqLen + beamIdx * maxSeqLen + pos;
outputLogProbs[batchBeamIdx]
= outputLogProbsTiled[pos * maxBatchSize * beamWidth + batchSlot * beamWidth + beamIdx];
}
}
void invokeTransposeLogProbs(float* outputLogProbs, float* outputLogProbsTiled, SizeType32 const* sequenceLengths,
SizeType32 const* batchSlots, SizeType batchSize, SizeType maxBatchSize, SizeType beamWidth, SizeType maxSeqLen,
cudaStream_t stream)
{
dim3 block(256);
dim3 grid(divUp(batchSize * beamWidth * maxSeqLen, block.x));
transposeLogProbs<<<grid, block, 0, stream>>>(outputLogProbs, outputLogProbsTiled, sequenceLengths, batchSlots,
batchSize, maxBatchSize, beamWidth, maxSeqLen);
}
__global__ void acceptDraftTokensByIds(TokenIdType const* draftIds, TokenIdType const* targetIds,
SizeType32 const* contextLengths, SizeType32 const* numsDraftTokens, SizeType32* sequenceLengths,
FinishedState const* finished, FinishedState* finishedFinal, SizeType32* finishedSum, SizeType32 const* batchSlots,
SizeType batchSize, SizeType maxBatchSize, SizeType maxSeqLen, SizeType maxDraftTokens)
{
for (auto batchIdx = static_cast<SizeType>(threadIdx.x); batchIdx < batchSize; batchIdx += blockDim.x)
{
auto const batchSlot = batchSlots == nullptr ? batchIdx : batchSlots[batchIdx];
auto const numDraftTokens = numsDraftTokens[batchSlot];
auto const contextLength = contextLengths[batchSlot];
auto& sequenceLength = sequenceLengths[batchSlot];
SizeType32 finishedDraftIdx = 0;
for (auto ti = contextLength; ti < min(sequenceLength, contextLength + numDraftTokens);
++ti, ++finishedDraftIdx)
{
auto const draftIdx = ti - contextLength;
auto const targetTokenIdx = batchSlot * maxSeqLen + ti;
auto const draftTokenIdx = batchSlot * maxDraftTokens + draftIdx;
// Check if draft tokens are the same as target tokens
bool const accepted = draftIds[draftTokenIdx] == targetIds[targetTokenIdx];
if (!accepted)
{
// Set sequence length to the numAcceptedTokens + 1
sequenceLength = min(ti + 1, maxSeqLen);
// FIXME(nkorobov): do we need to set endIds here?
break;
}
}
FinishedState finishState = finished[finishedDraftIdx * maxBatchSize + batchSlot];
finishedFinal[batchSlot] = finishState;
if (finishedSum)
{
finishedSum[batchSlot] = static_cast<int>(finishState.isFinished());
}
}
}
void invokeAcceptDraftTokensByIds(TokenIdType const* draftIds, TokenIdType const* targetIds,
SizeType32 const* contextLengths, SizeType32 const* numsDraftTokens, SizeType32* sequenceLengths,
FinishedState const* finished, FinishedState* finishedFinal, SizeType32* finishedSum, SizeType32 const* batchSlots,
SizeType batchSize, SizeType maxBatchSize, SizeType beamWidth, SizeType maxSeqLen, SizeType maxDraftTokens,
cudaStream_t stream)
{
TLLM_CHECK(beamWidth == 1);
dim3 block(min(1024, batchSize));
dim3 grid(1);
acceptDraftTokensByIds<<<grid, block, 0, stream>>>(draftIds, targetIds, contextLengths, numsDraftTokens,
sequenceLengths, finished, finishedFinal, finishedSum, batchSlots, batchSize, maxBatchSize, maxSeqLen,
maxDraftTokens);
}
template <typename T>
__global__ void acceptDraftTokensByLogitsKernel(T const* draftProbs, T* targetProbs, SizeType32 const* numsDraftTokens,
FinishedState* finished, curandState_t* curandState, SizeType32 const* batchSlots, SizeType batchSize,
SizeType maxBatchSize, SizeType maxDraftTokens, SizeType beamWidth, SizeType vocabSize, bool randomThreshold,
float constantThreshold)
{
auto const bid = blockIdx.x;
auto const draftTokenIdx = blockIdx.y;
auto const batchIdx = bid / beamWidth;
auto const beamIdx = bid % beamWidth;
auto const batchSlot = batchSlots == nullptr ? batchIdx : batchSlots[batchIdx];
auto const batchSlotBeamWidth = batchSlot * beamWidth + beamIdx;
auto const numDraftTokens = numsDraftTokens[batchSlotBeamWidth];
if (draftTokenIdx >= numDraftTokens)
{
return;
}
auto const logitsOffset = (batchSlot * maxDraftTokens + draftTokenIdx) * beamWidth * vocabSize;
auto const draftProbsBatch = draftProbs + logitsOffset;
auto const targetProbsBatch = targetProbs + logitsOffset;
SizeType32 rejected = 0;
auto vocabSizePadded = static_cast<SizeType32>((vocabSize + blockDim.x - 1) / blockDim.x) * blockDim.x;
for (auto vIdx = static_cast<SizeType32>(threadIdx.x); vIdx < vocabSizePadded;
vIdx += static_cast<SizeType32>(blockDim.x))
{
if (rejected > 0)
{
break;
}
// FIXME(nkorobov): We compare probability distributions, but it might make sense to compare probabilities of
// the selected tokens based on the https://arxiv.org/pdf/2302.01318.pdf
bool const pred = vIdx < vocabSize;
auto const threshold
= pred ? (randomThreshold ? curand_uniform(curandState + batchSlot) : constantThreshold) : 0.f;
auto const targetProb = pred ? static_cast<float>(targetProbsBatch[vIdx]) : 1.f;
auto const draftProb = pred ? static_cast<float>(draftProbsBatch[vIdx]) : 0.f;
rejected = __syncthreads_count(targetProb < threshold * draftProb);
}
if (threadIdx.x == 0)
{
finished[draftTokenIdx * maxBatchSize * beamWidth + batchSlotBeamWidth]
= rejected > 0 ? FinishedState::skipDecoding() : FinishedState::empty();
}
}
template <typename T>
__global__ void correctAcceptedStatesAndLogits(T const* draftProbs, T* targetProbs, T** targetLogits,
SizeType32 const* numsDraftTokens, FinishedState* finished, SizeType32 const* batchSlots, SizeType batchSize,
SizeType maxBatchSize, SizeType maxDraftTokens, SizeType beamWidth, SizeType vocabSize)
{
auto const bid = blockIdx.x;
auto const batchIdx = bid / beamWidth;
auto const beamIdx = bid % beamWidth;
auto const batchSlot = batchSlots == nullptr ? batchIdx : batchSlots[batchIdx];
auto const batchSlotBeamWidth = batchSlot * beamWidth + beamIdx;
auto const numDraftTokens = numsDraftTokens[batchSlotBeamWidth];
__shared__ SizeType32 numAcceptedTokens;
if (threadIdx.x == 0)
{
numAcceptedTokens = numDraftTokens;
bool cummulativeSkipDecoding = false;
for (SizeType32 ti = 0; ti < numDraftTokens + 1; ++ti)
{
auto& finishedState = finished[ti * maxBatchSize * beamWidth + batchSlotBeamWidth];
bool localSkipDecoding = finishedState.isSkipDecoding();
if (cummulativeSkipDecoding == false && localSkipDecoding == true)
{
numAcceptedTokens = ti;
}
finishedState = cummulativeSkipDecoding ? FinishedState::skipDecoding() : FinishedState::empty();
cummulativeSkipDecoding |= localSkipDecoding;
}
}
__syncthreads();
if (numAcceptedTokens < numDraftTokens)
{
auto const logitsIdx = (batchSlot * maxDraftTokens + numAcceptedTokens) * beamWidth * vocabSize;
auto const draftProbBatch = draftProbs + logitsIdx;
auto targetProbBatch = targetProbs + logitsIdx;
auto targetLogitsBatch = targetLogits[bid] + numAcceptedTokens * beamWidth * vocabSize;
float sumProbs = 0.f;
for (SizeType32 vIdx = static_cast<SizeType32>(threadIdx.x); vIdx < vocabSize;
vIdx += static_cast<SizeType32>(blockDim.x))
{
auto const correctedProb = max(static_cast<float>(targetProbBatch[vIdx] - draftProbBatch[vIdx]), 0.f);
sumProbs += correctedProb;
targetProbBatch[vIdx] = correctedProb;
}
__shared__ float sumProbsShared;
sumProbs = blockReduceSum<float>((float) sumProbs);
if (threadIdx.x == 0)
{
sumProbsShared = max(sumProbs, 1e-6f);
}
__syncthreads();
for (SizeType32 vIdx = static_cast<SizeType32>(threadIdx.x); vIdx < vocabSize;
vIdx += static_cast<SizeType32>(blockDim.x))
{
auto const correctedNormProb = static_cast<float>(targetProbBatch[vIdx]) / sumProbsShared;
targetLogitsBatch[vIdx] = __logf(correctedNormProb / (1.f - correctedNormProb));
}
}
}
template <typename T>
void acceptDraftTokensByLogits(T* draftLogits, T** targetLogits, T* draftProbs, T* targetProbs,
SizeType32 const* numsDraftTokens, FinishedState* finished, curandState_t* curandState,
SizeType32 const* batchSlots, SizeType batchSize, SizeType maxBatchSize, SizeType beamWidth, SizeType vocabSize,
SizeType vocabSizePadded, SizeType maxDraftTokens, bool randomThreshold, float constantThreshold,
cudaStream_t stream)
{
TLLM_CHECK(beamWidth == 1);
{
invokeAddBiasSoftMax(draftLogits, static_cast<T**>(nullptr), draftProbs, static_cast<T*>(nullptr), nullptr,
finished, batchSlots, batchSize, maxBatchSize, beamWidth * maxDraftTokens, vocabSize, vocabSizePadded,
/* skip softmax */ false,
/* batchSlotLogits */ true, stream);
invokeAddBiasSoftMax(static_cast<T*>(nullptr), targetLogits, targetProbs, static_cast<T*>(nullptr), nullptr,
finished, batchSlots, batchSize, maxBatchSize, beamWidth * maxDraftTokens, vocabSize, vocabSizePadded,
/* skip softmax */ false,
/* batchSlotLogits */ true, stream);
}
{
dim3 block(1024);
dim3 grid(batchSize * beamWidth, maxDraftTokens);
acceptDraftTokensByLogitsKernel<<<grid, block, 0, stream>>>(draftProbs, targetProbs, numsDraftTokens, finished,
curandState, batchSlots, batchSize, maxBatchSize, maxDraftTokens, beamWidth, vocabSizePadded,
randomThreshold, constantThreshold);
}
{
dim3 block(1024);
dim3 grid(batchSize * beamWidth);
correctAcceptedStatesAndLogits<<<grid, block, 0, stream>>>(draftProbs, targetProbs, targetLogits,
numsDraftTokens, finished, batchSlots, batchSize, maxBatchSize, maxDraftTokens, beamWidth, vocabSizePadded);
}
}
template void acceptDraftTokensByLogits(float* draftLogits, float** targetLogits, float* draftProbs, float* targetProbs,
SizeType32 const* numsDraftTokens, FinishedState* finished, curandState_t* curandState,
SizeType32 const* batchSlots, SizeType batchSize, SizeType maxBatchSize, SizeType beamWidth, SizeType vocabSize,
SizeType vocabSizePadded, SizeType maxDraftTokens, bool randomThreshold, float constantThreshold,
cudaStream_t stream);
template void acceptDraftTokensByLogits(half* draftLogits, half** targetLogits, half* draftProbs, half* targetProbs,
SizeType32 const* numsDraftTokens, FinishedState* finished, curandState_t* curandState,
SizeType32 const* batchSlots, SizeType batchSize, SizeType maxBatchSize, SizeType beamWidth, SizeType vocabSize,
SizeType vocabSizePadded, SizeType maxDraftTokens, bool randomThreshold, float constantThreshold,
cudaStream_t stream);
__device__ __forceinline__ int4 reduceMaxInt4(int4 const& a, int4 const& b)
{
return a.x >= b.x ? a : b;
}
template <typename T, SizeType BLOCK_SIZE>
__global__ void acceptDraftTokensByIdsWithPaths(TokenIdType* outputIds, TokenIdType const* draftIds,
TokenIdType const* targetIds, SizeType32* sequenceLengths, SizeType32* acceptedLengths,
FinishedState* finishedFinal, SizeType32 const* batchSlots, SizeType32 const* paths, TokenIdType const* endIds,
T const** medusaLogits, T const** logitsPtrs, SizeType32* curTokensPerStep, SizeType32 const* targetTokensPerStep,
SizeType32* bestPathIds, SizeType batchSize, SizeType vocabSize, SizeType maxBatchSize, SizeType maxDraftTokens,
SizeType maxSeqLen, SizeType maxNumHeads, SizeType maxTokensPerStep)
{
auto const batchIdx = static_cast<SizeType32>(blockIdx.x);
auto const batchSlot = batchSlots == nullptr ? batchIdx : batchSlots[batchIdx];
auto const inputLength = sequenceLengths[batchSlot];
auto const endId = endIds[batchSlot];
auto const numTokensPerStep = curTokensPerStep[batchSlot];
auto const maxNumDraftTokens = maxNumHeads + 1;
int4 partialMax{-1, -1, 0, 0};
// Go over different paths and construct implicit sequences
for (auto pathIdx = static_cast<SizeType32>(threadIdx.x); pathIdx < maxTokensPerStep;
pathIdx += static_cast<SizeType32>(blockDim.x))
{
auto acceptedLength = maxNumDraftTokens;
auto const pathOffset = flat_index3(batchSlot, pathIdx, 0, maxTokensPerStep, maxNumDraftTokens);
bool hasEnd = false;
auto const tokenId = paths[pathOffset];
// Continue if path does not exist
if (tokenId == -1)
{
continue;
}
auto const targetTokenIdx = batchSlot * maxDraftTokens + tokenId;
auto targetToken = targetIds[targetTokenIdx];
auto nextIdx = tokenId;
// Go along the path
for (SizeType ti = 1; ti < maxNumDraftTokens; ++ti)
{
auto const tokenId = paths[pathOffset + ti];
// Break if path terminates
if (tokenId == -1)
{
acceptedLength = ti;
break;
}
auto const targetTokenIdx = batchSlot * maxDraftTokens + tokenId;
auto const draftTokenIdx = batchSlot * (maxDraftTokens - 1) + tokenId - 1;
// In context phase, no draft tokens are given. Set draft token to -1 to get guaranteed rejection
auto const draftToken = tokenId >= numTokensPerStep ? -1 : draftIds[draftTokenIdx];
// Check if draft tokens are the same as target tokens
bool const accepted = draftToken == targetToken;
hasEnd = targetToken == endId;
if (!accepted || hasEnd)
{
acceptedLength = hasEnd ? ti - 1 : ti;
break;
}
targetToken = targetIds[targetTokenIdx];
nextIdx = tokenId;
}
// Get longest path of the thread
if (partialMax.x < acceptedLength)
{
partialMax.x = acceptedLength;
partialMax.y = pathIdx;
partialMax.z = hasEnd;
partialMax.w = nextIdx;
}
}
// Get the longest path of the block (request)
typedef cub::BlockReduce<int4, BLOCK_SIZE> BlockReduce;
__shared__ typename BlockReduce::TempStorage tempStorage;
int4 total = BlockReduce(tempStorage).Reduce(partialMax, reduceMaxInt4);
__shared__ int4 totalShared;
if (threadIdx.x == 0)
{
totalShared = total;
}
__syncthreads();
auto const acceptedLength = totalShared.x;
auto const bestPathIdx = totalShared.y;
auto const bestNextIdx = numTokensPerStep == 1 ? 0 : totalShared.w;
auto const pathOffset = flat_index3(batchSlot, bestPathIdx, 0, maxTokensPerStep, maxNumDraftTokens);
for (auto ti = static_cast<SizeType32>(threadIdx.x); ti < acceptedLength; ti += static_cast<SizeType32>(blockDim.x))
{
auto const tokenId = paths[pathOffset + ti];
auto const targetSrcTokenIdx = batchSlot * maxDraftTokens + tokenId;
auto const outputTokenIdx = batchSlot * maxSeqLen + inputLength + ti;
auto const targetToken = targetIds[targetSrcTokenIdx];
// Copy accepted tokens to the sequence with draft tokens (outputIds === outputIds)
outputIds[outputTokenIdx] = targetToken;
}
// Leading thread reconstructs winning path and sets new data
if (threadIdx.x == 0)
{
auto const hasEnd = totalShared.z;
// Set end condition
if (hasEnd)
{
finishedFinal[batchSlot].setFinishedEOS();
}
// Make correction to the sequence length
sequenceLengths[batchSlot] += acceptedLength;
acceptedLengths[batchSlot] = acceptedLength;
// In Medusa decoding step, number of draft tokens is 0 and must be updated for the next steps
if (numTokensPerStep == 1)
{
curTokensPerStep[batchSlot] = targetTokensPerStep[batchSlot];
}
bestPathIds[batchSlot] = bestPathIdx;
}
// Prepare logits pointers to respective logits from Medusa Heads for the all-top-K sampling kernel
for (auto hi = static_cast<SizeType>(threadIdx.x); hi < maxNumHeads; hi += static_cast<SizeType>(blockDim.x))
{
logitsPtrs[batchIdx * maxNumHeads + hi]
= medusaLogits[batchSlot * maxNumHeads + hi] + flat_index2(bestNextIdx, 0, vocabSize);
}
}
template <typename T>
void acceptDraftTokensByIdsWithPaths(TokenIdType* outputIds, TokenIdType const* draftIds, TokenIdType const* targetIds,
SizeType32* sequenceLengths, SizeType32* acceptedLengths, FinishedState* finishedFinal,
SizeType32 const* batchSlots, SizeType32 const* paths, TokenIdType const* endIds, T const** medusaLogits,
T const** logitsPtrs, SizeType32* curTokensPerStep, SizeType32 const* targetTokensPerStep, SizeType32* bestPathIds,
SizeType batchSize, SizeType vocabSize, SizeType maxBatchSize, SizeType maxDraftTokens, SizeType maxSeqLen,
SizeType maxNumHeads, SizeType maxTokensPerStep, cudaStream_t stream)
{
constexpr SizeType BLOCK_SIZE = 256;
dim3 block(BLOCK_SIZE);
dim3 grid(batchSize);
acceptDraftTokensByIdsWithPaths<T, BLOCK_SIZE><<<grid, block, 0, stream>>>(outputIds, draftIds, targetIds,
sequenceLengths, acceptedLengths, finishedFinal, batchSlots, paths, endIds, medusaLogits, logitsPtrs,
curTokensPerStep, targetTokensPerStep, bestPathIds, batchSize, vocabSize, maxBatchSize, maxDraftTokens,
maxSeqLen, maxNumHeads, maxTokensPerStep);
}
template void acceptDraftTokensByIdsWithPaths(TokenIdType* outputIds, TokenIdType const* draftIds,
TokenIdType const* targetIds, SizeType32* sequenceLengths, SizeType32* acceptedLengths,
FinishedState* finishedFinal, SizeType32 const* batchSlots, SizeType32 const* paths, TokenIdType const* endIds,
float const** medusaLogits, float const** logitsPtrs, SizeType32* curTokensPerStep,
SizeType32 const* targetTokensPerStep, SizeType32* bestPathIds, SizeType batchSize, SizeType vocabSize,
SizeType maxBatchSize, SizeType maxDraftTokens, SizeType maxSeqLen, SizeType maxNumHeads, SizeType maxTokensPerStep,
cudaStream_t stream);
template void acceptDraftTokensByIdsWithPaths(TokenIdType* outputIds, TokenIdType const* draftIds,
TokenIdType const* targetIds, SizeType32* sequenceLengths, SizeType32* acceptedLengths,
FinishedState* finishedFinal, SizeType32 const* batchSlots, SizeType32 const* paths, TokenIdType const* endIds,
half const** medusaLogits, half const** logitsPtrs, SizeType32* curTokensPerStep,
SizeType32 const* targetTokensPerStep, SizeType32* bestPathIds, SizeType batchSize, SizeType vocabSize,
SizeType maxBatchSize, SizeType maxDraftTokens, SizeType maxSeqLen, SizeType maxNumHeads, SizeType maxTokensPerStep,
cudaStream_t stream);
__global__ void scatterMedusaDraftTokens(TokenIdType* treeDraftIds, TokenIdType const* sourceDraftIds,
SizeType32 const* treeIds, SizeType32 const* tokensPerStepData, SizeType32 const* batchSlots,
SizeType maxTokensPerStep)
{
auto const batchIdx = static_cast<SizeType32>(blockIdx.x);
auto const batchSlot = batchSlots == nullptr ? batchIdx : batchSlots[batchIdx];
auto const tokensPerStep = tokensPerStepData[batchSlot];
auto const maxDraftTokens = maxTokensPerStep - 1;
for (auto index = static_cast<SizeType32>(threadIdx.x); index < tokensPerStep - 1;
index += static_cast<SizeType32>(blockDim.x))
{
auto const indexInTree = treeIds[batchSlot * maxDraftTokens + index];
auto const treeDraftIdx = batchSlot * maxDraftTokens + index;
auto const sourceDraftIdx = batchSlot * maxTokensPerStep + indexInTree;
treeDraftIds[treeDraftIdx] = sourceDraftIds[sourceDraftIdx];
}
}
void scatterMedusaDraftTokens(TokenIdType* treeDraftIds, TokenIdType const* sourceDraftIds, SizeType32 const* treeIds,
SizeType32 const* tokensPerStep, SizeType32 const* batchSlots, SizeType maxDraftTokens, SizeType batchSize,
cudaStream_t stream)
{
constexpr SizeType BLOCK_SIZE = 256;
scatterMedusaDraftTokens<<<batchSize, BLOCK_SIZE, 0, stream>>>(
treeDraftIds, sourceDraftIds, treeIds, tokensPerStep, batchSlots, maxDraftTokens);
}
template <int32_t BLOCK_SIZE>
__global__ void packAcceptedPaths(SizeType32* acceptedLengthsCumSum, SizeType32* pathsOffsets,
SizeType const* acceptedLengths, SizeType32 const* bestPathIds, SizeType32 const* paths,
SizeType32 const* batchSlots, SizeType batchSize, SizeType maxTokensPerStep, SizeType maxNumDraftTokens)
{
// Specialize BlockScan for a 1D block of 128 threads of type int
typedef cub::BlockScan<SizeType, BLOCK_SIZE> BlockScan;
// Allocate shared memory for BlockScan
__shared__ typename BlockScan::TempStorage tempStorage;
auto const batchSizeRounded = ((batchSize + BLOCK_SIZE - 1) / BLOCK_SIZE) * BLOCK_SIZE;
__shared__ SizeType currentCumSum;
if (threadIdx.x == 0)
{
currentCumSum = 0;
}
__syncthreads();
for (auto bi = static_cast<SizeType>(threadIdx.x); bi < batchSizeRounded; bi += static_cast<SizeType>(blockDim.x))
{
auto const valid = bi < batchSize;
auto const batchSlot = valid ? batchSlots[bi] : 0;
auto const acceptedLen = valid ? acceptedLengths[batchSlot] - 1 : 0;
SizeType32 cumSum;
BlockScan(tempStorage).ExclusiveSum(acceptedLen + currentCumSum, cumSum);
if (threadIdx.x == blockDim.x - 1)
{
currentCumSum = cumSum;
}
__syncthreads();
if (valid)
{
acceptedLengthsCumSum[bi] = cumSum;
auto const bestPathIdx = bestPathIds[batchSlot];
auto const pathIdx = flat_index3(batchSlot, bestPathIdx, 0, maxTokensPerStep, maxNumDraftTokens);
for (SizeType32 ti = 0; ti < acceptedLen; ++ti)
{
pathsOffsets[cumSum + ti] = paths[pathIdx + ti + 1] - 1;
}
}
}
if (threadIdx.x == 0)
{
acceptedLengthsCumSum[batchSize] = currentCumSum;
}
}
void invokePackAcceptedPaths(SizeType32* acceptedLengthsCumSum, SizeType32* pathsOffsets,
SizeType32 const* acceptedLengths, SizeType32 const* bestPathIds, SizeType32 const* paths,
SizeType32 const* batchSlots, SizeType batchSize, SizeType maxTokensPerStep, SizeType maxNumDraftTokens,
cudaStream_t stream)
{
constexpr SizeType BLOCK_SIZE = 1024;
packAcceptedPaths<BLOCK_SIZE><<<1, BLOCK_SIZE, 0, stream>>>(acceptedLengthsCumSum, pathsOffsets, acceptedLengths,
bestPathIds, paths, batchSlots, batchSize, maxTokensPerStep, maxNumDraftTokens);
}
} // namespace kernels
} // namespace tensorrt_llm