/* * Copyright (c) 2019-2024, NVIDIA CORPORATION. All rights reserved. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ #include "beamSearchLayer.h" #include "tensorrt_llm/kernels/beamSearchKernels/beamSearchKernelsTemplate.h" #include "tensorrt_llm/common/cudaUtils.h" #include "tensorrt_llm/common/stringUtils.h" #include "tensorrt_llm/kernels/beamSearchKernels.h" #include "tensorrt_llm/layers/defaultDecodingParams.h" #include "tensorrt_llm/layers/layerUtils.h" #include using namespace tensorrt_llm::runtime; using namespace tensorrt_llm::kernels; namespace tensorrt_llm::layers { #define GET_INFO_STAGE1(paddedBeamWidth) \ { \ int constexpr nBlock = (paddedBeamWidth < 16) ? ((paddedBeamWidth < 8) ? kThreadForSmallBeamWidth : 128) : 64; \ TLLM_CUDA_CHECK(cudaOccupancyMaxActiveBlocksPerMultiprocessor( \ &nMaxActiveBlock, beamStage1Kernel, nBlock, 0)); \ TLLM_CUDA_CHECK(cudaFuncGetAttributes(&attr, beamStage1Kernel)); \ break; \ } #define GET_INFO_STAGE2(paddedBeamWidth) \ { \ if (nByteDynamicSharedMemoryStage2 > nByteMaxSharedMemoryPerBlock) \ { \ TLLM_CUDA_CHECK(cudaFuncGetAttributes(&attr, beamStage2Kernel)); \ } \ else if (nVPart <= 32) \ { \ TLLM_CUDA_CHECK(cudaFuncGetAttributes(&attr, beamStage2Kernel)); \ } \ else if (nVPart <= 64) \ { \ TLLM_CUDA_CHECK(cudaFuncGetAttributes(&attr, beamStage2Kernel)); \ } \ else \ { \ TLLM_CUDA_CHECK(cudaFuncGetAttributes(&attr, beamStage2Kernel)); \ } \ break; \ } #define GET_INFO_STAGE3(paddedBeamWidth, isV2) \ { \ int constexpr nThreadStage3 = (paddedBeamWidth + 31) / 32 * 32; \ TLLM_CUDA_CHECK( \ cudaFuncGetAttributes(&attr, beamStage3Kernel)); \ break; \ } template BeamSearchLayer::BeamSearchLayer(DecoderDomain const& decoderDomain, std::shared_ptr bufferManager) : BaseLayer(decoderDomain, bufferManager) { TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__); SizeType32 const batchSize{mDecoderDomain.getBatchSize()}; SizeType32 const beamWidth{mDecoderDomain.getBeamWidth()}; SizeType32 const vocabSize{mDecoderDomain.getVocabSize()}; TLLM_CHECK_WITH_INFO(beamWidth <= kMaxBeamWidth, "Beam width is larger than the maximum supported (%d > %d)", int(beamWidth), int(kMaxBeamWidth)); this->mVBWS = decoderDomain.getUseVariableBeamWidthSearch(); allocateBuffer(); configureBeamSearchLayer(); TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__); } template void BeamSearchLayer::allocateBuffer() { TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__); SizeType32 const batchSize{mDecoderDomain.getBatchSize()}; auto const batchSizeShape{ITensor::makeShape({batchSize})}; auto const batchSizeXBeamWidthArraySizeShape{ ITensor::makeShape({batchSize * static_cast(kMaxBeamWidthArrayLength)})}; mBeamSearchDiversityRateHost = mBufferManager->pinnedPool(batchSizeShape, TRTDataType::value); mBeamSearchDiversityRateDevice = mBufferManager->gpu(batchSizeShape, TRTDataType::value); mLengthPenaltyHost = mBufferManager->pinnedPool(batchSizeShape, TRTDataType::value); mLengthPenaltyDevice = mBufferManager->gpu(batchSizeShape, TRTDataType::value); mEarlyStoppingHost = mBufferManager->pinnedPool(batchSizeShape, TRTDataType::value); mEarlyStoppingDevice = mBufferManager->gpu(batchSizeShape, TRTDataType::value); mBeamWidthArrayHost = mBufferManager->pinnedPool(batchSizeXBeamWidthArraySizeShape, TRTDataType::value); mBeamWidthArrayDevice = mBufferManager->gpu(batchSizeXBeamWidthArraySizeShape, TRTDataType::value); mBeamWidthIn = mBufferManager->pinnedPool(batchSizeShape, TRTDataType::value); mBeamWidthOut = mBufferManager->pinnedPool(batchSizeShape, TRTDataType::value); TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__); } template void BeamSearchLayer::configureBeamSearchLayer() { TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__); SizeType32 const batchSize{mDecoderDomain.getBatchSize()}; SizeType32 const beamWidth{mDecoderDomain.getBeamWidth()}; SizeType32 const vocabSize{mDecoderDomain.getVocabSize()}; SizeType32 const paddedBeamWidth{padToNextPowerOfTwo(beamWidth)}; cudaFuncAttributes attr; // Find device information to determine `nVPart`. int const nByteMaxSharedMemoryPerSM = getMaxSharedMemoryPerSM(); int const nByteMaxSharedMemoryPerBlock = getMaxSharedMemoryPerBlockOptin(); int const nByteReservedSharedMemoryPerBlock = nByteMaxSharedMemoryPerSM - nByteMaxSharedMemoryPerBlock; this->mByteMaxSharedMemoryPerBlock = nByteMaxSharedMemoryPerBlock; if (beamWidth <= kMaxBeamWidthForV1 && !(this->mVBWS)) { // V1 workflow for small beam width and non-VBWS // Stage 1 int nMaxActiveBlock = -1; switch (paddedBeamWidth) { case 1: GET_INFO_STAGE1(1); case 2: GET_INFO_STAGE1(2); case 4: GET_INFO_STAGE1(4); case 8: GET_INFO_STAGE1(8); default: break; } int nByteStaticSharedMemory = attr.sharedSizeBytes; int nByteMaxDynamicSharedMemoryPerBlock = nByteMaxSharedMemoryPerBlock - nByteStaticSharedMemory; // Find the maximum of `nBlock` (maximum of `nVPart`, minimum of `nByteDynamicSharedMemoryStage1`), s.t. // `nVPart <= kMaxVPartStage1 && nByteDynamicSharedMemoryStage1 * nVPart >= sizeof(T) * vocabSize` TLLM_CHECK_WITH_INFO(nByteMaxDynamicSharedMemoryPerBlock * kMaxVPartStage1 >= sizeof(T) * vocabSize, "vocab_size is too large for Beam search."); int nByteExtralSharedMemory = nByteReservedSharedMemoryPerBlock + nByteStaticSharedMemory; int nBlock = nMaxActiveBlock; int nVPart = kMaxVPartStage1 + 1; for (; nBlock > 0 && nVPart > kMaxVPartStage1; --nBlock) { int nByteDynamicSharedMemoryStage1 = nByteMaxSharedMemoryPerSM / nBlock - nByteExtralSharedMemory; nByteDynamicSharedMemoryStage1 -= nByteDynamicSharedMemoryStage1 % sizeof(T); nVPart = ceilDiv(sizeof(T) * vocabSize, nByteDynamicSharedMemoryStage1); } TLLM_CHECK_WITH_INFO(nBlock >= 0, "No enough active blocks for Beam Search stage 1 kernel."); int const nByteDynamicSharedMemoryStage1 = sizeof(T) * ceilDiv(vocabSize, nVPart); this->mVPart = nVPart; this->mByteSharedMemoryStage1 = nByteDynamicSharedMemoryStage1; // Only dynamic shared memory // Stage 2 TLLM_CHECK_WITH_INFO(batchSize * beamWidth * paddedBeamWidth < (1 << 21), "max_batch_size or max_beam_width of TRT-LLM engine is too large for Beam search, try to decrease the " "parameters while building."); size_t const nByteDynamicSharedMemoryStage2 = common::roundUp( sizeof(float) * nVPart * (paddedBeamWidth * 4) + sizeof(cub::KeyValuePair) * paddedBeamWidth * 2, 4); switch (paddedBeamWidth) { case 1: GET_INFO_STAGE2(1); case 2: GET_INFO_STAGE2(2); case 4: GET_INFO_STAGE2(4); case 8: GET_INFO_STAGE2(8); default: break; } nByteStaticSharedMemory = attr.sharedSizeBytes; nByteMaxDynamicSharedMemoryPerBlock = nByteMaxSharedMemoryPerBlock - nByteStaticSharedMemory; nByteExtralSharedMemory = nByteReservedSharedMemoryPerBlock + nByteStaticSharedMemory; bool const bUseGlobalMemoryStage2 = (nByteDynamicSharedMemoryStage2 > nByteMaxDynamicSharedMemoryPerBlock); // Stage 3 // Keep top 2K candidates in case of k candidates finishes in one iteration size_t const nByteDynamicSharedMemoryStage3 = common::roundUp(sizeof(T) * paddedBeamWidth * paddedBeamWidth * 2, 4); switch (paddedBeamWidth) { case 1: GET_INFO_STAGE3(1, false); case 2: GET_INFO_STAGE3(2, false); case 4: GET_INFO_STAGE3(4, false); case 8: GET_INFO_STAGE3(8, false); } nByteStaticSharedMemory = attr.sharedSizeBytes; nByteMaxDynamicSharedMemoryPerBlock = nByteMaxSharedMemoryPerBlock - nByteStaticSharedMemory; nByteExtralSharedMemory = nByteReservedSharedMemoryPerBlock + nByteStaticSharedMemory; bool const bUseGlobalMemoryStage3 = (nByteDynamicSharedMemoryStage3 > nByteMaxDynamicSharedMemoryPerBlock); this->mByteSharedMemoryStage3 = nByteStaticSharedMemory; // Only static shared memory // Compute workspace size, see `beamSearchKernelsTemplate.h` for detailed information // |<----- Workspace ----->| // |<- A ->|<- B ->|<- C ->| // |<---- D ---->| // A for data exchange between stage 2 and 3 // B for data exchange between stage 1 and 2, can be reuse for stage 3 // C for stage 2 if `bUseGlobalMemoryStage2 == true`, can be reuse for stage 3 // D for stage 3 if `bUseGlobalMemoryStage3 == true` size_t const nByteA = common::roundUp(sizeof(T) * batchSize * paddedBeamWidth * paddedBeamWidth * 4, 4); size_t const nByteB = common::roundUp(sizeof(T) * batchSize * paddedBeamWidth * kMaxVPartStage1 * paddedBeamWidth * 4, 4); size_t const nByteC = (bUseGlobalMemoryStage2) ? nByteDynamicSharedMemoryStage2 : 0; size_t const nByteD = (bUseGlobalMemoryStage3) ? nByteDynamicSharedMemoryStage3 : 0; this->mWorkspaceSize = nByteA + std::max(nByteB + nByteC, nByteD); } else { // V2 workflow for large beam width or VBWS this->mV2 = true; switch (paddedBeamWidth) { case 1: GET_INFO_STAGE3(1, true); case 2: GET_INFO_STAGE3(2, true); case 4: GET_INFO_STAGE3(4, true); case 8: GET_INFO_STAGE3(8, true); case 16: GET_INFO_STAGE3(16, true); case 32: GET_INFO_STAGE3(32, true); case 64: GET_INFO_STAGE3(64, true); case 128: GET_INFO_STAGE3(128, true); case 256: GET_INFO_STAGE3(256, true); case 512: GET_INFO_STAGE3(512, true); case 1024: GET_INFO_STAGE3(1024, true); } this->mByteSharedMemoryStage3 = attr.sharedSizeBytes; // Only static shared memory // Compute shared memory size for stage 3 // Compute workspace size, see `beamSearchKernelsTemplate.h` for detailed information // |<----------------------------------------- Workspace ------------------------------------------>| // |<- Stage2Ids ->|<- Stage2LogProbs ->|<- Stage1Ids ->|<- Stage1LogProbs ->|<---- Stage1TopK ---->| // |<- stage2TopK ->| // |<------------------ Stage3 ------------------>| SizeType32 const batchSize{mDecoderDomain.getBatchSize()}; SizeType32 const beamWidth{mDecoderDomain.getBeamWidth()}; SizeType32 const vocabSize{mDecoderDomain.getVocabSize()}; SizeType32 const paddedBeamWidth{padToNextPowerOfTwo(beamWidth)}; size_t const nByteStage1LogProbs = roundUp(sizeof(T) * batchSize * paddedBeamWidth * paddedBeamWidth * 2, 4); size_t const nByteStage1Ids = roundUp(sizeof(int) * batchSize * paddedBeamWidth * paddedBeamWidth * 2, 4); size_t const nByteStage2LogProbs = roundUp(sizeof(T) * batchSize * paddedBeamWidth * 2, 4); size_t const nByteStage2Ids = roundUp(sizeof(int) * batchSize * paddedBeamWidth * 2, 4); size_t const nByteStage1TopK = invokeComputeTopkLastDimWorkspaceSize(batchSize * beamWidth, vocabSize, paddedBeamWidth * 2, true); size_t const nByteStage2TopK = invokeComputeTopkLastDimWorkspaceSize( batchSize, paddedBeamWidth * paddedBeamWidth * 2, beamWidth * 2, true); size_t const nByteStage3 = sizeof(T) * beamWidth * beamWidth * 2; this->mWorkspaceSize = nByteStage2LogProbs + nByteStage2Ids + max(nByteStage1LogProbs + nByteStage1Ids + max(nByteStage1TopK, nByteStage2TopK), nByteStage3); } TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__); } template size_t BeamSearchLayer::getWorkspaceSize() const noexcept { return mWorkspaceSize; } template void BeamSearchLayer::setup(SizeType32 const batchSize, SizeType32 const beamWidth, TensorConstPtr batchSlots, std::shared_ptr const& baseSetupParams, std::shared_ptr const& workspace) { TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__); SizeType32 const maxBamWidth{mDecoderDomain.getBeamWidth()}; TLLM_CHECK_WITH_INFO(beamWidth <= maxBamWidth, "Beam width is larger than the constructed for (%d > %d).", int(beamWidth), int(maxBamWidth)); auto setupParams = std::dynamic_pointer_cast(baseSetupParams); auto constexpr fltMax = std::numeric_limits::max(); auto constexpr fltMin = std::numeric_limits::lowest(); auto constexpr fltEpsilon = std::numeric_limits::epsilon(); auto constexpr int32Max = std::numeric_limits::max(); FillBuffers const fillBuffers{batchSize, mDecoderDomain.getBatchSize(), mBufferManager}; fillBuffers(setupParams->beamSearchDiversityRate, DefaultDecodingParams::getBeamSearchDiversity(), mBeamSearchDiversityRateHost, mBeamSearchDiversityRateDevice, batchSlots, std::make_pair(-fltEpsilon, fltMax), "diversity rate"); fillBuffers(setupParams->lengthPenalty, DefaultDecodingParams::getLengthPenalty(), mLengthPenaltyHost, mLengthPenaltyDevice, batchSlots, std::make_pair(fltMin, fltMax), "length penalty"); fillBuffers(setupParams->earlyStopping, DefaultDecodingParams::getEarlyStopping(), mEarlyStoppingHost, mEarlyStoppingDevice, batchSlots, std::make_pair(-fltEpsilon, int32Max), "early stopping"); fillBuffers(setupParams->beamWidthArray, DefaultDecodingParams::getBeamWidthArray(), mBeamWidthArrayHost, mBeamWidthArrayDevice, batchSlots, std::make_pair(-fltEpsilon, kMaxBeamWidth), "beam width array"); TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__); } template void BeamSearchLayer::forwardAsync(std::shared_ptr const& baseOutputs, std::shared_ptr const& baseInputs, std::shared_ptr const& workspace) { TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__); auto ip = std::dynamic_pointer_cast(baseInputs); auto op = std::dynamic_pointer_cast(baseOutputs); auto const localDecoderDomain = getLocalDecoderDomain(ip, mDecoderDomain); TLLM_CHECK_WITH_INFO(localDecoderDomain.getBeamWidth() > 1, "Use beamWidth <= 1 (%d <= 1) in Beam Search mode", localDecoderDomain.getBeamWidth()); TLLM_CHECK_WITH_INFO(ip->srcCacheIndirection.has_value(), "srcCacheIndirection is mandatory in beam search."); TLLM_CHECK_WITH_INFO(op->parentIds.has_value(), "parentIds tensor is mandatory in beam search."); TLLM_CHECK_WITH_INFO(op->finished.has_value(), "finished tensor is mandatory in beam search."); TLLM_CHECK_WITH_INFO(op->cumLogProbs.has_value(), "cumLogProbs tensor is mandatory in beam search."); TLLM_CHECK_WITH_INFO(op->beamHypotheses, "Output BeamHypotheses is not set."); TLLM_CHECK_WITH_INFO(bufferCastOrNull(*op->sequenceLength) != nullptr || mLengthPenaltyDevice == nullptr, "Current sequence lengths must be set for length penalty computation."); TLLM_CHECK_WITH_INFO(ip->ite == 0, "Pipeline Parallelism is not supported yet!"); BeamHypotheses bh; // bh's members not used in this function: outputIds, logProbs, outputIdsUnfinish, parentIdsUnfinish bh.bVBWS = this->mVBWS; bh.nMaxBatchSize = static_cast(op->outputIdsPtr->getDimension<0>()); bh.nBatchSize = ip->localBatchSize; bh.nBeamWidth = op->outputIds->getDimension<1>(); bh.nMaxSeqLen = op->outputIds->getDimension<2>(); bh.nVocabSize = mDecoderDomain.getVocabSizePadded(); bh.nVPart = this->mVPart; bh.nByteMaxSharedMemoryPerBlock = this->mByteMaxSharedMemoryPerBlock; bh.nByteSharedMemoryStage1 = this->mByteSharedMemoryStage1; bh.nByteSharedMemoryStage3 = this->mByteSharedMemoryStage3; bh.diversityRates = bufferCast(*mBeamSearchDiversityRateDevice); bh.lengthPenalties = bufferCast(*mLengthPenaltyDevice); bh.earlyStoppings = bufferCast(*mEarlyStoppingDevice); bh.beamWidthArraysHost = bufferCast(*mBeamWidthArrayHost); bh.beamWidthArraysDevice = bufferCast(*mBeamWidthArrayDevice); bh.nBeamWidthInHost = bufferCast(*mBeamWidthIn); bh.nBeamWidthOutHost = bufferCast(*mBeamWidthOut); if (this->mVBWS) { int const* batchSlotsHost = bufferCast(*ip->batchSlots); for (int i = 0; i < ip->localBatchSize; ++i) { int const slot = batchSlotsHost[i]; int const step = ip->beamSearchSteps.value()[slot]; // Clamp `step` to [0, kMaxBeamWidthArrayLength - 1], and set `indexInput=0` when step = 0 or 1 int const indexInput = std::min(std::max((int) step - 1, 0), (int) kMaxBeamWidthArrayLength - 1); int const indexOutput = std::min((int) step, (int) kMaxBeamWidthArrayLength - 1); bh.nBeamWidthInHost[i] = bh.beamWidthArraysHost[slot * kMaxBeamWidthArrayLength + indexInput]; bh.nBeamWidthOutHost[i] = bh.beamWidthArraysHost[slot * kMaxBeamWidthArrayLength + indexOutput]; } } bh.inputLengths = bufferCast(*ip->inputLengths.value()); bh.endIds = bufferCast(*ip->endIds); bh.batchSlots = workspace->getDeviceBatchSlotsPtr(); // Device copy of `ip->batchSlots` bh.logProbsTiled = bufferCastOrNull(op->outputLogProbsTiled); bh.sequenceLengths = bufferCast(*op->sequenceLength.value()); bh.cumLogProbs = bufferCast(*op->cumLogProbs.value()); bh.outputIdsCBA = op->beamHypotheses->outputIdsCBA; bh.logProbsCBA = op->beamHypotheses->logProbsCBA; bh.sequenceLengthsCBA = op->beamHypotheses->sequenceLengthsCBA; bh.cumLogProbsCBA = op->beamHypotheses->cumLogProbsCBA; bh.normedScoresCBA = op->beamHypotheses->normedScoresCBA; bh.numBeamsCBA = op->beamHypotheses->numBeamsCBA; bh.minNormedScoresCBA = op->beamHypotheses->minNormedScoresCBA; bh.batchDones = op->beamHypotheses->batchDones; bh.finished = reinterpret_cast(bufferCast(*op->finished.value())); bh.outputIdsPtr = bufferCast(*op->outputIdsPtr); bh.parentIdsPtr = bufferCast(*op->parentIdsPtr); T const* logProbs = bufferCast(*workspace->getDeviceRuntimeLogits()); T const* bias = static_cast(nullptr); TLLM_CHECK_WITH_INFO(getWorkspaceSize() >= 2 * bh.nBatchSize * bh.nBeamWidth * bh.nBeamWidth * 2, "Workspace size (%lu) is not enough for topk softmax required (%lu).", (uint64_t) getWorkspaceSize(), (uint64_t) (2 * bh.nMaxBatchSize * bh.nBeamWidth * bh.nBeamWidth * 2)); if (this->mV2 || this->mVBWS) { invokeTopkBeamSearch(logProbs, bias, workspace->getRawWorkspaceDevicePtr(), bh, getStream()); } else { invokeTopkBeamSearch(logProbs, bias, workspace->getRawWorkspaceDevicePtr(), bh, getStream()); } int* tgtCI = bufferCast(*op->tgtCacheIndirection); int* srcCI = bufferCast(*ip->srcCacheIndirection.value()); invokeUpdateCacheIndirection(tgtCI, srcCI, bh, ip->maxAttentionWindow, ip->sinkTokenLength, getStream()); TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__); } template class BeamSearchLayer; template class BeamSearchLayer; } // namespace tensorrt_llm::layers