/* * Copyright (c) 2022-2024, NVIDIA CORPORATION. All rights reserved. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ #include "tensorrt_llm/runtime/decoderState.h" #include "tensorrt_llm/batch_manager/llmRequest.h" #include "tensorrt_llm/kernels/decodingCommon.h" #include "tensorrt_llm/runtime/runtimeKernels.h" namespace tk = tensorrt_llm::kernels; namespace tensorrt_llm::runtime::decoder { using TensorPtr = DecoderState::TensorPtr; BeamSearchBuffers::BeamSearchBuffers(BufferManager const& bufferManager) : mOutputBeamHypotheses{} , mCumLogProbsTmp(bufferManager.emptyTensor(MemoryType::kGPU, nvinfer1::DataType::kFLOAT)) { mOutputBeamHypotheses.empty(bufferManager); mCumLogProbsTmp = bufferManager.emptyTensor(MemoryType::kGPU, nvinfer1::DataType::kFLOAT); int device; cudaGetDevice(&device); cudaDeviceProp deviceProp; cudaGetDeviceProperties(&deviceProp, device); mNumSMs = deviceProp.multiProcessorCount; } void BeamSearchBuffers::reshape(SizeType32 maxBeamWidth, SizeType32 maxSequenceLength) { mOutputBeamHypotheses.reshape(1, maxBeamWidth, maxSequenceLength); mCumLogProbsTmp->reshape(ITensor::makeShape({1, maxBeamWidth})); } DecoderState::DecoderState() { TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__); mJointDecodingInput = std::make_unique(); mJointDecodingOutput = std::make_unique(); TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__); } void DecoderState::setup(SizeType32 maxNumSequences, SizeType32 maxBeamWidth, SizeType32 maxAttentionWindow, SizeType32 sinkTokenLength, SizeType32 maxSequenceLength, nvinfer1::DataType dtype, ModelConfig const& modelConfig, WorldConfig const& worldConfig, BufferManager const& bufferManager) { TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__); setupBuffers(dtype, bufferManager); reshapeBuffers(maxNumSequences, maxBeamWidth, maxAttentionWindow, sinkTokenLength, maxSequenceLength, modelConfig, worldConfig, bufferManager); TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__); } void DecoderState::setupBuffers(nvinfer1::DataType dtype, BufferManager const& bufferManager) { TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__); auto constexpr nvTokenIdType = TRTDataType::value; auto constexpr nvSizeType = TRTDataType::value; auto constexpr nvFloatType = TRTDataType::value; auto& dInput = mJointDecodingInput; TLLM_CHECK(static_cast(dInput)); dInput->endIds = bufferManager.emptyTensor(MemoryType::kGPU, nvTokenIdType); dInput->batchSlots = bufferManager.emptyTensor(MemoryType::kPINNEDPOOL, nvSizeType); dInput->sequenceLimitLength = bufferManager.emptyTensor(MemoryType::kGPU, nvSizeType); dInput->lengths = bufferManager.emptyTensor(MemoryType::kGPU, nvSizeType); auto& dOutput = mJointDecodingOutput; TLLM_CHECK(static_cast(dOutput)); dOutput->ids = bufferManager.emptyTensor(MemoryType::kGPU, nvTokenIdType); dOutput->gatheredIds = bufferManager.emptyTensor(MemoryType::kGPU, nvTokenIdType); dOutput->newTokensSteps = bufferManager.emptyTensor(MemoryType::kGPU, nvTokenIdType); dOutput->parentIds = bufferManager.emptyTensor(MemoryType::kGPU, nvSizeType); dOutput->lengths = bufferManager.emptyTensor(MemoryType::kGPU, nvSizeType); dOutput->finishedSum = bufferManager.emptyTensor(MemoryType::kGPU, nvSizeType); dOutput->cumLogProbs = bufferManager.emptyTensor(MemoryType::kGPU, nvFloatType); dOutput->logProbs = bufferManager.emptyTensor(MemoryType::kGPU, nvFloatType); dOutput->beamHypotheses.empty(bufferManager); dOutput->finishReasons = bufferManager.emptyTensor(MemoryType::kGPU, TRTDataType::value); dInput->finishReasons = dOutput->finishReasons; dOutput->logProbsTiled = bufferManager.emptyTensor(MemoryType::kGPU, nvFloatType); dInput->stopWordsPtrs = bufferManager.emptyTensor(MemoryType::kPINNEDPOOL, TRTDataType::value); dInput->stopWordsLens = bufferManager.emptyTensor(MemoryType::kPINNEDPOOL, nvSizeType); dInput->badWordsPtrs = bufferManager.emptyTensor(MemoryType::kPINNEDPOOL, TRTDataType::value); dInput->badWordsLens = bufferManager.emptyTensor(MemoryType::kPINNEDPOOL, nvSizeType); dInput->embeddingBias = bufferManager.emptyTensor(MemoryType::kGPU, dtype); mBeamSearchBuffers = std::make_unique(bufferManager); setupCacheIndirectionBuffers(bufferManager); TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__); } void DecoderState::setupSpeculativeDecoding(SpeculativeDecodingMode const& speculativeDecodingMode, SizeType32 maxTokensPerEngineStep, nvinfer1::DataType dtype, ModelConfig const& modelConfig, WorldConfig const& worldConfig, BufferManager const& bufferManager) { TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__); setupSpeculativeDecodingBuffers(speculativeDecodingMode, dtype, bufferManager); reshapeSpeculativeDecodingBuffers( speculativeDecodingMode, maxTokensPerEngineStep, modelConfig, worldConfig, bufferManager); TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__); } void DecoderState::setupSpeculativeDecodingBuffers( SpeculativeDecodingMode const speculativeDecodingMode, nvinfer1::DataType dtype, BufferManager const& bufferManager) { TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__); mSpeculativeDecodingMode = speculativeDecodingMode; auto constexpr nvTokenIdType = TRTDataType::value; auto constexpr nvSizeType = TRTDataType::value; auto& dInput = mJointDecodingInput; auto& dOutput = mJointDecodingOutput; if (speculativeDecodingMode.isMedusa()) { DecodingInput::MedusaInputs medusaInputs; medusaInputs.medusaPaths = bufferManager.emptyTensor(MemoryType::kGPU, nvSizeType); medusaInputs.medusaTreeIds = bufferManager.emptyTensor(MemoryType::kGPU, nvSizeType); medusaInputs.medusaCurTokensPerStep = bufferManager.emptyTensor(MemoryType::kGPU, nvSizeType); medusaInputs.medusaTargetTokensPerStep = bufferManager.emptyTensor(MemoryType::kGPU, nvSizeType); dInput->medusaInputs = medusaInputs; } DecodingOutput::SpeculativeDecodingOutputs speculativeDecodingOutputs; if (speculativeDecodingMode.predictsDraftTokens()) { speculativeDecodingOutputs.nextDraftTokens = bufferManager.emptyTensor(MemoryType::kGPU, nvinfer1::DataType::kINT32); if (speculativeDecodingMode.variableDraftLength()) { speculativeDecodingOutputs.nextDraftTokensLen = bufferManager.emptyTensor(MemoryType::kGPU, nvinfer1::DataType::kINT32); speculativeDecodingOutputs.prevDraftTokensLen = bufferManager.emptyTensor(MemoryType::kGPU, nvinfer1::DataType::kINT32); } } if (speculativeDecodingMode.isLookaheadDecoding()) { dInput->lookaheadInputs = DecodingInput::LookaheadInputs(); } if (speculativeDecodingMode.needsKVCacheRewind()) { speculativeDecodingOutputs.acceptedTokensLen = bufferManager.emptyTensor(MemoryType::kGPU, nvinfer1::DataType::kINT32); speculativeDecodingOutputs.acceptedLengthsCumSum = bufferManager.emptyTensor(MemoryType::kGPU, nvinfer1::DataType::kINT32); speculativeDecodingOutputs.pathsOffsets = bufferManager.emptyTensor(MemoryType::kGPU, nvinfer1::DataType::kINT32); } dOutput->speculativeDecodingOutputs = speculativeDecodingOutputs; if (speculativeDecodingMode.isDraftTokensExternal()) { DecodingInput::ExternalDraftTokensInputs externalDraftTokensInputs; externalDraftTokensInputs.draftLogits = bufferManager.emptyTensor(MemoryType::kGPU, dtype); externalDraftTokensInputs.draftLogitsHost = bufferManager.emptyTensor(MemoryType::kPINNEDPOOL, dtype); externalDraftTokensInputs.draftProbs = bufferManager.emptyTensor(MemoryType::kGPU, dtype); externalDraftTokensInputs.targetProbs = bufferManager.emptyTensor(MemoryType::kGPU, dtype); externalDraftTokensInputs.numDraftTokens = bufferManager.emptyTensor(MemoryType::kGPU, nvSizeType); externalDraftTokensInputs.numDraftTokensHost = bufferManager.emptyTensor(MemoryType::kPINNEDPOOL, nvSizeType); externalDraftTokensInputs.useDraftLogits = bufferManager.emptyTensor(MemoryType::kGPU, TRTDataType::value); externalDraftTokensInputs.useDraftLogitsHost = bufferManager.emptyTensor(MemoryType::kPINNEDPOOL, TRTDataType::value); externalDraftTokensInputs.draftTokenIds = bufferManager.emptyTensor(MemoryType::kGPU, nvTokenIdType); externalDraftTokensInputs.draftTokenIdsHost = bufferManager.emptyTensor(MemoryType::kPINNEDPOOL, nvTokenIdType); dInput->externalDraftTokensInputs = externalDraftTokensInputs; } TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__); } void DecoderState::reshapeBuffers(SizeType32 maxNumSequences, SizeType32 maxBeamWidth, SizeType32 maxAttentionWindow, SizeType32 sinkTokenLength, SizeType32 maxSequenceLength, ModelConfig const& modelConfig, WorldConfig const& worldConfig, BufferManager const& bufferManager) { TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__); auto const& stream = bufferManager.getStream(); TLLM_CHECK(maxNumSequences > 0); TLLM_CHECK(maxBeamWidth > 0); TLLM_CHECK(mMaxDecodingEngineTokens > 0); TLLM_CHECK(maxSequenceLength > 0); mMaxNumSequences = maxNumSequences; mMaxBeamWidth = maxBeamWidth; mMaxSequenceLength = maxSequenceLength; mNumDecodingEngineTokens.clear(); mNumDecodingEngineTokens.resize(mMaxNumSequences, 0); // setup input auto& dInput = *mJointDecodingInput; dInput.maxLength = mMaxSequenceLength; dInput.maxAttentionWindow = maxAttentionWindow; dInput.sinkTokenLength = sinkTokenLength; dInput.stopWordsLists.resize(mMaxNumSequences); dInput.badWordsLists.resize(mMaxNumSequences); auto const maxNumSequencesShape = ITensor::makeShape({mMaxNumSequences}); auto const maxNumSequencesXmaxBeamWidthShape = ITensor::makeShape({mMaxNumSequences, mMaxBeamWidth}); const_cast(*dInput.endIds).reshape(maxNumSequencesShape); auto& sequenceLimitLength = const_cast(*dInput.sequenceLimitLength); sequenceLimitLength.reshape(maxNumSequencesShape); kernels::invokeFill(sequenceLimitLength, mMaxSequenceLength, stream); auto& inputLengths = const_cast(*dInput.lengths); inputLengths.reshape(maxNumSequencesXmaxBeamWidthShape); bufferManager.setZero(inputLengths); dInput.beamWidths.clear(); dInput.beamWidths.resize(mMaxNumSequences, 0); auto const maxTotalTokensShape = ITensor::makeShape({mMaxNumSequences, mMaxBeamWidth, mMaxSequenceLength}); // setup output auto& dOutput = *mJointDecodingOutput; dOutput.ids->reshape(maxTotalTokensShape); auto const maxNewTokensShape = ITensor::makeShape({mMaxDecodingEngineTokens, mMaxNumSequences, mMaxBeamWidth}); dOutput.finishReasons->reshape(maxNumSequencesXmaxBeamWidthShape); bufferManager.setZero(*dOutput.finishReasons); dOutput.parentIds->reshape(maxTotalTokensShape); dOutput.lengths->reshape(maxNumSequencesXmaxBeamWidthShape); bufferManager.setZero(*dOutput.lengths); dOutput.finishedSum->reshape(maxNumSequencesShape); bufferManager.setZero(*dOutput.finishedSum); dOutput.newTokensSteps->reshape(maxNewTokensShape); bufferManager.setZero(*dOutput.newTokensSteps); dOutput.cumLogProbs->reshape(maxNumSequencesXmaxBeamWidthShape); bufferManager.setZero(*dOutput.cumLogProbs); dOutput.logProbs->reshape(maxTotalTokensShape); bufferManager.setZero(*dOutput.logProbs); dOutput.logProbsTiled->reshape(ITensor::makeShape({mMaxSequenceLength, mMaxNumSequences, mMaxBeamWidth})); bufferManager.setZero(*dOutput.logProbsTiled); if (mMaxBeamWidth > 1) { dOutput.beamHypotheses.reshape(mMaxNumSequences, mMaxBeamWidth, mMaxSequenceLength); mBeamSearchBuffers->reshape(mMaxBeamWidth, mMaxSequenceLength); reshapeCacheIndirectionBuffers(mMaxNumSequences, mMaxBeamWidth, maxAttentionWindow); dOutput.gatheredIds->reshape(maxTotalTokensShape); } else { dOutput.gatheredIds = dOutput.ids; } auto const vocabSizePadded = modelConfig.getVocabSizePadded(worldConfig.getSize()); const_cast(*dInput.embeddingBias) .reshape(ITensor::makeShape({mMaxNumSequences, static_cast(vocabSizePadded)})); const_cast(*dInput.badWordsPtrs).reshape(maxNumSequencesShape); const_cast(*dInput.badWordsLens).reshape(maxNumSequencesShape); const_cast(*dInput.stopWordsPtrs).reshape(maxNumSequencesShape); const_cast(*dInput.stopWordsLens).reshape(maxNumSequencesShape); TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__); } void DecoderState::setupCacheIndirection(SizeType32 maxNumSequences, SizeType32 maxBeamWidth, SizeType32 maxAttentionWindow, BufferManager const& bufferManager) { TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__); setupCacheIndirectionBuffers(bufferManager); reshapeCacheIndirectionBuffers(maxNumSequences, maxBeamWidth, maxAttentionWindow); TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__); } void DecoderState::setupCacheIndirectionBuffers(BufferManager const& bufferManager) { auto constexpr nvSizeType = TRTDataType::value; mJointDecodingInput->cacheIndirection = bufferManager.emptyTensor(MemoryType::kGPU, nvSizeType); mJointDecodingOutput->cacheIndirection = bufferManager.emptyTensor(MemoryType::kGPU, nvSizeType); } void DecoderState::reshapeCacheIndirectionBuffers( SizeType32 maxNumSequences, SizeType32 maxBeamWidth, SizeType32 maxAttentionWindow) { mJointDecodingInput->cacheIndirection->reshape( ITensor::makeShape({maxNumSequences, maxBeamWidth, maxAttentionWindow})); mJointDecodingOutput->cacheIndirection->reshape( ITensor::makeShape({maxNumSequences, maxBeamWidth, maxAttentionWindow})); } void DecoderState::reshapeSpeculativeDecodingBuffers(SpeculativeDecodingMode const& speculativeDecodingMode, SizeType32 maxTokensPerEngineStep, ModelConfig const& modelConfig, WorldConfig const& worldConfig, BufferManager const& bufferManager) { TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__); auto& dInput = *mJointDecodingInput; auto& dOutput = *mJointDecodingOutput; TLLM_CHECK(maxTokensPerEngineStep > 0); mMaxDecodingEngineTokens = maxTokensPerEngineStep; TLLM_CHECK_WITH_INFO((mMaxDecodingEngineTokens == 1 && speculativeDecodingMode.isNone()) || (mMaxDecodingEngineTokens > 1 && !speculativeDecodingMode.isNone()), "Max tokens per engine step is %d, but must be equal to 1 when no speculative decoding is configured, " "or > 1 for any speculative decoding mode.", mMaxDecodingEngineTokens); auto const maxNewTokensShape = ITensor::makeShape({mMaxDecodingEngineTokens, mMaxNumSequences, mMaxBeamWidth}); dOutput.newTokensSteps->reshape(maxNewTokensShape); bufferManager.setZero(*dOutput.newTokensSteps); if (speculativeDecodingMode.predictsDraftTokens()) { mMaxDecodingDecoderTokens = mMaxDecodingEngineTokens; } else { mMaxDecodingDecoderTokens = 1; } if (speculativeDecodingMode.isNone()) { TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__); return; } auto const maxNumSequencesShape = ITensor::makeShape({mMaxNumSequences}); if (speculativeDecodingMode.isDraftTokensExternal()) { auto const vocabSizePadded = modelConfig.getVocabSizePadded(worldConfig.getSize()); auto const probsShape = ITensor::makeShape( {mMaxNumSequences, mMaxDecodingEngineTokens, mMaxBeamWidth, static_cast(vocabSizePadded)}); dInput.externalDraftTokensInputs->draftProbs->reshape(probsShape); dInput.externalDraftTokensInputs->targetProbs->reshape(probsShape); auto const logitsShape = ITensor::makeShape( {mMaxNumSequences, mMaxDecodingEngineTokens, static_cast(vocabSizePadded)}); dInput.externalDraftTokensInputs->draftLogits->reshape(logitsShape); dInput.externalDraftTokensInputs->draftLogitsHost->reshape(logitsShape); auto const tokenIdsShape = ITensor::makeShape({mMaxNumSequences, mMaxDecodingEngineTokens}); dInput.externalDraftTokensInputs->draftTokenIds->reshape(tokenIdsShape); dInput.externalDraftTokensInputs->draftTokenIdsHost->reshape(tokenIdsShape); dInput.externalDraftTokensInputs->numDraftTokens->reshape(maxNumSequencesShape); dInput.externalDraftTokensInputs->numDraftTokensHost->reshape(maxNumSequencesShape); dInput.externalDraftTokensInputs->useDraftLogits->reshape(maxNumSequencesShape); dInput.externalDraftTokensInputs->useDraftLogitsHost->reshape(maxNumSequencesShape); } if (speculativeDecodingMode.isMedusa()) { auto const speculativeDecodingModule = modelConfig.getSpeculativeDecodingModulePtr(); auto& medusaPaths = const_cast(*dInput.medusaInputs->medusaPaths); medusaPaths.reshape(ITensor::makeShape({mMaxNumSequences, speculativeDecodingModule->getMaxDecodingTokens(), speculativeDecodingModule->getMaxPathLen()})); bufferManager.setMem(medusaPaths, -1); auto& medusaTreeIds = const_cast(*dInput.medusaInputs->medusaTreeIds); medusaTreeIds.reshape( ITensor::makeShape({mMaxNumSequences, speculativeDecodingModule->getMaxDecodingDraftTokens()})); bufferManager.setZero(medusaTreeIds); auto& curTokensPerStep = const_cast(*dInput.medusaInputs->medusaCurTokensPerStep); auto& targetTokensPerStep = const_cast(*dInput.medusaInputs->medusaTargetTokensPerStep); curTokensPerStep.reshape(maxNumSequencesShape); targetTokensPerStep.reshape(maxNumSequencesShape); bufferManager.setZero(curTokensPerStep); bufferManager.setZero(targetTokensPerStep); } if (speculativeDecodingMode.predictsDraftTokens()) { dOutput.speculativeDecodingOutputs->nextDraftTokens->reshape( ITensor::makeShape({mMaxNumSequences, mMaxDecodingEngineTokens - 1})); if (speculativeDecodingMode.variableDraftLength()) { dOutput.speculativeDecodingOutputs->nextDraftTokensLen->reshape(maxNumSequencesShape); dOutput.speculativeDecodingOutputs->prevDraftTokensLen->reshape(maxNumSequencesShape); } } if (speculativeDecodingMode.needsKVCacheRewind()) { auto const speculativeDecodingModule = modelConfig.getSpeculativeDecodingModulePtr(); dOutput.speculativeDecodingOutputs->acceptedTokensLen->reshape(maxNumSequencesShape); dOutput.speculativeDecodingOutputs->acceptedLengthsCumSum->reshape(ITensor::makeShape({mMaxNumSequences + 1})); dOutput.speculativeDecodingOutputs->pathsOffsets->reshape( ITensor::makeShape({mMaxNumSequences * speculativeDecodingModule->getMaxDraftPathLen()})); } if (speculativeDecodingMode.isExplicitDraftTokens()) { mJointDecodingOutput->explicitDraftTokensBuffers = runtime::ExplicitDraftTokensBuffers::Inputs(); mJointDecodingOutput->explicitDraftTokensBuffers->create( mMaxNumSequences, bufferManager, modelConfig, worldConfig); } else if (speculativeDecodingMode.isEagle()) { mJointDecodingOutput->eagleBuffers = runtime::EagleBuffers::Inputs(); mJointDecodingOutput->eagleBuffers->create(mMaxNumSequences, bufferManager, modelConfig, worldConfig); } else if (speculativeDecodingMode.isLookaheadDecoding()) { mJointDecodingOutput->lookaheadOutputs = runtime::LookaheadDecodingBuffers(mMaxNumSequences, mMaxDecodingEngineTokens, bufferManager); mJointDecodingInput->lookaheadInputs->tokensPerStep = mJointDecodingOutput->lookaheadOutputs->generationLengths; } TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__); } void DecoderState::disableLookahead(RequestVector const& genRequests) { TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__); mSpeculativeDecodingMode = SpeculativeDecodingMode::None(); mMaxDecodingEngineTokens = 1; mMaxDecodingDecoderTokens = 1; mJointDecodingInput->lookaheadInputs.reset(); auto const maxNewTokensShape = ITensor::makeShape({mMaxDecodingEngineTokens, mMaxNumSequences, mMaxBeamWidth}); mJointDecodingOutput->newTokensSteps->reshape(maxNewTokensShape); for (auto const& llmReq : genRequests) { if (llmReq->mSeqSlot) { setNumDecodingEngineTokens(llmReq->mSeqSlot.value(), 1); } } TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__); } TensorPtr DecoderState::getFinishedSum() const { return mJointDecodingOutput->finishedSum; } TensorPtr DecoderState::getFinishReasons() const { return mJointDecodingOutput->finishReasons; } TensorPtr DecoderState::getIds() const { return mJointDecodingOutput->ids; } TensorPtr DecoderState::getIds(SizeType32 batchIdx) const { return ITensor::at(mJointDecodingOutput->ids, {batchIdx}); } TensorPtr DecoderState::getGatheredIds() const { return mJointDecodingOutput->gatheredIds; } TensorPtr DecoderState::getGatheredIds(SizeType32 batchIdx) const { return ITensor::at(mJointDecodingOutput->gatheredIds, {batchIdx}); } TensorPtr DecoderState::getParentIds() const { return mJointDecodingOutput->parentIds; } TensorPtr DecoderState::getCumLogProbs() const { return mJointDecodingOutput->cumLogProbs; } TensorPtr DecoderState::getCumLogProbs(SizeType32 batchIdx) const { return ITensor::at(mJointDecodingOutput->cumLogProbs, {batchIdx}); } TensorPtr DecoderState::getLogProbs() const { return mJointDecodingOutput->logProbs; } TensorPtr DecoderState::getLogProbs(SizeType32 batchIdx) const { return ITensor::at(mJointDecodingOutput->logProbs, {batchIdx}); } TensorPtr DecoderState::getSequenceLengths() const { return mJointDecodingOutput->lengths; } TensorPtr DecoderState::getSequenceLengths(SizeType32 batchIdx) const { return ITensor::at(mJointDecodingOutput->lengths, {batchIdx}); } TensorPtr DecoderState::getAllNewTokens() const { return mJointDecodingOutput->newTokensSteps; } TensorPtr DecoderState::getNextDraftTokens() const { return mJointDecodingOutput->speculativeDecodingOutputs->nextDraftTokens; } TensorPtr DecoderState::getPrevDraftTokensLengths() const { return mJointDecodingOutput->speculativeDecodingOutputs->prevDraftTokensLen; } TensorPtr DecoderState::getNextDraftTokensLengths() const { return mJointDecodingOutput->speculativeDecodingOutputs->nextDraftTokensLen; } TensorPtr DecoderState::getAcceptedLengthsCumSum() const { return mJointDecodingOutput->speculativeDecodingOutputs->acceptedLengthsCumSum; } TensorPtr DecoderState::getAcceptedPackedPaths() const { return mJointDecodingOutput->speculativeDecodingOutputs->pathsOffsets; } SizeType32 DecoderState::getMaxNumSequences() const { return mMaxNumSequences; } SizeType32 DecoderState::getMaxBeamWidth() const { return mMaxBeamWidth; } SizeType32 DecoderState::getMaxSequenceLength() const { return mMaxSequenceLength; } SizeType32 DecoderState::getMaxDecodingDecoderTokens() const { return mMaxDecodingDecoderTokens; } SizeType32 DecoderState::getMaxDecodingEngineTokens() const { return mMaxDecodingEngineTokens; } SpeculativeDecodingMode DecoderState::getSpeculativeDecodingMode() const { return mSpeculativeDecodingMode; } ExplicitDraftTokensBuffers::Inputs const& DecoderState::getExplicitDraftTokensBuffers() const { return *mJointDecodingOutput->explicitDraftTokensBuffers; } EagleBuffers::Inputs const& DecoderState::getEagleBuffers() const { return *mJointDecodingOutput->eagleBuffers; } LookaheadDecodingBuffers const& DecoderState::getLookaheadBuffers() const { return *mJointDecodingOutput->lookaheadOutputs; } std::vector const& DecoderState::getNumDecodingEngineTokens() const { return mNumDecodingEngineTokens; } SizeType32 DecoderState::getNumDecodingEngineTokens(SizeType32 batchIdx) const { TLLM_CHECK_WITH_INFO( batchIdx < mMaxNumSequences, "Batch index %d out of bounds (max %d)", batchIdx, mMaxNumSequences); return mNumDecodingEngineTokens[batchIdx]; } void DecoderState::setNumDecodingEngineTokens(SizeType32 batchIdx, SizeType32 numTokens) { TLLM_CHECK_WITH_INFO( batchIdx < mMaxNumSequences, "Batch index %d out of bounds (max %d)", batchIdx, mMaxNumSequences); mNumDecodingEngineTokens[batchIdx] = numTokens; } BeamSearchBuffers const& DecoderState::getBeamSearchBuffers() const { return *mBeamSearchBuffers; } TensorPtr DecoderState::getCacheIndirectionInput() const { return mJointDecodingInput->cacheIndirection; } TensorPtr DecoderState::getCacheIndirectionOutput() const { return mJointDecodingOutput->cacheIndirection; } std::optional> const& DecoderState::getGenerationSteps() const { return mJointDecodingInput->generationSteps; } void DecoderState::setGenerationSteps(std::vector const& generationSteps) { mJointDecodingInput->generationSteps = generationSteps; } void DecoderState::setBeamWidth(SizeType32 batchIdx, SizeType32 beamWidth) { mJointDecodingInput->beamWidths.at(batchIdx) = beamWidth; } DecodingInput& DecoderState::getJointDecodingInput() const { return *mJointDecodingInput; } DecodingOutput& DecoderState::getJointDecodingOutput() const { return *mJointDecodingOutput; } } // namespace tensorrt_llm::runtime::decoder