/* * Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ #ifndef CUDART_VERSION #error CUDART_VERSION Undefined! #elif (CUDART_VERSION >= 11050) #include #else #include "3rdparty/cub/cub.cuh" #endif #include "tensorrt_llm/common/assert.h" #include "tensorrt_llm/common/cudaUtils.h" #include "tensorrt_llm/common/reduceKernelUtils.cuh" #include "tensorrt_llm/common/stringUtils.h" #include "tensorrt_llm/kernels/beamSearchKernels.h" #include "tensorrt_llm/kernels/decodingCommon.h" using namespace tensorrt_llm::common; namespace tensorrt_llm { namespace kernels { #define TOPK_FP16_STORAGE 0 #pragma nv_diag_suppress static_var_with_dynamic_init template __launch_bounds__(THREADBLOCK_SIZE) __global__ void beamStage3Kernel(int const* __restrict pTempId, T const* __restrict pTempVal, BeamHypotheses bh) { int const bid = blockIdx.x; // Index of Batch int const tid = threadIdx.x; auto const slot = bh.batchSlots[bid]; int const nMBS{bh.nMaxBatchSize}; // Only for bh.logProbsTiled int const nBM{bh.nBeamWidth}; int const nCandidate{nBM * nBM * 2}; // Keep top 2K candidates from each beam output int const nV{bh.nVocabSize}; float const diversityRate{bh.diversityRates[slot]}; float const lengthPenalty{bh.lengthPenalties[slot]}; int const earlyStopping{bh.earlyStoppings[slot]}; T const MAX_T_VAL = std::is_same_v ? HALF_FLT_MAX : FLT_MAX; __shared__ int nBeamForNextStep; // Only used by thread of tid == 0 __shared__ float smemCumLogProbs[PAD_2K / 2]; if (tid == 0) { nBeamForNextStep = 0; } if (tid < nBM) { smemCumLogProbs[tid] = bh.cumLogProbs[slot * nBM + tid]; } __syncthreads(); if (bh.numBeamsCBA != nullptr) { // Beam search is enabled if (bh.numBeamsCBA[slot] == 0 && tid == 0) { // Initialize worst score in the first call bh.minNormedScoresCBA[slot] = FLT_MAX; } else if (earlyStopping == 1 && bh.numBeamsCBA[slot] == nBM || earlyStopping != 1 && bh.finished[slot * nBM].isFinished()) { // Condition of early return: // 1. In EarlyStopping mode, and we have got enough beams // 2. In NonEarlyStopping mode, and this batch has been marked as done // TODO: improve the condition like below // earlyStopping == 1 && bh.numBeamsCBA[slot] == nBM || earlyStopping != 1 && bh.batchDones[slot] return; } } // Get top 2K tokens from candidates pTempId += bid * nCandidate; pTempVal += bid * nCandidate; using KVPair = cub::KeyValuePair; KVPair topKVPairPartial{nCandidate - 1, -MAX_T_VAL}; cub::ArgMax argmax; extern __shared__ char smem[]; T* smemVal = reinterpret_cast(smem); for (int i = tid; i < nCandidate; i += THREADBLOCK_SIZE) { int const index = bh.numBeamsCBA == nullptr ? i % nBM : i / 2 / nBM; T const val = pTempVal[i] + static_cast(diversityRate * index); topKVPairPartial = argmax(topKVPairPartial, {i, val}); smemVal[i] = val; } __syncthreads(); using BlockReduce = cub::BlockReduce; __shared__ typename BlockReduce::TempStorage smemReduceBuffer; __shared__ KVPair smemTopKV[PAD_2K]; __shared__ int threadToUpdate; for (int i = 0; i < 2 * nBM; ++i) { KVPair topKVPair = BlockReduce(smemReduceBuffer).Reduce(topKVPairPartial, argmax); if (tid == 0) { smemTopKV[i] = topKVPair; smemVal[topKVPair.key] = -MAX_T_VAL; threadToUpdate = topKVPair.key % THREADBLOCK_SIZE; } __syncthreads(); // Only one thread needs to update the old partial before the next block reduce. // No need to do this in the last iteration. if (tid == threadToUpdate && i < 2 * nBM - 1) { topKVPairPartial.key = nCandidate - 1; topKVPairPartial.value = -MAX_T_VAL; for (int index = tid; index < nCandidate; index += THREADBLOCK_SIZE) { topKVPairPartial = argmax(topKVPairPartial, {index, smemVal[index]}); } } } if (tid == 0) { // Select finished beams into CBA or select tokens for next step sequentially // Reference (might be changed along HF in the future): // https://github.com/huggingface/transformers/blob/main/src/transformers/generation/beam_search.py#L272 for (int i = 0; i < 2 * nBM; ++i) { int const topKey = smemTopKV[i].key; T const topValue = smemTopKV[i].value; bool const isEndToken = pTempId[topKey] % nV == bh.endIds[slot]; if (i < nBM && bh.numBeamsCBA != nullptr && isEndToken) { // Condition of this branch // This token is end-token and belongs to top nBM range in Beam search mode int const nSeqLen = bh.sequenceLengths[slot * nBM + i] + 1 - bh.inputLengths[slot * nBM + i]; float const score = applyLengthPenalty(topValue, nSeqLen, lengthPenalty); int nCBA = bh.numBeamsCBA[slot]; if (nCBA == nBM) { // There are already nBM beams if (score < bh.minNormedScoresCBA[slot]) { // Current score is worse than the worst one in candidate beams if (earlyStopping) { // Stop since we have got enough beams break; } else { // Continue since there might be longer but better beams continue; } } else { // Current score is better than the worst one in candidate beams // Find the candidate beam index with the worst score and erase it for (int j = 0; j < nBM; j++) { if (bh.normedScoresCBA[slot * (nBM * 2) + j] == bh.minNormedScoresCBA[slot]) { nCBA = j; bh.numBeamsCBA[slot]--; bh.minNormedScoresCBA[slot] = FLT_MAX; bh.normedScoresCBA[slot * (nBM * 2) + j] = score; for (int l = 0; l < nBM; l++) { bh.minNormedScoresCBA[slot] = min(bh.minNormedScoresCBA[slot], bh.normedScoresCBA[slot * (nBM * 2) + l]); } break; } } } } // Copy finished beam from work tree to CBA // The last token int indexPrev = (pTempId[topKey] / nV) % nBM; int const step = bh.sequenceLengths[slot * nBM + indexPrev]; int const offsetCBA = (slot * nBM * 2 + nCBA) * bh.nMaxSeqLen; bh.outputIdsCBA[offsetCBA + step] = bh.endIds[slot]; if (bh.logProbsCBA != nullptr) { bh.logProbsCBA[offsetCBA + step] = (float) pTempVal[topKey] - smemCumLogProbs[(pTempId[topKey] / nV) % nBM]; } // Previous tokens for (int j = step - 1; j >= 0; j--) { bh.outputIdsCBA[offsetCBA + j] = bh.outputIdsPtr[slot][indexPrev * bh.nMaxSeqLen + j]; indexPrev = bh.parentIdsPtr[slot][indexPrev * bh.nMaxSeqLen + j]; } if (bh.logProbsCBA != nullptr && bh.logProbsTiled != nullptr) { indexPrev = (pTempId[topKey] / nV) % nBM; for (int j = step - 1; j >= 0; j--) { int const index = (j * nMBS + slot) * nBM + indexPrev; bh.logProbsCBA[offsetCBA + j] = bh.logProbsTiled[index]; indexPrev = bh.parentIdsPtr[slot][indexPrev * bh.nMaxSeqLen + j]; } } // Other parameters int const index = slot * (nBM * 2) + nCBA; bh.sequenceLengthsCBA[index] = step; bh.normedScoresCBA[index] = score; bh.minNormedScoresCBA[slot] = min(bh.minNormedScoresCBA[slot], bh.normedScoresCBA[index]); bh.numBeamsCBA[slot]++; bh.cumLogProbsCBA[index] = (float) pTempVal[topKey]; } else if (i < nBM || bh.numBeamsCBA != nullptr && !isEndToken) { // Condition of this branch // 1. bh.numBeamsCBA == nullptr && i < nBM, i.e., beam search is disable // 2. bh.numBeamsCBA != nullptr && i < nBM && isEndToken == false, i.e., add token at the end // 3. bh.numBeamsCBA != nullptr && i >= nBM && isEndToken == false, i.e., add token at the end int const step = bh.sequenceLengths[slot * nBM + nBeamForNextStep]; // Copy the selected token to work tree bh.outputIdsPtr[slot][nBeamForNextStep * bh.nMaxSeqLen + step] = pTempId[topKey]; if (bh.logProbsTiled != nullptr) { int const index = step * nMBS * nBM + slot * nBM + nBeamForNextStep; int const indexBeam = pTempId[topKey] / nV % nBM; bh.logProbsTiled[index] = (float) pTempVal[topKey] - smemCumLogProbs[indexBeam]; } bh.cumLogProbs[slot * nBM + nBeamForNextStep] = (float) pTempVal[topKey]; nBeamForNextStep++; } else { // Condition of this branch, which we do nothing for it // 1. bh.numBeamsCBA == nullptr && i >= nBM, i.e., beam search is disable // 2. bh.numBeamsCBA != nullptr && i >= nBM && isEndToken == true, i.e., ignore the worse beams } if (nBeamForNextStep >= nBM) { // Condition of this branch // 1. In EarlyStopping mode, and get enough candidate beams // 2. In EarlyStopping mode, and get enough tokens for the next generation step // 3. In NonEarlyStopping mode, and get enough tokens for the next generation step // TODO: improve the condition like below // earlyStopping == 1 && bh.numBeamsCBA[slot] >= nBM || nBeamForNextStep >= nBM break; } } } // Update bh.batchDones if (tid == 0 && bh.numBeamsCBA != nullptr) { if (bh.numBeamsCBA[slot] < nBM) { // no enough beams bh.batchDones[slot] = false; } else if (earlyStopping == 1) { // enough candidate beams in EarlyStopping mode bh.batchDones[slot] = true; } else { // enough beams in NonEarlyStopping mode int nSeqLen = bh.sequenceLengths[slot * nBM] + 1 - bh.inputLengths[slot * nBM]; float const bestCumLogProbs = smemTopKV[0].value; // According to semantics of HF, smemTopKV[0].value is used as bestCumLogProbs // But maybe bh.cumLogProbs[slot * nBM + i] is more suitable? // https://github.com/huggingface/transformers/blob/main/src/transformers/generation/beam_search.py#L307 if (earlyStopping != 0 && lengthPenalty > 0.0f) { // Specialization for earlyStopping == "never" and lengthPenalty > 0 in HF nSeqLen = bh.nMaxSeqLen - bh.inputLengths[slot * nBM]; } float const bestAttainableScore = applyLengthPenalty(bestCumLogProbs, nSeqLen, lengthPenalty); bh.batchDones[slot] = bh.minNormedScoresCBA[slot] >= bestAttainableScore; } } __syncthreads(); // Update sequenceLengths, parentIdsPtr, outputIdsPtr and finished __shared__ int smemSeqLen[PAD_2K / 2]; if (tid < nBM) { smemSeqLen[tid] = bh.sequenceLengths[slot * nBM + tid]; } __syncthreads(); if (tid < nBM) { int const indexBatchBeam = slot * nBM + tid; int const step = smemSeqLen[tid]; if (!bh.finished[indexBatchBeam].isFinished()) { smemSeqLen[tid]++; } int const newId = bh.outputIdsPtr[slot][tid * bh.nMaxSeqLen + step]; int const newBeamId = (newId / nV) % nBM; int const newTokenId = newId % nV; bh.sequenceLengths[indexBatchBeam] = smemSeqLen[newBeamId]; if (newTokenId == bh.endIds[slot]) { bh.finished[indexBatchBeam].setFinishedEOS(); } bh.parentIdsPtr[slot][tid * bh.nMaxSeqLen + step] = newBeamId; bh.outputIdsPtr[slot][tid * bh.nMaxSeqLen + step] = newTokenId; if ((earlyStopping == 1) && (bh.numBeamsCBA != nullptr && bh.numBeamsCBA[slot] == nBM) || (earlyStopping != 1) && bh.batchDones[slot]) { bh.batchDones[slot] = true; bh.finished[indexBatchBeam].setFinished(); } } } struct __align__(8) MD { float m; float d; }; __device__ __forceinline__ MD reduce_md_op(MD a, MD b) { bool const isABigger = a.m > b.m; MD const bigger = isABigger ? a : b; MD const smaller = isABigger ? b : a; MD res{bigger.m, bigger.d + smaller.d * __expf(smaller.m - bigger.m)}; return res; } template __launch_bounds__(THREADBLOCK_SIZE, 1) __global__ void beamStage1Kernel(T const* __restrict logits, T const* __restrict bias, float* __restrict pTemp, int const* __restrict endIds, FinishedState const* __restrict finished, int const nV, int const nVLocal, runtime::SizeType32 const* batchSlots, int dyn_smem_size) { constexpr auto PACKED_TOP_KMD_SIZE = 2 * PAD_2K + 2; int const nBM = gridDim.y; int const tid = threadIdx.x; int const slot = batchSlots[blockIdx.x]; int const section_start = nVLocal * blockIdx.z; int const section_end = std::min(section_start + nVLocal, nV); auto const nVOffset = (blockIdx.x * nBM + blockIdx.y) * nV; int const valid_smem_length = section_end - section_start; T const MAX_T_VAL = std::is_same_v ? HALF_FLT_MAX : FLT_MAX; // Load element from logits to smemLogProbs, doing reduce_md and argmax meanwhile // Each thread is responsible for `nVLocal / THREADBLOCK_SIZE` elements extern __shared__ char smem[]; T* smemLogProbs = reinterpret_cast(smem); MD partial_md{-MAX_T_VAL, 0.0f}; #if TOPK_FP16_STORAGE == 1 using KVPair = cub::KeyValuePair; #else using KVPair = cub::KeyValuePair; #endif KVPair topKVPairPartial{-1, -MAX_T_VAL}; cub::ArgMax argmax; if (finished[slot * nBM + blockIdx.y].isFinished()) { for (int i = section_start + tid; i < section_end; i += THREADBLOCK_SIZE) { float const val = (i == endIds[slot]) ? MAX_T_VAL : -MAX_T_VAL; int const smem_index = i - section_start; smemLogProbs[smem_index] = val; MD const new_elem_md{val, 1.0F}; partial_md = reduce_md_op(partial_md, new_elem_md); KVPair const new_elem_topk{smem_index, val}; topKVPairPartial = argmax(topKVPairPartial, new_elem_topk); } } else { for (int i = section_start + tid; i < section_end; i += THREADBLOCK_SIZE) { T const b = bias == nullptr ? (T) 0.0f : bias[i]; T const val = logits[nVOffset + i] + b; int const smem_index = i - section_start; smemLogProbs[smem_index] = val; MD new_elem_md{val, 1.0F}; partial_md = reduce_md_op(partial_md, new_elem_md); KVPair new_elem_topk{smem_index, val}; topKVPairPartial = argmax(topKVPairPartial, new_elem_topk); } } __syncthreads(); // Search the top 2K elements among `nVLocal` elements of this ThreadBlock and write into smemOutput __shared__ float smemOutput[PACKED_TOP_KMD_SIZE]; __shared__ int threadToUpdate; using BlockReduceMD = cub::BlockReduce; using BlockReduceTopK = cub::BlockReduce; __shared__ union { typename BlockReduceTopK::TempStorage topk; typename BlockReduceMD::TempStorage md; } smemReduceBuffer; for (int i = 0; i < 2 * nBM; ++i) { // Pop the element with largest value to "smemOutput" per iteration KVPair topKVPair = BlockReduceTopK(smemReduceBuffer.topk).Reduce(topKVPairPartial, argmax); if (tid == 0) { int const index = nVOffset + section_start + topKVPair.key; reinterpret_cast(smemOutput)[i] = index; smemOutput[PAD_2K + i] = topKVPair.value; smemLogProbs[topKVPair.key] = -MAX_T_VAL; // pollute the value of the popped element threadToUpdate = topKVPair.key % THREADBLOCK_SIZE; } __syncthreads(); if (tid == threadToUpdate && i < 2 * nBM - 1) { // The thread popped the element need to update its topKVPairPartial // No need to do this in the last iteration topKVPairPartial.key = nV - 1; topKVPairPartial.value = -MAX_T_VAL; for (int index = tid; index < valid_smem_length; index += THREADBLOCK_SIZE) { topKVPairPartial = argmax(topKVPairPartial, {index, smemLogProbs[index]}); } } // Sync due to threadToUpdate RAW dependency __syncthreads(); } // Do reduce_md among the top 2K elements in the smemOutput and write into tail of smemOutput MD total_md = BlockReduceMD(smemReduceBuffer.md).Reduce(partial_md, reduce_md_op); if (tid == 0) { smemOutput[2 * PAD_2K] = total_md.d; smemOutput[2 * PAD_2K + 1] = total_md.m; } __syncthreads(); // Write the smemOutput into pTemp float* local_temp_buffer = pTemp + (blockIdx.x * nBM + blockIdx.y) * PACKED_TOP_KMD_SIZE * gridDim.z + blockIdx.z * PACKED_TOP_KMD_SIZE; for (int i = tid; i < PACKED_TOP_KMD_SIZE; i += THREADBLOCK_SIZE) { local_temp_buffer[i] = smemOutput[i]; } } template __launch_bounds__(THREADBLOCK_SIZE) __global__ void beamStage2Kernel(int* __restrict pTempId, T* __restrict pTempVal, float* __restrict pTemp, float const* __restrict cumLogProbs, int const nV, int const nVPart, runtime::SizeType32 const* batchSlots) { constexpr int PACKED_TOP_KMD_SIZE = 2 * PAD_2K + 2; auto const nBM = gridDim.y; auto const gbid = blockIdx.x * gridDim.y + blockIdx.y; int const tid = threadIdx.x; auto const slot = batchSlots[blockIdx.x]; T const MAX_T_VAL = std::is_same_v ? HALF_FLT_MAX : FLT_MAX; using KVPair = cub::KeyValuePair; using BlockReduceTopK = cub::BlockReduce; using BlockReduceMD = cub::BlockReduce; __shared__ KVPair buf_smem_kv[PAD_2K]; __shared__ union { typename BlockReduceTopK::TempStorage topk; typename BlockReduceMD::TempStorage md; } smemReduceBuffer; cub::ArgMax argmax; MD partial_md{-MAX_T_VAL, 0.0f}; KVPair topKVPair{nV - 1, -MAX_T_VAL}; // Load and unpack into registers through smem float* localTempBuffer = pTemp + PACKED_TOP_KMD_SIZE * gbid * nVPart; if constexpr (IS_FAST_KERNEL) // Use share memory instead of global memory { extern __shared__ char smem[]; float* smemVal = reinterpret_cast(smem); for (int idx = tid; idx < PACKED_TOP_KMD_SIZE * nVPart; idx += THREADBLOCK_SIZE) { smemVal[idx] = localTempBuffer[idx]; } localTempBuffer = smemVal; __syncthreads(); } // Find the top 2K across all nVPart for (int k = 0; k < 2 * nBM; ++k) { KVPair topKVPairPartial{nV - 1, -MAX_T_VAL}; // Only threads responsible for a chunk will do the computation if (tid < nVPart) { for (int i = 0; i < 2 * nBM; ++i) { int const current_index = tid * PACKED_TOP_KMD_SIZE + i; T topValue = localTempBuffer[current_index + PAD_2K]; topKVPairPartial = argmax(topKVPairPartial, {current_index, topValue}); } } KVPair topKVPair = BlockReduceTopK(smemReduceBuffer.topk).Reduce(topKVPairPartial, argmax); __syncthreads(); if (tid == 0) { // Store kv pairs in shared mem buffer int temp_offset = topKVPair.key; int global_offset = reinterpret_cast(localTempBuffer)[temp_offset]; topKVPair.key = global_offset; buf_smem_kv[k] = topKVPair; // Invalidate the maximum value within the chunk reinterpret_cast(localTempBuffer)[temp_offset] = nV - 1; // id in share memory localTempBuffer[temp_offset + PAD_2K] = -MAX_T_VAL; // value in share memory } __syncthreads(); } // Extract and reduce MD values across the chunks if (tid < nVPart) { partial_md.d = localTempBuffer[tid * PACKED_TOP_KMD_SIZE + 2 * PAD_2K]; partial_md.m = localTempBuffer[tid * PACKED_TOP_KMD_SIZE + 2 * PAD_2K + 1]; } __syncthreads(); MD total_md = BlockReduceMD(smemReduceBuffer.md).Reduce(partial_md, reduce_md_op); if (tid == 0) { float d_total_log = logf(total_md.d); auto const cumLogProbsValue = cumLogProbs[slot * nBM + blockIdx.y]; for (int i = 0; i < 2 * nBM; ++i) { float val = (float) buf_smem_kv[i].value - total_md.m - d_total_log; pTempId[gbid * 2 * nBM + i] = buf_smem_kv[i].key; pTempVal[gbid * 2 * nBM + i] = val + cumLogProbsValue; } } } #define BEAM_STAGE2_KERNEL(N_VOCAB_PART, IS_FAST_KERNEL) \ { \ if (IS_FAST_KERNEL && nShareMemory >= (48 << 10)) \ { \ TLLM_CUDA_CHECK(cudaFuncSetAttribute(beamStage2Kernel, \ cudaFuncAttributeMaxDynamicSharedMemorySize, nShareMemory)); \ } \ beamStage2Kernel \ <<>>( \ pTempId, pTempVal, pTemp, cumLogProbs, nV, nVPart, batchSlots); \ } \ return; template __inline__ void beamStage2KernelLauncher(float* pTemp, float const* cumLogProbs, int* pTempId, T* pTempVal, int const nBS, int const nBM, int const nVPart, int const nV, int const max_smem_per_block, cudaStream_t stream, runtime::SizeType32 const* batchSlots) { // TODO: rewrite kernel to remove dependence of constant block size to reduce compilation time size_t const nShareMemory = sizeof(float) * nVPart * (2 * PAD_2K + 2) + sizeof(cub::KeyValuePair) * PAD_2K; if (nShareMemory < max_smem_per_block) // IS_FAST_KERNEL must be a compilation-time constant { if (nVPart <= 32) { BEAM_STAGE2_KERNEL(32, true) } if (nVPart <= 64) { BEAM_STAGE2_KERNEL(64, true) } BEAM_STAGE2_KERNEL(128, true) // No larger branch since nVPart <= nMaxVocabPartForStage1FastKernel } BEAM_STAGE2_KERNEL(128, false) } template void topKSoftMaxKernelLauncher(T const* logits, T const* bias, void* workspace, BeamHypotheses& bh, cudaStream_t stream) { // Workflow of this function (reference: https://github.com/NVIDIA/online-softmax) // Using batch_size (BS) = 2, beam_width (BM) = 5, vocab_size (V) = vocan_size_padded (VP) = 32000 as an example: // nPaddedBeamWidth (PAD_K) = 2 ^ ceil(log(BM)) = 8, PAD_2K = 2 * PAD_K = 16 // logits.shape = [BS, BM, V] // nBlockSize = 128, nVPart = 13, nVocabChunk = 2462 = ceil(32000/13) // The content of workspace (length aligned to 4): // | allocated size | used size | data type | // ┏━━━━━━━━━━┓ -------------------------------------------------------------------------------- // ┃ pTempId ┃ BS * PAD_K * PAD_K * 2 | | int | // ┣━━━━━━━━━━┫ ----------------------------------------- Change "PAD_K" into "BM" ------------- // ┃ pTempVal ┃ BS * PAD_K * PAD_K * 2 | | float | // ┣━━━━━━━━━━┫ ----------------------------------------- in the left formulas ------------- // ┃ pTemp ┃ BS * PAD_K * VP * (2 * (PAD_K * 2) + 2) | | float | // ┗━━━━━━━━━━┛ -------------------------------------------------------------------------------- // Stage1: gridDim(BS*BM,nVPart,1), blockDim(nBlockSize,1,1) // Each ThreadBlock takes `nVocabChunk` contiguous elements in logits to do TopK and reduce_md, // then writes output into pTemp. // At end of this kernel, each ThreadBlock holds the indexes and values of the top 2*BM elements, // as well as the m(x) and l(x) of those elements (see paper of Flash Attention, arXiv:2205.14135) // pTemp.shape = [BS*BM, nVPart, 2*PAD_2K+2] // The content of the last dimension of pTemp (updated by each ThreadBlock, we call it "Tile"): // ┏━━━━━━━━━┳━━━━━━━━━━┳━━━━━━━┓ // ┃ topk_id ┃ topk_val ┃ md ┃ // ┗━━━━━━━━━┻━━━━━━━━━━┻━━━━━━━┛ // | allocated size | PAD_2K | PAD_2K | 2 | // | used size | BM * 2 | BM * 2 | 2 | // | data type | int | float | float | // Stage2: gridDim(BS*BM,1,1), blockDim(32/64/128,1,1) // Each TheadBlock takes `nVPart` contiguous Tiles in pTemp to do reduce_topk and reduce_md, // writes output topk_id into in pTempId, writes topk_value + cumLogProbs into pTempVal. // beamStage3Kernel: gridDim(BS,1,1), blockDim(128,1,1) // Each TheadBlock is responsible for one batch, doing work below: // + moves one beam into candidate-beam-array if it is finished (gemerated end_id in this step). // + selects BM elements for the next generation step if not. // + maintains related score array, min_normed_score / batchDones / finished, etc.. int constexpr items_per_thread = 1; int constexpr nBlockSize = (PAD_K < 16) ? ((PAD_K < 8) ? nBlockSizeForSmallBeamWidth : 128) : 64; int const nBS{bh.nBatchSize}; int const nBM{bh.nBeamWidth}; int const nV{bh.nVocabSize}; int const* endIds{bh.endIds}; runtime::SizeType32 const* batchSlots{bh.batchSlots}; FinishedState const* finished{bh.finished}; int const offset = roundUp(nBS * nBM * nBM * 2, 4); int* pTempId = reinterpret_cast(workspace); T* pTempVal = reinterpret_cast(pTempId + offset); float* pTemp = reinterpret_cast(pTempVal + offset); // Upper limit count of ThreadBlock, gotten by using no share memory int max_active_blocks = -1; TLLM_CUDA_CHECK(cudaOccupancyMaxActiveBlocksPerMultiprocessor( &max_active_blocks, beamStage1Kernel, nBlockSize, 0)); // Find the max smem on the device and use that to determine the vocab parts in the best case. int max_smem_per_sm = -1; int max_smem_per_block = -1; int const device = tensorrt_llm::common::getDevice(); TLLM_CUDA_CHECK(cudaDeviceGetAttribute(&max_smem_per_sm, cudaDevAttrMaxSharedMemoryPerMultiprocessor, device)); TLLM_CUDA_CHECK(cudaDeviceGetAttribute(&max_smem_per_block, cudaDevAttrMaxSharedMemoryPerBlockOptin, device)); cudaFuncAttributes attr; TLLM_CUDA_CHECK(cudaFuncGetAttributes(&attr, beamStage1Kernel)); // One ThreadBlock must at least have share memory of `sizeof(T) * nV / nMaxVocabPartForStage1FastKernel` bytes int const static_smem = attr.sharedSizeBytes; int const max_dyn_smem_per_block = max_smem_per_block - static_smem; TLLM_CHECK_WITH_INFO(sizeof(T) * nV <= max_dyn_smem_per_block * nMaxVocabPartForStage1FastKernel, "Vocab size is too large for split-k TopK beam search fast path."); // Find the maximum of ThreadBlock (maximum of nVPart, minimum of smem), // satisfying nVPart <= nMaxVocabPartForStage1FastKernel && dyn_smem_size * nVPart >= sizeof(T) * nV int const driver_smem_per_block = max_smem_per_sm - max_smem_per_block; int const extra_smem = driver_smem_per_block + static_smem; int nVPart = nMaxVocabPartForStage1FastKernel + 1; for (int n_block = max_active_blocks - 1; n_block > 0 && nVPart > nMaxVocabPartForStage1FastKernel; --n_block) { int dyn_smem_size = max_smem_per_sm / n_block - extra_smem; dyn_smem_size -= dyn_smem_size % sizeof(T); nVPart = ceilDiv(sizeof(T) * nV, dyn_smem_size); } int const nVocabChunk = (nV + nVPart - 1) / nVPart; int const dyn_smem_size = sizeof(T) * nVocabChunk; if (dyn_smem_size >= (48 << 10)) { TLLM_CUDA_CHECK(cudaFuncSetAttribute(beamStage1Kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, dyn_smem_size)); } dim3 gridSize(nBS, nBM, nVPart); beamStage1Kernel<<>>( logits, bias, pTemp, endIds, finished, nV, nVocabChunk, batchSlots, dyn_smem_size); sync_check_cuda_error(); beamStage2KernelLauncher( pTemp, bh.cumLogProbs, pTempId, pTempVal, nBS, nBM, nVPart, nV, max_smem_per_block, stream, batchSlots); sync_check_cuda_error(); // Keep top 2K candidates in case of k candidates finishes in one iteration size_t const nShareMemory = sizeof(T) * nBM * nBM * 2; size_t constexpr nBlockSizeStage3 = (PAD_K + 31) / 32 * 32; // can not use `roundUp()` if (nShareMemory >= (48 << 10)) { TLLM_CUDA_CHECK(cudaFuncSetAttribute(beamStage3Kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, nShareMemory)); } beamStage3Kernel <<>>(pTempId, pTempVal, bh); sync_check_cuda_error(); } #define INSTANTIATE_BEAMSEARCH_K(T, PAD_K) \ template void topKSoftMaxKernelLauncher( \ T const* logits, T const* bias, void* workspace, BeamHypotheses& bh, cudaStream_t stream); } // namespace kernels } // namespace tensorrt_llm