/* * Copyright (c) 2019-2024, NVIDIA CORPORATION. All rights reserved. * Copyright (c) 2021, NAVER Corp. Authored by CLOVA. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ #include "tensorrt_llm/layers/medusaDecodingLayer.h" #include "tensorrt_llm/common/cudaUtils.h" #include "tensorrt_llm/common/memoryUtils.h" #include "tensorrt_llm/kernels/decodingCommon.h" #include "tensorrt_llm/kernels/decodingKernels.h" #include "tensorrt_llm/kernels/samplingTopKKernels.h" #include "tensorrt_llm/runtime/bufferManager.h" #include "tensorrt_llm/runtime/iBuffer.h" #include using namespace tensorrt_llm::common; using namespace tensorrt_llm::kernels; using namespace tensorrt_llm::runtime; namespace tensorrt_llm { namespace layers { template MedusaDecodingLayer::MedusaDecodingLayer(SizeType maxBatchSize, SizeType vocabSize, SizeType vocabSizePadded, SizeType maxTokensPerStep, SizeType maxNumHeads, cudaStream_t stream, std::shared_ptr allocator) : BaseLayer(stream, std::move(allocator), nullptr) , mMaxBatchSize(maxBatchSize) , mVocabSize(vocabSize) , mVocabSizePadded(vocabSizePadded) , mMaxTokensPerStep(maxTokensPerStep) , mMaxNumHeads(maxNumHeads) { TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__); allocateBuffer(); TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__); } template MedusaDecodingLayer::~MedusaDecodingLayer() { TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__); freeBuffer(); TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__); } template void MedusaDecodingLayer::allocateBuffer() { TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__); // Get sampling workspace size { auto samplingSizePrimarySampling = getTopKWorkspaceSize(mMaxBatchSize, mMaxTokensPerStep, TOP_K_MAX, mVocabSizePadded); auto const maxBatchSizeHeadNums = mMaxBatchSize * mMaxNumHeads; auto samplingSizeMedusaHeadsSampling = getTopKWorkspaceSize(maxBatchSizeHeadNums, 1, TOP_K_MAX, mVocabSizePadded); mSamplingWorkspaceSize = std::max(samplingSizePrimarySampling, samplingSizeMedusaHeadsSampling); } mDraftIdsPtrHost = runtime::BufferManager::pinned(ITensor::makeShape({static_cast(mMaxBatchSize), mMaxNumHeads}), runtime::TRTDataType::value); mCummulativeTopK.resize(mMaxBatchSize * mMaxNumHeads); std::array deviceBufferSizes; deviceBufferSizes[0] = mMaxBatchSize * sizeof(curandState_t); deviceBufferSizes[1] = mMaxBatchSize * mMaxNumHeads * sizeof(SizeType); deviceBufferSizes[2] = mSamplingWorkspaceSize; deviceBufferSizes[3] = mMaxBatchSize * sizeof(SizeType); deviceBufferSizes[4] = mMaxBatchSize * mMaxTokensPerStep * sizeof(TokenIdType); deviceBufferSizes[5] = mMaxBatchSize * mMaxNumHeads * sizeof(uint64_t); deviceBufferSizes[6] = mMaxBatchSize * mMaxNumHeads * sizeof(T*); deviceBufferSizes[7] = mMaxBatchSize * mMaxNumHeads * sizeof(curandState_t); deviceBufferSizes[8] = mMaxBatchSize * mMaxNumHeads * sizeof(SizeType); deviceBufferSizes[9] = mMaxBatchSize * mMaxTokensPerStep * sizeof(TokenIdType); deviceBufferSizes[10] = mMaxBatchSize * sizeof(SizeType); mCurandStatesDevice = mAllocator->reMalloc(mCurandStatesDevice, deviceBufferSizes[0], false); mSetupWorkspaceDevice = mAllocator->reMalloc(mSetupWorkspaceDevice, deviceBufferSizes[1], false); mSamplingWorkspaceDevice = mAllocator->reMalloc(mSamplingWorkspaceDevice, deviceBufferSizes[2], false); mRuntimeTopKDevice = mAllocator->reMalloc(mRuntimeTopKDevice, deviceBufferSizes[3], false); mTargetTokensDevice = mAllocator->reMalloc(mTargetTokensDevice, deviceBufferSizes[4], false); mRandomSeedsDevice = mAllocator->reMalloc(mRandomSeedsDevice, deviceBufferSizes[5], false); mMedusaSelectedLogitsPtrsDevice = mAllocator->reMalloc(mMedusaSelectedLogitsPtrsDevice, deviceBufferSizes[6], false); mCurandStatesMedusaLogitsDevice = mAllocator->reMalloc(mCurandStatesMedusaLogitsDevice, deviceBufferSizes[7], false); mRuntimeTopKPerRequestPerMedusaHeadDevice = mAllocator->reMalloc(mRuntimeTopKPerRequestPerMedusaHeadDevice, deviceBufferSizes[8], false); mNewDraftTokensDevice = mAllocator->reMalloc(mNewDraftTokensDevice, deviceBufferSizes[9], false); mBestPathIdsDevice = mAllocator->reMalloc(mBestPathIdsDevice, deviceBufferSizes[10], false); mTiledBatchSlotsSetup = BufferManager::pinnedPool( ITensor::makeShape({static_cast(mMaxBatchSize * mMaxNumHeads)}), nvinfer1::DataType::kINT32); mTiledBatchSlotsForward = BufferManager::pinnedPool( ITensor::makeShape({static_cast(mMaxBatchSize * mMaxNumHeads)}), nvinfer1::DataType::kINT32); mMedusaInputLogitsPtrs = BufferManager::pinnedPool( ITensor::makeShape({static_cast(mMaxBatchSize * mMaxNumHeads)}), TRTDataType::value); TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__); } template void MedusaDecodingLayer::freeBuffer() { TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__); mAllocator->free((void**) (&mCurandStatesDevice)); mAllocator->free((void**) (&mSetupWorkspaceDevice)); mAllocator->free((void**) (&mSamplingWorkspaceDevice)); mAllocator->free((void**) (&mRuntimeTopKDevice)); mAllocator->free((void**) (&mTargetTokensDevice)); mAllocator->free((void**) (&mRandomSeedsDevice)); mAllocator->free((void**) (&mMedusaSelectedLogitsPtrsDevice)); mAllocator->free((void**) (&mCurandStatesMedusaLogitsDevice)); mAllocator->free((void**) (&mRuntimeTopKPerRequestPerMedusaHeadDevice)); mAllocator->free((void**) (&mNewDraftTokensDevice)); mAllocator->free((void**) (&mBestPathIdsDevice)); TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__); } template void MedusaDecodingLayer::setup(SizeType batchSize, SizeType const* batchSlots, MedusaSetupParams const& setupParams) { TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__); // Prepare random seed auto initCurandStates = [this](std::optional> const& randomSeed, SizeType batchSize, SizeType const* batchSlots, curandState_t* statesDevice) { if (randomSeed) { if (randomSeed->size() == 1) { invokeCurandInitialize(statesDevice, batchSlots, batchSize, randomSeed->front(), this->mStream); sync_check_cuda_error(); } else { TLLM_CHECK_WITH_INFO(randomSeed->size() == batchSize, "Random seed vector size mismatch."); cudaAutoCpy(this->mRandomSeedsDevice, randomSeed->data(), batchSize, this->mStream); invokeCurandBatchInitialize( statesDevice, batchSlots, batchSize, this->mRandomSeedsDevice, this->mStream); sync_check_cuda_error(); } } else { // Initialize curand states using the default seed 0. invokeCurandInitialize(statesDevice, batchSlots, batchSize, 0, this->mStream); } }; initCurandStates(setupParams.randomSeed, batchSize, batchSlots, mCurandStatesDevice); auto batchSizeMaxNumHeads = batchSize * mMaxNumHeads; auto randomSeed = setupParams.randomSeed.value_or(std::vector(batchSize, uint64_t{0})); std::vector tiledRandomSeed(batchSizeMaxNumHeads); if (randomSeed.size() > 1) { for (SizeType bi = 0; bi < batchSize; ++bi) { for (SizeType hi = 0; hi < mMaxNumHeads; ++hi) { tiledRandomSeed[bi * mMaxNumHeads + hi] = randomSeed[bi]; } } } auto tiledBatchSlots = bufferCast(*mTiledBatchSlotsSetup); for (SizeType bi = 0; bi < batchSize; ++bi) { for (SizeType hi = 0; hi < mMaxNumHeads; ++hi) { tiledBatchSlots[bi * mMaxNumHeads + hi] = batchSlots[bi] + hi; } } initCurandStates({tiledRandomSeed}, batchSizeMaxNumHeads, tiledBatchSlots, mCurandStatesMedusaLogitsDevice); // Prepare runtime top K auto prepareRuntimeTopK = [this](std::vector const& runtimeTopK, SizeType batchSize, SizeType const* batchSlots, SizeType* runtimeTopKDevice) { TLLM_CHECK_WITH_INFO(runtimeTopK.size() == batchSize, fmtstr("runtimeTopK.size() (%lu) == batchSize (%d) is not satisfied!", runtimeTopK.size(), batchSize)); cudaAutoCpy( reinterpret_cast(this->mSetupWorkspaceDevice), runtimeTopK.data(), batchSize, this->mStream); invokeScatterDecodingParams(reinterpret_cast(this->mSetupWorkspaceDevice), runtimeTopKDevice, batchSlots, batchSize, this->mStream); // FIXME(nkorobov): monotonically growing auto const curMaxTopK = *std::max_element(std::begin(runtimeTopK), std::end(runtimeTopK)); return curMaxTopK; }; auto constexpr defaultTopK = 1u; { auto runtimeTopK = setupParams.runtimeTopK.value_or(std::vector(batchSize, defaultTopK)); auto const curMaxTopK = prepareRuntimeTopK(runtimeTopK, batchSize, batchSlots, mRuntimeTopKDevice); mRuntimeMaxTopK = std::max(mRuntimeMaxTopK, curMaxTopK); } { auto runtimeHeadsTopK = setupParams.runtimeHeadsTopK; std::vector runtimeHeadsTopKFlatten; if (runtimeHeadsTopK.has_value()) { for (auto const& sub : runtimeHeadsTopK.value()) { runtimeHeadsTopKFlatten.insert(runtimeHeadsTopKFlatten.end(), sub.begin(), sub.end()); } } else { runtimeHeadsTopKFlatten = std::vector(batchSizeMaxNumHeads, defaultTopK); } for (SizeType bi = 0; bi < batchSize; ++bi) { auto const slot = batchSlots[bi]; SizeType cummulativeTopK = 0; for (SizeType hi = 0; hi < mMaxNumHeads; ++hi) { mCummulativeTopK[slot * mMaxNumHeads + hi] = cummulativeTopK; cummulativeTopK += runtimeHeadsTopKFlatten[bi * mMaxNumHeads + hi]; } } auto tiledBatchSlots = bufferCast(*mTiledBatchSlotsSetup); for (SizeType bi = 0; bi < batchSize; ++bi) { for (SizeType hi = 0; hi < mMaxNumHeads; ++hi) { tiledBatchSlots[bi * mMaxNumHeads + hi] = mMaxNumHeads * batchSlots[bi] + hi; } } auto const curMaxTopK = prepareRuntimeTopK(runtimeHeadsTopKFlatten, static_cast(batchSizeMaxNumHeads), tiledBatchSlots, mRuntimeTopKPerRequestPerMedusaHeadDevice); mRuntimeMaxTopKPerRequestPerMedusaHead = std::max(mRuntimeMaxTopKPerRequestPerMedusaHead, curMaxTopK); } TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__); } template void MedusaDecodingLayer::forward(DecodingOutputParams& outputs, MedusaForwardParams& inputs) { TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__); samplePrimeHeadTokens(outputs, inputs); acceptDraftTokens(outputs, inputs); sampleNewDraftTokens(outputs, inputs); scatterNewDraftTokens(outputs, inputs); packAcceptedPaths(outputs, inputs); TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__); } template void MedusaDecodingLayer::samplePrimeHeadTokens(DecodingOutputParams& outputs, MedusaForwardParams& inputs) { TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__); auto const batchSize = inputs.logits.shape[0]; auto logits = inputs.logits.template getPtr(); auto batchSlots = inputs.batch_slots ? inputs.batch_slots->template getPtr() : nullptr; auto sequenceLengths = (outputs.sequence_length) ? outputs.sequence_length->template getPtr() : nullptr; auto tokensPerStepDevice = inputs.medusaCurTokensPerStep.template getPtr(); TLLM_CHECK_WITH_INFO(batchSlots != nullptr, "Batch slots must be provided for MedusaDecoding"); TLLM_CHECK_WITH_INFO(sequenceLengths != nullptr, "Sequence lengths must be provided for MedusaDecoding"); // Sample multiple tokens per request and store them to separate to be accepted/rejected later // Sequence length is not modified, endIds is not checked, outputLogProbs are not supported. // Finished state is not set. invokeBatchTopKSampling(mSamplingWorkspaceDevice, logits, /* logProbsPtrs */ static_cast(nullptr), /* outputIdsPtrs */ nullptr, mTargetTokensDevice, /* sequenceLengths */ nullptr, /* finishedInput */ nullptr, /* finishedOutput */ nullptr, /* cumLogProbs */ nullptr, /* outputLogProbs */ nullptr, mCurandStatesDevice, mRuntimeMaxTopK, mRuntimeTopKDevice, 1.0f, /* runtimeTopPDevice */ nullptr, mVocabSizePadded, /* endIds */ nullptr, batchSlots, mStream, batchSize, mMaxBatchSize, tokensPerStepDevice, mMaxTokensPerStep, mMaxTokensPerStep, /* skipDecode */ nullptr, /* normalizeLogProbs */ false, /* probsComputed */ false, /* return all Top-K*/ false); TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__); } template void MedusaDecodingLayer::acceptDraftTokens(DecodingOutputParams& outputs, MedusaForwardParams& inputs) { TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__); auto const batchSize = inputs.logits.shape[0]; auto const maxSeqLen = outputs.output_ids.shape[outputs.output_ids.shape.size() - 1]; auto outputIds = outputs.output_ids.template getPtr(); auto endIds = inputs.end_ids.template getPtr(); auto paths = inputs.paths.template getPtr(); auto batchSlots = inputs.batch_slots ? inputs.batch_slots->template getPtr() : static_cast(nullptr); auto sequenceLengths = outputs.sequence_length ? outputs.sequence_length->template getPtr() : static_cast(nullptr); auto acceptedLengths = outputs.acceptedLengths ? outputs.acceptedLengths->template getPtr() : static_cast(nullptr); auto curTokensPerStepDevice = inputs.medusaCurTokensPerStep.template getPtr(); auto targetTokensPerStepDevice = inputs.medusaTargetTokensPerStep.template getPtr(); auto medusaInputLogitsPtrs = BufferRange(*mMedusaInputLogitsPtrs); for (SizeType bi = 0; bi < batchSize; ++bi) { auto const slot = batchSlots[bi]; for (SizeType hi = 0; hi < mMaxNumHeads; ++hi) { medusaInputLogitsPtrs[slot * mMaxNumHeads + hi] = inputs.medusaLogits[slot][hi].template getPtr(); } } auto draftIds = outputs.nextDraftTokens ? outputs.nextDraftTokens->template getPtr() : nullptr; TLLM_CHECK_WITH_INFO(draftIds != nullptr, "Draft ids must be provided for MedusaDecoding"); TLLM_CHECK_WITH_INFO(batchSlots != nullptr, "Batch slots must be provided for MedusaDecoding"); TLLM_CHECK_WITH_INFO(sequenceLengths != nullptr, "Sequence lengths must be provided for MedusaDecoding"); TLLM_CHECK_WITH_INFO(acceptedLengths != nullptr, "Accepted lengths must be provided for MedusaDecoding"); TLLM_CHECK_WITH_INFO( curTokensPerStepDevice != nullptr, "Current tokens per step must be provided for MedusaDecoding"); TLLM_CHECK_WITH_INFO( targetTokensPerStepDevice != nullptr, "Target tokens per step must be provided for MedusaDecoding"); auto finishedStates = reinterpret_cast(outputs.finished->template getPtr()); // Compare draft tokens from outputIds with sampled target tokens at mTargetTokensDevice using paths. // Select the longest accepted path, modify outputIds in-place, increment sequenceLengths accordingly. // Fill mMedusaSelectedLogitsPtrsDevice with respective Medusa logits acceptDraftTokensByIdsWithPaths(outputIds, draftIds, mTargetTokensDevice, sequenceLengths, acceptedLengths, finishedStates, batchSlots, paths, endIds, reinterpret_cast(bufferCast(*mMedusaInputLogitsPtrs)), const_cast(mMedusaSelectedLogitsPtrsDevice), curTokensPerStepDevice, targetTokensPerStepDevice, mBestPathIdsDevice, batchSize, mVocabSize, mMaxBatchSize, mMaxTokensPerStep, maxSeqLen, mMaxNumHeads, mMaxTokensPerStep, mStream); TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__); } template void MedusaDecodingLayer::sampleNewDraftTokens(DecodingOutputParams& outputs, MedusaForwardParams& inputs) { TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__); auto const batchSize = inputs.logits.shape[0]; auto batchSlots = inputs.batch_slots ? inputs.batch_slots->template getPtr() : static_cast(nullptr); auto sequenceLengths = (outputs.sequence_length) ? outputs.sequence_length->template getPtr() : nullptr; TLLM_CHECK_WITH_INFO(batchSlots != nullptr, "Batch slots must be provided for MedusaDecoding"); TLLM_CHECK_WITH_INFO(sequenceLengths != nullptr, "Sequence lengths must be provided for MedusaDecoding"); // For each request we sample Head Num times for topK[hi] tokens auto const batchSizeHeadNums = batchSize * mMaxNumHeads; auto const maxBatchSizeHeadNums = mMaxBatchSize * mMaxNumHeads; auto tiledBatchSlots = bufferCast(*mTiledBatchSlotsForward); for (SizeType bi = 0; bi < batchSize; ++bi) { for (SizeType hi = 0; hi < mMaxNumHeads; ++hi) { tiledBatchSlots[bi * mMaxNumHeads + hi] = mMaxNumHeads * batchSlots[bi] + hi; } } auto draftIdsPtrs = reinterpret_cast(bufferCast(*mDraftIdsPtrHost)); for (SizeType bi = 0; bi < batchSize; ++bi) { auto slot = batchSlots[bi]; for (SizeType hi = 0; hi < mMaxNumHeads; ++hi) { draftIdsPtrs[slot * mMaxNumHeads + hi] = mNewDraftTokensDevice + slot * mMaxTokensPerStep + mCummulativeTopK[slot * mMaxNumHeads + hi]; } } invokeBatchTopKSampling(mSamplingWorkspaceDevice, /* logits */ static_cast(nullptr), const_cast(mMedusaSelectedLogitsPtrsDevice), draftIdsPtrs, /* outputIds */ nullptr, /* sequenceLength */ nullptr, /* finishedInput */ nullptr, /* finishedOutput */ nullptr, /* cumLogProbs */ nullptr, /* outputLogProbs */ nullptr, mCurandStatesMedusaLogitsDevice, mRuntimeMaxTopKPerRequestPerMedusaHead, mRuntimeTopKPerRequestPerMedusaHeadDevice, 1.0f, /* runtimeTopPDevice */ nullptr, mVocabSizePadded, /* endIds */ nullptr, tiledBatchSlots, mStream, batchSizeHeadNums, maxBatchSizeHeadNums, /* tokensPerStep */ nullptr, /* maxTokensPerStep */ 1, /* maxSeqLen (not required as outputIds is nullptr) */ 0, /* skipDecode */ nullptr, /* normalizeLogProbs */ false, /* probsComputed */ false, /* return all Top-K*/ true); TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__); } template void MedusaDecodingLayer::scatterNewDraftTokens(DecodingOutputParams& outputs, MedusaForwardParams& inputs) { TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__); auto const batchSize = inputs.logits.shape[0]; auto batchSlots = inputs.batch_slots ? inputs.batch_slots->template getPtr() : static_cast(nullptr); TLLM_CHECK_WITH_INFO(batchSlots != nullptr, "Batch slots must be provided for MedusaDecoding"); auto draftIds = outputs.nextDraftTokens ? outputs.nextDraftTokens->template getPtr() : nullptr; auto tokensPerStepDevice = inputs.medusaCurTokensPerStep.template getPtr(); auto treeIds = inputs.treeIds.template getPtr(); TLLM_CHECK_WITH_INFO(draftIds != nullptr, "Draft ids must be provided for MedusaDecoding"); TLLM_CHECK_WITH_INFO(tokensPerStepDevice != nullptr, "Tokens per step must be provided for MedusaDecoding"); TLLM_CHECK_WITH_INFO(treeIds != nullptr, "Tree ids must be provided for MedusaDecoding"); scatterMedusaDraftTokens(draftIds, mNewDraftTokensDevice, treeIds, tokensPerStepDevice, batchSlots, mMaxTokensPerStep, batchSize, mStream); TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__); } template void MedusaDecodingLayer::packAcceptedPaths(DecodingOutputParams& outputs, MedusaForwardParams& inputs) { TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__); auto const batchSize = inputs.logits.shape[0]; auto paths = inputs.paths.template getPtr(); auto batchSlots = inputs.batch_slots ? inputs.batch_slots->template getPtr() : nullptr; auto acceptedLengths = outputs.acceptedLengths ? outputs.acceptedLengths->template getPtr() : nullptr; auto acceptedLengthsCumSum = outputs.acceptedLengthsCumSum ? outputs.acceptedLengthsCumSum->template getPtr() : nullptr; auto medusaPathsOffsets = outputs.medusaPathsOffsets ? outputs.medusaPathsOffsets->template getPtr() : nullptr; TLLM_CHECK_WITH_INFO(batchSlots != nullptr, "Batch slots must be provided for MedusaDecoding"); TLLM_CHECK_WITH_INFO(acceptedLengths != nullptr, "Accepted lengths must be provided for MedusaDecoding"); TLLM_CHECK_WITH_INFO(acceptedLengthsCumSum != nullptr, "acceptedLengthsCumSum must be provided for MedusaDecoding"); TLLM_CHECK_WITH_INFO(medusaPathsOffsets != nullptr, "medusaPathsOffsets must be provided for MedusaDecoding"); invokePackAcceptedPaths(acceptedLengthsCumSum, medusaPathsOffsets, acceptedLengths, mBestPathIdsDevice, paths, batchSlots, batchSize, mMaxTokensPerStep, mMaxNumHeads + 1, mStream); TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__); } template class MedusaDecodingLayer; template class MedusaDecodingLayer; } // namespace layers } // namespace tensorrt_llm