Update TensorRT-LLM (#2460)

This commit is contained in:
Kaiyu Xie 2024-11-19 18:30:34 +08:00 committed by GitHub
parent c629546ce4
commit 535c9cc673
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
244 changed files with 6967 additions and 3534 deletions

View File

@ -8,7 +8,7 @@ TensorRT-LLM
[![python](https://img.shields.io/badge/python-3.10.12-green)](https://www.python.org/downloads/release/python-31012/)
[![cuda](https://img.shields.io/badge/cuda-12.6.2-green)](https://developer.nvidia.com/cuda-downloads)
[![trt](https://img.shields.io/badge/TRT-10.6.0-green)](https://developer.nvidia.com/tensorrt)
[![version](https://img.shields.io/badge/release-0.15.0.dev-green)](./tensorrt_llm/version.py)
[![version](https://img.shields.io/badge/release-0.16.0.dev-green)](./tensorrt_llm/version.py)
[![license](https://img.shields.io/badge/license-Apache%202-blue)](./LICENSE)
[Architecture](./docs/source/architecture/overview.md)   |   [Results](./docs/source/performance/perf-overview.md)   |   [Examples](./examples/)   |   [Documentation](./docs/source/)   |   [Roadmap](https://docs.google.com/presentation/d/1gycPmtdh7uUcH6laOvW65Dbp9F1McUkGDIcAyjicBZs/edit?usp=sharing)
@ -18,12 +18,18 @@ TensorRT-LLM
## Latest News
* [2024/11/09] 🚀🚀🚀 3x Faster AllReduce with NVSwitch and TensorRT-LLM MultiShot
[➡️ link](https://developer.nvidia.com/blog/3x-faster-allreduce-with-nvswitch-and-tensorrt-llm-multishot/)
<div align="center">
<img src="https://developer-blogs.nvidia.com/wp-content/uploads/2024/08/HGX-H200-tech-blog-1920x1080-1.jpg" width="50%">
<div align="left">
* [2024/11/09] ✨ NVIDIA advances the AI ecosystem with the AI model of LG AI Research 🙌
[➡️ link](https://blogs.nvidia.co.kr/blog/nvidia-lg-ai-research/)
* [2024/11/02] 🌟🌟🌟 NVIDIA and LlamaIndex Developer Contest
🙌 Enter for a chance to win prizes including an NVIDIA® GeForce RTX™ 4080 SUPER GPU, DLI credits, and more🙌
[➡️ link](https://developer.nvidia.com/llamaindex-developer-contest)
<div align="center">
<img src="docs/source/media/image-11-02-2024.png" width="50%">
<div align="left">
* [2024/10/28] 🏎️🏎️🏎️ NVIDIA GH200 Superchip Accelerates Inference by 2x in Multiturn Interactions with Llama Models
[➡️ link](https://developer.nvidia.com/blog/nvidia-gh200-superchip-accelerates-inference-by-2x-in-multiturn-interactions-with-llama-models/)

View File

@ -664,7 +664,7 @@ class ExecutorServer
{
public:
ExecutorServer(std::optional<std::filesystem::path> const& decoderTrtEnginePath,
std::optional<std::filesystem::path> const& encoderTrtEnginePath, TrtGptModelType modelType,
std::optional<std::filesystem::path> const& encoderTrtEnginePath, texec::BatchingType batchingType,
int32_t maxBeamWidth, texec::CapacitySchedulerPolicy capacitySchedulerPolicy,
BenchmarkParams const& benchmarkParams, std::shared_ptr<Recorder> recorder, std::chrono::milliseconds waitSleep,
bool logIterationData, texec::ModelType executorModelType)
@ -692,8 +692,7 @@ public:
maxBeamWidth, schedulerConfig, kvCacheConfig, benchmarkParams.enableChunkedContext, true);
executorConfig.setGpuWeightsPercent(benchmarkParams.gpuWeightsPercent);
executorConfig.setPeftCacheConfig(peftCacheConfig);
executorConfig.setBatchingType(
modelType == TrtGptModelType::V1 ? texec::BatchingType::kSTATIC : texec::BatchingType::kINFLIGHT);
executorConfig.setBatchingType(batchingType);
if (benchmarkParams.maxBatchSize)
{
executorConfig.setMaxBatchSize(benchmarkParams.maxBatchSize.value());
@ -947,6 +946,7 @@ texec::Request makeExecutorRequest(Sample const& sample, SizeType32 const& beamW
std::nullopt, // embeddingBias
std::nullopt, // speculativeDecoding
std::nullopt, // pTuning
std::nullopt, // mRopeConfig
loraConfig, // loraConfig
lookaheadConfig, // lookaheadConfig
std::nullopt, // kvCacheRetentionConfig
@ -955,7 +955,7 @@ texec::Request makeExecutorRequest(Sample const& sample, SizeType32 const& beamW
}
void benchmarkExecutor(std::optional<std::filesystem::path> const& decoderEngineDir,
std::optional<std::filesystem::path> const& encoderEngineDir, TrtGptModelType modelType,
std::optional<std::filesystem::path> const& encoderEngineDir, texec::BatchingType batchingType,
std::string const& datasetPath, std::string const& opCsvFile, int maxNumSamples, int beamWidth, int warmUp,
std::optional<int32_t> const& eosId, std::optional<int32_t> const& padId, BenchmarkParams const& benchmarkParams,
texec::CapacitySchedulerPolicy capacitySchedulerPolicy, std::chrono::milliseconds waitSleep,
@ -977,16 +977,17 @@ void benchmarkExecutor(std::optional<std::filesystem::path> const& decoderEngine
{
TLLM_CHECK_WITH_INFO(
decoderEngineDir.has_value(), "decoder models require a path to decoder engine in executor benchmark.");
executorServer = std::make_shared<ExecutorServer>(decoderEngineDir.value(), std::nullopt, modelType, beamWidth,
capacitySchedulerPolicy, benchmarkParams, recorder, waitSleep, logIterationData, executorModelType);
executorServer
= std::make_shared<ExecutorServer>(decoderEngineDir.value(), std::nullopt, batchingType, beamWidth,
capacitySchedulerPolicy, benchmarkParams, recorder, waitSleep, logIterationData, executorModelType);
}
else if (executorModelType == texec::ModelType::kENCODER_DECODER)
{
TLLM_CHECK_WITH_INFO(encoderEngineDir.has_value(),
"encoder-decoder models require a path to encoder engine in executor benchmark.");
executorServer
= std::make_shared<ExecutorServer>(decoderEngineDir.value(), encoderEngineDir.value(), modelType, beamWidth,
capacitySchedulerPolicy, benchmarkParams, recorder, waitSleep, logIterationData, executorModelType);
executorServer = std::make_shared<ExecutorServer>(decoderEngineDir.value(), encoderEngineDir.value(),
batchingType, beamWidth, capacitySchedulerPolicy, benchmarkParams, recorder, waitSleep, logIterationData,
executorModelType);
try
{
std::ifstream decoderJsonConfigPath(decoderEngineDir.value() / "config.json");
@ -1011,8 +1012,9 @@ void benchmarkExecutor(std::optional<std::filesystem::path> const& decoderEngine
{
TLLM_CHECK_WITH_INFO(
encoderEngineDir.has_value(), "encoder models require a path to encoder engine in executor benchmark.");
executorServer = std::make_shared<ExecutorServer>(std::nullopt, encoderEngineDir.value(), modelType, beamWidth,
capacitySchedulerPolicy, benchmarkParams, recorder, waitSleep, logIterationData, executorModelType);
executorServer
= std::make_shared<ExecutorServer>(std::nullopt, encoderEngineDir.value(), batchingType, beamWidth,
capacitySchedulerPolicy, benchmarkParams, recorder, waitSleep, logIterationData, executorModelType);
}
else
{
@ -1219,8 +1221,9 @@ int main(int argc, char* argv[])
"encoder_engine_dir", "Directory that store the engines of the encoder models.", cxxopts::value<std::string>());
options.add_options()(
"api", "API type: gptManager or executor.", cxxopts::value<std::string>()->default_value("executor"));
options.add_options()("type", "Batching type: IFB, UIFB (unfused IFB) or V1 (non-IFB) batching.",
cxxopts::value<std::string>()->default_value("IFB"));
options.add_options()("type",
"Batching type: choose between inflight/static. (IFB/V1 options are going to be deprecated)",
cxxopts::value<std::string>()->default_value("inflight"));
options.add_options()("dataset", "Dataset that is used for benchmarking BatchManager.",
cxxopts::value<std::string>()->default_value(""));
options.add_options()(
@ -1332,18 +1335,22 @@ int main(int argc, char* argv[])
// Argument: Batching Type
auto const type = result["type"].as<std::string>();
TrtGptModelType modelType{TrtGptModelType::V1};
if (type == "V1")
texec::BatchingType batchingType{texec::BatchingType::kINFLIGHT};
if (type == "V1" || type == "static")
{
modelType = TrtGptModelType::V1;
if (type == "V1")
{
TLLM_LOG_WARNING("type option \"V1\" is going to be renamed to \"static\".");
}
batchingType = texec::BatchingType::kSTATIC;
}
else if (type == "UIFB")
else if (type == "IFB" || type == "inflight")
{
modelType = TrtGptModelType::InflightBatching;
}
else if (type == "IFB")
{
modelType = TrtGptModelType::InflightFusedBatching;
if (type == "IFB")
{
TLLM_LOG_WARNING("type option \"IFB\" is going to be renamed to \"inflight\".");
}
batchingType = texec::BatchingType::kINFLIGHT;
}
else
{
@ -1604,7 +1611,7 @@ int main(int argc, char* argv[])
{
TLLM_CHECK_WITH_INFO(api == "executor", "encoder-decoder only support executor api.");
TLLM_CHECK_WITH_INFO(
modelType == TrtGptModelType::InflightFusedBatching, "encoder-decoder only support inflight batching.");
batchingType == texec::BatchingType::kINFLIGHT, "encoder-decoder only support inflight batching.");
executorModelType = texec::ModelType::kENCODER_DECODER;
encoderEngineDir = result["encoder_engine_dir"].as<std::string>();
decoderEngineDir = result["decoder_engine_dir"].as<std::string>();
@ -1621,7 +1628,7 @@ int main(int argc, char* argv[])
}
try
{
benchmarkExecutor(decoderEngineDir, encoderEngineDir, modelType, datasetPath, opCsvFile, maxNumSamples,
benchmarkExecutor(decoderEngineDir, encoderEngineDir, batchingType, datasetPath, opCsvFile, maxNumSamples,
beamWidth, result["warm_up"].as<int>(), eosId, padId, benchmarkParams, capacitySchedulerPolicy,
waitSleep, returnContextLogits, returnGenerationLogits, staticEmulatedBatchSize, logIterationData,
maxPromptLen, executorModelType);

View File

@ -50,7 +50,7 @@ def print_dataset(input_ids, output_lens):
for i, input_tokens in enumerate(input_ids):
d = {
"task_id": i,
"logits": input_tokens,
"input_ids": input_tokens,
"output_tokens": output_lens[i]
}
print(json.dumps(d, separators=(',', ':'), ensure_ascii=False))

View File

@ -85,7 +85,7 @@ public:
std::optional<SizeType32> sinkTokenLength;
std::optional<float> freeGpuMemoryFraction;
bool enableBlockReuse;
static constexpr auto kDefaultGpuMemFraction = 0.9f;
static constexpr auto kDefaultGpuMemFraction = 0.9F;
bool useUvm;
std::optional<size_t> hostCacheSize;
bool onboardBlocks;

View File

@ -835,7 +835,7 @@ public:
* 2 * modelConfig.getSizePerHead();
}
[[nodiscard]] static std::tuple<SizeType32, SizeType32> const calculateMaxNumBlocks(KvCacheConfig const& config,
[[nodiscard]] static std::tuple<SizeType32, SizeType32> calculateMaxNumBlocks(KvCacheConfig const& config,
nvinfer1::DataType dtype, tensorrt_llm::runtime::ModelConfig const& modelConfig,
tensorrt_llm::runtime::WorldConfig const& worldConfig, runtime::BufferManager const& bufferManager);

View File

@ -92,6 +92,8 @@ public:
std::optional<std::shared_ptr<std::vector<SizeType32>>> positionIds = std::nullopt,
std::optional<TensorPtr> promptEmbeddingTable = std::nullopt,
std::optional<SizeType32> promptVocabSize = std::nullopt,
std::optional<TensorPtr> mropeRotarySinCos = std::nullopt,
std::optional<SizeType32> mropePositionDeltas = std::nullopt,
std::optional<LoraTaskIdType> loraTaskId = std::nullopt, std::optional<TensorPtr> loraWeights = std::nullopt,
std::optional<TensorPtr> loraConfig = std::nullopt,
std::optional<executor::LookaheadDecodingConfig> lookaheadConfig = std::nullopt,
@ -131,6 +133,8 @@ public:
, mPositionIds(std::move(positionIds))
, mPromptEmbeddingTable(std::move(promptEmbeddingTable))
, mPromptVocabSize(promptVocabSize)
, mMropeRotarySinCos(std::move(mropeRotarySinCos))
, mMropePositionDeltas(std::move(mropePositionDeltas))
, mLoraTaskId(loraTaskId)
, mLoraWeights(std::move(loraWeights))
, mLoraConfig(std::move(loraConfig))
@ -188,6 +192,8 @@ public:
, mPositionIds(std::nullopt)
, mPromptEmbeddingTable(std::nullopt)
, mPromptVocabSize(std::nullopt)
, mMropeRotarySinCos(std::nullopt)
, mMropePositionDeltas(std::nullopt)
, mLoraTaskId(std::nullopt)
, mLoraWeights(std::nullopt)
, mLoraConfig(std::nullopt)
@ -285,6 +291,12 @@ public:
= std::make_shared<VecTokenExtraIds>(pTuningConfig->getInputTokenExtraIds().value());
}
}
auto mRopeConfig = req.getMropeConfig();
if (mRopeConfig)
{
mMropeRotarySinCos = executor::detail::toITensor(mRopeConfig.value().getMRopeRotarySinCos());
mMropePositionDeltas = mRopeConfig.value().getMRopePositionDeltas();
}
auto loraConfig = req.getLoraConfig();
if (loraConfig)
@ -447,16 +459,6 @@ public:
mContextPhaseParams = std::move(contextPhaseParams);
}
[[nodiscard]] bool isLayerWiseKvCacheEnabled() const
{
return isContextOnlyRequest() && mLayerWiseKvCacheEnabled;
}
void setLayerWiseKvCacheEnabled(bool enabled)
{
mLayerWiseKvCacheEnabled = enabled;
}
/// @brief Get the state params of the context
/// @return The state params of the context
[[nodiscard]] executor::DataTransceiverState const& getDataTransceiverState() const
@ -798,6 +800,16 @@ public:
return mPromptVocabSize;
}
[[nodiscard]] std::optional<TensorPtr> getMropeRotarySinCos() const
{
return mMropeRotarySinCos;
}
[[nodiscard]] std::optional<SizeType32> getMropePositionDeltas() const
{
return mMropePositionDeltas;
}
[[nodiscard]] std::optional<LoraTaskIdType> getLoraTaskId() const
{
return mLoraTaskId;
@ -1604,6 +1616,8 @@ protected:
std::optional<TensorPtr> mPromptEmbeddingTable;
std::optional<SizeType32> mPromptVocabSize;
std::optional<TensorPtr> mMropeRotarySinCos;
std::optional<SizeType32> mMropePositionDeltas;
std::optional<LoraTaskIdType> mLoraTaskId;
std::optional<TensorPtr> mLoraWeights;
@ -1654,7 +1668,6 @@ protected:
std::optional<TensorPtr> mCrossAttentionMask; // Input cross attention mask
LlmRequestType mLlmRequestType;
std::optional<executor::ContextPhaseParams> mContextPhaseParams;
bool mLayerWiseKvCacheEnabled = false;
std::optional<std::shared_ptr<VecTokenExtraIds>> mInputTokenExtraIds;
BeamUniqueTokens mUniqueTokens;
@ -1819,6 +1832,8 @@ public:
std::optional<std::shared_ptr<std::vector<SizeType32>>> positionIds = std::nullopt,
std::optional<TensorPtr> promptEmbeddingTable = std::nullopt,
std::optional<SizeType32> promptVocabSize = std::nullopt,
std::optional<TensorPtr> mropeRotarySinCos = std::nullopt,
std::optional<SizeType32> mropePositionDeltas = std::nullopt,
std::optional<LoraTaskIdType> loraTaskId = std::nullopt, std::optional<TensorPtr> loraWeights = std::nullopt,
std::optional<TensorPtr> loraConfig = std::nullopt,
std::optional<executor::LookaheadDecodingConfig> lookaheadConfig = std::nullopt,
@ -1840,7 +1855,8 @@ public:
std::optional<TensorPtr> skipCrossAttnBlocks = std::nullopt)
: Base(requestId, maxNewTokens, std::move(inputTokens), samplingConfig, isStreaming, endId, padId,
std::move(embeddingBias), std::move(badWordsList), std::move(stopWordsList), std::move(positionIds),
std::move(promptEmbeddingTable), promptVocabSize, loraTaskId, std::move(loraWeights), std::move(loraConfig),
std::move(promptEmbeddingTable), promptVocabSize, std::move(mropeRotarySinCos),
std::move(mropePositionDeltas), loraTaskId, std::move(loraWeights), std::move(loraConfig),
std::move(lookaheadConfig), std::move(kvCacheRetentionConfig), returnLogProbs, returnContextLogits,
returnGenerationLogits, std::move(draftTokens), std::move(draftLogits), excludeInputFromOutput,
std::move(logitsPostProcessor), applyLogitsPostProcessorBatched, std::move(encoderInputTokens),
@ -1857,6 +1873,8 @@ public:
std::optional<std::vector<SizeType32>> positionIds = std::nullopt,
std::optional<TensorPtr> promptEmbeddingTable = std::nullopt,
std::optional<SizeType32> promptVocabSize = std::nullopt,
std::optional<TensorPtr> mropeRotarySinCos = std::nullopt,
std::optional<SizeType32> mropePositionDeltas = std::nullopt,
std::optional<LoraTaskIdType> loraTaskId = std::nullopt, std::optional<TensorPtr> loraWeights = std::nullopt,
std::optional<TensorPtr> loraConfig = std::nullopt,
std::optional<executor::LookaheadDecodingConfig> lookaheadConfig = std::nullopt,
@ -1879,7 +1897,8 @@ public:
std::move(stopWordsList),
positionIds.has_value() ? std::make_shared<std::vector<SizeType32>>(std::move(positionIds.value()))
: std::optional<std::shared_ptr<std::vector<SizeType32>>>(std::nullopt),
std::move(promptEmbeddingTable), promptVocabSize, loraTaskId, std::move(loraWeights), std::move(loraConfig),
std::move(promptEmbeddingTable), promptVocabSize, std::move(mropeRotarySinCos),
std::move(mropePositionDeltas), loraTaskId, std::move(loraWeights), std::move(loraConfig),
std::move(lookaheadConfig), std::move(kvCacheRetentionConfig), returnLogProbs, returnContextLogits,
returnGenerationLogits,
draftTokens.has_value() ? std::make_shared<VecTokens>(std::move(draftTokens.value()))

View File

@ -35,11 +35,11 @@
#include <mpi.h>
#else
// Dummy defines to avoid #if in wider places.
typedef int MPI_Datatype;
typedef int MPI_Comm;
typedef int MPI_Request;
typedef int MPI_Message;
typedef int MPI_Op;
typedef void* MPI_Datatype;
typedef void* MPI_Comm;
typedef void* MPI_Request;
typedef void* MPI_Message;
typedef void* MPI_Op;
typedef struct MPI_Status
{

View File

@ -250,6 +250,24 @@ private:
std::optional<VecTokenExtraIds> mInputTokenExtraIds;
};
/// @brief Configuration for mrope
class MropeConfig
{
public:
explicit MropeConfig(Tensor mropeRoratySinCos, SizeType32 mropePositionDeltas);
[[nodiscard]] Tensor getMRopeRotarySinCos() const;
[[nodiscard]] SizeType32 getMRopePositionDeltas() const;
private:
friend class Serialization;
/// @brief The mrope rotary sin and cos cache. Expected shape: [maxPositionEmbeddings*rotaryEmbeddingDim],Data type
/// must float32
Tensor mMRopeRotarySinCos;
/// @brief The mrope position deltas
SizeType32 mMRopePositionDeltas;
};
/// @brief Configuration for LoRA
class LoraConfig
{
@ -330,9 +348,10 @@ public:
ContextPhaseParams(VecTokens firstGenTokens, RequestIdType reqId, void* state);
ContextPhaseParams(ContextPhaseParams const&);
ContextPhaseParams(ContextPhaseParams&&);
ContextPhaseParams(ContextPhaseParams&&) noexcept;
ContextPhaseParams& operator=(ContextPhaseParams const&);
ContextPhaseParams& operator=(ContextPhaseParams&&);
ContextPhaseParams& operator=(ContextPhaseParams&&) noexcept;
~ContextPhaseParams();
[[nodiscard]] bool operator==(ContextPhaseParams const&) const noexcept;
@ -511,7 +530,7 @@ public:
std::optional<Tensor> embeddingBias = std::nullopt,
std::optional<ExternalDraftTokensConfig> externalDraftTokensConfig = std::nullopt,
std::optional<PromptTuningConfig> pTuningConfig = std::nullopt,
std::optional<LoraConfig> loraConfig = std::nullopt,
std::optional<MropeConfig> mRopeConfig = std::nullopt, std::optional<LoraConfig> loraConfig = std::nullopt,
std::optional<LookaheadDecodingConfig> lookaheadConfig = std::nullopt,
std::optional<KvCacheRetentionConfig> kvCacheRetentionConfig = std::nullopt,
std::optional<std::string> logitsPostProcessorName = std::nullopt,
@ -548,6 +567,7 @@ public:
[[nodiscard]] std::optional<Tensor> getEmbeddingBias() const;
[[nodiscard]] std::optional<ExternalDraftTokensConfig> getExternalDraftTokensConfig() const;
[[nodiscard]] std::optional<PromptTuningConfig> getPromptTuningConfig() const;
[[nodiscard]] std::optional<MropeConfig> getMropeConfig() const;
[[nodiscard]] std::optional<LoraConfig> getLoraConfig() const;
[[nodiscard]] std::optional<LookaheadDecodingConfig> getLookaheadConfig() const;
[[nodiscard]] std::optional<KvCacheRetentionConfig> getKvCacheRetentionConfig() const;
@ -576,6 +596,7 @@ public:
void setEmbeddingBias(Tensor const& embeddingBias);
void setExternalDraftTokensConfig(ExternalDraftTokensConfig const& externalDraftTokensConfig);
void setPromptTuningConfig(PromptTuningConfig const& pTuningConfig);
void setMropeConfig(MropeConfig const& mRopeConfig);
void setLoraConfig(LoraConfig const& loraConfig);
void setLookaheadConfig(LookaheadDecodingConfig const& lookaheadConfig);
void setKvCacheRetentionConfig(KvCacheRetentionConfig const& kvCacheRetentionConfig);
@ -648,7 +669,10 @@ struct Result
/// @brief The params of the context phase.
std::optional<ContextPhaseParams> contextPhaseParams;
/// @brief The decoding iterations it takes.
/// @brief The number of the decoding iterations used to generate the result.
/// In autoregressive decoding, it is equal to the maximum length of the beam in outputTokenIds.
/// In speculative decoding, might be less than maximum length of the beam in outputTokenIds as more than
/// one token can be generated per iteration. Used for speculative decoding statistics.
SizeType32 decodingIter{0};
/// @brief The index of the output sequence of this result where 0 <= sequenceIndex < numReturnSequences.

View File

@ -55,6 +55,11 @@ public:
static void serialize(PromptTuningConfig const& config, std::ostream& os);
[[nodiscard]] static size_t serializedSize(PromptTuningConfig const& config);
// MropeConfig
[[nodiscard]] static MropeConfig deserializeMropeConfig(std::istream& is);
static void serialize(MropeConfig const& config, std::ostream& os);
[[nodiscard]] static size_t serializedSize(MropeConfig const& config);
// LoraConfig
[[nodiscard]] static LoraConfig deserializeLoraConfig(std::istream& is);
static void serialize(LoraConfig const& config, std::ostream& os);
@ -206,6 +211,33 @@ public:
static void serialize(IterationStats const& iterStats, std::ostream& os);
static std::vector<char> serialize(IterationStats const& iterStats);
static size_t serializedSize(IterationStats const& iterStats);
static std::vector<char> serialize(std::vector<IterationStats> const& iterStatsVec);
static std::vector<IterationStats> deserializeIterationStatsVec(std::vector<char>& buffer);
// DisServingStats
[[nodiscard]] static DisServingRequestStats deserializeDisServingRequestStats(std::istream& is);
static void serialize(DisServingRequestStats const& state, std::ostream& os);
[[nodiscard]] static size_t serializedSize(DisServingRequestStats const& state);
// RequestStage
[[nodiscard]] static RequestStage deserializeRequestStage(std::istream& is);
static void serialize(RequestStage const& state, std::ostream& os);
[[nodiscard]] static size_t serializedSize(RequestStage const& state);
// RequestStats
[[nodiscard]] static RequestStats deserializeRequestStats(std::istream& is);
static void serialize(RequestStats const& state, std::ostream& os);
[[nodiscard]] static size_t serializedSize(RequestStats const& state);
// RequestStatsPerIteration
[[nodiscard]] static RequestStatsPerIteration deserializeRequestStatsPerIteration(std::istream& is);
[[nodiscard]] static RequestStatsPerIteration deserializeRequestStatsPerIteration(std::vector<char>& buffer);
static void serialize(RequestStatsPerIteration const& state, std::ostream& os);
[[nodiscard]] static std::vector<char> serialize(RequestStatsPerIteration const& state);
[[nodiscard]] static size_t serializedSize(RequestStatsPerIteration const& state);
[[nodiscard]] static std::vector<char> serialize(std::vector<RequestStatsPerIteration> const& requestStatsVec);
[[nodiscard]] static std::vector<RequestStatsPerIteration> deserializeRequestStatsPerIterationVec(
std::vector<char>& buffer);
// String
static std::string deserializeString(std::istream& is);

View File

@ -77,6 +77,8 @@ public:
TensorPtr eagleNetGenContextLengthsHost;
//! [maxBatchSize] or [numSequences]
TensorPtr eagleNetGenPastKeyValueLengthsHost;
//! [maxBatchSize * maxDecodingTokens] or [numSequences * maxDecodingTokens]
TensorPtr inputGenTokensHost;
void create(SizeType32 maxNumSequences, runtime::TllmRuntime const& runtime,
runtime::ModelConfig const& modelConfig, runtime::WorldConfig const& worldConfig);

View File

@ -67,7 +67,7 @@ public:
static std::unique_ptr<IGptDecoder> create(executor::DecodingMode const& mode, nvinfer1::DataType dtype,
size_t maxBatchSize, size_t maxBeamWidth, size_t vocabSize, size_t vocabSizePadded, size_t maxSequenceLength,
BufferManager::CudaStreamPtr const& stream,
std::shared_ptr<SpeculativeDecodingModule const> speculativeDecodingModule = nullptr);
std::shared_ptr<SpeculativeDecodingModule const> const& speculativeDecodingModule = nullptr);
};
template <typename T>
@ -110,7 +110,7 @@ private:
inline std::unique_ptr<IGptDecoder> IGptDecoder::create(executor::DecodingMode const& mode, nvinfer1::DataType dtype,
size_t maxBatchSize, size_t maxBeamWidth, size_t vocabSize, size_t vocabSizePadded, size_t maxSequenceLength,
BufferManager::CudaStreamPtr const& stream,
std::shared_ptr<SpeculativeDecodingModule const> speculativeDecodingModule)
std::shared_ptr<SpeculativeDecodingModule const> const& speculativeDecodingModule)
{
switch (dtype)
{
@ -128,10 +128,9 @@ inline std::unique_ptr<IGptDecoder> IGptDecoder::create(executor::DecodingMode c
/// @brief Helper function to produce batch slots [0, 1, ..., batchSize - 1] for paths that do not explicitly provide
/// batch slots to the decoder.
inline runtime::ITensor::SharedConstPtr getDefaultBatchSlots(
runtime::SizeType32 batchSize, runtime::BufferManager const& bufferManager)
inline runtime::ITensor::SharedConstPtr getDefaultBatchSlots(runtime::SizeType32 batchSize)
{
auto defaultBatchSlots = bufferManager.pinnedPool(
auto defaultBatchSlots = runtime::BufferManager::pinnedPool(
runtime::ITensor::makeShape({batchSize}), runtime::TRTDataType<runtime::SizeType32>::value);
auto range = runtime::BufferRange<runtime::SizeType32>(*defaultBatchSlots);
std::iota(range.begin(), range.end(), 0);

View File

@ -36,6 +36,7 @@ public:
// users tune that, we can support that by writing and reading the
// points in `config.json`.
static constexpr std::array kOPT_PROFILES_SPLIT_POINTS{64, 128, 256, 512, 1024};
static constexpr SizeType32 kDEFAULT_NUM_TOKENS_PER_BLOCK = 64;
enum class ModelVariant : std::int32_t
{
@ -80,18 +81,16 @@ public:
{
return KVCacheType::kCONTINUOUS;
}
else if (value == "PAGED")
if (value == "PAGED")
{
return KVCacheType::kPAGED;
}
else if (value == "DISABLED")
if (value == "DISABLED")
{
return KVCacheType::kDISABLED;
}
else
{
throw std::invalid_argument("Invalid KV cache type: " + value);
}
throw std::invalid_argument("Invalid KV cache type: " + value);
}
enum class ManageWeightsType : std::int32_t
@ -113,7 +112,7 @@ public:
, mUseGptAttentionPlugin(false)
, mUseMambaConv1dPlugin(false)
, mInputPacked{false}
, mTokensPerBlock{64}
, mTokensPerBlock{kDEFAULT_NUM_TOKENS_PER_BLOCK}
, mQuantMode{common::QuantMode::none()}
, mMaxBatchSize(0)
, mMaxBeamWidth(0)
@ -124,6 +123,9 @@ public:
, mComputeGenerationLogits(false)
, mModelVariant(ModelVariant::kGpt)
, mMaxPromptEmbeddingTableSize(0)
, mUseMrope{false}
, mMaxPositionEmbeddings(0)
, mRotaryEmbeddingDim(0)
, mContextFMHA(false)
, mPagedContextFMHA(false)
, mUseXQA{false}
@ -396,6 +398,36 @@ public:
return mMaxPromptEmbeddingTableSize > 0;
}
[[nodiscard]] bool constexpr useMrope() const noexcept
{
return mUseMrope;
}
void constexpr setUseMrope(bool useMrope) noexcept
{
mUseMrope = useMrope;
}
[[nodiscard]] SizeType32 constexpr getMaxPositionEmbeddings() const noexcept
{
return mMaxPositionEmbeddings;
}
void constexpr setMaxPositionEmbeddings(SizeType32 maxPositionEmbeddings) noexcept
{
mMaxPositionEmbeddings = maxPositionEmbeddings;
}
[[nodiscard]] SizeType32 constexpr getRotaryEmbeddingDim() const noexcept
{
return mRotaryEmbeddingDim;
}
void constexpr setRotaryEmbeddingDim(SizeType32 rotaryEmbeddingDim) noexcept
{
mRotaryEmbeddingDim = rotaryEmbeddingDim;
}
[[nodiscard]] SizeType32 constexpr getMaxPromptEmbeddingTableSize() const noexcept
{
return mMaxPromptEmbeddingTableSize;
@ -622,14 +654,12 @@ public:
{
return nvinfer1::DataType::kFP8;
}
else if (getQuantMode().hasInt8KvCache())
if (getQuantMode().hasInt8KvCache())
{
return nvinfer1::DataType::kINT8;
}
else
{
return getDataType();
}
return getDataType();
}
[[nodiscard]] bool constexpr isTransformerBased() const noexcept
@ -733,6 +763,7 @@ public:
: mNumKvHeadsPerAttentionLayer.cbegin() + numPrevAttnLayers;
auto const numLocalAttentionLayers
= countLocalLayers(LayerType::kATTENTION, pipelineParallelism, pipelineParallelismRank);
TLLM_LOG_TRACE("%s stop: %d", __PRETTY_FUNCTION__);
return std::make_pair(firstLocalAttentionLayerIt, firstLocalAttentionLayerIt + numLocalAttentionLayers);
}
@ -797,6 +828,9 @@ private:
ModelVariant mModelVariant;
SizeType32 mMaxPromptEmbeddingTableSize;
bool mUseMrope;
SizeType32 mMaxPositionEmbeddings;
SizeType32 mRotaryEmbeddingDim;
bool mContextFMHA;
bool mPagedContextFMHA;

View File

@ -26,4 +26,7 @@ bool tensorHasNan(ITensor const& tensor, BufferManager const& manager, std::stri
bool tensorHasNan(
size_t M, size_t K, nvinfer1::DataType type, void const* data, cudaStream_t stream, std::string const& infoStr);
int stallStream(
char const* name, std::optional<cudaStream_t> stream = std::nullopt, std::optional<int> delay = std::nullopt);
} // namespace tensorrt_llm::runtime::utils

View File

@ -35,6 +35,7 @@ struct TreeNode
SizeType32 initTensorsFromChoices(SpeculativeDecodingModule const& speculativeDecodingModule,
std::vector<std::vector<SizeType32>> const& choices, std::vector<SizeType32>& topKs,
ITensor::SharedPtr generationInputLengths, ITensor::SharedPtr positionOffsets, ITensor::SharedPtr treeIds,
ITensor::SharedPtr paths, ITensor::SharedPtr packedMask);
ITensor::SharedPtr paths, ITensor::SharedPtr packedMask,
std::optional<SizeType32> maxNonLeafNodesPerLayer = std::nullopt);
} // namespace tensorrt_llm::runtime::utils

View File

@ -1,3 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:748a53a5f70813f0ddb5bb54a56cd07a4b9146917c12ec34504dc4384b00610b
size 5882210
oid sha256:93114cc9b3f67d302800ef751a71a87f549ad1fb436d7983cedea7edaf3cdc34
size 6001292

View File

@ -1,3 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:2350b7f07b5f30179ebf24f6e103dc17d4a656c95c171eaca684529120ca245a
size 6001974
oid sha256:c65e18c28264cf19543f94e5e529e00e1cfda12e6d02c7a2141960f11e1020e8
size 6121836

View File

@ -1,3 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:8b28f05452036c1722a37ac625921cf4902cfb6c04fb01b9d958b9f40ff9be0b
oid sha256:be60169ba0f4d8a526427d942ddb7e657a075f82b9bde186f339d92e5baefedd
size 1958384

View File

@ -1,2 +1,2 @@
0066a5a67ec747f565158bbbc398cca9 libtensorrt_llm_ucx_wrapper.so
1c2eb102257f836cd50faf985e693241d7a84dbe commit
f0f55c8f4b75991abd3ff2fce878cbea libtensorrt_llm_ucx_wrapper.so
0397a251a647dd5d25f0de5279170ba35d82c50d commit

View File

@ -1,3 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:0132b1d4544101465ac37993ae20324c0c49ae978b0a3c8c95a03a08a17b5b36
size 5692876
oid sha256:7cb62faee8fdf912738a7e144f55fbbc8348dcee342224a6e07890b5ec3cac05
size 5793796

View File

@ -1,3 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:15ff5d0aeae4d3e776fdf3bb68af0cc5896b14f435b66a11fecc2111668fd089
size 5659602
oid sha256:85ff3b9cefda7fc15bcd8a92fbcfb6a7200f13fbfa47e133caa103c9c6a77e4c
size 5763878

View File

@ -1,2 +1,2 @@
1598761c1df1fd35b2180b599ad34f58 libtensorrt_llm_ucx_wrapper.so
1c2eb102257f836cd50faf985e693241d7a84dbe commit
0397a251a647dd5d25f0de5279170ba35d82c50d commit

View File

@ -1,3 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:f975b781b240c8489a48243a94dfdf0be6bfe6b862cf6ec6cbeacd5c66fae7af
size 36139148
oid sha256:89b13b0625a17c038545a6a2c00e1e752d366c0afddcceb411c7903459e45911
size 36266224

View File

@ -1,2 +1,2 @@
f9557afc965818430dcae14ae7542adf tensorrt_llm_batch_manager_static.lib
1c2eb102257f836cd50faf985e693241d7a84dbe commit
32ed5e6a9704a91e56cff30c1ebe7211 tensorrt_llm_batch_manager_static.lib
0397a251a647dd5d25f0de5279170ba35d82c50d commit

View File

@ -38,6 +38,13 @@ static std::optional<int32_t> getIntEnv(char const* name)
return {val};
};
// Returns true if the env variable exists and is set to "1"
static bool getBoolEnv(char const* name)
{
char const* env = std::getenv(name);
return env && env[0] == '1' && env[1] == '\0';
}
// XQA kernels (optimized kernels for generation phase).
bool forceXQAKernels()
{
@ -143,15 +150,8 @@ bool getEnvEnablePDL()
// PDL only available when arch >= 90
if (getSMVersion() >= 90)
{
char const* enable_pdl = std::getenv("TRTLLM_ENABLE_PDL");
if (enable_pdl)
{
// PDL will be enabled by setting the env variables `TRTLLM_ENABLE_PDL` to `1`
if (enable_pdl[0] == '1' && enable_pdl[1] == '\0')
{
enablePDL = true;
}
}
// PDL will be enabled by setting the env variables `TRTLLM_ENABLE_PDL` to `1`
enablePDL = getBoolEnv("TRTLLM_ENABLE_PDL");
}
}
return enablePDL;
@ -159,22 +159,7 @@ bool getEnvEnablePDL()
bool getEnvUseUCXKvCache()
{
static bool init = false;
static bool useUCXKVCache = false;
if (!init)
{
init = true;
{
char const* use_ucx_kv_cache = std::getenv("TRTLLM_USE_UCX_KVCACHE");
if (use_ucx_kv_cache)
{
if (use_ucx_kv_cache[0] == '1' && use_ucx_kv_cache[1] == '\0')
{
useUCXKVCache = true;
}
}
}
}
static bool const useUCXKVCache = getBoolEnv("TRTLLM_USE_UCX_KVCACHE");
return useUCXKVCache;
}
@ -195,4 +180,11 @@ std::string getEnvUCXInterface()
}
return ucxInterface;
}
bool getEnvDisaggLayerwise()
{
static bool const disaggLayerwise = getBoolEnv("TRTLLM_DISAGG_LAYERWISE");
return disaggLayerwise;
}
} // namespace tensorrt_llm::common

View File

@ -45,4 +45,6 @@ bool getEnvUseUCXKvCache();
std::string getEnvUCXInterface();
bool getEnvDisaggLayerwise();
} // namespace tensorrt_llm::common

View File

@ -14,7 +14,6 @@
* limitations under the License.
*/
#include <algorithm>
#include <numeric>
#include <unordered_set>
@ -101,7 +100,7 @@ std::recursive_mutex mpiMutex;
MpiComm initLocalSession()
{
#if ENABLE_MULTI_DEVICE
MPI_Comm localComm;
MPI_Comm localComm = nullptr;
MPI_Comm_split_type(COMM_SESSION, OMPI_COMM_TYPE_HOST, COMM_SESSION.getRank(), MPI_INFO_NULL, &localComm);
MpiComm localSession{localComm, false};
#else
@ -115,14 +114,16 @@ MpiComm initLocalSession()
std::vector<int> getWorldRanks(MpiComm const& comm)
{
#if ENABLE_MULTI_DEVICE
MPI_Group group, worldGroup;
MPI_Group group = nullptr;
MPI_Group worldGroup = nullptr;
MPICHECK(MPI_Comm_group(MPI_COMM_WORLD, &worldGroup));
MPICHECK(MPI_Comm_group(comm, &group));
int groupSize;
int groupSize = 0;
MPICHECK(MPI_Group_size(group, &groupSize));
std::vector<int> ranks(groupSize), worldRanks(groupSize);
std::vector<int> ranks(groupSize);
std::vector<int> worldRanks(groupSize);
std::iota(ranks.begin(), ranks.end(), 0);
MPICHECK(MPI_Group_translate_ranks(group, groupSize, ranks.data(), worldGroup, worldRanks.data()));
@ -152,7 +153,7 @@ void initialize(MpiThreadSupport threadMode, bool forwardAbortToParent)
if (!initialized)
{
TLLM_LOG_INFO("Initializing MPI with thread mode %d", threadMode);
int providedMode;
int providedMode = 0;
auto requiredMode = static_cast<int>(threadMode);
MPICHECK(MPI_Init_thread(nullptr, nullptr, requiredMode, &providedMode));
TLLM_CHECK_WITH_INFO(providedMode >= requiredMode, "MPI_Init_thread failed");
@ -287,7 +288,7 @@ MPI_Status MpiComm::recv(runtime::IBuffer& buf, int source, int tag) const
MpiComm MpiComm::split(int color, int key) const
{
MPI_Comm splitComm;
MPI_Comm splitComm = nullptr;
#if ENABLE_MULTI_DEVICE
MPICHECK(MPI_Comm_split(mComm, color, key, &splitComm));
#else
@ -431,11 +432,11 @@ void MpiComm::refreshLocalSession()
}
}
MPI_Group worldGroup;
MPI_Group worldGroup = nullptr;
MPICHECK(MPI_Comm_group(MPI_COMM_WORLD, &worldGroup));
MPI_Group localGroup;
MPI_Group localGroup = nullptr;
MPICHECK(MPI_Group_incl(worldGroup, intersectionRanks.size(), intersectionRanks.data(), &localGroup));
MPI_Comm localComm;
MPI_Comm localComm = nullptr;
MPICHECK(MPI_Comm_create_group(MPI_COMM_WORLD, localGroup, intersectionRanks.front(), &localComm));
MpiComm::mutableLocalSession().mFreeComm = true;
MpiComm::mutableLocalSession() = MpiComm{localComm, false};

View File

@ -1,3 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:33f66dba2f3024d979e38cf1aae4d10802c5a1fb0f4c801108c35824339eae5d
size 2419566
oid sha256:ee2f55f3882f75eec0e91ea8392899356b67feecf415945b5e8fee80045a1c97
size 2493128

View File

@ -1,3 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:d224780476ce5f398f30ffbfa0d61bbd0aae5cb1538c8d4c0a16cdf8945ba5d3
size 2449532
oid sha256:6ff7a82a0e772a5e9cfc9c02ca012bae35d0b98d00841df22df6745232c8bd96
size 2523762

View File

@ -1,3 +1,3 @@
ee532edbf35321d4ac0aadf8a3c6a3a5 libtensorrt_llm_executor_static.a
0bf468a19d4c353dcf421fc3e05a9d7d libtensorrt_llm_executor_static.pre_cxx11.a
1c2eb102257f836cd50faf985e693241d7a84dbe commit
e31839682c2e28f1ee65d0a4ea5bdbde libtensorrt_llm_executor_static.a
544b31144b0cb90d4256e83b0c10bdfe libtensorrt_llm_executor_static.pre_cxx11.a
0397a251a647dd5d25f0de5279170ba35d82c50d commit

View File

@ -1,3 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:9b21e2488bdb5c1e18e7aa129acb18087d031eea4f5b063910081ca09a3041a5
size 3494984
oid sha256:d07532b2d05dc3a69ad98cf8fdfab870a67a57863eb812721d74ff3fd4c740dc
size 3563598

View File

@ -1,3 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:94964aa02020e38e869bf9ca18385ae379c8b9d1819ad02e10b23d8175cc9d82
size 3412104
oid sha256:5b7c5ab992090c7456adc23c38bbdd85ad0a5992bb2231763bda7fa2211ddadd
size 3485776

View File

@ -1,3 +1,3 @@
ba01eba908f38eb582c22c1f822cfedf libtensorrt_llm_executor_static.a
ffe68ec0af94d364ec8db50a24ae0e8c libtensorrt_llm_executor_static.pre_cxx11.a
1c2eb102257f836cd50faf985e693241d7a84dbe commit
a3216154712c98a922de02d68fea8456 libtensorrt_llm_executor_static.a
9d0a93f148854b0fbd76aa36e7d54d8e libtensorrt_llm_executor_static.pre_cxx11.a
0397a251a647dd5d25f0de5279170ba35d82c50d commit

View File

@ -1,3 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:67f59341edab284c309d39f2a0ad39e91f8afe198c4cf6ba838ae7adb54ad01d
size 23192460
oid sha256:abfb1b75d8675abba0a8fc59b74d3f94e2ba3eaea4f28e0c4b2beea9cc182316
size 23865270

View File

@ -1,2 +1,2 @@
e3cd49147c73b0066dcb759df9556191 tensorrt_llm_executor_static.lib
1c2eb102257f836cd50faf985e693241d7a84dbe commit
8e6a97333e79684aca515112982f2624 tensorrt_llm_executor_static.lib
0397a251a647dd5d25f0de5279170ba35d82c50d commit

View File

@ -17,18 +17,18 @@
#include "tensorrt_llm/kernels/cutlass_kernels/cutlass_heuristic.h"
#include "tensorrt_llm/common/cudaBf16Wrapper.h"
#ifndef _WIN32
#ifdef __GNUC__ // Check if the compiler is GCC or Clang
#pragma GCC diagnostic push
#pragma GCC diagnostic ignored "-Wstrict-aliasing"
#endif // #ifndef _WIN32
#endif // __GNUC__
#include "cutlass/gemm/gemm.h"
#include "cutlass/numeric_types.h"
#include "tensorrt_llm/common/assert.h"
#ifndef _WIN32
#ifdef __GNUC__ // Check if the compiler is GCC or Clang
#pragma GCC diagnostic pop
#endif // #ifndef _WIN32
#endif // __GNUC
#include <cuda_runtime_api.h>
#include <set>

View File

@ -16,10 +16,10 @@
#pragma once
#ifndef _WIN32
#ifdef __GNUC__ // Check if the compiler is GCC or Clang
#pragma GCC diagnostic push
#pragma GCC diagnostic ignored "-Wstrict-aliasing"
#endif // #ifndef _WIN32
#endif // __GNUC__
// clang-format off
#include "cutlass/cutlass.h"
@ -29,9 +29,9 @@
#include "cutlass/epilogue/threadblock/fusion/visitors.hpp"
// clang-format on
#ifndef _WIN32
#ifdef __GNUC__ // Check if the compiler is GCC or Clang
#pragma GCC diagnostic pop
#endif // #ifndef _WIN32
#endif // __GNUC__
using namespace cute;

View File

@ -22,10 +22,10 @@
#pragma once
#ifndef _WIN32
#ifdef __GNUC__ // Check if the compiler is GCC or Clang
#pragma GCC diagnostic push
#pragma GCC diagnostic ignored "-Wstrict-aliasing"
#endif // #ifndef _WIN32
#endif // __GNUC__
#include "cute/tensor.hpp"
#include "cutlass/conv/convolution.h"
@ -43,9 +43,9 @@
#include "cutlass/epilogue/collective/collective_builder.hpp"
#include "cutlass/gemm/device/gemm_universal_adapter.h"
#ifndef _WIN32
#ifdef __GNUC__ // Check if the compiler is GCC or Clang
#pragma GCC diagnostic pop
#endif // #ifndef _WIN32
#endif // __GNUC__
using namespace cute;

View File

@ -22,10 +22,10 @@
#pragma once
#ifndef _WIN32
#ifdef __GNUC__ // Check if the compiler is GCC or Clang
#pragma GCC diagnostic push
#pragma GCC diagnostic ignored "-Wstrict-aliasing"
#endif // #ifndef _WIN32
#endif // __GNUC__
#include "cute/tensor.hpp"
#include "cutlass/conv/convolution.h"
@ -33,9 +33,9 @@
#include "cutlass/util/packed_stride.hpp"
#include "cutlass_extensions/gemm_configs.h"
#ifndef _WIN32
#ifdef __GNUC__ // Check if the compiler is GCC or Clang
#pragma GCC diagnostic pop
#endif // #ifndef _WIN32
#endif // __GNUC__
#include "fp8_rowwise_gemm.h"
#include "fp8_rowwise_gemm_kernel_template_sm89.h"

View File

@ -87,6 +87,9 @@ protected:
static constexpr int SPLIT_K_LIMIT = 7;
static constexpr int MIN_M_TILE = 16;
static constexpr int MIN_N_TILE = 64;
static constexpr int MAX_M_TILE_SM90 = 128;
static constexpr int MAX_N_TILE_SM90 = 256;
};
template <typename ActivationType, typename WeightType, cutlass::WeightOnlyQuantOp QuantOp,

View File

@ -14,10 +14,10 @@
* limitations under the License.
*/
#ifndef _WIN32
#ifdef __GNUC__ // Check if the compiler is GCC or Clang
#pragma GCC diagnostic push
#pragma GCC diagnostic ignored "-Wstrict-aliasing"
#endif // #ifndef _WIN32
#endif // __GNUC__
#include "cutlass/gemm/kernel/default_gemm.h"
#include "cutlass_extensions/compute_occupancy.h"
@ -29,9 +29,9 @@
#include "cutlass_extensions/gemm/threadblock/default_mma.h"
#include "cutlass_extensions/gemm_configs.h"
#ifndef _WIN32
#ifdef __GNUC__ // Check if the compiler is GCC or Clang
#pragma GCC diagnostic pop
#endif // #ifndef _WIN32
#endif // __GNUC__
#include "tensorrt_llm/common/assert.h"
#include "tensorrt_llm/common/cudaUtils.h"
@ -568,6 +568,27 @@ CutlassFpAIntBGemmRunner<ActivationType, WeightType, QuantOp, ScaleZeroType, Bia
int const m, int const n, int const k)
{
TLLM_LOG_DEBUG(__PRETTY_FUNCTION__);
// For Hopper, we have to allocate large memory size in case for stream-K
if (sm_ == 90)
{
// https://github.com/NVIDIA/cutlass/blob/19b4c5e065e7e5bbc8082dfc7dbd792bdac850fc/include/cutlass/gemm/kernel/tile_scheduler_params.h#L878-L892
// The above lines says sk_tiles = output_tiles - (static_cast<uint32_t>(output_tiles / ctas_per_wave) - 1) *
// ctas_per_wave This means sk_tiles is at most 2 * ctas_per_wave, which is 2 * multi_processor_count_
int const max_sk_tiles = 2 * multi_processor_count_;
// https://github.com/NVIDIA/cutlass/blob/19b4c5e065e7e5bbc8082dfc7dbd792bdac850fc/include/cutlass/gemm/kernel/tile_scheduler_params.h#L939
// The above line says uint64_t sk_units = platform::min(ctas_per_sk_wave, min_sized_sk_units);
// That means sk_units is at most ctas_per_sk_wave, which is multi_processor_count_
int const max_sk_units = multi_processor_count_;
// https://github.com/NVIDIA/cutlass/blob/19b4c5e065e7e5bbc8082dfc7dbd792bdac850fc/include/cutlass/gemm/kernel/tile_scheduler_params.h#L505
// The above lines scales sk_tiles by the factor of static_cast<uint32_t>(sk_units / sk_tiles + 2)
// That means the final sk_tiles is at most 2 * max_sk_tiles + max_sk_units;
int const max_sk_tiles_with_seperate_reduction = 2 * max_sk_tiles + max_sk_units;
return static_cast<size_t>(
max_sk_tiles_with_seperate_reduction * MAX_M_TILE_SM90 * MAX_N_TILE_SM90 * sizeof(float));
}
// These are the min tile sizes for each config, which would launch the maximum number of blocks
int const max_grid_m = cutlass::ceil_div(m, MIN_M_TILE);
int const max_grid_n = cutlass::ceil_div(n, MIN_N_TILE);

View File

@ -14,10 +14,10 @@
* limitations under the License.
*/
#ifndef _WIN32
#ifdef __GNUC__ // Check if the compiler is GCC or Clang
#pragma GCC diagnostic push
#pragma GCC diagnostic ignored "-Wstrict-aliasing"
#endif // #ifndef _WIN32
#endif // __GNUC__
#include "cutlass/epilogue/collective/default_epilogue.hpp"
#include "cutlass/epilogue/thread/linear_combination.h"
@ -34,9 +34,9 @@
#include "cutlass_extensions/epilogue_helpers.h"
#include "cutlass_extensions/gemm_configs.h"
#ifndef _WIN32
#ifdef __GNUC__ // Check if the compiler is GCC or Clang
#pragma GCC diagnostic pop
#endif // #ifndef _WIN32
#endif // __GNUC__
#include "tensorrt_llm/common/assert.h"
#include "tensorrt_llm/common/cudaUtils.h"
@ -161,8 +161,11 @@ void sm90_generic_mixed_gemm_kernelLauncher(ActivationType const* A, WeightType
sizeof(typename CollectiveEpilogue::SharedStorage))>,
KernelSchedule>::CollectiveOp;
using TileScheduler = cute::conditional_t<size<0>(CTAShape{}) == Int<64>{}, cutlass::gemm::PersistentScheduler,
cutlass::gemm::StreamKScheduler>;
using GemmKernel = cutlass::gemm::kernel::GemmUniversal<Shape<int, int, int, int>, // Indicates ProblemShape
CollectiveMainloop, CollectiveEpilogue>;
CollectiveMainloop, CollectiveEpilogue, TileScheduler>;
if (occupancy != nullptr)
{

View File

@ -16,10 +16,10 @@
#pragma once
#ifndef _WIN32
#ifdef __GNUC__ // Check if the compiler is GCC or Clang
#pragma GCC diagnostic push
#pragma GCC diagnostic ignored "-Wstrict-aliasing"
#endif // #ifndef _WIN32
#endif // __GNUC__
#include "cute/tensor.hpp"
#include "cutlass/conv/convolution.h"
@ -38,9 +38,9 @@
#include "cutlass/epilogue/collective/collective_builder.hpp"
#include "cutlass/gemm/device/gemm_universal_adapter.h"
#ifndef _WIN32
#ifdef __GNUC__ // Check if the compiler is GCC or Clang
#pragma GCC diagnostic pop
#endif // #ifndef _WIN32
#endif // __GNUC__
using namespace cute;

View File

@ -16,10 +16,10 @@
#pragma once
#ifndef _WIN32
#ifdef __GNUC__ // Check if the compiler is GCC or Clang
#pragma GCC diagnostic push
#pragma GCC diagnostic ignored "-Wstrict-aliasing"
#endif // #ifndef _WIN32
#endif // __GNUC__
#include "cute/tensor.hpp"
#include "cutlass/conv/convolution.h"
@ -27,9 +27,9 @@
#include "cutlass/util/packed_stride.hpp"
#include "cutlass_extensions/gemm_configs.h"
#ifndef _WIN32
#ifdef __GNUC__ // Check if the compiler is GCC or Clang
#pragma GCC diagnostic pop
#endif // #ifndef _WIN32
#endif // __GNUC
#include "fused_gated_gemm.h"
#include "fused_gated_gemm_kernel_template_sm90.h"

View File

@ -14,10 +14,10 @@
* limitations under the License.
*/
#ifndef _WIN32
#ifdef __GNUC__ // Check if the compiler is GCC or Clang
#pragma GCC diagnostic push
#pragma GCC diagnostic ignored "-Wstrict-aliasing"
#endif // #ifndef _WIN32
#endif // __GNUC__
// clang-format off
#include <cutlass/gemm/device/default_gemm_configuration.h>
@ -36,9 +36,9 @@
#include "cutlass_extensions/gemm/kernel/default_int8_traits.h"
#include "cutlass_extensions/gemm/kernel/gemm_with_epilogue_visitor.h"
#ifndef _WIN32
#ifdef __GNUC__ // Check if the compiler is GCC or Clang
#pragma GCC diagnostic pop
#endif // #ifndef _WIN32
#endif // __GNUC__
#include "tensorrt_llm/common/cudaUtils.h"
#include "tensorrt_llm/kernels/cutlass_kernels/cutlass_heuristic.h"

View File

@ -42,8 +42,6 @@
#include "cutlass_extensions/gemm/kernel/moe_cutlass_kernel.h"
#include "cutlass_extensions/gemm/threadblock/default_mma.h"
#pragma GCC diagnostic pop
#include "tensorrt_llm/common/assert.h"
#include "tensorrt_llm/common/cudaUtils.h"
#include "tensorrt_llm/kernels/cutlass_kernels/cutlass_heuristic.h"

View File

@ -15,8 +15,10 @@
*/
// Ignore CUTLASS warnings about type punning
#ifdef __GNUC__ // Check if the compiler is GCC or Clang
#pragma GCC diagnostic push
#pragma GCC diagnostic ignored "-Wstrict-aliasing"
#endif
#include "cutlass/array.h"
#include "cutlass/numeric_conversion.h"
@ -44,7 +46,9 @@
#include "cutlass_extensions/gemm/kernel/moe_cutlass_kernel.h"
#include "cutlass_extensions/gemm/threadblock/default_mma.h"
#ifdef __GNUC__ // Restore GCC-specific diagnostics
#pragma GCC diagnostic pop
#endif
#include "tensorrt_llm/common/assert.h"
#include "tensorrt_llm/common/cudaUtils.h"

View File

@ -15,8 +15,10 @@
*/
// Ignore CUTLASS warnings about type punning
#ifdef __GNUC__ // Check if the compiler is GCC or Clang
#pragma GCC diagnostic push
#pragma GCC diagnostic ignored "-Wstrict-aliasing"
#endif // __GNUC__
#include "cutlass/array.h"
#include "cutlass/numeric_conversion.h"
@ -44,7 +46,9 @@
#include "cutlass_extensions/gemm/kernel/moe_cutlass_kernel.h"
#include "cutlass_extensions/gemm/threadblock/default_mma.h"
#ifdef __GNUC__ // Check if the compiler is GCC or Clang
#pragma GCC diagnostic pop
#endif // __GNUC__
#include "tensorrt_llm/common/assert.h"
#include "tensorrt_llm/common/cudaUtils.h"

View File

@ -197,6 +197,7 @@ struct Multihead_attention_params_base
int* block_counter = nullptr;
int const* memory_length_per_sample = nullptr;
int32_t const* mrope_position_deltas = nullptr;
};
template <typename T, bool USE_CROSS_ATTENTION = false>

View File

@ -75,7 +75,8 @@ inline size_t smem_size_in_bytes(Multihead_attention_params<T, DO_CROSS_ATTENTIO
size_t transpose_rotary_size = 0;
if (params.position_embedding_type == PositionEmbeddingType::kROPE_GPT_NEOX
|| params.position_embedding_type == PositionEmbeddingType::kLONG_ROPE)
|| params.position_embedding_type == PositionEmbeddingType::kLONG_ROPE
|| params.position_embedding_type == PositionEmbeddingType::kROPE_M)
{
assert(params.rotary_embedding_dim > 0);
transpose_rotary_size = 2 * params.rotary_embedding_dim * sizeof(Tk);
@ -416,7 +417,8 @@ void mmha_launch_kernel(KernelParamsType const& params, KVCacheBuffer const& kv_
assert((params.rotary_embedding_dim != 0)
== (params.position_embedding_type == PositionEmbeddingType::kROPE_GPT_NEOX
|| params.position_embedding_type == PositionEmbeddingType::kROPE_GPTJ
|| params.position_embedding_type == PositionEmbeddingType::kLONG_ROPE));
|| params.position_embedding_type == PositionEmbeddingType::kLONG_ROPE
|| params.position_embedding_type == PositionEmbeddingType::kROPE_M));
if (params.beam_width == 1)
{
mmha_launch_kernel_dispatch<T, KVCacheBuffer, KernelParamsType, Dh, false, BLOCK_SPARSE_ATTN,

View File

@ -1500,7 +1500,7 @@ __global__ void __launch_bounds__(MAX_THEADS_PER_BLOCK, MIN_BLOCKS_PER_SM) maske
int const beam0_context_length
= HAS_BEAMS && tlength > cyclic_kv_cache_len ? 0 : params.input_lengths[batch_beam_idx];
// The position of the current timestep, and it is used to apply the position embedding
int const current_pos_idx = (!POS_SHIFT || DO_CROSS_ATTENTION) ? tlength : kv_loop_length;
int current_pos_idx = (!POS_SHIFT || DO_CROSS_ATTENTION) ? tlength : kv_loop_length;
// The offset in the Q and K buffer also accounts for the batch.
auto const qk_vec_idx = tidx * QK_VEC_SIZE;
@ -1667,6 +1667,7 @@ __global__ void __launch_bounds__(MAX_THEADS_PER_BLOCK, MIN_BLOCKS_PER_SM) maske
break;
}
case PositionEmbeddingType::kLONG_ROPE:
case PositionEmbeddingType::kROPE_M:
case PositionEmbeddingType::kROPE_GPT_NEOX:
{
bool const do_rotary = is_valid_qk_vec && QK_VEC_SIZE * tidx < params.rotary_embedding_dim;
@ -1680,6 +1681,10 @@ __global__ void __launch_bounds__(MAX_THEADS_PER_BLOCK, MIN_BLOCKS_PER_SM) maske
int const smem_pitch = half_rotary_dim; // TODO: adjust for bank conflicts
assert(half_rotary_dim % QK_VEC_SIZE == 0);
if (params.position_embedding_type == PositionEmbeddingType::kROPE_M)
{
current_pos_idx += params.mrope_position_deltas[batch_idx];
}
if (do_rotary)
{

View File

@ -221,8 +221,9 @@ void DecoderXQAImplJIT::runImpl(XQAParams const& xqaParams, KVCacheBuffer const&
xqaParams.spec_decoding_generation_lengths, xqaParams.sequence_lengths, /* encoder_seqlens */ nullptr,
xqaParams.multi_query_tokens ? launchParams.cu_seq_lens : nullptr, /* cu_kv_seqlens */ nullptr,
launchParams.rotary_inv_freq_buf, (float2 const*) nullptr, xqaParams.kv_scale_orig_quant,
xqaParams.spec_decoding_position_offsets, int(batch_beam_size), xqaParams.generation_input_length,
xqaParams.timestep, xqaParams.cyclic_attention_window_size, xqaParams.sink_token_length,
xqaParams.spec_decoding_position_offsets, xqaParams.mrope_rotary_sin_cos, xqaParams.mrope_position_deltas,
int(batch_beam_size), xqaParams.generation_input_length, xqaParams.timestep,
xqaParams.cyclic_attention_window_size, xqaParams.sink_token_length,
int(xqaParams.batch_size * beam_width * xqaParams.generation_input_length),
/*remove_padding*/ true, /*cross_attention*/ false, xqaParams.num_q_heads, xqaParams.num_kv_heads,
xqaParams.num_q_heads / xqaParams.num_kv_heads, xqaParams.head_size, xqaParams.rotary_embedding_dim,

View File

@ -73,7 +73,7 @@ bool supportConfigCommon(XQAParams const& xqaParams, bool forConfigurePlugin)
return false;
}
if (!contains({PositionEmbeddingType::kROPE_GPTJ, PositionEmbeddingType::kROPE_GPT_NEOX,
PositionEmbeddingType::kLONG_ROPE},
PositionEmbeddingType::kROPE_M, PositionEmbeddingType::kLONG_ROPE},
xqaParams.position_embedding_type))
{
return false;

View File

@ -1,2 +1,2 @@
90df70c216d9aa2c85b8b097c853e4ba libtensorrt_llm_nvrtc_wrapper.so
1c2eb102257f836cd50faf985e693241d7a84dbe commit
0397a251a647dd5d25f0de5279170ba35d82c50d commit

View File

@ -1,2 +1,2 @@
232f492424a31204a2be2e67be299aef libtensorrt_llm_nvrtc_wrapper.so
1c2eb102257f836cd50faf985e693241d7a84dbe commit
0397a251a647dd5d25f0de5279170ba35d82c50d commit

View File

@ -1,3 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:bed2713947315cf941533dd12b5b98270a2aabd584cc33bc2092be6dbf879959
oid sha256:ed682d2c566aa703cfd1276f1329d3db3b291aed7e325e521bbe6f3406d7cd84
size 1128448

View File

@ -1,3 +1,3 @@
c5f36e093e875c8ea84523fb1566d986 tensorrt_llm_nvrtc_wrapper.lib
aaa20992c207e46eab50dd90bcf3c405 tensorrt_llm_nvrtc_wrapper.dll
1c2eb102257f836cd50faf985e693241d7a84dbe commit
6ed8db839864ffdd8f22fff0d627ee9d tensorrt_llm_nvrtc_wrapper.dll
0397a251a647dd5d25f0de5279170ba35d82c50d commit

View File

@ -204,9 +204,10 @@ public:
xqaParams.sequence_lengths, /* encoder_seqlens */ nullptr,
xqaParams.multi_query_tokens ? launchParams.cu_seq_lens : nullptr,
/* cu_kv_seqlens */ nullptr, launchParams.rotary_inv_freq_buf, (float2 const*) nullptr,
xqaParams.kv_scale_orig_quant, xqaParams.spec_decoding_position_offsets, int(batch_beam_size),
xqaParams.generation_input_length, xqaParams.timestep, xqaParams.cyclic_attention_window_size,
xqaParams.sink_token_length, int(xqaParams.batch_size * beam_width * xqaParams.generation_input_length),
xqaParams.kv_scale_orig_quant, xqaParams.spec_decoding_position_offsets, xqaParams.mrope_rotary_sin_cos,
xqaParams.mrope_position_deltas, int(batch_beam_size), xqaParams.generation_input_length,
xqaParams.timestep, xqaParams.cyclic_attention_window_size, xqaParams.sink_token_length,
int(xqaParams.batch_size * beam_width * xqaParams.generation_input_length),
/*remove_padding*/ true, /*cross_attention*/ false, xqaParams.num_q_heads, xqaParams.num_kv_heads,
xqaParams.num_q_heads / xqaParams.num_kv_heads, xqaParams.head_size, xqaParams.rotary_embedding_dim,
xqaParams.rotary_embedding_base, xqaParams.rotary_embedding_scale_type, xqaParams.rotary_embedding_scale,

View File

@ -54,6 +54,8 @@ struct XQAParams
int const* spec_decoding_generation_lengths; // variable input lengths.
bool spec_decoding_is_generation_length_variable; // whether the generation lengths actually vary
int32_t spec_decoding_max_generation_length; // max possible input length
float2 const* mrope_rotary_sin_cos = nullptr;
int32_t const* mrope_position_deltas = nullptr;
// almost copy from GPTAttentionPluginCommon.
// maybe use one struct for parameters in GPTAttentionPluginCommon and share the same here.

View File

@ -59,6 +59,7 @@ enum class PositionEmbeddingType : int8_t
kRELATIVE = 6,
kCHATGLM = 7,
kYARN = 8,
kROPE_M = 9,
};
enum class RotaryScalingType : int8_t

View File

@ -1,3 +1,3 @@
f1820d73fc5cac7fa324d71933e5412a libtensorrt_llm_internal_cutlass_kernels_static.a
a8785db1cc11e3b571bc071d7abec1a8 libtensorrt_llm_internal_cutlass_kernels_static.pre_cxx11.a
1c2eb102257f836cd50faf985e693241d7a84dbe commit
0397a251a647dd5d25f0de5279170ba35d82c50d commit

View File

@ -1,3 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:4b6917794ec6e67989fdcd0af3cc4d84713f3d8d4dcd822d2df2272117c66d6b
oid sha256:14fa77c4e77a1a6c6955539f1721a773663785d628181e3bb30fedcccb676dfc
size 36626184

View File

@ -1,3 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:e6d2f3c25a8ce88917ba512eba804f14827703fab6f9ac8d63043e2d95b6b281
oid sha256:da7b96feeabce735db2a0d0524c4d057407bc7e7201b92383d4f7693db4aba7f
size 36080026

View File

@ -1,3 +1,3 @@
9e6ff6d826caeea1e6e19c71f5d0986b libtensorrt_llm_internal_cutlass_kernels_static.a
2fca4e76d5f21089f00b1d39f624b80a libtensorrt_llm_internal_cutlass_kernels_static.pre_cxx11.a
1c2eb102257f836cd50faf985e693241d7a84dbe commit
5b0663234aecc8eb59d51f1e0a9206d6 libtensorrt_llm_internal_cutlass_kernels_static.a
5da2c92d3a859d10cf65bf4edab96c62 libtensorrt_llm_internal_cutlass_kernels_static.pre_cxx11.a
0397a251a647dd5d25f0de5279170ba35d82c50d commit

View File

@ -1,3 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:c2f34df6d47b7b2b6629358bb03b33eb193db067188e8b980598027b0ff85392
size 2669968
oid sha256:17d5d559c2a9748cf03b24d82270739430eee53c1d4dc41442ae05745724af84
size 2669962

View File

@ -1,2 +1,2 @@
95c2f50347d4de94e2e09cbf0cf99582 tensorrt_llm_internal_cutlass_kernels_static.lib
1c2eb102257f836cd50faf985e693241d7a84dbe commit
133f764ad845f62a4641bffca9c436b2 tensorrt_llm_internal_cutlass_kernels_static.lib
0397a251a647dd5d25f0de5279170ba35d82c50d commit

View File

@ -26,8 +26,10 @@
#include <sstream>
// Ignore CUTLASS warnings about type punning
#ifdef __GNUC__ // Check if the compiler is GCC or Clang
#pragma GCC diagnostic push
#pragma GCC diagnostic ignored "-Wstrict-aliasing"
#endif
#include "cute/tensor.hpp"
#include "cutlass/conv/convolution.h"
@ -41,7 +43,9 @@
#include "cutlass_extensions/epilogue/thread/fused_activations.h"
#ifdef __GNUC__ // Check if the compiler is GCC or Clang
#pragma GCC diagnostic pop
#endif
#include "tensorrt_llm/common/cudaUtils.h"
#include "tensorrt_llm/common/dataType.h"

View File

@ -235,8 +235,23 @@ __global__ void topKStage2Sampling(SizeType32 const* __restrict topKTmpIdBuf, T*
auto const curSeqLen = sequenceLengths == nullptr ? 0 : sequenceLengths[batchSlot];
auto const outIdx = returnAllSelectedTokens ? tokenIdx * maxTopK + ki : curSeqLen + tokenIdx;
outputIdsRequestPtr[outIdx] = outputId;
// cum log prob is not supported with returnAllSelectedTokens
if (!returnAllSelectedTokens)
if (returnAllSelectedTokens)
{
// 'outputLogProbs' is the probability induced by the top-k sampling:
// NOT normalized (same way as OpenAI does):
// log_prob = log P(i | i is in vocab) = log(expLogit)
// normalized:
// log_prob = log P(i | i is in top-k) = log(expLogit / sum)
if (outputLogProbs != nullptr)
{
// outputLogProbs shape: [maxBatchSize, maxTopK]
auto logProb = logf(expLogit);
auto const normalizedProb = normalizeLogProbs ? logProb - logf(sSum) : logProb;
outputLogProbs[batchSlot * maxTopK + ki] = normalizedProb;
}
}
else
{
if (cumLogProbs != nullptr || outputLogProbs != nullptr)
{
@ -247,17 +262,14 @@ __global__ void topKStage2Sampling(SizeType32 const* __restrict topKTmpIdBuf, T*
}
if (outputLogProbs != nullptr)
{
// 'outputLogProbs' is the probability induced by the top-k sampling:
// NOT normalized (same way as OpenAI does):
// log_prob = log P(i | i is in vocab) = log(expLogit)
// normalized:
// log_prob = log P(i | i is in top-k) = log(expLogit / sum)
outputLogProbs[curSeqLen * maxBatchSize + batchSlot]
= normalizeLogProbs ? logProb - logf(sSum) : logProb;
auto const normalizedProb = normalizeLogProbs ? logProb - logf(sSum) : logProb;
// outputLogProbs shape: [maxSeqLen, maxBatchSize]
outputLogProbs[curSeqLen * maxBatchSize + batchSlot] = normalizedProb;
}
}
break;
}
if (returnAllSelectedTokens && randNum <= 0.0f)
{
if (ki < k - 1)

View File

@ -74,7 +74,9 @@ struct TopKSamplingKernelParams
//! input/output buffer [maxBatchSize], optional.
//! Cumulative log probability of selected tokens. Ignored if nullptr
float* cumLogProbs{nullptr};
//! output buffer [maxBatchSize]. Log probs is the probability induced by the top-k sampling.
//! output buffer
//! [maxBatchSize, maxTopK] when returnAllSelectedTokens, otherwise [maxSeqLen, maxBatchSize]
//! Log probs is the probability induced by the top-k sampling.
//! If normalizeLogProbs is true, we normalize the probability 'expLogit' of the selected token
//! by the probability 's_sum' of a set of top-k tokens, meaning the logProb is the probability
//! of the selected token, conditioned on the event that it is selected,
@ -137,8 +139,13 @@ struct TopKSamplingKernelParams
TLLM_CHECK(maxTokensPerStep != 1 || returnAllSelectedTokens || endIds);
if (cumLogProbs != nullptr || outputLogProbs != nullptr)
{
TLLM_CHECK(maxTokensPerStep == 1 && !returnAllSelectedTokens);
TLLM_CHECK(maxTokensPerStep == 1);
if (cumLogProbs != nullptr)
{
TLLM_CHECK(!returnAllSelectedTokens);
}
}
TLLM_CHECK(((finishedOutput == nullptr) ^ (endIds == nullptr)) == 0);
TLLM_CHECK(0 < maxTopP && maxTopP <= 1.f);

View File

@ -98,11 +98,11 @@ namespace
template <int BLOCK_SIZE>
__global__ void prepareCtxEagleNetInputsKernel(SizeType32* eagleNetSequenceLengths, SizeType32* eagleNetContextLengths,
TokenIdType* outputIds, SizeType32* positionIds, SizeType32* hiddenStatesIndices, SizeType32* lastTokenIndices,
SizeType32* numOutputTokens, SizeType32* numLastTokenIndices, SizeType32* hiddenSizeBatchLevelStarts,
TokenIdType const* inputIds, SizeType32 const* baseNetSequenceLengths, SizeType32 const* baseNetContextLengths,
SizeType32* numLastTokenIndices, SizeType32* hiddenSizeBatchLevelStarts, TokenIdType const* inputIds,
SizeType32 const* baseNetSequenceLengths, SizeType32 const* baseNetContextLengths,
TokenIdType const* acceptedTokens, SizeType32 const* acceptedLens, SizeType32 const* prevDraftLens,
SizeType32 const* prevPaths, SizeType32 const* bestPathIds, SizeType32 batchSize, SizeType32 maxPathLen,
SizeType32 maxDecodingTokens)
SizeType32 maxDecodingTokens, SizeType32 maxNonLeavesPerLayer)
{
typedef cub::BlockScan<SizeType32, BLOCK_SIZE> BlockScan;
__shared__ typename BlockScan::TempStorage tempStorage;
@ -135,6 +135,11 @@ __global__ void prepareCtxEagleNetInputsKernel(SizeType32* eagleNetSequenceLengt
}
}
for (SizeType32 ii = bid; ii < maxNonLeavesPerLayer * batchSize; ii += BLOCK_SIZE)
{
lastTokenIndices[ii] = 1;
}
SizeType32 outputStartPos{0};
SizeType32 inputIndexBase{0};
SizeType32 lastTokenIndex{0};
@ -215,9 +220,9 @@ __global__ void prepareCtxEagleNetInputsKernel(SizeType32* eagleNetSequenceLengt
// The last thread writes number of flattened tokens.
if (bid == BLOCK_SIZE - 1)
{
numOutputTokens[0] = outputStartPos + numDecodingTokens;
// After the first EagleNet we predict exactly one set of logits per request.
numLastTokenIndices[0] = batchSize;
// Set last hiddenSizeBatchLevelStarts.
hiddenSizeBatchLevelStarts[batchSize] = batchSize;
}
@ -226,19 +231,20 @@ __global__ void prepareCtxEagleNetInputsKernel(SizeType32* eagleNetSequenceLengt
void invokePrepareCtxEagleNetInputs(SizeType32* eagleNetSequenceLengths, SizeType32* eagleNetContextLengths,
TokenIdType* outputIds, SizeType32* positionIds, SizeType32* hiddenStatesIndices, SizeType32* lastTokenIndices,
SizeType32* numOutputTokens, SizeType32* numLastTokenIndices, SizeType32* hiddenSizeBatchLevelStarts,
TokenIdType const* inputIds, SizeType32 const* baseNetSequenceLengths, SizeType32 const* baseNetContextLengths,
SizeType32* numLastTokenIndices, SizeType32* hiddenSizeBatchLevelStarts, TokenIdType const* inputIds,
SizeType32 const* baseNetSequenceLengths, SizeType32 const* baseNetContextLengths,
TokenIdType const* acceptedTokens, SizeType32 const* acceptedLens, SizeType32 const* prevDraftLens,
SizeType32 const* prevPaths, SizeType32 const* bestPathIds, SizeType32 batchSize, SizeType32 maxPathLen,
SizeType32 maxDecodingTokens, cudaStream_t stream)
SizeType32 maxDecodingTokens, SizeType32 maxNonLeavesPerLayer, cudaStream_t stream)
{
SizeType32 constexpr BLOCK_SIZE = 512;
TLLM_CHECK_WITH_INFO(
batchSize <= BLOCK_SIZE, "Batch size larger than %d is not supported for EAGLE yet", batchSize);
prepareCtxEagleNetInputsKernel<BLOCK_SIZE><<<1, BLOCK_SIZE, 0, stream>>>(eagleNetSequenceLengths,
eagleNetContextLengths, outputIds, positionIds, hiddenStatesIndices, lastTokenIndices, numOutputTokens,
numLastTokenIndices, hiddenSizeBatchLevelStarts, inputIds, baseNetSequenceLengths, baseNetContextLengths,
acceptedTokens, acceptedLens, prevDraftLens, prevPaths, bestPathIds, batchSize, maxPathLen, maxDecodingTokens);
eagleNetContextLengths, outputIds, positionIds, hiddenStatesIndices, lastTokenIndices, numLastTokenIndices,
hiddenSizeBatchLevelStarts, inputIds, baseNetSequenceLengths, baseNetContextLengths, acceptedTokens,
acceptedLens, prevDraftLens, prevPaths, bestPathIds, batchSize, maxPathLen, maxDecodingTokens,
maxNonLeavesPerLayer);
}
namespace
@ -431,13 +437,13 @@ template <int BLOCK_SIZE>
__global__ void prepareGenEagleNetInputsKernel(SizeType32* nextSequenceLengths, SizeType32* nextContextLengths,
TokenIdType* outputIds, SizeType32* positionIds, SizeType32* specDecodingGenLengths,
SizeType32* specDecodingPositionOffsets, SizeType32* specDecodingPackedMasks, SizeType32* hiddenStatesIndices,
SizeType32* lastTokenIndices, SizeType32* numOutputTokens, SizeType32* numLastTokenIndices,
SizeType32* outputHiddenSizeBatchStartsPerLevel, SizeType32* cumSumGenerationLengths,
SizeType32* maxGenerationLength, TokenIdType const* nextDraftIds, SizeType32 const* selectedDraftIndices,
SizeType32 const* selectedDraftPosIds, SizeType32 const* numSelectedDraftIndices,
SizeType32 const* eagleNet0SequenceLengths, SizeType32 const* prevContextLengths,
SizeType32 const* inputHiddenSizeBatchStartsPerLevel, SizeType32 const* parentNonLeafInLevelOffset,
SizeType32 levelIdx, SizeType32 batchSize, SizeType32 maxPathLen, SizeType32 maxDecodingTokens)
SizeType32* lastTokenIndices, SizeType32* numLastTokenIndices, SizeType32* outputHiddenSizeBatchStartsPerLevel,
SizeType32* cumSumGenerationLengths, SizeType32* maxGenerationLength, TokenIdType const* nextDraftIds,
SizeType32 const* selectedDraftIndices, SizeType32 const* selectedDraftPosIds,
SizeType32 const* numSelectedDraftIndices, SizeType32 const* eagleNet0SequenceLengths,
SizeType32 const* prevContextLengths, SizeType32 const* inputHiddenSizeBatchStartsPerLevel,
SizeType32 const* parentNonLeafInLevelOffset, SizeType32 levelIdx, SizeType32 batchSize, SizeType32 maxPathLen,
SizeType32 maxDecodingTokens, SizeType32 maxNonLeavesPerLayer)
{
typedef cub::BlockScan<SizeType32, BLOCK_SIZE> BlockScan;
typedef cub::BlockReduce<SizeType32, BLOCK_SIZE> BlockReduce;
@ -551,7 +557,7 @@ __global__ void prepareGenEagleNetInputsKernel(SizeType32* nextSequenceLengths,
auto const lastStart = inputHiddenSizeBatchStartsPerLevel[(levelIdx - 1) * batchSize + batchSize];
// Set new layer idx.
outputHiddenSizeBatchStartsPerLevel[levelIdx * batchSize + bid] = lastStart + outputLastIndicesBase;
outputHiddenSizeBatchStartsPerLevel[levelIdx * batchSize + bid] = lastStart + bid * maxNonLeavesPerLayer;
}
__syncthreads();
@ -559,12 +565,11 @@ __global__ void prepareGenEagleNetInputsKernel(SizeType32* nextSequenceLengths,
// The last valid thread fills the number of tokens.
if (bid == batchSize - 1)
{
numOutputTokens[0] = outputIndexBase + nextDraftLen;
// Set the total number of logits needed after the next iteration.
numLastTokenIndices[0] = lastIndices;
// Set last outputHiddenSizeBatchStartsPerLevel.
outputHiddenSizeBatchStartsPerLevel[levelIdx * batchSize + batchSize]
= outputHiddenSizeBatchStartsPerLevel[levelIdx * batchSize + batchSize - 1] + numNextLogits;
= outputHiddenSizeBatchStartsPerLevel[levelIdx * batchSize + batchSize - 1] + maxNonLeavesPerLayer;
}
}
@ -774,12 +779,12 @@ void invokePrepareGenEagleNetInputs(PrepareGenEagleNetInputsParams const& params
prepareGenEagleNetInputsKernel<BLOCK_SIZE><<<1, BLOCK_SIZE, 0, params.stream>>>(params.nextSequenceLengths,
params.nextContextLengths, params.outputIds, params.positionIds, params.specDecodingGenLengths,
params.specDecodingPositionOffsets, params.specDecodingPackedMasks, params.hiddenStatesIndices,
params.lastTokenIndices, params.numOutputTokens, params.numLastTokenIndices,
params.outputHiddenSizeBatchStartsPerLevel, params.cumSumGenerationLengths, params.maxGenerationLength,
params.nextDraftIds, params.selectedDraftIndices, params.selectedDraftPosOffsets,
params.numSelectedDraftIndices, params.eagleNet0SequenceLengths, params.prevContextLengths,
params.inputHiddenSizeBatchStartsPerLevel, params.parentNonLeafInLevelOffset, params.levelIdx,
params.batchSize, params.maxPathLen, params.maxDecodingTokens);
params.lastTokenIndices, params.numLastTokenIndices, params.outputHiddenSizeBatchStartsPerLevel,
params.cumSumGenerationLengths, params.maxGenerationLength, params.nextDraftIds,
params.selectedDraftIndices, params.selectedDraftPosOffsets, params.numSelectedDraftIndices,
params.eagleNet0SequenceLengths, params.prevContextLengths, params.inputHiddenSizeBatchStartsPerLevel,
params.parentNonLeafInLevelOffset, params.levelIdx, params.batchSize, params.maxPathLen,
params.maxDecodingTokens, params.maxNonLeavesPerLayer);
sync_check_cuda_error();
}
@ -803,11 +808,13 @@ namespace
template <typename T>
__global__ void assembleDraftLogitsOffsets(T const** logitsPtrs, T const* logits, TokenIdType** outputIdsPtrs,
TokenIdType* outputIds, SizeType32 numInputLogits, SizeType32 maxDecodingDraftTokens, SizeType32 vocabSizePadded)
TokenIdType* outputIds, bool* skipDecode, runtime::SizeType32 const* numValidLogits, SizeType32 batchSize,
SizeType32 maxDecodingDraftTokens, SizeType32 vocabSizePadded)
{
auto const tix = static_cast<SizeType32>(blockIdx.x * blockDim.x + threadIdx.x);
auto const isValid{tix < numValidLogits[0]};
if (tix < numInputLogits)
if (isValid)
{
// logits: [numInputLogits, vocab_size]
// logitsPtrs: [numInputLogits][1, vocab_size]
@ -817,29 +824,34 @@ __global__ void assembleDraftLogitsOffsets(T const** logitsPtrs, T const* logits
// outputIdsPtrs: [numInputLogits][maxDecodingDraftTokens]
outputIdsPtrs[tix] = outputIds + tix * maxDecodingDraftTokens;
}
skipDecode[tix] = !isValid;
}
} // namespace
template <typename T>
void invokeAssembleDraftLogitsOffsets(T const** logitsPtrs, T const* logits, runtime::TokenIdType** outputIdsPtrs,
runtime::TokenIdType* outputIds, runtime::SizeType32 numInputLogits, runtime::SizeType32 maxDecodingDraftTokens,
runtime::TokenIdType* outputIds, bool* skipDecode, runtime::SizeType32 const* numValidLogits,
runtime::SizeType32 numInputLogits, runtime::SizeType32 batchSize, runtime::SizeType32 maxDecodingDraftTokens,
runtime::SizeType32 vocabSizePadded, cudaStream_t stream)
{
SizeType32 constexpr BLOCK_SIZE = 512;
assembleDraftLogitsOffsets<T><<<divUp(numInputLogits, BLOCK_SIZE), BLOCK_SIZE, 0, stream>>>(
logitsPtrs, logits, outputIdsPtrs, outputIds, numInputLogits, maxDecodingDraftTokens, vocabSizePadded);
assembleDraftLogitsOffsets<T><<<divUp(numInputLogits, BLOCK_SIZE), BLOCK_SIZE, 0, stream>>>(logitsPtrs, logits,
outputIdsPtrs, outputIds, skipDecode, numValidLogits, batchSize, maxDecodingDraftTokens, vocabSizePadded);
sync_check_cuda_error();
}
template void invokeAssembleDraftLogitsOffsets(float const** logitsPtrs, float const* logits,
runtime::TokenIdType** outputIdsPtrs, runtime::TokenIdType* outputIds, runtime::SizeType32 numInputLogits,
runtime::TokenIdType** outputIdsPtrs, runtime::TokenIdType* outputIds, bool* skipDecode,
runtime::SizeType32 const* numValidLogits, runtime::SizeType32 numInputLogits, runtime::SizeType32 batchSize,
runtime::SizeType32 maxDecodingDraftTokens, runtime::SizeType32 vocabSizePadded, cudaStream_t stream);
template void invokeAssembleDraftLogitsOffsets(__half const** logitsPtrs, __half const* logits,
runtime::TokenIdType** outputIdsPtrs, runtime::TokenIdType* outputIds, runtime::SizeType32 numInputLogits,
runtime::TokenIdType** outputIdsPtrs, runtime::TokenIdType* outputIds, bool* skipDecode,
runtime::SizeType32 const* numValidLogits, runtime::SizeType32 numInputLogits, runtime::SizeType32 batchSize,
runtime::SizeType32 maxDecodingDraftTokens, runtime::SizeType32 vocabSizePadded, cudaStream_t stream);
////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
@ -970,8 +982,8 @@ __global__ void extracTopKsFromSuccessorsArray(SizeType32* topKs, SizeType32* to
// Extract topKs from paths and layerId
void invokeExtractTopKsFromPath(runtime::SizeType32 const* paths, runtime::SizeType32* topKs,
runtime::SizeType32* topKOffset, runtime::SizeType32* numSuccessorsForEachNode, runtime::SizeType32 layerId,
runtime::SizeType32 batchSize, runtime::SizeType32 numInputLogits, runtime::SizeType32 maxDecodingTokens,
runtime::SizeType32 maxPathLen, cudaStream_t stream)
runtime::SizeType32 batchSize, runtime::SizeType32 maxDecodingTokens, runtime::SizeType32 maxPathLen,
cudaStream_t stream)
{
TLLM_CHECK_WITH_INFO(
@ -999,8 +1011,8 @@ namespace
{
__global__ void copyOutputTokensIds(TokenIdType** tmpOutputIdsPtrs, SizeType32 const* topKs,
SizeType32 const* topKOffset, TokenIdType const* pluginInputDraftIdsPtrs, SizeType32 const* pluginInputDraftLens,
TokenIdType* pluginOutputDraftIdsPtrs, SizeType32* pluginOutputDraftLens, SizeType32 layerId, SizeType32 batchSize,
SizeType32 numInputLogits, SizeType32 maxDecodingDraftTokens)
SizeType32 const* numValidLogits, TokenIdType* pluginOutputDraftIdsPtrs, SizeType32* pluginOutputDraftLens,
SizeType32 layerId, SizeType32 batchSize, SizeType32 maxDecodingDraftTokens)
{
// tmpOutputIdsPtrs: shape [numInputLogits][maxDecodingDraftTokens]
// topKs: shape [numInputLogits]
@ -1030,7 +1042,7 @@ __global__ void copyOutputTokensIds(TokenIdType** tmpOutputIdsPtrs, SizeType32 c
// Compute the topK offset
SizeType32 startTopKOffset = topKOffset[tix];
SizeType32 endTopkOffset = tix + 1 < batchSize ? topKOffset[tix + 1] : numInputLogits;
SizeType32 endTopkOffset = tix + 1 < batchSize ? topKOffset[tix + 1] : numValidLogits[0];
for (SizeType32 ii = startTopKOffset; ii < endTopkOffset; ii++)
{
@ -1051,15 +1063,16 @@ __global__ void copyOutputTokensIds(TokenIdType** tmpOutputIdsPtrs, SizeType32 c
// Copy output draft token ids from temporary buffer to plugin output buffer, also update the draft token length
void invokeCopyOutputTokensIds(runtime::TokenIdType** tmpOutputIdsPtrs, runtime::SizeType32 const* topKs,
runtime::SizeType32 const* topKOffset, runtime::TokenIdType const* pluginInputDraftIdsPtrs,
runtime::SizeType32 const* pluginInputDraftLens, runtime::TokenIdType* pluginOutputDraftIdsPtrs,
runtime::SizeType32* pluginOutputDraftLens, runtime::SizeType32 layerId, runtime::SizeType32 batchSize,
runtime::SizeType32 numInputLogits, runtime::SizeType32 maxDecodingDraftTokens, cudaStream_t stream)
runtime::SizeType32 const* pluginInputDraftLens, runtime::SizeType32 const* numValidLogits,
runtime::TokenIdType* pluginOutputDraftIdsPtrs, runtime::SizeType32* pluginOutputDraftLens,
runtime::SizeType32 layerId, runtime::SizeType32 batchSize, runtime::SizeType32 maxDecodingDraftTokens,
cudaStream_t stream)
{
SizeType32 constexpr BLOCK_SIZE = 512;
copyOutputTokensIds<<<divUp(batchSize, BLOCK_SIZE), BLOCK_SIZE, 0, stream>>>(tmpOutputIdsPtrs, topKs, topKOffset,
pluginInputDraftIdsPtrs, pluginInputDraftLens, pluginOutputDraftIdsPtrs, pluginOutputDraftLens, layerId,
batchSize, numInputLogits, maxDecodingDraftTokens);
pluginInputDraftIdsPtrs, pluginInputDraftLens, numValidLogits, pluginOutputDraftIdsPtrs, pluginOutputDraftLens,
layerId, batchSize, maxDecodingDraftTokens);
}
namespace

View File

@ -41,21 +41,25 @@ void invokeAssembleTargetLogitsOffsets(T const** logitsPtrs, runtime::SizeType32
runtime::SizeType32 maxDecodingTokens, runtime::SizeType32 vocabSizePadded, cudaStream_t stream);
//! FIXME: We may get rid of this kernel in future optimization
//! \brief Set the logitsPtrs[numInputLogits][1, vocabSizePadded] from logits [numInputLogits * vocabSizePadded]
//! and outputIdsPtrs[numInputLogits][maxDecodingDraftTokens] from outputIds[numInputLogits * maxDecodingDraftTokens]
//! \brief Set the logitsPtrs[numValidLogits][1, vocabSizePadded] from logits [numInputLogits * vocabSizePadded]
//! and outputIdsPtrs[numValidLogits][maxDecodingDraftTokens] from outputIds[numInputLogits * maxDecodingDraftTokens]
//! Can be merged into other kernels
//! \param logitsPtrs [numInputLogits][1, vocabSizePadded], on GPU. The logits pointer array that will be used in topK
//! \param logitsPtrs [numValidLogits][1, vocabSizePadded], on GPU. The logits pointer array that will be used in topK
//! sampling.
//! \param logits [numInputLogits * vocabSizePadded], on GPU. Flatten logits, generated by the EagleNet.
//! \param outputIdsPtrs [numInputLogits][maxDecodingDraftTokens], on GPU. The output buffer of the topK sampling.
//! \param outputIdsPtrs [numValidLogits][maxDecodingDraftTokens], on GPU. The output buffer of the topK sampling.
//! \param outputIds [numInputLogits * maxDecodingDraftTokens], on GPU. The flatten output buffer.
//! \param skipDecode [batchSize * maxNonLeavesPerLayer], on GPU. Flag whether to skip decoding or not.
//! First batchSize * sum(numValidLogitsPerRequest[:]) are set to true, the rest is false.
//! \param numValidLogits [1], on GPU. Number of valid logits.
//! \param numInputLogits SizeType32. Number of logits from all the requests.
//! \param maxDecodingDraftTokens maximum number of decoding draft tokens per step per request
//! \param vocabSizePadded vocab size of the logits
//! \param stream cuda stream
template <typename T>
void invokeAssembleDraftLogitsOffsets(T const** logitsPtrs, T const* logits, runtime::TokenIdType** outputIdsPtrs,
runtime::TokenIdType* outputIds, runtime::SizeType32 numInputLogits, runtime::SizeType32 maxDecodingDraftTokens,
runtime::TokenIdType* outputIds, bool* skipDecode, runtime::SizeType32 const* numValidLogits,
runtime::SizeType32 numInputLogits, runtime::SizeType32 batchSize, runtime::SizeType32 maxDecodingDraftTokens,
runtime::SizeType32 vocabSizePadded, cudaStream_t stream);
//! \brief Prepares data for ctx stage EagleNet (EagleNet0).
@ -74,12 +78,11 @@ void invokeAssembleDraftLogitsOffsets(T const** logitsPtrs, T const* logits, run
//! E.g. With 3 requests where the first two are context requests with lengths 5 and 3 respectively and the 3rd
//! is gen request with draftDecodingTokens=8 and acceptedLength=3 and the best path is [0, 2, 5].
//! hiddenStatesIndices equals to [0, 1, 2, 3, 4, 5, 6, 7, 8, 10, 13].
//! \param lastTokenIndices output buffer [numLastTokenIndices],
//! \param numOutputTokens output buffer [1], single number equals to the total
//! number of select tokens summed over all batches.
//! \param numLastTokenIndices output buffer [1], single number equals to the total
//! number of logits predicted by the next EagleNet iteration tokens summed over all batches.
//! For EagleNet0 it is equal to the batchSize.
//! \param lastTokenIndices output buffer [batchSize * maxNonLeavesPerLayer],
//! Indices (starting with 1) of the logits of interest after the EagleNet prediction.
//! Used for index_select of the hidden_states in the end of the EagleNet. Padded to maxNonLeavesPerLayer with 1s.
//! \param numLastTokenIndices output buffer [1], number of logits predicted by the next EagleNet
//! iteration. For EagleNet0 each value is batchSize.
//! \param hiddenSizeBatchLevelStarts output buffer [batchSize * maxDraftPathLen + 1]
//! Exclusive sum of the hidden states produced per batch per level.
//! For EagleNet0 it is just cum sum of 1s for batchSize.
@ -96,17 +99,18 @@ void invokeAssembleDraftLogitsOffsets(T const** logitsPtrs, T const* logits, run
//! \param batchSize batch size
//! \param maxPathLen Max number of accepted tokens per step
//! \param maxDecodingTokens Max number of draft tokens + 1
//! \param maxNonLeavesPerLayer Maximum number of non-leaf nodes per layer
//! \param stream cuda stream.
void invokePrepareCtxEagleNetInputs(runtime::SizeType32* eagleNetSequenceLengths,
runtime::SizeType32* eagleNetContextLengths, runtime::TokenIdType* outputIds, runtime::SizeType32* positionIds,
runtime::SizeType32* hiddenStatesIndices, runtime::SizeType32* lastTokenIndices,
runtime::SizeType32* numOutputTokens, runtime::SizeType32* numLastTokenIndices,
runtime::SizeType32* hiddenSizeBatchLevelStarts, runtime::TokenIdType const* inputIds,
runtime::SizeType32 const* baseNetSequenceLengths, runtime::SizeType32 const* baseNetContextLengths,
runtime::TokenIdType const* acceptedTokens, runtime::SizeType32 const* acceptedLens,
runtime::SizeType32 const* prevDraftLens, runtime::SizeType32 const* prevPaths,
runtime::SizeType32 const* bestPathIds, runtime::SizeType32 batchSize, runtime::SizeType32 maxPathLen,
runtime::SizeType32 maxDecodingTokens, cudaStream_t stream);
runtime::SizeType32* numLastTokenIndices, runtime::SizeType32* hiddenSizeBatchLevelStarts,
runtime::TokenIdType const* inputIds, runtime::SizeType32 const* baseNetSequenceLengths,
runtime::SizeType32 const* baseNetContextLengths, runtime::TokenIdType const* acceptedTokens,
runtime::SizeType32 const* acceptedLens, runtime::SizeType32 const* prevDraftLens,
runtime::SizeType32 const* prevPaths, runtime::SizeType32 const* bestPathIds, runtime::SizeType32 batchSize,
runtime::SizeType32 maxPathLen, runtime::SizeType32 maxDecodingTokens, runtime::SizeType32 maxNonLeavesPerLayer,
cudaStream_t stream);
struct PrepareGenEagleNetInputsParams
{
@ -136,14 +140,11 @@ struct PrepareGenEagleNetInputsParams
//! output buffer [numOutputTokens]
//! Indices of the hidden states for selected tokens for the next EagleNet iteration.
runtime::SizeType32* hiddenStatesIndices{nullptr};
//! output buffer [numLastTokenIndices]
//! output buffer [batchSize * maxNonLeavesPerLayer]
//! Indices of the hidden states where to sample logits from after the next EagleNet iteration.
runtime::SizeType32* lastTokenIndices{nullptr};
//! output buffer [1]
//! Single number equals to the total number of select tokens summed over all batches.
runtime::SizeType32* numOutputTokens{nullptr};
//! output buffer [1]
//! Single number equals to the total number of logits to be predicted by the next EagleNet summed over all batches.
//! Number of logits predicted by the next EagleNet iteration.
runtime::SizeType32* numLastTokenIndices{nullptr};
//! input buffer [(maxPathLen - 1) * batchSize + 1]
//! Exclusive sum of the hidden states produced per batch per level.
@ -205,6 +206,8 @@ struct PrepareGenEagleNetInputsParams
runtime::SizeType32 maxPathLen{0};
//! Max number of draft tokens + 1
runtime::SizeType32 maxDecodingTokens{0};
//! Maximum number of non-leaf nodes per layer
runtime::SizeType32 maxNonLeavesPerLayer{0};
cudaStream_t stream;
void checkParams()
@ -218,7 +221,6 @@ struct PrepareGenEagleNetInputsParams
TLLM_CHECK(specDecodingPackedMasks);
TLLM_CHECK(hiddenStatesIndices);
TLLM_CHECK(lastTokenIndices);
TLLM_CHECK(numOutputTokens);
TLLM_CHECK(numLastTokenIndices);
TLLM_CHECK(outputHiddenSizeBatchStartsPerLevel);
@ -242,6 +244,7 @@ struct PrepareGenEagleNetInputsParams
TLLM_CHECK(maxPathLen > 0);
TLLM_CHECK(maxDecodingTokens > 0);
TLLM_CHECK(0 < levelIdx && levelIdx < maxPathLen - 1);
TLLM_CHECK(maxNonLeavesPerLayer > 0);
}
};
@ -511,8 +514,8 @@ void invokeGetPackedMaskFromPath(int32_t* specDecodingPackedMasks, runtime::Size
//! \param stream cuda stream.
void invokeExtractTopKsFromPath(runtime::SizeType32 const* paths, runtime::SizeType32* topKs,
runtime::SizeType32* topKOffset, runtime::SizeType32* numSuccessorsForEachNode, runtime::SizeType32 layerId,
runtime::SizeType32 batchSize, runtime::SizeType32 numInputLogits, runtime::SizeType32 maxDecodingTokens,
runtime::SizeType32 maxPathLen, cudaStream_t stream);
runtime::SizeType32 batchSize, runtime::SizeType32 maxDecodingTokens, runtime::SizeType32 maxPathLen,
cudaStream_t stream);
//! \brief Copy the output draft token from input buffer (generated from previous EagleNets)
//! and new draft tokens generated by this layers to the output buffer of this plugin
@ -526,6 +529,7 @@ void invokeExtractTopKsFromPath(runtime::SizeType32 const* paths, runtime::SizeT
//! which contains draft tokens generated by previous EagleNets.
//! \param pluginInputDraftLens [batchSize], on GPU. The
//! plugin's input buffer, which contains the draft length from previous EagleNets.
//! \param numValidLogits [1], on GPU. The number of valid logits.
//! \param pluginOutputDraftIdsPtrs [batchSize * maxDecodingDraftTokens], on GPU. The plugin's output buffer,
//! which will contains all the draft tokens generated by this and previous EagleNets.
//! \param pluginOutputDraftLens [batchSize], on GPU. The plugin's input buffer,
@ -533,13 +537,13 @@ void invokeExtractTopKsFromPath(runtime::SizeType32 const* paths, runtime::SizeT
//! \param layerId SizeType32. The layerId of the EagleNet. Will
//! be used to traverse a specific level of the tree.
//! \param batchSize SizeType32. Batch size.
//! \param numInputLogits SizeType32. Number of logits from all the requests.
//! \param maxDecodingDraftTokens maximum number of decoding draft tokens per step per request.
//! \param stream cuda stream.
void invokeCopyOutputTokensIds(runtime::TokenIdType** tmpOutputIdsPtrs, runtime::SizeType32 const* topKs,
runtime::SizeType32 const* topKOffset, runtime::TokenIdType const* pluginInputDraftIdsPtrs,
runtime::SizeType32 const* pluginInputDraftLens, runtime::TokenIdType* pluginOutputDraftIdsPtrs,
runtime::SizeType32* pluginOutputDraftLens, runtime::SizeType32 layerId, runtime::SizeType32 batchSize,
runtime::SizeType32 numInputLogits, runtime::SizeType32 maxDecodingDraftTokens, cudaStream_t stream);
runtime::SizeType32 const* pluginInputDraftLens, runtime::SizeType32 const* numValidLogits,
runtime::TokenIdType* pluginOutputDraftIdsPtrs, runtime::SizeType32* pluginOutputDraftLens,
runtime::SizeType32 layerId, runtime::SizeType32 batchSize, runtime::SizeType32 maxDecodingDraftTokens,
cudaStream_t stream);
} // namespace tensorrt_llm::kernels::speculative_decoding

View File

@ -1411,6 +1411,7 @@ __global__ void add_fusedQKV_bias_transpose_kernel(T* q_buf, T* k_buf, T* v_buf,
break;
}
case PositionEmbeddingType::kLONG_ROPE:
case PositionEmbeddingType::kROPE_M:
case PositionEmbeddingType::kROPE_GPT_NEOX:
{
bool const do_rotary = !is_masked && vec_size * tidx < rotary_embedding_dim;
@ -1531,6 +1532,7 @@ void invokeAddFusedQKVBiasTranspose(T* q_buf, T* k_buf, T* v_buf, T* QKV, T cons
dim3 grid(token_num, std::max(head_num, 1));
size_t smem_size = (position_embedding_type == PositionEmbeddingType::kROPE_GPT_NEOX
|| position_embedding_type == PositionEmbeddingType::kLONG_ROPE
|| position_embedding_type == PositionEmbeddingType::kROPE_M
? 2 * rotary_embedding_dim * sizeof(T)
: 0);
// NOTE: add offset for rotary embedding
@ -1923,6 +1925,7 @@ __global__ void shiftKCache(KVCacheBuffer kvCacheBuffer, KVLinearBuffer shiftKCa
break;
}
case PositionEmbeddingType::kLONG_ROPE:
case PositionEmbeddingType::kROPE_M:
case PositionEmbeddingType::kROPE_GPT_NEOX:
{
bool const do_rotary = vec_size * tidx < rotary_embedding_dim;
@ -1983,6 +1986,7 @@ void invokeShiftKCache(KVCacheBuffer const& kvCacheBuffer, KVLinearBuffer const&
dim3 grid(token_num_in_k, kv_head_num, batch_beam);
size_t smem_size = (position_embedding_type == PositionEmbeddingType::kROPE_GPT_NEOX
|| position_embedding_type == PositionEmbeddingType::kLONG_ROPE
|| position_embedding_type == PositionEmbeddingType::kROPE_M
? 2 * rotary_embedding_dim * sizeof(T)
: 0);

View File

@ -106,6 +106,9 @@ struct QKVPreprocessingParams
float const* kvScaleOrigQuant{nullptr};
int const* spec_decoding_position_offsets{nullptr};
float2 const* mrope_rotary_sin_cos{nullptr};
int32_t const* mrope_position_deltas{nullptr};
// Scalars.
int batch_size{0};
int max_input_seq_len{0};

View File

@ -383,10 +383,12 @@ __global__ void applyBiasRopeUpdateKVCache(QKVPreprocessingParams<T, KVCacheBuff
// NOTE: only spec decoding needs the position offsets.
// In the generation phase, we assume all sequences should have the same input length.
int const rotary_position = params.spec_decoding_position_offsets != nullptr
? (params.spec_decoding_position_offsets[local_token_idx + batch_idx * params.max_input_seq_len]
+ cache_seq_len - actual_seq_len)
: token_idx_in_seq;
int const rotary_position
= (params.spec_decoding_position_offsets != nullptr ? (
params.spec_decoding_position_offsets[local_token_idx + batch_idx * params.max_input_seq_len]
+ cache_seq_len - actual_seq_len)
: token_idx_in_seq)
+ (params.mrope_position_deltas != nullptr ? params.mrope_position_deltas[batch_idx] : 0);
if (!valid_token)
{
@ -753,8 +755,18 @@ __global__ void applyBiasRopeUpdateKVCacheV2(QKVPreprocessingParams<T, KVCacheBu
}
// Cos/sin cache.
[[maybe_unused]] float2 const* rotary_coef_cache_buffer
= params.rotary_coef_cache_buffer + static_cast<size_t>(rotary_position) * params.half_rotary_dim;
[[maybe_unused]] float2 const* rotary_coef_cache_buffer = nullptr;
if (params.mrope_rotary_sin_cos != nullptr)
{
rotary_coef_cache_buffer = params.mrope_rotary_sin_cos + batch_idx * params.rotary_embedding_max_positions
+ static_cast<size_t>(rotary_position) * params.half_rotary_dim;
}
else
{
rotary_coef_cache_buffer
= params.rotary_coef_cache_buffer + static_cast<size_t>(rotary_position) * params.half_rotary_dim;
}
if constexpr (ROTARY_TYPE == RotaryPositionEmbeddingType::GPT_NEOX)
{
rotary_coef_cache_buffer += gptneox_rotary_dim_idx;
@ -892,7 +904,8 @@ __global__ void applyBiasRopeUpdateKVCacheV2(QKVPreprocessingParams<T, KVCacheBu
grid.z = std::min(int(divUp(params.multi_processor_count * WARPS_PER_SM, grid.x * grid.y)), \
int(divUp(params.batch_size, MIN_SEQUENCES_PER_WARP))); \
if (params.position_embedding_type == PositionEmbeddingType::kROPE_GPT_NEOX \
|| params.position_embedding_type == PositionEmbeddingType::kLONG_ROPE) \
|| params.position_embedding_type == PositionEmbeddingType::kLONG_ROPE \
|| params.position_embedding_type == PositionEmbeddingType::kROPE_M) \
{ \
applyBiasRopeUpdateKVCache<T, TCache, Dh_MAX, ADD_BIAS, STORE_QKV, KVCacheBuffer, \
RotaryPositionEmbeddingType::GPT_NEOX, DYNAMIC_ROTARY_SCALING, FP8_OUTPUT> \
@ -946,7 +959,8 @@ void kernelDispatchHeadSize(QKVPreprocessingParams<T, KVCacheBuffer> params, cud
constexpr int VEC_SIZE = Rotary_vec_t<T, Dh_MAX>::size;
// Make sure we have multiple of paired vectors so that the access is aligned.
TLLM_CHECK_WITH_INFO((params.position_embedding_type != PositionEmbeddingType::kROPE_GPT_NEOX
&& params.position_embedding_type != PositionEmbeddingType::kLONG_ROPE)
&& params.position_embedding_type != PositionEmbeddingType::kLONG_ROPE
&& params.position_embedding_type == PositionEmbeddingType::kROPE_M)
|| params.half_rotary_dim % VEC_SIZE == 0,
"Rotary dim size is not supported.");
@ -1002,7 +1016,8 @@ void kernelV1Dispatch(QKVPreprocessingParams<T, KVCacheBuffer> params, cudaStrea
dim3 block(BLOCK_SIZE); \
dim3 grid(int(divUp(params.max_input_seq_len, tokens_per_cuda_block)), params.batch_size, params.head_num); \
if (params.position_embedding_type == PositionEmbeddingType::kROPE_GPT_NEOX \
|| params.position_embedding_type == PositionEmbeddingType::kLONG_ROPE) \
|| params.position_embedding_type == PositionEmbeddingType::kLONG_ROPE \
|| params.position_embedding_type == PositionEmbeddingType::kROPE_M) \
{ \
applyBiasRopeUpdateKVCacheV2<T, TCache, BLOCK_SIZE, Dh, ADD_BIAS, STORE_QKV, FP8_OUTPUT, KVCacheBuffer, \
RotaryPositionEmbeddingType::GPT_NEOX><<<grid, block, 0, stream>>>(params); \
@ -1279,7 +1294,8 @@ void invokeApplyBiasRopeUpdateKVCacheDispatch(QKVPreprocessingParams<T, KVCacheB
bool const has_sink_tokens = params.sink_token_len > 0;
// V2 implementation requires multiple of paired 16 bytes for gpt-neox rotation.
bool const support_rotary_for_v2 = (params.position_embedding_type != PositionEmbeddingType::kROPE_GPT_NEOX
&& params.position_embedding_type != PositionEmbeddingType::kLONG_ROPE)
&& params.position_embedding_type != PositionEmbeddingType::kLONG_ROPE
&& params.position_embedding_type == PositionEmbeddingType::kROPE_M)
|| params.rotary_embedding_dim % 16 == 0;
// Use v2 kernel for absolute_position_embedding.

View File

@ -37,14 +37,20 @@ namespace tensorrt_llm::layers
template <typename T>
ExternalDraftTokensLayer<T>::ExternalDraftTokensLayer(executor::DecodingMode const& mode,
DecoderDomain const& decoderDomain, std::shared_ptr<BufferManager> bufferManager)
DecoderDomain const& decoderDomain, std::shared_ptr<BufferManager> bufferManager, bool isDeterministic,
bool isAirTopP)
: BaseLayer(decoderDomain, bufferManager)
, mDecodingMode(mode)
, mIsDeterministic(isDeterministic)
, mIsAirTopP(isAirTopP)
{
TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__);
TLLM_CHECK_WITH_INFO(!mDecodingMode.isBeamSearch(), "ExternalDraftTokensLayer does not support Beam search mode");
auto const deviceId = getDevice();
TLLM_CUDA_CHECK(cudaGetDeviceProperties(&mDeviceProp, deviceId));
allocateBuffer(decoderDomain.getBatchSize());
TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__);
@ -58,11 +64,15 @@ void ExternalDraftTokensLayer<T>::allocateBuffer(SizeType32 batchSize)
// top k workspace size
auto workspaceSize = getTopKWorkspaceSize<T>(batchSize, 1, TOP_K_MAX, mDecoderDomain.getVocabSizePadded());
mWorkspaceSize = std::max(workspaceSize, mWorkspaceSize);
// top p workspace size
workspaceSize = getTopPWorkspaceSize<T>(batchSize, mDecoderDomain.getVocabSizePadded());
mWorkspaceSize = std::max(workspaceSize, mWorkspaceSize);
// multinomial (top p == 1) workspace size
workspaceSize = getTopPWorkspaceSize<float>(batchSize, mDecoderDomain.getVocabSizePadded());
// top p and multinomial (top p == 1) workspace size
if (!mIsAirTopP)
{
workspaceSize = getTopPWorkspaceSize<T>(batchSize, mDecoderDomain.getVocabSizePadded());
}
else
{
workspaceSize = getAirTopPWorkspaceSize<T>(batchSize, mDecoderDomain.getVocabSizePadded(), mIsDeterministic);
}
mWorkspaceSize = std::max(workspaceSize, mWorkspaceSize);
// batchsize here is maxBatchSize
@ -169,6 +179,20 @@ void ExternalDraftTokensLayer<T>::setup(SizeType32 batchSize, SizeType32 beamWid
{topKsPtr, runtimeTopK.front(), bufferCast<SizeType32>(*mRuntimeTopKHost)}, {}, //
skipDecodeTopKHostPtr, skipDecodeTopPHostPtr, batchSlotsHostPtr, false);
if (mIsAirTopP)
{
auto smCnt = mDeviceProp.multiProcessorCount;
if (smCnt <= 0)
{
auto const deviceId = getDevice();
cudaDeviceProp prop{};
TLLM_CUDA_CHECK(cudaGetDeviceProperties(&prop, deviceId));
smCnt = prop.multiProcessorCount;
}
mAirTopPBlockNum
= calcAirTopPBlockNum<T>(batchSize, mDecoderDomain.getVocabSizePadded(), smCnt, mIsDeterministic);
}
TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__);
}
@ -330,7 +354,17 @@ void ExternalDraftTokensLayer<T>::multinomialSampling(std::shared_ptr<BaseDecodi
params.maxBatchSize = mDecoderDomain.getBatchSize();
params.vocabSizePadded = mDecoderDomain.getVocabSizePadded();
invokeBatchTopPSampling<T>(params, getStream());
if (!mIsAirTopP)
{
invokeBatchTopPSampling<T>(params, getStream());
}
else
{
params.blockNum = mAirTopPBlockNum;
params.isDeterministic = mIsDeterministic;
invokeBatchAirTopPSampling<T>(params, getStream());
}
TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__);
}
@ -423,7 +457,17 @@ void ExternalDraftTokensLayer<T>::getAllTopPs(std::shared_ptr<BaseDecodingOutput
params.returnAllSelectedTokens = true;
params.maxSeqLen = mDecoderDomain.getVocabSizePadded();
invokeBatchTopPSampling<T>(params, getStream());
if (!mIsAirTopP)
{
invokeBatchTopPSampling<T>(params, getStream());
}
else
{
params.blockNum = mAirTopPBlockNum;
params.isDeterministic = mIsDeterministic;
invokeBatchAirTopPSampling<T>(params, getStream());
}
TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__);
}

View File

@ -35,7 +35,7 @@ public:
using Base = BaseLayer;
ExternalDraftTokensLayer(executor::DecodingMode const& mode, DecoderDomain const& decoderDomain,
std::shared_ptr<runtime::BufferManager> bufferManager);
std::shared_ptr<runtime::BufferManager> bufferManager, bool isDeterministic = true, bool isAirTopP = true);
void setup(runtime::SizeType32 batchSize, runtime::SizeType32 beamWidth, TensorConstPtr batchSlots,
std::shared_ptr<BaseSetupParams> const& setupParams,
@ -74,6 +74,12 @@ private:
TensorPtr mTargetLogits;
// AirTopP
cudaDeviceProp mDeviceProp;
runtime::SizeType32 mAirTopPBlockNum{0};
bool mIsDeterministic{true};
bool mIsAirTopP{false};
private:
void allocateBuffer(runtime::SizeType32 batchSize);
void acceptDraftTokens(std::shared_ptr<BaseDecodingOutputs> const& outputs,

View File

@ -54,11 +54,12 @@ protected:
TensorPtr mSkipDecodeDevice;
TensorPtr mSkipDecodeHost;
runtime::SizeType32 mAirTopPBlockNum{0};
size_t mWorkspaceSize{0};
size_t mSetupWorkspaceSize{0};
// AirTopP
cudaDeviceProp mDeviceProp;
runtime::SizeType32 mAirTopPBlockNum{0};
bool mIsDeterministic{true};
bool mIsAirTopP{false};

View File

@ -58,7 +58,8 @@ set(PLUGIN_LISTS
lowLatencyGemmPlugin
eaglePlugin
lowLatencyGemmSwigluPlugin
qserveGemmPlugin)
qserveGemmPlugin
cudaStreamPlugin)
foreach(PLUGIN_ITER ${PLUGIN_LISTS})
include_directories(${PLUGIN_ITER})

View File

@ -38,6 +38,7 @@
#include "tensorrt_llm/plugins/ncclPlugin/reduceScatterPlugin.h"
#include "tensorrt_llm/plugins/ncclPlugin/sendPlugin.h"
#endif // ENABLE_MULTI_DEVICE
#include "tensorrt_llm/plugins/cudaStreamPlugin/cudaStreamPlugin.h"
#include "tensorrt_llm/plugins/cumsumLastDimPlugin/cumsumLastDimPlugin.h"
#include "tensorrt_llm/plugins/eaglePlugin/eagleDecodeDraftTokensPlugin.h"
#include "tensorrt_llm/plugins/eaglePlugin/eaglePrepareDrafterInputsPlugin.h"
@ -234,6 +235,7 @@ extern "C"
static tensorrt_llm::plugins::EagleDecodeDraftTokensPluginCreator eagleDecodeDraftTokensPluginCreator;
static tensorrt_llm::plugins::EagleSampleAndAcceptDraftTokensPluginCreator
eagleSampleAndAcceptDraftTokensPluginCreator;
static tensorrt_llm::plugins::CudaStreamPluginCreator cudaStreamPluginCreator;
static std::array pluginCreators
= { creatorPtr(identityPluginCreator),
@ -268,7 +270,9 @@ extern "C"
creatorPtr(lowLatencyGemmPluginCreator),
creatorPtr(eagleDecodeDraftTokensPluginCreator),
creatorPtr(eagleSampleAndAcceptDraftTokensPluginCreator),
creatorPtr(lowLatencyGemmSwigluPluginCreator) };
creatorPtr(lowLatencyGemmSwigluPluginCreator),
creatorPtr(cudaStreamPluginCreator),
};
nbCreators = pluginCreators.size();
return pluginCreators.data();
}

View File

@ -0,0 +1,21 @@
#
# SPDX-FileCopyrightText: Copyright (c) 1993-2022 NVIDIA CORPORATION &
# AFFILIATES. All rights reserved. SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License"); you may not
# use this file except in compliance with the License. You may obtain a copy of
# the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
# License for the specific language governing permissions and limitations under
# the License.
#
file(GLOB SRCS *.cpp)
set(PLUGIN_SOURCES ${PLUGIN_SOURCES} ${SRCS})
set(PLUGIN_SOURCES
${PLUGIN_SOURCES}
PARENT_SCOPE)

View File

@ -0,0 +1,294 @@
/*
* SPDX-FileCopyrightText: Copyright (c) 1993-2022 NVIDIA CORPORATION &
* AFFILIATES. All rights reserved. SPDX-License-Identifier: Apache-2.0
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "cudaStreamPlugin.h"
#include "tensorrt_llm/runtime/iBuffer.h"
#include <cassert>
using namespace nvinfer1;
using tensorrt_llm::plugins::CudaStreamPluginCreator;
using tensorrt_llm::plugins::CudaStreamPlugin;
static char const* CUDA_STREAM_PLUGIN_VERSION{"1"};
static char const* CUDA_STREAM_PLUGIN_NAME{"CudaStream"};
PluginFieldCollection CudaStreamPluginCreator::mFC{};
std::vector<nvinfer1::PluginField> CudaStreamPluginCreator::mPluginAttributes;
CudaStreamPlugin::CudaStreamPlugin(int sideStreamId, int nbInputs, nvinfer1::DataType type)
: mSideStreamId(sideStreamId)
, mNbInputs(nbInputs)
, mType(type)
{
init();
}
CudaStreamPlugin::CudaStreamPlugin(void const* data, size_t length)
{
char const *d = reinterpret_cast<char const*>(data), *a = d;
read(d, mSideStreamId);
read(d, mNbInputs);
read(d, mType);
init();
TLLM_CHECK_WITH_INFO(d == a + length,
"Expected length (%d) != real length (%d). This is often "
"caused by using different TensorRT-LLM version to build "
"engine and run engine.",
(int) length, (int) (d - a));
}
CudaStreamPlugin::CudaStreamPlugin(CudaStreamPlugin const& other)
: mSideStreamId(other.mSideStreamId)
, mNbInputs(other.mNbInputs)
, mType(other.mType)
{
init();
}
void CudaStreamPlugin::init()
{
mSideStreamPtr = nullptr;
}
// IPluginV2DynamicExt Methods
nvinfer1::IPluginV2DynamicExt* CudaStreamPlugin::clone() const noexcept
{
auto* plugin = new CudaStreamPlugin(*this);
plugin->setPluginNamespace(mNamespace.c_str());
return plugin;
}
nvinfer1::DimsExprs CudaStreamPlugin::getOutputDimensions(
int outputIndex, nvinfer1::DimsExprs const* inputs, int nbInputs, nvinfer1::IExprBuilder& exprBuilder) noexcept
{
assert(outputIndex == 0);
return inputs[outputIndex];
}
bool CudaStreamPlugin::supportsFormatCombination(
int pos, nvinfer1::PluginTensorDesc const* inOut, int nbInputs, int nbOutputs) noexcept
{
TLLM_CHECK_WITH_INFO(nbInputs == mNbInputs, "CudaStreamPlugin only accepts mNbInputs inputs");
TLLM_CHECK_WITH_INFO(nbOutputs == 1, "CudaStreamPlugin only accepts 1 output");
auto const& desc = inOut[pos];
if (desc.format != TensorFormat::kLINEAR)
{
return false;
}
if (pos > 0 && pos < nbInputs)
{
return true;
}
return desc.type == mType;
}
void CudaStreamPlugin::configurePlugin(nvinfer1::DynamicPluginTensorDesc const* in, int nbInputs,
nvinfer1::DynamicPluginTensorDesc const* out, int nbOutputs) noexcept
{
}
size_t CudaStreamPlugin::getWorkspaceSize(nvinfer1::PluginTensorDesc const* inputs, int nbInputs,
nvinfer1::PluginTensorDesc const* outputs, int nbOutputs) const noexcept
{
return 0;
}
int CudaStreamPlugin::enqueue(nvinfer1::PluginTensorDesc const* inputDesc, nvinfer1::PluginTensorDesc const* outputDesc,
void const* const* inputs, void* const* outputs, void* workspace, cudaStream_t stream) noexcept
{
if (!mSideStreamPtr)
{
auto const resource_name = nvinfer1::pluginInternal::SideStream::getResourceKey(mSideStreamId);
nvinfer1::pluginInternal::SideStream side_stream{};
mSideStreamPtr = reinterpret_cast<nvinfer1::pluginInternal::SideStream*>(
getPluginRegistry()->acquirePluginResource(resource_name, &side_stream));
}
mSideStreamPtr->waitSideStreamOnMainStream(stream);
size_t count = 1;
for (int i = 0; i < inputDesc[0].dims.nbDims; ++i)
{
count *= inputDesc[0].dims.d[i];
}
count *= tensorrt_llm::runtime::BufferDataType(inputDesc[0].type).getSize();
TLLM_CUDA_CHECK(cudaMemcpyAsync(outputs[0], inputs[0], count, cudaMemcpyDeviceToDevice, stream));
return 0;
}
// IPluginV2Ext Methods
nvinfer1::DataType CudaStreamPlugin::getOutputDataType(
int index, nvinfer1::DataType const* inputTypes, int nbInputs) const noexcept
{
TLLM_CHECK(index == 0);
return mType;
}
// IPluginV2 Methods
char const* CudaStreamPlugin::getPluginType() const noexcept
{
return CUDA_STREAM_PLUGIN_NAME;
}
char const* CudaStreamPlugin::getPluginVersion() const noexcept
{
return CUDA_STREAM_PLUGIN_VERSION;
}
int CudaStreamPlugin::getNbOutputs() const noexcept
{
return 1;
}
int CudaStreamPlugin::initialize() noexcept
{
return 0;
}
void CudaStreamPlugin::terminate() noexcept
{
if (mSideStreamPtr)
{
auto const resource_name = nvinfer1::pluginInternal::SideStream::getResourceKey(mSideStreamId);
getPluginRegistry()->releasePluginResource(resource_name);
mSideStreamPtr = nullptr;
}
}
size_t CudaStreamPlugin::getSerializationSize() const noexcept
{
return sizeof(mSideStreamId) + sizeof(mNbInputs) + sizeof(mType);
}
void CudaStreamPlugin::serialize(void* buffer) const noexcept
{
char *d = static_cast<char*>(buffer), *a = d;
write(d, mSideStreamId);
write(d, mNbInputs);
write(d, mType);
assert(d == a + getSerializationSize());
}
void CudaStreamPlugin::destroy() noexcept
{
delete this;
}
///////////////
CudaStreamPluginCreator::CudaStreamPluginCreator()
{
// Fill PluginFieldCollection with PluginField arguments metadata
mPluginAttributes.clear();
mPluginAttributes.emplace_back(PluginField("side_stream_id", nullptr, PluginFieldType::kINT32, 0));
mPluginAttributes.emplace_back(PluginField("num_inputs", nullptr, PluginFieldType::kINT32, 0));
mPluginAttributes.emplace_back(PluginField("type_id", nullptr, PluginFieldType::kINT32, 1));
mFC.nbFields = mPluginAttributes.size();
mFC.fields = mPluginAttributes.data();
}
char const* CudaStreamPluginCreator::getPluginName() const noexcept
{
return CUDA_STREAM_PLUGIN_NAME;
}
char const* CudaStreamPluginCreator::getPluginVersion() const noexcept
{
return CUDA_STREAM_PLUGIN_VERSION;
}
PluginFieldCollection const* CudaStreamPluginCreator::getFieldNames() noexcept
{
return &mFC;
}
IPluginV2* CudaStreamPluginCreator::createPlugin(char const* name, PluginFieldCollection const* fc) noexcept
{
PluginField const* fields = fc->fields;
int sideStreamId;
int nbInputs;
int type;
// Read configurations from each fields
struct MapPair
{
char const* key;
int& field;
bool optional = false;
bool set = false;
};
std::array input_map{
MapPair{"side_stream_id", std::ref(sideStreamId)},
MapPair{"num_inputs", std::ref(nbInputs)},
MapPair{"type_id", std::ref(type)},
};
bool typeSet = false;
for (int i = 0; i < fc->nbFields; ++i)
{
char const* attrName = fields[i].name;
for (auto& item : input_map)
{
if (!strcmp(item.key, attrName))
{
TLLM_CHECK(fields[i].type == nvinfer1::PluginFieldType::kINT32);
TLLM_CHECK_WITH_INFO(!item.set, "Parameter %s was set twice", item.key);
item.field = static_cast<int>(*(static_cast<int const*>(fields[i].data)));
item.set = true;
}
}
}
for (auto& item : input_map)
{
TLLM_CHECK_WITH_INFO(item.set || item.optional, "Parameter %s is required but not set", item.key);
}
try
{
auto* obj = new CudaStreamPlugin(sideStreamId, nbInputs, static_cast<nvinfer1::DataType>(type));
obj->setPluginNamespace(mNamespace.c_str());
return obj;
}
catch (std::exception const& e)
{
caughtError(e);
}
return nullptr;
}
IPluginV2* CudaStreamPluginCreator::deserializePlugin(
char const* name, void const* serialData, size_t serialLength) noexcept
{
// This object will be deleted when the network is destroyed, which will
// call CudaStreamPlugin::destroy()
try
{
auto* obj = new CudaStreamPlugin(serialData, serialLength);
obj->setPluginNamespace(mNamespace.c_str());
return obj;
}
catch (std::exception const& e)
{
caughtError(e);
}
return nullptr;
}

View File

@ -0,0 +1,266 @@
/*
* SPDX-FileCopyrightText: Copyright (c) 1993-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: Apache-2.0
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#pragma once
#include "NvInferPlugin.h"
#include "tensorrt_llm/plugins/common/plugin.h"
#include "tensorrt_llm/runtime/cudaMemPool.h"
#include "tensorrt_llm/runtime/utils/debugUtils.h"
#include <memory>
#include <string>
#include <vector>
namespace nvinfer1
{
namespace pluginInternal
{
class SideWorkspace
{
public:
SideWorkspace(cudaStream_t stream)
: mWorkspaceSize{0}
, mWorkspacePtr{nullptr}
, mStream{stream}
{
}
~SideWorkspace()
{
if (mWorkspacePtr)
{
TLLM_CUDA_CHECK(cudaFreeAsync(mWorkspacePtr, mStream));
}
}
void* get(size_t workspaceSize)
{
if (mWorkspacePtr && mWorkspaceSize < workspaceSize)
{
TLLM_CUDA_CHECK(cudaFreeAsync(mWorkspacePtr, mStream));
mWorkspacePtr = nullptr;
}
if (!mWorkspacePtr)
{
mWorkspaceSize = workspaceSize;
auto pool_ptr
= tensorrt_llm::runtime::CudaMemPool::getPrimaryPoolForDevice(tensorrt_llm::common::getDevice());
TLLM_CUDA_CHECK(cudaMallocFromPoolAsync(&mWorkspacePtr, mWorkspaceSize, pool_ptr->getPool(), mStream));
}
return mWorkspacePtr;
}
private:
size_t mWorkspaceSize;
void* mWorkspacePtr;
cudaStream_t mStream;
};
class SideStream : public IPluginResource
{
public:
SideStream(bool init = false)
: mStream{}
, mMainEvent{}
, mSideEvent{}
, mWorkspace{}
, mInit{init}
{
// The object passed to acquirePluginResource should use the default value init=false
if (init)
{
TLLM_CUDA_CHECK(cudaStreamCreate(&mStream));
TLLM_CUDA_CHECK(cudaEventCreateWithFlags(&mMainEvent, cudaEventDisableTiming));
TLLM_CUDA_CHECK(cudaEventCreateWithFlags(&mSideEvent, cudaEventDisableTiming));
mWorkspace = std::make_shared<SideWorkspace>(mStream);
}
}
void free()
{
if (mInit)
{
mWorkspace = nullptr;
TLLM_CUDA_CHECK(cudaStreamSynchronize(mStream));
TLLM_CUDA_CHECK(cudaStreamDestroy(mStream));
TLLM_CUDA_CHECK(cudaEventDestroy(mMainEvent));
TLLM_CUDA_CHECK(cudaEventDestroy(mSideEvent));
mInit = false;
}
}
int32_t release() noexcept override
{
try
{
free();
}
catch (std::exception const& e)
{
return -1;
}
return 0;
}
IPluginResource* clone() noexcept override
{
// An object is cloned only when calling acquirePluginResource for the first time for each key
std::unique_ptr<SideStream> cloned{};
try
{
if (!mInit)
{
cloned = std::make_unique<SideStream>(/* init */ true);
}
else
{
return nullptr;
}
}
catch (std::exception const& e)
{
return nullptr;
}
return cloned.release();
}
~SideStream() override
{
free();
}
void* getWorkspacePtr(size_t workspaceSize)
{
return mWorkspace->get(workspaceSize);
}
cudaStream_t getStream() const
{
return mStream;
}
void waitMainStreamOnSideStream(cudaStream_t const stream) const
{
TLLM_CUDA_CHECK(cudaEventRecord(mMainEvent, stream));
TLLM_CUDA_CHECK(cudaStreamWaitEvent(mStream, mMainEvent));
}
void waitSideStreamOnMainStream(cudaStream_t const stream) const
{
TLLM_CUDA_CHECK(cudaEventRecord(mSideEvent, mStream));
TLLM_CUDA_CHECK(cudaStreamWaitEvent(stream, mSideEvent));
}
void stallMainStream(char const* name, cudaStream_t const stream, std::optional<int> delay = std::nullopt) const
{
tensorrt_llm::runtime::utils::stallStream(name, stream, delay);
}
void stallSideStream(char const* name, std::optional<int> delay = std::nullopt) const
{
tensorrt_llm::runtime::utils::stallStream(name, mStream, delay);
}
static char const* getResourceKey(int const stream_id)
{
std::string keyString = "side_stream_" + std::to_string(stream_id);
return keyString.c_str();
}
private:
cudaStream_t mStream;
cudaEvent_t mMainEvent;
cudaEvent_t mSideEvent;
std::shared_ptr<SideWorkspace> mWorkspace;
bool mInit;
};
} // namespace pluginInternal
} // namespace nvinfer1
namespace tensorrt_llm::plugins
{
class CudaStreamPlugin : public BasePlugin
{
public:
CudaStreamPlugin(int sideStreamId, int nbInputs, nvinfer1::DataType type);
CudaStreamPlugin(void const* data, size_t length);
CudaStreamPlugin(CudaStreamPlugin const&);
void init();
~CudaStreamPlugin() override = default;
// IPluginV2DynamicExt Methods
nvinfer1::IPluginV2DynamicExt* clone() const noexcept override;
nvinfer1::DimsExprs getOutputDimensions(int outputIndex, nvinfer1::DimsExprs const* inputs, int nbInputs,
nvinfer1::IExprBuilder& exprBuilder) noexcept override;
bool supportsFormatCombination(
int pos, nvinfer1::PluginTensorDesc const* inOut, int nbInputs, int nbOutputs) noexcept override;
void configurePlugin(nvinfer1::DynamicPluginTensorDesc const* in, int nbInputs,
nvinfer1::DynamicPluginTensorDesc const* out, int nbOutputs) noexcept override;
size_t getWorkspaceSize(nvinfer1::PluginTensorDesc const* inputs, int nbInputs,
nvinfer1::PluginTensorDesc const* outputs, int nbOutputs) const noexcept override;
int enqueue(nvinfer1::PluginTensorDesc const* inputDesc, nvinfer1::PluginTensorDesc const* outputDesc,
void const* const* inputs, void* const* outputs, void* workspace, cudaStream_t stream) noexcept override;
// IPluginV2Ext Methods
nvinfer1::DataType getOutputDataType(
int index, nvinfer1::DataType const* inputTypes, int nbInputs) const noexcept override;
// IPluginV2 Methods
char const* getPluginType() const noexcept override;
char const* getPluginVersion() const noexcept override;
int getNbOutputs() const noexcept override;
int initialize() noexcept override;
void terminate() noexcept override;
size_t getSerializationSize() const noexcept override;
void serialize(void* buffer) const noexcept override;
void destroy() noexcept override;
private:
const std::string mLayerName;
int mSideStreamId;
int mNbInputs;
nvinfer1::DataType mType;
nvinfer1::pluginInternal::SideStream* mSideStreamPtr;
};
class CudaStreamPluginCreator : public BaseCreator
{
public:
CudaStreamPluginCreator();
char const* getPluginName() const noexcept override;
char const* getPluginVersion() const noexcept override;
nvinfer1::PluginFieldCollection const* getFieldNames() noexcept override;
nvinfer1::IPluginV2* createPlugin(char const* name, nvinfer1::PluginFieldCollection const* fc) noexcept override;
nvinfer1::IPluginV2* deserializePlugin(
char const* name, void const* serialData, size_t serialLength) noexcept override;
private:
static nvinfer1::PluginFieldCollection mFC;
static std::vector<nvinfer1::PluginField> mPluginAttributes;
};
} // namespace tensorrt_llm::plugins

View File

@ -71,7 +71,7 @@ nvinfer1::DimsExprs EagleDecodeDraftTokensPlugin::getOutputDimensions(
int outputIndex, nvinfer1::DimsExprs const* inputs, int nbInputs, nvinfer1::IExprBuilder& exprBuilder) noexcept
{
TLLM_CHECK(outputIndex < 2);
TLLM_CHECK(nbInputs == 5);
TLLM_CHECK(nbInputs == 6);
auto const batchSizeExpr = inputs[getIdx(InputIdxEntry::PATHS)].d[0];
auto const maxDecodingTokensExpr = inputs[getIdx(InputIdxEntry::PATHS)].d[1];
auto const maxDecodingDraftTokensExpr
@ -102,7 +102,7 @@ nvinfer1::DimsExprs EagleDecodeDraftTokensPlugin::getOutputDimensions(
bool EagleDecodeDraftTokensPlugin::supportsFormatCombination(
int pos, nvinfer1::PluginTensorDesc const* inOut, int nbInputs, int nbOutputs) noexcept
{
TLLM_CHECK(nbInputs == 5 && nbOutputs == 2);
TLLM_CHECK(nbInputs == 6 && nbOutputs == 2);
TLLM_CHECK(pos < nbInputs + nbOutputs);
if (pos == getIdx(InputIdxEntry::LOGITS)) // logits (input)
@ -164,7 +164,12 @@ size_t EagleDecodeDraftTokensPlugin::getWorkspaceSizeType(nvinfer1::PluginTensor
// [batchSize * maxDecodingTokens]
auto const numSuccessorsForEachNodeSize = batchSize * maxDecodingTokens * sizeof(SizeType32);
SizeType32 constexpr NUM_BUFFERS{7};
// 7. Flag whether to do decoding or not. SamplingTopK is done for numInputLogits tokens.
// But only sum(numValidLogitsPerRequest[:]) of them are valid.
// [batchSize * maxDecodingTokens]
auto const skipDecodeSize = numInputLogits * sizeof(bool);
SizeType32 constexpr NUM_BUFFERS{8};
size_t workspaces[NUM_BUFFERS];
workspaces[0] = draftTokenSamplingWorkspaceSize;
workspaces[1] = topKsSize;
@ -173,6 +178,7 @@ size_t EagleDecodeDraftTokensPlugin::getWorkspaceSizeType(nvinfer1::PluginTensor
workspaces[4] = outputIdsPtrsSize;
workspaces[5] = outputIdsSize;
workspaces[6] = numSuccessorsForEachNodeSize;
workspaces[7] = skipDecodeSize;
workspaceSize = tc::calculateTotalWorkspaceSize(workspaces, NUM_BUFFERS);
}
else
@ -221,6 +227,7 @@ void EagleDecodeDraftTokensPlugin::doTopKSampling(nvinfer1::PluginTensorDesc con
// Plugin inputs
auto logits = static_cast<T const*>(inputs[getIdx(InputIdxEntry::LOGITS)]);
auto numValidLogits = static_cast<SizeType32 const*>(inputs[getIdx(InputIdxEntry::NUM_VALID_LOGITS)]);
auto randSample = static_cast<float const*>(inputs[getIdx(InputIdxEntry::RAND_SAMPLE)]);
auto paths = static_cast<SizeType32 const*>(inputs[getIdx(InputIdxEntry::PATHS)]);
auto pluginInputDraftTokenIdsPtrs
@ -269,12 +276,16 @@ void EagleDecodeDraftTokensPlugin::doTopKSampling(nvinfer1::PluginTensorDesc con
SizeType32* numSuccessorsForEachNode = reinterpret_cast<SizeType32*>(
tc::nextWorkspacePtr(workspaceBytePtr, offset, batchSize * maxDecodingTokens * sizeof(SizeType32)));
// Workspace 7: skip decoding mask [numInputLogits]
bool* skipDecode
= reinterpret_cast<bool*>(tc::nextWorkspacePtr(workspaceBytePtr, offset, numInputLogits * sizeof(bool)));
// Fill logitsPtrs from logits, fill outputIdsPtrs from outputIdsFlatten and fill decodingTokens
invokeAssembleDraftLogitsOffsets(logitsPtrs, logits, outputIdsPtrs, outputIdsFlatten, numInputLogits,
maxDecodingDraftTokens, vocabSizePadded, stream);
invokeAssembleDraftLogitsOffsets(logitsPtrs, logits, outputIdsPtrs, outputIdsFlatten, skipDecode, numValidLogits,
numInputLogits, batchSize, maxDecodingDraftTokens, vocabSizePadded, stream);
sync_check_cuda_error();
invokeExtractTopKsFromPath(paths, topKs, topKOffset, numSuccessorsForEachNode, mLayerIdx, batchSize, numInputLogits,
invokeExtractTopKsFromPath(paths, topKs, topKOffset, numSuccessorsForEachNode, mLayerIdx, batchSize,
maxDecodingTokens, maxPathLen, stream);
sync_check_cuda_error();
@ -289,13 +300,14 @@ void EagleDecodeDraftTokensPlugin::doTopKSampling(nvinfer1::PluginTensorDesc con
params.maxTokensPerStep = 1;
params.vocabSizePadded = vocabSizePadded;
params.returnAllSelectedTokens = true;
params.skipDecode = skipDecode;
invokeBatchTopKSampling(params, stream);
sync_check_cuda_error();
// Copy output token id from outputIdsPtrs to the plugin output buffer
invokeCopyOutputTokensIds(outputIdsPtrs, topKs, topKOffset, pluginInputDraftTokenIdsPtrs, pluginInputDraftLens,
pluginOutputDraftTokenIdsPtrs, pluginOutputDraftLens, mLayerIdx, batchSize, numInputLogits,
numValidLogits, pluginOutputDraftTokenIdsPtrs, pluginOutputDraftLens, mLayerIdx, batchSize,
maxDecodingDraftTokens, stream);
sync_check_cuda_error();

View File

@ -70,6 +70,8 @@ private:
RAND_SAMPLE,
// [batch_size, max_decoding_tokens, max_path_len]
PATHS,
// [1]
NUM_VALID_LOGITS,
// [batch_size, max_decoding_draft_tokens]
INPUT_DRAFT_TOKEN_IDS,

View File

@ -36,8 +36,11 @@ static char const* EAGLE_PREPARE_DRAFTER_INPUTS_PLUGIN_NAME{"EaglePrepareDrafter
PluginFieldCollection EaglePrepareDrafterInputsPluginCreator::mFC{};
std::vector<nvinfer1::PluginField> EaglePrepareDrafterInputsPluginCreator::mPluginAttributes;
EaglePrepareDrafterInputsPlugin::EaglePrepareDrafterInputsPlugin(int32_t layerIdx)
EaglePrepareDrafterInputsPlugin::EaglePrepareDrafterInputsPlugin(
int32_t layerIdx, int32_t numLayers, int32_t maxNonLeavesPerLayer)
: mLayerIdx(layerIdx)
, mNumLayers(numLayers)
, mMaxNonLeavesPerLayer(maxNonLeavesPerLayer)
{
}
@ -45,6 +48,9 @@ void EaglePrepareDrafterInputsPlugin::initFieldsToSerialize()
{
mDataToSerialize.clear();
mDataToSerialize.emplace_back(PluginField("layer_idx", &mLayerIdx, PluginFieldType::kINT32, 1));
mDataToSerialize.emplace_back(PluginField("num_layers", &mNumLayers, PluginFieldType::kINT32, 1));
mDataToSerialize.emplace_back(
PluginField("max_non_leaves_per_layer", &mMaxNonLeavesPerLayer, PluginFieldType::kINT32, 1));
mFCToSerialize.nbFields = mDataToSerialize.size();
mFCToSerialize.fields = mDataToSerialize.data();
}
@ -99,7 +105,7 @@ char const* EaglePrepareDrafterInputsPlugin::getPluginNamespace() const noexcept
// IPluginV3OneBuild methods
int32_t EaglePrepareDrafterInputsPlugin::getNbOutputs() const noexcept
{
return 12;
return 11;
}
int32_t EaglePrepareDrafterInputsPlugin::configurePlugin(nvinfer1::DynamicPluginTensorDesc const* in, int32_t nbInputs,
@ -136,11 +142,13 @@ int32_t EaglePrepareDrafterInputsPlugin::getOutputShapes(nvinfer1::DimsExprs con
nvinfer1::DimsExprs const* shapeInputs, int32_t nbShapeInputs, nvinfer1::DimsExprs* outputs, int32_t nbOutputs,
nvinfer1::IExprBuilder& exprBuilder) noexcept
{
TLLM_CHECK(nbOutputs == 12);
TLLM_CHECK(nbInputs == 12);
TLLM_CHECK(nbOutputs == 11);
TLLM_CHECK(nbInputs == 14);
TLLM_CHECK(nbShapeInputs == 0);
auto const numTokens = inputs[getIdx(InputIdxEntry::INPUT_IDS)].d[0];
auto const batchSizeExpr = inputs[getIdx(InputIdxEntry::PREV_DRAFT_PATHS)].d[0];
auto const numGenRequestsExpr = inputs[getIdx(InputIdxEntry::SPEC_DECODING_GENERATION_LENGTHS)].d[0];
auto const numInputGenTokensExpr = inputs[getIdx(InputIdxEntry::INPUT_GEN_TOKENS)].d[0];
auto const maxDecodingLenExpr = inputs[getIdx(InputIdxEntry::PREV_DRAFT_PATHS)].d[1];
auto const maxPathLenExpr = inputs[getIdx(InputIdxEntry::PREV_DRAFT_PATHS)].d[2];
@ -171,14 +179,29 @@ int32_t EaglePrepareDrafterInputsPlugin::getOutputShapes(nvinfer1::DimsExprs con
|| outputIndex == getIdx(OutputIdxEntry::HIDDEN_STATES_INDICES)
|| (mLayerIdx == 0 && outputIndex == getIdx(OutputIdxEntry::POSITION_IDS)))
{
auto optValue = exprBuilder.operation(DimensionOperation::kPROD, *maxDecodingLenExpr, *batchSizeExpr);
auto upperBound = exprBuilder.operation(DimensionOperation::kSUM, *optValue, *numTokens);
SizeType32 outSizeIndex = getIdx(OutputIdxEntry::NUM_OUTPUT_TOKENS);
auto outSizeTensor = exprBuilder.declareSizeTensor(outSizeIndex, *optValue, *upperBound);
outputs[outputIndex].nbDims = 1;
outputs[outputIndex].d[0] = outSizeTensor;
if (mLayerIdx == 0)
{
// We have at most numGenRequests * (mNumLayers + 1) accepted tokens per step for gen requests and
// input_ids - numGenTokens tokens for context requests.
auto numOutputGenTokensExpr = exprBuilder.operation(
DimensionOperation::kPROD, *numGenRequestsExpr, *exprBuilder.constant(mNumLayers + 1));
auto numInputCtxTokensExpr
= exprBuilder.operation(DimensionOperation::kSUB, *numTokens, *numInputGenTokensExpr);
outputs[outputIndex].nbDims = 1;
outputs[outputIndex].d[0] = exprBuilder.operation(DimensionOperation::kMAX, *exprBuilder.constant(1),
*exprBuilder.operation(DimensionOperation::kSUM, *numOutputGenTokensExpr, *numInputCtxTokensExpr));
}
else
{
// At most we have mMaxNonLeavesPerLayer non-leaves at this layer.
// And in total we pass all non-leaves + all their preceding nodes.
// batchSize * mMaxNonLeavesPerLayer * layerIdx
outputs[outputIndex].nbDims = 1;
outputs[outputIndex].d[0] = exprBuilder.operation(DimensionOperation::kPROD,
*exprBuilder.operation(DimensionOperation::kPROD, *exprBuilder.constant(mLayerIdx),
*exprBuilder.constant(mMaxNonLeavesPerLayer)),
*batchSizeExpr);
}
}
else if (mLayerIdx > 0 && outputIndex == getIdx(OutputIdxEntry::POSITION_IDS))
{
@ -187,20 +210,14 @@ int32_t EaglePrepareDrafterInputsPlugin::getOutputShapes(nvinfer1::DimsExprs con
}
else if (outputIndex == getIdx(OutputIdxEntry::LAST_TOKEN_INDICES))
{
auto upperBound = exprBuilder.operation(DimensionOperation::kPROD, *maxDecodingLenExpr, *batchSizeExpr);
auto optValue = exprBuilder.operation(DimensionOperation::kCEIL_DIV, *upperBound, *exprBuilder.constant(2));
SizeType32 outSizeIndex = getIdx(OutputIdxEntry::NUM_LAST_TOKEN_INDICES);
auto outSizeTensor = exprBuilder.declareSizeTensor(outSizeIndex, *optValue, *upperBound);
outputs[outputIndex].nbDims = 1;
outputs[outputIndex].d[0] = outSizeTensor;
outputs[outputIndex].d[0] = exprBuilder.operation(
DimensionOperation::kPROD, *exprBuilder.constant(mMaxNonLeavesPerLayer), *batchSizeExpr);
}
else if (outputIndex == getIdx(OutputIdxEntry::NUM_OUTPUT_TOKENS)
|| outputIndex == getIdx(OutputIdxEntry::NUM_LAST_TOKEN_INDICES))
else if (outputIndex == getIdx(OutputIdxEntry::NUM_LAST_TOKEN_INDICES))
{
// size tensors must be declared as 0-D
outputs[outputIndex].nbDims = 0;
outputs[outputIndex].nbDims = 1;
outputs[outputIndex].d[0] = exprBuilder.constant(1);
}
else if (outputIndex == getIdx(OutputIdxEntry::HIDDEN_SIZE_BATCH_LEVEL_STARTS))
{
@ -265,6 +282,11 @@ void EaglePrepareDrafterInputsPlugin::prepareCtxEagleNetData(nvinfer1::PluginTen
TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__);
auto const batchSize = inputDesc[getIdx(InputIdxEntry::SEQUENCE_LENGTHS)].dims.d[0];
auto const numTokens = inputDesc[getIdx(InputIdxEntry::INPUT_IDS)].dims.d[0];
auto const numGenRequests = inputDesc[getIdx(InputIdxEntry::SPEC_DECODING_GENERATION_LENGTHS)].dims.d[0];
auto const numInputGenTokens = inputDesc[getIdx(InputIdxEntry::INPUT_GEN_TOKENS)].dims.d[0];
auto const maxPathLen = inputDesc[getIdx(InputIdxEntry::ACCEPTED_TOKENS)].dims.d[1];
auto const maxDecodingTokens = inputDesc[getIdx(InputIdxEntry::NEXT_DRAFT_PATHS)].dims.d[1];
@ -274,7 +296,6 @@ void EaglePrepareDrafterInputsPlugin::prepareCtxEagleNetData(nvinfer1::PluginTen
auto positionIds = reinterpret_cast<SizeType32*>(outputs[getIdx(OutputIdxEntry::POSITION_IDS)]);
auto hiddenStatesIndices = reinterpret_cast<SizeType32*>(outputs[getIdx(OutputIdxEntry::HIDDEN_STATES_INDICES)]);
auto lastTokenIndices = reinterpret_cast<SizeType32*>(outputs[getIdx(OutputIdxEntry::LAST_TOKEN_INDICES)]);
auto numOutputTokens = reinterpret_cast<SizeType32*>(outputs[getIdx(OutputIdxEntry::NUM_OUTPUT_TOKENS)]);
auto numLastTokenIndices = reinterpret_cast<SizeType32*>(outputs[getIdx(OutputIdxEntry::NUM_LAST_TOKEN_INDICES)]);
auto hiddenSizeBatchLevelStarts
= reinterpret_cast<SizeType32*>(outputs[getIdx(OutputIdxEntry::HIDDEN_SIZE_BATCH_LEVEL_STARTS)]);
@ -288,12 +309,14 @@ void EaglePrepareDrafterInputsPlugin::prepareCtxEagleNetData(nvinfer1::PluginTen
auto prevPaths = reinterpret_cast<SizeType32 const*>(inputs[getIdx(InputIdxEntry::PREV_DRAFT_PATHS)]);
auto bestPathIds = reinterpret_cast<SizeType32 const*>(inputs[getIdx(InputIdxEntry::ACCEPTED_PATHS)]);
invokePrepareCtxEagleNetInputs(eagleNetSequenceLengths, eagleNetContextLengths, outputIds, positionIds,
hiddenStatesIndices, lastTokenIndices, numOutputTokens, numLastTokenIndices, hiddenSizeBatchLevelStarts,
inputIds, baseNetSequenceLengths, baseNetContextLengths, acceptedTokens, acceptedLens, prevDraftLens, prevPaths,
bestPathIds, batchSize, maxPathLen, maxDecodingTokens, stream);
auto const numOutputTokens = (numTokens - numInputGenTokens) + (numGenRequests * (mNumLayers + 1));
cudaMemsetAsync(positionIds, 0, numOutputTokens * sizeof(SizeType32), stream);
cudaMemsetAsync(hiddenStatesIndices, 0, numOutputTokens * sizeof(SizeType32), stream);
auto const numTokens = inputDesc[getIdx(InputIdxEntry::INPUT_IDS)].dims.d[0];
invokePrepareCtxEagleNetInputs(eagleNetSequenceLengths, eagleNetContextLengths, outputIds, positionIds,
hiddenStatesIndices, lastTokenIndices, numLastTokenIndices, hiddenSizeBatchLevelStarts, inputIds,
baseNetSequenceLengths, baseNetContextLengths, acceptedTokens, acceptedLens, prevDraftLens, prevPaths,
bestPathIds, batchSize, maxPathLen, maxDecodingTokens, mMaxNonLeavesPerLayer, stream);
sync_check_cuda_error();
@ -322,7 +345,6 @@ void EaglePrepareDrafterInputsPlugin::prepareGenEagleNetData(nvinfer1::PluginTen
= reinterpret_cast<SizeType32*>(outputs[getIdx(OutputIdxEntry::SPEC_DECODING_PACKED_MASK)]);
auto hiddenStatesIndices = reinterpret_cast<SizeType32*>(outputs[getIdx(OutputIdxEntry::HIDDEN_STATES_INDICES)]);
auto lastTokenIndices = reinterpret_cast<SizeType32*>(outputs[getIdx(OutputIdxEntry::LAST_TOKEN_INDICES)]);
auto numOutputTokens = reinterpret_cast<SizeType32*>(outputs[getIdx(OutputIdxEntry::NUM_OUTPUT_TOKENS)]);
auto numLastTokenIndices = reinterpret_cast<SizeType32*>(outputs[getIdx(OutputIdxEntry::NUM_LAST_TOKEN_INDICES)]);
auto outputHiddenSizeBatchStartsPerLevel
= reinterpret_cast<SizeType32*>(outputs[getIdx(OutputIdxEntry::HIDDEN_SIZE_BATCH_LEVEL_STARTS)]);
@ -357,6 +379,7 @@ void EaglePrepareDrafterInputsPlugin::prepareGenEagleNetData(nvinfer1::PluginTen
SizeType32* maxGenerationLength
= reinterpret_cast<SizeType32*>(tc::nextWorkspacePtr(workspaceBytePtr, offset, 1 * sizeof(SizeType32)));
cudaMemsetAsync(hiddenStatesIndices, 0, batchSize * mMaxNonLeavesPerLayer * mLayerIdx * sizeof(SizeType32), stream);
cudaMemsetAsync(selectedMasks, 0, batchSize * maxDecodingTokens * maxDecodingTokens * sizeof(int8_t), stream);
// Prefill mask setting all to leaves.
cudaMemsetAsync(isLeafMask, 1, batchSize * maxDecodingTokens * sizeof(int8_t), stream);
@ -371,7 +394,6 @@ void EaglePrepareDrafterInputsPlugin::prepareGenEagleNetData(nvinfer1::PluginTen
params.specDecodingPackedMasks = specDecodingPackedMasks;
params.hiddenStatesIndices = hiddenStatesIndices;
params.lastTokenIndices = lastTokenIndices;
params.numOutputTokens = numOutputTokens;
params.numLastTokenIndices = numLastTokenIndices;
params.outputHiddenSizeBatchStartsPerLevel = outputHiddenSizeBatchStartsPerLevel;
@ -395,6 +417,7 @@ void EaglePrepareDrafterInputsPlugin::prepareGenEagleNetData(nvinfer1::PluginTen
params.batchSize = batchSize;
params.maxPathLen = maxPathLen;
params.maxDecodingTokens = maxDecodingTokens;
params.maxNonLeavesPerLayer = mMaxNonLeavesPerLayer;
params.stream = stream;
params.checkParams();
@ -459,7 +482,9 @@ EaglePrepareDrafterInputsPluginCreator::EaglePrepareDrafterInputsPluginCreator()
{
// Fill PluginFieldCollection with PluginField arguments metadata
mPluginAttributes.clear();
mPluginAttributes.emplace_back(PluginField("layer_idx", nullptr, PluginFieldType::kINT32, 0));
mPluginAttributes.emplace_back(PluginField("layer_idx", nullptr, PluginFieldType::kINT32, 1));
mPluginAttributes.emplace_back(PluginField("num_layers", nullptr, PluginFieldType::kINT32, 1));
mPluginAttributes.emplace_back(PluginField("max_non_leaves_per_layer", nullptr, PluginFieldType::kINT32, 1));
mFC.nbFields = mPluginAttributes.size();
mFC.fields = mPluginAttributes.data();
}
@ -485,6 +510,8 @@ nvinfer1::IPluginV3* EaglePrepareDrafterInputsPluginCreator::createPlugin(
try
{
int32_t layerIdx{0};
int32_t numLayers{0};
int32_t maxNonLeavesPerLayer{0};
// Read configurations from each fields
for (int i = 0; i < fc->nbFields; ++i)
{
@ -494,8 +521,18 @@ nvinfer1::IPluginV3* EaglePrepareDrafterInputsPluginCreator::createPlugin(
TLLM_CHECK(fc->fields[i].type == PluginFieldType::kINT32);
layerIdx = *static_cast<int32_t const*>(fc->fields[i].data);
}
else if (!strcmp(attrName, "num_layers"))
{
TLLM_CHECK(fc->fields[i].type == PluginFieldType::kINT32);
numLayers = *static_cast<int32_t const*>(fc->fields[i].data);
}
else if (!strcmp(attrName, "max_non_leaves_per_layer"))
{
TLLM_CHECK(fc->fields[i].type == PluginFieldType::kINT32);
maxNonLeavesPerLayer = *static_cast<int32_t const*>(fc->fields[i].data);
}
}
return new EaglePrepareDrafterInputsPlugin(layerIdx);
return new EaglePrepareDrafterInputsPlugin(layerIdx, numLayers, maxNonLeavesPerLayer);
}
catch (std::exception const& e)
{

View File

@ -33,7 +33,7 @@ class EaglePrepareDrafterInputsPlugin : public nvinfer1::IPluginV3,
public:
EaglePrepareDrafterInputsPlugin(EaglePrepareDrafterInputsPlugin const& p) = default;
EaglePrepareDrafterInputsPlugin(int32_t layerIdx);
EaglePrepareDrafterInputsPlugin(int32_t layerIdx, int32_t numLayers, int32_t maxNonLeavesPerLayer);
nvinfer1::IPluginV3* clone() noexcept override;
@ -98,6 +98,10 @@ private:
PREV_DRAFT_PATHS,
//! [(max_path_len - 1) * batch_size + 1]
HIDDEN_SIZE_BATCH_LEVEL_STARTS,
//! [num_gen_tokens]
INPUT_GEN_TOKENS,
//! [num_gen_requests]
SPEC_DECODING_GENERATION_LENGTHS,
};
enum class OutputIdxEntry : int32_t
@ -112,17 +116,18 @@ private:
SPEC_DECODING_POSITION_OFFSETS,
//! [batchSize, maxDecodingTokens, ceil(maxDecodingTokens / 32)]
SPEC_DECODING_PACKED_MASK,
//! [NUM_OUTPUT_TOKENS]
//! [batchSize * mMaxNonLeavesPerLayer * layerIdx] for layerIdx > 0
//! [num_tokens - numGenTokens + numGenRequests * (mNumLayers + 1)] for layerIdx == 0
OUTPUT_IDS,
//! [NUM_OUTPUT_TOKENS]
//! [batchSize] for layerIdx > 0
//! [num_tokens - numGenTokens + numGenRequests * (mNumLayers + 1)] for layerIdx == 0
POSITION_IDS,
//! [NUM_OUTPUT_TOKENS]
//! [batchSize * mMaxNonLeavesPerLayer * layerIdx] for layerIdx > 0
//! [num_tokens - numGenTokens + numGenRequests * (mNumLayers + 1)] for layerIdx == 0
HIDDEN_STATES_INDICES,
//! [NUM_LAST_TOKEN_INDICES]
//! [batchSize * mMaxNonLeavesPerLayer]
LAST_TOKEN_INDICES,
//! [1]
NUM_OUTPUT_TOKENS,
//! [1]
NUM_LAST_TOKEN_INDICES,
//! [(max_path_len - 1) * batch_size + 1]
HIDDEN_SIZE_BATCH_LEVEL_STARTS,
@ -148,7 +153,9 @@ private:
cudaStream_t stream) noexcept;
private:
int32_t mLayerIdx;
int32_t mLayerIdx{0};
int32_t mNumLayers{0};
int32_t mMaxNonLeavesPerLayer{0};
std::vector<nvinfer1::PluginField> mDataToSerialize;
nvinfer1::PluginFieldCollection mFCToSerialize;
};

View File

@ -116,6 +116,7 @@ struct FusedQKVMaskedAttentionDispatchParams
int max_distance = 0;
bool block_sparse_attention = false;
BlockSparseParams block_sparse_params;
int32_t const* mrope_position_deltas;
};
template <typename T, typename KVCacheBuffer>
@ -251,6 +252,9 @@ bool GPTAttentionPluginCommon::convertMMHAParamsToXQAParams(tensorrt_llm::kernel
= generationsParams.spec_decoding_is_generation_length_variable;
xqaParams.spec_decoding_max_generation_length = generationsParams.spec_decoding_max_generation_length;
xqaParams.mrope_rotary_sin_cos = generationsParams.mrope_rotary_sin_cos;
xqaParams.mrope_position_deltas = generationsParams.mrope_position_deltas;
xqaParams.total_num_input_tokens = generationsParams.total_num_input_tokens;
xqaParams.fp8_out_scale = (mFP8ContextFMHA ? generationsParams.attention_output_orig_quant : nullptr);
return true;
@ -376,6 +380,8 @@ void fusedQKV_masked_attention_dispatch(Multihead_attention_params<T_MMHA, CROSS
// cross attn
params.memory_length_per_sample = input_params.memory_length_per_sample;
params.mrope_position_deltas = input_params.mrope_position_deltas;
sync_check_cuda_error();
masked_multihead_attention(params, input_params.kv_block_array, input_params.shift_k_cache_buffer, stream);
@ -1248,6 +1254,8 @@ int GPTAttentionPluginCommon::enqueueContext(EnqueueContextParams<T> const& para
preprocessingParams.cu_kv_seq_lens = cu_kv_seqlens;
preprocessingParams.rotary_embedding_inv_freq = rotary_inv_freq_buf;
preprocessingParams.rotary_coef_cache_buffer = params.rotary_cos_sin;
preprocessingParams.mrope_rotary_sin_cos = params.mrope_rotary_sin_cos;
preprocessingParams.mrope_position_deltas = params.mrope_position_deltas;
preprocessingParams.kvScaleOrigQuant = params.kv_scale_orig_quant;
preprocessingParams.spec_decoding_position_offsets = nullptr;
@ -1861,6 +1869,7 @@ int GPTAttentionPluginCommon::enqueueGeneration(EnqueueGenerationParams<T> const
dispatch_params.memory_length_per_sample = params.encoder_input_lengths;
dispatch_params.block_sparse_attention = mMaskType == AttentionMaskType::BLOCKSPARSE;
dispatch_params.block_sparse_params = mBlockSparseParams;
dispatch_params.mrope_position_deltas = params.mrope_position_deltas;
using DataType = typename SATypeConverter<T>::Type;
if (!isCrossAttention())

View File

@ -135,6 +135,9 @@ protected:
int32_t max_blocks_per_sequence;
int32_t const* host_context_lengths;
void* workspace;
float2 const* mrope_rotary_sin_cos = nullptr;
int32_t const* mrope_position_deltas = nullptr;
// optional when relative position
T const* relative_attention_bias = nullptr;
int relative_attention_bias_stride = 0;
@ -237,6 +240,9 @@ protected:
int32_t* semaphores;
void* workspace;
int32_t const* host_past_key_value_lengths;
float2 const* mrope_rotary_sin_cos = nullptr;
int32_t const* mrope_position_deltas = nullptr;
// optional when relative position
T const* relative_attention_bias = nullptr;
int relative_attention_bias_stride = 0;
@ -293,7 +299,8 @@ protected:
{
return mPositionEmbeddingType == tensorrt_llm::kernels::PositionEmbeddingType::kROPE_GPTJ
|| mPositionEmbeddingType == tensorrt_llm::kernels::PositionEmbeddingType::kROPE_GPT_NEOX
|| mPositionEmbeddingType == tensorrt_llm::kernels::PositionEmbeddingType::kLONG_ROPE;
|| mPositionEmbeddingType == tensorrt_llm::kernels::PositionEmbeddingType::kLONG_ROPE
|| mPositionEmbeddingType == tensorrt_llm::kernels::PositionEmbeddingType::kROPE_M;
}
bool isLongRoPE() const
@ -306,6 +313,11 @@ protected:
return !mEnableContextFMHA && mCrossAttention;
}
bool isMRoPE() const
{
return mPositionEmbeddingType == tensorrt_llm::kernels::PositionEmbeddingType::kROPE_M;
}
bool isCrossAttention() const
{
return mCrossAttention;

View File

@ -162,6 +162,8 @@ bool GPTAttentionPlugin::isEntryUsed(IdxEntry const& entry) const
case IdxEntry::SPEC_DECODING_GENERATION_LENGTHS: return mIsSpecDecodingEnabled;
case IdxEntry::SPEC_DECODING_PACKED_MASK: return mIsSpecDecodingEnabled;
case IdxEntry::SPEC_DECODING_POSITION_OFFSETS: return mIsSpecDecodingEnabled;
case IdxEntry::MROPE_ROTARY_SIN_COS: return isMRoPE();
case IdxEntry::MROPE_POSITION_DELTAS: return isMRoPE();
case IdxEntry::HOST_RUNTIME_PERF_KNOBS: return true;
case IdxEntry::HOST_CONTEXT_PROGRESS: return true;
case IdxEntry::MLA_FUSED_Q_PROJ_TENSOR: return mIsMLAEnabled;
@ -266,6 +268,14 @@ bool GPTAttentionPlugin::supportsFormatCombination(
posCaseLine = __LINE__;
result = inOut[pos].type == nvinfer1::DataType::kINT32;
}
else if (isMRoPE() && (pos == getIdx(IdxEntry::MROPE_ROTARY_SIN_COS)))
{
return inOut[pos].type == nvinfer1::DataType::kFLOAT;
}
else if (isMRoPE() && (pos == getIdx(IdxEntry::MROPE_POSITION_DELTAS)))
{
return inOut[pos].type == nvinfer1::DataType::kINT32;
}
else if (pos == getIdx(IdxEntry::HOST_RUNTIME_PERF_KNOBS) || pos == getIdx(IdxEntry::HOST_CONTEXT_PROGRESS))
{
posCaseLine = __LINE__;
@ -411,13 +421,15 @@ void GPTAttentionPlugin::configurePluginImpl(nvinfer1::DynamicPluginTensorDesc c
int const num_requests = 256;
int const sink_token_length = 0;
EnqueueGenerationParams<T> enqueueParams{/*attention_input=*/nullptr,
EnqueueGenerationParams<T> enqueueParams{
/*attention_input=*/nullptr,
/*qkv_bias=*/nullptr,
/*attention_mask*/ nullptr,
/*rotary_inv_freq*/ nullptr,
/*input_seq_length=*/0,
/*sequence_lengths=*/nullptr,
/*past_kv_length=*/0, beamWidth,
/*past_kv_length=*/0,
beamWidth,
/*context_lengths=*/nullptr,
/*kv_scale_orig_quant=*/nullptr,
/*kv_scale_quant_orig=*/nullptr,
@ -428,12 +440,19 @@ void GPTAttentionPlugin::configurePluginImpl(nvinfer1::DynamicPluginTensorDesc c
/*block_offsets=*/nullptr,
/*host_primary_pool_pointer=*/nullptr,
/*host_secondary_pool_pointer=*/nullptr,
/*attention_mask_stride*/ 0, max_attention_window_size, cyclic_attention_window_size, sink_token_length,
/*attention_mask_stride*/ 0,
max_attention_window_size,
cyclic_attention_window_size,
sink_token_length,
num_requests,
/*max_blocks_per_sequence=*/0,
/*cache_indir=*/nullptr,
/*workspace=*/nullptr,
/*max_context_kv_len_list=*/nullptr};
/*max_context_kv_len_list=*/nullptr,
/*mrope_rotary_sin_cos*/ nullptr,
/*mrope_position_deltas*/ nullptr,
};
prepareEnqueueGeneration<T, KVCacheBuffer>(enqueueParams);
@ -699,6 +718,12 @@ int GPTAttentionPlugin::enqueueSome(int32_t seqIdxBeg, int32_t localNbSeq, int32
rotary_cos_sin = reinterpret_cast<float2 const*>(inputs[getIdx(IdxEntry::ROTARY_COS_SIN)]);
}
auto const mrope_rotary_sin_cos
= isMRoPE() ? reinterpret_cast<float2 const*>(inputs[getIdx(IdxEntry::MROPE_ROTARY_SIN_COS)]) : nullptr;
auto const mrope_position_deltas
= isMRoPE() ? reinterpret_cast<int32_t const*>(inputs[getIdx(IdxEntry::MROPE_POSITION_DELTAS)]) : nullptr;
if (mUnfuseQkvGemm)
{
int const max_seqlen = inputDesc[getIdx(IdxEntry::QKV_TENSOR)].dims.d[mRemovePadding ? 0 : 1];
@ -932,7 +957,9 @@ int GPTAttentionPlugin::enqueueSome(int32_t seqIdxBeg, int32_t localNbSeq, int32
cyclic_attention_window_size, sink_token_length, context_q_lengths, sequence_kv_length, kv_scale_orig_quant,
kv_scale_quant_orig, attention_output_orig_quant, alibi_slopes, context_buf_, key_value_cache,
block_offsets, host_block_offsets, host_primary_pool_pointer, host_secondary_pool_pointer, batch_size,
localNbTokens, max_blocks_per_sequence, host_context_lengths, workspace};
localNbTokens, max_blocks_per_sequence, host_context_lengths, workspace, mrope_rotary_sin_cos,
mrope_position_deltas};
enqueue_params.runtime_perf_knobs = runtime_perf_knobs;
if (isRelativePosition())
{
@ -1007,7 +1034,8 @@ int GPTAttentionPlugin::enqueueSome(int32_t seqIdxBeg, int32_t localNbSeq, int32
kv_scale_quant_orig, attention_output_orig_quant, alibi_slopes, context_buf_, key_value_cache,
block_offsets, host_primary_pool_pointer, host_secondary_pool_pointer, attention_mask_stride,
max_attention_window_size, cyclic_attention_window_size, sink_token_length, num_requests,
max_blocks_per_sequence, cache_indir, mMultiBlockSemaphores.get(), workspace, max_context_kv_len_list};
max_blocks_per_sequence, cache_indir, mMultiBlockSemaphores.get(), workspace, max_context_kv_len_list,
mrope_rotary_sin_cos, mrope_position_deltas};
enqueue_params.host_context_lengths = host_context_lengths;
enqueue_params.runtime_perf_knobs = runtime_perf_knobs;
if (isRelativePosition())

View File

@ -220,6 +220,8 @@ private:
SPEC_DECODING_GENERATION_LENGTHS,
SPEC_DECODING_PACKED_MASK,
SPEC_DECODING_POSITION_OFFSETS,
MROPE_ROTARY_SIN_COS,
MROPE_POSITION_DELTAS,
HOST_RUNTIME_PERF_KNOBS,
HOST_CONTEXT_PROGRESS,
MLA_FUSED_Q_PROJ_TENSOR,

View File

@ -18,6 +18,8 @@
#include "tensorrt_llm/common/cudaBf16Wrapper.h"
#include "tensorrt_llm/common/dataType.h"
#include "tensorrt_llm/common/quantization.h"
#include "tensorrt_llm/runtime/iBuffer.h"
#include "tensorrt_llm/runtime/utils/debugUtils.h"
#include <numeric>
using namespace nvinfer1;
@ -42,8 +44,8 @@ MixtureOfExpertsPlugin::MixtureOfExpertsPlugin(bool remove_input_padding, int nu
nvinfer1::DataType type, nvinfer1::DataType weight_type, nvinfer1::DataType output_type, QuantMode quant_mode,
bool use_finished, bool use_bias, int tp_size, int tp_rank, int ep_size, int ep_rank,
MOEExpertScaleNormalizationMode normalization_mode, float sparse_mixer_epsilon, bool force_determinism,
MixtureOfExpertsPluginProfilerPtr gemm_profiler_ptr, bool use_lora, nvinfer1::DataType lora_type,
LoraPluginProfilerPtr lora_profiler, int max_low_rank)
int side_stream_id, MixtureOfExpertsPluginProfilerPtr gemm_profiler_ptr, bool use_lora,
nvinfer1::DataType lora_type, LoraPluginProfilerPtr lora_profiler, int max_low_rank)
: mRemoveInputPadding(remove_input_padding)
, mNumExperts(number_of_experts)
, mK(top_k)
@ -60,6 +62,7 @@ MixtureOfExpertsPlugin::MixtureOfExpertsPlugin(bool remove_input_padding, int nu
, mNormalizationMode(normalization_mode)
, mSparseMixerEpsilon(sparse_mixer_epsilon)
, mUseDeterministicKernels(force_determinism)
, mSideStreamId(side_stream_id)
, mGemmProfiler(std::move(gemm_profiler_ptr))
, mUseLora(use_lora)
, mLoraType(lora_type)
@ -90,6 +93,7 @@ tensorrt_llm::plugins::MixtureOfExpertsPlugin::MixtureOfExpertsPlugin(MixtureOfE
, mGemmId2(other.mGemmId2)
, mSparseMixerEpsilon(other.mSparseMixerEpsilon)
, mUseDeterministicKernels(other.mUseDeterministicKernels)
, mSideStreamId(other.mSideStreamId)
, mGemmProfiler(other.mGemmProfiler)
, mUseLora(other.mUseLora)
, mLoraType(other.mLoraType)
@ -111,8 +115,8 @@ size_t MixtureOfExpertsPlugin::getSerializationSize() const noexcept
+ sizeof(mExpertInterSize) + sizeof(mActivationType) + sizeof(mType) + sizeof(mWeightType) + sizeof(mOutputType)
+ sizeof(QuantMode::BaseType) + sizeof(mUseFinished) + sizeof(mUseBias) + sizeof(mParallelismConfig)
+ sizeof(mNormalizationMode) + sizeof(mSparseMixerEpsilon) + sizeof(mDims) + sizeof(mUseDeterministicKernels)
+ mGemmProfiler->getSerializationSize(mGemmId1) + mGemmProfiler->getSerializationSize(mGemmId2)
+ sizeof(mUseLora) + sizeof(mLoraType) + sizeof(mMaxLowRank);
+ sizeof(mSideStreamId) + mGemmProfiler->getSerializationSize(mGemmId1)
+ mGemmProfiler->getSerializationSize(mGemmId2) + sizeof(mUseLora) + sizeof(mLoraType) + sizeof(mMaxLowRank);
if (hasLora())
{
@ -149,6 +153,7 @@ MixtureOfExpertsPlugin::MixtureOfExpertsPlugin(void const* data, size_t length,
read(d, mSparseMixerEpsilon);
read(d, mDims);
read(d, mUseDeterministicKernels);
read(d, mSideStreamId);
read(d, mUseLora);
read(d, mLoraType);
read(d, mMaxLowRank);
@ -193,6 +198,7 @@ void MixtureOfExpertsPlugin::serialize(void* buffer) const noexcept
write(d, mSparseMixerEpsilon);
write(d, mDims);
write(d, mUseDeterministicKernels);
write(d, mSideStreamId);
write(d, mUseLora);
write(d, mLoraType);
write(d, mMaxLowRank);
@ -308,6 +314,9 @@ void MixtureOfExpertsPlugin::init()
TLLM_CUDA_CHECK(cudaEventCreate(&mMemcpyEvent));
}
mSideStreamPtr = nullptr;
mDebugStallMain = tensorrt_llm::runtime::utils::stallStream("TLLM_DEBUG_MOE_STALL_MAIN");
mDebugStallSide = tensorrt_llm::runtime::utils::stallStream("TLLM_DEBUG_MOE_STALL_SIDE");
}
// IPluginV2DynamicExt Methods
@ -321,7 +330,7 @@ nvinfer1::IPluginV2DynamicExt* MixtureOfExpertsPlugin::clone() const noexcept
nvinfer1::DimsExprs MixtureOfExpertsPlugin::getOutputDimensions(
int outputIndex, nvinfer1::DimsExprs const* inputs, int nbInputs, nvinfer1::IExprBuilder& exprBuilder) noexcept
{
assert(outputIndex == getOutputTensorIndex());
assert(outputIndex == getOutputTensorIndex() || outputIndex == getOutputDummyTensorIndex());
return inputs[getInputTensorIndex()];
}
@ -359,6 +368,10 @@ bool MixtureOfExpertsPlugin::supportsFormatCombination(
{
return inOut[pos].type == mOutputType;
}
else if (useSideStream() && pos == nbInputs + getOutputDummyTensorIndex())
{
return inOut[pos].type == mType;
}
else if (hasExpertFp8QuantScales() && getExpertFP8Dequant1Index() <= pos && pos <= getExpertFP8QuantFinalIndex())
{
return inOut[pos].type == DataType::kFLOAT;
@ -514,6 +527,10 @@ size_t MixtureOfExpertsPlugin::getWorkspaceSize(nvinfer1::PluginTensorDesc const
TLLM_CHECK_WITH_INFO(nbInputs == getNbInputs(), "Required input to plugin is missing");
TLLM_CHECK_WITH_INFO(nbOutputs == getNbOutputs(), "Required output to plugin is missing");
if (useSideStream())
{
return 0;
}
int const num_tokens = getNumTokens(inputs);
int const num_lora_reqs = getNumLoraRequests(inputs);
return setupWorkspace(nullptr, num_tokens, num_lora_reqs).size;
@ -661,6 +678,36 @@ int MixtureOfExpertsPlugin::enqueue(nvinfer1::PluginTensorDesc const* inputDesc,
int64_t const num_reqs = getNumLoraRequests(inputDesc);
int64_t const num_not_finished = num_tokens; // TODO Take this as an input
if (useSideStream())
{
// Prepare the side stream
if (!mSideStreamPtr)
{
auto const resource_name = nvinfer1::pluginInternal::SideStream::getResourceKey(mSideStreamId);
nvinfer1::pluginInternal::SideStream side_stream{};
mSideStreamPtr = reinterpret_cast<nvinfer1::pluginInternal::SideStream*>(
getPluginRegistry()->acquirePluginResource(resource_name, &side_stream));
}
// Debug the code with the main stream stalled (only executed when the environment variable
// TLLM_DEBUG_MOE_STALL_MAIN is set and has a positive value)
mSideStreamPtr->stallMainStream("TLLM_DEBUG_MOE_STALL_MAIN", stream, mDebugStallMain);
// The side stream waits for the inputs managed by the main stream to be ready
mSideStreamPtr->waitMainStreamOnSideStream(stream);
// Provide data dependency for the shared experts running after this plugin by copying inputs on the main stream
size_t count = 1;
for (int i = 0; i < inputDesc[getInputTensorIndex()].dims.nbDims; ++i)
{
count *= inputDesc[getInputTensorIndex()].dims.d[i];
}
count *= tensorrt_llm::runtime::BufferDataType(inputDesc[getInputTensorIndex()].type).getSize();
TLLM_CUDA_CHECK(cudaMemcpyAsync(outputs[getOutputDummyTensorIndex()], inputs[getInputTensorIndex()], count,
cudaMemcpyDeviceToDevice, stream));
// Switch from the main stream to the side stream
stream = mSideStreamPtr->getStream();
// The workspace is managed by the side stream (otherwise, the lifetime of workspace may be incorrect)
auto const workspace_size = setupWorkspace(nullptr, num_tokens, num_reqs).size;
workspace_ptr = mSideStreamPtr->getWorkspacePtr(workspace_size);
}
auto workspace = setupWorkspace(workspace_ptr, num_tokens, num_reqs);
auto w1_desc = inputDesc[getExpertWeights1Index()];
@ -728,6 +775,13 @@ int MixtureOfExpertsPlugin::enqueue(nvinfer1::PluginTensorDesc const* inputDesc,
static_cast<int*>(workspace.selected_experts), mSparseMixerEpsilon, mParallelismConfig, mNormalizationMode,
hasLora(), lora_params, stream);
if (useSideStream())
{
// Debug the code with the side stream stalled (only executed when the environment variable
// TLLM_DEBUG_MOE_STALL_SIDE is set and has a positive value)
mSideStreamPtr->stallSideStream("TLLM_DEBUG_MOE_STALL_SIDE", mDebugStallSide);
}
return 0;
}
@ -735,8 +789,12 @@ int MixtureOfExpertsPlugin::enqueue(nvinfer1::PluginTensorDesc const* inputDesc,
nvinfer1::DataType MixtureOfExpertsPlugin::getOutputDataType(
int index, nvinfer1::DataType const* inputTypes, int nbInputs) const noexcept
{
TLLM_CHECK(index == getOutputTensorIndex());
TLLM_CHECK(index == getOutputTensorIndex() || index == getOutputDummyTensorIndex());
TLLM_CHECK(inputTypes[getInputTensorIndex()] == mType);
if (useSideStream() && index == getOutputDummyTensorIndex())
{
return mType;
}
return mOutputType;
}
@ -769,7 +827,15 @@ int MixtureOfExpertsPlugin::initialize() noexcept
return 0;
}
void MixtureOfExpertsPlugin::terminate() noexcept {}
void MixtureOfExpertsPlugin::terminate() noexcept
{
if (mSideStreamPtr)
{
auto const resource_name = nvinfer1::pluginInternal::SideStream::getResourceKey(mSideStreamId);
getPluginRegistry()->releasePluginResource(resource_name);
mSideStreamPtr = nullptr;
}
}
void MixtureOfExpertsPlugin::destroy() noexcept
{
@ -836,6 +902,7 @@ MixtureOfExpertsPluginCreator::MixtureOfExpertsPluginCreator()
static_cast<int>(MOEExpertScaleNormalizationMode::NONE)));
mPluginAttributes.emplace_back(
nvinfer1::PluginField("sparse_mixer_epsilon", nullptr, PluginFieldType::kFLOAT32, 0));
mPluginAttributes.emplace_back(nvinfer1::PluginField("side_stream_id", nullptr, PluginFieldType::kINT32, 0));
mPluginAttributes.emplace_back(nvinfer1::PluginField("use_lora", nullptr, PluginFieldType::kINT32, 0));
mPluginAttributes.emplace_back(nvinfer1::PluginField("lora_type_id", nullptr, PluginFieldType::kINT32, 0));
mPluginAttributes.emplace_back(nvinfer1::PluginField("max_low_rank", nullptr, PluginFieldType::kINT32, 0));
@ -865,6 +932,7 @@ IPluginV2* MixtureOfExpertsPluginCreator::createPlugin(
int mEPRank{};
int mNormalizationMode{};
int mRequiresDeterminism{0};
int mSideStreamId{0};
int mUseLora{};
int mLoraType{INT_MAX};
int mMaxLowRank{0};
@ -902,6 +970,7 @@ IPluginV2* MixtureOfExpertsPluginCreator::createPlugin(
MapPair{"use_bias", std::ref(mUseBias), true},
MapPair{"output_type_id", std::ref(mOutputType), true},
MapPair{"force_determinism", std::ref(mRequiresDeterminism), true},
MapPair{"side_stream_id", std::ref(mSideStreamId), true},
MapPair{"lora_type_id", std::ref(mLoraType), true},
MapPair{"max_low_rank", std::ref(mMaxLowRank), true},
};
@ -962,8 +1031,8 @@ IPluginV2* MixtureOfExpertsPluginCreator::createPlugin(
static_cast<nvinfer1::DataType>(mWeightType), static_cast<nvinfer1::DataType>(mOutputType),
QuantMode(mQuantMode), mUseFinished != 0, mUseBias != 0, mTPSize, mTPRank, mEPSize, mEPRank,
static_cast<MOEExpertScaleNormalizationMode>(mNormalizationMode), mSparseMixerEpsilon,
mRequiresDeterminism != 0, gemmProfiler, mUseLora != 0, static_cast<nvinfer1::DataType>(mLoraType),
loraProfiler, mMaxLowRank);
mRequiresDeterminism != 0, mSideStreamId, gemmProfiler, mUseLora != 0,
static_cast<nvinfer1::DataType>(mLoraType), loraProfiler, mMaxLowRank);
obj->setPluginNamespace(mNamespace.c_str());
return obj;
}

View File

@ -24,6 +24,7 @@
#include "tensorrt_llm/kernels/mixtureOfExperts/moe_kernels.h"
#include "tensorrt_llm/plugins/common/gemmPluginProfiler.h"
#include "tensorrt_llm/plugins/common/plugin.h"
#include "tensorrt_llm/plugins/cudaStreamPlugin/cudaStreamPlugin.h"
#include "tensorrt_llm/plugins/gemmPlugin/gemmPlugin.h"
#include "tensorrt_llm/runtime/cudaStream.h"
#include <cassert>
@ -106,8 +107,8 @@ public:
nvinfer1::DataType weight_type, nvinfer1::DataType output_type, tensorrt_llm::common::QuantMode quant_mode,
bool use_finished, bool use_bias, int tp_size, int tp_rank, int ep_size, int ep_rank,
MOEExpertScaleNormalizationMode normalization_mode, float sparse_mixer_epsilon, bool force_determinism,
MixtureOfExpertsPluginProfilerPtr gemm_profiler_ptr, bool use_lora, nvinfer1::DataType lora_type,
LoraPluginProfilerPtr lora_profiler, int max_low_rank);
int side_stream_id, MixtureOfExpertsPluginProfilerPtr gemm_profiler_ptr, bool use_lora,
nvinfer1::DataType lora_type, LoraPluginProfilerPtr lora_profiler, int max_low_rank);
MixtureOfExpertsPlugin(void const* data, size_t length, MixtureOfExpertsPluginProfilerPtr gemm_profiler_ptr,
LoraPluginProfilerPtr lora_profiler);
MixtureOfExpertsPlugin(MixtureOfExpertsPlugin const&);
@ -139,7 +140,7 @@ public:
int getNbOutputs() const noexcept override
{
return 1;
return 1 + useSideStream();
}
int initialize() noexcept override;
@ -170,6 +171,10 @@ private:
GemmDims mDims{};
bool mUseDeterministicKernels = false;
int mSideStreamId = 0;
int mDebugStallMain = 0;
int mDebugStallSide = 0;
GemmIDMoe mGemmId1{};
GemmIDMoe mGemmId2{};
@ -197,6 +202,7 @@ private:
std::vector<int32_t> mLoraExpandGatedRanks{};
cudaEvent_t mMemcpyEvent;
nvinfer1::pluginInternal::SideStream* mSideStreamPtr;
// The below are not serialised
std::string const mLayerName{};
@ -279,6 +285,11 @@ private:
return hasExpertFp8QuantScales() && mOutputType == nvinfer1::DataType::kFP8;
}
bool useSideStream() const
{
return mSideStreamId > 0;
}
bool hasLora() const
{
return mUseLora;
@ -390,6 +401,11 @@ private:
return 0;
}
IndexType getOutputDummyTensorIndex() const
{
return getOutputTensorIndex() + useSideStream();
}
/**
* Get the index of the expert shape tuple that represents the inner dimension
*/

View File

@ -416,34 +416,33 @@ std::set<int> getLocalGroup(std::set<int> const& group)
ranks.push_back(myRank);
for (auto it = std::next(std::begin(group), 1); it != group.end(); ++it)
{
LOCAL_COMM_SESSION.recvValue(rank, *it, 0);
COMM_SESSION.recvValue(rank, *it, 0);
ranks.push_back(rank);
}
for (auto it = std::next(std::begin(group), 1); it != group.end(); ++it)
{
LOCAL_COMM_SESSION.send(ranks.data(), localSize, tensorrt_llm::mpi::MpiType::kINT32, *it, 0);
COMM_SESSION.send(ranks.data(), localSize, tensorrt_llm::mpi::MpiType::kINT32, *it, 0);
}
localRanks.clear();
localRanks.push_back(myLocalRank);
for (auto it = std::next(std::begin(group), 1); it != group.end(); ++it)
{
LOCAL_COMM_SESSION.recvValue(rank, *it, 0);
COMM_SESSION.recvValue(rank, *it, 0);
localRanks.push_back(rank);
}
for (auto it = std::next(std::begin(group), 1); it != group.end(); ++it)
{
LOCAL_COMM_SESSION.send(localRanks.data(), localSize, tensorrt_llm::mpi::MpiType::kINT32, *it, 0);
COMM_SESSION.send(localRanks.data(), localSize, tensorrt_llm::mpi::MpiType::kINT32, *it, 0);
}
}
else
{
LOCAL_COMM_SESSION.sendValue(myRank, *group.begin(), 0);
LOCAL_COMM_SESSION.recv(ranks.data(), localSize, tensorrt_llm::mpi::MpiType::kINT32, *group.begin(), 0);
COMM_SESSION.sendValue(myRank, *group.begin(), 0);
COMM_SESSION.recv(ranks.data(), localSize, tensorrt_llm::mpi::MpiType::kINT32, *group.begin(), 0);
LOCAL_COMM_SESSION.sendValue(myLocalRank, *group.begin(), 0);
LOCAL_COMM_SESSION.recv(
localRanks.data(), localSize, tensorrt_llm::mpi::MpiType::kINT32, *group.begin(), 0);
COMM_SESSION.sendValue(myLocalRank, *group.begin(), 0);
COMM_SESSION.recv(localRanks.data(), localSize, tensorrt_llm::mpi::MpiType::kINT32, *group.begin(), 0);
}
}
@ -732,7 +731,6 @@ IPluginV2* AllreducePluginCreator::createPlugin(char const* name, PluginFieldCol
bias = *static_cast<int8_t const*>(fields[i].data);
}
}
try
{
auto* obj = new AllreducePlugin(group, type, strategy, config, fusion_op, counter, eps, affine, bias);

View File

@ -36,6 +36,8 @@ target_compile_definitions(
if(NOT WIN32)
set_target_properties(
${TRTLLM_PYBIND_MODULE}
PROPERTIES LINK_FLAGS
"-Wl,-rpath,'$ORIGIN/libs' ${AS_NEEDED_FLAG} ${UNDEFINED_FLAG}")
PROPERTIES
LINK_FLAGS
"-Wl,-rpath,'$ORIGIN/libs' -Wl,-rpath,'$ORIGIN/../nvidia/nccl/lib' ${AS_NEEDED_FLAG} ${UNDEFINED_FLAG}"
)
endif()

View File

@ -237,6 +237,8 @@ void initBindings(pybind11::module_& m)
std::optional<std::vector<tb::LlmRequest::SizeType32>> position_ids,
std::optional<at::Tensor> prompt_embedding_table,
std::optional<tb::LlmRequest::SizeType32> prompt_vocab_size,
std::optional<at::Tensor> mrope_rotary_sin_cos,
std::optional<tb::LlmRequest::SizeType32> mrope_position_deltas,
std::optional<LoraTaskIdType> lora_task_id, std::optional<at::Tensor> lora_weights,
std::optional<at::Tensor> lora_config,
std::optional<executor::LookaheadDecodingConfig> lookahead_config,
@ -269,6 +271,7 @@ void initBindings(pybind11::module_& m)
auto stop_words_list_tensor_ptr = makeOptionalTensor(stop_words_list);
auto prompt_embedding_table_tensor_ptr = makeOptionalTensor(prompt_embedding_table);
auto lora_weights_tensor_ptr = makeOptionalTensor(lora_weights);
auto mrope_rotary_sin_cos_tensor_ptr = makeOptionalTensor(mrope_rotary_sin_cos);
auto lora_config_tensor_ptr = makeOptionalTensor(lora_config);
auto draft_logits_tensor_ptr = makeOptionalTensor(draft_logits);
auto encoder_input_features_tensor_ptr = makeOptionalTensor(encoder_input_features);
@ -277,19 +280,20 @@ void initBindings(pybind11::module_& m)
return tb::LlmRequest{request_id, max_new_tokens, input_tokens, sampling_config, is_streaming,
end_id, pad_id, embedding_bias_tensor_ptr, bad_words_list_tensor_ptr,
stop_words_list_tensor_ptr, position_ids, prompt_embedding_table_tensor_ptr, prompt_vocab_size,
lora_task_id, lora_weights_tensor_ptr, lora_config_tensor_ptr, lookahead_config,
kv_cache_retention_config, return_log_probs, return_context_logits, return_generation_logits,
draft_tokens, draft_logits_tensor_ptr, exclude_input_from_output, logits_post_processor,
apply_logits_post_processor_batched, encoder_input_tokens, return_encoder_output, client_id,
priority, encoder_input_features_tensor_ptr, encoder_output_length,
cross_attention_mask_tensor_ptr, llm_request_type, input_token_extra_ids,
num_return_sequences};
mrope_rotary_sin_cos_tensor_ptr, mrope_position_deltas, lora_task_id, lora_weights_tensor_ptr,
lora_config_tensor_ptr, lookahead_config, kv_cache_retention_config, return_log_probs,
return_context_logits, return_generation_logits, draft_tokens, draft_logits_tensor_ptr,
exclude_input_from_output, logits_post_processor, apply_logits_post_processor_batched,
encoder_input_tokens, return_encoder_output, client_id, priority,
encoder_input_features_tensor_ptr, encoder_output_length, cross_attention_mask_tensor_ptr,
llm_request_type, input_token_extra_ids, num_return_sequences};
}),
py::arg("request_id"), py::arg("max_new_tokens"), py::arg("input_tokens"), py::arg("sampling_config"),
py::arg("is_streaming"), py::arg("end_id") = std::nullopt, py::arg("pad_id") = std::nullopt,
py::arg("embedding_bias") = std::nullopt, py::arg("bad_words_list") = std::nullopt,
py::arg("stop_words_list") = std::nullopt, py::arg("position_ids") = std::nullopt,
py::arg("prompt_embedding_table") = std::nullopt, py::arg("prompt_vocab_size") = std::nullopt,
py::arg("mrope_rotary_sin_cos") = std::nullopt, py::arg("mrope_position_deltas") = std::nullopt,
py::arg("lora_task_id") = std::nullopt, py::arg("lora_weights") = std::nullopt,
py::arg("lora_config") = std::nullopt, py::arg("lookahead_config") = std::nullopt,
py::arg("kv_cache_retention_config") = std::nullopt, py::arg("return_log_probs") = false,

View File

@ -73,6 +73,8 @@ std::shared_ptr<tb::LlmRequest> LlmRequest::toTrtLlm() const
auto badWordsList = from_torch(mBadWordsList);
auto stopWordsList = from_torch(mStopWordsList);
auto promptEmbeddingTable = from_torch(mPromptEmbeddingTable);
auto mropeRotarySinCos = from_torch(mMropeRotarySinCos);
auto loraWeights = from_torch(mLoraWeights);
auto loraConfig = from_torch(mLoraConfig);
auto draftLogits = from_torch(mDraftLogits);
@ -82,11 +84,11 @@ std::shared_ptr<tb::LlmRequest> LlmRequest::toTrtLlm() const
return std::make_shared<tb::LlmRequest>(mRequestId, mMaxNewTokens,
std::make_shared<std::vector<TokenIdType>>(mTokens.at(0)), mSamplingConfig, mIsStreaming, mEndId, mPadId,
embeddingBias, badWordsList, stopWordsList, mPositionIds, promptEmbeddingTable, mPromptVocabSize, mLoraTaskId,
loraWeights, loraConfig, mLookaheadConfig, mKvCacheRetentionConfig, returnLogProbs(), mReturnContextLogits,
mReturnGenerationLogits, mDraftTokens, draftLogits, mExcludeInputFromOutput,
callbackAdapter(mLogitsPostProcessor), mApplyLogitsPostProcessorBatched, mEncoderTokens, mReturnEncoderOutput,
mClientId, mPriority, encoderInputFeatures, mEncoderOutputLength, crossAttentionMask,
tb::LlmRequestType::LLMREQUEST_TYPE_CONTEXT_AND_GENERATION, mInputTokenExtraIds, mNumReturnSequences,
std::nullopt, skipCrossAttnBlocks);
embeddingBias, badWordsList, stopWordsList, mPositionIds, promptEmbeddingTable, mPromptVocabSize,
mropeRotarySinCos, mMropePositionDeltas, mLoraTaskId, loraWeights, loraConfig, mLookaheadConfig,
mKvCacheRetentionConfig, returnLogProbs(), mReturnContextLogits, mReturnGenerationLogits, mDraftTokens,
draftLogits, mExcludeInputFromOutput, callbackAdapter(mLogitsPostProcessor), mApplyLogitsPostProcessorBatched,
mEncoderTokens, mReturnEncoderOutput, mClientId, mPriority, encoderInputFeatures, mEncoderOutputLength,
crossAttentionMask, tb::LlmRequestType::LLMREQUEST_TYPE_CONTEXT_AND_GENERATION, mInputTokenExtraIds,
mNumReturnSequences, std::nullopt, skipCrossAttnBlocks);
}

View File

@ -57,6 +57,8 @@ public:
std::optional<std::vector<SizeType32>> positionIds = std::nullopt,
std::optional<TensorPtr> promptEmbeddingTable = std::nullopt,
std::optional<SizeType32> promptVocabSize = std::nullopt,
std::optional<TensorPtr> mropeRotarySinCos = std::nullopt,
std::optional<SizeType32> mropePositionDeltas = std::nullopt,
std::optional<LoraTaskIdType> loraTaskId = std::nullopt, std::optional<TensorPtr> loraWeights = std::nullopt,
std::optional<TensorPtr> loraConfig = std::nullopt,
std::optional<executor::LookaheadDecodingConfig> lookaheadConfig = std::nullopt,
@ -76,8 +78,9 @@ public:
samplingConfig, isStreaming, endId, padId, embeddingBias, badWordsList, stopWordsList,
positionIds.has_value() ? std::make_shared<std::vector<SizeType32>>(std::move(positionIds.value()))
: std::optional<std::shared_ptr<std::vector<SizeType32>>>(std::nullopt),
promptEmbeddingTable, promptVocabSize, loraTaskId, loraWeights, loraConfig, lookaheadConfig,
kvCacheRetentionConfig, returnLogProbs, returnContextLogits, returnGenerationLogits,
promptEmbeddingTable, promptVocabSize, mropeRotarySinCos, mropePositionDeltas, loraTaskId, loraWeights,
loraConfig, lookaheadConfig, kvCacheRetentionConfig, returnLogProbs, returnContextLogits,
returnGenerationLogits,
draftTokens.has_value() ? std::make_shared<VecTokens>(std::move(draftTokens.value()))
: std::make_shared<VecTokens>(),
draftLogits, excludeInputFromOutput, logitsPostProcessor, applyLogitsPostProcessorBatched,

View File

@ -245,6 +245,7 @@ PYBIND11_MODULE(TRTLLM_PYBIND_MODULE, m)
.def_property("max_prompt_embedding_table_size", &tr::ModelConfig::getMaxPromptEmbeddingTableSize,
&tr::ModelConfig::setMaxPromptEmbeddingTableSize)
.def_property_readonly("use_prompt_tuning", &tr::ModelConfig::usePromptTuning)
.def_property_readonly("use_mrope", &tr::ModelConfig::useMrope)
.def_property("use_lora_plugin", py::overload_cast<>(&tr::ModelConfig::useLoraPlugin, py::const_),
py::overload_cast<bool>(&tr::ModelConfig::useLoraPlugin))
.def_property("compute_context_logits", py::overload_cast<>(&tr::ModelConfig::computeContextLogits, py::const_),

View File

@ -314,6 +314,11 @@ void InitBindings(pybind11::module_& m)
.def_property_readonly("embedding_table", &tle::PromptTuningConfig::getEmbeddingTable)
.def_property_readonly("input_token_extra_ids", &tle::PromptTuningConfig::getInputTokenExtraIds);
py::class_<tle::MropeConfig>(m, "MropeConfig")
.def(py::init<Tensor, SizeType32>(), py::arg("mrope_rotary_sin_cos"), py::arg("mrope_position_deltas"))
.def_property_readonly("mrope_rotary_sin_cos", &tle::MropeConfig::getMRopeRotarySinCos)
.def_property_readonly("mrope_position_deltas", &tle::MropeConfig::getMRopePositionDeltas);
py::class_<tle::LoraConfig>(m, "LoraConfig")
.def(py::init<uint64_t, std::optional<Tensor>, std::optional<Tensor>>(), py::arg("task_id"),
py::arg("weights") = py::none(), py::arg("config") = py::none())
@ -391,7 +396,8 @@ void InitBindings(pybind11::module_& m)
std::optional<std::list<tle::VecTokens>> badWords,
std::optional<std::list<tle::VecTokens>> stopWords, std::optional<tle::Tensor> embeddingBias,
std::optional<tle::ExternalDraftTokensConfig> externalDraftTokensConfig,
std::optional<tle::PromptTuningConfig> pTuningConfig, std::optional<tle::LoraConfig> loraConfig,
std::optional<tle::PromptTuningConfig> pTuningConfig, std::optional<tle::MropeConfig> mRopeConfig,
std::optional<tle::LoraConfig> loraConfig,
std::optional<tle::LookaheadDecodingConfig> lookaheadConfig,
std::optional<tle::KvCacheRetentionConfig> kvCacheRetentionConfig,
std::optional<std::string> logitsPostProcessorName,
@ -413,10 +419,10 @@ void InitBindings(pybind11::module_& m)
TLLM_CHECK_WITH_INFO(maxTokens.has_value(), "missing required argument max_tokens");
return std::make_unique<tle::Request>(inputTokenIds, maxTokens.value(), streaming, samplingConfig,
outputConfig, endId, padId, positionIds, badWords, stopWords, embeddingBias,
externalDraftTokensConfig, pTuningConfig, loraConfig, lookaheadConfig, kvCacheRetentionConfig,
logitsPostProcessorName, encoderInputTokenIds, clientId, returnAllGeneratedTokens, priority,
type, contextPhaseParams, encoderInputFeatures, encoderOutputLength, crossAttentionMask,
numReturnSequences, eagleConfig, skipCrossAttnBlocks);
externalDraftTokensConfig, pTuningConfig, mRopeConfig, loraConfig, lookaheadConfig,
kvCacheRetentionConfig, logitsPostProcessorName, encoderInputTokenIds, clientId,
returnAllGeneratedTokens, priority, type, contextPhaseParams, encoderInputFeatures,
encoderOutputLength, crossAttentionMask, numReturnSequences, eagleConfig, skipCrossAttnBlocks);
}),
py::arg("input_token_ids"), py::kw_only(), py::arg("max_tokens") = py::none(),
py::arg("max_new_tokens") = py::none(), py::arg("streaming") = false,
@ -425,10 +431,11 @@ void InitBindings(pybind11::module_& m)
py::arg("pad_id") = py::none(), py::arg("position_ids") = py::none(), py::arg("bad_words") = py::none(),
py::arg("stop_words") = py::none(), py::arg("embedding_bias") = py::none(),
py::arg("external_draft_tokens_config") = py::none(), py::arg("prompt_tuning_config") = py::none(),
py::arg("lora_config") = py::none(), py::arg("lookahead_config") = py::none(),
py::arg("kv_cache_retention_config") = py::none(), py::arg("logits_post_processor_name") = py::none(),
py::arg("encoder_input_token_ids") = py::none(), py::arg("client_id") = py::none(),
py::arg("return_all_generated_tokens") = false, py::arg("priority") = tle::Request::kDefaultPriority,
py::arg("mrope_config") = py::none(), py::arg("lora_config") = py::none(),
py::arg("lookahead_config") = py::none(), py::arg("kv_cache_retention_config") = py::none(),
py::arg("logits_post_processor_name") = py::none(), py::arg("encoder_input_token_ids") = py::none(),
py::arg("client_id") = py::none(), py::arg("return_all_generated_tokens") = false,
py::arg("priority") = tle::Request::kDefaultPriority,
py::arg_v("type", tle::RequestType::REQUEST_TYPE_CONTEXT_AND_GENERATION,
"RequestType.REQUEST_TYPE_CONTEXT_AND_GENERATION"),
py::arg("context_phase_params") = py::none(), py::arg("encoder_input_features") = py::none(),
@ -451,6 +458,7 @@ void InitBindings(pybind11::module_& m)
&tle::Request::setExternalDraftTokensConfig)
.def_property(
"prompt_tuning_config", &tle::Request::getPromptTuningConfig, &tle::Request::setPromptTuningConfig)
.def_property("mrope_config", &tle::Request::getMropeConfig, &tle::Request::setMropeConfig)
.def_property("lora_config", &tle::Request::getLoraConfig, &tle::Request::setLoraConfig)
.def_property("lookahead_config", &tle::Request::getLookaheadConfig, &tle::Request::setLookaheadConfig)
.def_property("kv_cache_retention_config", &tle::Request::getKvCacheRetentionConfig,

View File

@ -69,6 +69,8 @@ void EagleBuffers::Inputs::create(SizeType32 maxNumSequences, TllmRuntime const&
= manager.pinnedPool(ITensor::makeShape({maxNumSequences}), nvinfer1::DataType::kINT32);
eagleNetGenPastKeyValueLengthsHost
= manager.pinnedPool(ITensor::makeShape({maxNumSequences}), nvinfer1::DataType::kINT32);
inputGenTokensHost
= manager.pinnedPool(ITensor::makeShape({maxNumSequences * maxDecodingTokens}), nvinfer1::DataType::kINT32);
}
EagleBuffers::EagleBuffers(SizeType32 maxBatchSize, SizeType32 maxBeamWidth, runtime::BufferManager const& manager,
@ -119,6 +121,7 @@ EagleBuffers::EagleBuffers(SizeType32 maxBatchSize, SizeType32 maxBeamWidth, run
= manager.emptyTensor(runtime::MemoryType::kPINNEDPOOL, nvinfer1::DataType::kINT32);
engineInputs.eagleNetGenPastKeyValueLengthsHost
= manager.emptyTensor(runtime::MemoryType::kPINNEDPOOL, nvinfer1::DataType::kINT32);
engineInputs.inputGenTokensHost = manager.emptyTensor(runtime::MemoryType::kPINNEDPOOL, nvinfer1::DataType::kINT32);
// output tensors
engineOutputs.nextDraftTokens
@ -187,6 +190,7 @@ void EagleBuffers::reshape(
engineInputs.eagleNetGenRequestTypesHost->reshape(ITensor::makeShape({numSequences}));
engineInputs.eagleNetGenContextLengthsHost->reshape(ITensor::makeShape({numSequences}));
engineInputs.eagleNetGenPastKeyValueLengthsHost->reshape(ITensor::makeShape({numSequences}));
engineInputs.inputGenTokensHost->reshape(ITensor::makeShape({numSequences * maxDecodingTokens}));
cumSumGenerationLengths->reshape(ITensor::makeShape({numSequences + 1}));
@ -260,6 +264,7 @@ void EagleBuffers::setFromInputs(SizeType32 numCtxSequences, SizeType32 numGenSe
// Pack host data.
SizeType32 maxGenerationLengthHostValue{-1};
SizeType32 numGenerationTokens{0};
for (SizeType32 bi = 0; bi < params.batchSize; ++bi)
{
auto const batchSlot = params.batchSlots[bi];
@ -276,8 +281,10 @@ void EagleBuffers::setFromInputs(SizeType32 numCtxSequences, SizeType32 numGenSe
bufferCast<SizeType32>(*engineInputs.eagleNetGenPastKeyValueLengthsHost)[bi]
= bufferCast<SizeType32>(*draftBuffers.eagleNetGenPastKeyValueLengthsHost)[batchSlot];
maxGenerationLengthHostValue = std::max(maxGenerationLengthHostValue,
bufferCast<SizeType32>(*draftBuffers.specDecodingGenerationLengthsHost)[batchSlot]);
auto const generationLength
= bufferCast<SizeType32>(*draftBuffers.specDecodingGenerationLengthsHost)[batchSlot];
maxGenerationLengthHostValue = std::max(maxGenerationLengthHostValue, generationLength);
numGenerationTokens += generationLength;
}
if (maxGenerationLengthHostValue <= 0)
@ -289,6 +296,10 @@ void EagleBuffers::setFromInputs(SizeType32 numCtxSequences, SizeType32 numGenSe
specDecodingPositionOffsetsShape.d[1] = maxGenerationLengthHostValue;
engineInputs.specDecodingPositionOffsets->reshape(specDecodingPositionOffsetsShape);
auto inputGenTokensHostShape = engineInputs.inputGenTokensHost->getShape();
inputGenTokensHostShape.d[0] = numGenerationTokens;
engineInputs.inputGenTokensHost->reshape(inputGenTokensHostShape);
TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__);
}
@ -350,6 +361,7 @@ void EagleBuffers::insertInputTensors(
inputBuffers.insert_or_assign("host_gen_eagle_net_context_lengths", engineInputs.eagleNetGenContextLengthsHost);
inputBuffers.insert_or_assign(
"host_gen_eagle_net_past_key_value_lengths", engineInputs.eagleNetGenPastKeyValueLengthsHost);
inputBuffers.insert_or_assign("input_gen_tokens", engineInputs.inputGenTokensHost);
// outputs
outputBuffers.insert_or_assign("next_draft_tokens", engineOutputs.nextDraftTokens);

View File

@ -28,15 +28,16 @@ class EagleModule : public SpeculativeDecodingModule
public:
// Number of paths is maxDecodingTokens = maxDecodingDraftTokens + 1 to account for very flat trees with
// depth 1.
explicit EagleModule(
SizeType32 maxDraftPathLen, SizeType32 maxDecodingDraftTokens, SizeType32 numTransformersLayer) noexcept
explicit EagleModule(SizeType32 maxDraftPathLen, SizeType32 maxDecodingDraftTokens, SizeType32 numTransformersLayer,
SizeType32 maxNonLeafNodesPerLayer) noexcept
: SpeculativeDecodingModule(maxDraftPathLen, maxDecodingDraftTokens, maxDecodingDraftTokens + 1)
, mNumTransformersLayer(numTransformersLayer)
, mMaxNonLeafNodesPerLayer(maxNonLeafNodesPerLayer)
{
}
explicit EagleModule() noexcept
: EagleModule(0, 0, 0)
: EagleModule(0, 0, 0, 0)
{
}
@ -50,8 +51,14 @@ public:
return mNumTransformersLayer;
}
[[nodiscard]] SizeType32 getMaxNonLeafNodesPerLayer() const noexcept
{
return mMaxNonLeafNodesPerLayer;
}
private:
SizeType32 mNumTransformersLayer;
SizeType32 mMaxNonLeafNodesPerLayer;
// We use mc_sim_7b_63 from official Medusa implementation, i.e. one of the best trees with 63 nodes found for 7B
// Vicuna model. We use it as default, if no other are trees are specified per request or on the server level.

Some files were not shown because too many files have changed in this diff Show More