/* * Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ #include "eagleDecodingLayer.h" #include "tensorrt_llm/kernels/penaltyTypes.h" #include "tensorrt_llm/kernels/speculativeDecoding/common.h" #include "tensorrt_llm/kernels/speculativeDecoding/eagleDecodingKernels.h" #include "tensorrt_llm/layers/defaultDecodingParams.h" #include "tensorrt_llm/layers/layerUtils.h" #include using namespace tensorrt_llm::common; using namespace tensorrt_llm::kernels; using namespace tensorrt_llm::kernels::speculative_decoding; using namespace tensorrt_llm::runtime; namespace tensorrt_llm::layers { template EagleDecodingLayer::EagleDecodingLayer( DecoderDomain const& decoderDomain, std::shared_ptr bufferManager) : BaseLayer(decoderDomain, bufferManager) { TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__); allocateBuffer(); TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__); } template void EagleDecodingLayer::allocateBuffer() { TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__); auto const batchSizeShape = ITensor::makeShape({mDecoderDomain.getBatchSize()}); mTemperature = mBufferManager->pinnedPool(batchSizeShape, TRTDataType::value); mTemperatureDevice = mBufferManager->gpu(batchSizeShape, TRTDataType::value); mCurandStatesDevice = mBufferManager->gpu( ITensor::makeShape({mDecoderDomain.getBatchSize(), sizeof(curandState_t)}), TRTDataType::value); mEagleNetCtxRequestTypes = mBufferManager->gpu(batchSizeShape, TRTDataType::value); mEagleNetCtxContextLengths = mBufferManager->gpu(batchSizeShape, TRTDataType::value); mEagleNetCtxPastKeyValueLengths = mBufferManager->gpu(batchSizeShape, TRTDataType::value); mEagleNetGenRequestTypes = mBufferManager->gpu(batchSizeShape, TRTDataType::value); mEagleNetGenContextLengths = mBufferManager->gpu(batchSizeShape, TRTDataType::value); mEagleNetGenPastKeyValueLengths = mBufferManager->gpu(batchSizeShape, TRTDataType::value); TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__); } template void EagleDecodingLayer::setup(SizeType32 batchSize, SizeType32 beamWidth, TensorConstPtr batchSlots, std::shared_ptr const& baseSetupParams, std::shared_ptr const& workspace) { TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__); auto setupParams = std::dynamic_pointer_cast(baseSetupParams); workspace->initializeDeviceCurandStates( setupParams->randomSeed, batchSize, workspace->getDeviceBatchSlots(), mCurandStatesDevice); // Setup penalties. FillBuffers const fillBuffers{batchSize, mDecoderDomain.getBatchSize(), mBufferManager}; fillBuffers(setupParams->temperature, DefaultDecodingParams::getTemperature(), mTemperature, mTemperatureDevice, batchSlots, getLimitsPenalty(DecodingPenaltyType::Temperature), "temperature penalty"); fillContextBuffers(batchSize, batchSlots, *setupParams, workspace); TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__); } template void EagleDecodingLayer::fillContextBuffers(SizeType32 batchSize, BufferConstPtr batchSlots, EagleSetupParams const& setupParams, std::shared_ptr const& workspace) { TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__); FillContextEagleParams params; params.outputRandDataSample = bufferCast(*setupParams.randomDataSample); params.outputTemperatures = bufferCast(*setupParams.temperatures); params.inputTemperatures = bufferCastOrNull(mTemperatureDevice); params.inputCurandState = reinterpret_cast(bufferCastOrNull(mCurandStatesDevice)); params.batchSlots = workspace->getDeviceBatchSlotsPtr(); params.batchSize = batchSize; params.checkParams(); invokeFillContextEagleData(params, getStream()); TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__); } template void EagleDecodingLayer::forwardAsync(std::shared_ptr const& baseOutputs, std::shared_ptr const& baseInputs, std::shared_ptr const& workspace) { TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__); auto inputs = std::dynamic_pointer_cast(baseInputs); auto outputs = std::dynamic_pointer_cast(baseOutputs); // Slice output ids, pos ids, next draft tokens. unpackData(*outputs, *inputs, workspace); // Convert masks to packed masks per request. convertToPackedMask(*outputs, *inputs, workspace); // Pack accepted paths for KV cache rewind. packAcceptedPaths(*outputs, *inputs, workspace); TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__); } template void EagleDecodingLayer::unpackData(EagleOutputs const& outputs, EagleInputs const& inputs, std::shared_ptr const& workspace) { TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__); auto const batchSize = inputs.localBatchSize; auto const maxSeqLen = outputs.outputIds->getDimension<-1>(); UnpackEagleDataParams params; params.batchSlots = bufferCast(*inputs.seqSlots); params.inputCurandState = reinterpret_cast(bufferCastOrNull(mCurandStatesDevice)); params.inputTemperatures = bufferCast(*mTemperatureDevice); params.inputNextDraftTokens = bufferCast(*inputs.nextDraftTokens); params.inputNextDraftLens = bufferCast(*inputs.nextDraftLens); params.inputNextDraftPaths = bufferCast(*inputs.nextDraftPaths); params.inputLastDraftTokens = bufferCast(*inputs.lastDraftTokens); params.inputLastDraftLens = bufferCast(*inputs.lastDraftLens); params.inputAcceptedTokens = bufferCast(*inputs.acceptedTokens); params.inputAcceptedLens = bufferCast(*inputs.acceptedLens); params.outputIds = bufferCast(*outputs.outputIds); params.outputNumNewTokens = bufferCast(*outputs.numNewTokens.value()); params.outputSequenceLengths = bufferCast(*outputs.sequenceLength.value()); // FIXME outputUnpackedNextDraftTokens is the same as outputNextDraftTokens. // outputUnpackedNextDraftTokens is used in eagleBuffers and outputNextDraftTokens is used in the runtime params.outputUnpackedNextDraftTokens = bufferCast(*outputs.unpackedNextDraftTokens); params.outputNextDraftTokens = bufferCast(*outputs.nextDraftTokens); params.outputNextDraftLengths = bufferCast(*outputs.nextDraftLengths); params.outputNextGenerationLength = bufferCast(*outputs.generationLengths); params.outputNextDraftPaths = bufferCast(*outputs.nextDraftPaths); params.outputPrevDraftLengths = bufferCast(*outputs.prevDraftLengths); params.outputPositionIds = bufferCast(*outputs.nextDraftPosIds); params.outputRandDataSample = bufferCast(*outputs.randomDataSample); params.outputRandDataVerification = bufferCast(*outputs.randomDataValidation); params.outputTemperatures = bufferCast(*outputs.temperatures); params.outputEagleNetCtxRequestTypes = bufferCast(*mEagleNetCtxRequestTypes); params.outputEagleNetCtxContextLengths = bufferCast(*mEagleNetCtxContextLengths); params.outputEagleNetCtxPastKeyValueLengths = bufferCast(*mEagleNetCtxPastKeyValueLengths); params.outputEagleNetGenRequestTypes = bufferCast(*mEagleNetGenRequestTypes); params.outputEagleNetGenContextLengths = bufferCast(*mEagleNetGenContextLengths); params.outputEagleNetGenPastKeyValueLengths = bufferCast(*mEagleNetGenPastKeyValueLengths); params.batchSize = batchSize; params.maxDecodingTokens = mDecoderDomain.getSpeculativeDecodingModule()->getMaxDecodingTokens(); params.maxPathLength = mDecoderDomain.getSpeculativeDecodingModule()->getMaxPathLen(); params.maxSeqLen = maxSeqLen; params.checkParams(); invokeUnpackEagleData(params, getStream()); mBufferManager->copy(*mEagleNetCtxRequestTypes, *outputs.eagleNetCtxRequestTypesHost); mBufferManager->copy(*mEagleNetCtxContextLengths, *outputs.eagleNetCtxContextLengthsHost); mBufferManager->copy(*mEagleNetCtxPastKeyValueLengths, *outputs.eagleNetCtxPastKeyValueLengthsHost); mBufferManager->copy(*mEagleNetGenRequestTypes, *outputs.eagleNetGenRequestTypesHost); mBufferManager->copy(*mEagleNetGenContextLengths, *outputs.eagleNetGenContextLengthsHost); mBufferManager->copy(*mEagleNetGenPastKeyValueLengths, *outputs.eagleNetGenPastKeyValueLengthsHost); mBufferManager->copy(*outputs.generationLengths, *outputs.generationLengthsHost); TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__); } template void EagleDecodingLayer::convertToPackedMask(EagleOutputs const& outputs, EagleInputs const& inputs, std::shared_ptr const& workspace) { TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__); auto batchSlots = bufferCast(*inputs.seqSlots); auto packedMasksDevice = bufferCast(*outputs.packedMasks); auto nextDraftPaths = bufferCast(*outputs.nextDraftPaths); auto const batchSize = inputs.localBatchSize; auto const maxDecodingTokens = mDecoderDomain.getSpeculativeDecodingModule()->getMaxDecodingTokens(); auto const maxPathLen = mDecoderDomain.getSpeculativeDecodingModule()->getMaxPathLen(); invokeGetPackedMaskFromPath( packedMasksDevice, batchSlots, nextDraftPaths, batchSize, maxDecodingTokens, maxPathLen, getStream()); TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__); } template void EagleDecodingLayer::packAcceptedPaths(EagleOutputs const& outputs, EagleInputs const& inputs, std::shared_ptr const& workspace) { TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__); auto const batchSize = inputs.localBatchSize; auto numNewTokens = bufferCast(*outputs.numNewTokens.value()); auto numNewTokensCumSum = bufferCast(*outputs.numNewTokensCumSum); auto pathsOffsets = bufferCast(*outputs.pathsOffsets); auto batchSlots = workspace->getDeviceBatchSlotsPtr(); auto bestPathIndicesSlotsPtr = bufferCast(*inputs.acceptedPathIds); auto lastDraftPathsSlotsPtr = bufferCast(*inputs.lastDraftPaths); TLLM_CHECK_WITH_INFO(batchSlots != nullptr, "Batch slots must be provided for EagleDecodingLayer"); TLLM_CHECK_WITH_INFO(numNewTokens != nullptr, "Accepted lengths must be provided for EagleDecodingLayer"); TLLM_CHECK_WITH_INFO(numNewTokensCumSum != nullptr, "numNewTokensCumSum must be provided for EagleDecodingLayer"); TLLM_CHECK_WITH_INFO(pathsOffsets != nullptr, "pathsOffsets must be provided for EagleDecodingLayer"); invokePackAcceptedPaths(numNewTokensCumSum, pathsOffsets, numNewTokens, bestPathIndicesSlotsPtr, lastDraftPathsSlotsPtr, batchSlots, batchSize, mDecoderDomain.getSpeculativeDecodingModule()->getMaxNumPaths(), mDecoderDomain.getSpeculativeDecodingModule()->getMaxPathLen(), true, getStream()); TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__); } template size_t EagleDecodingLayer::getWorkspaceSize() const noexcept { return mWorkspaceSize; } template class EagleDecodingLayer; template class EagleDecodingLayer; } // namespace tensorrt_llm::layers