/* * Copyright (c) 2022-2024, NVIDIA CORPORATION. All rights reserved. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ #include "tensorrt_llm/runtime/gptDecoderBatched.h" #include "common.h" #include "iBuffer.h" #include "tensorrt_llm/batch_manager/createNewDecoderRequests.h" #include "tensorrt_llm/batch_manager/llmRequest.h" #include "tensorrt_llm/common/assert.h" #include "tensorrt_llm/executor/types.h" #include "tensorrt_llm/kernels/decodingCommon.h" #include "tensorrt_llm/kernels/decodingKernels.h" #include "tensorrt_llm/runtime/bufferManager.h" #include "tensorrt_llm/runtime/cudaEvent.h" #include "tensorrt_llm/runtime/runtimeKernels.h" #include #include #include #include #include #include using namespace tensorrt_llm::runtime; namespace tk = tensorrt_llm::kernels; GptDecoderBatched::GptDecoderBatched(GptDecoderBatched::CudaStreamPtr stream, SpeculativeDecodingMode const& speculativeDecodingMode, nvinfer1::DataType dtype) : mRuntimeStream{std::move(stream)} , mBufferManager{mRuntimeStream} , mSpeculativeDecodingMode{speculativeDecodingMode} { TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__); auto constexpr nvTokenIdType = TRTDataType::value; auto constexpr nvSizeType = TRTDataType::value; auto constexpr nvFloatType = TRTDataType::value; auto& dInput = mJointDecodingInput; { // prevent reusing these vars after std::move auto dummyLogits = mBufferManager.emptyTensor(MemoryType::kGPU, nvFloatType); auto endIds = mBufferManager.emptyTensor(MemoryType::kGPU, nvTokenIdType); auto batchSlots = mBufferManager.emptyTensor(MemoryType::kPINNEDPOOL, nvSizeType); dInput = std::make_unique( 0, 0, 0, 0, std::move(dummyLogits), std::move(endIds), std::move(batchSlots)); } dInput->sequenceLimitLength = mBufferManager.emptyTensor(MemoryType::kGPU, nvSizeType); dInput->lengths = mBufferManager.emptyTensor(MemoryType::kGPU, nvSizeType); auto& dOutput = mJointDecodingOutput; { // prevent reusing these vars after std::move auto outputIds = mBufferManager.emptyTensor(MemoryType::kGPU, nvTokenIdType); auto gatheredOutputIds = mBufferManager.emptyTensor(MemoryType::kGPU, nvTokenIdType); dOutput = std::make_unique(std::move(outputIds), std::move(gatheredOutputIds)); } dOutput->newTokensSteps = mBufferManager.emptyTensor(MemoryType::kGPU, nvTokenIdType); dOutput->parentIds = mBufferManager.emptyTensor(MemoryType::kGPU, nvSizeType); mFinishedSteps = mBufferManager.emptyTensor(MemoryType::kGPU, TRTDataType::value); // use batchSize many entries instead of the usual 1 dOutput->finishedSum = mBufferManager.emptyTensor(MemoryType::kGPU, nvSizeType); // we don't need dOutput->lengths because lengths are passed from outside dOutput->cumLogProbs = mBufferManager.emptyTensor(MemoryType::kGPU, nvFloatType); dOutput->logProbs = mBufferManager.emptyTensor(MemoryType::kGPU, nvFloatType); dOutput->beamHypotheses.empty(mBufferManager); dOutput->finishReasons = mBufferManager.emptyTensor(MemoryType::kGPU, TRTDataType::value); dOutput->logProbsTiled = mBufferManager.emptyTensor(MemoryType::kGPU, nvFloatType); dInput->stopWordsPtrs = mBufferManager.emptyTensor(MemoryType::kPINNEDPOOL, TRTDataType::value); dInput->stopWordsLens = mBufferManager.emptyTensor(MemoryType::kPINNEDPOOL, nvSizeType); dInput->badWordsPtrs = mBufferManager.emptyTensor(MemoryType::kPINNEDPOOL, TRTDataType::value); dInput->badWordsLens = mBufferManager.emptyTensor(MemoryType::kPINNEDPOOL, nvSizeType); dInput->embeddingBias = mBufferManager.emptyTensor(MemoryType::kGPU, dtype); int device; cudaGetDevice(&device); cudaDeviceProp deviceProp; cudaGetDeviceProperties(&deviceProp, device); mNumSMs = deviceProp.multiProcessorCount; if (!mSpeculativeDecodingMode.isNone()) { allocateSpeculativeDecodingBuffers(dtype); } TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__); } void GptDecoderBatched::allocateSpeculativeDecodingBuffers(nvinfer1::DataType dtype) { TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__); auto constexpr nvSizeType = TRTDataType::value; auto& dInput = mJointDecodingInput; auto& dOutput = mJointDecodingOutput; if (mSpeculativeDecodingMode.isMedusa()) { DecodingInput::MedusaInputs medusaInputs; medusaInputs.medusaPaths = mBufferManager.emptyTensor(MemoryType::kGPU, nvSizeType); medusaInputs.medusaTreeIds = mBufferManager.emptyTensor(MemoryType::kGPU, nvSizeType); medusaInputs.medusaCurTokensPerStep = mBufferManager.emptyTensor(MemoryType::kGPU, nvSizeType); medusaInputs.medusaTargetTokensPerStep = mBufferManager.emptyTensor(MemoryType::kGPU, nvSizeType); dInput->medusaInputs = medusaInputs; } DecodingOutput::SpeculativeDecodingOutputs speculativeDecodingOutputs; if (mSpeculativeDecodingMode.predictsDraftTokens()) { speculativeDecodingOutputs.nextDraftTokens = mBufferManager.emptyTensor(MemoryType::kGPU, nvinfer1::DataType::kINT32); if (mSpeculativeDecodingMode.variableDraftLength()) { speculativeDecodingOutputs.nextDraftTokensLen = mBufferManager.emptyTensor(MemoryType::kGPU, nvinfer1::DataType::kINT32); speculativeDecodingOutputs.prevDraftTokensLen = mBufferManager.emptyTensor(MemoryType::kGPU, nvinfer1::DataType::kINT32); } } if (mSpeculativeDecodingMode.isLookaheadDecoding()) { dInput->lookaheadInputs = DecodingInput::LookaheadInputs(); } if (mSpeculativeDecodingMode.needsKVCacheRewind()) { speculativeDecodingOutputs.acceptedTokensLen = mBufferManager.emptyTensor(MemoryType::kGPU, nvinfer1::DataType::kINT32); speculativeDecodingOutputs.acceptedLengthsCumSum = mBufferManager.emptyTensor(MemoryType::kGPU, nvinfer1::DataType::kINT32); speculativeDecodingOutputs.pathsOffsets = mBufferManager.emptyTensor(MemoryType::kGPU, nvinfer1::DataType::kINT32); } dOutput->speculativeDecodingOutputs = speculativeDecodingOutputs; if (mSpeculativeDecodingMode.isDraftTokensExternal()) { DecodingInput::ExternalDraftTokensInputs externalDraftTokensInputs; externalDraftTokensInputs.draftLogits = mBufferManager.emptyTensor(MemoryType::kGPU, dtype); externalDraftTokensInputs.draftProbs = mBufferManager.emptyTensor(MemoryType::kGPU, dtype); externalDraftTokensInputs.targetProbs = mBufferManager.emptyTensor(MemoryType::kGPU, dtype); externalDraftTokensInputs.numDraftTokens = mBufferManager.emptyTensor(MemoryType::kGPU, nvSizeType); externalDraftTokensInputs.numDraftTokensHost = mBufferManager.emptyTensor(MemoryType::kPINNEDPOOL, nvSizeType); externalDraftTokensInputs.useDraftLogits = mBufferManager.emptyTensor(MemoryType::kGPU, TRTDataType::value); externalDraftTokensInputs.useDraftLogitsHost = mBufferManager.emptyTensor(MemoryType::kPINNEDPOOL, TRTDataType::value); externalDraftTokensInputs.draftTokenIds = mBufferManager.emptyTensor(MemoryType::kGPU, nvinfer1::DataType::kINT32); dInput->externalDraftTokensInputs = externalDraftTokensInputs; } TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__); } void GptDecoderBatched::setupExplicitDraftTokens(ExplicitDraftTokensBuffers::Inputs explicitDraftTokensBuffers) { TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__); TLLM_CHECK(mSpeculativeDecodingMode.isExplicitDraftTokens()); mJointDecodingOutput->explicitDraftTokensBuffers = std::move(explicitDraftTokensBuffers); TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__); } void GptDecoderBatched::setupLookahead(LookaheadDecodingBuffers lookaheadDecodingBuffers) { TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__); TLLM_CHECK(mSpeculativeDecodingMode.isLookaheadDecoding()); mJointDecodingOutput->lookaheadOutputs = std::move(lookaheadDecodingBuffers); mJointDecodingInput->lookaheadInputs->tokensPerStep = mJointDecodingOutput->lookaheadOutputs->generationLengths; TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__); } void GptDecoderBatched::setupEagle(EagleBuffers::Inputs eagleBuffers) { TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__); TLLM_CHECK(mSpeculativeDecodingMode.isEagle()); mJointDecodingOutput->eagleBuffers = std::move(eagleBuffers); TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__); } void GptDecoderBatched::disableLookahead( SizeType32 maxBatchSize, RequestVector const& genRequests, TensorPtr const& batchSlots) { TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__); mSpeculativeDecodingMode = SpeculativeDecodingMode::None(); mMaxDecodingEngineTokens = 1; mMaxDecodingDecoderTokens = 1; mDecodingMode = executor::DecodingMode::TopKTopP(); mJointDecodingInput->lookaheadInputs.reset(); mJointDecodingOutput->newTokensSteps->reshape(ITensor::makeShape({1, maxBatchSize, 1})); mFinishedSteps->reshape(ITensor::makeShape({1, maxBatchSize, 1})); mJointDecodingInput->numDecodingEngineTokens.clear(); mJointDecodingInput->numDecodingEngineTokens.resize(maxBatchSize, 0); std::vector samplingConfigs; samplingConfigs.reserve(genRequests.size()); auto batchSlotsRange = BufferRange(*batchSlots); SizeType32 batchIdx = 0; for (auto const& llmReq : genRequests) { mJointDecodingInput->numDecodingEngineTokens[llmReq->mSeqSlot.value()] = 1; samplingConfigs.push_back(llmReq->mSamplingConfig); batchSlotsRange[batchIdx] = llmReq->mSeqSlot.value(); batchIdx += 1; } auto const batchSize = batchIdx; std::optional samplingConfig; if (batchSize > 0) { samplingConfig = SamplingConfig(samplingConfigs); } TensorPtr batchSlotsView = ITensor::slice(batchSlots, 0, batchSize); mDecoder->disableLookahead(samplingConfig, batchSize, batchSlots); CudaEvent event{}; mDecoderStream->record(event); mRuntimeStream->wait(event); TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__); } void GptDecoderBatched::setup(executor::DecodingMode const& mode, SizeType32 maxBatchSize, SizeType32 maxBeamWidth, SizeType32 maxAttentionWindow, SizeType32 sinkTokenLength, SizeType32 maxSequenceLength, SizeType32 maxTokensPerEngineStep, nvinfer1::DataType dtype, ModelConfig const& modelConfig, WorldConfig const& worldConfig) { TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__); TLLM_CHECK(maxBatchSize > 0); TLLM_CHECK(maxBeamWidth > 0); TLLM_CHECK(maxTokensPerEngineStep > 0); TLLM_CHECK(maxSequenceLength > 0); mActualBatchSize = maxBatchSize; mMaxSequenceLength = maxSequenceLength; mMaxDecodingEngineTokens = maxTokensPerEngineStep; mDecodingMode = mode; TLLM_CHECK_WITH_INFO((mMaxDecodingEngineTokens == 1 && mSpeculativeDecodingMode.isNone()) || (mMaxDecodingEngineTokens > 1 && !mSpeculativeDecodingMode.isNone()), "Max tokens per engine step must be equal to 1 when no speculative decoding is configured, " "or > 1 for any speculative decoding mode"); auto const maxBatchSizeShape = ITensor::makeShape({maxBatchSize}); auto const maxBatchSizeXmaxBeamWidth = ITensor::makeShape({maxBatchSize, maxBeamWidth}); auto const maxTokensPerStepXmaxBatchSizeXmaxBeamWidth = ITensor::makeShape({maxTokensPerEngineStep, maxBatchSize, maxBeamWidth}); auto const maxBatchSizeXmaxTokensPerStep = ITensor::makeShape({maxBatchSize, maxTokensPerEngineStep}); auto const jointOutputIdsShape = ITensor::makeShape({maxBatchSize, maxBeamWidth, maxSequenceLength}); auto& dInput = *mJointDecodingInput; dInput.maxLength = mMaxSequenceLength; dInput.maxAttentionWindow = maxAttentionWindow; dInput.sinkTokenLength = sinkTokenLength; dInput.stopWordsLists.resize(maxBatchSize); dInput.badWordsLists.resize(maxBatchSize); const_cast(*dInput.endIds).reshape(maxBatchSizeShape); const_cast(*dInput.batchSlots).reshape(maxBatchSizeShape); auto& sequenceLimitLength = const_cast(*dInput.sequenceLimitLength); sequenceLimitLength.reshape(maxBatchSizeShape); kernels::invokeFill(sequenceLimitLength, mMaxSequenceLength, *mRuntimeStream); auto& inputLengths = const_cast(*dInput.lengths); inputLengths.reshape(maxBatchSizeXmaxBeamWidth); mBufferManager.setZero(inputLengths); dInput.beamWidths.clear(); dInput.beamWidths.resize(maxBatchSize, 0); dInput.numDecodingEngineTokens.clear(); dInput.numDecodingEngineTokens.resize(maxBatchSize, 0); auto& dOutput = *mJointDecodingOutput; dOutput.ids->reshape(jointOutputIdsShape); if (maxBeamWidth > 1) { dOutput.gatheredIds->reshape(jointOutputIdsShape); mOutputBeamHypotheses = std::make_shared(); mOutputBeamHypotheses->empty(mBufferManager); mOutputBeamHypotheses->reshape(1, maxBeamWidth, mMaxSequenceLength); mCumLogProbsTmp = mBufferManager.gpu(ITensor::makeShape({1, maxBeamWidth}), nvinfer1::DataType::kFLOAT); } else { dOutput.gatheredIds = dOutput.ids; } mFinishedSteps->reshape(maxTokensPerStepXmaxBatchSizeXmaxBeamWidth); mBufferManager.setZero(*mFinishedSteps); dOutput.finishReasons->reshape(maxBatchSizeXmaxBeamWidth); mBufferManager.setZero(*dOutput.finishReasons); auto const vocabSize = modelConfig.getVocabSize(); auto const vocabSizePadded = modelConfig.getVocabSizePadded(worldConfig.getSize()); if (mSpeculativeDecodingMode.isDraftTokensExternal()) { dInput.externalDraftTokensInputs->draftProbs->reshape(ITensor::makeShape( {maxBatchSize, maxTokensPerEngineStep, maxBeamWidth, static_cast(vocabSizePadded)})); dInput.externalDraftTokensInputs->targetProbs->reshape(ITensor::makeShape( {maxBatchSize, maxTokensPerEngineStep, maxBeamWidth, static_cast(vocabSizePadded)})); dInput.externalDraftTokensInputs->draftLogits->reshape( ITensor::makeShape({maxBatchSize, maxTokensPerEngineStep, static_cast(vocabSizePadded)})); dInput.externalDraftTokensInputs->draftTokenIds->reshape(maxBatchSizeXmaxTokensPerStep); dInput.externalDraftTokensInputs->numDraftTokens->reshape(ITensor::makeShape({maxBatchSize})); dInput.externalDraftTokensInputs->numDraftTokensHost->reshape(ITensor::makeShape({maxBatchSize})); dInput.externalDraftTokensInputs->useDraftLogits->reshape(ITensor::makeShape({maxBatchSize})); dInput.externalDraftTokensInputs->useDraftLogitsHost->reshape(ITensor::makeShape({maxBatchSize})); } dOutput.parentIds->reshape(jointOutputIdsShape); // use batchSize many entries instead of the usual 1 dOutput.finishedSum->reshape(maxBatchSizeShape); mBufferManager.setZero(*dOutput.finishedSum); dOutput.newTokensSteps->reshape(ITensor::makeShape({maxTokensPerEngineStep, maxBatchSize, maxBeamWidth})); mBufferManager.setZero(*dOutput.newTokensSteps); dOutput.cumLogProbs->reshape(maxBatchSizeXmaxBeamWidth); mBufferManager.setZero(*dOutput.cumLogProbs); dOutput.logProbs->reshape(jointOutputIdsShape); mBufferManager.setZero(*dOutput.logProbs); if (maxBeamWidth > 1) { dOutput.beamHypotheses.reshape(maxBatchSize, maxBeamWidth, mMaxSequenceLength); } dOutput.logProbsTiled->reshape(ITensor::makeShape({maxSequenceLength, maxBatchSize, maxBeamWidth})); mBufferManager.setZero(*dOutput.logProbsTiled); const_cast(*dInput.embeddingBias) .reshape(ITensor::makeShape({maxBatchSize, static_cast(vocabSizePadded)})); const_cast(*dInput.badWordsPtrs).reshape(ITensor::makeShape({maxBatchSize})); const_cast(*dInput.badWordsLens).reshape(ITensor::makeShape({maxBatchSize})); const_cast(*dInput.stopWordsPtrs).reshape(ITensor::makeShape({maxBatchSize})); const_cast(*dInput.stopWordsLens).reshape(ITensor::makeShape({maxBatchSize})); std::shared_ptr speculativeDecodingModulePtr = nullptr; if (mSpeculativeDecodingMode.predictsDraftTokens()) { speculativeDecodingModulePtr = modelConfig.getSpeculativeDecodingModulePtr(); setupSpeculativeDecoding(modelConfig); } else { mMaxDecodingDecoderTokens = 1; } auto const device = mRuntimeStream->getDevice(); mDecoderStream = std::make_shared(); TLLM_CHECK(mDecoderStream->getDevice() == device); mDecoder = IGptDecoder::create(mode, dtype, maxBatchSize, maxBeamWidth, vocabSize, vocabSizePadded, mMaxSequenceLength, mDecoderStream, speculativeDecodingModulePtr); TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__); } void GptDecoderBatched::setupSpeculativeDecoding(ModelConfig const& modelConfig) { TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__); auto& dInput = *mJointDecodingInput; auto& dOutput = *mJointDecodingOutput; auto const speculativeDecodingModule = modelConfig.getSpeculativeDecodingModulePtr(); if (mSpeculativeDecodingMode.isMedusa()) { auto& medusaPaths = const_cast(*dInput.medusaInputs->medusaPaths); medusaPaths.reshape(ITensor::makeShape({mActualBatchSize, speculativeDecodingModule->getMaxDecodingTokens(), speculativeDecodingModule->getMaxPathLen()})); mBufferManager.setMem(medusaPaths, -1); auto& medusaTreeIds = const_cast(*dInput.medusaInputs->medusaTreeIds); medusaTreeIds.reshape( ITensor::makeShape({mActualBatchSize, speculativeDecodingModule->getMaxDecodingDraftTokens()})); mBufferManager.setZero(medusaTreeIds); auto& curTokensPerStep = const_cast(*dInput.medusaInputs->medusaCurTokensPerStep); auto& targetTokensPerStep = const_cast(*dInput.medusaInputs->medusaTargetTokensPerStep); curTokensPerStep.reshape(ITensor::makeShape({mActualBatchSize})); targetTokensPerStep.reshape(ITensor::makeShape({mActualBatchSize})); mBufferManager.setZero(curTokensPerStep); mBufferManager.setZero(targetTokensPerStep); } if (mSpeculativeDecodingMode.predictsDraftTokens()) { dOutput.speculativeDecodingOutputs->nextDraftTokens->reshape( ITensor::makeShape({mActualBatchSize, mMaxDecodingEngineTokens - 1})); if (mSpeculativeDecodingMode.variableDraftLength()) { dOutput.speculativeDecodingOutputs->nextDraftTokensLen->reshape(ITensor::makeShape({mActualBatchSize})); dOutput.speculativeDecodingOutputs->prevDraftTokensLen->reshape(ITensor::makeShape({mActualBatchSize})); } } if (mSpeculativeDecodingMode.needsKVCacheRewind()) { dOutput.speculativeDecodingOutputs->acceptedTokensLen->reshape(ITensor::makeShape({mActualBatchSize})); dOutput.speculativeDecodingOutputs->acceptedLengthsCumSum->reshape(ITensor::makeShape({mActualBatchSize + 1})); dOutput.speculativeDecodingOutputs->pathsOffsets->reshape( ITensor::makeShape({mActualBatchSize * speculativeDecodingModule->getMaxDraftPathLen()})); } mMaxDecodingDecoderTokens = mMaxDecodingEngineTokens; TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__); } void GptDecoderBatched::setExplicitDraftTokensInputs(decoder_batch::Input const& input) { TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__); auto explicitDraftTokensInputs = DecodingInput::ExplicitDraftTokensInputs(); TLLM_CHECK(input.explicitDraftTokensInputs.has_value()); TLLM_CHECK(input.explicitDraftTokensLastInputs.has_value()); explicitDraftTokensInputs.nextDraftTokens = input.explicitDraftTokensInputs->nextDraftTokens; explicitDraftTokensInputs.nextFlatTokens = input.explicitDraftTokensInputs->nextFlatTokens; explicitDraftTokensInputs.nextDraftIndices = input.explicitDraftTokensInputs->nextDraftIndices; explicitDraftTokensInputs.nextDraftProbs = input.explicitDraftTokensInputs->nextDraftProbs; explicitDraftTokensInputs.lastDraftTokens = input.explicitDraftTokensLastInputs->draftTokens; explicitDraftTokensInputs.lastDraftIndices = input.explicitDraftTokensLastInputs->draftIndices; explicitDraftTokensInputs.lastPositionIdsBase = input.explicitDraftTokensLastInputs->positionIdsBase; explicitDraftTokensInputs.masks = input.explicitDraftTokensInputs->masks; explicitDraftTokensInputs.packedPositionIds = input.explicitDraftTokensInputs->packedPositionIds; explicitDraftTokensInputs.bestPathLengths = input.explicitDraftTokensInputs->bestPathLengths; explicitDraftTokensInputs.bestPathIndices = input.explicitDraftTokensInputs->bestPathIndices; explicitDraftTokensInputs.nextGenerationLengths = input.explicitDraftTokensInputs->nextGenerationLengths; explicitDraftTokensInputs.lastGenerationLengths = input.explicitDraftTokensLastInputs->generationLengths; explicitDraftTokensInputs.maxGenLengthDevice = input.explicitDraftTokensInputs->maxGenToken; explicitDraftTokensInputs.seqSlots = input.batchSlotsRequestOrder; mJointDecodingInput->explicitDraftTokensInputs = explicitDraftTokensInputs; TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__); } void GptDecoderBatched::setEagleInputs(decoder_batch::Input const& input) { TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__); TLLM_CHECK(input.eagleInputs.has_value()); TLLM_CHECK(input.eagleLastInputs.has_value()); auto eagleInputs = DecodingInput::EagleInputs(input.eagleInputs->nextDraftTokens, input.eagleInputs->nextDraftLens, input.eagleInputs->nextDraftPaths, input.eagleLastInputs->draftTokens, input.eagleLastInputs->draftLens, input.eagleLastInputs->draftPaths, input.eagleInputs->acceptedTokens, input.eagleInputs->acceptedLens, input.eagleInputs->acceptedPaths, input.eagleInputs->chunkedContextNextTokens, input.batchSlotsRequestOrder); mJointDecodingInput->eagleInputs = eagleInputs; TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__); } namespace { template T maxOfActiveSlots(std::vector const& values, std::vector const& active) { return std::transform_reduce( values.begin(), values.end(), active.begin(), std::numeric_limits::min(), [](auto lhf, auto rhs) { return std::max(lhf, rhs); }, [](auto numTokens, auto active) { return active ? numTokens : std::numeric_limits::min(); }); } } // namespace void GptDecoderBatched::forwardDispatch( decoder_batch::Output& output, decoder_batch::Input const& input, ForwardType forwardType) { auto eventStart = CudaEvent{}; mRuntimeStream->record(eventStart); bool const async = forwardType == ForwardType::kASYNC; if (async) { mDecoderStream->wait(eventStart.get()); } auto const maxDecodingEngineTokens = maxOfActiveSlots(mJointDecodingInput->numDecodingEngineTokens, input.active); for (SizeType32 si = 0; si < maxDecodingEngineTokens; si += mMaxDecodingDecoderTokens) { prepareForward(si, output, input); forwardDecoder(*mJointDecodingOutput, *mJointDecodingInput, forwardType); } if (async) { CudaEvent event{}; mDecoderStream->record(event); mRuntimeStream->wait(event); } } GptDecoderBatched::DecoderFinishedEventPtr GptDecoderBatched::forwardAsync( decoder_batch::Output& output, decoder_batch::Input const& input) { TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__); forwardDispatch(output, input, ForwardType::kASYNC); CudaEvent eventStop{}; mRuntimeStream->record(eventStop); TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__); return std::make_unique(std::move(eventStop), input.active); } void GptDecoderBatched::prepareForward( SizeType32 step, decoder_batch::Output& output, decoder_batch::Input const& input) { TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__); auto const& allTargetLogits = input.logits; auto const& jointOutputIdsShape = mJointDecodingOutput->ids->getShape(); auto const maxBeamWidth = jointOutputIdsShape.d[1]; auto constexpr singleRequest = 1; TLLM_CHECK(static_cast(output.sequenceLengths->getSize()) == mActualBatchSize * maxBeamWidth); // TODO should remove this reshape and set shape to [batch_size, beam_width] outside TensorPtr sequenceLengths = ITensor::view(output.sequenceLengths, ITensor::makeShape({mActualBatchSize, maxBeamWidth})); TLLM_CHECK(sequenceLengths); auto& dInput = *mJointDecodingInput; auto& dOutput = *mJointDecodingOutput; if (maxBeamWidth > 1) { dInput.cacheIndirection = input.cacheIndirection; dOutput.cacheIndirection = output.cacheIndirection; } if (mSpeculativeDecodingMode.isExplicitDraftTokens()) { setExplicitDraftTokensInputs(input); } else if (mSpeculativeDecodingMode.isEagle()) { setEagleInputs(input); } TensorPtr batchSlotsSlice = ITensor::at(input.batchSlots, {step}); auto batchSlotsRange = BufferRange(*batchSlotsSlice); SizeType32 localBatchDecoderIdx = 0; std::vector logitsVec; for (SizeType32 bi = 0; bi < mActualBatchSize; ++bi) { if (!input.active.at(bi) || step >= mJointDecodingInput->numDecodingEngineTokens.at(bi)) { continue; } batchSlotsRange[localBatchDecoderIdx] = bi; localBatchDecoderIdx++; auto const& targetLogits = allTargetLogits[bi]; TensorPtr logitsSlice = ITensor::slice(targetLogits, step, singleRequest); logitsVec.push_back(logitsSlice); } batchSlotsSlice->resize(localBatchDecoderIdx); dInput.batchSlots = batchSlotsSlice; dInput.batchSize = localBatchDecoderIdx; dInput.logitsVec = logitsVec; auto const maxDecodingEngineTokens = maxOfActiveSlots(mJointDecodingInput->numDecodingEngineTokens, input.active); TensorPtr finishedStepsInput = ITensor::slice(mFinishedSteps, step, 1); TensorPtr finishedStepsOutput = ITensor::slice(mFinishedSteps, std::min(maxDecodingEngineTokens - 1, step + 1), 1); finishedStepsInput->squeeze(0); finishedStepsOutput->squeeze(0); TensorPtr newTokensStepView = ITensor::slice(dOutput.newTokensSteps, step, mMaxDecodingDecoderTokens); dInput.finishReasons = finishedStepsInput; if (mSpeculativeDecodingMode.isMedusa()) { dInput.medusaInputs->medusaLogits = input.predictedDraftLogits; } if (mSpeculativeDecodingMode.isDraftTokensExternal()) { dInput.externalDraftTokensInputs->step = step; // WAR: reset finished state for generation requests if (step == 0) { BufferManager manager{mDecoderStream}; for (SizeType32 bi = 0; bi < mActualBatchSize; ++bi) { if (!input.active.at(bi)) { continue; } TensorPtr finishedStepsView = ITensor::slice(mFinishedSteps, 0, 1); finishedStepsView->squeeze(0); auto batchSlot = bi; TensorPtr finishedSteps = ITensor::slice(finishedStepsView, batchSlot, 1); manager.setZero(*finishedStepsView); } } } dOutput.newTokens = newTokensStepView; dOutput.finishReasons = finishedStepsOutput; dOutput.lengths = sequenceLengths; TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__); } void GptDecoderBatched::forwardDecoder(DecodingOutput& output, DecodingInput const& input, ForwardType forwardType) { TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__); if (input.batchSize > 0) { if (forwardType == ForwardType::kASYNC) { mDecoder->forwardAsync(output, input); } else if (forwardType == ForwardType::kSYNC) { mDecoder->forwardSync(output, input); } else { TLLM_THROW("Unknown ForwardType"); } } TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__); } void GptDecoderBatched::forward(decoder_batch::Output& output, decoder_batch::Input const& input) { TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__); auto decoderFinishEvent = forwardAsync(output, input); decoderFinishEvent->event.synchronize(); TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__); } // TODO call this at the end of forward if mFinished[i] changes from false to true? CudaEvent GptDecoderBatched::finalize(SizeType32 batchSlot, SamplingConfig const& samplingConfig, bool streaming) const { TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__); auto& stream = mRuntimeStream; auto manager = BufferManager{stream}; auto& dJointInput = *mJointDecodingInput; auto& dJointOutput = *mJointDecodingOutput; auto slice = [batchSlot](auto& a, auto const& b) { if (b && b->getShape().d[0] > 0) { a = ITensor::slice(b, batchSlot, 1); } }; // Prepare a slice of dJointInput and dJointOutput for gatherTree DecodingInput dInput{dJointInput}; slice(dInput.endIds, dJointInput.endIds); slice(dInput.lengths, dJointInput.lengths); DecodingOutput dOutput{ ITensor::slice(dJointOutput.ids, batchSlot, 1), ITensor::slice(dJointOutput.gatheredIds, batchSlot, 1)}; dOutput.beamHypotheses = dJointOutput.beamHypotheses.slice(batchSlot, 1); slice(dOutput.parentIds, dJointOutput.parentIds); slice(dOutput.cumLogProbs, dJointOutput.cumLogProbs); slice(dOutput.cacheIndirection, dJointOutput.cacheIndirection); slice(dOutput.lengths, dJointOutput.lengths); slice(dOutput.finishReasons, dJointOutput.finishReasons); slice(dOutput.logProbs, dJointOutput.logProbs); dOutput.newTokens = ITensor::view(dJointOutput.newTokens); TLLM_CHECK(dOutput.newTokens->getShape().d[0] == 1); dOutput.newTokens->squeeze(0); dOutput.newTokens = ITensor::slice(dOutput.newTokens, batchSlot, 1); dOutput.logProbsTiled = dJointOutput.logProbsTiled; if (streaming) { // in case of streaming we shouldn't overwrite the data in beamHypotheses, since the beam search kernels expect // ungathered data but the kernels in gatherTree write in-place. // Thus, we need to make a copy of the beamHypotheses tensorrt_llm::kernels::invokeCopyBeamHypotheses( dOutput.beamHypotheses, *mOutputBeamHypotheses, *dOutput.cumLogProbs, *mCumLogProbsTmp, *stream, mNumSMs); dOutput.beamHypotheses = *mOutputBeamHypotheses; dOutput.cumLogProbs = mCumLogProbsTmp; } kernels::gatherTree(dOutput, dInput, manager, samplingConfig); CudaEvent event{}; stream->record(event); mRuntimeStream->wait(event); TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__); return event; }