diff --git a/.gitattributes b/.gitattributes
index e72ba0fe7b..7486041ffd 100644
--- a/.gitattributes
+++ b/.gitattributes
@@ -7,3 +7,5 @@
triton_backend/tools/gpt/input_data.json filter=lfs diff=lfs merge=lfs -text
*cubin.cpp filter=lfs diff=lfs merge=lfs -text
docs/source/blogs/media/tech_blog3_mla_absorb.png filter=lfs diff=lfs merge=lfs -text
+tests/integration/test_input_files/*.png filter=lfs diff=lfs merge=lfs -text
+tests/integration/test_input_files/*.jpg filter=lfs diff=lfs merge=lfs -text
diff --git a/README.md b/README.md
index f6625a0559..745713e581 100644
--- a/README.md
+++ b/README.md
@@ -9,7 +9,7 @@ TensorRT-LLM
[](https://www.python.org/downloads/release/python-31012/)
[](https://developer.nvidia.com/cuda-downloads)
[](https://developer.nvidia.com/tensorrt)
-[](./tensorrt_llm/version.py)
+[](./tensorrt_llm/version.py)
[](./LICENSE)
[Architecture](./docs/source/torch/arch_overview.md) | [Performance](./docs/source/performance/perf-overview.md) | [Examples](https://nvidia.github.io/TensorRT-LLM/quick-start-guide.html) | [Documentation](./docs/source/) | [Roadmap](https://github.com/NVIDIA/TensorRT-LLM/issues?q=is%3Aissue%20state%3Aopen%20label%3Aroadmap)
@@ -18,10 +18,9 @@ TensorRT-LLM
## Tech Blogs
-* [08/06] Running a High Performance GPT-OSS-120B Inference Server with TensorRT-LLM
+* [08/05] Running a High-Performance GPT-OSS-120B Inference Server with TensorRT-LLM
✨ [➡️ link](./docs/source/blogs/tech_blog/blog9_Deploying_GPT_OSS_on_TRTLLM.md)
-
* [08/01] Scaling Expert Parallelism in TensorRT-LLM (Part 2: Performance Status and Optimization)
✨ [➡️ link](./docs/source/blogs/tech_blog/blog8_Scaling_Expert_Parallelism_in_TensorRT-LLM_part2.md)
@@ -44,6 +43,7 @@ TensorRT-LLM
✨ [➡️ link](./docs/source/blogs/tech_blog/blog1_Pushing_Latency_Boundaries_Optimizing_DeepSeek-R1_Performance_on_NVIDIA_B200_GPUs.md)
## Latest News
+* [08/05] 🌟 TensorRT-LLM delivers Day-0 support for OpenAI's latest open-weights models: GPT-OSS-120B [➡️ link](https://huggingface.co/openai/gpt-oss-120b) and GPT-OSS-20B [➡️ link](https://huggingface.co/openai/gpt-oss-20b)
* [07/15] 🌟 TensorRT-LLM delivers Day-0 support for LG AI Research's latest model, EXAONE 4.0 [➡️ link](https://huggingface.co/LGAI-EXAONE/EXAONE-4.0-32B)
* [06/17] Join NVIDIA and DeepInfra for a developer meetup on June 26 ✨ [➡️ link](https://events.nvidia.com/scaletheunscalablenextgenai)
* [05/22] Blackwell Breaks the 1,000 TPS/User Barrier With Meta’s Llama 4 Maverick
@@ -253,5 +253,5 @@ Deprecation is used to inform developers that some APIs and tools are no longer
## Useful Links
- [Quantized models on Hugging Face](https://huggingface.co/collections/nvidia/model-optimizer-66aa84f7966b3150262481a4): A growing collection of quantized (e.g., FP8, FP4) and optimized LLMs, including [DeepSeek FP4](https://huggingface.co/nvidia/DeepSeek-R1-FP4), ready for fast inference with TensorRT-LLM.
- [NVIDIA Dynamo](https://github.com/ai-dynamo/dynamo): A datacenter scale distributed inference serving framework that works seamlessly with TensorRT-LLM.
-- [AutoDeploy](./examples/auto_deploy/README.md): A prototype backend for TensorRT-LLM to simplify and accelerate the deployment of PyTorch models.
+- [AutoDeploy](https://nvidia.github.io/TensorRT-LLM/torch/auto_deploy/auto-deploy.html): A prototype backend for TensorRT-LLM to simplify and accelerate the deployment of PyTorch models.
- [WeChat Discussion Group](https://github.com/NVIDIA/TensorRT-LLM/issues/5359): A real-time channel for TensorRT-LLM Q&A and news.
diff --git a/cpp/CMakeLists.txt b/cpp/CMakeLists.txt
index 4f9d5f0f22..198fd22a71 100644
--- a/cpp/CMakeLists.txt
+++ b/cpp/CMakeLists.txt
@@ -69,7 +69,7 @@ add_compile_definitions("TLLM_GEN_EXPORT_INTERFACE")
add_compile_definitions("TLLM_ENABLE_CUDA")
set(BINDING_TYPE
- "pybind"
+ "nanobind"
CACHE STRING
"Binding type of Python bindings for C++ runtime and batch manager")
diff --git a/cpp/include/tensorrt_llm/batch_manager/createNewDecoderRequests.h b/cpp/include/tensorrt_llm/batch_manager/createNewDecoderRequests.h
index 394f7fb7bf..0978905b5e 100644
--- a/cpp/include/tensorrt_llm/batch_manager/createNewDecoderRequests.h
+++ b/cpp/include/tensorrt_llm/batch_manager/createNewDecoderRequests.h
@@ -24,7 +24,6 @@
#include "tensorrt_llm/runtime/common.h"
#include "tensorrt_llm/runtime/iTensor.h"
#include "tensorrt_llm/runtime/modelConfig.h"
-#include "tensorrt_llm/runtime/request.h"
#include "tensorrt_llm/runtime/worldConfig.h"
namespace tensorrt_llm::runtime
@@ -88,37 +87,6 @@ public:
SizeType32 maxSequenceLength, OptionalRef medusaBuffers) const;
private:
- //! @brief Setups decoder internal tensors for new speculative decoding request
- static void newRequestSpeculativeDecoding(SizeType32 batchIdx, runtime::decoder_batch::Request const& request,
- SamplingConfig const& samplingConfig, runtime::ModelConfig const& modelConfig,
- DecodingInput& jointDecodingInput, DecodingOutput& jointDecodingOutput, CudaStream const& runtimeStream,
- CudaStream const& decoderStream, SpeculativeDecodingMode const& speculativeDecodingMode,
- SizeType32 maxDecodingEngineTokens);
-
- //! @brief Setups decoder internal tensors for new request in Draft model Sps mode
- static void newRequestDraftTokensExternal(SizeType32 batchIdx, runtime::decoder_batch::Request const& request,
- SamplingConfig const& samplingConfig, DecodingInput& jointDecodingInput, CudaStream const& decoderStream);
-
- //! @brief Setups decoder internal tensors for new Medusa request
- static void newRequestMedusa(SizeType32 batchIdx, runtime::decoder_batch::Request const& request,
- DecodingInput& jointDecodingInput, CudaStream const& decoderStream, SizeType32 maxDecodingEngineTokens);
-
- //! @brief Setups decoder internal tensors for new Lookahead request
- static void newRequestLookahead(SizeType32 batchIdx, runtime::decoder_batch::Request const& request,
- DecodingInput& jointDecodingInput, DecodingOutput& jointDecodingOutput, CudaStream const& runtimeStream);
-
- //! @brief Setups decoder internal tensors for new Explicit draft tokens request
- static void newRequestExplicitDraftTokens(SizeType32 batchIdx, runtime::decoder_batch::Request const& request,
- DecodingOutput& jointDecodingOutput, CudaStream const& runtimeStream);
-
- //! @brief Setups decoder internal tensors for new Eagle request
- static void newRequestEagle(SizeType32 batchIdx, runtime::decoder_batch::Request const& request,
- runtime::ModelConfig const& modelConfig, DecodingOutput& jointDecodingOutput, CudaStream const& runtimeStream);
-
- [[nodiscard]] std::shared_ptr retrieveDraftLogits(runtime::ModelConfig const& modelConfig,
- runtime::WorldConfig const& worldConfig, std::shared_ptr const& tensor,
- runtime::BufferManager const& bufferManager) const;
-
bool mSpeculativeDecodingFastLogits;
bool mIsLeaderInOrchMode;
bool mIsNormalizeLogProbs;
diff --git a/cpp/include/tensorrt_llm/batch_manager/llmRequest.h b/cpp/include/tensorrt_llm/batch_manager/llmRequest.h
index e4d13c9e17..f069e3ac7f 100644
--- a/cpp/include/tensorrt_llm/batch_manager/llmRequest.h
+++ b/cpp/include/tensorrt_llm/batch_manager/llmRequest.h
@@ -1110,7 +1110,7 @@ public:
[[nodiscard]] SizeType32 getNumDraftTokens() const
{
- return mDraftTokens->size();
+ return hasDraftTokens() ? mDraftTokens->size() : 0;
}
void discardDraftTokens(SizeType32 numTokensToDiscard)
diff --git a/cpp/include/tensorrt_llm/common/logger.h b/cpp/include/tensorrt_llm/common/logger.h
index df84e22638..c8164b10e5 100644
--- a/cpp/include/tensorrt_llm/common/logger.h
+++ b/cpp/include/tensorrt_llm/common/logger.h
@@ -54,20 +54,21 @@ public:
#if defined(_MSC_VER)
template
- void log(Level level, char const* format, Args const&... args);
+ void log(Level const level, char const* format, Args const&... args);
template
- void log(Level level, int rank, char const* format, Args const&... args);
+ void log(Level const level, int const rank, char const* format, Args const&... args);
#else
template
- void log(Level level, char const* format, Args const&... args) __attribute__((format(printf, 3, 0)));
+ void log(Level const level, char const* format, Args const&... args) __attribute__((format(printf, 3, 0)));
template
- void log(Level level, int rank, char const* format, Args const&... args) __attribute__((format(printf, 4, 0)));
+ void log(Level const level, int const rank, char const* format, Args const&... args)
+ __attribute__((format(printf, 4, 0)));
#endif
template
- void log(Level level, std::string const& format, Args const&... args)
+ void log(Level const level, std::string const& format, Args const&... args)
{
return log(level, format.c_str(), args...);
}
@@ -134,7 +135,7 @@ private:
};
template
-void Logger::log(Logger::Level level, char const* format, Args const&... args)
+void Logger::log(Logger::Level const level, char const* format, Args const&... args)
{
if (isEnabled(level))
{
diff --git a/cpp/include/tensorrt_llm/executor/dataTransceiverState.h b/cpp/include/tensorrt_llm/executor/dataTransceiverState.h
index 2a2f1f8369..98b26a276c 100644
--- a/cpp/include/tensorrt_llm/executor/dataTransceiverState.h
+++ b/cpp/include/tensorrt_llm/executor/dataTransceiverState.h
@@ -52,29 +52,30 @@ public:
AttentionType attentionType = AttentionType::kDEFAULT, int kvFactor = 2)
: mModelConfig(std::move(modelConfig))
, mParallelConfig{worldConfig.getTensorParallelism(), worldConfig.getPipelineParallelism(),
- worldConfig.enableAttentionDP(), worldConfig.getTensorParallelRank(), worldConfig.getTensorParallelism()}
+ worldConfig.getContextParallelism(), worldConfig.enableAttentionDP(), worldConfig.getTensorParallelRank(),
+ worldConfig.getTensorParallelism()}
, mDataType{dataType}
, mAttentionConfig(attentionType, kvFactor)
{
}
CacheState(std::vector nbKvHeadPerLayer, SizeType32 sizePerHead, SizeType32 tokensPerBlock,
- SizeType32 tensorParallelism, SizeType32 pipelineParallelism, nvinfer1::DataType dataType,
- AttentionType attentionType = AttentionType::kDEFAULT, int kvFactor = 2, bool enableAttentionDP = false,
- int DPrank = 0, int DPsize = 0)
+ SizeType32 tensorParallelism, SizeType32 pipelineParallelism, SizeType32 contextParallelism,
+ nvinfer1::DataType dataType, AttentionType attentionType = AttentionType::kDEFAULT, int kvFactor = 2,
+ bool enableAttentionDP = false, int DPrank = 0, int DPsize = 0)
: mModelConfig{std::move(nbKvHeadPerLayer), sizePerHead, tokensPerBlock}
- , mParallelConfig{tensorParallelism, pipelineParallelism, enableAttentionDP, DPrank, DPsize}
+ , mParallelConfig{tensorParallelism, pipelineParallelism, contextParallelism, enableAttentionDP, DPrank, DPsize}
, mDataType{dataType}
, mAttentionConfig(attentionType, kvFactor)
{
}
CacheState(SizeType32 nbAttentionLayers, SizeType32 nbKvHeads, SizeType32 sizePerHead, SizeType32 tokensPerBlock,
- SizeType32 tensorParallelism, SizeType32 pipelineParallelism, nvinfer1::DataType dataType,
- AttentionType attentionType = AttentionType::kDEFAULT, int kvFactor = 2, bool enableAttentionDP = false,
- int DPrank = 0, int DPsize = 0)
+ SizeType32 tensorParallelism, SizeType32 pipelineParallelism, SizeType32 contextParallelism,
+ nvinfer1::DataType dataType, AttentionType attentionType = AttentionType::kDEFAULT, int kvFactor = 2,
+ bool enableAttentionDP = false, int DPrank = 0, int DPsize = 0)
: mModelConfig{std::vector(nbAttentionLayers, nbKvHeads), sizePerHead, tokensPerBlock}
- , mParallelConfig{tensorParallelism, pipelineParallelism, enableAttentionDP, DPrank, DPsize}
+ , mParallelConfig{tensorParallelism, pipelineParallelism, contextParallelism, enableAttentionDP, DPrank, DPsize}
, mDataType{dataType}
, mAttentionConfig(attentionType, kvFactor)
{
@@ -83,7 +84,7 @@ public:
[[nodiscard]] bool operator==(kv_cache::CacheState const& other) const noexcept
{
return mModelConfig == other.mModelConfig && mParallelConfig == other.mParallelConfig
- && mDataType == other.mDataType;
+ && mAttentionConfig == other.mAttentionConfig && mDataType == other.mDataType;
}
struct ModelConfig
@@ -103,6 +104,7 @@ public:
{
SizeType32 mTensorParallelism;
SizeType32 mPipelineParallelism;
+ SizeType32 mContextParallelism;
bool mEnableAttentionDP;
SizeType32 mDPrank;
SizeType32 mDPsize;
@@ -110,8 +112,8 @@ public:
[[nodiscard]] bool operator==(ParallelConfig const& other) const noexcept
{
return mTensorParallelism == other.mTensorParallelism && mPipelineParallelism == other.mPipelineParallelism
- && mEnableAttentionDP == other.mEnableAttentionDP && mDPrank == other.mDPrank
- && mDPsize == other.mDPsize;
+ && mContextParallelism == other.mContextParallelism && mEnableAttentionDP == other.mEnableAttentionDP
+ && mDPrank == other.mDPrank && mDPsize == other.mDPsize;
}
};
@@ -125,6 +127,11 @@ public:
{
}
+ [[nodiscard]] bool operator==(AttentionConfig const& other) const noexcept
+ {
+ return mAttentionType == other.mAttentionType && mKvFactor == other.mKvFactor;
+ }
+
// attentionType ;
AttentionType mAttentionType;
int mKvFactor;
@@ -162,6 +169,7 @@ public:
sstring << "mTokensPerBlock:" << mModelConfig.mTokensPerBlock << "\n";
sstring << "tp:" << mParallelConfig.mTensorParallelism << "\n";
sstring << "pp:" << mParallelConfig.mPipelineParallelism << "\n";
+ sstring << "cp:" << mParallelConfig.mContextParallelism << "\n";
sstring << "enableAttentionDP:" << mParallelConfig.mEnableAttentionDP << "\n";
sstring << "datatype:" << static_cast(mDataType) << "\n";
sstring << "attentionType:" << static_cast(mAttentionConfig.mAttentionType) << "\n";
diff --git a/cpp/include/tensorrt_llm/runtime/decodingInput.h b/cpp/include/tensorrt_llm/runtime/decodingInput.h
index deeb0fa0af..4344f423ac 100644
--- a/cpp/include/tensorrt_llm/runtime/decodingInput.h
+++ b/cpp/include/tensorrt_llm/runtime/decodingInput.h
@@ -102,11 +102,13 @@ public:
{
public:
TensorPtr draftLogits;
+ TensorPtr draftLogitsHost;
TensorPtr draftProbs;
TensorPtr targetProbs;
TensorPtr numDraftTokens;
TensorPtr numDraftTokensHost;
TensorPtr draftTokenIds;
+ TensorPtr draftTokenIdsHost;
TensorPtr useDraftLogits;
TensorPtr useDraftLogitsHost;
diff --git a/cpp/include/tensorrt_llm/runtime/request.h b/cpp/include/tensorrt_llm/runtime/request.h
deleted file mode 100644
index e8f851b7d7..0000000000
--- a/cpp/include/tensorrt_llm/runtime/request.h
+++ /dev/null
@@ -1,54 +0,0 @@
-/*
- * Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
- *
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-#pragma once
-
-#include "tensorrt_llm/executor/executor.h"
-#include "tensorrt_llm/runtime/iTensor.h"
-
-#include
-
-namespace tensorrt_llm::runtime::decoder_batch
-{
-
-class Request
-{
-public:
- using TensorConstPtr = ITensor::SharedConstPtr;
- using TensorPtr = ITensor::SharedPtr;
- using BufferPtr = IBuffer::SharedPtr;
-
- explicit Request(SizeType32 inputLen)
- : inputLen(inputLen)
- {
- }
-
- //! Mandatory parameters
- SizeType32 inputLen; // Input length without draft tokens, increasing with generation steps
-
- // optional parameters
- SizeType32 generatedTokensPerEngineStep{1}; //
-
- //! Optional parameters for speculative decoding
- BufferPtr draftTokens; // [generatedTokensPerEngineStep - 1] on gpu
- std::optional draftLogits; // [generatedTokensPerEngineStep - 1, vocabSize] on gpu
- TensorPtr medusaPaths; // [maxDecodingTokens, maxPathLen], on gpu
- TensorPtr medusaTreeIds; // [maxDecodingTokens], on gpu
- std::optional lookaheadRuntimeConfig;
- std::optional eagleConfig;
-};
-
-} // namespace tensorrt_llm::runtime::decoder_batch
diff --git a/cpp/kernels/xqa/mha_sm90.cu b/cpp/kernels/xqa/mha_sm90.cu
index 9a438df9a2..da44fba60c 100644
--- a/cpp/kernels/xqa/mha_sm90.cu
+++ b/cpp/kernels/xqa/mha_sm90.cu
@@ -1012,7 +1012,7 @@ CUBIN_EXPORT __global__
if (threadIdx.x < smem.gemm1AccColMax.size)
{
auto const idx = threadIdx.x;
- smem.gemm1AccColMax[idx] = mha::numeric_limits::lowest();
+ smem.gemm1AccColMax[idx] = safeInitRowMax;
smem.gemm1AccColSum[idx] = 0;
}
smem.gemm1WarpGrpBar.arrive_and_wait();
@@ -1949,7 +1949,7 @@ __device__ inline void warpGrpApplyMask(Gemm0Acc& acc, SpecDec const& specDec,
uint32_t const globalRow = tileStartRow + row;
if (globalRow >= cacheSeqLen)
{
- acc(m, n)(i, j) = mha::numeric_limits::lowest();
+ acc(m, n)(i, j) = safeInitRowMax;
continue;
}
if (globalRow >= maskStartRow)
@@ -1957,7 +1957,7 @@ __device__ inline void warpGrpApplyMask(Gemm0Acc& acc, SpecDec const& specDec,
uint32_t const maskRow = globalRow - maskStartRow;
if ((bit_mask >> maskRow) == 0)
{
- acc(m, n)(i, j) = mha::numeric_limits::lowest();
+ acc(m, n)(i, j) = safeInitRowMax;
}
}
}
@@ -2087,7 +2087,7 @@ __device__ inline void warpGrpApplyMask(uint32_t warpRank, Gemm0Acc& acc, uint32
#pragma unroll
for (uint32_t j = 0; j < GmmaAccCoreMat::cols; j++)
{
- acc(m, n)(i, j) = mha::numeric_limits::lowest();
+ acc(m, n)(i, j) = safeInitRowMax;
}
}
}
@@ -2380,9 +2380,9 @@ __device__ inline void warpGrpApplyMask(Gemm0Acc& acc, SpecDec const& specDec,
{
uint32_t const col = GmmaAccCoreMat::cols * (4 * n + idxInQuad) + j;
assert((col < nbValidCols) == bool(endMask & (1ULL << col)));
- if (((mask >> col) & 1) == 0)
+ if ((mask & (1ULL << col)) == 0)
{
- acc(m, n)(i, j) = mha::numeric_limits::lowest();
+ acc(m, n)(i, j) = safeInitRowMax;
}
}
}
@@ -2410,7 +2410,7 @@ __device__ inline void warpGrpApplyMask(Gemm0Acc& acc, uint32_t validColBeg, uin
#pragma unroll
for (uint32_t i = 0; i < GmmaAccCoreMat::rows; i++)
{
- acc(m, n)(i, j) = mha::numeric_limits::lowest();
+ acc(m, n)(i, j) = safeInitRowMax;
}
}
}
diff --git a/cpp/micro_benchmarks/mixtureOfExpertsBackendBenchmarkFixture.h b/cpp/micro_benchmarks/mixtureOfExpertsBackendBenchmarkFixture.h
index 2559ae5484..36cbe76544 100644
--- a/cpp/micro_benchmarks/mixtureOfExpertsBackendBenchmarkFixture.h
+++ b/cpp/micro_benchmarks/mixtureOfExpertsBackendBenchmarkFixture.h
@@ -833,7 +833,7 @@ public:
// Runs for 3 iterations or 1 second and picks the best option
int pickBestTactic(MOEParallelismConfig parallelism_config, GemmToProfile gemm_to_profile)
{
- auto tactics = mMoERunner.getTactics();
+ auto tactics = mMoERunner.getTactics(static_cast(gemm_to_profile));
::nvtx3::scoped_range nvtx(tensorrt_llm::common::nvtx::nextColor(),
"Tactic Profiling GEMM " + std::to_string(static_cast(gemm_to_profile)));
// We save space by reusing the same workspace buffer for all tactics when doing full layer profiling. So we
@@ -925,12 +925,14 @@ public:
std::pair setTactic(
int tactic_idx1, int tactic_idx2, MOEParallelismConfig parallelism_config, GemmToProfile gemm_to_profile)
{
- auto tactics = mMoERunner.getTactics();
+ auto tactics1 = mMoERunner.getTactics(MoeGemmId::GEMM_1);
+ auto tactics2 = mMoERunner.getTactics(MoeGemmId::GEMM_2);
std::vector, GemmToProfile>> tactics_to_profile{
{tactic_idx1, GemmToProfile::GEMM_1}, {tactic_idx2, GemmToProfile::GEMM_2}};
for (auto& combo : tactics_to_profile)
{
auto& t = combo.first.get();
+ auto& tactics = combo.second == GemmToProfile::GEMM_1 ? tactics1 : tactics2;
if (combo.second != gemm_to_profile && gemm_to_profile != GemmToProfile::LAYER)
{
t = 0; // Unneeded tactic, set to 0
@@ -947,7 +949,7 @@ public:
}
}
- mMoERunner.setTactic(tactics[tactic_idx1], tactics[tactic_idx2]);
+ mMoERunner.setTactic(tactics1[tactic_idx1], tactics2[tactic_idx2]);
mBestTacticGemm1 = tactic_idx1;
mBestTacticGemm2 = tactic_idx2;
return {tactic_idx1, tactic_idx2};
@@ -965,7 +967,7 @@ public:
auto expert_weights_size
= gemm_to_profile == GemmToProfile::GEMM_1 ? mExpertWeight1Size : mExpertWeight2Size;
- auto tactics = mMoERunner.getTactics()[tactic_idx];
+ auto tactics = mMoERunner.getTactics(static_cast(gemm_to_profile))[tactic_idx];
if (static_cast(gemm_to_profile) != static_cast(mGemmProfilerBackend.mGemmToProfile))
{
throw std::runtime_error("Configuration mismatch between mGemmProfilerBackend and runMoEPermute");
@@ -1074,11 +1076,12 @@ void MixtureOfExpertsBenchmark::runBenchmark(benchmark::State& state
}
if (LOG_LEVEL >= INFO)
{
- auto tactics = mMoERunner.getTactics();
- std::cout << "Selected tactic #1: " << tactic_idx1 << "/" << tactics.size() << "\n"
- << tactics[tactic_idx1].toString() << std::endl;
- std::cout << "Selected tactic #2: " << tactic_idx2 << "/" << tactics.size() << "\n"
- << tactics[tactic_idx2].toString() << std::endl;
+ auto tactics1 = mMoERunner.getTactics(MoeGemmId::GEMM_1);
+ auto tactics2 = mMoERunner.getTactics(MoeGemmId::GEMM_2);
+ std::cout << "Selected tactic #1: " << tactic_idx1 << "/" << tactics1.size() << "\n"
+ << tactics1[tactic_idx1].toString() << std::endl;
+ std::cout << "Selected tactic #2: " << tactic_idx2 << "/" << tactics2.size() << "\n"
+ << tactics2[tactic_idx2].toString() << std::endl;
}
state.counters["tactic_idx1"] = tactic_idx1;
state.counters["tactic_idx2"] = tactic_idx2;
diff --git a/cpp/micro_benchmarks/mixtureOfExpertsBackendBenchmarkLauncher.cu b/cpp/micro_benchmarks/mixtureOfExpertsBackendBenchmarkLauncher.cu
index b784c6d0bc..8e18694ad7 100644
--- a/cpp/micro_benchmarks/mixtureOfExpertsBackendBenchmarkLauncher.cu
+++ b/cpp/micro_benchmarks/mixtureOfExpertsBackendBenchmarkLauncher.cu
@@ -42,148 +42,15 @@ struct WeightParams
->Apply(argGen>>)
template
-auto listAllTactics()
+auto listAllTactics(MoeGemmId gemm_id)
{
int const sm = getSMVersion();
using RunnerType = decltype(BenchClass::mMoERunner);
- return RunnerType::getTactics(sm);
+ return RunnerType::getTactics(sm, gemm_id);
}
template
-int parseTacticToId(nlohmann::json tactic_config)
-{
- bool is_tma_warp_specialized = tactic_config.at("is_tma_warp_specialized").get();
- int tile_shape_id = -1;
- std::array tile_shape;
- if (tactic_config.at("tile_shape").is_array())
- tactic_config.at("tile_shape").get_to(tile_shape);
- else
- tile_shape_id = tactic_config.at("tile_shape").get();
-
- std::vector confs = listAllTactics();
-
- try
- {
- for (int i = 0; i < confs.size(); i++)
- {
- auto const& c = confs[i];
- if (c.is_tma_warp_specialized != is_tma_warp_specialized)
- continue;
-
- if (!is_tma_warp_specialized)
- {
- int stages = tactic_config.at("stages").get();
- if (c.stages != stages)
- continue;
- }
-
- if (tile_shape_id != -1)
- {
- int comp = c.getTileConfigAsInt();
- if (tile_shape_id != comp)
- continue;
- if (is_tma_warp_specialized && (int) c.cluster_shape != tactic_config.at("cluster_shape").get())
- continue;
-
- // Found matching config
- return i;
- }
-
- // Handle if the user provided a shape instead of the enum value
- if (is_tma_warp_specialized)
- {
- // TODO Add cases for blackwell shapes
- using Kv = uint64_t;
- constexpr static auto K = [](int m, int n) { return (uint64_t(m) << 32) | uint64_t(n); };
- static std::unordered_map const tile_map{
- {K(64, 16), CutlassTileConfigSM90::CtaShape64x16x128B},
- {K(64, 32), CutlassTileConfigSM90::CtaShape64x32x128B},
- {K(64, 64), CutlassTileConfigSM90::CtaShape64x64x128B},
- {K(64, 128), CutlassTileConfigSM90::CtaShape64x128x128B},
- {K(64, 256), CutlassTileConfigSM90::CtaShape64x256x128B},
-
- {K(128, 16), CutlassTileConfigSM90::CtaShape128x16x128B},
- {K(128, 32), CutlassTileConfigSM90::CtaShape128x32x128B},
- {K(128, 64), CutlassTileConfigSM90::CtaShape128x64x128B},
- {K(128, 128), CutlassTileConfigSM90::CtaShape128x128x128B},
- {K(128, 256), CutlassTileConfigSM90::CtaShape128x256x128B},
- {K(256, 128), CutlassTileConfigSM90::CtaShape256x128x128B},
- };
-
- if (c.getTileConfigAsInt() != (int) tile_map.at(K(tile_shape[0], tile_shape[1])))
- continue;
-
- static std::unordered_map const cluster_map{
- // CTA configs for M=64
- {K(1, 1), ClusterShape::ClusterShape_1x1x1},
- {K(2, 1), ClusterShape::ClusterShape_2x1x1},
- {K(1, 2), ClusterShape::ClusterShape_1x2x1},
- {K(2, 2), ClusterShape::ClusterShape_2x2x1},
- };
-
- std::array cluster_shape;
- tactic_config.at("cluster_shape").get_to(cluster_shape);
-
- if (c.cluster_shape != cluster_map.at(K(cluster_shape[0], cluster_shape[1])))
- continue;
-
- // Found matching config
- return i;
- }
- else
- {
- std::array warp_shape;
- tactic_config.at("warp_shape").get_to(warp_shape);
-
- using Kv = uint64_t;
- constexpr static auto K = [](std::array a, std::array b)
- {
- uint64_t sum = 0;
- for (auto v : a)
- sum = sum * 512 + v;
- for (auto v : b)
- sum = sum * 256 + v;
- return sum;
- };
- static std::unordered_map tile_map{
- {K({128, 128, 8}, {64, 64, 8}), CutlassTileConfig::CtaShape128x128x8_WarpShape64x64x8},
-
- {K({16, 128, 64}, {16, 32, 64}), CutlassTileConfig::CtaShape16x128x64_WarpShape16x32x64},
- {K({32, 128, 64}, {32, 32, 64}), CutlassTileConfig::CtaShape32x128x64_WarpShape32x32x64},
-
- {K({64, 128, 64}, {32, 64, 64}), CutlassTileConfig::CtaShape64x128x64_WarpShape32x64x64},
- {K({64, 64, 128}, {32, 64, 64}), CutlassTileConfig::CtaShape64x64x128_WarpShape32x64x64},
- {K({64, 128, 64}, {64, 32, 64}), CutlassTileConfig::CtaShape64x128x64_WarpShape64x32x64},
-
- {K({128, 64, 64}, {64, 32, 64}), CutlassTileConfig::CtaShape128x64x64_WarpShape64x32x64},
- {K({128, 128, 64}, {64, 32, 64}), CutlassTileConfig::CtaShape128x128x64_WarpShape64x32x64},
- {K({128, 128, 64}, {64, 64, 64}), CutlassTileConfig::CtaShape128x128x64_WarpShape64x64x64},
- {K({128, 128, 64}, {64, 32, 64}), CutlassTileConfig::CtaShape128x128x64_WarpShape128x32x64},
- {K({128, 256, 64}, {64, 64, 64}), CutlassTileConfig::CtaShape128x256x64_WarpShape64x64x64},
-
- {K({256, 128, 64}, {64, 64, 64}), CutlassTileConfig::CtaShape256x128x64_WarpShape64x64x64},
-
- {K({16, 256, 64}, {16, 64, 64}), CutlassTileConfig::CtaShape16x256x64_WarpShape16x64x64}
-
- };
- if (c.tile_config_sm80 != tile_map.at(K(tile_shape, warp_shape)))
- continue;
-
- // Found matching config
- return i;
- }
- }
- }
- catch (std::out_of_range const& e)
- {
- std::cerr << "Warning: error parsing tactic " << tactic_config.dump(2) << std::endl;
- }
-
- return -1;
-}
-
-template
-void parseTacticToVectorID(nlohmann::json& tactic, std::vector& tactic_ids)
+void parseTacticToVectorID(nlohmann::json& tactic, std::vector& tactic_ids, MoeGemmId gemm_id)
{
if (tactic.is_number_integer())
{
@@ -193,20 +60,16 @@ void parseTacticToVectorID(nlohmann::json& tactic, std::vector& tactic_ids)
{
for (auto c : tactic)
{
- parseTacticToVectorID(c, tactic_ids);
+ parseTacticToVectorID(c, tactic_ids, gemm_id);
}
}
- else if (tactic.is_object())
- {
- tactic_ids.push_back(parseTacticToId(tactic));
- }
else if (tactic.is_string())
{
assert(tactic.is_string());
auto tactic_name = tactic.get();
if (tactic_name == "all")
{
- auto all_tactics = listAllTactics();
+ auto all_tactics = listAllTactics(gemm_id);
tactic_ids.resize(all_tactics.size());
std::iota(tactic_ids.begin(), tactic_ids.end(), 0);
}
@@ -410,39 +273,15 @@ void argGenLoadFile(benchmark::internal::Benchmark* benchmark)
}
// Do this after filtering datatypes as tactics only make sense if we know the data type
- bool has_tactic_ids2 = false;
std::vector tactic_ids1{};
std::vector tactic_ids2{};
- if (run_config.contains("tactic_id1") || run_config.contains("tactic_id2"))
+ if (run_config.contains("tactic_id1"))
{
- if (run_config.contains("tactic_id"))
- {
- throw std::invalid_argument("Cannot use tactic_id and tactic_idX");
- }
- has_tactic_ids2 = true;
- parseTacticToVectorID(run_config["tactic_id1"], tactic_ids1);
- parseTacticToVectorID(run_config["tactic_id2"], tactic_ids2);
+ parseTacticToVectorID(run_config["tactic_id1"], tactic_ids1, MoeGemmId::GEMM_1);
}
- else
+ if (run_config.contains("tactic_id2"))
{
- parseTacticToVectorID(run_config["tactic_id"], tactic_ids1);
- has_tactic_ids2 = false;
- tactic_ids2.resize(1); // Dummy value so we loop exactly once below
- }
- if (tactic_ids1.empty() || tactic_ids2.empty())
- {
- std::cerr << "Warning: Skipping benchmark, no valid tactic found" << std::endl;
- static bool printed = false;
- if (!printed)
- {
- printed = true;
- std::cerr << __PRETTY_FUNCTION__ << ": Valid Tactics are:\n";
- auto confs = listAllTactics();
- for (auto c : confs)
- std::cerr << c.toString();
- }
-
- continue;
+ parseTacticToVectorID(run_config["tactic_id2"], tactic_ids2, MoeGemmId::GEMM_2);
}
auto get_or = [&](auto name, auto def)
@@ -478,8 +317,6 @@ void argGenLoadFile(benchmark::internal::Benchmark* benchmark)
}
else if (gemm_to_profile == (int) GemmToProfile::GEMM_2)
{
- if (!has_tactic_ids2)
- tactic_ids2 = std::move(tactic_ids1);
tactic_ids1 = {-1};
}
}
@@ -494,14 +331,31 @@ void argGenLoadFile(benchmark::internal::Benchmark* benchmark)
return val;
};
+ if (tactic_ids1.empty() || tactic_ids2.empty())
+ {
+ std::cerr << "Warning: Skipping benchmark, no valid tactic found" << std::endl;
+ static bool printed = false;
+ if (!printed)
+ {
+ printed = true;
+ std::cerr << __PRETTY_FUNCTION__ << ": Valid Tactics are:\n";
+ for (auto gemm_id : {MoeGemmId::GEMM_1, MoeGemmId::GEMM_2})
+ {
+ std::cerr << "GEMM " << (int) gemm_id << ":\n";
+ auto confs = listAllTactics(gemm_id);
+ for (auto c : confs)
+ std::cerr << c.toString();
+ std::cerr << std::endl;
+ }
+ }
+
+ continue;
+ }
+
for (auto t1 : tactic_ids1)
{
- // tactic_ids2 will have one dummy value if has_tactic_ids2 = false
for (auto t2 : tactic_ids2)
{
- if (!has_tactic_ids2)
- t2 = t1;
-
benchmark->Args({num_experts, //
get_range("k"), //
get_range("hidden_size"), //
@@ -531,7 +385,7 @@ void argGenHardcoded(benchmark::internal::Benchmark* benchmark)
// {ActivationType::Relu, ActivationType::Gelu,
// ActivationType::Silu, ActivationType::Geglu,
// ActivationType::Swiglu};
- auto cutlass_tactic = {-1}; // {0,..., listAllTactics().size()};
+ auto cutlass_tactic = {-1}; // {0,..., listAllTactics(MoeGemmId).size()};
auto routing_config = {LOAD_BALANCED_ROUTING_CONFIG}; // {0, 1, 2};
for (auto num_expert : num_experts)
@@ -558,14 +412,18 @@ void argGen(benchmark::internal::Benchmark* benchmark)
{
if (LOG_LEVEL >= VERBOSE)
{
- std::cout << "List of all tactics for dtype " << (int) BenchClass::toDTypeID() << ":\n";
- int i = 0;
- for (auto& t : listAllTactics())
+ std::cout << "== List of all tactics for dtype " << (int) BenchClass::toDTypeID() << " ==\n";
+ for (auto gemm_id : {MoeGemmId::GEMM_1, MoeGemmId::GEMM_2})
{
- std::cout << "Tactic " << i << ":\n";
- std::cout << t.toString() << std::endl;
+ int i = 0;
+ std::cout << "=== GEMM " << (int) gemm_id << " ===\n";
+ for (auto& t : listAllTactics(gemm_id))
+ {
+ std::cout << "==== Tactic " << i << " ====\n";
+ std::cout << t.toString() << std::endl;
- i++;
+ i++;
+ }
}
}
@@ -652,7 +510,6 @@ void help()
" \"bias\": int, (optional)\n"
" \"do_final_scale\": int, (optional)\n"
" \"act_fn\": int,\n"
- " \"tactic_id\": tactic, (see below)\n"
" \"tactic_id1\": tactic, (see below)\n"
" \"tactic_id2\": tactic, (see below)\n"
" \"dtypes\": [string, ...], (optional)\n"
@@ -676,27 +533,14 @@ void help()
"- \"do_final_scale\" - If final scales should be applied, 0 = no scale, 1 = scale\n"
"- \"act_fn\" - The activation function to use, 0 = identity, 1 = relu, 2 = gelu, 3 = silu, 4 = geglu, 5 = "
"swiglu\n"
- "- \"tactic_id, tactic_id1, tactic_id2\"\n"
- "The config for the CUTLASS GEMM. tactic_id sets the same tactic for both to the same tactic (except in "
- "auto mode)\n"
- "Use tactic_idX to set the tactic for the corresponding GEMM"
+ "- \"tactic_id1, tactic_id2\"\n"
+ "The config for the CUTLASS GEMM. tactic_idX sets the tactic for the corresponding GEMM"
"Valid tactics are:\n"
- " - An object:\n"
- " {\n"
- " \"is_tma_warp_specialized\": bool,\n"
- " \"tile_shape\": [int, int, int] or int,\n"
- " \"cluster_shape\": [int, int, int] or int, (required for sm90, type must be an int if tile_shape "
- "is "
- "an int)\n"
- " \"warp_shape\": [int, int, int], (required for non-sm90 if tile_shape is an array)\n"
- " \"stages\": int, (required for non-sm90)\n"
- " },\n"
- " - An integer: corresponds to an index in the tactics array. WARNING this is not stable between test "
- "configurations\n"
- " - An array: of integers or objects, forms a list of tactics to sweep\n"
+ " - An integer: corresponds to an index in the tactics array. WARNING this is not stable between data types "
+ "or GPU architectures\n"
+ " - An array: of integers, forms a list of tactics to sweep\n"
" - The string \"all\": This will sweep through all possible tactics\n"
- " - The string \"auto\": This runs a short benchmark to pick the fastest tactic before each benchmark "
- "case. "
+ " - The string \"auto\": This runs a short benchmark to pick the fastest tactic before each benchmark case. "
"Useful for quick perf tests, prefer a full sweep and manually setting the tactic for more accurate "
"results"
"- dtypes - A list of dtypes to run this config through.\n"
diff --git a/cpp/tensorrt_llm/CMakeLists.txt b/cpp/tensorrt_llm/CMakeLists.txt
index c4814c1d4e..2e625f4687 100644
--- a/cpp/tensorrt_llm/CMakeLists.txt
+++ b/cpp/tensorrt_llm/CMakeLists.txt
@@ -294,8 +294,7 @@ if(TARGET ${NIXL_WRAPPER_TARGET})
endif()
if(NOT WIN32)
- set_target_properties(${SHARED_TARGET} PROPERTIES LINK_FLAGS
- "-Wl,-rpath='$ORIGIN'")
+ set_target_properties(${SHARED_TARGET} PROPERTIES BUILD_RPATH "$ORIGIN")
endif()
if(BUILD_PYT)
diff --git a/cpp/tensorrt_llm/batch_manager/cacheFormatter.cpp b/cpp/tensorrt_llm/batch_manager/cacheFormatter.cpp
index 503c2e6c5d..e73e0f1541 100644
--- a/cpp/tensorrt_llm/batch_manager/cacheFormatter.cpp
+++ b/cpp/tensorrt_llm/batch_manager/cacheFormatter.cpp
@@ -822,6 +822,14 @@ void CacheFormatter::unformat(TransferSession& session)
TLLM_LOG_WARNING("CacheFormatter::inquireSupport: only support non-MLA");
return false;
}
+ if (selfConfig.getParallelConfig().mContextParallelism != 1
+ || destConfig.getParallelConfig().mContextParallelism != 1)
+ {
+ TLLM_LOG_WARNING(
+ "CacheFormatter::inquireSupport: context parallelism is not currently supported (selfCP=%d, destCP=%d).",
+ selfConfig.getParallelConfig().mContextParallelism, destConfig.getParallelConfig().mContextParallelism);
+ return false;
+ }
std::unordered_set setVecDest{
destConfig.getModelConfig().mNbKvHeadsPerLayer.begin(), destConfig.getModelConfig().mNbKvHeadsPerLayer.end()};
diff --git a/cpp/tensorrt_llm/batch_manager/createNewDecoderRequests.cpp b/cpp/tensorrt_llm/batch_manager/createNewDecoderRequests.cpp
index 16771709bb..3335d69a01 100644
--- a/cpp/tensorrt_llm/batch_manager/createNewDecoderRequests.cpp
+++ b/cpp/tensorrt_llm/batch_manager/createNewDecoderRequests.cpp
@@ -20,11 +20,14 @@
#include "tensorrt_llm/batch_manager/llmRequest.h"
#include "tensorrt_llm/batch_manager/medusaBuffers.h"
#include "tensorrt_llm/batch_manager/utils/logitsThread.h"
+#include "tensorrt_llm/common/assert.h"
#include "tensorrt_llm/common/logger.h"
#include "tensorrt_llm/common/nvtxUtils.h"
+#include "tensorrt_llm/runtime/common.h"
#include "tensorrt_llm/runtime/decoderState.h"
#include "tensorrt_llm/runtime/decodingInput.h"
#include "tensorrt_llm/runtime/decodingOutput.h"
+#include "tensorrt_llm/runtime/iBuffer.h"
#include "tensorrt_llm/runtime/runtimeKernels.h"
#include "tensorrt_llm/runtime/speculativeDecodingMode.h"
#include "tensorrt_llm/runtime/utils/mpiUtils.h"
@@ -45,6 +48,8 @@ namespace tensorrt_llm::batch_manager
using SizeType32 = CreateNewDecoderRequests::SizeType32;
using TensorPtr = CreateNewDecoderRequests::TensorPtr;
using SharedConstPtr = CreateNewDecoderRequests::SharedConstPtr;
+template
+using OptionalRef = tensorrt_llm::common::OptionalRef;
namespace
{
@@ -320,149 +325,165 @@ void initializeOutputs(DecodingOutput& dJointOutput, SizeType32 batchSlot, SizeT
TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__);
}
-} // namespace
-
-void CreateNewDecoderRequests::newRequestSpeculativeDecoding(SizeType32 batchIdx,
- runtime::decoder_batch::Request const& request, SamplingConfig const& samplingConfig,
- runtime::ModelConfig const& modelConfig, DecodingInput& jointDecodingInput, DecodingOutput& jointDecodingOutput,
- CudaStream const& runtimeStream, CudaStream const& decoderStream,
- SpeculativeDecodingMode const& speculativeDecodingMode, SizeType32 maxDecodingEngineTokens)
+void retrieveDraftLogits(TensorPtr& draftLogitsHost, std::shared_ptr const& reqDraftLogits,
+ ModelConfig const& modelConfig, WorldConfig const& worldConfig, bool speculativeDecodingFastLogits,
+ bool isLeaderInOrchMode, BufferManager const& bufferManager)
{
TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__);
- if (speculativeDecodingMode.predictsDraftTokens())
+ if (!speculativeDecodingFastLogits)
{
- auto const& stream = decoderStream;
- BufferManager manager{std::make_shared(stream.get())};
+ TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__);
+ bufferManager.copy(*reqDraftLogits, *draftLogitsHost);
+ return;
+ }
- auto& dJointOutput = jointDecodingOutput;
+ if (isLeaderInOrchMode)
+ {
+ // reqDraftLogits contains metadata for fast-logits path; validate size.
+ auto constexpr fastLogitsInfoSize = sizeof(te::SpeculativeDecodingFastLogitsInfo);
+ TLLM_CHECK_WITH_INFO(reqDraftLogits->getSizeInBytes() >= fastLogitsInfoSize,
+ "Draft logits metadata buffer is too small to hold SpeculativeDecodingFastLogitsInfo.");
+ te::SpeculativeDecodingFastLogitsInfo fastLogitsInfo{};
+ std::memcpy(&fastLogitsInfo, reqDraftLogits->data(), fastLogitsInfoSize);
+ utils::targetModelReceiveLogits(draftLogitsHost, fastLogitsInfo, modelConfig.getLogitsDtype());
- TensorPtr nextDraftTokens
- = ITensor::slice(dJointOutput.speculativeDecodingOutputs->nextDraftTokens, batchIdx, 1);
- // FIXME: can we skip this?
- manager.setZero(*nextDraftTokens);
- if (speculativeDecodingMode.variableDraftLength())
+ // Broadcast to other ranks if needed
+ if (worldConfig.isTensorParallel())
{
- TensorPtr nextDraftTokensLen
- = ITensor::slice(dJointOutput.speculativeDecodingOutputs->nextDraftTokensLen, batchIdx, 1);
- manager.setZero(*nextDraftTokensLen);
+ auto const& commSession = COMM_SESSION;
+ auto shape = draftLogitsHost->getShape();
+ commSession.bcastValue(shape.d[0], 0);
+ commSession.bcastValue(shape.d[1], 0);
+ commSession.bcast(draftLogitsHost->data(), draftLogitsHost->getSizeInBytes(), mpi::MpiType::kUINT8, 0);
}
}
+ else
+ {
+ TLLM_CHECK_WITH_INFO(worldConfig.isTensorParallel(),
+ "Fast logits path requires tensor-parallel broadcast for non-leader ranks.");
- if (speculativeDecodingMode.isDraftTokensExternal())
- {
- newRequestDraftTokensExternal(batchIdx, request, samplingConfig, jointDecodingInput, decoderStream);
- }
- else if (speculativeDecodingMode.isMedusa())
- {
- newRequestMedusa(batchIdx, request, jointDecodingInput, decoderStream, maxDecodingEngineTokens);
- }
- else if (speculativeDecodingMode.isLookaheadDecoding())
- {
- newRequestLookahead(batchIdx, request, jointDecodingInput, jointDecodingOutput, runtimeStream);
- }
- else if (speculativeDecodingMode.isExplicitDraftTokens())
- {
- newRequestExplicitDraftTokens(batchIdx, request, jointDecodingOutput, runtimeStream);
- }
- else if (speculativeDecodingMode.isEagle())
- {
- newRequestEagle(batchIdx, request, modelConfig, jointDecodingOutput, runtimeStream);
+ // Get logits from leader rank
+ auto const& commSession = COMM_SESSION;
+ int64_t dims[2];
+ commSession.bcastValue(dims[0], 0);
+ commSession.bcastValue(dims[1], 0);
+ draftLogitsHost->reshape(ITensor::makeShape({dims[0], dims[1]}));
+ commSession.bcast(draftLogitsHost->data(), draftLogitsHost->getSizeInBytes(), mpi::MpiType::kUINT8, 0);
}
+
TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__);
-}
+};
-void CreateNewDecoderRequests::newRequestDraftTokensExternal(SizeType32 batchIdx,
- runtime::decoder_batch::Request const& request, SamplingConfig const& samplingConfig,
- DecodingInput& jointDecodingInput, CudaStream const& decoderStream)
+//! @brief Setups decoder internal tensors for new request in Draft model Sps mode
+void newRequestDraftTokensExternal(DecodingInput& jointDecodingInput, SizeType32 batchIdx, LlmRequest const& llmReq,
+ SizeType32 numDecodingEngineTokens, runtime::ModelConfig const& modelConfig, WorldConfig const& worldConfig,
+ bool speculativeDecodingFastLogits, bool isLeaderInOrchMode, CudaStream const& decoderStream)
{
TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__);
- BufferManager manager{std::make_shared(decoderStream.get())};
+ BufferManager decoderBufferManager{std::make_shared(decoderStream.get())};
- auto& dJointInput = jointDecodingInput;
+ TLLM_CHECK(jointDecodingInput.externalDraftTokensInputs);
+ auto& externalDraftTokensInputs = jointDecodingInput.externalDraftTokensInputs;
- auto const numDraftTokens = request.generatedTokensPerEngineStep - 1;
+ auto const& draftTokens = llmReq.getDraftTokens();
+ auto const numDraftTokens = numDecodingEngineTokens - 1;
- auto const useDraftLogits = request.draftLogits.has_value();
- if (useDraftLogits)
- {
- TensorPtr draftLogitsView = ITensor::view(request.draftLogits.value());
-
- TensorPtr draftLogitsReqBatchSlice
- = ITensor::slice(dJointInput.externalDraftTokensInputs->draftLogits, batchIdx, 1);
- draftLogitsReqBatchSlice->squeeze(0);
- TensorPtr draftLogitsReqTokensSlice = ITensor::slice(draftLogitsReqBatchSlice, 0, numDraftTokens);
- manager.copy(*draftLogitsView, *draftLogitsReqTokensSlice);
- }
- auto* useDraftLogitsHostPtr = runtime::bufferCast(*dJointInput.externalDraftTokensInputs->useDraftLogitsHost);
- useDraftLogitsHostPtr[batchIdx] = useDraftLogits;
- auto useDraftLogitsView = ITensor::slice(dJointInput.externalDraftTokensInputs->useDraftLogits, batchIdx, 1);
- runtime::kernels::invokeFill(*useDraftLogitsView, useDraftLogits, decoderStream);
+ auto numDraftTokensHostRange = runtime::BufferRange(*externalDraftTokensInputs->numDraftTokensHost);
+ numDraftTokensHostRange[batchIdx] = numDraftTokens;
+ auto numDraftTokensView = ITensor::slice(externalDraftTokensInputs->numDraftTokens, batchIdx, 1);
+ runtime::kernels::invokeFill(*numDraftTokensView, numDraftTokens, decoderStream);
if (numDraftTokens > 0)
{
- TensorPtr draftTokensReqBatchSlice
- = ITensor::slice(dJointInput.externalDraftTokensInputs->draftTokenIds, batchIdx, 1);
- draftTokensReqBatchSlice->squeeze(0);
- TensorPtr draftTokensReqTokensSlice = ITensor::slice(draftTokensReqBatchSlice, 0, numDraftTokens);
- TensorPtr draftTokensView = ITensor::view(request.draftTokens, ITensor::makeShape({numDraftTokens}));
- manager.copy(*draftTokensView, *draftTokensReqTokensSlice);
+ TensorPtr draftTokenIdsHostSlice
+ = ITensor::slice(externalDraftTokensInputs->draftTokenIdsHost, {batchIdx, 0}, numDraftTokens);
+ // Copy to pinned host memory (don't care about stream of bufferManager)
+ decoderBufferManager.copy(draftTokens->data(), *draftTokenIdsHostSlice);
+
+ TensorPtr draftTokenIdsSlice
+ = ITensor::slice(externalDraftTokensInputs->draftTokenIds, {batchIdx, 0}, numDraftTokens);
+ decoderBufferManager.copy(*draftTokenIdsHostSlice, *draftTokenIdsSlice);
}
- auto* numDraftTokensHostPtr
- = runtime::bufferCast(*dJointInput.externalDraftTokensInputs->numDraftTokensHost);
- numDraftTokensHostPtr[batchIdx] = numDraftTokens;
- auto numDraftTokensView = ITensor::slice(dJointInput.externalDraftTokensInputs->numDraftTokens, batchIdx, 1);
- runtime::kernels::invokeFill(*numDraftTokensView, numDraftTokens, decoderStream);
+ auto const& draftLogits = llmReq.getDraftLogits();
+ auto const useDraftLogits = draftLogits.has_value();
+ auto useDraftLogitsHostRange = runtime::BufferRange(*externalDraftTokensInputs->useDraftLogitsHost);
+ useDraftLogitsHostRange[batchIdx] = useDraftLogits;
+ auto useDraftLogitsView = ITensor::slice(externalDraftTokensInputs->useDraftLogits, batchIdx, 1);
+ runtime::kernels::invokeFill(*useDraftLogitsView, useDraftLogits, decoderStream);
+
+ if (useDraftLogits)
+ {
+ TensorPtr draftLogitsHostSlice
+ = ITensor::slice(externalDraftTokensInputs->draftLogitsHost, {batchIdx, 0}, numDraftTokens);
+ retrieveDraftLogits(draftLogitsHostSlice, draftLogits.value(), modelConfig, worldConfig,
+ speculativeDecodingFastLogits, isLeaderInOrchMode, decoderBufferManager);
+
+ TensorPtr draftLogitsSlice
+ = ITensor::slice(externalDraftTokensInputs->draftLogits, {batchIdx, 0}, numDraftTokens);
+ decoderBufferManager.copy(*draftLogitsHostSlice, *draftLogitsSlice);
+ }
+
+ auto const& samplingConfig = llmReq.mSamplingConfig;
bool const useRandomAcceptanceThreshold = !samplingConfig.draftAcceptanceThreshold.has_value();
float const constantThreshold
= useRandomAcceptanceThreshold ? 0 : samplingConfig.draftAcceptanceThreshold.value()[0];
- dJointInput.externalDraftTokensInputs->useRandomAcceptanceThreshold = useRandomAcceptanceThreshold;
- dJointInput.externalDraftTokensInputs->constantThreshold = constantThreshold;
+ externalDraftTokensInputs->useRandomAcceptanceThreshold = useRandomAcceptanceThreshold;
+ externalDraftTokensInputs->constantThreshold = constantThreshold;
TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__);
}
-void CreateNewDecoderRequests::newRequestMedusa(SizeType32 batchIdx, runtime::decoder_batch::Request const& request,
- DecodingInput& jointDecodingInput, CudaStream const& decoderStream, SizeType32 maxDecodingEngineTokens)
+//! @brief Setups decoder internal tensors for new Medusa request
+void newRequestMedusa(DecodingInput& jointDecodingInput, SizeType32 batchIdx, LlmRequest& llmReq,
+ SizeType32 numDecodingEngineTokens, SizeType32 maxDecodingEngineTokens, MedusaBuffers const& medusaBuffers,
+ CudaStream const& decoderStream)
{
TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__);
+ llmReq.mSamplingConfig.topKMedusaHeads = {medusaBuffers.mTopKs};
+ // FIXME: we must set medusa paths and tree ids not from seq slot, but from llmRequest?
+ // When multiple microbatches buffers are used, runtime buffers can not be addressed with seqSlot.
+ auto medusaPaths = ITensor::slice(medusaBuffers.medusaPathsDevice, 0, 1);
+ auto medusaTreeIds = ITensor::slice(medusaBuffers.medusaTreeIdsDevice, 0, 1);
+
BufferManager manager{std::make_shared(decoderStream.get())};
- auto& dJointInput = jointDecodingInput;
+ auto& medusaInputs = jointDecodingInput.medusaInputs;
TensorPtr curTokensPerStepSlice
- = ITensor::slice(constPointerCast(dJointInput.medusaInputs->medusaCurTokensPerStep), batchIdx, 1);
+ = ITensor::slice(constPointerCast(medusaInputs->medusaCurTokensPerStep), batchIdx, 1);
// Context phase Medusa processes 1 token only, new value from targetTokensPerStep will be filled at the end
// of first decoder
runtime::kernels::invokeFill(*curTokensPerStepSlice, 1, decoderStream);
TensorPtr targetTokensPerStepSlice
- = ITensor::slice(constPointerCast(dJointInput.medusaInputs->medusaTargetTokensPerStep), batchIdx, 1);
- auto const generatedTokensPerEngineStep = request.generatedTokensPerEngineStep;
- TLLM_CHECK_WITH_INFO(generatedTokensPerEngineStep <= maxDecodingEngineTokens,
- "Tokens per step for (%d) is larger than maximum tokens per step (%d)", generatedTokensPerEngineStep,
+ = ITensor::slice(constPointerCast(medusaInputs->medusaTargetTokensPerStep), batchIdx, 1);
+ TLLM_CHECK_WITH_INFO(numDecodingEngineTokens <= maxDecodingEngineTokens,
+ "Tokens per step for (%d) is larger than maximum tokens per step (%d)", numDecodingEngineTokens,
maxDecodingEngineTokens);
- runtime::kernels::invokeFill(*targetTokensPerStepSlice, generatedTokensPerEngineStep, decoderStream);
+ runtime::kernels::invokeFill(*targetTokensPerStepSlice, numDecodingEngineTokens, decoderStream);
- TensorPtr pathsSlice = ITensor::slice(constPointerCast(dJointInput.medusaInputs->medusaPaths), batchIdx, 1);
- manager.copy(*request.medusaPaths, *pathsSlice);
+ TensorPtr pathsSlice = ITensor::slice(constPointerCast(medusaInputs->medusaPaths), batchIdx, 1);
+ manager.copy(*medusaPaths, *pathsSlice);
- TensorPtr treeIdsSlice = ITensor::slice(constPointerCast(dJointInput.medusaInputs->medusaTreeIds), batchIdx, 1);
- manager.copy(*request.medusaTreeIds, *treeIdsSlice);
+ TensorPtr treeIdsSlice = ITensor::slice(constPointerCast(medusaInputs->medusaTreeIds), batchIdx, 1);
+ manager.copy(*medusaTreeIds, *treeIdsSlice);
TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__);
}
-void CreateNewDecoderRequests::newRequestLookahead(SizeType32 batchIdx, runtime::decoder_batch::Request const& request,
- DecodingInput& jointDecodingInput, DecodingOutput& jointDecodingOutput, CudaStream const& runtimeStream)
+//! @brief Setups decoder internal tensors for new Lookahead request
+void newRequestLookahead(DecodingInput& jointDecodingInput, DecodingOutput& jointDecodingOutput, SizeType32 batchIdx,
+ CudaStream const& runtimeStream)
{
TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__);
TLLM_CHECK(jointDecodingOutput.lookaheadOutputs);
+ TLLM_CHECK(jointDecodingInput.lookaheadInputs);
// The first generation step only generate 1 token.
TensorPtr curTokensPerStepSlice
@@ -472,65 +493,72 @@ void CreateNewDecoderRequests::newRequestLookahead(SizeType32 batchIdx, runtime:
TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__);
}
-void CreateNewDecoderRequests::newRequestExplicitDraftTokens(SizeType32 batchIdx,
- runtime::decoder_batch::Request const& request, DecodingOutput& jointDecodingOutput,
- CudaStream const& runtimeStream)
+//! @brief Setups decoder internal tensors for new Explicit draft tokens request
+void newRequestExplicitDraftTokens(
+ DecodingOutput& jointDecodingOutput, SizeType32 batchIdx, LlmRequest const& llmReq, CudaStream const& runtimeStream)
{
TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__);
TLLM_CHECK(jointDecodingOutput.explicitDraftTokensBuffers);
+ auto const inputLen = llmReq.getPromptLen();
+
TensorPtr positionIdsBaseSlice
= ITensor::slice(jointDecodingOutput.explicitDraftTokensBuffers->positionIdsBase, batchIdx, 1);
- runtime::kernels::invokeFill(*positionIdsBaseSlice, request.inputLen, runtimeStream);
+ runtime::kernels::invokeFill(*positionIdsBaseSlice, inputLen, runtimeStream);
TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__);
}
-void CreateNewDecoderRequests::newRequestEagle(SizeType32 batchIdx, runtime::decoder_batch::Request const& request,
- runtime::ModelConfig const& modelConfig, DecodingOutput& jointDecodingOutput, CudaStream const& runtimeStream)
+//! @brief Setups decoder internal tensors for new Eagle request
+void newRequestEagle(DecodingOutput& jointDecodingOutput, SizeType32 batchIdx, LlmRequest const& llmReq,
+ runtime::ModelConfig const& modelConfig, executor::DecodingConfig const& decodingConfig,
+ CudaStream const& runtimeStream)
{
TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__);
TLLM_CHECK(jointDecodingOutput.eagleBuffers);
+ auto& eagleBuffers = *jointDecodingOutput.eagleBuffers;
+
+ auto const inputLen = llmReq.getPromptLen();
BufferManager manager{std::make_shared(runtimeStream.get())};
- TensorPtr eagleNetCtxRequestTypesHostSlice
- = ITensor::slice(jointDecodingOutput.eagleBuffers->eagleNetCtxRequestTypesHost, batchIdx, 1);
+ TensorPtr eagleNetCtxRequestTypesHostSlice = ITensor::slice(eagleBuffers.eagleNetCtxRequestTypesHost, batchIdx, 1);
TensorPtr eagleNetCtxContextLengthsHostSlice
- = ITensor::slice(jointDecodingOutput.eagleBuffers->eagleNetCtxContextLengthsHost, batchIdx, 1);
+ = ITensor::slice(eagleBuffers.eagleNetCtxContextLengthsHost, batchIdx, 1);
TensorPtr eagleNetCtxPastKeyValueLengthsHostSlice
- = ITensor::slice(jointDecodingOutput.eagleBuffers->eagleNetCtxPastKeyValueLengthsHost, batchIdx, 1);
+ = ITensor::slice(eagleBuffers.eagleNetCtxPastKeyValueLengthsHost, batchIdx, 1);
runtime::bufferCast(*eagleNetCtxRequestTypesHostSlice)[0] = 0;
- runtime::bufferCast(*eagleNetCtxContextLengthsHostSlice)[0] = request.inputLen;
- runtime::bufferCast(*eagleNetCtxPastKeyValueLengthsHostSlice)[0] = request.inputLen;
+ runtime::bufferCast(*eagleNetCtxContextLengthsHostSlice)[0] = inputLen;
+ runtime::bufferCast(*eagleNetCtxPastKeyValueLengthsHostSlice)[0] = inputLen;
- TensorPtr eagleNetGenRequestTypesHostSlice
- = ITensor::slice(jointDecodingOutput.eagleBuffers->eagleNetGenRequestTypesHost, batchIdx, 1);
+ TensorPtr eagleNetGenRequestTypesHostSlice = ITensor::slice(eagleBuffers.eagleNetGenRequestTypesHost, batchIdx, 1);
TensorPtr eagleNetGenContextLengthsHostSlice
- = ITensor::slice(jointDecodingOutput.eagleBuffers->eagleNetGenContextLengthsHost, batchIdx, 1);
+ = ITensor::slice(eagleBuffers.eagleNetGenContextLengthsHost, batchIdx, 1);
TensorPtr eagleNetGenPastKeyValueLengthsHostSlice
- = ITensor::slice(jointDecodingOutput.eagleBuffers->eagleNetGenPastKeyValueLengthsHost, batchIdx, 1);
+ = ITensor::slice(eagleBuffers.eagleNetGenPastKeyValueLengthsHost, batchIdx, 1);
runtime::bufferCast(*eagleNetGenRequestTypesHostSlice)[0] = 1;
- runtime::bufferCast(*eagleNetGenContextLengthsHostSlice)[0] = request.inputLen;
- runtime::bufferCast(*eagleNetGenPastKeyValueLengthsHostSlice)[0] = request.inputLen;
+ runtime::bufferCast(*eagleNetGenContextLengthsHostSlice)[0] = inputLen;
+ runtime::bufferCast(*eagleNetGenPastKeyValueLengthsHostSlice)[0] = inputLen;
auto const eagleModule = std::dynamic_pointer_cast(
modelConfig.getSpeculativeDecodingModulePtr());
std::optional eagleChoicesOpt;
- if (request.eagleConfig)
+ auto const& eagleConfig = llmReq.getEagleConfig() ? llmReq.getEagleConfig() : decodingConfig.getEagleConfig();
+
+ if (eagleConfig)
{
- eagleChoicesOpt = request.eagleConfig->getEagleChoices();
+ eagleChoicesOpt = eagleConfig->getEagleChoices();
}
- if (!request.eagleConfig || !request.eagleConfig->useDynamicTree())
+ if (!eagleConfig || !eagleConfig->useDynamicTree())
{
- TensorPtr draftPathsHostSlice = ITensor::slice(jointDecodingOutput.eagleBuffers->draftPathsHost, batchIdx, 1);
- TensorPtr draftPathsSlice = ITensor::slice(jointDecodingOutput.eagleBuffers->draftPaths, batchIdx, 1);
+ TensorPtr draftPathsHostSlice = ITensor::slice(eagleBuffers.draftPathsHost, batchIdx, 1);
+ TensorPtr draftPathsSlice = ITensor::slice(eagleBuffers.draftPaths, batchIdx, 1);
// eagleConfig is nullptr or Eagle-1
std::vector topKs;
@@ -546,6 +574,61 @@ void CreateNewDecoderRequests::newRequestEagle(SizeType32 batchIdx, runtime::dec
TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__);
}
+//! @brief Setups decoder internal tensors for new speculative decoding request
+void newRequestSpeculativeDecoding(DecodingInput& jointDecodingInput, DecodingOutput& jointDecodingOutput,
+ SizeType32 batchIdx, LlmRequest& llmReq, SpeculativeDecodingMode const& speculativeDecodingMode,
+ SizeType32 numDecodingEngineTokens, SizeType32 maxDecodingEngineTokens,
+ OptionalRef medusaBuffers, runtime::ModelConfig const& modelConfig,
+ WorldConfig const& worldConfig, executor::DecodingConfig const& decodingConfig, bool speculativeDecodingFastLogits,
+ bool isLeaderInOrchMode, CudaStream const& runtimeStream, CudaStream const& decoderStream)
+{
+ TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__);
+
+ if (speculativeDecodingMode.predictsDraftTokens())
+ {
+ BufferManager manager{std::make_shared(decoderStream.get())};
+
+ TLLM_CHECK(jointDecodingOutput.speculativeDecodingOutputs);
+ auto& speculativeDecodingOutputs = *jointDecodingOutput.speculativeDecodingOutputs;
+
+ TensorPtr nextDraftTokens = ITensor::slice(speculativeDecodingOutputs.nextDraftTokens, batchIdx, 1);
+ // FIXME: can we skip this?
+ manager.setZero(*nextDraftTokens);
+ if (speculativeDecodingMode.variableDraftLength())
+ {
+ TensorPtr nextDraftTokensLen = ITensor::slice(speculativeDecodingOutputs.nextDraftTokensLen, batchIdx, 1);
+ manager.setZero(*nextDraftTokensLen);
+ }
+ }
+
+ if (speculativeDecodingMode.isDraftTokensExternal())
+ {
+ newRequestDraftTokensExternal(jointDecodingInput, batchIdx, llmReq, numDecodingEngineTokens, modelConfig,
+ worldConfig, speculativeDecodingFastLogits, isLeaderInOrchMode, decoderStream);
+ }
+ else if (speculativeDecodingMode.isMedusa())
+ {
+ TLLM_CHECK(medusaBuffers);
+ newRequestMedusa(jointDecodingInput, batchIdx, llmReq, numDecodingEngineTokens, maxDecodingEngineTokens,
+ medusaBuffers.value(), decoderStream);
+ }
+ else if (speculativeDecodingMode.isLookaheadDecoding())
+ {
+ newRequestLookahead(jointDecodingInput, jointDecodingOutput, batchIdx, runtimeStream);
+ }
+ else if (speculativeDecodingMode.isExplicitDraftTokens())
+ {
+ newRequestExplicitDraftTokens(jointDecodingOutput, batchIdx, llmReq, runtimeStream);
+ }
+ else if (speculativeDecodingMode.isEagle())
+ {
+ newRequestEagle(jointDecodingOutput, batchIdx, llmReq, modelConfig, decodingConfig, runtimeStream);
+ }
+ TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__);
+}
+
+} // namespace
+
std::tuple, std::vector>
CreateNewDecoderRequests::createDecoderRequests(RequestVector const& finishedContextRequests, TensorPtr const& inputIds,
executor::DecodingConfig const& decodingConfig, runtime::decoder::DecoderState& decoderState,
@@ -563,9 +646,6 @@ CreateNewDecoderRequests::createDecoderRequests(RequestVector const& finishedCon
}
inputIds->resize(decoderInputSize);
- std::vector decoderRequests;
- decoderRequests.reserve(finishedContextRequests.size());
-
std::vector lookaheadPrompt;
std::vector lookaheadAlgoConfigs;
if (modelConfig.getSpeculativeDecodingMode().isLookaheadDecoding())
@@ -597,36 +677,18 @@ CreateNewDecoderRequests::createDecoderRequests(RequestVector const& finishedCon
auto const promptLen = llmReq->getPromptLen();
- auto decoderRequest = decoder_batch::Request{promptLen};
-
+ SizeType32 numDecodingEngineTokens{1};
if (modelConfig.getSpeculativeDecodingMode().isDraftTokensExternal())
{
- if (llmReq->hasDraftTokens())
- {
- auto const& draftTokens = llmReq->getDraftTokens();
- // Copy to pinned host memory (don't care about stream of bufferManager)
- decoderRequest.draftTokens = decoderBufferManager.copyFrom(*draftTokens, MemoryType::kPINNEDPOOL);
- auto const& draftLogits = llmReq->getDraftLogits();
- if (draftLogits.has_value())
- {
- decoderRequest.draftLogits
- = retrieveDraftLogits(modelConfig, worldConfig, draftLogits.value(), decoderBufferManager);
- }
- decoderRequest.generatedTokensPerEngineStep = draftTokens->size() + 1;
- }
- else
- {
- decoderRequest.generatedTokensPerEngineStep = 1;
- }
+ numDecodingEngineTokens = llmReq->getNumDraftTokens() + 1;
}
else if (!modelConfig.getSpeculativeDecodingMode().isNone())
{
- decoderRequest.generatedTokensPerEngineStep = modelConfig.getMaxDecodingTokens();
+ numDecodingEngineTokens = modelConfig.getMaxDecodingTokens();
}
auto& dJointInput = decoderState.getJointDecodingInput();
- auto const numDecodingEngineTokens = decoderRequest.generatedTokensPerEngineStep;
initializeInputLengths(dJointInput, batchSlot, promptLen, llmReq->mMaxNewTokens, numDecodingEngineTokens,
maxSequenceLength, decoderBufferManager);
decoderState.setNumDecodingEngineTokens(batchSlot, numDecodingEngineTokens);
@@ -667,16 +729,7 @@ CreateNewDecoderRequests::createDecoderRequests(RequestVector const& finishedCon
{
TLLM_CHECK(beamWidth == 1);
- if (modelConfig.getSpeculativeDecodingMode().isMedusa())
- {
- TLLM_CHECK(medusaBuffers);
- llmReq->mSamplingConfig.topKMedusaHeads = {medusaBuffers->mTopKs};
- // FIXME: we must set medusa paths and tree ids not from seq slot, but from llmRequest?
- // When multiple microbatches buffers are used, runtime buffers can not be addressed with seqSlot.
- decoderRequest.medusaPaths = ITensor::slice(medusaBuffers->medusaPathsDevice, 0, 1);
- decoderRequest.medusaTreeIds = ITensor::slice(medusaBuffers->medusaTreeIdsDevice, 0, 1);
- }
- else if (modelConfig.getSpeculativeDecodingMode().isLookaheadDecoding())
+ if (modelConfig.getSpeculativeDecodingMode().isLookaheadDecoding())
{
lookaheadPrompt.emplace_back(requestIds);
@@ -684,67 +737,17 @@ CreateNewDecoderRequests::createDecoderRequests(RequestVector const& finishedCon
= llmReq->getLookaheadConfig().value_or(decodingConfig.getLookaheadDecodingConfig().value());
lookaheadAlgoConfigs.emplace_back(lookaheadRuntimeConfig);
}
- else if (modelConfig.getSpeculativeDecodingMode().isEagle())
- {
- decoderRequest.eagleConfig
- = llmReq->getEagleConfig() ? llmReq->getEagleConfig() : decodingConfig.getEagleConfig();
- }
- newRequestSpeculativeDecoding(batchSlot, decoderRequest, samplingConfig, modelConfig,
- decoderState.getJointDecodingInput(), decoderState.getJointDecodingOutput(), runtimeStream,
- decoderStream, decoderState.getSpeculativeDecodingMode(), decoderState.getMaxDecodingEngineTokens());
+ newRequestSpeculativeDecoding(decoderState.getJointDecodingInput(), decoderState.getJointDecodingOutput(),
+ batchSlot, *llmReq, decoderState.getSpeculativeDecodingMode(), numDecodingEngineTokens,
+ decoderState.getMaxDecodingEngineTokens(), medusaBuffers, modelConfig, worldConfig, decodingConfig,
+ mSpeculativeDecodingFastLogits, mIsLeaderInOrchMode, runtimeStream, decoderStream);
}
- decoderRequests.push_back(decoderRequest);
-
inputOffset += promptLen;
}
return {std::move(lookaheadPrompt), std::move(lookaheadAlgoConfigs)};
}
-std::shared_ptr CreateNewDecoderRequests::retrieveDraftLogits(ModelConfig const& modelConfig,
- WorldConfig const& worldConfig, std::shared_ptr const& tensor,
- BufferManager const& bufferManager) const
-{
- TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__);
-
- if (!mSpeculativeDecodingFastLogits)
- {
- TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__);
- return bufferManager.copyFrom(*tensor, MemoryType::kPINNEDPOOL);
- }
-
- if (mIsLeaderInOrchMode)
- {
- te::SpeculativeDecodingFastLogitsInfo fastLogitsInfo;
- std::memcpy(&fastLogitsInfo, tensor->data(), sizeof(fastLogitsInfo));
- auto logits = utils::targetModelReceiveLogits(fastLogitsInfo, modelConfig).value();
-
- // Broadcast to other ranks if needed
- if (worldConfig.isTensorParallel())
- {
- auto const& commSession = COMM_SESSION;
- auto shape = logits->getShape();
- commSession.bcastValue(shape.d[0], 0);
- commSession.bcastValue(shape.d[1], 0);
- commSession.bcast(logits->data(), logits->getSizeInBytes(), mpi::MpiType::kUINT8, 0);
- }
- TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__);
- return logits;
- }
-
- // Get logits from leader rank
- auto const& commSession = COMM_SESSION;
- int64_t dims[2];
- commSession.bcastValue(dims[0], 0);
- commSession.bcastValue(dims[1], 0);
- auto const logitsDtype = modelConfig.getLogitsDtype();
- auto logits = tensorrt_llm::runtime::BufferManager::pinnedPool(ITensor::makeShape({dims[0], dims[1]}), logitsDtype);
- commSession.bcast(logits->data(), logits->getSizeInBytes(), mpi::MpiType::kUINT8, 0);
-
- TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__);
- return logits;
-};
-
} // namespace tensorrt_llm::batch_manager
diff --git a/cpp/tensorrt_llm/batch_manager/mlaCacheFormatter.cpp b/cpp/tensorrt_llm/batch_manager/mlaCacheFormatter.cpp
index 22756f2552..eaa2e957e8 100644
--- a/cpp/tensorrt_llm/batch_manager/mlaCacheFormatter.cpp
+++ b/cpp/tensorrt_llm/batch_manager/mlaCacheFormatter.cpp
@@ -558,18 +558,20 @@ void MLACacheFormatter::unformat(TransferSession& session)
TLLM_LOG_WARNING("MLACacheFormatter::inquireSupport: only support MLA");
return false;
}
-
- if (selfConfig.getAttentionConfig().mKvFactor != destConfig.getAttentionConfig().mKvFactor)
- {
- TLLM_LOG_WARNING("MLACacheFormatter::inquireSupport: only support same kv factor");
- return false;
- }
if (selfConfig.getParallelConfig().mEnableAttentionDP
&& (selfConfig.getParallelConfig().mTensorParallelism % selfConfig.getParallelConfig().mDPsize != 0))
{
TLLM_LOG_WARNING("MLACacheFormatter::inquireSupport: TP size must be divisible by DP size");
return false;
}
+ if (selfConfig.getParallelConfig().mContextParallelism != 1
+ || destConfig.getParallelConfig().mContextParallelism != 1)
+ {
+ TLLM_LOG_WARNING(
+ "MLACacheFormatter::inquireSupport: context parallelism is not currently supported (selfCP=%d, destCP=%d).",
+ selfConfig.getParallelConfig().mContextParallelism, destConfig.getParallelConfig().mContextParallelism);
+ return false;
+ }
if (destConfig.getParallelConfig().mEnableAttentionDP
&& (destConfig.getParallelConfig().mTensorParallelism % destConfig.getParallelConfig().mDPsize != 0))
{
diff --git a/cpp/tensorrt_llm/batch_manager/utils/logitsThread.cpp b/cpp/tensorrt_llm/batch_manager/utils/logitsThread.cpp
index 484cd7c3c7..7234ca9ba5 100644
--- a/cpp/tensorrt_llm/batch_manager/utils/logitsThread.cpp
+++ b/cpp/tensorrt_llm/batch_manager/utils/logitsThread.cpp
@@ -121,8 +121,8 @@ void draftModelSendLogitsThread(int device, std::atomic* draftModelThreadS
#endif // ENABLE_MULTI_DEVICE
}
-std::optional targetModelReceiveLogits(
- executor::SpeculativeDecodingFastLogitsInfo const& fastLogitsInfo, runtime::ModelConfig const& modelConfig)
+void targetModelReceiveLogits(runtime::ITensor::SharedPtr& draftLogitsHost,
+ executor::SpeculativeDecodingFastLogitsInfo const& fastLogitsInfo, nvinfer1::DataType logitsDtype)
{
#if ENABLE_MULTI_DEVICE
auto const& worldComm = tensorrt_llm::mpi::MpiComm::world();
@@ -151,10 +151,7 @@ std::optional targetModelReceiveLogits(
int64_t dims[2];
MPICHECK(MPI_Mrecv(&dims, count, MPI_INT64_T, &msg, &status));
- auto const logitsDtype = modelConfig.getLogitsDtype();
-
- auto tensor = tensorrt_llm::runtime::BufferManager::pinnedPool(
- runtime::ITensor::makeShape({dims[0], dims[1]}), logitsDtype);
+ draftLogitsHost->reshape(runtime::ITensor::makeShape({dims[0], dims[1]}));
worldComm.mprobe(fastLogitsInfo.draftParticipantId, mpi::MpiTag::kSpecDecLogitsData, &msg, &status);
@@ -163,11 +160,7 @@ std::optional targetModelReceiveLogits(
uint64_t const expectedSize = static_cast(dims[0]) * dims[1] * tc::getDTypeSize(logitsDtype);
TLLM_CHECK((uint64_t) count == expectedSize);
- MPICHECK(MPI_Mrecv(tensor->data(), count, MPI_UINT8_T, &msg, &status));
-
- return tensor;
-#else
- return std::nullopt;
+ MPICHECK(MPI_Mrecv(draftLogitsHost->data(), count, MPI_UINT8_T, &msg, &status));
#endif // ENABLE_MULTI_DEVICE
}
diff --git a/cpp/tensorrt_llm/batch_manager/utils/logitsThread.h b/cpp/tensorrt_llm/batch_manager/utils/logitsThread.h
index 6d87ebee16..f19d5f5ef3 100644
--- a/cpp/tensorrt_llm/batch_manager/utils/logitsThread.h
+++ b/cpp/tensorrt_llm/batch_manager/utils/logitsThread.h
@@ -21,10 +21,8 @@
#include "tensorrt_llm/executor/executor.h"
#include "tensorrt_llm/runtime/common.h"
#include "tensorrt_llm/runtime/iTensor.h"
-#include "tensorrt_llm/runtime/modelConfig.h"
#include
-#include
namespace tensorrt_llm::batch_manager
{
@@ -52,7 +50,7 @@ void draftModelSendLogitsThread(int device, std::atomic* draftModelThreadS
std::shared_ptr const& crossKvCacheManager,
std::shared_ptr const& peftCacheManager);
-std::optional targetModelReceiveLogits(
- executor::SpeculativeDecodingFastLogitsInfo const& fastLogitsInfo, runtime::ModelConfig const& modelConfig);
+void targetModelReceiveLogits(runtime::ITensor::SharedPtr& draftLogitsHost,
+ executor::SpeculativeDecodingFastLogitsInfo const& fastLogitsInfo, nvinfer1::DataType logitsDtype);
} // namespace tensorrt_llm::batch_manager::utils
diff --git a/cpp/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm_configs.h b/cpp/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm_configs.h
index 3fab43a3b4..e226c0d812 100644
--- a/cpp/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm_configs.h
+++ b/cpp/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm_configs.h
@@ -435,6 +435,14 @@ struct CutlassGemmConfig
int sm_version = 80; // Use 80 as a catch all for <90
bool is_tma_warp_specialized = false;
+ enum class EpilogueFusionType : int
+ {
+ NONE,
+ FINALIZE
+ };
+
+ EpilogueFusionType epilogue_fusion_type = EpilogueFusionType::NONE;
+
CutlassGemmConfig() = default;
CutlassGemmConfig(CutlassTileConfig tile_config, SplitKStyle split_k_style, int split_k_factor, int stages)
@@ -505,7 +513,8 @@ struct CutlassGemmConfig
<< "\n\tsm: " << sm_version << "\n\ttile shape ID: " << getTileConfigAsInt()
<< "\n\tcluster shape ID: " << (int) cluster_shape
<< "\n\tmainloop sched: " << (int) mainloop_schedule << "\n\tepi sched: " << (int) epilogue_schedule
- << "\n\tenable cuda kernel: " << (enableCudaKernel ? "true" : "false");
+ << "\n\tenable cuda kernel: " << (enableCudaKernel ? "true" : "false")
+ << "\n\tepilogue fusion type: " << (int) epilogue_fusion_type;
}
else if (tile_config_sm80 != tensorrt_llm::cutlass_extensions::CutlassTileConfig::ChooseWithHeuristic)
{
@@ -537,7 +546,8 @@ inline std::ostream& operator<<(std::ostream& out, CutlassGemmConfig const& conf
<< ", mainloop_schedule_enum: " << int(config.mainloop_schedule)
<< ", epilogue_schedule_enum: " << int(config.epilogue_schedule)
<< ", cluster_shape_enum: " << int(config.cluster_shape)
- << ", enable_cuda_kernel: " << (config.enableCudaKernel ? "true" : "false");
+ << ", enable_cuda_kernel: " << (config.enableCudaKernel ? "true" : "false")
+ << ", epilogue_fusion_type: " << int(config.epilogue_fusion_type);
}
else
{
diff --git a/cpp/tensorrt_llm/executor/serialization.cpp b/cpp/tensorrt_llm/executor/serialization.cpp
index 738a095eef..bba8d19e2f 100644
--- a/cpp/tensorrt_llm/executor/serialization.cpp
+++ b/cpp/tensorrt_llm/executor/serialization.cpp
@@ -531,14 +531,15 @@ kv_cache::CacheState Serialization::deserializeCacheState(std::istream& is)
auto tokensPerBlock = su::deserialize(is);
auto tensorParallelism = su::deserialize(is);
auto pipelineParallelism = su::deserialize(is);
+ auto contextParallelism = su::deserialize(is);
auto enableAttentionDP = su::deserialize(is);
auto DPrank = su::deserialize(is);
auto DPsize = su::deserialize(is);
auto dataType = su::deserialize(is);
auto attentionType = su::deserialize(is);
auto kvFactor = su::deserialize(is);
- return CacheState{nbKvHeadsPerLayer, sizePerHead, tokensPerBlock, tensorParallelism, pipelineParallelism, dataType,
- attentionType, kvFactor, enableAttentionDP, DPrank, DPsize};
+ return CacheState{nbKvHeadsPerLayer, sizePerHead, tokensPerBlock, tensorParallelism, pipelineParallelism,
+ contextParallelism, dataType, attentionType, kvFactor, enableAttentionDP, DPrank, DPsize};
}
void Serialization::serialize(kv_cache::CacheState const& state, std::ostream& os)
@@ -548,6 +549,7 @@ void Serialization::serialize(kv_cache::CacheState const& state, std::ostream& o
su::serialize(state.mModelConfig.mTokensPerBlock, os);
su::serialize(state.mParallelConfig.mTensorParallelism, os);
su::serialize(state.mParallelConfig.mPipelineParallelism, os);
+ su::serialize(state.mParallelConfig.mContextParallelism, os);
su::serialize(state.mParallelConfig.mEnableAttentionDP, os);
su::serialize(state.mParallelConfig.mDPrank, os);
su::serialize(state.mParallelConfig.mDPsize, os);
@@ -564,6 +566,7 @@ size_t Serialization::serializedSize(kv_cache::CacheState const& state)
totalSize += su::serializedSize(state.mModelConfig.mTokensPerBlock);
totalSize += su::serializedSize(state.mParallelConfig.mTensorParallelism);
totalSize += su::serializedSize(state.mParallelConfig.mPipelineParallelism);
+ totalSize += su::serializedSize(state.mParallelConfig.mContextParallelism);
totalSize += su::serializedSize(state.mParallelConfig.mEnableAttentionDP);
totalSize += su::serializedSize(state.mParallelConfig.mDPrank);
totalSize += su::serializedSize(state.mParallelConfig.mDPsize);
diff --git a/cpp/tensorrt_llm/kernels/customMoeRoutingKernels.cu b/cpp/tensorrt_llm/kernels/customMoeRoutingKernels.cu
new file mode 100644
index 0000000000..eb3b958eb2
--- /dev/null
+++ b/cpp/tensorrt_llm/kernels/customMoeRoutingKernels.cu
@@ -0,0 +1,268 @@
+/*
+ * Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "moeTopKFuncs.cuh"
+#include "tensorrt_llm/common/cudaTypeUtils.cuh"
+#include "tensorrt_llm/common/envUtils.h"
+#include "tensorrt_llm/kernels/archCondition.h"
+#include "tensorrt_llm/kernels/customMoeRoutingKernels.h"
+#include // For INT_MAX
+#include
+#include
+#include
+#include // For numeric_limits
+#include
+
+namespace cg = cooperative_groups;
+using namespace tensorrt_llm::common;
+
+namespace tensorrt_llm::kernels
+{
+
+static constexpr int BLOCK_SIZE = 1024;
+static constexpr int WARP_SIZE = 32;
+static constexpr int WARPS_PER_BLOCK = BLOCK_SIZE / WARP_SIZE;
+
+////////////////////////////////////////////////////////////////////////////////////////////////////
+
+template
+__device__ T calcSoftmax(cg::thread_block_tile const& warp, T score, int32_t laneIdx, int32_t NumTopExperts)
+{
+ T maxScore = T{-INFINITY};
+ if (laneIdx < NumTopExperts)
+ {
+ maxScore = score >= maxScore ? score : maxScore;
+ }
+ maxScore = cg::reduce(warp, maxScore, cg::greater());
+
+ float sumScore{0.f};
+ float newScore;
+ // Get the summation of scores for each token
+ if (laneIdx < NumTopExperts)
+ {
+ newScore = static_cast(score) - static_cast(maxScore);
+ newScore = static_cast(exp(newScore));
+ sumScore += newScore;
+ }
+ sumScore = cg::reduce(warp, sumScore, cg::plus());
+
+ if (laneIdx < NumTopExperts)
+ {
+ score = static_cast(newScore / sumScore);
+ }
+
+ return score;
+}
+
+template
+__device__ void calcSoftmax(cg::thread_block_tile const& warp, DataType (&scores)[VecSize])
+{
+ DataType maxScore = DataType{-INFINITY};
+ DataType sumScore = DataType{0.f};
+
+ // Get the max score for each token
+#pragma unroll
+ for (int i = 0; i < VecSize; ++i)
+ {
+ maxScore = scores[i] >= maxScore ? scores[i] : maxScore;
+ }
+ maxScore = cg::reduce(warp, maxScore, cg::greater());
+
+ // Get the summation of scores for each token
+#pragma unroll
+ for (int i = 0; i < VecSize; ++i)
+ {
+ scores[i] = static_cast(exp(scores[i] - maxScore));
+ sumScore += scores[i];
+ }
+ sumScore = cg::reduce(warp, sumScore, cg::plus());
+
+ // Normalize the scores
+#pragma unroll
+ for (int i = 0; i < VecSize; ++i)
+ {
+ scores[i] = static_cast(scores[i] / sumScore);
+ }
+}
+
+////////////////////////////////////////////////////////////////////////////////////////////////////
+
+template
+__global__ void customMoeRoutingKernel(InputT* routerLogits, OutputT* topkValues, IdxT* topkIndices,
+ int32_t const numTokens, int32_t const numExperts, int32_t const topK)
+{
+ using BaseType = std::conditional_t;
+ uint32_t const blockRank = blockIdx.x;
+ uint32_t const tIdx = BLOCK_SIZE * blockRank + threadIdx.x;
+ uint32_t const warpIdx = tIdx / WARP_SIZE;
+ uint32_t const laneIdx = tIdx % WARP_SIZE;
+ uint32_t const warpNum = gridDim.x * WARPS_PER_BLOCK;
+ auto block = cg::this_thread_block();
+ auto warp = cg::tiled_partition(block);
+
+ BaseType minScore = BaseType{-INFINITY};
+ for (uint32_t tokenId = warpIdx; tokenId < numTokens; tokenId += warpNum)
+ {
+ auto scoreOffset = tokenId * numExperts;
+ auto outputOffset = tokenId * topK;
+
+ BaseType inputScore[MaxNumExperts / WARP_SIZE];
+ IdxT inputIndex[MaxNumExperts / WARP_SIZE];
+
+ BaseType warpTopKScore[MaxNumTopExperts];
+ IdxT warpTopKExpertIdx[MaxNumTopExperts];
+
+ // Load scores and indices for this warp
+ for (uint32_t i = 0; i < MaxNumExperts / WARP_SIZE; ++i)
+ {
+ auto expertIdx = i * WARP_SIZE + laneIdx;
+ inputScore[i]
+ = expertIdx < numExperts ? static_cast(routerLogits[scoreOffset + expertIdx]) : minScore;
+ inputIndex[i] = expertIdx;
+ }
+
+ if constexpr (DoSoftmaxBeforeTopK)
+ {
+ calcSoftmax(warp, inputScore);
+ }
+ // Reduce topK scores and indices for this warp
+ reduce_topk::reduceTopK(warp, warpTopKScore, warpTopKExpertIdx, inputScore, inputIndex, minScore);
+
+ // Normalize the scores
+ if constexpr (DoSoftmaxBeforeTopK)
+ {
+ if (laneIdx < topK)
+ {
+ topkValues[outputOffset + laneIdx] = static_cast(warpTopKScore[laneIdx]);
+ topkIndices[outputOffset + laneIdx] = warpTopKExpertIdx[laneIdx];
+ }
+ }
+ else
+ {
+ auto softmaxScore = calcSoftmax(warp,
+ laneIdx < topK ? static_cast(warpTopKScore[laneIdx]) : static_cast(minScore), laneIdx,
+ topK);
+ if (laneIdx < topK)
+ {
+ topkValues[outputOffset + laneIdx] = static_cast(softmaxScore);
+ topkIndices[outputOffset + laneIdx] = warpTopKExpertIdx[laneIdx];
+ }
+ }
+ } // end for tokenId
+}
+
+int nextPowerOfTwo(int num)
+{
+ if (num <= 0)
+ {
+ return 1; // Handle invalid input
+ }
+ int power = 1;
+ while (power < num)
+ {
+ // Check for overflow before shifting
+ if (power > INT_MAX / 2)
+ {
+ return power;
+ }
+ power <<= 1;
+ }
+ return power;
+}
+
+#define CASE(MAX_NUM_EXPERTS) \
+ case MAX_NUM_EXPERTS: \
+ switch (maxNumTopExperts) \
+ { \
+ case 1: \
+ kernelInstance = &customMoeRoutingKernel; \
+ break; \
+ case 2: \
+ kernelInstance = &customMoeRoutingKernel; \
+ break; \
+ case 4: \
+ kernelInstance = &customMoeRoutingKernel; \
+ break; \
+ case 8: \
+ kernelInstance = &customMoeRoutingKernel; \
+ break; \
+ default: kernelInstance = nullptr; break; \
+ } \
+ break;
+
+template
+void invokeRenormMoeRouting(InputT* routerLogits, OutputT* topkValues, IdxT* topkIndices, int64_t const numTokens,
+ int64_t const numExperts, int64_t const topK, cudaStream_t const stream)
+{
+
+ const uint32_t maxNumBlocks = 1024;
+ const uint32_t numBlocks = std::min(static_cast((numTokens - 1) / WARPS_PER_BLOCK + 1), maxNumBlocks);
+
+ uint32_t maxNumExperts = nextPowerOfTwo(numExperts) < 32 ? 32 : nextPowerOfTwo(numExperts);
+ uint32_t maxNumTopExperts = nextPowerOfTwo(topK);
+
+ auto* kernelInstance = &customMoeRoutingKernel;
+
+ switch (maxNumExperts)
+ {
+ CASE(32)
+ CASE(64)
+ CASE(96)
+ CASE(128)
+ default: kernelInstance = nullptr; break;
+ }
+
+ if (kernelInstance == nullptr)
+ {
+ TLLM_CHECK_WITH_INFO(kernelInstance != nullptr, "Can not find corresponding kernel instance.");
+ }
+
+ dim3 renormMoeRoutingGridDim(numBlocks);
+ dim3 renormMoeRoutingBlockDim(BLOCK_SIZE);
+ cudaLaunchConfig_t config;
+ config.gridDim = renormMoeRoutingGridDim;
+ config.blockDim = renormMoeRoutingBlockDim;
+ config.dynamicSmemBytes = 0;
+ config.stream = stream;
+ cudaLaunchAttribute attrs[1];
+ attrs[0].id = cudaLaunchAttributeProgrammaticStreamSerialization;
+ attrs[0].val.programmaticStreamSerializationAllowed = tensorrt_llm::common::getEnvEnablePDL();
+ config.numAttrs = 1;
+ config.attrs = attrs;
+ cudaLaunchKernelEx(&config, kernelInstance, routerLogits, topkValues, topkIndices, static_cast(numTokens),
+ static_cast(numExperts), static_cast(topK));
+ sync_check_cuda_error(stream);
+}
+
+#define INSTANTIATE_RENORM_MOE_ROUTING(InputT, OutputT, IdxT, DoSoftmaxBeforeTopK) \
+ template void invokeRenormMoeRouting(InputT * routerLogits, \
+ OutputT * topkValues, IdxT * topkIndices, int64_t const numTokens, int64_t const numExperts, \
+ int64_t const topK, cudaStream_t const stream);
+
+INSTANTIATE_RENORM_MOE_ROUTING(float, float, int32_t, false);
+INSTANTIATE_RENORM_MOE_ROUTING(half, float, int32_t, false);
+#ifdef ENABLE_BF16
+INSTANTIATE_RENORM_MOE_ROUTING(__nv_bfloat16, float, int32_t, false);
+#endif
+
+INSTANTIATE_RENORM_MOE_ROUTING(float, float, int32_t, true);
+INSTANTIATE_RENORM_MOE_ROUTING(half, float, int32_t, true);
+#ifdef ENABLE_BF16
+INSTANTIATE_RENORM_MOE_ROUTING(__nv_bfloat16, float, int32_t, true);
+#endif
+
+} // namespace tensorrt_llm::kernels
diff --git a/cpp/tensorrt_llm/kernels/renormMoeRoutingKernels.h b/cpp/tensorrt_llm/kernels/customMoeRoutingKernels.h
similarity index 86%
rename from cpp/tensorrt_llm/kernels/renormMoeRoutingKernels.h
rename to cpp/tensorrt_llm/kernels/customMoeRoutingKernels.h
index 1e9b001f65..cfe0ae8f15 100644
--- a/cpp/tensorrt_llm/kernels/renormMoeRoutingKernels.h
+++ b/cpp/tensorrt_llm/kernels/customMoeRoutingKernels.h
@@ -1,5 +1,5 @@
/*
- * Copyright (c) 2019-2023, NVIDIA CORPORATION. All rights reserved.
+ * Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@@ -23,7 +23,7 @@
namespace tensorrt_llm::kernels
{
-template
+template
void invokeRenormMoeRouting(InputT* routerLogits, OutputT* topkValues, IdxT* topkIndices, int64_t const numTokens,
int64_t const numExperts, int64_t const topK, cudaStream_t const stream);
} // namespace tensorrt_llm::kernels
diff --git a/cpp/tensorrt_llm/kernels/cutlass_kernels/include/moe_gemm_kernels.h b/cpp/tensorrt_llm/kernels/cutlass_kernels/include/moe_gemm_kernels.h
index ba755ca669..2c0894ab57 100644
--- a/cpp/tensorrt_llm/kernels/cutlass_kernels/include/moe_gemm_kernels.h
+++ b/cpp/tensorrt_llm/kernels/cutlass_kernels/include/moe_gemm_kernels.h
@@ -288,15 +288,20 @@ public:
void moeGemm(GroupedGemmInput inputs,
TmaWarpSpecializedGroupedGemmInput hopper_inputs);
- std::vector getConfigs() const;
- static std::vector getConfigs(int sm);
- static std::vector getTmaWarpSpecializedConfigs(int sm);
- static std::vector getBlackwellConfigs(int sm);
- static std::vector getHopperConfigs(int sm);
+ std::vector getConfigs(bool supports_finalize_fusion) const;
+ static std::vector getConfigs(int sm, bool supports_finalize_fusion);
+ static std::vector getTmaWarpSpecializedConfigs(
+ int sm, bool supports_finalize_fusion);
static std::vector getAmpereConfigs(int sm);
[[nodiscard]] bool isTmaWarpSpecialized(cutlass_extensions::CutlassGemmConfig gemm_config) const;
- [[nodiscard]] bool supportsTmaWarpSpecialized() const;
+
+ [[nodiscard]] bool supportsTmaWarpSpecialized() const
+ {
+ return supportsTmaWarpSpecialized(sm_);
+ }
+
+ [[nodiscard]] static bool supportsTmaWarpSpecialized(int sm);
[[nodiscard]] bool isFusedGatedActivation(cutlass_extensions::CutlassGemmConfig gemm_config,
ActivationType activation_type, int gemm_n, int gemm_k) const;
[[nodiscard]] bool supportsFusedGatedActivation(ActivationType activation_type, int gemm_n, int gemm_k) const;
diff --git a/cpp/tensorrt_llm/kernels/cutlass_kernels/include/moe_kernels.h b/cpp/tensorrt_llm/kernels/cutlass_kernels/include/moe_kernels.h
index 7d592bed0e..389591e7fe 100644
--- a/cpp/tensorrt_llm/kernels/cutlass_kernels/include/moe_kernels.h
+++ b/cpp/tensorrt_llm/kernels/cutlass_kernels/include/moe_kernels.h
@@ -228,6 +228,13 @@ struct MOEParallelismConfig
}
};
+enum class MoeGemmId : int
+{
+ Undefined = 0,
+ GEMM_1,
+ GEMM_2
+};
+
struct QuantParams
{
// Int weight only quantization params
@@ -446,7 +453,7 @@ public:
virtual void setTactic(std::optional gemm1_config,
std::optional gemm2_config)
= 0;
- virtual std::vector getTactics() = 0;
+ virtual std::vector getTactics(MoeGemmId gemm_id) = 0;
virtual void runMoe(void const* input_activations, void const* input_sf, bool const swizzled_input_sf,
int const* token_selected_experts, float const* token_final_scales, void const* fc1_expert_weights,
@@ -593,15 +600,15 @@ public:
gemm2_config_ = std::move(gemm2_config);
}
- std::vector getTactics() override
+ std::vector getTactics(MoeGemmId gemm_id) override
{
- return moe_gemm_runner_.getConfigs();
+ return moe_gemm_runner_.getConfigs(gemm_id == MoeGemmId::GEMM_2 && mayHaveFinalizeFused());
}
- static std::vector getTactics(int sm)
+ static std::vector getTactics(int sm, MoeGemmId gemm_id)
{
using RunnerType = decltype(moe_gemm_runner_);
- return RunnerType::getConfigs(sm);
+ return RunnerType::getConfigs(sm, gemm_id == MoeGemmId::GEMM_2 && Self::mayHaveFinalizeFused(sm));
}
void runMoe(void const* input_activations, void const* input_sf, bool const swizzled_input_sf,
@@ -798,6 +805,12 @@ private:
&& !use_w4_groupwise;
}
+ static bool mayHaveFinalizeFused(int sm)
+ {
+ using RunnerType = decltype(moe_gemm_runner_);
+ return RunnerType::supportsTmaWarpSpecialized(sm) && sm >= 90 && !use_w4_groupwise;
+ }
+
// TODO: This should eventually take the quant params to give more flexibility
static auto getScalingType()
{
@@ -895,12 +908,7 @@ struct GemmProfilerBackend
{
public:
using Config = cutlass_extensions::CutlassGemmConfig;
- enum class GemmToProfile
- {
- Undefined = 0,
- GEMM_1,
- GEMM_2
- };
+ using GemmToProfile = MoeGemmId;
void init(CutlassMoeFCRunnerInterface& runner, GemmToProfile gemm_to_profile, nvinfer1::DataType dtype,
nvinfer1::DataType wtype, nvinfer1::DataType otype, int num_experts, int k, int64_t hidden_size,
@@ -951,7 +959,6 @@ public:
CutlassMoeFCRunnerInterface* mInterface;
GemmToProfile mGemmToProfile = GemmToProfile::Undefined;
- std::vector mAllTacticsSaved;
int mSM{};
int64_t mNumExperts{};
int64_t mNumExpertsPerNode{};
@@ -972,7 +979,7 @@ public:
// This will be a unique value for every iteration of warmup and actual bench
constexpr static int64_t NUM_ROUTING_SAMPLES = 16;
- std::array mTmaInputCache;
+ std::array, NUM_ROUTING_SAMPLES> mTmaInputCache;
QuantParams mQuantParams;
bool mBias{};
@@ -985,7 +992,8 @@ public:
private:
void prepareRouting(int num_tokens, char* workspace, cudaStream_t stream);
void prepareQuantParams(int num_tokens, char* workspace, cudaStream_t stream);
- void prepareTmaWsInputs(int num_tokens, char* workspace, void const* expert_weights, cudaStream_t stream);
+ void prepareTmaWsInputs(int num_tokens, char* workspace, void const* expert_weights,
+ TmaWarpSpecializedGroupedGemmInput::EpilogueFusion fusion, cudaStream_t stream);
};
// Populates a buffer with random values for use with MOE benchmarking
diff --git a/cpp/tensorrt_llm/kernels/cutlass_kernels/int8_gemm/int8_gemm_template.h b/cpp/tensorrt_llm/kernels/cutlass_kernels/int8_gemm/int8_gemm_template.h
index c44caae0fa..ef06abceee 100644
--- a/cpp/tensorrt_llm/kernels/cutlass_kernels/int8_gemm/int8_gemm_template.h
+++ b/cpp/tensorrt_llm/kernels/cutlass_kernels/int8_gemm/int8_gemm_template.h
@@ -57,7 +57,6 @@ namespace kernels
{
namespace cutlass_kernels
{
-
template
void genericInt8GemmKernelLauncher(int8_t const* A, int8_t const* B, tk::QuantMode quantOption, float const* alphaCol,
float const* alphaRow, T* C, int m, int n, int k, tkc::CutlassGemmConfig gemmConfig, char* workspace,
diff --git a/cpp/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_template_dispatch.h b/cpp/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_template_dispatch.h
index 5a07062f06..899e0787cd 100644
--- a/cpp/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_template_dispatch.h
+++ b/cpp/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_template_dispatch.h
@@ -475,17 +475,18 @@ void dispatchMoeGemmToCutlass(GroupedGemmInput
-std::vector
-MoeGemmRunner::getConfigs() const
+std::vector MoeGemmRunner::getConfigs(
+ bool supports_finalize_fusion) const
{
- return getConfigs(sm_);
+ return getConfigs(sm_, supports_finalize_fusion);
}
template
std::vector MoeGemmRunner::getConfigs(
- int sm)
+ int sm, bool supports_finalize_fusion)
{
- std::vector candidate_configs = getTmaWarpSpecializedConfigs(sm);
+ std::vector candidate_configs
+ = getTmaWarpSpecializedConfigs(sm, supports_finalize_fusion);
std::vector ampere_configs = getAmpereConfigs(sm);
std::copy(ampere_configs.begin(), ampere_configs.end(), std::back_inserter(candidate_configs));
return candidate_configs;
@@ -521,7 +522,8 @@ MoeGemmRunner::getAmpereConfigs(int sm
template
std::vector
-MoeGemmRunner::getTmaWarpSpecializedConfigs(int sm)
+MoeGemmRunner::getTmaWarpSpecializedConfigs(
+ int sm, bool supports_finalize_fusion)
{
using tensorrt_llm::cutlass_extensions::CutlassGemmConfig;
static constexpr auto weight_only_flag
@@ -568,6 +570,17 @@ MoeGemmRunner::getTmaWarpSpecializedCo
= tensorrt_llm::kernels::cutlass_kernels::get_candidate_configs(100, max_split_k, config_type_param);
std::copy(sm100_configs.begin(), sm100_configs.end(), std::back_inserter(tma_ws_configs));
}
+ if (supports_finalize_fusion)
+ {
+ // Duplicate the configs and set the epilogue fusion type to FINALIZE
+ auto finalize_configs = tma_ws_configs;
+ std::transform(finalize_configs.begin(), finalize_configs.end(), std::back_inserter(tma_ws_configs),
+ [](auto& config)
+ {
+ config.epilogue_fusion_type = cutlass_extensions::CutlassGemmConfig::EpilogueFusionType::FINALIZE;
+ return config;
+ });
+ }
return tma_ws_configs;
}
@@ -580,13 +593,11 @@ bool MoeGemmRunner::isTmaWarpSpecializ
}
template
-bool MoeGemmRunner::supportsTmaWarpSpecialized() const
+bool MoeGemmRunner::supportsTmaWarpSpecialized(int sm)
{
- return (sm_ == 90 && tensorrt_llm::kernels::cutlass_kernels::isValidHopperMOESpecialisation())
- || (sm_ >= 100 && sm_ < 120
- && tensorrt_llm::kernels::cutlass_kernels::isValidBlackwellMOESpecialisation())
- || ((sm_ == 120 || sm_ == 121)
- && tensorrt_llm::kernels::cutlass_kernels::isValidSM120MOESpecialisation());
+ return (sm == 90 && kernels::cutlass_kernels::isValidHopperMOESpecialisation())
+ || (sm >= 100 && sm < 120 && kernels::cutlass_kernels::isValidBlackwellMOESpecialisation())
+ || ((sm == 120 || sm == 121) && kernels::cutlass_kernels::isValidSM120MOESpecialisation());
}
template
@@ -833,7 +844,9 @@ size_t MoeGemmRunner::calcMaxWorkspace
if constexpr (kernels::cutlass_kernels::isValidTmaWarpSpecializedMOESpecialisation() && !use_w4afp8
&& !use_wfp4a16)
{
- auto configs = getTmaWarpSpecializedConfigs(sm_);
+ // Finalize fusion may not actually be supported by the kernel,
+ // if they are not we will catch the error and skip them
+ auto configs = getTmaWarpSpecializedConfigs(sm_, true);
auto fpX_block_scaling_type = TmaWarpSpecializedGroupedGemmInput::FpXBlockScalingType::NONE;
if constexpr (use_wfp4afp8)
{
diff --git a/cpp/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_kernels.cu b/cpp/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_kernels.cu
index 730840717c..ef70b9d45e 100644
--- a/cpp/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_kernels.cu
+++ b/cpp/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_kernels.cu
@@ -2847,9 +2847,10 @@ void CutlassMoeFCRunnerepilogue_fusion_type == cutlass_extensions::CutlassGemmConfig::EpilogueFusionType::FINALIZE;
permuted_token_final_scales_
- = (gemm2_using_tma_ws && mayHaveFinalizeFused()) ? getWsPtr(float{}, "permuted_token_final_scales") : nullptr;
+ = gemm2_using_finalize_fusion ? getWsPtr(float{}, "permuted_token_final_scales") : nullptr;
bool const is_gated_activation = isGatedActivation(activation_type);
bool const gemm1_using_fused_moe
@@ -4006,8 +4007,12 @@ CutlassMoeFCRunner::
bool apply_bias = parallelism_config.tp_rank == 0;
auto* fc2_bias = apply_bias ? fc2_expert_biases : nullptr;
+ bool gemm2_using_finalize_fusion = gemm2_config_->epilogue_fusion_type
+ == cutlass_extensions::CutlassGemmConfig::EpilogueFusionType::FINALIZE;
bool using_fused_finalize
- = use_fused_finalize_ && gemm2_config_->sm_version >= 90 && !use_w4_groupwise && !use_lora;
+ = use_fused_finalize_ && gemm2_using_finalize_fusion && !use_w4_groupwise && !use_lora;
+ TLLM_CHECK_WITH_INFO(using_fused_finalize == gemm2_using_finalize_fusion,
+ "GEMM2 tactic requests finalize fusion, but the runner is not configured to use it");
if (using_fused_finalize)
{
assert(min_latency_mode == false);
@@ -4550,14 +4555,26 @@ void GemmProfilerBackend::prepareQuantParams(int num_tokens, char* workspace_ptr
}
}
-void GemmProfilerBackend::prepareTmaWsInputs(
- int num_tokens, char* workspace_ptr_char, void const* expert_weights, cudaStream_t stream)
+void GemmProfilerBackend::prepareTmaWsInputs(int num_tokens, char* workspace_ptr_char, void const* expert_weights,
+ TmaWarpSpecializedGroupedGemmInput::EpilogueFusion fusion, cudaStream_t stream)
{
if (mSM < 90)
{
return;
}
+ bool use_w4afp8 = (mDType == nvinfer1::DataType::kFP8 && mWType == nvinfer1::DataType::kINT4);
+ bool use_wfp4a16 = ((mDType == nvinfer1::DataType::kHALF || mDType == nvinfer1::DataType::kBF16)
+ && mWType == nvinfer1::DataType::kUINT8);
+ bool use_w4_groupwise = use_w4afp8 || use_wfp4a16;
+ bool const use_finalize_fusion = fusion == TmaWarpSpecializedGroupedGemmInput::EpilogueFusion::FINALIZE;
+ bool const finalize_fusion_not_supported = !mInterface->use_fused_finalize_ || mMinLatencyMode || use_w4_groupwise
+ || mGemmToProfile != GemmToProfile::GEMM_2;
+ if (use_finalize_fusion && finalize_fusion_not_supported)
+ {
+ return;
+ }
+
auto workspaces = getProfilerWorkspaces(num_tokens, mSM >= 90);
#define GET_WS_PTR(type, name) \
@@ -4596,28 +4613,24 @@ void GemmProfilerBackend::prepareTmaWsInputs(
size_t num_expanded_tokens = num_tokens * mK;
for (int64_t i = 0; i < NUM_ROUTING_SAMPLES; i++)
{
- mTmaInputCache[i].configureWorkspace(tma_ws_input_workspace, mNumExpertsPerNode, gemm_workspace,
+ // Note: Even though we have separate TMA WS inputs for finalize fusion on/off we reuse the same pointers to
+ // save space.
+ auto& cache_element = mTmaInputCache[i][use_finalize_fusion];
+ cache_element.configureWorkspace(tma_ws_input_workspace, mNumExpertsPerNode, gemm_workspace,
workspaces.at("gemm_workspace").first, mScalingType);
tma_ws_input_workspace += tma_ws_size;
int64_t* expert_first_token_offset = expert_first_token_offset_base + i * (mNumExpertsPerNode + 1);
int* permuted_row_to_unpermuted_row = permuted_row_to_unpermuted_row_base + i * num_expanded_tokens;
- auto& gemm1_tma_ws_input = mGemmToProfile == GemmToProfile::GEMM_1 ? mTmaInputCache[i] : dummy_tma_ws_input;
- auto& gemm2_tma_ws_input = mGemmToProfile == GemmToProfile::GEMM_2 ? mTmaInputCache[i] : dummy_tma_ws_input;
+ auto& gemm1_tma_ws_input = mGemmToProfile == GemmToProfile::GEMM_1 ? cache_element : dummy_tma_ws_input;
+ auto& gemm2_tma_ws_input = mGemmToProfile == GemmToProfile::GEMM_2 ? cache_element : dummy_tma_ws_input;
if (mSM >= 90)
{
/* GEMM1 */
gemm1_tma_ws_input.fusion = TmaWarpSpecializedGroupedGemmInput::EpilogueFusion::NONE;
gemm2_tma_ws_input.fusion = TmaWarpSpecializedGroupedGemmInput::EpilogueFusion::NONE;
-
- bool use_w4afp8 = (mDType == nvinfer1::DataType::kFP8 && mWType == nvinfer1::DataType::kINT4);
- bool use_wfp4a16 = ((mDType == nvinfer1::DataType::kHALF || mDType == nvinfer1::DataType::kBF16)
- && mWType == nvinfer1::DataType::kUINT8);
- bool use_w4_groupwise = use_w4afp8 || use_wfp4a16;
- bool using_fused_finalize
- = mInterface->use_fused_finalize_ && mSM >= 90 && !mMinLatencyMode && !use_w4_groupwise;
- if (using_fused_finalize)
+ if (use_finalize_fusion)
{
assert(!mMinLatencyMode);
gemm2_tma_ws_input.fusion = TmaWarpSpecializedGroupedGemmInput::EpilogueFusion::FINALIZE;
@@ -4652,7 +4665,6 @@ void GemmProfilerBackend::prepareTmaWsInputs(
void GemmProfilerBackend::prepare(
int num_tokens, char* workspace_ptr_char, void const* expert_weights, cudaStream_t stream)
{
- mAllTacticsSaved = mInterface->getTactics();
mSampleIndex = 0;
auto workspace_size = getWorkspaceSize(num_tokens);
@@ -4660,7 +4672,10 @@ void GemmProfilerBackend::prepare(
prepareRouting(num_tokens, workspace_ptr_char, stream);
prepareQuantParams(num_tokens, workspace_ptr_char, stream);
- prepareTmaWsInputs(num_tokens, workspace_ptr_char, expert_weights, stream);
+ prepareTmaWsInputs(num_tokens, workspace_ptr_char, expert_weights,
+ TmaWarpSpecializedGroupedGemmInput::EpilogueFusion::NONE, stream);
+ prepareTmaWsInputs(num_tokens, workspace_ptr_char, expert_weights,
+ TmaWarpSpecializedGroupedGemmInput::EpilogueFusion::FINALIZE, stream);
}
size_t GemmProfilerBackend::getWorkspaceSize(int maxM)
@@ -4724,7 +4739,9 @@ void GemmProfilerBackend::runProfiler(int original_num_tokens, Config const& tac
TmaWarpSpecializedGroupedGemmInput tma_ws_input_template;
if (tactic.is_tma_warp_specialized)
{
- tma_ws_input_template = mTmaInputCache[mSampleIndex];
+ tma_ws_input_template = mTmaInputCache[mSampleIndex][tactic.epilogue_fusion_type
+ == cutlass_extensions::CutlassGemmConfig::EpilogueFusionType::FINALIZE];
+ TLLM_CHECK_WITH_INFO(tma_ws_input_template.isValid(), "TMA WS input template is not initialized");
}
mInterface->is_profiler = true;
diff --git a/cpp/tensorrt_llm/kernels/internal_cutlass_kernels/aarch64-linux-gnu/tensorrt_llm_internal_cutlass_kernels_static.tar.xz b/cpp/tensorrt_llm/kernels/internal_cutlass_kernels/aarch64-linux-gnu/tensorrt_llm_internal_cutlass_kernels_static.tar.xz
index 08cd9b6f66..5ebd5f7ebe 100644
--- a/cpp/tensorrt_llm/kernels/internal_cutlass_kernels/aarch64-linux-gnu/tensorrt_llm_internal_cutlass_kernels_static.tar.xz
+++ b/cpp/tensorrt_llm/kernels/internal_cutlass_kernels/aarch64-linux-gnu/tensorrt_llm_internal_cutlass_kernels_static.tar.xz
@@ -1,3 +1,3 @@
version https://git-lfs.github.com/spec/v1
-oid sha256:86586b9f6845e91e8ba0accad53a5a3418c50d8fd30ad49fa8837470c72b5dcf
-size 67051604
+oid sha256:d6a3f6adef11003f794a6cec1235d0c622ead71b4e801a89866e91dfd91bb30c
+size 67053244
diff --git a/cpp/tensorrt_llm/kernels/internal_cutlass_kernels/aarch64-linux-gnu/version.txt b/cpp/tensorrt_llm/kernels/internal_cutlass_kernels/aarch64-linux-gnu/version.txt
index 8b500f5c97..b93f46ea6d 100644
--- a/cpp/tensorrt_llm/kernels/internal_cutlass_kernels/aarch64-linux-gnu/version.txt
+++ b/cpp/tensorrt_llm/kernels/internal_cutlass_kernels/aarch64-linux-gnu/version.txt
@@ -1,2 +1,2 @@
-568cb6ca2413c93b0f5839dd05577c0c57bc4b5f2359366c79d0ace665de4bd6 libtensorrt_llm_internal_cutlass_kernels_static.a
-commit 9c0a42825905952beaf9b35d5a35d58de1a123fa
+317a25037093a6f3d156ffa58a68bce53071ef68dacdcb04cc0aaeea80b64e76 libtensorrt_llm_internal_cutlass_kernels_static.a
+commit 444ef1b3b06cdc7ee66b4e612ce26ad25967440b
diff --git a/cpp/tensorrt_llm/kernels/internal_cutlass_kernels/x86_64-linux-gnu/tensorrt_llm_internal_cutlass_kernels_static.tar.xz b/cpp/tensorrt_llm/kernels/internal_cutlass_kernels/x86_64-linux-gnu/tensorrt_llm_internal_cutlass_kernels_static.tar.xz
index f1a6b9dc88..bd07528460 100644
--- a/cpp/tensorrt_llm/kernels/internal_cutlass_kernels/x86_64-linux-gnu/tensorrt_llm_internal_cutlass_kernels_static.tar.xz
+++ b/cpp/tensorrt_llm/kernels/internal_cutlass_kernels/x86_64-linux-gnu/tensorrt_llm_internal_cutlass_kernels_static.tar.xz
@@ -1,3 +1,3 @@
version https://git-lfs.github.com/spec/v1
-oid sha256:6489751f16a4dadf42664738ded03fbbd60195619f2d5f80af8190554318257d
-size 66872936
+oid sha256:489fb557b78062efedd1514f2995fafb9216bb0e0068a550e86763efb9d5eee9
+size 66874608
diff --git a/cpp/tensorrt_llm/kernels/internal_cutlass_kernels/x86_64-linux-gnu/version.txt b/cpp/tensorrt_llm/kernels/internal_cutlass_kernels/x86_64-linux-gnu/version.txt
index 4af58b0800..3c053c1a91 100644
--- a/cpp/tensorrt_llm/kernels/internal_cutlass_kernels/x86_64-linux-gnu/version.txt
+++ b/cpp/tensorrt_llm/kernels/internal_cutlass_kernels/x86_64-linux-gnu/version.txt
@@ -1,2 +1,2 @@
-813c237a565664b2acf2313f0e436f66f24deeb16a84d273dc007af55795e55f libtensorrt_llm_internal_cutlass_kernels_static.a
-commit 9c0a42825905952beaf9b35d5a35d58de1a123fa
+5a31acd0fb1415196bff71fa4a8d1dded147e15ea10821cc46c85684c66986ee libtensorrt_llm_internal_cutlass_kernels_static.a
+commit 444ef1b3b06cdc7ee66b4e612ce26ad25967440b
diff --git a/cpp/tensorrt_llm/kernels/moeTopKFuncs.cuh b/cpp/tensorrt_llm/kernels/moeTopKFuncs.cuh
new file mode 100644
index 0000000000..933b599dbd
--- /dev/null
+++ b/cpp/tensorrt_llm/kernels/moeTopKFuncs.cuh
@@ -0,0 +1,205 @@
+
+/*
+ * Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+#pragma once
+#ifndef TRTLLM_MOETOPKFUNCS_CUH_H
+#define TRTLLM_MOETOPKFUNCS_CUH_H
+
+#include
+#include
+#include
+
+#include "tensorrt_llm/kernels/archCondition.h"
+
+namespace tensorrt_llm::kernels
+{
+
+namespace reduce_topk
+{
+namespace cg = cooperative_groups;
+static constexpr int kWARP_SIZE = 32;
+static constexpr bool kTLLM_GEN_HAS_FAST_REDUX = tensorrt_llm::kernels::arch::is_major_v<10>;
+
+template
+struct TopKRedType
+{
+ using T = T_;
+ static_assert(std::is_same_v || std::is_same_v || std::is_same_v
+ || std::is_same_v,
+ "Top K reduction only implemented for int, float, float16 and bfloat16");
+
+ using TypeCmp = std::conditional_t;
+ using IdxT = std::conditional_t;
+
+ static constexpr int kMoveBits = (sizeof(T) == 4) ? 32 : 16;
+ static constexpr int kMaxIdx = 65535;
+ TypeCmp compValIdx;
+
+ static __host__ __device__ inline TypeCmp makeCmpVal(T val, int32_t idx = 0)
+ {
+ auto valueBits = cub::Traits::TwiddleIn(reinterpret_cast::UnsignedBits&>(val));
+ TypeCmp compactTmp = reinterpret_cast(valueBits);
+ compactTmp = (compactTmp << kMoveBits) | (0xFFFF & (kMaxIdx - idx));
+ // Use 65535 minus idx to give higher priority to elements with smaller indices.
+ return compactTmp;
+ }
+
+ static __host__ __device__ void unpack(T& value, int32_t& index, TypeCmp cmp)
+ {
+ // Since “65535-idx” is always smaller than 65536 and positive, we can directly use it as the lower 16 bits
+ index = kMaxIdx - static_cast((cmp & 0xFFFF));
+
+ auto compactTmp = cmp >> kMoveBits;
+ auto valueBits
+ = cub::Traits::TwiddleOut(reinterpret_cast::UnsignedBits&>(compactTmp));
+ value = reinterpret_cast(valueBits);
+ }
+
+ __host__ __device__ TopKRedType() = default;
+
+ __host__ __device__ TopKRedType(T val, int32_t idx)
+ : compValIdx(makeCmpVal(val, idx))
+ {
+ }
+
+ __host__ __device__ operator TypeCmp() const noexcept
+ {
+ return compValIdx;
+ }
+
+ __device__ inline TypeCmp reduce(cg::thread_block_tile const& warp)
+ {
+ if constexpr (!kTLLM_GEN_HAS_FAST_REDUX || sizeof(TypeCmp) == 8)
+ {
+ return cg::reduce(warp, compValIdx, cg::greater{});
+ }
+ else
+ {
+ TypeCmp result;
+ asm("redux.sync.max.u32 %0, %1, 0xffffffff;\n" : "=r"(result) : "r"(compValIdx));
+ return result;
+ }
+ }
+};
+
+////////////////////////////////////////////////////////////////////////////////////////////////////
+
+template
+struct TopKIdx
+{
+ // by default, empty
+};
+
+template
+struct TopKIdx
+{
+ static constexpr int K = K_;
+ int32_t val[K];
+};
+
+////////////////////////////////////////////////////////////////////////////////////////////////////
+
+#define TOPK_SWAP(I, J) \
+ { \
+ auto pairMin = min(topK[I].compValIdx, topK[J].compValIdx); \
+ auto pairMax = max(topK[I].compValIdx, topK[J].compValIdx); \
+ topK[I].compValIdx = pairMax; \
+ topK[J].compValIdx = pairMin; \
+ }
+
+template
+struct Sort;
+
+template
+struct Sort<1, RedType>
+{
+ static __device__ void run(RedType* topK) {}
+};
+
+template
+struct Sort<2, RedType>
+{
+ static __device__ void run(RedType* topK)
+ {
+ TOPK_SWAP(0, 1);
+ }
+};
+
+template
+struct Sort<3, RedType>
+{
+ static __device__ void run(RedType* topK)
+ {
+ TOPK_SWAP(0, 1);
+ TOPK_SWAP(1, 2);
+ TOPK_SWAP(0, 1);
+ }
+};
+
+template
+struct Sort<4, RedType>
+{
+ static __device__ void run(RedType* topK)
+ {
+ TOPK_SWAP(0, 2);
+ TOPK_SWAP(1, 3);
+ TOPK_SWAP(0, 1);
+ TOPK_SWAP(2, 3);
+ TOPK_SWAP(1, 2);
+ }
+};
+
+template
+__device__ void reduceTopK(cg::thread_block_tile const& warp, Type (&out)[K], int32_t (&outIdx)[K],
+ Type (&value)[N], int32_t (&idx)[N], Type minValue)
+{
+ static_assert(K > 0, "Top K must have K > 0");
+ static_assert(K < kWARP_SIZE, "Top K must have K < kWARP_SIZE");
+ static_assert(N > 0, "Top K must have N > 0");
+ static_assert(N < 5, "Only support candidates number less than or equal to 128");
+ using RedType = TopKRedType;
+ RedType topK[N];
+#pragma unroll
+ for (int nn = 0; nn < N; ++nn)
+ {
+ topK[nn] = RedType{value[nn], idx[nn]};
+ }
+
+ if constexpr (!IsSorted)
+ {
+ Sort::run(topK);
+ }
+ typename RedType::TypeCmp packedMax{};
+#pragma unroll
+ for (int kk = 0; kk < K; ++kk)
+ {
+ bool update = kk > 0 && packedMax == topK[0].compValIdx;
+#pragma unroll
+ for (int nn = 0; nn < N; ++nn)
+ {
+ topK[nn] = update && nn == N - 1 ? RedType{minValue, idx[nn]} : update ? topK[nn + 1] : topK[nn];
+ }
+ // get the next largest value
+ packedMax = topK[0].reduce(warp);
+ RedType::unpack(out[kk], outIdx[kk], packedMax);
+ }
+};
+
+#undef TOPK_SWAP
+
+} // namespace reduce_topk
+} // namespace tensorrt_llm::kernels
+#endif // TRTLLM_MOETOPKFUNCS_CUH_H
diff --git a/cpp/tensorrt_llm/kernels/renormMoeRoutingKernels.cu b/cpp/tensorrt_llm/kernels/renormMoeRoutingKernels.cu
deleted file mode 100644
index 1b4239e48c..0000000000
--- a/cpp/tensorrt_llm/kernels/renormMoeRoutingKernels.cu
+++ /dev/null
@@ -1,376 +0,0 @@
-/*
- * Copyright (c) 2019-2023, NVIDIA CORPORATION. All rights reserved.
- *
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-#include "tensorrt_llm/common/cudaTypeUtils.cuh"
-#include "tensorrt_llm/common/envUtils.h"
-#include "tensorrt_llm/kernels/archCondition.h"
-#include "tensorrt_llm/kernels/renormMoeRoutingKernels.h"
-#include