/* * Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ #include "explicitDraftTokensLayer.h" #include "tensorrt_llm/common/cudaUtils.h" #include "tensorrt_llm/common/memoryUtils.h" #include "tensorrt_llm/kernels/decodingCommon.h" #include "tensorrt_llm/kernels/penaltyTypes.h" #include "tensorrt_llm/kernels/speculativeDecoding/explicitDraftTokensKernels.h" #include "tensorrt_llm/layers/defaultDecodingParams.h" #include "tensorrt_llm/layers/layerUtils.h" #include using namespace tensorrt_llm::common; using namespace tensorrt_llm::kernels; using namespace tensorrt_llm::kernels::speculative_decoding; using namespace tensorrt_llm::runtime; namespace tensorrt_llm { namespace layers { template ExplicitDraftTokensLayer::ExplicitDraftTokensLayer( DecoderDomain const& decoderDomain, std::shared_ptr bufferManager) : BaseLayer(decoderDomain, bufferManager) { TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__); allocateBuffer(); TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__); } template void ExplicitDraftTokensLayer::allocateBuffer() { TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__); mTemperature = mBufferManager->pinnedPool(ITensor::makeShape({mDecoderDomain.getBatchSize()}), TRTDataType::value); mScanWorkspaceSizeInBytes = invokeScanGenerationLengths( nullptr, mScanWorkspaceSizeInBytes, nullptr, nullptr, mDecoderDomain.getBatchSize(), getStream()); mReduceWorkspaceSizeInBytes = invokeReduceMaxGenerationLengths( nullptr, mReduceWorkspaceSizeInBytes, nullptr, nullptr, mDecoderDomain.getBatchSize(), getStream()); auto workspaceSizeInBytes = std::max(mScanWorkspaceSizeInBytes, mReduceWorkspaceSizeInBytes); mWorkspaceDevice = mBufferManager->gpu(workspaceSizeInBytes, nvinfer1::DataType::kINT8); mCurandStatesDevice = mBufferManager->gpu( ITensor::makeShape({mDecoderDomain.getBatchSize(), sizeof(curandState_t)}), TRTDataType::value); auto const batchSizeShape = ITensor::makeShape({mDecoderDomain.getBatchSize()}); mRandomSeedsDevice = mBufferManager->gpu(batchSizeShape, TRTDataType::value); mGenerationLengthInclusiveSum = mBufferManager->gpu(batchSizeShape, TRTDataType::value); mMaxGenerationLength = mBufferManager->gpu(ITensor::makeShape({1}), TRTDataType::value); mTemperatureDevice = mBufferManager->gpu(batchSizeShape, TRTDataType::value); mBestPathIndicesSlots = mBufferManager->gpu(batchSizeShape, TRTDataType::value); mLastDraftIndicesSlots = mBufferManager->gpu(ITensor::makeShape({mDecoderDomain.getBatchSize() * mDecoderDomain.getSpeculativeDecodingModule()->getMaxNumPaths() * mDecoderDomain.getSpeculativeDecodingModule()->getMaxPathLen()}), TRTDataType::value); TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__); } template void ExplicitDraftTokensLayer::setup(SizeType32 batchSize, SizeType32 beamWidth, BufferConstPtr batchSlots, std::shared_ptr const& baseSetupParams) { TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__); auto setupParams = std::dynamic_pointer_cast(baseSetupParams); auto batchSlotsPtr = bufferCastOrNull(batchSlots); auto randomSeedDevicePtr = bufferCast(*mRandomSeedsDevice); auto curandStatesDevicePtr = reinterpret_cast(bufferCast(*mCurandStatesDevice)); if (setupParams->randomSeed) { if (setupParams->randomSeed->size() == 1) { invokeCurandInitialize( curandStatesDevicePtr, batchSlotsPtr, batchSize, setupParams->randomSeed->front(), getStream()); sync_check_cuda_error(); } else { TLLM_CHECK_WITH_INFO(setupParams->randomSeed->size() == batchSize, "Random seed vector size mismatch."); mBufferManager->copy(setupParams->randomSeed.value().data(), *mRandomSeedsDevice); invokeCurandBatchInitialize( curandStatesDevicePtr, batchSlotsPtr, batchSize, randomSeedDevicePtr, getStream()); sync_check_cuda_error(); } } else { // Initialize curand states using the default seed 0. invokeCurandInitialize( curandStatesDevicePtr, batchSlotsPtr, batchSize, DefaultDecodingParams::getSeed(), getStream()); } // Setup penalties. FillBuffers const fillBuffers{batchSize, mDecoderDomain.getBatchSize(), mBufferManager}; fillBuffers(setupParams->temperature, DefaultDecodingParams::getTemperature(), mTemperature, mTemperatureDevice, batchSlots, getLimitsPenalty(DecodingPenaltyType::Temperature), "temperature penalty"); fillContextBuffers(batchSize, batchSlots, *setupParams); TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__); } template void ExplicitDraftTokensLayer::fillContextBuffers( SizeType32 batchSize, BufferConstPtr batchSlots, ExplicitDraftTokensSetupParams const& setupParams) { TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__); FillContextExplicitDraftTokensParams params; params.randDataSample = bufferCast(*setupParams.randomDataSample); params.outputTemperatures = bufferCast(*setupParams.temperatures); params.inputTemperatures = bufferCastOrNull(mTemperatureDevice); params.curandState = reinterpret_cast(bufferCastOrNull(mCurandStatesDevice)); params.batchSlots = bufferCastOrNull(batchSlots); params.batchSize = batchSize; params.checkParams(); invokeFillContextBuffers(params, getStream()); TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__); } template void ExplicitDraftTokensLayer::forwardAsync( std::shared_ptr const& baseOutputs, std::shared_ptr const& baseInputs) { TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__); auto inputs = std::dynamic_pointer_cast(baseInputs); auto outputs = std::dynamic_pointer_cast(baseOutputs); // DO NOT CHANGE THE ORDER. // Convert masks to packed masks per request. convertPackedMask(*outputs, *inputs); // Slice output ids, pos ids, next draft tokens. splitInputDataToBatchSlots(*outputs, *inputs); // Pack accepted paths for KV cache rewind. packAcceptedPaths(*outputs, *inputs); TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__); } template size_t ExplicitDraftTokensLayer::getWorkspaceSize() const noexcept { return mWorkspaceDevice->getSizeInBytes(); } template void ExplicitDraftTokensLayer::convertPackedMask( ExplicitDraftTokensOutputs const& outputs, ExplicitDraftTokensInputs const& inputs) { TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__); auto batchSlots = bufferCast(*inputs.seqSlots); auto masksDevice = bufferCast(*inputs.masks); auto generationLengths = bufferCast(*inputs.generationLengths); auto packedMasksDevice = bufferCast(*outputs.packedMasks); auto const batchSize = inputs.localBatchSize; auto generationLengthInclusiveSumPtr = bufferCastOrNull(mGenerationLengthInclusiveSum); auto workSpaceDevicePtr = mWorkspaceDevice->data(); auto maxGenerationLengthPtr = bufferCastOrNull(mMaxGenerationLength); invokeScanReduceGenerationLengths(batchSize, generationLengths, workSpaceDevicePtr, mScanWorkspaceSizeInBytes, generationLengthInclusiveSumPtr, workSpaceDevicePtr, mReduceWorkspaceSizeInBytes, maxGenerationLengthPtr, getStream()); invokeConvertMaskToPackedMask(batchSize, generationLengthInclusiveSumPtr, maxGenerationLengthPtr, masksDevice, batchSlots, mDecoderDomain.getSpeculativeDecodingModule()->getMaxDecodingDraftTokens(), mDecoderDomain.getSpeculativeDecodingModule()->getMaxDecodingTokens(), packedMasksDevice, getStream()); TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__); } template void ExplicitDraftTokensLayer::splitInputDataToBatchSlots( ExplicitDraftTokensOutputs const& outputs, ExplicitDraftTokensInputs const& inputs) { TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__); auto const batchSize = inputs.localBatchSize; auto const maxSeqLen = outputs.outputIds->getDimension<-1>(); ExtractExplicitDraftTokensParams params; params.outputIds = bufferCast(*outputs.outputIds); params.outputPositionIdsBase = bufferCast(*outputs.positionIdsBase); params.outputPositionIds = bufferCast(*outputs.nextDraftPosIds); params.outputNextDraftTokens = bufferCast(*outputs.nextDraftTokens); params.unpackedNextDraftTokens = bufferCast(*outputs.unpackedNextDraftTokens); params.unpackedNextDraftIndices = bufferCast(*outputs.unpackedNextDraftIndices); params.acceptedLengths = bufferCast(*outputs.numNewTokens.value()); params.nextDraftLengths = bufferCast(*outputs.nextDraftLengths); params.prevDraftLengths = bufferCast(*outputs.prevDraftLengths); params.sequenceLengths = bufferCast(*outputs.sequenceLength.value()); params.randDataSample = bufferCast(*outputs.randomDataSample); params.randDataVerification = bufferCast(*outputs.randomDataValidation); params.outputDraftProbs = bufferCast(*outputs.nextDraftProbs); params.outputTemperatures = bufferCast(*outputs.temperatures); params.outputGenerationLengths = bufferCast(*outputs.generationLengths); params.outputBestPathIndices = bufferCast(*mBestPathIndicesSlots); params.outputLastDraftIndices = bufferCast(*mLastDraftIndicesSlots); params.batchSlots = bufferCast(*inputs.seqSlots); params.nextDraftTokens = bufferCast(*inputs.nextDraftTokens); params.lastDraftTokens = bufferCast(*inputs.lastDraftTokens); params.inputUnpackedNextDraftIndices = bufferCast(*inputs.nextDraftIndices); params.bestPathLengths = bufferCast(*inputs.bestPathLengths); params.bestPathIndices = bufferCast(*inputs.bestPathIndices); params.inputPositionIdsBase = bufferCast(*inputs.positionIdsBase); params.packedPositionIds = bufferCast(*inputs.packedPosIds); params.nextFlatTokens = bufferCast(*inputs.nextFlatTokens); params.nextDraftProbs = bufferCast(*inputs.nextDraftProbs); params.lastGenerationLengths = bufferCastOrNull(inputs.lastGenerationLengths); params.generationLengthInclusiveSum = bufferCast(*mGenerationLengthInclusiveSum); params.lastDraftIndices = bufferCast(*inputs.lastDraftIndices); params.inputTemperatures = bufferCast(*mTemperatureDevice); params.curandState = reinterpret_cast(bufferCastOrNull(mCurandStatesDevice)); params.batchSize = batchSize; params.numPaths = mDecoderDomain.getSpeculativeDecodingModule()->getMaxNumPaths(); params.maxPathLength = mDecoderDomain.getSpeculativeDecodingModule()->getMaxPathLen(); params.maxSeqLen = maxSeqLen; params.vocabSize = mDecoderDomain.getVocabSizePadded(); params.numContextRequests = batchSize - inputs.lastDraftTokens->getDimension<0>(); params.numGenerationRequests = inputs.lastDraftTokens->getDimension<0>(); params.checkParams(); // Copy max generation length mBufferManager->copy(*inputs.maxGenLengthDevice, *outputs.maxGenLengthHost); invokeExtractExplicitDraftTokens(params, getStream()); invokeCopyProbs(params, getStream()); // Copy generation lengths mBufferManager->copy(*outputs.generationLengths, *outputs.generationLengthsHost); TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__); } template void ExplicitDraftTokensLayer::packAcceptedPaths( ExplicitDraftTokensOutputs const& outputs, ExplicitDraftTokensInputs const& inputs) { TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__); auto const batchSize = inputs.localBatchSize; auto numNewTokens = bufferCast(*outputs.numNewTokens.value()); auto numNewTokensCumSum = bufferCast(*outputs.numNewTokensCumSum); auto pathsOffsets = bufferCast(*outputs.pathsOffsets); auto batchSlots = bufferCast(*inputs.batchSlots.value()); auto bestPathIndicesSlotsPtr = bufferCastOrNull(mBestPathIndicesSlots); auto lastDraftIndicesSlotsPtr = bufferCastOrNull(mLastDraftIndicesSlots); TLLM_CHECK_WITH_INFO(batchSlots != nullptr, "Batch slots must be provided for ExplicitDraftTokensLayer"); TLLM_CHECK_WITH_INFO(numNewTokens != nullptr, "Accepted lengths must be provided for ExplicitDraftTokensLayer"); TLLM_CHECK_WITH_INFO( numNewTokensCumSum != nullptr, "numNewTokensCumSum must be provided for ExplicitDraftTokensLayer"); TLLM_CHECK_WITH_INFO(pathsOffsets != nullptr, "pathsOffsets must be provided for ExplicitDraftTokensLayer"); invokePackAcceptedPaths(numNewTokensCumSum, pathsOffsets, numNewTokens, bestPathIndicesSlotsPtr, lastDraftIndicesSlotsPtr, batchSlots, batchSize, mDecoderDomain.getSpeculativeDecodingModule()->getMaxNumPaths(), mDecoderDomain.getSpeculativeDecodingModule()->getMaxPathLen(), false, getStream()); TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__); } template class ExplicitDraftTokensLayer; template class ExplicitDraftTokensLayer; } // namespace layers } // namespace tensorrt_llm