/* * Copyright (c) 2022-2024, NVIDIA CORPORATION. All rights reserved. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ #include "tensorrt_llm/thop/dynamicDecodeOp.h" #include "tensorrt_llm/executor/types.h" #include "tensorrt_llm/kernels/decodingCommon.h" #include "tensorrt_llm/runtime/bufferManager.h" #include "tensorrt_llm/runtime/common.h" #include "tensorrt_llm/runtime/gptDecoder.h" #include "tensorrt_llm/runtime/iTensor.h" #include "tensorrt_llm/thop/thUtils.h" #include #include namespace th = torch; namespace tle = tensorrt_llm::executor; namespace tr = tensorrt_llm::runtime; namespace tl = tensorrt_llm::layers; namespace tk = tensorrt_llm::kernels; namespace torch_ext { template FtDynamicDecode::FtDynamicDecode(size_t const maxBatchSize, size_t const maxBeamWidth, size_t const vocabSize, size_t const vocabSizePadded, int const tensorParaSize, int const pipelineParaSize) { TLLM_CHECK_WITH_INFO(vocabSizePadded % tensorParaSize == 0, tensorrt_llm::common::fmtstr( "vocabSize (%ld) is not multiple of tensorParaSize (%d).", vocabSizePadded, tensorParaSize)); auto const decodingDomain = tl::DecoderDomain(maxBatchSize, maxBeamWidth, vocabSize, vocabSizePadded); auto stream = at::cuda::getCurrentCUDAStream().stream(); auto const currentDeviceId = c10::cuda::current_device(); auto cudaStreamPtr = std::make_shared(stream, currentDeviceId, false); auto bufferManager = std::make_shared(cudaStreamPtr); mFinishedSum = bufferManager->pinnedPool( tr::ITensor::makeShape({static_cast(maxBatchSize)}), nvinfer1::DataType::kINT32); mDynamicDecodeLayer = std::make_shared>(tle::DecodingMode::Auto(), decodingDomain, bufferManager); mBatchSlots = tr::getDefaultBatchSlots(maxBatchSize); mDecodingWorkspace = std::make_unique(bufferManager, decodingDomain, tensorrt_llm::runtime::TRTDataType::value, mDynamicDecodeLayer->getWorkspaceSize()); } namespace { template void safeInsert(th::optional& tensor, std::optional>& arg) { if (tensor.has_value()) { auto shape = convert_shape(tensor.value()); size_t const size = tensorrt_llm::runtime::ITensor::volume(shape); auto ptr = get_ptr(tensor.value()); arg = std::vector(ptr, ptr + size); } } template void safeUpdate(th::optional& tensor, std::optional& arg) { if (tensor.has_value()) { arg = convert_tensor(tensor.value()); } } template void safeUpdate(th::optional& tensor, std::optional& arg) { if (tensor.has_value()) { arg = convert_tensor(tensor.value()); } } template void safeUpdateScalar(th::optional& tensor, std::optional& arg, std::string const& name) { if (tensor.has_value()) { auto accessor = tensor->accessor(); TLLM_CHECK_WITH_INFO(accessor.size(0) == 1, name + " must be a scalar"); arg = accessor[0]; } } template void safeUpdatePtr(th::optional& tensor, T*& ptr) { if (tensor.has_value()) { ptr = get_ptr(tensor.value()); } } } // namespace template void FtDynamicDecode::setup(size_t const batch_size, size_t const beam_width, th::optional runtime_top_k_opt, th::optional runtime_top_p_opt, th::optional temperature_opt, th::optional repetition_penalty_opt, th::optional presence_penalty_opt, th::optional frequency_penalty_opt, th::optional prompt_ignore_length_opt, th::optional min_length_opt, th::optional length_penalty_opt, th::optional early_stopping_opt, th::optional beam_search_diversity_rate_opt, th::optional random_seed_opt, th::optional top_p_decay_opt, th::optional top_p_min_opt, th::optional top_p_reset_ids_opt, th::optional no_repeat_ngram_size_opt, th::optional min_p_opt, bool output_log_probs, bool cum_log_probs) { TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__); mBeamWidth = beam_width; auto setupParams = std::make_shared(); auto penaltyParams = std::make_shared(); auto banWordsParams = std::make_shared(); safeInsert(temperature_opt, penaltyParams->temperature); safeInsert(repetition_penalty_opt, penaltyParams->repetitionPenalty); safeInsert(presence_penalty_opt, penaltyParams->presencePenalty); safeInsert(frequency_penalty_opt, penaltyParams->frequencyPenalty); safeInsert(prompt_ignore_length_opt, penaltyParams->promptIgnoreLength); safeInsert(min_length_opt, penaltyParams->minLength); safeInsert(no_repeat_ngram_size_opt, banWordsParams->noRepeatNgramSize); if (beam_width == 1) { auto decodingParams = std::make_shared(); safeInsert(runtime_top_k_opt, decodingParams->runtimeTopK); safeInsert(runtime_top_p_opt, decodingParams->runtimeTopP); safeInsert(top_p_decay_opt, decodingParams->topPDecay); safeInsert(top_p_min_opt, decodingParams->topPMin); safeInsert(top_p_reset_ids_opt, decodingParams->topPResetIds); safeInsert(min_p_opt, decodingParams->runtimeMinP); decodingParams->outputLogProbs = std::vector({output_log_probs}); decodingParams->cumLogProbs = std::vector({cum_log_probs}); safeInsert(random_seed_opt, decodingParams->randomSeed); setupParams->decodingParams = decodingParams; } else { auto decodingParams = std::make_shared(); safeInsert(beam_search_diversity_rate_opt, decodingParams->beamSearchDiversityRate); safeInsert(length_penalty_opt, decodingParams->lengthPenalty); safeInsert(early_stopping_opt, decodingParams->earlyStopping); decodingParams->outputLogProbs = std::vector({output_log_probs}); decodingParams->cumLogProbs = std::vector({cum_log_probs}); safeInsert(random_seed_opt, decodingParams->randomSeed); setupParams->decodingParams = decodingParams; } // TODO: insert "normalizeLogProbs" and "topKMedusaHeads" setupParams->penaltyParams = penaltyParams; setupParams->banWordsParams = banWordsParams; mDynamicDecodeLayer->setup( batch_size, beam_width, tr::ITensor::slice(mBatchSlots, 0, batch_size), setupParams, mDecodingWorkspace); TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__); } template void FtDynamicDecode::forward(th::Tensor const& logits, int const step, int const maxInputLength, int const maxAttentionWindow, int const sinkTokenLength, uint64_t const ite, int const localBatchSize, th::Tensor endId, th::optional embeddingBiasOpt, th::optional inputLengthsOpt, th::optional sequenceLimitLengthOpt, th::optional stopWordsListPtrsOpt, th::optional stopWordsLensOpt, int32_t const maxStopWordsLen, th::optional badWordsListPtrsOpt, th::optional badWordsLensOpt, int32_t const maxBadWordsLen, th::optional srcCacheIndirectionOpt, th::Tensor& outputTokenIds, th::Tensor& newTokens, th::Tensor& shouldStop, th::optional finishedInput, th::optional finishedOutput, th::optional sequenceLengthsOpt, th::optional cumLogProbsOpt, th::optional outputLogProbsOpt, th::optional outputLogProbsTiledOpt, th::optional parentIdsOpt, th::optional tgtCacheIndirectionOpt, th::optional beamHypsOutputIdsCbaOpt, th::optional beamHypsSeqLenCbaOpt, th::optional beamHypsCumLogProbsCbaOpt, th::optional beamHypsNormedScoresCbaOpt, th::optional beamHypsLogProbsCbaOpt, th::optional beamHypsMinNormedScoresOpt, th::optional beamHypsNumBeamsOpt, th::optional beamHypsIsDoneOpt, bool const useBeamHyps) { TLLM_CHECK_WITH_INFO(mBeamWidth.has_value(), "Beam width is not set. setup() must be called before forward()"); auto const isBeamSearch = mBeamWidth.value() > 1; std::shared_ptr forwardParams; tr::ITensor::SharedConstPtr batchSlotsSlice = tr::ITensor::slice(mBatchSlots, 0, localBatchSize); if (isBeamSearch) { forwardParams = std::make_shared(convert_tensor(endId), batchSlotsSlice, step, static_cast(ite), localBatchSize, maxAttentionWindow, sinkTokenLength); } else { forwardParams = std::make_shared( convert_tensor(endId), batchSlotsSlice, step, static_cast(ite), localBatchSize); } forwardParams->logits = convert_tensor(logits); forwardParams->stopCriteriaInputs = std::make_shared(localBatchSize); forwardParams->banWordsInputs = std::make_shared(localBatchSize); safeUpdate(embeddingBiasOpt, forwardParams->embeddingBias); safeUpdate(inputLengthsOpt, forwardParams->inputLengths); safeUpdate(sequenceLimitLengthOpt, forwardParams->stopCriteriaInputs->sequenceLimitLength); safeUpdate(stopWordsListPtrsOpt, forwardParams->stopCriteriaInputs->stopWordsPtr); safeUpdate(stopWordsLensOpt, forwardParams->stopCriteriaInputs->stopWordsLengths); forwardParams->stopCriteriaInputs->maxStopWordsLen = maxStopWordsLen; safeUpdate(badWordsListPtrsOpt, forwardParams->banWordsInputs->badWordsPtr); safeUpdate(badWordsLensOpt, forwardParams->banWordsInputs->badWordsLengths); forwardParams->banWordsInputs->maxBadWordsLen = maxBadWordsLen; safeUpdate(srcCacheIndirectionOpt, forwardParams->srcCacheIndirection); tr::ITensor::SharedPtr outputIdsConverted = convert_tensor(outputTokenIds); std::shared_ptr outputParams; if (isBeamSearch) { outputParams = std::make_shared(outputIdsConverted); } else { outputParams = std::make_shared(outputIdsConverted); } outputParams->newTokens = convert_tensor(newTokens); safeUpdate(finishedInput, forwardParams->finished); safeUpdate(finishedOutput, outputParams->finished); safeUpdate(sequenceLengthsOpt, outputParams->sequenceLength); safeUpdate(cumLogProbsOpt, outputParams->cumLogProbs); safeUpdate(outputLogProbsOpt, outputParams->outputLogProbs); safeUpdate(outputLogProbsTiledOpt, outputParams->outputLogProbsTiled); safeUpdate(parentIdsOpt, outputParams->parentIds); tr::SizeType32* finishedSumHost = nullptr; if (forwardParams->stopCriteriaInputs->sequenceLimitLength && outputParams->finished.has_value()) { // Skip the initialization and later calculation if there is no limit of sequence length or no finished beam outputParams->finishedSum = mFinishedSum; finishedSumHost = tr::bufferCast(*mFinishedSum); for (int32_t bi = 0; bi < localBatchSize; ++bi) { finishedSumHost[bi] = 0; } } if (isBeamSearch) { auto outputsBeamSearch = std::dynamic_pointer_cast(outputParams); TLLM_CHECK_WITH_INFO(tgtCacheIndirectionOpt.has_value(), "tgtCacheIndirection must be set for beam search"); outputsBeamSearch->tgtCacheIndirection = convert_tensor(tgtCacheIndirectionOpt.value()); if (useBeamHyps) { // Additional parameters for beam search outputsBeamSearch->beamHypotheses = std::make_unique(); safeUpdatePtr(beamHypsIsDoneOpt, outputsBeamSearch->beamHypotheses->batchDones); safeUpdatePtr(beamHypsCumLogProbsCbaOpt, outputsBeamSearch->beamHypotheses->cumLogProbsCBA); safeUpdatePtr(beamHypsLogProbsCbaOpt, outputsBeamSearch->beamHypotheses->logProbsCBA); safeUpdatePtr(beamHypsMinNormedScoresOpt, outputsBeamSearch->beamHypotheses->minNormedScoresCBA); safeUpdatePtr(beamHypsNormedScoresCbaOpt, outputsBeamSearch->beamHypotheses->normedScoresCBA); safeUpdatePtr(beamHypsNumBeamsOpt, outputsBeamSearch->beamHypotheses->numBeamsCBA); safeUpdatePtr(beamHypsOutputIdsCbaOpt, outputsBeamSearch->beamHypotheses->outputIdsCBA); safeUpdatePtr(beamHypsSeqLenCbaOpt, outputsBeamSearch->beamHypotheses->sequenceLengthsCBA); } } mDynamicDecodeLayer->forwardAsync(outputParams, forwardParams, mDecodingWorkspace); if (finishedSumHost) { TLLM_CUDA_CHECK(::cudaStreamSynchronize(mDynamicDecodeLayer->getStream())); uint32_t numRealFinished = 0; for (int32_t bi = 0; bi < localBatchSize; ++bi) { numRealFinished += finishedSumHost[bi]; } auto const numToFinish = outputParams->finished.value()->getSize(); auto shouldStopAccessor = shouldStop.accessor(); shouldStopAccessor[0] = numToFinish == numRealFinished; } } DynamicDecodeOp::DynamicDecodeOp(int64_t const maxBatchSize, int64_t const maxBeamWidth, int64_t const vocabSize, int64_t const vocabSizePadded, int64_t const tensorParaSize, int64_t const pipelineParaSize, at::ScalarType const scalarType) : maxBatchSize_(static_cast(maxBatchSize)) , maxBeamWidth_(static_cast(maxBeamWidth)) , vocabSize_(static_cast(vocabSize)) , vocabSizePadded_(static_cast(vocabSizePadded)) , tensorParaSize_(static_cast(tensorParaSize)) , pipelineParaSize_(static_cast(pipelineParaSize)) , scalarType_(scalarType) { TLLM_LOG_DEBUG(__PRETTY_FUNCTION__); createInstance(); } void DynamicDecodeOp::createInstance() { dynamicDecode_.reset(); switch (scalarType_) { case at::ScalarType::Float: dynamicDecode_ = std::make_unique>( maxBatchSize_, maxBeamWidth_, vocabSize_, vocabSizePadded_, tensorParaSize_, pipelineParaSize_); break; case at::ScalarType::Half: dynamicDecode_ = std::make_unique>( maxBatchSize_, maxBeamWidth_, vocabSize_, vocabSizePadded_, tensorParaSize_, pipelineParaSize_); break; default: throw std::runtime_error("Wrong tensor type."); } } void DynamicDecodeOp::setup(int64_t const batchSize, int64_t const beamWidth, th::optional runtimeTopKOpt, th::optional runtimeTopPOpt, th::optional temperatureOpt, th::optional repetitionPenaltyOpt, th::optional presencePenaltyOpt, th::optional frequencyPenaltyOpt, th::optional promptIgnoreLengthOpt, th::optional minLengthOpt, th::optional lengthPenaltyOpt, th::optional earlyStoppingOpt, th::optional beamSearchDiversityRateOpt, th::optional randomSeedOpt, th::optional topPDecayOpt, th::optional topPMinOpt, th::optional topPResetIdsOpt, th::optional noRepeatNgramSizeOpt, th::optional minPOpt, bool outputLogProbs, bool cumLogProbs) { // TODO: Revise DynamicDecodeLayer and make the decode arguments consistent. // TODO: add parameters "normalizeLogProbs" and "topKMedusaHeads" CHECK_OPTIONAL_CPU_INPUT(runtimeTopKOpt, torch::kInt32); CHECK_OPTIONAL_CPU_INPUT(runtimeTopPOpt, torch::kFloat); CHECK_OPTIONAL_CPU_INPUT(temperatureOpt, torch::kFloat); CHECK_OPTIONAL_CPU_INPUT(repetitionPenaltyOpt, torch::kFloat); CHECK_OPTIONAL_CPU_INPUT(presencePenaltyOpt, torch::kFloat); CHECK_OPTIONAL_CPU_INPUT(frequencyPenaltyOpt, torch::kFloat); CHECK_OPTIONAL_CPU_INPUT(promptIgnoreLengthOpt, torch::kInt32); CHECK_OPTIONAL_CPU_INPUT(minLengthOpt, torch::kInt32); CHECK_OPTIONAL_CPU_INPUT(lengthPenaltyOpt, torch::kFloat); CHECK_OPTIONAL_CPU_INPUT(earlyStoppingOpt, torch::kInt32); CHECK_OPTIONAL_CPU_INPUT(beamSearchDiversityRateOpt, torch::kFloat); CHECK_OPTIONAL_CPU_INPUT(randomSeedOpt, torch::kInt64); CHECK_OPTIONAL_INPUT(topPDecayOpt, torch::kFloat); CHECK_OPTIONAL_INPUT(topPMinOpt, torch::kFloat); CHECK_OPTIONAL_INPUT(topPResetIdsOpt, torch::kInt32); CHECK_OPTIONAL_CPU_INPUT(noRepeatNgramSizeOpt, torch::kInt32); CHECK_OPTIONAL_CPU_INPUT(minPOpt, torch::kFloat); dynamicDecode_->setup(static_cast(batchSize), static_cast(beamWidth), runtimeTopKOpt, runtimeTopPOpt, temperatureOpt, repetitionPenaltyOpt, presencePenaltyOpt, frequencyPenaltyOpt, promptIgnoreLengthOpt, minLengthOpt, lengthPenaltyOpt, earlyStoppingOpt, beamSearchDiversityRateOpt, randomSeedOpt, topPDecayOpt, topPMinOpt, topPResetIdsOpt, noRepeatNgramSizeOpt, minPOpt, outputLogProbs, cumLogProbs); } th::Tensor DynamicDecodeOp::forward( // Inputs BS: batchSize, BM: beamWidth, MSL: maxSeqLength, V: vocabSize, VP: vocabSizePadded th::Tensor const& logits, // [BS, BM, VP], T, variables for input int64_t const step, // int64_t const maxInputLength, // int64_t const maxAttentionWindow, // int64_t const sinkTokenLength, // int64_t const ite, // int64_t const localBatchSize, // th::Tensor const endId, // [BS*BM], int th::optional embeddingBiasOpt, // [VP], T th::optional inputLengthsOpt, // [BS*BM], int, length of input contexts th::optional sequenceLimitLengthOpt, // [BS, 1], int th::optional stopWordsListPtrsOpt, // [BS][2, stopWordsLength], int64 th::optional stopWordsLensOpt, // [BS], int int64_t const maxStopWordsLen, // th::optional badWordsListPtrsOpt, // [BS][2, badWordsLength], int64 th::optional badWordsLensOpt, // [BS], int int64_t const maxBadWordsLen, // th::optional srcCacheIndirectionOpt, // [localBS, BM, MSL], int // Outputs th::Tensor outputTokenIds, // [BS, BM, MSL], variables for output th::Tensor newTokens, // [BS, BM, 1], int th::optional finishedInput, // [BS, BM], uint8 th::optional finishedOutput, // [BS, BM], uint8 th::optional sequenceLengthsOpt, // [BS*BM], int, length of the current sequences th::optional cumLogProbsOpt, // [BS, BM], float th::optional outputLogProbsOpt, // [BS, BM, MSL], float th::optional outputLogProbsTiledOpt, // [MSL, BS, BM], float, transpose of outputLogProbsOpt th::optional parentIdsOpt, // [BS, BM, MSL], int th::optional tgtCacheIndirectionOpt, // [localBS, BM, MSL], int th::optional beamHypsOutputIdsCbaOpt, // [BS, BM*2, MSL], int th::optional beamHypsSeqLenCbaOpt, // [BS, BM*2], int th::optional beamHypsCumLogProbsCbaOpt, // [BS, BM*2], float th::optional beamHypsNormedScoresCbaOpt, // [BS, BM*2], float th::optional beamHypsLogProbsCbaOpt, // [BS, BM*2, MSL], float th::optional beamHypsMinNormedScoresOpt, // [BS], float th::optional beamHypsNumBeamsOpt, // [BS], int th::optional beamHypsIsDoneOpt, // [BS], bool bool const useBeamHyps // ) { CHECK_INPUT(logits, scalarType_); TLLM_CHECK_WITH_INFO(logits.dim() == 3, "logits is of shape (batchSize, beamWidth, vocabSizePadded), but got dim=%d shape=%s", (int) logits.dim(), tensorrt_llm::runtime::ITensor::toString(convert_shape(logits)).c_str()); TLLM_CHECK_WITH_INFO(static_cast(logits.size(2)) == vocabSizePadded_, "logits is of shape (batchSize, beamWidth, vocabSize(%ld)), but got the last dim=%ld.", vocabSizePadded_, static_cast(logits.size(2))); CHECK_INPUT(endId, torch::kInt32); CHECK_OPTIONAL_INPUT(embeddingBiasOpt, scalarType_); CHECK_OPTIONAL_INPUT(inputLengthsOpt, torch::kInt32); CHECK_OPTIONAL_INPUT(sequenceLimitLengthOpt, torch::kInt32); CHECK_OPTIONAL_INPUT(stopWordsListPtrsOpt, torch::kInt64); CHECK_OPTIONAL_INPUT(stopWordsLensOpt, torch::kInt32); CHECK_OPTIONAL_INPUT(badWordsListPtrsOpt, torch::kInt64); CHECK_OPTIONAL_INPUT(badWordsLensOpt, torch::kInt32); CHECK_OPTIONAL_INPUT(srcCacheIndirectionOpt, torch::kInt32); CHECK_INPUT(outputTokenIds, torch::kInt32); CHECK_INPUT(newTokens, torch::kInt32); CHECK_OPTIONAL_INPUT(finishedInput, torch::kUInt8); CHECK_OPTIONAL_INPUT(finishedOutput, torch::kUInt8); CHECK_OPTIONAL_INPUT(sequenceLengthsOpt, torch::kInt32); CHECK_OPTIONAL_INPUT(cumLogProbsOpt, torch::kFloat32); CHECK_OPTIONAL_INPUT(outputLogProbsOpt, torch::kFloat32); CHECK_OPTIONAL_INPUT(outputLogProbsTiledOpt, torch::kFloat32); CHECK_OPTIONAL_INPUT(parentIdsOpt, torch::kInt32); CHECK_OPTIONAL_INPUT(tgtCacheIndirectionOpt, torch::kInt32); th::Tensor shouldStop = torch::zeros({1}, torch::dtype(torch::kBool).requires_grad(false)); dynamicDecode_->forward( // Inputs logits, static_cast(step), static_cast(maxInputLength), static_cast(maxAttentionWindow), static_cast(sinkTokenLength), static_cast(ite), static_cast(localBatchSize), endId, embeddingBiasOpt, inputLengthsOpt, sequenceLimitLengthOpt, stopWordsListPtrsOpt, stopWordsLensOpt, static_cast(maxStopWordsLen), badWordsListPtrsOpt, badWordsLensOpt, static_cast(maxBadWordsLen), srcCacheIndirectionOpt, // Outputs outputTokenIds, newTokens, shouldStop, finishedInput, finishedOutput, sequenceLengthsOpt, cumLogProbsOpt, outputLogProbsOpt, outputLogProbsTiledOpt, parentIdsOpt, tgtCacheIndirectionOpt, beamHypsOutputIdsCbaOpt, beamHypsSeqLenCbaOpt, beamHypsCumLogProbsCbaOpt, beamHypsNormedScoresCbaOpt, beamHypsLogProbsCbaOpt, beamHypsMinNormedScoresOpt, beamHypsNumBeamsOpt, beamHypsIsDoneOpt, useBeamHyps); return shouldStop; } } // namespace torch_ext static auto trtllmGptContextDecoderTHS = torch::jit::class_("trtllm", "DynamicDecodeOp") .def(torch::jit::init()) .def("setup", &torch_ext::DynamicDecodeOp::setup) .def("forward", &torch_ext::DynamicDecodeOp::forward);