TensorRT-LLMs/cpp/tensorrt_llm/kernels/decodingKernels.cu
Robin Kobus 45134d7095
refactor: Improve decoder finalize function (#3077)
* refactor: Update gatherTree function to accept CUDA stream parameter

This commit modifies the gatherTree function signature to include a runtime::CudaStream parameter, enhancing flexibility in stream management. Additionally, it removes unnecessary buffer manager parameters and stream handling from the function, streamlining the code. The finalize method in GptDecoderBatched is also updated to reflect these changes, improving clarity and maintainability in the decoding process.

Signed-off-by: Robin Kobus <19427718+Funatiq@users.noreply.github.com>

* refactor: Update GptDecoderBatched finalize

This commit refactors the GptDecoderBatched class to improve method signatures and reduce code complexity:

- Modified finalize method to accept DecoderState as a parameter
- Updated method signatures to work with the new DecoderState approach
- Improved code organization and readability

The changes continue the ongoing refactoring to centralize decoder state management and simplify the decoder implementation.

Signed-off-by: Robin Kobus <19427718+Funatiq@users.noreply.github.com>

---------

Signed-off-by: Robin Kobus <19427718+Funatiq@users.noreply.github.com>
2025-03-28 14:33:59 +08:00

812 lines
32 KiB
Plaintext

/*
* Copyright (c) 2020-2024, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "tensorrt_llm/common/assert.h"
#include "tensorrt_llm/common/cudaUtils.h"
#include "tensorrt_llm/common/reduceKernelUtils.cuh"
#include "tensorrt_llm/kernels/decodingKernels.h"
#ifndef CUDART_VERSION
#error CUDART_VERSION Undefined!
#elif (CUDART_VERSION >= 11050)
#include <cub/cub.cuh>
#else
#include "3rdparty/cub/cub.cuh"
#endif
using namespace tensorrt_llm::common;
using namespace tensorrt_llm::runtime;
namespace tensorrt_llm
{
namespace kernels
{
class CopyBeamHypothesesStruct
{
public:
TokenIdType const* srcOutputIdsCBA; // [BS, BM*2, MSL]
TokenIdType* dstOutputIdsCBA; // [BS, BM*2, MSL]
SizeType32 outputIdsNumElts;
float const* srcLogProbsCBA; // [BS, BM*2, MSL]
float* dstLogProbsCBA; // [BS, BM*2, MSL]
SizeType32 logProbsNumElts;
SizeType32 const* srcSequenceLengthsCBA; // [BS, BM*2]
SizeType32* dstSequenceLengthsCBA; // [BS, BM*2]
SizeType32 sequenceLengthsNumElts;
float const* srcCumLogProbsCBA; // [BS, BM*2]
float* dstCumLogProbsCBA; // [BS, BM*2]
SizeType32 cumLogProbsCBANumElts;
float const* srcNormedScoresCBA; // [BS, BM*2]
float* dstNormedScoresCBA; // [BS, BM*2]
SizeType32 normedScoresNumElts;
SizeType32 const* srcNumBeamsCBA; // [BS]
SizeType32* dstNumBeamsCBA; // [BS]
SizeType32 numBeamsNumElts;
float const* srcMinNormedScoresCBA; // [BS]
float* dstMinNormedScoresCBA; // [BS]
SizeType32 minNormedScoresNumElts;
bool const* srcBatchDones; // [BS]
bool* dstBatchDones; // [BS]
SizeType32 batchDonesNumElts;
float const* srcCumLogProbs; // [BS, BM]
float* dstCumLogProbs; // [BS, BM]
SizeType32 cumLogProbsNumElts;
};
__global__ void gatherTree(gatherTreeParam param)
{
for (int batchbeamIdx = blockIdx.x * blockDim.x + threadIdx.x; batchbeamIdx < param.batchSize * param.beamWidth;
batchbeamIdx += gridDim.x * blockDim.x)
{
int const batch = batchbeamIdx / param.beamWidth;
int const beam = batchbeamIdx % param.beamWidth;
int const inputLen = param.inputLengths == nullptr ? 0 : param.inputLengths[batchbeamIdx];
int const* parentIds = param.parentIds;
int const* stepIds = param.stepIds;
// TODO optimize the reduce_max operation for large beamWidth
int maxLen = -1;
bool updateResponseInputLength = param.responseInputLengths != nullptr;
// int selected_beam_index = 0;
for (int beamIdx = 0; beamIdx < param.beamWidth; beamIdx++)
{
int tmpLen
= param.sequenceLengths[batch * param.beamWidth + beamIdx] + param.maxSequenceLengthFinalStep - 1;
param.sequenceLengths[batch * param.beamWidth + beamIdx] = tmpLen;
if (updateResponseInputLength)
{
param.responseInputLengths[batch * param.beamWidth + beamIdx] = inputLen;
}
if (tmpLen > maxLen)
{
maxLen = tmpLen;
}
}
int const maxSeqLenB = min(param.maxSeqLen, maxLen);
if (maxSeqLenB <= 0)
{
continue;
}
int const initialTgtIx = batch * param.beamWidth * param.maxSeqLen + beam * param.maxSeqLen + maxSeqLenB - 1;
int const initialParentIx = batch * param.beamWidth * param.maxSeqLen + beam * param.maxSeqLen + maxSeqLenB - 1;
param.outputIds[initialTgtIx] = __ldg(stepIds + initialParentIx);
int parent = parentIds == nullptr ? 0 : __ldg(parentIds + initialParentIx) % param.beamWidth;
bool foundBad = false;
for (int level = maxSeqLenB - 2; level >= 0; --level)
{
int const levelBeamIx = batch * param.beamWidth * param.maxSeqLen + beam * param.maxSeqLen + level;
int const levelParentIx = batch * param.beamWidth * param.maxSeqLen + parent * param.maxSeqLen + level;
if (parent < 0 || parent > param.beamWidth)
{
param.outputIds[levelBeamIx] = param.endTokens[batch];
parent = -1;
foundBad = true;
}
else
{
param.outputIds[levelBeamIx] = __ldg(stepIds + levelParentIx);
parent = parentIds == nullptr ? 0 : __ldg(parentIds + levelParentIx) % param.beamWidth;
}
}
// set the padded part as end_token
// inputLen
for (int index = maxLen; index < param.maxSeqLen; ++index)
{
param.outputIds[batch * param.beamWidth * param.maxSeqLen + beam * param.maxSeqLen + index]
= param.endTokens[batch];
}
// Not necessary when using a BeamSearchDecoder, but necessary
// when a user feeds in possibly broken trajectory (i.e., non-eos
// entries in a beam following eos entries).
if (!foundBad)
{
bool finished = false;
// skip the step 0 because it is often the start token
int startStep = 1;
for (int time = startStep; time < maxSeqLenB; ++time)
{
int const levelBeamIx = batch * param.beamWidth * param.maxSeqLen + beam * param.maxSeqLen + time;
if (finished)
{
param.outputIds[levelBeamIx] = param.endTokens[batch];
}
else if (param.outputIds[levelBeamIx] == param.endTokens[batch])
{
finished = true;
}
}
}
}
}
struct RankNorm
{
int rank;
float norm;
};
inline __device__ RankNorm swap(RankNorm const& rankNorm, int mask, int dir)
{
// Exchange RankNorm data inside the warp
RankNorm other;
other.rank = __shfl_xor_sync(unsigned(-1), rankNorm.rank, mask);
other.norm = __shfl_xor_sync(unsigned(-1), rankNorm.norm, mask);
// dir == 0 -> return larger one
// dir == 1 -> return smaller one
bool doSwap = (rankNorm.norm != other.norm) && ((rankNorm.norm > other.norm) == dir);
return doSwap ? other : rankNorm;
}
inline __device__ uint32_t bfe(uint32_t a, uint32_t start, uint32_t len = 1)
{
uint32_t d;
asm volatile("bfe.u32 %0, %1, %2, %3;" : "=r"(d) : "r"(a), "r"(start), "r"(len));
return d;
}
__global__ void finalized(gatherTreeParam param)
{
int const beamIdx = static_cast<int>(threadIdx.x);
int const beamWidth{param.beamWidth};
extern __shared__ char array[];
int* sRank = (int*) (array);
int* sLength = (int*) (sRank + beamWidth);
float* sScores = (float*) (sLength + beamWidth);
float* sNormedScores = (float*) (sScores + beamWidth);
int* sIds = (int*) (sNormedScores + beamWidth);
if (beamIdx < beamWidth)
{
int const idx = blockIdx.x * param.beamWidth + beamIdx;
int const numGeneratedToken{param.sequenceLengths[idx] - param.inputLengths[idx]};
sNormedScores[beamIdx] = applyLengthPenalty(param.cumLogProbs[idx], numGeneratedToken, param.lengthPenalty);
sLength[beamIdx] = param.sequenceLengths[idx];
sScores[beamIdx] = param.cumLogProbs[idx];
}
for (int idx = beamIdx; idx < beamWidth * param.maxSeqLen; idx += blockDim.x)
{
sIds[idx] = param.outputIds[blockIdx.x * param.beamWidth * param.maxSeqLen + idx];
}
__syncthreads();
RankNorm rankNorm;
rankNorm.rank = beamIdx;
rankNorm.norm = beamIdx < beamWidth ? sNormedScores[beamIdx] : -FLT_MAX;
if (beamWidth < 32)
{
int warpid = threadIdx.x / 32;
int laneid = threadIdx.x % 32;
if (warpid == 0 && beamWidth > 1)
{
rankNorm = swap(rankNorm, 0x01, bfe(laneid, 1) ^ bfe(laneid, 0)); // 2
}
if (warpid == 0 && beamWidth > 2)
{
rankNorm = swap(rankNorm, 0x02, bfe(laneid, 2) ^ bfe(laneid, 1)); // 3~4
rankNorm = swap(rankNorm, 0x01, bfe(laneid, 2) ^ bfe(laneid, 0));
}
if (warpid == 0 && beamWidth > 4)
{
rankNorm = swap(rankNorm, 0x04, bfe(laneid, 3) ^ bfe(laneid, 2)); // 5~8
rankNorm = swap(rankNorm, 0x02, bfe(laneid, 3) ^ bfe(laneid, 1));
rankNorm = swap(rankNorm, 0x01, bfe(laneid, 3) ^ bfe(laneid, 0));
}
if (warpid == 0 && beamWidth > 8)
{
rankNorm = swap(rankNorm, 0x08, bfe(laneid, 4) ^ bfe(laneid, 3)); // 9~16
rankNorm = swap(rankNorm, 0x04, bfe(laneid, 4) ^ bfe(laneid, 2));
rankNorm = swap(rankNorm, 0x02, bfe(laneid, 4) ^ bfe(laneid, 1));
rankNorm = swap(rankNorm, 0x01, bfe(laneid, 4) ^ bfe(laneid, 0));
}
if (warpid == 0 && beamWidth > 16)
{
rankNorm = swap(rankNorm, 0x10, bfe(laneid, 5) ^ bfe(laneid, 4)); // 17~32
rankNorm = swap(rankNorm, 0x08, bfe(laneid, 5) ^ bfe(laneid, 3));
rankNorm = swap(rankNorm, 0x04, bfe(laneid, 5) ^ bfe(laneid, 2));
rankNorm = swap(rankNorm, 0x02, bfe(laneid, 5) ^ bfe(laneid, 1));
rankNorm = swap(rankNorm, 0x01, bfe(laneid, 5) ^ bfe(laneid, 0));
}
}
else
{
// Not supported! We must have a check before calling that kernel.
}
if (beamIdx < beamWidth)
{
sRank[beamIdx] = rankNorm.rank;
}
__syncthreads();
if (beamIdx < beamWidth)
{
auto srcIdx{rankNorm.rank};
auto tgtIdx{blockIdx.x * param.beamWidth + beamIdx};
param.sequenceLengths[tgtIdx] = sLength[srcIdx];
param.cumLogProbs[tgtIdx] = sScores[srcIdx];
}
for (int beamIdx = 0; beamIdx < beamWidth; beamIdx++)
{
for (int i = threadIdx.x; i < sLength[sRank[beamIdx]]; i += blockDim.x)
{
param.outputIds[blockIdx.x * beamWidth * param.maxSeqLen + beamIdx * param.maxSeqLen + i]
= sIds[sRank[beamIdx] * param.maxSeqLen + i];
}
}
}
void invokeGatherTree(gatherTreeParam param)
{
int batchbeam = param.batchSize * param.beamWidth;
dim3 grid(1), block(batchbeam);
// though decoder do not support > 1024 for now
if (batchbeam > 1024)
{
grid.x = ceil(param.batchSize * param.beamWidth / 1024.);
block.x = 1024;
}
gatherTree<<<grid, block, 0, param.stream>>>(param);
sync_check_cuda_error(param.stream);
if (param.beamWidth > 1)
{
TLLM_CHECK_WITH_INFO(param.beamWidth <= 32, "TRT-LLM does not support beam width > 32 now");
// sort results by normalized cumLogProbs
dim3 grid(param.batchSize);
dim3 block(divUp(param.beamWidth, 32) * 32);
auto shm_size = param.beamWidth * (sizeof(float) * 2 + sizeof(int) * 2 + sizeof(int) * param.maxSeqLen);
finalized<<<grid, block, shm_size, param.stream>>>(param);
}
}
__global__ void insertUnfinishedPathKernel(BeamHypotheses bh)
{
// Move ALL unfinished beams from bh.outputIdsUnfinish to bh.outputIdsCBA
// So here might be more than `nBM` beams in bh.outputIdsCBA after this kernel
// Data movement:
// bh.outputIdsUnfinish -> bh.outputIdsCBA
// bh.sequenceLengths -> bh.sequenceLengthsCBA
// bh.cumLogProbs -> bh.cumLogProbsCBA
// bh.logProbsTiled -> bh.logProbsCBA
// update bh.normedScoresCBA
// update bh.numBeamsCBA
size_t const bid = blockIdx.x; // Index of Batch
size_t const nBM{bh.nBeamWidth};
size_t const nMBS{bh.nMaxBatchSize}; // Only for bh.logProbsTiled
size_t const nMSL{bh.nMaxSeqLen};
bool const bOutputLogProbs{bh.logProbsCBA != nullptr && bh.logProbsTiled != nullptr};
int const indexDstStart{bh.numBeamsCBA[bid]};
if (bh.batchDones[bid])
{
return;
}
for (int i = 0; i < nBM; ++i)
{
int const srcBeam = bid * nBM + i;
int const dstBeam = bid * nBM * 2 + i + indexDstStart;
int const step = bh.sequenceLengths[srcBeam] - 1;
// The last token
int const srcId = srcBeam * nMSL + step;
int const dstId = dstBeam * nMSL + step;
bh.outputIdsCBA[dstId] = bh.outputIdsUnfinish[srcId];
if (bOutputLogProbs)
{
bh.logProbsCBA[dstId] = bh.logProbsTiled[step * nMBS * nBM + srcBeam];
}
// Previous tokens
int prevId = bh.parentIdsUnfinish[srcId];
for (int j = step - 1; j >= 0; --j)
{
int const index = bid * nBM * nMSL + prevId * nMSL + j;
bh.outputIdsCBA[dstBeam * nMSL + j] = bh.outputIdsUnfinish[index];
prevId = bh.parentIdsUnfinish[index];
}
if (bOutputLogProbs)
{
prevId = bh.parentIdsUnfinish[srcId];
for (int j = step - 1; j >= 0; --j)
{
int const index = bid * nBM * nMSL + prevId * nMSL + j;
bh.logProbsCBA[dstBeam * nMSL + j] = bh.logProbsTiled[j * nMBS * nBM + bid * nBM + prevId];
prevId = bh.parentIdsUnfinish[index];
}
}
// Other parameters
bh.sequenceLengthsCBA[dstBeam] = bh.sequenceLengths[srcBeam];
bh.normedScoresCBA[dstBeam]
= applyLengthPenalty(bh.cumLogProbs[srcBeam], step - bh.inputLengths[srcBeam] + 1, bh.lengthPenalties[bid]);
bh.cumLogProbsCBA[dstBeam] = bh.cumLogProbs[srcBeam];
bh.numBeamsCBA[bid]++;
}
}
void invokeInsertUnfinishedPath(BeamHypotheses& bh, cudaStream_t stream)
{
insertUnfinishedPathKernel<<<bh.nBatchSize, 1, 0, stream>>>(bh);
}
__global__ void finalizeKernel(BeamHypotheses bh)
{
// Do index sort on bh.normedScoresCBA, then move buffers from CBA to output by the order of index
// Data movement:
// bh.outputIdsCBA -> bh.outputIds
// bh.sequenceLengthsCBA -> bh.sequenceLengths
// bh.cumLogProbsCBA -> bh.cumLogProbs
// bh.logProbsCBA -> bh.logProbs
int const bid = blockIdx.x; // Index of Batch
int const tid = threadIdx.x; // Index of Beam
size_t const nBM{bh.nBeamWidth};
size_t const nMSL{bh.nMaxSeqLen};
int const nCBA{bh.numBeamsCBA[bid]}; // Count of candidates in CBA, nBM <= nCBA <= 2*nBM
extern __shared__ char smem[];
int* smemRank = (int*) (smem); // [nBM]
float* smemScore = (float*) (smemRank + nBM); // [2*nBM]
int* smemSL = (int*) (smemScore + nBM * 2); // [nBM]
// Sort
for (int i = tid; i < nCBA; i += blockDim.x)
{
smemScore[i] = bh.normedScoresCBA[bid * nBM * 2 + i];
}
__syncthreads();
if (nCBA <= 32)
{
int const warpid = tid / 32;
int const laneid = tid % 32;
RankNorm rankNorm{tid, tid < nCBA ? smemScore[tid] : -FLT_MAX};
if (warpid == 0 && nCBA > 1)
{
rankNorm = swap(rankNorm, 0x01, bfe(laneid, 1) ^ bfe(laneid, 0)); // 2
}
if (warpid == 0 && nCBA > 2)
{
rankNorm = swap(rankNorm, 0x02, bfe(laneid, 2) ^ bfe(laneid, 1)); // 3~4
rankNorm = swap(rankNorm, 0x01, bfe(laneid, 2) ^ bfe(laneid, 0));
}
if (warpid == 0 && nCBA > 4)
{
rankNorm = swap(rankNorm, 0x04, bfe(laneid, 3) ^ bfe(laneid, 2)); // 5~8
rankNorm = swap(rankNorm, 0x02, bfe(laneid, 3) ^ bfe(laneid, 1));
rankNorm = swap(rankNorm, 0x01, bfe(laneid, 3) ^ bfe(laneid, 0));
}
if (warpid == 0 && nCBA > 8)
{
rankNorm = swap(rankNorm, 0x08, bfe(laneid, 4) ^ bfe(laneid, 3)); // 9~16
rankNorm = swap(rankNorm, 0x04, bfe(laneid, 4) ^ bfe(laneid, 2));
rankNorm = swap(rankNorm, 0x02, bfe(laneid, 4) ^ bfe(laneid, 1));
rankNorm = swap(rankNorm, 0x01, bfe(laneid, 4) ^ bfe(laneid, 0));
}
if (warpid == 0 && nCBA > 16)
{
rankNorm = swap(rankNorm, 0x10, bfe(laneid, 5) ^ bfe(laneid, 4)); // 17~32
rankNorm = swap(rankNorm, 0x08, bfe(laneid, 5) ^ bfe(laneid, 3));
rankNorm = swap(rankNorm, 0x04, bfe(laneid, 5) ^ bfe(laneid, 2));
rankNorm = swap(rankNorm, 0x02, bfe(laneid, 5) ^ bfe(laneid, 1));
rankNorm = swap(rankNorm, 0x01, bfe(laneid, 5) ^ bfe(laneid, 0));
}
if (tid < nBM)
{
smemRank[tid] = rankNorm.rank;
}
__syncthreads();
}
else
{
for (int i = 0; i < nBM; ++i)
{
float maxScore = -FLT_MAX;
for (int j = 0; j < (nCBA + 1024 - 1) / 1024; ++j)
{
int const index = tid + 1024 * j;
float const score = (index < bh.numBeamsCBA[bid]) ? smemScore[index] : -FLT_MAX;
float const maxScore1 = blockReduceMax<float>(score);
maxScore = max(maxScore, maxScore1);
}
if (tid == 0)
{
for (int j = 0; j < nCBA; ++j)
{
if (smemScore[j] == maxScore)
{
smemRank[i] = j;
smemScore[j] = -FLT_MAX;
break;
}
}
}
__syncthreads();
}
}
// Move bh.sequenceLengths, bh.cumLogProbs
if (tid < nBM)
{
smemSL[tid] = bh.sequenceLengthsCBA[bid * nBM * 2 + smemRank[tid]];
bh.sequenceLengths[bid * nBM + tid] = smemSL[tid];
if (bh.cumLogProbs != nullptr)
{
bh.cumLogProbs[bid * nBM + tid] = bh.cumLogProbsCBA[bid * nBM * 2 + smemRank[tid]];
}
}
__syncthreads();
// Move bh.outputIds, bh.logProbs
for (int beamIdx = 0; beamIdx < nBM; beamIdx++)
{
for (int i = tid; i < smemSL[beamIdx]; i += blockDim.x)
{
int const dst = bid * nBM * nMSL + beamIdx * nMSL + i;
int const src = bid * nBM * 2 * nMSL + smemRank[beamIdx] * nMSL + i;
bh.outputIds[dst] = bh.outputIdsCBA[src];
}
if (bh.logProbs != nullptr)
{
for (int i = tid; i < smemSL[beamIdx]; i += blockDim.x)
{
if (int const inputLength = bh.inputLengths[bid * nBM + beamIdx]; i >= inputLength)
{
int const dst = bid * nBM * nMSL + beamIdx * nMSL + i;
int const src = bid * nBM * 2 * nMSL + smemRank[beamIdx] * nMSL + i;
bh.logProbs[dst - inputLength] = bh.logProbsCBA[src];
}
}
}
}
}
void invokeFinalize(BeamHypotheses& bh, cudaStream_t stream)
{
TLLM_LOG_TRACE("%s %s start", __FILE__, __PRETTY_FUNCTION__);
int const nBM = bh.nBeamWidth;
int const nThread = min(roundUp(nBM * 2, 32), 1024);
size_t const nByteSharedMemory = (sizeof(int) + sizeof(float)) * nBM * 2;
finalizeKernel<<<bh.nBatchSize, nThread, nByteSharedMemory, stream>>>(bh);
sync_check_cuda_error(stream);
TLLM_LOG_TRACE("%s %s stop", __FILE__, __PRETTY_FUNCTION__);
}
__global__ void copyBeamHypotheses(CopyBeamHypothesesStruct copyStruct)
{
auto const idx = static_cast<SizeType32>(threadIdx.x + blockIdx.x * blockDim.x);
auto const stride = static_cast<SizeType32>(blockDim.x * gridDim.x);
for (SizeType32 ii = idx; ii < copyStruct.outputIdsNumElts; ii += stride)
{
copyStruct.dstOutputIdsCBA[ii] = copyStruct.srcOutputIdsCBA[ii];
}
for (SizeType32 ii = idx; ii < copyStruct.logProbsNumElts; ii += stride)
{
copyStruct.dstLogProbsCBA[ii] = copyStruct.srcLogProbsCBA[ii];
}
for (SizeType32 ii = idx; ii < copyStruct.cumLogProbsNumElts; ii += stride)
{
copyStruct.dstCumLogProbs[ii] = copyStruct.srcCumLogProbs[ii];
}
for (SizeType32 ii = idx; ii < copyStruct.sequenceLengthsNumElts; ii += stride)
{
copyStruct.dstSequenceLengthsCBA[ii] = copyStruct.srcSequenceLengthsCBA[ii];
}
for (SizeType32 ii = idx; ii < copyStruct.cumLogProbsCBANumElts; ii += stride)
{
copyStruct.dstCumLogProbsCBA[ii] = copyStruct.srcCumLogProbsCBA[ii];
}
for (SizeType32 ii = idx; ii < copyStruct.normedScoresNumElts; ii += stride)
{
copyStruct.dstNormedScoresCBA[ii] = copyStruct.srcNormedScoresCBA[ii];
}
for (SizeType32 ii = idx; ii < copyStruct.numBeamsNumElts; ii += stride)
{
copyStruct.dstNumBeamsCBA[ii] = copyStruct.srcNumBeamsCBA[ii];
}
for (SizeType32 ii = idx; ii < copyStruct.minNormedScoresNumElts; ii += stride)
{
copyStruct.dstMinNormedScoresCBA[ii] = copyStruct.srcMinNormedScoresCBA[ii];
}
for (SizeType32 ii = idx; ii < copyStruct.batchDonesNumElts; ii += stride)
{
copyStruct.dstBatchDones[ii] = copyStruct.srcBatchDones[ii];
}
}
void invokeCopyBeamHypotheses(DecodingOutput::BeamHypotheses const& src, DecodingOutput::BeamHypotheses const& dst,
ITensor& srcCumLogProbs, ITensor& dstCumLogProbs, runtime::CudaStream const& stream, SizeType32 numSMs)
{
CopyBeamHypothesesStruct copyStruct = {};
copyStruct.srcOutputIdsCBA = bufferCast<TokenIdType>(*(src.outputIdsCBA));
copyStruct.dstOutputIdsCBA = bufferCast<TokenIdType>(*(dst.outputIdsCBA));
copyStruct.outputIdsNumElts = dst.outputIdsCBA->getSize();
copyStruct.srcLogProbsCBA = bufferCast<float>(*(src.logProbsCBA));
copyStruct.dstLogProbsCBA = bufferCast<float>(*(dst.logProbsCBA));
copyStruct.logProbsNumElts = dst.logProbsCBA->getSize();
copyStruct.srcSequenceLengthsCBA = bufferCast<SizeType32>(*(src.sequenceLengthsCBA));
copyStruct.dstSequenceLengthsCBA = bufferCast<SizeType32>(*(dst.sequenceLengthsCBA));
copyStruct.sequenceLengthsNumElts = dst.sequenceLengthsCBA->getSize();
copyStruct.srcCumLogProbsCBA = bufferCast<float>(*(src.cumLogProbsCBA));
copyStruct.dstCumLogProbsCBA = bufferCast<float>(*(dst.cumLogProbsCBA));
copyStruct.cumLogProbsCBANumElts = dst.cumLogProbsCBA->getSize();
copyStruct.srcNormedScoresCBA = bufferCast<float>(*(src.normedScoresCBA));
copyStruct.dstNormedScoresCBA = bufferCast<float>(*(dst.normedScoresCBA));
copyStruct.normedScoresNumElts = dst.normedScoresCBA->getSize();
copyStruct.srcNumBeamsCBA = bufferCast<SizeType32>(*(src.numBeamsCBA));
copyStruct.dstNumBeamsCBA = bufferCast<SizeType32>(*(dst.numBeamsCBA));
copyStruct.numBeamsNumElts = dst.numBeamsCBA->getSize();
copyStruct.srcMinNormedScoresCBA = bufferCast<float>(*(src.minNormedScoresCBA));
copyStruct.dstMinNormedScoresCBA = bufferCast<float>(*(dst.minNormedScoresCBA));
copyStruct.minNormedScoresNumElts = dst.minNormedScoresCBA->getSize();
copyStruct.srcBatchDones = bufferCast<bool>(*(src.batchDones));
copyStruct.dstBatchDones = bufferCast<bool>(*(dst.batchDones));
copyStruct.batchDonesNumElts = dst.batchDones->getSize();
copyStruct.srcCumLogProbs = bufferCast<float>(srcCumLogProbs);
copyStruct.dstCumLogProbs = bufferCast<float>(dstCumLogProbs);
copyStruct.cumLogProbsNumElts = srcCumLogProbs.getSize();
copyBeamHypotheses<<<numSMs, 256, 0, stream.get()>>>(copyStruct);
}
__global__ void initializeOutput(
TokenIdType* finalOutputIds, TokenIdType const* endIds, SizeType32 const beam, SizeType32 const nMaxSeqLen)
{
for (int i = threadIdx.x; i < nMaxSeqLen; i += blockDim.x)
{
finalOutputIds[blockIdx.x * nMaxSeqLen + i] = endIds[blockIdx.x / beam];
}
}
void invokeInitializeOutput(TokenIdType* finalOutputIds, TokenIdType const* endIds, SizeType32 const batch,
SizeType32 const beam, SizeType32 const nMaxSeqLen, cudaStream_t stream)
{
initializeOutput<<<batch * beam, 256, 0, stream>>>(finalOutputIds, endIds, beam, nMaxSeqLen);
}
__global__ void copyNextStepIds(TokenIdType* nextStepIds, TokenIdType const* const* outputIdsPtr,
SizeType32 const* sequenceLengths, SizeType32 const* numNewTokens, SizeType32 const* batchSlots,
SizeType32 batchSize, SizeType32 maxBatchSize, SizeType32 beamWidth, SizeType32 maxSeqLen,
SizeType32 maxTokensPerStep)
{
for (auto index = static_cast<SizeType32>(blockIdx.x * blockDim.x + threadIdx.x);
index < batchSize * beamWidth * maxTokensPerStep; index += static_cast<SizeType32>(blockDim.x * gridDim.x))
{
// numNewTokens == nullptr when Medusa is disabled
auto const batchIdx{index / (beamWidth * maxTokensPerStep)};
auto const batchSlot{batchSlots[batchIdx]};
auto const remainder{index % (beamWidth * maxTokensPerStep)};
auto const beamIdx{remainder / maxTokensPerStep};
auto const tokenIdx{remainder % maxTokensPerStep};
auto const newTokens{numNewTokens == nullptr ? 1 : numNewTokens[batchSlot]};
auto const batchBeamIdx = batchSlot * beamWidth + beamIdx;
auto const tokenBatchBeamIdx = tokenIdx * maxBatchSize * beamWidth + batchSlot * beamWidth + beamIdx;
auto const indexSrc = sequenceLengths[batchBeamIdx] - newTokens + tokenIdx;
if (tokenIdx >= newTokens || indexSrc < 0)
{
continue;
}
nextStepIds[tokenBatchBeamIdx] = outputIdsPtr[batchSlot][beamIdx * maxSeqLen + indexSrc];
}
}
void invokeCopyNextStepIds(TokenIdType* nextStepIds, TokenIdType const* const* outputIdsPtr,
SizeType32 const* sequenceLengths, SizeType32 const* numNewTokens, SizeType32 const* batchSlots,
SizeType32 batchSize, SizeType32 maxBatchSize, SizeType32 beamWidth, SizeType32 maxSeqLen,
SizeType32 maxTokensPerStep, cudaStream_t stream)
{
int const numElems = batchSize * beamWidth * maxTokensPerStep;
dim3 block(min(256, numElems));
dim3 grid(divUp(numElems, block.x));
copyNextStepIds<<<grid, block, 0, stream>>>(nextStepIds, outputIdsPtr, sequenceLengths, numNewTokens, batchSlots,
batchSize, maxBatchSize, beamWidth, maxSeqLen, maxTokensPerStep);
}
__global__ void transposeLogProbs(float* outputLogProbs, float* outputLogProbsTiled, SizeType32 const* sequenceLengths,
SizeType32 const* batchSlots, SizeType32 batchSize, SizeType32 maxBatchSize, SizeType32 beamWidth,
SizeType32 maxSeqLen)
{
auto index = static_cast<SizeType32>(blockIdx.x * blockDim.x + threadIdx.x);
auto const batchIdx = index / (beamWidth * maxSeqLen);
auto const tmpIdx = index % (beamWidth * maxSeqLen);
auto const beamIdx = tmpIdx / maxSeqLen;
auto const pos = tmpIdx % maxSeqLen;
if (batchIdx >= batchSize)
{
return;
}
auto const batchSlot = batchSlots[batchIdx];
if (pos < sequenceLengths[batchSlot])
{
auto const batchBeamIdx = batchSlot * beamWidth * maxSeqLen + beamIdx * maxSeqLen + pos;
outputLogProbs[batchBeamIdx]
= outputLogProbsTiled[pos * maxBatchSize * beamWidth + batchSlot * beamWidth + beamIdx];
}
}
void invokeTransposeLogProbs(float* outputLogProbs, float* outputLogProbsTiled, SizeType32 const* sequenceLengths,
SizeType32 const* batchSlots, SizeType32 batchSize, SizeType32 maxBatchSize, SizeType32 beamWidth,
SizeType32 maxSeqLen, cudaStream_t stream)
{
dim3 block(256);
dim3 grid(divUp(batchSize * beamWidth * maxSeqLen, block.x));
transposeLogProbs<<<grid, block, 0, stream>>>(outputLogProbs, outputLogProbsTiled, sequenceLengths, batchSlots,
batchSize, maxBatchSize, beamWidth, maxSeqLen);
}
} // namespace kernels
namespace runtime::kernels
{
// Must be similar to [cpp/tensorrt_llm/thop/gatherTreeOp.cpp] gatherTree
void gatherTree(DecodingOutput const& decodingOutput, DecodingInput const& decodingInput,
SamplingConfig const& samplingConfig, runtime::CudaStream const& cudaStream)
{
TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__);
auto const& stream = cudaStream.get();
BufferManager manager{std::make_shared<CudaStream>(stream)};
auto& finalOutputIds = *decodingOutput.gatheredIds;
auto const& finalOutputIdsShape = finalOutputIds.getShape();
auto const& decodingOutputIdsShape = decodingOutput.ids->getShape();
auto const batchSize = finalOutputIdsShape.d[0];
auto const beamWidth = finalOutputIdsShape.d[1];
auto const maxSeqLength = finalOutputIdsShape.d[2];
TLLM_CHECK_WITH_INFO(beamWidth > 1, "gatherTree is only needed for beam search.");
TLLM_CHECK_WITH_INFO(decodingOutputIdsShape.d[0] == batchSize,
common::fmtstr("Decoder batch size (" FMT_DIM ") does not match final batch size (" FMT_DIM ")",
decodingOutputIdsShape.d[0], batchSize));
TLLM_CHECK_WITH_INFO(decodingOutputIdsShape.d[1] == beamWidth,
common::fmtstr("Decoder beam width (" FMT_DIM ") does not match final beam width (" FMT_DIM ")",
decodingOutputIdsShape.d[1], beamWidth));
TLLM_CHECK_WITH_INFO(decodingOutputIdsShape.d[2] <= maxSeqLength,
common::fmtstr("Decoder seq length size (" FMT_DIM ") is too large for final seq length (" FMT_DIM ")",
decodingOutputIdsShape.d[2], maxSeqLength));
// prefill finalOutputIds with the EOS tokens from decodingInput.endIds
tensorrt_llm::kernels::invokeInitializeOutput(bufferCast<TokenIdType>(finalOutputIds),
bufferCast<TokenIdType>(*decodingInput.endIds), batchSize, beamWidth, maxSeqLength, stream);
sync_check_cuda_error(stream);
std::vector<float> lengthPenaltyVec;
auto lengthPenaltyPtr = std::shared_ptr(manager.gpu(ITensor::makeShape({batchSize}), TRTDataType<float>::value));
if (!samplingConfig.lengthPenalty.has_value() || samplingConfig.lengthPenalty.value().size() == 0)
{
lengthPenaltyVec = std::vector<float>(batchSize, 1.0f);
}
else if (long int const size = samplingConfig.lengthPenalty.value().size(); size == 1)
{
lengthPenaltyVec = std::vector<float>(batchSize, samplingConfig.lengthPenalty.value()[0]);
}
else
{
TLLM_CHECK_WITH_INFO(size == batchSize,
common::fmtstr("Size of lengthPenalty in SamplingConfig (" FMT_DIM ") is different from batchSize (" FMT_DIM
")",
size, batchSize));
lengthPenaltyVec = samplingConfig.lengthPenalty.value();
}
lengthPenaltyPtr = manager.copyFrom(lengthPenaltyVec, ITensor::makeShape({batchSize}), runtime::MemoryType::kGPU);
tensorrt_llm::kernels::BeamHypotheses bh;
bh.nMaxBatchSize = batchSize;
bh.nBatchSize = batchSize;
bh.nBeamWidth = beamWidth;
bh.nMaxSeqLen = maxSeqLength;
bh.lengthPenalties = bufferCast<float>(*lengthPenaltyPtr);
bh.inputLengths = bufferCast<SizeType32>(*decodingInput.lengths);
bh.outputIds = bufferCast<TokenIdType>(finalOutputIds);
bh.logProbs = bufferCastOrNull<float>(decodingOutput.logProbs);
bh.logProbsTiled = bufferCast<float>(*decodingOutput.logProbsTiled);
bh.sequenceLengths = bufferCast<SizeType32>(*decodingOutput.lengths);
bh.cumLogProbs = bufferCast<float>(*decodingOutput.cumLogProbs);
bh.outputIdsCBA = bufferCast<TokenIdType>(*decodingOutput.beamHypotheses.outputIdsCBA);
bh.logProbsCBA = bufferCast<float>(*decodingOutput.beamHypotheses.logProbsCBA);
bh.sequenceLengthsCBA = bufferCast<SizeType32>(*decodingOutput.beamHypotheses.sequenceLengthsCBA);
bh.cumLogProbsCBA = bufferCast<float>(*decodingOutput.beamHypotheses.cumLogProbsCBA);
bh.normedScoresCBA = bufferCast<float>(*decodingOutput.beamHypotheses.normedScoresCBA);
bh.numBeamsCBA = bufferCast<SizeType32>(*decodingOutput.beamHypotheses.numBeamsCBA);
bh.minNormedScoresCBA = bufferCast<float>(*decodingOutput.beamHypotheses.minNormedScoresCBA);
bh.batchDones = bufferCast<bool>(*decodingOutput.beamHypotheses.batchDones);
bh.finished = bufferCast<tensorrt_llm::kernels::FinishedState>(*decodingOutput.finishReasons);
bh.outputIdsUnfinish = bufferCast<TokenIdType>(*decodingOutput.ids);
bh.parentIdsUnfinish = bufferCast<TokenIdType>(*decodingOutput.parentIds);
// This is where transpose is done
tensorrt_llm::kernels::invokeInsertUnfinishedPath(bh, stream);
sync_check_cuda_error(stream);
tensorrt_llm::kernels::invokeFinalize(bh, stream);
sync_check_cuda_error(stream);
TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__);
}
} // namespace runtime::kernels
} // namespace tensorrt_llm