/* * Copyright (c) 2022-2024, NVIDIA CORPORATION. All rights reserved. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ #include "tensorrt_llm/runtime/gptDecoder.h" #include "tensorrt_llm/common/cudaAllocator.h" #include "tensorrt_llm/common/tensorConversion.h" #include "tensorrt_llm/kernels/decodingKernels.h" #include "tensorrt_llm/kernels/parallelDecoding/kvCacheUpdateKernels.h" #include "tensorrt_llm/layers/dynamicDecodeLayer.h" #include #include namespace tc = tensorrt_llm::common; namespace tl = tensorrt_llm::layers; namespace tcc = tensorrt_llm::common::conversion; using namespace tensorrt_llm::runtime; template GptDecoder::GptDecoder(DecodingMode const& mode, size_t maxBatchSize, size_t maxBeamWidth, size_t vocabSize, size_t vocabSizePadded, size_t maxSequenceLength, CudaStreamPtr const& stream, std::optional maxTokensPerStep, std::optional maxNumMedusaHeads) : mManager{stream} , mMaxBatchSize(maxBatchSize) { auto const decodingDomain = tensorrt_llm::layers::DecoderDomain( maxBatchSize, maxBeamWidth, vocabSize, vocabSizePadded, maxTokensPerStep, maxNumMedusaHeads); auto allocator = std::make_shared(mManager); mDynamicDecodeLayer = std::make_shared>( mode, decodingDomain, stream->get(), std::move(allocator)); auto constexpr nvFloatType = TRTDataType::value; mLogProbsTiled = mManager.gpu(ITensor::makeShape({static_cast(maxSequenceLength), static_cast(maxBatchSize), static_cast(maxBeamWidth)}), nvFloatType); mManager.setZero(*mLogProbsTiled); } template void GptDecoder::setup(SamplingConfig const& samplingConfig, size_t batchSize, SizeType32 maxSequenceLength, std::optional const& batchSlots) { mSamplingConfig = samplingConfig; auto setupParams = std::make_shared(); setupParams->penaltyParams.repetitionPenalty = samplingConfig.repetitionPenalty; setupParams->penaltyParams.presencePenalty = samplingConfig.presencePenalty; setupParams->penaltyParams.frequencyPenalty = samplingConfig.frequencyPenalty; setupParams->penaltyParams.temperature = samplingConfig.temperature; setupParams->penaltyParams.minLength = samplingConfig.minLength; setupParams->randomSeed = samplingConfig.randomSeed; setupParams->samplingParams.normalize_log_probs = samplingConfig.normalizeLogProbs; // signed to unsigned if (samplingConfig.topK) { auto const& topK = samplingConfig.topK.value(); setupParams->samplingParams.runtime_top_k = std::vector(std::begin(topK), std::end(topK)); } setupParams->samplingParams.runtime_top_p = samplingConfig.topP; setupParams->samplingParams.top_p_decay = samplingConfig.topPDecay; setupParams->samplingParams.top_p_min = samplingConfig.topPMin; setupParams->samplingParams.top_p_reset_ids = samplingConfig.topPResetIds; setupParams->beamSearchParams.beam_search_diversity_rate = samplingConfig.beamSearchDiversityRate; setupParams->beamSearchParams.length_penalty = samplingConfig.lengthPenalty; setupParams->beamSearchParams.early_stopping = samplingConfig.earlyStopping; setupParams->medusaParams.topKMedusaHeads = samplingConfig.topKMedusaHeads; auto const batchSlotsPtr = batchSlots.has_value() ? bufferCast(*(batchSlots.value())) : nullptr; mDynamicDecodeLayer->setup(batchSize, samplingConfig.beamWidth, batchSlotsPtr, setupParams); } namespace { void safeInsert(tc::TensorMap& map, std::string const& key, DecodingOutput::TensorPtr const& tensor) { if (tensor) { ITensor const& t{*tensor}; map.insert({key, tcc::toTllmTensor(t)}); } } template tl::DynamicDecodeInputParams::MedusaInputs prepareMedusaInputs(DecodingInput const& inputs, size_t maxBatchSize) { TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__); auto const& medusaInputs = inputs.medusaInputs.value(); tl::DynamicDecodeInputParams::MedusaInputs medusaDecodingInputs; medusaDecodingInputs.medusaCurTokensPerStep = tcc::toTllmTensor(*medusaInputs.medusaCurTokensPerStep); medusaDecodingInputs.medusaTargetTokensPerStep = tcc::toTllmTensor(*medusaInputs.medusaTargetTokensPerStep); medusaDecodingInputs.medusaPaths = tcc::toTllmTensor(*medusaInputs.medusaPaths); medusaDecodingInputs.medusaTreeIds = tcc::toTllmTensor(*medusaInputs.medusaTreeIds); auto const batchSlots = bufferCast(*inputs.batchSlots); if (medusaInputs.medusaLogits.size()) { std::vector> medusaLogits; auto const batchSize = medusaInputs.medusaLogits.size(); medusaLogits.resize(maxBatchSize); for (size_t bi = 0; bi < batchSize; ++bi) { auto const slot = batchSlots[bi]; auto const& logitsHeads = medusaInputs.medusaLogits.at(slot); auto const medusaHeads = logitsHeads.size(); medusaLogits[slot].resize(medusaHeads); for (size_t hi = 0; hi < medusaHeads; ++hi) { if (logitsHeads[hi]) { medusaLogits[slot][hi] = tcc::toTllmTensor(*logitsHeads[hi]); } } } medusaDecodingInputs.medusaLogits = medusaLogits; } TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__); return medusaDecodingInputs; } template std::shared_ptr prepareInputs(DecodingInput const& input, size_t maxBatchSize) { auto constexpr ite = 0; // no pipeline parallelism auto forwardParams = std::make_shared(input.step, ite, input.maxLength, input.maxAttentionWindow, input.sinkTokenLength, input.maxBatchSize, tcc::toTllmTensor(*input.endIds)); if (input.logitsVec) { std::vector logitsVec; for (auto const& logits : input.logitsVec.value()) { TLLM_CHECK(logits->getDataType() == TRTDataType::value); logitsVec.push_back(tcc::toTllmTensor(*logits)); } forwardParams->logits_vec = logitsVec; } else { TLLM_CHECK(input.logits->getDataType() == TRTDataType::value); forwardParams->logits = tcc::toTllmTensor(*input.logits); } if (input.cacheIndirection) { forwardParams->src_cache_indirection = tcc::toTllmTensor(*input.cacheIndirection); } if (input.sequenceLimitLength) { forwardParams->sequence_limit_length = tcc::toTllmTensor(*input.sequenceLimitLength); } if (input.embeddingBias) { forwardParams->embedding_bias = tcc::toTllmTensor(*input.embeddingBias); } if (input.lengths) { forwardParams->input_lengths = tcc::toTllmTensor(*input.lengths); } if (input.badWordsPtrs) { TLLM_CHECK_WITH_INFO(input.badWordsPtrs, "Bad word lengths must be provided when badWordsPtrs is given"); forwardParams->bad_words_ptr = tcc::toTllmTensor(*input.badWordsPtrs); forwardParams->bad_words_lengths = tcc::toTllmTensor(*input.badWordsLens); forwardParams->max_bad_words_len = input.maxBadWordsLen; } if (input.stopWordsPtrs) { TLLM_CHECK_WITH_INFO(input.stopWordsLens, "Stop word lengths must be provided when stopWordsPtrs is given"); forwardParams->stop_words_ptr = tcc::toTllmTensor(*input.stopWordsPtrs); forwardParams->stop_words_lengths = tcc::toTllmTensor(*input.stopWordsLens); forwardParams->max_stop_words_len = input.maxStopWordsLen; } if (input.finished) { forwardParams->finished = tcc::toTllmTensor(*input.finished); } if (input.batchSlots) { forwardParams->batch_slots = tcc::toTllmTensor(*input.batchSlots); } // Medusa if (input.medusaInputs) { forwardParams->medusaInputs = prepareMedusaInputs(input, maxBatchSize); } TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__); return forwardParams; } template tl::DynamicDecodeOutputParams::MedusaOutputs prepareMedusaOutputs(DecodingOutput::MedusaOutputs& output) { TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__); tl::DynamicDecodeOutputParams::MedusaOutputs medusaOutputs; medusaOutputs.nextDraftTokens = tcc::toTllmTensor(*output.medusaNextDraftTokens); medusaOutputs.acceptedLengths = tcc::toTllmTensor(*output.medusaAcceptedTokensLen); medusaOutputs.acceptedLengthsCumSum = tcc::toTllmTensor(*output.medusaAcceptedLengthsCumSum); medusaOutputs.pathsOffsets = tcc::toTllmTensor(*output.medusaPathsOffsets); TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__); return medusaOutputs; } template std::shared_ptr prepareOutputs( DecodingOutput& output, DecodingInput::TensorPtr const& inputLengths, DecodingOutput::TensorPtr& logProbsTiled) { TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__); auto outputParams = std::make_shared(tcc::toTllmTensor(*output.ids)); outputParams->newTokens = tcc::toTllmTensor(*output.newTokens); if (output.cumLogProbs) { outputParams->cum_log_probs = tcc::toTllmTensor(*output.cumLogProbs); } if (output.parentIds) { outputParams->parent_ids = tcc::toTllmTensor(*output.parentIds); } if (output.cacheIndirection) { outputParams->tgt_cache_indirection = tcc::toTllmTensor(*output.cacheIndirection); } if (output.finished) { outputParams->finished = tcc::toTllmTensor(*output.finished); } if (output.finishedSum) { outputParams->finished_sum = tcc::toTllmTensor(*output.finishedSum); } if (output.lengths) { outputParams->sequence_length = tcc::toTllmTensor(*output.lengths); } if (output.logProbs) { outputParams->output_log_probs = tcc::toTllmTensor(*output.logProbs); outputParams->output_log_probs_tiled = tcc::toTllmTensor(*logProbsTiled); } outputParams->beamHypotheses = std::make_unique(); if (output.beamHypotheses.batchDones) { outputParams->beamHypotheses->batchDones = bufferCast(*output.beamHypotheses.batchDones); } if (output.beamHypotheses.cumLogProbsCBA) { outputParams->beamHypotheses->cumLogProbsCBA = bufferCast(*output.beamHypotheses.cumLogProbsCBA); } if (output.beamHypotheses.logProbsCBA) { outputParams->beamHypotheses->logProbsCBA = bufferCast(*output.beamHypotheses.logProbsCBA); } if (output.beamHypotheses.minNormedScoresCBA) { outputParams->beamHypotheses->minNormedScoresCBA = bufferCast(*output.beamHypotheses.minNormedScoresCBA); } if (output.beamHypotheses.normedScoresCBA) { outputParams->beamHypotheses->normedScoresCBA = bufferCast(*output.beamHypotheses.normedScoresCBA); } if (output.beamHypotheses.numBeamsCBA) { outputParams->beamHypotheses->numBeamsCBA = bufferCast(*output.beamHypotheses.numBeamsCBA); } if (output.beamHypotheses.outputIdsCBA) { outputParams->beamHypotheses->outputIdsCBA = bufferCast(*output.beamHypotheses.outputIdsCBA); } if (output.beamHypotheses.sequenceLengthsCBA) { outputParams->beamHypotheses->sequenceLengthsCBA = bufferCast(*output.beamHypotheses.sequenceLengthsCBA); } if (inputLengths) { outputParams->beamHypotheses->inputLengths = bufferCast(*inputLengths); } // Medusa if (output.medusaOutputs) { outputParams->medusaOutputs = prepareMedusaOutputs(output.medusaOutputs.value()); } TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__); return outputParams; } } // namespace template bool GptDecoder::forward(DecodingOutput& output, DecodingInput const& input) { TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__); auto forwardParams = prepareInputs(input, mMaxBatchSize); auto outputParams = prepareOutputs(output, input.lengths, mLogProbsTiled); auto const maxBatchSize = input.maxBatchSize; BufferManager::ITensorPtr finishedSum; std::int32_t* finishedSumHost = nullptr; if (input.sequenceLimitLength && output.finished) { if (output.finishedSum) { finishedSumHost = bufferCast(*output.finishedSum); } else { finishedSum = BufferManager::pinned(ITensor::makeShape({maxBatchSize}), nvinfer1::DataType::kINT32); outputParams->finished_sum = tcc::toTllmTensor(*finishedSum); finishedSumHost = bufferCast(*finishedSum); } for (SizeType32 bi = 0; bi < maxBatchSize; ++bi) { finishedSumHost[bi] = 0; } } mDynamicDecodeLayer->forward(outputParams, forwardParams); if (finishedSumHost) { auto const numToFinish = output.finished->getSize(); TLLM_CUDA_CHECK(::cudaStreamSynchronize(mDynamicDecodeLayer->getStream())); SizeType32 finishedSum = 0; for (SizeType32 bi = 0; bi < maxBatchSize; ++bi) { finishedSum += finishedSumHost[bi]; } return numToFinish == static_cast(finishedSum); } else { return false; } } template void GptDecoder::forwardAsync(DecodingOutput& output, DecodingInput const& input) { TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__); auto forwardParams = prepareInputs(input, mMaxBatchSize); auto outputParams = prepareOutputs(output, input.lengths, mLogProbsTiled); mDynamicDecodeLayer->forward(outputParams, forwardParams); TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__); } // Must be similar to [cpp/tensorrt_llm/thop/gatherTreeOp.cpp] gatherTree template void GptDecoder::gatherTree(ITensor& finalOutputIds, DecodingOutput const& decodingOutput, DecodingInput const& decodingInput, BufferManager const& manager) { TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__); auto const& finalOutputIdsShape = finalOutputIds.getShape(); auto const& decodingOutputIdsShape = decodingOutput.ids->getShape(); auto const batchSize = finalOutputIdsShape.d[0]; auto const beamWidth = finalOutputIdsShape.d[1]; auto const maxSeqLength = finalOutputIdsShape.d[2]; TLLM_CHECK_WITH_INFO(beamWidth > 1, "gatherTree is only needed for beam search."); TLLM_CHECK_WITH_INFO(decodingOutputIdsShape.d[0] == batchSize, common::fmtstr("Decoder batch size (" FMT_DIM ") does not match final batch size (" FMT_DIM ")", decodingOutputIdsShape.d[0], batchSize)); TLLM_CHECK_WITH_INFO(decodingOutputIdsShape.d[1] == beamWidth, common::fmtstr("Decoder beam width (" FMT_DIM ") does not match final beam width (" FMT_DIM ")", decodingOutputIdsShape.d[1], beamWidth)); TLLM_CHECK_WITH_INFO(decodingOutputIdsShape.d[2] <= maxSeqLength, common::fmtstr("Decoder seq length size (" FMT_DIM ") is too large for final seq length (" FMT_DIM ")", decodingOutputIdsShape.d[2], maxSeqLength)); auto const& stream = manager.getStream().get(); tensorrt_llm::kernels::invokeInitializeOutput(bufferCast(finalOutputIds), bufferCast(*decodingInput.endIds), batchSize * beamWidth, maxSeqLength, stream); sync_check_cuda_error(); tensorrt_llm::kernels::BeamHypotheses bh; bh.nBatchSize = batchSize; bh.nBeamWidth = beamWidth; bh.nMaxSeqLen = maxSeqLength; bh.lengthPenalties = nullptr; // TODO (bhsueh): A gpu tensor used in invokeInsertUnfinishedPath // default value (1.0f) will be used when it is nullptr bh.inputLengths = bufferCast(*decodingInput.lengths); bh.outputIds = bufferCast(finalOutputIds); bh.logProbs = bufferCast(*mLogProbsTiled); bh.sequenceLengths = bufferCast(*decodingOutput.lengths); bh.cumLogProbs = bufferCast(*decodingOutput.cumLogProbs); bh.outputIdsCBA = bufferCast(*decodingOutput.beamHypotheses.outputIdsCBA); bh.logProbsCBA = bufferCast(*decodingOutput.beamHypotheses.logProbsCBA); bh.sequenceLengthsCBA = bufferCast(*decodingOutput.beamHypotheses.sequenceLengthsCBA); bh.cumLogProbsCBA = bufferCast(*decodingOutput.beamHypotheses.cumLogProbsCBA); bh.normedScoresCBA = bufferCast(*decodingOutput.beamHypotheses.normedScoresCBA); bh.numBeamsCBA = bufferCast(*decodingOutput.beamHypotheses.numBeamsCBA); bh.minNormedScoresCBA = bufferCast(*decodingOutput.beamHypotheses.minNormedScoresCBA); bh.batchDones = bufferCast(*decodingOutput.beamHypotheses.batchDones); bh.finished = reinterpret_cast( bufferCast(*decodingOutput.finished)); bh.outputIdsUnfinish = bufferCast(*decodingOutput.ids); bh.parentIdsUnfinish = bufferCast(*decodingOutput.parentIds); // This is where transpose is done tensorrt_llm::kernels::invokeInsertUnfinishedPath(bh, stream); sync_check_cuda_error(); tensorrt_llm::kernels::invokeFinalize(bh, stream); sync_check_cuda_error(); TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__); } namespace tensorrt_llm::runtime { template class GptDecoder; template class GptDecoder; } // namespace tensorrt_llm::runtime void IGptDecoder::acceptDraftTokensByIds(ITensor const& targetTokenIds, ITensor const& draftTokenIds, ITensor const& contextLengths, ITensor const& numDraftTokens, ITensor& sequenceLengths, ITensor const& finishedVec, ITensor& finishedFinal, ITensor& finishedSum, ITensor const& batchSlots, BufferManager::CudaStreamPtr const& stream) { TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__); auto const finishedVecShape = finishedVec.getShape(); auto const maxBatchSize = finishedVecShape.d[1]; auto const batchSlotsShape = batchSlots.getShape(); auto const batchSize = batchSlotsShape.d[0]; auto const targetTokenIdsShape = targetTokenIds.getShape(); auto const beamWidth = targetTokenIdsShape.d[1]; auto const maxSeqLength = targetTokenIdsShape.d[2]; auto const maxDraftTokens = draftTokenIds.getShape().d[1]; TLLM_CHECK_WITH_INFO(beamWidth == 1, common::fmtstr("Beam width (" FMT_DIM ") > 1 is not supported for the speculative decoding", beamWidth)); TLLM_CHECK_WITH_INFO(batchSize <= maxBatchSize, common::fmtstr("Batch size (" FMT_DIM ") is not smaller or equal to max batch size (" FMT_DIM ")", batchSize, maxBatchSize)); TLLM_CHECK_WITH_INFO(draftTokenIds.getShape().d[0] == maxBatchSize, common::fmtstr("Draft tokens batch size (" FMT_DIM ") is not equal to target batch size (" FMT_DIM ")", draftTokenIds.getShape().d[0], maxBatchSize)); TLLM_CHECK_WITH_INFO(contextLengths.getShape().d[0] == maxBatchSize, common::fmtstr("Context length batch size (" FMT_DIM ") is not equal to batch size (" FMT_DIM ")", contextLengths.getShape().d[0], maxBatchSize)); TLLM_CHECK_WITH_INFO(numDraftTokens.getShape().d[0] == maxBatchSize, common::fmtstr("Num draft tokens batch size (" FMT_DIM ") is not equal to batch size (" FMT_DIM ")", numDraftTokens.getShape().d[0], maxBatchSize)); TLLM_CHECK_WITH_INFO(sequenceLengths.getShape().d[0] == maxBatchSize, common::fmtstr("Sequence length batch size (" FMT_DIM ") is not equal to batch size (" FMT_DIM ")", sequenceLengths.getShape().d[0], maxBatchSize)); tensorrt_llm::kernels::invokeAcceptDraftTokensByIds(bufferCast(draftTokenIds), bufferCast(targetTokenIds), bufferCast(contextLengths), bufferCast(numDraftTokens), bufferCast(sequenceLengths), reinterpret_cast( bufferCast(finishedVec)), reinterpret_cast( bufferCast(finishedFinal)), bufferCast(finishedSum), bufferCast(batchSlots), batchSize, maxBatchSize, beamWidth, maxSeqLength, maxDraftTokens, stream->get()); sync_check_cuda_error(); TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__); } void IGptDecoder::acceptDraftTokensByLogits(ITensor& draftLogits, ITensor const& targetLogits, ITensor& draftProbs, ITensor& targetProbs, ITensor const& numDraftTokens, ITensor& finished, ITensor const& batchSlots, SizeType32 vocabSize, SizeType32 vocabSizePadded, bool useRandomAcceptThreshold, float randomAcceptThreshold, curandState_t* curandState, BufferManager::CudaStreamPtr const& stream) { TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__); auto const draftLogitsShape = draftLogits.getShape(); auto const maxBatchSize = draftLogitsShape.d[0]; auto const maxTokensPerStep = draftLogitsShape.d[1]; auto const batchSlotsShape = batchSlots.getShape(); auto const batchSize = batchSlotsShape.d[0]; auto constexpr beamWidth = 1; TLLM_CHECK_WITH_INFO( beamWidth == 1, common::fmtstr("Beam width (%d) > 1 is not supported for the speculative decoding", beamWidth)); TLLM_CHECK(draftLogitsShape.d[2] == vocabSize); if (draftLogits.getDataType() == nvinfer1::DataType::kFLOAT) { tensorrt_llm::kernels::acceptDraftTokensByLogits(bufferCast(draftLogits), const_cast(reinterpret_cast(bufferCast(targetLogits))), bufferCast(draftProbs), bufferCast(targetProbs), bufferCast(numDraftTokens), reinterpret_cast( bufferCast(finished)), curandState, bufferCast(batchSlots), batchSize, maxBatchSize, beamWidth, vocabSize, vocabSizePadded, maxTokensPerStep, useRandomAcceptThreshold, randomAcceptThreshold, stream->get()); } else if (draftLogits.getDataType() == nvinfer1::DataType::kHALF) { tensorrt_llm::kernels::acceptDraftTokensByLogits(bufferCast(draftLogits), const_cast(reinterpret_cast(bufferCast(targetLogits))), bufferCast(draftProbs), bufferCast(targetProbs), bufferCast(numDraftTokens), reinterpret_cast( bufferCast(finished)), curandState, bufferCast(batchSlots), batchSize, maxBatchSize, beamWidth, vocabSize, vocabSizePadded, maxTokensPerStep, useRandomAcceptThreshold, randomAcceptThreshold, stream->get()); } else { TLLM_THROW("Incorrect logits dtype. Only float32 and float16 are supported"); } sync_check_cuda_error(); TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__); }