/* * Copyright (c) 2019-2024, NVIDIA CORPORATION. All rights reserved. * Copyright (c) 2021, NAVER Corp. Authored by CLOVA. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ #include #ifndef CUDART_VERSION #error CUDART_VERSION Undefined! #elif (CUDART_VERSION >= 11050) #include #else #include "3rdparty/cub/cub.cuh" #endif #include "tensorrt_llm/common/logger.h" #include "tensorrt_llm/common/memoryUtils.h" #include "tensorrt_llm/common/reduceKernelUtils.cuh" #include "tensorrt_llm/common/stringUtils.h" #include "tensorrt_llm/kernels/samplingTopKKernels.h" using namespace tensorrt_llm::common; using namespace tensorrt_llm::runtime; namespace tensorrt_llm { namespace kernels { template __global__ void topKStage1(T const* __restrict logProbs, T const* const* __restrict logProbsPtrs, T* tmpLogProbs, SizeType32* topKTmpIdBuf, T* topKTmpValBuf, FinishedState const* finished, SizeType32 maxTopK, SizeType32 const* topKs, SizeType32 vocabSize, TokenIdType const* endIds, bool const* skipDecode, SizeType32 const* batchSlots, SizeType32 const* tokensPerStep, SizeType32 maxTokensPerStep) { typedef cub::BlockReduce, BLOCK_SIZE_> BlockReduce; __shared__ typename BlockReduce::TempStorage tempStorage; auto const tid = static_cast(threadIdx.x); auto const bid = static_cast(blockIdx.x); auto const tokenIdx = static_cast(blockIdx.y); auto const batchId = bid / BLOCKS_PER_BEAM_; // row id for logProbs auto const batchSlot = batchSlots != nullptr ? batchSlots[batchId] : batchId; if (tokensPerStep != nullptr && tokenIdx >= tokensPerStep[batchSlot]) { return; } FinishedState const finishState = finished != nullptr ? finished[batchSlot] : FinishedState::empty(); if ((skipDecode != nullptr && skipDecode[batchSlot]) || (finishState.isSkipDecoding())) { return; } auto const logBufIndex = batchId * maxTokensPerStep * vocabSize + tokenIdx * vocabSize; auto logProbsSlot = logProbsPtrs == nullptr ? logProbs + logBufIndex : logProbsPtrs[batchId * maxTokensPerStep + tokenIdx]; auto const blockLane = bid % BLOCKS_PER_BEAM_; // block id for a beam auto const k = (topKs != nullptr) ? topKs[batchSlot] : maxTopK; // batchId = batch index auto const tmpLogBufIndex = batchId * maxTokensPerStep * vocabSize + tokenIdx * vocabSize; auto const tmpTopKBufIndex = batchId * maxTokensPerStep * BLOCKS_PER_BEAM_ * maxTopK + tokenIdx * BLOCKS_PER_BEAM_ * maxTopK + blockLane * k; TopK_2 partial; bool const IS_FP16 = std::is_same::value; T const MAX_T_VAL = (IS_FP16) ? HALF_FLT_MAX : FLT_MAX; if (finished != nullptr && finishState.isFinished()) { if (tid < k) { auto const index = tmpTopKBufIndex + tid; if (blockLane == 0 && tid == 0) { auto const endId = endIds[batchSlot]; topKTmpIdBuf[index] = tmpLogBufIndex + endId; topKTmpValBuf[index] = logProbsSlot[endId]; } else { topKTmpIdBuf[index] = -1; topKTmpValBuf[index] = -MAX_T_VAL; } } return; } for (auto elemId = tid + blockLane * BLOCK_SIZE_; elemId < vocabSize; elemId += BLOCK_SIZE_ * BLOCKS_PER_BEAM_) { auto localIndex = elemId + tmpLogBufIndex; tmpLogProbs[localIndex] = logProbsSlot[elemId]; } for (SizeType32 ite = 0; ite < k; ite++) { partial.init(); #pragma unroll for (auto elemId = tid + blockLane * BLOCK_SIZE_; elemId < vocabSize; elemId += BLOCK_SIZE_ * BLOCKS_PER_BEAM_) { auto index = elemId + tmpLogBufIndex; partial.insert(tmpLogProbs[index], index); } TopK_2 total = BlockReduce(tempStorage).Reduce(partial, reduce_topk_op_2); if (tid == 0) { auto const index = tmpTopKBufIndex + ite; topKTmpIdBuf[index] = total.p; topKTmpValBuf[index] = total.u; if (total.p >= 0) { tmpLogProbs[total.p] = -MAX_T_VAL; } } __syncthreads(); } } template __global__ void topKStage2Sampling(SizeType32 const* __restrict topKTmpIdBuf, T* topKTmpValBuf, TokenIdType** idsPtrs, TokenIdType* ids, SizeType32* sequenceLengths, FinishedState const* finishedInput, FinishedState* finishedOutput, float* cumLogProbs, float* outputLogProbs, SizeType32 maxTopK, SizeType32 const* topKs, float topP, float const* topPs, curandState_t* curandState, TokenIdType const* endIds, SizeType32 vocabSize, bool const* skipDecode, SizeType32 const* batchSlots, SizeType32 maxBatchSize, bool normalizeLogProbs, bool logitHasProbs, SizeType32 const* tokensPerStep, SizeType32 maxTokensPerStep, SizeType32 maxSeqLen, bool returnAllTopK) { bool const IS_FP16 = std::is_same::value; T const MAX_T_VAL = (IS_FP16) ? HALF_FLT_MAX : FLT_MAX; auto const tid = static_cast(threadIdx.x); auto const batchIdx = static_cast(blockIdx.x); auto const tokenIdx = static_cast(blockIdx.y); auto const batchSlot = batchSlots != nullptr ? batchSlots[batchIdx] : batchIdx; FinishedState const finishState = finishedInput != nullptr ? finishedInput[batchSlot] : FinishedState::empty(); if ((skipDecode != nullptr && skipDecode[batchSlot]) || (finishState.isSkipDecoding())) { return; } if (tokensPerStep != nullptr && tokenIdx >= tokensPerStep[batchSlot]) { return; } auto const k = (topKs != nullptr) ? topKs[batchSlot] : maxTopK; auto const probThreshold = (topPs != nullptr) ? topPs[batchSlot] : topP; auto const size = k * BLOCKS_PER_BEAM_; auto const stride = maxTopK * BLOCKS_PER_BEAM_; typedef cub::BlockReduce, BLOCK_SIZE_> BlockReduce; __shared__ typename BlockReduce::TempStorage tempStorage; extern __shared__ char array[]; __shared__ float sSum; T* sVal = topKTmpValBuf + (batchIdx * maxTokensPerStep + tokenIdx) * stride; auto* sId = reinterpret_cast(array); if (tid == 0) { sSum = 0.0f; } TopK_2 partial; if (finishState.isFinished()) { if (finishedOutput != nullptr) { finishedOutput[batchSlot] = finishState; } return; } auto sVal2 = reinterpret_cast(sId + k); float maxLogit; for (SizeType32 ite = 0; ite < k; ite++) { partial.init(); #pragma unroll for (SizeType32 i = tid; i < size; i += BLOCK_SIZE_) { partial.insert((float) sVal[i], i); } TopK_2 total = BlockReduce(tempStorage).Reduce(partial, reduce_topk_op_2); if (tid == 0) { if (ite == 0) { maxLogit = total.u; } sId[ite] = total.p; sVal[total.p] = -MAX_T_VAL; // when cumLogProbs are computed, topKTmpValBuf (logits_buf_) are // already pre-processed by softmax_kernel if (!logitHasProbs) { total.u = __expf(total.u - maxLogit); } sVal2[ite] = total.u; sSum += total.u; } __syncthreads(); } if (tid == 0) { auto randNum = static_cast(curand_uniform(curandState + batchSlot) * probThreshold * sSum); auto* outputIdsRequestPtr = idsPtrs == nullptr ? ids + batchSlot * maxSeqLen : idsPtrs[batchSlot]; for (SizeType32 ki = 0; ki < k; ki++) { auto expLogit = sVal2[ki]; randNum = randNum - expLogit; if (randNum <= 0.0f || ki == k - 1 || returnAllTopK) { auto idx = sId[ki]; // If sId is -1 here we force output token to the last from vocabulary to get vivid indicator of smth // going wrong for the debug auto outputId = idx != -1 ? topKTmpIdBuf[(batchIdx * maxTokensPerStep + tokenIdx) * stride + idx] % vocabSize : vocabSize - 1; auto const curSeqLen = sequenceLengths == nullptr ? 0 : sequenceLengths[batchSlot]; auto const outIdx = returnAllTopK ? tokenIdx * maxTopK + ki : curSeqLen + tokenIdx; outputIdsRequestPtr[outIdx] = outputId; // cum log prob is not supported with returnAllTopK if (!returnAllTopK) { if (cumLogProbs != nullptr || outputLogProbs != nullptr) { auto logProb = logf(expLogit); if (cumLogProbs != nullptr) { cumLogProbs[batchSlot] += logProb; } if (outputLogProbs != nullptr) { // 'outputLogProbs' is the probability induced by the top-k sampling: // NOT normalized (same way as OpenAI does): // log_prob = log P(i | i is in vocab) = log(expLogit) // normalized: // log_prob = log P(i | i is in top-k) = log(expLogit / sum) outputLogProbs[curSeqLen * maxBatchSize + batchSlot] = normalizeLogProbs ? logProb - logf(sSum) : logProb; } } break; } } } if (maxTokensPerStep == 1 && !returnAllTopK && sequenceLengths != nullptr && finishedOutput != nullptr && endIds != nullptr) { auto const seqLen = sequenceLengths[batchSlot]; if (outputIdsRequestPtr[seqLen] == endIds[batchSlot]) { finishedOutput[batchSlot].setFinishedEOS(); // Do not increase seq len when EOS is generated. Seq len should always contain only tokens to be // outputted } else { // We don't need to set output finished state as it is assumed to be in non finished state sequenceLengths[batchSlot] += 1; } } } } #define CASE_K(K_MAX, BLOCK_SIZE_1_, BLOCK_SIZE_2_, BLOCKS_PER_BEAM_, normalizeLogProbs) \ do \ { \ { \ dim3 grid(params.batchSize* BLOCKS_PER_BEAM_, params.maxTokensPerStep); \ dim3 block(BLOCK_SIZE_1_); \ topKStage1<<>>(params.logProbs, \ params.logProbsPtrs, tempLogProbs, topKTmpIdBuf, topKTmpValBuf, params.finishedInput, params.maxTopK, \ params.topKs, params.vocabSizePadded, params.endIds, params.skipDecode, params.batchSlots, \ params.tokensPerStep, params.maxTokensPerStep); \ } \ { \ dim3 grid(params.batchSize, params.maxTokensPerStep); \ dim3 block(BLOCK_SIZE_2_); \ topKStage2Sampling \ <<>>(topKTmpIdBuf, \ topKTmpValBuf, params.outputIdsPtrs, params.outputIds, params.sequenceLengths, \ params.finishedInput, params.finishedOutput, params.cumLogProbs, params.outputLogProbs, \ params.maxTopK, params.topKs, params.maxTopP, params.topPs, params.curandState, params.endIds, \ params.vocabSizePadded, params.skipDecode, params.batchSlots, params.maxBatchSize, \ normalizeLogProbs, params.logitsHasProbs, params.tokensPerStep, params.maxTokensPerStep, \ params.maxSeqLen, params.returnAllTopK); \ } \ } while (0) template void invokeBatchTopKSampling(TopKSamplingKernelParams const& params, cudaStream_t stream) { TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__); params.checkParams(); // Not allow an ambiguous inputs topP and topPs. auto const workspaceSizes = getTopKWorkspaceSizes(params.batchSize, params.maxTokensPerStep, params.maxTopK, params.vocabSizePadded); if (params.maxTopK == 0) { return; } std::vector alignedPointers; calcAlignedPointers(alignedPointers, params.workspace, workspaceSizes); auto tempLogProbs = static_cast(alignedPointers[0]); auto topKTmpIdBuf = static_cast(alignedPointers[1]); auto topKTmpValBuf = static_cast(alignedPointers[2]); SizeType32 logMaxTopK{0}; SizeType32 recursor{params.maxTopK - 1}; while (recursor >>= 1) { ++logMaxTopK; } switch (logMaxTopK) { case 0: case 1: case 2: case 3: // 0 < maxTopK <= 16 CASE_K(16, 128, 128, 8, params.normalizeLogProbs); break; case 4: // 16 < maxTopK <= 32 CASE_K(32, 256, 128, 8, params.normalizeLogProbs); break; case 5: // 32 < maxTopK <= 64 CASE_K(64, 256, 256, 8, params.normalizeLogProbs); break; case 6: case 7: case 8: case 9: // 64 < maxTopK <= 1024 CASE_K(1024, 256, 256, 8, params.normalizeLogProbs); break; default: TLLM_CHECK_WITH_INFO(false, "TopK kernel supports 1 <= k <= 1024 but got k=%d", params.maxTopK); } TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__); } #undef CASE_K template void invokeBatchTopKSampling(TopKSamplingKernelParams const& params, cudaStream_t stream); template void invokeBatchTopKSampling(TopKSamplingKernelParams const& params, cudaStream_t stream); } // namespace kernels } // namespace tensorrt_llm