/* * Copyright (c) 2019-2024, NVIDIA CORPORATION. All rights reserved. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ #include "medusaDecodingLayer.h" #include "tensorrt_llm/common/nvtxUtils.h" #include "tensorrt_llm/kernels/decodingCommon.h" #include "tensorrt_llm/kernels/samplingTopKKernels.h" #include "tensorrt_llm/kernels/speculativeDecoding/medusaDecodingKernels.h" #include "tensorrt_llm/runtime/bufferManager.h" #include "tensorrt_llm/runtime/iBuffer.h" #include using namespace tensorrt_llm::common; using namespace tensorrt_llm::kernels; using namespace tensorrt_llm::kernels::speculative_decoding; using namespace tensorrt_llm::runtime; namespace tensorrt_llm::layers { template MedusaDecodingLayer::MedusaDecodingLayer( DecoderDomain const& decoderDomain, std::shared_ptr bufferManager) : BaseLayer(decoderDomain, bufferManager) { TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__); allocateBuffer(); TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__); } template void MedusaDecodingLayer::allocateBuffer() { TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__); auto const maxDraftPathLen = mDecoderDomain.getSpeculativeDecodingModule()->getMaxDraftPathLen(); // Get sampling workspace size { auto samplingSizePrimarySampling = getTopKWorkspaceSize(mDecoderDomain.getBatchSize(), mDecoderDomain.getMaxDecodingTokens(), TOP_K_MAX, mDecoderDomain.getVocabSizePadded()); auto const maxBatchSizeHeadNums = mDecoderDomain.getBatchSize() * maxDraftPathLen; auto samplingSizeMedusaHeadsSampling = getTopKWorkspaceSize(maxBatchSizeHeadNums, 1, TOP_K_MAX, mDecoderDomain.getVocabSizePadded()); mWorkspaceSize = std::max(samplingSizePrimarySampling, samplingSizeMedusaHeadsSampling); } mDraftIdsPtrHost = BufferManager::pinnedPool( ITensor::makeShape({static_cast(mDecoderDomain.getBatchSize()), maxDraftPathLen}), TRTDataType::value); mCummulativeTopK.resize(mDecoderDomain.getBatchSize() * maxDraftPathLen); auto const batchSize = mDecoderDomain.getBatchSize(); auto const batchSizeShape = ITensor::makeShape({mDecoderDomain.getBatchSize()}); mCurandStatesDevice = mBufferManager->gpu( ITensor::makeShape({static_cast(batchSize * sizeof(curandState_t))}), TRTDataType::value); mRuntimeTopKDevice = mBufferManager->gpu(batchSizeShape, TRTDataType::value); mTargetTokensDevice = mBufferManager->gpu( ITensor::makeShape({batchSize, mDecoderDomain.getMaxDecodingTokens()}), TRTDataType::value); mRandomSeedsDevice = mBufferManager->gpu(ITensor::makeShape({batchSize * maxDraftPathLen}), TRTDataType::value); mMedusaSelectedLogitsPtrsDevice = mBufferManager->gpu(ITensor::makeShape({batchSize, maxDraftPathLen}), TRTDataType::value); mCurandStatesMedusaLogitsDevice = mBufferManager->gpu( ITensor::makeShape({batchSize, maxDraftPathLen, sizeof(curandState_t)}), TRTDataType::value); mRuntimeTopKPerRequestPerMedusaHeadDevice = mBufferManager->gpu(ITensor::makeShape({batchSize, maxDraftPathLen}), TRTDataType::value); mNewDraftTokensDevice = mBufferManager->gpu( ITensor::makeShape({batchSize, mDecoderDomain.getMaxDecodingTokens()}), TRTDataType::value); mBestPathIdsDevice = mBufferManager->gpu(batchSizeShape, TRTDataType::value); mTiledBatchSlotsSetup = BufferManager::pinnedPool( ITensor::makeShape({static_cast(mDecoderDomain.getBatchSize() * maxDraftPathLen)}), nvinfer1::DataType::kINT32); mTiledBatchSlotsForward = BufferManager::pinnedPool( ITensor::makeShape({static_cast(mDecoderDomain.getBatchSize() * maxDraftPathLen)}), nvinfer1::DataType::kINT32); mMedusaInputLogitsPtrs = BufferManager::pinnedPool( ITensor::makeShape({static_cast(mDecoderDomain.getBatchSize() * maxDraftPathLen)}), TRTDataType::value); TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__); } template void MedusaDecodingLayer::setup(SizeType32 batchSize, SizeType32 beamWidth, TensorConstPtr batchSlots, std::shared_ptr const& baseSetupParams, std::shared_ptr const& workspace) { TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__); auto setupParams = std::dynamic_pointer_cast(baseSetupParams); workspace->initializeDeviceCurandStates( setupParams->randomSeed, batchSize, workspace->getDeviceBatchSlots(), mCurandStatesDevice); auto const maxDraftPathLen = mDecoderDomain.getSpeculativeDecodingModule()->getMaxDraftPathLen(); auto const batchSizeMaxNumHeads = batchSize * maxDraftPathLen; auto randomSeed = setupParams->randomSeed.value_or(std::vector(batchSize, uint64_t{0})); std::vector tiledRandomSeed(batchSizeMaxNumHeads); if (randomSeed.size() > 1) { for (SizeType32 bi = 0; bi < batchSize; ++bi) { for (SizeType32 hi = 0; hi < maxDraftPathLen; ++hi) { tiledRandomSeed[bi * maxDraftPathLen + hi] = randomSeed[bi]; } } } auto* tiledBatchSlots = bufferCast(*mTiledBatchSlotsSetup); BufferRange batchSlotsRange(*batchSlots); for (SizeType32 bi = 0; bi < batchSize; ++bi) { for (SizeType32 hi = 0; hi < maxDraftPathLen; ++hi) { tiledBatchSlots[bi * maxDraftPathLen + hi] = batchSlotsRange[bi] + hi; } } auto tiledRandomSeedOpt = std::make_optional(std::move(tiledRandomSeed)); workspace->initializeDeviceCurandStates( tiledRandomSeedOpt, batchSizeMaxNumHeads, mTiledBatchSlotsSetup, mCurandStatesMedusaLogitsDevice); // Prepare runtime top K auto prepareRuntimeTopK = [this, workspace](std::vector const& runtimeTopK, SizeType32 batchSize, BufferConstPtr const& batchSlots, BufferPtr const& runtimeTopKDevice) { TLLM_CHECK_WITH_INFO(runtimeTopK.size() == 1 || runtimeTopK.size() == static_cast(batchSize), fmtstr("runtimeTopK.size() (%lu) == batchSize (%d) is not satisfied!", runtimeTopK.size(), batchSize)); SizeType32* topKSetupPtr = nullptr; if (runtimeTopK.size() > 1) { DecodingLayerWorkspace::copyToWorkspace( *this->mBufferManager, runtimeTopK, workspace->getWorkspaceDeviceBuffer()); topKSetupPtr = workspace->getWorkspaceDevicePtrAs(); } auto* runtimeTopKDevicePtr = bufferCastOrNull(runtimeTopKDevice); auto const* batchSlotsPtr = bufferCastOrNull(batchSlots); invokeScatterDecodingParams( topKSetupPtr, runtimeTopK.front(), runtimeTopKDevicePtr, batchSlotsPtr, batchSize, getStream()); // FIXME(nkorobov): monotonically growing auto const curMaxTopK = *std::max_element(std::begin(runtimeTopK), std::end(runtimeTopK)); return curMaxTopK; }; SizeType32 constexpr defaultTopK = 1; { auto runtimeTopK = setupParams->runtimeTopK.value_or(std::vector{defaultTopK}); auto const curMaxTopK = prepareRuntimeTopK(runtimeTopK, batchSize, workspace->getDeviceBatchSlots(), mRuntimeTopKDevice); mRuntimeMaxTopK = std::max(mRuntimeMaxTopK, curMaxTopK); } { auto runtimeHeadsTopK = setupParams->runtimeHeadsTopK; std::vector runtimeHeadsTopKFlatten; if (runtimeHeadsTopK.has_value() && static_cast(runtimeHeadsTopK->size())) { for (auto const& sub : runtimeHeadsTopK.value()) { runtimeHeadsTopKFlatten.insert(runtimeHeadsTopKFlatten.end(), sub.begin(), sub.end()); } } else { runtimeHeadsTopKFlatten = std::vector(batchSizeMaxNumHeads, defaultTopK); } BufferRange batchSlotsRange(*batchSlots); for (SizeType32 bi = 0; bi < batchSize; ++bi) { auto const slot = batchSlotsRange[bi]; SizeType32 cummulativeTopK = 0; for (SizeType32 hi = 0; hi < maxDraftPathLen; ++hi) { mCummulativeTopK[slot * maxDraftPathLen + hi] = cummulativeTopK; cummulativeTopK += runtimeHeadsTopKFlatten[bi * maxDraftPathLen + hi]; } } auto* tiledBatchSlots = bufferCast(*mTiledBatchSlotsSetup); for (SizeType32 bi = 0; bi < batchSize; ++bi) { for (SizeType32 hi = 0; hi < maxDraftPathLen; ++hi) { tiledBatchSlots[bi * maxDraftPathLen + hi] = maxDraftPathLen * batchSlotsRange[bi] + hi; } } auto const curMaxTopK = prepareRuntimeTopK(runtimeHeadsTopKFlatten, static_cast(batchSizeMaxNumHeads), mTiledBatchSlotsSetup, mRuntimeTopKPerRequestPerMedusaHeadDevice); mRuntimeMaxTopKPerRequestPerMedusaHead = std::max(mRuntimeMaxTopKPerRequestPerMedusaHead, curMaxTopK); } TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__); } template void MedusaDecodingLayer::forwardAsync(std::shared_ptr const& baseOutputs, std::shared_ptr const& baseInputs, std::shared_ptr const& workspace) { TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__); NVTX3_SCOPED_RANGE(MedusaDecodingLayer_forwardAsync); auto inputs = std::dynamic_pointer_cast(baseInputs); auto outputs = std::dynamic_pointer_cast(baseOutputs); // TODO add typical acceptance similarly to EagleSampleAndAcceptDraftTokensPlugin::doTypicalAcceptance. samplePrimeHeadTokens(*outputs, *inputs, workspace); acceptDraftTokens(*outputs, *inputs, workspace); sampleNewDraftTokens(*outputs, *inputs, workspace); scatterNewDraftTokens(*outputs, *inputs); packAcceptedPaths(*outputs, *inputs, workspace); TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__); } template size_t MedusaDecodingLayer::getWorkspaceSize() const noexcept { return std::max(mWorkspaceSize, mSetupWorkspaceSize); } template void MedusaDecodingLayer::samplePrimeHeadTokens(SpeculativeDecodingOutputs const& outputs, MedusaDecodingInputs const& inputs, std::shared_ptr const& workspace) { TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__); auto const batchSize = inputs.logits.value()->getDimension<0>(); auto logits = bufferCast(*inputs.logits.value()); auto const* batchSlots = workspace->getDeviceBatchSlotsPtr(); auto* sequenceLengths = bufferCastOrNull(outputs.sequenceLength); auto* tokensPerStepDevice = bufferCast(*inputs.curTokensPerStep.value()); TLLM_CHECK_WITH_INFO(batchSlots != nullptr, "Batch slots must be provided for MedusaDecoding"); TLLM_CHECK_WITH_INFO(sequenceLengths != nullptr, "Sequence lengths must be provided for MedusaDecoding"); TopKSamplingKernelParams params; params.logProbs = logits; params.outputIds = bufferCastOrNull(mTargetTokensDevice); params.workspace = workspace->getRawWorkspaceDevicePtr(); params.maxTopK = mRuntimeMaxTopK; params.topKs = bufferCastOrNull(mRuntimeTopKDevice); params.batchSlots = batchSlots; params.curandState = reinterpret_cast(bufferCastOrNull(mCurandStatesDevice)); params.batchSize = batchSize; params.maxBatchSize = mDecoderDomain.getBatchSize(); params.tokensPerStep = tokensPerStepDevice; params.maxTokensPerStep = mDecoderDomain.getMaxDecodingTokens(); params.maxSeqLen = mDecoderDomain.getMaxDecodingTokens(); params.vocabSizePadded = mDecoderDomain.getVocabSizePadded(); // Sample multiple tokens per request and store them to separate to be accepted/rejected later // Sequence length is not modified, endIds is not checked, outputLogProbs are not supported. // Finished state is not set. invokeBatchTopKSampling(params, getStream()); TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__); } template void MedusaDecodingLayer::acceptDraftTokens(SpeculativeDecodingOutputs const& outputs, MedusaDecodingInputs const& inputs, std::shared_ptr const& workspace) { TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__); auto const batchSize = inputs.logits.value()->getDimension<0>(); auto const maxSeqLen = outputs.outputIds->getDimension<-1>(); auto* outputIds = bufferCast(*outputs.outputIds); auto const* endIds = bufferCast(*inputs.endIds); auto const* paths = bufferCast(*inputs.paths); auto const* batchSlots = bufferCast(*inputs.batchSlots); auto* sequenceLengths = bufferCastOrNull(outputs.sequenceLength); auto* numNewTokens = bufferCast(*outputs.numNewTokens.value()); auto* curTokensPerStepDevice = bufferCast(*inputs.curTokensPerStep.value()); auto const* targetTokensPerStepDevice = bufferCast(*inputs.targetTokensPerStep); auto const maxDraftPathLen = mDecoderDomain.getSpeculativeDecodingModule()->getMaxDraftPathLen(); auto medusaInputLogitsPtrs = BufferRange(*mMedusaInputLogitsPtrs); for (SizeType32 bi = 0; bi < batchSize; ++bi) { auto const slot = batchSlots[bi]; for (SizeType32 hi = 0; hi < maxDraftPathLen; ++hi) { medusaInputLogitsPtrs[slot * maxDraftPathLen + hi] = bufferCast(*inputs.medusaLogits[slot][hi]); } } auto* draftIds = bufferCast(*outputs.nextDraftTokens); TLLM_CHECK_WITH_INFO(draftIds != nullptr, "Draft ids must be provided for MedusaDecoding"); TLLM_CHECK_WITH_INFO(batchSlots != nullptr, "Batch slots must be provided for MedusaDecoding"); TLLM_CHECK_WITH_INFO(sequenceLengths != nullptr, "Sequence lengths must be provided for MedusaDecoding"); TLLM_CHECK_WITH_INFO(numNewTokens != nullptr, "Accepted lengths must be provided for MedusaDecoding"); TLLM_CHECK_WITH_INFO( curTokensPerStepDevice != nullptr, "Current tokens per step must be provided for MedusaDecoding"); TLLM_CHECK_WITH_INFO( targetTokensPerStepDevice != nullptr, "Target tokens per step must be provided for MedusaDecoding"); // Compare draft tokens from outputIds with sampled target tokens at mTargetTokensDevice using paths. // Select the longest accepted path, modify outputIds in-place, increment sequenceLengths accordingly. // Fill mMedusaSelectedLogitsPtrsDevice with respective Medusa logits auto* targetTokensDevicePtr = bufferCast(*mTargetTokensDevice); auto* finishedStatesPtr = reinterpret_cast(bufferCastOrNull(outputs.finished)); auto* bestPathIdsDevicePtr = bufferCastOrNull(mBestPathIdsDevice); auto medusaInputLogitsPtrsPtr = reinterpret_cast(bufferCast(*mMedusaInputLogitsPtrs)); auto medusaSelectedLogitsPtrsDevicePtr = const_cast(bufferCastOrNull(mMedusaSelectedLogitsPtrsDevice)); AcceptDraftTokensByIdsWithPathsParams params; params.outputIds = outputIds; params.draftIds = draftIds; params.targetIds = targetTokensDevicePtr; params.sequenceLengths = sequenceLengths; params.acceptedLengths = numNewTokens; params.finishedFinal = finishedStatesPtr; params.batchSlots = workspace->getDeviceBatchSlotsPtr(); params.paths = paths; params.endIds = endIds; params.medusaLogits = medusaInputLogitsPtrsPtr; params.logitsPtrs = medusaSelectedLogitsPtrsDevicePtr; params.curTokensPerStep = curTokensPerStepDevice; params.targetTokensPerStep = targetTokensPerStepDevice; params.bestPathIds = bestPathIdsDevicePtr; params.batchSize = batchSize; params.maxBatchSize = mDecoderDomain.getBatchSize(); params.vocabSize = mDecoderDomain.getVocabSize(); params.maxSeqLen = maxSeqLen; params.maxDraftPathLen = maxDraftPathLen; params.maxDecodingTokens = mDecoderDomain.getMaxDecodingTokens(); params.stream = getStream(); params.checkParams(); acceptDraftTokensByIdsWithPaths(params); TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__); } template void MedusaDecodingLayer::sampleNewDraftTokens(SpeculativeDecodingOutputs const& outputs, MedusaDecodingInputs const& inputs, std::shared_ptr const& workspace) { TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__); auto const batchSize = inputs.logits.value()->getDimension<0>(); auto const* batchSlots = bufferCast(*inputs.batchSlots); auto* sequenceLengths = bufferCastOrNull(outputs.sequenceLength); TLLM_CHECK_WITH_INFO(batchSlots != nullptr, "Batch slots must be provided for MedusaDecoding"); TLLM_CHECK_WITH_INFO(sequenceLengths != nullptr, "Sequence lengths must be provided for MedusaDecoding"); auto const maxDraftPathLen = mDecoderDomain.getSpeculativeDecodingModule()->getMaxDraftPathLen(); // For each request we sample Head Num times for topK[hi] tokens auto const batchSizeHeadNums = batchSize * maxDraftPathLen; auto const maxBatchSizeHeadNums = mDecoderDomain.getBatchSize() * maxDraftPathLen; auto* tiledBatchSlots = bufferCast(*mTiledBatchSlotsForward); for (SizeType32 bi = 0; bi < batchSize; ++bi) { for (SizeType32 hi = 0; hi < maxDraftPathLen; ++hi) { tiledBatchSlots[bi * maxDraftPathLen + hi] = maxDraftPathLen * batchSlots[bi] + hi; } } auto* draftIdsPtrs = reinterpret_cast(bufferCast(*mDraftIdsPtrHost)); auto* newDraftTokensDeviceRange = bufferCast(*mNewDraftTokensDevice); for (SizeType32 bi = 0; bi < batchSize; ++bi) { auto slot = batchSlots[bi]; for (SizeType32 hi = 0; hi < maxDraftPathLen; ++hi) { draftIdsPtrs[slot * maxDraftPathLen + hi] = newDraftTokensDeviceRange + slot * mDecoderDomain.getMaxDecodingTokens() + mCummulativeTopK[slot * maxDraftPathLen + hi]; } } TopKSamplingKernelParams params{}; params.logProbsPtrs = bufferCastOrNull(mMedusaSelectedLogitsPtrsDevice); params.outputIdsPtrs = draftIdsPtrs; params.workspace = workspace->getRawWorkspaceDevicePtr(); params.maxTopK = mRuntimeMaxTopKPerRequestPerMedusaHead; params.topKs = bufferCastOrNull(mRuntimeTopKPerRequestPerMedusaHeadDevice); params.batchSlots = tiledBatchSlots; params.curandState = reinterpret_cast(bufferCastOrNull(mCurandStatesMedusaLogitsDevice)); params.batchSize = batchSizeHeadNums; params.maxBatchSize = maxBatchSizeHeadNums; params.maxTokensPerStep = 1; params.vocabSizePadded = mDecoderDomain.getVocabSizePadded(); params.returnAllSelectedTokens = true; invokeBatchTopKSampling(params, getStream()); TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__); } template void MedusaDecodingLayer::scatterNewDraftTokens( SpeculativeDecodingOutputs const& outputs, MedusaDecodingInputs const& inputs) { TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__); auto const batchSize = inputs.logits.value()->getDimension<0>(); auto const* batchSlots = bufferCast(*inputs.batchSlots); TLLM_CHECK_WITH_INFO(batchSlots != nullptr, "Batch slots must be provided for MedusaDecoding"); auto* draftIds = bufferCastOrNull(outputs.nextDraftTokens); auto* tokensPerStepDevice = bufferCastOrNull(inputs.curTokensPerStep); auto const* treeIds = bufferCastOrNull(inputs.treeIds); TLLM_CHECK_WITH_INFO(draftIds != nullptr, "Draft ids must be provided for MedusaDecoding"); TLLM_CHECK_WITH_INFO(tokensPerStepDevice != nullptr, "Tokens per step must be provided for MedusaDecoding"); TLLM_CHECK_WITH_INFO(treeIds != nullptr, "Tree ids must be provided for MedusaDecoding"); auto* newDraftTokensDevice = bufferCastOrNull(mNewDraftTokensDevice); scatterMedusaDraftTokens(draftIds, newDraftTokensDevice, treeIds, tokensPerStepDevice, batchSlots, mDecoderDomain.getMaxDecodingTokens(), batchSize, getStream()); TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__); } template void MedusaDecodingLayer::packAcceptedPaths(SpeculativeDecodingOutputs const& outputs, MedusaDecodingInputs const& inputs, std::shared_ptr const& workspace) { TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__); auto const batchSize = inputs.logits.value()->getDimension<0>(); auto const* paths = bufferCast(*inputs.paths); auto const* batchSlots = workspace->getDeviceBatchSlotsPtr(); auto* numNewTokens = bufferCast(*outputs.numNewTokens.value()); auto* numNewTokensCumSum = bufferCast(*outputs.numNewTokensCumSum); auto* pathsOffsets = bufferCast(*outputs.pathsOffsets); auto* bestPathIdsDevicePtr = bufferCastOrNull(mBestPathIdsDevice); TLLM_CHECK_WITH_INFO(batchSlots != nullptr, "Batch slots must be provided for MedusaDecoding"); TLLM_CHECK_WITH_INFO(numNewTokens != nullptr, "Accepted lengths must be provided for MedusaDecoding"); TLLM_CHECK_WITH_INFO(numNewTokensCumSum != nullptr, "numNewTokensCumSum must be provided for MedusaDecoding"); TLLM_CHECK_WITH_INFO(pathsOffsets != nullptr, "pathsOffsets must be provided for MedusaDecoding"); invokePackAcceptedPaths(numNewTokensCumSum, pathsOffsets, numNewTokens, bestPathIdsDevicePtr, paths, batchSlots, batchSize, batchSize, mDecoderDomain.getMaxDecodingTokens(), mDecoderDomain.getSpeculativeDecodingModule()->getMaxPathLen(), false, getStream()); TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__); } template class MedusaDecodingLayer; template class MedusaDecodingLayer; } // namespace tensorrt_llm::layers