/* * Copyright (c) 2019-2024, NVIDIA CORPORATION. All rights reserved. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ #pragma once #include "gptKernels.h" #include "tensorrt_llm/kernels/beamSearchKernels.h" #include "tensorrt_llm/kernels/decodingCommon.h" #include "tensorrt_llm/runtime/common.h" #include #include #include namespace tensorrt_llm { namespace kernels { struct gatherTreeParam { // TODO rename the parameters int32_t* beams = nullptr; // [batchSize, beamWidth, maxSeqLen], workspace to put intermediate outputIds int32_t* sequenceLengths = nullptr; // [batchSize, beamWidth], total lengths of each query int32_t maxSequenceLengthFinalStep = 0; int32_t const* inputLengths = nullptr; // [batchSize, beamWidth] // response input lengths (used to slice the ids during postprocessing) int32_t* responseInputLengths = nullptr; int32_t maxSeqLen = 0; int32_t batchSize = 0; int32_t beamWidth = 0; int32_t const* stepIds = nullptr; // [maxSeqLen, batchSize, beamWidth] int32_t const* parentIds = nullptr; // [maxSeqLen, batchSize, beamWidth] int32_t const* endTokens = nullptr; // [batchSize], end token ids of each query int32_t* outputIds = nullptr; // the buffer to put finalized ids cudaStream_t stream; float* cumLogProbs = nullptr; // [batchSize, beamWidth] float lengthPenalty = 1.0f; int earlyStopping = 1; }; /* Do gatherTree on beam search to get final result. */ void invokeGatherTree(gatherTreeParam param); void invokeInsertUnfinishedPath(BeamHypotheses& bh, cudaStream_t stream); void invokeFinalize(BeamHypotheses& bh, cudaStream_t stream); void invokeInitializeOutput(runtime::TokenIdType* finalOutputIds, runtime::TokenIdType const* endIds, runtime::SizeType32 batchBeam, runtime::SizeType32 maxSeqLen, cudaStream_t stream); //! \brief Copies last numNewTokens (or 1 if numNewTokens == nullptr) tokens from outputIdsPtr //! to nextStepIds according to sequenceLengths. //! //! \param nextStepIds output buffer [maxTokensPerStep, maxBatchSize, maxBeamWidth], //! destination of the new tokens. //! \param outputIdsPtr input buffer [maxBatchSize][maxBeamWidth, maxSeqLen], //! array of pointers to the source of the copy. //! \param sequenceLengths input buffer [maxBatchSize], sequence length of the request //! in outputIdsPtr that includes all new tokens. It must be guaranteed that sequenceLengths <= maxSeqLen. //! \param numNewTokens input buffer [maxBatchSize], optional, number of tokens to be copied. //! If nullptr, only 1 token is copied. It must be guaranteed that numNewTokens <= sequenceLengths. //! \param batchSlots input buffer [batchSize], address map from local index //! to global index [0, batchSize] -> [0, maxBatchSize] //! \param batchSize current batch size //! \param maxBatchSize maximum batch size //! \param beamWidth current beam width //! \param maxSeqLen maximum sequence length //! \param maxTokensPerStep maximum tokens per step //! \param stream stream void invokeCopyNextStepIds(runtime::TokenIdType* nextStepIds, runtime::TokenIdType const* const* outputIdsPtr, runtime::SizeType32 const* sequenceLengths, runtime::SizeType32 const* numNewTokens, runtime::SizeType32 const* batchSlots, runtime::SizeType32 batchSize, runtime::SizeType32 maxBatchSize, runtime::SizeType32 beamWidth, runtime::SizeType32 maxSeqLen, runtime::SizeType32 maxTokensPerStep, cudaStream_t stream); //! \brief Accepts or rejects draft tokens based on the equality of draft and target tokens //! for speculative decoding. Target token is accepted if targetToken == draftToken. //! If number of accepted tokens N < maxDraftTokens, then function accepts N + 1 tokens of target model. //! sequenceLengths, finishedSum and finishedFinal are modified accordingly. //! //! \param draftIds input buffer [batchSize, maxDraftTokens]. //! Indices of the draft tokens. //! \param targetIds input buffer [batchSize, maxSeqLen]. Indices of the tokens decoded by the target model //! \param contextLengths input buffer [batchSize]. Context lengths of the requests without draft tokens //! \param numsDraftTokens input buffer [batchSize]. Number of draft tokens per request //! \param sequenceLengths input/output buffer [batchSize] sequence lengths of the requests in batch //! Modified in-place according to the accepted/rejected tokens //! \param finished input buffer [maxDraftTokens + 1, batchSize] finished states at each decoding iteration //! \param finishedFinal output buffer [batchSize] finished states after accepting/rejecting tokens //! \param finishedSum output buffer [1] total number of requests in batch that finished the execution //! \param batchSlots input buffer [batchSize], address map from local index //! to global index [0, batchSize] -> [0, maxBatchSize] //! \param batchSize current batch size //! \param maxBatchSize maximum batch size //! \param beamWidth beam width //! \param maxSeqLen maximum sequence length //! \param maxDraftTokens maximum number of draft tokens //! \param stream stream void invokeAcceptDraftTokensByIds(runtime::TokenIdType const* draftIds, runtime::TokenIdType const* targetIds, runtime::SizeType32 const* contextLengths, runtime::SizeType32 const* numsDraftTokens, runtime::SizeType32* sequenceLengths, FinishedState const* finished, FinishedState* finishedFinal, runtime::SizeType32* finishedSum, runtime::SizeType32 const* batchSlots, runtime::SizeType32 batchSize, runtime::SizeType32 maxBatchSize, runtime::SizeType32 beamWidth, runtime::SizeType32 maxSeqLen, runtime::SizeType32 maxDraftTokens, cudaStream_t stream); //! \brief Performs probabilistic acceptance of draft tokens based on their probability distributions. //! Corrects targetLogits for the next to the last accepted token //! according to https://openreview.net/pdf?id=C9NEblP8vS //! //! \param draftLogits input/output buffer [draftTokens, batchSize, beamWidth, vocabSize]. //! Initially contains token logits of the draft model. //! \param targetLogits input/output buffer [batchSize][draftTokens+1, beamWidth, vocabSize]. //! Vector of pointers to the logits. //! Initially contains token logits of the target model. //! It is modified in-place for next to the last accepted token such as //! P'(x) = norm(max(0, P_{n+1}(x) - Q_{n+1}(x))), where N < maxDraftTokens is number of accepted tokens. //! \param draftProbs output buffer [maxDraftTokens, batchSize, beamWidth, vocabSize]. //! Workspace buffer for token probabilities of the draft model. //! \param targetProbs output buffer [maxDraftTokens+1, batchSize, beamWidth, vocabSize]. //! Workspace buffer for token probabilities of the target model. //! \param numsDraftTokens input buffer [batchSize]. Number of draft tokens per request //! \param finished output buffer [draftTokens, batchSize, beamWidth]. //! At each step sets to NOT_FINISHED if token is accepted or SKIP_DECODING if token is not accepted //! \param curandState input buffer [batchSize]. Curand states properly //! initialized using invokeCurandInitialize per request. //! \param batchSlots input buffer [batchSize], address map from local index //! to global index [0, batchSize] -> [0, maxBatchSize] //! \param batchSize current batch size //! \param maxBatchSize maximum batch size //! \param beamWidth beam width //! \param vocabSize unpadded vocab size //! \param vocabSizePadded padded vocab size //! \param maxDraftTokens maximum number of draft tokens //! \param randomThreshold True if use uniformly sampled threshold for token acceptance //! \param constantThreshold threshold used to accept tokens if randomThreshold is false //! \param stream stream template void acceptDraftTokensByLogits(T* draftLogits, T** targetLogits, T* draftProbs, T* targetProbs, runtime::SizeType32 const* numsDraftTokens, FinishedState* finished, curandState_t* curandState, runtime::SizeType32 const* batchSlots, runtime::SizeType32 batchSize, runtime::SizeType32 maxBatchSize, runtime::SizeType32 beamWidth, runtime::SizeType32 vocabSize, runtime::SizeType32 vocabSizePadded, runtime::SizeType32 maxDraftTokens, bool randomThreshold, float constantThreshold, cudaStream_t stream); void invokeTransposeLogProbs(float* output_log_probs, float* output_log_probs_tiled, runtime::SizeType32 const* sequence_lengths, runtime::SizeType32 const* batchSlots, runtime::SizeType32 batch_size, runtime::SizeType32 max_batch_size, runtime::SizeType32 beam_width, runtime::SizeType32 max_seq_len, cudaStream_t stream); //! \brief verifies draft medusa tokens given target tokens. Modifies outputIds tensor accordingly filling it with //! accepted tokens. Fills logitsPtrs tensor with the pointers to the respective medusa logits tensor according //! to the next after the last accepted token. //! //! \param outputIds output buffer [maxBatchSize, maxSeqLen], input tokens. //! \param draftIds input buffer [maxBatchSize, maxDraftTokens], draft tokens //! \param targetIds input buffer [maxBatchSize, maxDraftTokens], tokens predicted from the target medusa head //! \param sequenceLengths input/output buffer [maxBatchSize], length of the data in outputIds without draft tokens //! Incrememnted according to the accepted length //! \param acceptedLengths output buffer [maxBatchSize], length of the data accepted tokens //! \param finishedFinal input buffer [maxBatchSize], finished states per request //! \param batchSlots input buffer [batchSize], address map from local index //! to global index [0, batchSize] -> [0, maxBatchSize] //! \param paths input buffer [maxBatchSize, maxTokensPerStep, maxNumHeads+1], //! paths to restore sequences from outputIds and targetIds. Should be filled with -1 for everything that is not path. //! \param endIds input buffer [maxBatchSize], EOS ids per request //! \param medusaLogits input buffer [maxNumHeads, maxBatchSize, maxTokensPerStep, vocabSize], pointer //! to the logits from medusa heads //! \param logitsPtrs output buffer [batchSize, maxNumHeads], contains pointers to the //! respective rows of the medusaLogits for the next after the accepted token //! \param curTokensPerStep current tokens to compute per step will be updated to //! targetTokensPerStep if curTokensPerStep == 1 //! \param targetTokensPerStep target values of tokens to compute per step //! \param bestPathIds output buffer [maxBatchSize], indices of the selected paths //! \param batchSize current batch size //! \param maxBatchSize maximum batch size //! \param vocabSize vocab size //! \param maxDraftTokens maximum sequence length of the sequence containing draft tokens //! \param maxSeqLen maximum sequence length of output ids //! \param maxNumHeads maximum number of medusa heads //! \param maxTokensPerStep maximum number of tokens per step configured in the system //! \param stream stream template void acceptDraftTokensByIdsWithPaths(runtime::TokenIdType* outputIds, runtime::TokenIdType const* draftIds, runtime::TokenIdType const* targetIds, runtime::SizeType32* sequenceLengths, runtime::SizeType32* acceptedLengths, FinishedState* finishedFinal, runtime::SizeType32 const* batchSlots, runtime::SizeType32 const* paths, runtime::TokenIdType const* endIds, T const** medusaLogits, T const** logitsPtrs, runtime::SizeType32* curTokensPerStep, runtime::SizeType32 const* targetTokensPerStep, runtime::SizeType32* bestPathIds, runtime::SizeType32 batchSize, runtime::SizeType32 maxBatchSize, runtime::SizeType32 vocabSize, runtime::SizeType32 maxDraftTokens, runtime::SizeType32 maxSeqLen, runtime::SizeType32 maxNumHeads, runtime::SizeType32 maxTokensPerStep, cudaStream_t stream); //! \brief assembles draft tokens to treeDraftIds from sourceDraftIds using indices of treeIds //! //! \param treeDraftIds output buffer [maxBatchSize, maxDraftTokens], output draft tokens //! scattered from sourceDraftIds according to treeIds111 //! \param sourceDraftIds input buffer [maxBatchSize, maxDraftTokens], draft tokens saved leanearly after //! sampling from Medusa heads with TopK. //! \param treeIds input buffer [maxBatchSize, maxDraftTokens], address map from sourceDraftIds to treeDraftIds //! [0, unqiueDraftTokens] -> [0, maxDraftTokens], where unqiueDraftTokens = sum(MedusaHeadsTopK) //! unqiueDraftTokens <= maxDraftTokens //! \param tokensPerStep input buffer [maxBatchSize], number of output draft tokens //! \param batchSlots input buffer [maxBatchSize], address map from local index //! to global index [0, batchSize] -> [0, maxBatchSize] //! \param maxDraftTokens maximum number of tokens per step configured in the system //! \param batchSize current batch size //! \param stream cuda stream void scatterMedusaDraftTokens(runtime::TokenIdType* treeDraftIds, runtime::TokenIdType const* sourceDraftIds, runtime::SizeType32 const* treeIds, runtime::SizeType32 const* tokensPerStep, runtime::SizeType32 const* batchSlots, runtime::SizeType32 maxDraftTokens, runtime::SizeType32 batchSize, cudaStream_t stream); //! \brief Linearly packs accepted paths in memory according to the accceptedLengths and bestPathIds //! //! \param acceptedLengthsCumSum input buffer [maxBatchSize + 1], exclusive sum of accepted lengths //! (indexed linearly in memory). //! \param pathsOffsets input buffer [maxBatchSize * maxDraftLen], slices of accepted paths packed in memory //! \param acceptedLengths input buffer [maxBatchSize], length of the data accepted tokens //! \param bestPathIds input buffer [maxBatchSize], indices of the selected paths //! \param paths input buffer [maxBatchSize, maxTokensPerStep, maxNumHeads+1], //! paths to restore sequences from outputIds and targetIds. Should be filled with -1 for everything that is not path. //! \param batchSlots input buffer [batchSize], address map from local index //! to global index [0, batchSize] -> [0, maxBatchSize] //! \param batchSize current batch size //! \param maxTokensPerStep maximum number of tokens per step configured in the system //! \param maxDraftTokens maximum sequence length of the sequence containing draft tokens //! \param stream stream void invokePackAcceptedPaths(runtime::SizeType32* acceptedLengthsCumSum, runtime::SizeType32* pathsOffsets, runtime::SizeType32 const* acceptedLengths, runtime::SizeType32 const* bestPathIds, runtime::SizeType32 const* paths, runtime::SizeType32 const* batchSlots, runtime::SizeType32 batchSize, runtime::SizeType32 maxTokensPerStep, runtime::SizeType32 maxNumDraftTokens, cudaStream_t stream); } // namespace kernels } // namespace tensorrt_llm