/* * Copyright (c) 2022-2024, NVIDIA CORPORATION. All rights reserved. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ #include "tensorrt_llm/runtime/gptDecoder.h" #include "tensorrt_llm/executor/executor.h" #include "tensorrt_llm/layers/decodingParams.h" #include "tensorrt_llm/layers/dynamicDecodeLayer.h" #include "tensorrt_llm/runtime/decodingLayerWorkspace.h" #include #include namespace tle = tensorrt_llm::executor; namespace tl = tensorrt_llm::layers; using namespace tensorrt_llm::runtime; using BufferConstPtr = IBuffer::SharedConstPtr; using BufferPtr = IBuffer::SharedPtr; using TensorConstPtr = ITensor::SharedConstPtr; using TensorPtr = ITensor::SharedPtr; template GptDecoder::GptDecoder(executor::DecodingMode const& mode, size_t maxNumSequences, size_t maxBeamWidth, size_t vocabSize, size_t vocabSizePadded, CudaStreamPtr const& stream, std::shared_ptr speculativeDecodingModule) : mManager{std::make_shared(stream)} , mMaxNumSequences(maxNumSequences) , mVocabSize(vocabSize) , mVocabSizePadded(vocabSizePadded) , mDecodingMode{mode} { TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__); auto const decodingDomain = tensorrt_llm::layers::DecoderDomain( maxNumSequences, maxBeamWidth, vocabSize, vocabSizePadded, speculativeDecodingModule); mDynamicDecodeLayer = std::make_shared>(mode, decodingDomain, mManager); mDecodingLayerWorkspace = std::make_unique( mManager, decodingDomain, TRTDataType::value, mDynamicDecodeLayer->getWorkspaceSize()); TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__); } template void GptDecoder::disableLookahead( std::optional const& samplingConfig, SizeType32 batchSize, TensorConstPtr batchSlots) { TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__); mDecodingMode = executor::DecodingMode::TopKTopP(); auto const decodingDomain = tensorrt_llm::layers::DecoderDomain(mMaxNumSequences, 1, mVocabSize, mVocabSizePadded, nullptr); auto setupParams = std::make_shared(); if (batchSize == 0) { mDynamicDecodeLayer->disableLookahead( decodingDomain, batchSize, batchSlots, setupParams, mDecodingLayerWorkspace); return; } mSamplingConfig = samplingConfig.value(); TLLM_CHECK_WITH_INFO(mSamplingConfig.validate(), "Sampling config is invalid"); TLLM_CHECK_WITH_INFO(batchSlots != nullptr, "Batch slots are mandatory to set up the decoder."); // penalty parameters auto penaltyParams = std::make_shared(); penaltyParams->repetitionPenalty = mSamplingConfig.repetitionPenalty; penaltyParams->presencePenalty = mSamplingConfig.presencePenalty; penaltyParams->frequencyPenalty = mSamplingConfig.frequencyPenalty; penaltyParams->promptIgnoreLength = mSamplingConfig.promptIgnoreLength; penaltyParams->temperature = mSamplingConfig.temperature; penaltyParams->minLength = mSamplingConfig.minLength; // banwords parameters auto banWordsParams = std::make_shared(); banWordsParams->noRepeatNgramSize = mSamplingConfig.noRepeatNgramSize; // sampling parameters auto samplingParams = std::make_shared(); samplingParams->normalizeLogProbs = mSamplingConfig.normalizeLogProbs; if (mSamplingConfig.topK) { auto const& topK = mSamplingConfig.topK.value(); samplingParams->runtimeTopK = std::vector(std::begin(topK), std::end(topK)); } samplingParams->runtimeTopP = mSamplingConfig.topP; samplingParams->topPDecay = mSamplingConfig.topPDecay; samplingParams->topPMin = mSamplingConfig.topPMin; samplingParams->topPResetIds = mSamplingConfig.topPResetIds; samplingParams->outputLogProbs = mSamplingConfig.outputLogProbs; samplingParams->cumLogProbs = mSamplingConfig.cumLogProbs; samplingParams->runtimeMinP = mSamplingConfig.minP; // get setup parameters setupParams->penaltyParams = std::move(penaltyParams); setupParams->banWordsParams = std::move(banWordsParams); setupParams->decodingParams = std::move(samplingParams); mDecodingLayerWorkspace->setDeviceBatchSlots(batchSlots); mDynamicDecodeLayer->disableLookahead(decodingDomain, batchSize, batchSlots, setupParams, mDecodingLayerWorkspace); TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__); } template void GptDecoder::setup(SamplingConfig const& samplingConfig, size_t batchSize, TensorConstPtr const& batchSlots, std::optional const& output, std::optional explicitDraftTokensDType, std::optional> const& lookaheadPrompt, std::optional> const& lookaheadAlgoConfigs) { TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__); mSamplingConfig = samplingConfig; auto setupParams = std::make_shared(); TLLM_CHECK_WITH_INFO(mSamplingConfig.validate(), "Sampling config is invalid"); TLLM_CHECK_WITH_INFO(batchSlots != nullptr, "Batch slots are mandatory to set up the decoder."); auto penaltyParams = std::make_shared(); penaltyParams->repetitionPenalty = mSamplingConfig.repetitionPenalty; penaltyParams->presencePenalty = mSamplingConfig.presencePenalty; penaltyParams->frequencyPenalty = mSamplingConfig.frequencyPenalty; penaltyParams->promptIgnoreLength = mSamplingConfig.promptIgnoreLength; penaltyParams->temperature = mSamplingConfig.temperature; penaltyParams->minLength = mSamplingConfig.minLength; setupParams->penaltyParams = std::move(penaltyParams); auto banWordsParams = std::make_shared(); banWordsParams->noRepeatNgramSize = mSamplingConfig.noRepeatNgramSize; setupParams->banWordsParams = std::move(banWordsParams); if (mDecodingMode.isTopKorTopP()) { auto samplingParams = std::make_shared(); samplingParams->normalizeLogProbs = mSamplingConfig.normalizeLogProbs; // signed to unsigned if (mSamplingConfig.topK) { auto const& topK = mSamplingConfig.topK.value(); samplingParams->runtimeTopK = std::vector(std::begin(topK), std::end(topK)); } samplingParams->runtimeTopP = mSamplingConfig.topP; samplingParams->topPDecay = mSamplingConfig.topPDecay; samplingParams->topPMin = mSamplingConfig.topPMin; samplingParams->topPResetIds = mSamplingConfig.topPResetIds; samplingParams->outputLogProbs = mSamplingConfig.outputLogProbs; samplingParams->cumLogProbs = mSamplingConfig.cumLogProbs; samplingParams->runtimeMinP = mSamplingConfig.minP; setupParams->decodingParams = std::move(samplingParams); } else if (mDecodingMode.isBeamSearch()) { auto beamSearchParams = std::make_shared(); beamSearchParams->beamSearchDiversityRate = mSamplingConfig.beamSearchDiversityRate; beamSearchParams->lengthPenalty = mSamplingConfig.lengthPenalty; beamSearchParams->earlyStopping = mSamplingConfig.earlyStopping; beamSearchParams->beamWidthArray = mSamplingConfig.beamWidthArray; setupParams->decodingParams = std::move(beamSearchParams); } else if (mDecodingMode.isMedusa()) { auto medusaParams = std::make_shared(); // signed to unsigned if (mSamplingConfig.topK) { auto const& topK = mSamplingConfig.topK.value(); medusaParams->runtimeTopK = std::vector(std::begin(topK), std::end(topK)); } medusaParams->runtimeHeadsTopK = mSamplingConfig.topKMedusaHeads; setupParams->decodingParams = std::move(medusaParams); } else if (mDecodingMode.isExplicitDraftTokens()) { TLLM_CHECK_WITH_INFO(output.has_value(), "Output tensors must be provided for ExplicitDraftTokens"); auto explicitDraftTokensParams = std::make_shared(); explicitDraftTokensParams->temperature = mSamplingConfig.temperature; explicitDraftTokensParams->randomDataSample = output->explicitDraftTokensBuffers->randomDataSample; explicitDraftTokensParams->temperatures = output->explicitDraftTokensBuffers->temperatures; TLLM_CHECK(explicitDraftTokensDType.has_value()); explicitDraftTokensParams->dtype = explicitDraftTokensDType.value(); setupParams->decodingParams = explicitDraftTokensParams; } else if (mDecodingMode.isLookahead()) { TLLM_LOG_DEBUG("gptDecoder setup lookahead, batchSize=%d", batchSize); auto lookaheadParams = std::make_shared(); TLLM_CHECK_WITH_INFO(lookaheadPrompt.has_value(), "Lookahead prompt must be provided"); lookaheadParams->prompt = lookaheadPrompt.value(); TLLM_CHECK_WITH_INFO(lookaheadAlgoConfigs.has_value(), "Lookahead algo configs must be provided"); lookaheadParams->algoConfigs = lookaheadAlgoConfigs.value(); TLLM_CHECK_WITH_INFO(output.has_value(), "Output tensors must be provided for Lookahead decoding"); lookaheadParams->generationLengths = output->lookaheadOutputs->generationLengths; lookaheadParams->positionOffsets = output->lookaheadOutputs->positionOffsets; lookaheadParams->attentionPackedMasks = output->lookaheadOutputs->packedMasks; setupParams->decodingParams = std::move(lookaheadParams); } else if (mDecodingMode.isExternalDraftTokens()) { auto externalDraftTokensParams = std::make_shared(); // signed to unsigned if (mSamplingConfig.topK) { auto const& topK = mSamplingConfig.topK.value(); externalDraftTokensParams->runtimeTopK = std::vector(std::begin(topK), std::end(topK)); } externalDraftTokensParams->runtimeTopP = mSamplingConfig.topP; setupParams->decodingParams = std::move(externalDraftTokensParams); } else if (mDecodingMode.isEagle()) { TLLM_CHECK_WITH_INFO(output.has_value(), "Output tensors must be provided for Eagle"); auto eagleParams = std::make_shared(); eagleParams->temperature = mSamplingConfig.originalTemperature; eagleParams->randomDataSample = output->eagleBuffers->randomDataSample; eagleParams->temperatures = output->eagleBuffers->temperatures; setupParams->decodingParams = eagleParams; } setupParams->decodingParams->randomSeed = mSamplingConfig.randomSeed; mDynamicDecodeLayer->setup(batchSize, mSamplingConfig.beamWidth, batchSlots, setupParams, mDecodingLayerWorkspace); TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__); } namespace { std::shared_ptr prepareBanWordsInputs(DecodingInput const& input) { auto banWordsParams = std::make_shared(input.batchSize); if (input.badWordsPtrs) { TLLM_CHECK_WITH_INFO(input.badWordsPtrs, "Bad word lengths must be provided when badWordsPtrs is given"); banWordsParams->badWordsPtr = input.badWordsPtrs; banWordsParams->badWordsLengths = input.badWordsLens; banWordsParams->maxBadWordsLen = input.maxBadWordsLen; } return banWordsParams; } std::shared_ptr prepareStopCriteriaInputs(DecodingInput const& input) { auto stopCriteriaParams = std::make_shared(input.batchSize); if (input.stopWordsPtrs) { TLLM_CHECK_WITH_INFO(input.stopWordsLens, "Stop word lengths must be provided when stopWordsPtrs is given"); stopCriteriaParams->stopWordsPtr = input.stopWordsPtrs; stopCriteriaParams->stopWordsLengths = input.stopWordsLens; stopCriteriaParams->maxStopWordsLen = input.maxStopWordsLen; } if (input.sequenceLimitLength) { stopCriteriaParams->sequenceLimitLength = input.sequenceLimitLength; } return stopCriteriaParams; } void prepareMedusaInputs( DecodingInput const& inputs, size_t maxNumSequences, std::shared_ptr& baseInputs) { TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__); auto inputParams = std::dynamic_pointer_cast(baseInputs); auto const& medusaInputs = inputs.medusaInputs.value(); inputParams->curTokensPerStep = medusaInputs.medusaCurTokensPerStep; inputParams->targetTokensPerStep = medusaInputs.medusaTargetTokensPerStep; inputParams->paths = medusaInputs.medusaPaths; inputParams->treeIds = medusaInputs.medusaTreeIds; auto const batchSlots = bufferCast(*inputs.batchSlots); if (medusaInputs.medusaLogits.size()) { std::vector> medusaLogits; auto const batchSize = medusaInputs.medusaLogits.size(); medusaLogits.resize(maxNumSequences); for (size_t bi = 0; bi < batchSize; ++bi) { auto const slot = batchSlots[bi]; auto const& logitsHeads = medusaInputs.medusaLogits.at(slot); auto const medusaHeads = logitsHeads.size(); medusaLogits[slot].resize(medusaHeads); for (size_t hi = 0; hi < medusaHeads; ++hi) { if (logitsHeads[hi]) { medusaLogits[slot][hi] = logitsHeads[hi]; } } } inputParams->medusaLogits = medusaLogits; } TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__); } void prepareExternalDraftTokensInputs(DecodingInput const& inputs, std::shared_ptr& baseInputs) { TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__); auto inputParams = std::dynamic_pointer_cast(baseInputs); auto const& externalDraftTokensInputs = inputs.externalDraftTokensInputs.value(); inputParams->draftLogits = externalDraftTokensInputs.draftLogits; inputParams->draftProbs = externalDraftTokensInputs.draftProbs; inputParams->targetProbs = externalDraftTokensInputs.targetProbs; inputParams->numDraftTokens = externalDraftTokensInputs.numDraftTokens; inputParams->numDraftTokensHost = externalDraftTokensInputs.numDraftTokensHost; inputParams->draftTokenIds = externalDraftTokensInputs.draftTokenIds; inputParams->constantThreshold = externalDraftTokensInputs.constantThreshold; inputParams->useRandomAcceptanceThreshold = externalDraftTokensInputs.useRandomAcceptanceThreshold; inputParams->step = externalDraftTokensInputs.step; inputParams->useDraftLogits = externalDraftTokensInputs.useDraftLogits; inputParams->useDraftLogitsHost = externalDraftTokensInputs.useDraftLogitsHost; TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__); } void prepareExplicitDraftTokensInput(DecodingInput const& inputs, std::shared_ptr& baseInputs) { TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__); auto inputParams = std::dynamic_pointer_cast(baseInputs); auto& explicitDraftTokensInputs = inputs.explicitDraftTokensInputs; TLLM_CHECK_WITH_INFO(explicitDraftTokensInputs.has_value(), "ExplicitDraftTokensInputs are not set"); inputParams->nextDraftTokens = explicitDraftTokensInputs->nextDraftTokens; inputParams->nextFlatTokens = explicitDraftTokensInputs->nextFlatTokens; inputParams->nextDraftIndices = explicitDraftTokensInputs->nextDraftIndices; inputParams->nextDraftProbs = explicitDraftTokensInputs->nextDraftProbs; inputParams->lastDraftTokens = explicitDraftTokensInputs->lastDraftTokens; inputParams->lastDraftIndices = explicitDraftTokensInputs->lastDraftIndices; inputParams->masks = explicitDraftTokensInputs->masks; inputParams->packedPosIds = explicitDraftTokensInputs->packedPositionIds; inputParams->bestPathLengths = explicitDraftTokensInputs->bestPathLengths; inputParams->bestPathIndices = explicitDraftTokensInputs->bestPathIndices; inputParams->generationLengths = explicitDraftTokensInputs->nextGenerationLengths; inputParams->positionIdsBase = explicitDraftTokensInputs->lastPositionIdsBase; inputParams->lastGenerationLengths = explicitDraftTokensInputs->lastGenerationLengths; inputParams->maxGenLengthDevice = explicitDraftTokensInputs->maxGenLengthDevice; inputParams->seqSlots = explicitDraftTokensInputs->seqSlots; TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__); } void prepareLookaheadInputs(DecodingInput const& inputs, std::shared_ptr& baseInputs) { TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__); auto inputParams = std::dynamic_pointer_cast(baseInputs); auto const& lookaheadInputs = inputs.lookaheadInputs.value(); inputParams->curTokensPerStep = lookaheadInputs.tokensPerStep; TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__); } void prepareEagleInput(DecodingInput const& inputs, std::shared_ptr& baseInputs) { TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__); auto inputParams = std::dynamic_pointer_cast(baseInputs); auto& eagleInputs = inputs.eagleInputs; TLLM_CHECK_WITH_INFO(eagleInputs.has_value(), "EagleInputs are not set"); inputParams->nextDraftTokens = eagleInputs->nextDraftTokens; inputParams->nextDraftLens = eagleInputs->nextDraftLens; inputParams->nextDraftPaths = eagleInputs->nextDraftPaths; inputParams->lastDraftTokens = eagleInputs->lastDraftTokens; inputParams->lastDraftLens = eagleInputs->lastDraftLens; inputParams->lastDraftPaths = eagleInputs->lastDraftPaths; inputParams->acceptedTokens = eagleInputs->acceptedTokens; inputParams->acceptedLens = eagleInputs->acceptedLens; inputParams->acceptedPathIds = eagleInputs->acceptedPathIds; inputParams->chunkedContextNextTokens = eagleInputs->chunkedContextNextTokens; inputParams->seqSlots = eagleInputs->seqSlots; TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__); } template std::shared_ptr prepareInputs( DecodingInput const& input, size_t maxNumSequences, tle::DecodingMode const& decodingMode) { TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__); auto constexpr ite = 0; TLLM_CHECK_WITH_INFO(input.batchSlots != nullptr, "Batch slots are mandatory to call the decoder."); std::shared_ptr forwardParams; if (decodingMode.isTopKorTopP()) { forwardParams = std::make_shared(input.endIds, input.batchSlots, input.step, ite, input.batchSize); } else if (decodingMode.isBeamSearch()) { forwardParams = std::make_shared(input.endIds, input.batchSlots, input.step, ite, input.batchSize, input.maxAttentionWindow, input.sinkTokenLength); if (input.cacheIndirection) { forwardParams->srcCacheIndirection = input.cacheIndirection; } forwardParams->beamSearchSteps = input.generationSteps; } else if (decodingMode.isMedusa()) { forwardParams = std::make_shared(input.endIds, input.batchSlots, input.batchSize); } else if (decodingMode.isLookahead()) { forwardParams = std::make_shared(input.endIds, input.batchSlots); } else if (decodingMode.isExplicitDraftTokens()) { forwardParams = std::make_shared(input.endIds, input.batchSlots, input.batchSize); } else if (decodingMode.isExternalDraftTokens()) { forwardParams = std::make_shared( input.endIds, input.batchSlots, input.step, ite, input.batchSize); } else if (decodingMode.isEagle()) { auto& eagleInputs = input.eagleInputs; TLLM_CHECK_WITH_INFO(eagleInputs.has_value(), "EagleInputs are not set"); forwardParams = std::make_shared(input.endIds, input.batchSlots, input.batchSize, eagleInputs->nextDraftTokens, eagleInputs->nextDraftLens, eagleInputs->nextDraftPaths, eagleInputs->lastDraftTokens, eagleInputs->lastDraftLens, eagleInputs->lastDraftPaths, eagleInputs->acceptedTokens, eagleInputs->acceptedLens, eagleInputs->acceptedPathIds, eagleInputs->chunkedContextNextTokens, eagleInputs->seqSlots); } // No logits for explicit draft tokens and eagle if (!decodingMode.isExplicitDraftTokens() && !decodingMode.isEagle()) { for (auto const& logits : input.logitsVec) { TLLM_CHECK(logits->getDataType() == TRTDataType::value); } forwardParams->logitsVec = input.logitsVec; } if (input.embeddingBias) { forwardParams->embeddingBias = input.embeddingBias; } if (input.lengths) { forwardParams->inputLengths = input.lengths; } forwardParams->banWordsInputs = prepareBanWordsInputs(input); forwardParams->stopCriteriaInputs = prepareStopCriteriaInputs(input); if (input.finishReasons) { forwardParams->finished = input.finishReasons; } // Speculative decoding if (decodingMode.isMedusa()) { prepareMedusaInputs(input, maxNumSequences, forwardParams); } else if (decodingMode.isExplicitDraftTokens()) { prepareExplicitDraftTokensInput(input, forwardParams); } else if (decodingMode.isLookahead() && input.lookaheadInputs) { prepareLookaheadInputs(input, forwardParams); forwardParams->localBatchSize = input.batchSize; } else if (decodingMode.isExternalDraftTokens()) { prepareExternalDraftTokensInputs(input, forwardParams); } else if (decodingMode.isEagle()) { prepareEagleInput(input, forwardParams); } TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__); return forwardParams; } void prepareBeamSearchOutputs(DecodingOutput& output, std::shared_ptr& baseOutputs) { TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__); auto const& bhSrc = output.beamHypotheses; auto bhOutputs = std::dynamic_pointer_cast(baseOutputs); bhOutputs->beamHypotheses = std::make_unique(); auto& bhDst = bhOutputs->beamHypotheses; if (bhSrc.outputIdsCBA) { bhDst->outputIdsCBA = bufferCast(*bhSrc.outputIdsCBA); } if (bhSrc.logProbsCBA) { bhDst->logProbsCBA = bufferCast(*bhSrc.logProbsCBA); } if (bhSrc.sequenceLengthsCBA) { bhDst->sequenceLengthsCBA = bufferCast(*bhSrc.sequenceLengthsCBA); } if (bhSrc.cumLogProbsCBA) { bhDst->cumLogProbsCBA = bufferCast(*bhSrc.cumLogProbsCBA); } if (bhSrc.normedScoresCBA) { bhDst->normedScoresCBA = bufferCast(*bhSrc.normedScoresCBA); } if (bhSrc.numBeamsCBA) { bhDst->numBeamsCBA = bufferCast(*bhSrc.numBeamsCBA); } if (bhSrc.minNormedScoresCBA) { bhDst->minNormedScoresCBA = bufferCast(*bhSrc.minNormedScoresCBA); } if (bhSrc.batchDones) { bhDst->batchDones = bufferCast(*bhSrc.batchDones); } if (output.cacheIndirection) { bhOutputs->tgtCacheIndirection = output.cacheIndirection; } TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__); } void prepareSpeculativeDecodingOutputs(DecodingOutput& output, std::shared_ptr& baseOutputs, tle::DecodingMode const& decodingMode) { TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__); auto outputParams = std::dynamic_pointer_cast(baseOutputs); auto const& speculativeDecodingOutputs = output.speculativeDecodingOutputs; TLLM_CHECK_WITH_INFO(speculativeDecodingOutputs.has_value(), "speculativeDecodingOutputs is not set"); outputParams->nextDraftTokens = speculativeDecodingOutputs->nextDraftTokens; outputParams->numNewTokens = speculativeDecodingOutputs->acceptedTokensLen; outputParams->numNewTokensCumSum = speculativeDecodingOutputs->acceptedLengthsCumSum; outputParams->pathsOffsets = speculativeDecodingOutputs->pathsOffsets; if (speculativeDecodingOutputs->nextDraftTokensLen) { outputParams->nextDraftLengths = speculativeDecodingOutputs->nextDraftTokensLen; } if (speculativeDecodingOutputs->prevDraftTokensLen) { outputParams->prevDraftLengths = speculativeDecodingOutputs->prevDraftTokensLen; } if (decodingMode.isExplicitDraftTokens()) { auto outputParams = std::dynamic_pointer_cast(baseOutputs); auto const& explicitDraftTokensBuffers = output.explicitDraftTokensBuffers; TLLM_CHECK_WITH_INFO(explicitDraftTokensBuffers.has_value(), "explicitDraftTokensBuffers is not set"); outputParams->packedMasks = explicitDraftTokensBuffers->packedMasks; outputParams->nextDraftPosIds = explicitDraftTokensBuffers->positionIds; outputParams->unpackedNextDraftTokens = explicitDraftTokensBuffers->draftTokens; outputParams->unpackedNextDraftIndices = explicitDraftTokensBuffers->draftIndices; outputParams->nextDraftProbs = explicitDraftTokensBuffers->draftProbs; outputParams->positionIdsBase = explicitDraftTokensBuffers->positionIdsBase; outputParams->randomDataSample = explicitDraftTokensBuffers->randomDataSample; outputParams->randomDataValidation = explicitDraftTokensBuffers->randomDataValidation; outputParams->temperatures = explicitDraftTokensBuffers->temperatures; outputParams->generationLengths = explicitDraftTokensBuffers->generationLengths; outputParams->generationLengthsHost = explicitDraftTokensBuffers->generationLengthsHost; outputParams->maxGenLengthHost = explicitDraftTokensBuffers->maxGenLengthHost; } else if (decodingMode.isLookahead()) { TLLM_CHECK(output.lookaheadOutputs); auto outputParams = std::dynamic_pointer_cast(baseOutputs); outputParams->packedMasks = output.lookaheadOutputs->packedMasks; outputParams->positionIds = output.lookaheadOutputs->positionIds; outputParams->positionOffsets = output.lookaheadOutputs->positionOffsets; outputParams->generationLengths = output.lookaheadOutputs->generationLengths; } else if (decodingMode.isEagle()) { auto outputParams = std::dynamic_pointer_cast(baseOutputs); auto const& eagleBuffers = output.eagleBuffers; TLLM_CHECK_WITH_INFO(eagleBuffers.has_value(), "eagleBuffers is not set"); outputParams->temperatures = eagleBuffers->temperatures; outputParams->unpackedNextDraftTokens = eagleBuffers->draftTokens; outputParams->nextDraftPaths = eagleBuffers->draftPaths; outputParams->generationLengths = eagleBuffers->specDecodingGenerationLengths; outputParams->generationLengthsHost = eagleBuffers->specDecodingGenerationLengthsHost; outputParams->nextDraftPosIds = eagleBuffers->specDecodingPositionOffsets; outputParams->packedMasks = eagleBuffers->specDecodingPackedMasks; outputParams->randomDataSample = eagleBuffers->randomDataSample; outputParams->randomDataValidation = eagleBuffers->randomDataValidation; outputParams->eagleNetCtxRequestTypesHost = eagleBuffers->eagleNetCtxRequestTypesHost; outputParams->eagleNetCtxContextLengthsHost = eagleBuffers->eagleNetCtxContextLengthsHost; outputParams->eagleNetCtxPastKeyValueLengthsHost = eagleBuffers->eagleNetCtxPastKeyValueLengthsHost; outputParams->eagleNetGenRequestTypesHost = eagleBuffers->eagleNetGenRequestTypesHost; outputParams->eagleNetGenContextLengthsHost = eagleBuffers->eagleNetGenContextLengthsHost; outputParams->eagleNetGenPastKeyValueLengthsHost = eagleBuffers->eagleNetGenPastKeyValueLengthsHost; } TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__); } std::shared_ptr prepareOutputs(DecodingOutput& output, tle::DecodingMode const& decodingMode) { TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__); std::shared_ptr outputParams; if (decodingMode.isBeamSearch()) { outputParams = std::make_shared(output.ids); } else if (decodingMode.isMedusa()) { outputParams = std::make_shared(output.ids); } else if (decodingMode.isLookahead()) { outputParams = std::make_shared(output.ids); } else if (decodingMode.isExplicitDraftTokens()) { outputParams = std::make_shared(output.ids); } else if (decodingMode.isEagle()) { outputParams = std::make_shared(output.ids); } else { outputParams = std::make_shared(output.ids); } // Common outputs outputParams->newTokens = output.newTokens; if (output.cumLogProbs) { outputParams->cumLogProbs = output.cumLogProbs; } if (output.parentIds) { outputParams->parentIds = output.parentIds; } if (output.finishReasons) { outputParams->finished = output.finishReasons; } if (output.finishedSum) { outputParams->finishedSum = output.finishedSum; } if (output.lengths) { outputParams->sequenceLength = output.lengths; } if (output.logProbs) { outputParams->outputLogProbs = output.logProbs; outputParams->outputLogProbsTiled = output.logProbsTiled; } // Beam search outputs if (decodingMode.isBeamSearch()) { prepareBeamSearchOutputs(output, outputParams); } // Speculative decoding outputs if (decodingMode.isMedusa() || decodingMode.isLookahead() || decodingMode.isExplicitDraftTokens() || decodingMode.isEagle()) { prepareSpeculativeDecodingOutputs(output, outputParams, decodingMode); } TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__); return outputParams; } } // namespace template void GptDecoder::forwardAsync(DecodingOutput& output, DecodingInput const& input) { TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__); auto forwardParams = prepareInputs(input, mMaxNumSequences, mDecodingMode); auto outputParams = prepareOutputs(output, mDecodingMode); mDynamicDecodeLayer->forwardAsync(outputParams, forwardParams, mDecodingLayerWorkspace); TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__); } template void GptDecoder::forwardSync(DecodingOutput& output, DecodingInput const& input) { TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__); auto forwardParams = prepareInputs(input, mMaxNumSequences, mDecodingMode); auto outputParams = prepareOutputs(output, mDecodingMode); mDynamicDecodeLayer->forwardSync(outputParams, forwardParams, mDecodingLayerWorkspace); TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__); } namespace tensorrt_llm::runtime { template class GptDecoder; template class GptDecoder; } // namespace tensorrt_llm::runtime