/* * Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ #include "tensorrt_llm/layers/explicitDraftTokensLayer.h" #include "tensorrt_llm/common/cudaUtils.h" #include "tensorrt_llm/common/memoryUtils.h" #include "tensorrt_llm/kernels/decodingCommon.h" #include "tensorrt_llm/kernels/decodingKernels.h" #include "tensorrt_llm/kernels/penaltyKernels.h" #include "tensorrt_llm/kernels/speculativeDecoding/explicitDraftTokensKernels.h" #include "tensorrt_llm/layers/defaultDecodingParams.h" #include "tensorrt_llm/layers/layerUtils.h" #include "tensorrt_llm/runtime/bufferManager.h" #include "tensorrt_llm/runtime/iBuffer.h" #include using namespace tensorrt_llm::common; using namespace tensorrt_llm::kernels; using namespace tensorrt_llm::kernels::speculative_decoding; using namespace tensorrt_llm::runtime; namespace tensorrt_llm { namespace layers { template ExplicitDraftTokensLayer::ExplicitDraftTokensLayer( DecoderDomain const& decoderDomain, cudaStream_t stream, std::shared_ptr allocator) : BaseLayer(decoderDomain, stream, std::move(allocator)) { TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__); allocateBuffer(); TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__); } template ExplicitDraftTokensLayer::~ExplicitDraftTokensLayer() { TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__); freeBuffer(); TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__); } template void ExplicitDraftTokensLayer::allocateBuffer() { TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__); mTemperature.resize(mDecoderDomain.getBatchSize()); mScanWorkspaceSizeInBytes = invokeScanSpecDecodingGenerationLengths( nullptr, mScanWorkspaceSizeInBytes, nullptr, nullptr, mDecoderDomain.getBatchSize(), mStream); mReduceWorkspaceSizeInBytes = invokeReduceMaxSpecDecodingGenerationLengths( nullptr, mReduceWorkspaceSizeInBytes, nullptr, nullptr, mDecoderDomain.getBatchSize(), mStream); mWorkspaceSizeInBytes = std::max(mScanWorkspaceSizeInBytes, mReduceWorkspaceSizeInBytes); std::array deviceBufferSizes = {sizeof(curandState_t) * mDecoderDomain.getBatchSize(), sizeof(uint64_t) * mDecoderDomain.getBatchSize(), mWorkspaceSizeInBytes, sizeof(SizeType32) * mDecoderDomain.getBatchSize(), sizeof(SizeType32), sizeof(float) * mDecoderDomain.getBatchSize()}; mCurandStatesDevice = mAllocator->reMalloc(mCurandStatesDevice, deviceBufferSizes[0], false); mRandomSeedsDevice = mAllocator->reMalloc(mRandomSeedsDevice, deviceBufferSizes[1], false); mWorkspaceDevice = mAllocator->reMalloc(mWorkspaceDevice, deviceBufferSizes[2], false); mGenerationLengthInclusiveSum = mAllocator->reMalloc(mGenerationLengthInclusiveSum, deviceBufferSizes[3], false); mMaxGenerationLength = mAllocator->reMalloc(mMaxGenerationLength, deviceBufferSizes[4], false); mTemperatureDevice = mAllocator->reMalloc(mTemperatureDevice, deviceBufferSizes[5], false); TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__); } template void ExplicitDraftTokensLayer::freeBuffer() { TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__); mAllocator->free((void**) (&mCurandStatesDevice)); mAllocator->free((void**) (&mRandomSeedsDevice)); mAllocator->free((void**) (&mWorkspaceDevice)); mAllocator->free((void**) (&mGenerationLengthInclusiveSum)); mAllocator->free((void**) (&mMaxGenerationLength)); mAllocator->free((void**) (&mTemperatureDevice)); TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__); } template void ExplicitDraftTokensLayer::setup(SizeType32 batchSize, SizeType32 beamWidth, SizeType32 const* batchSlots, std::shared_ptr baseSetupParams) { TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__); auto setupParams = std::dynamic_pointer_cast(baseSetupParams); if (setupParams->randomSeed) { if (setupParams->randomSeed->size() == 1) { invokeCurandInitialize( mCurandStatesDevice, batchSlots, batchSize, setupParams->randomSeed->front(), mStream); sync_check_cuda_error(); } else { TLLM_CHECK_WITH_INFO(setupParams->randomSeed->size() == batchSize, "Random seed vector size mismatch."); cudaAutoCpy(mRandomSeedsDevice, setupParams->randomSeed->data(), batchSize, mStream); invokeCurandBatchInitialize(mCurandStatesDevice, batchSlots, batchSize, mRandomSeedsDevice, mStream); sync_check_cuda_error(); } } else { // Initialize curand states using the default seed 0. invokeCurandInitialize(mCurandStatesDevice, batchSlots, batchSize, DefaultDecodingParams::getSeed(), mStream); } // Setup penalties. FillBuffers const fillBuffers{batchSize, mDecoderDomain.getBatchSize(), mStream}; fillBuffers(setupParams->temperature, DefaultDecodingParams::getTemperature(), mTemperature, mTemperatureDevice, batchSlots, getLimitsPenalty(DecodingPenaltyType::Temperature), "temperature penalty"); TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__); } template void ExplicitDraftTokensLayer::forwardAsync( std::shared_ptr baseOutputs, std::shared_ptr baseInputs) { TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__); auto inputs = std::dynamic_pointer_cast(baseInputs); auto outputs = std::dynamic_pointer_cast(baseOutputs); // DO NOT CHANGE THE ORDER. // Convert masks to packed masks per request. convertPackedMask(*outputs, *inputs); // Slice output ids, pos ids, next draft tokens. splitInputDataToBatchSlots(*outputs, *inputs); // Pack accepted paths for KV cache rewind. packAcceptedPaths(*outputs, *inputs); TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__); } template void ExplicitDraftTokensLayer::convertPackedMask( DynamicDecodeOutputParams const& outputs, ExplicitDraftTokensInputParams const& inputs) { TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__); auto batchSlots = inputs.batch_slots->template getPtr(); auto masksDevice = inputs.masks.template getPtr(); auto specDecodingGenerationLengths = inputs.specDecodingGenerationLengths.template getPtr(); auto packedMasksDevice = outputs.explicitDraftTokensOutputs->packedMasks.template getPtr(); auto const batchSize = inputs.batch_slots->shape[0]; invokeScanReduceSpecDecodingGenerationLengths(batchSize, specDecodingGenerationLengths, mWorkspaceDevice, mScanWorkspaceSizeInBytes, mGenerationLengthInclusiveSum, mWorkspaceDevice, mReduceWorkspaceSizeInBytes, mMaxGenerationLength, mStream); invokeConvertSpecDecodingMaskToPackedMask(batchSize, mGenerationLengthInclusiveSum, mMaxGenerationLength, masksDevice, batchSlots, mDecoderDomain.getSpeculativeDecodingModule()->getMaxDecodingDraftTokens(), mDecoderDomain.getSpeculativeDecodingModule()->getMaxDecodingTokens(), packedMasksDevice, mStream); TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__); } template void ExplicitDraftTokensLayer::splitInputDataToBatchSlots( DynamicDecodeOutputParams const& outputs, ExplicitDraftTokensInputParams const& inputs) { TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__); auto const batchSize = inputs.batch_slots->shape[0]; auto const maxSeqLen = outputs.output_ids.shape[outputs.output_ids.shape.size() - 1]; ExtractExplicitDraftTokensParams params; params.outputIds = outputs.output_ids.template getPtr(); params.outputPositionIdsBase = outputs.explicitDraftTokensOutputs->positionIdsBase.template getPtr(); params.outputPositionIds = outputs.explicitDraftTokensOutputs->nextDraftPosIds.template getPtr(); params.outputNextDraftTokens = outputs.explicitDraftTokensOutputs->nextDraftTokens.template getPtr(); params.unpackedNextDraftTokens = outputs.explicitDraftTokensOutputs->unpackedNextDraftTokens.template getPtr(); params.unpackedNextDraftIndices = outputs.explicitDraftTokensOutputs->unpackedNextDraftIndices.template getPtr(); params.acceptedLengths = outputs.explicitDraftTokensOutputs->acceptedLengths.template getPtr(); params.nextDraftLengths = outputs.explicitDraftTokensOutputs->nextDraftLengths.template getPtr(); params.sequenceLengths = outputs.sequence_length->template getPtr(); params.randDataSample = outputs.explicitDraftTokensOutputs->randomDataSample.template getPtr(); params.randDataVerification = outputs.explicitDraftTokensOutputs->randomDataValidation.template getPtr(); params.outputDraftProbs = outputs.explicitDraftTokensOutputs->nextDraftProbs.template getPtr(); params.outputTemperatures = outputs.explicitDraftTokensOutputs->temperatures.template getPtr(); params.batchSlots = inputs.batch_slots->template getPtr(); params.nextDraftTokens = inputs.nextDraftTokens.template getPtr(); params.lastDraftTokens = inputs.lastDraftTokens.template getPtr(); params.inputUnpackedNextDraftIndices = inputs.nextDraftIndices.template getPtr(); params.bestPathLengths = inputs.bestPathLengths.template getPtr(); params.bestPathIndices = inputs.bestPathIndices.template getPtr(); params.inputPositionIdsBase = inputs.positionIdsBase.template getPtr(); params.packedPositionIds = inputs.packedPosIds.template getPtr(); params.nextFlatTokens = inputs.nextFlatTokens.template getPtr(); params.nextDraftProbs = inputs.nextDraftProbs.template getPtr(); params.generationLengthInclusiveSum = mGenerationLengthInclusiveSum; params.inputTemperatures = mTemperatureDevice; params.curandState = mCurandStatesDevice; params.curandState = mCurandStatesDevice; params.batchSize = batchSize; params.numPaths = mDecoderDomain.getSpeculativeDecodingModule()->getMaxNumPaths(); params.maxPathLength = mDecoderDomain.getSpeculativeDecodingModule()->getMaxPathLen(); params.maxSeqLen = maxSeqLen; params.vocabSize = mDecoderDomain.getVocabSizePadded(); invokeExtractExplicitDraftTokens(params, mStream); invokeCopyProbs(params, mStream); TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__); } template void ExplicitDraftTokensLayer::packAcceptedPaths( DynamicDecodeOutputParams const& outputs, ExplicitDraftTokensInputParams const& inputs) { TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__); auto const batchSize = inputs.batch_slots->shape[0]; auto paths = inputs.lastDraftIndices.template getPtr(); auto batchSlots = inputs.batch_slots ? inputs.batch_slots->template getPtr() : nullptr; auto acceptedLengths = outputs.explicitDraftTokensOutputs->acceptedLengths.template getPtr(); auto acceptedLengthsCumSum = outputs.explicitDraftTokensOutputs->acceptedLengthsCumSum.template getPtr(); auto pathsOffsets = outputs.explicitDraftTokensOutputs->pathsOffsets.template getPtr(); auto bestPathIndices = inputs.bestPathIndices.template getPtr(); TLLM_CHECK_WITH_INFO(batchSlots != nullptr, "Batch slots must be provided for ExplicitDraftTokensLayer"); TLLM_CHECK_WITH_INFO(acceptedLengths != nullptr, "Accepted lengths must be provided for ExplicitDraftTokensLayer"); TLLM_CHECK_WITH_INFO( acceptedLengthsCumSum != nullptr, "acceptedLengthsCumSum must be provided for ExplicitDraftTokensLayer"); TLLM_CHECK_WITH_INFO(pathsOffsets != nullptr, "pathsOffsets must be provided for ExplicitDraftTokensLayer"); invokePackAcceptedPaths(acceptedLengthsCumSum, pathsOffsets, acceptedLengths, bestPathIndices, paths, batchSlots, batchSize, mDecoderDomain.getSpeculativeDecodingModule()->getMaxNumPaths(), mDecoderDomain.getSpeculativeDecodingModule()->getMaxPathLen(), true, mStream); TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__); } template class ExplicitDraftTokensLayer; template class ExplicitDraftTokensLayer; } // namespace layers } // namespace tensorrt_llm