/* * Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ #include "externalDraftTokensLayer.h" #include "tensorrt_llm/common/cudaUtils.h" #include "tensorrt_llm/common/nvtxUtils.h" #include "tensorrt_llm/kernels/decodingCommon.h" #include "tensorrt_llm/kernels/samplingTopKKernels.h" #include "tensorrt_llm/kernels/samplingTopPKernels.h" #include "tensorrt_llm/kernels/speculativeDecoding/externalDraftTokensKernels.h" #include "tensorrt_llm/layers/defaultDecodingParams.h" #include "tensorrt_llm/layers/layerUtils.h" #include "tensorrt_llm/runtime/runtimeKernels.h" #include namespace tksd = tensorrt_llm::kernels::speculative_decoding; using namespace tensorrt_llm::common; using namespace tensorrt_llm::kernels; using namespace tensorrt_llm::runtime; namespace tensorrt_llm::layers { template ExternalDraftTokensLayer::ExternalDraftTokensLayer(executor::DecodingMode const& mode, DecoderDomain const& decoderDomain, std::shared_ptr bufferManager, bool isDeterministic, bool isAirTopP) : BaseLayer(decoderDomain, bufferManager) , mDecodingMode(mode) , mIsDeterministic(isDeterministic) , mIsAirTopP(isAirTopP) { TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__); TLLM_CHECK_WITH_INFO(!mDecodingMode.isBeamSearch(), "ExternalDraftTokensLayer does not support Beam search mode"); auto const deviceId = getDevice(); TLLM_CUDA_CHECK(cudaGetDeviceProperties(&mDeviceProp, deviceId)); allocateBuffer(decoderDomain.getBatchSize()); TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__); } template void ExternalDraftTokensLayer::allocateBuffer(SizeType32 batchSize) { TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__); // top k workspace size auto workspaceSize = getTopKWorkspaceSize(batchSize, 1, TOP_K_MAX, mDecoderDomain.getVocabSizePadded()); mWorkspaceSize = std::max(workspaceSize, mWorkspaceSize); // top p workspace size workspaceSize = getTopPWorkspaceSize(batchSize, mDecoderDomain.getVocabSizePadded()); mWorkspaceSize = std::max(workspaceSize, mWorkspaceSize); // multinomial (top p == 1) workspace size workspaceSize = getAirTopPWorkspaceSize(batchSize, mDecoderDomain.getVocabSizePadded(), mIsDeterministic); mWorkspaceSize = std::max(workspaceSize, mWorkspaceSize); // batchsize here is maxBatchSize auto const batchSizeShape = ITensor::makeShape({batchSize}); mCurandStatesDevice = mBufferManager->gpu(ITensor::makeShape({batchSize, sizeof(curandState_t)}), TRTDataType::value); mBatchIsAccepted = mBufferManager->gpu(batchSizeShape, TRTDataType::value); mRuntimeMultinomialDevice = mBufferManager->gpu(batchSizeShape, TRTDataType::value); // host buffers. mSkipTopKDecodeDevice = mBufferManager->gpu(batchSizeShape, TRTDataType::value); mSkipTopKDecodeHost = BufferManager::pinnedPool(batchSizeShape, TRTDataType::value); mSkipTopPDecodeDevice = mBufferManager->gpu(batchSizeShape, TRTDataType::value); mSkipTopPDecodeHost = BufferManager::pinnedPool(batchSizeShape, TRTDataType::value); auto skipTopPDecodeHostRange = BufferRange(*mSkipTopPDecodeHost); std::fill(skipTopPDecodeHostRange.begin(), skipTopPDecodeHostRange.end(), true); mOutputIdsAfterSampling = mBufferManager->gpu( ITensor::makeShape({batchSize, mDecoderDomain.getVocabSizePadded()}), TRTDataType::value); mOutputIdsAfterSamplingPtrsHost = BufferManager::pinned(batchSizeShape, TRTDataType::value); mOutputIdsAfterSamplingPtrsDevice = mBufferManager->gpu(batchSizeShape, TRTDataType::value); mTargetOutputIds = mBufferManager->gpu(batchSizeShape, TRTDataType::value); mRuntimeTopKDevice = mBufferManager->gpu(batchSizeShape, TRTDataType::value); mRuntimeTopKHost = BufferManager::cpu(batchSizeShape, TRTDataType::value); mRuntimeTopPDevice = mBufferManager->gpu(batchSizeShape, TRTDataType::value); mReturnAllSelectedTokensPerSlotHost = BufferManager::pinned(batchSizeShape, TRTDataType::value); mReturnAllSelectedTokensPerSlotDevice = mBufferManager->gpu(batchSizeShape, TRTDataType::value); mMaskBuffer = mBufferManager->gpu( ITensor::makeShape({batchSize, mDecoderDomain.getVocabSizePadded()}), TRTDataType::value); mSetupWorkspaceSize = std::max({mBatchIsAccepted->getSizeInBytes(), mRuntimeMultinomialDevice->getSizeInBytes(), mOutputIdsAfterSampling->getSizeInBytes(), mTargetOutputIds->getSizeInBytes(), mMaskBuffer->getSizeInBytes()}); mTargetLogits = mBufferManager->gpu( ITensor::makeShape({batchSize, mDecoderDomain.getVocabSizePadded()}), TRTDataType::value); TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__); } template void ExternalDraftTokensLayer::setup(SizeType32 batchSize, SizeType32 beamWidth, TensorConstPtr batchSlots, std::shared_ptr const& baseSetupParams, std::shared_ptr const& workspace) { TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__); NVTX3_SCOPED_RANGE(ExternalDraftTokensLayer_setup); auto setupParams = std::dynamic_pointer_cast(baseSetupParams); workspace->initializeDeviceCurandStates( setupParams->randomSeed, batchSize, workspace->getDeviceBatchSlots(), mCurandStatesDevice); auto& runtimeMultinomialDeviceTensor = const_cast(*mRuntimeMultinomialDevice); tensorrt_llm::runtime::kernels::invokeFill(runtimeMultinomialDeviceTensor, 1.0f, mBufferManager->getStream()); // Prepare runtime top K auto runtimeTopK = setupParams->runtimeTopK.value_or(std::vector{DefaultDecodingParams::getTopK()}); auto runtimeTopP = setupParams->runtimeTopP.value_or(std::vector{DefaultDecodingParams::getTopP()}); auto const paramsSize = expandMatchElements(batchSize, runtimeTopK, runtimeTopP); TLLM_CHECK_WITH_INFO(paramsSize != 0, fmtstr("ExternalDraftTokensLayer got parameter with unexpected size, want 1 or batchSize(%d), got" "runtimeTopK.size() = %zu, runtimeTopP.size() = %zu", batchSize, runtimeTopK.size(), runtimeTopP.size())); for (size_t i = 0; i < paramsSize; ++i) { auto& topK = runtimeTopK[i]; auto& topP = runtimeTopP[i]; clampTopK(topK); clampTopP(topP); regularizeTopKTopP(topK, topP); } // Update parameters on both device and host, so we can // - determine whether we can skip launch TopK / TopP kernel by examine mSkipTopKDecodeHost / mSkipTopPDecodeHost // - select best kernel by examine mRuntimeTopKHost // without consulting device memory, or we'll have to do an expensive synchronization. SizeType32* topKsPtr = nullptr; float* topPsPtr = nullptr; if (paramsSize > 1) { auto initWorkspaceSizes = getTopKInitWorkspaceSizes(batchSize); calcAlignedPointers(workspace->getRawWorkspaceDevicePtr(), initWorkspaceSizes)(topKsPtr, topPsPtr); DecodingLayerWorkspace::copyToWorkspace( *mBufferManager, runtimeTopK, IBuffer::wrap(topKsPtr, initWorkspaceSizes[0] / sizeof(*topKsPtr))); DecodingLayerWorkspace::copyToWorkspace( *mBufferManager, runtimeTopP, IBuffer::wrap(topPsPtr, initWorkspaceSizes[1] / sizeof(*topPsPtr))); } auto const* batchSlotsDevicePtr = workspace->getDeviceBatchSlotsPtr(); auto* skipTopKDecodeDevicePtr = bufferCastOrNull(mSkipTopKDecodeDevice); auto* skipTopPDecodeDevicePtr = bufferCastOrNull(mSkipTopPDecodeDevice); invokeSetupTopKTopPRuntimeArgs(batchSize, // {topKsPtr, runtimeTopK.front(), bufferCast(*mRuntimeTopKDevice)}, // {topPsPtr, runtimeTopP.front(), bufferCast(*mRuntimeTopPDevice)}, // skipTopKDecodeDevicePtr, skipTopPDecodeDevicePtr, batchSlotsDevicePtr, true, getStream()); auto const* batchSlotsHostPtr = bufferCast(*batchSlots); auto* skipDecodeTopKHostPtr = bufferCastOrNull(mSkipTopKDecodeHost); auto* skipDecodeTopPHostPtr = bufferCastOrNull(mSkipTopPDecodeHost); topKsPtr = paramsSize > 1 ? runtimeTopK.data() : nullptr; invokeSetupTopKTopPRuntimeArgs(batchSize, // {topKsPtr, runtimeTopK.front(), bufferCast(*mRuntimeTopKHost)}, {}, // skipDecodeTopKHostPtr, skipDecodeTopPHostPtr, batchSlotsHostPtr, false); if (mIsAirTopP) { auto smCnt = mDeviceProp.multiProcessorCount; if (smCnt <= 0) { auto const deviceId = getDevice(); cudaDeviceProp prop{}; TLLM_CUDA_CHECK(cudaGetDeviceProperties(&prop, deviceId)); smCnt = prop.multiProcessorCount; } mAirTopPBlockNum = calcAirTopPBlockNum(batchSize, mDecoderDomain.getVocabSizePadded(), smCnt, mIsDeterministic); } TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__); } template void ExternalDraftTokensLayer::prepareInputs( std::shared_ptr const& outputs, std::shared_ptr const& baseInputs) { TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__); // Fill the buffer for selected ids from sampling with zero. -1 will be set as a boundary if topP kernel is required auto& outputIdsAfterSamplingTensor = const_cast(*mOutputIdsAfterSampling); mBufferManager->setZero(outputIdsAfterSamplingTensor); auto inputs = std::dynamic_pointer_cast(baseInputs); if (inputs->step == 0) { // Prepare mReturnAllSelectedTokensPerSlot auto numDraftTokensHost = BufferRange(*inputs->numDraftTokensHost); auto returnAllSelectedTokensPerSlot = BufferRange(*mReturnAllSelectedTokensPerSlotHost); std::transform(numDraftTokensHost.begin(), numDraftTokensHost.end(), returnAllSelectedTokensPerSlot.begin(), [](auto numDraftTokens) { return numDraftTokens > 0; }); mBufferManager->copy(*mReturnAllSelectedTokensPerSlotHost, *mReturnAllSelectedTokensPerSlotDevice); // Prepare mOutputIdsAfterSamplingPtrs auto outputIdsAfterSamplingPtrsHost = BufferRange(*mOutputIdsAfterSamplingPtrsHost); auto outputIdsPtrs = BufferRange(*outputs->outputIdsPtrHost); auto const maxBatchSize = mDecoderDomain.getBatchSize(); for (auto batchSlot = 0; batchSlot < maxBatchSize; ++batchSlot) { auto outputIdsAfterSamplingSlice = ITensor::slice(mOutputIdsAfterSampling, batchSlot); auto* outputIdsAfterSamplingPtr = bufferCast(*outputIdsAfterSamplingSlice); outputIdsAfterSamplingPtrsHost[batchSlot] = returnAllSelectedTokensPerSlot[batchSlot] ? outputIdsAfterSamplingPtr : outputIdsPtrs[batchSlot]; } mBufferManager->copy(*mOutputIdsAfterSamplingPtrsHost, *mOutputIdsAfterSamplingPtrsDevice); } TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__); }; template void ExternalDraftTokensLayer::forwardAsync(std::shared_ptr const& outputs, std::shared_ptr const& baseInputs, std::shared_ptr const& workspace) { TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__); NVTX3_SCOPED_RANGE(ExternalDraftTokensLayer_forwardAsync); targetSoftmax(baseInputs, workspace); prepareInputs(outputs, baseInputs); // The logits from target engine should go through samplings first. // gptDecoderBatched.cpp is calling dynamic decoder step by step, in this step, dynamic Decoder already forwarded // PenaltyLayer, BanWordsLayer. For (TopK > 0) && (TopK == 0 && TopP == 0), we invoke TopK sampling kernel. The same // logic is implemented in SamplingLayer.cpp getAllTopKs(outputs, baseInputs, workspace); // Only for (TopK == 0 && TopP > 0), we invoke TopP sampling getAllTopPs(outputs, baseInputs, workspace); // After all selected tokens are filled in mOutputIdsAfterSampling by topK, topP kernels, token acceptance logics // starts. First we mask the logits of unselected token id to -inf as HF's TopK, TopP implementation. We compute the // logit probs of draft and target and go through acceptance logics. acceptDraftTokens(outputs, baseInputs, workspace); // If the token of the sequence is not accepted, a multinomial sampling is required for the bonus token. // Multinomial sampling is achieved through TopP kernel with TopP = 1 and already weighted-sum target logits. // The acceptance result of each batch is used as skipDecode in topP kernel. If is accepted, no sampling is needed // (early exit). Forwarding for the next step is also set in this kernel. multinomialSampling(outputs, baseInputs, workspace); // For the sequence with accepted tokens, we simply forward a step. forwardAcceptedTokens(outputs, baseInputs, workspace); TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__); } template size_t ExternalDraftTokensLayer::getWorkspaceSize() const noexcept { return std::max(mWorkspaceSize, mSetupWorkspaceSize); } template void ExternalDraftTokensLayer::targetSoftmax(std::shared_ptr const& baseInputs, std::shared_ptr const& workspace) { TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__); auto inputs = std::dynamic_pointer_cast(baseInputs); auto const batchSize = inputs->logits.value()->getDimension<0>(); auto const* endIds = bufferCast(*inputs->endIds); FinishedState const* finishedInput = (inputs->finished) ? reinterpret_cast(bufferCast(*inputs->finished.value())) : nullptr; inputs->curandStates = reinterpret_cast(bufferCast(*mCurandStatesDevice)); inputs->probsComputed = true; auto runtimeLogitsPtr = bufferCast(*workspace->getDeviceRuntimeLogits()); auto logitsPtrsPtr = static_cast(nullptr); auto biasPtr = static_cast(nullptr); auto const* batchSlotsPtr = workspace->getDeviceBatchSlotsPtr(); mBufferManager->copy(runtimeLogitsPtr, *mTargetLogits); BiasSoftmaxParams biasSoftmaxParams; biasSoftmaxParams.logits = runtimeLogitsPtr; biasSoftmaxParams.logitsPtrs = logitsPtrsPtr; biasSoftmaxParams.probs = runtimeLogitsPtr; biasSoftmaxParams.bias = biasPtr; biasSoftmaxParams.endIds = endIds; biasSoftmaxParams.finished = finishedInput; biasSoftmaxParams.batchSlots = batchSlotsPtr; biasSoftmaxParams.batchSize = batchSize; biasSoftmaxParams.maxBatchSize = mDecoderDomain.getBatchSize(); biasSoftmaxParams.maxBeamWidth = 1; biasSoftmaxParams.vocabSize = mDecoderDomain.getVocabSize(); biasSoftmaxParams.vocabSizePadded = mDecoderDomain.getVocabSizePadded(); biasSoftmaxParams.skipSoftMax = false; biasSoftmaxParams.batchSlotsLogits = false; biasSoftmaxParams.checkParams(); invokeAddBiasSoftMax(biasSoftmaxParams, getStream()); TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__); } template void ExternalDraftTokensLayer::acceptDraftTokens(std::shared_ptr const& outputs, std::shared_ptr const& baseInputs, std::shared_ptr const& workspace) { TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__); NVTX3_SCOPED_RANGE(ExternalDraftTokensLayer_acceptDraftTokens); auto inputs = std::dynamic_pointer_cast(baseInputs); auto const draftLogitsShape = (*inputs->draftLogits).getShape(); auto const maxBatchSize = mDecoderDomain.getBatchSize(); auto const maxTokensPerStep = draftLogitsShape.d[1]; // 1 auto const batchSize = static_cast(inputs->logits.value()->getDimension<0>()); auto constexpr beamWidth = 1; FinishedState const* finishedInput = (inputs->finished) ? reinterpret_cast(bufferCastOrNull(inputs->finished)) : nullptr; FinishedState* finishedOutput = (outputs->finished) ? reinterpret_cast(bufferCastOrNull(outputs->finished)) : nullptr; tksd::invokeMaskTargetLogits(batchSize, bufferCast(*mTargetLogits), workspace->getDeviceBatchSlotsPtr(), beamWidth, mDecoderDomain.getVocabSizePadded(), finishedInput, maxBatchSize, bufferCast(*mOutputIdsAfterSampling), bufferCastOrNull(mRuntimeTopKDevice), bufferCast(*mMaskBuffer), getStream()); auto const* batchSlotsHost = bufferCast(*inputs->batchSlots); auto const* useDraftLogitsHostPtr = bufferCastOrNull(inputs->useDraftLogitsHost); auto const skipDraftLogits = allOfBatchSlots(batchSlotsHost, useDraftLogitsHostPtr, batchSize, false); if (!skipDraftLogits && inputs->step == 0) { BiasSoftmaxParams biasSoftmaxParams; biasSoftmaxParams.logits = bufferCast(*inputs->draftLogits); biasSoftmaxParams.probs = bufferCast(*inputs->draftProbs); biasSoftmaxParams.finished = finishedInput; biasSoftmaxParams.batchSlots = workspace->getDeviceBatchSlotsPtr(); biasSoftmaxParams.batchSize = batchSize; biasSoftmaxParams.maxBatchSize = maxBatchSize; biasSoftmaxParams.maxBeamWidth = beamWidth * maxTokensPerStep; biasSoftmaxParams.vocabSize = mDecoderDomain.getVocabSize(); biasSoftmaxParams.vocabSizePadded = mDecoderDomain.getVocabSizePadded(); biasSoftmaxParams.skipSoftMax = false; biasSoftmaxParams.batchSlotsLogits = true; biasSoftmaxParams.checkParams(); invokeAddBiasSoftMax(biasSoftmaxParams, getStream()); } { BiasSoftmaxParams biasSoftmaxParams; biasSoftmaxParams.logits = bufferCast(*mTargetLogits); biasSoftmaxParams.probs = bufferCast(*inputs->targetProbs); biasSoftmaxParams.finished = finishedInput; biasSoftmaxParams.batchSlots = workspace->getDeviceBatchSlotsPtr(); biasSoftmaxParams.batchSize = batchSize; biasSoftmaxParams.maxBatchSize = maxBatchSize; biasSoftmaxParams.maxBeamWidth = beamWidth; biasSoftmaxParams.vocabSize = mDecoderDomain.getVocabSize(); biasSoftmaxParams.vocabSizePadded = mDecoderDomain.getVocabSizePadded(); biasSoftmaxParams.skipSoftMax = false; biasSoftmaxParams.batchSlotsLogits = false; biasSoftmaxParams.checkParams(); invokeAddBiasSoftMax(biasSoftmaxParams, getStream()); } sync_check_cuda_error(getStream()); tksd::invokeAcceptDraftTokens(batchSize, bufferCast(*inputs->draftProbs), bufferCast(*inputs->targetProbs), bufferCast(*inputs->numDraftTokens), bufferCast(*inputs->useDraftLogits), bufferCast(*inputs->draftTokenIds), finishedInput, finishedOutput, inputs->curandStates, workspace->getDeviceBatchSlotsPtr(), maxTokensPerStep, beamWidth, mDecoderDomain.getVocabSizePadded(), inputs->useRandomAcceptanceThreshold, inputs->constantThreshold, inputs->step, bufferCast(*mBatchIsAccepted), bufferCast(*mTargetOutputIds), getStream()); sync_check_cuda_error(getStream()); TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__); } template void ExternalDraftTokensLayer::multinomialSampling(std::shared_ptr const& outputs, std::shared_ptr const& baseInputs, std::shared_ptr const& workspace) { TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__); NVTX3_SCOPED_RANGE(ExternalDraftTokensLayer_multinomialSampling); auto inputs = std::dynamic_pointer_cast(baseInputs); auto const batchSize = inputs->logits.value()->getDimension<0>(); auto probs = bufferCastOrNull(inputs->targetProbs); auto* sequenceLength = bufferCastOrNull(outputs->sequenceLength); auto const* endIds = bufferCastOrNull(inputs->endIds); FinishedState* finishedOutput = (outputs->finished) ? reinterpret_cast(bufferCastOrNull(outputs->finished)) : nullptr; TopPSamplingKernelParams params{}; params.probs = probs; params.outputIdsPtrs = bufferCastOrNull(outputs->outputIdsPtr); params.workspace = workspace->getRawWorkspaceDevicePtr(); params.topPs = bufferCastOrNull(mRuntimeMultinomialDevice); params.sequenceLength = sequenceLength; params.endIds = endIds; params.batchSlots = workspace->getDeviceBatchSlotsPtr(); params.finishedInput = nullptr; params.finishedOutput = finishedOutput; params.skipDecode = bufferCastOrNull(mBatchIsAccepted); params.cumLogProbs = nullptr; params.outputLogProbs = nullptr; params.curandState = inputs->curandStates; params.batchSize = batchSize; params.maxBatchSize = mDecoderDomain.getBatchSize(); params.vocabSizePadded = mDecoderDomain.getVocabSizePadded(); if (!mIsAirTopP) { invokeBatchTopPSampling(params, getStream()); } else { params.blockNum = mAirTopPBlockNum; params.isDeterministic = mIsDeterministic; invokeBatchAirTopPSampling(params, getStream()); } TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__); } template void ExternalDraftTokensLayer::getAllTopKs(std::shared_ptr const& outputs, std::shared_ptr const& baseInputs, std::shared_ptr const& workspace) { TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__); NVTX3_SCOPED_RANGE(ExternalDraftTokensLayer_getAllTopKs); auto inputs = std::dynamic_pointer_cast(baseInputs); auto logits = bufferCastOrNull(inputs->logits); auto const batchSize = static_cast(inputs->logits.value()->getDimension<0>()); auto const* batchSlotsHost = bufferCast(*inputs->batchSlots); auto const* skipDecodeHostPtr = bufferCastOrNull(mSkipTopKDecodeHost); auto const skip = allOfBatchSlots(batchSlotsHost, skipDecodeHostPtr, batchSize, true); if (skip) { return; } auto* sequenceLength = bufferCastOrNull(outputs->sequenceLength); auto const* endIds = bufferCastOrNull(inputs->endIds); FinishedState const* finishedInput = (inputs->finished) ? reinterpret_cast(bufferCastOrNull(inputs->finished)) : nullptr; FinishedState* finishedOutput = (outputs->finished) ? reinterpret_cast(bufferCastOrNull(outputs->finished)) : nullptr; auto const* runtimeTopKHostPtr = bufferCast(*mRuntimeTopKHost); TopKSamplingKernelParams params{}; params.logProbs = logits; params.outputIdsPtrs = bufferCastOrNull(mOutputIdsAfterSamplingPtrsDevice); params.workspace = workspace->getRawWorkspaceDevicePtr(); params.endIds = endIds; params.sequenceLengths = sequenceLength; params.maxTopP = 1.0F; params.topPs = bufferCastOrNull(mRuntimeTopPDevice); params.maxTopK = maxOfBatchSlots(batchSlotsHost, runtimeTopKHostPtr, batchSize); params.topKs = bufferCastOrNull(mRuntimeTopKDevice); params.batchSlots = workspace->getDeviceBatchSlotsPtr(); params.finishedInput = finishedInput; params.finishedOutput = finishedOutput; params.skipDecode = bufferCastOrNull(mSkipTopKDecodeDevice); params.curandState = inputs->curandStates; params.batchSize = batchSize; params.maxBatchSize = mDecoderDomain.getBatchSize(); params.maxTokensPerStep = 1; params.vocabSizePadded = mDecoderDomain.getVocabSizePadded(); params.returnAllSelectedTokens = true; params.returnAllSelectedTokensPerSlot = bufferCastOrNull(mReturnAllSelectedTokensPerSlotDevice); params.logitsHasProbs = inputs->probsComputed; params.outputIdCurrentStep = bufferCastOrNull(mTargetOutputIds); params.skipOutputIdCurrentStep = bufferCast(*inputs->useDraftLogits); invokeBatchTopKSampling(params, getStream()); TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__); } template void ExternalDraftTokensLayer::getAllTopPs(std::shared_ptr const& outputs, std::shared_ptr const& baseInputs, std::shared_ptr const& workspace) { TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__); NVTX3_SCOPED_RANGE(ExternalDraftTokensLayer_getAllTopPs); auto inputs = std::dynamic_pointer_cast(baseInputs); auto logits = bufferCastOrNull(inputs->logits); auto const batchSize = static_cast(inputs->logits.value()->getDimension<0>()); auto const* batchSlotsHost = bufferCast(*inputs->batchSlots); auto const* skipDecodeHostPtr = bufferCastOrNull(mSkipTopPDecodeHost); auto const skip = allOfBatchSlots(batchSlotsHost, skipDecodeHostPtr, batchSize, true); if (skip) { return; } auto* sequenceLength = bufferCastOrNull(outputs->sequenceLength); auto const* endIds = bufferCastOrNull(inputs->endIds); FinishedState const* finishedInput = (inputs->finished) ? reinterpret_cast(bufferCastOrNull(inputs->finished)) : nullptr; FinishedState* finishedOutput = (outputs->finished) ? reinterpret_cast(bufferCastOrNull(outputs->finished)) : nullptr; TopPSamplingKernelParams params{}; params.probs = logits; params.outputIdsPtrs = bufferCastOrNull(mOutputIdsAfterSamplingPtrsDevice); params.workspace = workspace->getRawWorkspaceDevicePtr(); params.endIds = endIds; params.sequenceLength = sequenceLength; params.topPs = bufferCastOrNull(mRuntimeTopPDevice); params.batchSlots = workspace->getDeviceBatchSlotsPtr(); params.finishedInput = finishedInput; params.finishedOutput = finishedOutput; params.skipDecode = bufferCastOrNull(mSkipTopPDecodeDevice); params.curandState = inputs->curandStates; params.batchSize = batchSize; params.maxBatchSize = mDecoderDomain.getBatchSize(); params.vocabSizePadded = mDecoderDomain.getVocabSizePadded(); params.returnAllSelectedTokens = true; params.returnAllSelectedTokensPerSlot = bufferCastOrNull(mReturnAllSelectedTokensPerSlotDevice); params.outputIdCurrentStep = bufferCastOrNull(mTargetOutputIds); params.skipOutputIdCurrentStep = bufferCast(*inputs->useDraftLogits); invokeBatchTopPSampling(params, getStream()); TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__); } template void ExternalDraftTokensLayer::forwardAcceptedTokens(std::shared_ptr const& outputs, std::shared_ptr const& baseInputs, std::shared_ptr const& workspace) { TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__); NVTX3_SCOPED_RANGE(ExternalDraftTokensLayer_forwardAcceptedTokens); auto inputs = std::dynamic_pointer_cast(baseInputs); auto const batchSize = inputs->logits.value()->getDimension<0>(); auto const draftLogitsShape = (*inputs->draftLogits).getShape(); auto const maxTokensPerStep = draftLogitsShape.d[1]; // 1 FinishedState* finishedOutput = (outputs->finished) ? reinterpret_cast(bufferCastOrNull(outputs->finished)) : nullptr; tksd::invokeForwardAcceptedTokens(batchSize, workspace->getDeviceBatchSlotsPtr(), bufferCast(*mBatchIsAccepted), bufferCastOrNull(outputs->sequenceLength), bufferCast(*inputs->draftTokenIds), bufferCastOrNull(outputs->outputIdsPtr), inputs->step, maxTokensPerStep, bufferCastOrNull(inputs->endIds), finishedOutput, getStream()); sync_check_cuda_error(getStream()); TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__); } template class ExternalDraftTokensLayer; template class ExternalDraftTokensLayer; } // namespace tensorrt_llm::layers