/* * Copyright (c) 2022-2024, NVIDIA CORPORATION. All rights reserved. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ #include "tensorrt_llm/runtime/gptDecoder.h" #include "tensorrt_llm/kernels/decodingKernels.h" #include "tensorrt_llm/kernels/speculativeDecoding/externalDraftTokensKernels.h" #include "tensorrt_llm/layers/decodingParams.h" #include "tensorrt_llm/layers/dynamicDecodeLayer.h" #include "tensorrt_llm/runtime/decodingLayerWorkspace.h" #include #include namespace tle = tensorrt_llm::executor; namespace tl = tensorrt_llm::layers; namespace tksd = tensorrt_llm::kernels::speculative_decoding; using namespace tensorrt_llm::runtime; using BufferConstPtr = IBuffer::SharedConstPtr; using BufferPtr = IBuffer::SharedPtr; using TensorConstPtr = ITensor::SharedConstPtr; using TensorPtr = ITensor::SharedPtr; template GptDecoder::GptDecoder(executor::DecodingMode const& mode, size_t maxBatchSize, size_t maxBeamWidth, size_t vocabSize, size_t vocabSizePadded, size_t maxSequenceLength, CudaStreamPtr const& stream, std::shared_ptr speculativeDecodingModule) : mManager{std::make_shared(stream)} , mMaxBatchSize(maxBatchSize) , mDecodingMode{mode} { TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__); auto const decodingDomain = tensorrt_llm::layers::DecoderDomain( maxBatchSize, maxBeamWidth, vocabSize, vocabSizePadded, speculativeDecodingModule); mDynamicDecodeLayer = std::make_shared>(mode, decodingDomain, mManager); mDecodingLayerWorkspace = std::make_unique( mManager, decodingDomain, TRTDataType::value, mDynamicDecodeLayer->getWorkspaceSize()); TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__); } template void GptDecoder::setup(SamplingConfig const& samplingConfig, size_t batchSize, TensorConstPtr const& batchSlots, std::optional const& output, std::optional const> const& requestsOpt) { TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__); mSamplingConfig = samplingConfig; auto setupParams = std::make_shared(); TLLM_CHECK_WITH_INFO(mSamplingConfig.validate(), "Sampling config is invalid"); TLLM_CHECK_WITH_INFO(batchSlots != nullptr, "Batch slots are mandatory to set up the decoder."); auto penaltyParams = std::make_shared(); penaltyParams->repetitionPenalty = mSamplingConfig.repetitionPenalty; penaltyParams->presencePenalty = mSamplingConfig.presencePenalty; penaltyParams->frequencyPenalty = mSamplingConfig.frequencyPenalty; penaltyParams->temperature = mSamplingConfig.temperature; penaltyParams->minLength = mSamplingConfig.minLength; setupParams->penaltyParams = std::move(penaltyParams); auto banWordsParams = std::make_shared(); banWordsParams->noRepeatNgramSize = mSamplingConfig.noRepeatNgramSize; setupParams->banWordsParams = std::move(banWordsParams); if (mDecodingMode.isTopKorTopP()) { auto samplingParams = std::make_shared(); samplingParams->normalizeLogProbs = mSamplingConfig.normalizeLogProbs; // signed to unsigned if (mSamplingConfig.topK) { auto const& topK = mSamplingConfig.topK.value(); samplingParams->runtimeTopK = std::vector(std::begin(topK), std::end(topK)); } samplingParams->runtimeTopP = mSamplingConfig.topP; samplingParams->topPDecay = mSamplingConfig.topPDecay; samplingParams->topPMin = mSamplingConfig.topPMin; samplingParams->topPResetIds = mSamplingConfig.topPResetIds; samplingParams->outputLogProbs = mSamplingConfig.outputLogProbs; samplingParams->cumLogProbs = mSamplingConfig.cumLogProbs; setupParams->decodingParams = std::move(samplingParams); } else if (mDecodingMode.isBeamSearch()) { auto beamSearchParams = std::make_shared(); beamSearchParams->beamSearchDiversityRate = mSamplingConfig.beamSearchDiversityRate; beamSearchParams->lengthPenalty = mSamplingConfig.lengthPenalty; beamSearchParams->earlyStopping = mSamplingConfig.earlyStopping; setupParams->decodingParams = std::move(beamSearchParams); } else if (mDecodingMode.isMedusa()) { auto medusaParams = std::make_shared(); // signed to unsigned if (mSamplingConfig.topK) { auto const& topK = mSamplingConfig.topK.value(); medusaParams->runtimeTopK = std::vector(std::begin(topK), std::end(topK)); } medusaParams->runtimeHeadsTopK = mSamplingConfig.topKMedusaHeads; setupParams->decodingParams = std::move(medusaParams); } else if (mDecodingMode.isExplicitDraftTokens()) { TLLM_CHECK_WITH_INFO(output.has_value(), "Output tensors must be provided for ExplicitDraftTokens"); auto explicitDraftTokensParams = std::make_shared(); explicitDraftTokensParams->temperature = mSamplingConfig.temperature; explicitDraftTokensParams->randomDataSample = output->explicitDraftTokensBuffers->randomDataSample; explicitDraftTokensParams->temperatures = output->explicitDraftTokensBuffers->temperatures; TLLM_CHECK(requestsOpt); // Ignore the dtype from all other requests assuming that it is the same for all. explicitDraftTokensParams->dtype = requestsOpt.value()[0].dtype; setupParams->decodingParams = explicitDraftTokensParams; } else if (mDecodingMode.isLookahead()) { TLLM_CHECK_WITH_INFO(output.has_value(), "Output tensors must be provided for Lookahead decoding"); TLLM_LOG_DEBUG("gptDecoder setup lookahead, batchSize=%d", batchSize); auto lookaheadParams = std::make_shared(); TLLM_CHECK(requestsOpt); auto& requests = requestsOpt.value(); lookaheadParams->prompt.resize(0); lookaheadParams->prompt.reserve(batchSize); lookaheadParams->algoConfigs.resize(0); lookaheadParams->algoConfigs.reserve(batchSize); for (size_t bi = 0; bi < batchSize; bi++) { lookaheadParams->prompt.emplace_back(ITensor::slice(requests[bi].ids, 0, requests[bi].inputLen)); TLLM_CHECK(requests[bi].lookaheadRuntimeConfig); lookaheadParams->algoConfigs.emplace_back(requests[bi].lookaheadRuntimeConfig.value()); } lookaheadParams->generationLengths = output->lookaheadOutputs->generationLengths; lookaheadParams->positionOffsets = output->lookaheadOutputs->positionOffsets; lookaheadParams->attentionPackedMasks = output->lookaheadOutputs->packedMasks; setupParams->decodingParams = std::move(lookaheadParams); } else if (mDecodingMode.isExternalDraftTokens()) { auto externalDraftTokensParams = std::make_shared(); // signed to unsigned if (mSamplingConfig.topK) { auto const& topK = mSamplingConfig.topK.value(); externalDraftTokensParams->runtimeTopK = std::vector(std::begin(topK), std::end(topK)); } externalDraftTokensParams->runtimeTopP = mSamplingConfig.topP; setupParams->decodingParams = std::move(externalDraftTokensParams); } setupParams->decodingParams->randomSeed = mSamplingConfig.randomSeed; mDecodingLayerWorkspace->setDeviceBatchSlots(batchSlots); mDynamicDecodeLayer->setup(batchSize, mSamplingConfig.beamWidth, batchSlots, setupParams, mDecodingLayerWorkspace); TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__); } namespace { std::shared_ptr prepareBanWordsInputs(DecodingInput const& input) { auto banWordsParams = std::make_shared(input.batchSize); if (input.badWordsPtrs) { TLLM_CHECK_WITH_INFO(input.badWordsPtrs, "Bad word lengths must be provided when badWordsPtrs is given"); banWordsParams->badWordsPtr = input.badWordsPtrs; banWordsParams->badWordsLengths = input.badWordsLens; banWordsParams->maxBadWordsLen = input.maxBadWordsLen; } return banWordsParams; } std::shared_ptr prepareStopCriteriaInputs(DecodingInput const& input) { auto stopCriteriaParams = std::make_shared(input.batchSize); if (input.stopWordsPtrs) { TLLM_CHECK_WITH_INFO(input.stopWordsLens, "Stop word lengths must be provided when stopWordsPtrs is given"); stopCriteriaParams->stopWordsPtr = input.stopWordsPtrs; stopCriteriaParams->stopWordsLengths = input.stopWordsLens; stopCriteriaParams->maxStopWordsLen = input.maxStopWordsLen; } if (input.sequenceLimitLength) { stopCriteriaParams->sequenceLimitLength = input.sequenceLimitLength; } return stopCriteriaParams; } void prepareMedusaInputs( DecodingInput const& inputs, size_t maxBatchSize, std::shared_ptr& baseInputs) { TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__); auto inputParams = std::dynamic_pointer_cast(baseInputs); auto const& medusaInputs = inputs.medusaInputs.value(); inputParams->curTokensPerStep = medusaInputs.medusaCurTokensPerStep; inputParams->targetTokensPerStep = medusaInputs.medusaTargetTokensPerStep; inputParams->paths = medusaInputs.medusaPaths; inputParams->treeIds = medusaInputs.medusaTreeIds; auto const batchSlots = bufferCast(*inputs.batchSlots); if (medusaInputs.medusaLogits.size()) { std::vector> medusaLogits; auto const batchSize = medusaInputs.medusaLogits.size(); medusaLogits.resize(maxBatchSize); for (size_t bi = 0; bi < batchSize; ++bi) { auto const slot = batchSlots[bi]; auto const& logitsHeads = medusaInputs.medusaLogits.at(slot); auto const medusaHeads = logitsHeads.size(); medusaLogits[slot].resize(medusaHeads); for (size_t hi = 0; hi < medusaHeads; ++hi) { if (logitsHeads[hi]) { medusaLogits[slot][hi] = logitsHeads[hi]; } } } inputParams->medusaLogits = medusaLogits; } TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__); } void prepareExternalDraftTokensInputs( DecodingInput const& inputs, size_t maxBatchSize, std::shared_ptr& baseInputs) { TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__); auto inputParams = std::dynamic_pointer_cast(baseInputs); auto const& externalDraftTokensInputs = inputs.externalDraftTokensInputs.value(); inputParams->draftLogits = externalDraftTokensInputs.draftLogits; inputParams->draftProbs = externalDraftTokensInputs.draftProbs; inputParams->targetProbs = externalDraftTokensInputs.targetProbs; inputParams->numDraftTokens = externalDraftTokensInputs.numDraftTokens; inputParams->draftTokenIds = externalDraftTokensInputs.draftTokenIds; inputParams->constantThreshold = externalDraftTokensInputs.constantThreshold; inputParams->useRandomAcceptanceThreshold = externalDraftTokensInputs.useRandomAcceptanceThreshold; inputParams->step = externalDraftTokensInputs.step; inputParams->useDraftLogits = externalDraftTokensInputs.useDraftLogits; TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__); } void prepareExplicitDraftTokensInput(DecodingInput const& inputs, std::shared_ptr& baseInputs) { TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__); auto inputParams = std::dynamic_pointer_cast(baseInputs); auto& explicitDraftTokensInputs = inputs.explicitDraftTokensInputs; TLLM_CHECK_WITH_INFO(explicitDraftTokensInputs.has_value(), "ExplicitDraftTokensInputs are not set"); inputParams->nextDraftTokens = explicitDraftTokensInputs->nextDraftTokens; inputParams->nextFlatTokens = explicitDraftTokensInputs->nextFlatTokens; inputParams->nextDraftIndices = explicitDraftTokensInputs->nextDraftIndices; inputParams->nextDraftProbs = explicitDraftTokensInputs->nextDraftProbs; inputParams->lastDraftTokens = explicitDraftTokensInputs->lastDraftTokens; inputParams->lastDraftIndices = explicitDraftTokensInputs->lastDraftIndices; inputParams->masks = explicitDraftTokensInputs->masks; inputParams->packedPosIds = explicitDraftTokensInputs->packedPositionIds; inputParams->bestPathLengths = explicitDraftTokensInputs->bestPathLengths; inputParams->bestPathIndices = explicitDraftTokensInputs->bestPathIndices; inputParams->generationLengths = explicitDraftTokensInputs->nextGenerationLengths; inputParams->positionIdsBase = explicitDraftTokensInputs->lastPositionIdsBase; inputParams->lastGenerationLengths = explicitDraftTokensInputs->lastGenerationLengths; inputParams->maxGenLengthDevice = explicitDraftTokensInputs->maxGenLengthDevice; inputParams->seqSlots = explicitDraftTokensInputs->seqSlots; TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__); } void prepareLookaheadInputs( DecodingInput const& inputs, size_t maxBatchSize, std::shared_ptr& baseInputs) { TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__); auto inputParams = std::dynamic_pointer_cast(baseInputs); auto const& lookaheadInputs = inputs.lookaheadInputs.value(); inputParams->curTokensPerStep = lookaheadInputs.tokensPerStep; TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__); } template std::shared_ptr prepareInputs( DecodingInput const& input, size_t maxBatchSize, tle::DecodingMode const& decodingMode) { auto constexpr ite = 0; TLLM_CHECK_WITH_INFO(input.batchSlots != nullptr, "Batch slots are mandatory to call the decoder."); std::shared_ptr forwardParams; if (decodingMode.isTopKorTopP()) { forwardParams = std::make_shared(input.endIds, input.batchSlots, input.step, ite, input.batchSize); } else if (decodingMode.isBeamSearch()) { forwardParams = std::make_shared(input.endIds, input.batchSlots, input.step, ite, input.batchSize, input.maxAttentionWindow, input.sinkTokenLength); } else if (decodingMode.isMedusa()) { forwardParams = std::make_shared(input.endIds, input.batchSlots, input.batchSize); } else if (decodingMode.isLookahead()) { forwardParams = std::make_shared(input.endIds, input.batchSlots); } else if (decodingMode.isExplicitDraftTokens()) { forwardParams = std::make_shared(input.endIds, input.batchSlots, input.batchSize); } else if (decodingMode.isExternalDraftTokens()) { forwardParams = std::make_shared( input.endIds, input.batchSlots, input.step, ite, input.batchSize); } // No logits for explicit draft tokens if (!decodingMode.isExplicitDraftTokens()) { if (input.logitsVec) { std::vector logitsVec; for (auto const& logits : input.logitsVec.value()) { TLLM_CHECK(logits->getDataType() == TRTDataType::value); logitsVec.push_back(logits); } forwardParams->logitsVec = logitsVec; } else if (input.logits) { TLLM_CHECK(input.logits->getDataType() == TRTDataType::value); forwardParams->logits = input.logits; } } if (input.cacheIndirection) { forwardParams->srcCacheIndirection = input.cacheIndirection; } if (input.embeddingBias) { forwardParams->embeddingBias = input.embeddingBias; } if (input.lengths) { forwardParams->inputLengths = input.lengths; } forwardParams->banWordsInputs = prepareBanWordsInputs(input); forwardParams->stopCriteriaInputs = prepareStopCriteriaInputs(input); if (input.finishReasons) { forwardParams->finished = input.finishReasons; } // Medusa if (decodingMode.isMedusa()) { prepareMedusaInputs(input, maxBatchSize, forwardParams); } // Explicit draft tokens if (decodingMode.isExplicitDraftTokens()) { prepareExplicitDraftTokensInput(input, forwardParams); } if (input.lookaheadInputs) { prepareLookaheadInputs(input, maxBatchSize, forwardParams); forwardParams->localBatchSize = input.batchSize; } if (decodingMode.isExternalDraftTokens()) { prepareExternalDraftTokensInputs(input, maxBatchSize, forwardParams); } TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__); return forwardParams; } void prepareBeamSearchOutputs(DecodingOutput& output, std::shared_ptr& baseOutputs) { TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__); auto outputParams = std::dynamic_pointer_cast(baseOutputs); outputParams->beamHypotheses = std::make_unique(); if (output.beamHypotheses.outputIdsCBA) { outputParams->beamHypotheses->outputIdsCBA = bufferCast(*output.beamHypotheses.outputIdsCBA); } if (output.beamHypotheses.logProbsCBA) { outputParams->beamHypotheses->logProbsCBA = bufferCast(*output.beamHypotheses.logProbsCBA); } if (output.beamHypotheses.sequenceLengthsCBA) { outputParams->beamHypotheses->sequenceLengthsCBA = bufferCast(*output.beamHypotheses.sequenceLengthsCBA); } if (output.beamHypotheses.cumLogProbsCBA) { outputParams->beamHypotheses->cumLogProbsCBA = bufferCast(*output.beamHypotheses.cumLogProbsCBA); } if (output.beamHypotheses.normedScoresCBA) { outputParams->beamHypotheses->normedScoresCBA = bufferCast(*output.beamHypotheses.normedScoresCBA); } if (output.beamHypotheses.numBeamsCBA) { outputParams->beamHypotheses->numBeamsCBA = bufferCast(*output.beamHypotheses.numBeamsCBA); } if (output.beamHypotheses.minNormedScoresCBA) { outputParams->beamHypotheses->minNormedScoresCBA = bufferCast(*output.beamHypotheses.minNormedScoresCBA); } if (output.beamHypotheses.batchDones) { outputParams->beamHypotheses->batchDones = bufferCast(*output.beamHypotheses.batchDones); } if (output.cacheIndirection) { outputParams->tgtCacheIndirection = output.cacheIndirection; } TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__); } void prepareSpeculativeDecodingOutputs(DecodingOutput& output, std::shared_ptr& baseOutputs, tle::DecodingMode const& decodingMode) { TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__); auto outputParams = std::dynamic_pointer_cast(baseOutputs); auto const& speculativeDecodingOutputs = output.speculativeDecodingOutputs; TLLM_CHECK_WITH_INFO(speculativeDecodingOutputs.has_value(), "speculativeDecodingOutputs is not set"); outputParams->nextDraftTokens = speculativeDecodingOutputs->nextDraftTokens; outputParams->numNewTokens = speculativeDecodingOutputs->acceptedTokensLen; outputParams->numNewTokensCumSum = speculativeDecodingOutputs->acceptedLengthsCumSum; outputParams->pathsOffsets = speculativeDecodingOutputs->pathsOffsets; if (speculativeDecodingOutputs->nextDraftTokensLen) { outputParams->nextDraftLengths = speculativeDecodingOutputs->nextDraftTokensLen; } if (speculativeDecodingOutputs->prevDraftTokensLen) { outputParams->prevDraftLengths = speculativeDecodingOutputs->prevDraftTokensLen; } if (decodingMode.isExplicitDraftTokens()) { auto outputParams = std::dynamic_pointer_cast(baseOutputs); auto const& explicitDraftTokensBuffers = output.explicitDraftTokensBuffers; TLLM_CHECK_WITH_INFO(explicitDraftTokensBuffers.has_value(), "explicitDraftTokensBuffers is not set"); outputParams->packedMasks = explicitDraftTokensBuffers->packedMasks; outputParams->nextDraftPosIds = explicitDraftTokensBuffers->positionIds; outputParams->unpackedNextDraftTokens = explicitDraftTokensBuffers->draftTokens; outputParams->unpackedNextDraftIndices = explicitDraftTokensBuffers->draftIndices; outputParams->nextDraftProbs = explicitDraftTokensBuffers->draftProbs; outputParams->positionIdsBase = explicitDraftTokensBuffers->positionIdsBase; outputParams->randomDataSample = explicitDraftTokensBuffers->randomDataSample; outputParams->randomDataValidation = explicitDraftTokensBuffers->randomDataValidation; outputParams->temperatures = explicitDraftTokensBuffers->temperatures; outputParams->generationLengths = explicitDraftTokensBuffers->generationLengths; outputParams->generationLengthsHost = explicitDraftTokensBuffers->generationLengthsHost; outputParams->maxGenLengthHost = explicitDraftTokensBuffers->maxGenLengthHost; } if (decodingMode.isLookahead()) { TLLM_CHECK(output.lookaheadOutputs); auto outputParams = std::dynamic_pointer_cast(baseOutputs); outputParams->packedMasks = output.lookaheadOutputs->packedMasks; outputParams->positionIds = output.lookaheadOutputs->positionIds; outputParams->positionOffsets = output.lookaheadOutputs->positionOffsets; outputParams->generationLengths = output.lookaheadOutputs->generationLengths; } TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__); } std::shared_ptr prepareOutputs(DecodingOutput& output, tle::DecodingMode const& decodingMode) { TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__); std::shared_ptr outputParams; if (decodingMode.isBeamSearch()) { outputParams = std::make_shared(output.ids); } else if (decodingMode.isMedusa()) { outputParams = std::make_shared(output.ids); } else if (decodingMode.isLookahead()) { outputParams = std::make_shared(output.ids); } else if (decodingMode.isExplicitDraftTokens()) { outputParams = std::make_shared(output.ids); } else { outputParams = std::make_shared(output.ids); } // Common outputs outputParams->newTokens = output.newTokens; if (output.cumLogProbs) { outputParams->cumLogProbs = output.cumLogProbs; } if (output.parentIds) { outputParams->parentIds = output.parentIds; } if (output.finishReasons) { outputParams->finished = output.finishReasons; } if (output.finishedSum) { outputParams->finishedSum = output.finishedSum; } if (output.lengths) { outputParams->sequenceLength = output.lengths; } if (output.logProbs) { outputParams->outputLogProbs = output.logProbs; outputParams->outputLogProbsTiled = output.logProbsTiled; } // Beam search outputs if (decodingMode.isBeamSearch()) { prepareBeamSearchOutputs(output, outputParams); } // Speculative decoding outputs if (decodingMode.isMedusa() || decodingMode.isLookahead() || decodingMode.isExplicitDraftTokens()) { prepareSpeculativeDecodingOutputs(output, outputParams, decodingMode); } TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__); return outputParams; } } // namespace template void GptDecoder::forwardAsync(DecodingOutput& output, DecodingInput const& input) { TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__); auto forwardParams = prepareInputs(input, mMaxBatchSize, mDecodingMode); auto outputParams = prepareOutputs(output, mDecodingMode); mDynamicDecodeLayer->forwardAsync(outputParams, forwardParams, mDecodingLayerWorkspace); TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__); } template void GptDecoder::forwardSync(DecodingOutput& output, DecodingInput const& input) { TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__); auto forwardParams = prepareInputs(input, mMaxBatchSize, mDecodingMode); auto outputParams = prepareOutputs(output, mDecodingMode); mDynamicDecodeLayer->forwardSync(outputParams, forwardParams, mDecodingLayerWorkspace); TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__); } namespace tensorrt_llm::runtime { template class GptDecoder; template class GptDecoder; } // namespace tensorrt_llm::runtime