TensorRT-LLMs/cpp/tensorrt_llm/runtime/gptDecoder.cpp
2024-04-30 17:19:10 +08:00

563 lines
24 KiB
C++

/*
* Copyright (c) 2022-2024, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "tensorrt_llm/runtime/gptDecoder.h"
#include "tensorrt_llm/common/cudaAllocator.h"
#include "tensorrt_llm/common/tensorConversion.h"
#include "tensorrt_llm/kernels/decodingKernels.h"
#include "tensorrt_llm/kernels/parallelDecoding/kvCacheUpdateKernels.h"
#include "tensorrt_llm/layers/dynamicDecodeLayer.h"
#include <memory>
#include <NvInferRuntime.h>
namespace tc = tensorrt_llm::common;
namespace tl = tensorrt_llm::layers;
namespace tcc = tensorrt_llm::common::conversion;
using namespace tensorrt_llm::runtime;
template <typename T>
GptDecoder<T>::GptDecoder(DecodingMode const& mode, size_t maxBatchSize, size_t maxBeamWidth, size_t vocabSize,
size_t vocabSizePadded, size_t maxSequenceLength, CudaStreamPtr const& stream,
std::optional<runtime::SizeType> maxTokensPerStep, std::optional<runtime::SizeType> maxNumMedusaHeads)
: mManager{stream}
, mMaxBatchSize(maxBatchSize)
{
auto const decodingDomain = tensorrt_llm::layers::DecoderDomain(
maxBatchSize, maxBeamWidth, vocabSize, vocabSizePadded, maxTokensPerStep, maxNumMedusaHeads);
auto allocator = std::make_shared<common::CudaAllocator>(mManager);
mDynamicDecodeLayer = std::make_shared<tensorrt_llm::layers::DynamicDecodeLayer<T>>(
mode, decodingDomain, stream->get(), std::move(allocator));
auto constexpr nvFloatType = TRTDataType<float>::value;
mLogProbsTiled = mManager.gpu(ITensor::makeShape({static_cast<SizeType>(maxSequenceLength),
static_cast<SizeType>(maxBatchSize), static_cast<SizeType>(maxBeamWidth)}),
nvFloatType);
mManager.setZero(*mLogProbsTiled);
}
template <typename T>
void GptDecoder<T>::setup(SamplingConfig const& samplingConfig, size_t batchSize, SizeType maxSequenceLength,
std::optional<TensorPtr> const& batchSlots)
{
mSamplingConfig = samplingConfig;
auto setupParams = std::make_shared<layers::DynamicDecodeSetupParams>();
setupParams->penaltyParams.repetitionPenalty = samplingConfig.repetitionPenalty;
setupParams->penaltyParams.presencePenalty = samplingConfig.presencePenalty;
setupParams->penaltyParams.frequencyPenalty = samplingConfig.frequencyPenalty;
setupParams->penaltyParams.temperature = samplingConfig.temperature;
setupParams->penaltyParams.minLength = samplingConfig.minLength;
setupParams->randomSeed = samplingConfig.randomSeed;
setupParams->samplingParams.normalize_log_probs = samplingConfig.normalizeLogProbs;
// signed to unsigned
if (samplingConfig.topK)
{
auto const& topK = samplingConfig.topK.value();
setupParams->samplingParams.runtime_top_k = std::vector<SizeType>(std::begin(topK), std::end(topK));
}
setupParams->samplingParams.runtime_top_p = samplingConfig.topP;
setupParams->samplingParams.top_p_decay = samplingConfig.topPDecay;
setupParams->samplingParams.top_p_min = samplingConfig.topPMin;
setupParams->samplingParams.top_p_reset_ids = samplingConfig.topPResetIds;
setupParams->beamSearchParams.beam_search_diversity_rate = samplingConfig.beamSearchDiversityRate;
setupParams->beamSearchParams.length_penalty = samplingConfig.lengthPenalty;
setupParams->beamSearchParams.early_stopping = samplingConfig.earlyStopping;
setupParams->medusaParams.topKMedusaHeads = samplingConfig.topKMedusaHeads;
auto const batchSlotsPtr = batchSlots.has_value() ? bufferCast<SizeType>(*(batchSlots.value())) : nullptr;
mDynamicDecodeLayer->setup(batchSize, samplingConfig.beamWidth, batchSlotsPtr, setupParams);
}
namespace
{
void safeInsert(tc::TensorMap& map, std::string const& key, DecodingOutput::TensorPtr const& tensor)
{
if (tensor)
{
ITensor const& t{*tensor};
map.insert({key, tcc::toTllmTensor(t)});
}
}
template <typename T>
tl::DynamicDecodeInputParams::MedusaInputs prepareMedusaInputs(DecodingInput const& inputs, size_t maxBatchSize)
{
TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__);
auto const& medusaInputs = inputs.medusaInputs.value();
tl::DynamicDecodeInputParams::MedusaInputs medusaDecodingInputs;
medusaDecodingInputs.medusaCurTokensPerStep = tcc::toTllmTensor(*medusaInputs.medusaCurTokensPerStep);
medusaDecodingInputs.medusaTargetTokensPerStep = tcc::toTllmTensor(*medusaInputs.medusaTargetTokensPerStep);
medusaDecodingInputs.medusaPaths = tcc::toTllmTensor(*medusaInputs.medusaPaths);
medusaDecodingInputs.medusaTreeIds = tcc::toTllmTensor(*medusaInputs.medusaTreeIds);
auto const batchSlots = bufferCast<SizeType>(*inputs.batchSlots);
if (medusaInputs.medusaLogits.size())
{
std::vector<std::vector<tc::Tensor>> medusaLogits;
auto const batchSize = medusaInputs.medusaLogits.size();
medusaLogits.resize(maxBatchSize);
for (size_t bi = 0; bi < batchSize; ++bi)
{
auto const slot = batchSlots[bi];
auto const& logitsHeads = medusaInputs.medusaLogits.at(slot);
auto const medusaHeads = logitsHeads.size();
medusaLogits[slot].resize(medusaHeads);
for (size_t hi = 0; hi < medusaHeads; ++hi)
{
if (logitsHeads[hi])
{
medusaLogits[slot][hi] = tcc::toTllmTensor(*logitsHeads[hi]);
}
}
}
medusaDecodingInputs.medusaLogits = medusaLogits;
}
TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__);
return medusaDecodingInputs;
}
template <typename T>
std::shared_ptr<tl::DynamicDecodeInputParams> prepareInputs(DecodingInput const& input, size_t maxBatchSize)
{
auto constexpr ite = 0; // no pipeline parallelism
auto forwardParams = std::make_shared<tl::DynamicDecodeInputParams>(input.step, ite, input.maxLength,
input.maxAttentionWindow, input.sinkTokenLength, input.maxBatchSize, tcc::toTllmTensor(*input.endIds));
if (input.logitsVec)
{
std::vector<tc::Tensor> logitsVec;
for (auto const& logits : input.logitsVec.value())
{
TLLM_CHECK(logits->getDataType() == TRTDataType<T>::value);
logitsVec.push_back(tcc::toTllmTensor(*logits));
}
forwardParams->logits_vec = logitsVec;
}
else
{
TLLM_CHECK(input.logits->getDataType() == TRTDataType<T>::value);
forwardParams->logits = tcc::toTllmTensor(*input.logits);
}
if (input.cacheIndirection)
{
forwardParams->src_cache_indirection = tcc::toTllmTensor(*input.cacheIndirection);
}
if (input.sequenceLimitLength)
{
forwardParams->sequence_limit_length = tcc::toTllmTensor(*input.sequenceLimitLength);
}
if (input.embeddingBias)
{
forwardParams->embedding_bias = tcc::toTllmTensor(*input.embeddingBias);
}
if (input.lengths)
{
forwardParams->input_lengths = tcc::toTllmTensor(*input.lengths);
}
if (input.badWordsPtrs)
{
TLLM_CHECK_WITH_INFO(input.badWordsPtrs, "Bad word lengths must be provided when badWordsPtrs is given");
forwardParams->bad_words_ptr = tcc::toTllmTensor(*input.badWordsPtrs);
forwardParams->bad_words_lengths = tcc::toTllmTensor(*input.badWordsLens);
forwardParams->max_bad_words_len = input.maxBadWordsLen;
}
if (input.stopWordsPtrs)
{
TLLM_CHECK_WITH_INFO(input.stopWordsLens, "Stop word lengths must be provided when stopWordsPtrs is given");
forwardParams->stop_words_ptr = tcc::toTllmTensor(*input.stopWordsPtrs);
forwardParams->stop_words_lengths = tcc::toTllmTensor(*input.stopWordsLens);
forwardParams->max_stop_words_len = input.maxStopWordsLen;
}
if (input.finished)
{
forwardParams->finished = tcc::toTllmTensor(*input.finished);
}
if (input.batchSlots)
{
forwardParams->batch_slots = tcc::toTllmTensor(*input.batchSlots);
}
// Medusa
if (input.medusaInputs)
{
forwardParams->medusaInputs = prepareMedusaInputs<T>(input, maxBatchSize);
}
TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__);
return forwardParams;
}
template <typename T>
tl::DynamicDecodeOutputParams::MedusaOutputs prepareMedusaOutputs(DecodingOutput::MedusaOutputs& output)
{
TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__);
tl::DynamicDecodeOutputParams::MedusaOutputs medusaOutputs;
medusaOutputs.nextDraftTokens = tcc::toTllmTensor(*output.medusaNextDraftTokens);
medusaOutputs.acceptedLengths = tcc::toTllmTensor(*output.medusaAcceptedTokensLen);
medusaOutputs.acceptedLengthsCumSum = tcc::toTllmTensor(*output.medusaAcceptedLengthsCumSum);
medusaOutputs.pathsOffsets = tcc::toTllmTensor(*output.medusaPathsOffsets);
TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__);
return medusaOutputs;
}
template <typename T>
std::shared_ptr<tl::DynamicDecodeOutputParams> prepareOutputs(
DecodingOutput& output, DecodingInput::TensorPtr const& inputLengths, DecodingOutput::TensorPtr& logProbsTiled)
{
TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__);
auto outputParams = std::make_shared<tl::DynamicDecodeOutputParams>(tcc::toTllmTensor(*output.ids));
outputParams->newTokens = tcc::toTllmTensor(*output.newTokens);
if (output.cumLogProbs)
{
outputParams->cum_log_probs = tcc::toTllmTensor(*output.cumLogProbs);
}
if (output.parentIds)
{
outputParams->parent_ids = tcc::toTllmTensor(*output.parentIds);
}
if (output.cacheIndirection)
{
outputParams->tgt_cache_indirection = tcc::toTllmTensor(*output.cacheIndirection);
}
if (output.finished)
{
outputParams->finished = tcc::toTllmTensor(*output.finished);
}
if (output.finishedSum)
{
outputParams->finished_sum = tcc::toTllmTensor(*output.finishedSum);
}
if (output.lengths)
{
outputParams->sequence_length = tcc::toTllmTensor(*output.lengths);
}
if (output.logProbs)
{
outputParams->output_log_probs = tcc::toTllmTensor(*output.logProbs);
outputParams->output_log_probs_tiled = tcc::toTllmTensor(*logProbsTiled);
}
outputParams->beamHypotheses = std::make_unique<tensorrt_llm::kernels::BeamHypotheses>();
if (output.beamHypotheses.batchDones)
{
outputParams->beamHypotheses->batchDones = bufferCast<bool>(*output.beamHypotheses.batchDones);
}
if (output.beamHypotheses.cumLogProbsCBA)
{
outputParams->beamHypotheses->cumLogProbsCBA = bufferCast<float>(*output.beamHypotheses.cumLogProbsCBA);
}
if (output.beamHypotheses.logProbsCBA)
{
outputParams->beamHypotheses->logProbsCBA = bufferCast<float>(*output.beamHypotheses.logProbsCBA);
}
if (output.beamHypotheses.minNormedScoresCBA)
{
outputParams->beamHypotheses->minNormedScoresCBA = bufferCast<float>(*output.beamHypotheses.minNormedScoresCBA);
}
if (output.beamHypotheses.normedScoresCBA)
{
outputParams->beamHypotheses->normedScoresCBA = bufferCast<float>(*output.beamHypotheses.normedScoresCBA);
}
if (output.beamHypotheses.numBeamsCBA)
{
outputParams->beamHypotheses->numBeamsCBA = bufferCast<int>(*output.beamHypotheses.numBeamsCBA);
}
if (output.beamHypotheses.outputIdsCBA)
{
outputParams->beamHypotheses->outputIdsCBA = bufferCast<int>(*output.beamHypotheses.outputIdsCBA);
}
if (output.beamHypotheses.sequenceLengthsCBA)
{
outputParams->beamHypotheses->sequenceLengthsCBA = bufferCast<int>(*output.beamHypotheses.sequenceLengthsCBA);
}
if (inputLengths)
{
outputParams->beamHypotheses->inputLengths = bufferCast<int32_t>(*inputLengths);
}
// Medusa
if (output.medusaOutputs)
{
outputParams->medusaOutputs = prepareMedusaOutputs<T>(output.medusaOutputs.value());
}
TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__);
return outputParams;
}
} // namespace
template <typename T>
bool GptDecoder<T>::forward(DecodingOutput& output, DecodingInput const& input)
{
TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__);
auto forwardParams = prepareInputs<T>(input, mMaxBatchSize);
auto outputParams = prepareOutputs<T>(output, input.lengths, mLogProbsTiled);
auto const maxBatchSize = input.maxBatchSize;
BufferManager::ITensorPtr finishedSum;
std::int32_t* finishedSumHost = nullptr;
if (input.sequenceLimitLength && output.finished)
{
if (output.finishedSum)
{
finishedSumHost = bufferCast<std::int32_t>(*output.finishedSum);
}
else
{
finishedSum = BufferManager::pinned(ITensor::makeShape({maxBatchSize}), nvinfer1::DataType::kINT32);
outputParams->finished_sum = tcc::toTllmTensor(*finishedSum);
finishedSumHost = bufferCast<std::int32_t>(*finishedSum);
}
for (SizeType bi = 0; bi < maxBatchSize; ++bi)
{
finishedSumHost[bi] = 0;
}
}
mDynamicDecodeLayer->forward(outputParams, forwardParams);
if (finishedSumHost)
{
auto const numToFinish = output.finished->getSize();
TLLM_CUDA_CHECK(::cudaStreamSynchronize(mDynamicDecodeLayer->getStream()));
SizeType finishedSum = 0;
for (SizeType bi = 0; bi < maxBatchSize; ++bi)
{
finishedSum += finishedSumHost[bi];
}
return numToFinish == static_cast<std::size_t>(finishedSum);
}
else
{
return false;
}
}
template <typename T>
void GptDecoder<T>::forwardAsync(DecodingOutput& output, DecodingInput const& input)
{
TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__);
auto forwardParams = prepareInputs<T>(input, mMaxBatchSize);
auto outputParams = prepareOutputs<T>(output, input.lengths, mLogProbsTiled);
mDynamicDecodeLayer->forward(outputParams, forwardParams);
TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__);
}
// Must be similar to [cpp/tensorrt_llm/thop/gatherTreeOp.cpp] gatherTree
template <typename T>
void GptDecoder<T>::gatherTree(ITensor& finalOutputIds, DecodingOutput const& decodingOutput,
DecodingInput const& decodingInput, BufferManager const& manager)
{
TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__);
auto const& finalOutputIdsShape = finalOutputIds.getShape();
auto const& decodingOutputIdsShape = decodingOutput.ids->getShape();
auto const batchSize = finalOutputIdsShape.d[0];
auto const beamWidth = finalOutputIdsShape.d[1];
auto const maxSeqLength = finalOutputIdsShape.d[2];
TLLM_CHECK_WITH_INFO(beamWidth > 1, "gatherTree is only needed for beam search.");
TLLM_CHECK_WITH_INFO(decodingOutputIdsShape.d[0] == batchSize,
common::fmtstr("Decoder batch size (" FMT_DIM ") does not match final batch size (" FMT_DIM ")",
decodingOutputIdsShape.d[0], batchSize));
TLLM_CHECK_WITH_INFO(decodingOutputIdsShape.d[1] == beamWidth,
common::fmtstr("Decoder beam width (" FMT_DIM ") does not match final beam width (" FMT_DIM ")",
decodingOutputIdsShape.d[1], beamWidth));
TLLM_CHECK_WITH_INFO(decodingOutputIdsShape.d[2] <= maxSeqLength,
common::fmtstr("Decoder seq length size (" FMT_DIM ") is too large for final seq length (" FMT_DIM ")",
decodingOutputIdsShape.d[2], maxSeqLength));
auto const& stream = manager.getStream().get();
tensorrt_llm::kernels::invokeInitializeOutput(bufferCast<TokenIdType>(finalOutputIds),
bufferCast<TokenIdType>(*decodingInput.endIds), batchSize * beamWidth, maxSeqLength, stream);
sync_check_cuda_error();
tensorrt_llm::kernels::BeamHypotheses bh;
bh.nBatchSize = batchSize;
bh.nBeamWidth = beamWidth;
bh.nMaxSeqLen = maxSeqLength;
bh.lengthPenalties = nullptr; // TODO (bhsueh): A gpu tensor used in invokeInsertUnfinishedPath
// default value (1.0f) will be used when it is nullptr
bh.inputLengths = bufferCast<SizeType>(*decodingInput.lengths);
bh.outputIds = bufferCast<TokenIdType>(finalOutputIds);
bh.logProbs = bufferCast<float>(*mLogProbsTiled);
bh.sequenceLengths = bufferCast<SizeType>(*decodingOutput.lengths);
bh.cumLogProbs = bufferCast<float>(*decodingOutput.cumLogProbs);
bh.outputIdsCBA = bufferCast<TokenIdType>(*decodingOutput.beamHypotheses.outputIdsCBA);
bh.logProbsCBA = bufferCast<float>(*decodingOutput.beamHypotheses.logProbsCBA);
bh.sequenceLengthsCBA = bufferCast<SizeType>(*decodingOutput.beamHypotheses.sequenceLengthsCBA);
bh.cumLogProbsCBA = bufferCast<float>(*decodingOutput.beamHypotheses.cumLogProbsCBA);
bh.normedScoresCBA = bufferCast<float>(*decodingOutput.beamHypotheses.normedScoresCBA);
bh.numBeamsCBA = bufferCast<SizeType>(*decodingOutput.beamHypotheses.numBeamsCBA);
bh.minNormedScoresCBA = bufferCast<float>(*decodingOutput.beamHypotheses.minNormedScoresCBA);
bh.batchDones = bufferCast<bool>(*decodingOutput.beamHypotheses.batchDones);
bh.finished = reinterpret_cast<tensorrt_llm::kernels::FinishedState*>(
bufferCast<tensorrt_llm::kernels::FinishedState::UnderlyingType>(*decodingOutput.finished));
bh.outputIdsUnfinish = bufferCast<TokenIdType>(*decodingOutput.ids);
bh.parentIdsUnfinish = bufferCast<TokenIdType>(*decodingOutput.parentIds);
// This is where transpose is done
tensorrt_llm::kernels::invokeInsertUnfinishedPath(bh, stream);
sync_check_cuda_error();
tensorrt_llm::kernels::invokeFinalize(bh, stream);
sync_check_cuda_error();
TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__);
}
namespace tensorrt_llm::runtime
{
template class GptDecoder<float>;
template class GptDecoder<half>;
} // namespace tensorrt_llm::runtime
void IGptDecoder::acceptDraftTokensByIds(ITensor const& targetTokenIds, ITensor const& draftTokenIds,
ITensor const& contextLengths, ITensor const& numDraftTokens, ITensor& sequenceLengths, ITensor const& finishedVec,
ITensor& finishedFinal, ITensor& finishedSum, ITensor const& batchSlots, BufferManager::CudaStreamPtr const& stream)
{
TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__);
auto const finishedVecShape = finishedVec.getShape();
auto const maxBatchSize = finishedVecShape.d[1];
auto const batchSlotsShape = batchSlots.getShape();
auto const batchSize = batchSlotsShape.d[0];
auto const targetTokenIdsShape = targetTokenIds.getShape();
auto const beamWidth = targetTokenIdsShape.d[1];
auto const maxSeqLength = targetTokenIdsShape.d[2];
auto const maxDraftTokens = draftTokenIds.getShape().d[1];
TLLM_CHECK_WITH_INFO(beamWidth == 1,
common::fmtstr("Beam width (" FMT_DIM ") > 1 is not supported for the speculative decoding", beamWidth));
TLLM_CHECK_WITH_INFO(batchSize <= maxBatchSize,
common::fmtstr("Batch size (" FMT_DIM ") is not smaller or equal to max batch size (" FMT_DIM ")", batchSize,
maxBatchSize));
TLLM_CHECK_WITH_INFO(draftTokenIds.getShape().d[0] == maxBatchSize,
common::fmtstr("Draft tokens batch size (" FMT_DIM ") is not equal to target batch size (" FMT_DIM ")",
draftTokenIds.getShape().d[0], maxBatchSize));
TLLM_CHECK_WITH_INFO(contextLengths.getShape().d[0] == maxBatchSize,
common::fmtstr("Context length batch size (" FMT_DIM ") is not equal to batch size (" FMT_DIM ")",
contextLengths.getShape().d[0], maxBatchSize));
TLLM_CHECK_WITH_INFO(numDraftTokens.getShape().d[0] == maxBatchSize,
common::fmtstr("Num draft tokens batch size (" FMT_DIM ") is not equal to batch size (" FMT_DIM ")",
numDraftTokens.getShape().d[0], maxBatchSize));
TLLM_CHECK_WITH_INFO(sequenceLengths.getShape().d[0] == maxBatchSize,
common::fmtstr("Sequence length batch size (" FMT_DIM ") is not equal to batch size (" FMT_DIM ")",
sequenceLengths.getShape().d[0], maxBatchSize));
tensorrt_llm::kernels::invokeAcceptDraftTokensByIds(bufferCast<SizeType>(draftTokenIds),
bufferCast<SizeType>(targetTokenIds), bufferCast<SizeType>(contextLengths),
bufferCast<SizeType>(numDraftTokens), bufferCast<SizeType>(sequenceLengths),
reinterpret_cast<tensorrt_llm::kernels::FinishedState const*>(
bufferCast<tensorrt_llm::kernels::FinishedState::UnderlyingType>(finishedVec)),
reinterpret_cast<tensorrt_llm::kernels::FinishedState*>(
bufferCast<tensorrt_llm::kernels::FinishedState::UnderlyingType>(finishedFinal)),
bufferCast<int>(finishedSum), bufferCast<SizeType>(batchSlots), batchSize, maxBatchSize, beamWidth,
maxSeqLength, maxDraftTokens, stream->get());
sync_check_cuda_error();
TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__);
}
void IGptDecoder::acceptDraftTokensByLogits(ITensor& draftLogits, ITensor const& targetLogits, ITensor& draftProbs,
ITensor& targetProbs, ITensor const& numDraftTokens, ITensor& finished, ITensor const& batchSlots,
SizeType vocabSize, SizeType vocabSizePadded, bool useRandomAcceptThreshold, float randomAcceptThreshold,
curandState_t* curandState, BufferManager::CudaStreamPtr const& stream)
{
TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__);
auto const draftLogitsShape = draftLogits.getShape();
auto const maxBatchSize = draftLogitsShape.d[0];
auto const maxTokensPerStep = draftLogitsShape.d[1];
auto const batchSlotsShape = batchSlots.getShape();
auto const batchSize = batchSlotsShape.d[0];
auto constexpr beamWidth = 1;
TLLM_CHECK_WITH_INFO(
beamWidth == 1, common::fmtstr("Beam width (%d) > 1 is not supported for the speculative decoding", beamWidth));
TLLM_CHECK(draftLogitsShape.d[2] == vocabSize);
if (draftLogits.getDataType() == nvinfer1::DataType::kFLOAT)
{
tensorrt_llm::kernels::acceptDraftTokensByLogits(bufferCast<float>(draftLogits),
const_cast<float**>(reinterpret_cast<float const* const*>(bufferCast<int64_t>(targetLogits))),
bufferCast<float>(draftProbs), bufferCast<float>(targetProbs), bufferCast<SizeType>(numDraftTokens),
reinterpret_cast<tensorrt_llm::kernels::FinishedState*>(
bufferCast<tensorrt_llm::kernels::FinishedState::UnderlyingType>(finished)),
curandState, bufferCast<SizeType>(batchSlots), batchSize, maxBatchSize, beamWidth, vocabSize,
vocabSizePadded, maxTokensPerStep, useRandomAcceptThreshold, randomAcceptThreshold, stream->get());
}
else if (draftLogits.getDataType() == nvinfer1::DataType::kHALF)
{
tensorrt_llm::kernels::acceptDraftTokensByLogits(bufferCast<half>(draftLogits),
const_cast<half**>(reinterpret_cast<half const* const*>(bufferCast<int64_t>(targetLogits))),
bufferCast<half>(draftProbs), bufferCast<half>(targetProbs), bufferCast<SizeType>(numDraftTokens),
reinterpret_cast<tensorrt_llm::kernels::FinishedState*>(
bufferCast<tensorrt_llm::kernels::FinishedState::UnderlyingType>(finished)),
curandState, bufferCast<SizeType>(batchSlots), batchSize, maxBatchSize, beamWidth, vocabSize,
vocabSizePadded, maxTokensPerStep, useRandomAcceptThreshold, randomAcceptThreshold, stream->get());
}
else
{
TLLM_THROW("Incorrect logits dtype. Only float32 and float16 are supported");
}
sync_check_cuda_error();
TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__);
}