mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-01-14 06:27:45 +08:00
* Update TensorRT-LLM --------- Co-authored-by: Kota Tsuyuzaki <bloodeagle40234@gmail.com> Co-authored-by: Pzzzzz <hello-cd.plus@hotmail.com> Co-authored-by: Patrick Reiter Horn <patrick.horn@gmail.com>
99 lines
4.3 KiB
C++
99 lines
4.3 KiB
C++
/*
|
|
* Copyright (c) 2019-2024, NVIDIA CORPORATION. All rights reserved.
|
|
*
|
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
|
* you may not use this file except in compliance with the License.
|
|
* You may obtain a copy of the License at
|
|
*
|
|
* http://www.apache.org/licenses/LICENSE-2.0
|
|
*
|
|
* Unless required by applicable law or agreed to in writing, software
|
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
* See the License for the specific language governing permissions and
|
|
* limitations under the License.
|
|
*/
|
|
|
|
#pragma once
|
|
|
|
#include "gptKernels.h"
|
|
#include "tensorrt_llm/kernels/beamSearchKernels.h"
|
|
#include "tensorrt_llm/kernels/decodingCommon.h"
|
|
#include "tensorrt_llm/runtime/common.h"
|
|
#include <cuda_fp16.h>
|
|
#include <cuda_runtime.h>
|
|
#include <curand_kernel.h>
|
|
|
|
namespace tensorrt_llm
|
|
{
|
|
|
|
namespace kernels
|
|
{
|
|
|
|
struct gatherTreeParam
|
|
{
|
|
// TODO rename the parameters
|
|
int32_t* beams = nullptr; // [batchSize, beamWidth, maxSeqLen], workspace to put intermediate outputIds
|
|
int32_t* sequenceLengths = nullptr; // [batchSize, beamWidth], total lengths of each query
|
|
int32_t maxSequenceLengthFinalStep = 0;
|
|
int32_t const* inputLengths = nullptr; // [batchSize, beamWidth]
|
|
// response input lengths (used to slice the ids during postprocessing)
|
|
int32_t* responseInputLengths = nullptr;
|
|
int32_t maxSeqLen = 0;
|
|
int32_t batchSize = 0;
|
|
int32_t beamWidth = 0;
|
|
int32_t const* stepIds = nullptr; // [maxSeqLen, batchSize, beamWidth]
|
|
int32_t const* parentIds = nullptr; // [maxSeqLen, batchSize, beamWidth]
|
|
int32_t const* endTokens = nullptr; // [batchSize], end token ids of each query
|
|
int32_t* outputIds = nullptr; // the buffer to put finalized ids
|
|
cudaStream_t stream;
|
|
float* cumLogProbs = nullptr; // [batchSize, beamWidth]
|
|
float lengthPenalty = 1.0f;
|
|
int earlyStopping = 1;
|
|
};
|
|
|
|
/*
|
|
Do gatherTree on beam search to get final result.
|
|
*/
|
|
void invokeGatherTree(gatherTreeParam param);
|
|
|
|
void invokeInsertUnfinishedPath(BeamHypotheses& bh, cudaStream_t stream);
|
|
|
|
void invokeFinalize(BeamHypotheses& bh, cudaStream_t stream);
|
|
|
|
void invokeInitializeOutput(runtime::TokenIdType* finalOutputIds, runtime::TokenIdType const* endIds,
|
|
runtime::SizeType32 batchBeam, runtime::SizeType32 maxSeqLen, cudaStream_t stream);
|
|
|
|
//! \brief Copies last numNewTokens (or 1 if numNewTokens == nullptr) tokens from outputIdsPtr
|
|
//! to nextStepIds according to sequenceLengths.
|
|
//!
|
|
//! \param nextStepIds output buffer [maxTokensPerStep, maxBatchSize, maxBeamWidth],
|
|
//! destination of the new tokens.
|
|
//! \param outputIdsPtr input buffer [maxBatchSize][maxBeamWidth, maxSeqLen],
|
|
//! array of pointers to the source of the copy.
|
|
//! \param sequenceLengths input buffer [maxBatchSize], sequence length of the request
|
|
//! in outputIdsPtr that includes all new tokens. It must be guaranteed that sequenceLengths <= maxSeqLen.
|
|
//! \param numNewTokens input buffer [maxBatchSize], optional, number of tokens to be copied.
|
|
//! If nullptr, only 1 token is copied. It must be guaranteed that numNewTokens <= sequenceLengths.
|
|
//! \param batchSlots input buffer [batchSize], address map from local index
|
|
//! to global index [0, batchSize] -> [0, maxBatchSize]
|
|
//! \param batchSize current batch size
|
|
//! \param maxBatchSize maximum batch size
|
|
//! \param beamWidth current beam width
|
|
//! \param maxSeqLen maximum sequence length
|
|
//! \param maxTokensPerStep maximum tokens per step
|
|
//! \param stream stream
|
|
void invokeCopyNextStepIds(runtime::TokenIdType* nextStepIds, runtime::TokenIdType const* const* outputIdsPtr,
|
|
runtime::SizeType32 const* sequenceLengths, runtime::SizeType32 const* numNewTokens,
|
|
runtime::SizeType32 const* batchSlots, runtime::SizeType32 batchSize, runtime::SizeType32 maxBatchSize,
|
|
runtime::SizeType32 beamWidth, runtime::SizeType32 maxSeqLen, runtime::SizeType32 maxTokensPerStep,
|
|
cudaStream_t stream);
|
|
|
|
void invokeTransposeLogProbs(float* output_log_probs, float* output_log_probs_tiled,
|
|
runtime::SizeType32 const* sequence_lengths, runtime::SizeType32 const* batchSlots, runtime::SizeType32 batch_size,
|
|
runtime::SizeType32 max_batch_size, runtime::SizeType32 beam_width, runtime::SizeType32 max_seq_len,
|
|
cudaStream_t stream);
|
|
|
|
} // namespace kernels
|
|
} // namespace tensorrt_llm
|