/* * Copyright (c) 2019-2024, NVIDIA CORPORATION. All rights reserved. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ #include "medusaDecodingLayer.h" #include "tensorrt_llm/common/cudaUtils.h" #include "tensorrt_llm/common/memoryUtils.h" #include "tensorrt_llm/kernels/decodingCommon.h" #include "tensorrt_llm/kernels/samplingTopKKernels.h" #include "tensorrt_llm/kernels/speculativeDecoding/medusaDecodingKernels.h" #include "tensorrt_llm/layers/defaultDecodingParams.h" #include "tensorrt_llm/runtime/bufferManager.h" #include "tensorrt_llm/runtime/iBuffer.h" #include using namespace tensorrt_llm::common; using namespace tensorrt_llm::kernels; using namespace tensorrt_llm::kernels::speculative_decoding; using namespace tensorrt_llm::runtime; namespace tensorrt_llm::layers { template MedusaDecodingLayer::MedusaDecodingLayer( DecoderDomain const& decoderDomain, std::shared_ptr bufferManager) : BaseLayer(decoderDomain, bufferManager) { TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__); allocateBuffer(); TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__); } template void MedusaDecodingLayer::allocateBuffer() { TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__); auto const maxDraftPathLen = mDecoderDomain.getSpeculativeDecodingModule()->getMaxDraftPathLen(); // Get sampling workspace size { auto samplingSizePrimarySampling = getTopKWorkspaceSize(mDecoderDomain.getBatchSize(), mDecoderDomain.getMaxDecodingTokens(), TOP_K_MAX, mDecoderDomain.getVocabSizePadded()); auto const maxBatchSizeHeadNums = mDecoderDomain.getBatchSize() * maxDraftPathLen; auto samplingSizeMedusaHeadsSampling = getTopKWorkspaceSize(maxBatchSizeHeadNums, 1, TOP_K_MAX, mDecoderDomain.getVocabSizePadded()); mWorkspaceSize = std::max(samplingSizePrimarySampling, samplingSizeMedusaHeadsSampling); } mDraftIdsPtrHost = BufferManager::pinnedPool( ITensor::makeShape({static_cast(mDecoderDomain.getBatchSize()), maxDraftPathLen}), TRTDataType::value); mCummulativeTopK.resize(mDecoderDomain.getBatchSize() * maxDraftPathLen); auto const batchSize = mDecoderDomain.getBatchSize(); auto const batchSizeShape = ITensor::makeShape({mDecoderDomain.getBatchSize()}); mCurandStatesDevice = mBufferManager->gpu( ITensor::makeShape({static_cast(batchSize * sizeof(curandState_t))}), TRTDataType::value); mSetupWorkspaceDevice = mBufferManager->gpu(ITensor::makeShape({batchSize * maxDraftPathLen}), TRTDataType::value); mSamplingWorkspaceDevice = mBufferManager->gpu(mWorkspaceSize, TRTDataType::value); mRuntimeTopKDevice = mBufferManager->gpu(batchSizeShape, TRTDataType::value); mTargetTokensDevice = mBufferManager->gpu( ITensor::makeShape({batchSize, mDecoderDomain.getMaxDecodingTokens()}), TRTDataType::value); mRandomSeedsDevice = mBufferManager->gpu(ITensor::makeShape({batchSize, maxDraftPathLen}), TRTDataType::value); mMedusaSelectedLogitsPtrsDevice = mBufferManager->gpu(ITensor::makeShape({batchSize, maxDraftPathLen}), TRTDataType::value); mCurandStatesMedusaLogitsDevice = mBufferManager->gpu( ITensor::makeShape({batchSize, maxDraftPathLen, sizeof(curandState_t)}), TRTDataType::value); mRuntimeTopKPerRequestPerMedusaHeadDevice = mBufferManager->gpu(ITensor::makeShape({batchSize, maxDraftPathLen}), TRTDataType::value); mNewDraftTokensDevice = mBufferManager->gpu( ITensor::makeShape({batchSize, mDecoderDomain.getMaxDecodingTokens()}), TRTDataType::value); mBestPathIdsDevice = mBufferManager->gpu(batchSizeShape, TRTDataType::value); mTiledBatchSlotsSetup = BufferManager::pinnedPool( ITensor::makeShape({static_cast(mDecoderDomain.getBatchSize() * maxDraftPathLen)}), nvinfer1::DataType::kINT32); mTiledBatchSlotsForward = BufferManager::pinnedPool( ITensor::makeShape({static_cast(mDecoderDomain.getBatchSize() * maxDraftPathLen)}), nvinfer1::DataType::kINT32); mMedusaInputLogitsPtrs = BufferManager::pinnedPool( ITensor::makeShape({static_cast(mDecoderDomain.getBatchSize() * maxDraftPathLen)}), TRTDataType::value); TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__); } template void MedusaDecodingLayer::setup(SizeType32 batchSize, SizeType32 beamWidth, BufferConstPtr batchSlots, std::shared_ptr const& baseSetupParams) { TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__); auto setupParams = std::dynamic_pointer_cast(baseSetupParams); // Prepare random seed auto initCurandStates = [this](std::optional>& randomSeed, SizeType32 batchSize, BufferConstPtr batchSlots, TensorPtr statesDevice) { auto batchSlotsPtr = bufferCastOrNull(batchSlots); auto curandStatesDevicePtr = reinterpret_cast(bufferCast(*statesDevice)); if (randomSeed) { if (randomSeed->size() == 1) { invokeCurandInitialize( curandStatesDevicePtr, batchSlotsPtr, batchSize, randomSeed->front(), this->getStream()); sync_check_cuda_error(); } else { TLLM_CHECK_WITH_INFO(randomSeed->size() == batchSize, "Random seed vector size mismatch."); this->mBufferManager->copy(randomSeed->data(), *this->mRandomSeedsDevice, runtime::MemoryType::kCPU); auto randomSeedsDevicePtr = bufferCastOrNull(this->mRandomSeedsDevice); invokeCurandBatchInitialize( curandStatesDevicePtr, batchSlotsPtr, batchSize, randomSeedsDevicePtr, this->getStream()); sync_check_cuda_error(); } } else { // Initialize curand states using the default seed 0. invokeCurandInitialize( curandStatesDevicePtr, batchSlotsPtr, batchSize, DefaultDecodingParams::getSeed(), this->getStream()); } }; initCurandStates(setupParams->randomSeed, batchSize, batchSlots, mCurandStatesDevice); auto const maxDraftPathLen = mDecoderDomain.getSpeculativeDecodingModule()->getMaxDraftPathLen(); auto const batchSizeMaxNumHeads = batchSize * maxDraftPathLen; auto randomSeed = setupParams->randomSeed.value_or(std::vector(batchSize, uint64_t{0})); std::vector tiledRandomSeed(batchSizeMaxNumHeads); if (randomSeed.size() > 1) { for (SizeType32 bi = 0; bi < batchSize; ++bi) { for (SizeType32 hi = 0; hi < maxDraftPathLen; ++hi) { tiledRandomSeed[bi * maxDraftPathLen + hi] = randomSeed[bi]; } } } auto tiledBatchSlots = bufferCast(*mTiledBatchSlotsSetup); BufferRange batchSlotsRange(*batchSlots); for (SizeType32 bi = 0; bi < batchSize; ++bi) { for (SizeType32 hi = 0; hi < maxDraftPathLen; ++hi) { tiledBatchSlots[bi * maxDraftPathLen + hi] = batchSlotsRange[bi] + hi; } } auto tiledRandomSeedOpt = std::make_optional(std::move(tiledRandomSeed)); initCurandStates(tiledRandomSeedOpt, batchSizeMaxNumHeads, mTiledBatchSlotsSetup, mCurandStatesMedusaLogitsDevice); // Prepare runtime top K auto prepareRuntimeTopK = [this](std::vector const& runtimeTopK, SizeType32 batchSize, BufferConstPtr batchSlots, BufferPtr runtimeTopKDevice) { TLLM_CHECK_WITH_INFO(runtimeTopK.size() == batchSize, fmtstr("runtimeTopK.size() (%lu) == batchSize (%d) is not satisfied!", runtimeTopK.size(), batchSize)); this->mBufferManager->copy(runtimeTopK.data(), *this->mSetupWorkspaceDevice, runtime::MemoryType::kCPU); auto setupWorkspaceDevicePtr = bufferCastOrNull(this->mSetupWorkspaceDevice); auto runtimeTopKDevicePtr = bufferCastOrNull(runtimeTopKDevice); auto batchSlotsPtr = bufferCastOrNull(batchSlots); invokeScatterDecodingParams( setupWorkspaceDevicePtr, runtimeTopKDevicePtr, batchSlotsPtr, batchSize, getStream()); // FIXME(nkorobov): monotonically growing auto const curMaxTopK = *std::max_element(std::begin(runtimeTopK), std::end(runtimeTopK)); return curMaxTopK; }; auto constexpr defaultTopK = 1u; { auto runtimeTopK = setupParams->runtimeTopK.value_or(std::vector(batchSize, defaultTopK)); auto const curMaxTopK = prepareRuntimeTopK(runtimeTopK, batchSize, batchSlots, mRuntimeTopKDevice); mRuntimeMaxTopK = std::max(mRuntimeMaxTopK, curMaxTopK); } { auto runtimeHeadsTopK = setupParams->runtimeHeadsTopK; std::vector runtimeHeadsTopKFlatten; if (runtimeHeadsTopK.has_value() && runtimeHeadsTopK->size()) { for (auto const& sub : runtimeHeadsTopK.value()) { runtimeHeadsTopKFlatten.insert(runtimeHeadsTopKFlatten.end(), sub.begin(), sub.end()); } } else { runtimeHeadsTopKFlatten = std::vector(batchSizeMaxNumHeads, defaultTopK); } BufferRange batchSlotsRange(*batchSlots); for (SizeType32 bi = 0; bi < batchSize; ++bi) { auto const slot = batchSlotsRange[bi]; SizeType32 cummulativeTopK = 0; for (SizeType32 hi = 0; hi < maxDraftPathLen; ++hi) { mCummulativeTopK[slot * maxDraftPathLen + hi] = cummulativeTopK; cummulativeTopK += runtimeHeadsTopKFlatten[bi * maxDraftPathLen + hi]; } } auto tiledBatchSlots = bufferCast(*mTiledBatchSlotsSetup); for (SizeType32 bi = 0; bi < batchSize; ++bi) { for (SizeType32 hi = 0; hi < maxDraftPathLen; ++hi) { tiledBatchSlots[bi * maxDraftPathLen + hi] = maxDraftPathLen * batchSlotsRange[bi] + hi; } } auto const curMaxTopK = prepareRuntimeTopK(runtimeHeadsTopKFlatten, static_cast(batchSizeMaxNumHeads), mTiledBatchSlotsSetup, mRuntimeTopKPerRequestPerMedusaHeadDevice); mRuntimeMaxTopKPerRequestPerMedusaHead = std::max(mRuntimeMaxTopKPerRequestPerMedusaHead, curMaxTopK); } TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__); } template void MedusaDecodingLayer::forwardAsync( std::shared_ptr const& baseOutputs, std::shared_ptr const& baseInputs) { TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__); auto inputs = std::dynamic_pointer_cast(baseInputs); auto outputs = std::dynamic_pointer_cast(baseOutputs); samplePrimeHeadTokens(*outputs, *inputs); acceptDraftTokens(*outputs, *inputs); sampleNewDraftTokens(*outputs, *inputs); scatterNewDraftTokens(*outputs, *inputs); packAcceptedPaths(*outputs, *inputs); TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__); } template size_t MedusaDecodingLayer::getWorkspaceSize() const noexcept { return mWorkspaceSize; } template void MedusaDecodingLayer::samplePrimeHeadTokens( SpeculativeDecodingOutputs const& outputs, MedusaDecodingInputs const& inputs) { TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__); auto const batchSize = inputs.logits.value()->getDimension<0>(); auto logits = bufferCast(*inputs.logits.value()); auto batchSlots = bufferCastOrNull(inputs.batchSlots); auto sequenceLengths = bufferCastOrNull(outputs.sequenceLength); auto tokensPerStepDevice = bufferCast(*inputs.curTokensPerStep.value()); TLLM_CHECK_WITH_INFO(batchSlots != nullptr, "Batch slots must be provided for MedusaDecoding"); TLLM_CHECK_WITH_INFO(sequenceLengths != nullptr, "Sequence lengths must be provided for MedusaDecoding"); TopKSamplingKernelParams params; params.logProbs = logits; params.outputIds = bufferCastOrNull(mTargetTokensDevice); params.workspace = mSamplingWorkspaceDevice->data(); params.maxTopK = mRuntimeMaxTopK; params.topKs = bufferCastOrNull(mRuntimeTopKDevice); params.batchSlots = batchSlots; params.curandState = reinterpret_cast(bufferCastOrNull(mCurandStatesDevice)); params.batchSize = batchSize; params.maxBatchSize = mDecoderDomain.getBatchSize(); params.tokensPerStep = tokensPerStepDevice; params.maxTokensPerStep = mDecoderDomain.getMaxDecodingTokens(); params.maxSeqLen = mDecoderDomain.getMaxDecodingTokens(); params.vocabSizePadded = mDecoderDomain.getVocabSizePadded(); // Sample multiple tokens per request and store them to separate to be accepted/rejected later // Sequence length is not modified, endIds is not checked, outputLogProbs are not supported. // Finished state is not set. invokeBatchTopKSampling(params, getStream()); TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__); } template void MedusaDecodingLayer::acceptDraftTokens( SpeculativeDecodingOutputs const& outputs, MedusaDecodingInputs const& inputs) { TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__); auto const batchSize = inputs.logits.value()->getDimension<0>(); auto const maxSeqLen = outputs.outputIds->getDimension<-1>(); auto outputIds = bufferCast(*outputs.outputIds); auto endIds = bufferCast(*inputs.endIds); auto paths = bufferCast(*inputs.paths); auto batchSlots = bufferCastOrNull(inputs.batchSlots); auto sequenceLengths = bufferCastOrNull(outputs.sequenceLength); auto numNewTokens = bufferCast(*outputs.numNewTokens.value()); auto curTokensPerStepDevice = bufferCast(*inputs.curTokensPerStep.value()); auto targetTokensPerStepDevice = bufferCast(*inputs.targetTokensPerStep); auto const maxDraftPathLen = mDecoderDomain.getSpeculativeDecodingModule()->getMaxDraftPathLen(); auto medusaInputLogitsPtrs = BufferRange(*mMedusaInputLogitsPtrs); for (SizeType32 bi = 0; bi < batchSize; ++bi) { auto const slot = batchSlots[bi]; for (SizeType32 hi = 0; hi < maxDraftPathLen; ++hi) { medusaInputLogitsPtrs[slot * maxDraftPathLen + hi] = bufferCast(*inputs.medusaLogits[slot][hi]); } } auto draftIds = bufferCast(*outputs.nextDraftTokens); TLLM_CHECK_WITH_INFO(draftIds != nullptr, "Draft ids must be provided for MedusaDecoding"); TLLM_CHECK_WITH_INFO(batchSlots != nullptr, "Batch slots must be provided for MedusaDecoding"); TLLM_CHECK_WITH_INFO(sequenceLengths != nullptr, "Sequence lengths must be provided for MedusaDecoding"); TLLM_CHECK_WITH_INFO(numNewTokens != nullptr, "Accepted lengths must be provided for MedusaDecoding"); TLLM_CHECK_WITH_INFO( curTokensPerStepDevice != nullptr, "Current tokens per step must be provided for MedusaDecoding"); TLLM_CHECK_WITH_INFO( targetTokensPerStepDevice != nullptr, "Target tokens per step must be provided for MedusaDecoding"); // Compare draft tokens from outputIds with sampled target tokens at mTargetTokensDevice using paths. // Select the longest accepted path, modify outputIds in-place, increment sequenceLengths accordingly. // Fill mMedusaSelectedLogitsPtrsDevice with respective Medusa logits auto targetTokensDevicePtr = bufferCast(*mTargetTokensDevice); auto finishedStatesPtr = reinterpret_cast(bufferCastOrNull(outputs.finished)); auto bestPathIdsDevicePtr = bufferCastOrNull(mBestPathIdsDevice); auto medusaInputLogitsPtrsPtr = reinterpret_cast(bufferCast(*mMedusaInputLogitsPtrs)); auto medusaSelectedLogitsPtrsDevicePtr = const_cast(bufferCastOrNull(mMedusaSelectedLogitsPtrsDevice)); acceptDraftTokensByIdsWithPaths(outputIds, draftIds, targetTokensDevicePtr, sequenceLengths, numNewTokens, finishedStatesPtr, batchSlots, paths, endIds, medusaInputLogitsPtrsPtr, medusaSelectedLogitsPtrsDevicePtr, curTokensPerStepDevice, targetTokensPerStepDevice, bestPathIdsDevicePtr, batchSize, mDecoderDomain.getVocabSize(), mDecoderDomain.getBatchSize(), maxSeqLen, maxDraftPathLen, mDecoderDomain.getMaxDecodingTokens(), getStream()); TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__); } template void MedusaDecodingLayer::sampleNewDraftTokens( SpeculativeDecodingOutputs const& outputs, MedusaDecodingInputs const& inputs) { TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__); auto const batchSize = inputs.logits.value()->getDimension<0>(); auto batchSlots = bufferCastOrNull(inputs.batchSlots); auto sequenceLengths = bufferCastOrNull(outputs.sequenceLength); TLLM_CHECK_WITH_INFO(batchSlots != nullptr, "Batch slots must be provided for MedusaDecoding"); TLLM_CHECK_WITH_INFO(sequenceLengths != nullptr, "Sequence lengths must be provided for MedusaDecoding"); auto const maxDraftPathLen = mDecoderDomain.getSpeculativeDecodingModule()->getMaxDraftPathLen(); // For each request we sample Head Num times for topK[hi] tokens auto const batchSizeHeadNums = batchSize * maxDraftPathLen; auto const maxBatchSizeHeadNums = mDecoderDomain.getBatchSize() * maxDraftPathLen; auto tiledBatchSlots = bufferCast(*mTiledBatchSlotsForward); for (SizeType32 bi = 0; bi < batchSize; ++bi) { for (SizeType32 hi = 0; hi < maxDraftPathLen; ++hi) { tiledBatchSlots[bi * maxDraftPathLen + hi] = maxDraftPathLen * batchSlots[bi] + hi; } } auto draftIdsPtrs = reinterpret_cast(bufferCast(*mDraftIdsPtrHost)); auto newDraftTokensDeviceRange = bufferCast(*mNewDraftTokensDevice); for (SizeType32 bi = 0; bi < batchSize; ++bi) { auto slot = batchSlots[bi]; for (SizeType32 hi = 0; hi < maxDraftPathLen; ++hi) { draftIdsPtrs[slot * maxDraftPathLen + hi] = newDraftTokensDeviceRange + slot * mDecoderDomain.getMaxDecodingTokens() + mCummulativeTopK[slot * maxDraftPathLen + hi]; } } TopKSamplingKernelParams params; params.logProbsPtrs = bufferCastOrNull(mMedusaSelectedLogitsPtrsDevice); params.outputIdsPtrs = draftIdsPtrs; params.workspace = mSamplingWorkspaceDevice->data(); params.maxTopK = mRuntimeMaxTopKPerRequestPerMedusaHead; params.topKs = bufferCastOrNull(mRuntimeTopKPerRequestPerMedusaHeadDevice); params.batchSlots = tiledBatchSlots; params.curandState = reinterpret_cast(bufferCastOrNull(mCurandStatesMedusaLogitsDevice)); params.batchSize = batchSizeHeadNums; params.maxBatchSize = maxBatchSizeHeadNums; params.maxTokensPerStep = 1; params.vocabSizePadded = mDecoderDomain.getVocabSizePadded(); params.returnAllTopK = true; invokeBatchTopKSampling(params, getStream()); TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__); } template void MedusaDecodingLayer::scatterNewDraftTokens( SpeculativeDecodingOutputs const& outputs, MedusaDecodingInputs const& inputs) { TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__); auto const batchSize = inputs.logits.value()->getDimension<0>(); auto batchSlots = bufferCastOrNull(inputs.batchSlots); TLLM_CHECK_WITH_INFO(batchSlots != nullptr, "Batch slots must be provided for MedusaDecoding"); auto draftIds = bufferCastOrNull(outputs.nextDraftTokens); auto tokensPerStepDevice = bufferCastOrNull(inputs.curTokensPerStep); auto treeIds = bufferCastOrNull(inputs.treeIds); TLLM_CHECK_WITH_INFO(draftIds != nullptr, "Draft ids must be provided for MedusaDecoding"); TLLM_CHECK_WITH_INFO(tokensPerStepDevice != nullptr, "Tokens per step must be provided for MedusaDecoding"); TLLM_CHECK_WITH_INFO(treeIds != nullptr, "Tree ids must be provided for MedusaDecoding"); auto newDraftTokensDevice = bufferCastOrNull(mNewDraftTokensDevice); scatterMedusaDraftTokens(draftIds, newDraftTokensDevice, treeIds, tokensPerStepDevice, batchSlots, mDecoderDomain.getMaxDecodingTokens(), batchSize, getStream()); TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__); } template void MedusaDecodingLayer::packAcceptedPaths( SpeculativeDecodingOutputs const& outputs, MedusaDecodingInputs const& inputs) { TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__); auto const batchSize = inputs.logits.value()->getDimension<0>(); auto paths = bufferCast(*inputs.paths); auto batchSlots = bufferCastOrNull(inputs.batchSlots); auto numNewTokens = bufferCast(*outputs.numNewTokens.value()); auto numNewTokensCumSum = bufferCast(*outputs.numNewTokensCumSum); auto pathsOffsets = bufferCast(*outputs.pathsOffsets); auto bestPathIdsDevicePtr = bufferCastOrNull(mBestPathIdsDevice); TLLM_CHECK_WITH_INFO(batchSlots != nullptr, "Batch slots must be provided for MedusaDecoding"); TLLM_CHECK_WITH_INFO(numNewTokens != nullptr, "Accepted lengths must be provided for MedusaDecoding"); TLLM_CHECK_WITH_INFO(numNewTokensCumSum != nullptr, "numNewTokensCumSum must be provided for MedusaDecoding"); TLLM_CHECK_WITH_INFO(pathsOffsets != nullptr, "pathsOffsets must be provided for MedusaDecoding"); invokePackAcceptedPaths(numNewTokensCumSum, pathsOffsets, numNewTokens, bestPathIdsDevicePtr, paths, batchSlots, batchSize, mDecoderDomain.getMaxDecodingTokens(), mDecoderDomain.getSpeculativeDecodingModule()->getMaxPathLen(), false, getStream()); TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__); } template class MedusaDecodingLayer; template class MedusaDecodingLayer; } // namespace tensorrt_llm::layers