/* * Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ #include "explicitDraftTokensLayer.h" #include "tensorrt_llm/kernels/penaltyTypes.h" #include "tensorrt_llm/kernels/speculativeDecoding/common.h" #include "tensorrt_llm/kernels/speculativeDecoding/explicitDraftTokensKernels.h" #include "tensorrt_llm/layers/defaultDecodingParams.h" #include "tensorrt_llm/layers/layerUtils.h" #include using namespace tensorrt_llm::common; using namespace tensorrt_llm::kernels; using namespace tensorrt_llm::kernels::speculative_decoding; using namespace tensorrt_llm::runtime; namespace tensorrt_llm::layers { template ExplicitDraftTokensLayer::ExplicitDraftTokensLayer( DecoderDomain const& decoderDomain, std::shared_ptr bufferManager) : BaseLayer(decoderDomain, bufferManager) { TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__); allocateBuffer(); TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__); } template void ExplicitDraftTokensLayer::allocateBuffer() { TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__); mTemperature = mBufferManager->pinnedPool(ITensor::makeShape({mDecoderDomain.getBatchSize()}), TRTDataType::value); mWorkspaceSize = invokeScanReduceGenerationLengths( mDecoderDomain.getBatchSize(), nullptr, nullptr, 0, nullptr, nullptr, getStream()); mCurandStatesDevice = mBufferManager->gpu( ITensor::makeShape({mDecoderDomain.getBatchSize(), sizeof(curandState_t)}), TRTDataType::value); auto const batchSizeShape = ITensor::makeShape({mDecoderDomain.getBatchSize()}); mGenerationLengthInclusiveSum = mBufferManager->gpu(batchSizeShape, TRTDataType::value); mMaxGenerationLength = mBufferManager->gpu(ITensor::makeShape({1}), TRTDataType::value); mTemperatureDevice = mBufferManager->gpu(batchSizeShape, TRTDataType::value); mBestPathIndicesSlots = mBufferManager->gpu(batchSizeShape, TRTDataType::value); mLastDraftIndicesSlots = mBufferManager->gpu(ITensor::makeShape({mDecoderDomain.getBatchSize() * mDecoderDomain.getSpeculativeDecodingModule()->getMaxNumPaths() * mDecoderDomain.getSpeculativeDecodingModule()->getMaxPathLen()}), TRTDataType::value); TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__); } template void ExplicitDraftTokensLayer::setup(SizeType32 batchSize, SizeType32 beamWidth, TensorConstPtr batchSlots, std::shared_ptr const& baseSetupParams, std::shared_ptr const& workspace) { TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__); auto setupParams = std::dynamic_pointer_cast(baseSetupParams); workspace->initializeDeviceCurandStates( setupParams->randomSeed, batchSize, workspace->getDeviceBatchSlots(), mCurandStatesDevice); // Setup penalties. FillBuffers const fillBuffers{batchSize, mDecoderDomain.getBatchSize(), mBufferManager}; // Set decoder dtype to WAR the lack of bf16 support in decoder. if (!mDecoderDtype) { mDecoderDtype = setupParams->dtype; } fillBuffers(setupParams->temperature, DefaultDecodingParams::getTemperature(), mTemperature, mTemperatureDevice, batchSlots, getLimitsPenalty(DecodingPenaltyType::Temperature), "temperature penalty"); // Dispatch context buffer fill if (mDecoderDtype == nvinfer1::DataType::kFLOAT) { fillContextBuffers(batchSize, batchSlots, *setupParams, workspace); } else if (mDecoderDtype == nvinfer1::DataType::kHALF) { fillContextBuffers(batchSize, batchSlots, *setupParams, workspace); } else if (mDecoderDtype == nvinfer1::DataType::kBF16) { fillContextBuffers<__nv_bfloat16>(batchSize, batchSlots, *setupParams, workspace); } TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__); } template void ExplicitDraftTokensLayer::forwardAsync(std::shared_ptr const& baseOutputs, std::shared_ptr const& baseInputs, std::shared_ptr const& workspace) { TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__); auto inputs = std::dynamic_pointer_cast(baseInputs); auto outputs = std::dynamic_pointer_cast(baseOutputs); // DO NOT CHANGE THE ORDER. // Convert masks to packed masks per request. convertPackedMask(*outputs, *inputs, workspace); // Slice output ids, pos ids, next draft tokens. if (mDecoderDtype == nvinfer1::DataType::kFLOAT) { splitInputDataToBatchSlots(*outputs, *inputs, workspace); } else if (mDecoderDtype == nvinfer1::DataType::kHALF) { splitInputDataToBatchSlots(*outputs, *inputs, workspace); } else if (mDecoderDtype == nvinfer1::DataType::kBF16) { splitInputDataToBatchSlots<__nv_bfloat16>(*outputs, *inputs, workspace); } // Pack accepted paths for KV cache rewind. packAcceptedPaths(*outputs, *inputs, workspace); TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__); } template size_t ExplicitDraftTokensLayer::getWorkspaceSize() const noexcept { return mWorkspaceSize; } template template void ExplicitDraftTokensLayer::fillContextBuffers(SizeType32 batchSize, BufferConstPtr batchSlots, ExplicitDraftTokensSetupParams const& setupParams, std::shared_ptr const& workspace) { TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__); FillContextExplicitDraftTokensParams params; params.randDataSample = bufferCast(*setupParams.randomDataSample); params.outputTemperatures = bufferCast(*setupParams.temperatures); params.inputTemperatures = bufferCastOrNull(mTemperatureDevice); params.curandState = reinterpret_cast(bufferCastOrNull(mCurandStatesDevice)); params.batchSlots = workspace->getDeviceBatchSlotsPtr(); params.batchSize = batchSize; params.checkParams(); invokeFillContextBuffers(params, getStream()); TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__); } template template void ExplicitDraftTokensLayer::splitInputDataToBatchSlots(ExplicitDraftTokensOutputs const& outputs, ExplicitDraftTokensInputs const& inputs, std::shared_ptr const& workspace) { TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__); auto const batchSize = inputs.localBatchSize; auto const maxSeqLen = outputs.outputIds->getDimension<-1>(); ExtractExplicitDraftTokensParams params; params.outputIds = bufferCast(*outputs.outputIds); params.outputPositionIdsBase = bufferCast(*outputs.positionIdsBase); params.outputPositionIds = bufferCast(*outputs.nextDraftPosIds); params.outputNextDraftTokens = bufferCast(*outputs.nextDraftTokens); params.unpackedNextDraftTokens = bufferCast(*outputs.unpackedNextDraftTokens); params.unpackedNextDraftIndices = bufferCast(*outputs.unpackedNextDraftIndices); params.acceptedLengths = bufferCast(*outputs.numNewTokens.value()); params.nextDraftLengths = bufferCast(*outputs.nextDraftLengths); params.prevDraftLengths = bufferCast(*outputs.prevDraftLengths); params.sequenceLengths = bufferCast(*outputs.sequenceLength.value()); params.randDataSample = bufferCast(*outputs.randomDataSample); params.randDataVerification = bufferCast(*outputs.randomDataValidation); params.outputDraftProbs = bufferCast(*outputs.nextDraftProbs); params.outputTemperatures = bufferCast(*outputs.temperatures); params.outputGenerationLengths = bufferCast(*outputs.generationLengths); params.outputBestPathIndices = bufferCast(*mBestPathIndicesSlots); params.outputLastDraftIndices = bufferCast(*mLastDraftIndicesSlots); params.batchSlots = bufferCast(*inputs.seqSlots); params.nextDraftTokens = bufferCast(*inputs.nextDraftTokens); params.lastDraftTokens = bufferCast(*inputs.lastDraftTokens); params.inputUnpackedNextDraftIndices = bufferCast(*inputs.nextDraftIndices); params.bestPathLengths = bufferCast(*inputs.bestPathLengths); params.bestPathIndices = bufferCast(*inputs.bestPathIndices); params.inputPositionIdsBase = bufferCast(*inputs.positionIdsBase); params.packedPositionIds = bufferCast(*inputs.packedPosIds); params.nextFlatTokens = bufferCast(*inputs.nextFlatTokens); params.nextDraftProbs = bufferCast(*inputs.nextDraftProbs); params.lastGenerationLengths = bufferCastOrNull(inputs.lastGenerationLengths); params.generationLengthInclusiveSum = bufferCast(*mGenerationLengthInclusiveSum); params.lastDraftIndices = bufferCast(*inputs.lastDraftIndices); params.inputTemperatures = bufferCast(*mTemperatureDevice); params.curandState = reinterpret_cast(bufferCastOrNull(mCurandStatesDevice)); params.batchSize = batchSize; params.numPaths = mDecoderDomain.getSpeculativeDecodingModule()->getMaxNumPaths(); params.maxPathLength = mDecoderDomain.getSpeculativeDecodingModule()->getMaxPathLen(); params.maxSeqLen = maxSeqLen; params.vocabSize = mDecoderDomain.getVocabSizePadded(); params.numContextRequests = batchSize - inputs.lastDraftTokens->getDimension<0>(); params.numGenerationRequests = inputs.lastDraftTokens->getDimension<0>(); params.checkParams(); // Copy max generation length mBufferManager->copy(*inputs.maxGenLengthDevice, *outputs.maxGenLengthHost); invokeExtractExplicitDraftTokens(params, getStream()); invokeCopyProbs(params, getStream()); // Copy generation lengths mBufferManager->copy(*outputs.generationLengths, *outputs.generationLengthsHost); TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__); } template void ExplicitDraftTokensLayer::convertPackedMask(ExplicitDraftTokensOutputs const& outputs, ExplicitDraftTokensInputs const& inputs, std::shared_ptr const& workspace) { TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__); auto batchSlots = bufferCast(*inputs.seqSlots); auto masksDevice = bufferCast(*inputs.masks); auto generationLengths = bufferCast(*inputs.generationLengths); auto packedMasksDevice = bufferCast(*outputs.packedMasks); auto const batchSize = inputs.localBatchSize; auto generationLengthInclusiveSumPtr = bufferCastOrNull(mGenerationLengthInclusiveSum); auto workSpaceDevicePtr = workspace->getRawWorkspaceDevicePtr(); auto maxGenerationLengthPtr = bufferCastOrNull(mMaxGenerationLength); invokeScanReduceGenerationLengths(batchSize, generationLengths, workSpaceDevicePtr, mWorkspaceSize, generationLengthInclusiveSumPtr, maxGenerationLengthPtr, getStream()); invokeConvertMaskToPackedMask(batchSize, generationLengthInclusiveSumPtr, maxGenerationLengthPtr, masksDevice, batchSlots, mDecoderDomain.getSpeculativeDecodingModule()->getMaxDecodingDraftTokens(), mDecoderDomain.getSpeculativeDecodingModule()->getMaxDecodingTokens(), packedMasksDevice, getStream()); TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__); } template void ExplicitDraftTokensLayer::packAcceptedPaths(ExplicitDraftTokensOutputs const& outputs, ExplicitDraftTokensInputs const& inputs, std::shared_ptr const& workspace) { TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__); auto const batchSize = inputs.localBatchSize; auto numNewTokens = bufferCast(*outputs.numNewTokens.value()); auto numNewTokensCumSum = bufferCast(*outputs.numNewTokensCumSum); auto pathsOffsets = bufferCast(*outputs.pathsOffsets); auto batchSlots = workspace->getDeviceBatchSlotsPtr(); auto bestPathIndicesSlotsPtr = bufferCastOrNull(mBestPathIndicesSlots); auto lastDraftIndicesSlotsPtr = bufferCastOrNull(mLastDraftIndicesSlots); TLLM_CHECK_WITH_INFO(batchSlots != nullptr, "Batch slots must be provided for ExplicitDraftTokensLayer"); TLLM_CHECK_WITH_INFO(numNewTokens != nullptr, "Accepted lengths must be provided for ExplicitDraftTokensLayer"); TLLM_CHECK_WITH_INFO( numNewTokensCumSum != nullptr, "numNewTokensCumSum must be provided for ExplicitDraftTokensLayer"); TLLM_CHECK_WITH_INFO(pathsOffsets != nullptr, "pathsOffsets must be provided for ExplicitDraftTokensLayer"); invokePackAcceptedPaths(numNewTokensCumSum, pathsOffsets, numNewTokens, bestPathIndicesSlotsPtr, lastDraftIndicesSlotsPtr, batchSlots, batchSize, batchSize, mDecoderDomain.getSpeculativeDecodingModule()->getMaxNumPaths(), mDecoderDomain.getSpeculativeDecodingModule()->getMaxPathLen(), false, getStream()); TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__); } template class ExplicitDraftTokensLayer; template class ExplicitDraftTokensLayer; } // namespace tensorrt_llm::layers